diff --git a/.circleci/cimodel/data/binary_build_data.py b/.circleci/cimodel/data/binary_build_data.py index 21b6eebef5a17..bbd237d4cf0b6 100644 --- a/.circleci/cimodel/data/binary_build_data.py +++ b/.circleci/cimodel/data/binary_build_data.py @@ -30,12 +30,12 @@ def get_processor_arch_name(gpu_version): "cu" + gpu_version.strip("cuda") if gpu_version.startswith("cuda") else gpu_version ) - LINUX_PACKAGE_VARIANTS = OrderedDict( manywheel=[ "3.6m", "3.7m", "3.8m", + "3.9m" ], conda=dimensions.STANDARD_PYTHON_VERSIONS, libtorch=[ diff --git a/.circleci/cimodel/data/dimensions.py b/.circleci/cimodel/data/dimensions.py index 1f83cd61b13cd..89c0d4f6641c5 100644 --- a/.circleci/cimodel/data/dimensions.py +++ b/.circleci/cimodel/data/dimensions.py @@ -1,15 +1,14 @@ PHASES = ["build", "test"] CUDA_VERSIONS = [ - "92", "101", "102", "110", ] ROCM_VERSIONS = [ - "3.7", - "3.8", + "3.10", + "4.0", ] ROCM_VERSION_LABELS = ["rocm" + v for v in ROCM_VERSIONS] @@ -20,4 +19,5 @@ "3.6", "3.7", "3.8", + "3.9" ] diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index ebaddda7ca265..8b129533765ab 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -18,7 +18,11 @@ ("clang", [ ("5", [ ("3.6", [ - ("asan", [XImportant(True)]), + ("asan", [ + (True, [ + ("shard_test", [XImportant(True)]), + ]), + ]), ]), ]), ("7", [ @@ -45,14 +49,22 @@ ]), ("10.2", [ ("3.6", [ - ("important", [X(True)]), - ("libtorch", [X(True)]), + ("shard_test", [XImportant(True)]), + ("libtorch", [ + (True, [ + ('build_only', [X(True)]), + ]), + ]), ]), ]), - ("11.0", [ + ("11.1", [ ("3.8", [ X(True), - ("libtorch", [XImportant(True)]) + ("libtorch", [ + (True, [ + ('build_only', [XImportant(True)]), + ]), + ]), ]), ]), ]), @@ -72,12 +84,16 @@ ("gcc", [ ("9", [ ("3.8", [ - ("coverage", [XImportant(True)]), + ("coverage", [ + (True, [ + ("shard_test", [XImportant(True)]), + ]), + ]), ]), ]), ]), ("rocm", [ - ("3.7", [ + ("3.9", [ ("3.6", [ ('build_only', [XImportant(True)]), ]), @@ -158,6 +174,7 @@ def child_constructor(self): "libtorch": LibTorchConfigNode, "important": ImportantConfigNode, "build_only": BuildOnlyConfigNode, + "shard_test": ShardTestConfigNode, "cuda_gcc_override": CudaGccOverrideConfigNode, "coverage": CoverageConfigNode, "pure_torch": PureTorchConfigNode, @@ -195,7 +212,7 @@ def init2(self, node_name): self.props["is_asan"] = node_name def child_constructor(self): - return ImportantConfigNode + return ExperimentalFeatureConfigNode class ONNXConfigNode(TreeConfigNode): @@ -250,7 +267,7 @@ def init2(self, node_name): self.props["is_libtorch"] = node_name def child_constructor(self): - return ImportantConfigNode + return ExperimentalFeatureConfigNode class CudaGccOverrideConfigNode(TreeConfigNode): @@ -260,8 +277,8 @@ def init2(self, node_name): def child_constructor(self): return ExperimentalFeatureConfigNode -class BuildOnlyConfigNode(TreeConfigNode): +class BuildOnlyConfigNode(TreeConfigNode): def init2(self, node_name): self.props["build_only"] = node_name @@ -269,8 +286,15 @@ def child_constructor(self): return ExperimentalFeatureConfigNode -class CoverageConfigNode(TreeConfigNode): +class ShardTestConfigNode(TreeConfigNode): + def init2(self, node_name): + self.props["shard_test"] = node_name + + def child_constructor(self): + return ImportantConfigNode + +class CoverageConfigNode(TreeConfigNode): def init2(self, node_name): self.props["is_coverage"] = node_name @@ -290,7 +314,6 @@ def get_children(self): class XenialCompilerConfigNode(TreeConfigNode): - def modify_label(self, label): return label or "" @@ -304,7 +327,6 @@ def child_constructor(self): class BionicCompilerConfigNode(TreeConfigNode): - def modify_label(self, label): return label or "" diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index d582348b00c85..75b0e8812e1b3 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -6,7 +6,7 @@ import cimodel.lib.conf_tree as conf_tree import cimodel.lib.miniutils as miniutils from cimodel.data.pytorch_build_data import CONFIG_TREE_DATA, TopLevelNode -from cimodel.data.simple.util.branch_filters import gen_filter_dict +from cimodel.data.simple.util.branch_filters import gen_filter_dict, RC_PATTERN from cimodel.data.simple.util.docker_constants import gen_docker_image @@ -110,6 +110,8 @@ def gen_workflow_params(self, phase): parameters["resource_class"] = resource_class if phase == "build" and self.rocm_version is not None: parameters["resource_class"] = "xlarge" + if hasattr(self, 'filters'): + parameters['filters'] = self.filters return parameters def gen_workflow_job(self, phase): @@ -139,14 +141,16 @@ def gen_workflow_job(self, phase): # TODO This is a hack to special case some configs just for the workflow list class HiddenConf(object): - def __init__(self, name, parent_build=None): + def __init__(self, name, parent_build=None, filters=None): self.name = name self.parent_build = parent_build + self.filters = filters def gen_workflow_job(self, phase): return { self.gen_build_name(phase): { - "requires": [self.parent_build.gen_build_name("build")] + "requires": [self.parent_build.gen_build_name("build")], + "filters": self.filters, } } @@ -166,7 +170,8 @@ def gen_workflow_job(self, phase): "branch": self.branch, "requires": [self.parent_build], "context": "org-member", - "filters": gen_filter_dict(branches_list=["nightly"]) + "filters": gen_filter_dict(branches_list=["nightly"], + tags_list=RC_PATTERN) } } @@ -205,7 +210,9 @@ def gen_docs_configs(xenial_parent_config): configs.append( HiddenConf( "pytorch_python_doc_build", - parent_build=xenial_parent_config + parent_build=xenial_parent_config, + filters=gen_filter_dict(branches_list=r"/.*/", + tags_list=RC_PATTERN), ) ) configs.append( @@ -219,7 +226,9 @@ def gen_docs_configs(xenial_parent_config): configs.append( HiddenConf( "pytorch_cpp_doc_build", - parent_build=xenial_parent_config + parent_build=xenial_parent_config, + filters=gen_filter_dict(branches_list=r"/.*/", + tags_list=RC_PATTERN), ) ) configs.append( @@ -263,6 +272,7 @@ def instantiate_configs(): compiler_version = fc.find_prop("compiler_version") is_xla = fc.find_prop("is_xla") or False is_asan = fc.find_prop("is_asan") or False + is_coverage = fc.find_prop("is_coverage") or False is_onnx = fc.find_prop("is_onnx") or False is_pure_torch = fc.find_prop("is_pure_torch") or False is_vulkan = fc.find_prop("is_vulkan") or False @@ -301,7 +311,10 @@ def instantiate_configs(): parms_list.append("asan") python_version = fc.find_prop("pyver") parms_list[0] = fc.find_prop("abbreviated_pyver") - restrict_phases = ["build", "test1", "test2"] + + if is_coverage: + parms_list_ignored_for_docker_image.append("coverage") + python_version = fc.find_prop("pyver") if is_onnx: parms_list.append("onnx") @@ -317,13 +330,13 @@ def instantiate_configs(): is_important = fc.find_prop("is_important") or False parallel_backend = fc.find_prop("parallel_backend") or None build_only = fc.find_prop("build_only") or False - is_coverage = fc.find_prop("is_coverage") or False + shard_test = fc.find_prop("shard_test") or False # TODO: fix pure_torch python test packaging issue. + if shard_test: + restrict_phases = ["build"] if restrict_phases is None else restrict_phases + restrict_phases.extend(["test1", "test2"]) if build_only or is_pure_torch: restrict_phases = ["build"] - if is_coverage and restrict_phases is None: - restrict_phases = ["build", "coverage_test"] - gpu_resource = None if cuda_version and cuda_version != "10": @@ -348,6 +361,8 @@ def instantiate_configs(): # run docs builds on "pytorch-linux-xenial-py3.6-gcc5.4". Docs builds # should run on a CPU-only build that runs on all PRs. + # XXX should this be updated to a more modern build? Projects are + # beginning to drop python3.6 if ( distro_name == "xenial" and fc.find_prop("pyver") == "3.6" @@ -358,6 +373,8 @@ def instantiate_configs(): and compiler_name == "gcc" and fc.find_prop("compiler_version") == "5.4" ): + c.filters = gen_filter_dict(branches_list=r"/.*/", + tags_list=RC_PATTERN) c.dependent_tests = gen_docs_configs(c) if cuda_version == "10.2" and python_version == "3.6" and not is_libtorch: diff --git a/.circleci/cimodel/data/simple/docker_definitions.py b/.circleci/cimodel/data/simple/docker_definitions.py index 90d7763116019..960bd2fbff851 100644 --- a/.circleci/cimodel/data/simple/docker_definitions.py +++ b/.circleci/cimodel/data/simple/docker_definitions.py @@ -1,21 +1,24 @@ from collections import OrderedDict from cimodel.lib.miniutils import quote +from cimodel.data.simple.util.branch_filters import gen_filter_dict, RC_PATTERN # TODO: make this generated from a matrix rather than just a static list IMAGE_NAMES = [ + "pytorch-linux-bionic-cuda11.1-cudnn8-py3.6-gcc9", + "pytorch-linux-bionic-cuda11.1-cudnn8-py3.8-gcc9", "pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9", "pytorch-linux-bionic-cuda11.0-cudnn8-py3.8-gcc9", "pytorch-linux-bionic-cuda10.2-cudnn7-py3.8-gcc9", "pytorch-linux-bionic-py3.6-clang9", "pytorch-linux-bionic-cuda10.2-cudnn7-py3.6-clang9", "pytorch-linux-bionic-py3.8-gcc9", - "pytorch-linux-bionic-rocm3.5.1-py3.6", "pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7", "pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7", "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7", "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7", + "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7", "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc5.4", "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7", "pytorch-linux-xenial-py3-clang5-android-ndk-r19c", @@ -23,27 +26,30 @@ "pytorch-linux-xenial-py3-clang7-onnx", "pytorch-linux-xenial-py3.8", "pytorch-linux-xenial-py3.6-clang7", - "pytorch-linux-xenial-py3.6-gcc4.8", - "pytorch-linux-xenial-py3.6-gcc5.4", + "pytorch-linux-xenial-py3.6-gcc5.4", # this one is used in doc builds "pytorch-linux-xenial-py3.6-gcc7.2", "pytorch-linux-xenial-py3.6-gcc7", - "pytorch-linux-bionic-rocm3.7-py3.6", - "pytorch-linux-bionic-rocm3.8-py3.6", + "pytorch-linux-bionic-rocm3.9-py3.6", + "pytorch-linux-bionic-rocm3.10-py3.6", ] def get_workflow_jobs(): """Generates a list of docker image build definitions""" - return [ - OrderedDict( + ret = [] + for image_name in IMAGE_NAMES: + parameters = OrderedDict({ + "name": quote(f"docker-{image_name}"), + "image_name": quote(image_name), + }) + if image_name == "pytorch-linux-xenial-py3.6-gcc5.4": + # pushing documentation on tags requires CircleCI to also + # build all the dependencies on tags, including this docker image + parameters['filters'] = gen_filter_dict(branches_list=r"/.*/", + tags_list=RC_PATTERN) + ret.append(OrderedDict( { - "docker_build_job": OrderedDict( - { - "name": quote(f"docker-{image_name}"), - "image_name": quote(image_name), - } - ) + "docker_build_job": parameters } - ) - for image_name in IMAGE_NAMES - ] + )) + return ret diff --git a/.circleci/cimodel/data/simple/ge_config_tests.py b/.circleci/cimodel/data/simple/ge_config_tests.py index 2f2dbf0027dc3..603966c860b86 100644 --- a/.circleci/cimodel/data/simple/ge_config_tests.py +++ b/.circleci/cimodel/data/simple/ge_config_tests.py @@ -61,41 +61,16 @@ def gen_tree(self): MultiPartVersion([3, 6], "py"), MultiPartVersion([5, 4], "gcc"), None, - ["ge_config_legacy", "test"], + ["jit_legacy", "test"], ["pytorch_linux_xenial_py3_6_gcc5_4_build"]), - GeConfigTestJob( - MultiPartVersion([3, 6], "py"), - MultiPartVersion([5, 4], "gcc"), - None, - ["ge_config_profiling", "test"], - ["pytorch_linux_xenial_py3_6_gcc5_4_build"]), - GeConfigTestJob( - MultiPartVersion([3, 6], "py"), - MultiPartVersion([5, 4], "gcc"), - None, - ["ge_config_simple", "test"], - ["pytorch_linux_xenial_py3_6_gcc5_4_build"], - ), - GeConfigTestJob( - None, - None, - CudaVersion(10, 2), - ["cudnn7", "py3", "ge_config_legacy", "test"], - ["pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build"], - use_cuda_docker=True, - # TODO Why does the build environment specify cuda10.1, while the - # job name is cuda10_2? - build_env_override="pytorch-linux-xenial-cuda10.1-cudnn7-ge_config_legacy-test"), GeConfigTestJob( None, None, CudaVersion(10, 2), - ["cudnn7", "py3", "ge_config_profiling", "test"], + ["cudnn7", "py3", "jit_legacy", "test"], ["pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build"], use_cuda_docker=True, - # TODO Why does the build environment specify cuda10.1, while the - # job name is cuda10_2? - build_env_override="pytorch-linux-xenial-cuda10.1-cudnn7-ge_config_profiling-test"), + ), ] diff --git a/.circleci/cimodel/data/simple/ios_definitions.py b/.circleci/cimodel/data/simple/ios_definitions.py index 4446fa24fc283..7af6e36300ae8 100644 --- a/.circleci/cimodel/data/simple/ios_definitions.py +++ b/.circleci/cimodel/data/simple/ios_definitions.py @@ -1,16 +1,16 @@ from cimodel.data.simple.util.versions import MultiPartVersion +import cimodel.lib.miniutils as miniutils - -IOS_VERSION = MultiPartVersion([11, 2, 1]) +XCODE_VERSION = MultiPartVersion([12, 0, 0]) class ArchVariant: - def __init__(self, name, is_custom=False): + def __init__(self, name, custom_build_name=""): self.name = name - self.is_custom = is_custom + self.custom_build_name = custom_build_name def render(self): - extra_parts = ["custom"] if self.is_custom else [] + extra_parts = [self.custom_build_name] if len(self.custom_build_name) > 0 else [] return "_".join([self.name] + extra_parts) @@ -19,15 +19,15 @@ def get_platform(arch_variant_name): class IOSJob: - def __init__(self, ios_version, arch_variant, is_org_member_context=True, extra_props=None): - self.ios_version = ios_version + def __init__(self, xcode_version, arch_variant, is_org_member_context=True, extra_props=None): + self.xcode_version = xcode_version self.arch_variant = arch_variant self.is_org_member_context = is_org_member_context self.extra_props = extra_props def gen_name_parts(self, with_version_dots): - version_parts = self.ios_version.render_dots_or_parts(with_version_dots) + version_parts = self.xcode_version.render_dots_or_parts(with_version_dots) build_variant_suffix = "_".join([self.arch_variant.render(), "build"]) return [ @@ -61,9 +61,10 @@ def gen_tree(self): WORKFLOW_DATA = [ - IOSJob(IOS_VERSION, ArchVariant("x86_64"), is_org_member_context=False), - # IOSJob(IOS_VERSION, ArchVariant("arm64")), - # IOSJob(IOS_VERSION, ArchVariant("arm64", True), extra_props={"op_list": "mobilenetv2.yaml"}), + IOSJob(XCODE_VERSION, ArchVariant("x86_64"), is_org_member_context=False), + IOSJob(XCODE_VERSION, ArchVariant("arm64")), + IOSJob(XCODE_VERSION, ArchVariant("arm64", "metal"), extra_props={"use_metal": miniutils.quote(str(int(True)))}), + IOSJob(XCODE_VERSION, ArchVariant("arm64", "custom"), extra_props={"op_list": "mobilenetv2.yaml"}), ] diff --git a/.circleci/cimodel/data/simple/nightly_ios.py b/.circleci/cimodel/data/simple/nightly_ios.py index 6c01479dde80c..8e21e4c3782f7 100644 --- a/.circleci/cimodel/data/simple/nightly_ios.py +++ b/.circleci/cimodel/data/simple/nightly_ios.py @@ -18,7 +18,7 @@ def get_common_name_pieces(self, with_version_dots): common_name_pieces = [ "ios", - ] + ios_definitions.IOS_VERSION.render_dots_or_parts(with_version_dots) + [ + ] + ios_definitions.XCODE_VERSION.render_dots_or_parts(with_version_dots) + [ "nightly", self.variant, "build", @@ -60,7 +60,7 @@ def gen_tree(self): WORKFLOW_DATA = BUILD_CONFIGS + [ - # IOSNightlyJob("binary", is_upload=True), + IOSNightlyJob("binary", is_upload=True), ] diff --git a/.circleci/cimodel/data/simple/util/versions.py b/.circleci/cimodel/data/simple/util/versions.py index cc23cadf480cb..53d3a837248c1 100644 --- a/.circleci/cimodel/data/simple/util/versions.py +++ b/.circleci/cimodel/data/simple/util/versions.py @@ -9,7 +9,7 @@ def prefixed_parts(self): with the prefix string. """ if self.parts: - return [self.prefix + str(self.parts[0])] + list(map(str, self.parts[1:])) + return [self.prefix + str(self.parts[0])] + [str(part) for part in self.parts[1:]] else: return [self.prefix] @@ -29,3 +29,6 @@ def __init__(self, major, minor): self.minor = minor super().__init__([self.major, self.minor], "cuda") + + def __str__(self): + return f"{self.major}.{self.minor}" diff --git a/.circleci/cimodel/data/windows_build_definitions.py b/.circleci/cimodel/data/windows_build_definitions.py index 8be63529d740d..c0e828eaab5e1 100644 --- a/.circleci/cimodel/data/windows_build_definitions.py +++ b/.circleci/cimodel/data/windows_build_definitions.py @@ -86,10 +86,11 @@ def gen_tree(self): props_dict["executor"] = "windows-with-nvidia-gpu" props_dict["cuda_version"] = ( - miniutils.quote(str(self.cuda_version.major)) + miniutils.quote(str(self.cuda_version)) if self.cuda_version else "cpu" ) + props_dict["name"] = "_".join(name_parts) return [{key_name: props_dict}] @@ -131,10 +132,10 @@ def TruePred(_): WindowsJob(None, _VC2019, CudaVersion(10, 1)), WindowsJob(1, _VC2019, CudaVersion(10, 1)), WindowsJob(2, _VC2019, CudaVersion(10, 1)), - # VS2019 CUDA-11.0 - WindowsJob(None, _VC2019, CudaVersion(11, 0)), - WindowsJob(1, _VC2019, CudaVersion(11, 0), master_only_pred=TruePred), - WindowsJob(2, _VC2019, CudaVersion(11, 0), master_only_pred=TruePred), + # VS2019 CUDA-11.1 + WindowsJob(None, _VC2019, CudaVersion(11, 1)), + WindowsJob(1, _VC2019, CudaVersion(11, 1), master_only_pred=TruePred), + WindowsJob(2, _VC2019, CudaVersion(11, 1), master_only_pred=TruePred), # VS2019 CPU-only WindowsJob(None, _VC2019, None), WindowsJob(1, _VC2019, None, master_only_pred=TruePred), diff --git a/.circleci/config.yml b/.circleci/config.yml index 700a4155441db..4cf4dc4e2c6ad 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -11,6 +11,9 @@ parameters: run_binary_tests: type: boolean default: false + run_build: + type: boolean + default: true docker_config_defaults: &docker_config_defaults user: jenkins @@ -142,7 +145,7 @@ commands: name: (Optional) Merge target branch no_output_timeout: "10m" command: | - if [ -n "$CIRCLE_PULL_REQUEST" ]; then + if [[ -n "$CIRCLE_PULL_REQUEST" && "$CIRCLE_BRANCH" != "nightly" ]]; then PR_NUM=$(basename $CIRCLE_PULL_REQUEST) CIRCLE_PR_BASE_BRANCH=$(curl -s https://api.github.com/repos/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME/pulls/$PR_NUM | jq -r '.base.ref') if [[ "${BUILD_ENVIRONMENT}" == *"xla"* || "${BUILD_ENVIRONMENT}" == *"gcc5"* ]] ; then @@ -302,11 +305,15 @@ pytorch_ios_params: &pytorch_ios_params op_list: type: string default: "" + use_metal: + type: string + default: "0" environment: BUILD_ENVIRONMENT: << parameters.build_environment >> IOS_ARCH: << parameters.ios_arch >> IOS_PLATFORM: << parameters.ios_platform >> SELECTED_OP_LIST: << parameters.op_list >> + USE_PYTORCH_METAL: << parameters.use_metal >> pytorch_windows_params: &pytorch_windows_params parameters: @@ -321,7 +328,7 @@ pytorch_windows_params: &pytorch_windows_params default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -449,12 +456,8 @@ jobs: no_output_timeout: "1h" command: | set -e - # TODO: Remove this after we figure out why rocm tests are failing - if [[ "${DOCKER_IMAGE}" == *rocm3.5* ]]; then - export DOCKER_TAG="ab1632df-fa59-40e6-8c23-98e004f61148" - fi - if [[ "${DOCKER_IMAGE}" == *rocm3.7* ]]; then - export DOCKER_TAG="1045c7b891104cb4fd23399eab413b6213e48aeb" + if [[ "${DOCKER_IMAGE}" == *rocm3.9* ]]; then + export DOCKER_TAG="f3d89a32912f62815e4feaeed47e564e887dffd6" fi if [[ ${BUILD_ENVIRONMENT} == *"pure_torch"* ]]; then echo 'BUILD_CAFFE2=OFF' >> "${BASH_ENV}" @@ -486,7 +489,7 @@ jobs: if [ -z "${BUILD_ONLY}" ]; then # Note [Special build images] # The xla build uses the same docker image as - # pytorch-linux-trusty-py3.6-gcc5.4-build. In the push step, we have to + # pytorch_linux_bionic_py3_6_clang9_build. In the push step, we have to # distinguish between them so the test can pick up the correct image. output_image=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} if [[ ${BUILD_ENVIRONMENT} == *"xla"* ]]; then @@ -534,12 +537,8 @@ jobs: command: | set -e export PYTHONUNBUFFERED=1 - # TODO: Remove this after we figure out why rocm tests are failing - if [[ "${DOCKER_IMAGE}" == *rocm3.5* ]]; then - export DOCKER_TAG="ab1632df-fa59-40e6-8c23-98e004f61148" - fi - if [[ "${DOCKER_IMAGE}" == *rocm3.7* ]]; then - export DOCKER_TAG="1045c7b891104cb4fd23399eab413b6213e48aeb" + if [[ "${DOCKER_IMAGE}" == *rocm3.9* ]]; then + export DOCKER_TAG="f3d89a32912f62815e4feaeed47e564e887dffd6" fi # See Note [Special build images] output_image=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} @@ -617,7 +616,7 @@ jobs: echo ".jenkins/pytorch/multigpu-test.sh" >> docker_commands.sh elif [[ ${BUILD_ENVIRONMENT} == *onnx* ]]; then echo "pip install click mock tabulate networkx==2.0" >> docker_commands.sh - echo "pip -q install --user -b /tmp/pip_install_onnx \"file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx\"" >> docker_commands.sh + echo "pip -q install --user \"file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx\"" >> docker_commands.sh echo ".jenkins/caffe2/test.sh" >> docker_commands.sh else echo ".jenkins/pytorch/test.sh" >> docker_commands.sh @@ -640,8 +639,10 @@ jobs: export CIRCLE_SHA1="$CIRCLE_SHA1" export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}" export CIRCLE_BRANCH="$CIRCLE_BRANCH" + export CIRCLE_JOB="$CIRCLE_JOB" + export CIRCLE_WORKFLOW_ID="$CIRCLE_WORKFLOW_ID" cd workspace - python test/print_test_stats.py test + python test/print_test_stats.py --upload-to-s3 test EOL echo "(cat docker_commands.sh | docker exec -u jenkins -i "$id" bash) 2>&1" > command.sh unbuffer bash command.sh | ts @@ -649,7 +650,11 @@ jobs: echo "Retrieving test reports" docker cp $id:/var/lib/jenkins/workspace/test/test-reports ./ || echo 'No test reports found!' if [[ ${BUILD_ENVIRONMENT} == *"coverage"* ]]; then - echo "Retrieving coverage report" + echo "Retrieving C++ coverage report" + docker cp $id:/var/lib/jenkins/workspace/build/coverage.info ./test + fi + if [[ ${BUILD_ENVIRONMENT} == *"coverage"* || ${BUILD_ENVIRONMENT} == *"onnx"* ]]; then + echo "Retrieving Python coverage report" docker cp $id:/var/lib/jenkins/workspace/test/.coverage ./test docker cp $id:/var/lib/jenkins/workspace/test/coverage.xml ./test python3 -mpip install codecov @@ -673,7 +678,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -735,7 +740,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -764,9 +769,6 @@ jobs: if [[ "${CUDA_VERSION}" != "10" || "${JOB_EXECUTOR}" != "windows-with-nvidia-gpu" ]]; then .circleci/scripts/windows_cuda_install.sh fi - if [[ "${CUDA_VERSION}" != "10" && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then - .circleci/scripts/driver_update.bat - fi fi - run: name: Install Cudnn @@ -779,7 +781,7 @@ jobs: no_output_timeout: "30m" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 set +x export AWS_ACCESS_KEY_ID=${CIRCLECI_AWS_ACCESS_KEY_FOR_WIN_BUILD_V1} export AWS_SECRET_ACCESS_KEY=${CIRCLECI_AWS_SECRET_KEY_FOR_WIN_BUILD_V1} @@ -924,7 +926,7 @@ jobs: smoke_mac_test: <<: *binary_linux_test_upload_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - run: @@ -949,7 +951,7 @@ jobs: binary_mac_build: <<: *binary_mac_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - checkout @@ -963,7 +965,7 @@ jobs: - run: name: Build - no_output_timeout: "1h" + no_output_timeout: "90m" command: | # Do not set -u here; there is some problem with CircleCI # variable expansion with PROMPT_COMMAND @@ -990,7 +992,7 @@ jobs: binary_ios_build: <<: *pytorch_ios_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - attach_workspace: at: ~/workspace @@ -1017,7 +1019,7 @@ jobs: binary_ios_upload: <<: *pytorch_ios_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - attach_workspace: at: ~/workspace @@ -1187,10 +1189,13 @@ jobs: set -ex export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} + tag=${CIRCLE_TAG:1:5} + target=${tag:-master} + echo "building for ${target}" time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) - export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && '"export CIRCLE_SHA1='$CIRCLE_SHA1'"' && . ./.circleci/scripts/python_doc_push_script.sh docs/'$target' '$target' site") | docker exec -u jenkins -i "$id" bash) 2>&1' echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts @@ -1229,10 +1234,13 @@ jobs: set -ex export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} + tag=${CIRCLE_TAG:1:5} + target=${tag:-master} + echo "building for ${target}" time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) - export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/cpp_doc_push_script.sh docs/master master") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && '"export CIRCLE_SHA1='$CIRCLE_SHA1'"' && . ./.circleci/scripts/cpp_doc_push_script.sh docs/"$target" master") | docker exec -u jenkins -i "$id" bash) 2>&1' echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts @@ -1253,7 +1261,7 @@ jobs: environment: BUILD_ENVIRONMENT: pytorch-macos-10.13-py3-build macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - run_brew_for_macos_build @@ -1262,7 +1270,7 @@ jobs: no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 # Install sccache sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache @@ -1287,7 +1295,7 @@ jobs: environment: BUILD_ENVIRONMENT: pytorch-macos-10.13-py3-test macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - attach_workspace: @@ -1298,7 +1306,7 @@ jobs: no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 chmod a+x .jenkins/pytorch/macos-test.sh unbuffer .jenkins/pytorch/macos-test.sh 2>&1 | ts @@ -1397,22 +1405,22 @@ jobs: pytorch_android_publish_snapshot: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-publish-snapshot - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:ab1632df-fa59-40e6-8c23-98e004f61148" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c" PYTHON_VERSION: "3.6" resource_class: large machine: image: ubuntu-1604:202007-01 steps: - checkout + - calculate_docker_image_tag - setup_linux_system_environment - - checkout - setup_ci_environment - run: name: pytorch android gradle build no_output_timeout: "1h" command: | set -eux - docker_image_commit=${DOCKER_IMAGE}-${CIRCLE_SHA1} + docker_image_commit=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} docker_image_libtorch_android_x86_32_gradle=${docker_image_commit}-android-x86_32-gradle @@ -1515,7 +1523,7 @@ jobs: pytorch_ios_build: <<: *pytorch_ios_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - run_brew_for_ios_build @@ -1534,7 +1542,7 @@ jobs: rm cert.txt bundle exec fastlane install_cert # install the provisioning profile - PROFILE=TestApp_CI.mobileprovision + PROFILE=PyTorch_CI_2021.mobileprovision PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles mkdir -pv "${PROVISIONING_PROFILES}" cd "${PROVISIONING_PROFILES}" @@ -1546,7 +1554,7 @@ jobs: no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 WORKSPACE=/Users/distiller/workspace PROJ_ROOT=/Users/distiller/project export TCLLIBPATH="/usr/local/lib" @@ -1563,7 +1571,7 @@ jobs: $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) } - retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes + retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi requests --yes # sync submodules cd ${PROJ_ROOT} @@ -1577,6 +1585,7 @@ jobs: chmod a+x ${PROJ_ROOT}/scripts/build_ios.sh echo "IOS_ARCH: ${IOS_ARCH}" echo "IOS_PLATFORM: ${IOS_PLATFORM}" + echo "USE_PYTORCH_METAL": "${USE_METAL}" #check the custom build flag echo "SELECTED_OP_LIST: ${SELECTED_OP_LIST}" @@ -1585,6 +1594,9 @@ jobs: fi export IOS_ARCH=${IOS_ARCH} export IOS_PLATFORM=${IOS_PLATFORM} + if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then + export USE_PYTORCH_METAL=${USE_METAL} + fi unbuffer ${PROJ_ROOT}/scripts/build_ios.sh 2>&1 | ts - run: name: Run Build Test @@ -1592,7 +1604,7 @@ jobs: command: | set -e PROJ_ROOT=/Users/distiller/project - PROFILE=TestApp_CI + PROFILE=PyTorch_CI_2021 # run the ruby build script if ! [ -x "$(command -v xcodebuild)" ]; then echo 'Error: xcodebuild is not installed.' @@ -1966,8 +1978,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_manywheel_3_6m_cu92_devtoolset7_nightly_build - build_environment: "manywheel 3.6m cu92 devtoolset7" + name: binary_linux_manywheel_3_9m_cpu_devtoolset7_nightly_build + build_environment: "manywheel 3.9m cpu devtoolset7" filters: branches: only: @@ -1975,10 +1987,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_manywheel_3_7m_cu92_devtoolset7_nightly_build - build_environment: "manywheel 3.7m cu92 devtoolset7" + name: binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_build + build_environment: "manywheel 3.6m cu101 devtoolset7" filters: branches: only: @@ -1986,10 +1998,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda101" - binary_linux_build: - name: binary_linux_manywheel_3_8m_cu92_devtoolset7_nightly_build - build_environment: "manywheel 3.8m cu92 devtoolset7" + name: binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_build + build_environment: "manywheel 3.7m cu101 devtoolset7" filters: branches: only: @@ -1997,10 +2009,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda101" - binary_linux_build: - name: binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_build - build_environment: "manywheel 3.6m cu101 devtoolset7" + name: binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_build + build_environment: "manywheel 3.8m cu101 devtoolset7" filters: branches: only: @@ -2010,8 +2022,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-cuda101" - binary_linux_build: - name: binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_build - build_environment: "manywheel 3.7m cu101 devtoolset7" + name: binary_linux_manywheel_3_9m_cu101_devtoolset7_nightly_build + build_environment: "manywheel 3.9m cu101 devtoolset7" filters: branches: only: @@ -2021,8 +2033,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-cuda101" - binary_linux_build: - name: binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_build - build_environment: "manywheel 3.8m cu101 devtoolset7" + name: binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_build + build_environment: "manywheel 3.6m cu102 devtoolset7" filters: branches: only: @@ -2030,10 +2042,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-cuda101" + docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_build - build_environment: "manywheel 3.6m cu102 devtoolset7" + name: binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_build + build_environment: "manywheel 3.7m cu102 devtoolset7" filters: branches: only: @@ -2043,8 +2055,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_build - build_environment: "manywheel 3.7m cu102 devtoolset7" + name: binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_build + build_environment: "manywheel 3.8m cu102 devtoolset7" filters: branches: only: @@ -2054,8 +2066,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_build - build_environment: "manywheel 3.8m cu102 devtoolset7" + name: binary_linux_manywheel_3_9m_cu102_devtoolset7_nightly_build + build_environment: "manywheel 3.9m cu102 devtoolset7" filters: branches: only: @@ -2098,8 +2110,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-cuda110" - binary_linux_build: - name: binary_linux_manywheel_3_6m_rocm3_7_devtoolset7_nightly_build - build_environment: "manywheel 3.6m rocm3.7 devtoolset7" + name: binary_linux_manywheel_3_9m_cu110_devtoolset7_nightly_build + build_environment: "manywheel 3.9m cu110 devtoolset7" filters: branches: only: @@ -2107,10 +2119,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-rocm:3.7" + docker_image: "pytorch/manylinux-cuda110" - binary_linux_build: - name: binary_linux_manywheel_3_7m_rocm3_7_devtoolset7_nightly_build - build_environment: "manywheel 3.7m rocm3.7 devtoolset7" + name: binary_linux_manywheel_3_6m_rocm3_10_devtoolset7_nightly_build + build_environment: "manywheel 3.6m rocm3.10 devtoolset7" filters: branches: only: @@ -2118,10 +2130,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-rocm:3.7" + docker_image: "pytorch/manylinux-rocm:3.10" - binary_linux_build: - name: binary_linux_manywheel_3_8m_rocm3_7_devtoolset7_nightly_build - build_environment: "manywheel 3.8m rocm3.7 devtoolset7" + name: binary_linux_manywheel_3_7m_rocm3_10_devtoolset7_nightly_build + build_environment: "manywheel 3.7m rocm3.10 devtoolset7" filters: branches: only: @@ -2129,10 +2141,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-rocm:3.7" + docker_image: "pytorch/manylinux-rocm:3.10" - binary_linux_build: - name: binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_build - build_environment: "manywheel 3.6m rocm3.8 devtoolset7" + name: binary_linux_manywheel_3_8m_rocm3_10_devtoolset7_nightly_build + build_environment: "manywheel 3.8m rocm3.10 devtoolset7" filters: branches: only: @@ -2140,10 +2152,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-rocm:3.8" + docker_image: "pytorch/manylinux-rocm:3.10" - binary_linux_build: - name: binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_build - build_environment: "manywheel 3.7m rocm3.8 devtoolset7" + name: binary_linux_manywheel_3_9m_rocm3_10_devtoolset7_nightly_build + build_environment: "manywheel 3.9m rocm3.10 devtoolset7" filters: branches: only: @@ -2151,10 +2163,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-rocm:3.8" + docker_image: "pytorch/manylinux-rocm:3.10" - binary_linux_build: - name: binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_build - build_environment: "manywheel 3.8m rocm3.8 devtoolset7" + name: binary_linux_manywheel_3_6m_rocm4_0_devtoolset7_nightly_build + build_environment: "manywheel 3.6m rocm4.0 devtoolset7" filters: branches: only: @@ -2162,10 +2174,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/manylinux-rocm:3.8" + docker_image: "pytorch/manylinux-rocm:4.0" - binary_linux_build: - name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_build - build_environment: "conda 3.6 cpu devtoolset7" + name: binary_linux_manywheel_3_7m_rocm4_0_devtoolset7_nightly_build + build_environment: "manywheel 3.7m rocm4.0 devtoolset7" filters: branches: only: @@ -2173,10 +2185,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/conda-cuda" + docker_image: "pytorch/manylinux-rocm:4.0" - binary_linux_build: - name: binary_linux_conda_3_7_cpu_devtoolset7_nightly_build - build_environment: "conda 3.7 cpu devtoolset7" + name: binary_linux_manywheel_3_8m_rocm4_0_devtoolset7_nightly_build + build_environment: "manywheel 3.8m rocm4.0 devtoolset7" filters: branches: only: @@ -2184,10 +2196,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/conda-cuda" + docker_image: "pytorch/manylinux-rocm:4.0" - binary_linux_build: - name: binary_linux_conda_3_8_cpu_devtoolset7_nightly_build - build_environment: "conda 3.8 cpu devtoolset7" + name: binary_linux_manywheel_3_9m_rocm4_0_devtoolset7_nightly_build + build_environment: "manywheel 3.9m rocm4.0 devtoolset7" filters: branches: only: @@ -2195,10 +2207,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_image: "pytorch/conda-cuda" + docker_image: "pytorch/manylinux-rocm:4.0" - binary_linux_build: - name: binary_linux_conda_3_6_cu92_devtoolset7_nightly_build - build_environment: "conda 3.6 cu92 devtoolset7" + name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_build + build_environment: "conda 3.6 cpu devtoolset7" filters: branches: only: @@ -2208,8 +2220,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_7_cu92_devtoolset7_nightly_build - build_environment: "conda 3.7 cu92 devtoolset7" + name: binary_linux_conda_3_7_cpu_devtoolset7_nightly_build + build_environment: "conda 3.7 cpu devtoolset7" filters: branches: only: @@ -2219,8 +2231,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_8_cu92_devtoolset7_nightly_build - build_environment: "conda 3.8 cu92 devtoolset7" + name: binary_linux_conda_3_8_cpu_devtoolset7_nightly_build + build_environment: "conda 3.8 cpu devtoolset7" filters: branches: only: @@ -2230,8 +2242,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_6_cu101_devtoolset7_nightly_build - build_environment: "conda 3.6 cu101 devtoolset7" + name: binary_linux_conda_3_9_cpu_devtoolset7_nightly_build + build_environment: "conda 3.9 cpu devtoolset7" filters: branches: only: @@ -2241,8 +2253,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_7_cu101_devtoolset7_nightly_build - build_environment: "conda 3.7 cu101 devtoolset7" + name: binary_linux_conda_3_6_cu101_devtoolset7_nightly_build + build_environment: "conda 3.6 cu101 devtoolset7" filters: branches: only: @@ -2252,8 +2264,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_8_cu101_devtoolset7_nightly_build - build_environment: "conda 3.8 cu101 devtoolset7" + name: binary_linux_conda_3_7_cu101_devtoolset7_nightly_build + build_environment: "conda 3.7 cu101 devtoolset7" filters: branches: only: @@ -2263,8 +2275,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_6_cu102_devtoolset7_nightly_build - build_environment: "conda 3.6 cu102 devtoolset7" + name: binary_linux_conda_3_8_cu101_devtoolset7_nightly_build + build_environment: "conda 3.8 cu101 devtoolset7" filters: branches: only: @@ -2274,8 +2286,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_7_cu102_devtoolset7_nightly_build - build_environment: "conda 3.7 cu102 devtoolset7" + name: binary_linux_conda_3_9_cu101_devtoolset7_nightly_build + build_environment: "conda 3.9 cu101 devtoolset7" filters: branches: only: @@ -2285,8 +2297,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_8_cu102_devtoolset7_nightly_build - build_environment: "conda 3.8 cu102 devtoolset7" + name: binary_linux_conda_3_6_cu102_devtoolset7_nightly_build + build_environment: "conda 3.6 cu102 devtoolset7" filters: branches: only: @@ -2296,8 +2308,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_6_cu110_devtoolset7_nightly_build - build_environment: "conda 3.6 cu110 devtoolset7" + name: binary_linux_conda_3_7_cu102_devtoolset7_nightly_build + build_environment: "conda 3.7 cu102 devtoolset7" filters: branches: only: @@ -2307,8 +2319,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_7_cu110_devtoolset7_nightly_build - build_environment: "conda 3.7 cu110 devtoolset7" + name: binary_linux_conda_3_8_cu102_devtoolset7_nightly_build + build_environment: "conda 3.8 cu102 devtoolset7" filters: branches: only: @@ -2318,8 +2330,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_conda_3_8_cu110_devtoolset7_nightly_build - build_environment: "conda 3.8 cu110 devtoolset7" + name: binary_linux_conda_3_9_cu102_devtoolset7_nightly_build + build_environment: "conda 3.9 cu102 devtoolset7" filters: branches: only: @@ -2329,8 +2341,8 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps_build - build_environment: "libtorch 3.7m cpu devtoolset7" + name: binary_linux_conda_3_6_cu110_devtoolset7_nightly_build + build_environment: "conda 3.6 cu110 devtoolset7" filters: branches: only: @@ -2338,11 +2350,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-with-deps" - docker_image: "pytorch/manylinux-cuda102" + docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-without-deps_build - build_environment: "libtorch 3.7m cpu devtoolset7" + name: binary_linux_conda_3_7_cu110_devtoolset7_nightly_build + build_environment: "conda 3.7 cu110 devtoolset7" filters: branches: only: @@ -2350,11 +2361,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-without-deps" - docker_image: "pytorch/manylinux-cuda102" + docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-with-deps_build - build_environment: "libtorch 3.7m cpu devtoolset7" + name: binary_linux_conda_3_8_cu110_devtoolset7_nightly_build + build_environment: "conda 3.8 cu110 devtoolset7" filters: branches: only: @@ -2362,11 +2372,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-with-deps" - docker_image: "pytorch/manylinux-cuda102" + docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-without-deps_build - build_environment: "libtorch 3.7m cpu devtoolset7" + name: binary_linux_conda_3_9_cu110_devtoolset7_nightly_build + build_environment: "conda 3.9 cu110 devtoolset7" filters: branches: only: @@ -2374,11 +2383,10 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-without-deps" - docker_image: "pytorch/manylinux-cuda102" + docker_image: "pytorch/conda-cuda" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-with-deps_build - build_environment: "libtorch 3.7m cu92 devtoolset7" + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps_build + build_environment: "libtorch 3.7m cpu devtoolset7" filters: branches: only: @@ -2387,10 +2395,10 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ libtorch_variant: "shared-with-deps" - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-without-deps_build - build_environment: "libtorch 3.7m cu92 devtoolset7" + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-without-deps_build + build_environment: "libtorch 3.7m cpu devtoolset7" filters: branches: only: @@ -2399,10 +2407,10 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ libtorch_variant: "shared-without-deps" - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-with-deps_build - build_environment: "libtorch 3.7m cu92 devtoolset7" + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-with-deps_build + build_environment: "libtorch 3.7m cpu devtoolset7" filters: branches: only: @@ -2411,10 +2419,10 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ libtorch_variant: "static-with-deps" - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-without-deps_build - build_environment: "libtorch 3.7m cu92 devtoolset7" + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-without-deps_build + build_environment: "libtorch 3.7m cpu devtoolset7" filters: branches: only: @@ -2423,7 +2431,7 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ libtorch_variant: "static-without-deps" - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda102" - binary_linux_build: name: binary_linux_libtorch_3_7m_cu101_devtoolset7_nightly_shared-with-deps_build build_environment: "libtorch 3.7m cu101 devtoolset7" @@ -2616,54 +2624,6 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ libtorch_variant: "static-without-deps" docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-with-deps_build - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-with-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-without-deps_build - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-without-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-with-deps_build - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-with-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - - binary_linux_build: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-without-deps_build - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-without-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - binary_linux_build: name: binary_linux_libtorch_3_7m_cu101_gcc5_4_cxx11-abi_nightly_shared-with-deps_build build_environment: "libtorch 3.7m cu101 gcc5.4_cxx11-abi" @@ -2838,6 +2798,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_mac_build: + name: binary_macos_wheel_3_9_cpu_nightly_build + build_environment: "wheel 3.9 cpu" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_mac_build: name: binary_macos_conda_3_6_cpu_nightly_build build_environment: "conda 3.6 cpu" @@ -2868,6 +2838,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_mac_build: + name: binary_macos_conda_3_9_cpu_nightly_build + build_environment: "conda 3.9 cpu" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_mac_build: name: binary_macos_libtorch_3_7_cpu_nightly_build build_environment: "libtorch 3.7 cpu" @@ -2908,6 +2888,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_9_cpu_nightly_build + build_environment: "wheel 3.9 cpu" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_wheel_3_6_cu101_nightly_build build_environment: "wheel 3.6 cu101" @@ -2938,6 +2928,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_9_cu101_nightly_build + build_environment: "wheel 3.9 cu101" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_wheel_3_6_cu102_nightly_build build_environment: "wheel 3.6 cu102" @@ -2968,6 +2968,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_9_cu102_nightly_build + build_environment: "wheel 3.9 cu102" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_wheel_3_6_cu110_nightly_build build_environment: "wheel 3.6 cu110" @@ -2998,6 +3008,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_wheel_3_9_cu110_nightly_build + build_environment: "wheel 3.9 cu110" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_conda_3_6_cpu_nightly_build build_environment: "conda 3.6 cpu" @@ -3028,6 +3048,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_conda_3_9_cpu_nightly_build + build_environment: "conda 3.9 cpu" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_conda_3_6_cu101_nightly_build build_environment: "conda 3.6 cu101" @@ -3058,6 +3088,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_conda_3_9_cu101_nightly_build + build_environment: "conda 3.9 cu101" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_conda_3_6_cu102_nightly_build build_environment: "conda 3.6 cu102" @@ -3088,6 +3128,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_conda_3_9_cu102_nightly_build + build_environment: "conda 3.9 cu102" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_conda_3_6_cu110_nightly_build build_environment: "conda 3.6 cu110" @@ -3118,6 +3168,16 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + - binary_windows_build: + name: binary_windows_conda_3_9_cu110_nightly_build + build_environment: "conda 3.9 cu110" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - binary_windows_build: name: binary_windows_libtorch_3_7_cpu_debug_nightly_build build_environment: "libtorch 3.7 cpu debug" @@ -3238,8 +3298,8 @@ workflows: - binary_linux_manywheel_3_8m_cpu_devtoolset7_nightly_build docker_image: "pytorch/manylinux-cuda102" - binary_linux_test: - name: binary_linux_manywheel_3_6m_cu92_devtoolset7_nightly_test - build_environment: "manywheel 3.6m cu92 devtoolset7" + name: binary_linux_manywheel_3_9m_cpu_devtoolset7_nightly_test + build_environment: "manywheel 3.9m cpu devtoolset7" filters: branches: only: @@ -3248,13 +3308,11 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_6m_cu92_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium + - binary_linux_manywheel_3_9m_cpu_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda102" - binary_linux_test: - name: binary_linux_manywheel_3_7m_cu92_devtoolset7_nightly_test - build_environment: "manywheel 3.7m cu92 devtoolset7" + name: binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_test + build_environment: "manywheel 3.6m cu101 devtoolset7" filters: branches: only: @@ -3263,13 +3321,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_7m_cu92_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-cuda92" + - binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda101" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_8m_cu92_devtoolset7_nightly_test - build_environment: "manywheel 3.8m cu92 devtoolset7" + name: binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_test + build_environment: "manywheel 3.7m cu101 devtoolset7" filters: branches: only: @@ -3278,13 +3336,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_8m_cu92_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-cuda92" + - binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda101" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_test - build_environment: "manywheel 3.6m cu101 devtoolset7" + name: binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_test + build_environment: "manywheel 3.8m cu101 devtoolset7" filters: branches: only: @@ -3293,13 +3351,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_build + - binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_build docker_image: "pytorch/manylinux-cuda101" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_test - build_environment: "manywheel 3.7m cu101 devtoolset7" + name: binary_linux_manywheel_3_9m_cu101_devtoolset7_nightly_test + build_environment: "manywheel 3.9m cu101 devtoolset7" filters: branches: only: @@ -3308,13 +3366,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_build + - binary_linux_manywheel_3_9m_cu101_devtoolset7_nightly_build docker_image: "pytorch/manylinux-cuda101" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_test - build_environment: "manywheel 3.8m cu101 devtoolset7" + name: binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_test + build_environment: "manywheel 3.6m cu102 devtoolset7" filters: branches: only: @@ -3323,13 +3381,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-cuda101" + - binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda102" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_test - build_environment: "manywheel 3.6m cu102 devtoolset7" + name: binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_test + build_environment: "manywheel 3.7m cu102 devtoolset7" filters: branches: only: @@ -3338,13 +3396,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_build + - binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_build docker_image: "pytorch/manylinux-cuda102" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_test - build_environment: "manywheel 3.7m cu102 devtoolset7" + name: binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_test + build_environment: "manywheel 3.8m cu102 devtoolset7" filters: branches: only: @@ -3353,13 +3411,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_build + - binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_build docker_image: "pytorch/manylinux-cuda102" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_test - build_environment: "manywheel 3.8m cu102 devtoolset7" + name: binary_linux_manywheel_3_9m_cu102_devtoolset7_nightly_test + build_environment: "manywheel 3.9m cu102 devtoolset7" filters: branches: only: @@ -3368,7 +3426,7 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_build + - binary_linux_manywheel_3_9m_cu102_devtoolset7_nightly_build docker_image: "pytorch/manylinux-cuda102" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -3418,8 +3476,8 @@ workflows: use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_6m_rocm3_7_devtoolset7_nightly_test - build_environment: "manywheel 3.6m rocm3.7 devtoolset7" + name: binary_linux_manywheel_3_9m_cu110_devtoolset7_nightly_test + build_environment: "manywheel 3.9m cu110 devtoolset7" filters: branches: only: @@ -3428,13 +3486,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_6m_rocm3_7_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-rocm:3.7" + - binary_linux_manywheel_3_9m_cu110_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-cuda110" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_7m_rocm3_7_devtoolset7_nightly_test - build_environment: "manywheel 3.7m rocm3.7 devtoolset7" + name: binary_linux_manywheel_3_6m_rocm3_10_devtoolset7_nightly_test + build_environment: "manywheel 3.6m rocm3.10 devtoolset7" filters: branches: only: @@ -3443,13 +3501,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_7m_rocm3_7_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-rocm:3.7" + - binary_linux_manywheel_3_6m_rocm3_10_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_8m_rocm3_7_devtoolset7_nightly_test - build_environment: "manywheel 3.8m rocm3.7 devtoolset7" + name: binary_linux_manywheel_3_7m_rocm3_10_devtoolset7_nightly_test + build_environment: "manywheel 3.7m rocm3.10 devtoolset7" filters: branches: only: @@ -3458,13 +3516,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_8m_rocm3_7_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-rocm:3.7" + - binary_linux_manywheel_3_7m_rocm3_10_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_test - build_environment: "manywheel 3.6m rocm3.8 devtoolset7" + name: binary_linux_manywheel_3_8m_rocm3_10_devtoolset7_nightly_test + build_environment: "manywheel 3.8m rocm3.10 devtoolset7" filters: branches: only: @@ -3473,13 +3531,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-rocm:3.8" + - binary_linux_manywheel_3_8m_rocm3_10_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_test - build_environment: "manywheel 3.7m rocm3.8 devtoolset7" + name: binary_linux_manywheel_3_9m_rocm3_10_devtoolset7_nightly_test + build_environment: "manywheel 3.9m rocm3.10 devtoolset7" filters: branches: only: @@ -3488,13 +3546,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-rocm:3.8" + - binary_linux_manywheel_3_9m_rocm3_10_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_test - build_environment: "manywheel 3.8m rocm3.8 devtoolset7" + name: binary_linux_manywheel_3_6m_rocm4_0_devtoolset7_nightly_test + build_environment: "manywheel 3.6m rocm4.0 devtoolset7" filters: branches: only: @@ -3503,13 +3561,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_build - docker_image: "pytorch/manylinux-rocm:3.8" + - binary_linux_manywheel_3_6m_rocm4_0_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:4.0" use_cuda_docker_runtime: "1" resource_class: gpu.medium - binary_linux_test: - name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_test - build_environment: "conda 3.6 cpu devtoolset7" + name: binary_linux_manywheel_3_7m_rocm4_0_devtoolset7_nightly_test + build_environment: "manywheel 3.7m rocm4.0 devtoolset7" filters: branches: only: @@ -3518,11 +3576,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_conda_3_6_cpu_devtoolset7_nightly_build - docker_image: "pytorch/conda-cuda" + - binary_linux_manywheel_3_7m_rocm4_0_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:4.0" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: - name: binary_linux_conda_3_7_cpu_devtoolset7_nightly_test - build_environment: "conda 3.7 cpu devtoolset7" + name: binary_linux_manywheel_3_8m_rocm4_0_devtoolset7_nightly_test + build_environment: "manywheel 3.8m rocm4.0 devtoolset7" filters: branches: only: @@ -3531,11 +3591,13 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_conda_3_7_cpu_devtoolset7_nightly_build - docker_image: "pytorch/conda-cuda" + - binary_linux_manywheel_3_8m_rocm4_0_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:4.0" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: - name: binary_linux_conda_3_8_cpu_devtoolset7_nightly_test - build_environment: "conda 3.8 cpu devtoolset7" + name: binary_linux_manywheel_3_9m_rocm4_0_devtoolset7_nightly_test + build_environment: "manywheel 3.9m rocm4.0 devtoolset7" filters: branches: only: @@ -3544,11 +3606,26 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_conda_3_8_cpu_devtoolset7_nightly_build + - binary_linux_manywheel_3_9m_rocm4_0_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:4.0" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_test + build_environment: "conda 3.6 cpu devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_conda_3_6_cpu_devtoolset7_nightly_build docker_image: "pytorch/conda-cuda" - binary_linux_test: - name: binary_linux_conda_3_6_cu92_devtoolset7_nightly_test - build_environment: "conda 3.6 cu92 devtoolset7" + name: binary_linux_conda_3_7_cpu_devtoolset7_nightly_test + build_environment: "conda 3.7 cpu devtoolset7" filters: branches: only: @@ -3557,13 +3634,11 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_conda_3_6_cu92_devtoolset7_nightly_build + - binary_linux_conda_3_7_cpu_devtoolset7_nightly_build docker_image: "pytorch/conda-cuda" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - binary_linux_test: - name: binary_linux_conda_3_7_cu92_devtoolset7_nightly_test - build_environment: "conda 3.7 cu92 devtoolset7" + name: binary_linux_conda_3_8_cpu_devtoolset7_nightly_test + build_environment: "conda 3.8 cpu devtoolset7" filters: branches: only: @@ -3572,13 +3647,11 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_conda_3_7_cu92_devtoolset7_nightly_build + - binary_linux_conda_3_8_cpu_devtoolset7_nightly_build docker_image: "pytorch/conda-cuda" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - binary_linux_test: - name: binary_linux_conda_3_8_cu92_devtoolset7_nightly_test - build_environment: "conda 3.8 cu92 devtoolset7" + name: binary_linux_conda_3_9_cpu_devtoolset7_nightly_test + build_environment: "conda 3.9 cpu devtoolset7" filters: branches: only: @@ -3587,10 +3660,8 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - - binary_linux_conda_3_8_cu92_devtoolset7_nightly_build + - binary_linux_conda_3_9_cpu_devtoolset7_nightly_build docker_image: "pytorch/conda-cuda" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - binary_linux_test: name: binary_linux_conda_3_6_cu101_devtoolset7_nightly_test build_environment: "conda 3.6 cu101 devtoolset7" @@ -3636,6 +3707,21 @@ workflows: docker_image: "pytorch/conda-cuda" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_conda_3_9_cu101_devtoolset7_nightly_test + build_environment: "conda 3.9 cu101 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_conda_3_9_cu101_devtoolset7_nightly_build + docker_image: "pytorch/conda-cuda" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: name: binary_linux_conda_3_6_cu102_devtoolset7_nightly_test build_environment: "conda 3.6 cu102 devtoolset7" @@ -3681,6 +3767,21 @@ workflows: docker_image: "pytorch/conda-cuda" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_conda_3_9_cu102_devtoolset7_nightly_test + build_environment: "conda 3.9 cu102 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_conda_3_9_cu102_devtoolset7_nightly_build + docker_image: "pytorch/conda-cuda" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: name: binary_linux_conda_3_6_cu110_devtoolset7_nightly_test build_environment: "conda 3.6 cu110 devtoolset7" @@ -3726,6 +3827,21 @@ workflows: docker_image: "pytorch/conda-cuda" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_conda_3_9_cu110_devtoolset7_nightly_test + build_environment: "conda 3.9 cu110 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_conda_3_9_cu110_devtoolset7_nightly_build + docker_image: "pytorch/conda-cuda" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps_test build_environment: "libtorch 3.7m cpu devtoolset7" @@ -3782,70 +3898,6 @@ workflows: requires: - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-without-deps_build docker_image: "pytorch/manylinux-cuda102" - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-with-deps_test - build_environment: "libtorch 3.7m cu92 devtoolset7" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-with-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-with-deps_build - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-without-deps_test - build_environment: "libtorch 3.7m cu92 devtoolset7" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-without-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-without-deps_build - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-with-deps_test - build_environment: "libtorch 3.7m cu92 devtoolset7" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-with-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-with-deps_build - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-without-deps_test - build_environment: "libtorch 3.7m cu92 devtoolset7" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-without-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-without-deps_build - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - binary_linux_test: name: binary_linux_libtorch_3_7m_cu101_devtoolset7_nightly_shared-with-deps_test build_environment: "libtorch 3.7m cu101 devtoolset7" @@ -4094,70 +4146,6 @@ workflows: requires: - binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_nightly_static-without-deps_build docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-with-deps_test - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-with-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-with-deps_build - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-without-deps_test - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "shared-without-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-without-deps_build - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-with-deps_test - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-with-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-with-deps_build - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - binary_linux_test: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-without-deps_test - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - filters: - branches: - only: - - /.*/ - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - libtorch_variant: "static-without-deps" - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-without-deps_build - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - binary_linux_test: name: binary_linux_libtorch_3_7m_cu101_gcc5_4_cxx11-abi_nightly_shared-with-deps_test build_environment: "libtorch 3.7m cu101 gcc5.4_cxx11-abi" @@ -4386,6 +4374,18 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - binary_windows_wheel_3_8_cpu_nightly_build + - binary_windows_test: + name: binary_windows_wheel_3_9_cpu_nightly_test + build_environment: "wheel 3.9 cpu" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_9_cpu_nightly_build - binary_windows_test: name: binary_windows_wheel_3_6_cu101_nightly_test build_environment: "wheel 3.6 cu101" @@ -4425,6 +4425,19 @@ workflows: requires: - binary_windows_wheel_3_8_cu101_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_wheel_3_9_cu101_nightly_test + build_environment: "wheel 3.9 cu101" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_9_cu101_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_wheel_3_6_cu102_nightly_test build_environment: "wheel 3.6 cu102" @@ -4464,6 +4477,19 @@ workflows: requires: - binary_windows_wheel_3_8_cu102_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_wheel_3_9_cu102_nightly_test + build_environment: "wheel 3.9 cu102" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_9_cu102_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_wheel_3_6_cu110_nightly_test build_environment: "wheel 3.6 cu110" @@ -4503,6 +4529,19 @@ workflows: requires: - binary_windows_wheel_3_8_cu110_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_wheel_3_9_cu110_nightly_test + build_environment: "wheel 3.9 cu110" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_wheel_3_9_cu110_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_conda_3_6_cpu_nightly_test build_environment: "conda 3.6 cpu" @@ -4539,6 +4578,18 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - binary_windows_conda_3_8_cpu_nightly_build + - binary_windows_test: + name: binary_windows_conda_3_9_cpu_nightly_test + build_environment: "conda 3.9 cpu" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_conda_3_9_cpu_nightly_build - binary_windows_test: name: binary_windows_conda_3_6_cu101_nightly_test build_environment: "conda 3.6 cu101" @@ -4578,6 +4629,19 @@ workflows: requires: - binary_windows_conda_3_8_cu101_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_conda_3_9_cu101_nightly_test + build_environment: "conda 3.9 cu101" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_conda_3_9_cu101_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_conda_3_6_cu102_nightly_test build_environment: "conda 3.6 cu102" @@ -4617,6 +4681,19 @@ workflows: requires: - binary_windows_conda_3_8_cu102_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_conda_3_9_cu102_nightly_test + build_environment: "conda 3.9 cu102" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_conda_3_9_cu102_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_conda_3_6_cu110_nightly_test build_environment: "conda 3.6 cu110" @@ -4656,6 +4733,19 @@ workflows: requires: - binary_windows_conda_3_8_cu110_nightly_build executor: windows-with-nvidia-gpu + - binary_windows_test: + name: binary_windows_conda_3_9_cu110_nightly_test + build_environment: "conda 3.9 cu110" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_windows_conda_3_9_cu110_nightly_build + executor: windows-with-nvidia-gpu - binary_windows_test: name: binary_windows_libtorch_3_7_cpu_debug_nightly_test build_environment: "libtorch 3.7 cpu debug" @@ -4801,10 +4891,10 @@ workflows: package_type: manywheel upload_subfolder: cpu - binary_upload: - name: binary_linux_manywheel_3_6m_cu92_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_9m_cpu_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_6m_cu92_devtoolset7_nightly_test + - binary_linux_manywheel_3_9m_cpu_devtoolset7_nightly_test filters: branches: only: @@ -4813,12 +4903,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: - name: binary_linux_manywheel_3_7m_cu92_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_7m_cu92_devtoolset7_nightly_test + - binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_test filters: branches: only: @@ -4827,12 +4917,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: cu92 + upload_subfolder: cu101 - binary_upload: - name: binary_linux_manywheel_3_8m_cu92_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_8m_cu92_devtoolset7_nightly_test + - binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_test filters: branches: only: @@ -4841,12 +4931,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: cu92 + upload_subfolder: cu101 - binary_upload: - name: binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_6m_cu101_devtoolset7_nightly_test + - binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_test filters: branches: only: @@ -4857,10 +4947,10 @@ workflows: package_type: manywheel upload_subfolder: cu101 - binary_upload: - name: binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_9m_cu101_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_7m_cu101_devtoolset7_nightly_test + - binary_linux_manywheel_3_9m_cu101_devtoolset7_nightly_test filters: branches: only: @@ -4871,10 +4961,10 @@ workflows: package_type: manywheel upload_subfolder: cu101 - binary_upload: - name: binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_8m_cu101_devtoolset7_nightly_test + - binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_test filters: branches: only: @@ -4883,12 +4973,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: cu101 + upload_subfolder: cu102 - binary_upload: - name: binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_6m_cu102_devtoolset7_nightly_test + - binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_test filters: branches: only: @@ -4899,10 +4989,10 @@ workflows: package_type: manywheel upload_subfolder: cu102 - binary_upload: - name: binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_7m_cu102_devtoolset7_nightly_test + - binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_test filters: branches: only: @@ -4913,10 +5003,10 @@ workflows: package_type: manywheel upload_subfolder: cu102 - binary_upload: - name: binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_9m_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_8m_cu102_devtoolset7_nightly_test + - binary_linux_manywheel_3_9m_cu102_devtoolset7_nightly_test filters: branches: only: @@ -4969,10 +5059,10 @@ workflows: package_type: manywheel upload_subfolder: cu110 - binary_upload: - name: binary_linux_manywheel_3_6m_rocm3_7_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_9m_cu110_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_6m_rocm3_7_devtoolset7_nightly_test + - binary_linux_manywheel_3_9m_cu110_devtoolset7_nightly_test filters: branches: only: @@ -4981,12 +5071,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: rocm3.7 + upload_subfolder: cu110 - binary_upload: - name: binary_linux_manywheel_3_7m_rocm3_7_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_6m_rocm3_10_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_7m_rocm3_7_devtoolset7_nightly_test + - binary_linux_manywheel_3_6m_rocm3_10_devtoolset7_nightly_test filters: branches: only: @@ -4995,12 +5085,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: rocm3.7 + upload_subfolder: rocm3.10 - binary_upload: - name: binary_linux_manywheel_3_8m_rocm3_7_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_7m_rocm3_10_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_8m_rocm3_7_devtoolset7_nightly_test + - binary_linux_manywheel_3_7m_rocm3_10_devtoolset7_nightly_test filters: branches: only: @@ -5009,12 +5099,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: rocm3.7 + upload_subfolder: rocm3.10 - binary_upload: - name: binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_8m_rocm3_10_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_test + - binary_linux_manywheel_3_8m_rocm3_10_devtoolset7_nightly_test filters: branches: only: @@ -5023,12 +5113,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: rocm3.8 + upload_subfolder: rocm3.10 - binary_upload: - name: binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_9m_rocm3_10_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_test + - binary_linux_manywheel_3_9m_rocm3_10_devtoolset7_nightly_test filters: branches: only: @@ -5037,12 +5127,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: rocm3.8 + upload_subfolder: rocm3.10 - binary_upload: - name: binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_6m_rocm4_0_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_test + - binary_linux_manywheel_3_6m_rocm4_0_devtoolset7_nightly_test filters: branches: only: @@ -5051,12 +5141,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel - upload_subfolder: rocm3.8 + upload_subfolder: rocm4.0 - binary_upload: - name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_7m_rocm4_0_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_6_cpu_devtoolset7_nightly_test + - binary_linux_manywheel_3_7m_rocm4_0_devtoolset7_nightly_test filters: branches: only: @@ -5064,13 +5154,13 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: conda - upload_subfolder: cpu + package_type: manywheel + upload_subfolder: rocm4.0 - binary_upload: - name: binary_linux_conda_3_7_cpu_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_8m_rocm4_0_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_7_cpu_devtoolset7_nightly_test + - binary_linux_manywheel_3_8m_rocm4_0_devtoolset7_nightly_test filters: branches: only: @@ -5078,13 +5168,27 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: conda - upload_subfolder: cpu + package_type: manywheel + upload_subfolder: rocm4.0 - binary_upload: - name: binary_linux_conda_3_8_cpu_devtoolset7_nightly_upload + name: binary_linux_manywheel_3_9m_rocm4_0_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_8_cpu_devtoolset7_nightly_test + - binary_linux_manywheel_3_9m_rocm4_0_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: rocm4.0 + - binary_upload: + name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_conda_3_6_cpu_devtoolset7_nightly_test filters: branches: only: @@ -5095,10 +5199,10 @@ workflows: package_type: conda upload_subfolder: cpu - binary_upload: - name: binary_linux_conda_3_6_cu92_devtoolset7_nightly_upload + name: binary_linux_conda_3_7_cpu_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_6_cu92_devtoolset7_nightly_test + - binary_linux_conda_3_7_cpu_devtoolset7_nightly_test filters: branches: only: @@ -5107,12 +5211,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: - name: binary_linux_conda_3_7_cu92_devtoolset7_nightly_upload + name: binary_linux_conda_3_8_cpu_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_7_cu92_devtoolset7_nightly_test + - binary_linux_conda_3_8_cpu_devtoolset7_nightly_test filters: branches: only: @@ -5121,12 +5225,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: - name: binary_linux_conda_3_8_cu92_devtoolset7_nightly_upload + name: binary_linux_conda_3_9_cpu_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_8_cu92_devtoolset7_nightly_test + - binary_linux_conda_3_9_cpu_devtoolset7_nightly_test filters: branches: only: @@ -5135,7 +5239,7 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: name: binary_linux_conda_3_6_cu101_devtoolset7_nightly_upload context: org-member @@ -5179,10 +5283,10 @@ workflows: package_type: conda upload_subfolder: cu101 - binary_upload: - name: binary_linux_conda_3_6_cu102_devtoolset7_nightly_upload + name: binary_linux_conda_3_9_cu101_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_6_cu102_devtoolset7_nightly_test + - binary_linux_conda_3_9_cu101_devtoolset7_nightly_test filters: branches: only: @@ -5191,12 +5295,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda - upload_subfolder: cu102 + upload_subfolder: cu101 - binary_upload: - name: binary_linux_conda_3_7_cu102_devtoolset7_nightly_upload + name: binary_linux_conda_3_6_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_7_cu102_devtoolset7_nightly_test + - binary_linux_conda_3_6_cu102_devtoolset7_nightly_test filters: branches: only: @@ -5207,10 +5311,10 @@ workflows: package_type: conda upload_subfolder: cu102 - binary_upload: - name: binary_linux_conda_3_8_cu102_devtoolset7_nightly_upload + name: binary_linux_conda_3_7_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_8_cu102_devtoolset7_nightly_test + - binary_linux_conda_3_7_cu102_devtoolset7_nightly_test filters: branches: only: @@ -5221,10 +5325,10 @@ workflows: package_type: conda upload_subfolder: cu102 - binary_upload: - name: binary_linux_conda_3_6_cu110_devtoolset7_nightly_upload + name: binary_linux_conda_3_8_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_6_cu110_devtoolset7_nightly_test + - binary_linux_conda_3_8_cu102_devtoolset7_nightly_test filters: branches: only: @@ -5233,12 +5337,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda - upload_subfolder: cu110 + upload_subfolder: cu102 - binary_upload: - name: binary_linux_conda_3_7_cu110_devtoolset7_nightly_upload + name: binary_linux_conda_3_9_cu102_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_7_cu110_devtoolset7_nightly_test + - binary_linux_conda_3_9_cu102_devtoolset7_nightly_test filters: branches: only: @@ -5247,12 +5351,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda - upload_subfolder: cu110 + upload_subfolder: cu102 - binary_upload: - name: binary_linux_conda_3_8_cu110_devtoolset7_nightly_upload + name: binary_linux_conda_3_6_cu110_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_conda_3_8_cu110_devtoolset7_nightly_test + - binary_linux_conda_3_6_cu110_devtoolset7_nightly_test filters: branches: only: @@ -5263,24 +5367,10 @@ workflows: package_type: conda upload_subfolder: cu110 - binary_upload: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps_upload - context: org-member - requires: - - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps_test - filters: - branches: - only: - - nightly - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cpu - - binary_upload: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-without-deps_upload + name: binary_linux_conda_3_7_cu110_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-without-deps_test + - binary_linux_conda_3_7_cu110_devtoolset7_nightly_test filters: branches: only: @@ -5288,13 +5378,13 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cpu + package_type: conda + upload_subfolder: cu110 - binary_upload: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-with-deps_upload + name: binary_linux_conda_3_8_cu110_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-with-deps_test + - binary_linux_conda_3_8_cu110_devtoolset7_nightly_test filters: branches: only: @@ -5302,13 +5392,13 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cpu + package_type: conda + upload_subfolder: cu110 - binary_upload: - name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-without-deps_upload + name: binary_linux_conda_3_9_cu110_devtoolset7_nightly_upload context: org-member requires: - - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-without-deps_test + - binary_linux_conda_3_9_cu110_devtoolset7_nightly_test filters: branches: only: @@ -5316,13 +5406,13 @@ workflows: tags: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cpu + package_type: conda + upload_subfolder: cu110 - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-with-deps_upload + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps_upload context: org-member requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-with-deps_test + - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps_test filters: branches: only: @@ -5331,12 +5421,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-without-deps_upload + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-without-deps_upload context: org-member requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-without-deps_test + - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-without-deps_test filters: branches: only: @@ -5345,12 +5435,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-with-deps_upload + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-with-deps_upload context: org-member requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-with-deps_test + - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-with-deps_test filters: branches: only: @@ -5359,12 +5449,12 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-without-deps_upload + name: binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-without-deps_upload context: org-member requires: - - binary_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-without-deps_test + - binary_linux_libtorch_3_7m_cpu_devtoolset7_nightly_static-without-deps_test filters: branches: only: @@ -5373,7 +5463,7 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch - upload_subfolder: cu92 + upload_subfolder: cpu - binary_upload: name: binary_linux_libtorch_3_7m_cu101_devtoolset7_nightly_shared-with-deps_upload context: org-member @@ -5598,62 +5688,6 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: libtorch upload_subfolder: cpu - - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-with-deps_upload - context: org-member - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-with-deps_test - filters: - branches: - only: - - nightly - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cu92 - - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-without-deps_upload - context: org-member - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-without-deps_test - filters: - branches: - only: - - nightly - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cu92 - - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-with-deps_upload - context: org-member - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-with-deps_test - filters: - branches: - only: - - nightly - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cu92 - - binary_upload: - name: binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-without-deps_upload - context: org-member - requires: - - binary_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-without-deps_test - filters: - branches: - only: - - nightly - tags: - only: - - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - package_type: libtorch - upload_subfolder: cu92 - binary_upload: name: binary_linux_libtorch_3_7m_cu101_gcc5_4_cxx11-abi_nightly_shared-with-deps_upload context: org-member @@ -5864,6 +5898,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: wheel upload_subfolder: cpu + - binary_upload: + name: binary_macos_wheel_3_9_cpu_nightly_upload + context: org-member + requires: + - binary_macos_wheel_3_9_cpu_nightly_build + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cpu - binary_upload: name: binary_macos_conda_3_6_cpu_nightly_upload context: org-member @@ -5906,6 +5954,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda upload_subfolder: cpu + - binary_upload: + name: binary_macos_conda_3_9_cpu_nightly_upload + context: org-member + requires: + - binary_macos_conda_3_9_cpu_nightly_build + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: conda + upload_subfolder: cpu - binary_upload: name: binary_macos_libtorch_3_7_cpu_nightly_upload context: org-member @@ -5962,6 +6024,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: wheel upload_subfolder: cpu + - binary_upload: + name: binary_windows_wheel_3_9_cpu_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_9_cpu_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cpu - binary_upload: name: binary_windows_wheel_3_6_cu101_nightly_upload context: org-member @@ -6004,6 +6080,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: wheel upload_subfolder: cu101 + - binary_upload: + name: binary_windows_wheel_3_9_cu101_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_9_cu101_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cu101 - binary_upload: name: binary_windows_wheel_3_6_cu102_nightly_upload context: org-member @@ -6046,6 +6136,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: wheel upload_subfolder: cu102 + - binary_upload: + name: binary_windows_wheel_3_9_cu102_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_9_cu102_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cu102 - binary_upload: name: binary_windows_wheel_3_6_cu110_nightly_upload context: org-member @@ -6088,6 +6192,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: wheel upload_subfolder: cu110 + - binary_upload: + name: binary_windows_wheel_3_9_cu110_nightly_upload + context: org-member + requires: + - binary_windows_wheel_3_9_cu110_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: wheel + upload_subfolder: cu110 - binary_upload: name: binary_windows_conda_3_6_cpu_nightly_upload context: org-member @@ -6130,6 +6248,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda upload_subfolder: cpu + - binary_upload: + name: binary_windows_conda_3_9_cpu_nightly_upload + context: org-member + requires: + - binary_windows_conda_3_9_cpu_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: conda + upload_subfolder: cpu - binary_upload: name: binary_windows_conda_3_6_cu101_nightly_upload context: org-member @@ -6172,6 +6304,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda upload_subfolder: cu101 + - binary_upload: + name: binary_windows_conda_3_9_cu101_nightly_upload + context: org-member + requires: + - binary_windows_conda_3_9_cu101_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: conda + upload_subfolder: cu101 - binary_upload: name: binary_windows_conda_3_6_cu102_nightly_upload context: org-member @@ -6214,6 +6360,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda upload_subfolder: cu102 + - binary_upload: + name: binary_windows_conda_3_9_cu102_nightly_upload + context: org-member + requires: + - binary_windows_conda_3_9_cu102_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: conda + upload_subfolder: cu102 - binary_upload: name: binary_windows_conda_3_6_cu110_nightly_upload context: org-member @@ -6256,6 +6416,20 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: conda upload_subfolder: cu110 + - binary_upload: + name: binary_windows_conda_3_9_cu110_nightly_upload + context: org-member + requires: + - binary_windows_conda_3_9_cu110_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: conda + upload_subfolder: cu110 - binary_upload: name: binary_windows_libtorch_3_7_cpu_debug_nightly_upload context: org-member @@ -6371,6 +6545,12 @@ workflows: when: << pipeline.parameters.run_binary_tests >> build: jobs: + - docker_build_job: + name: "docker-pytorch-linux-bionic-cuda11.1-cudnn8-py3.6-gcc9" + image_name: "pytorch-linux-bionic-cuda11.1-cudnn8-py3.6-gcc9" + - docker_build_job: + name: "docker-pytorch-linux-bionic-cuda11.1-cudnn8-py3.8-gcc9" + image_name: "pytorch-linux-bionic-cuda11.1-cudnn8-py3.8-gcc9" - docker_build_job: name: "docker-pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9" image_name: "pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9" @@ -6389,9 +6569,6 @@ workflows: - docker_build_job: name: "docker-pytorch-linux-bionic-py3.8-gcc9" image_name: "pytorch-linux-bionic-py3.8-gcc9" - - docker_build_job: - name: "docker-pytorch-linux-bionic-rocm3.5.1-py3.6" - image_name: "pytorch-linux-bionic-rocm3.5.1-py3.6" - docker_build_job: name: "docker-pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7" image_name: "pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7" @@ -6404,6 +6581,9 @@ workflows: - docker_build_job: name: "docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" image_name: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" + - docker_build_job: + name: "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" + image_name: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - docker_build_job: name: "docker-pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc5.4" image_name: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc5.4" @@ -6425,12 +6605,14 @@ workflows: - docker_build_job: name: "docker-pytorch-linux-xenial-py3.6-clang7" image_name: "pytorch-linux-xenial-py3.6-clang7" - - docker_build_job: - name: "docker-pytorch-linux-xenial-py3.6-gcc4.8" - image_name: "pytorch-linux-xenial-py3.6-gcc4.8" - docker_build_job: name: "docker-pytorch-linux-xenial-py3.6-gcc5.4" image_name: "pytorch-linux-xenial-py3.6-gcc5.4" + filters: + branches: + only: /.*/ + tags: + only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - docker_build_job: name: "docker-pytorch-linux-xenial-py3.6-gcc7.2" image_name: "pytorch-linux-xenial-py3.6-gcc7.2" @@ -6438,17 +6620,22 @@ workflows: name: "docker-pytorch-linux-xenial-py3.6-gcc7" image_name: "pytorch-linux-xenial-py3.6-gcc7" - docker_build_job: - name: "docker-pytorch-linux-bionic-rocm3.7-py3.6" - image_name: "pytorch-linux-bionic-rocm3.7-py3.6" + name: "docker-pytorch-linux-bionic-rocm3.9-py3.6" + image_name: "pytorch-linux-bionic-rocm3.9-py3.6" - docker_build_job: - name: "docker-pytorch-linux-bionic-rocm3.8-py3.6" - image_name: "pytorch-linux-bionic-rocm3.8-py3.6" + name: "docker-pytorch-linux-bionic-rocm3.10-py3.6" + image_name: "pytorch-linux-bionic-rocm3.10-py3.6" - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc5_4_build requires: - "docker-pytorch-linux-xenial-py3.6-gcc5.4" build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" + filters: + branches: + only: /.*/ + tags: + only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc5_4_test requires: @@ -6456,7 +6643,17 @@ workflows: build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" resource_class: large + filters: + branches: + only: /.*/ + tags: + only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ - pytorch_python_doc_build: + filters: + branches: + only: /.*/ + tags: + only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - pytorch_linux_xenial_py3_6_gcc5_4_build - pytorch_doc_push: @@ -6466,10 +6663,17 @@ workflows: branches: only: - nightly + tags: + only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ name: pytorch_python_doc_push requires: - pytorch_python_doc_build - pytorch_cpp_doc_build: + filters: + branches: + only: /.*/ + tags: + only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: - pytorch_linux_xenial_py3_6_gcc5_4_build - pytorch_doc_push: @@ -6479,6 +6683,8 @@ workflows: branches: only: - nightly + tags: + only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ name: pytorch_cpp_doc_push requires: - pytorch_cpp_doc_build @@ -6670,10 +6876,18 @@ workflows: build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test + name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 + requires: + - pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2 requires: - pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build - build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test" + build_environment: "pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium @@ -6743,60 +6957,38 @@ workflows: - /release\/.*/ build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test - requires: - - pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - pytorch_linux_build: - name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build + name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build requires: - - "docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" + - "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" filters: branches: only: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" + build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test + name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test requires: - - pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build + - pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build filters: branches: only: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" + build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: - name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build + name: pytorch_libtorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build requires: - - "docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test - requires: - - pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build - build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium + - "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" + build_environment: "pytorch-libtorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - pytorch_linux_build: name: pytorch_linux_bionic_py3_6_clang9_build requires: @@ -6837,24 +7029,31 @@ workflows: docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9" resource_class: large - pytorch_linux_build: - name: pytorch_linux_bionic_py3_8_gcc9_build + name: pytorch_linux_bionic_py3_8_gcc9_coverage_build requires: - "docker-pytorch-linux-bionic-py3.8-gcc9" - build_environment: "pytorch-linux-bionic-py3.8-gcc9-build" + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" - pytorch_linux_test: - name: pytorch_linux_bionic_py3_8_gcc9_coverage_test + name: pytorch_linux_bionic_py3_8_gcc9_coverage_test1 + requires: + - pytorch_linux_bionic_py3_8_gcc9_coverage_build + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-test1" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" + resource_class: large + - pytorch_linux_test: + name: pytorch_linux_bionic_py3_8_gcc9_coverage_test2 requires: - - pytorch_linux_bionic_py3_8_gcc9_build - build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage_test" + - pytorch_linux_bionic_py3_8_gcc9_coverage_build + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" resource_class: large - pytorch_linux_build: - name: pytorch_linux_bionic_rocm3_7_py3_6_build + name: pytorch_linux_bionic_rocm3_9_py3_6_build requires: - - "docker-pytorch-linux-bionic-rocm3.7-py3.6" - build_environment: "pytorch-linux-bionic-rocm3.7-py3.6-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-rocm3.7-py3.6" + - "docker-pytorch-linux-bionic-rocm3.9-py3.6" + build_environment: "pytorch-linux-bionic-rocm3.9-py3.6-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-rocm3.9-py3.6" resource_class: xlarge - pytorch_macos_10_13_py3_build: name: pytorch_macos_10_13_py3_build @@ -6942,10 +7141,30 @@ workflows: - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build - pytorch_ios_build: - build_environment: pytorch-ios-11.2.1-x86_64_build + build_environment: pytorch-ios-12.0.0-x86_64_build ios_arch: x86_64 ios_platform: SIMULATOR - name: pytorch_ios_11_2_1_x86_64_build + name: pytorch_ios_12_0_0_x86_64_build + - pytorch_ios_build: + build_environment: pytorch-ios-12.0.0-arm64_build + context: org-member + ios_arch: arm64 + ios_platform: OS + name: pytorch_ios_12_0_0_arm64_build + - pytorch_ios_build: + build_environment: pytorch-ios-12.0.0-arm64_metal_build + context: org-member + ios_arch: arm64 + ios_platform: OS + name: pytorch_ios_12_0_0_arm64_metal_build + use_metal: "1" + - pytorch_ios_build: + build_environment: pytorch-ios-12.0.0-arm64_custom_build + context: org-member + ios_arch: arm64 + ios_platform: OS + name: pytorch_ios_12_0_0_arm64_custom_build + op_list: mobilenetv2.yaml - pytorch_linux_build: build_environment: pytorch-linux-xenial-py3-clang5-mobile-build build_only: "1" @@ -6974,38 +7193,16 @@ workflows: requires: - docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c - pytorch_linux_test: - build_environment: pytorch-linux-xenial-py3.6-gcc5.4-ge_config_legacy-test - docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4 - name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_legacy_test - requires: - - pytorch_linux_xenial_py3_6_gcc5_4_build - resource_class: large - - pytorch_linux_test: - build_environment: pytorch-linux-xenial-py3.6-gcc5.4-ge_config_profiling-test - docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4 - name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_profiling_test - requires: - - pytorch_linux_xenial_py3_6_gcc5_4_build - resource_class: large - - pytorch_linux_test: - build_environment: pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test + build_environment: pytorch-linux-xenial-py3.6-gcc5.4-jit_legacy-test docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4 - name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test + name: pytorch_linux_xenial_py3_6_gcc5_4_jit_legacy_test requires: - pytorch_linux_xenial_py3_6_gcc5_4_build resource_class: large - pytorch_linux_test: - build_environment: pytorch-linux-xenial-cuda10.1-cudnn7-ge_config_legacy-test - docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7 - name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_ge_config_legacy_test - requires: - - pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build - resource_class: gpu.medium - use_cuda_docker_runtime: "1" - - pytorch_linux_test: - build_environment: pytorch-linux-xenial-cuda10.1-cudnn7-ge_config_profiling-test + build_environment: pytorch-linux-xenial-cuda10.2-cudnn7-py3-jit_legacy-test docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7 - name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_ge_config_profiling_test + name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_jit_legacy_test requires: - pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build resource_class: gpu.medium @@ -7127,23 +7324,32 @@ workflows: requires: - binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_shared-with-deps_build - binary_ios_build: - build_environment: libtorch-ios-11.2.1-nightly-x86_64-build + build_environment: libtorch-ios-12.0.0-nightly-x86_64-build context: org-member filters: branches: only: nightly ios_arch: x86_64 ios_platform: SIMULATOR - name: pytorch_ios_11_2_1_nightly_x86_64_build + name: pytorch_ios_12_0_0_nightly_x86_64_build - binary_ios_build: - build_environment: libtorch-ios-11.2.1-nightly-arm64-build + build_environment: libtorch-ios-12.0.0-nightly-arm64-build context: org-member filters: branches: only: nightly ios_arch: arm64 ios_platform: OS - name: pytorch_ios_11_2_1_nightly_arm64_build + name: pytorch_ios_12_0_0_nightly_arm64_build + - binary_ios_upload: + build_environment: libtorch-ios-12.0.0-nightly-binary-build-upload + context: org-member + filters: + branches: + only: nightly + requires: + - pytorch_ios_12_0_0_nightly_x86_64_build + - pytorch_ios_12_0_0_nightly_arm64_build - pytorch_linux_build: build_environment: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32 docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c @@ -7218,7 +7424,7 @@ workflows: - postnightly - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" name: pytorch_windows_vs2019_py36_cuda10.1_build python_version: "3.6" use_cuda: "1" @@ -7227,7 +7433,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" executor: windows-with-nvidia-gpu name: pytorch_windows_vs2019_py36_cuda10.1_test1 python_version: "3.6" @@ -7240,7 +7446,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" executor: windows-with-nvidia-gpu name: pytorch_windows_vs2019_py36_cuda10.1_test2 python_version: "3.6" @@ -7253,8 +7459,8 @@ workflows: vc_year: "2019" - pytorch_windows_build: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" - name: pytorch_windows_vs2019_py36_cuda11.0_build + cuda_version: "11.1" + name: pytorch_windows_vs2019_py36_cuda11.1_build python_version: "3.6" use_cuda: "1" vc_product: Community @@ -7262,7 +7468,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" executor: windows-with-nvidia-gpu filters: branches: @@ -7270,10 +7476,10 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - name: pytorch_windows_vs2019_py36_cuda11.0_test1 + name: pytorch_windows_vs2019_py36_cuda11.1_test1 python_version: "3.6" requires: - - pytorch_windows_vs2019_py36_cuda11.0_build + - pytorch_windows_vs2019_py36_cuda11.1_build test_name: pytorch-windows-test1 use_cuda: "1" vc_product: Community @@ -7281,7 +7487,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda11-cudnn8-py3 - cuda_version: "11" + cuda_version: "11.1" executor: windows-with-nvidia-gpu filters: branches: @@ -7289,10 +7495,10 @@ workflows: - master - /ci-all\/.*/ - /release\/.*/ - name: pytorch_windows_vs2019_py36_cuda11.0_test2 + name: pytorch_windows_vs2019_py36_cuda11.1_test2 python_version: "3.6" requires: - - pytorch_windows_vs2019_py36_cuda11.0_build + - pytorch_windows_vs2019_py36_cuda11.1_build test_name: pytorch-windows-test2 use_cuda: "1" vc_product: Community @@ -7345,7 +7551,7 @@ workflows: vc_year: "2019" - pytorch_windows_test: build_environment: pytorch-win-vs2019-cuda10-cudnn7-py3 - cuda_version: "10" + cuda_version: "10.1" filters: branches: only: @@ -7399,44 +7605,42 @@ workflows: - postnightly docker_image: "pytorch/manylinux-cuda102" - smoke_linux_test: - name: smoke_linux_manywheel_3_6m_cu92_devtoolset7_nightly - build_environment: "manywheel 3.6m cu92 devtoolset7" + name: smoke_linux_manywheel_3_9m_cpu_devtoolset7_nightly + build_environment: "manywheel 3.9m cpu devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium + docker_image: "pytorch/manylinux-cuda102" - smoke_linux_test: - name: smoke_linux_manywheel_3_7m_cu92_devtoolset7_nightly - build_environment: "manywheel 3.7m cu92 devtoolset7" + name: smoke_linux_manywheel_3_6m_cu101_devtoolset7_nightly + build_environment: "manywheel 3.6m cu101 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda101" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_8m_cu92_devtoolset7_nightly - build_environment: "manywheel 3.8m cu92 devtoolset7" + name: smoke_linux_manywheel_3_7m_cu101_devtoolset7_nightly + build_environment: "manywheel 3.7m cu101 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-cuda92" + docker_image: "pytorch/manylinux-cuda101" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_6m_cu101_devtoolset7_nightly - build_environment: "manywheel 3.6m cu101 devtoolset7" + name: smoke_linux_manywheel_3_8m_cu101_devtoolset7_nightly + build_environment: "manywheel 3.8m cu101 devtoolset7" requires: - update_s3_htmls filters: @@ -7447,8 +7651,8 @@ workflows: use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_7m_cu101_devtoolset7_nightly - build_environment: "manywheel 3.7m cu101 devtoolset7" + name: smoke_linux_manywheel_3_9m_cu101_devtoolset7_nightly + build_environment: "manywheel 3.9m cu101 devtoolset7" requires: - update_s3_htmls filters: @@ -7459,20 +7663,20 @@ workflows: use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_8m_cu101_devtoolset7_nightly - build_environment: "manywheel 3.8m cu101 devtoolset7" + name: smoke_linux_manywheel_3_6m_cu102_devtoolset7_nightly + build_environment: "manywheel 3.6m cu102 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-cuda101" + docker_image: "pytorch/manylinux-cuda102" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_6m_cu102_devtoolset7_nightly - build_environment: "manywheel 3.6m cu102 devtoolset7" + name: smoke_linux_manywheel_3_7m_cu102_devtoolset7_nightly + build_environment: "manywheel 3.7m cu102 devtoolset7" requires: - update_s3_htmls filters: @@ -7483,8 +7687,8 @@ workflows: use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_7m_cu102_devtoolset7_nightly - build_environment: "manywheel 3.7m cu102 devtoolset7" + name: smoke_linux_manywheel_3_8m_cu102_devtoolset7_nightly + build_environment: "manywheel 3.8m cu102 devtoolset7" requires: - update_s3_htmls filters: @@ -7495,8 +7699,8 @@ workflows: use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_8m_cu102_devtoolset7_nightly - build_environment: "manywheel 3.8m cu102 devtoolset7" + name: smoke_linux_manywheel_3_9m_cu102_devtoolset7_nightly + build_environment: "manywheel 3.9m cu102 devtoolset7" requires: - update_s3_htmls filters: @@ -7543,100 +7747,116 @@ workflows: use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_6m_rocm3_7_devtoolset7_nightly - build_environment: "manywheel 3.6m rocm3.7 devtoolset7" + name: smoke_linux_manywheel_3_9m_cu110_devtoolset7_nightly + build_environment: "manywheel 3.9m cu110 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-rocm:3.7" + docker_image: "pytorch/manylinux-cuda110" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_7m_rocm3_7_devtoolset7_nightly - build_environment: "manywheel 3.7m rocm3.7 devtoolset7" + name: smoke_linux_manywheel_3_6m_rocm3_10_devtoolset7_nightly + build_environment: "manywheel 3.6m rocm3.10 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-rocm:3.7" + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_8m_rocm3_7_devtoolset7_nightly - build_environment: "manywheel 3.8m rocm3.7 devtoolset7" + name: smoke_linux_manywheel_3_7m_rocm3_10_devtoolset7_nightly + build_environment: "manywheel 3.7m rocm3.10 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-rocm:3.7" + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly - build_environment: "manywheel 3.6m rocm3.8 devtoolset7" + name: smoke_linux_manywheel_3_8m_rocm3_10_devtoolset7_nightly + build_environment: "manywheel 3.8m rocm3.10 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-rocm:3.8" + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly - build_environment: "manywheel 3.7m rocm3.8 devtoolset7" + name: smoke_linux_manywheel_3_9m_rocm3_10_devtoolset7_nightly + build_environment: "manywheel 3.9m rocm3.10 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-rocm:3.8" + docker_image: "pytorch/manylinux-rocm:3.10" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly - build_environment: "manywheel 3.8m rocm3.8 devtoolset7" + name: smoke_linux_manywheel_3_6m_rocm4_0_devtoolset7_nightly + build_environment: "manywheel 3.6m rocm4.0 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/manylinux-rocm:3.8" + docker_image: "pytorch/manylinux-rocm:4.0" use_cuda_docker_runtime: "1" resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_conda_3_6_cpu_devtoolset7_nightly - build_environment: "conda 3.6 cpu devtoolset7" + name: smoke_linux_manywheel_3_7m_rocm4_0_devtoolset7_nightly + build_environment: "manywheel 3.7m rocm4.0 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/conda-cuda" + docker_image: "pytorch/manylinux-rocm:4.0" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_conda_3_7_cpu_devtoolset7_nightly - build_environment: "conda 3.7 cpu devtoolset7" + name: smoke_linux_manywheel_3_8m_rocm4_0_devtoolset7_nightly + build_environment: "manywheel 3.8m rocm4.0 devtoolset7" requires: - update_s3_htmls filters: branches: only: - postnightly - docker_image: "pytorch/conda-cuda" + docker_image: "pytorch/manylinux-rocm:4.0" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_conda_3_8_cpu_devtoolset7_nightly - build_environment: "conda 3.8 cpu devtoolset7" + name: smoke_linux_manywheel_3_9m_rocm4_0_devtoolset7_nightly + build_environment: "manywheel 3.9m rocm4.0 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-rocm:4.0" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_conda_3_6_cpu_devtoolset7_nightly + build_environment: "conda 3.6 cpu devtoolset7" requires: - update_s3_htmls filters: @@ -7645,8 +7865,8 @@ workflows: - postnightly docker_image: "pytorch/conda-cuda" - smoke_linux_test: - name: smoke_linux_conda_3_6_cu92_devtoolset7_nightly - build_environment: "conda 3.6 cu92 devtoolset7" + name: smoke_linux_conda_3_7_cpu_devtoolset7_nightly + build_environment: "conda 3.7 cpu devtoolset7" requires: - update_s3_htmls filters: @@ -7654,11 +7874,9 @@ workflows: only: - postnightly docker_image: "pytorch/conda-cuda" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_conda_3_7_cu92_devtoolset7_nightly - build_environment: "conda 3.7 cu92 devtoolset7" + name: smoke_linux_conda_3_8_cpu_devtoolset7_nightly + build_environment: "conda 3.8 cpu devtoolset7" requires: - update_s3_htmls filters: @@ -7666,11 +7884,9 @@ workflows: only: - postnightly docker_image: "pytorch/conda-cuda" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - smoke_linux_test: - name: smoke_linux_conda_3_8_cu92_devtoolset7_nightly - build_environment: "conda 3.8 cu92 devtoolset7" + name: smoke_linux_conda_3_9_cpu_devtoolset7_nightly + build_environment: "conda 3.9 cpu devtoolset7" requires: - update_s3_htmls filters: @@ -7678,8 +7894,6 @@ workflows: only: - postnightly docker_image: "pytorch/conda-cuda" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_conda_3_6_cu101_devtoolset7_nightly build_environment: "conda 3.6 cu101 devtoolset7" @@ -7716,6 +7930,18 @@ workflows: docker_image: "pytorch/conda-cuda" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_conda_3_9_cu101_devtoolset7_nightly + build_environment: "conda 3.9 cu101 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/conda-cuda" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_conda_3_6_cu102_devtoolset7_nightly build_environment: "conda 3.6 cu102 devtoolset7" @@ -7752,6 +7978,18 @@ workflows: docker_image: "pytorch/conda-cuda" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_conda_3_9_cu102_devtoolset7_nightly + build_environment: "conda 3.9 cu102 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/conda-cuda" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_conda_3_6_cu110_devtoolset7_nightly build_environment: "conda 3.6 cu110 devtoolset7" @@ -7788,6 +8026,18 @@ workflows: docker_image: "pytorch/conda-cuda" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_conda_3_9_cu110_devtoolset7_nightly + build_environment: "conda 3.9 cu110 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/conda-cuda" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_libtorch_3_7m_cpu_devtoolset7_nightly_shared-with-deps build_environment: "libtorch 3.7m cpu devtoolset7" @@ -7832,58 +8082,6 @@ workflows: - postnightly libtorch_variant: "static-without-deps" docker_image: "pytorch/manylinux-cuda102" - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-with-deps - build_environment: "libtorch 3.7m cu92 devtoolset7" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "shared-with-deps" - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_devtoolset7_nightly_shared-without-deps - build_environment: "libtorch 3.7m cu92 devtoolset7" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "shared-without-deps" - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-with-deps - build_environment: "libtorch 3.7m cu92 devtoolset7" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "static-with-deps" - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_devtoolset7_nightly_static-without-deps - build_environment: "libtorch 3.7m cu92 devtoolset7" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "static-without-deps" - docker_image: "pytorch/manylinux-cuda92" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_libtorch_3_7m_cu101_devtoolset7_nightly_shared-with-deps build_environment: "libtorch 3.7m cu101 devtoolset7" @@ -8084,58 +8282,6 @@ workflows: - postnightly libtorch_variant: "static-without-deps" docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-with-deps - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "shared-with-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_shared-without-deps - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "shared-without-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-with-deps - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "static-with-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - smoke_linux_test: - name: smoke_linux_libtorch_3_7m_cu92_gcc5_4_cxx11-abi_nightly_static-without-deps - build_environment: "libtorch 3.7m cu92 gcc5.4_cxx11-abi" - requires: - - update_s3_htmls - filters: - branches: - only: - - postnightly - libtorch_variant: "static-without-deps" - docker_image: "pytorch/pytorch-binary-docker-image-ubuntu16.04:latest" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_libtorch_3_7m_cu101_gcc5_4_cxx11-abi_nightly_shared-with-deps build_environment: "libtorch 3.7m cu101 gcc5.4_cxx11-abi" @@ -8319,6 +8465,15 @@ workflows: branches: only: - postnightly + - smoke_mac_test: + name: smoke_macos_wheel_3_9_cpu_nightly + build_environment: "wheel 3.9 cpu" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly - smoke_mac_test: name: smoke_macos_conda_3_6_cpu_nightly build_environment: "conda 3.6 cpu" @@ -8346,6 +8501,15 @@ workflows: branches: only: - postnightly + - smoke_mac_test: + name: smoke_macos_conda_3_9_cpu_nightly + build_environment: "conda 3.9 cpu" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly - smoke_mac_test: name: smoke_macos_libtorch_3_7_cpu_nightly build_environment: "libtorch 3.7 cpu" @@ -8382,6 +8546,15 @@ workflows: branches: only: - postnightly + - smoke_windows_test: + name: smoke_windows_wheel_3_9_cpu_nightly + build_environment: "wheel 3.9 cpu" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly - smoke_windows_test: name: smoke_windows_wheel_3_6_cu101_nightly build_environment: "wheel 3.6 cu101" @@ -8412,6 +8585,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_wheel_3_9_cu101_nightly + build_environment: "wheel 3.9 cu101" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_wheel_3_6_cu102_nightly build_environment: "wheel 3.6 cu102" @@ -8442,6 +8625,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_wheel_3_9_cu102_nightly + build_environment: "wheel 3.9 cu102" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_wheel_3_6_cu110_nightly build_environment: "wheel 3.6 cu110" @@ -8472,6 +8665,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_wheel_3_9_cu110_nightly + build_environment: "wheel 3.9 cu110" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_conda_3_6_cpu_nightly build_environment: "conda 3.6 cpu" @@ -8499,6 +8702,15 @@ workflows: branches: only: - postnightly + - smoke_windows_test: + name: smoke_windows_conda_3_9_cpu_nightly + build_environment: "conda 3.9 cpu" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly - smoke_windows_test: name: smoke_windows_conda_3_6_cu101_nightly build_environment: "conda 3.6 cu101" @@ -8529,6 +8741,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_conda_3_9_cu101_nightly + build_environment: "conda 3.9 cu101" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_conda_3_6_cu102_nightly build_environment: "conda 3.6 cu102" @@ -8559,6 +8781,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_conda_3_9_cu102_nightly + build_environment: "conda 3.9 cu102" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_conda_3_6_cu110_nightly build_environment: "conda 3.6 cu110" @@ -8589,6 +8821,16 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + - smoke_windows_test: + name: smoke_windows_conda_3_9_cu110_nightly + build_environment: "conda 3.9 cu110" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + executor: windows-with-nvidia-gpu - smoke_windows_test: name: smoke_windows_libtorch_3_7_cpu_debug_nightly build_environment: "libtorch 3.7 cpu debug" @@ -8667,6 +8909,7 @@ workflows: only: - postnightly executor: windows-with-nvidia-gpu + when: << pipeline.parameters.run_build >> ecr_gc: triggers: - schedule: diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index 0afc1b33c59e9..e01ca37d471de 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -40,9 +40,7 @@ function extract_all_from_image_name() { done } -if [[ "$image" == *-trusty* ]]; then - UBUNTU_VERSION=14.04 -elif [[ "$image" == *-xenial* ]]; then +if [[ "$image" == *-xenial* ]]; then UBUNTU_VERSION=16.04 elif [[ "$image" == *-artful* ]]; then UBUNTU_VERSION=17.10 @@ -79,19 +77,10 @@ TRAVIS_DL_URL_PREFIX="https://s3.amazonaws.com/travis-python-archives/binaries/u # from scratch case "$image" in pytorch-linux-xenial-py3.8) - # TODO: This is a hack, get rid of this as soon as you get rid of the travis downloads - TRAVIS_DL_URL_PREFIX="https://s3.amazonaws.com/travis-python-archives/binaries/ubuntu/16.04/x86_64" - TRAVIS_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.8 GCC_VERSION=7 # Do not install PROTOBUF, DB, and VISION as a test ;; - pytorch-linux-xenial-py3.6-gcc4.8) - ANACONDA_PYTHON_VERSION=3.6 - GCC_VERSION=4.8 - PROTOBUF=yes - DB=yes - VISION=yes - ;; pytorch-linux-xenial-py3.6-gcc5.4) ANACONDA_PYTHON_VERSION=3.6 GCC_VERSION=5 @@ -169,6 +158,16 @@ case "$image" in VISION=yes KATEX=yes ;; + pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7) + CUDA_VERSION=11.1 + CUDNN_VERSION=8 + ANACONDA_PYTHON_VERSION=3.6 + GCC_VERSION=7 + PROTOBUF=yes + DB=yes + VISION=yes + KATEX=yes + ;; pytorch-linux-xenial-py3-clang5-asan) ANACONDA_PYTHON_VERSION=3.6 CLANG_VERSION=5.0 @@ -255,19 +254,39 @@ case "$image" in VISION=yes KATEX=yes ;; - pytorch-linux-bionic-rocm3.7-py3.6) + pytorch-linux-bionic-cuda11.1-cudnn8-py3.6-gcc9) + CUDA_VERSION=11.1 + CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.6 + GCC_VERSION=9 PROTOBUF=yes DB=yes VISION=yes - ROCM_VERSION=3.7 + KATEX=yes ;; - pytorch-linux-bionic-rocm3.8-py3.6) + pytorch-linux-bionic-cuda11.1-cudnn8-py3.8-gcc9) + CUDA_VERSION=11.1 + CUDNN_VERSION=8 + ANACONDA_PYTHON_VERSION=3.8 + GCC_VERSION=9 + PROTOBUF=yes + DB=yes + VISION=yes + KATEX=yes + ;; + pytorch-linux-bionic-rocm3.9-py3.6) ANACONDA_PYTHON_VERSION=3.6 PROTOBUF=yes DB=yes VISION=yes - ROCM_VERSION=3.8 + ROCM_VERSION=3.9 + ;; + pytorch-linux-bionic-rocm3.10-py3.6) + ANACONDA_PYTHON_VERSION=3.6 + PROTOBUF=yes + DB=yes + VISION=yes + ROCM_VERSION=3.10 ;; *) # Catch-all for builds that are not hardcoded. @@ -334,7 +353,6 @@ docker build \ --build-arg "GLIBC_VERSION=${GLIBC_VERSION}" \ --build-arg "CLANG_VERSION=${CLANG_VERSION}" \ --build-arg "ANACONDA_PYTHON_VERSION=${ANACONDA_PYTHON_VERSION}" \ - --build-arg "TRAVIS_PYTHON_VERSION=${TRAVIS_PYTHON_VERSION}" \ --build-arg "GCC_VERSION=${GCC_VERSION}" \ --build-arg "CUDA_VERSION=${CUDA_VERSION}" \ --build-arg "CUDNN_VERSION=${CUDNN_VERSION}" \ @@ -377,19 +395,6 @@ if [[ "$OS" == "ubuntu" ]]; then fi fi -if [ -n "$TRAVIS_PYTHON_VERSION" ]; then - if [[ "$TRAVIS_PYTHON_VERSION" != nightly ]]; then - if !(drun python --version 2>&1 | grep -qF "Python $TRAVIS_PYTHON_VERSION"); then - echo "TRAVIS_PYTHON_VERSION=$TRAVIS_PYTHON_VERSION, but:" - drun python --version - exit 1 - fi - else - echo "Please manually check nightly is OK:" - drun python --version - fi -fi - if [ -n "$ANACONDA_PYTHON_VERSION" ]; then if !(drun python --version 2>&1 | grep -qF "Python $ANACONDA_PYTHON_VERSION"); then echo "ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION, but:" diff --git a/.circleci/docker/centos-rocm/Dockerfile b/.circleci/docker/centos-rocm/Dockerfile index 1bc7b0deea322..dcaf2b0b50e05 100644 --- a/.circleci/docker/centos-rocm/Dockerfile +++ b/.circleci/docker/centos-rocm/Dockerfile @@ -27,7 +27,7 @@ RUN rm install_glibc.sh ADD ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh @@ -64,7 +64,6 @@ ENV PATH /opt/rocm/hcc/bin:$PATH ENV PATH /opt/rocm/hip/bin:$PATH ENV PATH /opt/rocm/opencl/bin:$PATH ENV PATH /opt/rocm/llvm/bin:$PATH -ENV HIP_PLATFORM hcc ENV LANG en_US.utf8 ENV LC_ALL en_US.utf8 diff --git a/.circleci/docker/common/install_base.sh b/.circleci/docker/common/install_base.sh index 5e8173a436271..191b4732452d8 100755 --- a/.circleci/docker/common/install_base.sh +++ b/.circleci/docker/common/install_base.sh @@ -18,7 +18,6 @@ install_ubuntu() { # Install common dependencies apt-get update # TODO: Some of these may not be necessary - # TODO: libiomp also gets installed by conda, aka there's a conflict ccache_deps="asciidoc docbook-xml docbook-xsl xsltproc" numpy_deps="gfortran" apt-get install -y --no-install-recommends \ @@ -40,21 +39,11 @@ install_ubuntu() { libjpeg-dev \ libasound2-dev \ libsndfile-dev \ - python \ - python-dev \ - python-setuptools \ - python-wheel \ software-properties-common \ sudo \ wget \ vim - # TODO: THIS IS A HACK!!! - # distributed nccl(2) tests are a bit busted, see https://github.com/pytorch/pytorch/issues/5877 - if dpkg -s libnccl-dev; then - apt-get remove -y libnccl-dev libnccl2 --allow-change-held-packages - fi - # Cleanup package manager apt-get autoclean && apt-get clean rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* @@ -131,4 +120,3 @@ sudo make install cd ../../ rm -rf valgrind_build alias valgrind="/usr/local/bin/valgrind" - diff --git a/.circleci/docker/common/install_cache.sh b/.circleci/docker/common/install_cache.sh index f1066519cd70f..fc1630272472e 100644 --- a/.circleci/docker/common/install_cache.sh +++ b/.circleci/docker/common/install_cache.sh @@ -2,6 +2,28 @@ set -ex +install_ubuntu() { + echo "Preparing to build sccache from source" + apt-get update + apt-get install -y cargo pkg-config libssl-dev + echo "Checking out sccache repo" + git clone https://github.com/pytorch/sccache + cd sccache + echo "Building sccache" + cargo build --release + cp target/release/sccache /opt/cache/bin + echo "Cleaning up" + cd .. + rm -rf sccache + apt-get remove -y cargo rustc + apt-get autoclean && apt-get clean +} + +install_binary() { + echo "Downloading sccache binary from S3 repo" + curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /opt/cache/bin/sccache +} + mkdir -p /opt/cache/bin mkdir -p /opt/cache/lib sed -e 's|PATH="\(.*\)"|PATH="/opt/cache/bin:\1"|g' -i /etc/environment @@ -11,12 +33,20 @@ export PATH="/opt/cache/bin:$PATH" if [ -n "$ROCM_VERSION" ]; then curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache else - curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /opt/cache/bin/sccache + ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') + case "$ID" in + ubuntu) + install_ubuntu + ;; + *) + install_binary + ;; + esac fi chmod a+x /opt/cache/bin/sccache function write_sccache_stub() { - printf "#!/bin/sh\nexec sccache $(which $1) \$*" > "/opt/cache/bin/$1" + printf "#!/bin/sh\nif [ \$(ps -p \$PPID -o comm=) != sccache ]; then\n exec sccache $(which $1) \"\$@\"\nelse\n exec $(which $1) \"\$@\"\nfi" > "/opt/cache/bin/$1" chmod a+x "/opt/cache/bin/$1" } @@ -38,8 +68,8 @@ if [ -n "$CUDA_VERSION" ]; then # where CUDA is installed. Instead, we install an nvcc symlink outside # of the PATH, and set CUDA_NVCC_EXECUTABLE so that we make use of it. - printf "#!/bin/sh\nexec sccache $(which nvcc) \"\$@\"" > /opt/cache/lib/nvcc - chmod a+x /opt/cache/lib/nvcc + write_sccache_stub nvcc + mv /opt/cache/bin/nvcc /opt/cache/lib/ fi if [ -n "$ROCM_VERSION" ]; then @@ -57,8 +87,8 @@ if [ -n "$ROCM_VERSION" ]; then TOPDIR=$(dirname $OLDCOMP) WRAPPED="$TOPDIR/original/$COMPNAME" mv "$OLDCOMP" "$WRAPPED" - printf "#!/bin/sh\nexec sccache $WRAPPED \$*" > "$OLDCOMP" - chmod a+x "$1" + printf "#!/bin/sh\nexec sccache $WRAPPED \"\$@\"" > "$OLDCOMP" + chmod a+x "$OLDCOMP" } if [[ -e "/opt/rocm/hcc/bin/hcc" ]]; then diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index 2a1c2bd0ea8fa..d54a743319e72 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -72,11 +72,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then # DO NOT install cmake here as it would install a version newer than 3.5, but # we want to pin to version 3.5. if [ "$ANACONDA_PYTHON_VERSION" = "3.8" ]; then - # DO NOT install typing if installing python-3.8, since its part of python-3.8 core packages # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source - conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 dataclasses + conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 + elif [ "$ANACONDA_PYTHON_VERSION" = "3.7" ]; then + # DO NOT install dataclasses if installing python-3.7, since its part of python-3.7 core packages + conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six else - conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi typing future six dataclasses + conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six dataclasses fi if [[ "$CUDA_VERSION" == 9.2* ]]; then conda_install magma-cuda92 -c pytorch @@ -88,18 +90,38 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then conda_install magma-cuda102 -c pytorch elif [[ "$CUDA_VERSION" == 11.0* ]]; then conda_install magma-cuda110 -c pytorch + elif [[ "$CUDA_VERSION" == 11.1* ]]; then + conda_install magma-cuda111 -c pytorch fi # TODO: This isn't working atm conda_install nnpack -c killeent - # Install some other packages + # Install some other packages, including those needed for Python test reporting # TODO: Why is scipy pinned - # numba & llvmlite is pinned because of https://github.com/numba/numba/issues/4368 - # scikit-learn is pinned because of - # https://github.com/scikit-learn/scikit-learn/issues/14485 (affects gcc 5.5 - # only) - as_jenkins pip install --progress-bar off pytest scipy==1.1.0 scikit-learn==0.20.3 scikit-image librosa>=0.6.2 psutil numba==0.46.0 llvmlite==0.30.0 + # Pin MyPy version because new errors are likely to appear with each release + # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 + as_jenkins pip install --progress-bar off pytest \ + scipy==1.1.0 \ + scikit-image \ + librosa>=0.6.2 \ + psutil \ + numba \ + llvmlite \ + unittest-xml-reporting \ + boto3==1.16.34 \ + coverage \ + hypothesis==4.53.2 \ + mypy==0.770 \ + tb-nightly + + # Update scikit-learn to a python-3.8 compatible version + if [[ $(python -c "import sys; print(int(sys.version_info >= (3, 8)))") == "1" ]]; then + as_jenkins pip install --progress-bar off -U scikit-learn + else + # Pinned scikit-learn due to https://github.com/scikit-learn/scikit-learn/issues/14485 (affects gcc 5.5 only) + as_jenkins pip install --progress-bar off scikit-learn==0.20.3 + fi popd fi diff --git a/.circleci/docker/common/install_gcc.sh b/.circleci/docker/common/install_gcc.sh index 48f17989f9788..0e86df1c778c5 100644 --- a/.circleci/docker/common/install_gcc.sh +++ b/.circleci/docker/common/install_gcc.sh @@ -15,6 +15,7 @@ if [ -n "$GCC_VERSION" ]; then update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50 + update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50 # Cleanup package manager apt-get autoclean && apt-get clean diff --git a/.circleci/docker/common/install_lcov.sh b/.circleci/docker/common/install_lcov.sh new file mode 100644 index 0000000000000..b4364698318ab --- /dev/null +++ b/.circleci/docker/common/install_lcov.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -ex + +git clone --branch v1.15 https://github.com/linux-test-project/lcov.git +pushd lcov +sudo make install # will be installed in /usr/local/bin/lcov +popd diff --git a/.circleci/docker/common/install_nccl.sh b/.circleci/docker/common/install_nccl.sh new file mode 100644 index 0000000000000..594a227f61913 --- /dev/null +++ b/.circleci/docker/common/install_nccl.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +sudo apt-get -qq update +sudo apt-get -qq install --allow-downgrades --allow-change-held-packages libnccl-dev=2.5.6-1+cuda10.1 libnccl2=2.5.6-1+cuda10.1 diff --git a/.circleci/docker/common/install_openmpi.sh b/.circleci/docker/common/install_openmpi.sh new file mode 100644 index 0000000000000..7bd32c71f16fb --- /dev/null +++ b/.circleci/docker/common/install_openmpi.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +sudo apt-get update +sudo apt-get install -y --allow-downgrades --allow-change-held-packages openmpi-bin libopenmpi-dev diff --git a/.circleci/docker/common/install_rocm.sh b/.circleci/docker/common/install_rocm.sh index 4a60bd70e7798..ca79b15af2162 100644 --- a/.circleci/docker/common/install_rocm.sh +++ b/.circleci/docker/common/install_rocm.sh @@ -16,10 +16,9 @@ install_ubuntu() { apt-get install -y libc++1 apt-get install -y libc++abi1 - DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/${ROCM_VERSION} # Add rocm repository - wget -qO - $DEB_ROCM_REPO/rocm.gpg.key | apt-key add - - echo "deb [arch=amd64] $DEB_ROCM_REPO xenial main" > /etc/apt/sources.list.d/rocm.list + wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - + echo "deb [arch=amd64] http://repo.radeon.com/rocm/apt/${ROCM_VERSION} xenial main" > /etc/apt/sources.list.d/rocm.list apt-get update --allow-insecure-repositories DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ diff --git a/.circleci/docker/common/install_travis_python.sh b/.circleci/docker/common/install_travis_python.sh deleted file mode 100755 index 41ad2dd32eb43..0000000000000 --- a/.circleci/docker/common/install_travis_python.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash - -set -ex - -as_jenkins() { - # NB: Preserve PATH and LD_LIBRARY_PATH changes - sudo -H -u jenkins env "PATH=$PATH" "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" $* -} - -if [ -n "$TRAVIS_PYTHON_VERSION" ]; then - - mkdir -p /opt/python - chown jenkins:jenkins /opt/python - - # Download Python binary from Travis - pushd tmp - as_jenkins wget --quiet ${TRAVIS_DL_URL_PREFIX}/python-$TRAVIS_PYTHON_VERSION.tar.bz2 - # NB: The tarball also comes with /home/travis virtualenv that we - # don't care about. (Maybe we should, but we've worked around the - # "how do I install to python" issue by making this entire directory - # user-writable "lol") - # NB: Relative ordering of opt/python and flags matters - as_jenkins tar xjf python-$TRAVIS_PYTHON_VERSION.tar.bz2 --strip-components=2 --directory /opt/python opt/python - popd - - echo "/opt/python/$TRAVIS_PYTHON_VERSION/lib" > /etc/ld.so.conf.d/travis-python.conf - ldconfig - sed -e 's|PATH="\(.*\)"|PATH="/opt/python/'"$TRAVIS_PYTHON_VERSION"'/bin:\1"|g' -i /etc/environment - export PATH="/opt/python/$TRAVIS_PYTHON_VERSION/bin:$PATH" - - python --version - pip --version - - # Install pip from source. - # The python-pip package on Ubuntu Trusty is old - # and upon install numpy doesn't use the binary - # distribution, and fails to compile it from source. - pushd tmp - as_jenkins curl -L -O https://pypi.python.org/packages/11/b6/abcb525026a4be042b486df43905d6893fb04f05aac21c32c638e939e447/pip-9.0.1.tar.gz - as_jenkins tar zxf pip-9.0.1.tar.gz - pushd pip-9.0.1 - as_jenkins python setup.py install - popd - rm -rf pip-9.0.1* - popd - - # Install pip packages - as_jenkins pip install --upgrade pip - - pip --version - - as_jenkins pip install numpy pyyaml - - as_jenkins pip install \ - future \ - hypothesis \ - protobuf \ - pytest \ - pillow \ - typing \ - dataclasses - - as_jenkins pip install mkl mkl-devel - - # SciPy does not support Python 3.7 or Python 2.7.9 - if [[ "$TRAVIS_PYTHON_VERSION" != nightly ]] && [[ "$TRAVIS_PYTHON_VERSION" != "2.7.9" ]]; then - as_jenkins pip install scipy==1.1.0 scikit-image librosa>=0.6.2 - fi - - # Install psutil for dataloader tests - as_jenkins pip install psutil - - # Install dill for serialization tests - as_jenkins pip install "dill>=0.3.1" - - # Cleanup package manager - apt-get autoclean && apt-get clean - rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -fi diff --git a/.circleci/docker/ubuntu-cuda/Dockerfile b/.circleci/docker/ubuntu-cuda/Dockerfile index d9e23475881e6..f512180f16169 100644 --- a/.circleci/docker/ubuntu-cuda/Dockerfile +++ b/.circleci/docker/ubuntu-cuda/Dockerfile @@ -24,7 +24,7 @@ ARG KATEX ADD ./common/install_katex.sh install_katex.sh RUN bash ./install_katex.sh && rm install_katex.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh @@ -40,12 +40,6 @@ ARG CLANG_VERSION ADD ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh -# Install non-standard Python versions (via Travis binaries) -ARG TRAVIS_PYTHON_VERSION -ENV PATH /opt/python/$TRAVIS_PYTHON_VERSION/bin:$PATH -ADD ./common/install_travis_python.sh install_travis_python.sh -RUN bash ./install_travis_python.sh && rm install_travis_python.sh - # (optional) Install protobuf for ONNX ARG PROTOBUF ADD ./common/install_protobuf.sh install_protobuf.sh @@ -78,6 +72,16 @@ ADD ./common/install_jni.sh install_jni.sh ADD ./java/jni.h jni.h RUN bash ./install_jni.sh && rm install_jni.sh +# Install NCCL for when CUDA is version 10.1 +ADD ./common/install_nccl.sh install_nccl.sh +RUN if [ "${CUDA_VERSION}" = 10.1 ]; then bash ./install_nccl.sh; fi +RUN rm install_nccl.sh + +# Install Open MPI for CUDA +ADD ./common/install_openmpi.sh install_openmpi.sh +RUN if [ -n "${CUDA_VERSION}" ]; then bash install_openmpi.sh; fi +RUN rm install_openmpi.sh + # Include BUILD_ENVIRONMENT environment variable in image ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} diff --git a/.circleci/docker/ubuntu-rocm/Dockerfile b/.circleci/docker/ubuntu-rocm/Dockerfile index 5fd133d08245d..d9b189bcf16c2 100644 --- a/.circleci/docker/ubuntu-rocm/Dockerfile +++ b/.circleci/docker/ubuntu-rocm/Dockerfile @@ -21,7 +21,7 @@ RUN bash ./install_clang.sh && rm install_clang.sh ADD ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh @@ -58,7 +58,6 @@ ENV PATH /opt/rocm/hcc/bin:$PATH ENV PATH /opt/rocm/hip/bin:$PATH ENV PATH /opt/rocm/opencl/bin:$PATH ENV PATH /opt/rocm/llvm/bin:$PATH -ENV HIP_PLATFORM hcc ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 diff --git a/.circleci/docker/ubuntu/Dockerfile b/.circleci/docker/ubuntu/Dockerfile index 19cf4d1093582..72f2c108ff11e 100644 --- a/.circleci/docker/ubuntu/Dockerfile +++ b/.circleci/docker/ubuntu/Dockerfile @@ -33,7 +33,7 @@ ARG KATEX ADD ./common/install_katex.sh install_katex.sh RUN bash ./install_katex.sh && rm install_katex.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh @@ -44,12 +44,9 @@ ARG GCC_VERSION ADD ./common/install_gcc.sh install_gcc.sh RUN bash ./install_gcc.sh && rm install_gcc.sh -# Install non-standard Python versions (via Travis binaries) -ARG TRAVIS_PYTHON_VERSION -ARG TRAVIS_DL_URL_PREFIX -ENV PATH /opt/python/$TRAVIS_PYTHON_VERSION/bin:$PATH -ADD ./common/install_travis_python.sh install_travis_python.sh -RUN bash ./install_travis_python.sh && rm install_travis_python.sh +# Install lcov for C++ code coverage +ADD ./common/install_lcov.sh install_lcov.sh +RUN bash ./install_lcov.sh && rm install_lcov.sh # (optional) Install protobuf for ONNX ARG PROTOBUF diff --git a/.circleci/generate_config_yml.py b/.circleci/generate_config_yml.py index f1af924bd3e2b..a836d2e510a6f 100755 --- a/.circleci/generate_config_yml.py +++ b/.circleci/generate_config_yml.py @@ -112,7 +112,10 @@ def gen_build_workflows_tree(): "when": r"<< pipeline.parameters.run_binary_tests >>", "jobs": [f() for f in binary_build_functions], }, - "build": {"jobs": [f() for f in build_workflows_functions]}, + "build": { + "when": r"<< pipeline.parameters.run_build >>", + "jobs": [f() for f in build_workflows_functions] + }, } } diff --git a/.circleci/scripts/binary_checkout.sh b/.circleci/scripts/binary_checkout.sh index 93c0f92bf9bff..17f947d740246 100755 --- a/.circleci/scripts/binary_checkout.sh +++ b/.circleci/scripts/binary_checkout.sh @@ -33,6 +33,11 @@ else export BUILDER_ROOT="$workdir/builder" fi +# Try to extract PR number from branch if not already set +if [[ -z "${CIRCLE_PR_NUMBER:-}" ]]; then + CIRCLE_PR_NUMBER="$(echo ${CIRCLE_BRANCH} | sed -E -n 's/pull\/([0-9]*).*/\1/p')" +fi + # Clone the Pytorch branch retry git clone https://github.com/pytorch/pytorch.git "$PYTORCH_ROOT" pushd "$PYTORCH_ROOT" diff --git a/.circleci/scripts/binary_ios_build.sh b/.circleci/scripts/binary_ios_build.sh index efab1e5ded3ab..4cfe778e5134a 100644 --- a/.circleci/scripts/binary_ios_build.sh +++ b/.circleci/scripts/binary_ios_build.sh @@ -15,7 +15,8 @@ export PATH="~/anaconda/bin:${PATH}" source ~/anaconda/bin/activate # Install dependencies -conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes +conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi requests --yes +conda install -c conda-forge valgrind --yes export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} # sync submodules diff --git a/.circleci/scripts/binary_ios_test.sh b/.circleci/scripts/binary_ios_test.sh index be281120016ad..863b21724a5d8 100644 --- a/.circleci/scripts/binary_ios_test.sh +++ b/.circleci/scripts/binary_ios_test.sh @@ -13,7 +13,7 @@ base64 --decode cert.txt -o Certificates.p12 rm cert.txt bundle exec fastlane install_cert # install the provisioning profile -PROFILE=TestApp_CI.mobileprovision +PROFILE=PyTorch_CI_2021.mobileprovision PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles mkdir -pv "${PROVISIONING_PROFILES}" cd "${PROVISIONING_PROFILES}" @@ -25,5 +25,5 @@ if ! [ -x "$(command -v xcodebuild)" ]; then echo 'Error: xcodebuild is not installed.' exit 1 fi -PROFILE=TestApp_CI +PROFILE=PyTorch_CI_2021 ruby ${PROJ_ROOT}/scripts/xcode_build.rb -i ${PROJ_ROOT}/build_ios/install -x ${PROJ_ROOT}/ios/TestApp/TestApp.xcodeproj -p ${IOS_PLATFORM} -c ${PROFILE} -t ${IOS_DEV_TEAM_ID} diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh index b530521f7f2d0..f1022e113fa4a 100644 --- a/.circleci/scripts/binary_ios_upload.sh +++ b/.circleci/scripts/binary_ios_upload.sh @@ -34,7 +34,13 @@ touch version.txt echo $(date +%s) > version.txt zip -r ${ZIPFILE} install src version.txt LICENSE # upload to aws -brew install awscli +# Install conda then 'conda install' awscli +curl --retry 3 -o ~/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh +chmod +x ~/conda.sh +/bin/bash ~/conda.sh -b -p ~/anaconda +export PATH="~/anaconda/bin:${PATH}" +source ~/anaconda/bin/activate +conda install -c conda-forge awscli --yes set +x export AWS_ACCESS_KEY_ID=${AWS_S3_ACCESS_KEY_FOR_PYTORCH_BINARY_UPLOAD} export AWS_SECRET_ACCESS_KEY=${AWS_S3_ACCESS_SECRET_FOR_PYTORCH_BINARY_UPLOAD} diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index b0d7385d07ee6..d1e218cac5061 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -5,12 +5,18 @@ cat >/home/circleci/project/ci_test_script.sh <= 1.6.1 that makes archives + # above a certain size fail out when attempting to extract + # see: https://github.com/conda/conda-package-handling/issues/71 + conda install -y conda-package-handling=1.6.0 retry conda create -qyn testenv python="$DESIRED_PYTHON" source activate testenv >/dev/null elif [[ "$PACKAGE_TYPE" != libtorch ]]; then - python_nodot="\$(echo $DESIRED_PYTHON | tr -d m.u)" python_path="/opt/python/cp\$python_nodot-cp\${python_nodot}" # Prior to Python 3.8 paths were suffixed with an 'm' if [[ -d "\${python_path}/bin" ]]; then @@ -20,6 +26,11 @@ elif [[ "$PACKAGE_TYPE" != libtorch ]]; then fi fi +EXTRA_CONDA_FLAGS="" +if [[ "\$python_nodot" = *39* ]]; then + EXTRA_CONDA_FLAGS="-c=conda-forge" +fi + # Install the package # These network calls should not have 'retry's because they are installing # locally and aren't actually network calls @@ -28,20 +39,27 @@ fi # conda build scripts themselves. These should really be consolidated pkg="/final_pkgs/\$(ls /final_pkgs)" if [[ "$PACKAGE_TYPE" == conda ]]; then - conda install -y "\$pkg" --offline - if [[ "$DESIRED_CUDA" == 'cpu' ]]; then - retry conda install -y cpuonly -c pytorch - fi - retry conda install -yq future numpy protobuf six - if [[ "$DESIRED_CUDA" != 'cpu' ]]; then - # DESIRED_CUDA is in format cu90 or cu102 - if [[ "${#DESIRED_CUDA}" == 4 ]]; then - cu_ver="${DESIRED_CUDA:2:1}.${DESIRED_CUDA:3}" - else - cu_ver="${DESIRED_CUDA:2:2}.${DESIRED_CUDA:4}" + ( + # For some reason conda likes to re-activate the conda environment when attempting this install + # which means that a deactivate is run and some variables might not exist when that happens, + # namely CONDA_MKL_INTERFACE_LAYER_BACKUP from libblas so let's just ignore unbound variables when + # it comes to the conda installation commands + set +u + conda install \${EXTRA_CONDA_FLAGS} -y "\$pkg" --offline + if [[ "$DESIRED_CUDA" == 'cpu' ]]; then + retry conda install \${EXTRA_CONDA_FLAGS} -y cpuonly -c pytorch fi - retry conda install -yq -c nvidia -c pytorch "cudatoolkit=\${cu_ver}" - fi + retry conda install \${EXTRA_CONDA_FLAGS} -yq future numpy protobuf six + if [[ "$DESIRED_CUDA" != 'cpu' ]]; then + # DESIRED_CUDA is in format cu90 or cu102 + if [[ "${#DESIRED_CUDA}" == 4 ]]; then + cu_ver="${DESIRED_CUDA:2:1}.${DESIRED_CUDA:3}" + else + cu_ver="${DESIRED_CUDA:2:2}.${DESIRED_CUDA:4}" + fi + retry conda install \${EXTRA_CONDA_FLAGS} -yq -c nvidia -c pytorch "cudatoolkit=\${cu_ver}" + fi + ) elif [[ "$PACKAGE_TYPE" != libtorch ]]; then pip install "\$pkg" retry pip install -q future numpy protobuf six diff --git a/.circleci/scripts/binary_macos_test.sh b/.circleci/scripts/binary_macos_test.sh index 30682519cd8af..c36bfc28f67bb 100755 --- a/.circleci/scripts/binary_macos_test.sh +++ b/.circleci/scripts/binary_macos_test.sh @@ -20,9 +20,9 @@ if [[ "$PACKAGE_TYPE" == libtorch ]]; then unzip "$pkg" -d /tmp cd /tmp/libtorch elif [[ "$PACKAGE_TYPE" == conda ]]; then - conda install -y "$pkg" --offline + conda install -y "$pkg" else - pip install "$pkg" --no-index --no-dependencies -v + pip install "$pkg" -v fi # Test diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index d4c31cefc7e5b..8934f030b7780 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -73,7 +73,7 @@ PIP_UPLOAD_FOLDER='nightly/' # We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it export DATE="$(date -u +%Y%m%d)" #TODO: We should be pulling semver version from the base version.txt -BASE_BUILD_VERSION="1.7.0.dev$DATE" +BASE_BUILD_VERSION="1.8.0.dev$DATE" # Change BASE_BUILD_VERSION to git tag when on a git tag # Use 'git -C' to make doubly sure we're in the correct directory for checking # the git tag @@ -100,8 +100,14 @@ if [[ "$PACKAGE_TYPE" == libtorch ]]; then POSSIBLE_JAVA_HOMES+=(/usr/local) POSSIBLE_JAVA_HOMES+=(/usr/lib/jvm/java-8-openjdk-amd64) POSSIBLE_JAVA_HOMES+=(/Library/Java/JavaVirtualMachines/*.jdk/Contents/Home) + # Add the Windows-specific JNI path + POSSIBLE_JAVA_HOMES+=("$PWD/.circleci/windows-jni/") for JH in "${POSSIBLE_JAVA_HOMES[@]}" ; do if [[ -e "$JH/include/jni.h" ]] ; then + # Skip if we're not on Windows but haven't found a JAVA_HOME + if [[ "$JH" == "$PWD/.circleci/windows-jni/" && "$OSTYPE" != "msys" ]] ; then + break + fi echo "Found jni.h under $JH" JAVA_HOME="$JH" BUILD_JNI=ON @@ -130,7 +136,7 @@ if [[ "${BUILD_FOR_SYSTEM:-}" == "windows" ]]; then fi export DATE="$DATE" -export NIGHTLIES_DATE_PREAMBLE=1.7.0.dev +export NIGHTLIES_DATE_PREAMBLE=1.8.0.dev export PYTORCH_BUILD_VERSION="$PYTORCH_BUILD_VERSION" export PYTORCH_BUILD_NUMBER="$PYTORCH_BUILD_NUMBER" export OVERRIDE_PACKAGE_VERSION="$PYTORCH_BUILD_VERSION" @@ -161,6 +167,7 @@ export CIRCLE_TAG="${CIRCLE_TAG:-}" export CIRCLE_SHA1="$CIRCLE_SHA1" export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}" export CIRCLE_BRANCH="$CIRCLE_BRANCH" +export CIRCLE_WORKFLOW_ID="$CIRCLE_WORKFLOW_ID" # =================== The above code will be executed inside Docker container =================== EOL diff --git a/.circleci/scripts/cpp_doc_push_script.sh b/.circleci/scripts/cpp_doc_push_script.sh index 618b64c7f12af..c6b4f00a06f0f 100755 --- a/.circleci/scripts/cpp_doc_push_script.sh +++ b/.circleci/scripts/cpp_doc_push_script.sh @@ -57,6 +57,7 @@ cp torch/_utils_internal.py tools/shared # Generate PyTorch files time python tools/setup_helpers/generate_code.py \ --declarations-path build/aten/src/ATen/Declarations.yaml \ + --native-functions-path aten/src/ATen/native/native_functions.yaml \ --nn-path aten/src/ # Build the docs @@ -87,7 +88,7 @@ git status git config user.email "soumith+bot@pytorch.org" git config user.name "pytorchbot" # If there aren't changes, don't make a commit; push is no-op -git commit -m "Automatic sync on $(date)" || true +git commit -m "Generate C++ docs from pytorch/pytorch@$CIRCLE_SHA1" || true git status popd diff --git a/.circleci/scripts/driver_update.bat b/.circleci/scripts/driver_update.bat index 9fc33445dfb2f..46c05475cdba8 100644 --- a/.circleci/scripts/driver_update.bat +++ b/.circleci/scripts/driver_update.bat @@ -1,8 +1,8 @@ -set "DRIVER_DOWNLOAD_LINK=https://s3.amazonaws.com/ossci-windows/451.82-tesla-desktop-winserver-2019-2016-international.exe" -curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 451.82-tesla-desktop-winserver-2019-2016-international.exe +set "DRIVER_DOWNLOAD_LINK=https://s3.amazonaws.com/ossci-windows/452.39-data-center-tesla-desktop-win10-64bit-international.exe" +curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 452.39-data-center-tesla-desktop-win10-64bit-international.exe if errorlevel 1 exit /b 1 -start /wait 451.82-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot +start /wait 452.39-data-center-tesla-desktop-win10-64bit-international.exe -s -noreboot if errorlevel 1 exit /b 1 -del 451.82-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL +del 452.39-data-center-tesla-desktop-win10-64bit-international.exe || ver > NUL diff --git a/.circleci/scripts/python_doc_push_script.sh b/.circleci/scripts/python_doc_push_script.sh index 4da8d546d36a3..9061eb9d1d85c 100755 --- a/.circleci/scripts/python_doc_push_script.sh +++ b/.circleci/scripts/python_doc_push_script.sh @@ -107,7 +107,7 @@ git status git config user.email "soumith+bot@pytorch.org" git config user.name "pytorchbot" # If there aren't changes, don't make a commit; push is no-op -git commit -m "auto-generating sphinx docs" || true +git commit -m "Generate Python docs from pytorch/pytorch@$CIRCLE_SHA1" || true git status popd diff --git a/.circleci/scripts/setup_ci_environment.sh b/.circleci/scripts/setup_ci_environment.sh index f6c398aafd92d..7d1f0d6c5b756 100755 --- a/.circleci/scripts/setup_ci_environment.sh +++ b/.circleci/scripts/setup_ci_environment.sh @@ -54,7 +54,7 @@ add_to_env_file() { echo "${content}" >> "${BASH_ENV:-/tmp/env}" } -add_to_env_file "IN_CIRCLECI=1" +add_to_env_file "IN_CI=1" add_to_env_file "COMMIT_SOURCE=${CIRCLE_BRANCH:-}" add_to_env_file "BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" add_to_env_file "CIRCLE_PULL_REQUEST=${CIRCLE_PULL_REQUEST}" diff --git a/.circleci/scripts/upload_binary_size_to_scuba.py b/.circleci/scripts/upload_binary_size_to_scuba.py index a46b0392fc8aa..0c8992453ef63 100644 --- a/.circleci/scripts/upload_binary_size_to_scuba.py +++ b/.circleci/scripts/upload_binary_size_to_scuba.py @@ -41,6 +41,7 @@ def build_message(size): "build_num": os.environ.get("CIRCLE_BUILD_NUM"), "sha1": os.environ.get("CIRCLE_SHA1"), "branch": os.environ.get("CIRCLE_BRANCH"), + "workflow_id": os.environ.get("CIRCLE_WORKFLOW_ID"), }, "int": { "time": int(time.time()), @@ -115,6 +116,7 @@ def gen_messages(): "build_num": os.environ.get("CIRCLE_BUILD_NUM"), "sha1": os.environ.get("CIRCLE_SHA1"), "branch": os.environ.get("CIRCLE_BRANCH"), + "workflow_id": os.environ.get("CIRCLE_WORKFLOW_ID"), }, "int": { "time": int(time.time()), diff --git a/.circleci/scripts/windows_cuda_install.sh b/.circleci/scripts/windows_cuda_install.sh index 4557717528cac..b73b5b0a3c677 100644 --- a/.circleci/scripts/windows_cuda_install.sh +++ b/.circleci/scripts/windows_cuda_install.sh @@ -1,21 +1,25 @@ #!/bin/bash set -eux -o pipefail -if [[ "$CUDA_VERSION" == "10" ]]; then - cuda_complete_version="10.1" +cuda_major_version=${CUDA_VERSION%.*} + +if [[ "$cuda_major_version" == "10" ]]; then cuda_installer_name="cuda_10.1.243_426.00_win10" msbuild_project_dir="CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions" cuda_install_packages="nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" -elif [[ "$CUDA_VERSION" == "11" ]]; then - cuda_complete_version="11.0" - cuda_installer_name="cuda_11.0.2_451.48_win10" +elif [[ "$cuda_major_version" == "11" ]]; then + cuda_installer_name="cuda_11.1.0_456.43_win10" msbuild_project_dir="visual_studio_integration/CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions" - cuda_install_packages="nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" + cuda_install_packages="nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" else echo "CUDA_VERSION $CUDA_VERSION is not supported yet" exit 1 fi +if [[ "$cuda_major_version" == "11" && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then + cuda_install_packages="${cuda_install_packages} Display.Driver" +fi + cuda_installer_link="https://ossci-windows.s3.amazonaws.com/${cuda_installer_name}.exe" curl --retry 3 -kLO $cuda_installer_link @@ -44,7 +48,7 @@ then export NVTOOLSEXT_PATH="C:\\Program Files\\NVIDIA Corporation\\NvToolsExt\\" fi -if ! ls "/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${cuda_complete_version}/bin/nvcc.exe" +if ! ls "/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${CUDA_VERSION}/bin/nvcc.exe" then echo "CUDA installation failed" mkdir -p /c/w/build-results diff --git a/.circleci/scripts/windows_cudnn_install.sh b/.circleci/scripts/windows_cudnn_install.sh index 2725150bd0bf2..c133506290c1b 100644 --- a/.circleci/scripts/windows_cudnn_install.sh +++ b/.circleci/scripts/windows_cudnn_install.sh @@ -1,12 +1,12 @@ #!/bin/bash set -eux -o pipefail -if [[ "$CUDA_VERSION" == "10" ]]; then - cuda_complete_version="10.1" - cudnn_installer_name="cudnn-10.1-windows10-x64-v7.6.4.38" -elif [[ "$CUDA_VERSION" == "11" ]]; then - cuda_complete_version="11.0" - cudnn_installer_name="cudnn-11.0-windows-x64-v8.0.2.39" +cuda_major_version=${CUDA_VERSION%.*} + +if [[ "$cuda_major_version" == "10" ]]; then + cudnn_installer_name="cudnn-${CUDA_VERSION}-windows10-x64-v7.6.4.38" +elif [[ "$cuda_major_version" == "11" ]]; then + cudnn_installer_name="cudnn-${CUDA_VERSION}-windows-x64-v8.0.5.39" else echo "CUDNN for CUDA_VERSION $CUDA_VERSION is not supported yet" exit 1 @@ -16,6 +16,6 @@ cudnn_installer_link="https://ossci-windows.s3.amazonaws.com/${cudnn_installer_n curl --retry 3 -O $cudnn_installer_link 7z x ${cudnn_installer_name}.zip -ocudnn -cp -r cudnn/cuda/* "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${cuda_complete_version}/" +cp -r cudnn/cuda/* "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${CUDA_VERSION}/" rm -rf cudnn rm -f ${cudnn_installer_name}.zip diff --git a/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml b/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml index 41f72a1baa986..c912a4fb690bf 100644 --- a/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml +++ b/.circleci/verbatim-sources/build-parameters/pytorch-build-params.yml @@ -36,11 +36,15 @@ pytorch_ios_params: &pytorch_ios_params op_list: type: string default: "" + use_metal: + type: string + default: "0" environment: BUILD_ENVIRONMENT: << parameters.build_environment >> IOS_ARCH: << parameters.ios_arch >> IOS_PLATFORM: << parameters.ios_platform >> SELECTED_OP_LIST: << parameters.op_list >> + USE_PYTORCH_METAL: << parameters.use_metal >> pytorch_windows_params: &pytorch_windows_params parameters: @@ -55,7 +59,7 @@ pytorch_windows_params: &pytorch_windows_params default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" diff --git a/.circleci/verbatim-sources/commands.yml b/.circleci/verbatim-sources/commands.yml index cec3b3c588e2d..dfa4ee4d16ded 100644 --- a/.circleci/verbatim-sources/commands.yml +++ b/.circleci/verbatim-sources/commands.yml @@ -103,7 +103,7 @@ commands: name: (Optional) Merge target branch no_output_timeout: "10m" command: | - if [ -n "$CIRCLE_PULL_REQUEST" ]; then + if [[ -n "$CIRCLE_PULL_REQUEST" && "$CIRCLE_BRANCH" != "nightly" ]]; then PR_NUM=$(basename $CIRCLE_PULL_REQUEST) CIRCLE_PR_BASE_BRANCH=$(curl -s https://api.github.com/repos/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME/pulls/$PR_NUM | jq -r '.base.ref') if [[ "${BUILD_ENVIRONMENT}" == *"xla"* || "${BUILD_ENVIRONMENT}" == *"gcc5"* ]] ; then diff --git a/.circleci/verbatim-sources/header-section.yml b/.circleci/verbatim-sources/header-section.yml index 26205a0cccbaa..43d4c94ee5ed1 100644 --- a/.circleci/verbatim-sources/header-section.yml +++ b/.circleci/verbatim-sources/header-section.yml @@ -11,6 +11,9 @@ parameters: run_binary_tests: type: boolean default: false + run_build: + type: boolean + default: true docker_config_defaults: &docker_config_defaults user: jenkins diff --git a/.circleci/verbatim-sources/job-specs/binary-job-specs.yml b/.circleci/verbatim-sources/job-specs/binary-job-specs.yml index 7e635f42bce49..f8d1dde4e5adb 100644 --- a/.circleci/verbatim-sources/job-specs/binary-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/binary-job-specs.yml @@ -135,7 +135,7 @@ smoke_mac_test: <<: *binary_linux_test_upload_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - run: @@ -160,7 +160,7 @@ binary_mac_build: <<: *binary_mac_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - checkout @@ -174,7 +174,7 @@ - run: name: Build - no_output_timeout: "1h" + no_output_timeout: "90m" command: | # Do not set -u here; there is some problem with CircleCI # variable expansion with PROMPT_COMMAND @@ -201,7 +201,7 @@ binary_ios_build: <<: *pytorch_ios_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - attach_workspace: at: ~/workspace @@ -228,7 +228,7 @@ binary_ios_upload: <<: *pytorch_ios_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - attach_workspace: at: ~/workspace diff --git a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml index 9cc75136cfddf..b372f45907b36 100644 --- a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml @@ -43,10 +43,13 @@ set -ex export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} + tag=${CIRCLE_TAG:1:5} + target=${tag:-master} + echo "building for ${target}" time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) - export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/python_doc_push_script.sh docs/master master site") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && '"export CIRCLE_SHA1='$CIRCLE_SHA1'"' && . ./.circleci/scripts/python_doc_push_script.sh docs/'$target' '$target' site") | docker exec -u jenkins -i "$id" bash) 2>&1' echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts @@ -85,10 +88,13 @@ set -ex export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE} + tag=${CIRCLE_TAG:1:5} + target=${tag:-master} + echo "building for ${target}" time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) - export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && . ./.circleci/scripts/cpp_doc_push_script.sh docs/master master") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "sudo chown -R jenkins workspace && cd workspace && '"export CIRCLE_SHA1='$CIRCLE_SHA1'"' && . ./.circleci/scripts/cpp_doc_push_script.sh docs/"$target" master") | docker exec -u jenkins -i "$id" bash) 2>&1' echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts @@ -109,7 +115,7 @@ environment: BUILD_ENVIRONMENT: pytorch-macos-10.13-py3-build macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - run_brew_for_macos_build @@ -118,7 +124,7 @@ no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 # Install sccache sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache @@ -143,7 +149,7 @@ environment: BUILD_ENVIRONMENT: pytorch-macos-10.13-py3-test macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - attach_workspace: @@ -154,7 +160,7 @@ no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 chmod a+x .jenkins/pytorch/macos-test.sh unbuffer .jenkins/pytorch/macos-test.sh 2>&1 | ts @@ -253,22 +259,22 @@ pytorch_android_publish_snapshot: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-publish-snapshot - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:ab1632df-fa59-40e6-8c23-98e004f61148" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c" PYTHON_VERSION: "3.6" resource_class: large machine: image: ubuntu-1604:202007-01 steps: - checkout + - calculate_docker_image_tag - setup_linux_system_environment - - checkout - setup_ci_environment - run: name: pytorch android gradle build no_output_timeout: "1h" command: | set -eux - docker_image_commit=${DOCKER_IMAGE}-${CIRCLE_SHA1} + docker_image_commit=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} docker_image_libtorch_android_x86_32_gradle=${docker_image_commit}-android-x86_32-gradle @@ -371,7 +377,7 @@ pytorch_ios_build: <<: *pytorch_ios_params macos: - xcode: "11.2.1" + xcode: "12.0" steps: - checkout - run_brew_for_ios_build @@ -390,7 +396,7 @@ rm cert.txt bundle exec fastlane install_cert # install the provisioning profile - PROFILE=TestApp_CI.mobileprovision + PROFILE=PyTorch_CI_2021.mobileprovision PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles mkdir -pv "${PROVISIONING_PROFILES}" cd "${PROVISIONING_PROFILES}" @@ -402,7 +408,7 @@ no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 WORKSPACE=/Users/distiller/workspace PROJ_ROOT=/Users/distiller/project export TCLLIBPATH="/usr/local/lib" @@ -419,7 +425,7 @@ $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) } - retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes + retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi requests --yes # sync submodules cd ${PROJ_ROOT} @@ -433,6 +439,7 @@ chmod a+x ${PROJ_ROOT}/scripts/build_ios.sh echo "IOS_ARCH: ${IOS_ARCH}" echo "IOS_PLATFORM: ${IOS_PLATFORM}" + echo "USE_PYTORCH_METAL": "${USE_METAL}" #check the custom build flag echo "SELECTED_OP_LIST: ${SELECTED_OP_LIST}" @@ -441,6 +448,9 @@ fi export IOS_ARCH=${IOS_ARCH} export IOS_PLATFORM=${IOS_PLATFORM} + if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then + export USE_PYTORCH_METAL=${USE_METAL} + fi unbuffer ${PROJ_ROOT}/scripts/build_ios.sh 2>&1 | ts - run: name: Run Build Test @@ -448,7 +458,7 @@ command: | set -e PROJ_ROOT=/Users/distiller/project - PROFILE=TestApp_CI + PROFILE=PyTorch_CI_2021 # run the ruby build script if ! [ -x "$(command -v xcodebuild)" ]; then echo 'Error: xcodebuild is not installed.' diff --git a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml index 0f0dd76636b43..8cbb9a4e3f40b 100644 --- a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml @@ -15,12 +15,8 @@ jobs: no_output_timeout: "1h" command: | set -e - # TODO: Remove this after we figure out why rocm tests are failing - if [[ "${DOCKER_IMAGE}" == *rocm3.5* ]]; then - export DOCKER_TAG="ab1632df-fa59-40e6-8c23-98e004f61148" - fi - if [[ "${DOCKER_IMAGE}" == *rocm3.7* ]]; then - export DOCKER_TAG="1045c7b891104cb4fd23399eab413b6213e48aeb" + if [[ "${DOCKER_IMAGE}" == *rocm3.9* ]]; then + export DOCKER_TAG="f3d89a32912f62815e4feaeed47e564e887dffd6" fi if [[ ${BUILD_ENVIRONMENT} == *"pure_torch"* ]]; then echo 'BUILD_CAFFE2=OFF' >> "${BASH_ENV}" @@ -52,7 +48,7 @@ jobs: if [ -z "${BUILD_ONLY}" ]; then # Note [Special build images] # The xla build uses the same docker image as - # pytorch-linux-trusty-py3.6-gcc5.4-build. In the push step, we have to + # pytorch_linux_bionic_py3_6_clang9_build. In the push step, we have to # distinguish between them so the test can pick up the correct image. output_image=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} if [[ ${BUILD_ENVIRONMENT} == *"xla"* ]]; then @@ -100,12 +96,8 @@ jobs: command: | set -e export PYTHONUNBUFFERED=1 - # TODO: Remove this after we figure out why rocm tests are failing - if [[ "${DOCKER_IMAGE}" == *rocm3.5* ]]; then - export DOCKER_TAG="ab1632df-fa59-40e6-8c23-98e004f61148" - fi - if [[ "${DOCKER_IMAGE}" == *rocm3.7* ]]; then - export DOCKER_TAG="1045c7b891104cb4fd23399eab413b6213e48aeb" + if [[ "${DOCKER_IMAGE}" == *rocm3.9* ]]; then + export DOCKER_TAG="f3d89a32912f62815e4feaeed47e564e887dffd6" fi # See Note [Special build images] output_image=${DOCKER_IMAGE}:${DOCKER_TAG}-${CIRCLE_SHA1} @@ -183,7 +175,7 @@ jobs: echo ".jenkins/pytorch/multigpu-test.sh" >> docker_commands.sh elif [[ ${BUILD_ENVIRONMENT} == *onnx* ]]; then echo "pip install click mock tabulate networkx==2.0" >> docker_commands.sh - echo "pip -q install --user -b /tmp/pip_install_onnx \"file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx\"" >> docker_commands.sh + echo "pip -q install --user \"file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx\"" >> docker_commands.sh echo ".jenkins/caffe2/test.sh" >> docker_commands.sh else echo ".jenkins/pytorch/test.sh" >> docker_commands.sh @@ -206,8 +198,10 @@ jobs: export CIRCLE_SHA1="$CIRCLE_SHA1" export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}" export CIRCLE_BRANCH="$CIRCLE_BRANCH" + export CIRCLE_JOB="$CIRCLE_JOB" + export CIRCLE_WORKFLOW_ID="$CIRCLE_WORKFLOW_ID" cd workspace - python test/print_test_stats.py test + python test/print_test_stats.py --upload-to-s3 test EOL echo "(cat docker_commands.sh | docker exec -u jenkins -i "$id" bash) 2>&1" > command.sh unbuffer bash command.sh | ts @@ -215,7 +209,11 @@ jobs: echo "Retrieving test reports" docker cp $id:/var/lib/jenkins/workspace/test/test-reports ./ || echo 'No test reports found!' if [[ ${BUILD_ENVIRONMENT} == *"coverage"* ]]; then - echo "Retrieving coverage report" + echo "Retrieving C++ coverage report" + docker cp $id:/var/lib/jenkins/workspace/build/coverage.info ./test + fi + if [[ ${BUILD_ENVIRONMENT} == *"coverage"* || ${BUILD_ENVIRONMENT} == *"onnx"* ]]; then + echo "Retrieving Python coverage report" docker cp $id:/var/lib/jenkins/workspace/test/.coverage ./test docker cp $id:/var/lib/jenkins/workspace/test/coverage.xml ./test python3 -mpip install codecov @@ -239,7 +237,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -301,7 +299,7 @@ jobs: default: "" cuda_version: type: string - default: "10" + default: "10.1" python_version: type: string default: "3.6" @@ -330,9 +328,6 @@ jobs: if [[ "${CUDA_VERSION}" != "10" || "${JOB_EXECUTOR}" != "windows-with-nvidia-gpu" ]]; then .circleci/scripts/windows_cuda_install.sh fi - if [[ "${CUDA_VERSION}" != "10" && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then - .circleci/scripts/driver_update.bat - fi fi - run: name: Install Cudnn @@ -345,7 +340,7 @@ jobs: no_output_timeout: "30m" command: | set -e - export IN_CIRCLECI=1 + export IN_CI=1 set +x export AWS_ACCESS_KEY_ID=${CIRCLECI_AWS_ACCESS_KEY_FOR_WIN_BUILD_V1} export AWS_SECRET_ACCESS_KEY=${CIRCLECI_AWS_SECRET_KEY_FOR_WIN_BUILD_V1} diff --git a/.circleci/windows-jni/include/jni.h b/.circleci/windows-jni/include/jni.h new file mode 100644 index 0000000000000..f793148c1df0f --- /dev/null +++ b/.circleci/windows-jni/include/jni.h @@ -0,0 +1,1132 @@ +/* + * Copyright (C) 2006 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * JNI specification, as defined by Sun: + * http://java.sun.com/javase/6/docs/technotes/guides/jni/spec/jniTOC.html + * + * Everything here is expected to be VM-neutral. + */ + +#ifndef JNI_H_ +#define JNI_H_ + +#include +#include + +/* Primitive types that match up with Java equivalents. */ +typedef uint8_t jboolean; /* unsigned 8 bits */ +typedef int8_t jbyte; /* signed 8 bits */ +typedef uint16_t jchar; /* unsigned 16 bits */ +typedef int16_t jshort; /* signed 16 bits */ +typedef int32_t jint; /* signed 32 bits */ +typedef int64_t jlong; /* signed 64 bits */ +typedef float jfloat; /* 32-bit IEEE 754 */ +typedef double jdouble; /* 64-bit IEEE 754 */ + +/* "cardinal indices and sizes" */ +typedef jint jsize; + +#ifdef __cplusplus +/* + * Reference types, in C++ + */ +class _jobject {}; +class _jclass : public _jobject {}; +class _jstring : public _jobject {}; +class _jarray : public _jobject {}; +class _jobjectArray : public _jarray {}; +class _jbooleanArray : public _jarray {}; +class _jbyteArray : public _jarray {}; +class _jcharArray : public _jarray {}; +class _jshortArray : public _jarray {}; +class _jintArray : public _jarray {}; +class _jlongArray : public _jarray {}; +class _jfloatArray : public _jarray {}; +class _jdoubleArray : public _jarray {}; +class _jthrowable : public _jobject {}; + +typedef _jobject* jobject; +typedef _jclass* jclass; +typedef _jstring* jstring; +typedef _jarray* jarray; +typedef _jobjectArray* jobjectArray; +typedef _jbooleanArray* jbooleanArray; +typedef _jbyteArray* jbyteArray; +typedef _jcharArray* jcharArray; +typedef _jshortArray* jshortArray; +typedef _jintArray* jintArray; +typedef _jlongArray* jlongArray; +typedef _jfloatArray* jfloatArray; +typedef _jdoubleArray* jdoubleArray; +typedef _jthrowable* jthrowable; +typedef _jobject* jweak; + + +#else /* not __cplusplus */ + +/* + * Reference types, in C. + */ +typedef void* jobject; +typedef jobject jclass; +typedef jobject jstring; +typedef jobject jarray; +typedef jarray jobjectArray; +typedef jarray jbooleanArray; +typedef jarray jbyteArray; +typedef jarray jcharArray; +typedef jarray jshortArray; +typedef jarray jintArray; +typedef jarray jlongArray; +typedef jarray jfloatArray; +typedef jarray jdoubleArray; +typedef jobject jthrowable; +typedef jobject jweak; + +#endif /* not __cplusplus */ + +struct _jfieldID; /* opaque structure */ +typedef struct _jfieldID* jfieldID; /* field IDs */ + +struct _jmethodID; /* opaque structure */ +typedef struct _jmethodID* jmethodID; /* method IDs */ + +struct JNIInvokeInterface; + +typedef union jvalue { + jboolean z; + jbyte b; + jchar c; + jshort s; + jint i; + jlong j; + jfloat f; + jdouble d; + jobject l; +} jvalue; + +typedef enum jobjectRefType { + JNIInvalidRefType = 0, + JNILocalRefType = 1, + JNIGlobalRefType = 2, + JNIWeakGlobalRefType = 3 +} jobjectRefType; + +typedef struct { + const char* name; + const char* signature; + void* fnPtr; +} JNINativeMethod; + +struct _JNIEnv; +struct _JavaVM; +typedef const struct JNINativeInterface* C_JNIEnv; + +#if defined(__cplusplus) +typedef _JNIEnv JNIEnv; +typedef _JavaVM JavaVM; +#else +typedef const struct JNINativeInterface* JNIEnv; +typedef const struct JNIInvokeInterface* JavaVM; +#endif + +/* + * Table of interface function pointers. + */ +struct JNINativeInterface { + void* reserved0; + void* reserved1; + void* reserved2; + void* reserved3; + + jint (*GetVersion)(JNIEnv *); + + jclass (*DefineClass)(JNIEnv*, const char*, jobject, const jbyte*, + jsize); + jclass (*FindClass)(JNIEnv*, const char*); + + jmethodID (*FromReflectedMethod)(JNIEnv*, jobject); + jfieldID (*FromReflectedField)(JNIEnv*, jobject); + /* spec doesn't show jboolean parameter */ + jobject (*ToReflectedMethod)(JNIEnv*, jclass, jmethodID, jboolean); + + jclass (*GetSuperclass)(JNIEnv*, jclass); + jboolean (*IsAssignableFrom)(JNIEnv*, jclass, jclass); + + /* spec doesn't show jboolean parameter */ + jobject (*ToReflectedField)(JNIEnv*, jclass, jfieldID, jboolean); + + jint (*Throw)(JNIEnv*, jthrowable); + jint (*ThrowNew)(JNIEnv *, jclass, const char *); + jthrowable (*ExceptionOccurred)(JNIEnv*); + void (*ExceptionDescribe)(JNIEnv*); + void (*ExceptionClear)(JNIEnv*); + void (*FatalError)(JNIEnv*, const char*); + + jint (*PushLocalFrame)(JNIEnv*, jint); + jobject (*PopLocalFrame)(JNIEnv*, jobject); + + jobject (*NewGlobalRef)(JNIEnv*, jobject); + void (*DeleteGlobalRef)(JNIEnv*, jobject); + void (*DeleteLocalRef)(JNIEnv*, jobject); + jboolean (*IsSameObject)(JNIEnv*, jobject, jobject); + + jobject (*NewLocalRef)(JNIEnv*, jobject); + jint (*EnsureLocalCapacity)(JNIEnv*, jint); + + jobject (*AllocObject)(JNIEnv*, jclass); + jobject (*NewObject)(JNIEnv*, jclass, jmethodID, ...); + jobject (*NewObjectV)(JNIEnv*, jclass, jmethodID, va_list); + jobject (*NewObjectA)(JNIEnv*, jclass, jmethodID, jvalue*); + + jclass (*GetObjectClass)(JNIEnv*, jobject); + jboolean (*IsInstanceOf)(JNIEnv*, jobject, jclass); + jmethodID (*GetMethodID)(JNIEnv*, jclass, const char*, const char*); + + jobject (*CallObjectMethod)(JNIEnv*, jobject, jmethodID, ...); + jobject (*CallObjectMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jobject (*CallObjectMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jboolean (*CallBooleanMethod)(JNIEnv*, jobject, jmethodID, ...); + jboolean (*CallBooleanMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jboolean (*CallBooleanMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jbyte (*CallByteMethod)(JNIEnv*, jobject, jmethodID, ...); + jbyte (*CallByteMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jbyte (*CallByteMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jchar (*CallCharMethod)(JNIEnv*, jobject, jmethodID, ...); + jchar (*CallCharMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jchar (*CallCharMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jshort (*CallShortMethod)(JNIEnv*, jobject, jmethodID, ...); + jshort (*CallShortMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jshort (*CallShortMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jint (*CallIntMethod)(JNIEnv*, jobject, jmethodID, ...); + jint (*CallIntMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jint (*CallIntMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jlong (*CallLongMethod)(JNIEnv*, jobject, jmethodID, ...); + jlong (*CallLongMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jlong (*CallLongMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jfloat (*CallFloatMethod)(JNIEnv*, jobject, jmethodID, ...); + jfloat (*CallFloatMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jfloat (*CallFloatMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + jdouble (*CallDoubleMethod)(JNIEnv*, jobject, jmethodID, ...); + jdouble (*CallDoubleMethodV)(JNIEnv*, jobject, jmethodID, va_list); + jdouble (*CallDoubleMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + void (*CallVoidMethod)(JNIEnv*, jobject, jmethodID, ...); + void (*CallVoidMethodV)(JNIEnv*, jobject, jmethodID, va_list); + void (*CallVoidMethodA)(JNIEnv*, jobject, jmethodID, jvalue*); + + jobject (*CallNonvirtualObjectMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jobject (*CallNonvirtualObjectMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jobject (*CallNonvirtualObjectMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jboolean (*CallNonvirtualBooleanMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jboolean (*CallNonvirtualBooleanMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jboolean (*CallNonvirtualBooleanMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jbyte (*CallNonvirtualByteMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jbyte (*CallNonvirtualByteMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jbyte (*CallNonvirtualByteMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jchar (*CallNonvirtualCharMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jchar (*CallNonvirtualCharMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jchar (*CallNonvirtualCharMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jshort (*CallNonvirtualShortMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jshort (*CallNonvirtualShortMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jshort (*CallNonvirtualShortMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jint (*CallNonvirtualIntMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jint (*CallNonvirtualIntMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jint (*CallNonvirtualIntMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jlong (*CallNonvirtualLongMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jlong (*CallNonvirtualLongMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jlong (*CallNonvirtualLongMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jfloat (*CallNonvirtualFloatMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jfloat (*CallNonvirtualFloatMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jfloat (*CallNonvirtualFloatMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + jdouble (*CallNonvirtualDoubleMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + jdouble (*CallNonvirtualDoubleMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + jdouble (*CallNonvirtualDoubleMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + void (*CallNonvirtualVoidMethod)(JNIEnv*, jobject, jclass, + jmethodID, ...); + void (*CallNonvirtualVoidMethodV)(JNIEnv*, jobject, jclass, + jmethodID, va_list); + void (*CallNonvirtualVoidMethodA)(JNIEnv*, jobject, jclass, + jmethodID, jvalue*); + + jfieldID (*GetFieldID)(JNIEnv*, jclass, const char*, const char*); + + jobject (*GetObjectField)(JNIEnv*, jobject, jfieldID); + jboolean (*GetBooleanField)(JNIEnv*, jobject, jfieldID); + jbyte (*GetByteField)(JNIEnv*, jobject, jfieldID); + jchar (*GetCharField)(JNIEnv*, jobject, jfieldID); + jshort (*GetShortField)(JNIEnv*, jobject, jfieldID); + jint (*GetIntField)(JNIEnv*, jobject, jfieldID); + jlong (*GetLongField)(JNIEnv*, jobject, jfieldID); + jfloat (*GetFloatField)(JNIEnv*, jobject, jfieldID); + jdouble (*GetDoubleField)(JNIEnv*, jobject, jfieldID); + + void (*SetObjectField)(JNIEnv*, jobject, jfieldID, jobject); + void (*SetBooleanField)(JNIEnv*, jobject, jfieldID, jboolean); + void (*SetByteField)(JNIEnv*, jobject, jfieldID, jbyte); + void (*SetCharField)(JNIEnv*, jobject, jfieldID, jchar); + void (*SetShortField)(JNIEnv*, jobject, jfieldID, jshort); + void (*SetIntField)(JNIEnv*, jobject, jfieldID, jint); + void (*SetLongField)(JNIEnv*, jobject, jfieldID, jlong); + void (*SetFloatField)(JNIEnv*, jobject, jfieldID, jfloat); + void (*SetDoubleField)(JNIEnv*, jobject, jfieldID, jdouble); + + jmethodID (*GetStaticMethodID)(JNIEnv*, jclass, const char*, const char*); + + jobject (*CallStaticObjectMethod)(JNIEnv*, jclass, jmethodID, ...); + jobject (*CallStaticObjectMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jobject (*CallStaticObjectMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + jboolean (*CallStaticBooleanMethod)(JNIEnv*, jclass, jmethodID, ...); + jboolean (*CallStaticBooleanMethodV)(JNIEnv*, jclass, jmethodID, + va_list); + jboolean (*CallStaticBooleanMethodA)(JNIEnv*, jclass, jmethodID, + jvalue*); + jbyte (*CallStaticByteMethod)(JNIEnv*, jclass, jmethodID, ...); + jbyte (*CallStaticByteMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jbyte (*CallStaticByteMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + jchar (*CallStaticCharMethod)(JNIEnv*, jclass, jmethodID, ...); + jchar (*CallStaticCharMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jchar (*CallStaticCharMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + jshort (*CallStaticShortMethod)(JNIEnv*, jclass, jmethodID, ...); + jshort (*CallStaticShortMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jshort (*CallStaticShortMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + jint (*CallStaticIntMethod)(JNIEnv*, jclass, jmethodID, ...); + jint (*CallStaticIntMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jint (*CallStaticIntMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + jlong (*CallStaticLongMethod)(JNIEnv*, jclass, jmethodID, ...); + jlong (*CallStaticLongMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jlong (*CallStaticLongMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + jfloat (*CallStaticFloatMethod)(JNIEnv*, jclass, jmethodID, ...); + jfloat (*CallStaticFloatMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jfloat (*CallStaticFloatMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + jdouble (*CallStaticDoubleMethod)(JNIEnv*, jclass, jmethodID, ...); + jdouble (*CallStaticDoubleMethodV)(JNIEnv*, jclass, jmethodID, va_list); + jdouble (*CallStaticDoubleMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + void (*CallStaticVoidMethod)(JNIEnv*, jclass, jmethodID, ...); + void (*CallStaticVoidMethodV)(JNIEnv*, jclass, jmethodID, va_list); + void (*CallStaticVoidMethodA)(JNIEnv*, jclass, jmethodID, jvalue*); + + jfieldID (*GetStaticFieldID)(JNIEnv*, jclass, const char*, + const char*); + + jobject (*GetStaticObjectField)(JNIEnv*, jclass, jfieldID); + jboolean (*GetStaticBooleanField)(JNIEnv*, jclass, jfieldID); + jbyte (*GetStaticByteField)(JNIEnv*, jclass, jfieldID); + jchar (*GetStaticCharField)(JNIEnv*, jclass, jfieldID); + jshort (*GetStaticShortField)(JNIEnv*, jclass, jfieldID); + jint (*GetStaticIntField)(JNIEnv*, jclass, jfieldID); + jlong (*GetStaticLongField)(JNIEnv*, jclass, jfieldID); + jfloat (*GetStaticFloatField)(JNIEnv*, jclass, jfieldID); + jdouble (*GetStaticDoubleField)(JNIEnv*, jclass, jfieldID); + + void (*SetStaticObjectField)(JNIEnv*, jclass, jfieldID, jobject); + void (*SetStaticBooleanField)(JNIEnv*, jclass, jfieldID, jboolean); + void (*SetStaticByteField)(JNIEnv*, jclass, jfieldID, jbyte); + void (*SetStaticCharField)(JNIEnv*, jclass, jfieldID, jchar); + void (*SetStaticShortField)(JNIEnv*, jclass, jfieldID, jshort); + void (*SetStaticIntField)(JNIEnv*, jclass, jfieldID, jint); + void (*SetStaticLongField)(JNIEnv*, jclass, jfieldID, jlong); + void (*SetStaticFloatField)(JNIEnv*, jclass, jfieldID, jfloat); + void (*SetStaticDoubleField)(JNIEnv*, jclass, jfieldID, jdouble); + + jstring (*NewString)(JNIEnv*, const jchar*, jsize); + jsize (*GetStringLength)(JNIEnv*, jstring); + const jchar* (*GetStringChars)(JNIEnv*, jstring, jboolean*); + void (*ReleaseStringChars)(JNIEnv*, jstring, const jchar*); + jstring (*NewStringUTF)(JNIEnv*, const char*); + jsize (*GetStringUTFLength)(JNIEnv*, jstring); + /* JNI spec says this returns const jbyte*, but that's inconsistent */ + const char* (*GetStringUTFChars)(JNIEnv*, jstring, jboolean*); + void (*ReleaseStringUTFChars)(JNIEnv*, jstring, const char*); + jsize (*GetArrayLength)(JNIEnv*, jarray); + jobjectArray (*NewObjectArray)(JNIEnv*, jsize, jclass, jobject); + jobject (*GetObjectArrayElement)(JNIEnv*, jobjectArray, jsize); + void (*SetObjectArrayElement)(JNIEnv*, jobjectArray, jsize, jobject); + + jbooleanArray (*NewBooleanArray)(JNIEnv*, jsize); + jbyteArray (*NewByteArray)(JNIEnv*, jsize); + jcharArray (*NewCharArray)(JNIEnv*, jsize); + jshortArray (*NewShortArray)(JNIEnv*, jsize); + jintArray (*NewIntArray)(JNIEnv*, jsize); + jlongArray (*NewLongArray)(JNIEnv*, jsize); + jfloatArray (*NewFloatArray)(JNIEnv*, jsize); + jdoubleArray (*NewDoubleArray)(JNIEnv*, jsize); + + jboolean* (*GetBooleanArrayElements)(JNIEnv*, jbooleanArray, jboolean*); + jbyte* (*GetByteArrayElements)(JNIEnv*, jbyteArray, jboolean*); + jchar* (*GetCharArrayElements)(JNIEnv*, jcharArray, jboolean*); + jshort* (*GetShortArrayElements)(JNIEnv*, jshortArray, jboolean*); + jint* (*GetIntArrayElements)(JNIEnv*, jintArray, jboolean*); + jlong* (*GetLongArrayElements)(JNIEnv*, jlongArray, jboolean*); + jfloat* (*GetFloatArrayElements)(JNIEnv*, jfloatArray, jboolean*); + jdouble* (*GetDoubleArrayElements)(JNIEnv*, jdoubleArray, jboolean*); + + void (*ReleaseBooleanArrayElements)(JNIEnv*, jbooleanArray, + jboolean*, jint); + void (*ReleaseByteArrayElements)(JNIEnv*, jbyteArray, + jbyte*, jint); + void (*ReleaseCharArrayElements)(JNIEnv*, jcharArray, + jchar*, jint); + void (*ReleaseShortArrayElements)(JNIEnv*, jshortArray, + jshort*, jint); + void (*ReleaseIntArrayElements)(JNIEnv*, jintArray, + jint*, jint); + void (*ReleaseLongArrayElements)(JNIEnv*, jlongArray, + jlong*, jint); + void (*ReleaseFloatArrayElements)(JNIEnv*, jfloatArray, + jfloat*, jint); + void (*ReleaseDoubleArrayElements)(JNIEnv*, jdoubleArray, + jdouble*, jint); + + void (*GetBooleanArrayRegion)(JNIEnv*, jbooleanArray, + jsize, jsize, jboolean*); + void (*GetByteArrayRegion)(JNIEnv*, jbyteArray, + jsize, jsize, jbyte*); + void (*GetCharArrayRegion)(JNIEnv*, jcharArray, + jsize, jsize, jchar*); + void (*GetShortArrayRegion)(JNIEnv*, jshortArray, + jsize, jsize, jshort*); + void (*GetIntArrayRegion)(JNIEnv*, jintArray, + jsize, jsize, jint*); + void (*GetLongArrayRegion)(JNIEnv*, jlongArray, + jsize, jsize, jlong*); + void (*GetFloatArrayRegion)(JNIEnv*, jfloatArray, + jsize, jsize, jfloat*); + void (*GetDoubleArrayRegion)(JNIEnv*, jdoubleArray, + jsize, jsize, jdouble*); + + /* spec shows these without const; some jni.h do, some don't */ + void (*SetBooleanArrayRegion)(JNIEnv*, jbooleanArray, + jsize, jsize, const jboolean*); + void (*SetByteArrayRegion)(JNIEnv*, jbyteArray, + jsize, jsize, const jbyte*); + void (*SetCharArrayRegion)(JNIEnv*, jcharArray, + jsize, jsize, const jchar*); + void (*SetShortArrayRegion)(JNIEnv*, jshortArray, + jsize, jsize, const jshort*); + void (*SetIntArrayRegion)(JNIEnv*, jintArray, + jsize, jsize, const jint*); + void (*SetLongArrayRegion)(JNIEnv*, jlongArray, + jsize, jsize, const jlong*); + void (*SetFloatArrayRegion)(JNIEnv*, jfloatArray, + jsize, jsize, const jfloat*); + void (*SetDoubleArrayRegion)(JNIEnv*, jdoubleArray, + jsize, jsize, const jdouble*); + + jint (*RegisterNatives)(JNIEnv*, jclass, const JNINativeMethod*, + jint); + jint (*UnregisterNatives)(JNIEnv*, jclass); + jint (*MonitorEnter)(JNIEnv*, jobject); + jint (*MonitorExit)(JNIEnv*, jobject); + jint (*GetJavaVM)(JNIEnv*, JavaVM**); + + void (*GetStringRegion)(JNIEnv*, jstring, jsize, jsize, jchar*); + void (*GetStringUTFRegion)(JNIEnv*, jstring, jsize, jsize, char*); + + void* (*GetPrimitiveArrayCritical)(JNIEnv*, jarray, jboolean*); + void (*ReleasePrimitiveArrayCritical)(JNIEnv*, jarray, void*, jint); + + const jchar* (*GetStringCritical)(JNIEnv*, jstring, jboolean*); + void (*ReleaseStringCritical)(JNIEnv*, jstring, const jchar*); + + jweak (*NewWeakGlobalRef)(JNIEnv*, jobject); + void (*DeleteWeakGlobalRef)(JNIEnv*, jweak); + + jboolean (*ExceptionCheck)(JNIEnv*); + jobject (*NewDirectByteBuffer)(JNIEnv*, void*, jlong); + + void* (*GetDirectBufferAddress)(JNIEnv*, jobject); + jlong (*GetDirectBufferCapacity)(JNIEnv*, jobject); + + /* added in JNI 1.6 */ + jobjectRefType (*GetObjectRefType)(JNIEnv*, jobject); +}; + +/* + * C++ object wrapper. + * + * This is usually overlaid on a C struct whose first element is a + * JNINativeInterface*. We rely somewhat on compiler behavior. + */ +struct _JNIEnv { + /* do not rename this; it does not seem to be entirely opaque */ + const struct JNINativeInterface* functions; + +#if defined(__cplusplus) + jint GetVersion() + { return functions->GetVersion(this); } + + jclass DefineClass(const char *name, jobject loader, const jbyte* buf, + jsize bufLen) + { return functions->DefineClass(this, name, loader, buf, bufLen); } + + jclass FindClass(const char* name) + { return functions->FindClass(this, name); } + + jmethodID FromReflectedMethod(jobject method) + { return functions->FromReflectedMethod(this, method); } + + jfieldID FromReflectedField(jobject field) + { return functions->FromReflectedField(this, field); } + + jobject ToReflectedMethod(jclass cls, jmethodID methodID, jboolean isStatic) + { return functions->ToReflectedMethod(this, cls, methodID, isStatic); } + + jclass GetSuperclass(jclass clazz) + { return functions->GetSuperclass(this, clazz); } + + jboolean IsAssignableFrom(jclass clazz1, jclass clazz2) + { return functions->IsAssignableFrom(this, clazz1, clazz2); } + + jobject ToReflectedField(jclass cls, jfieldID fieldID, jboolean isStatic) + { return functions->ToReflectedField(this, cls, fieldID, isStatic); } + + jint Throw(jthrowable obj) + { return functions->Throw(this, obj); } + + jint ThrowNew(jclass clazz, const char* message) + { return functions->ThrowNew(this, clazz, message); } + + jthrowable ExceptionOccurred() + { return functions->ExceptionOccurred(this); } + + void ExceptionDescribe() + { functions->ExceptionDescribe(this); } + + void ExceptionClear() + { functions->ExceptionClear(this); } + + void FatalError(const char* msg) + { functions->FatalError(this, msg); } + + jint PushLocalFrame(jint capacity) + { return functions->PushLocalFrame(this, capacity); } + + jobject PopLocalFrame(jobject result) + { return functions->PopLocalFrame(this, result); } + + jobject NewGlobalRef(jobject obj) + { return functions->NewGlobalRef(this, obj); } + + void DeleteGlobalRef(jobject globalRef) + { functions->DeleteGlobalRef(this, globalRef); } + + void DeleteLocalRef(jobject localRef) + { functions->DeleteLocalRef(this, localRef); } + + jboolean IsSameObject(jobject ref1, jobject ref2) + { return functions->IsSameObject(this, ref1, ref2); } + + jobject NewLocalRef(jobject ref) + { return functions->NewLocalRef(this, ref); } + + jint EnsureLocalCapacity(jint capacity) + { return functions->EnsureLocalCapacity(this, capacity); } + + jobject AllocObject(jclass clazz) + { return functions->AllocObject(this, clazz); } + + jobject NewObject(jclass clazz, jmethodID methodID, ...) + { + va_list args; + va_start(args, methodID); + jobject result = functions->NewObjectV(this, clazz, methodID, args); + va_end(args); + return result; + } + + jobject NewObjectV(jclass clazz, jmethodID methodID, va_list args) + { return functions->NewObjectV(this, clazz, methodID, args); } + + jobject NewObjectA(jclass clazz, jmethodID methodID, jvalue* args) + { return functions->NewObjectA(this, clazz, methodID, args); } + + jclass GetObjectClass(jobject obj) + { return functions->GetObjectClass(this, obj); } + + jboolean IsInstanceOf(jobject obj, jclass clazz) + { return functions->IsInstanceOf(this, obj, clazz); } + + jmethodID GetMethodID(jclass clazz, const char* name, const char* sig) + { return functions->GetMethodID(this, clazz, name, sig); } + +#define CALL_TYPE_METHOD(_jtype, _jname) \ + _jtype Call##_jname##Method(jobject obj, jmethodID methodID, ...) \ + { \ + _jtype result; \ + va_list args; \ + va_start(args, methodID); \ + result = functions->Call##_jname##MethodV(this, obj, methodID, \ + args); \ + va_end(args); \ + return result; \ + } + +#define CALL_TYPE_METHODV(_jtype, _jname) \ + _jtype Call##_jname##MethodV(jobject obj, jmethodID methodID, \ + va_list args) \ + { return functions->Call##_jname##MethodV(this, obj, methodID, args); } + +#define CALL_TYPE_METHODA(_jtype, _jname) \ + _jtype Call##_jname##MethodA(jobject obj, jmethodID methodID, \ + jvalue* args) \ + { return functions->Call##_jname##MethodA(this, obj, methodID, args); } + +#define CALL_TYPE(_jtype, _jname) \ + CALL_TYPE_METHOD(_jtype, _jname) \ + CALL_TYPE_METHODV(_jtype, _jname) \ + CALL_TYPE_METHODA(_jtype, _jname) + CALL_TYPE(jobject, Object) + CALL_TYPE(jboolean, Boolean) + CALL_TYPE(jbyte, Byte) + CALL_TYPE(jchar, Char) + CALL_TYPE(jshort, Short) + CALL_TYPE(jint, Int) + CALL_TYPE(jlong, Long) + CALL_TYPE(jfloat, Float) + CALL_TYPE(jdouble, Double) + + void CallVoidMethod(jobject obj, jmethodID methodID, ...) + { + va_list args; + va_start(args, methodID); + functions->CallVoidMethodV(this, obj, methodID, args); + va_end(args); + } + + void CallVoidMethodV(jobject obj, jmethodID methodID, va_list args) + { functions->CallVoidMethodV(this, obj, methodID, args); } + + void CallVoidMethodA(jobject obj, jmethodID methodID, jvalue* args) + { functions->CallVoidMethodA(this, obj, methodID, args); } +#define CALL_NONVIRT_TYPE_METHOD(_jtype, _jname) \ + _jtype CallNonvirtual##_jname##Method(jobject obj, jclass clazz, \ + jmethodID methodID, ...) \ + { \ + _jtype result; \ + va_list args; \ + va_start(args, methodID); \ + result = functions->CallNonvirtual##_jname##MethodV(this, obj, \ + clazz, methodID, args); \ + va_end(args); \ + return result; \ + } +#define CALL_NONVIRT_TYPE_METHODV(_jtype, _jname) \ + _jtype CallNonvirtual##_jname##MethodV(jobject obj, jclass clazz, \ + jmethodID methodID, va_list args) \ + { return functions->CallNonvirtual##_jname##MethodV(this, obj, clazz, \ + methodID, args); } +#define CALL_NONVIRT_TYPE_METHODA(_jtype, _jname) \ + _jtype CallNonvirtual##_jname##MethodA(jobject obj, jclass clazz, \ + jmethodID methodID, jvalue* args) \ + { return functions->CallNonvirtual##_jname##MethodA(this, obj, clazz, \ + methodID, args); } +#define CALL_NONVIRT_TYPE(_jtype, _jname) \ + CALL_NONVIRT_TYPE_METHOD(_jtype, _jname) \ + CALL_NONVIRT_TYPE_METHODV(_jtype, _jname) \ + CALL_NONVIRT_TYPE_METHODA(_jtype, _jname) + CALL_NONVIRT_TYPE(jobject, Object) + CALL_NONVIRT_TYPE(jboolean, Boolean) + CALL_NONVIRT_TYPE(jbyte, Byte) + CALL_NONVIRT_TYPE(jchar, Char) + CALL_NONVIRT_TYPE(jshort, Short) + CALL_NONVIRT_TYPE(jint, Int) + CALL_NONVIRT_TYPE(jlong, Long) + CALL_NONVIRT_TYPE(jfloat, Float) + CALL_NONVIRT_TYPE(jdouble, Double) + void CallNonvirtualVoidMethod(jobject obj, jclass clazz, + jmethodID methodID, ...) + { + va_list args; + va_start(args, methodID); + functions->CallNonvirtualVoidMethodV(this, obj, clazz, methodID, args); + va_end(args); + } + void CallNonvirtualVoidMethodV(jobject obj, jclass clazz, + jmethodID methodID, va_list args) + { functions->CallNonvirtualVoidMethodV(this, obj, clazz, methodID, args); } + void CallNonvirtualVoidMethodA(jobject obj, jclass clazz, + jmethodID methodID, jvalue* args) + { functions->CallNonvirtualVoidMethodA(this, obj, clazz, methodID, args); } + jfieldID GetFieldID(jclass clazz, const char* name, const char* sig) + { return functions->GetFieldID(this, clazz, name, sig); } + jobject GetObjectField(jobject obj, jfieldID fieldID) + { return functions->GetObjectField(this, obj, fieldID); } + jboolean GetBooleanField(jobject obj, jfieldID fieldID) + { return functions->GetBooleanField(this, obj, fieldID); } + jbyte GetByteField(jobject obj, jfieldID fieldID) + { return functions->GetByteField(this, obj, fieldID); } + jchar GetCharField(jobject obj, jfieldID fieldID) + { return functions->GetCharField(this, obj, fieldID); } + jshort GetShortField(jobject obj, jfieldID fieldID) + { return functions->GetShortField(this, obj, fieldID); } + jint GetIntField(jobject obj, jfieldID fieldID) + { return functions->GetIntField(this, obj, fieldID); } + jlong GetLongField(jobject obj, jfieldID fieldID) + { return functions->GetLongField(this, obj, fieldID); } + jfloat GetFloatField(jobject obj, jfieldID fieldID) + { return functions->GetFloatField(this, obj, fieldID); } + jdouble GetDoubleField(jobject obj, jfieldID fieldID) + { return functions->GetDoubleField(this, obj, fieldID); } + void SetObjectField(jobject obj, jfieldID fieldID, jobject value) + { functions->SetObjectField(this, obj, fieldID, value); } + void SetBooleanField(jobject obj, jfieldID fieldID, jboolean value) + { functions->SetBooleanField(this, obj, fieldID, value); } + void SetByteField(jobject obj, jfieldID fieldID, jbyte value) + { functions->SetByteField(this, obj, fieldID, value); } + void SetCharField(jobject obj, jfieldID fieldID, jchar value) + { functions->SetCharField(this, obj, fieldID, value); } + void SetShortField(jobject obj, jfieldID fieldID, jshort value) + { functions->SetShortField(this, obj, fieldID, value); } + void SetIntField(jobject obj, jfieldID fieldID, jint value) + { functions->SetIntField(this, obj, fieldID, value); } + void SetLongField(jobject obj, jfieldID fieldID, jlong value) + { functions->SetLongField(this, obj, fieldID, value); } + void SetFloatField(jobject obj, jfieldID fieldID, jfloat value) + { functions->SetFloatField(this, obj, fieldID, value); } + void SetDoubleField(jobject obj, jfieldID fieldID, jdouble value) + { functions->SetDoubleField(this, obj, fieldID, value); } + jmethodID GetStaticMethodID(jclass clazz, const char* name, const char* sig) + { return functions->GetStaticMethodID(this, clazz, name, sig); } + +#define CALL_STATIC_TYPE_METHOD(_jtype, _jname) \ + _jtype CallStatic##_jname##Method(jclass clazz, jmethodID methodID, \ + ...) \ + { \ + _jtype result; \ + va_list args; \ + va_start(args, methodID); \ + result = functions->CallStatic##_jname##MethodV(this, clazz, \ + methodID, args); \ + va_end(args); \ + return result; \ + } +#define CALL_STATIC_TYPE_METHODV(_jtype, _jname) \ + _jtype CallStatic##_jname##MethodV(jclass clazz, jmethodID methodID, \ + va_list args) \ + { return functions->CallStatic##_jname##MethodV(this, clazz, methodID, \ + args); } +#define CALL_STATIC_TYPE_METHODA(_jtype, _jname) \ + _jtype CallStatic##_jname##MethodA(jclass clazz, jmethodID methodID, \ + jvalue* args) \ + { return functions->CallStatic##_jname##MethodA(this, clazz, methodID, \ + args); } + +#define CALL_STATIC_TYPE(_jtype, _jname) \ + CALL_STATIC_TYPE_METHOD(_jtype, _jname) \ + CALL_STATIC_TYPE_METHODV(_jtype, _jname) \ + CALL_STATIC_TYPE_METHODA(_jtype, _jname) + CALL_STATIC_TYPE(jobject, Object) + CALL_STATIC_TYPE(jboolean, Boolean) + CALL_STATIC_TYPE(jbyte, Byte) + CALL_STATIC_TYPE(jchar, Char) + CALL_STATIC_TYPE(jshort, Short) + CALL_STATIC_TYPE(jint, Int) + CALL_STATIC_TYPE(jlong, Long) + CALL_STATIC_TYPE(jfloat, Float) + CALL_STATIC_TYPE(jdouble, Double) + void CallStaticVoidMethod(jclass clazz, jmethodID methodID, ...) + { + va_list args; + va_start(args, methodID); + functions->CallStaticVoidMethodV(this, clazz, methodID, args); + va_end(args); + } + void CallStaticVoidMethodV(jclass clazz, jmethodID methodID, va_list args) + { functions->CallStaticVoidMethodV(this, clazz, methodID, args); } + void CallStaticVoidMethodA(jclass clazz, jmethodID methodID, jvalue* args) + { functions->CallStaticVoidMethodA(this, clazz, methodID, args); } + + jfieldID GetStaticFieldID(jclass clazz, const char* name, const char* sig) + { return functions->GetStaticFieldID(this, clazz, name, sig); } + + jobject GetStaticObjectField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticObjectField(this, clazz, fieldID); } + jboolean GetStaticBooleanField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticBooleanField(this, clazz, fieldID); } + jbyte GetStaticByteField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticByteField(this, clazz, fieldID); } + jchar GetStaticCharField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticCharField(this, clazz, fieldID); } + jshort GetStaticShortField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticShortField(this, clazz, fieldID); } + jint GetStaticIntField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticIntField(this, clazz, fieldID); } + jlong GetStaticLongField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticLongField(this, clazz, fieldID); } + jfloat GetStaticFloatField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticFloatField(this, clazz, fieldID); } + jdouble GetStaticDoubleField(jclass clazz, jfieldID fieldID) + { return functions->GetStaticDoubleField(this, clazz, fieldID); } + + void SetStaticObjectField(jclass clazz, jfieldID fieldID, jobject value) + { functions->SetStaticObjectField(this, clazz, fieldID, value); } + void SetStaticBooleanField(jclass clazz, jfieldID fieldID, jboolean value) + { functions->SetStaticBooleanField(this, clazz, fieldID, value); } + void SetStaticByteField(jclass clazz, jfieldID fieldID, jbyte value) + { functions->SetStaticByteField(this, clazz, fieldID, value); } + void SetStaticCharField(jclass clazz, jfieldID fieldID, jchar value) + { functions->SetStaticCharField(this, clazz, fieldID, value); } + void SetStaticShortField(jclass clazz, jfieldID fieldID, jshort value) + { functions->SetStaticShortField(this, clazz, fieldID, value); } + void SetStaticIntField(jclass clazz, jfieldID fieldID, jint value) + { functions->SetStaticIntField(this, clazz, fieldID, value); } + void SetStaticLongField(jclass clazz, jfieldID fieldID, jlong value) + { functions->SetStaticLongField(this, clazz, fieldID, value); } + void SetStaticFloatField(jclass clazz, jfieldID fieldID, jfloat value) + { functions->SetStaticFloatField(this, clazz, fieldID, value); } + void SetStaticDoubleField(jclass clazz, jfieldID fieldID, jdouble value) + { functions->SetStaticDoubleField(this, clazz, fieldID, value); } + + jstring NewString(const jchar* unicodeChars, jsize len) + { return functions->NewString(this, unicodeChars, len); } + + jsize GetStringLength(jstring string) + { return functions->GetStringLength(this, string); } + + const jchar* GetStringChars(jstring string, jboolean* isCopy) + { return functions->GetStringChars(this, string, isCopy); } + + void ReleaseStringChars(jstring string, const jchar* chars) + { functions->ReleaseStringChars(this, string, chars); } + + jstring NewStringUTF(const char* bytes) + { return functions->NewStringUTF(this, bytes); } + + jsize GetStringUTFLength(jstring string) + { return functions->GetStringUTFLength(this, string); } + + const char* GetStringUTFChars(jstring string, jboolean* isCopy) + { return functions->GetStringUTFChars(this, string, isCopy); } + + void ReleaseStringUTFChars(jstring string, const char* utf) + { functions->ReleaseStringUTFChars(this, string, utf); } + + jsize GetArrayLength(jarray array) + { return functions->GetArrayLength(this, array); } + + jobjectArray NewObjectArray(jsize length, jclass elementClass, + jobject initialElement) + { return functions->NewObjectArray(this, length, elementClass, + initialElement); } + + jobject GetObjectArrayElement(jobjectArray array, jsize index) + { return functions->GetObjectArrayElement(this, array, index); } + + void SetObjectArrayElement(jobjectArray array, jsize index, jobject value) + { functions->SetObjectArrayElement(this, array, index, value); } + + jbooleanArray NewBooleanArray(jsize length) + { return functions->NewBooleanArray(this, length); } + jbyteArray NewByteArray(jsize length) + { return functions->NewByteArray(this, length); } + jcharArray NewCharArray(jsize length) + { return functions->NewCharArray(this, length); } + jshortArray NewShortArray(jsize length) + { return functions->NewShortArray(this, length); } + jintArray NewIntArray(jsize length) + { return functions->NewIntArray(this, length); } + jlongArray NewLongArray(jsize length) + { return functions->NewLongArray(this, length); } + jfloatArray NewFloatArray(jsize length) + { return functions->NewFloatArray(this, length); } + jdoubleArray NewDoubleArray(jsize length) + { return functions->NewDoubleArray(this, length); } + + jboolean* GetBooleanArrayElements(jbooleanArray array, jboolean* isCopy) + { return functions->GetBooleanArrayElements(this, array, isCopy); } + jbyte* GetByteArrayElements(jbyteArray array, jboolean* isCopy) + { return functions->GetByteArrayElements(this, array, isCopy); } + jchar* GetCharArrayElements(jcharArray array, jboolean* isCopy) + { return functions->GetCharArrayElements(this, array, isCopy); } + jshort* GetShortArrayElements(jshortArray array, jboolean* isCopy) + { return functions->GetShortArrayElements(this, array, isCopy); } + jint* GetIntArrayElements(jintArray array, jboolean* isCopy) + { return functions->GetIntArrayElements(this, array, isCopy); } + jlong* GetLongArrayElements(jlongArray array, jboolean* isCopy) + { return functions->GetLongArrayElements(this, array, isCopy); } + jfloat* GetFloatArrayElements(jfloatArray array, jboolean* isCopy) + { return functions->GetFloatArrayElements(this, array, isCopy); } + jdouble* GetDoubleArrayElements(jdoubleArray array, jboolean* isCopy) + { return functions->GetDoubleArrayElements(this, array, isCopy); } + + void ReleaseBooleanArrayElements(jbooleanArray array, jboolean* elems, + jint mode) + { functions->ReleaseBooleanArrayElements(this, array, elems, mode); } + void ReleaseByteArrayElements(jbyteArray array, jbyte* elems, + jint mode) + { functions->ReleaseByteArrayElements(this, array, elems, mode); } + void ReleaseCharArrayElements(jcharArray array, jchar* elems, + jint mode) + { functions->ReleaseCharArrayElements(this, array, elems, mode); } + void ReleaseShortArrayElements(jshortArray array, jshort* elems, + jint mode) + { functions->ReleaseShortArrayElements(this, array, elems, mode); } + void ReleaseIntArrayElements(jintArray array, jint* elems, + jint mode) + { functions->ReleaseIntArrayElements(this, array, elems, mode); } + void ReleaseLongArrayElements(jlongArray array, jlong* elems, + jint mode) + { functions->ReleaseLongArrayElements(this, array, elems, mode); } + void ReleaseFloatArrayElements(jfloatArray array, jfloat* elems, + jint mode) + { functions->ReleaseFloatArrayElements(this, array, elems, mode); } + void ReleaseDoubleArrayElements(jdoubleArray array, jdouble* elems, + jint mode) + { functions->ReleaseDoubleArrayElements(this, array, elems, mode); } + + void GetBooleanArrayRegion(jbooleanArray array, jsize start, jsize len, + jboolean* buf) + { functions->GetBooleanArrayRegion(this, array, start, len, buf); } + void GetByteArrayRegion(jbyteArray array, jsize start, jsize len, + jbyte* buf) + { functions->GetByteArrayRegion(this, array, start, len, buf); } + void GetCharArrayRegion(jcharArray array, jsize start, jsize len, + jchar* buf) + { functions->GetCharArrayRegion(this, array, start, len, buf); } + void GetShortArrayRegion(jshortArray array, jsize start, jsize len, + jshort* buf) + { functions->GetShortArrayRegion(this, array, start, len, buf); } + void GetIntArrayRegion(jintArray array, jsize start, jsize len, + jint* buf) + { functions->GetIntArrayRegion(this, array, start, len, buf); } + void GetLongArrayRegion(jlongArray array, jsize start, jsize len, + jlong* buf) + { functions->GetLongArrayRegion(this, array, start, len, buf); } + void GetFloatArrayRegion(jfloatArray array, jsize start, jsize len, + jfloat* buf) + { functions->GetFloatArrayRegion(this, array, start, len, buf); } + void GetDoubleArrayRegion(jdoubleArray array, jsize start, jsize len, + jdouble* buf) + { functions->GetDoubleArrayRegion(this, array, start, len, buf); } + + void SetBooleanArrayRegion(jbooleanArray array, jsize start, jsize len, + const jboolean* buf) + { functions->SetBooleanArrayRegion(this, array, start, len, buf); } + void SetByteArrayRegion(jbyteArray array, jsize start, jsize len, + const jbyte* buf) + { functions->SetByteArrayRegion(this, array, start, len, buf); } + void SetCharArrayRegion(jcharArray array, jsize start, jsize len, + const jchar* buf) + { functions->SetCharArrayRegion(this, array, start, len, buf); } + void SetShortArrayRegion(jshortArray array, jsize start, jsize len, + const jshort* buf) + { functions->SetShortArrayRegion(this, array, start, len, buf); } + void SetIntArrayRegion(jintArray array, jsize start, jsize len, + const jint* buf) + { functions->SetIntArrayRegion(this, array, start, len, buf); } + void SetLongArrayRegion(jlongArray array, jsize start, jsize len, + const jlong* buf) + { functions->SetLongArrayRegion(this, array, start, len, buf); } + void SetFloatArrayRegion(jfloatArray array, jsize start, jsize len, + const jfloat* buf) + { functions->SetFloatArrayRegion(this, array, start, len, buf); } + void SetDoubleArrayRegion(jdoubleArray array, jsize start, jsize len, + const jdouble* buf) + { functions->SetDoubleArrayRegion(this, array, start, len, buf); } + + jint RegisterNatives(jclass clazz, const JNINativeMethod* methods, + jint nMethods) + { return functions->RegisterNatives(this, clazz, methods, nMethods); } + + jint UnregisterNatives(jclass clazz) + { return functions->UnregisterNatives(this, clazz); } + + jint MonitorEnter(jobject obj) + { return functions->MonitorEnter(this, obj); } + + jint MonitorExit(jobject obj) + { return functions->MonitorExit(this, obj); } + + jint GetJavaVM(JavaVM** vm) + { return functions->GetJavaVM(this, vm); } + + void GetStringRegion(jstring str, jsize start, jsize len, jchar* buf) + { functions->GetStringRegion(this, str, start, len, buf); } + + void GetStringUTFRegion(jstring str, jsize start, jsize len, char* buf) + { return functions->GetStringUTFRegion(this, str, start, len, buf); } + + void* GetPrimitiveArrayCritical(jarray array, jboolean* isCopy) + { return functions->GetPrimitiveArrayCritical(this, array, isCopy); } + + void ReleasePrimitiveArrayCritical(jarray array, void* carray, jint mode) + { functions->ReleasePrimitiveArrayCritical(this, array, carray, mode); } + + const jchar* GetStringCritical(jstring string, jboolean* isCopy) + { return functions->GetStringCritical(this, string, isCopy); } + + void ReleaseStringCritical(jstring string, const jchar* carray) + { functions->ReleaseStringCritical(this, string, carray); } + + jweak NewWeakGlobalRef(jobject obj) + { return functions->NewWeakGlobalRef(this, obj); } + + void DeleteWeakGlobalRef(jweak obj) + { functions->DeleteWeakGlobalRef(this, obj); } + + jboolean ExceptionCheck() + { return functions->ExceptionCheck(this); } + + jobject NewDirectByteBuffer(void* address, jlong capacity) + { return functions->NewDirectByteBuffer(this, address, capacity); } + + void* GetDirectBufferAddress(jobject buf) + { return functions->GetDirectBufferAddress(this, buf); } + + jlong GetDirectBufferCapacity(jobject buf) + { return functions->GetDirectBufferCapacity(this, buf); } + + /* added in JNI 1.6 */ + jobjectRefType GetObjectRefType(jobject obj) + { return functions->GetObjectRefType(this, obj); } +#endif /*__cplusplus*/ +}; + + +/* + * JNI invocation interface. + */ +struct JNIInvokeInterface { + void* reserved0; + void* reserved1; + void* reserved2; + jint (*DestroyJavaVM)(JavaVM*); + jint (*AttachCurrentThread)(JavaVM*, JNIEnv**, void*); + jint (*DetachCurrentThread)(JavaVM*); + jint (*GetEnv)(JavaVM*, void**, jint); + jint (*AttachCurrentThreadAsDaemon)(JavaVM*, JNIEnv**, void*); +}; + +/* + * C++ version. + */ +struct _JavaVM { + const struct JNIInvokeInterface* functions; + +#if defined(__cplusplus) + jint DestroyJavaVM() + { return functions->DestroyJavaVM(this); } + jint AttachCurrentThread(JNIEnv** p_env, void* thr_args) + { return functions->AttachCurrentThread(this, p_env, thr_args); } + jint DetachCurrentThread() + { return functions->DetachCurrentThread(this); } + jint GetEnv(void** env, jint version) + { return functions->GetEnv(this, env, version); } + jint AttachCurrentThreadAsDaemon(JNIEnv** p_env, void* thr_args) + { return functions->AttachCurrentThreadAsDaemon(this, p_env, thr_args); } +#endif /*__cplusplus*/ +}; + +struct JavaVMAttachArgs { + jint version; /* must be >= JNI_VERSION_1_2 */ + const char* name; /* NULL or name of thread as modified UTF-8 str */ + jobject group; /* global ref of a ThreadGroup object, or NULL */ +}; +typedef struct JavaVMAttachArgs JavaVMAttachArgs; + +/* + * JNI 1.2+ initialization. (As of 1.6, the pre-1.2 structures are no + * longer supported.) + */ +typedef struct JavaVMOption { + const char* optionString; + void* extraInfo; +} JavaVMOption; + +typedef struct JavaVMInitArgs { + jint version; /* use JNI_VERSION_1_2 or later */ + jint nOptions; + JavaVMOption* options; + jboolean ignoreUnrecognized; +} JavaVMInitArgs; + +#ifdef __cplusplus +extern "C" { +#endif +/* + * VM initialization functions. + * + * Note these are the only symbols exported for JNI by the VM. + */ +jint JNI_GetDefaultJavaVMInitArgs(void*); +jint JNI_CreateJavaVM(JavaVM**, JNIEnv**, void*); +jint JNI_GetCreatedJavaVMs(JavaVM**, jsize, jsize*); + +#define JNIIMPORT +// To match the JNIEXPORT on Windows +#define JNIEXPORT __declspec(dllexport) +#define JNICALL + +/* + * Prototypes for functions exported by loadable shared libs. These are + * called by JNI, not provided by JNI. + */ +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved); +JNIEXPORT void JNI_OnUnload(JavaVM* vm, void* reserved); + +#ifdef __cplusplus +} +#endif + +/* + * Manifest constants. + */ +#define JNI_FALSE 0 +#define JNI_TRUE 1 + +#define JNI_VERSION_1_1 0x00010001 +#define JNI_VERSION_1_2 0x00010002 +#define JNI_VERSION_1_4 0x00010004 +#define JNI_VERSION_1_6 0x00010006 + +#define JNI_OK (0) /* no error */ +#define JNI_ERR (-1) /* generic error */ +#define JNI_EDETACHED (-2) /* thread detached from the VM */ +#define JNI_EVERSION (-3) /* JNI version error */ + +#define JNI_COMMIT 1 /* copy content, do not free buffer */ +#define JNI_ABORT 2 /* free buffer w/o copying back */ + +#endif /* JNI_H_ */ + diff --git a/.clang-tidy b/.clang-tidy index a540d67a130eb..38cc67747c326 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,6 +1,7 @@ --- # NOTE there must be no spaces before the '-', so put the comma last. -Checks: '-*, +InheritParentConfig: true +Checks: ' bugprone-*, -bugprone-forward-declaration-namespace, -bugprone-macro-parentheses, @@ -17,9 +18,11 @@ cppcoreguidelines-*, -cppcoreguidelines-pro-type-union-access, -cppcoreguidelines-pro-type-vararg, -cppcoreguidelines-special-member-functions, +-facebook-hte-RelativeInclude, hicpp-exception-baseclass, hicpp-avoid-goto, modernize-*, +-modernize-concat-nested-namespaces, -modernize-return-braced-init-list, -modernize-use-auto, -modernize-use-default-member-init, @@ -27,7 +30,7 @@ modernize-*, -modernize-use-trailing-return-type, performance-*, -performance-noexcept-move-constructor, - ' +' HeaderFilterRegex: 'torch/csrc/.*' AnalyzeTemporaryDtors: false CheckOptions: diff --git a/.flake8 b/.flake8 index 7ecc6df31754c..0b058328f55e1 100644 --- a/.flake8 +++ b/.flake8 @@ -12,5 +12,22 @@ ignore = B007,B008, # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 -per-file-ignores = __init__.py: F401 -exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi,.git,build,build_test_custom_build,build_code_analyzer +per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 +exclude = + docs/src, + docs/cpp/src, + venv, + third_party, + caffe2, + scripts, + docs/caffe2, + torch/lib/include, + torch/lib/tmp_install, + build, + torch/include, + *.pyi, + .git, + build, + build_test_custom_build, + build_code_analyzer, + test/generated_type_hints_smoketest.py diff --git a/.github/pytorch-circleci-labels.yml b/.github/pytorch-circleci-labels.yml index ccdf2e876af10..3a9eeca0abcce 100644 --- a/.github/pytorch-circleci-labels.yml +++ b/.github/pytorch-circleci-labels.yml @@ -9,3 +9,5 @@ labels_to_circle_params: - release/.* tags: - v[0-9]+(\.[0-9]+)*-rc[0-9]+ + set_to_false: + - run_build diff --git a/.github/workflows/jit_triage.yml b/.github/workflows/jit_triage.yml index af59d2160ec67..1fb967e8ffb8d 100644 --- a/.github/workflows/jit_triage.yml +++ b/.github/workflows/jit_triage.yml @@ -19,7 +19,7 @@ jobs: // - io: A reference to the @actions/io package // Check if issue has a JIT label. - const kJitLabel = "jit"; + const kJitLabel = "oncall: jit"; issue = await github.issues.get({ owner: context.issue.owner, diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b1b539788ba4c..54acbe7b1c6a1 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,13 +17,28 @@ jobs: architecture: x64 - name: Checkout PyTorch uses: actions/checkout@v1 + - name: Checkout PR tip + run: | + set -eux + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + # We are on a PR, so actions/checkout leaves us on a merge commit. + # Check out the actual tip of the branch. + git checkout ${{ github.event.pull_request.head.sha }} + fi + echo ::set-output name=commit_sha::$(git rev-parse HEAD) + id: get_pr_tip - name: Ensure consistent CircleCI YAML config run: | pip install -r requirements.txt cd .circleci && ./ensure-consistency.py - name: Shellcheck Jenkins scripts + # https://github.com/koalaman/shellcheck#installing-a-pre-compiled-binary run: | - sudo apt-get install -y shellcheck + scversion="stable" + wget -qO- "https://github.com/koalaman/shellcheck/releases/download/${scversion?}/shellcheck-${scversion?}.linux.x86_64.tar.xz" | tar -xJv + sudo cp "shellcheck-${scversion}/shellcheck" /usr/bin/ + rm -r "shellcheck-${scversion}" + shellcheck --version .jenkins/run-shellcheck.sh - name: Ensure no tabs run: | @@ -31,6 +46,10 @@ jobs: - name: Ensure canonical include run: | (! git grep -I -l $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' || (echo "The above files have include with quotes; please convert them to #include "; false)) + # note that this next step depends on a clean shallow checkout; + # if you run it locally in a deep checkout then it will complain + # about android/libs/fbjni/gradlew (in a submodule), + # as well as all the generated files in torch/test - name: Ensure C++ source files are not executable run: | (! find . \( -path ./third_party -o -path ./.git -o -path ./torch/bin -o -path ./build \) -prune -o -type f -executable -regextype posix-egrep -not -regex '.+(\.(bash|sh|py|so)|git-pre-commit|git-clang-format)$' -print | grep . || (echo 'The above files have executable permission; please remove their executable permission by using `chmod -x`'; false)) @@ -38,6 +57,10 @@ jobs: run: | sudo apt-get install -y doxygen && pip install -r requirements.txt cd docs/cpp/source && ./check-doxygen.sh + - name: CUDA kernel launch check + run: | + set -eux + python torch/testing/check_kernel_launches.py |& tee ${GITHUB_WORKSPACE}/cuda_kernel_launch_checks.txt flake8-py3: runs-on: ubuntu-latest @@ -62,10 +85,9 @@ jobs: - name: Run flake8 run: | set -eux - pip install flake8==3.8.2 flake8-mypy flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 + pip install -r requirements-flake8.txt flake8 --version - flake8 --exit-zero > ${GITHUB_WORKSPACE}/flake8-output.txt - cat ${GITHUB_WORKSPACE}/flake8-output.txt + flake8 | tee ${GITHUB_WORKSPACE}/flake8-output.txt - name: Add annotations uses: pytorch/add-annotations-github-action@master with: @@ -110,10 +132,10 @@ jobs: # Install dependencies pip install pyyaml wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - - sudo apt-add-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-8 main" + sudo apt-add-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-11 main" sudo apt-get update - sudo apt-get install -y clang-tidy-8 - sudo update-alternatives --install /usr/bin/clang-tidy clang-tidy /usr/bin/clang-tidy-8 1000 + sudo apt-get install -y clang-tidy-11 + sudo update-alternatives --install /usr/bin/clang-tidy clang-tidy /usr/bin/clang-tidy-11 1000 - name: Run clang-tidy run: | set -eux @@ -138,6 +160,7 @@ jobs: # Generate PyTorch files. time python tools/setup_helpers/generate_code.py \ --declarations-path build/aten/src/ATen/Declarations.yaml \ + --native-functions-path aten/src/ATen/native/native_functions.yaml \ --nn-path aten/src fi @@ -146,6 +169,7 @@ jobs: # caffe2_pb.h, otherwise we'd have to build protos as part of this CI job. # FunctionsManual.cpp is excluded to keep this diff clean. It will be fixed # in a follow up PR. + # /torch/csrc/generic/*.cpp is excluded because those files aren't actually built. python tools/clang_tidy.py \ --verbose \ --paths torch/csrc/ \ @@ -160,6 +184,8 @@ jobs: -g"-torch/csrc/cuda/nccl.*" \ -g"-torch/csrc/cuda/python_nccl.cpp" \ -g"-torch/csrc/autograd/FunctionsManual.cpp" \ + -g"-torch/csrc/generic/*.cpp" \ + -g"-torch/csrc/jit/codegen/cuda/runtime/*" \ "$@" > ${GITHUB_WORKSPACE}/clang-tidy-output.txt cat ${GITHUB_WORKSPACE}/clang-tidy-output.txt diff --git a/.github/workflows/quantization_triage.yml b/.github/workflows/quantization_triage.yml new file mode 100644 index 0000000000000..ac337a066873a --- /dev/null +++ b/.github/workflows/quantization_triage.yml @@ -0,0 +1,78 @@ +name: quantization-triage + +on: + issues: + types: [labeled] + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/github-script@v2 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + // Arguments available: + // - github: A pre-authenticated octokit/rest.js client + // - context: An object containing the context of the workflow run + // - core: A reference to the @actions/core package + // - io: A reference to the @actions/io package + + // Check if issue has a Quantization label. + const kQuantizationLabel = "oncall: quantization"; + + issue = await github.issues.get({ + owner: context.issue.owner, + repo: context.issue.repo, + issue_number: context.issue.number, + }) + + const hasQuantizationLabel = issue.data.labels.filter(label => label.name == kQuantizationLabel).length > 0; + + if (!hasQuantizationLabel) { + core.debug("Issue " + issue.data.title + " does not have Quantization label"); + return; + } + + // Get project column ID. + const kProjectName = "Quantization Triage"; + const kColumnName = "Need Triage"; + + // Query all projects in the repository. + // TODO: Support pagination once there are > 30 projects. + const projects = await github.projects.listForRepo({ + owner: context.issue.owner, + repo: context.issue.repo, + }); + + // Filter out unwanted projects and get the ID for the Quantization Triage project. + const filteredProjects = projects.data.filter(project => project.name == kProjectName); + + if (filteredProjects.length != 1) { + core.setFailed("Unable to find a project named " + kProjectName); + return; + } + + const projectId = filteredProjects[0].id; + // First, query all columns in the project. + // TODO: Support pagination once there are > 30 columns. + const columns = await github.projects.listColumns({ + project_id: projectId, + }); + + // Filter out unwanted projects and get the ID for the Need triage column. + const filteredColumns = columns.data.filter(column => column.name == kColumnName); + + if (filteredColumns.length != 1) { + core.setFailed("Unable to find a column named " + kColumnName); + return; + } + + const columnId = filteredColumns[0].id; + + // Create a project card for this new issue. + await github.projects.createCard({ + column_id: columnId, + content_id: issue.data.id, + content_type: "Issue", + }) diff --git a/.github/workflows/update_s3_htmls.yml b/.github/workflows/update_s3_htmls.yml new file mode 100644 index 0000000000000..f2320ce2fcbf5 --- /dev/null +++ b/.github/workflows/update_s3_htmls.yml @@ -0,0 +1,23 @@ +name: Update S3 HTML indices for download.pytorch.org +on: + schedule: + # Update the indices every 30 minutes + - cron: "*/30 * * * *" + # Have the ability to trigger this job manually using the API as well + workflow_dispatch: + +jobs: + update-html: + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'pytorch' }} + strategy: + matrix: + prefix: ["whl", "whl/test", "whl/nightly"] + steps: + - name: Run updater image + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_UPDATE_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_UPDATE_SECRET_ACCESS_KEY }} + uses: docker://pytorch/manage_s3_html + with: + args: ${{ matrix.prefix }} diff --git a/.gitignore b/.gitignore index e908b405a6629..0a16dd79f2215 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ .coverage coverage.xml +.dmypy.json .gradle .hypothesis .mypy_cache @@ -33,6 +34,8 @@ docs/cpp/src docs/src/**/* docs/cpp/build docs/cpp/source/api +docs/cpp/source/html/ +docs/cpp/source/latex/ docs/source/generated/ log test/.coverage @@ -74,6 +77,7 @@ torch/lib/*.exe* torch/lib/*.dylib* torch/lib/*.h torch/lib/*.lib +torch/lib/*.pdb torch/lib/*.so* torch/lib/protobuf*.pc torch/lib/build @@ -91,6 +95,8 @@ torch/lib64 torch/include/ torch/share/ torch/test/ +torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h +torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py # Root level file used in CI to specify certain env configs. # E.g., see .circleci/config.yaml @@ -188,6 +194,7 @@ build_ios /build_* .build_debug/* .build_release/* +.build_profile/* distribute/* *.testbin *.bin diff --git a/.gitmodules b/.gitmodules index 509ab94f1cf44..cd2e81e07a554 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "third_party/pybind11"] ignore = dirty path = third_party/pybind11 - url = https://github.com/pybind/pybind11.git + url = https://github.com/seemethere/pybind11.git [submodule "third_party/cub"] ignore = dirty path = third_party/cub @@ -124,9 +124,12 @@ url = https://github.com/google/XNNPACK.git [submodule "third_party/fmt"] ignore = dirty - path = third_party/fmt - url = https://github.com/fmtlib/fmt.git + path = third_party/fmt + url = https://github.com/fmtlib/fmt.git [submodule "third_party/tensorpipe"] ignore = dirty path = third_party/tensorpipe url = https://github.com/pytorch/tensorpipe.git +[submodule "third_party/kineto"] + path = third_party/kineto + url = https://github.com/pytorch/kineto diff --git a/.jenkins/caffe2/build.sh b/.jenkins/caffe2/build.sh index bba8aa0e03655..0a4d1166bd059 100755 --- a/.jenkins/caffe2/build.sh +++ b/.jenkins/caffe2/build.sh @@ -18,49 +18,6 @@ build_to_cmake () { SCCACHE="$(which sccache)" -if [ "$(which gcc)" != "/root/sccache/gcc" ]; then - # Setup SCCACHE - ############################################################################### - # Setup sccache if SCCACHE_BUCKET is set - if [ -n "${SCCACHE_BUCKET}" ]; then - mkdir -p ./sccache - - SCCACHE="$(which sccache)" - if [ -z "${SCCACHE}" ]; then - echo "Unable to find sccache..." - exit 1 - fi - - # Setup wrapper scripts - wrapped="cc c++ gcc g++ x86_64-linux-gnu-gcc" - if [[ "${BUILD_ENVIRONMENT}" == *-cuda* ]]; then - wrapped="$wrapped nvcc" - fi - for compiler in $wrapped; do - ( - echo "#!/bin/sh" - - # TODO: if/when sccache gains native support for an - # SCCACHE_DISABLE flag analogous to ccache's CCACHE_DISABLE, - # this can be removed. Alternatively, this can be removed when - # https://github.com/pytorch/pytorch/issues/13362 is fixed. - # - # NOTE: carefully quoted - we want `which compiler` to be - # resolved as we execute the script, but SCCACHE_DISABLE and - # $@ to be evaluated when we execute the script - echo 'test $SCCACHE_DISABLE && exec '"$(which $compiler)"' "$@"' - - echo "exec $SCCACHE $(which $compiler) \"\$@\"" - ) > "./sccache/$compiler" - chmod +x "./sccache/$compiler" - done - - export CACHE_WRAPPER_DIR="$PWD/sccache" - - # CMake must find these wrapper scripts - export PATH="$CACHE_WRAPPER_DIR:$PATH" - fi -fi # Setup ccache if configured to use it (and not sccache) if [ -z "${SCCACHE}" ] && which ccache > /dev/null; then @@ -161,6 +118,11 @@ if [[ $BUILD_ENVIRONMENT == *cuda* ]]; then export PATH="/usr/local/cuda/bin:$PATH" fi if [[ $BUILD_ENVIRONMENT == *rocm* ]]; then + if [[ -n "$IN_CI" ]]; then + # Set ROCM_ARCH to gfx900 and gfx906 for CI builds + echo "Limiting PYTORCH_ROCM_ARCH to gfx90[06] for CI builds" + export PYTORCH_ROCM_ARCH="gfx900;gfx906" + fi # This is needed to enable ImageInput operator in resnet50_trainer build_args+=("USE_OPENCV=ON") # This is needed to read datasets from https://download.caffe2.ai/databases/resnet_trainer.zip @@ -260,6 +222,21 @@ fi ############################################################################### # Install ONNX into a local directory -pip install --user -b /tmp/pip_install_onnx "file://${ROOT_DIR}/third_party/onnx#egg=onnx" +pip install --user "file://${ROOT_DIR}/third_party/onnx#egg=onnx" report_compile_cache_stats + +if [[ $BUILD_ENVIRONMENT == *rocm* ]]; then + # remove sccache wrappers post-build; runtime compilation of MIOpen kernels does not yet fully support them + sudo rm -f /opt/cache/bin/cc + sudo rm -f /opt/cache/bin/c++ + sudo rm -f /opt/cache/bin/gcc + sudo rm -f /opt/cache/bin/g++ + pushd /opt/rocm/llvm/bin + if [[ -d original ]]; then + sudo mv original/clang . + sudo mv original/clang++ . + fi + sudo rm -rf original + popd +fi diff --git a/.jenkins/caffe2/common.sh b/.jenkins/caffe2/common.sh index 1f2cdc2f06d67..026cb8349d3d9 100644 --- a/.jenkins/caffe2/common.sh +++ b/.jenkins/caffe2/common.sh @@ -2,9 +2,9 @@ set -ex LOCAL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) ROOT_DIR=$(cd "$LOCAL_DIR"/../.. && pwd) -TEST_DIR="$ROOT_DIR/caffe2_tests" -gtest_reports_dir="${TEST_DIR}/cpp" -pytest_reports_dir="${TEST_DIR}/python" +TEST_DIR="$ROOT_DIR/test" +gtest_reports_dir="${TEST_DIR}/test-reports/cpp" +pytest_reports_dir="${TEST_DIR}/test-reports/python" # Figure out which Python to use PYTHON="$(which python)" @@ -13,6 +13,8 @@ if [[ "${BUILD_ENVIRONMENT}" =~ py((2|3)\.?[0-9]?\.?[0-9]?) ]]; then fi if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then + # HIP_PLATFORM is auto-detected by hipcc; unset to avoid build errors + unset HIP_PLATFORM if which sccache > /dev/null; then # Save sccache logs to file sccache --stop-server || true diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index 61fb7de08fe54..e6f43b6452cf8 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -88,19 +88,16 @@ if [[ "$PIP_USER" = root ]]; then MAYBE_SUDO=sudo fi -# if [[ "$BUILD_ENVIRONMENT" == *ubuntu14.04* ]]; then - # Hotfix, use hypothesis 3.44.6 on Ubuntu 14.04 - # See comments on - # https://github.com/HypothesisWorks/hypothesis-python/commit/eadd62e467d6cee6216e71b391951ec25b4f5830 - $MAYBE_SUDO pip -q uninstall -y hypothesis - # "pip install hypothesis==3.44.6" from official server is unreliable on - # CircleCI, so we host a copy on S3 instead - $MAYBE_SUDO pip -q install attrs==18.1.0 -f https://s3.amazonaws.com/ossci-linux/wheels/attrs-18.1.0-py2.py3-none-any.whl - $MAYBE_SUDO pip -q install coverage==4.5.1 -f https://s3.amazonaws.com/ossci-linux/wheels/coverage-4.5.1-cp36-cp36m-macosx_10_12_x86_64.whl - $MAYBE_SUDO pip -q install hypothesis==3.44.6 -f https://s3.amazonaws.com/ossci-linux/wheels/hypothesis-3.44.6-py3-none-any.whl -# else -# pip install --user --no-cache-dir hypothesis==3.59.0 -# fi +# Uninstall pre-installed hypothesis and coverage to use an older version as newer +# versions remove the timeout parameter from settings which ideep/conv_transpose_test.py uses +$MAYBE_SUDO pip -q uninstall -y hypothesis +$MAYBE_SUDO pip -q uninstall -y coverage + +# "pip install hypothesis==3.44.6" from official server is unreliable on +# CircleCI, so we host a copy on S3 instead +$MAYBE_SUDO pip -q install attrs==18.1.0 -f https://s3.amazonaws.com/ossci-linux/wheels/attrs-18.1.0-py2.py3-none-any.whl +$MAYBE_SUDO pip -q install coverage==4.5.1 -f https://s3.amazonaws.com/ossci-linux/wheels/coverage-4.5.1-cp36-cp36m-macosx_10_12_x86_64.whl +$MAYBE_SUDO pip -q install hypothesis==3.44.6 -f https://s3.amazonaws.com/ossci-linux/wheels/hypothesis-3.44.6-py3-none-any.whl # Collect additional tests to run (outside caffe2/python) EXTRA_TESTS=() @@ -163,15 +160,12 @@ pip install --user pytest-sugar if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # Check out torch/vision at Jun 11 2020 commit # This hash must match one in .jenkins/pytorch/test.sh - pip install -q --user git+https://github.com/pytorch/vision.git@c2e8a00885e68ae1200eb6440f540e181d9125de + pip install -q --user git+https://github.com/pytorch/vision.git@e70c91a9ff9b8a20e05c133aec6ec3ed538c32fb pip install -q --user ninja # JIT C++ extensions require ninja, so put it into PATH. export PATH="/var/lib/jenkins/.local/bin:$PATH" if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then - # default pip version is too old(9.0.2), unable to support tag `manylinux2010`. - # Fix the pip error: Couldn't find a version that satisfies the requirement - pip install --upgrade pip - pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.5.0.dev202009182 + pip install -q --user onnxruntime==1.6.0 fi "$ROOT_DIR/scripts/onnx/test.sh" fi diff --git a/.jenkins/pytorch/.shellcheckrc b/.jenkins/pytorch/.shellcheckrc new file mode 100644 index 0000000000000..ff96b057e50a3 --- /dev/null +++ b/.jenkins/pytorch/.shellcheckrc @@ -0,0 +1,6 @@ +disable=SC2086 +disable=SC1091 +disable=SC2155 +disable=SC1090 +disable=SC2164 +disable=SC1003 diff --git a/.jenkins/pytorch/README.md b/.jenkins/pytorch/README.md index ea6c6dd40f68b..9fd68ecf7f153 100644 --- a/.jenkins/pytorch/README.md +++ b/.jenkins/pytorch/README.md @@ -10,9 +10,9 @@ it is very easy to run these tests yourself: ``registry.pytorch.org/pytorch/pytorch-$BUILD_ENVIRONMENT:$DOCKER_VERSION``, where ``$BUILD_ENVIRONMENT`` is one of the build environments enumerated in - [pytorch-dockerfiles](https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh) + [pytorch-dockerfiles](https://github.com/pytorch/pytorch/blob/master/.circleci/docker/build.sh). The dockerfile used by jenkins can be found under the `.circle` [directory](https://github.com/pytorch/pytorch/blob/master/.circleci/docker) -2. Run ``docker -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and +2. Run ``docker run -it -u jenkins $DOCKER_IMAGE``, clone PyTorch and run one of the scripts in this directory. The Docker images are designed so that any "reasonable" build commands @@ -38,5 +38,5 @@ mechanisms we use: build scripts. - We reroute well known paths like `/usr/bin/gcc` to alternate - implementations with `update-alternatives, instead of setting + implementations with `update-alternatives`, instead of setting `CC` and `CXX` in our implementations. diff --git a/.jenkins/pytorch/build-mobile-code-analysis.sh b/.jenkins/pytorch/build-mobile-code-analysis.sh index 982ab257f84dc..0e6d4be88be37 100755 --- a/.jenkins/pytorch/build-mobile-code-analysis.sh +++ b/.jenkins/pytorch/build-mobile-code-analysis.sh @@ -5,6 +5,7 @@ set -eu -o pipefail # This script builds and runs code analyzer tool to generate aten op dependency # graph for custom mobile build. +# shellcheck disable=SC2034 COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" source "$(dirname "${BASH_SOURCE[0]}")/common.sh" diff --git a/.jenkins/pytorch/build-mobile.sh b/.jenkins/pytorch/build-mobile.sh index b1234f2728130..3ffec50741711 100755 --- a/.jenkins/pytorch/build-mobile.sh +++ b/.jenkins/pytorch/build-mobile.sh @@ -6,6 +6,7 @@ set -eu -o pipefail # build & test mobile libtorch without having to setup Android/iOS # toolchain/simulator. +# shellcheck disable=SC2034 COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" source "$(dirname "${BASH_SOURCE[0]}")/common.sh" diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index b4c6e923808a8..55b63d2144d04 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -1,5 +1,7 @@ #!/bin/bash +set -ex + # Required environment variable: $BUILD_ENVIRONMENT # (This is set by default in the Docker images we build, so you don't # need to set it yourself. @@ -7,37 +9,8 @@ # shellcheck disable=SC2034 COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" -# Temp: use new sccache -if [[ -n "$IN_CIRCLECI" && "$BUILD_ENVIRONMENT" == *rocm* ]]; then - # Download customized sccache - sudo curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache - sudo chmod 755 /opt/cache/bin/sccache -fi - source "$(dirname "${BASH_SOURCE[0]}")/common.sh" -# For distributed, four environmental configs: -# (1) build with only NCCL -# (2) build with NCCL and MPI -# (3) build with only MPI -# (4) build with neither -if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]]; then - # TODO: move this to Docker - sudo apt-get -qq update - sudo apt-get -qq install --allow-downgrades --allow-change-held-packages libnccl-dev=2.5.6-1+cuda10.1 libnccl2=2.5.6-1+cuda10.1 -fi - -if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9*gcc7* ]] || [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9*gcc5* ]] || [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]] || [[ "$BUILD_ENVIRONMENT" == *-trusty-py2.7.9* ]]; then - # TODO: move this to Docker - sudo apt-get -qq update - if [[ "$BUILD_ENVIRONMENT" == *-trusty-py2.7.9* ]]; then - sudo apt-get -qq install openmpi-bin libopenmpi-dev - else - sudo apt-get -qq install --allow-downgrades --allow-change-held-packages openmpi-bin libopenmpi-dev - fi - sudo mkdir -p /var/run/sshd -fi - if [[ "$BUILD_ENVIRONMENT" == *-linux-xenial-py3-clang5-asan* ]]; then exec "$(dirname "${BASH_SOURCE[0]}")/build-asan.sh" "$@" fi @@ -64,6 +37,11 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then nvcc --version fi +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + # enable build option in CMake + export USE_CPP_CODE_COVERAGE=ON +fi + # TODO: Don't run this... pip_install -r requirements.txt || true @@ -89,8 +67,14 @@ if [[ "$BUILD_ENVIRONMENT" == *libtorch* ]]; then POSSIBLE_JAVA_HOMES+=(/usr/local) POSSIBLE_JAVA_HOMES+=(/usr/lib/jvm/java-8-openjdk-amd64) POSSIBLE_JAVA_HOMES+=(/Library/Java/JavaVirtualMachines/*.jdk/Contents/Home) + # Add the Windows-specific JNI + POSSIBLE_JAVA_HOMES+=("$PWD/.circleci/windows-jni/") for JH in "${POSSIBLE_JAVA_HOMES[@]}" ; do if [[ -e "$JH/include/jni.h" ]] ; then + # Skip if we're not on Windows but haven't found a JAVA_HOME + if [[ "$JH" == "$PWD/.circleci/windows-jni/" && "$OSTYPE" != "msys" ]] ; then + break + fi echo "Found jni.h under $JH" export JAVA_HOME="$JH" export BUILD_JNI=ON @@ -135,48 +119,35 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then export MAX_JOBS=$(($(nproc) - 1)) fi - # ROCm CI is using Caffe2 docker images, which needs these wrapper - # scripts to correctly use sccache. - if [[ -n "${SCCACHE_BUCKET}" && -z "$IN_CIRCLECI" ]]; then - mkdir -p ./sccache - - SCCACHE="$(which sccache)" - if [ -z "${SCCACHE}" ]; then - echo "Unable to find sccache..." - exit 1 - fi - - # Setup wrapper scripts - for compiler in cc c++ gcc g++ clang clang++; do - ( - echo "#!/bin/sh" - echo "exec $SCCACHE $(which $compiler) \"\$@\"" - ) > "./sccache/$compiler" - chmod +x "./sccache/$compiler" - done - - export CACHE_WRAPPER_DIR="$PWD/sccache" - - # CMake must find these wrapper scripts - export PATH="$CACHE_WRAPPER_DIR:$PATH" - fi - - if [[ -n "$IN_CIRCLECI" ]]; then - # Set ROCM_ARCH to gtx900 and gtx906 in CircleCI - echo "Limiting PYTORCH_ROCM_ARCH to gfx90[06] for CircleCI builds" + if [[ -n "$IN_CI" ]]; then + # Set ROCM_ARCH to gfx900 and gfx906 for CI builds + echo "Limiting PYTORCH_ROCM_ARCH to gfx90[06] for CI builds" export PYTORCH_ROCM_ARCH="gfx900;gfx906" fi python tools/amd_build/build_amd.py python setup.py install --user + # remove sccache wrappers post-build; runtime compilation of MIOpen kernels does not yet fully support them + sudo rm -f /opt/cache/bin/cc + sudo rm -f /opt/cache/bin/c++ + sudo rm -f /opt/cache/bin/gcc + sudo rm -f /opt/cache/bin/g++ + pushd /opt/rocm/llvm/bin + if [[ -d original ]]; then + sudo mv original/clang . + sudo mv original/clang++ . + fi + sudo rm -rf original + popd + exit 0 fi # sccache will fail for CUDA builds if all cores are used for compiling # gcc 7 with sccache seems to have intermittent OOM issue if all cores are used if [ -z "$MAX_JOBS" ]; then - if ([[ "$BUILD_ENVIRONMENT" == *cuda* ]] || [[ "$BUILD_ENVIRONMENT" == *gcc7* ]]) && which sccache > /dev/null; then + if { [[ "$BUILD_ENVIRONMENT" == *cuda* ]] || [[ "$BUILD_ENVIRONMENT" == *gcc7* ]]; } && which sccache > /dev/null; then export MAX_JOBS=$(($(nproc) - 1)) fi fi diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh new file mode 100755 index 0000000000000..47d13f2908d04 --- /dev/null +++ b/.jenkins/pytorch/codegen-test.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +# This script can also be used to test whether your diff changes any codegen output. +# +# Run it before and after your change: +# .jenkins/pytorch/codegen-test.sh +# .jenkins/pytorch/codegen-test.sh +# +# Then run diff to compare the generated files: +# diff -Naur + +set -eu -o pipefail + +if [ "$#" -eq 0 ]; then + # shellcheck disable=SC2034 + COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" + source "$(dirname "${BASH_SOURCE[0]}")/common.sh" + OUT="$(dirname "${BASH_SOURCE[0]}")/../../codegen_result" +else + OUT=$1 +fi + +set -x + +rm -rf "$OUT" + +# aten codegen +python -m tools.codegen.gen \ + -d "$OUT"/torch/share/ATen + +# torch codegen +python -m tools.setup_helpers.generate_code \ + --declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \ + --install_dir "$OUT" + +# pyi codegen +mkdir -p "$OUT"/pyi/torch/_C +mkdir -p "$OUT"/pyi/torch/nn +python -m tools.pyi.gen_pyi \ + --native-functions-path aten/src/ATen/native/native_functions.yaml \ + --deprecated-functions-path tools/autograd/deprecated.yaml \ + --out "$OUT"/pyi + +# autograd codegen (called by torch codegen but can run independently) +python -m tools.autograd.gen_autograd \ + "$OUT"/torch/share/ATen/Declarations.yaml \ + aten/src/ATen/native/native_functions.yaml \ + "$OUT"/autograd \ + tools/autograd + +# annotated_fn_args codegen (called by torch codegen but can run independently) +mkdir -p "$OUT"/annotated_fn_args +python -m tools.autograd.gen_annotated_fn_args \ + aten/src/ATen/native/native_functions.yaml \ + "$OUT"/annotated_fn_args \ + tools/autograd diff --git a/.jenkins/pytorch/common.sh b/.jenkins/pytorch/common.sh index 559009c15957b..7ae9470c5c7c0 100644 --- a/.jenkins/pytorch/common.sh +++ b/.jenkins/pytorch/common.sh @@ -12,11 +12,13 @@ SCRIPT_DIR="$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )" # Figure out which Python to use for ROCm if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]] && [[ "${BUILD_ENVIRONMENT}" =~ py((2|3)\.?[0-9]?\.?[0-9]?) ]]; then + # HIP_PLATFORM is auto-detected by hipcc; unset to avoid build errors + unset HIP_PLATFORM PYTHON=$(which "python${BASH_REMATCH[1]}") # non-interactive bashs do not expand aliases by default shopt -s expand_aliases export PYTORCH_TEST_WITH_ROCM=1 - alias python="$PYTHON" + alias python='$PYTHON' # temporary to locate some kernel issues on the CI nodes export HSAKMT_DEBUG_LEVEL=4 fi @@ -43,7 +45,7 @@ fatal() { error "$@"; exit 1; } # - remaining args: names of traps to modify # trap_add() { - trap_add_cmd=$1; shift || fatal "${FUNCNAME} usage error" + trap_add_cmd=$1; shift || fatal "${FUNCNAME[0]} usage error" for trap_add_name in "$@"; do trap -- "$( # helper fn to get existing trap command from output @@ -114,6 +116,7 @@ if [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda10.1-cudnn7-py3* ]] || \ [[ "$BUILD_ENVIRONMENT" == *pytorch_macos* ]]; then BUILD_TEST_LIBTORCH=1 else + # shellcheck disable=SC2034 BUILD_TEST_LIBTORCH=0 fi @@ -126,6 +129,7 @@ fi if [[ "$BUILD_ENVIRONMENT" == *pytorch-xla-linux-bionic* ]] || \ [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda9-cudnn7-py2* ]] || \ [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda10.1-cudnn7-py3* ]] || \ + [[ "$BUILD_ENVIRONMENT" == *pytorch-*centos* ]] || \ [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-bionic* ]]; then if ! which conda; then echo "Expected ${BUILD_ENVIRONMENT} to use conda, but 'which conda' returns empty" @@ -133,9 +137,12 @@ if [[ "$BUILD_ENVIRONMENT" == *pytorch-xla-linux-bionic* ]] || \ else conda install -q -y cmake fi + if [[ "$BUILD_ENVIRONMENT" == *pytorch-*centos* ]]; then + # cmake3 package will conflict with conda cmake + sudo yum -y remove cmake3 || true + fi fi retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) + "$@" || (sleep 1 && "$@") || (sleep 2 && "$@") } - diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index 682dd29b4cff9..b28dcb2f41d8d 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -18,7 +18,7 @@ function cleanup { function assert_git_not_dirty() { # TODO: we should add an option to `build_amd.py` that reverts the repo to # an unmodified state. - if ([[ "$BUILD_ENVIRONMENT" != *rocm* ]] && [[ "$BUILD_ENVIRONMENT" != *xla* ]]) ; then + if [[ "$BUILD_ENVIRONMENT" != *rocm* ]] && [[ "$BUILD_ENVIRONMENT" != *xla* ]] ; then git_status=$(git status --porcelain) if [[ $git_status ]]; then echo "Build left local git repository checkout dirty" @@ -66,7 +66,7 @@ function get_bazel() { chmod +x tools/bazel } -TORCHVISION_COMMIT=c2e8a00885e68ae1200eb6440f540e181d9125de +TORCHVISION_COMMIT=e70c91a9ff9b8a20e05c133aec6ec3ed538c32fb function install_torchvision() { # Check out torch/vision at Jun 11 2020 commit diff --git a/.jenkins/pytorch/macos-build.sh b/.jenkins/pytorch/macos-build.sh index a27278c51ee57..25bf368e86ef1 100755 --- a/.jenkins/pytorch/macos-build.sh +++ b/.jenkins/pytorch/macos-build.sh @@ -8,48 +8,26 @@ git submodule update --init --recursive export CMAKE_PREFIX_PATH=${WORKSPACE_DIR}/miniconda3/ # Build PyTorch -if [[ "${BUILD_ENVIRONMENT}" == *cuda9.2* ]]; then - export CUDA_VERSION=9.2 - export TORCH_CUDA_ARCH_LIST=5.2 - export PATH=/Developer/NVIDIA/CUDA-${CUDA_VERSION}/bin${PATH:+:${PATH}} - export DYLD_LIBRARY_PATH=/Developer/NVIDIA/CUDA-${CUDA_VERSION}/lib${DYLD_LIBRARY_PATH:+:${DYLD_LIBRARY_PATH}} - export CUDA_HOME=/Developer/NVIDIA/CUDA-${CUDA_VERSION} - export USE_CUDA=1 - - if [ -z "${IN_CIRCLECI}" ]; then - # Eigen gives "explicit specialization of class must precede its first use" error - # when compiling with Xcode 9.1 toolchain, so we have to use Xcode 8.2 toolchain instead. - export DEVELOPER_DIR=/Library/Developer/CommandLineTools - fi -else - if [ -z "${IN_CIRCLECI}" ]; then - export DEVELOPER_DIR=/Applications/Xcode9.app/Contents/Developer - fi +if [ -z "${IN_CI}" ]; then + export DEVELOPER_DIR=/Applications/Xcode9.app/Contents/Developer fi if which sccache > /dev/null; then - printf "#!/bin/sh\nexec sccache $(which clang++) \$*" > "${WORKSPACE_DIR}/clang++" + printf "#!/bin/sh\nexec sccache %s \$*" "$(which clang++)" > "${WORKSPACE_DIR}/clang++" chmod a+x "${WORKSPACE_DIR}/clang++" - printf "#!/bin/sh\nexec sccache $(which clang) \$*" > "${WORKSPACE_DIR}/clang" + printf "#!/bin/sh\nexec sccache %s \$*" "$(which clang)" > "${WORKSPACE_DIR}/clang" chmod a+x "${WORKSPACE_DIR}/clang" - if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then - printf "#!/bin/sh\nexec sccache $(which nvcc) \$*" > "${WORKSPACE_DIR}/nvcc" - chmod a+x "${WORKSPACE_DIR}/nvcc" - export CUDA_NVCC_EXECUTABLE="${WORKSPACE_DIR}/nvcc" - fi - export PATH="${WORKSPACE_DIR}:$PATH" fi -# If we run too many parallel jobs, we will OOM -MAX_JOBS=2 USE_DISTRIBUTED=1 python setup.py install +USE_DISTRIBUTED=1 python setup.py install assert_git_not_dirty # Upload torch binaries when the build job is finished -if [ -z "${IN_CIRCLECI}" ]; then +if [ -z "${IN_CI}" ]; then 7z a ${IMAGE_COMMIT_TAG}.7z ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch* aws s3 cp ${IMAGE_COMMIT_TAG}.7z s3://ossci-macos-build/pytorch/${IMAGE_COMMIT_TAG}.7z --acl public-read fi diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 213750ba7280d..24ec02c76df53 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -7,14 +7,9 @@ conda install -y six pip install -q hypothesis "librosa>=0.6.2" "numba<=0.49.1" psutil # TODO move this to docker -pip install unittest-xml-reporting +pip install unittest-xml-reporting pytest -# faulthandler become built-in since 3.3 -if [[ ! $(python -c "import sys; print(int(sys.version_info >= (3, 3)))") == "1" ]]; then - pip install -q faulthandler -fi - -if [ -z "${IN_CIRCLECI}" ]; then +if [ -z "${IN_CI}" ]; then rm -rf ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch* fi @@ -23,7 +18,7 @@ git submodule update --init --recursive export CMAKE_PREFIX_PATH=${WORKSPACE_DIR}/miniconda3/ # Test PyTorch -if [ -z "${IN_CIRCLECI}" ]; then +if [ -z "${IN_CI}" ]; then if [[ "${BUILD_ENVIRONMENT}" == *cuda9.2* ]]; then # Eigen gives "explicit specialization of class must precede its first use" error # when compiling with Xcode 9.1 toolchain, so we have to use Xcode 8.2 toolchain instead. @@ -34,7 +29,7 @@ if [ -z "${IN_CIRCLECI}" ]; then fi # Download torch binaries in the test jobs -if [ -z "${IN_CIRCLECI}" ]; then +if [ -z "${IN_CI}" ]; then rm -rf ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch* aws s3 cp s3://ossci-macos-build/pytorch/${IMAGE_COMMIT_TAG}.7z ${IMAGE_COMMIT_TAG}.7z 7z x ${IMAGE_COMMIT_TAG}.7z -o"${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages" @@ -63,7 +58,7 @@ test_python_all() { # Increase default limit on open file handles from 256 to 1024 ulimit -n 1024 - python test/run_test.py --verbose --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_jit_legacy test_jit_fuser_legacy --determine-from="$DETERMINE_FROM" + python test/run_test.py --verbose --exclude-jit-executor --determine-from="$DETERMINE_FROM" assert_git_not_dirty } diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index efc58f0daed62..fdf3c03e7f679 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -10,26 +10,16 @@ COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch (distributed only)" -if [ -n "${IN_CIRCLECI}" ]; then +if [ -n "${IN_CI}" ]; then # TODO move this to docker pip_install unittest-xml-reporting - - if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]]; then - # TODO: move this to Docker - sudo apt-get update - sudo apt-get install -y --allow-downgrades --allow-change-held-packages libnccl-dev=2.5.6-1+cuda10.1 libnccl2=2.5.6-1+cuda10.1 - fi - - if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-cudnn7-py3* ]]; then - # TODO: move this to Docker - sudo apt-get update - sudo apt-get install -y --allow-downgrades --allow-change-held-packages openmpi-bin libopenmpi-dev - fi fi python tools/download_mnist.py --quiet -d test/cpp/api/mnist OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api +time python test/run_test.py --verbose -i distributed/test_jit_c10d time python test/run_test.py --verbose -i distributed/test_distributed_fork time python test/run_test.py --verbose -i distributed/test_c10d time python test/run_test.py --verbose -i distributed/test_c10d_spawn +time python test/run_test.py --verbose -i distributed/rpc/test_tensorpipe_agent assert_git_not_dirty diff --git a/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh b/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh index 795251fc86257..4f86eb88fe0cf 100644 --- a/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh +++ b/.jenkins/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh @@ -21,7 +21,7 @@ test_cpu_speed_mini_sequence_labeler () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py) - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../../.. diff --git a/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh b/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh index 29086fbc9976e..e284bb3aa6cca 100644 --- a/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh +++ b/.jenkins/pytorch/perf_test/test_cpu_speed_mnist.sh @@ -23,7 +23,7 @@ test_cpu_speed_mnist () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py --epochs 1 --no-log) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh index 667cfba617fca..25109fdb8428f 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh @@ -22,7 +22,7 @@ test_gpu_speed_cudnn_lstm () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python cudnn_lstm.py --skip-cpu-governor-check) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh index ea220b33ac7cc..e0f629cde86ed 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_lstm.sh @@ -22,7 +22,7 @@ test_gpu_speed_lstm () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python lstm.py --skip-cpu-governor-check) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh index 62b94a7b21d18..46bb2b5ba2e3d 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_mlstm.sh @@ -22,7 +22,7 @@ test_gpu_speed_mlstm () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python mlstm.py --skip-cpu-governor-check) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh index 2453f3d70e708..2868cfca30c39 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh @@ -26,7 +26,7 @@ test_gpu_speed_mnist () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py --epochs 1 --no-log) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh b/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh index b1dea09c7c347..d0ae3160a22b5 100644 --- a/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh +++ b/.jenkins/pytorch/perf_test/test_gpu_speed_word_language_model.sh @@ -31,7 +31,7 @@ test_gpu_speed_word_language_model () { for (( i=1; i<=NUM_RUNS; i++ )) do runtime=$(get_runtime_of_command python main.py --cuda --epochs 1) echo $runtime - SAMPLE_ARRAY+=(${runtime}) + SAMPLE_ARRAY+=("${runtime}") done cd ../.. diff --git a/.jenkins/pytorch/short-perf-test-cpu.sh b/.jenkins/pytorch/short-perf-test-cpu.sh index ae838276bd4d9..a77e6245e3554 100755 --- a/.jenkins/pytorch/short-perf-test-cpu.sh +++ b/.jenkins/pytorch/short-perf-test-cpu.sh @@ -27,13 +27,12 @@ fi git remote add upstream https://github.com/pytorch/pytorch.git git fetch upstream IFS=$'\n' -master_commit_ids=($(git rev-list upstream/master)) -for commit_id in "${master_commit_ids[@]}"; do +while IFS='' read -r commit_id; do if aws s3 ls s3://ossci-perf-test/pytorch/cpu_runtime/${commit_id}.json; then LATEST_TESTED_COMMIT=${commit_id} break fi -done +done < <(git rev-list upstream/master) aws s3 cp s3://ossci-perf-test/pytorch/cpu_runtime/${LATEST_TESTED_COMMIT}.json cpu_runtime.json if [[ "$COMMIT_SOURCE" == master ]]; then diff --git a/.jenkins/pytorch/short-perf-test-gpu.sh b/.jenkins/pytorch/short-perf-test-gpu.sh index 8fd701e197200..ec445409390b3 100755 --- a/.jenkins/pytorch/short-perf-test-gpu.sh +++ b/.jenkins/pytorch/short-perf-test-gpu.sh @@ -26,13 +26,12 @@ fi git remote add upstream https://github.com/pytorch/pytorch.git git fetch upstream IFS=$'\n' -master_commit_ids=($(git rev-list upstream/master)) -for commit_id in "${master_commit_ids[@]}"; do +while IFS='' read -r commit_id; do if aws s3 ls s3://ossci-perf-test/pytorch/gpu_runtime/${commit_id}.json; then LATEST_TESTED_COMMIT=${commit_id} break fi -done +done < <(git rev-list upstream/master) aws s3 cp s3://ossci-perf-test/pytorch/gpu_runtime/${LATEST_TESTED_COMMIT}.json gpu_runtime.json if [[ "$COMMIT_SOURCE" == master ]]; then diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 7e85039a72d1c..8e9afd5c9bc30 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -11,34 +11,20 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" -if [ -n "${IN_CIRCLECI}" ]; then - # TODO move this to docker - pip_install unittest-xml-reporting coverage - - if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]]; then - # TODO: move this to Docker - sudo apt-get -qq update - sudo apt-get -qq install --allow-downgrades --allow-change-held-packages libnccl-dev=2.5.6-1+cuda10.1 libnccl2=2.5.6-1+cuda10.1 - fi +export LANG=C.UTF-8 - if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-cudnn7-py3* ]]; then - # TODO: move this to Docker - sudo apt-get -qq update - sudo apt-get -qq install --allow-downgrades --allow-change-held-packages openmpi-bin libopenmpi-dev - fi +if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then + export PYTORCH_TEST_WITH_SLOW=1 + export PYTORCH_TEST_SKIP_FAST=1 +fi - if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then - export PYTORCH_TEST_WITH_SLOW=1 - export PYTORCH_TEST_SKIP_FAST=1 - fi - if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then - export PYTORCH_COLLECT_COVERAGE=1 - fi +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + export PYTORCH_COLLECT_COVERAGE=1 fi if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then # Print GPU info - rocminfo | egrep 'Name:.*\sgfx|Marketing' + rocminfo | grep -E 'Name:.*\sgfx|Marketing' fi # --user breaks ppc64le builds and these packages are already in ppc64le docker @@ -48,18 +34,6 @@ if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]] && [[ "$BUILD_ENVIRONMENT" != *-bazel # ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins # but this script should be runnable by any user, including root export PATH="$HOME/.local/bin:$PATH" - - # TODO: Please move this to Docker - # The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 - pip_install --user "hypothesis==4.53.2" - # Pin MyPy version because new errors are likely to appear with each release - pip_install --user "mypy==0.770" - # Update scikit-learn to a python-3.8 compatible version - if [[ $(python -c "import sys; print(int(sys.version_info >= (3, 8)))") == "1" ]]; then - pip_install -U scikit-learn - fi - - pip_install --user tb-nightly fi # DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems @@ -121,28 +95,23 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* ]]; then export ATEN_CPU_CAPABILITY=avx fi -if ([ -n "$CIRCLE_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]); then +if [ -n "$CIRCLE_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then DETERMINE_FROM=$(mktemp) file_diff_from_base "$DETERMINE_FROM" fi -test_python_nn() { - time python test/run_test.py --include test_nn --verbose --determine-from="$DETERMINE_FROM" - assert_git_not_dirty -} - -test_python_ge_config_profiling() { - time python test/run_test.py --include test_jit_cuda_fuser_profiling test_jit_profiling test_jit_fuser_te test_tensorexpr --verbose --determine-from="$DETERMINE_FROM" +test_python_legacy_jit() { + time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose --determine-from="$DETERMINE_FROM" assert_git_not_dirty } -test_python_ge_config_legacy() { - time python test/run_test.py --include test_jit_cuda_fuser_legacy test_jit_legacy test_jit_fuser_legacy --verbose --determine-from="$DETERMINE_FROM" +test_python_shard1() { + time python test/run_test.py --exclude-jit-executor --shard 1 2 --verbose --determine-from="$DETERMINE_FROM" assert_git_not_dirty } -test_python_all_except_nn_and_cpp_extensions() { - time python test/run_test.py --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_nn test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="$DETERMINE_FROM" +test_python_shard2() { + time python test/run_test.py --exclude-jit-executor --shard 2 2 --verbose --determine-from="$DETERMINE_FROM" assert_git_not_dirty } @@ -150,7 +119,7 @@ test_aten() { # Test ATen # The following test(s) of ATen have already been skipped by caffe2 in rocm environment: # scalar_tensor_test, basic, native_test - if ([[ "$BUILD_ENVIRONMENT" != *asan* ]] && [[ "$BUILD_ENVIRONMENT" != *rocm* ]]); then + if [[ "$BUILD_ENVIRONMENT" != *asan* ]] && [[ "$BUILD_ENVIRONMENT" != *rocm* ]]; then echo "Running ATen tests with pytorch lib" TORCH_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/torch/lib # NB: the ATen test binaries don't have RPATH set, so it's necessary to @@ -288,7 +257,7 @@ test_torch_function_benchmark() { test_xla() { export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" # Issue #30717: randomize the port of XLA/gRPC workers is listening on to reduce flaky tests. - XLA_PORT=`shuf -i 40701-40999 -n 1` + XLA_PORT=$(shuf -i 40701-40999 -n 1) export XRT_WORKERS="localservice:0;grpc://localhost:$XLA_PORT" pushd xla echo "Running Python Tests" @@ -304,7 +273,7 @@ test_xla() { assert_git_not_dirty } -# Do NOT run this test before any other tests, like test_python_nn, etc. +# Do NOT run this test before any other tests, like test_python_shard1, etc. # Because this function uninstalls the torch built from branch, and install # nightly version. test_backward_compatibility() { @@ -338,6 +307,8 @@ test_benchmarks() { pip_install --user "requests" BENCHMARK_DATA="benchmarks/.data" mkdir -p ${BENCHMARK_DATA} + pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_default.json --fuser=default --executor=default + python benchmarks/upload_scribe.py --pytest_bench_json ${BENCHMARK_DATA}/fastrnns_default.json pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_legacy_old.json --fuser=old --executor=legacy python benchmarks/upload_scribe.py --pytest_bench_json ${BENCHMARK_DATA}/fastrnns_legacy_old.json pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_profiling_te.json --fuser=te --executor=profiling @@ -379,19 +350,17 @@ if [[ "${BUILD_ENVIRONMENT}" == *backward* ]]; then elif [[ "${BUILD_ENVIRONMENT}" == *xla* || "${JOB_BASE_NAME}" == *xla* ]]; then install_torchvision test_xla -elif [[ "${BUILD_ENVIRONMENT}" == *ge_config_legacy* || "${JOB_BASE_NAME}" == *ge_config_legacy* ]]; then - test_python_ge_config_legacy -elif [[ "${BUILD_ENVIRONMENT}" == *ge_config_profiling* || "${JOB_BASE_NAME}" == *ge_config_profiling* ]]; then - test_python_ge_config_profiling +elif [[ "${BUILD_ENVIRONMENT}" == *jit_legacy-test || "${JOB_BASE_NAME}" == *jit_legacy-test ]]; then + test_python_legacy_jit elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then # TODO: run some C++ tests echo "no-op at the moment" elif [[ "${BUILD_ENVIRONMENT}" == *-test1 || "${JOB_BASE_NAME}" == *-test1 ]]; then - test_python_nn - test_cpp_extensions + install_torchvision + test_python_shard1 elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; then install_torchvision - test_python_all_except_nn_and_cpp_extensions + test_python_shard2 test_aten test_libtorch test_custom_script_ops @@ -407,9 +376,8 @@ elif [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc5.4 test_cpp_extensions else install_torchvision - test_python_nn - test_python_all_except_nn_and_cpp_extensions - test_cpp_extensions + test_python_shard1 + test_python_shard2 test_aten test_vec256 test_libtorch @@ -419,10 +387,15 @@ else test_distributed test_benchmarks test_rpc - if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then - pushd test - echo "Generating XML coverage report" - time python -mcoverage xml - popd - fi +fi + +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + pushd test + echo "Generating XML coverage report" + time python -mcoverage xml + popd + pushd build + echo "Generating lcov coverage report for C++ sources" + time lcov --capture --directory . --output-file coverage.info + popd fi diff --git a/.jenkins/pytorch/win-build.sh b/.jenkins/pytorch/win-build.sh index 6df36ce4d021f..f8d289f3336d9 100755 --- a/.jenkins/pytorch/win-build.sh +++ b/.jenkins/pytorch/win-build.sh @@ -15,7 +15,7 @@ COMPACT_JOB_NAME=pytorch-win-ws2019-cuda10-cudnn7-py3-build SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) source "$SCRIPT_PARENT_DIR/common.sh" -export IMAGE_COMMIT_ID=`git rev-parse HEAD` +export IMAGE_COMMIT_ID=$(git rev-parse HEAD) export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID} if [[ ${JOB_NAME} == *"develop"* ]]; then export IMAGE_COMMIT_TAG=develop-${IMAGE_COMMIT_TAG} @@ -38,6 +38,21 @@ fi export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers +set +ex +grep -E -R 'PyLong_(From|As)(Unsigned|)Long\(' --exclude=python_numbers.h torch/ +PYLONG_API_CHECK=$? +if [[ $PYLONG_API_CHECK == 0 ]]; then + echo "Usage of PyLong_{From,As}{Unsigned}Long API may lead to overflow errors on Windows" + echo "because \`sizeof(long) == 4\` and \`sizeof(unsigned long) == 4\`." + echo "Please include \"torch/csrc/python_numbers.h\" and use the correspoding APIs instead." + echo "PyLong_FromLong -> THPUtils_packInt32 / THPUtils_packInt64" + echo "PyLong_AsLong -> THPUtils_unpackInt (32-bit) / THPUtils_unpackLong (64-bit)" + echo "PyLong_FromUnsignedLong -> THPUtils_packUInt32 / THPUtils_packUInt64" + echo "PyLong_AsUnsignedLong -> THPUtils_unpackUInt32 / THPUtils_unpackUInt64" + exit 1 +fi +set -ex + $SCRIPT_HELPERS_DIR/build_pytorch.bat assert_git_not_dirty diff --git a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat index 0ddf3b4b462c0..7165f75a0e418 100644 --- a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat +++ b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat @@ -37,33 +37,19 @@ if "%VC_VERSION%" == "" ( @echo on popd -if "%CUDA_VERSION%" == "9" goto cuda_build_9 -if "%CUDA_VERSION%" == "10" goto cuda_build_10 -if "%CUDA_VERSION%" == "11" goto cuda_build_11 -goto cuda_build_end +if not "%USE_CUDA%"=="1" goto cuda_build_end -:cuda_build_9 +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION% -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2 -set CUDA_PATH_V9_2=%CUDA_PATH% +rem version transformer, for example 10.1 to 10_1. +set VERSION_SUFFIX=%CUDA_VERSION:.=_% +set CUDA_PATH_V%VERSION_SUFFIX%=%CUDA_PATH% -goto cuda_build_common - -:cuda_build_10 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1 -set CUDA_PATH_V10_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_11 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0 -set CUDA_PATH_V11_0=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_common +set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 +set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% +set CUDNN_ROOT_DIR=%CUDA_PATH% +set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt +set PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH% set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% @@ -95,7 +81,7 @@ if "%USE_CUDA%"=="1" ( copy %TMP_DIR_WIN%\bin\sccache.exe %TMP_DIR_WIN%\bin\nvcc.exe :: randomtemp is used to resolve the intermittent build error related to CUDA. - :: code: https://github.com/peterjc123/randomtemp + :: code: https://github.com/peterjc123/randomtemp-rust :: issue: https://github.com/pytorch/pytorch/issues/25393 :: :: Previously, CMake uses CUDA_NVCC_EXECUTABLE for finding nvcc and then @@ -103,7 +89,7 @@ if "%USE_CUDA%"=="1" ( :: in PATH, and then pass the arguments to it. :: Currently, randomtemp is placed before sccache (%TMP_DIR_WIN%\bin\nvcc) :: so we are actually pretending sccache instead of nvcc itself. - curl -kL https://github.com/peterjc123/randomtemp/releases/download/v0.3/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe + curl -kL https://github.com/peterjc123/randomtemp-rust/releases/download/v0.3/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe set RANDOMTEMP_EXECUTABLE=%TMP_DIR_WIN%\bin\nvcc.exe set CUDA_NVCC_EXECUTABLE=%TMP_DIR_WIN%\bin\randomtemp.exe set RANDOMTEMP_BASEDIR=%TMP_DIR_WIN%\bin diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat index c7d60bedafd71..ab102a0ea4233 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat @@ -1,18 +1,18 @@ -if "%CUDA_VERSION%" == "9" set CUDA_SUFFIX=cuda92 -if "%CUDA_VERSION%" == "10" set CUDA_SUFFIX=cuda101 -if "%CUDA_VERSION%" == "11" set CUDA_SUFFIX=cuda110 +rem remove dot in cuda_version, fox example 11.1 to 111 +set VERSION_SUFFIX=%CUDA_VERSION:.=% +set CUDA_SUFFIX=cuda%VERSION_SUFFIX% if "%CUDA_SUFFIX%" == "" ( - echo unknown CUDA version, please set `CUDA_VERSION` to 9, 10 or 11. + echo unknown CUDA version, please set `CUDA_VERSION` higher than 9.2 exit /b 1 ) if "%REBUILD%"=="" ( if "%BUILD_ENVIRONMENT%"=="" ( - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/magma_2.5.3_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --output %TMP_DIR_WIN%\magma_2.5.3_%CUDA_SUFFIX%_%BUILD_TYPE%.7z + curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --output %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z ) else ( - aws s3 cp s3://ossci-windows/magma_2.5.3_%CUDA_SUFFIX%_%BUILD_TYPE%.7z %TMP_DIR_WIN%\magma_2.5.3_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --quiet + aws s3 cp s3://ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --quiet ) - 7z x -aoa %TMP_DIR_WIN%\magma_2.5.3_%CUDA_SUFFIX%_%BUILD_TYPE%.7z -o%TMP_DIR_WIN%\magma + 7z x -aoa %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z -o%TMP_DIR_WIN%\magma ) set MAGMA_HOME=%TMP_DIR_WIN%\magma diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat index a66ef4b651c5d..7669f6cfd91ee 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_miniconda3.bat @@ -12,4 +12,11 @@ call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Minic if "%REBUILD%"=="" ( call conda install -y -q python=%PYTHON_VERSION% numpy cffi pyyaml boto3 call conda install -y -q -c conda-forge cmake + call conda install -y -q -c conda-forge libuv=1.39 ) + +:: Get installed libuv path +@echo off +set libuv_ROOT=%CONDA_PARENT_DIR%\Miniconda3\Library +@echo on +echo libuv_ROOT=%libuv_ROOT% diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat index 1fa84920cd701..656a5494ea3f8 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat @@ -7,4 +7,4 @@ if "%REBUILD%"=="" ( 7z x -aoa %TMP_DIR_WIN%\mkl.7z -o%TMP_DIR_WIN%\mkl ) set CMAKE_INCLUDE_PATH=%TMP_DIR_WIN%\mkl\include -set LIB=%TMP_DIR_WIN%\mkl\lib;%LIB +set LIB=%TMP_DIR_WIN%\mkl\lib;%LIB% diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat index 17a3d39d076d4..ed64828909932 100644 --- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat +++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat @@ -39,40 +39,18 @@ if %errorlevel% neq 0 ( exit /b %errorlevel% ) popd :: The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 -pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting +pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest coverage if %errorlevel% neq 0 ( exit /b %errorlevel% ) -:: No need to install faulthandler since we only test Python >= 3.6 on Windows -:: faulthandler is builtin since Python 3.3 set DISTUTILS_USE_SDK=1 -if "%CUDA_VERSION%" == "9" goto cuda_build_9 -if "%CUDA_VERSION%" == "10" goto cuda_build_10 -if "%CUDA_VERSION%" == "11" goto cuda_build_11 -goto cuda_build_end +if not "%USE_CUDA%"=="1" goto cuda_build_end -:cuda_build_9 +set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION% -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2 -set CUDA_PATH_V9_2=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_10 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1 -set CUDA_PATH_V10_1=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_11 - -set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0 -set CUDA_PATH_V11_0=%CUDA_PATH% - -goto cuda_build_common - -:cuda_build_common +rem version transformer, for example 10.1 to 10_1. +set VERSION_SUFFIX=%CUDA_VERSION:.=_% +set CUDA_PATH_V%VERSION_SUFFIX%=%CUDA_PATH% set CUDNN_LIB_DIR=%CUDA_PATH%\lib\x64 set CUDA_TOOLKIT_ROOT_DIR=%CUDA_PATH% diff --git a/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat b/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat index 4bfb5bc85e668..d76637dd0db7c 100644 --- a/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat +++ b/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat @@ -1,3 +1,3 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat -cd test && python run_test.py --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="%1" && cd .. +cd test && python run_test.py --exclude-jit-executor --verbose --determine-from="%1" && cd .. if ERRORLEVEL 1 exit /b 1 diff --git a/.jenkins/pytorch/win-test-helpers/test_python_jit_profiling.bat b/.jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat similarity index 51% rename from .jenkins/pytorch/win-test-helpers/test_python_jit_profiling.bat rename to .jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat index e437833d8c624..a9168644f4711 100644 --- a/.jenkins/pytorch/win-test-helpers/test_python_jit_profiling.bat +++ b/.jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat @@ -3,9 +3,7 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat pushd test echo Run jit_profiling tests -python run_test.py --include test_jit_profiling test_jit_fuser_te test_tensorexpr --verbose --determine-from="%1" +python run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose --determine-from="%1" if ERRORLEVEL 1 exit /b 1 popd - - diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index 0b0159d04a50b..c1c49cd711b8b 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -1,12 +1,12 @@ -#!/bin/bash -ex - +#!/bin/bash +set -ex # shellcheck disable=SC2034 COMPACT_JOB_NAME=pytorch-win-ws2019-cuda10-cudnn7-py3-test SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) source "$SCRIPT_PARENT_DIR/common.sh" -export IMAGE_COMMIT_ID=`git rev-parse HEAD` +export IMAGE_COMMIT_ID=$(git rev-parse HEAD) export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID} if [[ ${JOB_NAME} == *"develop"* ]]; then export IMAGE_COMMIT_TAG=develop-${IMAGE_COMMIT_TAG} @@ -14,6 +14,10 @@ fi export TMP_DIR="${PWD}/build/win_tmp" export TMP_DIR_WIN=$(cygpath -w "${TMP_DIR}") +export PROJECT_DIR="${PWD}" +export PROJECT_DIR_WIN=$(cygpath -w "${PROJECT_DIR}") +export TEST_DIR="${PWD}/test" +export TEST_DIR_WIN=$(cygpath -w "${TEST_DIR}") export PYTORCH_FINAL_PACKAGE_DIR="/c/users/circleci/workspace/build-results" export PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}") @@ -38,24 +42,40 @@ fi run_tests() { if [ -z "${JOB_BASE_NAME}" ] || [[ "${JOB_BASE_NAME}" == *-test ]]; then - $SCRIPT_HELPERS_DIR/test_python_nn.bat "$DETERMINE_FROM" && \ - $SCRIPT_HELPERS_DIR/test_python_all_except_nn.bat "$DETERMINE_FROM" && \ - $SCRIPT_HELPERS_DIR/test_custom_script_ops.bat && \ - $SCRIPT_HELPERS_DIR/test_custom_backend.bat && \ + $SCRIPT_HELPERS_DIR/test_python_nn.bat "$DETERMINE_FROM" + $SCRIPT_HELPERS_DIR/test_python_all_except_nn.bat "$DETERMINE_FROM" + $SCRIPT_HELPERS_DIR/test_custom_script_ops.bat + $SCRIPT_HELPERS_DIR/test_custom_backend.bat $SCRIPT_HELPERS_DIR/test_libtorch.bat else if [[ "${JOB_BASE_NAME}" == *-test1 ]]; then - $SCRIPT_HELPERS_DIR/test_python_nn.bat "$DETERMINE_FROM" && \ + export PYTORCH_COLLECT_COVERAGE=1 + $SCRIPT_HELPERS_DIR/test_python_nn.bat "$DETERMINE_FROM" $SCRIPT_HELPERS_DIR/test_libtorch.bat if [[ "${USE_CUDA}" == "1" ]]; then - $SCRIPT_HELPERS_DIR/test_python_jit_profiling.bat "$DETERMINE_FROM" + $SCRIPT_HELPERS_DIR/test_python_jit_legacy.bat "$DETERMINE_FROM" fi elif [[ "${JOB_BASE_NAME}" == *-test2 ]]; then - $SCRIPT_HELPERS_DIR/test_python_all_except_nn.bat "$DETERMINE_FROM" && \ - $SCRIPT_HELPERS_DIR/test_custom_backend.bat && \ + $SCRIPT_HELPERS_DIR/test_python_all_except_nn.bat "$DETERMINE_FROM" + $SCRIPT_HELPERS_DIR/test_custom_backend.bat $SCRIPT_HELPERS_DIR/test_custom_script_ops.bat fi fi } -run_tests && assert_git_not_dirty && echo "TEST PASSED" +run_tests +assert_git_not_dirty +echo "TEST PASSED" + +if [[ "${BUILD_ENVIRONMENT}" == "pytorch-win-vs2019-cuda10-cudnn7-py3" ]] && [[ "${JOB_BASE_NAME}" == *-test1 ]]; then + pushd $TEST_DIR + python -mpip install coverage + echo "Generating XML coverage report" + time python -mcoverage xml + popd + + pushd $PROJECT_DIR + python -mpip install codecov + python -mcodecov + popd +fi diff --git a/.jenkins/run-shellcheck.sh b/.jenkins/run-shellcheck.sh index 1333e9ab6f49d..5c64655b578fd 100755 --- a/.jenkins/run-shellcheck.sh +++ b/.jenkins/run-shellcheck.sh @@ -5,6 +5,4 @@ # .jenkins/run-shellcheck.sh --color=always | less -R -EXCLUSIONS=SC2086,SC1091,SC2155,SC1090,SC2164,SC1003 - -find .jenkins/pytorch -name *.sh | xargs shellcheck --exclude=$EXCLUSIONS --external-sources "$@" || true +find .jenkins/pytorch -name *.sh | xargs shellcheck --external-sources "$@" diff --git a/test/cpp/tensorexpr/__init__.py b/.nojekyll similarity index 100% rename from test/cpp/tensorexpr/__init__.py rename to .nojekyll diff --git a/.travis.aten.yml b/.travis.aten.yml deleted file mode 100644 index 2425845496258..0000000000000 --- a/.travis.aten.yml +++ /dev/null @@ -1,31 +0,0 @@ -# https://travis-ci.org/zdevito/ATen -language: python -python: - - 2.7 - - 3.6 - -dist: trusty - -before_install: - - sudo apt-get install -qq valgrind - -install: - - travis_retry pip install pyyaml typing - -script: - - cd aten - - mkdir build install - - cd build - - cmake .. -DUSE_CUDA=OFF -DCMAKE_INSTALL_PREFIX=../install - - make install - - ../tools/run_tests.sh . - - cd .. - - tools/test_install.sh $(pwd)/install $(pwd) - -matrix: - fast_finish: true - include: - env: LINT_CHECK - python: "2.7" - install: pip install flake8-mypy - script: flake8 diff --git a/BUILD.bazel b/BUILD.bazel index 016863ff09586..2b4636d850c99 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -55,6 +55,7 @@ cc_library( "c10/cuda/*.h", "c10/cuda/impl/*.h", "c10/macros/*.h", + "c10/mobile/*.h", "c10/util/*.h", "c10/util/*.hpp", ]), @@ -71,6 +72,7 @@ cc_library( srcs = glob([ "c10/core/*.cpp", "c10/core/impl/*.cpp", + "c10/mobile/*.cpp", "c10/util/*.cpp", ]) + if_cuda( glob([ @@ -123,20 +125,19 @@ genrule( ] + glob(["aten/src/ATen/templates/**"]), outs = [ "aten/src/ATen/Declarations.yaml", - "aten/src/ATen/BackendSelectRegister.cpp", - "aten/src/ATen/CPUType.h", - "aten/src/ATen/CPUType.cpp", + "aten/src/ATen/RegisterBackendSelect.cpp", + "aten/src/ATen/RegisterCPU.cpp", + "aten/src/ATen/RegisterMkldnnCPU.cpp", + "aten/src/ATen/RegisterQuantizedCPU.cpp", + "aten/src/ATen/RegisterSparseCPU.cpp", + "aten/src/ATen/RegisterMath.cpp", + "aten/src/ATen/RegisterMeta.cpp", + "aten/src/ATen/RegisterDefaultBackend.cpp", + "aten/src/ATen/RegisterSchema.cpp", "aten/src/ATen/Functions.h", "aten/src/ATen/Functions.cpp", "aten/src/ATen/NativeFunctions.h", - "aten/src/ATen/MkldnnCPUType.h", - "aten/src/ATen/MkldnnCPUType.cpp", - "aten/src/ATen/QuantizedCPUType.h", - "aten/src/ATen/QuantizedCPUType.cpp", - "aten/src/ATen/SparseCPUType.h", - "aten/src/ATen/SparseCPUType.cpp", - "aten/src/ATen/TypeDefault.h", - "aten/src/ATen/TypeDefault.cpp", + "aten/src/ATen/MetaFunctions.h", "aten/src/ATen/core/TensorBody.h", "aten/src/ATen/core/TensorMethods.cpp", "aten/src/ATen/core/ATenOpList.cpp", @@ -189,13 +190,9 @@ libtorch_cpp_generated_sources = [ "torch/csrc/autograd/generated/TraceType_3.cpp", "torch/csrc/autograd/generated/TraceType_4.cpp", # "torch/csrc/autograd/generated/TraceTypeEverything.cpp", - "torch/csrc/autograd/generated/RegistrationDeclarations.h", "torch/csrc/autograd/generated/Functions.h", "torch/csrc/autograd/generated/Functions.cpp", "torch/csrc/autograd/generated/variable_factories.h", - "torch/csrc/jit/generated/generated_unboxing_wrappers_0.cpp", - "torch/csrc/jit/generated/generated_unboxing_wrappers_1.cpp", - "torch/csrc/jit/generated/generated_unboxing_wrappers_2.cpp", ] libtorch_python_generated_sources = [ @@ -212,9 +209,10 @@ genrule( name = "all_generated_code", srcs = [ "aten/src/ATen/Declarations.yaml", + "aten/src/ATen/native/native_functions.yaml", ], outs = libtorch_cpp_generated_sources + libtorch_python_generated_sources, - cmd = "$(location :generate_code) --install_dir `dirname $(location torch/csrc/autograd/generated/variable_factories.h)`/../.. --declarations-path $(location aten/src/ATen/Declarations.yaml) --nn-path aten/src", + cmd = "$(location :generate_code) --install_dir `dirname $(location torch/csrc/autograd/generated/variable_factories.h)`/../.. --declarations-path $(location aten/src/ATen/Declarations.yaml) --native-functions-path $(location aten/src/ATen/native/native_functions.yaml) --nn-path aten/src", tools = [":generate_code"], ) @@ -296,6 +294,11 @@ filegroup( srcs = glob(["aten/src/ATen/vulkan/*.cpp"]), ) +filegroup( + name = "aten_base_metal", + srcs = glob(["aten/src/ATen/metal/*.cpp"]), +) + filegroup( name = "ATen_QUANTIZED_SRCS", srcs = glob( @@ -333,7 +336,8 @@ filegroup( "aten/src/ATen/cuda/CUDABlas.cpp", "aten/src/ATen/cuda/CUDASolver.cpp", "aten/src/ATen/cuda/CUDAContext.cpp", - "aten/src/ATen/cuda/CUDAGenerator.cpp", + "aten/src/ATen/cuda/CUDAGeneratorImpl.cpp", + "aten/src/ATen/cuda/CUDAGraph.cpp", "aten/src/ATen/cuda/CuSparseHandlePool.cpp", "aten/src/ATen/cuda/CublasHandlePool.cpp", "aten/src/ATen/cuda/CusolverDnHandlePool.cpp", @@ -366,7 +370,6 @@ filegroup( filegroup( name = "thc_srcs_cu", srcs = [ - "aten/src/THC/THCBlas.cu.cc", "aten/src/THC/THCReduceApplyUtils.cu.cc", "aten/src/THC/THCSleep.cu.cc", "aten/src/THC/THCSortUtils.cu.cc", @@ -376,7 +379,6 @@ filegroup( "aten/src/THC/THCTensorCopy.cu.cc", "aten/src/THC/THCTensorIndex.cu.cc", "aten/src/THC/THCTensorMath.cu.cc", - "aten/src/THC/THCTensorMathBlas.cu.cc", "aten/src/THC/THCTensorMathMagma.cu.cc", "aten/src/THC/THCTensorMathPairwise.cu.cc", "aten/src/THC/THCTensorMathReduce.cu.cc", @@ -453,6 +455,7 @@ filegroup( name = "aten_srcs_cu", srcs = [ "aten/src/ATen/cuda/detail/IndexUtils.cu.cc", + "aten/src/ATen/cuda/detail/CUDAGraphsUtils.cu.cc", "aten/src/ATen/native/cuda/Activation.cu.cc", "aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu.cc", "aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu.cc", @@ -538,6 +541,7 @@ header_template_rule( substitutions = { "@AT_MKLDNN_ENABLED@": "1", "@AT_MKL_ENABLED@": "0", + "@AT_FFTW_ENABLED@": "0", "@AT_NNPACK_ENABLED@": "0", "@CAFFE2_STATIC_LINK_CUDA_INT@": "0", "@USE_BLAS@": "1", @@ -593,7 +597,10 @@ cc_library( "aten/src/THC/generic/*.cu.cc", "aten/src/THCUNN/*.cuh", "aten/src/THCUNN/generic/*.cu.cc", - ]) + [ + ], + exclude = [ + "aten/src/ATen/Config.h", + ],) + [ ":generated_cpp", ":aten_src_ATen_config", ], @@ -648,6 +655,7 @@ cc_library( ":ATen_CORE_SRCS", ":ATen_QUANTIZED_SRCS", ":aten_base_cpp", + ":aten_base_metal", ":aten_base_vulkan", ":aten_native_cpp", ":aten_native_mkl_cpp", @@ -721,6 +729,7 @@ torch_cuda_half_options = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] @@ -1873,6 +1882,10 @@ cc_library( exclude = [ "torch/lib/c10d/ProcessGroupMPI.hpp", "torch/lib/c10d/ProcessGroupNCCL.hpp", + "torch/csrc/autograd/generated/VariableType.h", + "torch/csrc/autograd/generated/RegistrationDeclarations.h", + "torch/csrc/autograd/generated/variable_factories.h", + "torch/csrc/autograd/generated/Functions.h", ] + torch_cuda_headers, ) + [":cpp_generated_code"], includes = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index 826c187b602e8..b84a30469cb53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,8 +30,16 @@ endif() set(CMAKE_INSTALL_MESSAGE NEVER) +# check and set CMAKE_CXX_STANDARD +string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard) +if(env_cxx_standard GREATER -1) + message( + WARNING "C++ standard version definition detected in environment variable." + "PyTorch requires -std=c++14. Please remove -std=c++ settings in your environment.") +endif() set(CMAKE_CXX_STANDARD 14) set(CMAKE_C_STANDARD 11) + if(DEFINED GLIBCXX_USE_CXX11_ABI) if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) set(CXX_STANDARD_REQUIRED ON) @@ -83,6 +91,9 @@ if(APPLE) set(CMAKE_MACOSX_RPATH ON) endif() +set(CPU_AARCH64 OFF) +set(CPU_INTEL OFF) + if(WIN32) # On Windows, CMAKE_HOST_SYSTEM_PROCESSOR is calculated through `PROCESSOR_ARCHITECTURE`, # which only has the value of `x86` or `AMD64`. We cannot infer whether it's a Intel CPU @@ -93,17 +104,17 @@ if(WIN32) set(CPU_INTEL OFF) endif() else() - if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "(x86_64|i[3-6]+86)") + if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64|i[3-6]+86)") set(CPU_INTEL ON) - else() - set(CPU_INTEL OFF) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)") + set(CPU_AARCH64 ON) endif() endif() # For non-supported platforms, turn USE_DISTRIBUTED off by default. # It is not tested and likely won't work without additional changes. -if(NOT LINUX) +if(NOT LINUX AND NOT WIN32) set(USE_DISTRIBUTED OFF CACHE STRING "Use distributed") # On macOS, if USE_DISTRIBUTED is enabled (specified by the user), # then make Gloo build with the libuv transport. @@ -137,7 +148,8 @@ cmake_dependent_option( "NOT BUILD_SHARED_LIBS" OFF) option(BUILD_TEST "Build C++ test binaries (need gtest and gbenchmark)" OFF) option(BUILD_STATIC_RUNTIME_BENCHMARK "Build C++ binaries for static runtime benchmarks (need gbenchmark)" OFF) -option(BUILD_MOBILE_BENCHMARKS "Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)" OFF) +option(BUILD_TENSOREXPR_BENCHMARK "Build C++ binaries for tensorexpr benchmarks (need gbenchmark)" OFF) +option(BUILD_MOBILE_BENCHMARK "Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)" OFF) option(BUILD_MOBILE_TEST "Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)" OFF) option(BUILD_JNI "Build JNI bindings" OFF) option(BUILD_MOBILE_AUTOGRAD "Build autograd function in mobile build (in development)" OFF) @@ -158,6 +170,7 @@ cmake_dependent_option( USE_STATIC_CUDNN "Use cuDNN static libraries" OFF "USE_CUDNN" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) +option(USE_KINETO "Use Kineto profiling library" OFF) option(USE_FAKELOWP "Use FakeLowp operators" OFF) option(USE_FFMPEG "Use ffmpeg" OFF) option(USE_GFLAGS "Use GFLAGS" OFF) @@ -165,11 +178,14 @@ option(USE_GLOG "Use GLOG" OFF) option(USE_LEVELDB "Use LEVELDB" OFF) option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF) option(USE_LMDB "Use LMDB" OFF) -option(USE_METAL "Use Metal for iOS build" ON) +option(USE_METAL "Use Metal for Caffe2 iOS build" ON) +option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option( USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_RCCL "Use RCCL" ON + USE_NCCL OFF) cmake_dependent_option( USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) @@ -198,16 +214,22 @@ option(USE_SNPE "Use Qualcomm's SNPE library" OFF) option(USE_SYSTEM_EIGEN_INSTALL "Use system Eigen instead of the one under third_party" OFF) option(USE_TENSORRT "Using Nvidia TensorRT library" OFF) +cmake_dependent_option( + USE_VALGRIND "Use Valgrind. Only available on Linux." ON + "LINUX" OFF) option(USE_VULKAN "Use Vulkan GPU backend" OFF) -option(USE_VULKAN_WRAPPER "Use Vulkan wrapper" ON) -option(USE_VULKAN_SHADERC_RUNTIME "Use Vulkan Shader compilation runtime(Needs shaderc lib)" OFF) -option(USE_VULKAN_RELAXED_PRECISION "Use Vulkan relaxed precision(mediump)" OFF) +option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference even on fp32 tensors" OFF) +option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF) +option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation (needs libshaderc)" OFF) +option(USE_VULKAN_WRAPPER "Vulkan - Dynamically load Vulkan functions" ON) option(USE_XNNPACK "Use XNNPACK" ON) option(USE_ZMQ "Use ZMQ" OFF) option(USE_ZSTD "Use ZSTD" OFF) +# Ensure that an MKLDNN build is the default for x86 CPUs +# but optional for AArch64 (dependent on -DUSE_MKLDNN). cmake_dependent_option( - USE_MKLDNN "Use MKLDNN. Only available on x86 and x86_64." ON - "CPU_INTEL" OFF) + USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, and AArch64." ON + "CPU_INTEL OR CPU_AARCH64 AND USE_MKLDNN" OFF) set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN}) cmake_dependent_option( USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF @@ -226,6 +248,32 @@ option(USE_TBB "Use TBB" OFF) option(ONNX_ML "Enable traditional ONNX ML API." ON) option(HAVE_SOVERSION "Whether to add SOVERSION to the shared objects" OFF) +# Since TensorPipe does not support Windows, set it to OFF when WIN32 detected +# On Windows platform, if user does not install libuv in build conda env and +# does not set libuv_ROOT environment variable. Set USE_DISTRIBUTED to OFF. +if(WIN32) + set(USE_TENSORPIPE OFF) + message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF") + + if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT}) + find_library( + libuv_tmp_LIBRARY + NAMES uv libuv + HINTS $ENV{CONDA_PREFIX}\\Library $ENV{PREFIX}\\Library + PATH_SUFFIXES lib + NO_DEFAULT_PATH) + if(NOT libuv_tmp_LIBRARY) + set(USE_DISTRIBUTED OFF) + set(USE_GLOO OFF) + message( + WARNING "Libuv is not installed in current conda env. Set USE_DISTRIBUTED to OFF. " + "Please run command 'conda install -c conda-forge libuv=1.39' to install libuv.") + else() + set(ENV{libuv_ROOT} ${libuv_tmp_LIBRARY}/../../) + endif() + endif() +endif() + # Linux distributions do not want too many embedded sources, in that sense we # need to be able to build pytorch with an (almost) empty third_party # directory. @@ -283,7 +331,7 @@ set(OP_DEPENDENCY "" CACHE STRING # symbol lookup error: miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: mkl_blas_dsyrk # https://software.intel.com/en-us/articles/symbol-lookup-error-when-linking-intel-mkl-with-gcc-on-ubuntu if(LINUX) - set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-as-needed") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed") endif() if(MSVC) @@ -424,8 +472,6 @@ if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE) endif() # ---[ Utils -# TODO: merge the following 3 files into cmake/public/utils.cmake. -include(cmake/Utils.cmake) include(cmake/public/utils.cmake) # ---[ Version numbers for generated libraries @@ -472,6 +518,20 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build from: Debug Release RelWithDebInfo MinSizeRel Coverage." FORCE) endif() +# The below means we are cross compiling for arm64 or x86_64 on MacOSX +if(NOT IOS AND CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64)$") + set(CROSS_COMPILING_MACOSX TRUE) + # We need to compile a universal protoc to not fail protobuf build + execute_process(COMMAND ./scripts/build_host_protoc.sh --other-flags "-DCMAKE_OSX_ARCHITECTURES=x86_64;arm64" + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + RESULT_VARIABLE BUILD_HOST_PROTOC_RESULT) + if(NOT BUILD_HOST_PROTOC_RESULT EQUAL "0") + message(FATAL_ERROR "Could not compile universal protoc.") + endif() + set(PROTOBUF_PROTOC_EXECUTABLE "${PROJECT_SOURCE_DIR}/build_host_protoc/bin/protoc") + set(CAFFE2_CUSTOM_PROTOC_EXECUTABLE "${PROJECT_SOURCE_DIR}/build_host_protoc/bin/protoc") +endif() + # ---[ Misc checks to cope with various compiler modes include(cmake/MiscCheck.cmake) @@ -484,12 +544,31 @@ if(USE_FBGEMM AND ((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VO set(USE_FBGEMM OFF) endif() +if(USE_KINETO AND INTERN_BUILD_MOBILE) + message(STATUS "Not using libkineto in a mobile build.") + set(USE_KINETO OFF) +endif() + +if(USE_KINETO AND (NOT USE_CUDA)) + message(STATUS "Not using libkineto in a non-CUDA build.") + set(USE_KINETO OFF) +endif() + +if(USE_KINETO AND MSVC) + message(STATUS "Not using libkineto in a Windows build.") + set(USE_KINETO OFF) +endif() + include(cmake/Dependencies.cmake) if(USE_FBGEMM) string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM") endif() +if(USE_KINETO) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO") +endif() + if(USE_QNNPACK) string(APPEND CMAKE_CXX_FLAGS " -DUSE_QNNPACK") endif() @@ -504,18 +583,27 @@ endif() if(USE_VULKAN) string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN") -endif() + string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_API") -if(USE_VULKAN_WRAPPER) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_WRAPPER") -endif() + if(USE_VULKAN_FP16_INFERENCE) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_FP16_INFERENCE") + endif() + + if(USE_VULKAN_RELAXED_PRECISION) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION") + endif() + + if(USE_VULKAN_SHADERC_RUNTIME) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_SHADERC_RUNTIME") + endif() -if(USE_VULKAN_SHADERC_RUNTIME) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_SHADERC_RUNTIME") + if(USE_VULKAN_WRAPPER) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_WRAPPER") + endif() endif() -if(USE_VULKAN_RELAXED_PRECISION) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION") +if(USE_PYTORCH_METAL) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL") endif() # ---[ Allowlist file if allowlist is specified @@ -608,6 +696,10 @@ if(NOT MSVC) if(HAS_WERROR_FORMAT) string(APPEND CMAKE_CXX_FLAGS " -Werror=format") endif() + check_cxx_compiler_flag("-Werror=cast-function-type" HAS_WERROR_CAST_FUNCTION_TYPE) + if(HAS_WERROR_CAST_FUNCTION_TYPE) + string(APPEND CMAKE_CXX_FLAGS " -Werror=cast-function-type") + endif() endif() if(USE_ASAN) @@ -621,8 +713,8 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") int main() { float a[] = {1.0, 1.0}; float32x4x2_t v; - v.val[0] = vcombine_f32 (vcreate_f32 (__AARCH64_UINT64_C (0)), vcreate_f32 (__AARCH64_UINT64_C (0))); - v.val[1] = vcombine_f32 (vcreate_f32 (__AARCH64_UINT64_C (0)), vcreate_f32 (__AARCH64_UINT64_C (0))); + v.val[0] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); + v.val[1] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); vst1q_f32_x2(a, v); return 0; }" HAS_VST1) @@ -679,6 +771,8 @@ endif() if(ANDROID AND (NOT ANDROID_DEBUG_SYMBOLS)) if(CMAKE_COMPILER_IS_GNUCXX) string(APPEND CMAKE_CXX_FLAGS " -s") + elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + string(APPEND CMAKE_CXX_FLAGS " -g0") else() string(APPEND CMAKE_EXE_LINKER_FLAGS " -s") endif() diff --git a/CODEOWNERS b/CODEOWNERS index 77b8d2cbcb365..a0e4814ce4cd8 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,20 +1,15 @@ # This is a comment. # Each line is a file pattern followed by one or more owners. -/docs/cpp @goldsborough @ebetica @yf225 -/torch/csrc/api/ @ebetica @goldsborough @yf225 -/test/cpp/api/ @ebetica @goldsborough @yf225 -/torch/utils/cpp_extension.py @goldsborough @fmassa @soumith @ezyang +/docs/cpp @glaringlee +/torch/csrc/api/ @glaringlee +/test/cpp/api/ @glaringlee +/torch/utils/cpp_extension.py @fmassa @soumith @ezyang # Not there to strictly require the approval, but to be tagged as a reviewer # on the PRs to push them into a high priority inbox. -/torch/csrc/api/data/ @apaszke -/torch/csrc/autograd/ @apaszke @albanD -/torch/csrc/jit/ @apaszke -/torch/nn/ @apaszke -/torch/autograd/ @apaszke @albanD -/torch/jit/ @apaszke -/torch/utils/data/ @apaszke +/torch/csrc/autograd/ @albanD +/torch/autograd/ @albanD # Tensorpipe RPC Agent. /torch/csrc/distributed/rpc/tensorpipe_agent.cpp @jiayisuse @osalpekar @lw @beauby @@ -23,9 +18,9 @@ # Distributed package # This list is mostly if you'd like to be tagged as reviewer, feel free to add # or remove yourself from it. -/torch/lib/c10d/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma -/torch/csrc/distributed/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma -/torch/distributed/ @apaszke @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma +/torch/lib/c10d/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088 +/torch/csrc/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088 +/torch/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088 # Distributed tests # This list is mostly if you'd like to be tagged as reviewer, feel free to add diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 03ad14dd843e4..98e53c378564a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,10 +2,12 @@ - [Contributing to PyTorch](#contributing-to-pytorch) - [Developing PyTorch](#developing-pytorch) - - [Nightly Checkout & Pull](#nightly-checkout--pull) + - [Tips and Debugging](#tips-and-debugging) +- [Nightly Checkout & Pull](#nightly-checkout--pull) - [Codebase structure](#codebase-structure) - [Unit testing](#unit-testing) - [Better local unit tests with pytest](#better-local-unit-tests-with-pytest) + - [Running `mypy`](#running-mypy) - [Writing documentation](#writing-documentation) - [Building documentation](#building-documentation) - [Tips](#tips) @@ -118,11 +120,37 @@ For example: - modify your Python file `torch/__init__.py` - test functionality -You do not need to repeatedly install after modifying Python files. +You do not need to repeatedly install after modifying Python files (`.py`). However, you would need to reinstall +if you modify Python interface (`.pyi`, `.pyi.in`) or non-Python files (`.cpp`, `.cc`, `.cu`, `.h`, ...). In case you want to reinstall, make sure that you uninstall PyTorch first by running `pip uninstall torch` and `python setup.py clean`. Then you can install in `develop` mode again. +### Tips and Debugging +* A prerequisite to installing PyTorch is CMake. We recommend installing it with [Homebrew](https://brew.sh/) +with `brew install cmake` if you are developing on MacOS or Linux system. +* Our `setup.py` requires Python >= 3.6 +* If you run into errors when running `python setup.py develop`, here are some debugging steps: + 1. Run `printf '#include \nint main() { printf("Hello World");}'|clang -x c -; ./a.out` to make sure + your CMake works and can compile this simple Hello World program without errors. + 2. Nuke your `build` directory. The `setup.py` script compiles binaries into the `build` folder and caches many + details along the way, which saves time the next time you build. If you're running into issues, you can always + `rm -rf build` from the toplevel `pytorch` directory and start over. + 3. If you have made edits to the PyTorch repo, commit any change you'd like to keep and clean the repo with the + following commands (note that clean _really_ removes all untracked files and changes.): + ```bash + git submodule deinit -f . + git clean -xdf + python setup.py clean + git submodule update --init --recursive # very important to sync the submodules + python setup.py develop # then try running the command again + ``` + 4. The main step within `python setup.py develop` is running `make` from the `build` directory. If you want to + experiment with some environment variables, you can pass them into the command: + ```bash + ENV_KEY1=ENV_VAL1[, ENV_KEY2=ENV_VAL2]* python setup.py develop + ``` + ## Nightly Checkout & Pull The `tools/nightly.py` script is provided to ease pure Python development of @@ -265,6 +293,17 @@ pytest test/test_nn.py -k Loss -v The above is an example of testing a change to Loss functions: this command runs tests such as `TestNN.test_BCELoss` and `TestNN.test_MSELoss` and can be useful to save keystrokes. +### Running `mypy` + +One of the test suites runs `mypy` on the codebase: +```bash +python test/test_type_hints.py +``` +See [Guide for adding type annotations to +PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch) +for more information on how to set up `mypy` and tackle type annotation +tasks, as well as other ways to run `mypy` besides running that test suite. + ## Writing documentation PyTorch uses [Google style](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) @@ -287,6 +326,8 @@ pip install -r requirements.txt # npm install -g katex # Or if you prefer an uncontaminated global executable environment or do not want to go through the node configuration: # npm install katex && export PATH="$PATH:$(pwd)/node_modules/.bin" +# If you're a Facebook employee using a devserver, yarn may be more convenient: +# yarn global add katex ``` 3. Generate the documentation HTML files. The generated files will be in `docs/build/html`. @@ -327,7 +368,7 @@ information on the documentation syntax. We run Doxygen in CI (Travis) to verify that you do not use invalid Doxygen commands. To run this check locally, run `./check-doxygen.sh` from inside -`docs/cpp`. +`docs/cpp/source`. To build the documentation, follow the same steps as above, but run them from `docs/cpp` instead of `docs`. @@ -352,6 +393,14 @@ et my_machine -t="8000:8000" Then navigate to `localhost:8000` in your web browser. +Alternatively, you can run `rsync` on your local machine to copy the files from +your remote machine: +```bash +mkdir -p build cpp/build +rsync -az me@my_machine:/path/to/pytorch/docs/build/html build +rsync -az me@my_machine:/path/to/pytorch/docs/cpp/build/html cpp/build +``` + #### Submitting changes for review It is helpful when submitting a PR that changes the docs to provide a rendered @@ -489,8 +538,7 @@ only interested in a specific component. - Working on a test binary? Run `(cd build && ninja bin/test_binary_name)` to rebuild only that test binary (without rerunning cmake). (Replace `ninja` with `make` if you don't have ninja installed). -- Don't need Caffe2? Pass `BUILD_CAFFE2_OPS=0` to disable build of - Caffe2 operators. +- Don't need Caffe2? Pass `BUILD_CAFFE2=0` to disable Caffe2 build. On the initial build, you can also speed things up with the environment variables `DEBUG`, `USE_DISTRIBUTED`, `USE_MKLDNN`, `USE_CUDA`, `BUILD_TEST`, `USE_FBGEMM`, `USE_NNPACK` and `USE_QNNPACK`. @@ -639,7 +687,7 @@ ccache -M 25Gi ``` To check this is working, do two clean builds of pytorch in a row. The second -build should be substantially and noticeably faster than the first build. +build should be substantially and noticeably faster than the first build. If this doesn't seem to be the case, check that each of the symlinks above actually link to your installation of `ccache`. For example, if you followed the first option and installed `ccache` from source on a Linux machine, running `readlink -e $(which g++)` should return `~/ccache/bin/ccache`. #### Use a faster linker @@ -719,7 +767,7 @@ than Linux, which are worth keeping in mind when fixing these problems. 1. Symbols are NOT exported by default on Windows; instead, you have to explicitly mark a symbol as exported/imported in a header file with `__declspec(dllexport)` / `__declspec(dllimport)`. We have codified this pattern into a set of macros - which follow the convention `*_API`, e.g., `CAFFE2_API` inside Caffe2 and ATen. + which follow the convention `*_API`, e.g., `TORCH_API` inside Caffe2, Aten and Torch. (Every separate shared library needs a unique macro name, because symbol visibility is on a per shared library basis. See c10/macros/Macros.h for more details.) @@ -856,7 +904,7 @@ which is in PyTorch's `requirements.txt`. ## Pre-commit tidy/linting hook We use clang-tidy and flake8 (installed with flake8-bugbear, -flake8-comprehensions, flake8-mypy, and flake8-pyi) to perform additional +flake8-comprehensions, flake8-pyi, and others) to perform additional formatting and semantic checking of code. We provide a pre-commit git hook for performing these checks, before a commit is created: @@ -868,6 +916,16 @@ You'll need to install an appropriately configured flake8; see [Lint as you type](https://github.com/pytorch/pytorch/wiki/Lint-as-you-type) for documentation on how to do this. +If you haven't set up the pre-commit hook and have already committed files and +CI reports `flake8` errors, you can run the check locally in your PR branch with: + + ```bash + flake8 $(git diff --name-only $(git merge-base --fork-point master)) + ``` + +fix the code so that no errors are reported when you re-run the above check again, +and then commit the fix. + ## Building PyTorch with ASAN [ASAN](https://github.com/google/sanitizers/wiki/AddressSanitizer) is very diff --git a/Dockerfile b/Dockerfile index 5bae3ec14ea6c..cbaa85597ad95 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ # For reference: # https://docs.docker.com/develop/develop-images/build_enhancements/ ARG BASE_IMAGE=ubuntu:18.04 -ARG PYTHON_VERSION=3.7 +ARG PYTHON_VERSION=3.8 FROM ${BASE_IMAGE} as dev-base RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ @@ -27,6 +27,7 @@ RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache ENV PATH /opt/conda/bin:$PATH FROM dev-base as conda +ARG PYTHON_VERSION=3.8 RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ chmod +x ~/miniconda.sh && \ ~/miniconda.sh -b -p /opt/conda && \ @@ -49,11 +50,16 @@ RUN --mount=type=cache,target=/opt/ccache \ python setup.py install FROM conda as conda-installs +ARG PYTHON_VERSION=3.8 +ARG CUDA_VERSION=11.0 ARG INSTALL_CHANNEL=pytorch-nightly -RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y pytorch torchvision cudatoolkit=11.0.221 && \ +ENV CONDA_OVERRIDE_CUDA=${CUDA_VERSION} +RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERSION} pytorch torchvision torchtext "cudatoolkit=${CUDA_VERSION}" && \ /opt/conda/bin/conda clean -ya +RUN /opt/conda/bin/pip install torchelastic FROM ${BASE_IMAGE} as official +ARG PYTORCH_VERSION LABEL com.nvidia.volumes.needed="nvidia_driver" RUN --mount=type=cache,id=apt-final,target=/var/cache/apt \ apt-get update && apt-get install -y --no-install-recommends \ @@ -66,6 +72,7 @@ ENV PATH /opt/conda/bin:$PATH ENV NVIDIA_VISIBLE_DEVICES all ENV NVIDIA_DRIVER_CAPABILITIES compute,utility ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV PYTORCH_VERSION ${PYTORCH_VERSION} WORKDIR /workspace FROM official as dev diff --git a/LICENSE b/LICENSE index 4167b929cc741..9cb8cbef5a9f8 100644 --- a/LICENSE +++ b/LICENSE @@ -16,23 +16,26 @@ Copyright (c) 2016-present, Facebook Inc. All rights reserved. All contributions by Facebook: Copyright (c) 2016 Facebook Inc. - + All contributions by Google: Copyright (c) 2015 Google Inc. All rights reserved. - + All contributions by Yangqing Jia: Copyright (c) 2015 Yangqing Jia All rights reserved. - + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + All contributions from Caffe: Copyright(c) 2013, 2014, 2015, the respective contributors All rights reserved. - + All other contributions: Copyright(c) 2015, 2016 the respective contributors All rights reserved. - + Caffe2 uses a copyright model similar to Caffe: each contributor holds copyright over their contributions to Caffe2. The project versioning records all such contribution and copyright details. If a contributor wants to further diff --git a/NOTICE b/NOTICE index a346cb891713d..5abaac479a752 100644 --- a/NOTICE +++ b/NOTICE @@ -22,6 +22,9 @@ All contributions by Yangqing Jia: Copyright (c) 2015 Yangqing Jia All rights reserved. +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + All other contributions: Copyright(c) 2015, 2016 the respective contributors All rights reserved. @@ -281,6 +284,112 @@ Apache License Version 2.0: incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +======================================================================= +Cephes's 3-Clause BSD License +======================================================================= + +Code derived from implementations in the Cephes Math Library should mention +its derivation and reference the following license: + + 3-Clause BSD License for the Cephes Math Library + Copyright (c) 2018, Steven Moshier + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +SciPy's 3-Clause BSD License +======================================================================= + +Code derived from implementations in SciPy should mention its derivation +and reference the following license: + + Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================= +Boost's 1.0 Software License +======================================================================= + +Code derived from implementations in Boost 1.0 should mention its +derivation and reference the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + + Permission is hereby granted, free of charge, to any person or organization + obtaining a copy of the software and accompanying documentation covered by + this license (the "Software") to use, reproduce, display, distribute, + execute, and transmit the Software, and to prepare derivative works of the + Software, and to permit third-parties to whom the Software is furnished to + do so, all subject to the following: + + The copyright notices in the Software and this entire statement, including + the above license grant, this restriction and the following disclaimer, + must be included in all copies of the Software, in whole or in part, and + all derivative works of the Software, unless such copies or derivative + works are solely in the form of machine-executable object code generated by + a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. diff --git a/README.md b/README.md index 6191cabcb6854..d29eacc28664a 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ At a granular level, PyTorch is a library that consists of the following compone Usually, PyTorch is used either as: -- a replacement for NumPy to use the power of GPUs. -- a deep learning research platform that provides maximum flexibility and speed. +- A replacement for NumPy to use the power of GPUs. +- A deep learning research platform that provides maximum flexibility and speed. -Elaborating further: +Elaborating Further: ### A GPU-Ready Tensor Library @@ -149,7 +149,7 @@ They require JetPack 4.2 and above, and [@dusty-nv](https://github.com/dusty-nv) ### From Source -If you are installing from source, you will need Python 3.6 or later and a C++14 compiler. Also, we highly recommend installing an [Anaconda](https://www.anaconda.com/distribution/#download-section) environment. +If you are installing from source, you will need Python 3.6.2 or later and a C++14 compiler. Also, we highly recommend installing an [Anaconda](https://www.anaconda.com/distribution/#download-section) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. Once you have [Anaconda](https://www.anaconda.com/distribution/#download-section) installed, here are the instructions. @@ -158,6 +158,7 @@ If you want to compile with CUDA support, install - [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 9.2 or above - [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v7 or above - [Compiler](https://gist.github.com/ax3l/9489132) compatible with CUDA +Note: You could refer to the [cuDNN Support Matrix](https://docs.nvidia.com/deeplearning/cudnn/pdf/cuDNN-Support-Matrix.pdf) for cuDNN versions with the various supported CUDA, CUDA driver and NVIDIA hardwares If you want to disable CUDA support, export environment variable `USE_CUDA=0`. Other potentially useful environment variables may be found in `setup.py`. @@ -175,7 +176,7 @@ conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_ex On Linux ```bash # Add LAPACK support for the GPU if needed -conda install -c pytorch magma-cuda102 # or [ magma-cuda101 | magma-cuda100 | magma-cuda92 ] depending on your cuda version +conda install -c pytorch magma-cuda110 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo ``` On MacOS @@ -184,6 +185,13 @@ On MacOS conda install pkg-config libuv ``` +On Windows +```bash +# Add these packages if torch.distributed is needed. +# Distributed package support on Windows is a prototype feature and is subject to changes. +conda install -c conda-forge libuv=1.39 +``` + #### Get the PyTorch Source ```bash git clone --recursive https://github.com/pytorch/pytorch @@ -200,6 +208,16 @@ export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py install ``` +Note that if you are using [Anaconda](https://www.anaconda.com/distribution/#download-section), you may experience an error caused by the linker: + +```plaintext +build/temp.linux-x86_64-3.7/torch/csrc/stub.o: file not recognized: file format not recognized +collect2: error: ld returned 1 exit status +error: command 'g++' failed with exit status 1 +``` + +This is caused by `ld` from Conda environment shadowing the system `ld`. You should use a newer version of Python that fixes this issue. The recommended Python version is 3.6.10+, 3.7.6+ and 3.8.1+. + On macOS ```bash export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} @@ -216,44 +234,52 @@ Each CUDA version only supports one particular XCode version. The following comb On Windows -At least Visual Studio 2017 version 15.6 with the toolset 14.13 and [NVTX](https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) are needed. +Build with CPU + +It's fairly easy to build with CPU. Visual Studio 2019 version 16.7.6 (MSVC toolchain version 14.27) or higher is recommended. -If the version of Visual Studio 2017 is higher than 15.6, installing of "VC++ 2017 version 15.6 v14.13 toolset" is strongly recommended. -
If the version of Visual Studio 2017 is lesser than 15.6, please update Visual Studio 2017 to the latest version along with installing "VC++ 2017 version 15.6 v14.13 toolset". -
There is no guarantee of the correct building with VC++ 2017 toolsets, others than version 15.6 v14.13. -
"VC++ 2017 version 15.6 v14.13 toolset" might be installed onto already installed Visual Studio 2017 by running its installation once again and checking the corresponding checkbox under "Individual components"/"Compilers, build tools, and runtimes". +Build with CUDA +[NVTX](https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) is needed to build Pytorch with CUDA. NVTX is a part of CUDA distributive, where it is called "Nsight Compute". To install it onto already installed CUDA run CUDA installation once again and check the corresponding checkbox. -Be sure that CUDA with Nsight Compute is installed after Visual Studio 2017. +Make sure that CUDA with Nsight Compute is installed after Visual Studio. + +Currently, VS 2017 / 2019, and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise, it will use VS 2017 / 2019. +
If Ninja is selected as the generator, the latest MSVC will get selected as the underlying toolchain. + +CUDA, MSVC, and PyTorch versions are interdependent; please install matching versions from this table: +| CUDA version | Newest supported VS version | PyTorch version | +| ------------ | ------------------------------------------------------- | --------------- | +| 9.2 | Visual Studio 2017 Update 5 (15.5) (`_MSC_VER` <= 1912) | 0.4.1 ~ 1.5.1 | +| 10.1 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) | 1.3.0 ~ 1.7.0 | +| 10.2 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) | 1.5.0 ~ 1.7.0 | +| 11.0 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) | 1.7.0 | + +Note: There's a [compilation issue](https://github.com/oneapi-src/oneDNN/issues/812) in several Visual Studio 2019 versions since 16.7.1, so please make sure your Visual Studio 2019 version is not in 16.7.1 ~ 16.7.5 -Currently, VS 2017, VS 2019, and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise, it will use VS 2017. -
If Ninja is selected as the generator, the latest MSVC which is newer than VS 2015 (14.0) will get selected as the underlying toolchain. If you use CMake <= 3.14.2 and has VS 2019 installed, then even if you specify VS 2017 as the generator, VS 2019 will get selected as the generator. +Additional libraries such as +[Magma](https://developer.nvidia.com/magma), [oneDNN, a.k.a MKLDNN or DNNL](https://github.com/oneapi-src/oneDNN), and [Sccache](https://github.com/mozilla/sccache) are often needed. Please refer to the [installation-helper](https://github.com/pytorch/pytorch/tree/master/.jenkins/pytorch/win-test-helpers/installation-helpers) to install them. -CUDA and MSVC have strong version dependencies, so even if you use VS 2017 / 2019, you will get build errors like `nvcc fatal : Host compiler targets unsupported OS`. For this kind of problem, please install the corresponding VS toolchain in the table below, and then you can either specify the toolset during activation (recommended) or set `CUDAHOSTCXX` to override the Cuda host compiler (not recommended if there are big version differences). +You can refer to the [build_pytorch.bat](https://github.com/pytorch/pytorch/blob/master/.jenkins/pytorch/win-test-helpers/build_pytorch.bat) script for some other environment variables configurations -| CUDA version | Newest supported VS version | -| ------------ | ------------------------------------------------------- | -| 9.2 | Visual Studio 2017 Update 5 (15.5) (`_MSC_VER` <= 1912) | -| 10.0 | Visual Studio 2017 (15.X) (`_MSC_VER` < 1920) | -| 10.1 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) | ```cmd cmd -:: [Optional] If you want to build with VS 2019 generator, please change the value in the next line to `Visual Studio 16 2019`. +:: [Optional] If you want to build with the VS 2017 generator for old CUDA and PyTorch, please change the value in the next line to `Visual Studio 15 2017`. :: Note: This value is useless if Ninja is detected. However, you can force that by using `set USE_NINJA=OFF`. -set CMAKE_GENERATOR=Visual Studio 15 2017 +set CMAKE_GENERATOR=Visual Studio 16 2019 :: Read the content in the previous section carefully before you proceed. :: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block. -:: "Visual Studio 2017 Developer Command Prompt" will be run automatically. +:: "Visual Studio 2019 Developer Command Prompt" will be run automatically. :: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator. -set CMAKE_GENERATOR_TOOLSET_VERSION=14.11 +set CMAKE_GENERATOR_TOOLSET_VERSION=14.27 set DISTUTILS_USE_SDK=1 for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,16^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION% -:: [Optional] If you want to override the Cuda host compiler -set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Tools\MSVC\14.11.25503\bin\HostX64\x64\cl.exe +:: [Optional] If you want to override the CUDA host compiler +set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe python setup.py install @@ -323,7 +349,7 @@ If you get a katex error run ```npm install katex```. If it persists, try ### Previous Versions Installation instructions and binaries for previous PyTorch versions may be found -on [our website](https://pytorch.org/previous-versions). +on [Our Website](https://pytorch.org/previous-versions). ## Getting Started @@ -348,12 +374,12 @@ Three-pointers to get you started: * [PyTorch YouTube](https://www.youtube.com/channel/UCWXI5YeOsh03QvJ59PMaXFw) ## Communication -* forums: discuss implementations, research, etc. https://discuss.pytorch.org -* GitHub issues: bug reports, feature requests, install issues, RFCs, thoughts, etc. +* Forums: Discuss implementations, research, etc. https://discuss.pytorch.org +* GitHub Issues: Bug reports, feature requests, install issues, RFCs, thoughts, etc. * Slack: The [PyTorch Slack](https://pytorch.slack.com/) hosts a primary audience of moderate to experienced PyTorch users and developers for general chat, online discussions, collaboration, etc. If you are a beginner looking for help, the primary medium is [PyTorch Forums](https://discuss.pytorch.org). If you need a slack invite, please fill this form: https://goo.gl/forms/PP1AGvNHpSaJP8to1 -* newsletter: no-noise, a one-way email newsletter with important announcements about PyTorch. You can sign-up here: https://eepurl.com/cbG0rv -* Facebook page: important announcements about PyTorch. https://www.facebook.com/pytorch -* for brand guidelines, please visit our website at [pytorch.org](https://pytorch.org/) +* Newsletter: No-noise, a one-way email newsletter with important announcements about PyTorch. You can sign-up here: https://eepurl.com/cbG0rv +* Facebook Page: Important announcements about PyTorch. https://www.facebook.com/pytorch +* For brand guidelines, please visit our website at [pytorch.org](https://pytorch.org/) ## Releases and Contributing @@ -373,8 +399,8 @@ PyTorch is a community-driven project with several skillful engineers and resear PyTorch is currently maintained by [Adam Paszke](https://apaszke.github.io/), [Sam Gross](https://github.com/colesbury), [Soumith Chintala](http://soumith.ch) and [Gregory Chanan](https://github.com/gchanan) with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito. -Note: this project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. +Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. ## License -PyTorch is a BSD-style licensed, as found in the [LICENSE](LICENSE) file. +PyTorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. diff --git a/android/README.md b/android/README.md index bf5fa02e6cf47..e67b2e6ec0711 100644 --- a/android/README.md +++ b/android/README.md @@ -15,8 +15,8 @@ repositories { } dependencies { - implementation 'org.pytorch:pytorch_android:1.5.0' - implementation 'org.pytorch:pytorch_android_torchvision:1.5.0' + implementation 'org.pytorch:pytorch_android:1.6.0' + implementation 'org.pytorch:pytorch_android_torchvision:1.6.0' } ``` @@ -34,12 +34,12 @@ repositories { dependencies { ... - implementation 'org.pytorch:pytorch_android:1.7.0-SNAPSHOT' - implementation 'org.pytorch:pytorch_android_torchvision:1.7.0-SNAPSHOT' + implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT' + implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT' ... } ``` -The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.7.0-SNAPSHOT`. +The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.8.0-SNAPSHOT`. ## Building PyTorch Android from Source diff --git a/android/gradle.properties b/android/gradle.properties index 6e0dc0ac86b04..0ab42c56396d0 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -1,6 +1,6 @@ ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64 -VERSION_NAME=1.7.0-SNAPSHOT +VERSION_NAME=1.8.0-SNAPSHOT GROUP=org.pytorch MAVEN_GROUP=org.pytorch POM_URL=https://github.com/pytorch/pytorch/tree/master/android diff --git a/android/gradle/android_tasks.gradle b/android/gradle/android_tasks.gradle index ca188ac72d078..0d5932559e470 100644 --- a/android/gradle/android_tasks.gradle +++ b/android/gradle/android_tasks.gradle @@ -1,4 +1,3 @@ - import java.nio.file.Files import java.nio.file.Paths import java.io.FileOutputStream diff --git a/android/libs/fbjni b/android/libs/fbjni index f908b58be4828..b592c5591345a 160000 --- a/android/libs/fbjni +++ b/android/libs/fbjni @@ -1 +1 @@ -Subproject commit f908b58be482874137fa4c0e71333e4eca481706 +Subproject commit b592c5591345a05341ed6cd31d214e71e8bf4229 diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt index f81b1bf05527b..290d5aba93034 100644 --- a/android/pytorch_android/CMakeLists.txt +++ b/android/pytorch_android/CMakeLists.txt @@ -1,5 +1,8 @@ cmake_minimum_required(VERSION 3.4.1) project(pytorch_jni CXX) + +include(GNUInstallDirs) + set(CMAKE_CXX_STANDARD 14) set(CMAKE_VERBOSE_MAKEFILE ON) message(STATUS "ANDROID_STL:${ANDROID_STL}") @@ -68,8 +71,8 @@ target_compile_options(pytorch_jni PRIVATE -fexceptions ) -target_include_directories(pytorch_jni PUBLIC - ${libtorch_include_DIR} +target_include_directories(pytorch_jni BEFORE + PUBLIC $ ) set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/) @@ -128,13 +131,26 @@ else() torch torch_cpu c10 - nnpack - XNNPACK - pytorch_qnnpack - pthreadpool cpuinfo clog ) + + if(USE_NNPACK) + list(APPEND pytorch_jni_LIBS nnpack) + endif() + + if(USE_XNNPACK) + list(APPEND pytorch_jni_LIBS XNNPACK) + endif() + + if(USE_SYSTEM_PTHREADPOOL) + list(APPEND pytorch_jni_LIBS pthreadpool) + endif() + + if(USE_PYTORCH_QNNPACK) + list(APPEND pytorch_jni_LIBS pytorch_qnnpack) + endif() + endif() if(USE_VULKAN) @@ -142,3 +158,13 @@ if(USE_VULKAN) endif() target_link_libraries(pytorch_jni ${pytorch_jni_LIBS}) + +install(TARGETS pytorch_jni + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows + +if(MSVC) + install(FILES $ DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL) + install(TARGETS pytorch_jni DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif() diff --git a/android/pytorch_android/generate_test_torchscripts.py b/android/pytorch_android/generate_test_torchscripts.py index 6384d588e9aa0..8b41fefc246eb 100644 --- a/android/pytorch_android/generate_test_torchscripts.py +++ b/android/pytorch_android/generate_test_torchscripts.py @@ -20,92 +20,77 @@ def forward(self, input): return None @torch.jit.script_method - def eqBool(self, input): - # type: (bool) -> bool + def eqBool(self, input: bool) -> bool: return input @torch.jit.script_method - def eqInt(self, input): - # type: (int) -> int + def eqInt(self, input: int) -> int: return input @torch.jit.script_method - def eqFloat(self, input): - # type: (float) -> float + def eqFloat(self, input: float) -> float: return input @torch.jit.script_method - def eqStr(self, input): - # type: (str) -> str + def eqStr(self, input: str) -> str: return input @torch.jit.script_method - def eqTensor(self, input): - # type: (Tensor) -> Tensor + def eqTensor(self, input: Tensor) -> Tensor: return input @torch.jit.script_method - def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] + def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input @torch.jit.script_method - def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] + def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input @torch.jit.script_method - def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] + def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input @torch.jit.script_method - def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] + def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def listBoolConjunction(self, input): - # type: (List[bool]) -> bool + def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res @torch.jit.script_method - def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool + def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res @torch.jit.script_method - def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] + def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool + def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None @torch.jit.script_method - def intEq0None(self, input): - # type: (int) -> Optional[int] + def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input @torch.jit.script_method - def str3Concat(self, input): - # type: (str) -> str + def str3Concat(self, input: str) -> str: return input + input + input @torch.jit.script_method @@ -113,8 +98,7 @@ def newEmptyShapeWithItem(self, input): return torch.tensor([int(input.item())])[0] @torch.jit.script_method - def testAliasWithOffset(self): - # type: () -> List[Tensor] + def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -128,8 +112,7 @@ def testNonContiguous(self): return x @torch.jit.script_method - def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor + def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.nn.functional.conv2d(x, w) if (toChannelsLast): r = r.contiguous(memory_format=torch.channels_last) @@ -138,18 +121,15 @@ def conv2d(self, x, w, toChannelsLast): return r @torch.jit.script_method - def contiguous(self, x): - # type: (Tensor) -> Tensor + def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() @torch.jit.script_method - def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last) @torch.jit.script_method - def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last_3d) scriptAndSave(Test(), "test.pt") diff --git a/android/pytorch_android/host/build.gradle b/android/pytorch_android/host/build.gradle index a808ae882ce4e..fe30660929b92 100644 --- a/android/pytorch_android/host/build.gradle +++ b/android/pytorch_android/host/build.gradle @@ -38,4 +38,3 @@ dependencies { } apply from: rootProject.file('gradle/release.gradle') - diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp index 11696daf43a20..fed6170c2bf30 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp @@ -18,6 +18,17 @@ namespace pytorch_jni { +c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) { + if (deviceJniCode == kDeviceCPU) { + return at::kCPU; + } else if (deviceJniCode == kDeviceVulkan) { + return at::kVulkan; + } + + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, "Unknown device"); +} + bool Trace::is_initialized_ = false; #if defined(TRACE_ENABLED) && defined(__ANDROID__) diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.h b/android/pytorch_android/src/main/cpp/pytorch_jni_common.h index fb974d4ad702e..9b4e7e5f84a10 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_common.h +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.h @@ -1,3 +1,5 @@ +#pragma once + #include #include @@ -18,6 +20,11 @@ namespace pytorch_jni { +constexpr static int kDeviceCPU = 1; +constexpr static int kDeviceVulkan = 2; + +c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode); + class Trace { public: #if defined(TRACE_ENABLED) && defined(__ANDROID__) diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp index b05c19665f20b..9cc71f117d935 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp @@ -67,33 +67,36 @@ class PytorchJni : public facebook::jni::HybridClass { private: friend HybridBase; torch::jit::Module module_; + c10::DeviceType deviceType_; public: constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;"; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, - facebook::jni::alias_ref modelPath) { - return makeCxxInstance(modelPath); + facebook::jni::alias_ref modelPath, + jint device) { + return makeCxxInstance(modelPath, device); } #ifdef __ANDROID__ static facebook::jni::local_ref initHybridAndroidAsset( facebook::jni::alias_ref, facebook::jni::alias_ref assetName, - facebook::jni::alias_ref assetManager) { - return makeCxxInstance(assetName, assetManager); + facebook::jni::alias_ref assetManager, + jint device) { + return makeCxxInstance(assetName, assetManager, device); } #endif #ifdef TRACE_ENABLED - static bool onFunctionEnter( + static std::unique_ptr onFunctionEnter( const at::RecordFunction& fn) { Trace::beginSection(fn.name().str()); - return true; + return nullptr; } - static void onFunctionExit(const at::RecordFunction&) { + static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) { Trace::endSection(); } #endif @@ -127,17 +130,19 @@ class PytorchJni : public facebook::jni::HybridClass { ((void)once); } - PytorchJni(facebook::jni::alias_ref modelPath) { + PytorchJni(facebook::jni::alias_ref modelPath, jint device) { preModuleLoadSetup(); JITCallGuard guard; module_ = torch::jit::load(std::move(modelPath->toStdString())); module_.eval(); + deviceType_ = deviceJniCodeToDeviceType(device); } #ifdef __ANDROID__ PytorchJni( facebook::jni::alias_ref assetName, - facebook::jni::alias_ref assetManager) { + facebook::jni::alias_ref assetManager, + jint device) { preModuleLoadSetup(); JNIEnv* env = facebook::jni::Environment::current(); AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get()); @@ -166,6 +171,7 @@ class PytorchJni : public facebook::jni::HybridClass { assetBuffer, AAsset_getLength(asset))); AAsset_close(asset); module_.eval(); + deviceType_ = deviceJniCodeToDeviceType(device); } #endif @@ -191,7 +197,14 @@ class PytorchJni : public facebook::jni::HybridClass { inputs.reserve(n); for (size_t i = 0; i < n; i++) { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); - inputs.push_back(std::move(atIValue)); + if (at::kVulkan == deviceType_) { + inputs.push_back( + atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()} + : std::move(atIValue)); + } else { + TORCH_CHECK(at::kCPU == deviceType_); + inputs.push_back(std::move(atIValue)); + } } auto output = [&]() { JITCallGuard guard; @@ -212,7 +225,14 @@ class PytorchJni : public facebook::jni::HybridClass { inputs.reserve(n); for (size_t i = 0; i < n; i++) { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); - inputs.push_back(std::move(atIValue)); + if (at::kVulkan == deviceType_) { + inputs.push_back( + atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()} + : std::move(atIValue)); + } else { + TORCH_CHECK(at::kCPU == deviceType_); + inputs.push_back(std::move(atIValue)); + } } if (auto method = module_.find_method(methodName)) { auto output = [&]() { diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp index 061b85221fe9e..8a96e395f267a 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp @@ -30,9 +30,6 @@ struct LiteJITCallGuard { } // namespace class PytorchJni : public facebook::jni::HybridClass { - constexpr static int kDeviceCPU = 1; - constexpr static int kDeviceVulkan = 2; - private: friend HybridBase; torch::jit::mobile::Module module_; @@ -51,15 +48,7 @@ class PytorchJni : public facebook::jni::HybridClass { PytorchJni(facebook::jni::alias_ref modelPath, jint device) { LiteJITCallGuard guard; module_ = torch::jit::_load_for_mobile(std::move(modelPath->toStdString())); - if (device == kDeviceCPU) { - deviceType_ = at::kCPU; - } else if (device == kDeviceVulkan) { - deviceType_ = at::kVulkan; - } else { - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Unknown device specified"); - } + deviceType_ = deviceJniCodeToDeviceType(device); } static void registerNatives() { @@ -108,7 +97,14 @@ class PytorchJni : public facebook::jni::HybridClass { inputs.reserve(n); for (size_t i = 0; i < n; i++) { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); - inputs.push_back(std::move(atIValue)); + if (at::kVulkan == deviceType_) { + inputs.push_back( + atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()} + : std::move(atIValue)); + } else { + TORCH_CHECK(at::kCPU == deviceType_); + inputs.push_back(std::move(atIValue)); + } } if (auto method = module_.find_method(methodName)) { auto output = [&]() { diff --git a/android/pytorch_android/src/main/java/org/pytorch/Module.java b/android/pytorch_android/src/main/java/org/pytorch/Module.java index 9dafc687f9938..62db7042d57bf 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/Module.java +++ b/android/pytorch_android/src/main/java/org/pytorch/Module.java @@ -11,16 +11,28 @@ public class Module { private INativePeer mNativePeer; /** - * Loads a serialized TorchScript module from the specified path on the disk. + * Loads a serialized TorchScript module from the specified path on the disk to run on specified + * device. * * @param modelPath path to file that contains the serialized TorchScript module. + * @param device {@link org.pytorch.Device} to use for running specified module. * @return new {@link org.pytorch.Module} object which owns torch::jit::Module. */ - public static Module load(final String modelPath) { + public static Module load(final String modelPath, final Device device) { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } - return new Module(new NativePeer(modelPath)); + return new Module(new NativePeer(modelPath, device)); + } + + /** + * Loads a serialized TorchScript module from the specified path on the disk to run on CPU. + * + * @param modelPath path to file that contains the serialized TorchScript module. + * @return new {@link org.pytorch.Module} object which owns torch::jit::Module. + */ + public static Module load(final String modelPath) { + return load(modelPath, Device.CPU); } Module(INativePeer nativePeer) { diff --git a/android/pytorch_android/src/main/java/org/pytorch/NativePeer.java b/android/pytorch_android/src/main/java/org/pytorch/NativePeer.java index 5c6ef31061ae2..76c0c6226755b 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/NativePeer.java +++ b/android/pytorch_android/src/main/java/org/pytorch/NativePeer.java @@ -13,18 +13,23 @@ class NativePeer implements INativePeer { private final HybridData mHybridData; @DoNotStrip - private static native HybridData initHybrid(String moduleAbsolutePath); + private static native HybridData initHybrid(String moduleAbsolutePath, int deviceJniCode); @DoNotStrip private static native HybridData initHybridAndroidAsset( - String assetName, /* android.content.res.AssetManager */ Object androidAssetManager); + String assetName, /* android.content.res.AssetManager */ + Object androidAssetManager, + int deviceJniCode); - NativePeer(String moduleAbsolutePath) { - mHybridData = initHybrid(moduleAbsolutePath); + NativePeer(String moduleAbsolutePath, Device device) { + mHybridData = initHybrid(moduleAbsolutePath, device.jniCode); } - NativePeer(String assetName, /* android.content.res.AssetManager */ Object androidAssetManager) { - mHybridData = initHybridAndroidAsset(assetName, androidAssetManager); + NativePeer( + String assetName, /* android.content.res.AssetManager */ + Object androidAssetManager, + Device device) { + mHybridData = initHybridAndroidAsset(assetName, androidAssetManager, device.jniCode); } public void resetNative() { diff --git a/android/pytorch_android/src/main/java/org/pytorch/PyTorchAndroid.java b/android/pytorch_android/src/main/java/org/pytorch/PyTorchAndroid.java index 15664dd040eaf..b775c2bb2e2c6 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/PyTorchAndroid.java +++ b/android/pytorch_android/src/main/java/org/pytorch/PyTorchAndroid.java @@ -21,9 +21,14 @@ public final class PyTorchAndroid { * *

This method is meant to use in tests and demos. */ + public static Module loadModuleFromAsset( + final AssetManager assetManager, final String assetName, final Device device) { + return new Module(new NativePeer(assetName, assetManager, device)); + } + public static Module loadModuleFromAsset( final AssetManager assetManager, final String assetName) { - return new Module(new NativePeer(assetName, assetManager)); + return new Module(new NativePeer(assetName, assetManager, Device.CPU)); } /** diff --git a/android/pytorch_android/test_asset.jit b/android/pytorch_android/test_asset.jit index 49a41eff36a6b..3bd9037da4ee6 100644 --- a/android/pytorch_android/test_asset.jit +++ b/android/pytorch_android/test_asset.jit @@ -1,85 +1,69 @@ def forward(self, input): return None -def eqBool(self, input): - # type: (bool) -> bool +def eqBool(self, input: bool) -> bool: return input -def eqInt(self, input): - # type: (int) -> int +def eqInt(self, input: int) -> int: return input -def eqFloat(self, input): - # type: (float) -> float +def eqFloat(self, input: float) -> float: return input -def eqStr(self, input): - # type: (str) -> str +def eqStr(self, input: str) -> str: return input -def eqTensor(self, input): - # type: (Tensor) -> Tensor +def eqTensor(self, input: Tensor) -> Tensor: return input -def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] +def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input -def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] +def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input -def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] +def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input -def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] +def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) -def listBoolConjunction(self, input): - # type: (List[bool]) -> bool +def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res -def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool +def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res -def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] +def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) -def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool +def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None -def intEq0None(self, input): - # type: (int) -> Optional[int] +def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input -def str3Concat(self, input): - # type: (str) -> str +def str3Concat(self, input: str) -> str: return input + input + input def newEmptyShapeWithItem(self, input): return torch.tensor([int(input.item())])[0] -def testAliasWithOffset(self): - # type: () -> List[Tensor] +def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -91,8 +75,7 @@ def testNonContiguous(self): assert x[1] == 300 return x -def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor +def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.conv2d(x, w) if (toChannelsLast): # memory_format=torch.channels_last @@ -101,16 +84,13 @@ def conv2d(self, x, w, toChannelsLast): r = r.contiguous() return r -def contiguous(self, x): - # type: (Tensor) -> Tensor +def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() -def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last return x.contiguous(memory_format=2) -def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last_3d return x.contiguous(memory_format=3) diff --git a/android/settings.gradle b/android/settings.gradle index 09473fa342812..743f388b65075 100644 --- a/android/settings.gradle +++ b/android/settings.gradle @@ -4,4 +4,3 @@ project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torch project(':pytorch_host').projectDir = file('pytorch_android/host') project(':test_app').projectDir = file('test_app/app') - diff --git a/android/test_app/app/build.gradle b/android/test_app/app/build.gradle index c592728ce9f4e..df7b758e3b31c 100644 --- a/android/test_app/app/build.gradle +++ b/android/test_app/app/build.gradle @@ -40,6 +40,7 @@ android { buildConfigField("String", "LOGCAT_TAG", "@string/app_name") buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}") buildConfigField("boolean", "NATIVE_BUILD", 'false') + buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false') addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"]) } buildTypes { @@ -59,16 +60,24 @@ android { //} flavorDimensions "model", "build", "activity" productFlavors { - mbq { + mnet { dimension "model" - applicationIdSuffix ".mbq" - buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"") - addManifestPlaceholders([APP_NAME: "MBQ"]) - buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"") + applicationIdSuffix ".mnet" + buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"") + addManifestPlaceholders([APP_NAME: "MNET"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"") + } + mnetVulkan { + dimension "model" + applicationIdSuffix ".mnet_vulkan" + buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"") + buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true') + addManifestPlaceholders([APP_NAME: "MNET_VULKAN"]) + buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"") } resnet18 { dimension "model" - applicationIdSuffix ".resneti18" + applicationIdSuffix ".resnet18" buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"") addManifestPlaceholders([APP_NAME: "RN18"]) buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"") @@ -122,7 +131,7 @@ android { tasks.all { task -> // Disable externalNativeBuild for all but nativeBuild variant - if (task.name.startsWith('externalNativeBuild') + if (task.name.startsWith('externalNativeBuild') && !task.name.contains('NativeBuild')) { task.enabled = false } @@ -140,8 +149,8 @@ dependencies { //nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar') //extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar') - nightlyImplementation 'org.pytorch:pytorch_android:1.7.0-SNAPSHOT' - nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.7.0-SNAPSHOT' + nightlyImplementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT' + nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT' aarImplementation(name:'pytorch_android', ext:'aar') aarImplementation(name:'pytorch_android_torchvision', ext:'aar') diff --git a/android/test_app/app/src/main/AndroidManifest.xml b/android/test_app/app/src/main/AndroidManifest.xml index a83bf223bdaf2..abdd9a8d986ae 100644 --- a/android/test_app/app/src/main/AndroidManifest.xml +++ b/android/test_app/app/src/main/AndroidManifest.xml @@ -18,4 +18,10 @@ + + + + diff --git a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java index 5cc233011c8a2..bd7469950f875 100644 --- a/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java +++ b/android/test_app/app/src/main/java/org/pytorch/testapp/MainActivity.java @@ -17,6 +17,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.FloatBuffer; +import org.pytorch.Device; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.PyTorchAndroid; @@ -126,7 +127,11 @@ protected Result doModuleForward() { mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements); mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE); PyTorchAndroid.setNumThreads(1); - mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); + mModule = + BuildConfig.USE_VULKAN_DEVICE + ? PyTorchAndroid.loadModuleFromAsset( + getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN) + : PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME); } final long startTime = SystemClock.elapsedRealtime(); diff --git a/aten/conda/meta.yaml b/aten/conda/meta.yaml index d8096fc73a0fa..a502690a5447a 100644 --- a/aten/conda/meta.yaml +++ b/aten/conda/meta.yaml @@ -24,7 +24,7 @@ requirements: - mkl # [not osx] about: - home: https://github.com/zdevito/ATen + home: https://github.com/pytorch/pytorch license: BSD summary: A TENsor library for C++14 diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h index ae95ef43f21c2..8d29a92044203 100644 --- a/aten/src/ATen/ATen.h +++ b/aten/src/ATen/ATen.h @@ -31,3 +31,4 @@ #include #include #include +#include diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/BatchedFallback.cpp index 1d214081b43e2..c06c1579b28d9 100644 --- a/aten/src/ATen/BatchedFallback.cpp +++ b/aten/src/ATen/BatchedFallback.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace at { @@ -34,6 +35,195 @@ static bool areAnyArgumentsTensorList(const FunctionSchema& schema) { [] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); }); } +// Returns if an operator is in-place. An operator is inplace if: +// 1. The first argument is a Tensor and it is being written to +// 2. The first argument is being returned +// 3. No other arguments are aliased +// Here is an example of an in-place operator: +// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +static bool isInplaceOp(const c10::FunctionSchema& schema) { + if (!schema.is_mutable() || schema.returns().size() != 1) { + return false; + } + // Check that the first argument is being written to + const auto& first_arg_alias_info = schema.arguments().begin()->alias_info(); + if (!first_arg_alias_info || !first_arg_alias_info.value().isWrite()) { + return false; + } + // Check that none of the other args are being aliased + for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) { + const auto& alias_info = it->alias_info(); + if (alias_info) { + return false; + } + } + // Check that the first tensor is being returned (i.e., output has a (a!)) + const auto& return_alias_info = schema.returns()[0].alias_info(); + return return_alias_info && return_alias_info.value().isWrite(); +} + +static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) { + auto uses_stack = is_inplace ? "" : " and stack"; + TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back " + "to slow (for loop", uses_stack, ") implementation"); +} + +// The general flow of the algorithm is as follows. +// - First, we figure out which arguments are BatchedTensors and save them +// to a vector. We also store a vector of which index of the arguments list +// each BatchedTensor appears in. This will be useful for bookkeeping later. +// - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors. +// This returns a vector of VmapPhysicalView that hold tensors that contain +// all of the collective batch dimensions at the front of the tensors. +// - Then, we attempt to call `op` once per slice of the inputs. To do this, +// we repeatedly we slice the input arguments (if they are BatchedTensors), +// put the sliced (or a not-sliced) version of the input onto the stack, invoke +// the operator, and then pop the results off the stack. +void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + warnFallback(schema, /*in_place*/true); + + const auto num_arguments = schema.arguments().size(); + const auto arguments = torch::jit::last(stack, num_arguments); + const auto arguments_begin = stack->size() - num_arguments; + + // `self` is the Tensor being modified in-place + Tensor self = arguments[0].toTensor(); + const auto* self_impl = maybeGetBatchedImpl(self); + std::bitset self_vmap_levels; + if (self_impl) { + self_vmap_levels = createVmapLevelsBitset(self_impl->bdims()); + } + + // Figure out which arguments are BatchedTensor. Save them to a vector. + // For each BatchedTensor, also record what position of `arguments` they came from. + SmallVector batched_tensor_inputs; + VmapDimVector batched_tensor_inputs_position; + for (int64_t idx = 0; idx < arguments.size(); ++idx) { + const auto& ivalue = arguments[idx]; + if (!ivalue.isTensor()) { + continue; + } + const auto& tensor = ivalue.toTensor(); + if (!tensor.defined()) { + continue; + } + const auto* batched = maybeGetBatchedImpl(tensor); + if (!batched) { + continue; + } + + // NOTE: [vmap-incompatible in-place operations] + // In-place operations on `self` are not possible if there exists some vmap + // level `l` such that `self` is not being vmapped on that level but another + // argument is. For example, let B0 be a batch dim inside vmap and consider + // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3)) + // - self is torch.ones(3) and does not participate in this vmap + // - other is BatchedTensor(torch.ones(B0, 3)) + // There's no way to do self.add_(other) because `other` has more elements + // elements than `self` due to being vmapped over. + // + // In the vmap fallback, we should error out when we detect this. + auto other_vmap_levels = createVmapLevelsBitset(batched->bdims()); + if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) { + // Find one vmap level to complain about + auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels; + auto offending_level = llvm::findLastSet(additional_bdims.to_ulong()); + // The following prints out "vmap: aten::add_(tensor, ...) is not possible", + // but it would be better to print out "tensor.add_(...) is not possible". + // Afaict there's no official way to get the add_ and there is no way to + // tell if an operator has method or function variants. + TORCH_CHECK(false, + "vmap: ", schema.name(), "(self, *extra_args) is not possible because ", + "there exists a Tensor `other` in extra_args that has more elements ", + "than `self`. This happened due to `other` being vmapped over but ", + "`self` not being vmapped over at level ", offending_level, ". ", + "Please try to use out-of-place operators instead of ", schema.name(), ". ", + "If said operator is being called inside the PyTorch framework, ", + "please file a bug report instead."); + } + batched_tensor_inputs.push_back(tensor); + batched_tensor_inputs_position.push_back(idx); + } + TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0); + + // MultiBatchVmapTransform the BatchedTensor arguments. This returns + // VmapPhysicalViews that contain all of the batch dimensions. + const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical( + batched_tensor_inputs); + + // Compute the total number of batches + auto num_batch_dims = input_physical_views.front().numBatchDims(); + auto first_physical_view_sizes = input_physical_views.front().tensor().sizes(); + auto batch_sizes = ArrayRef( + first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims); + const auto num_batches = prod_intlist(batch_sizes); + // Without a shape-checking API, we're unable to compute the correct shape of + // the output so we just error out. + TORCH_CHECK(num_batches > 0, + "Batching rule not implemented for ", schema.operator_name(), ". ", + "The fallback path does not support vmap over dims of size 0."); + + // Strategy: For each batch, we are going to push slices (where applicable) + // of the arguments onto `stack`, and call `op`. + for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) { + auto index = computeIndex(linear_idx, batch_sizes); + auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin(); + auto input_physical_views_iter = input_physical_views.begin(); + for (int64_t arg_idx = 0; arg_idx < num_arguments; ++arg_idx) { + // We assume that torch::jit::Stack is backed by vector for + // simplicity. When that is not the case, this code should be updated. + const auto& argument = (*stack)[arguments_begin + arg_idx]; + if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() + || arg_idx != *batched_tensor_inputs_pos_iter) { + // argument isn't a BatchedTensor + torch::jit::push(stack, argument); + continue; + } + // argument is a BatchedTensor + TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end()); + const auto& physical_view_for_argument = *input_physical_views_iter; + torch::jit::push(stack, physical_view_for_argument.tensor().index(index)); + batched_tensor_inputs_pos_iter++; + input_physical_views_iter++; + } + + op.callBoxed(stack); + torch::jit::drop(stack, 1); + } + + // Return the tensor that was written to in-place + torch::jit::drop(stack, num_arguments); + torch::jit::push(stack, self); +} + +static Tensor safeStack(TensorList tensors) { + auto is_defined = [](const Tensor& t) { return t.defined(); }; + if (std::all_of(tensors.begin(), tensors.end(), is_defined)) { + return at::stack(tensors); + } + // NOTE [vmap through backward and undefined grad] + // While vmapping through backward functions (to compute batched grad), it + // is possible for the backward function to return an undefined grad for some + // grad_input for each example. In that case, we return an undefined grad. + // + // It is theoretically posssible for *some* of the examples to produce an + // undefined grad (a kernel could peek at the gradient values and return an + // undefined tensor if it determines the gradient is full of zeros). We + // could handle this by treating the undefined grad as a zero-filled tensor + // of the correct shape while stacking the tensors together. However I expect + // this to happen very rarely (I have not been able to find an example in our + // codebase) so we just error out in this case. + if (std::none_of(tensors.begin(), tensors.end(), is_defined)) { + return Tensor(); + } + TORCH_CHECK(false, + "vmap: slow fallback received a mix of undefined and defined tensors ", + "as the result of an operation. This is not supported, please file us ", + "an issue on github."); +} + // The general flow of the algorithm is as follows. // - First, we figure out which arguments are BatchedTensors and save them // to a vector. We also store a vector of which index of the arguments list @@ -50,17 +240,21 @@ static bool areAnyArgumentsTensorList(const FunctionSchema& schema) { void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { const auto& schema = op.schema(); const auto num_returns = schema.returns().size(); + + if (isInplaceOp(schema)) { + batchedTensorInplaceForLoopFallback(op, stack); + return; + } TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(), "Batching rule not implemented for ", schema.operator_name(), "; ", - "the fallback path doesn't work on in-place or view ops."); + "the fallback path doesn't work on out= or view ops."); TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema), "Batching rule not implemented for ", schema.operator_name(), ". ", "We could not generate a fallback."); TORCH_CHECK(num_returns >= 1, "Batching rule not implemented for ", schema.operator_name(), ". ", "The fallback path does not support operations with no returns."); - TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back " - "to slow (for loop and stack) implementation"); + warnFallback(schema, /*in_place*/false); const auto num_arguments = schema.arguments().size(); const auto arguments = torch::jit::last(stack, num_arguments); @@ -97,11 +291,12 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta auto num_batch_dims = input_physical_views.front().numBatchDims(); auto some_sizes = input_physical_views.front().tensor().sizes(); auto batch_sizes = ArrayRef(some_sizes.begin(), some_sizes.begin() + num_batch_dims); - auto num_batches = std::accumulate( - batch_sizes.begin(), - batch_sizes.end(), - 1, - std::multiplies()); + const auto num_batches = prod_intlist(batch_sizes); + // Without a shape-checking API, we're unable to compute the correct shape of + // the output so we just error out. + TORCH_CHECK(num_batches > 0, + "Batching rule not implemented for ", schema.operator_name(), ". ", + "The fallback path does not support vmap over dims of size 0."); // Strategy: For each batch, we are going to push slices (where applicable) // of the arguments onto `stack`, call `op`, and store the result in @@ -153,7 +348,12 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta auto output_shards_chunks = MatrixRef(output_shards, num_batches); for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) { auto shards = output_shards_chunks[return_idx]; - auto flat_output = at::stack(shards); + auto flat_output = safeStack(shards); + // See NOTE [vmap through backward and undefined grad] + if (!flat_output.defined()) { + torch::jit::push(stack, flat_output); + continue; + } VmapDimVector output_sizes(batch_sizes); output_sizes.insert( output_sizes.end(), @@ -161,7 +361,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta flat_output.sizes().end()); torch::jit::push( stack, - input_physical_views.front().newLogicalFromPhysical(flat_output.view(output_sizes))); + input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes))); } } diff --git a/aten/src/ATen/BatchedTensorImpl.cpp b/aten/src/ATen/BatchedTensorImpl.cpp index 3c2ce5b9a6712..9dbf9ea78f4b6 100644 --- a/aten/src/ATen/BatchedTensorImpl.cpp +++ b/aten/src/ATen/BatchedTensorImpl.cpp @@ -19,18 +19,20 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims) const auto public_dims = value_.dim() - bdims_.size(); const auto value_sizes = value_.sizes(); - sizes_.clear(); - sizes_.reserve(public_dims); + const auto value_strides = value_.strides(); + sizes_and_strides_.resize(public_dims); for (int64_t dim = 0; dim < public_dims; dim++) { auto actual_dim = actualDim(dim, /*wrap_dim=*/false); - sizes_.push_back(value_sizes.at(actual_dim)); + sizes_and_strides_.size_at_unchecked(dim) = value_sizes.at(actual_dim); + sizes_and_strides_.stride_at_unchecked(dim) = value_strides.at(actual_dim); } refresh_numel(); + refresh_contiguous(); } int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const { if (wrap_dim) { - const auto ndim = sizes_.size(); + const auto ndim = sizes_and_strides_.size(); dim = maybe_wrap_dim(dim, ndim); } auto is_bdim = createBatchDimBitset(bdims_); @@ -71,15 +73,13 @@ void BatchedTensorImpl::checkInvariants() const { } // The following are publically exposed as methods of Tensor -IntArrayRef BatchedTensorImpl::strides() const { - TORCH_CHECK(false, "NYI: Getting tensor strides inside of vmap"); -} -int64_t BatchedTensorImpl::stride(int64_t d) const { - TORCH_CHECK(false, "NYI: Getting tensor strides inside of vmap"); -} bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const { - TORCH_CHECK(false, "NYI: querying is_contiguous inside of vmap"); + TORCH_CHECK(memory_format == MemoryFormat::Contiguous, + "NYI: querying is_contiguous inside of vmap for memory_format ", + "other than torch.contiguous_format"); + return is_contiguous_; } + const Storage& BatchedTensorImpl::storage() const { TORCH_CHECK(false, "Due to limitations, we cannot access the storage() of a tensor from inside of vmap."); } @@ -129,4 +129,19 @@ Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) { return makeBatched(batched->value(), std::move(new_bdims)); } +bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other) { + const auto* other_batched = maybeGetBatchedImpl(other); + if (!other_batched) { + return true; + } + const auto* self_batched = maybeGetBatchedImpl(self); + if (!self_batched) { + // self is not batched but other is batched + return false; + } + auto self_levels = createVmapLevelsBitset(self_batched->bdims()); + auto other_levels = createVmapLevelsBitset(other_batched->bdims()); + return self_levels == (self_levels | other_levels); +} + } // namespace at diff --git a/aten/src/ATen/BatchedTensorImpl.h b/aten/src/ATen/BatchedTensorImpl.h index 0586e9caeec0a..7fdef64146fde 100644 --- a/aten/src/ATen/BatchedTensorImpl.h +++ b/aten/src/ATen/BatchedTensorImpl.h @@ -74,8 +74,6 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { // Override a bunch of methods inherited from TensorImpl to return error messages. bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; - IntArrayRef strides() const override; - int64_t stride(int64_t d) const override; void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; @@ -123,6 +121,15 @@ inline std::bitset createBatchDimBitset(BatchDimsRef bdims) return is_bdim; } +// Creates a bitset for all of the levels present in `bdims` +inline std::bitset createVmapLevelsBitset(BatchDimsRef bdims) { + std::bitset result; + for (const auto& bdim : bdims) { + result.set(bdim.level()); + } + return result; +} + inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) { out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")"; return out; @@ -134,5 +141,9 @@ TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims); // Adds a batch dim to `tensor`, returning a BatchedTensor TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim); +// Checks if an inplace operation on self and other is "vmap compatible". +// See NOTE: [vmap-incompatible in-place operations] for the definition of this. +TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other); + } diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index e930ffd7e2eaa..ff05b02cc7597 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include namespace at { @@ -48,11 +49,25 @@ namespace at { // if not use the same mechanism. In order to accomplish that we might have to // do some refactoring. +// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor. +static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { + return dim == 0 || dim == -1; +} + Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional dtype) { + // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail + // and instead returns a new scalar tensor (this also happens for dim=-1) + // If the following happens: + // >>> x = torch.randn(B0) # the per-examples are all scalars + // >>> vmap(partial(torch.sum, dim=0), x) + // then we replicate the behavior of sum(scalar_tensor, dim=0). + if (/*logical*/self.dim() == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])) { + return self.clone(); + } auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dims_physical = self_physical.getPhysicalDims(dims); auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } bool isPhysicalScalarTensor(const Tensor& logical_tensor) { @@ -72,17 +87,17 @@ Tensor binary_pointwise_batching_rule( if (self.dim() > 0 && other.dim() > 0) { auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); - return physical_args[0].newLogicalFromPhysical(result); + return physical_args[0].getPhysicalToLogicalMap().apply(result); } if (isPhysicalScalarTensor(self)) { auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); auto result = Func(self, other_physical.tensor(), args...); - return other_physical.newLogicalFromPhysical(result); + return other_physical.getPhysicalToLogicalMap().apply(result); } if (isPhysicalScalarTensor(other)) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto result = Func(self_physical.tensor(), other, args...); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } // At this point, we know at least one of the operands is a logical Scalar tensor. @@ -123,7 +138,7 @@ Tensor binary_pointwise_batching_rule( auto physical_args = BroadcastingVmapTransform::logicalToPhysical( {logical_self, logical_other}); auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); - return physical_args[0].newLogicalFromPhysical(result); + return physical_args[0].getPhysicalToLogicalMap().apply(result); } Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) { @@ -138,7 +153,7 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) if (self_physical_dim == size_physical.size()) { auto result = self_physical.tensor().expand(size_physical, implicit); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } TORCH_INTERNAL_ASSERT(self_physical_dim < size_physical.size()); @@ -161,14 +176,48 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) self_physical_size.end(), view_shape.begin() + self_physical.numBatchDims() + extra_dims); auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } std::vector chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = at::chunk(self_physical.tensor(), chunks, dim_physical); - self_physical.makeLogicalFromPhysicalListInplace(result); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +Tensor clamp_batching_rule(const Tensor& self, optional min, optional max) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto result = at::clamp(self_physical.tensor(), min, max); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +Tensor clamp_min_batching_rule(const Tensor& self, Scalar min) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto result = at::clamp_min(self_physical.tensor(), min); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +Tensor clamp_max_batching_rule(const Tensor& self, Scalar max) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto result = at::clamp_max(self_physical.tensor(), max); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +std::vector tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); + return result; +} + +std::vector tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto dim_physical = self_physical.getPhysicalDim(dim); + auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical); + self_physical.getPhysicalToLogicalMap().applyInplace(result); return result; } @@ -181,22 +230,97 @@ Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) { auto dim_physical = self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1); auto result = self_physical.tensor().unsqueeze(dim_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +Tensor& fill_inplace_scalar_batching_rule(Tensor& self, Scalar value) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().fill_(value); + return self; +} + +Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) { + auto value_batched = isBatchedTensor(value); + + if (value_batched) { + auto physical_args = + BroadcastingVmapTransform::logicalToPhysical({self, value}); + physical_args[0].tensor().copy_(physical_args[1].tensor()); + } else { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().fill_(value); + } + return self; +} + +Tensor& zero_inplace_batching_rule(Tensor &self) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + self_physical.tensor().zero_(); + return self; +} + +Tensor squeeze_batching_rule(const Tensor& self) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto physical_sizes = self_physical.tensor().sizes(); + + // Don't squeeze the batch dims! + VmapDimVector squeezed_sizes; + int64_t num_batch_dims = self_physical.numBatchDims(); + squeezed_sizes.insert( + squeezed_sizes.end(), + physical_sizes.begin(), + physical_sizes.begin() + num_batch_dims); + for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) { + if (*it != 1) { + squeezed_sizes.push_back(*it); + } + } + + auto result = self_physical.tensor().view(squeezed_sizes); + return self_physical.getPhysicalToLogicalMap().apply(result); } Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = self_physical.tensor().squeeze(dim_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +Tensor trace_batching_rule(const Tensor& self) { + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + // Batched Diagonal View + auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1); + auto result = at::sum(self_diag, -1); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) { + auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); + auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); + // Batched Diagonal View + auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1); + // Append a dimension of size one to the grad output + auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1); + grad_input_diag.copy_(grad_physical_tensor); + return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) { + // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works + // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens: + // >>> x = torch.randn(B0) # the per-examples are all scalars + // >>> vmap(lambda x: x.transpose(0, -1), x) + // then we replicate this behavior. + if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) && + is_allowed_dim_on_scalar_tensor(dim1)) { + return self; + } auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim0_physical = self_physical.getPhysicalDim(dim0); auto dim1_physical = self_physical.getPhysicalDim(dim1); auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) { @@ -206,21 +330,21 @@ Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) { VmapDimVector all_dims_physical; all_dims_physical.reserve(self_physical.tensor().dim()); for (int64_t bdim = 0; bdim < self_physical.numBatchDims(); bdim++) { - all_dims_physical.push_back(bdim); + all_dims_physical.push_back(bdim); } all_dims_physical.insert( all_dims_physical.end(), dims_physical.begin(), dims_physical.end()); auto result = self_physical.tensor().permute(all_dims_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = self_physical.tensor().select(dim_physical, index); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) { @@ -232,14 +356,19 @@ Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); grad_input.select(physical_dim, index).copy_(grad_physical.tensor()); - return grad_physical.newLogicalFromPhysical(grad_input); + return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } -Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { +Tensor slice_batching_rule( + const Tensor& self, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = self_physical.tensor().slice(dim_physical, start, end, step); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { @@ -247,7 +376,7 @@ Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor()); - return grad_physical.newLogicalFromPhysical(grad_input); + return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { @@ -255,7 +384,7 @@ Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, auto dim1_physical = self_physical.getPhysicalDim(dim1); auto dim2_physical = self_physical.getPhysicalDim(dim2); auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { @@ -264,7 +393,7 @@ Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_siz auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims()); auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims()); grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor()); - return grad_physical.newLogicalFromPhysical(grad_input); + return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) { @@ -272,21 +401,21 @@ Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef auto source_physical = self_physical.getPhysicalDims(source); auto destination_physical = self_physical.getPhysicalDims(destination); auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto shape_physical = self_physical.getPhysicalShape(shape); auto result = self_physical.tensor().reshape(shape_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } std::vector split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = at::split(self_physical.tensor(), split_size, dim_physical); - self_physical.makeLogicalFromPhysicalListInplace(result); + self_physical.getPhysicalToLogicalMap().applyInplace(result); return result; } @@ -294,7 +423,7 @@ std::vector split_with_sizes_batching_rule(const Tensor& self, IntArrayR auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical); - self_physical.makeLogicalFromPhysicalListInplace(result); + self_physical.getPhysicalToLogicalMap().applyInplace(result); return result; } @@ -302,7 +431,7 @@ std::vector unbind_batching_rule(const Tensor& self, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = at::unbind(self_physical.tensor(), dim_physical); - self_physical.makeLogicalFromPhysicalListInplace(result); + self_physical.getPhysicalToLogicalMap().applyInplace(result); return result; } @@ -310,18 +439,261 @@ Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = self_physical.tensor().unfold(dim_physical, size, step); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) { + TORCH_CHECK(memory_format == MemoryFormat::Contiguous, + "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ", + "than torch.contiguous_format"); + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); + auto result = physical_view.tensor().contiguous(memory_format); + return physical_view.getPhysicalToLogicalMap().apply(result); } Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto size_physical = self_physical.getPhysicalShape(size); auto result = self_physical.tensor().view(size_physical); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } +Tensor view_as_complex_batching_rule(const Tensor& self) { + // guard against the user passing in a batch of scalar tensors with batch + // size equal to 2. + TORCH_CHECK(self.sizes().size() != 0, "Input tensor must have one or more dimensions"); + auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); + auto result = at::view_as_complex(self_physical.tensor()); + return self_physical.getPhysicalToLogicalMap().apply(result); +} + +// Checks that the smallest batch stride is greater than the largest example +// stride. This is something we can support but we choose not to because it's +// potentially error prone. +static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) { + auto smallest_batch_stride = std::min_element( + physical_strides.begin(), physical_strides.begin() + num_batch_dims); + auto largest_example_stride = std::max_element( + physical_strides.begin() + num_batch_dims, physical_strides.end()); + if (largest_example_stride == physical_strides.end()) { + // No example dimensions + return; + } + TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride, + "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ", + "vmapped over are at the front of the tensor (in memory layout). When they are ", + "not at the front of the tensor this operation can be error prone so we " + "actively discourage it; please file us a bug report and/or try to ", + "express the as_strided operation in terms of PyTorch view operations"); +} + +// given (sizes, strides, storage_offset) returns the maximum location that +// can be indexed (or nullopt if such a location doesn't exist, e.g., tensors +// with zero-size dims). +static optional maximum_indexable_location( + IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) { + auto result = native::storage_size_for(sizes, strides); + if (result == 0) { + return nullopt; + } + return result + storage_offset; +} + +// Let x be the "first slice" of physical_tensor. +// This checks that the range of possible memory locations accessible by +// x.as_strided(sizes, strides, maybe_storage_offset) +// are within the bounds of possible memory locations accessible by x. +static void checkBasicAsStridedValidForSlice( + const Tensor& physical_tensor, + int64_t num_batch_dims, + IntArrayRef sizes, + IntArrayRef strides, + optional maybe_storage_offset) { + auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims); + auto slice_strides = physical_tensor.strides().slice(num_batch_dims); + auto base_offset = physical_tensor.storage_offset(); + + auto storage_offset = maybe_storage_offset.value_or(base_offset); + + auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset); + auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset); + + if (!max_as_strided_loc.has_value()) { + return; + } + if (!max_slice_loc.has_value()) { + TORCH_CHECK(false, + "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")", + "can access memory outside of `tensor`. `tensor` has no storage but the ", + "passed-in (size, stride, storage_offset) imply a result with some storage. ", + "This is not supported inside of vmap, please try to rewrite the ", + "`as_strided` call as a sequence of PyTorch view operations"); + } + + TORCH_CHECK( + *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset, + "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")", + "can access memory outside of `tensor`. `result` can access some", + "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ", + "`tensor` can only access some memory in range [", base_offset, ", ", + *max_slice_loc, "]. This is not supported inside of vmap, please try to", + "rewrite the `as_strided` call as a sequence of PyTorch view operations"); +} + +// What are the semantics of as_strided inside of vmap? +// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs) +// This returns a view on `x`, `y`, such that each y[i] has: +// - sizes: `sizes` +// - strides: `strides` +// - storage_offset: offset + i * x.stride(batch_dim) +// +// In other words, it is as if we had treated each x[i] as having storage +// offset equal to xs.offset() and called as_strided(sizes, sizes, offset). +// (that is equivalent to x[i].as_strided( +// sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i) +// +// Note that this *may* be different from actually running as_strided +// in a for-loop. This is due to how as_strided takes in `offset` to be +// an *absolute* offset. As an example, consider: +// >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1) +// >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)] +// Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))! +// However, we consider the above for-loop comprehension to be a user error: +// a user should have written the following if they wanted to use as_strided +// in a per-sample way: +// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)] +Tensor as_strided_batching_rule( + const Tensor& tensor, + IntArrayRef sizes, + IntArrayRef strides, + optional storage_offset) { + auto physical_view = at::MultiBatchVmapTransform::logicalToPhysical(tensor); + auto num_batch_dims = physical_view.numBatchDims(); + auto physical_sizes = physical_view.getPhysicalShape(sizes); + const auto& physical_tensor = physical_view.tensor(); + + // We can't rely on the physical as_strided call to do this for us because + // we do some sanity checks on the size/strides before calling into as_strided. + TORCH_CHECK(sizes.size() == strides.size(), + "Tensor.as_strided(size, stride, ...): size and stride must have the ", + "same length! Got size ", sizes, " and stride ", strides); + + // Sanity checks: + // 1. All batch dims are at the front in memory layout (not necessary for + // correctness, but we are worried the user might be doing crazy things) + // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset()) + // is valid for a slice of the input tensor. + // See Note: [When will the as_strided batching rule fail?] for details. + checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims); + checkBasicAsStridedValidForSlice( + physical_tensor, num_batch_dims, sizes, strides, storage_offset); + + // physical_strides = physical tensor's batch strides + (logical) strides + auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims); + at::VmapDimVector physical_strides; + physical_strides.reserve(num_batch_dims + strides.size()); + physical_strides.insert( + physical_strides.end(), batch_strides.begin(), batch_strides.end()); + physical_strides.insert( + physical_strides.end(), strides.begin(), strides.end()); + + // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + // is valid for all i, then it turns out that + // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds + // and creates a tensor y such that each y[i] references the same memory + // locations as zi. See NOTE: [When will the as_strided batching rule fail?] + auto result = physical_view.tensor().as_strided( + physical_sizes, physical_strides, storage_offset); + return physical_view.getPhysicalToLogicalMap().apply(result); +} + +// NOTE: [When will the as_strided batching rule fail?] +// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) +// is valid for all i, then it turns out that +// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and +// creates a tensor y such that each y[i] refers to the same memory as zi. +// +// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()). +// Furthermore, let's say that as a part of being "valid" this as_strided call +// does not return a result that can index memory not indexable by xs[i]. +// +// WLOG, assume that there's only one batch dim and it is at the front of the +// `xs` tensor. Let B be the batch size and S be the stride of the batch dim. +// - If the batch dim isn't at the front of the tensor, then we can just move it +// to the front with movedim/permute. This is always valid because it just swaps +// some strides around. +// - This proof also works for tensors with multiple batch dims. We just have to +// do a little accounting: +// - instead of [B], we'd have [B0, B1, ..., Bk]. +// - instead of [S], we'd have [S0, S1, ..., Sk]. +// - instead of i, we'd have a list of indices [I0, I1, ..., Ik] +// - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i +// +// [Equation 1] +// xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has: +// - sizes: sizes +// - strides: strides +// - offset: offset + S * i +// +// x.as_strided itself checks that: +// - (sizes, strides, offset) are in bounds for `x`'s storage. +// - strides are positive +// - offset is positive +// +// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) +// is valid, then +// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage. +// +// If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset) +// won't error out. So all we need to check is that the memory locations are +// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important) +// +// xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to +// xs.as_strided([B] + sizes, [S] + strides, offset) +// +// xs.as_strided([B] + sizes, [S] + strides, offset) has: +// - sizes: [B] + sizes +// - strides: [S] + strides +// - offset: offset +// +// xs.as_strided([B] + sizes, [S] + strides, offset)[i] has: +// - sizes: sizes +// - strides: strides +// - offset: offset + S * i +// These memory locations are exactly the same as what we got for [Equation 1], +// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid. +// +// [Hand-wavy proof of Claim 1] +// Part of our definition of being valid is that xs[i].as_strided(...) +// must return a tensor that only uses memory indexable by xs[i]. +// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies: +// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] +// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) +// (the largest-index memory location of xs[i].as_strided(...) must be \leq +// the largest-index memory location of xs[i]) +// +// Fiddling that inequality gives us: +// offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] +// <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) +// +// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] +// <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) +// +// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] +// <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j) +// +// offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j] +// <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j) +// (the largest-index memory location of xs.as_strided(size, stride, offset) +// is \leq than the largest-index memory location of xs) +// Under the assumptions we've made, the lower bound (lowest indexed memory) +// is trivially within the storage. +// +// Therefore ([B] + sizes, [S] + strides, offset) are in bounds for +// `xs`'s storage. + template -Tensor unary_pointwise_batching_rule(const Tensor& input, ExtraArgs... args) { +Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) { auto* input_batched = unsafeGetBatchedImpl(input); auto output_physical = Func(input_batched->value(), args...); auto old_bdims = input_batched->bdims(); @@ -329,7 +701,7 @@ Tensor unary_pointwise_batching_rule(const Tensor& input, ExtraArgs... args) { } template -Tensor unary_pointwise_method_batching_rule(const Tensor& input, ExtraArgs... extra_args) { +Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) { auto* input_batched = unsafeGetBatchedImpl(input); auto output_physical = (input_batched->value().*Func)(extra_args...); auto old_bdims = input_batched->bdims(); @@ -343,6 +715,42 @@ Tensor pow_scalar_Tensor_batching_rule(Scalar other, const Tensor& self) { return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); } +Tensor clone_batching_rule(const Tensor& self, optional memory_format) { + // Memory format support is a little tricky because vmap is allowed to move + // around batch dimensions and some memory formats are rank-dependent. + // Another weird case is: + // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we + // allow the user to clone a Tensor with 3 logical dimensions and 1 batch + // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims + // and N>1 batch dims? + TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve + || memory_format == MemoryFormat::Contiguous, + "NYI: Tensor.clone(memory_format) inside vmap is only supported with ", + "memory_format torch.preserve_format or torch.contiguous_format (got ", + *memory_format, ")"); + + if (memory_format == MemoryFormat::Contiguous) { + // There is an ambiguity here when the batch dims are not at the front of + // the tensor. + // >>> x = torch.randn(3, B0, 5) + // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x) + // >>> y[0].is_contiguous() + // ??? + // Should we make the whole tensor contiguous, or should we + // make the non-batch dims contiguous? We've chosen the latter because + // philosophically vmap hides the batch dims and operates on a per-sample level. + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); + auto output_physical = at::clone(physical_view.tensor(), memory_format); + return physical_view.getPhysicalToLogicalMap().apply(output_physical); + } + + TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve); + auto* self_batched = unsafeGetBatchedImpl(self); + auto output_physical = at::clone(self_batched->value(), memory_format); + auto old_bdims = self_batched->bdims(); + return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); +} + // Note [Batching rules for matmul-like operators] // at::matmul doesn't "de-expand" arguments to get better performance (maybe // it should). In the batching rules for matmul-like operators (dot, mv, mm), @@ -363,7 +771,7 @@ Tensor mv_batching_rule(const Tensor& self, const Tensor& other) { if (self_batched && !other_batched) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto result = at::matmul(self_physical.tensor(), other); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } if (!self_batched && other_batched) { // self_physical: [L, K], other_physical: [..., K] @@ -371,7 +779,7 @@ Tensor mv_batching_rule(const Tensor& self, const Tensor& other) { // a tensor of size [..., L, 1], and unsqueeze the last dim. auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1)); - return other_physical.newLogicalFromPhysical(result.squeeze(-1)); + return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1)); } if (self_batched && other_batched) { // self_physical: [..., L, K], other_physical: [..., K] @@ -381,7 +789,7 @@ Tensor mv_batching_rule(const Tensor& self, const Tensor& other) { auto result = at::matmul( physical_args[0].tensor(), physical_args[1].tensor().unsqueeze(-1)); - return physical_args[0].newLogicalFromPhysical(result.squeeze(-1)); + return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1)); } TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor"); } @@ -401,14 +809,14 @@ Tensor dot_batching_rule(const Tensor& self, const Tensor& other) { // View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze. auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other); - return self_physical.newLogicalFromPhysical(result.squeeze(-1)); + return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1)); } if (!self_batched && other_batched) { // self_physical: [K], other_physical: [..., K] // View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze. auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1)); - return other_physical.newLogicalFromPhysical(result.squeeze(-1)); + return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1)); } if (self_batched && other_batched) { // self_physical: [..., K], other_physical: [..., K] @@ -417,7 +825,7 @@ Tensor dot_batching_rule(const Tensor& self, const Tensor& other) { auto result = at::matmul( physical_args[0].tensor().unsqueeze(-2), physical_args[1].tensor().unsqueeze(-1)); - return physical_args[0].newLogicalFromPhysical(result.squeeze(-1).squeeze(-1)); + return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1)); } TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor"); } @@ -430,7 +838,7 @@ Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) { auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); - return physical_args[0].newLogicalFromPhysical(result); + return physical_args[0].getPhysicalToLogicalMap().apply(result); } Tensor mm_batching_rule(const Tensor& self, const Tensor& other) { @@ -446,17 +854,17 @@ Tensor mm_batching_rule(const Tensor& self, const Tensor& other) { if (self_batched && !other_batched) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto result = at::matmul(self_physical.tensor(), other); - return self_physical.newLogicalFromPhysical(result); + return self_physical.getPhysicalToLogicalMap().apply(result); } if (!self_batched && other_batched) { auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other); auto result = at::matmul(self, other_physical.tensor()); - return other_physical.newLogicalFromPhysical(result); + return other_physical.getPhysicalToLogicalMap().apply(result); } if (self_batched && other_batched) { auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other}); auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); - return physical_args[0].newLogicalFromPhysical(result.squeeze(-1).squeeze(-1)); + return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1)); } TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor"); } @@ -468,7 +876,7 @@ Tensor cat_batching_rule(TensorList tensors, int64_t dim) { TORCH_INTERNAL_ASSERT( tensors.size() > 0, "The dispatcher should not have dispatched here otherwise."); auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim)); - return physical_views[0].newLogicalFromPhysical(result); + return physical_views[0].getPhysicalToLogicalMap().apply(result); } Tensor stack_batching_rule(TensorList tensors, int64_t dim) { @@ -482,13 +890,13 @@ Tensor stack_batching_rule(TensorList tensors, int64_t dim) { auto dim_physical = physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1); auto result = at::stack(physical_tensors, dim_physical); - return physical_views[0].newLogicalFromPhysical(result); + return physical_views[0].getPhysicalToLogicalMap().apply(result); } // I am quite sad that we need to register operators with exploded TensorOptions, // even though the native:: implementations can use TensorOptions&. // This also makes it hard to metaprogram: i.e., we can't use -// unary_pointwise_batching_rule<..., at::to> because at::to takes TensorOptions& (!!) +// unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!) Tensor to_dtype_layout_batching_rule( const Tensor& self, optional dtype, @@ -508,6 +916,103 @@ Tensor to_dtype_layout_batching_rule( return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); } +Tensor new_zeros_batching_rule( + const Tensor& self, + IntArrayRef size, + optional dtype, + optional layout, + optional device, + optional pin_memory) { + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); + auto physical_size = physical_view.getPhysicalShape(size); + auto options = TensorOptions() + .dtype(dtype) + .layout(layout) + .device(device) + .pinned_memory(pin_memory); + auto result = physical_view.tensor().new_zeros(physical_size, options); + return physical_view.getPhysicalToLogicalMap().apply(result); +} + +Tensor new_empty_batching_rule( + const Tensor& self, + IntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); + auto physical_size = physical_view.getPhysicalShape(size); + auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)); + return physical_view.getPhysicalToLogicalMap().apply(result); +} + +Tensor new_empty_strided_batching_rule( + const Tensor& self, + IntArrayRef size, + IntArrayRef stride, + optional dtype, + optional layout, + optional device, + optional pin_memory) { + auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); + auto physical_size = physical_view.getPhysicalShape(size); + + // Let [B0, B1, B2] be the shape of the batch dims. We're going to create + // the batch dimensions at the front of the tensor (in memory layout), + // irrespective of whether or not they are actually at the front (in memory layout) + // in the original `self` tensor. This is because when a user calls + // `new_empty_strided` in general, the `strides` they provide are for a new + // tensor and have no relation to the strides of the original tensor. + // + // So, the physical shape of the result should be ([B0, B1, B2] + size), + // but what about the physical strides? + // + // We're actually free to pick whatever stride we want: + // e.g., for size=[5, 3], stride=[0, 1], we could decide to + // use + // - physical size: [B0, B1, B2, 5, 3] + // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1] + // + // Let's select some reasonable strides such that: + // - The batch dims are "contiguous" with respect to each other + // - if empty_strided(size, stride) would have created a contiguous Tensor, + // then this new physical Tensor (with batch dims) is also contiguous + // + // Let S be the size of the storage if one were to construct a tensor + // with `size` and `stride` via empty_strided(size, stride). + // Then the physical sizes/strides should be: + // - physical size: [B0, B1, B2, 5, 3] + // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1] + auto batch_shape = IntArrayRef( + physical_view.tensor().sizes().begin(), physical_view.numBatchDims()); + + // physical_strides = [B1 * B2 * S, B2 * S, S] + auto physical_strides = at::detail::defaultStrides(batch_shape); + TORCH_CHECK(size.size() == stride.size(), + "new_empty_strided(sizes, strides): dimensionality of sizes (", + size.size(), ") must match dimensionality of strides (", + stride.size(), ")"); + auto storage_size = native::storage_size_for(size, stride); + for (auto& physical_stride : physical_strides) { + physical_stride *= storage_size; + } + + // physical_strides = [B1 * B2 * S, B2 * S, S] + strides + physical_strides.insert(physical_strides.end(), stride.begin(), stride.end()); + + auto result = physical_view.tensor().new_empty_strided( + physical_size, physical_strides, dtype, layout, device, pin_memory); + return physical_view.getPhysicalToLogicalMap().apply(result); +} + +template +Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) { + auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); + auto result = Func(physical_args[0].tensor(), physical_args[1].tensor()); + return physical_args[0].getPhysicalToLogicalMap().apply(result); +} + TORCH_LIBRARY_IMPL(_, Batched, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>()); } @@ -521,12 +1026,20 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("_add_batch_dim", native::_add_batch_dim); m.impl("_remove_batch_dim", native::_remove_batch_dim); - m.impl_UNBOXED("sum.dim_IntList", sum_batching_rule); + m.impl("sum.dim_IntList", sum_batching_rule); m.impl("is_complex", native::is_complex); m.impl("conj", native::conj); + // inplace operations + m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule); + m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule); + m.impl("zero_", zero_inplace_batching_rule); + // view operations + m.impl("as_strided", as_strided_batching_rule); m.impl("chunk", chunk_batching_rule); + m.impl("tensor_split.sections", tensor_split_sections_batching_rule); + m.impl("tensor_split.indices", tensor_split_indices_batching_rule); m.impl("diagonal", diagonal_batching_rule); m.impl("expand", expand_batching_rule); m.impl("expand_as", native::expand_as); // composite wrt autograd @@ -543,8 +1056,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("slice.Tensor", slice_batching_rule); m.impl("split.Tensor", split_batching_rule); m.impl("split_with_sizes", split_with_sizes_batching_rule); + m.impl("squeeze", squeeze_batching_rule); m.impl("squeeze.dim", squeeze_dim_batching_rule); m.impl("t", native::t); // composite wrt autograd + m.impl("trace", trace_batching_rule); m.impl("transpose.int", transpose_int_batching_rule); m.impl("unbind.int", unbind_batching_rule); m.impl("unfold", unfold_batching_rule); @@ -552,9 +1067,14 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("view", view_batching_rule); m.impl("view_as", native::view_as); // composite wrt autograd + // clamp operations + m.impl("clamp", clamp_batching_rule); + m.impl("clamp_min", clamp_min_batching_rule); + m.impl("clamp_max", clamp_max_batching_rule); + // unary pointwise, out-of-place, no additional arguments. #define UNARY_POINTWISE(op) m.impl(#op, \ - unary_pointwise_batching_rule); + unwrap_and_call); UNARY_POINTWISE(abs); UNARY_POINTWISE(acos); UNARY_POINTWISE(asin); @@ -590,7 +1110,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { #define TO_BATCHING_RULE(name, ...) \ { \ using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \ - m.impl(name, unary_pointwise_method_batching_rule< \ + m.impl(name, unwrap_and_call_method< \ to_type, &Tensor::to, __VA_ARGS__>);\ } TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, optional) @@ -598,19 +1118,21 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional) m.impl("to.dtype_layout", to_dtype_layout_batching_rule); #undef TO_BATCHING_RULE + m.impl("clone", clone_batching_rule); + using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar); using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&); using TensorScalarType = Tensor (*)(const Tensor&, Scalar); #define BINARY_POINTWISE(op) \ m.impl(#op".Tensor", binary_pointwise_batching_rule); \ - m.impl(#op".Scalar", unary_pointwise_batching_rule); + m.impl(#op".Scalar", unwrap_and_call); #define BINARY_POINTWISE_VA(op, ...) \ { \ using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \ using Unop = Tensor (*)(const Tensor&, Scalar, __VA_ARGS__); \ m.impl(#op".Tensor", binary_pointwise_batching_rule); \ - m.impl(#op".Scalar", unary_pointwise_batching_rule); \ + m.impl(#op".Scalar", unwrap_and_call); \ } BINARY_POINTWISE_VA(add, Scalar); @@ -621,10 +1143,16 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { // at::pow has three out-of-place overloads m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule); - m.impl("pow.Tensor_Scalar", unary_pointwise_batching_rule); + m.impl("pow.Tensor_Scalar", unwrap_and_call); m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule); m.impl("sigmoid_backward", binary_pointwise_batching_rule); + m.impl( + "threshold_backward", + binary_pointwise_batching_rule< + TensorTensorScalarType, + at::threshold_backward, + Scalar>); // for at::result_type, call the native::result_type implementation. // We don't have to do anything special because native::result_type operates @@ -637,6 +1165,16 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { #undef BINARY_POINTWISE_VA #undef BINARY_POINTWISE + +#define TRIVIAL_OP(op) m.impl(#op, \ + unwrap_and_call); + // complex number view operators + TRIVIAL_OP(imag) + TRIVIAL_OP(real); + TRIVIAL_OP(view_as_real); + m.impl("view_as_complex", view_as_complex_batching_rule); +#undef TRIVIAL + // matmul-like operators m.impl("mv", mv_batching_rule); m.impl("dot", dot_batching_rule); @@ -650,7 +1188,29 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { // backward operators m.impl("select_backward", select_backward_batching_rule); m.impl("slice_backward", slice_backward_batching_rule); + m.impl("trace_backward", trace_backward_batching_rule); m.impl("diagonal_backward", diagonal_backward_batching_rule); + + // Tensor.new_* operators + m.impl("new_empty", new_empty_batching_rule); + m.impl("new_empty_strided", new_empty_strided_batching_rule); + m.impl("new_zeros", new_zeros_batching_rule); + + m.impl("contiguous", contiguous_batching_rule); + + // Comparison ops +#define COMPARISON_POINTWISE(op) \ + m.impl(#op".Tensor", comparison_pointwise_batching_rule); \ + m.impl(#op".Scalar", unwrap_and_call); + + COMPARISON_POINTWISE(eq); + COMPARISON_POINTWISE(gt); + COMPARISON_POINTWISE(ge); + COMPARISON_POINTWISE(le); + COMPARISON_POINTWISE(lt); + COMPARISON_POINTWISE(ne); + +#undef COMPARISON_POINTWISE } } // namespace at diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 5ec9d24eea390..6fedef185b214 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -51,6 +51,7 @@ file(GLOB cudnn_cpp "cudnn/*.cpp") file(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh" "hip/impl/*.h") file(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp") +list(REMOVE_ITEM hip_cpp "${CMAKE_CURRENT_SOURCE_DIR}/hip/detail/LazyNVRTC.cpp") file(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip") file(GLOB hip_nvrtc_stub_h "hip/nvrtc_stub/*.h") file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp") @@ -64,7 +65,18 @@ file(GLOB native_cpp "native/*.cpp") file(GLOB native_mkl_cpp "native/mkl/*.cpp") file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") file(GLOB vulkan_cpp "vulkan/*.cpp") -file(GLOB native_vulkan_cpp "native/vulkan/api/*.cpp" "native/vulkan/*.cpp") +file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/ops/*.cpp") + +# Metal +file(GLOB metal_h "metal/*.h") +file(GLOB metal_cpp "metal/*.cpp") +file(GLOB_RECURSE native_metal_h "native/metal/*.h") +file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm") +file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm" "native/metal/*.cpp") +EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs}) +file(GLOB metal_prepack_h "native/metal/MetalPrepackOpContext.h") +file(GLOB metal_prepack_cpp "native/metal/MetalPrepackOpRegister.cpp") + file(GLOB native_sparse_cpp "native/sparse/*.cpp") file(GLOB native_quantized_cpp "native/quantized/*.cpp" @@ -103,7 +115,8 @@ append_filelist("jit_core_headers" ATen_CORE_HEADERS) append_filelist("jit_core_sources" ATen_CORE_SRCS) add_subdirectory(quantized) -set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${native_utils_cpp} ${native_xnnpack} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp}) +add_subdirectory(nnapi) +set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${native_utils_cpp} ${native_xnnpack} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${ATen_NNAPI_SRCS} ${cpu_kernel_cpp}) if(AT_MKL_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp}) endif() @@ -116,6 +129,18 @@ else() set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp}) endif() +# Metal +if(USE_PYTORCH_METAL) + if(APPLE) + set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs}) + else() + # Add files needed from optimized_for_mobile + set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${metal_prepack_cpp}) + endif() +else() + set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp}) +endif() + if(USE_CUDA AND USE_ROCM) message(FATAL_ERROR "ATen doesn't not currently support simultaneously building with CUDA and ROCM") endif() @@ -374,6 +399,14 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS}) if(NOT INTERN_BUILD_MOBILE) list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${miopen_h}) +else() + if(USE_PYTORCH_METAL) + if(IOS) + list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h}) + else() + list(APPEND INSTALL_HEADERS ${metal_h} ${metal_prepack_h}) + endif() + endif() endif() # https://stackoverflow.com/questions/11096471/how-can-i-install-a-hierarchy-of-files-using-cmake @@ -407,6 +440,8 @@ endif() list(APPEND ATen_MOBILE_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/tensor_add.cpp) +list(APPEND ATen_MOBILE_BENCHMARK_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/quantize_per_channel.cpp) list(APPEND ATen_MOBILE_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/stateful_conv1d.cpp) diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp index bfa4a2a8f72f6..ff4a2f1c61e27 100644 --- a/aten/src/ATen/CPUGeneratorImpl.cpp +++ b/aten/src/ATen/CPUGeneratorImpl.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include @@ -6,6 +8,42 @@ namespace at { namespace detail { +/** + * CPUGeneratorImplStateLegacy is a POD class needed for memcpys + * in torch.get_rng_state() and torch.set_rng_state(). + * It is a legacy class and even though it is replaced with + * at::CPUGeneratorImpl, we need this class and some of its fields + * to support backward compatibility on loading checkpoints. + */ +struct CPUGeneratorImplStateLegacy { + /* The initial seed. */ + uint64_t the_initial_seed; + int left; /* = 1; */ + int seeded; /* = 0; */ + uint64_t next; + uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */ + + /********************************/ + + /* For normal distribution */ + double normal_x; + double normal_y; + double normal_rho; + int normal_is_valid; /* = 0; */ +}; + +/** + * CPUGeneratorImplState is a POD class containing + * new data introduced in at::CPUGeneratorImpl and the legacy state. It is used + * as a helper for torch.get_rng_state() and torch.set_rng_state() + * functions. + */ +struct CPUGeneratorImplState { + CPUGeneratorImplStateLegacy legacy_pod; + float next_float_normal_sample; + bool is_next_float_normal_sample_valid; +}; + /** * PyTorch maintains a collection of default generators that get * initialized once. The purpose of these default generators is to @@ -75,6 +113,128 @@ uint64_t CPUGeneratorImpl::seed() { return random; } +/** + * Sets the internal state of CPUGeneratorImpl. The new internal state + * must be a strided CPU byte tensor and of the same size as either + * CPUGeneratorImplStateLegacy (for legacy CPU generator state) or + * CPUGeneratorImplState (for new state). + * + * FIXME: Remove support of the legacy state in the future? + */ +void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + using detail::CPUGeneratorImplState; + using detail::CPUGeneratorImplStateLegacy; + + static_assert(std::is_pod::value, "CPUGeneratorImplStateLegacy is not a PODType"); + static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType"); + + static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy); + static const size_t size_current = sizeof(CPUGeneratorImplState); + static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size"); + + detail::check_rng_state(new_state); + + at::mt19937 engine; + auto float_normal_sample = c10::optional(); + auto double_normal_sample = c10::optional(); + + // Construct the state of at::CPUGeneratorImpl based on input byte tensor size. + CPUGeneratorImplStateLegacy* legacy_pod; + auto new_state_size = new_state.numel(); + if (new_state_size == size_legacy) { + legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data(); + // Note that in CPUGeneratorImplStateLegacy, we didn't have float version + // of normal sample and hence we leave the c10::optional as is + + // Update next_double_normal_sample. + // Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y) + // and a rho value (normal_rho). These three values were redundant and in the new + // DistributionsHelper.h, we store the actual extra normal sample, rather than three + // intermediate values. + if (legacy_pod->normal_is_valid) { + auto r = legacy_pod->normal_rho; + auto theta = 2.0 * M_PI * legacy_pod->normal_x; + // we return the sin version of the normal sample when in caching mode + double_normal_sample = c10::optional(r * ::sin(theta)); + } + } else if (new_state_size == size_current) { + auto rng_state = (CPUGeneratorImplState*)new_state.data(); + legacy_pod = &rng_state->legacy_pod; + // update next_float_normal_sample + if (rng_state->is_next_float_normal_sample_valid) { + float_normal_sample = c10::optional(rng_state->next_float_normal_sample); + } + + // Update next_double_normal_sample. + // Note that in getRNGState, we now return the actual normal sample in normal_y + // and if it's valid in normal_is_valid. The redundant normal_x and normal_rho + // are squashed to 0.0. + if (legacy_pod->normal_is_valid) { + double_normal_sample = c10::optional(legacy_pod->normal_y); + } + } else { + AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy, + " or a CPUGeneratorImplState of size ", size_current, + " but found the input RNG state size to be ", new_state_size); + } + + // construct engine_ + // Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our + // redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are + // doing a std::copy. + at::mt19937_data_pod rng_data; + std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin()); + rng_data.seed_ = legacy_pod->the_initial_seed; + rng_data.left_ = legacy_pod->left; + rng_data.seeded_ = legacy_pod->seeded; + rng_data.next_ = static_cast(legacy_pod->next); + engine.set_data(rng_data); + TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state"); + this->engine_ = engine; + this->next_float_normal_sample_ = float_normal_sample; + this->next_double_normal_sample_ = double_normal_sample; +} + +/** + * Gets the current internal state of CPUGeneratorImpl. The internal + * state is returned as a CPU byte tensor. + */ +c10::intrusive_ptr CPUGeneratorImpl::get_state() const { + using detail::CPUGeneratorImplState; + + static const size_t size = sizeof(CPUGeneratorImplState); + static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType"); + + auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto rng_state = state_tensor.data_ptr(); + + // accumulate generator data to be copied into byte tensor + auto accum_state = std::make_unique(); + auto rng_data = this->engine_.data(); + accum_state->legacy_pod.the_initial_seed = rng_data.seed_; + accum_state->legacy_pod.left = rng_data.left_; + accum_state->legacy_pod.seeded = rng_data.seeded_; + accum_state->legacy_pod.next = rng_data.next_; + std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state)); + accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy + accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy + accum_state->legacy_pod.normal_is_valid = false; + accum_state->legacy_pod.normal_y = 0.0; + accum_state->next_float_normal_sample = 0.0f; + accum_state->is_next_float_normal_sample_valid = false; + if (this->next_double_normal_sample_) { + accum_state->legacy_pod.normal_is_valid = true; + accum_state->legacy_pod.normal_y = *(this->next_double_normal_sample_); + } + if (this->next_float_normal_sample_) { + accum_state->is_next_float_normal_sample_valid = true; + accum_state->next_float_normal_sample = *(this->next_float_normal_sample_); + } + + memcpy(rng_state, accum_state.get(), size); + return state_tensor.getIntrusivePtr(); +} + /** * Gets the DeviceType of CPUGeneratorImpl. * Used for type checking during run time. diff --git a/aten/src/ATen/CPUGeneratorImpl.h b/aten/src/ATen/CPUGeneratorImpl.h index 04119d121b244..f8b43a04c73c0 100644 --- a/aten/src/ATen/CPUGeneratorImpl.h +++ b/aten/src/ATen/CPUGeneratorImpl.h @@ -7,7 +7,7 @@ namespace at { -struct CAFFE2_API CPUGeneratorImpl : public c10::GeneratorImpl { +struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl { // Constructors CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val); ~CPUGeneratorImpl() = default; @@ -17,6 +17,8 @@ struct CAFFE2_API CPUGeneratorImpl : public c10::GeneratorImpl { void set_current_seed(uint64_t seed) override; uint64_t current_seed() const override; uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; static DeviceType device_type(); uint32_t random(); uint64_t random64(); @@ -36,8 +38,8 @@ struct CAFFE2_API CPUGeneratorImpl : public c10::GeneratorImpl { namespace detail { -CAFFE2_API const Generator& getDefaultCPUGenerator(); -CAFFE2_API Generator createCPUGenerator(uint64_t seed_val = default_rng_seed_val); +TORCH_API const Generator& getDefaultCPUGenerator(); +TORCH_API Generator createCPUGenerator(uint64_t seed_val = default_rng_seed_val); } // namespace detail diff --git a/aten/src/ATen/CUDAGeneratorImpl.h b/aten/src/ATen/CUDAGeneratorImpl.h index 57ace5f63bcc8..1179a049aa081 100644 --- a/aten/src/ATen/CUDAGeneratorImpl.h +++ b/aten/src/ATen/CUDAGeneratorImpl.h @@ -2,10 +2,122 @@ #include #include +#include +#include +#include // TODO: this file should be in ATen/cuda, not top level namespace at { +/** + * Note [CUDA Graph-safe RNG states] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * + * Strategy: + * ~~~~~~~~~ + * A CUDA graph containing multiple RNG ops behaves like a + * single giant kernel from the perspective of ops external + * to the graph. During graph capture, logic below records + * the total of all offset increments that occur in the graphed + * region, and records the final total as the offset for the + * entire graph. + * + * When the graph reruns, the logic that reruns it + * increments this device's CUDA generator's offset + * by that total. + * + * Meanwhile, within the graph, at capture time, instead of + * populating PhiloxCudaStates with the uint64_t offset pulled + * directly from the global state, PhiloxCudaState instead + * holds a pointer to one-element stream-local int64_t device tensor + * holding an initial offset value, and a uint64_t holding an + * intra-graph offset. (The intra-graph offset starts from zero + * when capture begins.) In each consumer kernel, + * at::cuda::philox::unpack computes the offset to use for this kernel + * as intra-graph offset + *initial offset. + * + * When the graph reruns, the logic that reruns it first + * fill_s the initial offset tensor with this device's + * CUDA generator's current offset. + * + * The control flow above ensures graphed execution is bitwise + * identical to eager execution as long as RNG ops are enqueued + * from a single thread, even if RNG ops and graphs containing + * RNG ops are enqueued and run simultaneously on multiple streams. + * + * Usage: + * ~~~~~~ + * PhiloxCudaState in this file, and unpack() in + * cuda/CUDAGraphsUtils.cuh allow non-divergent use of + * CUDAGeneratorImpl whether graph capture is underway or not. + * + * Each PhiloxCudaState instance should be used for one and only one + * consumer kernel. + * + * Example (see e.g. native/cuda/Dropout.cu): + * + * #include + * #include + * + * __global__ void kernel(..., PhiloxCudaState philox_args) { + * auto seeds = at::cuda::philox::unpack(philox_args); + * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + * curandStatePhilox4_32_10_t state; + * curand_init(std::get<0>(seeds), // seed + * idx, // per-thread subsequence + * std::get<1>(seeds), // offset in subsequence + * &state); + * ... + * } + * + * host_caller(...) { + * PhiloxCudaState rng_engine_inputs; + * { + * // See Note [Acquire lock when using random generators] + * std::lock_guard lock(gen->mutex_); + * + * // gen could be HostState or DevState here! No divergent code needed! + * rng_engine_inputs = gen->philox_cuda_state(offset_increment); + * } + * kernel<<<...>>>(..., rng_engine_inputs); + * } + * + */ + + +// Stores state values. Passed as a kernel argument. See "Usage:" above. +struct PhiloxCudaState { + PhiloxCudaState() = default; + PhiloxCudaState(const PhiloxCudaState&) = default; + // Called if graph capture is not underway + PhiloxCudaState(uint64_t seed, + uint64_t offset) { + seed_ = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxCudaState(uint64_t seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_ = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + // Public members, directly accessible by at::cuda::philox::unpack. + // If we made them private with getters/setters, the getters/setters + // would have to be __device__, and we can't declare __device__ in ATen. + union Payload { + uint64_t val; + int64_t* ptr; + }; + + uint64_t seed_; + Payload offset_; + uint32_t offset_intragraph_; + bool captured_ = false; +}; struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl { // Constructors @@ -17,15 +129,27 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl { void set_current_seed(uint64_t seed) override; uint64_t current_seed() const override; uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; void set_philox_offset_per_thread(uint64_t offset); - uint64_t philox_offset_per_thread(); + uint64_t philox_offset_per_thread() const; + void capture_prologue(int64_t* offset_extragraph); + uint64_t capture_epilogue(); + PhiloxCudaState philox_cuda_state(uint64_t increment); + + // Temporarily accommodates call sites that use philox_engine_inputs. + // Allows incremental refactor of call sites to use philox_cuda_state. std::pair philox_engine_inputs(uint64_t increment); + static DeviceType device_type(); private: CUDAGeneratorImpl* clone_impl() const override; uint64_t seed_ = default_rng_seed_val; uint64_t philox_offset_per_thread_ = 0; + int64_t* offset_extragraph_; + uint32_t offset_intragraph_ = 0; + bool graph_expects_this_gen_ = false; }; namespace cuda { diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index 58c06c63535dc..38326491bed83 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -8,6 +8,7 @@ #define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@ #define AT_MKL_ENABLED() @AT_MKL_ENABLED@ +#define AT_FFTW_ENABLED() @AT_FFTW_ENABLED@ #define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@ #define CAFFE2_STATIC_LINK_CUDA() @CAFFE2_STATIC_LINK_CUDA_INT@ #define AT_BUILD_WITH_BLAS() @USE_BLAS@ diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 18673877c2192..e17322e1681dc 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -64,6 +65,11 @@ bool Context::deterministic() const { } void Context::setDeterministic(bool b) { + if (b) { + TORCH_WARN_ONCE("torch.set_deterministic is in beta, and its design and " + " functionality may change in the future."); + } + _deterministic = b; } @@ -227,7 +233,7 @@ bool Context::setFlushDenormal(bool on) { } Allocator* getCPUAllocator() { - return getTHDefaultAllocator(); + return c10::GetCPUAllocator(); } // override_allow_tf32_flag = true diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index fed5e88e53144..276bf16a2a53a 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -21,7 +21,7 @@ namespace at { class Tensor; -class CAFFE2_API Context { +class TORCH_API Context { public: Context(); @@ -225,13 +225,13 @@ class CAFFE2_API Context { std::unique_ptr thh_state; }; -CAFFE2_API Context& globalContext(); +TORCH_API Context& globalContext(); static inline void init() { globalContext(); } -CAFFE2_API Allocator* getCPUAllocator(); +TORCH_API Allocator* getCPUAllocator(); static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( @@ -314,10 +314,12 @@ static inline void manual_seed(uint64_t seed) { } // NB: Sometimes we build with CUDA, but we don't have any GPUs // available. In that case, we must not seed CUDA; it will fail! - int num_gpus = detail::getCUDAHooks().getNumGPUs(); + const auto num_gpus = detail::getCUDAHooks().getNumGPUs(); if (hasCUDA() && num_gpus > 0) { for (int i = 0; i < num_gpus; i++) { - auto cuda_gen = globalContext().defaultGenerator(Device(at::kCUDA, i)); + auto cuda_gen = globalContext().defaultGenerator( + Device(at::kCUDA, static_cast(i)) + ); { // See Note [Acquire lock when using random generators] std::lock_guard lock(cuda_gen.mutex()); diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 786fe6214dc3c..fd045960b52c0 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -43,13 +43,10 @@ DLDataType getDLDataType(const Tensor& t) { throw std::logic_error("BFloat16 is not supported by dlpack"); break; case ScalarType::QInt8: - throw std::logic_error("QInt8 is not supported by dlpack"); - break; case ScalarType::QUInt8: - throw std::logic_error("QUInt8 is not supported by dlpack"); - break; case ScalarType::QInt32: - throw std::logic_error("QInt32 is not supported by dlpack"); + case ScalarType::QUInt4x2: + throw std::logic_error("QUInt/QInt types are not supported by dlpack"); break; case ScalarType::ComplexHalf: throw std::logic_error("ComplexHalf is not supported by dlpack"); diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index 8458e6ec2d6b6..a34d4b3e7a4df 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -10,10 +10,10 @@ namespace at { -CAFFE2_API ScalarType toScalarType(const DLDataType& dtype); -CAFFE2_API DLManagedTensor* toDLPack(const Tensor& src); -CAFFE2_API Tensor fromDLPack(const DLManagedTensor* src); -CAFFE2_API DLDataType getDLDataType(const Tensor& t); -CAFFE2_API DLContext getDLContext(const Tensor& tensor, const int64_t& device_id); +TORCH_API ScalarType toScalarType(const DLDataType& dtype); +TORCH_API DLManagedTensor* toDLPack(const Tensor& src); +TORCH_API Tensor fromDLPack(const DLManagedTensor* src); +TORCH_API DLDataType getDLDataType(const Tensor& t); +TORCH_API DLContext getDLContext(const Tensor& tensor, const int64_t& device_id); } //namespace at diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 53a22db6ff9c9..341e20cab1f3d 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -2,17 +2,63 @@ #include #include +#include #include #include #include +#include #include +#include -#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \ - case enum_type: { \ - using scalar_t = type; \ - return __VA_ARGS__(); \ +#ifdef XPLAT_MOBILE_BUILD +#include +#else +namespace at { +/** + * The method should_include_kernel_dtype() returns true/false + * based on whether the switching code for a specific dtype should be + * included based on build time constants generated from tracing model + * execution. This method will be implmeneted via code-generation and + * included in this file when code-gen is ready. + */ +inline constexpr bool should_include_kernel_dtype( + const char *kernel_tag_str, + at::ScalarType scalar_type +) { + return true; +} +} +#endif + +/** + * In the Facebook internal build (using BUCK), this macro is enabled by + * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer + * binary. + */ +#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \ + {RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::KERNEL_FUNCTION_DTYPE, \ + std::string(NAME) + "$" + toString(enum_type), \ + {});} +#else +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) +#endif + +#define AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ + case enum_type: { \ + at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \ + [&] { \ + AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \ + } \ + ); \ + using HINT = type; \ + return __VA_ARGS__(); \ } +#define AT_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, scalar_t, __VA_ARGS__) + // Workaround for C10_UNUSED because CUDA 10.1 and below fails to handle unused // attribute in the type aliasing context. Keep name long and verbose to avoid // macro collisions. @@ -31,28 +77,25 @@ const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \ const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ toUnderlying(enum_type); \ + (void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \ + /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \ return __VA_ARGS__(); \ } -// This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and -// should be removed once the bfloat16 bringup is complete on other platforms. -// This is supposed to be used as a wrapper around the lambda function passed to -// the dispatch macro and will conditionally dispatch ops with bfloat16 type -// only on ROCm. -#if !defined(__HIP_PLATFORM_HCC__) -#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) \ - if (std::is_same::value) { \ - AT_ERROR( \ - #NAME, \ - " not implemented for '", \ - toString(at::ScalarType::BFloat16), \ - "'"); \ - } else { \ - return __VA_ARGS__(); \ +#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \ + case enum_type: { \ + using scalar_t = type; \ + using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ + scalar_t::underlying; \ + const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \ + const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ + toUnderlying(enum_type); \ + int bit_width = bitwidth; \ + int64_t quant_min = qmin; \ + int64_t quant_max = qmax; \ + return __VA_ARGS__(); \ } -#else -#define AT_SKIP_BFLOAT16_IF_NOT_ROCM(SCALARTYPE, NAME, ...) return __VA_ARGS__() -#endif namespace detail { @@ -126,6 +169,21 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} // 4. Should complex be supported? The answer is almost always no, // unless you are working on "generic" code that should work on // all dtypes. +// +// Parameters: +// ----------- +// +// 1. The NAME argument is a "tag" that is used to trace and then +// conditionally compile fragments of the case statements such +// that the kernel functions are specialized only for the dtypes +// that are needed. The NAME parameter *must* be a build time +// cons char* (can't be std::string, etc...) +// +// Please ensure that the NAME is unique for every implementation +// or you run the risk of over-including code for the kernel +// functions. There is no risk of missing out on any code, so +// it's mostly a risk of a Type-2 error, and not a Type-1 error. +// // NB: the the_type variable is not used, but we have kept it for // backwards compatibility. It's probably not used by anyone though; @@ -137,26 +195,28 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ +#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }() #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ @@ -164,10 +224,11 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -182,14 +243,17 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -203,13 +267,20 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ @@ -221,14 +292,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -242,19 +317,28 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} [&] { \ const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ @@ -268,31 +352,36 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ +#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() @@ -301,17 +390,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }() #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ @@ -319,11 +409,18 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ @@ -334,6 +431,7 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ AT_QINT_PRIVATE_CASE_TYPE( \ at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \ @@ -346,22 +444,43 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} } \ }() +#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + }() + #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ [&] { \ const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op*/ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, \ at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ + AT_PRIVATE_CASE_TYPE(NAME, \ at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ @@ -370,154 +489,210 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexFloat, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + NAME, \ + at::ScalarType::ComplexDouble, \ + c10::complex, \ + __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_ALL_TYPES_AND3( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ +#define AT_DISPATCH_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE3, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op*/ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + NAME, at::ScalarType::ComplexDouble, c10::complex, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE1, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE2, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE( \ + NAME, \ SCALARTYPE3, \ decltype(c10::impl::ScalarTypeToCPPType::t), \ __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ + }() + +#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& the_index_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _it = ::detail::scalar_type(the_index_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _it) \ + switch (_it) { \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Int, int32_t, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(NAME, at::ScalarType::Long, int64_t, index_t, __VA_ARGS__)\ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \ } \ }() @@ -531,15 +706,16 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ diff --git a/aten/src/ATen/DynamicLibrary.cpp b/aten/src/ATen/DynamicLibrary.cpp index d47d2bf7c5cb9..09ce15eac33b0 100644 --- a/aten/src/ATen/DynamicLibrary.cpp +++ b/aten/src/ATen/DynamicLibrary.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -6,7 +7,7 @@ #include #include #else -#include +#include #endif namespace at { @@ -48,10 +49,11 @@ DynamicLibrary::DynamicLibrary(const char* name) { // NOLINTNEXTLINE(hicpp-signed-bitwise) HMODULE theModule; bool reload = true; + auto wname = c10::u8u16(name); // Check if LOAD_LIBRARY_SEARCH_DEFAULT_DIRS is supported - if (GetProcAddress(GetModuleHandle("KERNEL32.DLL"), "AddDllDirectory") != NULL) { - theModule = LoadLibraryExA( - name, + if (GetProcAddress(GetModuleHandleW(L"KERNEL32.DLL"), "AddDllDirectory") != NULL) { + theModule = LoadLibraryExW( + wname.c_str(), NULL, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS); if (theModule != NULL || (GetLastError() != ERROR_MOD_NOT_FOUND)) { @@ -60,7 +62,7 @@ DynamicLibrary::DynamicLibrary(const char* name) { } if (reload) { - theModule = LoadLibraryA(name); + theModule = LoadLibraryW(wname.c_str()); } if (theModule) { diff --git a/aten/src/ATen/DynamicLibrary.h b/aten/src/ATen/DynamicLibrary.h index ea919a79d318b..089503cb9c0cc 100644 --- a/aten/src/ATen/DynamicLibrary.h +++ b/aten/src/ATen/DynamicLibrary.h @@ -8,11 +8,11 @@ namespace at { struct DynamicLibrary { AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); - CAFFE2_API DynamicLibrary(const char* name); + TORCH_API DynamicLibrary(const char* name); - CAFFE2_API void* sym(const char* name); + TORCH_API void* sym(const char* name); - CAFFE2_API ~DynamicLibrary(); + TORCH_API ~DynamicLibrary(); private: void* handle = nullptr; diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp index 0f5cfaec85742..ce21a6ff46e1e 100644 --- a/aten/src/ATen/ExpandUtils.cpp +++ b/aten/src/ATen/ExpandUtils.cpp @@ -85,4 +85,97 @@ std::tuple, std::vector> inferExpandGeometry( expandedSizes, expandedStrides); } + +// This function returns a dense and non-overlapping strides, which keeps the same layout permutation +// as the input `tensor_strides`, computed based on the input `tensor_sizes`. +// Note: +// 1. This function expects the inputs `tensor_strides` and `tensor_sizes` are non-dense or overlapping, +// If the inputs are densed and non-overlapping, the output strides will be the same as `tensor_strides`. +// However, this function won't check whether inputs are dense or overlapping, so the whole function will +// still be executed even the inputs are already dense and non-overlapping, this will cause slowness. +// +// Please verify whether the inputs are non-dense or overlapping before calling this function if possible, +// if the inputs come from a tensor, you can check this through `is_non_overlapping_and_dense()` +// +// 2. The strides propagation rule that is used in this function is exactily the same as what is being used in +// TensorIterator. Please refer to https://github.com/pytorch/pytorch/pull/42922 for more details + +std::vector infer_dense_strides(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) { + + TORCH_CHECK(tensor_sizes.size() == tensor_strides.size(), + "Input sizes and strides should have same size but got ", tensor_sizes.size(), " and ", tensor_strides.size()); + + size_t ndim = tensor_sizes.size(); + if (ndim == 0) { + return {}; + } + if (ndim == 1) { + return {1}; + } + + std::vector perm(ndim); + // initialize perm with n-1, n-2, ..., 1, 0 + std::iota(perm.rbegin(), perm.rend(), 0); + + // The following sorting algorithm has exactly the same behavior as TensorIterator + // This is to make sure we have the same stride propagation everywhere. + + // return -1 if dim0 should come before dim1 + // return 1 if dim0 should come after dim1 + // return 0 if comparison is ambiguous + auto should_swap = [&](size_t dim0, size_t dim1) { + int64_t stride0 = tensor_strides[dim0]; + int64_t stride1 = tensor_strides[dim1]; + + // if any stride is 0, treat it as ambiguous comparison to + // keep the same behavior as TensorIterator + if (stride0 == 0 || stride1 == 0) { + return 0; + } + if (stride0 < stride1) { + return -1; + } + if (stride0 > stride1) { + return 1; + } + // for equal strides, the dimension with smaller size goes front + if (tensor_sizes[dim0] > tensor_sizes[dim1]) { + return 1; + } + return 0; + }; + + // Insertion sort (stable) indices in `perm` based on input tensor's stride and shape, + // all dimensions with 0 stride won't move. This is the same behavior as TensorIterator. + // eg. Given tensor with size/stride (6, 5, 4, 3, 2)/(6, 0, 120, 0, 1), the initial `perm` + // is (4, 3, 2, 1, 0) and the sorted `perm` will be (4, 3, 0, 1, 2) + for (int i = 1; i < ndim; ++i) { + int dim1 = i; + for (int dim0 = i - 1; dim0 >= 0; --dim0) { + int comparison = should_swap(perm[dim0], perm[dim1]); + if (comparison > 0) { + std::swap(perm[dim0], perm[dim1]); + dim1 = dim0; + } + else if (comparison < 0) { + break; + } + } + } + + // compute output strides which preserves the input tensor's memory layout + std::vector out_strides(ndim); + int64_t curr_stride = 1; + for (size_t i = 0; i < ndim; ++i) { + int64_t idx = perm[i]; + out_strides[idx] = curr_stride; + // Note: for size 0, we simply treated it as 1, it really doesn't matter here + // since the total number of element is 0. + if (tensor_sizes[idx] > 1) { + curr_stride *= tensor_sizes[idx]; + } + } + return out_strides; +} + } // namespace at diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 303456a5b2f91..b03c293c17be7 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -9,13 +9,17 @@ namespace at { -CAFFE2_API std::vector infer_size(IntArrayRef a, IntArrayRef b); -CAFFE2_API std::tuple, std::vector> +TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b); +TORCH_API std::tuple, std::vector> inferExpandGeometry( IntArrayRef tensor_sizes, IntArrayRef tensor_strides, IntArrayRef sizes); +TORCH_API std::vector infer_dense_strides( + IntArrayRef tensor_sizes, + IntArrayRef tensor_strides); + // True if input shapes are expandable // NOTE: infer_size did a similar check, please keep them sync if change is needed inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) { diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp index 9136027c4c1e6..364b70ee5775c 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp +++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp @@ -1,7 +1,5 @@ #include -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.cpp - #include #include #include @@ -39,7 +37,7 @@ namespace { Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Bool: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CPU, dispatch_scalar_type); @@ -112,7 +110,7 @@ Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Bool: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CPU, dispatch_scalar_type); @@ -185,7 +183,7 @@ Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tens Tensor & _th_nonzero_out(Tensor & result, const Tensor & self) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Bool: { auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long); @@ -316,7 +314,7 @@ Tensor _th_nonzero(const Tensor & self) { Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Bool: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_index_copy_", false, DeviceType::CPU, dispatch_scalar_type); @@ -379,135 +377,10 @@ Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const } return self; } -Tensor & _th_take_out(Tensor & result, const Tensor & self, const Tensor & index) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THBoolTensor_take(result_, self_, index_); - break; - } - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THByteTensor_take(result_, self_, index_); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THCharTensor_take(result_, self_, index_); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THDoubleTensor_take(result_, self_, index_); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THFloatTensor_take(result_, self_, index_); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THIntTensor_take(result_, self_, index_); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THLongTensor_take(result_, self_, index_); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CPU, ScalarType::Long); - THShortTensor_take(result_, self_, index_); - break; - } - default: - AT_ERROR("_th_take_out not supported on CPUType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_take(const Tensor & self, const Tensor & index) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THBoolTensor_take(result_, self_, index_); - break; - } - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THByteTensor_take(result_, self_, index_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THCharTensor_take(result_, self_, index_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THDoubleTensor_take(result_, self_, index_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THFloatTensor_take(result_, self_, index_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THIntTensor_take(result_, self_, index_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THLongTensor_take(result_, self_, index_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CPU, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CPU, ScalarType::Long); - THShortTensor_take(result_, self_, index_); - break; - } - default: - AT_ERROR("_th_take not supported on CPUType for ", dispatch_scalar_type); - } - return result; -} Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Bool: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_put_", false, DeviceType::CPU, dispatch_scalar_type); @@ -573,7 +446,7 @@ Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bo Tensor & _th_index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Bool: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_index_fill_", false, DeviceType::CPU, dispatch_scalar_type); @@ -639,7 +512,7 @@ Tensor & _th_index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scala std::tuple _th_mode_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool keepdim) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Byte: { auto values_ = checked_dense_tensor_unwrap(values, "values", 0, "_th_mode_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -746,7 +619,7 @@ std::tuple _th_mode(const Tensor & self, int64_t dim, bool keepdi Tensor _th_var(const Tensor & self, bool unbiased) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_var", false, DeviceType::CPU, dispatch_scalar_type); @@ -765,7 +638,7 @@ Tensor _th_var(const Tensor & self, bool unbiased) { Tensor _th_std(const Tensor & self, bool unbiased) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_std", false, DeviceType::CPU, dispatch_scalar_type); @@ -784,7 +657,7 @@ Tensor _th_std(const Tensor & self, bool unbiased) { Tensor & _th_renorm_out(Tensor & result, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_renorm_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -835,7 +708,7 @@ Tensor _th_renorm(const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm_", false, DeviceType::CPU, dispatch_scalar_type); @@ -859,7 +732,7 @@ Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scalar min, Scalar max) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_histc_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -907,54 +780,11 @@ Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max) { } return result; } -Tensor _th_trace(const Tensor & self) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THByteTensor_trace(self_)), options(ScalarType::Byte)); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THCharTensor_trace(self_)), options(ScalarType::Char)); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THDoubleTensor_trace(self_)), options(ScalarType::Double)); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THFloatTensor_trace(self_)), options(ScalarType::Float)); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THIntTensor_trace(self_)), options(ScalarType::Int)); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THLongTensor_trace(self_)), options(ScalarType::Long)); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THShortTensor_trace(self_)), options(ScalarType::Short)); - break; - } - default: - AT_ERROR("_th_trace not supported on CPUType for ", dispatch_scalar_type); - } -} + std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_gels_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -1002,57 +832,10 @@ std::tuple _th_gels(const Tensor & self, const Tensor & A) { } return std::tuple(res1, res2); } -std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - case ScalarType::Float: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - default: - AT_ERROR("_th_eig_out not supported on CPUType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} -std::tuple _th_eig(const Tensor & self, bool eigenvectors) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto res1_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res1 = Tensor(c10::intrusive_ptr::reclaim(res1_)); - auto res2_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res2 = Tensor(c10::intrusive_ptr::reclaim(res2_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_geev(res1_, res2_, self_, eigenvectors); - break; - } - default: - AT_ERROR("_th_eig not supported on CPUType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -1095,7 +878,7 @@ Tensor _th_potri(const Tensor & self, bool upper) { std::tuple _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_geqrf_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -1142,7 +925,7 @@ std::tuple _th_geqrf(const Tensor & self) { Tensor & _th_orgqr_out(Tensor & result, const Tensor & self, const Tensor & input2) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -1189,7 +972,7 @@ Tensor _th_orgqr(const Tensor & self, const Tensor & input2) { Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - + switch (dispatch_scalar_type) { case ScalarType::Double: { auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_ormqr_out", false, DeviceType::CPU, dispatch_scalar_type); @@ -1237,100 +1020,6 @@ Tensor _th_ormqr(const Tensor & self, const Tensor & input2, const Tensor & inpu } return result; } -std::tuple _th_multinomial_alias_setup_out(Tensor & J, Tensor & q, const Tensor & probs) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(J); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CPU, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CPU, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_multinomialAliasSetup(probs_, J_, q_); - break; - } - case ScalarType::Float: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CPU, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CPU, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_multinomialAliasSetup(probs_, J_, q_); - break; - } - default: - AT_ERROR("_th_multinomial_alias_setup_out not supported on CPUType for ", dispatch_scalar_type); - } - return std::tuple(J, q); -} -std::tuple _th_multinomial_alias_setup(const Tensor & probs) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(probs); - auto J_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(ScalarType::Long)).release(); - auto J = Tensor(c10::intrusive_ptr::reclaim(J_)); - auto q_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto q = Tensor(c10::intrusive_ptr::reclaim(q_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_multinomialAliasSetup(probs_, J_, q_); - break; - } - case ScalarType::Float: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_multinomialAliasSetup(probs_, J_, q_); - break; - } - default: - AT_ERROR("_th_multinomial_alias_setup not supported on CPUType for ", dispatch_scalar_type); - } - return std::tuple(J, q); -} -Tensor & _th_multinomial_alias_draw_out(Tensor & result, const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(result); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_multinomial_alias_draw_out", false, DeviceType::CPU, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw_out", false, DeviceType::CPU, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw_out", false, DeviceType::CPU, ScalarType::Long); - THDoubleTensor_multinomialAliasDraw(result_, q_, J_, num_samples, generator); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_multinomial_alias_draw_out", false, DeviceType::CPU, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw_out", false, DeviceType::CPU, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw_out", false, DeviceType::CPU, ScalarType::Long); - THFloatTensor_multinomialAliasDraw(result_, q_, J_, num_samples, generator); - break; - } - default: - AT_ERROR("_th_multinomial_alias_draw_out not supported on CPUType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(q); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(ScalarType::Long)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw", false, DeviceType::CPU, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw", false, DeviceType::CPU, ScalarType::Long); - THDoubleTensor_multinomialAliasDraw(result_, q_, J_, num_samples, generator); - break; - } - case ScalarType::Float: { - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw", false, DeviceType::CPU, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw", false, DeviceType::CPU, ScalarType::Long); - THFloatTensor_multinomialAliasDraw(result_, q_, J_, num_samples, generator); - break; - } - default: - AT_ERROR("_th_multinomial_alias_draw not supported on CPUType for ", dispatch_scalar_type); - } - return result; -} } // namespace th } // namespace legacy diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h index 1bc9b66777bc4..9a2ec45efefa8 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.h +++ b/aten/src/ATen/LegacyTHFunctionsCPU.h @@ -1,7 +1,5 @@ #pragma once -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h - #include #include #include @@ -38,11 +36,8 @@ Tensor _th_renorm(const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm); Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm); Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scalar min, Scalar max); Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max); -Tensor _th_trace(const Tensor & self); std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A); std::tuple _th_gels(const Tensor & self, const Tensor & A); -std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors); -std::tuple _th_eig(const Tensor & self, bool eigenvectors); Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper); Tensor _th_potri(const Tensor & self, bool upper); std::tuple _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self); @@ -51,10 +46,6 @@ Tensor & _th_orgqr_out(Tensor & result, const Tensor & self, const Tensor & inpu Tensor _th_orgqr(const Tensor & self, const Tensor & input2); Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose); Tensor _th_ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose); -std::tuple _th_multinomial_alias_setup_out(Tensor & J, Tensor & q, const Tensor & probs); -std::tuple _th_multinomial_alias_setup(const Tensor & probs); -Tensor & _th_multinomial_alias_draw_out(Tensor & result, const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator); -Tensor _th_multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator); } // namespace th } // namespace legacy diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h index 20717ad43e6f4..a1076b782d2f0 100644 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ b/aten/src/ATen/LegacyTHFunctionsCUDA.h @@ -1,7 +1,5 @@ #pragma once -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h - #include #include #include @@ -31,37 +29,21 @@ Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bo Tensor & _th_index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value); std::tuple _th_mode_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool keepdim); std::tuple _th_mode(const Tensor & self, int64_t dim, bool keepdim); -std::tuple _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending); -std::tuple _th_sort(const Tensor & self, int64_t dim, bool descending); +std::tuple _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending, bool stable); +std::tuple _th_sort(const Tensor & self, int64_t dim, bool descending, bool stable); std::tuple _th_topk_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted); std::tuple _th_topk(const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted); Tensor & _th_renorm_out(Tensor & result, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm); Tensor _th_renorm(const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm); Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm); -Tensor & _th_fmod_out(Tensor & result, const Tensor & self, Scalar other); -Tensor _th_fmod(const Tensor & self, Scalar other); -Tensor & _th_fmod_out(Tensor & result, const Tensor & self, const Tensor & other); -Tensor _th_fmod(const Tensor & self, const Tensor & other); -Tensor & _th_fmod_(Tensor & self, Scalar other); -Tensor & _th_fmod_(Tensor & self, const Tensor & other); Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim); Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim); -Tensor & _th_bmm_out(Tensor & result, const Tensor & self, const Tensor & mat2); -Tensor _th_bmm(const Tensor & self, const Tensor & mat2); -Tensor & _th_baddbmm_out(Tensor & result, const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha); -Tensor _th_baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha); std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A); std::tuple _th_gels(const Tensor & self, const Tensor & A); -std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors); -std::tuple _th_eig(const Tensor & self, bool eigenvectors); Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper); Tensor _th_potri(const Tensor & self, bool upper); std::tuple _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self); std::tuple _th_geqrf(const Tensor & self); -std::tuple _th_multinomial_alias_setup_out(Tensor & J, Tensor & q, const Tensor & probs); -std::tuple _th_multinomial_alias_setup(const Tensor & probs); -Tensor & _th_multinomial_alias_draw_out(Tensor & result, const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator); -Tensor _th_multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator); Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src); Tensor & _thnn_multi_margin_loss_forward_out(Tensor & output, const Tensor & self, const Tensor & target, Scalar p, Scalar margin, const Tensor & weight, int64_t reduction); Tensor _thnn_multi_margin_loss_forward(const Tensor & self, const Tensor & target, Scalar p, Scalar margin, const Tensor & weight, int64_t reduction); @@ -89,7 +71,6 @@ Tensor & _thnn_log_sigmoid_backward_out(Tensor & grad_input, const Tensor & grad Tensor _thnn_log_sigmoid_backward(const Tensor & grad_output, const Tensor & self, const Tensor & buffer); Tensor & _thnn_rrelu_with_noise_forward_out(Tensor & output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional generator); Tensor _thnn_rrelu_with_noise_forward(const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional generator); -Tensor & _thnn_rrelu_with_noise_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training); Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training); Tensor & _thnn_rrelu_with_noise_forward_(Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional generator); std::tuple _thnn_conv2d_forward_out(Tensor & output, Tensor & columns, Tensor & ones, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const Tensor & bias, IntArrayRef stride, IntArrayRef padding); diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 264271d352298..2269d9ae11dcf 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -48,6 +48,9 @@ MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) { if (!a->is_contiguous() || !b->is_contiguous()) { return MemOverlapStatus::TOO_HARD; } + if (!a->has_storage() || !b->has_storage()) { + return MemOverlapStatus::NO; + } if (a->storage().data() == b->storage().data()) { const auto a_begin = static_cast(a->data()); const auto a_end = a_begin + a->numel() * a->itemsize(); @@ -75,4 +78,16 @@ void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) { "Please clone() the tensor before performing the operation."); } +void assert_no_overlap(const Tensor& a, const Tensor& b) { + assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); +} + +void assert_no_overlap(TensorImpl* a, TensorImpl* b) { + const auto lap = get_overlap_status(a, b); + TORCH_CHECK(lap != MemOverlapStatus::PARTIAL && lap != MemOverlapStatus::FULL, + "unsupported operation: some elements of the input tensor and " + "the written-to tensor refer to a single memory location. " + "Please clone() the tensor before performing the operation."); +} + } diff --git a/aten/src/ATen/MemoryOverlap.h b/aten/src/ATen/MemoryOverlap.h index 67f63a64668c3..f7437c61136ca 100644 --- a/aten/src/ATen/MemoryOverlap.h +++ b/aten/src/ATen/MemoryOverlap.h @@ -15,16 +15,19 @@ enum class MemOverlap { NO, YES, TOO_HARD }; enum class MemOverlapStatus { FULL, PARTIAL, NO, TOO_HARD }; -CAFFE2_API MemOverlap has_internal_overlap(const Tensor& t); -CAFFE2_API MemOverlap has_internal_overlap(TensorImpl* t); +TORCH_API MemOverlap has_internal_overlap(const Tensor& t); +TORCH_API MemOverlap has_internal_overlap(TensorImpl* t); -CAFFE2_API void assert_no_internal_overlap(const Tensor& t); -CAFFE2_API void assert_no_internal_overlap(TensorImpl* t); +TORCH_API void assert_no_internal_overlap(const Tensor& t); +TORCH_API void assert_no_internal_overlap(TensorImpl* t); -CAFFE2_API MemOverlapStatus get_overlap_status(const Tensor& a, const Tensor& b); -CAFFE2_API MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b); +TORCH_API MemOverlapStatus get_overlap_status(const Tensor& a, const Tensor& b); +TORCH_API MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b); -CAFFE2_API void assert_no_partial_overlap(const Tensor& a, const Tensor& b); +TORCH_API void assert_no_partial_overlap(const Tensor& a, const Tensor& b); void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b); +TORCH_API void assert_no_overlap(const Tensor& a, const Tensor& b); +TORCH_API void assert_no_overlap(TensorImpl* a, TensorImpl* b); + } diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp index f59cbed39abb9..5f8de486dc785 100644 --- a/aten/src/ATen/NamedTensorUtils.cpp +++ b/aten/src/ATen/NamedTensorUtils.cpp @@ -264,11 +264,11 @@ static std::vector compute_dot_product_outnames( } std::vector outnames(num_outnames, Dimname::wildcard()); int64_t index = 0; - for (int64_t j = 0; j < tensor_names.size(); ++j) { + for (size_t j = 0; j < tensor_names.size(); ++j) { if (j == tensor_dotted_dim) continue; outnames[index++] = tensor_names[j]; } - for (int64_t j = 0; j < other_names.size(); ++j) { + for (size_t j = 0; j < other_names.size(); ++j) { if (j == other_dotted_dim) continue; outnames[index++] = other_names[j]; } @@ -517,17 +517,16 @@ std::vector compute_bmm_outnames( } std::vector compute_baddbmm_outnames( - TensorImpl* result, - TensorImpl* batch1, - TensorImpl* batch2, - TensorImpl* bias) { - if (!impl::has_names(result) && !impl::has_names(batch1) && - !impl::has_names(batch2) && !impl::has_names(bias)) { + Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias) { + if (!result.has_names() && !self.has_names() + && !other.has_names() && !bias.has_names()) { return {}; } - auto bmm_names = compute_matmul_outnames( - impl::get_names(batch1), impl::get_names(batch2)); - auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names); + auto bmm_names = compute_matmul_outnames(self.names(), other.names()); + auto baddbmm_names = unify_from_right(bias.names(), bmm_names); return baddbmm_names; } diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h index 6777f39f7fcfc..af5584157550b 100644 --- a/aten/src/ATen/NamedTensorUtils.h +++ b/aten/src/ATen/NamedTensorUtils.h @@ -17,8 +17,8 @@ inline bool has_names(TensorList tensors) { // Converts dim to an positional index. Errors if `dim` cannot be used to // refer to any dimension of tensor. -CAFFE2_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim); -CAFFE2_API std::vector dimnames_to_positions(const Tensor& tensor, DimnameList dims); +TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim); +TORCH_API std::vector dimnames_to_positions(const Tensor& tensor, DimnameList dims); // Unifies two DimnameList to produce a third. This is useful for implementing // the named inference rule for binary broadcasting operations like add. @@ -28,7 +28,7 @@ CAFFE2_API std::vector dimnames_to_positions(const Tensor& tensor, Dimn // 2) Check misaligned: If a name `n` is in `names`, then it must appear at // the same index from the right in other. // 3) The output names are obtained by unifying the names individually from the right. -CAFFE2_API std::vector +TORCH_API std::vector unify_from_right(DimnameList names, DimnameList other, const char* action = "broadcast"); [[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) { @@ -75,50 +75,50 @@ namespace namedinference { // `names` can be empty; see [NOTE] Writing name inference rules // If `names` is not empty, `names.size()` should equal `result.dim()`. // When in doubt, use this overload instead of the others. -CAFFE2_API Tensor& propagate_names_if_nonempty( +TORCH_API Tensor& propagate_names_if_nonempty( Tensor& result, DimnameList maybe_names, bool validate_names = false); // Propagates `names` to `result`. Only use this if we are certain that there are // names to propagate (that names is not empty). -CAFFE2_API Tensor& propagate_names( +TORCH_API Tensor& propagate_names( Tensor& result, DimnameList names, bool validate_names = false); // Propagates all names from src to result. -CAFFE2_API void propagate_names(Tensor& result, const Tensor& src); +TORCH_API void propagate_names(Tensor& result, const Tensor& src); // Propagates all names except for those at the excluded_idxs. -CAFFE2_API void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs); +TORCH_API void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs); // Used for reduction ops that have a `keepdim` arg. -CAFFE2_API void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim); +TORCH_API void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim); -CAFFE2_API void propagate_names_for_expand(Tensor& result, const Tensor& self); +TORCH_API void propagate_names_for_expand(Tensor& result, const Tensor& self); -CAFFE2_API std::vector compute_cat_outnames(TensorList tensors); +TORCH_API std::vector compute_cat_outnames(TensorList tensors); -CAFFE2_API std::vector compute_broadcast_outnames( +TORCH_API std::vector compute_broadcast_outnames( const Tensor& self, const Tensor& other); -CAFFE2_API std::vector broadcast_to_outnames( +TORCH_API std::vector broadcast_to_outnames( const Tensor& tensor, const Tensor& reference_tensor, const char* op_name); -CAFFE2_API std::vector compute_matmul_outnames(const Tensor& self, const Tensor& other); +TORCH_API std::vector compute_matmul_outnames(const Tensor& self, const Tensor& other); -CAFFE2_API std::vector compute_cdist_outnames(const Tensor& self, const Tensor& other); +TORCH_API std::vector compute_cdist_outnames(const Tensor& self, const Tensor& other); -CAFFE2_API std::vector compute_bmm_outnames( +TORCH_API std::vector compute_bmm_outnames( Tensor& result, const Tensor& self, const Tensor& other); -CAFFE2_API std::vector compute_squeeze_outnames(const Tensor& tensor); +TORCH_API std::vector compute_squeeze_outnames(const Tensor& tensor); std::vector compute_diagonal_outnames( const Tensor& tensor, @@ -127,40 +127,40 @@ std::vector compute_diagonal_outnames( // TensorImpl* overloads for Legacy TH/THC code. Use these sparingly. -CAFFE2_API TensorImpl* propagate_names_if_nonempty( +TORCH_API TensorImpl* propagate_names_if_nonempty( TensorImpl* result, DimnameList maybe_names, bool validate_names = false); -CAFFE2_API TensorImpl* propagate_names( +TORCH_API TensorImpl* propagate_names( TensorImpl* result, DimnameList names, bool validate_names = false); -CAFFE2_API void propagate_names(TensorImpl* result, /*const */TensorImpl* src); +TORCH_API void propagate_names(TensorImpl* result, /*const */TensorImpl* src); // result = m1 @ m2 + bias -CAFFE2_API void propagate_names_for_addmm( +TORCH_API void propagate_names_for_addmm( Tensor& result, const Tensor& m1, const Tensor& m2, const Tensor& bias); -CAFFE2_API void propagate_names_for_addmv( +TORCH_API void propagate_names_for_addmv( Tensor& result, const Tensor& mat, const Tensor& vec, const Tensor& bias); -CAFFE2_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); +TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); -CAFFE2_API std::vector compute_baddbmm_outnames( - TensorImpl* result, - TensorImpl* self, - TensorImpl* other, - TensorImpl* bias); +TORCH_API std::vector compute_baddbmm_outnames( + Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias); -CAFFE2_API bool are_names_equal(TensorImpl* self, TensorImpl* other); +TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other); } // namespace namedinference diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h index 6cbd974f51dd7..e6726602bbd57 100644 --- a/aten/src/ATen/NumericUtils.h +++ b/aten/src/ATen/NumericUtils.h @@ -42,12 +42,18 @@ inline bool _isnan(T val) { template ::value, int>::type = 0> inline C10_HOST_DEVICE bool _isnan(T val) { - return at::_isnan(float(val)); + return at::_isnan(static_cast(val)); } +template ::value, int>::type = 0> +inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { + return at::_isnan(static_cast(val)); +} + inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { - return at::_isnan(float(val)); + return at::_isnan(static_cast(val)); } template diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index a4007c3115dcf..2072f549d0117 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -17,18 +17,20 @@ namespace at { // "shallow copy" in order to add support. template -struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { +struct TORCH_API OpaqueTensorImpl : public TensorImpl { // public constructor for now... OpaqueTensorImpl( at::DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, c10::Device device, OpaqueHandle opaque_handle, - c10::IntArrayRef sizes) + c10::IntArrayRef sizes, + bool is_non_overlapping_and_dense = true) : TensorImpl(key_set, data_type, device), opaque_handle_(std::move(opaque_handle)) { - sizes_ = sizes.vec(); + sizes_and_strides_.set_sizes(sizes); refresh_numel(); + is_non_overlapping_and_dense_ = is_non_overlapping_and_dense; } void release_resources() override { @@ -84,16 +86,36 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive>( - key_set(), dtype(), device(), opaque_handle_, sizes_); + key_set(), dtype(), device(), opaque_handle_, sizes_and_strides_.sizes_arrayref()); copy_tensor_metadata( - /*src_impl=*/this, - /*dest_impl=*/impl.get(), + /*src_opaque_impl=*/this, + /*dest_opaque_impl=*/impl.get(), /*version_counter=*/version_counter, /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); impl->refresh_numel(); return impl; } + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive>( + key_set(), dtype(), device(), opaque_handle_, sizes_and_strides_.sizes_arrayref()); + copy_tensor_metadata( + /*src_opaque_impl=*/this, + /*dest_opaque_impl=*/impl.get(), + /*version_counter=*/std::move(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + /** * Shallow-copies data from another TensorImpl into this TensorImpl. * @@ -143,6 +165,21 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; } + static void copy_tensor_metadata( + const OpaqueTensorImpl* src_opaque_impl, + OpaqueTensorImpl* dest_opaque_impl, + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_opaque_impl, + dest_opaque_impl, + std::move(version_counter), + allow_tensor_metadata_change); + + // OpaqueTensorImpl-specific fields. + dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; + } + private: OpaqueHandle opaque_handle_; }; diff --git a/aten/src/ATen/PTThreadPool.h b/aten/src/ATen/PTThreadPool.h index f5e8a1a18256a..7015f7cacc496 100644 --- a/aten/src/ATen/PTThreadPool.h +++ b/aten/src/ATen/PTThreadPool.h @@ -5,7 +5,7 @@ namespace at { -class CAFFE2_API PTThreadPool : public c10::ThreadPool { +class TORCH_API PTThreadPool : public c10::ThreadPool { public: explicit PTThreadPool( int pool_size, diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 9e2f9be3e66e2..122b8ea7548bd 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -1,43 +1,34 @@ #pragma once -#include #include #include #include namespace at { -namespace internal { -// This parameter is heuristically chosen to determine the minimum number of -// work that warrants parallelism. For example, when summing an array, it is -// deemed inefficient to parallelise over arrays shorter than 32768. Further, -// no parallel algorithm (such as parallel_reduce) should split work into -// smaller than GRAIN_SIZE chunks. -constexpr int64_t GRAIN_SIZE = 32768; -} // namespace internal inline int64_t divup(int64_t x, int64_t y) { return (x + y - 1) / y; } // Called during new thread initialization -CAFFE2_API void init_num_threads(); +TORCH_API void init_num_threads(); // Sets the number of threads to be used in parallel region -CAFFE2_API void set_num_threads(int); +TORCH_API void set_num_threads(int); // Returns the maximum number of threads that may be used in a parallel region -CAFFE2_API int get_num_threads(); +TORCH_API int get_num_threads(); // Returns the current thread number (starting from 0) // in the current parallel region, or 0 in the sequential region -CAFFE2_API int get_thread_num(); +TORCH_API int get_thread_num(); // Checks whether the code runs in parallel region -CAFFE2_API bool in_parallel_region(); +TORCH_API bool in_parallel_region(); namespace internal { // Initialise num_threads lazily at first parallel call -inline CAFFE2_API void lazy_init_num_threads() { +inline TORCH_API void lazy_init_num_threads() { thread_local bool init = false; if (C10_UNLIKELY(!init)) { at::init_num_threads(); @@ -119,29 +110,29 @@ inline scalar_t parallel_reduce( const SF& sf); // Returns a detailed string describing parallelization settings -CAFFE2_API std::string get_parallel_info(); +TORCH_API std::string get_parallel_info(); // Sets number of threads used for inter-op parallelism -CAFFE2_API void set_num_interop_threads(int); +TORCH_API void set_num_interop_threads(int); // Returns the number of threads used for inter-op parallelism -CAFFE2_API int get_num_interop_threads(); +TORCH_API int get_num_interop_threads(); // Launches inter-op parallel task -CAFFE2_API void launch(std::function func); +TORCH_API void launch(std::function func); namespace internal { void launch_no_thread_state(std::function fn); } // namespace internal // Launches intra-op parallel task -CAFFE2_API void intraop_launch(std::function func); +TORCH_API void intraop_launch(std::function func); // Launches intra-op parallel task, returns a future -CAFFE2_API std::shared_ptr intraop_launch_future( +TORCH_API std::shared_ptr intraop_launch_future( std::function func); // Returns number of intra-op threads used by default -CAFFE2_API int intraop_default_num_threads(); +TORCH_API int intraop_default_num_threads(); } // namespace at diff --git a/aten/src/ATen/ParallelNative.h b/aten/src/ATen/ParallelNative.h index 58d3445cc5674..3a8d2633191c9 100644 --- a/aten/src/ATen/ParallelNative.h +++ b/aten/src/ATen/ParallelNative.h @@ -22,7 +22,7 @@ inline std::tuple calc_num_tasks_and_chunk_size( return std::make_tuple(num_tasks, chunk_size); } -CAFFE2_API void _parallel_run( +TORCH_API void _parallel_run( const int64_t begin, const int64_t end, const int64_t grain_size, diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index 9c074cfea410d..261f6cdd46b5f 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -1,4 +1,5 @@ #include +#include #if AT_PARALLEL_OPENMP #include @@ -8,6 +9,8 @@ #include #endif +#include + namespace at { namespace { @@ -49,6 +52,12 @@ void set_num_threads(int nthreads) { // See https://github.com/pytorch/pytorch/issues/13757 mkl_set_dynamic(false); #endif +#ifdef USE_PTHREADPOOL + // because PyTorch uses caffe2::pthreadpool() in QNNPACK + caffe2::PThreadPool* const pool = caffe2::pthreadpool(); + TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); + pool->set_thread_count(nthreads); +#endif } // Explicitly calling omp_get_max_threads() as the size of the parallel diff --git a/aten/src/ATen/ParallelOpenMP.h b/aten/src/ATen/ParallelOpenMP.h index 5e01d1de9d187..bbb369ba3d506 100644 --- a/aten/src/ATen/ParallelOpenMP.h +++ b/aten/src/ATen/ParallelOpenMP.h @@ -1,5 +1,4 @@ #pragma once -#include #include #include diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp new file mode 100644 index 0000000000000..2510b5c808175 --- /dev/null +++ b/aten/src/ATen/ScalarOps.cpp @@ -0,0 +1,40 @@ +// FastPass +#ifdef _MSC_VER +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include +#endif + +#include +#include +#include + +namespace at { +namespace { +template +inline void fill_inplace(Tensor& self, Scalar value_scalar) { + auto value = value_scalar.to(); + scalar_t* dptr = static_cast(self.data_ptr()); + *dptr = value; +} +} + +namespace detail { +Tensor& scalar_fill(Tensor& self, Scalar value) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() { + fill_inplace(self, value); + }); + return self; +} + +Tensor scalar_tensor_static(Scalar s, c10::optional dtype_opt, c10::optional device_opt) { + at::tracer::impl::NoTracerDispatchMode tracer_guard; + at::AutoNonVariableTypeMode non_var_type_mode(true); + auto result = at::detail::empty_cpu({}, dtype_opt, c10::nullopt, device_opt, c10::nullopt, c10::nullopt); + scalar_fill(result, s); + return result; +} +} // namespace detail +} // namespace at diff --git a/aten/src/ATen/ScalarOps.h b/aten/src/ATen/ScalarOps.h index 8c07a9d618bcb..7f49aeb97cb43 100644 --- a/aten/src/ATen/ScalarOps.h +++ b/aten/src/ATen/ScalarOps.h @@ -4,6 +4,18 @@ #include #include +namespace at { +namespace detail { +// When filling a number to 1-element CPU tensor, we want to skip +// everything but manipulate data ptr directly. +// Ideally this fast pass should be implemented in TensorIterator, +// but we also want to skip compute_types which in not avoidable +// in TensorIterator for now. +Tensor& scalar_fill(Tensor& self, Scalar value); +TORCH_API Tensor scalar_tensor_static(Scalar s, c10::optional dtype_opt, c10::optional device_opt); +} // namespace detail +} // namespace at + // This is in the c10 namespace because we use ADL to find the functions in it. namespace c10 { @@ -13,14 +25,14 @@ inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { // This is the fast track we have for CPU scalar tensors. if (device == at::kCPU) { if (s.isFloatingPoint()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kDouble)); - } else if (s.isBoolean()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kBool)); + return at::detail::scalar_tensor_static(s, at::kDouble, at::kCPU); } else if (s.isComplex()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kComplexDouble)); + return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU); + } else if (s.isBoolean()) { + return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU); } else { AT_ASSERT(s.isIntegral(false)); - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kLong)); + return at::detail::scalar_tensor_static(s, at::kLong, at::kCPU); } } if (s.isFloatingPoint()) { @@ -35,30 +47,4 @@ inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { } } -// The above function is useful for type promotion -// in Binary Ops where one argument is `Tensor` and other is `Scalar`. -// In the above function, we generate wrapped tensor to type with highest -// range and precision based on scalar's type (to support type promotion). -// Eg. Floating Point Types -> Double -// Complex Types -> Complex Double -// -// However for `Scalar-Scalar` Binary Op,we default the type of wrapped tensor -// to the default type corresponding to scalar's type. -inline at::Tensor scalar_to_tensor_default_dtype( - Scalar s, - const Device device = at::kCPU) { - if (s.isFloatingPoint()) { - return at::scalar_tensor( - s, at::device(device).dtype(at::get_default_dtype())); - } else if (s.isBoolean()) { - return at::scalar_tensor(s, at::device(device).dtype(at::kBool)); - } else if (s.isComplex()) { - return at::scalar_tensor( - s, at::device(device).dtype(at::get_default_complex_dtype())); - } else { - AT_ASSERT(s.isIntegral(false)); - return at::scalar_tensor(s, at::device(device).dtype(at::kLong)); - } -} - } diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index ee1ff71b54d20..0e18dca131a20 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -30,12 +30,12 @@ namespace { // // This means that we allocate a [1,0] size indices tensor and a [0] size // values tensor for such an empty tensor. -SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type) +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type) : SparseTensorImpl(key_set, data_type , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(ScalarType::Long)) , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(data_type))) {} -SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type, at::Tensor indices, at::Tensor values) +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type, at::Tensor indices, at::Tensor values) : TensorImpl(key_set, data_type, values.device()) , sparse_dim_(1) , dense_dim_(0) @@ -46,6 +46,8 @@ SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::Typ AT_ASSERT(values_.sizes() == IntArrayRef({0})); AT_ASSERT(values_.device() == indices_.device()); AT_ASSERT(values_.device() == device()); + + is_non_overlapping_and_dense_ = false; } IntArrayRef SparseTensorImpl::strides() const { @@ -67,9 +69,6 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) { AT_ERROR("sparse tensors do not have set_storage_offset"); } -int64_t SparseTensorImpl::dim() const { - return sparse_dim_ + dense_dim_; -} bool SparseTensorImpl::has_storage() const { return false; } @@ -81,7 +80,6 @@ int64_t SparseTensorImpl::storage_offset() const { } void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) { TORCH_CHECK(allow_tensor_metadata_change(), "set_indices_and_values_unsafe ", err_msg_tensor_metadata_change_not_allowed); - TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch()); TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout()); TORCH_CHECK(!values.is_sparse(), "expected values to be a dense tensor, but got values of layout ", values.layout()); diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index bdccb540734f0..9daf21c15e568 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -5,7 +5,7 @@ #include namespace at { -struct CAFFE2_API SparseTensorImpl : public TensorImpl { +struct TORCH_API SparseTensorImpl : public TensorImpl { // Stored in COO format, indices + values. // INVARIANTS: @@ -31,7 +31,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { public: // Public for now... - explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); int64_t nnz() const { return values_.size(0); } int64_t sparse_dim() const { return sparse_dim_; } @@ -47,7 +47,6 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; - int64_t dim() const override; bool has_storage() const override; const Storage& storage() const override; int64_t storage_offset() const override; @@ -56,7 +55,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { // respect to indices and values void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { TORCH_CHECK(allow_tensor_metadata_change(), "raw_resize_ ", err_msg_tensor_metadata_change_not_allowed); - sizes_ = size.vec(); + sizes_and_strides_.set_sizes(size); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); @@ -126,7 +125,8 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { "shrinking the size of dense dimensions (from ", dense_size_original, " to ", dense_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg); } - if ((!size.equals(sizes_)) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) { + const bool size_equals_sizes = std::equal(size.begin(), size.end(), sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) { auto nnz = values().size(0); std::vector values_size = {nnz}; auto dense_size = size.slice(sparse_dim); @@ -135,7 +135,9 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { indices_.resize_({sparse_dim, nnz}); } - sizes_ = size.vec(); + if (!size_equals_sizes) { + sizes_and_strides_.set_sizes(size); + } sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); @@ -146,7 +148,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { TORCH_CHECK(allow_tensor_metadata_change(), "resize_and_clear_ ", err_msg_tensor_metadata_change_not_allowed); TORCH_CHECK(sparse_dim + dense_dim == static_cast(size.size()), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size()); - sizes_ = size.vec(); + sizes_and_strides_.set_sizes(size); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; @@ -200,6 +202,25 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { return impl; } + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive(key_set(), dtype()); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/std::move(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + /** * Shallow-copies data from another TensorImpl into this TensorImpl. * @@ -217,7 +238,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { refresh_numel(); } private: - explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta, at::Tensor indices, at::Tensor values); /** * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) diff --git a/aten/src/ATen/SparseTensorUtils.cpp b/aten/src/ATen/SparseTensorUtils.cpp new file mode 100644 index 0000000000000..bef33627c3ff4 --- /dev/null +++ b/aten/src/ATen/SparseTensorUtils.cpp @@ -0,0 +1,113 @@ +#include + +#include +#include +#include + +namespace at { namespace sparse { + +// NOTE [ Flatten Sparse Indices ] +// This helper function flattens a sparse indices tensor (a Tensor) into a 1D +// indices tensor. E.g., +// input = [[2, 4, 0], +// [3, 1, 10]] +// full_size = [2, 12] +// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10] +// +// In other words, assuming that each `indices[i, :]` is a valid index to a +// tensor `t` of shape `full_size`. This returns the corresponding indices to +// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`. +// if forceClone is true, the result will forced to be a clone of self. +// if force_clone is true, the result will forced to be a clone of self. +Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone /*= false*/) { + int64_t sparse_dim = indices.size(0); + if (sparse_dim == 1) { + if (force_clone) { + return indices.squeeze(0).clone(at::MemoryFormat::Contiguous); + } else { + return indices.squeeze(0); + } + } else { + std::vector indices_mult_cpu_vec; + indices_mult_cpu_vec.reserve(sparse_dim); + int64_t mult = 1; + for (int64_t i = sparse_dim - 1; i >= 0; i--) { + indices_mult_cpu_vec[i] = mult; + mult *= full_size[i]; + } + auto indices_mult_cpu = at::from_blob( + indices_mult_cpu_vec.data(), + /*size=*/{sparse_dim, 1}, + indices.options().device(kCPU)); + // NB: must be blocking because this blob may be freed after this closure, + // and non_blocking copy will see garbage. + auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false); + // Ideally we want matmul but matmul is slow on CPU Long and not implemented + // on CUDA Long. So mul is faster. + return indices.mul(indices_mult).sum(0); + } +} + +// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ], +// except this one allows partial flatten: only flatten on specified dims. Note that +// the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim. +// Also if input indices is already coalesced, the flattened indices will also be sorted. +// +// args: +// indices: sparse tensor indices +// sizes: sparse tensor sizes +// dims_to_flatten: a list of dim index to flatten +// +// Ex1: +// indices = [[2, 4, 0], +// [3, 1, 3]] +// sizes = [2, 12] +// dims_to_flatten = [0, 1] +// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3] +// +// Ex2: +// dims_to_flatten = [1] +// new_indices = [ 3, 1, 3 ] # uncoalesced +Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){ + Tensor new_indices = at::zeros({indices.size(1)}, indices.options()); + for (auto d : dims_to_flatten) { + new_indices.mul_(sizes[d]); + new_indices.add_(indices.select(0, d)); + } + return new_indices; +} + +Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) { + /* + Find the CSR representation for a row `indices` from the COO format + Inputs: + `indices` is the row pointer from COO indices + `dim` is the row dimensionality + `nnz` is the number of non-zeros + + Output: + `csr` is a compressed row array in a CSR format + */ + Tensor csr = at::zeros({dim + 1}, kLong); + + // TODO: eliminate this conditional when zero-size dims supported correctly + if (nnz > 0) { + auto csr_accessor = csr.accessor(); + // Convert the sparse matrix to CSR format + at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) { + int64_t h, hp0, hp1; + for (auto i = start; i < end; i++) { + hp0 = indices[i]; + hp1 = (i+1 == nnz) ? dim : indices[i+1]; + if (hp0 != hp1) { + for (h = hp0; h < hp1; h++) { + csr_accessor[h+1] = i+1; + } + } + } + }); + } + return csr; +} + +}} // namespace at::sparse diff --git a/aten/src/ATen/SparseTensorUtils.h b/aten/src/ATen/SparseTensorUtils.h index 2ed55881268e1..0327483607f17 100644 --- a/aten/src/ATen/SparseTensorUtils.h +++ b/aten/src/ATen/SparseTensorUtils.h @@ -2,15 +2,15 @@ #include #include +#include namespace at { namespace sparse { // Just for documentary purposes using SparseTensor = Tensor; -using LongTensor = Tensor; -using IntTensor = Tensor; using SparseType = Type; + // This is an internal utility function for getting at the SparseTensorImpl, // so that we can write sparse tensor specific accessors for special fields // in SparseTensor. You should only use this for writing low level @@ -18,21 +18,20 @@ using SparseType = Type; // the low level setters/getters that were implemented using this. // // This may be called repeatedly, so make sure it's pretty cheap. -inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) { - TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch()); +inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) { AT_ASSERTM(self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor"); return static_cast(self.unsafeGetTensorImpl()); } // Takes indices and values and directly puts them into the sparse tensor, no // copy. This used to be called THSTensor_(_move) -inline void alias_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values) { +inline void alias_into_sparse(const SparseTensor& self, const Tensor& indices, const Tensor& values) { get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values); } // Take indices and values and makes a (data) copy of them to put into the sparse // indices/values. This used to be called THSTensor_(_set) -inline void copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) { +inline void copy_into_sparse(const SparseTensor& self, const Tensor& indices, const Tensor& values, bool non_blocking) { alias_into_sparse( self, indices.to(self._indices().options(), non_blocking, /*copy=*/true), @@ -59,7 +58,7 @@ inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) { } // NOTE [ Flatten Sparse Indices ] -// This helper function flattens a sparse indices tensor (a LongTensor) into a 1D +// This helper function flattens a sparse indices tensor (a Tensor) into a 1D // indices tensor. E.g., // input = [[2, 4, 0], // [3, 1, 10]] @@ -71,34 +70,7 @@ inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) { // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`. // if forceClone is true, the result will forced to be a clone of self. // if force_clone is true, the result will forced to be a clone of self. -inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone = false) { - int64_t sparse_dim = indices.size(0); - if (sparse_dim == 1) { - if (force_clone) { - return indices.squeeze(0).clone(at::MemoryFormat::Contiguous); - } else { - return indices.squeeze(0); - } - } else { - std::vector indices_mult_cpu_vec; - indices_mult_cpu_vec.reserve(sparse_dim); - int64_t mult = 1; - for (int64_t i = sparse_dim - 1; i >= 0; i--) { - indices_mult_cpu_vec[i] = mult; - mult *= full_size[i]; - } - auto indices_mult_cpu = at::from_blob( - indices_mult_cpu_vec.data(), - /*size=*/{sparse_dim, 1}, - indices.options().device(kCPU)); - // NB: must be blocking because this blob may be freed after this closure, - // and non_blocking copy will see garbage. - auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false); - // Ideally we want matmul but matmul is slow on CPU Long and not implemented - // on CUDA Long. So mul is faster. - return indices.mul(indices_mult).sum(0); - } -} +TORCH_API Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone = false); // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ], // except this one allows partial flatten: only flatten on specified dims. Note that @@ -120,13 +92,9 @@ inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size, // Ex2: // dims_to_flatten = [1] // new_indices = [ 3, 1, 3 ] # uncoalesced -inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){ - LongTensor new_indices = at::zeros({indices.size(1)}, indices.options()); - for (auto d : dims_to_flatten) { - new_indices.mul_(sizes[d]); - new_indices.add_(indices.select(0, d)); - } - return new_indices; -} +TORCH_API Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten); + +// Find the CSR representation for a row `indices` from the COO format +TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz); }} // namespace at::sparse diff --git a/aten/src/ATen/TensorGeometry.h b/aten/src/ATen/TensorGeometry.h index 291892a14d085..ad3e16da4a6a7 100644 --- a/aten/src/ATen/TensorGeometry.h +++ b/aten/src/ATen/TensorGeometry.h @@ -5,7 +5,7 @@ namespace at { -struct CAFFE2_API TensorGeometry { +struct TORCH_API TensorGeometry { TensorGeometry() : storage_offset_(0) {} explicit TensorGeometry(IntArrayRef sizes) diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 11149cdb0451d..abc0aae59e337 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -4,6 +4,13 @@ #include #include #include +#include + +// TODO: try to remove this +// There is some back story, see https://github.com/pytorch/pytorch/issues/48684 +#include + +#include namespace at { namespace indexing { @@ -13,12 +20,12 @@ const int64_t INDEX_MIN = std::numeric_limits::min(); enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor }; -constexpr c10::nullopt_t None{c10::nullopt_t::init()}; +constexpr c10::nullopt_t None = c10::nullopt; -struct CAFFE2_API EllipsisIndexType final { EllipsisIndexType() {} }; -CAFFE2_API extern const EllipsisIndexType Ellipsis; +struct TORCH_API EllipsisIndexType final { EllipsisIndexType() {} }; +TORCH_API extern const EllipsisIndexType Ellipsis; -struct CAFFE2_API Slice final { +struct TORCH_API Slice final { public: // This mirrors `__PySlice_Unpack` in torch/csrc/utils/python_compat.h Slice( @@ -68,7 +75,7 @@ struct CAFFE2_API Slice final { int64_t step_; }; -CAFFE2_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); +TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}` @@ -95,7 +102,7 @@ CAFFE2_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); // `:3:2` | `Slice(None, 3, 2)` // `1:3:2` | `Slice(1, 3, 2)` // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})` -struct CAFFE2_API TensorIndex final { +struct TORCH_API TensorIndex final { // Case 1: `at::indexing::None` TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {} @@ -170,8 +177,8 @@ struct CAFFE2_API TensorIndex final { TensorIndexType type_; }; -CAFFE2_API std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index); -CAFFE2_API std::ostream& operator<<(std::ostream& stream, const std::vector& tensor_indices); +TORCH_API std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index); +TORCH_API std::ostream& operator<<(std::ostream& stream, const std::vector& tensor_indices); namespace impl { static inline Tensor applySlice( @@ -222,9 +229,9 @@ static inline Tensor applySelect( static inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) { // booleans add a dimension of size 1. true indexes this dimension as if 0:, false as empty. if (value) { - return at::native::zeros({1}, {}, self.options().dtype(kLong)); + return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.); } else { - return at::native::empty({0}, {}, self.options().dtype(kLong)); + return at::empty({0}, {}, self.options().dtype(kLong)); } } @@ -245,10 +252,6 @@ static inline Tensor boolToIndexingTensor(const Tensor& self, bool value, const } } -static inline Tensor scalarToTensorCPUOrCUDA(Scalar v, const TensorOptions& options) { - return at::native::scalar_tensor(v, options); -} - static inline Tensor scalarToTensorNonNativeDeviceType(Scalar v, const TensorOptions& options) { return at::scalar_tensor(v, options); } @@ -260,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector& (*dim_ptr)++; }; -static inline std::vector typeConvertIndices(const Tensor& self, std::vector&& indices) { - std::vector converted_inds(indices.size()); +static inline c10::List> typeConvertIndices(const Tensor& self, std::vector&& indices) { + c10::List> converted_inds; + converted_inds.reserve(indices.size()); for (size_t i = 0; i < indices.size(); ++i) { const auto &ind = indices[i]; if (ind.defined()) { - converted_inds[i] = ind.to(ind.options().device(self.device())); + converted_inds.push_back(ind.to(ind.options().device(self.device()))); } else { - converted_inds[i] = std::move(indices[i]); + converted_inds.push_back(std::move(indices[i])); } } return converted_inds; @@ -316,8 +320,8 @@ static inline int64_t count_specified_dimensions(const ArrayRef& in // The rest of the functions are in `at::indexing::impl` namespace, signifying // that they shouldn't be used from Python indexing implementation. static inline Tensor scalarToTensor(Scalar v, const TensorOptions& options, const at::Device& self_device) { - if (self_device == at::kCPU || self_device == at::kCUDA) { - return impl::scalarToTensorCPUOrCUDA(v, options); + if (self_device == at::kCPU) { + return at::detail::scalar_tensor_static(v, options.dtype_opt()->toScalarType(), self_device); } else { return impl::scalarToTensorNonNativeDeviceType(v, options); } @@ -339,6 +343,13 @@ static inline IntArrayRef slicePrefix1sSize(const IntArrayRef& sizes) { } static inline void copy_to(const Tensor& dst, const Tensor& src) { + if (dst.sizes().equals(src.sizes())) { + // A shortcut to avoid generating hard-coded constant sizes during tracing. + // This is not a perfect solution: when src & dst have different shapes, constants will still + // appear. Users can workaround that case by dst[index..] = src.reshape(..) + dst.copy_(src); + return; + } Tensor b_src; std::tie(b_src) = expand_inplace(dst, src.view(slicePrefix1sSize(src.sizes())), "setitem"); dst.copy_(b_src); diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp similarity index 76% rename from aten/src/ATen/native/TensorIterator.cpp rename to aten/src/ATen/TensorIterator.cpp index 71b552718ab13..3f5f9280eb99a 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -6,14 +6,15 @@ #include #include #include +#include namespace at { -using DimMask = TensorIterator::DimMask; -using PtrVector = TensorIterator::PtrVector; -using loop_t = TensorIterator::loop_t; -using loop2d_t = TensorIterator::loop2d_t; -using StrideVector = TensorIterator::StrideVector; +using DimMask = TensorIteratorBase::DimMask; +using PtrVector = TensorIteratorBase::PtrVector; +using loop_t = TensorIteratorBase::loop_t; +using loop2d_t = TensorIteratorBase::loop2d_t; +using StrideVector = TensorIteratorBase::StrideVector; /// Construction TensorIteratorConfig& TensorIteratorConfig::add_output(const Tensor& output) { @@ -101,12 +102,14 @@ TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef sha return *this; } -TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, const int64_t squash_dim) { +TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims) { declare_static_shape(shape); if (!static_shape_->size()) return *this; - TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast(static_shape_->size()), - "squash_dim ", squash_dim, " must be in [0, ", static_shape_->size(), ")."); - (*static_shape_)[squash_dim] = 1; + for (const auto& squash_dim : squash_dims) { + TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast(static_shape_->size()), + "squash_dim ", squash_dim, " must be in [0, ", static_shape_->size(), ")."); + (*static_shape_)[squash_dim] = 1; + } return *this; } @@ -148,7 +151,7 @@ TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef sha // in the strides of trivial dimensions, so physical layout is unaffected but permutation information is lost) // We might change this behavior in future once performance considerations are resolved -void TensorIterator::reorder_dimensions(const TensorIteratorConfig& config) { +void TensorIteratorBase::reorder_dimensions() { // Sort the dimensions based on strides in ascending order with reduced dims // at the front. NOTE: that this inverts the order of C-contiguous tensors. // strides[0] is the fastest moving dimension instead of strides[ndim - 1]. @@ -166,7 +169,6 @@ void TensorIterator::reorder_dimensions(const TensorIteratorConfig& config) { // returns 1 if the dim0 should come after dim1, -1 if dim0 should come // before dim1, and 0 if the comparison is ambiguous. auto should_swap = [&](size_t dim0, size_t dim1) { - int ret = 0; for (int arg = 0; arg < ntensors(); arg++) { // ignore undefined or incorrectly sized tensors if (operands_[arg].stride_bytes.empty() || operands_[arg].will_resize) { @@ -200,7 +202,7 @@ void TensorIterator::reorder_dimensions(const TensorIteratorConfig& config) { } } } - return ret; + return 0; }; // insertion sort with support for ambiguous comparisons @@ -223,7 +225,7 @@ void TensorIterator::reorder_dimensions(const TensorIteratorConfig& config) { // Computes a common dtype using type promotion // See the [Common Dtype Computation] note -ScalarType TensorIterator::compute_common_dtype() { +ScalarType TensorIteratorBase::compute_common_dtype() { at::native::ResultTypeState state = {}; for (const auto& op : operands_) { if (op.is_output) { @@ -250,7 +252,7 @@ ScalarType TensorIterator::compute_common_dtype() { // NOTE: Checks for more specific behaviors (e.g. the first and second // inputs must share a dtype, but the third must have the long dtype) // should be implemented directly and outside of TensorIterator. -void TensorIterator::compute_types(const TensorIteratorConfig& config) { +void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) { // Reviews operands (1/2) // - validates that all input tensors are defined // - computes common device @@ -400,12 +402,28 @@ void TensorIterator::compute_types(const TensorIteratorConfig& config) { // TODO: reuse temporaries when possible (e.g. for inplace operations) if (common_device == kCPU) { // Casts to outputs by creating temporaries of the correct dtype (if needed) - if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_) { + // NB: we skip this on is_meta_, because the temporary allocation here is + // unnecessary if we aren't going to actually do the compute + if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) { + TORCH_INTERNAL_ASSERT(op.tensor.defined()); + // Marker [Output original_tensor is set] op.original_tensor = op.tensor; + // NB: do NOT use set_output here, as the temporary is NOT a true output; + // op.tensor is the true output and it was pre-provided for us. + // TODO: The logic for cast_outputs will need to be handled by the + // structured kernels implementation. What probably should happen + // is that we pass in the inferred dtype into the out kernel, and + // then after calling the out kernel, do the conversion (which + // is cast_outputs here), but integrating this with existing + // TensorIterator will take a little doing op.tensor = at::empty_like(op.tensor, op.tensor.options().dtype(common_dtype_), LEGACY_CONTIGUOUS_MEMORY_FORMAT); + if (!names_.empty()) { + namedinference::propagate_names(op.tensor, names_); + } op.current_dtype = common_dtype_; + op.target_dtype = common_dtype_; } // Promotes inputs by creating temporaries of the correct dtype @@ -413,12 +431,13 @@ void TensorIterator::compute_types(const TensorIteratorConfig& config) { op.original_tensor = op.tensor; op.tensor = op.tensor.to(common_dtype_); op.current_dtype = common_dtype_; + op.target_dtype = common_dtype_; } } } } -StrideVector TensorIterator::compatible_stride(int element_size) const { +StrideVector TensorIteratorBase::compatible_stride(int element_size) const { auto stride = StrideVector(); int64_t next_stride = element_size; for (int dim = 0; dim < ndim(); dim++) { @@ -428,7 +447,7 @@ StrideVector TensorIterator::compatible_stride(int element_size) const { return stride; } -DimVector TensorIterator::invert_perm(IntArrayRef input) const { +DimVector TensorIteratorBase::invert_perm(IntArrayRef input) const { // Invert the permutation caused by reorder_dimensions. This is not valid // after coalesce_dimensions is called. TORCH_INTERNAL_ASSERT(!has_coalesced_dimensions_); @@ -440,7 +459,7 @@ DimVector TensorIterator::invert_perm(IntArrayRef input) const { return res; } -void TensorIterator::allocate_or_resize_outputs() { +void TensorIteratorBase::allocate_or_resize_outputs() { for (int i = 0; i < num_outputs_; i++) { auto& op = operands_[i]; if (!op.tensor.defined() || op.will_resize) { @@ -457,33 +476,27 @@ void TensorIterator::allocate_or_resize_outputs() { } auto tensor_shape = invert_perm(shape_); if (inverted) { - if (!op.tensor.defined()) { - // can just return contiguous output - // it is faster because it avoids allocating 0 size tensor and - // resizing and restriding it - op.tensor = at::empty(tensor_shape, op.options()); - } else { - at::native::resize_output(op.tensor, tensor_shape); - } + // can just return contiguous output + // it is faster because it avoids allocating 0 size tensor and + // resizing and restriding it + set_output(i, tensor_shape, {}, op.options(), names_); } else { auto tensor_stride = invert_perm(op.stride_bytes); for (int dim = 0; dim < ndim(); dim++) { tensor_stride[dim] /= element_size; } - if (!op.tensor.defined()) { - op.tensor = - at::empty_strided(tensor_shape, tensor_stride, op.options()); - } else { - at::native::resize_output(op.tensor, tensor_shape); - op.tensor.as_strided_(tensor_shape, tensor_stride); - } + set_output(i, tensor_shape, tensor_stride, op.options(), names_); } op.current_dtype = op.target_dtype; + } else if (op.tensor.defined()) { + // Even if we don't resize, we still need to tell set_output about + // the output, so that we properly set guard and propagate names + set_output(i, op.tensor.sizes(), {}, op.tensor.options(), names_); } } } -void TensorIterator::compute_names(const TensorIteratorConfig& config) { +void TensorIteratorBase::compute_names(const TensorIteratorConfig& config) { bool should_infer_names = std::any_of( operands_.begin(), operands_.end(), @@ -510,27 +523,7 @@ void TensorIterator::compute_names(const TensorIteratorConfig& config) { } } -void TensorIterator::propagate_names_to_outputs() { - // names_ can be empty for two reasons: - // 1. We were performing ops on scalar tensors. Then there should be no names. - // 2. All of the defined inputs/outputs had no names. Then we shouldn't - // run name inference. - if (names_.empty()) { - return; - } - - // propagate names - for (int i = 0; i < num_outputs_; i++) { - auto& op = operands_[i]; - // must call propagate_names_to_outputs after outputs have been allocated. - TORCH_INTERNAL_ASSERT(op.tensor.defined()); - if (!names_.empty()) { - namedinference::propagate_names(op.tensor, names_); - } - } -} - -void TensorIterator::coalesce_dimensions() { +void TensorIteratorBase::coalesce_dimensions() { if (ndim() <= 1) { return; } @@ -583,7 +576,7 @@ void TensorIterator::coalesce_dimensions() { has_coalesced_dimensions_ = true; } -int64_t TensorIterator::numel() const { +int64_t TensorIteratorBase::numel() const { int64_t numel = 1; for (int64_t size : shape_) { numel *= size; @@ -591,7 +584,7 @@ int64_t TensorIterator::numel() const { return numel; } -StrideVector TensorIterator::get_dim_strides(int dim) const { +StrideVector TensorIteratorBase::get_dim_strides(int dim) const { auto dims = ndim(); auto inner_strides = StrideVector(); for (auto& op : operands_) { @@ -600,7 +593,7 @@ StrideVector TensorIterator::get_dim_strides(int dim) const { return inner_strides; } -SmallVector TensorIterator::get_data_ptrs(ArrayRef base, IntArrayRef counter) const { +SmallVector TensorIteratorBase::get_data_ptrs(ArrayRef base, IntArrayRef counter) const { auto ptrs = SmallVector(base); for (int dim = 0; dim < ndim(); dim++) { int64_t value = counter[dim]; @@ -611,7 +604,7 @@ SmallVector TensorIterator::get_data_ptrs(ArrayRef base, IntArr return ptrs; } -SmallVector TensorIterator::get_base_ptrs() const { +SmallVector TensorIteratorBase::get_base_ptrs() const { auto ptrs = SmallVector(); for (int i = 0; i < ntensors(); i++) { ptrs.push_back((char*)data_ptr(i)); @@ -619,7 +612,7 @@ SmallVector TensorIterator::get_base_ptrs() const { return ptrs; } -bool TensorIterator::is_dim_reduced(int dim) const { +bool TensorIteratorBase::is_dim_reduced(int dim) const { for (auto& op : operands_) { if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) { return true; @@ -628,7 +621,7 @@ bool TensorIterator::is_dim_reduced(int dim) const { return false; } -void TensorIterator::permute_dimensions(IntArrayRef perm) { +void TensorIteratorBase::permute_dimensions(IntArrayRef perm) { TORCH_INTERNAL_ASSERT(perm.size() == ndim()); auto reorder = [perm](IntArrayRef data) { @@ -648,7 +641,7 @@ void TensorIterator::permute_dimensions(IntArrayRef perm) { } } -int64_t TensorIterator::num_output_elements() const { +int64_t TensorIteratorBase::num_output_elements() const { int64_t elem = 1; for (int dim = 0; dim < ndim(); dim++) { if (operands_[0].stride_bytes[dim] != 0 || shape_[dim] == 0) { @@ -658,7 +651,7 @@ int64_t TensorIterator::num_output_elements() const { return elem; } -int TensorIterator::num_reduce_dims() const { +int TensorIteratorBase::num_reduce_dims() const { int count = 0; for (int dim = 0; dim < ndim(); dim++) { if (operands_[0].stride_bytes[dim] == 0) { @@ -683,11 +676,11 @@ int TensorIterator::num_reduce_dims() const { } \ } -void TensorIterator::for_each(loop_t loop, int64_t grain_size) { +void TensorIteratorBase::for_each(loop_t loop, int64_t grain_size) { for_each(LOOP_WRAPPER(ntensors(), loop), grain_size); } -void TensorIterator::for_each(loop2d_t loop, int64_t grain_size) { +void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) { int64_t numel = this->numel(); if (numel == 0) { return; @@ -700,7 +693,7 @@ void TensorIterator::for_each(loop2d_t loop, int64_t grain_size) { } } -StrideVector TensorIterator::get_strides() const { +StrideVector TensorIteratorBase::get_strides() const { StrideVector strides; for (int dim = 0; dim < ndim(); dim++) { for (int arg = 0; arg < ntensors(); arg++) { @@ -710,11 +703,11 @@ StrideVector TensorIterator::get_strides() const { return strides; } -void TensorIterator::serial_for_each(loop_t loop, Range range) const { +void TensorIteratorBase::serial_for_each(loop_t loop, Range range) const { serial_for_each(LOOP_WRAPPER(ntensors(), loop), range); } -void TensorIterator::serial_for_each(loop2d_t loop, Range range) const { +void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const { if (range.size() == 0) { return; } @@ -738,12 +731,12 @@ void TensorIterator::serial_for_each(loop2d_t loop, Range range) const { } } -bool TensorIterator::is_trivial_1d() const { +bool TensorIteratorBase::is_trivial_1d() const { // TODO: check for casting once it's supported return ndim() == 1; } -bool TensorIterator::is_contiguous() const { +bool TensorIteratorBase::is_contiguous() const { if (numel() == 1) { return true; } @@ -754,7 +747,7 @@ bool TensorIterator::is_contiguous() const { } -bool TensorIterator::is_scalar(int arg) const { +bool TensorIteratorBase::is_scalar(int arg) const { const auto& stride = operands_[arg].stride_bytes; for (int i = 0; i < ndim(); i++) { if (stride[i] != 0 && shape_[i] != 1) { @@ -764,14 +757,16 @@ bool TensorIterator::is_scalar(int arg) const { return true; } -bool TensorIterator::is_cpu_scalar(int arg) const { +bool TensorIteratorBase::is_cpu_scalar(int arg) const { return is_scalar(arg) && device(arg).is_cpu(); } -void TensorIterator::cast_outputs() { +void TensorIteratorBase::cast_outputs() { for (auto& op : operands_) { if (op.is_output && op.original_tensor.defined() && op.original_tensor.scalar_type() != op.current_dtype) { + // TODO: Now that set_output resizes both the original_tensor + // and tensor, this condition should no longer ever be true if (op.original_tensor.sizes() != op.tensor.sizes()){ op.original_tensor.resize_as_(op.tensor).as_strided_(op.tensor.sizes(), op.tensor.strides()); } @@ -781,19 +776,19 @@ void TensorIterator::cast_outputs() { } } -void* TensorIterator::data_ptr(int arg) const { +void* TensorIteratorBase::data_ptr(int arg) const { return operands_[arg].data; } -void TensorIterator::remove_operand(int arg) { +void TensorIteratorBase::remove_operand(int arg) { operands_.erase(operands_.begin() + arg); } -void TensorIterator::unsafe_replace_operand(int arg, void* data) { +void TensorIteratorBase::unsafe_replace_operand(int arg, void* data) { operands_[arg].data = data; } -void TensorIterator::narrow(int dim, int64_t start, int64_t size) { +void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) { TORCH_INTERNAL_ASSERT(dim < ndim() && size >= 1); shape_[dim] = size; view_offsets_[dim] += start; @@ -805,7 +800,7 @@ void TensorIterator::narrow(int dim, int64_t start, int64_t size) { } } -void TensorIterator::select_all_keeping_dim(int start_dim, IntArrayRef indices) { +void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indices) { TORCH_INTERNAL_ASSERT(start_dim <= ndim()); for (int i = start_dim; i < ndim(); ++i) { for (auto& op : operands_) { @@ -815,18 +810,22 @@ void TensorIterator::select_all_keeping_dim(int start_dim, IntArrayRef indices) } } -TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, - const Tensor& b) { - return TensorIteratorConfig() - .set_check_mem_overlap(true) - .add_output(out) - .add_input(a) - .add_input(b) - .allow_cpu_scalars(true) - .promote_inputs_to_common_dtype(true) - .cast_common_dtype_to_outputs(true) - .enforce_safe_casting_to_output(true) - .build(); +void TensorIteratorBase::build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b) { + build(TensorIteratorConfig() + .set_check_mem_overlap(true) + .add_output(out) + .add_input(a) + .add_input(b) + .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true)); +} + +TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) { + TensorIterator iter; + iter.build_binary_op(out, a, b); + return iter; } // Helper to construct a binary op that promotes integer inputs to float. @@ -847,7 +846,15 @@ TensorIterator TensorIterator::binary_float_op(Tensor& out, const Tensor& a, TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a, const Tensor& b) { - return TensorIteratorConfig() + // Note [special-case bool outputs] + // We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor + // has `bool` dtype. This is a performance optimization: the functional + // version of all comparison/logical ops uses a bool output tensor, and we'd like to + // avoid creating a temporary copy of the output. + // However, note that all kernels using this TensorIterator will need to special-case when + // the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool). + if (out.scalar_type() == kBool) { + return TensorIteratorConfig() .set_check_mem_overlap(true) .add_output(out) .add_input(a) @@ -855,6 +862,17 @@ TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a, .allow_cpu_scalars(true) .promote_inputs_to_common_dtype(true) .build(); + } else { + return TensorIteratorConfig() + .set_check_mem_overlap(true) + .add_output(out) + .add_input(a) + .add_input(b) + .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .build(); + } } TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) { @@ -868,6 +886,18 @@ TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) { .build(); } +TensorIterator TensorIterator::unary_float_op(Tensor& out, const Tensor& a) { + return TensorIteratorConfig() + .set_check_mem_overlap(true) + .add_output(out) + .add_input(a) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true) + .promote_integer_inputs_to_float(true) + .build(); +} + TensorIterator TensorIterator::nullary_op(Tensor& out) { return TensorIteratorConfig() .set_check_mem_overlap(true) @@ -914,14 +944,21 @@ TensorIterator TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tenso .build(); } -void TensorIterator::populate_operands(TensorIteratorConfig& config) { - for (int i = 0; i < config.tensors_.size(); i++) { - operands_.emplace_back(std::move(config.tensors_[i])); +void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) { + for (auto& tensor: config.tensors_) { + // If *any* of the arguments is a meta tensor, the overall + // computation is a meta computation (don't do any work, + // just compute output information). This aligns with + // our multiple dispatch semantics. + if (tensor.is_meta()) { + is_meta_ = true; + } + operands_.emplace_back(std::move(tensor)); } num_outputs_ = config.num_outputs_; } -void TensorIterator::mark_outputs() { +void TensorIteratorBase::mark_outputs() { // TODO: merge this into populate_operands for (int i = 0; i < num_outputs_; i++) { operands_[i].is_output = true; @@ -938,7 +975,7 @@ void TensorIterator::mark_outputs() { } } -void TensorIterator::mark_resize_outputs(const TensorIteratorConfig& config) { +void TensorIteratorBase::mark_resize_outputs(const TensorIteratorConfig& config) { // Outputs cannot be broadcasted. Check that the shape of the outputs matches // the inferred shape. There's an exception for write-only tensors to support // our legacy behavior that functions with `out=` arguments resize their @@ -960,10 +997,14 @@ void TensorIterator::mark_resize_outputs(const TensorIteratorConfig& config) { } } -void TensorIterator::compute_mem_overlaps(const TensorIteratorConfig& config) { +void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config) { if (!config.check_mem_overlap_) { return; } + if (is_meta_) { + // We don't have pointer addresses, cannot check for overlap! + return; + } for (int i = 0; i < num_outputs_; i++) { const auto& output = operands_[i].tensor; if (!output.defined()) continue; @@ -975,7 +1016,7 @@ void TensorIterator::compute_mem_overlaps(const TensorIteratorConfig& config) { } } -void TensorIterator::compute_shape(const TensorIteratorConfig& config) { +void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) { if (config.static_shape_.has_value()) { shape_ = *config.static_shape_; return; @@ -1011,7 +1052,7 @@ void TensorIterator::compute_shape(const TensorIteratorConfig& config) { } } -void TensorIterator::compute_strides(const TensorIteratorConfig& config) { +void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) { for (auto& op : operands_) { if (op.tensor.defined()) { IntArrayRef original_shape = config.static_shape_ ? shape_ : op.tensor.sizes(); @@ -1034,7 +1075,7 @@ void TensorIterator::compute_strides(const TensorIteratorConfig& config) { } } -bool TensorIterator::can_use_32bit_indexing() const { +bool TensorIteratorBase::can_use_32bit_indexing() const { int64_t max_value = std::numeric_limits::max(); if (numel() > max_value) { return false; @@ -1051,7 +1092,7 @@ bool TensorIterator::can_use_32bit_indexing() const { return true; } -std::unique_ptr TensorIterator::split(int dim) { +std::unique_ptr TensorIteratorBase::split(int dim) { TORCH_INTERNAL_ASSERT(dim >= 0 && dim < ndim() && shape()[dim] >= 2); std::unique_ptr copy(new TensorIterator(*this)); @@ -1067,7 +1108,7 @@ std::unique_ptr TensorIterator::split(int dim) { } -int TensorIterator::get_dim_to_split() const { +int TensorIteratorBase::get_dim_to_split() const { TORCH_INTERNAL_ASSERT(ndim() >= 1); int64_t max_extent = -1; int dim_to_split = -1; @@ -1088,7 +1129,7 @@ int TensorIterator::get_dim_to_split() const { return dim_to_split; } -bool TensorIterator::fast_set_up(const TensorIteratorConfig& config) { +bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) { // This function tries to do a fast setup to avoid needless reordering of dimensions and tracking output strides // Return true if it can do fast setup or false otherwise // TODO enable fast handling for reductions @@ -1105,11 +1146,8 @@ bool TensorIterator::fast_set_up(const TensorIteratorConfig& config) { auto& op = operands_[i]; if (!op.tensor.defined()) { TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i); - op.tensor = at::empty(shape_, op.options(), MemoryFormat::Contiguous); - op.current_dtype = op.target_dtype; - } else if (op.will_resize) { - at::native::resize_output(op.tensor, shape_); } + set_output(i, shape_, {}, op.options().memory_format(MemoryFormat::Contiguous), names_); } break; } @@ -1119,13 +1157,8 @@ bool TensorIterator::fast_set_up(const TensorIteratorConfig& config) { auto& op = operands_[i]; if (!op.tensor.defined()) { TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i); - op.tensor = at::empty(shape_, op.options(), MemoryFormat::ChannelsLast); - op.current_dtype = op.target_dtype; - } else if (op.will_resize) { - at::native::resize_output(op.tensor, shape_); - op.tensor.unsafeGetTensorImpl()->empty_tensor_restride( - MemoryFormat::ChannelsLast); } + set_output(i, shape_, {}, op.options().memory_format(MemoryFormat::ChannelsLast), names_); } break; } @@ -1141,12 +1174,8 @@ bool TensorIterator::fast_set_up(const TensorIteratorConfig& config) { auto& op = operands_[i]; if (!op.tensor.defined()) { TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i); - op.tensor = at::empty_strided(shape_, operands_[i_defined].tensor.strides(), op.options()); - op.current_dtype = op.target_dtype; - } else if (op.will_resize) { - at::native::resize_output(op.tensor, shape_); - op.tensor.as_strided_(shape_, operands_[i_defined].tensor.strides()); } + set_output(i, shape_, operands_[i_defined].tensor.strides(), op.options(), names_); } break; } @@ -1171,7 +1200,7 @@ bool TensorIterator::fast_set_up(const TensorIteratorConfig& config) { return true; } -FastSetupType TensorIterator::compute_fast_setup_type(const TensorIteratorConfig& config) { +FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorConfig& config) { if (is_reduction_ || !all_ops_same_shape_) { return FastSetupType::NONE; } @@ -1223,11 +1252,9 @@ FastSetupType TensorIterator::compute_fast_setup_type(const TensorIteratorConfig return FastSetupType::NONE; } -TensorIterator::TensorIterator(TensorIteratorConfig& config) { - build(config); -} +TensorIteratorBase::TensorIteratorBase() {} -void TensorIterator::build(TensorIteratorConfig& config) { +void TensorIteratorBase::build(TensorIteratorConfig& config) { // populate some persistent configuration fields is_reduction_ = config.is_reduction_; @@ -1251,14 +1278,14 @@ void TensorIterator::build(TensorIteratorConfig& config) { // compute each tensor's stride after broadcasting compute_strides(config); // re-order dimensions to improve coalescing - reorder_dimensions(config); + reorder_dimensions(); // allocate the output tensor if it's not provided allocate_or_resize_outputs(); // coalesce adjacent dimensions when possible - coalesce_dimensions(); + if (!is_meta_) coalesce_dimensions(); } - // perform name inference - propagate_names_to_outputs(); + + if (is_meta_) return; for (auto& op : operands_) { TORCH_INTERNAL_ASSERT(op.tensor.defined()); @@ -1273,14 +1300,125 @@ void TensorIterator::build(TensorIteratorConfig& config) { view_offsets_ = DimVector(ndim_offsets, 0); } -SplitUntil32Bit TensorIterator::with_32bit_indexing() const { +// This is the structured kernels implementation of set_output. It is +// NEVER actually called directly; instead, a subclass of TensorIteratorBase +// will override set_output to actually do the operation, and then call +// set_output on the TensorIteratorBase to setup TI's metadata. +// The precondition for this function is that maybe_get_output() now +// unconditionally returns a real Tensor (prior to output setting, +// this function may return an undefined tensor.) +void TensorIteratorBase::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) { + auto& op = operands_[output_idx]; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); + const auto& t = maybe_get_output(output_idx); + TORCH_INTERNAL_ASSERT(t.defined()); + if (!op.tensor.defined()) { + op.tensor = t; + op.current_dtype = op.target_dtype; + } else if (op.will_resize) { + if (op.original_tensor.defined()) { + // OK, so this is pretty weird. To understand how we can end up in + // this situation, first look at Marker [Output original_tensor is set]. + // That is the sole site where original_tensor may be set on an + // output operand. Essentially, when we are given an explicit output + // tensor whose dtype doesn't match the computed common dtype from + // the input operands, we do a switcheroo: we replace the (incorrectly + // typed) output tensor with a correctly typed, *temporary* tensor, + // and remember the original tensor in original_tensor (which will + // then get written back to when we cast_outputs). + // + // Now, what if the given output tensor also happened to be zero + // size (meaning that we will_resize it)? Well, at the call site + // above, we don't necessarily(*) know what the correct shape should + // be, so we give the temporary tensor the same shape as the original. + // At the time of set_output is when we DO know what the correct size + // is, and the subclass's implementation of set_output in structured class + // responsible for resizing original_tensor. But we still have this + // incorrectly sized temporary output which the structured subclass + // knows nothing about, so we are obligated to also resize it here. + // + // This is a slight memory pessimization, because previously + // original_tensor only got resized at the end of the computation, rather + // than at the beginning (as happens here). However, the peak memory + // usage is the same, since you need to materialize both original tensor + // and temporary tensor to do the copy. + // + // (*) Actually, technically, we probably do know what the shape + // should be, since we do shape computation before dtype computation. + // So hypothetically we could figure out what the correct shape is + // at that point in time and directly allocate the temporary at + // the right size. + // + // But a better solution is to delay allocation of temporaries until + // after TensorIterator builder, waiting until we actually want + // to do the computation. That would also remove the necessity + // for the is_meta_ test. + TORCH_INTERNAL_ASSERT(op.original_tensor.is_same(t)); + TORCH_INTERNAL_ASSERT(!op.tensor.is_same(t)); + at::native::resize_output(op.tensor, sizes); + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + op.tensor.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + op.tensor.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } + } +} + +// This is the "traditional" implementation of set_output. On TensorIterator +// instances, it is invoked directly from various call sites in this file. No +// funny business. +void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) { + // NB: intentionally no superclass call + auto& op = operands_[output_idx]; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); + if (!op.tensor.defined()) { + if (strides.empty()) { + if (is_meta_) { + op.tensor = at::empty_meta(sizes, options); + } else { + op.tensor = at::empty(sizes, options); + } + } else { + if (is_meta_) { + TORCH_INTERNAL_ASSERT(0, "meta strided not yet implemented"); + } else { + op.tensor = at::empty_strided(sizes, strides, options); + } + } + op.current_dtype = op.target_dtype; + } else if (op.will_resize) { + at::native::resize_output(op.tensor, sizes); + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + op.tensor.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + op.tensor.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } + if (!names.empty()) { + TORCH_INTERNAL_ASSERT(op.tensor.defined()); + namedinference::propagate_names(op.tensor, names); + } +} + +// Not actually used by anything (TensorIterator subclass calls +// its own implementation of set_output which knows exactly where +// all the outputs are), but we have to provide all pure virtual methods +// for MetaBase +const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) { + return operands_[output_idx].tensor; +} + +SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const { return SplitUntil32Bit(*this); } /// SplitUntil32Bit. Recursively splits an iterator into sub-iterators that /// can use 32-bit indexing. -SplitUntil32Bit::iterator::iterator(const TensorIterator& iter) { +SplitUntil32Bit::iterator::iterator(const TensorIteratorBase& iter) { vec.emplace_back(new TensorIterator(iter)); vec.emplace_back(nullptr); // ++ first pops the last element ++(*this); diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h new file mode 100644 index 0000000000000..5132fb05dcc23 --- /dev/null +++ b/aten/src/ATen/TensorIterator.h @@ -0,0 +1,564 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// TensorIterator is a helper class for element-wise operations, such as +// arithmetic, comparisons, and trigonometric functions. It handles +// broadcasting and type conversions of operands. +// +// This is inspired by NumPy's Array Iterator API (NpyIter). +// +// The files Loops.h and Loops.cuh provide functions to build kernels that +// use TensorIterator. +// +// Example: +// +// auto iter = TensorIteratorConfig() +// .add_output(output) +// .add_input(input) +// .build() +// +// [MyKernel.cpp / MyKernel.cu] +// cpu_kernel(iter, [](float a, float b) { +// return a + b; +// }); +// +// gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float { +// return a + b; +// }); +// +// Note [Common Dtype Computation] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Some operations have a natural notion of a "common dtype" or +// "computation dtype" where all inputs are cast to one dtype, the +// operation is performed, and then the results are cast to all outputs. +// +// TensorIterator infers a common dtype if all inputs have the same dtype, +// and it computes one using type promotion rules on its inputs if +// promote_inputs_to_common_dtype_ is true. Attempting to query +// a common dtype otherwise will throw an exception. +// +// Note that the outputs are not considered when computing a common dtype. + +namespace at { + +namespace internal { +// This parameter is heuristically chosen to determine the minimum number of +// work that warrants parallelism. For example, when summing an array, it is +// deemed inefficient to parallelise over arrays shorter than 32768. Further, +// no parallel algorithm (such as parallel_reduce) should split work into +// smaller than GRAIN_SIZE chunks. +constexpr int64_t GRAIN_SIZE = 32768; +} // namespace internal + +struct DimCounter { + DimCounter(IntArrayRef shape, Range range); + + void increment(const std::array& step); + bool is_done() const; + std::array max_2d_step() const; + + IntArrayRef shape; + Range range; + DimVector values; + int64_t offset; +}; + +struct TORCH_API OperandInfo { + using StrideVector = SmallVector; + OperandInfo() {} + explicit OperandInfo(Tensor t) : tensor(std::move(t)) { + if (tensor.defined()) { + device = tensor.device(); + target_dtype = tensor.scalar_type(); + current_dtype = target_dtype; + } + validate(); + } + + /// Stride after broadcasting. The stride is in bytes, not number of elements. + StrideVector stride_bytes; + + /// The tensor operand. Note that the strides, data pointer, and + /// other attributes may differ due to dimension reordering and + /// coalescing. + Tensor tensor; + + // Save the original tensor operand in cases when an output is modified + // (e.g. if dtype is changed) + Tensor original_tensor; + + /// The desired device and type for the operand. For inputs, this specifies that + /// the input should be converted to this type if necessary. For outputs, this + /// specifies which type to allocate. target_dtype and device are initialized with the dtype and device of the tensor + /// but during type promotion target_dtype value can become different from tensor's dtype + /// also, during type promotion target_dtype and device can be set for an undefined tensor so that tensor can be properly + /// constructed later. + Device device = kCPU; + ScalarType target_dtype = ScalarType::Undefined; + // Caches dtype of the tensor, because scalar_type is an expensive operation + // If dtype of the tensor is changed (e.g. as a result of type promotion or in allocate_outputs), this + //value should be changed too. + ScalarType current_dtype = ScalarType::Undefined; + + bool is_type_defined() const { return target_dtype != ScalarType::Undefined; } + TensorOptions options() const { + return TensorOptions(target_dtype).device(device); + } + + /// The data pointer. This may be different from tensor.data_ptr() if the + /// iterator is split. + void* data = nullptr; + + bool is_output = false; + + bool will_resize = false; + + bool is_read_write = false; + + void validate() { + TORCH_CHECK( + !tensor.defined() || tensor.layout() == kStrided, + "unsupported tensor layout: ", tensor.layout()); + } +}; + +struct SplitUntil32Bit; + +enum class FastSetupType : uint8_t { + NONE, + CONTIGUOUS, + CHANNELS_LAST, + NON_OVERLAPPING_DENSE +}; + +class TensorIteratorConfig; +struct TensorIterator; + +struct TORCH_API TensorIteratorBase : public impl::MetaBase { + using DimMask = std::bitset<64>; + using PtrVector = SmallVector; + using StrideVector = SmallVector; + + TensorIteratorBase(); + void build(TensorIteratorConfig&); + + // The inner-loop function operates on the fastest moving dimension. It + // implements element-wise operations in terms of 1-d strided tensors. + // + // Arguments: + // data: data pointers for each operand (length `ntensors`) + // strides: stride for each operand (length `ntensors`) + // size: size of inner loop + // + // The `size` often matches shape[0], but may be smaller due to + // parallelization of the inner loop. + using loop_t = c10::function_ref; + using loop2d_t = c10::function_ref; + + using loop_subiter_t = c10::function_ref; + + void foreach_reduced_elt(loop_subiter_t loop, bool parallelize=true); + + int ndim() const { return shape_.size(); } + IntArrayRef shape() const { return shape_; } + int64_t numel() const; + int ntensors() const { return operands_.size(); } + int noutputs() const { return num_outputs_; } + int ninputs() const { return ntensors() - noutputs(); } + IntArrayRef view_offsets() const { return view_offsets_; } + + /// number of elements in the output operand. this is the same as numel() for + /// operations that are not reductions. + int64_t num_output_elements() const; + + /// number of reduced dimensions in a reduction operation + int num_reduce_dims() const; + + /// 1-dimensional iteration and no buffering or type conversion + bool is_trivial_1d() const; + /// Reducible to 1-dimensional and all operands are contiguous + bool is_contiguous() const; + bool is_dim_reduced(int dim) const; + + /// Accessors for each operand + IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; } + void* data_ptr(int arg) const; + ScalarType dtype(int arg=0) const { return operands_[arg].current_dtype; } + ScalarType common_dtype() const { + TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined, "Queried for invalid common dtype!"); + return common_dtype_; + } + ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].current_dtype; } + Device device(int arg=0) const { return operands_[arg].device; } + DeviceType device_type(int arg=0) const { return device(arg).type(); } + int64_t element_size(int arg) const { return elementSize(dtype(arg)); } + bool is_scalar(int arg) const; + bool is_cpu_scalar(int arg) const; + + const Tensor& tensor(int arg) const { return operands_[arg].tensor; } + Tensor& tensor(int arg) { return operands_[arg].tensor; } + + Tensor output(int arg=0) const { + AT_ASSERT(arg < num_outputs_); + return operands_[arg].tensor; + } + + // Copies from temporary outputs back to the original outputs + // NOTE: only used on CPU + void cast_outputs(); + + Tensor input(int arg=0) const { + AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); + return operands_[num_outputs_ + arg].tensor; + } + + /// Removes an operand from this iterator + void remove_operand(int arg); + /// Shrinks an iterated dimension + void narrow(int dim, int64_t start, int64_t size); + /// Narrows every dim after and including `start_dim` to size one. + void select_all_keeping_dim(int start_dim, IntArrayRef starts); + /// Replaces the data pointer for the operand at index `arg`. + /// The new pointer should have the same sizes, strides and dtype as the + /// original + void unsafe_replace_operand(int arg, void* data); + + /// Splits this TensorIterator into two iterators. Together they iterate over + /// the entire operation. Used by `with_32bit_indexing()`. + std::unique_ptr split(int dim); + + /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim] + int get_dim_to_split() const; + + template + T scalar_value(int arg) { + auto& op = operands_[arg]; + return c10::fetch_and_cast(op.tensor.scalar_type(), op.data); + } + + void for_each(loop_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); + void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); + + void parallel_reduce(loop2d_t loop); + + void serial_for_each(loop_t loop, Range range) const; + void serial_for_each(loop2d_t loop, Range range) const; + + /// Create a strides array for a Tensor with shape of this iterator. The + /// parameter `element_size` specifies the size of Tensor's data type in + /// bytes (e.g. `4` for `float`) + StrideVector compatible_stride(int element_size) const; + + /// Inverts the re-ordering done by reorder_dimensions. This can only be + /// called *before* coalesce_dimensions() is called. + DimVector invert_perm(IntArrayRef input) const; + + /// Reapply same re-ordering as it is done by reorder_dimensions. This can + /// only be called *before* coalesce_dimensions() is called. + DimVector apply_perm_and_mul(IntArrayRef input, int mul) const; + + /// Helper functions for CPU iteration + StrideVector get_dim_strides(int dim) const; + StrideVector get_strides() const; + StrideVector get_inner_strides() const { return get_dim_strides(0); } + PtrVector get_data_ptrs(ArrayRef base, IntArrayRef counter) const; + PtrVector get_base_ptrs() const; + + /// true if the stride computation can use 32-bit arithmetic. Used by GPU kernels + bool can_use_32bit_indexing() const; + + /// An "iteratable" object that recursively splits this iterator into sub-iterators + /// that can use 32-bit indexing. + SplitUntil32Bit with_32bit_indexing() const; + + /// If the kernel should accumulate into the output. Only relevant for CUDA + /// reductions. + bool should_accumulate() const { return accumulate_; } + + /// Whether this iterator produces the actual output, + /// as opposed to something that will be accumulated further. Only relevant for + /// CUDA reductions. + bool is_final_output() const { return final_output_; } + + bool has_contiguous_first_dim() const { + int num_tensors = ntensors(); + for (int i = 0; i < num_tensors; i++) { + if (strides(i)[0] != element_size(i)) { + return false; + } + } + return true; + } + + void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override; + + void build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b); + +protected: + // Mutable reference as it moves tensors out of TensorIteratorConfig + void populate_operands(TensorIteratorConfig&); + void mark_outputs(); + void mark_resize_outputs(const TensorIteratorConfig&); + void compute_mem_overlaps(const TensorIteratorConfig&); + void compute_shape(const TensorIteratorConfig&); + void compute_strides(const TensorIteratorConfig&); + void reorder_dimensions(); + void permute_dimensions(IntArrayRef perm); + void compute_types(const TensorIteratorConfig&); + ScalarType compute_common_dtype(); + void allocate_or_resize_outputs(); + bool fast_set_up(const TensorIteratorConfig&); + FastSetupType compute_fast_setup_type(const TensorIteratorConfig&); + void compute_names(const TensorIteratorConfig&); + void propagate_names_to_outputs(); + void coalesce_dimensions(); + +protected: + + /// Records the "computation" shape of the output tensor. The computation + /// shape is different from the regular shape in a few ways: + /// + /// - The shape may be permuted (via permute_dimensions) so that we + /// process the dimensions in the most computationally efficient order + /// (rather than the logical order given to us by the users.) + /// - The shape may have adjacent dimensions collapsed (via + /// coalesce_dimensions) so that we minimize the number of + /// dimensions we have to explicitly iterate over. For example, + /// a pointwise operation on a contiguous tensor "computationally" + /// consists of only a single dimension. + /// + /// In other words, the computation shape is the output shape as it + /// actually matters for implementing the kernel, but not necessarily the + /// output shape that the user will see in the end. + /// + /// The lifecycle of mutations to shape_ in TensorIterator: + /// - declare_static_shape() sets an initial shape explicitly + /// provided by user, otherwise + /// - compute_shape() computes the true (non-computational) shape + /// specified by the user. + /// - reorder_dimensions() reorders dimensions to improve coalescing. + /// - coalesce_dimensions() then coalesces adjacent dimensions when + /// possible. + /// + /// The shape may also be further modified if we create sub-TensorIterators, + /// e.g., via narrow or select_all_keeping_dim. + DimVector shape_; + + /// Temporarily records the permutation computed by reorder_dimensions. + /// This permutation maps the computation output dimension (dim) to + /// the original true output dimension (perm_[dim]). It is used by + /// invert_perm to undo the permutation. After coalesce_dimensions is + /// called, the permutation is no longer valid (as, in general, there + /// is no permutation that will make computation dimensions to + /// output dimensions); methods that manipulate perm_ are obligated + /// to test that !has_coalesced_dimensions + DimVector perm_; + + /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build()) + /// been called? This is SOLELY used to check validity of perm_. + bool has_coalesced_dimensions_ = false; + + /// The index offsets into the original tensors for each dimension. + /// This is only non-zero when you narrow() a TensorIterator (e.g., + /// when you make sub-TensorIterators). + DimVector view_offsets_; + + /// The computed names of the output tensor. Computed by compute_names() + NameVector names_; + + /// The operands of the TensorIterator: both the inputs and outputs. The + /// outputs MUST come first in the operands_ list. There is always an + /// operand for each output of the TensorIterator, even if TensorIterator + /// will ultimately be responsible for allocating the output; in those + /// cases, tensor is simply undefined (and will be populated later + /// during build()). + /// + /// This list is initially populated prior to build(), but build() mutates + /// OperandInfo to populate more information. + SmallVector operands_; + + /// Number of outputs in operands_ (the length of the outputs prefix + /// in operands_). + int num_outputs_ = 0; + + /// Whether or not all operands have the same shape. Having all the same + /// shape affects whether or not the iterator is eligible for fast setup. + bool all_ops_same_shape_ = false; + + /// The "computation" dtype of TensorIterator, specifying what the dtype + /// we will do the internal computation in TensorIterator. Typically, + /// this matches the dtype of the output tensors, but not always! + ScalarType common_dtype_ = ScalarType::Undefined; + + /// Set by split(), see should_accumulate() and is_final_output() + bool accumulate_ = false; + bool final_output_ = true; + + // From TensorIteratorConfig + bool is_reduction_ = false; + + /// Set by populate_operands(), says if we're handling meta tensors + bool is_meta_ = false; +}; + +struct TORCH_API TensorIterator final : public TensorIteratorBase { + TensorIterator() : TensorIteratorBase() {} + // Slicing is OK, TensorIterator guaranteed NOT to have any fields + TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {} + + static TensorIterator binary_float_op(Tensor& out, const Tensor& a, const Tensor& b); + static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b); + static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b); + static TensorIterator unary_op(Tensor& out, const Tensor& a); + static TensorIterator unary_float_op(Tensor& out, const Tensor& a); + static TensorIterator nullary_op(Tensor& out); + static TensorIterator reduce_op(Tensor& out, const Tensor& a); + static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a); + + const Tensor& maybe_get_output(int64_t output_idx) override; + void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override; +}; + +class TORCH_API TensorIteratorConfig final { +public: + friend struct TensorIteratorBase; + friend struct TensorIterator; + + TensorIteratorConfig() {} + + C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig); + + /// Construction + TensorIteratorConfig& add_output(const Tensor& output); + TensorIteratorConfig& add_input(const Tensor& input); + + // Sets the check_mem_overlap_ flag, which is true by default. + // If true, inputs are checked for partial overlap with the outputs and + // outputs are checked for internal overlap (e.g. broadcasted views). An error + // is raised if unacceptable overlap is detected. + // If you're migrating an existing operator to using TensorIterator, please + // consider if the previous implementation checked memory overlap. If it did + // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then + // checking memory overlap is BC-breaking. Please don't check memory overlap + // in that case. + TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap); + + // Sets the check_all_same_dtype_ flag, which is true by default + // If true, checks that all inputs and defined outputs have the same dtype + // Setting either of promote_inputs_to_common_dtype_ + // or cast_common_dtype_to_outputs_ to true will set + // check_all_same_dtype_ to false. + TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype); + + // Sets the check_all_same_device_ flag, which is true by default + // If true, all operands must be on the same device, with the possible + // exception of CPU scalars, which can be passed to some CUDA kernels + // as kernel arguments. + TensorIteratorConfig& check_all_same_device(const bool _check_all_same_device); + + // Sets the enforce_safe_casting_to_output_ flag, which is false by default + // If true, the iterator's "common dtype" must be computable + // (see the [Common Dtype Computation] note) and + // canCast(common dtype, output dtype) must be true for all outputs. + TensorIteratorConfig& enforce_safe_casting_to_output(const bool _enforce_safe_casting_to_output); + + // Sets the promote_inputs_to_common_dtype_ flag, which is false by default + // If true, the iterator's "common dtype" is always computed (see the + // [Common Dtype Computation] note) and, on the CPU, temporary copies of + // the inputs in the common dtype are passed as the actual inputs to + // the operation. + // Setting this flag to true sets check_all_same_dtype_ to false. + TensorIteratorConfig& promote_inputs_to_common_dtype(const bool _promote_inputs_to_common_dtype); + + // Sets the promote_integer_inputs_to_float_ flag, which is false by default + // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be true. + // If true, if the iterator's "common dtype" is an integral type (including bool) + // then it is changed to the default float scalar type. + TensorIteratorConfig& promote_integer_inputs_to_float(const bool _promote_integer_inputs_to_float); + TensorIteratorConfig& is_reduction(const bool _is_reduction); + TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars); + + // Sets the cast_common_dtype_to_outputs_ flag, which is false by default + // If true, the iterator's "common dtype" must be computatable + // (see the [Common Dtype Computation] note) and, on the CPU, temporary + // copies of the outputs are passed as the actual output to the operation. + // These temporaries are then copied to the original outputs after + // the operation is performed (see cast_outputs()). + // Setting this flag to true sets check_all_same_dtype_ to false. + TensorIteratorConfig& cast_common_dtype_to_outputs(const bool _cast_common_dtype_to_outputs); + TensorIteratorConfig& resize_outputs(bool resize_outputs); + + // Bypass output dtype/device computation and fix the dtype/device as specified here. + TensorIteratorConfig& declare_static_dtype_and_device(ScalarType dtype, Device device); + TensorIteratorConfig& declare_static_shape(IntArrayRef shape); + TensorIteratorConfig& declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims); + + // It would be better if this was && qualified, but this would be at the cost + // of a lot of boilerplate above + TensorIterator build() { + TensorIterator iter; + iter.build(*this); + return iter; + } + +private: + SmallVector tensors_; + int num_outputs_ = 0; + int num_inputs_ = 0; + + c10::optional static_shape_ = c10::nullopt; + c10::optional> static_dtype_and_device_ = c10::nullopt; + bool check_mem_overlap_ = true; + bool allow_cpu_scalars_ = false; + bool is_reduction_ = false; + bool resize_outputs_ = true; + bool check_all_same_dtype_ = true; + bool check_all_same_device_ = true; + bool enforce_safe_casting_to_output_ = false; + bool promote_inputs_to_common_dtype_ = false; + bool promote_integer_inputs_to_float_ = false; + bool cast_common_dtype_to_outputs_ = false; +}; + + + +/// A container-like struct that acts as if it contains splits of a +/// TensorIterator that can use 32-bit indexing. Taken together the splits cover +/// the original TensorIterator. +struct TORCH_API SplitUntil32Bit { + struct TORCH_API iterator { + iterator() {}; + iterator(const TensorIteratorBase& iter); + iterator(iterator&&) = default; + + // Guaranteed to be a TensorIterator proper! + TensorIterator& operator*() const; + iterator& operator++(); + bool operator==(const iterator& other) const { + // two iterators are equal if they are the same object or they're both empty + return this == &other || (vec.empty() && other.vec.empty()); + } + // needed for C++11 range-based for loop + bool operator!=(const iterator& other) const { return !(*this == other); } + + /// stack of TensorIterators to be split + std::vector> vec; + }; + + SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {} + + iterator begin() const; + iterator end() const; + +private: + const TensorIteratorBase& iter; +}; + +} // namespace at diff --git a/aten/src/ATen/TensorMeta.cpp b/aten/src/ATen/TensorMeta.cpp new file mode 100644 index 0000000000000..6f4d667d56531 --- /dev/null +++ b/aten/src/ATen/TensorMeta.cpp @@ -0,0 +1,5 @@ +#include + +namespace at { + +} // namespace at diff --git a/aten/src/ATen/TensorMeta.h b/aten/src/ATen/TensorMeta.h new file mode 100644 index 0000000000000..1b05b6943d271 --- /dev/null +++ b/aten/src/ATen/TensorMeta.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include + +namespace at { + +class Tensor; + +namespace impl { + +// Use this to define the prototype for a meta function. There are two +// versions; one that takes one argument (just the operator name), or FUNC2 +// variant that takes two arguments (operator name and overload name). +// +// Example usage: +// +// TORCH_META_FUNC2(add, Tensor) ( +// const Tensor& self, const Tensor& other +// ) { +// ... compute sizes and options ... +// set_output(sizes, options); +// } +// +#define TORCH_META_FUNC(name) void name::meta +#define TORCH_META_FUNC2(name, overload) void name##_##overload::meta + +// Use this to define the prototype for an implementation. This takes only +// one argument, which is the name of the dispatch key entry you're +// implementing. +// +// Example usage: +// +// TORCH_IMPL_FUNC(add_cpu) ( +// Tensor& result, const Tensor& self, const Tensor& other +// ) { +// ... do the actual implementation ... +// } +// +#define TORCH_IMPL_FUNC(name) void structured_##name::impl + +// Base class for all structured kernel classes. The set_output virtual +// method is varied depending whether or not the operator is +// functional/out/inplace, and could also be specialized for CPU/CUDA/etc +// (although presently it isn't). +// +// A notable subclass of this interface is TensorIteratorBase. +struct TORCH_API MetaBase { + virtual void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) = 0; + virtual const Tensor& maybe_get_output(int64_t output_idx) = 0; + void set_output(IntArrayRef sizes, TensorOptions options) { + set_output(0, sizes, {}, options, {}); + } + // Returns a reference to an undefined tensor if there is no presupplied + // output + const Tensor& maybe_get_output() { return maybe_get_output(0); } + virtual ~MetaBase() {} +}; + +} // namespace impl + +} // namespace at diff --git a/aten/src/ATen/TensorNames.cpp b/aten/src/ATen/TensorNames.cpp index 844ff4ba2bad0..a7dc0bd680363 100644 --- a/aten/src/ATen/TensorNames.cpp +++ b/aten/src/ATen/TensorNames.cpp @@ -61,10 +61,10 @@ TensorNames::TensorNames(ArrayRef names, int64_t start, int64_t end) { } TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) { - int64_t size_diff = std::labs(names_.size() - other.names_.size()); + size_t size_diff = std::labs(names_.size() - other.names_.size()); if (names_.size() > other.names_.size()) { - for (int64_t idx = size_diff; idx < names_.size(); ++idx) { + for (size_t idx = size_diff; idx < names_.size(); ++idx) { names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name); } } else { diff --git a/aten/src/ATen/TensorNames.h b/aten/src/ATen/TensorNames.h index eeb8ec1a2a290..64bad7c5d6c6a 100644 --- a/aten/src/ATen/TensorNames.h +++ b/aten/src/ATen/TensorNames.h @@ -26,7 +26,7 @@ namespace at { namespace namedinference { // None (in tensor) cannot match A (in other) because if the None were refined // to A, `tensor` would have duplicate names [A, A]. Therefore we need to check // tensor.names [A, None] for the existence of A. -struct CAFFE2_API TensorName { +struct TORCH_API TensorName { explicit TensorName(ArrayRef origin, int origin_idx) : origin_(origin), name_(origin[maybe_wrap_dim(origin_idx, origin.size())]), @@ -41,14 +41,14 @@ struct CAFFE2_API TensorName { Dimname name_; int origin_idx_; // A named tensor can have at most 64 dims. - CAFFE2_API friend std::ostream& operator<<( + TORCH_API friend std::ostream& operator<<( std::ostream& out, const TensorName& tensorname); }; using TensorNameVec = SmallVector; -struct CAFFE2_API TensorNames { +struct TORCH_API TensorNames { explicit TensorNames(ArrayRef names); // Create TensorNames from names[start:end]. Each individual TensorName stores diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 626e0c73e45e0..cb06876635811 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -19,6 +19,25 @@ std::ostream& operator<<(std::ostream & out, TensorGeometryArg t) { return out; } +void checkDim( + CheckedFrom c, + const Tensor& tensor, + const char* name, + int pos, // 1-indexed + int64_t dim) { + TORCH_CHECK( + tensor.dim() == dim, + "Expected ", + dim, + "-dimensional tensor, but got ", + tensor.dim(), + "-dimensional tensor for ", + TensorGeometryArg(TensorArg({tensor, name, pos})), + " (while checking arguments for ", + c, + ")"); +} + void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim) { TORCH_CHECK(t->dim() == dim, "Expected ", dim, "-dimensional tensor, but got ", t->dim(), @@ -335,8 +354,7 @@ c10::optional> computeStride( // we use the stride as if it were computed via resize. // This could perhaps be combined with the below code, but the complexity // didn't seem worth it. - int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1, - std::multiplies()); + const int64_t numel = prod_intlist(oldshape); if (numel == 0 && oldshape.equals(newshape)) { return oldstride.vec(); } diff --git a/aten/src/ATen/TensorUtils.h b/aten/src/ATen/TensorUtils.h index 0882eb4cba471..ef9cb0fcbe32e 100644 --- a/aten/src/ATen/TensorUtils.h +++ b/aten/src/ATen/TensorUtils.h @@ -12,7 +12,7 @@ namespace at { // make sense. These are particularly useful for native functions, // which do NO argument checking by default. -struct CAFFE2_API TensorArg { +struct TORCH_API TensorArg { Tensor tensor; const char* name; int pos; // 1-indexed @@ -22,7 +22,7 @@ struct CAFFE2_API TensorArg { const Tensor& operator*() const { return tensor; } }; -struct CAFFE2_API TensorGeometryArg { +struct TORCH_API TensorGeometryArg { TensorGeometry tensor; const char* name; int pos; // 1-indexed @@ -49,104 +49,110 @@ using CheckedFrom = const char*; // not TensorGeometryArg, because the Tensor to TensorGeometry // conversion will blow up if you have undefined tensors. -CAFFE2_API std::ostream& operator<<(std::ostream& out, TensorGeometryArg t); -CAFFE2_API void checkDim( +TORCH_API std::ostream& operator<<(std::ostream& out, TensorGeometryArg t); +TORCH_API void checkDim( + CheckedFrom c, + const Tensor& tensor, + const char* name, + int pos, // 1-indexed + int64_t dim); +TORCH_API void checkDim( CheckedFrom c, const TensorGeometryArg& t, int64_t dim); // NB: this is an inclusive-exclusive range -CAFFE2_API void checkDimRange( +TORCH_API void checkDimRange( CheckedFrom c, const TensorGeometryArg& t, int64_t dim_start, int64_t dim_end); -CAFFE2_API void checkSameDim( +TORCH_API void checkSameDim( CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2); -CAFFE2_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t); -CAFFE2_API void checkAllContiguous(CheckedFrom c, at::ArrayRef ts); -CAFFE2_API void checkSize( +TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t); +TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef ts); +TORCH_API void checkSize( CheckedFrom c, const TensorGeometryArg& t, IntArrayRef sizes); -CAFFE2_API void checkSize( +TORCH_API void checkSize( CheckedFrom c, const TensorGeometryArg& t, int64_t dim, int64_t size); -CAFFE2_API void checkNumel( +TORCH_API void checkNumel( CheckedFrom c, const TensorGeometryArg& t, int64_t numel); -CAFFE2_API void checkSameNumel( +TORCH_API void checkSameNumel( CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2); -CAFFE2_API void checkAllSameNumel(CheckedFrom c, ArrayRef tensors); -CAFFE2_API void checkScalarType( +TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkScalarType( CheckedFrom c, const TensorArg& t, ScalarType s); -CAFFE2_API void checkScalarTypes( +TORCH_API void checkScalarTypes( CheckedFrom c, const TensorArg& t, at::ArrayRef l); -CAFFE2_API void checkSameGPU( +TORCH_API void checkSameGPU( CheckedFrom c, const TensorArg& t1, const TensorArg& t2); -CAFFE2_API void checkAllSameGPU(CheckedFrom c, ArrayRef tensors); -CAFFE2_API void checkSameType( +TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkSameType( CheckedFrom c, const TensorArg& t1, const TensorArg& t2); -CAFFE2_API void checkAllSameType(CheckedFrom c, ArrayRef tensors); -CAFFE2_API void checkSameSize( +TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkSameSize( CheckedFrom c, const TensorArg& t1, const TensorArg& t2); -CAFFE2_API void checkDefined(CheckedFrom c, const TensorArg& t); -CAFFE2_API void checkAllDefined(CheckedFrom c, at::ArrayRef t); +TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t); +TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef t); // FixMe: does TensorArg slow things down? -CAFFE2_API void checkBackend( +TORCH_API void checkBackend( CheckedFrom c, at::ArrayRef t, at::Backend backend); -CAFFE2_API void checkDeviceType( +TORCH_API void checkDeviceType( CheckedFrom c, at::ArrayRef tensors, at::DeviceType device_type); -CAFFE2_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout); +TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout); -CAFFE2_API void checkLayout(CheckedFrom c, at::ArrayRef tensors, at::Layout layout); +TORCH_API void checkLayout(CheckedFrom c, at::ArrayRef tensors, at::Layout layout); // Methods for getting data_ptr if tensor is defined -CAFFE2_API void* maybe_data_ptr(const Tensor& tensor); -CAFFE2_API void* maybe_data_ptr(const TensorArg& tensor); +TORCH_API void* maybe_data_ptr(const Tensor& tensor); +TORCH_API void* maybe_data_ptr(const TensorArg& tensor); // Return if the tensor geometry represented by `sizes` and `strides` is contiguous // Although we cache is_contiguous in tensor now, this is till useful because it // allows checking if a particular geometry is contiguous without explicitly // constructing a tensor, e.g., when you want to choose a kernel strategy based // on whether a subgeometry is contiguous. -CAFFE2_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); +TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); // Correspond to THCUNN_check_dim_size/THNN_check_dim_size -CAFFE2_API void check_dim_size( +TORCH_API void check_dim_size( const Tensor& tensor, int64_t dim, int64_t dim_size, int64_t size); namespace detail { -CAFFE2_API std::vector defaultStrides(IntArrayRef sizes); -CAFFE2_API size_t +TORCH_API std::vector defaultStrides(IntArrayRef sizes); +TORCH_API size_t computeStorageNbytes(IntArrayRef sizes, IntArrayRef strides, size_t itemsize); -CAFFE2_API c10::optional> computeStride( +TORCH_API c10::optional> computeStride( IntArrayRef oldshape, IntArrayRef oldstride, IntArrayRef newshape); diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 7ed7f66e25227..3c7b9b6ff5bc1 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -10,9 +10,8 @@ namespace at { ThreadLocalState::ThreadLocalState(bool keep_grad_mode) : dispatch_key_(c10::impl::tls_local_dispatch_key_set()), - debug_info_(c10::ThreadLocalDebugInfo::current()), - observers_enabled_(at::isRecordFunctionEnabled()) { - callbacks_ = _getTLSCallbacks(); + debug_info_(c10::ThreadLocalDebugInfo::current()) { + rf_tls_ = at::get_record_function_tls_(); #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) keep_grad_mode_ = keep_grad_mode; @@ -20,6 +19,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode) grad_mode_enabled_ = GradMode::is_enabled(); } #endif + bumped_record_all_functions_ = at::checkRecordAllFunctions(); } /* static */ @@ -31,9 +31,7 @@ void ThreadLocalState::setThreadLocalState( } #endif - _setTLSCallbacks(state.callbacks_); - - at::enableRecordFunction(state.observers_enabled_); + at::set_record_function_tls_(state.rf_tls_); c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_); diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 186e521f01bd2..3c9b55b3d8d63 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -30,16 +30,17 @@ class TORCH_API ThreadLocalState { // with DebugInfoGuard std::shared_ptr debug_info_; - // RecordFunction TLS callbacks - RecordFunctionCallbacks callbacks_; - - bool observers_enabled_ = false; + // RecordFunction TLS + RecordFunctionTLS rf_tls_; #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) bool keep_grad_mode_ = true; bool grad_mode_enabled_; #endif + // Whether pre-sampling RecordFunction optimization was enabled + bool bumped_record_all_functions_ = false; + friend class ThreadLocalStateGuard; }; @@ -47,7 +48,21 @@ class TORCH_API ThreadLocalState { class TORCH_API ThreadLocalStateGuard { public: explicit ThreadLocalStateGuard(const ThreadLocalState& state) - : prev_state_(ThreadLocalState()) { + : prev_state_(ThreadLocalState()), + bumped_record_all_functions_(state.bumped_record_all_functions_) { + // Special handling of RecordFunction pre-sampling optimization: + // pre-samping is enabled (bumped) when there're non-sampled + // (or high-frequency) global or TLS callbacks. + // + // ThreadLocalStateGuard simply resets RecordFunction's TLS and + // hence its thread local callbacks. + // + // Checking if the pre-sampling was enabled and preserving it in the + // async task by calling bumpRecordAllFunctions() and the corresponding + // releaseRecordAllFunctions() + if (bumped_record_all_functions_) { + at::bumpRecordAllFunctions(); + } // set the given state across the thread boundary ThreadLocalState::setThreadLocalState(state); } @@ -55,10 +70,15 @@ class TORCH_API ThreadLocalStateGuard { ~ThreadLocalStateGuard() { // restore previously set variables ThreadLocalState::setThreadLocalState(prev_state_); + if (bumped_record_all_functions_) { + at::releaseRecordAllFunctions(); + } } private: const ThreadLocalState prev_state_; + // Whether pre-sampling RecordFunction optimization was enabled + bool bumped_record_all_functions_ = false; }; template diff --git a/aten/src/ATen/templates/TypeDefault.h b/aten/src/ATen/TypeDefault.h similarity index 86% rename from aten/src/ATen/templates/TypeDefault.h rename to aten/src/ATen/TypeDefault.h index fb62c7ba63542..7b5d77ba4d22c 100644 --- a/aten/src/ATen/templates/TypeDefault.h +++ b/aten/src/ATen/TypeDefault.h @@ -1,7 +1,5 @@ #pragma once -// ${generated_comment} - #include #include #include @@ -29,8 +27,4 @@ struct Quantizer; // to frontend using ConstQuantizerPtr = const c10::intrusive_ptr&; -namespace TypeDefault { - ${type_method_declarations} -} // namespace TypeDefault - } // namespace at diff --git a/aten/src/ATen/Utils.cpp b/aten/src/ATen/Utils.cpp index ccd4e4ba9f2f7..26fc7dabfd73a 100644 --- a/aten/src/ATen/Utils.cpp +++ b/aten/src/ATen/Utils.cpp @@ -1,8 +1,12 @@ #include +#include +#include +#include +#include #include +#include #include #include -#include namespace at { @@ -12,4 +16,113 @@ int _crash_if_asan(int arg) { return x[0]; } -} // at +namespace detail { +// empty_cpu is used in ScalarOps.h, which can be referenced by other ATen +// files. Since we want to decouple direct referencing native symbols and only +// access native symbols through dispatching, we move its implementation here. +Tensor empty_cpu( + IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + c10::optional memory_format_opt) { + Device device = device_or_default(device_opt); + + TORCH_CHECK(device.type() == DeviceType::CPU); + check_size_nonnegative(size); + + bool pin_memory = pinned_memory_or_default(pin_memory_opt); + c10::Allocator* allocator; + if (pin_memory) { + allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); + } else { + allocator = at::getCPUAllocator(); + } + + int64_t nelements = prod_intlist(size); + caffe2::TypeMeta dtype = scalarTypeToTypeMeta(dtype_or_default(dtype_opt)); + int64_t size_bytes = nelements * dtype.itemsize(); + auto storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes, + allocator->allocate(size_bytes), + allocator, + /*resizeable=*/true); + + auto tensor = detail::make_tensor( + std::move(storage_impl), at::DispatchKey::CPU, dtype); + // Default TensorImpl has size [0] + if (size.size() != 1 || size[0] != 0) { + tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); + } + + if (memory_format_opt.has_value()) { + // Restriding a just-created empty contiguous tensor does nothing. + if (*memory_format_opt != MemoryFormat::Contiguous) { + tensor.unsafeGetTensorImpl()->empty_tensor_restride(*memory_format_opt); + } + } + + return tensor; +} + +template +Tensor tensor_cpu(ArrayRef values, const TensorOptions& options) { + auto result = at::empty(values.size(), options); + AT_ASSERT(result.is_contiguous()); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(result.scalar_type(), "tensor_cpu", [&] { + std::copy( + values.begin(), values.end(), result.template data_ptr()); + }); + return result; +} + +template +Tensor tensor_backend(ArrayRef values, const TensorOptions& options) { + auto cpu_tensor = tensor_cpu(values, options.device(DeviceType::CPU)); + return cpu_tensor.to(options.device()); +} + +template +Tensor tensor_complex_cpu(ArrayRef values, const TensorOptions& options) { + auto result = at::empty(values.size(), options); + AT_ASSERT(result.is_contiguous()); + AT_DISPATCH_COMPLEX_TYPES(result.scalar_type(), "tensor_cpu", [&] { + std::copy( + values.begin(), values.end(), result.template data_ptr()); + }); + return result; +} + +template +Tensor tensor_complex_backend( + ArrayRef values, + const TensorOptions& options) { + auto cpu_tensor = tensor_complex_cpu(values, options.device(DeviceType::CPU)); + return cpu_tensor.to(options.device()); +} +} // namespace detail + +#define TENSOR(T, _1) \ + Tensor tensor(ArrayRef values, const TensorOptions& options) { \ + if (options.device().type() != c10::DeviceType::CPU) { \ + return at::detail::tensor_backend(values, options); \ + } else { \ + return at::detail::tensor_cpu(values, options); \ + } \ + } +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) +#undef TENSOR + +#define TENSOR(T, _1) \ + Tensor tensor(ArrayRef values, const TensorOptions& options) { \ + if (options.device().type() != c10::DeviceType::CPU) { \ + return at::detail::tensor_complex_backend(values, options); \ + } else { \ + return at::detail::tensor_complex_cpu(values, options); \ + } \ + } +AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR +} // namespace at diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index df0e49920afa1..e100bb11f4451 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -22,7 +22,7 @@ namespace at { -CAFFE2_API int _crash_if_asan(int); +TORCH_API int _crash_if_asan(int); // TODO: This unwrapping code is ONLY used for TH bindings; once TH goes // away, we can delete this function @@ -93,10 +93,18 @@ inline int64_t sum_intlist(ArrayRef list) { return std::accumulate(list.begin(), list.end(), 0ll); } -inline int64_t prod_intlist(ArrayRef list) { - return std::accumulate(list.begin(), list.end(), 1ll, std::multiplies()); +//std::accumulate infers return type from `init` type, so if `init` type is not enough to hold the result, computation can overflow +//the next 2 functions set `init` type to int64_t to avoid overflow. +template::value, int>::type = 0> +inline int64_t prod_intlist(const C &container){ + return std::accumulate(container.begin(), container.end(), static_cast(1), std::multiplies()); } +template::value_type>::value, int>::type = 0> +inline int64_t prod_intlist(Iter begin, Iter end){ + return std::accumulate(begin, end, static_cast(1), std::multiplies()); +} /** * Utility function to static cast input Generator* to * the backend generator type (CPU/CUDAGeneratorImpl etc.) @@ -120,4 +128,33 @@ static inline T* get_generator_or_default(const c10::optional& gen, c return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen); } +inline void check_size_nonnegative(IntArrayRef size) { + for (auto x: size) { + TORCH_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size); + } +} + +namespace detail { +TORCH_API +Tensor empty_cpu(IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt); + +template +TORCH_API +Tensor tensor_cpu(ArrayRef values, const TensorOptions& options); + +template +TORCH_API +Tensor tensor_backend(ArrayRef values, const TensorOptions& options); + +template +TORCH_API +Tensor tensor_complex_cpu(ArrayRef values, const TensorOptions& options); + +template +TORCH_API +Tensor tensor_complex_backend(ArrayRef values, const TensorOptions& options); +} // namespace detail + + } // at diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 0fcf38470a785..ecc5070610176 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -97,6 +97,14 @@ std::string used_cpu_capability() { ss << "CPU capability usage: "; auto capability = native::get_cpu_capability(); switch (capability) { +#ifdef HAVE_VSX_CPU_DEFINITION + case native::CPUCapability::DEFAULT: + ss << "DEFAULT"; + break; + case native::CPUCapability::VSX: + ss << "VSX"; + break; +#else case native::CPUCapability::DEFAULT: ss << "NO AVX"; break; @@ -106,6 +114,7 @@ std::string used_cpu_capability() { case native::CPUCapability::AVX2: ss << "AVX2"; break; +#endif default: break; } @@ -114,7 +123,7 @@ std::string used_cpu_capability() { std::string show_config() { std::ostringstream ss; - ss << "PyTorch built with:\n"; // TODO add the version of PyTorch + ss << "PyTorch built with:\n"; // Reference: // https://blog.kowalczyk.info/article/j/guide-to-predefined-macros-in-c-compilers-gcc-clang-msvc-etc..html @@ -165,6 +174,10 @@ std::string show_config() { ss << " - NNPACK is enabled\n"; #endif +#ifdef CROSS_COMPILING_MACOSX + ss << " - Cross compiling on MacOSX\n"; +#endif + ss << " - "<< used_cpu_capability() << "\n"; if (hasCUDA()) { @@ -172,7 +185,7 @@ std::string show_config() { } ss << " - Build settings: "; - for (const std::pair& pair : caffe2::GetBuildOptions()) { + for (const auto& pair : caffe2::GetBuildOptions()) { if (!pair.second.empty()) { ss << pair.first << "=" << pair.second << ", "; } @@ -185,4 +198,15 @@ std::string show_config() { return ss.str(); } +std::string get_cxx_flags() { + #if defined(FBCODE_CAFFE2) + TORCH_CHECK( + false, + "Buck does not populate the `CXX_FLAGS` field of Caffe2 build options. " + "As a result, `get_cxx_flags` is OSS only." + ); + #endif + return caffe2::GetBuildOptions().at("CXX_FLAGS"); +} + } diff --git a/aten/src/ATen/Version.h b/aten/src/ATen/Version.h index 18fd31d3ed877..88d010c18da00 100644 --- a/aten/src/ATen/Version.h +++ b/aten/src/ATen/Version.h @@ -3,12 +3,14 @@ namespace at { /// Returns a detailed string describing the configuration PyTorch. -CAFFE2_API std::string show_config(); +TORCH_API std::string show_config(); -CAFFE2_API std::string get_mkl_version(); +TORCH_API std::string get_mkl_version(); -CAFFE2_API std::string get_mkldnn_version(); +TORCH_API std::string get_mkldnn_version(); -CAFFE2_API std::string get_openmp_version(); +TORCH_API std::string get_openmp_version(); + +TORCH_API std::string get_cxx_flags(); } // namespace at diff --git a/aten/src/ATen/VmapMode.h b/aten/src/ATen/VmapMode.h index 8e59aacfa9252..c50f57a8a927d 100644 --- a/aten/src/ATen/VmapMode.h +++ b/aten/src/ATen/VmapMode.h @@ -11,7 +11,7 @@ namespace impl { // // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. -struct CAFFE2_API VmapMode { +struct TORCH_API VmapMode { // Returns the vmap level, aka the count of how many nested vmaps we're in. static int64_t current_vmap_level(); diff --git a/aten/src/ATen/VmapModeRegistrations.cpp b/aten/src/ATen/VmapModeRegistrations.cpp index 771706bca6ff7..ab4556c8c4155 100644 --- a/aten/src/ATen/VmapModeRegistrations.cpp +++ b/aten/src/ATen/VmapModeRegistrations.cpp @@ -42,70 +42,70 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) { #define TENSOROPTIONS c10::optional, c10::optional, c10::optional, c10::optional // random operations (out-of-place) - m.impl_UNBOXED("bernoulli", unsupportedRandomOp>); - m.impl_UNBOXED("bernoulli.out", unsupportedRandomOp_>); - m.impl_UNBOXED("bernoulli.p", unsupportedRandomOp>); - m.impl_UNBOXED("bernoulli_.Tensor", unsupportedRandomOp_>); - m.impl_UNBOXED("bernoulli_.float", unsupportedRandomOp_>); - - m.impl_UNBOXED("cauchy_", unsupportedRandomOp_>); - m.impl_UNBOXED("exponential_", unsupportedRandomOp_>); - m.impl_UNBOXED("geometric_", unsupportedRandomOp_>); - m.impl_UNBOXED("log_normal_", unsupportedRandomOp_>); - m.impl_UNBOXED("multinomial", unsupportedRandomOp>); - m.impl_UNBOXED("multinomial.out", unsupportedRandomOp_>); - - m.impl_UNBOXED("normal.Tensor_float", unsupportedRandomOp>); - m.impl_UNBOXED("normal.Tensor_float_out", unsupportedRandomOp_>); - m.impl_UNBOXED("normal.float_Tensor_out", unsupportedRandomOp_>); - m.impl_UNBOXED("normal.float_Tensor", unsupportedRandomOp>); - m.impl_UNBOXED("normal.Tensor_Tensor", unsupportedRandomOp>); - m.impl_UNBOXED("normal.Tensor_Tensor_out", unsupportedRandomOp_>); - m.impl_UNBOXED("normal.float_float", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("normal.float_float_out", unsupportedRandomOp_>); - m.impl_UNBOXED("normal_", unsupportedRandomOp_>); - - m.impl_UNBOXED("poisson", unsupportedRandomOp>); - - m.impl_UNBOXED("random_.from", unsupportedRandomOp_, optional>); - m.impl_UNBOXED("random_.to", unsupportedRandomOp_>); - m.impl_UNBOXED("random_", unsupportedRandomOp_>); - - m.impl_UNBOXED("rand_like", unsupportedRandomOp>); - m.impl_UNBOXED("randn_like", unsupportedRandomOp>); - - m.impl_UNBOXED("randint_like", unsupportedRandomOp>); - m.impl_UNBOXED("randint_like.low_dtype", unsupportedRandomOp>); + m.impl("bernoulli", unsupportedRandomOp>); + m.impl("bernoulli.out", unsupportedRandomOp_, Tensor&>); + m.impl("bernoulli.p", unsupportedRandomOp>); + m.impl("bernoulli_.Tensor", unsupportedRandomOp_>); + m.impl("bernoulli_.float", unsupportedRandomOp_>); + + m.impl("cauchy_", unsupportedRandomOp_>); + m.impl("exponential_", unsupportedRandomOp_>); + m.impl("geometric_", unsupportedRandomOp_>); + m.impl("log_normal_", unsupportedRandomOp_>); + m.impl("multinomial", unsupportedRandomOp>); + m.impl("multinomial.out", unsupportedRandomOp_, Tensor&>); + + m.impl("normal.Tensor_float", unsupportedRandomOp>); + m.impl("normal.Tensor_float_out", unsupportedRandomOp_, Tensor&>); + m.impl("normal.float_Tensor_out", unsupportedRandomOp_, Tensor&>); + m.impl("normal.float_Tensor", unsupportedRandomOp>); + m.impl("normal.Tensor_Tensor", unsupportedRandomOp>); + m.impl("normal.Tensor_Tensor_out", unsupportedRandomOp_, Tensor&>); + m.impl("normal.float_float", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("normal.float_float_out", unsupportedRandomOp_, Tensor&>); + m.impl("normal_", unsupportedRandomOp_>); + + m.impl("poisson", unsupportedRandomOp>); + + m.impl("random_.from", unsupportedRandomOp_, optional>); + m.impl("random_.to", unsupportedRandomOp_>); + m.impl("random_", unsupportedRandomOp_>); + + m.impl("rand_like", unsupportedRandomOp>); + m.impl("randn_like", unsupportedRandomOp>); + + m.impl("randint_like", unsupportedRandomOp>); + m.impl("randint_like.low_dtype", unsupportedRandomOp>); m.impl("rand", unsupportedRandomOp); - m.impl_UNBOXED("rand.generator", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("rand.names", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("rand.generator_with_names", unsupportedRandomOp, optional, const TensorOptions&>); - m.impl_UNBOXED("rand.out", unsupportedRandomOp_); - m.impl_UNBOXED("rand.generator_out", unsupportedRandomOp_>); + m.impl("rand.generator", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("rand.names", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("rand.generator_with_names", unsupportedRandomOp, optional, TENSOROPTIONS>); + m.impl("rand.out", unsupportedRandomOp_); + m.impl("rand.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randn", unsupportedRandomOp); - m.impl_UNBOXED("randn.generator", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("randn.names", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("randn.generator_with_names", unsupportedRandomOp, optional, const TensorOptions&>); - m.impl_UNBOXED("randn.out", unsupportedRandomOp_); - m.impl_UNBOXED("randn.generator_out", unsupportedRandomOp_>); + m.impl("randn.generator", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("randn.names", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("randn.generator_with_names", unsupportedRandomOp, optional, TENSOROPTIONS>); + m.impl("randn.out", unsupportedRandomOp_); + m.impl("randn.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randperm", unsupportedRandomOp); - m.impl_UNBOXED("randperm.generator", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("randperm.out", unsupportedRandomOp_); - m.impl_UNBOXED("randperm.generator_out", unsupportedRandomOp_>); + m.impl("randperm.generator", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("randperm.out", unsupportedRandomOp_); + m.impl("randperm.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randint", unsupportedRandomOp); - m.impl_UNBOXED("randint.generator", unsupportedRandomOp, const TensorOptions&>); + m.impl("randint.generator", unsupportedRandomOp, TENSOROPTIONS>); m.impl("randint.low", unsupportedRandomOp); - m.impl_UNBOXED("randint.low_generator", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("randint.out", unsupportedRandomOp_); - m.impl_UNBOXED("randint.generator_out", unsupportedRandomOp_>); - m.impl_UNBOXED("randint.low_out", unsupportedRandomOp_); - m.impl_UNBOXED("randint.low_generator_out", unsupportedRandomOp_>); + m.impl("randint.low_generator", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("randint.out", unsupportedRandomOp_); + m.impl("randint.generator_out", unsupportedRandomOp_, Tensor&>); + m.impl("randint.low_out", unsupportedRandomOp_); + m.impl("randint.low_generator_out", unsupportedRandomOp_, Tensor&>); - m.impl_UNBOXED("uniform_", unsupportedRandomOp_>); + m.impl("uniform_", unsupportedRandomOp_>); #undef TENSOROPTIONS } diff --git a/aten/src/ATen/VmapTransforms.cpp b/aten/src/ATen/VmapTransforms.cpp index bd7abf2341efb..f86e4819d743a 100644 --- a/aten/src/ATen/VmapTransforms.cpp +++ b/aten/src/ATen/VmapTransforms.cpp @@ -3,15 +3,6 @@ namespace at { -// Creates a bitset for all of the levels present in `bdims`. -std::bitset createLevelsBitset(BatchDimsRef bdims) { - std::bitset result; - for (const auto& bdim : bdims) { - result.set(bdim.level()); - } - return result; -} - // Checks if the batch dims in `bdims` appear at the front of the tensor. static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) { for (int64_t idx = 0; idx < bdims.size(); idx++) { @@ -52,7 +43,7 @@ VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logica TORCH_INTERNAL_ASSERT( batched, "logicalToPhysical(tensor) should only be passed a BatchedTensor"); - return { permuteBatchDimsToFront(batched), createLevelsBitset(batched->bdims()) }; + return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) }; } int64_t VmapPhysicalView::numBatchDims() const { @@ -100,16 +91,6 @@ static BatchDims computeFrontBatchDimsFromLevels(std::bitset lev return bdims; } -Tensor VmapPhysicalView::newLogicalFromPhysical(const Tensor& physical) const { - return makeBatched(physical, computeFrontBatchDimsFromLevels(levels_)); -} - -void VmapPhysicalView::makeLogicalFromPhysicalListInplace(std::vector& physical_tensors) const { - for (int64_t idx = 0; idx < physical_tensors.size(); ++idx) { - physical_tensors[idx] = newLogicalFromPhysical(physical_tensors[idx]); - } -} - // Given a Tensor or a BatchedTensor, returns the underlying physical tensor // with all vmapped dimensions permuted to the front, if they exist, and a // bitset of vmap levels that were present in the tensor. @@ -117,7 +98,7 @@ static std::pair> getPhysicalTensorAndLevels(const Tensor& self) { auto* batched = maybeGetBatchedImpl(self); if (batched) { - return {permuteBatchDimsToFront(batched), createLevelsBitset(batched->bdims())}; + return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())}; } return {self, 0}; } @@ -202,7 +183,7 @@ MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) { for (const auto& logical_tensor : logical_tensors) { auto* batched = maybeGetBatchedImpl(logical_tensor); if (batched) { - collective_levels |= createLevelsBitset(batched->bdims()); + collective_levels |= createVmapLevelsBitset(batched->bdims()); } } @@ -252,7 +233,7 @@ getLevelsAndLargestLogicalDim(TensorList logical_tensors) { for (const auto& tensor : logical_tensors) { auto* batched = maybeGetBatchedImpl(tensor); if (batched) { - levels = levels | createLevelsBitset(batched->bdims()); + levels = levels | createVmapLevelsBitset(batched->bdims()); } auto tensor_logical_dim = /*logical dim*/tensor.dim(); if (tensor_logical_dim > largest_logical_dim) { @@ -290,4 +271,18 @@ VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logi return result; } +VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const { + return VmapPhysicalToLogicalMap(levels_); +} + +Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const { + return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_)); +} + +void VmapPhysicalToLogicalMap::applyInplace(std::vector& physical_tensors) const { + for (int64_t idx = 0; idx < physical_tensors.size(); ++idx) { + physical_tensors[idx] = apply(physical_tensors[idx]); + } +} + } // namespace at diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h index b544618131ad4..8fa0852454593 100644 --- a/aten/src/ATen/VmapTransforms.h +++ b/aten/src/ATen/VmapTransforms.h @@ -79,6 +79,10 @@ struct TORCH_API BroadcastingVmapTransform { static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); }; +// Forward declared, if you're reading this file head to toe, don't worry about +// it yet. +struct VmapPhysicalToLogicalMap; + // NOTE: [What is a VmapPhysicalView?] // VmapPhysicalView represents a physical view on a Tensor. // @@ -92,8 +96,17 @@ struct TORCH_API BroadcastingVmapTransform { // The levels bitset specifies which vmap levels correspond to the batch // dimensions at the front of the tensor. In particular, the number of set bits // corresponds to the number of batch dimensions on `tensor` and the rightmost -// bit of `levels` specifies the minimum number of nested vmaps we are in at +// bit of `levels` specifies the maximum number of nested vmaps we are in at // this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 struct TORCH_API VmapPhysicalView { VmapPhysicalView(Tensor&& tensor, std::bitset levels) : levels_(levels), tensor_(tensor) { @@ -115,24 +128,14 @@ struct TORCH_API VmapPhysicalView { VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const; int64_t getPhysicalDim(int64_t logical_dim) const; + // Returns a VmapPhysicalToLogicalMap object. This can be used for + // mapping a physical tensor to a new logical tensor (BatchedTensor) + VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; + // Maps a logical shape to a physical shape by pre-pending the batch // sizes to the logical shape. VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; - // Maps a physical tensor to a new logical tensor (BatchedTensor), - // using the mapping info stored in this VmapPhysicalView. - // Assumes that all of the "batch dimensions" are at the front - // of the physical tensor. - Tensor newLogicalFromPhysical(const Tensor& physical) const; - - // Given a vector of physical tensors, - // 1. maps each tensor to a new logical tensor using the mapping info stored - // in this VmapPhysicalView. Assumes that all of the "batch dimensions" - // are at the front of the physical tensors. - // 2. stores the new logical tensors back into the passed-in vector. This is - // to avoid additional dynamic allocations. - void makeLogicalFromPhysicalListInplace(std::vector& physical_tensors) const; - int64_t numBatchDims() const; private: @@ -142,5 +145,31 @@ struct TORCH_API VmapPhysicalView { Tensor tensor_; }; +// Convenience struct used for mapping a physical tensor (a non-BatchedTensor) +// to a logical one (BatchedTensor). It holds some levels that are used to do the +// mapping and assumes that the batch dimensions in the physical tensor all +// occur at the front of the tensor. +struct TORCH_API VmapPhysicalToLogicalMap { + VmapPhysicalToLogicalMap(std::bitset levels): levels_(levels) {} + + // Maps a physical tensor to a new logical tensor (BatchedTensor). + // Assumes that all of the "batch dimensions" are at the front + // of the physical tensor. For example, given: + // - x = rank-4 Tensor with size 2, 3, 5, 7 + // - levels = (2, 4) + // Returns: + // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) + Tensor apply(const Tensor& physical_tensor) const; + + // Given a vector of physical tensors, + // 1. maps each tensor to a new logical tensor. Assumes that all of the + // "batch dimensions" are at the front of the physical tensors. + // 2. stores the new logical tensors back into the passed-in vector. This is + // to avoid additional dynamic allocations. + void applyInplace(std::vector& physical_tensors) const; + + std::bitset levels_; +}; + } // namespace at diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index cb1ea44d2e7de..9a2f34257c57b 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -68,7 +68,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) { if (is_eligible(arg) && (arg.scalar_type() != to_type)) { // Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves). // See cached_casts declaration above for detailed strategy. - bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf()); + bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view()); if (can_try_cache) { auto it = cached_casts.find(arg.unsafeGetTensorImpl()); if (it != cached_casts.end()) { @@ -239,13 +239,9 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ &WrapFunction::type::call); -#define KERNEL_UNBOXED_ONLY(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \ - m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ - &WrapFunction::type::call); - // Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype) -#define KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \ - m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \ + m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ &WrapFunction::type::call); /***************************************** @@ -341,8 +337,8 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { // The macro doesn't like this one (I think it chokes on commas inside <>) so write it manually m.impl(TORCH_SELECTIVE_NAME("aten::native_layer_norm"), TORCH_FN((&WrapFunction (const Tensor &, const c10::optional&, const c10::optional&, int64_t, int64_t, double), - std::tuple (const Tensor &, const c10::optional&, const c10::optional&, int64_t, int64_t, double), + std::tuple (const Tensor&, IntArrayRef, const c10::optional&, const c10::optional&, double), + std::tuple (const Tensor&, IntArrayRef, const c10::optional&, const c10::optional&, double), &ADD_NS(native_layer_norm)>::type::call))); KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional&, const c10::optional&, double, bool), fp32) KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32) @@ -357,7 +353,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(hinge_embedding_loss), "hinge_embedding_loss", Tensor (const Tensor &, const Tensor &, double, int64_t), fp32) KERNEL(ADD_NS(kl_div), "kl_div", Tensor (const Tensor &, const Tensor &, int64_t, bool), fp32) KERNEL(ADD_NS(l1_loss), "l1_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32) - KERNEL(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32) + KERNEL(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32) KERNEL(ADD_NS(mse_loss), "mse_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32) KERNEL(ADD_NS(margin_ranking_loss), "margin_ranking_loss", Tensor (const Tensor &, const Tensor &, const Tensor &, double, int64_t), fp32) KERNEL(ADD_NS(multilabel_margin_loss), "multilabel_margin_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32) @@ -367,20 +363,20 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional&, const c10::optional&, int64_t), fp32) KERNEL(ADD_NS(dist), "dist", Tensor (const Tensor &, const Tensor &, Scalar), fp32) KERNEL(ADD_NS(pdist), "pdist", Tensor (const Tensor &, double), fp32) - KERNEL_UNBOXED_ONLY(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional), fp32) + KERNEL(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional), fp32) KERNEL(ADD_NS(renorm), "renorm", Tensor (const Tensor &, Scalar, int64_t, Scalar), fp32) // fp32_set_opt_dtype KERNEL(ADD_NS(prod), "prod", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(prod), "prod.dim_int", Tensor (const Tensor &, int64_t, bool, c10::optional), fp32_set_opt_dtype) - KERNEL_UNBOXED_ONLY(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(softmax), "softmax.int", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype) - KERNEL_UNBOXED_ONLY(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(log_softmax), "log_softmax.int", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype) - KERNEL_UNBOXED_ONLY(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(cumprod), "cumprod", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype) - KERNEL_UNBOXED_ONLY(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(cumsum), "cumsum", Tensor (const Tensor &, int64_t, c10::optional), fp32_set_opt_dtype) - KERNEL_UNBOXED_ONLY(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional), fp32_set_opt_dtype) // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even // when autocasting. // KERNEL(ADD_NS(norm), "norm.ScalarOpt_dtype", Tensor (const Tensor &, c10::optional, ScalarType), fp32_set_opt_dtype) @@ -388,25 +384,25 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { // KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional), fp32_set_opt_dtype) - KERNEL_UNBOXED_ONLY(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype) // fp32_append_dtype // The fp32_append_dtype wrapper overrides implicit promotion behavior. // norm does not implicitly promote, but be aware when adding new ops to this policy. - KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional, ScalarType), fp32_append_dtype) - KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional, IntArrayRef, bool), Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_append_dtype) - KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional, DimnameList, bool), Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_append_dtype) + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional, ScalarType), fp32_append_dtype) + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional, IntArrayRef, bool), Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_append_dtype) + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional, DimnameList, bool), Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_append_dtype) // promote KERNEL(ADD_NS(addcdiv), "addcdiv", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote) KERNEL(ADD_NS(addcmul), "addcmul", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote) KERNEL(ADD_NS(atan2), "atan2", Tensor (const Tensor &, const Tensor &), promote) KERNEL(ADD_NS(bilinear), "bilinear", Tensor (const Tensor &, const Tensor &, const Tensor &, const c10::optional&), promote) KERNEL(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote) - KERNEL_UNBOXED_ONLY(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote) + KERNEL(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote) KERNEL(ADD_NS(_cat), "_cat", Tensor (TensorList, int64_t), promote) KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional), promote) KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote) KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote) - KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote) + KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List>&, const Tensor &, bool), promote) KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote) KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote) diff --git a/aten/src/ATen/benchmarks/quantize_per_channel.cpp b/aten/src/ATen/benchmarks/quantize_per_channel.cpp new file mode 100644 index 0000000000000..b9a3565937067 --- /dev/null +++ b/aten/src/ATen/benchmarks/quantize_per_channel.cpp @@ -0,0 +1,85 @@ +#include +#include + +#include + +static void quantize_per_channel_4d_contiguous(benchmark::State& state) { + const size_t batches = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + const size_t height = static_cast(state.range(2)); + const size_t width = static_cast(state.range(3)); + + at::Tensor a = at::rand({batches, channels, height, width}); + at::Tensor scales = at::rand({channels}); + at::Tensor zero_points = at::randint( + 0, 10, {channels}, at::TensorOptions().dtype(at::ScalarType::Int)); + + at::Tensor qa; + for (auto _ : state) { + qa = at::native::quantize_per_channel_cpu( + a, scales, zero_points, 1, at::ScalarType::QUInt8); + } +} + +static void quantize_per_channel_4d_channels_last(benchmark::State& state) { + const size_t batches = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + const size_t height = static_cast(state.range(2)); + const size_t width = static_cast(state.range(3)); + + at::Tensor a = at::rand( + {batches, channels, height, width}, + at::TensorOptions().memory_format(at::MemoryFormat::ChannelsLast)); + at::Tensor scales = at::rand({channels}); + at::Tensor zero_points = at::randint( + 0, 10, {channels}, at::TensorOptions().dtype(at::ScalarType::Int)); + + at::Tensor qa; + for (auto _ : state) { + qa = at::native::quantize_per_channel_cpu( + a, scales, zero_points, 1, at::ScalarType::QUInt8); + } +} + +static void quantize_per_channel_2d(benchmark::State& state) { + const size_t channels = static_cast(state.range(0)); + const size_t nelem = static_cast(state.range(1)); + + at::Tensor a = at::rand({channels, nelem}); + at::Tensor scales = at::rand({channels}); + at::Tensor zero_points = at::randint( + 0, 10, {channels}, at::TensorOptions().dtype(at::ScalarType::Int)); + + at::Tensor qa; + for (auto _ : state) { + qa = at::native::quantize_per_channel_cpu( + a, scales, zero_points, 0, at::ScalarType::QUInt8); + } +} + +static void GenerateSizes4d(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "C", "H", "W"}); + + for (size_t n = 16; n < 256; n *= 2) { + for (size_t c = 4; c < 256; c *= 2) { + for (size_t hw = 4; hw < 256; hw *= 2) { + b->Args({n, c, hw, hw}); + } + } + } +} + +static void GenerateSizes2d(benchmark::internal::Benchmark* b) { + b->ArgNames({"C", "N"}); + + for (size_t c = 4; c < 512; c *= 2) { + for (size_t n = 4; n < 512; n *= 2) { + b->Args({c, n}); + } + } +} + +BENCHMARK(quantize_per_channel_2d)->Apply(GenerateSizes2d); +BENCHMARK(quantize_per_channel_4d_contiguous)->Apply(GenerateSizes4d); +BENCHMARK(quantize_per_channel_4d_channels_last)->Apply(GenerateSizes4d); +BENCHMARK_MAIN(); diff --git a/aten/src/ATen/core/ATenOpList.h b/aten/src/ATen/core/ATenOpList.h index 880a690fd11c8..1419376a9017d 100644 --- a/aten/src/ATen/core/ATenOpList.h +++ b/aten/src/ATen/core/ATenOpList.h @@ -9,5 +9,5 @@ struct OperatorName; namespace at { // check if an op is a custom op (i.e. did not come from native_functions.yaml) -CAFFE2_API bool is_custom_op(const c10::OperatorName& opName); +TORCH_API bool is_custom_op(const c10::OperatorName& opName); } diff --git a/aten/src/ATen/core/DeprecatedTypeProperties.h b/aten/src/ATen/core/DeprecatedTypeProperties.h index 719cd9a186382..0c3044470cca2 100644 --- a/aten/src/ATen/core/DeprecatedTypeProperties.h +++ b/aten/src/ATen/core/DeprecatedTypeProperties.h @@ -17,7 +17,7 @@ class Tensor; // serves as a replacement return value for Tensor::type(). Previously, // Tensor::type() returned Type&, but we are changing Type to not be // dtype-specific. -class CAFFE2_API DeprecatedTypeProperties { +class TORCH_API DeprecatedTypeProperties { public: DeprecatedTypeProperties(Backend backend, ScalarType scalar_type) : backend_(backend), scalar_type_(scalar_type) {} diff --git a/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h b/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h index d9b29a35b3847..a21f1abbe97f4 100644 --- a/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h +++ b/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h @@ -10,11 +10,11 @@ namespace at { class DeprecatedTypeProperties; -struct CAFFE2_API DeprecatedTypePropertiesDeleter { +struct TORCH_API DeprecatedTypePropertiesDeleter { void operator()(DeprecatedTypeProperties * ptr); }; -class CAFFE2_API DeprecatedTypePropertiesRegistry { +class TORCH_API DeprecatedTypePropertiesRegistry { public: DeprecatedTypePropertiesRegistry(); @@ -26,6 +26,6 @@ class CAFFE2_API DeprecatedTypePropertiesRegistry { [static_cast(ScalarType::NumOptions)]; }; -CAFFE2_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry(); +TORCH_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry(); } // namespace at diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h index f5c997a55b747..6cfbea286f00c 100644 --- a/aten/src/ATen/core/Dict_inl.h +++ b/aten/src/ATen/core/Dict_inl.h @@ -38,7 +38,7 @@ namespace detail { inline size_t DictKeyHash::operator()(const IValue& ivalue) const { if (ivalue.isInt()) { - return std::hash()(ivalue.toInt()); + return std::hash()(ivalue.toInt()); } else if (ivalue.isString()) { return std::hash()(ivalue.toStringRef()); } else if (ivalue.isDouble()) { diff --git a/aten/src/ATen/core/Dimname.h b/aten/src/ATen/core/Dimname.h index d81cdfef34e7b..c68ee86733863 100644 --- a/aten/src/ATen/core/Dimname.h +++ b/aten/src/ATen/core/Dimname.h @@ -9,7 +9,7 @@ namespace at { enum class NameType: uint8_t { BASIC, WILDCARD }; -struct CAFFE2_API Dimname { +struct TORCH_API Dimname { static Dimname fromSymbol(Symbol name); static Dimname wildcard(); static bool isValidName(const std::string& name); @@ -21,7 +21,7 @@ struct CAFFE2_API Dimname { bool isWildcard() const { return type_ == NameType::WILDCARD; } bool matches(Dimname other) const; - optional unify(Dimname other) const; + c10::optional unify(Dimname other) const; private: Dimname(Symbol name) @@ -35,7 +35,7 @@ struct CAFFE2_API Dimname { using DimnameList = c10::ArrayRef; -CAFFE2_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname); +TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname); inline bool operator==(const Dimname& lhs, const Dimname& rhs) { return lhs.symbol() == rhs.symbol(); diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h index 071af83aa4bde..d767a1ca52341 100644 --- a/aten/src/ATen/core/DistributionsHelper.h +++ b/aten/src/ATen/core/DistributionsHelper.h @@ -197,7 +197,7 @@ template struct normal_distribution { C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) { - TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0); + TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in); mean = mean_in; stdv = stdv_in; } diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index 219dc857f2a11..eb124dab6874e 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -292,6 +292,11 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi stream << ", axis: " << tensor_.q_per_channel_axis(); } } + + auto& fw_grad = tensor.fw_grad(/* level */ 0); + if (fw_grad.defined()) { + stream << ", tangent:" << std::endl << fw_grad; + } stream << " ]"; } return stream; diff --git a/aten/src/ATen/core/Formatting.h b/aten/src/ATen/core/Formatting.h index 63c5e12e9630d..4a5545ab197a8 100644 --- a/aten/src/ATen/core/Formatting.h +++ b/aten/src/ATen/core/Formatting.h @@ -6,12 +6,12 @@ namespace c10 { -CAFFE2_API std::ostream& operator<<(std::ostream& out, Backend b); +TORCH_API std::ostream& operator<<(std::ostream& out, Backend b); } namespace at { -CAFFE2_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t); -CAFFE2_API std::ostream& print( +TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t); +TORCH_API std::ostream& print( std::ostream& stream, const Tensor& tensor, int64_t linesize); diff --git a/aten/src/ATen/core/Generator.cpp b/aten/src/ATen/core/Generator.cpp new file mode 100644 index 0000000000000..800f8c7c88ec6 --- /dev/null +++ b/aten/src/ATen/core/Generator.cpp @@ -0,0 +1,16 @@ +#include +#include +#include + +namespace at { + +void Generator::set_state(const at::Tensor& new_state) { + TORCH_CHECK(new_state.defined(), "Undefined tensor is not allowed"); + this->impl_->set_state(*new_state.unsafeGetTensorImpl()); +} + +at::Tensor Generator::get_state() const { + return at::Tensor::wrap_tensor_impl(this->impl_->get_state()); +} + +} // namespace at diff --git a/aten/src/ATen/core/Generator.h b/aten/src/ATen/core/Generator.h index 1228b42dfdd1d..b5bbb2fe3c747 100644 --- a/aten/src/ATen/core/Generator.h +++ b/aten/src/ATen/core/Generator.h @@ -13,6 +13,12 @@ #include #include #include + +// For the record I don't think this is a correct pimpl idiom. +// Including Impl header in interface header defeats the purpose +// because you can't change Impl private members without forcing +// everything that included the interface to rebuild. +// Impl should be forward-declared in the interface header instead. #include /** @@ -31,7 +37,6 @@ * * By default, there is one generator per device, and a device's generator is * lazily created. A user can use the torch.Generator() api to create their own generator. - * Currently torch.Generator() can only create a CPUGeneratorImpl. */ /** @@ -43,7 +48,7 @@ * Please use the public mutex_ when using any methods from these classes, except for the * read-only methods. You can learn about the usage by looking into the unittests * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard. - * + * * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making * them non-thread safe and instead making the generator state splittable, to accommodate * forks into other threads). @@ -51,7 +56,9 @@ namespace at { -struct CAFFE2_API Generator { +class Tensor; + +struct TORCH_API Generator { Generator() {} explicit Generator(c10::intrusive_ptr gen_impl) @@ -91,6 +98,12 @@ struct CAFFE2_API Generator { uint64_t seed() { return impl_->seed(); } + // Implementation not inlined to prevent cycle reference between + // `ATen/core/Generator.h` and `ATen/core/Tensor.h` + void set_state(const at::Tensor& new_state); + + at::Tensor get_state() const; + std::mutex& mutex() { return impl_->mutex_; } @@ -125,5 +138,24 @@ Generator make_generator(Args&&... args) { return Generator(c10::make_intrusive(std::forward(args)...)); } -} // namespace at +namespace detail { +/** + * Helper function for checking the validity of new random generator + * state. Right now following conditions are checked: + * + * - The new state tensor must be a torch.ByteTensor + * - Data of the new state tensor must be contiguous + */ +static inline void check_rng_state(const c10::TensorImpl& new_state) { + TORCH_CHECK_TYPE( + new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte, + "RNG state must be a torch.ByteTensor" + ); + + TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous"); +} + +} // namespace detail + +} // namespace at diff --git a/aten/src/ATen/core/LegacyTypeDispatch.h b/aten/src/ATen/core/LegacyTypeDispatch.h index 925a87a7c933c..85f771a1cbbb2 100644 --- a/aten/src/ATen/core/LegacyTypeDispatch.h +++ b/aten/src/ATen/core/LegacyTypeDispatch.h @@ -43,7 +43,7 @@ namespace at { // trace). To unify the two, we would first have to move profiling and tracing // out of VariableType. -struct CAFFE2_API AutoNonVariableTypeMode { +struct TORCH_API AutoNonVariableTypeMode { // NB: The enabled parameter must ALWAYS be black, as Henry Ford used to say. // TODO: Eliminate this parameter entirely AutoNonVariableTypeMode(bool enabled = true) : diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h index 40f733784fe58..c76a1e0173be8 100644 --- a/aten/src/ATen/core/List.h +++ b/aten/src/ATen/core/List.h @@ -53,6 +53,28 @@ bool operator==(const ListElementReference& lhs, const T& rhs); template bool operator==(const T& lhs, const ListElementReference& rhs); +template +struct ListElementConstReferenceTraits { + // In the general case, we cannot expose a true const reference to + // the contents of an IValue, so we copy. + using const_reference = T; +}; + +template<> +struct ListElementConstReferenceTraits { + using const_reference = const std::string&; +}; + +template<> +struct ListElementConstReferenceTraits> { + using const_reference = c10::optional>; +}; + +template<> +struct ListElementConstReferenceTraits { + using const_reference = const at::Tensor&; +}; + template class ListElementReference final { public: @@ -65,17 +87,6 @@ class ListElementReference final { // assigning another ref to this assigns the underlying value ListElementReference& operator=(ListElementReference&& rhs) &&; - // returns the underlying std::string by reference (only enabled if this is from a torch::List). - template - std::enable_if_t::value && std::is_same<_T, T>::value, const std::string&> toStringRef() { - return iterator_->toStringRef(); - } - - template - std::enable_if_t, T>::value && std::is_same<_T, T>::value, c10::optional>> toOptionalStringRef() { - return iterator_->toOptionalStringRef(); - } - friend void swap(ListElementReference&& lhs, ListElementReference&& rhs); private: @@ -226,6 +237,7 @@ class List final { c10::intrusive_ptr impl_; using internal_reference_type = impl::ListElementReference; + using internal_const_reference_type = typename impl::ListElementConstReferenceTraits::const_reference; public: using value_type = T; @@ -243,7 +255,7 @@ class List final { * Example: * List a({2, 3, 4}); */ - explicit List(std::initializer_list initial_values); + List(std::initializer_list initial_values); explicit List(ArrayRef initial_values); /** @@ -289,7 +301,9 @@ class List final { * list[2] = 5; * int64_t v = list[1]; */ - internal_reference_type operator[](size_type pos) const; + internal_const_reference_type operator[](size_type pos) const; + + internal_reference_type operator[](size_type pos); /** * Assigns a new value to the element at location pos. diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index 3cbd7a310275a..76c445a7451ee 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -1,7 +1,7 @@ #pragma once +#include #include -#include namespace c10 { @@ -50,7 +50,17 @@ List::List(TypePtr elementType) namespace impl { template List toTypedList(impl::GenericList list) { - TORCH_INTERNAL_ASSERT(*getTypePtr() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); + // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant + // because upcasting would allow people to add types into the new list that would break the old list. + // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can + // allow upcasting. This can be a perf improvement since we can cast List to List> + // without having to copy it. This is also used to provide backwards compatibility with some old models + // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_ + // as List before we changed that argument to be List>. When deserializing, we + // have list.use_count() == 1 and can deserialize the List directly as List>. + TORCH_CHECK(*list.impl_->elementType == *getTypePtr() + || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr())) + , "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); return List(std::move(list.impl_)); } @@ -91,13 +101,23 @@ namespace detail { return std::move(element).template to(); } template - IValue list_element_from(const T& element) { - return element; - } - template - IValue list_element_from(T&& element) { - return std::move(element); - } + struct ListElementFrom { + static IValue from(const T& element) { + return element; + } + static IValue from(T&& element) { + return std::move(element); + } + }; + template<> + struct ListElementFrom { + static const IValue& from(const IValue& element) { + return element; + } + static IValue&& from(IValue&& element) { + return std::move(element); + } + }; } namespace impl { @@ -109,13 +129,13 @@ ListElementReference::operator T() const { template ListElementReference& ListElementReference::operator=(T&& new_value) && { - *iterator_ = c10::detail::list_element_from(std::move(new_value)); + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } template ListElementReference& ListElementReference::operator=(const T& new_value) && { - *iterator_ = c10::detail::list_element_from(std::move(new_value)); + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); return *this; } @@ -140,16 +160,41 @@ template inline bool operator==(const T& lhs, const ListElementReference& rhs) { return rhs == lhs; } + +template +inline typename ListElementConstReferenceTraits::const_reference +list_element_to_const_ref(const IValue& element) { + return element.template to(); } +template<> +inline typename ListElementConstReferenceTraits::const_reference +list_element_to_const_ref(const IValue& element) { + return element.toStringRef(); +} + +template<> +inline typename ListElementConstReferenceTraits>::const_reference +list_element_to_const_ref>(const IValue& element) { + return element.toOptionalStringRef(); +} + +template<> +inline typename ListElementConstReferenceTraits::const_reference +list_element_to_const_ref(const IValue& element) { + return element.toTensor(); +} + +} // namespace impl + template void List::set(size_type pos, const value_type& value) const { - impl_->list.at(pos) = c10::detail::list_element_from(value); + impl_->list.at(pos) = c10::detail::ListElementFrom::from(value); } template void List::set(size_type pos, value_type&& value) const { - impl_->list.at(pos) = c10::detail::list_element_from(std::move(value)); + impl_->list.at(pos) = c10::detail::ListElementFrom::from(std::move(value)); } template @@ -158,7 +203,12 @@ typename List::value_type List::get(size_type pos) const { } template -typename List::internal_reference_type List::operator[](size_type pos) const { +typename List::internal_const_reference_type List::operator[](size_type pos) const { + return c10::impl::list_element_to_const_ref(impl_->list.at(pos)); +} + +template +typename List::internal_reference_type List::operator[](size_type pos) { static_cast(impl_->list.at(pos)); // Throw the exception if it is out of range. return {impl_->list.begin() + pos}; } @@ -168,7 +218,7 @@ typename List::value_type List::extract(size_type pos) const { auto& elem = impl_->list.at(pos); auto result = c10::detail::list_element_to(std::move(elem)); // Reset the list element to a T() instead of None to keep it correctly typed - elem = c10::detail::list_element_from(T{}); + elem = c10::detail::ListElementFrom::from(T{}); return result; } @@ -204,12 +254,12 @@ void List::clear() const { template typename List::iterator List::insert(iterator pos, const T& value) const { - return iterator { impl_->list.insert(pos.iterator_, c10::detail::list_element_from(value)) }; + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(value)) }; } template typename List::iterator List::insert(iterator pos, T&& value) const { - return iterator { impl_->list.insert(pos.iterator_, c10::detail::list_element_from(std::move(value))) }; + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(std::move(value))) }; } template @@ -221,12 +271,12 @@ typename List::iterator List::emplace(iterator pos, Args&&... value) const template void List::push_back(const T& value) const { - impl_->list.push_back(c10::detail::list_element_from(value)); + impl_->list.push_back(c10::detail::ListElementFrom::from(value)); } template void List::push_back(T&& value) const { - impl_->list.push_back(c10::detail::list_element_from(std::move(value))); + impl_->list.push_back(c10::detail::ListElementFrom::from(std::move(value))); } template @@ -312,3 +362,5 @@ void List::unsafeSetElementType(TypePtr t) { impl_->elementType = std::move(t); } } + +#include diff --git a/aten/src/ATen/core/List_test.cpp b/aten/src/ATen/core/List_test.cpp index 96af072862753..6dc4f53f07072 100644 --- a/aten/src/ATen/core/List_test.cpp +++ b/aten/src/ATen/core/List_test.cpp @@ -1085,14 +1085,35 @@ TEST(ListTest_NonIValueBasedList, sameValueDifferentStorage_thenIsReturnsFalse) TEST(ListTest, canAccessStringByReference) { List list({"one", "two"}); - const std::string& str = list[1].toStringRef(); + const auto& listRef = list; + static_assert(std::is_same::value, + "const List acccess should be by const reference"); + std::string str = list[1]; + const std::string& strRef = listRef[1]; EXPECT_EQ("two", str); + EXPECT_EQ("two", strRef); } TEST(ListTest, canAccessOptionalStringByReference) { List> list({"one", "two", c10::nullopt}); - c10::optional> str1 = list[1].toOptionalStringRef(); - c10::optional> str2 = list[2].toOptionalStringRef(); - EXPECT_EQ("two", str1.value().get()); + const auto& listRef = list; + static_assert( + std::is_same>>::value, + "List> acccess should be by const reference"); + c10::optional str1 = list[1]; + c10::optional str2 = list[2]; + decltype(auto) strRef1 = listRef[1]; + decltype(auto) strRef2 = listRef[2]; + EXPECT_EQ("two", str1.value()); EXPECT_FALSE(str2.has_value()); + EXPECT_EQ("two", strRef1.value().get()); + EXPECT_FALSE(strRef2.has_value()); +} + +TEST(ListTest, canAccessTensorByReference) { + List list; + const auto& listRef = list; + static_assert( + std::is_same::value, + "List access should be by const reference"); } diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 6712be56ebb2a..d9a4979ff3c93 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -37,9 +37,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("add.out", CppFunction::makeFallthrough()); m.impl("add_.Scalar", CppFunction::makeFallthrough()); m.impl("add_.Tensor", CppFunction::makeFallthrough()); - m.impl("add_relu.Tensor", CppFunction::makeFallthrough()); - m.impl("add_relu.out", CppFunction::makeFallthrough()); - m.impl("add_relu_.Tensor", CppFunction::makeFallthrough()); + m.impl("_add_relu.Tensor", CppFunction::makeFallthrough()); + m.impl("_add_relu.out", CppFunction::makeFallthrough()); + m.impl("_add_relu_.Tensor", CppFunction::makeFallthrough()); m.impl("addcdiv", CppFunction::makeFallthrough()); m.impl("addcdiv.out", CppFunction::makeFallthrough()); m.impl("addcdiv_", CppFunction::makeFallthrough()); @@ -113,10 +113,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("conj.out", CppFunction::makeFallthrough()); m.impl("contiguous", CppFunction::makeFallthrough()); m.impl("copy_", CppFunction::makeFallthrough()); - m.impl("copy_imag", CppFunction::makeFallthrough()); - m.impl("copy_imag.out", CppFunction::makeFallthrough()); - m.impl("copy_real", CppFunction::makeFallthrough()); - m.impl("copy_real.out", CppFunction::makeFallthrough()); m.impl("cos", CppFunction::makeFallthrough()); m.impl("cos.out", CppFunction::makeFallthrough()); m.impl("cos_", CppFunction::makeFallthrough()); @@ -214,6 +210,12 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("i0", CppFunction::makeFallthrough()); m.impl("i0.out", CppFunction::makeFallthrough()); m.impl("i0_", CppFunction::makeFallthrough()); + m.impl("igamma", CppFunction::makeFallthrough()); + m.impl("igamma.out", CppFunction::makeFallthrough()); + m.impl("igamma_", CppFunction::makeFallthrough()); + m.impl("igammac", CppFunction::makeFallthrough()); + m.impl("igammac.out", CppFunction::makeFallthrough()); + m.impl("igammac_", CppFunction::makeFallthrough()); m.impl("imag", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough()); @@ -317,6 +319,11 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("median.dim_values", CppFunction::makeFallthrough()); m.impl("median.names_dim", CppFunction::makeFallthrough()); m.impl("median.names_dim_values", CppFunction::makeFallthrough()); + m.impl("nanmedian", CppFunction::makeFallthrough()); + m.impl("nanmedian.dim", CppFunction::makeFallthrough()); + m.impl("nanmedian.dim_values", CppFunction::makeFallthrough()); + m.impl("nanmedian.names_dim", CppFunction::makeFallthrough()); + m.impl("nanmedian.names_dim_values", CppFunction::makeFallthrough()); m.impl("min", CppFunction::makeFallthrough()); m.impl("min.dim", CppFunction::makeFallthrough()); m.impl("min.dim_min", CppFunction::makeFallthrough()); @@ -453,6 +460,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("tanh", CppFunction::makeFallthrough()); m.impl("tanh.out", CppFunction::makeFallthrough()); m.impl("tanh_", CppFunction::makeFallthrough()); + m.impl("tensor_split.indices", CppFunction::makeFallthrough()); + m.impl("tensor_split.sections", CppFunction::makeFallthrough()); + m.impl("tensor_split.tensor_indices_or_sections", CppFunction::makeFallthrough()); m.impl("threshold", CppFunction::makeFallthrough()); m.impl("threshold.out", CppFunction::makeFallthrough()); m.impl("threshold_", CppFunction::makeFallthrough()); @@ -493,12 +503,12 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { // supported because they were manually registered. I'm not sure // if these registrations are right or not, but they preserve old behavior // (and some of them are exercised by the test suite). - m.impl("backward", CppFunction::makeFallthrough()); + m.impl("_backward", CppFunction::makeFallthrough()); m.impl("set_data", CppFunction::makeFallthrough()); m.impl("data", CppFunction::makeFallthrough()); m.impl("is_leaf", CppFunction::makeFallthrough()); m.impl("_version", CppFunction::makeFallthrough()); m.impl("requires_grad_", CppFunction::makeFallthrough()); - m.impl("requires_grad", CppFunction::makeFallthrough()); m.impl("retain_grad", CppFunction::makeFallthrough()); + m.impl("_fw_primal", CppFunction::makeFallthrough()); } diff --git a/aten/src/ATen/core/NamedTensor.h b/aten/src/ATen/core/NamedTensor.h index 6efd0fe1f61a7..5b064ca70aec9 100644 --- a/aten/src/ATen/core/NamedTensor.h +++ b/aten/src/ATen/core/NamedTensor.h @@ -19,7 +19,7 @@ namespace at { // // This class has an important invariant: there must be at least ONE // non-wildcard -struct CAFFE2_API NamedTensorMeta final : public c10::NamedTensorMetaInterface { +struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface { // This enum is to remind people that the invariant on constructors is that // the list of dimnames must have at least one non-wildcard enum HAS_NON_WILDCARD { @@ -69,7 +69,7 @@ struct CAFFE2_API NamedTensorMeta final : public c10::NamedTensorMetaInterface { // When NamesMode is disabled, then all operations ignore tensors' names fields. // Concretely speaking, all tensors are treated as having nullopt names. -struct CAFFE2_API NamesMode { +struct TORCH_API NamesMode { static bool is_enabled(); static void set_enabled(bool enabled); }; @@ -77,7 +77,7 @@ struct CAFFE2_API NamesMode { // A RAII, thread local (!) guard that enables or disables names upon // construction, and sets it back to the original value upon destruction. -struct CAFFE2_API NoNamesGuard { +struct TORCH_API NoNamesGuard { NoNamesGuard() : prev_mode(NamesMode::is_enabled()), initialized(true) { NamesMode::set_enabled(false); } @@ -99,8 +99,8 @@ void check_names_valid_for(const Tensor& tensor, DimnameList names); void check_names_valid_for(size_t tensor_dim, DimnameList names); // Sets the names of `tensor` to be `names`. -CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional names); -CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, std::vector&& names, bool validate_names); +TORCH_API Tensor& internal_set_names_inplace(Tensor& tensor, c10::optional names); +TORCH_API Tensor& internal_set_names_inplace(Tensor& tensor, std::vector&& names, bool validate_names); constexpr size_t kMaxNamedTensorDim = 64; @@ -110,8 +110,8 @@ namespace impl { // Some helper functions on TensorImpl. Useful for working with names in TH. // XXX: Ideally these would exist as methods on TensorImpl -CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, optional names, bool validate_names); -CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); +TORCH_API void internal_set_names_inplace(TensorImpl* impl, c10::optional names, bool validate_names); +TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); void check_names_valid_for(TensorImpl* impl, DimnameList names); @@ -119,19 +119,19 @@ void check_names_valid_for(TensorImpl* impl, DimnameList names); // Returns false if the tensor's names don't exist (were not allocated), // or if all names are 'None'. // We treat not-allocated-names the same as allocated names that are all 'None'. -CAFFE2_API bool has_names(const TensorImpl* impl); +TORCH_API bool has_names(const TensorImpl* impl); // Returns the names of the tensor's dimensions. // Unnamed tensors are treated as having 'None' in all dimension; this method // would return a DimnameList of all 'None's for an unnamed tensor. -CAFFE2_API DimnameList get_names(const TensorImpl* impl); +TORCH_API DimnameList get_names(const TensorImpl* impl); // This is more of an implementation detail; one should use impl::get_names / // Tensor::names() whenever possible because it provides a cleaner API. // Returns the names of the tensor if they have been allocated; returns nullopt // instead if the haven't been. The names of a tensor are not allocated if a // tensor is constructed with names=None. -CAFFE2_API optional get_opt_names(const TensorImpl* impl); +TORCH_API c10::optional get_opt_names(const TensorImpl* impl); } // namespace impl diff --git a/aten/src/ATen/core/QuantizerBase.h b/aten/src/ATen/core/QuantizerBase.h index fa796e54ac429..0103c8161ea1c 100644 --- a/aten/src/ATen/core/QuantizerBase.h +++ b/aten/src/ATen/core/QuantizerBase.h @@ -32,7 +32,7 @@ using QuantizerPtr = c10::intrusive_ptr; * Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can * share the same Quantizer. Quantizer should be immutable. */ -struct CAFFE2_API Quantizer : public c10::intrusive_ptr_target { +struct TORCH_API Quantizer : public c10::intrusive_ptr_target { const ScalarType scalar_type_; explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {} virtual ~Quantizer(); diff --git a/aten/src/ATen/core/TransformationHelper.h b/aten/src/ATen/core/TransformationHelper.h index 72e3984a0540a..e8bafe3bcbad4 100644 --- a/aten/src/ATen/core/TransformationHelper.h +++ b/aten/src/ATen/core/TransformationHelper.h @@ -49,16 +49,16 @@ C10_HOST_DEVICE inline T uniform_int_full_range(V val) { /** * A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`. + * In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double + * in this overloaded version */ template -C10_HOST_DEVICE inline T uniform_int(V val) { +C10_HOST_DEVICE inline typename std::enable_if::value), T>::type uniform_int(V val) { if (std::is_same::value) { return static_cast(val & 1); - } else if (std::is_same::value) { - return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1)); } else if (std::is_same::value) { return static_cast(val % (static_cast(std::numeric_limits::max()) + 1)); - } else if (std::is_floating_point::value || std::is_same::value || std::is_same::value) { + } else if (std::is_same::value || std::is_same::value) { return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1)); } else if (std::is_integral::value) { return static_cast(val % (static_cast(std::numeric_limits::max()) + 1)); @@ -68,6 +68,15 @@ C10_HOST_DEVICE inline T uniform_int(V val) { } } +/** + * An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`, + * added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version. + */ +template +C10_HOST_DEVICE inline typename std::enable_if::value, T>::type uniform_int(V val) { + return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1)); +} + template C10_HOST_DEVICE inline dist_acctype uniform_real(V val, T from, T to) { constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); diff --git a/aten/src/ATen/core/VariableHooksInterface.h b/aten/src/ATen/core/VariableHooksInterface.h index e510471446ff7..3a85919375c48 100644 --- a/aten/src/ATen/core/VariableHooksInterface.h +++ b/aten/src/ATen/core/VariableHooksInterface.h @@ -16,7 +16,7 @@ // merge the libraries inside Facebook". Well, the problem is that there // are some downstream applications which are at binary size limit, and // incorporating all of the extra code from libtorch would push them -// over (admarket/adreview/service:adreviewservice, see also +// over (admarket/adreview/service:adreviewservice, see also // https://github.com/pytorch/pytorch/pull/29299) So if you want to do that, // we have to fix all of the services like this. // @@ -38,7 +38,7 @@ struct Node; namespace at { namespace impl { -struct CAFFE2_API VariableHooksInterface { +struct TORCH_API VariableHooksInterface { virtual ~VariableHooksInterface() = default; virtual Tensor tensor_data(const Tensor&) const = 0; virtual Tensor variable_data(const Tensor&) const = 0; @@ -50,10 +50,10 @@ struct CAFFE2_API VariableHooksInterface { virtual const std::string& name(const Tensor&) const = 0; }; -CAFFE2_API void SetVariableHooks(VariableHooksInterface* hooks); -CAFFE2_API VariableHooksInterface* GetVariableHooks(); +TORCH_API void SetVariableHooks(VariableHooksInterface* hooks); +TORCH_API VariableHooksInterface* GetVariableHooks(); -struct CAFFE2_API VariableHooksRegisterer { +struct TORCH_API VariableHooksRegisterer { explicit VariableHooksRegisterer(VariableHooksInterface* hooks) { SetVariableHooks(hooks); } diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h index b49d94bba1c84..d33f3d575177c 100644 --- a/aten/src/ATen/core/Variadic.h +++ b/aten/src/ATen/core/Variadic.h @@ -6,6 +6,7 @@ #include #include +#include namespace at { @@ -56,6 +57,15 @@ struct IterArgs { } } + template + void operator()(const torch::List& args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + // NB: we need to specify std::vector manually as C++ won't // do an implicit conversion to make a template deduction go through. template diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 4fa49302240bb..518e74b95d549 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -80,6 +80,7 @@ _(aten, _floor) \ _(aten, _fused_dropout) \ _(aten, _indexCopy) \ _(aten, _indices) \ +_(aten, _ldexp) \ _(aten, _linspace) \ _(aten, _local_scalar) \ _(aten, _local_scalar_dense) \ @@ -133,11 +134,8 @@ _(aten, _sum_cuda) \ _(aten, _tan) \ _(aten, _tanh) \ _(aten, _tanh_forward) \ -_(aten, _th_baddbmm) \ -_(aten, _th_bmm) \ _(aten, _th_get_device) \ _(aten, _th_kthvalue) \ -_(aten, _th_median) \ _(aten, _th_mode) \ _(aten, _th_prod) \ _(aten, _th_sigmoid) \ @@ -222,6 +220,7 @@ _(aten, blackman_window) \ _(aten, block_diag) \ _(aten, bmm) \ _(aten, broadcast_tensors) \ +_(aten, broadcast_to) \ _(aten, cartesian_prod) \ _(aten, cat) \ _(aten, cauchy) \ @@ -240,6 +239,7 @@ _(aten, combinations) \ _(aten, _conj) \ _(aten, conj) \ _(aten, complex) \ +_(aten, copysign) \ _(aten, polar) \ _(aten, constant_pad_nd) \ _(aten, contiguous) \ @@ -372,6 +372,10 @@ _(aten, hstack) \ _(aten, hypot) \ _(aten, i0) \ _(aten, i0_) \ +_(aten, igamma) \ +_(aten, igamma_) \ +_(aten, igammac) \ +_(aten, igammac_) \ _(aten, ifft) \ _(aten, index) \ _(aten, index_add) \ @@ -380,6 +384,7 @@ _(aten, index_fill) \ _(aten, index_put) \ _(aten, index_select) \ _(aten, indices) \ +_(aten, inner) \ _(aten, instance_norm) \ _(aten, inverse) \ _(aten, irfft) \ @@ -431,6 +436,7 @@ _(aten, logdet) \ _(aten, logit) \ _(aten, logspace) \ _(aten, logsumexp) \ +_(aten, xlogy) \ _(aten, lstm) \ _(aten, lstm_cell) \ _(aten, lstsq) \ @@ -463,6 +469,7 @@ _(aten, max_unpool3d_forward) \ _(aten, max_values) \ _(aten, mean) \ _(aten, median) \ +_(aten, nanmedian) \ _(aten, meshgrid) \ _(aten, min) \ _(aten, min_values) \ @@ -492,6 +499,7 @@ _(aten, mode) \ _(aten, mse_loss) \ _(aten, mse_loss_backward) \ _(aten, mse_loss_forward) \ +_(aten, msort) \ _(aten, multi_margin_loss) \ _(aten, multi_margin_loss_backward) \ _(aten, multi_margin_loss_forward) \ @@ -502,6 +510,7 @@ _(aten, multinomial) \ _(aten, mv) \ _(aten, mvlgamma) \ _(aten, nansum) \ +_(aten, nan_to_num) \ _(aten, narrow) \ _(aten, narrow_copy) \ _(aten, native_batch_norm) \ @@ -526,6 +535,7 @@ _(aten, nll_loss2d_forward) \ _(aten, nll_loss_backward) \ _(aten, nll_loss_forward) \ _(aten, nonzero) \ +_(aten, nonzero_numpy) \ _(aten, norm) \ _(aten, norm_except_dim) \ _(aten, normal) \ @@ -540,13 +550,14 @@ _(aten, _euclidean_dist) \ _(aten, pdist) \ _(aten, cdist) \ _(aten, permute) \ -_(aten, movedim) \ _(aten, pin_memory) \ _(aten, pinverse) \ _(aten, pixel_shuffle) \ +_(aten, pixel_unshuffle) \ _(aten, poisson) \ _(aten, polygamma) \ _(aten, pow) \ +_(aten, float_power) \ _(aten, prelu) \ _(aten, prelu_backward) \ _(aten, prod) \ @@ -564,6 +575,7 @@ _(aten, randn_like) \ _(aten, random) \ _(aten, randperm) \ _(aten, range) \ +_(aten, ravel) \ _(aten, reciprocal) \ _(aten, reflection_pad1d) \ _(aten, reflection_pad1d_backward) \ @@ -613,6 +625,7 @@ _(aten, signbit) \ _(aten, silu) \ _(aten, sgn) \ _(aten, sin) \ +_(aten, sinc) \ _(aten, sinh) \ _(aten, size) \ _(aten, sizes) \ @@ -663,7 +676,7 @@ _(aten, tan) \ _(aten, tanh) \ _(aten, tensor) \ _(aten, tensordot) \ -_(aten, th_addmm) \ +_(aten, tensor_split) \ _(aten, th_clone) \ _(aten, th_norm) \ _(aten, th_pow) \ @@ -673,6 +686,7 @@ _(aten, th_zero) \ _(aten, thnn_conv2d) \ _(aten, thnn_conv2d_backward) \ _(aten, thnn_conv2d_forward) \ +_(aten, tile) \ _(aten, slow_conv3d) \ _(aten, slow_conv3d_backward) \ _(aten, slow_conv3d_forward) \ @@ -694,7 +708,6 @@ _(aten, to_sparse) \ _(aten, to_dense) \ _(aten, topk) \ _(aten, trace) \ -_(aten, transpose) \ _(aten, triangular_solve) \ _(aten, tril) \ _(aten, triplet_margin_loss) \ @@ -733,7 +746,6 @@ _(aten, vander) \ _(aten, var) \ _(aten, view) \ _(aten, view_as) \ -_(aten, vstack) \ _(aten, where) \ _(aten, zero) \ _(aten, zeros) \ @@ -778,6 +790,7 @@ _(attr, ceil_mode) \ _(attr, checked_signal_sizes) \ _(attr, chunks) \ _(attr, columns) \ +_(attr, column_stack) \ _(attr, complex_input) \ _(attr, complex_output) \ _(attr, condition) \ @@ -900,6 +913,7 @@ _(attr, maxnorm) \ _(attr, maximum) \ _(attr, mean) \ _(attr, median) \ +_(attr, nanmedian) \ _(attr, min) \ _(attr, min_indices) \ _(attr, min_val) \ diff --git a/aten/src/ATen/core/blob.h b/aten/src/ATen/core/blob.h index 988e99b2395e9..1c59ac0aa8446 100644 --- a/aten/src/ATen/core/blob.h +++ b/aten/src/ATen/core/blob.h @@ -21,7 +21,7 @@ class Tensor; * properly when the blob is deallocated or re-allocated with a new type. A blob * could contain anything, although the most common case is to contain a Tensor. */ -class CAFFE2_API Blob final : public c10::intrusive_ptr_target { +class TORCH_API Blob final : public c10::intrusive_ptr_target { public: /** * Initializes an empty Blob. @@ -51,7 +51,7 @@ class CAFFE2_API Blob final : public c10::intrusive_ptr_target { /** * Returns the meta info of the blob. */ - const TypeMeta& meta() const noexcept { + const TypeMeta meta() const noexcept { return meta_; } @@ -155,7 +155,7 @@ class CAFFE2_API Blob final : public c10::intrusive_ptr_target { TypeMeta::Make::type>())); } - void* ShareExternal(void* allocated, const TypeMeta& meta) { + void* ShareExternal(void* allocated, const TypeMeta meta) { free_(); meta_ = meta; pointer_ = allocated; diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp index f84352ebee1f1..260343ac3180a 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction.cpp @@ -24,7 +24,7 @@ void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle& op, S op.operator_name(), " has kernels registered to both Math and a backend mapped to AutogradOther. " "This makes the backend kernel unreachable (see Note [Ambiguity in AutogradOther kernel]). " "If it's intended to override Math kernel behavior, please open an issue to request a dedicated " - "Autograd dispatch key for the backend."); + "Autograd dispatch key for the backend.", "\nCanonical state\n~~~~~~~~~~~\n", op.dumpState(), "\n\n"); } void named_not_supported_kernel(OperatorKernel*, const OperatorHandle& op, Stack*) { @@ -57,25 +57,4 @@ bool KernelFunction::_equalsBoxedAndUnboxed(const KernelFunction& other) const { unboxed_kernel_func_ == other.unboxed_kernel_func_; } -void KernelFunction::checkBoxedKernel(const OperatorHandle& opHandle) const { - if (C10_UNLIKELY(boxed_kernel_func_ == nullptr)) { - if (unboxed_kernel_func_ == nullptr) { - TORCH_INTERNAL_ASSERT( - false, - "Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction.", - " opname: ", - opHandle.operator_name(), - " If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`."); - } else { - // TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this case should be impossible. - TORCH_INTERNAL_ASSERT( - false, - "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call().", - " opname: ", - opHandle.operator_name(), - " If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`."); - } - } -} - } // namespace c10 diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index 3f8c94ba481dc..ddbbd912777af 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -15,7 +15,7 @@ struct OperatorKernel; // no overhead to fallthrough to the next key. See cpp file for some more // implementation notes; notably, this does NOT actually go through the // boxing/unboxing codepath. -CAFFE2_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, Stack*); +TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, Stack*); // Note [Ambiguity in AutogradOther kernel] // This kernel implements reporting an error message when there're kernels registered @@ -27,7 +27,7 @@ CAFFE2_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, Stack // See c10/core/DispatchKeySet.cpp for a list of backends mapped to AutogradOther. // Thus if backend extender indeed want to override Math kernel behavior, they should request // a dedicated Autograd key for their backend to resolve the ambiguity. -CAFFE2_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, Stack*); +TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, Stack*); // Note [named_not_supported_kernel] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -36,7 +36,7 @@ CAFFE2_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHa // cased in the dispatcher to be triggered before we attempt boxing (so we can // give a good error message in cases when boxing is not supported). When // boxing is universally supported this can be removed. -[[noreturn]] CAFFE2_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, Stack*); +[[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, Stack*); /** * KernelFunction is similar to std::function but stores a kernel function. @@ -44,7 +44,7 @@ CAFFE2_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHa * and call it in a boxed or unboxed way. If the way it was created doesn't * match the way it was called, it will do boxing or unboxing as necessary. */ -class CAFFE2_API KernelFunction final { +class TORCH_API KernelFunction final { public: // This is how boxed kernels are actually stored using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, Stack*); @@ -123,26 +123,6 @@ class CAFFE2_API KernelFunction final { template static KernelFunction makeFromUnboxedFunctor(std::unique_ptr kernelFunctor); - /** - * Create a KernelFunction from an unboxed functor and prevent creation of an - * unboxing-wrapper. This means that you cannot call this KernelFunction - * using KernelFunction::callBoxed() - * - * This is necessary because our unboxing wrappers don't work for all types - * yet, so if you want to use one of these types as function arguments, - * you need to use makeFromUnboxedOnlyFunctor. - * - * Example: - * - * > class MyFunctor final { - * > public: - * > Tensor operator()(Tensor a, Tensor b) {...} - * > }; - * > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::make_unique()); - */ - template - static KernelFunction makeFromUnboxedOnlyFunctor(std::unique_ptr kernelFunctor); - /** * Create a KernelFunction from an unboxed function. * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction @@ -158,23 +138,6 @@ class CAFFE2_API KernelFunction final { template static KernelFunction makeFromUnboxedFunction(FuncPtr); - /** - * Create a KernelFunction from an unboxed function and prevent creation of an - * unboxing-wrapper. This means that you cannot call this KernelFunction - * using KernelFunction::callBoxed() - * - * This is necessary because our unboxing wrappers don't work for all types - * yet, so if you want to use one of these types as function arguments, - * you need to use makeFromUnboxedOnlyFunctor. - * - * Example: - * - * > Tensor unboxed_func(Tensor a, Tensor b) {...} - * > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(); - */ - template - static KernelFunction makeFromUnboxedOnlyFunction(FuncPtr); - /** * Create a KernelFunction from an unboxed function. * KernelFunction::makeFromUnboxedFunction is usually a better choice than @@ -189,9 +152,6 @@ class CAFFE2_API KernelFunction final { template static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func); - template - static KernelFunction makeFromUnboxedOnlyRuntimeFunction(FuncType* func); - static KernelFunction makeFallthrough(); static KernelFunction makeAmbiguousAutogradOther(); static KernelFunction makeNamedNotSupported(); @@ -205,18 +165,14 @@ class CAFFE2_API KernelFunction final { * > [] (Tensor a, bool b) -> Tensor {...}); */ template - static KernelFunction makeFromUnboxedLambda(Lambda&& lambda); + static std::enable_if_t>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); + template + static std::enable_if_t>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda); std::string dumpState() const; // For testing internal invariants only bool _equalsBoxedAndUnboxed(const KernelFunction&) const; - // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed - // unboxing wrapper for aten operators. We still need those for some operators because not all work - // with the templated unboxing logic yet. - // TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic. This can be done once https://github.com/pytorch/pytorch/issues/32366 is fixed. - void setManuallyBoxedKernel_(InternalBoxedKernelFunction* func); - private: explicit KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func); @@ -224,8 +180,6 @@ class CAFFE2_API KernelFunction final { template static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, Stack* stack); - void checkBoxedKernel(const OperatorHandle& opHandle) const; - OperatorKernel* getFunctor_() const; std::shared_ptr functor_; diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index bef5afae76faf..b248e54a6f948 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -23,8 +23,7 @@ inline void KernelFunction::make_boxed_function(OperatorKernel*, const OperatorH } inline bool KernelFunction::isValid() const { - // TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this should only check boxed_kernel_func_. - return boxed_kernel_func_ != nullptr || unboxed_kernel_func_ != nullptr; + return boxed_kernel_func_ != nullptr; } inline bool KernelFunction::isFallthrough() const { @@ -32,7 +31,10 @@ inline bool KernelFunction::isFallthrough() const { } inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, Stack* stack) const { - checkBoxedKernel(opHandle); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + boxed_kernel_func_ != nullptr, + "Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction." + ); (*boxed_kernel_func_)(functor_.get(), opHandle, stack); } @@ -111,43 +113,22 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr -inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr kernelFunctor) { - // TODO We want to get rid of kernels that have only an unboxed function pointer. - // All kernels should have a boxed pointer. - - static_assert(guts::is_functor::value, "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor."); - static_assert(std::is_base_of::value, "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); - - return KernelFunction( - std::move(kernelFunctor), - nullptr, // Don't create a boxed kernel for this - reinterpret_cast(&impl::wrap_kernel_functor_unboxed::call) - ); -} - template -inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr) { +inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) { static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN."); static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); +#if !defined(C10_MOBILE) return makeFromUnboxedFunctor::type>( guts::make_unique_base::type>() ); -} - -template -inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunction(FuncPtr) { - // TODO We want to get rid of kernels that have only an unboxed function pointer. - // All kernels should have a boxed pointer. - static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with an invalid parameter. It must be a function pointer created with TORCH_FN."); - static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); - static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); - - return makeFromUnboxedOnlyFunctor::type> ( - guts::make_unique_base::type>() - ); +#else + // On mobile, we rather want to optimize for binary size than for performance, + // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction + // instead. + return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr()); +#endif } template @@ -161,19 +142,25 @@ inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* f ); } -template -inline KernelFunction KernelFunction::makeFromUnboxedOnlyRuntimeFunction(FuncType* func) { - static_assert(guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type."); - static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); - TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr"); +template +inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { + static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); - return makeFromUnboxedOnlyFunctor>>( - guts::make_unique_base>>(func) +#if !defined(C10_MOBILE) + return makeFromUnboxedFunctor>>( + guts::make_unique_base>>(std::forward(lambda)) ); +#else + // On mobile, we rather want to optimize for binary size than for performance, + // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction + // instead. + using FuncType = typename guts::infer_function_traits_t>::func_type; + return makeFromUnboxedRuntimeFunction(lambda); +#endif } template -inline KernelFunction KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { +inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); return makeFromUnboxedFunctor>>( @@ -181,14 +168,4 @@ inline KernelFunction KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { ); } -inline void KernelFunction::setManuallyBoxedKernel_(InternalBoxedKernelFunction* func) { - if (boxed_kernel_func_ == &fallthrough_kernel) { - // special case no-op - return; - } - TORCH_INTERNAL_ASSERT(boxed_kernel_func_ == nullptr, "Tried to set a manually boxed kernel for a kernel that already has a boxed kernel set."); - TORCH_INTERNAL_ASSERT(unboxed_kernel_func_ != nullptr, "Tried to set a manually boxed kernel for an invalid KernelFunction."); - boxed_kernel_func_ = func; -} - } diff --git a/aten/src/ATen/core/boxing/KernelFunction_test.cpp b/aten/src/ATen/core/boxing/KernelFunction_test.cpp index 87517afe27c6a..e17efab10ba5a 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_test.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction_test.cpp @@ -110,50 +110,92 @@ OperatorHandle makeDummyOperatorHandle() { // boxed kernels that return refs to tensor arguments, a la inplace/outplace kernels // -void boxed_func_with_tensor_ref_return(const OperatorHandle& /*opHandle*/, Stack* stack) { +void boxed_func_for_inplace_op(const OperatorHandle& /*opHandle*/, Stack* stack) { // (Tensor(a!), Scalar) -> Tensor(a!) EXPECT_EQ(2, stack->size()); ASSERT_TRUE(stack->at(0).isTensor()); - auto a = stack->at(0).toTensor(); + auto t = stack->at(0).toTensor(); ASSERT_TRUE(stack->at(1).isScalar()); - auto b = stack->at(1).toScalar(); + auto s = stack->at(1).toScalar(); - a.add_(b); + t.add_(s); stack->clear(); - torch::jit::push(stack, a); + torch::jit::push(stack, t); } -void boxed_func_with_multiple_tensor_ref_return(const OperatorHandle& /*opHandle*/, Stack* stack) { +void boxed_func_for_outofplace_op(const OperatorHandle& /*opHandle*/, Stack* stack) { + // (Scalar, Tensor(a!)) -> Tensor(a!) + EXPECT_EQ(2, stack->size()); + + ASSERT_TRUE(stack->at(0).isScalar()); + auto s = stack->at(0).toScalar(); + + ASSERT_TRUE(stack->at(1).isTensor()); + auto t = stack->at(1).toTensor(); + + t.add_(s); + + stack->clear(); + torch::jit::push(stack, t); +} + +void boxed_func_for_outofplace_multi_op(const OperatorHandle& /*opHandle*/, Stack* stack) { // (Tensor(a!), Tensor(b!), Scalar, Scalar) -> (Tensor(a!), Tensor(b!)) EXPECT_EQ(4, stack->size()); ASSERT_TRUE(stack->at(0).isTensor()); - auto a = stack->at(0).toTensor(); + auto t1 = stack->at(0).toTensor(); ASSERT_TRUE(stack->at(1).isTensor()); - auto b = stack->at(1).toTensor(); + auto t2 = stack->at(1).toTensor(); ASSERT_TRUE(stack->at(2).isScalar()); - auto c = stack->at(2).toScalar(); + auto s1 = stack->at(2).toScalar(); ASSERT_TRUE(stack->at(3).isScalar()); - auto d = stack->at(3).toScalar(); + auto s2 = stack->at(3).toScalar(); - a.add_(c); - b.add_(d); + t1.add_(s1); + t2.add_(s2); stack->clear(); - torch::jit::push(stack, a); - torch::jit::push(stack, b); + torch::jit::push(stack, t1); + torch::jit::push(stack, t2); +} + +void boxed_func_for_legacy_outofplace_multi_op(const OperatorHandle& /*opHandle*/, Stack* stack) { + // (Scalar, Scalar, Tensor(a!), Tensor(b!)) -> (Tensor(a!), Tensor(b!)) + EXPECT_EQ(4, stack->size()); + + ASSERT_TRUE(stack->at(0).isScalar()); + auto s1 = stack->at(0).toScalar(); + + ASSERT_TRUE(stack->at(1).isScalar()); + auto s2 = stack->at(1).toScalar(); + + ASSERT_TRUE(stack->at(2).isTensor()); + auto t1 = stack->at(2).toTensor(); + + ASSERT_TRUE(stack->at(3).isTensor()); + auto t2 = stack->at(3).toTensor(); + + t1.add_(s1); + t2.add_(s2); + + stack->clear(); + torch::jit::push(stack, t1); + torch::jit::push(stack, t2); } // // boxed calling tests: // +// functional + void expectBoxedCallingWithReturnWorks(const KernelFunction& func) { called_with_args = c10::nullopt; vector stack {3, 4}; @@ -198,50 +240,76 @@ void expectBoxedCallingWithMultiReturnWorks(const KernelFunction& func) { EXPECT_EQ(12, stack[1].toInt()); } -void expectBoxedCallingWithTensorRefReturnWorks(const KernelFunction& func) { - OperatorHandle dummy = makeDummyOperatorHandle(); +// in/out - auto a = at::zeros({1}); - auto b = 1.0f; - vector stack {a, b}; +void expectInPlaceBoxedCallingWorks(const KernelFunction& func) { + OperatorHandle dummy = makeDummyOperatorHandle(); + auto t = at::zeros({1}); + auto s = 1.0f; + vector stack {t, s}; func.callBoxed(dummy, &stack); - // kernel should have updated arg 0 - EXPECT_EQ(a.item().toFloat(), 1.0f); - - // and returned it on the stack + // kernel should have updated out arg and returned it + EXPECT_EQ(t.item().toFloat(), 1.0f); EXPECT_EQ(1, stack.size()); EXPECT_TRUE(stack[0].isTensor()); - auto t = stack[0].toTensor(); - EXPECT_EQ(t.item().toFloat(), 1.0f); + EXPECT_TRUE(stack[0].toTensor().is_same(t)); } -void expectBoxedCallingWithMultipleTensorRefReturnWorks(const KernelFunction& func) { +void expectOutOfPlaceBoxedCallingWorks(const KernelFunction& func) { OperatorHandle dummy = makeDummyOperatorHandle(); - auto a = at::zeros({1}); - auto b = at::zeros({1}); - auto c = 1.0f; - auto d = 2.0f; - vector stack {a, b, c, d}; - + auto s = 1.0f; + auto t = at::zeros({1}); + vector stack {s, t}; func.callBoxed(dummy, &stack); - // kernel should have updated args 0 and 1 - EXPECT_EQ(a.item().toFloat(), 1.0f); - EXPECT_EQ(b.item().toFloat(), 2.0f); + // kernel should have updated out arg and returned it on the stack + EXPECT_EQ(t.item().toFloat(), 1.0f); + EXPECT_EQ(1, stack.size()); + EXPECT_TRUE(stack[0].isTensor()); + EXPECT_TRUE(stack[0].toTensor().is_same(t)); +} + +void expectOutOfPlaceMultiBoxedCallingWorks(const KernelFunction& func) { + OperatorHandle dummy = makeDummyOperatorHandle(); - // and pushed them onto the stack - EXPECT_EQ(2, stack.size()); + auto t1 = at::zeros({1}); + auto t2 = at::zeros({1}); + auto s1 = 1.0f; + auto s2 = 2.0f; + vector stack {t1, t2, s1, s2}; + func.callBoxed(dummy, &stack); + // kernel should have updated output args and returned them on the stack + EXPECT_EQ(t1.item().toFloat(), 1.0f); + EXPECT_EQ(t2.item().toFloat(), 2.0f); + EXPECT_EQ(2, stack.size()); EXPECT_TRUE(stack[0].isTensor()); - auto ta = stack[0].toTensor(); - EXPECT_EQ(ta.item().toFloat(), 1.0f); + EXPECT_TRUE(stack[0].toTensor().is_same(t1)); + EXPECT_TRUE(stack[1].isTensor()); + EXPECT_TRUE(stack[1].toTensor().is_same(t2)); +} + +void expectLegacyOutOfPlaceMultiBoxedCallingWorks(const KernelFunction& func) { + OperatorHandle dummy = makeDummyOperatorHandle(); + auto s1 = 1.0f; + auto s2 = 2.0f; + auto t1 = at::zeros({1}); + auto t2 = at::zeros({1}); + vector stack {s1, s2, t1, t2}; + func.callBoxed(dummy, &stack); + + // kernel should have updated output args and returned them on the stack + EXPECT_EQ(t1.item().toFloat(), 1.0f); + EXPECT_EQ(t2.item().toFloat(), 2.0f); + EXPECT_EQ(2, stack.size()); + EXPECT_TRUE(stack[0].isTensor()); + EXPECT_TRUE(stack[0].toTensor().is_same(t1)); EXPECT_TRUE(stack[1].isTensor()); - auto tb = stack[1].toTensor(); - EXPECT_EQ(tb.item().toFloat(), 2.0f); + EXPECT_TRUE(stack[1].toTensor().is_same(t2)); } void expectBoxedCallingFailsWith(const KernelFunction& func, const char* errorMessage) { @@ -254,6 +322,12 @@ void expectBoxedCallingFailsWith(const KernelFunction& func, const char* errorMe }, errorMessage); } +// +// unboxed calling tests: +// + +// functional + // make an unboxed call to a kernel that returns a single value. // void expectUnboxedCallingWithReturnWorks(const KernelFunction& func) { @@ -294,57 +368,84 @@ void expectUnboxedCallingWithMultiReturnWorks(const KernelFunction& func) { EXPECT_EQ((tuple(7, 12)), result); } -// make an unboxed call to a kernel that modifies its first (Tensor) argument -// and returns a reference to it. -// -void expectUnboxedCallingWithTensorRefReturnWorks(const KernelFunction& func) { +// in/out + +void expectInPlaceUnboxedCallingWorks(const KernelFunction& func) { OperatorHandle dummy = makeDummyOperatorHandle(); - auto a = at::zeros({1}); + auto t = at::zeros({1}); + at::Tensor& t_out = func.call(dummy, t, 1.0f); + + // should have updated first arg and returned it + EXPECT_EQ(t.item().toFloat(), 1.0f); + EXPECT_EQ(&t, &t_out); +} + +void expectOutOfPlaceUnboxedCallingWorks(const KernelFunction& func) { + OperatorHandle dummy = makeDummyOperatorHandle(); - at::Tensor& t = func.call(dummy, a, 1.0f); + auto t = at::zeros({1}); + at::Tensor& t_out = func.call(dummy, 1.0f, t); - EXPECT_EQ(a.item().toFloat(), 1.0f); + // should have updated out arg and returned it EXPECT_EQ(t.item().toFloat(), 1.0f); + EXPECT_EQ(&t, &t_out); +} + +void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) { + OperatorHandle dummy = makeDummyOperatorHandle(); + + auto t1 = at::zeros({1}); + auto t2 = at::zeros({1}); + auto s1 = 1.0f; + auto s2 = 2.0f; + + std::tuple tup = func.call< + std::tuple, at::Tensor&, at::Tensor&, at::Scalar, at::Scalar + >(dummy, t1, t2, s1, s2); - EXPECT_EQ(&a, &t); + // kernel should have updated out args and returned them in a tuple + EXPECT_EQ(t1.item().toFloat(), 1.0f); + EXPECT_EQ(t2.item().toFloat(), 2.0f); + + auto t1_out = std::get<0>(tup); + EXPECT_EQ(t1_out.item().toFloat(), 1.0f); + EXPECT_TRUE(t1_out.is_same(t1)); + + auto t2_out = std::get<1>(tup); + EXPECT_EQ(t2_out.item().toFloat(), 2.0f); + EXPECT_TRUE(t2_out.is_same(t2)); } -// make an unboxed call to a kernel that modifies its first two (Tensor) arguments -// and returns them. When calling unboxed, these are returned as a tuple. -// -void expectUnboxedCallingWithMultipleTensorRefReturnWorks(const KernelFunction& func) { +void expectLegacyOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) { OperatorHandle dummy = makeDummyOperatorHandle(); - auto a = at::zeros({1}); - auto b = at::zeros({1}); - auto c = 1.0f; - auto d = 2.0f; + auto s1 = 1.0f; + auto s2 = 2.0f; + auto t1 = at::zeros({1}); + auto t2 = at::zeros({1}); std::tuple tup = func.call< - std::tuple, - at::Tensor&, - at::Tensor&, - at::Scalar, - at::Scalar - >(dummy, a, b, c, d); + std::tuple, at::Scalar, at::Scalar, at::Tensor&, at::Tensor& + >(dummy, s1, s2, t1, t2); - // kernel should have updated args 0 and 1 - EXPECT_EQ(a.item().toFloat(), 1.0f); - EXPECT_EQ(b.item().toFloat(), 2.0f); + // kernel should have updated out args and returned them in a tuple + EXPECT_EQ(t1.item().toFloat(), 1.0f); + EXPECT_EQ(t2.item().toFloat(), 2.0f); - // and returned a tuple containing them - auto ta = std::get<0>(tup); - EXPECT_EQ(ta.item().toFloat(), 1.0f); - EXPECT_TRUE(a.is_same(ta)); + auto t1_out = std::get<0>(tup); + EXPECT_EQ(t1_out.item().toFloat(), 1.0f); + EXPECT_TRUE(t1_out.is_same(t1)); - auto tb = std::get<1>(tup); - EXPECT_EQ(tb.item().toFloat(), 2.0f); - EXPECT_TRUE(b.is_same(tb)); + auto t2_out = std::get<1>(tup); + EXPECT_EQ(t2_out.item().toFloat(), 2.0f); + EXPECT_TRUE(t2_out.is_same(t2)); } } +// functional, boxed calling + TEST(KernelFunctionTest, givenBoxedFunction_withReturn_whenCallingBoxed_thenWorks) { KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_with_return>(); kernels::expectBoxedCallingWithReturnWorks(func); @@ -360,16 +461,30 @@ TEST(KernelFunctionTest, givenBoxedFunction_withMultiReturn_whenCallingBoxed_the kernels::expectBoxedCallingWithMultiReturnWorks(func); } -TEST(KernelFunctionTest, givenBoxedFunction_withTensorRefReturn_whenCallingBoxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_with_tensor_ref_return>(); - kernels::expectBoxedCallingWithTensorRefReturnWorks(func); +// in/out, boxed calling + +TEST(KernelFunctionTest, givenBoxedFunction_withInPlaceSignature_whenCallingBoxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_inplace_op>(); + kernels::expectInPlaceBoxedCallingWorks(func); +} + +TEST(KernelFunctionTest, givenBoxedFunction_withOutOfPlaceSignature_whenCallingBoxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_outofplace_op>(); + kernels::expectOutOfPlaceBoxedCallingWorks(func); +} + +TEST(KernelFunctionTest, givenBoxedFunction_withOutOfPlaceMultiSignature_whenCallingBoxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_outofplace_multi_op>(); + kernels::expectOutOfPlaceMultiBoxedCallingWorks(func); } -TEST(KernelFunctionTest, givenBoxedFunction_withMultipleTensorRefReturn_whenCallingBoxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_with_multiple_tensor_ref_return>(); - kernels::expectBoxedCallingWithMultipleTensorRefReturnWorks(func); +TEST(KernelFunctionTest, givenBoxedFunction_withLegacyOutOfPlaceMultiSignature_whenCallingBoxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_legacy_outofplace_multi_op>(); + kernels::expectLegacyOutOfPlaceMultiBoxedCallingWorks(func); } +// functional, unboxed calling + TEST(KernelFunctionTest, givenBoxedFunction_withReturn_whenCallingUnboxed_thenWorks) { KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_with_return>(); kernels::expectUnboxedCallingWithReturnWorks(func); @@ -385,16 +500,30 @@ TEST(KernelFunctionTest, givenBoxedFunction_withMultiReturn_whenCallingUnboxed_t kernels::expectUnboxedCallingWithMultiReturnWorks(func); } -TEST(KernelFunctionTest, givenBoxedFunction_withTensorRefReturn_whenCallingUnboxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_with_tensor_ref_return>(); - kernels::expectUnboxedCallingWithTensorRefReturnWorks(func); +// in/out, unboxed calling + +TEST(KernelFunctionTest, givenBoxedFunction_withInPlaceSignature_whenCallingUnboxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_inplace_op>(); + kernels::expectInPlaceUnboxedCallingWorks(func); +} + +TEST(KernelFunctionTest, givenBoxedFunction_withOutOfPlaceSignature_whenCallingUnboxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_outofplace_op>(); + kernels::expectOutOfPlaceUnboxedCallingWorks(func); +} + +TEST(KernelFunctionTest, givenBoxedFunction_withOutOfPlaceMultiSignature_whenCallingUnboxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_outofplace_multi_op>(); + kernels::expectOutOfPlaceMultiUnboxedCallingWorks(func); } -TEST(KernelFunctionTest, givenBoxedFunction_withMultipleTensorRefReturn_whenCallingUnboxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_with_multiple_tensor_ref_return>(); - kernels::expectUnboxedCallingWithMultipleTensorRefReturnWorks(func); +TEST(KernelFunctionTest, givenBoxedFunction_withLegacyOutOfPlaceMultiSignature_whenCallingUnboxed_thenWorks) { + KernelFunction func = KernelFunction::makeFromBoxedFunction<&kernels::boxed_func_for_legacy_outofplace_multi_op>(); + kernels::expectLegacyOutOfPlaceMultiUnboxedCallingWorks(func); } +// functors etc. + TEST(KernelFunctionTest, givenUnboxedFunctor_withReturn_whenCallingBoxed_thenWorks) { KernelFunction func = KernelFunction::makeFromUnboxedFunctor(std::unique_ptr(std::make_unique())); kernels::expectBoxedCallingWithReturnWorks(func); @@ -415,26 +544,6 @@ TEST(KernelFunctionTest, givenUnboxedFunctor_withoutReturn_whenCallingUnboxed_th kernels::expectUnboxedCallingWithoutReturnWorks(func); } -TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withReturn_whenCallingBoxed_thenFails) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique())); - kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()"); -} - -TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withoutReturn_whenCallingBoxed_thenFails) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique())); - kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()"); -} - -TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withReturn_whenCallingUnboxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique())); - kernels::expectUnboxedCallingWithReturnWorks(func); -} - -TEST(KernelFunctionTest, givenUnboxedOnlyFunctor_withoutReturn_whenCallingUnboxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr(std::make_unique())); - kernels::expectUnboxedCallingWithoutReturnWorks(func); -} - TEST(KernelFunctionTest, givenUnboxedFunction_withReturn_whenCallingBoxed_thenWorks) { KernelFunction func = KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernels::unboxed_function_with_return)); kernels::expectBoxedCallingWithReturnWorks(func); @@ -455,26 +564,6 @@ TEST(KernelFunctionTest, givenUnboxedFunction_withoutReturn_whenCallingUnboxed_t kernels::expectUnboxedCallingWithoutReturnWorks(func); } -TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withReturn_whenCallingBoxed_thenFails) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_with_return)); - kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()"); -} - -TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withoutReturn_whenCallingBoxed_thenFails) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_without_return)); - kernels::expectBoxedCallingFailsWith(func, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()"); -} - -TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withReturn_whenCallingUnboxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_with_return)); - kernels::expectUnboxedCallingWithReturnWorks(func); -} - -TEST(KernelFunctionTest, givenUnboxedOnlyFunction_withoutReturn_whenCallingUnboxed_thenWorks) { - KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction(TORCH_FN(kernels::unboxed_function_without_return)); - kernels::expectUnboxedCallingWithoutReturnWorks(func); -} - TEST(KernelFunctionTest, givenUnboxedRuntimeFunction_withReturn_whenCallingBoxed_thenWorks) { KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&kernels::unboxed_function_with_return); kernels::expectBoxedCallingWithReturnWorks(func); diff --git a/aten/src/ATen/core/boxing/impl/boxing.h b/aten/src/ATen/core/boxing/impl/boxing.h index d40823555c653..4f9ae1fced708 100644 --- a/aten/src/ATen/core/boxing/impl/boxing.h +++ b/aten/src/ATen/core/boxing/impl/boxing.h @@ -71,44 +71,20 @@ using can_unbox = >; // -// BoxedKernelWrapper -// -// For a given function type FT, BoxedKernelWrapper implements -// -// 1. a `boxArgs` method that boxes the function's arguments - i.e., -// inserts each argument into an IValue that it pushes onto a -// torch::jit::Stack, which it returns -// -// 2. a `call` method that -// - takes a boxed kernel and unboxed arguments as specified by FT, -// - calls `boxArgs` to box the arguments -// - calls the boxed kernel -// - unboxes and returns the result -// -// The partial specializations below handle various cases: in -// particular, not all types appearing in op signatures are supported, -// and ops returning references have nonstandard wrapper implementations. -// - -// 1. The base specialization of BoxedKernelWrapper should never be instantiated. -// A "no call method defined on BoxedKernelWrapper" compile error means that -// an op signature has failed to trigger any of the partial specializations -// that follow this one. +// boxArgs - utility for pushing unboxed args onto IValue stack // -template -struct BoxedKernelWrapper { - static_assert(sizeof(FuncType) == -1, - "Function signature contains one or more unsupported parameter and/or return types. " - "Look for a nearby error like " - "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" " - "- (your function type) is the unsupported signature."); -}; +template +static torch::jit::Stack boxArgs(Args... args) { + // TODO Reuse stack vector instead of allocating? + torch::jit::Stack stack; + stack.reserve(sizeof...(Args)); + torch::jit::push(stack, std::forward(args)...); + return stack; +} // -// 2. Supported signatures, other than ref-passing. -// - -// helper class whose specializations handle single and multiple return values, respectively +// PopResult is a helper class whose specializations handle popping single and +// multiple return values, respectively. // template struct PopResult final { @@ -145,6 +121,46 @@ struct PopResult> final { } }; +// +// BoxedKernelWrapper +// +// For a given function type FT, BoxedKernelWrapper implements +// a `call` method that +// - takes a boxed kernel and unboxed arguments as specified by FT, +// - calls `boxArgs` to box the arguments +// - calls the boxed kernel +// - unboxes and returns the result +// +// The partial specializations below handle various cases: in +// particular, not all types appearing in op signatures are supported, +// and ops returning references have nonstandard wrapper implementations. +// + +// 1. The base specialization of BoxedKernelWrapper should never be instantiated. +// A "no call method defined on BoxedKernelWrapper" compile error means that +// an op signature has failed to trigger any of the partial specializations +// that follow this one. +// +template +struct BoxedKernelWrapper { + // The reason we're not just doing straight up static_assert(false, ...) here: + // Basically, the way to make sure a static_assert only fires if a template + // is actually instantiated (rather than every time the file is parsed) is to use + // template parameters in the expression, e.g. FuncType here. However, since + // `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the same + // effect. + static_assert(sizeof(FuncType) != sizeof(FuncType), + "Function signature contains one or more unsupported parameter and/or return types. " + "Look for a nearby error like " + "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" " + "- (your function type) is the unsupported signature."); +}; + +// +// 2. Supported signatures, other than those involving non-const Tensor refs - +// i.e., "functional" ops. +// + template struct BoxedKernelWrapper< Result(Args...), @@ -153,14 +169,6 @@ struct BoxedKernelWrapper< void > > { - static torch::jit::Stack boxArgs(Args... args) { - // TODO Reuse stack vector instead of allocating? - torch::jit::Stack stack; - stack.reserve(sizeof...(Args)); - torch::jit::push(stack, std::forward(args)...); - return stack; - } - static Result call( KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, OperatorKernel* functor, @@ -188,12 +196,15 @@ struct BoxedKernelWrapper< }; // -// 3. signatures taking a single Tensor reference as their first argument, -// and also returning one. +// 3. in-place and legacy out-of-place ops take a single non-const Tensor +// reference as their first argument, and return it. +// +// Note: all signatures matching this pattern are are assumed to be for such ops. +// Because of this, the generated BoxedKernelWrapper specializations simply +// return the in-place argument. // -// Note that the passed kernels are assumed to be for inplace/outplace ops, -// and the generated BoxedKernelWrapper specializations will simply return -// the initial argument. +// TODO update comment when legacy out-of-place signatures no longer need +// to be supported, due to hacky_wrapper reordering // template @@ -201,21 +212,11 @@ struct BoxedKernelWrapper< at::Tensor&(at::Tensor&, OtherArgs...), std::enable_if_t::value, void> > { - static torch::jit::Stack boxArgs(at::Tensor& outArg, OtherArgs... otherArgs) { - // TODO Reuse stack vector instead of allocating? - torch::jit::Stack stack; - stack.reserve(1 + sizeof...(OtherArgs)); - torch::jit::push_one(stack, outArg); - torch::jit::push(stack, std::forward(otherArgs)...); - return stack; - } - static at::Tensor& call( KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, OperatorKernel* functor, const OperatorHandle& opHandle, - at::Tensor& outArg, - OtherArgs... otherArgs + at::Tensor& outArg, OtherArgs... otherArgs ) { torch::jit::Stack stack = boxArgs(outArg, otherArgs...); (*boxed_kernel_func)(functor, opHandle, &stack); @@ -230,30 +231,75 @@ struct BoxedKernelWrapper< }; // -// 4. signatures returning a tuple of Tensor references, and taking the same -// number of Tensor refs as their initial arguments. +// 4. out of place ops that take a single non-const Tensor reference as their +// final argument, and also return it. // -// Note that the passed kernels are assumed to be for inplace/outplace ops, -// and the generated BoxedKernelWrapper specializations will return a tuple -// of those initial arguments. +// Note: all signatures matching this pattern are are assumed to be for such ops. +// This assumption permits the generated BoxedKernelWrapper specializations to simply +// return out arguments. // +template +struct BoxedKernelWrapper< + at::Tensor&(FirstArg, RestArgs...), + std::enable_if_t< + can_box_all::value + // this skips over in-place (and legacy out-of-place) kernels with a non-const Tensor + // arg at the front, so those can unambiguously trigger the preceding specialization. + // TODO update comment when hacky_wrapper reorders legacy out-of-place signatures + && !is_mutable_tensor_ref::value, + void + > +> { + static at::Tensor& call( + KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, + OperatorKernel* functor, + const OperatorHandle& opHandle, + FirstArg firstArg, RestArgs... restArgs + ) { + torch::jit::Stack stack = boxArgs(firstArg, restArgs...); + (*boxed_kernel_func)(functor, opHandle, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return a single value on the stack, ", + "but instead returned ", stack.size(), " values." + ); + + return std::get(std::tuple{restArgs...}); + } +}; +// +// 5. out of place ops that take multiple non-const Tensor references as their +// final arguments, and return them in a std::tuple. +// +// Note: all signatures matching this pattern are are assumed to be for such ops. +// This assumption permits the generated BoxedKernelWrapper specializations to simply +// return the out arguments. +// template struct BoxedKernelWrapper< Result(Args...), std::enable_if_t< - can_box_all::value && is_tuple_of_mutable_tensor_refs::value, + can_box_all::value && is_tuple_of_mutable_tensor_refs::value + // this test skips over legacy kernels with out args at the front, so they can trigger + // the specialization that follows. + // note: this test is complicated by the fact that boolean value expressions in templates + // don't shortcut. some signatures have a result tuple that's wider than the arg list, and + // without the length limiting ternary these will cause a template evaluation error on this + // test, even if a length check precedes it in the conjunction. + // TODO remove when hacky_wrapper reorders legacy kernel out args + && !std::is_same< + Result, + guts::typelist::to_tuple_t< + guts::typelist::take_t< + guts::typelist::typelist, + sizeof...(Args) >= std::tuple_size::value ? std::tuple_size::value : sizeof...(Args) + > + > + >::value, void > > { - static torch::jit::Stack boxArgs(Args... args) { - // TODO Reuse stack vector instead of allocating? - torch::jit::Stack stack; - stack.reserve(sizeof...(Args)); - torch::jit::push(stack, std::forward(args)...); - return stack; - } - static Result call( KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, OperatorKernel* functor, @@ -271,15 +317,65 @@ struct BoxedKernelWrapper< "but instead returned ", stack.size(), " values." ); - auto result = guts::tuple_take(ArgTuple{args...}); + auto result = guts::tuple_take(ArgTuple{args...}); static_assert( std::is_same::value, "The parameter list of an op returning a tuple of Tensor references " - "must begin with an equal number of Tensor reference parameters." + "must end with an equal number of Tensor reference parameters." ); return result; } }; +// +// 6. legacy trap for old-school multi-return out functions with mutable args +// at start rather than end of arg list. +// TODO remove when hacky_wrapper reorders legacy kernel out args +// + +template +struct BoxedKernelWrapper< + Result(Args...), + std::enable_if_t< + can_box_all::value && is_tuple_of_mutable_tensor_refs::value + // this test fires passes for legacy kernels with out args at the front. + // note: this test is complicated by the fact that boolean value expressions in templates + // don't shortcut. some signatures have a result tuple that's wider than the arg list, and + // without the length limiting ternary these will cause a template evaluation error on this + // test, even if a length check precedes it in the conjunction. + && std::is_same< + Result, + guts::typelist::to_tuple_t< + guts::typelist::take_t< + guts::typelist::typelist, + sizeof...(Args) >= std::tuple_size::value ? std::tuple_size::value : sizeof...(Args) + > + > + >::value, + void + > +> { + static Result call( + KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func, + OperatorKernel* functor, + const OperatorHandle& opHandle, + Args... args + ) { + using ArgTuple = std::tuple; + constexpr int RetCount = std::tuple_size(); + + torch::jit::Stack stack = boxArgs(args...); + (*boxed_kernel_func)(functor, opHandle, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == RetCount, + "Boxed kernel was expected to return ", RetCount, " values on the stack, ", + "but instead returned ", stack.size(), " values." + ); + + auto legacy_result = guts::tuple_take(ArgTuple{args...}); + return legacy_result; + } +}; + } // impl } // c10 diff --git a/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp index 3671b686bc871..ab96d1ae03b11 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp @@ -75,26 +75,31 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenCatchAllKernel_whenRegis TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); expectCallsIncrement(DispatchKey::CPU); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); + auto registrar2 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); expectCallsIncrement(DispatchKey::CPU); } -TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { + +TEST(NewOperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); + auto m = MAKE_TORCH_LIBRARY(_test); + m.def("_test::my_op(Tensor dummy, int input) -> int"); + auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); + m_cpu.impl("my_op", DispatchKey::CPU, TORCH_FN(incrementKernel)); { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); + auto m_cuda = MAKE_TORCH_LIBRARY_IMPL(_test, CUDA); + m_cuda.impl("my_op", DispatchKey::CUDA, TORCH_FN(decrementKernel)); // assert that schema and cpu kernel are present expectCallsIncrement(DispatchKey::CPU); @@ -165,8 +170,8 @@ Tensor kernelWithTensorOutput(const Tensor& input) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); @@ -185,7 +190,7 @@ c10::List kernelWithTensorListOutput(const Tensor& input1, const Tensor& } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators() + auto registrar = RegisterOperators() .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", RegisterOperators::options().kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); @@ -262,8 +267,8 @@ Tensor kernelWithTensorInputByValueWithOutput(Tensor input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -279,8 +284,8 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByR TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -306,8 +311,8 @@ void kernelWithTensorInputByValueWithoutOutput(Tensor input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -323,8 +328,8 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByR TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); diff --git a/aten/src/ATen/core/boxing/impl/kernel_lambda_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_lambda_test.cpp index 6000ca4188693..063379b15666c 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_lambda_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_lambda_test.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -54,26 +55,30 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenOutOfLineKernel_whenRegist TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t i) {return i+1;})) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t i) {return i+1;}) + .kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}) + .kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); expectCallsIncrement(DispatchKey::CPU); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t i) {return i+1;})); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t i) {return i+1;}) + .kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); + auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}) + .kernel(DispatchKey::CUDA, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); expectCallsIncrement(DispatchKey::CPU); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor, int64_t i) {return i+1;})); + auto m = MAKE_TORCH_LIBRARY(_test); + m.def("_test::my_op(Tensor dummy, int input) -> int"); + auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); + m_cpu.impl("my_op", DispatchKey::CPU, [] (Tensor, int64_t i) {return i+1;}); { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA, [] (Tensor, int64_t i) {return i-1;})); + auto m_cuda = MAKE_TORCH_LIBRARY_IMPL(_test, CUDA); + m_cuda.impl("my_op", DispatchKey::CUDA, [] (Tensor, int64_t i) {return i-1;}); // assert that schema and cpu kernel are present expectCallsIncrement(DispatchKey::CPU); @@ -132,9 +137,8 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRe TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::returning_tensor(Tensor input) -> Tensor", - RegisterOperators::options().kernel(DispatchKey::CPU, [] (const Tensor& a) {return a;})) - .op("_test::returning_tensor(Tensor input) -> Tensor", - RegisterOperators::options().kernel(DispatchKey::CUDA, [] (const Tensor& a) {return a;})); + RegisterOperators::options().kernel(DispatchKey::CPU, [] (const Tensor& a) {return a;}) + .kernel(DispatchKey::CUDA, [] (const Tensor& a) {return a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); @@ -216,9 +220,8 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(DispatchKey::CPU, [] (const Tensor& a) {return a;})) - .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(DispatchKey::CUDA, [] (const Tensor& a) {return a;})); + RegisterOperators::options().kernel(DispatchKey::CPU, [] (const Tensor& a) {return a;}) + .kernel(DispatchKey::CUDA, [] (const Tensor& a) {return a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -235,9 +238,8 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByRef TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor a) {return a;})) - .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(DispatchKey::CUDA, [] (Tensor a) {return a;})); + RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor a) {return a;}) + .kernel(DispatchKey::CUDA, [] (Tensor a) {return a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -256,9 +258,8 @@ Tensor captured_input; TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(DispatchKey::CPU, [] (const Tensor& a) -> void {captured_input = a;})) - .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(DispatchKey::CUDA, [] (const Tensor& a) -> void {captured_input = a;})); + RegisterOperators::options().kernel(DispatchKey::CPU, [] (const Tensor& a) -> void {captured_input = a;}) + .kernel(DispatchKey::CUDA, [] (const Tensor& a) -> void {captured_input = a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -275,9 +276,8 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByRef TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor a) -> void {captured_input = a;})) - .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(DispatchKey::CUDA, [] (Tensor a) -> void {captured_input = a;})); + RegisterOperators::options().kernel(DispatchKey::CPU, [] (Tensor a) -> void {captured_input = a;}) + .kernel(DispatchKey::CUDA, [] (Tensor a) -> void {captured_input = a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); diff --git a/aten/src/ATen/core/boxing/impl/kernel_stackbased_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_stackbased_test.cpp index 364345d9974ff..7e9c57585fa7f 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_stackbased_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_stackbased_test.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -72,26 +73,30 @@ TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistered_thenC TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPU)) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDA)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CPU)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDA)); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPU) + .kernel<&errorKernel>(DispatchKey::CUDA)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CPU) + .kernel<&errorKernel>(DispatchKey::CUDA)); expectCallsIncrement(DispatchKey::CPU); } TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPU)); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDA)); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CPU)); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDA)); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPU) + .kernel<&errorKernel>(DispatchKey::CUDA)); + auto registrar2 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CPU) + .kernel<&errorKernel>(DispatchKey::CUDA)); expectCallsIncrement(DispatchKey::CPU); } TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPU)); + auto m = MAKE_TORCH_LIBRARY(_test); + m.def("_test::my_op(Tensor dummy, int input) -> int"); + auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); + m_cpu.impl("my_op", DispatchKey::CPU, torch::CppFunction::makeFromBoxedFunction()); { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&decrementKernel>(DispatchKey::CUDA)); + auto m_cuda = MAKE_TORCH_LIBRARY_IMPL(_test, CUDA); + m_cuda.impl("my_op", DispatchKey::CUDA, torch::CppFunction::makeFromBoxedFunction()); // assert that schema and cpu kernel are present expectCallsIncrement(DispatchKey::CPU); diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 7341dc7c9b0e3..b9f59b3192399 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -26,7 +26,7 @@ class OperatorHandle; * * See below for how to register this kernel with PyTorch. */ -struct CAFFE2_API OperatorKernel { +struct TORCH_API OperatorKernel { virtual ~OperatorKernel() = default; }; @@ -119,14 +119,6 @@ namespace impl { "You tried to register a kernel with an unsupported input type: List. Please use List, List or Tensor instead."); }; - template - struct assert_is_valid_input_type, AllowDeprecatedTypes> - : assert_is_valid_input_type { - static_assert(!std::is_same::value, - "You tried to register a kernel with an unsupported input type: std::vector. Please use List, List or Tensor instead."); - // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::vector. Please use List instead."); - }; - template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type { @@ -273,37 +265,36 @@ namespace impl { return ivalue_to_arg, AllowDeprecatedTypes>::call(std::move(v)); } }; - template - struct ivalue_to_arg>, AllowDeprecatedTypes> final { - // If an argument is optional>, convert the IValue to a optional> and pass that - // to the operator. - static OptionalArray call(IValue&& v) { - return std::move(v).toOptionalIntArray(); - } - }; - template - struct ivalue_to_arg>, AllowDeprecatedTypes> final { - // If an argument is optional>, convert the IValue to a optional> and pass that - // to the operator. - static OptionalArray call(IValue&& v) { - return std::move(v).toOptionalDoubleArray(); + template + struct ivalue_to_arg>, AllowDeprecatedTypes> final { + // If an argument is optional>, convert the IValue to an optional> and pass that + // to the operator. OptionalArray is basically a optional> but impliticly convertible + // to optional>. + static OptionalArray call(IValue&& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(std::move(v)); } }; // return_to_ivalue + template + struct return_to_ivalue final {}; template - IValue return_to_ivalue(T&& v) { - assert_is_valid_output_type(); - return c10::ivalue::from(std::forward(v)); - } + struct return_to_ivalue::value>> final { + static IValue call(T&& v) { + assert_is_valid_output_type(); + return c10::ivalue::from(std::move(v)); + } + }; // Special case to allow kernels to return `Tensor&`. // TODO Delete this once kernels don't do that anymore - template<> - inline IValue return_to_ivalue(at::Tensor& v) { - return c10::ivalue::from(v); - } + template + struct return_to_ivalue final { + static IValue call(at::Tensor& v) { + return c10::ivalue::from(v); + } + }; // reference_cast allows casting references, e.g. T&& to T&: // T make_t() {} @@ -323,8 +314,6 @@ namespace impl { call_functor_with_args_from_stack_(Functor* functor, Stack* stack, std::index_sequence) { (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning. - constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices); - /* * For ops that take "Tensor&" as an argument, ivalue_to_arg would still return a "Tensor" by value * and C++ doesn't allow us to call (*functor) with a temporary "Tensor" when it expects "Tensor&". @@ -335,7 +324,7 @@ namespace impl { using ArgTypes = typename guts::infer_function_traits_t::parameter_types; return (*functor)(reference_cast>( ivalue_to_arg>, AllowDeprecatedTypes>::call( - std::move(torch::jit::peek(*stack, ivalue_arg_indices, num_ivalue_args)) + std::move(torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices))) ))...); } @@ -351,7 +340,7 @@ namespace impl { template struct push_outputs final { static void call(OutputType&& output, Stack* stack) { - torch::jit::push(*stack, return_to_ivalue(std::forward(output))); + torch::jit::push(*stack, return_to_ivalue::call(std::forward(output))); } }; template @@ -363,7 +352,7 @@ namespace impl { private: template static void call_(std::tuple&& output, Stack* stack, std::index_sequence) { - torch::jit::push(*stack, return_to_ivalue(std::move(std::get(output)))...); + torch::jit::push(*stack, return_to_ivalue::call(std::forward(std::get(output)))...); } }; template diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp index c94c14325111b..b5c5e1415cb76 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -68,41 +69,21 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistered_the TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); expectCallsIncrement(DispatchKey::CPU); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); + auto registrar2 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); expectCallsIncrement(DispatchKey::CPU); } -TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { - { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); - { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA)); - - // assert that schema and cpu kernel are present - expectCallsIncrement(DispatchKey::CPU); - expectCallsDecrement(DispatchKey::CUDA); - } - - // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not - expectCallsIncrement(DispatchKey::CPU); - expectDoesntFindKernel("_test::my_op", DispatchKey::CUDA); - } - - // now both registrars are destructed. Assert that the whole schema is gone - expectDoesntFindOperator("_test::my_op"); -} - bool was_called = false; struct KernelWithoutOutput final : OperatorKernel { @@ -166,8 +147,8 @@ struct KernelWithTensorOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); @@ -273,8 +254,8 @@ struct KernelWithTensorInputByValueWithOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -290,8 +271,8 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByRe TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -321,8 +302,8 @@ struct KernelWithTensorInputByValueWithoutOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -338,8 +319,8 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByRe TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDA)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPU) + .kernel(DispatchKey::CUDA)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); @@ -609,8 +590,8 @@ class KernelWithConstructorArg final : public OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithConstructorArg_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, 2)) - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA, 4)); + .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, 2) + .kernel(DispatchKey::CUDA, 4)); auto op = c10::Dispatcher::singleton().findSchema({"_test::offset_op", ""}); ASSERT_TRUE(op.has_value()); @@ -639,8 +620,8 @@ class KernelWithMultipleConstructorArgs final : public OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstructorArgs_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, 2, 3)) - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDA, 4, 5)); + .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU, 2, 3) + .kernel(DispatchKey::CUDA, 4, 5)); auto op = c10::Dispatcher::singleton().findSchema({"_test::offset_op", ""}); ASSERT_TRUE(op.has_value()); diff --git a/aten/src/ATen/core/builtin_function.h b/aten/src/ATen/core/builtin_function.h index b4804cfebcbe4..de30f9b7e179f 100644 --- a/aten/src/ATen/core/builtin_function.h +++ b/aten/src/ATen/core/builtin_function.h @@ -1,7 +1,11 @@ #pragma once -#include #include +#include +#include +#include +#include +#include namespace torch { namespace jit { @@ -10,13 +14,19 @@ struct BuiltinOpFunction : public Function { BuiltinOpFunction( c10::QualifiedName qualname, c10::FunctionSchema schema, - std::function callable) + std::function callable, + std::string doc_string = "") : name_(std::move(qualname)), callable_(std::move(callable)), - schema_(std::move(schema)) { + schema_(std::move(schema)), + doc_string_(std::move(doc_string)) { TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1); } + const std::string& doc_string() const override { + return doc_string_; + } + bool isGraphFunction() const override { return false; } @@ -29,7 +39,9 @@ struct BuiltinOpFunction : public Function { callable_(stack); } - c10::intrusive_ptr runAsync(Stack& stack) override { + c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher /* not used */) override { run(stack); auto res = c10::make_intrusive(stack.front().type()); res->markCompleted(std::move(stack.front())); @@ -93,8 +105,17 @@ struct BuiltinOpFunction : public Function { } std::string pretty_print_schema() const override { + #ifdef __NVCC__ + // Disable the "statement is unreachable" warning + #pragma diag_suppress code_is_unreachable + #endif + TORCH_INTERNAL_ASSERT(false); return ""; + + #ifdef __NVCC__ + #pragma diag_default code_is_unreachable + #endif } Function& setSchema(c10::FunctionSchema schema) override { @@ -110,6 +131,8 @@ struct BuiltinOpFunction : public Function { std::function callable_; c10::FunctionSchema schema_; + + std::string doc_string_; }; } // namespace jit diff --git a/aten/src/ATen/core/dispatch/CppSignature.h b/aten/src/ATen/core/dispatch/CppSignature.h index 9cfc7b33a4ab1..b5a41ca542356 100644 --- a/aten/src/ATen/core/dispatch/CppSignature.h +++ b/aten/src/ATen/core/dispatch/CppSignature.h @@ -10,7 +10,7 @@ namespace impl { // A CppSignature object holds RTTI information about a C++ function signature at runtime // and can compare them or get a debug-printable name. -class CAFFE2_API CppSignature final { +class TORCH_API CppSignature final { public: CppSignature(const CppSignature&) = default; CppSignature(CppSignature&&) noexcept = default; diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index f8d401e454310..1bc1f1d819db0 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -102,7 +102,7 @@ namespace detail { * varies from operator, as some operators may have overridden the * fallthrough with custom behavior. */ -struct CAFFE2_API DispatchKeyExtractor final { +struct TORCH_API DispatchKeyExtractor final { public: static DispatchKeyExtractor make(const FunctionSchema& schema) { return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema)); diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 6b4774d8f675d..270cffaf6d1ff 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -134,13 +134,11 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin OperatorName op_name = schema.operator_name(); auto op = findOrRegisterName_(op_name); - if (op.operatorIterator_->def_count == 0) { - // NB: registerSchema is not idempotent! Only do it once! - op.operatorIterator_->op.registerSchema(std::move(schema), std::move(debug)); - listeners_->callOnOperatorRegistered(op); - } else { - checkSchemaCompatibility(op, schema, debug); - } + TORCH_CHECK(op.operatorIterator_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.", + " Each overload's schema should only be registered with a single call to def().", + " Duplicate registration: ", debug, ". Original registration: ", op.operatorIterator_->op.debug()); + op.operatorIterator_->op.registerSchema(std::move(schema), std::move(debug)); + listeners_->callOnOperatorRegistered(op); // NB: do not increment the counts until AFTER error checking ++op.operatorIterator_->def_count; @@ -151,25 +149,6 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin }); } -void Dispatcher::checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug) { - TORCH_CHECK(op.schema() == schema, "Tried to register multiple operators with the same name and the same overload name but different schemas: ", schema, " (", debug, ") vs ", op.schema(), " (", op.debug(), ")"); - if (schema.isDefaultAliasAnalysisKind()) { - // [BACKWARDS COMPAT] If the *new* schema is the default alias analysis - // kind, for BC, we will accept it. If we don't accept it, most extensions - // that override existing operators will stop working (as they generally did - // not specify alias information). - } else if (op.schema().isDefaultAliasAnalysisKind()) { - // [BACKWARDS COMPAT] If you POST-FACTO specify a non-default alias analysis - // kind after we already have a schema for a function, bong it in for BC - // reasons. - op.operatorIterator_->op.updateSchemaAliasAnalysis(schema.aliasAnalysis()); - } else { - TORCH_CHECK(op.schema().aliasAnalysis() == schema.aliasAnalysis(), - "Tried to define the schema for ", toString(op.operator_name()), " with different alias analysis kinds: ", - toString(op.schema().aliasAnalysis()), " (", op.debug(), ") vs ", toString(schema.aliasAnalysis()), " (", debug, ")"); - } -} - void Dispatcher::deregisterDef_(const OperatorHandle& op, const OperatorName& op_name) { // we need a lock to avoid concurrent writes std::lock_guard lock(mutex_); @@ -316,10 +295,16 @@ void Dispatcher::checkInvariants() const { } } -void Dispatcher::setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func) { - std::lock_guard lock(mutex_); - op.operatorIterator_->op.setManuallyBoxedKernel_(*this, func); - // NB: Do not need to set manually boxed kernel for backend fallbacks +std::vector Dispatcher::findDanglingImpls() const { + return operatorLookupTable_.read([&] (const ska::flat_hash_map& operatorLookupTable) -> std::vector { + std::vector opsWithDanglingImpls; + for (const auto& op : operatorLookupTable) { + if (!op.second.hasSchema()) { + opsWithDanglingImpls.push_back(op.second); + } + } + return opsWithDanglingImpls; + }); } } diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 36eadd7c192d3..d83653f753631 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -16,7 +16,7 @@ namespace c10 { -class CAFFE2_API OperatorHandle; +class TORCH_API OperatorHandle; template class TypedOperatorHandle; /** @@ -27,7 +27,7 @@ template class TypedOperatorHandle; * NB: registration events only occur when a 'def' occurs; we don't trigger * on 'impl' or 'fallback' calls. */ -class CAFFE2_API OpRegistrationListener { +class TORCH_API OpRegistrationListener { public: virtual ~OpRegistrationListener(); @@ -45,7 +45,7 @@ class SchemaRegistrationHandleRAII; * Most end users shouldn't use this directly; if you're trying to register * ops look in op_registration */ -class CAFFE2_API Dispatcher final { +class TORCH_API Dispatcher final { private: // For direct access to backend fallback information friend class impl::OperatorEntry; @@ -182,12 +182,6 @@ class CAFFE2_API Dispatcher final { */ RegistrationHandleRAII registerLibrary(std::string ns, std::string debug); - // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed - // unboxing wrapper for aten operators. We still need those for some operators because not all work - // with the templated unboxing logic yet. - // TODO Delete setBoxedKernelFor_ once all operators work with the templated boxing logic - void setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func); - // ------------------------------------------------------------------------ // // Listeners on registrations @@ -211,6 +205,28 @@ class CAFFE2_API Dispatcher final { return dispatch_key != DispatchKey::BackendSelect; } + // + // ------------------------------------------------------------------------ + // + // Assertions + // + // ------------------------------------------------------------------------ + + /** + * For testing purposes. + * Returns a list of all operators that were created through calls to registerImpl(), + * without any corresponding calls to registerDef(). After static initialization + * is done this is almost certainly a bug, as the created OperatorHandle won't have + * any schema associated with it and users calling the op through the dispatcher + * won't be able to access it + * + * Note that we cannot enforce this invariant "as we go" during static initialization, + * due to undefined static initialization order- we have no guarantees over the order + * in which .def() and .impl() calls are registered in the dispatcher at static + * initialization time. So this function should only be called after static initialization. + */ + std::vector findDanglingImpls() const; + private: Dispatcher(); @@ -245,7 +261,7 @@ class CAFFE2_API Dispatcher final { * This handle can be used to register kernels with the dispatcher or * to lookup a kernel for a certain set of arguments. */ -class CAFFE2_API OperatorHandle { +class TORCH_API OperatorHandle { public: OperatorHandle(OperatorHandle&&) noexcept = default; OperatorHandle& operator=(OperatorHandle&&) noexcept = default; @@ -288,7 +304,9 @@ class CAFFE2_API OperatorHandle { // smuggle in a kernel that is typed incorrectly). For everything // in core library this won't happen, because all the static registrations // will be done by the time a typed() handle is acquired. +#if !defined C10_MOBILE operatorIterator_->op.assertSignatureIsCorrect(); +#endif return TypedOperatorHandle(operatorIterator_); } @@ -349,28 +367,39 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandleop.lookup(dispatchKey); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - // Check if we need to run callbacks registered with RecordFunction - // If true and callbacks need inputs, we box the arguments and pass - // them into the callbacks and also into the kernel call - - // Note: for perf reasons we wouldn't want to pass arguments into - // the function call or prematurely box them - at::RecordFunction guard(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(guard.active)) { - if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) { - int64_t seq_num = -1; - // Setting sequence number in the Autograd case to associate - // the forward range with the coresponding Autograd's node - if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { - seq_num = at::sequence_number::peek(); - } - if (guard.needs_inputs) { - torch::jit::Stack stack = impl::BoxedKernelWrapper::boxArgs(args...); - guard.before(op.schema().name(), stack, seq_num); - } else { - guard.before(op.schema().name(), seq_num); + // By default, when there're no high-frequency or non-sampled callbacks, + // RecordFunction is pre-sampled as a perf optimization; + // shouldRunRecordFunction checks whether RecordFunction should be executed, + // and sets pre_sampled boolean argument value to whether pre-sampling was used - + // this boolean is passed into RecordFunction to adjust the sampling rates of + // the callbacks + bool pre_sampled = false; + if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { + // Check if we need to run callbacks registered with RecordFunction + // If true and callbacks need inputs, we box the arguments and pass + // them into the callbacks and also into the kernel call + + // Note: for perf reasons we wouldn't want to pass arguments into + // the function call or prematurely box them + at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) { + int64_t seq_num = -1; + // Setting sequence number in the Autograd case to associate + // the forward range with the coresponding Autograd's node + if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { + seq_num = at::sequence_number::peek(); + } + if (guard.needsInputs()) { + torch::jit::Stack stack = impl::boxArgs(args...); + guard.before(op, stack, seq_num); + } else { + guard.before(op, seq_num); + } } } + // keeping the guard alive while executing the kernel + return kernel.template call(op, std::forward(args)...); } #endif // PYTORCH_DISABLE_PER_OP_PROFILING return kernel.template call(op, std::forward(args)...); @@ -407,20 +436,26 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const const auto& kernel = entry.lookup(dispatchKey); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - // using already existing stack to record function execution in observers - at::RecordFunction guard(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(guard.active)) { - if (shouldRecord(dispatchKey) && entry.isObserved()) { - int64_t seq_num = -1; - if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { - seq_num = at::sequence_number::peek(); - } - if (guard.needs_inputs) { - guard.before(op.schema().name(), *stack, seq_num); - } else { - guard.before(op.schema().name(), seq_num); + bool pre_sampled = false; + if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { + // using already existing stack to record function execution in observers + at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + if (shouldRecord(dispatchKey) && entry.isObserved()) { + int64_t seq_num = -1; + if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) { + seq_num = at::sequence_number::peek(); + } + if (guard.needsInputs()) { + guard.before(op, *stack, seq_num); + } else { + guard.before(op, seq_num); + } } } + // keeping the guard alive while executing the kernel + kernel.callBoxed(op, stack); + return; } #endif // PYTORCH_DISABLE_PER_OP_PROFILING kernel.callBoxed(op, stack); diff --git a/aten/src/ATen/core/dispatch/ObservedOperators.h b/aten/src/ATen/core/dispatch/ObservedOperators.h index 45db9d126d761..b8919d06cdf62 100644 --- a/aten/src/ATen/core/dispatch/ObservedOperators.h +++ b/aten/src/ATen/core/dispatch/ObservedOperators.h @@ -4,7 +4,7 @@ namespace c10 { -struct CAFFE2_API ObservedOperators { +struct TORCH_API ObservedOperators { ObservedOperators() = delete; static bool isObserved(const OperatorName& name); diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 0942659d29607..7c3698beeb064 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -21,9 +21,7 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name) , schema_() , dispatchTable_() , dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized()) -, manuallyBoxedKernel_() , kernels_() -, catchAllKernel_() , cpp_signature_() , is_observed_(ObservedOperators::isObserved(name_)) { @@ -37,9 +35,13 @@ namespace { c10::optional schema_difference = findSchemaDifferences(from_def, inferred); if (schema_difference.has_value()) { TORCH_CHECK(false, - "In registration for ", toString(name), ": expected schema of operator to be \"", toString(from_def), "\" (", from_def_debug, "), ", - "but got inferred schema \"", toString(inferred), "\" (", inferred_debug, "). ", - *schema_difference); + "Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n" + " operator: ", toString(name), "\n", + " expected schema: ", toString(from_def), "\n", + " ", from_def_debug, "\n", + " inferred schema: ", toString(inferred), "\n", + " ", inferred_debug, "\n", + " reason: ", *schema_difference); } } } // anonymous namespace @@ -56,11 +58,6 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) } } } - for (auto j = catchAllKernel_.begin(); j != catchAllKernel_.end(); ++j) { - if (j->inferred_function_schema != nullptr) { - checkSchema(name_, schema, debug, *j->inferred_function_schema, j->debug); - } - } // NB: don't register schema until after we've checked everything! dispatchKeyExtractor_.registerSchema(schema); schema_ = AnnotatedSchema(std::move(schema), std::move(debug)); @@ -89,13 +86,19 @@ std::list::iterator OperatorEntry::registerKernel( // that would also invalidate the old TypedOperatorHandles. if (cpp_signature.has_value()) { if (cpp_signature_.has_value()) { - TORCH_INTERNAL_ASSERT(*cpp_signature == *cpp_signature_, - "Tried to register a kernel (", debug, ") for operator ", name_," for dispatch key ", toString(dispatch_key), - ", but the C++ function signature ", cpp_signature->name(), " mismatched with a previous kernel that had the signature ", - cpp_signature_->name() + TORCH_CHECK(*cpp_signature == cpp_signature_->signature, + "\nMismatch in kernel C++ signatures\n", + " operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n", + " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", + " kernel 1: ", cpp_signature_->signature.name(), "\n", + " dispatch key: ", toString(cpp_signature_->dispatch_key), "\n", + " ", cpp_signature_->debug, "\n", + " kernel 2: ", cpp_signature->name(), "\n", + " dispatch key: ", toString(dispatch_key), "\n", + " ", debug, "\n" ); } else { - cpp_signature_ = *cpp_signature; + cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key }; } } @@ -105,14 +108,17 @@ std::list::iterator OperatorEntry::registerKernel( // Add the kernel to the kernels list, // possibly creating the list if this is the first kernel. - auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : catchAllKernel_; + // Redirect catchAll registrations to Math. + auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math]; if (k.size() > 0) { - TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator."); - } - - if (manuallyBoxedKernel_.has_value()) { - kernel.setManuallyBoxedKernel_(*manuallyBoxedKernel_); + TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n", + " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", + " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", + " dispatch key: ", toString(dispatch_key), "\n", + " previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n", + " new kernel: ", debug + ); } k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug)); @@ -132,20 +138,17 @@ void OperatorEntry::deregisterKernel_( c10::optional dispatch_key, std::list::iterator kernel ) { - if (dispatch_key.has_value()) { - auto found = kernels_.find(*dispatch_key); - TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_)); - auto& k = found->second; - k.erase(kernel); - if (k.empty()) { - // the invariant says we don't want empty lists but instead remove the list from the map - kernels_.erase(found); - } - updateDispatchTable_(dispatcher, *dispatch_key); - } else { - catchAllKernel_.erase(kernel); - updateDispatchTableFull_(dispatcher); + // Redirect catchAll deregistrations to Math. + DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::Math; + auto found = kernels_.find(dk); + TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_)); + auto& k = found->second; + k.erase(kernel); + if (k.empty()) { + // the invariant says we don't want empty lists but instead remove the list from the map + kernels_.erase(found); } + updateDispatchTable_(dispatcher, dk); } void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) { @@ -156,7 +159,7 @@ const KernelFunction& OperatorEntry::computeDispatchTableEntry(const c10::Dispat return computeDispatchTableEntryWithDebug(dispatcher, dispatch_key).first.kernel; } -bool OperatorEntry::hasKernelForDispatchKeySet(DispatchKeySet ks) const { +bool OperatorEntry::hasKernelForAnyDispatchKey(DispatchKeySet ks) const { TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end()); for (auto& kv : kernels_) { if (ks.has(kv.first)) return true; @@ -175,95 +178,104 @@ c10::optional OperatorEntry::getKernelForDispatchKey(Dis } std::pair OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const { - auto dispatch_ix = static_cast(dispatch_key); // [Note] DispatchTable computation // dispatchTable contains entries for runtime dispatch keys. // For any dispatch key, it'll pick a kernel using the following order: // (1) Use kernel if it's directly registered to this key // (2) Handle runtime keys that have kernels available from alias keys - // (2.1) Use kernel from DispatchKey::Math if available. + // (2.1) Use kernel from DispatchKey::DefaultBackend if available. + // This is used to register a kernel that works for all backend in inference. But it requires + // separate registration for Autograd keys to support training. + // (2.2) Use kernel from DispatchKey::Math if available. // For autograd keys, we only use kernel from Math when there's no direct registration - // to its corresponding backend key. + // to its corresponding backend key or DefaultBackend. See Note [DefaultBackend and Math]. // For AutogradOther, we eagerly return ambiguousAutogradOtherKernel_ if there's registration to any of // its backends and ask backend extender to request a decicated Autograd key for the backend. // See Note [Ambiguity in AutogradOther kernel] for more details. - // (2.2) Use kernel from DispatchKey::Autograd if available - // (2.3) Special logic to handle catchAll for Autograd keys - // For autograd backend keys, we use kernel from alias Math key (catchAll will be moved to Math) - // if there's no direct registration to the backend key. - // Tensor factory functions used to have no registration to Autograd key but only to catchAll. - // In the past we directly call into backends(filled with catchAll) after BackendSelect. - // Now that we first call Autograd backend keys after BackendSelect, we should fill those - // with catchAll as well. - // The implementation of (2.1) & (2.3) relies on the invariant that for a given backend, + // A DefaultBackend kernel prevents Math kernel being used for Autograd keys, but it doesn't + // cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available) + // in this case. + // (2.3) Use kernel from DispatchKey::Autograd if available + // The implementation of (2.2) relies on the invariant that for a given backend, // `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the // backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_] // (3) Use fallthrough kernel that are registered as fallback. - // (4) Use catchAll kernel if available // Alias Key Precedence: - // Math > Autograd + // DefaultBackend > Math > Autograd + // Note [DefaultBackend and Math] + // When there're registrations to both DefaultBackend & Math & Autograd, from (2.2) we know DefaultBackend + // and Autograd kernels will be picked up and Math is overriden. + // This is fine and in practice DefaultBackend and Math shouldn't co-exist for an op. // TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA . - // TODO: we can remove (2.3) and (4) after TypeDefault registrations are moved from catchAll to Math - // so that Math can populate to Autograd backend keys before fallback kernels. // 1. Operator registration if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) { return {*direct_registration.value(), "kernel"}; } - bool is_autograd_key_with_backend_kernel = - hasKernelForDispatchKeySet(getBackendKeySetFromAutograd(dispatch_key)); - // 2.1. Use Math kernel if available. For autograd keys, we only use kernel from Math - // when there's no direct registration to its corresponding backend key. + // 2.1 Use DefaultBackend kernel if available. + // See Note [Undefined in dispatchTable_] for the special handling for Undefined. + if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) { + if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) { + return {*default_backend_registration.value(), "default backend kernel"}; + } + } + + // Note when there's direct registration to DefaultBackend, this code path will only be hit by + // non backend keys (e.g AutogradXXX, Batched etc) due to (2.1). + bool has_backend_kernel = + hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key).add(DispatchKey::DefaultBackend)); + + // 2.2. Use Math kernel if available. For autograd keys, we only use kernel from Math + // when there's no direct registration to its corresponding backend key or DefaultBackend. // For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration // to any of its backends. - if (isIncludedInAlias(dispatch_key, DispatchKey::Math)) { + // See Note [Undefined in dispatchTable_] for the special handling for Undefined. + if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::Math)) { if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) { - if (dispatch_key == DispatchKey::AutogradOther && is_autograd_key_with_backend_kernel) { + if (dispatch_key == DispatchKey::AutogradOther + && hasKernelForAnyDispatchKey(c10::autogradother_backends)) { return {ambiguousAutogradOtherKernel_, "ambiguous autogradother"}; - } else if (!is_autograd_key_with_backend_kernel) { + } else if (!has_backend_kernel) { return {*math_registration.value(), "math kernel"}; } } } - // 2.2. For autograd backend keys, use kernel from DispatchKey::Autograd if available + // 2.3. For autograd backend keys, use kernel from DispatchKey::Autograd if available if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)) { if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) { return {*autograd_registration.value(), "autograd kernel"}; } } - // 2.3. For autograd backend keys, we use kernel from catchAll if there's no direct - // registration to the backend key. Once CatchAll is moved to Math, this should - // fit 2.1 and we can remove 2.3 entirely. - if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd) - && !is_autograd_key_with_backend_kernel && !catchAllKernel_.empty()) { - TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid()); - return {catchAllKernel_.front(), "catch all"}; - } - // 3. Backend fallback + auto dispatch_ix = static_cast(dispatch_key); if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) { return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"}; - - // 4. Catch all - } else if (!catchAllKernel_.empty()) { - TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid()); - return {catchAllKernel_.front(), "catch all"}; - - // 5. Default to error - } else { - return {missingKernel_, "missing"}; } + + // 4. Default to error + return {missingKernel_, "missing"}; } +// synchronizes the dispatch table entry for a given dispatch key +// with the current state of kernel registrations in the dispatcher. +// note that this is not a complete update, due to relationships between +// dispatch keys (e.g. runtime keys and their associated autograd keys, +// or alias keys and their associated keysets). +// This function should be considered a private helper for updateDispatchTable_() void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) { auto dispatch_ix = static_cast(dispatch_key); dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key); dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough()); } +// synchronizes the dispatch table entries for a given dispatch key *and its +// associated keys* with the current state of kernel registrations in the +// dispatcher. +// After a kernel has been registered to a dispatch key, a call to this +// function will synchronize the dispatcher state. See e.g. registerKernel() void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) { // Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_. // See Note [Undefined in dispatchTable_] @@ -274,41 +286,46 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) { updateDispatchTableEntry_(dispatcher, k); } + // Registration to DefaultBackend and Math should be populated to Undefined. + // We cannot do this above since Undefined cannot be represented in DispatchKeySet. + if (dispatch_key == DispatchKey::Math || dispatch_key == DispatchKey::DefaultBackend) { + updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined); + } // Note [Refresh Runtime Autograd entries in dispatchTable_] // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3). - DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key); - updateDispatchTableEntry_(dispatcher, autograd_key); + if (c10::isBackendDispatchKey(dispatch_key)) { + DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key); + updateDispatchTableEntry_(dispatcher, autograd_key); + } } +// does a complete update of the dispatch table, synchronizing all +// runtime dispatch keys with the current state of kernel registrations +// in the dispatcher. +// Note that we use updateDispatchTable_() to perform our per-key updating, +// even though that function is equipped to handle out-of-order updates and +// alias key updates, neither of which we send it. This is deliberate - the +// current design is more tractable with all updates funneled through a single +// per-key update mechanism, than with multiple variations that assume different +// invariants. +// void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) { // Note [Undefined in dispatchTable_] + // DispatchKey Undefined is used in runtime: // (1) it gives people place to specify functionality that should run when there are no dispatch keys, - // e.g., an empty TensorList argument + // e.g., an op without Tensor inputs or empty TensorList arguments // (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when // no dispatch keys are available we just slide into the undefined handler which would then raise - // the error message./ + // the error message. + // In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to + // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either DefaultBackend + // or Math alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, Math) + // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet. for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { updateDispatchTable_(dispatcher, static_cast(iter)); } } -void OperatorEntry::setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func) { - TORCH_INTERNAL_ASSERT(!manuallyBoxedKernel_); - manuallyBoxedKernel_ = func; - - for (auto& kv : kernels_) { - for (auto& k : kv.second) { - k.kernel.setManuallyBoxedKernel_(func); - } - } - for (auto& k : catchAllKernel_) { - k.kernel.setManuallyBoxedKernel_(func); - } - - // Refresh entries in dispatchTable_ - updateDispatchTableFull_(dispatcher); -} - void OperatorEntry::checkInvariants() const { if (schema_) { TORCH_INTERNAL_ASSERT(schema_->schema.operator_name() == name_, dumpState()); @@ -358,7 +375,11 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const { } TORCH_CHECK(false, "Could not run '", name_, "' with arguments", - " from the '", toString(dispatchKey), "' backend. '", + " from the '", toString(dispatchKey), "' backend. This could be because " + "the operator doesn't exist for this backend, or was omitted during ", + "the selective/custom build process (if using custom build). If you are a ", + "Facebook employee using PyTorch on mobile, please visit ", + "https://fburl.com/ptmfixes for possible resolutions. '", name_, "' is only available for these backends: ", listAllDispatchKeys(), ".\n\n", dumpComputedTable()); } @@ -423,7 +444,6 @@ std::string OperatorEntry::dumpState() const { print_kernel(toString(k), it->second, c10::isAliasDispatchKey(k)); } } - print_kernel("catchall", catchAllKernel_); return oss.str(); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 5e449b136dace..44b8fac5661e9 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -61,7 +61,7 @@ struct AnnotatedSchema final { // Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher // lock (this is important because some methods in OperatorEntry access // dispatcher state) -class CAFFE2_API OperatorEntry final { +class TORCH_API OperatorEntry final { public: explicit OperatorEntry(OperatorName&& operator_name); @@ -148,23 +148,18 @@ class CAFFE2_API OperatorEntry final { const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; } - // This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed - // unboxing wrapper for aten operators. We still need those for some operators because not all work - // with the templated unboxing logic yet. - // TODO Delete setManuallyBoxedKernel_ once all operators work with the templated boxing logic - void setManuallyBoxedKernel_(const c10::Dispatcher& dispatcher, KernelFunction::InternalBoxedKernelFunction* func); - // Asserts that the given FuncType is correct for calling this operator in an unboxed way. template void assertSignatureIsCorrect() { - TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make() == *cpp_signature_), - "Tried to access operator ", name_, " with a wrong signature. Accessed with ", - CppSignature::make().name(), - " but the operator was registered with ", - cpp_signature_->name(), - " (", - (schema_.has_value() ? schema_->debug : "unknown debug info"), - ") This likely happened in a call to OperatorHandle::typed(). Please make sure that the function signature matches the signature in the operator registration call." + TORCH_CHECK(!cpp_signature_.has_value() || (CppSignature::make() == cpp_signature_->signature), + "\nTried to access or call an operator with a wrong signature.\n", + " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", + " ", (schema_.has_value() ? schema_->debug : "unknown debug info"), "\n", + " correct signature: ", cpp_signature_->signature.name(), "\n", + " ", cpp_signature_->debug, "\n", + " accessed/called as: ", CppSignature::make().name(), "\n", + "This likely happened in a call to OperatorHandle::typed(). ", + "Please make sure that the function signature matches the signature in the operator registration call." ); } @@ -188,12 +183,6 @@ class CAFFE2_API OperatorEntry final { std::array(DispatchKey::NumDispatchKeys)> dispatchTable_; DispatchKeyExtractor dispatchKeyExtractor_; - // This manuallyBoxedKernel_ member is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed - // unboxing wrapper for aten operators. We still need those for some operators because not all work - // with the templated unboxing logic yet. - // TODO Delete manuallyBoxedKernel_ once all operators work with the templated boxing logic - c10::optional manuallyBoxedKernel_; - // kernels_ stores all registered kernels for the corresponding dispatch key // and catchAllKernels_ stores the catch-all kernels. // If an operator library gets loaded that overwrites an already existing kernel, @@ -227,16 +216,21 @@ class CAFFE2_API OperatorEntry final { // currently not high-pri. ska::flat_hash_map> kernels_; - std::list catchAllKernel_; AnnotatedKernel missingKernel_; static const AnnotatedKernel ambiguousAutogradOtherKernel_; - // signature_hash_ is set to the hash of the function signature if any of + // cpp_signature_ stores function signature if any of // the kernels was created in a way that allowed us to know the function // signature (i.e. by supplying an unboxed C++ kernel function). - // If this is set, it will be used in unboxed function calls + // If this is set, it will be used to check that future kernel + // registrations match and it will be used in unboxed function calls // to verify their arguments against the known function signature. - c10::optional cpp_signature_; + struct CppSignatureWithDebug { + CppSignature signature; + std::string debug; + c10::optional dispatch_key; + }; + c10::optional cpp_signature_; // Whether this operator needs to be observed with RecordFunction const bool is_observed_; @@ -254,7 +248,7 @@ class CAFFE2_API OperatorEntry final { void updateDispatchTableFull_(const c10::Dispatcher& dispatcher); // Returns true if kernel_ has entry for any key in ks. - bool hasKernelForDispatchKeySet(DispatchKeySet ks) const; + bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const; // Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front(). c10::optional getKernelForDispatchKey(DispatchKey dispatch_key) const; }; diff --git a/aten/src/ATen/core/function.h b/aten/src/ATen/core/function.h index 8c705a49b9e23..75592bf823b8c 100644 --- a/aten/src/ATen/core/function.h +++ b/aten/src/ATen/core/function.h @@ -8,6 +8,10 @@ namespace c10 { struct FunctionSchema; }; +namespace at { +TORCH_API void launch(std::function func); +} + namespace torch { namespace jit { @@ -17,21 +21,29 @@ struct GraphExecutor; using Stack = std::vector; using Kwargs = std::unordered_map; struct RecursiveMethodCallError : public std::exception {}; +using TaskLauncher = std::function)>; TORCH_API void preoptimizeGraph(std::shared_ptr& graph); // A Function is a pure Graph with no implicit `self` object bound. -// It contains schema information, and the executor that manages the -// execution of the function. Method is a wrapper around a +// It contains schema information and the executor that manages the +// execution of the function. Method is a wrapper around an // underlying Function that also provides a `self` object. struct TORCH_API Function { + virtual const std::string& doc_string() const { + static const std::string no_doc_string = ""; + return no_doc_string; + } + virtual bool isGraphFunction() const = 0; virtual void run(Stack& stack) = 0; virtual void run(Stack&& stack) = 0; - virtual c10::intrusive_ptr runAsync(Stack& stack) = 0; + virtual c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher taskLauncher = at::launch) = 0; virtual at::IValue operator()( std::vector stack, diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index a9182787d2e6b..624ded76ffda9 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -38,7 +38,7 @@ struct Argument { const std::string& name() const { return name_; } - TypePtr type() const { + const TypePtr& type() const { return type_; } c10::optional N() const { @@ -107,7 +107,7 @@ struct Argument { c10::optional N_; c10::optional default_value_; - // is this only specifyable as a keyword argument? + // is this only specifiable as a keyword argument? bool kwarg_only_; c10::optional alias_info_; }; @@ -156,18 +156,29 @@ struct FunctionSchema { checkSchema(); } - // check whether this schema is backward compatible with the old one. - // the following conditions are considered as this schema is backward - // compatible with old: - // 1) two schemas are equal - // 2) this schema has the same or more positional args than old, - // and any positional arg in this schema is backward compatible - // with the corresponding one in old schema, which could be an arg - // or a kwarg, if it has, or it must provide a default value - // 3) this schema has the same or more kwargs than old, and all the kwargs - // in old schema can find the corresponding kwarg in this schema which - // is backward compatible with the old kwarg, and the extra kwargs in - // this schema must provide default values. + // Checks whether this schema is backward compatible with the old one. + // The following conditions must be true: + // [Function structure] The new schema's name, overload-name, varargs, and + // return arity are the same. + // [Output Narrowing] The new schema's output type must be the same class + // or inherit from the old schema's output type. + // [Argument count] The new schema must have at least as many arguments as + // the old schema (considering the list of positional and kwargs). + // [Arg Compatibility] Every argument in the old schema has a corresponding + // argument in the new schema that: + // * is at the same position. + // * has the same name. + // * is either positional, or kwarg and the old argument was kwarg. + // * has the same type, or the old argument's type inherits from the + // new argument's type. + // [Default Values] Every new argument must have a default value. + // E.g. + // OK f_new(a, b, c=1) => f_old(a, b) + // NOK f_new(a, c=1, *, b) => f_old(a, *, b) + // OK f_new(a, b, *, c) => f_old(a, *, b, c) + // NOK f_new(a, *, b, c) -> f_old(a, b, *, c) + // NOK f_new(a, *, c, b) => f_old(a, *, b, c) + // OK f_new(a, *, b, c, d=1) => f_old(a, *, b, c) bool isBackwardCompatibleWith( const FunctionSchema& old, std::ostream* why_not = nullptr) const; diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index bc9a68fbad3f2..168ecb4f3dc17 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -111,69 +111,35 @@ inline bool FunctionSchema::isBackwardCompatibleWith( return false; } for (size_t i = 0; i < returns().size(); ++i) { - // functions are covariant in arguments but contravariant in returns + // Backwards compatibility requires covariance on argument types + // (i.e. more generic), and contravariance on return types (i.e. + // more specific). if (!old.returns().at(i).isBackwardCompatibleWith( returns().at(i), why_not)) { return false; } } - std::vector args, old_args; - std::map kwargs, old_kwargs; - auto split_func = [](const std::vector& arguments, - std::vector* positionals, - std::map* nameds) { - for (const Argument& arg : arguments) { - if (!arg.kwarg_only()) { - positionals->emplace_back(&arg); - } - nameds->emplace(arg.name(), &arg); - } - }; - // we split args into positional and keyward parts, - split_func(arguments(), &args, &kwargs); - split_func(old.arguments(), &old_args, &old_kwargs); - if (old_args.size() > args.size()) { - return false; - } - // make sure that all the old positional args have their corresponding - // backward compatible positional args in this schema - for (size_t i = 0; i < old_args.size(); ++i) { - if (!args.at(i)->isBackwardCompatibleWith( - *old_args.at(i), - why_not)) { + + // Make sure that all the old arguments have their corresponding backward + // compatible arguments in this schema. + for (size_t i = 0; i < old.arguments().size(); ++i) { + if (!arguments().at(i).isBackwardCompatibleWith( + old.arguments().at(i), why_not)) { return false; } } - // check the extra positional args in this schema either has corresponding - // backward compatible keyward args since positional args also can be used as - // a keyward arg, or provided default values - for (size_t i = old_args.size(); i < args.size(); ++i) { - if (!args.at(i)->default_value()) { - auto it = old_kwargs.find(args.at(i)->name()); - if (it == old_kwargs.end() || - !args.at(i)->isBackwardCompatibleWith( - *it->second, - why_not)) { - return false; + + // Validate that all new arguments provided a default value. + for (size_t i = old.arguments().size(); i < arguments().size(); ++i) { + if (!arguments().at(i).default_value()) { + if (why_not) { + *why_not + << "Function schema not backward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() + << " did not provide a default value."; } - } - } - // make sure that all the keyword args in the old schema have their - // corresponding backward compatible keyward args in this schema - for (auto& kv : old_kwargs) { - auto it = kwargs.find(kv.first); - if (it == kwargs.end() || - !it->second->isBackwardCompatibleWith( - *kv.second, - why_not)) { - return false; - } - kwargs.erase(it); - } - // check all the extra keyword args in this schema provide default values - for (auto& kv : kwargs) { - if (!kv.second->default_value()) { return false; } } @@ -185,8 +151,11 @@ inline void FunctionSchema::checkArg( const IValue& value, const Argument& argument, optional pos) const { + if (value.isTensor() && argument.type() == TensorType::get()) { + // Fast-path for the common case + return; + } if (!value.type()->isSubtypeOf(argument.type())) { - std::string position = pos ? ::c10::str(" in position ", *pos) : ""; TORCH_CHECK( false, formatTypeMismatchMsg( @@ -323,12 +292,12 @@ inline bool FunctionSchema::isSubtypeOf( bool as_method, std::ostream* why_not) const { size_t start = as_method ? 1 : 0; - // functions are covariant in arguments but contravariant in returns + // functions are contravariant in arguments but covariant in returns return isSubtypeOfList( - ArrayRef(arguments()).slice(start), ArrayRef(rhs.arguments()).slice(start), + ArrayRef(arguments()).slice(start), why_not) && - isSubtypeOfList(rhs.returns(), returns(), why_not); + isSubtypeOfList(returns(), rhs.returns(), why_not); } } // namespace c10 diff --git a/aten/src/ATen/core/grad_mode.h b/aten/src/ATen/core/grad_mode.h index acd5fd09e5ffc..84f8c6dce14e6 100644 --- a/aten/src/ATen/core/grad_mode.h +++ b/aten/src/ATen/core/grad_mode.h @@ -4,14 +4,14 @@ namespace at { -struct CAFFE2_API GradMode { +struct TORCH_API GradMode { static bool is_enabled(); static void set_enabled(bool enabled); }; // A RAII, thread local (!) guard that enables or disables grad mode upon // construction, and sets it back to the original value upon destruction. -struct CAFFE2_API AutoGradMode { +struct TORCH_API AutoGradMode { AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { GradMode::set_enabled(enabled); } @@ -23,7 +23,7 @@ struct CAFFE2_API AutoGradMode { // A RAII, thread local (!) guard that stops future operations from building // gradients. -struct CAFFE2_API NoGradGuard : public AutoGradMode { +struct TORCH_API NoGradGuard : public AutoGradMode { NoGradGuard() : AutoGradMode(/*enabled=*/false) {} }; diff --git a/aten/src/ATen/core/interned_strings.cpp b/aten/src/ATen/core/interned_strings.cpp index 2e1753167a2b9..ace56844fc6cb 100644 --- a/aten/src/ATen/core/interned_strings.cpp +++ b/aten/src/ATen/core/interned_strings.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +29,9 @@ std::pair InternedStrings::string(Symbol sym) { // we can bypass the need to acquire a lock // to read the map for Builtins because we already // know their string value +#if defined C10_MOBILE + return customString(sym); +#else switch (sym) { #define DEFINE_CASE(ns, s) \ case static_cast(ns::s): \ @@ -37,9 +41,14 @@ std::pair InternedStrings::string(Symbol sym) { default: return customString(sym); } +#endif } Symbol InternedStrings::ns(Symbol sym) { +#if defined C10_MOBILE + std::lock_guard guard(mutex_); + return sym_to_info_.at(sym).ns; +#else switch (sym) { #define DEFINE_CASE(ns, s) \ case static_cast(ns::s): \ @@ -51,6 +60,7 @@ Symbol InternedStrings::ns(Symbol sym) { return sym_to_info_.at(sym).ns; } } +#endif } Symbol InternedStrings::_symbol(const std::string& s) { diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index b279a24003508..7155651b3006d 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -17,6 +17,7 @@ namespace c10 { #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ _(namespaces, aten) \ + _(namespaces, cuda) \ _(namespaces, onnx) \ _(namespaces, attr) \ _(namespaces, scope) \ @@ -27,6 +28,7 @@ namespace c10 { _(prim, Assign) \ _(prim, BroadcastingChunk) \ _(prim, BroadcastSizes) \ + _(prim, ReductionSizes) \ _(prim, Constant) \ _(prim, ChunkSizes) \ _(prim, Drop) \ @@ -34,9 +36,11 @@ namespace c10 { _(prim, Expand) /* onnx */ \ _(prim, FusionGroup) \ _(prim, CudaFusionGroup) \ + _(prim, CudaFusionGuard) \ _(prim, FunctionalGraph) \ _(prim, DifferentiableGraph) \ _(prim, TensorExprGroup) \ + _(prim, StaticSubgraph) \ _(prim, If) \ _(prim, Jump) /* debug */ \ _(prim, JumpNZ) /* debug */ \ @@ -55,7 +59,7 @@ namespace c10 { _(prim, ReturnStmt) \ _(prim, BreakStmt) \ _(prim, ContinueStmt) \ - _(prim, LocalVariableScope) \ + _(prim, ComprehensionScope) \ _(prim, Store) \ _(prim, AutogradZero) \ _(prim, AutogradAnyNonZero) \ @@ -69,6 +73,7 @@ namespace c10 { _(prim, ListConstruct) \ _(prim, ListUnpack) \ _(prim, DictConstruct) \ + _(prim, ModuleDictIndex) \ _(prim, EnumName) \ _(prim, EnumValue) \ _(prim, StringIndex) \ @@ -98,11 +103,13 @@ namespace c10 { _(prim, Guard) \ _(prim, BailOut) \ _(prim, TypeCheck) \ + _(prim, RequiresGradCheck) \ _(prim, FallbackGraph) \ _(prim, FusedConcat) \ _(prim, ConstantChunk) \ _(prim, MMTreeReduce) \ _(prim, MMBatchSide) \ + _(prim, list) \ _(prim, min) \ _(prim, max) \ _(prim, abs) \ @@ -128,13 +135,14 @@ namespace c10 { _(prim, fork) \ _(prim, forkClosure) \ _(prim, RaiseException) \ - _(prim, Function) \ + _(prim, Closure) \ _(prim, CreateObject) \ _(prim, SetAttr) \ _(prim, GetAttr) \ _(prim, HasAttr) \ _(prim, profile) \ _(prim, profile_optional) \ + _(prim, profile_ivalue) \ _(prim, AddStatValue) \ _(prim, TimePoint) \ _(prim, CallFunction) \ @@ -185,6 +193,7 @@ namespace c10 { _(aten, append) \ _(aten, item) \ _(aten, format) \ + _(aten, percentFormat) \ _(aten, __not__) \ _(aten, __is__) \ _(aten, __isnot__) \ @@ -220,6 +229,7 @@ namespace c10 { _(aten, lt_) \ _(aten, less) \ _(aten, less_) \ + _(aten, isnan) \ _(aten, mul) \ _(aten, mul_) \ _(aten, multiply) \ @@ -231,6 +241,7 @@ namespace c10 { _(aten, _ger) \ _(aten, ger) \ _(aten, outer) \ + _(aten, transpose) \ _(aten, transpose_) \ _(aten, unsqueeze_) \ _(aten, __getitem__) \ @@ -266,12 +277,25 @@ namespace c10 { _(aten, bin) \ _(aten, pop) \ _(aten, insert) \ + _(aten, vstack) \ + _(aten, row_stack) \ _(prim, unchecked_unwrap_optional) \ _(aten, __contains__) \ _(prim, BailoutTemplate) \ _(prim, grad) \ _(aten, zero_) \ _(aten, fill_) \ + _(aten, masked_fill_) \ + _(cuda, _set_device) \ + _(cuda, set_stream) \ + _(cuda, _current_device) \ + _(aten, swapaxes) \ + _(aten, swapaxes_) \ + _(aten, swapdims) \ + _(aten, swapdims_) \ + _(aten, movedim) \ + _(aten, moveaxis) \ + _(aten, has_torch_function) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ @@ -323,6 +347,7 @@ namespace c10 { _(onnx, ReduceL2) \ _(onnx, Conv) \ _(onnx, BatchNormalization) \ + _(onnx, ReduceProd) \ FORALL_ATTR_BASE_SYMBOLS(_) \ _(attr, Subgraph) \ _(attr, ReverseSubgraph) \ @@ -359,11 +384,13 @@ namespace c10 { _(attr, scope) \ _(attr, keepdims) \ _(attr, cache_id) \ - _(attr, new_axis) + _(attr, new_axis) \ + _(attr, warn_id) #else #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ _(namespaces, aten) \ + _(namespaces, cuda) \ _(namespaces, onnx) \ _(namespaces, attr) \ _(namespaces, scope) \ @@ -416,7 +443,7 @@ const std::string& domain_prefix(); // A Symbol is like an interned string, but with a little extra // structure; it is namespaced via SymbolNamespace and the resulting // intern pointers support efficient namespace testing. -struct CAFFE2_API Symbol { +struct TORCH_API Symbol { explicit constexpr Symbol() : value(0) {}; explicit constexpr Symbol(unique_t uniq) : value(uniq) {} @@ -434,6 +461,7 @@ struct CAFFE2_API Symbol { // (and if it's not, you should add it to the built-ins list above.) static Symbol attr(const std::string & s); static Symbol aten(const std::string & s); + static Symbol cuda(const std::string & s); static Symbol onnx(const std::string & s); static Symbol prim(const std::string & s); static Symbol user(const std::string & s); @@ -444,6 +472,7 @@ struct CAFFE2_API Symbol { bool is_attr() const; bool is_aten() const; + bool is_cuda() const; bool is_prim() const; bool is_onnx() const; bool is_user() const; @@ -504,6 +533,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL) inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); } inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); } +inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); } inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); } inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); } inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); } @@ -512,6 +542,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); } inline bool Symbol::is_attr() const { return ns() == namespaces::attr; } inline bool Symbol::is_aten() const { return ns() == namespaces::aten; } +inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; } inline bool Symbol::is_prim() const { return ns() == namespaces::prim; } inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; } inline bool Symbol::is_user() const { return ns() == namespaces::user; } diff --git a/aten/src/ATen/core/interned_strings_class.h b/aten/src/ATen/core/interned_strings_class.h index b13e3f18eba89..54303e0384d28 100644 --- a/aten/src/ATen/core/interned_strings_class.h +++ b/aten/src/ATen/core/interned_strings_class.h @@ -11,7 +11,7 @@ namespace c10 { -struct CAFFE2_API InternedStrings { +struct TORCH_API InternedStrings { InternedStrings(); Symbol symbol(const std::string& s); std::pair string(Symbol sym); diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index e786d36256430..6ff7a52fd9cc9 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace c10 { @@ -21,7 +22,7 @@ namespace ivalue { // This is in ivalue.cpp because we need to access Type::annotation_str, which // is declared in jit_type.h -void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) { +void checkCustomClassType(const Type* expected_type, const Type* actual_type) { // NB: doing pointer comparison here // If in the future there ever arises a need to call operator== on custom class // Type's, this needs to be changed! @@ -32,7 +33,7 @@ void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) { expected_type->repr_str()); } -CAFFE2_API c10::intrusive_ptr ConstantString::create( +TORCH_API c10::intrusive_ptr ConstantString::create( std::string str_) { return c10::make_intrusive(std::move(str_)); } @@ -75,6 +76,8 @@ TypePtr IValue::type() const { return NoneType::get(); case Tag::Tensor: return TensorType::create(toTensor()); + case Tag::Storage: + return StorageType::get(); case Tag::Double: return FloatType::get(); case Tag::Int: @@ -97,6 +100,8 @@ TypePtr IValue::type() const { return RRefType::create(toRRef()->type()); case Tag::Device: return DeviceObjType::get(); + case Tag::Stream: + return StreamObjType::get(); case Tag::Object: return toObjectRef().type(); case Tag::PyObject: @@ -120,7 +125,7 @@ TypePtr IValue::type() const { void IValue::visit(const std::function& visitor) const { if (visitor(*this)) { - // Short cut. + // Shortcut return; } switch (this->tag) { @@ -153,6 +158,15 @@ void IValue::visit(const std::function& visitor) const { } break; } + case Tag::PyObject: { + c10::intrusive_ptr py_obj = toPyObjectHolder(); + auto match = py_obj->tryToInferType(); + if (match.success()) { + auto contained_value = py_obj->toIValue(match.type()); + contained_value.visit(visitor); + } + break; + } default: break; } @@ -196,9 +210,18 @@ void IValue::getSubValues(HashAliasedIValues& subValues) const { } break; } + case Tag::PyObject: { + subValues.insert(*this); + c10::intrusive_ptr py_obj = toPyObjectHolder(); + auto match = py_obj->tryToInferType(); + TORCH_INTERNAL_ASSERT(match.success(), + "Cannot infer type of ", py_obj->toStr(), "\n:", match.reason()); + auto contained_value = py_obj->toIValue(match.type()); + contained_value.getSubValues(subValues); + break; + } case Tag::Future: case Tag::Device: - case Tag::PyObject: case Tag::Uninitialized: case Tag::Capsule: TORCH_INTERNAL_ASSERT( @@ -242,7 +265,7 @@ bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) { TORCH_INTERNAL_ASSERT(lhs.is_intrusive_ptr); TORCH_INTERNAL_ASSERT(rhs.is_intrusive_ptr); return lhs.tag == rhs.tag && - lhs.payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } IValue IValue::equals(const IValue& rhs) const { @@ -257,6 +280,8 @@ IValue IValue::equals(const IValue& rhs) const { return false; } return lhs.toTensor().eq(rhs.toTensor()); + case Tag::Storage: + return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl(); case Tag::Double: return rhs.isDouble() && lhs.toDouble() == rhs.toDouble(); case Tag::Int: @@ -269,6 +294,8 @@ IValue IValue::equals(const IValue& rhs) const { return rhs.isGenericDict() && lhs.toGenericDict() == rhs.toGenericDict(); case Tag::Tuple: return rhs.isTuple() && *lhs.toTuple() == *rhs.toTuple(); + case Tag::Stream: + return rhs.isStream() && lhs.toStream() == rhs.toStream(); case Tag::Device: return rhs.isDevice() && lhs.toDevice() == rhs.toDevice(); case Tag::GenericList: @@ -293,6 +320,48 @@ IValue IValue::equals(const IValue& rhs) const { TORCH_INTERNAL_ASSERT(false, "we should never reach here") } +size_t IValue::hash(const IValue& v) { + switch (v.tag) { + case Tag::None: + return 0; + case Tag::Bool: + return c10::get_hash(v.payload.u.as_bool); + case Tag::Double: + return c10::get_hash(v.payload.u.as_double); + case Tag::Tensor: + // Tensor __hash__ is equivalent to `id()`, so take the pointer value of + // the tensor to emulate it + return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl()); + case Tag::Storage: + return c10::get_hash(v.payload.u.as_int); + case Tag::Int: + return c10::get_hash(v.payload.u.as_int); + case Tag::String: + return c10::get_hash(v.toStringRef()); + case Tag::Tuple: + return c10::get_hash(*v.toTuple()); + case Tag::Device: + return c10::get_hash(v.toDevice()); + case Tag::GenericDict: + case Tag::GenericList: + case Tag::Blob: + case Tag::Future: + case Tag::RRef: + case Tag::Object: + case Tag::PyObject: + case Tag::Capsule: + case Tag::Generator: + case Tag::Quantizer: + case Tag::Enum: + case Tag::Stream: + case Tag::Uninitialized: + throw std::runtime_error( + "unhashable type: '" + v.type()->repr_str() + "'"); + } + // the above switch should be exhaustive + TORCH_INTERNAL_ASSERT(false, "we should never reach here") +} + static bool isUndefinedTensor(const IValue& iv) { return iv.isTensor() && !iv.toTensor().defined(); } @@ -348,7 +417,7 @@ std::ostream& printMaybeAnnotatedList( std::ostream& out, const IValue& the_list, IValueFormatter formatter) { - auto list_elem_type = the_list.type()->expect()->getElementType(); + auto list_elem_type = the_list.type()->expectRef().getElementType(); if (the_list.toListRef().size() == 0 || !elementTypeCanBeInferredFromMembers(list_elem_type)) { out << "annotate(" << the_list.type()->annotation_str() << ", "; @@ -423,12 +492,16 @@ std::ostream& IValue::repr( if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) { int64_t i = int64_t(d); if (double(i) == d) { + // -0.0 (signed zero) needs to be parsed as -0. + if (i == 0 && std::signbit(d)) { + return out << "-" << i << "."; + } return out << i << "."; } } auto orig_prec = out.precision(); return out << std::setprecision(std::numeric_limits::max_digits10) - << v.toDouble() << std::setprecision(orig_prec); + << d << std::setprecision(orig_prec); } case IValue::Tag::Int: return out << v.toInt(); @@ -459,6 +532,9 @@ std::ostream& IValue::repr( return out << enum_holder->qualifiedClassName() << "." << enum_holder->name(); } + case IValue::Tag::Object: { + TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?"); + } default: TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind()); } @@ -595,6 +671,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { return out << v.toNone(); case IValue::Tag::Tensor: return out << v.toTensor(); + case IValue::Tag::Storage: + return out << v.toStorage().unsafeGetStorageImpl(); case IValue::Tag::Double: { double d = v.toDouble(); int c = std::fpclassify(d); @@ -634,6 +712,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { return out << "Uninitialized"; case IValue::Tag::Device: return out << v.toDevice(); + case IValue::Tag::Stream: + return out << v.toStream(); case IValue::Tag::GenericDict: return printDict(out, v.toGenericDict(), formatter); case IValue::Tag::PyObject: { @@ -825,7 +905,7 @@ getClassConverter() { return classConverter; } -CAFFE2_API intrusive_ptr collectAll( +TORCH_API intrusive_ptr collectAll( List> srcs) { struct Ctx { explicit Ctx(List> srcs) @@ -857,7 +937,7 @@ CAFFE2_API intrusive_ptr collectAll( return ctx->dstFuture; } -CAFFE2_API intrusive_ptr collectAny( +TORCH_API intrusive_ptr collectAny( List> srcs) { if (srcs.empty()) { auto res = make_intrusive(NoneType::get()); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 101af2beb5184..5633e52446cb4 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -17,8 +17,10 @@ struct Module; } // namespace jit } // namespace torch namespace c10 { -template class Dict; -template class List; +template +class Dict; +template +class List; struct IValue; struct ClassType; struct Type; @@ -30,7 +32,9 @@ using ClassTypePtr = std::shared_ptr; TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs); -TORCH_API torch::jit::Function* checkObjectSortSchema(const c10::ClassTypePtr& t, std::stringstream& why_not); +TORCH_API torch::jit::Function* checkObjectSortSchema( + const c10::ClassTypePtr& t, + std::stringstream& why_not); // A comparator that checks ordering of two IValues of same type. typedef std::function IValueComparator; @@ -46,18 +50,18 @@ struct GenericDict; struct Object; struct PyObjectHolder; struct EnumHolder; -} +} // namespace ivalue // This is an owning wrapper for a c10::optional> // that can be implicitly converted to a (non-owning) optional>. // Its purpose is to be used in generated code to keep the vector alive // either until the end of a statement (as a temporary), or as a saved arg // in autograd. -template +template struct OptionalArray { c10::optional> list; - OptionalArray() {}; + OptionalArray(){} OptionalArray(std::vector val) : list(std::move(val)) {} // Used when saving an argument for the backwards pass. @@ -78,6 +82,19 @@ struct OptionalArray { } }; +// Capsule is an internal implementation detail of custom C++ classes. We +// define it as an owning wrapper for +// c10::intrusive_ptr This wrapper is here to serve as +// an abstraction of the type erased custom class object pointer. It also allow +// pybind11 to treat this as a standalone class to register as a separate type +// caster, instead of a custom pointer holder which the pointer holder type +// caster try to "unwrap" it automatically. +struct Capsule { + c10::intrusive_ptr obj_ptr; + explicit Capsule(c10::intrusive_ptr ptr) + : obj_ptr(std::move(ptr)) {} +}; + // IValue is the generic tagged union used by the interpreter to hold // all value types. // It is a 16-byte object with an 8-byte payload and an 8-byte tag. @@ -88,6 +105,7 @@ struct OptionalArray { #define TORCH_FORALL_TAGS(_) \ _(None) \ _(Tensor) \ + _(Storage) \ _(Double) \ _(Int) \ _(Bool) \ @@ -98,6 +116,7 @@ struct OptionalArray { _(GenericDict) \ _(Future) \ _(Device) \ + _(Stream) \ _(Object) \ _(PyObject) \ _(Uninitialized) \ @@ -105,18 +124,22 @@ struct OptionalArray { _(RRef) \ _(Quantizer) \ _(Generator) \ - _(Enum) \ + _(Enum) // [doxygen private] // These methods are not actually private but we don't want to document them, so // they are marked `@private`, which hides them on the doxygen documentation for // this page. - -/// IValue (Interpreter Value) is a tagged union over the types supported by the -/// TorchScript interpreter. IValues contain their values as an `IValue::Payload`, -/// which holds primitive types (`int64_t`, `bool`, `double`, `Device`), as -/// values and all other types as a `c10::intrusive_ptr`. +/// IValue (Interpreter Value) is a tagged union over the types +/// supported by the TorchScript interpreter. IValues contain their +/// values as an `IValue::Payload`, which holds primitive types +/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values, +/// and all other types as a `c10::intrusive_ptr`. In order to +/// optimize performance of the destructor and related operations by +/// making the `Tensor` and `c10::intrusive_ptr` paths generate the +/// same code, we represent a null `c10::intrusive_ptr` as +/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`. /// /// IValues are used as inputs to and outputs from the TorchScript interpreter. /// To retrieve the value contained within an IValue, use the `.toX()` methods, @@ -139,47 +162,57 @@ struct OptionalArray { /// // `my_ivalue` is tagged as an int and cannot be used as another type /// torch::Tensor my_tensor = my_ivalue.toTensor() /// \endrst -struct CAFFE2_API IValue final { +struct TORCH_API IValue final { IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) { - if (is_intrusive_ptr) { - c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr); + if (is_intrusive_ptr && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); } } - IValue(IValue&& rhs) noexcept : IValue() { - swap(rhs); + + IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { + moveFrom(std::move(rhs)); } + /// @private [doxygen private] ~IValue() { - if (is_intrusive_ptr) { - c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr); - } + destroy(); } - IValue& operator=(IValue&& rhs) & noexcept { - IValue(std::move(rhs)).swap(*this); // this also sets rhs to None + + C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { + if (&rhs == this) { + return *this; + } + + destroy(); + moveFrom(std::move(rhs)); return *this; } + IValue& operator=(IValue const& rhs) & { IValue(rhs).swap(*this); return *this; } + void dump() const; /** * Equality comparison. The semantics are the same as Python's `==`: * 1. Numerical types are compared by value. - * 2. Tensors compute element-wise equality, returning a BoolTensor (see: `torch.eq()`) + * 2. Tensors compute element-wise equality, returning a BoolTensor (see: + * `torch.eq()`) * 3. Strings are compared by value. * 4. Sequence types (list, tuple) are compared lexicographically by * comparing their elements. Different sequence types never compare equal. * 5. Mappings (dict) must have equal (key, value) pairs. - * 6. If not listed above, the default behavior for is to test identity equality - * (e.g. pointer equality). + * 6. If not listed above, the default behavior for is to test identity + * equality (e.g. pointer equality). * * Why does this return an IValue instead of a bool? Because in PyTorch, * `tensor1 == tensor2` returns a `BoolTensor`, not a bool. * - * NOTE: we (like Python) assume that identity equality implies value equality for efficiency. + * NOTE: we (like Python) assume that identity equality implies value equality + * for efficiency. * TODO: need to support customizing equality */ IValue equals(const IValue& rhs) const; @@ -200,6 +233,24 @@ struct CAFFE2_API IValue final { */ bool is(const IValue& rhs) const; + /** + * Hashing for IValues. Returns an IValue-boxed int. + * + * Some notes: + * - Like eager, Tensors are hashed by looking at the pointer. This is not + * strictly correct because two value-equal tensors with different tensor + * pointers will hash differently, but we choose to reproduce the eager + * semantics. + * - Hashing is not defined on all built-in IValue types (e.g. list and + * dict), following Python. Calling `hash()` on these types will throw. + */ + IValue hash() const { + return (int64_t)IValue::hash(*this); + } + // This is defined because `c10::hash` dispatches to a function of this + // signature. See the member function `hash()`. + static size_t hash(const IValue& iv); + /** * @private [doxygen private] * [container equality] @@ -207,10 +258,13 @@ struct CAFFE2_API IValue final { * identity equal themselves, for efficiency reasons. We primarily have this * for consistency, because Python does the same thing. This actually * provokes user-visible changes in behavior due to quirks in torch: - * [tensor1] == [tensor1] -> True (because container equality will first compare identity) - * [tensor1] == [tensor1_copy] -> RuntimeError: bool value of Tensor is ambiguous + * [tensor1] == [tensor1] -> True (because container equality will first + * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError: bool value + * of Tensor is ambiguous */ - TORCH_API friend bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs); + TORCH_API friend bool _fastEqualsForContainer( + const IValue& lhs, + const IValue& rhs); /// @private [doxygen private] bool isAliasOf(const IValue& rhs) const { @@ -219,6 +273,13 @@ struct CAFFE2_API IValue final { return false; } + // Tensors should be compared based on internal storage + if (this->isTensor()) { + const auto& thisTensor = this->toTensor(); + const auto& rhsTensor = rhs.toTensor(); + return thisTensor.is_alias_of(rhsTensor); + } + if (!this->is_intrusive_ptr) { // Primitive types don't alias anything return false; @@ -226,29 +287,49 @@ struct CAFFE2_API IValue final { AT_ASSERT(rhs.is_intrusive_ptr); - // Tensors should be compared based on internal storage - if (this->isTensor()) { - const auto thisTensor = this->toTensor(); - const auto rhsTensor = rhs.toTensor(); - return thisTensor.is_alias_of(rhsTensor); - } - // Other types can be compared by their ptr value - return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } /// @private [doxygen private] size_t use_count() const noexcept { + if (isTensor()) { + return payload.as_tensor.use_count(); + } + if (!is_intrusive_ptr) { return 1; } - return c10::raw::intrusive_ptr::use_count(payload.as_intrusive_ptr); + if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { + return 0; + } + return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr); } /// @private [doxygen private] - void swap(IValue & rhs) noexcept { - std::swap(payload, rhs.payload); + void swap(IValue& rhs) noexcept { + if (isTensor() && rhs.isTensor()) { + std::swap(payload.as_tensor, rhs.payload.as_tensor); + } else if (isTensor()) { + at::Tensor t = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + payload.u = rhs.payload.u; + new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); + } else if (rhs.isTensor()) { + rhs.swap(*this); + return; + } else { + std::swap(payload.u, rhs.payload.u); + } std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); std::swap(tag, rhs.tag); } @@ -257,21 +338,32 @@ struct CAFFE2_API IValue final { // While some of these accessors could be generated through templates, // we prefer to write them manually for clarity - IValue(at::Tensor t) - : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) { + IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) { + new (&payload.as_tensor) at::Tensor(std::move(t)); + } + bool isTensor() const { + return Tag::Tensor == tag; + } + at::Tensor toTensor() &&; + at::Tensor& toTensor() &; + const at::Tensor& toTensor() const&; + at::TensorImpl* unsafeToTensorImpl() const { + return payload.as_tensor.unsafeGetTensorImpl(); + } + + IValue(at::Storage s) : tag(Tag::Storage), is_intrusive_ptr(static_cast(s)) { // Note: the undefined tensor is not refcounted, so while it // is tagged as a tensor, is_intrusive_ptr is set to false. // This is not an optional optimization: our incref call // *will not* do the right thing when called on an // undefined tensor. - payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(s.unsafeReleaseStorageImpl()); } - bool isTensor() const { return Tag::Tensor == tag; } - at::Tensor toTensor() &&; - at::Tensor toTensor() const &; - at::TensorImpl* unsafeToTensorImpl() const { - return static_cast(payload.as_intrusive_ptr); + bool isStorage() const { + return Tag::Storage == tag; } + c10::Storage toStorage() &&; + c10::Storage toStorage() const&; const IValue& toIValue() const { return *this; @@ -282,10 +374,10 @@ struct CAFFE2_API IValue final { /// @private [doxygen private] IValue(intrusive_ptr blob) - : tag(Tag::Blob), is_intrusive_ptr(true) { + : tag(Tag::Blob), is_intrusive_ptr(true) { // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract // and store it as a Tensor instead. - payload.as_intrusive_ptr = blob.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); } /// @private [doxygen private] @@ -297,26 +389,30 @@ struct CAFFE2_API IValue final { c10::intrusive_ptr toBlob() &&; /// @private [doxygen private] - c10::intrusive_ptr toBlob() const &; + c10::intrusive_ptr toBlob() const&; - // Capsule. Capsule is an internal implementation detail - // of custom C++ classes. No new callsites of these APIs should + // Capsule. No new callsites of these APIs should // be introduced. - static inline IValue make_capsule(intrusive_ptr blob); + static inline IValue make_capsule( + intrusive_ptr blob); bool isCapsule() const { return Tag::Capsule == tag; } c10::intrusive_ptr toCapsule() &&; - c10::intrusive_ptr toCapsule() const &; + c10::intrusive_ptr toCapsule() const&; // Custom C++ classes - template ::value, int> = 0> + template < + typename T, + std::enable_if_t< + std::is_base_of::value, + int> = 0> IValue(intrusive_ptr custom_class); bool isCustomClass() const; template c10::intrusive_ptr toCustomClass() &&; template - c10::intrusive_ptr toCustomClass() const &; + c10::intrusive_ptr toCustomClass() const&; // Tuple IValue(c10::intrusive_ptr v); @@ -326,165 +422,177 @@ struct CAFFE2_API IValue final { std::enable_if_t< !guts::disjunction< std::is_lvalue_reference..., - guts::negation>...>:: - value, + guts::negation>...>::value, std::nullptr_t> = nullptr> IValue(const std::tuple& t); - bool isTuple() const { return Tag::Tuple == tag; } + bool isTuple() const { + return Tag::Tuple == tag; + } c10::intrusive_ptr toTuple() &&; - c10::intrusive_ptr toTuple() const &; + c10::intrusive_ptr toTuple() const&; // Double - IValue(double d) - : tag(Tag::Double), is_intrusive_ptr(false) { - payload.as_double = d; + IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) { + payload.u.as_double = d; + } + bool isDouble() const { + return Tag::Double == tag; } - bool isDouble() const { return Tag::Double == tag; } double toDouble() const { AT_ASSERT(isDouble()); - return payload.as_double; + return payload.u.as_double; } // Future IValue(c10::intrusive_ptr v); - bool isFuture() const { return Tag::Future == tag; } + bool isFuture() const { + return Tag::Future == tag; + } c10::intrusive_ptr toFuture() &&; - c10::intrusive_ptr toFuture() const &; + c10::intrusive_ptr toFuture() const&; // RRef IValue(c10::intrusive_ptr v); - bool isRRef() const { return Tag::RRef == tag; } + bool isRRef() const { + return Tag::RRef == tag; + } c10::intrusive_ptr toRRef() &&; - c10::intrusive_ptr toRRef() const &; + c10::intrusive_ptr toRRef() const&; // Quantizer IValue(c10::intrusive_ptr v); - bool isQuantizer() const { return Tag::Quantizer == tag; } + bool isQuantizer() const { + return Tag::Quantizer == tag; + } c10::intrusive_ptr toQuantizer() &&; - c10::intrusive_ptr toQuantizer() const &; + c10::intrusive_ptr toQuantizer() const&; // Int - IValue(int64_t i) - : tag(Tag::Int), is_intrusive_ptr(false) { - payload.as_int = i; + IValue(int64_t i) : tag(Tag::Int), is_intrusive_ptr(false) { + payload.u.as_int = i; } // allow you to pass literals (3, 4) without ambiguity - IValue(int32_t i) - : IValue(static_cast(i)) {} + IValue(int32_t i) : IValue(static_cast(i)) {} - bool isInt() const { return Tag::Int == tag; } + bool isInt() const { + return Tag::Int == tag; + } int64_t toInt() const { AT_ASSERT(isInt()); - return payload.as_int; + return payload.u.as_int; } // Bool - IValue(bool b) - : tag(Tag::Bool), is_intrusive_ptr(false) { + IValue(bool b) : tag(Tag::Bool), is_intrusive_ptr(false) { #if defined(__clang__) && defined(__x86_64__) // Initializing entire payload stops valgrind's from reporting // "jump or move depends on uninitialised value" in IValue copy constructor // See https://github.com/pytorch/pytorch/issues/37117 - payload.as_int = b; + payload.u.as_int = b; #else - payload.as_bool = b; + payload.u.as_bool = b; #endif } - bool isBool() const { return Tag::Bool == tag; } - bool toBool() const { + bool isBool() const { + return Tag::Bool == tag; + } + bool toBool() const { AT_ASSERT(isBool()); - return payload.as_bool; + return payload.u.as_bool; } // IntList bool isIntList() const; c10::List toIntList() &&; - c10::List toIntList() const &; + c10::List toIntList() const&; std::vector toIntVector() const; // ConstantString IValue(c10::intrusive_ptr v); IValue(std::string v); - IValue(const char* v): IValue(std::string(v)) {} - bool isString() const { return Tag::String == tag; } + IValue(const char* v) : IValue(std::string(v)) {} + bool isString() const { + return Tag::String == tag; + } c10::intrusive_ptr toString() &&; - c10::intrusive_ptr toString() const &; + c10::intrusive_ptr toString() const&; const std::string& toStringRef() const; - c10::optional> toOptionalStringRef() const; + c10::optional> toOptionalStringRef() + const; // DoubleList bool isDoubleList() const; c10::List toDoubleList() &&; - c10::List toDoubleList() const &; + c10::List toDoubleList() const&; std::vector toDoubleVector() const; // BoolList bool isBoolList() const; c10::List toBoolList() &&; - c10::List toBoolList() const &; + c10::List toBoolList() const&; // TensorList bool isTensorList() const; c10::List toTensorList() &&; - c10::List toTensorList() const &; + c10::List toTensorList() const&; std::vector toTensorVector() const; // GenericList IValue(c10::List v); - bool isList() const { return Tag::GenericList == tag; } + bool isList() const { + return Tag::GenericList == tag; + } c10::List toList() &&; - c10::List toList() const &; + c10::List toList() const&; c10::ArrayRef toListRef() const; // Some template constructors of IValue calls another constructor recursively. // This SNIFAEs the called constructor exists. - template + template using enable_if_ivalue_constructible = std::enable_if_t::value, std::nullptr_t>; - template < - class T, - enable_if_ivalue_constructible = nullptr> + template = nullptr> IValue(c10::List v); - template < - class T, - enable_if_ivalue_constructible = nullptr> + template = nullptr> IValue(at::ArrayRef v); - template < - class T, - enable_if_ivalue_constructible = nullptr> + template = nullptr> IValue(const std::vector& v); - template + template IValue(std::array v); // GenericDict IValue(c10::Dict v); - bool isGenericDict() const { return Tag::GenericDict == tag; } + bool isGenericDict() const { + return Tag::GenericDict == tag; + } c10::Dict toGenericDict() &&; - c10::Dict toGenericDict() const &; + c10::Dict toGenericDict() const&; - template + template IValue(c10::Dict v); - template - /// \cond DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN - C10_DEPRECATED_MESSAGE("IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") - /// \endcond - IValue(std::unordered_map v); + template + /// \cond + /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN + C10_DEPRECATED_MESSAGE( + "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") + /// \endcond + IValue(std::unordered_map v); - template < - class T, - enable_if_ivalue_constructible = nullptr> + template = nullptr> IValue(c10::optional v); IValue(c10::nullopt_t); // ClassType IValue(c10::intrusive_ptr v); - bool isObject() const { return tag == Tag::Object; } + bool isObject() const { + return tag == Tag::Object; + } c10::intrusive_ptr toObject() &&; - c10::intrusive_ptr toObject() const & ; + c10::intrusive_ptr toObject() const&; const ivalue::Object& toObjectRef() const; torch::jit::Module toModule() const; @@ -492,19 +600,23 @@ struct CAFFE2_API IValue final { // PyObject IValue(c10::intrusive_ptr v); - bool isPyObject() const { return tag == Tag::PyObject; } + bool isPyObject() const { + return tag == Tag::PyObject; + } c10::intrusive_ptr toPyObjectHolder() &&; - c10::intrusive_ptr toPyObjectHolder() const &; + c10::intrusive_ptr toPyObjectHolder() const&; PyObject* toPyObject() const; // Enum explicit IValue(c10::intrusive_ptr v); - bool isEnum() const { return tag == Tag::Enum; } + bool isEnum() const { + return tag == Tag::Enum; + } c10::intrusive_ptr toEnumHolder() &&; - c10::intrusive_ptr toEnumHolder() const &; + c10::intrusive_ptr toEnumHolder() const&; // None - IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {} + IValue() : tag(Tag::None), is_intrusive_ptr(false) {} bool isNone() const { return Tag::None == tag; } @@ -520,9 +632,8 @@ struct CAFFE2_API IValue final { } // Scalar, which gets encoded as either an Int or a Double - IValue(at::Scalar s) - : IValue() { - if(s.isFloatingPoint()) { + IValue(at::Scalar s) : IValue() { + if (s.isFloatingPoint()) { *this = s.toDouble(); } else { *this = s.toLong(); @@ -532,50 +643,59 @@ struct CAFFE2_API IValue final { return isDouble() || isInt(); } at::Scalar toScalar() const { - if(isDouble()) + if (isDouble()) return toDouble(); - else if(isInt()) + else if (isInt()) return toInt(); throw std::runtime_error("IValue is not a Scalar"); } // Device - IValue(c10::Device d) - : tag(Tag::Device), is_intrusive_ptr(false) { - payload.as_device.type = d.type(); - payload.as_device.index = d.index(); + IValue(c10::Device d) : tag(Tag::Device), is_intrusive_ptr(false) { + payload.u.as_device.type = d.type(); + payload.u.as_device.index = d.index(); + } + bool isDevice() const { + return Tag::Device == tag; } - bool isDevice() const { return Tag::Device == tag; } c10::Device toDevice() const { AT_ASSERT(isDevice()); - return c10::Device(payload.as_device.type, payload.as_device.index); + return c10::Device(payload.u.as_device.type, payload.u.as_device.index); + } + + //Stream + IValue(c10::Stream stream) + : tag(Tag::Stream), is_intrusive_ptr(false) { + payload.u.as_int = stream.pack(); } + c10::Stream toStream() &&; + c10::Stream toStream() const &; + bool isStream() const { return Tag::Stream == tag; } // ScalarType IValue(ScalarType t) - : IValue(static_cast::type>(t)) {} + : IValue(static_cast::type>(t)) {} at::ScalarType toScalarType() const { return static_cast(toInt()); } // Layout IValue(Layout l) - : IValue(static_cast::type>(l)) {} + : IValue(static_cast::type>(l)) {} at::Layout toLayout() const { return static_cast(toInt()); } // MemoryFormat IValue(MemoryFormat m) - : IValue(static_cast::type>(m)) {} + : IValue(static_cast::type>(m)) {} at::MemoryFormat toMemoryFormat() const { return static_cast(toInt()); } // QScheme - IValue(at::QScheme qscheme) - : tag(Tag::Int), is_intrusive_ptr(false) { - payload.as_int = static_cast(qscheme); + IValue(at::QScheme qscheme) : tag(Tag::Int), is_intrusive_ptr(false) { + payload.u.as_int = static_cast(qscheme); } at::QScheme toQScheme() const { @@ -583,33 +703,35 @@ struct CAFFE2_API IValue final { } // Dimname - IValue(at::Dimname dimname) - : IValue(dimname.symbol().toQualString()) {} + IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {} at::Dimname toDimname() const { return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef())); } // Generator - IValue(at::Generator g) - : tag(Tag::Generator), is_intrusive_ptr(g.defined()) { + IValue(at::Generator g) : tag(Tag::Generator), is_intrusive_ptr(g.defined()) { // Note: the undefined generator is not refcounted, so while it // is tagged as a generator, is_intrusive_ptr is set to false. // This is not an optional optimization: our incref call // *will not* do the right thing when called on an // undefined generator. - payload.as_intrusive_ptr = g.unsafeReleaseGeneratorImpl(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl()); + } + bool isGenerator() const { + return Tag::Generator == tag; } - bool isGenerator() const { return Tag::Generator == tag; } at::Generator toGenerator() &&; - at::Generator toGenerator() const &; + at::Generator toGenerator() const&; // for debugging std::string tagKind() const { - switch(tag) { - #define DEFINE_CASE(x) case Tag::x: return #x; + switch (tag) { +#define DEFINE_CASE(x) \ + case Tag::x: \ + return #x; TORCH_FORALL_TAGS(DEFINE_CASE) - #undef DEFINE_CASE +#undef DEFINE_CASE } return "InvalidTag(" + c10::guts::to_string(static_cast(tag)) + ")"; } @@ -621,24 +743,20 @@ struct CAFFE2_API IValue final { // since they are simpler to understand // Note: if you get linker errors saying one of these is missing, - // change it to ... && = delete; and you will see better error messages for why - // However, we cannot commit this because some compiler versions barf on it. - template + // change it to ... && = delete; and you will see better error messages for + // why However, we cannot commit this because some compiler versions barf on + // it. + template T to() &&; - template - T to() const &; + template + T to() const&; - // ToOptional: convert a IValue to the Optional obj that accepts both T and None - template + // ToOptional: convert a IValue to the Optional obj that accepts both T and + // None + template optional toOptional(); - - /// @private [doxygen private] - /// Only for use in generated code. - OptionalArray toOptionalIntArray(); - - /// @private [doxygen private] - /// Only for use in generated code. - OptionalArray toOptionalDoubleArray(); + template + optional toOptional() const; /// @private [doxygen private] /// this is a shallow comparison of two IValues to test the object identity @@ -664,19 +782,24 @@ struct CAFFE2_API IValue final { // This is different from `repr()` in that there is no expectation that we can // exactly reconstruct an IValue from the output; feel free to use a // concise/pretty form - CAFFE2_API friend std::ostream& operator<<( + TORCH_API friend std::ostream& operator<<( std::ostream& out, const IValue& v); bool isPtrType() const { - return is_intrusive_ptr; + return (isTensor() && payload.as_tensor.defined()) || is_intrusive_ptr; } /// @private [doxygen private] const void* internalToPointer() const { TORCH_INTERNAL_ASSERT( isPtrType(), "Can only call internalToPointer() for pointer types"); - return payload.as_intrusive_ptr; + if (isTensor()) { + return payload.as_tensor.unsafeGetTensorImpl(); + } else { + return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() + ? payload.u.as_intrusive_ptr : nullptr; + } } TypePtr type() const; @@ -685,22 +808,25 @@ struct CAFFE2_API IValue final { struct HashAliasedIValue { size_t operator()(const IValue& val) const { if (val.isTensor()) { - return reinterpret_cast(val.toTensor().storage().unsafeGetStorageImpl()); + return reinterpret_cast( + val.toTensor().storage().unsafeGetStorageImpl()); } // If it is not a Tensor, then two mutable IValues alias each other only // if they are the same pointer. - return val.payload.as_int; + return val.payload.u.as_int; } }; struct CompAliasedIValues { bool operator()(const IValue& lhs, const IValue& rhs) const { - return lhs.isAliasOf(rhs); + return lhs.isAliasOf(rhs); } }; - using HashAliasedIValues = std::unordered_set; - using HashAliasedIValueMap = std::unordered_map; + using HashAliasedIValues = + std::unordered_set; + using HashAliasedIValueMap = + std::unordered_map; // Chechs if this and rhs has a subvalues in common. // [t1,t2] and [t2, t3] returns true. @@ -712,12 +838,15 @@ struct CAFFE2_API IValue final { // Apply visitor to every subvalue. // TODO: There are several places that recurse over IValue. This is fragile. // This visitor should be used to recurse over ivalues. - void visit(const std::function& visitor) const; + void visit(const std::function& visitor) const; IValue deepcopy() const; - IValue deepcopy( - HashAliasedIValueMap& memo) const; + IValue deepcopy(HashAliasedIValueMap& memo) const; private: + static c10::intrusive_ptr_target* null_to_undefined_tensor(c10::intrusive_ptr_target* p) { + return p ? p : static_cast(c10::UndefinedTensorImpl::singleton()); + } + static bool ptrEqual(const IValue& lhs, const IValue& rhs); // NOTE: IValue tags are intentionally private. In the future we may encode // this value different (e.g. using NaN boxing), and this would make it more @@ -731,30 +860,86 @@ struct CAFFE2_API IValue final { #undef DEFINE_TAG }; - template> + template < + class T, + class NullType = c10::detail::intrusive_target_default_null_type> c10::intrusive_ptr moveToIntrusivePtr(); - template> + template < + typename T, + class NullType = c10::detail::intrusive_target_default_null_type> c10::intrusive_ptr toIntrusivePtr() const; - void clearToNone() { - payload.as_int = 0; + void destroy() { + // We carefully construct this call to both 1) avoid UB by using + // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable + // the compiler to generate the same code for each case. It is + // surprisingly difficult to get this right. + if (isTensor() || is_intrusive_ptr) { + c10::intrusive_ptr_target* p = isTensor() ? payload.as_tensor.unsafeGetTensorImpl() : payload.u.as_intrusive_ptr; + c10::intrusive_ptr::reclaim(p); + // No need to make this destructor call! + // payload.as_tensor.~Tensor(); + } + } + + C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept { + if (rhs.isTensor()) { + new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // rhs.payload.as_tensor.~Tensor(); + } else { + payload.u = rhs.payload.u; + } + tag = rhs.tag; + is_intrusive_ptr = rhs.is_intrusive_ptr; + rhs.clearToNone(); + } + + void clearToNone() noexcept { + payload.u.as_int = 0; tag = Tag::None; is_intrusive_ptr = false; } union Payload { - int64_t as_int; - double as_double; - bool as_bool; - c10::intrusive_ptr_target* as_intrusive_ptr; - struct { - DeviceType type; - DeviceIndex index; - } as_device; + // We use a nested union here so that we can make the copy easy + // and efficient in the non-tensor (i.e., trivially copyable) + // case. Specifically, we do not have to do a switch-on-tag to + // figure out which union member to assign; we can just use + // TriviallyCopyablePayload::operator=. + union TriviallyCopyablePayload { + TriviallyCopyablePayload() : as_int(0) {} + int64_t as_int; + double as_double; + bool as_bool; + // Invariant: never nullptr; null state is represented as + // c10::UndefinedTensorImpl::singleton() for consistency of + // representation with Tensor. + c10::intrusive_ptr_target* as_intrusive_ptr; + struct { + DeviceType type; + DeviceIndex index; + } as_device; + } u; + at::Tensor as_tensor; + Payload() : u() {} + ~Payload() {} }; - IValue(Payload p, Tag t, bool i) - : payload(p), tag(t), is_intrusive_ptr(i) {} + IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) { + if (isTensor()) { + new (&payload.as_tensor) at::Tensor(p.as_tensor); + } else { + payload.u = p.u; + } + } Payload payload; Tag tag; @@ -762,45 +947,49 @@ struct CAFFE2_API IValue final { friend struct WeakIValue; }; -struct CAFFE2_API WeakIValue final { - WeakIValue() - : payload{0} - , tag(IValue::Tag::None) - , is_intrusive_ptr(false) {} +struct TORCH_API WeakIValue final { + WeakIValue() : tag(IValue::Tag::None), is_intrusive_ptr(false) {} WeakIValue(const WeakIValue& rhs) : payload(rhs.payload), tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { - if (is_intrusive_ptr) { + if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); } } WeakIValue(const IValue& rhs) - : payload(rhs.payload), - tag(rhs.tag), + : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { + if (rhs.isTensor()) { + payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); + is_intrusive_ptr = true; + } else { + payload = rhs.payload.u; + } if (is_intrusive_ptr) { - c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + } } } WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() { swap(rhs); } ~WeakIValue() { - if (is_intrusive_ptr) { + if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr); } } - WeakIValue & operator=(WeakIValue && rhs) & noexcept { + WeakIValue& operator=(WeakIValue&& rhs) & noexcept { WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None return *this; } - WeakIValue & operator=(WeakIValue const & rhs) & { + WeakIValue& operator=(WeakIValue const& rhs) & { WeakIValue(rhs).swap(*this); return *this; } - void swap(WeakIValue & rhs) noexcept { + void swap(WeakIValue& rhs) noexcept { std::swap(payload, rhs.payload); std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); std::swap(tag, rhs.tag); @@ -813,17 +1002,33 @@ struct CAFFE2_API WeakIValue final { IValue lock() const { if (!is_intrusive_ptr) { - return IValue(payload, tag, false); + IValue::Payload newPayload; + newPayload.u = payload; + return IValue(newPayload, tag, false); } - auto temp = c10::weak_intrusive_ptr::reclaim( - payload.as_intrusive_ptr); - IValue::Payload pl; - pl.as_intrusive_ptr = temp.lock().release(); - temp.release(); - if (!pl.as_intrusive_ptr) { - return IValue(); + if (IValue::Tag::Tensor == tag) { + auto temp = c10::weak_intrusive_ptr::reclaim( + static_cast(payload.as_intrusive_ptr)); + c10::intrusive_ptr ip(temp.lock()); + temp.release(); + if (!ip) { + return IValue(); + } else { + return IValue(at::Tensor(std::move(ip))); + } } else { - return IValue(pl, tag, true); + auto temp = c10::weak_intrusive_ptr::reclaim( + payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? nullptr + : payload.as_intrusive_ptr); + IValue::Payload pl; + pl.u.as_intrusive_ptr = temp.lock().release(); + temp.release(); + if (!pl.u.as_intrusive_ptr) { + return IValue(); + } else { + return IValue(pl, tag, true); + } } } @@ -831,7 +1036,7 @@ struct CAFFE2_API WeakIValue final { if (!is_intrusive_ptr) { return 1; } - auto temp = c10::weak_intrusive_ptr::reclaim( + auto temp = c10::weak_intrusive_ptr::reclaim( payload.as_intrusive_ptr); size_t result = temp.use_count(); temp.release(); @@ -842,7 +1047,7 @@ struct CAFFE2_API WeakIValue final { if (!is_intrusive_ptr) { return 1; } - auto temp = c10::weak_intrusive_ptr::reclaim( + auto temp = c10::weak_intrusive_ptr::reclaim( payload.as_intrusive_ptr); size_t result = temp.weak_use_count(); temp.release(); @@ -852,8 +1057,9 @@ struct CAFFE2_API WeakIValue final { return payload.as_int; } -private: - IValue::Payload payload; + private: + using Payload = IValue::Payload::TriviallyCopyablePayload; + Payload payload; IValue::Tag tag; bool is_intrusive_ptr; }; @@ -870,11 +1076,12 @@ struct TORCH_API StrongTypePtr { std::shared_ptr type_; }; -TORCH_API ska::flat_hash_map& getCustomClassTypeMap(); +TORCH_API ska::flat_hash_map& +getCustomClassTypeMap(); -template -c10::ClassTypePtr getCustomClassType() { - auto tmap = c10::getCustomClassTypeMap(); +template +c10::ClassTypePtr getCustomClassTypeImpl() { + auto& tmap = c10::getCustomClassTypeMap(); auto res = tmap.find(std::type_index(typeid(T))); if (res == tmap.end()) { throw c10::Error("Can't find class id in custom class type map", ""); @@ -882,14 +1089,18 @@ c10::ClassTypePtr getCustomClassType() { return res->second; } -template -inline bool isCustomClassRegistered() { - auto tmap = c10::getCustomClassTypeMap(); - return tmap.find(std::type_index(typeid(T))) != tmap.end(); +template +const c10::ClassTypePtr& getCustomClassType() { + // Classes are never unregistered from getCustomClassTypeMap and the + // hash lookup can be a hot path, so just cache. + // For the same reason, it's fine If this ends up getting duplicated across + // DSO boundaries for whatever reason. + static c10::ClassTypePtr cache = getCustomClassTypeImpl(); + return cache; } TORCH_API std::unordered_map>& getClassConverter(); -} +} // namespace c10 #include diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index e0f143a030cf1..e4d68ccab5918 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -3,16 +3,18 @@ #include #include +#include +#include #include #include +#include +#include #include +#include #include #include #include -#include -#include -#include -#include +#include namespace torch { namespace jit { @@ -26,13 +28,14 @@ struct IValue; struct ClassType; struct TupleType; struct EnumType; +struct InferredType; // For custom class __init__ registration, we need to pass in a function // that looks like this: [](IValue x, args...) -// However, make_boxed_from_unboxed_functor.h automatically sets the input types of the function -// by introspecting the types of the functor (which is IValue in this case). -// However, we need the type it binds to be Foo. +// However, make_boxed_from_unboxed_functor.h automatically sets the input types +// of the function by introspecting the types of the functor (which is IValue in +// this case). However, we need the type it binds to be Foo. // Instead, we pass in a lambda [](ivalue_holder x, args...) from // which getTypePtr can recover the original class pointer. @@ -42,26 +45,32 @@ struct tagged_capsule { IValue ivalue; }; -template +template c10::intrusive_ptr IValue::moveToIntrusivePtr() { - auto t = c10::intrusive_ptr::reclaim(static_cast(payload.as_intrusive_ptr)); + auto t = c10::intrusive_ptr::reclaim( + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); clearToNone(); return t; } -template +template c10::intrusive_ptr IValue::toIntrusivePtr() const { - auto r = c10::intrusive_ptr::reclaim(static_cast(payload.as_intrusive_ptr)); + auto r = c10::intrusive_ptr::reclaim( + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); auto p = r; r.release(); return p; } -template +template intrusive_ptr static_intrusive_pointer_cast(intrusive_ptr r) { return intrusive_ptr::reclaim(static_cast(r.release())); } -template +template intrusive_ptr dynamic_intrusive_pointer_cast(intrusive_ptr r) { return intrusive_ptr::reclaim(dynamic_cast(r.release())); } @@ -70,7 +79,7 @@ inline c10::intrusive_ptr IValue::toFuture() && { AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toFuture() const & { +inline c10::intrusive_ptr IValue::toFuture() const& { AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); return toIntrusivePtr(); } @@ -78,7 +87,7 @@ inline c10::intrusive_ptr IValue::toRRef() && { AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toRRef() const & { +inline c10::intrusive_ptr IValue::toRRef() const& { AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); return toIntrusivePtr(); } @@ -86,7 +95,7 @@ inline c10::intrusive_ptr IValue::toQuantizer() && { AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toQuantizer() const & { +inline c10::intrusive_ptr IValue::toQuantizer() const& { AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); return toIntrusivePtr(); } @@ -94,7 +103,7 @@ inline c10::intrusive_ptr IValue::toString() && { AT_ASSERT(isString(), "Expected String but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toString() const & { +inline c10::intrusive_ptr IValue::toString() const& { AT_ASSERT(isString(), "Expected String but got ", tagKind()); return toIntrusivePtr(); } @@ -102,15 +111,17 @@ inline c10::intrusive_ptr IValue::toObject() && { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toObject() const & { +inline c10::intrusive_ptr IValue::toObject() const& { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); return toIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toPyObjectHolder() && { +inline c10::intrusive_ptr IValue:: + toPyObjectHolder() && { TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toPyObjectHolder() const & { +inline c10::intrusive_ptr IValue::toPyObjectHolder() + const& { TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); return toIntrusivePtr(); } @@ -118,31 +129,62 @@ inline c10::intrusive_ptr IValue::toEnumHolder() && { TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toEnumHolder() const & { +inline c10::intrusive_ptr IValue::toEnumHolder() const& { TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); return toIntrusivePtr(); } inline at::Tensor IValue::toTensor() && { AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); - return at::Tensor(moveToIntrusivePtr()); + auto result = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + clearToNone(); + return result; +} +inline at::Tensor& IValue::toTensor() & { + AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); + return payload.as_tensor; } -inline at::Tensor IValue::toTensor() const & { +inline const at::Tensor& IValue::toTensor() const& { AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); - return at::Tensor(toIntrusivePtr()); + return payload.as_tensor; +} +inline c10::Storage IValue::toStorage() && { + AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); + return c10::Storage( + moveToIntrusivePtr()); +} +inline c10::Storage IValue::toStorage() const& { + AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); + return c10::Storage(toIntrusivePtr()); +} +inline c10::Stream IValue::toStream() && { + return c10::Stream::unpack(payload.u.as_int); +} +inline c10::Stream IValue::toStream() const& { + return c10::Stream::unpack(payload.u.as_int); } inline c10::intrusive_ptr IValue::toBlob() && { AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toBlob() const & { +inline c10::intrusive_ptr IValue::toBlob() const& { AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); - return toIntrusivePtr();; + return toIntrusivePtr(); + ; } inline c10::intrusive_ptr IValue::toCapsule() && { TORCH_INTERNAL_ASSERT(isCapsule()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toCapsule() const & { +inline c10::intrusive_ptr IValue::toCapsule() const& { TORCH_INTERNAL_ASSERT(isCapsule()); return toIntrusivePtr(); } @@ -150,43 +192,45 @@ inline at::Generator IValue::toGenerator() && { AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); return at::Generator(moveToIntrusivePtr()); } -inline at::Generator IValue::toGenerator() const & { +inline at::Generator IValue::toGenerator() const& { AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); return at::Generator(toIntrusivePtr()); } namespace ivalue { -void CAFFE2_API checkCustomClassType(TypePtr expected_type, TypePtr actual_type); +void TORCH_API +checkCustomClassType(const Type* expected_type, const Type* actual_type); template using Shared = c10::intrusive_ptr; // string -struct CAFFE2_API ConstantString final : c10::intrusive_ptr_target { +struct TORCH_API ConstantString final : c10::intrusive_ptr_target { private: const std::string str_; + public: - ConstantString(std::string str) - : str_(std::move(str)) {} + ConstantString(std::string str) : str_(std::move(str)) {} static c10::intrusive_ptr create(std::string str_); - const std::string & string() const { + const std::string& string() const { return str_; } - operator const std::string & () const { + operator const std::string&() const { return string(); } - CAFFE2_API friend std::ostream& operator<<( + TORCH_API friend std::ostream& operator<<( std::ostream& out, const ConstantString& v); }; struct Future; -struct CAFFE2_API Tuple : c10::intrusive_ptr_target { +struct TORCH_API Tuple : c10::intrusive_ptr_target { private: std::vector elements_; - mutable std::shared_ptr type_; // lazily computed for unnamed tuples + mutable std::shared_ptr + type_; // lazily computed for unnamed tuples public: // named tuples have additional type information, so we @@ -202,10 +246,11 @@ struct CAFFE2_API Tuple : c10::intrusive_ptr_target { template static c10::intrusive_ptr create(Args... elements_) { - return c10::make_intrusive(std::vector{IValue(elements_)...}); + return c10::make_intrusive( + std::vector{IValue(elements_)...}); } - const std::vector& elements() const & { + const std::vector& elements() const& { return elements_; } operator const std::vector&() const { @@ -224,11 +269,17 @@ struct CAFFE2_API Tuple : c10::intrusive_ptr_target { } std::shared_ptr type() const; - CAFFE2_API friend bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs); + static size_t hash(const Tuple& t) { + return c10::get_hash(t.elements()); + } + + TORCH_API friend bool operator==( + const ivalue::Tuple& lhs, + const ivalue::Tuple& rhs); private: Tuple(std::vector elements, std::shared_ptr type = nullptr) - : elements_(std::move(elements)), type_(std::move(type)) {} + : elements_(std::move(elements)), type_(std::move(type)) {} friend class c10::intrusive_ptr; }; @@ -236,7 +287,7 @@ struct CAFFE2_API Tuple : c10::intrusive_ptr_target { struct Object; struct PyObjectHolder; struct EnumHolder; -} +} // namespace ivalue // Future struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { @@ -251,7 +302,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { public: explicit Future(TypePtr type) : type_(type) {} - struct CAFFE2_API FutureError final : public std::exception { + struct TORCH_API FutureError final : public std::exception { explicit FutureError(std::string&& error_msg_) : error_msg(std::move(error_msg_)) {} @@ -267,18 +318,22 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { /** * Wait on the future until it completes. */ - virtual void wait() { + void wait() { std::unique_lock lock(mutex_); while (!completed_) { finished_cv_.wait(lock); } + + if (!eptr_) { + postWaitHook(value_); + } } /** * Wait on the future until it completes and throw an * exception if an error exists. */ - virtual void waitAndThrow() { + void waitAndThrow() { std::unique_lock lock(mutex_); while (!completed_) { finished_cv_.wait(lock); @@ -287,12 +342,14 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { if (eptr_) { std::rethrow_exception(eptr_); } + + postWaitHook(value_); } /** * Explicitly mark the future as completed with the output value. */ - virtual void markCompleted(IValue value) { + void markCompleted(IValue value) { std::unique_lock lock(mutex_); TORCH_CHECK( !completed(), @@ -301,6 +358,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { completed_ = true; value_ = std::move(value); + postMarkCompletedHook(value_); + std::vector> cbs; cbs.swap(callbacks_); lock.unlock(); @@ -312,7 +371,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } void markCompleted() { - markCompleted(IValue {}); + markCompleted(IValue{}); } void setError(std::exception_ptr eptr) { @@ -336,7 +395,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Get the result of the current future. - virtual IValue value() { + IValue value() { std::unique_lock lock(mutex_); AT_ASSERT(completed()); if (eptr_) { @@ -347,7 +406,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { // This accessor should only be used if we know that the future is // completed() with no error. - virtual const IValue& constValue() { + const IValue& constValue() const { std::unique_lock lock(mutex_); AT_ASSERT(completed()); AT_ASSERT(!eptr_); @@ -360,8 +419,9 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { * If the future has already completed, * this function will execute the callback immediately. */ - virtual void addCallback(std::function callback) { + void addCallback(std::function callback) { std::unique_lock lock(mutex_); + callback = wrapCallback(std::move(callback)); if (completed()) { lock.unlock(); callback(); @@ -375,38 +435,34 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { * value of the callback. This is necessary when the callback provider needs * to know for sure when the callback has finished. */ - virtual c10::intrusive_ptr then( + c10::intrusive_ptr then( std::function callback, TypePtr type) { - auto fut = c10::make_intrusive(type); - // Cannot move capture std::function in lambda, because it cannot deduce - // the template type for std::function. Hence use std::bind to explicitly - // specify types. - addCallback(std::bind( - [fut](std::function cb) { + auto fut = createInstance(std::move(type)); + addCallback( + [fut, cb = std::move(callback)]() { try { fut->markCompleted(cb()); - } catch (std::exception& e) { + } catch (std::exception&) { fut->setError(std::current_exception()); } - }, - std::move(callback))); + }); return fut; } // Tries to retrieve the error message from std::exception_ptr. - std::string tryRetrieveErrorMessage() { + std::string tryRetrieveErrorMessage() const { TORCH_CHECK(hasError(), "No error present on the future."); std::unique_lock lock(mutex_); return tryRetrieveErrorMessageInternal(eptr_); } // Check if the current future has completed - virtual bool completed() const{ + bool completed() const { return completed_; } - virtual bool hasValue() const { + bool hasValue() const { std::unique_lock lock(mutex_); return completed_ && !eptr_; } @@ -421,7 +477,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { return eptr_; } - CAFFE2_API friend std::ostream& operator<<( + TORCH_API friend std::ostream& operator<<( std::ostream& out, const Future& v); @@ -429,6 +485,43 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { return type_; } + protected: + // This hook is called by this class's then() method when it prepares the + // instance it returns to the caller. It should be overridden by subclasses so + // that they can produce an instace of their own type. + virtual c10::intrusive_ptr createInstance(at::TypePtr type) { + return c10::make_intrusive(type); + } + + // This hook will be called by this class (the superclass) when the future is + // marked completed _with a value_ (hence not in case of error). This is done + // right away, while the mutex is still held, before any callbacks are run. + // It allows subclasses to further update their state if they so need. For + // example the CUDAFuture subclass uses it to determine what devices the value + // resides on and record an event in those devices' current streams. + virtual void postMarkCompletedHook(const at::IValue& value) {} + + // This hook will be called by the addCallback() and the then() methods before + // storing the callback for later execution (or before running it inline if + // the future is already complete). Note that this method could thus be called + // while the future is _not_ yet complete. By default this method does nothing + // but subclasses can override this method to add functionality. For example + // the CUDAFuture subclass ensures the callback runs with CUDA streams which + // are synchronized with the events recorded in the I/O streams. + virtual std::function wrapCallback( + std::function callback) { + return callback; + } + + // This hook will be called by this class after a user thread has completed + // waiting on a successful future. It will thus not be called if the future + // completes with an error. It will also not be called if the user accesses + // the future's value without synchronization. Subclasses can override this + // to add some synchronization to the wait. For example, the CUDAFuture + // subclass ensures the user's current CUDA streams synchronize with the I/O + // events stored by the future. + virtual void postWaitHook(const at::IValue& value) {} + private: void setErrorInternal( std::exception_ptr eptr, @@ -437,6 +530,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { completed_ = true; eptr_ = std::move(eptr); + // Do not call postMarkCompletedHook() here as there isn't any value. + std::vector> cbs; cbs.swap(callbacks_); lock.unlock(); @@ -448,7 +543,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } // Tries to retrieve the error message from std::exception_ptr. - std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) { + std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const { try { std::rethrow_exception(eptr); } catch (const std::exception& e) { @@ -470,11 +565,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { // Input is a list of Futures with the same target type. // Output is a Future to the List of completed Futures. -CAFFE2_API intrusive_ptr collectAll( +TORCH_API intrusive_ptr collectAll( c10::List> srcs); // Input is a List of Futures with the same target type. // Output is a Future that will be updated with a seen value. -CAFFE2_API intrusive_ptr collectAny( +TORCH_API intrusive_ptr collectAny( c10::List> srcs); // User-defined object. @@ -571,25 +666,33 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { struct ivalue::PyObjectHolder : c10::intrusive_ptr_target { public: virtual PyObject* getPyObject() = 0; - virtual ~PyObjectHolder() {}; + virtual c10::InferredType tryToInferType() = 0; + virtual IValue toIValue(const TypePtr& type, c10::optional N = c10::nullopt) = 0; + virtual std::string toStr() = 0; + + virtual ~PyObjectHolder(){}; }; struct ivalue::EnumHolder : c10::intrusive_ptr_target { public: EnumHolder(std::shared_ptr type, std::string name, IValue value) - : type_(std::move(type)), name_(std::move(name)), value_(std::move(value)) {} + : type_(std::move(type)), + name_(std::move(name)), + value_(std::move(value)) {} bool is(const ivalue::EnumHolder& rhs) { return *this == rhs; } - friend bool operator==(const ivalue::EnumHolder&lhs, const ivalue::EnumHolder& rhs); + friend bool operator==( + const ivalue::EnumHolder& lhs, + const ivalue::EnumHolder& rhs); - CAFFE2_API friend std::ostream& operator<<( + TORCH_API friend std::ostream& operator<<( std::ostream& out, const EnumHolder& v); - CAFFE2_API const std::string qualifiedClassName() const; + TORCH_API const std::string qualifiedClassName() const; const std::string unqualifiedClassName() const; @@ -605,7 +708,7 @@ struct ivalue::EnumHolder : c10::intrusive_ptr_target { return type_; } -private: + private: std::shared_ptr type_; std::string name_; IValue value_; @@ -628,23 +731,27 @@ using _guarded_unsigned_long = std::conditional_t< inline const ivalue::Object& IValue::toObjectRef() const { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); - return *static_cast(payload.as_intrusive_ptr); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference"); + return *static_cast(payload.u.as_intrusive_ptr); } // note: when adding a DEFINE_TO case here you should also add a // toX method to IValue. These named methods are much more discoverable // than the to templated function. -#define DEFINE_TO(type, method_name) \ -template<> \ -inline type IValue::to() && { \ - return std::move(*this).method_name(); \ -} \ -template<> \ -inline type IValue::to() const & { \ - return this->method_name(); \ -} +#define DEFINE_TO(type, method_name) \ + template <> \ + inline type IValue::to()&& { \ + return std::move(*this).method_name(); \ + } \ + template <> \ + inline type IValue::to() const& { \ + return this->method_name(); \ + } + DEFINE_TO(at::Tensor, toTensor) +DEFINE_TO(at::Storage, toStorage) +DEFINE_TO(c10::Stream, toStream) DEFINE_TO(float, toDouble) DEFINE_TO(double, toDouble) DEFINE_TO(unsigned char, toInt) @@ -693,11 +800,11 @@ struct _fake_type {}; // The _fake_type parameter allows us to overload // based on the return type. template -// TODO this is deprecated but we don't throw a warning because a lot of ops in native_functions.yaml still return std::vector. -//C10_DEPRECATED_MESSAGE("IValues based on std::vector are potentially slow and deprecated. Please use torch::List instead.") -std::vector generic_to( - IValue ivalue, - _fake_type>) { +// TODO this is deprecated but we don't throw a warning because a lot of ops in +// native_functions.yaml still return std::vector. +// C10_DEPRECATED_MESSAGE("IValues based on std::vector are potentially slow +// and deprecated. Please use torch::List instead.") +std::vector generic_to(IValue ivalue, _fake_type>) { // We need to do a deep copy of the vector because there might be other // references to this same IValue that also use the list. We can't just // move the elements out. @@ -712,56 +819,86 @@ std::vector generic_to( template c10::intrusive_ptr IValue::toCustomClass() && { - static_assert(std::is_base_of::value == true, - "toCustomClass requires that template parameter T must inherit " - "from torch::CustomClassHolder"); + static_assert( + std::is_base_of::value == true, + "toCustomClass requires that template parameter T must inherit " + "from torch::CustomClassHolder"); auto obj = toObject(); - TORCH_CHECK(obj->slots().size() == 1, - "Tried to cast IValue to custom class but it did " - "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); - auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); + TORCH_CHECK( + obj->slots().size() == 1, + "Tried to cast IValue to custom class but it did " + "not contain a custom class!"); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); + auto userObj = + c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; } template -c10::intrusive_ptr IValue::toCustomClass() const & { - static_assert(std::is_base_of::value == true, - "toCustomClass requires that template parameter T must inherit " - "from torch::CustomClassHolder"); +c10::intrusive_ptr IValue::toCustomClass() const& { + static_assert( + std::is_base_of::value == true, + "toCustomClass requires that template parameter T must inherit " + "from torch::CustomClassHolder"); auto obj = toObject(); - TORCH_CHECK(obj->slots().size() == 1, - "Tried to cast IValue to custom class but it did " - "not contain a custom class!"); - auto expected_type = c10::getCustomClassType>(); - ivalue::checkCustomClassType(expected_type, type()); - auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); + TORCH_CHECK( + obj->slots().size() == 1, + "Tried to cast IValue to custom class but it did " + "not contain a custom class!"); + const Type* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); + auto userObj = + c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; } template -T generic_to( - IValue ivalue, - _fake_type) { - using ElemType = typename std::remove_pointer::type::element_type; - return std::move(ivalue).toCustomClass(); +T generic_to(IValue ivalue, _fake_type) { + using ElemType = typename std::remove_pointer::type::element_type; + return std::move(ivalue).toCustomClass(); } template -tagged_capsule generic_to( - IValue ivalue, - _fake_type>) { - return tagged_capsule{std::move(ivalue)}; +tagged_capsule generic_to(IValue ivalue, _fake_type>) { + return tagged_capsule{std::move(ivalue)}; } template -c10::List generic_to( - IValue ivalue, - _fake_type>) { +c10::List generic_to(IValue ivalue, _fake_type>) { return impl::toTypedList(std::move(ivalue).toList()); } +template +static std::vector createVectorFromList(const c10::detail::ListImpl* impl) { + std::vector result; + result.reserve(impl->list.size()); + for (size_t i = 0, N = impl->list.size(); i < N; ++i) { + result.push_back(impl->list[i].to()); + } + return result; +} + +template +static std::vector createVectorFromList(const c10::List& impl) { + std::vector result; + result.reserve(impl.size()); + for (size_t i = 0, N = impl.size(); i < N; ++i) { + result.push_back(impl[i]); + } + return result; +} + +template +OptionalArray generic_to(IValue ivalue, _fake_type>) { + if (ivalue.isNone()) { + return {}; + } + return createVectorFromList( + std::move(ivalue).to>() + ); +} + namespace detail { template std::array generic_to_array( @@ -772,10 +909,15 @@ std::array generic_to_array( // references to this same IValue that also use the list. We can't just // move the elements out. auto list = std::move(ivalue).to>(); - TORCH_CHECK(list.size() == sizeof...(I), "Tried to convert a List with ", list.size()," elements to a fixed-size array of size ", sizeof...(I)); + TORCH_CHECK( + list.size() == sizeof...(I), + "Tried to convert a List with ", + list.size(), + " elements to a fixed-size array of size ", + sizeof...(I)); return {list[I]...}; } -} +} // namespace detail template std::array generic_to( @@ -792,7 +934,8 @@ c10::Dict generic_to( } template -C10_DEPRECATED_MESSAGE("IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") +C10_DEPRECATED_MESSAGE( + "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") std::unordered_map generic_to( IValue ivalue, _fake_type>) { @@ -806,9 +949,7 @@ std::unordered_map generic_to( } template -c10::optional generic_to( - IValue ivalue, - _fake_type>) { +c10::optional generic_to(IValue ivalue, _fake_type>) { if (ivalue.isNone()) { return c10::nullopt; } @@ -823,7 +964,7 @@ Tuple generic_to_tuple_impl( return std::make_tuple( t[INDEX].to::type>()...); } -} +} // namespace detail template < typename... Args, @@ -849,45 +990,43 @@ inline T IValue::to() const& { return generic_to(*this, _fake_type{}); } -template -static std::vector createVectorFromList(const c10::detail::ListImpl* impl) { - std::vector result; - result.reserve(impl->list.size()); - for (size_t i = 0, N = impl->list.size(); i < N; ++i) { - result.push_back(impl->list[i].to()); - } - return result; -} - inline c10::List IValue::toIntList() && { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } -inline c10::List IValue::toIntList() const & { +inline c10::List IValue::toIntList() const& { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline std::vector IValue::toIntVector() const { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); - return createVectorFromList(static_cast(payload.as_intrusive_ptr)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toIntVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toDoubleList() && { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } -inline c10::List IValue::toDoubleList() const & { +inline c10::List IValue::toDoubleList() const& { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline std::vector IValue::toDoubleVector() const { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); - return createVectorFromList(static_cast(payload.as_intrusive_ptr)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toDoubleVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toBoolList() && { AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } -inline c10::List IValue::toBoolList() const & { +inline c10::List IValue::toBoolList() const& { AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); return c10::List(toIntrusivePtr()); } @@ -895,31 +1034,39 @@ inline c10::List IValue::toTensorList() && { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } -inline c10::List IValue::toTensorList() const & { +inline c10::List IValue::toTensorList() const& { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline std::vector IValue::toTensorVector() const { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); - return createVectorFromList(static_cast(payload.as_intrusive_ptr)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toTensorVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toList() && { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } -inline c10::List IValue::toList() const & { +inline c10::List IValue::toList() const& { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline c10::ArrayRef IValue::toListRef() const { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); - return static_cast(payload.as_intrusive_ptr)->list; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toListRef on null intrusive_ptr IValue"); + return static_cast(payload.u.as_intrusive_ptr) + ->list; } inline c10::Dict IValue::toGenericDict() && { AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); return c10::Dict(moveToIntrusivePtr()); } -inline c10::Dict IValue::toGenericDict() const & { +inline c10::Dict IValue::toGenericDict() const& { AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); return c10::Dict(toIntrusivePtr()); } @@ -927,14 +1074,14 @@ inline c10::intrusive_ptr IValue::toTuple() && { AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); return moveToIntrusivePtr(); } -inline c10::intrusive_ptr IValue::toTuple() const & { +inline c10::intrusive_ptr IValue::toTuple() const& { AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); return toIntrusivePtr(); } inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::Tuple), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::Tuple), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } template < typename... Args, @@ -949,21 +1096,19 @@ inline IValue::IValue(const std::tuple& t) } inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::String), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::String), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(std::string v) -: IValue(ivalue::ConstantString::create(std::move(v))) {} - + : IValue(ivalue::ConstantString::create(std::move(v))) {} inline IValue::IValue(c10::impl::GenericList v) -: tag(Tag::GenericList), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.impl_.release(); + : tag(Tag::GenericList), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } template > -inline IValue::IValue(c10::List v) -: IValue(impl::toList(std::move(v))) {} +inline IValue::IValue(c10::List v) : IValue(impl::toList(std::move(v))) {} template > inline IValue::IValue(at::ArrayRef v) : IValue(c10::List()) { auto list = to>(); @@ -980,8 +1125,8 @@ inline IValue::IValue(const std::vector& v) : IValue(c10::List()) { list.push_back(e); } } -template inline IValue::IValue(std::array v) -: IValue(c10::List()) { +template +inline IValue::IValue(std::array v) : IValue(c10::List()) { auto list = to>(); list.reserve(v.size()); for (auto& e : v) { @@ -990,15 +1135,16 @@ template inline IValue::IValue(std::array v) } inline IValue::IValue(c10::impl::GenericDict v) -: tag(Tag::GenericDict), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.impl_.release(); + : tag(Tag::GenericDict), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } -template +template inline IValue::IValue(c10::Dict v) -: IValue(impl::toGenericDict(std::move(v))) {} + : IValue(impl::toGenericDict(std::move(v))) {} -template inline IValue::IValue(std::unordered_map v) -: IValue(Dict()) { +template +inline IValue::IValue(std::unordered_map v) + : IValue(Dict()) { auto dict = to>(); dict.reserve(v.size()); for (auto& e : v) { @@ -1013,80 +1159,97 @@ inline IValue::IValue(c10::optional v) : IValue() { } } -inline IValue::IValue(c10::nullopt_t): IValue() {} +inline IValue::IValue(c10::nullopt_t) : IValue() {} inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::Object), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::Object), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::PyObject), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::PyObject), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::Enum), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::Enum), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } -inline IValue IValue::make_capsule(intrusive_ptr blob) { +inline IValue IValue::make_capsule( + intrusive_ptr blob) { IValue iv; iv.tag = Tag::Capsule; iv.is_intrusive_ptr = true; - iv.payload.as_intrusive_ptr = blob.release(); + iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); return iv; } -template ::value, int>> +template < + typename T, + std::enable_if_t::value, int>> IValue::IValue(c10::intrusive_ptr custom_class) { - if (!c10::isCustomClassRegistered>()) { - throw c10::Error( - "Trying to instantiate a class that isn't a registered custom class: " + + TypePtr classType = []() { + try { + return c10::getCustomClassType>(); + } catch (const c10::Error&) { + throw c10::Error( + "Trying to instantiate a class that isn't a registered custom class: " + std::string(c10::util::get_fully_qualified_type_name()), - ""); - } - auto classType = c10::getCustomClassType>(); + ""); + } + }(); auto ivalue_obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); - payload.as_intrusive_ptr = ivalue_obj.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release()); tag = Tag::Object; is_intrusive_ptr = true; } inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::Future), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::Future), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::RRef), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::RRef), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) -: tag(Tag::Quantizer), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + : tag(Tag::Quantizer), is_intrusive_ptr(true) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline const std::string& IValue::toStringRef() const { AT_ASSERT(isString(), "Expected String but got ", tagKind()); - return static_cast(payload.as_intrusive_ptr)->string(); -} -inline c10::optional> IValue::toOptionalStringRef() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringRef on null intrusive_ptr IValue"); + return static_cast( + payload.u.as_intrusive_ptr) + ->string(); +} +inline c10::optional> IValue:: + toOptionalStringRef() const { if (isNone()) { return c10::nullopt; } AT_ASSERT(isString(), "Expected optional but got ", tagKind()); - return std::reference_wrapper(static_cast(payload.as_intrusive_ptr)->string()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toOptionalStringRef on null intrusive_ptr IValue"); + return std::reference_wrapper( + static_cast(payload.u.as_intrusive_ptr) + ->string()); } inline PyObject* IValue::toPyObject() const { return toPyObjectHolder()->getPyObject(); } -template +template inline optional IValue::toOptional() { if (this->isNone()) { return nullopt; @@ -1094,18 +1257,12 @@ inline optional IValue::toOptional() { return this->to(); } -inline OptionalArray IValue::toOptionalIntArray() { - if (this->isNone()) { - return {}; - } - return this->toIntVector(); -} - -inline OptionalArray IValue::toOptionalDoubleArray() { +template +inline optional IValue::toOptional() const { if (this->isNone()) { - return {}; + return nullopt; } - return this->toDoubleVector(); + return this->to(); } inline bool IValue::isCustomClass() const { @@ -1113,13 +1270,16 @@ inline bool IValue::isCustomClass() const { } inline bool IValue::isSameIdentity(const IValue& rhs) const { - // We choose to not use memcmp for payload check due to potential random padding characters on union type + // We choose to not use memcmp for payload check due to potential random + // padding characters on union type // Semantics: - // 1. Immutable primitive values of the same type (Int, Double, None, Bool, Str) return value equality + // 1. Immutable primitive values of the same type (Int, Double, None, Bool, + // Str) return value equality // 2. If it is a tensor type, we need to take undefined tensor into account // 3. Undefined_tensor is None and vice versa should be true - // 4. If it is a reference type (i.e. is_intrusive_ptr), then is is True when the pointed-to object is the same. + // 4. If it is a reference type (i.e. is_intrusive_ptr), then is is True when + // the pointed-to object is the same. // 5. False for all other comparisons. if (this->isNone() && rhs.isNone()) { return true; @@ -1127,14 +1287,13 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const { // for bool type, do equality check return this->toBool() == rhs.toBool(); } else if (this->isTensor() && rhs.isTensor()) { - // for tensor type, just check the as_intrusive_ptr since is_intrusive_ptr is false for undefined tensor - return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + return this->payload.as_tensor.is_same(rhs.payload.as_tensor); } else if (this->isTensor() && rhs.isNone()) { // special case: undefined tensor and None are the same identity - return !this->is_intrusive_ptr; + return !this->payload.as_tensor.defined(); } else if (this->isNone() && rhs.isTensor()) { // special case: undefined tensor and None are the same identity - return !rhs.is_intrusive_ptr; + return !rhs.payload.as_tensor.defined(); } else if (this->isInt() && rhs.isInt()) { return this->toInt() == rhs.toInt(); } else if (this->isDouble() && rhs.isDouble()) { @@ -1142,9 +1301,10 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const { } else if (this->isString() && rhs.isString()) { return this->toStringRef() == rhs.toStringRef(); } else { - // for objects holding in IValue, do shallow compare on pointer address to testify the identity - return this->is_intrusive_ptr && rhs.is_intrusive_ptr - && this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + // for objects holding in IValue, do shallow compare on pointer address to + // testify the identity + return this->is_intrusive_ptr && rhs.is_intrusive_ptr && + this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } } @@ -1161,15 +1321,18 @@ IValue from_(c10::intrusive_ptr x, std::false_type) { } template IValue from_(T x, std::false_type) { - static_assert(guts::false_t::value, "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)"); + static_assert( + guts::false_t::value, + "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)"); return IValue(); } -} +} // namespace detail template IValue from(T x) { - return detail::from_(std::move(x), typename std::is_constructible::type{}); + return detail::from_( + std::move(x), typename std::is_constructible::type{}); } -} +} // namespace ivalue } // namespace c10 diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 1c9d31dd630cc..7d3890f582b8f 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1,10 +1,11 @@ #pragma once +#include #include #include #include -#include #include +#include #include #include @@ -17,200 +18,22 @@ struct ClassType; namespace torch { namespace jit { struct CompilationUnit; +struct Function; } // namespace jit } // namespace torch namespace c10 { +struct IValue; struct FunctionSchema; struct NamedType; using OptNameList = c10::optional>; -#define C10_FORALL_TYPES(_) \ - _(AnyType) \ - _(EnumType) \ - _(AnyEnumType) \ - _(TensorType) \ - _(TupleType) \ - _(ListType) \ - _(DictType) \ - _(NumberType) \ - _(FloatType) \ - _(FutureType) \ - _(RRefType) \ - _(IntType) \ - _(NoneType) \ - _(StringType) \ - _(GeneratorType) \ - _(QuantizerType) \ - _(BoolType) \ - _(OptionalType) \ - _(VarType) \ - _(DeviceObjType) \ - _(FunctionType) \ - _(ClassType) \ - _(PyObjectType) \ - _(CapsuleType) \ - _(InterfaceType) \ - _(QSchemeType) \ - _(LayoutType) \ - _(ScalarTypeType) \ - _(AnyListType) \ - _(AnyTupleType) \ - _(AnyClassType) - -enum class TypeKind { -#define DEFINE_TYPE(T) T, - C10_FORALL_TYPES(DEFINE_TYPE) -#undef DEFINE_TYPE -}; - -CAFFE2_API const char* typeKindToString(TypeKind kind); - -struct Type; -using TypePtr = std::shared_ptr; -using ConstTypePtr = std::shared_ptr; - -// Use this to customize how a Type is printed using `annotation_str()`. If -// c10::nullopt is returned, `annotation_str()` falls through to its default -// implementation. -using TypePrinter = - std::function(const ConstTypePtr&)>; - -struct CAFFE2_API Type : std::enable_shared_from_this { - private: - TypeKind kind_; - - protected: - Type(TypeKind kind) : kind_(kind) {} - - virtual std::string annotation_str_impl(TypePrinter printer) const { - return str(); - } - - public: - virtual bool operator==(const Type& rhs) const = 0; - - // subtyping relation. By default, we return true for the case - // when the type is exactly equal or if this <: T where rhs = Optional[T] - - // if this returns false and the why_not stream is non-null, it contains - // additional details that describe why this is not a subtype of 'rhs'. - // This additional information should only contain details that are not obvious - // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false - // but not clear why `Foo <: InterfaceBar` might be false. - virtual bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const; - virtual bool is_module() const; - bool isSubtypeOf(const TypePtr rhs) const { - return isSubtypeOfExt(rhs, nullptr); - } - - // How this type will appear in FunctionSchema declarations - virtual std::string str() const = 0; - - // How this type will appear as if it were a type annotation in Python - // which is sometimes different than how it appears in declarations (e.g. - // int[] vs List[int]) - // - // Takes a custom printer that users can pass in to customize the output of - // this method. - std::string annotation_str(TypePrinter printer) const { - if (printer) { - // the printer can return nullopt to fall through to the default impl - if (auto renamed = printer(shared_from_this())) { - return *renamed; - } - } - return annotation_str_impl(printer); - } - std::string annotation_str() const { - // Overload instead of define a default value for `printer` to help - // debuggers out. - return annotation_str(nullptr); - } - - // Returns a human readable string that includes additional information like - // "type is inferred rather than explictly defined" to help construct more - // user-friendly messages. - virtual std::string repr_str() const { - return annotation_str(); - } - - TypeKind kind() const { - return kind_; - } - - virtual bool requires_grad() const { - for (const auto& ct : containedTypes()) { - if (ct->requires_grad()) { - return true; - } - } - return false; - } - - // Dynamically cast this object to the subclass indicated by the - // template variable, returning nullptr if the cast is invalid. - template - std::shared_ptr cast() { - if (T::Kind == kind()) { - return std::static_pointer_cast(shared_from_this()); - } - return nullptr; - } - template - std::shared_ptr cast() const { - if (T::Kind == kind()) { - return std::static_pointer_cast(shared_from_this()); - } - return nullptr; - } - template - std::shared_ptr expect() { - auto r = cast(); - AT_ASSERT(r); - return r; - } - template - std::shared_ptr expect() const { - auto r = cast(); - AT_ASSERT(r); - return r; - } - virtual ~Type() = default; - virtual bool hasFreeVariables() const { - return false; - } - // list of types this type contains, e.g. for a List then element type of a - // list for a tuple, the types of the tuple elements - virtual at::ArrayRef containedTypes() const { - return {}; - } - // create a new version of this type, replacing its contained types with - // contained_types - TypePtr withContained(std::vector contained_types) { - auto current_contained = containedTypes(); - AT_ASSERT(current_contained.size() == contained_types.size()); - if (current_contained.equals(contained_types)) { - return shared_from_this(); - } - return createWithContained(std::move(contained_types)); - } - // per-type constructor, you only need to override this if the - // containedTypes() is not empty - virtual TypePtr createWithContained( - std::vector contained_types) const { - AT_ERROR( - "type with contained types did not overload createWithContained: ", - str()); - } -}; - struct AnyType; using AnyTypePtr = std::shared_ptr; // Any is the top of the type hierarchy, all other types are subtypes // T <: Any, forall T -struct CAFFE2_API AnyType : public Type { +struct TORCH_API AnyType : public Type { static AnyTypePtr create() { return AnyTypePtr( new AnyType()); // NOLINT(modernize-make-shared) @@ -238,7 +61,7 @@ inline bool operator!=(const Type& lhs, const Type& rhs) { } // common base for all types that have a single sub element -// e.g. Future[T], Option[T], List[T] +// e.g. Future[T], Optional[T], List[T] template struct SingleElementType : public Type { static const TypeKind Kind = K; @@ -282,7 +105,7 @@ using OptionalTypePtr = std::shared_ptr; // 1. Optional[T] <: Optional[R] iff T <: R // 2. T <: Optional[R] if T <: R // 3. None <: Optional[T] for all T -struct CAFFE2_API OptionalType +struct TORCH_API OptionalType : public SingleElementType { static OptionalTypePtr create(TypePtr element) { TORCH_INTERNAL_ASSERT(element, "OptionalType requires valid TypePtr"); @@ -307,7 +130,7 @@ struct CAFFE2_API OptionalType return create(contained_types[0]); } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { if (Type::isSubtypeOfExt(rhs, why_not)) { return true; } @@ -354,7 +177,7 @@ inline c10::optional merge_primitive( // `stride_indices` A contiguity marker on the smallest stride (c0) indicates // the stride is precisely 1, otherwise a contiguity marker means that $stride_n // = size_{n-1}*stride_{n-1}$ -struct CAFFE2_API Stride { +struct TORCH_API Stride { Stride() {} Stride( const c10::optional& stride_index, @@ -399,7 +222,7 @@ inline c10::optional merge_primitive( return r; } -struct CAFFE2_API ShapeSymbol { +struct TORCH_API ShapeSymbol { // needed for use in `std::map` ShapeSymbol() : value_(-1) {} // is this symbol a fixed/static dimension @@ -424,7 +247,7 @@ struct CAFFE2_API ShapeSymbol { static ShapeSymbol newSymbol() { return fromStaticSize(-static_cast(++num_symbols)); }; - friend CAFFE2_API std::ostream& operator<<( + friend TORCH_API std::ostream& operator<<( std::ostream& os, const ShapeSymbol& s); @@ -445,7 +268,7 @@ inline ShapeSymbol merge_primitive( // Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown // dims, partially known and fully known shapes are all supported. -struct CAFFE2_API SymbolicShape { +struct TORCH_API SymbolicShape { // Unranked shape constructor. SymbolicShape() : dims_(c10::nullopt) {} @@ -464,7 +287,7 @@ struct CAFFE2_API SymbolicShape { } // Mix of known and unknown ranks - SymbolicShape(const std::vector> dims) { + SymbolicShape(const std::vector>& dims) { std::vector shape_symbols; shape_symbols.reserve(dims.size()); for(c10::optional dim: dims) { @@ -477,7 +300,7 @@ struct CAFFE2_API SymbolicShape { dims_ = shape_symbols; } - SymbolicShape(const std::vector dims) : dims_(dims) {} + SymbolicShape(std::vector dims) : dims_(std::move(dims)) {} SymbolicShape(c10::IntArrayRef dims) { std::vector shape_symbols; @@ -488,6 +311,13 @@ struct CAFFE2_API SymbolicShape { dims_ = shape_symbols; } + ShapeSymbol operator[](size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + // Returns rank or nullopt in case of unranked shape. c10::optional rank() const { if(!dims_) { @@ -548,7 +378,7 @@ struct VaryingShape { return dims_ == other.dims_; } - const c10::optional& operator[](int i) const { + const c10::optional &operator[](size_t i) const { if (!dims_) { throw std::runtime_error("Rank isn't fixed"); } @@ -567,7 +397,7 @@ struct VaryingShape { return dims_; } - CAFFE2_API VaryingShape merge(const VaryingShape& other) const; + TORCH_API VaryingShape merge(const VaryingShape& other) const; c10::optional> concrete_sizes() const { if (!dims_) { @@ -602,7 +432,7 @@ struct VaryingShape { struct TensorType; using TensorTypePtr = std::shared_ptr; // This type represents a single Tensor with a specific size -struct CAFFE2_API TensorType : public Type { +struct TORCH_API TensorType : public Type { static TensorTypePtr create(const at::Tensor& t); // used by TensorType::create(size_t dim) which in turn used by @@ -622,8 +452,7 @@ struct CAFFE2_API TensorType : public Type { const SymbolicShape& sizes, const VaryingShape& stride_, c10::optional requires_grad, - c10::optional undefined = false, - bool is_inferred = false); + c10::optional undefined = false); static TensorTypePtr create( c10::optional scalar_type, @@ -667,7 +496,7 @@ struct CAFFE2_API TensorType : public Type { } bool operator==(const Type& rhs) const override; - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; std::string str() const override; @@ -722,7 +551,7 @@ struct CAFFE2_API TensorType : public Type { TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const { auto cloned = clone(); - cloned->sizes_ = ssizes; + cloned->sizes_ = std::move(ssizes); return cloned; } @@ -750,7 +579,7 @@ struct CAFFE2_API TensorType : public Type { const SymbolicShape& symbolic_sizes() const; - TensorTypePtr merge(TensorTypePtr other, bool merge_sizes = true) const; + TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const; bool matchTensor(const at::Tensor& t); @@ -768,10 +597,13 @@ struct CAFFE2_API TensorType : public Type { static TensorTypePtr getInferred() { static auto valueInferred = TensorType::create( - /*scalar_type=*/{}, /*device=*/{}, - /*sizes=*/SymbolicShape(), - /*stride=*/VaryingShape{}, /*requires_grad=*/{}, - /*undefined=*/false, /*is_inferred=*/true); + /*scalar_type=*/{}, + /*device=*/{}, + /*sizes=*/SymbolicShape(), + /*stride=*/VaryingShape{}, + /*requires_grad=*/{}, + /*undefined=*/false); + valueInferred->is_inferred_ = true; return valueInferred; } @@ -800,6 +632,17 @@ struct CAFFE2_API TensorType : public Type { static const TypeKind Kind = TypeKind::TensorType; + static std::vector contiguousStridesOf(at::IntArrayRef sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) // zero-dim case + return strides; + strides.back() = 1; + for (size_t i = strides.size() - 1; i > 0; i--) { + strides[i - 1] = strides[i] * sizes[i]; + } + return strides; + } + private: TensorType( c10::optional scalar_type, @@ -814,17 +657,6 @@ struct CAFFE2_API TensorType : public Type { scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_)); } - static std::vector contiguousStridesOf(at::IntArrayRef sizes) { - std::vector strides(sizes.size()); - if (sizes.empty()) // zero-dim case - return strides; - strides.back() = 1; - for (size_t i = strides.size() - 1; i > 0; i--) { - strides[i - 1] = strides[i] * sizes[i]; - } - return strides; - } - static VaryingShape computeStrideProps( at::IntArrayRef sizes, at::IntArrayRef strides, @@ -853,7 +685,7 @@ struct CAFFE2_API TensorType : public Type { struct ListType; using ListTypePtr = std::shared_ptr; -struct CAFFE2_API ListType +struct TORCH_API ListType : public SingleElementType { // It's not exactly a singleton, but there should be exactly one instance of // List[T] for every T @@ -874,7 +706,7 @@ struct CAFFE2_API ListType return create(contained_types.at(0)); } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; // common cast List[Tensor] static ListTypePtr ofTensors(); @@ -895,7 +727,7 @@ struct CAFFE2_API ListType struct DictType; using DictTypePtr = std::shared_ptr; -struct CAFFE2_API DictType : public Type { +struct TORCH_API DictType : public Type { friend struct Type; static const TypeKind Kind = TypeKind::DictType; @@ -977,7 +809,7 @@ struct CAFFE2_API DictType : public Type { struct FutureType; using FutureTypePtr = std::shared_ptr; -struct CAFFE2_API FutureType +struct TORCH_API FutureType : public SingleElementType { friend struct Type; template @@ -996,7 +828,7 @@ struct CAFFE2_API FutureType return create(contained_types.at(0)); } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { if (Type::isSubtypeOfExt(rhs, why_not)) { return true; } @@ -1019,7 +851,7 @@ struct CAFFE2_API FutureType struct RRefType; using RRefTypePtr = std::shared_ptr; -struct CAFFE2_API RRefType +struct TORCH_API RRefType : public SingleElementType { friend struct Type; template @@ -1053,7 +885,7 @@ struct NamedType; using NamedTypePtr = std::shared_ptr; using ConstNamedTypePtr = std::shared_ptr; -struct CAFFE2_API NamedType : public Type { +struct TORCH_API NamedType : public Type { NamedType(TypeKind tk, c10::optional name) : Type(tk), name_(std::move(name)) { TORCH_INTERNAL_ASSERT( @@ -1080,7 +912,7 @@ struct CAFFE2_API NamedType : public Type { // static types in named types to reconstruct type tags of loaded // values. Lifting this restriction requires solving the serialization // problem first. -CAFFE2_API void checkNoAny( +TORCH_API void checkNoAny( const Type& base, const char* what, const std::string& attrname, @@ -1090,7 +922,7 @@ struct TupleType; using TupleTypePtr = std::shared_ptr; using NameList = std::vector; // This type represents a Tuple -struct CAFFE2_API TupleType : public NamedType { +struct TORCH_API TupleType : public NamedType { static TupleTypePtr createNamed(const c10::optional& name, const std::vector& field_names, const std::vector& types); @@ -1107,7 +939,7 @@ struct CAFFE2_API TupleType : public NamedType { } bool operator==(const Type& rhs) const override; - bool isSubtypeOfExt(const TypePtr rhs_, std::ostream* why_not) const override; + bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override; std::string str() const override; bool hasFreeVariables() const override { @@ -1161,7 +993,7 @@ struct CAFFE2_API TupleType : public NamedType { struct EnumType; using EnumTypePtr = std::shared_ptr; using EnumNameValue = std::pair; -struct CAFFE2_API EnumType : public NamedType { +struct TORCH_API EnumType : public NamedType { friend struct Type; static const TypeKind Kind = TypeKind::EnumType; @@ -1202,7 +1034,7 @@ struct CAFFE2_API EnumType : public NamedType { return false; } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; std::shared_ptr compilation_unit() const { auto cu = cu_.lock(); @@ -1247,7 +1079,7 @@ struct CAFFE2_API EnumType : public NamedType { // EnumType <: AnyEnumType for all Enums struct AnyEnumType; using AnyEnumTypePtr = std::shared_ptr; -struct CAFFE2_API AnyEnumType : public Type { +struct TORCH_API AnyEnumType : public Type { static AnyEnumTypePtr create() { return AnyEnumTypePtr( new AnyEnumType()); // NOLINT(modernize-make-shared) @@ -1273,7 +1105,7 @@ using NumberTypePtr = std::shared_ptr; // Subtype hierarchy for Number Types (NumberType as the base type): // IntType <: NumberType // FloatType <: NumberType -struct CAFFE2_API NumberType : public Type { +struct TORCH_API NumberType : public Type { static NumberTypePtr create() { return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared) } @@ -1300,7 +1132,7 @@ struct CAFFE2_API NumberType : public Type { struct FloatType; using FloatTypePtr = std::shared_ptr; // This type represents a Python float number -struct CAFFE2_API FloatType : public NumberType { +struct TORCH_API FloatType : public NumberType { static FloatTypePtr create() { return FloatTypePtr(new FloatType()); // NOLINT(modernize-make-shared) } @@ -1310,7 +1142,7 @@ struct CAFFE2_API FloatType : public NumberType { std::string str() const override { return "float"; } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::FloatType; @@ -1327,7 +1159,7 @@ struct CAFFE2_API FloatType : public NumberType { struct IntType; using IntTypePtr = std::shared_ptr; // This type represents a Python int number -struct CAFFE2_API IntType : public NumberType { +struct TORCH_API IntType : public NumberType { static IntTypePtr create() { return IntTypePtr(new IntType()); // NOLINT(modernize-make-shared) } @@ -1337,7 +1169,7 @@ struct CAFFE2_API IntType : public NumberType { std::string str() const override { return "int"; } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override { + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::IntType; @@ -1354,7 +1186,7 @@ struct CAFFE2_API IntType : public NumberType { struct BoolType; using BoolTypePtr = std::shared_ptr; // This node represents a Python bool value -struct CAFFE2_API BoolType : public Type { +struct TORCH_API BoolType : public Type { static BoolTypePtr create() { return BoolTypePtr(new BoolType()); } @@ -1375,7 +1207,7 @@ struct CAFFE2_API BoolType : public Type { struct StringType; using StringTypePtr = std::shared_ptr; // This type represents a Python string -struct CAFFE2_API StringType : public Type { +struct TORCH_API StringType : public Type { static StringTypePtr create() { return StringTypePtr(new StringType()); // NOLINT(modernize-make-shared) } @@ -1397,9 +1229,32 @@ struct CAFFE2_API StringType : public Type { StringType() : Type(TypeKind::StringType) {} }; +struct StorageType; +using StorageTypePtr = std::shared_ptr; +struct TORCH_API StorageType : public Type { + static StorageTypePtr create() { + return StorageTypePtr(new StorageType()); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return annotation_str(); + } + std::string annotation_str_impl(TypePrinter printer = nullptr) const override { + return "Storage"; + } + static const TypeKind Kind = TypeKind::StorageType; + // global singleton + static StorageTypePtr get(); + + private: + StorageType() : Type(TypeKind::StorageType) {} +}; + struct FunctionType; using FunctionTypePtr = std::shared_ptr; -struct CAFFE2_API FunctionType : public NamedType { +struct TORCH_API FunctionType : public NamedType { static FunctionTypePtr create(torch::jit::Function* function) { return FunctionTypePtr( new FunctionType(function)); // NOLINT(modernize-make-shared) @@ -1431,7 +1286,7 @@ struct CAFFE2_API FunctionType : public NamedType { struct NoneType; using NoneTypePtr = std::shared_ptr; // This type represents a Python None -struct CAFFE2_API NoneType : public Type { +struct TORCH_API NoneType : public Type { static NoneTypePtr create() { return NoneTypePtr(new NoneType()); // NOLINT(modernize-make-shared) } @@ -1441,7 +1296,7 @@ struct CAFFE2_API NoneType : public Type { std::string str() const override { return "None"; } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream *why_not) const override { + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override { if (rhs->kind() == OptionalType::Kind) { return true; } @@ -1458,7 +1313,7 @@ struct CAFFE2_API NoneType : public Type { struct GeneratorType; using GeneratorTypePtr = std::shared_ptr; // This type represents a Generator -struct CAFFE2_API GeneratorType : public Type { +struct TORCH_API GeneratorType : public Type { static GeneratorTypePtr create() { return GeneratorTypePtr( new GeneratorType()); // NOLINT(modernize-make-shared) @@ -1480,7 +1335,7 @@ struct CAFFE2_API GeneratorType : public Type { struct QuantizerType; using QuantizerTypePtr = std::shared_ptr; // This type represents a Quantizer -struct CAFFE2_API QuantizerType : public Type { +struct TORCH_API QuantizerType : public Type { static QuantizerTypePtr create() { return QuantizerTypePtr( new QuantizerType()); // NOLINT(modernize-make-shared) @@ -1502,7 +1357,7 @@ struct CAFFE2_API QuantizerType : public Type { struct QSchemeType; using QSchemeTypePtr = std::shared_ptr; // This type represents a QScheme -struct CAFFE2_API QSchemeType : public Type { +struct TORCH_API QSchemeType : public Type { static QSchemeTypePtr create() { return QSchemeTypePtr( new QSchemeType()); // NOLINT(modernize-make-shared) @@ -1524,7 +1379,7 @@ struct CAFFE2_API QSchemeType : public Type { struct DeviceObjType; using DeviceObjTypePtr = std::shared_ptr; // This type represents a Device -struct CAFFE2_API DeviceObjType : public Type { +struct TORCH_API DeviceObjType : public Type { static DeviceObjTypePtr create() { return DeviceObjTypePtr( new DeviceObjType()); // NOLINT(modernize-make-shared) @@ -1543,6 +1398,28 @@ struct CAFFE2_API DeviceObjType : public Type { DeviceObjType() : Type(TypeKind::DeviceObjType) {} }; +struct StreamObjType; +using StreamObjTypePtr = std::shared_ptr; +// This type represents a Generator +struct TORCH_API StreamObjType : public Type { + static StreamObjTypePtr create() { + return StreamObjTypePtr( + new StreamObjType()); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Stream"; + } + static const TypeKind Kind = TypeKind::StreamObjType; + // global singleton + static StreamObjTypePtr get(); + +private: + StreamObjType() : Type(TypeKind::StreamObjType) {} +}; + struct VarType; using VarTypePtr = std::shared_ptr; // This type represents a type variable, used in FunctionSchema @@ -1574,7 +1451,7 @@ struct CapsuleType; using CapsuleTypePtr = std::shared_ptr; // This type represents a Python Capsule. // It does not appear in the IR and is only used during runtime -struct CAFFE2_API CapsuleType : public Type { +struct TORCH_API CapsuleType : public Type { static CapsuleTypePtr create() { return CapsuleTypePtr(new CapsuleType()); // NOLINT(modernize-make-shared) } @@ -1595,7 +1472,7 @@ struct CAFFE2_API CapsuleType : public Type { struct PyObjectType; using PyObjectTypePtr = std::shared_ptr; // This type represents a PyObject Type -struct CAFFE2_API PyObjectType : public Type { +struct TORCH_API PyObjectType : public Type { static PyObjectTypePtr create() { return PyObjectTypePtr(new PyObjectType()); // NOLINT(modernize-make-shared) } @@ -1621,16 +1498,16 @@ enum class TypeVerbosity { Default = Full, }; -CAFFE2_API TypeVerbosity type_verbosity(); +TORCH_API TypeVerbosity type_verbosity(); -CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t); +TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t); template -CAFFE2_API std::ostream& operator<<( +TORCH_API std::ostream& operator<<( std::ostream& out, const VaryingShape& t); -CAFFE2_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s); -CAFFE2_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s); -CAFFE2_API std::ostream& operator<<(std::ostream& os, const Stride& s); +TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s); +TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s); +TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s); // what is the type, ignoring extra size/shape information? // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) @@ -1682,12 +1559,12 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { // Two different tensortypes will return dynamic. // Currently we chose not to support returning a NumberType for a float & int // input because of a lack of operator support for NumberType -CAFFE2_API c10::optional unifyTypes( +TORCH_API c10::optional unifyTypes( const TypePtr& t1, const TypePtr& t2, bool default_to_any = false); -CAFFE2_API c10::optional unifyTypeList( +TORCH_API c10::optional unifyTypeList( at::ArrayRef elements, std::ostream& why_not); @@ -1695,13 +1572,18 @@ namespace detail { template struct getTypePtr_ final { static TypePtr call() { - TORCH_CHECK( - isCustomClassRegistered(), - "Type ", - c10::util::get_fully_qualified_type_name(), - " could not be converted to any of the known types." - ); - auto res = getCustomClassType(); + TypePtr res = []() { + try { + return getCustomClassType(); + } catch(const c10::Error&) { + TORCH_CHECK( + false, + "Type ", + c10::util::get_fully_qualified_type_name(), + " could not be converted to any of the known types." + ); + } + }(); return std::dynamic_pointer_cast(std::move(res)); } }; @@ -1720,6 +1602,18 @@ struct getTypePtr_ final { } }; template <> +struct getTypePtr_ final { + static TypePtr call() { + return StorageType::get(); + } +}; +template <> +struct getTypePtr_ final { + static TypePtr call() { + return StreamObjType::get(); + } +}; +template <> struct getTypePtr_ final { static TypePtr call() { return FloatType::get(); @@ -1785,6 +1679,12 @@ struct getTypePtr_ final { return StringType::get(); } }; +template <> +struct getTypePtr_ final { + static TypePtr call() { + return StringType::get(); + } +}; template struct getTypePtr_> final { static TypePtr call() { @@ -1884,15 +1784,15 @@ struct MatchTypeReturn { // note: It is possible to successfully match a formal, but for type variables // in the formal to still not be defined. In particular, None matches Optional[T] // but does not define the value of T. -CAFFE2_API MatchTypeReturn +TORCH_API MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env); // replace type variables appearing in `type` with the values in // `type_env`. Returns nullptr if a variable used in `type` // does not appear in `type_env` -CAFFE2_API TypePtr tryEvalTypeVariables(TypePtr type, TypeEnv& type_env); +TORCH_API TypePtr tryEvalTypeVariables(TypePtr type, TypeEnv& type_env); -CAFFE2_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type); +TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type); // This enumerator represents the 'kind' of an attribute - a buffer, a paramter, or neither. // This state is mutually exclusive. Buffers and Parameters can only appear on modules. @@ -1904,7 +1804,7 @@ enum class AttributeKind { // This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr). // Note: This structure does not represent the value of the attribute. -struct CAFFE2_API ClassAttribute { +struct TORCH_API ClassAttribute { public: ClassAttribute(AttributeKind kind, TypePtr attributeType, @@ -1940,7 +1840,7 @@ using ClassTypePtr = std::shared_ptr; using ::torch::jit::CompilationUnit; // This represents a class in TorchScript. -struct CAFFE2_API ClassType : public NamedType { +struct TORCH_API ClassType : public NamedType { // This represents an attribute of a class; a name associated with an attribute, and a // getter and (optional) setter for that attribute. struct Property { @@ -1953,7 +1853,8 @@ struct CAFFE2_API ClassType : public NamedType { static ClassTypePtr create( c10::optional qualifiedName, std::weak_ptr cu, - bool is_module = false); + bool is_module = false, + std::string doc_string = ""); bool operator==(const Type& rhs) const override { if (auto user_rhs = rhs.cast()) { @@ -2065,6 +1966,13 @@ struct CAFFE2_API ClassType : public NamedType { // valid again. void unsafeRemoveAttribute(const std::string& name); + // [Internal Only] Change the type of an attribute of the ClassType, + // The caller is responsible to make sure the modification is safe: + // it is unsafe to maintain uses of the old type of the attribute, + // and any code that works on the attribute is now invalid. + // Only newly created code is valid again. + void unsafeChangeAttributeType(const std::string& name, TypePtr new_ty); + // Add attribute \p NAME if it doesn't exist or verify that it has a // compatible type otherwise. size_t addOrCheckAttribute( @@ -2139,6 +2047,9 @@ struct CAFFE2_API ClassType : public NamedType { return constantNames_[slot]; } + const std::string& doc_string() const { + return doc_string_; + } IValue getConstant(const std::string& name) const; @@ -2227,7 +2138,7 @@ struct CAFFE2_API ClassType : public NamedType { // These variants are not registered in the global class table. ClassTypePtr refine(at::ArrayRef refined_slots) const; - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; static const TypeKind Kind = TypeKind::ClassType; @@ -2235,7 +2146,8 @@ struct CAFFE2_API ClassType : public NamedType { ClassType( c10::optional name, std::weak_ptr cu, - bool is_module); + bool is_module, + std::string doc_string); std::string annotation_str_impl(TypePrinter printer = nullptr) const override { const auto& n = name().value(); @@ -2270,6 +2182,9 @@ struct CAFFE2_API ClassType : public NamedType { std::vector properties_; bool isModule_ = false; + + // Doc string of class. + std::string doc_string_ = ""; }; struct InterfaceType; @@ -2283,7 +2198,7 @@ using ::torch::jit::CompilationUnit; // lhs (ClassType or InterfaceType) is a subtype of rhs if: // 1. lhs methods are a superset of rhs methods // 2. if rhs is module interface, the lhs must be module interface or module itself -struct CAFFE2_API InterfaceType : public NamedType { +struct TORCH_API InterfaceType : public NamedType { static InterfaceTypePtr create( QualifiedName qualifiedName, bool is_module=false); @@ -2299,7 +2214,7 @@ struct CAFFE2_API InterfaceType : public NamedType { return std::string("InterfaceType<") + name()->name() + ">"; } - bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override; + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; // try to find a method of this interface, // returns nullptr if not found. @@ -2347,7 +2262,7 @@ EnumerationType() : Type(Kind) {} struct LayoutType; using LayoutTypePtr = std::shared_ptr; // This type represents a Generator -struct CAFFE2_API LayoutType : public EnumerationType { +struct TORCH_API LayoutType : public EnumerationType { static LayoutTypePtr create() { return LayoutTypePtr( new LayoutType()); // NOLINT(modernize-make-shared) @@ -2366,7 +2281,7 @@ LayoutType() : EnumerationType() {} struct ScalarTypeType; using ScalarTypeTypePtr = std::shared_ptr; // This type represents a Generator -struct CAFFE2_API ScalarTypeType : public EnumerationType { +struct TORCH_API ScalarTypeType : public EnumerationType { static ScalarTypeTypePtr create() { return ScalarTypeTypePtr( new ScalarTypeType()); // NOLINT(modernize-make-shared) @@ -2386,7 +2301,7 @@ ScalarTypeType() : EnumerationType() {} // List[T] <: AnyList for all T struct AnyListType; using AnyListTypePtr = std::shared_ptr; -struct CAFFE2_API AnyListType : public Type { +struct TORCH_API AnyListType : public Type { static AnyListTypePtr create() { return AnyListTypePtr( new AnyListType()); // NOLINT(modernize-make-shared) @@ -2409,7 +2324,7 @@ struct CAFFE2_API AnyListType : public Type { // Tuple[T...] <: AnyTuple for all T struct AnyTupleType; using AnyTupleTypePtr = std::shared_ptr; -struct CAFFE2_API AnyTupleType : public Type { +struct TORCH_API AnyTupleType : public Type { static AnyTupleTypePtr create() { return AnyTupleTypePtr( new AnyTupleType()); // NOLINT(modernize-make-shared) @@ -2434,7 +2349,7 @@ struct CAFFE2_API AnyTupleType : public Type { // ClassType <: AnyClassType for all classes struct AnyClassType; using AnyClassTypePtr = std::shared_ptr; -struct CAFFE2_API AnyClassType : public Type { +struct TORCH_API AnyClassType : public Type { static AnyClassTypePtr create() { return AnyClassTypePtr( new AnyClassType()); // NOLINT(modernize-make-shared) @@ -2455,19 +2370,19 @@ struct CAFFE2_API AnyClassType : public Type { inline bool IValue::isDoubleList() const { // note: avoids calling type() to avoid extra referencing counting for the returned type. - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == FloatType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == FloatType::Kind; } inline bool IValue::isTensorList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == TensorType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == TensorType::Kind; } inline bool IValue::isIntList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == IntType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == IntType::Kind; } inline bool IValue::isBoolList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == BoolType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == BoolType::Kind; } template<> @@ -2487,4 +2402,27 @@ inline std::shared_ptr Type::cast() const { } return nullptr; } + +// Used as a return type when inferring the IValue type of a Python object. +struct InferredType { + /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {} + /* implicit */ InferredType(std::string reason) + : type_(nullptr), reason_(std::move(reason)) {} + TypePtr type() const { + TORCH_INTERNAL_ASSERT(type_); + return type_; + } + bool success() const { + return type_ != nullptr; + } + const std::string& reason() const { + TORCH_INTERNAL_ASSERT(!type_); + return reason_; + } + +private: + TypePtr type_; + std::string reason_; +}; + } // namespace c10 diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h new file mode 100644 index 0000000000000..e5a6d48340cf9 --- /dev/null +++ b/aten/src/ATen/core/jit_type_base.h @@ -0,0 +1,221 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +#define C10_FORALL_TYPES(_) \ + _(AnyType) \ + _(EnumType) \ + _(AnyEnumType) \ + _(TensorType) \ + _(StorageType) \ + _(TupleType) \ + _(ListType) \ + _(DictType) \ + _(NumberType) \ + _(FloatType) \ + _(FutureType) \ + _(RRefType) \ + _(IntType) \ + _(NoneType) \ + _(StringType) \ + _(GeneratorType) \ + _(QuantizerType) \ + _(BoolType) \ + _(OptionalType) \ + _(VarType) \ + _(DeviceObjType) \ + _(StreamObjType) \ + _(FunctionType) \ + _(ClassType) \ + _(PyObjectType) \ + _(CapsuleType) \ + _(InterfaceType) \ + _(QSchemeType) \ + _(LayoutType) \ + _(ScalarTypeType) \ + _(AnyListType) \ + _(AnyTupleType) \ + _(AnyClassType) + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +TORCH_API const char* typeKindToString(TypeKind kind); + +struct Type; +using TypePtr = std::shared_ptr; +using ConstTypePtr = std::shared_ptr; + +// Use this to customize how a Type is printed using `annotation_str()`. If +// c10::nullopt is returned, `annotation_str()` falls through to its default +// implementation. +using TypePrinter = + std::function(const ConstTypePtr&)>; + +struct TORCH_API Type : std::enable_shared_from_this { + private: + TypeKind kind_; + + protected: + Type(TypeKind kind) : kind_(kind) {} + + virtual std::string annotation_str_impl(TypePrinter printer) const { + return str(); + } + + public: + virtual bool operator==(const Type& rhs) const = 0; + + // subtyping relation. By default, we return true for the case + // when the type is exactly equal or if this <: T where rhs = Optional[T] + + // if this returns false and the why_not stream is non-null, it contains + // additional details that describe why this is not a subtype of 'rhs'. + // This additional information should only contain details that are not obvious + // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false + // but not clear why `Foo <: InterfaceBar` might be false. + virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const; + virtual bool is_module() const; + bool isSubtypeOf(const TypePtr& rhs) const { + return isSubtypeOfExt(rhs, nullptr); + } + + // How this type will appear in FunctionSchema declarations + virtual std::string str() const = 0; + + // How this type will appear as if it were a type annotation in Python + // which is sometimes different than how it appears in declarations (e.g. + // int[] vs List[int]) + // + // Takes a custom printer that users can pass in to customize the output of + // this method. + std::string annotation_str(TypePrinter printer) const { + if (printer) { + // the printer can return nullopt to fall through to the default impl + if (auto renamed = printer(shared_from_this())) { + return *renamed; + } + } + return annotation_str_impl(printer); + } + std::string annotation_str() const { + // Overload instead of define a default value for `printer` to help + // debuggers out. + return annotation_str(nullptr); + } + + // Returns a human readable string that includes additional information like + // "type is inferred rather than explictly defined" to help construct more + // user-friendly messages. + virtual std::string repr_str() const { + return annotation_str(); + } + + TypeKind kind() const { + return kind_; + } + + virtual bool requires_grad() const { + for (const auto& ct : containedTypes()) { + if (ct->requires_grad()) { + return true; + } + } + return false; + } + + // Dynamically cast this object to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid. + template + std::shared_ptr cast() { + if (T::Kind == kind()) { + return std::static_pointer_cast(shared_from_this()); + } + return nullptr; + } + template + std::shared_ptr cast() const { + if (T::Kind == kind()) { + return std::static_pointer_cast(shared_from_this()); + } + return nullptr; + } + template + T* castRaw() { + if (T::Kind == kind()) { + return static_cast(this); + } + return nullptr; + } + template + const T* castRaw() const { + if (T::Kind == kind()) { + return static_cast(this); + } + return nullptr; + } + template + std::shared_ptr expect() { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + std::shared_ptr expect() const { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + T& expectRef() { + auto* r = castRaw(); + AT_ASSERT(r); + return *r; + } + template + const T& expectRef() const { + auto* r = castRaw(); + AT_ASSERT(r); + return *r; + } + virtual ~Type() = default; + virtual bool hasFreeVariables() const { + return false; + } + // list of types this type contains, e.g. for a List then element type of a + // list for a tuple, the types of the tuple elements + virtual at::ArrayRef containedTypes() const { + return {}; + } + // create a new version of this type, replacing its contained types with + // contained_types + TypePtr withContained(std::vector contained_types) { + auto current_contained = containedTypes(); + AT_ASSERT(current_contained.size() == contained_types.size()); + if (current_contained.equals(contained_types)) { + return shared_from_this(); + } + return createWithContained(std::move(contained_types)); + } + // per-type constructor, you only need to override this if the + // containedTypes() is not empty + virtual TypePtr createWithContained( + std::vector contained_types) const { + AT_ERROR( + "type with contained types did not overload createWithContained: ", + str()); + } +}; + +} diff --git a/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h b/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h index ea7a5bd0b54c1..5dc435a22035e 100644 --- a/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h +++ b/aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h @@ -207,12 +207,116 @@ constexpr auto with_explicit_optional_tensors(KernelFunc kernel_func) { return kernel_func; } +template constexpr bool is_out_argument_() { + return std::is_same::value; } +template using is_out_argument = guts::bool_constant()>; -template +template +struct with_out_arguments_reordered_impl final { +private: + // For an example op + // > aten::example(Tensor a, int64_t b, int64_t c, Tensor(a!) out_d, Tensor(b!) out_e) -> (Tensor(a!), Tensor(b!)) + // we get a KernelFunc + // > KernelFunc = std::tuple example(Tensor& out_d, Tensor& out_e, const Tensor& a, int64_t b, int64_t c) + // > NumOutParameters = 2 + // with the out arguments at the front, and reorder that into + // > std::tuple example(const Tensor& a, int64_t b, int64_t c, Tensor& out_d, Tensor& out_e) + // where the out arguments are in the back. + + using kernel_signature_traits = guts::infer_function_traits_t; + + // Assert that the KernelFunc is what we expect. The following block is + // not strictly necessary for the metaprogramming here, it's just a check. + static_assert( + guts::typelist::all< + is_out_argument, + guts::typelist::take_t< + typename kernel_signature_traits::parameter_types, + NumOutParameters + > + >::value, + "The kernel function has the wrong number of leading Tensor& arguments to match the out arguments in the JIT signature" + ); + + static constexpr size_t num_parameters = kernel_signature_traits::number_of_parameters; + static constexpr size_t num_nonout_parameters = num_parameters - NumOutParameters; + + // kernel_to_schema_permutation_indices contains a mapping from argument index in KernelFunc to the corresponding + // argument index in the schema. + // For the aten::example op, that'll be + // > kernel_to_schema_permutation_indices = [3, 4, 0, 1, 2] + // Interpreted as a mapping, this means + // - argument 0 in KernelFunc maps to argument 3 in the schema, + // - argument 1 in KernelFunc maps to argument 4 in the schema, + // - argument 2 in KernelFunc maps to argument 0 in the schema, + // - ... + // We can use this as a permutation function to reorder types or values correspondingly + using kernel_to_schema_permutation_indices = guts::concat_iseq_t< + guts::make_offset_index_sequence, + std::make_index_sequence + >; + + // For types, we need the inverse permutation because parameters (i.e. types) and arguments (i.e. values) + // need to be mapped in inverted directions. For types, we generate the schema order types from + // the KernelFunction types, but for arguments we get schema order arguments and need to generate + // the KernelFunction arguments. + // That's why in this reordering, we use NumOutParameters instead of the num_nonout_parameters we used above. + using schema_parameters = guts::typelist::concat_t< + guts::typelist::drop_t, + guts::typelist::take_t + >; + + template + struct wrapper_; + template + struct wrapper_, guts::typelist::typelist, std::index_sequence> { + static Return call(SchemaParameters... args) { + // call through to KernelFunc but reorder arguments as determined + // by the permutation we calculated above. + return (*KernelFunc::func_ptr())( + std::forward( + std::get( + std::tuple...>(args...) + ) + )... + ); + } + }; + +public: + using wrapper = wrapper_; +}; + + +/** + * Take a kernel function that has a number of `Tensor`, `const Tensor&` or `Tensor&` arguments + * where all `Tensor&` arguments are at the beginning, and take NumOutParameters. + * Create a wrapper function that has `NumOutParameters` `Tensor&` arguments at the end + * and calls through the underlying kernel function by reordering them to the front. + */ +template 0), int> = 0> +constexpr auto with_out_arguments_reordered(KernelFunc kernel_func) { + // SFINAE case for kernels that have out tensor arguments. + // Wrap them and reorder the arguments. + using impl = with_out_arguments_reordered_impl; + return TORCH_FN((&impl::wrapper::call)); +} + +template = 0> +constexpr auto with_out_arguments_reordered(KernelFunc kernel_func) { + // SFINAE case for kernels that don't have out tensor arguments. + // Don't wrap them but just use the kernel directly. + return kernel_func; +} + +} + +template constexpr auto hacky_wrapper_for_legacy_signatures(FuncPtr kernel_func) { - auto with_tensoroptions_scattered = detail::with_scattered_tensor_options(kernel_func); - auto result = detail::with_explicit_optional_tensors(with_tensoroptions_scattered); + auto with_scattered_tensor_options = detail::with_scattered_tensor_options(kernel_func); + auto with_out_arguments_reordered = detail::with_out_arguments_reordered(with_scattered_tensor_options); + auto result = detail::with_explicit_optional_tensors(with_out_arguments_reordered); static_assert(std::is_same::value, "Generated signature doesn't match the expected one."); return result; }; diff --git a/aten/src/ATen/core/op_registration/infer_schema.h b/aten/src/ATen/core/op_registration/infer_schema.h index 9dd73ae9fdc12..17bf6bb09c68a 100644 --- a/aten/src/ATen/core/op_registration/infer_schema.h +++ b/aten/src/ATen/core/op_registration/infer_schema.h @@ -22,6 +22,8 @@ namespace infer_schema { struct ArgumentDef final { using GetTypeFn = TypePtr(); GetTypeFn* getTypeFn; + constexpr ArgumentDef(): getTypeFn(nullptr) {} + explicit constexpr ArgumentDef(GetTypeFn *getTypeFn): getTypeFn(getTypeFn) {} }; template @@ -50,7 +52,7 @@ constexpr std::array createArgumentVectorFromTypes(s checkStaticTypes(), // Create the return value - std::array{{ArgumentDef{&getTypePtr_>::call}...}} + std::array{ArgumentDef(&getTypePtr_>::call)...} ); } @@ -151,6 +153,6 @@ FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn>(std::move(name), std::move(overload_name)); } -CAFFE2_API c10::optional findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); +TORCH_API c10::optional findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); } diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index a2a62152a5963..99fc2862614c4 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -1,5 +1,6 @@ #include +#include #include #if !defined(CAFFE2_IS_XPLAT_BUILD) #include diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 63dc16b82ea2b..f7ab2d0919e77 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -7,7 +7,9 @@ #include #include -#include +#include +#include +#include #include #if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD) #include @@ -43,7 +45,7 @@ std::unique_ptr inferFunctionSchemaFromFunctor() { * > .schema("my_op") * > .kernel(DispatchKey::CPU)); */ -class CAFFE2_API RegisterOperators final { +class TORCH_API RegisterOperators final { public: RegisterOperators(); ~RegisterOperators(); @@ -53,7 +55,7 @@ class CAFFE2_API RegisterOperators final { RegisterOperators(RegisterOperators&&) noexcept; RegisterOperators& operator=(RegisterOperators&&) noexcept; - class CAFFE2_API Options final { + class TORCH_API Options final { public: Options(const Options&) = delete; Options(Options&&) noexcept = delete; @@ -237,7 +239,7 @@ class CAFFE2_API RegisterOperators final { return std::move(*this).kernel( std::move(dispatch_key), - KernelFunction::makeFromUnboxedFunction(CompileTimeFunctionPointer()), + KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)), impl::CppSignature::make(), // TODO Do schema inference without relying on WrapFunctionIntoFunctor detail::inferFunctionSchemaFromFunctor>::type>() diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index e49e32829cd84..56afe8ca7fb53 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -46,50 +46,6 @@ struct MockKernel final : OperatorKernel { bool* called_; }; -TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithAliasAnalysisAfterRegisteringWithoutAliasAnalysis_thenCanBeCalled) { - { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU)); - // NB: this is OK right now for BC reasons - // expectThrows([&] { - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLA).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - // }, "Tried to define the schema for _test::dummy multiple times without providing an explicit alias analysis kind"); - } -} - -TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithoutAliasAnalysisAfterRegisteringWithAliasAnalysis_thenCanBeCalled) { - { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLA).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - // NB: this is OK right now for BC reasons - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU)); - } -} - -TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithSameAliasAnalysis_thenCanBeCalled) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLA).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); - EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); -} - -TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_thenCanBeCalled) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLA)); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); - EXPECT_TRUE(op->schema().isDefaultAliasAnalysisKind()); - EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::CONSERVATIVE); -} - -TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithDifferentAliasAnalysis_thenShouldThrow) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - expectThrows([] { - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLA).aliasAnalysis(at::AliasAnalysisKind::CONSERVATIVE)); - }, "Tried to define the schema for _test::dummy with different alias analysis kind"); -} - TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled) { bool called = false; auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy(Tensor dummy) -> ()").catchAllKernel(&called)); @@ -254,42 +210,6 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRunningOutOfScope_thenS EXPECT_FALSE(op.has_value()); } -TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwards_thenCanBeCalled) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - - bool called_kernel = false; - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel)); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - EXPECT_TRUE(called_kernel); -} - -TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwardsWithDifferentSchema_thenFails) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, int arg) -> ()"); - - bool called_kernel = false; - expectThrows([&] { - c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel)); - }, "Tried to register multiple operators with the same name and the same overload name but different schemas"); -} - -TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwardsAndRunsOutOfScope_thenSchemaIsStillThereButCannotBeCalledAnymore) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - - { - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU)); - } - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend."); -} - TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegistering_thenRegisters) { // as long as we don't register non-catchall kernels, ops without tensor arguments are fine auto registrar = c10::RegisterOperators().op("_test::dummy() -> ()"); @@ -298,21 +218,6 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegi ASSERT_TRUE(op.has_value()); // assert schema is registered } -TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegistering_thenShowsWarning) { - auto registrar = c10::RegisterOperators() - .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU)); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - testing::internal::CaptureStderr(); - c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU)); - std::string output = testing::internal::GetCapturedStderr(); - EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); - EXPECT_THAT(output, testing::HasSubstr("CPU")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator")); -} - TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails) { expectThrows([&] { auto registrar = c10::RegisterOperators() @@ -322,35 +227,6 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegis }, "In operator registration: Tried to register multiple kernels with same dispatch key CPU for operator schema _test::dummy"); } -TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel2)); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - EXPECT_FALSE(called_kernel1); - EXPECT_TRUE(called_kernel2); -} - -TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegistering_thenShowsWarning) { - auto registrar = c10::RegisterOperators() - .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel()); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - testing::internal::CaptureStderr(); - c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel()); - std::string output = testing::internal::GetCapturedStderr(); - EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); - EXPECT_THAT(output, testing::HasSubstr("catch all")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator")); -} - TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails) { expectThrows([&] { auto registrar = c10::RegisterOperators() @@ -360,160 +236,6 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSam }, "Tried to register multiple catch-all kernels for operator schema _test::dummy"); } -TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenCalled_thenCallsNewerKernel) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel2)); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - EXPECT_FALSE(called_kernel1); - EXPECT_TRUE(called_kernel2); -} - -TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewerKernelDeletedAndOpCalled_thenCallsOlderKernel) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel2)); - - registrar2 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - EXPECT_TRUE(called_kernel1); - EXPECT_FALSE(called_kernel2); -} - -TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerKernelDeletedAndOpCalled_thenCallsOlderKernel) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel2)); - - registrar2 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - EXPECT_TRUE(called_kernel1); - EXPECT_FALSE(called_kernel2); -} - -TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlderKernelDeletedAndOpCalled_thenCallsNewerKernel) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel2)); - - registrar1 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - EXPECT_FALSE(called_kernel1); - EXPECT_TRUE(called_kernel2); -} - -TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderKernelDeletedAndOpCalled_thenCallsNewerKernel) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel2)); - - registrar1 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - EXPECT_FALSE(called_kernel1); - EXPECT_TRUE(called_kernel2); -} - -TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel2)); - - registrar1 = c10::RegisterOperators(); // destruct the registrar - registrar2 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend."); -} - -TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderAndThenNewerKernelDeletedAndOpCalled_thenFails) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel2)); - - registrar1 = c10::RegisterOperators(); // destruct the registrar - registrar2 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend."); -} - -TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPU, &called_kernel2)); - - registrar2 = c10::RegisterOperators(); // destruct the registrar - registrar1 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend."); -} - -TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerAndThenOlderKernelDeletedAndOpCalled_thenFails) { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called_kernel2)); - - registrar2 = c10::RegisterOperators(); // destruct the registrar - registrar1 = c10::RegisterOperators(); // destruct the registrar - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend."); -} - TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUDispatchKey) { bool called_kernel_cpu = false; auto registrar= c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() @@ -569,35 +291,6 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCall }, "CUDA"); } -TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallOutOfScopeAndCalling_thenFails) { - auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - { - bool called_kernel1 = false; - bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(c10::DispatchKey::CPU, &called_kernel1) - .kernel(c10::DispatchKey::CUDA, &called_kernel2)); - } - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); // assert schema is registered - - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - }, "Could not run '_test::dummy' with arguments from the 'CPU'" - " backend."); - - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::CUDA)); - }, "Could not run '_test::dummy' with arguments from the 'CUDA'" - " backend."); - - expectThrows([&] { - callOp(*op, dummyTensor(c10::DispatchKey::XLA)); - }, "Could not run '_test::dummy' with arguments from the 'XLA'" - " backend."); -} - bool called_stackbased_kernel = false; void stackBasedKernel(const OperatorHandle&, c10::Stack* stack) { called_stackbased_kernel = true; @@ -701,7 +394,7 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(c10::DispatchKey::CPU) .kernel(c10::DispatchKey::CUDA, &called_kernel)); - }, "mismatched with a previous kernel that had the signature"); + }, "Mismatch in kernel C++ signatures"); } void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) { @@ -777,22 +470,6 @@ TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKer EXPECT_TRUE(called); } -TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) { - auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), ""); - - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options() - .catchAllKernel([] (Tensor, std::string) { - called = true; - })); - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); - - called = false; - auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello "); - EXPECT_FALSE(called); - EXPECT_EQ("hello _test::dummy", stack[1].toString()->string()); -} - bool called_autograd = false; bool called_nonautograd = false; @@ -835,20 +512,6 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_th EXPECT_TRUE(called_autograd); } -TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) { - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(DispatchKey::CPU) - .kernel(DispatchKey::Autograd)); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); - - called_nonautograd = called_autograd = false; - op->typed().call(dummyTensor(DispatchKey::CPU)); - EXPECT_TRUE(called_nonautograd); - EXPECT_FALSE(called_autograd); -} - TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallAutogradKernel) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() .catchAllKernel() @@ -857,10 +520,11 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_t auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); + // catchAll now maps to Math which has higher precedence than Autograd called_nonautograd = called_autograd = false; op->typed().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true)); - EXPECT_FALSE(called_nonautograd); - EXPECT_TRUE(called_autograd); + EXPECT_TRUE(called_nonautograd); + EXPECT_FALSE(called_autograd); } TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) { @@ -931,67 +595,68 @@ TEST(OperatorRegistrationTest, AutogradXLAOverridesAutogradKernel) { } TEST(OperatorRegistrationTest, whenRegisterWithXLAKernelAndCatchAll_AutogradXLAIsNotFilled) { - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .catchAllKernel()); - - auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); - - called_nonautograd = called_autograd = false; - op->typed().call(dummyTensor(DispatchKey::XLA, /*requires_grad=*/true)); - EXPECT_TRUE(called_nonautograd); - EXPECT_FALSE(called_autograd); + { + auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() + .catchAllKernel()); - called_nonautograd = called_autograd = false; - op->typed().call(dummyTensor(DispatchKey::XLA)); - EXPECT_FALSE(called_autograd); - EXPECT_TRUE(called_nonautograd); + auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); + ASSERT_TRUE(op.has_value()); - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(DispatchKey::XLA)); + called_nonautograd = called_autograd = false; + op->typed().call(dummyTensor(DispatchKey::XLA, /*requires_grad=*/true)); + EXPECT_TRUE(called_nonautograd); + EXPECT_FALSE(called_autograd); - op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); - ASSERT_TRUE(op.has_value()); + called_nonautograd = called_autograd = false; + op->typed().call(dummyTensor(DispatchKey::XLA)); + EXPECT_FALSE(called_autograd); + EXPECT_TRUE(called_nonautograd); + } + { + auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() + .kernel(DispatchKey::XLA) + .catchAllKernel()); - // When there's direct registration to XLA backend, AutogradXLA doesn't pick up catchAll - // kernel in precompute but just keep fallthrough kernel from backend fallback. - // Thus it falls through AutogradXLA and reaches the kernel at XLA key. - called_nonautograd = called_autograd = false; - op->typed().call(dummyTensor(DispatchKey::XLA, /*requires_grad=*/true)); - EXPECT_FALSE(called_nonautograd); - EXPECT_TRUE(called_autograd); + auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); + ASSERT_TRUE(op.has_value()); - called_nonautograd = called_autograd = false; - op->typed().call(dummyTensor(DispatchKey::XLA)); - EXPECT_TRUE(called_autograd); - EXPECT_FALSE(called_nonautograd); + // When there's direct registration to XLA backend, AutogradXLA doesn't pick up catchAll + // kernel in precompute but just keep fallthrough kernel from backend fallback. + // Thus it falls through AutogradXLA and reaches the kernel at XLA key. + called_nonautograd = called_autograd = false; + op->typed().call(dummyTensor(DispatchKey::XLA, /*requires_grad=*/true)); + EXPECT_FALSE(called_nonautograd); + EXPECT_TRUE(called_autograd); + + called_nonautograd = called_autograd = false; + op->typed().call(dummyTensor(DispatchKey::XLA)); + EXPECT_TRUE(called_autograd); + EXPECT_FALSE(called_nonautograd); + } } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingCppSignatures_thenFails) { - auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .kernel(DispatchKey::CPU, [] (int64_t) {})); expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + .kernel(DispatchKey::CPU, [] (const int64_t&) {}) + .kernel(DispatchKey::CUDA, [] (int64_t&) {})); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails) { - auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .catchAllKernel([] (int64_t) {})); expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + .kernel(DispatchKey::CPU, [] (const int64_t&) {}) + .catchAllKernel([] (int64_t) {})); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails) { - auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .kernel(DispatchKey::CPU, [] (int64_t) {})); expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .catchAllKernel([] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + .catchAllKernel([] (const int64_t&) {}) + .kernel(DispatchKey::CPU, [] (int64_t) {})); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails) { @@ -1000,7 +665,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCpp expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int _0) -> ()"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) { @@ -1009,7 +674,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingCatchAllWithMismat expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int _0) -> ()"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingCppSignatures_thenFails) { @@ -1018,7 +683,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingC m.impl("dummy", DispatchKey::CPU, [] (int64_t) {}); expectThrows([&] { m.impl("dummy", DispatchKey::CUDA, [] (const int64_t&) {}); - }, "mismatched with a previous kernel that had the signature"); + }, "Mismatch in kernel C++ signatures"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails) { @@ -1028,7 +693,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCpp expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int a) -> ()"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingCatchAllWithMismatchingCppSignatures_thenFails) { @@ -1037,7 +702,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingCatchAllWithMismat expectThrows([] { c10::Dispatcher::singleton().findSchemaOrThrow("_test::dummy", "") .typed(); - }, "Tried to access operator _test::dummy with a wrong signature"); + }, "Tried to access or call an operator with a wrong signature.\n operator: _test::dummy(int a) -> ()"); } /** @@ -1627,6 +1292,39 @@ TEST(NewOperatorRegistrationTest, schema) { ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def4", ""})->schema().isDefaultAliasAnalysisKind()); } +TEST(NewOperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) { + auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, CPU); + m1.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + + bool called = false; + auto m = MAKE_TORCH_LIBRARY(test); + m.def("fn(Tensor t, str input) -> ()"); + m.impl("fn", [&] (Tensor, std::string) { called = true; }); + + auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); + ASSERT_TRUE(op.has_value()); + + called = false; + auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello "); + // CatchAll now maps to Math and has higher precedence than backend fallback. + EXPECT_TRUE(called); +} + +TEST(NewOperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) { + auto m = MAKE_TORCH_LIBRARY(test); + m.def("fn(Tensor dummy) -> ()"); + m.impl("fn", c10::DispatchKey::CPU, nonautograd_kernel); + m.impl("fn", c10::DispatchKey::Autograd, autograd_kernel); + + auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); + ASSERT_TRUE(op.has_value()); + + called_nonautograd = called_autograd = false; + callOp(*op, dummyTensor(DispatchKey::CPU)); + EXPECT_TRUE(called_nonautograd); + EXPECT_FALSE(called_autograd); +} + TEST(NewOperatorRegistrationTest, dispatchWithMathKernel) { bool math_called = false; auto m = MAKE_TORCH_LIBRARY(test); @@ -1708,18 +1406,20 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathAndCatchAllKernel) { auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); ASSERT_TRUE(op.has_value()); + // catchAll now maps to Math, which means we have two registrations to Math key. + // The last registration is used. { catchall_called = math_called = false; callOp(*op, dummyTensor(c10::DispatchKey::CPU)); - ASSERT_TRUE(math_called); - ASSERT_FALSE(catchall_called); + ASSERT_FALSE(math_called); + ASSERT_TRUE(catchall_called); } { catchall_called = math_called = false; callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true)); - ASSERT_TRUE(math_called); - ASSERT_FALSE(catchall_called); + ASSERT_FALSE(math_called); + ASSERT_TRUE(catchall_called); } } @@ -1802,6 +1502,152 @@ TEST(NewOperatorRegistrationTest, BackendOverridesMathKernel) { } } +TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendKernel) { + bool called = false; + auto m = MAKE_TORCH_LIBRARY(test); + m.def("fn", torch::dispatch(c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { called = true; return x; })); + + auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); + ASSERT_TRUE(op.has_value()); + + { + ASSERT_FALSE(called); + callOp(*op, dummyTensor(c10::DispatchKey::CPU)); + ASSERT_TRUE(called); + } + + { + called = false; + // AutogradCPU is fallthrough, calls CPU kernel + callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true)); + ASSERT_TRUE(called); + } + + { + called = false; + callOp(*op, dummyTensor(c10::DispatchKey::XLA)); + ASSERT_TRUE(called); + } + + { + called = false; + // AutogradXLA is fallthrough, calls XLA kernel + callOp(*op, dummyTensor(c10::DispatchKey::XLA, /*requires_grad=*/true)); + ASSERT_TRUE(called); + } + + { + called = false; + callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU)); + ASSERT_TRUE(called); + } + + { + called = false; + // AutogradCPU is fallthrough, calls CPU kernel + callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true)); + ASSERT_TRUE(called); + } +} + +TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendAndMathKernel) { + bool backend_called = false; + bool math_called = false; + auto m = MAKE_TORCH_LIBRARY(test); + m.def("fn", torch::dispatch(c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { backend_called = true; return x; })); + m.impl("fn", c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; }); + + auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); + ASSERT_TRUE(op.has_value()); + + { + backend_called = math_called = false; + callOp(*op, dummyTensor(c10::DispatchKey::CPU)); + ASSERT_TRUE(backend_called); + ASSERT_FALSE(math_called); + } + + { + backend_called = math_called = false; + // AutogradCPU is fallthrough, calls CPU kernel + callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true)); + ASSERT_FALSE(math_called); + ASSERT_TRUE(backend_called); + } + + { + backend_called = math_called = false; + callOp(*op, dummyTensor(c10::DispatchKey::XLA)); + ASSERT_TRUE(backend_called); + ASSERT_FALSE(math_called); + } + + { + backend_called = math_called = false; + // AutogradXLA is fallthrough, calls XLA kernel + callOp(*op, dummyTensor(c10::DispatchKey::XLA, /*requires_grad=*/true)); + ASSERT_FALSE(math_called); + ASSERT_TRUE(backend_called); + } + + { + backend_called = math_called = false; + callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU)); + ASSERT_TRUE(backend_called); + ASSERT_FALSE(math_called); + } + + { + backend_called = math_called = false; + // AutogradOther is fallthrough, calls SparseCPU kernel + callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true)); + ASSERT_FALSE(math_called); + ASSERT_TRUE(backend_called); + } +} + +TEST(NewOperatorRegistrationTest, BackendOverridesDefaultBackendKernel) { + bool default_called = false; + bool backend_called = false; + auto m = MAKE_TORCH_LIBRARY(test); + m.def("fn", torch::dispatch(c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { default_called = true; return x; })); + m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true; return x; }); + + auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); + ASSERT_TRUE(op.has_value()); + + { + default_called = backend_called = false; + callOp(*op, dummyTensor(c10::DispatchKey::CPU)); + ASSERT_TRUE(backend_called); + ASSERT_FALSE(default_called); + } + + { + default_called = backend_called = false; + // AutogradCPU is fallthrough, calls CPU kernel + callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true)); + ASSERT_TRUE(backend_called); + ASSERT_FALSE(default_called); + } + + { + default_called = backend_called = false; + callOp(*op, dummyTensor(c10::DispatchKey::CUDA)); + ASSERT_TRUE(default_called); + ASSERT_FALSE(backend_called); + } + + { + default_called = backend_called = false; + // AutogradCUDA is fallthrough, calls CUDA kernel + callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true)); + ASSERT_TRUE(default_called); + ASSERT_FALSE(backend_called); + } +} + + TEST(NewOperatorRegistrationTest, dispatch) { bool cpu_called = false; bool cuda_called = false; @@ -1866,7 +1712,7 @@ TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) { } bool autograd_called = false; - m.def("fn", torch::dispatch(c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; })); + m.impl("fn", c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; }); { auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); @@ -1876,7 +1722,7 @@ TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) { // Autograd backend kernel has higher precedence than Autograd alias. bool autogradcpu_called = false; - m.def("fn", torch::dispatch(c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autogradcpu_called = true; return x; })); + m.impl("fn", c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autogradcpu_called = true; return x; }); { auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); @@ -1909,6 +1755,10 @@ TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) { bool privateuse1_called = false; bool catchall_called = false; + // Similar to in-tree AutogradCPU/AutogradCUDA etc, out-of-tree backends usually register + // a fallthrough kernel for AutogradPrivateUse1. + auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1); + m1.fallback(CppFunction::makeFallthrough()); auto m = MAKE_TORCH_LIBRARY(test); m.def("fn", torch::dispatch(c10::DispatchKey::PrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; })); @@ -2059,7 +1909,7 @@ TEST(NewOperatorRegistrationTest, CppFunction) { m.def("fn3", [](const Tensor& x) { return x; }); // These require explicit schema m.def("fn4(Tensor x) -> Tensor", CppFunction::makeFallthrough()); - m.def("fn5(Tensor x) -> Tensor", CppFunction::makeUnboxedOnly(dummy_fn)); + m.def("fn5(Tensor x) -> Tensor", CppFunction::makeFromUnboxedFunction(dummy_fn)); m.def("fn6(Tensor x) -> Tensor", CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); } @@ -2095,6 +1945,22 @@ TEST(NewOperatorRegistrationTest, testDelayedListener) { EXPECT_EQ(initial_num_deregisters + 1, listener_ptr->num_deregisters_); } +TEST(NewOperatorRegistrationTest, testImplNoDefGetsCaught) { + auto danglingImpls = Dispatcher::singleton().findDanglingImpls(); + std::string error_str = "Discovered operators that have been registered through the dispatcher" + " without explicitly specifying their schemas. Please do so using" + " the TORCH_LIBRARY macro. Suspect operators:\n"; + for (auto& op : danglingImpls) { + auto& op_name = op.operator_name(); + error_str += "\t" + op_name.name; + if (op_name.overload_name != "") { + error_str += "." + op_name.overload_name; + } + error_str += "\n"; + } + ASSERT_EQ(danglingImpls.size(), 0) << error_str; +} + } #pragma GCC diagnostic pop diff --git a/aten/src/ATen/core/op_registration/op_whitelist.h b/aten/src/ATen/core/op_registration/op_whitelist.h index c8437e924a3c9..26d5533244d73 100644 --- a/aten/src/ATen/core/op_registration/op_whitelist.h +++ b/aten/src/ATen/core/op_registration/op_whitelist.h @@ -36,7 +36,9 @@ namespace impl { // returns true iff whitelist contains item // op_whitelist_contains("a;bc;d", "bc") == true constexpr bool op_whitelist_contains(string_view whitelist, string_view item) { - size_t next = -1; + //Choose a really big value for next so that if something goes wrong + //this code will blow up in a hopefully detectable way. + size_t next = std::numeric_limits::max(); for (size_t cur = 0; cur <= whitelist.size(); cur = next) { next = whitelist.find(';', cur); if (next != string_view::npos) { diff --git a/aten/src/ATen/core/operator_name.h b/aten/src/ATen/core/operator_name.h index b120a079a7c8f..2a926977f001a 100644 --- a/aten/src/ATen/core/operator_name.h +++ b/aten/src/ATen/core/operator_name.h @@ -72,8 +72,8 @@ inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) { return !operator==(lhs, rhs); } -CAFFE2_API std::string toString(const OperatorName& opName); -CAFFE2_API std::ostream& operator<<(std::ostream&, const OperatorName&); +TORCH_API std::string toString(const OperatorName& opName); +TORCH_API std::ostream& operator<<(std::ostream&, const OperatorName&); } // namespace c10 diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 13e82d4346472..878d08c032323 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { out << "Tensor"; } if (auto ndim = value->sizes().size()) { - bool has_valid_strides_info = + bool has_valid_strides_info = *ndim > 0 && value->strides().isComplete() && value->strides().size() == ndim; out << "("; @@ -41,10 +41,17 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { } else { out << "*"; } - if (has_valid_strides_info && - type_verbosity() >= TypeVerbosity::TypeAndStride) { - out << ":" << *value->strides()[i]; + } + if (has_valid_strides_info && + type_verbosity() >= TypeVerbosity::TypeAndStride) { + out << ", strides=["; + for (size_t i = 0; i < *ndim; ++i) { + if (i > 0) { + out << ", "; + } + out << *value->strides()[i]; } + out << "]"; } if (type_verbosity() >= TypeVerbosity::Full) { if (value->requiresGrad()) { @@ -61,6 +68,21 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { } } out << ")"; + } else { + if (type_verbosity() >= TypeVerbosity::Full) { + size_t i = 0; + if (value->requiresGrad()) { + out << "(" + << "requires_grad=" << *value->requiresGrad(); + i++; + } + if (value->device()) { + out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device(); + } + if (i > 0) { + out << ")"; + } + } } if (value->undefined() && *value->undefined()) { @@ -127,6 +149,10 @@ BoolTypePtr BoolType::get() { static auto value = BoolType::create(); return value; } +StorageTypePtr StorageType::get() { + static auto value = StorageType::create(); + return value; +} NoneTypePtr NoneType::get() { static auto value = NoneType::create(); return value; @@ -151,6 +177,10 @@ DeviceObjTypePtr DeviceObjType::get() { static auto value = DeviceObjType::create(); return value; } +StreamObjTypePtr StreamObjType::get() { + static auto value = StreamObjType::create(); + return value; +} ScalarTypeTypePtr ScalarTypeType::get() { static auto value = ScalarTypeType::create(); return value; @@ -222,7 +252,7 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) { // Handle non-container types which do not subtype each other and unify if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) { - return t1->expect()->merge(t2->expect()); + return t1->expectRef().merge(*t2->expect()); } if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) { @@ -425,7 +455,7 @@ MatchTypeReturn matchTypeVariables( // unknown type). return matchTypeVariables(opt_formal->getElementType(), actual, type_env); } - // note: if actual was non here we potentially did not fill in the type + // note: if actual was None here we potentially did not fill in the type // variables contained in the formal. It is still a valid match because None // matches Optional[T] later error checking on tryEvalTypeVariables will // report the problem if we never match variables in type T @@ -454,7 +484,7 @@ MatchTypeReturn matchTypeVariables( } // change return types like List[List[t]] into List[List[int]] -CAFFE2_API TypePtr tryEvalTypeVariables(TypePtr type, std::unordered_map& type_env) { +TORCH_API TypePtr tryEvalTypeVariables(TypePtr type, std::unordered_map& type_env) { if (!type->hasFreeVariables()) { return type; } @@ -479,7 +509,7 @@ CAFFE2_API TypePtr tryEvalTypeVariables(TypePtr type, std::unordered_mapkind() == OptionalType::Kind || elem_type->kind() == NumberType::Kind) { // Builtin Union types @@ -506,7 +536,7 @@ const char * typeKindToString(TypeKind kind) { return ""; } -bool Type::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { +bool Type::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { if (rhs->kind() == TypeKind::AnyType || *this == *rhs) { return true; } @@ -549,16 +579,16 @@ VaryingShape TensorType::sizes() const { })); } -TensorTypePtr TensorType::merge(TensorTypePtr other, bool merge_sizes) const { - auto scalar_type = merge_primitive(scalarType(), other->scalarType()); - auto dev = merge_primitive(device(), other->device()); - auto sprops = stride_properties().merge(other->stride_properties()); - auto gr = merge_primitive(requiresGrad(), other->requiresGrad()); - auto undef = merge_primitive(undefined(), other->undefined()); +TensorTypePtr TensorType::merge(const TensorType& other, bool merge_sizes) const { + auto scalar_type = merge_primitive(scalarType(), other.scalarType()); + auto dev = merge_primitive(device(), other.device()); + auto sprops = stride_properties().merge(other.stride_properties()); + auto gr = merge_primitive(requiresGrad(), other.requiresGrad()); + auto undef = merge_primitive(undefined(), other.undefined()); return TensorType::create( scalar_type, dev, - merge_sizes ? symbolic_sizes().merge(other->symbolic_sizes()) + merge_sizes ? symbolic_sizes().merge(other.symbolic_sizes()) : symbolic_sizes(), sprops, gr, @@ -585,8 +615,10 @@ bool TensorType::matchTensor(const at::Tensor& t) { } // Here we know t.defined() == true and compare all other properties. bool rg = at::GradMode::is_enabled() && t.requires_grad(); - bool matched_strides = (!t.has_storage() && !stride_properties().isComplete()) - || stride_properties() == computeStrideProps(t.sizes(), t.strides(), t.is_contiguous()); + bool matched_strides = (!stride_properties().size()) || + (!t.has_storage() && !stride_properties().isComplete()) || + stride_properties() == + computeStrideProps(t.sizes(), t.strides(), t.is_contiguous()); return scalarType().value_or(t.scalar_type()) == t.scalar_type() && device().value_or(t.device()) == t.device() && requiresGrad().value_or(rg) == rg @@ -728,7 +760,7 @@ TupleType::TupleType( } } -bool TupleType::isSubtypeOfExt(const TypePtr rhs_, std::ostream* why_not) const { +bool TupleType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const { if (Type::isSubtypeOfExt(rhs_, why_not)) { return true; } @@ -763,7 +795,7 @@ bool TupleType::isSubtypeOfExt(const TypePtr rhs_, std::ostream* why_not) const }); } -bool ListType::isSubtypeOfExt(const TypePtr rhs_, std::ostream* why_not) const { +bool ListType::isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const { if (Type::isSubtypeOfExt(rhs_, why_not)) { return true; } @@ -967,11 +999,9 @@ TensorTypePtr TensorType::create( const SymbolicShape& sizes, const VaryingShape& strides, c10::optional requires_grad, - c10::optional undefined, - bool is_inferred) { - auto pt = TensorTypePtr(new TensorType( + c10::optional undefined) { + auto pt = TensorTypePtr(new TensorType( scalar_type, device, sizes, strides, requires_grad, undefined)); - pt->is_inferred_ = is_inferred; return pt; } @@ -1006,13 +1036,13 @@ const SymbolicShape& TensorType::symbolic_sizes() const { return sizes_; } -bool TensorType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { +bool TensorType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { if (auto rhs_p = rhs->cast()) { // if we have the same pointer, avoid computing the merge if (this == rhs_p.get()) { return true; } - return *merge(rhs_p) == *rhs_p; + return *merge(*rhs_p) == *rhs_p; } return Type::isSubtypeOfExt(rhs, why_not); } @@ -1087,7 +1117,7 @@ ClassTypePtr ClassType::refine(at::ArrayRef refined_slots) const { return ptr; } -bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { +bool ClassType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { if (rhs->cast()) { return true; } @@ -1171,7 +1201,7 @@ bool InterfaceType::isSubTypeImpl( return true; } -bool InterfaceType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { +bool InterfaceType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { // to improve performance this check can be cached if (auto iface = rhs->cast()) { return isSubTypeImpl(*this, *iface, why_not); @@ -1200,19 +1230,21 @@ InterfaceType::~InterfaceType() = default; ClassTypePtr ClassType::create( c10::optional qualifiedName, std::weak_ptr cu, - bool is_module) { + bool is_module, + std::string doc_string) { return ClassTypePtr( - new ClassType(std::move(qualifiedName), std::move(cu), is_module)); + new ClassType(std::move(qualifiedName), std::move(cu), is_module, std::move(doc_string))); } ClassType::ClassType( c10::optional name, std::weak_ptr cu, - bool is_module = false) + bool is_module = false, + std::string doc_string = "") : NamedType(TypeKind::ClassType, std::move(name)), compilation_unit_(std::move(cu)), - isModule_(is_module) { -} + isModule_(is_module), + doc_string_(std::move(doc_string)) {} const std::vector& ClassType::methods() const { return methods_; @@ -1285,7 +1317,7 @@ size_t ClassType::addAttribute( TORCH_CHECK( (type->kind() == TensorType::Kind) || (type->kind() == OptionalType::Kind && - type->expect()->getElementType()->kind() == + type->expectRef().getElementType()->kind() == TensorType::Kind) || (type->kind() == NoneType::Kind), "Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ", @@ -1302,6 +1334,14 @@ void ClassType::unsafeRemoveAttribute(const std::string& name) { AT_ASSERT(attributes_.size() == attributeTypes_.size()); } +void ClassType::unsafeChangeAttributeType(const std::string& name, TypePtr new_ty) { + auto slot = getAttributeSlot(name); + auto old_attr_info = attributes_[slot]; + AT_ASSERT(old_attr_info.getKind() == AttributeKind::REGULAR_ATTRIBUTE); + attributes_[slot] = ClassAttribute(old_attr_info.getKind(), new_ty, old_attr_info.getName()); + attributeTypes_[slot] = new_ty; +} + size_t ClassType::addConstant(const std::string& name, const IValue& value) { checkNotExist(name, "constant"); size_t slot = constantNames_.size(); @@ -1419,7 +1459,7 @@ SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const { return SymbolicShape(std::move(dims)); } -bool EnumType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { +bool EnumType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { return rhs->kind() == TypeKind::AnyType || rhs->kind() == TypeKind::AnyEnumType || *this == *rhs; } diff --git a/aten/src/ATen/cpp_custom_type_hack.h b/aten/src/ATen/cpp_custom_type_hack.h index 9f8f61f534edb..d690a00e0c2c3 100644 --- a/aten/src/ATen/cpp_custom_type_hack.h +++ b/aten/src/ATen/cpp_custom_type_hack.h @@ -1,11 +1,52 @@ -// WARNING! WARNING! WARNING! -// This file is a temporary hack to enable development of pytorch quantization -// -// It's a stub for wrapping arbitrary cpp types in TorchScript. Proper -// implementation (under development) is to use TorchScript custom types. -// In the meantime, we abuse ByteTensor with custom deleter for this purpose. -// -// Template argument has to be registered with CAFFE_KNOWN_TYPE mechanism. +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP + +// YOU ARE IN THE WRONG PLACE! TURN BACK NOW! + +// This code was a temporary hack to enable embedding arbitrary C++ structures +// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE, +// IT __WILL__ BREAK. + +// This code has been superseded by custom classes: +// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html + +// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED +// IN THIS FILE**. + +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP #include #include @@ -14,6 +55,8 @@ namespace at { namespace cpp_custom_type_hack { template +[[deprecated("Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool isa(const Tensor& packed) { return (packed.scalar_type() == kByte) && (packed.storage().data_ptr().get_deleter() == @@ -21,6 +64,8 @@ bool isa(const Tensor& packed) { } template +[[deprecated("Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T& cast(const Tensor& packed) { TORCH_CHECK( packed.scalar_type() == kByte, "Expected temporary cpp type wrapper"); @@ -33,6 +78,8 @@ T& cast(const Tensor& packed) { } template +[[deprecated("Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor create(std::unique_ptr ptr, TensorOptions options) { // None of this should trace, so turn off Tracer dispatching at::AutoNonVariableTypeMode guard; // TODO: remove diff --git a/aten/src/ATen/cpu/vec256/intrinsics.h b/aten/src/ATen/cpu/vec256/intrinsics.h index fc3fd8547d81c..26b8b78a8cee7 100644 --- a/aten/src/ATen/cpu/vec256/intrinsics.h +++ b/aten/src/ATen/cpu/vec256/intrinsics.h @@ -32,6 +32,11 @@ (defined(__VEC__) || defined(__ALTIVEC__)) /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ #include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel #elif defined(__GNUC__) && defined(__SPE__) /* GCC-compatible compiler, targeting PowerPC with SPE */ #include diff --git a/aten/src/ATen/cpu/vec256/missing_vld1_neon.h b/aten/src/ATen/cpu/vec256/missing_vld1_neon.h index 3bc17f3fa6a4c..5540c8bc782fa 100644 --- a/aten/src/ATen/cpu/vec256/missing_vld1_neon.h +++ b/aten/src/ATen/cpu/vec256/missing_vld1_neon.h @@ -5,7 +5,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_u8_x2 (const uint8_t *__a) { uint8x8x2_t ret; - asm ("ld1 {%S0.8b - %T0.8b}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -14,7 +14,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_s8_x2 (const int8_t *__a) { int8x8x2_t ret; - asm ("ld1 {%S0.8b - %T0.8b}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -23,7 +23,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_u16_x2 (const uint16_t *__a) { uint16x4x2_t ret; - asm ("ld1 {%S0.4h - %T0.4h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -32,7 +32,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_s16_x2 (const int16_t *__a) { int16x4x2_t ret; - asm ("ld1 {%S0.4h - %T0.4h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -41,7 +41,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_u32_x2 (const uint32_t *__a) { uint32x2x2_t ret; - asm ("ld1 {%S0.2s - %T0.2s}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -50,7 +50,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_s32_x2 (const int32_t *__a) { int32x2x2_t ret; - asm ("ld1 {%S0.2s - %T0.2s}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -59,7 +59,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_u64_x2 (const uint64_t *__a) { uint64x1x2_t ret; - asm ("ld1 {%S0.1d - %T0.1d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -69,7 +69,7 @@ vld1_s64_x2 (const int64_t *__a) { int64x1x2_t ret; __builtin_aarch64_simd_oi __o; - asm ("ld1 {%S0.1d - %T0.1d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -78,7 +78,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_f16_x2 (const float16_t *__a) { float16x4x2_t ret; - asm ("ld1 {%S0.4h - %T0.4h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -87,7 +87,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_f32_x2 (const float32_t *__a) { float32x2x2_t ret; - asm ("ld1 {%S0.2s - %T0.2s}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -96,7 +96,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_f64_x2 (const float64_t *__a) { float64x1x2_t ret; - asm ("ld1 {%S0.1d - %T0.1d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -105,7 +105,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_p8_x2 (const poly8_t *__a) { poly8x8x2_t ret; - asm ("ld1 {%S0.8b - %T0.8b}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -114,7 +114,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_p16_x2 (const poly16_t *__a) { poly16x4x2_t ret; - asm ("ld1 {%S0.4h - %T0.4h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -123,7 +123,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1_p64_x2 (const poly64_t *__a) { poly64x1x2_t ret; - asm ("ld1 {%S0.1d - %T0.1d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -132,7 +132,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_u8_x2 (const uint8_t *__a) { uint8x16x2_t ret; - asm ("ld1 {%S0.16b - %T0.16b}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -141,7 +141,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_s8_x2 (const int8_t *__a) { int8x16x2_t ret; - asm ("ld1 {%S0.16b - %T0.16b}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -150,7 +150,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_u16_x2 (const uint16_t *__a) { uint16x8x2_t ret; - asm ("ld1 {%S0.8h - %T0.8h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -159,7 +159,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_s16_x2 (const int16_t *__a) { int16x8x2_t ret; - asm ("ld1 {%S0.8h - %T0.8h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -168,7 +168,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_u32_x2 (const uint32_t *__a) { uint32x4x2_t ret; - asm ("ld1 {%S0.4s - %T0.4s}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -177,7 +177,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_s32_x2 (const int32_t *__a) { int32x4x2_t ret; - asm ("ld1 {%S0.4s - %T0.4s}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -186,7 +186,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_u64_x2 (const uint64_t *__a) { uint64x2x2_t ret; - asm ("ld1 {%S0.2d - %T0.2d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -195,7 +195,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_s64_x2 (const int64_t *__a) { int64x2x2_t ret; - asm ("ld1 {%S0.2d - %T0.2d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -204,7 +204,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_f16_x2 (const float16_t *__a) { float16x8x2_t ret; - asm ("ld1 {%S0.8h - %T0.8h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -213,7 +213,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_f32_x2 (const float32_t *__a) { float32x4x2_t ret; - asm ("ld1 {%S0.4s - %T0.4s}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -222,7 +222,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_f64_x2 (const float64_t *__a) { float64x2x2_t ret; - asm ("ld1 {%S0.2d - %T0.2d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -231,7 +231,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_p8_x2 (const poly8_t *__a) { poly8x16x2_t ret; - asm ("ld1 {%S0.16b - %T0.16b}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -240,7 +240,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_p16_x2 (const poly16_t *__a) { poly16x8x2_t ret; - asm ("ld1 {%S0.8h - %T0.8h}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -249,7 +249,7 @@ __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vld1q_p64_x2 (const poly64_t *__a) { poly64x2x2_t ret; - asm ("ld1 {%S0.2d - %T0.2d}, [%1]" : "=w" (ret) : "r"(__a) :); + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a)); return ret; } @@ -259,194 +259,194 @@ __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_s64_x2 (int64_t * __a, int64x1x2_t val) { - asm ("st1 {%S0.1d - %T0.1d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_u64_x2 (uint64_t * __a, uint64x1x2_t val) { - asm ("st1 {%S0.1d - %T0.1d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_f64_x2 (float64_t * __a, float64x1x2_t val) { - asm ("st1 {%S0.1d - %T0.1d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_s8_x2 (int8_t * __a, int8x8x2_t val) { - asm ("st1 {%S0.8b - %T0.8b}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_p8_x2 (poly8_t * __a, poly8x8x2_t val) { - asm ("st1 {%S0.8b - %T0.8b}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_s16_x2 (int16_t * __a, int16x4x2_t val) { - asm ("st1 {%S0.4h - %T0.4h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_p16_x2 (poly16_t * __a, poly16x4x2_t val) { - asm ("st1 {%S0.4h - %T0.4h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_s32_x2 (int32_t * __a, int32x2x2_t val) { - asm ("st1 {%S0.2s - %T0.2s}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_u8_x2 (uint8_t * __a, uint8x8x2_t val) { - asm ("st1 {%S0.8b - %T0.8b}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_u16_x2 (uint16_t * __a, uint16x4x2_t val) { - asm ("st1 {%S0.4h - %T0.4h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_u32_x2 (uint32_t * __a, uint32x2x2_t val) { - asm ("st1 {%S0.2s - %T0.2s}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_f16_x2 (float16_t * __a, float16x4x2_t val) { - asm ("st1 {%S0.4h - %T0.4h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_f32_x2 (float32_t * __a, float32x2x2_t val) { - asm ("st1 {%S0.2s - %T0.2s}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1_p64_x2 (poly64_t * __a, poly64x1x2_t val) { - asm ("st1 {%S0.1d - %T0.1d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_s8_x2 (int8_t * __a, int8x16x2_t val) { - asm ("st1 {%S0.16b - %T0.16b}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_p8_x2 (poly8_t * __a, poly8x16x2_t val) { - asm ("st1 {%S0.16b - %T0.16b}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_s16_x2 (int16_t * __a, int16x8x2_t val) { - asm ("st1 {%S0.8h - %T0.8h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_p16_x2 (poly16_t * __a, poly16x8x2_t val) { - asm ("st1 {%S0.8h - %T0.8h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_s32_x2 (int32_t * __a, int32x4x2_t val) { - asm ("st1 {%S0.4s - %T0.4s}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_s64_x2 (int64_t * __a, int64x2x2_t val) { - asm ("st1 {%S0.2d - %T0.2d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_u8_x2 (uint8_t * __a, uint8x16x2_t val) { - asm ("st1 {%S0.16b - %T0.16b}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_u16_x2 (uint16_t * __a, uint16x8x2_t val) { - asm ("st1 {%S0.8h - %T0.8h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_u32_x2 (uint32_t * __a, uint32x4x2_t val) { - asm ("st1 {%S0.4s - %T0.4s}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_u64_x2 (uint64_t * __a, uint64x2x2_t val) { - asm ("st1 {%S0.2d - %T0.2d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_f16_x2 (float16_t * __a, float16x8x2_t val) { - asm ("st1 {%S0.8h - %T0.8h}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_f32_x2 (float32_t * __a, float32x4x2_t val) { - asm ("st1 {%S0.4s - %T0.4s}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_f64_x2 (float64_t * __a, float64x2x2_t val) { - asm ("st1 {%S0.2d - %T0.2d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val)); } __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_p64_x2 (poly64_t * __a, poly64x2x2_t val) { - asm ("st1 {%S0.2d - %T0.2d}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val)); } diff --git a/aten/src/ATen/cpu/vec256/missing_vst1_neon.h b/aten/src/ATen/cpu/vec256/missing_vst1_neon.h index dbb2ba479f858..711d16f9b231f 100644 --- a/aten/src/ATen/cpu/vec256/missing_vst1_neon.h +++ b/aten/src/ATen/cpu/vec256/missing_vst1_neon.h @@ -4,6 +4,5 @@ __extension__ extern __inline void __attribute__ ((__always_inline__, __gnu_inline__, __artificial__)) vst1q_f32_x2 (float32_t * __a, float32x4x2_t val) { - asm ("st1 {%S0.4s - %T0.4s}, [%1]" :: "w" (val), "r"(__a) :); + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val)); } - diff --git a/aten/src/ATen/cpu/vec256/vec256.h b/aten/src/ATen/cpu/vec256/vec256.h index 96d17a9e1afa8..ae40b9a5b4fd9 100644 --- a/aten/src/ATen/cpu/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec256/vec256.h @@ -6,6 +6,7 @@ #include #include +#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) #include #include #include @@ -14,6 +15,9 @@ #include #include #include +#else +#include +#endif #include #include diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index 49acbc518dca9..477e366ea18b3 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -251,7 +251,7 @@ struct Vec256 { Vec256 angle() const { // other_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "other_t_angle must be T"); - return Vec256(0); + return map(at::native::angle_impl); // compiler is unable to resolve the overload without } template ::value, int>::type = 0> @@ -394,6 +394,20 @@ struct Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + Vec256 ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = calc_igamma(values[i], x[i]); + } + return ret; + } + Vec256 igammac(const Vec256 &x) const { + Vec256 ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = calc_igammac(values[i], x[i]); + } + return ret; + } Vec256 neg() const { // NB: the trailing return type is needed because we need to coerce the // return value back to T in the case of unary operator- incuring a @@ -615,23 +629,12 @@ inline T minimum(const T& a, const T& b) { return c; } -// To save BC, it will not propagate NaN based on IEEE 754 201X template ::value, int>::type = 0> Vec256 inline clamp(const Vec256 &a, const Vec256 &min_vec, const Vec256 &max_vec) { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size(); i++) { - c[i] = a[i] < min_vec[i] ? min_vec[i] : (a[i] > max_vec[i] ? max_vec[i] : a[i]); - } - return c; -} - -template ::value, int>::type = 0> -Vec256 inline clamp(const Vec256 &a, const Vec256 &min_vec, const Vec256 &max_vec) { - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : (std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]); + c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); } return c; } @@ -646,16 +649,6 @@ Vec256 inline clamp_max(const Vec256 &a, const Vec256 &max_vec) { return c; } -template ::value, int>::type = 0> -Vec256 inline clamp_max(const Vec256 &a, const Vec256 &max_vec) { - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = std::abs(a[i]) > std::abs(max_vec[i]) ? max_vec[i] : a[i]; - } - return c; -} - template ::value, int>::type = 0> Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { @@ -666,16 +659,6 @@ Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { return c; } -template ::value, int>::type = 0> -Vec256 inline clamp_min(const Vec256 &a, const Vec256 &min_vec) { - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = std::abs(a[i]) < std::abs(min_vec[i]) ? min_vec[i] : a[i]; - } - return c; -} - struct Vec256i; #ifdef CPU_CAPABILITY_AVX2 @@ -736,6 +719,14 @@ inline Vec256 operator^(const Vec256& a, const Vec256& b) { #endif +template>::value, int> = 0> +inline Vec256 operator~(const Vec256& a) { + Vec256 ones; // All bits are 1 + memset((T*) ones, 0xFF, 32); + return a ^ ones; +} + + template inline Vec256& operator += (Vec256& a, const Vec256& b) { a = a + b; diff --git a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h index 37d41676e53c9..dbe9cf374d959 100644 --- a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h @@ -25,7 +25,7 @@ static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) { static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) { __m256i lo = _mm256_castps_si256(a); __m256i hi = _mm256_castps_si256(b); - __m256i nan = _mm256_set1_epi32(0x7fc0); + __m256i nan = _mm256_set1_epi32(0xffff); __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q)); __m256i ones = _mm256_set1_epi32(0x1); @@ -203,7 +203,23 @@ template <> class Vec256 { return cvtfp32_bf16(o1, o2); } Vec256 angle() const { - return _mm256_set1_epi16(0); + __m256 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto angle_lambda = [](__m256 values) { + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(M_PI); + + const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvtfp32_bf16(o1, o2); } Vec256 real() const { return *this; @@ -290,6 +306,45 @@ template <> class Vec256 { auto o2 = _mm256_loadu_ps(tmp2); return cvtfp32_bf16(o1, o2); } + Vec256 igamma(const Vec256 &x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + + Vec256 igammac(const Vec256 &x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } Vec256 log() const { return map(Sleef_logf8_u10); } @@ -446,7 +501,7 @@ Vec256 inline Vec256::operator==(const Vec256& oth } Vec256 inline Vec256::operator!=(const Vec256& other) const { return bfloat16_binary_op_as_fp32(*this, other, [](__m256 x, __m256 y) { - return _mm256_cmp_ps(x, y, _CMP_NEQ_OQ); + return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ); }); } diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec256/vec256_complex_double.h index 0827b33a31228..a9f9b6a776cf8 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_double.h @@ -252,6 +252,12 @@ template <> class Vec256> { Vec256> hypot(const Vec256> &b) const { AT_ERROR("not supported for complex numbers"); } + Vec256> igamma(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> igammac(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } Vec256> neg() const { auto zero = _mm256_setzero_pd(); return _mm256_sub_pd(zero, values); @@ -272,18 +278,7 @@ template <> class Vec256> { return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } Vec256> sqrt() const { - // sqrt(a + bi) - // = sqrt(2)/2 * [sqrt(sqrt(a**2 + b**2) + a) + sgn(b)*sqrt(sqrt(a**2 + b**2) - a)i] - // = sqrt(2)/2 * [sqrt(abs() + a) + sgn(b)*sqrt(abs() - a)i] - - const __m256d scalar = _mm256_set1_pd(std::sqrt(2)/2); //sqrt(2)/2 sqrt(2)/2 - const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); - auto sign = _mm256_and_pd(values, sign_mask); - auto factor = _mm256_or_pd(scalar, sign); - - auto a_a = _mm256_xor_pd(_mm256_movedup_pd(values), sign_mask); // a -a - auto res_re_im = _mm256_sqrt_pd(_mm256_add_pd(abs_(), a_a)); // sqrt(abs + a) sqrt(abs - a) - return _mm256_mul_pd(factor, res_re_im); + return map(std::sqrt); } Vec256> reciprocal() const; Vec256> rsqrt() const { @@ -306,7 +301,7 @@ template <> class Vec256> { return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); } Vec256> operator!=(const Vec256>& other) const { - return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ); + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); } Vec256> operator<(const Vec256>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); @@ -416,32 +411,6 @@ Vec256> inline minimum(const Vec256>& return _mm256_or_pd(min, isnan); } -template <> -Vec256> inline clamp(const Vec256>& a, const Vec256>& min, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_pd(_mm256_blendv_pd(a, min, max_mask), max, min_mask); -} - -template <> -Vec256> inline clamp_min(const Vec256>& a, const Vec256>& min) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_pd(abs_a, abs_min, _CMP_LT_OQ); - return _mm256_blendv_pd(a, min, max_mask); -} - -template <> -Vec256> inline clamp_max(const Vec256>& a, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_pd(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_pd(a, max, min_mask); -} - template <> Vec256> inline operator&(const Vec256>& a, const Vec256>& b) { return _mm256_and_pd(a, b); diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec256/vec256_complex_float.h index ea931acc494b5..0398567629caf 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_float.h @@ -290,6 +290,12 @@ template <> class Vec256> { Vec256> hypot(const Vec256> &b) const { AT_ERROR("not supported for complex numbers"); } + Vec256> igamma(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vec256> igammac(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } Vec256> neg() const { auto zero = _mm256_setzero_ps(); return _mm256_sub_ps(zero, values); @@ -310,18 +316,7 @@ template <> class Vec256> { return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } Vec256> sqrt() const { - // sqrt(a + bi) - // = sqrt(2)/2 * [sqrt(sqrt(a**2 + b**2) + a) + sgn(b)*sqrt(sqrt(a**2 + b**2) - a)i] - // = sqrt(2)/2 * [sqrt(abs() + a) + sgn(b)*sqrt(abs() - a)i] - - const __m256 scalar = _mm256_set1_ps(std::sqrt(2)/2); //sqrt(2)/2 sqrt(2)/2 - const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); - auto sign = _mm256_and_ps(values, sign_mask); - auto factor = _mm256_or_ps(scalar, sign); - - auto a_a = _mm256_xor_ps(_mm256_moveldup_ps(values), sign_mask); // a -a - auto res_re_im = _mm256_sqrt_ps(_mm256_add_ps(abs_(), a_a)); // sqrt(abs + a) sqrt(abs - a) - return _mm256_mul_ps(factor, res_re_im); + return map(std::sqrt); } Vec256> reciprocal() const; Vec256> rsqrt() const { @@ -344,7 +339,7 @@ template <> class Vec256> { return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); } Vec256> operator!=(const Vec256>& other) const { - return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ); + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); } Vec256> operator<(const Vec256>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); @@ -456,32 +451,6 @@ Vec256> inline minimum(const Vec256>& a, return _mm256_or_ps(min, isnan); } -template <> -Vec256> inline clamp(const Vec256>& a, const Vec256>& min, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_ps(_mm256_blendv_ps(a, min, max_mask), max, min_mask); -} - -template <> -Vec256> inline clamp_min(const Vec256>& a, const Vec256>& min) { - auto abs_a = a.abs_2_(); - auto abs_min = min.abs_2_(); - auto max_mask = _mm256_cmp_ps(abs_a, abs_min, _CMP_LT_OQ); - return _mm256_blendv_ps(a, min, max_mask); -} - -template <> -Vec256> inline clamp_max(const Vec256>& a, const Vec256>& max) { - auto abs_a = a.abs_2_(); - auto abs_max = max.abs_2_(); - auto min_mask = _mm256_cmp_ps(abs_a, abs_max, _CMP_GT_OQ); - return _mm256_blendv_ps(a, max, min_mask); -} - template <> Vec256> inline operator&(const Vec256>& a, const Vec256>& b) { return _mm256_and_ps(a, b); diff --git a/aten/src/ATen/cpu/vec256/vec256_double.h b/aten/src/ATen/cpu/vec256/vec256_double.h index fcad154e68b2c..0bea07dbf5929 100644 --- a/aten/src/ATen/cpu/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_double.h @@ -108,7 +108,16 @@ template <> class Vec256 { return _mm256_andnot_pd(mask, values); } Vec256 angle() const { - return _mm256_set1_pd(0); + const auto zero_vec = _mm256_set1_pd(0.f); + const auto nan_vec = _mm256_set1_pd(NAN); + const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_pd(M_PI); + + const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask); + angle = _mm256_blendv_pd(angle, nan_vec, nan_mask); + return angle; } Vec256 real() const { return *this; @@ -155,6 +164,26 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ double tmp[size()]; + __at_align32__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vec256 igammac(const Vec256 &x) const { + __at_align32__ double tmp[size()]; + __at_align32__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 log() const { return Vec256(Sleef_logd4_u10(values)); } @@ -227,7 +256,7 @@ template <> class Vec256 { } Vec256 operator!=(const Vec256& other) const { - return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ); + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); } Vec256 operator<(const Vec256& other) const { diff --git a/aten/src/ATen/cpu/vec256/vec256_float.h b/aten/src/ATen/cpu/vec256/vec256_float.h index 1ab11ea81529d..a8fd65b0ba79b 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_float.h @@ -115,7 +115,16 @@ template <> class Vec256 { return _mm256_andnot_ps(mask, values); } Vec256 angle() const { - return _mm256_set1_ps(0); + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(M_PI); + + const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; } Vec256 real() const { return *this; @@ -193,6 +202,26 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vec256 igammac(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 neg() const { return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); } @@ -234,7 +263,7 @@ template <> class Vec256 { } Vec256 operator!=(const Vec256& other) const { - return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ); + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); } Vec256 operator<(const Vec256& other) const { diff --git a/aten/src/ATen/cpu/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec256/vec256_float_neon.h index cfe6b0ea0fb36..58a4afac17cb9 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float_neon.h +++ b/aten/src/ATen/cpu/vec256/vec256_float_neon.h @@ -259,12 +259,12 @@ template <> class Vec256 { // Only required because vec256_qint refers to this. // Once we specialize that implementation for ARM // this should be removed. TODO (kimishpatel) - const float operator[](int idx) const { + float operator[](int idx) const { __at_align32__ float tmp[size()]; store(tmp); return tmp[idx]; - }; - const float operator[](int idx) { + } + float operator[](int idx) { __at_align32__ float tmp[size()]; store(tmp); return tmp[idx]; @@ -362,6 +362,26 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vec256 igammac(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 log() const { return map(std::log); } @@ -432,14 +452,23 @@ template <> class Vec256 { vsqrtq_f32(values.val[1])); } Vec256 reciprocal() const { - return Vec256( - vrecpeq_f32(values.val[0]), - vrecpeq_f32(values.val[1])); + float32x4_t r0 = vrecpeq_f32(values.val[0]); + float32x4_t r1 = vrecpeq_f32(values.val[1]); + // Run two more Netwon's method iterations to get more accurate results + r0 = vmulq_f32(vrecpsq_f32(values.val[0], r0), r0); + r0 = vmulq_f32(vrecpsq_f32(values.val[0], r0), r0); + r1 = vmulq_f32(vrecpsq_f32(values.val[1], r1), r1); + r1 = vmulq_f32(vrecpsq_f32(values.val[1], r1), r1); + return Vec256(r0, r1); } Vec256 rsqrt() const { - return Vec256( - vrsqrteq_f32(values.val[0]), - vrsqrteq_f32(values.val[1])); + float32x4_t r0 = vrsqrteq_f32(values.val[0]); + float32x4_t r1 = vrsqrteq_f32(values.val[1]); + r0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[0], r0), r0), r0); + r0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[0], r0), r0), r0); + r1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[1], r1), r1), r1); + r1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[1], r1), r1), r1); + return Vec256(r0, r1); } Vec256 pow(const Vec256 &exp) const { __at_align32__ float tmp[size()]; @@ -665,6 +694,6 @@ Vec256 inline fmadd(const Vec256& a, const Vec256& b, const return Vec256(r0, r1); } -#endif +#endif /* defined(aarch64) */ }}} diff --git a/aten/src/ATen/cpu/vec256/vec256_int.h b/aten/src/ATen/cpu/vec256/vec256_int.h index 98afd8bdd33cc..2ba2744d35268 100644 --- a/aten/src/ATen/cpu/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec256/vec256_int.h @@ -104,6 +104,8 @@ class Vec256 : public Vec256i { } void store(void* ptr, int count = size()) const { if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { __at_align32__ int64_t tmp_values[size()]; @@ -119,9 +121,6 @@ class Vec256 : public Vec256i { auto inverse = _mm256_xor_si256(values, is_larger); return _mm256_sub_epi64(inverse, is_larger); } - Vec256 angle() const { - return _mm256_set1_epi64x(0); - } Vec256 real() const { return *this; } @@ -228,6 +227,8 @@ class Vec256 : public Vec256i { } void store(void* ptr, int count = size()) const { if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { __at_align32__ int32_t tmp_values[size()]; @@ -246,9 +247,6 @@ class Vec256 : public Vec256i { Vec256 abs() const { return _mm256_abs_epi32(values); } - Vec256 angle() const { - return _mm256_set1_epi32(0); - } Vec256 real() const { return *this; } @@ -449,6 +447,8 @@ class Vec256 : public Vec256i { } void store(void* ptr, int count = size()) const { if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { __at_align32__ int16_t tmp_values[size()]; @@ -461,9 +461,6 @@ class Vec256 : public Vec256i { Vec256 abs() const { return _mm256_abs_epi16(values); } - Vec256 angle() const { - return _mm256_set1_epi16(0); - } Vec256 real() const { return *this; } @@ -699,6 +696,8 @@ class Vec256 : public Vec256i { } void store(void* ptr, int count = size()) const { if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { __at_align32__ int8_t tmp_values[size()]; @@ -711,9 +710,6 @@ class Vec256 : public Vec256i { Vec256 abs() const { return _mm256_abs_epi8(values); } - Vec256 angle() const { - return _mm256_set1_epi8(0); - } Vec256 real() const { return *this; } @@ -879,8 +875,8 @@ Vec256 inline operator*(const Vec256& a, const Vec256 template Vec256 inline int_elementwise_binary_256(const Vec256& a, const Vec256& b, Op op) { - __at_align32__ T values_a[Vec256::size()]; - __at_align32__ T values_b[Vec256::size()]; + T values_a[Vec256::size()]; + T values_b[Vec256::size()]; a.store(values_a); b.store(values_b); for (int i = 0; i != Vec256::size(); i++) { @@ -1039,6 +1035,10 @@ template>: inline Vec256 operator^(const Vec256& a, const Vec256& b) { return _mm256_xor_si256(a, b); } +template>::value, int> = 0> +inline Vec256 operator~(const Vec256& a) { + return _mm256_xor_si256(a, _mm256_set1_epi32(-1)); +} Vec256 Vec256::eq(const Vec256& other) const { return (*this == other) & Vec256(1); diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_common_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_common_vsx.h new file mode 100644 index 0000000000000..516179932d34c --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_common_vsx.h @@ -0,0 +1,216 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace at { +namespace vec256 { + +namespace { + +DEFINE_CLAMP_FUNCS(c10::quint8) +DEFINE_CLAMP_FUNCS(c10::qint8) +DEFINE_CLAMP_FUNCS(c10::qint32) +DEFINE_CLAMP_FUNCS(int16_t) +DEFINE_CLAMP_FUNCS(int32_t) +DEFINE_CLAMP_FUNCS(int64_t) +DEFINE_CLAMP_FUNCS(float) +DEFINE_CLAMP_FUNCS(double) + +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + vec_madd(a.vec0(), b.vec0(), c.vec0()), + vec_madd(a.vec1(), b.vec1(), c.vec1())}; +} + +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vec256 C10_ALWAYS_INLINE fmadd( + const Vec256& a, + const Vec256& b, + const Vec256& c) { + return Vec256{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +template <> +Vec256 C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vec256& src) { + return Vec256{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +Vec256 C10_ALWAYS_INLINE +convert_to_int_of_same_size( + const Vec256& src) { + return Vec256{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vec256::size()); i += Vec256::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint32 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat32 c0 = vec_float(input_vec0); + vfloat32 c1 = vec_float(input_vec1); + vec_vsx_st(c0, offset0, dst_a); + vec_vsx_st(c1, offset16, dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vec256::size()); i += Vec256::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + vint64 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint64 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat64 c0 = vec_double(input_vec0); + vfloat64 c1 = vec_double(input_vec1); + vec_vsx_st(c0, offset0, reinterpret_cast(dst_a)); + vec_vsx_st(c1, offset16, reinterpret_cast(dst_a)); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +std::pair, Vec256> inline interleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0); + vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3); + vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0); + vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3); + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vec256{ab00, ab11}, Vec256{ab2_00, ab2_11}); +} + +template <> +std::pair, Vec256> inline deinterleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0); + vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0); + + vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3); + vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3); + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vec256{aa01, aa23}, Vec256{bb_01, bb_23}); +} + +template <> +std::pair, Vec256> inline interleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + + vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vec256{ab0011, ab2233}, Vec256{ab2_0011, ab2_2233}); +} + +template <> +std::pair, Vec256> inline deinterleave2( + const Vec256& a, + const Vec256& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vec256{aa0123, aa0123_2}, Vec256{bb0123, bb0123_2}); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_double_vsx.h new file mode 100644 index 0000000000000..f62ac36850bed --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_double_vsx.h @@ -0,0 +1,597 @@ +#pragma once +#include +#include +#include +#include + +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { +using ComplexDbl = c10::complex; + +template <> +class Vec256 { + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexDbl; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + static constexpr int size() { + return 2; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {} + + Vec256(ComplexDbl val) { + double real_value = val.real(); + double imag_value = val.imag(); + _vec0 = vfloat64{real_value, imag_value}; + _vec1 = vfloat64{real_value, imag_value}; + } + Vec256(ComplexDbl val1, ComplexDbl val2) { + _vec0 = vfloat64{val1.real(), val1.imag()}; + _vec1 = vfloat64{val2.real(), val2.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {a._vec0, b._vec1}; + } + + template + static Vec256 C10_ALWAYS_INLINE + el_blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vec256 blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = + Vec256(vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0)); + return { + vec_sel(a._vec0, b._vec0, mask_complex._vecb0), + vec_sel(a._vec1, b._vec1, mask_complex._vecb1)}; + } + + static Vec256 C10_ALWAYS_INLINE elwise_blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vec256 arange( + ComplexDbl base = 0., + step_t step = static_cast(1)) { + return Vec256(base, base + step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexDbl& operator[](int idx) const = delete; + ComplexDbl& operator[](int idx) = delete; + + Vec256 map(ComplexDbl (*f)(ComplexDbl)) const { + __at_align32__ ComplexDbl tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vec256 map(ComplexDbl (*f)(const ComplexDbl&)) const { + __at_align32__ ComplexDbl tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vec256 el_swapped() const { + vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2); + vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2); + return {v0, v1}; + } + + Vec256 el_madd( + const Vec256& multiplier, + const Vec256& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + Vec256 el_mergeo() const { + vfloat64 v0 = vec_splat(_vec0, 1); + vfloat64 v1 = vec_splat(_vec1, 1); + return {v0, v1}; + } + + Vec256 el_mergee() const { + vfloat64 v0 = vec_splat(_vec0, 0); + vfloat64 v1 = vec_splat(_vec1, 0); + return {v0, v1}; + } + + static Vec256 el_mergee( + Vec256& first, + Vec256& second) { + // as mergee phased in , we can use vec_perm with mask + return { + vec_mergeh(first._vec0, second._vec0), + vec_mergeh(first._vec1, second._vec1)}; + } + + Vec256 abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a; + } + + Vec256 abs_() const { + auto ret = abs_2_(); + return ret.elwise_sqrt(); + } + + Vec256 abs() const { + return abs_() & vd_real_mask; + } + + Vec256 angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_pd(values, 0x05); // b a + // return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + auto ret = el_swapped(); + for (int i = 0; i < 2; i++) { + ret._vec0[i] = std::atan2(_vec0[i], ret._vec0[i]); + ret._vec1[i] = std::atan2(_vec1[i], ret._vec0[i]); + } + return ret; + } + + Vec256 angle() const { + auto a = angle_().el_swapped(); + return a & vd_real_mask; + } + + Vec256 real_() const { + return *this & vd_real_mask; + } + Vec256 real() const { + return *this & vd_real_mask; + } + Vec256 imag_() const { + return *this & vd_imag_mask; + } + Vec256 imag() const { + return imag_().el_swapped(); + } + + Vec256 conj_() const { + return *this ^ vd_isign_mask; + } + Vec256 conj() const { + return *this ^ vd_isign_mask; + } + + Vec256 log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vec256 log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(vd_log2e_inv); + } + Vec256 log10() const { + auto ret = log(); + return ret.elwise_mult(vd_log10e_inv); + } + + Vec256 asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub(val_2, val_2_swapped); + re = Vec256(vd_one) - re; + auto root = el_blend<0x0A>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); + } + + Vec256 acos() const { + // acos(x) = pi/2 - asin(x) + return Vec256(vd_pi_2) - asin(); + } + + Vec256 atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vec256(vd_imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * vd_imag_half; // i/2*ln() + } + + Vec256 sin() const { + return map(std::sin); + } + Vec256 sinh() const { + return map(std::sinh); + } + Vec256 cos() const { + return map(std::cos); + } + Vec256 cosh() const { + return map(std::cosh); + } + + Vec256 tan() const { + return map(std::tan); + } + Vec256 tanh() const { + return map(std::tanh); + } + Vec256 ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 neg() const { + auto z = Vec256(vd_zero); + return z - *this; + } + Vec256 round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vec256 trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << ","; + std::cout << _vec1[0] << "," << _vec1[1] << std::endl; + } + + Vec256 sqrt() const { + return map(std::sqrt); + } + + Vec256 reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ vd_isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vec256 rsqrt() const { + return sqrt().reciprocal(); + } + + static Vec256 horizontal_add( + Vec256& first, + Vec256& second) { + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first + first_perm; // 2add + auto second_ret = second + second_perm; // 2 add + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + static Vec256 horizontal_sub( + Vec256& first, + Vec256& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vec256 inline operator*(const Vec256& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ vd_rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); +#else + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ vd_isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub(ac_bd, ad_bc); +#endif + return ret; + } + + Vec256 inline operator/(const Vec256& b) const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() +#if 1 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + auto abs_b = b.abs_2_(); + vi = vi ^ vd_isign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); + ret = ret.elwise_div(abs_b); +#else + // Vec256 x86 simulation + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ vd_rsign_mask; + auto ad_bc = elwise_mult(d_c); + auto abs_b = b.abs_2_(); + auto re_im = horizontal_add(ac_bd, ad_bc); + auto ret = re_im.elwise_div(abs_b); +#endif + return ret; + } + + Vec256 exp() const { + return map(std::exp); + } + + Vec256 pow(const Vec256& exp) const { + __at_align32__ ComplexDbl x_tmp[size()]; + __at_align32__ ComplexDbl y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vec256 sgn() const { + return map(at::native::sgn_impl); + } + + Vec256 hypot(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 nextafter(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igamma(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igammac(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 log1p() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 atan2(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 erf() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 erfc() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 expm1() const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator<(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 operator<=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 operator>(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 operator>=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 eq(const Vec256& other) const { + auto ret = (*this == other); + return ret & vd_one; + } + Vec256 ne(const Vec256& other) const { + auto ret = (*this != other); + return ret & vd_one; + } + + Vec256 lt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 le(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 gt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vec256 ge(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor) + // elelemtwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vec256::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vec256::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + + +} // namespace +} // namespace vec256 +} // namespace at + diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_float_vsx.h new file mode 100644 index 0000000000000..cb9b4c90fbe07 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_complex_float_vsx.h @@ -0,0 +1,670 @@ + +#pragma once +#include +#include +#include +#include + +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { +using ComplexFlt = c10::complex; + +template <> +class Vec256 { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexFlt; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + + static constexpr int size() { + return 4; + } + Vec256() {} + + C10_ALWAYS_INLINE Vec256(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + + Vec256(ComplexFlt val) { + float real_value = val.real(); + float imag_value = val.imag(); + _vec0 = vfloat32{real_value, imag_value, real_value, imag_value}; + _vec1 = vfloat32{real_value, imag_value, real_value, imag_value}; + } + + Vec256(ComplexFlt val1, ComplexFlt val2, ComplexFlt val3, ComplexFlt val4) { + _vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()}; + _vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + const vbool32 mask_2nd = VsxComplexMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static Vec256 C10_ALWAYS_INLINE + el_blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vec256 blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vec256( + vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1)); + // mask_complex.dump(); + return { + vec_sel(a._vec0, b._vec0, mask_complex._vec0), + vec_sel(a._vec1, b._vec1, mask_complex._vec1), + }; + } + + static Vec256 elwise_blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + return { + vec_sel(a._vec0, b._vec0, mask._vec0), + vec_sel(a._vec1, b._vec1, mask._vec1), + }; + } + + template + static Vec256 arange( + ComplexFlt base = 0., + step_t step = static_cast(1)) { + return Vec256( + base, + base + step, + base + ComplexFlt(2) * step, + base + ComplexFlt(3) * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexFlt& operator[](int idx) const = delete; + ComplexFlt& operator[](int idx) = delete; + + Vec256 map(ComplexFlt (*f)(ComplexFlt)) const { + __at_align32__ ComplexFlt tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vec256 map(ComplexFlt (*f)(const ComplexFlt&)) const { + __at_align32__ ComplexFlt tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + static Vec256 horizontal_add_permD8( + Vec256& first, + Vec256& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first + first_perm; // 2add + auto second_ret = second + second_perm; // 2 add + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + static Vec256 horizontal_sub_permD8( + Vec256& first, + Vec256& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vec256 abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a.el_mergee(); + } + + Vec256 abs_() const { + auto ret = abs_2_(); + return ret.elwise_sqrt(); + } + + Vec256 abs() const { + return abs_() & real_mask; + } + + Vec256 real_() const { + return *this & real_mask; + } + Vec256 real() const { + return *this & real_mask; + } + Vec256 imag_() const { + return *this & imag_mask; + } + Vec256 imag() const { + // we can use swap_mask or sldwi + auto ret = imag_(); + return { + vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)}; + } + + Vec256 conj_() const { + return *this ^ isign_mask; + } + Vec256 conj() const { + return *this ^ isign_mask; + } + + Vec256 log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vec256 log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(log2e_inv); + } + Vec256 log10() const { + auto ret = log(); + return ret.elwise_mult(log10e_inv); + } + + Vec256 el_swapped() const { + vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask); + vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + Vec256 el_mergee() const { + // as mergee phased in , we can use vec_perm with mask + return {vec_mergee(_vec0, _vec0), vec_mergee(_vec1, _vec1)}; + } + + Vec256 el_mergeo() const { + // as mergeo phased in , we can use vec_perm with mask + return {vec_mergeo(_vec0, _vec0), vec_mergeo(_vec1, _vec1)}; + } + + Vec256 el_madd( + const Vec256& multiplier, + const Vec256& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + static Vec256 el_mergee( + Vec256& first, + Vec256& second) { + // as mergee phased in , we can use vec_perm with mask + return { + vec_mergee(first._vec0, second._vec0), + vec_mergee(first._vec1, second._vec1)}; + } + + Vec256 angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_ps(values, 0xB1); // b a + // return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + auto ret = el_swapped(); + for (int i = 0; i < 4; i++) { + ret._vec0[i] = std::atan2(_vec0[i], ret._vec0[i]); + ret._vec1[i] = std::atan2(_vec1[i], ret._vec0[i]); + } + return ret; + } + + Vec256 angle() const { + auto a = angle_().el_swapped(); + return a & real_mask; + } + + Vec256 sin() const { + return map(std::sin); + } + Vec256 sinh() const { + return map(std::sinh); + } + Vec256 cos() const { + return map(std::cos); + } + Vec256 cosh() const { + return map(std::cosh); + } + Vec256 ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 neg() const { + auto z = Vec256(zero); + return z - *this; + } + Vec256 round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vec256 tan() const { + return map(std::tan); + } + Vec256 tanh() const { + return map(std::tanh); + } + Vec256 trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << "," << _vec0[2] << "," + << _vec0[3] << ","; + std::cout << _vec1[0] << "," << _vec1[1] << "," << _vec1[2] << "," + << _vec1[3] << std::endl; + } + + Vec256 sqrt() const { + return map(std::sqrt); + } + + Vec256 reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vec256 rsqrt() const { + return sqrt().reciprocal(); + } + + Vec256 pow(const Vec256& exp) const { + __at_align32__ ComplexFlt x_tmp[size()]; + __at_align32__ ComplexFlt y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vec256 atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vec256(imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * imag_half; // i/2*ln() + } + + Vec256 acos() const { + // acos(x) = pi/2 - asin(x) + return Vec256(pi_2) - asin(); + } + + Vec256 inline operator*(const Vec256& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); + return ret; + +#else + + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub_permD8(ac_bd, ad_bc); + return ret; +#endif + } + + Vec256 inline operator/(const Vec256& b) const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() +#if 1 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + auto abs_b = b.abs_2_(); + vi = vi ^ isign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.el_madd(vi, ret); + ret = ret.elwise_div(abs_b); +#else + // Vec256 x86 simulation + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ rsign_mask; + auto ad_bc = elwise_mult(d_c); + auto abs_b = b.abs_2_(); + auto re_im = horizontal_add_permD8(ac_bd, ad_bc); + auto ret = re_im.elwise_div(abs_b); +#endif + return ret; + } + + Vec256 asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + +#if 1 + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub_permD8(val_2, val_2_swapped); + re = Vec256(one) - re; + auto root = el_blend<0xAA>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); +#else + return map(std::asin); +#endif + } + + Vec256 exp() const { + return map(std::exp); + } + + Vec256 eq(const Vec256& other) const { + auto ret = (*this == other); + return ret & one; + } + Vec256 ne(const Vec256& other) const { + auto ret = (*this != other); + return ret & one; + } + + Vec256 sgn() const { + return map(at::native::sgn_impl); + } + + Vec256 hypot(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 nextafter(const Vec256& b) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igamma(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 igammac(const Vec256& x) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 atan2(const Vec256& b) const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + Vec256 erf() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + Vec256 erfc() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + + Vec256 log1p() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + + Vec256 expm1() const { + TORCH_CHECK(false,"not supported for complex numbers"); + } + + Vec256 operator<(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator<=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator>(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 operator>=(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 lt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 le(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 gt(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vec256 ge(const Vec256& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor) + // elelemtwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vec256::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vec256::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_double_vsx.h new file mode 100644 index 0000000000000..f34bdc7bbcb30 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_double_vsx.h @@ -0,0 +1,392 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace vec256 { + +namespace { + + +template <> +class Vec256 { + private: + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = double; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + static constexpr int size() { + return 4; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(double scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + double scalar1, + double scalar2, + double scalar3, + double scalar4) + : _vec0{vfloat64{scalar1, scalar2}}, _vec1{vfloat64{scalar3, scalar4}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + int zero_mask() const { + auto cmp = (*this == vd_zero); + return (cmp._vecb0[0] & 1) | (cmp._vecb0[1] & 2) | (cmp._vecb1[0] & 4) | + (cmp._vecb1[1] & 8); + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return { b._vec0, a._vec1 }; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return { a._vec0, b._vec1 }; + } + + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1 }; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1 }; + } + + + template + static std::enable_if_t> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return { a._vec0, + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) }; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return { b._vec0, + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) }; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) }; + } + + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + static Vec256 arange(double base = 0., double step = 1.) { + return Vec256(base, base + step, base + 2 * step, base + 3 * step); + } + + static Vec256 C10_ALWAYS_INLINE + set(const Vec256& a, const Vec256& b, size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << "," << _vec1[0] << "," << _vec1[1] << std::endl; + } + Vec256 map(double (*f)(double)) const { + Vec256 ret; + for (int i = 0; i < size()/2; i++) { + ret._vec0[i] = f(_vec0[i]); + } + for (int i = 0; i < size()/2; i++) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vec256 mapbi(double (*f)(double, double), const Vec256& other) + const { + Vec256 ret; + for (int i = 0; i < size()/2; i++) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (int i = 0; i < size()/2; i++) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE acos() const { + return {Sleef_acosd2_u10vsx(_vec0), Sleef_acosd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE asin() const { + return {Sleef_asind2_u10vsx(_vec0), Sleef_asind2_u10vsx(_vec1)}; + } + Vec256 atan() const { + return {Sleef_atand2_u10vsx(_vec0), Sleef_atand2_u10vsx(_vec1)}; + } + Vec256 atan2(const Vec256& b) const { + return {Sleef_atan2d2_u10vsx(_vec0, b._vec0), Sleef_atan2d2_u10vsx(_vec1, b._vec1)}; + } + Vec256 erf() const { + return {Sleef_erfd2_u10vsx(_vec0), Sleef_erfd2_u10vsx(_vec1)}; + } + Vec256 erfc() const { + return {Sleef_erfcd2_u15vsx(_vec0), Sleef_erfcd2_u15vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE exp() const { + return {Sleef_expd2_u10vsx(_vec0), Sleef_expd2_u10vsx(_vec1)}; + } + Vec256 expm1() const { + return {Sleef_expm1d2_u10vsx(_vec0), Sleef_expm1d2_u10vsx(_vec1)}; + } + + Vec256 lgamma() const __ubsan_ignore_undefined__ { + return {Sleef_lgammad2_u10vsx(_vec0), Sleef_lgammad2_u10vsx(_vec1)}; + } + + Vec256 erfinv() const { + return map(calc_erfinv); + } + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE log() const { + return {Sleef_logd2_u10vsx(_vec0), Sleef_logd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE log10() const { + return {Sleef_log10d2_u10vsx(_vec0), Sleef_log10d2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pd2_u10vsx(_vec0), Sleef_log1pd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE log2() const { + return {Sleef_log2d2_u10vsx(_vec0), Sleef_log2d2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE cos() const { + return {Sleef_cosd2_u10vsx(_vec0), Sleef_cosd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshd2_u10vsx(_vec0), Sleef_coshd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE sin() const { + return {Sleef_sind2_u10vsx(_vec0), Sleef_sind2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhd2_u10vsx(_vec0), Sleef_sinhd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE tan() const { + return {Sleef_tand2_u10vsx(_vec0), Sleef_tand2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhd2_u10vsx(_vec0), Sleef_tanhd2_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vec256 C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE reciprocal() const { + return { + vec_div(vd_one, _vec0), // vec_re(_vec0) is estimated one. + vec_div(vd_one, _vec1)}; + } + Vec256 C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vec256 C10_ALWAYS_INLINE pow(const Vec256& b) const { + return {Sleef_powd2_u10vsx(_vec0, b._vec0), Sleef_powd2_u10vsx(_vec1, b._vec1)}; + } + Vec256 C10_ALWAYS_INLINE fmod(const Vec256& b) const { + return {Sleef_fmodd2_vsx(_vec0, b._vec0),Sleef_fmodd2_vsx(_vec1, b._vec1)}; + } + + Vec256 hypot(const Vec256& b) const { + return {Sleef_hypotd2_u05vsx(_vec0, b._vec0), Sleef_hypotd2_u05vsx(_vec1, b._vec1)}; + } + + Vec256 nextafter(const Vec256& b) const { + return {Sleef_nextafterd2_vsx(_vec0, b._vec0), Sleef_nextafterd2_vsx(_vec1, b._vec1)}; + } + + Vec256 igamma(const Vec256& x) const { + return mapbi(calc_igamma, x); + } + + Vec256 igammac(const Vec256& x) const { + return mapbi(calc_igammac, x); + } + + + Vec256 i0() const { + return map(calc_i0); + } + + DEFINE_MEMBER_OP(operator==, double, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, double, vec_cmpne) + DEFINE_MEMBER_OP(operator<, double, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, double, vec_cmple) + DEFINE_MEMBER_OP(operator>, double, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, double, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, double, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, double, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, double, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, double, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, double, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, double, vec_cmpge) + DEFINE_MEMBER_OP(operator+, double, vec_add) + DEFINE_MEMBER_OP(operator-, double, vec_sub) + DEFINE_MEMBER_OP(operator*, double, vec_mul) + DEFINE_MEMBER_OP(operator/, double, vec_div) + DEFINE_MEMBER_OP(maximum, double, vec_max) + DEFINE_MEMBER_OP(minimum, double, vec_min) + DEFINE_MEMBER_OP(operator&, double, vec_and) + DEFINE_MEMBER_OP(operator|, double, vec_or) + DEFINE_MEMBER_OP(operator^, double, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, double, vec_madd) +}; +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_float_vsx.h new file mode 100644 index 0000000000000..2a1a87aa72c87 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_float_vsx.h @@ -0,0 +1,676 @@ +#pragma once + +#include +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] + +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = float; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + + static constexpr int size() { + return 8; + } + Vec256() {} + + C10_ALWAYS_INLINE Vec256(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(float scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + float scalar1, + float scalar2, + float scalar3, + float scalar4, + float scalar5, + float scalar6, + float scalar7, + float scalar8) + : _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + static Vec256 arange(float base = 0.f, float step = 1.f) { + return Vec256( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + + Vec256 map(float (*f)(float)) const { + Vec256 ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vec256 mapbi(float (*f)(float, float), const Vec256& other) + const { + Vec256 ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + + Vec256 _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vec256 _isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + + Vec256 _isinf() const { + auto x = *this; + return (x == v_inf) | (x == v_minus_inf); + } + + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + //__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + auto cmp = (*this == zero); + // return _mm256_movemask_ps(cmp); + // possible simulation //mask= lvsl ( 0 ) vbpermq( vec, mask <<5) + vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits); + vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits); + return (result0[1] >> 12 | (result1[1] >> 8)); + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE acos() const { + return {Sleef_acosf4_u10vsx(_vec0), Sleef_acosf4_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE asin() const { + return {Sleef_asinf4_u10vsx(_vec0), Sleef_asinf4_u10vsx(_vec1)}; + } + Vec256 atan() const { + return {Sleef_atanf4_u10vsx(_vec0), Sleef_atanf4_u10vsx(_vec1)}; + } + Vec256 atan2(const Vec256& b) const { + return {Sleef_atan2f4_u10vsx(_vec0, b._vec0), Sleef_atan2f4_u10vsx(_vec1, b._vec1)}; + } + + Vec256 lgamma() const { + return {Sleef_lgammaf4_u10vsx(_vec0), Sleef_lgammaf4_u10vsx(_vec1)}; + } + Vec256 erf() const { + return {Sleef_erff4_u10vsx(_vec0), Sleef_erff4_u10vsx(_vec1)}; + } + + Vec256 erfc() const { + return {Sleef_erfcf4_u15vsx(_vec0), Sleef_erfcf4_u15vsx(_vec1)}; + } + + Vec256 erfinv() const { + return map(calc_erfinv); + } + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE exp() const { + // implementation logic from avx_mathfun with some modifications from sleef + // Express e**x = e**g 2**n + /// = e**g e**( n loge(2) ) + /// = e**( g + n loge(2) ) + // + auto tmp_x = *this; + auto fx = (tmp_x * log2e_inv).round(); + + auto x = fx.madd(negln2f_hi, tmp_x); + x = fx.madd(negln2f_lo, x); + auto z = x * x; + auto y = x.madd(exp_p0, exp_p1); + y = y.madd(x, exp_p2); + y = y.madd(x, exp_p3); + y = y.madd(x, exp_p4); + y = y.madd(x, exp_p5); + y = y.madd(z, x) + one; + + // vm_pow2n 2^n + vint32 imm0 = vec_signed(fx._vec0); + vint32 imm1 = vec_signed(fx._vec1); + // this pow2n logic is from Sleef code + vint32 imm00 = imm0 >> 1; //>>1 + vint32 imm01 = imm1 >> 1; + vint32 imm10 = imm0 - imm00; + vint32 imm11 = imm1 - imm01; + imm00 = (imm00 + v0x7f) << vu_23; + imm01 = (imm01 + v0x7f) << vu_23; + imm10 = (imm10 + v0x7f) << vu_23; + imm11 = (imm11 + v0x7f) << vu_23; + // treat imm as float vector without conversion + + y._vec0 = (y._vec0 * (vfloat32)imm00) * (vfloat32)imm10; + y._vec1 = (y._vec1 * (vfloat32)imm01) * (vfloat32)imm11; + // boundary check + auto tmp = blendv(y, v_inf, (Vec256(exp_hi) <= tmp_x)); + y = blendv(tmp, zero, (tmp_x < Vec256(exp_lo))); + + return y; + } + Vec256 expm1() const { + return exp() - one; + } + + Vec256 C10_ALWAYS_INLINE log() const { + auto temp = *this; + auto invalid_mask = temp < zero; + // cut off denormalized stuff + auto x = temp.maximum(min_norm_pos); + vint32 imm0 = vec_sr(vint32(x._vec0), vu_23); + vint32 imm1 = vec_sr(vint32(x._vec1), vu_23); + // keep only the fractional part + x = x & inv_mant_mask; + x = x | half; + imm0 = imm0 - v0x7f; + imm1 = imm1 - v0x7f; + Vec256 ex; + ex._vec0 = vec_float(imm0); + ex._vec1 = vec_float(imm1); + ex = ex + one; + auto mask = x < cephes_SQRTHF; + auto t = x & mask; + x = x - one; + ex = ex - (mask & one); + x = x + t; + auto z = x * x; + auto y = x.madd(log_p0, log_p1); + y = y.madd(x, log_p2); + y = y.madd(x, log_p3); + y = y.madd(x, log_p4); + y = y.madd(x, log_p5); + y = y.madd(x, log_p6); + y = y.madd(x, log_p7); + y = y.madd(x, log_p8); + y = y * x * z; + y = ex.madd(log_q1, y); + y = y - z * half; + x = x + y; + x = ex.madd(log_q2, x); + // negative arg will be NAN + x = blendv(x, v_nan, invalid_mask); + // zero is -inf + x = blendv(x, min_inf, (temp == zero)); + return x; + } + Vec256 C10_ALWAYS_INLINE log10() const { + return log() * log10e_inv; + } + Vec256 C10_ALWAYS_INLINE log1p() const { + return ((*this) + one).log(); + } + Vec256 C10_ALWAYS_INLINE log2() const { + return log() * log2e_inv; + } + Vec256 C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE cos() const { + // take the absolute value + auto x = abs(); + // extract the sign bit (upper one) + auto sign_bit = (*this) & sign_mask; + // scale by 4/Pi + auto y = x * _4div_pi; + // store the integer part of y in mm0 + // j=(j+1) & (~1) (see the cephes sources) + vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; + vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; + y._vec0 = vec_float(imm0); + y._vec1 = vec_float(imm1); + + imm0 = imm0 - vi_2; + imm1 = imm1 - vi_2; + Vec256 poly_mask; + // get the swap sign flag + vint32 tmp0 = vec_and(vec_nand(imm0, imm0), vi_4); + vint32 tmp1 = vec_and(vec_nand(imm1, imm1), vi_4); + sign_bit._vecb0 = (vbool32)vec_sl(tmp0, vu_29); + sign_bit._vecb1 = (vbool32)vec_sl(tmp1, vu_29); + // get the polynom selection mask + // there is one polynom for 0 <= x <= Pi / 4 + // and another one for Pi / 4 < x <= Pi / 2 + // Both branches will be computed. + + poly_mask._vecb0 = (vbool32)vec_cmpeq((imm0 & vi_2), vi_0); + poly_mask._vecb1 = (vbool32)vec_cmpeq((imm1 & vi_2), vi_0); + + // The magic pass: "Extended precision modular arithmetic" + // x = ((x - y * DP1) - y * DP2) - y * DP3; + x = y.madd(minus_cephes_dp1, x); + x = y.madd(minus_cephes_dp2, x); + x = y.madd(minus_cephes_dp3, x); + + // Evaluate the first polynom (0 <= x <= Pi/4) + auto z = x * x; + y = z.madd(coscof_p0, coscof_p1); + y = y.madd(z, coscof_p2); + y = y * z * z; + y = y - z * half + one; + + // Evaluate the second polynom (Pi/4 <= x <= 0) + auto y_2 = z.madd(sincof_p0, sincof_p1); + y_2 = y_2.madd(z, sincof_p2); + y_2 = y_2 * z; + y_2 = y_2.madd(x, x); + + // select the correct result from the two polynoms + y = blendv(y, y_2, poly_mask); + // update the sign + y = y ^ sign_bit; + + return y; + } + Vec256 C10_ALWAYS_INLINE cosh() const { + // cosh = 1/2 * (e^x + e^-x) + auto x = abs(); + auto e_x = x.exp(); + auto ret = (e_x + Vec256(one) / e_x) * half; + // inf and nan checks +#if 0 + ret = blendv(ret, v_inf, x >= vf_89); + ret = blendv(ret, v_inf, ret._isnan()); + ret = blendv(ret, v_nan, this->_isnan()); +#endif + return ret; + } + Vec256 C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + void dump() const { + std::cout << _vec0[0] << "," << _vec0[1] << "," << _vec0[2] << "," + << _vec0[3] << ","; + std::cout << _vec1[0] << "," << _vec1[1] << "," << _vec1[2] << "," + << _vec1[3] << std::endl; + } + + Vec256 C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE sin() const { + // take the absolute value and xtract sign + auto x = abs(); + auto sign_bit = (*this) & sign_mask; + + // scale by 4/Pi + auto y = x * _4div_pi; + // store the integer part of y in mm0 + + // j=(j+1) & (~1) (see the cephes sources) + vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; + vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; + y._vec0 = vec_float(imm0); + y._vec1 = vec_float(imm1); + // get the swap sign flag + Vec256 swap_sign_bit, poly_mask; + swap_sign_bit._vecb0 = (vbool32)vec_sl(imm0 & vi_4, vu_29); + swap_sign_bit._vecb1 = (vbool32)vec_sl(imm1 & vi_4, vu_29); + // get the polynom selection mask + // there is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4 C10_ALWAYS_INLINE sinh() const { + auto temp_abs = abs(); + // get exponent + auto ret = temp_abs.exp(); + auto recp = Vec256(half) / ret; + auto v = ret * half - recp; + // extract the sign bit (upper one) + auto sign_bit = (*this) & sign_mask; + auto z = temp_abs * temp_abs; + auto y = z.madd(p0, p1); + y = y.madd(z, p2); + y = (y * z).madd(temp_abs, temp_abs); + // check and select + auto result = blendv(y, v, temp_abs > one); + return result | sign_bit; + } + Vec256 C10_ALWAYS_INLINE tan() const { + return {Sleef_tanf4_u10vsx(_vec0), Sleef_tanf4_u10vsx(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE tanh() const { + auto x = *this; + auto vabs = abs(); + // get exponent + auto exp2x = (vabs + vabs).exp(); + auto vv = Vec256(one) - Vec256(two) / (exp2x + one); + // extract the sign bit (upper one) + auto sign_bit = (*this) & sign_mask; + auto z = vabs * vabs; + auto y = z.madd(tanh_p0, tanh_p1); + auto tmp = y.madd(z, tanh_p2); + y = z.madd(tmp, tanh_p3); + tmp = y.madd(z, tanh_p4); + y = tmp * z; + tmp = y.madd(x, x); + // add sign + vv = vv | sign_bit; + // check and select + auto sel_mask = vabs >= tanh_0p625; + auto max_mask = vabs > tanh_half_max; + auto max_ret = sign_bit ^ one; + return blendv(blendv(tmp, vv, sel_mask), max_ret, max_mask); + } + Vec256 C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vec256 C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vec256 C10_ALWAYS_INLINE reciprocal() const { + return Vec256(one) / (*this); + } + Vec256 C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vec256 C10_ALWAYS_INLINE pow(const Vec256& exp) const { + auto x = *this; + auto sign_bit = (*this) & sign_mask; + // |b| + auto exp_abs = exp.abs(); + auto exp_trunc = exp.trunc(); + Vec256 odd_mask; + odd_mask._vecb0 = (vec_signed(exp._vec0) & vi_1) != vi_0; + odd_mask._vecb1 = (vec_signed(exp._vec1) & vi_1) != vi_0; + // using ln fuction + auto temp = (abs().log() * exp).exp(); + + // is odd or even check from Sleef + auto is_int = (exp == exp_trunc) | (exp_abs >= vcheck); + auto is_odd = odd_mask & is_int & (exp_abs < vcheck); + // if even then then pow result should be absolute + auto temp_sign = temp | sign_bit; // copy_sign + auto out = blendv(temp, temp_sign, is_odd); + // x<0 and y != N, then NAN + auto out1 = blendv(out, v_nan, ((exp.floor() != exp) & (x < zero))); + // y = 0 then 1 + return blendv(out1, one, (exp_abs == zero)); + } + + Vec256 fmod(const Vec256& b) const { + return {Sleef_fmodf4_vsx(_vec0, b._vec0),Sleef_fmodf4_vsx(_vec1, b._vec1)}; + } + + Vec256 hypot(const Vec256& b) const { + return {Sleef_hypotf4_u05vsx(_vec0, b._vec0), Sleef_hypotf4_u05vsx(_vec1, b._vec1)}; + } + + Vec256 nextafter(const Vec256& b) const { + return {Sleef_nextafterf4_vsx(_vec0, b._vec0), Sleef_nextafterf4_vsx(_vec1, b._vec1)}; + } + + Vec256 igamma(const Vec256& x) const { + return mapbi(calc_igamma, x); + } + + Vec256 igammac(const Vec256& x) const { + return mapbi(calc_igammac, x); + } + + Vec256 i0() const { + return map(calc_i0); + } + + DEFINE_MEMBER_OP(operator==, float, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, float, vec_cmpne) + DEFINE_MEMBER_OP(operator<, float, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, float, vec_cmple) + DEFINE_MEMBER_OP(operator>, float, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, float, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge) + DEFINE_MEMBER_OP(operator+, float, vec_add) + DEFINE_MEMBER_OP(operator-, float, vec_sub) + DEFINE_MEMBER_OP(operator*, float, vec_mul) + DEFINE_MEMBER_OP(operator/, float, vec_div) + DEFINE_MEMBER_OP(maximum, float, vec_max) + DEFINE_MEMBER_OP(minimum, float, vec_min) + DEFINE_MEMBER_OP(operator&, float, vec_and) + DEFINE_MEMBER_OP(operator|, float, vec_or) + DEFINE_MEMBER_OP(operator^, float, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd) +}; + +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum(const Vec256& a, const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_int16_vsx.h new file mode 100644 index 0000000000000..33460abe2a587 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_int16_vsx.h @@ -0,0 +1,351 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vint16 _vec0; + vint16 _vec1; + }; + struct { + vbool16 _vecb0; + vbool16 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int16_t; + using vec_internal_type = vint16; + using vec_internal_mask_type = vbool16; + static constexpr int size() { + return 16; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vint16 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool16 v1, vbool16 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(int16_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + + C10_ALWAYS_INLINE Vec256( + int16_t scalar1, + int16_t scalar2, + int16_t scalar3, + int16_t scalar4, + int16_t scalar5, + int16_t scalar6, + int16_t scalar7, + int16_t scalar8, + int16_t scalar9, + int16_t scalar10, + int16_t scalar11, + int16_t scalar12, + int16_t scalar13, + int16_t scalar14, + int16_t scalar15, + int16_t scalar16) + : _vec0{vint16{ + scalar1, + scalar2, + scalar3, + scalar4, + scalar5, + scalar6, + scalar7, + scalar8}}, + _vec1{vint16{ + scalar9, + scalar10, + scalar11, + scalar12, + scalar13, + scalar14, + scalar15, + scalar16}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t<(mask & 65535) == 65535, Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 255), Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + + return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) && + ((mask & 255) != 255)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return { + (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), + (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + static Vec256 arange(int16_t base = 0, int16_t step = 1) { + return Vec256( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not) + DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int16_t, vec_add) + DEFINE_MEMBER_OP(operator-, int16_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int16_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /) + DEFINE_MEMBER_OP(maximum, int16_t, vec_max) + DEFINE_MEMBER_OP(minimum, int16_t, vec_min) + DEFINE_MEMBER_OP(operator&, int16_t, vec_and) + DEFINE_MEMBER_OP(operator|, int16_t, vec_or) + DEFINE_MEMBER_OP(operator^, int16_t, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_int32_vsx.h new file mode 100644 index 0000000000000..2ee2318f03492 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_int32_vsx.h @@ -0,0 +1,281 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int32_t; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + static constexpr int size() { + return 8; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(int32_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + int32_t scalar1, + int32_t scalar2, + int32_t scalar3, + int32_t scalar4, + int32_t scalar5, + int32_t scalar6, + int32_t scalar7, + int32_t scalar8) + : _vec0{vint32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t<(mask & 255) == 255, Vec256> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 15), Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + + return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) && + ((mask & 15) != 15)), + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return { + (vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), + (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + static Vec256 arange(int32_t base = 0.f, int32_t step = 1.f) { + return Vec256( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vec256 set( + const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not) + DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int32_t, vec_add) + DEFINE_MEMBER_OP(operator-, int32_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int32_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /) + DEFINE_MEMBER_OP(maximum, int32_t, vec_max) + DEFINE_MEMBER_OP(minimum, int32_t, vec_min) + DEFINE_MEMBER_OP(operator&, int32_t, vec_and) + DEFINE_MEMBER_OP(operator|, int32_t, vec_or) + DEFINE_MEMBER_OP(operator^, int32_t, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_int64_vsx.h new file mode 100644 index 0000000000000..d752f71c9a636 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_int64_vsx.h @@ -0,0 +1,233 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +template <> +class Vec256 { + private: + union { + struct { + vint64 _vec0; + vint64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int64_t; + using vec_internal_type = vint64; + using vec_internal_mask_type = vbool64; + static constexpr int size() { + return 4; + } + Vec256() {} + C10_ALWAYS_INLINE Vec256(vint64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vec256(int64_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vec256( + int64_t scalar1, + int64_t scalar2, + int64_t scalar3, + int64_t scalar4) + : _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask & 15) == 15, Vec256> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + return b; + } + + template + static std::enable_if_t<(mask > 0 && mask < 3), Vec256> C10_ALWAYS_INLINE + blend(const Vec256& a, const Vec256& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + const vbool64 mask_1st = (vbool64){g0, g1}; + return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1}; + } + + template + static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 3) && (mask & 3) != 0 && (mask & 15) != 15, + Vec256> + C10_ALWAYS_INLINE blend(const Vec256& a, const Vec256& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_1st = (vbool64){g0, g1}; + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return { + (vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), + (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + static Vec256 C10_ALWAYS_INLINE blendv( + const Vec256& a, + const Vec256& b, + const Vec256& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + static Vec256 arange(int64_t base = 0., int64_t step = 1.) { + return Vec256(base, base + step, base + 2 * step, base + 3 * step); + } + + static Vec256 C10_ALWAYS_INLINE + set(const Vec256& a, + const Vec256& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + static_assert(sizeof(double) == sizeof(value_type)); + const double* dptr = reinterpret_cast(ptr); + return {// treat it as double load + (vint64)vec_vsx_ld(offset0, dptr), + (vint64)vec_vsx_ld(offset16, dptr)}; + } + + __at_align32__ double tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + (vint64)vec_vsx_ld(offset0, tmp_values), + (vint64)vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + double* dptr = reinterpret_cast(ptr); + vec_vsx_st((vfloat64)_vec0, offset0, dptr); + vec_vsx_st((vfloat64)_vec1, offset16, dptr); + } else if (count > 0) { + __at_align32__ double tmp_values[size()]; + vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); + vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + + Vec256 angle() const { + return Vec256{0}; + } + Vec256 real() const { + return *this; + } + Vec256 imag() const { + return Vec256{0}; + } + Vec256 conj() const { + return *this; + } + + Vec256 C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vec256 C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not) + DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int64_t, vec_add) + DEFINE_MEMBER_OP(operator-, int64_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int64_t, vec_mul) + DEFINE_MEMBER_OP(operator/, int64_t, vec_div) + DEFINE_MEMBER_OP(maximum, int64_t, vec_max) + DEFINE_MEMBER_OP(minimum, int64_t, vec_min) + DEFINE_MEMBER_OP(operator&, int64_t, vec_and) + DEFINE_MEMBER_OP(operator|, int64_t, vec_or) + DEFINE_MEMBER_OP(operator^, int64_t, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_qint32_vsx.h new file mode 100644 index 0000000000000..a47e295ce03b3 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_qint32_vsx.h @@ -0,0 +1,242 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vec256<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vec256, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vec256 -> 1x Vec256 +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vec256::float_num_vecs +// iterations. + +namespace at { +namespace vec256 { +namespace { + +template <> +struct Vec256 { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vec256() {} + + static constexpr int size() { + return 8; + } + + static constexpr size_t float_num_vecs() { + return 1; + } + static constexpr int int_num_vecs() { + return 1; + } + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + C10_ALWAYS_INLINE Vec256(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} + + Vec256(const c10::qint32& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + static Vec256 C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + float_vec_return_type dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); + vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); + return {Vec256{ + vec_madd(scale_vec0, float_vals0, scale_zp_premul0), + vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}}; + } + + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vec256 retval; + + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)(zero_point)); + Vec256 vf0 = rhs[0]; + + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + Vec256 relu(Vec256 zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vec256 relu6( + Vec256 zero_point, + Vec256 q_six) const { + vint32 max0 = vec_max(_vec0, zero_point._vec0); + vint32 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vec256 b) const { + return {*this - b}; + } + + static Vec256 requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 vec_mult = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + Vec256 vi = inp[0]; + vfloat32 vecf0 = vec_float(vi.vec0()); + vfloat32 vecf1 = vec_float(vi.vec1()); + + vecf0 = vec_mul(vecf0, vec_mult); + vecf1 = vec_mul(vecf1, vec_mult); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + + vint32 veci0 = vec_add(vec_signed(vecf0),vec_zero_point); + vint32 veci1 = vec_add(vec_signed(vecf1),vec_zero_point); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + void dump() const { + std::cout << _vec0[0] << " "; + std::cout << _vec0[1] << " "; + std::cout << _vec0[2] << " "; + std::cout << _vec0[3] << " "; + std::cout << _vec1[0] << " "; + std::cout << _vec1[1] << " "; + std::cout << _vec1[2] << " "; + std::cout << _vec1[3] << " "; + std::cout << std::endl; + } + + DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /) + DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_qint8_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_qint8_vsx.h new file mode 100644 index 0000000000000..f8b6eced60ef1 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_qint8_vsx.h @@ -0,0 +1,404 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vec256<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vec256, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vec256 -> 4x Vec256 +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vec256::float_num_vecs +// iterations. + +namespace at { +namespace vec256 { +namespace { + +template <> +struct Vec256 { + private: + union { + struct { + vint8 _vec0; + vint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vec256() {} + static constexpr int size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + using vec_internal_type = vint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vec256(const c10::qint8& val) + : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {} + + C10_ALWAYS_INLINE Vec256(const Vec256& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vec256(vint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vec256 loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); + vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); + return { + Vec256{ + vec_madd(scale_vec0, vecf0_0, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_1, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_2, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_3, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}}; + } + + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vint32 vmin = vec_splats(min_val); + // vint32 vmax = vec_splats(max_val); + + Vec256 vf0 = rhs[0]; + Vec256 vf1 = rhs[1]; + Vec256 vf2 = rhs[2]; + Vec256 vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf2 = vec_mul(vecf2, inverse_scale_v); + vecf3 = vec_mul(vecf3, inverse_scale_v); + + vecf4 = vec_mul(vecf4, inverse_scale_v); + vecf5 = vec_mul(vecf5, inverse_scale_v); + vecf6 = vec_mul(vecf6, inverse_scale_v); + vecf7 = vec_mul(vecf7, inverse_scale_v); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ; + // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ; + // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ; + // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ; + + // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ; + // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ; + // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ; + // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ; + // vec_packs CLAMP already + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vec256 C10_ALWAYS_INLINE relu(Vec256 zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vec256 C10_ALWAYS_INLINE + relu6(Vec256 zero_point, Vec256 q_six) const { + vint8 max0 = vec_max(_vec0, zero_point._vec0); + vint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vec256 b) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecBshi0 = vec_unpackh(b._vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + vint16 vecBshi1 = vec_unpackl(b._vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecBshi2 = vec_unpackh(b._vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + vint16 vecBshi3 = vec_unpackl(b._vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vec256(veci0 - vecBi0, veci1 - vecBi1), + Vec256(veci2 - vecBi2, veci3 - vecBi3), + Vec256(veci4 - vecBi4, veci5 - vecBi5), + Vec256(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vec256 requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vec256 vi0 = inp[0]; + Vec256 vi1 = inp[1]; + Vec256 vi2 = inp[2]; + Vec256 vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + void dump() const { + value_type vals[size()]; + store((void*)vals); + for (int i = 0; i < size(); ++i) { + std::cout << (int)(vals[i]) << " "; + } + std::cout << std::endl; + } + + DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /) + DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vec256_quint8_vsx.h b/aten/src/ATen/cpu/vec256/vsx/vec256_quint8_vsx.h new file mode 100644 index 0000000000000..96809ce325939 --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vec256_quint8_vsx.h @@ -0,0 +1,413 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vec256<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vec256, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vec256 -> 4x Vec256 +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vec256::float_num_vecs +// iterations. + +namespace at { +namespace vec256 { +namespace { + +const vint16 mask_unsigned = vec_splats((short int)0xFF); +template <> +struct Vec256 { + private: + union { + struct { + vuint8 _vec0; + vuint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vec256() {} + static constexpr int size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + using vec_internal_type = vuint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vec256(const c10::quint8& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + C10_ALWAYS_INLINE Vec256(const Vec256& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vec256(vuint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vec256(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vec256(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vec256(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vec256 loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align32__ value_type tmp_values[size()]; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align32__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); + vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); + return { + Vec256{ + vec_madd(scale_vec0, vecf0_0, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_1, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_2, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)}, + Vec256{ + vec_madd(scale_vec0, vecf0_3, scale_zp_premul0), + vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}}; + } + + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 vec_inverse = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vuint32 vmin = vec_splats(min_val); + // vuint32 vmax = vec_splats(max_val); + Vec256 vf0 = rhs[0]; + Vec256 vf1 = rhs[1]; + Vec256 vf2 = rhs[2]; + Vec256 vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, vec_inverse); + vecf1 = vec_mul(vecf1, vec_inverse); + vecf2 = vec_mul(vecf2, vec_inverse); + vecf3 = vec_mul(vecf3, vec_inverse); + + vecf4 = vec_mul(vecf4, vec_inverse); + vecf5 = vec_mul(vecf5, vec_inverse); + vecf6 = vec_mul(vecf6, vec_inverse); + vecf7 = vec_mul(vecf7, vec_inverse); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vec256 C10_ALWAYS_INLINE relu(Vec256 zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vec256 C10_ALWAYS_INLINE + relu6(Vec256 zero_point, Vec256 q_six) const { + vuint8 max0 = vec_max(_vec0, zero_point._vec0); + vuint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vec256 b) const { + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecBshi0 = vec_unpackh((vint8)b._vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + vint16 vecBshi1 = vec_unpackl((vint8)b._vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecBshi2 = vec_unpackh((vint8)b._vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + vint16 vecBshi3 = vec_unpackl((vint8)b._vec1); + + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecBshi0 = vec_and(vecBshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + vecBshi1 = vec_and(vecBshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecBshi2 = vec_and(vecBshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + vecBshi3 = vec_and(vecBshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vec256(veci0 - vecBi0, veci1 - vecBi1), + Vec256(veci2 - vecBi2, veci3 - vecBi3), + Vec256(veci4 - vecBi4, veci5 - vecBi5), + Vec256(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vec256 requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vec256 vi0 = inp[0]; + Vec256 vi1 = inp[1]; + Vec256 vi2 = inp[2]; + Vec256 vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + void dump() const { + value_type vals[size()]; + store((void*)vals); + for (int i = 0; i < size(); ++i) { + std::cout << (int)(vals[i]) << " "; + } + std::cout << std::endl; + } + + DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /) + DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor) +}; + +template <> +Vec256 inline maximum( + const Vec256& a, + const Vec256& b) { + return a.maximum(b); +} + +template <> +Vec256 inline minimum( + const Vec256& a, + const Vec256& b) { + return a.minimum(b); +} + +} // namespace +} // namespace vec256 +} // namespace at diff --git a/aten/src/ATen/cpu/vec256/vsx/vsx_helpers.h b/aten/src/ATen/cpu/vec256/vsx/vsx_helpers.h new file mode 100644 index 0000000000000..40cb7ef7a66ea --- /dev/null +++ b/aten/src/ATen/cpu/vec256/vsx/vsx_helpers.h @@ -0,0 +1,332 @@ +#pragma once +#include +#include +#include + +using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; +using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; +using vbool32 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) int; +using vbool64 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) long long; +using vint8 = __attribute__((altivec(vector__))) signed char; +using vint16 = __attribute__((altivec(vector__))) signed short; +using vint32 = __attribute__((altivec(vector__))) signed int; +using vint64 = __attribute__((altivec(vector__))) signed long long; +using vuint8 = __attribute__((altivec(vector__))) unsigned char; +using vuint16 = __attribute__((altivec(vector__))) unsigned short; +using vuint32 = __attribute__((altivec(vector__))) unsigned int; +using vuint64 = __attribute__((altivec(vector__))) unsigned long long; +using vfloat32 = __attribute__((altivec(vector__))) float; +using vfloat64 = __attribute__((altivec(vector__))) double; + +#if !defined(vec_float) +C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) { + vfloat32 vec_out; + __asm__("xvcvsxwsp %x0,%x1" : "=wf"(vec_out) : "wa"(vec_in)); + return vec_out; +} +#endif + +#define vec_not(a) vec_nor(a, a) + +#define DEFINE_MEMBER_UNARY_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op() const { \ + return Vec256{func(_vec0), func(_vec1)}; \ + } + +#define DEFINE_MEMBER_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& other) const { \ + return Vec256{ \ + func(_vec0, other._vec0), func(_vec1, other._vec1)}; \ + } + +#define DEFINE_MEMBER_BITWISE_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& other) const { \ + return Vec256{ \ + func(_vecb0, other._vecb0), func(_vecb1, other._vecb1)}; \ + } + +#define DEFINE_MEMBER_TERNARY_OP(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op( \ + const Vec256& b, const Vec256& c) const { \ + return Vec256{ \ + func(_vec0, b._vec0, c._vec0), func(_vec1, b._vec1, c._vec1)}; \ + } + +#define DEFINE_MEMBER_EMULATE_BINARY_OP(op, op_type, binary_op) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& b) const { \ + Vec256::vec_internal_type ret_0; \ + Vec256::vec_internal_type ret_1; \ + for (int i = 0; i < Vec256::size() / 2; i++) { \ + ret_0[i] = _vec0[i] binary_op b._vec0[i]; \ + ret_1[i] = _vec1[i] binary_op b._vec1[i]; \ + } \ + return Vec256{ret_0, ret_1}; \ + } + + +#define DEFINE_MEMBER_OP_AND_ONE(op, op_type, func) \ + Vec256 C10_ALWAYS_INLINE op(const Vec256& other) const { \ + using vvtype = Vec256::vec_internal_type; \ + const vvtype v_one = vec_splats(static_cast(1.0)); \ + vvtype ret0 = (vvtype)func(_vec0, other._vec0); \ + vvtype ret1 = (vvtype)func(_vec1, other._vec1); \ + return Vec256{vec_and(ret0, v_one), vec_and(ret1, v_one)}; \ + } + +#define DEFINE_CLAMP_FUNCS(operand_type) \ + template <> \ + Vec256 C10_ALWAYS_INLINE clamp( \ + const Vec256& a, \ + const Vec256& min, \ + const Vec256& max) { \ + return Vec256{ \ + vec_min(max.vec0(), vec_max(a.vec0(), min.vec0())), \ + vec_min(max.vec1(), vec_max(a.vec1(), min.vec1()))}; \ + } \ + template <> \ + Vec256 C10_ALWAYS_INLINE clamp_min( \ + const Vec256& a, const Vec256& min) { \ + return Vec256{ \ + vec_max(a.vec0(), min.vec0()), vec_max(a.vec1(), min.vec1())}; \ + } \ + template <> \ + Vec256 C10_ALWAYS_INLINE clamp_max( \ + const Vec256& a, const Vec256& max) { \ + return Vec256{ \ + vec_min(a.vec0(), max.vec0()), vec_min(a.vec1(), max.vec1())}; \ + } + +#define DEFINE_REINTERPRET_CAST_FUNCS( \ + first_type, cast_type, cast_inner_vector_type) \ + template <> \ + C10_ALWAYS_INLINE Vec256 cast( \ + const Vec256& src) { \ + return Vec256{(cast_inner_vector_type)src.vec0(), \ + (cast_inner_vector_type)src.vec1()}; \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(first_type) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, double, vfloat64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, float, vfloat32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int64_t, vint64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int32_t, vint32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int16_t, vint16) + +// it can be used to emulate blend faster +constexpr int blendChoice(uint32_t mask, uint32_t half1 = 0xF, uint32_t half2 = 0xF0) { + uint32_t none = 0; + uint32_t both = half1 | half2; + // clamp it between 0 and both + mask = mask & both; + // return (a._vec0, a._vec1) + if (mask == none) return 0; + // return (b._vec0,b._vec1) + else if (mask == both) + return 1; + // return (b._vec0,a._vec1) + else if (mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (mask > 0 && mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((mask & half1) == 0 && mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((mask & half1) == half1 && mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +constexpr int blendChoiceDbl(uint32_t mask) { + // clamp it 0 and 0xF + return blendChoice(mask, 0x3, 0xC); +} + +constexpr vbool32 VsxMask1(uint32_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (vbool32){g0, g1, g2, g3}; +} + +constexpr vbool32 VsxMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xFF) >> 4; + return VsxMask1(mask2); +} + +constexpr vbool64 VsxDblMask1(uint32_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (vbool64){g0, g1}; +} + +constexpr vbool64 VsxDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +constexpr int maskForComplex(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) complex_mask |= 3; + if (mask & 2) complex_mask |= (3 << 2); + if (mask & 4) complex_mask |= (3 << 4); + if (mask & 8) complex_mask |= (3 << 6); + return complex_mask; +} + +constexpr int maskForComplexDbl(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) complex_mask |= 3; + if (mask & 2) complex_mask |= (3 << 2); + return complex_mask; +} + +constexpr int blendChoiceComplex(uint32_t mask) { + return blendChoice(maskForComplex(mask)); +} + +constexpr int blendChoiceComplexDbl(uint32_t mask) { + return blendChoiceDbl(maskForComplexDbl(mask)); +} + +constexpr vbool32 VsxComplexMask1(uint32_t mask) { + return VsxMask1(maskForComplex(mask)); +} + +constexpr vbool32 VsxComplexMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxMask1(maskForComplex(mask2)); +} + +constexpr vbool64 VsxComplexDblMask1(uint32_t mask) { return VsxDblMask1(mask); } + +constexpr vbool64 VsxComplexDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +// constants +namespace at { +namespace vec256 { +// See Note [Acceptable use of anonymous namespace in header] +namespace { +// + constexpr int offset0 = 0; + constexpr int offset16 = 16; + +//#Constants +const vuint8 mask_zero_bits = vuint8{128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 96, 64, 32, 0}; + +const vuint8 swap_mask = + vuint8{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; + +const vint32 v0x7f = vec_splats(0x7f); +const vint32 vi_0 = vec_splats((int)(0)); +const vint32 vi_1 = vec_splats((int)1); +const vint32 vi_2 = vec_splats((int)2); +const vint32 vi_4 = vec_splats((int)4); +const vint32 vi_inv1 = vec_splats((int)~1); +const vuint32 vu_29 = vec_splats(29u); +const vuint32 vu_23 = vec_splats(23u); + +const vbool32 inv_mant_mask = (vbool32)vec_splats((unsigned int)~0xff800000); +const vbool32 sign_mask = (vbool32)vec_splats((int)0x80000000); +const vbool32 real_mask = vbool32{0xFFFFFFFF, 0x0, 0xFFFFFFFF, 0x0}; +const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF}; +const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000}; +const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0}; + +const vbool64 vd_imag_mask = vbool64{0x0, 0xFFFFFFFFFFFFFFFF}; +const vbool64 vd_real_mask = vbool64{0xFFFFFFFFFFFFFFFF, 0x0}; +const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000}; +const vbool64 vd_rsign_mask = vbool64{0x8000000000000000, 0x0}; + +const vfloat32 zero = vec_splats(0.f); +const vfloat32 half = vec_splats(0.5f); +const vfloat32 one = vec_splats(1.f); +const vfloat32 two = vec_splats(2.0f); +const vfloat32 _4div_pi = vec_splats(1.27323954473516f); +const vfloat32 v_inf = (vfloat32)vec_splats(0x7f800000u); +const vfloat32 v_minus_inf = vfloat32{ 0xff800000u, 0xff800000u, 0xff800000u, 0xff800000u }; +const vfloat32 v_nan = (vfloat32)vec_splats(0x7fffffff); +const vfloat32 log10e_inv = vec_splats(0.43429448190325176f); +const vfloat32 log2e_inv = vec_splats(1.4426950408889634f); +const vfloat32 log2eB_inv = vec_splats(1.442695036924675f); +const vfloat32 cephes_SQRTHF = vec_splats(0.707106781186547524f); +const vfloat32 coscof_p0 = vec_splats(2.443315711809948E-005f); +const vfloat32 coscof_p1 = vec_splats(-1.388731625493765E-003f); +const vfloat32 coscof_p2 = vec_splats(4.166664568298827E-002f); +const vfloat32 exp_hi = vec_splats(104.f); +const vfloat32 exp_lo = vec_splats(-104.f); +const vfloat32 exp_p0 = vec_splats(0.000198527617612853646278381f); +const vfloat32 exp_p1 = vec_splats((0.00139304355252534151077271f)); +const vfloat32 exp_p2 = vec_splats(0.00833336077630519866943359f); +const vfloat32 exp_p3 = vec_splats(0.0416664853692054748535156f); +const vfloat32 exp_p4 = vec_splats(0.166666671633720397949219f); +const vfloat32 exp_p5 = vec_splats(0.5f); +const vfloat32 log_p0 = vec_splats(7.0376836292E-2f); +const vfloat32 log_p1 = vec_splats(-1.1514610310E-1f); +const vfloat32 log_p2 = vec_splats(1.1676998740E-1f); +const vfloat32 log_p3 = vec_splats(-1.2420140846E-1f); +const vfloat32 log_p4 = vec_splats(+1.4249322787E-1f); +const vfloat32 log_p5 = vec_splats(-1.6668057665E-1f); +const vfloat32 log_p6 = vec_splats(+2.0000714765E-1f); +const vfloat32 log_p7 = vec_splats(-2.4999993993E-1f); +const vfloat32 log_p8 = vec_splats(+3.3333331174E-1f); +const vfloat32 log_q1 = vec_splats(-2.12194440e-4f); +const vfloat32 log_q2 = vec_splats(0.693359375f); +const vfloat32 max_logf = vec_splats(88.02969187150841f); +const vfloat32 max_numf = vec_splats(1.7014117331926442990585209174225846272e38f); +const vfloat32 min_inf = (vfloat32)vec_splats(0xff800000u); +const vfloat32 min_norm_pos = (vfloat32)vec_splats(0x0800000u); +const vfloat32 minus_cephes_dp1 = vec_splats(-0.78515625f); +const vfloat32 minus_cephes_dp2 = vec_splats(-2.4187564849853515625e-4f); +const vfloat32 minus_cephes_dp3 = vec_splats(-3.77489497744594108e-8f); +const vfloat32 negln2f_hi = vec_splats(-0.693145751953125f); +const vfloat32 negln2f_lo = vec_splats(-1.428606765330187045e-06f); +const vfloat32 p0 = vec_splats(2.03721912945E-4f); +const vfloat32 p1 = vec_splats(8.33028376239E-3f); +const vfloat32 p2 = vec_splats(1.66667160211E-1f); +const vfloat32 sincof_p0 = vec_splats(-1.9515295891E-4f); +const vfloat32 sincof_p1 = vec_splats(8.3321608736E-3f); +const vfloat32 sincof_p2 = vec_splats(-1.6666654611E-1f); +const vfloat32 tanh_0p625 = vec_splats(0.625f); +const vfloat32 tanh_half_max = vec_splats(44.014845935754205f); +const vfloat32 tanh_p0 = vec_splats(-5.70498872745E-3f); +const vfloat32 tanh_p1 = vec_splats(2.06390887954E-2f); +const vfloat32 tanh_p2 = vec_splats(-5.37397155531E-2f); +const vfloat32 tanh_p3 = vec_splats(1.33314422036E-1f); +const vfloat32 tanh_p4 = vec_splats(-3.33332819422E-1f); +const vfloat32 vcheck = vec_splats((float)(1LL << 24)); +const vfloat32 imag_one = vfloat32{0.f, 1.f, 0.f, 1.f}; +const vfloat32 imag_half = vfloat32{0.f, 0.5f, 0.f, 0.5f}; +const vfloat32 sqrt2_2 = vfloat32{0.70710676908493042f, 0.70710676908493042, + 0.70710676908493042, 0.70710676908493042}; +const vfloat32 pi_2 = vfloat32{M_PI / 2, 0.0, M_PI / 2, 0.0}; +const vfloat32 vf_89 = vfloat32{89.f, 89.f, 89.f, 89.f}; +const vfloat64 vd_one = vec_splats(1.0); +const vfloat64 vd_zero = vec_splats(0.0); +const vfloat64 vd_log10e_inv = vec_splats(0.43429448190325176); +const vfloat64 vd_log2e_inv = vec_splats(1.4426950408889634); +const vfloat64 vd_imag_one = vfloat64{0.0, 1.0}; +const vfloat64 vd_imag_half = vfloat64{0.0, 0.5}; +const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757}; +const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0}; + +} // namespace +} // namespace vec256 +} // namespace at + diff --git a/aten/src/ATen/cuda/CUDAApplyUtils.cuh b/aten/src/ATen/cuda/CUDAApplyUtils.cuh index 3e4ea5a2b3c28..6810b51d3f701 100644 --- a/aten/src/ATen/cuda/CUDAApplyUtils.cuh +++ b/aten/src/ATen/cuda/CUDAApplyUtils.cuh @@ -356,9 +356,11 @@ template + int step, + int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK, + int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM> #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ -C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) +C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) #endif __global__ void kernelPointwiseApply2(detail::TensorInfo a, @@ -400,7 +402,9 @@ inline dim3 getApplyBlock() { return dim3(AT_APPLY_THREADS_PER_BLOCK); } -template +template inline bool CUDA_tensor_apply2(at::Tensor a, at::Tensor b, const Op op, @@ -463,7 +467,9 @@ inline bool CUDA_tensor_apply2(at::Tensor a, kernelPointwiseApply2 \ + TYPE, A, B, step, \ + max_threads_per_block, \ + min_blocks_per_sm> \ <<>>( \ aInfo, bInfo, static_cast(totalElements), op); @@ -549,13 +555,16 @@ inline bool CUDA_tensor_apply2(at::Tensor a, } /* Provides default step = 1 to CUDA_tensor_apply2. */ -template +template inline bool CUDA_tensor_apply2(at::Tensor a, at::Tensor b, const Op op, TensorArgType aType = TensorArgType::ReadWrite, TensorArgType bType = TensorArgType::ReadOnly) { - return CUDA_tensor_apply2(a, b, op, aType, bType); + return CUDA_tensor_apply2(a, b, op, aType, bType); } } // cuda diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index c4e4793b19382..0521adf669c55 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -133,6 +133,56 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { /* LEVEL 3 BLAS FUNCTIONS */ +#ifndef __HIP_PLATFORM_HCC__ +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 +#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx +#else +// Workaround for https://github.com/pytorch/pytorch/issues/45724 +cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const void *alpha, + const void *A, + cudaDataType Atype, + int lda, + long long int strideA, + const void *B, + cudaDataType Btype, + int ldb, + long long int strideB, + const void *beta, + void *C, + cudaDataType Ctype, + int ldc, + long long int strideC, + int64_t batchCount, + cudaDataType computeType, + cublasGemmAlgo_t algo) +{ + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major != 7) { + return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo); + } + cublasStatus_t result; + constexpr int64_t split = 63 * 1024; + for(int64_t i = 0; i < batchCount; i += split) { + int64_t count = std::min(split, batchCount - i); + result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, + (char *)A + i * strideA * 2, Atype, lda, strideA, + (char *)B + i * strideB * 2, Btype, ldb, strideB, + beta, + (char *)C + i * strideC * 2, Ctype, ldc, strideC, + (int)count, computeType, algo); + TORCH_CUDABLAS_CHECK(result); + } + return result; +} +#endif +#endif + #define GEMM_CHECK_ARGVALUES(Dtype) \ do { \ CUDABLAS_NONNEGINT_CHECK(gemm, m); \ @@ -143,6 +193,161 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { CUDABLAS_POSINT_CHECK(gemm, ldc); \ } while (0) +#define BGEMM_CHECK_ARGVALUES(Dtype) \ + do { \ + CUDABLAS_NONNEGINT_CHECK(bgemm, m); \ + CUDABLAS_NONNEGINT_CHECK(bgemm, n); \ + CUDABLAS_NONNEGINT_CHECK(bgemm, k); \ + CUDABLAS_POSINT_CHECK(bgemm, lda); \ + CUDABLAS_POSINT_CHECK(bgemm, ldb); \ + CUDABLAS_POSINT_CHECK(bgemm, ldc); \ + CUDABLAS_NONNEGINT_CHECK(bgemm, num_batches); \ + } while (0) + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(double); + TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched( + handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(float); + TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched( + handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); +} + +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(c10::complex); + TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched( + handle, opa, opb, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), + lda, stridea, reinterpret_cast(b), ldb, strideb, reinterpret_cast(&beta), + reinterpret_cast(c), ldc, stridec, num_batches)); +} + +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(c10::complex); + TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched( + handle, opa, opb, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), + lda, stridea, reinterpret_cast(b), ldb, strideb, reinterpret_cast(&beta), + reinterpret_cast(c), ldc, stridec, num_batches)); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + BGEMM_CHECK_ARGVALUES(at::Half); + float falpha = alpha; + float fbeta = beta; +#ifdef __HIP_PLATFORM_HCC__ + TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, + (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea, + b, rocblas_datatype_f16_r, (int)ldb, strideb, + (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec, + c, rocblas_datatype_f16_r, (int)ldc, stridec, + (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, + 0, 0)); +#else + #if defined(CUDA_VERSION) && CUDA_VERSION < 11000 + // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH + // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. + TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + #endif // CUDA_VERSION < 11000 + + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 5){ + TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix( + handle, opa, opb, m, n, k, + (void*)(&falpha), a, CUDA_R_16F, lda, stridea, + b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta), + c, CUDA_R_16F, ldc, stridec, + num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + for (int64_t i = 0; i < num_batches; ++i) { + at::cuda::blas::gemm( + transa, transb, + m, n, k, + alpha, (a + i * stridea), lda, + (b + i * strideb), ldb, beta, + (c + i * stridec), ldc); + } + } + #if defined(CUDA_VERSION) && CUDA_VERSION < 11000 + // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH + // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. + TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + #endif // CUDA_VERSION < 11000 +#endif // __HIP_PLATFORM_HCC__ +} + +#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { + // See Note [Writing Nondeterministic Operations] + globalContext().alertCuBLASConfigNotDeterministic(); + BGEMM_CHECK_ARGVALUES(at::BFloat16); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + float falpha = alpha; + float fbeta = beta; + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + + #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); + TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle, + opa, opb, (int)m, (int)n, (int)k, + (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, + b, CUDA_R_16BF, (int)ldb, strideb, + (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, + (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + #elif defined(__HIP_PLATFORM_HCC__) + TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, + (void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea, + b, rocblas_datatype_bf16_r, (int)ldb, strideb, + (void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec, + c, rocblas_datatype_bf16_r, (int)ldc, stridec, + (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, + 0, 0, NULL, NULL)); + #else + TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); + #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +} +#endif // __HIP_PLATFORM_HCC__ + template <> void gemm(CUDABLAS_GEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] @@ -374,7 +579,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); } else { - AT_ERROR("BFloat16 gemm in CUDA requires Ampere or later GPU"); + TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU"); } } #endif @@ -586,20 +791,104 @@ void getrfBatched( handle, n, dA_array, ldda, ipiv_array, info_array, batchsize)); } +template <> +void getrfBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasZgetrfBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + info_array, + batchsize)); +} + +template <> +void getrfBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + int* info_array, + int batchsize) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasCgetrfBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + info_array, + batchsize)); +} + template <> void getriBatched( - int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize, double** dC_array) { + int n, double** dA_array, int ldda, int* ipiv_array, double** dC_array, int lddc, int* info_array, int batchsize) { auto handle = at::cuda::getCurrentCUDABlasHandle(); TORCH_CUDABLAS_CHECK(cublasDgetriBatched( - handle, n, dA_array, ldda, ipiv_array, dC_array, n, info_array, batchsize)); + handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); } template <> void getriBatched( - int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize, float** dC_array) { + int n, float** dA_array, int ldda, int* ipiv_array, float** dC_array, int lddc, int* info_array, int batchsize) { auto handle = at::cuda::getCurrentCUDABlasHandle(); TORCH_CUDABLAS_CHECK(cublasSgetriBatched( - handle, n, dA_array, ldda, ipiv_array, dC_array, n, info_array, batchsize)); + handle, n, dA_array, ldda, ipiv_array, dC_array, lddc, info_array, batchsize)); +} + +template <> +void getriBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + c10::complex** dC_array, + int lddc, + int* info_array, + int batchsize) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasZgetriBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dC_array), + lddc, + info_array, + batchsize)); +} + +template <> +void getriBatched>( + int n, + c10::complex** dA_array, + int ldda, + int* ipiv_array, + c10::complex** dC_array, + int lddc, + int* info_array, + int batchsize) { + auto handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK(cublasCgetriBatched( + handle, + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dC_array), + lddc, + info_array, + batchsize)); } #endif // CUDART_VERSION diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 17236dc435db0..d44fc49c589a8 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -69,6 +69,31 @@ template <> void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); #endif +#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \ + char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \ + const Dtype *a, int64_t lda, int64_t stridea, \ + const Dtype *b, int64_t ldb, int64_t strideb, \ + Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches + +template +inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { + AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name()); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)); +#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +#endif /* LEVEL 2 BLAS FUNCTIONS */ #define CUDABLAS_GEMV_ARGTYPES(Dtype) \ @@ -97,18 +122,6 @@ template <> void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); #endif -template -void ger( - int64_t m, - int64_t n, - Dtype alpha, - Dtype* x, - int64_t incx, - Dtype* y, - int64_t incy, - Dtype* a, - int64_t lda); - /* LEVEL 1 BLAS FUNCTIONS */ #define CUDABLAS_DOT_ARGTYPES(Dtype) \ @@ -155,10 +168,14 @@ template<> void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float)); template<> void getrfBatched(CUDABLAS_GETRF_ARGTYPES(double)); +template<> +void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); +template<> +void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); #define CUDABLAS_GETRI_ARGTYPES(Dtype) \ - int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize, Dtype** dC_array + int n, Dtype** dA_array, int ldda, int* ipiv_array, Dtype** dC_array, int lddc, int* info_array, int batchsize template void getriBatched(CUDABLAS_GETRI_ARGTYPES(Dtype)) { @@ -168,6 +185,10 @@ template<> void getriBatched(CUDABLAS_GETRI_ARGTYPES(float)); template<> void getriBatched(CUDABLAS_GETRI_ARGTYPES(double)); +template<> +void getriBatched>(CUDABLAS_GETRI_ARGTYPES(c10::complex)); +template<> +void getriBatched>(CUDABLAS_GETRI_ARGTYPES(c10::complex)); #endif // CUDART_VERSION diff --git a/aten/src/ATen/cuda/CUDAContext.cpp b/aten/src/ATen/cuda/CUDAContext.cpp index fd51cc45e7769..d656369c0d6cd 100644 --- a/aten/src/ATen/cuda/CUDAContext.cpp +++ b/aten/src/ATen/cuda/CUDAContext.cpp @@ -51,6 +51,16 @@ cudaDeviceProp* getDeviceProperties(int64_t device) { return &device_properties[device]; } +bool canDeviceAccessPeer(int64_t device, int64_t peer_device) { + std::call_once(init_flag, initCUDAContextVectors); + if (device == -1) device = c10::cuda::current_device(); + AT_ASSERT(device >= 0 && device < num_gpus); + AT_ASSERT(peer_device >= 0 && peer_device < num_gpus); + int can_access = 0; + AT_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, device, peer_device)); + return can_access != 0; +} + Allocator* getCUDADeviceAllocator() { return c10::cuda::CUDACachingAllocator::get(); } diff --git a/aten/src/ATen/cuda/CUDAContext.h b/aten/src/ATen/cuda/CUDAContext.h index e5e4a5462d57a..bf6a7242766bc 100644 --- a/aten/src/ATen/cuda/CUDAContext.h +++ b/aten/src/ATen/cuda/CUDAContext.h @@ -62,6 +62,8 @@ TORCH_CUDA_API int warp_size(); TORCH_CUDA_API cudaDeviceProp* getDeviceProperties(int64_t device); +TORCH_CUDA_API bool canDeviceAccessPeer(int64_t device, int64_t peer_device); + TORCH_CUDA_API Allocator* getCUDADeviceAllocator(); /* Handles */ diff --git a/aten/src/ATen/cuda/CUDADevice.h b/aten/src/ATen/cuda/CUDADevice.h index 9d14ab1627e06..2d0a682488fdd 100644 --- a/aten/src/ATen/cuda/CUDADevice.h +++ b/aten/src/ATen/cuda/CUDADevice.h @@ -11,7 +11,7 @@ namespace cuda { inline Device getDeviceFromPtr(void* ptr) { cudaPointerAttributes attr; AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); - return {DeviceType::CUDA, static_cast(attr.device)}; + return {DeviceType::CUDA, static_cast(attr.device)}; } }} // namespace at::cuda diff --git a/aten/src/ATen/cuda/CUDAFuture.h b/aten/src/ATen/cuda/CUDAFuture.h new file mode 100644 index 0000000000000..5915d89e61d54 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAFuture.h @@ -0,0 +1,142 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace cuda { + +struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future { + public: + using at::ivalue::Future::Future; + + protected: + c10::intrusive_ptr createInstance(at::TypePtr type) override { + return c10::make_intrusive(std::move(type)); + } + + void postMarkCompletedHook(const at::IValue& value) override { + currentDevice_ = c10::cuda::current_device(); + + // Extract them once and cache them for later uses. + dataPtrs_ = extractDataPtrs(value); + + std::vector isCudaDeviceUsed(c10::cuda::device_count(), false); + for (const at::DataPtr& data_ptr : dataPtrs_) { + if (data_ptr.device().is_cuda()) { + isCudaDeviceUsed[data_ptr.device().index()] = true; + } + } + + for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) { + if (isCudaDeviceUsed[idx]) { + at::cuda::CUDAEvent cudaEvent; + cudaEvent.record(at::cuda::getCurrentCUDAStream(idx)); + cudaEvents_.push_back(std::move(cudaEvent)); + } + } + } + + std::function wrapCallback( + std::function callback) override { + return [this, callback{std::move(callback)}]() { + // We'd love to get a stream for all devices, even those that are not used + // by the value, because the callback could use those other devices, but + // unfortunately this could cause a deadlock with NCCL. See + // https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414 + // In general, if some devices haven't been used yet, by getting a stream + // for them we'd initialize them, and in addition to causing NCCL to + // misbehaving this also ends up using memory on those devices, which the + // user might not want. + std::vector streams; + for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) { + c10::DeviceIndex idx = cudaEvent.device_index(); + // FIXME Should we find a way to allow to change the priority of + // streams? + at::cuda::CUDAStream stream = + at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx); + cudaEvent.block(stream); + streams.push_back(stream); + } + + // Use the dedicated callback stream to run callback. + at::cuda::CUDAMultiStreamGuard streamGuard(streams); + + // Do not free the underlying data storage of value_ before its + // usage on the stream finishes. + for (const at::DataPtr& data_ptr : dataPtrs_) { + if (data_ptr.device().is_cuda()) { + c10::cuda::CUDACachingAllocator::recordStream( + data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); + } + } + + c10::cuda::CUDAGuard deviceGuard(currentDevice_); + + callback(); + }; + } + + void postWaitHook(const at::IValue& value) override { + for (at::cuda::CUDAEvent& cudaEvent : cudaEvents_) { + cudaEvent.block( + at::cuda::getCurrentCUDAStream(cudaEvent.device_index())); + } + + for (const at::DataPtr& data_ptr : dataPtrs_) { + if (data_ptr.device().is_cuda()) { + c10::cuda::CUDACachingAllocator::recordStream( + data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index())); + } + } + } + + virtual std::vector> extractDataPtrs( + const at::IValue& value) { + at::IValue::HashAliasedIValues sub_values; + // Prefer getSubValues() over visit() as the latter is a silent no-op for + // some unsupported types, whereas the former at least fails loudly. + value.getSubValues(sub_values); + + std::vector> data_ptrs; + for (const at::IValue& sub_value : sub_values) { + if (sub_value.isTensor()) { + data_ptrs.emplace_back(sub_value.toTensor().storage().data_ptr()); + } + } + return data_ptrs; + } + + private: + // The device that was current when markCompleted was called, which we'll + // restore when invoking callbacks. + c10::DeviceIndex currentDevice_; + + // The events that correspond to the completion of the async I/O kernels. They + // are recorded on the appropriate streams when the future is marked completed + // and can then be queried/waited/blocked on. There is one event for each + // distinct device on which the value's tensors reside. + std::vector cudaEvents_; + + // A cached version of the data ptrs extracted from the value when the future + // is first marked completed. + std::vector> dataPtrs_; +}; + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index ea7c015499ea3..9e0adee23fb1c 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -1,10 +1,15 @@ +#include #include +#include +#include #include #include namespace at { +namespace cuda { +namespace detail { -namespace cuda { namespace detail { +namespace { // Ensures we only call cudaGetDeviceCount only once. static std::once_flag num_gpu_init_flag; @@ -18,7 +23,7 @@ static std::deque cuda_gens_init_flag; // Default, global CUDA generators, one per GPU. static std::vector default_gens_cuda; -/* +/* * Populates the global variables related to CUDA generators * Warning: this function must only be called once! */ @@ -28,6 +33,8 @@ static void initCUDAGenVector(){ default_gens_cuda.resize(num_gpus); } +} // anonymous namespace + /** * PyTorch maintains a collection of default generators that get * initialized once. The purpose of these default generators is to @@ -71,79 +78,230 @@ Generator createCUDAGenerator(DeviceIndex device_index) { } // namespace detail } // namespace cuda +/** + * Note [Why enforce RNG offset % 4 == 0?] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Curand philox does allow offsets that aren't a multiple of 4. + * But jit kernels don't use curand, they use a custom "Philox" class (see + * torch/csrc/jit/tensorexpr/cuda_random.h or + * torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu). + * The "Philox" constructor computes offset/4 (a uint64_t division) to locate its + * internal start in its virtual bitstream viewed as 128-bit chunks, then, when called + * in a thread, returns one 32-bit chunk at a time from that start in the bitstream. + * In other words, if the incoming offset is not a multiple of 4, each thread + * might repeat some previously-generated 32-bit values in the bitstream. See + * https://github.com/pytorch/pytorch/pull/50169. + */ + /** * CUDAGeneratorImpl class implementation */ CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index) : c10::GeneratorImpl{Device(DeviceType::CUDA, device_index), - DispatchKeySet(c10::DispatchKey::CUDA)} { } + DispatchKeySet(c10::DispatchKey::CUDA)} { + at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl"); +} /** * Sets the seed to be used by curandStatePhilox4_32_10 * Resets the philox_offset_per_thread_ to 0 - * + * * See Note [Acquire lock when using random generators] */ void CUDAGeneratorImpl::set_current_seed(uint64_t seed) { + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_current_seed"); seed_ = seed; philox_offset_per_thread_ = 0; } +#define CAPTURE_DEFAULT_GENS_MSG \ +"In regions captured by CUDA graphs, you may only use the default CUDA RNG " \ +"generator on the device that's current when capture begins. " \ +"If you need a non-default (user-supplied) generator, or a generator on another " \ +"device, please file an issue." + /** * Gets the current seed of CUDAGeneratorImpl. */ uint64_t CUDAGeneratorImpl::current_seed() const { + // Debatable if current_seed() should be allowed in captured regions. + // Conservatively disallow it for now. + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed"); return seed_; } /** * Gets a nondeterministic random number from /dev/urandom or time, * seeds the CPUGeneratorImpl with it and then returns that number. - * + * * FIXME: You can move this function to Generator.cpp if the algorithm * in getNonDeterministicRandom is unified for both CPU and CUDA */ uint64_t CUDAGeneratorImpl::seed() { + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::seed"); auto random = c10::detail::getNonDeterministicRandom(true); this->set_current_seed(random); return random; } +/** + * Gets the current internal state of CUDAGeneratorImpl. The internal + * state is returned as a CPU byte tensor. + */ +c10::intrusive_ptr CUDAGeneratorImpl::get_state() const { + // The RNG state comprises the seed, and an offset used for Philox. + // The following line is just here for BC reason. sizeof curandStateMtgp32 is 4120. + // It used to be static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); + // MAX_NUM_BLOCKS was 200 and sizeof(curandStateMtgp32) is 4120. Hardcoding these numbers here + // because this is just host side code and we don't want to worry about linking with cuda + static const size_t states_size = 200 * sizeof(4120); + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); + static const size_t total_size = states_size + seed_size + offset_size; + + auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto rng_state = state_tensor.data_ptr(); + // since curandStateMTGP is not used anymore, fill gen_states of THCGenerator with deterministic garbage value of -1 + // gen_states in THCGenerator struct was an array of curandStateMtgp32s. + memset(rng_state, -1, states_size); + auto current_seed = this->current_seed(); + auto offset = static_cast(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic + memcpy(rng_state + states_size, ¤t_seed, seed_size); + memcpy(rng_state + states_size + seed_size, &offset, offset_size); + + return state_tensor.getIntrusivePtr(); +} + +/** + * Sets the internal state of CUDAGeneratorImpl. The new internal state + * must be a strided CPU byte tensor and have appropriate size. See + * comments of CUDAGeneratorImpl::state for information about the layout + * and size of the internal state. + */ +void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t states_size = 200 * sizeof(4120); // this line is just here for BC reason + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); + static const size_t total_size = states_size + seed_size + offset_size; + + detail::check_rng_state(new_state); + + bool no_philox_seed = false; + auto new_state_size = new_state.numel(); + if (new_state_size == total_size - offset_size) { + no_philox_seed = true; + } else { + TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); + } + + uint64_t input_seed; + auto new_rng_state = new_state.data(); + memcpy(&input_seed, new_rng_state + states_size, seed_size); + this->set_current_seed(input_seed); + int64_t philox_offset = 0; + if (!no_philox_seed) { + memcpy(&philox_offset, new_rng_state + states_size + seed_size, offset_size); + } + this->set_philox_offset_per_thread(static_cast(philox_offset)); +} + /** * Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10 - * + * * See Note [Acquire lock when using random generators] */ void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_philox_offset_per_thread"); + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); philox_offset_per_thread_ = offset; } /** * Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl. */ -uint64_t CUDAGeneratorImpl::philox_offset_per_thread() { +uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const { + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::philox_offset_per_thread"); return philox_offset_per_thread_; } +/** + * Called by CUDAGraph to prepare this instance for a graph capture region. + * offset_extragraph is the initial offset at the start of the graphed region. + * offset_intragraph tracks the offset in the graphed region. + */ +void CUDAGeneratorImpl::capture_prologue(int64_t* offset_extragraph) { + offset_extragraph_ = offset_extragraph; + offset_intragraph_ = 0; + graph_expects_this_gen_ = true; +} + +/** + * Called by CUDAGraph to finalize a graph capture region for this instance. + */ +uint64_t CUDAGeneratorImpl::capture_epilogue() { + graph_expects_this_gen_ = false; + return offset_intragraph_; +} + /** * Gets the seed and philox offset value to be used in - * curandStatePhilox4_32_10 - * + * curandStatePhilox4_32_10, in an opaque PhiloxCudaState that's safe + * and can be used non-divergently in callers whether CUDA graph + * capture is underway or not. See + * Note [CUDA Graph-safe RNG states] + * * Each kernel using philox has to sensibly increment offset * for future users of philox. So it gets the "old" value for * itself (before add), and tells subsequent users which offset * they should use, since only the kernel knows how many randoms - * it intends to generate. - * + * it intends to generate. + * * Increment should be at least the number of curand() random numbers used in - * each thread. It is the user's responsibility to make sure that the increment + * each thread. It is the user's responsibility to make sure the increment * for philox is never smaller than the number of curand() calls. Increment * value > the number of curand() calls won't harm but anything less would mean * that you would be reusing random values from previous calls. - * + * * See Note [Acquire lock when using random generators] */ +PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) { + // rounds increment up to the nearest multiple of 4 + increment = ((increment + 3) / 4) * 4; + if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { + TORCH_CHECK(graph_expects_this_gen_, + "philox_cuda_state for an unexpected CUDA generator used during capture. " + CAPTURE_DEFAULT_GENS_MSG); + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT(this->offset_intragraph_ % 4 == 0); + uint32_t offset = this->offset_intragraph_; + TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <= + std::numeric_limits::max() - increment); + this->offset_intragraph_ += increment; + return PhiloxCudaState(this->seed_, + this->offset_extragraph_, + offset); + } else { + TORCH_CHECK(!graph_expects_this_gen_, + "CUDA generator expects graph capture to be underway, " + "but the current stream is not capturing."); + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); + uint64_t offset = this->philox_offset_per_thread_; + this->philox_offset_per_thread_ += increment; + return PhiloxCudaState(this->seed_, offset); + } +} + +/** + * Temporarily accommodates call sites that use philox_engine_inputs. + * Allows incremental refactor of call sites to use philox_cuda_state. + */ std::pair CUDAGeneratorImpl::philox_engine_inputs(uint64_t increment) { + at::cuda::assertNotCapturing("Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. " + "Cannot call CUDAGeneratorImpl::philox_engine_inputs"); + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); uint64_t offset = this->philox_offset_per_thread_; this->philox_offset_per_thread_ += increment; return std::make_pair(this->seed_, offset); @@ -159,7 +317,7 @@ DeviceType CUDAGeneratorImpl::device_type() { /** * Public clone method implementation - * + * * See Note [Acquire lock when using random generators] */ std::shared_ptr CUDAGeneratorImpl::clone() const { @@ -168,10 +326,11 @@ std::shared_ptr CUDAGeneratorImpl::clone() const { /** * Private clone method implementation - * + * * See Note [Acquire lock when using random generators] */ CUDAGeneratorImpl* CUDAGeneratorImpl::clone_impl() const { + at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::clone_impl"); auto gen = new CUDAGeneratorImpl(this->device().index()); gen->set_current_seed(this->seed_); gen->set_philox_offset_per_thread(this->philox_offset_per_thread_); diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp new file mode 100644 index 0000000000000..74cc5ca097939 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -0,0 +1,168 @@ +#include +#include +#include +#include + +namespace at { +namespace cuda { + +/** + * Note [CUDA Graph Wrapper Class] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Q: Why do we need graph capture and launch bindings in Pytorch? + * Why can't they live in a user extension, for example? + * + * A1: Convenience. + * A2: To ensure valid numerics on replay, some native CUDA ops (like RNG ops with + * CPU statefulness) need cooperation from the capture and replay bindings + * (see Note [CUDA Graph-safe RNG states] in CUDAGeneratorImpl.h). + * + * We can't expect users to know about this cooperation. If users write capture + * bindings naively in an extension, they likely won't interact with the native + * ops properly. Their graphs would yield invalid numerics on replay. + */ + +CUDAGraph::CUDAGraph() + // CUDAStreams may not be default-constructed. + : capture_stream_(at::cuda::getCurrentCUDAStream()) { +#if CUDA_VERSION < 11000 + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::capture_begin() { +#if CUDA_VERSION >= 11000 + TORCH_CHECK(!has_graph_exec_, + "This CUDAGraph instance already owns a captured graph. " + "To capture a new graph, create a new instance."); + + // For now, a CUDAGraph instance only accommodates the default generator on the device that's + // current when capture begins. If any op in the captured region uses a non-default generator, + // or a generator on another device, the offending generator will throw an error. + // These restrictions simplify CUDAGraph, but could be relaxed in the future: + // in principle, the underlying Cuda calls do permit cross-device ops to be captured. + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + + auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong); + offset_extragraph_ = at::empty({1}, options); + + gen->capture_prologue(offset_extragraph_.data_ptr()); + + auto stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(stream != at::cuda::getDefaultCUDAStream(), + "CUDA graphs must be captured on a non-default stream. " + "(However, after capture, it's ok to replay them on the " + "default stream.)"); + + capture_stream_ = stream; + capture_gen_ = gen; + + // cudaStreamCaptureModeGlobal is the most conservative option to + // prevent potentially unsafe CUDA API calls during capture. See + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); + + // Stashes the current graph's uuid. + cudaStreamCaptureStatus status; + AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id_)); + TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive); +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::capture_end() { +#if CUDA_VERSION >= 11000 + auto stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK(stream == capture_stream_, + "Capture must end on the same stream it began on."); + + AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_)); + TORCH_CHECK(graph_ != NULL, "Invalid capture."); + has_graph_ = true; + + // Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people, + // who prefer not to report error message through these arguments moving forward + // (they prefer return value, or errors on api calls internal to the capture) + AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); + has_graph_exec_ = true; + + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + TORCH_CHECK(gen == capture_gen_, + "Default CUDA RNG generator on current device at capture end " + "is different from default generator on current device " + "when capture began"); + wholegraph_increment_ = gen->capture_epilogue(); + + // Now that we've instantiated graph_ into graph_exec_, + // we don't need graph_ anymore. + AT_CUDA_CHECK(cudaGraphDestroy(graph_)); + has_graph_ = false; +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::replay() { +#if CUDA_VERSION >= 11000 + TORCH_CHECK(has_graph_exec_, + "Called CUDAGraph::replay without a preceding successful capture."); + + { + c10::OptionalDeviceGuard device_guard{capture_stream_.device()}; + + // Just like any RNG consumer kernel! + auto* gen = get_generator_or_default( + c10::nullopt, cuda::detail::getDefaultCUDAGenerator()); + PhiloxCudaState rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(wholegraph_increment_); + } + offset_extragraph_.fill_(int64_t(rng_engine_inputs.offset_.val)); + + // graph_exec_ may be replayed in any stream. + AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream())); + } +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +void CUDAGraph::reset() { +#if CUDA_VERSION >= 11000 + // I'd prefer these checks throw exceptions, not print warnings, + // but the destructor calls reset(), and at least one CI build + // refuses to compile with a throwing destructor. + // + // Instead of calling reset() in the destructor to clean up, I could + // call reset() in the __del__ method of a thin Python wrapper, + // in which case reset would be allowed to throw exceptions. + // But Stackoverflow does not like user-defined __del__. + // __del__ prevents Graph instances from EVER being garbage collected + // if they participate in a reference cycle. + // And exceptions thrown in __del__ only print a warning anyway. + // + // Calling reset() in the C++ destructor, with warnings instead of exceptions + // if calls fail, is the compromise we chose. + if (has_graph_) { + C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_)); + } + if (has_graph_exec_) { + C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_)); + } +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0"); +#endif +} + +CUDAGraph::~CUDAGraph() { + reset(); +} + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h new file mode 100644 index 0000000000000..3872717150550 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -0,0 +1,43 @@ +#include +#include +#include +#include + +namespace at { +namespace cuda { + +struct TORCH_CUDA_API CUDAGraph { + CUDAGraph(); + ~CUDAGraph(); + + void capture_begin(); + void capture_end(); + void replay(); + void reset(); + + protected: +#if CUDA_VERSION >= 11000 + cudaGraph_t graph_ = NULL; + cudaGraphExec_t graph_exec_ = NULL; +#endif + + // internal states for error checking + bool has_graph_ = false; + bool has_graph_exec_ = false; + + // uuid, retrieved from Cuda + unsigned long long id_; + + // Stream on which capture began + at::cuda::CUDAStream capture_stream_; + + // Default generator on device where capture began + at::CUDAGeneratorImpl* capture_gen_; + + // RNG state trackers + at::Tensor offset_extragraph_; + uint64_t wholegraph_increment_; +}; + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDAGraphsUtils.cuh b/aten/src/ATen/cuda/CUDAGraphsUtils.cuh new file mode 100644 index 0000000000000..4b2d09ad74d43 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGraphsUtils.cuh @@ -0,0 +1,97 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { +namespace cuda { +namespace philox { + +// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen. +// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable. +// Easiest thing that comes to mind is, define a free function here, in ATen/cuda. +// Any cuda consumer can include this header. +__device__ __forceinline__ std::tuple +unpack(at::PhiloxCudaState arg) { + if (arg.captured_) { + return std::make_tuple(arg.seed_, *(arg.offset_.ptr) + arg.offset_intragraph_); + } else { + return std::make_tuple(arg.seed_, arg.offset_.val); + } +} + +} // namespace philox + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +// Protects against enum cudaStreamCaptureStatus implementation changes. +// Some compilers seem not to like static_assert without the messages. +static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, + "unexpected int(cudaStreamCaptureStatusNone) value"); +static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, + "unexpected int(cudaStreamCaptureStatusActive) value"); +static_assert(int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, + "unexpected int(cudaStreamCaptureStatusInvalidated) value"); +#endif + +enum class CaptureStatus: int { + #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), + Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), + Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) + #else + None = 0 + #endif +}; + +inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { + switch(status) { + case CaptureStatus::None: + os << "cudaStreamCaptureStatusNone"; + break; + #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + case CaptureStatus::Active: + os << "cudaStreamCaptureStatusActive"; + break; + case CaptureStatus::Invalidated: + os << "cudaStreamCaptureStatusInvalidated"; + break; + #endif + default: + TORCH_INTERNAL_ASSERT(false, + "Unknown CUDA graph CaptureStatus", + int(status)); + } + return os; +} + +inline CaptureStatus currentStreamCaptureStatus() { + #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + // don't create a context if we don't have to + if (at::detail::getCUDAHooks().hasPrimaryContext(c10::cuda::current_device())) { + cudaStreamCaptureStatus is_capturing; + AT_CUDA_CHECK(cudaStreamIsCapturing(at::cuda::getCurrentCUDAStream(), + &is_capturing)); + return CaptureStatus(is_capturing); + } else { + return CaptureStatus::None; + } + #else + return CaptureStatus::None; + #endif +} + +inline void assertNotCapturing(std::string attempt) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK(status == CaptureStatus::None, + attempt, + " during CUDA graph capture. If you need this call to be captured, " + "please file an issue. " + "Current cudaStreamCaptureStatus: ", + status); +} + +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/CUDASolver.cpp b/aten/src/ATen/cuda/CUDASolver.cpp index acaa4234d3e6b..bcd630a06b9e8 100644 --- a/aten/src/ATen/cuda/CUDASolver.cpp +++ b/aten/src/ATen/cuda/CUDASolver.cpp @@ -16,9 +16,9 @@ void getrf( TORCH_CUSOLVER_CHECK( cusolverDnDgetrf_bufferSize(handle, m, n, dA, ldda, &lwork)); auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - void* buffer = allocator.allocate(sizeof(double)*lwork).get(); + auto dataPtr = allocator.allocate(sizeof(double)*lwork); TORCH_CUSOLVER_CHECK(cusolverDnDgetrf( - handle, m, n, dA, ldda, static_cast(buffer), ipiv, info)); + handle, m, n, dA, ldda, static_cast(dataPtr.get()), ipiv, info)); } template <> @@ -28,9 +28,59 @@ void getrf( TORCH_CUSOLVER_CHECK( cusolverDnSgetrf_bufferSize(handle, m, n, dA, ldda, &lwork)); auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - void* buffer = allocator.allocate(sizeof(float)*lwork).get(); + auto dataPtr = allocator.allocate(sizeof(float)*lwork); TORCH_CUSOLVER_CHECK(cusolverDnSgetrf( - handle, m, n, dA, ldda, static_cast(buffer), ipiv, info)); + handle, m, n, dA, ldda, static_cast(dataPtr.get()), ipiv, info)); +} + +template <> +void getrf>( + cusolverDnHandle_t handle, + int m, + int n, + c10::complex* dA, + int ldda, + int* ipiv, + int* info) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize( + handle, m, n, reinterpret_cast(dA), ldda, &lwork)); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex) * lwork); + TORCH_CUSOLVER_CHECK(cusolverDnZgetrf( + handle, + m, + n, + reinterpret_cast(dA), + ldda, + static_cast(dataPtr.get()), + ipiv, + info)); +} + +template <> +void getrf>( + cusolverDnHandle_t handle, + int m, + int n, + c10::complex* dA, + int ldda, + int* ipiv, + int* info) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize( + handle, m, n, reinterpret_cast(dA), ldda, &lwork)); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(cuComplex) * lwork); + TORCH_CUSOLVER_CHECK(cusolverDnCgetrf( + handle, + m, + n, + reinterpret_cast(dA), + ldda, + static_cast(dataPtr.get()), + ipiv, + info)); } template <> @@ -47,6 +97,54 @@ void getrs( handle, CUBLAS_OP_N, n, nrhs, dA, lda, ipiv, ret, ldb, info)); } +template <> +void getrs>( + cusolverDnHandle_t handle, + int n, + int nrhs, + c10::complex* dA, + int lda, + int* ipiv, + c10::complex* ret, + int ldb, + int* info) { + TORCH_CUSOLVER_CHECK(cusolverDnZgetrs( + handle, + CUBLAS_OP_N, + n, + nrhs, + reinterpret_cast(dA), + lda, + ipiv, + reinterpret_cast(ret), + ldb, + info)); +} + +template <> +void getrs>( + cusolverDnHandle_t handle, + int n, + int nrhs, + c10::complex* dA, + int lda, + int* ipiv, + c10::complex* ret, + int ldb, + int* info) { + TORCH_CUSOLVER_CHECK(cusolverDnCgetrs( + handle, + CUBLAS_OP_N, + n, + nrhs, + reinterpret_cast(dA), + lda, + ipiv, + reinterpret_cast(ret), + ldb, + info)); +} + } // namespace solver } // namespace cuda } // namespace at diff --git a/aten/src/ATen/cuda/CUDASolver.h b/aten/src/ATen/cuda/CUDASolver.h index 06609409f177f..327c7b824c5e2 100644 --- a/aten/src/ATen/cuda/CUDASolver.h +++ b/aten/src/ATen/cuda/CUDASolver.h @@ -19,6 +19,10 @@ template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(float)); template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(double)); +template<> +void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); +template<> +void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); #define CUDASOLVER_GETRS_ARGTYPES(Dtype) \ @@ -32,6 +36,10 @@ template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(float)); template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(double)); +template<> +void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); +template<> +void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); } // namespace solver diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 0165c53ac60d0..82421f49de1e0 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -50,6 +50,15 @@ cublasHandle_t getCurrentCUDABlasHandle() { } else { TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); } +#endif +#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 308 + rocblas_atomics_mode rocblas_mode; + if (at::globalContext().deterministic()) { + rocblas_mode = rocblas_atomics_not_allowed; + } else { + rocblas_mode = rocblas_atomics_allowed; + } + TORCH_CUDABLAS_CHECK(rocblas_set_atomics_mode(handle, rocblas_mode)); #endif return handle; } diff --git a/aten/src/ATen/cuda/DeviceUtils.cuh b/aten/src/ATen/cuda/DeviceUtils.cuh index 0bc3a3e505e9f..52926a84893cd 100644 --- a/aten/src/ATen/cuda/DeviceUtils.cuh +++ b/aten/src/ATen/cuda/DeviceUtils.cuh @@ -1,3 +1,5 @@ +#pragma once + #include #include #include diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index 615ba3e92b71a..80e39c6bc6bc4 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -19,20 +19,23 @@ class CuDNNError : public c10::Error { } // namespace c10 +#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__) + // See Note [CHECK macro] -#define AT_CUDNN_CHECK(EXPR) \ - do { \ - cudnnStatus_t status = EXPR; \ - if (status != CUDNN_STATUS_SUCCESS) { \ - if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ - TORCH_CHECK_WITH(CuDNNError, false, \ - "cuDNN error: ", \ - cudnnGetErrorString(status), \ - ". This error may appear if you passed in a non-contiguous input."); \ - } else { \ - TORCH_CHECK_WITH(CuDNNError, false, "cuDNN error: ", cudnnGetErrorString(status)); \ - } \ - } \ +#define AT_CUDNN_CHECK(EXPR, ...) \ + do { \ + cudnnStatus_t status = EXPR; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", \ + cudnnGetErrorString(status), \ + ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \ + } else { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \ + } \ + } \ } while (0) namespace at { namespace cuda { namespace blas { diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp index c43d53751aee2..7abd99c4e4f05 100644 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp @@ -1,7 +1,5 @@ #include -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.cpp - #include #include #include @@ -435,144 +433,6 @@ Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const } return self; } -Tensor & _th_take_out(Tensor & result, const Tensor & self, const Tensor & index) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaBoolTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaByteTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaCharTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaDoubleTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaIntTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaLongTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaShortTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaHalfTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - default: - AT_ERROR("_th_take_out not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_take(const Tensor & self, const Tensor & index) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaBoolTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaByteTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaCharTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaDoubleTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaIntTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaLongTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaShortTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_take", false, DeviceType::CUDA, dispatch_scalar_type); - auto index_ = checked_dense_tensor_unwrap(index, "index", 2, "_th_take", false, DeviceType::CUDA, ScalarType::Long); - THCudaHalfTensor_take(globalContext().getTHCState(), result_, self_, index_); - break; - } - default: - AT_ERROR("_th_take not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); @@ -838,10 +698,12 @@ std::tuple _th_mode(const Tensor & self, int64_t dim, bool keepdi } return std::tuple(values, indices); } -std::tuple _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending) { +std::tuple _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending, bool stable) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); + TORCH_CHECK(!stable, "stable=True is not implemented on CUDA yet."); + switch (dispatch_scalar_type) { case ScalarType::Byte: { auto values_ = checked_dense_tensor_unwrap(values, "values", 0, "_th_sort_out", false, DeviceType::CUDA, dispatch_scalar_type); @@ -904,8 +766,11 @@ std::tuple _th_sort_out(Tensor & values, Tensor & indices, co } return std::tuple(values, indices); } -std::tuple _th_sort(const Tensor & self, int64_t dim, bool descending) { +std::tuple _th_sort(const Tensor & self, int64_t dim, bool descending, bool stable) { // DeviceGuard omitted + + TORCH_CHECK(!stable, "stable=True is not implemented on CUDA yet."); + auto dispatch_scalar_type = infer_scalar_type(self); auto values_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); auto values = Tensor(c10::intrusive_ptr::reclaim(values_)); @@ -1135,877 +1000,182 @@ Tensor _th_renorm(const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { THCudaDoubleTensor_renorm(globalContext().getTHCState(), result_, self_, p_, dim, maxnorm_); break; } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm", false, DeviceType::CUDA, dispatch_scalar_type); - auto p_ = p.toFloat(); - auto maxnorm_ = maxnorm.toFloat(); - THCudaTensor_renorm(globalContext().getTHCState(), result_, self_, p_, dim, maxnorm_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm", false, DeviceType::CUDA, dispatch_scalar_type); - auto p_ = p.toHalf(); - auto maxnorm_ = maxnorm.toHalf(); - THCudaHalfTensor_renorm(globalContext().getTHCState(), result_, self_, p_, dim, maxnorm_); - break; - } - default: - AT_ERROR("_th_renorm not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm_", false, DeviceType::CUDA, dispatch_scalar_type); - auto p_ = p.toDouble(); - auto maxnorm_ = maxnorm.toDouble(); - THCudaDoubleTensor_renorm(globalContext().getTHCState(), self_, self_, p_, dim, maxnorm_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm_", false, DeviceType::CUDA, dispatch_scalar_type); - auto p_ = p.toFloat(); - auto maxnorm_ = maxnorm.toFloat(); - THCudaTensor_renorm(globalContext().getTHCState(), self_, self_, p_, dim, maxnorm_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm_", false, DeviceType::CUDA, dispatch_scalar_type); - auto p_ = p.toHalf(); - auto maxnorm_ = maxnorm.toHalf(); - THCudaHalfTensor_renorm(globalContext().getTHCState(), self_, self_, p_, dim, maxnorm_); - break; - } - default: - AT_ERROR("_th_renorm_ not supported on CUDAType for ", dispatch_scalar_type); - } - return self; -} -Tensor & _th_fmod_out(Tensor & result, const Tensor & self, Scalar other) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toByte(); - THCudaByteTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toChar(); - THCudaCharTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toDouble(); - THCudaDoubleTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toFloat(); - THCudaTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toInt(); - THCudaIntTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toLong(); - THCudaLongTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toShort(); - THCudaShortTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toHalf(); - THCudaHalfTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - default: - AT_ERROR("_th_fmod_out not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_fmod(const Tensor & self, Scalar other) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toByte(); - THCudaByteTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toChar(); - THCudaCharTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toDouble(); - THCudaDoubleTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toFloat(); - THCudaTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toInt(); - THCudaIntTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toLong(); - THCudaLongTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toShort(); - THCudaShortTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toHalf(); - THCudaHalfTensor_fmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - default: - AT_ERROR("_th_fmod not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor & _th_fmod_out(Tensor & result, const Tensor & self, const Tensor & other) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - default: - AT_ERROR("_th_fmod_out not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_fmod(const Tensor & self, const Tensor & other) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_fmod", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_cfmod(globalContext().getTHCState(), result_, self_, other_); - break; - } - default: - AT_ERROR("_th_fmod not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor & _th_fmod_(Tensor & self, Scalar other) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toByte(); - THCudaByteTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toChar(); - THCudaCharTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toDouble(); - THCudaDoubleTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toFloat(); - THCudaTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toInt(); - THCudaIntTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toLong(); - THCudaLongTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toShort(); - THCudaShortTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = other.toHalf(); - THCudaHalfTensor_fmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - default: - AT_ERROR("_th_fmod_ not supported on CUDAType for ", dispatch_scalar_type); - } - return self; -} -Tensor & _th_fmod_(Tensor & self, const Tensor & other) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 3, "_th_fmod_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_cfmod(globalContext().getTHCState(), self_, self_, other_); - break; - } - default: - AT_ERROR("_th_fmod_ not supported on CUDAType for ", dispatch_scalar_type); - } - return self; -} -Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - default: - AT_ERROR("_th_cross_kernel_out not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); - break; - } - default: - AT_ERROR("_th_cross_kernel not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor & _th_bmm_out(Tensor & result, const Tensor & self, const Tensor & mat2) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, uint8_t(0), uint8_t(1)); - break; - } - case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int8_t(0), int8_t(1)); - break; - } - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, double(0), double(1)); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, float(0), float(1)); - break; - } - case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int(0), int(1)); - break; - } - case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int64_t(0), int64_t(1)); - break; - } - case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int16_t(0), int16_t(1)); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, Half(0), Half(1)); + case ScalarType::Float: { + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm", false, DeviceType::CUDA, dispatch_scalar_type); + auto p_ = p.toFloat(); + auto maxnorm_ = maxnorm.toFloat(); + THCudaTensor_renorm(globalContext().getTHCState(), result_, self_, p_, dim, maxnorm_); break; } - case ScalarType::BFloat16: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, BFloat16(0), BFloat16(1)); + case ScalarType::Half: { + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm", false, DeviceType::CUDA, dispatch_scalar_type); + auto p_ = p.toHalf(); + auto maxnorm_ = maxnorm.toHalf(); + THCudaHalfTensor_renorm(globalContext().getTHCState(), result_, self_, p_, dim, maxnorm_); break; } default: - AT_ERROR("_th_bmm_out not supported on CUDAType for ", dispatch_scalar_type); + AT_ERROR("_th_renorm not supported on CUDAType for ", dispatch_scalar_type); } return result; } -Tensor _th_bmm(const Tensor & self, const Tensor & mat2) { +Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); + switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, uint8_t(0), uint8_t(1)); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int8_t(0), int8_t(1)); - break; - } case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, double(0), double(1)); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm_", false, DeviceType::CUDA, dispatch_scalar_type); + auto p_ = p.toDouble(); + auto maxnorm_ = maxnorm.toDouble(); + THCudaDoubleTensor_renorm(globalContext().getTHCState(), self_, self_, p_, dim, maxnorm_); break; } case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, float(0), float(1)); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int(0), int(1)); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int64_t(0), int64_t(1)); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, int16_t(0), int16_t(1)); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm_", false, DeviceType::CUDA, dispatch_scalar_type); + auto p_ = p.toFloat(); + auto maxnorm_ = maxnorm.toFloat(); + THCudaTensor_renorm(globalContext().getTHCState(), self_, self_, p_, dim, maxnorm_); break; } case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, Half(0), Half(1)); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto mat2_ = checked_dense_tensor_unwrap(mat2, "mat2", 2, "_th_bmm", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, result_, self_, mat2_, BFloat16(0), BFloat16(1)); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_renorm_", false, DeviceType::CUDA, dispatch_scalar_type); + auto p_ = p.toHalf(); + auto maxnorm_ = maxnorm.toHalf(); + THCudaHalfTensor_renorm(globalContext().getTHCState(), self_, self_, p_, dim, maxnorm_); break; } default: - AT_ERROR("_th_bmm not supported on CUDAType for ", dispatch_scalar_type); + AT_ERROR("_th_renorm_ not supported on CUDAType for ", dispatch_scalar_type); } - return result; + return self; } -Tensor & _th_baddbmm_out(Tensor & result, const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) { +Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); switch (dispatch_scalar_type) { case ScalarType::Byte: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toByte(); - auto alpha_ = alpha.toByte(); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaByteTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Char: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toChar(); - auto alpha_ = alpha.toChar(); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaCharTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toDouble(); - auto alpha_ = alpha.toDouble(); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaDoubleTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toFloat(); - auto alpha_ = alpha.toFloat(); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Int: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toInt(); - auto alpha_ = alpha.toInt(); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaIntTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Long: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toLong(); - auto alpha_ = alpha.toLong(); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaLongTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Short: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toShort(); - auto alpha_ = alpha.toShort(); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaShortTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toHalf(); - auto alpha_ = alpha.toHalf(); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::BFloat16: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toBFloat16(); - auto alpha_ = alpha.toBFloat16(); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel_out", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaHalfTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } default: - AT_ERROR("_th_baddbmm_out not supported on CUDAType for ", dispatch_scalar_type); + AT_ERROR("_th_cross_kernel_out not supported on CUDAType for ", dispatch_scalar_type); } return result; } -Tensor _th_baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) { +Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); switch (dispatch_scalar_type) { case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toByte(); - auto alpha_ = alpha.toByte(); - THCudaByteTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaByteTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toChar(); - auto alpha_ = alpha.toChar(); - THCudaCharTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaCharTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toDouble(); - auto alpha_ = alpha.toDouble(); - THCudaDoubleTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaDoubleTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toFloat(); - auto alpha_ = alpha.toFloat(); - THCudaTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toInt(); - auto alpha_ = alpha.toInt(); - THCudaIntTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaIntTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toLong(); - auto alpha_ = alpha.toLong(); - THCudaLongTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaLongTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toShort(); - auto alpha_ = alpha.toShort(); - THCudaShortTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaShortTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toHalf(); - auto alpha_ = alpha.toHalf(); - THCudaHalfTensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch1_ = checked_dense_tensor_unwrap(batch1, "batch1", 2, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto batch2_ = checked_dense_tensor_unwrap(batch2, "batch2", 3, "_th_baddbmm", false, DeviceType::CUDA, dispatch_scalar_type); - auto beta_ = beta.toBFloat16(); - auto alpha_ = alpha.toBFloat16(); - THCudaBFloat16Tensor_baddbmm(globalContext().getTHCState(), result_, self_, batch1_, batch2_, beta_, alpha_); + auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + auto other_ = checked_dense_tensor_unwrap(other, "other", 2, "_th_cross_kernel", false, DeviceType::CUDA, dispatch_scalar_type); + THCudaHalfTensor_crossKernel(globalContext().getTHCState(), result_, self_, other_, dim); break; } default: - AT_ERROR("_th_baddbmm not supported on CUDAType for ", dispatch_scalar_type); + AT_ERROR("_th_cross_kernel not supported on CUDAType for ", dispatch_scalar_type); } return result; } + std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); @@ -2057,53 +1227,6 @@ std::tuple _th_gels(const Tensor & self, const Tensor & A) { } return std::tuple(res1, res2); } -std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors); - break; - } - case ScalarType::Float: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors); - break; - } - default: - AT_ERROR("_th_eig_out not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} -std::tuple _th_eig(const Tensor & self, bool eigenvectors) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto res1_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res1 = Tensor(c10::intrusive_ptr::reclaim(res1_)); - auto res2_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res2 = Tensor(c10::intrusive_ptr::reclaim(res2_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors); - break; - } - default: - AT_ERROR("_th_eig not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); @@ -2194,125 +1317,6 @@ std::tuple _th_geqrf(const Tensor & self) { } return std::tuple(res1, res2); } -std::tuple _th_multinomial_alias_setup_out(Tensor & J, Tensor & q, const Tensor & probs) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(J); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_multinomialAliasSetup(globalContext().getTHCState(), probs_, J_, q_); - break; - } - case ScalarType::Float: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_multinomialAliasSetup(globalContext().getTHCState(), probs_, J_, q_); - break; - } - case ScalarType::Half: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_setup_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_multinomialAliasSetup(globalContext().getTHCState(), probs_, J_, q_); - break; - } - default: - AT_ERROR("_th_multinomial_alias_setup_out not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(J, q); -} -std::tuple _th_multinomial_alias_setup(const Tensor & probs) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(probs); - auto J_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(ScalarType::Long)).release(); - auto J = Tensor(c10::intrusive_ptr::reclaim(J_)); - auto q_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto q = Tensor(c10::intrusive_ptr::reclaim(q_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_multinomialAliasSetup(globalContext().getTHCState(), probs_, J_, q_); - break; - } - case ScalarType::Float: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_multinomialAliasSetup(globalContext().getTHCState(), probs_, J_, q_); - break; - } - case ScalarType::Half: { - auto probs_ = checked_dense_tensor_unwrap(probs, "probs", 1, "_th_multinomial_alias_setup", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_multinomialAliasSetup(globalContext().getTHCState(), probs_, J_, q_); - break; - } - default: - AT_ERROR("_th_multinomial_alias_setup not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(J, q); -} -Tensor & _th_multinomial_alias_draw_out(Tensor & result, const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(result); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaDoubleTensor_multinomialAliasDraw(globalContext().getTHCState(), result_, q_, J_, num_samples, generator); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaTensor_multinomialAliasDraw(globalContext().getTHCState(), result_, q_, J_, num_samples, generator); - break; - } - case ScalarType::Half: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, ScalarType::Long); - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw_out", false, DeviceType::CUDA, ScalarType::Long); - THCudaHalfTensor_multinomialAliasDraw(globalContext().getTHCState(), result_, q_, J_, num_samples, generator); - break; - } - default: - AT_ERROR("_th_multinomial_alias_draw_out not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional generator) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(q); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(ScalarType::Long)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw", false, DeviceType::CUDA, ScalarType::Long); - THCudaDoubleTensor_multinomialAliasDraw(globalContext().getTHCState(), result_, q_, J_, num_samples, generator); - break; - } - case ScalarType::Float: { - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw", false, DeviceType::CUDA, ScalarType::Long); - THCudaTensor_multinomialAliasDraw(globalContext().getTHCState(), result_, q_, J_, num_samples, generator); - break; - } - case ScalarType::Half: { - auto q_ = checked_dense_tensor_unwrap(q, "q", 1, "_th_multinomial_alias_draw", false, DeviceType::CUDA, dispatch_scalar_type); - auto J_ = checked_dense_tensor_unwrap(J, "J", 2, "_th_multinomial_alias_draw", false, DeviceType::CUDA, ScalarType::Long); - THCudaHalfTensor_multinomialAliasDraw(globalContext().getTHCState(), result_, q_, J_, num_samples, generator); - break; - } - default: - AT_ERROR("_th_multinomial_alias_draw not supported on CUDAType for ", dispatch_scalar_type); - } - return result; -} Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); @@ -3380,84 +2384,6 @@ Tensor _thnn_rrelu_with_noise_forward(const Tensor & self, const Tensor & noise, } return output; } -Tensor & _thnn_rrelu_with_noise_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training) { - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto lower_ = lower.toDouble(); - auto upper_ = upper.toDouble(); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaDoubleRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false); - break; - } - case ScalarType::Float: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto lower_ = lower.toDouble(); - auto upper_ = upper.toDouble(); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false); - break; - } - case ScalarType::Half: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto lower_ = lower.toDouble(); - auto upper_ = upper.toDouble(); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 6, "_thnn_rrelu_with_noise_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaHalfRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false); - break; - } - default: - AT_ERROR("_thnn_rrelu_with_noise_backward_out not supported on CUDAType for ", dispatch_scalar_type); - } - return grad_input; -} -Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training) { - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - auto grad_input_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto grad_input = Tensor(c10::intrusive_ptr::reclaim(grad_input_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto lower_ = lower.toDouble(); - auto upper_ = upper.toDouble(); - THNN_CudaDoubleRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false); - break; - } - case ScalarType::Float: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto lower_ = lower.toDouble(); - auto upper_ = upper.toDouble(); - THNN_CudaRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false); - break; - } - case ScalarType::Half: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto noise_ = checked_dense_tensor_unwrap(noise, "noise", 3, "_thnn_rrelu_with_noise_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto lower_ = lower.toDouble(); - auto upper_ = upper.toDouble(); - THNN_CudaHalfRReLU_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_, noise_, lower_, upper_, training, false); - break; - } - default: - AT_ERROR("_thnn_rrelu_with_noise_backward not supported on CUDAType for ", dispatch_scalar_type); - } - return grad_input; -} Tensor & _thnn_rrelu_with_noise_forward_(Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional generator) { const OptionalDeviceGuard device_guard(device_of(self)); auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 58f6a8d53e921..b75ef8219b1c4 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -21,13 +21,17 @@ #endif #ifdef USE_MAGMA -#include +#include #endif #ifdef __HIP_PLATFORM_HCC__ #include #endif +#ifndef USE_ROCM +#include +#endif + #include #include @@ -116,10 +120,14 @@ bool CUDAHooks::hasCuDNN() const { return AT_CUDNN_ENABLED(); } -#ifdef USE_DIRECT_NVRTC +#if defined(USE_DIRECT_NVRTC) static std::pair, at::cuda::NVRTC*> load_nvrtc() { return std::make_pair(nullptr, at::cuda::load_nvrtc()); } +#elif !defined(USE_ROCM) +static std::pair, at::cuda::NVRTC*> load_nvrtc() { + return std::make_pair(nullptr, &at::cuda::detail::lazyNVRTC); +} #else static std::pair, at::cuda::NVRTC*> load_nvrtc() { #if defined(_WIN32) @@ -155,7 +163,9 @@ bool CUDAHooks::hasPrimaryContext(int64_t device_index) const { TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(), "hasPrimaryContext expects a valid device index, but got device_index=", device_index); unsigned int ctx_flags; - int ctx_is_active; + // In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird + // (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero. + int ctx_is_active = 0; AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active)); return ctx_is_active == 1; } @@ -359,6 +369,11 @@ int CUDAHooks::getNumGPUs() const { return at::cuda::device_count(); } +void CUDAHooks::deviceSynchronize(int64_t device_index) const { + at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + c10::cuda::device_synchronize(); +} + // Sigh, the registry doesn't support namespaces :( using at::CUDAHooksRegistry; using at::RegistererCUDAHooksRegistry; diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index dff8913b153f8..abef2e7ff8355 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -38,6 +38,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override; void cuFFTClearPlanCache(int64_t device_index) const override; int getNumGPUs() const override; + void deviceSynchronize(int64_t device_index) const override; }; }}} // at::cuda::detail diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp new file mode 100644 index 0000000000000..fae48c08b61f9 --- /dev/null +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -0,0 +1,171 @@ +#include + +#include +#include +#include + +namespace at { +namespace cuda { +namespace detail { +namespace _stubs { + +at::DynamicLibrary& getCUDALibrary() { +#if defined(_WIN32) + static at::DynamicLibrary lib("nvcuda.dll"); +#else + static at::DynamicLibrary lib("libcuda.so.1"); +#endif + return lib; +} + +at::DynamicLibrary& getNVRTCLibrary() { + constexpr auto major = CUDA_VERSION / 1000; + constexpr auto minor = ( CUDA_VERSION / 10 ) % 10; +#if defined(_WIN32) + auto libname = std::string("nvrtc64_") + std::to_string(major) + std::to_string(minor) + "_0.dll"; +#else + static auto libname = std::string("libnvrtc.so.") + std::to_string(major) + "." + std::to_string(minor); +#endif + static at::DynamicLibrary lib(libname.c_str()); + return lib; +} + +#define _STUB_1(LIB, NAME, RETTYPE, ARG1) \ +RETTYPE NAME(ARG1 a1) { \ + auto fn = reinterpret_cast(get## LIB ## Library().sym(__func__)); \ + if (!fn) \ + throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \ + lazyNVRTC.NAME = fn; \ + return fn(a1); \ +} + +#define _STUB_2(LIB, NAME, RETTYPE, ARG1, ARG2) \ +RETTYPE NAME(ARG1 a1, ARG2 a2) { \ + auto fn = reinterpret_cast(get## LIB ## Library().sym(__func__)); \ + if (!fn) \ + throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \ + lazyNVRTC.NAME = fn; \ + return fn(a1, a2); \ +} + +#define _STUB_3(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3) \ +RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3) { \ + auto fn = reinterpret_cast(get## LIB ## Library().sym(__func__)); \ + if (!fn) \ + throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \ + lazyNVRTC.NAME = fn; \ + return fn(a1, a2, a3); \ +} + +#define _STUB_4(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3, ARG4) \ +RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3, ARG4 a4) { \ + auto fn = reinterpret_cast(get## LIB ## Library().sym(__func__)); \ + if (!fn) \ + throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \ + lazyNVRTC.NAME = fn; \ + return fn(a1, a2, a3, a4); \ +} + +#define CUDA_STUB1(NAME, A1) _STUB_1(CUDA, NAME, CUresult CUDAAPI, A1) +#define CUDA_STUB2(NAME, A1, A2) _STUB_2(CUDA, NAME, CUresult CUDAAPI, A1, A2) +#define CUDA_STUB3(NAME, A1, A2, A3) _STUB_3(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3) +#define CUDA_STUB4(NAME, A1, A2, A3, A4) _STUB_4(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3, A4) + +#define NVRTC_STUB1(NAME, A1) _STUB_1(NVRTC, NAME, nvrtcResult, A1) +#define NVRTC_STUB2(NAME, A1, A2) _STUB_2(NVRTC, NAME, nvrtcResult, A1, A2) +#define NVRTC_STUB3(NAME, A1, A2, A3) _STUB_3(NVRTC, NAME, nvrtcResult, A1, A2, A3) + +NVRTC_STUB2(nvrtcVersion, int*, int*); +NVRTC_STUB2(nvrtcAddNameExpression, nvrtcProgram, const char * const); + +nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog, + const char *src, + const char *name, + int numHeaders, + const char * const *headers, + const char * const *includeNames) { + auto fn = reinterpret_cast(getNVRTCLibrary().sym(__func__)); + if (!fn) + throw std::runtime_error("Can't get nvrtcCreateProgram"); + lazyNVRTC.nvrtcCreateProgram = fn; + return fn(prog, src, name, numHeaders, headers, includeNames); +} + +NVRTC_STUB1(nvrtcDestroyProgram, nvrtcProgram *); +NVRTC_STUB2(nvrtcGetPTXSize, nvrtcProgram, size_t *); +NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *); +NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *); +_STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult); +NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*); +NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *); +NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **); + +CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *); +CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *); +CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t); +CUDA_STUB2(cuGetErrorString, CUresult, const char **); +CUDA_STUB1(cuCtxGetCurrent, CUcontext *); +CUDA_STUB1(cuModuleUnload, CUmodule); +CUDA_STUB3(cuDevicePrimaryCtxGetState, CUdevice, unsigned int *, int *); +CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *); +CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *); + +// Irregularly shaped functions +CUresult CUDAAPI cuLaunchKernel(CUfunction f, + unsigned int gridDimX, + unsigned int gridDimY, + unsigned int gridDimZ, + unsigned int blockDimX, + unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + CUstream hStream, + void **kernelParams, + void **extra) { + auto fn = reinterpret_cast(getCUDALibrary().sym(__func__)); + if (!fn) + throw std::runtime_error("Can't get cuLaunchKernel"); + lazyNVRTC.cuLaunchKernel = fn; + return fn(f, + gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, + sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, + const void *image, + unsigned int numOptions, + CUjit_option *options, + void **optionValues) { + auto fn = reinterpret_cast(getCUDALibrary().sym(__func__)); + if (!fn) + throw std::runtime_error("Can't get cuModuleLoadDataEx"); + lazyNVRTC.cuModuleLoadDataEx = fn; + return fn(module, image, numOptions, options, optionValues); +} + +CUresult CUDAAPI +cuLinkAddData(CUlinkState state, + CUjitInputType type, + void *data, + size_t size, + const char *name, + unsigned int numOptions, + CUjit_option *options, + void **optionValues) { + auto fn = reinterpret_cast(getCUDALibrary().sym(__func__)); + if (!fn) + throw std::runtime_error("Can't get cuLinkAddData"); + lazyNVRTC.cuLinkAddData = fn; + return fn(state, type, data, size, name, numOptions, options, optionValues); +} + +} // namespace _stubs + +NVRTC lazyNVRTC = { +#define _REFERENCE_MEMBER(name) _stubs::name, + AT_FORALL_NVRTC(_REFERENCE_MEMBER) +#undef _REFERENCE_MEMBER +}; +} // namespace detail +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.h b/aten/src/ATen/cuda/detail/LazyNVRTC.h new file mode 100644 index 0000000000000..810e1c322dbd8 --- /dev/null +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.h @@ -0,0 +1,11 @@ +#pragma once +#include +namespace at { namespace cuda { +// Forward-declares at::cuda::NVRTC +struct NVRTC; + +namespace detail { +extern NVRTC lazyNVRTC; +} + +}} // at::cuda::detail diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh index 190b6f378ea32..33b499f03b33b 100644 --- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh +++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh @@ -90,7 +90,7 @@ struct TrivialOffsetCalculator { }; template -static OffsetCalculator make_offset_calculator(const at::TensorIterator& iter) { +static OffsetCalculator make_offset_calculator(const at::TensorIteratorBase& iter) { AT_ASSERT(N <= iter.ntensors()); std::array strides; for (int i = 0; i < N; i++) { diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 00e57ca635203..f7381813d0b59 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -77,22 +77,24 @@ namespace at { namespace cuda { #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR cuOccupancyMaxActiveBlocksPerMultiprocessor #endif -#define AT_FORALL_NVRTC(_) \ - _(nvrtcVersion) \ - _(nvrtcCreateProgram) \ - _(nvrtcDestroyProgram) \ - _(nvrtcGetPTXSize) \ - _(nvrtcGetPTX) \ - _(cuModuleLoadData) \ - _(cuModuleGetFunction) \ - _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR)\ - _(nvrtcGetErrorString) \ - _(nvrtcGetProgramLogSize) \ - _(nvrtcGetProgramLog) \ - _(cuLaunchKernel) \ - _(nvrtcCompileProgram) \ - _(cuCtxGetCurrent) \ - _(cuModuleUnload) \ +#define AT_FORALL_NVRTC(_) \ + _(nvrtcVersion) \ + _(nvrtcCreateProgram) \ + _(nvrtcAddNameExpression) \ + _(nvrtcDestroyProgram) \ + _(nvrtcGetPTXSize) \ + _(nvrtcGetPTX) \ + _(cuModuleLoadData) \ + _(cuModuleGetFunction) \ + _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \ + _(nvrtcGetErrorString) \ + _(nvrtcGetProgramLogSize) \ + _(nvrtcGetProgramLog) \ + _(cuLaunchKernel) \ + _(nvrtcCompileProgram) \ + _(cuCtxGetCurrent) \ + _(nvrtcGetLoweredName) \ + _(cuModuleUnload) \ _(cuDevicePrimaryCtxGetState) #endif diff --git a/aten/src/ATen/cudnn/AutocastRNN.cpp b/aten/src/ATen/cudnn/AutocastRNN.cpp index 31e1a26e8fb7e..083d435975c7c 100644 --- a/aten/src/ATen/cudnn/AutocastRNN.cpp +++ b/aten/src/ATen/cudnn/AutocastRNN.cpp @@ -27,6 +27,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, const c10::optional& cx, int64_t mode, int64_t hidden_size, + int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, @@ -43,10 +44,18 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, // weight_stride0 is the number of weight tensors per layer and direction, as seen by model.parameters(). // If bias is enabled, there are 4 such tensors (ih and hh weights, ih and hh biases). // If bias is not enabled, there are 2 (ih and hh weights). - // This organization holds for all rnn types (RNN, GRU, and LSTM). - TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4), - "weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ", - weight_stride0); + // This organization holds for all rnn types (RNN, GRU, and LSTM). If LSTM with projections is + // used, additional hr weight is added. + if (proj_size > 0) { + TORCH_INTERNAL_ASSERT((weight_stride0 == 3) || (weight_stride0 == 5), + "weight_stride0 must be 3 (if no bias) or 5 (if bias) for LSTM with projections. Received ", + weight_stride0); + } else { + TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4), + "weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ", + weight_stride0); + } + Tensor weight_buf, redispatch_weight_buf; std::vector redispatch_weight; @@ -65,6 +74,10 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, // Casts weight tensors to FP16 and ensures all weights for all layers are views into a large flat buffer, // with the right locations and layouts expected by cudnn. // This is (and should be) autograd-exposed. + bool include_bias = true; + if (weight_stride0 == 2 || (weight_stride0 == 3 && proj_size > 0)) { + include_bias = false; + } std::tie(redispatch_weight_buf, redispatch_weight) = at::native::cudnn_rnn::copy_weights_to_flat_buf_views( weight, @@ -72,6 +85,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, input.size(-1), mode, hidden_size, + proj_size, num_layers, batch_first, bidirectional, @@ -79,9 +93,8 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, /*flat_buf_options=*/weight[0].options().dtype(at::kHalf), /*set_orig_weights_to_flat_buf=*/false, /*allow_type_change=*/true, - /*include_bias=*/weight_stride0 == 4); + /*include_bias=*/include_bias); } - return at::_cudnn_rnn( cached_cast(at::kHalf, input), needs_cast_and_flatten ? TensorList(redispatch_weight) : weight, @@ -91,6 +104,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, cached_cast(at::kHalf, cx), mode, hidden_size, + proj_size, num_layers, batch_first, dropout, diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index 2863212a03a8b..aba7b407162fe 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -4,7 +4,6 @@ #include #include -#include namespace at { namespace native { @@ -144,4 +143,38 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad, bool force_nhwc) { set(getDataType(t), (int) dim, size, filter_format); } +std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) { + switch (tformat) { + case CUDNN_TENSOR_NCHW: + return "CUDNN_TENSOR_NCHW"; + case CUDNN_TENSOR_NHWC: + return "CUDNN_TENSOR_NHWC"; + default: + std::ostringstream oss; + oss << "(unknown cudnn tensor format " << static_cast(tformat) << ")"; + return oss.str(); + } +} + +std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d) { + out << "FilterDescriptor " << static_cast(d.desc()) << "\n"; + int nbDims; + int dimA[CUDNN_DIM_MAX]; + cudnnDataType_t dtype; + cudnnTensorFormat_t tformat; + cudnnGetFilterNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &tformat, &nbDims, dimA); + out << " type = " << cudnnTypeToString(dtype) << "\n"; + out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << "\n"; + out << " nbDims = " << nbDims << "\n"; + // Read out only nbDims of the arrays! + out << " dimA = "; + for (auto i : ArrayRef{dimA, static_cast(nbDims)}) { + out << i << ", "; + } + out << "\n"; + return out; +} + +void FilterDescriptor::print() { std::cout << *this; } + }} diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index 04e027491709c..64306d115e16e 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #include @@ -12,6 +14,8 @@ namespace at { namespace native { +std::string cudnnTypeToString(cudnnDataType_t dtype); + // TODO: Add constructors for all of the descriptors inline int dataSize(cudnnDataType_t dataType) @@ -37,7 +41,7 @@ static inline void fixSizeOneDimStride(int dim, const int *size, int *stride, bo int64_t z = 1; int index = 0; std::vector permutation(dim); - + if (nhwc) { permutation[index++] = 1; } @@ -153,12 +157,15 @@ class TORCH_CUDA_API FilterDescriptor public: void set(const at::Tensor &t, int64_t pad = 0, bool force_nhwc = false); + void print(); private: void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) { AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size)); } }; +std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d); + struct TORCH_CUDA_API ConvolutionDescriptor : public Descriptor { DropoutDescriptor dropout_desc_; - void set(cudnnHandle_t handle, int hidden_size, int num_layers, DropoutDescriptor&& dropout_desc, + void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) { dropout_desc_ = std::move(dropout_desc); + AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6( handle, mut_desc(), @@ -252,12 +260,19 @@ struct TORCH_CUDA_API RNNDescriptor mode, algo, datatype)); + if (proj_size != 0) { + AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers( + handle, + /*rnnDesc=*/mut_desc(), + /*recProjSize=*/proj_size, + /*outProjSize=*/0)); + } cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); if (prop->major >= 7) { if (input_type == CUDNN_DATA_HALF) { cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH); } -#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) { cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH); } diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index f57ad03ce645b..afe88761d88f4 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -67,7 +67,7 @@ constexpr const char* CUDA_HELP = // TODO: Consider putting the stub definitions in another class, so that one // never forgets to implement each virtual function in the real implementation // in CUDAHooks. This probably doesn't buy us much though. -struct CAFFE2_API CUDAHooksInterface { +struct TORCH_API CUDAHooksInterface { // This should never actually be implemented, but it is used to // squelch -Werror=non-virtual-dtor virtual ~CUDAHooksInterface() {} @@ -181,17 +181,21 @@ struct CAFFE2_API CUDAHooksInterface { virtual int getNumGPUs() const { return 0; } + + virtual void deviceSynchronize(int64_t device_index) const { + TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP); + } }; // NB: dummy argument to suppress "ISO C++11 requires at least one argument // for the "..." in a variadic macro" -struct CAFFE2_API CUDAHooksArgs {}; +struct TORCH_API CUDAHooksArgs {}; C10_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs); #define REGISTER_CUDA_HOOKS(clsname) \ C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname) namespace detail { -CAFFE2_API const CUDAHooksInterface& getCUDAHooks(); +TORCH_API const CUDAHooksInterface& getCUDAHooks(); } // namespace detail } // namespace at diff --git a/aten/src/ATen/detail/HIPHooksInterface.h b/aten/src/ATen/detail/HIPHooksInterface.h index e5099a85e6f90..876d7caf30000 100644 --- a/aten/src/ATen/detail/HIPHooksInterface.h +++ b/aten/src/ATen/detail/HIPHooksInterface.h @@ -24,7 +24,7 @@ namespace at { // which we may want to call into from CPU code (and thus must be dynamically // dispatched, to allow for separate compilation of HIP code). See // CUDAHooksInterface for more detailed motivation. -struct CAFFE2_API HIPHooksInterface { +struct TORCH_API HIPHooksInterface { // This should never actually be implemented, but it is used to // squelch -Werror=non-virtual-dtor virtual ~HIPHooksInterface() {} @@ -61,14 +61,14 @@ struct CAFFE2_API HIPHooksInterface { // NB: dummy argument to suppress "ISO C++11 requires at least one argument // for the "..." in a variadic macro" -struct CAFFE2_API HIPHooksArgs {}; +struct TORCH_API HIPHooksArgs {}; C10_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs); #define REGISTER_HIP_HOOKS(clsname) \ C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname) namespace detail { -CAFFE2_API const HIPHooksInterface& getHIPHooks(); +TORCH_API const HIPHooksInterface& getHIPHooks(); } // namespace detail } // namespace at diff --git a/torch/csrc/jit/tensorexpr/buffer.cpp b/aten/src/ATen/function_wrapper.py similarity index 100% rename from torch/csrc/jit/tensorexpr/buffer.cpp rename to aten/src/ATen/function_wrapper.py diff --git a/aten/src/ATen/metal/Context.cpp b/aten/src/ATen/metal/Context.cpp new file mode 100644 index 0000000000000..2b0e658f2bb60 --- /dev/null +++ b/aten/src/ATen/metal/Context.cpp @@ -0,0 +1,31 @@ +#include + +#include +#include + +namespace at { +namespace metal { + +std::atomic g_metal_impl_registry; + +MetalImplRegistrar::MetalImplRegistrar(MetalInterface* impl) { + g_metal_impl_registry.store(impl); +} + +at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) { + auto p = at::metal::g_metal_impl_registry.load(); + if (p) { + return p->metal_copy_(self, src); + } + AT_ERROR("Metal backend was not linked to the build"); +} +} // namespace metal + +namespace native { +bool is_metal_available() { + auto p = at::metal::g_metal_impl_registry.load(); + return p ? p->is_metal_available() : false; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/metal/Context.h b/aten/src/ATen/metal/Context.h new file mode 100644 index 0000000000000..3f1b9e75d45bf --- /dev/null +++ b/aten/src/ATen/metal/Context.h @@ -0,0 +1,30 @@ +#ifndef MetalContext_h +#define MetalContext_h + +#include + +#include + +namespace at { +namespace metal { + +struct MetalInterface { + virtual ~MetalInterface() = default; + virtual bool is_metal_available() const = 0; + virtual at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) + const = 0; +}; + +extern std::atomic g_metal_impl_registry; + +class MetalImplRegistrar { + public: + explicit MetalImplRegistrar(MetalInterface*); +}; + +at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src); + +} // namespace metal +} // namespace at + +#endif /* MetalContext_h */ diff --git a/aten/src/ATen/miopen/Handle.cpp b/aten/src/ATen/miopen/Handle.cpp index 8965ef5a2cce8..6b8c7c6421c42 100644 --- a/aten/src/ATen/miopen/Handle.cpp +++ b/aten/src/ATen/miopen/Handle.cpp @@ -1,39 +1,53 @@ -#include - #include - -#include -#include +#include +#include +#include namespace at { namespace native { - namespace { -struct Handle { - miopenHandle_t handle; - Handle() : handle(NULL) { - MIOPEN_CHECK(miopenCreate(&handle)); - } - ~Handle() { - if (handle) { - miopenDestroy(handle); - } - } -}; +void createMIOpenHandle(miopenHandle_t *handle) { + MIOPEN_CHECK(miopenCreate(handle)); +} -std::mutex mutex; -std::unordered_map handles; +void destroyMIOpenHandle(miopenHandle_t handle) { +// this is because of something dumb in the ordering of +// destruction. Sometimes atexit, the cuda context (or something) +// would already be destroyed by the time this gets destroyed. It +// happens in fbcode setting. @colesbury and I decided to not destroy +// the handle as a workaround. +// - @soumith +// +// Further note: this is now disabled globally, because we are seeing +// the same issue as mentioned above in CUDA 11 CI. +// - @zasdfgbnm +// +// #ifdef NO_MIOPEN_DESTROY_HANDLE +// #else +// miopenDestroy(handle); +// #endif +} -} // namespace +using MIOpenPoolType = at::cuda::DeviceThreadHandlePool; +} // namespace -miopenHandle_t getMiopenHandle() -{ +miopenHandle_t getMiopenHandle() { int device; HIP_CHECK(hipGetDevice(&device)); - std::lock_guard guard(mutex); - return handles[device].handle; + // Thread local PoolWindows are lazily-initialized + // to avoid initialization issues that caused hangs on Windows. + // See: https://github.com/pytorch/pytorch/pull/22405 + // This thread local unique_ptrs will be destroyed when the thread terminates, + // releasing its reserved handles back to the pool. + static auto pool = std::make_shared(); + thread_local std::unique_ptr myPoolWindow( + pool->newPoolWindow()); + + auto handle = myPoolWindow->reserve(device); + MIOPEN_CHECK(miopenSetStream(handle, at::hip::getCurrentHIPStream())); + return handle; } }} // namespace at::native diff --git a/aten/src/ATen/miopen/Utils.h b/aten/src/ATen/miopen/Utils.h index 90ee4b7a14ee5..5952e4f4c796c 100644 --- a/aten/src/ATen/miopen/Utils.h +++ b/aten/src/ATen/miopen/Utils.h @@ -7,12 +7,6 @@ namespace at { namespace native { -inline void setMIOpenStreamToCurrent() { - // NB: Due to in-place HIPify, getCurrentCUDAStream actually means - // getCurrentHIPStream - MIOPEN_CHECK(miopenSetStream(getMiopenHandle(), at::hip::getCurrentHIPStream())); -} - // This function makes tensors which have zero stride contiguous, by // setting the strides to 1. inline Tensor contiguousIfZeroInStrides(const Tensor& t) { diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index 4f3932f68f5ee..91d3793fa93cd 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -116,27 +116,22 @@ Tensor & elu_( return at::elu_out(self, self, alpha, scale, input_scale); } -Tensor& elu_backward_out( - Tensor& grad_input, - const Tensor& grad_output, - Scalar alpha, - Scalar scale, - Scalar input_scale, - const Tensor& output) { - auto iter = TensorIterator::binary_op(grad_input, grad_output, output); - elu_backward_stub(iter.device_type(), iter, alpha, scale, input_scale); - return grad_input; -} - Tensor elu_backward( const Tensor& grad_output, Scalar alpha, Scalar scale, Scalar input_scale, - const Tensor& output) { + bool is_result, + const Tensor& self_or_result) { + TORCH_CHECK( + !is_result || alpha.to() >= 0.0, + "In-place elu backward calculation is triggered with a negative slope which is not supported. " + "This is caused by calling in-place forward function with a negative slope, " + "please call out-of-place version instead."); + Tensor result; - auto iter = TensorIterator::binary_op(result, grad_output, output); - elu_backward_stub(iter.device_type(), iter, alpha, scale, input_scale); + auto iter = TensorIterator::binary_op(result, grad_output, self_or_result); + elu_backward_stub(iter.device_type(), iter, alpha, scale, input_scale, is_result); return iter.output(); } @@ -224,6 +219,13 @@ Tensor silu_backward( return grad_input; } +Tensor math_silu_backward( + const Tensor& grad_output, + const Tensor& input) { + auto input_sigmoid = at::sigmoid(input); + return grad_output * (input_sigmoid * (1 + input * (1 - input_sigmoid))); +} + template inline void _rrelu_with_noise_train( Tensor& output, @@ -270,8 +272,8 @@ Tensor& rrelu_with_noise_out_cpu( }); return output; } else { - auto lower_tensor = scalar_to_tensor(lower, self.device()); - auto upper_tensor = scalar_to_tensor(upper, self.device()); + auto lower_tensor = scalar_to_tensor(lower); + auto upper_tensor = scalar_to_tensor(upper); auto negative = (lower_tensor + upper_tensor) / 2; Scalar negative_slope = negative.item(); return at::leaky_relu_out(output, self, negative_slope); @@ -307,8 +309,8 @@ Tensor rrelu_with_noise_backward( Scalar upper, bool training, bool is_result) { - auto lower_tensor = scalar_to_tensor(lower, grad_output.device()); - auto upper_tensor = scalar_to_tensor(upper, grad_output.device()); + auto lower_tensor = scalar_to_tensor(lower); + auto upper_tensor = scalar_to_tensor(upper); if (training && (upper_tensor - lower_tensor).item().to() > 1E-6) { return grad_output.mul(noise); } else { diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index bebfa67c93cff..acc80b8f0229e 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -23,12 +23,13 @@ using hardswish_backward_fn = void(*)(TensorIterator&); using shrink_fn = void (*)(TensorIterator&, Scalar); using shrink_backward_fn = void (*)(TensorIterator&, Scalar); using elu_fn = void (*)(TensorIterator&, Scalar, Scalar, Scalar); +using elu_backward_fn = void (*)(TensorIterator&, Scalar, Scalar, Scalar, bool); using leaky_relu_fn = void (*)(TensorIterator&, Scalar); using leaky_relu_backward_fn = void (*)(TensorIterator&, Scalar); using log_sigmoid_cpu_fn = void (*)(Tensor& , Tensor&, const Tensor& ); DECLARE_DISPATCH(elu_fn, elu_stub); -DECLARE_DISPATCH(elu_fn, elu_backward_stub); +DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub); DECLARE_DISPATCH(softplus_fn, softplus_stub); DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub); DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub); diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index 9802797874b95..9778aa035cb1f 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -1,7 +1,6 @@ #include #include -#include -#include +#include namespace at { @@ -9,295 +8,66 @@ namespace native { namespace { - inline int start_index(int a, int b, int c) { - return (int)std::floor((float)(a * c) / b); - } - - inline int end_index(int a, int b, int c) { - return (int)std::ceil((float)((a + 1) * c) / b); - } - - template - static void adaptive_avg_pool2d_single_out_frame( - scalar_t *input_p, - scalar_t *output_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW, - int64_t istrideD, - int64_t istrideH, - int64_t istrideW) - { - at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { - for (auto d = start; d < end; d++) - { - /* loop over output */ - int64_t oh, ow; - for(oh = 0; oh < osizeH; oh++) - { - int istartH = start_index(oh, osizeH, isizeH); - int iendH = end_index(oh, osizeH, isizeH); - int kH = iendH - istartH; - - for(ow = 0; ow < osizeW; ow++) - { - int istartW = start_index(ow, osizeW, isizeW); - int iendW = end_index(ow, osizeW, isizeW); - int kW = iendW - istartW; - - /* local pointers */ - scalar_t *ip = input_p + d*istrideD + istartH*istrideH + istartW*istrideW; - scalar_t *op = output_p + d*osizeH*osizeW + oh*osizeW + ow; - - /* compute local average: */ - scalar_t sum = 0; - int ih, iw; - for(ih = 0; ih < kH; ih++) - { - for(iw = 0; iw < kW; iw++) - { - scalar_t val = *(ip + ih*istrideH + iw*istrideW); - sum += val; - } - } - - /* set output to local average */ - *op = sum / kW / kH; - } - } - } - }); - } - - template - void adaptive_avg_pool2d_out_frame( - scalar_t *input_p, - scalar_t *output_p, - int64_t sizeB, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW, - int64_t istrideB, - int64_t istrideD, - int64_t istrideH, - int64_t istrideW) - { - at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - adaptive_avg_pool2d_single_out_frame( - input_p + b * istrideB, - output_p + b * sizeD * osizeH * osizeW, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); - } - }); - } - void adaptive_avg_pool2d_out_cpu_template( at::Tensor& output, at::Tensor const& input, IntArrayRef output_size) { TORCH_CHECK(output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2"); - for (int64_t i = 0; i < input.ndimension(); i++) { + int64_t ndim = input.ndimension(); + for (int64_t i = 0; i < ndim; i++) { TORCH_CHECK(input.size(i) > 0, "adaptive_avg_pooling2d(): expected input to have non-empty spatial dimensions, " "but input has sizes ", input.sizes(), " with dimension ", i, " being " "empty"); } - TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), + TORCH_CHECK((ndim == 3 || ndim == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); + TORCH_CHECK(input.dtype() == output.dtype(), + "expected dtype ", input.dtype(), " for `output` but got dtype ", output.dtype()); - /* sizes */ - int64_t sizeD = input.size(-3); - int64_t isizeH = input.size(-2); - int64_t isizeW = input.size(-1); - /* strides */ - int64_t istrideD = input.stride(-3); - int64_t istrideH = input.stride(-2); - int64_t istrideW = input.stride(-1); - - auto osizeH = output_size[0]; - auto osizeW = output_size[1]; - - /* resize output */ - if (input.ndimension() == 3 || input.size(-4) == 1) - { - if (input.ndimension() == 3) { - output.resize_({sizeD, osizeH, osizeW}); - } else { - output.resize_({1, sizeD, osizeH, osizeW}); - } - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { - auto input_data = input.data_ptr(); - auto output_data = output.data_ptr(); - adaptive_avg_pool2d_single_out_frame( - input_data, - output_data, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); - } - ); - } - else - { - int64_t sizeB = input.size(-4); - output.resize_({sizeB, sizeD, osizeH, osizeW}); - int64_t istrideB = input.stride(-4); + int64_t channels = input.size(-3); + int64_t input_height = input.size(-2); + int64_t input_width = input.size(-1); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] { - auto input_data = input.data_ptr(); - auto output_data = output.data_ptr(); - adaptive_avg_pool2d_out_frame( - input_data, - output_data, - sizeB, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideB, - istrideD, - istrideH, istrideW); - }); + if (ndim == 3) { + output.resize_({channels, output_height, output_width}); + } else { + int64_t nbatch = input.size(0); + output.resize_({nbatch, channels, output_height, output_width}, input.suggest_memory_format()); } - } - - template - static void adaptive_avg_pool2d_backward_single_out_frame( - scalar_t *gradInput_p, - scalar_t *gradOutput_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW) - { - at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) { - for (auto d = start; d < end; d++) - { - scalar_t *gradInput_p_d = gradInput_p + d*isizeW*isizeH; - scalar_t *gradOutput_p_d = gradOutput_p + d*osizeW*osizeH; - - /* calculate average */ - int64_t oh, ow; - for(oh = 0; oh < osizeH; oh++) - { - int istartH = start_index(oh, osizeH, isizeH); - int iendH = end_index(oh, osizeH, isizeH); - int kH = iendH - istartH; - for(ow = 0; ow < osizeW; ow++) - { - - int istartW = start_index(ow, osizeW, isizeW); - int iendW = end_index(ow, osizeW, isizeW); - int kW = iendW - istartW; - - scalar_t grad_delta = gradOutput_p_d[oh*osizeW +ow] / kH / kW; - - int ih, iw; - for(ih = istartH; ih < iendH; ih++) - { - for(iw = istartW; iw < iendW; iw++) - { - /* update gradient */ - gradInput_p_d[ih*isizeW + iw] += grad_delta; - } - } - } - } - } - }); - } - - template - void adaptive_avg_pool2d_backward_out_frame( - scalar_t *gradInput_p, - scalar_t *gradOutput_p, - int64_t sizeB, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW) - { - at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) { - for (auto b = start; b < end; b++) - { - scalar_t *gradInput_p_d = gradInput_p + b * sizeD * isizeW * isizeH; - scalar_t *gradOutput_p_d = gradOutput_p + b * sizeD * osizeW * osizeH; - adaptive_avg_pool2d_backward_single_out_frame( - gradInput_p_d, - gradOutput_p_d, - sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - }); + adaptive_avg_pool2d_kernel(kCPU, output, input, output_size); } Tensor& adaptive_avg_pool2d_backward_out_cpu_template( - Tensor& gradInput, - const Tensor& gradOutput_, + Tensor& grad_input, + const Tensor& grad_output, const Tensor& input) { - /* sizes */ - int sizeD = input.size(-3); - int isizeH = input.size(-2); - int isizeW = input.size(-1); - int osizeH = gradOutput_.size(-2); - int osizeW = gradOutput_.size(-1); - - /* get contiguous gradOutput */ - auto gradOutput = gradOutput_.contiguous(); + int64_t ndim = grad_output.ndimension(); + for (int64_t i = 0; i < ndim; i++) { + TORCH_CHECK(grad_output.size(i) > 0, + "adaptive_avg_pooling2d_backward(): expected grad_output to have non-empty spatial dimensions, " + "but grad_output has sizes ", grad_output.sizes(), " with dimension ", i, " being " + "empty"); + } - /* backprop */ - if (input.ndimension() == 3 || input.size(-4) == 1) - { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { - /* get raw pointers */ - scalar_t *gradInput_data = gradInput.data_ptr(); - scalar_t *gradOutput_data = gradOutput.data_ptr(); + TORCH_CHECK((ndim == 3 || ndim == 4), + "non-empty 3D or 4D (batch mode) tensor expected for grad_output"); + TORCH_CHECK(input.dtype() == grad_output.dtype(), + "expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype()); + TORCH_CHECK(input.dtype() == grad_input.dtype(), + "expected dtype ", input.dtype(), " for `grad_input` but got dtype ", grad_input.dtype()); - adaptive_avg_pool2d_backward_single_out_frame( - gradInput_data, gradOutput_data, - sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - ); - } - else - { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] { - /* get raw pointers */ - scalar_t *gradInput_data = gradInput.data_ptr(); - scalar_t *gradOutput_data = gradOutput.data_ptr(); - int64_t sizeB = input.size(-4); + grad_input.resize_(input.sizes(), input.suggest_memory_format()); + grad_input.zero_(); - adaptive_avg_pool2d_backward_out_frame( - gradInput_data, gradOutput_data, - sizeB, sizeD, - isizeH, isizeW, - osizeH, osizeW); - } - ); - } - return gradInput; + adaptive_avg_pool2d_backward_kernel(kCPU, grad_input, grad_output); + return grad_input; } } // namespace @@ -346,25 +116,27 @@ namespace { } Tensor& adaptive_avg_pool2d_backward_out_cpu( - Tensor& gradInput, - const Tensor& gradOutput, + Tensor& grad_input, + const Tensor& grad_output, const Tensor& input) { - gradInput.resize_as_(input); adaptive_avg_pool2d_backward_out_cpu_template( - gradInput, gradOutput, input); - return gradInput; + grad_input, grad_output, input); + return grad_input; } Tensor adaptive_avg_pool2d_backward_cpu( - const Tensor& gradOutput, + const Tensor& grad_output, const Tensor& input) { - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_input = at::empty({0}, input.options()); adaptive_avg_pool2d_backward_out_cpu_template( - gradInput, gradOutput, input); - return gradInput; + grad_input, grad_output, input); + return grad_input; } +DEFINE_DISPATCH(adaptive_avg_pool2d_kernel); +DEFINE_DISPATCH(adaptive_avg_pool2d_backward_kernel); + } // at::native } // at diff --git a/aten/src/ATen/native/AdaptivePooling.h b/aten/src/ATen/native/AdaptivePooling.h new file mode 100644 index 0000000000000..29b2fd1c94c9c --- /dev/null +++ b/aten/src/ATen/native/AdaptivePooling.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel); +DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel); + +static inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (int64_t)std::floor((float)(a * c) / b); +} + +static inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return (int64_t)std::ceil((float)((a + 1) * c) / b); +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/AutogradComposite.cpp b/aten/src/ATen/native/AutogradComposite.cpp new file mode 100644 index 0000000000000..be7184a26565a --- /dev/null +++ b/aten/src/ATen/native/AutogradComposite.cpp @@ -0,0 +1,27 @@ +#include + +namespace at { +namespace native { + +/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients. +/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is. +/// This function is backward differentiable. +at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) { + TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that " + "already has a forward gradient at the same level ", level, " is not supported."); + + auto dual_tensor = primal.view(primal.sizes()); + dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false); + return dual_tensor; +} + +/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal +/// is a view of the dual and the tangent is returned as is. +/// This function is backward differentiable. +std::tuple unpack_dual(const at::Tensor& tensor, int64_t level) { + return std::tuple(tensor._fw_primal(level), tensor.fw_grad(level)); +} + +} // namespace native + +} // namespace at diff --git a/aten/src/ATen/native/AveragePool2d.cpp b/aten/src/ATen/native/AveragePool2d.cpp index 04da6a726ecfa..e7834fcf33964 100644 --- a/aten/src/ATen/native/AveragePool2d.cpp +++ b/aten/src/ATen/native/AveragePool2d.cpp @@ -119,9 +119,6 @@ void avg_pool2d_out_cpu_template( const int padH = safe_downcast(padding[0]); const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4), - "non-empty 2D or 3D (batch mode) tensor expected for input"); - TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero"); @@ -139,7 +136,7 @@ void avg_pool2d_out_cpu_template( kH, kW, dH, dW, padH, padW, 1, 1, nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth); + outputHeight, outputWidth, input_.suggest_memory_format()); if (input_.ndimension() == 3) { output.resize_({nInputPlane, outputHeight, outputWidth}); @@ -276,12 +273,8 @@ Tensor& avg_pool2d_backward_out_cpu_template( "avg_pool2d: padding must either be a single int, or a tuple of two ints"); const int padH = safe_downcast(padding[0]); const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - const int64_t ndim = input.ndimension(); - TORCH_CHECK((ndim == 3 || ndim == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero"); /* sizes */ @@ -299,7 +292,8 @@ Tensor& avg_pool2d_backward_out_cpu_template( kH, kW, dH, dW, padH, padW, nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth); + outputHeight, outputWidth, + input.suggest_memory_format()); /* get contiguous gradOutput */ const Tensor gradOutput = gradOutput_.contiguous(); diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index f4babb2a14a3c..d2f1358a345d2 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -4,7 +4,9 @@ #include #include +#include #include +#include #include #include @@ -66,14 +68,26 @@ extern "C" void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau extern "C" void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info); // syev +extern "C" void zheev_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, double *w, std::complex *work, int *lwork, double *rwork, int *info); +extern "C" void cheev_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, float *w, std::complex *work, int *lwork, float *rwork, int *info); extern "C" void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); extern "C" void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info); +// syevd +extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, double *w, std::complex *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info); +extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, float *w, std::complex *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info); +extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info); +extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info); + +// geev +extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); +extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); + // gesdd extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex *a, int *lda, - double *s, std::complex *u, int *ldu, std::complex *vt, int *ldvt, std::complex *work, int *lwork, int *rwork, int *iwork, int *info); + double *s, std::complex *u, int *ldu, std::complex *vt, int *ldvt, std::complex *work, int *lwork, double *rwork, int *iwork, int *info); extern "C" void cgesdd_(char *jobz, int *m, int *n, std::complex *a, int *lda, - float *s, std::complex *u, int *ldu, std::complex *vt, int *ldvt, std::complex *work, int *lwork, int *rwork, int *iwork, int *info); + float *s, std::complex *u, int *ldu, std::complex *vt, int *ldvt, std::complex *work, int *lwork, float *rwork, int *iwork, int *info); extern "C" void dgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *iwork, int *info); extern "C" void sgesdd_(char *jobz, int *m, int *n, float *a, int *lda, @@ -116,12 +130,15 @@ void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *wo template void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); -template -void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, scalar_t *w, scalar_t *work, int lwork, int *info); +template +void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info); + +template +void lapackSyevd(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int lrwork, int *iwork, int liwork, int *info); template void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, - value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, int *rwork, int *iwork, int *info); + value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); template void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info); @@ -255,33 +272,71 @@ template<> void lapackOrgqr(int m, int n, int k, float *a, int lda, float sorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info); } -template<> void lapackSymeig(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, int *info) { +template<> void lapackSymeig, double>(char jobz, char uplo, int n, c10::complex *a, int lda, double *w, c10::complex *work, int lwork, double *rwork, int *info) { + zheev_(&jobz, &uplo, &n, reinterpret_cast*>(a), &lda, w, reinterpret_cast*>(work), &lwork, rwork, info); +} + +template<> void lapackSymeig, float>(char jobz, char uplo, int n, c10::complex *a, int lda, float *w, c10::complex *work, int lwork, float *rwork, int *info) { + cheev_(&jobz, &uplo, &n, reinterpret_cast*>(a), &lda, w, reinterpret_cast*>(work), &lwork, rwork, info); +} + +template<> void lapackSymeig(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double* rwork, int *info) { + (void)rwork; // unused dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); } -template<> void lapackSymeig(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, int *info) { +template<> void lapackSymeig(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float* rwork, int *info) { + (void)rwork; // unused ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); } +template<> void lapackSyevd, double>(char jobz, char uplo, int n, c10::complex *a, int lda, double *w, c10::complex *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) { + zheevd_(&jobz, &uplo, &n, reinterpret_cast*>(a), &lda, w, reinterpret_cast*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info); +} + +template<> void lapackSyevd, float>(char jobz, char uplo, int n, c10::complex *a, int lda, float *w, c10::complex *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) { + cheevd_(&jobz, &uplo, &n, reinterpret_cast*>(a), &lda, w, reinterpret_cast*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info); +} + +template<> void lapackSyevd(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) { + (void)rwork; // unused + (void)lrwork; // unused + dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); +} + +template<> void lapackSyevd(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) { + (void)rwork; // unused + (void)lrwork; // unused + ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); +} + +template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) { + dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); +} + +template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *wr, float *wi, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) { + sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); +} + template<> void lapackSvd, double>(char jobz, int m, int n, c10::complex *a, int lda, - double *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, int *rwork, int *iwork, int *info) { + double *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, double *rwork, int *iwork, int *info) { zgesdd_(&jobz, &m, &n, reinterpret_cast*>(a), &lda, s, reinterpret_cast*>(u), &ldu, reinterpret_cast*>(vt), &ldvt, reinterpret_cast*>(work), &lwork, rwork, iwork, info); } template<> void lapackSvd, float>(char jobz, int m, int n, c10::complex *a, int lda, - float *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, int *rwork, int *iwork, int *info) { + float *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, float *rwork, int *iwork, int *info) { cgesdd_(&jobz, &m, &n, reinterpret_cast*>(a), &lda, s, reinterpret_cast*>(u), &ldu, reinterpret_cast*>(vt), &ldvt, reinterpret_cast*>(work), &lwork, rwork, iwork, info); } template<> void lapackSvd(char jobz, int m, int n, double *a, int lda, - double *s, double *u, int ldu, double *vt, int ldvt, double *work, int lwork, int *rwork, int *iwork, int *info) { + double *s, double *u, int ldu, double *vt, int ldvt, double *work, int lwork, double *rwork, int *iwork, int *info) { dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info); } template<> void lapackSvd(char jobz, int m, int n, float *a, int lda, - float *s, float *u, int ldu, float *vt, int ldvt, float *work, int lwork, int *rwork, int *iwork, int *info) { + float *s, float *u, int ldu, float *vt, int ldvt, float *work, int lwork, float *rwork, int *iwork, int *info) { sgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info); } @@ -307,8 +362,18 @@ template<> void lapackLuSolve(char trans, int n, int nrhs, float *a, int // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +/* +Computes the solution to a system of linear equations + A X = B, +where A is an n-by-n matrix and X and B are n-by-nrhs matrices. +Note that B is required to be a matrix, the usual, vector case, is obtained with nrhs = 1. +Above description is for non-batched input, the batched input is also supported. +This is an in-place routine, content of both A and b are overriden. +'infos' is an int Tensor containing error codes for each matrix in the batched input. +For more information see LAPACK's documentation for GESV routine. +*/ template -static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { +static void apply_solve(Tensor& b, Tensor& A, Tensor& infos) { #ifndef USE_LAPACK AT_ERROR("solve: LAPACK library not found in compilation"); #else @@ -319,19 +384,17 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { auto batch_size = batchCount(A); auto n = A.size(-2); auto nrhs = b.size(-1); + auto lda = std::max(1, n); - auto ipiv = at::empty({n}, b.options().dtype(kInt)); + auto ipiv = at::empty({lda}, b.options().dtype(kInt)); auto ipiv_data = ipiv.data_ptr(); + auto infos_data = infos.data_ptr(); - int info; for (int64_t i = 0; i < batch_size; i++) { scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; scalar_t* b_working_ptr = &b_data[i * b_mat_stride]; - lapackSolve(n, nrhs, A_working_ptr, n, ipiv_data, b_working_ptr, n, &info); - infos[i] = info; - if (info != 0) { - return; - } + int* info_working_ptr = &infos_data[i]; + lapackSolve(n, nrhs, A_working_ptr, lda, ipiv_data, b_working_ptr, lda, info_working_ptr); } #endif } @@ -339,14 +402,14 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { std::tuple _solve_helper_cpu(const Tensor& self, const Tensor& A) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - std::vector infos(batchCount(self), 0); + auto infos = at::empty({std::max(1, batchCount(self))}, self.options().dtype(kInt)); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "solve_cpu", [&]{ apply_solve(self_working_copy, A_working_copy, infos); }); if (self.dim() > 2) { batchCheckErrors(infos, "solve_cpu"); } else { - singleCheckErrors(infos[0], "solve_cpu"); + singleCheckErrors(infos.item().toInt(), "solve_cpu"); } return std::tuple(self_working_copy, A_working_copy); } @@ -370,10 +433,110 @@ std::tuple solve_out(Tensor& solution, Tensor& lu, const Tensor return std::tuple(solution, lu); } + +// This is a type dispatching helper function for 'apply_solve' +Tensor& _linalg_solve_out_helper_cpu(Tensor& result, Tensor& input, Tensor& infos) { + // 'result' and 'input' should be in column major order (it should be checked before calling this function) + // the content of 'result', 'input' and 'infos' is overriden by 'apply_solve' + // 'result' should contain data of 'other' tensor (right-hand-side of the linear system of equations) + // 'input' should contain data of original 'input' tensor (left-hand-side of the linear system of equations) + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_solve_out_cpu", [&]{ + apply_solve(result, input, infos); + }); + return result; +} + +// Solves a system of linear equations matmul(input, x) = other in-place +// LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here +static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor& input, const Tensor& other) { + TORCH_CHECK(infos.scalar_type() == kInt, + "infos dtype ", infos.scalar_type(), " does not match the expected dtype ", kInt); + TORCH_CHECK(result.scalar_type() == input.scalar_type(), + "result dtype ", result.scalar_type(), " does not match input dtype ", input.scalar_type()); + TORCH_CHECK(input.scalar_type() == other.scalar_type(), + "input dtype ", input.scalar_type(), " does not match other dtype ", other.scalar_type()); + + TORCH_CHECK(input.dim() >= 2, + "input should have at least 2 dimensions, but has ", input.dim(), " dimensions instead"); + TORCH_CHECK(other.dim() >= 1, + "other should have at least 1 dimension, but has ", other.dim(), " dimensions instead"); + + // NumPy works for 1-dimensional 'other' or batch of 1-dimensional tensors, we need to unsqueeze it + // because 2-dimensional tensors are expected in the implementation + auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // A.shape[:-1] + bool is_rhs_broadcasted = other.dim() == 1 || (input.dim()-1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape)); + Tensor other_ = is_rhs_broadcasted ? other.unsqueeze(-1) : other; + + // _linalg_broadcast_batch_dims also includes linearSolveCheckInputs + // it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input' + Tensor other_broadcasted, input_broadcasted; + std::tie(other_broadcasted, input_broadcasted) = _linalg_broadcast_batch_dims(other_, input, "linalg_solve"); + + // if result has no elements we can modify it + if (result.numel() == 0) { + at::native::resize_as_(result, other_broadcasted.transpose(-2, -1), MemoryFormat::Contiguous); + result.transpose_(-2, -1); + } else { + // Resize messes up the strides and we expect strictly column major order, so let's not use at::native::resize_output + TORCH_CHECK(result.sizes().equals(other_broadcasted.sizes()), + "result shape ", result.sizes(), " does not match broadcasted other shape ", other_broadcasted.sizes()); + } + + TORCH_CHECK(result.transpose(-2, -1).is_contiguous(), "result tensor must be in batched column major order (Fortran contiguous)."); + result.copy_(other_broadcasted); + + auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted); + at::native::resize_output(infos, {std::max(1, batchCount(input_broadcasted))}); + // if input is empty infos might not get filled; make sure infos doesn't contain garbage then + if (input.numel() == 0) { + infos.fill_(0); + } + result = at::_linalg_solve_out_helper_(result, input_working_copy, infos); + + // NumPy works for 1-dimensional 'other', we need to squeeze the result in this case + if (is_rhs_broadcasted) { + result.squeeze_(-1); + } + + return result; +} + +// Solves a system of linear equations matmul(input, x) = other in-place +Tensor& linalg_solve_out(Tensor& result, const Tensor& input, const Tensor& other) { + auto infos = at::empty({0}, input.options().dtype(kInt)); + result = linalg_solve_out_info(result, infos, input, other); + + // Now check LAPACK/MAGMA error codes + // batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)' + auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // A.shape[:-1] + bool is_rhs_broadcasted = other.dim() == 1 || (input.dim()-1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape)); + if (is_rhs_broadcasted ? result.dim() > 1 : result.dim() > 2) { + batchCheckErrors(infos, "linalg_solve"); + } else { + singleCheckErrors(infos.item().toInt(), "linalg_solve"); + } + + return result; +} + +// Solves a system of linear equations matmul(input, x) = other +Tensor linalg_solve(const Tensor& input, const Tensor& other) { + Tensor result = at::empty({0}, input.options()); + result = at::linalg_solve_out(result, input, other); + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +/* +Computes the inverse of n-by-n matrix 'self' +This is an in-place routine, it overwrites the content of 'self'. +'infos_lu' and 'infos_getri' are int Tensors containing error codes for each matrix in the batched input. +'infos_lu' is for holding lapackLU errors, and 'infos_getri' is for holding lapackGetri errors. +For more information see LAPACK's documentation for GETRI and GETRF routines. +*/ template -static void apply_inverse(Tensor& self, std::vector& infos) { +static void apply_inverse(Tensor& self, Tensor& infos_lu, Tensor& infos_getri) { #ifndef USE_LAPACK AT_ERROR("inverse: LAPACK library not found in compilation"); #else @@ -382,9 +545,12 @@ static void apply_inverse(Tensor& self, std::vector& infos) { auto self_matrix_stride = matrixStride(self); auto batch_size = batchCount(self); auto n = self.size(-2); + auto lda = std::max(1, n); - auto ipiv = at::empty({n}, self.options().dtype(kInt)); + auto ipiv = at::empty({lda}, self.options().dtype(kInt)); auto ipiv_data = ipiv.data_ptr(); + auto infos_lu_data = infos_lu.data_ptr(); + auto infos_getri_data = infos_getri.data_ptr(); int info; // Run once, first to get the optimum work size @@ -393,39 +559,36 @@ static void apply_inverse(Tensor& self, std::vector& infos) { // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() int lwork = -1; scalar_t wkopt; - lapackGetri(n, self_data, n, ipiv_data, &wkopt, lwork, &info); + lapackGetri(n, self_data, lda, ipiv_data, &wkopt, lwork, &info); lwork = static_cast(real_impl(wkopt)); Tensor work = at::empty({lwork}, self.options()); auto work_data = work.data_ptr(); for (int64_t i = 0; i < batch_size; i++) { scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - lapackLu(n, n, self_working_ptr, n, ipiv_data, &info); - infos[i] = info; - if (info != 0) { - return; - } + int* info_lu_working_ptr = &infos_lu_data[i]; + lapackLu(n, n, self_working_ptr, lda, ipiv_data, info_lu_working_ptr); // now compute the actual inverse - lapackGetri(n, self_working_ptr, n, ipiv_data, work_data, lwork, &info); - infos[i] = info; - if (info != 0) { - return; - } + int* info_getri_working_ptr = &infos_getri_data[i]; + lapackGetri(n, self_working_ptr, lda, ipiv_data, work_data, lwork, info_getri_working_ptr); } #endif } Tensor _inverse_helper_cpu(const Tensor& self) { - std::vector infos(batchCount(self), 0); + auto infos_lu = at::empty({std::max(1, batchCount(self))}, self.options().dtype(kInt)); + auto infos_getri = at::empty({std::max(1, batchCount(self))}, self.options().dtype(kInt)); auto self_working_copy = cloneBatchedColumnMajor(self); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cpu", [&]{ - apply_inverse(self_working_copy, infos); + apply_inverse(self_working_copy, infos_lu, infos_getri); }); if (self.dim() > 2) { - batchCheckErrors(infos, "inverse_cpu"); + batchCheckErrors(infos_lu, "inverse_cpu"); + batchCheckErrors(infos_getri, "inverse_cpu"); } else { - singleCheckErrors(infos[0], "inverse_cpu"); + singleCheckErrors(infos_lu.item().toInt(), "inverse_cpu"); + singleCheckErrors(infos_getri.item().toInt(), "inverse_cpu"); } return self_working_copy; } @@ -446,6 +609,75 @@ Tensor& inverse_out(Tensor &result, const Tensor &self) { return result; } +// This is a type dispatching helper function for 'apply_inverse' +Tensor& _linalg_inv_out_helper_cpu(Tensor &result, Tensor& infos_lu, Tensor& infos_getri) { + // This function calculates the inverse matrix in-place + // result should be in column major order and contain matrices to invert + // the content of result is overwritten by 'apply_inverse' + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_inv_out_cpu", [&]{ + apply_inverse(result, infos_lu, infos_getri); + }); + return result; +} + +// Computes the inverse matrix of 'input', it is is saved to 'result' in-place +// LAPACK/MAGMA/cuSOLVER error codes are saved in 'infos' tensors, they are not checked here +static Tensor& linalg_inv_out_info(Tensor& result, Tensor& infos_lu, Tensor& infos_getri, const Tensor& input) { + squareCheckInputs(input); + TORCH_INTERNAL_ASSERT(infos_lu.scalar_type() == kInt); + TORCH_INTERNAL_ASSERT(infos_getri.scalar_type() == kInt); + TORCH_CHECK(result.scalar_type() == input.scalar_type(), + "result dtype ", result.scalar_type(), " does not match input dtype ", input.scalar_type()); + TORCH_CHECK(result.device() == input.device(), + "result device ", result.device(), " does not match input device ", input.device()); + + // if result has no elements we can modify it + if (result.numel() == 0) { + at::native::resize_as_(result, input.transpose(-2, -1), MemoryFormat::Contiguous); + result.transpose_(-2, -1); + } else { + // Resize messes up the strides and we expect strictly column major order, so let's not use at::native::resize_output + TORCH_CHECK(result.sizes().equals(input.sizes()), + "result shape ", result.sizes(), " does not match input shape ", input.sizes()); + } + + TORCH_CHECK(result.transpose(-2, -1).is_contiguous(), "result tensor must be in batched column major order (Fortran contiguous)."); + result.copy_(input); + + at::native::resize_output(infos_lu, {std::max(1, batchCount(input))}); + at::native::resize_output(infos_getri, {std::max(1, batchCount(input))}); + infos_lu.fill_(0); + infos_getri.fill_(0); + + result = at::_linalg_inv_out_helper_(result, infos_lu, infos_getri); + return result; +} + +// Computes the inverse matrix of 'input', it is is saved to 'result' in-place +Tensor& linalg_inv_out(Tensor &result, const Tensor &input) { + auto infos_lu = at::empty({0}, input.options().dtype(kInt)); + auto infos_getri = at::empty({0}, input.options().dtype(kInt)); + result = linalg_inv_out_info(result, infos_lu, infos_getri, input); + + // Now check LAPACK/MAGMA/cuSOLVER error codes + if (result.dim() > 2) { + batchCheckErrors(infos_lu, "linalg_inv_lu"); + batchCheckErrors(infos_getri, "linalg_inv_getri"); + } else { + singleCheckErrors(infos_lu.item().toInt(), "linalg_inv_lu"); + singleCheckErrors(infos_getri.item().toInt(), "linalg_inv_getri"); + } + + return result; +} + +// Computes the inverse matrix of 'input' +Tensor linalg_inv(const Tensor &input) { + Tensor result = at::empty({0}, input.options()); + result = at::linalg_inv_out(result, input); + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -522,11 +754,12 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector& infos auto self_matrix_stride = matrixStride(self); auto batch_size = batchCount(self); auto n = self.size(-2); + auto lda = std::max(1, n); int info; for (int64_t i = 0; i < batch_size; i++) { scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - lapackCholesky(uplo, n, self_working_ptr, n, &info); + lapackCholesky(uplo, n, self_working_ptr, lda, &info); infos[i] = info; if (info != 0) { return; @@ -571,6 +804,20 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) { return result; } +Tensor linalg_cholesky(const Tensor &self) { + squareCheckInputs(self); + return at::_cholesky_helper(self, /*upper=*/false).tril_(); +} + +Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + Tensor result_tmp = at::linalg_cholesky(self); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -773,7 +1020,9 @@ static void apply_orgqr(Tensor& self, const Tensor& tau, int64_t m, int64_t n_co #endif } -std::tuple _qr_helper_cpu(const Tensor& self, bool some) { +std::tuple _linalg_qr_helper_cpu(const Tensor& self, std::string mode) { + bool compute_q, reduced; + std::tie(compute_q, reduced) = _parse_qr_mode(mode); std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); @@ -783,25 +1032,22 @@ std::tuple _qr_helper_cpu(const Tensor& self, bool some) { self_sizes[self.dim() - 2] = std::min(m, n); auto tau_working_copy = at::empty(self_sizes, self.options()); Tensor q_working_copy; + Tensor R; // Setup input geometry for apply_orgqr std::vector q_sizes, q_strides; int64_t n_columns_q; - Tensor R; - std::tie(q_sizes, q_strides, n_columns_q) = _compute_geometry_for_Q(self, some); + std::tie(q_sizes, q_strides, n_columns_q) = _compute_geometry_for_Q(self, reduced); // If there are no elements, then we simply return a pair of tensors of required dimensions if (self.numel() == 0) { - // Fix the number of columns of q appropriately - q_sizes[self.dim() - 1] = n_columns_q; - q_working_copy = at::eye(q_sizes[self.dim() - 2], q_sizes[self.dim() - 1], self.options()); - q_working_copy = q_working_copy.expand_as(q_working_copy); - - // We repurpose the same q_sizes for R - // Fix the number of rows and columns of q_working_copy appropriately - q_sizes[self.dim() - 1] = n; - q_sizes[self.dim() - 2] = n_columns_q; - R = at::empty(q_sizes, self.options()); + R = at::empty({n_columns_q, n}, self.options()); + if (compute_q) { + int64_t n_rows_q = q_sizes[self.dim() - 2]; + q_working_copy = at::eye(n_rows_q, n_columns_q, self.options()); + } else { + q_working_copy = at::empty({0}, self.options()); + } return std::make_tuple(q_working_copy, R); } @@ -821,6 +1067,11 @@ std::tuple _qr_helper_cpu(const Tensor& self, bool some) { } R = q_working_copy.slice(-2, 0, n_columns_q).slice(-1, 0, n).triu(); + if (!compute_q) { + // this is for mode='r' + Tensor empty_Q = at::empty({0}, self.options()); + return std::make_tuple(empty_Q, R); + } // Next perform ORGQR for Q using the results (both raw R and TAU) from GEQRF AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "qr_cpu", [&]{ @@ -834,22 +1085,185 @@ std::tuple _qr_helper_cpu(const Tensor& self, bool some) { return std::make_tuple(q_working_copy.narrow(-1, 0, n_columns_q), R); } -std::tuple qr(const Tensor& self, bool some) { +std::tuple linalg_qr(const Tensor& self, std::string mode) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - return at::_qr_helper(self, some); + "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + return at::_linalg_qr_helper(self, mode); } -std::tuple qr_out(Tensor& Q, Tensor& R, const Tensor& self, bool some) { +std::tuple linalg_qr_out(Tensor& Q, Tensor& R, const Tensor& self, std::string mode) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); Tensor Q_tmp, R_tmp; - std::tie(Q_tmp, R_tmp) = at::_qr_helper(self, some); - Q.resize_as_(Q_tmp).copy_(Q_tmp); - R.resize_as_(R_tmp).copy_(R_tmp); + std::tie(Q_tmp, R_tmp) = at::_linalg_qr_helper(self, mode); + at::native::resize_output(Q, Q_tmp.sizes()); + Q.copy_(Q_tmp); + at::native::resize_output(R, R_tmp.sizes()); + R.copy_(R_tmp); return std::tuple(Q, R); } +std::tuple qr(const Tensor& self, bool some) { + std::string mode = some ? "reduced" : "complete"; + return at::linalg_qr(self, mode); +} + +std::tuple qr_out(Tensor& Q, Tensor& R, const Tensor& self, bool some) { + std::string mode = some ? "reduced" : "complete"; + return at::linalg_qr_out(Q, R, self, mode); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v' +// The computation is done in-place: 'v' stores the input and will be overriden, 'w' should be an allocated empty array +// compute_v controls whether eigenvectors should be computed +// uplo_str controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" +// infos is used to store information for possible checks for error +// This function doesn't do any error checks and it's assumed that every argument is valid +template +static void apply_syevd(Tensor& w, Tensor& v, bool compute_v, const std::string& uplo_str, std::vector& infos) { +#ifndef USE_LAPACK + AT_ERROR("syevd: LAPACK library not found in compilation"); +#else + using value_t = typename c10::scalar_value_type::type; + + auto v_data = v.data_ptr(); + auto w_data = w.data_ptr(); + auto v_matrix_stride = matrixStride(v); + auto w_stride = w.size(-1); + auto batch_size = batchCount(v); + auto n = v.size(-1); + auto lda = std::max(int64_t{1}, n); + + // NumPy allows lowercase input for UPLO argument + // It is assumed that uplo_str is either "U" or "L" + char uplo = std::toupper(uplo_str[0]); + char jobz = compute_v ? 'V' : 'N'; + + // Using 'int' instead of int32_t or int64_t is consistent with the current LAPACK interface + // It really should be changed in the future to something like lapack_int that depends on the specific LAPACK library that is linked + // or switch to supporting only 64-bit indexing by default. + int info; + int lwork = -1; + int lrwork = -1; + int liwork = -1; + scalar_t work_query; + value_t rwork_query; + int iwork_query; + + // Run lapackSyevd once, first to get the optimum work size. + // Since we deal with batches of matrices with the same dimensions, doing this outside + // the main loop saves (batch_size - 1) workspace queries which would provide the same result + // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() + lapackSyevd(jobz, uplo, n, v_data, lda, w_data, &work_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, &info); + + lwork = std::max(1, real_impl(work_query)); + Tensor work = at::empty({lwork}, v.options()); + liwork = std::max(1, iwork_query); + Tensor iwork = at::empty({liwork}, at::kInt); + + Tensor rwork; + value_t* rwork_data = nullptr; + if (isComplexType(at::typeMetaToScalarType(v.dtype()))) { + lrwork = std::max(1, rwork_query); + rwork = at::empty({lrwork}, w.options()); + rwork_data = rwork.data_ptr(); + } + + // Now call lapackSyevd for each matrix in the batched input + for (auto i = decltype(batch_size){0}; i < batch_size; i++) { + scalar_t* v_working_ptr = &v_data[i * v_matrix_stride]; + value_t* w_working_ptr = &w_data[i * w_stride]; + lapackSyevd(jobz, uplo, n, v_working_ptr, lda, w_working_ptr, work.data_ptr(), lwork, rwork_data, lrwork, iwork.data_ptr(), liwork, &info); + infos[i] = info; + // The current behaviour for Linear Algebra functions to raise an error if something goes wrong or input doesn't satisfy some requirement + // therefore return early since further computations will be wasted anyway + if (info != 0) { + return; + } + } +#endif +} + +// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self' +// compute_eigenvectors controls whether eigenvectors should be computed +// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" +// This function prepares correct input for 'apply_syevd' and checks for possible errors using 'infos' +std::tuple _syevd_helper_cpu(const Tensor& self, bool compute_eigenvectors, std::string uplo) { + std::vector infos(batchCount(self), 0); + + auto self_sizes = self.sizes().vec(); + self_sizes.pop_back(); + ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); + auto eigvals = at::empty(self_sizes, self.options().dtype(dtype)); + + auto eigvecs = cloneBatchedColumnMajor(self); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "syevd_cpu", [&]{ + apply_syevd(eigvals, eigvecs, compute_eigenvectors, uplo, infos); + }); + + if (self.dim() > 2) { + batchCheckErrors(infos, "syevd_cpu"); + } else { + singleCheckErrors(infos[0], "syevd_cpu"); + } + if (compute_eigenvectors) { + return std::tuple(eigvals, eigvecs); + } else { + return std::tuple(eigvals, at::empty({0}, self.options())); + } +} + +std::tuple linalg_eigh(const Tensor& self, std::string uplo) { + squareCheckInputs(self); + checkUplo(uplo); + return at::_syevd_helper(self, /*compute_eigenvectors=*/true, uplo); +} + +// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out +// TODO: implement _out variant avoiding copy and using already allocated storage directly +std::tuple linalg_eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { + TORCH_CHECK(eigvecs.scalar_type() == self.scalar_type(), + "eigvecs dtype ", eigvecs.scalar_type(), " does not match self dtype ", self.scalar_type()); + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + TORCH_CHECK(eigvals.scalar_type() == real_dtype, + "eigvals dtype ", eigvals.scalar_type(), " does not match self dtype ", real_dtype); + + Tensor eigvals_tmp, eigvecs_tmp; + std::tie(eigvals_tmp, eigvecs_tmp) = at::linalg_eigh(self, uplo); + + at::native::resize_output(eigvals, eigvals_tmp.sizes()); + eigvals.copy_(eigvals_tmp); + at::native::resize_output(eigvecs, eigvecs_tmp.sizes()); + eigvecs.copy_(eigvecs_tmp); + + return std::tuple(eigvals, eigvecs); +} + +Tensor linalg_eigvalsh(const Tensor& self, std::string uplo) { + squareCheckInputs(self); + checkUplo(uplo); + Tensor eigvals, eigvecs; + std::tie(eigvals, eigvecs) = at::_syevd_helper(self, /*compute_eigenvectors=*/false, uplo); + return eigvals; +} + +// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out +// TODO: implement _out variant avoiding copy and using already allocated storage directly +Tensor& linalg_eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + TORCH_CHECK(result.scalar_type() == real_dtype, + "result dtype ", result.scalar_type(), " does not match self dtype ", real_dtype); + + Tensor result_tmp = at::linalg_eigvalsh(self, uplo); + + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -859,7 +1273,7 @@ static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool #else using value_t = typename c10::scalar_value_type::type; auto self_data = self.data_ptr(); - auto eigvals_data = eigvals.data_ptr(); + auto eigvals_data = eigvals.data_ptr(); auto self_matrix_stride = matrixStride(self); auto eigvals_stride = eigvals.size(-1); auto batch_size = batchCount(self); @@ -875,16 +1289,26 @@ static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() int lwork = -1; scalar_t wkopt; - lapackSymeig(jobz, uplo, n, self_data, n, eigvals_data, &wkopt, lwork, &info); + + Tensor rwork; + value_t* rwork_data = nullptr; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + int64_t lrwork = std::max(int64_t(1), 3 * n - 2); + ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); + rwork = at::empty({lrwork}, self.options().dtype(dtype)); + rwork_data = rwork.data_ptr(); + } + + lapackSymeig(jobz, uplo, n, self_data, n, eigvals_data, &wkopt, lwork, rwork_data, &info); lwork = static_cast(real_impl(wkopt)); Tensor work = at::empty({lwork}, self.options()); for (int64_t i = 0; i < batch_size; i++) { scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - scalar_t* eigvals_working_ptr = &eigvals_data[i * eigvals_stride]; + value_t* eigvals_working_ptr = &eigvals_data[i * eigvals_stride]; // now compute the eigenvalues and the eigenvectors (optionally) - lapackSymeig(jobz, uplo, n, self_working_ptr, n, eigvals_working_ptr, work.data_ptr(), lwork, &info); + lapackSymeig(jobz, uplo, n, self_working_ptr, n, eigvals_working_ptr, work.data_ptr(), lwork, rwork_data, &info); infos[i] = info; if (info != 0) { return; @@ -898,14 +1322,15 @@ std::tuple _symeig_helper_cpu(const Tensor& self, bool eigenvect auto self_sizes = self.sizes().vec(); self_sizes.pop_back(); - auto eigvals = at::empty(self_sizes, self.options()); + ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); + auto eigvals = at::empty(self_sizes, self.options().dtype(dtype)); if (self.numel() == 0) { return std::tuple(eigvals, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); } auto self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "symeig_cpu", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "symeig_cpu", [&]{ apply_symeig(self_working_copy, eigvals, eigenvectors, upper, infos); }); @@ -935,6 +1360,46 @@ std::tuple symeig_out(Tensor& vals, Tensor& vecs, const Tensor return std::tuple(vals, vecs); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +DEFINE_DISPATCH(eig_stub); + +std::tuple eig_out(Tensor& e, Tensor& v, const Tensor& self, bool eigenvectors) { + TORCH_CHECK(self.dim() == 2, "input should be 2 dimensional"); + TORCH_CHECK(self.size(0) == self.size(1), "input should be square"); + TORCH_CHECK(self.isfinite().all().item(), "input should not contain infs or NaNs"); + TORCH_CHECK(e.dtype() == self.dtype(), "Expected 'e' to have dtype ", self.dtype(), " but got ", e.dtype()); + if (eigenvectors) + TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype()); + int64_t n = self.size(-1); + + at::native::resize_output(e, {n, 2}); + if (eigenvectors) { + at::native::resize_output(v, self.sizes()); + } + + // optimization: if self is empty, we can immediately return the empty + // tensors, instead of getting empty tensors from eig_helper + if (self.numel() == 0) { + return std::tuple(e, v); + } + + Tensor vals_, vecs_; + std::tie(vals_, vecs_) = eig_stub(self.device().type(), self, eigenvectors); + e.copy_(vals_); + if (eigenvectors) { + v.copy_(vecs_); + } + return std::tuple(e, v); +} + +std::tuple eig(const Tensor& self, bool eigenvectors) { + Tensor e = at::empty({0}, self.options()); + Tensor v = at::empty({0}, self.options()); + at::eig_out(e, v, self, eigenvectors); + return std::tuple(e, v); +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -958,22 +1423,15 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, auto m = self.size(-2); auto n = self.size(-1); auto mn = std::min(m, n); - Tensor iwork = at::empty({8*mn}, at::kInt); + Tensor iwork = at::empty({8 * mn}, at::kInt); auto iwork_data = iwork.data_ptr(); Tensor rwork; - int* rwork_data = nullptr; + value_t* rwork_data = nullptr; if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { - auto mx = std::max(m, n); - int64_t lrwork; // These settings are valid for on LAPACK 3.6+ - if (jobz == 'N'){ - lrwork = 7 * mn; - }else if (mx > 10 * mn){ - lrwork = 7 * mn * mn + 7 * mn; - } else { - lrwork = std::max(7 * mn * mn + 7 * mn, 2 * mx * mn + 2 *mn * mn + mn); - } - rwork = at::empty({std::max(int64_t(1), lrwork)}, at::kInt); - rwork_data = rwork.data_ptr(); + auto lrwork = computeLRWorkDim(jobz, m, n); + // rwork is an array of floats or doubles depending on the type + rwork = at::empty({std::max(int64_t(1), lrwork)}, at::typeMetaToScalarType(S.dtype())); + rwork_data = rwork.data_ptr(); } // Run once, first to get the optimum work size. @@ -992,7 +1450,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, value_t* S_working_ptr = &S_data[i * S_stride]; scalar_t* U_working_ptr = &U_data[i * U_stride]; scalar_t* VT_working_ptr = &VT_data[i * VT_stride]; - + // Compute S, U (optionally) and VT (optionally) lapackSvd(jobz, m, n, self_working_ptr, m, S_working_ptr, U_working_ptr, m, VT_working_ptr, n, work_data, lwork, rwork_data, iwork_data, &info); @@ -1008,7 +1466,7 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); - + char jobz = compute_uv ? (some ? 'S' : 'A') : 'N'; Tensor U_working_copy, S_working_copy, VT_working_copy; @@ -1029,7 +1487,7 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -1039,24 +1497,71 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some U_working_copy.zero_(); VT_working_copy.zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_svd_helper(self, some, compute_uv); } -std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, +std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - Tensor U_tmp, S_tmp, VT_tmp; - std::tie(U_tmp, S_tmp, VT_tmp) = at::_svd_helper(self, some, compute_uv); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + Tensor U_tmp, S_tmp, V_tmp; + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); - VT.resize_as_(VT_tmp).copy_(VT_tmp); + V.resize_as_(V_tmp).copy_(V_tmp); + return std::tuple(U, S, V); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/* torch.linalg.svd, implemented in terms of torch.svd. There are two main + differences: + + 1. the 2nd parameter is bool some=True, which if effectively the opposite + of full_matrices=True + + 2. svd returns V, while linalg.svd returns VT. To accommodate the + difference, we transpose() V upon return +*/ + +std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { + TORCH_CHECK(self.dim() >= 2, + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + + bool some = !full_matrices; + Tensor U, S, V; + std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); + if (compute_uv) { + Tensor VT = V.transpose(-2, -1); + return std::make_tuple(U, S, VT); + } else { + Tensor empty_U = at::empty({0}, self.options()); + Tensor empty_VT = at::empty({0}, self.options()); + return std::make_tuple(empty_U, S, empty_VT); + } +} + +static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) { + TORCH_CHECK(src.device() == dst.device(), "svd output tensor ", name, " is on the wrong device: expected ", src.device(), " got ", dst.device()); + at::native::resize_output(dst, src.sizes()); + dst.copy_(src); +} + +std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, + const Tensor& self, bool full_matrices, bool compute_uv) { + Tensor U_tmp, S_tmp, VT_tmp; + std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv); + svd_resize_and_copy("U", U_tmp, U); + svd_resize_and_copy("S", S_tmp, S); + svd_resize_and_copy("V", VT_tmp, VT); return std::tuple(U, S, VT); } @@ -1102,7 +1607,7 @@ Tensor _lu_solve_helper_cpu(const Tensor& self, const Tensor& LU_data, const Ten if (self.numel() == 0 || LU_data.numel() == 0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{ apply_lu_solve(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, infos); }); if (self.dim() > 2) { diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h new file mode 100644 index 0000000000000..95fc2c6097ce9 --- /dev/null +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#include // for USE_LAPACK + + +namespace at { namespace native { + +#ifdef USE_LAPACK +// Define per-batch functions to be used in the implementation of batched +// linear algebra operations + +template +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); + +#endif + +using eig_fn = std::tuple (*)(const Tensor&, bool&); + +DECLARE_DISPATCH(eig_fn, eig_stub); + +}} // namespace at::native diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp new file mode 100644 index 0000000000000..d251245c60c58 --- /dev/null +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -0,0 +1,77 @@ +#include +#include +#include +#include + +#include // for USE_LAPACK + +namespace at { namespace native { + +namespace { + +template +void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int64_t* info_ptr) { +#ifndef USE_LAPACK + TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ", + "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); +#else + char jobvr = eigenvectors ? 'V' : 'N'; + int64_t n = self.size(-1); + auto self_data = self.data_ptr(); + + auto vals_data = vals_.data_ptr(); + scalar_t* wr = vals_data; + scalar_t* wi = vals_data + n; + + scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr() : nullptr; + int ldvr = eigenvectors ? n : 1; + + if (n > 0) { + // call lapackEig once to get the optimal size for work data + scalar_t wkopt; + int info; + lapackEig('N', jobvr, n, self_data, n, wr, wi, + nullptr, 1, vecs_data, ldvr, &wkopt, -1, &info); + int lwork = static_cast(wkopt); + + // call again to do the actual work + Tensor work = at::empty({lwork}, self.dtype()); + lapackEig('N', jobvr, n, self_data, n, wr, wi, + nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, &info); + *info_ptr = info; + } +#endif +} + +std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvectors) { + int64_t n = self.size(-1); + // lapackEig function expects the input to be column major, or stride {1, n}, + // so we must set the stride manually since the default stride for tensors is + // row major, {n, 1} + Tensor self_ = at::empty_strided( + {n, n}, + {1, n}, + at::TensorOptions(self.dtype())); + self_.copy_(self); + + auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor vals_ = at::empty_strided({n, 2}, {1, n}, options); + Tensor vecs_ = eigenvectors + ? at::empty_strided({n, n}, {1, n}, options) + : Tensor(); + + int64_t info; + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cpu", [&]{ + apply_eig(self_, eigenvectors, vals_, vecs_, &info); + }); + singleCheckErrors(info, "eig_cpu"); + return std::tuple(vals_, vecs_); +} + +} // anonymous namespace + +REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl); +REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl); +REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl); + +}} // namespace at::native diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index f8af756773c91..47f75b392f9a0 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -11,6 +11,18 @@ #include namespace at { +namespace meta { + +TORCH_META_FUNC2(add, Tensor) ( + const Tensor& self, const Tensor& other, Scalar alpha +) { + build_binary_op(maybe_get_output(), self, other); + native::alpha_check(dtype(), alpha); +} + +} // namespace meta + + namespace native { DEFINE_DISPATCH(add_stub); @@ -40,14 +52,17 @@ DEFINE_DISPATCH(tanh_backward_stub); DEFINE_DISPATCH(maximum_stub); DEFINE_DISPATCH(minimum_stub); DEFINE_DISPATCH(fmod_stub); -DEFINE_DISPATCH(fmod_scalar_stub); DEFINE_DISPATCH(logaddexp_stub); DEFINE_DISPATCH(logaddexp2_stub); DEFINE_DISPATCH(gcd_stub); DEFINE_DISPATCH(lcm_stub); DEFINE_DISPATCH(hypot_stub); +DEFINE_DISPATCH(igamma_stub); +DEFINE_DISPATCH(igammac_stub); DEFINE_DISPATCH(nextafter_stub); DEFINE_DISPATCH(heaviside_stub); +DEFINE_DISPATCH(copysign_stub); +DEFINE_DISPATCH(xlogy_stub); static Tensor wrapped_scalar_tensor(Scalar scalar) { auto tensor = scalar_to_tensor(scalar); @@ -55,24 +70,11 @@ static Tensor wrapped_scalar_tensor(Scalar scalar) { return tensor; } -Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { - auto iter = TensorIterator::binary_op(result, self, other); - alpha_check(iter.dtype(), alpha); - add_stub(iter.device_type(), iter, alpha); - TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); - return result; -} - -Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { - Tensor result; - auto iter = TensorIterator::binary_op(result, self, other); - alpha_check(iter.dtype(), alpha); - add_stub(iter.device_type(), iter, alpha); - return iter.output(); -} - -Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { - return native::add_out(self, self, other, alpha); +TORCH_IMPL_FUNC(add_out) ( + const Tensor& self, const Tensor& other, Scalar alpha, Tensor& result +) { + add_stub(device_type(), *this, alpha); + TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype()); } Tensor& add_relu_impl( @@ -121,7 +123,32 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) { return add_relu_impl(self, self, other, alpha); } -Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { +Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_float_op(result, self, other); + copysign_stub(iter.device_type(), iter); + return result; +} + +Tensor copysign(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_float_op(result, self, other); + copysign_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor& copysign_(Tensor& self, const Tensor& other) { + return native::copysign_out(self, self, other); +} + +Tensor copysign(const Tensor& self, Scalar other) { + return native::copysign(self, wrapped_scalar_tensor(other)); +} + +Tensor& copysign_(Tensor& self, Scalar other) { + return native::copysign_(self, wrapped_scalar_tensor(other)); +} + +Tensor& div_out(const Tensor& self, const Tensor& other, Tensor& result) { auto iter = TensorIterator::binary_float_op(result, self, other); div_stub(iter.device_type(), iter); return result; @@ -135,7 +162,7 @@ Tensor div(const Tensor& self, const Tensor& other) { } Tensor& div_(Tensor& self, const Tensor& other) { - return native::div_out(self, self, other); + return native::div_out(self, other, self); } // WARNING: There doesn't appear to be any testing for this function @@ -422,36 +449,27 @@ static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tens return wrapped_scalar_tensor(scalar); } +// TODO: Make this structured to undo the perf regression from native:: removal +// in call here + Tensor add(const Tensor& self, Scalar other, Scalar alpha) { - return native::add(self, wrapped_scalar_tensor(other), alpha); + return at::add(self, wrapped_scalar_tensor(other), alpha); } Tensor& add_(Tensor& self, Scalar other, Scalar alpha) { - return native::add_(self, wrapped_scalar_tensor(other), alpha); + return self.add_(wrapped_scalar_tensor(other), alpha); } Tensor remainder(const Tensor& self, Scalar other) { - Tensor other_tensor = wrapped_scalar_tensor(other); - // FIXME: 'other' is converted to match the dtype of 'self' to retain - // BC with TH, but in the future, we should use normal type promotion, - // like in numpy - return native::remainder(self, other_tensor.toType(self.scalar_type())); + return native::remainder(self, wrapped_scalar_tensor(other)); } Tensor& remainder_(Tensor& self, Scalar other) { - Tensor other_tensor = wrapped_scalar_tensor(other); - // FIXME: 'other' is converted to match the dtype of 'self' to retain - // BC with TH, but in the future, we should use normal type promotion, - // like in numpy - return native::remainder_(self, other_tensor.toType(self.scalar_type())); + return native::remainder_(self, wrapped_scalar_tensor(other)); } Tensor& remainder_out(Tensor& result, const Tensor& self, Scalar other) { - Tensor other_tensor = wrapped_scalar_tensor(other); - // FIXME: 'other' is converted to match the dtype of 'self' to retain - // BC with TH, but in the future, we should use normal type promotion, - // like in numpy - return native::remainder_out(result, self, other_tensor.toType(self.scalar_type())); + return native::remainder_out(result, self, wrapped_scalar_tensor(other)); } Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) { @@ -808,16 +826,12 @@ Tensor logical_xor(const Tensor& self, Scalar other) { return comparison_op(self Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast(at::logical_xor_out)); } Tensor& maximum_out(Tensor& result, const Tensor& self, const Tensor& other) { - TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs."); - auto iter = TensorIterator::binary_op(result, self, other); maximum_stub(iter.device_type(), iter); return result; } Tensor maximum(const Tensor& self, const Tensor& other) { - TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs."); - Tensor result; auto iter = TensorIterator::binary_op(result, self, other); maximum_stub(iter.device_type(), iter); @@ -834,16 +848,12 @@ Tensor max(const Tensor& self, const Tensor& other) { } Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) { - TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs."); - auto iter = TensorIterator::binary_op(result, self, other); minimum_stub(iter.device_type(), iter); return result; } Tensor minimum(const Tensor& self, const Tensor& other) { - TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs."); - Tensor result; auto iter = TensorIterator::binary_op(result, self, other); minimum_stub(iter.device_type(), iter); @@ -869,34 +879,31 @@ Tensor& floor_divide_(Tensor& self, Scalar other) { Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_op(result, self, other); - TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU"); fmod_stub(iter.device_type(), iter); return result; } Tensor& fmod_out(Tensor & result, const Tensor& self, Scalar other) { - auto iter = TensorIterator::unary_op(result, self); - TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU"); - fmod_scalar_stub(iter.device_type(), iter, other); - return result; + return native::fmod_out(result, self, wrapped_scalar_tensor(other)); } Tensor fmod(const Tensor& self, const Tensor & other) { - Tensor result = at::empty({0}, self.options()); - return at::fmod_out(result, self, other); + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + fmod_stub(iter.device_type(), iter); + return iter.output(); } Tensor fmod(const Tensor& self, Scalar other) { - Tensor result = at::empty({0}, self.options()); - return at::fmod_out(result, self, other); + return native::fmod(self, wrapped_scalar_tensor(other)); } Tensor& fmod_(Tensor& self, const Tensor& other) { - return at::fmod_out(self, self, other); + return native::fmod_out(self, self, other); } Tensor& fmod_(Tensor& self, Scalar other) { - return at::fmod_out(self, self, other); + return native::fmod_(self, wrapped_scalar_tensor(other)); } Tensor& logaddexp_out(Tensor& result, const Tensor& self, const Tensor& other) { @@ -968,6 +975,40 @@ Tensor& hypot_(Tensor& self, const Tensor& other) { return at::hypot_out(self, self, other); } +Tensor& igamma_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_op(result, self, other); + igamma_stub(iter.device_type(), iter); + return result; +} + +Tensor igamma(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + igamma_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor& igamma_(Tensor& self, const Tensor& other) { + return at::igamma_out(self, self, other); +} + +Tensor& igammac_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_op(result, self, other); + igammac_stub(iter.device_type(), iter); + return result; +} + +Tensor igammac(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + igammac_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor& igammac_(Tensor& self, const Tensor& other) { + return at::igammac_out(self, self, other); +} + Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_op(result, self, other); nextafter_stub(iter.device_type(), iter); @@ -1018,36 +1059,53 @@ Tensor& heaviside_(Tensor& self, const Tensor& values) { return at::heaviside_out(self, self, values); } -// TODO: Deduplicate this with the TensorIterator logic. This would -// also fix the TODOs below. -Tensor binary_op_meta(const Tensor& self, const Tensor& other) { - // TODO: Doesn't do type promotion correctly - // TODO: Doesn't do strides correctly - int64_t dim = std::max(self.dim(), other.dim()); - std::vector sizes(dim); - for (int64_t i = 0; i < dim; i++) { - int64_t j = -1 - i; - if (i >= self.dim() || self.size(j) == 1) { - sizes[dim + j] = other.size(j); - } else if (i >= other.dim() || self.size(i) == 1) { - sizes[dim + j] = self.size(j); - } else { - TORCH_CHECK( - self.size(j) == other.size(j), - "Expected self.size(", j, ") == other.size(", j, "), but got ", self.size(j), " != ", other.size(j) - ); - sizes[dim + j] = self.size(j); - } - } - return at::empty_meta(sizes, self.options()); +Tensor& ldexp_out(Tensor& result, const Tensor& self, const Tensor& other) { + return at::mul_out(result, self, at::pow(2.0, other)); +} + +Tensor ldexp(const Tensor& self, const Tensor& other) { + return at::mul(self, at::pow(2.0, other)); +} + +Tensor& ldexp_(Tensor& self, const Tensor& other) { + return at::ldexp_out(self, self, other); +} + +Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_float_op(result, self, other); + xlogy_stub(iter.device_type(), iter); + return result; +} + +Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) { + return at::xlogy_out(result, wrapped_scalar_tensor(self), other); +} + +Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) { + return at::xlogy_out(result, self, wrapped_scalar_tensor(other)); +} + +Tensor xlogy(const Tensor& x, const Tensor& y) { + Tensor result; + auto iter = TensorIterator::binary_float_op(result, x, y); + xlogy_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor xlogy(Scalar x, const Tensor& y) { + return at::xlogy(wrapped_scalar_tensor(x), y); +} + +Tensor xlogy(const Tensor& x, Scalar y) { + return at::xlogy(x, wrapped_scalar_tensor(y)); } -Tensor binary_op_with_scalar_meta(const Tensor& self, const Tensor& other, Scalar x) { - return binary_op_meta(self, other); +Tensor& xlogy_(Tensor& x, const Tensor& y) { + return at::xlogy_out(x, x, y); } -TORCH_LIBRARY_IMPL(aten, Meta, m) { - m.impl("add.Tensor", binary_op_with_scalar_meta); +Tensor& xlogy_(Tensor& x, Scalar y) { + return at::xlogy_out(x, x, wrapped_scalar_tensor(y)); } } // namespace native diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index e2dad35eb7ece..191611875f08f 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -10,7 +10,8 @@ namespace at { namespace native { inline void alpha_check(const ScalarType dtype, Scalar alpha) { TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool, "Boolean alpha only supported for Boolean results."); - TORCH_CHECK(isFloatingType(dtype) || alpha.isIntegral(true), + TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype) + || alpha.isIntegral(true), "For integral input tensors, argument alpha must not be a floating point number."); } @@ -24,12 +25,15 @@ inline void sub_check(const Tensor& self, const Tensor& other) { "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); } +using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, Scalar alpha); + using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha); +using binary_fn_beta = void(*)(TensorIterator&, double beta); using binary_fn = void(*)(TensorIterator&); using binary_clamp_fn_alpha = void(*)(TensorIterator&, Scalar alpha, Scalar min_val, Scalar max_val); -DECLARE_DISPATCH(binary_fn_alpha, add_stub); +DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub); DECLARE_DISPATCH(binary_fn_alpha, sub_stub); DECLARE_DISPATCH(binary_fn, mul_stub); @@ -54,19 +58,22 @@ DECLARE_DISPATCH(binary_fn, max_elementwise_stub); DECLARE_DISPATCH(binary_fn, min_elementwise_stub); DECLARE_DISPATCH(binary_fn, maximum_stub); DECLARE_DISPATCH(binary_fn, minimum_stub); -DECLARE_DISPATCH(binary_fn, smooth_l1_stub); +DECLARE_DISPATCH(binary_fn_beta, smooth_l1_stub); DECLARE_DISPATCH(binary_fn, sigmoid_backward_stub); DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub); DECLARE_DISPATCH(binary_fn, tanh_backward_stub); DECLARE_DISPATCH(binary_fn, mse_stub); DECLARE_DISPATCH(binary_fn, fmod_stub); -DECLARE_DISPATCH(binary_fn_alpha, fmod_scalar_stub); DECLARE_DISPATCH(binary_fn, logaddexp_stub); DECLARE_DISPATCH(binary_fn, logaddexp2_stub); DECLARE_DISPATCH(binary_fn, gcd_stub); DECLARE_DISPATCH(binary_fn, lcm_stub); DECLARE_DISPATCH(binary_fn, hypot_stub); +DECLARE_DISPATCH(binary_fn, igamma_stub); +DECLARE_DISPATCH(binary_fn, igammac_stub); DECLARE_DISPATCH(binary_fn, nextafter_stub); DECLARE_DISPATCH(binary_fn, heaviside_stub); +DECLARE_DISPATCH(binary_fn, copysign_stub); +DECLARE_DISPATCH(binary_fn, xlogy_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/ComplexHelper.h b/aten/src/ATen/native/ComplexHelper.h index b8830691f47a3..3fde6dbb77e1c 100644 --- a/aten/src/ATen/native/ComplexHelper.h +++ b/aten/src/ATen/native/ComplexHelper.h @@ -4,12 +4,25 @@ namespace at { namespace native { -inline std::vector computeStrideForViewAsReal(IntArrayRef oldstride) { - auto res = oldstride.vec(); - for(size_t i = 0; i < res.size(); i++) { - res[i] = res[i] * 2; +// View tensor with new dtype, storage offset, sizes and strides +inline Tensor view_tensor( + const Tensor &tensor, ScalarType dtype, + int64_t offset, IntArrayRef sizes, IntArrayRef strides) { + Storage storage = tensor.storage(); + auto new_tensor = detail::make_tensor( + std::move(storage), tensor.key_set(), scalarTypeToTypeMeta(dtype)); + auto * impl = new_tensor.unsafeGetTensorImpl(); + impl->set_storage_offset(offset); + impl->set_sizes_and_strides(sizes, strides); + return new_tensor; +} + +inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) { + DimVector res(oldstride.size() + 1); + for(size_t i = 0; i < oldstride.size(); i++) { + res[i] = oldstride[i] * 2; } - res.emplace_back(1); + res.back() = 1; return res; } @@ -18,25 +31,25 @@ inline std::vector computeStrideForViewAsReal(IntArrayRef oldstride) { // in the last two dimensions Tensor view_as_real(const Tensor& self) { TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); - auto new_sizes = self.sizes().vec(); + auto old_sizes = self.sizes(); + DimVector new_sizes(old_sizes.size() + 1); + std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); // last dimension will always have two elements containing the real and imag vals - new_sizes.emplace_back(2); + new_sizes.back() = 2; auto new_strides = computeStrideForViewAsReal(self.strides()); auto new_storage_offset = 2 * self.storage_offset(); const auto float_type = c10::toValueType(self.scalar_type()); - return at::empty({0}, self.options().dtype(float_type)).set_(self.storage(), new_storage_offset, new_sizes, new_strides); + return view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides); } -inline std::vector computeStrideForViewAsComplex(IntArrayRef oldstride) { - auto res = oldstride.vec(); - int dim = res.size(); - - TORCH_CHECK(res[dim-1] == 1, "Tensor must have a last dimension with stride 1"); - res.pop_back(); +inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) { + const int64_t dim = oldstride.size(); + TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1"); - for (auto i = decltype(res.size()){0}; i < res.size(); i++) { - TORCH_CHECK(res[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension"); - res[i] = res[i] / 2; + DimVector res(dim - 1); + for (int64_t i = 0; i < res.size(); i++) { + TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension"); + res[i] = oldstride[i] / 2; } return res; } @@ -48,10 +61,10 @@ Tensor view_as_complex(const Tensor& self) { self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf, "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type()); - TORCH_CHECK(self.dim() != 0, "Input tensor must have one or more dimensions"); - auto new_sizes = self.sizes().vec(); - TORCH_CHECK(new_sizes[self.dim()-1] == 2, "Tensor must have a last dimension of size 2"); - new_sizes.pop_back(); + auto old_sizes = self.sizes(); + TORCH_CHECK(old_sizes.size() != 0, "Input tensor must have one or more dimensions"); + TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2"); + DimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); const auto new_strides = computeStrideForViewAsComplex(self.strides()); const auto complex_type = c10::toComplexType(self.scalar_type()); @@ -59,7 +72,7 @@ Tensor view_as_complex(const Tensor& self) { TORCH_CHECK(self.storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); const auto new_storage_offset = self.storage_offset() / 2; - return at::empty({0}, self.options().dtype(complex_type)).set_(self.storage(), new_storage_offset, new_sizes, new_strides); + return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides); } }} // namespace at::native diff --git a/aten/src/ATen/native/CompositeRandomAccessorCommon.h b/aten/src/ATen/native/CompositeRandomAccessorCommon.h index 256ae5b0d98fb..0be75d8244f03 100644 --- a/aten/src/ATen/native/CompositeRandomAccessorCommon.h +++ b/aten/src/ATen/native/CompositeRandomAccessorCommon.h @@ -122,6 +122,9 @@ class CompositeRandomAccessor { using difference_type = typename std::iterator_traits::difference_type; using iterator_category = std::random_access_iterator_tag; + C10_HOST_DEVICE + CompositeRandomAccessor() = default; + C10_HOST_DEVICE CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values) : keys(keys), values(values) @@ -129,7 +132,7 @@ class CompositeRandomAccessor { // Pointer-like operations { C10_HOST_DEVICE - reference operator*() { + reference operator*() const { return TupleInfo::tie(*keys, *values); } diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index ea7903369e93a..66eca30809843 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -46,7 +46,7 @@ struct ConvParams { bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const; bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const; bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const; - bool use_nnpack(const at::Tensor& input) const; + bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const; bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) const; bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; }; @@ -62,6 +62,7 @@ std::ostream& operator<<(std::ostream & out, const ConvParams& params) { << " benchmark = " << params.benchmark << " deterministic = " << params.deterministic << " cudnn_enabled = " << params.cudnn_enabled + << " allow_tf32 = " << params.allow_tf32 << "}"; return out; } @@ -101,7 +102,7 @@ auto ConvParams::is_output_padding_neg() const -> bool { auto ConvParams::is_output_padding_big() const -> bool { bool is_big = false; for (size_t i = 0; i < output_padding.size(); i++) { - is_big |= (output_padding[i] >= stride[i] || output_padding[i] >= dilation[i]); + is_big |= (output_padding[i] >= stride[i]); } return is_big; } @@ -176,14 +177,10 @@ auto ConvParams::needs_64bit_indexing_no_split(const at::Tensor& input, const at int64_t outsize = 1; if (transposed) { std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); - for (int64_t i = 1; i < o.size(); i++) { - outsize *= o[i]; - } + outsize = prod_intlist(o.begin() + 1, o.end()); } else { std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); - for (int64_t i = 1; i < o.size(); i++) { - outsize *= o[i]; - } + outsize = prod_intlist(o.begin() + 1, o.end()); } return outsize > int_max; } @@ -198,6 +195,9 @@ auto ConvParams::use_cudnn(const at::Tensor& input, const at::Tensor& weight) co if (!input.is_cuda() || !cudnn_enabled) { return false; } + if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { + return false; + } if (!cudnn_conv_use_channels_last(input, weight)) { // bypass dilation checks for channels-last convolution if (deterministic && is_dilated()) { // cudnn doesn't support deterministic dilated convolution fully yet @@ -233,20 +233,26 @@ auto ConvParams::use_mkldnn(const at::Tensor& input, const at::Tensor& weight) c (input.options().backend() == at::Backend::CPU && input.scalar_type() == kFloat && // only on CPU Float Tensors !transposed && // or transposed tensors - (groups > 1 || weight.size(2) > 3 || input.size(0) > 1 + (is_strided() || is_dilated() || input.size(0) >= 16 || + weight.size(-1) != 1 || weight.size(-2) != 1) && + (groups > 1 + || (weight.size(-1) > 3 && weight.size(-2) > 3) + || input.size(0) > 1 || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480)); // for some case, native is faster #endif return false; } -auto ConvParams::use_nnpack(const at::Tensor& input) const -> bool { +auto ConvParams::use_nnpack(const at::Tensor& input, const at::Tensor& weight) const -> bool { #if AT_NNPACK_ENABLED() return at::_nnpack_available() && input.options().backend() == at::Backend::CPU && input.scalar_type() == kFloat && // only on CPU Float Tensors !is_dilated() && // or dilation !transposed && // or transposed tensors - input.ndimension() == 4 // must be in NCHW format + input.ndimension() == 4 && // must be in NCHW format + weight.ndimension() == 4 && + (weight.size(2) < 17) && (weight.size(3) < 17) // NNPACK only supports kernels up to 16x16 #if !defined(C10_MOBILE) && !defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) && input.size(0) >= 16 // ensure large enough batch size to ensure perf, tuneable #endif @@ -845,7 +851,7 @@ at::Tensor _convolution_nogroup( input, weight, kernel_size, bias, stride, padding, dilation); } else { /* dim == 4, non-dilated */ - if (params.use_nnpack(input)) { + if (params.use_nnpack(input, weight)) { #if AT_NNPACK_ENABLED() return at::_nnpack_spatial_convolution( input, weight, bias, padding, stride); diff --git a/aten/src/ATen/native/ConvolutionMM2d.cpp b/aten/src/ATen/native/ConvolutionMM2d.cpp index 9bc6b476e2215..6a0ca1e679007 100644 --- a/aten/src/ATen/native/ConvolutionMM2d.cpp +++ b/aten/src/ATen/native/ConvolutionMM2d.cpp @@ -191,11 +191,10 @@ static void slow_conv2d_update_output_frame( output.reshape({n_output_plane, output_height * output_width}); if (bias.defined()) { output.copy_(bias.unsqueeze(-1).unsqueeze(-1)); + output2d.addmm_(weight, finput, 1, 1); } else { - output.zero_(); + output2d.addmm_(weight, finput, 0, 1); } - - output2d.addmm_(weight, finput, 1, 1); } void slow_conv2d_backward_update_grad_input_frame( @@ -434,16 +433,23 @@ std::tuple slow_conv2d_forward_out_cpu( const int64_t batch_size = input.size(0); - finput.resize_({batch_size, + if ((input.ndimension() == 4) && (kernel_height == 1) && (stride_height == 1) && (pad_height == 0) && + (kernel_width == 1) && (stride_width == 1) && (pad_width == 0)) { + finput = + input.view({batch_size, n_input_plane, output_height * output_width}) + .detach(); + } else { + finput.resize_({batch_size, n_input_plane * kernel_height * kernel_width, output_height * output_width}); + } output.resize_({batch_size, n_output_plane, output_height, output_width}); at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) { NoGradGuard no_grad; AutoNonVariableTypeMode non_variable_type_mode; for (int64_t t = start; t < end; t++) { - Tensor input_t = input[t]; + Tensor input_t = input[t].unsqueeze(0); Tensor output_t = output[t]; Tensor finput_t = finput[t]; slow_conv2d_update_output_frame( diff --git a/aten/src/ATen/native/ConvolutionMM3d.cpp b/aten/src/ATen/native/ConvolutionMM3d.cpp index d5a29a3abbe1f..95263617e2a8c 100644 --- a/aten/src/ATen/native/ConvolutionMM3d.cpp +++ b/aten/src/ATen/native/ConvolutionMM3d.cpp @@ -581,9 +581,15 @@ std::tuple slow_conv3d_forward_out_cpu( (input_width + 2 * pad_width - kernel_width) / stride_width + 1; const int64_t batch_size = input.size(0); - finput.resize_({batch_size, - n_input_plane * kernel_depth * kernel_height * kernel_width, - output_depth * output_height * output_width}); + if ((kernel_depth == 1) && (kernel_height == 1) && (kernel_width == 1) && + (pad_depth == 0) && (pad_height == 0) && (pad_width == 0) && + (stride_depth == 1) && (stride_height == 1) && (stride_width == 1) && (groups == 1)) { + finput = input.view({batch_size, n_input_plane, output_height * output_width * output_depth}).detach(); + } else { + finput.resize_({batch_size, + n_input_plane * kernel_depth * kernel_height * kernel_width, + output_depth * output_height * output_width}); + } output.resize_( {batch_size, n_output_plane, output_depth, output_height, output_width}); diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 79fb0a11fba47..98f6ce65a5a31 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -5,12 +5,20 @@ #include #include #include +#include #include #include +#include #include #include +#include #include +#ifdef USE_FBGEMM +#include +#include +#endif + namespace { using namespace at; @@ -34,7 +42,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) { } Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options()); - AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, self.scalar_type(), "copy_", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, self.scalar_type(), "copy_", [&] { scalar_t* sp = src.data_ptr(); scalar_t* rp = self.data_ptr(); scalar_t* bp = buf.data_ptr(); @@ -79,7 +87,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) { // (e.g. XLA) may be supported by overriding copy_ and _copy_from. bool is_supported_device(Device device) { DeviceType device_type = device.type(); - return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan; + return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal; } } // namespace @@ -92,6 +100,47 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) TORCH_CHECK(self.defined(), "self is undefined"); TORCH_CHECK(src.defined(), "src is undefined"); + // FBGeMM kernel support exists only for the following case, + // 1. Memory Format for source and destination tensors is contiguous. + // 2. Device for both the source and destination tensor is CPU. + // 3. dtype conversion between FP32->FP16 and FP16->FP32. + #ifdef USE_FBGEMM + if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) || + (self.dtype() == at::kHalf && src.dtype() == at::kFloat)) && + (self.device().is_cpu() && src.device().is_cpu()) && + !self.is_sparse() && !src.is_sparse() && + ((self.is_contiguous() && src.is_contiguous()) || + (self.is_non_overlapping_and_dense() && self.strides() == src.strides()))) { + if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) { + auto* output_ptr = + reinterpret_cast(self.data_ptr()); + at::parallel_for( + 0, + self.numel(), + at::internal::GRAIN_SIZE, + [&](int64_t begin, int64_t end) { + fbgemm::FloatToFloat16_simd( + src.data_ptr() + begin, + output_ptr + begin, + end - begin); + }); + } else { + auto in_data = reinterpret_cast( + src.data_ptr()); + auto* output_ptr = self.data_ptr(); + at::parallel_for( + 0, + self.numel(), + at::internal::GRAIN_SIZE, + [&](int64_t begin, int64_t end) { + fbgemm::Float16ToFloat_simd( + in_data + begin, output_ptr + begin, end - begin); + }); + } + return self; + } + #endif + if (self.is_sparse() && src.is_sparse()) { return at::copy_sparse_to_sparse_(self, src, non_blocking); } else if (self.is_sparse() || src.is_sparse()) { @@ -122,7 +171,7 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) TORCH_CHECK(self.qscheme() == src.qscheme(), "Quantized Copy only works with same qscheme"); TORCH_CHECK(self.scalar_type() == src.scalar_type()); - self.set_quantizer_(src.quantizer()); + set_quantizer_(self, src.quantizer()); } if (!self.is_quantized() && src.is_quantized()) { @@ -130,7 +179,15 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) } if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) { + #ifdef USE_VULKAN_API + return vulkan::ops::copy_(self, src); + #else return at::vulkan::vulkan_copy_(self, src); + #endif + } + + if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) { + return at::metal::metal_copy_(self, src); } auto iter = TensorIteratorConfig() diff --git a/aten/src/ATen/native/DilatedMaxPool2d.cpp b/aten/src/ATen/native/DilatedMaxPool2d.cpp index 53d1a5f579f38..b97ec9a5893bb 100644 --- a/aten/src/ATen/native/DilatedMaxPool2d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool2d.cpp @@ -169,7 +169,7 @@ void max_pool2d_with_indices_out_cpu_template( kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth); + outputHeight, outputWidth, input_.suggest_memory_format()); /* get contiguous input */ Tensor input = input_.contiguous(); @@ -360,7 +360,8 @@ Tensor& max_pool2d_with_indices_backward_out_cpu_template( kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, - outputHeight_for_shape_check, outputWidth_for_shape_check); + outputHeight_for_shape_check, outputWidth_for_shape_check, + input.suggest_memory_format()); /* backprop */ if (input.ndimension() == 3) diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index d4c106477fe76..0c562f3637310 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -11,12 +11,18 @@ namespace at { namespace native { static CPUCapability compute_cpu_capability() { auto envar = std::getenv("ATEN_CPU_CAPABILITY"); if (envar) { +#ifdef HAVE_VSX_CPU_DEFINITION + if (strcmp(envar, "vsx") == 0) { + return CPUCapability::VSX; + } +#else if (strcmp(envar, "avx2") == 0) { return CPUCapability::AVX2; } if (strcmp(envar, "avx") == 0) { return CPUCapability::AVX; } +#endif if (strcmp(envar, "default") == 0) { return CPUCapability::DEFAULT; } @@ -33,7 +39,11 @@ static CPUCapability compute_cpu_capability() { } } #endif +#ifdef HAVE_VSX_CPU_DEFINITION + return CPUCapability::VSX; +#else return CPUCapability::DEFAULT; +#endif } CPUCapability get_cpu_capability() { diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index dc21a505e8c11..b5de0f589de77 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -3,7 +3,9 @@ #include #include #include + #include +#include // Implements instruction set specific function dispatch. // @@ -45,18 +47,22 @@ namespace at { namespace native { enum class CPUCapability { DEFAULT = 0, +#ifdef HAVE_VSX_CPU_DEFINITION + VSX = 1, +#else AVX = 1, AVX2 = 2, +#endif NUM_OPTIONS }; CPUCapability get_cpu_capability(); template -struct CAFFE2_API DispatchStub; +struct TORCH_API DispatchStub; template -struct CAFFE2_API DispatchStub { +struct TORCH_API DispatchStub { using FnPtr = rT (*) (Args...); DispatchStub() = default; @@ -99,6 +105,12 @@ struct CAFFE2_API DispatchStub { AT_ASSERTM(AVX, "DispatchStub: missing AVX kernel"); return AVX; } +#endif +#ifdef HAVE_VSX_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::VSX)) { + AT_ASSERTM(VSX, "DispatchStub: missing VSX kernel"); + return VSX; + } #endif AT_ASSERTM(DEFAULT, "DispatchStub: missing default kernel"); return DEFAULT; @@ -122,6 +134,9 @@ struct CAFFE2_API DispatchStub { #ifdef HAVE_AVX2_CPU_DEFINITION static FnPtr AVX2; #endif +#ifdef HAVE_VSX_CPU_DEFINITION + static FnPtr VSX; +#endif }; namespace { @@ -152,7 +167,7 @@ struct RegisterHIPDispatch { name(const name&) = delete; \ name& operator=(const name&) = delete; \ }; \ - extern CAFFE2_API struct name name + extern TORCH_API struct name name #define DEFINE_DISPATCH(name) struct name name @@ -171,10 +186,17 @@ struct RegisterHIPDispatch { #define REGISTER_AVX2_DISPATCH(name, fn) #endif +#ifdef HAVE_VSX_CPU_DEFINITION +#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) +#else +#define REGISTER_VSX_DISPATCH(name, fn) +#endif + #define REGISTER_NO_CPU_DISPATCH(name, fn_type) \ REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast(nullptr)) \ REGISTER_AVX_DISPATCH(name, static_cast(nullptr)) \ - REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) + REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) \ + REGISTER_VSX_DISPATCH(name, static_cast(nullptr)) #define REGISTER_CUDA_DISPATCH(name, fn) \ static RegisterCUDADispatch name ## __register(name, fn); diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b2b760513a1d1..91d804687290e 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -27,7 +27,7 @@ Tensor pdist(const Tensor& self, const double p) { Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) { /** This function does the fist part of the euclidean distance calculation - * We divide it in two steps to simplify dealing with subgradients in the + * We divide it in two steps to simplify dealing with subgradients in the * backward step */ Tensor x1_norm = x1.pow(2).sum(-1, true); Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT); @@ -74,7 +74,7 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10 std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies()); + const int64_t expand_batch_product = prod_intlist(expand_batch_portion); std::vector tensor1_view{expand_batch_product, r1, c1}; std::vector tensor2_view{expand_batch_product, r2, c2}; @@ -147,8 +147,10 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c auto device2 = x2.device().type(); TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2); IntArrayRef batch_tensor1(x1.sizes().data(), std::max(x1.dim() - 2, 0)); - int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies()); - Tensor grad_x1 = at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT).view({batch_product, n, m}); + const int64_t batch_product = prod_intlist(batch_tensor1); + Tensor grad_x1 = + at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT) + .view({batch_product, n, m}); cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist); return grad_x1; } diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 7f2ceb267efca..413ea32acdef2 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -117,7 +118,7 @@ DEFINE_DISPATCH(bernoulli_tensor_stub); DEFINE_DISPATCH(bernoulli_scalar_stub); DEFINE_DISPATCH(cauchy_stub); DEFINE_DISPATCH(exponential_stub); -DEFINE_DISPATCH(multinomial_stub); +DEFINE_DISPATCH(multinomial_with_replacement_stub); DEFINE_DISPATCH(geometric_stub); DEFINE_DISPATCH(log_normal_stub); DEFINE_DISPATCH(uniform_stub); @@ -300,24 +301,31 @@ Tensor& random_(Tensor& self, int64_t to, c10::optional gen) { Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) { Tensor ret = at::empty(self.sizes(), self.options()); + auto iter = TensorIteratorConfig() + .add_output(ret) + .add_input(self) + .add_input(output) + .build(); AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "_standard_gamma_grad_cpu", [&] { - CPU_tensor_apply3(ret, self, output, - [](scalar_t& ret_val, const scalar_t& self_val, const scalar_t &output_val) { - ret_val = standard_gamma_grad_one(self_val, output_val); - } - ); + cpu_serial_kernel(iter, [](scalar_t self_val, scalar_t output_val) -> scalar_t{ + return standard_gamma_grad_one(self_val, output_val); + }); }); return ret; } Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& total) { Tensor ret = at::empty(x.sizes(), x.options()); + auto iter = TensorIteratorConfig() + .add_output(ret) + .add_input(x) + .add_input(alpha) + .add_input(total) + .build(); AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "_dirichlet_grad_cpu", [&] { - CPU_tensor_apply4(ret, x, alpha, total, - [](scalar_t& ret_val, const scalar_t& x_val, const scalar_t& alpha_val, const scalar_t& total_val) { - ret_val = dirichlet_grad_one(x_val, alpha_val, total_val); - } - ); + cpu_serial_kernel(iter, [](scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t{ + return dirichlet_grad_one(x_val, alpha_val, total_val); + }); }); return ret; } @@ -328,67 +336,72 @@ Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& t Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, c10::optional gen) { Tensor ret = at::zeros(count.sizes(), count.options()); + auto iter = TensorIteratorConfig() + .add_output(ret) + .add_input(count) + .add_input(prob) + .build(); AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "binomial_cpu", [&] { CPUGeneratorImpl* generator = get_generator_or_default(gen, detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(generator->mutex_); - CPU_tensor_apply3(ret, count, prob, - [generator](scalar_t& ret_val, const scalar_t& count, const scalar_t& prob){ - - auto uniform_lambda = [generator] () { - at::uniform_real_distribution standard_uniform(0.0, 1.0); - return standard_uniform(generator); - }; - BaseSampler standard_uniform(uniform_lambda); - - auto sample = sample_binomial(count, prob, standard_uniform); - ret_val = static_cast(sample); - } - ); + cpu_serial_kernel(iter, [generator](scalar_t count_val, scalar_t prob_val) -> scalar_t{ + auto uniform_lambda = [generator] () { + at::uniform_real_distribution standard_uniform(0.0, 1.0); + return standard_uniform(generator); + }; + BaseSampler standard_uniform(uniform_lambda); + + auto sample = sample_binomial(count_val, prob_val, standard_uniform); + return static_cast(sample); }); + }); return ret; } Tensor _s_poisson_cpu(const Tensor& lambda, c10::optional gen) { Tensor ret = at::zeros(lambda.sizes(), lambda.options()); + auto iter = TensorIteratorConfig() + .add_output(ret) + .add_input(lambda) + .build(); AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "poisson_cpu", [&] { CPUGeneratorImpl* generator = get_generator_or_default(gen, detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(generator->mutex_); - CPU_tensor_apply2(ret, lambda, - [generator](scalar_t& ret_val, const scalar_t& lambda){ - ret_val = static_cast(sample_poisson(static_cast(lambda), generator)); - } - ); + cpu_serial_kernel(iter, [generator](scalar_t lambda_val) -> scalar_t{ + return static_cast(sample_poisson(static_cast(lambda_val), generator)); }); + }); return ret; } Tensor _s_gamma_cpu(const Tensor& alpha, c10::optional gen) { Tensor ret = at::zeros(alpha.sizes(), alpha.options()); + auto iter = TensorIteratorConfig() + .add_output(ret) + .add_input(alpha) + .build(); AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "gamma_cpu", [&] { CPUGeneratorImpl* generator = get_generator_or_default(gen, detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(generator->mutex_); - CPU_tensor_apply2(ret, alpha, - [generator](scalar_t& ret_val, const scalar_t& alpha){ - - auto uniform_lambda = [generator] () { - at::uniform_real_distribution standard_uniform(0.0, 1.0); - return standard_uniform(generator); - }; - BaseSampler standard_uniform(uniform_lambda); - - auto normal_lambda = [generator] () { - at::normal_distribution normal(0.0, 1.0); - return normal(generator); - }; - BaseSampler standard_normal(normal_lambda); - auto sample = sample_gamma(alpha, standard_uniform, standard_normal); - ret_val = std::max(std::numeric_limits::min(), (scalar_t) sample); - } - ); + cpu_serial_kernel(iter, [generator](scalar_t alpha_val) -> scalar_t{ + auto uniform_lambda = [generator] () { + at::uniform_real_distribution standard_uniform(0.0, 1.0); + return standard_uniform(generator); + }; + BaseSampler standard_uniform(uniform_lambda); + + auto normal_lambda = [generator] () { + at::normal_distribution normal(0.0, 1.0); + return normal(generator); + }; + BaseSampler standard_normal(normal_lambda); + auto sample = sample_gamma(alpha_val, standard_uniform, standard_normal); + return std::max(std::numeric_limits::min(), (scalar_t) sample); }); + }); return ret; } @@ -401,35 +414,41 @@ Tensor _s_dirichlet_cpu(const Tensor& alpha, c10::optional gen) { // See Note [Acquire lock when using random generators] std::lock_guard lock(generator->mutex_); /* Generate gamma sample by casting alpha to double to prevent underflow. */ - CPU_tensor_apply2(gamma, alpha, - [generator](double& ret_val, const scalar_t& alpha){ - auto uniform_lambda = [generator] () { - at::uniform_real_distribution standard_uniform(0.0, 1.0); - return standard_uniform(generator); - }; - BaseSampler standard_uniform(uniform_lambda); - - auto normal_lambda = [generator] () { - at::normal_distribution normal(0.0, 1.0); - return normal(generator); - }; - BaseSampler standard_normal(normal_lambda); - auto sample = sample_gamma - (alpha, standard_uniform, standard_normal); - ret_val = std::max(std::numeric_limits::min(), sample); - } - ); + auto iter1 = TensorIteratorConfig() + .add_output(gamma) + .add_input(alpha) + .check_all_same_dtype(false) + .build(); + cpu_serial_kernel(iter1, [generator](scalar_t alpha_val) -> double{ + auto uniform_lambda = [generator] () { + at::uniform_real_distribution standard_uniform(0.0, 1.0); + return standard_uniform(generator); + }; + BaseSampler standard_uniform(uniform_lambda); + + auto normal_lambda = [generator] () { + at::normal_distribution normal(0.0, 1.0); + return normal(generator); + }; + BaseSampler standard_normal(normal_lambda); + auto sample = sample_gamma + (alpha_val, standard_uniform, standard_normal); + return std::max(std::numeric_limits::min(), sample); + }); /* Normalize and cast back to scalar_t. */ Tensor gamma_sum = gamma.sum(-1, true).expand(alpha.sizes()); - CPU_tensor_apply3(ret, gamma, gamma_sum, - [](scalar_t& ret_val, const double& gamma, const double& gamma_sum){ - ret_val = gamma / gamma_sum; - auto min_val = std::numeric_limits::min(); - auto max_val = std::nexttoward(static_cast(1.0f), 0.0f); - ret_val = std::min(max_val, std::max(min_val, ret_val)); - ret_val = static_cast(ret_val); - } - ); + auto iter2 = TensorIteratorConfig() + .add_output(ret) + .add_input(gamma) + .add_input(gamma_sum) + .check_all_same_dtype(false) + .build(); + cpu_serial_kernel(iter2, [](double gamma_val, double gamma_sum_val) -> scalar_t{ + auto ret_val = gamma_val / gamma_sum_val; + auto min_val = std::numeric_limits::min(); + auto max_val = std::nexttoward(static_cast(1.0f), 0.0f); + return std::min(max_val, std::max(min_val, static_cast(ret_val))); + }); }); return ret; } @@ -437,11 +456,21 @@ Tensor _s_dirichlet_cpu(const Tensor& alpha, c10::optional gen) { /* The largest consecutive integer representable in float32 (2^24) */ constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG); -Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional gen) { - TORCH_CHECK(result.device() == self.device(), "multinomial arguments must have the same device"); - TORCH_CHECK(self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim"); - TORCH_CHECK(at::isFloatingType(self.scalar_type()), - "multinomial only supports floating-point dtypes for input, got: ", self.scalar_type()); +Tensor& multinomial_out( + Tensor& result, + const Tensor& self, + int64_t n_sample, + bool with_replacement, + c10::optional gen) { + TORCH_CHECK( + result.device() == self.device(), + "multinomial arguments must have the same device"); + TORCH_CHECK( + self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim"); + TORCH_CHECK( + at::isFloatingType(self.scalar_type()), + "multinomial only supports floating-point dtypes for input, got: ", + self.scalar_type()); TORCH_CHECK(result.scalar_type() == ScalarType::Long, "multinomial expects Long tensor out, got: ", result.scalar_type()); TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples"); @@ -450,42 +479,76 @@ Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bo "cannot sample n_sample > prob_dist.size(-1) samples without replacement"); // Since the index tensor is float, numCategories cannot exceed max // float integer precision - TORCH_CHECK(n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, "number of categories cannot exceed 2^24"); - if (self.dim() > 1) { - int64_t n_dist = self.size(-2); - result.resize_({n_dist, n_sample}); - if (n_dist == 0) { return result; }; - } else { + TORCH_CHECK( + n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, + "number of categories cannot exceed 2^24"); + + if (self.dim() == 1) { result.resize_({n_sample}); + } else { + const int64_t n_dist = self.size(0); + result.resize_({n_dist, n_sample}); } - // Fast-path based on RobertoLat example. + if (result.numel() == 0) { + return result; + } + + // Fast-path for no replacement. // Reference: // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 // Half is not supported on CPU. - if (!with_replacement && - !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half)) { - if (result.numel()==0) return result; + TORCH_CHECK( + !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half), + "multinomial is not implemented for half on CPU"); + if (!with_replacement) { // Sanity checks on `self`. auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item(); - TORCH_CHECK(is_valid.to(), "probability tensor contains either `inf`, `nan` or element < 0"); + TORCH_CHECK( + is_valid.to(), + "probability tensor contains either `inf`, `nan` or element < 0"); bool zero_prob_condition; if (self.dim() == 1){ zero_prob_condition = (self.sum() == 0).item().to(); } else { zero_prob_condition = (self.sum(1) == 0).sum().item().to(); } - TORCH_CHECK(!zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)"); - auto rand = at::empty_like(self).uniform_(0, 1, gen); - rand.log_().div_(self); //save memory with inplace operations - auto vals = at::empty(result.sizes(), self.options()); - at::topk_out(vals, result, rand, n_sample); + TORCH_CHECK( + !zero_prob_condition, + "invalid multinomial distribution (sum of probabilities <= 0)"); + + // The algorithm is from gumbel softmax. + // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) + // Here we can apply exp to the formula which will not affect result of + // argmax or topk. Then we have + // s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1). + // We can also simplify the formula above by + // s = argmax( p / q ) where q ~ Exp(1) + Tensor q = at::empty_like(self).exponential_(1, gen); + // In theory the probability to generate 0 from exponential distribution is + // 0. However, on CUDA side there is a protection to avoid 0s, but on CPU + // side, there is a very low probability to generate 0 from + // exponential. The probability is about 2^(-DBL_MANT_DIG). We just + // ignore it here, but there may be some risk to get invalid output on CPU. + at::div_out(q, self, q); + if (n_sample == 1) { + at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true); + } else { + Tensor vals = at::empty(result.sizes(), self.options()); + at::topk_out(vals, result, q, n_sample); + } return result; } - multinomial_stub(result.device().type(), result, self, n_sample, with_replacement, gen); + + multinomial_with_replacement_stub( + result.device().type(), result, self, n_sample, gen); return result; } -Tensor multinomial(const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional gen) { +Tensor multinomial( + const Tensor& self, + int64_t n_sample, + bool with_replacement, + c10::optional gen) { Tensor result = at::empty({0}, self.options().dtype(kLong)); native::multinomial_out(result, self, n_sample, with_replacement, gen); return result; diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index 3f250ae09909a..a4854e1ced4d9 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -15,18 +15,29 @@ Tensor embedding(const Tensor & weight, const Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { TORCH_CHECK(weight.dim() >= 1, "'weight' must be at least 1-D"); auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding", indices_arg, kLong); + checkScalarTypes("embedding", indices_arg, {kLong, kInt}); + + auto zerofill_padding = [&](Tensor& embedding) { + if (padding_idx >= 0) { + embedding.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); + } + }; // TODO: use tensor.index() after improving perf if (indices.dim() == 1) { - return weight.index_select(0, indices); + auto out = weight.index_select(0, indices); + zerofill_padding(out); + return out; } auto size = indices.sizes().vec(); for (auto d : weight.sizes().slice(1)) { size.push_back(d); } - return weight.index_select(0, indices.reshape(-1)).view(size); + + auto out = weight.index_select(0, indices.reshape(-1)); + zerofill_padding(out); + return out.view(size); } Tensor embedding_backward( @@ -46,7 +57,7 @@ Tensor embedding_sparse_backward( int64_t padding_idx, bool scale_grad_by_freq) { auto indices_arg = TensorArg(indices_, "indices", 2); - checkScalarType("embedding_backward", indices_arg, kLong); + checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); // TODO: implement scale_grad_by_freq if (scale_grad_by_freq) { @@ -57,7 +68,7 @@ Tensor embedding_sparse_backward( Tensor indices = indices_; Tensor grad = grad_; if (padding_idx != -1) { - auto c = indices != padding_idx; + torch::List> c({indices != padding_idx}); indices = indices.index(c); grad = grad.index(c); } @@ -68,14 +79,14 @@ Tensor embedding_sparse_backward( // check if all our grad come from padding_idx if (grad.numel() == 0) { - return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options()), + return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)), at::empty({0, num_features}, dense_options), weight_size); } auto index = indices.reshape({1, -1}); auto values = grad.reshape({-1, num_features}); - return at::_sparse_coo_tensor_unsafe(index, values, weight_size); + return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size); } Tensor embedding_dense_backward_cpu( @@ -83,50 +94,48 @@ Tensor embedding_dense_backward_cpu( int64_t padding_idx, bool scale_grad_by_freq) { auto indices_arg = TensorArg(indices, "indices", 2); - checkScalarType("embedding_backward", indices_arg, kLong); + checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); + auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options()); auto indices_contig = indices.contiguous(); - auto indices_data = indices_contig.data_ptr(); int64_t numel = indices.numel(); + auto grad = grad_.contiguous().view({numel, grad_.size(-1)}); - std::unique_ptr counts; - if (scale_grad_by_freq) { - counts.reset(new int64_t[num_weights]); - for (int i = 0; i < numel; i++) { - counts[indices_data[i]] = 0; - } - for (int i = 0; i < numel; i++) { - counts[indices_data[i]]++; - } - } + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () { + auto indices_data = indices_contig.data_ptr(); - auto grad = grad_.contiguous().view({numel, grad_.size(-1)}); - auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options()); + std::unique_ptr counts; + if (scale_grad_by_freq) { + counts.reset(new index_t[num_weights]); + for (int i = 0; i < numel; i++) { + counts[indices_data[i]] = 0; + } + for (int i = 0; i < numel; i++) { + counts[indices_data[i]]++; + } + } - auto parallel_section = [&](int64_t start, int64_t end) { - for (int64_t i = 0; i < numel; i++) { - if (indices_data[i] != padding_idx) { - int64_t k = indices_data[i]; - if (k >= start && k < end) { - double scale = 1.0; - if (scale_grad_by_freq) { - scale /= counts[k]; + auto parallel_section = [&](index_t start, index_t end) { + for (int64_t i = 0; i < numel; i++) { + if (indices_data[i] != padding_idx) { + index_t k = indices_data[i]; + if (k >= start && k < end) { + double scale = 1.0; + if (scale_grad_by_freq) { + scale /= counts[k]; + } + grad_weight[k].add_(grad[i], scale); } - grad_weight[k].add_(grad[i], scale); } } - } - }; + }; - if (numel > 1000) { - // The strategy is to parallelize over sections of the vocabulary, so that - // thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread - // has to traverse the entire input, but the dominating factor is the axpy - // BLAS call. - at::parallel_for(0, num_weights, 0, parallel_section); - } else { - parallel_section(0, num_weights); - } + if (numel > 1000) { + at::parallel_for(0, num_weights, 0, parallel_section); + } else { + parallel_section(0, num_weights); + } + }); return grad_weight; } @@ -136,28 +145,30 @@ Tensor & embedding_renorm_cpu_( auto self_arg = TensorArg(self, "self", 1); auto indices_arg = TensorArg(indices, "indices", 2); checkDim("embedding_renorm_", self_arg, 2); - checkScalarType("embedding_renorm_", indices_arg, kLong); + checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt}); auto indices_contig = indices.contiguous(); - auto num_indices = indices.numel(); - auto data_ptr = indices_contig.data_ptr(); - auto sorted_indices = std::vector(data_ptr, data_ptr + num_indices); - std::sort(sorted_indices.begin(), sorted_indices.end(), std::less()); - - // Note that we cannot use at::parallel_for here because we perform operations on - // Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details. - for (auto i = 0; i < num_indices; i++) { - if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) { - continue; - } - auto row = self[sorted_indices[i]]; - auto norm = row.norm(norm_type).item(); - if (norm > max_norm) { - auto scale = max_norm / (norm + 1e-7); - row *= scale; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() { + auto data_ptr = indices_contig.data_ptr(); + auto sorted_indices = std::vector(data_ptr, data_ptr + num_indices); + std::sort(sorted_indices.begin(), sorted_indices.end()); + + // Note that we cannot use at::parallel_for here because we perform operations on + // Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details. + for (auto i = 0; i < num_indices; i++) { + if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) { + continue; + } + auto row = self[sorted_indices[i]]; + auto norm = row.norm(norm_type).item(); + if (norm > max_norm) { + auto scale = max_norm / (norm + 1e-7); + row *= scale; + } } - } + }); return self; } diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index a0b2a37ed6d9e..ef318285ed4ec 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -32,11 +32,11 @@ namespace native { template scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy); -static void make_offset2bag(const Tensor &offsets, const Tensor &indices, Tensor& offset2bag) { +static void make_offset2bag(const Tensor &offsets, Tensor& offset2bag) { offset2bag.index_add_( 0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1] offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1] - offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2] + offset2bag = offset2bag.cumsum(0, offset2bag.scalar_type()); // offset2bag = [0 0 1 1 2] } namespace { @@ -52,18 +52,19 @@ bool isFastPathIndexSelectScale(const Tensor& src, const Tensor& scale, Tensor& // This function combines index_select (using select_indices as the index) and // index_add (using add_indices as the index), without creating an intermediary // tensor to hold the selected embeddings -template -void index_select_add(const Tensor &select_indices, +template +typename std::enable_if::value, void>::type +index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, Tensor &output, const Tensor& /*offsets*/, bool /*include_last_offset*/) { AT_ASSERT(select_indices.numel() == add_indices.numel()); - auto* add_indices_data = add_indices.data_ptr(); - auto* select_indices_data = select_indices.data_ptr(); - auto* src_data = src.data_ptr(); - auto* output_data = output.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); + auto* src_data = src.data_ptr(); + auto* output_data = output.data_ptr(); auto numel = add_indices.numel(); int64_t ddim = src.size(1); auto src_stride0 = src.stride(0); @@ -72,29 +73,30 @@ void index_select_add(const Tensor &select_indices, auto output_stride1 = output.stride(1); for (int64_t i = 0; i < numel; i++) { - THBlas_axpy(ddim, 1, + THBlas_axpy(ddim, 1, src_data + src_stride0 * select_indices_data[i], src_stride1, output_data + output_stride0 * add_indices_data[i], output_stride1); } } -template<> -void index_select_add(const Tensor &select_indices, +template +typename std::enable_if::value, void>::type +index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, Tensor &output, const Tensor& offsets, bool include_last_offset) { int64_t ddim = src.size(1); - auto* select_indices_data = select_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); auto* output_data = output.data_ptr(); if (isFastPathIndexSelect(src, output)) { auto src_contig = src.contiguous(); auto* src_data = src_contig.data_ptr(); int64_t output_size = offsets.numel() - 1; - auto* offsets_data = offsets.data_ptr(); - std::vector offsets_include_last; + auto* offsets_data = offsets.data_ptr(); + std::vector offsets_include_last; if (include_last_offset) { output_size = offsets.numel() - 1; @@ -103,15 +105,15 @@ void index_select_add(const Tensor &select_indices, offsets_include_last.resize(offsets.numel() + 1); std::memcpy( offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * offsets.numel()); + offsets.data_ptr(), + sizeof(index_t) * offsets.numel()); offsets_include_last[offsets.numel()] = select_indices.numel(); offsets_data = offsets_include_last.data(); } #ifdef USE_FBGEMM - auto kernel_fp32_i64 = - fbgemm::GenerateEmbeddingSpMDM( + auto kernel_fp32_index_t = + fbgemm::GenerateEmbeddingSpMDM( /* block_size */ddim, /* has_weight */false, /* normalize_by_lengths */false, @@ -121,9 +123,9 @@ void index_select_add(const Tensor &select_indices, ); #endif at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { + 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { #ifdef USE_FBGEMM - kernel_fp32_i64( + kernel_fp32_index_t( /* output_size */end_idx - start_idx, /* index_size */offsets_data[end_idx] - offsets_data[start_idx], /* data_size */src.size(0), @@ -150,7 +152,7 @@ void index_select_add(const Tensor &select_indices, } else { AT_ASSERT(select_indices.numel() == add_indices.numel()); auto* src_data = src.data_ptr(); - auto* add_indices_data = add_indices.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); auto src_stride0 = src.stride(0); auto src_stride1 = src.stride(1); auto output_stride0 = output.stride(0); @@ -172,8 +174,9 @@ void index_select_add(const Tensor &select_indices, // index_select (using select_indices as the index) // mul (scaling by per_sample_weights) // index_add (using add_indices as the index) -template -static void index_select_scale_add(const Tensor &select_indices, +template +static typename std::enable_if::value, void>::type +index_select_scale_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &scale, const Tensor &src, @@ -181,10 +184,10 @@ static void index_select_scale_add(const Tensor &select_indices, const Tensor& /*offsets*/, bool /*include_last_offset*/) { AT_ASSERT(select_indices.numel() == add_indices.numel()); - auto* add_indices_data = add_indices.data_ptr(); - auto* select_indices_data = select_indices.data_ptr(); - auto* src_data = src.data_ptr(); - auto* output_data = output.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); + auto* src_data = src.data_ptr(); + auto* output_data = output.data_ptr(); auto numel = add_indices.numel(); int64_t ddim = src.size(1); auto src_stride0 = src.stride(0); @@ -192,7 +195,7 @@ static void index_select_scale_add(const Tensor &select_indices, auto output_stride0 = output.stride(0); auto output_stride1 = output.stride(1); - auto* scale_data = scale.data_ptr(); + auto* scale_data = scale.data_ptr(); auto scale_stride = scale.stride(0); for (int64_t i = 0; i < numel; i++) { @@ -205,8 +208,9 @@ static void index_select_scale_add(const Tensor &select_indices, } } -template<> -void index_select_scale_add(const Tensor &select_indices, +template +typename std::enable_if::value, void>::type +index_select_scale_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &scale, const Tensor &src, @@ -215,15 +219,15 @@ void index_select_scale_add(const Tensor &select_indices, bool include_last_offset) { int64_t ddim = src.size(1); auto* scale_data = scale.data_ptr(); - auto* select_indices_data = select_indices.data_ptr(); + auto* select_indices_data = select_indices.data_ptr(); auto* output_data = output.data_ptr(); if (isFastPathIndexSelectScale(src, scale, output)) { auto src_contig = src.contiguous(); auto* src_data = src_contig.data_ptr(); int64_t output_size = offsets.numel() - 1; - auto* offsets_data = offsets.data_ptr(); - std::vector offsets_include_last; + auto* offsets_data = offsets.data_ptr(); + std::vector offsets_include_last; if (include_last_offset) { output_size = offsets.numel() - 1; @@ -232,15 +236,15 @@ void index_select_scale_add(const Tensor &select_indices, offsets_include_last.resize(offsets.numel() + 1); std::memcpy( offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * offsets.numel()); + offsets.data_ptr(), + sizeof(index_t) * offsets.numel()); offsets_include_last[offsets.numel()] = select_indices.numel(); offsets_data = offsets_include_last.data(); } #ifdef USE_FBGEMM - auto kernel_fp32_i64 = - fbgemm::GenerateEmbeddingSpMDM( + auto kernel_fp32_index_t = + fbgemm::GenerateEmbeddingSpMDM( /* block_size */ddim, /* has_weight */true, /* normalize_by_lengths */false, @@ -250,9 +254,9 @@ void index_select_scale_add(const Tensor &select_indices, ); #endif at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { + 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { #ifdef USE_FBGEMM - kernel_fp32_i64( + kernel_fp32_index_t( /* output_size */end_idx - start_idx, /* index_size */offsets_data[end_idx] - offsets_data[start_idx], /* data_size */src.size(0), @@ -279,7 +283,7 @@ void index_select_scale_add(const Tensor &select_indices, } else { AT_ASSERT(select_indices.numel() == add_indices.numel()); auto* src_data = src.data_ptr(); - auto* add_indices_data = add_indices.data_ptr(); + auto* add_indices_data = add_indices.data_ptr(); auto src_stride0 = src.stride(0); auto src_stride1 = src.stride(1); auto output_stride0 = output.stride(0); @@ -308,7 +312,7 @@ static at::Tensor make_bag_size( const bool requires_grad) { at::Tensor bag_size; if (mode == MODE_MEAN || mode == MODE_MAX) { - bag_size = at::zeros(offsets.sizes(), indices.options()); + bag_size = at::zeros(offsets.sizes(), offsets.options()); // Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards) if (offsets.size(0) != 1) { bag_size.slice(0, 0, bag_size.size(0) - 1, 1) = @@ -318,7 +322,7 @@ static at::Tensor make_bag_size( bag_size[-1] = indices.size(0) - offsets[-1]; } else if (requires_grad) { // in MODE_SUM, only allocate bag_size if we need gradients - bag_size = at::empty(offsets.sizes(), indices.options()); + bag_size = at::empty(offsets.sizes(), offsets.options()); } return bag_size; } @@ -384,35 +388,36 @@ std::tuple embedding_bag_cpu_max( } auto max_indices = at::zeros({numBags, featureSize}, indices.options()); - - auto* indices_data = indices.data_ptr(); - auto* offset2bag_data = offset2bag.data_ptr(); - - auto* max_indices_data = max_indices.data_ptr(); - auto max_indices_stride = max_indices.stride(0); - - auto* weight_data = weight.data_ptr(); - auto* output_data = output.data_ptr(); - auto weight_stride0 = weight.stride(0); - auto weight_stride1 = weight.stride(1); - auto output_stride = output.stride(0); - - for (int i = 0; i < numIndices; i++) { - auto bag = offset2bag_data[i]; - auto word_idx = indices_data[i]; - - for (int dim = 0; dim < featureSize; dim++) { - auto& current_item = output_data[output_stride * bag + dim]; - auto weight_item = - weight_data[weight_stride0 * word_idx + dim * weight_stride1]; - bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag; - - if (is_first_for_bag || weight_item > current_item) { - current_item = weight_item; - max_indices_data[max_indices_stride * bag + dim] = word_idx; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max", [&] { + auto* indices_data = indices.data_ptr(); + auto* offset2bag_data = offset2bag.data_ptr(); + + auto* max_indices_data = max_indices.data_ptr(); + auto max_indices_stride = max_indices.stride(0); + + auto* weight_data = weight.data_ptr(); + auto* output_data = output.data_ptr(); + auto weight_stride0 = weight.stride(0); + auto weight_stride1 = weight.stride(1); + auto output_stride = output.stride(0); + + for (int i = 0; i < numIndices; ++i) { + auto bag = offset2bag_data[i]; + auto word_idx = indices_data[i]; + + for (int dim = 0; dim < featureSize; dim++) { + auto& current_item = output_data[output_stride * bag + dim]; + auto weight_item = + weight_data[weight_stride0 * word_idx + dim * weight_stride1]; + bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag; + + if (is_first_for_bag || weight_item > current_item) { + current_item = weight_item; + max_indices_data[max_indices_stride * bag + dim] = word_idx; + } } } - } + }); return std::tuple( output, offset2bag, bag_size, max_indices); @@ -429,19 +434,23 @@ std::tuple _embedding_bag_cpu_impl( bool include_last_offset, bool requires_grad) { auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag", indices_arg, kLong); + checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); auto offsets_arg = TensorArg(offsets, "offsets", 1); - checkScalarType("embedding_bag", offsets_arg, kLong); + checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt}); + checkSameType("embedding_bag", indices_arg, offsets_arg); auto weight_arg = TensorArg(weight, "weight", 1); checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble}); - int64_t offset_0 = offsets.data_ptr()[0]; - int64_t offset_n = offsets.data_ptr()[offsets.size(0)-1]; - TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence " - "in the mini-batch has to start from position 0. " - "However, got ", offsets[0]); - TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not " - "be greater than input's length ", indices.size(0), " but got offsets[-1] of ", - offset_n); + + AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() { + index_t offset_0 = offsets.data_ptr()[0]; + index_t offset_n = offsets.data_ptr()[offsets.size(0)-1]; + TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence " + "in the mini-batch has to start from position 0. " + "However, got ", offsets[0]); + TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not " + "be greater than input's length ", indices.size(0), " but got offsets[-1] of ", + offset_n); + }); if (per_sample_weights.defined()) { TORCH_CHECK(mode == MODE_SUM, @@ -494,9 +503,9 @@ std::tuple _embedding_bag_cpu_impl( // throw out of bounds error. So to keep it simple we just add one more // entry to the end then get rid of it after make_offset2bag. offset2bag = at::zeros( - {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + {indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] - make_offset2bag(offsets, indices, offset2bag); + make_offset2bag(offsets, offset2bag); offset2bag.resize_({indices.sizes()[0]}); @@ -505,14 +514,20 @@ std::tuple _embedding_bag_cpu_impl( } if (mode == MODE_MEAN || mode == MODE_SUM) { - AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() { - if (per_sample_weights.defined()) { - AT_ASSERT(mode == MODE_SUM); - index_select_scale_add( - indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset); - } else { - index_select_add(indices, offset2bag, weight, output, offsets, include_last_offset); - } + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", + [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu", + [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() { + if (per_sample_weights.defined()) { + AT_ASSERT(mode == MODE_SUM); + index_select_scale_add( + indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset); + } else { + index_select_add(indices, offset2bag, weight, output, offsets, include_last_offset); + } + }); }); auto ret = apply_bag_size(offsets, indices, mode, output, bag_size); return std::tuple(ret, offset2bag, bag_size, bag_size); @@ -598,23 +613,24 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices, bool sparse, const Tensor& per_sample_weights) { auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag", indices_arg, kLong); + checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); checkContiguous("embedding_bag", indices_arg); auto offsets_arg = TensorArg(offsets, "offsets", 1); - checkScalarType("embedding_bag", offsets_arg, kLong); + checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt}); + checkSameType("embedding_bag", indices_arg, offsets_arg); checkContiguous("embedding_bag", offsets_arg); Tensor offset2bag_; if (indices.numel() != 0 && offset2bag.numel() == 0) { offset2bag_ = at::zeros( - {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + {indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] - make_offset2bag(offsets, indices, offset2bag_); + make_offset2bag(offsets, offset2bag_); offset2bag_.resize_({indices.sizes()[0]}); } else { auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); + checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt}); checkContiguous("embedding_bag", offset2bag_arg); offset2bag_ = offset2bag; } @@ -648,11 +664,12 @@ static Tensor _embedding_bag_dense_backward_cpu_max( return index_grad_weight; } -static std::vector compute_counts( +template +static std::vector compute_counts( int64_t num_weights, - int64_t* indices_data, + index_t* indices_data, int64_t indices_length) { - std::vector counts(num_weights, 0); + std::vector counts(num_weights, 0); for (int i = 0; i < indices_length; i++) { counts[indices_data[i]]++; } @@ -668,12 +685,13 @@ static std::vector compute_counts( // counts_uniq: [3, 4, 6, 7] // // The unique indices can be found at index 0, 3, 4, 6. -static std::vector compute_counts_uniq( +template +static std::vector compute_counts_uniq( int64_t num_weights, - int64_t* indices_data, + index_t* indices_data, int64_t indices_length, - const std::vector& counts) { - std::vector counts_uniq; + const std::vector& counts) { + std::vector counts_uniq; counts_uniq.reserve(num_weights); int64_t o = 0; for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) { @@ -714,54 +732,66 @@ void _embedding_bag_dense_backward_cpu_sum_mean( per_sample_weights_stride = per_sample_weights->stride(0); } - auto* indices_data = indices.data_ptr(); - auto* offsets_data = offsets_.data_ptr(); - auto* offset2bag_data = offset2bag.data_ptr(); int64_t numel = indices.numel(); - auto counts = compute_counts(num_weights, indices_data, numel); - auto next_unique_index_idx = - compute_counts_uniq(num_weights, indices_data, numel, counts); - - auto loop = [&](int64_t start, int64_t end) { - for (int64_t i = start; i < end; i++) { - int64_t start = i == 0 ? 0 : next_unique_index_idx[i - 1]; - int64_t index = indices_data[start]; - for (int64_t j = start; j < next_unique_index_idx[i]; j++) { - int64_t source = offset2bag_data[j]; - double scale = 1.0; - if (per_sample_weights) { - AT_ASSERT(mode == MODE_SUM); - scale = per_sample_weights_data[*per_sample_weights_stride * j]; - } - if (scale_grad_by_freq) { - scale /= counts[indices_data[i]]; - } - if (mode == 1) { // MODE_MEAN - if (offsets_.size(0) == 1) { - auto bag_size = indices.size(0); - scale /= bag_size; - } else { - if (source == offsets_.size(0) - 1) { - scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1]; + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean", + [&indices, &offsets_, &offset2bag, &num_weights, &numel, &per_sample_weights, + &per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq, + &grad, &index_grad_weight] { + auto* indices_data = indices.data_ptr(); + auto* offsets_data = offsets_.data_ptr(); + auto* offset2bag_data = offset2bag.data_ptr(); + + auto counts = compute_counts(num_weights, indices_data, numel); + auto next_unique_index_idx = + compute_counts_uniq(num_weights, indices_data, numel, counts); + + auto loop = + [&next_unique_index_idx, &indices_data, &offset2bag_data, &per_sample_weights, + &mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq, + &counts, &offsets_, &indices, &offsets_data, &grad, &index_grad_weight](index_t start, index_t end) { + for (index_t i = start; i < end; i++) { + index_t start = i == 0 ? 0 : next_unique_index_idx[i - 1]; + index_t index = indices_data[start]; + for (index_t j = start; j < next_unique_index_idx[i]; j++) { + index_t source = offset2bag_data[j]; + double scale = 1.0; + if (per_sample_weights) { + AT_ASSERT(mode == MODE_SUM); + scale = per_sample_weights_data[*per_sample_weights_stride * j]; + } + if (scale_grad_by_freq) { + scale /= counts[indices_data[i]]; + } + if (mode == 1) { // MODE_MEAN + if (offsets_.size(0) == 1) { + auto bag_size = indices.size(0); + scale /= bag_size; } else { - scale /= offsets_data[source + 1] - offsets_data[source]; + if (source == offsets_.size(0) - 1) { + scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1]; + } else { + scale /= offsets_data[source + 1] - offsets_data[source]; + } } } + int64_t ddim = grad.size(1); + auto igwd = index_grad_weight.data_ptr(); + auto gd = grad.data_ptr(); + THBlas_axpy(ddim, (scalar_t)scale, gd + ddim * source, 1, + igwd + ddim * index, 1); } - int64_t ddim = grad.size(1); - auto igwd = index_grad_weight.data_ptr(); - auto gd = grad.data_ptr(); - THBlas_axpy(ddim, (scalar_t)scale, gd + ddim * source, 1, - igwd + ddim * index, 1); } + }; + + if (numel > 1000) { + at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop); + } else { + loop(0, (int64_t)next_unique_index_idx.size()); } - }; - if (numel > 1000) { - at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop); - } else { - loop(0, (int64_t)next_unique_index_idx.size()); - } + }); } Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_, @@ -820,20 +850,20 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template( auto output = at::zeros({num_samples}, grad.options()); auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag", indices_arg, kLong); + checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); checkContiguous("embedding_bag", indices_arg); Tensor offset2bag_; if (indices.numel() != 0 && offset2bag.numel() == 0) { offset2bag_ = at::zeros( - {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + {indices.sizes()[0] + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0] - make_offset2bag(offsets, indices, offset2bag_); + make_offset2bag(offsets, offset2bag_); offset2bag_.resize_({indices.sizes()[0]}); } else { auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); + checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt}); checkContiguous("embedding_bag", offset2bag_arg); offset2bag_ = offset2bag; } @@ -846,23 +876,31 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template( auto weight_stride0 = weight.stride(0); auto weight_stride1 = weight.stride(1); - auto* indices_data = indices.data_ptr(); - - // The following are contiguous - auto* output_data = output.data_ptr(); - auto* offset2bag_data = offset2bag_.data_ptr(); - - // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number. - parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) { - for (int64_t sample_idx = begin; sample_idx < end; sample_idx++) { - auto bag_idx = offset2bag_data[sample_idx]; - auto embedding_idx = indices_data[sample_idx]; - - output_data[sample_idx] = dot_impl( - embedding_features, - grad_data + grad_stride0 * bag_idx, grad_stride1, - weight_data + weight_stride0 * embedding_idx, weight_stride1); - } + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template", + [&indices, &output, &offset2bag_, &num_samples, &embedding_features, + &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1] () { + auto* indices_data = indices.data_ptr(); + + // The following are contiguous + auto* output_data = output.data_ptr(); + auto* offset2bag_data = offset2bag_.data_ptr(); + + // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number. + parallel_for(0, num_samples, 64, + [&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, + &weight_stride1, &offset2bag_data, &indices_data, &output_data](index_t begin, index_t end) { + for (index_t sample_idx = begin; sample_idx < end; sample_idx++) { + auto bag_idx = offset2bag_data[sample_idx]; + auto embedding_idx = indices_data[sample_idx]; + + output_data[sample_idx] = dot_impl( + embedding_features, + grad_data + grad_stride0 * bag_idx, grad_stride1, + weight_data + weight_stride0 * embedding_idx, weight_stride1); + } + }); }); return output; } diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index 73f7dcd619264..81bce59bc3533 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -4,20 +4,12 @@ #include #include #include +#include namespace at { namespace native { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -namespace { - template - inline void fill_fast(Tensor& self, Scalar value_scalar) { - auto value = value_scalar.to(); - scalar_t * dptr = static_cast(self.data_ptr()); - *dptr = value; - } -} // namspace - Tensor& fill_out(Tensor& self, Scalar value) { if (self.is_quantized()) { at::Tensor out = at::ones(self.sizes()).to(kFloat) * value; @@ -26,15 +18,8 @@ Tensor& fill_out(Tensor& self, Scalar value) { self.copy_(out); return self; } - // When filling a number to 1-element CPU tensor, we want to skip - // everything but manipulate data ptr directly. - // Ideally this fast pass should be implemented in TensorIterator, - // but we also want to skip compute_types which in not avoidable - // in TensorIterator for now. - if (self.device() == at::kCPU && self.numel() == 1 && !self.is_complex() && !value.isComplex()) { - AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() { - fill_fast(self, value);}); - return self; + if (self.device() == at::kCPU && self.numel() == 1) { + return at::detail::scalar_fill(self, value); } auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay @@ -106,7 +91,25 @@ Tensor& fill_diagonal_(Tensor& self, Scalar fill_value, bool wrap) { return self; } +Tensor& zero_cpu_(Tensor &self, int64_t nelements) { + void* ptr = self.data_ptr(); + if (nullptr == ptr) { + return self.fill_(0); + } + int64_t size_bytes = nelements * self.dtype().itemsize(); + if (size_bytes > 0) { + std::memset(ptr, 0, size_bytes); + } + return self; +} + Tensor& zero_(Tensor &self) { + int64_t nelements = at::prod_intlist(self.sizes()); + if (self.device() == at::kCPU && + self.is_non_overlapping_and_dense() && + nelements < internal::GRAIN_SIZE) { + return zero_cpu_(self, nelements); + } return self.fill_(0); } diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index 912b5116c4ccd..7f352d1bce855 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -3,125 +3,223 @@ namespace at { namespace native { -#define FOREACH_BINARY_OP_SCALAR(NAME) \ -void foreach_tensor_##NAME##_scalar_kernel_slow_(TensorList tensors, Scalar scalar) { \ +#define FOREACH_BINARY_OP_SCALAR(OP) \ +void foreach_tensor_##OP##_scalar_kernel_slow_(TensorList tensors, Scalar scalar) { \ check_foreach_api_restrictions(tensors); \ \ for (auto& t: tensors) { \ - t.NAME##_(scalar); \ + t.OP##_(scalar); \ } \ } \ \ -std::vector foreach_tensor_##NAME##_scalar_kernel_slow(TensorList tensors, Scalar scalar) { \ +std::vector foreach_tensor_##OP##_scalar_kernel_slow(TensorList tensors, Scalar scalar) { \ check_foreach_api_restrictions(tensors); \ \ std::vector result; \ result.reserve(tensors.size()); \ for (const auto& t: tensors) { \ - result.emplace_back(t.NAME(scalar)); \ + result.emplace_back(t.OP(scalar)); \ } \ \ return result; \ } -#define FOREACH_BINARY_OP_LIST(NAME) \ -std::vector foreach_tensor_##NAME##_list_kernel_slow(TensorList tensors1, TensorList tensors2) { \ +#define FOREACH_BINARY_OP_SCALARLIST(OP) \ +void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(tensors, scalars); \ + \ + for (size_t i = 0; i < tensors.size(); i++) { \ + tensors[i].OP##_(scalars[i]); \ + } \ +} \ + \ +std::vector foreach_tensor_##OP##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(tensors, scalars); \ + std::vector result; \ + result.reserve(tensors.size()); \ + for (size_t i = 0; i < tensors.size(); i++) { \ + result.emplace_back(tensors[i].OP(scalars[i])); \ + } \ + \ + return result; \ +} + +#define FOREACH_BINARY_OP_LIST(OP) \ +std::vector foreach_tensor_##OP##_list_kernel_slow(TensorList tensors1, TensorList tensors2) { \ check_foreach_api_restrictions(tensors1, tensors2); \ \ std::vector result; \ result.reserve(tensors1.size()); \ - for (int i = 0; i < tensors1.size(); i++) { \ - result.emplace_back(tensors1[i].NAME(tensors2[i])); \ + for (size_t i = 0; i < tensors1.size(); i++) { \ + result.emplace_back(tensors1[i].OP(tensors2[i])); \ } \ \ return result; \ } \ \ -void foreach_tensor_##NAME##_list_kernel_slow_(TensorList tensors1, TensorList tensors2) { \ +void foreach_tensor_##OP##_list_kernel_slow_(TensorList tensors1, TensorList tensors2) { \ check_foreach_api_restrictions(tensors1, tensors2); \ \ - for (int i = 0; i < tensors1.size(); i++) { \ - tensors1[i].NAME##_(tensors2[i]); \ + for (size_t i = 0; i < tensors1.size(); i++) { \ + tensors1[i].OP##_(tensors2[i]); \ } \ } -#define FOREACH_BINARY_OP_LIST_ALPHA(NAME) \ -std::vector foreach_tensor_##NAME##_list_kernel_slow(TensorList tensors1, TensorList tensors2, Scalar alpha) { \ +#define FOREACH_BINARY_OP_LIST_ALPHA(OP) \ +std::vector foreach_tensor_##OP##_list_kernel_slow(TensorList tensors1, TensorList tensors2, Scalar alpha) { \ check_foreach_api_restrictions(tensors1, tensors2); \ \ std::vector result; \ result.reserve(tensors1.size()); \ - for (int i = 0; i < tensors1.size(); i++) { \ - result.emplace_back(tensors1[i].NAME(tensors2[i], alpha)); \ + for (size_t i = 0; i < tensors1.size(); i++) { \ + result.emplace_back(tensors1[i].OP(tensors2[i], alpha)); \ } \ \ return result; \ } \ \ -void foreach_tensor_##NAME##_list_kernel_slow_(TensorList tensors1, TensorList tensors2, Scalar alpha) { \ +void foreach_tensor_##OP##_list_kernel_slow_(TensorList tensors1, TensorList tensors2, Scalar alpha) { \ check_foreach_api_restrictions(tensors1, tensors2); \ \ - for (int i = 0; i < tensors1.size(); i++) { \ - tensors1[i].NAME##_(tensors2[i], alpha); \ + for (size_t i = 0; i < tensors1.size(); i++) { \ + tensors1[i].OP##_(tensors2[i], alpha); \ } \ } -#define FOREACH_UNARY_OP(NAME) \ -std::vector foreach_tensor_##NAME##_slow(TensorList tensors) { \ +#define FOREACH_UNARY_OP(OP) \ +std::vector foreach_tensor_##OP##_slow(TensorList tensors) { \ check_foreach_api_restrictions(tensors); \ \ std::vector result; \ result.reserve(tensors.size()); \ for (const auto& t : tensors) { \ - result.emplace_back(t.NAME()); \ + result.emplace_back(t.OP()); \ } \ \ return result; \ } \ \ -void foreach_tensor_##NAME##_slow_(TensorList tensors) { \ +void foreach_tensor_##OP##_slow_(TensorList tensors) { \ check_foreach_api_restrictions(tensors); \ \ for (auto& t : tensors) { \ - t.NAME##_(); \ + t.OP##_(); \ } \ } -#define FOREACH_POINTWISE_OP(NAME) \ -std::vector foreach_tensor_##NAME##_slow(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ - TORCH_CHECK(input.size() > 0, "Tensor list must have at least one tensor."); \ - TORCH_CHECK(input.size() == tensors1.size(), "Tensor lists must be of the same length."); \ - TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must be of the same length."); \ - \ - std::vector result; \ - for (int i = 0; i < input.size(); i++) { \ - result.emplace_back(input[i].NAME(tensors1[i], tensors2[i], scalar)); \ - } \ - \ - return result; \ -} \ - \ -void foreach_tensor_##NAME##_slow_(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ - TORCH_CHECK(input.size() > 0, "Tensor list must have at least one tensor."); \ - TORCH_CHECK(input.size() == tensors1.size(), "Tensor lists must be of the same length."); \ - TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must be of the same length."); \ - \ - for (int i = 0; i < input.size(); i++) { \ - input[i].NAME##_(tensors1[i], tensors2[i], scalar); \ - } \ -} \ +#define FOREACH_POINTWISE_OP_SCALAR(OP) \ +std::vector foreach_tensor_##OP##_scalar_slow(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ + check_foreach_api_restrictions(input, tensors1, tensors2); \ + \ + std::vector result; \ + for (size_t i = 0; i < input.size(); i++) { \ + result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalar)); \ + } \ + \ + return result; \ +} \ + \ +void foreach_tensor_##OP##_scalar_slow_(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ + check_foreach_api_restrictions(input, tensors1, tensors2); \ + \ + for (size_t i = 0; i < input.size(); i++) { \ + input[i].OP##_(tensors1[i], tensors2[i], scalar); \ + } \ +} \ + +#define FOREACH_POINTWISE_OP_SCALARLIST(OP) \ +std::vector foreach_tensor_##OP##_scalarlist_slow(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + \ + std::vector result; \ + for (size_t i = 0; i < input.size(); i++) { \ + result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalars[i])); \ + } \ + \ + return result; \ +} \ + \ +void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + \ + for (size_t i = 0; i < input.size(); i++) { \ + input[i].OP##_(tensors1[i], tensors2[i], scalars[i]); \ + } \ +} \ FOREACH_BINARY_OP_LIST_ALPHA(add); FOREACH_BINARY_OP_LIST_ALPHA(sub); + FOREACH_BINARY_OP_SCALAR(add); FOREACH_BINARY_OP_SCALAR(sub); FOREACH_BINARY_OP_SCALAR(mul); FOREACH_BINARY_OP_SCALAR(div); + +FOREACH_BINARY_OP_SCALARLIST(add); +FOREACH_BINARY_OP_SCALARLIST(sub); +FOREACH_BINARY_OP_SCALARLIST(mul); +FOREACH_BINARY_OP_SCALARLIST(div); + FOREACH_BINARY_OP_LIST(mul); FOREACH_BINARY_OP_LIST(div); + FOREACH_UNARY_OP(sqrt); FOREACH_UNARY_OP(exp); -FOREACH_POINTWISE_OP(addcdiv); -FOREACH_POINTWISE_OP(addcmul); +FOREACH_UNARY_OP(abs); +FOREACH_UNARY_OP(acos); +FOREACH_UNARY_OP(asin); +FOREACH_UNARY_OP(atan); +FOREACH_UNARY_OP(ceil); +FOREACH_UNARY_OP(cos); +FOREACH_UNARY_OP(cosh); +FOREACH_UNARY_OP(erf); +FOREACH_UNARY_OP(erfc); +FOREACH_UNARY_OP(expm1); +FOREACH_UNARY_OP(floor); +FOREACH_UNARY_OP(log); +FOREACH_UNARY_OP(log10); +FOREACH_UNARY_OP(log1p); +FOREACH_UNARY_OP(log2); +FOREACH_UNARY_OP(neg); +FOREACH_UNARY_OP(tan); +FOREACH_UNARY_OP(tanh); +FOREACH_UNARY_OP(sin); +FOREACH_UNARY_OP(sinh); +FOREACH_UNARY_OP(round); +FOREACH_UNARY_OP(lgamma); +FOREACH_UNARY_OP(frac); +FOREACH_UNARY_OP(trunc); +FOREACH_UNARY_OP(reciprocal); +FOREACH_UNARY_OP(sigmoid); + +FOREACH_POINTWISE_OP_SCALAR(addcdiv); +FOREACH_POINTWISE_OP_SCALAR(addcmul); + +FOREACH_POINTWISE_OP_SCALARLIST(addcdiv); +FOREACH_POINTWISE_OP_SCALARLIST(addcmul); + +#define FOREACH_MAXIMUM_MINIMUM_OP(NAME) \ +std::vector foreach_tensor_##NAME##_slow(TensorList tensors1, TensorList tensors2) { \ + check_foreach_api_restrictions(tensors1, tensors2); \ + \ + std::vector result; \ + result.reserve(tensors1.size()); \ + for (size_t i = 0; i < tensors1.size(); i++) { \ + result.emplace_back(at::NAME(tensors1[i], tensors2[i])); \ + } \ + \ + return result; \ +} \ + +FOREACH_MAXIMUM_MINIMUM_OP(maximum) +FOREACH_MAXIMUM_MINIMUM_OP(minimum) + +void foreach_tensor_zero_slow_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + for (auto& t : tensors) { + t.zero_(); + } +} }} // namespace at::native diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index 5a7aced74702c..e915102de70cf 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -5,18 +5,19 @@ namespace at { namespace native { namespace { -// Set of foreach API restrictions -// - All tensors must be of the same dtype -// - All corresponding tensors must be of the same size void check_foreach_api_restrictions(TensorList tensors) { TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); auto expected_dtype = tensors[0].dtype(); - for (const auto& t : tensors) { TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); } } +void check_foreach_api_restrictions(TensorList tensors, ArrayRef scalars) { + check_foreach_api_restrictions(tensors); + TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list."); +} + void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) { TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor."); TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor."); @@ -31,17 +32,38 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) { } } +void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3) { + TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor."); + TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor."); + TORCH_CHECK(tensors3.size() > 0, "Tensor list must have at least one tensor."); + TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size()); + TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size()); + + auto expected_dtype = tensors1[0].dtype(); + + for (int i = 0; i < tensors1.size(); i++) { + TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); + TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); + TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes()); + TORCH_CHECK(tensors1[i].sizes() == tensors3[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors3[i].sizes()); + } +} + +void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, tensors3); + TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size()); +} + // To go via 'fast' path, several conditions must be satisfied // - All tensors must be on the same device // - All tensors must have strided layout // - All tensors must be non-overlapping and dense -// - All tensors must be on the same device // - Resulting tensor must have the same dtype as the input one -bool can_use_fast_route(TensorList tensors, Scalar scalar) { - TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); - auto expected_device = tensors[0].device(); - for (auto t : tensors) { +// Check if all tensors have the same device, layout, strides and are not overlapping and dense +bool has_same_attributes(Device expected_device, TensorList tensors) { + auto expected_strides = tensors[0].strides(); + for (const auto& t : tensors) { if (t.device() != expected_device) { return false; } @@ -50,86 +72,145 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) { return false; } - if (t.device() != expected_device) { + if (!t.is_non_overlapping_and_dense()) { return false; } - if (!t.is_non_overlapping_and_dense()) { + if (t.strides() != expected_strides) { return false; } + } + + return true; +} - // complex scalar + integral or boolean tensor will result in complex tensor - if (scalar.isComplex() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) { +bool will_promote_tensor(const Tensor& tensor, Scalar scalar) { + // complex scalar + integral or boolean tensor will result in complex tensor + if (scalar.isComplex() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) { + return false; + } + + // float scalar + integral or boolean tensor will result in float tensor + if (scalar.isFloatingPoint() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) { + return false; + } + + // integral scalar + boolean tensor will result in integral tensor + if (scalar.isIntegral(/*includeBool*/ false) && tensor.dtype() == at::kBool) { + return false; + } + return true; +} + +bool can_use_fast_route(TensorList tensors) { +#ifdef __HIP_PLATFORM_HCC__ + return false; +#else + auto expected_device = tensors[0].device(); + for (auto t : tensors) { + if (!has_same_attributes(expected_device, {t})) { return false; } + } + + return true; +#endif +} - // float scalar + integral or boolean tensor will result in float tensor - if (scalar.isFloatingPoint() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) { +bool can_use_fast_route(TensorList tensors, Scalar scalar) { +#ifdef __HIP_PLATFORM_HCC__ + return false; +#else + auto expected_device = tensors[0].device(); + + for (auto t : tensors) { + if (!has_same_attributes(expected_device, {t})) { return false; } - // integral scalar + boolean tensor will result in integral tensor - if (scalar.isIntegral(/*includeBool*/ false) && t.dtype() == at::kBool) { + if (!will_promote_tensor(t, scalar)) { return false; } } return true; +#endif +} + +bool can_use_fast_route(TensorList tensors, ArrayRef scalars) { + return can_use_fast_route(tensors); } bool can_use_fast_route(TensorList tensors1, TensorList tensors2) { +#ifdef __HIP_PLATFORM_HCC__ + return false; +#else auto expected_device = tensors1[0].device(); - for (int64_t i = 0; i < tensors1.size(); i++) { - TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors from tensor lists have different size."); - - if (tensors1[i].device() != expected_device || - tensors2[i].device() != expected_device) { - return false; - } - - if (tensors1[i].layout() != at::kStrided || - tensors2[i].layout() != at::kStrided) { + if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i]})) { return false; } + } - if (tensors1[i].device() != expected_device || - tensors2[i].device() != expected_device) { - return false; - } + return true; +#endif +} - if (tensors1[i].strides() != tensors2[i].strides()) { +bool can_use_fast_route(TensorList tensors1, TensorList tensors2, Scalar scalar) { +#ifdef __HIP_PLATFORM_HCC__ + return false; +#else + auto expected_device = tensors1[0].device(); + for (int64_t i = 0; i < tensors1.size(); i++) { + if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i]})) { return false; } - if (!tensors1[i].is_non_overlapping_and_dense() || - !tensors2[i].is_non_overlapping_and_dense()) { + if (!will_promote_tensor(tensors1[i], scalar)) { return false; } } return true; +#endif } -bool can_use_fast_route(TensorList tensors) { - TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); - auto expected_device = tensors[0].device(); - - for (auto t : tensors) { - if (t.layout() != at::kStrided) { +bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3) { +#ifdef __HIP_PLATFORM_HCC__ + return false; +#else + auto expected_device = tensors1[0].device(); + for (int64_t i = 0; i < tensors1.size(); i++) { + if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i], tensors3[i]})) { return false; } + } - if (!t.is_non_overlapping_and_dense()) { + return true; +#endif +} + +bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, Scalar scalar) { +#ifdef __HIP_PLATFORM_HCC__ + return false; +#else + auto expected_device = tensors1[0].device(); + for (int64_t i = 0; i < tensors1.size(); i++) { + if (!has_same_attributes(expected_device, {tensors1[i], tensors2[i], tensors3[i]})) { return false; } - if (t.device() != expected_device) { + if (!will_promote_tensor(tensors1[i], scalar)) { return false; } } return true; +#endif +} + +bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef scalars) { + return can_use_fast_route(tensors1, tensors2, tensors3); } } diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index 59242e0e6c034..667cbe8f07b33 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -422,11 +423,11 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, for (int64_t w = 0; w < out_W; ++w) { // get the corresponding input x, y, z co-ordinates from grid scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - scalar_t ix = *grid_ptr_NHW; - scalar_t iy = grid_ptr_NHW[grid_sCoor]; + scalar_t x = *grid_ptr_NHW; + scalar_t y = grid_ptr_NHW[grid_sCoor]; - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get corner pixel values from (x, y) @@ -483,6 +484,43 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, *out_ptr_NCHW = static_cast(0); } } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + // grid_sampler_compute_source_index will "clip the value" of idx depends on the padding, + // which would cause calculation to be wrong, + // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix = floor(x) = -1 + // There would be more problem in reflection padding, since the -1 and +1 direction is not fixed in boundary condition + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + scalar_t ix_nw = std::floor(ix); + scalar_t iy_nw = std::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t *inp_ptr_NC = inp_ptr_N; + scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { + scalar_t coefficients[4]; + + // Interpolate 4 values in the x directon + for (int64_t i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + // Interpolate in the y direction + *out_ptr_NCHW = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + ty); + } } } } @@ -547,13 +585,13 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) { // get the corresponding input x, y co-ordinates from grid scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - scalar_t ix = *grid_ptr_NHW; - scalar_t iy = grid_ptr_NHW[grid_sCoor]; + scalar_t x = *grid_ptr_NHW; + scalar_t y = grid_ptr_NHW[grid_sCoor]; // multipliers for gradients on ix, iy scalar_t gix_mult, giy_mult; - ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); - iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get corner pixel values from (x, y) @@ -628,6 +666,55 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW); } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult); + iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult); + + scalar_t ix_nw = std::floor(ix); + scalar_t iy_nw = std::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t x_coeffs[4]; + scalar_t y_coeffs[4]; + scalar_t x_coeffs_grad[4]; + scalar_t y_coeffs_grad[4]; + + get_cubic_upsample_coefficients(x_coeffs, tx); + get_cubic_upsample_coefficients(y_coeffs, ty); + get_cubic_coefficients_grad(x_coeffs_grad, tx); + get_cubic_coefficients_grad(y_coeffs_grad, ty); + + scalar_t gix = static_cast(0); + scalar_t giy = static_cast(0); + + scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; + scalar_t *inp_ptr_NC = inp_ptr_N; + + for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + for (int64_t i = 0; i < 4; ++i) { + for (int64_t j = 0; j < 4; ++j) { + + // set input gradient + add_value_bounded(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners); + + // set grid gradient + scalar_t val = get_value_bounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; } } } @@ -640,6 +727,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + // AVX gather instructions use signed 32-bit offsets to gather float values. // Check for possible overflow and fallback to scalar implementation if (input.scalar_type() != kDouble) { @@ -682,6 +770,7 @@ Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid, std::tuple grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + // AVX gather instructions use signed 32-bit offsets to gather float values. // Check for possible overflow and fallback to scalar implementation if (input.scalar_type() != kDouble) { @@ -757,6 +846,10 @@ Tensor grid_sampler(const Tensor& input, const Tensor& grid, grid.size(-1) == input.dim() - 2, "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last " "dimension, but got grid with sizes ", grid.sizes()); + TORCH_CHECK( + !(input.dim() == 5 && static_cast(interpolation_mode) == GridSamplerInterpolation::Bicubic), + "grid_sampler(): bicubic interpolation only supports 4D input" + ); for (int64_t i = 2; i < input.dim(); i++) { TORCH_CHECK(input.size(i) > 0, "grid_sampler(): expected input to have non-empty spatial dimensions, " diff --git a/aten/src/ATen/native/GridSampler.h b/aten/src/ATen/native/GridSampler.h index ebafc9727061e..effc322c0d3a0 100644 --- a/aten/src/ATen/native/GridSampler.h +++ b/aten/src/ATen/native/GridSampler.h @@ -7,7 +7,7 @@ namespace at { namespace native { namespace detail { - enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; enum class GridSamplerPadding {Zeros, Border, Reflection}; } // namespace detail @@ -139,14 +139,12 @@ static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_l } } -// Computes the pixel source index value for a grid coordinate -template -static inline scalar_t grid_sampler_compute_source_index( - scalar_t coord, - int64_t size, - GridSamplerPadding padding_mode, - bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); +// Mapping the out-of-boundary points back into boundary +// This would only affect padding_mode=border or reflection +template +static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { if (padding_mode == GridSamplerPadding::Border) { // clip coordinates to image borders coord = clip_coordinates(coord, size); @@ -163,6 +161,18 @@ static inline scalar_t grid_sampler_compute_source_index( return coord; } +// Computes the pixel source index value for a grid coordinate +template +static inline scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + // grid_sampler_compute_source_index_set_grad works similarly to // grid_sampler_compute_source_index except that it also returns the // `d output / d input` via pointer argument `grad_in`. @@ -202,6 +212,30 @@ static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } +template +static inline scalar_t get_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + template static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w, int64_t sH, int64_t sW, int64_t H, int64_t W, @@ -221,4 +255,47 @@ static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w, } } +template +static inline void add_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static inline void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index 743f18c00c925..d59e3a3bf16ee 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -10,7 +10,7 @@ namespace at { namespace native { namespace { - + static void im2col_out_cpu_template( Tensor& output, const Tensor& input_, diff --git a/aten/src/ATen/native/IndexingUtils.h b/aten/src/ATen/native/IndexingUtils.h index 94d61b02dd0b9..92f6957f25ad7 100644 --- a/aten/src/ATen/native/IndexingUtils.h +++ b/aten/src/ATen/native/IndexingUtils.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include @@ -15,40 +16,45 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, } -static std::vector expandTensors(const Tensor & self, TensorList indices) { +static std::vector expandTensors(const Tensor & self, const torch::List>& indices) { // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors std::vector result; - for (const auto & index : indices) { - if (index.scalar_type() == kByte || index.scalar_type() == kBool) { - if (index.scalar_type() == kByte) { - TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \ - " please use a dtype torch.bool instead."); - } - // The sizes of the ByteTensor mask or bool tensor must match the sizes of the - // corresponding dimensions in self - for (int64_t j = 0; j < index.dim(); j++) { - int64_t srcIdx = result.size() + j; - if (index.size(j) != self.size(srcIdx)) { - invalid_mask(self, srcIdx, index, j); + for (c10::optional index_opt : indices) { + if (!index_opt.has_value()) { + result.emplace_back(); + } else { + Tensor index = std::move(*index_opt); + if (index.scalar_type() == kByte || index.scalar_type() == kBool) { + if (index.scalar_type() == kByte) { + TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \ + " please use a dtype torch.bool instead."); } + // The sizes of the ByteTensor mask or bool tensor must match the sizes of the + // corresponding dimensions in self + for (int64_t j = 0; j < index.dim(); j++) { + int64_t srcIdx = result.size() + j; + if (index.size(j) != self.size(srcIdx)) { + invalid_mask(self, srcIdx, index, j); + } + } + // Replace with nonzeros + auto nonzero = index.nonzero(); + for (int64_t j = 0; j < index.dim(); j++) { + result.emplace_back(nonzero.select(1, j)); + } + } else { + result.emplace_back(std::move(index)); } - // Replace with nonzeros - auto nonzero = index.nonzero(); - for (int64_t j = 0; j < index.dim(); j++) { - result.emplace_back(nonzero.select(1, j)); - } - } else { - result.emplace_back(index); } } return result; } -static void checkIndexTensorTypes(TensorList indices) { - for (auto& tensor : indices) { - if (tensor.defined()) { - auto scalarType = tensor.scalar_type(); +static void checkIndexTensorTypes(const torch::List>& indices) { + for (c10::optional tensor : indices) { + if (tensor.has_value() && tensor->defined()) { + auto scalarType = tensor->scalar_type(); if (scalarType != kLong && scalarType != kByte && scalarType != kBool) { TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors"); } @@ -56,6 +62,15 @@ static void checkIndexTensorTypes(TensorList indices) { } } +inline torch::List> toListOfOptionalTensors(ArrayRef list) { + torch::List> result; + result.reserve(list.size()); + for (const Tensor& a : list) { + result.push_back(a); + } + return result; +} + static bool hasContiguousSubspace(TensorList tl) { // true if all the non-null tensors are adjacent auto isDefined = [](const Tensor & tensor){ return tensor.defined(); }; diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 1d839eff28e0d..b9a9cd5e5ad06 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -136,241 +136,336 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra return result; } -Tensor einsum(std::string eqn, TensorList tensors) { - constexpr size_t number_of_letters = 26; - std::string in_eqn; - size_t pos; - // The equation is given in terms of single lowercase letters ('a'..'z') and potentially an ellipsis. - // Internally, we represent it using indices from 0 to num_total_dimensions, with each letter - // mapped to an index and the ellipsis ('...') being mapped to a number of consequtive indices. - // The mapping of letters to internal indices is given in letter_mapping. A value of -1 means that - // the letter has not been assigned an index yet (because it has not been seen). - // The ellipsis is defined by first_ell_idx (the first index) and num_ell_idxes (the number of indices). - // A value of -1 for num_ell_idxes specifies that we have not seen an ellipsis yet. - // Note: The internal indices are NOT the dimensions used internally. There is a mapping to them below. - - std::array letter_mapping; // map letter to internal (numerical) label - letter_mapping.fill(-1); - int64_t num_ell_idxes = -1; - int64_t first_ell_idx = 0; - - // The internal representation of the left hand side fo the equation (with ellipsis expanded) is stored in input_op_idxes. - // For each operand, we have a vector mapping each dimension to an internal index. - // We also keep track of the number of occurrences for each letter (to infer a right hand side if not given) and - // of the last occurrence of each index. - std::vector> input_op_idxes; // the parsed operand indices - std::array num_letter_occurrences; // number of occurrence in the equation of this letter - num_letter_occurrences.fill(0); - std::vector last_idx_occurrence; // the last operator (left to right) using this index - - if ((pos = eqn.find("->")) != std::string::npos) { // check whether we have a right hand side. in_eq is the left hand side - in_eqn = eqn.substr(0, pos); - } else { - in_eqn = eqn; - } - // remove spaces for einsum compatibility (#9929) - in_eqn.erase(std::remove_if(in_eqn.begin(), in_eqn.end(), isspace), in_eqn.end()); - - // next we parse in_eq (the left hand side) by iterating. It is a string of comma separated terms per index - int64_t operand = 0; - std::stringstream eqn_stream(in_eqn); - std::string term; - int64_t num_total_idxes = 0; - while (! eqn_stream.eof()) { - std::getline(eqn_stream, term, ','); // term = string with indices of current term - TORCH_CHECK((int64_t) tensors.size()>operand, "more operands in equation than tensors"); // we cannot have a longer equation than operands. We need to check here before we use the dimension - - int64_t ell_char_count = 0; // handling of ellipsis '...' is a bit tedious, we count the '.' - // if there is an ellipsis, the number of dimensions it represents must be total dim - letter dimensions - int64_t candidate_num_ell_idxes = tensors[operand].dim() - term.size() + 3; - int64_t dims_in_term = 0; // dimensions we have seen - std::vector current_op_idxes; // mapping of operand dimensions to indices for current term - for (auto &c : term) { // c = character with a single letter or '.' - if (c == '.') { - ell_char_count++; - TORCH_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in term ", operand, " of the equation"); - if (ell_char_count == 3) { // this completes the ellipsis - if (num_ell_idxes == -1) { // if we have not seen an ellipsis before, keep track of indices and size - first_ell_idx = num_total_idxes; - num_ell_idxes = candidate_num_ell_idxes; - num_total_idxes += num_ell_idxes; - } - else { // we have seen an ellipsis before, so we check compatibility - TORCH_CHECK(candidate_num_ell_idxes == num_ell_idxes, - "ellipsis must represent ", num_ell_idxes, " dimensions in all terms"); - } - for (int64_t i = 0; i < num_ell_idxes; ++i) { // map ellipsis dimensions in operand to indices - current_op_idxes.push_back(first_ell_idx + i); - last_idx_occurrence.push_back(operand); - } - dims_in_term += num_ell_idxes; // keep track of dimensions - } - } else { // a letter (hopefully) - TORCH_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis, operand ", operand); - TORCH_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices"); - int64_t letter_num = c-'a'; // letter_num = position in letter_mapping - if (letter_mapping[letter_num] == -1) { // new letter, add internal index and mapping - letter_mapping[letter_num] = num_total_idxes; - num_total_idxes++; - last_idx_occurrence.push_back(operand); - } else { // letter we have already seen - last_idx_occurrence[letter_mapping[letter_num]] = operand; - } - num_letter_occurrences[letter_num]++; - current_op_idxes.push_back(letter_mapping[letter_num]); - dims_in_term++; - } +// There are roughly three parts to compute einsum: +// 1. Parse equation to extract the labels for each input operand and output +// 2. Unsqueeze missing dimensions from input operands and permute to align them +// 3. Compute result by multiplying input operands and summing contraction +// dimensions We do the last part by reducing to bmm. +Tensor einsum(std::string equation, TensorList operands) { + TORCH_CHECK(!operands.empty(), "einsum() must provide at least one operand"); + checkDeviceType("einsum()", operands, operands[0].device().type()); + + // Code used to identify ELLIPSIS ("...") + constexpr int ELLIPSIS = '.'; + + // Find arrow (->) to split equation into lhs and rhs + const auto arrow_pos = equation.find("->"); + const auto lhs = equation.substr(0, arrow_pos); + + const auto num_ops = operands.size(); + + // Convert labels for input operands into an index in [0, 25] and store + // them in op_labels for each operand along with ELLIPSIS if present. + std::vector> op_labels(num_ops); + bool found_ell = false; + std::size_t curr_op = 0; + for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) { + switch (lhs[i]) { + case ' ': + // Ignore spaces + break; + + case '.': + TORCH_CHECK( + // Only one ellipsis per operand can be given + !found_ell, + "einsum() found \'.\' for operand ", + curr_op, + " for which an ellipsis was already found"); + TORCH_CHECK( + // Ensure it's a valid ellipsis + i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.', + "einsum() found \'.\' for operand ", + curr_op, + " that is not part of any ellipsis"); + op_labels[curr_op].push_back(ELLIPSIS); + found_ell = true; + break; + + case ',': + // Move onto next operand + ++curr_op; + TORCH_CHECK( + curr_op < num_ops, + "einsum() fewer operands were provided than specified in the equation"); + found_ell = false; + break; + + default: + // Parse label + TORCH_CHECK( + lhs[i] >= 'a' && lhs[i] <= 'z', + "einsum() operand subscript must be in range [a, z] but found ", + lhs[i], + " for operand ", + curr_op); + // Convert label to index in [0, 25] and store + op_labels[curr_op].push_back(lhs[i] - 'a'); } - TORCH_CHECK(dims_in_term == tensors[operand].dim(), "dimension mismatch for operand ", operand, ": equation ", dims_in_term, " tensor ", tensors[operand].dim()); - input_op_idxes.push_back(std::move(current_op_idxes)); - operand++; } - // in the check below, we need ==, but > is captured above, so the error message can be specific that it is <. - TORCH_CHECK((int64_t) tensors.size()==operand, "more tensors than operands in equation"); - - // the following parses or infers output (right hand side) - // it also assigns the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors) - // for the output indices. -1 means that the index has not been assigned a dimension yet - std::vector idxes_to_preprocessed_dims(num_total_idxes, -1); // the position of the index in the tensor dimensions - int64_t num_output_dims = 0; - if (pos != std::string::npos) { // parse the user provided right hand side - int64_t ell_char_count = 0; - for (auto &c : eqn.substr(pos+2)) { - if (c == '.') { // '.' as part of ellipsis - ell_char_count++; - TORCH_CHECK(ell_char_count <= 3, "can only have '.' in one ellispis '...' in right hand side of the equation"); - if (ell_char_count == 3) { // ellipsis complete - TORCH_CHECK(num_ell_idxes >= 0, "ellipsis '...' may only appear in right hand side if it does in left hand side"); - for (int64_t i = 0; i < num_ell_idxes; ++i) { - idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims; - num_output_dims++; - } - } - } else if (! isspace(c)) { // letter (hopefully) - TORCH_CHECK((ell_char_count == 0) || (ell_char_count == 3), "'.' must only occur in ellipsis in the right hand side"); - TORCH_CHECK(('a' <= c) && (c <= 'z'), "only lowercase letters a-z allowed as indices"); - int64_t letter_num = c-'a'; - TORCH_CHECK(idxes_to_preprocessed_dims[letter_mapping[letter_num]] == -1, "index ", c, " occurs twice in output"); - idxes_to_preprocessed_dims[letter_mapping[letter_num]] = num_output_dims; - num_output_dims++; + + TORCH_CHECK( + curr_op == num_ops - 1, + "einsum() more operands were provided than specified in the equation"); + + // Labels must be within [a, z]. + constexpr int TOTAL_LABELS = 'z' - 'a' + 1; + std::vector label_count(TOTAL_LABELS, 0); + + // The maximum number of dimensions covered by any ellipsis, needed when + // unsqueezing missing dimensions from operands to permute and broadcast + int64_t ell_num_dim = 0; + + // Compute label frequency and number of dimensions covered by ellipsis + // We do this after parsing labels to make it more readable and simpler + // to compute the number of dimensions covered by ellipsis. + for (auto i = decltype(num_ops){0}; i < num_ops; ++i) { + const auto operand = operands[i]; + const auto labels = op_labels[i]; + const int64_t ndims = operand.dim(); + int64_t nlabels = labels.size(); + bool has_ellipsis = false; + + for (const auto& label : labels) { + if (label == ELLIPSIS) { + --nlabels; + has_ellipsis = true; + ell_num_dim = std::max(ell_num_dim, ndims - nlabels); + } else { + ++label_count[label]; } } - } else { // create an inferred right hand side - // the ellipsis (if in the lhs) comes first - if (num_ell_idxes >= 0) { - for (int64_t i = 0; i < num_ell_idxes; ++i) { - idxes_to_preprocessed_dims[first_ell_idx + i] = num_output_dims; - num_output_dims++; + + TORCH_CHECK( + has_ellipsis ? nlabels <= ndims : nlabels == ndims, + "einsum() the number of subscripts in the equation (", + nlabels, + has_ellipsis ? ") is more than the number of dimensions (" + : ") does not match the number of dimensions (", + ndims, + ") for operand ", + i, + has_ellipsis ? "" : " and no ellipsis was given"); + } + + // We want to align the dimensions of every input tensor to have + // shape out_dims + sum_dims. For this, we create a mapping of label + // to index into the permuted shape. + std::vector label_perm_index(TOTAL_LABELS, -1); + + // Current index in the permuted shape + int64_t perm_index = 0; + + // Start index of ellipsis dimensions in the permuted shape + int64_t ell_index = 0; + found_ell = false; + + if (arrow_pos == std::string::npos) { + // Implicit output is ellipsis (...) + labels seen only once + perm_index = ell_num_dim; + found_ell = true; + for (int label = 0; label < TOTAL_LABELS; ++label) { + if (label_count[label] == 1) { + label_perm_index[label] = perm_index++; } } - // then the indices that occur exactly once in alphabetic order - for (size_t idx = 0; idx < number_of_letters; idx++) { - if (num_letter_occurrences[idx] == 1) { - idxes_to_preprocessed_dims[letter_mapping[idx]] = num_output_dims; - num_output_dims++; + } else { + // Parse explicit output + const auto rhs = equation.substr(arrow_pos + 2); + for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) { + switch (rhs[i]) { + case ' ': + // Ignore spaces + break; + + case '.': + TORCH_CHECK( + // There can only be one ellipsis in the output + !found_ell, + "einsum() found \'.\' for output but an ellipsis (...) was already found"); + TORCH_CHECK( + // Ensure ellipsis is correct + i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.', + "einsum() found \'.\' for output that is not part of any ellipsis (...)"); + ell_index = perm_index; + perm_index += ell_num_dim; + found_ell = true; + break; + + default: + TORCH_CHECK( + // Labels must be in [a, z] + rhs[i] >= 'a' && rhs[i] <= 'z', + "einsum() subscripts must be in range [a, z] but found ", + rhs[i], + " for the output"); + const auto label = rhs[i] - 'a'; + TORCH_CHECK( + // Ensure label appeared at least once for some input operand and at + // most once for the output + label_count[label] > 0 && label_perm_index[label] == -1, + "einsum() output subscript ", + rhs[i], + label_perm_index[label] > -1 + ? " appears more than once in the output" + : " does not appear in the equation for any input operand"); + label_perm_index[label] = perm_index++; } } } - // now we assign the idxes_to_preprocessed_dims (index -> dimension in preprocessed / output tensors) - // for the non-output indices - those that are eventually summed over - int64_t position = num_output_dims; - for (int64_t i = 0; i < num_total_idxes; i++) { - if (idxes_to_preprocessed_dims[i]==-1) { - idxes_to_preprocessed_dims[i] = position; - position++; + + // Save output size before adding contraction dims (dims to sum out) + const int64_t out_size = perm_index; + + // If ellipsis is not part of the output, add to contraction dimensions + if (!found_ell) { + ell_index = perm_index; + perm_index += ell_num_dim; + } + + // Add contraction labels (labels not present in output) + for (int label = 0; label < TOTAL_LABELS; ++label) { + if (label_count[label] > 0 && label_perm_index[label] == -1) { + label_perm_index[label] = perm_index++; } } - // we now "homogenize the dimensions", i.e. - // - take diagonals for duplicated indices - // - permute the dimensions to match the order given by idxes_to_preprocessed_dims - // - unsqueeze to create all dimensions for each index in each tensor where they are missing - // we also check that sizes match - // after this, all operands will have compatible shapes (i.e. all dimensions are aligned are broadcastable) - std::vector preprocessed_operands; - std::vector size_of_dims(num_total_idxes, -1); // keep track of sizes for each index, -1 means we have not seen a size yet - for (int64_t op = 0; op < (int64_t) tensors.size(); op++) { - auto preprocessed_op = tensors[op]; - std::vector idx_to_dim(num_total_idxes, -1); // the dimension which the index refers to in the original tensor, -1 means it does not appear - std::vector& current_op_input_idxes = input_op_idxes[op]; - int64_t dim = 0; // there are two dimension indices: dim is after taking diagonals, i is in input - for (size_t i = 0; i < current_op_input_idxes.size(); i++) { - auto idx = current_op_input_idxes[i]; - auto dim_out = idxes_to_preprocessed_dims[idx]; - if (idx_to_dim[dim_out] == -1) { // first appearance - idx_to_dim[dim_out] = dim; - if (size_of_dims[idx] == -1) { // keep track of sizes - size_of_dims[idx] = preprocessed_op.size(dim); - } - else { - TORCH_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i); + // Here we unsqueeze missing dimensions to make all operands have the same + // number of dimensions. We take diagonals for repeated labels within the + // same operand. Finally we permute the operands to align dimensions as + // per the perm_out_index we computed above. + std::vector permuted_operands; + for (auto i = decltype(num_ops){0}; i < num_ops; ++i) { + std::vector perm_shape(perm_index, -1); + std::vector label_dim(TOTAL_LABELS, -1); + Tensor operand = operands[i]; + const auto labels = op_labels[i]; + const auto original_sizes = operand.sizes(); + + std::size_t j = 0; + for (const auto& label : labels) { + if (label == ELLIPSIS) { + // Add missing dimensions covered by the ellipsis + const int64_t num_missing_dim = + ell_num_dim - (original_sizes.size() - labels.size() + 1); + for (int64_t k = 0; k < num_missing_dim; ++k) { + operand = operand.unsqueeze(j); } - dim++; - } else { // duplicate dimension in tensor --> take diagonal of idx_to_dim[dim_out] and dim and put the diagonal dimension to idx_to_dim[dim_out] - TORCH_CHECK(size_of_dims[idx] == preprocessed_op.size(dim), "size of dimension does not match previous size, operand ", op, ", dim ", i); - preprocessed_op = preprocessed_op.diagonal(0, idx_to_dim[dim_out], dim); - // diagonal moves the diagonal dimension to the back - // now we permute the last dim back to idx_to_dim[dim_out] - std::vector perm(preprocessed_op.dim(), 0); - for (int64_t d = 0; d < preprocessed_op.dim(); d++) { - if (d == idx_to_dim[dim_out]) { - perm[d] = preprocessed_op.dim() - 1; - } else { - perm[d] = d - (d > idx_to_dim[dim_out]); - } + for (int64_t k = 0; k < ell_num_dim; ++k) { + perm_shape[ell_index + k] = j++; } - preprocessed_op = preprocessed_op.permute(perm); + } else if (label_dim[label] != -1) { + // Repeated label, take diagonal + const auto dim = label_dim[label]; + TORCH_CHECK( + operand.size(j) == operand.size(dim), + "einsum() subscript ", + char(label + 'a'), + " is repeated for operand ", + i, + " but the sizes don't match, ", + operand.size(j), + " != ", + operand.size(dim)); + operand = operand.diagonal(0, dim, j).movedim(-1, dim); + } else { + // Lookup output index for label + label_dim[label] = j; + perm_shape[label_perm_index[label]] = j++; } } - // now we permute the dimensions in the right order - std::vector permutation; // permutation for this tensor - for (auto &d : idx_to_dim) { - if (d > -1) { - permutation.push_back(d); + + // Add dimensions for missing labels + for (int64_t& index : perm_shape) { + if (index == -1) { + operand = operand.unsqueeze(-1); + index = j++; } } - preprocessed_op = preprocessed_op.permute(permutation); - // finally, we insert dimensions for idxes not in the operand - for (size_t dim = 0; dim < idx_to_dim.size(); dim++) { - if (idx_to_dim[dim] == -1) { - preprocessed_op = preprocessed_op.unsqueeze(dim); + + permuted_operands.push_back(operand.permute(perm_shape)); + } + + // Check if operands broadcast and keep track of last operand with + // dimension size != 1 for optimizing reductions + std::vector dim_last_op(perm_index, 0); + bool has_zero_size_dim = false; + for (int64_t dim = 0; dim < perm_index; ++dim) { + auto broadcast_size = permuted_operands[0].size(dim); + for (auto i = decltype(num_ops){1}; i < num_ops; ++i) { + const auto dim_size = permuted_operands[i].size(dim); + if (broadcast_size != dim_size && broadcast_size != 1 && dim_size != 1) { + std::ostringstream msg; + msg << "einsum() operands do not broadcast with remapped shapes [original->remapped]:"; + for (auto j = decltype(num_ops){0}; j < num_ops; ++j) { + msg << " " << operands[j].sizes() << "->" + << permuted_operands[j].sizes(); + } + TORCH_CHECK(false, msg.str()); + } + if (dim_size != 1) { + broadcast_size = dim_size; + dim_last_op[dim] = i; } } + has_zero_size_dim |= broadcast_size == 0; + } - preprocessed_operands.push_back(std::move(preprocessed_op)); + // Compute result + Tensor result = permuted_operands[0]; + + // Fast path for when an operand has zero sized dim + if (has_zero_size_dim) { + std::vector out_shape(out_size); + for (int64_t i = 0; i < out_size; ++i) { + out_shape[i] = permuted_operands[dim_last_op[i]].size(i); + } + return at::zeros(out_shape, result.options()); } - // now we reduce the indices from left to right - // numpy allows to optimize the path using various - // algorithms (see eigen_path in numpy docs) - // we start with the leftmost operator and reduce indices that - // appear only there - Tensor result = std::move(preprocessed_operands[0]); - for (int64_t idx = 0; idx < num_total_idxes; idx++) { - if ((last_idx_occurrence[idx] == 0) - && (idxes_to_preprocessed_dims[idx]>=num_output_dims)) { - result = result.sum(idxes_to_preprocessed_dims[idx], true); + // Sum out or squeeze dimensions that are size 1 for all later operands + int64_t dim = out_size; + for (int64_t i = dim; i < perm_index; ++i, ++dim) { + if (dim_last_op[i] == 0) { + if (result.size(dim) == 1) { + result = result.squeeze(dim--); + } else { + result = result.sum(dim--); + } } } - // now we process each tensor using sumproduct_pair - for (int64_t i = 1; i < (int64_t) preprocessed_operands.size(); i++) { + for (auto i = decltype(num_ops){1}; i < num_ops; ++i) { + Tensor operand = permuted_operands[i]; std::vector sum_dims; - for (int64_t idx = 0; idx < num_total_idxes; idx++) { - if ((last_idx_occurrence[idx] == i) - && (idxes_to_preprocessed_dims[idx]>=num_output_dims)) { - sum_dims.push_back(idxes_to_preprocessed_dims[idx]); + + // Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size; + for (int64_t j = dim; j < perm_index; ++j, ++dim) { + if (dim_last_op[j] < i) { + operand = operand.squeeze(dim); + --dim; + } else if (dim_last_op[j] == i) { + if (result.size(dim) == 1) { + operand = operand.sum(dim); + result = result.squeeze(dim); + --dim; + } else { + sum_dims.push_back(dim); + } } } - result = at::native::sumproduct_pair(result, std::move(preprocessed_operands[i]), sum_dims, true); - } - // finally, we squeeze out all non-result dimensions - auto sizes = result.sizes().vec(); - for (int64_t dim = num_total_idxes-1; dim >= num_output_dims; dim--) { - sizes.erase(sizes.begin() + dim); + + // Multiply tensors and sum out dimensions in sum_dims + if (sum_dims.empty()) { + result = result.mul(operand); + } else if (sum_dims.size() == result.sizes().size()) { + result = result.flatten().dot(operand.flatten()); + } else { + result = sumproduct_pair(result, operand, sum_dims, false); + } } - result = result.view(sizes); return result; } @@ -534,4 +629,10 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, return at::mm(t1, t2).reshape(rsizes); } +Tensor &tensordot_out(Tensor& result, const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) { + result.copy_(at::native::tensordot(input1, input2, dims1, dims2)); + return result; +} + + }} // namespace at::native diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 9c3742c129de7..8809410599513 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -4,7 +4,11 @@ #include #include #include +#include #include +#include +#include +#include #include #include #include @@ -15,9 +19,13 @@ #include #include +#include + namespace at { namespace native { +DEFINE_DISPATCH(addr_stub); + // Helper function for det methods. // For pivoted LU factorization A = P * L * U. Since we always have det(L) = 1, // det(P) = \pm 1, this method returns a 3-tuple: @@ -28,10 +36,9 @@ static inline std::tuple _lu_det_P_diag_U(const Tensor& self) { std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false); TORCH_CHECK(infos.ge(0).all().item(), "Invalid argument passed to lu"); auto n = self.size(-1); - auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs).sum(-1, /*keepdim=*/false, /*dtype=*/self.scalar_type()).fmod_(2); - // NB: the `.contiguous()` call is added due to the bug in `.prod()` as reported in - // issue #https://github.com/pytorch/pytorch/issues/34061 - auto u_diagonal = lu.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).contiguous(); + auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs) + .sum(-1, /*keepdim=*/false, /*dtype=*/at::kLong).fmod_(2); + auto u_diagonal = lu.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); return std::tuple(num_exchanges.mul_(-2).add_(1), u_diagonal); } @@ -67,101 +74,307 @@ Tensor logdet(const Tensor& self) { // U is singular when U(i, i) = 0 for some i in [1, self.size(-1)]. Tensor logdet_vals = diag_U.abs_().log_().sum(-1); if (self.dim() > 2) { - logdet_vals.index_put_((det_sign < 0).nonzero_numpy(), at::full({}, NAN, self.options())); + auto indices = toListOfOptionalTensors((det_sign < 0).nonzero_numpy()); + logdet_vals.index_put_(std::move(indices), at::full({}, NAN, self.options())); } else if (det_sign.item() < 0) { logdet_vals.fill_(NAN); } return logdet_vals; } -std::tuple slogdet(const Tensor& self) { +std::tuple linalg_slogdet(const Tensor& self) { squareCheckInputs(self); - TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())), - "Expected a floating point tensor as input"); + ScalarType t = self.scalar_type(); + TORCH_CHECK(t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble, + "linalg_slogdet: expected a tensor of float, double, cfloat or cdouble types but got ", t); Tensor det_P, diag_U; std::tie(det_P, diag_U) = _lu_det_P_diag_U(self); - auto det_sign = diag_U.sign().prod(-1).mul_(det_P); + auto det_sign = diag_U.sgn().prod(-1).mul_(det_P); // abslogdet_val is -inf if U is singular, in which case diag_U.abs_().log_().sum(-1) will return -inf. // U is singular when U(i, i) = 0 for some i in [1, self.size(-1)]. // Since abslogdet_val cannot take nan, no special case handling is required. - auto abslogdet_val = diag_U.abs_().log_().sum(-1); + // in-place abs is not supported for complex tensors + auto abslogdet_val = isComplexType(t) ? diag_U.abs().log_().sum(-1) : diag_U.abs_().log_().sum(-1); return std::make_tuple(det_sign, abslogdet_val); } +// TODO: implement _out variant avoiding copy and using already allocated storage directly +std::tuple linalg_slogdet_out(const Tensor& input, Tensor& sign, Tensor& logabsdet) { + TORCH_CHECK(sign.scalar_type() == input.scalar_type(), + "sign dtype ", sign.scalar_type(), " does not match input dtype ", input.scalar_type()); + ScalarType real_dtype = toValueType(typeMetaToScalarType(input.dtype())); + TORCH_CHECK(logabsdet.scalar_type() == real_dtype, + "logabsdet dtype ", logabsdet.scalar_type(), " does not match the expected dtype ", real_dtype); + TORCH_CHECK(sign.device() == input.device() && logabsdet.device() == input.device(), + "Expected sign, logabsdet and input to be on the same device, but found sign on ", + sign.device(), ", logabsdet on ", logabsdet.device(), " and input on ", input.device(), " instead."); + + Tensor sign_tmp, logabsdet_tmp; + std::tie(sign_tmp, logabsdet_tmp) = at::linalg_slogdet(input); + + at::native::resize_output(sign, sign_tmp.sizes()); + sign.copy_(sign_tmp); + at::native::resize_output(logabsdet, logabsdet_tmp.sizes()); + logabsdet.copy_(logabsdet_tmp); + + return std::tuple(sign, logabsdet); +} + +std::tuple slogdet(const Tensor& self) { + return at::linalg_slogdet(self); +} + +Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) { + ScalarType t = input.scalar_type(); + TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble) + && input.dim() >= 2, + "linalg_pinv(", t, "{", input.sizes(), "}): expected a tensor with 2 or more dimensions " + "of float, double, cfloat or cdouble types"); + TORCH_CHECK(rcond.device() == input.device(), + "Expected rcond and input to be on the same device, but found rcond on ", + rcond.device(), " and input on ", input.device(), " instead."); + TORCH_CHECK(!at::isComplexType(rcond.scalar_type()), + "linalg_pinv: rcond tensor of complex type is not supported."); + + if (input.numel() == 0) { + // The implementation below uses operations that do not work for zero numel tensors + // therefore we need this early return for 'input.numel() == 0' case + auto input_sizes = input.sizes().vec(); + std::swap(input_sizes[input.dim() - 1], input_sizes[input.dim() - 2]); + return at::empty(input_sizes, input.options()); + } + + // If not Hermitian use singular value decomposition, else use eigenvalue decomposition + if (!hermitian) { + // until https://github.com/pytorch/pytorch/issues/45821 is resolved + // svd() returns conjugated V for complex-valued input + Tensor U, S, V_conj; + // TODO: replace input.svd with linalg_svd + std::tie(U, S, V_conj) = input.svd(); + Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order + Tensor S_pseudoinv = at::where(S > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype()); + // computes V @ diag(S_pseudoinv) @ U.T.conj() + // TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved + return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); + } else { + Tensor S, U; + std::tie(S, U) = at::linalg_eigh(input); + // For Hermitian matrices, singular values equal to abs(eigenvalues) + Tensor S_abs = S.abs(); + // eigenvalues are sorted in ascending order starting with negative values, we need a maximum value of abs(eigenvalues) + Tensor max_val = S_abs.amax(/*dim=*/-1, /*keepdim=*/true); + Tensor S_pseudoinv = at::where(S_abs > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype()); + // computes U @ diag(S_pseudoinv) @ U.conj().T + return at::matmul(U * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); + } +} + +Tensor linalg_pinv(const Tensor& input, double rcond, bool hermitian) { + Tensor rcond_tensor = at::full({}, rcond, input.options().dtype(ScalarType::Double)); + return at::linalg_pinv(input, rcond_tensor, hermitian); +} + +// TODO: implement _out variant avoiding copy and using already allocated storage directly +Tensor& linalg_pinv_out(Tensor& result, const Tensor& input, const Tensor& rcond, bool hermitian) { + TORCH_CHECK(result.scalar_type() == input.scalar_type(), + "result dtype ", result.scalar_type(), " does not match the expected dtype ", input.scalar_type()); + TORCH_CHECK(result.device() == input.device(), + "Expected result and input to be on the same device, but found result on ", + result.device(), " and input on ", input.device(), " instead."); + + Tensor result_tmp = at::linalg_pinv(input, rcond, hermitian); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + +Tensor& linalg_pinv_out(Tensor& result, const Tensor& input, double rcond, bool hermitian) { + Tensor rcond_tensor = at::full({}, rcond, input.options().dtype(ScalarType::Double)); + return at::linalg_pinv_out(result, input, rcond_tensor, hermitian); +} + Tensor pinverse(const Tensor& self, double rcond) { - TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() >= 2, - "pinverse(", self.scalar_type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions " - "of floating types"); + return at::linalg_pinv(self, rcond, /*hermitian=*/false); +} + +Tensor& linalg_matrix_rank_out(Tensor& result, const Tensor& self, optional tol, bool hermitian) { + TORCH_CHECK(result.scalar_type() == ScalarType::Long, + "result dtype ", result.scalar_type(), " does not match the expected dtype ", ScalarType::Long); + + // Matrices or batch of matrices are allowed + TORCH_CHECK(self.dim() >= 2, "linalg_matrix_rank: Expected as input a matrix or a batch of matrices, but got a tensor of size: ", self.sizes()); + + // matrix_rank assigns a scalar value for each matrix in the batch so + // result's shape is equal to self.shape[0:self.ndim-2] + // for single matrix result_shape = {} + auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); + at::native::resize_output(result, result_shape); + + // NumPy doesn't take into account possible input with no elements and it errors on max not defined for this case + // Let's output 0 for this case, since that kind of matrices have zero number of non-zero rows, hence rank is 0. if (self.numel() == 0) { - // Match NumPy - auto self_sizes = self.sizes().vec(); - std::swap(self_sizes[self.dim() - 1], self_sizes[self.dim() - 2]); - return at::empty(self_sizes, self.options()); + result.fill_(0); + return result; } - Tensor U, S, V; - std::tie(U, S, V) = self.svd(); - Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); - Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, self.options())); - return at::matmul(V, at::matmul(S_pseudoinv.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1), U.transpose(-2, -1))); -} -static inline Tensor _matrix_rank_helper(const Tensor& self, bool symmetric) { + // We compute matrix rank as the number of singular or absolute eigen values above 'tol' threshold Tensor S; - if (!symmetric) { + if (!hermitian) { Tensor U, V; + // TODO: replace self.svd with linalg_svd std::tie(U, S, V) = self.svd(/*some=*/true, /*compute_uv=*/false); } else { - Tensor eigvecs; - std::tie(S, eigvecs) = self.symeig(/*eigenvectors=*/false); + S = at::linalg_eigvalsh(self); S = S.abs(); } - return S; + + if (tol.has_value()) { + double tol_value = tol.value(); + at::sum_out(result, S > tol_value, /*dim=*/-1); + } else { + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + double tol_value = _get_epsilon(real_dtype) * std::max(self.size(-1), self.size(-2)); + Tensor max_S = S.amax(/*dim=*/-1); + at::sum_out(result, S > max_S.mul_(tol_value).unsqueeze_(-1), /*dim=*/-1); + } + return result; } -Tensor matrix_rank(const Tensor& self, double tol, bool symmetric) { - TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2, - "matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor " - "of floating types"); +Tensor linalg_matrix_rank(const Tensor& self, optional tol, bool hermitian) { + Tensor result = at::empty({0}, self.options().dtype(ScalarType::Long)); + result = at::linalg_matrix_rank_out(result, self, tol, hermitian); + return result; +} - Tensor S = _matrix_rank_helper(self, symmetric); - return (S > tol).sum(); +Tensor matrix_rank(const Tensor& self, double tol, bool symmetric) { + return at::linalg_matrix_rank(self, optional(tol), symmetric); } Tensor matrix_rank(const Tensor& self, bool symmetric) { - TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2, - "matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor " - "of floating types"); - - Tensor S = _matrix_rank_helper(self, symmetric); - double tol = _get_epsilon(self.scalar_type()) * std::max(self.size(0), self.size(1)); - return (S > S.max().mul_(tol)).sum(); + return at::linalg_matrix_rank(self, c10::nullopt, symmetric); } static void check_1d(const Tensor& t, const char* arg, const char* fn) { TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); } -Tensor addr(const Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) { - TORCH_WARN( - "torch.addr is deprecated and may be removed in a future PyTorch release. " - "This function can be implemented using torch.outer as " - "alpha * torch.outer(vec1, vec2) + beta * input when beta is not zero, " - "alpha * torch.outer(vec1, vec2) when beta is zero."); +static void check_addr_scalar(const ScalarType dtype, + const Scalar scalar, + const std::string& scalar_name) { + TORCH_CHECK( + !scalar.isBoolean() || dtype == ScalarType::Bool, + "Boolean ", scalar_name, " only supported for Boolean results."); + TORCH_CHECK( + isFloatingType(dtype) || isComplexType(dtype) || scalar.isIntegral(true), + "For integral input tensors, " + "argument ", scalar_name ," must not be a floating point number."); +} + +static TensorIterator build_addr_iter(Tensor& result, + const Tensor& self, + const Tensor& vec1, + const Tensor& vec2) { + check_1d(vec1, "vec1", "addr"); + check_1d(vec2, "vec2", "addr"); - Tensor outer_result = at::outer(vec1, vec2) * alpha; - if (beta.to() == 0.0) { - return outer_result; + Tensor self_; + if (&result != &self) { + std::tie(self_) = expand_size(self, {vec1.size(0), vec2.size(0)}, "addr"); + } else { + self_ = self; } - return outer_result + (self * beta); + TORCH_CHECK( + self_.dim() == 2, + "2D tensor expected, got ", self_.dim(), "D tensor for input" + ); + TORCH_CHECK( + self_.size(0) == vec1.size(0) && self_.size(1) == vec2.size(0), + "size mismatch, input: ", self_.sizes(), + ", v1: ", vec1.sizes(), + ", v2: ", vec2.sizes() + ); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(true) + .add_output(result) + .add_input(self_) + .add_input(vec1.reshape({vec1.size(0), 1})) + .add_input(vec2) + .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true) + .build(); + return iter; +} + +Tensor addr(const Tensor& self, + const Tensor& vec1, const Tensor& vec2, + Scalar beta, Scalar alpha) { + Tensor result; + auto iter = build_addr_iter(result, self, vec1, vec2); + + check_addr_scalar(iter.dtype(), beta, "beta"); + check_addr_scalar(iter.dtype(), alpha, "alpha"); + + addr_stub(iter.device_type(), iter, beta, alpha); + return iter.output(); } -Tensor& addr_(Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) { +Tensor& addr_(Tensor& self, + const Tensor& vec1, const Tensor& vec2, + Scalar beta, Scalar alpha) { return at::addr_out(self, self, vec1, vec2, beta, alpha); } -Tensor& addr_out(Tensor &result, const Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) { +Tensor& addr_out(Tensor &result, + const Tensor& self, + const Tensor& vec1, const Tensor& vec2, + Scalar beta, Scalar alpha) { + auto iter = build_addr_iter(result, self, vec1, vec2); + + check_addr_scalar(iter.dtype(), beta, "beta"); + check_addr_scalar(iter.dtype(), alpha, "alpha"); + + addr_stub(iter.device_type(), iter, beta, alpha); + return result; +} + +// The math_addr and math_addr_out functions support backends +// other than CPU and CUDA, such as XLA. +// They are implemented using the composition of existing ops +Tensor math_addr(const Tensor& self, + const Tensor& vec1, const Tensor& vec2, + Scalar beta, Scalar alpha) { + // when beta==0, values in self should be ignored, + // nans and infs in self should not propagate. + if (beta.toComplexDouble() == 0.0) { + if (alpha.toComplexDouble() == 1.0) { + return at::outer(vec1, vec2); + } + return alpha * at::outer(vec1, vec2); + } + + if (beta.toComplexDouble() == 1.0) { + if (alpha.toComplexDouble() == 1.0) { + return self + at::outer(vec1, vec2); + } + return self + alpha * at::outer(vec1, vec2); + } + + if (alpha.toComplexDouble() == 1.0) { + return beta * self + at::outer(vec1, vec2); + } + return beta * self + alpha * at::outer(vec1, vec2); +} + +Tensor& math_addr_out(Tensor &result, + const Tensor& self, + const Tensor& vec1, const Tensor& vec2, + Scalar beta, Scalar alpha) { auto addr_result = at::addr(self, vec1, vec2, beta, alpha); + // Validates safe casting const auto result_dtype = addr_result.scalar_type(); TORCH_CHECK(canCast(result_dtype, result.scalar_type()), @@ -184,6 +397,46 @@ Tensor ger(const Tensor& self, const Tensor& vec2) { return self.outer(vec2); } +Tensor& inner_out(Tensor& out, const Tensor& self, const Tensor& other) { + checkDeviceType("inner()", {out, self, other}, self.device().type()); + + // If either self or other is a scalar just multiply them + if (self.dim() == 0 || other.dim() == 0) { + at::mul_out(out, self, other); + return out; + } + + // Last dimension should match (tensordot does not enforce this) + TORCH_CHECK( + self.size(-1) == other.size(-1), + "inner() the last dimension must match on both input tensors but got shapes ", + self.sizes(), + " and ", + other.sizes()); + + at::tensordot_out(out, self, other, -1, -1); + return out; +} + +Tensor inner(const Tensor& self, const Tensor& other) { + checkDeviceType("inner()", {self, other}, self.device().type()); + + // If either self or other is a scalar just multiply them + if (self.dim() == 0 || other.dim() == 0) { + return self * other; + } + + // Last dimension should match (tensordot does not enforce this) + TORCH_CHECK( + self.size(-1) == other.size(-1), + "inner() the last dimension must match on both input tensors but got shapes ", + self.sizes(), + " and ", + other.sizes()); + + return at::tensordot(self, other, -1, -1); +} + Tensor& outer_out(Tensor &result, const Tensor& self, const Tensor& vec2) { check_1d(self, "self", "outer"); check_1d(vec2, "vec2", "outer"); @@ -316,7 +569,7 @@ static void addmm_impl_cpu_( } } -static void addbmm_impl_cpu_( +static void addbmm_impl_( Tensor &result, const Tensor &self, const Tensor &batch1, const Tensor &batch2, Scalar beta, Scalar alpha) { TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor"); @@ -335,35 +588,44 @@ static void addbmm_impl_cpu_( result.resize_as_(self); - if (beta.to() != 0.0 && !self.is_same(result)) { + if (beta.to>() != 0.0 && !self.is_same(result)) { result.copy_(self); } const int64_t num_batches = batch1.size(0); + if (num_batches == 0) { + if (beta.to>() != 0.0) { + result.mul_(beta); + } else { + result.zero_(); + } + return; + } + for (int64_t batch = 0; batch < num_batches; ++batch) { - addmm_impl_cpu_(result, result, batch1[batch], batch2[batch], beta, alpha); + result.addmm_(batch1[batch], batch2[batch], beta, alpha); beta = 1; // accumulate output once } } -Tensor& addbmm_cpu_out(Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { +Tensor& addbmm_out(Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { Tensor b_self = std::get<0>(expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out")); { at::NoNamesGuard guard; - addbmm_impl_cpu_(result, b_self, batch1, batch2, beta, alpha); + addbmm_impl_(result, b_self, batch1, batch2, beta, alpha); } at::namedinference::propagate_names_for_addmm(result, batch1, batch2, self); return result; } -Tensor &addbmm_cpu_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - return addbmm_cpu_out(self, self, batch1, batch2, beta, alpha); +Tensor &addbmm_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + return native::addbmm_out(self, self, batch1, batch2, beta, alpha); } -Tensor addbmm_cpu(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { +Tensor addbmm(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { Tensor result = at::empty({0}, self.options()); - return addbmm_cpu_out(result, self, batch1, batch2, beta, alpha); + return native::addbmm_out(result, self, batch1, batch2, beta, alpha); } Tensor& addmm_cpu_out(Tensor &result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { @@ -456,33 +718,45 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& // is_bmm_out: true for bmm_out, false for baddbmm_ // self_or_result is "self" for baddbmm_ and "result" for bmm_out CheckedFrom c = (is_bmm_out ? "bmm" : "baddbmm"); - TensorArg self_arg(self_or_result, is_bmm_out ? "self" : "result", 0); - TensorArg b1_arg(batch1, "batch1", 1); - TensorArg b2_arg(batch2, "batch2", 2); - checkBackend(c, {self_or_result, batch1, batch2}, Backend::CPU); - checkDim(c, b1_arg, 3); - checkDim(c, b2_arg, 3); - - int64_t bs = batch1.size(0); - checkSize(c, b2_arg, 0, bs); - int64_t contraction_size = batch1.size(2); - int64_t res_rows = batch1.size(1); - int64_t res_cols = batch2.size(2); - checkSize(c, b2_arg, 1, contraction_size); + + auto checkOnCPU = [](const Tensor& t, CheckedFrom c) { + TORCH_CHECK( + !t.is_cuda(), + "Expect tensor to have CPU backend, but got tensor with ", + toString(t.options().backend()), + " Backend (while checking arguments for ", + c); + }; + + checkOnCPU(self_or_result, c); + checkOnCPU(batch1, c); + checkOnCPU(batch2, c); + + checkDim(c, batch1, "batch1", /* pos */ 1, /* dim */ 3); + checkDim(c, batch2, "batch2", /* pos */ 2, /* dim */ 3); + + const auto batch1_sizes = batch1.sizes(); + const auto batch2_sizes = batch2.sizes(); + + int64_t bs = batch1_sizes[0]; + int64_t contraction_size = batch1_sizes[2]; + int64_t res_rows = batch1_sizes[1]; + int64_t res_cols = batch2_sizes[2]; + + TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size); if (is_bmm_out) { self_or_result.resize_({bs, res_rows, res_cols}); } else { - checkSize(c, self_arg, 0, bs); - checkSize(c, self_arg, 1, res_rows); - checkSize(c, self_arg, 2, res_cols); + const auto self_sizes = self_or_result.sizes(); + TORCH_CHECK(self_sizes[0] == bs && self_sizes[1] == res_rows && self_sizes[2] == res_cols); } // handle pathological cases that blas may not like if (self_or_result.numel() == 0) { return self_or_result; } else if (contraction_size == 0) { - if (is_bmm_out) { + if (is_bmm_out || (beta.to>() == 0.0)) { return self_or_result.zero_(); } else { return self_or_result.mul_(beta); @@ -490,21 +764,26 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& } auto batch_items_contiguous_or_transposed = [&](const Tensor& t) { - return (t.stride(2) == 1 && t.stride(1) >= t.size(2)) - || (t.stride(1) == 1 && t.stride(2) >= t.size(1)); + const auto sizes = t.sizes(); + const auto strides = t.strides(); + return (strides[2] == 1 && strides[1] >= sizes[2]) + || (strides[1] == 1 && strides[2] >= sizes[1]); }; if (contraction_size * res_rows * res_cols < 400) { if (is_bmm_out) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(batch1.scalar_type(), "bmm", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "bmm", [&] { baddbmm_cpu_kernel(self_or_result, batch1, batch2, beta, alpha); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(batch1.scalar_type(), "baddbmm", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "baddbmm", [&] { baddbmm_cpu_kernel(self_or_result, batch1, batch2, beta, alpha); }); } - } else if (at::hasMKL() && (at::native::is_floating_point(self_or_result) || + } else if (at::hasMKL() && (( + self_or_result.scalar_type() != kHalf && + self_or_result.scalar_type() != kBFloat16 && + at::native::is_floating_point(self_or_result)) || at::native::is_complex(self_or_result)) && batch_items_contiguous_or_transposed(batch1) && batch_items_contiguous_or_transposed(batch2) @@ -673,8 +952,8 @@ Tensor matmul( std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), - 1, std::multiplies()); + const int64_t expand_batch_product = + prod_intlist(expand_batch_portion); std::vector tensor1_bmm_view({expand_batch_product}); tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1}); @@ -740,7 +1019,7 @@ Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) { {n_copies, a.size(0), a.size(1), a.size(2)}, a.options().memory_format(at::MemoryFormat::Contiguous) ); - + if (is_zero) { res.zero_(); } @@ -820,8 +1099,8 @@ inline Tensor _blob_to_Tensor( // Blob is assumed to be a 1D array, that is why // we also insert a fake dimension so that the result could directly // be used in _compute_linear_combination - auto tensor = at::from_blob((void*)blob.begin(), blob.size(), in.dtype()) - .unsqueeze(0); + auto tensor = at::from_blob((void*)blob.begin(), blob.size(), + c10::toValueType(in.scalar_type())).unsqueeze(0); return _move_memory_if_cuda_input(tensor, in); } @@ -848,7 +1127,7 @@ Tensor compute_T4(const Tensor& A) { auto As = _allocate_buffer(A, 4); // 3 for {I, A, A^2} _fill_matrix_powers(As, A, 3); - + at::native::matmul( // output for A^2 * (I / 2 + A / 6 + A^2 / 24) As.select(0, 3), @@ -954,7 +1233,7 @@ Tensor compute_T12(const Tensor& A) { reinterpret_cast(&b), {num_prods, num_prods}, {num_prods, 1}, - A.dtype() + c10::toValueType(A.scalar_type()) ); bs = _move_memory_if_cuda_input(bs, A); @@ -1026,7 +1305,7 @@ Tensor compute_T18(const Tensor& A) { reinterpret_cast(&b), {num_prods, num_prods}, {num_prods, 1}, - A.dtype() + c10::toValueType(A.scalar_type()) ); bs = _move_memory_if_cuda_input(bs, A); @@ -1099,7 +1378,7 @@ Tensor mexp_impl( if (!compute_highest_degree_approx) { constexpr std::array< Tensor(*)(const Tensor&), - total_n_degs - 1> + total_n_degs - 1> compute_Ts = { compute_T1, compute_T2, compute_T4, compute_T8, compute_T12 @@ -1190,7 +1469,7 @@ Tensor mexp(const Tensor& a, bool compute_highest_degree_approx = false) { // Based on: // -// Mathias, Roy. +// Mathias, Roy. // A Chain Rule for Matrix Functions and Applications. // SIAM J. Matrix Anal. Appl. 17 (1996): 610-620. // @@ -1199,7 +1478,7 @@ Tensor backward_analytic_function_of_a_matrix( const Tensor& self, const Tensor& grad, const func_t& function_of_a_matrix ) { - auto self_transposed = self.transpose(-2, -1); + auto self_transposed = self.transpose(-2, -1).conj(); auto self_transposed_sizes = self_transposed.sizes().vec(); self_transposed_sizes[self.dim() - 2] <<= 1; self_transposed_sizes[self.dim() - 1] <<= 1; @@ -1225,8 +1504,8 @@ Tensor backward_analytic_function_of_a_matrix( // Mathematics 2019, 7, 1174. // Tensor matrix_exp(const Tensor& a) { - TORCH_CHECK(a.dim() >= 2 - && (at::isFloatingType(a.scalar_type()) + TORCH_CHECK(a.dim() >= 2 + && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())), "matrix_exp(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor " "of floating or complex types with dim at least 2"); @@ -1292,14 +1571,13 @@ Tensor matrix_power(const Tensor& a, int64_t n) { } Tensor frobenius_norm(const Tensor& self) { - TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors"); return at::norm(self); } Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { // NOTE: As frobenius_norm_out is currently implemented, it will always produce a // strided tensor result, even if the input is sparse. - auto options = self.options().layout(c10::Layout::Strided); + auto options = self.options().layout(c10::Layout::Strided).dtype(toValueType(self.scalar_type())); Tensor result = at::empty({0}, options); return at::native::frobenius_norm_out(result, self, dim, keepdim); } @@ -1309,7 +1587,6 @@ Tensor &frobenius_norm_out( const Tensor& self, IntArrayRef dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors"); TORCH_CHECK( dim.size() <= 2, "Expected at most 2 dimensions, but got ", @@ -1353,7 +1630,7 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) { } Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + Tensor result = at::empty({0}, self.options().dtype(toValueType(self.scalar_type()))); return at::native::nuclear_norm_out(result, self, dim, keepdim); } @@ -1389,9 +1666,8 @@ static std::vector make_dim_list(int64_t ndim) { } // Checks for valid arguments to linalg_norm when type(ord) == str -static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) { +static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim) { TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord); - TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument"); bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2); TORCH_CHECK(dims_valid, "order \"", str_ord, "\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)"); @@ -1425,9 +1701,9 @@ static Tensor _norm_min_max(Tensor& self, double ord, int64_t dim, bool keepdim) } // Performs matrix norm -static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, +static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, optional opt_ord, IntArrayRef dim, bool keepdim, optional opt_dtype) { - Tensor result; + Tensor result_; auto ord = opt_ord.value_or(2.0).toDouble(); TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, "matrix norm only supports CPU AND CUDA device type, got: ", self.device().type()); @@ -1460,12 +1736,12 @@ static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); auto permutation_reverse = create_reverse_permutation(permutation); - result = std::get<1>(self_.permute(permutation).svd()).abs(); - result = _norm_min_max(result, ord, result.dim() - 1, keepdim); + result_ = std::get<1>(self_.permute(permutation).svd()).abs(); + result_ = _norm_min_max(result_, ord, result_.dim() - 1, keepdim); if (keepdim) { - result.unsqueeze_(-1); - result = result.permute(permutation_reverse); + result_.unsqueeze_(-1); + result_ = result_.permute(permutation_reverse); } } else { // abs(p) == infinity and abs(p) == 1 will perform identical reductions, except @@ -1482,12 +1758,14 @@ static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, dim_[1]--; } if (std::abs(ord) == 1 || std::abs(ord) == INFINITY) { - result = self_.abs().sum(dim_[0], keepdim); - result = _norm_min_max(result, ord, dim_[1], keepdim); + result_ = self_.abs().sum(dim_[0], keepdim); + result_ = _norm_min_max(result_, ord, dim_[1], keepdim); } else { TORCH_CHECK(false, "Order ", ord, " not supported for matrix norm"); } } + resize_output(result, result_.sizes()); + result.copy_(result_); return result; } @@ -1495,7 +1773,9 @@ static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, // This function mostly serves as a wrapper for at::norm, but it overrides a few cases // for numpy compatibility. These cases are corrected within this wrapper, rather than // in at::norm itself, to avoid breaking backward compatibility. -static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, std::vector dim, bool keepdim, optional opt_dtype) { +static Tensor& _linalg_norm_vector_out(Tensor& result, const Tensor& self, optional opt_ord, std::vector dim, bool keepdim, optional opt_dtype) { + Tensor result_; + bool case_was_overridden = false; if (opt_ord.has_value()) { TORCH_INTERNAL_ASSERT(dim.size() == 1); auto ord = opt_ord.value().toDouble(); @@ -1504,20 +1784,15 @@ static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, // The ord = +/-infinity case is overridden because at::norm does not match numpy // when the input contains extreme values (like nan or +/-inf) or if the input // size is degenerate (like size(0), size(0, N), etc) + case_was_overridden = true; self_ = self_.abs(); - return _norm_min_max(self_, ord, dim[0], keepdim); + result_ = _norm_min_max(self_, ord, dim[0], keepdim); } else if ((self_.numel() == 0) && (ord < 0)) { // For negative orders with degenerate input sizes, at::norm's result does not - // match numpy. - Tensor result = self_.abs().pow(ord + 1).sum(dim[0], keepdim); - if (ord >= -1) { - // Result must be infinite in this case, and the simplest way to make that - // happen is to simply add infinity - result += INFINITY; - } else { - result = result.pow(1.0 / (ord + 1)); - } - return result; + // match numpy. It should always be infinity. + auto mask = make_dim_mask(dim[0], self_.dim()); + allocate_reduction_result(result, self_, mask, keepdim, result.scalar_type()); + return result.fill_(INFINITY); } } else { // If ord == None, need to check for unique dims because at::norm does not check it @@ -1527,11 +1802,16 @@ static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, bool unique_dims = (std::unique(dim_.begin(), dim_.end())) == dim_.end(); TORCH_CHECK(unique_dims, "Expected dims to be different, got this instead: (", dim, ")"); } - if (opt_dtype.has_value()) { - return at::norm(self, opt_ord, dim, keepdim, opt_dtype.value()); - } else { - return at::norm(self, opt_ord, dim, keepdim); + if (!case_was_overridden) { + if (opt_dtype.has_value()) { + result_ = at::norm(self.to(opt_dtype.value()), opt_ord, dim, keepdim); + } else { + result_ = at::norm(self, opt_ord, dim, keepdim); + } } + resize_output(result, result_.sizes()); + result.copy_(result_); + return result; } static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional opt_num_ord, optional opt_str_ord, optional opt_dim, bool keepdim, optional opt_dtype) { @@ -1544,47 +1824,45 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional "dtype = ", dtype, ", out.dtype = ", result.scalar_type()); } int64_t ndim = self.dim(); - Tensor result_; if (opt_str_ord.has_value()) { // 'ord' is string auto str_ord = opt_str_ord.value(); - check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype); + check_str_ord_valid(str_ord, opt_dim, ndim); + Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self; if (str_ord == "fro") { - result_ = at::frobenius_norm(self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); + at::frobenius_norm_out(result, self_, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); } else if (str_ord == "nuc") { if (opt_dim.has_value()) { - result_ = at::nuclear_norm(self, opt_dim.value(), keepdim); + at::nuclear_norm_out(result, self_, opt_dim.value(), keepdim); } else { - result_ = at::nuclear_norm(self, keepdim); + at::nuclear_norm_out(result, self_, keepdim); } } } else { // 'ord' is int or None std::vector dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim); if (!opt_num_ord.has_value() || dim_.size() == 1) { - result_ = _linalg_norm_vector(self, opt_num_ord, dim_, keepdim, opt_dtype); + _linalg_norm_vector_out(result, self, opt_num_ord, dim_, keepdim, opt_dtype); } else if (dim_.size() == 2) { - result_ = _linalg_norm_matrix(self, opt_num_ord.value(), dim_, keepdim, opt_dtype); + _linalg_norm_matrix_out(result, self, opt_num_ord.value(), dim_, keepdim, opt_dtype); } else { TORCH_CHECK(false, "'dim' must specify 1 or 2 dimensions when order is numerical and input is " "not 1-D or 2-D"); } } - resize_output(result, result_.sizes()); - result.copy_(result_); return result; } // Numerical or None norms Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { - auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device()); Tensor result = at::empty({0}, options); return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); } // Frobenius and nuclear norms Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { - auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device()); Tensor result = at::empty({0}, options); return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } @@ -1599,6 +1877,262 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); } +Tensor _linalg_cond_exception_helper(const Tensor& self) { + // For batched input if at least one matrix in the batch is not invertible, + // we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. + // This should change when at::inverse works with silent errors + if (self.dim() > 2) { + TORCH_CHECK(false, + "One or more matrices in the batch was not invertible! " + "linalg_cond does not support yet this case."); + } + auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); + TensorOptions options = self.options().dtype(toValueType(self.scalar_type())); + Tensor result = at::full(result_shape, INFINITY, options); + return result; +} + +// This function helps to dispatch norm computations depending on 'ord' of variant type +Tensor _linalg_cond_helper(const Tensor& self, c10::variant ord_variant) { + // Ignore errors if not invertible, result is INFINITY in this case + // Currently checking for error in at::inverse causes cross-device data movement + // For batched input if at least one matrix in the batch is not invertible, + // then the result for all other (possibly) invertible matrices will be infinity as well + // since there is currently no way to use at::inverse with silent errors + Tensor self_inverse; + try { + self_inverse = at::inverse(self); + } catch (const std::exception& e) { + if (strstr(e.what(), "singular")) { + return _linalg_cond_exception_helper(self); + } else { + TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); + } + } + std::array dim_arr = {-2, -1}; + optional dim = IntArrayRef(dim_arr); + + return c10::visit([&](auto&& ord) { + Tensor norm_self = at::linalg_norm(self, ord, dim); + Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + Tensor result = norm_self * norm_inverse; + return result; + }, ord_variant); +} + +// Return zero for each matrix in the batch +Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) { + auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); + TensorOptions options = self.options().dtype(toValueType(self.scalar_type())); + return at::zeros(result_shape, options); +} + +void _linalg_cond_check_ord(c10::variant ord_variant) { + if (ord_variant.index() == 0) { + Scalar* ord = c10::get_if(&ord_variant); + double abs_ord = std::abs(ord->toDouble()); + TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY, + "linalg_cond got an invalid norm type: ", ord->toDouble()); + } else if (ord_variant.index() == 1) { + std::string* ord = c10::get_if(&ord_variant); + TORCH_CHECK(*ord == "fro" || *ord == "nuc", + "linalg_cond got an invalid norm type: ", *ord); + } else { + TORCH_CHECK(false, + "linalg_cond: something went wrong while checking the norm type"); + } +} + +// Numerical or None norms +Tensor linalg_cond(const Tensor& self, optional opt_ord) { + TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", + self.dim(), " dimensions."); + + // The default case is using 2-norm + Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; + + c10::variant ord_variant = ord; + _linalg_cond_check_ord(ord_variant); + + // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input + if (self.numel() == 0) { + auto real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + return _linalg_cond_empty_matrix(self, real_dtype); + } + + // If ord == None or ord == ±2 + if (std::abs(ord.toDouble()) == 2.0) { + auto singular_values = std::get<1>(at::svd(self)); + // singular values are sorted in descending order + auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1); + auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1); + Tensor result; + if (ord.toDouble() == -2.0) { + result = s_min / s_max; + } else { + result = s_max / s_min; + } + return result; + } + + // ord == ±1 ord == ±inf + // since at::inverse is used in the implementation, self has to be a tensor consisting of square matrices + // the same check as squareCheckInputs(self) but with a slightly more informative error message + TORCH_CHECK(self.size(-1) == self.size(-2), + "linalg_cond with ±1 or ±inf norm types only supports square matrices or batches of square matrices " + "but got ", self.size(-1), " by ", self.size(-2), " matrices"); + + return _linalg_cond_helper(self, ord_variant); +} + +Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt_ord) { + // If ord == None or ord == ±2 then SVD is used to compute the condition number + // the result is always real-valued, for other cases it is complex-valued for the complex-valued input. + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; + + TORCH_CHECK(result.scalar_type() == real_dtype, + "result dtype ", result.scalar_type(), " does not match the expected dtype ", real_dtype); + + Tensor result_tmp = at::linalg_cond(self, opt_ord); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + +// Frobenius or nuclear norms +Tensor linalg_cond(const Tensor& self, std::string ord) { + // the same checks as squareCheckInputs(self) but with a slightly more informative error message + TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", + self.dim(), " dimensions."); + TORCH_CHECK(self.size(-1) == self.size(-2), + "linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices " + "but got ", self.size(-1), " by ", self.size(-2), " matrices"); + + c10::variant ord_variant = ord; + _linalg_cond_check_ord(ord_variant); + + // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input + if (self.numel() == 0) { + return _linalg_cond_empty_matrix(self, self.scalar_type()); + } + + return _linalg_cond_helper(self, ord_variant); +} + +// TODO: implement _out variant avoiding copy and using already allocated storage directly +Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) { + ScalarType real_type = toValueType(self.scalar_type()); + TORCH_CHECK(result.scalar_type() == real_type, + "result dtype ", result.scalar_type(), " does not match the expected dtype ", real_type); + + Tensor result_tmp = at::linalg_cond(self, ord); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + +Tensor linalg_tensorinv(const Tensor& self, int64_t ind) { + /* + The idea is to reduce the problem to 2D square matrix inversion. + Step 1. Calculate the shape of the result and the shape of the intermediate 2D matrix. + Step 2. Reshape `self` to 2D matrix. + Step 3. Invert the 2D matrix self.to_2D() + There is no quick way to find out whether the matrix is invertible, + so at this stage an error from at::inverse can be thrown. + Note that for CUDA this causes cross-device memory synchronization that can be slow. + Step 4. reshape the result. + */ + TORCH_CHECK(ind > 0, "Expected a strictly positive integer for 'ind', but got ", ind); + + // self[ind:] + std::vector shape_ind_end = self.sizes().slice(ind).vec(); + // self[:ind] + std::vector shape_start_ind = self.sizes().slice(0, ind).vec(); + + int64_t prod_ind_end = std::accumulate(shape_ind_end.cbegin(), shape_ind_end.cend(), int64_t{1}, std::multiplies()); + int64_t prod_start_ind = std::accumulate(shape_start_ind.cbegin(), shape_start_ind.cend(), int64_t{1}, std::multiplies()); + + // Check whether the self tensor can be reshaped to the 2D square matrix + TORCH_CHECK(prod_ind_end == prod_start_ind, + "Expected self to satisfy the requirement prod(self.shape[ind:]) == prod(self.shape[:ind]), but got ", + prod_ind_end, " != ", prod_start_ind); + + // Concatenate shape_ind_end and shape_start_ind to form the shape of the result + // self[ind:] + self[:ind] + shape_ind_end.insert(shape_ind_end.cend(), shape_start_ind.cbegin(), shape_start_ind.cend()); + + // If the reshaped self is not invertible catch this error + Tensor result; + try { + result = at::inverse(self.reshape({prod_ind_end, prod_ind_end})); + } catch (...) { + TORCH_CHECK(false, "Failed to invert the input tensor, because it is singular."); + } + + return result.reshape(shape_ind_end); +} + +// TODO: implement _out variant avoiding copy and using already allocated storage directly +Tensor& linalg_tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + + Tensor result_tmp = at::linalg_tensorinv(self, ind); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + +Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, optional dims) { + /* + The idea is to reduce the problem to 2D matrix solve. + Step 1. (optional) `self` is permuted with `dims` such that dimensions from `dims` are moved to the right. + For example, if we have 4D input with the shape (1, 2, 3, 4) and dims=(0, 2), + then the result of permutation would have the shape (2, 4, 1, 3). + Step 2. reshape `self` to 2D matrix. + Step 3. solve the matrix equation self.to_2D() @ result = other.to_1D() + Step 4. reshape the result. + */ + int64_t ndim = self.dim(); + Tensor self_ = self; + + // move dimensions of `self_` from `dims` to the end + if (dims.has_value()) { + DimVector dest_axes(dims.value().size()); + std::iota(dest_axes.begin(), dest_axes.end(), ndim - dest_axes.size()); + self_ = at::movedim(self_, dims.value(), dest_axes); + } + + // result_shape is self_.sizes[-(an-other.dim):] + std::vector result_shape = self_.sizes().slice(other.dim(), ndim - other.dim()).vec(); + + int64_t result_product = std::accumulate(result_shape.begin(), result_shape.end(), int64_t{1}, std::multiplies()); + int64_t other_product = std::accumulate(other.sizes().begin(), other.sizes().end(), int64_t{1}, std::multiplies()); + + // Check whether the self tensor can be reshaped to the 2D square matrix + TORCH_CHECK(result_product == other_product, + "Expected self to satisfy the requirement prod(self.shape[other.ndim:]) == prod(self.shape[:other.ndim]), but got ", + result_product, " != ", other_product); + + self_ = self_.reshape({result_product, result_product}); + + // 0th output of at::solve is the solution + // normally `other` would be flattened by at::solve expects 2D input + Tensor result = std::get<0>(at::solve(other.reshape({other.numel(), 1}), self_)); + return result.reshape(result_shape); +} + +Tensor& linalg_tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional dims) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + + Tensor result_tmp = at::linalg_tensorsolve(self, other, dims); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + static inline Tensor _chain_matmul_general(TensorList matrices, std::vector>& order, int64_t i, int64_t j) { if (i == j) return matrices[i]; @@ -1686,5 +2220,71 @@ Tensor chain_matmul(TensorList matrices) { } } +/* +Calculates the Kronecker product between two Tensors. +*/ +Tensor kron(const Tensor& self, const Tensor& other) { + /* + We can obtain the kron result using tensordot or einsum. The implementation below uses tensordot. + In einsum notation suppose we have `self` with dim 4 and `other` with dim 2 + the result of below tensordot is in einsum 0123, 45 -> 012345. + To obtain the correct kron we need to permute and reshape the array. + The permutation rule is the following: going from right to left + take axes in turn to form the permutation + with our example the correct permutation is 012435 and + the kron shape is (shape_self[0], shape_self[1], shape_self[3]*shape_other[0], + shape_self[4]*shape_other[1]) + */ + std::vector self_sizes = self.sizes().vec(); + std::vector other_sizes = other.sizes().vec(); + int64_t self_ndim = self.dim(); + int64_t other_ndim = other.dim(); + int64_t min_ndim = std::min(self_ndim, other_ndim); + int64_t ndim_diff = std::abs(self_ndim - other_ndim); + + std::vector a_axes(self_ndim); + std::vector b_axes(other_ndim); + std::iota(a_axes.begin(), a_axes.end(), 0); + std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim); + + bool is_a_larger = self_ndim >= other_ndim; + std::vector kron_permutation(self_ndim + other_ndim); + for (int64_t i = 0; i < ndim_diff; i++) { + kron_permutation[i] = is_a_larger ? a_axes[i] : b_axes[i]; + } + for (int64_t i = 0, j = 0; i < min_ndim; i++, j += 2) { + kron_permutation[self_ndim + other_ndim - 1 - j] = b_axes[other_ndim - 1 - i]; + kron_permutation[self_ndim + other_ndim - 1 - j - 1] = a_axes[self_ndim - 1 - i]; + } + + std::vector result_shape(std::max(self_ndim, other_ndim)); + for (int64_t i = 0; i < ndim_diff; i++) { + result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i]; + } + for (int64_t i = 0; i < min_ndim; i++) { + result_shape[ndim_diff + i] = is_a_larger + ? self_sizes[ndim_diff + i] * other_sizes[i] + : other_sizes[ndim_diff + i] * self_sizes[i]; + } + + Tensor result = at::tensordot(self, other, {}, {}); + // Step 2: now permute result + result = result.permute(kron_permutation); + // Step 3: reshape + result = result.reshape(result_shape); + + return result; +} + +Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + + Tensor result_tmp = at::kron(self, other); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/LinearAlgebra.h b/aten/src/ATen/native/LinearAlgebra.h new file mode 100644 index 0000000000000..4fc3aa0814a3c --- /dev/null +++ b/aten/src/ATen/native/LinearAlgebra.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { namespace native { + +using addr_fn = void (*)(TensorIterator &, Scalar beta, Scalar alpha); +DECLARE_DISPATCH(addr_fn, addr_stub); + +}} // namespace at::native diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 5c07700f1e850..4322c4c792223 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace at { namespace native { @@ -76,7 +77,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, c " but each b matrix is ", self.size(-2), " by ", self.size(-1)); } -// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig) +// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig) static inline void squareCheckInputs(const Tensor& self) { TORCH_CHECK(self.dim() >= 2, "Tensor of matrices must have at least 2 dimensions. "); TORCH_CHECK(self.size(-1) == self.size(-2), @@ -97,7 +98,7 @@ static inline void batchCheckErrors(std::vector& infos, const char* nam } else if (info > 0) { if (strstr(name, "svd")) { AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")"); - } else if (strstr(name, "symeig")) { + } else if (strstr(name, "symeig") || strstr(name, "syevd")) { AT_ERROR(name, ": For batch ", i, ": the algorithm failed to converge; ", info, " off-diagonal elements of an intermediate tridiagonal form did not converge to zero."); } else if (!allow_singular) { @@ -110,16 +111,16 @@ static inline void batchCheckErrors(std::vector& infos, const char* nam /* * This is an overloaded case of the previous function for a tensor of infos. */ -static inline void batchCheckErrors(const Tensor& infos, const char* name, bool allow_singular=false) { +static inline void batchCheckErrors(const Tensor& infos, const char* name, bool allow_singular=false, int info_per_batch=1) { auto batch_size = infos.numel(); auto infos_cpu = infos.to(at::kCPU); auto infos_data = infos_cpu.data_ptr(); for (int64_t i = 0; i < batch_size; i++) { auto info = infos_data[i]; if (info < 0) { - AT_ERROR(name, ": For batch ", i, ": Argument ", -info, " has illegal value"); + AT_ERROR(name, ": For batch ", i/info_per_batch, ": Argument ", -info, " has illegal value"); } else if (!allow_singular && info > 0) { - AT_ERROR(name, ": For batch ", i, ": U(", info, ",", info, ") is zero, singular U."); + AT_ERROR(name, ": For batch ", i/info_per_batch, ": U(", info, ",", info, ") is zero, singular U."); } } } @@ -134,7 +135,7 @@ static inline void singleCheckErrors(int64_t info, const char* name, bool allow_ } else if (info > 0) { if (strstr(name, "svd")) { AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")"); - } else if (strstr(name, "symeig")) { + } else if (strstr(name, "eig")) { // this catches both "eig" and "symeig" AT_ERROR(name, ": the algorithm failed to converge; ", info, " off-diagonal elements of an intermediate tridiagonal form did not converge to zero."); } else if (!allow_singular) { @@ -191,16 +192,36 @@ static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { return self.permute(perm); } +// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) +static inline std::tuple _parse_qr_mode(std::string mode) { + bool compute_q; + bool reduced; + if (mode == "reduced") { + compute_q = true; + reduced = true; + } else if (mode == "complete") { + compute_q = true; + reduced = false; + } else if (mode == "r") { + compute_q = false; + reduced = true; // this is actually irrelevant in this mode + } else { + TORCH_CHECK(false, "qr received unrecognized mode '", mode, + "' but expected one of 'reduced' (default), 'r', or 'complete'"); + } + return std::make_tuple(compute_q, reduced); +} + // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition static inline std::tuple, std::vector, - int64_t> _compute_geometry_for_Q(const Tensor& input, bool some) { + int64_t> _compute_geometry_for_Q(const Tensor& input, bool reduced) { int64_t m = input.size(-2), n = input.size(-1); int64_t n_columns_q; - // We need to compute the required size of Q based on the `some` option + // We need to compute the required size of Q based on the `reduced` option auto q_sizes = input.sizes().vec(); - if (!some && m > n) { + if (!reduced && m > n) { q_sizes[input.dim() - 1] = m; n_columns_q = m; } else { @@ -241,18 +262,21 @@ static inline std::tuple _create_U_S_VT(const Tensor& in U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } + // VT should be a column-major or a batch of column-major matrices sizes[input.dim() - 2] = n; sizes[input.dim() - 1] = n; - // VT should be a row-major or a batch of row-major matrices + strides = at::detail::defaultStrides(sizes); + strides[input.dim() - 1] = n; + strides[input.dim() - 2] = 1; Tensor VT_empty; if (!input.is_cuda()) { - VT_empty = at::empty(sizes, input.options()); + VT_empty = at::empty_strided(sizes, strides, input.options()); } else { // NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd // (which is the driver routine for the divide and conquer SVD operation) // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that // moves the inputs between devices internally. - VT_empty = at::empty(sizes, input.options().device(at::kCPU)); + VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } sizes.pop_back(); @@ -318,4 +342,28 @@ static inline std::vector create_reverse_permutation(std::vector 10 * mn) { + return 5 * mn * mn + 5 * mn; + } + return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn); +} + +// This function checks whether the uplo argument input is valid +// Allowed strings are "u", "U", "l", "L" +static inline void checkUplo(const std::string& uplo) { + // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char + char uplo_uppercase = static_cast(std::toupper(static_cast(uplo[0]))); + TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'), + "Expected UPLO argument to be 'L' or 'U', but got ", uplo); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 8dc5432d8a8c1..f5738dd83d0c2 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -97,16 +97,15 @@ Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto grad_expand = grad.expand_as(input); if (!log_target) { + auto iter = TensorIteratorConfig() + .add_output(grad_input) + .add_input(target) + .add_input(grad_expand) + .build(); AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "kl_div_backward_cpu", [&]() { - at::CPU_tensor_apply3( - grad_input, - target, - grad_expand, - [] (scalar_t& grad_input_val, const scalar_t& target_val, const scalar_t& grad_val) { - if (target_val > 0) { - grad_input_val = -target_val * grad_val; - } - }); + cpu_serial_kernel(iter, [](scalar_t target_val, scalar_t grad_val) -> scalar_t{ + return target_val > 0 ? -target_val * grad_val : 0; + }); }); } else { @@ -295,24 +294,41 @@ Tensor soft_margin_loss( return output; } -Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t reduction) { +Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t reduction, double beta) { + TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.") + if (beta == 0) { + return at::native::l1_loss(input, target, reduction); + } Tensor loss; auto iter = TensorIterator::binary_op(loss, input, target); - smooth_l1_stub(iter.device_type(), iter); + smooth_l1_stub(iter.device_type(), iter, beta); return apply_loss_reduction(iter.output(), reduction); } -Tensor& smooth_l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction) { +Tensor& smooth_l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction, double beta) { + TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.") + if (beta == 0) { + return at::native::l1_loss_out(result, input, target, reduction); + } if (reduction != Reduction::None) { - result = at::smooth_l1_loss(input, target, reduction); + Tensor loss; + auto iter = TensorIterator::binary_op(loss, input, target); + smooth_l1_stub(iter.device_type(), iter, beta); + if (reduction == Reduction::Mean) { + at::mean_out(result, iter.output(), 0); + } else { + at::sum_out(result, iter.output(), 0); + } } else { auto iter = TensorIterator::binary_op(result, input, target); - smooth_l1_stub(iter.device_type(), iter); + smooth_l1_stub(iter.device_type(), iter, beta); } return result; } -Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) { +Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) { + if (beta <= 0) + return at::native::l1_loss_backward_out(grad_input, grad_output, input, target, reduction); auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.; auto iter = at::TensorIteratorConfig() .add_output(grad_input) @@ -320,13 +336,15 @@ Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_outpu .add_input(target) .add_input(grad_output) .build(); - smooth_l1_backward_stub(iter.device_type(), iter, norm); + smooth_l1_backward_stub(iter.device_type(), iter, norm, beta); return grad_input; } -Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) { +Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) { + if (beta <= 0) + return at::native::l1_loss_backward(grad_output, input, target, reduction); auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction); + return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction, beta); } Tensor mse_loss(const Tensor& input, const Tensor& target, int64_t reduction) { @@ -372,22 +390,24 @@ Tensor& mse_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, } Tensor l1_loss(const Tensor& input, const Tensor& target, int64_t reduction) { - auto loss = input.sub(target).abs_(); - return apply_loss_reduction(loss, reduction); + const auto float_type = c10::toValueType(input.scalar_type()); + Tensor result = at::empty({0}, input.options().dtype(float_type)); + return at::l1_loss_out(result, input, target, reduction); } -Tensor& l1_loss_out(Tensor&result, const Tensor& input, const Tensor& target, int64_t reduction) { +Tensor& l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction) { if (reduction != Reduction::None) { - auto loss = input.sub(target).abs_(); + auto diff = at::sub(input, target); + auto loss = diff.is_complex() ? diff.abs() : diff.abs_(); if (reduction == Reduction::Mean) { - at::mean_out(result, loss, 0); + return at::mean_out(result, loss, IntArrayRef{}); } else { - at::sum_out(result, loss, 0); + return at::sum_out(result, loss, IntArrayRef{}); } } else { - at::sub_out(result, input, target).abs_(); + auto diff = input.is_complex() ? at::sub(input, target) : at::sub_out(result, input, target); + return at::abs_out(result, diff); } - return result; } Tensor l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) { @@ -398,8 +418,7 @@ Tensor l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Te Tensor& l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) { auto norm = reduction == Reduction::Mean ? grad_output / input.numel() : grad_output; - at::sub_out(grad_input, input, target).sign_().mul_(norm); - return grad_input; + return at::sub_out(grad_input, input, target).sgn_().mul_(norm); } }} // namespace at::native diff --git a/aten/src/ATen/native/LossMulti.h b/aten/src/ATen/native/LossMulti.h new file mode 100644 index 0000000000000..4282c346702cf --- /dev/null +++ b/aten/src/ATen/native/LossMulti.h @@ -0,0 +1,72 @@ +#include +#include +#include + +#pragma once + +namespace at { namespace native { +namespace { + static void multilabel_margin_loss_shape_check( + int64_t& nframe, + int64_t& dim, + const int64_t& ndims, + TensorArg& target_arg, + const Tensor& input, + const Tensor& target) { + bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0; + TORCH_CHECK( + valid_inputs, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input.sizes()); + + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input.size(0); + TORCH_CHECK( + valid_inputs && target.dim() <= 1 && target.numel() == dim, + "inconsistent size ", + target.sizes(), + " for ", + target_arg); + } else { + nframe = input.size(0); + dim = input.size(1); + TORCH_CHECK( + valid_inputs && target.dim() == 2 && target.size(0) == nframe && + target.size(1) == dim, + "inconsistent size ", + target.sizes(), + " for ", + target_arg); + } + } + + static void multi_margin_loss_shape_check( + int64_t& nframe, + int64_t& dim, + const int64_t& ndims, + TensorArg& target_arg, + const Tensor& input, + const Tensor& target) { + bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0; + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input.size(0); + } else { + nframe = input.size(0); + dim = input.size(1); + } + + TORCH_CHECK( + valid_inputs, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input.sizes()); + TORCH_CHECK( + valid_inputs && target.dim() <= 1 && target.numel() == nframe, + "inconsistent target size, got: ", + target.sizes()); + } + + +} // anonymous namespace +}} // namespace at::native diff --git a/aten/src/ATen/native/LossMultiLabelMargin.cpp b/aten/src/ATen/native/LossMultiLabelMargin.cpp index 9582bf661a327..e30839afca93c 100644 --- a/aten/src/ATen/native/LossMultiLabelMargin.cpp +++ b/aten/src/ATen/native/LossMultiLabelMargin.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace at { namespace native { @@ -39,6 +40,7 @@ inline scalar_t multilabel_margin_loss_forward_inner_sum_cpu( } } } + return sum; } @@ -100,34 +102,32 @@ static void multilabel_margin_loss_forward_out_cpu_template( Tensor& is_target, int64_t reduction) { auto target_arg = TensorArg(target, "target", 2); - - const auto ndims = input.dim(); - - TORCH_CHECK( - input.numel() > 0 && ndims <= 2, - "non-empty vector or matrix expected, got size: ", - input.sizes()); - int64_t nframe, dim; + const int64_t ndims = input.dim(); if (ndims <= 1) { nframe = 1; dim = ndims == 0 ? 1 : input.size(0); - TORCH_CHECK( - target.numel() > 0 && target.dim() <= 1 && target.numel() == dim, - "inconsistent size ", - target.sizes(), - " for ", - target_arg); - } else { + } + else { nframe = input.size(0); dim = input.size(1); - TORCH_CHECK( - target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe && - target.size(1) == dim, - "inconsistent size ", - target.sizes(), - " for ", - target_arg); + } + multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target); + + // special case target.dim() <= 1: produce scalar output for scalar inputs + // even if reduction == Reduction::None + if (reduction != Reduction::None || target.dim() <= 1) { + output.resize_({}); + } else { + output.resize_({nframe}); + } + + is_target.resize_as_(target); + TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous"); + is_target.zero_(); + + if (input.numel() == 0) { + return; } TORCH_CHECK( @@ -138,18 +138,6 @@ static void multilabel_margin_loss_forward_out_cpu_template( auto input_contiguous = input.contiguous(); auto target_contiguous = target.contiguous(); - is_target.resize_as_(target); - TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous"); - is_target.zero_(); - - // special case target.dim() <= 1: produce scalar output for scalar inputs - // even if reduction == Reduction::None - if (reduction != Reduction::None || target.dim() <= 1) { - output.resize_({}); - } else { - output.resize_({nframe}); - } - AT_DISPATCH_FLOATING_TYPES( input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] { multilabel_margin_loss_forward_out_frame( @@ -232,39 +220,22 @@ static void multilabel_margin_loss_backward_out_cpu_template( const Tensor& target, int64_t reduction, const Tensor& is_target) { + int64_t nframe, dim; CheckedFrom c = "multilabel_margin_loss_backward_cpu_template"; auto target_arg = TensorArg(target, "target", 3); auto is_target_arg = TensorArg(is_target, "is_target", 5); + const int64_t ndims = input.dim(); - const auto ndims = input.dim(); - - TORCH_CHECK( - input.numel() > 0 && ndims <= 2, - "non-empty vector or matrix expected, got size: ", - input.sizes()); - - int64_t nframe, dim; - if (ndims <= 1) { - nframe = 1; - dim = ndims == 0 ? 1 : input.size(0); - TORCH_CHECK( - target.numel() > 0 && target.dim() <= 1 && target.numel() == dim, - "inconsistent size ", - target.sizes(), - " for ", - target_arg); - } else { - nframe = input.size(0); - dim = input.size(1); - TORCH_CHECK( - target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe && - target.size(1) == dim, - "inconsistent size ", - target.sizes(), - " for ", - target_arg); - } + multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target); checkSameSize(c, target_arg, is_target_arg); + + grad_input.resize_as_(input); + if (grad_input.numel() == 0) { + return; + } + + TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); + grad_input.zero_(); TORCH_CHECK( target.min().item() >= -1, target_arg, " is out of range"); @@ -275,10 +246,6 @@ static void multilabel_margin_loss_backward_out_cpu_template( auto target_contiguous = target.contiguous(); auto is_target_contiguous = is_target.contiguous(); - grad_input.resize_as_(input); - TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); - grad_input.zero_(); - AT_DISPATCH_FLOATING_TYPES( input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] { multilabel_margin_loss_backward_out_frame( diff --git a/aten/src/ATen/native/LossMultiMargin.cpp b/aten/src/ATen/native/LossMultiMargin.cpp index 48446a98559d9..7bc7f1fcf72d7 100644 --- a/aten/src/ATen/native/LossMultiMargin.cpp +++ b/aten/src/ATen/native/LossMultiMargin.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace at { namespace native { @@ -93,27 +94,13 @@ void multi_margin_loss_out_cpu_template( Scalar margin, const Tensor& weight, int64_t reduction) { + int64_t nframe, dim; const auto ndims = input.dim(); - TORCH_CHECK( - input.numel() > 0 && ndims <= 2, - "non-empty vector or matrix expected, got size: ", - input.sizes()); + auto target_arg = TensorArg(target, "target", 2); TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported"); - int64_t nframe, dim; - if (ndims <= 1) { - nframe = 1; - dim = ndims == 0 ? 1 : input.size(0); - } else { - nframe = input.size(0); - dim = input.size(1); - } - - TORCH_CHECK( - target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe, - "inconsistent target size, got: ", - target.sizes()); + multi_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target); // produce a scalar output for 1d input if (reduction == Reduction::None && target.dim() > 0) { @@ -121,6 +108,9 @@ void multi_margin_loss_out_cpu_template( } else { output.resize_({}); } + if (input.numel() == 0) { + return; + } auto input_contiguous = input.contiguous(); auto target_contiguous = target.contiguous(); @@ -212,31 +202,20 @@ void multi_margin_loss_backward_out_cpu_template( Scalar margin, const Tensor& weight, int64_t reduction) { + int64_t nframe, dim; + auto target_arg = TensorArg(target, "target", 2); const auto ndims = input.dim(); - TORCH_CHECK( - input.numel() > 0 && ndims <= 2, - "non-empty vector or matrix expected, got size: ", - input.sizes()); - + TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported"); - int64_t nframe, dim; - if (ndims <= 1) { - nframe = 1; - dim = ndims == 0 ? 1 : input.size(0); - } else { - nframe = input.size(0); - dim = input.size(1); - } - - TORCH_CHECK( - target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe, - "inconsistent target size, got: ", - target.sizes()); - + multi_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target); grad_input.resize_as_(input); TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); + if (input.numel() == 0) { + return; + } + auto input_contiguous = input.contiguous(); auto target_contiguous = target.contiguous(); auto weight_contiguous = weight.contiguous(); diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index c00ffec941199..6cd0464de9216 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -277,15 +277,20 @@ static inline float trigamma(float x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline double calc_digamma(double x) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma static double PSI_10 = 2.25175258906672110764; if (x == 0) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); } - int x_is_integer = x == floor(x); + bool x_is_integer = x == trunc(x); if (x < 0) { if (x_is_integer) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return NAN; } return calc_digamma(1 - x) - M_PI / tan(M_PI * x); } @@ -324,15 +329,20 @@ static inline double calc_digamma(double x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline float calc_digamma(float x) { + // See [C++ Standard Reference: Gamma Function] static float PSI_10 = 2.25175258906672110764f; if (x == 0) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); } - int x_is_integer = x == floorf(x); + bool x_is_integer = x == truncf(x); if (x < 0) { if (x_is_integer) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return NAN; } // Avoid rounding errors for `tan`'s input. // Those make a big difference at extreme values. @@ -381,6 +391,726 @@ static inline float calc_polygamma(int64_t n, float x) { zeta(double(n + 1), x); } +// regularized lower incomplete gamma +// the regularized lower, upper incomplete gamma, as well as their +// helper functions follow SciPy's implementation + +/* References + * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov + * [igam2] Maddock et. al., "Incomplete Gamma Functions", + * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html + */ + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +template +static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + scalar_t y, num_ans, denom_ans; + scalar_t absx = std::fabs(x); + const scalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return std::pow(x, i) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +// SciPy's lanczos implementation is taken from Boost +/* (C) Copyright John Maddock 2006. + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. See + * https://www.boost.org/LICENSE_1_0.txt or see NOTICE. + */ +template +static scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + static const scalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const scalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0. + }; + return ratevl(x, lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + scalar_t ax, fac, res, num, numfac; + static scalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static scalar_t EXP1 = 2.718281828459045; + static scalar_t lanczos_g = 6.024680040776729583740234375; + + if (std::fabs(a - x) > 0.4 * std::fabs(a)) { + ax = a * std::log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return std::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= std::exp(a - x) * std::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + scalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + scalar_t fac = 1; + scalar_t sum = 0; + scalar_t term, logx; + static scalar_t MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + static const scalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, + 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, + 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, + -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, + -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, + -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, + 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, + -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, + -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, + 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, + 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, + -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, + 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, + -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, + -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, + 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, + -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + scalar_t lambda = x / a; + scalar_t sigma = (x - a) / a; + scalar_t eta, res, ck, ckterm, term, absterm; + scalar_t absoldterm = INFINITY; + scalar_t etapow[25] = {1}; + scalar_t sum = 0; + scalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * M_PIf * a); + + return res; +} + +template +static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + scalar_t ans, ax, c, yc, r, t, y, z; + scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static scalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static scalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +static inline scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + scalar_t absxma_a; + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +template <> +c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { + return calc_igamma(float(a), float(x)); +} + +template <> +c10::Half calc_igamma(c10::Half a, c10::Half x) { + return calc_igamma(float(a), float(x)); +} + +template <> +c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) { + return calc_igammac(float(a), float(x)); +} + +template <> +c10::Half calc_igammac(c10::Half a, c10::Half x) { + return calc_igammac(float(a), float(x)); +} + inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } template diff --git a/aten/src/ATen/native/MaxPooling.cpp b/aten/src/ATen/native/MaxPooling.cpp index a0298ea937de6..682af63cdafea 100644 --- a/aten/src/ATen/native/MaxPooling.cpp +++ b/aten/src/ATen/native/MaxPooling.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -97,7 +98,12 @@ Tensor max_pool1d( IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) { - if (self.requires_grad() || !self.device().is_cpu()) { + if (self.is_quantized()) { + return at::quantized_max_pool1d( + self, kernel_size, stride, padding, dilation, ceil_mode); + } + if ((self.requires_grad() && at::GradMode::is_enabled()) || + !self.device().is_cpu()) { // Needs indices for grad and with_indices defines CUDA dispatch return std::get<0>(at::max_pool1d_with_indices( self, kernel_size, stride, padding, dilation, ceil_mode)); diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index 2ae5fb0f9d59e..af293d7ebe219 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -7,34 +7,29 @@ namespace native { // Will be promoted to a public API later, but not now Tensor empty_meta( IntArrayRef size, - const TensorOptions& options_, - c10::optional optional_memory_format + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional memory_format ) { - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); - // TODO: deduplicate this logic with empty_cpu - auto dtype = options.dtype(); - auto device = options.device(); auto tensor = detail::make_tensor( // NB: We include the computed dispatch key, not because it will actually // participate in dispatch, but so that tests like is_sparse/is_cuda // give the correct result (a CUDA meta tensor "is cuda"). If we don't // like this, remove the computeDispatchKey line - DispatchKeySet{DispatchKey::Meta, computeDispatchKey(options)}, - dtype, + DispatchKeySet{DispatchKey::Meta, computeDispatchKey(dtype, layout, device)}, + scalarTypeToTypeMeta(dtype_or_default(dtype)), device ); if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } - auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); + auto memory_format_ = memory_format.value_or(MemoryFormat::Contiguous); + tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format_); return tensor; } diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index 459dd857727f2..e80b0c5463629 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -181,10 +182,8 @@ void slow_conv_dilated_all_cpu_template( // Temporary buffer: Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t n = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t n = prod_intlist(output_size); columns.resize_({nInputPlane * m, n}); } // Initialize diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index 7f717121cea4b..c18dc05028c3d 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -359,10 +359,10 @@ Tensor scatter_add(const Tensor& self, Dimname dim, const Tensor& index, const T Tensor& scatter_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { reportNYIDimnameOverload("scatter_add"); } -std::tuple sort_out(Tensor& values, Tensor& indices, const Tensor& self, Dimname dim, bool keepdim) { +std::tuple sort_out(Tensor& values, Tensor& indices, const Tensor& self, Dimname dim, bool keepdim, bool stable) { reportNYIDimnameOverload("sort"); } -std::tuple sort(const Tensor& self, Dimname dim, bool keepdim) { +std::tuple sort(const Tensor& self, Dimname dim, bool keepdim, bool stable) { reportNYIDimnameOverload("sort"); } Tensor& squeeze_(Tensor& self, Dimname dim) { diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 2ddcf5bd5c169..ea4a54c13196e 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -118,7 +118,7 @@ void batch_norm_cpu_inference_channels_last(Tensor& output, const Tensor& input, // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c) // No need to use parallel_for as this function is supposed to be // memory-limited. - // Keep the loop struture simple to make sure compiler vectorization kicks in. + // Keep the loop structure simple to make sure compiler vectorization kicks in. if (n_channel != 1) { for (int64_t n = 0; n < n_batch; ++n) { for (int64_t i = 0; i < image_size; ++i) { @@ -415,6 +415,7 @@ std::tuple _batch_norm_impl_index( bool use_cudnn = false; use_cudnn = (input.is_cuda() + && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16 && (input.scalar_type() != at::kHalf || weight.scalar_type() == at::kFloat) && weight.defined() && bias.defined() diff --git a/aten/src/ATen/native/PixelShuffle.cpp b/aten/src/ATen/native/PixelShuffle.cpp index 14c126f77bdff..20214470ba283 100644 --- a/aten/src/ATen/native/PixelShuffle.cpp +++ b/aten/src/ATen/native/PixelShuffle.cpp @@ -4,31 +4,112 @@ #include #include +#include #include namespace at { namespace native { Tensor pixel_shuffle(const Tensor& self, int64_t upscale_factor) { - AT_ASSERTM(self.dim() == 4, - "pixel_shuffle expects 4D input, but got input with sizes ",self.sizes()); - int64_t b = self.size(0); - int64_t c = self.size(1); - int64_t h = self.size(2); - int64_t w = self.size(3); + TORCH_CHECK(self.dim() >= 3, + "pixel_shuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), " dimension(s)"); + TORCH_CHECK( + upscale_factor > 0, + "pixel_shuffle expects a positive upscale_factor, but got ", + upscale_factor); + // Format: (B1, ..., Bn), C, H, W + int64_t c = self.size(-3); + int64_t h = self.size(-2); + int64_t w = self.size(-1); + const auto NUM_NON_BATCH_DIMS = 3; + const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS; + int64_t upscale_factor_squared = upscale_factor * upscale_factor; - AT_ASSERTM(c % upscale_factor_squared == 0, - "pixel_shuffle expects input channel to be divisible by square of " - "upscale_factor, but got input with sizes ", self.sizes(), - ", upscale_factor=", upscale_factor, - ", and self.size(1)=", c, " is not divisible by ", upscale_factor_squared); + TORCH_CHECK(c % upscale_factor_squared == 0, + "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " + "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared); int64_t oc = c / upscale_factor_squared; int64_t oh = h * upscale_factor; int64_t ow = w * upscale_factor; - auto input_reshaped = self.reshape({b, oc, upscale_factor, upscale_factor, h, w}); - return input_reshaped.permute({0 /* b */, 1 /* oc */, 4 /* h */, 2 /* 1st upscale_factor */, 5 /* w */, 3 /* 2nd upscale_factor */}) - .reshape({b, oc, oh, ow}); + // First, reshape to split the channels dim from c into 3 separate dims: (oc, + // upscale_factor, upscale_factor). This allows shuffling to be done next by + // permuting dims. + std::vector added_dims_shape( + self.sizes().begin(), self_sizes_batch_end); + added_dims_shape.insert( + added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w}); + const auto input_reshaped = self.reshape(added_dims_shape); + + // Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims. + std::vector permutation(self.sizes().begin(), self_sizes_batch_end); + // std::iota is used to maintain the batch dims within the permutation. + // Since 2 dims were added, the correct batch dim offsets are now: + // -added_dims_shape.size(), ..., -7, -6. + std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size()); + permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */, + -3 /* 2nd upscale_factor */}); + const auto input_permuted = input_reshaped.permute(permutation); + + // Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh) + // and (w, upscale_factor) -> a single dim (ow). + std::vector final_shape(self.sizes().begin(), self_sizes_batch_end); + final_shape.insert(final_shape.end(), {oc, oh, ow}); + return input_permuted.reshape(final_shape); +} + + +Tensor pixel_unshuffle(const Tensor& self, int64_t downscale_factor) { + TORCH_CHECK(self.dim() >= 3, + "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), " dimension(s)"); + TORCH_CHECK( + downscale_factor > 0, + "pixel_unshuffle expects a positive downscale_factor, but got ", + downscale_factor); + // Format: (B1, ..., Bn), C, H, W + int64_t c = self.size(-3); + int64_t h = self.size(-2); + int64_t w = self.size(-1); + constexpr auto NUM_NON_BATCH_DIMS = 3; + const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS; + + TORCH_CHECK(h % downscale_factor == 0, + "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h, + " is not divisible by ", downscale_factor) + TORCH_CHECK(w % downscale_factor == 0, + "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w, + " is not divisible by ", downscale_factor) + int64_t downscale_factor_squared = downscale_factor * downscale_factor; + int64_t oc = c * downscale_factor_squared; + int64_t oh = h / downscale_factor; + int64_t ow = w / downscale_factor; + + // First, reshape to split height dim into (oh, downscale_factor) dims and + // width dim into (ow, downscale_factor) dims. This allows unshuffling to be + // done next by permuting dims. + std::vector added_dims_shape( + self.sizes().begin(), self_sizes_batch_end); + added_dims_shape.insert( + added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor}); + const auto input_reshaped = self.reshape(added_dims_shape); + + // Next, unshuffle by permuting the downscale_factor dims alongside the channel dim. + std::vector permutation(self.sizes().begin(), self_sizes_batch_end); + // std::iota is used to maintain the batch dims within the permutation. + // Since 2 dims were added, the correct batch dim offsets are now: + // -added_dims_shape.size(), ..., -7, -6. + std::iota(permutation.begin(), permutation.end(), -added_dims_shape.size()); + permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */, + -4 /* oh */, -2 /* ow */}); + const auto input_permuted = input_reshaped.permute(permutation); + + // Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc), + // resulting in height=oh and width=ow. + std::vector final_shape(self.sizes().begin(), self_sizes_batch_end); + final_shape.insert(final_shape.end(), {oc, oh, ow}); + return input_permuted.reshape(final_shape); } }} // namespace at::native diff --git a/aten/src/ATen/native/PointwiseOps.h b/aten/src/ATen/native/PointwiseOps.h index e81a894549058..98df21121ba30 100644 --- a/aten/src/ATen/native/PointwiseOps.h +++ b/aten/src/ATen/native/PointwiseOps.h @@ -11,10 +11,11 @@ struct TensorIterator; namespace native { using pointwise_fn = void (*)(TensorIterator&, Scalar scalar); +using pointwise_fn_beta = void (*)(TensorIterator&, Scalar scalar, double beta); DECLARE_DISPATCH(pointwise_fn, addcmul_stub); DECLARE_DISPATCH(pointwise_fn, addcdiv_stub); -DECLARE_DISPATCH(pointwise_fn, smooth_l1_backward_stub); +DECLARE_DISPATCH(pointwise_fn_beta, smooth_l1_backward_stub); DECLARE_DISPATCH(pointwise_fn, mse_backward_stub); } // namespace native diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index b89554fd4d48b..60d7f8a419d25 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -28,11 +28,12 @@ static inline T pooling_output_shape_pad_lr( T outputSize = div_rtn( inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + (ceil_mode ? stride - 1 : 0), stride) + 1; - if (pad_l) { + if (ceil_mode) { // ensure that the last pooling starts inside the image // needed to avoid problems in ceil mode - if ((outputSize - 1) * stride >= inputSize + pad_l) + if ((outputSize - 1) * stride >= inputSize + pad_l) { --outputSize; + } } return outputSize; } @@ -53,7 +54,7 @@ pool2d_shape_check( int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW, int64_t nInputPlane, int64_t inputHeight, int64_t inputWidth, - int64_t outputHeight, int64_t outputWidth) + int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format) { const int64_t ndim = input.ndimension(); const int64_t nOutputPlane = nInputPlane; @@ -67,11 +68,22 @@ pool2d_shape_check( TORCH_CHECK(dilationH > 0 && dilationW > 0, "dilation should be greater than zero, but got ", "dilationH: ", dilationH, " dilationW: ", dilationW); - - TORCH_CHECK(input.numel() > 0 && (ndim == 3 || ndim == 4), - "non-empty 3D or 4D input tensor expected but got ndim: ", ndim); + + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + if (memory_format == at::MemoryFormat::ChannelsLast){ + // Expect tensor in NHWC format and allow 0-dim only for N. + TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 4D (batch mode) tensor expected for input with channels_last layout" + " with optional 0 dim batch size for input, but got: ", input.sizes()); + } else { + TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:", + input.sizes()); + } + TORCH_CHECK(kW/2 >= padW && kH/2 >= padH, - "pad should be smaller than half of kernel size, but got ", + "pad should be smaller than or equal to half of kernel size, but got ", "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH); TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, @@ -92,13 +104,13 @@ max_pool2d_backward_shape_check( int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW, int64_t nInputPlane, int64_t inputHeight, int64_t inputWidth, - int64_t outputHeight, int64_t outputWidth, + int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format, bool cuda=false) { pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, - nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth); + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format); const int64_t ndim = input.ndimension(); const int64_t nOutputPlane = nInputPlane; @@ -121,12 +133,14 @@ avg_pool2d_backward_shape_check( int kH, int kW, int dH, int dW, int padH, int padW, int64_t nInputPlane, int64_t inputHeight, int64_t inputWidth, - int64_t outputHeight, int64_t outputWidth) + int64_t outputHeight, int64_t outputWidth, + MemoryFormat memory_format) { pool2d_shape_check( input, kH, kW, dH, dW, padH, padW, 1, 1, - nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth); + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + memory_format); const int64_t ndim = input.ndimension(); const int64_t nOutputPlane = nInputPlane; @@ -171,7 +185,7 @@ pool3d_shape_check( } TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH, - "pad should be smaller than half of kernel size, but got " + "pad should be smaller than or equal to half of kernel size, but got " "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH); TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1, diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 414c8a6f63904..0c69488e0a916 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace at { namespace native { @@ -11,6 +12,10 @@ DEFINE_DISPATCH(pow_tensor_tensor_stub); DEFINE_DISPATCH(pow_tensor_scalar_stub); Tensor& pow_out(Tensor& result, const Tensor& base, const Tensor& exp) { + if (exp.dim() == 0 && exp.device().type() == DeviceType::CPU + && base.device().type() == DeviceType::CUDA) { + return native::pow_out(result, base, exp.item()); + } auto iter = TensorIterator::binary_op(result, base, exp); pow_tensor_tensor_stub(iter.device_type(), iter); return result; @@ -24,17 +29,17 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { auto common_dtype = at::result_type(base, exp); TORCH_CHECK(at::can_cast(common_dtype, result.scalar_type()), - "result type ", common_dtype, "can't be cast to the desired output type ", + "result type ", common_dtype, " can't be cast to the desired output type ", result.scalar_type()); - if (exp.isComplex() && (exp.toComplexDouble() == 0.0) ) { - result.resize_as_(base).fill_(1); - } else if (exp.isComplex() && (exp.toComplexDouble() == 1.0) ) { - result.resize_as_(base).fill_(base); - } else if (!exp.isComplex() && (exp.toDouble() == 0.0)) { - result.resize_as_(base).fill_(1); - } else if (!exp.isComplex() && (exp.toDouble() == 1.0)) { - result.resize_as_(base).copy_(base); + if (exp.equal(0.0)) { + resize_output(result, base.sizes()); + result.fill_(1); + namedinference::propagate_names(result, base); + } else if (exp.equal(1.0)) { + resize_output(result, base.sizes()); + result.copy_(base); + namedinference::propagate_names(result, base); } else { auto iter = TensorIterator::unary_op(result, base.to(common_dtype)); pow_tensor_scalar_stub(iter.device_type(), iter, exp); @@ -43,8 +48,14 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { } Tensor& pow_out(Tensor& result, Scalar base, const Tensor& exp) { - if (base.toDouble() == 1.0) { - result.resize_as_(exp).fill_(1); + if (base.isComplex() && base.toComplexDouble() == 1.0) { + resize_output(result, exp.sizes()); + result.fill_(1); + namedinference::propagate_names(result, exp); + } else if (!base.isComplex() && base.toDouble() == 1.0) { + resize_output(result, exp.sizes()); + result.fill_(1); + namedinference::propagate_names(result, exp); } else { native::pow_out(result, c10::scalar_to_tensor(base, exp.device()), exp); } @@ -77,6 +88,74 @@ Tensor pow(Scalar base, const Tensor& exp) { return native::pow_out(result, base, exp); } +Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? + at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); + + return at::pow_out(result, base.to(dtype), exp.to(dtype)); +} + +Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); + + // Note: need the casts inside the ternary because conversion functions return e.g. c10::complex, + // which causes a complex scalar to always be returned. + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return at::pow_out(result, base.to(dtype), exp); +} + +Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) { + auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "the output given to float_power has dtype ", result.scalar_type(), + " but the operation's result requires dtype ", dtype); + + base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble()); + return at::pow_out(result, base, exp.to(dtype)); +} + +Tensor float_power(const Tensor& base, Scalar exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return at::pow(base.to(dtype), exp); +} + +Tensor float_power(Scalar base, const Tensor& exp) { + auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble; + base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble()); + return at::pow(base, exp.to(dtype)); +} + +Tensor float_power(const Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; + return at::pow(base.to(dtype), exp.to(dtype)); +} + +Tensor& float_power_(Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(base.scalar_type() == dtype, + "the base given to float_power_ has dtype ", base.scalar_type(), + " but the operation's result requires dtype ", dtype); + + return base.pow_(exp.to(dtype)); +} + +Tensor& float_power_(Tensor& base, Scalar exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(base.scalar_type() == dtype, + "the base given to float_power_ has dtype ", base.scalar_type(), + " but the operation's result requires dtype ", dtype); + + exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble()); + return base.pow_(exp); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index f18114e732464..6e7664c1e1a55 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -279,13 +279,19 @@ to reuse the same function name in both cases. Available backend options can be found at https://github.com/pytorch/pytorch/blob/master/tools/codegen/gen.py#L970. -In addition to backends above, we also support keyword `Math` which is an alias -that maps to all backend and autograd backend keys. In other words, function registered to `Math` key -should be a plain mathematical composition of other `at::` functions and works for any backend. +In addition to the backends above, we also support the keywords: + - `DefaultBackend`: an alias that maps to all backends. Functions registered to + `DefaultBackend` should work for any backend inference. + - `Math`: an alias that maps to all backend and autograd backend keys. Functions + registered to `Math` key should be plain mathematical composition of other + `at::` functions and support training and inference for any backend. If you add `dispatch` section to any API that didn't have it before, you **have to** move the old implementation to `Math` field so that it's still available for other backends to use. +If you implemented a native function in C++ and want to find out which dispatch keyword +should be used in native_functions.yaml, please [follow steps in dispatch keywords](#choosing-the-right-dispatch-keyword) + This work is currently WIP and you can find the design proposal in https://github.com/pytorch/pytorch/issues/44680. @@ -329,17 +335,31 @@ set of reviewers. ### `use_c10_dispatcher` ``` -use_c10_dispatcher: 'with_codegenerated_unboxing_wrapper' use_c10_dispatcher: 'full' +use_c10_dispatcher: 'hacky_wrapper_for_legacy_signatures' ``` This will indicate the level of integration with the c10 dispatcher. -If setting this to 'full' works for your operator, please do. -This will enabled the full templated boxing and unboxing for your operator. -Some ops use features that aren't supported by those templates yet, -and enabling `use_c10_dispatcher: full` for those will result in a compiler error. -For those, use `use_c10_dispatcher: 'with_codegenerated_unboxing_wrapper'` instead, -or just omit the argument because 'with_codegenerated_unboxing_wrapper' is the default. +For any new ops, please set this to 'full'. This is also the default, +so you can just omit it. +This requires the operator function signature to be aligned with the +function schema in native_functions.yaml, i.e. +- out arguments have to be in the end of the argument list instead of in the beginning +- TensorOptions are taken as separate arguments +``` + const c10::optional& dtype, + const c10::optional& layout, + const c10::optional& device, + const c10::optional& pin_memory +``` + instead of one `TensorOptions` argument +- optional tensors are taken as `const c10::optional&` instead of `Tensor` +Some of our kernels are still written in a legacy way, not doing those things, +and need an adapter to work with the dispatcher calling convention. For those, we use +`use_c10_dispatcher: hacky_wrapper_for_legacy_signatures` to codegenerate a corresponding +adapter around them in the operator registration call. Over time, we will migrate all +those kernels to the new calling convention and hacky_wrapper will die. +Please don't use it for new operators. ### `manual_kernel_registration` @@ -384,6 +404,88 @@ will be automatically differentiated! This can be the case if the function imple only calls other operations which are themselves differentiable. In this case, you don't have to write an entry in `tools/autograd/derivatives.yaml`. +### Choosing the right dispatch keyword + +After writing a native function in C++, it's important to think about which dispatch keyword +to use in native_functions.yaml as it gives the dispatcher information about backend and autograd support +of the implementation. + +Here're steps to follow to decide the right dispatch keyword: + +1. Think about inference: does your kernel work for all backends? + + - No: you're likely providing different kernels for different backends, e.g. + backend-dependent logic is used in the implementation or it's implemented through DispatchStub. + DispatchStub only support a backend if you explicitly provide a kernel through `REGISTER_DISPATCH`. + Typically it only supports a few in-tree backends like CPU, CUDA, QuantizedCPU etc but not + out-of-tree backends like XLA. + Write a dispatch section, enumerate all supported backends and point them to the implementations. + ``` + dispatch: + CPU: kernel_cpu + CUDA: kernel_cuda + QuantizedCPU: kernel_quantized_cpu + ``` + + You're done. Now this op will be called in `CPU/CUDA/QuantizedCPU` backend inference! + + Note: to support training, you're required to write a formula in + derivatives.yaml since your backend implementations don't support autograd. + + - Yes: you're likely calling other `at::` ops in the implemetation. Go to step 2. + +2. Think about training: does your kernel support autograd? [check autograd support](#will-your-function-be-automatically-differentiable) + - Yes: in other words, you're providing a `Math` kernel which supports both inference and autograd. + To use autograd support for training, simply skip adding a dispatch section or write + ``` + dispatch: + Math: kernel + ``` + + You're done. This will allow this op to be correctly registered for both inference and training. + + - Yes, but you still want to provide a numerically stable gradient formula instead of using autograd, write + ``` + dispatch: + DefaultBackend: kernel + ``` + + You're done. This op will be called in inference for all backends. + + Note: to support training you're required to add a autograd formula, + or it'll error out in backward pass when calling with a Tensor has requires_grad=True. + + - No: ops in this category are mainly using `_out` boilerplate where its out version doesn't have a derivative + formula defined. For example: + ``` + Tensor& sign_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sign_stub); } + Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); } + Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); } + ``` + + `sign_out` uses DispatchStub so the supported backends are enumerated in its dispatch section. + For `sign` and `sign_`, write + ``` + dispatch: + DefaultBackend: kernel + ``` + + You're done. This op will be called in inference for all backends. + + Note: to support training you're required to add an autograd formula for `sign`, + or it'll error out in backward pass when calling with a Tensor has requires_grad=True. + + Note: current plan on record for ops using this boilerplate is to replace `at::` with `at::native` in + the implementations and add dispatch section with device keywords instead. + +3. TODO: AutogradCPUOrCUDA + +Note that in native_functions.yaml you can mix using backend keywords and alias keywords above for one op: + - direct registration to backend always has higher precendence than alias + - DO NOT provide multiple alias keywords to the same op: alias keywords have precedence `DefaultBackend > Math`, + e.g. adding both `Math` and `DefaultBackend` kernels for one op will completely ignore `Math` kernel for + both inference and training. Thus this will trigger an error when native_functions.yaml is parsed. + ### Will this function be exposed to python? What are the namespaces? We don't generate python bindings for all functions. There're certain patterns in function diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index cc940386ada94..307437aca874b 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -68,6 +68,14 @@ using CellParamsSerializationType = std::tuple< struct CellParamsBase : torch::CustomClassHolder { virtual Tensor matmul_ih(const Tensor& input) const = 0; virtual Tensor matmul_hh(const Tensor& h) const = 0; + // by default doing nothing. CellParams will override this + // to define correct behavior for LSTMs with projections. + // This function is not pure virtual, because it's useful to + // provide this default implementation, so that all cell params + // that don't support projections work correctly (e.g. QuantizedCellParams variations) + virtual Tensor matmul_hr(const Tensor& h) const { + return h; + } virtual Tensor linear_ih(const Tensor& input_ih) const = 0; virtual Tensor linear_hh(const Tensor& input_hh) const = 0; @@ -79,19 +87,22 @@ struct CellParamsBase : torch::CustomClassHolder { // Pretty much all cells we support take the same set of arguments, but threading those // 4 arguments manually is really annoying. Their lifetime is externally managed, so we only -// pass this struct of references around. +// pass this struct of references around. LSTMs with projections have 5th argument w_hr, for all +// other models it's always going to be undefined. struct CellParams : public CellParamsBase { CellParams( const Tensor& _w_ih, const Tensor& _w_hh, const Tensor& _b_ih, - const Tensor& _b_hh) - : w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh){}; + const Tensor& _b_hh, + const Tensor& _w_hr) + : w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr) {}; const Tensor& w_ih; const Tensor& w_hh; const Tensor& b_ih_; /* optional */ const Tensor& b_hh_; /* optional */ + const Tensor& w_hr; /* only defined for LSTMs with projections */ Tensor matmul_ih(const Tensor& input) const override { return at::matmul(input, w_ih.t()); @@ -99,6 +110,12 @@ struct CellParams : public CellParamsBase { Tensor matmul_hh(const Tensor& h) const override { return at::matmul(h, w_hh.t()); } + Tensor matmul_hr(const Tensor& h) const override { + if (w_hr.defined()) { + return at::matmul(h, w_hr.t()); + } + return h; + } Tensor linear_ih(const Tensor& input) const override { return at::linear(input, w_ih, b_ih_); } @@ -468,6 +485,9 @@ struct QRNNCellParamsWrapper { Tensor matmul_hh(const Tensor& h) const { return param_->matmul_hh(h); } + Tensor matmul_hr(const Tensor& h) const { + return param_->matmul_hr(h); + } Tensor linear_ih(const Tensor& input) const { return param_->linear_ih(input); } @@ -509,18 +529,32 @@ static std::vector unpair_vec(std::vector>&& vals) { } // Parses a flat list of parameter tensors into a list of CellParams -static std::vector gather_params(TensorList params, bool has_biases) { +static std::vector gather_params(TensorList params, bool has_biases, bool has_projections = false) { static at::Tensor undefined; std::vector result; if (has_biases) { - TORCH_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters"); - for (size_t i = 0; i < params.size(); i += 4) { - result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3]); + if (has_projections) { + TORCH_CHECK(params.size() % 5 == 0, "got an incorrect number of RNN parameters"); + for (size_t i = 0; i < params.size(); i += 5) { + result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], params[i + 4]); + } + } else { + TORCH_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters"); + for (size_t i = 0; i < params.size(); i += 4) { + result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], undefined); + } } } else { - TORCH_CHECK(params.size() % 2 == 0, "got an incorrect number of RNN parameters"); - for (size_t i = 0; i < params.size(); i += 2) { - result.emplace_back(params[i], params[i + 1], undefined, undefined); + if (has_projections) { + TORCH_CHECK(params.size() % 3 == 0, "got an incorrect number of RNN parameters"); + for (size_t i = 0; i < params.size(); i += 3) { + result.emplace_back(params[i], params[i + 1], undefined, undefined, params[i + 2]); + } + } else { + TORCH_CHECK(params.size() % 2 == 0, "got an incorrect number of RNN parameters"); + for (size_t i = 0; i < params.size(); i += 2) { + result.emplace_back(params[i], params[i + 1], undefined, undefined, undefined); + } } } return result; @@ -702,8 +736,10 @@ struct LSTMCell : Cell, cell_params> { auto hgates = params.matmul_hh(hx); auto result = at::_thnn_fused_lstm_cell( igates, hgates, cx, params.b_ih(), params.b_hh()); + // applying projections if w_hr is defined + auto hy = params.matmul_hr(std::get<0>(result)); // Slice off the workspace argument (it's needed only for AD). - return std::make_tuple(std::move(std::get<0>(result)), std::move(std::get<1>(result))); + return std::make_tuple(std::move(hy), std::move(std::get<1>(result))); } const auto gates = params.linear_hh(hx).add_( @@ -715,6 +751,7 @@ struct LSTMCell : Cell, cell_params> { auto outgate = chunked_gates[3].sigmoid_(); auto cy = (forgetgate * cx).add_(ingate * cellgate); auto hy = outgate * cy.tanh(); + hy = params.matmul_hr(hy); return std::make_tuple(std::move(hy), std::move(cy)); } @@ -1404,8 +1441,10 @@ std::tuple lstm( num_layers, dropout_p, train, bidirectional, batch_first); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } - + // if cells are of different size, that means projections are used + bool has_projections = (hx[0].size(2) != hx[1].size(2)); if (use_miopen(_input, dropout_p)) { + TORCH_CHECK(!has_projections, "LSTM with projections is not supported with MIOpen"); Tensor output, hy, cy; lstm_miopen_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first); @@ -1413,7 +1452,7 @@ std::tuple lstm( } check_attributes(_input, _params, hx); auto input = batch_first ? _input.transpose(0, 1) : _input; - auto params = gather_params(_params, has_biases); + auto params = gather_params(_params, has_biases, has_projections); auto results = _lstm_impl( input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional); if (batch_first) { @@ -1433,8 +1472,10 @@ std::tuple lstm( _params, has_biases, num_layers, dropout_p, train, bidirectional); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } - + // if cells are of different size, that means projections are used + bool has_projections = (hx[0].size(2) != hx[1].size(2)); if (use_miopen(data, dropout_p)) { + TORCH_CHECK(!has_projections, "LSTM with projections is not supported with MIOpen"); Tensor output, hy, cy; lstm_packed_miopen_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional); @@ -1442,7 +1483,7 @@ std::tuple lstm( } PackedSequence input { data, batch_sizes }; - auto params = gather_params(_params, has_biases); + auto params = gather_params(_params, has_biases, has_projections); auto result = _lstm_impl( input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional); auto & packed_output = std::get<0>(result); @@ -1455,7 +1496,8 @@ std::tuple lstm_cell( const Tensor& input, TensorList hx, const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) { TORCH_CHECK(hx.size() == 2, "lstm_cell expects two hidden states"); - return LSTMCell{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh}); + static at::Tensor undefined; + return LSTMCell{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh, undefined}); } std::tuple @@ -1552,19 +1594,22 @@ std::tuple _thnn_differentiable_gru_cell Tensor gru_cell( const Tensor& input, const Tensor& hx, const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) { - return GRUCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh}); + static at::Tensor undefined; + return GRUCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined}); } Tensor rnn_tanh_cell( const Tensor& input, const Tensor& hx, const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) { - return SimpleCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh}); + static at::Tensor undefined; + return SimpleCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined}); } Tensor rnn_relu_cell( const Tensor& input, const Tensor& hx, const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) { - return SimpleCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh}); + static at::Tensor undefined; + return SimpleCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined}); } // Quantized implementations @@ -1592,6 +1637,7 @@ std::tuple quantized_lstm_input( params.emplace_back(static_cast>(param)); } TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states"); + TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported"); auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar; auto input = batch_first ? _input.transpose(0, 1) : _input; TORCH_CHECK(has_biases, "quantized LSTM requires biases"); @@ -1685,6 +1731,7 @@ std::tuple quantized_lstm_data( params.emplace_back(static_cast>(param)); } TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states"); + TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported"); auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar; @@ -1784,7 +1831,8 @@ return_type name( \ return cell_type{}( \ input, prepare_hx_fn(hx), params); \ } - +// Set reduced range to be True for all RNN Cells by default. This flag is used only for FBGEMM kernels +// QNNPACK does not reduce range for activations #define DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(name, hx_type, cell_type, return_type, prepare_hx_fn) \ return_type name( \ const Tensor& input, \ @@ -1798,7 +1846,8 @@ return_type name( \ _packed_w_ih, \ _packed_w_hh, \ b_ih, \ - b_hh); \ + b_hh,\ + true); \ return cell_type{}( \ input, prepare_hx_fn(hx), params); \ } @@ -1859,7 +1908,7 @@ static auto cell_params_base_registry = return cell_params_deserializers[type](std::move(state)); }); -TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(aten, m) { +TORCH_LIBRARY_FRAGMENT(aten, m) { m.def( TORCH_SELECTIVE_SCHEMA("aten::quantized_lstm.input(Tensor input, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)")); m.def( @@ -1878,7 +1927,7 @@ TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(aten, m) { TORCH_SELECTIVE_SCHEMA("aten::quantized_gru.data_legacy(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)")); } -TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(quantized, m) { +TORCH_LIBRARY_FRAGMENT(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase")); diff --git a/aten/src/ATen/native/ReduceAllOps.cpp b/aten/src/ATen/native/ReduceAllOps.cpp index 6db19b1a1be94..ae3991715acd3 100644 --- a/aten/src/ATen/native/ReduceAllOps.cpp +++ b/aten/src/ATen/native/ReduceAllOps.cpp @@ -11,7 +11,6 @@ DEFINE_DISPATCH(max_all_stub); DEFINE_DISPATCH(_aminmax_all_stub); Tensor min(const Tensor &self) { - TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors."); TORCH_CHECK(self.numel() > 0, "operation does not have an identity."); Tensor result = at::empty({}, self.options()); min_all_stub(self.device().type(), result, self.contiguous()); @@ -19,7 +18,6 @@ Tensor min(const Tensor &self) { } Tensor max(const Tensor &self) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); TORCH_CHECK(self.numel() > 0, "operation does not have an identity."); Tensor result = at::empty({}, self.options()); max_all_stub(self.device().type(), result, self.contiguous()); @@ -27,7 +25,6 @@ Tensor max(const Tensor &self) { } std::tuple _aminmax_all(const Tensor &self) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); TORCH_CHECK(self.numel() > 0, "operation does not have an identity."); Tensor min_result = at::empty({}, self.options()); Tensor max_result = at::empty({}, self.options()); diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index ff6d702293b97..fd27b3e7efe5a 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -89,6 +90,18 @@ Tensor cumsum(const Tensor& self, int64_t dim, c10::optional dtype) return result; } +Tensor& cumsum_(Tensor& self, int64_t dim, c10::optional dtype) { + TORCH_CHECK( + !dtype.has_value() || (self.scalar_type() == dtype.value()), + "provided dtype must match the dtype of self tensor in cumsum. Got ", + toString(self.scalar_type()), + " and ", + toString(dtype.value()), + "."); + + return at::_cumsum_out(self, self, dim); +} + Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, c10::optional dtype) { // result type is favored over dtype; check that they match if provided (NumPy doesn't check) TORCH_CHECK( @@ -126,6 +139,18 @@ Tensor cumprod(const Tensor& self, int64_t dim, c10::optional dtype) return result; } +Tensor& cumprod_(Tensor& self, int64_t dim, c10::optional dtype) { + TORCH_CHECK( + !dtype.has_value() || (self.scalar_type() == dtype.value()), + "provided dtype must match the dtype of self tensor in cumprod. Got ", + toString(self.scalar_type()), + " and ", + toString(dtype.value()), + "."); + + return at::_cumprod_out(self, self, dim); +} + Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, c10::optional dtype) { // result type is favored over dtype; check that they match if provided (NumPy doesn't check) TORCH_CHECK( @@ -473,6 +498,40 @@ static Tensor& prod_out_impl(Tensor& result, const Tensor& self, IntArrayRef dim return result; } +// NOTE: this could be implemented via diag and sum, but this has perf problems, +// see https://github.com/pytorch/pytorch/pull/47305, +Tensor trace_cpu(const Tensor& self) { + Tensor result; + ScalarType dtype = get_dtype(result, self, c10::nullopt, true); + result = at::empty({}, self.options().dtype(dtype)); + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] { + using accscalar_t = at::acc_type; + accscalar_t sum = 0; + const auto* t_data = self.data_ptr(); + + int64_t t_stride_0, t_stride_1, t_diag_size; + + TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim()); + + t_stride_0 = self.stride(0); + t_stride_1 = self.stride(1); + + t_diag_size = std::min(self.size(0), self.size(1)); + for (int64_t i = 0; i < t_diag_size; i++) { + sum += t_data[i * (t_stride_0 + t_stride_1)]; + } + + // all integer types get promoted to kLong + if (result.scalar_type() == at::kLong) { + *result.data_ptr() = sum; + } else { + *result.data_ptr() = sum; + } + }); + + return result; +} + Tensor prod(const Tensor& self, int64_t dim, bool keepdim, c10::optional dtype) { Tensor result; native::prod_out_impl(result, self, dim, keepdim, dtype); @@ -601,22 +660,23 @@ Tensor& logsumexp_out(Tensor& result, const Tensor& self, DimnameList dims, bool static Tensor& norm_out(Tensor &result, const Tensor &self, optional opt_p, IntArrayRef dim, bool keepdim, optional opt_dtype) { - auto p = opt_p.value_or(2.0); - TORCH_CHECK(!(p.toDouble() == 2 && self.is_complex()), "norm with p=2 not supported for complex tensors"); + auto p = opt_p.value_or(2.0).to(); TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, "norm only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "norm only supports strided layout, got: ", self.layout()); - ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); + ScalarType in_dtype = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); TORCH_CHECK( - at::isFloatingType(scalarType) || at::isComplexType(scalarType), - "Can only calculate the mean of floating types. Got ", - toString(scalarType), + at::isFloatingType(in_dtype) || at::isComplexType(in_dtype), + "Can only calculate the norm of floating point and complex dtypes. Got ", + toString(in_dtype), " instead."); - ScalarType dtype = get_dtype(result, self, opt_dtype, true); - auto iter = make_reduction("norm", result, self, dim, keepdim, dtype); + ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())); + + auto iter = make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype); + if (iter.numel() == 0) { result.zero_(); } else { @@ -680,6 +740,12 @@ Tensor norm(const Tensor& self, Scalar p) { return at::native::_norm(self, p); } +// Note [all, any : uint8 compatibility]: +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// For NumPy comptability, `all` and `any` return +// Tensor of dtype `bool`. However for compatibility reason, +// for `uint8`, they return Tensor of same dtype `uint8`. +// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 inline Tensor & _all(Tensor & result, TensorIterator & iter) { if (iter.numel() == 0) { result.fill_(1); @@ -695,17 +761,41 @@ Tensor all(const Tensor& self) { "all only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "all only supports strided layout, got: ", self.layout()); - TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, - "all only supports torch.uint8 and torch.bool dtypes"); - Tensor result = at::empty({0}, self.options()); - auto iter = make_reduction( - "all", result, self, {}, false, self.scalar_type()); + // Refer [all, any : uint8 compatibility] + Tensor result; + ScalarType out_dtype; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + out_dtype = self.scalar_type(); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + out_dtype = ScalarType::Bool; + } + + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "all", result, self, {}, false, self.scalar_type(), out_dtype); + return _all(result, iter); + } + auto iter = + make_reduction("all", result, self, {}, false, /*out_dtype=*/out_dtype); return _all(result, iter); } Tensor all(const Tensor& self, int64_t dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + // Refer [all, any : uint8 compatibility] + Tensor result; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + } + return at::native::all_out(result, self, dim, keepdim); } @@ -714,14 +804,26 @@ Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { "all only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "all only supports strided layout, got: ", self.layout()); - TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, - "all only supports torch.uint8 and torch.bool dtypes"); + // Refer [all, any : uint8 compatibility] + TORCH_CHECK(result.scalar_type() == ScalarType::Bool || result.scalar_type() == ScalarType::Byte, + "all only supports bool tensor for result, got: ", result.scalar_type()); + + auto out_dtype = result.scalar_type(); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial(result, self, 1, dim, keepdim)) { return result; } else { - auto iter = make_reduction( - "all", result, self, dim, keepdim, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "all", result, self, dim, keepdim, self.scalar_type(), out_dtype); + return _all(result, iter); + } + auto iter = + make_reduction("all", result, self, dim, keepdim, /*out_dtype=*/out_dtype); return _all(result, iter); } } @@ -741,17 +843,41 @@ Tensor any(const Tensor& self) { "any only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse, "any only supports strided AND sparse layout, got: ", self.layout()); - TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, - "all only supports torch.uint8 and torch.bool dtypes"); + + // Refer [all, any : uint8 compatibility] + Tensor result; + ScalarType out_dtype; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + out_dtype = self.scalar_type(); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + out_dtype = ScalarType::Bool; + } - Tensor result = at::empty({0}, self.options()); - auto iter = make_reduction( - "any", result, self, {}, false, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "any", result, self, {}, false, self.scalar_type(), out_dtype); + return _any(result, iter); + } + auto iter = + make_reduction("any", result, self, {}, false, /*out_dtype=*/out_dtype); return _any(result, iter); } Tensor any(const Tensor& self, int64_t dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); + // Refer [all, any : uint8 compatibility] + Tensor result; + if (self.scalar_type() == ScalarType::Byte){ + result = at::empty({0}, self.options()); + } else { + result = at::empty({0}, self.options().dtype(kBool)); + } + return at::native::any_out(result, self, dim, keepdim); } @@ -760,14 +886,26 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { "any only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "any only supports strided layout, got: ", self.layout()); - TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, - "all only supports torch.uint8 and torch.bool dtypes"); + // Refer [all, any : uint8 compatibility] + TORCH_CHECK(result.scalar_type() == ScalarType::Bool || result.scalar_type() == ScalarType::Byte, + "any only supports bool tensor for result, got: ", result.scalar_type()); + + auto out_dtype = result.scalar_type(); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) { return result; } else { - auto iter = make_reduction( - "any", result, self, dim, keepdim, self.scalar_type()); + if (self.is_cuda()) { + // As CUDA supports dynamic type casting, we use this overload of + // `make_reduction`, which doesn't cast input to the result type i.e. kBool., + // otherwise we use the overload below which casts the input to kBool (which is + // an extra operation). + auto iter = make_reduction( + "any", result, self, dim, keepdim, self.scalar_type(), out_dtype); + return _any(result, iter); + } + auto iter = + make_reduction("any", result, self, dim, keepdim, /*out_dtype=*/out_dtype); return _any(result, iter); } } @@ -1008,8 +1146,8 @@ Tensor var(const Tensor& self, bool unbiased) { return trivial_return.value(); } - // NOTE: CPU performance significantly regressed when attempting to port to ATen, - // so this dispatches differently based on device type. + // NOTE: CPU performance significantly regressed when attempting to port to ATen, + // so this dispatches differently based on device type. // See https://github.com/pytorch/pytorch/pull/43858. if (self.device().type() == kCPU) { return at::_var(self, unbiased); @@ -1040,8 +1178,8 @@ Tensor std(const Tensor& self, bool unbiased) { return trivial_return.value(); } - // NOTE: CPU performance significantly regressed when attempting to port to ATen, - // so this dispatches differently based on device type. + // NOTE: CPU performance significantly regressed when attempting to port to ATen, + // so this dispatches differently based on device type. // See https://github.com/pytorch/pytorch/pull/43858. if (self.device().type() == kCPU) { return at::_std(self, unbiased); @@ -1121,12 +1259,18 @@ Tensor& logcumsumexp_out(Tensor& result, const Tensor& self, Dimname dim) { Tensor cumsum(const Tensor& self, Dimname dim, c10::optional dtype) { return at::cumsum(self, dimname_to_position(self, dim), dtype); } +Tensor& cumsum_(Tensor& self, Dimname dim, c10::optional dtype) { + return native::cumsum_(self, dimname_to_position(self, dim), dtype); +} Tensor& cumsum_out(Tensor& result, const Tensor& self, Dimname dim, c10::optional dtype) { return at::cumsum_out(result, self, dimname_to_position(self, dim), dtype); } Tensor cumprod(const Tensor& self, Dimname dim, c10::optional dtype) { return at::cumprod(self, dimname_to_position(self, dim), dtype); } +Tensor& cumprod_(Tensor& self, Dimname dim, c10::optional dtype) { + return native::cumprod_(self, dimname_to_position(self, dim), dtype); +} Tensor& cumprod_out(Tensor& result, const Tensor& self, Dimname dim, c10::optional dtype) { return at::cumprod_out(result, self, dimname_to_position(self, dim), dtype); } diff --git a/aten/src/ATen/native/ReflectionPad.cpp b/aten/src/ATen/native/ReflectionPad.cpp index 0ad13bd340a1e..617f1f03a0d82 100644 --- a/aten/src/ATen/native/ReflectionPad.cpp +++ b/aten/src/ATen/native/ReflectionPad.cpp @@ -346,23 +346,43 @@ void reflection_pad2d_out_template( if (input.ndimension() == 3) { /* resize output */ output.resize_({nplane, output_h, output_w}); - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] { - reflection_pad2d_out_frame( - input.data_ptr(), output.data_ptr(), - nplane, - input_w, input_h, output_w, output_h, - pad_l, pad_t); - }); + if (input.is_quantized()) { + AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qreflection_pad2d", [&] { + reflection_pad2d_out_frame( + input.data_ptr(), output.data_ptr(), + nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] { + reflection_pad2d_out_frame( + input.data_ptr(), output.data_ptr(), + nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + }); + } } else { /* resize output */ output.resize_({nbatch, nplane, output_h, output_w}); - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] { - reflection_pad2d_out_loop( - input.data_ptr(), output.data_ptr(), - nbatch, nplane, - input_w, input_h, output_w, output_h, - pad_l, pad_t); - }); + if (input.is_quantized()) { + AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qreflection_pad2d", [&] { + reflection_pad2d_out_loop( + input.data_ptr(), output.data_ptr(), + nbatch, nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] { + reflection_pad2d_out_loop( + input.data_ptr(), output.data_ptr(), + nbatch, nplane, + input_w, input_h, output_w, output_h, + pad_l, pad_t); + }); + } } } @@ -547,7 +567,18 @@ Tensor& reflection_pad2d_out_cpu( } Tensor reflection_pad2d_cpu(const Tensor& input, IntArrayRef padding) { - auto output = at::empty({0}, input.options()); + Tensor output; + if (input.is_quantized()) { + if (input.qscheme() == kPerTensorAffine) { + output = at::_empty_affine_quantized({0}, input.options(), + input.q_scale(), + input.q_zero_point()); + } else { + TORCH_CHECK(false, "Only per tensor quantization is supported"); + } + } else { + output = at::empty({0}, input.options()); + } reflection_pad2d_out_template(output, input, padding); return output; } diff --git a/aten/src/ATen/native/ReplicationPadding.cpp b/aten/src/ATen/native/ReplicationPadding.cpp index 4a9c8cd7ad1af..a4eb075a5c3c3 100644 --- a/aten/src/ATen/native/ReplicationPadding.cpp +++ b/aten/src/ATen/native/ReplicationPadding.cpp @@ -71,9 +71,11 @@ void replication_pad1d_out_cpu_template( int pad_l = paddingSize[0]; int pad_r = paddingSize[1]; - TORCH_CHECK(input_.numel() > 0 - && (input_.ndimension() == 2 || input_.ndimension() == 3), - "non-empty 2D or 3D (batch mode) tensor expected for input"); + // allow empty batch size but not other dimensions. + TORCH_CHECK((input_.dim() == 2 && input_.size(0) != 0 && input_.size(1) != 0) || + (input_.dim() == 3 && input_.size(1) != 0 && input_.size(2) != 0), + "Expected 2D or 3D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input_.sizes()); if (input_.ndimension() == 3) { @@ -91,7 +93,6 @@ void replication_pad1d_out_cpu_template( "input (W: ", iwidth, ") is too small." " Calculated output W: ", owidth); - /* get contiguous input */ auto input = input_.contiguous(); @@ -216,6 +217,9 @@ Tensor& replication_pad1d_backward_out_cpu_template( /* get contiguous gradOutput */ auto gradOutput = gradOutput_.contiguous(); gradInput.resize_as_(input); + if (gradInput.numel() == 0) { + return gradInput; + } gradInput.zero_(); /* backprop */ @@ -339,8 +343,13 @@ void replication_pad2d_out_cpu_template(Tensor& output, int dimslices = 0; int64_t nbatch = 1; - TORCH_CHECK(input_.numel() > 0 && (input_.dim() == 3 || input_.dim() == 4), - "3D or 4D (batch mode) tensor expected for input, but got: ", input_); + // allow 0 dim batch size and nothing else. + bool valid_dims = input_.size(1) != 0 && input_.size(2) != 0; + TORCH_CHECK( + (input_.dim() == 3 && input_.size(0) != 0 && valid_dims) || + (input_.dim() == 4 && valid_dims && input_.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input_.sizes()); if (input_.dim() == 4) { @@ -510,6 +519,10 @@ Tensor& replication_pad2d_backward_out_cpu_template( /* resize */ gradInput.resize_as_(input); + if (gradInput.numel() == 0) { + return gradInput; + } + gradInput.zero_(); /* backprop */ @@ -557,8 +570,13 @@ static inline void shapeCheck3d( int dimd = 1; int dimslices = 0; - TORCH_CHECK(input.numel() > 0 && (input.dim() == 4 || input.dim() == 5), - "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input); + // allow batch size of 0-dim. + bool valid_dims = input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0; + TORCH_CHECK( + (input.dim() == 4 && input.size(0) != 0 && valid_dims) || + (input.dim() == 5 && valid_dims && input.size(4) != 0), + "Expected 4D or 5D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); if (input.dim() == 5) { @@ -872,6 +890,9 @@ Tensor& replication_pad3d_backward_out_cpu_template( /* resize */ gradInput.resize_as_(input); + if (gradInput.numel() == 0) { + return gradInput; + } gradInput.zero_(); /* backprop */ diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index d6da309d4cf7a..e5a0423e493c9 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -77,12 +77,13 @@ Tensor& resize_as_( Tensor& resize_( Tensor& self, IntArrayRef size, - c10::optional optional_memory_format) { + c10::optional optional_memory_format, + bool resize_storage) { if (self.has_names()) { return resize_named_tensor_(self, size, optional_memory_format); } auto* self_ = self.unsafeGetTensorImpl(); - resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt); + resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage); if (optional_memory_format.has_value()) { auto memory_format = optional_memory_format.value(); @@ -95,5 +96,20 @@ Tensor& resize_( return self; } +Tensor& resize_( + Tensor& self, + IntArrayRef size, + c10::optional optional_memory_format) { + return resize_(self, size, optional_memory_format, /*resize_storage=*/true); +} + +Tensor& resize_meta_( + Tensor& self, + IntArrayRef size, + c10::optional optional_memory_format) { + // meta tensors don't have storage, so don't resize them + return resize_(self, size, optional_memory_format, /*resize_storage=*/false); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 8fdc977092f46..bde91c6acf1df 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include namespace at { namespace native { @@ -13,7 +14,7 @@ namespace at { namespace native { // Issues a warning if the output tensor has one or more elements and // needs resizing // NOTE: In the future the warning will become an error -void resize_output(Tensor& output, IntArrayRef shape); +TORCH_API void resize_output(Tensor& output, IntArrayRef shape); // These functions are called by native::resize_ as well as (legacy) TH resize. // They are not in TH/THTensor.cpp because the at namespace is easier @@ -42,7 +43,8 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, int64_t new_size) inline TensorImpl* resize_impl_cpu_( TensorImpl* self, IntArrayRef size, - c10::optional stride) { + c10::optional stride, + bool resize_storage = true) { if (self->sizes() == size && (!stride || self->strides() == stride)) { return self; } @@ -51,20 +53,14 @@ inline TensorImpl* resize_impl_cpu_( if (stride) { self->set_sizes_and_strides(size, *stride); // NB: storage size can be different from numel. - for (size_t dim = 0; dim < size.size(); ++dim) { - // FIXME: Don't rely on storage_size being negative because this - // may not be true for some edge cases. - if (size[dim] == 0) { - storage_size = 0; - break; - } - storage_size += (size[dim] - 1) * stride.value()[dim]; - } + storage_size = storage_size_for(size, *stride); } else { self->set_sizes_contiguous(size); storage_size = self->numel(); } - maybe_resize_storage_cpu(self, storage_size); + if (resize_storage) { + maybe_resize_storage_cpu(self, storage_size); + } return self; } @@ -73,7 +69,7 @@ static inline void checkInBoundsForStorage( IntArrayRef size, IntArrayRef stride, int64_t storage_offset, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, const Storage& new_storage) { int64_t storage_size_bytes = detail::computeStorageNbytes(size, stride, data_type.itemsize()); diff --git a/aten/src/ATen/native/ResizeCommon.h b/aten/src/ATen/native/ResizeCommon.h index ec272e227acf7..8204d00cd77e8 100644 --- a/aten/src/ATen/native/ResizeCommon.h +++ b/aten/src/ATen/native/ResizeCommon.h @@ -5,6 +5,21 @@ namespace at { namespace native { +inline int64_t storage_size_for(IntArrayRef size, IntArrayRef stride) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(), + "storage_size_for(size, stride) requires that size and stride ", + "have the same size as a precondition."); + int64_t storage_size = 1; + for (size_t dim = 0; dim < size.size(); ++dim) { + if (size[dim] == 0) { + storage_size = 0; + break; + } + storage_size += (size[dim] - 1) * stride[dim]; + } + return storage_size; +} + inline Tensor& resize_named_tensor_( Tensor& self, IntArrayRef size, diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index ae80a9e41be90..e25b943d13a87 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -2,6 +2,8 @@ // Please note that this file is // used across both CPU and GPU. +#include +#include #include #include #include @@ -157,11 +159,15 @@ struct MeanOps { } }; -template +// This accumulator template is used to calculate the minimum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct AbsMinOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return MIN(acc, acc_t(std::abs(data))); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MIN(acc, static_cast(std::abs(data))); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -177,17 +183,21 @@ struct AbsMinOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template +// This accumulator template is used to calculate the maximum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct AbsMaxOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return MAX(acc, acc_t(std::abs(data))); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MAX(acc, static_cast(std::abs(data))); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -203,18 +213,22 @@ struct AbsMaxOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template +// This accumulator template is used to calculate the norm of the absolute value +// of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormOps { acc_t norm_; - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + compat_pow(std::abs(data), norm_); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + compat_pow(static_cast(std::abs(data)), norm_); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -222,7 +236,7 @@ struct NormOps { } inline C10_DEVICE acc_t project(acc_t a) const { - return compat_pow(a, acc_t(1.0)/norm_); + return compat_pow(a, static_cast(1.0) / norm_); } static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { @@ -230,8 +244,8 @@ struct NormOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif @@ -239,10 +253,14 @@ struct NormOps { } }; -template +// This accumulator template is used to calculate the order zero norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormZeroOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + (data==acc_t(0) ? acc_t(0) : acc_t(1)); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + (data == static_cast(0) ? static_cast(0) : static_cast(1)); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -259,16 +277,20 @@ struct NormZeroOps { #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template +// This accumulator template is used to calculate the order one norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormOneOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + std::abs(data); + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + static_cast(std::abs(data)); } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -284,16 +306,40 @@ struct NormOneOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; -template + +template +struct AbsSwitch {}; + +template +inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch s) { + return static_cast(data); +} + +template +inline C10_DEVICE acc_t abs_if_complex(std::complex data, AbsSwitch s) { + return static_cast(std::abs(data)); +} + +template +inline C10_DEVICE acc_t abs_if_complex(c10::complex data, AbsSwitch s) { + return static_cast(std::abs(data)); +} + +// This accumulator template is used to calculate the order two norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template struct NormTwoOps { - inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data, int64_t /*idx*/) const { - return acc + data * data; + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + acc_t data_ = abs_if_complex(data, AbsSwitch()); + return acc + data_ * data_; } inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { @@ -309,8 +355,8 @@ struct NormTwoOps { } #if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { - return WARP_SHFL_DOWN(data, offset); + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); } #endif }; diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index ffddddfd2ba5e..f4b7907df051a 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -17,10 +18,7 @@ DEFINE_DISPATCH(topk_stub); namespace { -// maybe these days, one should define a random access iterator and use -// std::sort... /* Note from TH: - I cut and pasted (slightly adapted) the quicksort code from Sedgewick's 1978 "Implementing Quicksort Programs" article http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf @@ -35,7 +33,6 @@ namespace { Julien, November 12th 2013 */ - template void quick_select_template( TensorAccessor arr, @@ -156,7 +153,7 @@ void quantile_impl( } else if (dim == self.dim() - 1) { sorted = std::get<0>(self.sort()); } else { - sorted = std::get<0>(self.unsqueeze(-1).transpose_(dim, -1).sort()); + sorted = std::get<0>(self.unsqueeze(-1).transpose(dim, -1).sort()); } // Treat q as a 1D tensor for the following computations @@ -228,6 +225,8 @@ std::tuple kthvalue_out_impl_cpu( k > 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), "selected index k out of range"); + at::assert_no_overlap(self, values); + _reduction_with_indices_allocate_or_resize_output( values, indices, self, dim_, keepdim); if (self.dim() == 0 && self.numel() == 1) { @@ -272,35 +271,146 @@ std::tuple kthvalue_out_impl_cpu( return std::forward_as_tuple(values, indices); } -} // namespace - -std::tuple kthvalue_out_cpu( +// Computes both the median and its index along dimension dim of the input +std::tuple median_with_indices_impl( Tensor& values, Tensor& indices, const Tensor& self, - int64_t k, int64_t dim, - bool keepdim) { - auto result = [&]() { - NoNamesGuard guard; - return kthvalue_out_impl_cpu(values, indices, self, k, dim, keepdim); - }(); - namedinference::propagate_names_for_reduction(values, self, dim, keepdim); - namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); - return result; + bool keepdim, + bool ignore_nan) { + dim = at::maybe_wrap_dim(dim, self.dim()); + + int64_t size = self.dim() > 0 ? self.size(dim) : 1; + TORCH_CHECK( + size > 0, + "median() cannot compute median for a dimension of size 0 because ", + "the operation does not have an identity"); + + checkDeviceType("median", {values, indices}, self.device().type()); + checkScalarType("median", {indices, "indices", 1}, kLong); + checkSameType("median", {values, "values", 0}, {self, "self", 2}); + + std::vector out_shape = self.sizes().vec(); + if (self.dim() > 0) { + if (keepdim) { + out_shape[dim] = 1; + } else { + out_shape.erase(out_shape.begin() + dim); + } + } + + resize_output(values, out_shape); + resize_output(indices, out_shape); + + // Ensure #dim is the same for all tensors required for dim_apply + Tensor in = self.dim() > 0 ? self : self.unsqueeze(0); + Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim); + Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim); + + // Make dim to reduce contiguous (stride=1) + if (in.stride(dim) > 1) { + in = in.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim).contiguous(); + vals = vals.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim); + inds = inds.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim); + dim = in.dim() - 1; + } + + AT_DISPATCH_ALL_TYPES(in.scalar_type(), "median_out", [&] { + dim_apply({in, vals, inds}, dim, [&](int64_t it, TensorList tl) { + // Make the current row to be reduced contiguous + scalar_t* ip = tl[0].data_ptr(); + + // For torch.median, search for NaN and return it if found + if (!ignore_nan) { + scalar_t* nanp = std::find_if(ip, ip + size, _isnan); + if (nanp != ip + size) { + *tl[1].data_ptr() = *nanp; + *tl[2].data_ptr() = nanp - ip; + return; + } + } + + // Vector of indices for indirectly partitioning input around median + std::vector idx(size); + auto first = idx.begin(); + auto last = idx.end(); + std::iota(first, last, 0); + + // We partition the input around the median indirectly using the indices + // vector so that nth points to the index of the median in the unmodified + // input tensor. + auto nth = first; + if (!ignore_nan) { + // If we got here, there are no nan values + nth += (size - 1) / 2; + std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) { + return ip[i] < ip[j] || (ip[i] == ip[j] && i < j); + }); + } else { + // For torch.nanmedian, compute median of non-nan values only + int64_t num_nan = std::count_if(ip, ip + size, _isnan); + nth += (size - num_nan - 1) / 2; + std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) { + return ip[i] < ip[j] || (ip[i] == ip[j] && i < j) || + (_isnan(ip[j]) && !_isnan(ip[i])); + }); + } + + *tl[1].data_ptr() = ip[*nth]; + *tl[2].data_ptr() = *nth; + }); + }); + + return std::forward_as_tuple(values, indices); } -std::tuple kthvalue( - const Tensor& self, - int64_t k, - int64_t dim, - bool keepdim) { - Tensor values = at::empty({0}, self.options()); - Tensor indices = at::empty({0}, self.options().dtype(kLong)); - at::kthvalue_out(values, indices, self, k, dim, keepdim); - return std::make_tuple(values, indices); +// Computes the median of all values in the input +Tensor median_impl(const Tensor& self, bool ignore_nan) { + NoNamesGuard guard; + + int64_t size = self.numel(); + TORCH_CHECK( + size > 0, + "median() operation does not have an identity for empty input tensor"); + + // Clone the input tensor so we can partition it around the median value + Tensor in = self.clone(); + Tensor out = at::empty({}, self.options()); + + AT_DISPATCH_ALL_TYPES(in.scalar_type(), "median_cpu", [&] { + scalar_t* op = out.data_ptr(); + scalar_t* first = in.data_ptr(); + scalar_t* last = first + size; + + // For torch.median, if there are nan values return nan + if (!ignore_nan && std::any_of(first, last, _isnan)) { + *op = std::numeric_limits::quiet_NaN(); + return; + } + + scalar_t* median = first; + if (!ignore_nan) { + // If we got here, there are no nan values + median += (size - 1) / 2; + std::nth_element(first, median, last); + } else { + // For torch.nanmedian, compute median of non-nan values only + int64_t num_nan = std::count_if(first, last, _isnan); + median += (size - num_nan - 1) / 2; + std::nth_element(first, median, last, [](scalar_t a, scalar_t b) { + return a < b || (_isnan(b) && !_isnan(a)); + }); + } + + *op = *median; + }); + + return out; } +} // namespace + Tensor& quantile_out( Tensor& out, const Tensor& self, @@ -395,6 +505,52 @@ Tensor nanquantile( self, at::scalar_tensor(q, self.options()), std::move(_dim), keepdim); } +std::tuple kthvalue_out_cpu( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + auto result = [&]() { + NoNamesGuard guard; + return kthvalue_out_impl_cpu(values, indices, self, k, dim, keepdim); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; +} + +std::tuple kthvalue_out( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + Dimname dim, + bool keepdim) { + return at::kthvalue_out( + values, indices, self, k, dimname_to_position(self, dim), keepdim); +} + +std::tuple kthvalue( + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + Tensor values = at::empty({0}, self.options()); + Tensor indices = at::empty({0}, self.options().dtype(kLong)); + at::kthvalue_out(values, indices, self, k, dim, keepdim); + return std::make_tuple(values, indices); +} + +std::tuple kthvalue( + const Tensor& self, + int64_t k, + Dimname dim, + bool keepdim) { + return at::kthvalue(self, k, dimname_to_position(self, dim), keepdim); +} + std::tuple topk_out_cpu( Tensor& values, Tensor& indices, @@ -432,16 +588,30 @@ std::tuple topk( return std::make_tuple(values, indices); } -std::tuple median_out( +std::tuple median_out_cpu( Tensor& values, Tensor& indices, const Tensor& self, int64_t dim, bool keepdim) { - // note: kthvalue counts from 1..n - int64_t k = self.dim() > 0 ? (self.size(dim) + 1) / 2 : 1; - at::kthvalue_out(values, indices, self, k, dim, keepdim); - return std::forward_as_tuple(values, indices); + auto result = [&]() { + NoNamesGuard guard; + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/false); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; +} + +std::tuple median_out( + Tensor& values, + Tensor& indices, + const Tensor& self, + Dimname dim, + bool keepdim) { + return at::median_out( + values, indices, self, dimname_to_position(self, dim), keepdim); } std::tuple median( @@ -454,67 +624,62 @@ std::tuple median( return std::make_tuple(values, indices); } -std::tuple median_out( - Tensor& values, - Tensor& indices, +std::tuple median( const Tensor& self, Dimname dim, bool keepdim) { - return at::median_out( - values, indices, self, dimname_to_position(self, dim), keepdim); + return at::median(self, dimname_to_position(self, dim), keepdim); } -std::tuple median( +Tensor median_cpu(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/false); +} + +std::tuple nanmedian_out_cpu( + Tensor& values, + Tensor& indices, const Tensor& self, - Dimname dim, + int64_t dim, bool keepdim) { - return at::median(self, dimname_to_position(self, dim), keepdim); + auto result = [&]() { + NoNamesGuard guard; + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/true); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; } -std::tuple kthvalue_out( +std::tuple nanmedian_out( Tensor& values, Tensor& indices, const Tensor& self, - int64_t k, Dimname dim, bool keepdim) { - return at::kthvalue_out( - values, indices, self, k, dimname_to_position(self, dim), keepdim); + return at::nanmedian_out( + values, indices, self, dimname_to_position(self, dim), keepdim); } -std::tuple kthvalue( +std::tuple nanmedian( + const Tensor& self, + int64_t dim, + bool keepdim) { + Tensor values = at::empty({0}, self.options()); + Tensor indices = at::empty({0}, self.options().dtype(kLong)); + at::nanmedian_out(values, indices, self, dim, keepdim); + return std::make_tuple(values, indices); +} + +std::tuple nanmedian( const Tensor& self, - int64_t k, Dimname dim, bool keepdim) { - return at::kthvalue(self, k, dimname_to_position(self, dim), keepdim); + return at::nanmedian(self, dimname_to_position(self, dim), keepdim); } -// this does not reduce to median with dim because we don't want to copy twice -Tensor median_cpu(const Tensor& self) { - NoNamesGuard guard; - TORCH_CHECK(self.numel() > 0, "median cannot be called with empty tensor"); - if (self.dim() == 0 && self.numel() == 1) { - return self.clone(at::MemoryFormat::Contiguous); - } - auto tmp_values = self.clone(at::MemoryFormat::Contiguous).view(-1); - auto result = at::empty({1}, self.options()); - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "median", [&] { - // note, quick_select is 0 based while kthvalue is not - int64_t k = (tmp_values.size(0) - 1) / 2; - auto val_accessor = tmp_values.accessor(); - quick_select_template( - val_accessor, - k, - [](scalar_t x, scalar_t y) -> bool { - return ((_isnan(x) && !_isnan(y)) || (x > y)); - }, - [&](int64_t i, int64_t j) { - std::swap(val_accessor[i], val_accessor[j]); - }); - result.fill_(tmp_values[k]); - }); - return result.view({}); +Tensor nanmedian_cpu(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/true); } std::tuple sort_out_cpu( @@ -522,8 +687,8 @@ std::tuple sort_out_cpu( Tensor& indices, const Tensor& self, int64_t dim, - bool descending - ) { + bool descending, + bool stable) { values.resize_(self.sizes()).copy_(self); indices.resize_(self.sizes()); @@ -533,7 +698,7 @@ std::tuple sort_out_cpu( return std::forward_as_tuple(values, indices); } - sort_stub(kCPU, values, indices, dim, descending); + sort_stub(kCPU, values, indices, dim, descending, stable); return std::forward_as_tuple(values, indices); } @@ -541,11 +706,21 @@ std::tuple sort_out_cpu( std::tuple sort_cpu( const Tensor& self, int64_t dim, - bool descending - ) { + bool descending, + bool stable) { Tensor values = at::empty({0}, self.options()); Tensor indices = at::empty({0}, self.options().dtype(kLong)); - return sort_out_cpu(values, indices, self, dim, descending); + return sort_out_cpu(values, indices, self, dim, descending, stable); +} + +Tensor& msort_out(Tensor& values, const Tensor& self) { + Tensor indices = at::empty({0}, self.options().dtype(kLong)); + at::sort_out(values, indices, self, 0, false, false); + return values; +} + +Tensor msort(const Tensor& self) { + return std::get<0>(at::sort(self, 0, false)); } } // namespace native diff --git a/aten/src/ATen/native/Sorting.h b/aten/src/ATen/native/Sorting.h index ee69b3ebd2540..f62f81e6dca51 100644 --- a/aten/src/ATen/native/Sorting.h +++ b/aten/src/ATen/native/Sorting.h @@ -5,7 +5,7 @@ namespace at { namespace native { -using sort_fn = void(*)(Tensor& values, Tensor& indices, int64_t dim, bool descending); +using sort_fn = void(*)(Tensor& values, Tensor& indices, int64_t dim, bool descending, bool stable); using topk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool); DECLARE_DISPATCH(sort_fn, sort_stub); diff --git a/aten/src/ATen/native/SortingUtils.h b/aten/src/ATen/native/SortingUtils.h index 6b83be84ce903..3223fd3a779d1 100644 --- a/aten/src/ATen/native/SortingUtils.h +++ b/aten/src/ATen/native/SortingUtils.h @@ -40,7 +40,7 @@ void dim_apply(TensorList tensors, int64_t dim, Fn f) { }); } -// ensure we get good values and indices for kthvalue, mode, median +// ensure we get good values and indices for kthvalue, mode // this will always be with the reducing dim as 1-d inline void _reduction_with_indices_allocate_or_resize_output( Tensor& values, diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 1e9c1bce67d3a..289d1128d2f9a 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -18,12 +19,6 @@ namespace at { namespace native { -// Common code for all FFT functions -static inline Tensor _fft( - const Tensor &self, int64_t signal_ndim, bool complex_input, - const bool complex_output, bool inverse, IntArrayRef signal_sizes, - fft_norm_mode normalization, bool onesided); - namespace { // Promote inputs to FFT functions @@ -107,9 +102,12 @@ Tensor resize_fft_input(Tensor x, IntArrayRef dims, IntArrayRef sizes) { } // Complex to real FFT -Tensor fft_c2r(Tensor input, c10::optional n_opt, +Tensor fft_c2r(c10::string_view function_name, + Tensor out, Tensor input, c10::optional n_opt, int64_t unwrapped_dim, c10::optional norm_str, bool forward) { + TORCH_CHECK(!out.defined() || out.is_floating_point(), function_name, + " expects a floating point output tensor, but got ", out.scalar_type()); input = promote_tensor_fft(input, /*require_complex=*/true); const auto input_dim = input.dim(); const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); @@ -118,32 +116,27 @@ Tensor fft_c2r(Tensor input, c10::optional n_opt, if (n_opt) { input = resize_fft_input(input, dim, n/2 + 1); } - // _fft only operates on the last dim, so transpose the selected dim to the end - const bool must_transpose = (dim != input_dim - 1); - if (must_transpose) { - input = at::transpose(input, -1, dim); - } const auto norm = norm_from_string(norm_str, forward); if (forward) { // FIXME: _fft does not support complex_output=false with inverse=false input = at::conj(input); } - auto out = _fft(at::view_as_real(input), - /*signal_ndim=*/1, /*complex_input=*/true, - /*complex_output=*/false, /*inverse=*/true, - /*signal_sizes=*/{n}, /*normalization=*/norm, - /*onesided=*/true); - if (must_transpose) { - out = at::transpose(out, -1, dim); + if (out.defined()) { + return at::_fft_c2r_out(out, input, dim, static_cast(norm), n); + } else { + return at::_fft_c2r(input, dim, static_cast(norm), n); } - return out; } // Real to complex FFT -Tensor fft_r2c(Tensor input, c10::optional n_opt, +Tensor fft_r2c(c10::string_view function_name, + Tensor out, Tensor input, c10::optional n_opt, int64_t unwrapped_dim, c10::optional norm_str, bool forward, bool onesided) { - TORCH_CHECK(!input.is_complex(), "Expected a real input tensor to FFT"); + TORCH_CHECK(!input.is_complex(), function_name, + " expects a real input tensor, but got ", input.scalar_type()); + TORCH_CHECK(!out.defined() || out.is_complex(), function_name, + " expects a complex output tensor, but got ", out.scalar_type()); input = promote_tensor_fft(input); const auto input_dim = input.dim(); const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); @@ -152,32 +145,31 @@ Tensor fft_r2c(Tensor input, c10::optional n_opt, if (n_opt) { input = resize_fft_input(input, dim, n); } - // _fft only operates on the last dim, so transpose the selected dim to the end - const bool must_transpose = (dim != input_dim - 1); - if (must_transpose) { - input = at::transpose(input, -1, dim); - } + const auto norm = norm_from_string(norm_str, forward); - auto out = _fft(input, /*signal_ndim=*/1, /*complex_input=*/false, - /*complex_output=*/true, /*inverse=*/false, - /*signal_sizes=*/{n}, /*normalization=*/norm, - /*onesided=*/onesided); - out = at::view_as_complex(out); - if (must_transpose) { - out = at::transpose(out, -1, dim); + + Tensor ret; + if (out.defined() && forward) { + ret = at::_fft_r2c_out(out, input, dim, static_cast(norm), onesided); + } else { + ret = at::_fft_r2c(input, dim, static_cast(norm), onesided); } + if (!forward) { - // FIXME: _fft does not support complex_input=false with inverse=true - out = at::conj(out); + // FIXME: _fft_r2c doesn't support native r2c IFFT + return out.defined() ? at::conj_out(out, ret) : at::conj(ret); + } else { + return ret; } - return out; } // Complex to complex FFT -Tensor fft_c2c(Tensor input, c10::optional n_opt, +Tensor fft_c2c(c10::string_view function_name, + Tensor out, Tensor input, c10::optional n_opt, int64_t unwrapped_dim, c10::optional norm_str, bool forward) { - TORCH_CHECK(input.is_complex(), "Expected a complex input tensor to FFT"); + TORCH_CHECK(input.is_complex(), function_name, + " expects a complex input tensor, but got ", input.scalar_type()); const auto input_dim = input.dim(); const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); const auto n = n_opt.value_or(input.sizes()[dim]); @@ -185,22 +177,14 @@ Tensor fft_c2c(Tensor input, c10::optional n_opt, if (n_opt) { input = resize_fft_input(input, dim, n); } - // _fft only operates on the last dim, so transpose the selected dim to the end - const bool must_transpose = (dim != input_dim - 1); - if (must_transpose) { - input = at::transpose(input, -1, dim); - } const auto norm = norm_from_string(norm_str, forward); - auto out = _fft(at::view_as_real(input), - /*signal_ndim=*/1, /*complex_input=*/true, - /*complex_output=*/true, /*inverse=*/!forward, - /*signal_sizes=*/{}, /*normalization=*/norm, - /*onesided=*/false); - out = at::view_as_complex(out); - if (must_transpose) { - out = at::transpose(out, -1, dim); + if (out.defined()) { + TORCH_CHECK(out.is_complex(), function_name, + " expects a complex output tensor, but got ", out.scalar_type()); + return at::_fft_c2c_out(out, input, dim, static_cast(norm), forward); + } else { + return at::_fft_c2c(input, dim, static_cast(norm), forward); } - return out; } // Dimensions to transform, and the signal shape in those dimensions @@ -262,92 +246,110 @@ ShapeAndDims canonicalize_fft_shape_and_dim_args( ret.shape[i] = input_sizes[ret.dim[i]]; } } - + for (int64_t i = 0; i < ret.shape.size(); ++i) { TORCH_CHECK(ret.shape[i] > 0, "Invalid number of data points (", ret.shape[i], ") specified"); } - + return ret; } // Complex to complex n-dimensional fft Tensor fftn_c2c( - const Tensor& input, IntArrayRef shape, IntArrayRef dim, - c10::optional norm_str, bool forward) { - TORCH_CHECK(input.is_complex(), "Expected a complex input tensor to FFT"); - const auto input_dim = input.dim(); - + c10::string_view function_name, + Tensor out, const Tensor& input, IntArrayRef shape, + IntArrayRef dim, c10::optional norm_str, bool forward) { + TORCH_CHECK(input.is_complex(), function_name, " expects a complex input tensor, but got", input.scalar_type()); Tensor x = resize_fft_input(input, dim, shape); - x = at::view_as_real(x); - - const int64_t transform_ndim = dim.size(); const auto norm = norm_from_string(norm_str, forward); - // _fft_with_size only supports 3 dimensions being transformed at a time. - // This limit is inherited from cuFFT. - constexpr int64_t max_signal_ndim = 3; - - // Transform n dimensions, up to 3 at a time - // TODO: rewrite _fft_with_size to transform more than 3 dimensions at once. - for (int64_t i = 0; i < transform_ndim; i += max_signal_ndim) { - const int64_t signal_ndim = std::min(transform_ndim - i, max_signal_ndim); - DimVector source_dim(signal_ndim); - DimVector dest_dim(signal_ndim); - - for (int64_t j = 0; j < signal_ndim; ++j) { - source_dim[j] = dim[i + j]; - dest_dim[j] = j + (input_dim - signal_ndim); - } - - // _fft operates on up-to the last 3 dims, so move selected dims to the end - x = at::movedim(x, source_dim, dest_dim); - - x = _fft(x, signal_ndim, /*complex_input=*/true, /*complex_output=*/true, - /*inverse=*/!forward, /*signal_sizes=*/{}, /*normalization=*/norm, - /*onesided=*/false); - - // Move transform dims back to their original order - x = at::movedim(x, dest_dim, source_dim); + if (out.defined()) { + TORCH_CHECK(out.is_complex(), function_name, " expects a complex output tensor, but got ", out.scalar_type()); + return at::_fft_c2c_out(out, x, dim, static_cast(norm), forward); + } else { + return at::_fft_c2c(x, dim, static_cast(norm), forward); } - - return at::view_as_complex(x); } -} +} // namespace (anonymous) // torch.fft.fft, analogous to NumPy's numpy.fft.fft Tensor fft_fft(const Tensor& self, c10::optional n, int64_t dim, c10::optional norm) { - return self.is_complex() ? - fft_c2c(self, n, dim, norm, /*forward=*/true) : - fft_r2c(self, n, dim, norm, /*forward=*/true, /*onesided=*/false); + return self.is_complex() ? + fft_c2c("fft", {}, self, n, dim, norm, /*forward=*/true) : + fft_r2c("fft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/false); +} + +Tensor& fft_fft_out(Tensor& out, const Tensor& self, c10::optional n, + int64_t dim, c10::optional norm) { + if (self.is_complex()) { + fft_c2c("fft", out, self, n, dim, norm, /*forward=*/true); + } else { + fft_r2c("fft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/false); + } + return out; } Tensor fft_ifft(const Tensor& self, c10::optional n, int64_t dim, c10::optional norm) { - return self.is_complex() ? - fft_c2c(self, n, dim, norm, /*forward=*/false) : - fft_r2c(self, n, dim, norm, /*forward=*/false, /*onesided=*/false); + return self.is_complex() ? + fft_c2c("ifft", {}, self, n, dim, norm, /*forward=*/false) : + fft_r2c("ifft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/false); +} + +Tensor& fft_ifft_out(Tensor& out, const Tensor& self, c10::optional n, + int64_t dim, c10::optional norm) { + if (self.is_complex()) { + fft_c2c("ifft", out, self, n, dim, norm, /*forward=*/false); + } else { + fft_r2c("ifft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/false); + } + return out; } Tensor fft_rfft(const Tensor& self, c10::optional n, int64_t dim, c10::optional norm) { - return fft_r2c(self, n, dim, norm, /*forward=*/true, /*onesided=*/true); + return fft_r2c("rfft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/true); +} + +Tensor& fft_rfft_out(Tensor& out, const Tensor& self, c10::optional n, + int64_t dim, c10::optional norm) { + fft_r2c("rfft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/true); + return out; } Tensor fft_irfft(const Tensor& self, c10::optional n, int64_t dim, c10::optional norm) { - return fft_c2r(self, n, dim, norm, /*forward=*/false); + return fft_c2r("irfft", {}, self, n, dim, norm, /*forward=*/false); +} + +Tensor& fft_irfft_out(Tensor& out, const Tensor& self, c10::optional n, + int64_t dim, c10::optional norm) { + fft_c2r("irfft", out, self, n, dim, norm, /*forward=*/false); + return out; } Tensor fft_hfft(const Tensor& self, c10::optional n, int64_t dim, c10::optional norm) { - return fft_c2r(self, n, dim, norm, /*forward=*/true); + return fft_c2r("hfft", {}, self, n, dim, norm, /*forward=*/true); +} + +Tensor& fft_hfft_out(Tensor& out, const Tensor& self, c10::optional n, + int64_t dim, c10::optional norm) { + fft_c2r("hfft", out, self, n, dim, norm, /*forward=*/true); + return out; } Tensor fft_ihfft(const Tensor& self, c10::optional n, int64_t dim, c10::optional norm) { - return fft_r2c(self, n, dim, norm, /*forward=*/false, /*onesided=*/true); + return fft_r2c("ihfft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/true); +} + +Tensor& fft_ihfft_out(Tensor& out, const Tensor& self, c10::optional n, + int64_t dim, c10::optional norm) { + fft_r2c("ihfft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/true); + return out; } Tensor fft_fftn(const Tensor& self, c10::optional s, @@ -356,7 +358,18 @@ Tensor fft_fftn(const Tensor& self, c10::optional s, auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); // TODO: For real input, perform rfftn then mirror with conjugate symmetry Tensor input = promote_tensor_fft(self, /*require_complex=*/true); - return fftn_c2c(input, desc.shape, desc.dim, norm, /*forward=*/true); + return fftn_c2c("fftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/true); +} + +Tensor& fft_fftn_out(Tensor& out, const Tensor& self, + c10::optional s, + c10::optional dim, + c10::optional norm) { + auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); + // TODO: For real input, perform rfftn then mirror with conjugate symmetry + Tensor input = promote_tensor_fft(self, /*require_complex=*/true); + fftn_c2c("fftn", out, input, desc.shape, desc.dim, norm, /*forward=*/true); + return out; } Tensor fft_ifftn(const Tensor& self, c10::optional s, @@ -364,184 +377,206 @@ Tensor fft_ifftn(const Tensor& self, c10::optional s, c10::optional norm) { auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); Tensor input = promote_tensor_fft(self, /*require_complex=*/true); - return fftn_c2c(input, desc.shape, desc.dim, norm, /*forward=*/false); + return fftn_c2c("ifftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/false); } -Tensor fft_rfftn(const Tensor& self, c10::optional s, - c10::optional dim, - c10::optional norm) { +Tensor& fft_ifftn_out(Tensor& out, const Tensor& self, + c10::optional s, + c10::optional dim, + c10::optional norm) { + auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); + Tensor input = promote_tensor_fft(self, /*require_complex=*/true); + fftn_c2c("ifftn", out, input, desc.shape, desc.dim, norm, /*forward=*/false); + return out; +} + +static Tensor fft_rfftn_impl(Tensor out, const Tensor& self, + c10::optional s, + c10::optional dim, + const c10::optional& norm_str) { + TORCH_CHECK(!self.is_complex(), "rfftn expects a real-valued input tensor, but got ", self.scalar_type()); auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis"); + Tensor input = promote_tensor_fft(self, /*require_complex=*/false); + Tensor x = resize_fft_input(input, desc.dim, desc.shape); + const auto norm = norm_from_string(norm_str, /*forward=*/true); + if (out.defined()) { + TORCH_CHECK(out.is_complex(), "rfftn expects a complex-valued output tensor, but got ", out.scalar_type()); + return at::_fft_r2c_out(out, x, desc.dim, static_cast(norm), /*onesided=*/true); + } else { + return at::_fft_r2c(x, desc.dim, static_cast(norm), /*onesided=*/true); + } +} - const auto last_dim = desc.dim.back(); - const auto last_shape = desc.shape.back(); - desc.shape.pop_back(); - desc.dim.pop_back(); +Tensor fft_rfftn(const Tensor& self, c10::optional s, + c10::optional dim, + c10::optional norm_str) { + return fft_rfftn_impl({}, self, s, dim, norm_str); +} - // rfft on last dim to get hermitian complex shape - auto x = native::fft_rfft(self, last_shape, last_dim, norm); - // Normal fft on remaining dims - return fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/true); +Tensor& fft_rfftn_out(Tensor& out, const Tensor& self, + c10::optional s, + c10::optional dim, + c10::optional norm_str) { + fft_rfftn_impl(out, self, s, dim, norm_str); + return out; } -Tensor fft_irfftn(const Tensor& self, c10::optional s, - c10::optional dim, - c10::optional norm) { +static Tensor fft_irfftn_impl(Tensor out, const Tensor& self, + c10::optional s, + c10::optional dim, + const c10::optional& norm_str) { auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim); TORCH_CHECK(desc.shape.size() > 0, "irfftn must transform at least one axis"); - const auto last_dim = desc.dim.back(); - const auto last_shape = [&]() -> c10::optional { - // If shape is defaulted in the last dimension, - // pass nullopt to irfft and let it calculate the default size + const auto last_dim_size = [&] { + // Fixup default shape handling in the last dimension, if (!s.has_value() || (s->back() == -1)) { - return c10::nullopt; + const auto last_dim = desc.dim.back(); + return 2 * (self.sizes()[last_dim] - 1); } return desc.shape.back(); }(); - desc.shape.pop_back(); - desc.dim.pop_back(); - - // Normal ifft for all but last dim - Tensor x = promote_tensor_fft(self, /*require_complex=*/true); - x = fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/false); - // Then 1d irfft on last dim to get real output - return native::fft_irfft(x, last_shape, last_dim, norm); -} - -// This is a pass-through wrapper function that does the size check and -// inferences. The actual forward implementation function is called -// at::_fft_with_size which dispatches to _fft_cufft (CUDA) or _fft_mkl (CPU). -static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim, - const bool complex_input, const bool complex_output, - const bool inverse, IntArrayRef signal_sizes, - const fft_norm_mode normalization, const bool onesided) { - - TORCH_CHECK(signal_ndim >= 1 && signal_ndim <= 3, - "Expected signal_ndim to be 1, 2, or 3, but got signal_ndim=", - signal_ndim); - TORCH_CHECK(at::isFloatingType(self.scalar_type()), - "Expected an input tensor of floating types, but got input=", - self.toString(), self.sizes()); - - auto signal_tensor_ndim = signal_ndim + static_cast(complex_input); // add complex dim - if (self.dim() < signal_tensor_ndim) { - std::ostringstream ss; - ss << "Given signal_ndim=" << signal_ndim << ", expected an input tensor " - << "of at least " << signal_tensor_ndim << "D"; - if (complex_input) { - ss << " (complex input adds an extra dimension)"; - } - ss << ", but got input=" << self.toString() << self.sizes(); - AT_ERROR(ss.str()); + desc.shape.back() = last_dim_size / 2 + 1; + + Tensor input = promote_tensor_fft(self, /*require_complex=*/true); + Tensor x = resize_fft_input(input, desc.dim, desc.shape); + const auto norm = norm_from_string(norm_str, /*forward=*/false); + if (out.defined()) { + TORCH_CHECK(out.is_floating_point(), "irfftn expects a floating point output tensor, but got ", out.scalar_type()); + return at::_fft_c2r_out(out, x, desc.dim, static_cast(norm), last_dim_size); + } else { + return at::_fft_c2r(x, desc.dim, static_cast(norm), last_dim_size); } +} - auto self_shape = self.sizes(); - auto batch_ndim = self.dim() - signal_tensor_ndim; +Tensor fft_irfftn(const Tensor& self, + c10::optional s, + c10::optional dim, + c10::optional norm_str) { + return fft_irfftn_impl({}, self, s, dim, norm_str); +} - Tensor input = self; - // flatten the batch dims - if (batch_ndim == 0) { - // slightly faster path for non-batch mode - input = input.unsqueeze(0); - } else if (batch_ndim > 1) { - std::vector flatten_input_shape(signal_tensor_ndim + 1); - std::copy(self_shape.begin() + batch_ndim, self_shape.end(), flatten_input_shape.begin() + 1); - flatten_input_shape[0] = -1; - input = input.reshape(flatten_input_shape); - - } - - // now we assume that input is batched as [ B x signal_dims... ] - - if (complex_input) { - TORCH_CHECK(input.size(signal_ndim + 1) == 2, - "Expected an input tensor with a last dimension of size 2 " - "representing real + imaginary components, but got input ", - self.toString(), self.sizes()); - } - - // build signal_sizes and output_size - TORCH_CHECK(signal_sizes.size() == 0 || static_cast(signal_sizes.size()) == signal_ndim, - "Expected signal_sizes to be empty (default) or of signal_ndim=", - signal_ndim, "D, but got signal_sizes=", signal_sizes); - std::vector output_sizes(signal_ndim + 1 + static_cast(complex_output)); - output_sizes[0] = input.size(0); // batch size - std::vector checked_signal_sizes(signal_ndim); - for (int64_t i = 0; i < signal_ndim; i++) { - int64_t input_size = input.size(i + 1); - if (i == signal_ndim - 1 && onesided && complex_input && !complex_output) { - // If last dim and complex-to-real onesided, input is only half of - // signal, and we need to infer basing on signal_sizes, if given - // See native/SpectralOpsUtils.h for detailed description. - int64_t inferred_size; - if (signal_sizes.size() > 0) { - inferred_size = infer_ft_complex_to_real_onesided_size(input_size, signal_sizes[i]); - } else { - inferred_size = infer_ft_complex_to_real_onesided_size(input_size); - } - checked_signal_sizes[i] = inferred_size; - output_sizes[i + 1] = inferred_size; - } else { - if (i == signal_ndim - 1 && onesided && !complex_input && complex_output) { - // if last dim and real-to-complex onesided, output should be only - // half of the signal, and we need to infer using input_size - output_sizes[i + 1] = infer_ft_real_to_complex_onesided_size(input_size); - } else { - output_sizes[i + 1] = input_size; - } - checked_signal_sizes[i] = input_size; - TORCH_CHECK(signal_sizes.size() == 0 || signal_sizes[i] == checked_signal_sizes[i], - "Expected given signal_sizes=", signal_sizes," to have same " - "shape with input at signal dimension ", i, ", but got " - "signal_sizes=", signal_sizes, " and input=", self.toString(), - self.sizes()); +Tensor& fft_irfftn_out(Tensor& out, const Tensor& self, + c10::optional s, + c10::optional dim, + c10::optional norm_str) { + fft_irfftn_impl(out, self, s, dim, norm_str); + return out; +} + +Tensor fft_fft2(const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_fftn(self, s, dim, std::move(norm)); +} + +Tensor& fft_fft2_out(Tensor& out, const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_fftn_out(out, self, s, dim, std::move(norm)); +} + +Tensor fft_ifft2(const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_ifftn(self, s, dim, std::move(norm)); +} + +Tensor& fft_ifft2_out(Tensor& out, const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_ifftn_out(out, self, s, dim, std::move(norm)); +} + +Tensor fft_rfft2(const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_rfftn(self, s, dim, std::move(norm)); +} + +Tensor& fft_rfft2_out(Tensor& out, const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_rfftn_out(out, self, s, dim, std::move(norm)); +} + +Tensor fft_irfft2(const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_irfftn(self, s, dim, std::move(norm)); +} + +Tensor& fft_irfft2_out(Tensor& out, const Tensor& self, c10::optional s, + IntArrayRef dim, c10::optional norm) { + return native::fft_irfftn_out(out, self, s, dim, std::move(norm)); +} + +Tensor& fft_fftfreq_out(Tensor& out, int64_t n, double d) { + ScalarType dtype = out.scalar_type(); + TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype), + "fftfreq requires a floating point or complex dtype"); + // TODO: arange doesn't have complex support + at::arange_out(out, n); + auto right_slice = out.slice(0, (n + 1) / 2, 0); + at::arange_out(right_slice, -(n/2), 0, 1); + return out.mul_(1.0 / (n * d)); // Slightly faster than div_(n*d) +} + +Tensor fft_fftfreq(int64_t n, double d, const TensorOptions& options) { + auto out = at::empty({n}, options); + return native::fft_fftfreq_out(out, n, d); +} + +Tensor& fft_rfftfreq_out(Tensor& out, int64_t n, double d) { + ScalarType dtype = out.scalar_type(); + TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype), + "rfftfreq requires a floating point or complex dtype"); + // TODO: arange doesn't have complex support + native::arange_out(out, n/2 + 1); + return out.mul_(1.0 / (n * d)); // Slightly faster than div_(n*d) +} + +Tensor fft_rfftfreq(int64_t n, double d, const TensorOptions& options) { + auto out = at::empty({n/2 + 1}, options); + return native::fft_rfftfreq_out(out, n, d); +} + +// If an array dim is specified, wraps them according to self.dim(). +// Otherwise returns a vector of all dims. +DimVector default_alldims(const Tensor& self, c10::optional dim_opt) { + DimVector dim; + if (dim_opt) { + IntArrayRef dim_unwrapped = *dim_opt; + dim.resize(dim_unwrapped.size()); + for (int64_t i = 0; i < dim.size(); ++i) { + dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim()); } - } - if (complex_output) { - output_sizes[signal_ndim + 1] = 2; - } - - Tensor output = at::_fft_with_size(input, signal_ndim, complex_input, - complex_output, inverse, - checked_signal_sizes, - static_cast(normalization), - onesided, - output_sizes); - - // unflatten the batch dims - if (batch_ndim == 0) { - // slightly faster path for non-batch mode - output = output.squeeze(0); - } else if (batch_ndim > 1) { - auto output_ndim = self.dim() + static_cast(complex_output) - static_cast(complex_input); - std::vector unflatten_output_shape(output_ndim); - std::copy(self_shape.begin(), self_shape.begin() + batch_ndim, unflatten_output_shape.begin()); - std::copy(output_sizes.begin() + 1, output_sizes.end(), unflatten_output_shape.begin() + batch_ndim); - output = output.reshape(unflatten_output_shape); - } - return output; -} - -// Wrapper to preserve the historic signature of _fft_with_size -// NOTE: This is only used for torchscript backwards compatibility and the new -// signature with normalization modes should be used in all other cases -Tensor _fft_with_size(const Tensor& input, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - bool normalized, bool onesided, - IntArrayRef output_sizes) { - fft_norm_mode norm; - if (normalized) { - norm = fft_norm_mode::by_root_n; } else { - norm = inverse ? fft_norm_mode::by_n : fft_norm_mode::none; + dim.resize(self.dim()); + std::iota(dim.begin(), dim.end(), 0); } - return at::_fft_with_size( - input, signal_ndim, complex_input, complex_output, inverse, - checked_signal_sizes, static_cast(norm), onesided, output_sizes); + return dim; } +Tensor fft_fftshift(const Tensor& x, c10::optional dim_opt) { + auto dim = default_alldims(x, dim_opt); + + IntArrayRef x_sizes = x.sizes(); + DimVector shift(dim.size()); + for (int64_t i = 0; i < dim.size(); ++i) { + shift[i] = x_sizes[dim[i]] / 2; + } + + return at::roll(x, shift, dim); +} + +Tensor fft_ifftshift(const Tensor& x, c10::optional dim_opt) { + auto dim = default_alldims(x, dim_opt); + + IntArrayRef x_sizes = x.sizes(); + DimVector shift(dim.size()); + for (int64_t i = 0; i < dim.size(); ++i) { + shift[i] = (x_sizes[dim[i]] + 1) / 2; + } + + return at::roll(x, shift, dim); +} + + // We call the following methods via CUDA hooks because they are really only // valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details. int64_t _cufft_get_plan_cache_max_size(int64_t device_index) { @@ -560,52 +595,6 @@ void _cufft_clear_plan_cache(int64_t device_index) { detail::getCUDAHooks().cuFFTClearPlanCache(device_index); } -Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) { - TORCH_WARN_ONCE( - "The function torch.fft is deprecated and will be removed in PyTorch 1.8. " - "Use the new torch.fft module functions, instead, by importing torch.fft " - "and calling torch.fft.fft or torch.fft.fftn."); - return _fft(self, signal_ndim, /* complex_input */ true, - /* complex_output */ true, /* inverse */ false, {}, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none, - /* onesided */ false); -} - -Tensor ifft(const Tensor& self, const int64_t signal_ndim, const bool normalized) { - TORCH_WARN_ONCE( - "The function torch.ifft is deprecated and will be removed in a future " - "PyTorch release. Use the new torch.fft module functions, instead, by " - "importing torch.fft and calling torch.fft.ifft or torch.fft.ifftn."); - return _fft(self, signal_ndim, /* complex_input */ true, - /* complex_output */ true, /* inverse */ true, {}, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n, - /* onesided */ false); -} - -Tensor rfft(const Tensor& self, const int64_t signal_ndim, const bool normalized, - const bool onesided) { - TORCH_WARN_ONCE( - "The function torch.rfft is deprecated and will be removed in a future " - "PyTorch release. Use the new torch.fft module functions, instead, by " - "importing torch.fft and calling torch.fft.fft or torch.fft.rfft."); - return _fft(self, signal_ndim, /* complex_input */ false, - /* complex_output */ true, /* inverse */ false, {}, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none, - onesided); -} - -Tensor irfft(const Tensor& self, const int64_t signal_ndim, const bool normalized, - const bool onesided, IntArrayRef signal_sizes) { - TORCH_WARN_ONCE( - "The function torch.irfft is deprecated and will be removed in a future " - "PyTorch release. Use the new torch.fft module functions, instead, by " - "importing torch.fft and calling torch.fft.ifft or torch.fft.irfft."); - return _fft(self, signal_ndim, /* complex_input */ true, - /* complex_output */ false, /* inverse */ true, signal_sizes, - normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n, - onesided); -} - template static Stream& write_opt(Stream& SS, const optional& value) { if (value) { @@ -646,9 +635,21 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop auto win_length = win_lengthOpt.value_or(n_fft); const bool return_complex = return_complexOpt.value_or( self.is_complex() || (window.defined() && window.is_complex())); - if (!return_complexOpt && !return_complex) { - TORCH_WARN("stft will return complex tensors by default in future, use" - " return_complex=False to preserve the current output format."); + if (!return_complex) { + if (!return_complexOpt.has_value()) { + TORCH_WARN_ONCE( + "stft will soon require the return_complex parameter be given for real inputs, " + "and will further require that return_complex=True in a future PyTorch release." + ); + } + + + // TORCH_WARN_ONCE( + // "stft with return_complex=False is deprecated. In a future pytorch " + // "release, stft will return complex tensors for all inputs, and " + // "return_complex=False will raise an error.\n" + // "Note: you can still call torch.view_as_real on the complex output to " + // "recover the old return format."); } if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) { @@ -717,12 +718,13 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop const bool complex_fft = input.is_complex(); const auto onesided = onesidedOpt.value_or(!complex_fft); + const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none; Tensor out; if (complex_fft) { TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex"); - out = at::native::fft(at::view_as_real(input), 1, normalized); + out = at::_fft_c2c(input, input.dim() - 1, static_cast(norm), /*forward=*/true); } else { - out = at::native::rfft(input, 1, normalized, onesided); + out = at::_fft_r2c(input, input.dim() - 1, static_cast(norm), onesided); } out.transpose_(1, 2); @@ -731,12 +733,28 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop } if (return_complex) { - return at::view_as_complex(out); - } else { return out; + } else { + return at::view_as_real(out); } } +// Create complex tensor from the old style of real tensor with size=(..., 2) +// This is to support istft in the transition to requiring complex input. +// NOTE: This may return a view of the input tensor, or might clone if necessary +static Tensor as_complex(const Tensor& self) { + const bool can_view_as_complex = [&]{ + auto strides = self.strides(); + for (int64_t i = 0; i + 1 < strides.size(); ++i) { + if (strides[i] % 2 != 0) { + return false; + } + } + return strides.back() == 1 && self.storage_offset() % 2 == 0; + }(); + return at::view_as_complex(can_view_as_complex ? self : self.clone(MemoryFormat::Contiguous)); +} + /* Inverse Short-time Fourier Transform * * This is modeled after librosa but with support for complex time-domain @@ -763,6 +781,11 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho const auto hop_length = hop_lengthOpt.value_or(n_fft >> 2); const auto win_length = win_lengthOpt.value_or(n_fft); + if (!self.is_complex()) { + TORCH_WARN_ONCE( + "istft will require a complex-valued input tensor in a future PyTorch release. " + "Matching the output from stft with return_complex=True. "); + } Tensor input = self.is_complex() ? at::view_as_real(self) : self; const auto input_dim = input.dim(); const auto n_frames = input.size(-2); @@ -833,16 +856,19 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho input = input.unsqueeze(0); } - input = input.transpose(1, 2); // size: (channel, n_frames, fft_size, 2) + input = as_complex(input.transpose(1, 2)); // size: (channel, n_frames, fft_size, 2) + const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n; if (return_complex) { TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex"); - input = at::native::ifft(input, 1, normalized); // size: (channel, n_frames, n_fft) - input = at::view_as_complex(input); + input = at::_fft_c2c(input, input.dim() - 1, static_cast(norm), /*forward=*/false); // size: (channel, n_frames, n_fft) } else { TORCH_CHECK(!window.defined() || !window.is_complex(), "Complex windows are incompatible with return_complex=False"); - input = at::native::irfft(input, 1, normalized, onesided, {n_fft,}); // size: (channel, n_frames, n_fft) + if (!onesided) { + input = input.slice(-1, 0, n_fft / 2 + 1); + } + input = at::_fft_c2r(input, input.dim() - 1, static_cast(norm), n_fft); // size: (channel, n_frames, n_fft) } TORCH_INTERNAL_ASSERT(input.size(2) == n_fft); @@ -907,4 +933,97 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho onesidedOpt, lengthOpt, /*return_complex=*/false); } +void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { + const auto input_sizes = input.sizes(); + const auto input_strides = input.strides(); + TORCH_CHECK(dim_.size() > 0); + DimVector dim(dim_.begin(), dim_.end()); + at::maybe_wrap_dims(dim, input_strides.size()); + + if (input.numel() == 0 || input_sizes[dim.back()] <= 2) { + return; // No elements need writing + } + + // Small dimensions may be treated as batch dims since they don't get mirrored + dim.erase( + std::remove_if(dim.begin(), dim.end(), [&](int64_t dim) { + return (input_sizes[dim] <= 2); + }), + dim.end()); + + // Use TensorIterator to coalesce batch dimensions + // NOTE: Can't use TensorIterator loops because we need negative strides + auto iter = TensorIteratorConfig() + .add_output(input) + .add_input(input) + .resize_outputs(false) + .declare_static_shape(input_sizes, dim) + .build(); + + const auto iter_strides = iter.strides(0); + const auto iter_sizes = iter.shape(); + const auto ndim = iter_strides.size() + dim.size(); + DimVector in_strides(ndim), signal_half_sizes(ndim); + // Take coalesced batch dimensions from TensorIterator + std::copy(iter_strides.begin(), iter_strides.end(), in_strides.begin()); + std::copy(iter_sizes.begin(), iter_sizes.end(), signal_half_sizes.begin()); + + // Take transformed dimensions directly from the input + const auto element_size = iter.element_size(0); + for (int64_t i = 0; i < dim.size(); ++i) { + // Convert to byte strides to match TensorIterator + in_strides[iter_strides.size() + i] = input_strides[dim[i]] * element_size; + signal_half_sizes[iter_strides.size() + i] = input_sizes[dim[i]]; + } + + // For the last dimension, use negative strides to perform the mirroring + signal_half_sizes.back() = (input_sizes[dim.back()] - 1) / 2; + auto out_strides = in_strides; + out_strides.back() *= -1; + + auto* data_ptr = static_cast(input.data_ptr()); + const auto* in_data = data_ptr + input_strides[dim.back()] * element_size; + auto* out_data = data_ptr + ( + input_strides[dim.back()] * (input_sizes[dim.back()] - 1) * element_size); + + // Reorder dimensions by stride to maximize data locality + DimVector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), 0); + std::sort(dim_permute.begin(), dim_permute.end(), + [&](auto dim1, auto dim2) { + return in_strides[dim1] < in_strides[dim2]; + }); + + DimVector temp(ndim); + auto apply_permutation = [&] (DimVector & vec) { + // Do permuted index copy into a temporary, then copy back + for (int64_t i = 0; i < ndim; ++i) { + temp[i] = vec[dim_permute[i]]; + } + vec = temp; + }; + apply_permutation(in_strides); + apply_permutation(out_strides); + apply_permutation(signal_half_sizes); + + // Find dims.slice(dims.size() - 1) in the new permuted order. + // These are the dimensions that need explicit Hermitian mirroring + DimVector mirror_dims; + mirror_dims.reserve(dim.size() - 1); + for (int64_t i = 0; i < ndim; ++i) { + if (dim_permute[i] >= iter_strides.size() && // Not a batch dimension + dim_permute[i] != ndim - 1) { // Not the last dim, which is mirrored separately with negative strides + mirror_dims.push_back(i); + } + } + TORCH_INTERNAL_ASSERT(mirror_dims.size() == dim.size() - 1); + + // Dispatch to CPU or CUDA kernel to do the actual conjugate mirroring + fft_fill_with_conjugate_symmetry_stub( + input.device().type(), input.scalar_type(), + mirror_dims, signal_half_sizes, in_strides, in_data, out_strides, out_data); +} + +DEFINE_DISPATCH(fft_fill_with_conjugate_symmetry_stub); + }} // at::native diff --git a/aten/src/ATen/native/SpectralOpsUtils.h b/aten/src/ATen/native/SpectralOpsUtils.h index f498da5adc985..bd38257d12755 100644 --- a/aten/src/ATen/native/SpectralOpsUtils.h +++ b/aten/src/ATen/native/SpectralOpsUtils.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace at { namespace native { @@ -62,4 +63,18 @@ inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size, } } +using fft_fill_with_conjugate_symmetry_fn = + void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes, + IntArrayRef in_strides, const void* in_data, + IntArrayRef out_strides, void* out_data); +DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub); + +// In real-to-complex transform, cuFFT and MKL only fill half of the values +// due to conjugate symmetry. This function fills in the other half of the full +// fft by using the Hermitian symmetry in the signal. +// self should be the shape of the full signal and dims.back() should be the +// one-sided dimension. +// See NOTE [ Fourier Transform Conjugate Symmetry ] +TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims); + }} // at::native diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index ad6625308ff50..3ced0cf5eb524 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -135,6 +135,26 @@ static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t return index.reshape(shape); } +static ptrdiff_t dataOffset(const Tensor& tensor, ptrdiff_t linearIndex) { + auto size = tensor.sizes(); + auto stride = tensor.strides(); + int nDim = tensor.dim(); + ptrdiff_t dataOffset = 0; + for (int i = nDim - 1; i >= 0; i--) { + dataOffset += (linearIndex % size[i]) * stride[i]; + linearIndex /= size[i]; + } + return dataOffset; +} + +static inline int64_t wrapLinearIndex(int64_t linearIndex, int64_t numel) { + return linearIndex < 0 ? linearIndex + numel : linearIndex; +} + +static inline void checkLinearIndex(int64_t linearIndex, int64_t numel) { + TORCH_CHECK(linearIndex < numel && linearIndex >= -numel, "out of range: ", linearIndex, " out of ", numel); +} + AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) { int64_t element_size_bytes = src.element_size(); @@ -186,7 +206,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) } } -static AdvancedIndex make_info(Tensor self, TensorList orig) { +static AdvancedIndex make_info(Tensor self, const torch::List>& orig) { checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors auto indices = expandTensors(self, orig); @@ -261,7 +281,7 @@ static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& return config.build(); } -Tensor index(const Tensor & self, TensorList indices) { +Tensor index(const Tensor & self, const torch::List>& indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); auto info = make_info(self, indices); @@ -270,9 +290,36 @@ Tensor index(const Tensor & self, TensorList indices) { return iter.output(); } -Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { +Tensor quantized_index(const Tensor & self, const torch::List>& indices) { + TORCH_INTERNAL_ASSERT( + self.qscheme() == c10::kPerTensorAffine || + self.qscheme() == c10::kPerTensorSymmetric, + "Indexing is only supported for per-Tensor quantized Tensors."); + + // For now, this is a naive implementation which does dq -> index -> q. + // TODO(future PR): improve performance by removing the copies. + const auto& self_dq = self.dequantize(); + + TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); + + auto info = make_info(self_dq, indices); + auto iter = make_index_iterator(info); + index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides); + at::Tensor res = iter.output(); + + return at::quantize_per_tensor( + res, self.q_scale(), self.q_zero_point(), self.scalar_type()); +} + +Tensor& index_out(Tensor& result, const Tensor & self, const torch::List>& indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); at::assert_no_internal_overlap(result); + at::assert_no_overlap(result, self); + for (const c10::optional& index: indices) { + if (index.has_value()) { + at::assert_no_overlap(result, *index); + } + } auto info = make_info(self, indices); auto iter = make_index_out_iterator(info, result); @@ -280,26 +327,31 @@ Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { return result; } -Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) { +Tensor index_put(const Tensor & self, const torch::List>& indices, const Tensor & value, bool accumulate) { return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate); } -Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate, const bool unsafe) { - TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); - if (accumulate && self.device().type() == kCUDA) { - TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ", - value.device(), " for value tensor"); - index_put_accum_stub(self.device().type(), self, indices, value, unsafe); - return self; - } - +Tensor & _index_put_impl_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate, const bool unsafe) { + TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); if (at::has_internal_overlap(self) == MemOverlap::YES) { TORCH_WARN( "Use of index_put_ on expanded tensors is deprecated. " "Please clone() the tensor before performing this operation. " "This also applies to advanced indexing e.g. tensor[indices] = tensor"); } - at::assert_no_partial_overlap(self, value); + at::assert_no_overlap(self, value); + for (const c10::optional& index: indices) { + if (index.has_value()) { + at::assert_no_overlap(self, *index); + } + } + + if (accumulate && self.device().type() == kCUDA) { + TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ", + value.device(), " for value tensor"); + index_put_accum_stub(self.device().type(), self, indices, value, unsafe); + return self; + } auto info = make_info(self, indices); auto iter = make_index_put_iterator(info, value); @@ -308,14 +360,20 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu } -Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate) { +Tensor & index_put_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate) { return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false); } Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { + // See note [Writing Nondeterministic Operations] + // Nondeterministic when index contains duplicate entries + at::globalContext().alertNotDeterministic("index_copy"); dim = maybe_wrap_dim(dim, self.dim()); TORCH_CHECK_INDEX(index.dim() < 2, "index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")"); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, source); int64_t numIndices = index.numel(); if (source.dim() == 0 && numIndices != 1) { @@ -361,7 +419,8 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T auto numel = index.numel(); TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector"); - TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index"); + TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, + "index_add_(): Expected dtype int32/int64 for index"); TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_add_(): self and source must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < source.dim(), @@ -370,11 +429,10 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T "index_add_(): Number of indices should be equal to self.size(dim)"); at::assert_no_internal_overlap(self); - at::assert_no_partial_overlap(self, index); - at::assert_no_partial_overlap(self, source); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, source); auto index_contig = index.contiguous(); - auto index_data = index_contig.data_ptr(); if (self.dim() > 1) { // Equivalent to: @@ -394,32 +452,41 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T auto self_dim_size = self.size(dim); auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice); - for (auto i = 0; i < numel; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); - auto self_data = static_cast(selfSlice.data_ptr()) + self_i * self_stride_bytes; - auto source_data = static_cast(sourceSlice.data_ptr()) + i * source_stride_bytes; - iter.unsafe_replace_operand(0, self_data); - iter.unsafe_replace_operand(1, self_data); - iter.unsafe_replace_operand(2, source_data); - add_stub(iter.device_type(), iter, 1); - } + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () { + auto index_data = index_contig.data_ptr(); + for (auto i = 0; i < numel; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); + auto self_data = static_cast(selfSlice.data_ptr()) + self_i * self_stride_bytes; + auto source_data = static_cast(sourceSlice.data_ptr()) + i * source_stride_bytes; + iter.unsafe_replace_operand(0, self_data); + iter.unsafe_replace_operand(1, self_data); + iter.unsafe_replace_operand(2, source_data); + add_stub(iter.device_type(), iter, 1); + } + }); } else { TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&] { + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "index_add_", [&self, &source, &dim, &index_contig, &numel] { auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); auto source_stride = source.dim() == 0 ? 1 : source.stride(dim); // TODO: Maybe TensorAccessor can beused here? auto* self_ptr = self.data_ptr(); auto* source_ptr = source.data_ptr(); - for (auto i = 0; i < numel; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self"); - scalar_t *self_ip = self_ptr + self_i * self_stride; - *self_ip += *(source_ptr + i * source_stride); - } + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_", + [&index_contig, &numel, &self, &self_ptr, &self_stride, &source_ptr, &source_stride] { + auto index_data = index_contig.data_ptr(); + for (auto i = 0; i < numel; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self"); + scalar_t *self_ip = self_ptr + self_i * self_stride; + *self_ip += *(source_ptr + i * source_stride); + } + }); }); } return self; @@ -429,17 +496,100 @@ Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const T return self.clone(at::MemoryFormat::Preserve).index_add_(dim, index, source); } +// Check that indices fall within dimension array size +// Avoid redispatch call to min/max +template +static void check_indexarray_range( + const IndexType* indices, + int64_t n, + IndexType indexing_axis_dim) { + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + TORCH_CHECK( + 0 <= idx && idx < indexing_axis_dim, + "INDICES element is out of DATA bounds, id=", + idx, + " axis_dim=", + indexing_axis_dim); + } +} + +Tensor & index_select_out_cpu_dim1_( + Tensor & result_contig, const Tensor & self, const Tensor & index_contig) { + + auto self_contig = self.contiguous(); + const caffe2::TypeMeta dataType = self_contig.dtype(); + size_t item_bytesize = dataType.itemsize(); + + auto out = static_cast(result_contig.data_ptr()); + + auto src_base = static_cast(self_contig.data_ptr()); + + auto self_sizes = self_contig.sizes(); + auto outer_dims_product = c10::size_to_dim_(1, self_sizes); + auto block_size = c10::size_from_dim_(2, self_sizes); + auto block_bytesize = block_size * item_bytesize; + + auto src_indexing_axis_dim = self_sizes[1]; + auto src_batch_bytesize = self_sizes[1] * block_bytesize; + auto N = index_contig.numel(); + + auto gathered_batch_bytesize = N * block_bytesize; + + AT_DISPATCH_INDEX_TYPES( + index_contig.scalar_type(), "batch_index_select_compute", [&]() { + + const auto* idxs = index_contig.data_ptr(); + check_indexarray_range(idxs, N, src_indexing_axis_dim); + + // Special-case single-float copy for efficiency + if (self.scalar_type() == ScalarType::Float && block_size == 1) { + for (auto batch = 0; batch < outer_dims_product; ++batch) { + const float* src_floats = + (const float*)(src_base + batch * src_batch_bytesize); + float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); + + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (idx < 0) { + idx = idx + src_indexing_axis_dim; + } + dst_floats[i] = src_floats[idx]; + } + } + } else { + // outer_dims_product specifies how many times we repeat inner dimensions, + // so we just iterate over it to cover all outer dimensions. + for (auto batch = 0; batch < outer_dims_product; ++batch) { + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (idx < 0) { + idx = idx + src_indexing_axis_dim; + } + + auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; + auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; + memcpy(dst, src, block_bytesize); + } + } + } + }); + return result_contig; +} + Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index) { dim = maybe_wrap_dim(dim, self.dim()); auto numel = index.numel(); TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector"); - TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_select(): Expected dtype int64 for index"); + TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "index_select(): self and result must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < self.dim(), "index_select(): Indexing dim ", dim, " is out of bounds of tensor"); at::assert_no_internal_overlap(result); + at::assert_no_overlap(result, self); + at::assert_no_overlap(result, index); auto result_size = self.sizes().vec(); if (self.dim() > 0) { @@ -448,13 +598,17 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim result.resize_(result_size); auto index_contig = index.contiguous(); - auto index_data = index_contig.data_ptr(); if (self.dim() > 1) { if (numel == 0 || self.numel() == 0) { return result; } + if (dim == 1 && result.is_contiguous()) { + // fast pass + return index_select_out_cpu_dim1_(result, self, index_contig); + } + auto selfSlice = self.select(dim, 0); auto resultSlice = result.select(dim, 0); auto selfSlice_data = selfSlice.data_ptr(); @@ -472,17 +626,26 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim .build(); auto grain_size = at::internal::GRAIN_SIZE; - auto outer_loop = [&](int64_t start, int64_t end) { + auto outer_loop = + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + [&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data, + &result_stride_bytes](int64_t start, int64_t end) { auto sub_iter = TensorIterator(iter); - for (int64_t i = start; i < end; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); - auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; - auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; - sub_iter.unsafe_replace_operand(0, result_data); - sub_iter.unsafe_replace_operand(1, self_data); - copy_stub(sub_iter.device_type(), sub_iter, false); - } + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", + [&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, + &resultSlice_data, &result_stride_bytes] () { + auto index_data = index_contig.data_ptr(); + for (int64_t i = start; i < end; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); + auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; + auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; + sub_iter.unsafe_replace_operand(0, result_data); + sub_iter.unsafe_replace_operand(1, self_data); + copy_stub(sub_iter.device_type(), sub_iter, false); + }; + }); }; // parallel on inner loop in case the slice is large enough; @@ -493,14 +656,23 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim // use a fast loop when self and result are contiguous and of the same data type if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) { auto slice_size_bytes = slice_size * elementSize(self.scalar_type()); - at::parallel_for(0, numel, grain_size / slice_size, [&](int64_t start, int64_t end) { - for (int64_t i = start; i < end; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); - auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; - auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; - memcpy(result_data, self_data, slice_size_bytes); - } + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + at::parallel_for(0, numel, grain_size / slice_size, + [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, + &self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) { + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", + [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, + &self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () { + auto index_data = index_contig.data_ptr(); + for (int64_t i = start; i < end; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); + auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; + auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; + memcpy(result_data, self_data, slice_size_bytes); + } + }); }); } else { at::parallel_for(0, numel, grain_size / slice_size, outer_loop); @@ -508,20 +680,26 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim } } else { TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); - - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "index_select", [&] { + // explicitly capture all required variables to work around windows build + // TODO: fix this when windows can correctly capture variables in nested lambda + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, self.scalar_type(), "index_select", + [&index_contig, &self, &result, &dim, &numel] { auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); auto self_data_ptr = self.data_ptr(); auto result_data_ptr = result.data_ptr(); auto self_numel = self.numel(); - for (auto i = 0; i < numel; i++) { - auto self_i = index_data[i]; - TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); - scalar_t *self_ip = self_data_ptr + self_i * self_stride; - *(result_data_ptr + i * result_stride) = *self_ip; - } + AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", + [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] { + auto index_data = index_contig.data_ptr(); + for (auto i = 0; i < numel; i++) { + auto self_i = index_data[i]; + TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); + scalar_t *self_ip = self_data_ptr + self_i * self_stride; + *(result_data_ptr + i * result_stride) = *self_ip; + } + }); }); } @@ -553,6 +731,9 @@ Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & gather_out_cpu_cuda(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) { result.resize_(index.sizes()); + at::assert_no_internal_overlap(result); + at::assert_no_overlap(result, self); + at::assert_no_partial_overlap(result, index); gather_stub(result.device().type(), result, self, dim, index); return result; } @@ -572,6 +753,9 @@ Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, cons Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, source); + at::assert_no_overlap(self, index); scatter_stub(self.device().type(), self, dim, index, source); return self; } @@ -579,6 +763,8 @@ Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor Tensor & scatter_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar source) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); scatter_fill_stub(self.device().type(), self, dim, index, source); return self; } @@ -592,7 +778,7 @@ SCATTER_GATHER_OP get_operator_enum(const std::string& reduce) { } else { TORCH_CHECK(false, - "reduce argument must be either of add, subtract, multiply or divide."); + "reduce argument must be either add or multiply."); } } @@ -602,6 +788,8 @@ Tensor& scatter_scalar_reduce_(Tensor& self, const int64_t dim, const Tensor& in "scatter_(): Expected dtype int64 for index."); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "scatter_(): Expected floating or complex type for self."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); SCATTER_GATHER_OP op = get_operator_enum(reduce); scatter_scalar_reduce_stub(self.device().type(), self, dim, index, value, op); return self; @@ -613,6 +801,9 @@ Tensor & scatter_reduce_(Tensor & self, const int64_t dim, const Tensor & index, "scatter_(): Expected dtype int64 for index"); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "scatter_(): Expected floating or complex type for self."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, src); SCATTER_GATHER_OP op = get_operator_enum(reduce); scatter_reduce_stub(self.device().type(), self, dim, index, src, op); return self; @@ -629,6 +820,9 @@ Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar so Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) { TORCH_CHECK_INDEX(index.scalar_type() == ScalarType::Long, "scatter_(): Expected dtype int64 for index."); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, src); scatter_add_stub(self.device().type(), self, dim, index, src); return self; } @@ -725,8 +919,8 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, "masked_select(): self and result must have the same scalar type"); at::assert_no_internal_overlap(result); - at::assert_no_partial_overlap(result, self); - at::assert_no_partial_overlap(result, mask); + at::assert_no_overlap(result, self); + at::assert_no_overlap(result, mask); if (mask.dtype() == at::ScalarType::Byte) { TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \ @@ -815,6 +1009,80 @@ Tensor masked_select_backward(const Tensor& grad, const Tensor& input, const Ten return result.masked_scatter_(mask, grad); } +void take_out_cpu_template( + Tensor& output, + Tensor const& input, + Tensor const& index) +{ + TORCH_CHECK(output.device().type() == at::kCPU, "device type of output (", output.device().type(), ") is not CPU"); + TORCH_CHECK(input.device().type() == at::kCPU, "device type of input (", input.device().type(), ") is not CPU"); + TORCH_CHECK(index.device().type() == at::kCPU, "device type of index (", index.device().type(), ") is not CPU"); + + TORCH_CHECK(output.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", + output.layout(), " on output tensor"); + TORCH_CHECK(input.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", + input.layout(), " on input tensor"); + TORCH_CHECK(index.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", + index.layout(), " on index tensor"); + + TORCH_CHECK(output.scalar_type() == input.scalar_type(), "output and input scalar type must match.", + "But got different types: ", output.scalar_type(), " and ", input.scalar_type()); + TORCH_CHECK(index.scalar_type() == kLong, "index must be an int64 tensor"); + + output.resize_(index.sizes()); + auto output_contiguous = output.contiguous(); + auto index_continuous = index.contiguous(); + bool is_contiguous = input.is_contiguous(); + auto input_size = input.numel(); + at::assert_no_internal_overlap(output); + at::assert_no_partial_overlap(output, index); + at::assert_no_overlap(output, input); + + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, input.scalar_type(), "take_cpu", [&] { + auto output_data = output_contiguous.data_ptr(); + auto input_data = input.data_ptr(); + auto index_data = index.data_ptr(); + + // Exceptions must not be thrown across parallel sections, so we + // record the position of the invalid index and throw the exception after the + // loop. + std::atomic invalidIdxPos(-1); + + at::parallel_for(0, index.numel(), at::internal::GRAIN_SIZE, + [&](int64_t start, int64_t end) { + for (auto i = start; i < end; i++) { + int64_t idx = index_data[i]; + if (idx < input_size && idx >= -input_size) { + idx = wrapLinearIndex(idx, input_size); + if (is_contiguous) { + output_data[i] = input_data[idx]; + } else { + output_data[i] = input_data[dataOffset(input, idx)]; + } + } else { + int64_t tmp = -1; + invalidIdxPos.compare_exchange_strong(tmp, i); + } + } + }); + + if (invalidIdxPos >= 0) { + checkLinearIndex(index_data[invalidIdxPos], input_size); + } + }); +} + +Tensor take_cpu(const Tensor& self, const Tensor& index) { + auto output = at::empty(index.sizes(), self.options()); + take_out_cpu_template(output, self, index); + return output; +} + +Tensor& take_out_cpu(Tensor& out, const Tensor& self, const Tensor& index) { + take_out_cpu_template(out, self, index); + return out; +} + Tensor take_backward(const Tensor& grad, const Tensor& input, const Tensor& index) { return at::zeros_like(input).put_(index, grad, true); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h index 560b461625467..0e0958606de1c 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.h +++ b/aten/src/ATen/native/TensorAdvancedIndexing.h @@ -15,7 +15,7 @@ enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY}; using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); -using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe); +using index_put_accum_fn = void(*)(Tensor &, const c10::List> &, const Tensor &, bool unsafe); using masked_fill_fn = void(*)(TensorIterator &, Scalar scalar); using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); @@ -42,6 +42,6 @@ DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub); DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub); DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub); -TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices); +TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List>& indices); }} // namespace at::native diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 09ed60c19c32d..27db468a407cd 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -38,6 +38,8 @@ Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type()); TORCH_CHECK(!(self.is_complex() && equal_nan), "isclose with equal_nan=True is not supported for complex inputs."); + TORCH_CHECK(!(self.is_quantized() || other.is_quantized()), + "isclose is not supported for quantized inputs."); // Checks that rtol and atol are non-negative // Note: consistent with Python's isclose but divergent from NumPy's, which @@ -104,7 +106,7 @@ Tensor isinf(const Tensor &self) { (at::isinf(at::imag(self))); } - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "isinf", [&]() { + return AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "isinf", [&]() { return self.abs() == std::numeric_limits::infinity(); }); } @@ -168,7 +170,7 @@ Tensor isfinite(const Tensor& self) { return at::isfinite(self.abs()); } - return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "isfinite", [&]() { + return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "isfinite", [&]() { return (self == self) * (self.abs() != std::numeric_limits::infinity()); }); } @@ -193,7 +195,35 @@ bool is_nonzero(const Tensor& self) { namespace { -static Tensor wrapped_scalar_tensor( +// DO NOT USE THIS -- it's just an implementation detail of wrapped_scalar tensor below. +at::Tensor scalar_to_tensor_default_dtype( + Scalar s, + const Device device = at::kCPU) { + if (s.isFloatingPoint()) { + return at::scalar_tensor( + s, at::device(device).dtype(at::get_default_dtype())); + } else if (s.isBoolean()) { + return at::scalar_tensor(s, at::device(device).dtype(at::kBool)); + } else if (s.isComplex()) { + return at::scalar_tensor( + s, at::device(device).dtype(at::get_default_complex_dtype())); + } else { + TORCH_INTERNAL_ASSERT(s.isIntegral(false)); + return at::scalar_tensor(s, at::device(device).dtype(at::kLong)); + } +} + +// TLDR: Don't call with `use_default_dtype` true -- this is only necessary to support the partial +// type-promotion that torch.where supports. Once torch.where fully supports type promotion, we +// won't need this function. +// +// Longer explanation: +// `use_default_dtype` is a bit of a hack because torch.where doesn't support type promotion, but +// does support `torch.where(tensor, scalar1, scalar2)` with default scalar types. The trickiness is we +// usually convert double scalars to doubles, and `set_wrapped_number` defines type promotion priority +// as being below tensor types rather than as the default dtype (perhaps we should?). This wouldn't matter +// if we just supported type normal type promotion on torch.where, however. +Tensor wrapped_scalar_tensor( Scalar scalar, Device device, bool use_default_dtype = false) { @@ -284,7 +314,6 @@ std::tuple mode_out(Tensor& values, Tensor& indices, } std::tuple max(const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); Tensor max_indices = at::empty({0}, self.options().dtype(kLong)); if (self.is_quantized()) { Tensor max = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type()))); @@ -299,7 +328,6 @@ std::tuple max(const Tensor& self, int64_t dim, bool keepdim) { static std::tuple max_out_impl(Tensor& max, Tensor& max_indices, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, "max only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, @@ -312,6 +340,7 @@ static std::tuple max_out_impl(Tensor& max, Tensor& max_indic max_indices.device(), " for indices output"); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) { + TORCH_CHECK(!self.is_complex(), "max does not support complex inputs."); AT_ASSERT(max.dim() == 0); max_indices.resize_({}).fill_(0); return std::forward_as_tuple(max, max_indices); @@ -323,7 +352,6 @@ static std::tuple max_out_impl(Tensor& max, Tensor& max_indic std::tuple max_out(Tensor& max, Tensor& max_indices, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); auto result = [&]() { NoNamesGuard guard; return max_out_impl(max, max_indices, self, dim, keepdim); @@ -334,7 +362,6 @@ std::tuple max_out(Tensor& max, Tensor& max_indices, } std::tuple min(const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors."); Tensor min_indices = at::empty({0}, self.options().dtype(kLong)); if (self.is_quantized()) { Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type()))); @@ -348,7 +375,6 @@ std::tuple min(const Tensor& self, int64_t dim, bool keepdim) { static std::tuple _aminmax_out_impl(Tensor& min, Tensor& max, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, "min_max_val only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, @@ -362,6 +388,7 @@ static std::tuple _aminmax_out_impl(Tensor& min, Tensor& max dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min") && _dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) { + TORCH_CHECK(!self.is_complex(), "min_max does not support complex inputs."); return std::forward_as_tuple(min, max); } else { _aminmax_stub(self.device().type(), min, max, self, dim, keepdim); @@ -370,7 +397,6 @@ static std::tuple _aminmax_out_impl(Tensor& min, Tensor& max } std::tuple _aminmax(const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "min_max is not yet implemented for complex tensors."); TORCH_CHECK(!self.is_quantized(), "min is not yet implemented for quantized tensors."); Tensor min = at::empty({0}, self.options()); @@ -382,7 +408,6 @@ std::tuple _aminmax(const Tensor& self, int64_t dim, bool keepdi static std::tuple min_out_impl(Tensor& min, Tensor& min_indices, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors."); TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, "min only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, @@ -395,6 +420,7 @@ static std::tuple min_out_impl(Tensor& min, Tensor& min_indic min_indices.device(), " for indices output"); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) { + TORCH_CHECK(!self.is_complex(), "min does not support complex inputs."); AT_ASSERT(min.dim() == 0); min_indices.resize_({}).fill_(0); return std::forward_as_tuple(min, min_indices); @@ -406,7 +432,6 @@ static std::tuple min_out_impl(Tensor& min, Tensor& min_indic std::tuple min_out(Tensor& min, Tensor& min_indices, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors."); auto result = [&]() { NoNamesGuard guard; return min_out_impl(min, min_indices, self, dim, keepdim); @@ -420,21 +445,17 @@ std::tuple min_out(Tensor& min, Tensor& min_indices, // Named tensor overloads std::tuple min(const Tensor& self, Dimname dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors."); return at::min(self, dimname_to_position(self, dim), keepdim); } std::tuple min_out(Tensor& min, Tensor& min_indices, const Tensor& self, Dimname dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors."); return at::min_out(min, min_indices, self, dimname_to_position(self, dim), keepdim); } std::tuple max(const Tensor& self, Dimname dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); return at::max(self, dimname_to_position(self, dim), keepdim); } std::tuple max_out(Tensor& max, Tensor& max_indices, const Tensor& self, Dimname dim, bool keepdim) { - TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors."); return at::max_out(max, max_indices, self, dimname_to_position(self, dim), keepdim); } Tensor argmax(const Tensor& self, Dimname dim, bool keepdim) { diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 58df4cf110f76..d773de927efb6 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -29,10 +29,15 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b return self; } + bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() && + (options.layout() == c10::kStrided)); + if (memory_format == MemoryFormat::Preserve) { if (self.is_non_overlapping_and_dense()) { // Copy all strides - auto r = at::empty_strided(self.sizes(), self.strides(), options.memory_format(c10::nullopt)); + auto r = at::empty_strided(self.sizes(), + self.strides(), + options.memory_format(c10::nullopt).pinned_memory(pin_out)); r.copy_(self, non_blocking); return r; } else { @@ -40,7 +45,9 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b } } // See Note [Explicit nullopt MemoryFormat argument] - auto r = at::empty(self.sizes(), options.memory_format(memory_format), c10::nullopt); + auto r = at::empty(self.sizes(), + options.memory_format(memory_format).pinned_memory(pin_out), + c10::nullopt); r.copy_(self, non_blocking); return r; } @@ -56,7 +63,7 @@ Tensor to( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK(options.requires_grad_opt() == c10::nullopt, "to(options) expects unset requires_grad flag, but got " @@ -99,7 +106,7 @@ Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) { auto input = input_.coalesce(); return grad.sparse_mask(input); } else if (input_.layout() == c10::kMkldnn) { - return grad.to_mkldnn(); + return grad.to_mkldnn(input_.scalar_type()); } else { AT_ERROR("Unsupported input layout: ", input_.layout()); } @@ -107,7 +114,23 @@ Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) { Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) { AT_ASSERT(input_.layout() == c10::kStrided); - return grad.to_dense(); + return grad.to_dense(input_.scalar_type()); +} + +Tensor view_dtype(const Tensor& self, ScalarType dtype) { + if (self.scalar_type() == dtype) { + return self; + } + auto type_meta = c10::scalarTypeToTypeMeta(dtype); + TORCH_CHECK(self.element_size() == type_meta.itemsize(), + "Viewing a tensor as a new dtype with a different number of bytes per element is not supported."); + Storage storage = self.storage(); + auto new_tensor = detail::make_tensor( + std::move(storage), self.key_set(), type_meta); + auto* impl = new_tensor.unsafeGetTensorImpl(); + impl->set_storage_offset(self.storage_offset()); + impl->set_sizes_and_strides(self.sizes(), self.strides()); + return new_tensor; } }} // namespace at::native diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 149aab7cfc225..10c2d6ee999c3 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -165,46 +165,9 @@ Tensor polar(const Tensor& abs, const Tensor& angle) { } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional optional_memory_format) { - - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); - - AT_ASSERT(options.device().type() == DeviceType::CPU); - TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch()); - check_size_nonnegative(size); - - c10::Allocator* allocator; - if (options.pinned_memory()) { - allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); - } else { - allocator = at::getCPUAllocator(); - } - - int64_t nelements = prod_intlist(size); - auto dtype = options.dtype(); - int64_t size_bytes = nelements * dtype.itemsize(); - auto storage_impl = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size_bytes, - allocator->allocate(size_bytes), - allocator, - /*resizeable=*/true); - - auto tensor = detail::make_tensor( - std::move(storage_impl), at::DispatchKey::CPU, dtype); - // Default TensorImpl has size [0] - if (size.size() != 1 || size[0] != 0) { - tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); - } - - auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); - - return tensor; +Tensor empty_cpu(IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } Tensor empty( @@ -224,9 +187,10 @@ Tensor empty( return result; } -Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) { +Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { check_size_nonnegative(size); - auto t = at::native::empty_cpu({0}, options); + auto t = at::native::empty_cpu({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt); at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride); return t; } @@ -278,7 +242,7 @@ Tensor empty_like( TensorOptions options = self.options() .merge_in(options_) - .merge_in(TensorOptions().memory_format(optional_memory_format)); + .merge_memory_format(optional_memory_format); TORCH_CHECK( !(options.layout() != kStrided && @@ -349,6 +313,12 @@ Tensor empty_like( if (memory_format == MemoryFormat::Preserve) { if (self.is_non_overlapping_and_dense()) { result = at::empty_strided(self.sizes(), self.strides(), options.memory_format(c10::nullopt)); + } else if (self.unsafeGetTensorImpl()->support_as_strided() && self.layout() == kStrided) { + // If input tensor is not dense and non-overlapping but strided, we will infer an output strides + // which keeps the layout permutation of the input tensor. + std::vector strides = infer_dense_strides(self.sizes(), self.strides()); + // See Note [Explicit nullopt MemoryFormat argument] + result = at::empty_strided(self.sizes(), strides, options.memory_format(c10::nullopt)); } else { // See Note [Explicit nullopt MemoryFormat argument] result = at::empty(self.sizes(), options.memory_format(self.suggest_memory_format()), c10::nullopt); @@ -368,15 +338,32 @@ Tensor empty_like( Tensor new_empty( const Tensor& self, IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt + ) { + auto dtype = dtype_opt.has_value() ? dtype_opt : optTypeMetaToScalarType(self.options().dtype_opt()); + auto layout = layout_opt.has_value() ? layout_opt : self.options().layout_opt(); + auto device = device_opt.has_value() ? device_opt : self.options().device_opt(); + auto pin_memory = pin_memory_opt.has_value() ? pin_memory_opt : self.options().pinned_memory_opt(); + return at::empty(size, dtype, layout, device, pin_memory, c10::nullopt); +} + +Tensor new_empty_strided( + const Tensor& self, + IntArrayRef size, + IntArrayRef stride, const TensorOptions& options ) { - return at::empty(size, self.options().merge_in(options)); + return at::empty_strided(size, stride, self.options().merge_in(options)); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor eye(int64_t n, const TensorOptions& options) { - return native::eye(n, -1, options); + // the default value of `m` equals to `n` + return native::eye(n, n, options); } Tensor eye(int64_t n, int64_t m, const TensorOptions& options) { @@ -385,15 +372,13 @@ Tensor eye(int64_t n, int64_t m, const TensorOptions& options) { } Tensor& eye_out_cpu(Tensor& result, int64_t n) { - return native::eye_out_cpu(result, n, -1); + // the default value of `m` equals to `n` + return native::eye_out_cpu(result, n, n); } Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) { TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); - - if(m < 0) { - m = n; - } + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); result.zero_(); @@ -531,7 +516,7 @@ Tensor scalar_tensor(Scalar s, const TensorOptions& options) { // auto result = at::empty({}, options); at::tracer::impl::NoTracerDispatchMode tracer_guard; at::AutoNonVariableTypeMode non_var_type_mode(true); - auto result = empty_cpu({}, options); + auto result = empty_cpu({}, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); at::native::fill_(result, s); return result; } @@ -725,6 +710,7 @@ Tensor& randperm_out(Tensor& result, int64_t n) { Tensor& randperm_out_cpu(Tensor& result, int64_t n, c10::optional generator) { TORCH_CHECK(n >= 0, "n must be non-negative, got", n); + TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), "Expected a '", result.device(), "' generator device but found '", generator->device(), "'"); check_supported_max_int_with_precision(n, result); result.resize_({n}); auto gen = get_generator_or_default(generator, detail::getDefaultCPUGenerator()); @@ -758,13 +744,14 @@ Tensor range( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor tril_indices_cpu( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto tril_size = get_tril_size(row, col, offset); // create an empty Tensor with correct size - auto result = at::empty({2, tril_size}, options); + auto result = at::native::empty_cpu({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); // The following three approaches result in very little performance // differences. Hence, the 2nd option is taken for simpler code, and to return @@ -803,13 +790,14 @@ Tensor tril_indices_cpu( } Tensor triu_indices_cpu( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto triu_size = row * col - get_tril_size(row, col, offset - 1); // create an empty Tensor with correct size - auto result = at::empty({2, triu_size}, options); + auto result = at::native::empty_cpu({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void { // fill the Tensor with correct values @@ -1058,58 +1046,24 @@ Tensor vander(const Tensor& x, c10::optional N, bool increasing) { template Tensor tensor_cpu(ArrayRef values, const TensorOptions& options) { - auto result = at::empty(values.size(), options); - AT_ASSERT(result.is_contiguous()); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(result.scalar_type(), "tensor_cpu", [&] { - std::copy(values.begin(), values.end(), result.template data_ptr()); - }); - return result; + return at::detail::tensor_cpu(values, options); } template Tensor tensor_backend(ArrayRef values, const TensorOptions& options) { - auto cpu_tensor = tensor_cpu(values, options.device(DeviceType::CPU)); - return cpu_tensor.to(options.device()); + return at::detail::tensor_backend(values, options); } template Tensor tensor_complex_cpu(ArrayRef values, const TensorOptions& options) { - auto result = at::empty(values.size(), options); - AT_ASSERT(result.is_contiguous()); - AT_DISPATCH_COMPLEX_TYPES(result.scalar_type(), "tensor_cpu", [&] { - std::copy(values.begin(), values.end(), result.template data_ptr()); - }); - return result; + return at::detail::tensor_complex_cpu(values, options); } template Tensor tensor_complex_backend(ArrayRef values, const TensorOptions& options) { - auto cpu_tensor = tensor_complex_cpu(values, options.device(DeviceType::CPU)); - return cpu_tensor.to(options.device()); + return at::detail::tensor_complex_backend(values, options); } -#define TENSOR(T, _1) \ - Tensor tensor(ArrayRef values, const TensorOptions& options) { \ - if (options.device().type() != c10::DeviceType::CPU) { \ - return tensor_backend(values, options); \ - } else { \ - return tensor_cpu(values, options); \ - } \ - } -AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) -#undef TENSOR - -#define TENSOR(T, _1) \ - Tensor tensor(ArrayRef values, const TensorOptions& options) { \ - if (options.device().type() != c10::DeviceType::CPU) { \ - return tensor_complex_backend(values, options); \ - } else { \ - return tensor_complex_cpu(values, options); \ - } \ - } -AT_FORALL_COMPLEX_TYPES(TENSOR) -#undef TENSOR - Tensor from_file(std::string filename, c10::optional shared, c10::optional size, const TensorOptions& options) { TORCH_CHECK(!options.pinned_memory(), "tensors constructed from a file cannot be pinned"); int64_t my_size = size.value_or(0); diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index f551adcec693a..d5943ac55ae57 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include #include @@ -50,22 +52,18 @@ inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) { } inline void check_args( - int64_t row, int64_t col, const TensorOptions& options) { + int64_t row, int64_t col, c10::optional layout_opt) { TORCH_CHECK(row >= 0, "row must be non-negative, got", row); TORCH_CHECK(col >= 0, "col must be non-negative, got", col); - if (options.has_layout()) { + if (layout_opt.has_value()) { TORCH_CHECK( - options.layout() == at::kStrided, + *layout_opt == at::kStrided, "only support layout=torch.strided, got", - options.layout()) + *layout_opt) } } -inline void check_size_nonnegative(IntArrayRef size) { - for (auto x: size) { - TORCH_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size); - } -} +using at::check_size_nonnegative; inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) { TORCH_CHECK(at::scalar_tensor(n, tensor.options()).defined(), diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h index f15af0b2208fc..e55d2a58d7099 100644 --- a/aten/src/ATen/native/TensorIterator.h +++ b/aten/src/ATen/native/TensorIterator.h @@ -1,535 +1,2 @@ #pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -// TensorIterator is a helper class for element-wise operations, such as -// arithmetic, comparisons, and trigonometric functions. It handles -// broadcasting and type conversions of operands. -// -// This is inspired by NumPy's Array Iterator API (NpyIter). -// -// The files Loops.h and Loops.cuh provide functions to build kernels that -// use TensorIterator. -// -// Example: -// -// auto iter = TensorIteratorConfig() -// .add_output(output) -// .add_input(input) -// .build() -// -// [MyKernel.cpp / MyKernel.cu] -// cpu_kernel(iter, [](float a, float b) { -// return a + b; -// }); -// -// gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float { -// return a + b; -// }); -// -// Note [Common Dtype Computation] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Some operations have a natural notion of a "common dtype" or -// "computation dtype" where all inputs are cast to one dtype, the -// operation is performed, and then the results are cast to all outputs. -// -// TensorIterator infers a common dtype if all inputs have the same dtype, -// and it computes one using type promotion rules on its inputs if -// promote_inputs_to_common_dtype_ is true. Attempting to query -// a common dtype otherwise will throw an exception. -// -// Note that the outputs are not considered when computing a common dtype. - -namespace at { - -struct DimCounter { - DimCounter(IntArrayRef shape, Range range); - - void increment(const std::array& step); - bool is_done() const; - std::array max_2d_step() const; - - IntArrayRef shape; - Range range; - DimVector values; - int64_t offset; -}; - -struct CAFFE2_API OperandInfo { - using StrideVector = SmallVector; - OperandInfo() {} - explicit OperandInfo(Tensor t) : tensor(std::move(t)) { - if (tensor.defined()) { - device = tensor.device(); - target_dtype = tensor.scalar_type(); - current_dtype = target_dtype; - } - validate(); - } - - /// Stride after broadcasting. The stride is in bytes, not number of elements. - StrideVector stride_bytes; - - /// The tensor operand. Note that the strides, data pointer, and - /// other attributes may differ due to dimension reordering and - /// coalescing. - Tensor tensor; - - // Save the original tensor operand in cases when an output is modified - // (e.g. if dtype is changed) - Tensor original_tensor; - - /// The desired device and type for the operand. For inputs, this specifies that - /// the input should be converted to this type if necessary. For outputs, this - /// specifies which type to allocate. target_dtype and device are initialized with the dtype and device of the tensor - /// but during type promotion target_dtype value can become different from tensor's dtype - /// also, during type promotion target_dtype and device can be set for an undefined tensor so that tensor can be properly - /// constructed later. - Device device = kCPU; - ScalarType target_dtype = ScalarType::Undefined; - // Caches dtype of the tensor, because scalar_type is an expensive operation - // If dtype of the tensor is changed (e.g. as a result of type promotion or in allocate_outputs), this - //value should be changed too. - ScalarType current_dtype = ScalarType::Undefined; - - bool is_type_defined() const { return target_dtype != ScalarType::Undefined; } - TensorOptions options() const { - return TensorOptions(target_dtype).device(device); - } - - /// The data pointer. This may be different from tensor.data_ptr() if the - /// iterator is split. - void* data = nullptr; - - bool is_output = false; - - bool will_resize = false; - - bool is_read_write = false; - - void validate() { - TORCH_CHECK( - !tensor.defined() || tensor.layout() == kStrided, - "unsupported tensor layout: ", tensor.layout()); - } -}; - -struct SplitUntil32Bit; - -enum class FastSetupType : uint8_t { - NONE, - CONTIGUOUS, - CHANNELS_LAST, - NON_OVERLAPPING_DENSE -}; - -class TensorIteratorConfig; - -struct CAFFE2_API TensorIterator { - using DimMask = std::bitset<64>; - using PtrVector = SmallVector; - using StrideVector = SmallVector; - - TensorIterator(TensorIteratorConfig&); - - // The inner-loop function operates on the fastest moving dimension. It - // implements element-wise operations in terms of 1-d strided tensors. - // - // Arguments: - // data: data pointers for each operand (length `ntensors`) - // strides: stride for each operand (length `ntensors`) - // size: size of inner loop - // - // The `size` often matches shape[0], but may be smaller due to - // parallelization of the inner loop. - using loop_t = c10::function_ref; - using loop2d_t = c10::function_ref; - - using loop_subiter_t = c10::function_ref; - - void foreach_reduced_elt(loop_subiter_t loop, bool parallelize=true); - - static TensorIterator binary_float_op(Tensor& out, const Tensor& a, const Tensor& b); - static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b); - static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b); - static TensorIterator unary_op(Tensor& out, const Tensor& a); - static TensorIterator nullary_op(Tensor& out); - static TensorIterator reduce_op(Tensor& out, const Tensor& a); - static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a); - - int ndim() const { return shape_.size(); } - IntArrayRef shape() const { return shape_; } - int64_t numel() const; - int ntensors() const { return operands_.size(); } - int noutputs() const { return num_outputs_; } - int ninputs() const { return ntensors() - noutputs(); } - IntArrayRef view_offsets() const { return view_offsets_; } - - /// number of elements in the output operand. this is the same as numel() for - /// operations that are not reductions. - int64_t num_output_elements() const; - - /// number of reduced dimensions in a reduction operation - int num_reduce_dims() const; - - /// 1-dimensional iteration and no buffering or type conversion - bool is_trivial_1d() const; - /// Reducible to 1-dimensional and all operands are contiguous - bool is_contiguous() const; - bool is_dim_reduced(int dim) const; - - /// Accessors for each operand - IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; } - void* data_ptr(int arg) const; - ScalarType dtype(int arg=0) const { return operands_[arg].current_dtype; } - ScalarType common_dtype() const { - TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined, "Queried for invalid common dtype!"); - return common_dtype_; - } - ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].current_dtype; } - Device device(int arg=0) const { return operands_[arg].device; } - DeviceType device_type(int arg=0) const { return device(arg).type(); } - int64_t element_size(int arg) const { return elementSize(dtype(arg)); } - bool is_scalar(int arg) const; - bool is_cpu_scalar(int arg) const; - - const Tensor& tensor(int arg) const { return operands_[arg].tensor; } - Tensor& tensor(int arg) { return operands_[arg].tensor; } - - Tensor output(int arg=0) const { - AT_ASSERT(arg < num_outputs_); - return operands_[arg].tensor; - } - - // Copies from temporary outputs back to the original outputs - // NOTE: only used on CPU - void cast_outputs(); - - Tensor input(int arg=0) const { - AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); - return operands_[num_outputs_ + arg].tensor; - } - - /// Removes an operand from this iterator - void remove_operand(int arg); - /// Shrinks an iterated dimension - void narrow(int dim, int64_t start, int64_t size); - /// Narrows every dim after and including `start_dim` to size one. - void select_all_keeping_dim(int start_dim, IntArrayRef starts); - /// Replaces the data pointer for the operand at index `arg`. - /// The new pointer should have the same sizes, strides and dtype as the - /// original - void unsafe_replace_operand(int arg, void* data); - - /// Splits this TensorIterator into two iterators. Together they iterate over - /// the entire operation. Used by `with_32bit_indexing()`. - std::unique_ptr split(int dim); - - /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim] - int get_dim_to_split() const; - - template - T scalar_value(int arg) { - auto& op = operands_[arg]; - return c10::fetch_and_cast(op.tensor.scalar_type(), op.data); - } - - void for_each(loop_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); - void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); - - void parallel_reduce(loop2d_t loop); - - void serial_for_each(loop_t loop, Range range) const; - void serial_for_each(loop2d_t loop, Range range) const; - - /// Create a strides array for a Tensor with shape of this iterator. The - /// parameter `element_size` specifies the size of Tensor's data type in - /// bytes (e.g. `4` for `float`) - StrideVector compatible_stride(int element_size) const; - - /// Inverts the re-ordering done by reorder_dimensions. This can only be - /// called *before* coalesce_dimensions() is called. - DimVector invert_perm(IntArrayRef input) const; - - /// Reapply same re-ordering as it is done by reorder_dimensions. This can - /// only be called *before* coalesce_dimensions() is called. - DimVector apply_perm_and_mul(IntArrayRef input, int mul) const; - - /// Helper functions for CPU iteration - StrideVector get_dim_strides(int dim) const; - StrideVector get_strides() const; - StrideVector get_inner_strides() const { return get_dim_strides(0); } - PtrVector get_data_ptrs(ArrayRef base, IntArrayRef counter) const; - PtrVector get_base_ptrs() const; - - /// true if the stride computation can use 32-bit arithmetic. Used by GPU kernels - bool can_use_32bit_indexing() const; - - /// An "iteratable" object that recursively splits this iterator into sub-iterators - /// that can use 32-bit indexing. - SplitUntil32Bit with_32bit_indexing() const; - - /// If the kernel should accumulate into the output. Only relevant for CUDA - /// reductions. - bool should_accumulate() const { return accumulate_; } - - /// Whether this iterator produces the actual output, - /// as opposed to something that will be accumulated further. Only relevant for - /// CUDA reductions. - bool is_final_output() const { return final_output_; } - - bool has_contiguous_first_dim() const { - int num_tensors = ntensors(); - for (int i = 0; i < num_tensors; i++) { - if (strides(i)[0] != element_size(i)) { - return false; - } - } - return true; - } - -protected: - void build(TensorIteratorConfig&); - - // Mutable reference as it moves tensors out of TensorIteratorConfig - void populate_operands(TensorIteratorConfig&); - void mark_outputs(); - void mark_resize_outputs(const TensorIteratorConfig&); - void compute_mem_overlaps(const TensorIteratorConfig&); - void compute_shape(const TensorIteratorConfig&); - void compute_strides(const TensorIteratorConfig&); - void reorder_dimensions(const TensorIteratorConfig&); - void permute_dimensions(IntArrayRef perm); - void compute_types(const TensorIteratorConfig&); - ScalarType compute_common_dtype(); - void allocate_or_resize_outputs(); - bool fast_set_up(const TensorIteratorConfig&); - FastSetupType compute_fast_setup_type(const TensorIteratorConfig&); - void compute_names(const TensorIteratorConfig&); - void propagate_names_to_outputs(); - void coalesce_dimensions(); - -protected: - - /// Records the "computation" shape of the output tensor. The computation - /// shape is different from the regular shape in a few ways: - /// - /// - The shape may be permuted (via permute_dimensions) so that we - /// process the dimensions in the most computationally efficient order - /// (rather than the logical order given to us by the users.) - /// - The shape may have adjacent dimensions collapsed (via - /// coalesce_dimensions) so that we minimize the number of - /// dimensions we have to explicitly iterate over. For example, - /// a pointwise operation on a contiguous tensor "computationally" - /// consists of only a single dimension. - /// - /// In other words, the computation shape is the output shape as it - /// actually matters for implementing the kernel, but not necessarily the - /// output shape that the user will see in the end. - /// - /// The lifecycle of mutations to shape_ in TensorIterator: - /// - declare_static_shape() sets an initial shape explicitly - /// provided by user, otherwise - /// - compute_shape() computes the true (non-computational) shape - /// specified by the user. - /// - reorder_dimensions() reorders dimensions to improve coalescing. - /// - coalesce_dimensions() then coalesces adjacent dimensions when - /// possible. - /// - /// The shape may also be further modified if we create sub-TensorIterators, - /// e.g., via narrow or select_all_keeping_dim. - DimVector shape_; - - /// Temporarily records the permutation computed by reorder_dimensions. - /// This permutation maps the computation output dimension (dim) to - /// the original true output dimension (perm_[dim]). It is used by - /// invert_perm to undo the permutation. After coalesce_dimensions is - /// called, the permutation is no longer valid (as, in general, there - /// is no permutation that will make computation dimensions to - /// output dimensions); methods that manipulate perm_ are obligated - /// to test that !has_coalesced_dimensions - DimVector perm_; - - /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build()) - /// been called? This is SOLELY used to check validity of perm_. - bool has_coalesced_dimensions_ = false; - - /// The index offsets into the original tensors for each dimension. - /// This is only non-zero when you narrow() a TensorIterator (e.g., - /// when you make sub-TensorIterators). - DimVector view_offsets_; - - /// The computed names of the output tensor. Computed by compute_names() - NameVector names_; - - /// The operands of the TensorIterator: both the inputs and outputs. The - /// outputs MUST come first in the operands_ list. There is always an - /// operand for each output of the TensorIterator, even if TensorIterator - /// will ultimately be responsible for allocating the output; in those - /// cases, tensor is simply undefined (and will be populated later - /// during build()). - /// - /// This list is initially populated prior to build(), but build() mutates - /// OperandInfo to populate more information. - SmallVector operands_; - - /// Number of outputs in operands_ (the length of the outputs prefix - /// in operands_). - int num_outputs_ = 0; - - /// Whether or not all operands have the same shape. Having all the same - /// shape affects whether or not the iterator is eligible for fast setup. - bool all_ops_same_shape_ = false; - - /// The "computation" dtype of TensorIterator, specifying what the dtype - /// we will do the internal computation in TensorIterator. Typically, - /// this matches the dtype of the output tensors, but not always! - ScalarType common_dtype_ = ScalarType::Undefined; - - /// Set by split(), see should_accumulate() and is_final_output() - bool accumulate_ = false; - bool final_output_ = true; - - // From TensorIteratorConfig - bool is_reduction_ = false; -}; - -class CAFFE2_API TensorIteratorConfig final { -public: - friend struct TensorIterator; - - TensorIteratorConfig() {} - - C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig); - - /// Construction - TensorIteratorConfig& add_output(const Tensor& output); - TensorIteratorConfig& add_input(const Tensor& input); - - // Sets the check_mem_overlap_ flag, which is true by default. - // If true, inputs are checked for partial overlap with the outputs and - // outputs are checked for internal overlap (e.g. broadcasted views). An error - // is raised if unacceptable overlap is detected. - // If you're migrating an existing operator to using TensorIterator, please - // consider if the previous implementation checked memory overlap. If it did - // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then - // checking memory overlap is BC-breaking. Please don't check memory overlap - // in that case. - TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap); - - // Sets the check_all_same_dtype_ flag, which is true by default - // If true, checks that all inputs and defined outputs have the same dtype - // Setting either of promote_inputs_to_common_dtype_ - // or cast_common_dtype_to_outputs_ to true will set - // check_all_same_dtype_ to false. - TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype); - - // Sets the check_all_same_device_ flag, which is true by default - // If true, all operands must be on the same device, with the possible - // exception of CPU scalars, which can be passed to some CUDA kernels - // as kernel arguments. - TensorIteratorConfig& check_all_same_device(const bool _check_all_same_device); - - // Sets the enforce_safe_casting_to_output_ flag, which is false by default - // If true, the iterator's "common dtype" must be computable - // (see the [Common Dtype Computation] note) and - // canCast(common dtype, output dtype) must be true for all outputs. - TensorIteratorConfig& enforce_safe_casting_to_output(const bool _enforce_safe_casting_to_output); - - // Sets the promote_inputs_to_common_dtype_ flag, which is false by default - // If true, the iterator's "common dtype" is always computed (see the - // [Common Dtype Computation] note) and, on the CPU, temporary copies of - // the inputs in the common dtype are passed as the actual inputs to - // the operation. - // Setting this flag to true sets check_all_same_dtype_ to false. - TensorIteratorConfig& promote_inputs_to_common_dtype(const bool _promote_inputs_to_common_dtype); - - // Sets the promote_integer_inputs_to_float_ flag, which is false by default - // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be true. - // If true, if the iterator's "common dtype" is an integral type (including bool) - // then it is changed to the default float scalar type. - TensorIteratorConfig& promote_integer_inputs_to_float(const bool _promote_integer_inputs_to_float); - TensorIteratorConfig& is_reduction(const bool _is_reduction); - TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars); - - // Sets the cast_common_dtype_to_outputs_ flag, which is false by default - // If true, the iterator's "common dtype" must be computatable - // (see the [Common Dtype Computation] note) and, on the CPU, temporary - // copies of the outputs are passed as the actual output to the operation. - // These temporaries are then copied to the original outputs after - // the operation is performed (see cast_outputs()). - // Setting this flag to true sets check_all_same_dtype_ to false. - TensorIteratorConfig& cast_common_dtype_to_outputs(const bool _cast_common_dtype_to_outputs); - TensorIteratorConfig& resize_outputs(bool resize_outputs); - - // Bypass output dtype/device computation and fix the dtype/device as specified here. - TensorIteratorConfig& declare_static_dtype_and_device(ScalarType dtype, Device device); - TensorIteratorConfig& declare_static_shape(IntArrayRef shape); - TensorIteratorConfig& declare_static_shape(IntArrayRef shape, const int64_t squash_dim); - - // It would be better if this was && qualified, but this would be at the cost - // of a lot of boilerplate above - TensorIterator build() { - return TensorIterator(*this); - } - -private: - SmallVector tensors_; - int num_outputs_ = 0; - int num_inputs_ = 0; - - c10::optional static_shape_ = c10::nullopt; - c10::optional> static_dtype_and_device_ = c10::nullopt; - bool check_mem_overlap_ = true; - bool allow_cpu_scalars_ = false; - bool is_reduction_ = false; - bool resize_outputs_ = true; - bool check_all_same_dtype_ = true; - bool check_all_same_device_ = true; - bool enforce_safe_casting_to_output_ = false; - bool promote_inputs_to_common_dtype_ = false; - bool promote_integer_inputs_to_float_ = false; - bool cast_common_dtype_to_outputs_ = false; -}; - - - -/// A container-like struct that acts as if it contains splits of a -/// TensorIterator that can use 32-bit indexing. Taken together the splits cover -/// the original TensorIterator. -struct CAFFE2_API SplitUntil32Bit { - struct CAFFE2_API iterator { - iterator() {}; - iterator(const TensorIterator& iter); - iterator(iterator&&) = default; - - TensorIterator& operator*() const; - iterator& operator++(); - bool operator==(const iterator& other) const { - // two iterators are equal if they are the same object or they're both empty - return this == &other || (vec.empty() && other.vec.empty()); - } - // needed for C++11 range-based for loop - bool operator!=(const iterator& other) const { return !(*this == other); } - - /// stack of TensorIterators to be split - std::vector> vec; - }; - - SplitUntil32Bit(const TensorIterator& iter) : iter(iter) {} - - iterator begin() const; - iterator end() const; - -private: - const TensorIterator& iter; -}; - -} // namespace at +#include diff --git a/aten/src/ATen/native/TensorIteratorDynamicCasting.h b/aten/src/ATen/native/TensorIteratorDynamicCasting.h index 31b4461c67e77..8e3b6760091ce 100644 --- a/aten/src/ATen/native/TensorIteratorDynamicCasting.h +++ b/aten/src/ATen/native/TensorIteratorDynamicCasting.h @@ -26,7 +26,7 @@ namespace at { namespace native { // (and returns) of func_t template::arity> struct needs_dynamic_casting { - static bool check(TensorIterator& iter) { + static bool check(TensorIteratorBase& iter) { using traits = function_traits; using cpp_type = typename traits::template arg::type; using cpp_map = c10::CppTypeToScalarType; @@ -40,7 +40,7 @@ struct needs_dynamic_casting { template struct needs_dynamic_casting { - static bool check(TensorIterator& iter) { + static bool check(TensorIteratorBase& iter) { using traits = function_traits; using cpp_type = typename traits::result_type; diff --git a/aten/src/ATen/native/TensorIteratorReduce.cpp b/aten/src/ATen/native/TensorIteratorReduce.cpp index 6d3ba3acb4fcb..fcce06a6c9362 100644 --- a/aten/src/ATen/native/TensorIteratorReduce.cpp +++ b/aten/src/ATen/native/TensorIteratorReduce.cpp @@ -2,18 +2,20 @@ #include #include #include +#include +#include /// Contains the implementation of parallel reductions in TensorIterator. namespace at { -using loop2d_t = TensorIterator::loop2d_t; +using loop2d_t = TensorIteratorBase::loop2d_t; -static bool use_two_pass_reduction(TensorIterator& iter); -static void two_pass_reduction(TensorIterator& iter, loop2d_t loop); -static void parallel_dim_reduction(TensorIterator& iter, loop2d_t loop); +static bool use_two_pass_reduction(TensorIteratorBase& iter); +static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop); +static void parallel_dim_reduction(TensorIteratorBase& iter, loop2d_t loop); -void TensorIterator::parallel_reduce(loop2d_t loop) { +void TensorIteratorBase::parallel_reduce(loop2d_t loop) { TORCH_CHECK(ntensors() == 2, "parallel_reduce only supports one input and one output"); int64_t numel = this->numel(); if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 || @@ -26,11 +28,11 @@ void TensorIterator::parallel_reduce(loop2d_t loop) { } } -static bool use_two_pass_reduction(TensorIterator& iter) { +static bool use_two_pass_reduction(TensorIteratorBase& iter) { return iter.output(0).numel() == 1; } -static void two_pass_reduction(TensorIterator& iter, loop2d_t loop) { +static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop) { int max_threads = at::get_num_threads(); auto dst = iter.output(0); @@ -65,7 +67,7 @@ static void two_pass_reduction(TensorIterator& iter, loop2d_t loop) { /// Chooses a dimension over which to parallelize. Prefers the outer-most /// dimension thats larger than the number of available threads. -static int find_split_dim(TensorIterator& iter) { +static int find_split_dim(TensorIteratorBase& iter) { int num_threads = at::get_num_threads(); auto shape = iter.shape(); @@ -84,7 +86,7 @@ static int find_split_dim(TensorIterator& iter) { } static std::tuple -round_columns(TensorIterator& iter, int dim, int multiple, int64_t begin, int64_t end) { +round_columns(TensorIteratorBase& iter, int dim, int multiple, int64_t begin, int64_t end) { begin = begin - (begin % multiple); if (end != iter.shape()[dim]) { // only round the 'end' column down if it's not the final column @@ -93,7 +95,7 @@ round_columns(TensorIterator& iter, int dim, int multiple, int64_t begin, int64_ return std::make_tuple(begin, end); } -static void parallel_dim_reduction(TensorIterator& iter, loop2d_t loop) { +static void parallel_dim_reduction(TensorIteratorBase& iter, loop2d_t loop) { AT_ASSERT(iter.ndim() >= 1); int dim = find_split_dim(iter); int64_t cols = iter.shape()[dim]; @@ -116,7 +118,7 @@ static void parallel_dim_reduction(TensorIterator& iter, loop2d_t loop) { }); } -void TensorIterator::foreach_reduced_elt(loop_subiter_t loop, bool parallelize) { +void TensorIteratorBase::foreach_reduced_elt(loop_subiter_t loop, bool parallelize) { AT_ASSERT(ninputs() == 1); AT_ASSERT(noutputs() >= 1); @@ -153,7 +155,7 @@ void TensorIterator::foreach_reduced_elt(loop_subiter_t loop, bool parallelize) return; } - auto sub_iter = *this; + TensorIterator sub_iter(*this); sub_iter.narrow(dim, begin, end - begin); // On some broken setups, `#ifdef _OPENMP` is true, diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 48dab43b2dc81..f395c6956da56 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -14,15 +13,11 @@ bool is_same_size(const Tensor& self, const Tensor& other) { } int64_t size(const Tensor& self, int64_t dim) { - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - dim = maybe_wrap_dim(dim, self.dim(), false); - return self.sizes()[dim]; + return self.size(dim); } int64_t stride(const Tensor& self, int64_t dim) { - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - dim = maybe_wrap_dim(dim, self.dim(), false); - return self.strides()[dim]; + return self.stride(dim); } int64_t size(const Tensor& self, Dimname dim) { diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 7fba7916354a1..f2c3f6309a2eb 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,22 +1,24 @@ -#include -#include #include +#include #include #include +#include +#include #include +#include #include -#include -#include +#include #include -#include -#include -#include -#include #include #include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include namespace at { namespace native { @@ -27,7 +29,7 @@ Tensor _reshape_from_tensor(const Tensor& self, const Tensor& shape_tensor) { TORCH_CHECK(shape_tensor.dim() == 1); std::vector shape; auto accessor = shape_tensor.accessor(); - for (size_t i = 0; i < shape_tensor.numel(); ++i) { + for (int64_t i = 0; i < shape_tensor.numel(); ++i) { shape.push_back(accessor[i]); } return self.reshape(IntArrayRef(shape)); @@ -77,6 +79,10 @@ Tensor& set_cpu_(Tensor& result) { return result; } +Tensor broadcast_to(const Tensor& self, IntArrayRef size) { + return self.expand(size); +} + std::vector broadcast_tensors(TensorList tensors) { return expand_outplace(tensors); } @@ -92,26 +98,28 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor if (dim == dimension) { continue; } - int64_t first_dim_size = first.size(dim); - int64_t second_dim_size = second.size(dim); + int64_t first_dim_size = first.sizes()[dim]; + int64_t second_dim_size = second.sizes()[dim]; TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ", dimension, ". Got ", first_dim_size, " and ", second_dim_size, " in dimension ", dim, " (The offending index is ", index, ")"); } } +static bool should_skip(const Tensor& t) { + return t.numel() == 0 && t.dim() == 1; +} + Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific // size (i.e. other empty sizes are not skipped). - // FIXME: warn if this is the case - bool allSkipped = true; + bool allContiguous = true; - Tensor notSkippedTensor; // Inputs cannot alias the output tensor - for (int64_t i = 0; i < tensors.size(); i++) { + for (size_t i = 0; i < tensors.size(); i++) { auto lap = at::get_overlap_status(result, tensors[i]); TORCH_CHECK(lap != at::MemOverlapStatus::PARTIAL && lap != at::MemOverlapStatus::FULL, 0, @@ -120,19 +128,23 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { } at::assert_no_internal_overlap(result); - auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; }; - for (auto const &tensor : tensors) { - if (should_skip(tensor)) { - continue; + const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* { + for (auto const &tensor : tensors) { + if (should_skip(tensor)) { + continue; + } + // we've found a non-empty tensor + return &tensor; } - // we've found a non-empty tensor - allSkipped = false; - notSkippedTensor = tensor; - break; - } - if (allSkipped) { + return nullptr; + }(tensors); + + if (!pnotSkippedTensor) { + // FIXME: warn if this is the case -- see comment about skipped + // tensors at top of function. return result; } + const Tensor& notSkippedTensor = *pnotSkippedTensor; TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors"); TORCH_CHECK(dim <= notSkippedTensor.dim(), "dimension ", dim, "out of range"); @@ -147,7 +159,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { // compute size of the result in the cat dimension int64_t cat_dim_size = 0; auto first_tensor_mem_format = tensors[0].suggest_memory_format(); - for (int i = 0; i < tensors.size(); i++) { + for (size_t i = 0; i < tensors.size(); i++) { auto const &tensor = tensors[i]; if (should_skip(tensor)) { // don't use fast path for empty tensor @@ -155,7 +167,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { continue; } check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i); - cat_dim_size += tensor.size(dim); + cat_dim_size += tensor.sizes()[dim]; if (!tensor.is_contiguous(first_tensor_mem_format)) { allContiguous = false; @@ -172,7 +184,12 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { // compute the size of the result auto result_size = notSkippedTensor.sizes().vec(); result_size[dim] = cat_dim_size; - result.resize_(result_size, first_tensor_mem_format); + + // skip resizing if size of result is same as expected + if (result.sizes() != result_size) { + result.resize_(result_size, first_tensor_mem_format); + } + if (result.numel() == 0) { return result; } @@ -190,8 +207,8 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { if (reuse_iterator && result.is_contiguous(first_tensor_mem_format) && no_type_promotion) { - auto source_slice = notSkippedTensor; - auto slice_dim_size = source_slice.size(dim); + const auto& source_slice = notSkippedTensor; + auto slice_dim_size = source_slice.sizes()[dim]; auto result_slice = result.narrow(dim, 0, slice_dim_size); auto result_slice_data = result_slice.data_ptr(); auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); @@ -220,7 +237,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { if (should_skip(tensor)) { continue; } - auto slice_dim_size = tensor.size(dim); + auto slice_dim_size = tensor.sizes()[dim]; auto result_slice = result.narrow(dim, offset, slice_dim_size); auto iter = TensorIteratorConfig() @@ -280,7 +297,7 @@ static bool sizes_match_except(IntArrayRef s1, IntArrayRef s2, int64_t dim_excep if (s1.size() != s2.size()) { return false; } - for (int64_t i = 0; i < s1.size(); ++i) { + for (size_t i = 0; i < s1.size(); ++i) { if (i != dim_except && s1[i] != s2[i]) { return false; } @@ -368,7 +385,7 @@ static Tensor cat_sparse(TensorList tensors, int64_t dim) { // The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting. int64_t values_dim = wrapped - sparse_dim + 1; // The final size along the catted dimension. - int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), 0, [values_dim](int64_t l, Tensor const &r) { + const int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), static_cast(0), [values_dim](int64_t l, Tensor const &r) { return l + r._values().size(values_dim); }); auto zeros_sizes = tensors[0]._values().sizes().vec(); @@ -513,6 +530,59 @@ std::vector chunk(const Tensor& self, int64_t chunks, int64_t dim) { } } +std::vector tensor_split(const Tensor& self, int64_t sections, int64_t dim) { + TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); + int64_t dim_ = maybe_wrap_dim(dim, self.dim()); + TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections); + std::vector splits(sections); + int64_t min_split_size = self.size(dim_) / sections; + int64_t num_splits_one_extra = self.size(dim_) % sections; + int64_t start_idx = 0; + for (int64_t split_idx = 0; split_idx < sections; split_idx++) { + int64_t split_size = (split_idx < num_splits_one_extra) ? (min_split_size + 1) : min_split_size; + splits[split_idx] = at::slice(self, dim_, start_idx, start_idx + split_size); + start_idx += split_size; + } + return splits; +} + +std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) { + TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); + int64_t dim_ = maybe_wrap_dim(dim, self.dim()); + int64_t num_indices = indices.size(); + std::vector splits(num_indices + 1); + int64_t start_idx = 0; + for (int64_t split_idx = 0; split_idx < num_indices; split_idx++) { + int64_t end_idx = indices[split_idx]; + splits[split_idx] = at::slice(self, dim_, start_idx, end_idx); + start_idx = end_idx; + } + splits[num_indices] = at::slice(self, dim_, start_idx, self.size(dim_)); + return splits; +} + +std::vector tensor_split(const Tensor& self, const Tensor& tensor_indices_or_sections, int64_t dim) { + TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); + auto split_device = tensor_indices_or_sections.device(); + TORCH_CHECK(split_device == kCPU, + "tensor_split expected tensor_indices_or_sections to be on cpu, but it's on ", split_device); + auto split_dtype = tensor_indices_or_sections.scalar_type(); + TORCH_CHECK(split_dtype == at::kLong, + "tensor_split expected tensor_indices_or_sections to have dtype of long, but got ", split_dtype); + auto split_dim = tensor_indices_or_sections.dim(); + TORCH_CHECK(split_dim == 1 || split_dim == 0, + "tensor_split expected tensor_indices_or_sections to be a zero-dimensional or one-dimensional tensor, but got a tensor with ", split_dim, " dims"); + + if (split_dim == 0) { + int64_t sections = tensor_indices_or_sections.item(); + return self.tensor_split(sections, dim); + } else { + auto indices_data = tensor_indices_or_sections.data_ptr(); + std::vector indices(indices_data, indices_data + tensor_indices_or_sections.numel()); + return self.tensor_split(indices, dim); + } +} + std::vector unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) { TORCH_CHECK(self.dim() > 0, "chunk expects at least a 1-dimensional tensor"); @@ -721,8 +791,90 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ return newTensor._coalesced_(self.is_coalesced()); } +Tensor& narrow_copy_dense_cpu_out( + const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output +) { + TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(self.dtype() == output.dtype()); + + Tensor self_contig = self.contiguous(); + const auto self_sizes = self_contig.sizes(); + + // wrap dim if negative and do bound check + if (dim < 0) { + dim = at::maybe_wrap_dim(dim, self_sizes.size()); + } else { + TORCH_CHECK(dim < self_sizes.size()); + } + + // wrap start and do bound check + const auto cur_size = self_sizes[dim]; + if (start != cur_size && start < 0) { // start being the end is valid, but + // not a valid dim specification. + start = at::maybe_wrap_dim(start, cur_size); + } + TORCH_CHECK( + length >= 0 && start <= cur_size - length, + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + + // resize output + auto output_sizes = self_sizes.vec(); + output_sizes[dim] = length; + at::native::resize_(output, output_sizes); + + const int64_t unit = c10::size_from_dim_(dim + 1, self_sizes); + const int64_t num_blocks = c10::size_to_dim_(dim, self_sizes); + + const auto itemsize = self_contig.dtype().itemsize(); + size_t src_nbytes = itemsize * self_contig.numel(); + size_t dst_nbytes = itemsize * output.numel(); + + size_t src_block_size = unit * self_sizes[dim]; + size_t dst_block_size = unit * length; + + if (num_blocks == 0 || dst_block_size == 0) { + return output; + } + + char* src_bytes = static_cast(self_contig.data_ptr()); + char* dst_bytes = static_cast(output.data_ptr()); + + size_t src_block_size_bytes = itemsize * src_block_size; + size_t dst_block_size_bytes = itemsize * dst_block_size; + size_t src_offset = unit * start; + + char* src_offset_bytes = src_bytes + itemsize * src_offset; + char* dst_offset_bytes = dst_bytes; + + for (int64_t i = 0; i < num_blocks; ++i) { + char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes; + char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(local_src_offset_bytes + dst_block_size_bytes) <= + static_cast(src_bytes + src_nbytes)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(local_dst_offset_bytes + dst_block_size_bytes) <= + static_cast(dst_bytes + dst_nbytes)); + + memcpy( + local_dst_offset_bytes, local_src_offset_bytes, dst_block_size_bytes); + } + return output; +} + Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length){ - return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); + return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); +} + +Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ + auto output = at::empty_like(self); + return narrow_copy_dense_cpu_out(self, dim, start, length, output); } Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) { @@ -808,6 +960,23 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) { return result; } +Tensor tile(const Tensor& self, IntArrayRef reps){ + // If self.size() > len(reps), reps is promoted to self.size() by pre-pending + // 1’s to it to keep the same behaviour as `numpy.tile`. + // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated + // as (1, 1, 2, 2). + const int64_t size_diff = self.dim() - static_cast(reps.size()); + if (size_diff > 0){ + std::vector new_reps(size_diff, 1); + for(auto i = decltype(reps.size()){0}; i < reps.size(); ++i){ + new_reps.emplace_back(reps[i]); + } + return self.repeat(IntArrayRef(new_reps)); + } + // `torch.tile` is equivalent to the already implemented `torch.Tensor.repeat` + return self.repeat(reps); +} + Tensor alias_with_sizes_and_strides( const Tensor& self, const c10::IntArrayRef sizes, @@ -1021,7 +1190,12 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) } } -Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { +Tensor slice( + const Tensor& self, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { int64_t ndim = self.dim(); if (ndim == 0) { TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor."); @@ -1029,27 +1203,37 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_ dim = maybe_wrap_dim(dim, ndim); auto sizes = self.sizes().vec(); auto strides = self.strides().vec(); + + // handle optional parameters + int64_t start_val = start.has_value() ? start.value() : 0; + int64_t end_val = end.has_value() ? end.value() : INT64_MAX; + // TODO: support negative strides TORCH_CHECK(step > 0, "slice step must be positive"); - if (start < 0) { - start += sizes[dim]; + + // INT64_MAX stands for default value. + if (start_val == INT64_MAX) { + start_val = 0; + } + if (start_val < 0) { + start_val += sizes[dim]; } - if (end < 0) { - end += sizes[dim]; + if (end_val < 0) { + end_val += sizes[dim]; } - if (start < 0) { - start = 0; - } else if (start >= sizes[dim]) { - start = sizes[dim]; + if (start_val < 0) { + start_val = 0; + } else if (start_val >= sizes[dim]) { + start_val = sizes[dim]; } - if (end < start) { - end = start; - } else if (end >= sizes[dim]) { - end = sizes[dim]; + if (end_val < start_val) { + end_val = start_val; + } else if (end_val >= sizes[dim]) { + end_val = sizes[dim]; } - auto storage_offset = self.storage_offset() + start * strides[dim]; - auto len = end - start; - sizes[dim] = (len + step - 1) / step; // round-up + auto storage_offset = self.storage_offset() + start_val * strides[dim]; + auto len = end_val - start_val; + sizes[dim] = (len + step - 1) / step; // round-up strides[dim] *= step; auto result = self.as_strided(sizes, strides, storage_offset); namedinference::propagate_names(result, self); @@ -1231,6 +1415,47 @@ static inline Tensor & sparse_transpose_(Tensor & self, int64_t dim0, int64_t di return self; } +// torch.row_stack, alias for torch.vstack +Tensor& row_stack_out(Tensor& result, TensorList tensors) { + return at::vstack_out(result, tensors); +} + +Tensor row_stack(TensorList tensors) { + return at::vstack(tensors); +} + +static std::vector reshape_input_for_column_stack(TensorList tensors) { + std::vector result(tensors.size()); + auto transform_lambda = [](const Tensor& input) -> Tensor { + // reshape 0D or 1D tensor t into (t.numel(), 1) + if (input.dim() <= 1) { + return input.reshape({input.numel(), 1}); + } + return input; + }; + std::transform(tensors.cbegin(), + tensors.cend(), + result.begin(), + transform_lambda); + return result; +} + +Tensor& column_stack_out(Tensor& result, TensorList tensors) { + TORCH_CHECK(tensors.size() > 0, + "column_stack expects a non-empty TensorList"); + + auto reshaped_tensors = reshape_input_for_column_stack(tensors); + return at::hstack_out(result, reshaped_tensors); +} + +Tensor column_stack(TensorList tensors) { + TORCH_CHECK(tensors.size() > 0, + "column_stack expects a non-empty TensorList"); + + auto reshaped_tensors = reshape_input_for_column_stack(tensors); + return at::hstack(reshaped_tensors); +} + static Tensor& propagate_transposed_names( Tensor& result, const Tensor& other, @@ -1351,15 +1576,25 @@ inferSqueezeGeometry(const Tensor& tensor, int64_t dim) { return std::make_tuple(sizes, strides); } -std::tuple, std::vector > +namespace { +// Named type instead of a pair/tuple so that we can be sure to +// construct the vectors in place and get NRVO. +struct InferUnsqueezeGeometryResult { + c10::SmallVector sizes; + c10::SmallVector strides; + InferUnsqueezeGeometryResult(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) + : sizes(tensor_sizes.begin(), tensor_sizes.end()) + , strides(tensor_strides.begin(), tensor_strides.end()) {} +}; +} +InferUnsqueezeGeometryResult inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) { - auto sizes = tensor.sizes().vec(); - auto strides = tensor.strides().vec(); - int64_t new_stride = dim >= tensor.dim() ? 1 : sizes[dim] * strides[dim]; - sizes.insert(sizes.begin() + dim, 1); - strides.insert(strides.begin() + dim, new_stride); + InferUnsqueezeGeometryResult result(tensor.sizes(), tensor.strides()); + int64_t new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim]; + result.sizes.insert(result.sizes.begin() + dim, 1); + result.strides.insert(result.strides.begin() + dim, new_stride); - return std::make_tuple(sizes, strides); + return result; } Tensor squeeze_qtensor(const Tensor& self) { @@ -1508,7 +1743,7 @@ Tensor unsqueeze_qtensor(const Tensor& self, int64_t dim) { axis, quantizer->scalar_type()); } - return make_qtensor(self, std::get<0>(g), std::get<1>(g), quantizer); + return make_qtensor(self, g.sizes, g.strides, quantizer); } Tensor unsqueeze(const Tensor& self, int64_t dim) { @@ -1520,7 +1755,7 @@ Tensor unsqueeze(const Tensor& self, int64_t dim) { return unsqueeze_qtensor(self, dim); } else { auto g = inferUnsqueezeGeometry(self, dim); - return self.as_strided(std::get<0>(g), std::get<1>(g)); + return self.as_strided(g.sizes, g.strides); } } @@ -1528,7 +1763,7 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) { dim = maybe_wrap_dim(dim, self.dim() + 1); auto g = inferUnsqueezeGeometry(self, dim); - return self.as_strided_(std::get<0>(g), std::get<1>(g)); + return self.as_strided_(g.sizes, g.strides); } Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) { @@ -1593,16 +1828,20 @@ Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) { return native::flatten(self, *dims.begin(), *(dims.end() - 1), out_dim); } +Tensor ravel(const Tensor& self) { + return self.reshape(-1); +} + Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::optional names) { dim = maybe_wrap_dim(dim, self.dim()); TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty"); TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size()); - auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const int64_t numel = prod_intlist(sizes); if (self.has_names()) { TORCH_CHECK(numel == self.size(dim), - "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", + "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", dim, " (", self.names()[dim], ": ", self.size(dim), ") in Tensor", self.names()); TORCH_CHECK(names, "unflatten: input is a named tensor but no names were given for unflattened sizes"); } else { @@ -1791,7 +2030,7 @@ Tensor diag(const Tensor& self, int64_t dimension) { } Tensor& diag_cpu_out(Tensor &result, const Tensor& self, int64_t dimension) { - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "diag", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, self.scalar_type(), "diag", [&] { apply_diag(result, self, dimension); }); return result; @@ -1828,7 +2067,7 @@ Tensor movedim(const Tensor& self, IntArrayRef src, IntArrayRef dst) { DimVector normalized_dst(dst.size()); auto wrap_dims = [&self_dim](const IntArrayRef& vec, DimVector& normalized_vec) { - for (int i = 0; i < vec.size(); i++) { + for (size_t i = 0; i < vec.size(); i++) { normalized_vec[i] = maybe_wrap_dim(vec[i], self_dim); } }; @@ -1873,7 +2112,7 @@ Tensor movedim(const Tensor& self, IntArrayRef src, IntArrayRef dst) { // order = NA, NA, 0, NA, 1 // source_dims = -1, -1, 2, 3, 4 // destination_dims = 0, 1, -1, 3, -1 - for (int64_t i = 0; i < src.size(); ++i) { + for (size_t i = 0; i < src.size(); ++i) { order[normalized_dst[i]] = normalized_src[i]; source_dims[normalized_src[i]] = -1; destination_dims[normalized_dst[i]] = -1; @@ -1908,4 +2147,28 @@ Tensor movedim(const Tensor& self, int64_t src, int64_t dst) { return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst}); } +Tensor moveaxis(const Tensor& self, IntArrayRef src, IntArrayRef dst) { + return at::movedim(self, src, dst); +} + +Tensor moveaxis(const Tensor& self, int64_t src, int64_t dst) { + return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst}); +} + +Tensor swapaxes(const Tensor& self, int64_t axis0, int64_t axis1) { + return self.transpose(axis0, axis1); +} + +Tensor& swapaxes_(Tensor& self, int64_t axis0, int64_t axis1) { + return self.transpose_(axis0, axis1); +} + +Tensor swapdims(const Tensor& self, int64_t dim0, int64_t dim1) { + return self.transpose(dim0, dim1); +} + +Tensor& swapdims_(Tensor& self, int64_t dim0, int64_t dim1) { + return self.transpose_(dim0, dim1); +} + }} // at::native diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 1b86b3f2d634c..5c6ab40b0ad4c 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -61,15 +61,30 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] { - flip_cpu_kernel( - total_dims, - stride_contiguous_v, - flip_dims_b, - in_tensor, - out_tensor - ); - }); + if (in_tensor.is_quantized()) { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(in_tensor.scalar_type(), + "flip_quantized_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, + in_tensor.scalar_type(), + "flip_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } return out_tensor; } diff --git a/aten/src/ATen/native/TensorTransformations.h b/aten/src/ATen/native/TensorTransformations.h index 6c28bfbab41b8..aaacc0941a121 100644 --- a/aten/src/ATen/native/TensorTransformations.h +++ b/aten/src/ATen/native/TensorTransformations.h @@ -10,8 +10,11 @@ namespace at { namespace native { static inline void flip_check_errors(int64_t total_dims, int64_t flip_dims_size, IntArrayRef dims) { + if (flip_dims_size==0) { + return; + } // check if number of axis in dim is valid - if (flip_dims_size <= 0 || flip_dims_size > total_dims) { + if (flip_dims_size < 0 || flip_dims_size > total_dims) { TORCH_CHECK_INDEX(false, "flip dims size out of range, got flip dims size=", flip_dims_size); } diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index c89a7ee02221c..7a6f7c6e8e056 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace at { namespace native { @@ -42,5 +43,30 @@ Tensor _test_optional_floatlist( return output; } +// Test default strings can handle escape sequences properly (although commas are broken) +Tensor _test_string_default(const Tensor& dummy, std::string a, std::string b) { + const c10::string_view expect = "\"'\\"; + TORCH_CHECK(a == expect, "Default A failed"); + TORCH_CHECK(b == expect, "Default B failed"); + return dummy; +} + +// Test that overloads with ambiguity created by defaulted parameters work. +// The operator declared first should have priority always + +// Overload a +Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, int64_t b) { + TORCH_CHECK(a == 1); + TORCH_CHECK(b == 1); + return c10::scalar_to_tensor(1); +} + +// Overload b +Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string b) { + TORCH_CHECK(a == 2); + TORCH_CHECK(b == "2"); + return c10::scalar_to_tensor(2); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/TypeProperties.h b/aten/src/ATen/native/TypeProperties.h index 2e0c750c414ad..85ffed1ee07f8 100644 --- a/aten/src/ATen/native/TypeProperties.h +++ b/aten/src/ATen/native/TypeProperties.h @@ -10,9 +10,9 @@ struct ResultTypeState { c10::ScalarType zeroResult = ScalarType::Undefined; }; -CAFFE2_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state); -CAFFE2_API ScalarType result_type(const ResultTypeState& state); +TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state); +TORCH_API ScalarType result_type(const ResultTypeState& state); -CAFFE2_API ScalarType result_type(TensorList tensors); +TORCH_API ScalarType result_type(TensorList tensors); }} diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index f9af400ba2f4a..e2d1de5c07bfe 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -44,13 +44,29 @@ static inline Tensor& unary_op_impl_out(Tensor& result, const Tensor& self, Stub return result; } +template +static inline Tensor& unary_op_impl_float_out(Tensor& result, const Tensor& self, Stub& stub) { + auto iter = TensorIterator::unary_float_op(result, self); + stub(iter.device_type(), iter); + iter.cast_outputs(); + return result; +} + +template +Tensor unary_op_impl_float(const Tensor& self, Stub& stub) { + Tensor result; + auto iter = TensorIterator::unary_float_op(result, self); + stub(iter.device_type(), iter); + return iter.output(); +} + // An alternate version of unary_op_impl_out that follows the same pattern // for non-complex inputs, but returns a floating point tensor // for complex inputs by default. // Note: This is done by running the operation as usual and then copying the // operation's result to the expected result type. template -static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub) { +static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, const Tensor& self, Stub& stub, bool promotes_integer_to_float) { if (self.is_complex() && !result.is_complex()) { // Checks if the corresponding float type can be cast to the desired dtype const auto float_type = c10::toValueType(self.scalar_type()); @@ -69,6 +85,10 @@ static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, co return result; } + if (promotes_integer_to_float) { + return unary_op_impl_float_out(result, self, stub); + } + return unary_op_impl_out(result, self, stub); } @@ -101,8 +121,8 @@ static inline Tensor& unary_op_impl_(Tensor& self, OutImpl& out_impl) { return out_impl(self, self); } -Tensor& acos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acos_stub); } -Tensor acos(const Tensor& self) { return unary_op_impl(self, at::acos_out); } +Tensor& acos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, acos_stub); } +Tensor acos(const Tensor& self) { return unary_op_impl_float(self, acos_stub); } Tensor& acos_(Tensor& self) { return unary_op_impl_(self, at::acos_out); } // arccos, alias for acos @@ -133,8 +153,8 @@ Tensor& deg2rad_out(Tensor& result, const Tensor& self) { Tensor deg2rad(const Tensor& self) { return unary_op_impl(self, at::deg2rad_out); } Tensor& deg2rad_(Tensor& self) { return unary_op_impl_(self, at::deg2rad_out); } -Tensor& asin_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asin_stub); } -Tensor asin(const Tensor& self) { return unary_op_impl(self, at::asin_out); } +Tensor& asin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, asin_stub); } +Tensor asin(const Tensor& self) { return unary_op_impl_float(self, asin_stub); } Tensor& asin_(Tensor& self) { return unary_op_impl_(self, at::asin_out); } // arcsin, alias of asin @@ -142,8 +162,8 @@ Tensor& arcsin_out(Tensor& result, const Tensor& self) { return at::asin_out(res Tensor arcsin(const Tensor& self) { return self.asin(); } Tensor& arcsin_(Tensor& self) { return self.asin_(); } -Tensor& atan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atan_stub); } -Tensor atan(const Tensor& self) { return unary_op_impl(self, at::atan_out); } +Tensor& atan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, atan_stub); } +Tensor atan(const Tensor& self) { return unary_op_impl_float(self, atan_stub); } Tensor& atan_(Tensor& self) { return unary_op_impl_(self, at::atan_out); } // arctan, alias of atan @@ -157,12 +177,15 @@ Tensor& arctan_(Tensor& self) { return self.atan_(); } // complex input. This makes sense mathematically since the absolute value // and angle of a complex number has no imaginary part. Tensor& abs_out(Tensor& result, const Tensor& self) { - return unary_op_impl_with_complex_to_float_out(result, self, abs_stub); + return unary_op_impl_with_complex_to_float_out(result, self, abs_stub, /*promotes_integer_to_float=*/false); } Tensor abs(const Tensor& self) { return unary_op_impl_with_complex_to_float(self, at::abs_out); } -Tensor& abs_(Tensor& self) { return unary_op_impl_(self, at::abs_out); } +Tensor& abs_(Tensor& self) { + TORCH_CHECK(!self.is_complex(), "In-place abs is not supported for complex tensors."); + return unary_op_impl_(self, at::abs_out); +} // Absolute, alias for abs Tensor& absolute_out(Tensor& result, const Tensor& self) { @@ -176,10 +199,16 @@ Tensor& absolute_(Tensor& self) { } Tensor& angle_out(Tensor& result, const Tensor& self) { - return unary_op_impl_with_complex_to_float_out(result, self, angle_stub); + return unary_op_impl_with_complex_to_float_out(result, self, angle_stub, /*promotes_integer_to_float=*/true); } Tensor angle(const Tensor& self) { - return unary_op_impl_with_complex_to_float(self, at::angle_out); + if (self.is_complex()) { + const auto float_type = c10::toValueType(self.scalar_type()); + Tensor result = at::empty({0}, self.options().dtype(float_type)); + return at::angle_out(result, self); + } + + return unary_op_impl_float(self, angle_stub); } Tensor real(const Tensor& self) { @@ -227,26 +256,30 @@ Tensor& ceil_out(Tensor& result, const Tensor& self) { Tensor ceil(const Tensor& self) { return unary_op_impl(self, at::ceil_out); } Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, at::ceil_out); } -Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, exp_stub); } -Tensor exp(const Tensor& self) { return unary_op_impl(self, at::exp_out); } +Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp_stub); } +Tensor exp(const Tensor& self) { return unary_op_impl_float(self, exp_stub); } Tensor& exp_(Tensor& self) { return unary_op_impl_(self, at::exp_out); } -Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, exp2_stub); } -Tensor exp2(const Tensor& self) { return unary_op_impl(self, at::exp2_out); } +Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp2_stub); } +Tensor exp2(const Tensor& self) { return unary_op_impl_float(self, exp2_stub); } Tensor& exp2_(Tensor& self) { return unary_op_impl_(self, at::exp2_out); } -Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, expm1_stub); } -Tensor expm1(const Tensor& self) { return unary_op_impl(self, at::expm1_out); } +Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, expm1_stub); } +Tensor expm1(const Tensor& self) { return unary_op_impl_float(self, expm1_stub); } Tensor& expm1_(Tensor& self) { return unary_op_impl_(self, at::expm1_out); } -Tensor& erf_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, erf_stub); } -Tensor erf(const Tensor& self) { return unary_op_impl(self, at::erf_out); } +Tensor& erf_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erf_stub); } +Tensor erf(const Tensor& self) { return unary_op_impl_float(self, erf_stub); } Tensor& erf_(Tensor& self) { return unary_op_impl_(self, at::erf_out); } -Tensor& erfc_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, erfc_stub); } -Tensor erfc(const Tensor& self) { return unary_op_impl(self, at::erfc_out); } +Tensor& erfc_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erfc_stub); } +Tensor erfc(const Tensor& self) { return unary_op_impl_float(self, erfc_stub); } Tensor& erfc_(Tensor& self) { return unary_op_impl_(self, at::erfc_out); } +Tensor& erfinv_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erfinv_stub); } +Tensor erfinv(const Tensor& self) { return unary_op_impl_float(self, erfinv_stub); } +Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); } + Tensor& frac_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, frac_stub); } Tensor frac(const Tensor& self) { return unary_op_impl(self, at::frac_out); } Tensor& frac_(Tensor& self) { return unary_op_impl_(self, at::frac_out); } @@ -265,39 +298,47 @@ Tensor& i0_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(re Tensor i0(const Tensor& self) { return unary_op_impl(self, at::i0_out); } Tensor& i0_(Tensor& self) { return unary_op_impl_(self, at::i0_out); } -Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log_stub); } -Tensor log(const Tensor& self) { return unary_op_impl(self, at::log_out); } +Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log_stub); } +Tensor log(const Tensor& self) { return unary_op_impl_float(self, log_stub); } Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); } -Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log10_stub); } -Tensor log10(const Tensor& self) { return unary_op_impl(self, at::log10_out); } +Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log10_stub); } +Tensor log10(const Tensor& self) { return unary_op_impl_float(self, log10_stub); } Tensor& log10_(Tensor& self) { return unary_op_impl_(self, at::log10_out); } -Tensor& log1p_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log1p_stub); } -Tensor log1p(const Tensor& self) { return unary_op_impl(self, at::log1p_out); } +Tensor& log1p_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log1p_stub); } +Tensor log1p(const Tensor& self) { return unary_op_impl_float(self, log1p_stub); } Tensor& log1p_(Tensor& self) { return unary_op_impl_(self, at::log1p_out); } -Tensor& log2_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log2_stub); } -Tensor log2(const Tensor& self) { return unary_op_impl(self, at::log2_out); } +Tensor& log2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log2_stub); } +Tensor log2(const Tensor& self) { return unary_op_impl_float(self, log2_stub); } Tensor& log2_(Tensor& self) { return unary_op_impl_(self, at::log2_out); } Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, round_stub); } Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); } Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); } -Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, digamma_stub); } -Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); } +Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, digamma_stub); } +Tensor digamma(const Tensor& self) { return unary_op_impl_float(self, digamma_stub); } Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); } -Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, reciprocal_stub); } -Tensor reciprocal(const Tensor& self) { return unary_op_impl(self, at::reciprocal_out); } +Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, reciprocal_stub); } +Tensor reciprocal(const Tensor& self) { return unary_op_impl_float(self, reciprocal_stub); } Tensor& reciprocal_(Tensor& self) { return unary_op_impl_(self, at::reciprocal_out); } -Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, rsqrt_stub); } -Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); } +Tensor& rsqrt_out(Tensor& result, const Tensor& self) { + return unary_op_impl_float_out(result, self, rsqrt_stub); +} +Tensor rsqrt(const Tensor& self) { + return unary_op_impl_float(self, rsqrt_stub); +} Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); } -Tensor& sign_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sign_stub); } +Tensor& sign_out(Tensor& result, const Tensor& self) { + TORCH_CHECK(!self.is_complex(), + "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead."); + return unary_op_impl_out(result, self, sign_stub); +} Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); } Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); } @@ -312,24 +353,28 @@ Tensor& sgn_out(Tensor& result, const Tensor& self) { Tensor sgn(const Tensor& self) { return unary_op_impl(self, at::sgn_out); } Tensor& sgn_(Tensor& self) { return unary_op_impl_(self, at::sgn_out); } -Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sin_stub); } -Tensor sin(const Tensor& self) { return unary_op_impl(self, at::sin_out); } +Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sin_stub); } +Tensor sin(const Tensor& self) { return unary_op_impl_float(self, sin_stub); } Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); } -Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cos_stub); } -Tensor cos(const Tensor& self) { return unary_op_impl(self, at::cos_out); } +Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cos_stub); } +Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); } Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); } -Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); } -Tensor sinh(const Tensor& self) { return unary_op_impl(self, at::sinh_out); } +Tensor& sinc_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sinc_stub); } +Tensor sinc(const Tensor& self) { return unary_op_impl_float(self, sinc_stub); } +Tensor& sinc_(Tensor& self) { return unary_op_impl_(self, at::sinc_out); } + +Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sinh_stub); } +Tensor sinh(const Tensor& self) { return unary_op_impl_float(self, sinh_stub); } Tensor& sinh_(Tensor& self) { return unary_op_impl_(self, at::sinh_out); } -Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cosh_stub); } -Tensor cosh(const Tensor& self) { return unary_op_impl(self, at::cosh_out); } +Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cosh_stub); } +Tensor cosh(const Tensor& self) { return unary_op_impl_float(self, cosh_stub); } Tensor& cosh_(Tensor& self) { return unary_op_impl_(self, at::cosh_out); } -Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acosh_stub); } -Tensor acosh(const Tensor& self) { return unary_op_impl(self, at::acosh_out); } +Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, acosh_stub); } +Tensor acosh(const Tensor& self) { return unary_op_impl_float(self, acosh_stub); } Tensor& acosh_(Tensor& self) { return unary_op_impl_(self, at::acosh_out); } // arccosh, alias for acosh @@ -337,8 +382,8 @@ Tensor& arccosh_out(Tensor& result, const Tensor& self) { return at::acosh_out(r Tensor arccosh(const Tensor& self) { return at::acosh(self); } Tensor& arccosh_(Tensor& self) { return at::acosh_(self); } -Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asinh_stub); } -Tensor asinh(const Tensor& self) { return unary_op_impl(self, at::asinh_out); } +Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, asinh_stub); } +Tensor asinh(const Tensor& self) { return unary_op_impl_float(self, asinh_stub); } Tensor& asinh_(Tensor& self) { return unary_op_impl_(self, at::asinh_out); } // arcsinh, alias for asinh @@ -346,8 +391,8 @@ Tensor& arcsinh_out(Tensor& result, const Tensor& self) { return at::asinh_out(r Tensor arcsinh(const Tensor& self) { return self.asinh(); } Tensor& arcsinh_(Tensor& self) { return self.asinh_(); } -Tensor& atanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atanh_stub); } -Tensor atanh(const Tensor& self) { return unary_op_impl(self, at::atanh_out); } +Tensor& atanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, atanh_stub); } +Tensor atanh(const Tensor& self) { return unary_op_impl_float(self, atanh_stub); } Tensor& atanh_(Tensor& self) { return unary_op_impl_(self, at::atanh_out); } // arctanh, alias for atanh @@ -355,15 +400,15 @@ Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(r Tensor arctanh(const Tensor& self) { return self.atanh(); } Tensor& arctanh_(Tensor& self) { return self.atanh_(); } -Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); } -Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); } +Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sqrt_stub); } +Tensor sqrt(const Tensor& self) { return unary_op_impl_float(self, sqrt_stub); } Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); } Tensor square(const Tensor& self) { return at::pow(self, 2); } Tensor& square_(Tensor& self) { return at::pow_out(self, self, 2); } -Tensor& sigmoid_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sigmoid_stub); } -Tensor sigmoid(const Tensor& self) { return unary_op_impl(self, at::sigmoid_out); } +Tensor& sigmoid_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sigmoid_stub); } +Tensor sigmoid(const Tensor& self) { return unary_op_impl_float(self, sigmoid_stub); } Tensor& sigmoid_(Tensor& self) { return unary_op_impl_(self, at::sigmoid_out); } Tensor& logit_out( @@ -384,12 +429,53 @@ Tensor& logit_(Tensor& self, c10::optional eps) { return at::logit_out(self, self, eps); } -Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tanh_stub); } -Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); } +Tensor& nan_to_num_out( + Tensor& result, + const Tensor& self, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + TORCH_CHECK( + self.scalar_type() == result.scalar_type(), + "nan_to_num: dtype of out: ", + result.scalar_type(), + " should be same as input: ", + self.scalar_type()); + + if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) { + result.resize_as_(self); + result.copy_(self); + return result; + } + + auto iter = TensorIterator::unary_op(result, self); + nan_to_num_stub(iter.device_type(), iter, nan, pos_inf, neg_inf); + return result; +} + +Tensor nan_to_num( + const Tensor& self, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + auto result = at::empty_like(self); + return at::nan_to_num_out(result, self, nan, pos_inf, neg_inf); +} + +Tensor& nan_to_num_( + Tensor& self, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + return at::nan_to_num_out(self, self, nan, pos_inf, neg_inf); +} + +Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tanh_stub); } +Tensor tanh(const Tensor& self) { return unary_op_impl_float(self, tanh_stub); } Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); } -Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); } -Tensor tan(const Tensor& self) { return unary_op_impl(self, at::tan_out); } +Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tan_stub); } +Tensor tan(const Tensor& self) { return unary_op_impl_float(self, tan_stub); } Tensor& tan_(Tensor& self) { return unary_op_impl_(self, at::tan_out); } Tensor& trunc_out(Tensor& result, const Tensor& self) { @@ -463,7 +549,6 @@ Tensor signbit(const Tensor& self) { } Tensor& clamp_out(Tensor& result, const Tensor& self, optional min, optional max) { - TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); if (min && max) { TORCH_CHECK(self.layout() == Layout::Strided, "clamp only supports strided layout, got: ", self.layout()); @@ -474,7 +559,7 @@ Tensor& clamp_out(Tensor& result, const Tensor& self, optional min, opti } else if (min) { at::clamp_min_out(result, self, *min); } else { - AT_ERROR("At least one of 'min' or 'max' must not be None"); + TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None"); } return result; } @@ -489,7 +574,6 @@ Tensor& clamp_(Tensor& self, optional min, optional max) { } Tensor& clamp_max_out(Tensor& result, const Tensor& self, Scalar max) { - TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); TORCH_CHECK(self.layout() == Layout::Strided, "clamp_max only supports strided layout, got: ", self.layout()); auto iter = TensorIterator::unary_op(result, self); @@ -507,7 +591,6 @@ Tensor& clamp_max_(Tensor& self, Scalar max) { } Tensor& clamp_min_out(Tensor& result, const Tensor& self, Scalar min) { - TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); TORCH_CHECK(self.layout() == Layout::Strided, "clamp_min only supports strided layout, got: ", self.layout()); auto iter = TensorIterator::unary_op(result, self); @@ -605,7 +688,6 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) \ IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA) -IMPLEMENT_UNARY_OP_VEC_CUDA(erfinv) IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma) DEFINE_DISPATCH(abs_stub); @@ -642,6 +724,7 @@ DEFINE_DISPATCH(log1p_stub); DEFINE_DISPATCH(log2_stub); DEFINE_DISPATCH(logical_not_stub); DEFINE_DISPATCH(neg_stub); +DEFINE_DISPATCH(nan_to_num_stub); DEFINE_DISPATCH(polygamma_stub); DEFINE_DISPATCH(reciprocal_stub); DEFINE_DISPATCH(round_stub); @@ -652,6 +735,7 @@ DEFINE_DISPATCH(sign_stub); DEFINE_DISPATCH(signbit_stub); DEFINE_DISPATCH(sgn_stub); DEFINE_DISPATCH(sin_stub); +DEFINE_DISPATCH(sinc_stub); DEFINE_DISPATCH(sinh_stub); DEFINE_DISPATCH(sqrt_stub); DEFINE_DISPATCH(tan_stub); diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index 0dcd5a0b94735..d92864e6fb2af 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -55,6 +55,7 @@ DECLARE_DISPATCH(unary_fn, sign_stub); DECLARE_DISPATCH(unary_fn, signbit_stub); DECLARE_DISPATCH(unary_fn, sgn_stub); DECLARE_DISPATCH(unary_fn, sin_stub); +DECLARE_DISPATCH(unary_fn, sinc_stub); DECLARE_DISPATCH(unary_fn, sinh_stub); DECLARE_DISPATCH(unary_fn, sqrt_stub); DECLARE_DISPATCH(unary_fn, tan_stub); @@ -76,7 +77,16 @@ DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional), random_full DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional), random_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar a, Scalar b), clamp_stub); -DECLARE_DISPATCH(void(*)(Tensor&, const Tensor&, int64_t, bool, c10::optional), multinomial_stub); +DECLARE_DISPATCH( + void (*)(Tensor&, const Tensor&, int64_t, c10::optional), + multinomial_with_replacement_stub); +DECLARE_DISPATCH( + void (*)( + TensorIterator&, + c10::optional, + c10::optional, + c10::optional), + nan_to_num_stub); // Missing unary functions // digamma diff --git a/aten/src/ATen/native/UpSampleNearest1d.cpp b/aten/src/ATen/native/UpSampleNearest1d.cpp index e6e0033a0b847..6478bbb58eafb 100644 --- a/aten/src/ATen/native/UpSampleNearest1d.cpp +++ b/aten/src/ATen/native/UpSampleNearest1d.cpp @@ -1,47 +1,12 @@ #include #include #include +#include namespace at { -namespace native { -namespace { +namespace meta { -static void upsample_nearest1d_out_cpu_template( - Tensor& output, - const Tensor& input, - IntArrayRef output_size, - c10::optional scales) { - TORCH_CHECK( - output_size.size() == 1, - "It is expected output_size equals to 1, but got size ", - output_size.size()); - - int64_t output_width = output_size[0]; - - int64_t nbatch = input.size(0); - int64_t channels = input.size(1); - int64_t input_width = input.size(2); - - upsample_1d_shape_check( - input, - Tensor(), - nbatch, - channels, - input_width, - output_width); - - output.resize_({nbatch, channels, output_width}); - - AT_ASSERT(input_width > 0 && output_width > 0); - upsample_nearest1d_kernel(kCPU, output, input, scales); -} - -static void upsample_nearest1d_backward_out_cpu_template( - Tensor& grad_input, - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales) { +static std::array upsample_nearest1d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 1, "It is expected output_size equals to 1, but got size ", @@ -58,90 +23,95 @@ static void upsample_nearest1d_backward_out_cpu_template( int64_t channels = input_size[1]; int64_t input_width = input_size[2]; - upsample_1d_shape_check( - Tensor(), - grad_output, - nbatch, - channels, + TORCH_CHECK( + input_width > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", input_width, - output_width); - - grad_input.resize_({nbatch, channels, input_width}); - grad_input.zero_(); + ") and output (W: ", + output_width, + ")"); - upsample_nearest1d_backward_kernel(kCPU, grad_input, grad_output, scales); + return {nbatch, channels, output_width}; } -} // namespace -Tensor& upsample_nearest1d_out_cpu( - Tensor& output, - const Tensor& input, - IntArrayRef output_size, - c10::optional scales) { - upsample_nearest1d_out_cpu_template(output, input, output_size, scales); - return output; +TORCH_META_FUNC(upsample_nearest1d) ( + const Tensor& input, IntArrayRef output_size, c10::optional scales +) { + auto full_output_size = upsample_nearest1d_common_check(input.sizes(), output_size); + + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + (input.size(1) != 0 && input.size(2) != 0) && input.dim() == 3, + "Non-empty 3D data tensor expected but got a tensor with sizes ", + input.sizes()); + + set_output(full_output_size, input.options()); } -Tensor upsample_nearest1d_cpu( - const Tensor& input, - IntArrayRef output_size, - c10::optional scales) { - auto output = at::empty({0}, input.options()); - upsample_nearest1d_out_cpu_template(output, input, output_size, scales); - return output; +TORCH_META_FUNC(upsample_nearest1d_backward) ( + const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, c10::optional scales +) { + auto full_output_size = upsample_nearest1d_common_check(input_size, output_size); + + check_dim_size(grad_output, 3, 0, full_output_size[0]); + check_dim_size(grad_output, 3, 1, full_output_size[1]); + check_dim_size(grad_output, 3, 2, full_output_size[2]); + + set_output(input_size, grad_output.options()); } -Tensor& upsample_nearest1d_backward_out_cpu( - Tensor& grad_input, - const Tensor& grad_output, +} // namespace meta + + +namespace native { + +TORCH_IMPL_FUNC(upsample_nearest1d_out_cpu) ( + const Tensor& input, IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales) { - upsample_nearest1d_backward_out_cpu_template( - grad_input, grad_output, output_size, input_size, scales); - return grad_input; + c10::optional scales, + Tensor& output +) { + upsample_nearest1d_kernel(kCPU, output, input, scales); } -Tensor upsample_nearest1d_backward_cpu( +TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cpu) ( const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, - c10::optional scales) { - auto grad_input = at::zeros(input_size, grad_output.options()); - upsample_nearest1d_backward_out_cpu_template( - grad_input, grad_output, output_size, input_size, scales); - return grad_input; + c10::optional scales, + Tensor& grad_input +) { + grad_input.zero_(); + upsample_nearest1d_backward_kernel(kCPU, grad_input, grad_output, scales); } using at::native::upsample::compute_output_size; using at::native::upsample::get_scale_value; -Tensor upsample_nearest1d_cpu( +// vec variants + +Tensor upsample_nearest1d( const Tensor& input, c10::optional output_size, c10::optional> scale_factors) { - auto output = at::empty({0}, input.options()); auto osize = compute_output_size(input.sizes(), output_size, scale_factors); auto scale_w = get_scale_value(scale_factors, 0); - upsample_nearest1d_out_cpu_template(output, input, osize, scale_w); - return output; + return at::upsample_nearest1d(input, osize, scale_w); } -Tensor upsample_nearest1d_backward_cpu( +Tensor upsample_nearest1d_backward( const Tensor& grad_output, c10::optional output_size, IntArrayRef input_size, c10::optional> scale_factors) { auto osize = compute_output_size(input_size, output_size, scale_factors); auto scale_w = get_scale_value(scale_factors, 0); - auto grad_input = at::zeros(input_size, grad_output.options()); - upsample_nearest1d_backward_out_cpu_template( - grad_input, grad_output, osize, input_size, scale_w); - return grad_input; + return at::upsample_nearest1d_backward(grad_output, osize, input_size, scale_w); } DEFINE_DISPATCH(upsample_nearest1d_kernel); DEFINE_DISPATCH(upsample_nearest1d_backward_kernel); } // namespace native + } // namespace at diff --git a/aten/src/ATen/native/VariableMethodStubs.cpp b/aten/src/ATen/native/VariableMethodStubs.cpp index 7d5cea725cf10..e7d65dc0967df 100644 --- a/aten/src/ATen/native/VariableMethodStubs.cpp +++ b/aten/src/ATen/native/VariableMethodStubs.cpp @@ -8,7 +8,7 @@ namespace at { namespace native { -void backward(const Tensor& self, const Tensor& gradient, c10::optional keep_graph, bool create_graph) { +void _backward(const Tensor& self, TensorList inputs, const Tensor& gradient, c10::optional keep_graph, bool create_graph) { AT_ERROR("backward is not implemented for Tensor"); } @@ -40,5 +40,9 @@ void retain_grad(Tensor& self) { AT_ERROR("retain_grad is not implemented for Tensor"); } +Tensor _fw_primal(const Tensor& self, int64_t level) { + AT_ERROR("_fw_primal is not implemented for Tensor"); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 2f4f082b22b1c..e9af9a7e7d961 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -226,7 +226,7 @@ void elu_kernel(TensorIterator& it, Scalar alpha, Scalar scale, Scalar input_sca }); } -void elu_backward_kernel(TensorIterator& it, Scalar alpha, Scalar scale, Scalar input_scale) { +void elu_backward_kernel(TensorIterator& it, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result) { AT_DISPATCH_FLOATING_TYPES(it.dtype(), "elu_backward_cpu", [&]() { using Vec = Vec256; auto negcoef = alpha.to() * scale.to(); @@ -238,15 +238,23 @@ void elu_backward_kernel(TensorIterator& it, Scalar alpha, Scalar scale, Scalar const Vec zero_vec(static_cast(0)); cpu_kernel_vec( it, - [negcoef, negiptcoef, poscoef](scalar_t a, scalar_t b) -> scalar_t { - return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef; + [negcoef, negiptcoef, poscoef, is_result](scalar_t a, scalar_t b) -> scalar_t { + if (is_result) { + return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef; + } else { + return b <= scalar_t(0) ? a * negiptcoef * negcoef * std::exp(b * negiptcoef): a * poscoef; + } }, - [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &zero_vec](Vec a, Vec b) -> Vec { + [&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &zero_vec, is_result](Vec a, Vec b) -> Vec { auto cmp = (b > zero_vec); - if (!cmp.zero_mask()) { // only a * poscoef (which is very quick) needs to be computed - return a * poscoef_vec; + if (is_result) { + if (!cmp.zero_mask()) { // only a * poscoef (which is very quick) needs to be computed + return a * poscoef_vec; + } else { + return Vec::blendv(a * negiptcoef_vec * (b + negcoef_vec), a * poscoef_vec, cmp); + } } else { - return Vec::blendv(a * negiptcoef_vec * (b + negcoef_vec), a * poscoef_vec, cmp); + return Vec::blendv(a * negiptcoef_vec * negcoef_vec * (b * negiptcoef_vec).exp(), a * poscoef_vec, cmp); } } ); diff --git a/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp b/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp new file mode 100644 index 0000000000000..b5ed77f6e400b --- /dev/null +++ b/aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp @@ -0,0 +1,311 @@ +#include + +#include +#include +#include +#include +#include + +namespace at { namespace native { + +namespace { + +template +void cpu_adaptive_avg_pool( + Tensor& output_, + const Tensor& input_, + IntArrayRef output_size) { + auto input = input_.contiguous(); + auto output = output_.contiguous(); + + auto input_data = input.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t ndim = input.ndimension(); + // treat batch size and channels as one dimension + int64_t channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1); + int64_t input_height = input.size(-2); + int64_t input_width = input.size(-1); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + // parallel on dim of N, C + at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) { + for (int64_t c = begin; c < end; c++) { + scalar_t* input_ptr = input_data + c * input_height * input_width; + scalar_t* output_ptr = output_data + c * output_height * output_width; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + // compute local average + scalar_t sum = 0; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + sum += input_ptr[ih * input_width + iw]; + } + } + output_ptr[oh * output_width + ow] = sum / kh / kw; + } + } + } + }); + + if (!output_.is_contiguous()) { + output_.copy_(output); + } +} + +template +void cpu_adaptive_avg_pool_channels_last( + Tensor& output_, + const Tensor& input_, + IntArrayRef output_size) { + auto memory_format = at::MemoryFormat::ChannelsLast; + auto input = input_.contiguous(memory_format); + auto output = output_.contiguous(memory_format); + + auto input_data = input.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t nbatch = input.size(0); + int64_t channels = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + using Vec = vec256::Vec256; + // parallel on dim N, H, W + at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) { + int64_t n = 0; + int64_t oh = 0; + int64_t ow = 0; + data_index_init(begin, n, nbatch, oh, output_height, ow, output_width); + + for (int64_t i = begin; i < end; i++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t* out = output_data + i * channels; + int64_t size = channels; + + // Note: For oridinary usage scenario, each out lane should + // fit in L1 cache; otherwise consider block dim C. + // Pass I: zero the out lane + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec out_vec = Vec(scalar_t(0)); + out_vec.store(out + d1); + } + for (; d1 < size; d1++) { + out[d1] = scalar_t(0); + } + // Pass II: compute local sum + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + scalar_t* in = input_data + n * input_height * input_width * channels + + ih * input_width * channels + iw * channels; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2); + out_vec.store(out + d2); + } + for (; d2 < size; d2++) { + out[d2] += in[d2]; + } + } + } + // Pass III: compute local average + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kh * kw)); + out_vec.store(out + d3); + } + for (; d3 < size; d3++) { + out[d3] = out[d3] / kh / kw; + } + + // move on to next output index + data_index_step(n, nbatch, oh, output_height, ow, output_width); + } + }); + + if (!output_.is_contiguous(memory_format)) { + output_.copy_(output); + } +} + +template +void cpu_adaptive_avg_pool_backward( + Tensor& grad_input_, + const Tensor& grad_output_) { + auto grad_output = grad_output_.contiguous(); + auto grad_input = grad_input_.contiguous(); + + auto grad_output_data = grad_output.data_ptr(); + auto grad_input_data = grad_input.data_ptr(); + + int64_t ndim = grad_output.ndimension(); + // treat batch size and channels as one dimension + int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); + int64_t input_height = grad_input.size(-2); + int64_t input_width = grad_input.size(-1); + int64_t output_height = grad_output.size(-2); + int64_t output_width = grad_output.size(-1); + + // parallel on dim of N, C + at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) { + for (int64_t c = begin; c < end; c++) { + scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width; + scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t grad_delta = grad_output_ptr[oh * output_width + ow] / kh / kw; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + grad_input_ptr[ih * input_width + iw] += grad_delta; + } + } + } + } + } + }); + + if (!grad_input_.is_contiguous()) { + grad_input_.copy_(grad_input); + } +} + +template +void cpu_adaptive_avg_pool_backward_channels_last( + Tensor& grad_input_, + const Tensor& grad_output_) { + auto memory_format = at::MemoryFormat::ChannelsLast; + auto grad_input = grad_input_.contiguous(memory_format); + auto grad_output = grad_output_.contiguous(memory_format); + + auto grad_input_data = grad_input.data_ptr(); + auto grad_output_data = grad_output.data_ptr(); + + int64_t nbatch = grad_input.size(0); + int64_t channels = grad_input.size(1); + int64_t input_height = grad_input.size(2); + int64_t input_width = grad_input.size(3); + int64_t output_height = grad_output.size(2); + int64_t output_width = grad_output.size(3); + + using Vec = vec256::Vec256; + // parallel on dim N + at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) { + for (int64_t n = begin; n < end; n++) { + scalar_t* grad_input_ptr = grad_input_data + n * input_height * input_width * channels; + scalar_t* grad_output_ptr = grad_output_data + n * output_height * output_width * channels; + + for (int64_t oh = 0; oh < output_height; oh++) { + int64_t ih0 = start_index(oh, output_height, input_height); + int64_t ih1 = end_index(oh, output_height, input_height); + int64_t kh = ih1 - ih0; + + for (int64_t ow = 0; ow < output_width; ow++) { + int64_t iw0 = start_index(ow, output_width, input_width); + int64_t iw1 = end_index(ow, output_width, input_width); + int64_t kw = iw1 - iw0; + + scalar_t* gout = grad_output_ptr + oh * output_width * channels + ow * channels; + int64_t size = channels; + for (int64_t ih = ih0; ih < ih1; ih++) { + for (int64_t iw = iw0; iw < iw1; iw++) { + scalar_t* gin = grad_input_ptr + ih * input_width * channels + iw * channels; + + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kh * kw)); + gin_vec.store(gin + d); + } + for (; d < size; d++) { + gin[d] += gout[d] / kw / kw; + } + } + } + } + } + } + }); + + if (!grad_input_.is_contiguous(memory_format)) { + grad_input_.copy_(grad_input); + } +} + +void adaptive_avg_pool2d_kernel_impl( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + switch (input.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d", [&] { + cpu_adaptive_avg_pool(output, input, output_size); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{ + cpu_adaptive_avg_pool_channels_last(output, input, output_size); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +void adapative_avg_pool2d_backward_kernel_impl( + Tensor& grad_input, + const Tensor& grad_output) { + switch (grad_output.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "adaptive_avg_pool2d_backward", [&] { + cpu_adaptive_avg_pool_backward(grad_input, grad_output); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "adaptive_avg_pool2d_backward_channels_last", [&]{ + cpu_adaptive_avg_pool_backward_channels_last(grad_input, grad_output); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +} // anonymous namespace + +REGISTER_DISPATCH(adaptive_avg_pool2d_kernel, &adaptive_avg_pool2d_kernel_impl); +REGISTER_DISPATCH(adaptive_avg_pool2d_backward_kernel, &adapative_avg_pool2d_backward_kernel_impl); + +}} // at::native diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 67a961401fb00..12301dc4a38e4 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -21,7 +21,7 @@ using namespace vec256; // Note: Undefined behavior when performing addition is intentionally // ignored. -void add_kernel(TensorIterator& iter, Scalar alpha_scalar) { +void add_kernel(TensorIteratorBase& iter, Scalar alpha_scalar) { if (iter.dtype() == ScalarType::Bool) { using scalar_t = bool; auto alpha = alpha_scalar.to(); @@ -116,8 +116,8 @@ void div_kernel(TensorIterator& iter) { } void remainder_kernel(TensorIterator& iter) { - if (isIntegralType(iter.dtype(), /*includeBool*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "remainder_cpu", [&]() { + if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { TORCH_CHECK(b != 0, "ZeroDivisionError"); scalar_t r = a % b; @@ -128,7 +128,7 @@ void remainder_kernel(TensorIterator& iter) { }); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "remainder_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "remainder_cpu", [&]() { cpu_kernel_vec(iter, [=](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { scalar_t mod = std::fmod(a, b); @@ -234,17 +234,16 @@ void lshift_kernel(TensorIterator& iter) { } void logical_and_kernel(TensorIterator& iter) { - // We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because - // common_dtype() is unavailable for bfloat16. + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_and_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a && b; }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_and_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return static_cast(a && b); @@ -254,37 +253,35 @@ void logical_and_kernel(TensorIterator& iter) { } void logical_or_kernel(TensorIterator& iter) { - // We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because - // common_dtype() is unavailable for bfloat16. + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_or_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a || b; }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.dtype(), "logical_or_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return static_cast(a || b); }); - }); + }); } } void logical_xor_kernel(TensorIterator& iter) { - // We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because - // common_dtype() is unavailable for bfloat16. + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_xor_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return bool(a) != bool(b); }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_xor_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return static_cast(bool(a) != bool(b)); @@ -311,21 +308,22 @@ void rshift_kernel(TensorIterator& iter) { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return a >> b; - }); + }); }); } } void lt_kernel(TensorIterator& iter) { + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "lt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a < b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a < b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "lt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -334,20 +332,21 @@ void lt_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.lt(b); }); - }); + }); } } void le_kernel(TensorIterator& iter) { + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "le_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a <= b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a <= b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "le_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -356,20 +355,21 @@ void le_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.le(b); }); - }); + }); } } void gt_kernel(TensorIterator& iter) { + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "gt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() { cpu_kernel(iter, - [=](scalar_t a, scalar_t b) -> bool { - return a > b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a > b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "gt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -378,20 +378,21 @@ void gt_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.gt(b); }); - }); + }); } } void ge_kernel(TensorIterator& iter) { + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ge_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a >= b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a >= b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "ge_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -400,20 +401,21 @@ void ge_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.ge(b); }); - }); + }); } } void eq_kernel(TensorIterator& iter) { + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "eq_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a == b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a == b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "eq_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -422,20 +424,21 @@ void eq_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.eq(b); }); - }); + }); } } void ne_kernel(TensorIterator& iter) { + // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ne_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a != b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a != b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "ne_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -444,7 +447,7 @@ void ne_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.ne(b); }); - }); + }); } } @@ -502,30 +505,31 @@ void minimum_kernel(TensorIterator& iter) { } } -void smooth_l1_kernel(TensorIterator& iter) { +void smooth_l1_kernel(TensorIterator& iter, double beta) { AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, iter.dtype(), "smooth_l1_cpu", [&]() { using Vec = Vec256; - const Vec one_vec(static_cast(1)); + const scalar_t beta_val(beta); + const Vec beta_val_vec(beta_val); const Vec point_five_vec(static_cast(0.5)); cpu_kernel_vec( iter, - [](scalar_t a, scalar_t b) -> scalar_t { + [&beta_val](scalar_t a, scalar_t b) -> scalar_t { auto z = std::abs(a - b); - return z < static_cast(1) - ? static_cast(0.5) * z * z - : z - static_cast(0.5); + return z < beta_val + ? static_cast(0.5) * z * z / beta_val + : z - static_cast(0.5) * beta_val; }, - [&one_vec, &point_five_vec](Vec a, Vec b) { + [&beta_val_vec, &point_five_vec](Vec a, Vec b) { auto z = (a - b).abs(); return Vec::blendv( - point_five_vec * z * z, z - point_five_vec, z >= one_vec); + point_five_vec * z * z / beta_val_vec, z - point_five_vec * beta_val_vec, z >= beta_val_vec); }); }); } void sigmoid_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "sigmoid_backward_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "sigmoid_backward_cpu", [&]() { auto one_vec = Vec256((scalar_t)(1)); cpu_kernel_vec(iter, [=](scalar_t a, scalar_t b) -> scalar_t { @@ -588,17 +592,31 @@ void logit_backward_kernel(TensorIterator& iter, Scalar eps_scalar) { } void tanh_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { - auto one_vec = Vec256(scalar_t{1}); + if (isComplexType(iter.dtype())) { + AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { + auto one_vec = Vec256(scalar_t{1}); cpu_kernel_vec( iter, [=](scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t{1} - b * b); + return a * std::conj(scalar_t{1} - b * b); }, [=](Vec256 a, Vec256 b) { - return a * (one_vec - b * b); + return a * (one_vec - b * b).conj(); }); }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { + auto one_vec = Vec256(scalar_t{1}); + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t{1} - b * b); + }, + [=](Vec256 a, Vec256 b) { + return a * (one_vec - b * b); + }); + }); + } } void mse_kernel(TensorIterator& iter) { @@ -621,15 +639,15 @@ void mse_kernel(TensorIterator& iter) { } void fmod_kernel(TensorIterator& iter) { - if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "fmod_cpu", [&]() { + if (isIntegralType(iter.common_dtype(), /*includeBool=*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_cpu", [&]() { cpu_kernel(iter, [=](scalar_t x, scalar_t d) -> scalar_t { TORCH_CHECK(d != 0, "ZeroDivisionError"); return x % d; }); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "fmod_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(), "fmod_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t x, scalar_t d) -> scalar_t { @@ -638,34 +656,8 @@ void fmod_kernel(TensorIterator& iter) { [](Vec256 x, Vec256 d) { return x.fmod(d); }); - }); - } -} - -void fmod_scalar_kernel(TensorIterator& iter, Scalar divisor) { - if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "fmod_scalar_cpu", [&]() { - const auto div = divisor.to(); - TORCH_CHECK(div != 0, "ZeroDivisionError"); - cpu_kernel(iter, [=](scalar_t x) -> scalar_t { - return x % div; - }); }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "fmod_scalar_cpu", [&]() { - const auto div = divisor.to(); - const auto div_vec = Vec256(div); - cpu_kernel_vec( - iter, - [=](scalar_t x) -> scalar_t { - return std::fmod(x, div); - }, - [=](Vec256 x) { - return x.fmod(div_vec); - }); - }); } - } void logaddexp_kernel(TensorIterator& iter) { @@ -751,6 +743,32 @@ void hypot_kernel(TensorIterator& iter) { }); } +void igamma_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "igamma_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return calc_igamma(a, b); + }, + [=](Vec256 a, Vec256 b) { + return a.igamma(b); + }); + }); +} + +void igammac_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "igammac_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return calc_igammac(a, b); + }, + [=](Vec256 a, Vec256 b) { + return a.igammac(b); + }); + }); +} + void nextafter_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() { cpu_kernel_vec( @@ -772,6 +790,45 @@ void heaviside_kernel(TensorIterator& iter) { }); } +template +T copysign(T a, T b) { + return std::copysign(a, b); +} + +// Implement copysign for half precision floats using bit ops +// Sign is the most significant bit for both half and bfloat16 types +template<> +c10::Half copysign(c10::Half a, c10::Half b) { + return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits()); +} + +template<> +c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) { + return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits()); +} + +void copysign_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() { + cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { + return copysign(a, b); + }); + }); +} + +void xlogy_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() { + cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t { + if (at::_isnan(y)){ + return NAN; + } + if (x == 0){ + return 0; + } + return x * std::log(y); + }); + }); +} + } // namespace REGISTER_DISPATCH(add_stub, &add_kernel); @@ -803,14 +860,17 @@ REGISTER_DISPATCH(logit_backward_stub, &logit_backward_kernel); REGISTER_DISPATCH(tanh_backward_stub, &tanh_backward_kernel); REGISTER_DISPATCH(mse_stub, &mse_kernel); REGISTER_DISPATCH(fmod_stub, &fmod_kernel); -REGISTER_DISPATCH(fmod_scalar_stub, &fmod_scalar_kernel); REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel); REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel); REGISTER_DISPATCH(gcd_stub, &gcd_kernel); REGISTER_DISPATCH(lcm_stub, &lcm_kernel); REGISTER_DISPATCH(hypot_stub, &hypot_kernel); +REGISTER_DISPATCH(igamma_stub, &igamma_kernel); +REGISTER_DISPATCH(igammac_stub, &igammac_kernel); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel); +REGISTER_DISPATCH(copysign_stub, ©sign_kernel); +REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cpu/CatKernel.cpp b/aten/src/ATen/native/cpu/CatKernel.cpp index 299850407da35..f86adb8e63181 100644 --- a/aten/src/ATen/native/cpu/CatKernel.cpp +++ b/aten/src/ATen/native/cpu/CatKernel.cpp @@ -15,18 +15,20 @@ struct InputMeta { InputMeta(const Tensor& t, int64_t dim, int64_t inner) : data_ptr(t.data_ptr()) - , inner_size(t.size(dim) * inner) {} + , inner_size(t.sizes()[dim] * inner) {} }; template void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) { - int64_t outer = result.numel() / (result.size(dim) * result.stride(dim)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl"); + int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]); scalar_t* result_data = result.data_ptr(); int64_t ninputs = tensors.size(); std::vector inputs; inputs.reserve(ninputs); for (auto const &tensor : tensors) { - inputs.emplace_back(tensor, dim, result.stride(dim)); + inputs.emplace_back(tensor, dim, result.strides()[dim]); } using Vec = vec256::Vec256; diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 3dbd52aa19055..bca831dc55e2c 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -17,6 +17,8 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) { cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; }); } else if (dtype == ScalarType::BFloat16) { cpu_kernel(iter, [=](at::BFloat16 a) -> at::BFloat16 { return a; }); + } else if (dtype == ScalarType::ComplexHalf) { + cpu_kernel(iter, [=](c10::complex a) -> c10::complex { return a; }); } else if (isQIntType(dtype)) { AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] { cpu_kernel_vec( diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index 114ca93dae261..34911a2975e49 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -104,7 +104,11 @@ struct Dist { // Special general pnorm derivative if p is less than two struct lttdist_calc { - static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); } + static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { + Vec result = (dist == 0.0) ? Vec(0) : (sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1))); + result = Vec::blendv(result, Vec(0), (diff == Vec(0)) & (p < Vec(1))); + return result; + } }; // Two norm diff --git a/aten/src/ATen/native/cpu/DistributionTemplates.h b/aten/src/ATen/native/cpu/DistributionTemplates.h index 6fe825bcde1eb..d4b6da57111d8 100644 --- a/aten/src/ATen/native/cpu/DistributionTemplates.h +++ b/aten/src/ATen/native/cpu/DistributionTemplates.h @@ -180,9 +180,7 @@ void normal_kernel(Tensor& self, double mean, double std, RNG generator) { normal_fill(self, static_cast(mean), static_cast(std), generator); #endif } else { - // bfloat16 cannot be properly tested due to the lack of other operations - // like add/sub/mean implemented for half - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "normal_kernel_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { if (size >= 16 && self.is_contiguous()) { normal_fill(self, static_cast(mean), static_cast(std), generator); } else { @@ -208,7 +206,7 @@ struct NormalKernel { template void uniform_kernel(TensorIterator& iter, double from_, double to_, RNG generator) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { std::lock_guard lock(generator->mutex_); auto from = static_cast(from_); auto to = static_cast(to_); @@ -230,7 +228,7 @@ struct UniformKernel { template void cauchy_kernel(TensorIterator& iter, double median, double sigma, RNG generator) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() { std::lock_guard lock(generator->mutex_); at::cauchy_distribution cauchy(median, sigma); cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t { diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 4dfe644b89a4a..ece2d527e8990 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -354,6 +354,10 @@ struct ComputeLocation return unnormalize(in); } + inline Vec compute_coordinates(const Vec &in) const { + return in; + } + inline std::pair apply_get_grad(const Vec &in) const { return std::make_pair(unnormalize(in), Vec(scaling_factor)); } @@ -374,6 +378,10 @@ struct ComputeLocation return clip_coordinates(unnormalize(in)); } + inline Vec compute_coordinates(const Vec &in) const { + return clip_coordinates(in); + } + inline std::pair apply_get_grad(const Vec &in) const { Vec res, grad_clip; std::tie(res, grad_clip) = clip_coordinates_get_grad(unnormalize(in)); @@ -400,6 +408,12 @@ struct ComputeLocation return res; } + inline Vec compute_coordinates(const Vec &in) const { + auto res = reflect_coordinates(in); + res = clip_coordinates(res); + return res; + } + inline std::pair apply_get_grad(const Vec &in) const { Vec res, grad_refl, grad_clip, grad(scaling_factor); std::tie(res, grad_refl) = reflect_coordinates_get_grad(unnormalize(in)); @@ -764,6 +778,202 @@ struct ApplyGridSample +struct ApplyGridSample { + using Vec = Vec256; + using integer_t = int_same_size_t; + using iVec = Vec256; + + const int64_t inp_H; + const int64_t inp_W; + const int64_t inp_sH; + const int64_t inp_sW; + const int64_t C; + const int64_t inp_sC; + const ComputeLocation compute_H; + const ComputeLocation compute_W; + const bool must_in_bound = padding != GridSamplerPadding::Zeros; + + // constant used in cubic convolution + // could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h + const Vec A = Vec(-0.75); + + ApplyGridSample(const TensorAccessor& input) + : inp_H(input.size(2)) + , inp_W(input.size(3)) + , inp_sH(input.stride(2)) + , inp_sW(input.stride(3)) + , C(input.size(1)) + , inp_sC(input.stride(1)) + , compute_H(input.size(2)) + , compute_W(input.size(3)) {} + + // Calculate the cubic convolution coefficient + inline void get_cubic_coefficients(Vec (&coeffs)[4], const Vec& tx) const { + Vec x; + x = tx + Vec(1); // 1 < x = |-1 - tx| < 2 + coeffs[0] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A; + x = tx; // x = |0 - tx| <= 1 + coeffs[1] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1); + x = Vec(1) - tx; // x = |1 - tx| <= 1 + coeffs[2] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1); + x = Vec(2) - tx; // 1 < x = |2 - tx| < 2 + coeffs[3] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A; + } + + // Calculate the differential of the cubic convolution, i.e. `d coeff / d x` + inline void get_cubic_coefficients_grad(Vec (&coeffs)[4], const Vec& tx) const { + Vec x; + x = Vec(-1) - tx; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (Vec(-3) * A * x - Vec(10) * A ) * x - Vec(8) * A; + x = Vec(0) - tx; // x = |0 - tx| <= 1 + coeffs[1] = (Vec(-3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x; + x = Vec(1) - tx; // x = |1 - tx| <= 1 + coeffs[2] = (Vec(3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x; + x = Vec(2) - tx; // 1 < x = |2 - tx| < 2 + coeffs[3] = (Vec(3) * A * x - Vec(10) * A) * x + Vec(8) * A; + } + + inline Vec get_value_bounded(const scalar_t* data, const Vec& x, const Vec& y) const { + auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x)); + auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y)); + + auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W)); + auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H)); + auto mask = cast(mask_x & mask_y); + + auto offset = iy * iVec(inp_sH) + ix * iVec(inp_sW); + + auto val = mask_gather(Vec(0), data, offset, mask); + return val; + } + + inline void add_value_bounded(scalar_t* data, int64_t len, const Vec& x, const Vec&y, + const Vec& delta) const { + + auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x)); + auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y)); + + auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W)); + auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H)); + auto mask = cast(mask_x & mask_y); + + auto i_gInp_offset = iy * iVec(inp_W) + ix; + integer_t i_gInp_offset_arr[iVec::size()]; + i_gInp_offset.store(i_gInp_offset_arr); + + integer_t mask_arr[iVec::size()]; + mask.store(mask_arr); + + scalar_t gInp_corner_arr[Vec::size()]; + delta.store(gInp_corner_arr); + + mask_scatter_add(gInp_corner_arr, data, i_gInp_offset_arr, mask_arr, len); + } + + inline void forward(TensorAccessor& out_slice, + const TensorAccessor& inp_slice, + int64_t offset, const Vec& grid_x, const Vec& grid_y, + int64_t len) const { + + auto x = compute_W.unnormalize(grid_x); + auto y = compute_H.unnormalize(grid_y); + + auto ix = x.floor(); + auto iy = y.floor(); + + Vec coeff_x[4]; + Vec coeff_y[4]; + get_cubic_coefficients(coeff_x, x - ix); + get_cubic_coefficients(coeff_y, y - iy); + + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (int64_t c = 0; c < C; ++c) { + auto inp_slice_C_ptr = inp_slice[c].data(); + + // Interpolate the 4 values in the x direction + Vec interp_x[4]; + for (int64_t i = 0; i < 4; ++i) { + interp_x[i] = + coeff_x[0] * get_value_bounded(inp_slice_C_ptr, ix - Vec(1), iy + Vec(-1 + i)) + + coeff_x[1] * get_value_bounded(inp_slice_C_ptr, ix + Vec(0), iy + Vec(-1 + i)) + + coeff_x[2] * get_value_bounded(inp_slice_C_ptr, ix + Vec(1), iy + Vec(-1 + i)) + + coeff_x[3] * get_value_bounded(inp_slice_C_ptr, ix + Vec(2), iy + Vec(-1 + i)); + } + + // Interpolate the 4 values in the y direction + auto interpolated = coeff_y[0] * interp_x[0] + coeff_y[1] * interp_x[1] + + coeff_y[2] * interp_x[2] + coeff_y[3] * interp_x[3]; + interpolated.store(out_slice[c].data() + offset, len); + } + } + + inline void backward(TensorAccessor& gInp_slice, + TensorAccessor& gGrid_slice, + const TensorAccessor& gOut_slice, + const TensorAccessor& inp_slice, + int64_t offset, const Vec& grid_x, const Vec& grid_y, + int64_t len) const { + + Vec x = compute_W.unnormalize(grid_x); + Vec y = compute_H.unnormalize(grid_y); + Vec gx_mult = Vec(compute_W.scaling_factor); + Vec gy_mult = Vec(compute_H.scaling_factor); + + auto ix = x.floor(); + auto iy = y.floor(); + + Vec coeff_x[4]; + Vec coeff_y[4]; + get_cubic_coefficients(coeff_x, x - ix); + get_cubic_coefficients(coeff_y, y - iy); + + Vec coeff_x_grad[4]; + Vec coeff_y_grad[4]; + get_cubic_coefficients_grad(coeff_x_grad, x - ix); + get_cubic_coefficients_grad(coeff_y_grad, y - iy); + + auto gx = Vec(0), gy = Vec(0); + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (int64_t c = 0; c < C; ++c) { + auto inp_slice_C_ptr = inp_slice[c].data(); + auto gInp_slice_C_ptr = gInp_slice[c].data(); + auto gOut = Vec::loadu(gOut_slice[c].data() + offset, len); + + for (int64_t i = 0; i < 4; ++i) { + for (int64_t j = 0; j < 4; ++j) { + auto xx = ix + Vec(-1 + i); + auto yy = iy + Vec(-1 + j); + + add_value_bounded(gInp_slice_C_ptr, len, xx, yy, gOut * coeff_x[i] * coeff_y[j]); + + auto val = get_value_bounded(inp_slice_C_ptr, xx, yy); + gx = gx - val * gOut * coeff_x_grad[i] * coeff_y[j]; + gy = gy - val * gOut * coeff_y_grad[j] * coeff_x[i]; + } + } + } + + gx = gx * gx_mult; + gy = gy * gy_mult; + + constexpr int64_t step = Vec::size(); + auto interleaved_gGrid = interleave2(gx, gy); + auto gGrid_ptr = gGrid_slice.data() + offset * 2; + std::get<0>(interleaved_gGrid).store(gGrid_ptr, + std::min(len * 2, step)); + std::get<1>(interleaved_gGrid).store(gGrid_ptr + step, + std::max(static_cast(0), len * 2 - step)); + } +}; + // ~~~~~~~~~~~~~~~~~~ grid_sample_2d_grid_slice_iterator ~~~~~~~~~~~~~~~~~~~~~~ // Function to apply a vectorized function on a grid slice tensor (without batch // dimension). @@ -940,11 +1150,13 @@ Tensor grid_sampler_2d_cpu_kernel_impl(const Tensor& input, const Tensor& grid, switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true); HANDLE_INTERP(GridSamplerInterpolation::Nearest, true); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true); } } else { switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false); HANDLE_INTERP(GridSamplerInterpolation::Nearest, false); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false); } } }); @@ -1014,11 +1226,13 @@ grid_sampler_2d_backward_cpu_kernel_impl(const Tensor& grad_output_, switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true); HANDLE_INTERP(GridSamplerInterpolation::Nearest, true); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true); } } else { switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false); HANDLE_INTERP(GridSamplerInterpolation::Nearest, false); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false); } } }); diff --git a/aten/src/ATen/native/cpu/Intrinsics.h b/aten/src/ATen/native/cpu/Intrinsics.h index c4c1d6b648a37..f3b35328f1882 100644 --- a/aten/src/ATen/native/cpu/Intrinsics.h +++ b/aten/src/ATen/native/cpu/Intrinsics.h @@ -22,6 +22,11 @@ (defined(__VEC__) || defined(__ALTIVEC__)) /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ #include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel #elif defined(__GNUC__) && defined(__SPE__) /* GCC-compatible compiler, targeting PowerPC with SPE */ #include diff --git a/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp b/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp new file mode 100644 index 0000000000000..d6f6a69b09945 --- /dev/null +++ b/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include + +namespace at { namespace native { namespace { + +void addr_kernel(TensorIterator &iter, + Scalar beta, Scalar alpha) { + if (iter.dtype() == ScalarType::Bool) { + using scalar_t = bool; + auto beta_val = beta.to(); + auto alpha_val = alpha.to(); + + // when beta is false, values in self should be ignored, + // nans and infs in self should not propagate. + if (beta_val == false) { + cpu_kernel(iter, + [=](scalar_t self_val, + scalar_t vec1_val, + scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t { + return alpha_val && vec1_val && vec2_val; + } + ); + } else { + cpu_kernel(iter, + [=](scalar_t self_val, + scalar_t vec1_val, + scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t { + return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val); + } + ); + } + return; + } + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, + iter.dtype(), "addr_cpu", [&]() { + using Vec = Vec256; + + auto beta_val = beta.to(); + auto alpha_val = alpha.to(); + + auto beta_vec = Vec(beta_val); + auto alpha_vec = Vec(alpha_val); + + const scalar_t zero_val(0); + // when beta == 0, values in self should be ignored, + // nans and infs in self should not propagate. + if (beta_val == zero_val) { + cpu_kernel_vec(iter, + [=](scalar_t self_val, + scalar_t vec1_val, + scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t { + return alpha_val * vec1_val * vec2_val; + }, + [=](Vec self_vec, + Vec vec1_vec, + Vec vec2_vec) __ubsan_ignore_undefined__ { + return alpha_vec * vec1_vec * vec2_vec; + } + ); + } else { + cpu_kernel_vec(iter, + [=](scalar_t self_val, + scalar_t vec1_val, + scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t { + return beta_val * self_val + alpha_val * vec1_val * vec2_val; + }, + [=](Vec self_vec, + Vec vec1_vec, + Vec vec2_vec) __ubsan_ignore_undefined__ { + return beta_vec * self_vec + alpha_vec * vec1_vec * vec2_vec; + } + ); + } + } + ); +} + +} // anonymous namespace + +REGISTER_DISPATCH(addr_stub, &addr_kernel); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index f263ce897fbb0..305c14eb9c5a8 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -185,7 +185,7 @@ static inline void unroll_contiguous_scalar_checks( } template -void cpu_kernel(TensorIterator& iter, func_t&& op) { +void cpu_kernel(TensorIteratorBase& iter, func_t&& op) { using traits = function_traits; // this could be extended to work with void return types TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); @@ -207,7 +207,7 @@ void cpu_kernel(TensorIterator& iter, func_t&& op) { } template -void cpu_kernel_vec(TensorIterator& iter, func_t&& op, vec_func_t&& vop) { +void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) { using traits = function_traits; // this could be extended to work with void return types TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); @@ -236,7 +236,7 @@ void cpu_kernel_vec(TensorIterator& iter, func_t&& op, vec_func_t&& vop) { } template -void cpu_serial_kernel(TensorIterator& iter, func_t&& op, const Range& range) { +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) { using traits = function_traits; constexpr bool result_void = std::is_void::value; TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity && @@ -258,12 +258,12 @@ void cpu_serial_kernel(TensorIterator& iter, func_t&& op, const Range& range) { } template -void cpu_serial_kernel(TensorIterator& iter, func_t&& op) { +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) { cpu_serial_kernel(iter, op, {0, iter.numel()}); } template -void cpu_serial_kernel_vec(TensorIterator& iter, func_t&& op, vec_func_t&& vop, const Range& range) { +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) { using traits = function_traits; // this could be extended to work with void return types TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); @@ -289,7 +289,7 @@ void cpu_serial_kernel_vec(TensorIterator& iter, func_t&& op, vec_func_t&& vop, } template -void cpu_serial_kernel_vec(TensorIterator& iter, func_t&& op, vec_func_t&& vop) { +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) { cpu_serial_kernel_vec(iter, op, vop, {0, iter.numel()}); } diff --git a/aten/src/ATen/native/cpu/MaxPooling.cpp b/aten/src/ATen/native/cpu/MaxPooling.cpp index 35575091dcdbd..3741d06a9bf53 100644 --- a/aten/src/ATen/native/cpu/MaxPooling.cpp +++ b/aten/src/ATen/native/cpu/MaxPooling.cpp @@ -30,8 +30,9 @@ void max_pool1d_impl( const Tensor& input, const PoolingParams1D& p) { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] { + const Tensor in = input.contiguous(); scalar_t* const OP = output.data_ptr(); - const scalar_t* const IP = input.contiguous().data_ptr(); + const scalar_t* const IP = in.data_ptr(); // Value used for padding constexpr scalar_t FILL = std::numeric_limits::has_infinity diff --git a/aten/src/ATen/native/cpu/MultinomialKernel.cpp b/aten/src/ATen/native/cpu/MultinomialKernel.cpp index 1f4a520849627..62f1d7b879aca 100644 --- a/aten/src/ATen/native/cpu/MultinomialKernel.cpp +++ b/aten/src/ATen/native/cpu/MultinomialKernel.cpp @@ -11,8 +11,12 @@ namespace at { namespace native { namespace { -template -void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional generator) { +template +void multinomial_with_replacement_apply( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator) { auto gen = get_generator_or_default(generator, detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); @@ -61,8 +65,6 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl } TORCH_CHECK(sum > 0, "invalid multinomial distribution (sum of probabilities <= 0)"); - TORCH_CHECK(with_replacement || (n_categories - n_zeros >= n_sample), - "invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)"); /* normalize cumulative probability distribution so that last val is 1 i.e. doesn't assume original self row sums to one */ @@ -100,45 +102,23 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl /* store in result tensor (will be incremented for lua compat by wrapper) */ result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] = sample_idx; - - /* Once a sample is drawn, it cannot be drawn again. ie sample without replacement */ - if (!with_replacement && j < n_sample - 1) { - /* update cumulative distribution so that sample cannot be drawn again */ - scalar_t diff; - scalar_t new_val = 0; - scalar_t sum; - - if (sample_idx != 0) { - new_val = cum_dist_ptr[(sample_idx - 1) * cum_dist_stride_0]; - } - /* marginal cumulative mass (i.e. original probability) of sample */ - diff = cum_dist_ptr[sample_idx * cum_dist_stride_0] - new_val; - /* new sum of marginals is not one anymore... */ - sum = 1.0 - diff; - for (int64_t k = 0; k < n_categories; k++) { - new_val = cum_dist_ptr[k * cum_dist_stride_0]; - if (k >= sample_idx) { - /* remove sampled probability mass from later cumulative probabilities */ - new_val -= diff; - } - /* make total marginals sum to one */ - new_val /= sum; - cum_dist_ptr[k * cum_dist_stride_0] = new_val; - } - } } } } -static void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional gen) { +static void multinomial_with_replacement_kernel_impl( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional gen) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "multinomial", [&] { - multinomial_apply(result, self, n_sample, with_replacement, gen); + multinomial_with_replacement_apply(result, self, n_sample, gen); }); } - } -REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl); - +REGISTER_DISPATCH( + multinomial_with_replacement_stub, + &multinomial_with_replacement_kernel_impl); } } diff --git a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp index 45c803e0fec2f..4a52178972fca 100644 --- a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp @@ -46,28 +46,39 @@ static void addcdiv_cpu_kernel(TensorIterator& iter, Scalar value) { }); } -static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, Scalar norm) { +static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, Scalar norm, double beta) { ScalarType dtype = iter.dtype(0); AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] { auto norm_val = norm.to(); + scalar_t beta_val(beta); auto norm_val_vec = Vec256(norm_val); + auto beta_val_vec = Vec256(beta_val); const auto neg_1_vec = Vec256(-1); + const auto zero_vec = Vec256(0); const auto pos_1_vec = Vec256(1); cpu_kernel_vec(iter, [=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t { const auto x = input - target; - if (x < -1.) + if (x <= -beta) return -norm_val * grad_output; - else if (x > 1.) + else if (x >= beta) return norm_val * grad_output; else - return norm_val * x * grad_output; + return norm_val * x * grad_output / beta; }, - [norm_val_vec, neg_1_vec, pos_1_vec]( + [norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec]( Vec256 input, Vec256 target, Vec256 grad_output) -> Vec256 { - auto x = input - target; - x = clamp(x, neg_1_vec, pos_1_vec); - return norm_val_vec * x * grad_output; + // using two blendv calls to simulate the 3 cases + // 1 if x >= beta + // -1 if x <= -beta + // x / beta if |x| < beta + const auto x = input - target; + const auto pos_or_neg_1_vec = Vec256::blendv( + neg_1_vec, pos_1_vec, x > zero_vec); + const auto x_abs = x.abs(); + const auto output = Vec256::blendv( + x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec); + return norm_val_vec * output * grad_output; } ); }); diff --git a/aten/src/ATen/native/cpu/PowKernel.cpp b/aten/src/ATen/native/cpu/PowKernel.cpp index b7ec099a80da4..6f0d153e978a2 100644 --- a/aten/src/ATen/native/cpu/PowKernel.cpp +++ b/aten/src/ATen/native/cpu/PowKernel.cpp @@ -63,7 +63,7 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) { ); } else if (exp == -0.5) { cpu_kernel_vec(iter, - [](scalar_t base) -> scalar_t { + [](scalar_t base) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return 1.0 / std::sqrt(base); }, [](Vec base) -> Vec { return base.rsqrt(); } diff --git a/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp b/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp index 55ab614e42d1e..ba7f1af7eabb9 100644 --- a/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp +++ b/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 129cdc0845c4f..b94e4b44aae0e 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -109,7 +109,7 @@ static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, } template -static void set_result(const int index, const res_t result, const TensorIterator &iter, const int num_outputs) { +static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) { // static_assert(std::is_same::value, "data types must match"); if (index < num_outputs) { char *out = (char *) iter.data_ptr(index); @@ -118,20 +118,20 @@ static void set_result(const int index, const res_t result, const TensorIterator } template -static void set_results(const res_t result, const TensorIterator &iter, const int num_outputs) { +static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) { AT_ASSERT(num_outputs == 1); set_result(0, result, iter, num_outputs); } template static inline typename std::enable_if::type -for_each_in_tuple(const std::tuple& t, const TensorIterator &iter, const int num_outputs) { +for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { return i; } template static inline typename std::enable_if::type -for_each_in_tuple(const std::tuple& t, const TensorIterator &iter, const int num_outputs) { +for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { if (i < (size_t)num_outputs) { set_result(i, std::get(t), iter, num_outputs); return for_each_in_tuple(t, iter, num_outputs); @@ -140,7 +140,7 @@ for_each_in_tuple(const std::tuple& t, const TensorIterator &iter, c } template -static void set_results(const std::tuple& result, const TensorIterator &iter, const int num_outputs) { +static void set_results(const std::tuple& result, const TensorIteratorBase &iter, const int num_outputs) { AT_ASSERT(num_outputs >= 1); std::size_t result_size = for_each_in_tuple(result, iter, num_outputs); AT_ASSERT((size_t)num_outputs == result_size); @@ -178,7 +178,7 @@ struct all_same : guts::conjunction< // into several pieces, reduce each separately, and then combine them. template -void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) { +void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { using rf_t = decltype(&ops_t::reduce); using cf_t = decltype(&ops_t::combine); using pf_t = decltype(&ops_t::project); @@ -202,7 +202,7 @@ void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) { "the accumulate type must be default-constructible" ); const int num_outputs = iter.noutputs(); - iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIterator &sub_iter) { + iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) { auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t { int ntensors = sub_iter.ntensors(); sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) { @@ -244,7 +244,7 @@ void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) { } template -void binary_kernel_reduce_vec(TensorIterator& iter, func_t op, vec_func_t vop, double ident = 0) { +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { using traits = binary_function_traits; static_assert( all_same< diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp index b2b9f82ed4d1c..b64ba684f562b 100644 --- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp @@ -74,7 +74,7 @@ static void min_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, upper_bound(), [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(input.scalar_type(), "min_all", [&] { + AT_DISPATCH_ALL_TYPES(input.scalar_type(), "min_all", [&] { using Vec = vec256::Vec256; reduce_all_impl_vec(result, input, upper_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); }, @@ -99,7 +99,7 @@ static void max_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, lower_bound(), [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(input.scalar_type(), "max_all", [&] { + AT_DISPATCH_ALL_TYPES(input.scalar_type(), "max_all", [&] { using Vec = vec256::Vec256; reduce_all_impl_vec(result, input, lower_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); }, @@ -193,7 +193,7 @@ static void _aminmax_all_kernel_impl(Tensor& min_result, Tensor& max_result, } ); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(input.scalar_type(), "_aminmax_all_all", [&] { + AT_DISPATCH_ALL_TYPES(input.scalar_type(), "_aminmax_all_all", [&] { using Vec = vec256::Vec256; using scalar_t_pair = std::pair; reduce_all_impl_vec_two_outputs( diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index b6d38ce36bc03..14f3d4a1fc21b 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -158,13 +158,23 @@ static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_s } static void prod_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { + // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] + if (iter.dtype() == ScalarType::Bool) { + using scalar_t = bool; binary_kernel_reduce_vec( iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, - [=](Vec256 a, Vec256 b) { return a * b; }, + [=](scalar_t a, scalar_t b) -> scalar_t { return a && b; }, + [=](Vec256 a, Vec256 b) { return a && b; }, /*identity=*/1); - }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { + binary_kernel_reduce_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, + [=](Vec256 a, Vec256 b) { return a * b; }, + /*identity=*/1); + }); + } } static void norm_kernel_tensor_iterator_impl( @@ -174,99 +184,147 @@ static void norm_kernel_tensor_iterator_impl( if (p.isIntegral(false)) { val = p.to(); } else if (p.isFloatingPoint()) { - val = p.to(); + val = p.to(); } else { AT_ERROR("norm_kernel_tensor_iterator_impl expects norm to be integer or float"); } - + // In the dispatch code blocks below, reduction kernels accumulate results as + // the type `acc_t`. When `scalar_t` is complex, `acc_t` is the downgraded + // real number type. Otherwise, `acc_t` and `scalar_t` are the same type. if (val == 0) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormZeroOps(), - scalar_t(0) + NormZeroOps(), + acc_t(0) ); }); } else if (val == 1) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormOneOps(), - scalar_t(0) + NormOneOps(), + acc_t(0) ); }); } else if (val == 2) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormTwoOps(), - scalar_t(0) + NormTwoOps(), + acc_t(0) ); }); } else if (val == INFINITY) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - AbsMaxOps(), - scalar_t(std::numeric_limits::min()) + AbsMaxOps(), + acc_t(0) ); }); } else if (val == -INFINITY) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - AbsMinOps(), - scalar_t(std::numeric_limits::max()) + AbsMinOps(), + std::numeric_limits::max() ); }); } else { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] { + using acc_t = typename scalar_value_type::type; binary_kernel_reduce( iter, - NormOps { scalar_t(val) }, - scalar_t(0) + NormOps { acc_t(val) }, + acc_t(0) ); }); } + + // For complex outputs, the above kernels do not touch the imaginary values, + // so we must zero them out + if (isComplexType(iter.output().scalar_type())) { + at::imag(iter.output()).zero_(); + } } static void and_kernel_impl(TensorIterator& iter) { - binary_kernel_reduce_vec( - iter, - [=](uint8_t a, uint8_t b) -> uint8_t { return a && b; }, - [=](Vec256 a, Vec256 b) { - // Adding the implementation here instead of in vec256_base to avoid - // return value inconsistency. Other comparison operators in vec256_base - // return -1/0 (all bit 1 / all bit 0) as true/false to follow the AVX2 - // convention. This would be convenient when combined with other - // vectorized operations. For example, one can use the logical operation - // results as a mask for a bit operation to retrieve/reset multiple - // elements in a vector. - // - // In this method, users would expect, e.g., all(), to return 1/0 as - // true/false. - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = a[i] && b[i]; - } - return c; - }, - /*ident=*/true); + if (iter.dtype() == ScalarType::Byte) { + // Refer [all, any : uint8 compatibility] + binary_kernel_reduce_vec( + iter, + [=](uint8_t a, uint8_t b) -> uint8_t { return (a && b) ? 1 : 0; }, + [=](Vec256 a, Vec256 b) { + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = (a[i] && b[i]) ? 1 : 0; + } + return c; + }, + /*ident=*/true); + } else { + binary_kernel_reduce_vec( + iter, + [=](bool a, bool b) -> bool { return a && b; }, + [=](Vec256 a, Vec256 b) { + // Adding the implementation here instead of in vec256_base to avoid + // return value inconsistency. Other comparison operators in + // vec256_base return -1/0 (all bit 1 / all bit 0) as true/false to + // follow the AVX2 convention. This would be convenient when combined + // with other vectorized operations. For example, one can use the + // logical operation results as a mask for a bit operation to + // retrieve/reset multiple elements in a vector. + // + // In this method, users would expect, e.g., all(), to return 1/0 as + // true/false. + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = a[i] && b[i]; + } + return c; + }, + /*ident=*/true); + } } static void or_kernel_impl(TensorIterator& iter) { - binary_kernel_reduce_vec( - iter, - [=](uint8_t a, uint8_t b) -> uint8_t { return a || b; }, - [=](Vec256 a, Vec256 b) { - Vec256 c = Vec256(); - for (int i = 0; i != Vec256::size(); i++) { - c[i] = a[i] || b[i]; - } - return c; - }, - /*ident=*/false); + if (iter.dtype() == ScalarType::Byte) { + // Refer [all, any : uint8 compatibility] + binary_kernel_reduce_vec( + iter, + [=](uint8_t a, uint8_t b) -> uint8_t { return (a || b) ? 1 : 0; }, + [=](Vec256 a, Vec256 b) { + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = (a[i] || b[i]) ? 1 : 0; + } + return c; + }, + /*ident=*/false); + } else { + binary_kernel_reduce_vec( + iter, + [=](bool a, bool b) -> bool { return a || b; }, + [=](Vec256 a, Vec256 b) { + Vec256 c = Vec256(); + + for (decltype(c.size()) i = 0; i != Vec256::size(); i++) { + c[i] = a[i] || b[i]; + } + return c; + }, + /*ident=*/false); + } } template @@ -294,7 +352,7 @@ static void min_values_kernel_impl(TensorIterator& iter) { iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vec256 a, Vec256 b) { return minimum(a, b); }, - upper_bound()); + static_cast(upper_bound())); }); } diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index ccb5d5634423d..056b941f77784 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -7,7 +7,7 @@ namespace at { namespace native { namespace { - + // Implement as functors since lambdas don't get optimized. class ReduceMultiply { public: @@ -145,7 +145,7 @@ struct cpu_scatter_gather_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( ScalarType::Bool, ScalarType::Half, iter.dtype(), - method_name, [&] { + "method_name", [&] { constexpr auto SELF_ITER_STRIDE_IDX = 0; constexpr auto INDEX_ITER_STRIDE_IDX = 1; @@ -240,7 +240,7 @@ struct cpu_scatter_gather_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( ScalarType::Bool, ScalarType::Half, iter.dtype(), - method_name, [&] { + "method_name", [&] { constexpr auto SELF_ITER_STRIDE_IDX = 0; constexpr auto INDEX_ITER_STRIDE_IDX = 2; constexpr auto SRC_ITER_STRIDE_IDX = 1; diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp index 7d13de1855099..db4a6cc864612 100644 --- a/aten/src/ATen/native/cpu/SortingKernel.cpp +++ b/aten/src/ATen/native/cpu/SortingKernel.cpp @@ -47,10 +47,10 @@ void _dim_apply( auto values_dim_stride = values.stride(dim); auto indices_dim_stride = indices.stride(dim); auto dim_size = values.size(dim); - + AT_DISPATCH_ALL_TYPES_AND2( ScalarType::Bool, ScalarType::Half, iter.dtype(), - method_name, [&] { + "sorting_kernel_method_name", [&] { auto loop = [&](char** data, const int64_t* strides, int64_t n) { auto* values_data_bytes = data[0]; auto* indices_data_bytes = data[1]; @@ -68,7 +68,7 @@ void _dim_apply( indices_data_bytes += strides[1]; } }; - + iter.for_each(loop); } ); @@ -96,7 +96,8 @@ static void sort_kernel( Tensor& values, Tensor& indices, int64_t dim, - bool descending) { + bool descending, + bool stable) { dim = maybe_wrap_dim(dim, values.dim()); _fill_indices(indices, dim); _dim_apply( @@ -114,14 +115,26 @@ static void sort_kernel( auto composite_accessor = CompositeRandomAccessorCPU< decltype(values_accessor), decltype(indices_accessor) >(values_accessor, indices_accessor); - + if (descending) { - std::sort(composite_accessor, composite_accessor + dim_size, - KeyValueCompDesc()); + if (stable) { + std::stable_sort(composite_accessor, composite_accessor + dim_size, + KeyValueCompDesc()); + } + else { + std::sort(composite_accessor, composite_accessor + dim_size, + KeyValueCompDesc()); + } } else { - std::sort(composite_accessor, composite_accessor + dim_size, - KeyValueCompAsc()); + if (stable) { + std::stable_sort(composite_accessor, composite_accessor + dim_size, + KeyValueCompAsc()); + } + else { + std::sort(composite_accessor, composite_accessor + dim_size, + KeyValueCompAsc()); + } } } ); diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 6e4b3c325f5d3..4be47fd12837d 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -81,7 +81,7 @@ static void min_kernel_impl( TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong, "Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type()); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { compare_base_kernel(result, indice, self, wrap_dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -118,7 +118,7 @@ static void max_kernel_impl( TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong, "Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type()); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, self.scalar_type(), "max_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max_cpu", [&] { compare_base_kernel(result, indice, self, wrap_dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -183,7 +183,8 @@ static void _aminmax_kernel_impl( } static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "where_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, + iter.dtype(), "where_cpu", [&] { if (condition_type == at::ScalarType::Byte) { cpu_kernel( iter, diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 45c7e4e237622..6ed4b3af6f23c 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -167,7 +167,7 @@ static void abs_kernel(TensorIterator& iter) { } static void angle_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "angle_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "angle_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return angle_impl(a); }, @@ -213,11 +213,14 @@ static void bitwise_not_kernel(TensorIterator& iter) { }); } else { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_cpu", [&]() { - cpu_kernel( + cpu_kernel_vec( iter, [](scalar_t a) -> scalar_t { return ~a; - }); + }, + [](Vec256 a) -> Vec256 { + return ~a; + }); }); } } @@ -235,19 +238,19 @@ static void logical_not_kernel(TensorIterator& iter) { // NOTE: this implementation differs from the CUDA implementation which only does single dispatch // (to avoid expensive compilation) because CPU kernels don't handle dynamic_casting // (see needs_dynamic_casting). - AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(1), "logical_not_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(1), "logical_not_cpu", [&]() { using self_t = scalar_t; - AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(0), "logical_not_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(0), "logical_not_cpu", [&]() { cpu_kernel(iter, [](self_t a) -> scalar_t { return static_cast(!a); }); }); }); } static void reciprocal_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "reciprocal_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "reciprocal_cpu", [&]() { cpu_kernel_vec( iter, - [=](scalar_t a) -> scalar_t { return static_cast(1.0) / a; }, + [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return static_cast(1.0) / a; }, [=](Vec256 a) { return a.reciprocal(); }); }); } @@ -274,7 +277,7 @@ static void sign_kernel(TensorIterator& iter){ [=](scalar_t a) -> scalar_t { return (0 < a) - (a < 0); }, [=](Vec256 self_vec){ - // Comparision operators returns bitmask. + // Comparison operators returns bitmask. auto left = Vec256::blendv(zero_vec, one_vec, zero_vec < self_vec); auto right = Vec256::blendv(zero_vec, one_vec, self_vec < zero_vec); @@ -291,7 +294,7 @@ static void signbit_kernel(TensorIterator& iter){ } static void sgn_kernel(TensorIterator& iter){ - AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), 'sgn_cpu', [&]() { + AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sgn_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return sgn_impl(a); }, @@ -299,6 +302,21 @@ static void sgn_kernel(TensorIterator& iter){ }); } +static void sinc_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.common_dtype(), "sinc_cpu", [&]() { + cpu_kernel( + iter, + [=](scalar_t a) -> scalar_t { + if (a == scalar_t(0)) { + return scalar_t(1); + } else { + scalar_t product = scalar_t(M_PI) * a; + return std::sin(product) / product; + } + }); + }); +} + static void sinh_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "sinh_cpu", [&]() { cpu_kernel_vec( @@ -318,7 +336,7 @@ static void cosh_kernel(TensorIterator& iter) { } static void acosh_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "acosh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "acosh_cpu", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return std::acosh(a); }); @@ -326,7 +344,7 @@ static void acosh_kernel(TensorIterator& iter) { } static void asinh_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "asinh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "asinh_cpu", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return std::asinh(a); }); @@ -334,7 +352,7 @@ static void asinh_kernel(TensorIterator& iter) { } static void atanh_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "atanh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "atanh_cpu", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return std::atanh(a); }); @@ -342,7 +360,7 @@ static void atanh_kernel(TensorIterator& iter) { } static void digamma_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "digamma", [&]() { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "digamma", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return calc_digamma(a); }); @@ -380,37 +398,61 @@ static void polygamma_kernel(TensorIterator& iter, int64_t n) { } } +static void nan_to_num_kernel( + TensorIterator& iter, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "nan_to_num", [&]() { + scalar_t nan_replacement = static_cast(nan.value_or(0.)); + scalar_t pos_inf_replacement = pos_inf.has_value() + ? static_cast(pos_inf.value()) + : std::numeric_limits::max(); + scalar_t neg_inf_replacement = neg_inf.has_value() + ? static_cast(neg_inf.value()) + : std::numeric_limits::lowest(); + + cpu_kernel(iter, [=](scalar_t a) -> scalar_t { + return ( + at::_isnan(a) + ? nan_replacement + : (a == std::numeric_limits::infinity() + ? pos_inf_replacement + : (a == -std::numeric_limits::infinity() + ? neg_inf_replacement + : a))); + }); + }); +} + static void clamp_kernel(TensorIterator& iter, Scalar min_scalar, Scalar max_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() { - c10::scalar_value_type::type (*zabs_)(scalar_t) = zabs; + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_cpu", [&]() { auto min = min_scalar.to(); auto max = max_scalar.to(); auto min_vec = Vec256(min); auto max_vec = Vec256(max); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : (zabs_(a) > zabs_(max) ? max : a); }, + [=](scalar_t a) -> scalar_t { return std::min(std::max(a, min), max); }, [=](Vec256 a) { return vec256::clamp(a, min_vec, max_vec); }); }); } static void clamp_max_kernel(TensorIterator& iter, Scalar max_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_max_cpu", [&]() { - c10::scalar_value_type::type (*zabs_)(scalar_t) = zabs; + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_max_cpu", [&]() { auto max = max_scalar.to(); auto max_vec = Vec256(max); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return zabs_(a) > zabs_(max) ? max : a; }, + [=](scalar_t a) -> scalar_t { return std::min(a, max); }, [=](Vec256 a) { return vec256::clamp_max(a, max_vec); }); }); } static void clamp_min_kernel(TensorIterator& iter, Scalar min_scalar) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "clamp_min_cpu", [&]() { - c10::scalar_value_type::type (*zabs_)(scalar_t) = zabs; + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "clamp_min_cpu", [&]() { auto min = min_scalar.to(); auto min_vec = Vec256(min); cpu_kernel_vec(iter, - [=](scalar_t a) -> scalar_t { return zabs_(a) < zabs_(min) ? min : a; }, + [=](scalar_t a) -> scalar_t { return std::max(a, min); }, [=](Vec256 a) { return vec256::clamp_min(a, min_vec); }); }); } @@ -545,10 +587,10 @@ static void random_full_64_bits_range_kernel(TensorIterator& iter, c10::optional } static void rsqrt_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "rsqrt_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "rsqrt_cpu", [&] { cpu_kernel_vec( iter, - [=](scalar_t a) -> scalar_t { + [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return (static_cast(1)) / std::sqrt(a); }, [=](Vec256 a) { return a.rsqrt(); }); @@ -560,7 +602,7 @@ static void rsqrt_kernel(TensorIterator& iter) { #define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op) \ static void op##_kernel(TensorIterator& iter) { \ TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \ - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), op##_vml_cpu, [&]() { \ + AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), #op "_vml_cpu", [&]() { \ iter.serial_for_each( \ [&](char** data_, const int64_t* strides, int64_t n) { \ scalar_t* out_data = reinterpret_cast(data_[0]); \ @@ -591,7 +633,7 @@ static void rsqrt_kernel(TensorIterator& iter) { #define IMPLEMENT_COMPLEX_KERNEL(dispatchtypes, op) \ static void op##_kernel(TensorIterator& iter) { \ TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \ - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), op##_vml_cpu, [&]() {\ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), #op "_vml_cpu", [&]() {\ iter.serial_for_each( \ [&](char** data_, const int64_t* strides, int64_t n) { \ scalar_t* out_data = reinterpret_cast(data_[0]); \ @@ -645,10 +687,12 @@ REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel); REGISTER_DISPATCH(logical_not_stub, &logical_not_kernel); REGISTER_DISPATCH(frac_stub, &frac_kernel); REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel); +REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel); REGISTER_DISPATCH(neg_stub, &neg_kernel); REGISTER_DISPATCH(sign_stub, &sign_kernel); REGISTER_DISPATCH(signbit_stub, &signbit_kernel); REGISTER_DISPATCH(sgn_stub, &sgn_kernel); +REGISTER_DISPATCH(sinc_stub, &sinc_kernel); REGISTER_DISPATCH(sinh_stub, &sinh_kernel); REGISTER_DISPATCH(cosh_stub, &cosh_kernel); REGISTER_DISPATCH(acosh_stub, &acosh_kernel); @@ -679,7 +723,7 @@ IMPLEMENT_COMPLEX_KERNEL(FLOATING, log10) IMPLEMENT_FLOAT_KERNEL(FLOATING, log1p) IMPLEMENT_COMPLEX_KERNEL(FLOATING, log2) IMPLEMENT_FLOAT_KERNEL(FLOATING, i0) -IMPLEMENT_COMPLEX_KERNEL(FLOATING, round) +IMPLEMENT_FLOAT_KERNEL(FLOATING, round) IMPLEMENT_COMPLEX_KERNEL(FLOATING, sin) IMPLEMENT_COMPLEX_KERNEL(FLOATING, sqrt) IMPLEMENT_COMPLEX_KERNEL(FLOATING, tan) diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index aa6d57cdd2df0..61e7877761d80 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -4,36 +4,12 @@ #include #include #include +#include namespace at { namespace native { namespace { -template -inline T data_index_init(T offset) { - return offset; -} - -template -inline T data_index_init(T offset, T &x, const T &X, Args &&... args) { - offset = data_index_init(offset, std::forward(args)...); - x = offset % X; - return offset / X; -} - -inline bool data_index_step() { - return true; -} - -template -inline bool data_index_step(T &x, const T &X, Args &&... args) { - if (data_index_step(std::forward(args)...)) { - x = ((x + 1) == X) ? 0 : (x + 1); - return x == 0; - } - return false; -} - static inline int64_t nearest_idx( int64_t output_index, int64_t input_size, diff --git a/aten/src/ATen/native/cpu/group_norm_kernel.cpp b/aten/src/ATen/native/cpu/group_norm_kernel.cpp index 12537659b3f3b..a32c76757f3e1 100644 --- a/aten/src/ATen/native/cpu/group_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/group_norm_kernel.cpp @@ -24,9 +24,9 @@ void GroupNormKernelImplInternal( int64_t HxW, int64_t group, T eps, - Tensor* Y, - Tensor* mean, - Tensor* rstd) { + Tensor& Y, + Tensor& mean, + Tensor& rstd) { TORCH_CHECK(X.numel() == N * C * HxW); TORCH_CHECK(!gamma.defined() || gamma.numel() == C); TORCH_CHECK(!beta.defined() || beta.numel() == C); @@ -35,9 +35,9 @@ void GroupNormKernelImplInternal( const T* X_data = X.data_ptr(); const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.data_ptr() : nullptr; - T* Y_data = Y->data_ptr(); - T* mean_data = mean->data_ptr(); - T* rstd_data = rstd->data_ptr(); + T* Y_data = Y.data_ptr(); + T* mean_data = mean.data_ptr(); + T* rstd_data = rstd.data_ptr(); const T s = T(1) / static_cast(D * HxW); const bool gamma_null = (gamma_data == nullptr); const bool beta_null = beta_data == nullptr; @@ -94,9 +94,9 @@ void GroupNormKernelImpl( int64_t HxW, int64_t group, double eps, - Tensor* Y, - Tensor* mean, - Tensor* rstd) { + Tensor& Y, + Tensor& mean, + Tensor& rstd) { AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() { GroupNormKernelImplInternal( X, @@ -268,9 +268,9 @@ void GroupNormBackwardKernelImplInternal( int64_t C, int64_t HxW, int64_t group, - Tensor* dX, - Tensor* dgamma, - Tensor* dbeta) { + Tensor& dX, + Tensor& dgamma, + Tensor& dbeta) { TORCH_CHECK(dY.numel() == N * C * HxW); TORCH_CHECK(X.numel() == N * C * HxW); TORCH_CHECK(mean.numel() == N * group); @@ -282,9 +282,9 @@ void GroupNormBackwardKernelImplInternal( const T* mean_data = mean.data_ptr(); const T* rstd_data = rstd.data_ptr(); const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; - T* dX_data = dX->defined() ? dX->data_ptr() : nullptr; - T* dgamma_data = dgamma->defined() ? dgamma->data_ptr() : nullptr; - T* dbeta_data = dbeta->defined() ? dbeta->data_ptr() : nullptr; + T* dX_data = dX.defined() ? dX.data_ptr() : nullptr; + T* dgamma_data = dgamma.defined() ? dgamma.data_ptr() : nullptr; + T* dbeta_data = dbeta.defined() ? dbeta.data_ptr() : nullptr; Tensor ds = at::empty({N, C}, X.options()); Tensor db = at::empty({N, C}, X.options()); T* ds_data = ds.data_ptr(); @@ -326,9 +326,9 @@ void GroupNormBackwardKernelImpl( int64_t C, int64_t HxW, int64_t group, - Tensor* dX, - Tensor* dgamma, - Tensor* dbeta) { + Tensor& dX, + Tensor& dgamma, + Tensor& dbeta) { AT_DISPATCH_FLOATING_TYPES( X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() { GroupNormBackwardKernelImplInternal( diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h new file mode 100644 index 0000000000000..32d1de5adb519 --- /dev/null +++ b/aten/src/ATen/native/cpu/utils.h @@ -0,0 +1,30 @@ +#pragma once + +namespace at { namespace native { namespace { + +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T &x, const T &X, Args &&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T &x, const T &X, Args &&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +}}} // namespace at::native:: diff --git a/aten/src/ATen/native/cpu/zmath.h b/aten/src/ATen/native/cpu/zmath.h index e0554e0cbc29b..7be18ef519b12 100644 --- a/aten/src/ATen/native/cpu/zmath.h +++ b/aten/src/ATen/native/cpu/zmath.h @@ -33,9 +33,18 @@ inline double zabs , double> (c10::complex z) { return std::abs(z); } +// This overload corresponds to non-complex dtypes. +// The function is consistent with its NumPy equivalent +// for non-complex dtypes where `pi` is returned for +// negative real numbers and `0` is returned for 0 or positive +// real numbers. +// Note: `nan` is propagated. template inline VALUE_TYPE angle_impl (SCALAR_TYPE z) { - return 0; + if (at::_isnan(z)) { + return z; + } + return z < 0 ? M_PI : 0; } template<> diff --git a/aten/src/ATen/native/cuda/AbsKernel.cu b/aten/src/ATen/native/cuda/AbsKernel.cu index 4113115d7b129..649b235bf6545 100644 --- a/aten/src/ATen/native/cuda/AbsKernel.cu +++ b/aten/src/ATen/native/cuda/AbsKernel.cu @@ -6,11 +6,16 @@ namespace at { namespace native { +template +struct AbsFunctor { + __device__ __forceinline__ scalar_t operator() (const scalar_t a) const { + return std::abs(a); + } +}; + void abs_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, iter.dtype(), "abs_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return std::abs(a); - }); + gpu_kernel(iter, AbsFunctor()); }); } diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 145fc990daeb5..e512e38e6aaf6 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -112,6 +112,7 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) { input_stride0, input_stride1, input_numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } return result; @@ -229,6 +230,7 @@ std::tuple prelu_backward_cuda(const Tensor& grad_out_, const Te input_stride0, input_stride1, input_numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); // update weight_grad std::vector reduce_dims; @@ -326,13 +328,17 @@ void elu_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_s }); } -void elu_backward_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale) { +void elu_backward_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_backward_cuda", [&]() { auto negcoef = alpha.to() * scale.to(); auto poscoef = scale.to(); auto negiptcoef = input_scale.to(); - gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef; + gpu_kernel(iter, [negcoef, poscoef, negiptcoef, is_result]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + if (is_result) { + return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef; + } else { + return b <= scalar_t(0) ? a * negiptcoef * negcoef * (static_cast(std::exp(b * negiptcoef))) : a * poscoef; + } }); }); } @@ -341,12 +347,10 @@ namespace { void GeluCUDAKernelImpl(TensorIterator& it) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "GeluCUDAKernelImpl", [&] { - using T_ACC = acc_type; - gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { - return static_cast(x) * - c10::cuda::compat::normcdf(static_cast(x)); - }); + using T_ACC = acc_type; + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + return static_cast(x) * + c10::cuda::compat::normcdf(static_cast(x)); }); }); } @@ -354,17 +358,15 @@ void GeluCUDAKernelImpl(TensorIterator& it) { void GeluBackwardCUDAKernelImpl(TensorIterator& it) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "GeluBackwardCUDAKernelImpl", [&] { - using T_ACC = acc_type; - gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5); - const T_ACC cdf = c10::cuda::compat::normcdf(static_cast(x)); - const T_ACC pdf = - c10::cuda::compat::exp( - T_ACC(-0.5) * static_cast(x) * static_cast(x)) * - kBeta; - return static_cast(dy) * (cdf + static_cast(x) * pdf); - }); + using T_ACC = acc_type; + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5); + const T_ACC cdf = c10::cuda::compat::normcdf(static_cast(x)); + const T_ACC pdf = + c10::cuda::compat::exp( + T_ACC(-0.5) * static_cast(x) * static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); }); }); } @@ -389,68 +391,70 @@ void leaky_relu_backward_kernel(TensorIterator& iter, Scalar negval_) { void hardswish_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardswish_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t one_sixth(1.0f / 6.0f); - const scalar_t three(3.0f); - const scalar_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - return self_val * std::min(std::max(self_val + three, zero), six) * one_sixth; - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC one_sixth(1.0f / 6.0f); + const T_ACC three(3.0f); + const T_ACC six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + T_ACC x = static_cast(self_val); + return x * std::min(std::max(x + three, zero), six) * one_sixth; }); }); } void hardswish_backward_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardswish_backward_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t three(3.0f); - const scalar_t neg_three(-3.0f); - const scalar_t one_half(0.5f); - gpu_kernel( - iter, - [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t { - if (self_val < neg_three) { - return zero; - } else if (self_val <= three) { - return grad_val * ((self_val / three) + one_half); - } else { - return grad_val; - } - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC three(3.0f); + const T_ACC neg_three(-3.0f); + const T_ACC one_half(0.5f); + gpu_kernel( + iter, + [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + T_ACC grad_val = static_cast(grad_val_); + T_ACC self_val = static_cast(self_val_); + if (self_val < neg_three) { + return zero; + } else if (self_val <= three) { + return grad_val * ((self_val / three) + one_half); + } else { + return grad_val; + } }); }); } void hardsigmoid_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardsigmoid_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardsigmoid_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t one_sixth(1.0f / 6.0f); - const scalar_t three(3.0f); - const scalar_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - return std::min(std::max(self_val + three, zero), six) * one_sixth; - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC one_sixth(1.0f / 6.0f); + const T_ACC three(3.0f); + const T_ACC six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + T_ACC x = static_cast(self_val); + return std::min(std::max(x + three, zero), six) * one_sixth; }); }); } void hardsigmoid_backward_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardsigmoid_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardsigmoid_backward_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t three(3.0f); - const scalar_t neg_three(-3.0f); - const scalar_t one_sixth(1.0f / 6.0f); - gpu_kernel( - iter, - [zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t { - return (self_val >= neg_three && self_val <= three) - ? grad_val * one_sixth - : zero; - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC three(3.0f); + const T_ACC neg_three(-3.0f); + const T_ACC one_sixth(1.0f / 6.0f); + gpu_kernel( + iter, + [zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + T_ACC grad_val = static_cast(grad_val_); + T_ACC self_val = static_cast(self_val_); + return (self_val >= neg_three && self_val <= three) + ? grad_val * one_sixth + : zero; }); }); } diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu index d051649f069e0..5066480535b9e 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu @@ -520,6 +520,7 @@ namespace { sizeB, sizeC, isizeH, isizeW, osizeH, osizeW, kernel_stride_C, kernel_size_C, istrideB, istrideC, istrideH, istrideW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); break; @@ -562,6 +563,7 @@ namespace { input_data, output_data, isizeH, isizeW, osizeH, osizeW, istrideD, istrideH, istrideW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); break; @@ -571,7 +573,6 @@ namespace { false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } - AT_CUDA_CHECK(cudaGetLastError()); } void adaptive_avg_pool2d_backward_out_cuda_template( @@ -665,6 +666,7 @@ namespace { sizeB, sizeC, isizeH, isizeW, osizeH, osizeW, kernel_stride_C, kernel_size_C, ostrideB, ostrideC, ostrideH, ostrideW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); break; @@ -701,6 +703,7 @@ namespace { atomic_adaptive_average_gradinput <<>> ( gradInput_data, gradOutput_data, isizeH, isizeW, osizeH, osizeW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { @@ -708,6 +711,7 @@ namespace { adaptive_average_gradinput <<>> ( gradInput_data, gradOutput_data, isizeH, isizeW, osizeH, osizeW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } ); @@ -719,7 +723,6 @@ namespace { "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu index 50b1d19105341..3e87105298e09 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu @@ -123,10 +123,9 @@ void adaptiveaveragepool_loop( istrideD, istrideT, istrideH, istrideW, offsetZ); - + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; - AT_CUDA_CHECK(cudaGetLastError()); } } @@ -217,10 +216,9 @@ void adaptiveaveragegradinput_loop( isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ); - + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; - AT_CUDA_CHECK(cudaGetLastError()); } } @@ -312,10 +310,9 @@ void atomicadaptiveaveragegradinput_loop( isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ); - + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; - AT_CUDA_CHECK(cudaGetLastError()); } } diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu index 0b4160d7d095b..dfe4c49b80aa6 100644 --- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu @@ -251,10 +251,9 @@ void adaptive_max_pool2d_out_cuda_template( indices_data, isizeH, isizeW, osizeH, osizeW, istrideD, istrideH, istrideW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - AT_CUDA_CHECK(cudaGetLastError()); - } else { Tensor input_ = input.contiguous(); int64_t sizeB = input_.size(0); @@ -288,10 +287,9 @@ void adaptive_max_pool2d_out_cuda_template( indices_data, isizeH, isizeW, osizeH, osizeW, istrideD, istrideH, istrideW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - AT_CUDA_CHECK(cudaGetLastError()); - } } @@ -346,6 +344,7 @@ void adaptive_max_pool2d_backward_out_cuda_template( gradInput_data, gradOutput_data, indices_data, isizeH, isizeW, osizeH, osizeW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { @@ -354,10 +353,10 @@ void adaptive_max_pool2d_backward_out_cuda_template( gradInput_data, gradOutput_data, indices_data, isizeH, isizeW, osizeH, osizeW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } ); - AT_CUDA_CHECK(cudaGetLastError()); } else { int64_t sizeB = input.size(0); int64_t sizeD = input.size(1); @@ -392,6 +391,7 @@ void adaptive_max_pool2d_backward_out_cuda_template( gradInput_data, gradOutput_data, indices_data, isizeH, isizeW, osizeH, osizeW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { @@ -400,10 +400,10 @@ void adaptive_max_pool2d_backward_out_cuda_template( gradInput_data, gradOutput_data, indices_data, isizeH, isizeW, osizeH, osizeW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } ); - AT_CUDA_CHECK(cudaGetLastError()); } } diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu index ca3178d0ee338..d515cf78bbca6 100644 --- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu @@ -131,10 +131,10 @@ void adaptivemaxpool_loop( adaptivemaxpool<<>>( input_data, output_data, indices_data, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW, offsetZ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; - AT_CUDA_CHECK(cudaGetLastError()); } } @@ -209,10 +209,9 @@ void adaptivemaxgradinput_loop( adaptivemaxgradinput<<>>( gradInput_data, gradOutput_data, indices_data, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ); - + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; - AT_CUDA_CHECK(cudaGetLastError()); } } @@ -286,10 +285,9 @@ void atomicadaptivemaxgradinput_loop( atomicadaptivemaxgradinput<<>>( gradInput_data, gradOutput_data, indices_data, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ); - + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; - AT_CUDA_CHECK(cudaGetLastError()); } } diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu index 0d8b87f402dec..0ac6537a8de99 100644 --- a/aten/src/ATen/native/cuda/AmpKernels.cu +++ b/aten/src/ATen/native/cuda/AmpKernels.cu @@ -3,9 +3,13 @@ #include #include +#include #include -#include +#include #include +#include +#include + namespace { // Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e, @@ -33,49 +37,139 @@ static __host__ __device__ __forceinline__ int isfinite_ensure_cuda_math(float v namespace at { namespace native { -// Multiplies scaled_grad in-place by inv_scale. If an element of scaled_grad was inf or NaN sets found_inf to 1.0. -// -// Args: -// scaled_grad: A (scaled) gradient tensor. May contain infs or NaNs. -// found_inf: A single-element float tensor to which 1.0 will be written if any gradients contain infs/nans. -// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller. -// inv_scale: The inverse of the scale factor by which scaled_grad is currently multiplied. -// -// Returns: -// A tuple with references to scaled_grad, which is now unscaled in place, and found_inf, -// which is now guaranteed to contain 1.0 if an inf or NaN was found in scaled_grad. +namespace { +// Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_cuda_. +// Handles individual tensors that are acceptable to unscale but not MTA-safe. void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad, Tensor& found_inf, const Tensor& inv_scale) { - TORCH_CHECK(scaled_grad.is_cuda(), "scaled_grad must be a CUDA tensor."); + // The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_cuda_, so no input checks. + + // It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case. + const OptionalDeviceGuard device_guard(device_of(scaled_grad)); + + // Acts on scaled_grad in place. + auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + iter.dtype(), + "_amp_non_finite_check_and_unscale_cuda", + [&iter, &found_inf, &inv_scale] { + auto* found_inf_ptr = found_inf.data_ptr(); + auto* inv_scale_ptr = inv_scale.data_ptr(); + + using opmath_t = get_opmath_t::opmath_t; + + gpu_kernel(iter, + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t { + auto val = static_cast(val_in); + if (!isfinite_ensure_cuda_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); + }); +} +} // anonymous namespace + + +// Multiplies each tensor in scaled_grads by inv_scale in-place. +// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0. +// Uses multi tensor apply (MTA) to process all MTA-safe tensors. +// +// Args: +// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs. +// found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans. +// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller. +// inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied. +void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads, + Tensor& found_inf, + const Tensor& inv_scale) +{ + if (scaled_grads.size() == 0) { + return; + } + TORCH_CHECK(inv_scale.is_cuda(), "inv_scale must be a CUDA tensor."); TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor."); TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."); TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor."); TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); - TORCH_CHECK(scaled_grad.layout() == at::kStrided, "scaled_grad must be a strided (not sparse) Tensor."); - // Act on scaled_grad in place. - auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad); + // Ensures client code (GradScaler) filtered scaled_grads by dtype. + check_foreach_api_restrictions(scaled_grads); + + std::vector> tensor_lists; + + // is_non_overlapping_and_dense() is not available in Python. + // GradScaler can't filter for it. We need to filter here. + if (can_use_fast_route(scaled_grads)) { + // Hopefully common case. + // can_use_fast_route is true, which confirms: + // - all scaled_grads are strided + // - all scaled_grads are non overlapping and dense + // - all scaled_grads are on the same device + TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors."); + // Sets up MTA launch to use scaled_grads as-is. + tensor_lists.emplace_back(scaled_grads.vec()); + } else { + // Hopefully uncommon case. + // can_use_fast_route is an all-or-nothing check. In this path it was false, + // so any of the above confirmations could have gone wrong. + // We filter MTA-safe tensors into an MTA-able list. + // If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel. + // If a tensor is unacceptable, we throw an error to blame GradScaler. + tensor_lists.resize(1); + tensor_lists[0].reserve(scaled_grads.size()); + auto expected_device = scaled_grads[0].device(); + for (const Tensor& t : scaled_grads) { + // Ensures GradScaler filtered scaled_grads by device. + TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor."); + TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); + TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); + if (!t.is_non_overlapping_and_dense()) { + // t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel. + _amp_non_finite_check_and_unscale_cuda_(const_cast(t), + found_inf, + inv_scale); + } else { + tensor_lists[0].push_back(t); + } + } + if (tensor_lists[0].size() == 0) { + return; + } + } AT_DISPATCH_FLOATING_TYPES_AND_HALF( - iter.dtype(), - "_amp_non_finite_check_and_unscale_cuda", - [&iter, &found_inf, &inv_scale] { + tensor_lists[0][0].scalar_type(), + "_amp_foreach_non_finite_check_and_unscale_cuda", + [&tensor_lists, &found_inf, &inv_scale] { auto* found_inf_ptr = found_inf.data_ptr(); auto* inv_scale_ptr = inv_scale.data_ptr(); - gpu_kernel(iter, [found_inf_ptr, inv_scale_ptr]GPU_LAMBDA(scalar_t val) -> scalar_t { - float fval = static_cast(val); - // See isfinite_ensure_cuda_math above. - if (!isfinite_ensure_cuda_math(fval)) { - *found_inf_ptr = 1.f; - } - const auto inv_scale_val = *inv_scale_ptr; // Every thread accesses inv_scale, but it will hit in cache. - return static_cast(inv_scale_val == 1.f ? fval : fval*inv_scale_val); - }); + using opmath_t = get_opmath_t::opmath_t; + + // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { + // There is a slight asymmetry here with the TensorIterator kernel above. + // MTA Functors ensure val comes in as opmath_t rather than scalar_t. + if (!isfinite_ensure_cuda_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); }); } @@ -149,6 +243,7 @@ Tensor _amp_update_scale_cuda(Tensor& growth_tracker, growth_factor, backoff_factor, growth_interval); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return new_scale; } diff --git a/aten/src/ATen/native/cuda/AveragePool2d.cu b/aten/src/ATen/native/cuda/AveragePool2d.cu index e8988244289dd..6973cdf21af17 100644 --- a/aten/src/ATen/native/cuda/AveragePool2d.cu +++ b/aten/src/ATen/native/cuda/AveragePool2d.cu @@ -125,7 +125,7 @@ __global__ void avg_pool2d_backward_out_cuda_frame(const int nthreads, const sca const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, - scalar_t* const bottom_diff, const int divisor_override, + scalar_t* const bottom_diff, const int divisor_override, bool count_include_pad, bool use_divisor) { CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index @@ -176,19 +176,19 @@ __global__ void avg_pool2d_backward_out_cuda_frame(const int nthreads, const sca } template -__global__ void avg_pool2d_backward_out_cuda_frame_nhwc(const int nthreads, +__global__ void avg_pool2d_backward_out_cuda_frame_nhwc(const int nthreads, const scalar_t* const top_diff, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, - scalar_t* const bottom_diff, const int divisor_override, + scalar_t* const bottom_diff, const int divisor_override, bool count_include_pad, bool use_divisor) { CUDA_KERNEL_LOOP(index, nthreads) { - const int c = index % channels; + const int c = index % channels; const int w = (index / channels) % width; - const int h = (index / channels / width) % height; - const int n = index / channels / width / height; + const int h = (index / channels / width) % height; + const int n = index / channels / width / height; const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; const int phend = min(h / stride_h + 1, pooled_height); @@ -262,14 +262,6 @@ void avg_pool2d_out_cuda_template( const int padH = safe_downcast(padding[0]); const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - const auto memory_format = input_.suggest_memory_format(); - if (memory_format == at::MemoryFormat::ChannelsLast){ - TORCH_CHECK(input_.ndimension() == 4, - "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); - } else { - TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - } TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero"); @@ -281,13 +273,14 @@ void avg_pool2d_out_cuda_template( const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + const auto memory_format = input_.suggest_memory_format(); pool2d_shape_check( input_, kH, kW, dH, dW, padH, padW, 1, 1, nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth); + outputHeight, outputWidth, memory_format); Tensor input = input_.contiguous(memory_format); @@ -298,20 +291,38 @@ void avg_pool2d_out_cuda_template( const uint32_t num_blocks = cuda::ATenCeilDiv(count, num_threads); bool use_divisor = divisor_override.has_value(); - const auto divisor_override_value = use_divisor ? divisor_override.value() : 0; - - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), - "avg_pool2d_out_cuda_frame", - [&] { - using accscalar_t = acc_type; - - scalar_t *output_data = output.data_ptr(); - scalar_t *input_data = input.data_ptr(); + const auto divisor_override_value = use_divisor ? divisor_override.value() : 0; - switch (memory_format){ - case MemoryFormat::ChannelsLast: { - output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast); - avg_pool2d_out_cuda_frame_nhwc + if (count != 0) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "avg_pool2d_out_cuda_frame", + [&] { + using accscalar_t = acc_type; + + scalar_t *output_data = output.data_ptr(); + scalar_t *input_data = input.data_ptr(); + + switch (memory_format){ + case MemoryFormat::ChannelsLast: { + output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast); + avg_pool2d_out_cuda_frame_nhwc + <<>>( + count, + input_data, + nbatch, + nInputPlane, + inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, + dH, dW, + padH, padW, + output_data, + divisor_override_value, + count_include_pad, use_divisor); + break; + } + case MemoryFormat::Contiguous: { + avg_pool2d_out_cuda_frame <<>>( count, input_data, @@ -325,31 +336,13 @@ void avg_pool2d_out_cuda_template( output_data, divisor_override_value, count_include_pad, use_divisor); - break; - } - case MemoryFormat::Contiguous: { - avg_pool2d_out_cuda_frame - <<>>( - count, - input_data, - nbatch, - nInputPlane, - inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, - dH, dW, - padH, padW, - output_data, - divisor_override_value, - count_include_pad, use_divisor); - break; + break; + } + default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } - default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } - } - ); - - AT_CUDA_CHECK(cudaGetLastError()); + ); + } if (input.ndimension() == 3) { output.resize_({nInputPlane, outputHeight, outputWidth}); @@ -394,15 +387,7 @@ Tensor& avg_pool2d_backward_out_cuda_template( TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero"); - const auto memory_format = input_.suggest_memory_format(); - if (memory_format == at::MemoryFormat::ChannelsLast) { - TORCH_CHECK(input_.ndimension() == 4, - "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); - } else { - TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - } - + const auto memory_format = input_.suggest_memory_format(); const Tensor input = input_.contiguous(memory_format); const Tensor gradOutput = gradOutput_.contiguous(memory_format); @@ -421,11 +406,14 @@ Tensor& avg_pool2d_backward_out_cuda_template( kH, kW, dH, dW, padH, padW, nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth); + outputHeight, outputWidth, memory_format); gradInput.resize_as_(input); - const int32_t count = safe_downcast(input.numel()); + if (count == 0) { + return gradInput; + } + const uint32_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); const uint32_t num_blocks = cuda::ATenCeilDiv(count, num_threads); @@ -455,8 +443,9 @@ Tensor& avg_pool2d_backward_out_cuda_template( dH, dW, padH, padW, gradInput_data, - divisor_override_value, + divisor_override_value, count_include_pad, use_divisor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } case MemoryFormat::Contiguous: { @@ -472,8 +461,9 @@ Tensor& avg_pool2d_backward_out_cuda_template( dH, dW, padH, padW, gradInput_data, - divisor_override_value, + divisor_override_value, count_include_pad, use_divisor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); @@ -481,8 +471,6 @@ Tensor& avg_pool2d_backward_out_cuda_template( } ); - AT_CUDA_CHECK(cudaGetLastError()); - return gradInput; } diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu index 4214b4dace198..388b04dba76a9 100644 --- a/aten/src/ATen/native/cuda/AveragePool3d.cu +++ b/aten/src/ATen/native/cuda/AveragePool3d.cu @@ -317,16 +317,17 @@ __global__ void avg_pool3d_cuda_update_grad_input( } } -#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ +#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ avg_pool3d_cuda_update_output \ <<>>( \ - work_input.packed_accessor64(), \ - work_output.packed_accessor64(), \ + work_input.packed_accessor64(), \ + work_output.packed_accessor64(), \ kT, kH, \ dT, dH, dW, \ padT, padH, padW, \ count_include_pad, \ offsetZ, divisor); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ break void avg_pool3d_out_cuda_template( @@ -443,11 +444,10 @@ void avg_pool3d_out_cuda_template( padT, padH, padW, count_include_pad, offsetZ, divisor); - break; + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; } - AT_CUDA_CHECK(cudaGetLastError()); - totalZ -= 65535; offsetZ += 65535; } @@ -581,8 +581,7 @@ void avg_pool3d_backward_out_cuda_template( kT, kH, kW, 1.0f/divide_factor, offsetZ); - - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; @@ -614,6 +613,7 @@ void avg_pool3d_backward_out_cuda_template( padT, padH, padW, count_include_pad, offsetZ, divisor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { avg_pool3d_cuda_update_grad_input @@ -625,10 +625,9 @@ void avg_pool3d_backward_out_cuda_template( padT, padH, padW, count_include_pad, offsetZ, divisor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); - totalZ -= 65535; offsetZ += 65535; } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index c86f355a67c21..8153b75aae8c4 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -8,13 +8,16 @@ #include #include +#include +#include #include +#include #include // for USE_MAGMA #ifdef USE_MAGMA -#include #include +#include const bool use_magma_ = true; #else @@ -92,10 +95,18 @@ void magmaCholeskyBatched( magma_uplo_t uplo, magma_int_t n, scalar_t** dA_array, magma_int_t ldda, magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue); -template +template void magmaTriangularSolve( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, - scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb); + magma_uplo_t uplo, + magma_trans_t trans, + magma_diag_t diag, + magma_int_t m, + magma_int_t n, + scalar_t* dA, + magma_int_t ldda, + scalar_t* dB, + magma_int_t lddb, + const MAGMAQueue& magma_queue); template void magmaTriangularSolveBatched( @@ -116,17 +127,24 @@ void magmaOrgqr( magma_int_t m, magma_int_t n, magma_int_t k, scalar_t* dA, magma_int_t ldda, scalar_t* tau, scalar_t* dT, magma_int_t nb, magma_int_t* info); -template +template void magmaSymeig( magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, scalar_t* dA, magma_int_t ldda, - scalar_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork, - magma_int_t* iwork, magma_int_t liwork, magma_int_t* info); + value_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork, value_t* rwork, + magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info); template +void magmaEig( + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, scalar_t *A, magma_int_t lda, + scalar_t *wr, scalar_t *wi, scalar_t *VL, magma_int_t ldvl, + scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork, magma_int_t *info); + +template void magmaSvd( magma_vec_t jobz, magma_int_t m, magma_int_t n, scalar_t* A, - magma_int_t lda, scalar_t* s, scalar_t* U, magma_int_t ldu, + magma_int_t lda, value_t* s, scalar_t* U, magma_int_t ldu, scalar_t* VT, magma_int_t ldvt, scalar_t* work, magma_int_t lwork, + value_t* rwork, magma_int_t* iwork, magma_int_t* info); template @@ -158,6 +176,28 @@ void magmaSolve( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, + magma_int_t* ipiv, c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgesv_gpu(n, nrhs, + reinterpret_cast(dA), ldda, ipiv, + reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, + magma_int_t* ipiv, c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgesv_gpu(n, nrhs, + reinterpret_cast(dA), ldda, ipiv, + reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaSolveBatched( magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, @@ -176,6 +216,28 @@ void magmaSolveBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, + magma_int_t** dipiv_array, c10::complex** dB_array, magma_int_t lddb, + magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { + magma_zgesv_batched(n, nrhs, + reinterpret_cast(dA_array), ldda, dipiv_array, + reinterpret_cast(dB_array), lddb, dinfo_array, batch_count, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, + magma_int_t** dipiv_array, c10::complex** dB_array, magma_int_t lddb, + magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { + magma_cgesv_batched(n, nrhs, + reinterpret_cast(dA_array), ldda, dipiv_array, + reinterpret_cast(dB_array), lddb, dinfo_array, batch_count, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaLu( magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, @@ -194,6 +256,24 @@ void magmaLu( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaLu>( + magma_int_t m, magma_int_t n, c10::complex* dA, magma_int_t ldda, + magma_int_t* ipiv, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgetrf_gpu(m, n, reinterpret_cast(dA), ldda, ipiv, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLu>( + magma_int_t m, magma_int_t n, c10::complex* dA, magma_int_t ldda, + magma_int_t* ipiv, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgetrf_gpu(m, n, reinterpret_cast(dA), ldda, ipiv, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaLuBatched( magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, @@ -212,6 +292,24 @@ void magmaLuBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaLuBatched>( + magma_int_t m, magma_int_t n, c10::complex** dA_array, magma_int_t ldda, + magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magma_zgetrf_batched(m, n, reinterpret_cast(dA_array), ldda, ipiv_array, info_array, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLuBatched>( + magma_int_t m, magma_int_t n, c10::complex** dA_array, magma_int_t ldda, + magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magma_cgetrf_batched(m, n, reinterpret_cast(dA_array), ldda, ipiv_array, info_array, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaLuNoPiv( magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, @@ -230,6 +328,24 @@ void magmaLuNoPiv( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaLuNoPiv>( + magma_int_t m, magma_int_t n, c10::complex* dA, magma_int_t ldda, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgetrf_nopiv_gpu(m, n, reinterpret_cast(dA), ldda, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLuNoPiv>( + magma_int_t m, magma_int_t n, c10::complex* dA, magma_int_t ldda, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgetrf_nopiv_gpu(m, n, reinterpret_cast(dA), ldda, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaLuNoPivBatched( magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, @@ -246,6 +362,22 @@ void magmaLuNoPivBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaLuNoPivBatched>( + magma_int_t m, magma_int_t n, c10::complex** dA_array, magma_int_t ldda, + magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + magma_zgetrf_nopiv_batched(m, n, reinterpret_cast(dA_array), ldda, info_array, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLuNoPivBatched>( + magma_int_t m, magma_int_t n, c10::complex** dA_array, magma_int_t ldda, + magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + magma_cgetrf_nopiv_batched(m, n, reinterpret_cast(dA_array), ldda, info_array, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> inline magma_int_t magmaGetriOptimalBlocksize(magma_int_t n) { return magma_get_dgetri_nb(n); @@ -256,6 +388,18 @@ inline magma_int_t magmaGetriOptimalBlocksize(magma_int_t n) { return magma_get_sgetri_nb(n); } +template <> +inline magma_int_t magmaGetriOptimalBlocksize>( + magma_int_t n) { + return magma_get_zgetri_nb(n); +} + +template <> +inline magma_int_t magmaGetriOptimalBlocksize>( + magma_int_t n) { + return magma_get_cgetri_nb(n); +} + template<> void magmaGetri( magma_int_t n, double* dA, magma_int_t ldda, magma_int_t* ipiv, double* dwork, @@ -274,6 +418,48 @@ void magmaGetri( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaGetri>( + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + magma_int_t* ipiv, + c10::complex* dwork, + magma_int_t lwork, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgetri_gpu( + n, + reinterpret_cast(dA), + ldda, + ipiv, + reinterpret_cast(dwork), + lwork, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaGetri>( + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + magma_int_t* ipiv, + c10::complex* dwork, + magma_int_t lwork, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgetri_gpu( + n, + reinterpret_cast(dA), + ldda, + ipiv, + reinterpret_cast(dwork), + lwork, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaGetriBatched( magma_int_t n, double** dA_array, magma_int_t ldda, @@ -292,6 +478,54 @@ void magmaGetriBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaGetriBatched>( + magma_int_t n, + c10::complex** dA_array, + magma_int_t ldda, + magma_int_t** ipiv_array, + c10::complex** dinvA_array, + magma_int_t lddia, + magma_int_t* info_array, + magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magma_zgetri_outofplace_batched( + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dinvA_array), + lddia, + info_array, + batchsize, + magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaGetriBatched>( + magma_int_t n, + c10::complex** dA_array, + magma_int_t ldda, + magma_int_t** ipiv_array, + c10::complex** dinvA_array, + magma_int_t lddia, + magma_int_t* info_array, + magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magma_cgetri_outofplace_batched( + n, + reinterpret_cast(dA_array), + ldda, + ipiv_array, + reinterpret_cast(dinvA_array), + lddia, + info_array, + batchsize, + magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaCholeskySolve( magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda, @@ -310,6 +544,28 @@ void magmaCholeskySolve( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaCholeskySolve>( + magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, + c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zpotrs_gpu(uplo, n, nrhs, + reinterpret_cast(dA), ldda, + reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaCholeskySolve>( + magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, + c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cpotrs_gpu(uplo, n, nrhs, + reinterpret_cast(dA), ldda, + reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaCholeskySolveBatched( magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, @@ -326,6 +582,26 @@ void magmaCholeskySolveBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaCholeskySolveBatched>( + magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, + c10::complex** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_zpotrs_batched(uplo, n, nrhs, + reinterpret_cast(dA_array), ldda, + reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaCholeskySolveBatched>( + magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, + c10::complex** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_cpotrs_batched(uplo, n, nrhs, + reinterpret_cast(dA_array), ldda, + reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaCholesky( magma_uplo_t uplo, magma_int_t n, double* dA, @@ -394,21 +670,117 @@ void magmaCholeskyBatched>( AT_CUDA_CHECK(cudaGetLastError()); } -template<> +template <> void magmaTriangularSolve( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, - double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) { - MagmaStreamSyncGuard guard; - magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); + magma_uplo_t uplo, + magma_trans_t trans, + magma_diag_t diag, + magma_int_t m, + magma_int_t n, + double* dA, + magma_int_t ldda, + double* dB, + magma_int_t lddb, + const MAGMAQueue& magma_queue) { + magma_dtrsm( + MagmaLeft, + uplo, + trans, + diag, + m, + n, + 1, + dA, + ldda, + dB, + lddb, + magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); } -template<> +template <> void magmaTriangularSolve( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, - float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) { - MagmaStreamSyncGuard guard; - magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); + magma_uplo_t uplo, + magma_trans_t trans, + magma_diag_t diag, + magma_int_t m, + magma_int_t n, + float* dA, + magma_int_t ldda, + float* dB, + magma_int_t lddb, + const MAGMAQueue& magma_queue) { + magma_strsm( + MagmaLeft, + uplo, + trans, + diag, + m, + n, + 1, + dA, + ldda, + dB, + lddb, + magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaTriangularSolve>( + magma_uplo_t uplo, + magma_trans_t trans, + magma_diag_t diag, + magma_int_t m, + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + c10::complex* dB, + magma_int_t lddb, + const MAGMAQueue& magma_queue) { + magmaDoubleComplex alpha({1, 0}); + magma_ztrsm( + MagmaLeft, + uplo, + trans, + diag, + m, + n, + alpha, + reinterpret_cast(dA), + ldda, + reinterpret_cast(dB), + lddb, + magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaTriangularSolve>( + magma_uplo_t uplo, + magma_trans_t trans, + magma_diag_t diag, + magma_int_t m, + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + c10::complex* dB, + magma_int_t lddb, + const MAGMAQueue& magma_queue) { + magmaFloatComplex alpha({1, 0}); + magma_ctrsm( + MagmaLeft, + uplo, + trans, + diag, + m, + n, + alpha, + reinterpret_cast(dA), + ldda, + reinterpret_cast(dB), + lddb, + magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); } @@ -430,6 +802,30 @@ void magmaTriangularSolveBatched( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaTriangularSolveBatched>( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + c10::complex** dA_array, magma_int_t ldda, c10::complex** dB_array, magma_int_t lddb, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magmaDoubleComplex alpha({1, 0}); + magmablas_ztrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha, + reinterpret_cast(dA_array), ldda, + reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaTriangularSolveBatched>( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + c10::complex** dA_array, magma_int_t ldda, c10::complex** dB_array, magma_int_t lddb, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magmaFloatComplex alpha({1, 0}); + magmablas_ctrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha, + reinterpret_cast(dA_array), ldda, + reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> inline magma_int_t magmaGeqrfOptimalBlocksize(magma_int_t m, magma_int_t n) { return magma_get_dgeqrf_nb(m, n); @@ -440,6 +836,20 @@ inline magma_int_t magmaGeqrfOptimalBlocksize(magma_int_t m, magma_int_t return magma_get_sgeqrf_nb(m, n); } +template <> +inline magma_int_t magmaGeqrfOptimalBlocksize>( + magma_int_t m, + magma_int_t n) { + return magma_get_zgeqrf_nb(m, n); +} + +template <> +inline magma_int_t magmaGeqrfOptimalBlocksize>( + magma_int_t m, + magma_int_t n) { + return magma_get_cgeqrf_nb(m, n); +} + template<> void magmaGeqrf( magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, @@ -466,6 +876,70 @@ void magmaGeqrf( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaGeqrf>( + magma_int_t m, + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t* info, + bool is_v2) { + MagmaStreamSyncGuard guard; + if (!is_v2) { + magma_zgeqrf_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + info); + } else { + magma_zgeqrf2_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + info); + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaGeqrf>( + magma_int_t m, + magma_int_t n, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t* info, + bool is_v2) { + MagmaStreamSyncGuard guard; + if (!is_v2) { + magma_cgeqrf_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + info); + } else { + magma_cgeqrf2_gpu( + m, + n, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + info); + } + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaOrgqr( magma_int_t m, magma_int_t n, magma_int_t k, double* dA, magma_int_t ldda, @@ -484,11 +958,63 @@ void magmaOrgqr( AT_CUDA_CHECK(cudaGetLastError()); } +template <> +void magmaOrgqr>( + magma_int_t m, + magma_int_t n, + magma_int_t k, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t nb, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zungqr_gpu( + m, + n, + k, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + nb, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template <> +void magmaOrgqr>( + magma_int_t m, + magma_int_t n, + magma_int_t k, + c10::complex* dA, + magma_int_t ldda, + c10::complex* tau, + c10::complex* dT, + magma_int_t nb, + magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cungqr_gpu( + m, + n, + k, + reinterpret_cast(dA), + ldda, + reinterpret_cast(tau), + reinterpret_cast(dT), + nb, + info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaSymeig( magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, double* dA, magma_int_t ldda, - double* w, double* wA, magma_int_t ldwa, double* work, magma_int_t lwork, - magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) { + double* w, double* wA, magma_int_t ldwa, double* work, magma_int_t lwork, double* rwork, + magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) { + (void)rwork; // unused + (void)lrwork; // unused MagmaStreamSyncGuard guard; magma_dsyevd_gpu(jobz, uplo, n, dA, ldda, w, wA, ldwa, work, lwork, iwork, liwork, info); AT_CUDA_CHECK(cudaGetLastError()); @@ -497,19 +1023,66 @@ void magmaSymeig( template<> void magmaSymeig( magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, float* dA, magma_int_t ldda, - float* w, float* wA, magma_int_t ldwa, float* work, magma_int_t lwork, - magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) { + float* w, float* wA, magma_int_t ldwa, float* work, magma_int_t lwork, float* rwork, + magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) { + (void)rwork; // unused + (void)lrwork; // unused MagmaStreamSyncGuard guard; magma_ssyevd_gpu(jobz, uplo, n, dA, ldda, w, wA, ldwa, work, lwork, iwork, liwork, info); AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaSymeig, double>( + magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, c10::complex* dA, magma_int_t ldda, + double* w, c10::complex* wA, magma_int_t ldwa, c10::complex* work, magma_int_t lwork, double* rwork, + magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zheevd_gpu( + jobz, uplo, n, reinterpret_cast(dA), ldda, w, reinterpret_cast(wA), + ldwa, reinterpret_cast(work), lwork, rwork, lrwork, iwork, liwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaSymeig, float>( + magma_vec_t jobz, magma_uplo_t uplo, magma_int_t n, c10::complex* dA, magma_int_t ldda, + float* w, c10::complex* wA, magma_int_t ldwa, c10::complex* work, magma_int_t lwork, float* rwork, + magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cheevd_gpu( + jobz, uplo, n, reinterpret_cast(dA), ldda, w, reinterpret_cast(wA), + ldwa, reinterpret_cast(work), lwork, rwork, lrwork, iwork, liwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaEig( + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, double *A, magma_int_t lda, + double *wr, double *wi, double *VL, magma_int_t ldvl, + double *VR, magma_int_t ldvr, double *work, magma_int_t lwork, magma_int_t *info) { + MagmaStreamSyncGuard guard; + magma_dgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaEig( + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, float *A, magma_int_t lda, + float *wr, float *wi, float *VL, magma_int_t ldvl, + float *VR, magma_int_t ldvr, float *work, magma_int_t lwork, magma_int_t *info) { + MagmaStreamSyncGuard guard; + magma_sgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaSvd( magma_vec_t jobz, magma_int_t m, magma_int_t n, double* A, magma_int_t lda, double* s, double* U, magma_int_t ldu, double* VT, magma_int_t ldvt, double* work, magma_int_t lwork, - magma_int_t* iwork, magma_int_t* info) { + double *rwork, magma_int_t* iwork, magma_int_t* info) { + (void)rwork; // unused MagmaStreamSyncGuard guard; magma_dgesdd(jobz, m, n, A, lda, s, U, ldu, VT, ldvt, work, lwork, iwork, info); AT_CUDA_CHECK(cudaGetLastError()); @@ -520,12 +1093,43 @@ void magmaSvd( magma_vec_t jobz, magma_int_t m, magma_int_t n, float* A, magma_int_t lda, float* s, float* U, magma_int_t ldu, float* VT, magma_int_t ldvt, float* work, magma_int_t lwork, - magma_int_t* iwork, magma_int_t* info) { + float* rwork, magma_int_t* iwork, magma_int_t* info) { + (void)rwork; // unused MagmaStreamSyncGuard guard; magma_sgesdd(jobz, m, n, A, lda, s, U, ldu, VT, ldvt, work, lwork, iwork, info); AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaSvd, float>( + magma_vec_t jobz, magma_int_t m, magma_int_t n, c10::complex* A, + magma_int_t lda, float* s, c10::complex* U, magma_int_t ldu, + c10::complex* VT, magma_int_t ldvt, c10::complex* work, magma_int_t lwork, + float *rwork, magma_int_t* iwork, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgesdd(jobz, m, n, reinterpret_cast(A), lda, s, + reinterpret_cast(U), ldu, + reinterpret_cast(VT), ldvt, + reinterpret_cast(work), lwork, + rwork, iwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaSvd, double>( + magma_vec_t jobz, magma_int_t m, magma_int_t n, c10::complex* A, + magma_int_t lda, double* s, c10::complex* U, magma_int_t ldu, + c10::complex* VT, magma_int_t ldvt, c10::complex* work, magma_int_t lwork, + double *rwork, magma_int_t* iwork, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgesdd(jobz, m, n, reinterpret_cast(A), lda, s, + reinterpret_cast(U), ldu, + reinterpret_cast(VT), ldvt, + reinterpret_cast(work), lwork, + rwork, iwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaLuSolve( magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda, magma_int_t* ipiv, @@ -544,6 +1148,23 @@ void magmaLuSolve( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaLuSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, magma_int_t* ipiv, + c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast(dA), ldda, ipiv, reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLuSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, magma_int_t* ipiv, + c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast(dA), ldda, ipiv, reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} template<> void magmaLuSolveBatched( @@ -562,6 +1183,24 @@ void magmaLuSolveBatched( info = magma_sgetrs_batched(MagmaNoTrans, n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, batchsize, magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); } + +template<> +void magmaLuSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, + c10::complex** dB_array, magma_int_t lddb, magma_int_t& info, + magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_zgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast(dA_array), ldda, dipiv_array, reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLuSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, + c10::complex** dB_array, magma_int_t lddb, magma_int_t& info, + magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_cgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast(dA_array), ldda, dipiv_array, reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} #endif #define ALLOCATE_ARRAY(name, type, size) \ @@ -571,7 +1210,7 @@ void magmaLuSolveBatched( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { +static void apply_solve(Tensor& b, Tensor& A, Tensor& infos) { #ifndef USE_MAGMA AT_ERROR("solve: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); @@ -580,25 +1219,24 @@ AT_ERROR("solve: MAGMA library not found in " auto b_data = b.data_ptr(); magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); + magma_int_t lda = std::max(magma_int_t{1}, n); if (b.dim() == 2) { auto ipiv = at::empty({n}, at::kInt); - magma_int_t info = 0; - magmaSolve(n, nrhs, A_data, n, ipiv.data_ptr(), - b_data, n, &info); - infos[0] = info; + infos = infos.to(at::kCPU); // magmaSolve requires infos tensor to live on CPU + magmaSolve(n, nrhs, A_data, lda, ipiv.data_ptr(), + b_data, lda, infos.data_ptr()); } else { + auto infos_data = infos.data_ptr(); auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); - magma_int_t* info_array; magma_int_t* ipiv_data; magma_int_t** ipiv_array; scalar_t** A_array; scalar_t** b_array; - ALLOCATE_ARRAY(info_array, magma_int_t, batch_size); ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n); ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size); ALLOCATE_ARRAY(A_array, scalar_t*, batch_size); @@ -622,10 +1260,10 @@ AT_ERROR("solve: MAGMA library not found in " scalar_t** A_array_cur = &A_array[mini_idx]; scalar_t** b_array_cur = &b_array[mini_idx]; magma_int_t** ipiv_array_cur = &ipiv_array[mini_idx]; - magma_int_t* info_array_cur = &info_array[mini_idx]; + magma_int_t* info_array_cur = &infos_data[mini_idx]; magmaSolveBatched( - n, nrhs, A_array_cur, n, ipiv_array_cur, b_array_cur, n, + n, nrhs, A_array_cur, lda, ipiv_array_cur, b_array_cur, lda, info_array_cur, batch_limit, magma_queue); } @@ -633,12 +1271,8 @@ AT_ERROR("solve: MAGMA library not found in " // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaSolveBatched( - n, nrhs, &A_array[mini_idx], n, &ipiv_array[mini_idx], &b_array[mini_idx], n, - &info_array[mini_idx], batch_size % batch_limit, magma_queue); - } - - for (int64_t i = 0; i < batch_size; i++) { - infos[i] = info_array[i]; + n, nrhs, &A_array[mini_idx], lda, &ipiv_array[mini_idx], &b_array[mini_idx], lda, + &infos_data[mini_idx], batch_size % batch_limit, magma_queue); } } #endif @@ -647,22 +1281,40 @@ AT_ERROR("solve: MAGMA library not found in " std::tuple _solve_helper_cuda(const Tensor& self, const Tensor& A) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "solve_cuda", [&]{ + auto infos = at::empty({std::max(1, batchCount(self))}, self.options().dtype(kInt)); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "solve_cuda", [&]{ apply_solve(self_working_copy, A_working_copy, infos); }); if (self.dim() > 2) { batchCheckErrors(infos, "solve_cuda"); } else { - singleCheckErrors(infos[0], "solve_cuda"); + singleCheckErrors(infos.item().toInt(), "solve_cuda"); } return std::tuple(self_working_copy, A_working_copy); } +// This is a type dispatching helper function for 'apply_solve' +Tensor& _linalg_solve_out_helper_cuda(Tensor& result, Tensor& input, Tensor& infos) { + // 'result' and 'input' should be in column major order (it should be checked before calling this function) + // the content of 'result', 'input' and 'infos' is overriden by 'apply_solve' + // 'result' should contain data of 'other' tensor (right-hand-side of the linear system of equations) + // 'input' should contain data of origianl 'input' tensor (left-hand-side of the linear system) + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_solve_out_cpu", [&]{ + apply_solve(result, input, infos); + }); + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +/* +Computes the inverse of n-by-n matrix 'self', it is saved to 'self_inv'. +'infos' is an int Tensor containing error codes for each matrix in the batched input. +'infos_lu' is for holding magmaLU errors, and 'infos_getri' is for holding magmaGetri errors +For more information see MAGMA's documentation for GETRI and GETRF routines. +*/ template -static void apply_batched_inverse(Tensor& self, Tensor& self_inv, std::vector& infos) { +static void apply_batched_inverse(Tensor& self, Tensor& self_inv, Tensor& infos_lu, Tensor& infos_getri) { #ifndef USE_MAGMA AT_ERROR("inverse: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); @@ -672,17 +1324,24 @@ AT_ERROR("inverse: MAGMA library not found in " auto self_inv_data = self_inv.data_ptr(); auto self_inv_mat_stride = matrixStride(self_inv); + auto infos_lu_data = infos_lu.data_ptr(); + auto infos_getri_data = infos_getri.data_ptr(); + magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount"); + // MAGMA does not work with batch_size == 0, let's return early in this case + if (batch_size == 0) { + return; + } + magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)"); + magma_int_t lda = std::max(1, n); - magma_int_t* info_array; magma_int_t* ipiv_data; magma_int_t** ipiv_array; scalar_t** self_array; scalar_t** self_inv_array; - ALLOCATE_ARRAY(info_array, magma_int_t, batch_size); - ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n); + ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * lda); ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size); ALLOCATE_ARRAY(self_array, scalar_t*, batch_size); ALLOCATE_ARRAY(self_inv_array, scalar_t*, batch_size); @@ -696,7 +1355,7 @@ AT_ERROR("inverse: MAGMA library not found in " MAGMAQueue magma_queue(self.get_device()); magmaLuBatched( - n, n, self_array, n, ipiv_array, info_array, + n, n, self_array, lda, ipiv_array, infos_lu_data, batch_size, magma_queue); constexpr int64_t batch_limit = 65535; @@ -708,67 +1367,67 @@ AT_ERROR("inverse: MAGMA library not found in " scalar_t** self_array_cur = &self_array[mini_idx]; scalar_t** self_inv_array_cur = &self_inv_array[mini_idx]; magma_int_t** ipiv_array_cur = &ipiv_array[mini_idx]; - magma_int_t* info_array_cur = &info_array[mini_idx]; + magma_int_t* info_array_cur_getri = &infos_getri_data[mini_idx]; magmaGetriBatched( - n, self_array_cur, n, ipiv_array_cur, self_inv_array_cur, - n, info_array_cur, batch_limit, magma_queue); + n, self_array_cur, lda, ipiv_array_cur, self_inv_array_cur, + lda, info_array_cur_getri, batch_limit, magma_queue); } // Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaGetriBatched( - n, &self_array[mini_idx], n, &ipiv_array[mini_idx], &self_inv_array[mini_idx], - n, &info_array[mini_idx], batch_size % batch_limit, magma_queue); - } - - for (int64_t i = 0; i < batch_size; i++) { - infos[i] = info_array[i]; + n, &self_array[mini_idx], lda, &ipiv_array[mini_idx], &self_inv_array[mini_idx], + lda, &infos_getri_data[mini_idx], batch_size % batch_limit, magma_queue); } #endif } template -static void apply_single_inverse(Tensor& self, int64_t& info) { +static void apply_single_inverse(Tensor& self, Tensor& infos_lu, Tensor& infos_getri) { #ifndef USE_MAGMA AT_ERROR("inverse: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else auto self_data = self.data_ptr(); magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)"); + magma_int_t lda = std::max(1, n); magma_int_t lwork = n * magmaGetriOptimalBlocksize(n); - magma_int_t info_tmp = 0; - Tensor ipiv = at::empty({n}, at::kInt); + // magmaLu and magmaGetri requires infos tensor to live on CPU + infos_lu = infos_lu.to(at::kCPU); + infos_getri = infos_getri.to(at::kCPU); + + Tensor ipiv = at::empty({lda}, at::kInt); Tensor dwork = at::empty({lwork}, self.options()); - magmaLu(n, n, self_data, n, ipiv.data_ptr(), &info_tmp); - if (info_tmp != 0) { - info = info_tmp; - return; - } + magmaLu(n, n, self_data, lda, ipiv.data_ptr(), infos_lu.data_ptr()); magmaGetri( - n, self_data, n, ipiv.data_ptr(), dwork.data_ptr(), lwork, &info_tmp); - info = info_tmp; + n, self_data, lda, ipiv.data_ptr(), dwork.data_ptr(), lwork, infos_getri.data_ptr()); #endif } Tensor _inverse_helper_cuda_legacy(const Tensor& self) { auto self_inv_working_copy = cloneBatchedColumnMajor(self); if (self.dim() > 2) { - std::vector infos(batchCount(self), 0); + auto infos_lu = at::zeros({std::max(1, batchCount(self))}, self.options().dtype(kInt)); + auto infos_getri = at::zeros({std::max(1, batchCount(self))}, self.options().dtype(kInt)); auto self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ apply_batched_inverse( - self_working_copy, self_inv_working_copy, infos); + self_working_copy, self_inv_working_copy, infos_lu, infos_getri); }); - batchCheckErrors(infos, "inverse_cuda"); + batchCheckErrors(infos_lu, "inverse_cuda"); + batchCheckErrors(infos_getri, "inverse_cuda"); } else { - int64_t info = 0; - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ - apply_single_inverse(self_inv_working_copy, info); + // magmaLu and magmaGetri requires infos tensor to live on CPU + auto infos_lu = at::zeros({1}, self.options().dtype(kInt).device(kCPU)); + auto infos_getri = at::zeros({1}, self.options().dtype(kInt).device(kCPU)); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + apply_single_inverse(self_inv_working_copy, infos_lu, infos_getri); }); - singleCheckErrors(info, "inverse_cuda"); + singleCheckErrors(infos_lu.item().toInt(), "inverse_cuda"); + singleCheckErrors(infos_getri.item().toInt(), "inverse_cuda"); } return self_inv_working_copy; } @@ -785,6 +1444,39 @@ Tensor _inverse_helper_cuda(const Tensor& self) { #endif } +// This is a type dispatching helper function for 'apply_batched_inverse' and 'singleCheckErrors' +Tensor& _linalg_inv_out_helper_cuda_legacy(Tensor& result, Tensor& infos_lu, Tensor& infos_getri) { + // assuming result is in column major order and contains the matrices to invert + if (result.dim() > 2) { + auto input_working_copy = cloneBatchedColumnMajor(result); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_inv_out_cuda", [&]{ + apply_batched_inverse( + input_working_copy, result, infos_lu, infos_getri); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_inv_out_cuda", [&]{ + apply_single_inverse(result, infos_lu, infos_getri); + }); + } + return result; +} + +// This is a MAGMA/cuSOLVER dispatching helper function +Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& infos_getri) { + // This function calculates the inverse matrix in-place + // result should be in column major order and contain matrices to invert +#ifdef USE_CUSOLVER + if ((result.dim() == 2) || (/* result.dim() > 2 && */ batchCount(result) <= 2) || !use_magma_) { + return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas + } else { + return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda + } +#else + return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda +#endif + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -859,7 +1551,7 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp int64_t info = 0; auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_solve_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_solve_cuda", [&]{ apply_cholesky_solve(self_working_copy, A_working_copy, upper, info); }); TORCH_CHECK(info == 0, "MAGMA cholesky_solve : invalid argument: ", -info); @@ -878,10 +1570,11 @@ AT_ERROR("cholesky: MAGMA library not found in " auto self_data = self.data_ptr(); magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)"); + auto lda = std::max(1, n); if (self.dim() == 2) { magma_int_t info = 0; - magmaCholesky(uplo, n, self_data, n, &info); + magmaCholesky(uplo, n, self_data, lda, &info); infos[0] = info; } else { auto self_mat_stride = matrixStride(self); @@ -900,10 +1593,11 @@ AT_ERROR("cholesky: MAGMA library not found in " MAGMAQueue magma_queue(self.get_device()); - constexpr int64_t batch_limit = 262140; + int64_t batch_limit = self.is_complex() ? 65535 : 262140; // Compute as many batches of 262140 possible // 262140 is the size of the largest batch of matrices that can be run with // violating maximum kernel configuration + // For complex input the batch limit is 65535 (determined experimentally, see https://github.com/pytorch/pytorch/pull/47047#discussion_r516086923 for more information) // The number of "mini"-batches are floor(batch_size / batch_limit) // and these cover floor(batch_size / batch_limit) * batch_limit cholesky calls int64_t mini_batches = batch_size / batch_limit, mini_idx; @@ -912,14 +1606,14 @@ AT_ERROR("cholesky: MAGMA library not found in " magma_int_t* info_array_cur = &info_array[mini_idx]; magmaCholeskyBatched( - uplo, n, self_array_cur, n, info_array_cur, batch_limit, magma_queue); + uplo, n, self_array_cur, lda, info_array_cur, batch_limit, magma_queue); } // Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaCholeskyBatched( - uplo, n, &self_array[mini_idx], n, &info_array[mini_idx], batch_size % batch_limit, magma_queue); + uplo, n, &self_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue); } for (int64_t i = 0; i < batch_size; i++) { @@ -1036,7 +1730,7 @@ std::tuple _lu_with_info_cuda(const Tensor& self, bool p self_working_copy = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_cuda", [&]{ apply_lu(self_working_copy, pivots_tensor, infos_tensor, pivot); }); } @@ -1068,11 +1762,14 @@ AT_ERROR("triangular_solve: MAGMA library not found in " magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); + MAGMAQueue magma_queue(b.get_device()); + // batch_size == 1 implies that: // 1. the RHS and LHS tensors have 2 dimensions, or // 2. the RHS and LHS tensors have more than 2 dimensions but all batch dimensions are 1 if (batch_size == 1) { - magmaTriangularSolve(uplo, trans, diag, n, nrhs, A_data, n, b_data, n); + magmaTriangularSolve( + uplo, trans, diag, n, nrhs, A_data, n, b_data, n, magma_queue); } else { auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); @@ -1120,7 +1817,7 @@ std::tuple _triangular_solve_helper_cuda(const Tensor& self, con bool upper, bool transpose, bool unitriangular) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "triangular_solve_cuda", [&]{ apply_triangular_solve(self_working_copy, A_working_copy, upper, transpose, unitriangular); }); return std::tuple(self_working_copy, A_working_copy); @@ -1129,18 +1826,18 @@ std::tuple _triangular_solve_helper_cuda(const Tensor& self, con // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_qr(Tensor& Q, Tensor& R, int64_t n_columns, std::vector& infos) { +static void apply_qr(Tensor& Q, Tensor& R, int64_t q_size_minus_2, int64_t r_size_minus_1, int64_t n_columns, + bool compute_q, std::vector& infos) { #ifndef USE_MAGMA AT_ERROR("qr: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else - auto q_data = Q.data_ptr(); + + magma_int_t m = magma_int_cast(q_size_minus_2, "Q.size(-2)"); + magma_int_t n = magma_int_cast(r_size_minus_1, "R.size(-1)"); + auto r_data = R.data_ptr(); - auto q_matrix_stride = matrixStride(Q); auto r_matrix_stride = matrixStride(R); - - magma_int_t m = magma_int_cast(Q.size(-2), "Q.size(-2)"); - magma_int_t n = magma_int_cast(R.size(-1), "R.size(-1)"); magma_int_t k = m < n ? m : n; magma_int_t nb = magmaGeqrfOptimalBlocksize(m, n); int64_t batch_size = batchCount(R); @@ -1163,6 +1860,10 @@ AT_ERROR("qr: MAGMA library not found in " return; } } + if (!compute_q) { + // this is for mode='r' + return; + } // This phase computes Q (the raw version) // We require to perform ?geqrf_gpu again due to this bug in MAGMA: @@ -1170,6 +1871,8 @@ AT_ERROR("qr: MAGMA library not found in " // - ?geqrf2_gpu gives correct R, but doesn't allow computation of Q via ?orgqr_gpu // Refer to the below link for more details: // http://icl.cs.utk.edu/magma/forum/viewtopic.php?f=2&t=1015&p=2800&hilit=geqrf_gpu#p2800 + auto q_data = Q.data_ptr(); + auto q_matrix_stride = matrixStride(Q); for (int64_t i = 0; i < batch_size; i++) { scalar_t* q_working_ptr = &q_data[i * q_matrix_stride]; magmaGeqrf(m, n, q_working_ptr, m, tau_data, work_data, &info, /*is_v2=*/false); @@ -1186,36 +1889,43 @@ AT_ERROR("qr: MAGMA library not found in " #endif } -std::tuple _qr_helper_cuda(const Tensor& self, bool some) { +std::tuple _linalg_qr_helper_cuda(const Tensor& self, std::string mode) { + bool compute_q, reduced; + std::tie(compute_q, reduced) = _parse_qr_mode(mode); std::vector infos(batchCount(self), 0); // Setup input geometry and inputs for apply_qr std::vector q_sizes, q_strides; int64_t n_columns_q; - std::tie(q_sizes, q_strides, n_columns_q) = _compute_geometry_for_Q(self, some); + std::tie(q_sizes, q_strides, n_columns_q) = _compute_geometry_for_Q(self, reduced); Tensor q_working_copy, r_working_copy; // If there are no elements, then we simply return a pair of tensors of required dimensions if (self.numel() == 0) { - // Fix the number of columns of q_working_copy appropriately - q_sizes[self.dim() - 1] = n_columns_q; - q_working_copy = at::eye(q_sizes[self.dim() - 2], q_sizes[self.dim() - 1], self.options()); - q_working_copy = q_working_copy.expand_as(q_working_copy); - - // We repurpose the same q_sizes for r_working_copy - // Fix the number of rows and columns of q_working_copy appropriately - q_sizes[self.dim() - 1] = self.size(-1); - q_sizes[self.dim() - 2] = n_columns_q; - r_working_copy = at::empty(q_sizes, self.options()); + int64_t n = self.size(-1); + r_working_copy = at::empty({n_columns_q, n}, self.options()); + if (compute_q) { + int64_t n_rows_q = q_sizes[self.dim() - 2]; + q_working_copy = at::eye(n_rows_q, n_columns_q, self.options()); + } else { + q_working_copy = at::empty({0}, self.options()); + } return std::make_tuple(q_working_copy, r_working_copy); } - q_working_copy = at::empty_strided(q_sizes, q_strides, self.options()); - q_working_copy.narrow(-1, 0, self.size(-1)).copy_(self); + if (compute_q) { + q_working_copy = at::empty_strided(q_sizes, q_strides, self.options()); + q_working_copy.narrow(-1, 0, self.size(-1)).copy_(self); + } else { + q_working_copy = at::empty({0}, self.options()); + } r_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "qr_cuda", [&]{ - apply_qr(q_working_copy, r_working_copy, n_columns_q, infos); + int64_t m = q_sizes[self.dim() - 2]; + int64_t n = r_working_copy.size(-1); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "qr_cuda", [&]{ + apply_qr(q_working_copy, r_working_copy, m, n, n_columns_q, compute_q, infos); }); if (self.dim() > 2) { batchCheckErrors(infos, "qr_cuda"); @@ -1223,10 +1933,12 @@ std::tuple _qr_helper_cuda(const Tensor& self, bool some) { singleCheckErrors(infos[0], "qr_cuda"); } - return std::make_tuple(q_working_copy.narrow(-1, 0, n_columns_q), - r_working_copy.narrow(-2, 0, n_columns_q).triu()); + if (compute_q) { + q_working_copy = q_working_copy.narrow(-1, 0, n_columns_q); + } + r_working_copy = r_working_copy.narrow(-2, 0, n_columns_q).triu(); + return std::make_tuple(q_working_copy, r_working_copy); } - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -1235,8 +1947,9 @@ static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool AT_ERROR("symeig: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else + using value_t = typename c10::scalar_value_type::type; auto self_data = self.data_ptr(); - auto eigvals_data = eigvals.data_ptr(); + auto eigvals_data = eigvals.data_ptr(); auto self_matrix_stride = matrixStride(self); auto eigvals_stride = eigvals.size(-1); int64_t batch_size = batchCount(self); @@ -1257,20 +1970,30 @@ AT_ERROR("symeig: MAGMA library not found in " scalar_t wkopt; magma_int_t liwork = -1; magma_int_t iwkopt; - magmaSymeig(jobz, uplo, n, self_data, n, eigvals_data, wA, n, &wkopt, lwork, &iwkopt, liwork, &info); + magma_int_t lrwork = -1; + value_t rwkopt; + magmaSymeig(jobz, uplo, n, self_data, n, eigvals_data, wA, n, &wkopt, lwork, &rwkopt, lrwork, &iwkopt, liwork, &info); scalar_t* work; magma_int_t* iwork; - lwork = magma_int_cast(wkopt, "work_size"); + lwork = magma_int_cast(real_impl(wkopt), "work_size"); liwork = magma_int_cast(iwkopt, "iwork_size"); ALLOCATE_ARRAY(work, scalar_t, lwork); ALLOCATE_ARRAY(iwork, magma_int_t, liwork); + value_t* rwork = nullptr; + c10::Storage storage_rwork; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + lrwork = magma_int_cast(rwkopt, "rwork_size"); + storage_rwork = pin_memory(lrwork); + rwork = static_cast(storage_rwork.data()); + } + for (int64_t i = 0; i < batch_size; i++) { scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - scalar_t* eigvals_working_ptr = &eigvals_data[i * eigvals_stride]; - magmaSymeig(jobz, uplo, n, self_working_ptr, n, eigvals_working_ptr, - wA, n, work, lwork, iwork, liwork, &info); + value_t* eigvals_working_ptr = &eigvals_data[i * eigvals_stride]; + magmaSymeig(jobz, uplo, n, self_working_ptr, n, eigvals_working_ptr, + wA, n, work, lwork, rwork, lrwork, iwork, liwork, &info); infos[i] = info; if (info != 0) { return; @@ -1284,6 +2007,7 @@ std::tuple _symeig_helper_cuda(const Tensor& self, bool eigenvec auto self_sizes = self.sizes().vec(); self_sizes.pop_back(); + ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); // magmaSymeig uses a hybrid CPU-GPU algorithm to compute the eigenvalues and eigenvectors. // The driver routine magma_(d/s)syev_gpu accepts a tensor on the CPU for eigvalenvalues. @@ -1291,15 +2015,15 @@ std::tuple _symeig_helper_cuda(const Tensor& self, bool eigenvec // In the case where self.numel() == 0, we just return an empty tensor of // dimensions on the CUDA (to avoid the unnecessary "to(at::kCUDA)") auto eigvals_working_copy = self.numel() == 0 - ? at::empty(self_sizes, self.options()) - : at::empty(self_sizes, self.options().device(at::kCPU)); + ? at::empty(self_sizes, self.options().dtype(dtype)) + : at::empty(self_sizes, self.options().dtype(dtype).device(at::kCPU)); if (self.numel() == 0) { return std::tuple(eigvals_working_copy, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); } auto self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "symeig_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "symeig_cuda", [&]{ apply_symeig(self_working_copy, eigvals_working_copy, eigenvectors, upper, infos); }); @@ -1315,6 +2039,101 @@ std::tuple _symeig_helper_cuda(const Tensor& self, bool eigenvec } } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// magmaEig uses a hybrid CPU-GPU algorithm, which takes and return CPU +// memory. So, we accept a GPU tensor, copy it to CPU memory, and later copy +// the returned values from CPU to GPU. See also magmaSymeig, which uses a +// similar approach. + +template +static void apply_eig(const Tensor& self, bool eigenvectors, Tensor& out_eigvals, Tensor& out_eigvecs, + int64_t *info_ptr) { +#ifndef USE_MAGMA +TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorch with MAGMA. " + "Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA."); +#else + TORCH_INTERNAL_ASSERT(self.device() == at::kCPU, "Internal error: apply_eig needs a CPU tensor"); + magma_vec_t jobvr = eigenvectors ? MagmaVec : MagmaNoVec; + magma_int_t n = magma_int_cast(self.size(-1), "n"); + auto self_data = self.data_ptr(); + + auto out_eigvals_data = out_eigvals.data_ptr(); + scalar_t *wr = out_eigvals_data; + scalar_t *wi = out_eigvals_data+n; + + scalar_t *vr_data = NULL; + magma_int_t ldvr = 1; + if (jobvr == MagmaVec) + { + vr_data = out_eigvecs.data_ptr(); + ldvr = n; + } + + if (n > 0) { + // call magmaEig once to get the optimal size of work_data + scalar_t wkopt; + magma_int_t info; + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info); + magma_int_t lwork = (magma_int_t) wkopt; + + // call it a 2nd time to to the actual work + scalar_t *work_data = nullptr; + ALLOCATE_ARRAY(work_data, scalar_t, lwork); + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info); + *info_ptr = info; + } +#endif +} + +/* + * Internal helper; like eig_cuda but: + * 1. assume that self is a square matrix of side "n" + * 2. return CPU tensors (because this is what magmaEig returns), which will be copied to GPU memory + * by the caller + */ +std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvectors) { + int64_t n = self.size(-1); + // copy self to pinned CPU memory + auto self_working_copy = at::empty_strided( + {n, n}, // square matrix + {1, n}, // column-ordered, as magmaEig expects + at::TensorOptions(at::kCPU).dtype(self.dtype()).pinned_memory(true)); + self_working_copy.copy_(self); + + // tensors holding the results. We use empty_strided to make them column-ordered + auto options = self.options().device(at::kCPU).memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto out_eigvals = at::empty_strided({n, 2}, {1, n}, options); + auto out_eigvecs = eigenvectors + ? at::empty_strided({n, n}, {1, n}, options) + : Tensor(); + + int64_t info; + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cuda", [&]{ + apply_eig(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, &info); + }); + singleCheckErrors(info, "eig_cuda"); + + return std::tuple(out_eigvals, out_eigvecs); +} + +REGISTER_DISPATCH(eig_stub, &eig_kernel_impl); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self' +// compute_eigenvectors controls whether eigenvectors should be computed +// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" +// '_symeig_helper_cuda' prepares correct input for 'apply_symeig' and checks for possible errors using 'infos' +// See also CPU implementation in aten/src/ATen/native/BatchLinearAlgebra.cpp +std::tuple _syevd_helper_cuda(const Tensor& self, bool compute_eigenvectors, std::string uplo_str) { + // NumPy allows lowercase input for UPLO argument + // It is assumed that uplo_str is either "U" or "L" + char uplo = std::toupper(uplo_str[0]); + bool upper = uplo == 'U' ? true : false; + return _symeig_helper_cuda(self, compute_eigenvectors, upper); +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -1324,9 +2143,10 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, AT_ERROR("svd: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else + using value_t = typename c10::scalar_value_type::type; auto self_data = self.data_ptr(); auto U_data = U.data_ptr(); - auto S_data = S.data_ptr(); + auto S_data = S.data_ptr(); auto VT_data = VT.data_ptr(); auto self_stride = matrixStride(self); auto U_stride = matrixStride(U); @@ -1338,7 +2158,18 @@ AT_ERROR("svd: MAGMA library not found in " magma_int_t m = magma_int_cast(self.size(-2), "m"); magma_int_t n = magma_int_cast(self.size(-1), "n"); - auto k = std::min(m, n); + auto mn = std::min(m, n); + + c10::Storage storage_rwork; + value_t* rwork = nullptr; + + magma_int_t* iwork; + ALLOCATE_ARRAY(iwork, magma_int_t, 8 * mn); + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + auto lrwork = computeLRWorkDim(jobchar, m, n); + storage_rwork = pin_memory(lrwork); + rwork = static_cast(storage_rwork.data()); + } magma_int_t info = 0; // Run once, first to get the optimum work size. @@ -1347,22 +2178,20 @@ AT_ERROR("svd: MAGMA library not found in " // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() magma_int_t lwork = -1; scalar_t wkopt; - magma_int_t* iwork; - ALLOCATE_ARRAY(iwork, magma_int_t, 8 * k); - magmaSvd(jobz, m, n, self_data, m, S_data, U_data, m, VT_data, n, &wkopt, lwork, iwork, &info); - lwork = magma_int_cast(wkopt, "work_size"); + magmaSvd(jobz, m, n, self_data, m, S_data, U_data, m, VT_data, n, &wkopt, lwork, rwork, iwork, &info); + lwork = magma_int_cast(real_impl(wkopt), "work_size"); scalar_t* work; ALLOCATE_ARRAY(work, scalar_t, lwork); for (int64_t i = 0; i < batchsize; i++) { scalar_t* self_working_ptr = &self_data[i * self_stride]; - scalar_t* S_working_ptr = &S_data[i * S_stride]; + value_t* S_working_ptr = &S_data[i * S_stride]; scalar_t* U_working_ptr = &U_data[i * U_stride]; scalar_t* VT_working_ptr = &VT_data[i * VT_stride]; // Compute S, U (optionally), VT (optionally) - magmaSvd(jobz, m, n, self_working_ptr, m, - S_working_ptr, U_working_ptr, m, VT_working_ptr, n, work, lwork, iwork, &info); + magmaSvd(jobz, m, n, self_working_ptr, m, + S_working_ptr, U_working_ptr, m, VT_working_ptr, n, work, lwork, rwork, iwork, &info); infos[i] = info; if (info != 0) { return; @@ -1395,7 +2224,7 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som at::TensorOptions(at::kCPU).dtype(self.dtype()).pinned_memory(true)); self_working_copy.copy_(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "svd_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda", [&] { apply_svd(self_working_copy, U_working_copy, S_working_copy, VT_working_copy, jobchar, infos); }); @@ -1406,12 +2235,12 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som } U_working_copy = same_stride_to(U_working_copy, self.options()); - S_working_copy = same_stride_to(S_working_copy, self.options()); + S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device())); VT_working_copy = same_stride_to(VT_working_copy, self.options()); if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -1419,9 +2248,11 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som } } else { U_working_copy = same_stride_to(U_working_copy, self.options()).zero_(); - S_working_copy = same_stride_to(S_working_copy, self.options()); + S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device())); VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } @@ -1509,7 +2340,7 @@ Tensor _lu_solve_helper_cuda(const Tensor& self, const Tensor& LU_data, const Te if (self.numel() == 0 || LU_data.numel() == 0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{ apply_lu_solve(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, info); }); TORCH_CHECK(info == 0, "MAGMA lu_solve : invalid argument: ", -info); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu index cf23b73a4cf05..534f257d55bbc 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu @@ -26,41 +26,50 @@ inline static Tensor column_major_identity_matrix_like(const Tensor& self) { } template -inline static void _apply_single_inverse_helper(scalar_t* self_ptr, scalar_t* self_inv_ptr, int* ipiv_ptr, int* info_ptr, int n) { +inline static void _apply_single_inverse_helper(scalar_t* self_ptr, scalar_t* self_inv_ptr, int* ipiv_ptr, int* info_getrf_ptr, int* info_getrs_ptr, int n, int lda) { // self_inv_ptr should already be an identity matrix auto handle = at::cuda::getCurrentCUDASolverDnHandle(); - at::cuda::solver::getrf(handle, n, n, self_ptr, n, ipiv_ptr, info_ptr); - at::cuda::solver::getrs(handle, n, n, self_ptr, n, ipiv_ptr, self_inv_ptr, n, info_ptr); + at::cuda::solver::getrf(handle, n, n, self_ptr, lda, ipiv_ptr, info_getrf_ptr); + at::cuda::solver::getrs(handle, n, n, self_ptr, lda, ipiv_ptr, self_inv_ptr, lda, info_getrs_ptr); } template -static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& infos) { +static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& infos_getrf, Tensor& infos_getrs) { const int batch_size = cuda_int_cast(batchCount(self), "batchCount"); const int n = cuda_int_cast(self.size(-2), "self.size(-2)"); + const int lda = std::max(1, n); auto self_data = self.data_ptr(); auto self_mat_stride = matrixStride(self); auto self_inv_data = self_inv.data_ptr(); auto self_inv_mat_stride = matrixStride(self_inv); + auto infos_getrf_data = infos_getrf.data_ptr(); + auto infos_getrs_data = infos_getrs.data_ptr(); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); if (use_loop_launch(batch_size, n)) { - int* p_infos = infos.data_ptr(); auto main_stream = at::cuda::getCurrentCUDAStream(); + at::cuda::CUDAEvent main_event; + main_event.record(main_stream); + for (int64_t i = 0; i < batch_size; i++) { auto stream = at::cuda::getStreamFromPool(); at::cuda::CUDAStreamGuard guard(stream); - at::cuda::CUDAEvent can_start; - can_start.record(main_stream); - can_start.block(main_stream); + main_event.block(stream); + + auto dataPtr = allocator.allocate(sizeof(int) * lda); + int* pivot = reinterpret_cast(dataPtr.get()); + + int* infos_getrf_working_ptr = &infos_getrf_data[i]; + int* infos_getrs_working_ptr = &infos_getrs_data[i]; - int* pivot = reinterpret_cast(allocator.allocate(sizeof(int) * n).get()); _apply_single_inverse_helper( - &self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, p_infos + i, n); + &self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, infos_getrf_working_ptr, infos_getrs_working_ptr, n, lda); at::cuda::CUDAEvent finished; finished.record(stream); @@ -77,27 +86,52 @@ static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& in reinterpret_cast(&self_inv_data[(batch_size-1) * self_inv_mat_stride]) + 1, static_cast(self_inv_mat_stride * sizeof(scalar_t)), self.options().dtype(at::kLong)); - int* ipiv_array = reinterpret_cast(allocator.allocate(sizeof(int)*batch_size*n).get()); + auto dataPtr = allocator.allocate(sizeof(int)*batch_size*lda); + int* ipiv_array = reinterpret_cast(dataPtr.get()); - at::cuda::blas::getrfBatched(n, reinterpret_cast(self_array.data_ptr()), n, - ipiv_array, infos.data_ptr(), batch_size); + at::cuda::blas::getrfBatched(n, reinterpret_cast(self_array.data_ptr()), lda, + ipiv_array, infos_getrf_data, batch_size); - at::cuda::blas::getriBatched(n, reinterpret_cast(self_array.data_ptr()), n, - ipiv_array, infos.data_ptr(), batch_size, reinterpret_cast(self_inv_array.data_ptr())); + at::cuda::blas::getriBatched(n, reinterpret_cast(self_array.data_ptr()), lda, + ipiv_array, reinterpret_cast(self_inv_array.data_ptr()), lda, infos_getrs_data, batch_size); } } template -static void apply_single_inverse_lib(const Tensor& self, Tensor& self_inv, int64_t& info) { +static void apply_single_inverse_lib(const Tensor& self, Tensor& self_inv, Tensor& infos_getrf, Tensor& infos_getrs) { int n = cuda_int_cast(self.size(-2), "self.size(-2)"); + int lda = std::max(1, n); - Tensor ipiv = at::empty({n}, self.options().dtype(at::kInt)); - Tensor info_tmp = at::zeros({1}, self.options().dtype(at::kInt)); + Tensor ipiv = at::empty({lda}, self.options().dtype(at::kInt)); _apply_single_inverse_helper( - self.data_ptr(), self_inv.data_ptr(), ipiv.data_ptr(), info_tmp.data_ptr(), n); + self.data_ptr(), self_inv.data_ptr(), ipiv.data_ptr(), infos_getrf.data_ptr(), infos_getrs.data_ptr(), n, lda); +} + +// This is a type dispatching helper function for 'apply_batched_inverse_lib' and 'apply_single_inverse_lib' +Tensor& _linalg_inv_out_helper_cuda_lib(Tensor& result, Tensor& infos_getrf, Tensor& infos_getrs) { + // assuming result is in column major order and contains the matrices to invert + Tensor input_working_copy = cloneBatchedColumnMajor(result); + + // for getrf + getrs (cusolver path) + // result should be filled with identity matrices + result.zero_(); + result.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1); + + const int batch_size = cuda_int_cast(batchCount(result), "batchCount"); + + if (result.dim() > 2) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_inv_out_cuda", [&]{ + apply_batched_inverse_lib( + input_working_copy, result, infos_getrf, infos_getrs); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_inv_out_cuda", [&]{ + apply_single_inverse_lib(input_working_copy, result, infos_getrf, infos_getrs); + }); + } - info = info_tmp.item(); + return result; } Tensor _inverse_helper_cuda_lib(const Tensor& self) { @@ -106,18 +140,22 @@ Tensor _inverse_helper_cuda_lib(const Tensor& self) { const int batch_size = cuda_int_cast(batchCount(self), "batchCount"); if (self.dim() > 2 && batch_size > 1) { - Tensor infos = at::zeros({batchCount(self)}, self.options().dtype(kInt)); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + Tensor infos_getrf = at::zeros({std::max(1, batchCount(self))}, self.options().dtype(kInt)); + Tensor infos_getrs = at::zeros({std::max(1, batchCount(self))}, self.options().dtype(kInt)); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ apply_batched_inverse_lib( - self_working_copy, self_inv_working_copy, infos); + self_working_copy, self_inv_working_copy, infos_getrf, infos_getrs); }); - batchCheckErrors(infos, "inverse_cuda"); + batchCheckErrors(infos_getrf, "inverse_cuda"); + batchCheckErrors(infos_getrs, "inverse_cuda"); } else { - int64_t info = 0; - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{ - apply_single_inverse_lib(self_working_copy, self_inv_working_copy, info); + Tensor infos_getrf = at::zeros({1}, self.options().dtype(kInt)); + Tensor infos_getrs = at::zeros({1}, self.options().dtype(kInt)); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{ + apply_single_inverse_lib(self_working_copy, self_inv_working_copy, infos_getrf, infos_getrs); }); - singleCheckErrors(info, "inverse_cuda"); + batchCheckErrors(infos_getrf, "inverse_cuda"); + batchCheckErrors(infos_getrs, "inverse_cuda"); } return self_inv_working_copy; diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h index 85014c5773ee3..2be18137a64fb 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h @@ -7,8 +7,8 @@ #include #include -#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000 -// some cusolver functions doesn't work well on cuda 9.2, cusolver is used on cuda >= 10.0 +#if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && CUSOLVER_VERSION >= 10200 +// some cusolver functions don't work well on cuda 9.2 or cuda 10.1.105, cusolver is used on cuda >= 10.1.243 #define USE_CUSOLVER #endif @@ -18,6 +18,7 @@ namespace at { namespace native { Tensor _inverse_helper_cuda_lib(const Tensor& self); +Tensor& _linalg_inv_out_helper_cuda_lib(Tensor& result, Tensor& infos_getrf, Tensor& infos_getrs); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu index f05d73453dcf3..bbc85f7997e4a 100644 --- a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu @@ -8,12 +8,20 @@ namespace at { namespace native { -void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) { +template +struct AddFunctor { + AddFunctor(scalar_t a): alpha(a) {} + __device__ __forceinline__ scalar_t operator() (const scalar_t a, const scalar_t b) const { + return a + alpha * b; + } + private: + scalar_t alpha; +}; + +void add_kernel_cuda(TensorIteratorBase& iter, Scalar alpha_scalar) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { - auto alpha = alpha_scalar.to(); - gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a + alpha * b; - }); + AddFunctor f(alpha_scalar.to()); + gpu_kernel_with_scalars(iter, f); }); } diff --git a/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu index 128c05bed3cb8..30894b5687621 100644 --- a/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu @@ -9,60 +9,67 @@ namespace at { namespace native { -void bitwise_and_kernel_cuda(TensorIterator& iter) { - if (iter.dtype() == ScalarType::Bool) { - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(bool a, bool b) { - return a && b; - }); - } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cuda", [&]() { - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a & b; - }); - }); +template +struct BitwiseAndFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a & b; + } +}; + +template<> +struct BitwiseAndFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a && b; } +}; + +void bitwise_and_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_cuda", [&]() { + BitwiseAndFunctor f; + gpu_kernel_with_scalars(iter, f); + }); } -void bitwise_or_kernel_cuda(TensorIterator& iter) { - if (iter.dtype() == ScalarType::Bool) { - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(bool a, bool b) { - return a || b; - }); - } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cuda", [&]() { - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a | b; - }); - }); +template +struct BitwiseOrFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a | b; + } +}; + +template<> +struct BitwiseOrFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a || b; } +}; + +void bitwise_or_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_or_cuda", [&]() { + BitwiseOrFunctor f; + gpu_kernel_with_scalars(iter, f); + }); } -void bitwise_xor_kernel_cuda(TensorIterator& iter) { - if (iter.dtype() == ScalarType::Bool) { - // Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and - // integral types. - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(bool a, bool b) { - return a != b; - }); - } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_xor_cuda", [&]() { - gpu_kernel_with_scalars( - iter, - []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a ^ b; - }); - }); +template +struct BitwiseXorFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a ^ b; } +}; + +template<> +struct BitwiseXorFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a != b; + } +}; + +void bitwise_xor_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_xor_cuda", [&]() { + BitwiseXorFunctor f; + gpu_kernel_with_scalars(iter, f); + }); } REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/BinaryGeometricKernels.cu b/aten/src/ATen/native/cuda/BinaryGeometricKernels.cu new file mode 100644 index 0000000000000..49238691657c1 --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryGeometricKernels.cu @@ -0,0 +1,31 @@ +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at { namespace native { + +void atan2_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "atan2_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return ::atan2(a, b); + }); + }); +} + +void hypot_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "hypot_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return ::hypot(a, b); + }); + }); +} + +REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda); +REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu index 9b7bc28a829e2..a385aa7215220 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu @@ -16,10 +16,8 @@ namespace native { void sigmoid_backward_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "sigmoid_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "sigmoid_backward_cuda", [&] { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t(1.) - b) * b; - }); + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t(1.) - b) * b; }); }); } @@ -31,42 +29,46 @@ void logit_backward_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) { iter.dtype(), "logit_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "logit_cuda", [&] { - using T_ACC = acc_type; - const T_ACC eps = eps_scalar.to(); - if (eps < T_ACC(0)) { - gpu_kernel( - iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - const T_ACC dy_acc = static_cast(dy); - const T_ACC x_acc = static_cast(x); - return (x_acc < T_ACC(0) || x_acc > T_ACC(1)) - ? std::numeric_limits::quiet_NaN() - : dy_acc / (x_acc * (T_ACC(1) - x_acc)); - }); - } else { - const T_ACC lo = eps; - const T_ACC hi = T_ACC(1) - eps; - gpu_kernel( - iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - const T_ACC dy_acc = static_cast(dy); - const T_ACC x_acc = static_cast(x); - return (x_acc < lo || x_acc > hi) - ? T_ACC(0) - : dy_acc / (x_acc * (T_ACC(1) - x_acc)); - }); - } - }); + using T_ACC = acc_type; + const T_ACC eps = eps_scalar.to(); + if (eps < T_ACC(0)) { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < T_ACC(0) || x_acc > T_ACC(1)) + ? std::numeric_limits::quiet_NaN() + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } else { + const T_ACC lo = eps; + const T_ACC hi = T_ACC(1) - eps; + gpu_kernel( + iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < lo || x_acc > hi) + ? T_ACC(0) + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } }); } void tanh_backward_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "tanh_backward_cuda", [&] { + if(isComplexType(iter.dtype())) { + AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_complex_cuda", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * std::conj(scalar_t{1.} - b * b); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a * (scalar_t{1.} - b * b); }); }); - }); + } } REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index a2ffdb75c84b4..bc1884d8d642a 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -1,29 +1,21 @@ #include #include #include -#include #include #include - +#include // NOTE: CUDA on Windows requires that the enclosing function // of a __device__ lambda not have internal linkage. namespace at { namespace native { -void atan2_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "atan2_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return ::atan2(a, b); - }); - }); -} - -void smooth_l1_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&]() { - gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { +void smooth_l1_kernel_cuda(TensorIterator& iter, double beta) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&iter, beta]() { + scalar_t beta_val(beta); + gpu_kernel(iter, [beta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { auto z = ::abs(a - b); - return z < scalar_t(1.) ? scalar_t(0.5) * z * z : z - scalar_t(0.5); + return z < beta_val ? scalar_t(0.5) * z * z / beta_val : z - scalar_t(0.5) * beta_val; }); }); } @@ -38,84 +30,25 @@ void mse_kernel_cuda(TensorIterator& iter) { }); } -void logaddexp_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp_cuda", [&]() { - gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - if (::isinf(a) && a == b) { - return a; +void xlogy_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t { + if (at::_isnan(y)){ + return NAN; } - else { - scalar_t m = ::max(a, b); - return m + ::log((scalar_t)(1.0) + ::exp(-::abs(a - b))); + if (x == 0){ + return 0; } + return x * std::log(y); }); }); } -void logaddexp2_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp2_cuda", [&]() { - gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - if (::isinf(a) && a == b) { - return a; - } - else { - scalar_t m = ::max(a, b); - return m + ::log2((scalar_t)(1.0) + ::pow((scalar_t)(2.0), -::abs(a - b))); - } - }); - }); -} - -void gcd_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "gcd_cuda", [&]() { - gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - return calc_gcd(a, b); - }); - }); -} - -void lcm_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lcm_cuda", [&]() { - gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - scalar_t g = calc_gcd(a, b); - return (g == 0) ? 0 : ::abs(a / g * b); - }); - }); -} - -void hypot_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "hypot_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return ::hypot(a, b); - }); - }); -} - -void nextafter_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return ::nextafter(a, b); - }); - }); -} - -void heaviside_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a == 0 ? b : static_cast(a > 0); - }); - }); -} - -REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda); REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda); REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda); -REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel_cuda); -REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda); -REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda); -REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda); -REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda); -REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda); -REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda); +REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda); + +// DO NOT ADD ANY NEW KERNELS HERE +// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel. }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu index 044fc955b954b..f80d0906dfa26 100644 --- a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu @@ -11,6 +11,39 @@ namespace at { namespace native { +template +struct MulScalarFunctor { + MulScalarFunctor(accscalar_t b_): b(b_) {} + __device__ scalar_t operator() (scalar_t a) const { + return a * b; + } + private: + accscalar_t b; +}; + +template +struct DivFunctor { + __device__ scalar_t operator() (scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ scalar_t operator() (scalar_t a, scalar_t b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] +template<> +struct MulFunctor { + __device__ bool operator() (bool a, bool b) const { + return a && b; + } +}; + + void div_kernel_cuda(TensorIterator& iter) { if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) { // optimization for floating-point types: if the second operand is a CPU @@ -20,44 +53,35 @@ void div_kernel_cuda(TensorIterator& iter) { using accscalar_t = at::acc_type; auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); iter.remove_operand(2); - gpu_kernel(iter, [inv_b]GPU_LAMBDA(scalar_t a) -> scalar_t { - return a * inv_b; - }); + MulScalarFunctor f(inv_b); + gpu_kernel(iter, f); }); } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a / b; - }); + DivFunctor f; + gpu_kernel_with_scalars(iter, f); }); } } void mul_kernel_cuda(TensorIterator& iter) { - if (iter.common_dtype() == ScalarType::Bool) { - // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(bool a, bool b) -> bool { - return a && b; - }); - } else if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && + if (!isIntegralType(iter.common_dtype(), /*includeBool*/ true) && (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) { - //if common dtype is half the scalar constant can overflow in half precision, and yet the result can - //still be representable in the half dtype. Cast scalar to acc_type to have better accuracy - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "mul_cuda", [&]() { - using accscalar_t = at::acc_type; - int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2; - auto b = iter.scalar_value(scalar_arg); - iter.remove_operand(scalar_arg); - const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1))); - gpu_kernel(iter, [b]GPU_LAMBDA(scalar_t a) -> scalar_t { - return a * b; - }); - }); + //if common dtype is half the scalar constant can overflow in half precision, and yet the result can + //still be representable in the half dtype. Cast scalar to acc_type to have better accuracy + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "mul_cuda", [&]() { + using accscalar_t = at::acc_type; + int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2; + auto b = iter.scalar_value(scalar_arg); + iter.remove_operand(scalar_arg); + const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1))); + MulScalarFunctor f(b); + gpu_kernel(iter, f); + }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "mul_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * b; - }); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() { + MulFunctor f; + gpu_kernel_with_scalars(iter, f); }); } } diff --git a/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu b/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu index 86b2703797dcd..9b5cc2ce6ad1d 100644 --- a/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryRemainderKernel.cu @@ -1,8 +1,10 @@ #include #include #include -#include #include +#include + +#include // NOTE: CUDA on Windows requires that the enclosing function // of a __device__ lambda not have internal linkage. @@ -10,28 +12,48 @@ namespace at { namespace native { void remainder_kernel_cuda(TensorIterator& iter) { - if (isIntegralType(iter.dtype(), /*includeBool*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "remainder_cuda", [&]() { + if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cuda", [&]() { gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { scalar_t r = a % b; - if ((r != 0) && ((r < 0) != (b < 0))) { + if (!std::is_unsigned::value && (r != 0) && ((r < 0) != (b < 0))) { r += b; } return r; }); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "remainder_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "remainder_cuda", [&]() { gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { auto mod = ::fmod(a, b); - if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b; + if (!std::is_unsigned::value && (mod != 0) && ((b < 0) != (mod < 0))) { + mod += b; + } return mod; }); }); } } +void fmod_kernel_cuda(TensorIterator& iter) { + if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a % b; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(), "fmod_cuda", [&]() { + gpu_kernel_with_scalars(iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { + return ::fmod(a, b); + }); + }); + } +} + REGISTER_DISPATCH(remainder_stub, &remainder_kernel_cuda); +REGISTER_DISPATCH(fmod_stub, &fmod_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu index 4b6533e62db6f..67ff7954294df 100644 --- a/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu @@ -13,8 +13,9 @@ namespace at { namespace native { void lshift_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double || - iter.dtype() == ScalarType::Half) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "lshift_cuda", [&]() { + iter.dtype() == ScalarType::Half || + iter.dtype() == ScalarType::BFloat16) { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "lshift_cuda", [&]() { gpu_kernel_with_scalars( iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { @@ -34,8 +35,9 @@ void lshift_kernel_cuda(TensorIterator& iter) { void rshift_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == ScalarType::Float || iter.dtype() == ScalarType::Double || - iter.dtype() == ScalarType::Half) { - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "rshift_cuda", [&]() { + iter.dtype() == ScalarType::Half || + iter.dtype() == ScalarType::BFloat16) { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "rshift_cuda", [&]() { gpu_kernel_with_scalars( iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { diff --git a/aten/src/ATen/native/cuda/Bucketization.cu b/aten/src/ATen/native/cuda/Bucketization.cu index 84f62726cef0f..e28e7414aac6c 100644 --- a/aten/src/ATen/native/cuda/Bucketization.cu +++ b/aten/src/ATen/native/cuda/Bucketization.cu @@ -86,7 +86,7 @@ void searchsorted_cuda_contiguous(Tensor& result, const Tensor& input, const Ten searchsorted_cuda_kernel<<>>( data_out, data_in, data_bd, idim_in, idim_bd, numel_in, right, boundaries.dim() == 1); - THCudaCheck(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right) { diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 093ace17297c1..d11a5bb074c5d 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -101,9 +101,11 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t switch (vec_size) { case 4: vectorized_elementwise_kernel<4, func_t, array_t><<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: vectorized_elementwise_kernel<2, func_t, array_t><<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 1: { auto input_calc = TrivialOffsetCalculator(); @@ -111,12 +113,12 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); unrolled_elementwise_kernel<<>>(N, f, data, input_calc, output_calc, loader, storer); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); } - AT_CUDA_CHECK(cudaGetLastError()); } template @@ -127,11 +129,11 @@ static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t da int64_t grid = (N + block_work_size - 1) / block_work_size; auto stream = at::cuda::getCurrentCUDAStream(); unrolled_elementwise_kernel<<>>(N, f, data, ic, oc, l, s); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template -void gpu_kernel_impl(TensorIterator& iter, const func_t& f) { +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; using arg0_t = typename traits::result_type; constexpr int ntensors = traits::arity + 1; diff --git a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp deleted file mode 100644 index 2c390a1a6c68c..0000000000000 --- a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include -#include -#include - -namespace at { namespace native { - -// These are just forwarding stubs - -#define IMPLEMENT_UNARY_OP_PREQUEL(op, _th_op) \ - Tensor& _##op##__cuda(Tensor& self) { \ - return _th_op##_out(self, self); \ - } \ - Tensor& _##op##_out_cuda(Tensor& result, const Tensor& self) { \ - return _th_op##_out(result, self); \ - } - -}} diff --git a/aten/src/ATen/native/cuda/CompareEQKernel.cu b/aten/src/ATen/native/cuda/CompareEQKernel.cu index 947b53bce8fd4..20f76ce0d8e15 100644 --- a/aten/src/ATen/native/cuda/CompareEQKernel.cu +++ b/aten/src/ATen/native/cuda/CompareEQKernel.cu @@ -10,11 +10,16 @@ namespace at { namespace native { +template +struct CompareEqFunctor { + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + return a == b; + } +}; + void eq_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "eq_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { - return a == b; - }); + gpu_kernel_with_scalars(iter, CompareEqFunctor()); }); } diff --git a/aten/src/ATen/native/cuda/CompareGEKernel.cu b/aten/src/ATen/native/cuda/CompareGEKernel.cu index e276237ea8e6a..c96b7f3929bc8 100644 --- a/aten/src/ATen/native/cuda/CompareGEKernel.cu +++ b/aten/src/ATen/native/cuda/CompareGEKernel.cu @@ -10,11 +10,16 @@ namespace at { namespace native { +template +struct CompareGEFunctor { + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + return a >= b; + } +}; + void ge_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "ge_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { - return a >= b; - }); + gpu_kernel_with_scalars(iter, CompareGEFunctor()); }); } diff --git a/aten/src/ATen/native/cuda/CompareGTKernel.cu b/aten/src/ATen/native/cuda/CompareGTKernel.cu index c17b14855dd6c..cbd189ed1b6db 100644 --- a/aten/src/ATen/native/cuda/CompareGTKernel.cu +++ b/aten/src/ATen/native/cuda/CompareGTKernel.cu @@ -10,11 +10,16 @@ namespace at { namespace native { +template +struct CompareGTFunctor { + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + return a > b; + } +}; + void gt_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "gt_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { - return a > b; - }); + gpu_kernel_with_scalars(iter, CompareGTFunctor()); }); } diff --git a/aten/src/ATen/native/cuda/CompareLEKernel.cu b/aten/src/ATen/native/cuda/CompareLEKernel.cu index 3987b87e918cc..13e60a78ffb2f 100644 --- a/aten/src/ATen/native/cuda/CompareLEKernel.cu +++ b/aten/src/ATen/native/cuda/CompareLEKernel.cu @@ -10,11 +10,16 @@ namespace at { namespace native { +template +struct CompareLEFunctor { + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + return a <= b; + } +}; + void le_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "le_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { - return a <= b; - }); + gpu_kernel_with_scalars(iter, CompareLEFunctor()); }); } diff --git a/aten/src/ATen/native/cuda/CompareLTKernel.cu b/aten/src/ATen/native/cuda/CompareLTKernel.cu index 3684d65f66319..e301284c83e74 100644 --- a/aten/src/ATen/native/cuda/CompareLTKernel.cu +++ b/aten/src/ATen/native/cuda/CompareLTKernel.cu @@ -10,11 +10,16 @@ namespace at { namespace native { +template +struct CompareLTFunctor { + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + return a < b; + } +}; + void lt_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "lt_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { - return a < b; - }); + gpu_kernel_with_scalars(iter, CompareLTFunctor()); }); } diff --git a/aten/src/ATen/native/cuda/CompareNEKernel.cu b/aten/src/ATen/native/cuda/CompareNEKernel.cu index 0834a0d2b3bb7..3ef397ec52002 100644 --- a/aten/src/ATen/native/cuda/CompareNEKernel.cu +++ b/aten/src/ATen/native/cuda/CompareNEKernel.cu @@ -10,11 +10,16 @@ namespace at { namespace native { +template +struct CompareNEFunctor { + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + return a != b; + } +}; + void ne_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "ne_cuda", [&]() { - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { - return a != b; - }); + gpu_kernel_with_scalars(iter, CompareNEFunctor()); }); } diff --git a/aten/src/ATen/native/cuda/CopysignKernel.cu b/aten/src/ATen/native/cuda/CopysignKernel.cu new file mode 100644 index 0000000000000..442d649b2e99c --- /dev/null +++ b/aten/src/ATen/native/cuda/CopysignKernel.cu @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include + +#if defined(__CUDACC__) +#include +#include +#include +#elif defined(__HIPCC__) +#include +#include +#include +#endif + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at { namespace native { + +void copysign_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return c10::cuda::compat::copysign(a, b); + }); + }); +} + +REGISTER_DISPATCH(copysign_stub, ©sign_kernel_cuda); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/CuFFTPlanCache.h b/aten/src/ATen/native/cuda/CuFFTPlanCache.h index 60963d1db7e18..642326dd6a019 100644 --- a/aten/src/ATen/native/cuda/CuFFTPlanCache.h +++ b/aten/src/ATen/native/cuda/CuFFTPlanCache.h @@ -15,51 +15,102 @@ namespace at { namespace native { namespace detail { -// This POD struct is used to let us easily compute hashes of the +// Enum representing the FFT type +enum class CuFFTTransformType : int8_t { + C2C, // Complex-to-complex + R2C, // Real-to-complex + C2R, // Complex-to-real +}; + +// This struct is used to let us easily compute hashes of the // parameters. // It will be the **key** to the plan cache. struct CuFFTParams { - at::ScalarType scalar_type_; - int64_t input_sizes_[max_rank + 2]; - int64_t input_strides_[max_rank + 2]; - uint8_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3 - bool complex_input_; - bool complex_output_; - int64_t signal_sizes_[max_rank]; - bool onesided_; + int64_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3 + // These include additional batch dimension as well. + int64_t sizes_[max_rank + 1]; + int64_t input_strides_[max_rank + 1]; + int64_t output_strides_[max_rank + 1]; + CuFFTTransformType fft_type_; + ScalarType value_type_; + + CuFFTParams() = default; + + CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides, + IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) { + // Padding bits must be zeroed for hashing + memset(this, 0, sizeof(*this)); + signal_ndim_ = signal_sizes.size() - 1; + fft_type_ = fft_type; + value_type_ = value_type; + + TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size()); + TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size()); + TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank); + + std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_); + std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_); + std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_); + } }; -// NB: This can't be a constructor, because then CuFFTParams -// would not be a POD anymore. -static inline void setCuFFTParams(CuFFTParams* params, - const Tensor& input, int64_t signal_ndim, bool complex_input, - bool complex_output, IntArrayRef checked_signal_sizes, bool onesided) { - - memset(params, 0, sizeof(CuFFTParams)); - params->scalar_type_ = input.scalar_type(); - for (int i = 0; i != input.dim(); ++i) { - params->input_sizes_[i] = input.size(i); - if (input.size(i) != 1) { - params->input_strides_[i] = input.stride(i); - } +static_assert(std::is_trivial::value, ""); + +// Returns true if the transform type has complex input +inline bool cufft_complex_input(CuFFTTransformType type) { + switch (type) { + case CuFFTTransformType::C2C: + case CuFFTTransformType::C2R: + return true; + + case CuFFTTransformType::R2C: + return false; } - params->signal_ndim_ = (uint8_t) signal_ndim; - params->complex_input_ = complex_input; - params->complex_output_ = complex_output; - for (size_t i = 0; i != checked_signal_sizes.size(); ++i) { - params->signal_sizes_[i] = checked_signal_sizes[i]; + TORCH_INTERNAL_ASSERT(false); +} + +// Returns true if the transform type has complex output +inline bool cufft_complex_output(CuFFTTransformType type) { + switch (type) { + case CuFFTTransformType::C2C: + case CuFFTTransformType::R2C: + return true; + + case CuFFTTransformType::C2R: + return false; } - params->onesided_ = onesided; + TORCH_INTERNAL_ASSERT(false); } -struct CuFFTHandleDeleter { - void operator()(cufftHandle* x) { +// Create transform type enum from bools representing if input and output are complex +inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) { + if (complex_input && complex_output) { + return CuFFTTransformType::C2C; + } else if (complex_input && !complex_output) { + return CuFFTTransformType::C2R; + } else if (!complex_input && complex_output) { + return CuFFTTransformType::R2C; + } + TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported"); +} + + +class CuFFTHandle { + ::cufftHandle handle_; +public: + + CuFFTHandle() { + CUFFT_CHECK(cufftCreate(&handle_)); + } + + ::cufftHandle & get() { return handle_; } + const ::cufftHandle & get() const { return handle_; } + + ~CuFFTHandle() { // Not using fftDestroy() for rocFFT to work around double freeing of handles #ifndef __HIP_PLATFORM_HCC__ - if (x != nullptr) { - CUFFT_CHECK(cufftDestroy(*x)); - } + cufftDestroy(handle_); #endif } }; @@ -69,6 +120,101 @@ static bool is_pow_of_two(int64_t x) { return (x & (x - 1)) == 0; } +#ifdef __HIP_PLATFORM_HCC__ + using cufft_size_type = int; +#else + using cufft_size_type = long long int; +#endif + +using CuFFTDimVector = c10::SmallVector; + +// Struct representing a tensor in CuFFT's data layout for planning transforms +// See NOTE [ cuFFT Embedded Strides ]. +struct CuFFTDataLayout { + CuFFTDimVector embed; + cufft_size_type stride, dist; + bool must_clone, simple; +}; + +// Returns a cufft embedding for a contiguous signal of the given size. +// e.g. if the input is cloned, this will be the resulting data layout +// See NOTE [ cuFFT Embedded Strides ]. +inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) { + const auto signal_ndim = sizes.size() - 1; + CuFFTDataLayout layout; + layout.simple = true; + layout.must_clone = false; + layout.embed.assign(sizes.cbegin() + 1, sizes.cend()); + if (onesided) { + layout.embed.back() = sizes.back() / 2 + 1; + } + layout.stride = 1; + layout.dist = 1; + for (const auto& len : layout.embed) { + layout.dist *= len; + } + return layout; +} + +// Convert strides to a CuFFT embedded representation. +// If strides cannot be embedded, returns a simple layout and sets must_clone flag +// See NOTE [ cuFFT Embedded Strides ]. +inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) { + const auto signal_ndim = strides.size() - 1; + CuFFTDataLayout layout; + auto last_stride = strides[signal_ndim]; + layout.must_clone = (last_stride <= 0); + + const auto last_dim_size = onesided ? + sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim]; + const auto signal_numel = at::prod_intlist(sizes.slice(1, sizes.size() - 2)) * last_dim_size; + + // Zero stides are not allowed, even if the batch size is one. + // If that happens just set a dummy case + if (sizes[0] == 1) { + layout.dist = signal_numel; + } else if (strides[0] == 0) { + layout.must_clone = true; + } else { + layout.dist = strides[0]; + } + + // Calculate the embedding shape, or set must_clone if the strides cannot be embedded + layout.embed.resize(signal_ndim); + for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) { + auto stride = strides[i]; + if (sizes[i] == 1) { + layout.embed[i] = 1; + } else if (stride > 0 && stride % last_stride == 0) { + layout.embed[i] = stride / last_stride; + last_stride = stride; + } else { + layout.must_clone = true; + } + } + + if (layout.must_clone) { + // If the input needs to be cloned, assume it will be contiguous + layout = cufft_simple_embed(sizes, onesided); + layout.must_clone = true; + } else { + layout.embed[0] = sizes[1]; + layout.stride = strides[signal_ndim]; + // Determine if layout represents a simple embedding (contiguous data) + layout.simple = [&] { + for (int64_t i = 1; i < signal_ndim - 1; ++i) { + if (layout.embed[i] != sizes[i + 1]) { + return false; + } + } + + return (layout.stride == 1 && layout.dist == signal_numel && + layout.embed.back() == last_dim_size); + }(); + } + return layout; +} + // This class contains all the information needed to execute a cuFFT plan: // 1. the plan // 2. whether to clone input before executing the plan @@ -85,21 +231,26 @@ class CuFFTConfig { CuFFTConfig(const CuFFTConfig&) = delete; CuFFTConfig& operator=(CuFFTConfig const&) = delete; - explicit CuFFTConfig(Tensor& input, int64_t signal_ndim, bool complex_input, - bool complex_output, IntArrayRef checked_signal_sizes, bool onesided, - IntArrayRef output_sizes) { + explicit CuFFTConfig(const CuFFTParams& params): + CuFFTConfig( + IntArrayRef(params.input_strides_, params.signal_ndim_ + 1), + IntArrayRef(params.output_strides_, params.signal_ndim_ + 1), + IntArrayRef(params.sizes_, params.signal_ndim_ + 1), + params.fft_type_, + params.value_type_) {} - // signal sizes -#ifdef __HIP_PLATFORM_HCC__ - std::vector signal_sizes(checked_signal_sizes.begin(), - checked_signal_sizes.end()); -#else - std::vector signal_sizes(checked_signal_sizes.begin(), - checked_signal_sizes.end()); -#endif + // For complex types, strides are in units of 2 * element_size(dtype) + // sizes are for the full signal, including batch size and always two-sided + CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides, + IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype): + fft_type_(fft_type), value_type_(dtype) { + + // signal sizes (excluding batch dim) + CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end()); // input batch size - long long int batch = input.size(0); + const int64_t batch = sizes[0]; + const int64_t signal_ndim = sizes.size() - 1; // Since cuFFT has limited non-unit stride support and various constraints, we // use a flag to keep track throughout this function to see if we need to @@ -109,7 +260,7 @@ class CuFFTConfig { // For half, base strides on the real part of real-to-complex and // complex-to-real transforms are not supported. Since our output is always // contiguous, only need to check real-to-complex case. - if (input.scalar_type() == ScalarType::Half) { + if (dtype == ScalarType::Half) { // cuFFT on half requires compute capability of at least SM_53 auto dev_prop = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3), @@ -117,145 +268,68 @@ class CuFFTConfig { "capability less than SM_53, but the device containing input half " "tensor only has SM_", dev_prop->major, dev_prop->minor); for (int64_t i = 0; i < signal_ndim; i++) { - auto signal_size = checked_signal_sizes[i]; - TORCH_CHECK(is_pow_of_two(signal_size), - "cuFFT doesn't support signals of half type with size at any ", - "dimension that is not a power of two, but got a signal size of ", - checked_signal_sizes); + TORCH_CHECK(is_pow_of_two(sizes[i + 1]), + "cuFFT only supports dimensions whose sizes are powers of two when" + " computing in half precision, but got a signal size of", + sizes.slice(1)); } - clone_input |= input.stride(signal_ndim) != 1; + clone_input |= in_strides.back() != 1; } - // check the input sizes and strides to see if we need to make it contiguous - // cuFFT doesn't support batch dim with stride 0 - clone_input |= input.stride(0) == 0; - - if (complex_input) { - // Real/imag dimension must be like complex type. - clone_input |= input.stride(-1) != 1; - // Strides of other dimensions needs to be aligned when viewed as of complex - // type, i.e., multiples of 2. We check the batch dim and last signal dim - // here. If the input can be viewed as having embedded strides, the other - // signal dims will also satisfy this. - // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. - clone_input |= (batch > 0 && input.stride(0) % 2 != 0) || - input.stride(signal_ndim) % 2 != 0; - - // Complex to real FFTs may overwrite the input buffer (gh-34551) - clone_input |= !complex_output; - } - - // Checks if input strides can be viewed as embedded. - // See NOTE [ cuFFT Embedded Strides ]. - // - // TODO: Figure out why windows fails to compile - // c10::optional> inembed_opt = - // c10::nullopt; - // Then move the following to a helper function. -#ifdef __HIP_PLATFORM_HCC__ - std::vector inembed(signal_ndim); -#else - std::vector inembed(signal_ndim); -#endif - if (!clone_input) { - auto istrides = input.strides(); - auto last_istride = istrides[signal_ndim]; - clone_input = last_istride <= 0; - for (auto i = signal_ndim - 1; !clone_input && i > 0 /* inembed[0] doesn't matteer */; i--) { - auto istride = istrides[i]; - if (istride > 0 && istride % last_istride == 0) { - inembed[i] = istride / last_istride; - last_istride = istride; - } else { - clone_input = true; - } - } + CuFFTDataLayout in_layout; + if (clone_input) { + in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R); + } else { + in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R); } + auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C); + TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding"); + clone_input |= in_layout.must_clone; // Check if we can take advantage of simple data layout. // - // Note that this is before the actual cloning. This is intentional so we can - // check for advanced data layout with complex-to-real transform. cuFFT - // out-of-place complex-to-real transforms with advanced layout may overwrite - // input, and we need to clone the input. - // - // This just needs contiguity in cases except for twosided real-to-complex - // transform where we won't have simple data layout as output is two sided. - // // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. - bool simple_layout = !(!complex_input && complex_output && !onesided) && // not twosided R2C - (clone_input || input.is_contiguous()); // contiguous - if (!simple_layout && complex_input && !complex_output) { - clone_input = true; - simple_layout = true; - } - - // if input should be cloned but simple layout can't be used (e.g. twosided R2C) - if (clone_input && !simple_layout) { - auto input_size = input.sizes(); - std::copy(input_size.begin() + 1, // begin of signal dim in input - input_size.begin() + signal_ndim + 1, // end of signal dim in input - inembed.begin()); // begin of output - } + const bool simple_layout = in_layout.simple && out_layout.simple; #ifdef __HIP_PLATFORM_HCC__ - - hipfftType exec_type; - if (input.scalar_type() == ScalarType::Float) { - if (complex_input && complex_output) { - exec_type = HIPFFT_C2C; - } else if (complex_input && !complex_output) { - exec_type = HIPFFT_C2R; - } else if (!complex_input && complex_output) { - exec_type = HIPFFT_R2C; - } else { - AT_ERROR("hipFFT doesn't support r2r (float)"); - } - } else if (input.scalar_type() == ScalarType::Double) { - if (complex_input && complex_output) { - exec_type = HIPFFT_Z2Z; - } else if (complex_input && !complex_output) { - exec_type = HIPFFT_Z2D; - } else if (!complex_input && complex_output) { - exec_type = HIPFFT_D2Z; - } else { - AT_ERROR("hipFFT doesn't support r2r (double)"); + hipfftType exec_type = [&]{ + if (dtype == kFloat) { + switch (fft_type) { + case CuFFTTransformType::C2C: return HIPFFT_C2C; + case CuFFTTransformType::R2C: return HIPFFT_R2C; + case CuFFTTransformType::C2R: return HIPFFT_C2R; + } + } else if (dtype == kDouble) { + switch (fft_type) { + case CuFFTTransformType::C2C: return HIPFFT_Z2Z; + case CuFFTTransformType::R2C: return HIPFFT_D2Z; + case CuFFTTransformType::C2R: return HIPFFT_Z2D; + } } - } else { - std::ostringstream ss; - ss << "hipFFT doesn't support tensor of type: " - << toString(input.scalar_type()); - AT_ERROR(ss.str()); - } - + TORCH_CHECK(false, "hipFFT doesn't support transforms of type: ", dtype); + }(); #else cudaDataType itype, otype, exec_type; - if (input.scalar_type() == ScalarType::Float) { + const auto complex_input = cufft_complex_input(fft_type); + const auto complex_output = cufft_complex_output(fft_type); + if (dtype == ScalarType::Float) { itype = complex_input ? CUDA_C_32F : CUDA_R_32F; otype = complex_output ? CUDA_C_32F : CUDA_R_32F; exec_type = CUDA_C_32F; - } else if (input.scalar_type() == ScalarType::Double) { + } else if (dtype == ScalarType::Double) { itype = complex_input ? CUDA_C_64F : CUDA_R_64F; otype = complex_output ? CUDA_C_64F : CUDA_R_64F; exec_type = CUDA_C_64F; - } else if (input.scalar_type() == ScalarType::Half) { + } else if (dtype == ScalarType::Half) { itype = complex_input ? CUDA_C_16F : CUDA_R_16F; otype = complex_output ? CUDA_C_16F : CUDA_R_16F; exec_type = CUDA_C_16F; } else { - std::ostringstream ss; - ss << "cuFFT doesn't support tensor of type: " - << toString(input.scalar_type()); - AT_ERROR(ss.str()); + TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype); } #endif - // create plan - auto raw_plan_ptr = new cufftHandle(); - CUFFT_CHECK(cufftCreate(raw_plan_ptr)); - plan_ptr.reset(raw_plan_ptr); - // disable auto allocation of workspace to use THC allocator CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0)); @@ -264,8 +338,8 @@ class CuFFTConfig { // make plan if (simple_layout) { // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL. - // In such case, cuFFT ignores base_istride, base_ostride, idist, and odist - // by assuming base_istride = base_ostride = 1. + // In such case, cuFFT ignores istride, ostride, idist, and odist + // by assuming istride = ostride = 1. // // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. #ifdef __HIP_PLATFORM_HCC__ @@ -280,65 +354,34 @@ class CuFFTConfig { batch, &ws_size_t, exec_type)); #endif } else { - // set idist (stride at batch dim) - // set base_istride (stride at innermost dim of signal) - long long int idist, base_istride; - if (clone_input) { - idist = at::prod_intlist(input.sizes().slice(1, signal_ndim)); - base_istride = 1; - } else if (complex_input) { - idist = input.stride(0) >> 1; - base_istride = input.stride(signal_ndim) >> 1; - } else { - idist = input.stride(0); - base_istride = input.stride(signal_ndim); - } - // Even if batch dimension is one and idist (stride(0)) doesn't matter, - // cuFFT errors if idist = 0. This is hack to make it succeed. - if (idist == 0 && batch == 1) { - idist = 1; - } - - // set odist, onembed, base_ostride #ifdef __HIP_PLATFORM_HCC__ - int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim)); - std::vector onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1); - int base_ostride = 1; - - int istride = base_istride; - int iidist = idist; CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(), - inembed.data(), istride, iidist, - onembed.data(), base_ostride, odist, + in_layout.embed.data(), in_layout.stride, in_layout.dist, + out_layout.embed.data(), out_layout.stride, out_layout.dist, exec_type, batch, &ws_size_t)); #else - long long int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim)); - std::vector onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1); - long long int base_ostride = 1; - CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), - inembed.data(), base_istride, idist, itype, - onembed.data(), base_ostride, odist, otype, + in_layout.embed.data(), in_layout.stride, in_layout.dist, itype, + out_layout.embed.data(), out_layout.stride, out_layout.dist, otype, batch, &ws_size_t, exec_type)); #endif - } + } ws_size = static_cast(ws_size_t); } -#ifdef __HIP_PLATFORM_HCC__ - cufftHandle &plan() const { return *plan_ptr.get(); } -#else - const cufftHandle &plan() const { return *plan_ptr.get(); } -#endif + const cufftHandle &plan() const { return plan_ptr.get(); } + CuFFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } bool should_clone_input() const { return clone_input; } - int64_t workspace_size() const { return ws_size; } private: - std::unique_ptr plan_ptr; + CuFFTHandle plan_ptr; bool clone_input; int64_t ws_size; + CuFFTTransformType fft_type_; + ScalarType value_type_; }; #if CUDA_VERSION < 10000 @@ -392,15 +435,13 @@ class CuFFTParamsLRUCache { } // If key is in this cache, return the cached config. Otherwise, emplace the - // config in this cache using value_args and return it. + // config in this cache and return it. // Return const reference because CuFFTConfig shouldn't be tampered with once // created. - // This is similar to c++ 17 try_emplace. - template - const CuFFTConfig &try_emplace_value(K&& key, VArgs&&... value_args) { + const CuFFTConfig &lookup(CuFFTParams params) { AT_ASSERT(_max_size > 0); - map_kkv_iter_t map_it = _cache_map.find(key); + map_kkv_iter_t map_it = _cache_map.find(params); // Hit, put to list front if (map_it != _cache_map.end()) { _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second); @@ -418,8 +459,8 @@ class CuFFTParamsLRUCache { // construct new plan at list front, then insert into _cache_map _usage_list.emplace_front(std::piecewise_construct, - std::forward_as_tuple(key), - std::forward_as_tuple(value_args...)); + std::forward_as_tuple(params), + std::forward_as_tuple(params)); auto kv_it = _usage_list.begin(); _cache_map.emplace(std::piecewise_construct, std::forward_as_tuple(kv_it->first), diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu index 3e0e70c019526..083e34c5c8311 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu @@ -73,13 +73,13 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom template C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS) -__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch, +__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, - const int in_stride_n, const int in_stride_c, + const int in_stride_n, const int in_stride_c, const int in_stride_h, const int in_stride_w, const int kernel_stride_C, const int kernel_size_C, scalar_t* top_data, int64_t* top_mask) { @@ -100,9 +100,9 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba __syncthreads(); - int batch_id = blockIdx.x % nbatch; - int channel_id = blockIdx.x / nbatch; - int channel_offset = threadIdx.x + channel_id * blockDim.x; + int batch_id = blockIdx.x % nbatch; + int channel_id = blockIdx.x / nbatch; + int channel_offset = threadIdx.x + channel_id * blockDim.x; top_data = top_data + batch_id * pooled_height * pooled_width * channels; top_mask = top_mask + batch_id * pooled_height * pooled_width * channels; @@ -130,7 +130,7 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba wstart += dilation_w; for (int ih = hstart; ih < hend; ih++) { for (int iw = wstart; iw < wend; iw++) { - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w; for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { scalar_t val = ptr_input[c*in_stride_c]; @@ -138,20 +138,20 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba out_cached[cached_index] = scalar_cast(val); out_mask_cached[cached_index] = ih * width + iw; } - cached_index += blockDim.x; + cached_index += blockDim.x; } } } scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels; int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels; - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { ptr_output_data[c] = out_cached[cached_index]; ptr_output_mask[c] = out_mask_cached[cached_index]; out_cached[cached_index] = at::numeric_limits::lower_bound(); out_mask_cached[cached_index] = 0; - cached_index += blockDim.x; + cached_index += blockDim.x; } } } @@ -206,9 +206,9 @@ __global__ void max_pool_backward_nhwc(const int nthreads, const scalar_t* top_d const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int out_stride_c, const int out_stride_h, const int out_stride_w, - const int in_stride_n, const int in_stride_c, + const int in_stride_n, const int in_stride_c, const int in_stride_h, const int in_stride_w, - const int kernel_stride_C, const int kernel_size_C, + const int kernel_stride_C, const int kernel_size_C, scalar_t* bottom_diff) { extern __shared__ int smem[]; accscalar_t *out_cached = reinterpret_cast(smem); @@ -216,9 +216,9 @@ __global__ void max_pool_backward_nhwc(const int nthreads, const scalar_t* top_d int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); int block_size = blockDim.x * blockDim.y * blockDim.z; - int batch_id = blockIdx.x % nbatch; - int channel_id = blockIdx.x / nbatch; - int channel_offset = threadIdx.x + channel_id * blockDim.x; + int batch_id = blockIdx.x % nbatch; + int channel_id = blockIdx.x / nbatch; + int channel_offset = threadIdx.x + channel_id * blockDim.x; for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) { out_cached[i] = accscalar_t(0.0); @@ -245,38 +245,38 @@ __global__ void max_pool_backward_nhwc(const int nthreads, const scalar_t* top_d for (int iw = istartW; iw < iendW; iw+=blockDim.y) { int pwstart = p_start(iw, pad_w, kernel_w, dilation_w, stride_w); int pwend = p_end(iw, pad_w, pooled_width, stride_w); - int index_shift = ih * width + iw; + int index_shift = ih * width + iw; if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) { for(int oh = phstart; oh < phend; ++oh) { for(int ow = pwstart; ow < pwend; ++ow) { - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; const int64_t* ptr_top_mask = top_mask + oh * out_stride_h + ow * out_stride_w; for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { if (ptr_top_mask[c*out_stride_c] == index_shift) { - out_cached[cached_index] += + out_cached[cached_index] += scalar_cast(top_diff[oh * out_stride_h + ow * out_stride_w + c*out_stride_c]); } - cached_index += blockDim.x; + cached_index += blockDim.x; } } } scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels; - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { ptr_bottom_diff[c] = scalar_cast(out_cached[cached_index]); out_cached[cached_index] = accscalar_t(0.0); - cached_index += blockDim.x; + cached_index += blockDim.x; } } else { const int64_t* ptr_top_mask = top_mask + phstart * out_stride_h + pwstart * out_stride_w; scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels; - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { if (ptr_top_mask[c*out_stride_c] == index_shift) { - ptr_bottom_diff[c] = + ptr_bottom_diff[c] = scalar_cast(top_diff[phstart * out_stride_h + pwstart * out_stride_w + c*out_stride_c]); } - cached_index += blockDim.x; + cached_index += blockDim.x; } } } @@ -346,7 +346,7 @@ void max_pool2d_with_indices_out_cuda_template( kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth); + outputHeight, outputWidth, memory_format); Tensor input = input_.contiguous(memory_format); @@ -388,9 +388,9 @@ void max_pool2d_with_indices_out_cuda_template( const dim3 block(block_x, block_y, block_z); int kernel_stride_C = cuda::ATenCeilDiv( - safe_downcast(nInputPlane), block_x * 4); + safe_downcast(nInputPlane), block_x * 4); int kernel_size_C = cuda::ATenCeilDiv( - safe_downcast(nInputPlane), block_x * kernel_stride_C); + safe_downcast(nInputPlane), block_x * kernel_stride_C); int grid_x = nbatch*kernel_stride_C; int grid_y = std::min( @@ -402,17 +402,18 @@ void max_pool2d_with_indices_out_cuda_template( const dim3 grid(grid_x, grid_y, grid_z); size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t)); - AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); + AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); max_pool_forward_nhwc <<>>( - input_data, nbatch, + input_data, nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, - in_stride_n, in_stride_c, + in_stride_n, in_stride_c, in_stride_h, in_stride_w, - kernel_stride_C, kernel_size_C, + kernel_stride_C, kernel_size_C, output_data, indices_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } case MemoryFormat::Contiguous: { @@ -424,6 +425,7 @@ void max_pool2d_with_indices_out_cuda_template( nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); @@ -431,8 +433,6 @@ void max_pool2d_with_indices_out_cuda_template( } ); - AT_CUDA_CHECK(cudaGetLastError()); - if(input.ndimension() == 3) { output.resize_({nInputPlane, outputHeight, outputWidth}); indices.resize_({nInputPlane, outputHeight, outputWidth}); @@ -513,7 +513,7 @@ void max_pool2d_with_indices_backward_out_cuda_template( kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth, + outputHeight, outputWidth, memory_format, /*cuda=*/ true); const Tensor gradOutput = gradOutput_.contiguous(memory_format); @@ -565,11 +565,11 @@ void max_pool2d_with_indices_backward_out_cuda_template( const dim3 grid(grid_x, grid_y, grid_z); size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t); - AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); + AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); - // The backward kernel is launched on input instead output. - // If it is launched on output layer, atomic_add would not provide much benefit on FP16. - // Please check comments at https://github.com/pytorch/pytorch/pull/34519. + // The backward kernel is launched on input instead output. + // If it is launched on output layer, atomic_add would not provide much benefit on FP16. + // Please check comments at https://github.com/pytorch/pytorch/pull/34519. max_pool_backward_nhwc <<>>( count, @@ -579,10 +579,11 @@ void max_pool2d_with_indices_backward_out_cuda_template( nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, out_stride_c, out_stride_h, out_stride_w, - in_stride_n, in_stride_c, + in_stride_n, in_stride_c, in_stride_h, in_stride_w, - kernel_stride_C, kernel_size_C, + kernel_stride_C, kernel_size_C, gradInput_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } case MemoryFormat::Contiguous: { @@ -606,14 +607,13 @@ void max_pool2d_with_indices_backward_out_cuda_template( nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, gradInput_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } } ); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu index 9d72e0027007d..e6eacbb8424ee 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu @@ -112,8 +112,7 @@ void max_pool3d_with_indices_out_frame( pT, pH, pW, dilationT, dilationH, dilationW, offsetZ); - - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; @@ -178,8 +177,7 @@ void max_pool3d_with_indices_backward_out_frame( pT, pH, pW, dilationT, dilationH, dilationW, offsetZ); - - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; diff --git a/aten/src/ATen/native/cuda/DistanceKernel.cu b/aten/src/ATen/native/cuda/DistanceKernel.cu index 385cac5c79e82..515388a0fe3e1 100644 --- a/aten/src/ATen/native/cuda/DistanceKernel.cu +++ b/aten/src/ATen/native/cuda/DistanceKernel.cu @@ -50,7 +50,9 @@ struct dists { // Special case backward when p is less than two struct lt_two { - static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1); } + static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { + return (dist == 0.0 || (diff == 0.0 && p < 1)) ? 0 : (sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1)); + } }; // Two norm @@ -229,17 +231,21 @@ void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, doubl AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] { if (p == 0.0) { cdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 1.0) { cdist_kernel_cuda_impl::one><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { cdist_kernel_cuda_impl::two><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { cdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { cdist_kernel_cuda_impl::p><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); } void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { @@ -255,17 +261,21 @@ void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] { if (p == 0.0) { pdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 1.0) { pdist_kernel_cuda_impl::one><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { pdist_kernel_cuda_impl::two><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { pdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { pdist_kernel_cuda_impl::p><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); } void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) { @@ -293,17 +303,21 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] { if (p == 1.0) { pdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p < 2.0) { pdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { pdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { pdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { pdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); at::sum_out(result, buffer, 0); } @@ -340,25 +354,29 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor cdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p < 2.0) { cdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { cdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { cdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { cdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); if (x1.dim() > 2) { at::sum_out(result, buffer, 1); diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h index 1cf107c171f43..1b4f228bf229b 100644 --- a/aten/src/ATen/native/cuda/DistributionTemplates.h +++ b/aten/src/ATen/native/cuda/DistributionTemplates.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -62,16 +63,17 @@ std::tuple calc_execution_policy(int64_t total_elements) { template C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) __global__ void distribution_elementwise_grid_stride_kernel(int numel, - std::pair seeds, + PhiloxCudaState philox_args, const dist_t dist_func, const transform_t transform_func) { + auto seeds = at::cuda::philox::unpack(philox_args); int idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; - curand_init( - seeds.first, - idx, - seeds.second, - &state); + curand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * blockDim.x * gridDim.x * unroll_factor; for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { @@ -123,11 +125,11 @@ void distribution_nullary_kernel(at::TensorIterator& iter, auto counter_offset = std::get<0>(execution_policy); auto grid = std::get<1>(execution_policy); auto block = std::get<2>(execution_policy); - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); } if (!iter.can_use_32bit_indexing()) { @@ -153,6 +155,7 @@ void distribution_nullary_kernel(at::TensorIterator& iter, *out = transform_func(rand); } ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto offset_calc = make_offset_calculator<1>(iter); distribution_elementwise_grid_stride_kernel<<>>( @@ -165,8 +168,8 @@ void distribution_nullary_kernel(at::TensorIterator& iter, *out = transform_func(rand); } ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); } // Binary kernel @@ -174,12 +177,14 @@ template seeds, + PhiloxCudaState philox_args, typename function_traits::result_type *output_data, const typename function_traits::template arg<1>::type *input_data_1, const typename function_traits::template arg<2>::type *input_data_2, inp_offset_calc_t inp_calc, out_offset_calc_t out_calc) { + auto seeds = at::cuda::philox::unpack(philox_args); + using input_t_1 = typename function_traits::template arg<1>::type; using input_t_2 = typename function_traits::template arg<2>::type; @@ -190,7 +195,10 @@ __global__ void distribution_binary_elementwise_kernel( int remaining = std::min(numel - base_index, BLOCK_WORK_SIZE); curandStatePhilox4_32_10_t state; - curand_init(seeds.first, blockIdx.x * blockDim.x + threadIdx.x, seeds.second, &state); + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); // load data into registers int thread_idx = threadIdx.x; @@ -222,7 +230,7 @@ __global__ void distribution_binary_elementwise_kernel( } template -void distribution_binary_kernel(TensorIterator &iter, std::pair seeds, const func_t &f) { +void distribution_binary_kernel(TensorIterator &iter, PhiloxCudaState philox_args, const func_t &f) { static_assert(std::is_same::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t"); using input_t_1 = typename function_traits::template arg<1>::type; using input_t_2 = typename function_traits::template arg<2>::type; @@ -230,7 +238,7 @@ void distribution_binary_kernel(TensorIterator &iter, std::pair>>( - numel, f, seeds, output_data, input_data_1, input_data_2, + numel, f, philox_args, output_data, input_data_1, input_data_2, TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { distribution_binary_elementwise_kernel<<>>( - numel, f, seeds, output_data, input_data_1, input_data_2, + numel, f, philox_args, output_data, input_data_1, input_data_2, make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -570,20 +580,17 @@ struct CauchyKernel { template void bernoulli_tensor_cuda_kernel( at::Tensor& ret, const at::Tensor& p, - std::pair seeds) { - // The template argument `4` below indicates that we want to operate on four - // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. - at::cuda::CUDA_tensor_apply2( - ret, p, - [seeds] __device__( + PhiloxCudaState philox_args) { + auto functor = [philox_args] __device__( int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4, const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) { + auto seeds = at::cuda::philox::unpack(philox_args); curandStatePhilox4_32_10_t state; - curand_init( - seeds.first, - blockIdx.x * blockDim.x + threadIdx.x, - seeds.second, - &state); + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + // See Note [Register spilling in curand call for CUDA < 10] float4 rand = curand_uniform4(&state); switch (n) { @@ -607,17 +614,21 @@ void bernoulli_tensor_cuda_kernel( v1 = static_cast(rand.x <= p1); } } - } - ); + }; + // The template argument `4` below indicates that we want to operate on four + // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. + at::cuda::CUDA_tensor_apply2(ret, p, functor); } template void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG gen) { - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(10); + rng_engine_inputs = gen->philox_cuda_state(10); } auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA))); AT_DISPATCH_ALL_TYPES_AND3( diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu index 90b7644abfe38..dd09efc9e7199 100644 --- a/aten/src/ATen/native/cuda/Distributions.cu +++ b/aten/src/ATen/native/cuda/Distributions.cu @@ -49,20 +49,20 @@ template void poisson_cuda_kernel( at::Tensor& ret, const at::Tensor& lambda, - std::pair seeds) { - at::cuda::CUDA_tensor_apply2( - ret, - lambda, - [seeds] __device__( + at::PhiloxCudaState philox_args) { + auto functor = [philox_args] __device__( scalar_t & ret_val, const scalar_t& lambda) { + auto seeds = at::cuda::philox::unpack(philox_args); curandStatePhilox4_32_10_t state; - curand_init( - seeds.first, - blockIdx.x * blockDim.x + threadIdx.x, - seeds.second, - &state); + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); ret_val = static_cast(curand_poisson(&state, lambda)); - }); + }; + at::cuda::CUDA_tensor_apply2(ret, lambda, functor); } struct curand_uniform_wrapper { @@ -82,7 +82,7 @@ void binomial_cuda_kernel( at::Tensor& ret, const at::Tensor& count, const at::Tensor& prob, - std::pair seeds) { + at::PhiloxCudaState philox_args) { using accscalar_t = at::acc_type; at::TensorIterator iter = at::TensorIteratorConfig() .add_output(ret) @@ -90,8 +90,8 @@ void binomial_cuda_kernel( .add_input(prob) .build(); - at::native::distribution_binary_kernel(iter, seeds, - [seeds] GPU_LAMBDA (curandStatePhilox4_32_10_t& state, scalar_t count, scalar_t prob) { + at::native::distribution_binary_kernel(iter, philox_args, + [philox_args] GPU_LAMBDA (curandStatePhilox4_32_10_t& state, scalar_t count, scalar_t prob) { #if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__) auto uniform_lambda = curand_uniform_wrapper(state); BaseSampler standard_uniform(uniform_lambda); @@ -108,19 +108,16 @@ template void gamma_cuda_kernel( at::Tensor& ret, const at::Tensor& alpha, - std::pair seeds) { + at::PhiloxCudaState philox_args) { using accscalar_t = at::acc_type; - at::cuda::CUDA_tensor_apply2( - ret, - alpha, - [seeds] __device__( + auto functor = [philox_args] __device__( scalar_t & ret_val, const scalar_t& alpha) { + auto seeds = at::cuda::philox::unpack(philox_args); curandStatePhilox4_32_10_t state; - curand_init( - seeds.first, - blockIdx.x * blockDim.x + threadIdx.x, - seeds.second, - &state); + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); auto uniform_lambda = [&state] __device__ () { return curand_uniform(&state); @@ -134,7 +131,10 @@ void gamma_cuda_kernel( auto sample = sample_gamma(alpha, standard_uniform, standard_normal); auto min_value = std::numeric_limits::min(); ret_val = (min_value > sample) ? min_value : sample; - }); + }; + at::cuda::CUDA_tensor_apply2(ret, alpha, functor); } template @@ -164,11 +164,11 @@ namespace at { namespace native { Tensor _s_poisson_cuda(const Tensor& lambda, c10::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(20); + rng_engine_inputs = gen->philox_cuda_state(20); } Tensor ret = at::empty(lambda.sizes(), lambda.options()); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "poisson_cuda", [&] { @@ -179,11 +179,11 @@ Tensor _s_poisson_cuda(const Tensor& lambda, c10::optional gen_) { Tensor _s_binomial_cuda(const Tensor& count, const Tensor& prob, c10::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(42); + rng_engine_inputs = gen->philox_cuda_state(42); } Tensor ret = at::empty(count.sizes(), count.options()); AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "binomial_cuda", [&] { @@ -194,11 +194,11 @@ Tensor _s_binomial_cuda(const Tensor& count, const Tensor& prob, c10::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(10); + rng_engine_inputs = gen->philox_cuda_state(10); } Tensor ret = at::empty(alpha.sizes(), alpha.options()); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "gamma_cuda", [&] { @@ -209,11 +209,11 @@ Tensor _s_gamma_cuda(const Tensor& alpha, c10::optional gen_) { Tensor _s_dirichlet_cuda(const Tensor& alpha, c10::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(10); + rng_engine_inputs = gen->philox_cuda_state(10); } Tensor ret = at::empty(alpha.sizes(), alpha.options()); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "dirichlet", [&] { diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index 2016e96c9fd81..c3e456d970560 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -30,31 +31,37 @@ template < int ADims, int VEC> #if __CUDA_ARCH__ >= 350 -C10_LAUNCH_BOUNDS_2(256, 8) +C10_LAUNCH_BOUNDS_2(256, 4) #elif defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2(256, 4) #endif __global__ void fused_dropout_kernel_vec(at::cuda::detail::TensorInfo a, - at::cuda::detail::TensorInfo b, - at::cuda::detail::TensorInfo c, - IndexType totalElements, accscalar_t p, std::pair seeds - ) { - + at::cuda::detail::TensorInfo b, + at::cuda::detail::TensorInfo c, + IndexType totalElements, accscalar_t p, + PhiloxCudaState philox_args) { // make sure we don't break assumption that we can't have > 4 elements / thread static_assert(VEC <= 4, "Value of VEC must be in [2, 4]"); using LoadT = memory::aligned_vector; using MaskLoadT = memory::aligned_vector; - accscalar_t pinv = accscalar_t(1)/p; + auto seeds = at::cuda::philox::unpack(philox_args); IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; - curand_init( - seeds.first, - idx, - seeds.second, - &state); + curand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + + accscalar_t pinv = accscalar_t(1)/p; + + // Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements + // in the vec=2 and vec=4 cases. + bool gridxvec_loop_state = 0; + + float4 rand; // Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time for (IndexType linearIndex = idx * VEC; @@ -68,12 +75,21 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo a, //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything // Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4) // sets of rand. - float4 rand = curand_uniform4(&state); + if ((VEC == 4) || (gridxvec_loop_state == 0)) { + rand = curand_uniform4(&state); + } else { + // sets up the last two values we generated last iteration to be used this iteration. + rand.x = rand.z; + rand.y = rand.w; + gridxvec_loop_state ^= 1; + } rand.x = rand.x < p; rand.y = rand.y < p; - rand.z = rand.z < p; - rand.w = rand.w < p; + if (VEC == 4) { + rand.z = rand.z < p; + rand.w = rand.w < p; + } // Note: We explicitly check for is_contiguous() before launching the vectorized kernel // and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other) @@ -102,27 +118,29 @@ template < typename scalar_t, typename accscalar_t, typename IndexType, - int ADims> + int ADims, + int BDims=ADims> #if __CUDA_ARCH__ >= 350 -C10_LAUNCH_BOUNDS_2(256, 8) +C10_LAUNCH_BOUNDS_2(256, 4) #elif defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2(256, 4) #endif __global__ void fused_dropout_kernel(cuda::detail::TensorInfo a, - cuda::detail::TensorInfo b, - cuda::detail::TensorInfo c, - IndexType totalElements, accscalar_t p, std::pair seeds - ) { - - accscalar_t pinv = accscalar_t(1)/p; + cuda::detail::TensorInfo b, + cuda::detail::TensorInfo c, + IndexType totalElements, accscalar_t p, + PhiloxCudaState philox_args) { + auto seeds = at::cuda::philox::unpack(philox_args); IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; - curand_init( - seeds.first, - idx, - seeds.second, - &state); + curand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + + accscalar_t pinv = accscalar_t(1)/p; + IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL; for (IndexType linearIndex = idx; @@ -149,7 +167,7 @@ fused_dropout_kernel(cuda::detail::TensorInfo a, if (li < totalElements) { // Convert `linearIndex` into an offset of `b` const IndexType bOffset = - cuda::detail::IndexToOffset::get(li, b); + cuda::detail::IndexToOffset::get(li, b); b.data[bOffset] = src[ii]*(&rand.x)[ii]*pinv; c.data[bOffset] = (uint8_t)(&rand.x)[ii]; } @@ -178,8 +196,7 @@ template int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { int vec_size = 4; // get the vector size - auto memory_format = self.suggest_memory_format(); - if (!self.is_contiguous(memory_format) || !ret.is_contiguous(memory_format) || !mask.is_contiguous(memory_format)) { + if (!self.is_non_overlapping_and_dense() || !ret.is_non_overlapping_and_dense() || !mask.is_non_overlapping_and_dense()) { vec_size = 1; } else { vec_size = memory::can_vectorize_up_to((char*)self.data_ptr()); @@ -194,13 +211,126 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { return can_vectorize ? vec_size : 1; } +template +inline void launcher( + const Tensor& self, + Tensor& ret, + Tensor& mask, + double p, + const int64_t nelem, + const PhiloxCudaState rng_engine_inputs, + dim3 grid, + dim3 dim_block) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "fused_dropout", + [&] { + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(p); + auto self_info = + cuda::detail::getTensorInfo(self); + auto ret_info = + cuda::detail::getTensorInfo(ret); + auto mask_info = + cuda::detail::getTensorInfo(mask); + self_info.collapseDims(); + ret_info.collapseDims(); + mask_info.collapseDims(); // ret and mask are collapsed to 1d + // contiguous tensor + + int vec_size = get_vector_size(self, ret, mask); + + if (vec_size > 1) { + switch (vec_size) { + case 4: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 4> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 2> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + } else { + switch (self_info.dims) { + case 1: + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + if (!self.is_contiguous() && ret.is_contiguous() && + mask.is_contiguous()) { + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + } + }); +} + } //anonymous namespace std::tuple fused_dropout_cuda(const Tensor& self, double p, c10::optional gen_){ auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); - Tensor ret = at::empty_like(self, self.suggest_memory_format()); - Tensor mask = at::empty(self.sizes(), self.options().dtype(kByte), self.suggest_memory_format()); + Tensor ret = at::empty_like(self); + Tensor mask = at::empty_like(self, self.options().dtype(kByte)); const int64_t nelem = self.numel(); //empty tensors should not get here, but just in case, avoid FPE if (nelem==0) return std::tuple(self, mask); @@ -211,82 +341,19 @@ fused_dropout_cuda(const Tensor& self, double p, c10::optional gen_){ grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); //number of times random will be generated per thread, to offset philox counter in thc random state int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); } if (cuda::detail::canUse32BitIndexMath(self)){ - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "fused_dropout", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "fused_dropout", [&] { - using accscalar_t = acc_type; - accscalar_t pa = (accscalar_t)(p); - auto self_info = cuda::detail::getTensorInfo(self); - auto ret_info = cuda::detail::getTensorInfo(ret); - auto mask_info = cuda::detail::getTensorInfo(mask); - self_info.collapseDims(); - ret_info.collapseDims(); - mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor - - int vec_size = get_vector_size(self, ret, mask); - - if (vec_size > 1) { - switch (vec_size) { - case 4: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - break; - case 2: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - break; - } - } else { - switch (self_info.dims) { - case 1: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - break; - default: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - } - } - }); - }); + launcher( + self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "fused_dropout", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "fused_dropout", [&] { - using accscalar_t = acc_type; - accscalar_t pa = (accscalar_t)(p); - auto self_info = cuda::detail::getTensorInfo(self); - auto ret_info = cuda::detail::getTensorInfo(ret); - auto mask_info = cuda::detail::getTensorInfo(mask); - self_info.collapseDims(); - ret_info.collapseDims(); - mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor - - int vec_size = get_vector_size(self, ret, mask); - - if (vec_size > 1) { - switch (vec_size) { - case 4: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - break; - case 2: - fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - break; - } - } else { - switch (self_info.dims) { - case 1: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - break; - default: - fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); - } - } - }); - }); + launcher( + self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); } - AT_CUDA_CHECK(cudaGetLastError()); return std::tuple(ret, mask); } @@ -294,11 +361,9 @@ Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){ Tensor ret = at::empty_like(self, self.suggest_memory_format()); TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype"); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "masked_scale", [&] { - using accscalar_t = acc_type; - accscalar_t pa = (accscalar_t)(scale); - masked_scale_kernel(ret, self, mask, pa); - }); + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(scale); + masked_scale_kernel(ret, self, mask, pa); }); return ret; } diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 365db61d06c0e..80a8bfa5a6e86 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -29,9 +29,10 @@ static const int BLOCKDIMY = 32; template + typename accscalar_t, + typename index_t> __global__ void embedding_backward_feature_kernel - (int64_t* indices, + (index_t* indices, const scalar_t* __restrict__ grad, scalar_t* __restrict__ grad_weight, int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot @@ -117,10 +118,10 @@ __global__ void embedding_backward_feature_kernel } -template +template __global__ void embedding_backward_kernel( - int64_t* input, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight, - int64_t* count, int64_t numel, int64_t stride, int padding_idx) { + index_t* input, index_t* indices, scalar_t* grad_output, scalar_t* grad_weight, + index_t* count, int64_t numel, int64_t stride, int padding_idx) { using accscalar_t = acc_type; int idx = blockIdx.x * 4 + threadIdx.y; @@ -179,9 +180,9 @@ __global__ void embedding_backward_kernel( } /* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */ -template +template __global__ void renorm_kernel( - scalar_t* weights, int64_t* indices, accscalar_t max_norm, + scalar_t* weights, index_t* indices, accscalar_t max_norm, accscalar_t norm_type, int64_t dim, int64_t weights_stride0, int64_t weights_stride1) { @@ -228,7 +229,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice bool scale_grad_by_freq) { auto grad_arg = TensorArg(grad_, "grad", 1); auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_backward", indices_arg, kLong); + checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); checkSameGPU("embedding_backward", grad_arg, indices_arg); auto num_indices = indices.numel(); @@ -248,83 +249,84 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice "embedding_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] { - using accscalar_t = acc_type; - embedding_backward_feature_kernel - <<; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { + embedding_backward_feature_kernel + <<>> - (indices_contig.data_ptr(), + (indices_contig.data_ptr(), grad.data_ptr(), grad_weight.data_ptr(), static_cast(num_indices), static_cast(stride), static_cast(padding_idx)); - }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); - - AT_CUDA_CHECK(cudaGetLastError()); return grad_weight; } auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - using device_ptr = thrust::device_ptr; + Tensor count; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { + using device_ptr = thrust::device_ptr; - // Sort the inputs into sorted with the corresponding indices; we - // don't need a stable or multidimensional sort, so just use Thrust - // directly - { - sorted_indices.copy_(indices); + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + { + sorted_indices.copy_(indices); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); - // Fill sortedOrigIndices with sequential indices - auto count_iter = thrust::counting_iterator(0); - auto orig_data = device_ptr(orig_indices.data_ptr()); - thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); + // Fill sortedOrigIndices with sequential indices + auto count_iter = thrust::counting_iterator(0); + auto orig_data = device_ptr(orig_indices.data_ptr()); + thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); - // Sort; a stable sort is not required - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, - ThrustLTOp()); - } + // Sort; a stable sort is not required + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, + ThrustLTOp()); + } - Tensor count; - if (scale_grad_by_freq) { - count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - auto count_data = device_ptr(count.data_ptr()); - thrust::inclusive_scan_by_key( - policy, - sorted_data, - sorted_data + num_indices, - thrust::make_constant_iterator(1), - count_data - ); - - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, - thrust::make_reverse_iterator(sorted_data + num_indices), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::make_reverse_iterator(count_data + num_indices), - thrust::equal_to(), - thrust::maximum() - ); - } + if (scale_grad_by_freq) { + count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + auto count_data = device_ptr(count.data_ptr()); + thrust::inclusive_scan_by_key( + policy, + sorted_data, + sorted_data + num_indices, + thrust::make_constant_iterator(1), + count_data + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( + policy, + thrust::make_reverse_iterator(sorted_data + num_indices), + thrust::make_reverse_iterator(sorted_data), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::equal_to(), + thrust::maximum() + ); + } + }); return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, count, num_weights, padding_idx); @@ -341,37 +343,33 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); auto policy = thrust::cuda::par(allocator).on(stream); - using device_ptr = thrust::device_ptr; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () { + using device_ptr = thrust::device_ptr; - auto num_indices = indices.numel(); - auto indices_contig = indices.contiguous(); - auto indices_data = device_ptr(indices_contig.data_ptr()); - - // FIXME: thrust::unique only removes consecutive elements that are equal. - // We have race conditions when indices contain duplicates which are not - // adjacent - auto unique_indices = at::empty(indices.numel(), indices.options()); - auto unique_data = device_ptr(unique_indices.data_ptr()); - auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); - auto num_unique_indices = static_cast(end - unique_data); - - dim3 grid(num_unique_indices); - dim3 block(128); - int dim = self.stride(0); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] { + auto num_indices = indices.numel(); + auto indices_contig = std::get<0>(indices.sort()).contiguous(); + auto indices_data = device_ptr(indices_contig.data_ptr()); + + auto unique_indices = at::empty(indices.numel(), indices.options()); + auto unique_data = device_ptr(unique_indices.data_ptr()); + auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); + auto num_unique_indices = static_cast(end - unique_data); + + dim3 grid(num_unique_indices); + dim3 block(128); + int dim = self.stride(0); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] { using accscalar_t = acc_type; renorm_kernel<<>>( self.data_ptr(), - unique_indices.data_ptr(), + unique_indices.data_ptr(), static_cast(max_norm), static_cast(norm_type), dim, self.stride(0), self.stride(1)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); - AT_CUDA_CHECK(cudaGetLastError()); - return self; } diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu index 0fd742d7b70fe..dd0730a38bcb1 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu @@ -40,8 +40,9 @@ int64_t ceil_div(int64_t x, int64_t y) { return (x + y - 1) / y; } +template __global__ -void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets, +void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets, int64_t num_of_segments, int64_t numel) { const int id = blockIdx.x * blockDim.x + threadIdx.x; if(id < num_of_segments) { @@ -52,18 +53,19 @@ void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets, } } +template __global__ void krn_partial_segment_offset( - int64_t *ret, - const int64_t *partials_per_segment, - const int64_t *partials_per_segment_offset, - const int64_t *segment_offsets, + index_t *ret, + const index_t *partials_per_segment, + const index_t *partials_per_segment_offset, + const index_t *segment_offsets, int64_t num_of_segments) { const int id = blockIdx.x * blockDim.x + threadIdx.x; if(id < num_of_segments) { - int64_t idx = partials_per_segment_offset[id]; - const int64_t num_partials = partials_per_segment[id]; - const int64_t segment_offset = segment_offsets[id]; + index_t idx = partials_per_segment_offset[id]; + const index_t num_partials = partials_per_segment[id]; + const index_t segment_offset = segment_offsets[id]; for (int64_t i=0; i +template __global__ void compute_grad_weight_bags( - int64_t *indices, scalar_t *gradOutput, - int64_t *offset2bag, int64_t *count, ptrdiff_t numel, - int64_t stride, int mode_mean, const int64_t *bag_size, + index_t *indices, scalar_t *gradOutput, + index_t *offset2bag, index_t *count, ptrdiff_t numel, + int64_t stride, int mode_mean, const index_t *bag_size, scalar_t* per_sample_weights, int64_t per_sample_weights_stride, - int64_t* segment_offsets, int64_t num_of_segments, + index_t* segment_offsets, int64_t num_of_segments, acc_type *grad_weight_per_segment, const int64_t stride_warped) { @@ -113,14 +115,14 @@ __global__ void compute_grad_weight_bags( grad_weight_per_segment[id * stride + startFeature] = weight; } -template +template __global__ void compute_grad_weight( - int64_t *indices, + index_t *indices, scalar_t *gradOutput, - int64_t *count, + index_t *count, ptrdiff_t numel, int64_t stride, - int64_t* segment_offsets, + index_t* segment_offsets, int64_t num_of_segments, acc_type *grad_weight_per_segment, const int64_t stride_warped) { @@ -140,7 +142,7 @@ __global__ void compute_grad_weight( accscalar_t weight = 0; for (int idx=idx_begin; idx < idx_end; ++idx) { - const int64_t target_row = indices[idx]; + const index_t target_row = indices[idx]; const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0; weight += gradOutput[target_row * stride + startFeature] * scale; } @@ -148,12 +150,12 @@ __global__ void compute_grad_weight( } // This kernel assumes that all input tensors are contiguous. -template +template __global__ void sum_and_scatter( - int64_t *input, scalar_t *gradWeight, int64_t stride, - int64_t* segment_offsets, int64_t num_of_segments, + index_t *input, scalar_t *gradWeight, int64_t stride, + index_t* segment_offsets, int64_t num_of_segments, const acc_type *grad_weight_per_segment, - const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments, + const index_t *segment_sizes_offsets, int64_t num_of_partial_segments, const int64_t padding_idx, const int64_t stride_warped) { @@ -206,68 +208,70 @@ Tensor embedding_backward_cuda_kernel( // spawn a warp per index. In this context, a segment is a number of rows that should // be summarized. // Unit: index in `sorted_indices` and `orig_indices` - auto segment_offsets = at::empty({numel}, orig_indices.options()); - int64_t num_of_segments; - { - auto sorted_indices_dev = thrust::device_ptr(sorted_indices.data_ptr()); - auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto dummy_dev = thrust::device_ptr(dummy.data_ptr()); - auto ends = thrust::unique_by_key_copy( - policy, - sorted_indices_dev, - sorted_indices_dev + numel, - thrust::make_counting_iterator(0), - dummy_dev, - thrust::device_ptr(segment_offsets.data_ptr())); - num_of_segments = thrust::get<0>(ends) - dummy_dev; - } + AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { + auto segment_offsets = at::empty({numel}, orig_indices.options()); + int64_t num_of_segments; + { + auto sorted_indices_dev = thrust::device_ptr(sorted_indices.data_ptr()); + auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto dummy_dev = thrust::device_ptr(dummy.data_ptr()); + auto ends = thrust::unique_by_key_copy( + policy, + sorted_indices_dev, + sorted_indices_dev + numel, + thrust::make_counting_iterator(0), + dummy_dev, + thrust::device_ptr(segment_offsets.data_ptr())); + num_of_segments = thrust::get<0>(ends) - dummy_dev; + } - // We split the segments up into sizes of `NROWS_PER_THREAD` - // Compute the number partial-segments per segment (some partial-segments - // may not be the full `NROWS_PER_THREAD` number of rows) - auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options()); - { - krn_partials_per_segment<<>> ( - partials_per_segment.data_ptr(), - segment_offsets.data_ptr(), - num_of_segments, - numel); - } + // We split the segments up into sizes of `NROWS_PER_THREAD` + // Compute the number partial-segments per segment (some partial-segments + // may not be the full `NROWS_PER_THREAD` number of rows) + auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options()); + { + krn_partials_per_segment<<>> ( + partials_per_segment.data_ptr(), + segment_offsets.data_ptr(), + num_of_segments, + numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } - // In order to compute `partial_segment_offset`, which is the start index - // of each partial-segment in `sorted_indices`, we need to compute the - // start position of each _segment_ in `partial_segment_offset`. - // Unit: index in `partial_segment_offset` - auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options()); - thrust::exclusive_scan( - policy, - thrust::device_ptr(partials_per_segment.data_ptr()), - thrust::device_ptr(partials_per_segment.data_ptr()+num_of_segments), - thrust::device_ptr(partials_per_segment_offset.data_ptr())); - - // The total number of partial-segments is the sum of `partials_per_segment_offset` - const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item() + - partials_per_segment_offset[num_of_segments-1].item(); - - // Now we can compute the start position of each partial-segment - // Unit: index in `sorted_indices` and `orig_indices` - auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options()); - { - krn_partial_segment_offset<<>> ( - partial_segment_offset.data_ptr(), - partials_per_segment.data_ptr(), - partials_per_segment_offset.data_ptr(), - segment_offsets.data_ptr(), - num_of_segments); - } + // In order to compute `partial_segment_offset`, which is the start index + // of each partial-segment in `sorted_indices`, we need to compute the + // start position of each _segment_ in `partial_segment_offset`. + // Unit: index in `partial_segment_offset` + auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options()); + thrust::exclusive_scan( + policy, + thrust::device_ptr(partials_per_segment.data_ptr()), + thrust::device_ptr(partials_per_segment.data_ptr()+num_of_segments), + thrust::device_ptr(partials_per_segment_offset.data_ptr())); + + // The total number of partial-segments is the sum of `partials_per_segment_offset` + const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item() + + partials_per_segment_offset[num_of_segments-1].item(); + + // Now we can compute the start position of each partial-segment + // Unit: index in `sorted_indices` and `orig_indices` + auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options()); + { + krn_partial_segment_offset<<>> ( + partial_segment_offset.data_ptr(), + partials_per_segment.data_ptr(), + partials_per_segment_offset.data_ptr(), + segment_offsets.data_ptr(), + num_of_segments); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } - const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE; - const int block = std::min(stride_warped, MAX_BLOCK_SIZE); - const int grid = ceil_div(num_of_partial_segments*stride_warped, block); + const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE; + const int block = std::min(stride_warped, MAX_BLOCK_SIZE); + const int grid = ceil_div(num_of_partial_segments*stride_warped, block); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_backward_cuda_compute_grad_weight", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] { // For numerical stability, the dtype of `grad_weight_per_segment` // should match `acc_type` using partial_weight_t = acc_type; @@ -281,43 +285,44 @@ Tensor embedding_backward_cuda_kernel( // Compute the sum of each partial-segment and handle bags if (offset2bag.defined()) { compute_grad_weight_bags<<>>( - orig_indices.data_ptr(), + orig_indices.data_ptr(), grad.data_ptr(), - offset2bag.data_ptr(), - count.defined() ? count.data_ptr() : nullptr, numel, stride, - mode_mean, bag_size.data_ptr(), + offset2bag.data_ptr(), + count.defined() ? count.data_ptr() : nullptr, numel, stride, + mode_mean, bag_size.data_ptr(), per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, - partial_segment_offset.data_ptr(), + partial_segment_offset.data_ptr(), num_of_partial_segments, grad_weight_per_segment.data_ptr(), stride_warped); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { compute_grad_weight<<>>( - orig_indices.data_ptr(), + orig_indices.data_ptr(), grad.data_ptr(), - count.defined() ? count.data_ptr() : nullptr, + count.defined() ? count.data_ptr() : nullptr, numel, stride, - partial_segment_offset.data_ptr(), + partial_segment_offset.data_ptr(), num_of_partial_segments, grad_weight_per_segment.data_ptr(), stride_warped); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); // Finally, we sum all the partial-sums and scatter them // into `grad_weight`. const int grid2 = ceil_div(num_of_segments*stride_warped, block); sum_and_scatter<<>>( - sorted_indices.data_ptr(), + sorted_indices.data_ptr(), grad_weight.data_ptr(), stride, - segment_offsets.data_ptr(), + segment_offsets.data_ptr(), num_of_segments, grad_weight_per_segment.data_ptr(), - partials_per_segment_offset.data_ptr(), + partials_per_segment_offset.data_ptr(), num_of_partial_segments, padding_idx, stride_warped); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); return grad_weight; diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 87df361280ef9..a80de4b45138f 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -31,12 +31,12 @@ constexpr int MODE_MAX = 2; // This kernel assumes that all input tensors except `weight` and // per_sample_weights are contiguous. -template +template __global__ void EmbeddingBag_updateOutputKernel( - int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output, - int64_t *offset2bag, int64_t numIndices, int64_t numBags, + index_t *input, index_t *offsets, scalar_t *weight, scalar_t *output, + index_t *offset2bag, int64_t numIndices, int64_t numBags, int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1, - int mode, int64_t *bag_size, int64_t *max_indices, + int mode, index_t *bag_size, index_t *max_indices, scalar_t* per_sample_weights, int64_t per_sample_weights_stride) { // the strategy here is that each bag x feature is handled by a single thread @@ -135,62 +135,65 @@ Tensor embedding_bag_backward_cuda_sum_avg( auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - using device_ptr = thrust::device_ptr; - - // Sort the inputs into sorted with the corresponding indices; we - // don't need a stable or multidimensional sort, so just use Thrust - // directly - { - sorted_indices.copy_(indices); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Fill sortedOrigIndices with sequential indices - auto count_iter = thrust::counting_iterator(0); - auto orig_data = device_ptr(orig_indices.data_ptr()); - thrust::copy(policy, count_iter, count_iter + numel, orig_data); - - // Sort; a stable sort is not required - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data, - ThrustLTOp()); - } - Tensor count; - if (scale_grad_by_freq) { - count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - - // Compute an increasing sequence per unique item in sortedIndices: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 1 2 3 1 2 1 1 2 - auto sorted_data = device_ptr(sorted_indices.data_ptr()); - auto count_data = device_ptr(count.data_ptr()); - thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel, - thrust::make_constant_iterator(1), - count_data); - - // Take the maximum of each count per unique key in reverse: - // sorted: 2 5 5 5 7 7 8 9 9 - // count: 1 3 3 3 2 2 1 2 2 - thrust::inclusive_scan_by_key( - policy, thrust::make_reverse_iterator(sorted_data + numel), - thrust::make_reverse_iterator(sorted_data), - thrust::make_reverse_iterator(count_data + numel), - thrust::make_reverse_iterator(count_data + numel), - thrust::equal_to(), thrust::maximum()); - } + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { + using device_ptr = thrust::device_ptr; + + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + { + sorted_indices.copy_(indices); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Fill sortedOrigIndices with sequential indices + auto count_iter = thrust::counting_iterator(0); + auto orig_data = device_ptr(orig_indices.data_ptr()); + thrust::copy(policy, count_iter, count_iter + numel, orig_data); + + // Sort; a stable sort is not required + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data, + ThrustLTOp()); + } + + if (scale_grad_by_freq) { + count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = device_ptr(sorted_indices.data_ptr()); + auto count_data = device_ptr(count.data_ptr()); + thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel, + thrust::make_constant_iterator(1), + count_data); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( + policy, thrust::make_reverse_iterator(sorted_data + numel), + thrust::make_reverse_iterator(sorted_data), + thrust::make_reverse_iterator(count_data + numel), + thrust::make_reverse_iterator(count_data + numel), + thrust::equal_to(), thrust::maximum()); + } + }); return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, count, num_weights, /* padding_idx= */ -1, scale_grad_by_freq, mode == MODE_MEAN, offset2bag, bag_size, per_sample_weights); } -template +template __global__ void EmbeddingBag_accGradParametersKernel_max( - int64_t *max_indices, scalar_t *gradOutput, + index_t *max_indices, scalar_t *gradOutput, scalar_t *gradWeight, int64_t stride, int64_t numBags) { using accscalar_t = acc_type; @@ -205,7 +208,7 @@ __global__ void EmbeddingBag_accGradParametersKernel_max( if (featureDim < stride) { int64_t bag = chunk / chunksPerBag; - int64_t word_idx = max_indices[bag * stride + featureDim]; + index_t word_idx = max_indices[bag * stride + featureDim]; if (word_idx >= 0) { // If bag is empty, we have max_indices[idx] set to -1 in forward. gpuAtomicAdd(&(gradWeight[word_idx * stride + featureDim]), @@ -236,13 +239,15 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad, AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] { - EmbeddingBag_accGradParametersKernel_max< - scalar_t><<>>( - max_indices.data_ptr(), grad.data_ptr(), - grad_weight.data_ptr(), stride, numBags); + AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () { + EmbeddingBag_accGradParametersKernel_max< + scalar_t, index_t><<>>( + max_indices.data_ptr(), grad.data_ptr(), + grad_weight.data_ptr(), stride, numBags); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); + }); - AT_CUDA_CHECK(cudaGetLastError()); return grad_weight; } } @@ -275,9 +280,10 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, const Tensor& per_sample_weights, bool include_last_offset) { auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag_cuda", indices_arg, kLong); + checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt}); auto offsets_arg = TensorArg(offsets, "offsets", 1); - checkScalarType("embedding_bag_cuda", offsets_arg, kLong); + checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt}); + checkSameType("embedding_bag_cuda", indices_arg, offsets_arg); auto weight_arg = TensorArg(weight, "weight", 1); checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg); checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg); @@ -319,19 +325,19 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices, #endif int grid = 1024; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_cuda", [&] { - EmbeddingBag_updateOutputKernel<<>>( - indices.data_ptr(), offsets.data_ptr(), + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () { + EmbeddingBag_updateOutputKernel<<>>( + indices.data_ptr(), offsets.data_ptr(), weight.data_ptr(), output.data_ptr(), - offset2bag.data_ptr(), numIndices, numBags, featureSize, - weight.stride(0), weight.stride(1), mode, bag_size.data_ptr(), - mode == MODE_MAX ? max_indices.data_ptr() : NULL, + offset2bag.data_ptr(), numIndices, numBags, featureSize, + weight.stride(0), weight.stride(1), mode, bag_size.data_ptr(), + mode == MODE_MAX ? max_indices.data_ptr() : NULL, per_sample_weights.defined() ? per_sample_weights.data_ptr() : NULL, per_sample_weights.defined() ? per_sample_weights.stride(0) : 0); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); - AT_CUDA_CHECK(cudaGetLastError()); return std::tuple(output, offset2bag, bag_size, max_indices); } @@ -387,12 +393,12 @@ static scalar_t warpReduceSum(scalar_t val) { return val; } -template +template __global__ static void _embedding_bag_per_sample_weights_backward_kernel( const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1, const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1, - const int64_t* indices, // contiguous - const int64_t* offset2bag, // contiguous + const index_t* indices, // contiguous + const index_t* offset2bag, // contiguous int64_t num_samples, int64_t embedding_features, scalar_t* output) { @@ -448,17 +454,27 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda( dim3 grid((num_samples + warps_per_block - 1) / warps_per_block); auto output = at::empty({num_samples}, grad.options()); + + // Early return when there is no samples in the batch. This saves unnecesary kernel + // launch, but also prevents cudaGetLastError() to complain about invalid launch args + if (num_samples == 0) { + return output; + } + AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { - _embedding_bag_per_sample_weights_backward_kernel - <<>>( - grad.data_ptr(), grad.stride(0), grad.stride(1), - weight.data_ptr(), weight.stride(0), weight.stride(1), - indices.data_ptr(), - offset2bag.data_ptr(), - num_samples, - embedding_features, - output.data_ptr()); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { + _embedding_bag_per_sample_weights_backward_kernel + <<>>( + grad.data_ptr(), grad.stride(0), grad.stride(1), + weight.data_ptr(), weight.stride(0), weight.stride(1), + indices.data_ptr(), + offset2bag.data_ptr(), + num_samples, + embedding_features, + output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } ); return output; diff --git a/aten/src/ATen/native/cuda/FillKernel.cu b/aten/src/ATen/native/cuda/FillKernel.cu index 7376ecfa63948..e4fe4b68f2eb8 100644 --- a/aten/src/ATen/native/cuda/FillKernel.cu +++ b/aten/src/ATen/native/cuda/FillKernel.cu @@ -6,12 +6,19 @@ namespace at { namespace native { +template +struct FillFunctor { + FillFunctor(scalar_t v): value(v) {} + __device__ __forceinline__ scalar_t operator() () const { + return value; + } + private: + scalar_t value; +}; + void fill_kernel_cuda(TensorIterator& iter, Scalar value) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "fill_cuda", [&]() { - auto value_converted = value.to(); - gpu_kernel(iter, [value_converted]GPU_LAMBDA() -> scalar_t { - return value_converted; - }); + gpu_kernel(iter, FillFunctor(value.to())); }); } diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu index 239859b9138c8..cdfc0d0abec78 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu @@ -6,8 +6,9 @@ namespace at { namespace native { template class Op> std::vector foreach_tensor_list_op(TensorList tensors1, TensorList tensors2, Scalar alpha = 1) { - std::vector> tensor_lists; + std::vector> tensor_lists; std::vector vec_res; + vec_res.reserve(tensors1.size()); for (const auto& t: tensors1) { vec_res.emplace_back(at::native::empty_like(t)); } @@ -17,7 +18,14 @@ std::vector foreach_tensor_list_op(TensorList tensors1, TensorList tenso tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors1[0].scalar_type(), "foreach_binary_op_list_cuda", [&]() { - multi_tensor_apply<3>(tensor_lists, BinaryOpListAlphaFunctor(), alpha.to()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<3>(tensor_lists, + BinaryOpListAlphaFunctor(), + Op(), + alpha.to()); }); return tensor_lists[2]; @@ -25,19 +33,25 @@ std::vector foreach_tensor_list_op(TensorList tensors1, TensorList tenso template class Op> void foreach_tensor_list_op_(TensorList tensors1, TensorList tensors2, Scalar alpha = 1) { - std::vector> tensor_lists; + std::vector> tensor_lists; tensor_lists.emplace_back(tensors1.vec()); tensor_lists.emplace_back(tensors2.vec()); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors1[0].scalar_type(), "foreach_binary_op_list_cuda_", [&]() { - multi_tensor_apply<2>(tensor_lists, BinaryOpListAlphaFunctor_(), alpha.to()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + BinaryOpListAlphaFunctor(), + Op(), + alpha.to()); }); } #define FOREACH_BINARY_OP_LIST(NAME, OP) \ void foreach_tensor_##NAME##_list_kernel_cuda_(TensorList tensors1, TensorList tensors2) { \ check_foreach_api_restrictions(tensors1, tensors2); \ - \ if (!can_use_fast_route(tensors1, tensors2)) { \ return at::native::foreach_tensor_##NAME##_list_kernel_slow_(tensors1, tensors2); \ } \ @@ -47,7 +61,6 @@ void foreach_tensor_##NAME##_list_kernel_cuda_(TensorList tensors1, TensorList t \ std::vector foreach_tensor_##NAME##_list_kernel_cuda(TensorList tensors1, TensorList tensors2) { \ check_foreach_api_restrictions(tensors1, tensors2); \ - \ if (!can_use_fast_route(tensors1, tensors2)) { \ return at::native::foreach_tensor_##NAME##_list_kernel_slow(tensors1, tensors2); \ } \ @@ -58,8 +71,7 @@ std::vector foreach_tensor_##NAME##_list_kernel_cuda(TensorList tensors1 #define FOREACH_BINARY_OP_LIST_ALPHA(NAME, OP) \ void foreach_tensor_##NAME##_list_kernel_cuda_(TensorList tensors1, TensorList tensors2, Scalar alpha) { \ check_foreach_api_restrictions(tensors1, tensors2); \ - \ - if (!can_use_fast_route(tensors1, tensors2)) { \ + if (!can_use_fast_route(tensors1, tensors2, alpha)) { \ return at::native::foreach_tensor_##NAME##_list_kernel_slow_(tensors1, tensors2, alpha); \ } \ \ @@ -68,8 +80,7 @@ void foreach_tensor_##NAME##_list_kernel_cuda_(TensorList tensors1, TensorList t \ std::vector foreach_tensor_##NAME##_list_kernel_cuda(TensorList tensors1, TensorList tensors2, Scalar alpha) { \ check_foreach_api_restrictions(tensors1, tensors2); \ - \ - if (!can_use_fast_route(tensors1, tensors2)) { \ + if (!can_use_fast_route(tensors1, tensors2, alpha)) { \ return at::native::foreach_tensor_##NAME##_list_kernel_slow(tensors1, tensors2, alpha); \ } \ \ diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu index 215410bbc2a55..43372055dbb54 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu @@ -6,10 +6,9 @@ namespace at { namespace native { template class Op> std::vector foreach_binary_op(TensorList tensors, Scalar scalar) { - check_foreach_api_restrictions(tensors); - - std::vector> tensor_lists; + std::vector> tensor_lists; std::vector vec_res; + vec_res.reserve(tensors.size()); for (const auto& t: tensors) { vec_res.emplace_back(at::native::empty_like(t)); } @@ -18,27 +17,38 @@ std::vector foreach_binary_op(TensorList tensors, Scalar scalar) { tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalar_cuda", [&]() { - multi_tensor_apply<2>(tensor_lists, BinaryOpScalarFunctor(), scalar.to()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + BinaryOpScalarFunctor(), + Op(), + scalar.to()); }); return tensor_lists[1]; } template class Op> void foreach_binary_op_(TensorList tensors, Scalar scalar) { - check_foreach_api_restrictions(tensors); - - std::vector> tensor_lists; + std::vector> tensor_lists; tensor_lists.emplace_back(tensors.vec()); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalar_cuda_", [&]() { - multi_tensor_apply<1>(tensor_lists, BinaryOpScalarFunctor_(), scalar.to()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + BinaryOpScalarFunctor(), + Op(), + scalar.to()); }); } #define FOREACH_BINARY_OP_SCALAR(NAME, OP) \ void foreach_tensor_##NAME##_scalar_kernel_cuda_(TensorList tensors, Scalar scalar) { \ check_foreach_api_restrictions(tensors); \ - \ if (!can_use_fast_route(tensors, scalar)) { \ return at::native::foreach_tensor_##NAME##_scalar_kernel_slow_(tensors, scalar); \ } \ @@ -48,7 +58,6 @@ void foreach_tensor_##NAME##_scalar_kernel_cuda_(TensorList tensors, Scalar scal \ std::vector foreach_tensor_##NAME##_scalar_kernel_cuda(TensorList tensors, Scalar scalar) { \ check_foreach_api_restrictions(tensors); \ - \ if (!can_use_fast_route(tensors, scalar)) { \ return at::native::foreach_tensor_##NAME##_scalar_kernel_slow(tensors, scalar); \ } \ diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu new file mode 100644 index 0000000000000..8ae678405cd3c --- /dev/null +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu @@ -0,0 +1,74 @@ +#include +#include +#include + +namespace at { namespace native { + +template class Op> +std::vector foreach_binary_op(TensorList tensors, at::ArrayRef scalars) { + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(vec_res); + + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + scalars, + BinaryOpScalarListFunctor(), + + Op()); + }); + return tensor_lists[1]; +} + +template class Op> +void foreach_binary_op_(TensorList tensors, at::ArrayRef scalars) { + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + scalars, + BinaryOpScalarListFunctor(), + Op()); + }); +} + +#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \ +void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(tensors, scalars); \ + if (!can_use_fast_route(tensors, scalars)) { \ + return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \ + } \ + \ + foreach_binary_op_(tensors, scalars); \ +} \ + \ +std::vector foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(tensors, scalars); \ + if (!can_use_fast_route(tensors, scalars)) { \ + return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \ + } \ + \ + return foreach_binary_op(tensors, scalars); \ +} + +FOREACH_BINARY_OP_SCALARLIST(add, std::plus); +FOREACH_BINARY_OP_SCALARLIST(sub, std::minus); +FOREACH_BINARY_OP_SCALARLIST(mul, std::multiplies); +FOREACH_BINARY_OP_SCALARLIST(div, std::divides); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index a04d27110c9ae..4805b0fc00cc0 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -5,499 +5,418 @@ namespace at { namespace native { namespace { -template class Op> -struct BinaryOpScalarFunctor_ { - __device__ void operator() ( - int chunk_size, - TensorListMetadata<1>& tl, - T scalar) { - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; - - n -= chunk_idx * chunk_size; - - T r_x[kILP]; - - // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x)) { - for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { - // load - load_store(r_x, x, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), scalar); - } - // store - load_store(x, r_x, i_start, 0); +// For FP16 or BFloat16 inputs, ops should perform internal math in FP32. +template struct get_opmath_t { using opmath_t = scalar_t; }; +template<> struct get_opmath_t { using opmath_t = float; }; +template<> struct get_opmath_t { using opmath_t = float; }; + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListMetadata& tl, + int chunk_idx, + int chunk_size, + int tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListScalarListMetadata& tl, + int chunk_idx, + int chunk_size, + int tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ void load_args(T r_args[][kILP], T** args, int i_start, int chunk_size, int n) { +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + for (int r_index = 0; r_index < depth; r_index++) { + r_args[r_index][ii] = 0; + if(i < n && i < chunk_size) { + r_args[r_index][ii] = args[r_index][i]; + } + } + } +} + +template +__device__ void store_args(T* dst, T* src, int i_start, int chunk_size, int n) { +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if(i < n && i < chunk_size) + dst[i] = src[ii]; + } +} + +template +__device__ __forceinline__ void binary_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + int n, + int chunk_size, + bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]), + static_cast(scalar))); } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); } - else { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), scalar); - } + } + else { + for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); #pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - x[i] = r_x[ii]; - } + for(int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]), + static_cast(scalar))); } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } -}; +} + +template +__device__ __forceinline__ void pointwise_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + int n, + int chunk_size, + bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(static_cast(r_args[0][ii]) + + scalar * op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } + else { + for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(static_cast(r_args[0][ii]) + + scalar * op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} -template class Op> +// +// Binary Functors +// +template struct BinaryOpScalarFunctor { - __device__ void operator() ( + using opmath_t = typename get_opmath_t::opmath_t; + template __device__ __forceinline__ void operator() ( int chunk_size, - TensorListMetadata<2>& tl, - T scalar) { + TensorListMetadata& tl, + Op op, + opmath_t scalar) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; - - T* out = (T*)tl.addresses[1][tensor_loc]; - out += chunk_idx * chunk_size; + int n = tl.numel_for_tensor[tensor_loc]; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; - T r_x[kILP]; - - // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) { - for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { - // load - load_store(r_x, x, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), scalar); - } - // store - load_store(out, r_x, i_start, 0); - } - } - else { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), scalar); - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_x[ii]; - } - } - } + binary_op_scalar(r_args, args, scalar, n, chunk_size, all_aligned, op); } }; -template class Op> -struct BinaryOpListAlphaFunctor_ { - __device__ void operator() ( +template +struct BinaryOpScalarListFunctor { + using opmath_t = typename get_opmath_t::opmath_t; + template __device__ __forceinline__ void operator() ( int chunk_size, - TensorListMetadata<2>& tl, - T alpha) { + TensorListScalarListMetadata& tl, + Op op) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; - - T* y = (T*)tl.addresses[1][tensor_loc]; - y += chunk_idx * chunk_size; + int n = tl.numel_for_tensor[tensor_loc]; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; - T r_x[kILP]; - T r_y[kILP]; - - // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(y)) { - for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { - // load - load_store(r_x, x, 0 , i_start); - load_store(r_y, y, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), alpha * static_cast(r_y[ii])); - } - // store - load_store(x, r_x, i_start , 0); - } - } - else { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - r_y[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - r_y[ii] = y[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), alpha * static_cast(r_y[ii])); - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - x[i] = r_x[ii]; - } - } - } + binary_op_scalar(r_args, args, scalar, n, chunk_size, all_aligned, op); } }; -template class Op> +template struct BinaryOpListAlphaFunctor { - __device__ void operator() ( + using opmath_t = typename get_opmath_t::opmath_t; + template __device__ __forceinline__ void operator() ( int chunk_size, - TensorListMetadata<3>& tl, - T alpha) { + TensorListMetadata& tl, + Op op, + opmath_t alpha) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; - - T* y = (T*)tl.addresses[1][tensor_loc]; - y += chunk_idx * chunk_size; - - T* out = (T*)tl.addresses[2][tensor_loc]; - out += chunk_idx * chunk_size; + int n = tl.numel_for_tensor[tensor_loc]; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; - - T r_x[kILP]; - T r_y[kILP]; + T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out)) { + if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0 , i_start); - load_store(r_y, y, 0 , i_start); + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), alpha * static_cast(r_y[ii])); + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); } // store - load_store(out, r_x, i_start , 0); + load_store(args[res_arg_index], r_args[0], i_start , 0); } } else { for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - r_y[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - r_y[ii] = y[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii]), alpha * static_cast(r_y[ii])); - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_x[ii]; + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; -template class Op> -struct UnaryOpFunctor_ { - __device__ void operator() ( +// +// Unary Functors +// + +template +struct ZeroFunctor { + __device__ __forceinline__ void operator() ( int chunk_size, TensorListMetadata<1>& tl) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; + int n = tl.numel_for_tensor[tensor_loc]; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; - - T r_x[kILP]; + T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x)) { + if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0 , i_start); + load_store(r_args[0], args[0], 0, i_start); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii])); + r_args[0][ii] = 0; } // store - load_store(x, r_x, i_start, 0); + load_store(args[0], r_args[0], i_start, 0); } } else { for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii])); - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - x[i] = r_x[ii]; + r_args[0][ii] = 0; } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; -template class Op> +template struct UnaryOpFunctor { - __device__ void operator() ( + using opmath_t = typename get_opmath_t::opmath_t; + template __device__ __forceinline__ void operator() ( int chunk_size, - TensorListMetadata<2>& tl) { + TensorListMetadata& tl, + Op op) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; - - T* out = (T*)tl.addresses[1][tensor_loc]; - out += chunk_idx * chunk_size; + int n = tl.numel_for_tensor[tensor_loc]; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; - - T r_x[kILP]; + T r_args[r_args_depth][kILP]; // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) { + if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0 , i_start); + load_store(r_args[0], args[0], 0, i_start); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii])); + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]))); } // store - load_store(out, r_x, i_start, 0); + load_store(args[res_arg_index], r_args[0], i_start, 0); } } else { for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = Op()(static_cast(r_x[ii])); - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_x[ii]; + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]))); } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); } } } }; -template class Op> -struct PointwiseOpFunctor_ { - __device__ void operator() ( +// +// Pointwise Functors +// + +template +struct PointwiseOpScalarFunctor { + using opmath_t = typename get_opmath_t::opmath_t; + template __device__ __forceinline__ void operator() ( int chunk_size, - TensorListMetadata<3>& tl, - T scalar) { + TensorListMetadata& tl, + Op op, + opmath_t scalar) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; - - T* y = (T*)tl.addresses[1][tensor_loc]; - y += chunk_idx * chunk_size; - - T* z = (T*)tl.addresses[2][tensor_loc]; - z += chunk_idx * chunk_size; + int n = tl.numel_for_tensor[tensor_loc]; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; - T r_x[kILP]; - T r_y[kILP]; - T r_z[kILP]; - - // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(z)) { - for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { - // load - load_store(r_x, x, 0 , i_start); - load_store(r_y, y, 0 , i_start); - load_store(r_z, z, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = static_cast(r_x[ii]) + scalar * Op()(static_cast(r_y[ii]), static_cast(r_z[ii])); - } - // store - load_store(x, r_x, i_start, 0); - } - } - else { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - r_y[ii] = 0; - r_z[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - r_y[ii] = y[i]; - r_z[ii] = z[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = static_cast(r_x[ii]) + scalar * Op()(static_cast(r_y[ii]), static_cast(r_z[ii])); - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - x[i] = r_x[ii]; - } - } - } + pointwise_op_scalar(r_args, args, scalar, n, chunk_size, all_aligned, op); } }; -template class Op> -struct PointwiseOpFunctor { - __device__ void operator() ( +template +struct PointwiseOpScalarListFunctor { + using opmath_t = typename get_opmath_t::opmath_t; + template __device__ __forceinline__ void operator() ( int chunk_size, - TensorListMetadata<4>& tl, - T scalar) { + TensorListScalarListMetadata& tl, + Op op) { int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* x = (T*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; + int n = tl.numel_for_tensor[tensor_loc]; - T* y = (T*)tl.addresses[1][tensor_loc]; - y += chunk_idx * chunk_size; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; - T* z = (T*)tl.addresses[2][tensor_loc]; - z += chunk_idx * chunk_size; + pointwise_op_scalar(r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; - T* out = (T*)tl.addresses[3][tensor_loc]; - out += chunk_idx * chunk_size; +template +struct PointwiseOpListFunctor { + using opmath_t = typename get_opmath_t::opmath_t; + template __device__ __forceinline__ void operator() ( + int chunk_size, + TensorListMetadata& tl, + Op op) { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.numel_for_tensor[tensor_loc]; + T* args[depth]; + bool all_aligned = init_args(args, tl, chunk_idx, chunk_size, tensor_loc); n -= chunk_idx * chunk_size; - - T r_x[kILP]; - T r_y[kILP]; - T r_z[kILP]; + T r_args[depth - 1][kILP]; // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(z) && is_aligned(out)) { + if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { // load - load_store(r_x, x, 0 , i_start); - load_store(r_y, y, 0 , i_start); - load_store(r_z, z, 0 , i_start); + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = static_cast(r_x[ii]) + scalar * Op()(static_cast(r_y[ii]), static_cast(r_z[ii])); + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); } // store - load_store(out, r_x, i_start, 0); + load_store(args[2], r_args[0], i_start , 0); } } else { for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); #pragma unroll for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - r_y[ii] = 0; - r_z[ii] = 0; - - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - r_y[ii] = y[i]; - r_z[ii] = z[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = static_cast(r_x[ii]) + scalar * Op()(static_cast(r_y[ii]), static_cast(r_z[ii])); - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_x[ii]; + r_args[0][ii] = static_cast(op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); } + store_args(args[2], r_args[0], i_start, chunk_size, n); } } } }; } // namespace - }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu index b514f3294c52e..43f1928495c76 100644 --- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu +++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu @@ -1,13 +1,15 @@ #include #include #include +#include namespace at { namespace native { template class Op> std::vector foreach_pointwise_op(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { - std::vector> tensor_lists; + std::vector> tensor_lists; std::vector vec_res; + vec_res.reserve(input.size()); for (const auto& t: input) { vec_res.emplace_back(at::native::empty_like(t)); } @@ -18,7 +20,14 @@ std::vector foreach_pointwise_op(TensorList input, TensorList tensors1, tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op_cuda", [&]() { - multi_tensor_apply<4>(tensor_lists, PointwiseOpFunctor(), scalar.to()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<4>(tensor_lists, + PointwiseOpScalarFunctor(), + Op(), + scalar.to()); }); return tensor_lists[3]; @@ -26,46 +35,155 @@ std::vector foreach_pointwise_op(TensorList input, TensorList tensors1, template class Op> void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { - std::vector> tensor_lists; + std::vector> tensor_lists; tensor_lists.emplace_back(input.vec()); tensor_lists.emplace_back(tensors1.vec()); tensor_lists.emplace_back(tensors2.vec()); AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op__cuda", [&]() { - multi_tensor_apply<3>(tensor_lists, PointwiseOpFunctor_(), scalar.to()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<3>(tensor_lists, + PointwiseOpScalarFunctor(), + Op(), + scalar.to()); }); } -#define FOREACH_UNARY_OP(NAME, OP) \ -std::vector foreach_tensor_##NAME##_cuda(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ - TORCH_CHECK(input.size() > 0, "Tensor list must have at least one tensor."); \ - TORCH_CHECK(input.size() == tensors1.size(), "Tensor lists must be of the same length."); \ - TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must be of the same length."); \ - \ - if (!can_use_fast_route(input, scalar) || \ - !can_use_fast_route(tensors1, tensors2) || \ - !can_use_fast_route(input, tensors1)) { \ - return at::native::foreach_tensor_##NAME##_slow(input, tensors1, tensors2, scalar); \ - } \ - \ - return foreach_pointwise_op(input, tensors1, tensors2, scalar); \ -} \ - \ -void foreach_tensor_##NAME##_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ - TORCH_CHECK(input.size() > 0, "Tensor list must have at least one tensor."); \ - TORCH_CHECK(input.size() == tensors1.size(), "Tensor lists must be of the same length."); \ - TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must be of the same length."); \ - \ - if (!can_use_fast_route(input, scalar) || \ - !can_use_fast_route(tensors1, tensors2) || \ - !can_use_fast_route(input, tensors1)) { \ - at::native::foreach_tensor_##NAME##_slow_(input, tensors1, tensors2, scalar); \ - } \ - \ - foreach_pointwise_op_(input, tensors1, tensors2, scalar); \ +template class Op> +void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { + std::vector> tensor_lists; + tensor_lists.reserve(3); + tensor_lists.emplace_back(input.vec()); + tensor_lists.emplace_back(tensors1.vec()); + tensor_lists.emplace_back(tensors2.vec()); + + AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op__cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<3>(tensor_lists, + scalars, + PointwiseOpScalarListFunctor(), + Op()); + }); } -FOREACH_UNARY_OP(addcmul, std::multiplies); -FOREACH_UNARY_OP(addcdiv, std::divides); +template class Op> +std::vector foreach_pointwise_op(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { + std::vector> tensor_lists; + tensor_lists.reserve(4); + std::vector vec_res; + vec_res.reserve(input.size()); + for (const auto& t: input) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(input.vec()); + tensor_lists.emplace_back(tensors1.vec()); + tensor_lists.emplace_back(tensors2.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<4>(tensor_lists, + scalars, + PointwiseOpScalarListFunctor(), + Op()); + }); + + return tensor_lists[3]; +} + +#define FOREACH_POINTWISE_OP_SCALAR(NAME, OP) \ +std::vector foreach_tensor_##NAME##_scalar_cuda(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ + check_foreach_api_restrictions(input, tensors1, tensors2); \ + \ + if (!can_use_fast_route(input, tensors1, tensors2, scalar)) { \ + return at::native::foreach_tensor_##NAME##_scalar_slow(input, tensors1, tensors2, scalar); \ + } \ + \ + return foreach_pointwise_op(input, tensors1, tensors2, scalar); \ +} \ + \ +void foreach_tensor_##NAME##_scalar_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \ + check_foreach_api_restrictions(input, tensors1, tensors2); \ + \ + if (!can_use_fast_route(input, tensors1, tensors2, scalar)) { \ + return at::native::foreach_tensor_##NAME##_scalar_slow_(input, tensors1, tensors2, scalar); \ + } \ + \ + foreach_pointwise_op_(input, tensors1, tensors2, scalar); \ +} + + +#define FOREACH_POINTWISE_OP_SCALARLIST(NAME, OP) \ +std::vector foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + \ + if (!can_use_fast_route(input, tensors1, tensors2, scalars)) { \ + return at::native::foreach_tensor_##NAME##_scalarlist_slow(input, tensors1, tensors2, scalars); \ + } \ + \ + return foreach_pointwise_op(input, tensors1, tensors2, scalars); \ +} \ + \ +void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef scalars) { \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + \ + if (!can_use_fast_route(input, tensors1, tensors2, scalars)) { \ + return at::native::foreach_tensor_##NAME##_scalarlist_slow_(input, tensors1, tensors2, scalars); \ + } \ + \ + foreach_pointwise_op_(input, tensors1, tensors2, scalars); \ +} + +FOREACH_POINTWISE_OP_SCALAR(addcmul, std::multiplies); +FOREACH_POINTWISE_OP_SCALAR(addcdiv, std::divides); +FOREACH_POINTWISE_OP_SCALARLIST(addcmul, std::multiplies); +FOREACH_POINTWISE_OP_SCALARLIST(addcdiv, std::divides); + +#define FOREACH_MAXIMUM_MINIMUM_OP(NAME, OP) \ +std::vector foreach_tensor_##NAME##_cuda(TensorList tensors1, TensorList tensors2) { \ + check_foreach_api_restrictions(tensors1, tensors2); \ + if (!can_use_fast_route(tensors1, tensors2)) { \ + return at::native::foreach_tensor_##NAME##_slow(tensors1, tensors2); \ + } \ + \ + std::vector> tensor_lists; \ + std::vector vec_res; \ + vec_res.reserve(tensors1.size()); \ + for (const auto& t: tensors1) { \ + vec_res.emplace_back(at::native::empty_like(t)); \ + } \ + \ + tensor_lists.emplace_back(tensors1.vec()); \ + tensor_lists.emplace_back(tensors2.vec()); \ + tensor_lists.emplace_back(std::move(vec_res)); \ + \ + AT_DISPATCH_ALL_TYPES_AND(kHalf, tensors1[0].scalar_type(), "foreach_maximum_minimum_op_cuda", [&]() { \ + using opmath_t = get_opmath_t::opmath_t; \ + auto op = [] GPU_LAMBDA (opmath_t a, opmath_t b) -> opmath_t { \ + opmath_t c = a OP b ? a : b; \ + if (_isnan(a)) { \ + c = a; \ + } \ + return c;}; \ + multi_tensor_apply<3>(tensor_lists, \ + PointwiseOpListFunctor(), \ + op); \ + }); \ + \ + return tensor_lists[2]; \ +} \ + +FOREACH_MAXIMUM_MINIMUM_OP(maximum, >) +FOREACH_MAXIMUM_MINIMUM_OP(minimum, <) }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index 32bb6ab6b5095..e3e8e86e7268a 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -5,9 +5,10 @@ namespace at { namespace native { template class Op> -std::vector foreach_unary_op(TensorList tensors) { - std::vector> tensor_lists; +std::vector foreach_unary_op_complex(TensorList tensors) { + std::vector> tensor_lists; std::vector vec_res; + vec_res.reserve(tensors.size()); for (const auto& t: tensors) { vec_res.emplace_back(at::native::empty_like(t)); } @@ -16,22 +17,194 @@ std::vector foreach_unary_op(TensorList tensors) { tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { - multi_tensor_apply<2>(tensor_lists, UnaryOpFunctor()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Op()); }); return tensor_lists[1]; } template class Op> -void foreach_unary_op_(TensorList tensors) { - std::vector> tensor_lists; +void foreach_unary_op_complex_(TensorList tensors) { + std::vector> tensor_lists; tensor_lists.emplace_back(tensors.vec()); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { - multi_tensor_apply<1>(tensor_lists, UnaryOpFunctor_()); + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Op()); }); } -#define FOREACH_UNARY_OP(NAME, NAME1) \ +template class Op> +std::vector foreach_unary_op_complex_bfloat16(TensorList tensors) { + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Op()); + }); + return tensor_lists[1]; +} + +template class Op> +void foreach_unary_op_complex_bfloat16_(TensorList tensors) { + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Op()); + }); +} + +template class Op> +std::vector foreach_unary_op(TensorList tensors) { + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Op()); + }); + return tensor_lists[1]; +} + +template class Op> +void foreach_unary_op_(TensorList tensors) { + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Op()); + }); +} + +template class Op> +void foreach_op_unary_(TensorList tensors) { + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Op()); + }); +} + +template class Op> +std::vector foreach_unary_op_bfloat16(TensorList tensors) { + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Op()); + }); + return tensor_lists[1]; +} + +template class Op> +void foreach_unary_op_bfloat16_(TensorList tensors) { + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Op()); + }); +} + +#define FOREACH_UNARY_OP_COMPLEX(NAME, NAME1) \ +template \ +struct NAME1 { \ + __device__ T operator()(T t) const { return std::NAME(t); } \ +}; \ + \ +std::vector foreach_tensor_##NAME##_cuda(TensorList tensors) { \ + check_foreach_api_restrictions(tensors); \ + if (!can_use_fast_route(tensors)) { \ + return at::native::foreach_tensor_##NAME##_slow(tensors); \ + } \ + \ + return foreach_unary_op_complex(tensors); \ +} \ + \ +void foreach_tensor_##NAME##_cuda_(TensorList tensors) { \ + check_foreach_api_restrictions(tensors); \ + if (!can_use_fast_route(tensors)) { \ + return at::native::foreach_tensor_##NAME##_slow_(tensors); \ + } \ + \ + foreach_unary_op_complex_(tensors); \ +} + +#define FOREACH_UNARY_OP_COMPLEX_BFLOAT16(NAME, NAME1) \ template \ struct NAME1 { \ __device__ T operator()(T t) const { return std::NAME(t); } \ @@ -39,7 +212,30 @@ struct NAME1 { \ \ std::vector foreach_tensor_##NAME##_cuda(TensorList tensors) { \ check_foreach_api_restrictions(tensors); \ + if (!can_use_fast_route(tensors)) { \ + return at::native::foreach_tensor_##NAME##_slow(tensors); \ + } \ \ + return foreach_unary_op_complex_bfloat16(tensors); \ +} \ + \ +void foreach_tensor_##NAME##_cuda_(TensorList tensors) { \ + check_foreach_api_restrictions(tensors); \ + if (!can_use_fast_route(tensors)) { \ + return at::native::foreach_tensor_##NAME##_slow_(tensors); \ + } \ + \ + foreach_unary_op_complex_bfloat16_(tensors); \ +} + +#define FOREACH_UNARY_OP(NAME, NAME1) \ +template \ +struct NAME1 { \ + __device__ T operator()(T t) const { return std::NAME(t); } \ +}; \ + \ +std::vector foreach_tensor_##NAME##_cuda(TensorList tensors) { \ + check_foreach_api_restrictions(tensors); \ if (!can_use_fast_route(tensors)) { \ return at::native::foreach_tensor_##NAME##_slow(tensors); \ } \ @@ -49,7 +245,6 @@ std::vector foreach_tensor_##NAME##_cuda(TensorList tensors) { \ \ void foreach_tensor_##NAME##_cuda_(TensorList tensors) { \ check_foreach_api_restrictions(tensors); \ - \ if (!can_use_fast_route(tensors)) { \ return at::native::foreach_tensor_##NAME##_slow_(tensors); \ } \ @@ -57,7 +252,384 @@ void foreach_tensor_##NAME##_cuda_(TensorList tensors) { \ foreach_unary_op_(tensors); \ } -FOREACH_UNARY_OP(exp, Exp); -FOREACH_UNARY_OP(sqrt, Sqrt); +#define FOREACH_UNARY_OP_BFLOAT16(NAME, NAME1) \ +template \ +struct NAME1 { \ + __device__ T operator()(T t) const { return std::NAME(t); } \ +}; \ + \ +std::vector foreach_tensor_##NAME##_cuda(TensorList tensors) { \ + check_foreach_api_restrictions(tensors); \ + \ + if (!can_use_fast_route(tensors)) { \ + return at::native::foreach_tensor_##NAME##_slow(tensors); \ + } \ + \ + return foreach_unary_op_bfloat16(tensors); \ +} \ + \ +void foreach_tensor_##NAME##_cuda_(TensorList tensors) { \ + check_foreach_api_restrictions(tensors); \ + \ + if (!can_use_fast_route(tensors)) { \ + return at::native::foreach_tensor_##NAME##_slow_(tensors); \ + } \ + \ + foreach_unary_op_bfloat16_(tensors); \ +} + +FOREACH_UNARY_OP(ceil, Ceil); +FOREACH_UNARY_OP(erfc, Erfc); +FOREACH_UNARY_OP(expm1, Expm1); +FOREACH_UNARY_OP(floor, Floor); +FOREACH_UNARY_OP(lgamma, Lgamma); + +FOREACH_UNARY_OP_BFLOAT16(log1p, Log1p); +FOREACH_UNARY_OP_BFLOAT16(erf, Erf); + +FOREACH_UNARY_OP_COMPLEX(acos, Acos); +FOREACH_UNARY_OP_COMPLEX(asin, Asin); +FOREACH_UNARY_OP_COMPLEX(atan, Atan); +FOREACH_UNARY_OP_COMPLEX(cosh, Cosh); +FOREACH_UNARY_OP_COMPLEX(tan, Tan); +FOREACH_UNARY_OP_COMPLEX(sin, Sin); +FOREACH_UNARY_OP_COMPLEX(sinh, Sinh); + +FOREACH_UNARY_OP_COMPLEX_BFLOAT16(exp, Exp); +FOREACH_UNARY_OP_COMPLEX_BFLOAT16(sqrt, Sqrt); +FOREACH_UNARY_OP_COMPLEX_BFLOAT16(cos, Cos); +FOREACH_UNARY_OP_COMPLEX_BFLOAT16(tanh, Tanh); +FOREACH_UNARY_OP_COMPLEX_BFLOAT16(log, Log); +FOREACH_UNARY_OP_COMPLEX_BFLOAT16(log10, Log10); +FOREACH_UNARY_OP_COMPLEX_BFLOAT16(log2, Log2); + +// +// Special cases +// +std::vector foreach_tensor_neg_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_neg_slow(tensors); + } + + return foreach_unary_op_complex_bfloat16(tensors); +} + +void foreach_tensor_neg_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_neg_slow_(tensors); + } + + foreach_unary_op_complex_bfloat16_(tensors); +} + +template \ +struct Round { \ + __device__ T operator()(T t) const { return std::nearbyint(t); } \ +}; + +std::vector foreach_tensor_round_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_round_slow(tensors); + } + + return foreach_unary_op(tensors); +} + +void foreach_tensor_round_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_round_slow_(tensors); + } + + foreach_unary_op_(tensors); +} + +// Abs have to go via slow path in case of a complex type. +// This is because foreach kernels can't return a different dtype than passed, while +// abs with complex input will produce float output. +template +struct Abs { + __device__ T operator()(T t) const { return std::abs(t); } +}; + +std::vector foreach_tensor_abs_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + bool has_complex = false; + for (auto t : tensors) { + if (at::isComplexType(t.scalar_type())) { + has_complex = true; + } + } + + if (!can_use_fast_route(tensors) || has_complex) { + return at::native::foreach_tensor_abs_slow(tensors); + } + + return foreach_unary_op_complex_bfloat16(tensors); +} + +void foreach_tensor_abs_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + bool has_complex = false; + for (auto t : tensors) { + if (at::isComplexType(t.scalar_type())) { + has_complex = true; + } + } + + if (!can_use_fast_route(tensors) || has_complex) { + return at::native::foreach_tensor_abs_slow_(tensors); + } + + foreach_unary_op_complex_bfloat16_(tensors); +} + +template +struct Trunc { + __device__ T operator()(T t) const { return t - std::trunc(t); } +}; + +std::vector foreach_tensor_frac_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_frac_slow(tensors); + } + + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Trunc()); + }); + return tensor_lists[1]; +} + +void foreach_tensor_frac_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_frac_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Trunc()); + }); +} + +template +struct Sigmoid { + T one = T(1); + __device__ T operator()(T t) const { return (one / (one + std::exp(-t))); } +}; + +std::vector foreach_tensor_sigmoid_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_sigmoid_slow(tensors); + } + + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Sigmoid()); + }); + return tensor_lists[1]; +} + +void foreach_tensor_sigmoid_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_sigmoid_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Sigmoid()); + }); +} + +template +struct Reciprocal { + T one = T(1); + __device__ T operator()(T t) const { return (one / t); } +}; + +std::vector foreach_tensor_reciprocal_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_reciprocal_slow(tensors); + } + + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Reciprocal()); + }); + return tensor_lists[1]; +} + +void foreach_tensor_reciprocal_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_reciprocal_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Reciprocal()); + }); +} + +template +struct Truncf { + __device__ T operator()(T t) const { return std::trunc(t); } +}; + +std::vector foreach_tensor_trunc_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_trunc_slow(tensors); + } + + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors.size()); + for (const auto& t: tensors) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<2>(tensor_lists, + UnaryOpFunctor(), + Truncf()); + }); + return tensor_lists[1]; +} + +void foreach_tensor_trunc_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_trunc_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensors[0].scalar_type(), "foreach_unary_op_cuda_", [&]() { + using opmath_t = get_opmath_t::opmath_t; + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + Truncf()); + }); +} + +void foreach_tensor_zero_cuda_(TensorList tensors) { + check_foreach_api_restrictions(tensors); + + if (!can_use_fast_route(tensors)) { + return at::native::foreach_tensor_zero_slow_(tensors); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors.vec()); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, tensors[0].scalar_type(), "foreach_zero_cuda_", [&]() { + multi_tensor_apply<1>(tensor_lists, + ZeroFunctor()); + }); +} }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu index f1932e64add3f..41fc2dea5856d 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool2d.cu @@ -205,9 +205,9 @@ void fractional_max_pool2d_out_cuda_template( <<>>( devOutput, devIndices, devInput, devSamples, poolSizeH, poolSizeW); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - AT_CUDA_CHECK(cudaGetLastError()); } void fractional_max_pool2d_backward_out_cuda_template( @@ -272,9 +272,9 @@ void fractional_max_pool2d_backward_out_cuda_template( fractional_max_pool2d_backward_out_cuda_frame <<>>( devGradInput, devGradOutput, devIndices); - } - ); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + ); } }// namespace diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu index 85be04ba20d3d..0d492de485708 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu @@ -241,9 +241,9 @@ void fractional_max_pool3d_out_cuda_template( randomSamples.packed_accessor64(), poolSizeT, poolSizeH, poolSizeW ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - AT_CUDA_CHECK(cudaGetLastError()); } void fractional_max_pool3d_backward_out_cuda_template( @@ -327,9 +327,9 @@ void fractional_max_pool3d_backward_out_cuda_template( gradOutput_.packed_accessor64(), indices_.packed_accessor64() ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - AT_CUDA_CHECK(cudaGetLastError()); } }// namespace diff --git a/aten/src/ATen/native/cuda/FunctionOfAMatrixUtilsKernel.cu b/aten/src/ATen/native/cuda/FunctionOfAMatrixUtilsKernel.cu index 0e90ae4b3cbbc..180385aaf0520 100644 --- a/aten/src/ATen/native/cuda/FunctionOfAMatrixUtilsKernel.cu +++ b/aten/src/ATen/native/cuda/FunctionOfAMatrixUtilsKernel.cu @@ -37,8 +37,7 @@ void _lauch_kernel(int total_n_elems, const func_t& f) { auto stream = at::cuda::getCurrentCUDAStream(); _elemwise_kernel <<>>(total_n_elems, f); - - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template diff --git a/aten/src/ATen/native/cuda/GcdLcmKernel.cu b/aten/src/ATen/native/cuda/GcdLcmKernel.cu new file mode 100644 index 0000000000000..92e5ef73d6c28 --- /dev/null +++ b/aten/src/ATen/native/cuda/GcdLcmKernel.cu @@ -0,0 +1,33 @@ +#include +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at { namespace native { + +void gcd_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "gcd_cuda", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { + return calc_gcd(a, b); + }); + }); +} + +void lcm_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lcm_cuda", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { + scalar_t g = calc_gcd(a, b); + return (g == 0) ? 0 : ::abs(a / g * b); + }); + }); +} + +REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda); +REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/GridSampler.cu b/aten/src/ATen/native/cuda/GridSampler.cu index 023167109af27..a08c13037e348 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cu +++ b/aten/src/ATen/native/cuda/GridSampler.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -50,11 +51,11 @@ namespace { const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid - scalar_t ix = grid.data[grid_offset]; - scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get NE, NW, SE, SW pixel values from (x, y) @@ -105,6 +106,38 @@ namespace { *out_ptr_NCHW = static_cast(0); } } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + scalar_t ix_nw = ::floor(ix); + scalar_t iy_nw = ::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + scalar_t coefficients[4]; + + for (index_t i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + *out_ptr_NCHW = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + ty); + } } } } @@ -300,13 +333,13 @@ namespace { const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid - scalar_t ix = grid.data[grid_offset]; - scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; // multipliers for gradients on ix and iy scalar_t gix_mult, giy_mult; - ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); - iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get NE, NW, SE, SW pixel values from (x, y) @@ -387,6 +420,57 @@ namespace { scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; gGrid_ptr_NHW[0] = static_cast(0); gGrid_ptr_NHW[1] = static_cast(0); + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult); + iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult); + + scalar_t ix_nw = ::floor(ix); + scalar_t iy_nw = ::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t x_coeffs[4]; + scalar_t y_coeffs[4]; + scalar_t x_coeffs_grad[4]; + scalar_t y_coeffs_grad[4]; + + get_cubic_upsampling_coefficients(x_coeffs, tx); + get_cubic_upsampling_coefficients(y_coeffs, ty); + get_cubic_coefficients_grad(x_coeffs_grad, tx); + get_cubic_coefficients_grad(y_coeffs_grad, ty); + + scalar_t gix = static_cast(0); + scalar_t giy = static_cast(0); + + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN; + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + for (index_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + for (index_t i = 0; i < 4; ++i) { + for (index_t j = 0; j < 4; ++j) { + + // set input gradient + add_value_bounded(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, + gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners); + + // set grid gradient + scalar_t val = get_value_bounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; } } } @@ -624,6 +708,7 @@ Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid, static_cast(interpolation_mode), static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { grid_sampler_2d_kernel <<>>( @@ -634,6 +719,7 @@ Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid, static_cast(interpolation_mode), static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); } @@ -663,16 +749,18 @@ Tensor grid_sampler_3d_cuda(const Tensor& input, const Tensor& grid, static_cast(interpolation_mode), static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - grid_sampler_3d_kernel - <<>>( - count, - getTensorInfo(input), - getTensorInfo(grid), - getTensorInfo(output), - static_cast(interpolation_mode), - static_cast(padding_mode), - align_corners); + grid_sampler_3d_kernel + <<>>( + count, + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(interpolation_mode), + static_cast(padding_mode), + align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); } @@ -708,6 +796,7 @@ grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input, static_cast(interpolation_mode), static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { grid_sampler_2d_backward_kernel <<>>( @@ -720,6 +809,7 @@ grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input, static_cast(interpolation_mode), static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); } @@ -756,6 +846,7 @@ grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input, static_cast(interpolation_mode), static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { grid_sampler_3d_backward_kernel <<>>( @@ -768,6 +859,7 @@ grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input, static_cast(interpolation_mode), static_cast(padding_mode), align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); } diff --git a/aten/src/ATen/native/cuda/GridSampler.cuh b/aten/src/ATen/native/cuda/GridSampler.cuh index 4a94a3fda1bb7..0c4acd1be41c0 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cuh +++ b/aten/src/ATen/native/cuda/GridSampler.cuh @@ -7,7 +7,7 @@ namespace at { namespace native { namespace detail { - enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; enum class GridSamplerPadding {Zeros, Border, Reflection}; } // namespace detail @@ -153,15 +153,11 @@ scalar_t safe_downgrade_to_int_range(scalar_t x){ return x; } -// Computes the pixel source index value for a grid coordinate -template +template static __forceinline__ __device__ -scalar_t grid_sampler_compute_source_index( - scalar_t coord, - int size, - GridSamplerPadding padding_mode, - bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); +scalar_t compute_coordinates(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { if (padding_mode == GridSamplerPadding::Border) { // clip coordinates to image borders coord = clip_coordinates(coord, size); @@ -176,7 +172,20 @@ scalar_t grid_sampler_compute_source_index( coord = clip_coordinates(coord, size); } - coord = safe_downgrade_to_int_range(coord); + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +static __forceinline__ __device__ +scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); return coord; } @@ -224,6 +233,25 @@ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } +template +static __forceinline__ __device__ +scalar_t get_value_bounded( + scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + template static __forceinline__ __device__ void safe_add_2d(scalar_t *data, int h, int w, @@ -244,4 +272,44 @@ void safe_add_3d(scalar_t *data, int d, int h, int w, } } +template +static __forceinline__ __device__ +void add_value_bounded( + scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static __forceinline__ __device__ +void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/IGammaKernel.cu b/aten/src/ATen/native/cuda/IGammaKernel.cu new file mode 100644 index 0000000000000..d74f09573c621 --- /dev/null +++ b/aten/src/ATen/native/cuda/IGammaKernel.cu @@ -0,0 +1,542 @@ +#include +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace { + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +// regularized lower & upper incomplete gamma +template +__host__ __device__ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + using accscalar_t = at::acc_type; + int64_t i, dir; + accscalar_t y, num_ans, denom_ans; + accscalar_t absx = ::fabs(x); + const accscalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return ::pow(x, static_cast(i)) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +template +__host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + using accscalar_t = at::acc_type; + + static const accscalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const accscalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0 + }; + return ratevl(static_cast(x), lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +__host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + using accscalar_t = at::acc_type; + accscalar_t ax, fac, res, num, numfac; + static accscalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static accscalar_t EXP1 = 2.718281828459045; + static accscalar_t lanczos_g = 6.024680040776729583740234375; + + if (::fabs(a - x) > 0.4 * ::fabs(a)) { + ax = a * ::log(x) - x - ::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return ::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= ::exp(a - x) * ::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +__host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + + using accscalar_t = at::acc_type; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + accscalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +__host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + using accscalar_t = at::acc_type; + int n; + accscalar_t fac = 1; + accscalar_t sum = 0; + accscalar_t term, logx; + static accscalar_t MAXITER = 2000; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (::fabs(term) <= MACHEP * ::fabs(sum)) { + break; + } + } + + logx = ::log(x); + term = -::expm1(a * logx - ::lgamma(1+a)); + return term - ::exp(a * logx - ::lgamma(a)) * sum; +} + +template +__host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + + using accscalar_t = at::acc_type; + static const accscalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + accscalar_t lambda = x / a; + accscalar_t sigma = (x - a) / a; + accscalar_t eta, res, ck, ckterm, term, absterm; + accscalar_t absoldterm = INFINITY; + accscalar_t etapow[25] = {1}; + accscalar_t sum = 0; + accscalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = ::sqrt(-2 * (::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -::sqrt(-2 * (::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a); + + return res; +} + +template +__host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + + using accscalar_t = at::acc_type; + int i; + accscalar_t ans, ax, c, yc, r, t, y, z; + accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static accscalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static accscalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = ::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +__noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + + using accscalar_t = at::acc_type; + accscalar_t absxma_a; + + static accscalar_t SMALL = 20.0; + static accscalar_t LARGE = 200.0; + static accscalar_t SMALLRATIO = 0.3; + static accscalar_t LARGERATIO = 4.5; + + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (::isinf(static_cast(x))) { + return 0.0; + } + + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / ::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +// NOTE: this __noinline__ is important -- otherwise, observed compile times significantly +// increase. The same kernel seems to get recompiled mulitple times via gpu_kernel_with_scalars, +// multiple dtypes, etc. +template +__noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + + using accscalar_t = at::acc_type; + accscalar_t absxma_a; + static accscalar_t SMALL = 20.0; + static accscalar_t LARGE = 200.0; + static accscalar_t SMALLRATIO = 0.3; + static accscalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (::isinf(static_cast(x))) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. */ + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +} + +// end of regularized lower & upper incomplete gamma + +namespace at { namespace native { + +void igamma_kernel_cuda(TensorIterator& iter) { + + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igamma_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return calc_igamma(a, b); + }); + }); +} + +void igammac_kernel_cuda(TensorIterator& iter) { + + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igammac_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return calc_igammac(a, b); + }); + }); +} + +REGISTER_DISPATCH(igamma_stub, &igamma_kernel_cuda); +REGISTER_DISPATCH(igammac_stub, &igammac_kernel_cuda); + +// DO NOT ADD ANY NEW KERNELS HERE +// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel. + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index b69267e90437e..d88f202487af7 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -4,9 +4,13 @@ #include #include #include -#include +#include #include +#include +#include #include +#include +#include namespace at { namespace native { @@ -14,6 +18,54 @@ static constexpr int launch_bound2 = 4; static constexpr int launch_size_nd = 128; +template +__device__ __forceinline__ IndexType indexToOffset( + const cuda::detail::TensorInfo& info, + int64_t index, + IndexType size) { + IndexType linearIndex = static_cast(index); + CUDA_KERNEL_ASSERT(linearIndex < size && linearIndex >= -size); + if (linearIndex < 0) { + linearIndex += size; + } + return cuda::detail::IndexToOffset::get(linearIndex, info); +} + +template +void dispatchTakePutImpl(const Tensor& input, Tensor& output, const Tensor& index) { + auto inputInfo = cuda::detail::getTensorInfo(input); + inputInfo.collapseDims(); + auto numel = input.numel(); + if (inputInfo.isContiguous()) { + cuda::CUDA_tensor_apply2( + output, + index, + [inputInfo, numel] __device__ ( + T & out, const int64_t& idx) { + auto offset = indexToOffset<-2, T, IndexType>(inputInfo, idx, numel); + out = inputInfo.data[offset]; + }); + } else { + cuda::CUDA_tensor_apply2( + output, + index, + [inputInfo, numel] __device__ ( + T & out, const int64_t& idx) { + auto offset = indexToOffset<-1, T, IndexType>(inputInfo, idx, numel); + out = inputInfo.data[offset]; + }); + } +} + +template +void dispatchTakePut(const Tensor& input, Tensor& output, const Tensor& index) { + if (cuda::detail::canUse32BitIndexMath(input)) { + dispatchTakePutImpl(input, output, index); + } else { + dispatchTakePutImpl(input, output, index); + } +} + template C10_LAUNCH_BOUNDS_2(nt, launch_bound2) __global__ void index_elementwise_kernel(int N, func_t f) { @@ -39,7 +91,7 @@ static void launch_kernel(int64_t N, const func_t& f) { dim3 grid((N + block.x * vt - 1) / (block.x * vt)); auto stream = at::cuda::getCurrentCUDAStream(); index_elementwise_kernel<<>>(N, f); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -138,7 +190,7 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self Tensor _mask = (mask.dim() == 0) ? mask.unsqueeze(0) : mask; Tensor _self = (self.dim() == 0) ? self.unsqueeze(0) : self; std::tie(_mask, _self) = expand_outplace(_mask, _self); - at::native::index_out(result, _self, _mask); + at::native::index_out(result, _self, c10::List>({_mask})); return result; } @@ -154,6 +206,52 @@ Tensor & masked_select_out_cuda(Tensor & result, const Tensor & self, const Tens return masked_select_out_cuda_impl(result, self, mask); } +void take_out_cuda_template(Tensor& output, const Tensor& input, const Tensor& index) { + TORCH_CHECK(output.device().type() == at::kCUDA, "device type of output (", output.device().type(), ") is not GPU"); + TORCH_CHECK(input.device().type() == at::kCUDA, "device type of input (", input.device().type(), ") is not GPU"); + TORCH_CHECK(index.device().type() == at::kCUDA, "device type of index (", index.device().type(), ") is not GPU"); + + TORCH_CHECK(output.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", output.layout(), " on output tensor"); + TORCH_CHECK(input.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", input.layout(), " on input tensor"); + TORCH_CHECK(index.layout() == Layout::Strided, "take() only supports strided layout, got layout: ", index.layout(), " on index tensor"); + + TORCH_CHECK(output.scalar_type() == input.scalar_type(), + "output and input scalar type must match. but got different types: ", output.scalar_type(), " and ", input.scalar_type()); + TORCH_CHECK(index.scalar_type() == kLong, "index must be an int64 tensor"); + + TensorArg output_arg{ output, "output", 1 }; + TensorArg input_arg{ input, "input", 2 }; + TensorArg index_arg{ index, "index", 3 }; + checkAllSameGPU("take", {output_arg, input_arg, index_arg}); + + TORCH_CHECK(input.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); + TORCH_CHECK(output.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); + TORCH_CHECK(index.dim() < MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); + + TORCH_CHECK(!(input.numel() == 0 && index.numel() != 0), "tried to take from an empty tensor"); + + at::assert_no_internal_overlap(output); + at::assert_no_partial_overlap(output, index); + at::assert_no_overlap(output, input); + + output.resize_(index.sizes()); + + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, input.scalar_type(), "take_cuda", [&] { + dispatchTakePut(input, output, index); + }); +} + +Tensor take_cuda(const Tensor& self, const Tensor& index) { + auto out = at::empty(index.sizes(), self.options()); + take_out_cuda_template(out, self, index); + return out; +} + +Tensor& take_out_cuda(Tensor& out, const Tensor& self, const Tensor& index) { + take_out_cuda_template(out, self, index); + return out; +} + REGISTER_DISPATCH(index_stub, &index_kernel); REGISTER_DISPATCH(index_put_stub, &index_put_kernel); diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 37ed52755a527..035dc188c81ca 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -160,7 +160,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) { } -static std::tuple> makeLinearIndex(Tensor self, TensorList orig, bool check_range) { +static std::tuple> makeLinearIndex(Tensor self, const c10::List>& orig, bool check_range) { checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors auto indices = expandTensors(self, orig); @@ -184,7 +184,7 @@ static std::tuple>& indices, const Tensor & value, bool unsafe) { if (indices.size() > (size_t)self.dim()) { TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); } @@ -232,19 +232,18 @@ void index_put_accum_kernel(Tensor & self, TensorList indices, const Tensor & va AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, value_.scalar_type(), "indexing_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "indexing_backward", [&] { - indexing_backward_kernel<<>>( - sorted_indices.data_ptr(), - orig_indices.data_ptr(), - value_.data_ptr(), - src_.data_ptr(), - num_indices, - sliceSize, - strideBefore, - nElemBefore); + indexing_backward_kernel<<>>( + sorted_indices.data_ptr(), + orig_indices.data_ptr(), + value_.data_ptr(), + src_.data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - }); - AT_CUDA_CHECK(cudaGetLastError()); + if (permuted) self.copy_(src_.permute(inversePerm)); } @@ -308,10 +307,10 @@ static ptrdiff_t getSliceSize(const Tensor & dst, // the number of indices chosen is large, then the // indexAddLargeIndex kernel is a better choice to increase // parallelism. -template +template __global__ void indexAddSmallIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstAddDim, int srcAddDim, IndexType innerSize, @@ -324,7 +323,7 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo dst, for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) { // Lua indices begin at 1 IndexType dstIndex = - indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize); // We stride over the output ignoring the indexed dimension @@ -351,11 +350,11 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo dst, // the number of indices chosen is small, then the // indexAddSmallIndex kernel is a better choice to reduce memory // accesses. -template __global__ void indexAddLargeIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstAddDim, int srcAddDim, IndexType totalSize, @@ -378,7 +377,7 @@ __global__ void indexAddLargeIndex(cuda::detail::TensorInfo dst, // Lua indices begin at 1 IndexType dstIndex = - indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(srcIndex, indices)]; CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize); IndexType dstOffset = @@ -438,7 +437,7 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const checkAllSameGPU("index_add", {self_arg, index_arg, source_arg}); TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector"); - TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index"); + TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_add_(): Expected dtype int32/int64 for index"); TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_add_(): self and source must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < source.dim(), @@ -446,6 +445,10 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const TORCH_CHECK(index.numel() == (source.dim() == 0 ? 1 : source.size(dim)), "index_add_(): Number of indices should be equal to self.size(dim)"); + at::assert_no_internal_overlap(self); + at::assert_no_overlap(self, index); + at::assert_no_overlap(self, source); + // Scalars are treated as 1-d tensor Tensor self_ = (self.dim() == 0) ? self.view(1) : self; Tensor source_ = (source.dim() == 0) ? source.view(1) : source; @@ -476,21 +479,23 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ - indexAddSmallIndex \ - <<>>( \ - selfInfo, sourceInfo, indexInfo, \ - selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); - -#define LARGE_INDEX(TENSOR_TYPE, TYPE, \ - SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \ - indexAddLargeIndex \ - <<>>( \ - selfInfo, sourceInfo, indexInfo, \ - selfAddDim, sourceAddDim, sourceTotalSize, \ - (IDX_IS_MAJOR) ? sliceSize : numIndex, \ - selfAddDimSize); +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ + indexAddSmallIndex \ + <<>>( \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ + SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \ + indexAddLargeIndex \ + <<>>( \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sourceTotalSize, \ + (IDX_IS_MAJOR) ? sliceSize : numIndex, \ + selfAddDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); @@ -501,75 +506,74 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const if (cuda::detail::canUse32BitIndexMath(self) && cuda::detail::canUse32BitIndexMath(source) && cuda::detail::canUse32BitIndexMath(index)) { - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] { - cuda::detail::TensorInfo selfInfo = - cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); - selfInfo.reduceDim(selfAddDim); - + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { + cuda::detail::TensorInfo selfInfo = + cuda::detail::getTensorInfo(self_); + int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { auto sourceInfo = cuda::detail::getTensorInfo(source_); int sourceAddDim = sourceInfo.collapseDims(dim); sourceInfo.reduceDim(sourceAddDim); auto indexInfo = - cuda::detail::getTensorInfo(index); + cuda::detail::getTensorInfo(index); indexInfo.collapseDims(); // A reasonable choice for when to have each thread iterate over // index to choose if (numIndex <= 16) { if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2); + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2); + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2); + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); } else { - SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1); + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); } } else { bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim); if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { - LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true); + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true); + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); } else { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false); + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); } } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true); + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); } else { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false); + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); } } else { - LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true); + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); } } }); }); } else { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "index_add", [&] { - cuda::detail::TensorInfo selfInfo = - cuda::detail::getTensorInfo(self_); - int selfAddDim = selfInfo.collapseDims(dim); - selfInfo.reduceDim(selfAddDim); - - cuda::detail::TensorInfo sourceInfo = - cuda::detail::getTensorInfo(source_); - int sourceAddDim = sourceInfo.collapseDims(dim); - sourceInfo.reduceDim(sourceAddDim); - - cuda::detail::TensorInfo indexInfo = - cuda::detail::getTensorInfo(index); + cuda::detail::TensorInfo selfInfo = + cuda::detail::getTensorInfo(self_); + int selfAddDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfAddDim); + + cuda::detail::TensorInfo sourceInfo = + cuda::detail::getTensorInfo(source_); + int sourceAddDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceAddDim); + + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () { + cuda::detail::TensorInfo indexInfo = + cuda::detail::getTensorInfo(index); indexInfo.collapseDims(); - LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true); + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); }); }); } @@ -586,10 +590,10 @@ namespace { // the number of indices chosen is large, then the // indexSelectLargeIndex kernel is a better choice to increase // parallelism. -template +template __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstSelectDim, int srcSelectDim, IndexType innerSize, @@ -601,7 +605,7 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo dst // re-accessing indices in addition to src elements can be slow. for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) { IndexType srcIndex = - indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize); // We stride over the output ignoring the indexed dimension @@ -628,11 +632,11 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo dst // the number of indices chosen is small, then the // indexSelectSmallIndex kernel is a better choice to reduce memory // accesses. -template __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo dst, cuda::detail::TensorInfo src, - cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo indices, int dstSelectDim, int srcSelectDim, IndexType totalSize, @@ -654,7 +658,7 @@ __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo dst } IndexType srcIndex = - indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; + indices.data[cuda::detail::IndexToOffset::get(dstIndex, indices)]; CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize); IndexType dstOffset = @@ -722,22 +726,24 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim, int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ - indexSelectSmallIndex \ - <<>>( \ - outInfo, selfInfo, indicesInfo, \ - outSelectDim, selfSelectDim, static_cast(sliceSize), \ - selfSelectDimSize); - -#define LARGE_INDEX(TENSOR_TYPE, TYPE, \ - DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ - indexSelectLargeIndex \ - <<>>( \ - outInfo, selfInfo, indicesInfo, \ - outSelectDim, selfSelectDim, static_cast(outTotalSize), \ - static_cast((IDX_IS_MAJOR) ? sliceSize : numIndices), \ - selfSelectDimSize); +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ + indexSelectSmallIndex \ + <<>>( \ + outInfo, selfInfo, indicesInfo, \ + outSelectDim, selfSelectDim, static_cast(sliceSize), \ + selfSelectDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \ + DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ + indexSelectLargeIndex \ + <<>>( \ + outInfo, selfInfo, indicesInfo, \ + outSelectDim, selfSelectDim, static_cast(outTotalSize), \ + static_cast((IDX_IS_MAJOR) ? sliceSize : numIndices), \ + selfSelectDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); @@ -755,42 +761,44 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim, int selfSelectDim = selfInfo.collapseDims(dim); selfInfo.reduceDim(selfSelectDim); - auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); - indicesInfo.collapseDims(); - - // A reasonable choice for when to have each thread iterate over - // indices to choose - if (numIndices <= 16) { - if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2); - } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2); - } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { - SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2); - } else { - SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1); - } - } else { - bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim); - - if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { - LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true); - } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () { + auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); + indicesInfo.collapseDims(); + + // A reasonable choice for when to have each thread iterate over + // indices to choose + if (numIndices <= 16) { + if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2); + } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2); + } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { + SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2); } else { - LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false); + SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1); } - } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { - if (indexIsMajor) { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true); + } else { + bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim); + + if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true); + } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false); + } + } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) { + if (indexIsMajor) { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true); + } else { + LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false); + } } else { - LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false); + LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true); } - } else { - LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true); } - } + }); } else { auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(out)); int outSelectDim = outInfo.collapseDims(dim); @@ -799,11 +807,12 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim, auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(self)); int selfSelectDim = selfInfo.collapseDims(dim); selfInfo.reduceDim(selfSelectDim); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () { + auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); + indicesInfo.collapseDims(); - auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo(index)); - indicesInfo.collapseDims(); - - LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true); + LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true); + }); } #undef SMALL_INDEX #undef LARGE_INDEX @@ -818,22 +827,17 @@ Tensor& index_select_out_cuda(Tensor& out, const Tensor& self, int64_t dim, TORCH_CHECK(at::cuda::check_device({out, self, index}), "Input, output and indices must be on the current device"); at::assert_no_internal_overlap(out); + at::assert_no_overlap(out, self); + at::assert_no_overlap(out, index); dim = at::maybe_wrap_dim(dim, self); TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); -#if defined(__HIP_PLATFORM_HCC__) AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, out.scalar_type(), "index_select_cuda", [&] { index_select_out_cuda_impl(out, self, dim, index); }); -#else // __HIP_PLATFORM_HCC__ - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::Bool, - out.scalar_type(), "index_select_cuda", - [&] { index_select_out_cuda_impl(out, self, dim, index); }); -#endif // __HIP_PLATFORM_HCC__ return out; } @@ -876,7 +880,8 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ //However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced. bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous(); at::Tensor out_temp = need_to_copy ? - at::native::empty_cuda({self.dim(), num_nonzeros_h}, out.options()) : + at::native::empty_cuda({self.dim(), num_nonzeros_h}, optTypeMetaToScalarType(out.options().dtype_opt()), + out.options().layout_opt(), out.options().device_opt(), out.options().pinned_memory_opt()) : out.resize_({self.dim(), num_nonzeros_h}); //Scalars are expected to produce output of size (1,0), so we can't write to it if (self.dim() > 0) { @@ -925,7 +930,7 @@ Tensor& nonzero_out_cuda(Tensor& out, const Tensor& self){ } Tensor nonzero_cuda(const Tensor& self){ - Tensor out = at::native::empty_cuda({0}, self.options().dtype(kLong)); + Tensor out = at::native::empty_cuda({0}, kLong, self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt()); return nonzero_out_cuda(out, self); } diff --git a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp index 92895d947758e..1bbe47dbfb2e3 100644 --- a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp @@ -76,50 +76,4 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor & } } -Tensor & fmod_cuda_out(Tensor & result, const Tensor & self, Scalar other) { - at::assert_no_internal_overlap(result); - return legacy::cuda::_th_fmod_out(result, self, other); -} - -Tensor fmod_cuda(const Tensor & self, Scalar other) { - return legacy::cuda::_th_fmod(self, other); -} - -Tensor & fmod_cuda_out(Tensor & result, const Tensor & self, const Tensor & other) { - at::assert_no_internal_overlap(result); - Tensor b_self, b_other; - // optimization that codegen used to do; avoids broadcast. - if (other.dim() == 0) { - return fmod_cuda_out(result, self, other.item()); - } - std::tie(b_self, b_other) = expand_outplace(self, other, "fmod_out"); - return legacy::cuda::_th_fmod_out(result, b_self, b_other); -} - -Tensor fmod_cuda(const Tensor & self, const Tensor & other) { - // optimization that codegen used to do; avoids broadcast. - if (other.dim() == 0) { - return fmod_cuda(self, other.item()); - } - Tensor b_self, b_other; - std::tie(b_self, b_other) = expand_outplace(self, other, "fmod"); - return legacy::cuda::_th_fmod(b_self, b_other); -} - -Tensor & fmod_cuda_(Tensor & self, Scalar other) { - at::assert_no_internal_overlap(self); - return legacy::cuda::_th_fmod_(self, other); -} - -Tensor & fmod_cuda_(Tensor & self, const Tensor & other) { - // optimization that codegen used to do; avoids broadcast. - if (other.dim() == 0) { - return fmod_cuda_(self, other.item()); - } - at::assert_no_internal_overlap(self); - Tensor b_other; - std::tie(b_other) = expand_inplace(self, other, "fmod_"); - return legacy::cuda::_th_fmod_(self, b_other); -} - }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu index 76f5c0a99efe5..69a366cc9cd5a 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu @@ -2,35 +2,14 @@ #include #include #include +#include +#include +#include +#include +#include namespace at { namespace native { -Tensor baddbmm_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm"); - return legacy::cuda::_th_baddbmm(b_self, batch1, batch2, beta, alpha); -} - -Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm_out"); - return legacy::cuda::_th_baddbmm_out(result, b_self, batch1, batch2, beta, alpha); -} - -Tensor& baddbmm__cuda(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - return baddbmm_out_cuda(self, self, batch1, batch2, beta, alpha); -} - -Tensor& bmm_out_cuda(Tensor &result, const Tensor& batch1, const Tensor& batch2) { - result.resize_({ batch1.size(0), batch1.size(1), batch2.size(2) }); - return legacy::cuda::_th_bmm_out(result, batch1, batch2); -} - -Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) { - Tensor result = at::empty({0}, self.options()); - return native::bmm_out_cuda(result, self, mat2); -} - Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) { Tensor tensor_; IntArrayRef tensor_strides = tensor.strides(); @@ -50,6 +29,37 @@ Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) { return tensor_; } +Tensor prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) { + IntArrayRef tensor_strides = tensor.strides(); + Tensor tensor_; + int fast_dim = transpose_result ? 2 : 1; + int leading_dim = transpose_result ? 1 : 2; + + if (tensor_strides[fast_dim] == 1 && + (tensor_strides[leading_dim] >= std::max(1, m))) { + transpose_tensor = false; + tensor_ = tensor; + ld_tensor = tensor_strides[leading_dim]; + } else if ((tensor_strides[leading_dim] == 1) && + (tensor_strides[fast_dim] >= std::max(1, n))) { + transpose_tensor = true; + tensor_ = tensor; + ld_tensor = tensor_strides[fast_dim]; + } else { + transpose_tensor = !transpose_result; + // gemm call requires leading dimension and stride parameters to be non-zero + bool is_stride_non_zero = tensor.stride(1) != 0 && tensor.stride(2) != 0; + if (tensor.is_contiguous() && is_stride_non_zero) { + tensor_ = tensor; + } else { + tensor_ = tensor.clone(at::MemoryFormat::Contiguous); + } + ld_tensor = tensor_.stride(1); + } + + return tensor_; +} + namespace { Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { @@ -142,6 +152,99 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma return result; } +Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor"); + TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); + TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor"); + + TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {batch1, "batch1", 2}, {batch2, "batch2", 3}}; + checkAllSameGPU("baddbmm", args); + + IntArrayRef batch1_sizes = batch1.sizes(); + IntArrayRef batch2_sizes = batch2.sizes(); + IntArrayRef self_sizes = self.sizes(); + + TORCH_CHECK(self_sizes[0] == batch1_sizes[0], "self dim 0 must match batch1 dim 0"); + TORCH_CHECK(self_sizes[0] == batch2_sizes[0], "self dim 0 must match batch2 dim 0"); + TORCH_CHECK(self_sizes[1] == batch1_sizes[1], "self dim 1 must match batch1 dim 1"); + TORCH_CHECK(self_sizes[2] == batch2_sizes[2], "self dim 2 must match batch2 dim 2"); + TORCH_CHECK(batch1_sizes[2] == batch2_sizes[1], "batch1 dim 2 must match batch2 dim 1"); + + if (!result.is_same(self)) { + result.resize_as_(self); + if (beta.to>() != 0.0) { + result.copy_(self); + } + } + + // handle pathological cases that blas may not like + if (result.numel() == 0) { + return result; + } else if (batch1_sizes[2] == 0) { + if (beta.to>() == 0.0) { + return result.zero_(); + } else { + return result.mul_(beta); + } + } + + bool transpose_result = false; + Tensor result_; + IntArrayRef result_strides = result.strides(); + IntArrayRef result_sizes = result.sizes(); + + if ((result_strides[1] == 1) && + ((result_sizes[2] == 1) || (result_strides[2] >= std::max(1, result_sizes[1])))) { + result_ = result; + } else if ((result_strides[2] == 1) && + (result_sizes[1] == 1 || (result_strides[1] >= std::max(1, result_sizes[2])))) { + transpose_result = true; + result_ = result; + } else { + result_ = result.transpose(1, 2).clone(at::MemoryFormat::Contiguous); + result_ = result_.transpose(1, 2); + } + + int leading_dim = transpose_result ? 1 : 2; + + Tensor batch1_ = transpose_result ? batch2 : batch1; + Tensor batch2_ = transpose_result ? batch1 : batch2; + int64_t m = result_sizes[transpose_result ? 2 : 1]; + int64_t n = result_sizes[leading_dim]; + int64_t k = batch1_.size(leading_dim); + + int64_t lda, ldb, ldc; + bool transpose_batch1, transpose_batch2; + batch1_ = prepare_batch_matrix_for_cublas(batch1_, transpose_batch1, lda, transpose_result, m, k); + batch2_ = prepare_batch_matrix_for_cublas(batch2_, transpose_batch2, ldb, transpose_result, k, n); + + ldc = result_.stride(leading_dim); + int64_t num_batches = result_.size(0); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] { + scalar_t alpha_val = alpha.to(); + scalar_t beta_val = beta.to(); + scalar_t* batch1_ptr = batch1_.data_ptr(); + scalar_t* batch2_ptr = batch2_.data_ptr(); + scalar_t* result_ptr = result_.data_ptr(); + at::cuda::blas::bgemm( + transpose_batch1 ? 't' : 'n', + transpose_batch2 ? 't' : 'n', + m, n, k, + alpha_val, + batch1_ptr, lda, batch1_.stride(0), + batch2_ptr, ldb, batch2_.stride(0), + beta_val, + result_ptr, ldc, result_.stride(0), + num_batches + ); + }); + if (!result.is_same(result_)) { + result.copy_(result_); + } + return result; +} + } // anonymous namespace Tensor& mm_out_cuda(Tensor& result, const Tensor& self, const Tensor& mat2) { @@ -178,69 +281,49 @@ Tensor& addmm__cuda(Tensor& self, const Tensor& mat1, const Tensor& mat2, return self; } -Tensor& addbmm_out_cuda(Tensor& out, const Tensor& self, - const Tensor& batch1, const Tensor& batch2, - Scalar beta, Scalar alpha) { - TORCH_CHECK(batch1.dim() == 3 && batch2.dim() == 3, - "Batch tensors should be 3D, got dimensions ", batch1.dim(), - " and ", batch2.dim()); - +Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { Tensor self_; - if (&out != &self) { - std::tie(self_) = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm"); + if (&result != &self) { + std::tie(self_) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm"); } else { - self_ = self; + self_ = self; } - - TORCH_CHECK(out.device() == self_.device() && - out.device() == batch1.device() && - out.device() == batch2.device(), - "Expected all tensors to be on the same device. Found: ", - out.device(), ", ", self_.device(), ", ", - batch1.device(), " and ", batch2.device()); - TORCH_CHECK(self_.dim() == 2, - "2D tensor expected, got ", self_.dim(), "D tensor for input"); - int64_t batchnum = batch1.size(0); - int64_t m1d1 = batch1.size(1); - int64_t innerdim = batch1.size(2); - int64_t m2d2 = batch2.size(2); - TORCH_CHECK(batchnum == batch2.size(0), - "equal number of batches expected"); - TORCH_CHECK(m1d1 == self_.size(0), - "first dimension of batch1 must match first dimension of input"); - TORCH_CHECK(m2d2 == self_.size(1), - "second dimension of batch2 must match second dimension of input"); - TORCH_CHECK(innerdim == batch2.size(1), - "second dimension of batch1 must match first dimension of batch2"); - - if (&out != &self) { - at::native::resize_as_(out, self_); - if (beta.to() != 0.0) { - at::native::copy_(out, self_); - } + { + at::NoNamesGuard guard; + baddbmm_out_cuda_impl(result, self_, batch1, batch2, beta, alpha); } + namedinference::propagate_names_if_nonempty( + result, + namedinference::compute_baddbmm_outnames(result, batch1, batch2, self)); + return result; +} - for (int64_t i=0; i(); + auto alpha_val = alpha.to(); + + // when beta is false, values in self should be ignored, + // nans and infs in self should not propagate. + if (beta_val == false) { + gpu_kernel( + iter, + [=] GPU_LAMBDA (scalar_t self_val, + scalar_t vec1_val, scalar_t vec2_val) -> scalar_t { + return alpha_val && vec1_val && vec2_val; + } + ); + } else { + gpu_kernel( + iter, + [=] GPU_LAMBDA (scalar_t self_val, + scalar_t vec1_val, scalar_t vec2_val) -> scalar_t { + return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val); + } + ); + } + return; + } + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, + iter.dtype(), "addr_cuda", [&] { + auto beta_val = beta.to(); + auto alpha_val = alpha.to(); + + scalar_t zero_val(0); + // when beta==0, values in self should be ignored, + // nans and infs in self should not propagate. + if (beta_val == zero_val) { + gpu_kernel( + iter, + [=] GPU_LAMBDA (scalar_t self_val, + scalar_t vec1_val, scalar_t vec2_val) -> scalar_t { + return alpha_val * vec1_val * vec2_val; + } + ); + } else { + gpu_kernel( + iter, + [=] GPU_LAMBDA (scalar_t self_val, + scalar_t vec1_val, scalar_t vec2_val) -> scalar_t { + return beta_val * self_val + alpha_val * vec1_val * vec2_val; + } + ); + } + }); +} + +} // anonymous namespace + +REGISTER_DISPATCH(addr_stub, &addr_kernel_cuda); + +}} diff --git a/aten/src/ATen/native/cuda/LogAddExpKernel.cu b/aten/src/ATen/native/cuda/LogAddExpKernel.cu new file mode 100644 index 0000000000000..ac2f94aceeee9 --- /dev/null +++ b/aten/src/ATen/native/cuda/LogAddExpKernel.cu @@ -0,0 +1,43 @@ +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at { namespace native { + +void logaddexp_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp_cuda", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { + if (::isinf(a) && a == b) { + return a; + } + else { + scalar_t m = ::max(a, b); + return m + ::log((scalar_t)(1.0) + ::exp(-::abs(a - b))); + } + }); + }); +} + +void logaddexp2_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "logaddexp2_cuda", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { + if (::isinf(a) && a == b) { + return a; + } + else { + scalar_t m = ::max(a, b); + return m + ::log2((scalar_t)(1.0) + ::pow((scalar_t)(2.0), -::abs(a - b))); + } + }); + }); +} + +REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel_cuda); +REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index bb913dc0ec9e7..e74debfb29be0 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -20,7 +20,7 @@ constexpr int block_work_size = BLOCK_WORK_SIZE; namespace at { namespace native { template -static OffsetCalculator make_input_offset_calculator(const TensorIterator& iter) { +static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) { // array size can not be 0, this happens when N == 0 constexpr int array_size = std::max(N, 1); TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); @@ -34,7 +34,7 @@ static OffsetCalculator make_input_offset_calculator(const TensorIterator& it } template -static OffsetCalculator make_output_offset_calculator(const TensorIterator& iter) { +static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) { TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); std::array strides; int64_t element_sizes[num_outputs]; @@ -88,7 +88,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { namespace at { namespace native { template -void gpu_kernel(TensorIterator& iter, const func_t& f) { +void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { for (int arg = 0; arg < iter.ntensors(); arg++) { TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda()); @@ -139,8 +139,7 @@ struct BUnaryFunctor { }; template -void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) { - ASSERT_HOST_DEVICE_LAMBDA(func_t); +void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); using traits = function_traits; @@ -153,6 +152,11 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) { if (iter.is_cpu_scalar(1)) { AUnaryFunctor af(f, iter.scalar_value(1)); iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly const OptionalDeviceGuard device_guard(device_of(iter.tensor(1))); gpu_kernel(iter, af); } else if (iter.is_cpu_scalar(2)) { @@ -184,11 +188,11 @@ static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const fun int64_t grid = (N + block_work_size - 1) / block_work_size; auto stream = at::cuda::getCurrentCUDAStream(); unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template -void gpu_kernel_multiple_outputs_impl(TensorIterator& iter, const func_t& f) { +void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; using output_t = typename traits::result_type; static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); @@ -219,7 +223,7 @@ void gpu_kernel_multiple_outputs_impl(TensorIterator& iter, const func_t& f) { } // namespace template -void gpu_kernel_multiple_outputs(TensorIterator& iter, const func_t& f) { +void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { ASSERT_HOST_DEVICE_LAMBDA(func_t); for (int arg = 0; arg < iter.ntensors(); arg++) { diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index 602541cda5660..69718b206d6b7 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -281,7 +281,7 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, BLANK); - AT_CUDA_CHECK(cudaGetLastError()); // catch launch errors + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(neg_log_likelihood, log_alpha); } @@ -633,7 +633,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, BLANK); - AT_CUDA_CHECK(cudaGetLastError()); // catch launch errors + C10_CUDA_KERNEL_LAUNCH_CHECK(); } // Very crude heuristic for what is a small problem., based on linearly regressing problem dimensions on @@ -690,7 +690,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, num_labels, BLANK, zero_infinity); - AT_CUDA_CHECK(cudaGetLastError()); // catch launch errors + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // small problem, use naive algorithm // Still no block/grid configuration guru... int threads_input = max_threads; @@ -713,7 +713,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, num_labels, BLANK, zero_infinity); - AT_CUDA_CHECK(cudaGetLastError()); // catch launch errors + C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors } // zero those invalid graident elements due to padding @@ -737,7 +737,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ grad.size(1), grad.size(2) ); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return grad; diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index eec428ae2a12b..17c30cd00ea7d 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -54,12 +54,12 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { a = q; i = 0; b = 0.0; - while((i < 9) || (a <= 9.0)){ + while ((i < 9) || (a <= 9.0)) { i += 1; a += 1.0; b = ::pow( a, -x ); s += b; - if((-MACHEP < (b / s)) && ((b / s) < MACHEP)) { + if ((-MACHEP < (b / s)) && ((b / s) < MACHEP)) { return static_cast(s); } }; @@ -68,16 +68,16 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { s -= 0.5 * b; a = 1.0; k = 0.0; - for(int i=0; i < 12; i++) { + for (int i=0; i < 12; i++) { a *= x + k; b /= w; t = a * b / A[i]; s = s + t; t = t / s; - if(t < 0){ + if (t < 0){ t = -t; } - if((-MACHEP (s); } k += 1.0; @@ -93,6 +93,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { */ template static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type; static const double PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; @@ -108,14 +109,18 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { accscalar_t x = static_cast(in); if (x == 0) { - return static_cast(INFINITY); + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(static_cast(INFINITY), -x); } - bool x_is_integer = x == ::floor(x); + bool x_is_integer = x == ::trunc(x); accscalar_t result = 0; if (x < 0) { if (x_is_integer) { - return static_cast(INFINITY); + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return static_cast(NAN); } // Rounding errors in tan's input can really affect the output // for extreme values, so we always perform this computation in double. @@ -174,7 +179,6 @@ static inline __host__ __device__ scalar_t calc_polygamma(int n, scalar_t x) { return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast(n) + 1.0)) * zeta(static_cast(n + 1), x); } - template static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) { scalar_t a = ::abs(a_in); diff --git a/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu b/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu index 7788dc054d8f9..6142e427ffd1f 100644 --- a/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu +++ b/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu @@ -12,18 +12,18 @@ namespace at { namespace native { void maximum_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == ScalarType::Bool) { - gpu_kernel(iter, []GPU_LAMBDA(bool a, bool b) -> bool { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(bool a, bool b) -> bool { return a || b; }); } else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "max_elementwise_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return ::max(a, b); }); }); } else { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "max_elementwise_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { if (a != a) { return a; } else if (b != b) { @@ -38,18 +38,18 @@ void maximum_kernel_cuda(TensorIterator& iter) { void minimum_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == ScalarType::Bool) { - gpu_kernel(iter, []GPU_LAMBDA(bool a, bool b) -> bool { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(bool a, bool b) -> bool { return a && b; }); } else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "minimum_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return ::min(a, b); }); }); } else { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "min_elementwise_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { if (a != a) { return a; } else if (b != b) { diff --git a/aten/src/ATen/native/cuda/MaxUnpooling.cu b/aten/src/ATen/native/cuda/MaxUnpooling.cu index 33382e0700bb1..c3517ab49d1c4 100644 --- a/aten/src/ATen/native/cuda/MaxUnpooling.cu +++ b/aten/src/ATen/native/cuda/MaxUnpooling.cu @@ -169,8 +169,8 @@ Tensor& max_unpooling2d_forward_out_cuda( oheight, owidth, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); })); - AT_CUDA_CHECK(cudaGetLastError()); if (self.ndimension() == 3) { output.resize_({numChannels, oheight, owidth}); } @@ -343,7 +343,7 @@ Tensor& max_unpooling3d_forward_out_cuda( oH, oW, offsetZ); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; } @@ -446,8 +446,8 @@ at::Tensor& max_unpooling2d_backward_out_cuda( oheight, owidth, grad_input.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); })); - AT_CUDA_CHECK(cudaGetLastError()); return grad_input; } at::Tensor max_unpooling2d_backward_cuda( @@ -550,7 +550,7 @@ at::Tensor& max_unpooling3d_backward_out_cuda( indices.packed_accessor64(), grad_input_reshaped.packed_accessor64(), offsetZ); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; } diff --git a/aten/src/ATen/native/cuda/MiscUtils.h b/aten/src/ATen/native/cuda/MiscUtils.h index 31e6d69aa0a10..8f78e8d780031 100644 --- a/aten/src/ATen/native/cuda/MiscUtils.h +++ b/aten/src/ATen/native/cuda/MiscUtils.h @@ -6,8 +6,8 @@ #include // for USE_MAGMA #ifdef USE_MAGMA -#include #include +#include #endif namespace at { diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh index f82a0d9a58c8f..a9877c088b109 100644 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh @@ -26,41 +26,109 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s // TensorListMetadata has to be < 4KB - the limit for kernel launch argument static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; template struct TensorListMetadata { void* addresses[n][depth_to_max_tensors[n-1]]; - int sizes[depth_to_max_tensors[n-1]]; + int numel_for_tensor[depth_to_max_tensors[n-1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; + int block_to_chunk[depth_to_max_blocks[n-1]]; +}; + +template struct TensorListScalarListMetadata +{ + void* addresses[n][depth_to_max_tensors_scalarlist[n-1]]; + int numel_for_tensor[depth_to_max_tensors_scalarlist[n-1]]; + scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n-1]]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; int block_to_chunk[depth_to_max_blocks[n-1]]; }; template C10_LAUNCH_BOUNDS_1(kBlockSize) -__global__ void +__global__ void multi_tensor_apply_kernel( T tensorListMeta, U callable, ArgTypes... args) { // Hand the chunk information to the user-supplied functor to process however it likes. - callable(kChunkSize, tensorListMeta, args...); + callable(kChunkSize, tensorListMeta, args...); } template void multi_tensor_apply( std::vector>& tensor_lists, + at::ArrayRef scalars, T callable, ArgTypes... args) { TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth."); - const cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + size_t n_tensors = tensor_lists[0].size(); + using scalar_vals_t = typename T::opmath_t; + TensorListScalarListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for(size_t t = 0; t < n_tensors; t++) { + + tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t]; + + tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + } + loc_tensor_info++; + + int chunks = (tensor_lists[0][t].numel() + kChunkSize - 1)/kChunkSize; + for (int chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + bool tensors_full = (loc_tensor_info == depth_to_max_tensors_scalarlist[depth-1] && + chunk == chunks - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]); + bool last_chunk = (t == n_tensors - 1 && chunk == chunks - 1); + + if (tensors_full || blocks_full || last_chunk) { + multi_tensor_apply_kernel<<>>( + tensorListMeta, + callable, + args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if(chunk == chunks - 1) { + loc_tensor_info = 0; + } + else { + tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1]; + tensorListMeta.scalar_vals[0] = tensorListMeta.scalar_vals[loc_tensor_info-1]; + for(int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1]; + } + loc_tensor_info = 1; + } + } + } + } + } + + +template +void multi_tensor_apply( + std::vector>& tensor_lists, + T callable, + ArgTypes... args) { + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth."); size_t n_tensors = tensor_lists[0].size(); TensorListMetadata tensorListMeta; int loc_block_info = 0; int loc_tensor_info = 0; for(size_t t = 0; t < n_tensors; t++) { - tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel(); for (int d = 0; d < depth; d++) { tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); } @@ -82,16 +150,15 @@ void multi_tensor_apply( tensorListMeta, callable, args...); - - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Reset. loc_block_info = 0; if(chunk == chunks - 1) { - loc_tensor_info = 0; + loc_tensor_info = 0; } else { - tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1]; + tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1]; for(int d = 0; d < depth; d++) { tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1]; } diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 4e96c28683360..cc74848b632a5 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -73,6 +74,7 @@ void renormRows(Tensor& t) { <<>>(t.data_ptr(), rows, cols); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -113,7 +115,7 @@ __device__ int binarySearchForMultinomial(scalar_t* cumdist, template __global__ void -sampleMultinomialWithReplacement(std::pair seeds, +sampleMultinomialWithReplacement(PhiloxCudaState philox_args, int totalSamples, int64_t* dest, int64_t distributions, @@ -124,11 +126,16 @@ sampleMultinomialWithReplacement(std::pair seeds, // search due to divergence. It seems possible to compute multiple // values and limit divergence though later on. + auto seeds = at::cuda::philox::unpack(philox_args); + // global index formula for 2D grid of 1D blocks int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; - curand_init(seeds.first, idx, seeds.second, &state); + curand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); // The block determines the distribution for which we generate a point for (int64_t curDist = blockIdx.y; @@ -293,7 +300,11 @@ sampleMultinomialOnce(int64_t* dest, } } -void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional generator) { +void multinomial_with_replacement_kernel_impl( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator) { auto gen = get_generator_or_default(generator, cuda::detail::getDefaultCUDAGenerator()); int inputSize = self.dim(); @@ -322,7 +333,9 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n // To exploit greater parallelism for the sampling, generate the // Uniform random samples in a separate kernel launch, into // temporarily allocated memory. The device RNG is thread-limited - Tensor sampled = native::empty_cuda({numDist, n_sample}, self_v.options()); + Tensor sampled = native::empty_cuda({numDist, n_sample}, optTypeMetaToScalarType(self_v.options().dtype_opt()), + self_v.options().layout_opt(), self_v.options().device_opt(), + self_v.options().pinned_memory_opt()); at::native::uniform_(sampled, 0.0, 1.0, generator); dim3 block(numCategories < maxThreads ? numCategories : maxThreads); @@ -340,6 +353,7 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n self_v.stride(0), self_v.stride(1) ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // Generic, slow implementation with memory allocations @@ -359,9 +373,8 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n // Prefix sum along rows at::_cumsum_out(prefixSum, normDist, 1); - std::pair rng_engine_inputs; + PhiloxCudaState rng_engine_inputs; - if (with_replacement) { // Binary search is warp divergent (so effectively we're running // with just a single thread), but for better utilization, // we need each block to have at least 4 warps. @@ -379,30 +392,29 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n // curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]), // offset is 4 times that. auto offset = ((numDist-1)/grid.y+1)*4; - rng_engine_inputs = gen->philox_engine_inputs(offset); + rng_engine_inputs = gen->philox_cuda_state(offset); } // Sample with replacement sampleMultinomialWithReplacement <<>>( - rng_engine_inputs, + rng_engine_inputs, n_sample, result.data_ptr(), numDist, numCategories, prefixSum.data_ptr(), normDist.data_ptr()); - } + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); - if (inputSize == 1) { result.resize_({n_sample}); } } } -REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl); - +REGISTER_DISPATCH( + multinomial_with_replacement_stub, + &multinomial_with_replacement_kernel_impl); }} diff --git a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu index 2ad6f0785a17d..522e3bbd8760a 100644 --- a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu +++ b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu @@ -188,10 +188,8 @@ void slow_conv_dilated_all_cuda_template( int64_t nInputPlane = weight.size(1); int64_t nOutputPlane = weight.size(0); // Temporary buffers: - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t output_vsize = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t output_vsize = prod_intlist(output_size); Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { columns.resize_({nInputPlane * m, output_vsize}); diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 4830ca149cffd..6bf4e0f32f134 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -5,26 +5,24 @@ namespace at { namespace native { std::tuple batch_norm_cuda_out(Tensor& output, Tensor& save_mean, Tensor& save_invstd, const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_cuda", [&] { - auto mean_st = running_mean.dtype(); - auto var_st = running_var.dtype(); - TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } else { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } + auto mean_st = running_mean.dtype(); + auto var_st = running_var.dtype(); + TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); } else { - if (is_half_float || is_bfloat16_float) { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } else { - batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); - } + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); } - }); + } else { + if (is_half_float || is_bfloat16_float) { + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); + } else { + batch_norm_cuda_template(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon); + } + } }); return std::tuple(output, save_mean, save_invstd); } @@ -54,38 +52,34 @@ std::tuple batch_norm_cuda(const Tensor& self, const Ten std::tuple batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array grad_input_mask) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward_cuda", [&] { - auto mean_st = running_mean.dtype(); - auto var_st = running_var.dtype(); - TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } else { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } + auto mean_st = running_mean.dtype(); + auto var_st = running_var.dtype(); + TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } else { - return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); - } + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); } - }); + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); + } else { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); + } + } }); } std::tuple batch_norm_stats_cuda(const Tensor& self, double epsilon) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_stats_cuda", [&] { - if (cuda::detail::canUse32BitIndexMath(self)) { - return batch_norm_stats_cuda_template(self, epsilon); - } else { - return batch_norm_stats_cuda_template(self, epsilon); - } - }); + if (cuda::detail::canUse32BitIndexMath(self)) { + return batch_norm_stats_cuda_template(self, epsilon); + } else { + return batch_norm_stats_cuda_template(self, epsilon); + } }); } @@ -99,26 +93,24 @@ Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Ten Tensor& batch_norm_elemt_cuda_out(Tensor& output, const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& mean, const Tensor& invstd, double epsilon) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_elemt", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_elemt", [&] { - auto mean_st = mean.dtype(); - auto invstd_st = invstd.dtype(); - TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } else { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); + } else { + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); + } + } else { + if (is_half_float || is_bfloat16_float) { + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); } else { - if (is_half_float || is_bfloat16_float) { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } else { - batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); - } + batch_norm_elemt_cuda_template(output, self, weight, bias, mean, invstd, epsilon); } - }); + } }); return output; } @@ -128,104 +120,98 @@ std::tuple batch_norm_gather_stats_cuda(const Tensor& self, cons const Tensor& running_var, double momentum, double epsilon, int64_t count) { std::vector counts(mean.size(0), count); Tensor counts_ = at::from_blob((void*)counts.data(), {(int64_t)counts.size()}, self.options().dtype(at::kLong).device(at::kCPU)); - counts_ = counts_.to(self.device()).to(running_mean.dtype()); + counts_ = counts_.to(self.device()).to(running_mean.defined() ? running_mean.dtype() : self.dtype()); return batch_norm_gather_stats_with_counts_cuda(self, mean, invstd, running_mean, running_var, momentum, epsilon, counts_); } -std::tuple batch_norm_gather_stats_with_counts_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean, - const Tensor& running_var, double momentum, double epsilon, const Tensor& counts) { - - return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, running_mean.scalar_type(), "batch_norm_update_stats_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_update_stats_cuda", [&] { - using accscalar_t = at::acc_type; - if (cuda::detail::canUse32BitIndexMath(self)) { - return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); - } else { - return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); - } - }); +std::tuple batch_norm_gather_stats_with_counts_cuda( + const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean /* optional */, + const Tensor& running_var /* optional */, double momentum, double epsilon, const Tensor& counts) { + + auto scalar_type = running_mean.defined() ? running_mean.scalar_type() : self.scalar_type(); + return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "batch_norm_update_stats_cuda", [&] { + using accscalar_t = at::acc_type; + if (cuda::detail::canUse32BitIndexMath(self)) { + return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); + } else { + return batch_norm_gather_stats_cuda_template(mean, invstd, running_mean, running_var, momentum, epsilon, counts); + } }); } std::tuple batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, const Tensor& weight, bool input_g, bool weight_g, bool bias_g) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_reduce", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward_reduce", [&] { - auto mean_st = mean.dtype(); - auto invstd_st = invstd.dtype(); - TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } else { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); + } else { + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); + } + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } else { - return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); - } + return batch_norm_backward_reduce_cuda_template(self, input, mean, invstd, weight, input_g, weight_g, bias_g); } - }); + } }); } Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd, const Tensor& weight, const Tensor& mean_dy, const Tensor& mean_dy_xmu) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward_elemt", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward_elemt", [&] { - auto mean_st = mean.dtype(); - auto invstd_st = invstd.dtype(); - TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } else { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types"); + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } else { - return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); - } + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); } - }); + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); + } else { + return batch_norm_backward_elemt_cuda_template(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu); + } + } }); } std::tuple batch_norm_update_stats_cuda( const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward", [&] { - auto mean_st = running_mean.dtype(); - auto var_st = running_var.dtype(); - TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); - // Some workloads depend on passing in half input and float stats, which is - // usually handled by cuDNN. However, the JIT sometimes replaces cuDNN calls with this - // one so it needs to support the same case, or people start to complain. - bool is_half_float = std::is_same::value && mean_st == at::kFloat; - bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; - if (cuda::detail::canUse32BitIndexMath(self)) { - if (is_half_float || is_bfloat16_float) { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } else { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } + auto mean_st = running_mean.dtype(); + auto var_st = running_var.dtype(); + TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); + // Some workloads depend on passing in half input and float stats, which is + // usually handled by cuDNN. However, the JIT sometimes replaces cuDNN calls with this + // one so it needs to support the same case, or people start to complain. + bool is_half_float = std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; + if (cuda::detail::canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); + } else { + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); + } + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); } else { - if (is_half_float || is_bfloat16_float) { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } else { - return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); - } + return batch_norm_update_stats_cuda_template(self, running_mean, running_var, momentum); } - }); + } }); } diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index a0d37dd44be1f..a0445f1291928 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -104,7 +104,7 @@ static __device__ __forceinline__ Float2 warpSum(Float2 <<>> (input, output, running_mean, running_var, weight, bias, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in // the feature dimension, we'll use some threads for blocks @@ -566,10 +567,11 @@ void batch_norm_cuda_template(Tensor& output_, Tensor& save_mean_, Tensor& save_ dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); batch_norm_collect_statistics_kernel <<>> (input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd); + C10_CUDA_KERNEL_LAUNCH_CHECK(); batch_norm_transform_input_kernel <<>> (input, output, save_mean, save_invstd, weight, bias, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); } template @@ -615,7 +617,7 @@ std::tuple batch_norm_backward_cuda_template(const Tenso batch_norm_backward_kernel <<>> (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(grad_input_, grad_weight_, grad_bias_); } @@ -654,7 +656,7 @@ std::tuple batch_norm_stats_cuda_template(const Tensor& input_, dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); batch_norm_collect_statistics_kernel <<>> (input, epsilon, 0.0, dummy_mean, dummy_invstd, mean, invstd); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(mean_, invstd_); } @@ -694,7 +696,7 @@ void batch_norm_elemt_cuda_template(Tensor& output_, const Tensor& input_, const dim3 threads_trans(tf, tb); batch_norm_transform_input_kernel <<>> (input, output, mean, invstd, weight, bias, epsilon); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -727,7 +729,7 @@ std::tuple batch_norm_gather_stats_cuda_template(const Tensor& m int grid = std::max(1, features/block); batch_norm_reduce_statistics_kernel <<>> (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(save_mean_, save_invstd_); } @@ -777,7 +779,7 @@ std::tuple batch_norm_backward_reduce_cuda_templ batch_norm_backward_reduce_kernel <<>> (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); } @@ -819,7 +821,7 @@ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Te dim3 threads_trans(tf, tb); batch_norm_backward_elemt_kernel <<>> (input, grad_output, mean, invstd, weight, mean_dy, mean_dy_xmu, grad_input); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input_reshaped.view(input_.sizes()); } @@ -853,7 +855,7 @@ std::tuple batch_norm_update_stats_cuda_template( // NB: epsilon is unused by the Var transform, so we set it to 0 batch_norm_collect_statistics_kernel <<>> (input, 0., momentum, running_mean, running_var, save_mean, save_var); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(save_mean_, save_var_); } diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh index 8265c59993763..051583a12a537 100644 --- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh +++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh @@ -258,50 +258,24 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ + softmax_warp_forward \ + <<>>(dst, \ + src, batch_count, softmax_elements_stride, softmax_elements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024 default: break; } @@ -333,53 +307,27 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_backward - <<>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); - break; + #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ + softmax_warp_backward \ + <<>> \ + (grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024 default: break; } } } - diff --git a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu index 501ef90477da6..f349438964143 100644 --- a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu +++ b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu @@ -9,7 +9,7 @@ namespace at { namespace native { void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() { auto alpha = value.to(); gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t { return a + alpha * b * c; @@ -18,7 +18,7 @@ void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) { } void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() { auto alpha = value.to(); gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t { return a + alpha * (b / c); @@ -26,17 +26,18 @@ void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) { }); } -void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm) { - AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_backward_cuda", [&]() { +void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm, double beta) { + AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_backward_cuda", [&iter, &norm, beta] { auto norm_val = norm.to(); - gpu_kernel(iter, [norm_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t { + scalar_t beta_val(beta); + gpu_kernel(iter, [norm_val, beta_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t { const auto x = input - target; - if (x < scalar_t(-1)) + if (x < -beta_val) return -norm_val * grad_output; - else if (x > scalar_t(1)) + else if (x > beta_val) return norm_val * grad_output; else - return norm_val * x * grad_output; + return norm_val * x * grad_output / beta_val; }); }); } diff --git a/aten/src/ATen/native/cuda/ROCmLoops.cuh b/aten/src/ATen/native/cuda/ROCmLoops.cuh index e0dc83556677a..c339364b5a021 100644 --- a/aten/src/ATen/native/cuda/ROCmLoops.cuh +++ b/aten/src/ATen/native/cuda/ROCmLoops.cuh @@ -134,7 +134,7 @@ static void launch_kernel(int64_t N, const func_t& f) { dim3 grid((N + block.x * vt - 1) / (block.x * vt)); auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel<<>>(N, f); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -296,7 +296,7 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) { int64_t grid = (N + block_work_size - 1) / block_work_size; auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel<<>>(N, f, data); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template::value, int> = 0> @@ -306,7 +306,7 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {} template -void gpu_kernel_impl(TensorIterator& iter, const func_t& f) { +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; using arg0_t = typename traits::result_type; constexpr int ntensors = traits::arity + 1; diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 38f3f8487fc4e..107c3c28fdac6 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -39,8 +39,10 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) { using scalar_t = typename function_traits::result_type; if (N <= std::numeric_limits::max()) { elementwise_kernel_with_index<<>>(N, f, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { elementwise_kernel_with_index<<>>(N, f, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -105,7 +107,6 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, c10::optiona result.copy_(r); } - AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -164,7 +165,6 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, c10::optiona result.copy_(r); } - AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -201,71 +201,67 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { }); - AT_CUDA_CHECK(cudaGetLastError()); return result; } Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "arange_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "arange_cuda", [&] { - using accscalar_t = at::acc_type; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); - - // we use double precision for (start - end) / step - // to compute size_d for consistency across devices. - // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, - // but double on cpu for the same, - // and the effective output size starts differing on CPU vs GPU because of precision issues, which - // we dont want. - // the corner-case we do want to take into account is int64_t, which has higher precision than double - double size_d; - if (std::is_same::value) { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); - } else { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); - } - - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); - - TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), - "invalid size, possible overflow?"); - int64_t size = static_cast(size_d); - int64_t numel = result.numel(); - - if (numel != size) { - if(numel > 0){ - TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(), - " is ", numel, " which does not match the computed number of elements ", size, - ". Note that this may occur as a result of rounding error. " - "The out tensor will be resized to a tensor of shape (", size, ",)."); - } - result.resize_({size}); - } - bool is_contiguous = result.is_contiguous(); - Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); - gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t { - accscalar_t inc = xstep * static_cast(ind); - accscalar_t val = xstart + inc; - return static_cast(val); - }); + // we use double precision for (start - end) / step + // to compute size_d for consistency across devices. + // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, + // but double on cpu for the same, + // and the effective output size starts differing on CPU vs GPU because of precision issues, which + // we dont want. + // the corner-case we do want to take into account is int64_t, which has higher precision than double + double size_d; + if (std::is_same::value) { + size_d = std::ceil(static_cast(end.to() - start.to()) + / step.to()); + } else { + size_d = std::ceil(static_cast(end.to() - start.to()) + / step.to()); + } - if(!is_contiguous) { - result.copy_(r); + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && + std::isfinite(static_cast(xend)), + "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), + "upper bound and larger bound inconsistent with step sign"); + + TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), + "invalid size, possible overflow?"); + int64_t size = static_cast(size_d); + int64_t numel = result.numel(); + + if (numel != size) { + if(numel > 0){ + TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(), + " is ", numel, " which does not match the computed number of elements ", size, + ". Note that this may occur as a result of rounding error. " + "The out tensor will be resized to a tensor of shape (", size, ",)."); } + result.resize_({size}); + } + bool is_contiguous = result.is_contiguous(); + Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + + gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t { + accscalar_t inc = xstep * static_cast(ind); + accscalar_t val = xstart + inc; + return static_cast(val); }); + + if(!is_contiguous) { + result.copy_(r); + } }); - AT_CUDA_CHECK(cudaGetLastError()); return result; } diff --git a/aten/src/ATen/native/cuda/RecordStream.cu b/aten/src/ATen/native/cuda/RecordStream.cu new file mode 100644 index 0000000000000..d48561df00e5c --- /dev/null +++ b/aten/src/ATen/native/cuda/RecordStream.cu @@ -0,0 +1,7 @@ +#include +#include +namespace at { namespace native { +void record_stream_cuda(Tensor& self, c10::Stream stream) { + c10::cuda::CUDACachingAllocator::recordStream(self.storage().data_ptr(), at::cuda::CUDAStream::unpack(stream.pack())); +} +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 618088cefb3a4..ea797e6011afe 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -817,15 +817,16 @@ static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) switch(config.output_vec_size) { case 4: reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); break; default: reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - - AT_CUDA_CHECK(cudaGetLastError()); } class AccumulationBuffer { @@ -872,7 +873,7 @@ int get_output_vec_size(TensorIterator &iter) { vec_size /= 2; } }; - + uint64_t base_address = reinterpret_cast(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t); update_vec_size(base_address); diff --git a/aten/src/ATen/native/cuda/ReduceLogicKernel.cu b/aten/src/ATen/native/cuda/ReduceLogicKernel.cu index ca2db43637dd6..fcf60678929e4 100644 --- a/aten/src/ATen/native/cuda/ReduceLogicKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceLogicKernel.cu @@ -3,22 +3,33 @@ #include #include #include +#include namespace at { namespace native { void and_kernel_cuda(TensorIterator& iter) { - gpu_reduce_kernel( - iter, func_wrapper ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t { - return a && b; - }), true); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "and_cuda", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) && static_cast(b)); + }), + true); + }); } void or_kernel_cuda(TensorIterator& iter) { - gpu_reduce_kernel( - iter, func_wrapper ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t { - return a || b; - }), false); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "or_cuda", [&]() { + gpu_reduce_kernel( + iter, + func_wrapper([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return (static_cast(a) || static_cast(b)); + }), + false); + }); } REGISTER_DISPATCH(and_stub, &and_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu index 83d11ed9f9e10..cb070e15f191c 100644 --- a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu @@ -13,20 +13,32 @@ namespace at { namespace native { +template +struct MaxNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (THCNumerics::isnan(a) || a > b) ? a : b; + } +}; + template void max_values_kernel_cuda_impl(TensorIterator& iter) { gpu_reduce_kernel( - iter, func_wrapper ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { - return (THCNumerics::isnan(a) || a > b) ? a : b; - }), at::numeric_limits::lower_bound()); + iter, func_wrapper (MaxNanFunctor()), + at::numeric_limits::lower_bound()); } +template +struct MinNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (THCNumerics::isnan(a) || a < b) ? a : b; + } +}; + template void min_values_kernel_cuda_impl(TensorIterator& iter) { gpu_reduce_kernel( - iter, func_wrapper ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { - return (THCNumerics::isnan(a) || a < b) ? a : b; - }), at::numeric_limits::upper_bound()); + iter, func_wrapper (MinNanFunctor()), + at::numeric_limits::upper_bound()); } void max_values_kernel_cuda(TensorIterator& iter) { diff --git a/aten/src/ATen/native/cuda/ReduceMomentKernel.cu b/aten/src/ATen/native/cuda/ReduceMomentKernel.cu index 34ce6b3a25be3..d9d289ba8cea8 100644 --- a/aten/src/ATen/native/cuda/ReduceMomentKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMomentKernel.cu @@ -36,26 +36,28 @@ static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_s template void mean_kernel_impl(TensorIterator& iter) { - float factor = float(iter.num_output_elements()) / iter.numel(); - gpu_reduce_kernel(iter, MeanOps {factor}); + // returns acc_t for all non-complex dtypes and returns T for c10::complex + using factor_t = typename c10::scalar_value_type::type; + factor_t factor = static_cast(iter.num_output_elements()) / iter.numel(); + gpu_reduce_kernel(iter, MeanOps {factor}); } static void mean_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == kHalf) { - return mean_kernel_impl(iter); + mean_kernel_impl(iter); } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel - return mean_kernel_impl(iter); - } - else if(iter.dtype() == kBFloat16) { - return mean_kernel_impl(iter); + mean_kernel_impl(iter); + } else if(iter.dtype() == kBFloat16) { + mean_kernel_impl(iter); } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel - return mean_kernel_impl(iter); + mean_kernel_impl(iter); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "mean_cuda", [&]() { + mean_kernel_impl(iter); + }); } - AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cuda", [&]() { - mean_kernel_impl(iter); - }); } REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/ReduceNormKernel.cu b/aten/src/ATen/native/cuda/ReduceNormKernel.cu index 39a355a96756f..3a24f00f6ebfa 100644 --- a/aten/src/ATen/native/cuda/ReduceNormKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceNormKernel.cu @@ -7,48 +7,49 @@ namespace at { namespace native { -template +// This reduction accumulates results as the type `acc_t`. By default, when +// `scalar_t` is complex, `acc_t` is the downgraded real number type. +// Otherwise, `acc_t` and `scalar_t` are the same type. +template ::type, typename out_t=typename scalar_value_type::type> void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) { - float p; + double p; if (val.isIntegral(false)) { p = val.to(); } else if (val.isFloatingPoint()) { - p = val.to(); + p = val.to(); } else { AT_ERROR("norm_kernel_cuda_impl expects norm to be integer or float"); } - if (p == static_cast(0)) { - gpu_reduce_kernel(iter, NormZeroOps(), 0); - } else if (p == static_cast(1)) { - gpu_reduce_kernel(iter, NormOneOps(), 0); - } else if (p == static_cast(2)) { - gpu_reduce_kernel(iter, NormTwoOps(), 0); - } else if (p == static_cast(INFINITY)) { - gpu_reduce_kernel(iter, AbsMaxOps(), std::numeric_limits::min()); - } else if (p == static_cast(-INFINITY)) { - gpu_reduce_kernel(iter, AbsMinOps(), std::numeric_limits::max()); + if (p == static_cast(0)) { + gpu_reduce_kernel(iter, NormZeroOps(), 0); + } else if (p == static_cast(1)) { + gpu_reduce_kernel(iter, NormOneOps(), 0); + } else if (p == static_cast(2)) { + gpu_reduce_kernel(iter, NormTwoOps(), 0); + } else if (p == static_cast(INFINITY)) { + gpu_reduce_kernel(iter, AbsMaxOps(), 0); + } else if (p == static_cast(-INFINITY)) { + gpu_reduce_kernel(iter, AbsMinOps(), std::numeric_limits::max()); } else { - gpu_reduce_kernel(iter, NormOps{ acc_t(p) }, 0); + gpu_reduce_kernel(iter, NormOps{ acc_t(p) }, 0); } } static void norm_kernel_cuda(TensorIterator& iter, Scalar p) { - if (iter.dtype() == kHalf) { + if (iter.input_dtype() == kHalf) { return norm_kernel_cuda_impl(iter, p); - } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { + } else if (iter.dtype(1) == kHalf && iter.input_dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return norm_kernel_cuda_impl(iter, p); } - #ifdef __HIP_PLATFORM_HCC__ - else if(iter.dtype() == kBFloat16) { + else if(iter.input_dtype() == kBFloat16) { return norm_kernel_cuda_impl(iter, p); - } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) { + } else if (iter.dtype(1) == kBFloat16 && iter.input_dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return norm_kernel_cuda_impl(iter, p); } - #endif - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cuda", [&] { norm_kernel_cuda_impl(iter, p); }); } diff --git a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu index f7d23e74ee1be..9919b0f0eac44 100644 --- a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu @@ -35,6 +35,17 @@ struct prod_functor { } }; +// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] +template <> +struct prod_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper([] GPU_LAMBDA(bool a, bool b) -> bool { + return a && b; + }), 1); + } +}; + // The function `reduce_dispatch` below dispatches to the kernel based // on the type of `iter`. It takes care of the common logic // for handling Half-Precision floating types. @@ -88,7 +99,7 @@ static void nansum_kernel_cuda(TensorIterator& iter) { static void prod_kernel_cuda(TensorIterator& iter) { auto general_dispatcher = [](TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES(iter.dtype(), "prod_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, iter.dtype(), "prod_cuda", [&]() { prod_functor{}(iter); }); }; diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu index 2b182f32b5e75..95a6825d507f8 100644 --- a/aten/src/ATen/native/cuda/ReflectionPad.cu +++ b/aten/src/ATen/native/cuda/ReflectionPad.cu @@ -200,10 +200,9 @@ void reflection_pad1d_out_template( grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( input.data_ptr(), output.data_ptr(), input_w, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } void reflection_pad1d_backward_out_template( @@ -213,7 +212,7 @@ void reflection_pad1d_backward_out_template( if (grad_input.numel() == 0) { return; } - + TORCH_CHECK(canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); @@ -252,15 +251,14 @@ void reflection_pad1d_backward_out_template( grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( grad_input.data_ptr(), grad_output.data_ptr(), input_w, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } void reflection_pad2d_out_template( Tensor &output, const Tensor &input_, IntArrayRef padding) { - + TORCH_CHECK(canUse32BitIndexMath(input_), "input tensor must fit into 32-bit index math"); @@ -331,10 +329,9 @@ void reflection_pad2d_out_template( input.data_ptr(), output.data_ptr(), input_w, input_h, pad_t, pad_b, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } void reflection_pad2d_backward_out_template( @@ -344,7 +341,7 @@ void reflection_pad2d_backward_out_template( if (grad_input.numel() == 0) { return; } - + TORCH_CHECK(canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math"); TORCH_CHECK(canUse32BitIndexMath(grad_output_), @@ -393,10 +390,9 @@ void reflection_pad2d_backward_out_template( grad_input.data_ptr(), grad_output.data_ptr(), input_w, input_h, pad_t, pad_b, pad_l, pad_r); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu index f70459928bf06..8437e80ebb48f 100644 --- a/aten/src/ATen/native/cuda/Repeat.cu +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -23,6 +23,7 @@ static void compute_cuda(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *resu int64_t grid = std::min((size + warps_per_block - 1) / warps_per_block, 2048L); compute_cuda_kernel<<>>(repeat_ptr, cumsum_ptr, result_ptr, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } namespace at { namespace native { diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu index 515dc61eca64c..8f164c8476f7b 100644 --- a/aten/src/ATen/native/cuda/ReplicationPadding.cu +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -217,14 +217,17 @@ void replication_pad1d_out_cuda_template( int numBatch = 1; int numInputDims = input.ndimension(); - TORCH_CHECK(input.numel() > 0 && (numInputDims == 2 || numInputDims == 3), - "2D or 3D (batch mode) tensor expected for input") + TORCH_CHECK( + (numInputDims == 2 && input.size(0) != 0 && input.size(1) != 0) || + (numInputDims == 3 && input.size(1) != 0 && input.size(2) != 0), + "Expected 2D or 3D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); - if (numInputDims == 3) { - numBatch = input.size(0); - planeDim++; - dimw++; - } + if (numInputDims == 3) { + numBatch = input.size(0); + planeDim++; + dimw++; + } int numPlanes = input.size(planeDim); int inputW = input.size(dimw); @@ -234,13 +237,19 @@ void replication_pad1d_out_cuda_template( "input (W: ", inputW, ")is too small." " Calculated output W: ", outputW); + if (numInputDims == 2) { + output.resize_({numPlanes, outputW}); + } else { + output.resize_({numBatch, numPlanes, outputW}); + } - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad1d_cuda", [&] { - + if (input.numel() == 0) { + return; + } + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "replication_pad1d_cuda", [&] { if (numInputDims == 2) { - output.resize_({numPlanes, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); auto devInput = input_.packed_accessor64(); @@ -254,8 +263,8 @@ void replication_pad1d_out_cuda_template( replication_pad_forward_kernel1d <<>>(devInput, devOutput, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - output.resize_({numBatch, numPlanes, outputW}); auto devInput = input.packed_accessor64(); auto devOutput = output.packed_accessor64(); @@ -267,10 +276,10 @@ void replication_pad1d_out_cuda_template( replication_pad_forward_kernel1d <<>>(devInput, devOutput, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } + } ); - AT_CUDA_CHECK(cudaGetLastError()); } void replication_pad1d_backward_out_cuda_template( @@ -304,6 +313,9 @@ void replication_pad1d_backward_out_cuda_template( gradOutput.size(dimw)); gradInput.resize_as_(input); + if (gradInput.numel() == 0) { + return; + } gradInput.zero_(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( @@ -312,8 +324,8 @@ void replication_pad1d_backward_out_cuda_template( auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; if (numInputDims == 2) { - gradInput_ = gradInput.unsqueeze(0); - gradOutput_ = gradOutput.unsqueeze(0); + gradInput_ = gradInput.unsqueeze(0); + gradOutput_ = gradOutput.unsqueeze(0); } auto devGradInput = gradInput_.packed_accessor64(); auto devGradOutput = gradOutput_.packed_accessor64(); @@ -327,9 +339,8 @@ void replication_pad1d_backward_out_cuda_template( replication_pad_backward_kernel <<>>(devGradInput, devGradOutput, padL, padR); - } - ); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } void replication_pad2d_out_cuda_template( @@ -351,9 +362,12 @@ void replication_pad2d_out_cuda_template( int numBatch = 1; int numInputDims = input.dim(); - TORCH_CHECK(input.numel() && (numInputDims == 3 || numInputDims == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input, but got: ", - input) + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + TORCH_CHECK( + (numInputDims == 3 && input.size(0) != 0 && valid_dims) || + (numInputDims == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); if (numInputDims == 4) { numBatch = input.size(0); @@ -372,12 +386,19 @@ void replication_pad2d_out_cuda_template( "input (H: ", inputH, ", W: ", inputW, ") is too small." " Calculated output H: ", outputH, " W: ", outputW); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad2d_cuda", [&] { + if (numInputDims == 3) { + output.resize_({numPlanes, outputH, outputW}); + } else { + output.resize_({numBatch, numPlanes, outputH, outputW}); + } + if (input.numel() == 0) { + return; + } + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "replication_pad2d_cuda", [&] { if (numInputDims == 3) { - output.resize_({numPlanes, outputH, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); auto devInput = input_.packed_accessor64(); @@ -392,8 +413,8 @@ void replication_pad2d_out_cuda_template( replication_pad_forward_kernel2d <<>>( devInput, devOutput, padT, padB, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - output.resize_({numBatch, numPlanes, outputH, outputW}); auto devInput = input.packed_accessor64(); auto devOutput = output.packed_accessor64(); @@ -406,10 +427,10 @@ void replication_pad2d_out_cuda_template( replication_pad_forward_kernel2d <<>>(devInput, devOutput, padT, padB, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } + } ); - AT_CUDA_CHECK(cudaGetLastError()); } void replication_pad2d_backward_out_cuda_template( @@ -452,6 +473,9 @@ void replication_pad2d_backward_out_cuda_template( gradOutput.size(dimh)); gradInput.resize_as_(input); + if (gradInput.numel() == 0) { + return; + } gradInput.zero_(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( @@ -474,9 +498,9 @@ void replication_pad2d_backward_out_cuda_template( replication_pad_backward_kernel <<>>(devGradInput, devGradOutput, padT, padB, padL, padR); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } ); - AT_CUDA_CHECK(cudaGetLastError()); } static inline void shapeCheck3d( @@ -488,8 +512,12 @@ static inline void shapeCheck3d( "input tensor must fit into 32-bit index math"); int numInputDims = input.dim(); - TORCH_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5), - "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input); + bool valid_dims = input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0; + TORCH_CHECK( + (numInputDims == 4 && input.size(0) != 0 && valid_dims) || + (numInputDims == 5 && valid_dims && input.size(4) != 0), + "Expected 4D or 5D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); int planeDim = 0; int dimd = 1; @@ -526,8 +554,12 @@ static inline void shapeAndGradOutputCheck3d( "input tensor must fit into 32-bit index math"); int numInputDims = input.dim(); - TORCH_CHECK(input.numel() && (numInputDims == 4 || numInputDims == 5), - "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ", input); + bool valid_dims = input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0; + TORCH_CHECK( + (numInputDims == 4 && valid_dims) || + (numInputDims == 5 && valid_dims && input.size(4) != 0), + "Expected 4D or 5D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); int planeDim = 0; int dimd = 1; @@ -608,11 +640,19 @@ void replication_pad3d_out_cuda_template( int outputH = inputH + ptop + pbottom; int outputW = inputW + pleft + pright; - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad3d_cuda", [&] { + if (numInputDims == 4) { + output.resize_({numPlanes, outputD, outputH, outputW}); + } else { + output.resize_({numBatch, numPlanes, outputD, outputH, outputW}); + } + if (input.numel() == 0) { + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "replication_pad3d_cuda", [&] { if (numInputDims == 4) { - output.resize_({numPlanes, outputD, outputH, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); auto devInput = input_.packed_accessor64(); @@ -628,8 +668,8 @@ void replication_pad3d_out_cuda_template( replication_pad_forward_kernel3d <<>>( devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - output.resize_({numBatch, numPlanes, outputD, outputH, outputW}); auto devInput = input.packed_accessor64(); auto devOutput = output.packed_accessor64(); @@ -643,10 +683,10 @@ void replication_pad3d_out_cuda_template( replication_pad_forward_kernel3d <<>>( devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } - } + } ); - AT_CUDA_CHECK(cudaGetLastError()); } void replication_pad3d_backward_out_cuda_template( @@ -679,11 +719,13 @@ void replication_pad3d_backward_out_cuda_template( } gradInput.resize_as_(input); + if (gradInput.numel() == 0) { + return; + } gradInput.zero_(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "replication_pad3d_backward_cuda", [&] { - + input.scalar_type(), "replication_pad3d_backward_cuda", [&] { auto gradInput_ = gradInput; auto gradOutput_ = gradOutput; if (numInputDims == 4) { @@ -703,9 +745,9 @@ void replication_pad3d_backward_out_cuda_template( replication_pad_backward_kernel <<>>( devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright); - } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } ); - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace @@ -751,7 +793,7 @@ Tensor replication_pad1d_backward_cuda( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad1d_backward_cuda"); - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); replication_pad1d_backward_out_cuda_template( gradInput, gradOutput, input, paddingSize); return gradInput; @@ -799,7 +841,7 @@ Tensor replication_pad2d_backward_cuda( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad2d_backward_cuda"); - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); replication_pad2d_backward_out_cuda_template( gradInput, gradOutput, input, paddingSize); return gradInput; @@ -847,7 +889,7 @@ Tensor replication_pad3d_backward_cuda( // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("replication_pad3d_backward_cuda"); - auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); replication_pad3d_backward_out_cuda_template( gradInput, gradOutput, input, paddingSize); return gradInput; diff --git a/aten/src/ATen/native/cuda/Resize.cuh b/aten/src/ATen/native/cuda/Resize.cuh index d6af01638d210..346181466595f 100644 --- a/aten/src/ATen/native/cuda/Resize.cuh +++ b/aten/src/ATen/native/cuda/Resize.cuh @@ -2,6 +2,7 @@ #include #include +#include #include @@ -51,15 +52,7 @@ inline TensorImpl* resize_impl_cuda_( if (stride) { self->set_sizes_and_strides(size, *stride); // NB: storage size can be different from numel. - for (size_t dim = 0; dim < size.size(); ++dim) { - // FIXME: Don't rely on storage_size being negative because this - // may not be true for some edge cases. - if (size[dim] == 0) { - storage_size = 0; - break; - } - storage_size += (size[dim] - 1) * stride.value()[dim]; - } + storage_size = storage_size_for(size, *stride); } else { self->set_sizes_contiguous(size); storage_size = self->numel(); diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu index 6bc2c381e1dba..3848545050547 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cu +++ b/aten/src/ATen/native/cuda/ScanKernels.cu @@ -128,16 +128,16 @@ __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *se */ template __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_, - int num_orows, int num_irows, int row_size, scalar_t init, BinaryFunction binary_op) { - for (int orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (int irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *self = self_ + orow * row_size * num_irows + irow; scalar_t *values = values_ + orow * row_size * num_irows + irow; int64_t *indices = indices_ + orow * row_size * num_irows + irow; scalar_t out = init; int64_t out_idx = 0; - for (int64_t col = 0; col < row_size; ++col) { + for (auto col = decltype(row_size){0}; col < row_size; ++col) { if(THCNumerics::isnan(*self) || (!THCNumerics::isnan(out) && binary_op(*self, out))) { out = *self; out_idx = col; @@ -152,25 +152,38 @@ __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scala } } +void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + template __host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices, int dim, scalar_t init, BinaryFunction binary_op) { - int row_size = self.size(dim); + int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + dim3 threads(std::min(512, int(num_irows))); - int maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int(threads.x)))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); tensor_kernel_scan_outer_dim_with_indices<<>>( self.data_ptr(), values.data_ptr(), indices.data_ptr(), num_orows, num_irows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -186,7 +199,7 @@ __host__ void scan_innermost_dim_with_indices(const Tensor& self, Tensor& values tensor_kernel_scan_innermost_dim_with_indices<<>>( self.data_ptr(), values.data_ptr(), indices.data_ptr(), num_rows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -254,16 +267,16 @@ void cummin_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int */ template __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, - unsigned num_orows, unsigned num_irows, unsigned row_size, - scalar_t init, BinaryOp binary_op) + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) { - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *src = src_ + orow * row_size * num_irows + irow; scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; scalar_t acc = init; - for (unsigned col = 0; col < row_size; ++col) { + for (uint32_t col = 0; col < row_size; ++col) { acc = binary_op(acc, *src); *tgt = acc; @@ -286,12 +299,12 @@ __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, */ template __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_, - unsigned num_rows, unsigned row_size, + const uint32_t num_rows, const uint32_t row_size, T init, BinaryFunction binary_op){ - for (unsigned block_row = blockIdx.x * blockDim.y; + for (uint32_t block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; + uint32_t row = block_row + threadIdx.y; T block_total = init; T *row_src = src_ + row * row_size; @@ -299,10 +312,10 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr // Perform scan on one block at a time, keeping track of the total value of // all blocks processed so far. - for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { // Load data into shared memory (two values per thread). - unsigned col1 = block_col + threadIdx.x; - unsigned col2 = block_col + num_threads_x + threadIdx.x; + uint32_t col1 = block_col + threadIdx.x; + uint32_t col2 = block_col + num_threads_x + threadIdx.x; if (row < num_rows) { if (col1 < row_size) { row_buf[threadIdx.x] = row_src[col1]; @@ -324,18 +337,18 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr __syncthreads(); // Parallel reduction (up-sweep). - for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { + for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { if (row < num_rows && threadIdx.x < s) { - unsigned offset = (2 * threadIdx.x + 1) * d - 1; + uint32_t offset = (2 * threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); } // Down-sweep. - for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { + for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { if (row < num_rows && threadIdx.x < s - 1) { - unsigned offset = 2 * (threadIdx.x + 1) * d - 1; + uint32_t offset = 2 * (threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); @@ -361,8 +374,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { __shared__ T sbuf[num_threads_y][2 * num_threads_x]; @@ -381,8 +394,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { // As we cannot directly initialize shared array for complex types @@ -399,23 +412,18 @@ tensor_kernel_scan_innermost_dim( row_buf, tgt_, src_, num_rows, row_size, init, binary_op); } -void check_fits_in_unsigned(int64_t val, const char* name) { - constexpr auto umax = std::numeric_limits::max(); - TORCH_CHECK( - val >= 0 && val <= umax, name, " must fit in a 32-bit unsigned value"); -} template __host__ void scan_outer_dim(const Tensor& self, Tensor& result, int dim, scalar_t init, BinaryFunction binary_op) { - int64_t row_size = self.size(dim); + const int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int64_t num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int64_t num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); dim3 threads(std::min(512, int(num_irows))); int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; @@ -428,7 +436,7 @@ __host__ void scan_outer_dim(const Tensor& self, Tensor& result, tensor_kernel_scan_outer_dim<<>>( result.data_ptr(), self.data_ptr(), num_orows, num_irows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -448,7 +456,7 @@ void scan_innermost_dim(const Tensor& self, Tensor& result, scalar_t init, Binar tensor_kernel_scan_innermost_dim<<>>( result.data_ptr(), self.data_ptr(), num_rows, row_size, init, binary_op); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -477,6 +485,7 @@ void scan_cub(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction result.data_ptr() + i - 1, self.data_ptr() + i, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } size_t temp_storage_bytes = 0; AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan( @@ -489,7 +498,8 @@ void scan_cub(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction at::cuda::getCurrentCUDAStream())); auto temp_storage = at::native::empty_cuda( {static_cast(temp_storage_bytes)}, - self.options().dtype(kByte)); + kByte, self.options().layout_opt(), self.options().device_opt(), + self.options().pinned_memory_opt()); AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan( temp_storage.data_ptr(), temp_storage_bytes, diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index 552384b459454..ff3b5bb08baa3 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -72,11 +72,11 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) { return; } - dim3 block(nt); - dim3 grid((N + block.x * vt - 1) / (block.x * vt)); - auto stream = at::cuda::getCurrentCUDAStream(); + const dim3 block(nt); + const dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + const auto stream = at::cuda::getCurrentCUDAStream(); _scatter_gather_elementwise_kernel<<>>(N, f); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -192,7 +192,7 @@ struct cuda_scatter_gather_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_gather_base_kernel_func", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -264,7 +264,7 @@ struct cuda_scatter_gather_base_kernel { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_gather_base_kernel_reduce_multiply", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -365,7 +365,7 @@ struct cuda_scatter_fill_base_kernel { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_fill_base_kernel_func", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -417,7 +417,7 @@ struct cuda_scatter_fill_base_kernel { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - method_name, [&] { + "cuda_scatter_fill_base_kernel_reduce_multiply", [&] { using dtype = typename std::conditional, scalar_t>::type; @@ -494,5 +494,5 @@ REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cuda_kernel); REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cuda_kernel); REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cuda_kernel); REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cuda_kernel); - + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 86523e523a7d4..9bcf046765c8f 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -12,7 +12,11 @@ namespace at { namespace native { +#ifdef __HIP_PLATFORM_HCC__ +constexpr int CAT_ARRAY_BATCH_SIZE = 1024; +#else constexpr int CAT_ARRAY_BATCH_SIZE = 128; +#endif constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4; namespace { @@ -32,8 +36,8 @@ inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) { template struct CatArrIndexToOffset { static inline __device__ IndexType compute( - const IndexType outputSize[Dims], - const IndexType outputStride[Dims], + const IndexType tensorSize[Dims], + const IndexType tensorStride[Dims], const IndexType dimSize, const unsigned int concatDim, IndexType linearIndex) { @@ -45,22 +49,22 @@ struct CatArrIndexToOffset { #pragma unroll for (int i = Dims - 1; i >= 1; --i) { - IndexType curDimSize = i == concatDim ? dimSize : outputSize[i]; + IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i]; IndexType nextDimIndex = linearIndex / curDimSize; IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex; - IndexType curDimOffset = curDimIndex * outputStride[i]; + IndexType curDimOffset = curDimIndex * tensorStride[i]; offset += curDimOffset; linearIndex = nextDimIndex; } - return offset + linearIndex * outputStride[0]; + return offset + linearIndex * tensorStride[0]; } }; template -struct OutputTensorSizeStride { - IndexType outputSize[MaxDims]; - IndexType outputStride[MaxDims]; +struct TensorSizeStride { + IndexType tensorSize[MaxDims]; + IndexType tensorStride[MaxDims]; }; /** @@ -78,28 +82,71 @@ struct OutputTensorSizeStride { * The most important assumption made is that the input tensors are contiguous. */ + +// Use pinned memory and and pass the struct by pointer on ROCm +template +struct CatArrInputTensor { + T* input; + IndexType offset; + IndexType dimSize; + IndexType nElements; +}; + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void HIP_CatArrayBatchedCopy( + T* output, + CatArrInputTensor* inputs, + TensorSizeStride os, + const int concatDim, + IndexType dimStride) { + + IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; + IndexType nElements = inputs[blockIdx.y].nElements; + + if(tid >= nElements) return; + + T* data = inputs[blockIdx.y].input; + IndexType offset = inputs[blockIdx.y].offset; + IndexType dimSize = inputs[blockIdx.y].dimSize; + IndexType dataOffset = offset * dimStride; + + IndexType stride = gridDim.x * blockDim.x; + + while( tid < nElements){ + IndexType elementOffset = CatArrIndexToOffset::compute( + os.tensorSize, os.tensorStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[tid]; + + tid += stride; + } +} + // pass meta data directly through kernel argument instead of pin memory -template +// In contiguous case, we will not need stride_size, setting it as 1 as placeholder +// to pass compile. +template struct CatArrInputTensorMetadata { T* input[n]; IndexType offset[n]; IndexType dimSize[n]; IndexType nElements[n]; + bool isContiguous[n]; + TensorSizeStride tensorStride[stride_size]; }; -template -#ifdef __HIP_PLATFORM_HCC__ -C10_LAUNCH_BOUNDS_1(512) -#endif +template __global__ void CatArrayBatchedCopy( T* output, - CatArrInputTensorMetadata inputs, - OutputTensorSizeStride os, + CatArrInputTensorMetadata inputs, + TensorSizeStride os, const int concatDim, IndexType dimStride) { IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; IndexType nElements = inputs.nElements[blockIdx.y]; + TensorSizeStride ins = stride_size > 1 ? inputs.tensorStride[blockIdx.y] : inputs.tensorStride[0]; + bool isContig = inputs.isContiguous[blockIdx.y]; if(tid >= nElements) return; @@ -111,10 +158,15 @@ __global__ void CatArrayBatchedCopy( IndexType stride = gridDim.x * blockDim.x; while( tid < nElements){ - IndexType elementOffset = CatArrIndexToOffset::compute( - os.outputSize, os.outputStride, dimSize, concatDim, tid); - output[dataOffset + elementOffset] = data[tid]; - + IndexType elementOffset = CatArrIndexToOffset::compute( + os.tensorSize, os.tensorStride, dimSize, concatDim, tid); + if (isContig) { + output[dataOffset + elementOffset] = data[tid]; + } else { + IndexType inElementOffset = CatArrIndexToOffset::compute( + ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[inElementOffset]; + } tid += stride; } } @@ -142,31 +194,39 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, } template -void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, +void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, int nDims, c10::MemoryFormat memory_format) { // First, let's set up our kernel parameters. We start with a raw pointer to // the storage for the output Tensor. scalar_t *data = out.data_ptr(); - CatArrInputTensorMetadata catMetaData; - OutputTensorSizeStride param; + + // Kernel Parameter + long tensorMetadataSize = + sizeof(CatArrInputTensor) * CAT_ARRAY_BATCH_SIZE; + auto d_inputs_storage = at::empty( + {tensorMetadataSize}, out.options().dtype(at::kByte)); + auto d_inputs = static_cast *>( + d_inputs_storage.data_ptr()); + + TensorSizeStride outputParam; // Next, let's initialize the size, stride arrays for the output Tensor. if (memory_format == c10::MemoryFormat::Contiguous) { for (int i = 0; i < nDims; ++i) { - param.outputSize[i] = at::native::size(out, i); - param.outputStride[i] = out.stride(i); + outputParam.tensorSize[i] = at::native::size(out, i); + outputParam.tensorStride[i] = out.stride(i); } } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) { // permute the semantics of dims from NCHW to NHWC so that the input // tensor is now contiguous - param.outputSize[0] = at::native::size(out, 0); - param.outputStride[0] = out.stride(0); + outputParam.tensorSize[0] = at::native::size(out, 0); + outputParam.tensorStride[0] = out.stride(0); for (int i = 1; i < nDims - 1; ++i) { - param.outputSize[i] = at::native::size(out, i + 1); - param.outputStride[i] = out.stride(i + 1); + outputParam.tensorSize[i] = at::native::size(out, i + 1); + outputParam.tensorStride[i] = out.stride(i + 1); } - param.outputSize[nDims - 1] = at::native::size(out, 1); - param.outputStride[nDims - 1] = out.stride(1); + outputParam.tensorSize[nDims - 1] = at::native::size(out, 1); + outputParam.tensorStride[nDims - 1] = out.stride(1); } else { TORCH_CHECK(false, "unsupported memory format"); } @@ -177,16 +237,144 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, int batchCounter = 0; int64_t offset = 0; for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) { + // Re-allocate stackInputs every iteration to avoid read-after-write hazard + { + auto stackInputs_storage = at::empty({tensorMetadataSize}, + out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true)); + auto stackInputs = + static_cast *>( + stackInputs_storage.data_ptr()); + for (batchCounter = 0; + batchCounter < CAT_ARRAY_BATCH_SIZE && + (i+batchCounter) < inputs.size(); + ++batchCounter) { + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } + + stackInputs[batchCounter].input = + inputs[i+batchCounter].data_ptr(); + stackInputs[batchCounter].offset = offset; + stackInputs[batchCounter].dimSize = dimSize; + stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel(); + + // update offset + offset += dimSize; + } + at::native::copy_(d_inputs_storage, stackInputs_storage, + /* non_blocking= */ true); + } + + // Next, let's consider how we set our kernel launch parameters. + // We borrow from THCApply, which the kernel's internal indexing + // is based on. + dim3 applyBlock = dim3(32*16); + + //Get grid where x dim fills half gpu and y dim is number of tensors. + //This will have cating two tensors fill the entire grid, but prevent + //many threads from needlessly load meta data if their sizes is small. + dim3 catGrid; + getCatGrid(batchCounter, catGrid); + + if (memory_format != c10::MemoryFormat::Contiguous) { + switch (dimension) { + case 0: + break; + case 1: + dimension = nDims - dimension; + break; + default: + dimension--; + } + } + // Template Declarations for dim = 1, 2, 3, 4 +#define HANDLE_CASE(DIMS) \ + HIP_CatArrayBatchedCopy<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + data, d_inputs, outputParam, dimension, outputParam.tensorStride[dimension]); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + switch (nDims) { + case 1: + HANDLE_CASE(1); + break; + case 2: + HANDLE_CASE(2); + break; + case 3: + HANDLE_CASE(3); + break; + case 4: + HANDLE_CASE(4); + break; + } +#undef HANDLE_CASE + } +} + +template +void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, + int nDims, c10::MemoryFormat memory_format) { + // First, let's set up our kernel parameters. We start with a raw pointer to + // the storage for the output Tensor. + scalar_t *data = out.data_ptr(); + CatArrInputTensorMetadata catMetaData; + TensorSizeStride outputParam; + + // Next, let's initialize the size, stride arrays for the output Tensor. + if (memory_format == c10::MemoryFormat::Contiguous) { + for (int i = 0; i < nDims; ++i) { + outputParam.tensorSize[i] = at::native::size(out, i); + outputParam.tensorStride[i] = out.stride(i); + } + } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) { + // permute the semantics of dims from NCHW to NHWC so that the input + // tensor is now contiguous + outputParam.tensorSize[0] = at::native::size(out, 0); + outputParam.tensorStride[0] = out.stride(0); + for (int i = 1; i < nDims - 1; ++i) { + outputParam.tensorSize[i] = at::native::size(out, i + 1); + outputParam.tensorStride[i] = out.stride(i + 1); + } + outputParam.tensorSize[nDims - 1] = at::native::size(out, 1); + outputParam.tensorStride[nDims - 1] = out.stride(1); + } else { + TORCH_CHECK(false, "unsupported memory format"); + } + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + // Now we loop + int batchCounter = 0; + int64_t offset = 0; + for (int i = 0; i < inputs.size() ; i += batch_size) { for (batchCounter = 0; - batchCounter < CAT_ARRAY_BATCH_SIZE && + batchCounter < batch_size && (i+batchCounter) < inputs.size(); ++batchCounter) { - int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } catMetaData.input[batchCounter] = inputs[i+batchCounter].data_ptr(); catMetaData.offset[batchCounter] = offset; catMetaData.dimSize[batchCounter] = dimSize; catMetaData.nElements[batchCounter] = inputs[i+batchCounter].numel(); - + if (stride_size > 1) { + auto strides = inputs[i+batchCounter].strides(); + auto sizes = inputs[i+batchCounter].sizes(); + for(int j = 0; j < nDims; j++){ + catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j]; + catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j]; + } + catMetaData.isContiguous[batchCounter] = false; + } else { + catMetaData.isContiguous[batchCounter] = true; + } // update offset offset += dimSize; } @@ -214,9 +402,10 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, } // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - CatArrayBatchedCopy<<<\ + CatArrayBatchedCopy<<<\ catGrid, applyBlock, 0, stream.stream()>>>(\ - data, catMetaData, param, dimension, param.outputStride[dimension]); + data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); switch (nDims) { case 1: HANDLE_CASE(1); @@ -232,10 +421,8 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, break; } #undef HANDLE_CASE - AT_CUDA_CHECK(cudaGetLastError()); } } - } // namespace Tensor cat_cuda(TensorList inputs, int64_t dimension) { @@ -275,7 +462,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { auto should_skip = [](const Tensor &t) { return t.dim() == 1 && at::native::size(t, 0) == 0; }; - bool hasSkippedInput = false; const Tensor *notSkippedTensor = NULL; // non-owning reference int nDims = 0; @@ -296,10 +482,8 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { } at::assert_no_internal_overlap(out); - for (int i = 0; i < inputs.size(); i++) - { + for (int i = 0; i < inputs.size(); i++) { if (should_skip(inputs[i])) { - hasSkippedInput = true; continue; } nDims = inputs[i].dim(); @@ -337,7 +521,12 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { // Compute the size of the result size[dimension] = cat_dim_size; - out.resize_(size, memory_format); + + // skip resizing if size of result is same as expected + if (out.sizes() != size) { + out.resize_(size, memory_format); + } + if (out.numel() == 0) { return out; } @@ -345,11 +534,10 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { // We parallelize the copy if all 6 conditions pass: // // 1. There is more than one input tensor - // 2. No empty inputs - // 3. The out tensor is 32-bit indexable - // 4. The number of dimensions is <= 4 - // 5. All input tensors are contiguous (output tensor may be non-contig) - // 6. All input tensors can use 32-bit indexing + // 2. The out tensor is 32-bit indexable + // 3. The number of dimensions is <= 4 + // 4. All input tensors are contiguous (output tensor may be non-contig) + // 5. All input tensors can use 32-bit indexing const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(), [] (const Tensor& t) { @@ -365,20 +553,48 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { return t.scalar_type() == firstType; }); allSameType = allSameType && (out.scalar_type() == firstType); + +#ifdef __HIP_PLATFORM_HCC__ if (inputs.size() > 1 && - !hasSkippedInput && out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && at::cuda::detail::canUse32BitIndexMath(out) && allContiguous && all32BitIndexable && allSameType) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, out.scalar_type(), "cat_cuda", [&]() { - parallel_cat(out, inputs, dimension, nDims, memory_format); + hip_parallel_cat(out, inputs, dimension, nDims, memory_format); }); - +#else + // We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways + // For contiguous input, we don't need to pass stride meta data to cuda kernel through constant + // memory. Therefore, we could pass more inputs to cuda threads. + // For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation + // of constant memory. + if (inputs.size() > 1 && + out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && + at::cuda::detail::canUse32BitIndexMath(out) && + allContiguous && + all32BitIndexable && + allSameType) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + out.scalar_type(), "cat_cuda", [&]() { + parallel_cat(out, inputs, dimension, nDims, memory_format); + }); + } else if (inputs.size() > 1 && + out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && + at::cuda::detail::canUse32BitIndexMath(out) && + nDims <= CAT_ARRAY_MAX_INPUT_DIMS && + all32BitIndexable && + allSameType) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + out.scalar_type(), "cat_cuda", [&]() { + parallel_cat(out, inputs, dimension, nDims, memory_format); + }); +#endif } else { int64_t offset = 0; for (int j = 0; j < inputs.size(); j++) diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index ca00a3520f299..fb43dcb4c3c3d 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -709,32 +709,32 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t if (inner_size == 1) { dim3 grid(outer_size); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { - using accscalar_t = acc_type; - if (!half_to_float) { - if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { - dispatch_softmax_forward( - output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); - } else { - constexpr int ILP = sizeof(float4) / sizeof(scalar_t); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), dim_size - ); - } - } else { - if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { - dispatch_softmax_forward( - output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); + using accscalar_t = acc_type; + if (!half_to_float) { + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + dispatch_softmax_forward( + output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); + } else { + constexpr int ILP = sizeof(float4) / sizeof(scalar_t); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), dim_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } else { - constexpr int ILP = sizeof(float4) / sizeof(accscalar_t); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), dim_size - ); + if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + dispatch_softmax_forward( + output.data_ptr(), input.data_ptr(), dim_size, dim_size, outer_size); + } else { + constexpr int ILP = sizeof(float4) / sizeof(accscalar_t); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), dim_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } - } }); // This kernel runs in a 2D grid, where each application along y dimension has a fixed // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size. @@ -743,29 +743,28 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t uint32_t smem_size; dim3 grid, block; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { - using accscalar_t = acc_type; - if (!half_to_float) { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxForward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - cunn_SpatialSoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size - ); - } else { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxForward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - cunn_SpatialSoftMaxForward - <<>>( - output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size - ); - } + using accscalar_t = acc_type; + if (!half_to_float) { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxForward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + cunn_SpatialSoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxForward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + cunn_SpatialSoftMaxForward + <<>>( + output.data_ptr(), input.data_ptr(), outer_size, dim_size, inner_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } }); } - AT_CUDA_CHECK(cudaGetLastError()); } return output; } @@ -807,6 +806,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t <<>>( gI.data_ptr(), output.data_ptr(), grad.data_ptr(), dim_size ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } else { if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { @@ -819,6 +819,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t <<>>( gI.data_ptr(), output.data_ptr(), grad.data_ptr(), dim_size ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } }); @@ -826,33 +827,35 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t uint32_t smem_size; dim3 grid, block; AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] { - using accscalar_t = acc_type; - if (!half_to_float) { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxBackward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - - cunn_SpatialSoftMaxBackward - <<>>( - gI.data_ptr(), output.data_ptr(), grad.data_ptr(), - outer_size, dim_size, inner_size - ); - } else { - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxBackward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - - cunn_SpatialSoftMaxBackward - <<>>( - gI.data_ptr(), output.data_ptr(), grad.data_ptr(), - outer_size, dim_size, inner_size - ); - } + using accscalar_t = acc_type; + if (!half_to_float) { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxBackward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + + cunn_SpatialSoftMaxBackward + <<>>( + gI.data_ptr(), output.data_ptr(), grad.data_ptr(), + outer_size, dim_size, inner_size + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxBackward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + + cunn_SpatialSoftMaxBackward + <<>>( + gI.data_ptr(), output.data_ptr(), grad.data_ptr(), + outer_size, dim_size, inner_size + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } }); } - AT_CUDA_CHECK(cudaGetLastError()); + return gI; } } diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu new file mode 100644 index 0000000000000..0774143a83e2a --- /dev/null +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -0,0 +1,463 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // only for THCRoundUp? +#include +#include +#include // AddOp + +#include +#include + +namespace at { +namespace native { + +namespace { + +// Finds the rank k element, and its index, of the values along dimension dim +template +__global__ void gatherKthValue( + cuda::detail::TensorInfo input, + index_t inputSliceSize, + index_t k, + index_t numInputSlices, + index_t inputWithinSliceStride, + cuda::detail::TensorInfo kthValue, + cuda::detail::TensorInfo indices) { + // Indices are limited to integer fp precision, so counts can fit in + // int32, regardless of index_t + __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit + + index_t slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Find the start offset for our slice + index_t sliceStartIndex = + cuda::detail::IndexToOffset::get(slice, input); + index_t kthValueSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, kthValue); + index_t indicesSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, indices); + + scalar_t* inputSliceStart = &input.data[sliceStartIndex]; + scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + // Find the k-th highest element in our input + scalar_t kValue = static_cast(0); + radixSelect< + scalar_t, + typename TopKTypeConfig::RadixType, + index_t, + false>( + inputSliceStart, + k, + inputSliceSize, + inputWithinSliceStride, + smem, + &kValue); + + // Find the index of the k-th highest element + index_t kValueIndex = 0; + bool foundKValue = false; + + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) + : static_cast(0); + bool isKValue = inRange && + ((v == kValue) || + (THCNumerics::isnan(v) && + THCNumerics::isnan(kValue))); + if (isKValue) { + kValueIndex = i; + foundKValue = true; + break; + } + } + + if (foundKValue) { + kthValueSliceStart[0] = kValue; + indicesSliceStart[0] = kValueIndex; + } +} + +// CUDA kernel to find the median, and its index, of the values along dimension dim +template +__global__ void gatherMedian( + cuda::detail::TensorInfo values, + cuda::detail::TensorInfo indices, + cuda::detail::TensorInfo input, + index_t inputSliceSize, + index_t numInputSlices, + index_t inputWithinSliceStride, + bool ignore_nan) { + // Shared memory for the subroutine RadixSelect. Note that RadixSelect converts the + // floating point type to int with the same relative ordering. + __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit + + index_t slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Finds the start offset for our slice + index_t valuesSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, values); + index_t indicesSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, indices); + index_t inputSliceStartIndex = + cuda::detail::IndexToOffset::get(slice, input); + + scalar_t* valuesSliceStart = &values.data[valuesSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + scalar_t* inputSliceStart = &input.data[inputSliceStartIndex]; + + index_t nan_count = 0; + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]); + nan_count += THCNumerics::isnan(val) ? 1 : 0; + } + + // Counts number of nan values + // This code performs a parallel sum reduction (not the most efficient code) + __shared__ int64_t num_nan; + if (threadIdx.x == 0) { + num_nan = 0; + } + __syncthreads(); + if (nan_count > 0) { + atomicAdd(&num_nan, nan_count); + } + __syncthreads(); + + // For torch.median, if we found nan set k to last index so the computed value + // is nan, otherwise set k to the middle element of the non-nan values + index_t k = (!ignore_nan && num_nan > 0) ? inputSliceSize - 1 + : (inputSliceSize - num_nan - 1) / 2; + + // Find the median + scalar_t median = static_cast(0); + radixSelect< + scalar_t, + typename TopKTypeConfig::RadixType, + index_t, + false>( + inputSliceStart, + k + 1, + inputSliceSize, + inputWithinSliceStride, + smem, + &median); + + valuesSliceStart[0] = median; + + // Find the index of the median value in the slice + for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { + scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]); + if (val == median || + (THCNumerics::isnan(val) && + THCNumerics::isnan(median))) { + indicesSliceStart[0] = i; + break; + } + } +} + +struct KthValueLauncher { + int64_t k; + + KthValueLauncher(int64_t k) : k(k) {} + + template + inline void launch( + cuda::detail::TensorInfo values_info, + int collapse_values_dim, + cuda::detail::TensorInfo indices_info, + int collapse_indices_dim, + cuda::detail::TensorInfo self_info, + int collapse_self_dim, + int64_t num_slices, + int64_t slice_size) { + dim3 grid; + if (!getGridFromTiles(num_slices, grid)) { + AT_ERROR("slices are too many"); + } + + dim3 block(std::min( + THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024)); + auto stream = at::cuda::getCurrentCUDAStream(); + gatherKthValue<<>>( + self_info, + slice_size, + k, + num_slices, + /* The actual dimension that the k-selection is running in */ + /* may have changed from collapseDims() */ + self_info.strides[collapse_self_dim], + values_info, + indices_info); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +}; + +struct MedianLauncher { + bool ignore_nan; + + MedianLauncher(bool ignore_nan) : ignore_nan(ignore_nan) {} + + template + inline void launch( + cuda::detail::TensorInfo values_info, + int collapse_values_dim, + cuda::detail::TensorInfo indices_info, + int collapse_indices_dim, + cuda::detail::TensorInfo self_info, + int collapse_self_dim, + int64_t num_slices, + int64_t slice_size) { + dim3 grid; + if (!getGridFromTiles(num_slices, grid)) { + AT_ERROR("slices are too many"); + } + + dim3 block(std::min( + THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024)); + auto stream = at::cuda::getCurrentCUDAStream(); + gatherMedian<<>>( + values_info, + indices_info, + self_info, + slice_size, + num_slices, + self_info.strides[collapse_self_dim], + ignore_nan); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +}; + +template +void kthvalue_cuda_template( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim_, + bool keepdim) { + int64_t dim = maybe_wrap_dim(dim_, self.dim()); + int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim); + // FIXME: This seems bogus, I only do this because it was the old behaviour. + // The reductions are fine, as long as the axis being reduced along + // isn't of 0 elements (and the output has elements). + TORCH_CHECK( + self.numel() > 0, + "cannot perform reduction function kthvalue", + " on tensor with no elements because the operation does not have an identity"); + TORCH_CHECK(k >= 1 && k <= slicesize, "selected number k out of range"); + + at::assert_no_overlap(self, values); + + _reduction_with_indices_allocate_or_resize_output( + values, indices, self, dim, keepdim); + if (self.dim() == 0 && self.numel() == 1) { + values.copy_(self); + indices.zero_(); + return; + } + + TORCH_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + // Based on required index size, run the algorithm with the + // appropriate index type + if (cuda::detail::canUse32BitIndexMath(self) && + cuda::detail::canUse32BitIndexMath(values) && + cuda::detail::canUse32BitIndexMath(indices)) { + run_launcher( + values, indices, self, dim, KthValueLauncher(k)); + } else { + run_launcher( + values, indices, self, dim, KthValueLauncher(k)); + } + + if (!keepdim) { + values.squeeze_(dim); + indices.squeeze_(dim); + } +} + +std::tuple kthvalue_out_impl_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] { + kthvalue_cuda_template( + values, indices, self, k, dim, keepdim); + }); + return std::forward_as_tuple(values, indices); +} + +std::tuple median_with_indices_impl( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim, + bool ignore_nan) { + // See note [Writing Nondeterministic Operations] + // If there are duplicate elements of a median value, the procedure for choosing which + // of the duplicates to use for the indices output is nondeterministic. + at::globalContext().alertNotDeterministic("median CUDA with indices output"); + NoNamesGuard guard; + + dim = at::maybe_wrap_dim(dim, self.dim()); + Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0); + + int64_t size = in.size(dim); + TORCH_CHECK( + size > 0, + "median() cannot compute median for a dimension of size 0 because ", + "the operation does not have an identity"); + + checkDeviceType("median", {values, indices}, self.device().type()); + checkScalarType("median", {indices, "indices", 1}, kLong); + checkSameType("median", {values, "values", 0}, {self, "self", 2}); + + TORCH_CHECK( + self.dim() <= MAX_TENSORINFO_DIMS, + "median() cannot operate on more than ", + MAX_TENSORINFO_DIMS, + " dimensions"); + + std::vector out_shape = self.sizes().vec(); + if (self.dim() > 0) { + if (keepdim) { + out_shape[dim] = 1; + } else { + out_shape.erase(out_shape.begin() + dim); + } + } + + values.resize_(out_shape); + indices.resize_(out_shape); + + // Only launch kernel for non-empty tensors + if (self.numel() > 0) { + // Ensure #dim is the same for all tensors required for reduction + Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim); + Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim); + + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, self.scalar_type(), "median_out_impl", [&] { + if (cuda::detail::canUse32BitIndexMath(vals) && + cuda::detail::canUse32BitIndexMath(inds) && + cuda::detail::canUse32BitIndexMath(in)) { + run_launcher( + vals, inds, in, dim, MedianLauncher(ignore_nan)); + } else { + run_launcher( + vals, inds, in, dim, MedianLauncher(ignore_nan)); + } + }); + } + + guard.reset(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + + return std::forward_as_tuple(values, indices); +} + +Tensor median_impl(const Tensor& self, bool ignore_nan) { + NoNamesGuard guard; + + int64_t size = self.numel(); + TORCH_CHECK(size > 0, "median() input tensor cannot be empty"); + + // Sort input tensor to efficiently query for median element + Tensor sorted = std::get<0>(self.flatten().sort()); + + if (!ignore_nan) { + // For torch.median return either the middle element or nan (sorted as + // largest) if there are any + int64_t k = (size - 1) / 2; + return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]); + } else { + // For torch.nanmedian return the middle element among the non-nan values + Tensor k = ((size - 1) - sorted.isnan().sum()) / 2; + return sorted[k.toType(kLong)]; + } +} + +} // namespace + +// Mark: kthvalue + +std::tuple kthvalue_out_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t k, + int64_t dim, + bool keepdim) { + // See note [Writing Nondeterministic Operations] + // If there are duplicate elements of the kth value, the procedure for choosing which + // of the duplicates to use for the indices output is nondeterministic. + at::globalContext().alertNotDeterministic("kthvalue CUDA"); + auto result = [&]() { + NoNamesGuard guard; + // `kthvalue_out_impl_cuda` expects contiguous in input `self`. + return kthvalue_out_impl_cuda(values, indices, self.contiguous(), k, dim, keepdim); + }(); + namedinference::propagate_names_for_reduction(values, self, dim, keepdim); + namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); + return result; +} + +// Mark: median + +std::tuple median_out_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim) { + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/false); +} + +Tensor median_cuda(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/false); +} + +std::tuple nanmedian_out_cuda( + Tensor& values, + Tensor& indices, + const Tensor& self, + int64_t dim, + bool keepdim) { + return median_with_indices_impl( + values, indices, self, dim, keepdim, /*ignore_nan=*/true); +} + +Tensor nanmedian_cuda(const Tensor& self) { + return median_impl(self, /*ignore_nan=*/true); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/SortingCommon.cuh b/aten/src/ATen/native/cuda/SortingCommon.cuh index 54513955e9127..0e5cb7371d585 100644 --- a/aten/src/ATen/native/cuda/SortingCommon.cuh +++ b/aten/src/ATen/native/cuda/SortingCommon.cuh @@ -143,6 +143,7 @@ static uint64_t nextHighestPowerOf2(uint64_t n) { } +// WARNING: This function assumes input tensors are contiguous template void run_launcher( Tensor& values, diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu deleted file mode 100644 index 953239197a668..0000000000000 --- a/aten/src/ATen/native/cuda/SortingKthValue.cu +++ /dev/null @@ -1,259 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include // only for THCRoundUp? -#include -#include -#include // AddOp - -#include -#include -#include - -namespace at { -namespace native { - -namespace { - -template -__global__ void gatherKthValue( - cuda::detail::TensorInfo input, - index_t inputSliceSize, - index_t k, - - index_t numInputSlices, - index_t inputWithinSliceStride, - - cuda::detail::TensorInfo kthValue, - cuda::detail::TensorInfo indices) { - // Indices are limited to integer fp precision, so counts can fit in - // int32, regardless of index_t - __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit - - index_t slice = getLinearBlockId(); - if (slice >= numInputSlices) { - return; - } - - // Find the start offset for our slice - index_t sliceStartIndex = - cuda::detail::IndexToOffset::get(slice, input); - index_t kthValueSliceStartIndex = - cuda::detail::IndexToOffset::get(slice, kthValue); - index_t indicesSliceStartIndex = - cuda::detail::IndexToOffset::get(slice, indices); - - scalar_t* inputSliceStart = &input.data[sliceStartIndex]; - scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex]; - int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; - - // Find the k-th highest element in our input - scalar_t kValue = static_cast(0); - radixSelect< - scalar_t, - typename TopKTypeConfig::RadixType, - index_t, - false>( - inputSliceStart, - k, - inputSliceSize, - inputWithinSliceStride, - smem, - &kValue); - - // Find the index of the k-th highest element - index_t kValueIndex = 0; - bool foundKValue = false; - - for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) { - bool inRange = (i < inputSliceSize); - scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) - : static_cast(0); - bool isKValue = inRange && - ((v == kValue) || - (THCNumerics::isnan(v) && - THCNumerics::isnan(kValue))); - if (isKValue) { - kValueIndex = i; - foundKValue = true; - break; - } - } - - if (foundKValue) { - kthValueSliceStart[0] = kValue; - indicesSliceStart[0] = kValueIndex; - } -} - -struct KthValueLauncher { - int64_t k; - - KthValueLauncher(int64_t k) : k(k) {} - - template - inline void launch( - cuda::detail::TensorInfo values_info, - int collapse_values_dim, - cuda::detail::TensorInfo indices_info, - int collapse_indices_dim, - cuda::detail::TensorInfo self_info, - int collapse_self_dim, - int64_t num_slices, - int64_t slice_size) { - dim3 grid; - if (!getGridFromTiles(num_slices, grid)) { - AT_ERROR("slices are too many"); - } - - dim3 block( - std::min(THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024)); - auto stream = at::cuda::getCurrentCUDAStream(); - gatherKthValue<<>>( - self_info, - slice_size, - k, - num_slices, - /* The actual dimension that the k-selection is running in */ - /* may have changed from collapseDims() */ - self_info.strides[collapse_self_dim], - values_info, - indices_info); - } -}; - -template -void kthvalue_cuda_template( - Tensor& values, - Tensor& indices, - const Tensor& self, - int64_t k, - int64_t dim_, - bool keepdim) { - int64_t dim = maybe_wrap_dim(dim_, self.dim()); - int64_t slicesize = self.size(dim); - // FIXME: This seems bogus, I only do this because it was the old behaviour. - // The reductions are fine, as long as the axis being reduced along - // isn't of 0 elements (and the output has elements). - TORCH_CHECK( - self.numel() > 0, - "cannot perform reduction function kthvalue", - " on tensor with no elements because the operation does not have an identity"); - TORCH_CHECK(k >= 1 && k <= slicesize, "selected number k out of range"); - - _reduction_with_indices_allocate_or_resize_output( - values, indices, self, dim, keepdim); - if (self.dim() == 0 && self.numel() == 1) { - values.copy_(self); - indices.zero_(); - return; - } - - TORCH_CHECK( - self.dim() <= MAX_TENSORINFO_DIMS, - "cannot operate on more than ", - MAX_TENSORINFO_DIMS, - " dimensions"); - - // Based on required index size, run the algorithm with the - // appropriate index type - if (cuda::detail::canUse32BitIndexMath(self) && - cuda::detail::canUse32BitIndexMath(values) && - cuda::detail::canUse32BitIndexMath(indices)) { - run_launcher( - values, indices, self, dim, KthValueLauncher(k)); - } else { - run_launcher( - values, indices, self, dim, KthValueLauncher(k)); - } - - if (!keepdim) { - values.squeeze_(dim); - indices.squeeze_(dim); - } - - AT_CUDA_CHECK(cudaGetLastError()); -} - -// this does not reduce to median with dim because we don't want to copy twice -template -Tensor median_cuda_template(const Tensor& self) { - TORCH_CHECK(self.numel() > 0, "median cannot be called with empty tensor"); - if (self.dim() == 0 && self.numel() == 1) { - return self.clone(at::MemoryFormat::Contiguous); - } - auto self_copy = self.clone(at::MemoryFormat::Contiguous).view(-1); - auto values = at::empty({1}, self.options()); - auto indices = at::empty({1}, self.options().dtype(kLong)); - TORCH_CHECK( - self.dim() <= MAX_TENSORINFO_DIMS, - "cannot operate on more than ", - MAX_TENSORINFO_DIMS, - " dimensions"); - - // Based on required index size, run the algorithm with the - // appropriate index type - if (cuda::detail::canUse32BitIndexMath(self) && - cuda::detail::canUse32BitIndexMath(values) && - cuda::detail::canUse32BitIndexMath(indices)) { - run_launcher( - values, - indices, - self_copy, - 0, - KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based - } else { - run_launcher( - values, - indices, - self_copy, - 0, - KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based - } - return values.view({}); -} - -} // namespace - -static std::tuple kthvalue_out_impl_cuda( - Tensor& values, - Tensor& indices, - const Tensor& self, - int64_t k, - int64_t dim, - bool keepdim) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] { - kthvalue_cuda_template(values, indices, self, k, dim, keepdim); - }); - return std::forward_as_tuple(values, indices); -} - -std::tuple kthvalue_out_cuda( - Tensor& values, - Tensor& indices, - const Tensor& self, - int64_t k, - int64_t dim, - bool keepdim) { - auto result = [&]() { - NoNamesGuard guard; - return kthvalue_out_impl_cuda(values, indices, self, k, dim, keepdim); - }(); - namedinference::propagate_names_for_reduction(values, self, dim, keepdim); - namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); - return result; -} - -Tensor median_cuda(const Tensor& self) { - NoNamesGuard guard; - return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "median", [&] { - return median_cuda_template(self); - }); -} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/cuda/SpectralOps.cu b/aten/src/ATen/native/cuda/SpectralOps.cu index bf7aac20815e4..e5e91cea4cccd 100644 --- a/aten/src/ATen/native/cuda/SpectralOps.cu +++ b/aten/src/ATen/native/cuda/SpectralOps.cu @@ -4,7 +4,11 @@ #include #include #include +#include +#include #include +#include +#include #include #include #include @@ -22,116 +26,165 @@ namespace at { namespace native { using namespace at::native::detail; -// In real-to-complex transform, cuFFT only fills half of the values due to -// conjugate symmetry. See native/SpectralUtils.h for more details. -// The following structs are used to fill in the other half with symmetry in -// case of real-to-complex transform with onesided=False flag. -// See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h. +// Offset calculator for indexing in Hermitian mirrored order. +// In mirrored dims, maps linear index i to (n - i) % n +template +struct HermitianSymmetryOffsetCalculator { + using offset_type = at::detail::Array; + using dim_type = std::remove_cv_t; + dim_type dims; + IntDivider sizes_[MAX_DIMS]; + index_t strides_[MAX_DIMS]; + uint32_t mirror_dim_; // bit mask + static_assert(MAX_DIMS < 32, "Need a bigger mask type"); + + HermitianSymmetryOffsetCalculator( + IntArrayRef sizes, IntArrayRef strides, IntArrayRef dim, + const int64_t element_size){ + TORCH_INTERNAL_ASSERT(sizes.size() == strides.size()); + TORCH_INTERNAL_ASSERT(sizes.size() <= MAX_DIMS); + dims = sizes.size(); + + for (dim_type i = 0; i < MAX_DIMS; ++i) { + if (i < dims) { + sizes_[i] = IntDivider(sizes[i]); + strides_[i] = strides[i] / element_size; + } else { + sizes_[i] = IntDivider(1); + strides_[i] = 0; + } + } -// counting_iterator => index to fill -struct cnt_to_dst_idx_functor : public thrust::unary_function -{ - int64_t last_dim_size; - int64_t last_dim_start_slice; - int64_t last_dim_to_fill_size; + mirror_dim_ = 0; + for (int64_t i = 0; i < dim.size(); ++i) { + mirror_dim_ |= (uint32_t{1} << dim[i]); + } + } - cnt_to_dst_idx_functor(int64_t last_dim_size, int64_t last_dim_start_slice) : - last_dim_size(last_dim_size), last_dim_start_slice(last_dim_start_slice), - last_dim_to_fill_size(last_dim_size - last_dim_start_slice) {} + C10_HOST_DEVICE offset_type get(index_t linear_idx) const { + index_t offset = 0; - // HIP wants __host__ __device__ tag, CUDA does not -#ifdef __HIP_PLATFORM_HCC__ - __host__ __device__ -#endif - cnt_to_dst_idx_functor & operator=(const cnt_to_dst_idx_functor&) = default; + for (dim_type dim = 0; dim < dims; ++dim) { + auto divmod = sizes_[dim].divmod(linear_idx); + linear_idx = divmod.div; - __host__ __device__ __forceinline__ - int64_t operator()(const int64_t& i) const - { - int64_t imag = i % 2; - int64_t idx = i / 2; - int64_t num_dim = idx / last_dim_to_fill_size; - int64_t slice_idx = idx % last_dim_to_fill_size; - return (num_dim * last_dim_size + last_dim_start_slice + slice_idx) * 2 + imag; + if ((mirror_dim_ & (uint32_t{1} << dim)) == 0) { + offset += divmod.mod * strides_[dim]; + } else if (divmod.mod != 0) { + offset += (sizes_[dim].divisor - divmod.mod) * strides_[dim]; + } + } + offset_type offsets; + offsets[0] = offset; + return offsets; } }; -// index to fill => index to read from -template -struct dst_idx_to_src_functor : public thrust::unary_function -{ - // output can have at most dim 5 (batch + 3 signal dim + real/imag) - int64_t sizes[max_rank + 2], strides[max_rank + 2]; - const int64_t signal_ndim; - scalar_t *data; // device ptr - - dst_idx_to_src_functor(const Tensor& batched_complex_signal) - : signal_ndim(batched_complex_signal.dim() - 1), - data(batched_complex_signal.data_ptr()) { - for (int64_t i = 0; i < signal_ndim; i++) { - sizes[i] = batched_complex_signal.size(i); - strides[i] = batched_complex_signal.stride(i); - } +// out[:] = conj(in[:]) where in and out ordering is generalized by offset calculators +template +C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS) +__global__ void _fft_conjugate_copy_kernel( + int64_t numel, scalar_t * out_data, const scalar_t * in_data, + inp_calc_t ic, out_calc_t oc) { + CUDA_KERNEL_LOOP_TYPE(index, numel, int64_t) { + auto in_offset = ic.get(index)[0]; + auto out_offset = oc.get(index)[0]; + out_data[out_offset] = std::conj(in_data[in_offset]); } +} - __device__ __forceinline__ - scalar_t operator()(const int64_t& write_idx_with_imag) const - { - int64_t imag = write_idx_with_imag % 2; - // all but first (batch) and last (real/imag) dims need to be reflected - int64_t read_idx = 0; - int64_t remainder = write_idx_with_imag - imag; - int64_t dim_idx, dim_stride; - for (int64_t i = 0; i < signal_ndim; i++) { - dim_stride = strides[i]; - dim_idx = remainder / dim_stride; - if (i == 0) { - read_idx += dim_idx * dim_stride; - } else if (dim_idx != 0) { - read_idx += (sizes[i] - dim_idx) * dim_stride; - } - remainder = remainder % dim_stride; - } - if (imag) { - return -data[read_idx + 1]; - } else { - return data[read_idx]; - } - } -}; +// In real-to-complex transform, cuFFT only fills half of the values due to +// conjugate symmetry. See native/SpectralUtils.h for more details. +// The following function fills in the other half with symmetry in +// case of real-to-complex transform with onesided=False flag. +// See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h. -// input should be a contiguous batched tensor of same size as full (twosided) +// input should be a tensor of same size as full (twosided) // signals, but only contains half (onesided) of the values. // This function modifies inplace. -__forceinline__ -static void _fft_fill_with_conjugate_symmetry_(Tensor& input, - int64_t size_last_dim, int64_t last_dim_start_slice) { - if (last_dim_start_slice >= size_last_dim) { - return; - } +void _fft_fill_with_conjugate_symmetry_cuda_( + ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes, + IntArrayRef in_strides, const void * in_data, + IntArrayRef out_strides, void * out_data) { + // Do the actual conjugate mirroring. + // TODO: consider adding a 32bit indexed kernel for improved performance + auto* in_strides_ptr = in_strides.data(); + const int ndim = in_strides.size(); + const int64_t element_size = scalarTypeToTypeMeta(dtype).itemsize(); + OffsetCalculator<1, int64_t> input_offset_calculator( + ndim, signal_half_sizes.data(), &in_strides_ptr, &element_size); + HermitianSymmetryOffsetCalculator output_offset_calculator( + signal_half_sizes, out_strides, mirror_dims, element_size); + + const auto numel = at::prod_intlist(signal_half_sizes); + AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] { + using namespace cuda::detail; + _fft_conjugate_copy_kernel<<< + GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( + numel, + static_cast(out_data), + static_cast(in_data), + input_offset_calculator, + output_offset_calculator); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_); - // copy - int64_t n = input.numel() / size_last_dim * (size_last_dim - last_dim_start_slice); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "_fft_fill_with_conjugate_symmetry_", [&] { - typedef thrust::device_ptr device_ptr; - typedef thrust::counting_iterator counter; - typedef thrust::transform_iterator dst_idx_iterator; - typedef thrust::permutation_iterator dst_iterator; - typedef thrust::transform_iterator, dst_idx_iterator> src_iterator; - - dst_idx_iterator dst_idxs(counter(0), cnt_to_dst_idx_functor(size_last_dim, last_dim_start_slice)); - - auto data = device_ptr(input.data_ptr()); - dst_iterator dsts(data, dst_idxs); - src_iterator srcs(dst_idxs, dst_idx_to_src_functor(input)); - thrust::copy_n(policy, srcs, n, dsts); - }); +// Execute a pre-planned tranform +static void exec_cufft_plan( + const CuFFTConfig &config, void* in_data, void* out_data, bool forward) { + auto& plan = config.plan(); +#ifdef __HIP_PLATFORM_HCC__ + auto value_type = config.data_type(); + if (value_type == kFloat) { + switch (config.transform_type()) { + case CuFFTTransformType::C2C: { + CUFFT_CHECK(hipfftExecC2C(plan, static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + return; + } + case CuFFTTransformType::R2C: { + CUFFT_CHECK(hipfftExecC2R(plan, static_cast(in_data), + static_cast(out_data))); + return; + } + case CuFFTTransformType::C2R: { + CUFFT_CHECK(hipfftExecR2C(plan, static_cast(in_data), + static_cast(out_data))); + return; + } + } + } else if (value_type == kDouble) { + switch (config.transform_type()) { + case CuFFTTransformType::C2C: { + CUFFT_CHECK(hipfftExecZ2Z(plan, static_cast(in_data), + static_cast(out_data), + forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD)); + return; + } + case CuFFTTransformType::R2C: { + CUFFT_CHECK(hipfftExecD2Z(plan, static_cast(in_data), + static_cast(out_data))); + return; + } + case CuFFTTransformType::C2R: { + CUFFT_CHECK(hipfftExecZ2D(plan, static_cast(in_data), + static_cast(out_data))); + return; + } + } + } + TORCH_CHECK(false, "hipFFT doesn't support transforms on type: ", value_type); +#else + CUFFT_CHECK(cufftXtExec(plan, in_data, out_data, + forward ? CUFFT_FORWARD : CUFFT_INVERSE)); +#endif } + // NOTE [ cuFFT Embedded Strides ] // // cuFFT supports a subset of arbitrary strides via their "advanced data layout" @@ -195,45 +248,7 @@ static inline Tensor _run_cufft( CUFFT_CHECK(cufftSetWorkArea(plan, ws.data_ptr())); // run -#ifdef __HIP_PLATFORM_HCC__ - if (input.scalar_type() == ScalarType::Float) { - if (complex_input && complex_output) { - CUFFT_CHECK(hipfftExecC2C(plan, static_cast(input.data_ptr()), - static_cast(output.data_ptr()), - inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD)); - } else if (complex_input && !complex_output) { - CUFFT_CHECK(hipfftExecC2R(plan, static_cast(input.data_ptr()), - static_cast(output.data_ptr()))); - } else if (!complex_input && complex_output) { - CUFFT_CHECK(hipfftExecR2C(plan, static_cast(input.data_ptr()), - static_cast(output.data_ptr()))); - } else { - AT_ERROR("hipFFT doesn't support r2r (float)"); - } - } else if (input.scalar_type() == ScalarType::Double) { - if (complex_input && complex_output) { - CUFFT_CHECK(hipfftExecZ2Z(plan, static_cast(input.data_ptr()), - static_cast(output.data_ptr()), - inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD)); - } else if (complex_input && !complex_output) { - CUFFT_CHECK(hipfftExecZ2D(plan, static_cast(input.data_ptr()), - static_cast(output.data_ptr()))); - } else if (!complex_input && complex_output) { - CUFFT_CHECK(hipfftExecD2Z(plan, static_cast(input.data_ptr()), - static_cast(output.data_ptr()))); - } else { - AT_ERROR("hipFFT doesn't support r2r (double)"); - } - } else { - std::ostringstream ss; - ss << "hipFFT doesn't support tensor of type: " - << toString(input.scalar_type()); - AT_ERROR(ss.str()); - } -#else - CUFFT_CHECK(cufftXtExec(plan, input.data_ptr(), output.data_ptr(), - inverse ? CUFFT_INVERSE : CUFFT_FORWARD)); -#endif + exec_cufft_plan(config, input.data_ptr(), output.data_ptr(), !inverse); // rescale if requested auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1]; @@ -255,8 +270,10 @@ static inline Tensor _run_cufft( // if needed, fill out the other half using conjugate symmetry if (!complex_input && complex_output && !onesided) { - auto start_slice = infer_ft_real_to_complex_onesided_size(size_last_signal_dim); - _fft_fill_with_conjugate_symmetry_(output, size_last_signal_dim, start_slice); + DimVector signal_dims(signal_ndim); + std::iota(signal_dims.begin(), signal_dims.end(), 1); + auto out_as_complex = at::view_as_complex(output); + at::native::_fft_fill_with_conjugate_symmetry_(out_as_complex, signal_dims); } return output; } @@ -320,74 +337,296 @@ void cufft_clear_plan_cache_impl(int64_t device_index) { } // namespace at::native::detail -// cuFFT -// Currently not utilizing multi GPUs so this can be potentially sped up. -Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim, - bool complex_input, bool complex_output, bool inverse, - IntArrayRef checked_signal_sizes, int64_t normalization, bool onesided, - IntArrayRef output_sizes) { - - CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(self.device().index()); - - Tensor input = self; - bool input_was_cloned = false; - - // Slice when twosided complex-to-real. This is not always needed because we - // calculate the inembed. But it will benefit us in certain cases where we - // clone the input tensor. - // - // See NOTE [ cuFFT Embedded Strides ]. - // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h. - if (complex_input && !complex_output && !onesided) { - auto onesided_size = infer_ft_real_to_complex_onesided_size(checked_signal_sizes[signal_ndim - 1]); - input = input.narrow(signal_ndim, 0, onesided_size); - } +namespace { +constexpr int64_t cufft_max_ndim = 3; - // cuFFT requires input and output data pointers to complex type aligned. - // Our newly allocated output tensor is always 512 bytes aligned so it is fine - // (see kRoundSmall and kRoundLarge in THCCachingAllocator.cpp), but we do - // need to check input tensor to make sure that it is not unaligned, e.g., - // from a slicing. - auto complex_size_bytes = 2 * input.element_size(); - if (reinterpret_cast(input.data_ptr()) % complex_size_bytes != 0) { - input = input.clone(at::MemoryFormat::Contiguous); - input_was_cloned = true; +// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) +static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes, + IntArrayRef dim, bool forward) { + const auto ndim = self.dim(); + const int64_t signal_ndim = dim.size(); + const auto batch_dims = ndim - signal_ndim; + + // Permute dimensions so batch dimensions come first, and in stride order + // This maximizes data locality when collapsing to a single batch dimension + DimVector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0}); + + c10::SmallVector is_transformed_dim(ndim); + for (const auto& d : dim) { + is_transformed_dim[d] = true; + } + auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(), + [&](int64_t d) {return !is_transformed_dim[d]; }); + auto self_strides = self.strides(); + std::sort(dim_permute.begin(), batch_end, + [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; }); + std::copy(dim.cbegin(), dim.cend(), batch_end); + auto input = self.permute(dim_permute); + + // Collapse batch dimensions into a single dimension + DimVector batched_sizes(signal_ndim + 1); + batched_sizes[0] = -1; + std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1); + input = input.reshape(batched_sizes); + + const auto batch_size = input.sizes()[0]; + DimVector signal_size(signal_ndim + 1); + signal_size[0] = batch_size; + for (int64_t i = 0; i < signal_ndim; ++i) { + auto in_size = input.sizes()[i + 1]; + auto out_size = out_sizes[dim[i]]; + signal_size[i + 1] = std::max(in_size, out_size); + TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] || + in_size == (signal_size[i + 1] / 2) + 1); + TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] || + out_size == (signal_size[i + 1] / 2) + 1); } - // Now that we have done error check and data_ptr checks, we delegate all - // further cuFFT parameter computation and plan creation to the helper class - // CuFFTConfig in CuFFTPlanCache.h. + batched_sizes[0] = batch_size; + DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end()); + for (size_t i = 0; i < dim.size(); ++i) { + batched_out_sizes[i + 1] = out_sizes[dim[i]]; + } + out.resize_(batched_out_sizes, MemoryFormat::Contiguous); - // If plan caching is enabled, we check the cache. Note that this accesses - // plan_cache.max_size() and thus makes this function less functional. - // However, integrating additional arguments into the "public" level c++ APIs, - // e.g., irfft, is difficult as we have a long call sequence looking like - // irfft --> _fft --> _fft_with_size --dispatching-to-> _fft_cufft + // Create the transform plan (either from cache or locally) + const auto value_type = c10::toValueType(input.scalar_type()); + auto fft_type = GetCuFFTTransformType(input.is_complex(), out.is_complex()); + CuFFTParams Params(input.strides(), out.strides(), signal_size, fft_type, value_type); + CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(input.device().index()); + std::unique_lock guard(plan_cache.mutex, std::defer_lock); + c10::optional uncached_plan; + const CuFFTConfig * config = nullptr; - // This read is not locked for perf reason. Shouldn't matter too much because - // we check again after acquiring the lock. if (plan_cache.max_size() > 0) { - CuFFTParams params; - setCuFFTParams(¶ms, input, signal_ndim, complex_input, - complex_output, checked_signal_sizes, onesided); - std::lock_guard guard(plan_cache.mutex); + guard.lock(); if (plan_cache.max_size() > 0) { // check again after acquiring the lock - const CuFFTConfig &config = plan_cache.try_emplace_value(std::move(params), - input, signal_ndim, complex_input, - complex_output, checked_signal_sizes, - onesided, output_sizes); - return _run_cufft(config, input, signal_ndim, complex_input, - complex_output, inverse, checked_signal_sizes, - static_cast(normalization), - onesided, output_sizes, input_was_cloned); + config = &plan_cache.lookup(Params); + } + } + + if (config == nullptr) { + uncached_plan.emplace(Params); + config = &uncached_plan.value(); + } + + auto & plan = config->plan(); + + if (config->should_clone_input()) { + input = input.clone(MemoryFormat::Contiguous); + } + + // prepare cufft for execution + CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream())); + auto workspace = at::empty({ config->workspace_size() }, at::device(at::kCUDA).dtype(at::kByte)); + CUFFT_CHECK(cufftSetWorkArea(plan, workspace.data_ptr())); + + // execute transform plan + exec_cufft_plan(*config, input.data_ptr(), out.data_ptr(), forward); + + // Inplace reshaping to original batch shape and inverting the dimension permutation + DimVector out_strides(ndim); + int64_t batch_numel = 1; + for (int64_t i = batch_dims - 1; i >= 0; --i) { + out_strides[dim_permute[i]] = batch_numel * out.strides()[0]; + batch_numel *= out_sizes[dim_permute[i]]; + } + for (int64_t i = batch_dims; i < ndim; ++i) { + out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)]; + } + return out.as_strided_(out_sizes, out_strides, out.storage_offset()); +} + +// Calculates the normalization constant and applies it in-place to self +// sizes is the sizes of a twosided tensor and dims are all transformed dims +double _fft_normalization_scale(int64_t normalization, IntArrayRef sizes, IntArrayRef dims) { + auto norm = static_cast(normalization); + if (norm == fft_norm_mode::none) { + return 1.0; + } + + int64_t signal_numel = 1; + for (auto dim : dims) { + signal_numel *= sizes[dim]; + } + const double scale_denom = (norm == fft_norm_mode::by_root_n) ? + std::sqrt(signal_numel) : static_cast(signal_numel); + return 1.0 / scale_denom; +} + +const Tensor& _fft_apply_normalization(const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) { + auto scale = _fft_normalization_scale(normalization, sizes, dims); + return (scale == 1.0) ? self : self.mul_(scale); +} + +Tensor& _fft_apply_normalization_out(Tensor& out, const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) { + auto scale = _fft_normalization_scale(normalization, sizes, dims); + return at::mul_out(out, self, c10::scalar_to_tensor(scale)); +} + +} // namespace (anonymous) + +// n-dimensional real to complex FFT +Tensor _fft_r2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { + TORCH_CHECK(self.is_floating_point()); + auto input_sizes = self.sizes(); + DimVector onesided_sizes(input_sizes.begin(), input_sizes.end()); + auto last_dim = dim.back(); + auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; + onesided_sizes[last_dim] = last_dim_halfsize; + IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes; + + const auto out_options = self.options().dtype(c10::toComplexType(self.scalar_type())); + auto output = at::empty(out_sizes, out_options); + + // CuFFT requires real input to be over-aligned, as if it were complex + const auto complex_size = 2 * self.element_size(); + const bool complex_aligned = ( + reinterpret_cast(self.data_ptr()) % complex_size == 0); + auto working_tensor = self; + if (!complex_aligned) { + working_tensor = self.movedim(last_dim, -1) + .clone(MemoryFormat::Contiguous) + .movedim(-1, last_dim); + } + + // First do the R2C transform on the last dimension + { + auto target_sizes = dim.size() == 1 ? out_sizes : onesided_sizes; + _exec_fft(output, working_tensor, target_sizes, last_dim, /*forward=*/true); + if (dim.size() > 1) { + working_tensor = at::empty(out_sizes, out_options); } } - CuFFTConfig config(input, signal_ndim, complex_input, complex_output, - checked_signal_sizes, onesided, output_sizes); - return _run_cufft(config, input, signal_ndim, complex_input, - complex_output, inverse, checked_signal_sizes, - static_cast(normalization), - onesided, output_sizes, input_was_cloned); + + // Then any remaining C2C transforms + DimVector sorted_dims(dim.begin(), dim.end() - 1); + while (!sorted_dims.empty()) { + std::swap(output, working_tensor); + + // Resort dimensions every time as _exec_fft re-strides the output + auto strides = working_tensor.strides(); + std::sort(sorted_dims.begin(), sorted_dims.end(), + [&](int64_t a, int64_t b) { return strides[a] > strides[b]; }); + + const auto max_dims = std::min(static_cast(cufft_max_ndim), sorted_dims.size()); + auto last_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims); + + // Intermediate results are always onesided + _exec_fft(output, working_tensor, onesided_sizes, last_dims, /*forward=*/true); + sorted_dims.resize(sorted_dims.size() - max_dims); + } + + // Only need to normalize the onesided slice since data in the other half is overwritten + auto out_slice = output.slice(last_dim, 0, last_dim_halfsize); + _fft_apply_normalization(out_slice, normalization, input_sizes, dim); + + if (!onesided) { + if (output.sizes()[last_dim] != out_sizes[last_dim]) { + working_tensor.resize_(out_sizes, MemoryFormat::Contiguous); + working_tensor.slice(last_dim, 0, last_dim_halfsize).copy_(output); + output = std::move(working_tensor); + } + at::native::_fft_fill_with_conjugate_symmetry_(output, dim); + } + return output; +} + +Tensor& _fft_r2c_cufft_out(Tensor& out, const Tensor& self, IntArrayRef dim, + int64_t normalization, bool onesided) { + auto result = _fft_r2c_cufft(self, dim, static_cast(fft_norm_mode::none), /*onesided=*/true); + if (onesided) { + return _fft_apply_normalization_out(out, result, normalization, self.sizes(), dim); + } + + resize_output(out, self.sizes()); + + auto last_dim = dim.back(); + auto last_dim_halfsize = result.sizes()[last_dim]; + auto out_slice = out.slice(last_dim, 0, last_dim_halfsize); + _fft_apply_normalization_out(out_slice, result, normalization, self.sizes(), dim); + at::native::_fft_fill_with_conjugate_symmetry_(out, dim); + return out; +} + +// n-dimensional complex to real IFFT +Tensor _fft_c2r_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t lastdim) { + TORCH_CHECK(self.is_complex()); + auto in_sizes = self.sizes(); + DimVector out_sizes(in_sizes.begin(), in_sizes.end()); + out_sizes[dim.back()] = lastdim; + + // First complete any C2C transforms + Tensor temp; + if (dim.size() > 1) { + temp = _fft_c2c_cufft( + self, dim.slice(0, dim.size() - 1), + static_cast(fft_norm_mode::none), /*forward=*/false); + } else { + // Complex to real FFTs may overwrite the input buffer, so must always clone (gh-34551) + temp = self.clone(MemoryFormat::Contiguous); + } + + // Finally, do a 1D C2R transform + // TODO: could transform up to 2 other dims in the same cuFFT operation + auto output = at::empty(out_sizes, self.options().dtype(c10::toValueType(self.scalar_type()))); + _exec_fft(output, temp, out_sizes, dim.back(), /*forward=*/false); + return _fft_apply_normalization(output, normalization, out_sizes, dim); +} + +Tensor& _fft_c2r_cufft_out(Tensor& out, const Tensor& self, IntArrayRef dim, + int64_t normalization, int64_t lastdim) { + auto result = _fft_c2r_cufft(self, dim, static_cast(fft_norm_mode::none), lastdim); + return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim); +} + +// n-dimensional complex to complex FFT/IFFT +Tensor _fft_c2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { + TORCH_CHECK(self.is_complex()); + if (dim.empty()) { + return self.clone(); + } + + auto out_sizes = self.sizes(); + auto output = at::empty(out_sizes, self.options()); + + // Perform any number of C2C transforms + DimVector sorted_dims(dim.begin(), dim.end()); + auto self_strides = self.strides(); + auto working_tensor = self; + while (true) { + // Sort dimensions every time as _exec_fft re-strides the output + auto strides = working_tensor.strides(); + std::sort(sorted_dims.begin(), sorted_dims.end(), + [&](int64_t a, int64_t b) { return strides[a] > strides[b]; }); + + const auto max_dims = std::min(static_cast(cufft_max_ndim), sorted_dims.size()); + auto first_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims); + + _exec_fft(output, working_tensor, out_sizes, first_dims, forward); + sorted_dims.resize(sorted_dims.size() - max_dims); + + if (sorted_dims.empty()) { + break; + } + + if (working_tensor.is_same(self)) { + working_tensor = std::move(output); + output = at::empty(out_sizes, self.options()); + } else { + std::swap(output, working_tensor); + } + } + + return _fft_apply_normalization(output, normalization, out_sizes, dim); } +Tensor& _fft_c2c_cufft_out(Tensor& out, const Tensor& self, IntArrayRef dim, + int64_t normalization, bool forward) { + auto result = _fft_c2c_cufft(self, dim, static_cast(fft_norm_mode::none), forward); + return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim); +} + + }} // at::native diff --git a/aten/src/ATen/native/cuda/StepKernel.cu b/aten/src/ATen/native/cuda/StepKernel.cu new file mode 100644 index 0000000000000..61aaa493122ce --- /dev/null +++ b/aten/src/ATen/native/cuda/StepKernel.cu @@ -0,0 +1,31 @@ +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at { namespace native { + +void nextafter_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return ::nextafter(a, b); + }); + }); +} + +void heaviside_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a == 0 ? b : static_cast(a > 0); + }); + }); +} + +REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda); +REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index 443bea3f71aca..b10ae52e44fd0 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -17,7 +17,7 @@ DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub); namespace { void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] { if (condition_type == at::ScalarType::Byte) { gpu_kernel( iter, diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index 0331e768a6922..effeef69f0cf4 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -22,15 +22,13 @@ namespace at { namespace native { Tensor& eye_out_cuda(Tensor& result, int64_t n) { - return at::native::eye_out_cuda(result, n, /*m=*/-1); + // the default value of `m` equals to `n` + return at::native::eye_out_cuda(result, n, n); } Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); - - if(m < 0) { - m = n; - } + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); result.zero_(); @@ -43,16 +41,16 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { return result; } -Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional optional_memory_format) { - AT_ASSERT(options.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch()); - TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); +Tensor empty_cuda(IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { + AT_ASSERT(device_or_default(device_opt).type() == at::DeviceType::CUDA); + TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); auto* allocator = at::cuda::getCUDADeviceAllocator(); int64_t nelements = prod_intlist(size); - auto dtype = options.dtype(); - int64_t size_bytes = nelements * dtype.itemsize(); + auto dtype = dtype_or_default(dtype_opt); + auto dtype_meta = scalarTypeToTypeMeta(dtype); + int64_t size_bytes = nelements * dtype_meta.itemsize(); auto storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, @@ -61,29 +59,26 @@ Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional< /*resizeable=*/true); auto tensor = - detail::make_tensor(storage_impl, DispatchKey::CUDA, dtype); + detail::make_tensor(storage_impl, DispatchKey::CUDA, dtype_meta); // Default TensorImpl has size [0] if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); } - TORCH_CHECK( - !(options.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - auto memory_format = options.memory_format_opt().value_or(optional_memory_format.value_or(MemoryFormat::Contiguous)); + auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous); tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); return tensor; } -Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) { - auto t = at::native::empty_cuda({0}, options); +Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + auto t = at::native::empty_cuda({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt); at::native::resize_impl_cuda_(t.unsafeGetTensorImpl(), size, stride); return t; } Tensor& randperm_out_cuda(Tensor& result, int64_t n, c10::optional generator) { TORCH_CHECK(n >= 0, "n must be non-negative, got", n); + TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), "Expected a '", result.device(), "' generator device but found '", generator->device(), "'"); check_supported_max_int_with_precision(n, result); result.resize_({n}); @@ -327,11 +322,12 @@ void tril_indices_kernel(scalar_t * tensor, // implementation, please enable them in test/test_cuda.py and make sure they // pass on your local server. Tensor tril_indices_cuda( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto tril_size = get_tril_size(row, col, offset); - auto tensor = empty_cuda({2, tril_size}, options); + auto tensor = empty_cuda({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); if (tril_size > 0) { auto m_first_row = offset > 0 ? @@ -361,6 +357,7 @@ Tensor tril_indices_cuda( col, tril_size - rectangle_size, tril_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -401,11 +398,12 @@ void triu_indices_kernel(scalar_t * tensor, // implementation, please enable them in test/test_cuda.py and make sure they // pass on your local server. Tensor triu_indices_cuda( - int64_t row, int64_t col, int64_t offset, const TensorOptions& options) { - check_args(row, col, options); + int64_t row, int64_t col, int64_t offset, c10::optional dtype_opt, + c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { + check_args(row, col, layout_opt); auto triu_size = row * col - get_tril_size(row, col, offset - 1); - auto tensor = empty_cuda({2, triu_size}, options); + auto tensor = empty_cuda({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt); if (triu_size > 0) { // # of triu elements in the first row @@ -437,6 +435,7 @@ Tensor triu_indices_cuda( col, rectangle_size, triu_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index 4318b35c12952..9dfa4e8759cf0 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -87,7 +87,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, in_tensor.scalar_type(), "flip_cuda", [&] { auto in_tensor_info = cuda::detail::getTensorInfo(in_tensor); auto out_tensor_info = cuda::detail::getTensorInfo(out_tensor); int flip_dim = in_tensor_info.collapseDims(flip_dims[0]); @@ -95,6 +95,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { kernel_pointwise_flip_apply2 <<>>( in_tensor_info, out_tensor_info, N, flip_dim, total_dims); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); return out_tensor; } @@ -122,7 +123,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, in_tensor.scalar_type(), "flip_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, in_tensor.scalar_type(), "flip_cuda", [&] { flip_cuda_kernel<<>>( in_tensor.data_ptr(), out_tensor.data_ptr(), N, flip_dims_t.cuda().data_ptr(), @@ -131,6 +132,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { stride_contiguous.cuda().data_ptr(), shape_t.cuda().data_ptr(), total_dims); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); return out_tensor; @@ -195,6 +197,7 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { size, in_tensor.stride(dim), total_dims); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); return out_tensor; diff --git a/aten/src/ATen/native/cuda/TriangularOps.cu b/aten/src/ATen/native/cuda/TriangularOps.cu index bb17233b38660..8d497b5c94af6 100644 --- a/aten/src/ATen/native/cuda/TriangularOps.cu +++ b/aten/src/ATen/native/cuda/TriangularOps.cu @@ -60,22 +60,23 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c int64_t N = self.numel(); dim3 dim_block = cuda::getApplyBlock(); dim3 dim_grid((N + dim_block.x - 1) / dim_block.x); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), name, [&]{ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu_tril_cuda_template", [&]{ if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) { auto result_info = cuda::detail::getTensorInfo(result); auto self_info = cuda::detail::getTensorInfo(self); triu_tril_kernel <<>>( result_info, self_info, k, N); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto result_info = cuda::detail::getTensorInfo(result); auto self_info = cuda::detail::getTensorInfo(self); triu_tril_kernel <<>>( result_info, self_info, k, N); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -191,6 +192,7 @@ Tensor& apply_diag(Tensor& result, const Tensor& self, int64_t dimension) { sz, self_stride_0 + self_stride_1, result_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } else { auto n_elems = self.numel(); @@ -219,6 +221,7 @@ Tensor& apply_diag(Tensor& result, const Tensor& self, int64_t dimension) { n_elems, result_stride_0 + result_stride_1, self_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -226,7 +229,7 @@ Tensor& apply_diag(Tensor& result, const Tensor& self, int64_t dimension) { } Tensor& diag_cuda_out(Tensor& result, const Tensor& self, int64_t dimension) { - AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, self.scalar_type(), "diag_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(ScalarType::Half, ScalarType::Bool, self.scalar_type(), "diag_cuda", [&] { apply_diag(result, self, dimension); }); return result; diff --git a/aten/src/ATen/native/cuda/UnaryComplexKernels.cu b/aten/src/ATen/native/cuda/UnaryComplexKernels.cu index 30fa3dc90176c..6e192b51494fc 100644 --- a/aten/src/ATen/native/cuda/UnaryComplexKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryComplexKernels.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -11,7 +12,10 @@ namespace at { namespace native { // We manually overload angle because std::arg does not work with types other than c10::complex. template __host__ __device__ static inline scalar_t angle_wrapper(scalar_t v) { - return 0; + if (at::_isnan(v)){ + return v; + } + return v < 0 ? M_PI : 0; } template @@ -20,7 +24,7 @@ __host__ __device__ static inline c10::complex angle_wrapper(c10::complex } void angle_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "angle_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "angle_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return angle_wrapper(a); }); diff --git a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu index eb9250befd569..c904e1776eed7 100644 --- a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu @@ -88,7 +88,7 @@ __host__ __device__ static inline c10::complex reciprocal_wrapper(c10::comple } void reciprocal_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "reciprocal_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "reciprocal_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return reciprocal_wrapper(a); }); @@ -114,7 +114,7 @@ __host__ __device__ static inline c10::complex nearbyint_wrapper(c10::co } void round_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "round_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "round_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { // We do not use std::round because we would like to round midway numbers to the nearest even integer. return nearbyint_wrapper(a); diff --git a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu index d752d606474de..97dbeefccc773 100644 --- a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu @@ -11,7 +11,7 @@ namespace at { namespace native { void digamma_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "digamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "digamma_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_digamma(a); }); diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 9138aa7a0098d..bac3a05439d20 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -11,7 +11,7 @@ namespace at { namespace native { void acos_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "acos_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "acos_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::acos(a); }); @@ -19,7 +19,7 @@ void acos_kernel_cuda(TensorIterator& iter) { } void asin_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "asin_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "asin_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::asin(a); }); @@ -27,7 +27,7 @@ void asin_kernel_cuda(TensorIterator& iter) { } void atan_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "atan_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "atan_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::atan(a); }); @@ -35,7 +35,7 @@ void atan_kernel_cuda(TensorIterator& iter) { } void sin_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "sin_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sin_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sin(a); }); @@ -43,7 +43,7 @@ void sin_kernel_cuda(TensorIterator& iter) { } void cos_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "cos_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "cos_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cos(a); }); @@ -51,7 +51,7 @@ void cos_kernel_cuda(TensorIterator& iter) { } void sinh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "sinh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sinh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sinh(a); }); @@ -59,7 +59,7 @@ void sinh_kernel_cuda(TensorIterator& iter) { } void cosh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "cosh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "cosh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cosh(a); }); @@ -67,7 +67,7 @@ void cosh_kernel_cuda(TensorIterator& iter) { } void tanh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "tanh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "tanh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::tanh(a); }); @@ -75,7 +75,7 @@ void tanh_kernel_cuda(TensorIterator& iter) { } void acosh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "acosh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::acosh(a); }); @@ -83,7 +83,7 @@ void acosh_kernel_cuda(TensorIterator& iter) { } void asinh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "asinh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::asinh(a); }); @@ -91,7 +91,7 @@ void asinh_kernel_cuda(TensorIterator& iter) { } void atanh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "atanh_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::atanh(a); }); @@ -99,7 +99,7 @@ void atanh_kernel_cuda(TensorIterator& iter) { } void tan_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "tan_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "tan_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::tan(a); }); diff --git a/aten/src/ATen/native/cuda/UnaryLogKernels.cu b/aten/src/ATen/native/cuda/UnaryLogKernels.cu index a43fa541554b4..84edc13e14ae6 100644 --- a/aten/src/ATen/native/cuda/UnaryLogKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryLogKernels.cu @@ -11,7 +11,7 @@ namespace at { namespace native { void log_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log(a); }); @@ -19,7 +19,7 @@ void log_kernel_cuda(TensorIterator& iter) { } void log10_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log10_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log10_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log10(a); }); @@ -27,7 +27,7 @@ void log10_kernel_cuda(TensorIterator& iter) { } void log1p_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log1p_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log1p_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log1p(a); }); @@ -35,7 +35,7 @@ void log1p_kernel_cuda(TensorIterator& iter) { } void log2_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log2_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log2_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log2(a); }); diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 1067d7c61bc5f..e727335aaf173 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include #include namespace at { @@ -31,7 +33,7 @@ void bitwise_not_kernel_cuda(TensorIterator& iter) { } void exp_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exp_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "exp_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::exp(a); }); @@ -39,7 +41,7 @@ void exp_kernel_cuda(TensorIterator& iter) { } void exp2_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "exp2_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "exp2_cuda", [&]() { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::exp2(a); }); @@ -47,7 +49,7 @@ void exp2_kernel_cuda(TensorIterator& iter) { } void expm1_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "expm1_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "expm1_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::expm1(a); }); @@ -76,7 +78,7 @@ __host__ __device__ static inline c10::complex rsqrt_wrapper(c10::complex } void rsqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "rsqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "rsqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float. return rsqrt_wrapper(a); @@ -85,7 +87,7 @@ void rsqrt_kernel_cuda(TensorIterator& iter) { } void sqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sqrt(a); }); @@ -93,7 +95,7 @@ void sqrt_kernel_cuda(TensorIterator& iter) { } void sigmoid_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "sigmoid_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "sigmoid_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { scalar_t one = scalar_t(1); return one / (one + std::exp(- a)); @@ -101,6 +103,19 @@ void sigmoid_kernel_cuda(TensorIterator& iter) { }); } +void sinc_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "sinc_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + if (a == scalar_t(0)) { + return scalar_t(1); + } else { + scalar_t product = scalar_t(M_PI) * a; + return std::sin(product) / product; + } + }); + }); +} + void logit_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, @@ -129,7 +144,7 @@ void logit_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) { } void erf_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "erf_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::erf(a); }); @@ -137,7 +152,7 @@ void erf_kernel_cuda(TensorIterator& iter) { } void erfc_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "erfc_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfc_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::erfc(a); }); @@ -145,7 +160,7 @@ void erfc_kernel_cuda(TensorIterator& iter) { } void erfinv_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "erfinv_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::erfinv(a); }); @@ -153,40 +168,84 @@ void erfinv_kernel_cuda(TensorIterator& iter) { } void clamp_kernel_cuda(TensorIterator& iter, Scalar min_value, Scalar max_value) { - AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_cuda", [&]() { auto lower = min_value.to(); auto upper = max_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { - return (v < lower) ? lower : (v > upper ? upper : v); + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::min(::max(v, lower), upper); + } }); }); } void clamp_min_kernel_cuda(TensorIterator& iter, Scalar min_value) { - AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_min_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_min_cuda", [&]() { auto lower = min_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { - return v < lower ? lower : v; + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::max(v, lower); + } }); }); } void clamp_max_kernel_cuda(TensorIterator& iter, Scalar max_value) { - AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "clamp_max_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_max_cuda", [&]() { auto upper = max_value.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { - return v > upper ? upper : v; + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(v)) { + return v; + } else { + return ::min(v, upper); + } }); }); } -void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta){ +void nan_to_num_kernel_cuda( + TensorIterator& iter, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "nan_to_num_cuda", [&]() { + scalar_t nan_replacement = static_cast(nan.value_or(0.)); + scalar_t pos_inf_replacement = pos_inf.has_value() + ? static_cast(pos_inf.value()) + : std::numeric_limits::max(); + scalar_t neg_inf_replacement = neg_inf.has_value() + ? static_cast(neg_inf.value()) + : std::numeric_limits::lowest(); + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ( + at::_isnan(a) + ? nan_replacement + : (a == std::numeric_limits::infinity() + ? pos_inf_replacement + : (a == -std::numeric_limits::infinity() + ? neg_inf_replacement + : a))); + }); + }); +} + +void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta_){ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){ - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "kaiser_window_cuda", [&] { - const scalar_t alpha = static_cast((window_length - 1) / 2.0); - gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t { - return calc_i0(static_cast(beta) * ::sqrt(1 - ::pow((a - alpha) / alpha, static_cast(2.0)))) / calc_i0(static_cast(beta)); - }); + using T_ACC = acc_type; + const T_ACC inv_alpha = static_cast(2.0 / (window_length - 1)); + const T_ACC beta = static_cast(beta_); + const T_ACC inv_i0_beta = 1.0 / calc_i0(beta); + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t { + T_ACC x = static_cast(a) * inv_alpha - 1; + T_ACC y = std::max(0, 1 - x * x); + return calc_i0(beta * ::sqrt(y)) * inv_i0_beta; }); }); } @@ -199,6 +258,7 @@ REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda); REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda); REGISTER_DISPATCH(sqrt_stub, &sqrt_kernel_cuda); REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda); +REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda); REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda); REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda); REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda); @@ -206,6 +266,7 @@ REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda); REGISTER_DISPATCH(clamp_stub, &clamp_kernel_cuda); REGISTER_DISPATCH(clamp_min_stub, &clamp_min_kernel_cuda); REGISTER_DISPATCH(clamp_max_stub, &clamp_max_kernel_cuda); +REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda); REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda); } // namespace native diff --git a/aten/src/ATen/native/cuda/UnarySignKernels.cu b/aten/src/ATen/native/cuda/UnarySignKernels.cu index cd02c89f23f02..617f09c4fbd67 100644 --- a/aten/src/ATen/native/cuda/UnarySignKernels.cu +++ b/aten/src/ATen/native/cuda/UnarySignKernels.cu @@ -1,4 +1,3 @@ -#include #include #include #include @@ -8,24 +7,24 @@ #include #include +#include + namespace at { namespace native { void logical_not_kernel_cuda(TensorIterator& iter) { - // error check -- this is just ensuring we don't dispatch on types that aren't in ALL_TYPES_AND2(...) + // error check -- this is just ensuring we don't dispatch on types that aren't in ALL_TYPES_AND_COMPLEX_AND3(...) // so we don't have to maintain a separate list or to do double dispatch. - AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(0), "logical_not_cuda", [&]() {}); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(0), "logical_not_cuda", [&]() {}); - AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(1), "logical_not_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(1), "logical_not_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> bool { return !a; }); }); } void neg_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "neg_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "neg_cuda", [&] { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return -a; - }); + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return -a; }); }); } @@ -36,7 +35,7 @@ void sign_kernel_cuda(TensorIterator& iter){ return a; }); } else { - AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, iter.dtype(), "sign_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sign_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { scalar_t zero = scalar_t(0); return (zero < a) - (a < zero); @@ -47,7 +46,7 @@ void sign_kernel_cuda(TensorIterator& iter){ void signbit_kernel_cuda(TensorIterator& iter){ AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, ScalarType::Half, iter.input_dtype(), "signbit_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> bool { return a < 0; }); + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> bool { return !std::is_unsigned::value && a < 0; }); }); } diff --git a/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu b/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu index d7fd3e924b49d..64bda79809bbb 100644 --- a/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu @@ -41,11 +41,11 @@ static void _launch_unfold_backward_kernel(int total_n_elems, func_t f) { dim3 block(n_threads); constexpr int total_work_block = n_threads * n_elems_per_thread; dim3 grid((total_n_elems + total_work_block - 1) / total_work_block); - + auto stream = at::cuda::getCurrentCUDAStream(); _unfold_backward_elementwise_kernel <<>>(total_n_elems, f); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template diff --git a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu index fa4cde69e4991..13f0741bb5da6 100644 --- a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu @@ -228,9 +228,8 @@ static void upsample_bicubic2d_out_cuda_template( align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_bicubic2d_backward_out_cuda_template( @@ -303,9 +302,8 @@ static void upsample_bicubic2d_backward_out_cuda_template( 0, stream>>>( num_kernels, rheight, rwidth, align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index d65d6fa5e1b84..4b142d5024d89 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -213,9 +213,8 @@ static void upsample_bilinear2d_out_cuda_template( 0, stream>>>( num_kernels, rheight, rwidth, align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_bilinear2d_backward_out_cuda_template( @@ -306,9 +305,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu index a81d3e6c78b6f..eda43fbfa3986 100644 --- a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu +++ b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu @@ -160,9 +160,8 @@ static void upsample_linear1d_out_cuda_template( num_threads, 0, stream>>>(num_kernels, rwidth, align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_linear1d_backward_out_cuda_template( @@ -221,9 +220,8 @@ static void upsample_linear1d_backward_out_cuda_template( num_threads, 0, stream>>>(num_kernels, rwidth, align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu index 08bea73727ea7..b269bd303e765 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu @@ -128,9 +128,8 @@ static void upsample_nearest1d_out_cuda_template( upsample_nearest1d_out_frame<<>>( idata, nbatch, channels, input_width, output_width, odata, scale_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_nearest1d_backward_out_cuda_template( @@ -191,48 +190,30 @@ static void upsample_nearest1d_backward_out_cuda_template( upsample_nearest1d_backward_out_frame <<>>( odata, nbatch, channels, output_width, input_width, idata, scale_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace -Tensor& upsample_nearest1d_out_cuda( - Tensor& output, +TORCH_IMPL_FUNC(upsample_nearest1d_out_cuda) ( const Tensor& input, IntArrayRef output_size, - c10::optional scales) { + c10::optional scales, + Tensor& output +) { upsample_nearest1d_out_cuda_template(output, input, output_size, scales); - return output; -} - -Tensor upsample_nearest1d_cuda(const Tensor& input, IntArrayRef output_size, c10::optional scales) { - Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - upsample_nearest1d_out_cuda_template(output, input, output_size, scales); - return output; -} - -Tensor& upsample_nearest1d_backward_out_cuda( - Tensor& grad_input, - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales) { - upsample_nearest1d_backward_out_cuda_template( - grad_input, grad_output, output_size, input_size, scales); - return grad_input; } -Tensor upsample_nearest1d_backward_cuda( +TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cuda) ( const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, - c10::optional scales) { - Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + c10::optional scales, + Tensor& grad_input +) { upsample_nearest1d_backward_out_cuda_template( grad_input, grad_output, output_size, input_size, scales); - return grad_input; } using at::native::upsample::compute_output_size; diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu index 49a74f46ee14c..0ac02e292b285 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu @@ -204,9 +204,8 @@ static void upsample_nearest2d_out_cuda_template( output_width, height_scale, width_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_nearest2d_backward_out_cuda_template( @@ -287,8 +286,8 @@ static void upsample_nearest2d_backward_out_cuda_template( idata, height_scale, width_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu index 76f694274f89b..000e116e7bdfa 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu @@ -199,9 +199,8 @@ static void upsample_nearest3d_out_cuda_template( depth_scale, height_scale, width_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_nearest3d_backward_out_cuda_template( @@ -292,9 +291,8 @@ static void upsample_nearest3d_backward_out_cuda_template( depth_scale, height_scale, width_scale); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu index 0498daa037c94..8ac7abca18243 100644 --- a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu @@ -8,12 +8,24 @@ #include #include #include +#include #include namespace at { namespace native { namespace { +__device__ __forceinline__ size_t +idx_3d(const size_t nc, + const size_t depth, + const size_t height, + const size_t width, + const size_t z, + const size_t y, + const size_t x) { + return ((nc * depth + z) * height + y) * width + x; +} + template C10_LAUNCH_BOUNDS_1(1024) __global__ void upsample_trilinear3d_out_frame( @@ -101,43 +113,31 @@ __global__ void upsample_trilinear3d_out_frame( template C10_LAUNCH_BOUNDS_1(1024) __global__ void upsample_trilinear3d_backward_out_frame( - const int n, + const size_t nc_, + const int depth1, + const int height1, + const int width1, + const int depth2, + const int height2, + const int width2, const accscalar_t rdepth, const accscalar_t rheight, const accscalar_t rwidth, const bool align_corners, - PackedTensorAccessor64 idata, - const PackedTensorAccessor64 odata) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - - const int batchsize = idata.size(0); - const int channels = idata.size(1); - const int depth1 = idata.size(2); - const int height1 = idata.size(3); - const int width1 = idata.size(4); - const int depth2 = odata.size(2); - const int height2 = odata.size(3); - const int width2 = odata.size(4); - - if (index < n) { - const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1 - const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1 - const int t2 = index / (height2 * width2); // 0:depth2-1 - // special case: just copy - if (depth1 == depth2 && height1 == height2 && width1 == width2) { - const int t1 = t2; - const int h1 = h2; - const int w1 = w2; + scalar_t* __restrict__ idata, + const scalar_t* __restrict__ odata) { + const size_t i_numel = nc_ * depth1 * height1 * width1; + const size_t o_numel = nc_ * depth2 * height2 * width2; + + for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; index += blockDim.x * gridDim.x) { + size_t index_temp = index; + const int w2 = index_temp % width2; // 0:width2-1 + index_temp /= width2; + const int h2 = index_temp % height2; // 0:height2-1 + index_temp /= height2; + const int t2 = index_temp % depth2; // 0:depth2-1 + const int nc = index_temp / depth2; - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; ++c) { - const scalar_t val = odata[n][c][t1][h1][w1]; - idata[n][c][t2][h2][w2] = val; - } - } - return; - } - // const accscalar_t t1r = area_pixel_compute_source_index( rdepth, t2, align_corners, /*cubic=*/false); const int t1 = t1r; @@ -159,35 +159,55 @@ __global__ void upsample_trilinear3d_backward_out_frame( const accscalar_t w1lambda = w1r - w1; const accscalar_t w0lambda = static_cast(1) - w1lambda; // - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; ++c) { - const scalar_t d2val = odata[n][c][t2][h2][w2]; - gpuAtomicAdd( - &idata[n][c][t1][h1][w1], - static_cast(t0lambda * h0lambda * w0lambda * d2val)); - gpuAtomicAdd( - &idata[n][c][t1][h1][w1 + w1p], - static_cast(t0lambda * h0lambda * w1lambda * d2val)); - gpuAtomicAdd( - &idata[n][c][t1][h1 + h1p][w1], - static_cast(t0lambda * h1lambda * w0lambda * d2val)); - gpuAtomicAdd( - &idata[n][c][t1][h1 + h1p][w1 + w1p], - static_cast(t0lambda * h1lambda * w1lambda * d2val)); - gpuAtomicAdd( - &idata[n][c][t1 + t1p][h1][w1], - static_cast(t1lambda * h0lambda * w0lambda * d2val)); - gpuAtomicAdd( - &idata[n][c][t1 + t1p][h1][w1 + w1p], - static_cast(t1lambda * h0lambda * w1lambda * d2val)); - gpuAtomicAdd( - &idata[n][c][t1 + t1p][h1 + h1p][w1], - static_cast(t1lambda * h1lambda * w0lambda * d2val)); - gpuAtomicAdd( - &idata[n][c][t1 + t1p][h1 + h1p][w1 + w1p], - static_cast(t1lambda * h1lambda * w1lambda * d2val)); - } - } + const scalar_t d2val = odata[index]; + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1, h1, w1), + i_numel, + static_cast(t0lambda * h0lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p), + i_numel, + static_cast(t0lambda * h0lambda * w1lambda * d2val), + true); + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1), + i_numel, + static_cast(t0lambda * h1lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p), + i_numel, + static_cast(t0lambda * h1lambda * w1lambda * d2val), + true); + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1), + i_numel, + static_cast(t1lambda * h0lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p), + i_numel, + static_cast(t1lambda * h0lambda * w1lambda * d2val), + true); + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1), + i_numel, + static_cast(t1lambda * h1lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p), + i_numel, + static_cast(t1lambda * h1lambda * w1lambda * d2val), + true); } } @@ -234,7 +254,6 @@ static void upsample_trilinear3d_out_cuda_template( output_depth, output_height, output_width}); - output.zero_(); AT_ASSERT( input_depth > 0 && input_height > 0 && input_width > 0 && @@ -271,9 +290,8 @@ static void upsample_trilinear3d_out_cuda_template( align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_trilinear3d_backward_out_cuda_template( @@ -326,21 +344,27 @@ static void upsample_trilinear3d_backward_out_cuda_template( grad_input.resize_( {nbatch, channels, input_depth, input_height, input_width}); + // A contiguous tensor is required for the kernel launch config + grad_input.contiguous(); + // Numbers are added atomically to grad_input tensor from multiple threads, + // so it has to be initialized to zero. grad_input.zero_(); - const int num_kernels = output_depth * output_height * output_width; + // const size_t num_kernels = nbatch * channels * output_depth * output_height * output_width; + const size_t num_kernels = grad_output.numel(); const int num_threads = std::min( at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_kernels > 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_output.scalar_type(), "upsample_trilinear3d_backward_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.packed_accessor64(); - auto odata = grad_output.packed_accessor64(); + auto idata = grad_input.data_ptr(); + auto odata = grad_output.data_ptr(); const accscalar_t rdepth = area_pixel_compute_scale( input_depth, output_depth, align_corners, scales_d); @@ -350,20 +374,26 @@ static void upsample_trilinear3d_backward_out_cuda_template( input_width, output_width, align_corners, scales_w); upsample_trilinear3d_backward_out_frame - <<(num_threads)), num_threads, 0, stream>>>( - num_kernels, + nbatch * channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, rdepth, rheight, rwidth, align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); + } } } // namespace diff --git a/aten/src/ATen/native/cuda/WeightNorm.cu b/aten/src/ATen/native/cuda/WeightNorm.cu index d90dc03007fd5..8261eda01a3c9 100644 --- a/aten/src/ATen/native/cuda/WeightNorm.cu +++ b/aten/src/ATen/native/cuda/WeightNorm.cu @@ -394,14 +394,14 @@ std::tuple weight_norm_cuda g.data_ptr(), fast_dim_size, slower_dims_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } // The kernel execution is asynchronous, so this will only catch errors on the kernel launch, // not the kernel's execution. Errors in kernel execution aren't guaranteed to be caught // until a later error check on a synchronizing CUDA call. Unfortunately, without manually - // synchronizing here, this is the best we can do. - AT_CUDA_CHECK(cudaGetLastError()); + // synchronizing here, the foregoing is the best we can do. return std::tuple{w, norms}; } @@ -486,14 +486,14 @@ std::tuple weight_norm_cuda_backward saved_norms.data_ptr(), fast_dim_size, slower_dims_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } // The kernel execution is asynchronous, so this will only catch errors on the kernel launch, // not the kernel's execution. Errors in kernel execution aren't guaranteed to be caught // until a later error check on a synchronizing CUDA call. Unfortunately, without manually - // synchronizing here, this is the best we can do. - AT_CUDA_CHECK(cudaGetLastError()); + // synchronizing here, the foregoing is the best we can do. return std::tuple{grad_v, grad_g}; } diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index 780a44c093b15..346c7cbc4d5d9 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index 3aeee9efe025e..1fd710a65e9fa 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -1,13 +1,18 @@ #include +#include + +#include + #include #include #include #include +#include #include #include +#include #include -#include #include @@ -27,8 +32,6 @@ __global__ void RowwiseMomentsCUDAKernel( T* mean, T* rstd) { using T_ACC = acc_type; - __shared__ T_ACC m_shared[C10_WARP_SIZE]; - __shared__ T_ACC v_shared[C10_WARP_SIZE]; const int64_t i = blockIdx.x; T_ACC sum1 = 0; T_ACC sum2 = 0; @@ -37,8 +40,15 @@ __global__ void RowwiseMomentsCUDAKernel( sum1 += static_cast(X[index]); sum2 += static_cast(X[index]) * static_cast(X[index]); } - sum1 = cuda_utils::BlockReduceSum(sum1, m_shared); - sum2 = cuda_utils::BlockReduceSum(sum2, v_shared); + if (blockDim.x <= C10_WARP_SIZE) { + sum1 = cuda_utils::WarpReduceSum(sum1); + sum2 = cuda_utils::WarpReduceSum(sum2); + } else { + __shared__ T_ACC m_shared[C10_WARP_SIZE]; + __shared__ T_ACC v_shared[C10_WARP_SIZE]; + sum1 = cuda_utils::BlockReduceSum(sum1, m_shared); + sum2 = cuda_utils::BlockReduceSum(sum2, v_shared); + } if (threadIdx.x == 0) { const T_ACC scale = T_ACC(1) / static_cast(N); sum1 *= scale; @@ -64,44 +74,190 @@ __global__ void ComputeFusedParamsCUDAKernel( if (index < N * C) { const int64_t ng = index / (C / group); const int64_t c = index % C; - const T_ACC x = (gamma == nullptr) + const T_ACC scale = (gamma == nullptr) ? static_cast(rstd[ng]) : static_cast(rstd[ng]) * static_cast(gamma[c]); - a[index] = x; - b[index] = -x * static_cast(mean[ng]) + - (beta == nullptr ? T_ACC(0) : static_cast(beta[c])); + a[index] = scale; + b[index] = -scale * static_cast(mean[ng]) + + ((beta == nullptr) ? 0 : static_cast(beta[c])); } } template -__global__ void GroupNormForwardSimpleCUDAKernel( +__global__ void Compute1dBackwardFusedParamsCUDAKernel( + int64_t C, + int64_t group, + const T* dY, + const T* X, + const T* mean, + const T* rstd, + const T* gamma, + acc_type* c2, + acc_type* c3) { + using T_ACC = acc_type; + const int64_t G = group; + const int64_t D = C / G; + const int64_t n = blockIdx.x; + const int64_t g = blockIdx.y; + const int64_t ng = n * G + g; + T_ACC sum1 = 0; + T_ACC sum2 = 0; + for (int64_t i = threadIdx.x; i < D; i += blockDim.x) { + const int64_t index = ng * D + i; + const int64_t c = g * D + i; + const T_ACC gamma_v = + gamma == nullptr ? T_ACC(1) : static_cast(gamma[c]); + sum1 += dY[index] * X[index] * gamma_v; + sum2 += dY[index] * gamma_v; + } + if (blockDim.x <= C10_WARP_SIZE) { + sum1 = cuda_utils::WarpReduceSum(sum1); + sum2 = cuda_utils::WarpReduceSum(sum2); + } else { + __shared__ T_ACC ds_shared[C10_WARP_SIZE]; + __shared__ T_ACC db_shared[C10_WARP_SIZE]; + sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); + sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); + } + if (threadIdx.x == 0) { + const T_ACC s = T_ACC(1) / static_cast(D); + const T_ACC x = (sum2 * static_cast(mean[ng]) - sum1) * + static_cast(rstd[ng]) * static_cast(rstd[ng]) * + static_cast(rstd[ng]) * s; + c2[ng] = x; + c3[ng] = -x * static_cast(mean[ng]) - + sum2 * static_cast(rstd[ng]) * s; + } +} + +template +__global__ void GammaBeta1dBackwardCUDAKernel1( int64_t N, int64_t C, - int64_t HxW, + int64_t group, + const T* dY, const T* X, - const acc_type* a, - const acc_type* b, - T* Y) { + const T* mean, + const T* rstd, + T* dgamma, + T* dbeta) { using T_ACC = acc_type; - const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < N * C * HxW) { - const int64_t nc = index / HxW; - Y[index] = a[nc] * static_cast(X[index]) + b[nc]; + const int64_t c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < C) { + const int64_t G = group; + const int64_t D = C / G; + T_ACC sum1 = 0; + T_ACC sum2 = 0; + for (int64_t n = 0; n < N; ++n) { + const int64_t nc = n * C + c; + const int64_t ng = n * G + c / D; + const T_ACC dy_acc = static_cast(dY[nc]); + const T_ACC x_acc = static_cast(X[nc]); + sum1 += (dgamma == nullptr) + ? T_ACC(0) + : ((dy_acc * x_acc - dy_acc * static_cast(mean[ng])) * + static_cast(rstd[ng])); + sum2 += (dbeta == nullptr) ? T_ACC(0) : dy_acc; + } + if (dgamma != nullptr) { + dgamma[c] = sum1; + } + if (dbeta != nullptr) { + dbeta[c] = sum2; + } } } template -__global__ void GroupNormForwardCUDAKernel( - int64_t HxW, +__global__ void GammaBeta1dBackwardCUDAKernel2( + int64_t N, + int64_t C, + int64_t group, + const T* dY, const T* X, - const acc_type* a, - const acc_type* b, - T* Y) { + const T* mean, + const T* rstd, + T* dgamma, + T* dbeta) { using T_ACC = acc_type; - const int64_t nc = blockIdx.x; - for (int64_t hw = threadIdx.x; hw < HxW; hw += blockDim.x) { - const int64_t index = nc * HxW + hw; - Y[index] = a[nc] * static_cast(X[index]) + b[nc]; + __shared__ T_ACC g_shared[kReduceTileSize][kReduceTileSize + 1]; + __shared__ T_ACC b_shared[kReduceTileSize][kReduceTileSize + 1]; + const int64_t c = blockIdx.x * blockDim.x + threadIdx.x; + T_ACC dg_sum1 = 0; + T_ACC dg_sum2 = 0; + T_ACC db_sum1 = 0; + T_ACC db_sum2 = 0; + if (c < C) { + const int64_t G = group; + const int64_t D = C / G; + // Accumulate each 32 cols into a 32 * 32 tile. + // Since the blockDim is (32, 16), accumulate twice for 1st and 2nd 16 rows + // of a 32 contiguous elements. + for (int64_t n = threadIdx.y; n < N; n += blockDim.y * 2) { + const int64_t n1 = n; + const int64_t n2 = n + blockDim.y; + const int64_t nc1 = n1 * C + c; + const int64_t nc2 = n2 * C + c; + const int64_t ng1 = n1 * G + c / D; + const int64_t ng2 = n2 * G + c / D; + const T_ACC dy1_acc = static_cast(dY[nc1]); + const T_ACC x1_acc = static_cast(X[nc1]); + dg_sum1 += dgamma == nullptr + ? T_ACC(0) + : ((dy1_acc * x1_acc - dy1_acc * static_cast(mean[ng1])) * + static_cast(rstd[ng1])); + db_sum1 += dbeta == nullptr ? T_ACC(0) : dy1_acc; + if (n2 < N) { + const T_ACC dy2_acc = static_cast(dY[nc2]); + const T_ACC x2_acc = static_cast(X[nc2]); + dg_sum2 += dgamma == nullptr + ? T_ACC(0) + : ((dy2_acc * x2_acc - dy2_acc * static_cast(mean[ng2])) * + static_cast(rstd[ng2])); + db_sum2 += dbeta == nullptr ? T_ACC(0) : dy2_acc; + } + } + } + + // Write accumulated tile to shared memory. + g_shared[threadIdx.y][threadIdx.x] = dg_sum1; + g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2; + b_shared[threadIdx.y][threadIdx.x] = db_sum1; + b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2; + __syncthreads(); + + // Do warp reduce for the 1st 16 cols in the tile. + T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y]; + T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y]; + sum1 = cuda_utils::WarpReduceSum(sum1); + sum2 = cuda_utils::WarpReduceSum(sum2); + if (threadIdx.x == 0) { + const int64_t c = blockIdx.x * blockDim.x + threadIdx.y; + if (c < C) { + if (dgamma != nullptr) { + dgamma[c] = sum1; + } + if (dbeta != nullptr) { + dbeta[c] = sum2; + } + } + } + + // Do warp reduce for the 2nd 16 cols in the tile. + sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y]; + sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y]; + sum1 = cuda_utils::WarpReduceSum(sum1); + sum2 = cuda_utils::WarpReduceSum(sum2); + if (threadIdx.x == 0) { + const int64_t c = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; + if (c < C) { + if (dgamma != nullptr) { + dgamma[c] = sum1; + } + if (dbeta != nullptr) { + dbeta[c] = sum2; + } + } } } @@ -113,8 +269,6 @@ __global__ void ComputeInternalGradientsCUDAKernel( acc_type* ds, acc_type* db) { using T_ACC = acc_type; - __shared__ T_ACC ds_shared[C10_WARP_SIZE]; - __shared__ T_ACC db_shared[C10_WARP_SIZE]; const int64_t nc = blockIdx.x; T_ACC sum1 = 0; T_ACC sum2 = 0; @@ -123,32 +277,21 @@ __global__ void ComputeInternalGradientsCUDAKernel( sum1 += static_cast(dY[index]) * static_cast(X[index]); sum2 += static_cast(dY[index]); } - sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); - sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); + if (blockDim.x <= C10_WARP_SIZE) { + sum1 = cuda_utils::WarpReduceSum(sum1); + sum2 = cuda_utils::WarpReduceSum(sum2); + } else { + __shared__ T_ACC ds_shared[C10_WARP_SIZE]; + __shared__ T_ACC db_shared[C10_WARP_SIZE]; + sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); + sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); + } if (threadIdx.x == 0) { ds[nc] = sum1; db[nc] = sum2; } } -template -__global__ void ComputeGradOutputCoeffientCUDAKernel( - int64_t N, - int64_t C, - int64_t group, - const T* rstd, - const T* gamma, - acc_type* c1) { - using T_ACC = acc_type; - const int64_t nc = blockIdx.x * blockDim.x + threadIdx.x; - if (nc < N * C) { - const int64_t ng = nc / (C / group); - const int64_t c = nc % C; - c1[nc] = static_cast(rstd[ng]) * - (gamma == nullptr ? T_ACC(1) : static_cast(gamma[c])); - } -} - template __global__ void ComputeBackwardFusedParamsCUDAKernel( int64_t C, @@ -162,8 +305,6 @@ __global__ void ComputeBackwardFusedParamsCUDAKernel( acc_type* c2, acc_type* c3) { using T_ACC = acc_type; - __shared__ T_ACC ds_shared[C10_WARP_SIZE]; - __shared__ T_ACC db_shared[C10_WARP_SIZE]; const int64_t G = group; const int64_t D = C / G; const int64_t n = blockIdx.x; @@ -179,8 +320,15 @@ __global__ void ComputeBackwardFusedParamsCUDAKernel( sum1 += ds[index] * gamma_v; sum2 += db[index] * gamma_v; } - sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); - sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); + if (blockDim.x <= C10_WARP_SIZE) { + sum1 = cuda_utils::WarpReduceSum(sum1); + sum2 = cuda_utils::WarpReduceSum(sum2); + } else { + __shared__ T_ACC ds_shared[C10_WARP_SIZE]; + __shared__ T_ACC db_shared[C10_WARP_SIZE]; + sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); + sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); + } if (threadIdx.x == 0) { const T_ACC s = T_ACC(1) / static_cast(D * HxW); const T_ACC x = (sum2 * static_cast(mean[ng]) - sum1) * @@ -193,51 +341,7 @@ __global__ void ComputeBackwardFusedParamsCUDAKernel( } template -__global__ void GroupNormBackwardSimpleCUDAKernel( - int64_t N, - int64_t C, - int64_t HxW, - int64_t group, - const T* dY, - const T* X, - const acc_type* c1, - const acc_type* c2, - const acc_type* c3, - T* dX) { - using T_ACC = acc_type; - const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < N * C * HxW) { - const int64_t nc = index / HxW; - const int64_t ng = nc / (C / group); - dX[index] = c1[nc] * static_cast(dY[index]) + - c2[ng] * static_cast(X[index]) + c3[ng]; - } -} - -template -__global__ void GroupNormBackwardCUDAKernel( - int64_t C, - int64_t HxW, - int64_t group, - const T* dY, - const T* X, - const acc_type* c1, - const acc_type* c2, - const acc_type* c3, - T* dX) { - using T_ACC = acc_type; - const int64_t D = C / group; - const int64_t nc = blockIdx.x; - const int64_t ng = nc / D; - for (int64_t hw = threadIdx.x; hw < HxW; hw += blockDim.x) { - const int64_t index = nc * HxW + hw; - dX[index] = c1[nc] * static_cast(dY[index]) + - c2[ng] * static_cast(X[index]) + c3[ng]; - } -} - -template -__global__ void GammaBetaBackwardSimpleCUDAKernel( +__global__ void GammaBetaBackwardCUDAKernel1( int64_t N, int64_t C, int64_t group, @@ -273,7 +377,7 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( } template -__global__ void GammaBetaBackwardCUDAKernel( +__global__ void GammaBetaBackwardCUDAKernel2( int64_t N, int64_t C, int64_t group, @@ -294,6 +398,9 @@ __global__ void GammaBetaBackwardCUDAKernel( if (c < C) { const int64_t G = group; const int64_t D = C / G; + // Accumulate each 32 cols into a 32 * 32 tile. + // Since the blockDim is (32, 16), accumulate twice for 1st and 2nd 16 rows + // of a 32 contiguous elements. for (int64_t n = threadIdx.y; n < N; n += blockDim.y * 2) { const int64_t n1 = n; const int64_t n2 = n + blockDim.y; @@ -315,11 +422,15 @@ __global__ void GammaBetaBackwardCUDAKernel( } } } + + // Write accumulated tile to shared memory. g_shared[threadIdx.y][threadIdx.x] = dg_sum1; g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2; b_shared[threadIdx.y][threadIdx.x] = db_sum1; b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2; __syncthreads(); + + // Do warp reduce for the 1st 16 cols in the tile. T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y]; T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y]; sum1 = cuda_utils::WarpReduceSum(sum1); @@ -335,6 +446,8 @@ __global__ void GammaBetaBackwardCUDAKernel( } } } + + // Do warp reduce for the 2st 16 cols in the tile. sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y]; sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y]; sum1 = cuda_utils::WarpReduceSum(sum1); @@ -352,6 +465,78 @@ __global__ void GammaBetaBackwardCUDAKernel( } } +template +void GroupNorm1dForward( + const Tensor& X, + const Tensor& mean, + const Tensor& rstd, + const Tensor& gamma, + const Tensor& beta, + int64_t N, + int64_t C, + int64_t group, + Tensor& Y) { + using T_ACC = acc_type; + const int64_t G = group; + const int64_t D = C / G; + if (gamma.defined() && beta.defined()) { + auto iter = TensorIteratorConfig() + .resize_outputs(false) + .add_output(Y.view({N, G, D})) + .add_input(X.view({N, G, D})) + .add_input(mean.view({N, G, 1})) + .add_input(rstd.view({N, G, 1})) + .add_input(gamma.view({1, G, D})) + .add_input(beta.view({1, G, D})) + .build(); + gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma, T beta) -> T { + return (static_cast(x) - static_cast(mean)) * + static_cast(rstd) * static_cast(gamma) + + static_cast(beta); + }); + } else if (gamma.defined()) { + auto iter = TensorIteratorConfig() + .resize_outputs(false) + .add_output(Y.view({N, G, D})) + .add_input(X.view({N, G, D})) + .add_input(mean.view({N, G, 1})) + .add_input(rstd.view({N, G, 1})) + .add_input(gamma.view({1, G, D})) + .build(); + gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T gamma) -> T { + return (static_cast(x) - static_cast(mean)) * + static_cast(rstd) * static_cast(gamma); + }); + } else if (beta.defined()) { + auto iter = TensorIteratorConfig() + .resize_outputs(false) + .add_output(Y.view({N, G, D})) + .add_input(X.view({N, G, D})) + .add_input(mean.view({N, G, 1})) + .add_input(rstd.view({N, G, 1})) + .add_input(beta.view({1, G, D})) + .build(); + gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd, T beta) -> T { + return (static_cast(x) - static_cast(mean)) * + static_cast(rstd) + + static_cast(beta); + }); + } else { + auto iter = TensorIteratorConfig() + .resize_outputs(false) + .add_output(Y.view({N * G, D})) + .add_input(X.view({N * G, D})) + .add_input(mean.view({N * G, 1})) + .add_input(rstd.view({N * G, 1})) + .build(); + gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T { + return (static_cast(x) - static_cast(mean)) * + static_cast(rstd); + }); + } + AT_CUDA_CHECK(cudaGetLastError()); +} + template void GroupNormKernelImplInternal( const Tensor& X, @@ -362,9 +547,9 @@ void GroupNormKernelImplInternal( int64_t HxW, int64_t group, T eps, - Tensor* Y, - Tensor* mean, - Tensor* rstd) { + Tensor& Y, + Tensor& mean, + Tensor& rstd) { using T_ACC = acc_type; TORCH_CHECK(X.numel() == N * C * HxW); TORCH_CHECK(!gamma.defined() || gamma.numel() == C); @@ -375,32 +560,63 @@ void GroupNormKernelImplInternal( const int64_t G = group; const int64_t D = C / G; const T* X_data = X.data_ptr(); - const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; - const T* beta_data = beta.defined() ? beta.data_ptr() : nullptr; - T* Y_data = Y->data_ptr(); - T* mean_data = mean->data_ptr(); - T* rstd_data = rstd->data_ptr(); - const auto kAccType = X.scalar_type() == kHalf ? kFloat : X.scalar_type(); - Tensor a = at::empty({N, C}, X.options().dtype(kAccType)); - Tensor b = at::empty({N, C}, X.options().dtype(kAccType)); - T_ACC* a_data = a.data_ptr(); - T_ACC* b_data = b.data_ptr(); + T* Y_data = Y.data_ptr(); + T* mean_data = mean.data_ptr(); + T* rstd_data = rstd.data_ptr(); + cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); - RowwiseMomentsCUDAKernel - <<>>( - D * HxW, eps, X_data, mean_data, rstd_data); - int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads; - ComputeFusedParamsCUDAKernel<<>>( - N, C, G, mean_data, rstd_data, gamma_data, beta_data, a_data, b_data); - if (HxW < kCUDANumThreads) { - B = (N * C * HxW + kCUDANumThreads - 1) / kCUDANumThreads; - GroupNormForwardSimpleCUDAKernel<<>>( - N, C, HxW, X_data, a_data, b_data, Y_data); + const int64_t num_threads = D * HxW < cuda_utils::kCUDABlockReduceNumThreads + ? C10_WARP_SIZE + : cuda_utils::kCUDABlockReduceNumThreads; + RowwiseMomentsCUDAKernel<<>>( + D * HxW, eps, X_data, mean_data, rstd_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + if (HxW == 1) { + GroupNorm1dForward(X, mean, rstd, gamma, beta, N, C, G, Y); + } else if (!gamma.defined() && !beta.defined()) { + auto iter = TensorIteratorConfig() + .resize_outputs(false) + .add_output(Y.view({N * G, D * HxW})) + .add_input(X.view({N * G, D * HxW})) + .add_input(mean.view({N * G, 1})) + .add_input(rstd.view({N * G, 1})) + .build(); + gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T { + return (static_cast(x) - static_cast(mean)) * + static_cast(rstd); + }); } else { - GroupNormForwardCUDAKernel<<>>( - HxW, X_data, a_data, b_data, Y_data); + const auto kAccType = + (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16) + ? kFloat + : X.scalar_type(); + Tensor a = at::empty({N, C}, X.options().dtype(kAccType)); + Tensor b = at::empty({N, C}, X.options().dtype(kAccType)); + const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; + const T* beta_data = beta.defined() ? beta.data_ptr() : nullptr; + T_ACC* a_data = a.data_ptr(); + T_ACC* b_data = b.data_ptr(); + + // TODO: Since there is some issues in gpu_kernel_multiple_outputs, we are + // using maunal kernel here. Make it using gpu_kernel_multiple_outputs once + // the issue fixed. + const int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads; + ComputeFusedParamsCUDAKernel<<>>( + N, C, G, mean_data, rstd_data, gamma_data, beta_data, a_data, b_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .resize_outputs(false) + .add_output(Y.view({N * C, HxW})) + .add_input(X.view({N * C, HxW})) + .add_input(a.view({N * C, 1})) + .add_input(b.view({N * C, 1})) + .build(); + gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T { + return a * static_cast(x) + b; + }); } - AT_CUDA_CHECK(cudaGetLastError()); } void GroupNormKernelImpl( @@ -412,30 +628,154 @@ void GroupNormKernelImpl( int64_t HxW, int64_t group, double eps, - Tensor* Y, - Tensor* mean, - Tensor* rstd) { + Tensor& Y, + Tensor& mean, + Tensor& rstd) { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, X.scalar_type(), "GroupNormKernelImpl", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "GroupNormKernelImpl", [&]() { - GroupNormKernelImplInternal( - X, - gamma, - beta, + GroupNormKernelImplInternal( + X, + gamma, + beta, + N, + C, + HxW, + group, + static_cast(eps), + Y, + mean, + rstd); + }); +} + +template +void GroupNorm1dBackward( + const Tensor dY, + const Tensor X, + const Tensor mean, + const Tensor rstd, + const Tensor gamma, + int64_t N, + int64_t C, + int64_t group, + Tensor& dX, + Tensor& dgamma, + Tensor& dbeta) { + using T_ACC = acc_type; + const int64_t G = group; + const int64_t D = C / G; + const T* dY_data = dY.data_ptr(); + const T* X_data = X.data_ptr(); + const T* mean_data = mean.data_ptr(); + const T* rstd_data = rstd.data_ptr(); + + cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); + if (dX.defined()) { + const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; + const auto kAccType = + (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16) + ? kFloat + : X.scalar_type(); + Tensor c2 = at::empty({N, G}, X.options().dtype(kAccType)); + Tensor c3 = at::empty({N, G}, X.options().dtype(kAccType)); + T_ACC* c2_data = c2.data_ptr(); + T_ACC* c3_data = c3.data_ptr(); + const int64_t num_threads = (C / G) < cuda_utils::kCUDABlockReduceNumThreads + ? C10_WARP_SIZE + : cuda_utils::kCUDABlockReduceNumThreads; + Compute1dBackwardFusedParamsCUDAKernel + <<>>( + C, + G, + dY_data, + X_data, + mean_data, + rstd_data, + gamma_data, + c2_data, + c3_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + if (gamma.defined()) { + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .resize_outputs(false) + .add_output(dX.view({N, G, D})) + .add_input(dY.view({N, G, D})) + .add_input(X.view({N, G, D})) + .add_input(rstd.view({N, G, 1})) + .add_input(gamma.view({1, G, D})) + .add_input(c2.view({N, G, 1})) + .add_input(c3.view({N, G, 1})) + .build(); + gpu_kernel( + iter, + [] GPU_LAMBDA(T dy, T x, T rstd, T gamma, T_ACC c2, T_ACC c3) -> T { + const T_ACC c1 = + static_cast(rstd) * static_cast(gamma); + return c1 * static_cast(dy) + c2 * static_cast(x) + + c3; + }); + } else { + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .resize_outputs(false) + .add_output(dX.view({N * G, D})) + .add_input(dY.view({N * G, D})) + .add_input(X.view({N * G, D})) + .add_input(rstd.view({N * G, 1})) + .add_input(c2.view({N * G, 1})) + .add_input(c3.view({N * G, 1})) + .build(); + gpu_kernel( + iter, [] GPU_LAMBDA(T dy, T x, T rstd, T_ACC c2, T_ACC c3) -> T { + const T_ACC c1 = static_cast(rstd); + return c1 * static_cast(dy) + c2 * static_cast(x) + + c3; + }); + } + } + if (dgamma.defined() || dbeta.defined()) { + T* dgamma_data = dgamma.defined() ? dgamma.data_ptr() : nullptr; + T* dbeta_data = dbeta.defined() ? dbeta.data_ptr() : nullptr; + if (N <= 128) { + const int64_t B = (C + kCUDANumThreads - 1) / kCUDANumThreads; + GammaBeta1dBackwardCUDAKernel1<<>>( + N, + C, + G, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize; + // The algorithm for colwise reduction here is to accumulate each 32 cols + // to a 32 * 32 tile and write the tile to shared memmory. Then do warp + // reduce for each col in the tile. So here the blockDim must be (32, 16). + constexpr int kThreadX = kReduceTileSize; + constexpr int kThreadY = kReduceTileSize / 2; + GammaBeta1dBackwardCUDAKernel2 + <<>>( N, C, - HxW, - group, - static_cast(eps), - Y, - mean, - rstd); - }); - }); + G, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } } template @@ -449,11 +789,12 @@ void GroupNormBackwardKernelImplInternal( int64_t C, int64_t HxW, int64_t group, - Tensor* dX, - Tensor* dgamma, - Tensor* dbeta) { + Tensor& dX, + Tensor& dgamma, + Tensor& dbeta) { using T_ACC = acc_type; const int64_t G = group; + const int64_t D = C / G; TORCH_CHECK(dY.numel() == N * C * HxW); TORCH_CHECK(X.numel() == N * C * HxW); TORCH_CHECK(mean.numel() == N * G); @@ -462,15 +803,11 @@ void GroupNormBackwardKernelImplInternal( cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); if (N == 0) { - if (dgamma->defined()) { - T* dgamma_data = dgamma->data_ptr(); - AT_CUDA_CHECK(cudaMemsetAsync( - dgamma_data, 0, dgamma->numel() * sizeof(T), cuda_stream)); + if (dgamma.defined()) { + dgamma.fill_(T(0)); } - if (dbeta->defined()) { - T* dbeta_data = dbeta->data_ptr(); - AT_CUDA_CHECK(cudaMemsetAsync( - dbeta_data, 0, dbeta->numel() * sizeof(T), cuda_stream)); + if (dbeta.defined()) { + dbeta.fill_(T(0)); } return; } @@ -480,31 +817,52 @@ void GroupNormBackwardKernelImplInternal( const T* mean_data = mean.data_ptr(); const T* rstd_data = rstd.data_ptr(); const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; - T* dX_data = dX->defined() ? dX->data_ptr() : nullptr; - const auto kAccType = X.scalar_type() == kHalf ? kFloat : X.scalar_type(); + const auto kAccType = + (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16) + ? kFloat + : X.scalar_type(); Tensor ds = at::empty({N, C}, X.options().dtype(kAccType)); Tensor db = at::empty({N, C}, X.options().dtype(kAccType)); T_ACC* ds_data = ds.data_ptr(); T_ACC* db_data = db.data_ptr(); - ComputeInternalGradientsCUDAKernel - <<>>( - HxW, dY_data, X_data, ds_data, db_data); - if (dX_data != nullptr) { - Tensor c1 = at::empty({N, C}, X.options().dtype(kAccType)); + + if (HxW == 1) { + GroupNorm1dBackward( + dY, X, mean, rstd, gamma, N, C, G, dX, dgamma, dbeta); + return; + } + + int64_t num_threads = HxW < cuda_utils::kCUDABlockReduceNumThreads + ? C10_WARP_SIZE + : cuda_utils::kCUDABlockReduceNumThreads; + ComputeInternalGradientsCUDAKernel<<>>( + HxW, dY_data, X_data, ds_data, db_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + if (dX.defined()) { + Tensor c1 = at::empty({0}, X.options().dtype(kAccType)); Tensor c2 = at::empty({N, G}, X.options().dtype(kAccType)); Tensor c3 = at::empty({N, G}, X.options().dtype(kAccType)); - T_ACC* c1_data = c1.data_ptr(); T_ACC* c2_data = c2.data_ptr(); T_ACC* c3_data = c3.data_ptr(); - int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads; - ComputeGradOutputCoeffientCUDAKernel - <<>>( - N, C, G, rstd_data, gamma_data, c1_data); + + if (gamma.defined()) { + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .add_output(c1) + .add_input(rstd.view({N, G, 1})) + .add_input(gamma.view({1, G, D})) + .build(); + gpu_kernel(iter, [] GPU_LAMBDA(T rstd, T gamma) -> T_ACC { + return static_cast(rstd) * static_cast(gamma); + }); + } + + num_threads = (C / G) < cuda_utils::kCUDABlockReduceNumThreads + ? C10_WARP_SIZE + : cuda_utils::kCUDABlockReduceNumThreads; ComputeBackwardFusedParamsCUDAKernel - <<>>( + <<>>( C, HxW, G, @@ -515,40 +873,67 @@ void GroupNormBackwardKernelImplInternal( db_data, c2_data, c3_data); - if (HxW < kCUDANumThreads) { - B = (N * C * HxW + kCUDANumThreads - 1) / kCUDANumThreads; - GroupNormBackwardSimpleCUDAKernel< - T><<>>( - N, C, HxW, G, dY_data, X_data, c1_data, c2_data, c3_data, dX_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + if (gamma.defined()) { + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .resize_outputs(false) + .add_output(dX.view({N * G, D, HxW})) + .add_input(dY.view({N * G, D, HxW})) + .add_input(X.view({N * G, D, HxW})) + .add_input(c1.view({N * G, D, 1})) + .add_input(c2.view({N * G, 1, 1})) + .add_input(c3.view({N * G, 1, 1})) + .build(); + gpu_kernel( + iter, [] GPU_LAMBDA(T dy, T x, T_ACC c1, T_ACC c2, T_ACC c3) -> T { + return c1 * static_cast(dy) + c2 * static_cast(x) + + c3; + }); } else { - GroupNormBackwardCUDAKernel - <<>>( - C, HxW, G, dY_data, X_data, c1_data, c2_data, c3_data, dX_data); + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .resize_outputs(false) + .add_output(dX.view({N * G, D * HxW})) + .add_input(dY.view({N * G, D * HxW})) + .add_input(X.view({N * G, D * HxW})) + .add_input(rstd.view({N * G, 1})) + .add_input(c2.view({N * G, 1})) + .add_input(c3.view({N * G, 1})) + .build(); + gpu_kernel( + iter, [] GPU_LAMBDA(T dy, T x, T_ACC c1, T_ACC c2, T_ACC c3) -> T { + return c1 * static_cast(dy) + c2 * static_cast(x) + + c3; + }); } - AT_CUDA_CHECK(cudaGetLastError()); } - if (dgamma->defined() || dbeta->defined()) { - T* dgamma_data = dgamma->defined() ? dgamma->data_ptr() : nullptr; - T* dbeta_data = dbeta->defined() ? dbeta->data_ptr() : nullptr; - if (N < 512) { + if (dgamma.defined() || dbeta.defined()) { + T* dgamma_data = dgamma.defined() ? dgamma.data_ptr() : nullptr; + T* dbeta_data = dbeta.defined() ? dbeta.data_ptr() : nullptr; + if (N <= 128) { // For small batch size, do colwise reduce directly. const int64_t B = (C + kCUDANumThreads - 1) / kCUDANumThreads; - GammaBetaBackwardSimpleCUDAKernel - <<>>( - N, - C, - G, - mean_data, - rstd_data, - ds_data, - db_data, - dgamma_data, - dbeta_data); + GammaBetaBackwardCUDAKernel1<<>>( + N, + C, + G, + mean_data, + rstd_data, + ds_data, + db_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize; + // The algorithm for colwise reduction here is to accumulate each 32 cols + // to a 32 * 32 tile and write the tile to shared memmory. Then do warp + // reduce for each col in the tile. So here the blockDim must be (32, 16). constexpr int kThreadX = kReduceTileSize; constexpr int kThreadY = kReduceTileSize / 2; - GammaBetaBackwardCUDAKernel + GammaBetaBackwardCUDAKernel2 <<>>( N, C, @@ -559,9 +944,9 @@ void GroupNormBackwardKernelImplInternal( db_data, dgamma_data, dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } - AT_CUDA_CHECK(cudaGetLastError()); } void GroupNormBackwardKernelImpl( @@ -574,31 +959,17 @@ void GroupNormBackwardKernelImpl( int64_t C, int64_t HxW, int64_t group, - Tensor* dX, - Tensor* dgamma, - Tensor* dbeta) { + Tensor& dX, + Tensor& dgamma, + Tensor& dbeta) { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, X.scalar_type(), "GroupNormBackwardKernelImpl", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM( - scalar_t, "GroupNormBackwardKernelImpl", [&]() { - GroupNormBackwardKernelImplInternal( - dY, - X, - mean, - rstd, - gamma, - N, - C, - HxW, - group, - dX, - dgamma, - dbeta); - }); + GroupNormBackwardKernelImplInternal( + dY, X, mean, rstd, gamma, N, C, HxW, group, dX, dgamma, dbeta); }); } diff --git a/aten/src/ATen/native/cuda/im2col.cuh b/aten/src/ATen/native/cuda/im2col.cuh index 77e39de0d76a2..aee072fcea824 100644 --- a/aten/src/ATen/native/cuda/im2col.cuh +++ b/aten/src/ATen/native/cuda/im2col.cuh @@ -108,7 +108,7 @@ void im2col( height_col, width_col, data_col); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -208,7 +208,7 @@ void col2im( output_height, output_width, data_im); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } // namespace native diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 1fa8e6be58d3e..27e25be626a59 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -278,9 +278,10 @@ void LayerNormKernelImplInternal( RowwiseMomentsCUDAKernel <<>>( N, eps, X_data, mean_data, rstd_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); LayerNormForwardCUDAKernel<<>>( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } void LayerNormKernelImpl( @@ -339,6 +340,7 @@ void LayerNormBackwardKernelImplInternal( ComputeInternalGradientsCUDAKernel <<>>( N, dY_data, X_data, gamma_data, ds_data, db_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; ComputeGradientFusedParamsCUDAKernel <<>>( @@ -350,6 +352,7 @@ void LayerNormBackwardKernelImplInternal( db_data, scale_data, bias_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); LayerNormBackwardCUDAKenrel<<>>( N, dY_data, @@ -359,6 +362,7 @@ void LayerNormBackwardKernelImplInternal( scale_data, bias_data, dX_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } if (dgamma->defined() || dbeta->defined()) { T* dgamma_data = @@ -377,6 +381,7 @@ void LayerNormBackwardKernelImplInternal( rstd_data, dgamma_data, dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { const int64_t B = (N + kColwiseReduceTileSize - 1) / kColwiseReduceTileSize; @@ -392,6 +397,7 @@ void LayerNormBackwardKernelImplInternal( rstd_data, dgamma_data, dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } } @@ -417,47 +423,76 @@ void LayerNormBackwardKernelImpl( } // namespace std::tuple layer_norm_cuda( - const Tensor& X, - const Tensor& gamma /* optional */, - const Tensor& beta /* optional */, - int64_t M, - int64_t N, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */, double eps) { + + auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias); + auto X = std::get<0>(inputs); + auto gamma = std::get<1>(inputs); + auto beta = std::get<2>(inputs); + auto M = std::get<3>(inputs); + auto N = std::get<4>(inputs); + Tensor Y = at::native::empty_like(X, LEGACY_CONTIGUOUS_MEMORY_FORMAT); Tensor mean = at::empty({M}, X.options()); Tensor rstd = at::empty({M}, X.options()); if (M > 0) { LayerNormKernelImpl(X, gamma, beta, M, N, eps, &Y, &mean, &rstd); + + const auto input_shape = input.sizes(); + const size_t axis = input.dim() - normalized_shape.size(); + + std::vector stat_shape; + for (size_t idx = 0; idx < axis; ++idx) { + stat_shape.push_back(input_shape[idx]); + } + for (size_t idx = axis; idx < input.dim(); ++idx) { + stat_shape.push_back(1); + } + + mean = mean.view(stat_shape); + rstd = rstd.view(stat_shape); } return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd)); } std::tuple layer_norm_backward_cuda( const Tensor& dY, - const Tensor& X, + const Tensor& input, + IntArrayRef normalized_shape, const Tensor& mean, const Tensor& rstd, - const Tensor& gamma, - int64_t M, - int64_t N, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */, std::array grad_input_mask) { - Tensor dX; - Tensor dgamma; - Tensor dbeta; - if (grad_input_mask[0]) { - dX = at::native::empty_like(X, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - if (grad_input_mask[1]) { - dgamma = M > 0 ? at::native::empty_like(gamma, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : at::native::zeros_like(gamma, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - if (grad_input_mask[2]) { - dbeta = M > 0 ? at::native::empty_like(gamma, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : at::native::zeros_like(gamma, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - if (M > 0) { - LayerNormBackwardKernelImpl( - dY, X, mean, rstd, gamma, M, N, &dX, &dgamma, &dbeta); - } - return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); + + auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias); + auto X = std::get<0>(inputs); + auto gamma = std::get<1>(inputs); + auto beta = std::get<2>(inputs); + auto M = std::get<3>(inputs); + auto N = std::get<4>(inputs); + + Tensor dX; + Tensor dgamma; + Tensor dbeta; + if (grad_input_mask[0]) { + dX = at::native::empty_like(X, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[1]) { + dgamma = M > 0 ? at::native::empty_like(gamma, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : at::native::zeros_like(gamma, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[2]) { + dbeta = M > 0 ? at::native::empty_like(beta, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : at::native::zeros_like(beta, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (M > 0) { + LayerNormBackwardKernelImpl( + dY, X, mean, rstd, gamma, M, N, &dX, &dgamma, &dbeta); + } + return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } diff --git a/aten/src/ATen/native/cuda/vol2col.cuh b/aten/src/ATen/native/cuda/vol2col.cuh index 496d7578e2d5c..960b44e6d106e 100644 --- a/aten/src/ATen/native/cuda/vol2col.cuh +++ b/aten/src/ATen/native/cuda/vol2col.cuh @@ -129,7 +129,7 @@ void vol2col( height_col, width_col, data_col); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -264,7 +264,7 @@ void col2vol( output_height, output_width, data_vol); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 30053c2e02f24..5d86ee495926a 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -122,7 +122,7 @@ std::tuple cudnn_batch_norm( int64_t num_features = input_t.size(1); save_mean = at::empty({ num_features }, weight_t.options()); save_var = at::empty({ num_features }, weight_t.options()); - + #if CUDNN_VERSION >= 7400 auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; @@ -225,6 +225,7 @@ std::tuple cudnn_batch_norm_backward( { // TODO: Is it worth it to have a contiguous call or maybe we should go with // whatever format is given here. + TensorArg input{ input_t, "input", 1 }, grad_output{ grad_output_t.contiguous(input_t.suggest_memory_format()), "grad_output", 2 }, weight{ weight_t, "weight", 3 }, @@ -246,7 +247,7 @@ std::tuple cudnn_batch_norm_backward( checkAllContiguous(c, {save_mean, save_var}); // TODO: TensorArg check should start handle memory format TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); - TORCH_CHECK(grad_output->is_contiguous(grad_output->suggest_memory_format())); + TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format())); checkDimRange(c, input, 2, 6 /* exclusive */); checkSameSize(c, input, grad_output); auto num_features = input->size(1); @@ -310,7 +311,7 @@ std::tuple cudnn_batch_norm_backward( odesc.desc(), grad_output->data_ptr(), nullptr, nullptr, idesc.desc(), grad_input_t.data_ptr(), - wdesc.desc(), weight->data_ptr(), + wdesc.desc(), weight->data_ptr(), nullptr, grad_weight_t.data_ptr(), grad_bias_t.data_ptr(), diff --git a/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp b/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp new file mode 100644 index 0000000000000..bac8df92a5fc3 --- /dev/null +++ b/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp @@ -0,0 +1,147 @@ +#include // for the definition of AT_CUDNN_ENABLED +#include +#include + +namespace at { namespace native { + +// --------------------------------------------------------------------- +// +// Placeholder operators +// +// --------------------------------------------------------------------- + +#if !AT_CUDNN_ENABLED() + +// See Note [ATen preprocessor philosophy] + +at::Tensor cudnn_convolution( + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_backward_input( + IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_backward_weight( + IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support"); +} + +std::tuple cudnn_convolution_backward( + const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_transpose( + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_transpose_backward_input( + const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); +} + +at::Tensor cudnn_convolution_transpose_backward_weight( + IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support"); +} + +std::tuple cudnn_convolution_transpose_backward( + const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); +} + +void raw_cudnn_convolution_forward_out( + const Tensor& output, const Tensor& input, const Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("raw_cudnn_convolution_forward_out: ATen not compiled with cuDNN support"); +} + +void raw_cudnn_convolution_backward_input_out( + const at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("raw_cudnn_convolution_backward_input_out: ATen not compiled with cuDNN support"); +} + +void raw_cudnn_convolution_backward_weight_out( + const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) { + AT_ERROR("raw_cudnn_convolution_backward_weight_out: ATen not compiled with cuDNN support"); +} + +#endif // AT_CUDNN_ENABLED + +// --------------------------------------------------------------------- +// +// Deprecated operators +// +// --------------------------------------------------------------------- + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_deprecated( + const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) { + auto output = at::cudnn_convolution(input, weight, padding, stride, dilation, groups, benchmark, deterministic); + if (bias.defined()) { + output = output + reshape_bias(input.dim(), bias); + } + return output; +} + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_deprecated2( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + return at::cudnn_convolution(input_t, weight_t, padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); +} + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_transpose_deprecated( + const Tensor& input, const Tensor& weight, const Tensor& bias /* optional */, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + auto output = at::cudnn_convolution_transpose(input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + if (bias.defined()) { + output = output + reshape_bias(input.dim(), bias); + } + return output; +} + +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_transpose_deprecated2( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + return at::cudnn_convolution_transpose(input_t, weight_t, padding, output_padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); +} + +}} diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp new file mode 100644 index 0000000000000..e360008e2707e --- /dev/null +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -0,0 +1,500 @@ +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() + +#include + +// NOTE [cuDNN API version] +// +// ConvPlaceholders.cpp contains placeholder implementation of cudnn +// convolution when cudnn is not enabled. These operators only raises +// errors, and do no real computation. This file also contains deprecated +// operators. These operators are implemented using currnet operators. +// +// cuDNN v7 and v8 have different API. ConvShared.{cpp, h} contains +// code shared by v7 and v8. Conv_v7.cpp contains implementation of +// convolution using cuDNN v7 API. Conv_v8.cpp contains implementation +// with v8 API. +// +// NOTE [ Convolution design ] +// +// cuDNN convolutions does not handle bias. Bias is handled outside. +// +// The general strategy: +// +// - cudnn_convolution (Tensor) +// Entry points for clients +// +// - cudnn_convolution_forward (TensorArg) +// Entry point, which may be reused between regular +// convolution and transposed convolution. +// +// - raw_cudnn_convolution_forward_out (Tensor) +// Function that has different implementation on Conv_v7.cpp +// and Conv_v8.cpp +// +// The raw API directly invokes CuDNN and are implemeted differently +// on cuDNN v7 and cuDNN v8 +// +// There are a few reasons this should never be directly exposed +// via ATen: +// +// - It takes output as a parameter (this should be computed!) +// - It doesn't do input checking +// - It doesn't resize output (it is assumed to be correctly sized) +// +// Where does argument checking happen? Here's the division of +// responsibility: +// - Things that happen in at::Tensor +// - TensorArg allocation +// - Things that happen in TensorArg +// - Check arguments (type, GPU, shape) + +namespace at { namespace native { + +// --------------------------------------------------------------------- +// +// ConvolutionParams and ConvolutionArgs +// +// --------------------------------------------------------------------- + +std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params) { + out << "ConvolutionParams \n" + << " data_type = " << cudnnTypeToString(params.dataType) << "\n" + << " padding = " << ArrayRef{params.padding} << "\n" + << " stride = " << ArrayRef{params.stride} << "\n" + << " dilation = " << ArrayRef{params.dilation} << "\n" + << " groups = " << params.groups << "\n" + << " deterministic = " << (params.deterministic ? "true" : "false") << "\n" + << " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n"; + + return out; +} + +// NB: This can't be a constructor, because then ConvolutionParams +// would not be a POD anymore. +// TODO: Use TensorGeometry here instead of the entire Tensor, which we +// don't actually need. (OTOH: We can always pass in +// grad_input/grad_output, so this is not very pressing) +void setConvolutionParams( + ConvolutionParams* params, + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool deterministic, bool allow_tf32) { + + cudnnDataType_t dataType = getCudnnDataType(input); + memset(params, 0, sizeof(ConvolutionParams)); + params->dataType = dataType; + // ASSERT(weight.dim() == input.dim()) + for (int i = 0; i != input.dim(); ++i) { + params->input_size[i] = (int) input.size(i); + params->input_stride[i] = (int) input.stride(i); + params->weight_size[i] = (int) weight.size(i); + } + // ASSERT(padding.size() == stride.size()) + // ASSERT(padding.size() == dilation.size()) + for (size_t i = 0; i != padding.size(); ++i) { + params->padding[i] = padding[i]; + params->stride[i] = stride[i]; + params->dilation[i] = dilation[i]; + } + // In principle, we shouldn't parametrize by groups for legacy + // CuDNN, but it doesn't seem worth the effort to actually do this. + params->groups = groups; + params->deterministic = deterministic; + params->allow_tf32 = allow_tf32; +} + +std::string repro_from_args(const ConvolutionArgs& args) { + auto pybool = [](bool b) -> const char* { return b ? "True" : "False"; }; + std::string partial_dtype; + switch (args.params.dataType) { + case CUDNN_DATA_FLOAT: partial_dtype = "float"; break; + case CUDNN_DATA_DOUBLE: partial_dtype = "double"; break; + case CUDNN_DATA_HALF: partial_dtype = "half"; break; + default: partial_dtype = "unsupported"; + } + const std::string full_dtype = "torch." + partial_dtype; + const int out_channels = args.weight.sizes()[0]; + const int in_channels = args.weight.sizes()[1] * args.params.groups; + const size_t dim = args.input.sizes().size(); + const std::string channels_last_xd = dim == 4 ? "channels_last" : "channels_last_3d"; + const std::string to_channels_last = args.input.suggest_memory_format() == at::MemoryFormat::ChannelsLast \ + ? ".to(memory_format=torch." + channels_last_xd + ")" : ""; + + std::ostringstream ss; + ss << "You can try to repro this exception using the following code snippet. "; + ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n"; + ss << "import torch\n"; + ss << "torch.backends.cuda.matmul.allow_tf32 = " << pybool(at::globalContext().allowTF32CuBLAS()) << "\n"; + ss << "torch.backends.cudnn.benchmark = " << pybool(at::globalContext().benchmarkCuDNN()) << "\n"; + ss << "torch.backends.cudnn.deterministic = " << pybool(args.params.deterministic) << "\n"; + ss << "torch.backends.cudnn.allow_tf32 = " << pybool(args.params.allow_tf32) << "\n"; + ss << "data = torch.randn(" << args.input.sizes() << ", dtype=" << full_dtype << ", "; + ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n"; + ss << "net = torch.nn.Conv" << dim-2 << "d(" << in_channels << ", " << out_channels << ", "; + ss << "kernel_size=" << args.weight.sizes().slice(2) << ", "; + ss << "padding=" << ArrayRef(args.params.padding, dim-2) << ", "; + ss << "stride=" << ArrayRef(args.params.stride, dim-2) << ", "; + ss << "dilation=" << ArrayRef(args.params.dilation, dim-2) << ", "; + ss << "groups=" << args.params.groups << ")\n"; + ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last << "\n"; + ss << "out = net(data)\n"; + ss << "out.backward(torch.randn_like(out))\n"; + ss << "torch.cuda.synchronize()\n\n"; + + return ss.str(); +} + +std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args) { + out << repro_from_args(args) // already has a trailing newline + << args.params // already has a trailing newline + << "input: " << args.idesc // already has a trailing newline + << "output: " << args.odesc // already has a trailing newline + << "weight: " << args.wdesc // already has a trailing newline + << "Pointer addresses: " << "\n" + << " input: " << args.input.data_ptr() << "\n" + << " output: " << args.output.data_ptr() << "\n" + << " weight: " << args.weight.data_ptr() << "\n"; + + return out; +} + +// --------------------------------------------------------------------- +// +// Checking +// +// --------------------------------------------------------------------- + +// Used on pad, stride and dilation +static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name) +{ + TORCH_CHECK(args.size() <= expected_size, + "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ", + expected_size, " (while checking arguments for ", c, ")"); + TORCH_CHECK(args.size() >= expected_size, + "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ", + expected_size, " (while checking arguments for ", c, ")"); + + auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;}); + if (num_negative_values > 0){ + std::stringstream ss; + ss << arg_name << " should be greater than zero but got ("; + std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", ")); + ss << args.back() << ")" << " (while checking arguments for " << c << ")"; + AT_ERROR(ss.str()); + } +} + + +// NOTE [ Convolution checks ] +// +// NB: For many call sites, it is not strictly necessary to check all of +// these relationships (for example, for forward convolution, we compute +// the size of output ourselves, so we don't actually need to check +// output. However, writing a single function that does everything +// means we get to reuse it for both forwards and all backwards +// variants, even when the set of "real" inputs varies. The magic of +// relational computing! +// +// (There is one downside, which is that it is slightly harder to write +// error messages which are able to distinguish between real inputs +// (which the user can change) and computed inputs (which the user can +// only indirectly affect). It would be an interesting exercise to +// come up with a general framework to handle such situations.) +static void convolution_shape_check( + CheckedFrom c, + const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) +{ + check_args(c, padding, input->dim() - 2, "padding"); + check_args(c, stride, padding.size(), "stride"); + check_args(c, dilation, padding.size(), "dilation"); + + // Input + checkDimRange(c, input, 3, 6 /* exclusive */); + checkSize(c, input, input_channels_dim, weight->size(1) * groups); + + // Weight + checkSameDim(c, input, weight); + + // TODO: check that output->size() matches output_sizes + // TODO: check that weight matches output->sizes() + checkSameDim(c, input, output); +} + +// --------------------------------------------------------------------- +// +// Convolution forward / Transposed convolution backward +// +// --------------------------------------------------------------------- + +Tensor cudnn_convolution_forward( + CheckedFrom c, + const TensorArg& input, const TensorArg& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + checkAllSameType(c, {input, weight}); + checkAllSameGPU(c, {input, weight}); + + auto layout = cudnn_conv_use_channels_last(*input, *weight) ? + at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; + auto output_t = at::empty( + conv_output_size(input->sizes(), weight->sizes(), + padding, stride, dilation), + input->options(), + layout); + + if (output_t.numel() == 0) { + return output_t; + } + + // Avoid ambiguity of "output" when this is being used as backwards + TensorArg output{ output_t, "result", 0 }; + convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); + + // See #4500 + Tensor weight_contig = weight->contiguous(layout); + // Make sure that NC11 strides follow formula + weight_contig.resize_(weight_contig.sizes(), layout); + Tensor input_contig = input->contiguous(layout); + input_contig.resize_(input_contig.sizes(), layout); + + raw_cudnn_convolution_forward_out( + *output, input_contig, weight_contig, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + + return *output; +} + +Tensor cudnn_convolution( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }; + CheckedFrom c = "cudnn_convolution"; + auto output_t = cudnn_convolution_forward( + c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return output_t; +} + +// NB: output_padding not needed here, as there is no ambiguity to +// resolve +Tensor cudnn_convolution_transpose_backward_input( + const Tensor& grad_output_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg grad_output { grad_output_t, "grad_output", 1 }, + weight { weight_t, "weight", 2 }; + return cudnn_convolution_forward( + "cudnn_convolution_transpose_backward_input", + grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +std::tuple cudnn_convolution_transpose_backward( + const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + + Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); + + Tensor grad_input, grad_weight; + if (output_mask[0]) { + grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + if (output_mask[1]) { + grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + + return std::tuple{grad_input, grad_weight}; +} + +// --------------------------------------------------------------------- +// +// Convolution backward / Transposed convolution forward +// +// --------------------------------------------------------------------- + +// NOTE [ Backward vs transpose convolutions ] +// +// Backward and transpose are algorithmically equivalent, but they +// compute their geometry differently. In a backwards, you knew what +// the original size of the input tensor was, so you can cache that +// geometry and fill it directly. In transposed convolution, it is +// more conventional to not explicitly specify the output (previously +// input) size, and compute it. This, however, leaves a degree of +// freedom; this degree of freedom is resolved using the +// output_padding parameter. Both of these interfaces are equivalent, +// but they are differently convenient depending on the use case. + +Tensor cudnn_convolution_backward_input( + CheckedFrom c, + IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + checkAllSameType(c, {grad_output, weight}); + checkAllSameGPU(c, {grad_output, weight}); + + auto layout = cudnn_conv_use_channels_last(*grad_output, *weight) ? + at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; + auto grad_input_t = at::empty(input_size, grad_output->options(), layout); + + // Avoid "grad_input" when this is being used as transposed convolution + TensorArg grad_input{ grad_input_t, "result", 0 }; + convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); + + // See #4500 + Tensor weight_contig = weight->contiguous(layout); + // Make sure that NC11 strides follow formula + weight_contig.resize_(weight_contig.sizes(), layout); + + Tensor grad_output_contig = grad_output->contiguous(layout); + grad_output_contig.resize_(grad_output_contig.sizes(), layout); + + raw_cudnn_convolution_backward_input_out( + *grad_input, grad_output_contig, weight_contig, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + + return *grad_input; +} + +Tensor cudnn_convolution_transpose_forward( + CheckedFrom c, + const TensorArg& grad_output, const TensorArg& weight, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(), + padding, output_padding, stride, dilation, groups); + return cudnn_convolution_backward_input(c, input_size, grad_output, weight, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +Tensor cudnn_convolution_backward_input( + IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg grad_output{ grad_output_t, "grad_output", 1 }, + weight{ weight_t, "weight", 2 }; + return cudnn_convolution_backward_input( + "cudnn_convolution_backward_input", + input_size, grad_output, weight, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +std::tuple cudnn_convolution_backward( + const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { + + Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); + + Tensor grad_input, grad_weight; + if (input.numel() == 0) { + if (output_mask[0]) { + grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (output_mask[1]) { + grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + } else { + if (output_mask[0]) { + grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + if (output_mask[1]) { + grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + } + + return std::tuple{grad_input, grad_weight}; +} + +Tensor cudnn_convolution_transpose( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) +{ + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }; + CheckedFrom c = "cudnn_convolution_transpose"; + auto output_t = cudnn_convolution_transpose_forward( + c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return output_t; +} + +// --------------------------------------------------------------------- +// +// Convolution backward (weight) +// +// --------------------------------------------------------------------- + +Tensor cudnn_convolution_backward_weight( + CheckedFrom c, + IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ? + at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; + + Tensor grad_output_contig_t = grad_output_t.contiguous(layout); + // Make sure that NC11 strides follow formula + grad_output_contig_t.resize_(grad_output_contig_t.sizes(), layout); + TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 }; + + Tensor input_contig_t = input_t.contiguous(layout); + input_contig_t.resize_(input_contig_t.sizes(), layout); + TensorArg input{ input_contig_t, "input", 2}; + + checkAllSameType(c, {grad_output_contig, input}); + checkAllSameGPU(c, {grad_output_contig, input}); + + auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), layout); + + // For uniformity with everything else, although it seems grad_weight + // would be unambiguous too. + TensorArg grad_weight{ grad_weight_t, "result", 0 }; + convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups); + + raw_cudnn_convolution_backward_weight_out( + *grad_weight, *grad_output_contig, *input, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + + return grad_weight_t; +} + +Tensor cudnn_convolution_backward_weight( + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + return cudnn_convolution_backward_weight( + "cudnn_convolution_backward_weight", + weight_size, grad_output_t, input_t, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +Tensor cudnn_convolution_transpose_backward_weight( + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32) +{ + return cudnn_convolution_backward_weight( + "cudnn_convolution_backward_weight", + weight_size, input_t, grad_output_t, + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); +} + +}} + +#endif // AT_CUDNN_ENABLED diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h new file mode 100644 index 0000000000000..e30b5c7be581c --- /dev/null +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -0,0 +1,88 @@ +#include + +#include +#include +#include +#include + +namespace at { namespace native { + +// --------------------------------------------------------------------- +// +// Helper classes +// +// --------------------------------------------------------------------- + +// This POD struct is used to let us easily compute hashes of the +// parameters +struct ConvolutionParams +{ + cudnnDataType_t dataType; + int input_size[2 + max_dim]; + int input_stride[2 + max_dim]; + int weight_size[2 + max_dim]; + int padding[max_dim]; + int stride[max_dim]; + int dilation[max_dim]; + int64_t groups; + bool deterministic; + bool allow_tf32; + // NB: transposed purposely omitted: transposed just swaps + // forward and backward, so you can reuse the benchmark entry, +}; + +// Convenience struct for passing around descriptors and data +// pointers +struct ConvolutionArgs { + cudnnHandle_t handle; + ConvolutionParams params; + TensorDescriptor idesc, odesc; + FilterDescriptor wdesc; + const Tensor& input, output, weight; + ConvolutionDescriptor cdesc; + + ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) { + } +}; + +std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params); + +// NB: This can't be a constructor, because then ConvolutionParams +// would not be a POD anymore. +// TODO: Use TensorGeometry here instead of the entire Tensor, which we +// don't actually need. (OTOH: We can always pass in +// grad_input/grad_output, so this is not very pressing) +void setConvolutionParams( + ConvolutionParams* params, + const at::Tensor& input, const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool deterministic, bool allow_tf32); + +std::string repro_from_args(const ConvolutionArgs& args); + +std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args); + +// --------------------------------------------------------------------- +// +// Raw functions +// +// --------------------------------------------------------------------- + +void raw_cudnn_convolution_forward_out( + const Tensor& output, const Tensor& input, const Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32); + +void raw_cudnn_convolution_backward_input_out( + const at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& weight, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32); + +void raw_cudnn_convolution_backward_weight_out( + const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, + bool benchmark, bool deterministic, bool allow_tf32); + +}} diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp similarity index 54% rename from aten/src/ATen/native/cudnn/Conv.cpp rename to aten/src/ATen/native/cudnn/Conv_v7.cpp index 4ddd533ec8f8b..5e1f124f11854 100644 --- a/aten/src/ATen/native/cudnn/Conv.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -1,83 +1,18 @@ +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() + #include #include +#include #include #include #include #include -#include #include -#include - -#if !AT_CUDNN_ENABLED() - -namespace at { namespace native { - -// See Note [ATen preprocessor philosophy] - -at::Tensor cudnn_convolution( - const at::Tensor& input, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_backward_input( - IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_backward_weight( - IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support"); -} - -std::tuple cudnn_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_transpose( - const at::Tensor& input, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_transpose_backward_input( - const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); -} - -at::Tensor cudnn_convolution_transpose_backward_weight( - IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support"); -} - -std::tuple cudnn_convolution_transpose_backward( - const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); -} - -}} - -#else // AT_CUDNN_ENABLED +#include #include - -#include -#include #include #include #include @@ -100,6 +35,10 @@ std::tuple cudnn_convolution_transpose_backward( // if(dataType == CUDNN_DATA_HALF) // AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); // +// Update: AT_CUDNN_CHECK is updated with AT_CUDNN_CHECK_WITH_SHAPES, which +// automatically prints tensor shapes and convolution parameters if there is +// a cuDNN exception thrown. +// // When cudnnSetConvolutionMathType is called before cudnnGet/cudnnFind, it informs // cudnnGet/cudnnFind to iterate/take into account both tensor core and non-tensor-core algos. // If you don't call cudnnSetConvolutionMathType before calling cudnnGet/cudnnFind, @@ -125,149 +64,6 @@ namespace at { namespace native { // TODO: Go through all the checking code again and make sure // we haven't missed anything. -// TODO: Move this into the standard library, with a better name? -Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) { - auto group_size = t.size(dim) / groups; - return t.narrow(dim, group_idx * group_size, group_size); -} - -// --------------------------------------------------------------------- -// -// Checking -// -// --------------------------------------------------------------------- - -// Note [Legacy CuDNN grouped convolution support] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// CuDNN earlier than CuDNN 7 does not directly support group -// convolution, so we provide support for it by sequentially -// running a convolution per group with appropriately -// adjusted sizes. https://blog.yani.io/filter-group-tutorial/ -// has a fairly good diagram explaining how it works. - -// Used on pad, stride and dilation -static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name) -{ - TORCH_CHECK(args.size() <= expected_size, - "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ", - expected_size, " (while checking arguments for ", c, ")"); - TORCH_CHECK(args.size() >= expected_size, - "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ", - expected_size, " (while checking arguments for ", c, ")"); - - auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;}); - if (num_negative_values > 0){ - std::stringstream ss; - ss << arg_name << " should be greater than zero but got ("; - std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", ")); - ss << args.back() << ")" << " (while checking arguments for " << c << ")"; - AT_ERROR(ss.str()); - } -} - - -// NOTE [ Convolution checks ] -// -// NB: For many call sites, it is not strictly necessary to check all of -// these relationships (for example, for forward convolution, we compute -// the size of output ourselves, so we don't actually need to check -// output. However, writing a single function that does everything -// means we get to reuse it for both forwards and all backwards -// variants, even when the set of "real" inputs varies. The magic of -// relational computing! -// -// (There is one downside, which is that it is slightly harder to write -// error messages which are able to distinguish between real inputs -// (which the user can change) and computed inputs (which the user can -// only indirectly affect). It would be an interesting exercise to -// come up with a general framework to handle such situations.) -static void convolution_shape_check( - CheckedFrom c, - const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) -{ - check_args(c, padding, input->dim() - 2, "padding"); - check_args(c, stride, padding.size(), "stride"); - check_args(c, dilation, padding.size(), "dilation"); - - // Input - checkDimRange(c, input, 3, 6 /* exclusive */); - checkSize(c, input, input_channels_dim, weight->size(1) * groups); - - // Weight - checkSameDim(c, input, weight); - - // TODO: check that output->size() matches output_sizes - // TODO: check that weight matches output->sizes() - checkSameDim(c, input, output); -} - -// This POD struct is used to let us easily compute hashes of the -// parameters -struct ConvolutionParams -{ - cudnnDataType_t dataType; - int input_size[2 + max_dim]; - int input_stride[2 + max_dim]; - int weight_size[2 + max_dim]; - int padding[max_dim]; - int stride[max_dim]; - int dilation[max_dim]; - int64_t groups; - bool deterministic; - bool allow_tf32; - // NB: transposed purposely omitted: transposed just swaps - // forward and backward, so you can reuse the benchmark entry, -}; - -// NB: This can't be a constructor, because then ConvolutionParams -// would not be a POD anymore. -// TODO: Use TensorGeometry here instead of the entire Tensor, which we -// don't actually need. (OTOH: We can always pass in -// grad_input/grad_output, so this is not very pressing) -void setConvolutionParams( - ConvolutionParams* params, - const at::Tensor& input, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool deterministic, bool allow_tf32) { - - cudnnDataType_t dataType = getCudnnDataType(input); - memset(params, 0, sizeof(ConvolutionParams)); - params->dataType = dataType; - // ASSERT(weight.dim() == input.dim()) - for (int i = 0; i != input.dim(); ++i) { - params->input_size[i] = (int) input.size(i); - params->input_stride[i] = (int) input.stride(i); - params->weight_size[i] = (int) weight.size(i); - } - // ASSERT(padding.size() == stride.size()) - // ASSERT(padding.size() == dilation.size()) - for (size_t i = 0; i != padding.size(); ++i) { - params->padding[i] = padding[i]; - params->stride[i] = stride[i]; - params->dilation[i] = dilation[i]; - } - // In principle, we shouldn't parametrize by groups for legacy - // CuDNN, but it doesn't seem worth the effort to actually do this. - params->groups = groups; - params->deterministic = deterministic; - params->allow_tf32 = allow_tf32; -} - -// Convenience struct for passing around descriptors and data -// pointers -struct ConvolutionArgs { - cudnnHandle_t handle; - ConvolutionParams params; - TensorDescriptor idesc, odesc; - FilterDescriptor wdesc; - const Tensor& input, output, weight; - ConvolutionDescriptor cdesc; - - ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) { - } -}; - // --------------------------------------------------------------------- // // Benchmarking @@ -457,7 +253,7 @@ struct algorithm_search { int perf_count; std::unique_ptr perf_results(new perf_t[num_algos]); if (!benchmark) { - AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionForwardAlgorithm_v7( args.handle, args.idesc.desc(), args.wdesc.desc(), @@ -465,11 +261,11 @@ struct algorithm_search { args.odesc.desc(), num_algos, &perf_count, - perf_results.get())); + perf_results.get()), args); } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionForwardAlgorithmEx( args.handle, args.idesc.desc(), args.input.data_ptr(), args.wdesc.desc(), args.weight.data_ptr(), @@ -479,7 +275,7 @@ struct algorithm_search { &perf_count, perf_results.get(), ws.data, - ws.size)); + ws.size), args); // Free the cached blocks in our caching allocator. They are // needed here because the above benchmarking uses a huge amount of memory, @@ -493,14 +289,14 @@ struct algorithm_search { const ConvolutionArgs& args, algo_t algo, size_t* workspaceSize) { - AT_CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionForwardWorkspaceSize( args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.odesc.desc(), algo, - workspaceSize)); + workspaceSize), args); } }; @@ -527,7 +323,7 @@ struct algorithm_search { int perf_count; std::unique_ptr perf_results(new perf_t[num_algos]); if (!benchmark) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardDataAlgorithm_v7( args.handle, args.wdesc.desc(), args.odesc.desc(), @@ -535,11 +331,11 @@ struct algorithm_search { args.idesc.desc(), num_algos, &perf_count, - perf_results.get())); + perf_results.get()), args); } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardDataAlgorithmEx( args.handle, args.wdesc.desc(), args.weight.data_ptr(), args.odesc.desc(), args.output.data_ptr(), @@ -549,7 +345,7 @@ struct algorithm_search { &perf_count, perf_results.get(), ws.data, - ws.size)); + ws.size), args); // Free the cached blocks in our caching allocator. They are // needed here because the above benchmarking uses a huge amount of memory, @@ -563,14 +359,14 @@ struct algorithm_search { const ConvolutionArgs& args, cudnnConvolutionBwdDataAlgo_t algo, size_t* workspaceSize) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardDataWorkspaceSize( args.handle, args.wdesc.desc(), args.odesc.desc(), args.cdesc.desc(), args.idesc.desc(), algo, - workspaceSize)); + workspaceSize), args); } }; @@ -599,7 +395,7 @@ struct algorithm_search { std::unique_ptr perf_results(new perf_t[num_algos]); int perf_count; if (!benchmark) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardFilterAlgorithm_v7( args.handle, args.idesc.desc(), args.odesc.desc(), @@ -607,11 +403,11 @@ struct algorithm_search { args.wdesc.desc(), num_algos, &perf_count, - perf_results.get())); + perf_results.get()), args); } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardFilterAlgorithmEx( args.handle, args.idesc.desc(), args.input.data_ptr(), args.odesc.desc(), args.output.data_ptr(), @@ -621,7 +417,7 @@ struct algorithm_search { &perf_count, perf_results.get(), ws.data, - ws.size)); + ws.size), args); // Free the cached blocks in our caching allocator. They are // needed here because the above benchmarking uses a huge amount of memory, @@ -633,14 +429,14 @@ struct algorithm_search { static void getWorkspaceSize(const ConvolutionArgs& args, algo_t algo, size_t* workspaceSize) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionBackwardFilterWorkspaceSize( args.handle, args.idesc.desc(), args.odesc.desc(), args.cdesc.desc(), args.wdesc.desc(), algo, - workspaceSize)); + workspaceSize), args); } }; @@ -708,18 +504,7 @@ inline Tensor allocate_workspace(size_t size, const Tensor &other) { return at::empty({static_cast(size)}, other.options().dtype(kByte)); } -// NOTE [ Convolution design ] -// -// cuDNN convolutions does not handle bias. Bias is handled outside. -// -// The general strategy: -// -// - cudnn_convolution (Tensor) -// Entry points for clients -// -// - cudnn_convolution_forward (TensorArg) -// Entry point, which may be reused between regular -// convolution and transposed convolution. +// NOTE [ raw_cudnn_convolution_forward_out ] // // - raw_cudnn_convolution_forward_out (Tensor) // Functiont that handles tensors that are too large to use 32bit indexing. @@ -729,14 +514,6 @@ inline Tensor allocate_workspace(size_t size, const Tensor &other) { // Low level function which invokes CuDNN, and takes an output // tensor which is directly written to (thus _out). // -// Where does argument checking happen? Here's the division of -// responsibility: -// - Things that happen in at::Tensor -// - TensorArg allocation -// - Things that happen in TensorArg -// - Check arguments (type, GPU, shape) -// -// TODO: Consider renaming zero-indexed arguments to "self" // --------------------------------------------------------------------- @@ -812,16 +589,6 @@ if (args.params.dataType == CUDNN_DATA_FLOAT) { // // --------------------------------------------------------------------- -// The raw API directly invokes CuDNN and does not emulate support -// for group convolution on old versions of CuDNN. -// -// There are a few reasons this should never be directly exposed -// via ATen: -// -// - It takes output as a parameter (this should be computed!) -// - It doesn't do input checking -// - It doesn't resize output (it is assumed to be correctly sized) -// void raw_cudnn_convolution_forward_out_32bit( const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, @@ -850,17 +617,18 @@ void raw_cudnn_convolution_forward_out_32bit( // whether to use Tensor core kernels or not // See Note [behavior of cudnnFind and cudnnGet] ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType); - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType)); + AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType), args); Constant one(dataType, 1); Constant zero(dataType, 0); - AT_CUDNN_CHECK(cudnnConvolutionForward( - args.handle, - &one, args.idesc.desc(), input.data_ptr(), - args.wdesc.desc(), weight.data_ptr(), - args.cdesc.desc(), fwdAlgPerf.algo, workspace.data_ptr(), fwdAlgPerf.memory, - &zero, args.odesc.desc(), output.data_ptr())); + AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionForward( + args.handle, + &one, args.idesc.desc(), input.data_ptr(), + args.wdesc.desc(), weight.data_ptr(), + args.cdesc.desc(), fwdAlgPerf.algo, workspace.data_ptr(), fwdAlgPerf.memory, + &zero, args.odesc.desc(), output.data_ptr()), + args, "Forward algorithm: ", static_cast(fwdAlgPerf.algo), "\n"); } ); } @@ -872,90 +640,6 @@ void raw_cudnn_convolution_forward_out( split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit); } -Tensor cudnn_convolution_forward( - CheckedFrom c, - const TensorArg& input, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - checkAllSameType(c, {input, weight}); - checkAllSameGPU(c, {input, weight}); - - auto layout = cudnn_conv_use_channels_last(*input, *weight) ? - at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - auto output_t = at::empty( - conv_output_size(input->sizes(), weight->sizes(), - padding, stride, dilation), - input->options(), - layout); - - if (output_t.numel() == 0) { - return output_t; - } - - // Avoid ambiguity of "output" when this is being used as backwards - TensorArg output{ output_t, "result", 0 }; - convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); - - // See #4500 - Tensor weight_contig = weight->contiguous(layout); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), layout); - Tensor input_contig = input->contiguous(layout); - input_contig.resize_(input_contig.sizes(), layout); - - raw_cudnn_convolution_forward_out( - *output, input_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - - return *output; -} - -Tensor cudnn_convolution( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg input { input_t, "input", 1 }, - weight { weight_t, "weight", 2 }; - CheckedFrom c = "cudnn_convolution"; - auto output_t = cudnn_convolution_forward( - c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - return output_t; -} - -// NB: output_padding not needed here, as there is no ambiguity to -// resolve -Tensor cudnn_convolution_transpose_backward_input( - const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg grad_output { grad_output_t, "grad_output", 1 }, - weight { weight_t, "weight", 2 }; - return cudnn_convolution_forward( - "cudnn_convolution_transpose_backward_input", - grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -std::tuple cudnn_convolution_transpose_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - - Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); - - Tensor grad_input, grad_weight; - if (output_mask[0]) { - grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - if (output_mask[1]) { - grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - - return std::tuple{grad_input, grad_weight}; -} - // --------------------------------------------------------------------- // // Convolution backward / Transposed convolution forward @@ -986,17 +670,22 @@ void raw_cudnn_convolution_backward_input_out_32bit( // whether to use Tensor core kernels or not // See Note [behavior of cudnnFind and cudnnGet] ASSERT_CORRECT_PRECISION(bwdDataAlgPerf.mathType); - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType)); + AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType), args); Constant one(dataType, 1); Constant zero(dataType, 0); - AT_CUDNN_CHECK(cudnnConvolutionBackwardData( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionBackwardData( args.handle, &one, args.wdesc.desc(), weight.data_ptr(), args.odesc.desc(), grad_output.data_ptr(), args.cdesc.desc(), bwdDataAlgPerf.algo, workspace.data_ptr(), bwdDataAlgPerf.memory, - &zero, args.idesc.desc(), grad_input.data_ptr())); + &zero, args.idesc.desc(), grad_input.data_ptr()), + args, + "Additional pointer addresses: \n", + " grad_output: ", grad_output.data_ptr(), "\n", + " grad_input: ", grad_input.data_ptr(), "\n", + "Backward data algorithm: ", static_cast(bwdDataAlgPerf.algo), "\n"); } ); } @@ -1010,115 +699,6 @@ void raw_cudnn_convolution_backward_input_out( split_batch_dim_to_32bit_out(grad_input, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 128, raw_cudnn_convolution_backward_input_out_32bit); } -// NOTE [ Backward vs transpose convolutions ] -// -// Backward and transpose are algorithmically equivalent, but they -// compute their geometry differently. In a backwards, you knew what -// the original size of the input tensor was, so you can cache that -// geometry and fill it directly. In transposed convolution, it is -// more conventional to not explicitly specify the output (previously -// input) size, and compute it. This, however, leaves a degree of -// freedom; this degree of freedom is resolved using the -// output_padding parameter. Both of these interfaces are equivalent, -// but they are differently convenient depending on the use case. - -Tensor cudnn_convolution_backward_input( - CheckedFrom c, - IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - checkAllSameType(c, {grad_output, weight}); - checkAllSameGPU(c, {grad_output, weight}); - - auto layout = cudnn_conv_use_channels_last(*grad_output, *weight) ? - at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - auto grad_input_t = at::empty(input_size, grad_output->options(), layout); - - // Avoid "grad_input" when this is being used as transposed convolution - TensorArg grad_input{ grad_input_t, "result", 0 }; - convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); - - // See #4500 - Tensor weight_contig = weight->contiguous(layout); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), layout); - - Tensor grad_output_contig = grad_output->contiguous(layout); - grad_output_contig.resize_(grad_output_contig.sizes(), layout); - - raw_cudnn_convolution_backward_input_out( - *grad_input, grad_output_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - - return *grad_input; -} - -Tensor cudnn_convolution_transpose_forward( - CheckedFrom c, - const TensorArg& grad_output, const TensorArg& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(), - padding, output_padding, stride, dilation, groups); - return cudnn_convolution_backward_input(c, input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -Tensor cudnn_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - weight{ weight_t, "weight", 2 }; - return cudnn_convolution_backward_input( - "cudnn_convolution_backward_input", - input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -std::tuple cudnn_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { - - Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); - - Tensor grad_input, grad_weight; - if (input.numel() == 0) { - if (output_mask[0]) { - grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - if (output_mask[1]) { - grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - } else { - if (output_mask[0]) { - grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - if (output_mask[1]) { - grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - } - } - - return std::tuple{grad_input, grad_weight}; -} - -Tensor cudnn_convolution_transpose( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) -{ - TensorArg input { input_t, "input", 1 }, - weight { weight_t, "weight", 2 }; - CheckedFrom c = "cudnn_convolution_transpose"; - auto output_t = cudnn_convolution_transpose_forward( - c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - return output_t; -} - // --------------------------------------------------------------------- // // Convolution backward (weight) @@ -1148,17 +728,22 @@ void raw_cudnn_convolution_backward_weight_out_32bit( // whether to use Tensor core kernels or not // See Note [behavior of cudnnFind and cudnnGet] ASSERT_CORRECT_PRECISION(bwdFilterAlgPerf.mathType); - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType)); + AT_CUDNN_CHECK_WITH_SHAPES(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType), args); Constant one(dataType, 1); Constant zero(dataType, 0); - AT_CUDNN_CHECK(cudnnConvolutionBackwardFilter( + AT_CUDNN_CHECK_WITH_SHAPES(cudnnConvolutionBackwardFilter( args.handle, &one, args.idesc.desc(), input.data_ptr(), args.odesc.desc(), grad_output.data_ptr(), args.cdesc.desc(), bwdFilterAlgPerf.algo, workspace.data_ptr(), bwdFilterAlgPerf.memory, - &zero, args.wdesc.desc(), grad_weight.data_ptr())); + &zero, args.wdesc.desc(), grad_weight.data_ptr()), + args, + "Additional pointer addresses: \n", + " grad_output: ", grad_output.data_ptr(), "\n", + " grad_weight: ", grad_weight.data_ptr(), "\n", + "Backward filter algorithm: ", static_cast(bwdFilterAlgPerf.algo), "\n"); } ); } @@ -1211,115 +796,6 @@ void raw_cudnn_convolution_backward_weight_out( TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); } -Tensor cudnn_convolution_backward_weight( - CheckedFrom c, - IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ? - at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - - Tensor grad_output_contig_t = grad_output_t.contiguous(layout); - // Make sure that NC11 strides follow formula - grad_output_contig_t.resize_(grad_output_contig_t.sizes(), layout); - TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 }; - - Tensor input_contig_t = input_t.contiguous(layout); - input_contig_t.resize_(input_contig_t.sizes(), layout); - TensorArg input{ input_contig_t, "input", 2}; - - checkAllSameType(c, {grad_output_contig, input}); - checkAllSameGPU(c, {grad_output_contig, input}); - - auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), layout); - - // For uniformity with everything else, although it seems grad_weight - // would be unambiguous too. - TensorArg grad_weight{ grad_weight_t, "result", 0 }; - convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups); - - raw_cudnn_convolution_backward_weight_out( - *grad_weight, *grad_output_contig, *input, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); - - return grad_weight_t; -} - -Tensor cudnn_convolution_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - return cudnn_convolution_backward_weight( - "cudnn_convolution_backward_weight", - weight_size, grad_output_t, input_t, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - -Tensor cudnn_convolution_transpose_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, bool allow_tf32) -{ - return cudnn_convolution_backward_weight( - "cudnn_convolution_backward_weight", - weight_size, input_t, grad_output_t, - padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -} - }} // namespace at::native #endif - - -namespace at { namespace native { - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_deprecated( - const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) { - auto output = at::cudnn_convolution(input, weight, padding, stride, dilation, groups, benchmark, deterministic); - if (bias.defined()) { - output = output + reshape_bias(input.dim(), bias); - } - return output; -} - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_deprecated2( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - return at::cudnn_convolution(input_t, weight_t, padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); -} - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_transpose_deprecated( - const Tensor& input, const Tensor& weight, const Tensor& bias /* optional */, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - auto output = at::cudnn_convolution_transpose(input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic); - if (bias.defined()) { - output = output + reshape_bias(input.dim(), bias); - } - return output; -} - -// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future -Tensor cudnn_convolution_transpose_deprecated2( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - return at::cudnn_convolution_transpose(input_t, weight_t, padding, output_padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); -} - -}} // namespace at::native diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp new file mode 100644 index 0000000000000..53f8c37f5e64e --- /dev/null +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -0,0 +1,5 @@ +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 +// Coming soon +#endif // AT_CUDNN_ENABLED and CUDNN_VERSION diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 5be7d6eea8ea2..8e1f254da9f80 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -19,7 +19,7 @@ namespace at { namespace native { Tensor _cudnn_rnn_flatten_weight( TensorList weight_arr, int64_t weight_stride0, int64_t input_size, - int64_t fn_mode, int64_t fn_hidden_size, + int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, bool fn_bidirectional ) { @@ -30,7 +30,7 @@ std::tuple _cudnn_rnn( const Tensor& input_r, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf_r, const Tensor& hx, const Tensor& cx, - int64_t fn_mode, int64_t fn_hidden_size, + int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state @@ -42,7 +42,7 @@ std::tuple> _cudnn_rnn_backward( const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output, const Tensor& grad_output_r, const Tensor& grad_hy_r, const Tensor& grad_cy_r, - int64_t mode, int64_t hidden_size, + int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const Tensor& dropout_state, const Tensor& reserve, @@ -92,6 +92,7 @@ namespace { struct RNNDescriptorParams { int64_t hidden_size; + int64_t proj_size; int64_t num_layers; cudnnDirectionMode_t bidirectional; cudnnRNNMode_t mode; @@ -135,19 +136,19 @@ namespace { this->algo = algo; } - void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype) { + void set(int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype) { this->set_mode(mode); this->hidden_size = hidden_size; + this->proj_size = proj_size; this->num_layers = num_layers; this->set_bidirectional(bidirectional); this->datatype = datatype; this->input_datatype = input_datatype; } - RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const { RNNDescriptor rnn_desc; - rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo, at::globalContext().allowTF32CuDNN()); + rnn_desc.set(handle, hidden_size, proj_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo, at::globalContext().allowTF32CuDNN()); return rnn_desc; } @@ -359,7 +360,7 @@ namespace { size_t weight_size; AT_CUDNN_CHECK(cudnnGetRNNParamsSize(handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype)); auto elem_size = dataSize(datatype); - AT_ASSERTM(weight_size % elem_size == 0, "cudnnGetRNNParamsSize returned nonsensical weight_size"); + TORCH_INTERNAL_ASSERT(weight_size % elem_size == 0, "cudnnGetRNNParamsSize returned nonsensical weight_size"); return weight_size / elem_size; } @@ -378,6 +379,58 @@ namespace { } } + void add_projection_weights( + cudnnHandle_t handle, + const RNNDescriptor& rnn_desc, + const TensorDescriptor& x_desc, + const FilterDescriptor& w_desc, + const Tensor& weight_buf, + int64_t layer, + std::vector& params + ) { + void* matrix_pointer = nullptr; + // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4 biases) + int64_t linear_id = 8; + FilterDescriptor lin_layer_mat_desc; + AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams( + /*handle=*/handle, + /*rnnDesc=*/rnn_desc.desc(), + /*layer=*/layer, + /*xDesc=*/x_desc.desc(), + /*wDesc=*/w_desc.desc(), + /*w=*/weight_buf.data_ptr(), + /*linLayerID=*/linear_id, + /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), + /*linLayerMat=*/&matrix_pointer)); + + cudnnDataType_t data_type; + cudnnTensorFormat_t format; + int nb_dims; + constexpr int min_dim = 3; + int filter_dim_a[min_dim]; + AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor( + lin_layer_mat_desc.desc(), + min_dim, + &data_type, + &format, + &nb_dims, + filter_dim_a + )); + + TORCH_INTERNAL_ASSERT(nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim); + auto elem_size = dataSize(getCudnnDataType(weight_buf)); + auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr(); + TORCH_INTERNAL_ASSERT(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size); + size_t offset = offset_bytes / elem_size; + + int mat_numel = prod_intlist(filter_dim_a, filter_dim_a + nb_dims); + // Generate a new parameter tensor which is a view into the weight_buf. + std::initializer_list size = {mat_numel, 1}; + Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size); + params.emplace_back(std::move(param)); + } + + /* Returns weight and bias tensors for each layer of the RNN. These tensors are views on the underlying weight buffer allocated by CuDNN. @@ -433,24 +486,20 @@ namespace { cudnnTensorFormat_t format; int nb_dims; constexpr int min_dim = 3; - // TODO: The use of CPU tensor here is a bit goofy in C++, - // some sort of alloca would be good enough except that it is - // kind of convenient to be able to prod() on it. - Tensor filter_dim_a = at::empty(min_dim, at::initialTensorOptions().dtype(kInt)); + int filter_dim_a[min_dim]; AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor( lin_layer_mat_desc.desc(), min_dim, &data_type, &format, &nb_dims, - filter_dim_a.data_ptr() + filter_dim_a )); - AT_ASSERTM(nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim); - filter_dim_a = filter_dim_a.slice(0, 0, nb_dims); + TORCH_INTERNAL_ASSERT(nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim); auto elem_size = dataSize(getCudnnDataType(weight_buf)); auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr(); - AT_ASSERTM(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size); + TORCH_INTERNAL_ASSERT(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size); size_t offset = offset_bytes / elem_size; // for all the RNN types provided by CUDNN, all the ih weights @@ -458,7 +507,7 @@ namespace { // (same for the hh weights, and the ih and hh biases). // Since we're storing all the weights in a single tensor anyway, // might as well merge the CUDNN ones into a single tensor as well - int mat_numel = *filter_dim_a.prod(at::ScalarType::Int).data_ptr(); + int mat_numel = prod_intlist(filter_dim_a, filter_dim_a + nb_dims); if (linear_id == 0 || linear_id == num_linear_layers / 2) { // We could also exclude bias params by restricting cudnn_methods to just { cudnnGetRNNLinLayerMatrixParams } // at the very top. However, to do so would throw off the cur_offset account, which is currently a strict @@ -474,15 +523,20 @@ namespace { layer_params_count++; } } else { - AT_ASSERTM(cur_offset == offset, "cur_offset = ", cur_offset, "; offset = ", offset); + TORCH_INTERNAL_ASSERT(cur_offset == offset, "cur_offset = ", cur_offset, "; offset = ", offset); } cur_offset = offset + mat_numel; } } // for cudnn_method + if (rnn.proj_size != 0) { + add_projection_weights(handle, rnn_desc, x_desc, w_desc, weight_buf, layer, params); + layer_params_count++; + } + if (layer == 0) { global_layer_params_count = layer_params_count; } else { - AT_ASSERTM(global_layer_params_count == layer_params_count, + TORCH_INTERNAL_ASSERT(global_layer_params_count == layer_params_count, "global_layer_params_count = ", global_layer_params_count, "; layer_params_count = ", layer_params_count); } @@ -502,7 +556,11 @@ namespace { int64_t num_dir_layers = rnn.num_directions() * rnn.num_layers; const auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams }; std::vector data_ptrs; - data_ptrs.reserve(num_dir_layers * 2 * 2); + if (rnn.proj_size != 0) { + data_ptrs.reserve(num_dir_layers * (2 * 2 + 1)); + } else { + data_ptrs.reserve(num_dir_layers * 2 * 2); + } for (int64_t layer = 0; layer < num_dir_layers; layer++) { for (auto cudnn_method : cudnn_methods) { // This API returns a separate pointer for weight of every gate, @@ -526,34 +584,73 @@ namespace { data_ptrs.push_back(matrix_pointer); } } + if (rnn.proj_size != 0) { + // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4 biases) + int64_t linear_id = 8; + FilterDescriptor lin_layer_mat_desc; + void* matrix_pointer; + AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams( + handle, + rnn_desc.desc(), + layer, + x_desc.desc(), + w_desc.desc(), + weight_buf.data_ptr(), + linear_id, + lin_layer_mat_desc.mut_desc(), + &matrix_pointer + )); + data_ptrs.push_back(matrix_pointer); + } } return data_ptrs; } + void _viewOrCopyOneParam(const Tensor& param_from, const Tensor& param_to, + bool copy, bool allow_type_change=false) { + // if copying, allow_type_change may be true or false. + // if viewing, allow_type_change must be false. + TORCH_INTERNAL_ASSERT(copy || !allow_type_change, + "if viewing, type change is not allowed."); + TORCH_INTERNAL_ASSERT(allow_type_change || (param_from.scalar_type() == param_to.scalar_type()), + "parameter types mismatch"); + if (copy) { + param_to.copy_(param_from.view_as(param_to)); + } else { + param_from.resize_as_(param_to); + } + } + void _viewOrCopyParams(MatrixRef params_from, MatrixRef params_to, bool copy, bool allow_type_change=false) { - AT_ASSERTM(params_from.size(0) == params_to.size(0), "number of layers mismatch"); + TORCH_INTERNAL_ASSERT(params_from.size(0) == params_to.size(0), "number of layers mismatch"); for (size_t i = 0; i < params_from.size(0); i++) { auto layer_params_from = params_from[i]; auto layer_params_to = params_to[i]; // NOTE: these lists have all weights before all biases, so if the layer // doesn't use biases, iteration will terminate once layer_params_from ends // and ignore them. + + // NOTE: there is an exception from the above statement. If LSTMs with projections + // are used, weights layout will be w_ih, w_hh, b_ih, b_hh, w_hr. So need to handle no-bias + // case specially, because will need to copy 0->0, 1->1, 2->4. This case can be uniquely + // identified by checking if number of defined parameters for each layer is 3. + if (layer_params_from.size() == 3 && layer_params_to.size() != 3) { + _viewOrCopyOneParam(layer_params_from[0], layer_params_to[0], copy, allow_type_change); + _viewOrCopyOneParam(layer_params_from[1], layer_params_to[1], copy, allow_type_change); + _viewOrCopyOneParam(layer_params_from[2], layer_params_to[4], copy, allow_type_change); + continue; + } + if (layer_params_to.size() == 3 && layer_params_from.size() != 3) { + _viewOrCopyOneParam(layer_params_from[0], layer_params_to[0], copy, allow_type_change); + _viewOrCopyOneParam(layer_params_from[1], layer_params_to[1], copy, allow_type_change); + _viewOrCopyOneParam(layer_params_from[4], layer_params_to[2], copy, allow_type_change); + continue; + } for (auto a = layer_params_from.begin(), b = layer_params_to.begin(); - a != layer_params_from.end() && b != layer_params_to.end(); - ++a, ++b) { - auto param_from = *a, param_to = *b; - // if copying, allow_type_change may be true or false. - // if viewing, allow_type_change must be false. - TORCH_INTERNAL_ASSERT(copy || !allow_type_change, - "if viewing, type change is not allowed."); - TORCH_INTERNAL_ASSERT(allow_type_change || (param_from.scalar_type() == param_to.scalar_type()), - "parameter types mismatch"); - if (copy) { - param_to.copy_(param_from.view_as(param_to)); - } else { - param_from.resize_as_(param_to); - } + a != layer_params_from.end() && b != layer_params_to.end(); + ++a, ++b) { + _viewOrCopyOneParam(*a, *b, copy, allow_type_change); } } } @@ -576,36 +673,93 @@ namespace { } std::vector _hidden_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { + if (rnn.proj_size != 0) { + return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.proj_size}; + } else { + return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size}; + } + } + + std::vector _cell_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size}; } std::vector _output_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) { + auto out_size = rnn.hidden_size; + if (rnn.proj_size != 0) { + out_size = rnn.proj_size; + } if (tensors.is_input_packed()) { - return {tensors.batch_sizes_sum, rnn.hidden_size * rnn.num_directions()}; + return {tensors.batch_sizes_sum, out_size * rnn.num_directions()}; } else { - return {tensors.seq_length, tensors.mini_batch, rnn.hidden_size * rnn.num_directions()}; + return {tensors.seq_length, tensors.mini_batch, out_size * rnn.num_directions()}; } } - cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input){ - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - const int64_t bsize = tensors.mini_batch; - //excluding Turing from using persistent rnn. - if (prop->major == 7 && prop->minor != 5 && getCudnnDataType(input) == CUDNN_DATA_HALF && !tensors.is_input_packed()) { - if (rnn.num_layers == 1 && rnn.hidden_size <= 1024 && rnn.num_directions() == 1 && - rnn.hidden_size % 128 == 0 && tensors.input_size % 128 == 0){ - //technically, batch size should be multiple of 8, but there are quite a few multiple-of-8 batchsizes that give bad perf, - //weed them out - if ((bsize % 16 == 0 && bsize != 80 && bsize !=112) || bsize == 8){ - if ((tensors.seq_length >=40 && bsize <=128) || - (tensors.seq_length >=20 && bsize <=96) || - (tensors.seq_length >=10 && bsize <=32)) { - return CUDNN_RNN_ALGO_PERSIST_STATIC; - } - } - } + inline bool use_persist_common_heuristics(const RNNDescriptorParams& rnn, + const TensorDescriptorListParams& tensors) { + return rnn.num_layers == 1 && + rnn.hidden_size <= 1024 && + rnn.num_directions() == 1 && + rnn.hidden_size % 128 == 0 && + tensors.input_size % 128 == 0; + } + + inline bool use_persist_device_heuristics(const RNNDescriptorParams& rnn, + const TensorDescriptorListParams& tensors) { + auto bsize = tensors.mini_batch; + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major == 7) { + if (prop->minor == 5) { + // Excludes Turing from using persistent rnn. + return false; + } else { + // technically, batch size should be multiple of 8, but there are quite a few multiple-of-8 batchsizes that give bad perf, + // weed them out + return ((bsize % 16 == 0 && bsize != 80 && bsize !=112) || bsize == 8) && + ((tensors.seq_length >=40 && bsize <=128) || + (tensors.seq_length >=20 && bsize <=96) || + (tensors.seq_length >=10 && bsize <=32)); + } + } else if (prop->major >= 8) { + if (prop->minor == 6) { + // Excludes sm_86 GPU devices from using persistent rnn. + // This is because there are some edge cases that will throw exceptions with cudnn 8.0.5 on Nvidia A40 GPU. + return false; } + // Based on tests by Vasily Volkov and xwang233. Vasily only tried bsize <= 128, + // so conservatively enable persistence for bsize <= 128 only. + // TODO: Run more tests for bsize > 128. + if (rnn.mode == CUDNN_GRU) { + // Persistent GRU performance is flakier than other RNN types. Exclude them for now. + // TODO: Write a more refined GRU heuristic. + return false; + } else if (rnn.mode == CUDNN_LSTM) { + // Persistent LSTMs are comparable to or better than non-persistent for bsize <= 128. + return (bsize % 8 == 0) && (bsize <= 128); + } else { + // Persistent RNN_RELU and TANH show poor performance when bsize >= 96 AND hidden size >= 896. + return (bsize % 8 == 0) && (bsize <= 128) && (bsize < 96 || rnn.hidden_size < 896); + } + } else { + return false; + } + } + + cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input) { + // LSTM with projections only works with standard algorithm + if (rnn.proj_size != 0) { return CUDNN_RNN_ALGO_STANDARD; + } + + if (getCudnnDataType(input) == CUDNN_DATA_HALF && + !tensors.is_input_packed()) { + if (use_persist_common_heuristics(rnn, tensors) && + use_persist_device_heuristics(rnn, tensors)) { + return CUDNN_RNN_ALGO_PERSIST_STATIC; + } + } + return CUDNN_RNN_ALGO_STANDARD; } cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) { @@ -626,6 +780,7 @@ namespace cudnn_rnn { int64_t input_size, int64_t mode, int64_t hidden_size, + int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional, @@ -638,12 +793,11 @@ namespace cudnn_rnn { // because to extract flat_buf_datatype from flat_buf_options, we'd need to say // auto flat_buf_datatype = getCudnnDataTypeFromScalarType(typeMetaToScalarType(options.dtype())); // typeMetaToScalarType is a surprisingly nontrivial function. We should avoid it if we can. - TORCH_CHECK(weight_arr.size() > 0, "copy_weights_to_flat_buf_views: cannot flatten empty weight list"); RNNDescriptorParams rnn; - rnn.set(mode, hidden_size, num_layers, bidirectional, promote_rnn_math_type(flat_buf_datatype), flat_buf_datatype); + rnn.set(mode, hidden_size, proj_size, num_layers, bidirectional, promote_rnn_math_type(flat_buf_datatype), flat_buf_datatype); auto handle = getCudnnHandle(); RNNDescriptor rnn_desc = rnn.descriptor(handle); @@ -665,21 +819,27 @@ namespace cudnn_rnn { std::vector params_arr; size_t params_stride0; std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf, include_bias); - MatrixRef weight{weight_arr, static_cast(weight_stride0)}, params{params_arr, params_stride0}; // Copy weights _viewOrCopyParams(weight, params, /*copy=*/true, allow_type_change); - if (set_orig_weights_to_flat_buf) { // Update the storage for (size_t i = 0; i < weight.size(0); i++) { - for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin(); - orig_param_it != weight[i].end() && new_param_it != params[i].end(); - orig_param_it++, new_param_it++) { - auto orig_param = *orig_param_it, new_param = *new_param_it; - orig_param.set_(new_param.view_as(orig_param)); + // There is a special case for LSTM with projections and no bias, + // where weight copy is done in 0->0, 1->1, 2->4 layout + if (weight[i].size() == 3 && params[i].size() == 5) { + weight[i][0].set_(params[i][0].view_as(weight[i][0])); + weight[i][1].set_(params[i][1].view_as(weight[i][1])); + weight[i][2].set_(params[i][4].view_as(weight[i][2])); + } else { + for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin(); + orig_param_it != weight[i].end() && new_param_it != params[i].end(); + orig_param_it++, new_param_it++) { + auto orig_param = *orig_param_it, new_param = *new_param_it; + orig_param.set_(new_param.view_as(orig_param)); + } } } } @@ -698,7 +858,7 @@ using namespace cudnn_rnn; Tensor _cudnn_rnn_flatten_weight( TensorList weight_arr, int64_t weight_stride0, int64_t input_size, - int64_t fn_mode, int64_t fn_hidden_size, + int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, bool fn_bidirectional ) { @@ -709,6 +869,7 @@ Tensor _cudnn_rnn_flatten_weight( input_size, fn_mode, fn_hidden_size, + fn_proj_size, fn_num_layers, batch_first, fn_bidirectional, @@ -727,12 +888,11 @@ std::tuple _cudnn_rnn( const Tensor& input_r, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf_r, const Tensor& hx, const Tensor& cx, - int64_t fn_mode, int64_t fn_hidden_size, + int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state ) { - check_attributes(input_r, weight, {hx, cx}, /*check_dtype=*/true); auto input = input_r; auto weight_buf = weight_buf_r; @@ -746,7 +906,7 @@ std::tuple _cudnn_rnn( } RNNParams fn; auto datatype = getCudnnDataType(input); - fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); + fn.rnn.set(fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); fn.dropout.set(fn_train, fn_dropout, fn_dropout_state); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); @@ -764,6 +924,7 @@ std::tuple _cudnn_rnn( } auto hidden_size = _hidden_size(fn.rnn, fn.tensors); + auto cell_size = _cell_size(fn.rnn, fn.tensors); auto output_size = _output_size(fn.rnn, fn.tensors); TORCH_CHECK(hx.is_contiguous(), @@ -776,7 +937,7 @@ std::tuple _cudnn_rnn( auto hy = at::empty(hidden_size, hx.options()); Tensor cy; if (cx.defined()) { - cy = at::empty(hidden_size, cx.options()); + cy = at::empty(cell_size, cx.options()); } else { cy = at::empty({0}, hx.options()); // NB: Not allowed to return undefined tensors } @@ -802,9 +963,8 @@ std::tuple _cudnn_rnn( w_desc.set(weight_buf, 3); } - TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size), - "Expected cell size ", IntArrayRef{hidden_size}, ", got ", cx.sizes()); - + TORCH_CHECK(!cx.defined() || cx.sizes().equals(cell_size), + "Expected cell size ", IntArrayRef{cell_size}, ", got ", cx.sizes()); size_t workspace_size; auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); @@ -816,7 +976,6 @@ std::tuple _cudnn_rnn( &workspace_size )); Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte)); - Tensor reserve; // NB: Previously, the test was for fn.requires_grad, but we don't have // this information. Use 'train' as a proxy. @@ -873,7 +1032,7 @@ std::tuple _cudnn_rnn_backward_input( const Tensor& input_r, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output_r, const Tensor& grad_output_r, const Tensor& grad_hy, const Tensor& grad_cy, - int64_t fn_mode, int64_t fn_hidden_size, + int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state, const Tensor& fn_reserve, @@ -886,7 +1045,7 @@ std::tuple _cudnn_rnn_backward_input( RNNParams fn; auto datatype = getCudnnDataType(input); - fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); + fn.rnn.set(fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); fn.dropout.set(fn_train, fn_dropout, fn_dropout_state); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); @@ -907,6 +1066,7 @@ std::tuple _cudnn_rnn_backward_input( auto input_size = _input_size(fn.tensors); auto hidden_size = _hidden_size(fn.rnn, fn.tensors); + auto cell_size = _cell_size(fn.rnn, fn.tensors); auto output_size = _output_size(fn.rnn, fn.tensors); TORCH_CHECK(hx.is_contiguous(), @@ -920,10 +1080,10 @@ std::tuple _cudnn_rnn_backward_input( auto w = weight_buf; auto dx = at::empty(input.sizes(), input.options()); // TODO: more compact way of saying this auto dhy = grad_hy.contiguous().view(hidden_size); - auto dcy = grad_cy.defined() ? grad_cy.contiguous().view(hidden_size) : Tensor(); + auto dcy = grad_cy.defined() ? grad_cy.contiguous().view(cell_size) : Tensor(); auto dhx = at::empty(hidden_size, hx.options()); - AT_ASSERTM(cx.defined() || !output_mask[2], "illegally required grad of cx for non-LSTM RNN"); - auto dcx = cx.defined() ? at::empty(hidden_size, cx.options()) : Tensor(); + TORCH_INTERNAL_ASSERT(cx.defined() || !output_mask[2], "illegally required grad of cx for non-LSTM RNN"); + auto dcx = cx.defined() ? at::empty(cell_size, cx.options()) : Tensor(); TORCH_CHECK(fn_train, "cudnn RNN backward can only be called in training mode"); @@ -935,12 +1095,12 @@ std::tuple _cudnn_rnn_backward_input( TORCH_CHECK(!hx.defined() || hx.sizes().equals(hidden_size), "Expected hidden size ", IntArrayRef{hidden_size}, ", got ", hx.sizes()); - TORCH_CHECK(!cx.defined() || cx.sizes().equals(hidden_size), - "Expected cell size ", IntArrayRef{hidden_size}, ", got ", cx.sizes()); + TORCH_CHECK(!cx.defined() || cx.sizes().equals(cell_size), + "Expected cell size ", IntArrayRef{cell_size}, ", got ", cx.sizes()); TORCH_CHECK(!dhy.defined() || dhy.sizes().equals(hidden_size), "Expected d_hidden size ", IntArrayRef{hidden_size}, ", got ", dhy.sizes()); - TORCH_CHECK(!dcy.defined() || dcy.sizes().equals(hidden_size), - "Expected d_cell size ", IntArrayRef{hidden_size}, ", got ", dcy.sizes()); + TORCH_CHECK(!dcy.defined() || dcy.sizes().equals(cell_size), + "Expected d_cell size ", IntArrayRef{cell_size}, ", got ", dcy.sizes()); TORCH_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()), "Gradients aren't CUDA tensors"); @@ -996,20 +1156,19 @@ std::vector _cudnn_rnn_backward_weight( const Tensor& input_r, TensorList weight_arr, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output_r, - int64_t fn_mode, int64_t fn_hidden_size, + int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const Tensor& fn_dropout_state, const Tensor& fn_reserve ) { MatrixRef weight{ weight_arr, static_cast(weight_stride0) }; - auto input = input_r; auto output = output_r; RNNParams fn; auto datatype = getCudnnDataType(input); - fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); + fn.rnn.set(fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype); fn.dropout.set(fn_train, fn_dropout, fn_dropout_state); fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first); @@ -1105,7 +1264,7 @@ std::tuple> _cudnn_rnn_backward( const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const Tensor& cx, const Tensor& output, const Tensor& grad_output_r, const Tensor& grad_hy_r, const Tensor& grad_cy_r, - int64_t mode, int64_t hidden_size, + int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const Tensor& dropout_state, const Tensor& reserve, @@ -1121,10 +1280,10 @@ std::tuple> _cudnn_rnn_backward( Tensor dx, dhx, dcx; // NB: unconditionally compute this gradient, because it mutates reserve - std::tie(dx, dhx, dcx) = at::native::_cudnn_rnn_backward_input(input, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {output_mask[0], output_mask[1], output_mask[2]}); + std::tie(dx, dhx, dcx) = at::native::_cudnn_rnn_backward_input(input, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {output_mask[0], output_mask[1], output_mask[2]}); std::vector dw; if (output_mask[3]) { - dw = at::native::_cudnn_rnn_backward_weight(input, weight, weight_stride0, weight_buf, hx, cx, output, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve); + dw = at::native::_cudnn_rnn_backward_weight(input, weight, weight_stride0, weight_buf, hx, cx, output, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve); } return std::tuple>{dx, dhx, dcx, dw}; } @@ -1231,7 +1390,7 @@ DropoutState& get_dropout_state(double dropout_p, bool train, TensorOptions opti Tensor try_get_weight_buf( const Tensor& input, TensorList parameters, bool has_biases, - cudnnRNNMode_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional) { + cudnnRNNMode_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool bidirectional) { // Prepare all relevant descriptors auto handle = getCudnnHandle(); @@ -1239,7 +1398,7 @@ Tensor try_get_weight_buf( auto datatype = getCudnnDataType(any_param); RNNDescriptorParams rnn; - rnn.set(mode, hidden_size, num_layers, bidirectional, promote_rnn_math_type(datatype), datatype); + rnn.set(mode, hidden_size, proj_size, num_layers, bidirectional, promote_rnn_math_type(datatype), datatype); RNNDescriptor rnn_desc = rnn.descriptor(handle); TensorGeometry x_geom ({1, input.size(-1)}); @@ -1266,13 +1425,34 @@ Tensor try_get_weight_buf( int64_t num_parameters = parameters.size(); int64_t num_ptrs = expected_data_ptrs.size(); - AT_ASSERT(num_ptrs == (num_parameters * (has_biases ? 1 : 2))); - AT_ASSERT(num_ptrs % (has_biases ? 4 : 2) == 0); - for (int64_t param_i = 0, ptr_i = 0; - ptr_i < num_ptrs; - ptr_i += (has_biases ? 2 : 4), param_i += 2) { - if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr()) return {}; - if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr()) return {}; + if (proj_size != 0) { + AT_ASSERT(num_parameters % (has_biases ? 5 : 3) == 0); + AT_ASSERT(num_ptrs % 5 == 0); + if (has_biases) { + AT_ASSERT(num_ptrs == num_parameters); + for (int64_t i = 0; i < num_parameters; i++) { + if (expected_data_ptrs[i] != parameters[i].data_ptr()) return {}; + } + } else { + AT_ASSERT(num_parameters % 3 == 0); + AT_ASSERT(num_ptrs == num_parameters * 5 / 3); + for (int64_t param_i = 0, ptr_i = 0; + ptr_i < num_ptrs; + ptr_i += 5, param_i += 3) { + if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr()) return {}; + if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr()) return {}; + if (expected_data_ptrs[ptr_i + 4] != parameters[param_i + 2].data_ptr()) return {}; + } + } + } else { + AT_ASSERT(num_ptrs == (num_parameters * (has_biases ? 1 : 2))); + AT_ASSERT(num_parameters % (has_biases ? 4 : 2) == 0); + for (int64_t param_i = 0, ptr_i = 0; + ptr_i < num_ptrs; + ptr_i += (has_biases ? 2 : 4), param_i += 2) { + if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr()) return {}; + if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr()) return {}; + } } if (!parameters[num_parameters - 1].is_contiguous()) return {}; return weight_buf; @@ -1286,22 +1466,32 @@ std::pair _cudnn_impl( Tensor hx, cx; std::tie(hx, cx) = unpack_hidden(hidden); int64_t hidden_size = hx.size(2); + int64_t proj_size = 0; + // For LSTM models with projections hidden size could be different + if (cx.defined() && cx.size(2) != hx.size(2)) { + hidden_size = cx.size(2); + proj_size = hx.size(2); + } // TODO: try_get_weight_buf returns a Tensor, but _cudnn_rnn below takes a c10::optional // in weight_buf's slot. Do we want try_get_weight_buf to return a c10::optional // instead of a defined or undefined Tensor? auto weight_buf = try_get_weight_buf( - input, params, has_biases, mode, hidden_size, num_layers, bidirectional); + input, params, has_biases, mode, hidden_size, proj_size, num_layers, bidirectional); TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D"); IntArrayRef batch_sizes { _batch_sizes.data_ptr(), static_cast(_batch_sizes.size(0)) }; auto & dropout_state = get_dropout_state(dropout_p, train, input.options()); std::unique_lock lock { dropout_state }; + int64_t num_params = has_biases ? 4 : 2; + if (proj_size != 0) { + ++num_params; + } // cudnn_output = std::tuple auto cudnn_output = at::_cudnn_rnn( - input, params, has_biases ? 4 : 2, weight_buf, - hx, cx, static_cast(mode), hidden_size, num_layers, /*batch_first=*/false, + input, params, num_params, weight_buf, + hx, cx, static_cast(mode), hidden_size, proj_size, num_layers, /*batch_first=*/false, dropout_p, train, bidirectional, batch_sizes, dropout_state.buffer); return {std::get<0>(cudnn_output), @@ -1316,16 +1506,24 @@ std::pair _cudnn_impl( Tensor hx, cx; std::tie(hx, cx) = unpack_hidden(hidden); int64_t hidden_size = hx.size(2); - + int64_t proj_size = 0; + // For LSTM models with projections hidden size could be different + if (cx.defined() && cx.size(2) != hx.size(2)) { + hidden_size = cx.size(2); + proj_size = hx.size(2); + } auto weight_buf = try_get_weight_buf( - input, params, has_biases, mode, hidden_size, num_layers, bidirectional); - + input, params, has_biases, mode, hidden_size, proj_size, num_layers, bidirectional); auto & dropout_state = get_dropout_state(dropout_p, train, input.options()); std::unique_lock lock { dropout_state }; + int64_t num_params = has_biases ? 4 : 2; + if (proj_size != 0) { + ++num_params; + } // cudnn_output = std::tuple auto cudnn_output = at::_cudnn_rnn( - input, params, has_biases ? 4 : 2, weight_buf, - hx, cx, static_cast(mode), hidden_size, num_layers, batch_first, dropout_p, + input, params, num_params, weight_buf, + hx, cx, static_cast(mode), hidden_size, proj_size, num_layers, batch_first, dropout_p, train, bidirectional, /*batch_sizes=*/{}, dropout_state.buffer); return {std::get<0>(cudnn_output), diff --git a/aten/src/ATen/native/cudnn/RNNUtils.h b/aten/src/ATen/native/cudnn/RNNUtils.h index 89b58ebef1d8a..e1b79bb3c81fa 100644 --- a/aten/src/ATen/native/cudnn/RNNUtils.h +++ b/aten/src/ATen/native/cudnn/RNNUtils.h @@ -14,6 +14,7 @@ TORCH_CUDA_API std::tuple> copy_weights_to_flat_buf_ int64_t input_size, int64_t mode, int64_t hidden_size, + int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional, diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index beb4d940363e3..51ec4cc507a71 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -29,18 +29,7 @@ std::tuple native_group_norm( Tensor mean = at::empty({N, group}, X.options()); Tensor rstd = at::empty({N, group}, X.options()); GroupNormKernel( - X.device().type(), - X, - gamma, - beta, - N, - C, - HxW, - group, - eps, - &Y, - &mean, - &rstd); + X.device().type(), X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd); return std::make_tuple(Y, mean, rstd); } @@ -78,9 +67,9 @@ std::tuple native_group_norm_backward( C, HxW, group, - &dX, - &dgamma, - &dbeta); + dX, + dgamma, + dbeta); return std::make_tuple(dX, dgamma, dbeta); } @@ -117,15 +106,15 @@ Tensor group_norm( input.sizes()); const auto input_shape = input.sizes(); - const int64_t HxW = std::accumulate( - input_shape.cbegin() + 2, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t HxW = + prod_intlist(input_shape.cbegin() + 2, input_shape.cend()); + const Tensor kEmpty; const auto& X = input.is_contiguous() ? input : input.contiguous(); - const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); - const auto& beta = bias.is_contiguous() ? bias : bias.contiguous(); + const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty; + const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; + TORCH_CHECK(!gamma.defined() || gamma.numel() == C); + TORCH_CHECK(!beta.defined() || beta.numel() == C); return std::get<0>( at::native_group_norm(X, gamma, beta, N, C, HxW, num_groups, eps)); } @@ -133,21 +122,34 @@ Tensor group_norm( DEFINE_DISPATCH(GroupNormKernel); DEFINE_DISPATCH(GroupNormBackwardKernel); +// Ported from pytorch/xla repo std::tuple math_group_norm( - const at::Tensor& input, const at::Tensor& weight, - const at::Tensor& bias, int64_t N, int64_t C, int64_t HxW, - int64_t group, double eps) { + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, + int64_t N, + int64_t C, + int64_t HxW, + int64_t group, + double eps) { auto input_shape = input.sizes(); at::Tensor input_reshaped = input.view({1, N * group, N ? -1 : 1}); auto outputs = at::native_batch_norm( - input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{}, - /*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps); + input_reshaped, + /*weight=*/{}, + /*bias=*/{}, + /*running_mean=*/{}, + /*running_var=*/{}, + /*training=*/true, + /*momentum=*/0, + eps); at::Tensor out = std::get<0>(outputs); out = out.view(input_shape); std::vector affine_param_shape(input.dim(), 1); affine_param_shape[1] = C; if (weight.defined() && bias.defined()) { - out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1); + out = bias.view(affine_param_shape) + .addcmul(out, weight.view(affine_param_shape), 1); } else if (weight.defined()) { out = out.mul(weight.view(affine_param_shape)); } else if (bias.defined()) { diff --git a/aten/src/ATen/native/group_norm.h b/aten/src/ATen/native/group_norm.h index 73e91541003ac..58fc0867b1ac6 100644 --- a/aten/src/ATen/native/group_norm.h +++ b/aten/src/ATen/native/group_norm.h @@ -15,9 +15,9 @@ using forward_fn = void (*)( int64_t /* HxW */, int64_t /* group */, double /* eps */, - Tensor* /* Y */, - Tensor* /* mean */, - Tensor* /* rstd */); + Tensor& /* Y */, + Tensor& /* mean */, + Tensor& /* rstd */); using backward_fn = void (*)( const Tensor& /* dY */, @@ -29,9 +29,9 @@ using backward_fn = void (*)( int64_t /* C */, int64_t /* HxW */, int64_t /* group */, - Tensor* /* dX */, - Tensor* /* dgamma */, - Tensor* /* dbeta */); + Tensor& /* dX */, + Tensor& /* dgamma */, + Tensor& /* dbeta */); DECLARE_DISPATCH(forward_fn, GroupNormKernel); DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel); diff --git a/aten/src/ATen/native/im2col_shape_check.h b/aten/src/ATen/native/im2col_shape_check.h index a2c57b079eba1..78725bb81ff7e 100644 --- a/aten/src/ATen/native/im2col_shape_check.h +++ b/aten/src/ATen/native/im2col_shape_check.h @@ -37,9 +37,11 @@ static inline void col2im_shape_check( dilation_width); int64_t ndim = input.ndimension(); + // allow dim=0 only the batch dimension. TORCH_CHECK( - input.numel() != 0 && (ndim == 2 || ndim == 3), - "Expected non-empty 2D or 3D input tensor, but got input of sizes", + (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) || + (ndim == 3 && input.size(1) != 0 && input.size(2) != 0), + "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ", input.sizes()); int64_t batch_dim = (ndim == 3) ? 0 : -1; @@ -155,9 +157,12 @@ static inline void im2col_shape_check( int64_t ndim = input.ndimension(); + // allow dim=0 only the batch dimension. + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; TORCH_CHECK( - input.numel() != 0 && (ndim == 3 || ndim == 4), - "Expected non-empty 3D or 4D input tensor, but got input of size ", + (ndim == 3 && input.size(0) && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", input.sizes()); int64_t dim_batch = 0; diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 1f311d6fcdba8..a639fe41d534a 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -18,47 +18,76 @@ namespace at { namespace native { std::tuple layer_norm_cpu( - const Tensor& X, - const Tensor& gamma /* optional */, - const Tensor& beta /* optional */, - int64_t M, - int64_t N, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */, double eps) { + + auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias); + auto X = std::get<0>(inputs); + auto gamma = std::get<1>(inputs); + auto beta = std::get<2>(inputs); + auto M = std::get<3>(inputs); + auto N = std::get<4>(inputs); + Tensor Y = at::native::empty_like(X, at::MemoryFormat::Contiguous); Tensor mean = at::empty({M}, X.options()); Tensor rstd = at::empty({M}, X.options()); if (M > 0) { LayerNormKernel(kCPU, X, gamma, beta, M, N, eps, &Y, &mean, &rstd); + + const auto input_shape = input.sizes(); + const size_t axis = input.dim() - normalized_shape.size(); + + std::vector stat_shape; + for (size_t idx = 0; idx < axis; ++idx) { + stat_shape.push_back(input_shape[idx]); + } + for (size_t idx = axis; idx < input.dim(); ++idx) { + stat_shape.push_back(1); + } + + mean = mean.view(stat_shape); + rstd = rstd.view(stat_shape); } return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd)); } std::tuple layer_norm_backward_cpu( const Tensor& dY, - const Tensor& X, + const Tensor& input, + IntArrayRef normalized_shape, const Tensor& mean, const Tensor& rstd, - const Tensor& gamma, - int64_t M, - int64_t N, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */, std::array grad_input_mask) { - Tensor dX; - Tensor dgamma; - Tensor dbeta; - if (grad_input_mask[0]) { - dX = at::native::empty_like(X, at::MemoryFormat::Contiguous); - } - if (grad_input_mask[1]) { - dgamma = M > 0 ? at::native::empty_like(gamma, at::MemoryFormat::Contiguous) : at::native::zeros_like(gamma, at::MemoryFormat::Contiguous); - } - if (grad_input_mask[2]) { - dbeta = M > 0 ? at::native::empty_like(gamma, at::MemoryFormat::Contiguous) : at::native::zeros_like(gamma, at::MemoryFormat::Contiguous); - } - if (M > 0) { - LayerNormBackwardKernel( - kCPU, dY, X, mean, rstd, gamma, M, N, &dX, &dgamma, &dbeta); - } - return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); + + auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias); + auto X = std::get<0>(inputs); + auto gamma = std::get<1>(inputs); + auto beta = std::get<2>(inputs); + auto M = std::get<3>(inputs); + auto N = std::get<4>(inputs); + + Tensor dX; + Tensor dgamma; + Tensor dbeta; + if (grad_input_mask[0]) { + dX = at::native::empty_like(X, at::MemoryFormat::Contiguous); + } + if (grad_input_mask[1]) { + dgamma = M > 0 ? at::native::empty_like(gamma, at::MemoryFormat::Contiguous) : at::native::zeros_like(gamma, at::MemoryFormat::Contiguous); + } + if (grad_input_mask[2]) { + dbeta = M > 0 ? at::native::empty_like(beta, at::MemoryFormat::Contiguous) : at::native::zeros_like(beta, at::MemoryFormat::Contiguous); + } + if (M > 0) { + LayerNormBackwardKernel( + kCPU, dY, X, mean, rstd, gamma, M, N, &dX, &dgamma, &dbeta); + } + return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } Tensor layer_norm( @@ -69,18 +98,58 @@ Tensor layer_norm( double eps, bool /* cudnn_enable, deprecated */) { + return std::get<0>(at::native_layer_norm(input, normalized_shape, weight, bias, eps)); +} + +DEFINE_DISPATCH(LayerNormKernel); +DEFINE_DISPATCH(LayerNormBackwardKernel); + +// Ported from pytorch/xla repo +std::tuple math_native_layer_norm( + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps) { auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias); auto X = std::get<0>(inputs); auto gamma = std::get<1>(inputs); auto beta = std::get<2>(inputs); auto M = std::get<3>(inputs); auto N = std::get<4>(inputs); - - return std::get<0>(at::native_layer_norm(X, gamma, beta, M, N, eps)); + auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + const int normalized_ndim = normalized_shape.size(); + const int axis = input_ndim - normalized_ndim; + at::Tensor input_reshaped = input.view({1, M, -1}); + // Unlike Batch Normalization, which applies scalar scale and bias for each + // entire channel/plane with the affine option, Layer Normalization applies + // per-element scale and bias. E.g. For input {N, C, H, W}, weight for + // batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}. + auto outputs = at::native_batch_norm( + input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{}, + /*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps); + at::Tensor out = std::get<0>(outputs); + out = out.view(input_shape); + if (weight.defined() && bias.defined()) { + out = bias.addcmul(out, weight, 1); + } else if (weight.defined()) { + out = out.mul(weight); + } else if (bias.defined()) { + out = out.add(bias); + } + at::Tensor mean = std::get<1>(outputs); + at::Tensor rstd = std::get<2>(outputs); + std::vector stat_shape; + for (size_t idx = 0; idx < axis; ++idx) { + stat_shape.push_back(input_shape[idx]); + } + for (size_t idx = axis; idx < input.dim(); ++idx) { + stat_shape.push_back(1); + } + mean = mean.view(stat_shape); + rstd = rstd.view(stat_shape); + return std::make_tuple(out, mean, rstd); } - -DEFINE_DISPATCH(LayerNormKernel); -DEFINE_DISPATCH(LayerNormBackwardKernel); - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index bf931fb26c5fd..fa936ab7d4ce3 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -52,16 +52,10 @@ std::tuple _prepare_layer_norm_inputs( } const int axis = input_ndim - normalized_ndim; - const int64_t M = std::accumulate( - input_shape.cbegin(), - input_shape.cbegin() + axis, - 1LL, - std::multiplies()); - const int64_t N = std::accumulate( - input_shape.cbegin() + axis, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t M = + prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + prod_intlist(input_shape.cbegin() + axis, input_shape.cend()); const auto& X = input.is_contiguous() ? input : input.contiguous(); const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); diff --git a/aten/src/ATen/native/metal/MetalAten.mm b/aten/src/ATen/native/metal/MetalAten.mm new file mode 100644 index 0000000000000..77fec7dae1ff1 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalAten.mm @@ -0,0 +1,294 @@ +#import +#import +#import +#import +#import + +#include +#include + +namespace at { +namespace native { +namespace metal { + +at::Tensor& copy_from_metal_(at::Tensor& dst, const at::Tensor& src) { + TORCH_INTERNAL_ASSERT( + src.device().type() == DeviceType::Metal, + "copy_from_metal input tensor's device is not metal"); + TORCH_INTERNAL_ASSERT( + dst.device().type() == DeviceType::CPU, + "copy_from_metal is implemented only for CPU device output"); + TORCH_INTERNAL_ASSERT( + dst.layout() == Layout::Strided, + "copy_from_metal is implemented only for Strided layout output"); + TORCH_INTERNAL_ASSERT( + dst.scalar_type() == ScalarType::Float, + "copy_from_metal is implemented only for float dtype output, got:", + dst.scalar_type()); + TORCH_INTERNAL_ASSERT( + dst.is_contiguous(), + "copy_from_metal is implemented only for contiguous output tensor"); + + MetalTensor& mtensor = MetalTensor::fromTensor(src); + mtensor.copy_data_to_host(dst.data_ptr()); + return dst; +} + +at::Tensor& copy_to_metal_(at::Tensor& dst, const at::Tensor& src) { + TORCH_INTERNAL_ASSERT( + dst.device().type() == DeviceType::Metal, + "copy_to_metal_ output tensor's device is not metal"); + TORCH_INTERNAL_ASSERT( + src.device().type() == DeviceType::CPU, + "copy_to_metal_ is implemented only for CPU device input"); + TORCH_INTERNAL_ASSERT( + src.layout() == Layout::Strided, + "copy_to_metal_ is implemented only for Strided layout input"); + TORCH_INTERNAL_ASSERT( + src.scalar_type() == ScalarType::Float, + "copy_to_metal_ is implemented only for float dtype"); + auto cpu_tensor_contiguous = src.contiguous(); + MetalTensor& mtensor = MetalTensor::fromTensor(dst); + mtensor.set_data_from_host(cpu_tensor_contiguous.data_ptr()); + return dst; +} + +at::Tensor& metal_copy_impl_(at::Tensor& dst, const at::Tensor& src) { + if (src.device().type() == at::kMetal && dst.device().type() == at::kCPU) { + return copy_from_metal_(dst, src); + } + if (src.device().type() == at::kCPU && dst.device().type() == at::kMetal) { + return copy_to_metal_(dst, src); + } + TORCH_INTERNAL_ASSERT( + src.device().type() == DeviceType::Metal, + "metal_copy_ is implemented only for CPU,Strided,float->Metal; Metal->CPU,Strided,float"); + return dst; +} + +#pragma mark - ATen Ops + +Tensor empty( + IntArrayRef size, + optional dtype, + optional layout, + optional device, + optional pin_memory, + c10::optional memory_format) { + TORCH_CHECK( + !pin_memory.has_value(), + "'pin_memory' argument is incompatible with Metal tensor"); + TORCH_CHECK( + !memory_format.has_value(), + "'memory_format' argument is incompatible with Metal tensor"); + MetalTensor mt{size.vec()}; + return MetalTensor::toTensor( + std::move(mt), at::device(at::kMetal).dtype(dtype)); +}; + +at::Tensor empty_strided( + IntArrayRef size, + IntArrayRef stride, + optional dtype, + optional layout, + optional device, + optional pin_memory) { + TORCH_CHECK( + !pin_memory.has_value() || !pin_memory.value(), + "'pin_memory' argument is incompatible with Metal tensor"); + MetalTensor mt{size.vec(), stride.vec()}; + return MetalTensor::toTensor( + std::move(mt), at::device(at::kMetal).dtype(dtype)); +} + +Tensor addmm( + const Tensor& bias, + const Tensor& input, + const Tensor& weight, + Scalar beta, + Scalar alpha) { + TORCH_CHECK(input.is_metal()); + TORCH_CHECK(input.dim() == 2 && weight.dim() == 2); + TORCH_CHECK(beta.toFloat() == 1.0f); + TORCH_CHECK(alpha.toFloat() == 1.0f); + auto&& sizes = weight.sizes(); + at::Tensor transposedWeight = weight.t().contiguous(); + at::Tensor mWeight = + transposedWeight.view({sizes[1], sizes[0], 1, 1}).contiguous(); + return mpscnn::addmm(bias, input, mWeight); +} + +Tensor conv2d( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups) { + TORCH_CHECK(input.is_metal()); + Conv2DParams params{ + input.sizes(), weight.sizes(), padding, stride, dilation, groups}; + TORCH_INTERNAL_ASSERT(input.dim() == 4, "Expected 4-dimensional input"); + TORCH_INTERNAL_ASSERT(weight.dim() == 4, "Expected 4-dimensional weight"); + TORCH_CHECK(weight.device().type() == kCPU); + return mpscnn::conv2d(input, weight, bias, params); +} + +Tensor log_softmax_int( + const Tensor& input, + int64_t dim, + c10::optional dtype) { + TORCH_CHECK(dim == 1); + return mpscnn::log_softmax_int(input); +} + +Tensor max_pool2d( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + TORCH_CHECK(input.is_metal()); + TORCH_CHECK( + dilation[0] == dilation[1] == 1, "dilation is not supported on MPSCNN"); + TORCH_CHECK(ceil_mode == false, "ceil_mode is not supported on MPSCNN"); + return mpscnn::max_pool2d( + input, kernel_size, stride, padding, dilation, ceil_mode); +} + +Tensor relu(const Tensor& input) { + TORCH_CHECK(input.is_metal()); + return mpscnn::relu(input); +} + +Tensor& relu_(Tensor& input) { + TORCH_CHECK(input.is_metal()); + return mpscnn::relu_(input); +} + +Tensor sigmoid(const Tensor& input) { + TORCH_CHECK(input.is_metal()); + return mpscnn::sigmoid(input); +} + +Tensor t(const Tensor& input) { + TORCH_CHECK(input.is_metal()); + TORCH_CHECK(input.dim() == 2); + return mpscnn::t(input); +} + +Tensor view(const Tensor& input, IntArrayRef size) { + TORCH_CHECK(input.is_metal()); + return mpscnn::view(input, size); +} + +Tensor upsample_nearest2d_vec( + const Tensor& input, + c10::optional output_size, + c10::optional> scale_factors) { + TORCH_CHECK(input.is_metal()); + return mpscnn::upsample_nearest2d_vec(input, output_size, scale_factors); +} + +Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) { + TORCH_CHECK(input1.is_metal()); + TORCH_CHECK(input1.dim() == input2.dim()); + TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]); + TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]); + return mpscnn::add(input1, input2.is_metal() ? input2 : input2.metal()); +} + +Tensor& add__Tensor(Tensor& input1, const Tensor& input2, Scalar alpha) { + TORCH_CHECK(input1.is_metal()); + TORCH_CHECK(input1.dim() == input2.dim()); + TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]); + TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]); + return mpscnn::add_(input1, input2.is_metal() ? input2 : input2.metal()); +} + +Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) { + TORCH_CHECK(input1.is_metal()); + TORCH_CHECK(input1.dim() == input2.dim()); + TORCH_CHECK(input2.sizes()[2] == input2.sizes()[3] == 1); + return mpscnn::sub(input1, input2.is_metal() ? input2 : input2.metal()); +} + +Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { + TORCH_CHECK(input1.is_metal()); + TORCH_CHECK(input1.dim() == input2.dim()); + TORCH_CHECK(input2.sizes()[2] == input2.sizes()[3] == 1); + return mpscnn::mul(input1, input2.is_metal() ? input2 : input2.metal()); +} + +Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { + // averages across the width and height, and outputs a 1x1xC image. + TORCH_CHECK(output_size[0] == 1 && output_size[1] == 1); + TORCH_CHECK(input.is_metal()); + return mpscnn::global_avg_pool2d(input, output_size); +} + +Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val) { + TORCH_CHECK(input.is_metal()); + return mpscnn::hardtanh_(input, min_val, max_val); +} + +Tensor reshape(const Tensor& input, IntArrayRef shape) { + TORCH_CHECK(input.is_metal()); + return mpscnn::reshape(input, shape); +} + +Tensor flatten_using_ints( + const Tensor& input, + int64_t start_dim, + int64_t end_dim) { + TORCH_CHECK(input.is_metal()); + return mpscnn::flatten_using_ints(input, start_dim, end_dim); +} + +TORCH_LIBRARY_IMPL(aten, Metal, m) { + m.impl("conv2d", TORCH_FN(conv2d)); + m.impl("add.Tensor", TORCH_FN(add_Tensor)); + m.impl("add_.Tensor", TORCH_FN(add__Tensor)); + m.impl("addmm", TORCH_FN(addmm)); + m.impl("empty.memory_format", empty); + m.impl("empty_strided", TORCH_FN(empty_strided)); + m.impl("log_softmax.int", TORCH_FN(log_softmax_int)); + m.impl("max_pool2d", TORCH_FN(max_pool2d)); + m.impl("mul.Tensor", TORCH_FN(mul_Tensor)); + m.impl("relu", TORCH_FN(relu)); + m.impl("relu_", TORCH_FN(relu_)); + m.impl("sigmoid", TORCH_FN(sigmoid)); + m.impl("sub.Tensor", TORCH_FN(sub_Tensor)); + m.impl("upsample_nearest2d.vec", TORCH_FN(upsample_nearest2d_vec)); + m.impl("view", TORCH_FN(view)); + m.impl("adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d)); + m.impl("hardtanh_", TORCH_FN(hardtanh_)); + m.impl("reshape", TORCH_FN(reshape)); + m.impl("flatten.using_ints", TORCH_FN(flatten_using_ints)); +} + +} // namespace metal +} // namespace native + +struct MetalImpl : public at::metal::MetalInterface { + bool is_metal_available() const override { +#if defined(USE_PYTORCH_METAL) + return [[MPSCNNContext sharedInstance] available]; +#else + return false; +#endif + } + at::Tensor& metal_copy_(at::Tensor& input, const at::Tensor& src) + const override { + TORCH_CHECK( + is_metal_available(), "Metal is not available on the current device"); + return native::metal::metal_copy_impl_(input, src); + } +}; +#if defined(USE_PYTORCH_METAL) +static at::metal::MetalImplRegistrar g_metal_impl(new MetalImpl()); +#endif + +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalCommandBuffer.h b/aten/src/ATen/native/metal/MetalCommandBuffer.h new file mode 100644 index 0000000000000..9469a35245193 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalCommandBuffer.h @@ -0,0 +1,16 @@ +#import +#import +#import + +@interface MetalCommandBuffer : NSObject +@property(nonatomic, strong, readonly) NSThread* thread; +@property(nonatomic, strong, readonly) id buffer; + ++ (MetalCommandBuffer*)newBuffer; ++ (MetalCommandBuffer*)currentBuffer; +- (void)synchronize; + +- (void)add:(MPSTemporaryImage*)image; +- (void)remove:(MPSTemporaryImage*)image; + +@end diff --git a/aten/src/ATen/native/metal/MetalCommandBuffer.mm b/aten/src/ATen/native/metal/MetalCommandBuffer.mm new file mode 100644 index 0000000000000..33b5db9386ad7 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalCommandBuffer.mm @@ -0,0 +1,79 @@ +#import +#import +#import + +#include + +NSString* cb_key = @"PTCommandBuffer"; +@implementation MetalCommandBuffer { + NSMutableArray* _images; + std::mutex _mutex; +} + ++ (MetalCommandBuffer*)newBuffer { + MetalCommandBuffer* cb = [MetalCommandBuffer new]; + cb->_buffer = [[MPSCNNContext sharedInstance].commandQueue commandBuffer]; + cb->_thread = [NSThread currentThread]; + cb->_images = [NSMutableArray new]; + return cb; +} + ++ (MetalCommandBuffer*)currentBuffer { + NSThread* thd = [NSThread currentThread]; + NSMutableDictionary* dict = [thd threadDictionary]; + MetalCommandBuffer* cb = dict[cb_key]; + if (!cb) { + cb = [MetalCommandBuffer new]; + cb->_buffer = [[MPSCNNContext sharedInstance].commandQueue commandBuffer]; + cb->_thread = thd; + cb->_images = [NSMutableArray new]; + dict[cb_key] = cb; + } + return cb; +} + +- (void)flush { + [[_thread threadDictionary] removeObjectForKey:cb_key]; +} + +- (void)add:(MPSTemporaryImage*)image { + if (![image isTemporaryImage]) { + return; + } + std::lock_guard g(_mutex); + [_images addObject:image]; +} + +- (void)remove:(MPSTemporaryImage*)image { + if (![image isTemporaryImage]) { + return; + } + std::lock_guard g(_mutex); + [_images removeObject:image]; +} + +- (void)synchronize { + if (_buffer.status == 0) { + // recycle all temporary images manually before flushing the command buffer + [self recycle]; + [_buffer commit]; + [_buffer waitUntilCompleted]; + [[_thread threadDictionary] removeObjectForKey:cb_key]; + } +} + +- (void)recycle { + for (MPSTemporaryImage* image in _images) { + [image recycle]; + } +} + +- (BOOL)isEqual:(id)object { + if (![object isKindOfClass:[MetalCommandBuffer class]]) { + return NO; + } + MetalCommandBuffer* mc = (MetalCommandBuffer*)object; + return (_thread == mc.thread && _buffer == mc.buffer); +} + +@end diff --git a/aten/src/ATen/native/metal/MetalConvolution.h b/aten/src/ATen/native/metal/MetalConvolution.h new file mode 100644 index 0000000000000..7a7bdfbd21c23 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalConvolution.h @@ -0,0 +1,53 @@ +#import + +#include + +namespace at { +namespace native { +namespace metal { + +enum class NeuronType { + None, + Clamp, + Relu, + Sigmoid, + Tanh, +}; + +struct Conv2DParams final { + Conv2DParams() = delete; + Conv2DParams( + c10::IntArrayRef inputSizes, + c10::IntArrayRef weightSizes, + c10::IntArrayRef padding, + c10::IntArrayRef stride, + c10::IntArrayRef dilation, + int64_t groups); + + std::vector output_sizes() const; + bool isDepthwise() const; + + int64_t N; // batch size + int64_t C; // channels + int64_t H; // input height + int64_t W; // input width + int64_t OC; // output channels + int64_t IC; // input channels + int64_t KH; // kernel height + int64_t KW; // kernel width + int64_t SY; // stride y (height) + int64_t SX; // stride x (width) + int64_t PY; // padding y (height) + int64_t PX; // padding x (width) + int64_t DY; // dilation y (height) + int64_t DX; // dilation x (width) + int64_t G; // groups + int64_t OW; // output width + int64_t OH; // output height +}; + +NeuronType neuronType(const Conv2dOpContext& context); + +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalConvolution.mm b/aten/src/ATen/native/metal/MetalConvolution.mm new file mode 100644 index 0000000000000..1d316e2144d2b --- /dev/null +++ b/aten/src/ATen/native/metal/MetalConvolution.mm @@ -0,0 +1,65 @@ +#import +#import +#import + +namespace at { +namespace native { +namespace metal { + +Conv2DParams::Conv2DParams( + c10::IntArrayRef inputSizes, + c10::IntArrayRef weightSizes, + c10::IntArrayRef padding, + c10::IntArrayRef stride, + c10::IntArrayRef dilation, + int64_t groups) + : N(inputSizes[0]), + C(inputSizes[1]), + H(inputSizes[2]), + W(inputSizes[3]), + OC(weightSizes[0]), + IC(weightSizes[1]), + KH(weightSizes[2]), + KW(weightSizes[3]), + SY(stride[0]), + SX(stride[1]), + PY(padding[0]), + PX(padding[1]), + DY(dilation[0]), + DX(dilation[1]), + G(groups) { + OH = std::floor((H + 2 * PY - DY * (KH - 1) - 1) / SY + 1); + OW = std::floor((W + 2 * PX - DX * (KW - 1) - 1) / SX + 1); +}; + +std::vector Conv2DParams::output_sizes() const { + return {N, OC, OH, OW}; +} + +bool Conv2DParams::isDepthwise() const { + // Currently, only channel multipler of 1 is supported + // i.e. inputFeatureChannels == outputFeatureChannels + return G > 1 && IC == 1 && OC == G && OC == C; +} + +NeuronType neuronType(const Conv2dOpContext& context) { + float inf_max = std::numeric_limits::infinity(); + float inf_min = -std::numeric_limits::infinity(); + float output_max = context.output_max.has_value() + ? context.output_max.value().toFloat() + : inf_max; + float output_min = context.output_min.has_value() + ? context.output_min.value().toFloat() + : inf_min; + if (output_max == inf_max && output_min == 0) { + return NeuronType::Relu; + } else if (output_max < inf_max && output_min > inf_min) { + return NeuronType::Clamp; + } else { + return NeuronType::None; + } +} + +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalGuardImpl.cpp b/aten/src/ATen/native/metal/MetalGuardImpl.cpp new file mode 100644 index 0000000000000..c3aed9cad5b61 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalGuardImpl.cpp @@ -0,0 +1,64 @@ +#include +#include + +namespace at { +namespace detail { + +struct MetalGuardImpl final : public c10::impl::DeviceGuardImplInterface { + MetalGuardImpl() {} + + explicit MetalGuardImpl(DeviceType t) { + TORCH_INTERNAL_ASSERT(t == DeviceType::Metal); + } + + DeviceType type() const override { + return DeviceType::Metal; + } + Device exchangeDevice(Device) const override { + // no-op + return Device(DeviceType::Metal, -1); + } + Device getDevice() const override { + return Device(DeviceType::Metal, -1); + } + void setDevice(Device) const override { + // no-op + } + void uncheckedSetDevice(Device d) const noexcept override { + // no-op + } + Stream getStream(Device d) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1)); + } + // NB: These do NOT set the current device + Stream exchangeStream(Stream s) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1)); + } + DeviceIndex deviceCount() const noexcept override { + return 1; + } + + // Event-related functions + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK(false, "Metal backend doesn't support events."); + } + void block(void* event, const Stream& stream) const override { + TORCH_CHECK(false, "Metal backend doesn't support events.") + } + bool queryEvent(void* event) const override { + TORCH_CHECK(false, "Metal backend doesn't support events.") + } + void destroyEvent(void* event, const DeviceIndex device_index) const + noexcept override {} +}; + +C10_REGISTER_GUARD_IMPL(Metal, MetalGuardImpl); + +} // namespace detail +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalPrepackOpContext.h b/aten/src/ATen/native/metal/MetalPrepackOpContext.h new file mode 100644 index 0000000000000..04c6da7a3aed9 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalPrepackOpContext.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { +namespace metal { + +using SerializationTypeConv2dPrePack = std::tuple< + Tensor, + c10::optional, + std::vector, + std::vector, + std::vector, + int64_t, + c10::optional, + c10::optional>; + +class Conv2dOpContext : public torch::jit::CustomClassHolder { + public: + SerializationTypeConv2dPrePack pack() { + return std::make_tuple( + weight, + bias, + stride, + padding, + dilation, + groups, + output_min, + output_max); + } + Conv2dOpContext() = delete; + Conv2dOpContext( + at::Tensor&& weight, + c10::optional&& bias, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + int64_t groups, + c10::optional output_min, + c10::optional output_max) + : weight(std::move(weight)), + bias(std::move(bias)), + stride(stride), + padding(padding), + dilation(dilation), + groups(groups), + output_min(output_min), + output_max(output_max) {} + + void release_resources() override { + if (releaseCallback) { + releaseCallback(conv2dOp); + conv2dOp = nullptr; + } + } + + Tensor weight; + c10::optional bias; + std::vector stride; + std::vector padding; + std::vector dilation; + int64_t groups; + c10::optional output_min; + c10::optional output_max; + void* conv2dOp = nullptr; // reserved to hold MPSCNNConv2dOp objects + std::function releaseCallback = nullptr; +}; + +// The MPSCNNConvolution class takes weights in the order +// [outputChannels][kernelHeight][kernelWidth][inputChannels/groups]. +static inline std::vector permuteWeights( + const float* src, + const std::vector& sizes) { + const int64_t M = sizes[0]; + const int64_t Cf = sizes[1]; + const int64_t kH = sizes[2]; + const int64_t kW = sizes[3]; + std::vector packedWeights(M * kH * kW * Cf); + for (auto m = 0; m < M; ++m) { + for (auto c = 0; c < Cf; ++c) { + for (auto kh = 0; kh < kH; ++kh) { + for (auto kw = 0; kw < kW; ++kw) { + int64_t oc = m * kH * kW * Cf + kh * kW * Cf + kw * Cf + c; + int64_t ic = m * Cf * kH * kW + c * kH * kW + kh * kW + kw; + packedWeights[oc] = src[ic]; + } + } + } + } + return packedWeights; +} + +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp new file mode 100644 index 0000000000000..115f2140f397e --- /dev/null +++ b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp @@ -0,0 +1,128 @@ +#include +#include +#include + + +#if (C10_IOS || TARGET_OS_MAC) +#import +#endif + +namespace at { +namespace native { +namespace metal { + +c10::intrusive_ptr unpack( + Tensor&& weight, + c10::optional&& bias, + std::vector&& stride, + std::vector&& padding, + std::vector&& dilation, + int64_t groups, + c10::optional output_min, + c10::optional output_max) { + const Tensor weightContig = weight.contiguous(); + const auto ws = weightContig.sizes(); + auto packed_buffer = permuteWeights(weightContig.data_ptr(), ws.vec()); + auto packedWeight = at::empty(ws); + int64_t size_bytes = at::prod_intlist(ws) * sizeof(float); + memcpy(packedWeight.data_ptr(), packed_buffer.data(), size_bytes); + return c10::make_intrusive( + std::move(packedWeight), + std::move(bias), + stride, + padding, + dilation, + groups, + output_min, + output_max); +} + +TORCH_LIBRARY(metal, m) { + m.class_("Conv2dOpContext") + .def_pickle( + [](const c10::intrusive_ptr& op_context) + -> SerializationTypeConv2dPrePack { // __getstate__ + return op_context->pack(); + }, + [](SerializationTypeConv2dPrePack state) + -> c10::intrusive_ptr { // __setstate__ + return unpack( + std::move(std::get<0>(state)), + std::move(std::get<1>(state)), + std::move(std::get<2>(state)), + std::move(std::get<3>(state)), + std::move(std::get<4>(state)), + std::move(std::get<5>(state)), + std::move(std::get<6>(state)), + std::move(std::get<7>(state))); + }); + m.def("copy_to_host(Tensor X) -> Tensor Y"); +} + +TORCH_LIBRARY(metal_prepack, m) { + m.def( + "conv2d_prepack(Tensor W, Tensor? B, int[2] stride, " + "int[2] padding, int[2] dilation, int groups, " + "Scalar? output_min=None, Scalar? output_max=None) " + "-> __torch__.torch.classes.metal.Conv2dOpContext"); + m.def( + "conv2d_run(Tensor X, " + "__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y"); +} + +c10::intrusive_ptr conv2d_prepack( + Tensor&& weight, + c10::optional&& bias, + std::vector&& stride, + std::vector&& padding, + std::vector&& dilation, + int64_t groups, + c10::optional output_min, + c10::optional output_max) { + TORCH_CHECK(weight.dim() == 4); + return c10::make_intrusive( + std::move(weight), + std::move(bias), + stride, + padding, + dilation, + groups, + output_min, + output_max); +} + +Tensor conv2d_prepack_run( + const Tensor& input, + const c10::intrusive_ptr& op_context) { +#if (C10_IOS || TARGET_OS_MAC) + return mpscnn::conv2d(input, *op_context); +#else + TORCH_CHECK(false, "conv2d_prepack_run can only be invoked on iOS and MacOS"); + return input; +#endif +} + +Tensor copy_to_host(const Tensor& input) { +#if (C10_IOS || TARGET_OS_MAC) + return mpscnn::copy_to_host(input); +#else + TORCH_CHECK(false, "copy_to_host can only be invoked on iOS and MacOS"); + return input; +#endif +} + +TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) { + m.impl("conv2d_prepack", TORCH_FN(conv2d_prepack)); +} + +TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) { + m.impl("conv2d_run", conv2d_prepack_run); +} + +TORCH_LIBRARY_IMPL(metal, Metal, m) { + m.impl("copy_to_host", copy_to_host); +} + +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalShaders.h b/aten/src/ATen/native/metal/MetalShaders.h new file mode 100644 index 0000000000000..2cf0c55c5587c --- /dev/null +++ b/aten/src/ATen/native/metal/MetalShaders.h @@ -0,0 +1,270 @@ +#ifndef MPSCNNShaders_h +#define MPSCNNShaders_h + +static const char* METAL_SHADERS = R"METAL_SHADERS( +#include +using namespace metal; + +constant ushort ushort_arg_0[[function_constant(0)]]; +constant ushort ushort_arg_1[[function_constant(1)]]; +constant ushort ushort_arg_2[[function_constant(2)]]; +constant ushort ushort_arg_3[[function_constant(3)]]; +constant ushort ushort_arg_4[[function_constant(4)]]; +constant ushort ushort_arg_5[[function_constant(5)]]; +constant ushort ushort_arg_6[[function_constant(6)]]; +constant ushort ushort_arg_7[[function_constant(7)]]; +constant ushort ushort_arg_8[[function_constant(8)]]; +constant ushort ushort_arg_9[[function_constant(9)]]; +constant float float_arg_0 [[function_constant(10)]]; +constant float float_arg_1 [[function_constant(11)]]; + + +inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; } + +kernel void elementwise_add_nonarray(texture2d in0[[texture(0)]], + texture2d in1[[texture(1)]], + texture2d out[[texture(2)]], + ushort2 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + out.write(in0.read(gid) + in1.read(gid), gid); +} + +kernel void elementwise_add(texture2d_array in0[[texture(0)]], + texture2d_array in1[[texture(1)]], + texture2d_array out[[texture(2)]], + ushort3 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + ushort2 gid_ = gid.xy; + out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z); +} + +kernel void elementwise_sub_nonarray(texture2d in0[[texture(0)]], + texture2d in1[[texture(1)]], + texture2d out[[texture(2)]], + ushort2 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + ushort2 gid2{0,0}; + out.write(in0.read(gid) - in1.read(gid2), gid); +} + +kernel void elementwise_sub(texture2d_array in0[[texture(0)]], + texture2d_array in1[[texture(1)]], + texture2d_array out[[texture(2)]], + ushort3 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + ushort2 gid1 = gid.xy; + ushort2 gid2{0,0}; + out.write(in0.read(gid1, gid.z) - in1.read(gid2, gid.z), gid1, gid.z); +} +kernel void elementwise_mul_nonarray(texture2d in0[[texture(0)]], + texture2d in1[[texture(1)]], + texture2d out[[texture(2)]], + ushort2 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + ushort2 gid2{0,0}; + out.write(in0.read(gid) * in1.read(gid2), gid); +} + +kernel void elementwise_mul(texture2d_array in0[[texture(0)]], + texture2d_array in1[[texture(1)]], + texture2d_array out[[texture(2)]], + ushort3 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + ushort2 gid1 = gid.xy; + ushort2 gid2{0,0}; + out.write(in0.read(gid1, gid.z) * in1.read(gid2, gid.z), gid1, gid.z); +} + +kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]], + texture2d_array out[[texture(0)]], + ushort3 gid[[thread_position_in_grid]]) { + const ushort C = ushort_arg_0; + const ushort H = ushort_arg_1; + const ushort W = ushort_arg_2; + if (gid.x >= W || gid.y >= H) { + return; + } + const ushort n = gid.z / divRoundUp(C, 4); + const ushort c = gid.z - n * divRoundUp(C, 4); + // TODO: are the `else` branches needed? + // TODO: trick the optimizer for case where C == 4? +#define CHW_TO_CHWP4(idx, n, c_, h, w) \ +if ((c_) < C) { \ +trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \ +} else { \ +trns[idx] = 0.0h; \ +} + half4 trns; + CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x); + CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x); + CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x); + CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x); +#undef CHW_TO_CHWP4 + out.write(trns, gid.xy, gid.z); +} + +kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]], + texture2d out[[texture(0)]], + ushort2 gid[[thread_position_in_grid]]) { + const ushort C = ushort_arg_0; + const ushort H = ushort_arg_1; + const ushort W = ushort_arg_2; + if (gid.x >= W || gid.y >= H) { + return; + } + half4 trns; + // TODO: are the `else` branches needed? + // TODO: trick the optimizer for case where C % 4 == 0? +#define CHW_TO_CHWP4(idx, c, h, w) \ +if ((c) < C) { \ +trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \ +} else { \ +trns[idx] = 0.0h; \ +} + CHW_TO_CHWP4(0, 0, gid.y, gid.x); + CHW_TO_CHWP4(1, 1, gid.y, gid.x); + CHW_TO_CHWP4(2, 2, gid.y, gid.x); + CHW_TO_CHWP4(3, 3, gid.y, gid.x); +#undef CHW_TO_CHWP4 + out.write(trns, gid.xy); +} + +kernel void copy_metal_to_nchw(texture2d_array in[[texture(0)]], + device float* out[[buffer(0)]], + ushort3 gid[[thread_position_in_grid]]) { + const ushort C = ushort_arg_0; + const ushort H = ushort_arg_1; + const ushort W = ushort_arg_2; + if (gid.x >= W || gid.y >= H) { + return; + } + const ushort n = gid.z / divRoundUp(C, 4); + const ushort c = gid.z - n * divRoundUp(C, 4); + half4 cs = in.read(gid.xy, gid.z); +#define CHWP4_TO_CHW(idx, n, c_, h, w) \ +if ((c_) < C) { \ +out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \ +} + CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x); + CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x); + CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x); + CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x); +#undef CHWP4_TO_CHW +} + +kernel void copy_metal_to_nchw_nonarray(texture2d in[[texture(0)]], + device float* out[[buffer(0)]], + ushort2 gid[[thread_position_in_grid]]) { + const ushort C = ushort_arg_0; + const ushort H = ushort_arg_1; + const ushort W = ushort_arg_2; + if (gid.x >= W || gid.y >= H) { + return; + } + half4 cs = in.read(gid.xy); +#define CHWP4_TO_CHW(idx, c, h, w) \ +if ((c) < C) { \ +out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \ +} + CHWP4_TO_CHW(0, 0, gid.y, gid.x); + CHWP4_TO_CHW(1, 1, gid.y, gid.x); + CHWP4_TO_CHW(2, 2, gid.y, gid.x); + CHWP4_TO_CHW(3, 3, gid.y, gid.x); +#undef CHWP4_TO_CHW +} + +kernel void copy(texture2d_array in[[texture(0)]], + texture2d_array out[[texture(1)]], + ushort3 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + ushort2 gid_ = gid.xy; + out.write(in.read(gid_, gid.z), gid_, gid.z); +} + +kernel void copy_nonarray(texture2d in[[texture(0)]], + texture2d out[[texture(1)]], + ushort2 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + out.write(in.read(gid), gid); +} + +kernel void clamp_half4(texture2d_array in[[texture(0)]], + texture2d_array out[[texture(1)]], + constant half* clamp_buf[[buffer(0)]], + ushort3 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]); + const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]); + ushort2 gid_ = gid.xy; + half4 value = in.read(gid_, gid.z); + half4 clamped = clamp(value, min_, max_); + out.write(clamped, gid_, gid.z); +} + +kernel void clamp_half4_nonarray(texture2d in[[texture(0)]], + texture2d out[[texture(1)]], + constant half* clamp_buf[[buffer(0)]], + ushort2 gid[[thread_position_in_grid]]) { + if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + return; + } + const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]); + const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]); + half4 value = in.read(gid); + half4 clamped = clamp(value, min_, max_); + out.write(clamped, gid); +} + +kernel void resize_nearest(texture2d_array in[[texture(0)]], + texture2d_array out[[texture(1)]], + ushort3 gid[[thread_position_in_grid]]) { + const ushort oH = ushort_arg_0; + const ushort oW = ushort_arg_1; + if (gid.x >= oW || gid.y >= oH) { + return; + } + const float height_scale = float(ushort_arg_2) / 10000; + const float width_scale = float(ushort_arg_3) / 10000; + constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest); + const int in_y = (int)(gid.y / height_scale); + const int in_x = (int)(gid.x / width_scale); + out.write(in.sample(s, float2(in_x, in_y), gid.z), gid.xy, gid.z); +} + +kernel void resize_nearest_nonarray(texture2d in[[texture(0)]], + texture2d out[[texture(1)]], + ushort2 gid[[thread_position_in_grid]]) { + const ushort oH = ushort_arg_0; + const ushort oW = ushort_arg_1; + if (gid.x >= oW || gid.y >= oH) { + return; + } + const float height_scale = float(ushort_arg_2) / 10000; + const float width_scale = float(ushort_arg_3) / 10000; + constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest); + const int in_y = (int)(gid.y / height_scale); + const int in_x = (int)(gid.x / width_scale); + out.write(in.sample(s, float2(in_x, in_y)), gid.xy); +} + +)METAL_SHADERS"; + +#endif /* MPSCNNShaders_h */ diff --git a/aten/src/ATen/native/metal/MetalTensor.h b/aten/src/ATen/native/metal/MetalTensor.h new file mode 100644 index 0000000000000..c21e904e3dd62 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalTensor.h @@ -0,0 +1,47 @@ +#include + +namespace at { +namespace native { +namespace metal { + +class MPSImageWrapper; +class MetalTensor final { + class Impl; + + public: + MetalTensor(){}; + explicit MetalTensor(const std::vector& sizes); + explicit MetalTensor( + const std::vector& sizes, + const std::vector& strides); + ~MetalTensor() = default; + + MetalTensor(MetalTensor&&) = default; + MetalTensor& operator=(MetalTensor&&) = default; + + MetalTensor(const MetalTensor&) = default; + MetalTensor& operator=(const MetalTensor&) = default; + + friend std::ostream& operator<<(std::ostream& output, const MetalTensor& mt); + + static at::Tensor toTensor(MetalTensor&& mt, const TensorOptions& options); + static MetalTensor& fromTensor(const at::Tensor& tensor); + + bool defined() const; + IntArrayRef sizes() const; + IntArrayRef strides() const; + int64_t dim() const; + int64_t numel() const; + void set_data_from_host(const float* inputData); + void copy_data_to_host(float* host); + MPSImageWrapper* texture() const; + + private: + std::shared_ptr impl(); + std::shared_ptr impl() const; + std::shared_ptr _impl; +}; + +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalTensor.mm b/aten/src/ATen/native/metal/MetalTensor.mm new file mode 100644 index 0000000000000..3d5590500aea9 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalTensor.mm @@ -0,0 +1,145 @@ +#import +#import +#import +#import + +namespace at { +namespace native { +namespace metal { + +class API_AVAILABLE(ios(10.0), macos(10.13)) MetalTensor::Impl { + public: + Impl(const std::vector& sizes, const std::vector& strides) + : _sizes(sizes), + _strides(strides), + _numel(std::accumulate( + std::begin(_sizes), + std::end(_sizes), + (int64_t)1, + std::multiplies())), + _textureImpl(std::make_unique(sizes)) {} + + IntArrayRef sizes() const { + return _sizes; + } + IntArrayRef strides() const { + return _strides; + } + int64_t dim() const { + return _sizes.size(); + } + int64_t numel() const { + return _numel; + } + void set_data_from_host(const float* inputData) { + _textureImpl->copyDataFromHost(inputData); + } + void copy_data_to_host(float* host) { + _textureImpl->copyDataToHost(host); + } + MPSImageWrapper* texture() const { + return _textureImpl.get(); + } + + private: + std::vector _sizes; + std::vector _strides; + int64_t _numel; + std::unique_ptr _textureImpl; +}; + +MetalTensor::MetalTensor(const std::vector& sizes) + : MetalTensor(sizes, compute_strides(sizes)) {} + +MetalTensor::MetalTensor( + const std::vector& sizes, + const std::vector& strides) + : _impl(std::make_shared(std::move(sizes), std::move(strides))) {} + +bool MetalTensor::defined() const { + return static_cast(_impl); +} + +at::Tensor MetalTensor::toTensor( + MetalTensor&& mt, + const TensorOptions& options) { + using MetalTensorImpl = at::MetalTensorImpl; + auto sizes = mt.sizes(); // sizes is stored in TensorImpl + auto strides = mt.strides(); // strides is stored in MetalTensorImpl + return detail::make_tensor( + DispatchKeySet(DispatchKey::Metal), + options.dtype(), + at::Device(at::kMetal), + std::move(mt), + std::vector(sizes.begin(), sizes.end()), + std::vector(strides.begin(), strides.end())); +} + +MetalTensor& MetalTensor::fromTensor(const at::Tensor& tensor) { + using MetalTensorImpl = at::MetalTensorImpl; + TORCH_INTERNAL_ASSERT( + tensor.is_metal(), "unbox expects Metal tensor as inputs"); + MetalTensorImpl* impl = + static_cast(tensor.unsafeGetTensorImpl()); + return impl->unsafe_opaque_handle(); +} + +std::shared_ptr MetalTensor::impl() { + return _impl; +} + +std::shared_ptr MetalTensor::impl() const { + return _impl; +} + +IntArrayRef MetalTensor::sizes() const { + return impl()->sizes(); +} + +IntArrayRef MetalTensor::strides() const { + return impl()->strides(); +} + +int64_t MetalTensor::dim() const { + return impl()->dim(); +} + +int64_t MetalTensor::numel() const { + return impl()->numel(); +} + +void MetalTensor::set_data_from_host(const float* inputData) { + impl()->set_data_from_host(inputData); +} + +void MetalTensor::copy_data_to_host(float* hostData) { + impl()->copy_data_to_host(hostData); +} + +API_AVAILABLE(ios(10.0)) +MPSImageWrapper* MetalTensor::texture() const { + return impl()->texture(); +} + +std::ostream& operator<<(std::ostream& output, const MetalTensor& mt) { + auto&& sizes = mt.sizes(); + auto&& strides = mt.strides(); + output << "[MetalTensor] | Size:{"; + std::ostringstream oss; + std::copy( + sizes.begin(), sizes.end() - 1, std::ostream_iterator(oss, ",")); + oss << sizes.back(); + output << oss.str() << "}, Stride:{"; + std::string sizesStr = oss.str(); + oss.str(""); + oss.clear(); + std::copy( + strides.begin(), strides.end() - 1, std::ostream_iterator(oss, ",")); + oss << sizes.back(); + output << oss.str() << "}"; + return output; +} + +} +} +} diff --git a/aten/src/ATen/native/metal/MetalTensorImpl.h b/aten/src/ATen/native/metal/MetalTensorImpl.h new file mode 100644 index 0000000000000..fd41a6089a810 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalTensorImpl.h @@ -0,0 +1,55 @@ +#ifndef MetalTensorImpl_h +#define MetalTensorImpl_h + +#include +#include +#include +#import +#import + +namespace at { +template +struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl { + MetalTensorImpl( + at::DispatchKeySet key_set, + const caffe2::TypeMeta& data_type, + c10::Device device, + OpaqueHandle opaque_handle, + c10::IntArrayRef sizes, + c10::IntArrayRef strides) + : OpaqueTensorImpl( + key_set, + data_type, + device, + opaque_handle, + sizes), + strides_(strides.vec()) {} + + IntArrayRef strides() const override { + return strides_; + } + + bool is_contiguous( + c10::MemoryFormat memory_format = + c10::MemoryFormat::Contiguous) const override { + return true; + } + + int64_t stride(int64_t d) const override { + d = at::maybe_wrap_dim(d, this->dim(), false); + return strides_[d]; + } + + void release_resources() override { + using MetalTensor = at::native::metal::MetalTensor; + auto&& handle = (MetalTensor)this->opaque_handle(); + handle.texture()->recycleImage(); + OpaqueTensorImpl::release_resources(); + } + + private: + SmallVector strides_; +}; +} // namespace at + +#endif /* MetalTensorImpl_h*/ diff --git a/aten/src/ATen/native/metal/MetalUtils.h b/aten/src/ATen/native/metal/MetalUtils.h new file mode 100644 index 0000000000000..79685bef149db --- /dev/null +++ b/aten/src/ATen/native/metal/MetalUtils.h @@ -0,0 +1,33 @@ +#include + +namespace at { +namespace native { +namespace metal { + +std::vector fp32_to_fp16(const std::vector& src); +std::vector fp16_to_fp32(const std::vector& src); +std::vector NCHW_to_NC4( + const float* src, + const std::vector& sizes); +std::vector NC4_to_NCHW( + const float* src, + const std::vector& sizes); + +// When copying the result back to a CPU tensor, the memory format becomes NCHW. +// Thus,we compute the strides based on contiguous memory format. +static inline std::vector compute_strides(const std::vector& sizes) { + const auto dim = sizes.size(); + std::vector strides(dim, 0); + if (dim > 0) { + const auto last_idx = dim - 1; + strides[last_idx] = 1; + for (int i = last_idx - 1; i >= 0; --i) { + strides[i] = strides[i + 1] * std::max(sizes[i + 1], 1); + } + } + return strides; +} + +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/MetalUtils.mm b/aten/src/ATen/native/metal/MetalUtils.mm new file mode 100644 index 0000000000000..59ec44b290286 --- /dev/null +++ b/aten/src/ATen/native/metal/MetalUtils.mm @@ -0,0 +1,110 @@ +#import +#import + +#import + +#include +#include +#include + +namespace at { +namespace native { +namespace metal { + +std::vector fp32_to_fp16(const std::vector& src) { + unsigned long count = src.size(); + std::vector output(count, 0); + vImage_Buffer float32{(void*)src.data(), 1, count, count * sizeof(float)}; + vImage_Buffer float16{ + (void*)output.data(), 1, count, count * sizeof(uint16_t)}; + if (vImageConvert_PlanarFtoPlanar16F(&float32, &float16, 0) != + kvImageNoError) { + TORCH_CHECK(false, "fp32_to_fp16 failed"); + return {}; + } + + return output; +} + +std::vector fp16_to_fp32(const std::vector& src) { + unsigned long count = src.size(); + std::vector output(count, 0); + vImage_Buffer float16{(void*)src.data(), 1, count, count * sizeof(uint16_t)}; + vImage_Buffer float32{(void*)output.data(), 1, count, count * sizeof(float)}; + if (vImageConvert_Planar16FtoPlanarF(&float16, &float32, 0) != + kvImageNoError) { + TORCH_CHECK(false, "fp16_to_fp32 failed"); + return {}; + } + return output; +} + +std::vector NCHW_to_NC4( + const float* src, + const std::vector& sizes) { + int64_t N = sizes[0]; + int64_t C = sizes[1]; + int64_t H = sizes[2]; + int64_t W = sizes[3]; + int64_t src_image_count = C * H * W; + int64_t src_count = N * src_image_count; + int64_t slices = (C + 3) / 4; + int64_t numComponents = C < 3 ? C : 4; + int64_t dst_image_count = slices * numComponents * W * H; + int64_t dst_count = N * dst_image_count; + std::vector output(dst_count, 0.0f); + for (int n = 0; n < N; ++n) { + int64_t src_image = n * src_image_count; + int64_t dst_image = n * dst_image_count; + for (int i = 0; i < slices; ++i) { + int64_t slice = i * W * H * numComponents; + for (int j = 0; j < W * H; ++j) { + for (int k = 0; k < numComponents; ++k) { + int ii = src_image + slice + k * W * H + j; + int oi = dst_image + slice + j * numComponents + k; + if (k < C && ii < src_count) { + output[oi] = src[ii]; + } + } + } + } + } + + return output; +} + +std::vector NC4_to_NCHW( + const float* src, + const std::vector& sizes) { + int64_t N = sizes[0]; + int64_t C = sizes[1]; + int64_t H = sizes[2]; + int64_t W = sizes[3]; + int64_t slices = (C + 3) / 4; + int64_t numComponents = C < 3 ? C : 4; + int64_t src_image_count = slices * numComponents * W * H; + int64_t dst_image_count = C * H * W; + int64_t dst_count = N * dst_image_count; + std::vector output(dst_count, 0.0f); + for (int n = 0; n < N; ++n) { + int64_t src_image = n * src_image_count; + int64_t dst_image = n * dst_image_count; + for (int i = 0; i < slices; ++i) { + int64_t slice = i * W * H * numComponents; + for (int j = 0; j < numComponents; ++j) { + for (int k = 0; k < W * H; ++k) { + int ii = src_image + slice + k * numComponents + j; + int oi = dst_image + slice + j * W * H + k; + if (j < C && oi < dst_count) { + output[oi] = src[ii]; + } + } + } + } + } + return output; +} + +} +} +} diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNN.h b/aten/src/ATen/native/metal/mpscnn/MPSCNN.h new file mode 100644 index 0000000000000..48cc99e7f1983 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNN.h @@ -0,0 +1,31 @@ +#import +#import + +namespace at { +namespace native { +namespace metal { +namespace mpscnn { + +struct LaunchParams { + MTLSize threadsPerThreadgroup; + MTLSize threadgroupsPerGrid; + MTLSize threadsPerGrid; // iOS 11.0 +}; + +API_AVAILABLE(ios(10.0), macos(10.13)) +LaunchParams spatialPointwiseKernelLaunchParams( + id pipeline, + MPSImage* im); + +API_AVAILABLE(ios(10.0), macos(10.13)) +NSString* kernelFor( + MPSImage* image, + NSString* arrayKernel, + NSString* nonArrayKernel); + +int computeMPSAlignOffset(int kernel, int pad); + +} +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNN.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNN.mm new file mode 100644 index 0000000000000..eba52ac605262 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNN.mm @@ -0,0 +1,59 @@ +#import + +namespace at { +namespace native { +namespace metal { +namespace mpscnn { + +auto divRoundUp(uint x, uint y) -> uint { + return (x + y - 1) / y; +} + +int computeMPSAlignOffset(int kernel, int pad) { + // To set the offset, we can just match the top-left pixel (in the input + // image, with negative values for padding) that we look at. For 3x3s1p1, we + // look at the (-1, -1) pixel in the original impl. For 3x3s1p0, we look at + // (0, 0) pixel. For 3x3s1p2, look at (-2, -2) MPSCNN always looks at + // (-floor(kernel_size - 1 / 2), -floor(kernel_size - 1 / 2)) Thus, we just + // need to match this up. + + // For 3x3s1p1, offset should be (0, 0) + // For 3x3s1p0, offset should be (1, 1) + // For 3x3s1p2, offset should be (-1, -1) + const int mps_offset = kernel / 2; + const int c2_offset = pad; + return mps_offset - c2_offset; +} + +NSString* kernelFor( + MPSImage* X, + NSString* arrayKernel, + NSString* nonArrayKernel) { + if (X.featureChannels > 4 || X.numberOfImages > 1) { + return arrayKernel; + } + return nonArrayKernel; +} + +LaunchParams spatialPointwiseKernelLaunchParams( + id pipeline, + MPSImage* im) { + const auto threadsPerThreadgroup = MTLSizeMake( + 8 /* threadExecutionWidth */, + 4 /* maxThreadsPerThreadgroup / threadExecutionWidth */, + 1); + const auto threadgroupsPerGrid = MTLSizeMake( + divRoundUp(im.width, threadsPerThreadgroup.width), + divRoundUp(im.height, threadsPerThreadgroup.height), + im.numberOfImages * divRoundUp(im.featureChannels, 4)); + const auto threadsPerGrid = MTLSizeMake( + im.width, + im.height, + im.numberOfImages * divRoundUp(im.featureChannels, 4)); + return {threadsPerThreadgroup, threadgroupsPerGrid, threadsPerGrid}; +}; + +} +} +} +} diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.h new file mode 100644 index 0000000000000..c3635ca43adfb --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.h @@ -0,0 +1,5 @@ +#import + +@interface MPSCNNClampOp : NSObject + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm new file mode 100644 index 0000000000000..5892e0e3b1c1a --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm @@ -0,0 +1,53 @@ +#import +#import +#import +#import + +@implementation MPSCNNClampOp { + MPSImage* _X; + MPSImage* _Y; + NSNumber* _min; + NSNumber* _max; +} + ++ (id)newWithTextures:(NSArray*)textures + Args:(NSArray*)args { + MPSCNNClampOp* op = [MPSCNNClampOp new]; + op->_X = textures[0]; + op->_Y = textures[1]; + op->_min = args[0]; + op->_max = args[1]; + + return op; +} + +- (void)encode:(id)cb { + /* + `clamp(vector, float, float)` is not available on iOS 10.0, + have to use `clamp(vector, half4, half4)` instead. + */ + id encoder = [cb computeCommandEncoder]; + id state = [[MPSCNNContext sharedInstance] + pipelineState:at::native::metal::mpscnn::kernelFor( + _X, @"clamp_half4", @"clamp_half4_nonarray")]; + + [encoder setComputePipelineState:state]; + [encoder setTexture:[_X texture] atIndex:0]; + [encoder setTexture:[_Y texture] atIndex:1]; + id clampBuffer = [[MPSCNNContext sharedInstance].device + newBufferWithLength:2 * sizeof(fp16) + options:MTLResourceOptionCPUCacheModeWriteCombined]; + fp16* clampBufferPtr = (fp16*)[clampBuffer contents]; + clampBufferPtr[0] = _min.floatValue; + clampBufferPtr[1] = _max.floatValue; + [encoder setBuffer:clampBuffer offset:0 atIndex:0]; + const auto& launchParams = + at::native::metal::mpscnn::spatialPointwiseKernelLaunchParams(state, _Y); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [_X markRead]; + [_Y markRead]; +} + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNContext.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNContext.h new file mode 100644 index 0000000000000..738a2d7307e7b --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNContext.h @@ -0,0 +1,18 @@ +#import +#import +#import + +API_AVAILABLE(ios(10.0), macos(10.13)) +@interface MPSCNNContext : NSObject +@property(nonatomic, strong, readonly) id device; +@property(nonatomic, strong, readonly) id commandQueue; +@property(nonatomic, strong, readonly) id library; + ++ (instancetype)sharedInstance; +- (BOOL)available; +- (id)pipelineState:(NSString*)kernel; +- (id)specializedPipelineState:(NSString*)kernel + Constants:(NSArray*) + constants; + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNContext.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNContext.mm new file mode 100644 index 0000000000000..3834380c72770 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNContext.mm @@ -0,0 +1,139 @@ +#import +#import + +#include +#include + +#if C10_IOS +#import +#elif TARGET_OS_MAC +#import +#endif + +@implementation MPSCNNContext { + std::mutex _pipelineCacheMutex; + NSMutableDictionary>* _pipelineCache; +} + ++ (instancetype)sharedInstance { + static dispatch_once_t onceToken; + static MPSCNNContext* instance = nil; + dispatch_once(&onceToken, ^{ + instance = [[MPSCNNContext alloc] init]; + instance->_device = MTLCreateSystemDefaultDevice(); + instance->_library = [instance.device + newLibraryWithSource:[NSString stringWithUTF8String:METAL_SHADERS] + options:nil + error:nil]; + instance->_commandQueue = [instance.device newCommandQueue]; + instance->_pipelineCache = + [NSMutableDictionary> new]; + }); + return instance; +} + +- (BOOL)available { +#if !defined(__APPLE__) + return false; +#elif TARGET_IPHONE_SIMULATOR + return false; +#elif TARGET_OS_IPHONE + if (!MPSSupportsMTLDevice(_device)) { + return false; + } + if ([UIDevice currentDevice].systemVersion.floatValue < 10.2) { + return false; + } + if (![MTLCreateSystemDefaultDevice() + supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily3_v2]) { + return false; + } +#elif TARGET_OS_MAC + if (!MPSSupportsMTLDevice(_device)) { + return false; + } + NSOperatingSystemVersion supportedVer = {10, 13, 0}; + if (![[NSProcessInfo processInfo] + isOperatingSystemAtLeastVersion:supportedVer]) { + return false; + } + if (![MTLCreateSystemDefaultDevice() + supportsFeatureSet:MTLFeatureSet_macOS_GPUFamily1_v3]) { + return false; + } +#else + return false; +#endif + + return _device && _library && _commandQueue; +} + +- (id)pipelineState:(NSString*)kernel { + TORCH_CHECK(_library, "Failed to load kernels"); + std::lock_guard g(_pipelineCacheMutex); + id state = _pipelineCache[kernel]; + if (state) { + return state; + } + id func = [_library newFunctionWithName:kernel]; + TORCH_CHECK(func != nil, "Failed to load the kernel function", kernel); + NSError* errors; + state = [_device newComputePipelineStateWithFunction:func error:&errors]; + TORCH_CHECK(state != nil, errors.localizedDescription.UTF8String); + _pipelineCache[kernel] = state; + return state; +} + +- (id)specializedPipelineState:(NSString*)kernel + Constants:(NSArray*) + constants { + TORCH_CHECK(_library, "Failed to load kernels"); + std::string kernelStr = std::string([kernel UTF8String]); + for (auto i = 0; i < constants.count; ++i) { + kernelStr += "_" + std::string([constants[i] stringValue].UTF8String); + } + std::lock_guard g(_pipelineCacheMutex); + id state = _pipelineCache[kernel]; + if (state) { + return state; + } + MTLFunctionConstantValues* constantValues = [MTLFunctionConstantValues new]; + NSUInteger ushortArgIndex = 0; + NSUInteger floatArgIndex = 10; + for (auto i = 0; i < constants.count; ++i) { + NSNumber* constant = constants[i]; + const char* type = constant.objCType; + if (strcmp(type, @encode(NSUInteger)) == 0 || + strcmp(type, @encode(NSInteger)) == 0) { + TORCH_CHECK(ushortArgIndex <= 10); + ushort value = ushort([constant unsignedIntegerValue]); + [constantValues setConstantValue:&value + type:MTLDataTypeUShort + atIndex:ushortArgIndex]; + ushortArgIndex++; + } + if (strcmp(type, @encode(float)) == 0 || + strcmp(type, @encode(double)) == 0) { + TORCH_CHECK(floatArgIndex <= 2); + float value = [constant floatValue]; + [constantValues setConstantValue:&value + type:MTLDataTypeFloat + atIndex:floatArgIndex]; + floatArgIndex++; + } + } + NSError* errors; + id func = [_library newFunctionWithName:kernel + constantValues:constantValues + error:&errors]; + TORCH_CHECK( + func, "Couldn't get function: ", errors.localizedDescription.UTF8String); + state = [_device newComputePipelineStateWithFunction:func error:&errors]; + TORCH_CHECK(state != nil, errors.localizedDescription.UTF8String); + kernel = [NSString stringWithCString:kernelStr.c_str() + encoding:NSUTF8StringEncoding]; + _pipelineCache[kernel] = state; + return state; +} + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h new file mode 100644 index 0000000000000..140c4aecd4563 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h @@ -0,0 +1,23 @@ +#import +#import +#import + +API_AVAILABLE(ios(10.0), macos(10.13)) +@interface MPSCNNConvDataSource : NSObject +@property(nonatomic, assign) void* weights; +@property(nonatomic, assign) float* bias; + +- (id)initWithWeights:(void*)weights + Bias:(float*)bias + Desc:(MPSCNNConvolutionDescriptor*)desc; + +@end + +using namespace at::native::metal; +API_AVAILABLE(ios(10.0), macos(10.13)) +@interface MPSCNNConvOp : NSObject ++ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params + weights:(float*)w + bias:(float*)b + neuronFilter:(NeuronType)t; +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm new file mode 100644 index 0000000000000..1ab2c2213f497 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm @@ -0,0 +1,173 @@ +#import +#import +#import +#import + +#include + +@implementation MPSCNNConvDataSource { + void* _weights; + float* _bias; + MPSCNNConvolutionDescriptor* _descriptor; +} + +- (id)initWithWeights:(void*)weights + Bias:(float*)bias + Desc:(MPSCNNConvolutionDescriptor*)desc + API_AVAILABLE(ios(10.0), macos(10.13)) { + self = [super init]; + if (self) { + _weights = (float*)weights; + _bias = (float*)bias; + _descriptor = desc; + } + return self; +} + +- (nonnull id)copyWithZone:(nullable NSZone*)zone { + MPSCNNConvDataSource* dataSource = [MPSCNNConvDataSource allocWithZone:zone]; + dataSource->_weights = _weights; + dataSource->_bias = _bias; + dataSource->_descriptor = _descriptor; + return dataSource; +} + +- (float* _Nullable)biasTerms { + return _bias; +} + +- (MPSDataType)dataType API_AVAILABLE(ios(10.0), macos(10.13)) { + return MPSDataTypeFloat32; +} + +- (NSString* _Nullable)label { + return @""; +} + +- (BOOL)load { + return true; +} + +- (void)purge { + _bias = nullptr; + _weights = nullptr; +} + +- (void*)weights { + return _weights; +} + +- (MPSCNNConvolutionDescriptor* _Nonnull)descriptor { + return _descriptor; +} + +@end + +@implementation MPSCNNConvOp { +} + +@synthesize kernel = _kernel; + ++ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params + weights:(float*)w + bias:(float*)b + neuronFilter:(NeuronType)t API_AVAILABLE(ios(10.0), macos(10.13)) { + using namespace at::native::metal::mpscnn; + TORCH_CHECK( + params.DX == params.DY == 1, "Dilated convolution is not supported yet."); + const int64_t oC = params.OC; + const int64_t iC = params.C; + const int64_t kH = params.KH; + const int64_t kW = params.KW; + MPSCNNNeuron* neuron = [MPSCNNConvOp neuron:t]; + MPSCNNConvolutionDescriptor* desc = nil; + if (params.isDepthwise()) { + if (@available(iOS 11.0, *)) { + desc = [MPSCNNDepthWiseConvolutionDescriptor + cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iC + outputFeatureChannels:oC + neuronFilter:neuron]; + desc.groups = 1; + } else { + TORCH_CHECK( + false, + "MPSCNNDepthWiseConvolutionDescriptor is only available on iOS 11.0 and above"); + } + } else { + if (params.G > 1) { + TORCH_CHECK( + params.IC % 4 == 0, + "MPSCNNConvolution requires number of input \ + channels in each group to be multiple of 4 for \ + group > 1."); + } + desc = [MPSCNNConvolutionDescriptor + cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iC + outputFeatureChannels:oC + neuronFilter:neuron]; + desc.groups = params.G; + } + desc.strideInPixelsX = params.SX; + desc.strideInPixelsY = params.SY; + id dataSource = + [[MPSCNNConvDataSource alloc] initWithWeights:(float*)w + Bias:(float*)b + Desc:desc]; + MPSCNNConvolution* conv = nil; + if (@available(iOS 11.0, *)) { + conv = [[MPSCNNConvolution alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + weights:dataSource]; + + } else { +#if TARGET_OS_IPHONE + // Fallback on earlier versions + conv = [[MPSCNNConvolution alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + convolutionDescriptor:desc + kernelWeights:w + biasTerms:b + flags:MPSCNNConvolutionFlagsNone]; +#endif + } + [conv setEdgeMode:MPSImageEdgeModeZero]; + MPSOffset offset; + offset.x = computeMPSAlignOffset(kW, params.PX); + offset.y = computeMPSAlignOffset(kH, params.PY); + offset.z = 0; + [conv setOffset:offset]; + + TORCH_CHECK(conv.inputFeatureChannels == params.IC * params.G); + TORCH_CHECK(oC % conv.groups == 0); + TORCH_CHECK(conv.outputFeatureChannels == oC); + TORCH_CHECK(conv.kernelWidth == kW); + TORCH_CHECK(conv.kernelHeight == kH); + + MPSCNNConvOp* op = [MPSCNNConvOp new]; + op->_kernel = conv; + return op; +} + +- (void)encode:(id)cb + sourceImage:(MPSImage*)src + destinationImage:(MPSImage*)dst { + [_kernel encodeToCommandBuffer:cb sourceImage:src destinationImage:dst]; +} + ++ (MPSCNNNeuron*)neuron:(NeuronType)type { + if (type == NeuronType::Relu) { + return [MPSCNNNeuronOp relu]; + } else if (type == NeuronType::Sigmoid) { + return [MPSCNNNeuronOp sigmoid]; + } else if (type == NeuronType::Tanh) { + return [MPSCNNNeuronOp tanh]; + } else { + return nil; + } +} + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h new file mode 100644 index 0000000000000..711ccd8088fe6 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h @@ -0,0 +1,12 @@ +#import +#import +#import + +using namespace at::native::metal; +@interface MPSCNNNeuronOp : NSObject + ++ (MPSCNNNeuronReLU*)relu; ++ (MPSCNNNeuronSigmoid*)sigmoid; ++ (MPSCNNNeuronTanH*)tanh; + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm new file mode 100644 index 0000000000000..94c50e2c865b5 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm @@ -0,0 +1,39 @@ +#import +#import + +@implementation MPSCNNNeuronOp + ++ (MPSCNNNeuronReLU*)relu { + static MPSCNNNeuronReLU* relu = nil; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + relu = [[MPSCNNNeuronReLU alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + a:0]; + }); + return relu; +} + ++ (MPSCNNNeuronSigmoid*)sigmoid { + static dispatch_once_t onceToken; + static MPSCNNNeuronSigmoid* sigmoid = nil; + dispatch_once(&onceToken, ^{ + sigmoid = [[MPSCNNNeuronSigmoid alloc] + initWithDevice:[MPSCNNContext sharedInstance].device]; + }); + return sigmoid; +} + ++ (MPSCNNNeuronTanH*)tanh { + static dispatch_once_t onceToken; + static MPSCNNNeuronTanH* tanh = nil; + dispatch_once(&onceToken, ^{ + tanh = [[MPSCNNNeuronTanH alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + a:1 + b:1]; + }); + return tanh; +} + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNOp.h new file mode 100644 index 0000000000000..cf2f243d86db4 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNOp.h @@ -0,0 +1,27 @@ +#import +#import +#import + +#if (defined(__ARM_NEON__) || defined(__ARM_NEON)) +typedef float16_t fp16; +#else +typedef uint16_t fp16; +#endif + +@protocol MPSCNNOp + +@property(nonatomic, strong) MPSCNNKernel* kernel; + +- (void)encode:(id)cb + sourceImage:(MPSImage*)src + destinationImage:(MPSImage*)dst; + +@end + +@protocol MPSCNNShaderOp + ++ (id)newWithTextures:(NSArray*)textures + Args:(NSArray*)args; +- (void)encode:(id)cb; + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h new file mode 100644 index 0000000000000..4a55b5e540d2b --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h @@ -0,0 +1,69 @@ +#import +#import + +#include + +namespace at { +namespace native { +namespace metal { +namespace mpscnn { + +Tensor conv2d( + const Tensor& input, // metal + const Tensor& weight, // cpu + const c10::optional& bias, // cpu + const Conv2DParams& params, + NeuronType t = NeuronType::None); + +// conv2d with prepacked weights +Tensor conv2d(const Tensor& input, Conv2dOpContext& context); + +Tensor max_pool2d( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode); + +Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size); + +Tensor relu(const Tensor& input); + +Tensor& relu_(Tensor& input); + +Tensor sigmoid(const Tensor& input); + +Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val); + +Tensor t(const Tensor& input); + +Tensor view(const Tensor& input, IntArrayRef size); + +Tensor reshape(const Tensor& input, IntArrayRef shape); + +Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight); + +Tensor add(const Tensor& input1, const Tensor& input2); + +Tensor& add_(Tensor& input1, const Tensor& input2); + +Tensor sub(const Tensor& input1, const Tensor& input2); + +Tensor mul(const Tensor& input1, const Tensor& input2); + +Tensor log_softmax_int(const Tensor& input); + +Tensor upsample_nearest2d_vec( + const Tensor& input, + c10::optional output_size, + c10::optional> scale_factors); + +Tensor flatten_using_ints(const Tensor & input, int64_t start_dim, int64_t end_dim); + +Tensor copy_to_host(const Tensor& input); + +} // namespace mpscnn +} // namespace metal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm new file mode 100644 index 0000000000000..49b04ba6d6233 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm @@ -0,0 +1,642 @@ +#import +#import +#import +#import +#import +#import +#import +#import +#import +#import +#import +#import + +#include +#include +#include + +namespace at { +namespace native { +namespace metal { +namespace mpscnn { + +using MetalTensor = at::native::metal::MetalTensor; +using MetalTensorImpl = at::MetalTensorImpl; + +API_AVAILABLE(ios(10.0), macos(10.13)) +static inline MPSImage* imageFromMetalTensor(const MetalTensor& tensor) { + return tensor.texture()->image(); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +static inline MPSImage* imageFromTensor(const Tensor& tensor) { + TORCH_CHECK(tensor.is_metal()); + MetalTensorImpl* impl = (MetalTensorImpl*)tensor.unsafeGetTensorImpl(); + MetalTensor& metalTensor = impl->unsafe_opaque_handle(); + return imageFromMetalTensor(metalTensor); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +static inline MetalCommandBuffer* commandBufferFromInputTensor( + const Tensor& tensor) { + TORCH_CHECK(tensor.is_metal()); + MetalTensorImpl* impl = (MetalTensorImpl*)tensor.unsafeGetTensorImpl(); + MetalTensor& metalTensor = impl->unsafe_opaque_handle(); + MetalCommandBuffer* cmdBuffer = metalTensor.texture()->commandBuffer(); + TORCH_CHECK(cmdBuffer, @"Command Buffer can't be nil!"); + return cmdBuffer; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor conv2d( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias, + const Conv2DParams& params, + NeuronType t) { + TORCH_CHECK(weight.device().type() == kCPU); + MPSImage* X = imageFromTensor(input); + const int64_t oC = weight.sizes()[0]; + const int64_t iC = weight.sizes()[1]; + const int64_t kH = weight.sizes()[2]; + const int64_t kW = weight.sizes()[3]; + auto packedWeights = at::native::metal::permuteWeights( + weight.data_ptr(), {oC, iC, kH, kW}); + // MPSCNN Convolution + float* w = packedWeights.data(); + float* b = bias.has_value() ? bias->data_ptr() : nullptr; + MPSCNNConvOp* op = [MPSCNNConvOp conv2d:params + weights:w + bias:b + neuronFilter:t]; + auto outputSize = params.output_sizes(); + MetalTensor mt{outputSize}; + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor conv2d(const Tensor& input, Conv2dOpContext& context) { + MPSImage* X = imageFromTensor(input); + Conv2DParams params{input.sizes(), + context.weight.sizes(), + context.padding, + context.stride, + context.dilation, + context.groups}; + MPSCNNConvOp* op = (__bridge MPSCNNConvOp*)(context.conv2dOp); + NeuronType nt = neuronType(context); + if (!op) { + float* w = context.weight.data_ptr(); + float* b = context.bias.has_value() ? ((*context.bias).data_ptr()) + : nullptr; + op = [MPSCNNConvOp conv2d:params weights:w bias:b neuronFilter:nt]; + context.conv2dOp = (void*)CFBridgingRetain(op); + context.releaseCallback = ^(void* res) { + if (res) { + CFBridgingRelease(res); + } + }; + } + + auto outputSize = params.output_sizes(); + MetalTensor mt{outputSize}; + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer); + MPSImage* Y1 = imageFromMetalTensor(mt); + [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y1]; + // fuse hardtanh with convolution + if (nt == NeuronType::Clamp) { + MPSImage* Y2 = [MPSImage temporaryImageFromSize:[Y1 sizes] + commandBuffer:commandBuffer]; + float min = context.output_min.value().toFloat(); + float max = context.output_max.value().toFloat(); + MPSCNNClampOp* clampOp = + [MPSCNNClampOp newWithTextures:@[ Y1, Y2 ] Args:@[ @(min), @(max) ]]; + [clampOp encode:commandBuffer.buffer]; + mt.texture()->copyFromTexture(Y2); + } + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor max_pool2d( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + const int64_t iN = input.sizes()[0]; + const int64_t iC = input.sizes()[1]; + const int64_t iH = input.sizes()[2]; + const int64_t iW = input.sizes()[3]; + const int64_t kH = kernel_size[0]; + const int64_t kW = kernel_size[1]; + const int64_t sH = stride[0]; + const int64_t sW = stride[1]; + const int64_t pH = padding[0]; + const int64_t pW = padding[1]; + const int64_t dH = dilation[0]; + const int64_t dW = dilation[1]; + MPSImage* X = imageFromTensor(input); + MPSCNNPoolingMax* pool = [[MPSCNNPoolingMax alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + kernelWidth:kernel_size[0] + kernelHeight:kernel_size[1] + strideInPixelsX:stride[0] + strideInPixelsY:stride[1]]; + [pool setEdgeMode:MPSImageEdgeModeClamp]; + [pool setOffset:{.x = static_cast(kernel_size[0] / 2), + .y = static_cast(kernel_size[1] / 2), + .z = 0}]; + + int64_t oN = iN; + int64_t oC = iC; + int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode); + int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode); + + std::vector outputSize{oN, oC, oH, oW}; + MetalTensor mt{outputSize}; + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + [pool encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size) { + MPSImage* X = imageFromTensor(input); + MPSCNNPoolingAverage* pool = [[MPSCNNPoolingAverage alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + kernelWidth:X.width + kernelHeight:X.height + strideInPixelsX:X.width + strideInPixelsY:X.height]; + [pool setEdgeMode:MPSImageEdgeModeClamp]; + [pool setOffset:{.x = static_cast(X.width / 2), + .y = static_cast(X.height / 2), + .z = 0}]; + std::vector outputSize{ + input.sizes()[0], input.sizes()[1], output_size[0], output_size[1]}; + MetalTensor mt{outputSize}; + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + [pool encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { + MPSImage* X = imageFromTensor(input); + std::vector outputSize = input.sizes().vec(); + std::vector textureSize = outputSize; + if (input.dim() == 2) { + textureSize = {outputSize[0], outputSize[1], 1, 1}; + } + MetalTensor mt{outputSize}; + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage(textureSize, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + [neuron encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) { + MPSImage* X = imageFromTensor(input); + std::vector outputSize = input.sizes().vec(); + std::vector textureSize = outputSize; + if (input.dim() == 2) { + textureSize = {outputSize[0], outputSize[1], 1, 1}; + } + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + MPSImage* Y = [MPSImage temporaryImageFromSize:input.sizes().vec() + commandBuffer:commandBuffer]; + [neuron encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); + MetalTensor& metalTensor = impl->unsafe_opaque_handle(); + metalTensor.texture()->copyFromTexture(Y); + return input; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor relu(const Tensor& input) { + return neuronKernel(input, [MPSCNNNeuronOp relu]); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor& relu_(Tensor& input) { + return neuronKernel_(input, [MPSCNNNeuronOp relu]); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor sigmoid(const Tensor& input) { + return neuronKernel(input, [MPSCNNNeuronOp sigmoid]); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor tanh(const Tensor& input) { + return neuronKernel(input, [MPSCNNNeuronOp tanh]); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val) { + MPSImage* X = imageFromTensor(input); + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + MPSImage* Y = [MPSImage temporaryImageFromSize:input.sizes().vec() + commandBuffer:commandBuffer]; + float min = min_val.toFloat(); + float max = max_val.toFloat(); + MPSCNNClampOp* clampOp = [MPSCNNClampOp newWithTextures:@[ X, Y ] + Args:@[ @(min), @(max) ]]; + [clampOp encode:commandBuffer.buffer]; + MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); + MetalTensor& metalTensor = impl->unsafe_opaque_handle(); + metalTensor.texture()->copyFromTexture(Y); + return input; +} + +/* + A fully connected layer takes an MPSImage object with dimensions source.width x + source.height x Ni, convolves it with + Weights[No][source.width][source.height][Ni],and produces a 1 x 1 x No output. + + Thus, the following conditions must be true: + kernelWidth == source.width + kernelHeight == source.height + clipRect.size.width == 1 + clipRect.size.height == 1 + + You can think of a fully connected layer as a matrix multiplication + where the image is flattened into a vector of length + source.width*source.height*Ni, and the weights are arranged in a matrix of + dimension No x (source.width*source.height*Ni) to produce an output vector of + length No + + The value of the strideInPixelsX, strideInPixelsY, and groups properties must + be 1. The offset property is not applicable and it is ignored. Because the clip + rectangle is clamped to the destination image bounds, if the destination is 1 x + 1, you do not need to set the clipRect property. + */ +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight) { + MPSImage* X = imageFromTensor(input); + const int64_t N = X.numberOfImages; + const int64_t oC = weight.sizes()[0]; + const int64_t kH = X.height; + const int64_t kW = X.width; + const int64_t iC = weight.sizes()[1] / kH / kW; + auto packedWeights = at::native::metal::permuteWeights( + weight.data_ptr(), {oC, iC, kH, kW}); + MPSCNNConvolutionDescriptor* desc = + [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iC + outputFeatureChannels:oC + neuronFilter:nil]; + desc.strideInPixelsX = 1; + desc.strideInPixelsY = 1; + MPSCNNConvDataSource* ds = [[MPSCNNConvDataSource alloc] + initWithWeights:packedWeights.data() + Bias:bias.defined() ? bias.data_ptr() : nil + Desc:desc]; + MPSCNNFullyConnected* fc = nil; + if (@available(iOS 11.0, *)) { + fc = [[MPSCNNFullyConnected alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + weights:ds]; + } else { +#if TARGET_OS_IPHONE + fc = [[MPSCNNFullyConnected alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + convolutionDescriptor:desc + kernelWeights:(float*)packedWeights.data() + biasTerms:bias.defined() ? bias.data_ptr() : nil + flags:MPSCNNConvolutionFlagsNone]; +#endif + } + [fc setClipRect:MTLRegionMake3D(0, 0, 0, 1, 1, N)]; + [fc setOffset:{.x = static_cast(X.width / 2), + .y = static_cast(X.height / 2), + .z = 0}]; + std::vector outputSize = {N, oC, 1, 1}; + MetalTensor mt{{N, oC}}; + + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + [fc encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor binaryElementwiseKernel( + const Tensor& input1, + const Tensor& input2, + NSString* arrayKernel, + NSString* nonarrayKernel) { + MPSImage* X1 = imageFromTensor(input1); + MPSImage* X2 = imageFromTensor(input2); + std::vector outputSize = input1.sizes().vec(); + MetalTensor mt{outputSize}; + MetalCommandBuffer* cb1 = commandBufferFromInputTensor(input1); + MetalCommandBuffer* cb2 = commandBufferFromInputTensor(input2); + TORCH_CHECK([cb1 isEqual:cb2], @"inputs have different command buffer"); + mt.texture()->allocateTemporaryTextureStorage(outputSize, cb1); + MPSImage* Y = imageFromMetalTensor(mt); + id state = [[MPSCNNContext sharedInstance] + pipelineState:kernelFor(X1, arrayKernel, nonarrayKernel)]; + id encoder = [cb1.buffer computeCommandEncoder]; + [encoder setComputePipelineState:state]; + [encoder setTexture:[X1 texture] atIndex:0]; + [encoder setTexture:[X2 texture] atIndex:1]; + [encoder setTexture:[Y texture] atIndex:2]; + const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [X1 markRead]; + [X2 markRead]; + auto output = MetalTensor::toTensor(std::move(mt), input1.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor& binaryElementwiseKernel_( + Tensor& input1, + const Tensor& input2, + NSString* arrayKernel, + NSString* nonarrayKernel) { + MPSImage* X1 = imageFromTensor(input1); + MPSImage* X2 = imageFromTensor(input2); + std::vector outputSize = input1.sizes().vec(); + MetalCommandBuffer* cb1 = commandBufferFromInputTensor(input1); + MetalCommandBuffer* cb2 = commandBufferFromInputTensor(input2); + TORCH_CHECK([cb1 isEqual:cb2], @"inputs have different command buffer"); + MPSImage* Y = [MPSImage temporaryImageFromSize:outputSize commandBuffer:cb1]; + id state = [[MPSCNNContext sharedInstance] + pipelineState:kernelFor(X1, arrayKernel, nonarrayKernel)]; + id encoder = [cb1.buffer computeCommandEncoder]; + [encoder setComputePipelineState:state]; + [encoder setTexture:[X1 texture] atIndex:0]; + [encoder setTexture:[X2 texture] atIndex:1]; + [encoder setTexture:[Y texture] atIndex:2]; + const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [X1 markRead]; + [X2 markRead]; + MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl(); + MetalTensor& metalTensor = impl->unsafe_opaque_handle(); + metalTensor.texture()->copyFromTexture(Y); + return input1; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor add(const Tensor& input1, const Tensor& input2) { + return binaryElementwiseKernel( + input1, input2, @"elementwise_add", @"elementwise_add_nonarray"); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor& add_(Tensor& input1, const Tensor& input2) { + return binaryElementwiseKernel_( + input1, input2, @"elementwise_add", @"elementwise_add_nonarray"); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor sub(const Tensor& input1, const Tensor& input2) { + return binaryElementwiseKernel( + input1, input2, @"elementwise_sub", @"elementwise_sub_nonarray"); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor mul(const Tensor& input1, const Tensor& input2) { + return binaryElementwiseKernel( + input1, input2, @"elementwise_mul", @"elementwise_mul_nonarray"); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor t(const Tensor& input) { + auto strides = input.strides().vec(); + auto sizes = input.sizes().vec(); + MPSImage* X = imageFromTensor(input); + TORCH_CHECK(X.numberOfImages == 1); + TORCH_CHECK(X.featureChannels == 1); + MetalTensor mt({sizes[1], sizes[0]}); + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage( + {1, 1, sizes[1], sizes[0]}, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + MPSImageTranspose* transpose = [[MPSImageTranspose alloc] + initWithDevice:[MPSCNNContext sharedInstance].device]; + [transpose encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor view(const Tensor& input, IntArrayRef size) { + auto inferred_size = at::infer_size(size, input.numel()); + auto stride = + at::detail::computeStride(input.sizes(), input.strides(), inferred_size); + TORCH_CHECK( + stride.has_value(), + "view size is " + "not compatible with input tensor's size and stride (at least one dimension" + " spans across two contiguous subspaces). Use .reshape(...) instead."); + auto stride_value = *stride; + + MPSImage* X = imageFromTensor(input); + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + MetalTensor mt{inferred_size, stride_value}; + mt.texture()->setCommandBuffer(commandBuffer); + mt.texture()->copyFromTexture(X); + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +Tensor reshape(const Tensor& input, IntArrayRef shape) { + return view(input, shape); +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor log_softmax_int(const Tensor& input) { + MPSImage* X = imageFromTensor(input); + TORCH_CHECK(X.height == 1 && X.width == 1); + std::vector outputSize = input.sizes().vec(); + MPSCNNLogSoftMax* logSoftmax = [[MPSCNNLogSoftMax alloc] + initWithDevice:[MPSCNNContext sharedInstance].device]; + + MetalTensor mt{outputSize}; + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage( + {outputSize[0], outputSize[1], 1, 1}, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + [logSoftmax encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +API_AVAILABLE(ios(10.0), macos(10.13)) +Tensor upsample_nearest2d_vec( + const Tensor& input, + c10::optional output_size, + c10::optional> scale_factors) { + auto osize = + upsample::compute_output_size(input.sizes(), output_size, scale_factors); + auto scale_h = upsample::get_scale_value(scale_factors, 0); + auto scale_w = upsample::get_scale_value(scale_factors, 1); + int64_t output_height = osize[0]; + int64_t output_width = osize[1]; + int64_t nbatch = input.size(0); + int64_t channels = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + upsample_2d_shape_check( + input, + Tensor(), + nbatch, + channels, + input_height, + input_width, + output_height, + output_width); + std::vector outputSizes{ + nbatch, channels, output_height, output_width}; + MPSImage* X = imageFromTensor(input); + MetalTensor mt{outputSizes}; + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + mt.texture()->allocateTemporaryTextureStorage(outputSizes, commandBuffer); + MPSImage* Y = imageFromMetalTensor(mt); + if (@available(iOS 11.0, *)) { + MPSCNNUpsamplingNearest* kernel = [[MPSCNNUpsamplingNearest alloc] + initWithDevice:[MPSCNNContext sharedInstance].device + integerScaleFactorX:(NSUInteger)scale_w.value() + integerScaleFactorY:(NSUInteger)scale_h.value()]; + [kernel encodeToCommandBuffer:commandBuffer.buffer + sourceImage:X + destinationImage:Y]; + } else { + NSUInteger sh = scale_h.value() * 10000; + NSUInteger sw = scale_w.value() * 10000; + id state = [[MPSCNNContext sharedInstance] + specializedPipelineState:kernelFor( + Y, + @"resize_nearest", + @"resize_nearest_nonarray") + Constants:@[ + @(output_height), + @(output_width), + @(sh), + @(sw) + ]]; + id encoder = + [commandBuffer.buffer computeCommandEncoder]; + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [X markRead]; + [Y markRead]; + } + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +Tensor flatten_using_ints( + const Tensor& input, + int64_t start_dim, + int64_t end_dim) { + start_dim = maybe_wrap_dim(start_dim, input.dim()); + end_dim = maybe_wrap_dim(end_dim, input.dim()); + TORCH_CHECK( + start_dim <= end_dim, + "flatten() has invalid args: start_dim cannot come after end_dim"); + std::vector shape; + if (input.dim() == 0) { + return input.reshape({1}); + } + if (start_dim == end_dim) { + return input; + } + auto slice_numel = + prod_intlist(input.sizes().slice(start_dim, end_dim - start_dim + 1)); + shape.reserve(input.dim() - end_dim + start_dim); + for (int64_t i = 0; i < start_dim; i++) { + shape.push_back(input.size(i)); + } + shape.push_back(slice_numel); + for (int64_t i = end_dim + 1; i < input.dim(); i++) { + shape.push_back(input.size(i)); + } + return input.reshape(shape); +} + +Tensor copy_to_host(const Tensor& input) { + MPSImage* X = imageFromTensor(input); + MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input); + auto&& sizes = [X sizes]; + MetalTensor mt{sizes}; + mt.texture()->setCommandBuffer(commandBuffer); + mt.texture()->allocateTextureStorage(sizes); + MPSImage* Y = imageFromMetalTensor(mt); + id encoder = + [commandBuffer.buffer computeCommandEncoder]; + id state = [[MPSCNNContext sharedInstance] + specializedPipelineState:metal::mpscnn::kernelFor( + X, @"copy", @"copy_nonarray") + Constants:@[ + @(X.featureChannels), + @(X.height), + @(X.width) + ]]; + + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [X markRead]; + auto output = MetalTensor::toTensor(std::move(mt), input.options()); + return output; +} + +} +} +} +} diff --git a/aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.h b/aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.h new file mode 100644 index 0000000000000..27bb318d32835 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.h @@ -0,0 +1,47 @@ +#include +#import +#import +#import + +@interface MPSImage (Tensor) + ++ (MPSImage*)imageFromCPUTensor:(const at::Tensor&)tensor; +- (at::Tensor)toCPUTensor; + ++ (MPSImage*)imageFromFp16Array:(const uint16_t*)src + Sizes:(const std::vector&)sizes; +- (std::vector)toFp16Array; + ++ (MPSImage*)imageFromSize:(const std::vector&)size; ++ (MPSTemporaryImage*)temporaryImageFromSize:(const std::vector&)size + commandBuffer:(MetalCommandBuffer*)cmdBuffer; + +- (std::vector)sizes; +- (int64_t)readCount; +- (BOOL)isTemporaryImage; +- (void)markRead; +- (void)recycle; + +@end + +@interface MPSImage (Shaders) + ++ (MPSImage*)imageFromImage:(MPSImage*)image; + ++ (MPSTemporaryImage*)temporaryImageFromImage:(MPSImage*)image + CommandBuffer:(MetalCommandBuffer*)cb; + ++ (MPSImage*)imageFromTemporaryImage:(MPSTemporaryImage*)image + CommandBuffer:(MetalCommandBuffer*)cb + waitUntilCompleted:(BOOL)b; + ++ (MPSImage*)imageFromHost:(const float*)src + Sizes:(const std::vector&)sizes; + ++ (MPSTemporaryImage*)temporaryImageFromHost:(const float*)src + Sizes:(const std::vector&)sizes + CommandBuffer:(MetalCommandBuffer*)cb; + ++ (void)copyToHost:(float*)dst FromImage:(MPSImage*)image; + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm b/aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm new file mode 100644 index 0000000000000..3098bb70d0139 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm @@ -0,0 +1,337 @@ +#include +#include +#include +#include + +#include + +using namespace at::native; +@implementation MPSImage (Tensor) + ++ (MPSImage*)imageFromCPUTensor:(const at::Tensor&)tensor { + TORCH_CHECK(tensor.device().is_cpu()); + TORCH_CHECK(tensor.dim() == 4); + auto contiguousTensor = tensor.contiguous(); + float* src = tensor.data_ptr(); + std::vector sizes = tensor.sizes().vec(); + auto c4 = metal::NCHW_to_NC4(src, sizes); + auto c4fp16 = metal::fp32_to_fp16(c4); + return [self imageFromFp16Array:c4fp16.data() Sizes:sizes]; +} + ++ (MPSImage*)imageFromFp16Array:(const uint16_t*)src + Sizes:(const std::vector&)sizes { + int64_t N = sizes[0]; + int64_t C = sizes[1]; + int64_t H = sizes[2]; + int64_t W = sizes[3]; + MPSImageDescriptor* desc = [MPSImageDescriptor + imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 + width:W + height:H + featureChannels:C + numberOfImages:N + usage:MTLTextureUsageShaderRead | + MTLTextureUsageShaderWrite]; + MPSImage* image = + [[MPSImage alloc] initWithDevice:[MPSCNNContext sharedInstance].device + imageDescriptor:desc]; + + int64_t slices = (C + 3) / 4 * N; + int64_t numComponents = image.featureChannels < 3 ? image.featureChannels : 4; + int64_t bytesPerRow = W * numComponents * sizeof(uint16_t); + uint8_t* ptr = (uint8_t*)src; + for (int i = 0; i < slices; ++i) { + [image.texture replaceRegion:MTLRegionMake2D(0, 0, W, H) + mipmapLevel:0 + slice:i + withBytes:ptr + bytesPerRow:bytesPerRow + bytesPerImage:0]; + ptr += H * bytesPerRow; + } + return image; +} + ++ (MPSImage*)imageFromSize:(const std::vector&)size { + MPSImageDescriptor* desc = [MPSImageDescriptor + imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 + width:size[3] + height:size[2] + featureChannels:size[1] + numberOfImages:size[0] + usage:MTLTextureUsageShaderRead | + MTLTextureUsageShaderWrite]; + return [[MPSImage alloc] initWithDevice:[MPSCNNContext sharedInstance].device + imageDescriptor:desc]; +} + +- (std::vector)toFp16Array { + if (self.pixelFormat == MTLPixelFormatR16Float || + self.pixelFormat == MTLPixelFormatRG16Float || + self.pixelFormat == MTLPixelFormatRGBA16Float) { + int64_t slices = (self.featureChannels + 3) / 4; + int64_t C = self.featureChannels < 3 ? self.featureChannels : slices * 4; + int64_t numComponents = self.featureChannels < 3 ? self.featureChannels : 4; + int64_t count = self.width * self.height * self.numberOfImages * C; + std::vector output(count, 0); + int64_t bytesPerRow = self.width * numComponents * sizeof(uint16_t); + uint8_t* buffer = (uint8_t*)output.data(); + for (int i = 0; i < slices * self.numberOfImages; ++i) { + [self.texture getBytes:buffer + bytesPerRow:bytesPerRow + bytesPerImage:0 + fromRegion:MTLRegionMake2D(0, 0, self.width, self.height) + mipmapLevel:0 + slice:i]; + buffer += self.height * bytesPerRow; + } + return output; + } + TORCH_CHECK( + false, "Copy to float buffer failed: The pixel format didn't match"); + return {}; +} + +- (at::Tensor)toCPUTensor { + auto outputSize = [self sizes]; + std::vector fp16 = [self toFp16Array]; + auto fp32 = metal::fp16_to_fp32(fp16); + std::vector fp32_nchw = metal::NC4_to_NCHW(fp32.data(), outputSize); + auto tensor = at::empty(outputSize); + int64_t size_bytes = at::prod_intlist(outputSize) * sizeof(float); + memcpy(tensor.data_ptr(), fp32_nchw.data(), size_bytes); + return tensor; +} + +- (std::vector)sizes { + int64_t N = self.numberOfImages; + int64_t C = self.featureChannels; + int64_t H = self.height; + int64_t W = self.width; + return {N, C, H, W}; +} + ++ (MPSTemporaryImage*)temporaryImageFromSize:(const std::vector&)size + commandBuffer:(MetalCommandBuffer*)cmdBuffer { + NSCAssert(cmdBuffer, @"CommandBuffer is nil!"); + MPSImageDescriptor* desc = [MPSImageDescriptor + imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 + width:size[3] + height:size[2] + featureChannels:size[1] + numberOfImages:size[0] + usage:MTLTextureUsageShaderRead | + MTLTextureUsageShaderWrite]; + MPSTemporaryImage* image = + [MPSTemporaryImage temporaryImageWithCommandBuffer:cmdBuffer.buffer + imageDescriptor:desc]; + image.readCount = INT_MAX; + [cmdBuffer add:image]; + return image; +} + +- (BOOL)isTemporaryImage { + return [self isKindOfClass:[MPSTemporaryImage class]]; +} + +- (void)markRead { + if ([self isTemporaryImage]) { + MPSTemporaryImage* tmpImage = (MPSTemporaryImage*)self; + if (tmpImage.readCount > 0) { + tmpImage.readCount -= 1; + } + } +} + +- (void)recycle { + if ([self isTemporaryImage]) { + MPSTemporaryImage* tmpImage = (MPSTemporaryImage*)self; + if (tmpImage.readCount > 0) { + tmpImage.readCount = 0; + } + } +} + +- (int64_t)readCount { + if ([self isTemporaryImage]) { + MPSTemporaryImage* tmpImage = (MPSTemporaryImage*)self; + return (int64_t)tmpImage.readCount; + } + return -1; +} + +@end + +@implementation MPSImage (Shaders) + ++ (MPSImage*)imageFromImage:(MPSImage*)X { + auto&& sizes = [X sizes]; + MPSImage* Y = [MPSImage imageFromSize:sizes]; + MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer]; + id encoder = [cb.buffer computeCommandEncoder]; + id state = [[MPSCNNContext sharedInstance] + pipelineState:metal::mpscnn::kernelFor(X, @"copy", @"copy_nonarray")]; + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [cb synchronize]; + return Y; +} + ++ (MPSTemporaryImage*)temporaryImageFromImage:(MPSImage*)X + CommandBuffer:(MetalCommandBuffer*)cb { + NSCAssert(cb, @"CommandBuffer is nil!"); + MPSTemporaryImage* Y = [MPSImage temporaryImageFromSize:[X sizes] + commandBuffer:cb]; + id encoder = [cb.buffer computeCommandEncoder]; + id state = [[MPSCNNContext sharedInstance] + pipelineState:metal::mpscnn::kernelFor(X, @"copy", @"copy_nonarray")]; + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + return Y; +} + ++ (MPSImage*)imageFromTemporaryImage:(MPSTemporaryImage*)X + CommandBuffer:(MetalCommandBuffer*)cb + waitUntilCompleted:(BOOL)b { + NSCAssert(cb, @"CommandBuffer is nil!"); + auto&& sizes = [X sizes]; + MPSImage* Y = [MPSImage imageFromSize:sizes]; + id encoder = [cb.buffer computeCommandEncoder]; + id state = [[MPSCNNContext sharedInstance] + pipelineState:metal::mpscnn::kernelFor(X, @"copy", @"copy_nonarray")]; + + [encoder setComputePipelineState:state]; + [encoder setTexture:[X texture] atIndex:0]; + [encoder setTexture:[Y texture] atIndex:1]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [X markRead]; + if (b) { + [cb synchronize]; + } + return Y; +} + ++ (MPSImage*)imageFromHost:(const float*)src + Sizes:(const std::vector&)sizes { + int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float); + // allocte buffer on CPU + id buff = [[MPSCNNContext sharedInstance].device + newBufferWithLength:size_bytes + options:MTLResourceOptionCPUCacheModeWriteCombined]; + memcpy(buff.contents, src, size_bytes); + MPSImage* output = [MPSImage imageFromSize:sizes]; + id state = [[MPSCNNContext sharedInstance] + specializedPipelineState:metal::mpscnn::kernelFor( + output, + @"copy_nchw_to_metal", + @"copy_nchw_to_metal_nonarray") + Constants:@[ + @(output.featureChannels), + @(output.height), + @(output.width) + ]]; + MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer]; + id encoder = [cb.buffer computeCommandEncoder]; + [encoder setComputePipelineState:state]; + [encoder setBuffer:buff offset:0 atIndex:0]; + [encoder setTexture:[output texture] atIndex:0]; + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [cb synchronize]; + return output; +} + ++ (MPSTemporaryImage*)temporaryImageFromHost:(const float*)src + Sizes:(const std::vector&)sizes + CommandBuffer:(MetalCommandBuffer*)cb { + NSCAssert(cb, @"CommandBuffer is nil!"); + int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float); + // allocte buffer on CPU + id buff = [[MPSCNNContext sharedInstance].device + newBufferWithLength:size_bytes + options:MTLResourceOptionCPUCacheModeWriteCombined]; + memcpy(buff.contents, src, size_bytes); + MPSTemporaryImage* output = [MPSImage temporaryImageFromSize:sizes + commandBuffer:cb]; + id state = [[MPSCNNContext sharedInstance] + specializedPipelineState:metal::mpscnn::kernelFor( + output, + @"copy_nchw_to_metal", + @"copy_nchw_to_metal_nonarray") + Constants:@[ + @(output.featureChannels), + @(output.height), + @(output.width) + ]]; + id encoder = [cb.buffer computeCommandEncoder]; + [encoder setComputePipelineState:state]; + [encoder setBuffer:buff offset:0 atIndex:0]; + [encoder setTexture:[output texture] atIndex:0]; + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [output markRead]; + return output; +} + ++ (void)copyToHost:(float*)dst FromImage:(MPSImage*)image { + auto&& sizes = [image sizes]; + int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float); + id buffer = [[MPSCNNContext sharedInstance].device + newBufferWithLength:size_bytes + options:MTLResourceOptionCPUCacheModeDefault]; + + id cb = + [MPSCNNContext sharedInstance].commandQueue.commandBuffer; + id encoder = [cb computeCommandEncoder]; + id state = [[MPSCNNContext sharedInstance] + specializedPipelineState:metal::mpscnn::kernelFor( + image, + @"copy_metal_to_nchw", + @"copy_metal_to_nchw_nonarray") + Constants:@[ + @(image.featureChannels), + @(image.height), + @(image.width) + ]]; + + [encoder setComputePipelineState:state]; + [encoder setBuffer:buffer offset:0 atIndex:0]; + [encoder setTexture:[image texture] atIndex:0]; + + const auto& launchParams = + metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image); + [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid + threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; + [encoder endEncoding]; + [cb commit]; + [cb waitUntilCompleted]; + memcpy(dst, buffer.contents, buffer.length); +} + +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.h b/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.h new file mode 100644 index 0000000000000..c6b902116c3b2 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.h @@ -0,0 +1,49 @@ +#ifndef MPSImageWrapper_h +#define MPSImageWrapper_h + +#import +#import +#include + +namespace at { +namespace native { +namespace metal { + +enum class TextureType { + TextureNone, + TextureType2D, + TextureType2DArray, +}; + +class API_AVAILABLE(ios(10.0), macos(10.13)) MPSImageWrapper { + public: + MPSImageWrapper(IntArrayRef sizes); + operator bool() const { + return _image; + } + void copyDataFromHost(const float* inputData); + void copyDataToHost(float* hostData); + void allocateTextureStorage(IntArrayRef sizes); + void allocateTemporaryTextureStorage( + IntArrayRef sizes, + MetalCommandBuffer* commandBuffer); + void copyFromTexture(MPSImage* image); + void setCommandBuffer(MetalCommandBuffer* buffer); + MetalCommandBuffer* commandBuffer() const; + TextureType textureType() const; + IntArrayRef textureSizes() const; + MPSImage* image() const; + void recycleImage(); + void synchronize(); + + private: + std::vector _textureSizes; + MPSImage* _image = nullptr; + MetalCommandBuffer* _commandBuffer; +}; + +} // namespace metal +} // namespace native +} // namespace at + +#endif /* MPSImageWrapper_h */ diff --git a/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm b/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm new file mode 100644 index 0000000000000..6130ce62a491b --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm @@ -0,0 +1,116 @@ +#import +#import +#import +#import +#import +#import + +#include + +namespace at { +namespace native { +namespace metal { + +std::vector textureSizeFromSizes(IntArrayRef sizes, TextureType type) { + if (sizes.size() == 2) { + if (type == TextureType::TextureType2DArray) { + return {sizes[0], sizes[1], 1, 1}; + } else if (type == TextureType::TextureType2D) { + return {1, 1, sizes[0], sizes[1]}; + } else { + return {}; + } + } + return sizes.vec(); +} +MPSImageWrapper::MPSImageWrapper(IntArrayRef sizes) { + _textureSizes = textureSizeFromSizes(sizes, TextureType::TextureType2D); +} + +void MPSImageWrapper::copyDataFromHost(const float* inputData) { + TORCH_CHECK(inputData); + TORCH_CHECK(_textureSizes.size() == 4); + _commandBuffer = [MetalCommandBuffer currentBuffer]; + _image = [MPSImage temporaryImageFromHost:inputData + Sizes:_textureSizes + CommandBuffer:_commandBuffer]; +} + +void MPSImageWrapper::copyDataToHost(float* hostData) { + TORCH_CHECK(_image); + synchronize(); + [MPSImage copyToHost:hostData FromImage:_image]; +} + +MPSImage* MPSImageWrapper::image() const { + return _image; +} + +void MPSImageWrapper::recycleImage() { + if ([_image isTemporaryImage]) { + [_image recycle]; + [_commandBuffer remove:(MPSTemporaryImage*)_image]; + } +} + +void MPSImageWrapper::setCommandBuffer(MetalCommandBuffer* cb) { + _commandBuffer = cb; +} +MetalCommandBuffer* MPSImageWrapper::commandBuffer() const { + return _commandBuffer; +} + +IntArrayRef MPSImageWrapper::textureSizes() const { + return _textureSizes; +} + +TextureType MPSImageWrapper::textureType() const { + if (!_image) { + return TextureType::TextureNone; + } + MTLTextureType textureType = _image.textureType; + if (textureType == MTLTextureType2D) { + return TextureType::TextureType2D; + } else if (textureType == MTLTextureType2DArray) { + return TextureType::TextureType2DArray; + } + return TextureType::TextureNone; +} + +void MPSImageWrapper::allocateTextureStorage(IntArrayRef sizes) { + _textureSizes = sizes.vec(); + _image = [MPSImage imageFromSize:_textureSizes]; +} + +void MPSImageWrapper::allocateTemporaryTextureStorage( + IntArrayRef sizes, + MetalCommandBuffer* commandBuffer) { + TORCH_CHECK(commandBuffer) + _textureSizes = sizes.vec(); + _commandBuffer = commandBuffer; + _image = [MPSImage temporaryImageFromSize:_textureSizes + commandBuffer:commandBuffer]; +} + +void MPSImageWrapper::copyFromTexture(MPSImage* image) { + if ([image isTemporaryImage]) { + _image = [MPSImage temporaryImageFromImage:image + CommandBuffer:_commandBuffer]; + } else { + _image = [MPSImage imageFromImage:image]; + } +} + +void MPSImageWrapper::synchronize() { + if ([_image isTemporaryImage]) { + _image = [MPSImage imageFromTemporaryImage:(MPSTemporaryImage*)_image + CommandBuffer:_commandBuffer + waitUntilCompleted:NO]; + } + [_commandBuffer synchronize]; + _commandBuffer = nil; +} + +} +} +} diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h new file mode 100644 index 0000000000000..b7193b8863b71 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h @@ -0,0 +1,24 @@ +#ifndef MPSCNNTests_h +#define MPSCNNTests_h + +bool test_synchronization(); +bool test_nchw_to_nc4_cpu(); +bool test_copy_nchw_to_metal(); +bool test_conv2d(); +bool test_depthwiseConv(); +bool test_max_pool2d(); +bool test_relu(); +bool test_addmm(); +bool test_add(); +bool test_sub(); +bool test_mul(); +bool test_t(); +bool test_view(); +bool test_softmax(); +bool test_sigmoid(); +bool test_upsampling_nearest2d_vec(); +bool test_adaptive_avg_pool2d(); +bool test_hardtanh_(); +bool test_reshape(); + +#endif diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm new file mode 100644 index 0000000000000..dd265b6a1b8f4 --- /dev/null +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -0,0 +1,406 @@ +#import +#import +#import +#import +#import +#import +#import + +#include +#import + +#include +#include +#include + +#define ITER_COUNT 5 + +namespace { + +int64_t rand(int64_t min, int64_t max) { + return min + (std::rand() % static_cast(max - min + 1)); +} + +bool checkRtol(const at::Tensor& diff, const std::vector inputs) { + double maxValue = 0.0; + for (auto& tensor : inputs) { + maxValue = fmax(tensor.abs().max().item(), maxValue); + } + return diff.abs().max().item() < (0.01 + 2e-2 * maxValue); +} +bool almostEqual(const at::Tensor& a, const at::Tensor& b) { + return checkRtol(a - b, {a, b}) && a.strides().vec() == b.strides().vec(); +} + +bool almostEqualTensor(const at::Tensor& a, const at::Tensor& b, float t) { + if (a.sizes() != b.sizes()) { + return false; + } + if (a.numel() != b.numel()) { + return false; + } + for (int i = 0; i < a.numel(); ++i) { + float x1 = a.data_ptr()[i]; + float x2 = b.data_ptr()[i]; + if (std::abs(x1 - x2) > t) { + return false; + } + } + return true; +} + +bool almostEqualVec( + const std::vector vec1, + const std::vector vec2, + float t) { + if (vec1.size() != vec2.size()) { + return false; + } + for (int i = 0; i < vec1.size(); ++i) { + if (std::abs(vec1[i] - vec2[i]) > t) { + return false; + } + } + return true; +} + +typedef bool (^Func)(void); +bool TEST(const std::vector& sizes, std::string name, Func block) { + std::stringstream ss; + std::copy(sizes.begin(), sizes.end(), std::ostream_iterator(ss, " ")); + __block std::string str1 = ss.str(); + bool b = block(); + void (^print)(NSString*) = ^(NSString* result) { + NSLog(@"[%s],[%s],[%@]", name.c_str(), str1.c_str(), result); + }; + b ? print(@"SUCCEED") : print(@"FAILED"); + return b; +} + +} + +using namespace at::native::metal; + +bool test_synchronization() { + __block std::vector size{1, 3, 2, 2}; + return TEST(size, __PRETTY_FUNCTION__, ^bool(void) { + auto x1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto mx1 = x1.metal(); + TORCH_CHECK(mx1.device().type() == at::kMetal); + auto x2 = mx1.cpu(); + TORCH_CHECK(x2.device().type() == at::kCPU); + return almostEqual(x1, x2); + }); +} + +bool test_nchw_to_nc4_cpu() { + bool result = true; + for (int i = 0; i < ITER_COUNT; ++i) { + int64_t N = rand(1, 24); + int64_t C = rand(1, 48); + int64_t H = rand(1, 320); + int64_t W = rand(1, 320); + __block std::vector size{N, C, H, W}; + bool b = TEST(size, __PRETTY_FUNCTION__, ^bool { + auto t = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + int len = std::accumulate( + std::begin(size), + std::end(size), + (int64_t)1, + std::multiplies()); + auto buf = + std::vector{t.data_ptr(), t.data_ptr() + len}; + auto c4 = NCHW_to_NC4((float*)t.data_ptr(), t.sizes().vec()); + auto n4 = NC4_to_NCHW((float*)c4.data(), t.sizes().vec()); + return n4 == buf; + }); + if (!b) { + result = false; + } + } + return result; +} + +bool test_copy_nchw_to_metal() { + __block std::vector size{1, 3, 224, 224}; + return TEST(size, __PRETTY_FUNCTION__, ^bool(void) { + auto t1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer]; + MPSTemporaryImage* img1 = + [MPSImage temporaryImageFromHost:t1.data_ptr() + Sizes:t1.sizes().vec() + CommandBuffer:cb]; + MPSImage* img2 = [MPSImage imageFromTemporaryImage:img1 + CommandBuffer:cb + waitUntilCompleted:YES]; + auto t2 = at::zeros(size); + [MPSImage copyToHost:t2.data_ptr() FromImage:img2]; + return almostEqual(t1, t2); + }); +} + +bool test_conv2d() { + bool result = true; + for (int i = 0; i < ITER_COUNT; ++i) { + int64_t N = rand(1, 10); + int64_t C = rand(1, 48); + int64_t IH = rand(1, 300); + int64_t IW = rand(1, 300); + int64_t OC = rand(1, 48); + int64_t IC = C; + int64_t KH = rand(1, MIN(10, IH)); + int64_t KW = rand(1, MIN(10, IW)); + int64_t PH = rand(1, 10); + int64_t PW = rand(1, 10); + int64_t SH = rand(1, 10); + int64_t SW = rand(1, 10); + bool b = TEST({N, C, IH, IW}, __PRETTY_FUNCTION__, ^bool { + auto X = at::rand( + {N, C, IH, IW}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto W = at::rand( + {OC, IC, KH, KW}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto B = at::rand({OC}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto S = c10::IntArrayRef({SH, SW}); + auto P = c10::IntArrayRef({PH, PW}); + auto D = + c10::IntArrayRef({1, 1}); // Dilated convolution is not supported yet + int64_t groups = 1; + auto Y1 = at::native::conv2d(X, W, B, S, P, D, groups); + auto X2 = X.metal(); + Conv2DParams params{X.sizes(), W.sizes(), P, S, D, groups}; + auto Y2 = mpscnn::conv2d(X2, W, B, params).cpu(); + return almostEqual(Y1, Y2); + }); + if (!b) { + result = false; + } + } + return result; +} + +bool test_depthwiseConv() { + __block std::vector x{1, 32, 112, 112}; + __block std::vector w{32, 1, 3, 3}; + __block std::vector b{32}; + __block std::vector p{1, 1}; + int g = 32; + return TEST(x, __PRETTY_FUNCTION__, ^bool { + auto S = std::vector{1, 1}; + auto D = std::vector{1, 1}; + auto OP = std::vector({0, 0}); + auto X = at::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto W = at::rand(w, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto B = at::rand(b, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::_convolution( + X, W, B, {1, 1}, p, {1, 1}, false, {0, 0}, g, false, false, true, true); + auto X2 = X.metal(); + Conv2DParams params{X.sizes(), W.sizes(), p, S, D, g}; + if (!params.isDepthwise()) { + return false; + } + auto Y2 = mpscnn::conv2d(X2, W, B, params).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_max_pool2d() { + __block std::vector size{1, 3, 4, 4}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::max_pool2d(X, {2, 2}, {2, 2}, {0, 0}, {1, 1}, false); + auto X2 = X.metal(); + auto Y2 = + mpscnn::max_pool2d(X2, {2, 2}, {2, 2}, {0, 0}, {1, 1}, false).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_relu() { + __block std::vector size{1, 3, 4, 4}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = torch::native::relu(X); + auto X2 = X.metal(); + auto Y2 = mpscnn::relu(X2).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_sigmoid() { + __block std::vector size{1, 3, 4, 4}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::sigmoid(X); + auto X2 = X.metal(); + auto Y2 = mpscnn::sigmoid(X2).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_addmm() { + bool result = true; + for (int i = 0; i < ITER_COUNT; ++i) { + int64_t N = rand(1, 10); + int64_t IC = rand(1, 128); + int64_t OC = rand(1, 128); + bool b = TEST({N, IC, OC}, __PRETTY_FUNCTION__, ^bool { + auto X1 = + at::rand({N, IC}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto W1 = + at::rand({IC, OC}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto B1 = + at::rand({1, OC}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::addmm_cpu(B1, X1, W1); + auto X2 = X1.view({N, IC, 1, 1}).contiguous().metal(); + auto W2 = W1.t().view({W1.sizes()[1], W1.sizes()[0], 1, 1}).contiguous(); + auto Y2 = mpscnn::addmm(B1, X2, W2).cpu(); + return almostEqual(Y1, Y2); + }); + if (!b) { + result = false; + } + } + return result; +} + +bool test_add() { + __block std::vector x{1, 180, 12, 12}; + return TEST(x, __PRETTY_FUNCTION__, ^bool { + auto X1 = at::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto X2 = at::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::add(X1, X2); + auto MX1 = X1.metal(); + auto MX2 = X2.metal(); + auto Y2 = mpscnn::add(MX1, MX2).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_sub() { + __block std::vector x1{1, 3, 192, 192}; + __block std::vector x2{1, 3, 1, 1}; + return TEST(x1, __PRETTY_FUNCTION__, ^bool { + auto X1 = at::rand(x1, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto X2 = at::rand(x2, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::sub(X1, X2); + auto MX1 = X1.metal(); + auto MX2 = X2.metal(); + auto Y2 = mpscnn::sub(MX1, MX2).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_mul() { + __block std::vector x1{1, 3, 192, 192}; + __block std::vector x2{1, 3, 1, 1}; + return TEST(x1, __PRETTY_FUNCTION__, ^bool { + auto X1 = at::rand(x1, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto X2 = at::rand(x2, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::mul(X1, X2); + auto MX1 = X1.metal(); + auto MX2 = X2.metal(); + auto Y2 = mpscnn::mul(MX1, MX2).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_t() { + bool result = true; + for (int i = 0; i < ITER_COUNT; ++i) { + int64_t H = rand(1, 256); + int64_t W = rand(1, 256); + bool b = TEST({H, W}, __PRETTY_FUNCTION__, ^bool { + auto X1 = + torch::rand({H, W}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::t(X1).contiguous(); + auto X2 = X1.metal(); + auto Y2 = mpscnn::t(X2).cpu(); + return almostEqual(Y1, Y2); + }); + if (!b) { + result = false; + } + } + return result; +} + +bool test_view() { + __block std::vector size{1, 3, 2, 2}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = X1.view({3, 4}).contiguous(); + auto X2 = X1.metal(); + auto Y2 = mpscnn::view(X2, {3, 4}).cpu(); + bool b1 = (Y1.sizes() == Y2.sizes()); + bool b2 = (Y1.strides() == Y2.strides()); + return b1 && b2; + }); +} + +bool test_softmax() { + __block std::vector size{2, 3, 1, 1}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::log_softmax(X1, 1); + auto X2 = X1.metal(); + auto Y2 = mpscnn::log_softmax_int(X2).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_upsampling_nearest2d_vec() { + __block std::vector size{1, 48, 24, 24}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X1 = torch::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = torch::native::upsample_nearest2d_cpu( + X1, + c10::optional({}), + c10::optional>({2, 2})); + auto X2 = X1.metal(); + auto Y2 = mpscnn::upsample_nearest2d_vec( + X2, + c10::optional({}), + c10::optional>({2, 2})) + .cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_adaptive_avg_pool2d() { + __block std::vector size{1, 48, 24, 24}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::adaptive_avg_pool2d(X1, {1, 1}); + auto X2 = X1.metal(); + auto Y2 = mpscnn::global_avg_pool2d(X2, {1, 1}).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_reshape() { + __block std::vector size{1, 1280, 1, 1}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::reshape(X1, {1, -1}); + auto X2 = X1.metal(); + auto Y2 = torch::native::metal::mpscnn::reshape(X2, {1, -1}).cpu(); + return almostEqual(Y1, Y2); + }); +} + +bool test_hardtanh_() { +#if TARGET_OS_IPHONE + __block std::vector size{1, 32, 112, 112}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X1 = torch::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto Y1 = at::native::hardtanh_(X1, 0, 6.0); + auto X2 = X1.metal(); + auto Y2 = at::native::metal::mpscnn::hardtanh_(X2, 0, 6.0).cpu(); + return almostEqual(Y1, Y2); + }); +#else + // Skip this test on MacOS as the shader function doesn't work well + // Will get back and fix it - T82700462 + return true; +#endif +} diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 0d4af95c7a765..92473ecc68c85 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -62,7 +62,6 @@ std::tuple miopen_batch_norm( running_mean{ running_mean_t, "running_mean", 4 }, running_var{ running_var_t, "running_var", 5 }; CheckedFrom c = "miopen_batch_norm"; - setMIOpenStreamToCurrent(); checkAllDefined(c, {input, weight, bias}); if (!training) { @@ -151,7 +150,6 @@ std::tuple miopen_batch_norm_backward( save_mean{ save_mean_t, "save_mean", 4 }, save_var{ save_var_t, "save_var", 5 }; CheckedFrom c = "miopen_batch_norm_backward"; - setMIOpenStreamToCurrent(); checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index 3f6e78e77c9f7..f0b0d6fdd5b7a 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -468,7 +468,6 @@ void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) { if (args.params.deterministic && !benchmark) { *algo = search::DEFAULT_ALGO; - return; } if (cache.find(args.params, algo)) { @@ -625,7 +624,6 @@ Tensor miopen_convolution( TensorArg input { input_t, "input", 1 }, weight { weight_t, "weight", 2 }, bias { bias_t, "bias", 3 }; - setMIOpenStreamToCurrent(); CheckedFrom c = "miopen_convolution"; auto output_t = miopen_convolution_forward( c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); @@ -700,7 +698,6 @@ Tensor miopen_depthwise_convolution( TensorArg input { input_t, "input", 1 }, weight { weight_t, "weight", 2 }, bias { bias_t, "bias", 3 }; - setMIOpenStreamToCurrent(); CheckedFrom c = "miopen_depthwise_convolution"; auto output_t = miopen_depthwise_convolution_forward( c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); @@ -717,7 +714,6 @@ Tensor miopen_convolution_transpose_backward_input( { TensorArg grad_output { grad_output_t, "grad_output", 1 }, weight { weight_t, "weight", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_forward( "miopen_convolution_transpose_backward_input", grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); @@ -828,7 +824,6 @@ Tensor miopen_convolution_backward_input( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, weight{ weight_t, "weight", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_backward_input( "miopen_convolution_backward_input", input_size, grad_output, weight, @@ -898,7 +893,6 @@ Tensor miopen_depthwise_convolution_backward_input( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, weight{ weight_t, "weight", 2 }; - setMIOpenStreamToCurrent(); return miopen_depthwise_convolution_backward_input( "miopen_depthwise_convolution_backward_input", input_size, grad_output, weight, @@ -1088,7 +1082,6 @@ Tensor miopen_convolution_backward_weight( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, input{ input_t, "input", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_backward_weight( "miopen_convolution_backward_weight", weight_size, grad_output, input, @@ -1104,7 +1097,6 @@ Tensor miopen_convolution_transpose_backward_weight( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, input{ input_t, "input", 2 }; - setMIOpenStreamToCurrent(); return miopen_convolution_backward_weight( "miopen_convolution_backward_weight", weight_size, input, grad_output, @@ -1120,7 +1112,6 @@ Tensor miopen_depthwise_convolution_backward_weight( { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, input{ input_t, "input", 2 }; - setMIOpenStreamToCurrent(); return miopen_depthwise_convolution_backward_weight( "miopen_depthwise_convolution_backward_weight", weight_size, grad_output, input, @@ -1137,7 +1128,6 @@ Tensor miopen_convolution_backward_bias( const Tensor& grad_output_t) { TensorArg grad_output{ grad_output_t, "grad_output", 1 }; - setMIOpenStreamToCurrent(); auto grad_bias_t = at::empty( { grad_output->size(output_channels_dim) }, grad_output->options()); diff --git a/aten/src/ATen/native/miopen/RNN_miopen.cpp b/aten/src/ATen/native/miopen/RNN_miopen.cpp index 1493cece32128..10b535f890ac0 100644 --- a/aten/src/ATen/native/miopen/RNN_miopen.cpp +++ b/aten/src/ATen/native/miopen/RNN_miopen.cpp @@ -509,7 +509,6 @@ std::tuple miopen_rnn( size_t reserver_size; MIOPEN_CHECK(miopenGetRNNTrainingReserveSize(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), &reserver_size)); reserve = at::empty(reserver_size, input.options().dtype(kByte)); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNForwardTraining(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), @@ -521,7 +520,6 @@ std::tuple miopen_rnn( workspace.data_ptr(), workspace_size, reserve.data_ptr(), reserver_size )); } else { //Inference. reserve = at::empty({0}, input.options().dtype(kByte)); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNForwardInference(handle, descs.rnn_desc.desc(), fn.tensors.seq_length, x_descs_arr.data(), x.data_ptr(), descs.hx_desc.desc(), hx.data_ptr(), @@ -630,7 +628,6 @@ std::tuple miopen_rnn_backward_input( )); auto workspace = at::empty(workspace_size, input.options().dtype(kByte)); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNBackwardData( handle, descs.rnn_desc.desc(), @@ -715,7 +712,6 @@ std::vector miopen_rnn_backward_weight( auto x_descs_arr = descs.get_x_descs(); auto y_descs_arr = descs.get_y_descs(); - setMIOpenStreamToCurrent(); MIOPEN_CHECK(miopenRNNBackwardWeights( handle, descs.rnn_desc.desc(), diff --git a/aten/src/ATen/native/mkl/LinearAlgebra.cpp b/aten/src/ATen/native/mkl/LinearAlgebra.cpp index 9592fa7560346..0fc22c2c637df 100644 --- a/aten/src/ATen/native/mkl/LinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/LinearAlgebra.cpp @@ -72,36 +72,63 @@ static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANS template static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, Scalar beta_, Scalar alpha_) { - auto is_transposed = [&](const TensorAccessor& t) { - return t.stride(0) == 1 && t.stride(1) >= t.size(0); + const auto mat1_strides = mat1.strides(); + const auto mat2_strides = mat2.strides(); + const auto mat1_sizes = mat1.sizes(); + const auto mat2_sizes = mat2.sizes(); + + auto is_transposed = [](const c10::IntArrayRef& strides, const c10::IntArrayRef& sizes) { + return strides[1] == 1 && strides[2] >= sizes[1]; }; - auto mat1_acc = mat1.accessor(); - auto mat2_acc = mat2.accessor(); - auto res_acc = res.accessor(); + const CBLAS_TRANSPOSE trans_A = + is_transposed(mat1_strides, mat1_sizes) ? CblasTrans : CblasNoTrans; + const CBLAS_TRANSPOSE trans_B = + is_transposed(mat2_strides, mat2_sizes) ? CblasTrans : CblasNoTrans; + - const CBLAS_TRANSPOSE trans_A = is_transposed(mat1_acc[0]) ? CblasTrans : CblasNoTrans; - const CBLAS_TRANSPOSE trans_B = is_transposed(mat2_acc[0]) ? CblasTrans : CblasNoTrans; + // mat1: batch_size * M * K + const int batch_size = mat1_sizes[0]; + const int M = mat1_sizes[1]; + // mat2: batch_size * K * N + const int N = mat2_sizes[2]; + const int K = mat1_sizes[2]; - const int batch_size = mat1_acc.size(0); - const int M = mat1_acc.size(1); - const int N = mat2_acc.size(2); - const int K = mat1_acc.size(2); scalar_t alpha = alpha_.to(); scalar_t beta = beta_.to(); - const int lda = is_transposed(mat1_acc[0]) ? mat1_acc[0].stride(1) : mat1_acc[0].stride(0); - const int ldb = is_transposed(mat2_acc[0]) ? mat2_acc[0].stride(1) : mat2_acc[0].stride(0); - const int ldc = res[0].stride(0); - - std::vector A(batch_size); - std::vector B(batch_size); - std::vector C(batch_size); - - for (int64_t batch = 0; batch < batch_size; batch++) { - A[batch] = mat1_acc[batch].data(); - B[batch] = mat2_acc[batch].data(); - C[batch] = res_acc[batch].data(); + const int lda = trans_A == CblasTrans ? mat1_strides[2] : mat1_strides[1]; + const int ldb = trans_B == CblasTrans ? mat2_strides[2] : mat2_strides[1]; + const int ldc = res.strides()[1]; + + std::vector A; + A.reserve(batch_size); + std::vector B; + B.reserve(batch_size); + std::vector C; + C.reserve(batch_size); + + // avoid using tensor accessor in the case of mat1/mat2 not being transposed + // or only transposed in the last two axis + scalar_t* res_data = static_cast(res.data_ptr()); + const auto res_sizes = res.sizes(); + if (mat1_strides[0] == mat1_sizes[1] * mat1_sizes[2] && + mat2_strides[0] == mat2_sizes[1] * mat2_sizes[2]) { + scalar_t* mat1_data = static_cast(mat1.data_ptr()); + scalar_t* mat2_data = static_cast(mat2.data_ptr()); + for (int64_t batch = 0; batch < batch_size; batch++) { + A.emplace_back(mat1_data + batch * mat1_sizes[1] * mat1_sizes[2]); + B.emplace_back(mat2_data + batch * mat2_sizes[1] * mat2_sizes[2]); + C.emplace_back(res_data + batch * res_sizes[1] * res_sizes[2]); + } + } else { + auto mat1_acc = mat1.accessor(); + auto mat2_acc = mat2.accessor(); + for (int64_t batch = 0; batch < batch_size; batch++) { + A.emplace_back(mat1_acc[batch].data()); + B.emplace_back(mat2_acc[batch].data()); + C.emplace_back(res_data + batch * res_sizes[1] * res_sizes[2]); + } } gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc); diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 29c4481bb1cd5..d5a39e45941b6 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -7,11 +8,32 @@ namespace at { namespace native { -Tensor _fft_mkl(const Tensor& input, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - int64_t normalization, bool onesided, - IntArrayRef output_sizes) { +REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub, fft_fill_with_conjugate_symmetry_fn); + +Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { + AT_ERROR("fft: ATen not compiled with MKL support"); +} + +Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { + AT_ERROR("fft: ATen not compiled with MKL support"); +} + +Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { + AT_ERROR("fft: ATen not compiled with MKL support"); +} + +Tensor& _fft_r2c_mkl_out(Tensor& out, const Tensor& self, IntArrayRef dim, int64_t normalization, + bool onesided) { + AT_ERROR("fft: ATen not compiled with MKL support"); +} + +Tensor& _fft_c2r_mkl_out(Tensor& out, const Tensor& self, IntArrayRef dim, int64_t normalization, + int64_t last_dim_size) { + AT_ERROR("fft: ATen not compiled with MKL support"); +} + +Tensor& _fft_c2c_mkl_out(Tensor& out, const Tensor& self, IntArrayRef dim, int64_t normalization, + bool forward) { AT_ERROR("fft: ATen not compiled with MKL support"); } @@ -26,6 +48,8 @@ Tensor _fft_mkl(const Tensor& input, int64_t signal_ndim, #include #include +#include + #include #include #include @@ -46,195 +70,193 @@ namespace at { namespace native { // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h. template -static inline void _fft_fill_with_conjugate_symmetry_slice(Tensor& output, - int64_t signal_ndim, int64_t size_last_dim, - int64_t start_last_dim_idx, int64_t i, int64_t num) { - scalar_t *data = output.data_ptr(); - - // A slice means a slice of last dimension (of size size_last_dim) - - // This function iterates through the slices to fill, i.e. to_slice_data - // (basically data_slices[i:i+num]), and keeps track of the slices it reads - // data from, i.e., from_slice_data, using from_slice_indices, a vector - // containing the index of the from_slice_data slice. - - // Compute the indices for the first from_slice_data - std::vector from_slice_indices(signal_ndim); // up to before last signal dim - int64_t remainder = i; - // set last signal dim values - int64_t from_slice_offset = 0; - for (int64_t d = signal_ndim - 1; d >= 0; d--) { - int64_t dim_size = output.size(d); - int64_t dim_idx = remainder % dim_size; - remainder = remainder / dim_size; - from_slice_indices[d] = dim_idx; - if (d == 0) { - from_slice_offset += dim_idx * output.stride(d); - } else if (dim_idx != 0) { - from_slice_offset += (dim_size - dim_idx) * output.stride(d); - } - } +static __ubsan_ignore_undefined__ // UBSAN gives false positives on using negative indexes with a pointer +void _fft_fill_with_conjugate_symmetry_slice( + Range range, at::ArrayRef is_mirrored_dim, IntArrayRef signal_half_sizes, + IntArrayRef in_strides, const scalar_t * in_ptr, + IntArrayRef out_strides, scalar_t * out_ptr) { + const auto ndim = signal_half_sizes.size(); + DimVector iter_index(ndim, 0); - // First to_slice_data and from_slice_data - scalar_t *to_slice_data = data + i * size_last_dim * 2; - scalar_t *from_slice_data = data + from_slice_offset; - - while (num > 0) { - // Fill to_slice_data from values in from_slice_data - for (int64_t j = start_last_dim_idx; j < size_last_dim; j++) { - // multiply index by 2 because of the last complex dim has size 2 - int64_t to_idx = j * 2; - int64_t from_idx = (size_last_dim - j) * 2; - to_slice_data[to_idx] = from_slice_data[from_idx]; - to_slice_data[to_idx + 1] = -from_slice_data[from_idx + 1]; - } - // Compute the next to_slice_data and from_slice_data slices - to_slice_data += size_last_dim * 2; - for (int64_t d = signal_ndim - 1; d >= 0; d--) { - // Compute the next index at this dimension using conjugate symmetry - // Break out of this loop if nothing carries over - from_slice_indices[d] = (from_slice_indices[d] + 1) % output.size(d); - if (d > 0) { - // At d > 0 nonbatch dim, to get next from_slice_data offset - // 1. if this dim idx becomes 1, will need to add (size - 1) * stride - // 2. otherwise, will need to subtract stride - if (from_slice_indices[d] == 0) { - // Subtract. Carries over to previous dimension - from_slice_data -= output.stride(d); - } else if (from_slice_indices[d] == 1) { - // Dimension index becomes 1 - // Doesn't carry over to previous dimension - from_slice_data += (output.size(d) - 1) * output.stride(d); - break; + // We explicitly loop over one row, then use this lambda to iterate over + // n-dimensions. This advances iter_index by one row, while updating in_ptr + // and out_ptr to point to the new row of data. + auto advance_index = [&] () __ubsan_ignore_undefined__ { + for (size_t i = 1; i < iter_index.size(); ++i) { + if (iter_index[i] + 1 < signal_half_sizes[i]) { + ++iter_index[i]; + in_ptr += in_strides[i]; + if (is_mirrored_dim[i]) { + if (iter_index[i] == 1) { + out_ptr += (signal_half_sizes[i] - 1) * out_strides[i]; + } else { + out_ptr -= out_strides[i]; + } } else { - // Subtract. Doesn't carry over to previous dimension - from_slice_data -= output.stride(d); - break; + out_ptr += out_strides[i]; } + return; + } + + in_ptr -= in_strides[i] * iter_index[i]; + if (is_mirrored_dim[i]) { + out_ptr -= out_strides[i]; } else { - // At d = 0 nonbatch dim, it means that to_slice_data ise now at a the - // beginning of a data sample. It maps to itself by conjugate symmetry. - from_slice_data = to_slice_data; + out_ptr -= out_strides[i] * iter_index[i]; + } + iter_index[i] = 0; + } + }; + + // The data slice we operate on may start part-way into the data + // Update iter_index and pointers to reference the start of the slice + if (range.begin > 0) { + iter_index[0] = range.begin % signal_half_sizes[0]; + auto linear_idx = range.begin / signal_half_sizes[0]; + + for (size_t i = 1; i < ndim && linear_idx > 0; ++i) { + iter_index[i] = linear_idx % signal_half_sizes[i]; + linear_idx = linear_idx / signal_half_sizes[i]; + + if (iter_index[i] > 0) { + in_ptr += in_strides[i] * iter_index[i]; + if (is_mirrored_dim[i]) { + out_ptr += out_strides[i] * (signal_half_sizes[i] - iter_index[i]); + } else { + out_ptr += out_strides[i] * iter_index[i]; + } } } - num--; + } + + auto numel_remaining = range.end - range.begin; + + if (is_mirrored_dim[0]) { + // Explicitly loop over a Hermitian mirrored dimension + if (iter_index[0] > 0) { + auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining); + for (int64_t i = iter_index[0]; i < end; ++i) { + out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]); + } + numel_remaining -= (end - iter_index[0]); + iter_index[0] = 0; + advance_index(); + } + + while (numel_remaining > 0) { + auto end = std::min(signal_half_sizes[0], numel_remaining); + out_ptr[0] = std::conj(in_ptr[0]); + for (int64_t i = 1; i < end; ++i) { + out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]); + } + numel_remaining -= end; + advance_index(); + } + } else { + // Explicit loop over a non-mirrored dimension, so just a simple conjugated copy + while (numel_remaining > 0) { + auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining); + for (int64_t i = iter_index[0]; i != end; ++i) { + out_ptr[i * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]); + } + numel_remaining -= (end - iter_index[0]); + iter_index[0] = 0; + advance_index(); + } } } -// input should be a contiguous batched tensor of same size as full (twosided) -// signals, but only contains half (onesided) of the values. -// This function modifies inplace. -static inline void _fft_fill_with_conjugate_symmetry_(Tensor& input, - int64_t signal_ndim, int64_t size_last_dim, - int64_t last_dim_start_slice) { - if (last_dim_start_slice >= size_last_dim) { - return; +static void _fft_fill_with_conjugate_symmetry_cpu_( + ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes, + IntArrayRef in_strides_bytes, const void * in_data, + IntArrayRef out_strides_bytes, void * out_data) { + + // Convert strides from bytes to elements + const auto element_size = scalarTypeToTypeMeta(dtype).itemsize(); + const auto ndim = signal_half_sizes.size(); + DimVector in_strides(ndim), out_strides(ndim); + for (int64_t i = 0; i < ndim; ++i) { + TORCH_INTERNAL_ASSERT(in_strides_bytes[i] % element_size == 0); + in_strides[i] = in_strides_bytes[i] / element_size; + TORCH_INTERNAL_ASSERT(out_strides_bytes[i] % element_size == 0); + out_strides[i] = out_strides_bytes[i] / element_size; } - int64_t num = 1; - for (int64_t d = 0; d < signal_ndim; d++) { - num *= input.size(d); + // Construct boolean mask for mirrored dims + c10::SmallVector is_mirrored_dim(ndim, false); + for (const auto& dim : mirror_dims) { + is_mirrored_dim[dim] = true; } - at::parallel_for(0, num, 500, [&](int64_t start, int64_t end) { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "_fft_fill_with_conjugate_symmetry", [&] { - _fft_fill_with_conjugate_symmetry_slice(input, signal_ndim, size_last_dim, - last_dim_start_slice, start, (end - start)); - }); + const auto numel = at::prod_intlist(signal_half_sizes); + AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] { + at::parallel_for(0, numel, at::internal::GRAIN_SIZE, + [&](int64_t begin, int64_t end) { + _fft_fill_with_conjugate_symmetry_slice( + {begin, end}, is_mirrored_dim, signal_half_sizes, + in_strides, static_cast(in_data), + out_strides, static_cast(out_data)); + }); }); } -// MKL DFTI -Tensor _fft_mkl(const Tensor& self, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - int64_t normalization, bool onesided, - IntArrayRef output_sizes) { - int64_t batch = self.size(0); - Tensor input = self; - // real/imag dimension must aligned when viewed as of complex type - if (complex_input) { - bool need_contiguous = input.stride(-1) != 1; - for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) { - need_contiguous |= input.stride(i) % 2 != 0; - } - if (need_contiguous) { - input = input.contiguous(); - } - } +// Register this one implementation for all cpu types instead of compiling multiple times +REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_) +REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) +REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) - // check if we can use MKL because MKL_LONG is 32bit on some OS, e.g. Windows - // need to check input and output size and strides - // be careful about complex domain, where the stride needs to be divided by 2 - // only need to test upper bound MKL_LONG_MAX as these values are non-negative - if (sizeof(MKL_LONG) < sizeof(int64_t)) { - bool need_contiguous = false; - int64_t inumel = 1 /* istride if we contiguous-fy */, onumel = 1; - int64_t isize, osize, istride, ostride; - for (int64_t i = signal_ndim; i >= 0; i--) { - isize = input.size(i); - osize = output_sizes[i]; - istride = complex_input ? input.stride(i) >> 1 : input.stride(i); - ostride = onumel; - TORCH_CHECK(isize <= MKL_LONG_MAX && osize <= MKL_LONG_MAX && ostride <= MKL_LONG_MAX, - "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]"); - if (!need_contiguous && istride > MKL_LONG_MAX) { - // If we didn't plan to contiguous-fy but the `istride` exceeds bound, - // check if we can stride (equal to `inumel`) get back within bound if - // we contiguous-fy. If so, then we need to always check `inumel` - // instead for the remaining iterations. The iterations before this are - // fine as `inumel` is non-decreasing. - need_contiguous = true; - } - TORCH_CHECK(!need_contiguous || inumel <= MKL_LONG_MAX, - "MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]"); - inumel *= isize; - onumel *= osize; - } - } - Tensor output = at::empty(output_sizes, input.options()); +// Constructs an mkl-fft plan descriptor representing the desired transform +// For complex types, strides are in units of 2 * element_size(dtype) +// sizes are for the full signal, including batch size and always two-sided +static DftiDescriptor _plan_mkl_fft( + IntArrayRef in_strides, IntArrayRef out_strides, IntArrayRef sizes, + bool complex_input, bool complex_output, + int64_t normalization, bool forward, ScalarType dtype) { + const int64_t signal_ndim = sizes.size() - 1; + TORCH_INTERNAL_ASSERT(in_strides.size() == sizes.size()); + TORCH_INTERNAL_ASSERT(out_strides.size() == sizes.size()); // precision - DFTI_CONFIG_VALUE prec; - if (input.scalar_type() == ScalarType::Float) { - prec = DFTI_SINGLE; - } else if (input.scalar_type() == ScalarType::Double) { - prec = DFTI_DOUBLE; - } else { - std::ostringstream ss; - ss << "MKL FFT doesn't support tensor of type: " - << toString(input.scalar_type()); - AT_ERROR(ss.str()); - } + const DFTI_CONFIG_VALUE prec = [&]{ + switch (c10::toValueType(dtype)) { + case ScalarType::Float: return DFTI_SINGLE; + case ScalarType::Double: return DFTI_DOUBLE; + default: TORCH_CHECK(false, "MKL FFT doesn't support tensors of type: ", dtype); + } + }(); // signal type - DFTI_CONFIG_VALUE signal_type; - if (!inverse) { - signal_type = complex_input ? DFTI_COMPLEX : DFTI_REAL; - } else { - signal_type = complex_output ? DFTI_COMPLEX : DFTI_REAL; - } + const DFTI_CONFIG_VALUE signal_type = [&]{ + if (forward) { + return complex_input ? DFTI_COMPLEX : DFTI_REAL; + } else { + return complex_output ? DFTI_COMPLEX : DFTI_REAL; + } + }(); // create descriptor with signal size - std::vector mkl_signal_sizes(checked_signal_sizes.begin(), checked_signal_sizes.end()); + using MklDimVector = c10::SmallVector; + MklDimVector mkl_signal_sizes(sizes.begin() + 1, sizes.end()); DftiDescriptor descriptor; descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data()); // out of place FFT MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); // batch mode - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch)); + MKL_LONG mkl_batch_size = sizes[0]; + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, mkl_batch_size)); - auto istrides = input.strides(); - auto ostrides = output.strides(); // batch dim stride, i.e., dist between each data - MKL_LONG idist = complex_input ? istrides[0] >> 1 : istrides[0]; - MKL_LONG odist = complex_output ? ostrides[0] >> 1 : ostrides[0]; + TORCH_CHECK(in_strides[0] <= MKL_LONG_MAX && out_strides[0] <= MKL_LONG_MAX); + MKL_LONG idist = in_strides[0]; + MKL_LONG odist = out_strides[0]; MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); + // signal strides // first val is offset, set to zero (ignored) - std::vector mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0); + MklDimVector mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0); for (int64_t i = 1; i <= signal_ndim; i++) { - mkl_istrides[i] = complex_input ? istrides[i] >> 1 : istrides[i]; - mkl_ostrides[i] = complex_output ? ostrides[i] >> 1 : ostrides[i]; + TORCH_CHECK(in_strides[i] <= MKL_LONG_MAX && out_strides[i] <= MKL_LONG_MAX); + mkl_istrides[i] = in_strides[i]; + mkl_ostrides[i] = out_strides[i]; } MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data())); MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data())); @@ -245,33 +267,195 @@ Tensor _fft_mkl(const Tensor& self, int64_t signal_ndim, } // rescale if requested const auto norm = static_cast(normalization); + int64_t signal_numel = at::prod_intlist(IntArrayRef(sizes.data() + 1, signal_ndim)); if (norm != fft_norm_mode::none) { - auto signal_numel = at::prod_intlist(checked_signal_sizes); - double double_scale; - if (norm == fft_norm_mode::by_root_n) { - double_scale = 1.0 / std::sqrt(static_cast(signal_numel)); - } else { - double_scale = 1.0 / static_cast(signal_numel); - } - MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), - inverse ? DFTI_BACKWARD_SCALE : DFTI_FORWARD_SCALE, - prec == DFTI_DOUBLE ? double_scale : static_cast(double_scale))); + const double scale = ( + (norm == fft_norm_mode::by_root_n) ? + 1.0 / std::sqrt(static_cast(signal_numel)) : + 1.0 / static_cast(signal_numel)); + const auto scale_direction = forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE; + MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); } + + if (sizeof(MKL_LONG) < sizeof(int64_t)) { + TORCH_CHECK(signal_numel <= MKL_LONG_MAX, + "MKL FFT: input signal numel exceeds allowed range [1, ", MKL_LONG_MAX, "]"); + } + // finalize MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); - // run - if (!inverse) { - MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), output.data_ptr())); + + return descriptor; +} + +// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) +static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes, + IntArrayRef dim, int64_t normalization, bool forward) { + const auto ndim = self.dim(); + const int64_t signal_ndim = dim.size(); + const auto batch_dims = ndim - signal_ndim; + + // Permute dimensions so batch dimensions come first, and in stride order + // This maximizes data locality when collapsing to a single batch dimension + DimVector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0}); + + c10::SmallVector is_transformed_dim(ndim); + for (const auto& d : dim) { + is_transformed_dim[d] = true; + } + auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(), + [&](int64_t d) {return !is_transformed_dim[d]; }); + auto self_strides = self.strides(); + std::sort(dim_permute.begin(), batch_end, + [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; }); + std::copy(dim.cbegin(), dim.cend(), batch_end); + auto input = self.permute(dim_permute); + + // Collapse batch dimensions into a single dimension + DimVector batched_sizes(signal_ndim + 1); + batched_sizes[0] = -1; + std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1); + input = input.reshape(batched_sizes); + + const auto batch_size = input.sizes()[0]; + DimVector signal_size(signal_ndim + 1); + signal_size[0] = batch_size; + for (int64_t i = 0; i < signal_ndim; ++i) { + auto in_size = input.sizes()[i + 1]; + auto out_size = out_sizes[dim[i]]; + signal_size[i + 1] = std::max(in_size, out_size); + TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] || + in_size == (signal_size[i + 1] / 2) + 1); + TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] || + out_size == (signal_size[i + 1] / 2) + 1); + } + + batched_sizes[0] = batch_size; + DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end()); + for (size_t i = 0; i < dim.size(); ++i) { + batched_out_sizes[i + 1] = out_sizes[dim[i]]; + } + + const auto value_type = c10::toValueType(input.scalar_type()); + out.resize_(batched_out_sizes, MemoryFormat::Contiguous); + + auto descriptor = _plan_mkl_fft( + input.strides(), out.strides(), signal_size, input.is_complex(), + out.is_complex(), normalization, forward, value_type); + + // run the FFT + if (forward) { + MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), out.data_ptr())); } else { - MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), output.data_ptr())); + MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), out.data_ptr())); } - // now if needed, fill out the other half using Hermitian symmetry dim - if (!complex_input && complex_output && !onesided) { - auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1]; - auto start_slice = infer_ft_real_to_complex_onesided_size(size_last_signal_dim); - _fft_fill_with_conjugate_symmetry_(output, signal_ndim, size_last_signal_dim, start_slice); + + // Inplace reshaping to original batch shape and inverting the dimension permutation + DimVector out_strides(ndim); + int64_t batch_numel = 1; + for (int64_t i = batch_dims - 1; i >= 0; --i) { + out_strides[dim_permute[i]] = batch_numel * out.strides()[0]; + batch_numel *= out_sizes[dim_permute[i]]; + } + for (int64_t i = batch_dims; i < ndim; ++i) { + out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)]; + } + return out.as_strided_(out_sizes, out_strides, out.storage_offset()); +} + +// Sort transform dimensions by input layout, for best performance +// exclude_last is for onesided transforms where the last dimension cannot be reordered +static DimVector _sort_dims(const Tensor& self, IntArrayRef dim, bool exclude_last=false) { + DimVector sorted_dims(dim.begin(), dim.end()); + auto self_strides = self.strides(); + std::sort(sorted_dims.begin(), sorted_dims.end() - exclude_last, + [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; }); + return sorted_dims; +} + +// n-dimensional complex to real IFFT +Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { + TORCH_CHECK(self.is_complex()); + // NOTE: Multi-dimensional C2R transforms don't agree with numpy in cases + // where the input isn't strictly Hermitian-symmetric. Instead, we use a + // multi-dim C2C transform followed by a 1D C2R transform. + // + // Such inputs are technically out of contract though, so maybe a disagreement + // is okay. + auto input = self; + if (dim.size() > 1) { + auto c2c_dims = dim.slice(0, dim.size() - 1); + input = _fft_c2c_mkl(self, c2c_dims, normalization, /*foward=*/false); + dim = dim.slice(dim.size() - 1); + } + + auto in_sizes = input.sizes(); + DimVector out_sizes(in_sizes.begin(), in_sizes.end()); + out_sizes[dim.back()] = last_dim_size; + auto out = at::empty(out_sizes, self.options().dtype(c10::toValueType(self.scalar_type()))); + return _exec_fft(out, input, out_sizes, dim, normalization, /*forward=*/false); +} + +Tensor& _fft_c2r_mkl_out(Tensor& out, const Tensor& self, IntArrayRef dim, int64_t normalization, + int64_t last_dim_size) { + auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size); + resize_output(out, result.sizes()); + return out.copy_(result); +} + +// n-dimensional real to complex FFT +Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { + TORCH_CHECK(self.is_floating_point()); + auto input_sizes = self.sizes(); + DimVector out_sizes(input_sizes.begin(), input_sizes.end()); + auto last_dim = dim.back(); + auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; + if (onesided) { + out_sizes[last_dim] = last_dim_halfsize; } - return output; + + auto sorted_dims = _sort_dims(self, dim, /*exclude_last=*/true); + auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type()))); + _exec_fft(out, self, out_sizes, sorted_dims, normalization, /*forward=*/true); + + if (!onesided) { + at::native::_fft_fill_with_conjugate_symmetry_(out, dim); + } + return out; +} + +Tensor& _fft_r2c_mkl_out(Tensor& out, const Tensor& self, IntArrayRef dim, int64_t normalization, + bool onesided) { + auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true); + if (onesided) { + resize_output(out, result.sizes()); + return out.copy_(result); + } + + resize_output(out, self.sizes()); + + auto last_dim = dim.back(); + auto last_dim_halfsize = result.sizes()[last_dim]; + auto out_slice = out.slice(last_dim, 0, last_dim_halfsize); + out_slice.copy_(result); + at::native::_fft_fill_with_conjugate_symmetry_(out, dim); + return out; +} + +// n-dimensional complex to complex FFT/IFFT +Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { + TORCH_CHECK(self.is_complex()); + const auto sorted_dims = _sort_dims(self, dim); + auto out = at::empty(self.sizes(), self.options()); + return _exec_fft(out, self, self.sizes(), sorted_dims, normalization, forward); +} + +Tensor& _fft_c2c_mkl_out(Tensor& out, const Tensor& self, IntArrayRef dim, int64_t normalization, + bool forward) { + auto result = _fft_c2c_mkl(self, dim, normalization, forward); + resize_output(out, result.sizes()); + return out.copy_(result); } }} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/BinaryOps.cpp b/aten/src/ATen/native/mkldnn/BinaryOps.cpp index 3364fe8b335c6..3358079f4df57 100644 --- a/aten/src/ATen/native/mkldnn/BinaryOps.cpp +++ b/aten/src/ATen/native/mkldnn/BinaryOps.cpp @@ -8,10 +8,11 @@ namespace at { namespace native { Tensor& mkldnn_add_out( - Tensor& result, const Tensor& self, const Tensor& other, - Scalar alpha) { + Scalar alpha, + Tensor& result + ) { TORCH_CHECK(false, "mkldnn_add_out: ATen not compiled with MKLDNN support"); } @@ -46,10 +47,11 @@ namespace at { namespace native { Tensor& mkldnn_add_out( - Tensor& result, const Tensor& self, const Tensor& other, - Scalar alpha) { + Scalar alpha, + Tensor& result + ) { ideep::tensor& x = itensor_from_mkldnn(self); ideep::tensor& y = itensor_from_mkldnn(other); @@ -68,11 +70,12 @@ Tensor mkldnn_add(const Tensor& self, const Tensor& other, Scalar alpha) { const std::vector scales{1.0, alpha.to()}; ideep::sum::compute(scales, {x, y}, z); - return new_with_itensor_mkldnn(std::move(z), self.options()); + return new_with_itensor_mkldnn(std::move(z), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) { - return native::mkldnn_add_out(self, self, other, alpha); + return native::mkldnn_add_out(self, other, alpha, self); } Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) { @@ -99,7 +102,9 @@ Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) } Tensor mkldnn_mul(const Tensor& self, const Tensor& other) { - Tensor result = empty_mkldnn(self.sizes(), self.options()); + Tensor result = empty_mkldnn(self.sizes(), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().layout_opt(), self.options().device_opt(), + self.options().pinned_memory_opt()); return native::mkldnn_mul_out(result, self, other); } diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 664f7bbd8f1e4..3b27752bdc1d2 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -38,21 +38,6 @@ std::tuple mkldnn_convolution_backward( #include #include -namespace { -// Helper function for getting an ideep tensor out of an aten Tensor. -// Note in case the aten Tensor is a dense tensor, the returned ideep -// tensor is just a view of the storage of the aten dense tensor, so -// caller needs to make sure the aten dense tensor's lifetime is -// longer than the ideep tensor. -inline ideep::tensor get_mkldnn_tensor(const at::Tensor& tensor) { - if (tensor.is_mkldnn()) { - return at::native::itensor_from_mkldnn(tensor); - } else { - return at::native::itensor_view_from_dense(tensor); - } -} -} - namespace at { namespace native { ideep::tensor _mkldnn_convolution( @@ -106,11 +91,11 @@ Tensor mkldnn_convolution( IntArrayRef stride, IntArrayRef dilation, int64_t groups) { - const ideep::tensor mkldnn_input = get_mkldnn_tensor(input); - const ideep::tensor mkldnn_weight = get_mkldnn_tensor(weight); + const ideep::tensor mkldnn_input = itensor_from_tensor(input); + const ideep::tensor mkldnn_weight = itensor_from_tensor(weight); c10::optional mkldnn_bias{c10::nullopt}; if (bias.defined()) { - mkldnn_bias = get_mkldnn_tensor(bias); + mkldnn_bias = itensor_from_tensor(bias); } ideep::tensor mkldnn_output = _mkldnn_convolution( @@ -123,10 +108,12 @@ Tensor mkldnn_convolution( groups); if (input.is_mkldnn()) { - return new_with_itensor_mkldnn(std::move(mkldnn_output), input.options()); + return new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()); } else { return mkldnn_to_dense( - new_with_itensor_mkldnn(std::move(mkldnn_output), input.options())); + new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt())); } } @@ -134,8 +121,8 @@ Tensor mkldnn_convolution_backward_input( IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - auto mkldnn_grad_output = get_mkldnn_tensor(grad_output); - auto mkldnn_weight = get_mkldnn_tensor(weight); + auto mkldnn_grad_output = itensor_from_tensor(grad_output); + auto mkldnn_weight = itensor_from_tensor(weight); ideep::tensor mkldnn_grad_input; ideep::convolution_backward_data::compute( @@ -150,15 +137,16 @@ Tensor mkldnn_convolution_backward_input( groups); return mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_input), - grad_output.options())); + optTypeMetaToScalarType(grad_output.options().dtype_opt()), + grad_output.options().device_opt())); } std::tuple mkldnn_convolution_backward_weights( IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor(grad_output); - const ideep::tensor mkldnn_input = get_mkldnn_tensor(input); + const ideep::tensor mkldnn_grad_output = itensor_from_tensor(grad_output); + const ideep::tensor mkldnn_input = itensor_from_tensor(input); ideep::tensor mkldnn_grad_weight, mkldnn_grad_bias; if (bias_defined) { @@ -188,9 +176,11 @@ std::tuple mkldnn_convolution_backward_weights( return std::make_tuple( mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_weight), - grad_output.options())), + optTypeMetaToScalarType(grad_output.options().dtype_opt()), + grad_output.options().device_opt())), mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_bias), - grad_output.options()))); + optTypeMetaToScalarType(grad_output.options().dtype_opt()), + grad_output.options().device_opt()))); } std::tuple mkldnn_convolution_backward( diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index bcc9b786b869b..21d240ef52799 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -54,9 +54,11 @@ Tensor mkldnn_linear( output_size.push_back(weight.size(0)); if (self.dim() > 2) { - return new_with_itensor_mkldnn(std::move(y), self.options()).reshape(output_size); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()).reshape(output_size); } - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index a581740349679..ce397aabc3a32 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -21,7 +21,7 @@ namespace at { namespace native { * NOTE: if this is generally useful we may want to move this to its own header. */ template -struct CAFFE2_API IntrusivePtrTargetWrapper : c10::intrusive_ptr_target { +struct TORCH_API IntrusivePtrTargetWrapper : c10::intrusive_ptr_target { private: T target_; @@ -40,14 +40,34 @@ using IDeepTensorWrapperPtr = c10::intrusive_ptr; using MKLDNNTensorImpl = OpaqueTensorImpl; using MKLDNNTensor = Tensor; -Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options) { +ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) { + switch (type) { + case ScalarType::Float: + return ideep::tensor::data_type::f32; + case ScalarType::QInt32: + return ideep::tensor::data_type::s32; + case ScalarType::QInt8: + return ideep::tensor::data_type::s8; + case ScalarType::QUInt8: + case ScalarType::Byte: + return ideep::tensor::data_type::u8; + case ScalarType::BFloat16: + return ideep::tensor::data_type::bf16; + default: + TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type"); + } +} + +Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional dtype, c10::optional device) { // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t // TODO: support int64_t dims in ideep::tensor to avoid extra conversion auto dims = it.get_dims(); IDeepTensorWrapperPtr handle = c10::make_intrusive(std::move(it)); + caffe2::TypeMeta dtype_ = scalarTypeToTypeMeta(dtype_or_default(dtype)); + Device device_ = device_or_default(device); return detail::make_tensor( DispatchKeySet(DispatchKey::MkldnnCPU), - options.dtype(), options.device(), handle, + dtype_, device_, handle, std::vector(dims.begin(), dims.end())); } @@ -73,6 +93,20 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor) { ideep::tensor::data_type::f32}, tensor.template data_ptr()}; } + +// Helper function for getting an ideep tensor out of an aten Tensor. +// Note in case the aten Tensor is a dense tensor, the returned ideep +// tensor is just a view of the storage of the aten dense tensor, so +// caller needs to make sure the aten dense tensor's lifetime is +// longer than the ideep tensor. +ideep::tensor itensor_from_tensor(const Tensor& tensor) { + if (tensor.is_mkldnn()) { + return itensor_from_mkldnn(tensor); + } else { + return itensor_view_from_dense(tensor); + } +} + }} #endif // AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h index 0167b8183d469..e6c10b94ced34 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h @@ -8,8 +8,11 @@ namespace at { namespace native { +// Mapping ScalarType to ideep tensor data_type +ideep::tensor::data_type get_mkldnn_dtype(ScalarType type); + // Construct aten MKL-DNN tensor given an ideep tensor -Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options); +Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional dtype, c10::optional device); // Retrieve `ideep::tensor` from MKL-DNN tensor ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor); @@ -17,6 +20,10 @@ ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor); // Construct an `ideep::tensor` "view" from dense tensor, note the // ideep::tensor will share the underlying buffer ideep::tensor itensor_view_from_dense(const Tensor& tensor); + +// Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor. +ideep::tensor itensor_from_tensor(const Tensor& tensor); + }} #endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp index 971fa7a3af2fb..d5c062f0050c6 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp @@ -8,35 +8,58 @@ namespace at { namespace native { #if AT_MKLDNN_ENABLED() -Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) { +Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional dtype) { + TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float || + mkldnn_tensor.scalar_type() == ScalarType::BFloat16, + "mkldnn_to_dense expects float or bfloat16 tensor input"); ideep::tensor& stensor = itensor_from_mkldnn(mkldnn_tensor); auto dims = stensor.get_dims(); + auto data_type = dtype.has_value() ? dtype.value() : mkldnn_tensor.scalar_type(); + TORCH_CHECK(data_type == ScalarType::Float || data_type == ScalarType::BFloat16, + "mkldnn tensor only can be converted to be a float or bfloat16 cpu tensor") // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t Tensor cpu_tensor = at::empty( std::vector(dims.begin(), dims.end()), - mkldnn_tensor.options().layout(c10::kStrided)); + mkldnn_tensor.options().layout(c10::kStrided).dtype(data_type)); if (stensor.is_empty()) return cpu_tensor; - auto pub_tensor = stensor.to_public(cpu_tensor.template data_ptr()); + auto pub_tensor = + data_type == ScalarType::Float + ? stensor.to_public(cpu_tensor.template data_ptr(), + ideep::tensor::data_type::f32) + : stensor.to_public(cpu_tensor.template data_ptr(), + ideep::tensor::data_type::bf16); cpu_tensor.as_strided_(dims, pub_tensor.get_strides()); return cpu_tensor; } -Tensor dense_to_mkldnn(const Tensor& cpu_tensor) { +Tensor dense_to_mkldnn(const Tensor& cpu_tensor, c10::optional dtype) { TORCH_CHECK(cpu_tensor.device().type() == DeviceType::CPU, "dense_to_mkldnn expects CPU tensor input"); TORCH_CHECK(cpu_tensor.layout() == Layout::Strided, "dense_to_mkldnn expects strided tensor input"); - TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float, - "dense_to_mkldnn expects float tensor input"); + TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float || + cpu_tensor.scalar_type() == ScalarType::BFloat16, + "dense_to_mkldnn expects float or bfloat16 tensor input"); TORCH_CHECK(cpu_tensor.dim() <= 5, "Can't convert cpu tensor with the number of dimensions > 5"); // TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly. auto cpu_tensor_cont = cpu_tensor.contiguous(); - Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), cpu_tensor_cont.options()); + auto data_type = dtype.has_value() ? dtype.value() : cpu_tensor.scalar_type(); + TORCH_CHECK(data_type == ScalarType::Float || data_type == ScalarType::BFloat16, + "cpu tensor only can be converted to be a float or bfloat16 mkldnn tensor") + Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), data_type, + cpu_tensor_cont.options().layout_opt(), cpu_tensor_cont.options().device_opt(), + cpu_tensor_cont.options().pinned_memory_opt()); ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor); - dtensor.feed_from(dtensor.get_dims(), - ideep::tensor::data_type::f32, - (cpu_tensor_cont.template data_ptr())); + if (cpu_tensor.scalar_type() == ScalarType::Float) { + dtensor.feed_from(dtensor.get_dims(), + ideep::tensor::data_type::f32, + (cpu_tensor_cont.template data_ptr())); + } else { + dtensor.feed_from(dtensor.get_dims(), + ideep::tensor::data_type::bf16, + cpu_tensor_cont.template data_ptr()); + } return mkldnn_tensor; } @@ -79,7 +102,8 @@ Tensor mkldnn_reorder_conv2d_weight( result.init(desc); result.feed_from(w); - return new_with_itensor_mkldnn(std::move(result), self.options()); + return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor mkldnn_reorder_conv3d_weight( @@ -105,16 +129,16 @@ Tensor mkldnn_reorder_conv3d_weight( result.init(desc); result.feed_from(w); - return new_with_itensor_mkldnn(std::move(result), self.options()); + return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt()); } #else -Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) { +Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional dtype) { TORCH_CHECK(false, "MKL-DNN build is disabled"); } -Tensor dense_to_mkldnn(const Tensor& cpu_tensor) { +Tensor dense_to_mkldnn(const Tensor& cpu_tensor, c10::optional dtype) { TORCH_CHECK(false, "MKL-DNN build is disabled"); } diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index 86d9d0643a27c..ca331392acd8c 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -56,18 +56,24 @@ std::tuple mkldnn_batch_norm( // ideep::batch_normalization_forward_training::compute( // x, w, b, y, saved_mean, saved_var, m, v, momentum, eps); // return std::make_tuple( - // new_with_itensor_mkldnn(std::move(y), input.options()), - // new_with_itensor_mkldnn(std::move(saved_mean), input.options()), - // new_with_itensor_mkldnn(std::move(saved_var), input.options())); + // new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), + // input.options().device_opt()), + // new_with_itensor_mkldnn(std::move(saved_mean), optTypeMetaToScalarType(input.options().dtype_opt()), + // input.options().device_opt()), + // new_with_itensor_mkldnn(std::move(saved_var), optTypeMetaToScalarType(input.options().dtype_opt()), + // input.options().device_opt())); } else { TORCH_CHECK(input.dim() == 4 || input.dim() == 5, "mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm"); ideep::batch_normalization_forward_inference::compute( x, m, v, w, b, y, eps); return std::make_tuple( - new_with_itensor_mkldnn(std::move(y), input.options()), - new_with_itensor_mkldnn(ideep::tensor{}, input.options()), - new_with_itensor_mkldnn(ideep::tensor{}, input.options())); + new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()), + new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()), + new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt())); } } diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp index a272bc3d6070b..5f744f494443e 100644 --- a/aten/src/ATen/native/mkldnn/Pooling.cpp +++ b/aten/src/ATen/native/mkldnn/Pooling.cpp @@ -174,7 +174,7 @@ static Tensor _mkldnn_pooling( algo, ideep::prop_kind::forward); - return new_with_itensor_mkldnn(std::move(y), input.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), input.options().device_opt()); } Tensor mkldnn_max_pool2d( diff --git a/aten/src/ATen/native/mkldnn/Relu.cpp b/aten/src/ATen/native/mkldnn/Relu.cpp index 42397255caf00..6915447980bba 100644 --- a/aten/src/ATen/native/mkldnn/Relu.cpp +++ b/aten/src/ATen/native/mkldnn/Relu.cpp @@ -28,7 +28,8 @@ Tensor mkldnn_relu(const Tensor& input) { ideep::tensor y; ideep::eltwise_forward::compute( x, y, ideep::algorithm::eltwise_relu, ideep::prop_kind::forward_training, /*alpha*/ 0.0); - return new_with_itensor_mkldnn(std::move(y), input.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()); } Tensor& mkldnn_relu_(Tensor& input) { diff --git a/aten/src/ATen/native/mkldnn/SoftMax.cpp b/aten/src/ATen/native/mkldnn/SoftMax.cpp index cdeb6cb859719..861cca0aae53c 100644 --- a/aten/src/ATen/native/mkldnn/SoftMax.cpp +++ b/aten/src/ATen/native/mkldnn/SoftMax.cpp @@ -35,7 +35,8 @@ Tensor mkldnn_softmax( ideep::tensor& x = itensor_from_mkldnn(self); ideep::tensor y; ideep::softmax_forward::compute(x, y, wrapped_dim); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/TensorFactories.cpp b/aten/src/ATen/native/mkldnn/TensorFactories.cpp index 603819ed32873..dc34281d25cac 100644 --- a/aten/src/ATen/native/mkldnn/TensorFactories.cpp +++ b/aten/src/ATen/native/mkldnn/TensorFactories.cpp @@ -4,23 +4,21 @@ namespace at { namespace native { #if AT_MKLDNN_ENABLED() -Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional optional_memory_format) { - TORCH_CHECK( - !options.has_memory_format(), - "'memory_format' argument is incompatible with mkldnn tensor"); +Tensor empty_mkldnn(IntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { TORCH_CHECK( !optional_memory_format.has_value(), "'memory_format' argument is incompatible with mkldnn tensor"); // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t // TODO: support int64_t dims in ideep::tensor to avoid extra conversion ideep::tensor::dims dst_dims (sizes.begin(), sizes.end()); - ideep::tensor it {dst_dims, ideep::tensor::data_type::f32}; - return new_with_itensor_mkldnn(std::move(it), options); + auto data_type = dtype.has_value() ? get_mkldnn_dtype(dtype.value()) : ideep::tensor::data_type::f32; + ideep::tensor it {dst_dims, data_type}; + return new_with_itensor_mkldnn(std::move(it), dtype, device); } #else -Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional optional_memory_format) { +Tensor empty_mkldnn(IntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { TORCH_CHECK(false, "empty_mkldnn: MKL-DNN build is disabled"); } diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp index 3229a07e94609..6e31a3a8aa93c 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.cpp +++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp @@ -51,7 +51,8 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) { const ideep::tensor& x = itensor_from_mkldnn(self); ideep::tensor y{x}; y.reshape(inferred_size); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor mkldnn_clone(const Tensor& self, c10::optional optional_memory_format) { @@ -62,7 +63,8 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional optiona ideep::tensor& src = itensor_from_mkldnn(self); ideep::tensor dst; ideep::direct_copy::compute(src, dst); - return new_with_itensor_mkldnn(std::move(dst), self.options()); + return new_with_itensor_mkldnn(std::move(dst), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { @@ -72,7 +74,8 @@ Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { std::iota(axes.begin(), axes.end(), 0); std::swap(axes[dim0], axes[dim1]); y.transpose_from(x, axes); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) { diff --git a/aten/src/ATen/native/mkldnn/UnaryOps.cpp b/aten/src/ATen/native/mkldnn/UnaryOps.cpp index 4eb02dc483c5c..1434512b5241a 100644 --- a/aten/src/ATen/native/mkldnn/UnaryOps.cpp +++ b/aten/src/ATen/native/mkldnn/UnaryOps.cpp @@ -30,7 +30,8 @@ Tensor mkldnn_sigmoid(const Tensor& self) { ideep::tensor y; ideep::eltwise_forward::compute( x, y, ideep::algorithm::eltwise_logistic, ideep::prop_kind::forward); - return new_with_itensor_mkldnn(std::move(y), self.options()); + return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().device_opt()); } Tensor& mkldnn_sigmoid_(Tensor& self) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 78b6d33303000..792c44bd1c139 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7,47 +7,39 @@ # DEPRECATED. DO NOT USE - func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # DEPRECATED. DO NOT USE - func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # DEPRECATED. DO NOT USE - func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # DEPRECATED. DO NOT USE - func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # DEPRECATED. DO NOT USE - func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # DEPRECATED. DO NOT USE - func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # DEPRECATED. DO NOT USE - func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # DEPRECATED. DO NOT USE - func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full variants: function # Computes the gradient of current tensor w.r.t. graph leaves. -- func: backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () - use_c10_dispatcher: full +- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures manual_kernel_registration: True variants: method @@ -59,18 +51,15 @@ # where Variables *are* Tensors (as opposed to them containing tensors, which # is what the previous interpretation was.) - func: set_data(Tensor(a!) self, Tensor new_data) -> () - use_c10_dispatcher: full manual_kernel_registration: True variants: method - func: data(Tensor self) -> Tensor - use_c10_dispatcher: full manual_kernel_registration: True variants: method # True if this `Variable` is a leaf and thus does not have a `grad_fn`. - func: is_leaf(Tensor self) -> bool - use_c10_dispatcher: full manual_kernel_registration: True variants: method @@ -85,26 +74,36 @@ # assert y2.output_nr == 2 # - func: output_nr(Tensor self) -> int - use_c10_dispatcher: full manual_kernel_registration: True variants: method - func: _version(Tensor self) -> int - use_c10_dispatcher: full manual_kernel_registration: True variants: method - func: requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!) - use_c10_dispatcher: full manual_kernel_registration: True variants: method # Enables .grad attribute for non-leaf Tensors. - func: retain_grad(Tensor(a!) self) -> () - use_c10_dispatcher: full manual_kernel_registration: True variants: method +- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a) + use_c10_dispatcher: full + variants: method + dispatch: + DefaultBackend: _fw_primal + +- func: make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a) + use_c10_dispatcher: full + variants: function + +- func: unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent) + use_c10_dispatcher: full + variants: function + - func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) variants: method @@ -118,50 +117,43 @@ variants: method - func: align_as(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method - func: align_tensors(Tensor[] tensors) -> Tensor[] - use_c10_dispatcher: full - func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) variants: method - func: _use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool - use_c10_dispatcher: full dispatch: CUDA: _use_cudnn_ctc_loss - func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: _cudnn_ctc_loss - func: _use_cudnn_rnn_flatten_weight() -> bool - use_c10_dispatcher: full -- func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, int input_size, int mode, int hidden_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor - use_c10_dispatcher: full +- func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, int input_size, int mode, int hidden_size, int proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor dispatch: CUDA: _cudnn_rnn_flatten_weight -- func: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full +- func: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, int hidden_size, int proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: _cudnn_rnn -- func: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) - use_c10_dispatcher: full +- func: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: _cudnn_rnn_backward - func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: _cudnn_init_dropout_state - func: _debug_has_internal_overlap(Tensor self) -> int - use_c10_dispatcher: full variants: function - func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) @@ -170,62 +162,50 @@ CUDA: fused_dropout_cuda - func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CUDA: masked_scale_cuda - func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: _sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!) - use_c10_dispatcher: full - func: _sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!) - use_c10_dispatcher: full - func: _sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!) - use_c10_dispatcher: full - func: _reshape_from_tensor(Tensor self, Tensor shape) -> Tensor - use_c10_dispatcher: full - func: _shape_as_tensor(Tensor self) -> Tensor - use_c10_dispatcher: full - func: dropout(Tensor input, float p, bool train) -> Tensor - use_c10_dispatcher: full - func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) - use_c10_dispatcher: full - func: feature_dropout(Tensor input, float p, bool train) -> Tensor - use_c10_dispatcher: full - func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) - use_c10_dispatcher: full - func: alpha_dropout(Tensor input, float p, bool train) -> Tensor - use_c10_dispatcher: full - func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) - use_c10_dispatcher: full - func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor - use_c10_dispatcher: full - func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) - use_c10_dispatcher: full - func: abs(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: abs - func: abs_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: abs_ - func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: abs_out @@ -256,97 +236,103 @@ # Absolute, alias for abs - func: absolute(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: absolute_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: angle(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: angle - func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: angle_out - func: view_as_real(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: function + dispatch: + CPU, CUDA: view_as_real - func: view_as_complex(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: function + dispatch: + CPU, CUDA: view_as_complex - func: sgn(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sgn - func: sgn_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method + dispatch: + DefaultBackend: sgn_ - func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sgn_out - func: real(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: function - func: imag(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: function -- func: conj(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: conj(Tensor(a) self) -> Tensor(a) variants: function, method - func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: conj_out - func: _conj(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function + dispatch: + DefaultBackend: _conj - func: acos(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: acos - func: acos_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: acos_ - func: acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: acos_out # arccos, alias of acos - func: arccos(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: arccos_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor - use_c10_dispatcher: full - func: adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor - use_c10_dispatcher: full # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full + structured_delegate: add.out variants: function, method dispatch: CPU, CUDA: add @@ -354,87 +340,102 @@ MkldnnCPU: mkldnn_add - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method + structured_delegate: add.out dispatch: CPU, CUDA: add_ SparseCPU, SparseCUDA: add_sparse_ MkldnnCPU: mkldnn_add_ - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: full + structured: True + structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: add_out SparseCPU: add_out_sparse_cpu SparseCUDA: add_out_sparse_cuda MkldnnCPU: mkldnn_add_out -- func: add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full +- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor variants: function dispatch: CPU: add_relu -- func: add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full +- func: _add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) variants: function dispatch: CPU: add_relu_ -- func: add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +- func: _add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU: add_relu_out # For C++ only, until we have conversion from C++ numbers to Tensor - func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: add - func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: add_ - func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: addmv - func: addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: addmv_ - func: addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: addmv_out - func: _addmv_impl_(Tensor(a!) self, Tensor self2, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full dispatch: CPU: addmv_impl_cpu CUDA: addmv_impl_cuda - func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: addr + Math: math_addr - func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: addr_ - func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: addr_out + Math: math_addr_out - func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor - use_c10_dispatcher: full variants: function + dispatch: + DefaultBackend: affine_grid_generator - func: affine_grid_generator_backward(Tensor grad, int[] size, bool align_corners) -> Tensor - use_c10_dispatcher: full variants: function - func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: all - func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: all_out @@ -442,18 +443,18 @@ variants: function, method - func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool - use_c10_dispatcher: full variants: function, method - func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: any - func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: any_out @@ -461,19 +462,22 @@ variants: function, method - func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: arange_cpu_out CUDA: arange_cuda_out @@ -484,91 +488,103 @@ # preserve tracing. Get rid of this when arange can directly take tensors for bounds # (so that it can be traced directly). - func: _dim_arange(Tensor like, int dim) -> Tensor - use_c10_dispatcher: full - func: argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: argmax +- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: argmax_out + - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: argmin +- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: argmin_out + - func: acosh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: acosh - func: acosh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: acosh_ - func: acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: acosh_out # arccosh, alias for acosh - func: arccosh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: arccosh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: asinh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: asinh - func: asinh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: asinh_ - func: asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: asinh_out # arcsinh, alias for asinh - func: arcsinh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: arcsinh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: atanh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: atanh - func: atanh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: atanh_ - func: atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: atanh_out # arctanh, alias for atanh - func: arctanh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: arctanh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: as_strided_tensorimpl @@ -576,131 +592,132 @@ device_guard: False - func: as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function, method device_guard: False + dispatch: + DefaultBackend: as_strided_ - func: asin(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: asin + SparseCPU, SparseCUDA: asin_sparse - func: asin_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: asin_ SparseCPU, SparseCUDA: asin_sparse_ - func: asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: asin_out SparseCPU, SparseCUDA: asin_out_sparse # arcsin, alias of asin - func: arcsin(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: arcsin_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: atan(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: atan - func: atan_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: atan_ - func: atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: atan_out # arctan, alias of atan - func: arctan(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: arctan_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: atleast_1d(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function - func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[] - use_c10_dispatcher: full - func: atleast_2d(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function - func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[] - use_c10_dispatcher: full variants: function - func: atleast_3d(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function - func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] - use_c10_dispatcher: full variants: function - func: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: baddbmm_cpu CUDA: baddbmm_cuda - func: baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: baddbmm__cpu CUDA: baddbmm__cuda - func: _baddbmm_mkl_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: function - func: baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU: baddbmm_out_cpu CUDA: baddbmm_out_cuda - func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: QuantizedCPU: quantized_batch_norm - func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures # Sample bernoulli with values in `self` as probability. - func: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor variants: function, method + dispatch: + DefaultBackend: bernoulli - func: bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU, CUDA: bernoulli_out @@ -722,10 +739,10 @@ variants: function, method - func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn variants: function dispatch: @@ -733,6 +750,7 @@ CUDA: binary_cross_entropy_cuda - func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn variants: function dispatch: @@ -740,7 +758,7 @@ CUDA: binary_cross_entropy_out_cuda - func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn variants: function dispatch: @@ -748,6 +766,7 @@ CUDA: binary_cross_entropy_backward_cuda - func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn variants: function dispatch: @@ -755,88 +774,109 @@ CUDA: binary_cross_entropy_backward_out_cuda - func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function + dispatch: + DefaultBackend: binary_cross_entropy_with_logits - func: binary_cross_entropy_with_logits_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function - func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function, method dispatch: CPU: _bincount_cpu CUDA: _bincount_cuda - func: bitwise_not(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: bitwise_not_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: bitwise_not_out +- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor + variants: function, method + dispatch: + CPU, CUDA: copysign + +- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + dispatch: + CPU, CUDA: copysign_ + +- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: copysign_out + +- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor + variants: function, method + dispatch: + CPU, CUDA: copysign + +- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + dispatch: + CPU, CUDA: copysign_ + - func: logical_not(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: logical_not_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: logical_not_out - func: logical_xor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: logical_xor_out - func: logical_and(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: logical_and_out - func: logical_or(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: logical_or_out - func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: bmm(Tensor self, Tensor mat2) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: bmm_cpu @@ -845,12 +885,12 @@ SparseCUDA: bmm_sparse_cuda - func: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor - use_c10_dispatcher: full variants: function dispatch: SparseCUDA: _bmm_sparse_cuda - func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU: bmm_out_cpu @@ -859,433 +899,488 @@ SparseCUDA: bmm_out_sparse_cuda - func: _bmm.out(Tensor self, Tensor mat2, *, bool deterministic=False, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: SparseCUDA: _bmm_out_sparse_cuda - func: broadcast_tensors(Tensor[] tensors) -> Tensor[] - use_c10_dispatcher: full device_guard: False -- func: cat(Tensor[] tensors, int dim=0) -> Tensor +- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) use_c10_dispatcher: full + variants: function, method + dispatch: + Math: broadcast_to + +- func: cat(Tensor[] tensors, int dim=0) -> Tensor + dispatch: + DefaultBackend: cat - func: cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: cat_out - func: cat.names(Tensor[] tensors, Dimname dim) -> Tensor - func: cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: block_diag(Tensor[] tensors) -> Tensor - use_c10_dispatcher: full variants: function - func: ceil(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: ceil - func: ceil_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: ceil_ - func: ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: ceil_out - func: chain_matmul(Tensor[] matrices) -> Tensor - use_c10_dispatcher: full variants: function - func: unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[] - use_c10_dispatcher: full variants: function, method device_guard: False - func: chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[] - use_c10_dispatcher: full variants: function, method device_guard: False -- func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor +- func: tensor_split.sections(Tensor(a) self, int sections, int dim=0) -> Tensor(a)[] + variants: function, method + +- func: tensor_split.indices(Tensor(a) self, int[] indices, int dim=0) -> Tensor(a)[] + variants: function, method + +- func: tensor_split.tensor_indices_or_sections(Tensor(a) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] use_c10_dispatcher: full variants: function, method + +- func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + variants: function, method dispatch: CPU, CUDA: clamp QuantizedCPU: clamp_quantized_cpu - func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: clamp_ - func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: clamp_out - func: clamp_max(Tensor self, Scalar max) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: clamp_max - func: clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: clamp_max_ - func: clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: clamp_max_out - func: clamp_min(Tensor self, Scalar min) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: clamp_min - func: clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: clamp_min_ - func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: clamp_min_out # clip is an alias for clamp - func: clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor - use_c10_dispatcher: full variants: function, method - func: clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function, method - func: clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: cudnn_is_acceptable(Tensor self) -> bool - use_c10_dispatcher: full device_guard: False - func: complex(Tensor real, Tensor imag) -> Tensor - use_c10_dispatcher: full variants: function + dispatch: + DefaultBackend: complex - func: complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: complex_out - func: polar(Tensor abs, Tensor angle) -> Tensor - use_c10_dispatcher: full variants: function + dispatch: + DefaultBackend: polar - func: polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: polar_out - func: constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> Tensor - use_c10_dispatcher: full variants: function + dispatch: + DefaultBackend: constant_pad_nd - func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) - use_c10_dispatcher: full variants: method + manual_cpp_binding: True - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: convolution_overrideable - func: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - use_c10_dispatcher: full + dispatch: + DefaultBackend: convolution_backward_overrideable - func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _convolution_nogroup(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: conv_tbc - func: conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full # NB: we inherit the goofy argument order from PyTorch torch.nn.functional - func: conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] output_padding=0, int groups=1, int[1] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int groups=1, int[3] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) - use_c10_dispatcher: full variants: method device_guard: False + dispatch: + DefaultBackend: copy_ - func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor - use_c10_dispatcher: full dispatch: {} - func: cos(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cos - func: cos_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cos_ - func: cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: cos_out - func: cosh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cosh - func: cosh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cosh_ - func: cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: cosh_out - func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor - use_c10_dispatcher: full - func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: count_nonzero - func: count_nonzero(Tensor self, int? dim=None) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: count_nonzero - func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid - use_c10_dispatcher: full dispatch: CUDA: cudnn_affine_grid_generator_forward # TODO: Why do I have to call this grad?! - func: cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta - use_c10_dispatcher: full dispatch: CUDA: cudnn_affine_grid_generator_backward - func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: cudnn_batch_norm # NB: You can only use this if you used cudnn_batch_norm training=True - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: cudnn_batch_norm_backward - func: cudnn_convolution.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: cudnn_convolution_deprecated - func: cudnn_convolution.deprecated2(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_deprecated2 - func: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution - func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward_input - func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward - func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward_weight - func: cudnn_convolution_transpose.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: cudnn_convolution_transpose_deprecated - func: cudnn_convolution_transpose.deprecated2(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_deprecated2 - func: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose # NB: output_padding not strictly needed here, but it's helpful for the float # backwards - func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward - func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward_input - func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward_weight # NB: input is special cased in a way I don't quite understand - func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output - use_c10_dispatcher: full dispatch: CUDA: cudnn_grid_sampler_forward - func: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid) - use_c10_dispatcher: full dispatch: CUDA: cudnn_grid_sampler_backward - func: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cummax - func: cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: cummax_out - func: cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) variants: function, method - func: cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU: cummax_helper_cpu CUDA: cummax_helper_cuda - func: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cummin - func: cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: cummin_out - func: cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) variants: function, method - func: cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU: cummin_helper_cpu CUDA: cummin_helper_cuda - func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cumprod + +- func: cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + variants: method + dispatch: + DefaultBackend: cumprod_ - func: cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: cumprod_out - func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function, method +- func: cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + variants: method + - func: cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: cumprod_backward(Tensor grad, Tensor input, int dim) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: cumsum + +- func: cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + variants: method + dispatch: + DefaultBackend: cumsum_ - func: cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: cumsum_out - func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function, method +- func: cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + variants: method + - func: cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor - use_c10_dispatcher: full # convenience function that converts to intlists for you - func: ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor - use_c10_dispatcher: full - func: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) - use_c10_dispatcher: full dispatch: CPU: ctc_loss_cpu CUDA: ctc_loss_gpu - func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor - use_c10_dispatcher: full dispatch: CPU: ctc_loss_backward_cpu CUDA: ctc_loss_backward_gpu - func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor - use_c10_dispatcher: full variants: function, method - func: diagflat(Tensor self, int offset=0) -> Tensor - use_c10_dispatcher: full variants: function, method - func: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: diagonal - func: diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a) variants: function, method - func: diagonal_backward(Tensor grad, int[] input_sizes, int offset, int dim1, int dim2) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: div.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: div SparseCPU, SparseCUDA: div_sparse - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: div_ @@ -1298,92 +1393,88 @@ # For C++ only, until we have conversion from C++ numbers to Tensor - func: div.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: div - func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: div_ # divide, alias for div - func: divide.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: divide.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method # true_divide, an alias for div - func: true_divide.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: true_divide.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: dot(Tensor self, Tensor tensor) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: dot CUDA: dot_cuda - func: dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: dot_out - func: vdot(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: vdot CUDA: vdot_cuda - func: vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: vdot_out - func: einsum(str equation, Tensor[] tensors) -> Tensor - use_c10_dispatcher: full - func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: embedding - func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor - use_c10_dispatcher: full - func: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor - use_c10_dispatcher: full dispatch: CPU: embedding_dense_backward_cpu CUDA: embedding_dense_backward_cuda - func: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) - use_c10_dispatcher: full dispatch: CPU: embedding_renorm_cpu_ CUDA: embedding_renorm_cuda_ - func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor - use_c10_dispatcher: full # NOTE [ embedding_bag Native Functions ] # The `_embedding_bag.*` variants assume that input tensors except for `weight`, @@ -1396,49 +1487,56 @@ - func: _embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _embedding_bag_forward_only_cpu CUDA: _embedding_bag_forward_only_cuda - func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) - use_c10_dispatcher: full + +# row_stack is the alias of vstack +- func: row_stack(Tensor[] tensors) -> Tensor + dispatch: + Math: row_stack + +- func: row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: row_stack_out - func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _embedding_bag_cpu CUDA: _embedding_bag_cuda - func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _embedding_bag_dense_backward_cpu CUDA: _embedding_bag_dense_backward_cuda - func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode) -> Tensor - use_c10_dispatcher: full dispatch: CPU: _embedding_bag_per_sample_weights_backward_cpu CUDA: _embedding_bag_per_sample_weights_backward_cuda - func: empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - #use_c10_dispatcher: full - func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - #use_c10_dispatcher: full dispatch: CPU: empty_cpu CUDA: empty_cuda @@ -1446,20 +1544,23 @@ SparseCPU, SparseCUDA: empty_sparse - func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - #use_c10_dispatcher: full + variants: method + +- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method - func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method - func: new_zeros(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method # other overrides are to provide a more helpful error message that dtype is required - func: _empty_affine_quantized(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: empty_affine_quantized_other_backends_stub QuantizedCPU, QuantizedCUDA: empty_affine_quantized @@ -1467,128 +1568,143 @@ # it's a factory function receiving a tensor argument, thus overriding explicitly # other overrides are to provide a more helpful error message that dtype is required - func: _empty_per_channel_affine_quantized(int[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures category_override: factory dispatch: CPU: empty_per_channel_affine_quantized_other_backends_stub QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized - func: resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) - use_c10_dispatcher: full variants: method device_guard: False dispatch: CPU: resize_ CUDA: resize_cuda_ QuantizedCPU: quantized_resize_cpu_ + Meta: resize_meta_ - func: empty_quantized(int[] size, Tensor qtensor) -> Tensor - use_c10_dispatcher: full variants: function dispatch: QuantizedCPU, QuantizedCUDA: empty_quantized - func: empty.out(int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full dispatch: CPU: empty_strided_cpu CUDA: empty_strided_cuda - func: erf(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: erf - func: erf_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: erf_ - func: erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: erf_out - func: erfc(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: erfc - func: erfc_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: erfc_ - func: erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: erfc_out - func: exp(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: exp - func: exp_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: exp_ - func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: exp_out - func: exp2(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: exp2 - func: exp2_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: exp2_ - func: exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: exp2_out - func: expm1(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: expm1 - func: expm1_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: expm1_ - func: expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: expm1_out - func: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a) - use_c10_dispatcher: full variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False + dispatch: + DefaultBackend: expand - func: expand_as(Tensor(a) self, Tensor other) -> Tensor(a) - use_c10_dispatcher: full variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False - func: eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: eye.m(int n, int m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: eye.out(int n, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: eye_out_cpu CUDA: eye_out_cuda - func: eye.m_out(int n, int m, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: eye_out_cpu CUDA: eye_out_cuda - func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) - use_c10_dispatcher: full variants: function, method - func: flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a) @@ -1607,102 +1723,107 @@ variants: method - func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: fill_ - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: fill_ - func: floor(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: floor - func: floor_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: floor_ - func: floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: floor_out - func: floor_divide(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: floor_divide SparseCPU, SparseCUDA: floor_divide_sparse - func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: floor_divide_ SparseCPU, SparseCUDA: floor_divide_sparse_ - func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: floor_divide_out SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim - func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: frac(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: frac - func: frac_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: frac_ - func: frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: frac_out - func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: from_file - func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: gcd_out - func: gcd(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: lcm_out - func: lcm(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method # NOTE [ grid_sampler Native Functions ] @@ -1721,133 +1842,143 @@ # Nor does it take in `align_corners` because it only supports the mode # `align_corners = True`. - func: grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor - use_c10_dispatcher: full - func: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor - use_c10_dispatcher: full dispatch: CPU: grid_sampler_2d_cpu CUDA: grid_sampler_2d_cuda - func: grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) - use_c10_dispatcher: full dispatch: CPU: grid_sampler_2d_backward_cpu CUDA: grid_sampler_2d_backward_cuda # See NOTE [ grid_sample CPU fallback ] - func: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: _grid_sampler_2d_cpu_fallback - func: _grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor - use_c10_dispatcher: full dispatch: CPU: grid_sampler_3d_cpu CUDA: grid_sampler_3d_cuda - func: grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) - use_c10_dispatcher: full dispatch: CPU: grid_sampler_3d_backward_cpu CUDA: grid_sampler_3d_backward_cuda - func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor - use_c10_dispatcher: full - func: group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: native_group_norm Math: math_group_norm - func: native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int N, int C, int HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: native_group_norm_backward -- func: ifft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor - use_c10_dispatcher: full - variants: function, method +# Real to complex forward FFT +- func: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + variants: function + dispatch: + CPU: _fft_r2c_mkl + CUDA: _fft_r2c_cufft -- func: rfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True) -> Tensor - use_c10_dispatcher: full - variants: function, method +- func: _fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU: _fft_r2c_mkl_out + CUDA: _fft_r2c_cufft_out -- func: irfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True, int[] signal_sizes=[]) -> Tensor - use_c10_dispatcher: full - variants: function, method +# Complex to real inverse FFT +- func: _fft_c2r(Tensor self, int[] dim, int normalization, int last_dim_size) -> Tensor + variants: function + dispatch: + CPU: _fft_c2r_mkl + CUDA: _fft_c2r_cufft -- func: _fft_with_size(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, bool normalized, bool onesided, int[] output_sizes) -> Tensor - use_c10_dispatcher: full +- func: _fft_c2r.out(Tensor self, int[] dim, int normalization, int last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function + dispatch: + CPU: _fft_c2r_mkl_out + CUDA: _fft_c2r_cufft_out -- func: _fft_with_size.norm_modes(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, int normalization, bool onesided, int[] output_sizes) -> Tensor - use_c10_dispatcher: full +# Standard complex to complex FFT (forward or backward) +- func: _fft_c2c(Tensor self, int[] dim, int normalization, bool forward) -> Tensor + variants: function + dispatch: + CPU: _fft_c2c_mkl + CUDA: _fft_c2c_cufft + +- func: _fft_c2c.out(Tensor self, int[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: - CPU: _fft_mkl - CUDA: _fft_cufft + CPU: _fft_c2c_mkl_out + CUDA: _fft_c2c_cufft_out - func: _cufft_get_plan_cache_size(int device_index) -> int - use_c10_dispatcher: full - func: _cufft_get_plan_cache_max_size(int device_index) -> int - use_c10_dispatcher: full - func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> () - use_c10_dispatcher: full - func: _cufft_clear_plan_cache(int device_index) -> () - use_c10_dispatcher: full - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor variants: function, method dispatch: CPU, CUDA: index + QuantizedCPU: quantized_index # NB: This function is special-cased in tools/autograd/gen_variable_type.py # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor Tensor::index(ArrayRef indices) # - Tensor Tensor::index(std::initializer_list indices) - func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: index_copy_ - func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor - use_c10_dispatcher: full variants: function, method - func: index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!) @@ -1858,6 +1989,8 @@ - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) variants: function, method + dispatch: + DefaultBackend: index_put_ # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs) # - Tensor & Tensor::index_put_(ArrayRef indices, Scalar v) @@ -1873,28 +2006,29 @@ CPU, CUDA: _index_put_impl_ - func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function - func: inverse(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: inverse - func: inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: inverse_out - func: _inverse_helper(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: _inverse_helper_cpu CUDA: _inverse_helper_cuda - func: isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor - use_c10_dispatcher: full variants: function, method - func: isnan(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method device_guard: False dispatch: @@ -1902,53 +2036,58 @@ SparseCPU, SparseCUDA: isnan_sparse - func: is_distributed(Tensor self) -> bool - use_c10_dispatcher: full variants: function, method device_guard: False - func: is_floating_point(Tensor self) -> bool - use_c10_dispatcher: full variants: function, method device_guard: False - func: is_complex(Tensor self) -> bool - use_c10_dispatcher: full variants: function, method device_guard: False - func: isreal(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: is_nonzero(Tensor self) -> bool - use_c10_dispatcher: full variants: function, method device_guard: False - func: is_same_size(Tensor self, Tensor other) -> bool - use_c10_dispatcher: full variants: function, method device_guard: False - func: is_signed(Tensor self) -> bool - use_c10_dispatcher: full variants: function, method device_guard: False - func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: kl_div - func: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor - use_c10_dispatcher: full dispatch: CPU: kl_div_backward_cpu CUDA: kl_div_backward_cuda +- func: kron(Tensor self, Tensor other) -> Tensor + variants: function, method + dispatch: + Math: kron + +- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: kron_out + - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: kthvalue - func: kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: kthvalue_out_cpu CUDA: kthvalue_out_cuda @@ -1957,233 +2096,318 @@ variants: function, method - func: kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures -- func: native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full +- func: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: layer_norm_cpu CUDA: layer_norm_cuda + Math: math_native_layer_norm -- func: native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full +- func: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: layer_norm_backward_cpu CUDA: layer_norm_backward_cuda +- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + variants: function, method + dispatch: + DefaultBackend: nan_to_num + +- func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) + variants: function, method + dispatch: + DefaultBackend: nan_to_num_ + +- func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: nan_to_num_out + - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: mkldnn_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: MkldnnCPU: mkldnn_linear - func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor - use_c10_dispatcher: full - func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor - use_c10_dispatcher: full - func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) - use_c10_dispatcher: full - func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor - use_c10_dispatcher: full - func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor - use_c10_dispatcher: full - func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor - use_c10_dispatcher: full - func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor - use_c10_dispatcher: full - func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor - use_c10_dispatcher: full + +- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor + variants: function, method + +- func: ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: function, method + +- func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: linspace_cpu_out CUDA: linspace_cuda_out - func: log(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: log - func: log_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: log_ - func: log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: log_out - func: log10(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: log10 - func: log10_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: log10_ - func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: log10_out - func: log1p(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: log1p + SparseCPU, SparseCUDA: log1p_sparse - func: log1p_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: log1p_ SparseCPU, SparseCUDA: log1p_sparse_ - func: log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: log1p_out SparseCPU, SparseCUDA: log1p_out_sparse - func: log2(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: log2 - func: log2_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: log2_ - func: log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: log2_out - func: logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: logaddexp_out - func: logaddexp(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: logaddexp - func: logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: logaddexp2_out - func: logaddexp2(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: logaddexp2 -- func: logdet(Tensor self) -> Tensor +- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: xlogy -- func: logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: function + dispatch: + CPU, CUDA: xlogy + +- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: xlogy + +# xlogy: inplace variant +- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: xlogy_ + +- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: xlogy_ + +# xlogy: out variant +- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU, CUDA: xlogy_out + +- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU, CUDA: xlogy_out + +- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU, CUDA: xlogy_out + +- func: logdet(Tensor self) -> Tensor + variants: function, method + dispatch: + DefaultBackend: logdet + +- func: logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: logspace_cpu_out CUDA: logspace_cuda_out # log_softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. - func: log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method - func: log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function, method - func: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor - use_c10_dispatcher: full dispatch: CPU: log_softmax_cpu CUDA: log_softmax_cuda - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor - use_c10_dispatcher: full dispatch: CPU: log_softmax_backward_cpu CUDA: log_softmax_backward_cuda - func: _logcumsumexp(Tensor self, int dim) -> Tensor - use_c10_dispatcher: full dispatch: CPU: _logcumsumexp_cpu CUDA: _logcumsumexp_cuda - func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _logcumsumexp_out_cpu CUDA: _logcumsumexp_out_cuda - func: logcumsumexp(Tensor self, int dim) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: logcumsumexp - func: logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: logcumsumexp_out - func: logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor variants: function, method - func: logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: logsumexp - func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: logsumexp_out - func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor variants: function, method - func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor - use_c10_dispatcher: full - func: matmul(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: matrix_rank.tol(Tensor self, float tol, bool symmetric=False) -> Tensor - use_c10_dispatcher: full - func: matrix_rank(Tensor self, bool symmetric=False) -> Tensor - use_c10_dispatcher: full - func: matrix_power(Tensor self, int n) -> Tensor - use_c10_dispatcher: full variants: function, method - func: matrix_exp(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: matrix_exp - func: matrix_exp_backward(Tensor self, Tensor grad) -> Tensor - use_c10_dispatcher: full - func: _aminmax(Tensor self) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: _aminmax_all - func: _aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: _aminmax @@ -2193,14 +2417,17 @@ CPU, CUDA: _compute_linear_combination - func: _compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: _compute_linear_combination_out - func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: max - func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: max_out @@ -2208,65 +2435,63 @@ variants: function, method - func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: amax - func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: amax_out # Return: (Tensor output, Tensor indices) - func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor - use_c10_dispatcher: full - func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor - use_c10_dispatcher: full - func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor - use_c10_dispatcher: full dispatch: MkldnnCPU: mkldnn_max_pool2d - func: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor - use_c10_dispatcher: full dispatch: MkldnnCPU: mkldnn_max_pool3d +- func: quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + dispatch: + QuantizedCPU: quantized_max_pool1d + - func: quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor - use_c10_dispatcher: full dispatch: QuantizedCPU: quantized_max_pool2d - func: max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor - use_c10_dispatcher: full # The CPU and GPU dispatch variants are named weirdly here because otherwise there # are namespacing issues in C++ - func: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: mean_cpu_gpu QuantizedCPU: mean_quantized_cpu - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: mean_cpu_gpu QuantizedCPU: mean_quantized_cpu - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: mean_out_cpu_gpu QuantizedCPU: mean_out_quantized_cpu @@ -2275,23 +2500,61 @@ variants: function, method - func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + +- func: median(Tensor self) -> Tensor + variants: function, method + dispatch: + CPU: median_cpu + CUDA: median_cuda - func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: median - func: median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU: median_out_cpu + CUDA: median_out_cuda - func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method - func: median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + +- func: nanmedian(Tensor self) -> Tensor + variants: function, method + dispatch: + CPU: nanmedian_cpu + CUDA: nanmedian_cuda + +- func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + dispatch: + DefaultBackend: nanmedian + +- func: nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU: nanmedian_out_cpu + CUDA: nanmedian_out_cuda + +- func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + variants: function, method + +- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: min - func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: min_out @@ -2299,116 +2562,109 @@ variants: function, method - func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: amin - func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: amin_out - func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: mkldnn_convolution - func: mkldnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool bias_defined) -> Tensor - use_c10_dispatcher: full - func: mkldnn_convolution_backward_weights(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool bias_defined) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + dispatch: + DefaultBackend: mkldnn_convolution_backward - func: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: miopen_batch_norm - func: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: miopen_batch_norm_backward - func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: miopen_convolution - func: miopen_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: miopen_convolution_backward_input - func: miopen_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: miopen_convolution_backward - func: miopen_convolution_backward_bias(Tensor grad_output) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: miopen_convolution_backward_bias - func: miopen_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: miopen_convolution_backward_weight - func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: miopen_convolution_transpose # NB: output_padding not strictly needed here, but it's helpful for the float # backwards - func: miopen_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: miopen_convolution_transpose_backward - func: miopen_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: miopen_convolution_transpose_backward_input - func: miopen_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: miopen_convolution_transpose_backward_weight - func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: miopen_depthwise_convolution - func: miopen_depthwise_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: miopen_depthwise_convolution_backward_input - func: miopen_depthwise_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: miopen_depthwise_convolution_backward - func: miopen_depthwise_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - use_c10_dispatcher: full dispatch: CUDA: miopen_depthwise_convolution_backward_weight - func: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: miopen_rnn - func: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: miopen_rnn_backward - func: mm(Tensor self, Tensor mat2) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: mm_cpu @@ -2416,27 +2672,42 @@ SparseCPU, SparseCUDA: _sparse_mm - func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: mm_cpu_out CUDA: mm_out_cuda SparseCPU, SparseCUDA: _sparse_mm_out - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor + +- func: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full + dispatch: + SparseCPU: sparse_sparse_matmul_cpu + SparseCUDA: sparse_sparse_matmul_cuda + +- func: _sparse_matrix_mask_helper(Tensor t, Tensor mask_indices) -> Tensor + dispatch: + SparseCPU: sparse_matrix_mask_helper_cpu + SparseCUDA: sparse_matrix_mask_helper_cuda - func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: mode - func: mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: mode_out - func: mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method - func: mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: mul.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: mul @@ -2444,7 +2715,6 @@ MkldnnCPU: mkldnn_mul - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: mul_ @@ -2452,6 +2722,7 @@ MkldnnCPU: mkldnn_mul_ - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: mul_out SparseCPU: mul_out_sparse_cpu @@ -2460,203 +2731,207 @@ # For C++ only, until we have conversion from C++ numbers to Tensor - func: mul.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: mul - func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: mul_ # multiply, alias for mul - func: multiply.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: multiply.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: mv(Tensor self, Tensor vec) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: mv SparseCPU, SparseCUDA: mv_sparse - func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: mv_out - func: mvlgamma(Tensor self, int p) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: mvlgamma - func: mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: mvlgamma_ - func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor - use_c10_dispatcher: full - variants: method + variants: function, method dispatch: - CPU, CUDA: narrow_copy_dense + CPU: narrow_copy_dense_cpu SparseCPU, SparseCUDA: narrow_copy_sparse + DefaultBackend: narrow_copy_dense + +- func: narrow_copy.out(Tensor self, int dim, int start, int length, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: narrow_copy_dense_cpu_out - func: narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False - func: narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False - func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: batch_norm_cpu CUDA: batch_norm_cuda MkldnnCPU: mkldnn_batch_norm - func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: batch_norm_cuda_out - func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: batch_norm_stats_cuda - func: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: batch_norm_elemt_cuda - func: batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: batch_norm_elemt_cuda_out # for backward compatibility - func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: batch_norm_gather_stats_cuda - func: batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: batch_norm_gather_stats_with_counts_cuda - func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: batch_norm_backward_cpu CUDA: batch_norm_backward_cuda - func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: batch_norm_backward_reduce_cuda - func: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor mean_dy, Tensor mean_dy_xmu) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: batch_norm_backward_elemt_cuda - func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: batch_norm_update_stats_cpu CUDA: batch_norm_update_stats_cuda - func: is_vulkan_available() -> bool - use_c10_dispatcher: full - func: _nnpack_available() -> bool - use_c10_dispatcher: full - func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function + dispatch: + DefaultBackend: _nnpack_spatial_convolution - func: _nnpack_spatial_convolution_backward(Tensor input, Tensor grad_output, Tensor weight, int[2] padding, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function - func: _nnpack_spatial_convolution_backward_input(Tensor input, Tensor grad_output, Tensor weight, int[2] padding) -> Tensor - use_c10_dispatcher: full variants: function - func: _nnpack_spatial_convolution_backward_weight(Tensor input, int[] weightsize, Tensor grad_output, int[2] padding) -> Tensor - use_c10_dispatcher: full variants: function - func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: ones.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor - use_c10_dispatcher: full - func: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor - use_c10_dispatcher: full - func: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: _euclidean_dist - func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor - use_c10_dispatcher: full dispatch: CPU, CUDA: _cdist_forward - func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor - use_c10_dispatcher: full dispatch: CPU, CUDA: _cdist_backward - func: pdist(Tensor self, float p=2) -> Tensor - use_c10_dispatcher: full - func: _pdist_forward(Tensor self, float p=2) -> Tensor - use_c10_dispatcher: full dispatch: CPU, CUDA: _pdist_forward - func: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor - use_c10_dispatcher: full dispatch: CPU, CUDA: _pdist_backward - func: cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor - use_c10_dispatcher: full variants: function - func: permute(Tensor(a) self, int[] dims) -> Tensor(a) - use_c10_dispatcher: full variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + dispatch: + DefaultBackend: permute - func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) - use_c10_dispatcher: full variants: function, method - func: movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) - use_c10_dispatcher: full + variants: function, method + +# moveaxis, alias for movedim +- func: moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) + variants: function, method + +- func: moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a) variants: function, method # Only exposed from C++ -- in Python, @@ -2667,223 +2942,253 @@ # behavior on Windows, for reasons I don't understand # (maybe related to capital letter collation somehow...) - func: numpy_T(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: method - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + +- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor use_c10_dispatcher: full - func: channel_shuffle(Tensor self, int groups) -> Tensor - use_c10_dispatcher: full dispatch: CPU: channel_shuffle QuantizedCPU: channel_shuffle_quantized_cpu - func: is_pinned(Tensor self) -> bool - use_c10_dispatcher: full variants: method - func: pin_memory(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: method - func: pinverse(Tensor self, float rcond=1e-15) -> Tensor - use_c10_dispatcher: full variants: function, method - func: poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor - use_c10_dispatcher: full variants: function - func: rad2deg(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: rad2deg - func: rad2deg_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: rad2deg_ - func: rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: rad2deg_out - func: deg2rad(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: deg2rad - func: deg2rad_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: deg2rad_ - func: deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: deg2rad_out - func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rand.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: rand.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rand.generator(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rand.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rand.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint.generator(int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint.low(int low, int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint.low_generator(int low, int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint.out(int high, int[] size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint.generator_out(int high, int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint.low_out(int low, int high, int[] size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint.low_generator_out(int low, int high, int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint_like(Tensor self, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randint_like.low_dtype(Tensor self, int low, int high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randn.generator(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randn.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: randn.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: randn.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randn.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randperm(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randperm.generator(int n, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randperm.out(int n, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randperm.generator_out(int n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: randperm_out_cpu CUDA: randperm_out_cuda - func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: range_cpu_out CUDA: range_cuda_out +- func: ravel(Tensor(a) self) -> Tensor(a) + variants: function, method + - func: reciprocal(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: reciprocal - func: reciprocal_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: reciprocal_ - func: reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: reciprocal_out - func: neg(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: neg - func: neg_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: neg_ SparseCPU, SparseCUDA: neg_sparse_ - func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: neg_out SparseCPU, SparseCUDA: neg_out_sparse # Alias for neg - func: negative(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: negative_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: repeat(Tensor self, int[] repeats) -> Tensor - use_c10_dispatcher: full variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + dispatch: + DefaultBackend: repeat - func: repeat_interleave.Tensor(Tensor repeats) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: repeat_interleave_cpu CUDA: repeat_interleave_cuda - func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None) -> Tensor - use_c10_dispatcher: full variants: function, method - func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None) -> Tensor - use_c10_dispatcher: full variants: function, method - func: reshape(Tensor(a) self, int[] shape) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False - func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor - use_c10_dispatcher: full device_guard: False dispatch: MkldnnCPU: mkldnn_reshape - func: reshape_as(Tensor(a) self, Tensor other) -> Tensor(a) - use_c10_dispatcher: full variants: method device_guard: False - func: round(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: round - func: round_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: round_ - func: round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: round_out CUDA: round_out @@ -2893,7 +3198,6 @@ - func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) - func: relu(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: relu @@ -2901,7 +3205,6 @@ QuantizedCPU: relu_quantized_cpu - func: relu_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: relu_ @@ -2909,60 +3212,56 @@ QuantizedCPU: relu_quantized_cpu_ - func: prelu(Tensor self, Tensor weight) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: prelu_cpu CUDA: prelu_cuda - func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function, method dispatch: CPU: prelu_backward_cpu CUDA: prelu_backward_cuda - func: gelu(Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: gelu_cpu CUDA: gelu_cuda - func: gelu_backward(Tensor grad, Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: gelu_backward_cpu CUDA: gelu_backward_cuda - func: infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor - use_c10_dispatcher: full variants: function python_module: nn device_guard: False - func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: hardshrink - func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: hardshrink_backward - func: rsqrt(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: rsqrt - func: rsqrt_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: rsqrt_ - func: rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: rsqrt_out @@ -2971,46 +3270,50 @@ device_guard: False - func: select.int(Tensor(a) self, int dim, int index) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: select - func: select_backward(Tensor grad, int[] input_sizes, int dim, int index) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: selu(Tensor self) -> Tensor - use_c10_dispatcher: full - func: selu_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full - func: celu(Tensor self, Scalar alpha=1.0) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: celu - func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) - use_c10_dispatcher: full + dispatch: + DefaultBackend: celu_ - func: silu(Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: silu - func: silu_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: silu_ - func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: silu_out - func: silu_backward(Tensor grad_output, Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn + dispatch: + CPU, CUDA: silu_backward + Math: math_silu_backward - func: sigmoid(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: sigmoid @@ -3018,53 +3321,75 @@ MkldnnCPU: mkldnn_sigmoid - func: sigmoid_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: sigmoid_ MkldnnCPU: mkldnn_sigmoid_ - func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sigmoid_out - func: logit(Tensor self, float? eps=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: logit - func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: logit_ - func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: logit_out - func: sin(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sin - func: sin_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sin_ - func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sin_out -- func: sinh(Tensor self) -> Tensor +- func: sinc(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sinc -- func: sinh_(Tensor(a!) self) -> Tensor(a!) +- func: sinc_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sinc_ + +- func: sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: sinc_out + +- func: sinh(Tensor self) -> Tensor + variants: function, method + dispatch: + DefaultBackend: sinh + +- func: sinh_(Tensor(a!) self) -> Tensor(a!) + variants: function, method + dispatch: + DefaultBackend: sinh_ - func: sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sinh_out @@ -3080,117 +3405,125 @@ # changing metadata of the detached tensor and expecting the original tensor to also # be updated. - func: detach(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: detach # Like `detach()`, but modifies this `Variable` in-place. This method may # only be called on non-view `Variable`s. You can use `is_view()` to check # this. If this `Variable` is a view, throws an `std::runtime_error()`. - func: detach_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: detach_ - func: size.int(Tensor self, int dim) -> int - use_c10_dispatcher: full - variants: function, method + variants: function device_guard: False + manual_cpp_binding: True - func: size.Dimname(Tensor self, Dimname dim) -> int variants: function, method device_guard: False -- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a) +- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a) use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: slice - func: slice_backward(Tensor grad, int[] input_sizes, int dim, int start, int end, int step) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: slogdet - func: smm(Tensor self, Tensor mat2) -> Tensor - use_c10_dispatcher: full variants: function, method # softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. - func: softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method - func: softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function, method - func: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor - use_c10_dispatcher: full dispatch: CPU: softmax_cpu CUDA: softmax_cuda MkldnnCPU: mkldnn_softmax - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor - use_c10_dispatcher: full dispatch: CPU: softmax_backward_cpu CUDA: softmax_backward_cuda - func: unsafe_split.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[] - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: unsafe_split - func: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: split - func: unsafe_split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: unsafe_split_with_sizes - func: split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> Tensor(a)[] - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: split_with_sizes - func: squeeze(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: squeeze - func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: squeeze - func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a) variants: function, method device_guard: False - func: squeeze_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method device_guard: False + dispatch: + DefaultBackend: squeeze_ - func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) - use_c10_dispatcher: full variants: method device_guard: False + dispatch: + DefaultBackend: squeeze_ - func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) variants: method device_guard: False - func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method - func: sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _sspaddmm_out_only_sparse CUDA: _sspaddmm_out_only_sparse_cuda @@ -3198,54 +3531,56 @@ SparseCUDA: _sspaddmm_out_cuda - func: stack(Tensor[] tensors, int dim=0) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: stack - func: stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: stack_out - func: hstack(Tensor[] tensors) -> Tensor - use_c10_dispatcher: full - func: hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: vstack(Tensor[] tensors) -> Tensor - use_c10_dispatcher: full - func: vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: dstack(Tensor[] tensors) -> Tensor - use_c10_dispatcher: full - func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures # The signature is designed to be consistent with librosa except that it is # missing the `pad_mode` and `center` arguments, which are taken care of at # `torch.functional.py`. They shall be moved here once we have mapping between # Python strings and C++ Enum in codegen. - func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function, method - func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function, method - func: stride.int(Tensor self, int dim) -> int - use_c10_dispatcher: full - variants: function, method + variants: function device_guard: False + manual_cpp_binding: True - func: stride.Dimname(Tensor self, Dimname dim) -> int variants: function, method device_guard: False - func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: sum - func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: sum @@ -3254,72 +3589,69 @@ variants: function, method - func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sum_out - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: nansum(Tensor self, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: nansum - func: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: nansum - func: nansum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: nansum_out - func: sum_to_size(Tensor self, int[] size) -> Tensor - use_c10_dispatcher: full variants: method device_guard: False - func: sqrt(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sqrt - func: sqrt_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sqrt_ - func: sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sqrt_out - func: square(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: square_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: std(Tensor self, bool unbiased=True) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: std - func: std.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: std - func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: std_mean - func: std_mean.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: std_mean @@ -3328,6 +3660,7 @@ variants: function - func: std.out(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: std_out @@ -3335,20 +3668,20 @@ variants: function, method - func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: prod - func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: prod - func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: prod_out @@ -3356,51 +3689,62 @@ variants: function, method - func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: t(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full device_guard: False variants: function, method + dispatch: + DefaultBackend: t - func: t_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full device_guard: False variants: method + dispatch: + DefaultBackend: t_ - func: tan(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: tan - func: tan_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: tan_ - func: tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: tan_out - func: tanh(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: tanh QuantizedCPU: tanh_quantized_cpu - func: tanh_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: tanh_ - func: tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: tanh_out - func: tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor - use_c10_dispatcher: full variants: function +- func: tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU, CUDA: tensordot_out + # TODO: namespace threshold in 'nn' - func: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: threshold @@ -3408,72 +3752,69 @@ QuantizedCPU: threshold_quantized_cpu - func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) - use_c10_dispatcher: full variants: function dispatch: CPU: threshold_ CUDA: threshold__cuda - func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: threshold_out CUDA: threshold_out_cuda - func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: threshold_backward CUDA: threshold_backward_cuda +- func: tile(Tensor self, int[] dims) -> Tensor + variants: function, method + - func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: transpose - func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) variants: function, method device_guard: False - func: _mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor - use_c10_dispatcher: full device_guard: False dispatch: MkldnnCPU: mkldnn_transpose - func: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) - use_c10_dispatcher: full variants: method device_guard: False + dispatch: + DefaultBackend: transpose_ - func: _mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) - use_c10_dispatcher: full device_guard: False dispatch: MkldnnCPU: mkldnn_transpose_ - func: one_hot(Tensor self, int num_classes=-1) -> Tensor - use_c10_dispatcher: full python_module: nn variants: function - func: flip(Tensor self, int[] dims) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: - CPU: flip_cpu + CPU, QuantizedCPU: flip_cpu CUDA: flip_cuda - func: fliplr(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: flipud(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: roll_cpu @@ -3482,75 +3823,70 @@ # default int[] value [0,1] should not add space after comma, since codegen parser uses ', ' to split args - func: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: rot90 - func: trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor - use_c10_dispatcher: full - func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor - use_c10_dispatcher: full - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: _trilinear - func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor - use_c10_dispatcher: full - func: trunc(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: trunc - func: trunc_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: trunc_ - func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: trunc_out # Alias for trunc - func: fix(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: fix_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method - func: fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: type_as(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method - func: _has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool - use_c10_dispatcher: full variants: function - func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: _unique_cpu CUDA: _unique_cuda - func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: unique_dim_cpu CUDA: unique_dim_cuda - func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: unique_consecutive_cpu CUDA: unique_consecutive_cuda - func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: unique_dim_consecutive_cpu @@ -3561,41 +3897,41 @@ # Please don't rely on these two operators, they will be removed soon - func: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: _unique2_cpu CUDA: _unique2_cuda - func: _unsafe_view(Tensor self, int[] size) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: _unsafe_view - func: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) - use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + DefaultBackend: unsqueeze - func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) - use_c10_dispatcher: full variants: method device_guard: False + dispatch: + DefaultBackend: unsqueeze_ - func: vander(Tensor x, int? N=None, bool increasing=False) -> Tensor - use_c10_dispatcher: full - func: var(Tensor self, bool unbiased=True) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: var - func: var.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: var - func: var.out(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: var_out @@ -3603,15 +3939,14 @@ variants: function, method - func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: var_mean - func: var_mean.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: var_mean @@ -3620,7 +3955,6 @@ variants: function - func: view_as(Tensor(a) self, Tensor other) -> Tensor(a) - use_c10_dispatcher: full variants: method device_guard: False @@ -3628,70 +3962,60 @@ # this allows us to implicitly calculate the broadcast derivative, while only dealing with the # _s_where derivative. - func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function, method - func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function - func: where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function - func: where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: function - func: where(Tensor condition) -> Tensor[] - use_c10_dispatcher: full variants: function - func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: _s_where - func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor - use_c10_dispatcher: full variants: function # VariableType::_weight_norm does not want to be given a gap in the autograd graph, # so we don't define "dispatch" variants for it. - func: _weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor - use_c10_dispatcher: full variants: function - func: _weight_norm_cuda_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CUDA: weight_norm_cuda - func: _weight_norm_cuda_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CUDA: weight_norm_cuda_backward - func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function - func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: _standard_gamma_grad_cpu @@ -3704,7 +4028,6 @@ CUDA: _s_gamma_cuda - func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor - use_c10_dispatcher: full dispatch: CPU: _dirichlet_grad_cpu CUDA: _dirichlet_grad_cuda @@ -3729,93 +4052,88 @@ # complicated - func: native_norm(Tensor self, Scalar p=2) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU, SparseCUDA: norm_sparse - func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU, SparseCUDA: norm_sparse # TODO: reduce signatures down to one when optional args is available - func: _sparse_sum(Tensor self) -> Tensor - use_c10_dispatcher: full - func: _sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor - use_c10_dispatcher: full - func: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: _sparse_sum - func: _sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor - use_c10_dispatcher: full - func: _sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU: _sparse_sum_backward_cpu SparseCUDA: _sparse_sum_backward_cuda - func: _sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function - func: _sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function - func: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU: softmax_sparse_cpu SparseCUDA: softmax_sparse_cuda - func: _sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU: softmax_backward_sparse_cpu SparseCUDA: softmax_backward_sparse_cuda - func: _sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor - use_c10_dispatcher: full variants: function - func: _sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function - func: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU: log_softmax_sparse_cpu SparseCUDA: log_softmax_sparse_cuda - func: _sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU: log_softmax_backward_sparse_cpu SparseCUDA: log_softmax_backward_sparse_cuda - func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: norm - func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: norm - func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: norm - func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: norm - func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: norm_out - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: norm_out @@ -3826,36 +4144,36 @@ variants: function, method - func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: frobenius_norm(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function - func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function - func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function - func: nuclear_norm(Tensor self, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function - func: nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function - func: nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: function - func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function - func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: clone @@ -3864,11 +4182,11 @@ QuantizedCPU, QuantizedCUDA: quantized_clone - func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: resize_as_ - func: zero_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: zero_ @@ -3876,19 +4194,18 @@ MkldnnCPU: mkldnn_zero_ - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sub_out SparseCPU, SparseCUDA: sub_out_sparse - func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: sub SparseCPU, SparseCUDA: sub_sparse - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: sub_ @@ -3896,61 +4213,63 @@ # For C++ only, until we have conversion from C++ numbers to Tensor - func: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sub - func: sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: sub_ # subtract, alias for sub - func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method - func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method # For C++ only, until we have conversion from C++ numbers to Tensor - func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method - func: subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: rsub - func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: heaviside_out - func: heaviside(Tensor self, Tensor values) -> Tensor - use_c10_dispatcher: full variants: function, method - func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method # For C++ only, until we have conversion from C++ numbers to Tensor - func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function + dispatch: + DefaultBackend: rsub # Functionally the same as addmm, but we give it a different derivative formula # that doesn't propagate gradients to non-present entries on sparse. - func: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full + dispatch: + DefaultBackend: _sparse_addmm - func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: addmm_cpu_out CUDA: addmm_out_cuda @@ -3958,7 +4277,6 @@ SparseCUDA: addmm_out_sparse_dense_cuda - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU: addmm_cpu @@ -3967,7 +4285,6 @@ SparseCUDA: addmm_sparse_dense_cuda - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: addmm_cpu_ @@ -4089,61 +4406,52 @@ # FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given # the default would never make sense. - func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size) -> () - use_c10_dispatcher: full - func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU, SparseCUDA: new_with_dims_sparse - func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU, SparseCUDA: new_with_dims_and_tensor_sparse - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: sparse_resize_ - func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: sparse_resize_and_clear_ - func: sparse_mask(Tensor self, Tensor mask) -> Tensor - use_c10_dispatcher: full variants: method dispatch: SparseCPU: sparse_mask_cpu SparseCUDA: sparse_mask_cuda -- func: to_dense(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor variants: method dispatch: SparseCPU, SparseCUDA: sparse_to_dense MkldnnCPU: mkldnn_to_dense - func: to_dense_backward(Tensor grad, Tensor input) -> Tensor - use_c10_dispatcher: full - func: sparse_dim(Tensor self) -> int - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: sparse_dim_sparse @@ -4151,14 +4459,12 @@ # legacy method - func: _dimI(Tensor self) -> int - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: sparse_dim_sparse device_guard: False - func: dense_dim(Tensor self) -> int - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: dense_dim_sparse @@ -4166,42 +4472,36 @@ # legacy method - func: _dimV(Tensor self) -> int - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: dense_dim_sparse device_guard: False - func: _nnz(Tensor self) -> int - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: _nnz_sparse device_guard: False - func: coalesce(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method dispatch: SparseCPU: coalesce_sparse_cpu SparseCUDA: coalesce_sparse_cuda - func: is_coalesced(Tensor self) -> bool - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: is_coalesced_sparse device_guard: False - func: _indices(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: _indices_sparse device_guard: False - func: _values(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: _values_sparse @@ -4211,285 +4511,237 @@ # a bit unsafe. Similar to _indices and _values, this is useful for implementing # custom sparse operations in Python/C++ extension. - func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: _coalesced_sparse_ device_guard: False - func: indices(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: indices_sparse device_guard: False - func: values(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: method dispatch: SparseCPU, SparseCUDA: values_sparse device_guard: False - func: hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: SparseCPU: hspmm_out_sparse_cpu SparseCUDA: hspmm_out_sparse_cuda - func: hspmm(Tensor mat1, Tensor mat2) -> Tensor - use_c10_dispatcher: full dispatch: SparseCPU: hspmm_sparse_cpu SparseCUDA: hspmm_sparse_cuda - func: copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) - use_c10_dispatcher: full variants: function dispatch: SparseCPU, SparseCUDA: copy_sparse_ - func: unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[] - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: unbind - func: unbind.Dimname(Tensor(a) self, Dimname dim) -> Tensor(a)[] variants: function, method - func: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: dense_to_sparse - func: to_sparse(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: dense_to_sparse -- func: to_mkldnn(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor variants: method dispatch: CPU: dense_to_mkldnn - func: mkldnn_reorder_conv2d_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1) -> Tensor - use_c10_dispatcher: full variants: function python_module: nn dispatch: MkldnnCPU: mkldnn_reorder_conv2d_weight - func: mkldnn_reorder_conv3d_weight(Tensor self, int[3] padding=0, int[3] stride=1, int[3] dilation=1, int groups=1) -> Tensor - use_c10_dispatcher: full variants: function python_module: nn dispatch: MkldnnCPU: mkldnn_reorder_conv3d_weight - func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor - use_c10_dispatcher: full - func: quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: quantize_per_tensor - func: quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[] - use_c10_dispatcher: full variants: function dispatch: CPU: quantize_per_tensor_list_cpu - func: quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: quantize_per_channel_cpu - func: dequantize.self(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: QuantizedCPU, QuantizedCUDA: dequantize_quant - func: dequantize.tensors(Tensor[] tensors) -> Tensor[] - use_c10_dispatcher: full variants: function dispatch: QuantizedCPU: dequantize_tensors_quantized_cpu - func: q_scale(Tensor self) -> float - use_c10_dispatcher: full variants: function, method dispatch: QuantizedCPU, QuantizedCUDA: q_scale_quant - func: q_zero_point(Tensor self) -> int - use_c10_dispatcher: full variants: function, method dispatch: QuantizedCPU, QuantizedCUDA: q_zero_point_quant - func: q_per_channel_scales(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: QuantizedCPU, QuantizedCUDA: q_per_channel_scales - func: q_per_channel_zero_points(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: QuantizedCPU, QuantizedCUDA: q_per_channel_zero_points - func: q_per_channel_axis(Tensor self) -> int - use_c10_dispatcher: full variants: function, method dispatch: QuantizedCPU, QuantizedCUDA: q_per_channel_axis - func: int_repr(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: QuantizedCPU: int_repr_quantized_cpu QuantizedCUDA: int_repr_quantized_cuda - func: _make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor - use_c10_dispatcher: full dispatch: CPU: make_per_tensor_quantized_tensor_cpu CUDA: make_per_tensor_quantized_tensor_cuda - func: _make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor - use_c10_dispatcher: full dispatch: CPU: make_per_channel_quantized_tensor_cpu - func: qscheme(Tensor self) -> QScheme - use_c10_dispatcher: full variants: method dispatch: QuantizedCPU, QuantizedCUDA: qscheme_quant - func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: fake_quantize_per_tensor_affine - func: fake_quantize_per_tensor_affine_backward(Tensor grad, Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor - use_c10_dispatcher: full variants: function - func: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: _fake_quantize_learnable_per_tensor_affine - func: _fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function - func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: fake_quantize_per_channel_affine - func: fake_quantize_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor - use_c10_dispatcher: full variants: function - func: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: _fake_quantize_learnable_per_channel_affine - func: _fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function - func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int) - use_c10_dispatcher: full variants: function - func: _saturate_weight_to_fp16(Tensor weight) -> Tensor - use_c10_dispatcher: full variants: function -- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (float, float) - use_c10_dispatcher: full +- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor) variants: function # to(Device) must not exist because all constructors of Device also works for # TensorOptions. Otherwise, an ambiguity error is thrown. # See NOTE [ TensorOptions Constructors ]. - func: to.dtype_layout(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method device_guard: False - func: to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full variants: method device_guard: False - func: to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full variants: method device_guard: False - func: to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor - use_c10_dispatcher: full variants: method device_guard: False - func: meshgrid(Tensor[] tensors) -> Tensor[] - use_c10_dispatcher: full - func: cartesian_prod(Tensor[] tensors) -> Tensor - use_c10_dispatcher: full variants: function - func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor - use_c10_dispatcher: full variants: function - func: item(Tensor self) -> Scalar - use_c10_dispatcher: full variants: method - func: result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType - use_c10_dispatcher: full variants: function - func: result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType - use_c10_dispatcher: full variants: function - func: result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType - use_c10_dispatcher: full variants: function - func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType - use_c10_dispatcher: full - func: can_cast(ScalarType from, ScalarType to) -> bool - use_c10_dispatcher: full variants: function - func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType - use_c10_dispatcher: full variants: function # NB: Does NOT check precondition that numel == 1 - func: _local_scalar_dense(Tensor self) -> Scalar - use_c10_dispatcher: full dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda @@ -4497,107 +4749,93 @@ # Fused RNN kernels - func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: _thnn_fused_lstm_cell_cuda - func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: _thnn_fused_lstm_cell_backward_cuda - func: _thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CUDA: _thnn_fused_gru_cell_cuda - func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full dispatch: CUDA: _thnn_fused_gru_cell_backward_cuda - func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures # RNN cells and layers - func: lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full - func: lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full - func: gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures # Quantized RNN layer registration has been moved to C10 dispatch in `RNN.cpp` # Quantized RNN layers # - func: quantized_lstm(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor) -# use_c10_dispatcher: full + # - func: quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor) -# use_c10_dispatcher: full + # Quantized GRU layers # - func: quantized_gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) -# use_c10_dispatcher: full +# # - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) -# use_c10_dispatcher: full +# # Quantized RNN cells - func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor) - use_c10_dispatcher: full - func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor - use_c10_dispatcher: full - func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor - use_c10_dispatcher: full - func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor - use_c10_dispatcher: full # PackedSequence utilities - func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) - use_c10_dispatcher: full + dispatch: + DefaultBackend: _pack_padded_sequence - func: _pack_padded_sequence_backward(Tensor grad, int[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor - use_c10_dispatcher: full - func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) - use_c10_dispatcher: full # wrappers for legacy TH methods @@ -4616,112 +4854,103 @@ QuantizedCPU, QuantizedCUDA: set_storage_quantized_ - func: set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) - use_c10_dispatcher: full variants: method device_guard: False dispatch: CPU, CUDA: set_tensor_ - func: set_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: set_cpu_ CUDA: set_cuda_ -- func: set_quantizer_(Tensor(a!) self, ConstQuantizerPtr quantizer) -> Tensor(a!) - variants: method - dispatch: - QuantizedCPU, QuantizedCUDA: set_quantizer_ - - func: is_set_to(Tensor self, Tensor tensor) -> bool - use_c10_dispatcher: full variants: method device_guard: False dispatch: CPU, CUDA: is_set_to - func: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda - func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor - use_c10_dispatcher: full variants: function, method - func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda - func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor - use_c10_dispatcher: full variants: function, method - func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: masked_scatter__cpu CUDA: masked_scatter__cuda - func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor - use_c10_dispatcher: full variants: function, method - func: view(Tensor(a) self, int[] size) -> Tensor(a) - use_c10_dispatcher: full variants: method device_guard: False dispatch: CPU, CUDA, QuantizedCPU, QuantizedCUDA: view MkldnnCPU: mkldnn_view -- func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) +# Warning: If you want to change the name or overload name of this +# operator, you might also want to change the `isBlockListedSchema` +# function in `torch/csrc/jit/frontend/schema_catching.cpp`. +# The name and overload name of this operator is hardcoded in that +# function in order to workaround a bug: +# https://github.com/pytorch/pytorch/issues/47964 +- func: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) use_c10_dispatcher: full + variants: method + device_guard: False + dispatch: + DefaultBackend: view_dtype + +- func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_put_ CUDA: legacy::cuda::_th_put_ - func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: index_add_cpu_ CUDA: index_add_cuda_ - func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor - use_c10_dispatcher: full variants: function, method - func: index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor variants: function, method - func: index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: legacy::cpu::_th_index_fill_ CUDA: legacy::cuda::_th_index_fill_ - func: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor - use_c10_dispatcher: full variants: function, method - func: index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: index_fill_ - func: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor - use_c10_dispatcher: full variants: function, method - func: index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!) @@ -4737,23 +4966,19 @@ variants: function, method - func: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: scatter_ - func: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor - use_c10_dispatcher: full variants: function, method - func: scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: scatter_fill_ - func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor - use_c10_dispatcher: full variants: function, method - func: scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor @@ -4763,330 +4988,272 @@ variants: function, method - func: scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: scatter_reduce_ - func: scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: scatter_scalar_reduce_ - func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: scatter_add_ - func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor - use_c10_dispatcher: full variants: function, method - func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor variants: function, method - func: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: eq_ - func: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: eq_ - func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU, CUDA: bitwise_and_out - func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU, CUDA: bitwise_and_out - func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: __and__.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: __and__.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU, CUDA: bitwise_or_out - func: bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU, CUDA: bitwise_or_out - func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: __or__.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: __or__.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: __ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: __ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU, CUDA: bitwise_xor_out - func: bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: function dispatch: CPU, CUDA: bitwise_xor_out - func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: __xor__.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: __ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: __ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: __lshift__.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: __lshift__ - func: __lshift__.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: __lshift__ - func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: __ilshift__ - func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: __ilshift__ - func: __rshift__.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: __rshift__ - func: __rshift__.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: __rshift__ - func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: __irshift__ - func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: __irshift__ - func: lgamma_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: _lgamma__cpu CUDA: _lgamma__cuda - func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: atan2_ - func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: tril_cpu_ CUDA: tril_cuda_ - func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: triu_cpu_ CUDA: triu_cuda_ - func: digamma_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: digamma_ - func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: legacy::cpu::_th_renorm_ CUDA: legacy::cuda::_th_renorm_ -- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) - use_c10_dispatcher: full - variants: method - dispatch: - CPU, CUDA: pow_ - -- func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) - use_c10_dispatcher: full - variants: method - dispatch: - CPU, CUDA: pow_ - - func: lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: lerp_cpu_scalar_ CUDA: lerp_cuda_scalar_ - func: lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU: lerp_cpu_tensor_ CUDA: lerp_cuda_tensor_ - func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: - CPU: fmod_ - CUDA: fmod_cuda_ + CPU, CUDA: fmod_ - func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: - CPU: fmod_ - CUDA: fmod_cuda_ + CPU, CUDA: fmod_ - func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: remainder_ - func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: CPU, CUDA: remainder_ - func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: - CPU: addbmm_cpu_ - CUDA: addbmm__cuda + CPU, CUDA: addbmm_ - func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: addbmm_cpu_out - CUDA: addbmm_out_cuda + CPU, CUDA: addbmm_out - func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: - CPU: addbmm_cpu - CUDA: addbmm_cuda + CPU, CUDA: addbmm - func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: addcdiv_ - func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) variants: method @@ -5131,384 +5298,384 @@ # wrappers for TH functions - func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: diag_cpu_out CUDA: diag_cuda_out - func: diag(Tensor self, int diagonal=0) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: diag - func: diag_backward(Tensor grad, int[] input_sizes, int diagonal) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: cross_out - func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: cross - func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: triu_cpu_out CUDA: triu_cuda_out - func: triu(Tensor self, int diagonal=0) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: triu - func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: tril_cpu_out CUDA: tril_cuda_out - func: tril(Tensor self, int diagonal=0) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: tril - func: tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full dispatch: CPU: tril_indices_cpu CUDA: tril_indices_cuda - func: triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - use_c10_dispatcher: full dispatch: CPU: triu_indices_cpu CUDA: triu_indices_cuda - func: trace(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: - CPU: legacy::cpu::_th_trace + CPU: trace_cpu CUDA: trace_cuda - func: trace_backward(Tensor grad, int[] sizes) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu - func: ne.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: ne QuantizedCPU: ne_quantized_cpu - func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu - func: ne.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: ne QuantizedCPU: ne_quantized_cpu - func: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: ne_ - func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: ne_ # not_equal, alias for torch.ne - func: not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: not_equal.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: not_equal.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu - func: eq.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: eq QuantizedCPU: eq_quantized_cpu - func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu - func: eq.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: eq QuantizedCPU: eq_quantized_cpu - func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: ge_out QuantizedCPU: ge_out_quantized_cpu - func: ge.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: ge QuantizedCPU: ge_quantized_cpu - func: ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: ge_out QuantizedCPU: ge_out_quantized_cpu - func: ge.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: ge QuantizedCPU: ge_quantized_cpu - func: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: ge_ - func: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: ge_ # greater_equal, alias for torch.ge - func: greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: greater_equal.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: greater_equal.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: le_out QuantizedCPU: le_out_quantized_cpu - func: le.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: le QuantizedCPU: le_quantized_cpu - func: le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: le_out QuantizedCPU: le_out_quantized_cpu - func: le.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: le QuantizedCPU: le_quantized_cpu - func: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: le_ - func: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: le_ # less_equal, alias for torch.le - func: less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: less_equal.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: less_equal.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: gt_out QuantizedCPU: gt_out_quantized_cpu - func: gt.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: gt QuantizedCPU: gt_quantized_cpu - func: gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: gt_out QuantizedCPU: gt_out_quantized_cpu - func: gt.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: gt QuantizedCPU: gt_quantized_cpu - func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: gt_ - func: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: gt_ # greater, alias for torch.gt - func: greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: greater.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: greater.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu - func: lt.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: lt QuantizedCPU: lt_quantized_cpu - func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu - func: lt.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: lt QuantizedCPU: lt_quantized_cpu - func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: lt_ - func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: lt_ # less, alias for torch.lt - func: less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: less.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: less.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - use_c10_dispatcher: full variants: method - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: legacy::cpu::_th_take_out - CUDA: legacy::cuda::_th_take_out + CPU: take_out_cpu + CUDA: take_out_cuda - func: take(Tensor self, Tensor index) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: - CPU: legacy::cpu::_th_take - CUDA: legacy::cuda::_th_take + CPU: take_cpu + CUDA: take_cuda - func: take_backward(Tensor grad, Tensor input, Tensor index) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: index_select_out_cpu_ CUDA: index_select_out_cuda - func: index_select(Tensor self, int dim, Tensor index) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: index_select_cpu_ @@ -5517,266 +5684,301 @@ SparseCUDA: index_select_sparse - func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor variants: method, function - func: index_select_backward(Tensor grad, int[] self_sizes, int dim, Tensor index) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: masked_select_out_cpu CUDA: masked_select_out_cuda - func: masked_select(Tensor self, Tensor mask) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: masked_select_cpu CUDA: masked_select_cuda - func: masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_nonzero_out CUDA: nonzero_out_cuda - func: nonzero(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_nonzero CUDA: nonzero_cuda - func: nonzero_numpy(Tensor self) -> Tensor[] - use_c10_dispatcher: full variants: method, function - func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: gather_out_cpu_cuda CUDA: gather_out_cpu_cuda - func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: gather - func: gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor - use_c10_dispatcher: full variants: function device_guard: False - func: gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor variants: method, function - func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor - use_c10_dispatcher: full - func: addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: addcmul_out - func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: addcmul - func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: addcmul_ - func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: addcdiv_out - func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: addcdiv - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_gels_out CUDA: legacy::cuda::_th_gels_out - func: lstsq(Tensor self, Tensor A) -> (Tensor solution, Tensor QR) - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_gels CUDA: legacy::cuda::_th_gels - func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: triangular_solve_out - func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: triangular_solve - func: _triangular_solve_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: _triangular_solve_helper_cpu CUDA: _triangular_solve_helper_cuda - func: symeig.e(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: symeig_out - func: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: symeig - func: _symeig_helper(Tensor self, bool eigenvectors, bool upper) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: _symeig_helper_cpu CUDA: _symeig_helper_cuda - func: eig.e(Tensor self, bool eigenvectors=False, *, Tensor(a!) e, Tensor(b!) v) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: legacy::cpu::_th_eig_out - CUDA: legacy::cuda::_th_eig_out + DefaultBackend: eig_out - func: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors) - use_c10_dispatcher: full variants: method, function dispatch: - CPU: legacy::cpu::_th_eig - CUDA: legacy::cuda::_th_eig + DefaultBackend: eig - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: svd_out - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) - use_c10_dispatcher: full variants: method, function + dispatch: + Math: svd -- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full +- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) variants: function dispatch: CPU: _svd_helper_cpu CUDA: _svd_helper_cuda +# swapaxes, alias for transpose +- func: swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a) + variants: function, method + device_guard: False + +- func: swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!) + variants: method + device_guard: False + +# swapdims, alias for transpose +- func: swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + variants: function, method + device_guard: False + +- func: swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + variants: method + device_guard: False + - func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: cholesky_out - func: cholesky(Tensor self, bool upper=False) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: cholesky - func: _cholesky_helper(Tensor self, bool upper) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: _cholesky_helper_cpu CUDA: _cholesky_helper_cuda - func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: cholesky_solve_out - func: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: cholesky_solve - func: _cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: _cholesky_solve_helper_cpu CUDA: _cholesky_solve_helper_cuda - func: solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: solve - func: solve.solution(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!) solution, Tensor(b!) LU) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: solve_out - func: _solve_helper(Tensor self, Tensor A) -> (Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: _solve_helper_cpu CUDA: _solve_helper_cuda - func: cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_potri_out CUDA: legacy::cuda::_th_potri_out - func: cholesky_inverse(Tensor self, bool upper=False) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_potri CUDA: legacy::cuda::_th_potri - func: qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: qr_out - func: qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) - use_c10_dispatcher: full variants: method, function - -- func: _qr_helper(Tensor self, bool some) -> (Tensor, Tensor) - use_c10_dispatcher: full - variants: function dispatch: - CPU: _qr_helper_cpu - CUDA: _qr_helper_cuda + Math: qr - func: geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_geqrf_out CUDA: legacy::cuda::_th_geqrf_out - func: geqrf(Tensor self) -> (Tensor a, Tensor tau) - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_geqrf CUDA: legacy::cuda::_th_geqrf - func: orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_orgqr_out - func: orgqr(Tensor self, Tensor input2) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_orgqr - func: ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_ormqr_out - func: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_ormqr - func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full variants: function dispatch: CPU: _lu_with_info_cpu CUDA: _lu_with_info_cuda - func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: lu_solve_out - func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: lu_solve - func: _lu_solve_helper(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU: _lu_solve_helper_cpu @@ -5784,6 +5986,7 @@ # TODO: remove dispatch section when porting TH CUDA to ATen - func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: multinomial_out @@ -5792,364 +5995,388 @@ dispatch: CPU, CUDA: multinomial -- func: _multinomial_alias_setup(Tensor probs) -> (Tensor, Tensor) - use_c10_dispatcher: full - variants: function - dispatch: - CPU: legacy::cpu::_th_multinomial_alias_setup - CUDA: legacy::cuda::_th_multinomial_alias_setup - -- func: _multinomial_alias_draw(Tensor J, Tensor q, int num_samples, *, Generator? generator=None) -> Tensor - variants: function - dispatch: - CPU: legacy::cpu::_th_multinomial_alias_draw - CUDA: legacy::cuda::_th_multinomial_alias_draw - - func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _lgamma_out_cpu CUDA: _lgamma_out_cuda - func: lgamma(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: lgamma - func: digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: digamma_out - func: digamma(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: digamma - func: polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: polygamma_out - func: polygamma(int n, Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: polygamma - func: erfinv(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: erfinv - func: erfinv_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method dispatch: - CPU: _erfinv__cpu - CUDA: _erfinv__cuda + CPU, CUDA: erfinv_ - func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: _erfinv_out_cpu - CUDA: _erfinv_out_cuda + CPU, CUDA: erfinv_out - func: i0(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: i0 - func: i0_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: i0_ - func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: i0_out - func: sign(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: sign - func: sign_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full variants: method + dispatch: + DefaultBackend: sign_ - func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: sign_out - func: signbit(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: signbit_out CUDA: signbit_out - func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: dist - func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: atan2_out - func: atan2(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: atan2 - func: lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: lerp_cpu_scalar_out CUDA: lerp_cuda_scalar_out - func: lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: lerp_cpu_tensor_out CUDA: lerp_cuda_tensor_out - func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: lerp_cpu_scalar CUDA: lerp_cuda_scalar - func: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: lerp_cpu_tensor CUDA: lerp_cuda_tensor - func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_histc_out CUDA: _histc_out_cuda - func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_histc CUDA: _histc_cuda - func: fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: fmod_out - CUDA: fmod_cuda_out + CPU, CUDA: fmod_out - func: fmod.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: - CPU: fmod - CUDA: fmod_cuda + CPU, CUDA: fmod - func: fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: fmod_out - CUDA: fmod_cuda_out + CPU, CUDA: fmod_out - func: fmod.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: - CPU: fmod - CUDA: fmod_cuda + CPU, CUDA: fmod - func: hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: hypot_out - func: hypot(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: hypot - func: hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: method + dispatch: + DefaultBackend: hypot_ + +- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: igamma_out + +- func: igamma(Tensor self, Tensor other) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: igamma + +- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: method + dispatch: + CPU, CUDA: igamma_ + +- func: igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU, CUDA: igammac_out + +- func: igammac(Tensor self, Tensor other) -> Tensor + variants: method, function + dispatch: + CPU, CUDA: igammac + +- func: igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) variants: method + dispatch: + CPU, CUDA: igammac_ - func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: nextafter_out - func: nextafter(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: nextafter - func: nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method + dispatch: + DefaultBackend: nextafter_ - func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: remainder_out - func: remainder.Scalar(Tensor self, Scalar other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: remainder - func: remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: remainder_out - func: remainder.Tensor(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: remainder - func: min(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: min QuantizedCPU: min_quantized_cpu - func: max(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: max QuantizedCPU: max_quantized_cpu - func: maximum(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: maximum - func: maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: maximum_out # binary max, alias of maximum # NOTE: max is not an alias for maximum, since there is also unary max - func: max.other(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function - func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: minimum(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: minimum - func: minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: minimum_out # binary min, alias for minimum # NOTE: min is not an alias for minimum, since there is also unary min - func: min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: min.other(Tensor self, Tensor other) -> Tensor - use_c10_dispatcher: full - variants: method, function - -- func: median(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function - dispatch: - CPU: median_cpu - CUDA: median_cuda - func: quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: method, function - func: quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: method, function - func: nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: method, function - func: nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False) -> Tensor - use_c10_dispatcher: full variants: method, function -- func: sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) +- func: sort.values(Tensor self, int dim=-1, bool descending=False, bool stable=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: sort_out_cpu CUDA: legacy::cuda::_th_sort_out -- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full +- func: sort(Tensor self, int dim=-1, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices) variants: method, function dispatch: CPU: sort_cpu CUDA: legacy::cuda::_th_sort QuantizedCPU: sort_quantized_cpu -- func: sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) +- func: sort.dimname_values(Tensor self, Dimname dim, bool descending=False, bool stable=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + +- func: sort.dimname(Tensor self, Dimname dim, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices) + variants: method, function + +- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: msort_out -- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) +- func: msort(Tensor self) -> Tensor variants: method, function + dispatch: + Math: msort - func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor - use_c10_dispatcher: full variants: method, function - func: argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor variants: method, function - func: topk.values(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: topk_out_cpu CUDA: legacy::cuda::_th_topk_out - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: topk QuantizedCPU: topk_quantized_cpu - func: all(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: all - func: any(Tensor self) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: any SparseCPU, SparseCUDA: any_sparse - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: legacy::cpu::_th_renorm_out CUDA: legacy::cuda::_th_renorm_out - func: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU: legacy::cpu::_th_renorm CUDA: legacy::cuda::_th_renorm - func: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) - use_c10_dispatcher: full variants: method device_guard: False dispatch: @@ -6157,13 +6384,11 @@ QuantizedCPU, QuantizedCUDA: unfold - func: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: unfold_backward - func: equal(Tensor self, Tensor other) -> bool - use_c10_dispatcher: full variants: method, function dispatch: CPU: cpu_equal @@ -6171,42 +6396,94 @@ QuantizedCPU: equal_quantized_cpu - func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: pow_out - func: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor - use_c10_dispatcher: full variants: method, function dispatch: CPU, CUDA: pow - func: pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: pow_out - func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor - use_c10_dispatcher: full dispatch: CPU, CUDA: pow - func: pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: pow_out SparseCPU, SparseCUDA: pow_out_sparse_scalar - func: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor - use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: pow SparseCPU, SparseCUDA: pow_sparse_scalar +- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + dispatch: + CPU, CUDA: pow_ + +- func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + dispatch: + CPU, CUDA: pow_ + +- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: float_power_out + +- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + variants: function, method + dispatch: + Math: float_power + +- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: float_power_out + +- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor + dispatch: + Math: float_power + +- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: float_power_out + +- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + variants: function, method + dispatch: + Math: float_power + +- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + variants: method + dispatch: + Math: float_power_ + +- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + variants: method + dispatch: + Math: float_power_ + - func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU, CUDA: normal_ - func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: normal_out @@ -6215,6 +6492,7 @@ CPU, CUDA: normal - func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: normal_out @@ -6223,6 +6501,7 @@ CPU, CUDA: normal - func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: normal_out @@ -6231,286 +6510,659 @@ CPU, CUDA: normal - func: normal.float_float(float mean, float std, int[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: normal.float_float_out(float mean, float std, int[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: alias(Tensor(a) self) -> Tensor(a) - use_c10_dispatcher: full variants: method, function + dispatch: + DefaultBackend: alias - func: _index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) - use_c10_dispatcher: full dispatch: CPU: legacy::cpu::_th_index_copy_ CUDA: legacy::cuda::_th_index_copy_ - func: _cumsum(Tensor self, int dim) -> Tensor - use_c10_dispatcher: full dispatch: CPU: _cumsum_cpu CUDA: _cumsum_cuda - func: _cumsum.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _cumsum_out_cpu CUDA: _cumsum_out_cuda - func: _cumprod(Tensor self, int dim) -> Tensor - use_c10_dispatcher: full dispatch: CPU: _cumprod_cpu CUDA: _cumprod_cuda - func: _cumprod.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _cumprod_out_cpu CUDA: _cumprod_out_cuda - func: _var(Tensor self, bool unbiased=True) -> Tensor - use_c10_dispatcher: full dispatch: CPU: legacy::cpu::_th_var - func: _std(Tensor self, bool unbiased=True) -> Tensor - use_c10_dispatcher: full dispatch: CPU: legacy::cpu::_th_std -- func: _amp_non_finite_check_and_unscale_(Tensor(a!) self, Tensor(b!) found_inf, Tensor inv_scale) -> () - use_c10_dispatcher: full +- func: _amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> () variants: function dispatch: - CUDA: _amp_non_finite_check_and_unscale_cuda_ + CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ - func: _amp_update_scale(Tensor(a!) growth_tracker, Tensor current_scale, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor - use_c10_dispatcher: full variants: function dispatch: CUDA: _amp_update_scale_cuda - func: _cat(Tensor[] tensors, int dim=0) -> Tensor - use_c10_dispatcher: full dispatch: CPU: _cat_cpu CUDA: cat_cuda QuantizedCPU: cat_quantized_cpu - func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: _cat_out_cpu CUDA: cat_out_cuda QuantizedCPU: cat_out_quantized_cpu - func: _foreach_add.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] - use_c10_dispatcher: full - device_guard: False variants: function dispatch: CPU: foreach_tensor_add_scalar_kernel_slow CUDA: foreach_tensor_add_scalar_kernel_cuda - func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ - func: _foreach_sub.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] - device_guard: False variants: function dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow CUDA: foreach_tensor_sub_scalar_kernel_cuda - func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow_ CUDA: foreach_tensor_sub_scalar_kernel_cuda_ - func: _foreach_mul.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] - device_guard: False variants: function dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow CUDA: foreach_tensor_mul_scalar_kernel_cuda - func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ - func: _foreach_div.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] - device_guard: False variants: function dispatch: CPU: foreach_tensor_div_scalar_kernel_slow CUDA: foreach_tensor_div_scalar_kernel_cuda - func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_div_scalar_kernel_slow_ CUDA: foreach_tensor_div_scalar_kernel_cuda_ -- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, Scalar alpha=1) -> Tensor[] - device_guard: False +- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[] variants: function dispatch: CPU: foreach_tensor_add_list_kernel_slow CUDA: foreach_tensor_add_list_kernel_cuda -- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, Scalar alpha=1) -> () - device_guard: False +- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () variants: function dispatch: CPU: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ -- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, Scalar alpha=1) -> Tensor[] - device_guard: False +- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[] variants: function dispatch: CPU: foreach_tensor_sub_list_kernel_slow CUDA: foreach_tensor_sub_list_kernel_cuda -- func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, Scalar alpha=1) -> () - device_guard: False +- func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () variants: function dispatch: CPU: foreach_tensor_sub_list_kernel_slow_ CUDA: foreach_tensor_sub_list_kernel_cuda_ - func: _foreach_mul.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] - device_guard: False variants: function dispatch: CPU: foreach_tensor_mul_list_kernel_slow CUDA: foreach_tensor_mul_list_kernel_cuda - func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ -- func: _foreach_div.List(Tensor(a!)[] self, Tensor[] other) -> Tensor[] - device_guard: False +- func: _foreach_div.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] variants: function dispatch: CPU: foreach_tensor_div_list_kernel_slow CUDA: foreach_tensor_div_list_kernel_cuda - func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_div_list_kernel_slow_ CUDA: foreach_tensor_div_list_kernel_cuda_ +- func: _foreach_add.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_add_scalarlist_kernel_slow + CUDA: foreach_tensor_add_scalarlist_kernel_cuda + +- func: _foreach_add_.ScalarList(Tensor(a!)[] self, float[] scalars) -> () + variants: function + dispatch: + CPU: foreach_tensor_add_scalarlist_kernel_slow_ + CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ + +- func: _foreach_sub.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_sub_scalarlist_kernel_slow + CUDA: foreach_tensor_sub_scalarlist_kernel_cuda + +- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, float[] scalars) -> () + variants: function + dispatch: + CPU: foreach_tensor_sub_scalarlist_kernel_slow_ + CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ + +- func: _foreach_div.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_div_scalarlist_kernel_slow + CUDA: foreach_tensor_div_scalarlist_kernel_cuda + +- func: _foreach_div_.ScalarList(Tensor(a!)[] self, float[] scalars) -> () + variants: function + dispatch: + CPU: foreach_tensor_div_scalarlist_kernel_slow_ + CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ + +- func: _foreach_mul.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_mul_scalarlist_kernel_slow + CUDA: foreach_tensor_mul_scalarlist_kernel_cuda + +- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, float[] scalars) -> () + variants: function + dispatch: + CPU: foreach_tensor_mul_scalarlist_kernel_slow_ + CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ + - func: _foreach_exp(Tensor[] tensors) -> Tensor[] - device_guard: False variants: function dispatch: CPU: foreach_tensor_exp_slow CUDA: foreach_tensor_exp_cuda +- func: _foreach_zero_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_zero_slow_ + CUDA: foreach_tensor_zero_cuda_ + - func: _foreach_exp_(Tensor(a!)[] self) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_exp_slow_ CUDA: foreach_tensor_exp_cuda_ - func: _foreach_sqrt(Tensor[] tensors) -> Tensor[] - device_guard: False variants: function dispatch: CPU: foreach_tensor_sqrt_slow CUDA: foreach_tensor_sqrt_cuda - func: _foreach_sqrt_(Tensor(a!)[] self) -> () - device_guard: False variants: function dispatch: CPU: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ -- func: _foreach_addcdiv_(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () - device_guard: False +- func: _foreach_abs(Tensor[] tensors) -> Tensor[] variants: function dispatch: - CPU: foreach_tensor_addcdiv_slow_ - CUDA: foreach_tensor_addcdiv_cuda_ + CPU: foreach_tensor_abs_slow + CUDA: foreach_tensor_abs_cuda -- func: _foreach_addcmul_(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () - device_guard: False +- func: _foreach_abs_(Tensor(a!)[] self) -> () variants: function dispatch: - CPU: foreach_tensor_addcmul_slow_ - CUDA: foreach_tensor_addcmul_cuda_ + CPU: foreach_tensor_abs_slow_ + CUDA: foreach_tensor_abs_cuda_ -- func: _foreach_addcdiv(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] - device_guard: False +- func: _foreach_acos(Tensor[] tensors) -> Tensor[] variants: function dispatch: - CPU: foreach_tensor_addcdiv_slow - CUDA: foreach_tensor_addcdiv_cuda + CPU: foreach_tensor_acos_slow + CUDA: foreach_tensor_acos_cuda -- func: _foreach_addcmul(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] - device_guard: False +- func: _foreach_acos_(Tensor(a!)[] self) -> () variants: function dispatch: - CPU: foreach_tensor_addcmul_slow - CUDA: foreach_tensor_addcmul_cuda + CPU: foreach_tensor_acos_slow_ + CUDA: foreach_tensor_acos_cuda_ -- func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor) - use_c10_dispatcher: full +- func: _foreach_asin(Tensor[] tensors) -> Tensor[] + variants: function dispatch: - CPU: legacy::cpu::_th_mode - CUDA: legacy::cuda::_th_mode + CPU: foreach_tensor_asin_slow + CUDA: foreach_tensor_asin_cuda -- func: _mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +- func: _foreach_asin_(Tensor(a!)[] self) -> () + variants: function dispatch: - CPU: legacy::cpu::_th_mode_out - CUDA: legacy::cuda::_th_mode_out + CPU: foreach_tensor_asin_slow_ + CUDA: foreach_tensor_asin_cuda_ -- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor - use_c10_dispatcher: full +- func: _foreach_atan(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_atan_slow + CUDA: foreach_tensor_atan_cuda + +- func: _foreach_atan_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_atan_slow_ + CUDA: foreach_tensor_atan_cuda_ + +- func: _foreach_ceil(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_ceil_slow + CUDA: foreach_tensor_ceil_cuda + +- func: _foreach_ceil_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_ceil_slow_ + CUDA: foreach_tensor_ceil_cuda_ + +- func: _foreach_cos(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_cos_slow + CUDA: foreach_tensor_cos_cuda + +- func: _foreach_cos_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_cos_slow_ + CUDA: foreach_tensor_cos_cuda_ + +- func: _foreach_cosh(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_cosh_slow + CUDA: foreach_tensor_cosh_cuda + +- func: _foreach_cosh_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_cosh_slow_ + CUDA: foreach_tensor_cosh_cuda_ + +- func: _foreach_erf(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_erf_slow + CUDA: foreach_tensor_erf_cuda + +- func: _foreach_erf_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_erf_slow_ + CUDA: foreach_tensor_erf_cuda_ + +- func: _foreach_erfc(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_erfc_slow + CUDA: foreach_tensor_erfc_cuda + +- func: _foreach_erfc_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_erfc_slow_ + CUDA: foreach_tensor_erfc_cuda_ + +- func: _foreach_expm1(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_expm1_slow + CUDA: foreach_tensor_expm1_cuda + +- func: _foreach_expm1_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_expm1_slow_ + CUDA: foreach_tensor_expm1_cuda_ + +- func: _foreach_floor(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_floor_slow + CUDA: foreach_tensor_floor_cuda + +- func: _foreach_floor_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_floor_slow_ + CUDA: foreach_tensor_floor_cuda_ + +- func: _foreach_log(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_log_slow + CUDA: foreach_tensor_log_cuda + +- func: _foreach_log_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_log_slow_ + CUDA: foreach_tensor_log_cuda_ + +- func: _foreach_log10(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_log10_slow + CUDA: foreach_tensor_log10_cuda + +- func: _foreach_log10_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_log10_slow_ + CUDA: foreach_tensor_log10_cuda_ + +- func: _foreach_log1p(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_log1p_slow + CUDA: foreach_tensor_log1p_cuda + +- func: _foreach_log1p_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_log1p_slow_ + CUDA: foreach_tensor_log1p_cuda_ + +- func: _foreach_log2(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_log2_slow + CUDA: foreach_tensor_log2_cuda + +- func: _foreach_log2_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_log2_slow_ + CUDA: foreach_tensor_log2_cuda_ + +- func: _foreach_neg(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_neg_slow + CUDA: foreach_tensor_neg_cuda + +- func: _foreach_neg_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_neg_slow_ + CUDA: foreach_tensor_neg_cuda_ + +- func: _foreach_tan(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_tan_slow + CUDA: foreach_tensor_tan_cuda + +- func: _foreach_tan_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_tan_slow_ + CUDA: foreach_tensor_tan_cuda_ + +- func: _foreach_tanh(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_tanh_slow + CUDA: foreach_tensor_tanh_cuda + +- func: _foreach_tanh_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_tanh_slow_ + CUDA: foreach_tensor_tanh_cuda_ + +- func: _foreach_sin(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_sin_slow + CUDA: foreach_tensor_sin_cuda + +- func: _foreach_sin_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_sin_slow_ + CUDA: foreach_tensor_sin_cuda_ + +- func: _foreach_sinh(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_sinh_slow + CUDA: foreach_tensor_sinh_cuda + +- func: _foreach_sinh_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_sinh_slow_ + CUDA: foreach_tensor_sinh_cuda_ + +- func: _foreach_round(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_round_slow + CUDA: foreach_tensor_round_cuda + +- func: _foreach_round_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_round_slow_ + CUDA: foreach_tensor_round_cuda_ + +- func: _foreach_lgamma(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_lgamma_slow + CUDA: foreach_tensor_lgamma_cuda + +- func: _foreach_lgamma_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_lgamma_slow_ + CUDA: foreach_tensor_lgamma_cuda_ + +- func: _foreach_frac(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_frac_slow + CUDA: foreach_tensor_frac_cuda + +- func: _foreach_frac_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_frac_slow_ + CUDA: foreach_tensor_frac_cuda_ + +- func: _foreach_reciprocal(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_reciprocal_slow + CUDA: foreach_tensor_reciprocal_cuda + +- func: _foreach_reciprocal_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_reciprocal_slow_ + CUDA: foreach_tensor_reciprocal_cuda_ + +- func: _foreach_sigmoid(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_sigmoid_slow + CUDA: foreach_tensor_sigmoid_cuda + +- func: _foreach_sigmoid_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_sigmoid_slow_ + CUDA: foreach_tensor_sigmoid_cuda_ + +- func: _foreach_trunc(Tensor[] tensors) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_trunc_slow + CUDA: foreach_tensor_trunc_cuda + +- func: _foreach_trunc_(Tensor(a!)[] self) -> () + variants: function + dispatch: + CPU: foreach_tensor_trunc_slow_ + CUDA: foreach_tensor_trunc_cuda_ + +- func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + variants: function + dispatch: + CPU: foreach_tensor_addcdiv_scalar_slow_ + CUDA: foreach_tensor_addcdiv_scalar_cuda_ + +- func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + variants: function + dispatch: + CPU: foreach_tensor_addcmul_scalar_slow_ + CUDA: foreach_tensor_addcmul_scalar_cuda_ + +- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> () + variants: function + dispatch: + CPU: foreach_tensor_addcdiv_scalarlist_slow_ + CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ + +- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> () + variants: function + dispatch: + CPU: foreach_tensor_addcmul_scalarlist_slow_ + CUDA: foreach_tensor_addcmul_scalarlist_cuda_ + +- func: _foreach_addcdiv.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_addcdiv_scalar_slow + CUDA: foreach_tensor_addcdiv_scalar_cuda + +- func: _foreach_addcmul.Scalar(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_addcmul_scalar_slow + CUDA: foreach_tensor_addcmul_scalar_cuda + +- func: _foreach_addcdiv.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_addcdiv_scalarlist_slow + CUDA: foreach_tensor_addcdiv_scalarlist_cuda + +- func: _foreach_addcmul.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_addcmul_scalarlist_slow + CUDA: foreach_tensor_addcmul_scalarlist_cuda + +- func: _foreach_maximum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_maximum_slow + CUDA: foreach_tensor_maximum_cuda + +- func: _foreach_minimum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[] + variants: function + dispatch: + CPU: foreach_tensor_minimum_slow + CUDA: foreach_tensor_minimum_cuda + +- func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor) + dispatch: + CPU: legacy::cpu::_th_mode + CUDA: legacy::cuda::_th_mode + +- func: _mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + CPU: legacy::cpu::_th_mode_out + CUDA: legacy::cuda::_th_mode_out + +- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor dispatch: CPU: bucketize_cpu CUDA: bucketize_cuda - func: bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: bucketize_out_cpu CUDA: bucketize_out_cuda - func: bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor - use_c10_dispatcher: full dispatch: CPU: bucketize_cpu CUDA: bucketize_cuda - func: searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False) -> Tensor - use_c10_dispatcher: full dispatch: CPU: searchsorted_cpu CUDA: searchsorted_cuda - func: searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU: searchsorted_out_cpu CUDA: searchsorted_out_cuda - func: searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False) -> Tensor - use_c10_dispatcher: full dispatch: CPU: searchsorted_cpu CUDA: searchsorted_cuda @@ -6518,394 +7170,408 @@ ## NN wrappers - func: mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: mse_loss_out - func: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: mse_loss - func: mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: mse_loss_backward_out - func: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: mse_loss_backward - func: l1_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn + dispatch: + DefaultBackend: l1_loss_out - func: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: l1_loss - func: l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: l1_loss_backward_out - func: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: l1_loss_backward - func: multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: multi_margin_loss_cpu_out CUDA: legacy::cuda::_thnn_multi_margin_loss_forward_out - func: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: multi_margin_loss_cpu CUDA: legacy::cuda::_thnn_multi_margin_loss_forward - func: multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: multi_margin_loss_cpu_backward_out CUDA: legacy::cuda::_thnn_multi_margin_loss_backward_out - func: multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: multi_margin_loss_cpu_backward CUDA: legacy::cuda::_thnn_multi_margin_loss_backward - func: multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor - use_c10_dispatcher: full python_module: nn - func: multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: multilabel_margin_loss_forward_out_cpu CUDA: legacy::cuda::_thnn_multilabel_margin_loss_forward_out - func: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) - use_c10_dispatcher: full python_module: nn dispatch: CPU: multilabel_margin_loss_forward_cpu CUDA: legacy::cuda::_thnn_multilabel_margin_loss_forward - func: multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: multilabel_margin_loss_backward_cpu_out CUDA: legacy::cuda::_thnn_multilabel_margin_loss_backward_out - func: multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: multilabel_margin_loss_backward_cpu CUDA: legacy::cuda::_thnn_multilabel_margin_loss_backward - func: nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss_forward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward_out - func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss_forward_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward - func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss_backward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward_out - func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss_backward_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward - func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss2d_forward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss2d_forward_out - func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss2d_forward_cpu CUDA: legacy::cuda::_thnn_nll_loss2d_forward - func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss2d_backward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss2d_backward_out - func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: nll_loss2d_backward_cpu CUDA: legacy::cuda::_thnn_nll_loss2d_backward -- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) +- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: smooth_l1_loss_out CUDA: smooth_l1_loss_out -- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor - use_c10_dispatcher: full +- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor python_module: nn dispatch: CPU, CUDA: smooth_l1_loss -- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) +- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: smooth_l1_loss_backward_out CUDA: smooth_l1_loss_backward_out -- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor - use_c10_dispatcher: full +- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor python_module: nn + dispatch: + DefaultBackend: smooth_l1_loss_backward - func: soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn + dispatch: + DefaultBackend: soft_margin_loss_out - func: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: soft_margin_loss - func: soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn + dispatch: + DefaultBackend: soft_margin_loss_backward_out - func: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: soft_margin_loss_backward - func: elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: elu_out - func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: elu -- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) - python_module: nn - dispatch: - CPU, CUDA: elu_backward_out - -- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, Tensor output) -> Tensor +- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: elu_backward - func: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: elu_ - func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: glu_out CUDA: legacy::cuda::_thnn_glu_forward_out - func: glu(Tensor self, int dim=-1) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: glu CUDA: legacy::cuda::_thnn_glu_forward - func: glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: glu_backward_out CUDA: legacy::cuda::_thnn_glu_backward_out - func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: glu_backward CUDA: legacy::cuda::_thnn_glu_backward - func: hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: hardsigmoid_out - func: hardsigmoid(Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardsigmoid QuantizedCPU: hardsigmoid_quantized_cpu - func: hardsigmoid_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardsigmoid_ - func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardsigmoid_backward - func: hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: hardtanh_out QuantizedCPU: hardtanh_out_quantized_cpu - func: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardtanh QuantizedCPU: hardtanh_quantized_cpu - func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: hardtanh_backward_out - func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardtanh_backward - func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardtanh_ QuantizedCPU: hardtanh_quantized_cpu_ - func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: hardswish_out - func: hardswish(Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardswish - func: hardswish_(Tensor(a!) self) -> Tensor(a!) - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardswish_ - func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: hardswish_backward - func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: leaky_relu_out QuantizedCPU: leaky_relu_out_quantized_cpu - func: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: leaky_relu - QuantizedCPU: heaky_relu_quantized_cpu + QuantizedCPU: leaky_relu_quantized_cpu - func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: leaky_relu_backward - func: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: leaky_relu_ QuantizedCPU: leaky_relu_quantized_cpu_ - func: log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: log_sigmoid(Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn - func: log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: log_sigmoid_forward_out_cpu CUDA: legacy::cuda::_thnn_log_sigmoid_forward_out - func: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) - use_c10_dispatcher: full python_module: nn dispatch: CPU: log_sigmoid_forward_cpu CUDA: legacy::cuda::_thnn_log_sigmoid_forward - func: log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: log_sigmoid_backward_out_cpu CUDA: legacy::cuda::_thnn_log_sigmoid_backward_out - func: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: log_sigmoid_backward_cpu CUDA: legacy::cuda::_thnn_log_sigmoid_backward - func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: rrelu_with_noise_out_cpu @@ -6918,8 +7584,9 @@ CUDA: legacy::cuda::_thnn_rrelu_with_noise_forward - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor - use_c10_dispatcher: full python_module: nn + dispatch: + DefaultBackend: rrelu_with_noise_backward - func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) python_module: nn @@ -6928,79 +7595,77 @@ CUDA: legacy::cuda::_thnn_rrelu_with_noise_forward_ - func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: softplus_out - func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: softplus - func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: softplus_backward_out - func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: softplus_backward - func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: softshrink_out - func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: softshrink - func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: softshrink_backward_out - func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: softshrink_backward - func: adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: adaptive_avg_pool2d_out_cpu MkldnnCPU: mkldnn_adaptive_avg_pool2d_out - func: adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor - use_c10_dispatcher: full python_module: nn - func: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor - use_c10_dispatcher: full dispatch: MkldnnCPU: mkldnn_adaptive_avg_pool2d - func: _adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor - use_c10_dispatcher: full dispatch: CPU: adaptive_avg_pool2d_cpu CUDA: adaptive_avg_pool2d_cuda QuantizedCPU: adaptive_avg_pool2d_quantized_cpu - func: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: adaptive_avg_pool2d_backward_cpu CUDA: adaptive_avg_pool2d_backward_cuda - func: adaptive_avg_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: adaptive_avg_pool3d_out_cpu @@ -7008,7 +7673,6 @@ QuantizedCPU: adaptive_avg_pool3d_out_quantized_cpu - func: adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: adaptive_avg_pool3d_cpu @@ -7016,13 +7680,13 @@ QuantizedCPU: adaptive_avg_pool3d_quantized_cpu - func: adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: adaptive_avg_pool3d_backward_out_cpu CUDA: adaptive_avg_pool3d_backward_out_cuda - func: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: adaptive_avg_pool3d_backward_cpu @@ -7030,6 +7694,7 @@ # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: adaptive_max_pool2d_out_cpu @@ -7037,20 +7702,19 @@ # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) - use_c10_dispatcher: full python_module: nn dispatch: CPU: adaptive_max_pool2d_cpu CUDA: adaptive_max_pool2d_cuda - func: adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: adaptive_max_pool2d_backward_out_cpu CUDA: adaptive_max_pool2d_backward_out_cuda - func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: adaptive_max_pool2d_backward_cpu @@ -7058,6 +7722,7 @@ # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: adaptive_max_pool3d_out_cpu @@ -7065,26 +7730,26 @@ # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) - use_c10_dispatcher: full python_module: nn dispatch: CPU: adaptive_max_pool3d_cpu CUDA: adaptive_max_pool3d_cuda - func: adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: adaptive_max_pool3d_backward_out_cpu CUDA: adaptive_max_pool3d_backward_out_cuda - func: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: adaptive_max_pool3d_backward_cpu CUDA: adaptive_max_pool3d_backward_cuda - func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: avg_pool2d_out_cpu @@ -7092,7 +7757,6 @@ MkldnnCPU: mkldnn_avg_pool2d_out - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: avg_pool2d_cpu @@ -7101,19 +7765,20 @@ QuantizedCPU: avg_pool2d_quantized_cpu - func: avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: avg_pool2d_backward_out_cpu CUDA: avg_pool2d_backward_out_cuda - func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: avg_pool2d_backward_cpu CUDA: avg_pool2d_backward_cuda - func: avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: avg_pool3d_out_cpu @@ -7121,7 +7786,6 @@ MkldnnCPU: mkldnn_avg_pool3d_out - func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: avg_pool3d_cpu @@ -7130,13 +7794,13 @@ QuantizedCPU: avg_pool3d_quantized_cpu - func: avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: avg_pool3d_backward_out_cpu CUDA: avg_pool3d_backward_out_cuda - func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: avg_pool3d_backward_cpu @@ -7144,6 +7808,7 @@ # Return: (Tensor output, Tensor indices) - func: fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: fractional_max_pool2d_out_cpu @@ -7151,20 +7816,19 @@ # Return: (Tensor output, Tensor indices) - func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) - use_c10_dispatcher: full python_module: nn dispatch: CPU: fractional_max_pool2d_cpu CUDA: fractional_max_pool2d_cuda - func: fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: fractional_max_pool2d_backward_out_cpu CUDA: fractional_max_pool2d_backward_out_cuda - func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: fractional_max_pool2d_backward_cpu @@ -7172,6 +7836,7 @@ # Return: (Tensor output, Tensor indices) - func: fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: fractional_max_pool3d_out_cpu @@ -7179,20 +7844,19 @@ # Return: (Tensor output, Tensor indices) - func: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) - use_c10_dispatcher: full python_module: nn dispatch: CPU: fractional_max_pool3d_cpu CUDA: fractional_max_pool3d_cuda - func: fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: fractional_max_pool3d_backward_out_cpu CUDA: fractional_max_pool3d_backward_out_cuda - func: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: fractional_max_pool3d_backward_cpu @@ -7200,6 +7864,7 @@ # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_pool2d_with_indices_out_cpu @@ -7207,20 +7872,19 @@ # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_pool2d_with_indices_cpu CUDA: max_pool2d_with_indices_cuda - func: max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_pool2d_with_indices_backward_out_cpu CUDA: max_pool2d_with_indices_backward_out_cuda - func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_pool2d_with_indices_backward_cpu @@ -7228,6 +7892,7 @@ # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_pool3d_with_indices_out_cpu @@ -7235,224 +7900,219 @@ # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_pool3d_with_indices_cpu CUDA: max_pool3d_with_indices_cuda - func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_pool3d_with_indices_backward_out_cpu CUDA: max_pool3d_with_indices_backward_out_cuda - func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_pool3d_with_indices_backward_cpu CUDA: max_pool3d_with_indices_backward_cuda - func: max_unpool2d.out(Tensor self, Tensor indices, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_unpooling2d_forward_out_cpu CUDA: max_unpooling2d_forward_out_cuda - func: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_unpooling2d_forward_cpu CUDA: max_unpooling2d_forward_cuda - func: max_unpool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, int[2] output_size, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_unpooling2d_backward_out_cpu CUDA: max_unpooling2d_backward_out_cuda - func: max_unpool2d_backward(Tensor grad_output, Tensor self, Tensor indices, int[2] output_size) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_unpooling2d_backward_cpu CUDA: max_unpooling2d_backward_cuda - func: max_unpool3d.out(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_unpooling3d_forward_out_cpu CUDA: max_unpooling3d_forward_out_cuda - func: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_unpooling3d_forward_cpu CUDA: max_unpooling3d_forward_cuda - func: max_unpool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: max_unpooling3d_backward_out_cpu CUDA: max_unpooling3d_backward_out_cuda - func: max_unpool3d_backward(Tensor grad_output, Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: max_unpooling3d_backward_cpu CUDA: max_unpooling3d_backward_cuda - func: reflection_pad1d.out(Tensor self, int[2] padding, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: - CPU: reflection_pad1d_out_cpu + CPU, QuantizedCPU: reflection_pad1d_out_cpu CUDA: reflection_pad1d_out_cuda - func: reflection_pad1d(Tensor self, int[2] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: - CPU: reflection_pad1d_cpu + CPU, QuantizedCPU: reflection_pad1d_cpu CUDA: reflection_pad1d_cuda - QuantizedCPU: reflection_pad1d_cpu - func: reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, int[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: reflection_pad1d_backward_out_cpu CUDA: reflection_pad1d_backward_out_cuda - func: reflection_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: reflection_pad1d_backward_cpu CUDA: reflection_pad1d_backward_cuda - func: reflection_pad2d.out(Tensor self, int[4] padding, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: - CPU: reflection_pad2d_out_cpu + CPU, QuantizedCPU: reflection_pad2d_out_cpu CUDA: reflection_pad2d_out_cuda - func: reflection_pad2d(Tensor self, int[4] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: - CPU: reflection_pad2d_cpu + CPU, QuantizedCPU: reflection_pad2d_cpu CUDA: reflection_pad2d_cuda - func: reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, int[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: reflection_pad2d_backward_out_cpu CUDA: reflection_pad2d_backward_out_cuda - func: reflection_pad2d_backward(Tensor grad_output, Tensor self, int[4] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: reflection_pad2d_backward_cpu CUDA: reflection_pad2d_backward_cuda - func: replication_pad1d.out(Tensor self, int[2] padding, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: replication_pad1d_out_cpu CUDA: replication_pad1d_out_cuda - func: replication_pad1d(Tensor self, int[2] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: replication_pad1d_cpu CUDA: replication_pad1d_cuda - func: replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, int[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: replication_pad1d_backward_out_cpu CUDA: replication_pad1d_backward_out_cuda - func: replication_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: replication_pad1d_backward_cpu CUDA: replication_pad1d_backward_cuda - func: replication_pad2d.out(Tensor self, int[4] padding, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: replication_pad2d_out_cpu CUDA: replication_pad2d_out_cuda - func: replication_pad2d(Tensor self, int[4] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: replication_pad2d_cpu CUDA: replication_pad2d_cuda - func: replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, int[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: replication_pad2d_backward_out_cpu CUDA: replication_pad2d_backward_out_cuda - func: replication_pad2d_backward(Tensor grad_output, Tensor self, int[4] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: replication_pad2d_backward_cpu CUDA: replication_pad2d_backward_cuda - func: replication_pad3d.out(Tensor self, int[6] padding, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: replication_pad3d_out_cpu CUDA: replication_pad3d_out_cuda - func: replication_pad3d(Tensor self, int[6] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: replication_pad3d_cpu CUDA: replication_pad3d_cuda - func: replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, int[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: replication_pad3d_backward_out_cpu CUDA: replication_pad3d_backward_out_cuda - func: replication_pad3d_backward(Tensor grad_output, Tensor self, int[6] padding) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: replication_pad3d_backward_cpu CUDA: replication_pad3d_backward_cuda - func: upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_linear1d_cpu CUDA: upsample_linear1d_cuda - func: upsample_linear1d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_linear1d_backward_cpu CUDA: upsample_linear1d_backward_cuda - func: upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bilinear2d_cpu @@ -7460,56 +8120,46 @@ QuantizedCPU: upsample_bilinear2d_quantized_cpu - func: upsample_bilinear2d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bilinear2d_backward_cpu CUDA: upsample_bilinear2d_backward_cuda - func: upsample_trilinear3d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_trilinear3d_cpu CUDA: upsample_trilinear3d_cuda - func: upsample_trilinear3d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_trilinear3d_backward_cpu CUDA: upsample_trilinear3d_backward_cuda - func: upsample_bicubic2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bicubic2d_cpu CUDA: upsample_bicubic2d_cuda - func: upsample_bicubic2d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bicubic2d_backward_cpu CUDA: upsample_bicubic2d_backward_cuda - func: upsample_nearest1d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: - CPU: upsample_nearest1d_cpu - CUDA: upsample_nearest1d_cuda + DefaultBackend: upsample_nearest1d - func: upsample_nearest1d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: - CPU: upsample_nearest1d_backward_cpu - CUDA: upsample_nearest1d_backward_cuda + DefaultBackend: upsample_nearest1d_backward - func: upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest2d_cpu @@ -7517,14 +8167,12 @@ QuantizedCPU: upsample_nearest2d_quantized_cpu - func: upsample_nearest2d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest2d_backward_cpu CUDA: upsample_nearest2d_backward_cuda - func: upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest3d_cpu @@ -7532,7 +8180,6 @@ QuantizedCPU: upsample_nearest3d_quantized_cpu - func: upsample_nearest3d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, float[]? scale_factors) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest3d_backward_cpu @@ -7540,39 +8187,39 @@ # NOTE: all of the non-"vec" upsample overloads are only kept for backward compatibility. - func: upsample_linear1d.out(Tensor self, int[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_linear1d_out_cpu CUDA: upsample_linear1d_out_cuda - func: upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_linear1d_cpu CUDA: upsample_linear1d_cuda - func: upsample_linear1d_backward.grad_input(Tensor grad_output, int[1] output_size, int[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_linear1d_backward_out_cpu CUDA: upsample_linear1d_backward_out_cuda - func: upsample_linear1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, bool align_corners, float? scales=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_linear1d_backward_cpu CUDA: upsample_linear1d_backward_cuda - func: upsample_bilinear2d.out(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_bilinear2d_out_cpu CUDA: upsample_bilinear2d_out_cuda - func: upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bilinear2d_cpu @@ -7580,104 +8227,102 @@ QuantizedCPU: upsample_bilinear2d_quantized_cpu - func: upsample_bilinear2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_bilinear2d_backward_out_cpu CUDA: upsample_bilinear2d_backward_out_cuda - func: upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bilinear2d_backward_cpu CUDA: upsample_bilinear2d_backward_cuda - func: upsample_bicubic2d.out(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_bicubic2d_out_cpu CUDA: upsample_bicubic2d_out_cuda - func: upsample_bicubic2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bicubic2d_cpu CUDA: upsample_bicubic2d_cuda - func: upsample_bicubic2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_bicubic2d_backward_out_cpu CUDA: upsample_bicubic2d_backward_out_cuda - func: upsample_bicubic2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_bicubic2d_backward_cpu CUDA: upsample_bicubic2d_backward_cuda - func: upsample_trilinear3d.out(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_trilinear3d_out_cpu CUDA: upsample_trilinear3d_out_cuda - func: upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_trilinear3d_cpu CUDA: upsample_trilinear3d_cuda - func: upsample_trilinear3d_backward.grad_input(Tensor grad_output, int[3] output_size, int[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_trilinear3d_backward_out_cpu CUDA: upsample_trilinear3d_backward_out_cuda - func: upsample_trilinear3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_trilinear3d_backward_cpu CUDA: upsample_trilinear3d_backward_cuda - func: upsample_nearest1d.out(Tensor self, int[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: full python_module: nn + structured: True dispatch: CPU: upsample_nearest1d_out_cpu CUDA: upsample_nearest1d_out_cuda - func: upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> Tensor - use_c10_dispatcher: full python_module: nn - dispatch: - CPU: upsample_nearest1d_cpu - CUDA: upsample_nearest1d_cuda + structured_delegate: upsample_nearest1d.out - func: upsample_nearest1d_backward.grad_input(Tensor grad_output, int[1] output_size, int[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: full python_module: nn + structured: True dispatch: CPU: upsample_nearest1d_backward_out_cpu CUDA: upsample_nearest1d_backward_out_cuda - func: upsample_nearest1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, float? scales=None) -> Tensor - use_c10_dispatcher: full python_module: nn - dispatch: - CPU: upsample_nearest1d_backward_cpu - CUDA: upsample_nearest1d_backward_cuda + structured_delegate: upsample_nearest1d_backward.grad_input - func: upsample_nearest2d.out(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_nearest2d_out_cpu CUDA: upsample_nearest2d_out_cuda - func: upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest2d_cpu @@ -7685,26 +8330,26 @@ QuantizedCPU: upsample_nearest2d_quantized_cpu - func: upsample_nearest2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_nearest2d_backward_out_cpu CUDA: upsample_nearest2d_backward_out_cuda - func: upsample_nearest2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest2d_backward_cpu CUDA: upsample_nearest2d_backward_cuda - func: upsample_nearest3d.out(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_nearest3d_out_cpu CUDA: upsample_nearest3d_out_cuda - func: upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest3d_cpu @@ -7712,47 +8357,47 @@ QuantizedCPU: upsample_nearest3d_quantized_cpu - func: upsample_nearest3d_backward.grad_input(Tensor grad_output, int[3] output_size, int[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: upsample_nearest3d_backward_out_cpu CUDA: upsample_nearest3d_backward_out_cuda - func: upsample_nearest3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: upsample_nearest3d_backward_cpu CUDA: upsample_nearest3d_backward_cuda - func: sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: sigmoid_backward_out - func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: sigmoid_backward - func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: logit_backward_out - func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: logit_backward - func: tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU, CUDA: tanh_backward_out - func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU, CUDA: tanh_backward @@ -7776,251 +8421,269 @@ # make the operational distinction clear. - func: slow_conv_transpose2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose2d_out_cpu CUDA: slow_conv_transpose2d_out_cuda - func: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose2d_cpu CUDA: slow_conv_transpose2d_cuda -- func: slow_conv_transpose2d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: slow_conv_transpose2d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose2d_backward_out_cpu CUDA: slow_conv_transpose2d_backward_out_cuda - func: slow_conv_transpose2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - use_c10_dispatcher: full python_module: nn dispatch: CPU: slow_conv_transpose2d_backward_cpu CUDA: slow_conv_transpose2d_backward_cuda - func: slow_conv_transpose3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose3d_out_cpu CUDA: slow_conv_transpose3d_out_cuda - func: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose3d_cpu CUDA: slow_conv_transpose3d_cuda -- func: slow_conv_transpose3d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: slow_conv_transpose3d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose3d_backward_out_cpu CUDA: slow_conv_transpose3d_backward_out_cuda - func: slow_conv_transpose3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - use_c10_dispatcher: full python_module: nn dispatch: CPU: slow_conv_transpose3d_backward_cpu CUDA: slow_conv_transpose3d_backward_cuda - func: thnn_conv2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: thnn_conv2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: thnn_conv2d_forward.output(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, *, Tensor(a!) output, Tensor(b!) finput, Tensor(c!) fgrad_input) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv2d_forward_out_cpu CUDA: legacy::cuda::_thnn_conv2d_forward_out - func: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv2d_forward_cpu CUDA: legacy::cuda::_thnn_conv2d_forward -- func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv2d_backward_out_cpu CUDA: slow_conv2d_backward_out_cuda - func: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - use_c10_dispatcher: full python_module: nn dispatch: CPU: slow_conv2d_backward_cpu CUDA: slow_conv2d_backward_cuda - func: thnn_conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: thnn_conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: thnn_conv_depthwise2d_forward.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CUDA: legacy::cuda::_thnn_conv_depthwise2d_forward_out - func: thnn_conv_depthwise2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CUDA: legacy::cuda::_thnn_conv_depthwise2d_forward -- func: thnn_conv_depthwise2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight) -> (Tensor(a!), Tensor(b!)) +- func: thnn_conv_depthwise2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!) grad_input, Tensor(b!) grad_weight) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CUDA: thnn_conv_depthwise2d_backward_out - func: thnn_conv_depthwise2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[2] output_mask) -> (Tensor grad_input, Tensor grad_weight) - use_c10_dispatcher: full python_module: nn dispatch: CUDA: thnn_conv_depthwise2d_backward - func: slow_conv3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: slow_conv3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn - func: slow_conv3d_forward.output(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, *, Tensor(a!) output, Tensor(b!) finput, Tensor(c!) fgrad_input) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv3d_forward_out_cpu - func: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv3d_forward_cpu -- func: slow_conv3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: slow_conv3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv3d_backward_out_cpu - func: slow_conv3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - use_c10_dispatcher: full python_module: nn dispatch: CPU: slow_conv3d_backward_cpu - func: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_dilated2d_cpu CUDA: slow_conv_dilated2d_cuda - func: slow_conv_dilated2d_backward(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - use_c10_dispatcher: full python_module: nn dispatch: CPU: slow_conv_dilated2d_backward_cpu CUDA: slow_conv_dilated2d_backward_cuda - func: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor - use_c10_dispatcher: full + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_dilated3d_cpu CUDA: slow_conv_dilated3d_cuda - func: slow_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - use_c10_dispatcher: full python_module: nn dispatch: CPU: slow_conv_dilated3d_backward_cpu CUDA: slow_conv_dilated3d_backward_cuda - func: col2im.out(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: col2im_out_cpu CUDA: col2im_out_cuda - func: col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: col2im_cpu CUDA: col2im_cuda - func: col2im_backward.grad_input(Tensor grad_output, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: col2im_backward_out_cpu CUDA: col2im_backward_out_cuda - func: col2im_backward(Tensor grad_output, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: col2im_backward_cpu CUDA: col2im_backward_cuda +- func: column_stack(Tensor[] tensors) -> Tensor + dispatch: + Math: column_stack + +- func: column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + Math: column_stack_out + - func: im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: im2col_out_cpu CUDA: im2col_out_cuda - func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: im2col_cpu CUDA: im2col_cuda - func: im2col_backward.grad_input(Tensor grad_output, int[2] input_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) grad_input) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: im2col_backward_out_cpu CUDA: im2col_backward_out_cuda - func: im2col_backward(Tensor grad_output, int[2] input_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda - func: isfinite(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method device_guard: False - func: isinf(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method device_guard: False +- func: record_stream(Tensor(a!) self, Stream s) -> () + variants: method + dispatch: + CUDA: record_stream_cuda + - func: isposinf(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: isposinf_out - func: isneginf(Tensor self) -> Tensor - use_c10_dispatcher: full variants: function, method - func: isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: CPU, CUDA: isneginf_out @@ -8029,12 +8692,10 @@ # of the vmap frontend API (see torch/_vmap_internals.py). They are not # user-facing, hence the leading underscore. Please don't use them them anywhere else. - func: _add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor - use_c10_dispatcher: full variants: function # See NOTE [_add_batch_dim and _remove_batch_dim] - func: _remove_batch_dim(Tensor self, int level, int batch_size, int out_dim) -> Tensor - use_c10_dispatcher: full variants: function ## Functions related to the fast Fourier transform and the torch.fft namespace @@ -8050,57 +8711,157 @@ # NOTE: NOT an alias for torch.fft, which has different semantics - func: fft_fft(Tensor self, int? n=None, int dim=-1, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_fft.out(Tensor self, int? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_ifft(Tensor self, int? n=None, int dim=-1, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_ifft.out(Tensor self, int? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_rfft(Tensor self, int? n=None, int dim=-1, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_rfft.out(Tensor self, int? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_irfft(Tensor self, int? n=None, int dim=-1, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_irfft.out(Tensor self, int? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_hfft(Tensor self, int? n=None, int dim=-1, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_hfft.out(Tensor self, int? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_ihfft(Tensor self, int? n=None, int dim=-1, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_ihfft.out(Tensor self, int? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_fft2(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + +- func: fft_fft2.out(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_ifft2(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + +- func: fft_ifft2.out(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_rfft2(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + +- func: fft_rfft2.out(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_irfft2(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + python_module: fft + variants: function + +- func: fft_irfft2.out(Tensor self, int[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_fftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_fftn.out(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_ifftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_ifftn.out(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full + variants: function + +- func: fft_rfftn.out(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft variants: function - func: fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor python_module: fft - use_c10_dispatcher: full variants: function -- func: fft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor - use_c10_dispatcher: full - variants: function, method +- func: fft_irfftn.out(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: fft + variants: function + +- func: fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor + python_module: fft + variants: function + +- func: fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor + python_module: fft + variants: function ## Functions for linear algebra and the torch.linalg namespace # Note [linalg namespace binding] @@ -8111,28 +8872,115 @@ # # See linalg_det as an example. +- func: linalg_cholesky(Tensor self) -> Tensor + python_module: linalg + variants: function + dispatch: + DefaultBackend: linalg_cholesky + +- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + DefaultBackend: linalg_cholesky_out + # torch.linalg.det, alias for torch.det - func: linalg_det(Tensor self) -> Tensor python_module: linalg - use_c10_dispatcher: full variants: function - func: det(Tensor self) -> Tensor + variants: function, method + dispatch: + DefaultBackend: det + +- func: linalg_slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + CPU, CUDA: linalg_slogdet + +- func: linalg_slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + python_module: linalg + use_c10_dispatcher: full + dispatch: + CPU, CUDA: linalg_slogdet_out + +- func: _syevd_helper(Tensor self, bool compute_eigenvectors, str uplo) -> (Tensor, Tensor) + variants: function + dispatch: + CPU: _syevd_helper_cpu + CUDA: _syevd_helper_cuda + +- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) + python_module: linalg + variants: function + dispatch: + DefaultBackend: linalg_eigh + +- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + dispatch: + DefaultBackend: linalg_eigh_out + +- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor + python_module: linalg + variants: function + dispatch: + DefaultBackend: linalg_eigvalsh + +- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + dispatch: + DefaultBackend: linalg_eigvalsh_out + +- func: _linalg_inv_out_helper_(Tensor(a!) self, Tensor(b!) infos_lu, Tensor(c!) infos_getri) -> Tensor(a!) use_c10_dispatcher: full + variants: function + dispatch: + CPU: _linalg_inv_out_helper_cpu + CUDA: _linalg_inv_out_helper_cuda + +- func: linalg_inv(Tensor self) -> Tensor + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + DefaultBackend: linalg_inv + +- func: linalg_inv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + DefaultBackend: linalg_inv_out + +- func: inner(Tensor self, Tensor other) -> Tensor variants: function, method +- func: inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + # torch.outer, alias for torch.ger - func: outer(Tensor self, Tensor vec2) -> Tensor - use_c10_dispatcher: full variants: function, method - func: outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: ger(Tensor self, Tensor vec2) -> Tensor - use_c10_dispatcher: full variants: function, method + dispatch: + DefaultBackend: ger - func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: ger_out - func: linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor python_module: linalg @@ -8143,35 +8991,192 @@ variants: function - func: linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: linalg variants: function - func: linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + +- func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + +- func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + python_module: linalg + use_c10_dispatcher: full + variants: function + +- func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor + python_module: linalg + variants: function + dispatch: + Math: linalg_cond + +- func: linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + Math: linalg_cond_out + +- func: linalg_cond.p_str(Tensor self, str p) -> Tensor + python_module: linalg + variants: function + dispatch: + Math: linalg_cond + +- func: linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + Math: linalg_cond_out + +- func: linalg_pinv(Tensor self, float rcond=1e-15, bool hermitian=False) -> Tensor + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + Math: linalg_pinv + +- func: linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + Math: linalg_pinv + +- func: linalg_pinv.out(Tensor self, float rcond=1e-15, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + Math: linalg_pinv_out + +- func: linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + Math: linalg_pinv_out + +- func: _linalg_solve_out_helper_(Tensor(a!) self, Tensor(b!) other, Tensor(c!) infos) -> Tensor(a!) + use_c10_dispatcher: full + variants: function + dispatch: + CPU: _linalg_solve_out_helper_cpu + CUDA: _linalg_solve_out_helper_cuda + +- func: linalg_solve(Tensor input, Tensor other) -> Tensor + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + DefaultBackend: linalg_solve + +- func: linalg_solve.out(Tensor input, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + dispatch: + DefaultBackend: linalg_solve_out + +- func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + Math: linalg_tensorinv + +- func: linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + Math: linalg_tensorinv_out + +- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor python_module: linalg variants: function + dispatch: + Math: linalg_tensorsolve + +- func: linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + Math: linalg_tensorsolve_out + +- func: linalg_qr(Tensor self, str mode='reduced') -> (Tensor Q, Tensor R) + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + DefaultBackend: linalg_qr + +- func: linalg_qr.out(Tensor self, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + DefaultBackend: linalg_qr_out + +- func: _linalg_qr_helper(Tensor self, str mode) -> (Tensor, Tensor) + use_c10_dispatcher: full + variants: function + dispatch: + CPU: _linalg_qr_helper_cpu + CUDA: _linalg_qr_helper_cuda + +- func: linalg_matrix_rank(Tensor self, float? tol=None, bool hermitian=False) -> Tensor + python_module: linalg + variants: function + dispatch: + Math: linalg_matrix_rank + +- func: linalg_matrix_rank.out(Tensor self, float? tol=None, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + python_module: linalg + variants: function + dispatch: + Math: linalg_matrix_rank_out ## Functions that are only for testing # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor - use_c10_dispatcher: full # Note: this function is only for testing. - func: _test_optional_intlist(Tensor values, int[]? addends) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: _test_optional_intlist # Note: this function is only for testing. - func: _test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: _test_optional_intlist # Note: this function is only for testing. - func: _test_optional_floatlist(Tensor values, float[]? addends) -> Tensor - use_c10_dispatcher: full python_module: nn dispatch: CPU: _test_optional_floatlist + +# Note: this function is only for testing. +- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor + python_module: nn + +# Note: this function is only for testing. +- func: _test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor + use_c10_dispatcher: full + python_module: nn + +# Note: this function is only for testing. +- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor + cpp_no_default_args: ['a', 'b'] + use_c10_dispatcher: full + python_module: nn diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp index 9db2a6eb2ac4f..2b3ff71ae5f56 100644 --- a/aten/src/ATen/native/quantized/QTensor.cpp +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -128,11 +128,6 @@ QScheme qscheme_quant(const Tensor& self) { return quantizer->qscheme(); } -Tensor& set_quantizer_(Tensor& self, ConstQuantizerPtr quantizer) { - get_qtensorimpl(self)->set_quantizer_(quantizer); - return self; -} - Tensor quantized_clone( const Tensor& self, c10::optional optional_memory_format) { @@ -245,15 +240,14 @@ float calculate_quant_loss( float scale = data_range == 0 ? 1.0 : static_cast(static_cast(data_range / qmax)); - float inverse_scale = 1.0f / scale; + float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale; float norm = 0.0f; - constexpr int VLEN = 8; int i = 0; -// TODO add FBGEMM kernel -// #ifdef USE_FBGEMM -// #endif + // TODO add FBGEMM kernel + // #ifdef USE_FBGEMM + // #endif // remainder loop for (; i < numel; i++) { @@ -271,7 +265,7 @@ float calculate_quant_loss( and tries to minimize the quant error by doing `torch.norm(x-fake_quant(x,s,z))` Returns the optimized xmax and xmin value of the tensor. */ -std::tuple choose_qparams_optimized( +std::tuple choose_qparams_optimized( const at::Tensor& input_tensor, int64_t numel, const int64_t n_bins, @@ -318,7 +312,11 @@ std::tuple choose_qparams_optimized( } } - return std::make_tuple((float) xmax, (float) xmin); + at::Tensor xmax_tensor = at::empty({1}); + at::Tensor xmin_tensor = at::empty({1}); + xmax_tensor[0] = xmax; + xmin_tensor[0] = xmin; + return std::make_tuple(xmax_tensor, xmin_tensor); } } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/TensorCompare.cpp b/aten/src/ATen/native/quantized/TensorCompare.cpp index 38a105f680891..48e14e1b6adb0 100644 --- a/aten/src/ATen/native/quantized/TensorCompare.cpp +++ b/aten/src/ATen/native/quantized/TensorCompare.cpp @@ -23,11 +23,12 @@ Tensor min_quantized_cpu(const Tensor& self) { std::tuple sort_quantized_cpu( const Tensor& self, int64_t dim, - bool descending) { + bool descending, + bool stable) { Tensor sort_int; Tensor sort_indicies; std::tie(sort_int, sort_indicies) = - at::sort(self.int_repr(), dim, descending); + at::sort(self.int_repr(), dim, descending, stable); return std::forward_as_tuple( at::_make_per_tensor_quantized_tensor( sort_int, self.q_scale(), self.q_zero_point()), diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp index 11b84fa4713d4..a05026d49f46d 100644 --- a/aten/src/ATen/native/quantized/TensorFactories.cpp +++ b/aten/src/ATen/native/quantized/TensorFactories.cpp @@ -20,7 +20,7 @@ Tensor empty_affine_quantized( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK( options.has_dtype(), "Must provide data type for Tensor creation functions."); @@ -42,7 +42,7 @@ Tensor empty_per_channel_affine_quantized( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK( options.has_dtype(), "Must provide data type for Tensor creation functions."); diff --git a/aten/src/ATen/native/quantized/affine_quantizer.cpp b/aten/src/ATen/native/quantized/affine_quantizer.cpp index cbf116d741e3a..ecbe1de4bbfa0 100644 --- a/aten/src/ATen/native/quantized/affine_quantizer.cpp +++ b/aten/src/ATen/native/quantized/affine_quantizer.cpp @@ -17,6 +17,8 @@ DEFINE_DISPATCH(quantize_tensor_per_channel_float_qparams_stub); DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_stub); DEFINE_DISPATCH(dequantize_tensor_per_channel_affine_stub); DEFINE_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub); +DEFINE_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub); +DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub); namespace { @@ -55,7 +57,8 @@ void checkQuantizedTensor(const std::string& fn_name, Tensor t) { fn_name, " expects a ", caffe2::TypeMeta::Make(), - " Tensor"); + " Tensor, got ", + t.scalar_type()); } template @@ -103,13 +106,21 @@ Tensor quantize_tensor_per_tensor_affine( checkSameDevice(fn_name, rtensor, qtensor); checkSameSize(fn_name, qtensor, rtensor); - AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); checkZeroPoint(fn_name, zero_point); }); - quantize_tensor_per_tensor_affine_stub( + // Temporary solution to pack the tensor if dtype is torch.quint4x2 + // Can move this into the fbgemm::Quantize op. + if (qtensor.scalar_type() == at::ScalarType::QUInt4x2) { + quantize_tensor_per_tensor_affine_sub_byte_stub( + rtensor.device().type(), rtensor, qtensor, scale, zero_point); + } + else { + quantize_tensor_per_tensor_affine_stub( rtensor.device().type(), rtensor, qtensor, scale, zero_point); + } return qtensor; } @@ -163,7 +174,7 @@ Tensor quantize_tensor_per_channel_float_qparams( checkSameDevice(fn_name, rtensor, qtensor); checkSameSize(fn_name, qtensor, rtensor); - AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); }); @@ -195,13 +206,18 @@ Tensor dequantize_tensor_per_tensor_affine( checkSameDevice(fn_name, rtensor, qtensor); checkSameSize(fn_name, qtensor, rtensor); - AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); checkZeroPoint(fn_name, zero_point); }); - dequantize_tensor_per_tensor_affine_stub( - qtensor.device().type(), qtensor, rtensor, scale, zero_point); + if (qtensor.scalar_type() == at::ScalarType::QUInt4x2) { + dequantize_tensor_per_tensor_affine_sub_byte_stub( + qtensor.device().type(), qtensor, rtensor, scale, zero_point); + } else { + dequantize_tensor_per_tensor_affine_stub( + qtensor.device().type(), qtensor, rtensor, scale, zero_point); + } return rtensor; } @@ -253,7 +269,7 @@ Tensor dequantize_tensor_per_channel_float_qparams( checkSameDevice(fn_name, rtensor, qtensor); checkSameSize(fn_name, qtensor, rtensor); - AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); }); @@ -359,7 +375,8 @@ uint8_t quantize_val_arm( const float value) { const int32_t qmin = std::numeric_limits::min(); const int32_t qmax = std::numeric_limits::max(); - auto r = zero_point + static_cast(Round(value / scale)); + float inv_scale = 1.0f / scale; + auto r = zero_point + static_cast(Round(value * inv_scale)); r = std::max(r, qmin); r = std::min(r, qmax); return static_cast(r); @@ -379,7 +396,7 @@ void quantize_vec( } template -CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) { +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value) { // We need to convert the qint8 value to float to ensure the subtraction // subexpression returns a float return (static_cast(value.val_) - zero_point) * scale; @@ -394,17 +411,13 @@ CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) { * Note: For the case of embedding quantization we will set zero_point * to (-Xmin/scale), where Xmin is the min value in input tensor row. */ -template -T quantize_val_float_qparams(float scale, float zero_point, float value) { - int64_t qvalue; +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) { + int qvalue; - // TODO make sure qmax and qmin for dtypes other than int8, uint8 is correctly defined. - constexpr int64_t qmin = std::numeric_limits::min(); - constexpr int64_t qmax = std::numeric_limits::max(); float inv_scale = scale == 0 ? 1.0f : 1.0f / scale; qvalue = lrintf(value * inv_scale + zero_point); qvalue = std::max(qmin, std::min(qvalue, qmax)); - return static_cast(qvalue); + return qvalue; } template @@ -428,74 +441,68 @@ DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src) { std::min(std::max(quantize_down, min), max)); } -template CAFFE2_API qint8 +template TORCH_API qint8 quantize_val(double scale, int64_t zero_point, float value); -template CAFFE2_API quint8 +template TORCH_API quint8 quantize_val(double scale, int64_t zero_point, float value); -template CAFFE2_API qint32 +template TORCH_API qint32 quantize_val(double scale, int64_t zero_point, float value); -template CAFFE2_API void quantize_vec( +template TORCH_API void quantize_vec( double scale, int64_t zero_point, const float* src, c10::qint8* dst, size_t count); -template CAFFE2_API void quantize_vec( +template TORCH_API void quantize_vec( double scale, int64_t zero_point, const float* src, c10::quint8* dst, size_t count); -template CAFFE2_API void quantize_vec( +template TORCH_API void quantize_vec( double scale, int64_t zero_point, const float* src, c10::qint32* dst, size_t count); -template CAFFE2_API float dequantize_val( +template TORCH_API float dequantize_val( double scale, int64_t zero_point, qint8 value); -template CAFFE2_API float dequantize_val( +template TORCH_API float dequantize_val( double scale, int64_t zero_point, quint8 value); -template CAFFE2_API float dequantize_val( +template TORCH_API float dequantize_val( double scale, int64_t zero_point, qint32 value); -template CAFFE2_API qint8 +template TORCH_API qint8 requantize_val(double, int64_t, double, int64_t, qint8); -template CAFFE2_API quint8 +template TORCH_API quint8 requantize_val(double, int64_t, double, int64_t, qint8); -template CAFFE2_API qint32 +template TORCH_API qint32 requantize_val(double, int64_t, double, int64_t, qint8); -template CAFFE2_API qint8 +template TORCH_API qint8 requantize_val(double, int64_t, double, int64_t, quint8); -template CAFFE2_API quint8 +template TORCH_API quint8 requantize_val(double, int64_t, double, int64_t, quint8); -template CAFFE2_API qint32 +template TORCH_API qint32 requantize_val(double, int64_t, double, int64_t, quint8); -template CAFFE2_API qint8 +template TORCH_API qint8 requantize_val(double, int64_t, double, int64_t, qint32); -template CAFFE2_API quint8 +template TORCH_API quint8 requantize_val(double, int64_t, double, int64_t, qint32); -template CAFFE2_API qint32 +template TORCH_API qint32 requantize_val(double, int64_t, double, int64_t, qint32); -template CAFFE2_API qint8 requantize_from_int(double, int64_t, int64_t); -template CAFFE2_API quint8 +template TORCH_API qint8 requantize_from_int(double, int64_t, int64_t); +template TORCH_API quint8 requantize_from_int(double, int64_t, int64_t); -template CAFFE2_API qint32 +template TORCH_API qint32 requantize_from_int(double, int64_t, int64_t); -template CAFFE2_API qint8 -quantize_val_float_qparams(float scale, float zero_point, float value); -template CAFFE2_API quint8 -quantize_val_float_qparams(float scale, float zero_point, float value); -template CAFFE2_API qint32 -quantize_val_float_qparams(float scale, float zero_point, float value); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/affine_quantizer.h b/aten/src/ATen/native/quantized/affine_quantizer.h index 862a36f5f61a1..d583106116e69 100644 --- a/aten/src/ATen/native/quantized/affine_quantizer.h +++ b/aten/src/ATen/native/quantized/affine_quantizer.h @@ -77,6 +77,12 @@ using dequantize_tensor_per_channel_float_qparams_fn = void (*)( Tensor zero_points, int64_t axis); +using quantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(Tensor rtensor, Tensor qtensor, float scale, float zero_point); + +using dequantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(Tensor qtensor, Tensor rtensor, float scale, float zero_point); + DECLARE_DISPATCH( quantize_tensor_per_tensor_affine_fn, quantize_tensor_per_tensor_affine_stub); @@ -97,10 +103,17 @@ DECLARE_DISPATCH( dequantize_tensor_per_channel_float_qparams_fn, dequantize_tensor_per_channel_float_qparams_stub); +DECLARE_DISPATCH( + quantize_tensor_per_tensor_affine_sub_byte_fn, + quantize_tensor_per_tensor_affine_sub_byte_stub); + +DECLARE_DISPATCH( + dequantize_tensor_per_tensor_affine_sub_byte_fn, + dequantize_tensor_per_tensor_affine_sub_byte_stub); // Quantize a float value into a uint value given scale and zero_point template -CAFFE2_API T quantize_val(double scale, int64_t zero_point, float value); +TORCH_API T quantize_val(double scale, int64_t zero_point, float value); // TODO combine this with quantize_val once the numerics for ARM are aligned // with it uint8_t quantize_val_arm( @@ -115,38 +128,37 @@ void quantize_vec( T* dst, size_t count = 8); template -CAFFE2_API Tensor quantize_tensor( +TORCH_API Tensor quantize_tensor( Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point); template -CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value); +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value); template -CAFFE2_API float dequantize_vec( +TORCH_API float dequantize_vec( double scale, int64_t zero_point, const T* src, float* dst, size_t count = 8); template -CAFFE2_API Tensor dequantize_tensor( +TORCH_API Tensor dequantize_tensor( Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point); template -CAFFE2_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); +TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); // Given a multiplier and a zero_point, requantize int32_t computed values back // to quantized values. See comment above // make_per_tensor_affine_quantizer function for the usage of int64_t template -CAFFE2_API DST_T +TORCH_API DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src); -template -CAFFE2_API T quantize_val_float_qparams(float scale, float zero_point, float value); +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h index 921b585b89a76..d2e7500bf3021 100644 --- a/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h +++ b/aten/src/ATen/native/quantized/cpu/embedding_packed_params.h @@ -7,8 +7,18 @@ struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder { virtual at::Tensor embeddingbag_byte( const at::Tensor& indices, const c10::optional& offsets, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) = 0; + + virtual at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, bool include_last_offset) = 0; virtual at::Tensor unpack() = 0; diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index aa3b298e4b2b0..f25a3019347c4 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -60,33 +60,34 @@ void CopyToChannelsLast3dTensor( } } -} // namespace - -template <> -fbgemm::conv_param_t<2> MakeFbgemmConvParam<2>( - int N, - int C, - int M, - const std::vector& image_shape, - int groups, - const std::vector& kernels, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations) { - return fbgemm::conv_param_t<2>( - N, // batch size - C, // input channels - M, // output channels - {image_shape[0], image_shape[1]}, // feature map size - groups, // groups - {kernels[0], kernels[1]}, // kernels - {strides[0], strides[1]}, // strides - {pads[0], pads[1], pads[0], pads[1]}, // paddings - {dilations[0], dilations[1]}); // dilations +template +void CopyICFirst3dTensorToChannelsLast3dTensor( + int64_t G, + int64_t IC_G, + int64_t OC_G, + int64_t D, + int64_t H, + int64_t W, + const T* src, + T* dst) { + // IC OC/G THW -> G OC/G THW IC/G + const int64_t inner_size = D * H * W; + for (int64_t i = 0; i < G * OC_G; ++i) { + for (int64_t j = 0; j < inner_size; ++j) { + for (int64_t ic = 0; ic < IC_G; ++ic) { + int g = i / OC_G; + int oc = i % OC_G; + dst[(i * inner_size + j) * IC_G + ic] = + src[((g * IC_G + ic) * OC_G + oc) * inner_size + j]; + } + } + } } -template <> -fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>( +} // namespace + +template +fbgemm::conv_param_t MakeFbgemmConvParam( int N, int C, int M, @@ -95,17 +96,43 @@ fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>( const std::vector& kernels, const std::vector& strides, const std::vector& pads, - const std::vector& dilations) { - return fbgemm::conv_param_t<3>( + const std::vector& dilations, + const std::vector& output_padding, + bool transposed) { + std::array image_shape_; + std::array kernels_; + std::array strides_; + std::array pads_; + std::array dilations_; + std::array output_padding_; + std::move(image_shape.begin(), image_shape.begin() + image_shape.size(), image_shape_.begin()); + std::move( + kernels.begin(), kernels.begin() + kernels.size(), kernels_.begin()); + std::move( + strides.begin(), strides.begin() + strides.size(), strides_.begin()); + std::move( + dilations.begin(), + dilations.begin() + dilations.size(), + dilations_.begin()); + std::move( + output_padding.begin(), + output_padding.begin() + output_padding.size(), + output_padding_.begin()); + std::copy(pads.begin(), pads.begin() + pads.size(), pads_.begin()); + std::move(pads.begin(), pads.begin() + pads.size(), pads_.begin() + pads.size()); + + return fbgemm::conv_param_t( N, // batch size C, // input channels M, // output channels - {image_shape[0], image_shape[1], image_shape[2]}, // feature map size + image_shape_, // feature map size groups, // groups - {kernels[0], kernels[1], kernels[2]}, // kernels - {strides[0], strides[1], strides[2]}, // strides - {pads[0], pads[1], pads[2], pads[0], pads[1], pads[2]}, // paddings - {dilations[0], dilations[1], dilations[2]}); // dilations + kernels_, // kernels + strides_, // strides + pads_, // paddings + dilations_, // dilations + output_padding_, // output paddings for conv transpose + transposed); } Tensor MakeStridedQTensorCPU( @@ -206,14 +233,132 @@ Tensor ConvertToChannelsLast3dTensor(const Tensor& src) { return dst; } +template <> +Tensor TransposeConvTensorUnpackConversion<2>(const Tensor& src, int groups) { + // OC IC/G HW -> IC OC/G HW logically + auto oc_g_ic_g_hw_tensors = src.chunk(groups); + auto fused_tensor = at::cat(oc_g_ic_g_hw_tensors, 1); + set_quantizer_(fused_tensor, src.quantizer()); + return fused_tensor.permute({1, 0, 2, 3}); +} + +template fbgemm::conv_param_t<1> MakeFbgemmConvParam<1>( + int N, + int C, + int M, + const std::vector& image_shape, + int groups, + const std::vector& kernels, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::vector& output_padding, + bool transposed); + +template fbgemm::conv_param_t<2> MakeFbgemmConvParam<2>( + int N, + int C, + int M, + const std::vector& image_shape, + int groups, + const std::vector& kernels, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::vector& output_padding, + bool transposed); + +template fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>( + int N, + int C, + int M, + const std::vector& image_shape, + int groups, + const std::vector& kernels, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::vector& output_padding, + bool transposed); +template <> +Tensor TransposeConvTensorUnpackConversion<3>(const Tensor& src, int groups) { + // OC IC/G DHW -> IC OC/G DHW logically + auto oc_g_ic_g_hw_tensors = src.chunk(groups); + auto fused_tensor = at::cat(oc_g_ic_g_hw_tensors, 1); + set_quantizer_(fused_tensor, src.quantizer()); + return fused_tensor.permute({1, 0, 2, 3, 4}); +} + +template <> +Tensor ConvertConvWeightsToChannelLastTensor<2>( + const at::Tensor& src, + int groups, + bool transpose) { + return transpose ? + // 2D conv transpose weight transform + // IC OC/G KH KW -> G OC/G KH KW IC/G + [&]() { + auto ic_g_oc_g_hw_tensors = src.chunk(groups); + for (auto& tensor : ic_g_oc_g_hw_tensors) { + tensor = tensor.unsqueeze(0); + } + auto fused_tensor = at::cat(ic_g_oc_g_hw_tensors); + set_quantizer_(fused_tensor, src.quantizer()); + return fused_tensor.permute({0, 2, 3, 4, 1}) + .contiguous(c10::MemoryFormat::Contiguous); + }() + // 2d conv weight transform + : src.contiguous(c10::MemoryFormat::ChannelsLast); +} + +template <> +Tensor ConvertConvWeightsToChannelLastTensor<3>( + const at::Tensor& src, + int groups, + bool transpose) { + if (!transpose) { + return ConvertToChannelsLast3dTensor(src); + } else { + TORCH_CHECK(src.dim() == 5); + Tensor dst; + const int64_t N = src.size(0); + const int64_t IC_G = N / groups; + const int64_t OC_G = src.size(1); + const int64_t D = src.size(2); + const int64_t H = src.size(3); + const int64_t W = src.size(4); + dst = MakeStridedQTensorCPU( + {groups * OC_G, IC_G, D, H, W}, + {D * H * W * IC_G, 1, H * W * IC_G, W * IC_G, IC_G}, + src.options(), + src.quantizer()); + AT_DISPATCH_QINT_TYPES( + src.scalar_type(), "CopyICFirst3dTensorToChannelsLast3dTensor", [&]() { + const Tensor src_contig = src.contiguous(); + CopyICFirst3dTensorToChannelsLast3dTensor( + groups, + IC_G, + OC_G, + D, + H, + W, + src_contig.data_ptr(), + dst.data_ptr()); + }); + return dst; + } +} + } // namespace fbgemm_utils } // namespace native } // namespace at + #endif // USE_FBGEMM -template -CAFFE2_API torch::class_> register_conv_params() { + template + TORCH_API torch::class_> + register_conv_params() { static auto register_conv_params = torch::class_>( "quantized", "Conv" + c10::to_string(kSpatialDim) + "dPackedParamsBase") @@ -252,9 +397,9 @@ CAFFE2_API torch::class_> register_conv_params } template -CAFFE2_API torch::class_> register_conv_params<2>(); +TORCH_API torch::class_> register_conv_params<2>(); template -CAFFE2_API torch::class_> register_conv_params<3>(); +TORCH_API torch::class_> register_conv_params<3>(); torch::class_ register_linear_params() { using SerializationType = std::tuple>; diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index 9833980223530..916bf03fc098b 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -20,7 +20,7 @@ // of the A rows. The column offsets are needed for the asymmetric quantization // (affine quantization) of input matrix. // Note that in JIT mode we can think of a way to fuse col_offsets with bias. -struct CAFFE2_API PackedLinearWeight : public LinearPackedParamsBase { +struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { PackedLinearWeight( std::unique_ptr> w, c10::optional bias, @@ -74,7 +74,7 @@ struct CAFFE2_API PackedLinearWeight : public LinearPackedParamsBase { at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false); }; -struct CAFFE2_API PackedLinearWeightFp16 : public LinearPackedParamsBase { +struct TORCH_API PackedLinearWeightFp16 : public LinearPackedParamsBase { PackedLinearWeightFp16( std::unique_ptr w, c10::optional bias) @@ -117,7 +117,7 @@ struct CAFFE2_API PackedLinearWeightFp16 : public LinearPackedParamsBase { }; template -struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase { +struct TORCH_API PackedConvWeight : public ConvPackedParamsBase { PackedConvWeight( std::unique_ptr> w, c10::optional bias, @@ -257,7 +257,9 @@ fbgemm::conv_param_t MakeFbgemmConvParam( const std::vector& kernels, const std::vector& strides, const std::vector& pads, - const std::vector& dilations); + const std::vector& dilations, + const std::vector& output_padding = std::vector(kSpatialDim, 0), + bool transposed = false); // TODO: Remove functions below when ChannelsLast3d is ready. Tensor MakeStridedQTensorCPU( @@ -288,13 +290,23 @@ Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor( Tensor ConvertToChannelsLast3dTensor(const Tensor& src); +template +Tensor TransposeConvTensorUnpackConversion( + const Tensor& src, + int groups); + +template +Tensor ConvertConvWeightsToChannelLastTensor( + const at::Tensor& src, + int groups, + bool transpose); } // namespace fbgemm_utils } // namespace native } // namespace at #endif // USE_FBGEMM -struct CAFFE2_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { +struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { PackedEmbeddingBagWeight( at::Tensor packed_w, std::vector w_scale, @@ -302,12 +314,16 @@ struct CAFFE2_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { int64_t bit_rate, c10::QScheme q_scheme, int64_t version) - : packed_w(std::move(packed_w)), - w_scale(std::move(w_scale)), - w_zp(std::move(w_zp)), - bit_rate_(bit_rate), - q_scheme(q_scheme), - version_(version) {} + : packed_w(std::move(packed_w)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + bit_rate_(bit_rate), + q_scheme(q_scheme), + version_(version) { + if (!packed_w.is_contiguous()) { + packed_w = packed_w.contiguous(); + } + } at::Tensor packed_w; std::vector w_scale; @@ -330,7 +346,17 @@ struct CAFFE2_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { at::Tensor embeddingbag_byte( const at::Tensor& indices, const c10::optional& offsets, - bool sparse, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; + + at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, bool include_last_offset) override; }; diff --git a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp b/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp index 65036302e6ef0..29e7a9b259bb3 100644 --- a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp +++ b/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp @@ -10,17 +10,29 @@ namespace native { // format of the output the same as input Tensor int_repr_quantized_cpu(const Tensor& self) { Tensor dst; - AT_DISPATCH_QINT_TYPES(self.scalar_type(), "int_repr", [&]() { - dst = at::empty( - self.sizes(), - self.options().dtype(UNDERLYING_TYPE), - self.suggest_memory_format()); - auto iter = TensorIteratorConfig() - .check_all_same_dtype(false) - .add_output(dst) - .add_input(self) - .build(); - cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; }); + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self.scalar_type(), "int_repr", [&]() { + if (bit_width == 4) { + int64_t out_size = std::ceil(self.numel() * 0.5); + dst = at::empty( + {out_size}, + self.options().dtype(UNDERLYING_TYPE), + self.suggest_memory_format()); + const underlying_t* qdata = reinterpret_cast(self.data_ptr()); + for (int64_t i = 0; i < dst.numel(); ++i) { + dst[i] = static_cast(qdata[i]); + } + } else { + dst = at::empty( + self.sizes(), + self.options().dtype(UNDERLYING_TYPE), + self.suggest_memory_format()); + auto iter = TensorIteratorConfig() + .check_all_same_dtype(false) + .add_output(dst) + .add_input(self) + .build(); + cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; }); + } }); return dst; } diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index ddde74b61d520..8137049a75c87 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -486,7 +486,8 @@ static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx, }); } -void qsigmoid_kernel(const Tensor& qx, Tensor& qy) { +void qsigmoid_kernel( + const Tensor& qx, Tensor& qy, double output_scale, int64_t output_zero_point ) { int64_t zero_point = qx.q_zero_point(); float scale = qx.q_scale(); auto scale_vec = Vec256(scale); @@ -494,19 +495,6 @@ void qsigmoid_kernel(const Tensor& qx, Tensor& qy) { auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg(); AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() { - // Naive implemenentation: uses dequantize/execute/quantize routine - // - Output scale is set to 1.0 / 2^(BIT_NUM) - // - For signed types output zero point is set to 0 - // - For unsigned types output zero point is set to (qmax + qmin) / 2.0 - // See https://stackoverflow.com/a/34448562/3606192 for potential - // optimizations - float output_scale = 0.00390625; // 1.0 / 2^8 - int64_t output_zero_point = 0; - if (SCALAR_TYPE == at::kQInt32) { - output_scale = 2.3283064365386963e-10; // 1.0 / 2^32 - } else if (SCALAR_TYPE == at::kQInt8) { - output_zero_point = -128; - } float inv_output_scale = 1.0 / output_scale; qy = at::_empty_affine_quantized( @@ -641,6 +629,56 @@ void qclamp_kernel( }); } +void qclamp_min_kernel(const Tensor& qx, Scalar min_scalar, Tensor& qy) { + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() { + qy = at::_empty_affine_quantized( + qx.sizes(), + at::device(kCPU) + .dtype(SCALAR_TYPE) + .memory_format(qx.suggest_memory_format()), + qx.q_scale(), + qx.q_zero_point(), + c10::nullopt); + using Vec = Vec256; + auto iter = TensorIterator::unary_op(qy, qx); + auto min = min_scalar.to(); + scalar_t min_q = at::native::quantize_val( + qx.q_scale(), qx.q_zero_point(), min); + auto min_vec = Vec(min_q); + cpu_kernel_vec( + iter, + [&](scalar_t value) -> scalar_t { + return scalar_t(std::max(value.val_, min_q.val_)); + }, + [&](Vec val) -> Vec { return val.maximum(min_vec); }); + }); +} + +void qclamp_max_kernel(const Tensor& qx, Scalar max_scalar, Tensor& qy) { + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() { + qy = at::_empty_affine_quantized( + qx.sizes(), + at::device(kCPU) + .dtype(SCALAR_TYPE) + .memory_format(qx.suggest_memory_format()), + qx.q_scale(), + qx.q_zero_point(), + c10::nullopt); + using Vec = Vec256; + auto iter = TensorIterator::unary_op(qy, qx); + auto max = max_scalar.to(); + scalar_t max_q = at::native::quantize_val( + qx.q_scale(), qx.q_zero_point(), max); + auto max_vec = Vec(max_q); + cpu_kernel_vec( + iter, + [&](scalar_t value) -> scalar_t { + return scalar_t(std::min(value.val_, max_q.val_)); + }, + [&](Vec val) -> Vec { return val.minimum(max_vec); }); + }); +} + void qthreshold_kernel( // TODO: For future tasks, since output quantization parameters are set equal to // the input ones, it might make sense to implement this completely in the @@ -2530,7 +2568,9 @@ void dequantize_tensor_per_tensor_affine_cpu( #endif // USE_FBGEMM // TODO: add fbgemm for per channel -void quantize_tensor_per_channel_affine_cpu( +// Generic template defaults to naive quantize implementation +template +void quantize_tensor_per_channel_impl( Tensor rtensor, Tensor qtensor, Tensor scales, @@ -2542,47 +2582,253 @@ void quantize_tensor_per_channel_affine_cpu( // Since current implemntation on channels_last format does not // cover per channel quant with arbitrary axis value, it is better // to check and fail. - TORCH_CHECK(rtensor.is_contiguous() || (axis <=1), + int64_t batches = size_to_dim_(axis, rtensor.sizes()); + int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes()); + int64_t channels = rtensor.size(axis); + auto scales_data = scales.data_ptr(); + auto zero_points_data = zero_points.data_ptr(); + const float* in = rtensor.data_ptr(); + auto out = qtensor.data_ptr(); + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (auto b = 0; b < batches; ++b) { + for (auto e = 0; e < elements_per_channel; ++e) { + for (auto c = 0; c < channels; ++c) { + auto i = b * channels * elements_per_channel + e * channels + c; + out[i] = at::native::quantize_val( + scales_data[c], zero_points_data[c], in[i]); + } + } + } + } else { + for (auto b = 0; b < batches; ++b) { + for (auto c = 0; c < channels; ++c) { + for (auto e = 0; e < elements_per_channel; ++e) { + auto i = b * channels * elements_per_channel + + c * elements_per_channel + e; + out[i] = at::native::quantize_val( + scales_data[c], zero_points_data[c], in[i]); + } + } + } + } +} + +#if defined(__ARM_NEON__) || defined(__aarch64__) +// Specialized implementation from caffe2::Int8Quantize. +// There may be slight accuracy difference between this and implementation of +// quantize_val +// TODO Update quantize_tensor_per_channel_impl implementation to follow +// quantize_val, i.e. f = Round(value/scale + zero_point) +// TODO Make quantize_tensor_per_channel_impl work for other datatypes too +// (int8, int32). +template <> +void quantize_tensor_per_channel_impl( + Tensor rtensor, + Tensor qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis) { + int64_t batches = size_to_dim_(axis, rtensor.sizes()); + int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes()); + int64_t channels = rtensor.size(axis); + auto scales_data = scales.data_ptr(); + auto zero_points_data = zero_points.data_ptr(); + const float* in = rtensor.data_ptr(); + auto out = (uint8_t*)qtensor.data_ptr(); +#if defined(__ARM_NEON__) + // magic float and magic int to take care of rounding + // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000 + // Some detail: + // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you + // add a small number to a large number, the result rounds to the precision of + // the least significant bit of the large number. For IEEE-754 + // single-precision number mantissa has 23 bits, and adding 2**23 would cause + // rounding to the nearest even integer. The we cast to int and subtract the + // same number (0x4B400000 is the integer representation of 12582912.0f) to + // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the + // sign for negative numbers. + const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f); + // Copy reciprocal of scales (double) into float array + // Copy zero_points with magic int (int64_t) into int32_t array + std::vector inv_scales(channels); + std::vector zero_points_int32t(channels); + for (int i = 0; i < channels; ++i) { + inv_scales[i] = 1.0f / (float)scales_data[i]; + zero_points_int32t[i] = (int32_t)(uint32_t)zero_points_data[i] - 0x4B400000; + } + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]); + const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]); + c += 4; + const int32x4_t voffset4567 = vld1q_s32(&zero_points_int32t[c]); + const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset0123, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset4567, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]); + const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } +#else // defined(__ARM_NEON__) + // Copy scales (double) into float array + // Copy zero_points (int64_t) into int16_t array + std::vector inv_scales(channels); + std::vector zero_points_int16t(channels); + for (int i = 0; i < channels; ++i) { + inv_scales[i] = 1.0f / (float)scales_data[i]; + zero_points_int16t[i] = (int16_t)(uint16_t)zero_points_data[i]; + } + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]); + const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]); + const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } +#endif // defined(__ARM_NEON__) +} +#endif // defined(__ARM_NEON__) || defined(__aarch64__) + +void quantize_tensor_per_channel_affine_cpu( + Tensor rtensor, + Tensor qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis) { + TORCH_CHECK( + rtensor.is_contiguous() || (axis <= 1), "If tensor is channels_last contig then per channel quantization " "is supported only for axis = 0 or 1."); AT_DISPATCH_QINT_TYPES( qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() { - int64_t batches = size_to_dim_(axis, rtensor.sizes()); - int64_t elements_per_channel = - size_from_dim_(axis + 1, rtensor.sizes()); - int64_t channel = rtensor.size(axis); - auto scales_data = scales.data_ptr(); - auto zero_points_data = zero_points.data_ptr(); check_tensor_memory_format(rtensor, qtensor); - const float* rdata = rtensor.data_ptr(); - auto qdata = qtensor.data_ptr(); - if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || - rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { - // This code handles per channel quant when axis = 1 and - // channels_last contig. - // If axis = 0 and channels_last contig, implementation - // for channels first (NCHW) works. - for (auto b = 0; b < batches; ++b) { - for (auto e = 0; e < elements_per_channel; ++e) { - for (auto c = 0; c < channel; ++c) { - auto i = b * channel * elements_per_channel + e * channel + c; - qdata[i] = quantize_val( - scales_data[c], zero_points_data[c], rdata[i]); - } - } - } - } else { - for (auto b = 0; b < batches; ++b) { - for (auto c = 0; c < channel; ++c) { - for (auto e = 0; e < elements_per_channel; ++e) { - auto i = b * channel * elements_per_channel + - c * elements_per_channel + e; - qdata[i] = quantize_val( - scales_data[c], zero_points_data[c], rdata[i]); - } - } - } - } + quantize_tensor_per_channel_impl( + rtensor, qtensor, scales, zero_points, axis); }); } @@ -2592,7 +2838,8 @@ void dequantize_per_channel_affine_kernel( Tensor rtensor, Tensor scales, Tensor zero_points, - int64_t axis) { + int64_t axis, + int bit_width=8) { // For contiguous tensors, e.g. NCHW, arbitrary axis can be used. // For channels_last/3d however axis == 0 or 1. @@ -2611,6 +2858,7 @@ void dequantize_per_channel_affine_kernel( check_tensor_memory_format(qtensor, rtensor); const auto* qd = qtensor.data_ptr(); float* rd = rtensor.data_ptr(); + const auto elem_per_byte = 8 / bit_width; if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { for (auto b = 0; b < batches; ++b) { @@ -2619,8 +2867,12 @@ void dequantize_per_channel_affine_kernel( auto i = b * channel * elements_per_channel + e * channel + c; // We need to convert the qint8 value to float to ensure the // subtraction subexpression returns a float - rd[i] = (static_cast(qd[i].val_) - zero_points_data[c]) * - scales_data[c]; + auto qvalue = qd[i / elem_per_byte].val_; + if (bit_width < 8) { + qvalue >>= (i % elem_per_byte) * bit_width; + qvalue &= (1 << bit_width) - 1; + } + rd[i] = (static_cast(qvalue) - zero_points_data[c]) * scales_data[c]; } } } @@ -2632,8 +2884,12 @@ void dequantize_per_channel_affine_kernel( c * elements_per_channel + e; // We need to convert the qint8 value to float to ensure the // subtraction subexpression returns a float - rd[i] = (static_cast(qd[i].val_) - zero_points_data[c]) * - scales_data[c]; + auto qvalue = qd[i / elem_per_byte].val_; + if (bit_width < 8) { + qvalue >>= (i % elem_per_byte) * bit_width; + qvalue &= (1 << bit_width) - 1; + } + rd[i] = (static_cast(qvalue) - zero_points_data[c]) * scales_data[c]; } } } @@ -2667,7 +2923,7 @@ void quantize_tensor_per_channel_float_qparams_cpu( TORCH_CHECK(rtensor.is_contiguous() || (axis <=1), "If tensor is channels_last contig then per channel quantization " "is supported only for axis = 0 or 1."); - AT_DISPATCH_QINT_TYPES( + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES( qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() { int64_t batches = size_to_dim_(axis, rtensor.sizes()); int64_t elements_per_channel = @@ -2677,15 +2933,22 @@ void quantize_tensor_per_channel_float_qparams_cpu( auto zero_points_data = zero_points.data_ptr(); check_tensor_memory_format(rtensor, qtensor); const float* rdata = rtensor.data_ptr(); - auto qdata = qtensor.data_ptr(); + auto qdata = reinterpret_cast(qtensor.data_ptr()); + const auto elem_per_byte = CHAR_BIT / bit_width; + int qvalue = 0; if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { for (auto b = 0; b < batches; ++b) { for (auto e = 0; e < elements_per_channel; ++e) { for (auto c = 0; c < channel; ++c) { auto i = b * channel * elements_per_channel + e * channel + c; - qdata[i] = quantize_val_float_qparams( - scales_data[c], zero_points_data[c], rdata[i]); + qvalue = quantize_val_float_qparams( + scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max); + if (i % elem_per_byte == 0) { + qdata[i / elem_per_byte] = static_cast(qvalue); + } else { + qdata[i / elem_per_byte] |= static_cast(qvalue << ((i % elem_per_byte) * bit_width)); + } } } } @@ -2695,8 +2958,13 @@ void quantize_tensor_per_channel_float_qparams_cpu( for (auto e = 0; e < elements_per_channel; ++e) { auto i = b * channel * elements_per_channel + c * elements_per_channel + e; - qdata[i] = quantize_val_float_qparams( - scales_data[c], zero_points_data[c], rdata[i]); + qvalue = quantize_val_float_qparams( + scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max); + if (i % elem_per_byte == 0) { + qdata[i / elem_per_byte] = static_cast(qvalue); + } else { + qdata[i / elem_per_byte] |= static_cast(qvalue << ((i % elem_per_byte) * bit_width)); + } } } } @@ -2710,12 +2978,66 @@ void dequantize_tensor_per_channel_float_qparams_cpu( Tensor scales, Tensor zero_points, int64_t axis) { - AT_DISPATCH_QINT_TYPES( + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES( qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() { - dequantize_per_channel_affine_kernel(qtensor, rtensor, scales, zero_points, axis); + dequantize_per_channel_affine_kernel(qtensor, rtensor, scales, zero_points, axis, bit_width); }); } +void quantize_tensor_per_tensor_affine_sub_byte_cpu( + Tensor rtensor, + Tensor qtensor, + float scale, + float zero_point) { + // TODO Use fbgemm kernel to pack values + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES( + qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() { + check_tensor_memory_format(rtensor, qtensor); + const float* const rdata = rtensor.data_ptr(); + auto qdata = reinterpret_cast(qtensor.data_ptr()); + auto numel = rtensor.numel(); + const auto elem_per_byte = CHAR_BIT / bit_width; + for (int i = 0; i < numel; ++i) { + float inv_scale = scale == 0 ? 1.0f : 1.0f / scale; + int qvalue = lrintf(std::nearbyint(rdata[i] * inv_scale) + zero_point); + qvalue = std::max(quant_min, std::min(qvalue, quant_max)); + + // We pack sub_byte values and align them to a byte. + // Eg. for 4-bits Index 0 is packed in the lower 4-bits + // and index 1 is packed in the upper 4-bits. + if (i % elem_per_byte == 0) { + qdata[i / elem_per_byte] = static_cast(qvalue); + } else { + qdata[i / elem_per_byte] |= static_cast(qvalue << ((i % elem_per_byte) * bit_width)); + } + } // for numel + }); +} + +void dequantize_tensor_per_tensor_affine_sub_byte_cpu( + Tensor qtensor, + Tensor rtensor, + float scale, + float zero_point) { + // TODO Use fbgemm kernel to pack values + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES( + qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() { + check_tensor_memory_format(rtensor, qtensor); + auto rdata = rtensor.data_ptr(); + const underlying_t* qdata = reinterpret_cast(qtensor.data_ptr()); + auto numel = rtensor.numel(); + const auto elem_per_byte = CHAR_BIT / bit_width; + + for (int i = 0; i < numel; ++i) { + underlying_t qvalue = qdata[i / elem_per_byte]; + qvalue >>= (i % elem_per_byte) * bit_width; + qvalue &= (1 << bit_width) - 1; + rdata[i] = (static_cast(qvalue) - zero_point) * scale; + } + }); + +} + } // namespace REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub, @@ -2747,6 +3069,8 @@ REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel); REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel); REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel); REGISTER_DISPATCH(qclamp_stub, &qclamp_kernel); +REGISTER_DISPATCH(qclamp_min_stub, &qclamp_min_kernel); +REGISTER_DISPATCH(qclamp_max_stub, &qclamp_max_kernel); REGISTER_DISPATCH(qelu_stub, &qelu_kernel); REGISTER_DISPATCH(qhardsigmoid_stub, &qhardsigmoid_kernel); REGISTER_DISPATCH(qhardswish_stub, &qhardswish_kernel); @@ -2773,6 +3097,13 @@ REGISTER_DISPATCH( REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel); REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub, &qupsample_bilinear2d_nhwc_kernel); +REGISTER_DISPATCH( + quantize_tensor_per_tensor_affine_sub_byte_stub, + &quantize_tensor_per_tensor_affine_sub_byte_cpu); +REGISTER_DISPATCH( + dequantize_tensor_per_tensor_affine_sub_byte_stub, + &dequantize_tensor_per_tensor_affine_sub_byte_cpu); + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp index 22db20eeedb6e..b947db8f820c7 100644 --- a/aten/src/ATen/native/quantized/cpu/qadd.cpp +++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp @@ -88,7 +88,7 @@ Tensor _add_scalar_out(Tensor& out, const Tensor& self, Scalar other) { if (q_min > z - c_q) { s_prime = (((double)q_max - (z - c_q))) / ((double)q_max - q_min) * s; z_prime = q_min; - out.set_quantizer_(make_per_tensor_affine_quantizer( + set_quantizer_(out, make_per_tensor_affine_quantizer( s_prime, z_prime, self.scalar_type())); if (ReLUFused) { qadd_scalar_relu_stub(self.device().type(), out, self, c_q); @@ -98,7 +98,7 @@ Tensor _add_scalar_out(Tensor& out, const Tensor& self, Scalar other) { } else if (q_max < z - c_q) { s_prime = ((double)(z - c_q) - q_min) / ((double)q_max - q_min) * s; z_prime = q_max; - out.set_quantizer_(make_per_tensor_affine_quantizer( + set_quantizer_(out, make_per_tensor_affine_quantizer( s_prime, z_prime, self.scalar_type())); if (ReLUFused) { qadd_scalar_relu_stub(self.device().type(), out, self, c_q); @@ -109,7 +109,7 @@ Tensor _add_scalar_out(Tensor& out, const Tensor& self, Scalar other) { s_prime = s; z_prime = z - c_q; out.copy_(self); - out.set_quantizer_(make_per_tensor_affine_quantizer( + set_quantizer_(out, make_per_tensor_affine_quantizer( s_prime, z_prime, self.scalar_type())); if (ReLUFused) { at::native::relu_quantized_cpu_(out); @@ -243,6 +243,15 @@ Tensor qadd_scalar(Tensor qa, Scalar b) { return _add_scalar_out(qc, qa, b); } +template +Tensor qadd_scalar2(Scalar b, Tensor qa) { + TORCH_CHECK(qa.qscheme() == kPerTensorAffine || + qa.qscheme() == kPerTensorSymmetric, + "Only per tensor quantization is supported in Add."); + auto qc = at::empty_like(qa, qa.suggest_memory_format()); + return _add_scalar_out(qc, qa, b); +} + template Tensor qadd_scalar_out(Tensor qa, Scalar b, Tensor out) { check_inputs(qa, out); @@ -266,29 +275,31 @@ Tensor qadd_scalar_tensor_out(Tensor qa, Tensor b, Tensor out) { } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("add", TORCH_FN(qadd)); - m.impl("add.out", TORCH_FN(qadd_out)); - m.impl("add.Scalar", TORCH_FN(qadd_scalar)); - m.impl("add.Scalar_out", TORCH_FN(qadd_scalar_out)); - m.impl("add_relu", TORCH_FN(qadd)); - m.impl("add_relu.out", TORCH_FN(qadd_out)); - m.impl("add_relu.Scalar", TORCH_FN(qadd_scalar)); - m.impl("add_relu.Scalar_out", TORCH_FN(qadd_scalar_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add"), TORCH_FN(qadd)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add.out"), TORCH_FN(qadd_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar2"), TORCH_FN(qadd_scalar2)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar_out"), TORCH_FN(qadd_scalar_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu"), TORCH_FN(qadd)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.out"), TORCH_FN(qadd_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar2"), TORCH_FN(qadd_scalar2)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar_out"), TORCH_FN(qadd_scalar_out)); // deprecated functions, kept for backward compatibility - m.impl("add_out", TORCH_FN(qadd_out)); - m.impl("add_relu_out", TORCH_FN(qadd_out)); - m.impl("add_scalar", TORCH_FN(qadd_scalar)); - m.impl("add_scalar_relu", TORCH_FN(qadd_scalar)); - m.impl("add_scalar_out", TORCH_FN(qadd_scalar_out)); - m.impl("add_scalar_relu_out", TORCH_FN(qadd_scalar_out)); - m.impl("add_scalar.Tensor", TORCH_FN(qadd_scalar_tensor)); - m.impl("add_scalar_relu.Tensor", TORCH_FN(qadd_scalar_tensor)); - m.impl("add_scalar_out.Tensor", TORCH_FN(qadd_scalar_tensor_out)); - m.impl("add_scalar_relu_out.Tensor", TORCH_FN(qadd_scalar_tensor_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_out"), TORCH_FN(qadd_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu_out"), TORCH_FN(qadd_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar_relu"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar_out"), TORCH_FN(qadd_scalar_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar_relu_out"), TORCH_FN(qadd_scalar_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar.Tensor"), TORCH_FN(qadd_scalar_tensor)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar_relu.Tensor"), TORCH_FN(qadd_scalar_tensor)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar_out.Tensor"), TORCH_FN(qadd_scalar_tensor_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_scalar_relu_out.Tensor"), TORCH_FN(qadd_scalar_tensor_out)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { - m.impl("add", TORCH_FN(qadd)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::add"), TORCH_FN(qadd)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp index effafcacc76e0..b053940abba29 100644 --- a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp +++ b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp @@ -378,14 +378,14 @@ Tensor quantized_batch_norm( } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("batch_norm", TORCH_FN(q_batch_norm_impl)); - m.impl("batch_norm_relu", TORCH_FN(q_batch_norm_impl)); - m.impl("batch_norm1d", TORCH_FN(q_batch_norm1d_impl)); - m.impl("batch_norm1d_relu", TORCH_FN(q_batch_norm1d_impl)); - m.impl("batch_norm2d", TORCH_FN(q_batch_norm2d_impl)); - m.impl("batch_norm2d_relu", TORCH_FN(q_batch_norm2d_impl)); - m.impl("batch_norm3d", TORCH_FN(q_batch_norm3d_impl)); - m.impl("batch_norm3d_relu", TORCH_FN(q_batch_norm3d_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm"), TORCH_FN(q_batch_norm_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm_relu"), TORCH_FN(q_batch_norm_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm1d"), TORCH_FN(q_batch_norm1d_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm1d_relu"), TORCH_FN(q_batch_norm1d_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm2d"), TORCH_FN(q_batch_norm2d_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm2d_relu"), TORCH_FN(q_batch_norm2d_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm3d"), TORCH_FN(q_batch_norm3d_impl)); + m.impl(TORCH_SELECTIVE_NAME("quantized::batch_norm3d_relu"), TORCH_FN(q_batch_norm3d_impl)); } } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qclamp.cpp b/aten/src/ATen/native/quantized/cpu/qclamp.cpp index a700163077850..8c0e38390647e 100644 --- a/aten/src/ATen/native/quantized/cpu/qclamp.cpp +++ b/aten/src/ATen/native/quantized/cpu/qclamp.cpp @@ -15,6 +15,8 @@ namespace at { namespace native { DEFINE_DISPATCH(qclamp_stub); +DEFINE_DISPATCH(qclamp_min_stub); +DEFINE_DISPATCH(qclamp_max_stub); namespace { @@ -84,14 +86,26 @@ Tensor quantized_clamp_impl( Tensor qy; if (min && max) { #ifdef USE_PYTORCH_QNNPACK - if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) { + if (at::globalContext().qEngine() == at::QEngine::QNNPACK && + qx.scalar_type() == kQUInt8) { return qnnpack_clamp(qx, *min, *max); } #endif qclamp_stub(qx.device().type(), qx, *min, *max, qy); } else { - TORCH_CHECK( - false, "Both min and max should be specified for quantized clamp!"); +#ifdef USE_PYTORCH_QNNPACK + if (at::globalContext().qEngine() == at::QEngine::QNNPACK) { + TORCH_CHECK( + false, "Both min and max should be specified for quantized clamp!"); + } +#endif + if (max) { + qclamp_max_stub(qx.device().type(), qx, *max, qy); + } else if (min) { + qclamp_min_stub(qx.device().type(), qx, *min, qy); + } else { + TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None"); + } } return qy; } @@ -140,7 +154,7 @@ Tensor& hardtanh_quantized_cpu_( } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("clamp", TORCH_FN(clamp_quantized_cpu)); + m.impl(TORCH_SELECTIVE_NAME("quantized::clamp"), TORCH_FN(clamp_quantized_cpu)); } } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qconcat.cpp b/aten/src/ATen/native/quantized/cpu/qconcat.cpp index 0656f40e3554b..ca08c365d83d6 100644 --- a/aten/src/ATen/native/quantized/cpu/qconcat.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconcat.cpp @@ -102,10 +102,10 @@ Tensor qcat_out(const c10::List& qxs, int64_t dim, Tensor out) { } // namespace TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("cat", TORCH_FN(qcat)); - m.impl("cat_relu", TORCH_FN(qcat)); - m.impl("cat_out", TORCH_FN(qcat_out)); - m.impl("cat_relu_out", TORCH_FN(qcat_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::cat"), TORCH_FN(qcat)); + m.impl(TORCH_SELECTIVE_NAME("quantized::cat_relu"), TORCH_FN(qcat)); + m.impl(TORCH_SELECTIVE_NAME("quantized::cat_out"), TORCH_FN(qcat_out)); + m.impl(TORCH_SELECTIVE_NAME("quantized::cat_relu_out"), TORCH_FN(qcat_out)); } Tensor cat_quantized_cpu(TensorList qxs, int64_t dim) { diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 12563eb36d443..05762bfb036f9 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -5,17 +5,17 @@ #include #include #include -#include +#include #include #include #include -#include #include +#include namespace { // To have a sanity check for maximum matrix size. constexpr int64_t kReasonableMaxDim = 1000000; -} +} // namespace template bool ConvDimChecks( @@ -307,6 +307,9 @@ at::Tensor PackedConvWeight::apply_impl( const int dilation_d = kSpatialDim == 2 ? 1 : dilation_[0]; const int dilation_h = dilation_[kSpatialDim - 2]; const int dilation_w = dilation_[kSpatialDim - 1]; + const int output_padding_d = kSpatialDim == 2 ? 0 : output_padding_[0]; + const int output_padding_h = output_padding_[kSpatialDim - 2]; + const int output_padding_w = output_padding_[kSpatialDim - 1]; if (kSpatialDim == 2) { TORCH_CHECK( @@ -381,7 +384,13 @@ at::Tensor PackedConvWeight::apply_impl( : std::vector{pad_d, pad_h, pad_w}, kSpatialDim == 2 ? std::vector{dilation_h, dilation_w} - : std::vector{dilation_d, dilation_h, dilation_w}); + : std::vector{dilation_d, dilation_h, dilation_w}, + kSpatialDim == 2 + ? std::vector{output_padding_h, output_padding_w} + : std::vector{output_padding_d, + output_padding_h, + output_padding_w}, + transpose()); const float act_scale = act.q_scale(); const int32_t act_zero_point = act.q_zero_point(); @@ -397,8 +406,20 @@ at::Tensor PackedConvWeight::apply_impl( GetQuantizationParams( act_scale, output_scale, &output_multiplier_float, &act_times_w_scale); - const at::SmallVector output_shape = - MakeConvOutputShape(N, M, conv_p.OUT_DIM); + at::SmallVector output_shape; + if (transpose()) { + output_shape = MakeDeConvOutputShape( + N, + M, + kSpatialDim == 2 ? std::vector{H, W} : std::vector{D, H, W}, + kernel, + stride(), + padding(), + output_padding(), + dilation()); + } else { + output_shape = MakeConvOutputShape(N, M, conv_p.OUT_DIM); + } if (N > 0) { TORCH_CHECK( std::all_of( @@ -850,30 +871,33 @@ class QConvInt8ForBC final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("conv1d", QConv1dInt8::run); - m.impl("conv1d_relu", QConv1dInt8::run); - m.impl("conv2d.new", QConvInt8<2, false>::run); - m.impl("conv2d_relu.new", QConvInt8<2, true>::run); - m.impl("conv3d.new", QConvInt8<3, false>::run); - m.impl("conv3d_relu.new", QConvInt8<3, true>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"), QConv1dInt8::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"), QConv1dInt8::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu.new"), QConvInt8<2, true>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d.new"), QConvInt8<3, false>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_relu.new"), QConvInt8<3, true>::run); // for backward compatibility - m.impl("conv2d", QConvInt8ForBC<2, false>::run); - m.impl("conv2d_relu", QConvInt8ForBC<2, true>::run); - m.impl("conv3d", QConvInt8ForBC<3, false>::run); - m.impl("conv3d_relu", QConvInt8ForBC<3, true>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d"), QConvInt8ForBC<2, false>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu"), QConvInt8ForBC<2, true>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d"), QConvInt8ForBC<3, false>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_relu"), QConvInt8ForBC<3, true>::run); // transpose - m.impl("conv_transpose1d", QConv1dInt8::run); - m.impl("conv_transpose2d", QConvInt8<2, false>::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d"), QConv1dInt8::run); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d"), QConvInt8<2, false>::run); + m.impl( + TORCH_SELECTIVE_NAME("quantized::conv_transpose3d"), + QConvInt8<3, false>::run); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { - m.impl("conv2d", QConvInt8<2, false>::run); - m.impl("conv2d_relu", QConvInt8<2, true>::run); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv2d"), QConvInt8<2, false>::run); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv2d_relu"), QConvInt8<2, true>::run); // transpose - m.impl("conv_transpose1d", QConv1dInt8::run); - m.impl("conv_transpose2d", QConvInt8<2, false>::run); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose1d"), QConv1dInt8::run); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose2d"), QConvInt8<2, false>::run); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 4387b255dfe18..1bd0da28f053b 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -23,7 +23,6 @@ c10::intrusive_ptr> PackedConvWeight< torch::List dilation, int64_t groups, bool transpose) { - TORCH_CHECK(!transpose, "FBGEMM doesn't support transpose packing yet!"); TORCH_CHECK( weight.ndimension() == kSpatialDim + 2, "Weights are expected to have ", @@ -75,7 +74,9 @@ c10::intrusive_ptr> PackedConvWeight< : std::vector{kernel_d, kernel_h, kernel_w}, std::vector(stride.begin(), stride.end()), std::vector(padding.begin(), padding.end()), - std::vector(dilation.begin(), dilation.end())); + std::vector(dilation.begin(), dilation.end()), + std::vector(output_padding.begin(), output_padding.end()), + transpose); const auto qtype = weight.qscheme(); std::vector zero_points; @@ -84,8 +85,8 @@ c10::intrusive_ptr> PackedConvWeight< } else if (qtype == c10::kPerChannelAffine) { int64_t axis = weight.q_per_channel_axis(); TORCH_CHECK( - axis == 0, - "Only per output channel quantization is supported for the weights"); + !transpose, + "Per Channel Quantization is currently disabled for transposed conv"); zero_points.resize(output_channels); for (int i = 0; i < output_channels; ++i) { zero_points[i] = weight.q_per_channel_zero_points()[i].item(); @@ -96,11 +97,14 @@ c10::intrusive_ptr> PackedConvWeight< // FBGEMM expects weights to be in channels last // TODO: Change this when ChannelsLast3d is ready. - const at::Tensor weight_nhwc = kSpatialDim == 2 - ? weight.contiguous(c10::MemoryFormat::ChannelsLast) - : at::native::fbgemm_utils::ConvertToChannelsLast3dTensor(weight); + // FBGEMM needs G OC/G kDim0 ... kDimN IC/G + // for both conv and conv transpose + // but PyTorch lays them out as {out_c, in_c/groups, kH, kW} + // (or for ConvTranspose {in_c, out_c/groups, kH, kW}) + const at::Tensor weight_nhwc = + at::native::fbgemm_utils::ConvertConvWeightsToChannelLastTensor(weight, groups, transpose); const int8_t* weight_data_int8 = - reinterpret_cast(weight_nhwc.data_ptr()); + reinterpret_cast(weight_nhwc.data_ptr()); std::vector col_offsets(output_channels); // compute column offsets (Similar to // fbgemm::col_offsets_with_zero_pt_s8acc32_ref) please note that offsets @@ -166,6 +170,8 @@ c10::intrusive_ptr> PackedConvWeight< return ret_ptr; } +template struct PackedConvWeight<2>; +template struct PackedConvWeight<3>; #endif // USE_FBGEMM #ifdef USE_PYTORCH_QNNPACK @@ -270,6 +276,18 @@ c10::intrusive_ptr> PackedConvWeightsQnnp< return ret_ptr; } +template +c10::intrusive_ptr> PackedConvWeightsQnnp< + 2>:: + prepack( + at::Tensor weight, + c10::optional bias_in, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); #endif // USE_PYTORCH_QNNPACK namespace at { @@ -320,7 +338,6 @@ class QConvPackWeightInt8 final { auto& ctx = at::globalContext(); #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM) { - TORCH_CHECK(!transpose, "FBGEMM doesn't support transpose packing yet!"); return PackedConvWeight::prepack( weight, bias, stride, padding, output_padding, dilation, groups, transpose); @@ -415,21 +432,23 @@ class QConv1dPackWeightInt8 final { TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { // Conv // conv_prepack is deprecated, please use conv2d_prepack for 2D conv. - m.impl("conv_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_conv)); - m.impl("conv1d_prepack", TORCH_FN(QConv1dPackWeightInt8::run_conv)); - m.impl("conv2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_conv)); - m.impl("conv3d_prepack", TORCH_FN(QConvPackWeightInt8<3>::run_conv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_conv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_conv)); // ConvTranspose - m.impl("conv_transpose1d_prepack", TORCH_FN(QConv1dPackWeightInt8::run_deconv)); - m.impl("conv_transpose2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { // Conv - m.impl("conv2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_conv)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_conv)); // ConvTranspose - m.impl("conv_transpose1d_prepack", TORCH_FN(QConv1dPackWeightInt8::run_deconv)); - m.impl("conv_transpose2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp index 9e8a103cb17cc..484bfe44fc76e 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp @@ -13,7 +13,6 @@ template std::tuple> PackedConvWeight< kSpatialDim>::unpack() { auto* packed_weights_p = w.get(); - // output channels const int output_channels = packed_weights_p->outputChannels(); const int input_channels = packed_weights_p->inputChannels(); @@ -54,6 +53,9 @@ std::tuple> PackedConvWeight< w_scale[0], w_zp[0]); } else if (q_scheme == c10::kPerChannelAffine) { + TORCH_CHECK( + !transpose(), + "Per Channel Quantization is currently disabled for transposed conv"); auto scales = at::from_blob( w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat)); auto zero_points = at::from_blob( @@ -82,7 +84,11 @@ std::tuple> PackedConvWeight< int8_t* unpacked_weights_p = reinterpret_cast(unpacked_weights.data_ptr()); packed_weights_p->unpack(unpacked_weights_p); - + if(transpose()){ + unpacked_weights = + at::native::fbgemm_utils::TransposeConvTensorUnpackConversion< + kSpatialDim>(unpacked_weights, groups); + } return std::tuple>( unpacked_weights, bias); } @@ -243,36 +249,43 @@ class QConvTranspose final { TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { // conv_unpack is deprecated, please use conv2d_unpack for 2D conv. - m.impl("conv_unpack", TORCH_FN(QConvUnpackWeightsInt8<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_unpack"), TORCH_FN(QConvUnpackWeightsInt8<2>::run)); // We use conv2d_unpack to be consistent with conv3d_unpack - m.impl("conv1d_unpack", TORCH_FN(QConv1dUnpackWeightsInt8::run)); - m.impl("conv2d_unpack", TORCH_FN(QConvUnpackWeightsInt8<2>::run)); - m.impl("conv3d_unpack", TORCH_FN(QConvUnpackWeightsInt8<3>::run)); - - m.impl("conv2d_stride", TORCH_FN(QConvStride<2>::run)); - m.impl("conv2d_padding", TORCH_FN(QConvPadding<2>::run)); - m.impl("conv2d_output_padding", TORCH_FN(QConvOutputPadding<2>::run)); - m.impl("conv2d_dilation", TORCH_FN(QConvDilation<2>::run)); - m.impl("conv2d_groups", TORCH_FN(QConvGroups<2>::run)); - m.impl("conv2d_transpose", TORCH_FN(QConvTranspose<2>::run)); - - m.impl("conv3d_stride", TORCH_FN(QConvStride<3>::run)); - m.impl("conv3d_padding", TORCH_FN(QConvPadding<3>::run)); - m.impl("conv3d_output_padding", TORCH_FN(QConvOutputPadding<3>::run)); - m.impl("conv3d_dilation", TORCH_FN(QConvDilation<3>::run)); - m.impl("conv3d_groups", TORCH_FN(QConvGroups<3>::run)); - m.impl("conv3d_transpose", TORCH_FN(QConvTranspose<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_unpack"), TORCH_FN(QConv1dUnpackWeightsInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<3>::run)); + + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_stride"), TORCH_FN(QConvStride<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_padding"), TORCH_FN(QConvPadding<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_output_padding"), TORCH_FN(QConvOutputPadding<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_dilation"), TORCH_FN(QConvDilation<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_groups"), TORCH_FN(QConvGroups<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_transpose"), TORCH_FN(QConvTranspose<2>::run)); + + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_stride"), TORCH_FN(QConvStride<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_padding"), TORCH_FN(QConvPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_output_padding"), TORCH_FN(QConvOutputPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_dilation"), TORCH_FN(QConvDilation<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_groups"), TORCH_FN(QConvGroups<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv3d_transpose"), TORCH_FN(QConvTranspose<3>::run)); // ConvTranspose is the same, however, we want to have different name. - m.impl("conv_transpose1d_unpack", TORCH_FN(QConv1dUnpackWeightsInt8::run)); - m.impl("conv_transpose2d_unpack", TORCH_FN(QConvUnpackWeightsInt8<2>::run)); - - m.impl("conv_transpose2d_stride", TORCH_FN(QConvStride<2>::run)); - m.impl("conv_transpose2d_padding", TORCH_FN(QConvPadding<2>::run)); - m.impl("conv_transpose2d_output_padding", TORCH_FN(QConvOutputPadding<2>::run)); - m.impl("conv_transpose2d_dilation", TORCH_FN(QConvDilation<2>::run)); - m.impl("conv_transpose2d_groups", TORCH_FN(QConvGroups<2>::run)); - m.impl("conv_transpose2d_transpose", TORCH_FN(QConvTranspose<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_unpack"), TORCH_FN(QConv1dUnpackWeightsInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<3>::run)); + + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_stride"), TORCH_FN(QConvStride<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_padding"), TORCH_FN(QConvPadding<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_output_padding"), TORCH_FN(QConvOutputPadding<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_dilation"), TORCH_FN(QConvDilation<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_groups"), TORCH_FN(QConvGroups<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_transpose"), TORCH_FN(QConvTranspose<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_stride"), TORCH_FN(QConvStride<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_padding"), TORCH_FN(QConvPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_output_padding"), TORCH_FN(QConvOutputPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_dilation"), TORCH_FN(QConvDilation<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_groups"), TORCH_FN(QConvGroups<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_transpose"), TORCH_FN(QConvTranspose<3>::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qelu.cpp b/aten/src/ATen/native/quantized/cpu/qelu.cpp index 92b635471e784..e873506026e6f 100644 --- a/aten/src/ATen/native/quantized/cpu/qelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qelu.cpp @@ -24,8 +24,8 @@ Tensor quantized_celu(const Tensor& qx, double output_scale, int64_t output_zero } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("elu", quantized_elu); - m.impl("celu", quantized_celu); + m.impl(TORCH_SELECTIVE_NAME("quantized::elu"), quantized_elu); + m.impl(TORCH_SELECTIVE_NAME("quantized::celu"), quantized_celu); } }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index da494936aad7b..f017ad93cdee4 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #ifdef USE_FBGEMM #include @@ -11,267 +12,59 @@ torch::class_ register_embedding_params(); -at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( +namespace { +template +at::Tensor embedding_bag_4bit_impl( + const at::Tensor& weight, const at::Tensor& indices, - const c10::optional& offsets_in, - bool sparse, + const at::Tensor& offsets, + bool pruned_weights, const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, bool include_last_offset) { - TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_byte_rowwise_offsets expects offsets to be set"); - auto offsets = offsets_in.value(); - auto offsets_data = offsets.data_ptr(); - const auto indices_data = indices.data_ptr(); - - const auto weight_data = packed_w.data_ptr(); - - const int64_t N = packed_w.size(0); - const int64_t D = - packed_w.size(1) - 8; // NB: -8 to account for scale and bias - const int64_t M = offsets.size(0); - - int64_t output_size = M - 1; - std::vector offsets_include_last; - - if (!include_last_offset) { - output_size = M; - offsets_include_last.resize(M + 1); - std::memcpy( - offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * M); - offsets_include_last[M] = indices.numel(); - offsets_data = offsets_include_last.data(); - } - - std::vector shape = {output_size, D}; - auto output = at::empty(shape, packed_w.options().dtype(at::kFloat)); - auto* output_data = output.data_ptr(); - -#ifdef USE_FBGEMM - - auto kernel_i8_i64 = - fbgemm::GenerateEmbeddingSpMDM( - /*block_size=*/D, - /*has_weight=*/per_sample_weights_.has_value(), - /*normalize_by_lengths=*/false, - /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers) - /*is_weight_positional=*/false, - /*use_offsets=*/true); - - if (packed_w.is_contiguous()) { - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); + TORCH_CHECK(weight.dim() == 2); + TORCH_CHECK(offsets.dim() == 1); - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } else { - auto weight_contig = packed_w.contiguous(); - const auto weight_data_contig = weight_contig.data_ptr(); - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data_contig, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } -#endif - // TODO add default (non-FBGEMM) implementation. - return output; -} - -namespace at { -namespace native { -namespace { - -Tensor embedding_bag_byte_rowwise_offsets( - const Tensor& weight, - const Tensor& indices, - const c10::optional& offsets_in, - const bool /* scale_grad_by_freq */, - const int64_t /* mode */, - bool /* sparse */, - const c10::optional& per_sample_weights_, - bool include_last_offset) { - TORCH_CHECK(weight.scalar_type() == at::kByte); - TORCH_CHECK(weight.ndimension() == 2); - TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_byte_rowwise_offsets expects offsets to be set"); - - auto offsets = offsets_in.value(); - auto offsets_data = offsets.data_ptr(); const auto weight_data = weight.data_ptr(); - const auto indices_data = indices.data_ptr(); - - const int64_t N = weight.size(0); - const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias - const int64_t M = offsets.size(0); - - int64_t output_size = M - 1; - std::vector offsets_include_last; - - if (!include_last_offset) { - output_size = M; - offsets_include_last.resize(M + 1); - std::memcpy( - offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * M); - offsets_include_last[M] = indices.numel(); - offsets_data = offsets_include_last.data(); - } + const auto indices_data = indices.data_ptr(); + auto offsets_data = offsets.data_ptr(); - std::vector shape = {output_size, D}; - auto output = at::empty(shape, weight.options().dtype(at::kFloat)); - auto* output_data = output.data_ptr(); - -#ifdef USE_FBGEMM - - auto kernel_i8_i64 = - fbgemm::GenerateEmbeddingSpMDM( - /*block_size=*/D, - /*has_weight=*/per_sample_weights_.has_value(), - /*normalize_by_lengths=*/false, - /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers) - /*is_weight_positional=*/false, - /*use_offsets=*/true); - - if (weight.is_contiguous()) { - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); - - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } else { - auto weight_contig = weight.contiguous(); - const auto weight_data_contig = weight_contig.data_ptr(); - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data_contig, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } -#endif - return output; -} - -Tensor embedding_bag_4bit_rowwise_offsets( - const Tensor& weight, - const Tensor& indices, - const c10::optional& offsets_in, - const bool /* scale_grad_by_freq */, - const int64_t /* mode */, - bool sparse, - const c10::optional& per_sample_weights_, - const c10::optional& compressed_indices_mapping, - bool include_last_offset) { - TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_4bit_rowwise_offsets expects offsets to be set"); - - TORCH_CHECK(weight.ndimension() == 2); - TORCH_CHECK(indices.ndimension() == 1); - - auto offsets = offsets_in.value(); - TORCH_CHECK(offsets.ndimension() == 1); - - // FBGEMM expects the offsets to be of int type. - at::Tensor offsets_new = offsets.toType(ScalarType::Int); - - auto offsets_data = offsets_new.data_ptr(); - const auto weight_data = weight.data_ptr(); - uint8_t* input_data = nullptr; - if (!weight.is_contiguous()) { - auto weight_contig = weight.contiguous(); - input_data = weight_contig.data_ptr(); - } else { - input_data = weight.data_ptr(); - } - - // Get compressed indices for sparse op. + // Get compressed indices for pruned_weights op. int32_t* compressed_indices_mapping_data = nullptr; int compressed_index_size = 0; - if (sparse) { + bool fallback_to_no_sparse = false; + if (pruned_weights) { compressed_index_size = compressed_indices_mapping.value().numel(); compressed_indices_mapping_data = compressed_indices_mapping.value().data_ptr(); + + // if compressed_indices_mapping is [0], it is a indicator that + // we should fallback to non sparse embedding look up kernel. + if ((compressed_index_size == 1 && + compressed_indices_mapping_data[0] == 0)) { + fallback_to_no_sparse = true; + } } - const auto indices_data = indices.data_ptr(); - const int64_t N = weight.size(0); + const auto weight_sizes = weight.sizes(); + const int64_t N = weight_sizes[0]; + const int64_t weight_size = weight_sizes[1]; const int64_t D = - (weight.size(1) - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset - const int64_t M = offsets.size(0); + (weight_size - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset + const int64_t M = offsets.sizes()[0]; int64_t output_size = M - 1; - std::vector offsets_include_last_val; + std::vector offsets_include_last_val; if (!include_last_offset) { output_size = M; offsets_include_last_val.resize(M + 1); - // Avoid `null pointer passed as argument 2` ASAN violation when ofests + // Avoid `null pointer passed as argument 2` ASAN violation when offsets // tensor is empty. if (M > 0) { std::memcpy( - offsets_include_last_val.data(), offsets_data, sizeof(int) * M); + offsets_include_last_val.data(), + offsets_data, + sizeof(OffsetType) * M); } offsets_include_last_val[M] = indices.numel(); offsets_data = offsets_include_last_val.data(); @@ -280,14 +73,15 @@ Tensor embedding_bag_4bit_rowwise_offsets( const std::vector shape = {output_size, D}; auto output = at::empty(shape, weight.options().dtype(at::kFloat)); auto* output_data = output.data_ptr(); - const int64_t block_size = output.size(1); - TORCH_CHECK(block_size % 2 == 0, "block size must be divisible by 2"); + + const int64_t block_size = D; const int index_size = indices.numel(); constexpr int prefetch_distance = 16; + #ifdef USE_FBGEMM - if (!sparse) { + if (!pruned_weights || fallback_to_no_sparse) { // Generate the fbgemm kernel - auto kernel_64_ = fbgemm::GenerateEmbeddingSpMDMNBit( + auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit( /*bit rate=*/4, /*block size=*/block_size, /*has weights=*/per_sample_weights_.has_value(), @@ -296,11 +90,11 @@ Tensor embedding_bag_4bit_rowwise_offsets( /*is_weight_positional=*/false, /*use_offsets=*/true); - bool success = kernel_64_( + bool success = kernel( /*output_size=*/output_size, /*index_size=*/index_size, /*data_size=*/N, - /*input=*/input_data, + /*input=*/weight_data, /*indices=*/indices_data, /*offsets=*/offsets_data, /*weights=*/ @@ -313,8 +107,8 @@ Tensor embedding_bag_4bit_rowwise_offsets( success, "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for 4-bit input"); } else { - auto kernel_64_ = - fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse( + auto kernel = + fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse( /*bit rate=*/4, /*block_size=*/block_size, /*has weights=*/per_sample_weights_.has_value(), @@ -322,7 +116,7 @@ Tensor embedding_bag_4bit_rowwise_offsets( /*prefetch distance*/ prefetch_distance, /*is_weight_positional*/ false, /*use_offsets*/ true); - bool success = kernel_64_( + bool success = kernel( /*output_size=*/output_size, /*index_size=*/index_size, /*data_size=*/compressed_index_size, @@ -341,8 +135,8 @@ Tensor embedding_bag_4bit_rowwise_offsets( } #else - auto accessor = offsets.accessor(); - std::vector lengths_data; + auto accessor = offsets.accessor(); + std::vector lengths_data; int64_t lower = accessor[0]; for (int64_t i = 1; i < offsets.numel(); ++i) { @@ -366,7 +160,7 @@ Tensor embedding_bag_4bit_rowwise_offsets( for (int i = 0; i < lengths_data[m]; ++i, ++current) { int64_t idx; - if (!sparse) { + if (!pruned_weights) { idx = indices_data[current]; TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data"); } else { @@ -380,7 +174,7 @@ Tensor embedding_bag_4bit_rowwise_offsets( } } const at::Half* scale_bias = reinterpret_cast( - input_data + (idx + 1) * weight.size(1) - 2 * sizeof(at::Half)); + weight_data + (idx + 1) * weight_size - 2 * sizeof(at::Half)); float weight_val = 1.0f; if (per_sample_weights_.has_value()) { @@ -391,7 +185,7 @@ Tensor embedding_bag_4bit_rowwise_offsets( for (int j = 0; j < block_size; ++j) { uint8_t quantized = - input_data[idx * weight.size(1) + j / /*NUM_ELEM_PER_BYTE*/ 2]; + weight_data[idx * weight_size + j / /*NUM_ELEM_PER_BYTE*/ 2]; quantized >>= (j % 2) * 4; quantized &= (1 << 4) - 1; @@ -405,6 +199,460 @@ Tensor embedding_bag_4bit_rowwise_offsets( return output; } +template +at::Tensor& embedding_bag_byte_impl( + at::Tensor& output, + const at::Tensor& weight, + const at::Tensor& indices, + const at::Tensor& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) { + TORCH_CHECK(weight.scalar_type() == at::kByte); + TORCH_CHECK(weight.dim() == 2); + TORCH_CHECK(offsets.dim() == 1); + const auto weight_data = weight.data_ptr(); + const auto indices_data = indices.data_ptr(); + auto offsets_data = offsets.data_ptr(); + + // Get compressed indices for pruned_weights. + int32_t* compressed_indices_mapping_data = nullptr; + int compressed_index_size = 0; + bool fallback_to_no_sparse = false; + if (pruned_weights) { + compressed_index_size = compressed_indices_mapping.value().numel(); + compressed_indices_mapping_data = + compressed_indices_mapping.value().data_ptr(); + + // if compressed_indices_mapping is [0], it is a indicator that + // we should fallback to non sparse embedding look up kernel. + if ((compressed_index_size == 1 && + compressed_indices_mapping_data[0] == 0)) { + fallback_to_no_sparse = true; + } + } + + const auto weight_sizes = weight.sizes(); + const int64_t N = weight_sizes[0]; + const int64_t D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias + const int64_t M = offsets.sizes()[0]; + + int64_t output_size = M - 1; + std::vector offsets_include_last_val; + + if (!include_last_offset) { + output_size = M; + offsets_include_last_val.resize(M + 1); + // Avoid `null pointer passed as argument 2` ASAN violation when offsets + // tensor is empty. + if (M > 0) { + std::memcpy( + offsets_include_last_val.data(), + offsets_data, + sizeof(OffsetType) * M); + } + offsets_include_last_val[M] = indices.numel(); + offsets_data = offsets_include_last_val.data(); + } + std::vector shape; + if (indices.dim() == 2 && is_embedding_op) { + const auto indices_sizes = indices.sizes(); + shape = {indices_sizes[0], indices_sizes[1], D}; + } else { + shape = {output_size, D}; + } + output.resize_(shape); + auto* output_data = output.data_ptr(); + + const int index_size = indices.numel(); +#ifdef USE_FBGEMM + if (!pruned_weights || fallback_to_no_sparse) { + auto kernel_i8 = + fbgemm::GenerateEmbeddingSpMDM( + /*block_size=*/D, + /*has_weight=*/per_sample_weights_.has_value(), + /*normalize_by_lengths=*/false, + /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers) + /*is_weight_positional=*/false, + /*use_offsets=*/true); + + at::parallel_for( + 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { + bool success = kernel_i8( + /*output_size=*/end_idx - start_idx, + /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], + /*data_size=*/N, + /*input=*/weight_data, + /*indices=*/indices_data + offsets_data[start_idx], + /*offsets_or_lengths=*/offsets_data + start_idx, + /*weights=*/ + per_sample_weights_ + ? per_sample_weights_.value().data_ptr() + + offsets_data[start_idx] + : nullptr, + /*out=*/output_data + start_idx * D); + + TORCH_CHECK( + success, + "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); + }); + } else { + // pruned weights + auto kernel_i8_sparse = fbgemm:: + GenerateEmbeddingSpMDMRowWiseSparse( + /*block_size=*/D, + /*has_weight=*/per_sample_weights_.has_value(), + /*normalize_by_lengths=*/false, + /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers) + /*is_weight_positional=*/false, + /*use_offsets=*/true); + + auto success = kernel_i8_sparse( + /*output_size=*/output_size, + /*index_size=*/index_size, + /*data_size=*/compressed_index_size, + /*input=*/weight_data, + /*indices=*/indices_data, + /*offsets=*/offsets_data, + /*weights=*/ + per_sample_weights_.has_value() + ? per_sample_weights_.value().data_ptr() + : nullptr, + /*output=*/output_data, + /*compressed_indices_table=*/compressed_indices_mapping_data); + TORCH_CHECK( + success, + "FBGEMM GenerateEmbeddingSpMDMRowWiseSparse kernel failed for 8-bit input"); + } + return output; +#endif + // TODO add default (non-FBGEMM) implementation. + TORCH_CHECK( + false, + "embedding_bag_byte expects FBGEMM support. This PyTorch installation was not built with FBGEMM operators"); +} + +at::Tensor& embedding_bag_byte_helper( + at::Tensor& output, + const at::Tensor& weight, + const at::Tensor& indices, + const c10::optional& offsets_in, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) { + at::Tensor offsets; + TORCH_CHECK( + indices.dim() == 1 || indices.dim() == 2, + "qembedding/qembedding_bag operator supports 1 or 2d indices, got ", + indices.dim()); + // For embedding_bag operator with 2D indices, we set the offsets explicitly + // here. + if (indices.dim() == 2 && !is_embedding_op) { + TORCH_CHECK( + !offsets_in.has_value(), + "embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); + + offsets = at::arange( + 0, indices.numel(), indices.sizes()[1], indices.scalar_type()); + } else { + TORCH_CHECK( + offsets_in.has_value(), + "embedding_bag_byte expects offsets to be set for 1D indices."); + offsets = offsets_in.value(); + } + + TORCH_CHECK( + indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong, + "Expect 32 or 64 bit indices, but found ", + indices.scalar_type(), + " instead."); + TORCH_CHECK( + offsets.scalar_type() == at::kInt || offsets.scalar_type() == at::kLong, + "Expect 32 or 64 bit offsets, but found ", + offsets.scalar_type(), + " instead."); + TORCH_CHECK( + weight.is_contiguous() && indices.is_contiguous() && + offsets.is_contiguous(), + "Expect weight, indices, and offsets to be contiguous."); + + // Using helper function to support different type combination without the + // need to cast, which can be additional performance overhead + if (indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kInt) { + return embedding_bag_byte_impl( + output, + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset, + is_embedding_op); + } else if ( + indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) { + return embedding_bag_byte_impl( + output, + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset, + is_embedding_op); + } else if ( + indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) { + return embedding_bag_byte_impl( + output, + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset, + is_embedding_op); + } + + // default case given the TORCH_CHECK above + return embedding_bag_byte_impl( + output, + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset, + is_embedding_op); +} + +at::Tensor embedding_bag_4bit_helper( + const at::Tensor& weight, + const at::Tensor& indices, + const c10::optional& offsets_in, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + at::Tensor offsets; + TORCH_CHECK( + indices.dim() == 1 || indices.dim() == 2, + "qembedding/qembedding_bag operator supports 1 or 2d indices, got ", + indices.dim()); + + // For embedding_bag operator with 2D indices, we need to set the offsets + // explicitly here. + if (indices.dim() == 2) { + TORCH_CHECK( + !offsets_in.has_value(), + "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); + + offsets = at::arange( + 0, indices.numel(), indices.sizes()[1], indices.scalar_type()); + } else { + TORCH_CHECK( + offsets_in.has_value(), + "embedding_bag_4bit operator expects offsets to be set for 1D indices."); + offsets = offsets_in.value(); + } + + TORCH_CHECK( + indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong, + "Expect 32 or 64 bit indices, but found ", + indices.scalar_type(), + " instead."); + TORCH_CHECK( + offsets.scalar_type() == at::kInt || offsets.scalar_type() == at::kLong, + "Expect 32 or 64 bit offsets, but found ", + offsets.scalar_type(), + " instead."); + TORCH_CHECK( + weight.is_contiguous() && indices.is_contiguous() && + offsets.is_contiguous(), + "Expect weight, indices, and offsets to be contiguous."); + + // Using helper function to support different type combination without the + // need to cast, which can be additional performance overhead + if (indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kInt) { + return embedding_bag_4bit_impl( + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); + } else if ( + indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) { + return embedding_bag_4bit_impl( + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); + } else if ( + indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) { + return embedding_bag_4bit_impl( + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); + } + return embedding_bag_4bit_impl( + weight, + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); +} +} // namespace + +at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( + const at::Tensor& indices, + const c10::optional& offsets_in, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) { + auto output = at::empty({0}, packed_w.options().dtype(at::kFloat)); + return embedding_bag_byte_helper( + output, + packed_w, + indices, + offsets_in, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset, + is_embedding_op); +} + +at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets_in, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + if (per_sample_weights_.has_value()) { + TORCH_CHECK( + (per_sample_weights_.value().scalar_type() == at::kFloat || + per_sample_weights_.value().scalar_type() == at::kHalf), + "Expect fp32 or fp16 weights, but found", + per_sample_weights_.value().scalar_type(), + " instead") + } + + return embedding_bag_4bit_helper( + packed_w, + indices, + offsets_in, + pruned_weights, + per_sample_weights_.has_value() + ? per_sample_weights_.value().to(at::kFloat) + : per_sample_weights_, + compressed_indices_mapping, + include_last_offset); +} + +namespace at { +namespace native { + +Tensor& embedding_bag_byte_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + return embedding_bag_byte_helper( + output, + weight, + indices, + offsets_in, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset, + false /* is_embedding_op */); +} + +namespace { + +Tensor embedding_bag_byte_rowwise_offsets( + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + auto output = at::empty({0}, weight.options().dtype(at::kFloat)); + embedding_bag_byte_rowwise_offsets_out( + output, + weight, + indices, + offsets_in, + false /*unused scale_grad_by_freq*/, + 0 /*unused mode*/, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); + return output; +} + +Tensor embedding_bag_4bit_rowwise_offsets( + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + if (per_sample_weights_.has_value()) { + TORCH_CHECK( + (per_sample_weights_.value().scalar_type() == at::kFloat || + per_sample_weights_.value().scalar_type() == at::kHalf), + "Expect fp32 or fp16 weights, but found", + per_sample_weights_.value().scalar_type(), + " instead") + } + + return embedding_bag_4bit_helper( + weight, + indices, + offsets_in, + pruned_weights, + per_sample_weights_.has_value() + ? per_sample_weights_.value().to(at::kFloat) + : per_sample_weights_, + compressed_indices_mapping, + include_last_offset); +} + template class QEmbeddingBag final { public: @@ -414,13 +662,27 @@ class QEmbeddingBag final { const c10::optional& offsets, const bool /* scale_grad_by_freq */, const int64_t /* mode */, - bool sparse, + bool pruned_weights, const c10::optional& per_sample_weights_, const c10::optional& compressed_indices_mapping, bool include_last_offset) { if (bit_rate == 8) { return packed_weight->embeddingbag_byte( - indices, offsets, sparse, per_sample_weights_, include_last_offset); + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset, + false /* is_embedding_op */); + } else if (bit_rate == 4) { + return packed_weight->embeddingbag_4bit( + indices, + offsets, + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset); } else { TORCH_INTERNAL_ASSERT( "Currently only support 8-bit embedding_bag quantization"); @@ -434,13 +696,21 @@ class QEmbedding final { static at::Tensor run( const c10::intrusive_ptr& packed_weight, const Tensor& indices, - bool sparse) { + bool pruned_weights) { + // Set default offsets here since the FBGEMM lookup op expects it. const auto offsets_size = indices.numel(); - at::Tensor offsets = at::arange(0, offsets_size, at::kLong); + at::Tensor offsets = at::arange(0, offsets_size, indices.scalar_type()); at::Tensor output; if (bit_rate == 8) { return packed_weight->embeddingbag_byte( - indices, offsets, sparse, c10::nullopt, false); + indices, + offsets, + pruned_weights, + c10::nullopt, + c10::nullopt, + false /* include_last_offset */, + true /* is_embedding_op */); + } else { TORCH_INTERNAL_ASSERT( "Currently only support 8-bit embedding quantization"); @@ -451,14 +721,23 @@ class QEmbedding final { TORCH_LIBRARY_IMPL(quantized, CPU, m) { // Function that works on TorchBind packed weights. - m.impl("embedding_bag_byte", TORCH_FN(QEmbeddingBag<8>::run)); - m.impl("embedding_byte", TORCH_FN(QEmbedding<8>::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte"), + TORCH_FN(QEmbeddingBag<8>::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit"), + TORCH_FN(QEmbeddingBag<4>::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_byte"), + TORCH_FN(QEmbedding<8>::run)); // Functions that work on at::Tensor packed weight. m.impl( - "embedding_bag_byte_rowwise_offsets", embedding_bag_byte_rowwise_offsets); + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_rowwise_offsets"), + embedding_bag_byte_rowwise_offsets); m.impl( - "embedding_bag_4bit_rowwise_offsets", embedding_bag_4bit_rowwise_offsets); + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"), + embedding_bag_4bit_rowwise_offsets); } } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.h b/aten/src/ATen/native/quantized/cpu/qembeddingbag.h new file mode 100644 index 0000000000000..b8e523b6216e2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.h @@ -0,0 +1,17 @@ +#include + +namespace at { +namespace native { +Tensor& embedding_bag_byte_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset); +} // native +} // at diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 96d592594d04f..0305f0b380ba1 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -14,7 +14,6 @@ torch::class_ register_embedding_params(); * To prepack the weights we store the scale and bias (where bias is Xmin) * for each row along with the quantized weights. */ -// TODO: Extend this to support 4-bits once 4-bit qtensor support is added. c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( at::Tensor qweight) { static constexpr int64_t version = 1; @@ -22,13 +21,24 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( qweight.dim() == 2, "quantized::embedding_bag_prepack weight tensor rank should be 2"); TORCH_CHECK( - qweight.scalar_type() == c10::kQUInt8, - "qembedding_bag_prepack currently only supports quint8 weights"); + qweight.scalar_type() == c10::kQUInt8 || + qweight.scalar_type() == c10::kQUInt4x2, + "qembedding_bag_prepack currently only supports quint8 and quint4x2 weights"); at::Tensor weight_contig = qweight.contiguous(qweight.suggest_memory_format()); - const uint8_t* weight_data = - reinterpret_cast(weight_contig.data_ptr()); + + int bit_width, scale_bias_bytes; + uint8_t* weight_data = static_cast(weight_contig.data_ptr()); + if (qweight.scalar_type() == c10::kQUInt8) { + bit_width = 8; + scale_bias_bytes = 8; // extra 8 bytes to store FP scale and bias per row. + } else { + bit_width = 4; + scale_bias_bytes = + 4; // extra 4 bytes to store at::Half scale and bias per row. + } + const auto num_elem_per_byte = 8 / bit_width; int64_t embedding_rows = qweight.size(0); int64_t embedding_cols = qweight.size(1); @@ -50,8 +60,9 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( std::vector output_shape = { embedding_rows, - embedding_cols + - 8}; // extra 8 bytes to store FP scale and zero_point per row. + static_cast( + (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte + + scale_bias_bytes)}; // extra bytes to store scale and bias per row. size_t output_columns = output_shape[1]; // Allocate output packed weights. @@ -61,28 +72,46 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( weight_contig.suggest_memory_format()); auto* output_data = output.data_ptr(); - at::parallel_for( - 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { - for (int64_t row = start_idx; row < end_idx; ++row) { - const uint8_t* input_row = weight_data + row * embedding_cols; - std::uint8_t* output_row = output_data + row * output_columns; - float* output_row_scale_bias = - reinterpret_cast(output_row + embedding_cols); - output_row_scale_bias[0] = weight_scales[row]; - output_row_scale_bias[1] = weight_bias[row]; - for (int64_t col = 0; col < embedding_cols; ++col) { - output_row[col] = input_row[col]; + if (bit_width == 8) { + at::parallel_for( + 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + const uint8_t* input_row = weight_data + row * embedding_cols; + std::uint8_t* output_row = output_data + row * output_columns; + float* output_row_scale_bias = + reinterpret_cast(output_row + embedding_cols); + output_row_scale_bias[0] = weight_scales[row]; + output_row_scale_bias[1] = weight_bias[row]; + for (int64_t col = 0; col < embedding_cols; ++col) { + output_row[col] = input_row[col]; + } } - } - }); + }); + } else { + // Re-calculate the number of embedding_cols, to account for values packed + // in a byte. + embedding_cols = + (embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte; + at::parallel_for( + 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + const uint8_t* input_row = weight_data + row * embedding_cols; + std::uint8_t* output_row = output_data + row * output_columns; + at::Half* output_row_scale_bias = + reinterpret_cast(output_row + embedding_cols); + output_row_scale_bias[0] = weight_scales[row]; + output_row_scale_bias[1] = weight_bias[row]; + for (int64_t col = 0; col < embedding_cols; ++col) { + // The weight values have already been packed, so here we just + // store it in the output tensor. + output_row[col] = input_row[col]; + } + } + }); + } auto packed_ptr = c10::make_intrusive( - output, - weight_scales, - weight_zero_points, - 8 /* bit rate */, - qtype, - version); + output, weight_scales, weight_zero_points, bit_width, qtype, version); return packed_ptr; } @@ -113,8 +142,14 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) { auto* output_data = output.data_ptr(); #ifdef USE_FBGEMM - fbgemm::FloatToFused8BitRowwiseQuantizedSBFloat( - weight_data, embedding_rows, embedding_cols, output_data); + at::parallel_for( + 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + fbgemm::FloatToFused8BitRowwiseQuantizedSBFloat( + weight_data + row * embedding_cols, 1, + embedding_cols, output_data + row * output_shape[1]); + } + }); #else size_t output_columns = output_shape[1]; constexpr float kEpsilon = 1e-8f; @@ -184,8 +219,14 @@ Tensor _qembeddingbag_nbit_prepack_helper( #ifdef USE_FBGEMM if (!optimized_qparams) { - fbgemm::FloatToFusedNBitRowwiseQuantizedSBHalf( - bit_width, weight_data, embedding_rows, embedding_cols, output_data); + at::parallel_for( + 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + fbgemm::FloatToFusedNBitRowwiseQuantizedSBHalf( + bit_width, weight_data + row * embedding_cols, 1, + embedding_cols, output_data + row * output_shape[1]); + } + }); } else { #endif // USE_FBGEMM const auto output_columns = output.size(output.dim() - 1); @@ -196,8 +237,14 @@ Tensor _qembeddingbag_nbit_prepack_helper( float Xmin, Xmax; if (optimized_qparams) { - std::tie(Xmax, Xmin) = at::choose_qparams_optimized( + at::Tensor xmax_tensor, xmin_tensor; + std::tie(xmax_tensor, xmin_tensor) = at::choose_qparams_optimized( weight_contig[row], embedding_cols, 200, 0.16, bit_width); + TORCH_CHECK( + xmax_tensor.numel() == 1 && xmin_tensor.numel() == 1, + "Expected choose_qparams_optimized to return min/max tensors of size 1"); + Xmax = xmax_tensor.item(); + Xmin = xmin_tensor.item(); } else { Xmin = *std::min_element(input_row, input_row + embedding_cols); Xmax = *std::max_element(input_row, input_row + embedding_cols); @@ -254,7 +301,9 @@ Tensor _qembeddingbag_nbit_prepack_helper( // To later de-quantize values, the scale (range / 15) and zero_point // are stored alongside the data. More precisely, each row first has quantized // values, and then 2-byte fp16 scale and 2-byte zero_offset. -Tensor qembeddingbag_4bit_prepack(const Tensor& weight, bool optimized_qparams) { +Tensor qembeddingbag_4bit_prepack( + const Tensor& weight, + bool optimized_qparams) { return _qembeddingbag_nbit_prepack_helper( weight, 4 /*bit_width*/, optimized_qparams); } @@ -267,7 +316,9 @@ Tensor qembeddingbag_4bit_prepack(const Tensor& weight, bool optimized_qparams) // are stored alongside the data. More precisely, each row first has quantized // values, and then 2-byte fp16 scale and 2-byte zero_offset. // TODO() - Add 2Bit Embedding Lookup operator. -Tensor qembeddingbag_2bit_prepack(const Tensor& weight, bool optimized_qparams) { +Tensor qembeddingbag_2bit_prepack( + const Tensor& weight, + bool optimized_qparams) { return _qembeddingbag_nbit_prepack_helper( weight, 2 /*bit_width*/, optimized_qparams); } @@ -280,13 +331,21 @@ class QEmbeddingPackWeights final { }; TORCH_LIBRARY_IMPL(quantized, CPU, m) { - m.impl("embedding_bag_byte_prepack", TORCH_FN(qembeddingbag_byte_prepack)); - m.impl("embedding_bag_4bit_prepack", TORCH_FN(qembeddingbag_4bit_prepack)); - m.impl("embedding_bag_2bit_prepack", TORCH_FN(qembeddingbag_2bit_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"), + TORCH_FN(qembeddingbag_byte_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"), + TORCH_FN(qembeddingbag_4bit_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"), + TORCH_FN(qembeddingbag_2bit_prepack)); } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("embedding_bag_prepack", TORCH_FN(QEmbeddingPackWeights::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_prepack"), + TORCH_FN(QEmbeddingPackWeights::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp index 4a9ae73ee1373..542d166151b20 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp @@ -9,38 +9,69 @@ torch::class_ register_embedding_params(); at::Tensor PackedEmbeddingBagWeight::unpack() { auto packed_weight = packed_w; at::Tensor weight_origin; - if (bit_rate_ == 8) { + + if (bit_rate_ == 8 || bit_rate_ == 4) { const auto input_rows = packed_weight.size(0); const auto input_columns = packed_weight.size(1); - - // The last 2 values are used to store the FP32 scale and zero_point values - // per row. - int output_columns = input_columns - 2 * sizeof(float); + int scale_bias_bytes; + const auto num_elem_per_byte = 8 / bit_rate_; + if (bit_rate_ == 8) { + // The last 2 values are used to store the FP32 scale and zero_point + // values per row. + scale_bias_bytes = 8; + } else { + scale_bias_bytes = 4; + } const auto* input = packed_weight.data_ptr(); - std::vector output_shape = {input_rows, output_columns}; + // Calculate the output shape, accounting for the last n bytes to be used + // for scale/bias rest of the entries are packed depending on the bit_width. + std::vector output_shape = { + input_rows, + static_cast(input_columns - scale_bias_bytes) * + num_elem_per_byte}; auto scales = at::from_blob( w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat)); auto zero_points = at::from_blob( w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kFloat)); - weight_origin = at::_empty_per_channel_affine_quantized( - output_shape, - scales.toType(c10::kFloat), - zero_points.toType(c10::kFloat), - 0, // The output channel axis is 0 - device(c10::kCPU).dtype(c10::kQUInt8)); - - uint8_t* output_data = - reinterpret_cast(weight_origin.data_ptr()); - + auto output_columns = output_shape[1]; + uint8_t* output_data; + + // Allocate output weight tensor based on the bit_width + if (bit_rate_ == 8) { + weight_origin = at::_empty_per_channel_affine_quantized( + output_shape, + scales.toType(c10::kFloat), + zero_points.toType(c10::kFloat), + 0, // The output channel axis is 0 + device(c10::kCPU).dtype(c10::kQUInt8)); + output_data = static_cast(weight_origin.data_ptr()); + } else { + // We create empty qtensor with the full output shape, and dtype set to + // quint4x2 This will internally allocate appropriate storage bytes to + // account for the packed nature of this dtype. + weight_origin = at::_empty_per_channel_affine_quantized( + output_shape, + scales.toType(c10::kFloat), + zero_points.toType(c10::kFloat), + 0, // The output channel axis is 0 + device(c10::kCPU).dtype(c10::kQUInt4x2)); + output_data = static_cast(weight_origin.data_ptr()); + } + + // Copy over the data from the packed weight to the output. + // For sub-byte tensors this will copy the packed bytes over since the + // sub_byte qtensors are expected to store data in packed format. at::parallel_for(0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) { for (int64_t row = start_idx; row < end_idx; ++row) { const std::uint8_t* input_row = input + row * input_columns; - uint8_t* output_row = output_data + row * output_columns; + uint8_t* output_row = + output_data + row * output_columns / num_elem_per_byte; - for (std::size_t col = 0; col < output_columns; ++col) { + for (std::size_t col = 0; col < output_columns / num_elem_per_byte; + ++col) { output_row[col] = input_row[col]; } // output_columns } @@ -49,7 +80,8 @@ at::Tensor PackedEmbeddingBagWeight::unpack() { return weight_origin; } TORCH_INTERNAL_ASSERT( - "Currently only supporting 8-bit quantization of embedding bag."); + false, + "We currently only support 8-bit and 4-bit quantization of embedding_bag."); return weight_origin; } @@ -74,8 +106,16 @@ Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) { float* output_data = output.data_ptr(); #ifdef USE_FBGEMM - fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloat( - input, input_rows, input_columns, output_data); + at::parallel_for( + 0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloat( + input + row * input_columns, + 1, + input_columns, + output_data + row * output_columns); + } + }); #else for (std::size_t row = 0; row < input_rows; ++row) { const std::uint8_t* input_row = input + row * input_columns; @@ -113,8 +153,16 @@ Tensor _qembeddingbag_nbit_unpack_helper( packed_weight.suggest_memory_format()); float* output_data = output.data_ptr(); #ifdef USE_FBGEMM - fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloat( - BIT_RATE, input_data, input_rows, input_columns, output_data); + at::parallel_for( + 0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) { + for (int64_t row = start_idx; row < end_idx; ++row) { + fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloat(BIT_RATE, + input_data + row * input_columns, + 1, + input_columns, + output_data + row * output_dimensions[1]); + } + }); #else auto output_columns = output_dimensions[1]; for (size_t row = 0; row < input_rows; ++row) { @@ -171,15 +219,23 @@ class QEmbeddingUnpackWeights final { }; TORCH_LIBRARY_IMPL(quantized, CPU, m) { - m.impl("embedding_bag_byte_unpack", qembeddingbag_byte_unpack); - m.impl("embedding_bag_4bit_unpack", qembeddingbag_4bit_unpack); - m.impl("embedding_bag_2bit_unpack", qembeddingbag_2bit_unpack); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"), + qembeddingbag_byte_unpack); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"), + qembeddingbag_4bit_unpack); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_unpack"), + qembeddingbag_2bit_unpack); } TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { // Unpack the packed embedding_bag weights using TorchBind custom class. // TODO extend to support 4-bit qtensor. - m.impl("embedding_bag_unpack", TORCH_FN(QEmbeddingUnpackWeights::run)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_unpack"), + TORCH_FN(QEmbeddingUnpackWeights::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp b/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp index 2ef3106bd20bb..c60f49f5de8c3 100644 --- a/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +++ b/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp @@ -43,9 +43,10 @@ Tensor qnnpack_hardsigmoid(Tensor input) { "failed to create QNNPACK Hardsigmoid operator"); Tensor qy = at::_empty_affine_quantized( input_contig.sizes(), - input_contig.options(), + at::device(kCPU).dtype(input_contig.dtype()), o_scale, - o_zero_point); + o_zero_point, + input_contig.suggest_memory_format()); const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_hardsigmoid_nc_q8( hardsigmoid_op, diff --git a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp index f0dbd644b2bee..064b88a8c91fd 100644 --- a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp +++ b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp @@ -85,7 +85,7 @@ Tensor quantized_hardswish(const Tensor& qx, double output_scale, int64_t output } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("hardswish", TORCH_FN(quantized_hardswish)); + m.impl(TORCH_SELECTIVE_NAME("quantized::hardswish"), TORCH_FN(quantized_hardswish)); } }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index fdc6d1dd4d8b5..a7b4f4b743578 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -397,12 +397,12 @@ class QLinearInt8 final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("linear", TORCH_FN(QLinearInt8::run)); - m.impl("linear_relu", TORCH_FN(QLinearInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), TORCH_FN(QLinearInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), TORCH_FN(QLinearInt8::run)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { - m.impl("linear", TORCH_FN(QLinearInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::linear"), TORCH_FN(QLinearInt8::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 2accf060deab5..af2d7749ee50d 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -455,13 +455,13 @@ class QLinearDynamicFp16 final { }; TORCH_LIBRARY_IMPL(quantized, CPU, m) { - m.impl("linear_dynamic", TORCH_FN(QLinearDynamicInt8::run)); - m.impl("linear_relu_dynamic", TORCH_FN(QLinearDynamicInt8::run)); - m.impl("linear_dynamic_fp16", TORCH_FN(QLinearDynamicFp16::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16::run)); } TORCH_LIBRARY_IMPL(_quantized, CPU, m) { - m.impl("linear_dynamic", TORCH_FN(QLinearDynamicInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index ee4b6ee2aaf61..23912f87d1235 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -318,22 +318,22 @@ class QLinearPackWeightFp16Legacy final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("linear_prepack", TORCH_FN(QLinearPackWeightInt8::run)); - m.impl("linear_prepack_legacy", TORCH_FN(QLinearPackWeightInt8Legacy::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_legacy"), TORCH_FN(QLinearPackWeightInt8Legacy::run)); } TORCH_LIBRARY_IMPL(quantized, CPU, m) { - m.impl("linear_prepack_fp16", TORCH_FN(QLinearPackWeightFp16::run)); - m.impl("linear_prepack_fp16_legacy", TORCH_FN(QLinearPackWeightFp16Legacy::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { - m.impl("linear_prepack", TORCH_FN(QLinearPackWeightInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run)); } TORCH_LIBRARY_IMPL(_quantized, CPU, m) { - m.impl("linear_prepack_fp16", TORCH_FN(QLinearPackWeightFp16::run)); - m.impl("linear_prepack_fp16_legacy", TORCH_FN(QLinearPackWeightFp16Legacy::run)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp index 1bc8711a22f4d..ecbae04dd9572 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp @@ -137,13 +137,13 @@ class QLinearUnpackWeightFp16Legacy final { }; TORCH_LIBRARY_IMPL(quantized, CPU, m) { - m.impl("linear_unpack.legacy", TORCH_FN(QLinearUnpackWeightInt8Legacy::run)); - m.impl("linear_unpack_fp16.legacy", TORCH_FN(QLinearUnpackWeightFp16Legacy::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_unpack.legacy"), TORCH_FN(QLinearUnpackWeightInt8Legacy::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_unpack_fp16.legacy"), TORCH_FN(QLinearUnpackWeightFp16Legacy::run)); } TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { - m.impl("linear_unpack", TORCH_FN(QLinearUnpackWeightInt8::run)); - m.impl("linear_unpack_fp16", TORCH_FN(QLinearUnpackWeightFp16::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_unpack"), TORCH_FN(QLinearUnpackWeightInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_unpack_fp16"), TORCH_FN(QLinearUnpackWeightFp16::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index 13aa8acc669a8..a54b29e44a8cf 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -59,7 +59,7 @@ Tensor _mul_scalar_out(Tensor& out, const Tensor& self, Scalar other) { } else { out.copy_(self); } - out.set_quantizer_(make_per_tensor_affine_quantizer( + set_quantizer_(out, make_per_tensor_affine_quantizer( scale_prime, zero_point_prime, self.scalar_type())); } else if (other_val == 0.0) { scale_prime = 1.0; @@ -74,7 +74,7 @@ Tensor _mul_scalar_out(Tensor& out, const Tensor& self, Scalar other) { [&](Vec256 vec) -> Vec256 { return Vec256(scalar_t(0)); }); - out.set_quantizer_(make_per_tensor_affine_quantizer( + set_quantizer_(out, make_per_tensor_affine_quantizer( scale_prime, zero_point_prime, self.scalar_type())); } else /* other_val < 0.0 */ { scale_prime = std::abs(other_val) * self_scale; @@ -91,7 +91,7 @@ Tensor _mul_scalar_out(Tensor& out, const Tensor& self, Scalar other) { } return a; }); - out.set_quantizer_(make_per_tensor_affine_quantizer( + set_quantizer_(out, make_per_tensor_affine_quantizer( scale_prime, zero_point_prime, self.scalar_type())); } }); @@ -136,6 +136,18 @@ class QMulScalar final { } }; +template +class QMulScalar2 final { + public: + static Tensor run(Scalar b, Tensor qa) { + TORCH_CHECK(qa.qscheme() == kPerTensorAffine || + qa.qscheme() == kPerTensorSymmetric, + "Only per tensor quantization is supported in Mul."); + auto qc = at::empty_like(qa, qa.suggest_memory_format()); + return _mul_scalar_out(qc, qa, b); + } +}; + template class QMulScalarOut final { public: @@ -173,26 +185,28 @@ class QMulScalarTensorOut final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("mul", TORCH_FN(QMul::run)); - m.impl("mul.out", TORCH_FN(QMulOut::run)); - m.impl("mul.Scalar", TORCH_FN(QMulScalar::run)); - m.impl("mul.Scalar_out", TORCH_FN(QMulScalarOut::run)); - m.impl("mul_relu", TORCH_FN(QMul::run)); - m.impl("mul_relu.out", TORCH_FN(QMulOut::run)); - m.impl("mul_relu.Scalar", TORCH_FN(QMulScalar::run)); - m.impl("mul_relu.Scalar_out", TORCH_FN(QMulScalarOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul"), TORCH_FN(QMul::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul.out"), TORCH_FN(QMulOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar2"), TORCH_FN(QMulScalar2::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar_out"), TORCH_FN(QMulScalarOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu"), TORCH_FN(QMul::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.out"), TORCH_FN(QMulOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar2"), TORCH_FN(QMulScalar2::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar_out"), TORCH_FN(QMulScalarOut::run)); // deprecated functions, kept for backward compatibility - m.impl("mul_out", TORCH_FN(QMulOut::run)); - m.impl("mul_relu_out", TORCH_FN(QMulOut::run)); - m.impl("mul_scalar", TORCH_FN(QMulScalar::run)); - m.impl("mul_scalar_relu", TORCH_FN(QMulScalar::run)); - m.impl("mul_scalar_out", TORCH_FN(QMulScalarOut::run)); - m.impl("mul_scalar_relu_out", TORCH_FN(QMulScalarOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_out"), TORCH_FN(QMulOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu_out"), TORCH_FN(QMulOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_out"), TORCH_FN(QMulScalarOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu_out"), TORCH_FN(QMulScalarOut::run)); // TODO: remove after broadcasting is supported - m.impl("mul_scalar.Tensor", TORCH_FN(QMulScalarTensor::run)); - m.impl("mul_scalar_relu.Tensor", TORCH_FN(QMulScalarTensor::run)); - m.impl("mul_scalar_out.Tensor", TORCH_FN(QMulScalarTensorOut::run)); - m.impl("mul_scalar_relu_out.Tensor", TORCH_FN(QMulScalarTensorOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar.Tensor"), TORCH_FN(QMulScalarTensor::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu.Tensor"), TORCH_FN(QMulScalarTensor::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_out.Tensor"), TORCH_FN(QMulScalarTensorOut::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu_out.Tensor"), TORCH_FN(QMulScalarTensorOut::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt index 6f6133f931b36..01c815139de34 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt @@ -21,6 +21,12 @@ option(PYTORCH_QNNPACK_BUILD_BENCHMARKS "Build QNNPACK benchmarks" ON) # Enable runtime requantization. add_definitions(-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION=1) +# ---[ Target processor +SET(PYTORCH_QNNPACK_TARGET_PROCESSOR "${CMAKE_SYSTEM_PROCESSOR}") +IF(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64)$") + SET(PYTORCH_QNNPACK_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") +ENDIF() + # ---[ CMake options if(PYTORCH_QNNPACK_BUILD_TESTS) enable_testing() @@ -40,7 +46,7 @@ if(NOT CMAKE_SYSTEM_PROCESSOR) else() message(FATAL_ERROR "CMAKE_SYSTEM_PROCESSOR is not defined") endif() -elseif(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|x86_64|armv[5-8].*|aarch64)$") +elseif(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|x86_64|armv[5-8].*|aarch64|arm64)$") message(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_PROCESSOR = ${CMAKE_SYSTEM_PROCESSOR}") endif() @@ -244,11 +250,11 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") list(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS}) list(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS}) endif() -if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR IOS_ARCH MATCHES "^arm64.*") +if(PYTORCH_QNNPACK_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$" OR IOS_ARCH MATCHES "^arm64.*") list(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS}) list(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS}) endif() -if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|x86_64)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") +if(PYTORCH_QNNPACK_TARGET_PROCESSOR MATCHES "^(i[3-6]86|x86_64)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") list(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_X86_SSE2_UKERNELS}) endif() @@ -271,13 +277,13 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") set_property(SOURCE ${PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") endif() endif() -if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR IOS_ARCH MATCHES "^arm64.*") +if(PYTORCH_QNNPACK_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$" OR IOS_ARCH MATCHES "^arm64.*") set_property(SOURCE ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 ") if(IOS) set_property(SOURCE ${PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") endif() endif() -if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|x86_64)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") +if(PYTORCH_QNNPACK_TARGET_PROCESSOR MATCHES "^(i[3-6]86|x86_64)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") set_property(SOURCE ${PYTORCH_QNNPACK_X86_SSE2_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 -msse2 ") endif() if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c index 3145442299dc9..fa1fdebdd4d4b 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c @@ -30,16 +30,19 @@ void pytorch_q8avgpool_ukernel_mp8x9p8q__neon( const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale); +#if defined(__aarch64__) const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif do { { diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c index 453cf80fa08b7..dc7209cd5f328 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c @@ -30,16 +30,19 @@ void pytorch_q8avgpool_ukernel_up8x9__neon( const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale); +#if defined(__aarch64__) const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif do { const uint8_t* i0 = input[0]; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon-per-channel.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon-per-channel.c index 940cd2847833b..3e2d11408eac3 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon-per-channel.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon-per-channel.c @@ -23,14 +23,17 @@ void pytorch_q8dwconv_ukernel_mp8x25_per_channel__neon( quantization_params[restrict static 1]) { const uint8x8_t vinput_zero_point = vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); +#ifdef __aarch64__ const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif do { uint8_t* output_start = output; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon.c index e338f6d9673ab..25c7957714d68 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon.c @@ -27,14 +27,17 @@ void pytorch_q8dwconv_ukernel_mp8x25__neon( vdup_n_u8(quantization_params->neon.kernel_zero_points[0]); const float32x4_t requantization_scale_v = vdupq_n_f32(quantization_params->neon.requantization_scales[0]); +#ifdef __aarch64__ const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif do { uint8_t* output_start = output; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon-per-channel.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon-per-channel.c index c8a102aaaa716..68ff1a3b41b1e 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon-per-channel.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon-per-channel.c @@ -23,16 +23,19 @@ void pytorch_q8dwconv_ukernel_up8x9_per_channel__neon( quantization_params[restrict static 1]) { const uint8x8_t va_zero_point = vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); +#ifdef __aarch64__ const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif #ifdef __aarch64__ /* Larger number of registers on AArch64 make it possible to process few diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c index b6dd3b7a4455e..9f442938f7dc6 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c @@ -27,16 +27,19 @@ void pytorch_q8dwconv_ukernel_up8x9__neon( vdup_n_u8(quantization_params->neon.kernel_zero_points[0]); const float32x4_t requantization_scale_v = vdupq_n_f32(quantization_params->neon.requantization_scales[0]); +#ifdef __aarch64__ const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif #ifdef __aarch64__ /* Larger number of registers on AArch64 make it possible to process few diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-neon.c index 88a59311bc89e..27040ef67280f 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-neon.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-neon.c @@ -119,16 +119,19 @@ void pytorch_q8gavgpool_ukernel_mp8x7p7q__neon( const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale); +#if defined(__aarch64__) const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif i0 = (const uint8_t*)((uintptr_t)i0 + input_increment); i1 = (const uint8_t*)((uintptr_t)i1 + input_increment); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c index 36359286bd065..3d69ef13a6045 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c @@ -52,16 +52,19 @@ void pytorch_q8gavgpool_ukernel_up8x7__neon( } const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); const float32x4_t vscale = vdupq_n_f32(quantization_params->neon.scale); +#if defined(__aarch64__) const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min); const uint8x8_t voutput_max = vld1_dup_u8(&quantization_params->neon.output_max); +#else const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin); const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax); const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic); const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic); +#endif do { const uint8x8_t vi0 = vld1_u8(i0); diff --git a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp index f5bef2b93a0ad..c8bbe9d29b24d 100644 --- a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp @@ -71,11 +71,8 @@ Tensor quantized_group_norm_impl( const int64_t batches = input_shape[0]; const int64_t num_channels = input_shape[1]; - const int64_t elements_per_batch = std::accumulate( - input_shape.cbegin() + 1, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t elements_per_batch = + prod_intlist(input_shape.cbegin() + 1, input_shape.cend()); const int64_t M = batches * num_groups; const int64_t N = elements_per_batch / num_groups; @@ -120,7 +117,7 @@ Tensor quantized_instance_norm_impl( TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { // TODO: this is kind of... blegh - m.impl("layer_norm", []( + m.impl(TORCH_SELECTIVE_NAME("quantized::layer_norm"), []( Tensor input, std::vector normalized_shape, // because IntArrayRef doesn't work c10::optional weight, @@ -134,7 +131,7 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { bias.has_value() ? *bias : Tensor(), eps, output_scale, output_zero_point); }); - m.impl("group_norm", []( + m.impl(TORCH_SELECTIVE_NAME("quantized::group_norm"), []( Tensor qx, int64_t num_groups, c10::optional weight, @@ -148,7 +145,7 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { bias.has_value() ? *bias : Tensor(), eps, output_scale, output_zero_point); }); - m.impl("instance_norm", []( + m.impl(TORCH_SELECTIVE_NAME("quantized::instance_norm"), []( Tensor qx, c10::optional weight, c10::optional bias, diff --git a/aten/src/ATen/native/quantized/cpu/qpool.cpp b/aten/src/ATen/native/quantized/cpu/qpool.cpp index f986ab4934b92..7fa56619609bf 100644 --- a/aten/src/ATen/native/quantized/cpu/qpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/qpool.cpp @@ -134,7 +134,12 @@ Tensor q_maxpool_2d( int64_t oC = iC; int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode); int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode); - TORCH_CHECK(oH > 0 && oW > 0, "the resulting Tensor is too small."); + TORCH_CHECK(oH > 0 && oW > 0, + "Given input size: (", + iC, "x", iH, "x", iW, + "). Calculated output size: (", + oC, "x", oH, "x", oW, + "). Output size is too small."); std::vector oSizes; if (ndim == 3) { @@ -232,7 +237,7 @@ void check_maxpool2d_params( } #ifdef USE_PYTORCH_QNNPACK - static Tensor qnnpack_maxpool( + static Tensor qnnpack_maxpool2d( Tensor input, IntArrayRef kernel_size, IntArrayRef stride, @@ -243,23 +248,23 @@ void check_maxpool2d_params( TORCH_CHECK( input.ndimension() == 4, - "qnnpack_maxpool(): Expected input to be 4-dimensional: got ", + "qnnpack_maxpool2d(): Expected input to be 4-dimensional: got ", input.ndimension()); TORCH_CHECK( kernel_size.size() == 2, - "qnnpack_maxpool(): Expected kernel_size to be 2-dimensional: got ", + "qnnpack_maxpool2d(): Expected kernel_size to be 2-dimensional: got ", kernel_size.size()); TORCH_CHECK( stride.size() == 2, - "qnnpack_maxpool(): Expected stride to be 2-dimensional: got ", + "qnnpack_maxpool2d(): Expected stride to be 2-dimensional: got ", stride.size()); TORCH_CHECK( dilation.size() == 2, - "qnnpack_maxpool(): Expected dilation to be 2-dimensional: got ", + "qnnpack_maxpool2d(): Expected dilation to be 2-dimensional: got ", dilation.size()); TORCH_CHECK( padding.size() == 2, - "qnnpack_maxpool(): Expected padding to be 2-dimensional: got ", + "qnnpack_maxpool2d(): Expected padding to be 2-dimensional: got ", padding.size()); int64_t batch_size = input.size(0); @@ -284,10 +289,10 @@ void check_maxpool2d_params( TORCH_CHECK( kH > 0 && kW > 0, - "qnnpack_maxpool(): kernel_size should be greater than zero."); + "qnnpack_maxpool2d(): kernel_size should be greater than zero."); TORCH_CHECK( strideH > 0 && strideW > 0, - "qnnpack_maxpool(): strides should be greater than zero."); + "qnnpack_maxpool2d(): strides should be greater than zero."); const pytorch_qnnp_status createStatus = pytorch_qnnp_create_max_pooling2d_nhwc_u8( @@ -318,7 +323,7 @@ void check_maxpool2d_params( TORCH_CHECK( outH > 0 && outW > 0, - "qnnpack_maxpool(): the resulting output Tensor size should be >= 0"); + "qnnpack_maxpool2d(): the resulting output Tensor size should be >= 0"); std::unique_ptr qnnpack_uniq_ptr(qnnpack_operator); @@ -375,7 +380,7 @@ Tensor quantized_max_pool2d( } #ifdef USE_PYTORCH_QNNPACK if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8 && !ceil_mode) { - return qnnpack_maxpool(qx, kernel_size, stride, padding, dilation, ceil_mode); + return qnnpack_maxpool2d(qx, kernel_size, stride, padding, dilation, ceil_mode); } #endif Tensor qy; @@ -395,9 +400,37 @@ Tensor quantized_max_pool2d( return qy; } +// Quantized max_pool1d is a special case of the max_pool2d, with one of the +// dimensions and kernels removed. +Tensor quantized_max_pool1d( + const Tensor& qx, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + // (C, L) -> (C, 1, L) => kSqueezeDim = 1 + // (N, C, L) -> (N, C, 1, L) => kSqueezeDim = 2 + const int32_t kSqueezeDim = qx.dim() - 1; + const auto qx_unsqueeze = qx.unsqueeze(kSqueezeDim); + if (stride.empty()) { + stride = kernel_size; + } + auto qy = at::quantized_max_pool2d( + qx.unsqueeze(kSqueezeDim), + {1, kernel_size[0]}, + {1, stride[0]}, + {0, padding[0]}, + {1, dilation[0]}, + ceil_mode); + qy = qy.squeeze(kSqueezeDim); + return qy; +} + // Keep the registry in the anonymous namespace. namespace { -class QMaxPool2D_arr_args final { +template +class QMaxPool_arr_args final { public: static Tensor run( Tensor qx, @@ -406,17 +439,20 @@ class QMaxPool2D_arr_args final { std::vector padding, std::vector dilation, bool ceil_mode) { - #ifdef USE_PYTORCH_QNNPACK - if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8 && !ceil_mode) { - return qnnpack_maxpool(qx, kernel_size, stride, padding, dilation, ceil_mode); + if (kSpatialDim == 1) { + return at::quantized_max_pool1d(qx, kernel_size, stride, padding, + dilation, ceil_mode); + } else if (kSpatialDim == 2) { + return at::quantized_max_pool2d(qx, kernel_size, stride, padding, + dilation, ceil_mode); } - #endif - return at::max_pool2d(qx, kernel_size, stride, padding, dilation, ceil_mode); + TORCH_CHECK(false, "MaxPool", kSpatialDim, "D is not supported."); } }; TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("max_pool2d", TORCH_FN(QMaxPool2D_arr_args::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool1d"), TORCH_FN(QMaxPool_arr_args<1>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::max_pool2d"), TORCH_FN(QMaxPool_arr_args<2>::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qreduction.cpp b/aten/src/ATen/native/quantized/cpu/qreduction.cpp index 739638b7a67ee..74b2661142300 100644 --- a/aten/src/ATen/native/quantized/cpu/qreduction.cpp +++ b/aten/src/ATen/native/quantized/cpu/qreduction.cpp @@ -83,7 +83,14 @@ Tensor& mean_out_quantized_cpu( c10::optional opt_dtype) { #ifdef USE_PYTORCH_QNNPACK if (at::globalContext().qEngine() == at::QEngine::QNNPACK && - self.scalar_type() == kQUInt8) { + self.scalar_type() == kQUInt8 && + // QNNPACK currently is only supported for NCHW + dim=(2, 3) + // Remove these checks after generic version is implemented. + self.ndimension() == 4 && + dim.size() == 2 && + dim[0] == 2 && + dim[1] == 3 + ){ result = qnnpack_mean(self, dim); return result; } diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp index 447e5cb23af5b..ca03081a1a255 100644 --- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp @@ -113,7 +113,7 @@ Tensor& leaky_relu_out_quantized_cpu(Tensor& result, const Tensor& self, return result; } -Tensor heaky_relu_quantized_cpu(const Tensor& self, Scalar negval) { +Tensor leaky_relu_quantized_cpu(const Tensor& self, Scalar negval) { const auto qx = self.contiguous(self.suggest_memory_format()); auto qy = at::_empty_affine_quantized(qx.sizes(), at::device(kCPU).dtype(self.scalar_type()), @@ -170,8 +170,27 @@ class QRelu6 final { } }; +class QLeakyRelu final { + public: + static Tensor run(Tensor self, Scalar negative_slope, bool inplace, double output_scale, int64_t output_zero_point) { + // inplace argument is ignored now, TODO:support inplace + if (inplace) { + TORCH_WARN("inplace=True is not supported for quantized::leaky_relu yet"); + } + const auto qx = self.contiguous(self.suggest_memory_format()); + auto qy = at::_empty_affine_quantized(qx.sizes(), + at::device(kCPU).dtype(self.scalar_type()), + output_scale, + output_zero_point, + self.suggest_memory_format()); + qrelu_leaky_stub(self.device().type(), qy, qx, negative_slope); + return qy; + } +}; + TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("relu6", TORCH_FN(QRelu6::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::relu6"), TORCH_FN(QRelu6::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::leaky_relu"), TORCH_FN(QLeakyRelu::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp index 5c2bcd859bed4..4ee7ec2b1d545 100644 --- a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +++ b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp @@ -17,15 +17,11 @@ namespace native { DEFINE_DISPATCH(qsigmoid_stub); #ifdef USE_PYTORCH_QNNPACK -// This ALWAYS outputs scale=1.0/256, dtype=quint8 -// The zero_point is 0 for qint32 and quint8, but -128 for qint8. -Tensor qnnpack_sigmoid(Tensor input) { +Tensor qnnpack_sigmoid( + Tensor input, double output_scale, int64_t output_zero_point) { TORCH_CHECK(input.ndimension() > 0, "qnnpack_sigmoid(): Got empty input tensor"); Tensor qy; - constexpr float output_scale = 1.0f / 256.0f; - constexpr int32_t output_zero_point = 0; - initQNNPACK(); Tensor input_contig = input.contiguous(input.suggest_memory_format()); @@ -52,9 +48,10 @@ Tensor qnnpack_sigmoid(Tensor input) { "failed to create QNNPACK sigmoid operator"); qy = at::_empty_affine_quantized( input_contig.sizes(), - input.options(), + at::device(kCPU).dtype(input_contig.dtype()), output_scale, - output_zero_point); + output_zero_point, + input_contig.suggest_memory_format()); const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_sigmoid_nc_q8( sigmoid_op, @@ -76,17 +73,60 @@ Tensor qnnpack_sigmoid(Tensor input) { "failed to run QNNPACK sigmoid operator"); return qy; } + #endif // USE_PYTORCH_QNNPACK +// This ALWAYS outputs scale=1.0/256, dtype=quint8 +// The zero_point is 0 for qint32 and quint8, but -128 for qint8. Tensor sigmoid_quantized_cpu(const Tensor& qx) { #ifdef USE_PYTORCH_QNNPACK if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) { - return qnnpack_sigmoid(qx); + constexpr double output_scale = 1.0f / 256.0f; + constexpr int64_t output_zero_point = 0; + return qnnpack_sigmoid(qx, output_scale, output_zero_point); } #endif // USE_PYTORCH_QNNPACK Tensor qy; - qsigmoid_stub(qx.device().type(), qx, qy); + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() { + // Naive implemenentation: uses dequantize/execute/quantize routine + // - Output scale is set to 1.0 / 2^(BIT_NUM) + // - For signed types output zero point is set to 0 + // - For unsigned types output zero point is set to (qmax + qmin) / 2.0 + // See https://stackoverflow.com/a/34448562/3606192 for potential + // optimizations + double output_scale = 0.00390625; // 1.0 / 2^8 + int64_t output_zero_point = 0; + if (SCALAR_TYPE == at::kQInt32) { + output_scale = 2.3283064365386963e-10; // 1.0 / 2^32 + } else if (SCALAR_TYPE == at::kQInt8) { + output_zero_point = -128; + } + qsigmoid_stub(qx.device().type(), qx, qy, output_scale, output_zero_point); + }); return qy; } + +namespace { + +class QSigmoid final { + public: + static Tensor run(Tensor qx, double output_scale, int64_t output_zero_point) { +#ifdef USE_PYTORCH_QNNPACK + if (at::globalContext().qEngine() == at::QEngine::QNNPACK && + qx.scalar_type() == kQUInt8) { + return qnnpack_sigmoid(qx, output_scale, output_zero_point); + } +#endif // USE_PYTORCH_QNNPACK + Tensor qy; + qsigmoid_stub(qx.device().type(), qx, qy, output_scale, output_zero_point); + return qy; + } +}; + +TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("quantized::sigmoid"), TORCH_FN(QSigmoid::run)); +} +} // namespace + }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qtanh.cpp b/aten/src/ATen/native/quantized/cpu/qtanh.cpp index 60d58ec39b270..a61b337f841d2 100644 --- a/aten/src/ATen/native/quantized/cpu/qtanh.cpp +++ b/aten/src/ATen/native/quantized/cpu/qtanh.cpp @@ -50,9 +50,10 @@ Tensor qnnpack_tanh(Tensor input) { "failed to create QNNPACK TanH operator"); qy = at::_empty_affine_quantized( input_contig.sizes(), - input.options(), + at::device(kCPU).dtype(input_contig.dtype()), output_scale, - output_zero_point); + output_zero_point, + input_contig.suggest_memory_format()); const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_tanh_nc_q8( tanh_op, diff --git a/aten/src/ATen/native/quantized/cpu/qthreshold.cpp b/aten/src/ATen/native/quantized/cpu/qthreshold.cpp index 281274d27be2b..a42da4081c713 100644 --- a/aten/src/ATen/native/quantized/cpu/qthreshold.cpp +++ b/aten/src/ATen/native/quantized/cpu/qthreshold.cpp @@ -35,7 +35,7 @@ Tensor threshold_quantized_cpu( } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { - m.impl("threshold", TORCH_FN(threshold_quantized_cpu)); + m.impl(TORCH_SELECTIVE_NAME("quantized::threshold"), TORCH_FN(threshold_quantized_cpu)); } } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/quant_utils.h b/aten/src/ATen/native/quantized/cpu/quant_utils.h index 0d63c3d9cae58..6e4bcf1dee8c9 100644 --- a/aten/src/ATen/native/quantized/cpu/quant_utils.h +++ b/aten/src/ATen/native/quantized/cpu/quant_utils.h @@ -127,7 +127,8 @@ inline TensorQuantizationParams ChooseQuantizationParams( // to be a middle value between qmin and qmax. // If either min or max is 0, then we just use 0 as zero_point. if (min < 0 && max > 0 && preserve_sparsity) { - initial_zero_point = (qmin + qmax) / 2 + 1; + const auto midpoint = qmin + (qmax - qmin) / 2; // Overflow-safe midpoint + initial_zero_point = midpoint + 1; } // Now we need to nudge the zero point to be an integer diff --git a/aten/src/ATen/native/quantized/cpu/quantized_ops.h b/aten/src/ATen/native/quantized/cpu/quantized_ops.h index baf522731e6d8..e275e3cdc7eda 100644 --- a/aten/src/ATen/native/quantized/cpu/quantized_ops.h +++ b/aten/src/ATen/native/quantized/cpu/quantized_ops.h @@ -8,13 +8,17 @@ namespace native { using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, Scalar /*negval_*/); -using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point); using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); using qclamp_fn = void (*)( const at::Tensor& /*qx*/, Scalar min, Scalar max, at::Tensor& /*qy*/); +using qclamp_minmax_fn = void (*)( + const at::Tensor& /*qx*/, + Scalar /*min or max*/, + at::Tensor& /*qy*/); using qthreshold_fn = void (*)( const at::Tensor& /*qx*/, Scalar threshold, @@ -167,6 +171,8 @@ DECLARE_DISPATCH(qbinary_fn, qmul_stub); DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub); DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub); DECLARE_DISPATCH(qclamp_fn, qclamp_stub); +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub); +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub); DECLARE_DISPATCH(qelu_fn, qelu_stub); DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub); DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub); diff --git a/aten/src/ATen/native/quantized/cuda/affine_quantizer.cu b/aten/src/ATen/native/quantized/cuda/affine_quantizer.cu index 2c0c2a312e076..12f9058f6efd8 100644 --- a/aten/src/ATen/native/quantized/cuda/affine_quantizer.cu +++ b/aten/src/ATen/native/quantized/cuda/affine_quantizer.cu @@ -25,14 +25,16 @@ void quantize_tensor_per_tensor_affine_cuda( .add_input(qtensor) .build(); - gpu_kernel(iter, - [=] GPU_LAMBDA (float raw_val, scalar_t quantized_val) -> scalar_t { - int64_t qvalue = static_cast(nearbyint(raw_val / scale + zero_point)); - qvalue = std::max(qvalue, qmin); - qvalue = std::min(qvalue, qmax); - quantized_val.val_ = qvalue; - return quantized_val; - }); + gpu_kernel( + iter, + [=] GPU_LAMBDA(float raw_val, scalar_t quantized_val) -> scalar_t { + int64_t qvalue = + static_cast(nearbyint(raw_val / scale) + zero_point); + qvalue = std::max(qvalue, qmin); + qvalue = std::min(qvalue, qmax); + quantized_val.val_ = qvalue; + return quantized_val; + }); }); } diff --git a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu index 8e25f5ff443de..e2f51398b48f3 100644 --- a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu +++ b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu @@ -34,17 +34,16 @@ void fake_quantize_tensor_kernel_cuda( .add_output(output) .add_input(input) .build(); - gpu_kernel(iter, - [=] GPU_LAMBDA (float input_val) -> float { - return (fminf( + gpu_kernel(iter, [=] GPU_LAMBDA(float input_val) -> float { + return (fminf( quant_max, fmaxf( quant_min, - static_cast(std::nearbyint( - input_val * inv_scale + zero_point)))) - + static_cast( + std::nearbyint(input_val * inv_scale) + zero_point))) - zero_point) * - scale; - }); + scale; + }); } void fake_quantize_grad_tensor_kernel_cuda( @@ -63,11 +62,10 @@ void fake_quantize_grad_tensor_kernel_cuda( .add_input(output_grad) .add_input(input) .build(); - gpu_kernel(iter, - [=] GPU_LAMBDA (float dy, float x) -> float { - int64_t Xq = std::nearbyint(x * inv_scale + zero_point); - return (Xq >= quant_min && Xq <= quant_max) * dy; - }); + gpu_kernel(iter, [=] GPU_LAMBDA(float dy, float x) -> float { + int64_t Xq = std::nearbyint(x * inv_scale) + zero_point; + return (Xq >= quant_min && Xq <= quant_max) * dy; + }); } void _fake_quantize_grad_learnable_tensor_kernel_cuda( @@ -82,7 +80,7 @@ void _fake_quantize_grad_learnable_tensor_kernel_cuda( gpu_kernel_multiple_outputs( iter, [=] GPU_LAMBDA (float XInput, float dYInput) -> thrust::tuple { float dXOutput, dZeroPointOutput, dScaleOutput; - int64_t xq = std::nearbyint(zero_point + XInput * inv_scale); + int64_t xq = std::nearbyint(XInput * inv_scale) + zero_point; dXOutput = dYInput * (xq >= quant_min && xq <= quant_max); xq = std::max(std::min(xq, quant_max), quant_min); float xfq = static_cast((xq - zero_point) * scale); @@ -108,12 +106,13 @@ void fake_quant_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_ [=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> float { float inv_scale = 1.0f / scale; return (fminf( - quant_max, - fmaxf( - quant_min, - static_cast(std::nearbyint( - input_val * inv_scale + zero_point)))) - - zero_point) * + quant_max, + fmaxf( + quant_min, + static_cast( + std::nearbyint(input_val * inv_scale) + + zero_point))) - + zero_point) * scale; }); } @@ -122,7 +121,7 @@ void fake_quant_grad_per_channel_cuda(TensorIterator &iter, int64_t quant_min, i gpu_kernel(iter, [=] GPU_LAMBDA (float x, float dy, float scale, int64_t zero_point) -> float { float inv_scale = 1.0f / scale; - int64_t Xq = std::nearbyint(x * inv_scale + zero_point); + int64_t Xq = std::nearbyint(x * inv_scale) + zero_point; return (Xq >= quant_min && Xq <= quant_max) * dy; }); } diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 6049ccbe1e467..2c8a6d4e49466 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -20,174 +20,181 @@ TORCH_LIBRARY(quantized, m) { register_conv_params<3>(); register_embedding_params(); - m.def("add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc"); - m.def("add.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out"); - m.def("add.Scalar(Tensor qa, Scalar b) -> Tensor qc"); - m.def("add.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out"); - m.def("add_relu(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc"); - m.def("add_relu.Scalar(Tensor qa, Scalar b) -> Tensor qc"); - m.def("add_relu.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out"); - m.def("add_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar2(Scalar b, Tensor qa) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar2(Scalar b, Tensor qa) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); // deprecated functions, kept for backward compatibility - m.def("add_out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out"); - m.def("add_relu_out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out"); - m.def("add_scalar(Tensor qa, Scalar b) -> Tensor qc"); - m.def("add_scalar_relu(Tensor qa, Scalar b) -> Tensor qc"); - m.def("add_scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out"); - m.def("add_scalar_relu_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu_out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar_relu(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar_relu_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); // TODO: remove after broadcasting is supported - m.def("add_scalar_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out) -> Tensor(a!) out"); - m.def("add_scalar.Tensor(Tensor qa, Tensor b) -> Tensor qc"); - m.def("add_scalar_relu.Tensor(Tensor qa, Tensor b) -> Tensor qc"); - m.def("add_scalar_relu_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out) -> Tensor(a!) out"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out) -> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar.Tensor(Tensor qa, Tensor b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar_relu.Tensor(Tensor qa, Tensor b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_scalar_relu_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out) -> Tensor(a!) out")); // This is needed for graph mode quantization, when we fuse // dequant - aten::batch_norm - quant into quantized::batch_norm // and dimension is unknown given only the aten op call // quantized::batch_norm supports both 2d and 3d batch norm right now // it should also support 1d batch_norm after quantized::batch_norm1d is // implemented - m.def("batch_norm(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("batch_norm_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("batch_norm1d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("batch_norm1d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("batch_norm2d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("batch_norm2d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("batch_norm3d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("batch_norm3d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("clamp(Tensor qx, Scalar? min, Scalar? max) -> Tensor qy"); - m.def("threshold(Tensor qx, Scalar threshold, Scalar value) -> Tensor qy"); - m.def("cat(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor"); - m.def("cat_relu(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor"); - m.def("cat_out(Tensor[] qx, int dim, Tensor(a!) out) -> Tensor(a!)"); - m.def("cat_relu_out(Tensor[] qx, int dim, Tensor(a!) out) -> Tensor(a!)"); - m.def("conv1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv1d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv3d.new(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv3d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv2d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv3d(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv3d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm1d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm1d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm2d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm2d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm3d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm3d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::clamp(Tensor qx, Scalar? min=None, Scalar? max=None) -> Tensor qy")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::threshold(Tensor qx, Scalar threshold, Scalar value) -> Tensor qy")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::cat(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::cat_relu(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::cat_out(Tensor[] qx, int dim, Tensor(a!) out) -> Tensor(a!)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::cat_relu_out(Tensor[] qx, int dim, Tensor(a!) out) -> Tensor(a!)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv1d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d.new(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase weight, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor")); // conv_prepack is deprecated, please use conv2d_prepack for 2D conv. - m.def("conv_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"); - m.def("conv1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"); - m.def("conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"); - m.def("conv3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase")); // conv_unpack is deprecated, please use conv2d_unpack for 2D conv. - m.def("conv_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"); - m.def("conv1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"); - m.def("conv2d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"); - m.def("conv3d_unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"); - m.def("conv2d_stride(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"); - m.def("conv2d_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"); - m.def("conv2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"); - m.def("conv2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"); - m.def("conv2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"); - m.def("conv3d_stride(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]"); - m.def("conv3d_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]"); - m.def("conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]"); - m.def("conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_stride(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_output_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_stride(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_output_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_transpose(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); // conv_tranpsose - m.def("conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"); - m.def("conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"); - m.def("conv_transpose1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"); - m.def("conv_transpose2d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"); - m.def("conv_transpose2d_stride(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"); - m.def("conv_transpose2d_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"); - m.def("conv_transpose2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"); - m.def("conv_transpose2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"); - m.def("conv_transpose2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_stride(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_output_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_stride(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_output_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_transpose(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); - m.def("elu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor"); - m.def("embedding_bag_prepack(Tensor weight) -> __torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack"); - m.def("embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"); - m.def("embedding_bag_byte_prepack(Tensor weight) -> Tensor"); - m.def("embedding_bag_byte_unpack(Tensor weight) -> Tensor"); - m.def("embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False) -> Tensor"); - m.def("embedding_bag_4bit_unpack(Tensor weight) -> Tensor"); - m.def("embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False) -> Tensor"); - m.def("embedding_bag_2bit_unpack(Tensor weight) -> Tensor"); - m.def("embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor"); - m.def("embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"); - m.def("embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"); - m.def("embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool sparse=False) -> Tensor"); - m.def("celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor"); - m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor"); - m.def("group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor"); - m.def("instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def("layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor"); - m.def( - "linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"); - m.def( - "linear_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"); - m.def( - "linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"); - m.def( - "linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"); - m.def( - "linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"); - m.def( - "linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"); - m.def( - "linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"); - m.def("linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"); - m.def( - "linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"); - m.def( - "linear_unpack(__torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin)"); - m.def( - "linear_unpack_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin)"); - m.def( - "linear_unpack.legacy(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)"); - m.def( - "linear_unpack_fp16.legacy(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)"); - m.def("mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc"); - m.def("mul.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out"); - m.def("mul.Scalar(Tensor qa, Scalar b)-> Tensor qc"); - m.def("mul.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out"); - m.def("mul_relu(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc"); - m.def("mul_relu.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out"); - m.def("mul_relu.Scalar(Tensor qa, Scalar b)-> Tensor qc"); - m.def("mul_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::elu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_prepack(Tensor weight) -> __torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_unpack(__torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_unpack_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_unpack.legacy(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_unpack_fp16.legacy(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar2(Scalar b, Tensor qa)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar2(Scalar b, Tensor qa)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); // deprecated functions, kept for backward compatibility - m.def("mul_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out"); - m.def("mul_relu_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out"); - m.def("mul_scalar(Tensor qa, Scalar b)-> Tensor qc"); - m.def("mul_scalar_relu(Tensor qa, Scalar b)-> Tensor qc"); - m.def("mul_scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out"); - m.def("mul_scalar_relu_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar_relu(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar_relu_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); // TODO: remove after broadcasting is supported - m.def("mul_scalar.Tensor(Tensor qa, Tensor b)-> Tensor qc"); - m.def("mul_scalar_relu.Tensor(Tensor qa, Tensor b)-> Tensor qc"); - m.def("mul_scalar_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out)-> Tensor(a!) out"); - m.def("mul_scalar_relu_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out)-> Tensor(a!) out"); - // NB: missing a space after comma here... - m.def("max_pool2d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation,bool ceil_mode) -> Tensor"); - m.def("relu6(Tensor qx, bool inplace=False) -> Tensor"); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar.Tensor(Tensor qa, Tensor b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar_relu.Tensor(Tensor qa, Tensor b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_scalar_relu_out.Tensor(Tensor qa, Tensor b, Tensor(a!) out)-> Tensor(a!) out")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool1d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool2d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor")); } // According to #33294: The "_" prefix registration will be // removed when the operators are all migrated to mobile. // https://github.com/pytorch/pytorch/issues/36510 TORCH_LIBRARY(_quantized, m) { - m.def("add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc"); - m.def("conv2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv2d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"); - m.def("conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"); - m.def( - "linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"); - m.def( - "linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"); - m.def( - "linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"); - m.def( - "linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"); - m.def("linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"); - m.def( - "linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv2d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); } diff --git a/aten/src/ATen/native/sparse/SparseMatMul.cpp b/aten/src/ATen/native/sparse/SparseMatMul.cpp new file mode 100644 index 0000000000000..84a98a5cc4c3c --- /dev/null +++ b/aten/src/ATen/native/sparse/SparseMatMul.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace native { + +using namespace at::sparse; + +/* + This is an implementation of the SMMP algorithm: + "Sparse Matrix Multiplication Package (SMMP)" + + Randolph E. Bank and Craig C. Douglas + https://doi.org/10.1007/BF02070824 +*/ +namespace { +void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) { + /* + Expands a compressed row pointer into a row indices array + Inputs: + `n_row` is the number of rows in `Ap` + `Ap` is the row pointer + + Output: + `Bi` is the row indices + */ + for (int64_t i = 0; i < n_row; i++) { + for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) { + Bi[jj] = i; + } + } +} + +int64_t _csr_matmult_maxnnz( + const int64_t n_row, + const int64_t n_col, + const int64_t Ap[], + const int64_t Aj[], + const int64_t Bp[], + const int64_t Bj[]) { + /* + Compute needed buffer size for matrix `C` in `C = A@B` operation. + + The matrices should be in proper CSR structure, and their dimensions + should be compatible. + */ + std::vector mask(n_col, -1); + int64_t nnz = 0; + for (int64_t i = 0; i < n_row; i++) { + int64_t row_nnz = 0; + + for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) { + int64_t j = Aj[jj]; + for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) { + int64_t k = Bj[kk]; + if (mask[k] != i) { + mask[k] = i; + row_nnz++; + } + } + } + int64_t next_nnz = nnz + row_nnz; + nnz = next_nnz; + } + return nnz; +} + +template +void _csr_matmult( + const int64_t n_row, + const int64_t n_col, + const int64_t Ap[], + const int64_t Aj[], + const scalar_t Ax[], + const int64_t Bp[], + const int64_t Bj[], + const scalar_t Bx[], + int64_t Cp[], + int64_t Cj[], + scalar_t Cx[]) { + /* + Compute CSR entries for matrix C = A@B. + + The matrices `A` and 'B' should be in proper CSR structure, and their dimensions + should be compatible. + + Inputs: + `n_row` - number of row in A + `n_col` - number of columns in B + `Ap[n_row+1]` - row pointer + `Aj[nnz(A)]` - column indices + `Ax[nnz(A)] - nonzeros + `Bp[?]` - row pointer + `Bj[nnz(B)]` - column indices + `Bx[nnz(B)]` - nonzeros + Outputs: + `Cp[n_row+1]` - row pointer + `Cj[nnz(C)]` - column indices + `Cx[nnz(C)]` - nonzeros + + Note: + Output arrays Cp, Cj, and Cx must be preallocated + */ + std::vector next(n_col, -1); + std::vector sums(n_col, 0); + + int64_t nnz = 0; + + Cp[0] = 0; + + for (int64_t i = 0; i < n_row; i++) { + int64_t head = -2; + int64_t length = 0; + + int64_t jj_start = Ap[i]; + int64_t jj_end = Ap[i + 1]; + for (int64_t jj = jj_start; jj < jj_end; jj++) { + int64_t j = Aj[jj]; + scalar_t v = Ax[jj]; + + int64_t kk_start = Bp[j]; + int64_t kk_end = Bp[j + 1]; + for (int64_t kk = kk_start; kk < kk_end; kk++) { + int64_t k = Bj[kk]; + + sums[k] += v * Bx[kk]; + + if (next[k] == -1) { + next[k] = head; + head = k; + length++; + } + } + } + + for (int64_t jj = 0; jj < length; jj++) { + Cj[nnz] = head; + Cx[nnz] = sums[head]; + nnz++; + + int64_t temp = head; + head = next[head]; + + next[temp] = -1; // clear arrays + sums[temp] = 0; + } + + Cp[i + 1] = nnz; + } +} + + +template +void sparse_matmul_kernel( + Tensor& output, + const Tensor& mat1, + const Tensor& mat2) { + /* + Computes the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format. + */ + + auto M = mat1.size(0); + auto K = mat1.size(1); + auto N = mat2.size(1); + + auto mat1_indices_ = mat1._indices().contiguous(); + auto mat1_values = mat1._values().contiguous(); + Tensor mat1_row_indices = mat1_indices_.select(0, 0); + Tensor mat1_col_indices = mat1_indices_.select(0, 1); + + Tensor mat1_indptr = coo_to_csr(mat1_row_indices.data_ptr(), M, mat1._nnz()); + + auto mat2_indices_ = mat2._indices().contiguous(); + auto mat2_values = mat2._values().contiguous(); + Tensor mat2_row_indices = mat2_indices_.select(0, 0); + Tensor mat2_col_indices = mat2_indices_.select(0, 1); + + Tensor mat2_indptr = coo_to_csr(mat2_row_indices.data_ptr(), K, mat2._nnz()); + + auto nnz = _csr_matmult_maxnnz(M, N, mat1_indptr.data_ptr(), mat1_col_indices.data_ptr(), + mat2_indptr.data_ptr(), mat2_col_indices.data_ptr()); + + auto output_indices = output._indices(); + auto output_values = output._values(); + + Tensor output_indptr = at::empty({M + 1}, kLong); + at::native::resize_output(output_indices, {2, nnz}); + at::native::resize_output(output_values, nnz); + + Tensor output_row_indices = output_indices.select(0, 0); + Tensor output_col_indices = output_indices.select(0, 1); + + _csr_matmult(M, N, mat1_indptr.data_ptr(), mat1_col_indices.data_ptr(), mat1_values.data_ptr(), + mat2_indptr.data_ptr(), mat2_col_indices.data_ptr(), mat2_values.data_ptr(), + output_indptr.data_ptr(), output_col_indices.data_ptr(), output_values.data_ptr()); + + csr_to_coo(M, output_indptr.data_ptr(), output_row_indices.data_ptr()); +} + +} // end anonymous namespace + +Tensor sparse_matrix_mask_helper_cpu( + const SparseTensor& t, + const Tensor& mask_indices +) { + /* + This is a helper function which filter values from `t._values()` using the `mask_indices`. + This CPU implementation uses a simple hash_map to filter values by matching the `mask_indices` + with the indices at tensor input `t`. + + Inputs: + `t` - tensor input + `mask_indices` - mask indices tensor + */ + int64_t r_nnz = mask_indices.size(1); + auto t_v = t._values(); + Tensor r_values = at::zeros({r_nnz}, t_v.options()); + auto t_i = t._indices(); + auto t_nnz = t._nnz(); + + std::unordered_map t_flatten_indices = std::unordered_map{}; + + // Step 1: flatten the sparse indices `t._indices()` tensor and then map this flatten value `index` to the original position `i` + auto t_indices_accessor = t_i.accessor(); + for(int64_t i = 0; i < t_nnz; i++) { + int64_t index = t_indices_accessor[0][i] * t.size(1) + t_indices_accessor[1][i]; + t_flatten_indices[index] = i; + } + + // Step 2: Filter `t._values()` values by matching the flatten `mask_indices` with the flatten `t._indices()` using the + // hash_map `t_flatten_indices` + AT_DISPATCH_FLOATING_TYPES(r_values.scalar_type(), "_sparse_matrix_mask", [&] { + auto r_values_accessor = r_values.accessor(); + auto t_values = t_v.accessor(); + auto mask_indices_accessor = mask_indices.accessor(); + at::parallel_for(0, r_nnz, 0, [&](int64_t start, int64_t end) { + for (auto i = start; i < end; i++) { + auto x = mask_indices_accessor[0][i]; + auto y = mask_indices_accessor[1][i]; + int64_t index = (x * t.size(1) + y); + auto iter = t_flatten_indices.find(index); + if (iter != t_flatten_indices.end()) { + assert(iter->second < t_nnz); + assert(i < r_nnz); + r_values_accessor[i] = t_values[ iter->second ]; + } + } + }); + }); + return r_values; +} + +Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) { + TORCH_INTERNAL_ASSERT(mat1_.is_sparse()); + TORCH_INTERNAL_ASSERT(mat2_.is_sparse()); + TORCH_CHECK(mat1_.dim() == 2); + TORCH_CHECK(mat2_.dim() == 2); + TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat1_.dense_dim(), "D values"); + TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat2_.dense_dim(), "D values"); + + TORCH_CHECK( + mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (", + mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")"); + + TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(), + "mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type()); + + auto output = at::native::empty_like(mat1_); + output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0); + + AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { + sparse_matmul_kernel(output, mat1_.coalesce(), mat2_.coalesce()); + }); + return output; +} + + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 3196c083eea03..fb7e16539c15e 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -14,7 +15,6 @@ namespace at { namespace native { using namespace at::sparse; - /****************************************************************************** * access methods ******************************************************************************/ @@ -70,23 +70,23 @@ Tensor values_sparse(const Tensor& self) { /*** Helper methods ***/ -SparseTensor new_sparse(const TensorOptions& options) { - TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch()); - AT_ASSERT(options.layout() == kSparse); +SparseTensor new_sparse(c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { + AT_ASSERT(layout.has_value() && *layout == kSparse); DispatchKey dispatch_key; - if (options.device().is_cuda()) { + if (device_or_default(device).is_cuda()) { dispatch_key = DispatchKey::SparseCUDA; } else { dispatch_key = DispatchKey::SparseCPU; } return detail::make_tensor( - DispatchKeySet(dispatch_key), options.dtype()); + DispatchKeySet(dispatch_key), scalarTypeToTypeMeta(dtype_or_default(dtype))); } /** Actual dispatched creation methods ***/ -SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef size, const TensorOptions& options) { - SparseTensor self = new_sparse(options); +SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, c10::optional pin_memory) { + SparseTensor self = new_sparse(dtype, layout, device, pin_memory); get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size); return self; } @@ -95,15 +95,18 @@ SparseTensor new_with_dims_and_tensor_sparse( int64_t sparse_dim, int64_t dense_dim, ArrayRef size, - const LongTensor& indices, + const Tensor& indices, const Tensor& values, - const TensorOptions& options) { - SparseTensor self = new_sparse(options); + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + SparseTensor self = new_sparse(dtype, layout, device, pin_memory); get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size); // NOTE: There is no guarantee that `indices` and `values` don't contain AutogradMeta. However, // we want to maintain the invariant that `indices_` and `values_` of a sparse tensor don't // contain AutogradMeta, and to achieve that we shallow-copy `indices` and `values` here. - auto indices_shallow_copy = LongTensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach( + auto indices_shallow_copy = Tensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach( /*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(), /*allow_tensor_metadata_change=*/true)); auto values_shallow_copy = Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach( @@ -116,9 +119,9 @@ SparseTensor new_with_dims_and_tensor_sparse( /** Public creation API that dispatch to methods above **/ /** Empty init **/ -Tensor empty_sparse(IntArrayRef size, const TensorOptions& options, c10::optional optional_memory_format) { - TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); - return new_with_dims_sparse(size.size(), 0, size, options); +Tensor empty_sparse(IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { + TORCH_CHECK(!pin_memory.has_value() || !*pin_memory, "Only dense CPU tensors can be pinned"); + return new_with_dims_sparse(size.size(), 0, size, dtype, layout, device, pin_memory); } /* Shape init */ @@ -160,11 +163,11 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_, const Ten // If the indices has elements in it, we infer the minimum sparse dimension sizes // as the max value of each dim in indices. // NB: It used to keepdim. I think that was wrong. - LongTensor min_indices = std::get(indices.min(/* dim */ 1, /* keepdim */ false)); - LongTensor computed_indices_sizes = std::get(indices.max(/* dim */ 1, /* keepdim */ false)); + Tensor min_indices = std::get(indices.min(/* dim */ 1, /* keepdim */ false)); + Tensor computed_indices_sizes = std::get(indices.max(/* dim */ 1, /* keepdim */ false)); computed_indices_sizes.add_(1); // len = max_index + 1 - LongTensor cpu_min_indices = min_indices.to(at::DeviceType::CPU); - LongTensor cpu_computed_indices_sizes = computed_indices_sizes.to(at::DeviceType::CPU); + Tensor cpu_min_indices = min_indices.to(at::DeviceType::CPU); + Tensor cpu_computed_indices_sizes = computed_indices_sizes.to(at::DeviceType::CPU); auto cpu_min_indices_accessor = cpu_min_indices.accessor(); auto cpu_computed_indices_sizes_accessor = cpu_computed_indices_sizes.accessor(); for (int64_t d = 0; d < sparse_dim; d++) { @@ -203,9 +206,9 @@ void _validate_sparse_coo_tensor_args(const Tensor& indices, const Tensor& value // Check to make sure all indices are within the boundaries of `size` if (indices.numel() > 0) { - LongTensor min_indices = std::get(indices.min(/* dim */ 1, /* keepdim */ false)); - LongTensor max_indices = std::get(indices.max(/* dim */ 1, /* keepdim */ false)); - LongTensor cpu_min_indices, cpu_max_indices; + Tensor min_indices = std::get(indices.min(/* dim */ 1, /* keepdim */ false)); + Tensor max_indices = std::get(indices.max(/* dim */ 1, /* keepdim */ false)); + Tensor cpu_min_indices, cpu_max_indices; if (indices.is_cuda()) { cpu_min_indices = min_indices.to(at::DeviceType::CPU); cpu_max_indices = max_indices.to(at::DeviceType::CPU); @@ -261,7 +264,9 @@ SparseTensor clone_sparse(const SparseTensor& self, c10::optional 0) { - std::vector ix = indices.chunk(indices.size(0), 0); + auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0)); values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve); } else { AT_ASSERT(nz.sizes().equals({0, 1})); @@ -338,9 +343,10 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ // NB: Dropped the resizeNd variants -Tensor sparse_to_dense(const SparseTensor& self) { +Tensor sparse_to_dense(const SparseTensor& self, c10::optional dtype) { + TORCH_CHECK(!dtype.has_value(), "dtype argument is not supported by sparse_to_dense"); if(self.scalar_type() == ScalarType::Half && self.options().device().is_cpu()) { - AT_ERROR("to_dense() not supported for float16 on CPU"); + TORCH_CHECK(false, "to_dense() not supported for float16 on CPU"); } Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided)); return dst.add_(self); @@ -369,23 +375,23 @@ SparseTensor coalesce_sparse_cpu(const SparseTensor& self) { return dst; } - LongTensor indices = self._indices(); + Tensor indices = self._indices(); Tensor values = self._values().contiguous(); int64_t sparse_dim = self.sparse_dim(); int64_t dense_dim = self.dense_dim(); int64_t nnz = self._nnz(); - LongTensor indices_scalar = flatten_indices(indices, self.sizes()); + Tensor indices_scalar = flatten_indices(indices, self.sizes()); - SparseTensor dst = new_sparse(self.options()); + SparseTensor dst = new_sparse(optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt()); get_sparse_impl(dst)->resize_(sparse_dim, dense_dim, self.sizes()); // TODO: is there a more idiomatic way to do this? - LongTensor newIndices = at::empty(indices.sizes(), indices.options()); + Tensor newIndices = at::empty(indices.sizes(), indices.options()); Tensor newValues = at::empty(values.sizes(), values.options()); alias_into_sparse(dst, newIndices, newValues); - LongTensor indicesBuffer; - LongTensor indicesPermutation; + Tensor indicesBuffer; + Tensor indicesPermutation; std::tie(indicesBuffer, indicesPermutation) = indices_scalar.sort(0); // NB: The accessor accesses here rely on self._nnz() > 0 (tested earlier in this function) auto newIndicesAccessor = newIndices.accessor(); @@ -440,7 +446,7 @@ void inline sparse_mask_out_cpu_kernel( const Tensor& t, const int64_t r_nnz, const int64_t sparse_dim, - const LongTensor& mask_indices + const Tensor& mask_indices ) { auto r_values_accessor = r_values.accessor(); auto mask_indices_accessor = mask_indices.accessor(); @@ -470,7 +476,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse } int64_t dim = t.dim(); int64_t sparse_dim = mask.sparse_dim(); - LongTensor mask_indices = mask._indices(); + Tensor mask_indices = mask._indices(); Tensor mask_values = mask._values(); Tensor r_values = at::empty(mask_values.sizes(), r._values().options()); alias_into_sparse(r, mask_indices.clone(), r_values); @@ -486,7 +492,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse // Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices ]. // Keeping this implementation because it is faster than flatten_indices() - LongTensor indices = at::zeros({mask._nnz()}, mask_indices.options()); + Tensor indices = at::zeros({mask._nnz()}, mask_indices.options()); for (int64_t d = 0; d < mask.sparse_dim(); d++) { indices.mul_(mask.size(d)); indices.add_(mask_indices.select(0, d)); diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 2bb5842b47263..6c3298b72e759 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -21,27 +21,13 @@ using namespace at::sparse; // -------------------------------------------------------------------- namespace { - LongTensor _to_csr(const int64_t* indices, int64_t dim, int64_t nnz) { - LongTensor csr = native::zeros({dim + 1}, kLong); - - // TODO: eliminate this conditional when zero-size dims supported correctly - if (nnz > 0) { - auto csr_accessor = csr.accessor(); - // Convert the sparse matrix to CSR format - at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) { - int64_t h, hp0, hp1; - for (auto i = start; i < end; i++) { - hp0 = indices[i]; - hp1 = (i+1 == nnz) ? dim : indices[i+1]; - if (hp0 != hp1) for (h = hp0; h < hp1; h++) { - csr_accessor[h+1] = i+1; - } - } - }); + + inline SparseTensor get_result_tensor_for_unary_op(const SparseTensor& input) { + if (c10::isIntegralType(input.scalar_type(), /*includeBool=*/true)) { + return at::empty_like(input, input.options().dtype(c10::get_default_dtype())); } - return csr; + return at::empty_like(input); } - } // -------------------------------------------------------------------- @@ -102,6 +88,10 @@ SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scal SparseTensor& log1p_out_sparse(SparseTensor& r, const SparseTensor& t) { TORCH_CHECK(r.is_sparse(), "Tensor should be sparse"); TORCH_CHECK(t.is_sparse(), "Tensor should be sparse"); + TORCH_CHECK( + !c10::isIntegralType(r.scalar_type(), /*includeBool=*/true), + "log1p: result type cannot be Integral, got:", + r.scalar_type()); if (is_same_tensor(r, t)) { // don't have in-place log1p for uncoalesced input because coalesce() is not in-place @@ -114,6 +104,11 @@ SparseTensor& log1p_out_sparse(SparseTensor& r, const SparseTensor& t) { return r; } +SparseTensor log1p_sparse(const SparseTensor& t) { + auto result = get_result_tensor_for_unary_op(t); + return log1p_out_sparse(result, t); +} + SparseTensor& log1p_sparse_(SparseTensor& t) { return log1p_out_sparse(t, t); } @@ -147,6 +142,10 @@ SparseTensor& neg_sparse_(SparseTensor& t) { SparseTensor& asin_out_sparse(SparseTensor& r, const SparseTensor& t) { TORCH_CHECK(r.is_sparse(), "Tensor should be sparse"); TORCH_CHECK(t.is_sparse(), "Tensor should be sparse"); + TORCH_CHECK( + !c10::isIntegralType(r.scalar_type(), /*includeBool=*/true), + "asin: result type cannot be Integral, got:", + r.scalar_type()); if (is_same_tensor(r, t)) { // don't have in-place asin for uncoalesced input because coalesce() is not in-place, see above comment @@ -158,6 +157,11 @@ SparseTensor& asin_out_sparse(SparseTensor& r, const SparseTensor& t) { return r; } +SparseTensor asin_sparse(const SparseTensor& t) { + auto result = get_result_tensor_for_unary_op(t); + return asin_out_sparse(result, t); +} + SparseTensor& asin_sparse_(SparseTensor& t) { return asin_out_sparse(t, t); } @@ -217,7 +221,7 @@ static SparseTensor& coalesce_(SparseTensor& tensor) { // values=[1., 1.] (after truncation), which sum to 2.f instead of 3.f. // To perform floor division the sparse tensor must be coalesced first. -SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) { +SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, SparseTensor& r) { TORCH_CHECK(value.dim() == 0, "Sparse division requires a scalar or ", "zero-dim dense tensor divisor (got shape ", value.sizes(), " for divisor)"); TORCH_CHECK(!value.is_sparse(), "Sparse division requires a scalar or ", @@ -255,15 +259,15 @@ Tensor div_sparse(const Tensor& self, const Tensor& value) { commonDtype = typeMetaToScalarType(at::get_default_dtype()); } Tensor result = at::empty({0}, self.options().dtype(commonDtype)); - return div_out_sparse_zerodim(result, self, value); + return div_out_sparse_zerodim(self, value, result); } Tensor& div_sparse_(Tensor& self, const Tensor& value) { - return div_out_sparse_zerodim(self, self, value); + return div_out_sparse_zerodim(self, value, self); } SparseTensor& div_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) { - return div_out_sparse_zerodim(r, t, wrapped_scalar_tensor(value)); + return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r); } // -------------------------------------------------------------------- @@ -429,7 +433,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t, bool coalesced = t.is_coalesced() && src.is_coalesced(); int64_t sparse_dim = src.sparse_dim(); - LongTensor r_indices = at::empty({src.sparse_dim(), max_nnz}, t._indices().options()); + Tensor r_indices = at::empty({src.sparse_dim(), max_nnz}, t._indices().options()); Tensor t_values = t._values().to(commonDtype); Tensor s_values = src._values().to(commonDtype); @@ -524,7 +528,7 @@ SparseTensor& add_out_sparse_non_contiguous(SparseTensor& r, const SparseTensor& } }); - LongTensor r_indices = at::cat({t._indices(), src._indices()}, 1); + Tensor r_indices = at::cat({t._indices(), src._indices()}, 1); Tensor r_values = at::cat({t_values, s_values}, 0).to(r.scalar_type()); alias_into_sparse(r, r_indices, r_values); @@ -540,7 +544,7 @@ SparseTensor& add_out_sparse_non_contiguous(SparseTensor& r, const SparseTensor& Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); -SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, Scalar value) { +SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src, Scalar value, SparseTensor& r) { if (!t.is_sparse()) { return add_out_dense_sparse_cpu(r, t, src, value); } @@ -616,7 +620,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen r.resize_as_(dense); SparseTensor sparse = sparse_.coalesce(); - LongTensor indices = sparse._indices(); + Tensor indices = sparse._indices(); Tensor values = sparse._values(); int64_t nDim = dense.dim(); int64_t nDimI = sparse.sparse_dim(); @@ -646,7 +650,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen dstBuffer.add_(srcBuffer, value); } } else { - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, commonDtype, "add_dense_sparse", [&] { add_dense_sparse_worker_cpu(resultBuffer, value, sparse, indices, valuesBuffer); }); @@ -697,9 +701,9 @@ SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor int64_t t_nnz = t._nnz(), s_nnz = src._nnz(); int64_t max_nnz = std::min(t_nnz, s_nnz); // multiply by zero is zero, and can be dropped int64_t sparse_dim = src.sparse_dim(); - LongTensor t_indices = t._indices(); - LongTensor src_indices = src._indices(); - LongTensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options()); + Tensor t_indices = t._indices(); + Tensor src_indices = src._indices(); + Tensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options()); int64_t match, d; int64_t r_i = 0, t_i = 0, s_i = 0; @@ -865,7 +869,7 @@ Tensor& s_addmm_out_sparse_dense_cpu( return r; } - LongTensor indices = sparse_._indices(); + Tensor indices = sparse_._indices(); Tensor values = sparse_._values(); AT_DISPATCH_ALL_TYPES( @@ -997,13 +1001,13 @@ SparseTensor& hspmm_out_sparse_cpu(SparseTensor& r, const SparseTensor& sparse_, return r; } - LongTensor indices = at::empty({1, nnz}, at::initialTensorOptions().dtype(kLong)); + Tensor indices = at::empty({1, nnz}, at::initialTensorOptions().dtype(kLong)); // Initialize the sparse matrix that will be used with spaddmm to send rows // from the dense matrix to rows of the output's value tensor SparseTensor newSparse = sparse.clone(); - LongTensor spIndices = newSparse._indices(); - LongTensor valueIndices = spIndices.select(0, 0); + Tensor spIndices = newSparse._indices(); + Tensor valueIndices = spIndices.select(0, 0); // Compute output indices auto valueIndices_accessor = valueIndices.accessor(); @@ -1084,18 +1088,19 @@ SparseTensor& _sspaddmm_out_cpu( "sspaddmm: Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1)); int64_t nnz = sparse._nnz(); - LongTensor indices = sparse._indices(); + // We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage + Tensor indices = sparse._indices().contiguous(); Tensor values = sparse._values(); - LongTensor csr = _to_csr(indices.data_ptr(), dim_i, nnz); + Tensor csr = coo_to_csr(indices.data_ptr(), dim_i, nnz); int64_t t_nnz = t._nnz(); int64_t r_nnz = nnz * dim_k + t_nnz; - LongTensor newi = at::empty({2, r_nnz}, kLong); - LongTensor newv = native::zeros({r_nnz}, values.options()); + Tensor newi = at::empty({2, r_nnz}, kLong); + Tensor newv = native::zeros({r_nnz}, values.options()); if (t_nnz != 0) { - LongTensor narrowi = newi.narrow(1, 0, t_nnz); + Tensor narrowi = newi.narrow(1, 0, t_nnz); Tensor narrowv = newv.narrow(0, 0, t_nnz); narrowi.copy_(t._indices()); @@ -1205,7 +1210,7 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) { auto dims_to_sum_v = dims_to_sum.vec(); maybe_wrap_dims(dims_to_sum_v, input_dim); - LongTensor indices = input._indices(); + Tensor indices = input._indices(); Tensor values = input._values(); IntArrayRef sizes = input.sizes(); const int64_t sparse_dim = input.sparse_dim(); @@ -1241,7 +1246,7 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) { } else { // !sum_all_sparse_dim // new indices - LongTensor new_indices; + Tensor new_indices; if (sparse_dims_to_sum_size == 0) { new_indices = indices.clone(at::MemoryFormat::Contiguous); } @@ -1323,7 +1328,7 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_, auto dims_to_sum_v = dims_to_sum.vec(); maybe_wrap_dims(dims_to_sum_v, input_dim); - LongTensor input_indices = input._indices(); + Tensor input_indices = input._indices(); Tensor input_values = input._values(); IntArrayRef input_sizes = input.sizes(); const int64_t input_sparse_dim = input.sparse_dim(); @@ -1364,7 +1369,7 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_, else { TORCH_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense"); auto grad = grad_.coalesce(); - LongTensor grad_indices = grad._indices(); + Tensor grad_indices = grad._indices(); Tensor grad_values = grad._values(); const int64_t grad_sparse_dim = grad.sparse_dim(); const int64_t grad_nnz = grad._nnz(); @@ -1508,12 +1513,12 @@ Tensor& bmm_out_sparse_cpu(Tensor& result, const SparseTensor& self, const Tenso SparseTensor self_coalesced = self.coalesce(); int64_t nnz = self_coalesced._nnz(); - LongTensor indices = self_coalesced._indices(); + Tensor indices = self_coalesced._indices(); Tensor values = self_coalesced._values(); - LongTensor indices_dim0 = indices[0]; + Tensor indices_dim0 = indices[0]; auto indices_dim0_accessor = indices_dim0.accessor(); - LongTensor indices_dim1_dim2 = indices.slice(0, 1, 3); + Tensor indices_dim1_dim2 = indices.slice(0, 1, 3); int64_t dim_i = self_coalesced.size(1); int64_t dim_j = self_coalesced.size(2); @@ -1563,7 +1568,7 @@ Tensor& bmm_out_sparse_cpu(Tensor& result, const SparseTensor& self, const Tenso // Create tensors to view just the current set of matrices const Tensor dense_matrix = mat2[cur_mat_num]; Tensor result_matrix = result[cur_mat_num]; - LongTensor sparse_indices = indices_dim1_dim2.slice(1, mat_el_begin_idx, mat_el_end_idx); + Tensor sparse_indices = indices_dim1_dim2.slice(1, mat_el_begin_idx, mat_el_end_idx); Tensor sparse_values = values.slice(0, mat_el_begin_idx, mat_el_end_idx); int64_t sparse_nnz = mat_el_end_idx - mat_el_begin_idx; diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp index 6328d7df88157..1a3650e6880af 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp @@ -22,7 +22,7 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars if (mask._nnz() == 0) { return r.zero_(); } - LongTensor mask_indices = mask._indices(); + Tensor mask_indices = mask._indices(); Tensor mask_values = mask._values(); Tensor r_values = at::empty(mask_values.sizes(), r._values().options()); alias_into_sparse(r, mask_indices.clone(at::MemoryFormat::Contiguous), r_values); @@ -33,7 +33,7 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars // Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices ]. // Keeping this implementation because it is faster than flatten_indices() - LongTensor indices = at::zeros({mask._nnz()}, mask_indices.options()); + Tensor indices = at::zeros({mask._nnz()}, mask_indices.options()); for (int64_t d = 0; d < mask.sparse_dim(); d++) { indices.mul_(mask.size(d)); // This used to use a buffer but I deoptimized it diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 5d25138500d7d..e08e06b18be50 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -52,10 +52,10 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { // indices will be modified by Thrust, so we have to clone or use new storage // here. - LongTensor indices1D = flatten_indices(self._indices(), self.sizes(), true); + Tensor indices1D = flatten_indices(self._indices(), self.sizes(), true); - LongTensor origIndices = at::empty({nnz}, self._indices().options()); - LongTensor uniqueOffsets = at::empty({nnz}, self._indices().options()); + Tensor origIndices = at::empty({nnz}, self._indices().options()); + Tensor uniqueOffsets = at::empty({nnz}, self._indices().options()); typedef thrust::device_ptr thrust_ptr; thrust_ptr indicesIter(indices1D.data_ptr()); @@ -96,18 +96,16 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { dim3 block(C10_WARP_SIZE, SZ); AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, values.scalar_type(), "coalesce_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "coalesce_sparse_cuda", [&] { - using cuda_accscalar_t = acc_type; - apply::coalesceValuesKernel<<>>( - uniqueOffsets.data_ptr(), - origIndices.data_ptr(), - values.data_ptr(), - newValues.data_ptr(), - nnz, - newNnz, - stride - ); - }); + using cuda_accscalar_t = acc_type; + apply::coalesceValuesKernel<<>>( + uniqueOffsets.data_ptr(), + origIndices.data_ptr(), + values.data_ptr(), + newValues.data_ptr(), + nnz, + newNnz, + stride + ); }); } @@ -128,14 +126,14 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { //////////////////////////////////////////////////////////// // unflatten indices if necessary - LongTensor newIndices; + Tensor newIndices; if (sparse_dim == 1) { newIndices = indices1D; } else { newIndices = at::empty({sparse_dim, newNnz}, origIndices.options()); for (int64_t d = sparse_dim - 1; d >= 0; d--) { // NB: Not a select, so I can preserve the outer dimension - LongTensor indicesSlice = newIndices.narrow(0, d, 1); + Tensor indicesSlice = newIndices.narrow(0, d, 1); // Note for the porting guide: THCTensor_(copy) does NOT do normal // broadcasting logic; instead, it will blast the elements from one // to the other so long as the numel is the same diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 81058ec266f2c..fce3446816e7e 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -39,9 +39,9 @@ using at::cuda::detail::getTensorInfo; // -------------------------------------------------------------------- namespace { - IntTensor _to_csr_int(const LongTensor& rowIndices, int64_t dim, int64_t nnz) { - IntTensor csr = at::empty({dim+1}, CUDA(kInt)); - IntTensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt)); + Tensor _to_csr_int(const Tensor& rowIndices, int64_t dim, int64_t nnz) { + Tensor csr = at::empty({dim+1}, CUDA(kInt)); + Tensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt)); rowIndicesInt.copy_(rowIndices); sparse::cuda::Xcoo2csr(rowIndicesInt.data_ptr(), nnz, dim, csr.data_ptr()); return csr; @@ -52,13 +52,13 @@ namespace { // wired at all) template -void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, Scalar beta, const Tensor& t, Scalar alpha, LongTensor& indices, Tensor& values, const Tensor& dense) { +void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, Scalar beta, const Tensor& t, Scalar alpha, Tensor& indices, Tensor& values, const Tensor& dense) { scalar_t cast_beta = beta.to(); scalar_t cast_alpha = alpha.to(); - LongTensor rowIndices = indices.select(0, 0); - LongTensor colIndices = indices.select(0, 1); - IntTensor csr = _to_csr_int(rowIndices, m, nnz); - IntTensor colIndicesInt = at::empty({colIndices.size(0)}, indices.options().dtype(kInt)); + Tensor rowIndices = indices.select(0, 0); + Tensor colIndices = indices.select(0, 1); + Tensor csr = _to_csr_int(rowIndices, m, nnz); + Tensor colIndicesInt = at::empty({colIndices.size(0)}, indices.options().dtype(kInt)); colIndicesInt.copy_(colIndices); Tensor r__; @@ -147,7 +147,7 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT SparseTensor sparse = sparse_.coalesce(); int64_t nnz = sparse._nnz(); - LongTensor indices = sparse._indices(); + Tensor indices = sparse._indices(); Tensor values = sparse._values(); @@ -247,7 +247,7 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse int64_t nnz = sparse._nnz(); - LongTensor indices = at::empty({1, nnz}, CUDA(kLong)); + Tensor indices = at::empty({1, nnz}, CUDA(kLong)); // create values in column-major format to avoid copying in spaddmm Tensor values = at::empty({n, nnz}, dense.options()); values.transpose_(0, 1); @@ -255,8 +255,8 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse // why does sparse need to be cloned? If this is really necessary maybe we // need to fuse this with newCoalesce SparseTensor newSparse = sparse.clone(); - LongTensor spIndices = newSparse._indices(); - LongTensor dstIndices = spIndices.select(0, 0); + Tensor spIndices = newSparse._indices(); + Tensor dstIndices = spIndices.select(0, 0); // Save destination indices to output hybrid tensor indices.copy_(dstIndices); // Replace destination indices with 0, 1, 2, 3, ... and compute output values @@ -320,7 +320,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT r.copy_(dense_buffer); } - LongTensor indices = sparse._indices(); + Tensor indices = sparse._indices(); int64_t nDim = dense.dim(); int64_t nDimI = sparse.sparse_dim(); @@ -338,15 +338,13 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT if (sparse.dense_dim() == 0) { TORCH_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); - AT_DISPATCH_ALL_TYPES_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { + apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } else { TORCH_CHECK(cuda::getApplyGrid(nnz * block.x, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); @@ -356,28 +354,24 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernel, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + apply::sparseElementwiseKernel, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } } else { - LongTensor indices1D = flatten_indices(indices, sparse.sizes(), 0); + Tensor indices1D = flatten_indices(indices, sparse.sizes(), 0); // FIXME: at some point we can wrap the scale into indexAdd // NB: Purposely not inplace! AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - values = values.mul(value); - } - }); + if (value.to() != static_cast(1)) { + values = values.mul(value); + } }); int64_t view_rows = 1; @@ -405,7 +399,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT Tensor& add_out_dense_sparse_cuda(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); -SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const SparseTensor& src, Scalar value) { +SparseTensor& add_out_sparse_cuda(const SparseTensor& t, const SparseTensor& src, Scalar value, SparseTensor& r_) { if (!t.is_sparse()) { return add_out_dense_sparse_cuda(r_, t, src, value); } @@ -437,21 +431,19 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const // rather than merging them. This removes the need to synchronously fetch nnz // at the end of the operation, at the cost of having a non-coalesced result. // This trade-off is preferable for the common use-case of gradient accumulation. - LongTensor t_indices_ = t._indices(); - LongTensor s_indices_ = src._indices(); + Tensor t_indices_ = t._indices(); + Tensor s_indices_ = src._indices(); Tensor t_values_ = t._values().to(commonDtype); Tensor s_values_ = src._values().to(commonDtype); AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - s_values_ = s_values_.mul(value); - } - }); + if (value.to() != static_cast(1)) { + s_values_ = s_values_.mul(value); + } }); - LongTensor r_indices_ = at::cat({t_indices_, s_indices_}, 1); + Tensor r_indices_ = at::cat({t_indices_, s_indices_}, 1); Tensor r_values_ = at::cat({t_values_, s_values_}, 0); if (r_.scalar_type() != commonDtype) { @@ -509,11 +501,11 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons int64_t sparse_dim = src.sparse_dim(); auto commonDtype = at::result_type(t, src); TORCH_CHECK(canCast(commonDtype, r_.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r_.scalar_type()); - LongTensor t_indices_ = t._indices().contiguous(); + Tensor t_indices_ = t._indices().contiguous(); Tensor t_values_ = t._values().to(commonDtype); - LongTensor s_indices_ = src._indices().contiguous(); + Tensor s_indices_ = src._indices().contiguous(); Tensor s_values_ = src._values().to(commonDtype); - LongTensor r_indices_ = at::empty({sparse_dim, max_nnz}, t_indices_.options()); + Tensor r_indices_ = at::empty({sparse_dim, max_nnz}, t_indices_.options()); r_.resize_as_(src); Tensor r_values_ = new_values_with_size_of(t_values_, max_nnz).zero_(); @@ -526,7 +518,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); TORCH_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "mul: Argument #0: tensor too large or too many dimensions"); - LongTensor resultNnz = at::empty({1}, CUDA(kLong)); + Tensor resultNnz = at::empty({1}, CUDA(kLong)); AT_DISPATCH_ALL_TYPES_AND( at::ScalarType::Half, commonDtype, "mul_out_sparse_cuda", [&] { apply::valueSparseIntersectionKernel, uint64_t, scalar_t> @@ -549,7 +541,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons get_sparse_impl(r_)->set_indices_and_values_unsafe(r_indices_, r_values_); // sync! (surely there is a more idiomatic way to do this...) - LongTensor cpu_resultNnz = at::empty({1}, CPU(kLong)); + Tensor cpu_resultNnz = at::empty({1}, CPU(kLong)); cpu_resultNnz.copy_(resultNnz); get_sparse_impl(r_)->set_nnz_and_narrow(cpu_resultNnz.accessor()[0]); @@ -609,7 +601,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_ auto dims_to_sum_v = dims_to_sum.vec(); maybe_wrap_dims(dims_to_sum_v, input_dim); - LongTensor input_indices = input._indices(); + Tensor input_indices = input._indices(); Tensor input_values = input._values(); IntArrayRef input_sizes = input.sizes(); const int64_t input_sparse_dim = input.sparse_dim(); @@ -649,7 +641,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_ else { TORCH_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad_ Tensor to be sparse, but got dense"); auto grad = grad_.coalesce(); - LongTensor grad_indices = grad._indices(); + Tensor grad_indices = grad._indices(); Tensor grad_values = grad._values(); const int64_t grad_sparse_dim = grad.sparse_dim(); const int64_t grad_nnz = grad._nnz(); @@ -687,7 +679,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_ thrust_ptr input_indices_iter(input_indices_1D.data_ptr()); // store lower_bound of input indices at grad indices - LongTensor input_indices_pos = at::empty_like(input_indices_1D, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor input_indices_pos = at::empty_like(input_indices_1D, LEGACY_CONTIGUOUS_MEMORY_FORMAT); thrust_ptr input_indices_pos_iter(input_indices_pos.data_ptr()); thrust::lower_bound(policy, grad_indices_iter, grad_indices_iter + grad_nnz, @@ -775,7 +767,7 @@ __global__ void search_end_matrix_indices_cuda_kernel( // Search through a 1D tensor of sorted sparse matrix // indices to find the end index for each matrix -void search_end_matrix_indices(int64_t* mat_el_end_indices, int64_t num_matrices, const LongTensor& indices_1D) { +void search_end_matrix_indices(int64_t* mat_el_end_indices, int64_t num_matrices, const Tensor& indices_1D) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); @@ -863,10 +855,10 @@ Tensor& _bmm_out_sparse_cuda(Tensor& result, const SparseTensor& self, const Ten SparseTensor self_coalesced = coalesce_sparse_cuda(self); int64_t nnz = self_coalesced._nnz(); - LongTensor indices = self_coalesced._indices(); + Tensor indices = self_coalesced._indices(); Tensor values = self_coalesced._values(); - LongTensor indices_dim0 = indices[0]; + Tensor indices_dim0 = indices[0]; // Need to convert dim1 and dim2 indices to 32-bit since cusparseSpMM // only supports 32-bit indices diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu new file mode 100644 index 0000000000000..d7f98ee20b7b0 --- /dev/null +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -0,0 +1,896 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +#if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000) +#define IS_CUSPARSE11_AVAILABLE() 1 +#else +#define IS_CUSPARSE11_AVAILABLE() 0 +#endif + +#if IS_CUSPARSE11_AVAILABLE() +#include +#endif + +namespace at { +namespace native { + +namespace { + +using namespace at::sparse; + +Tensor _to_csr_int(const Tensor& rowIndices, int64_t dim, int64_t nnz) { + Tensor csr = at::empty({dim + 1}, CUDA(kInt)); + Tensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt)); + rowIndicesInt.copy_(rowIndices); + sparse::cuda::Xcoo2csr( + rowIndicesInt.data_ptr(), nnz, dim, csr.data_ptr()); + return csr; +} + +int confirm_mult_size(const std::vector& mat1_size, const std::vector& mat2_size) { + TORCH_CHECK( + mat1_size[1] == mat2_size[0], + "mat1 and mat2 shapes cannot be multiplied (", + mat1_size[0], + "x", + mat1_size[1], + " and ", + mat2_size[0], + "x", + mat2_size[1], + ")"); + return mat1_size[1]; +} + +void create_general_description_(cusparseMatDescr_t& description_) { + TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&description_)); + TORCH_CUDASPARSE_CHECK(cusparseSetMatType(description_, CUSPARSE_MATRIX_TYPE_GENERAL)); + TORCH_CUDASPARSE_CHECK(cusparseSetMatIndexBase(description_, CUSPARSE_INDEX_BASE_ZERO)); +} + +// csrMatrixRef is used to have a representation of a raw CSR matrix representation +// comming from `sparse_sparse_matmul_cuda_kernel` function. +// Moreover this implements a RAII guard for a cusparse descriptor +template +struct csrMatrixRef { + int* csr_indices_{nullptr}; + int* csr_pointers_{nullptr}; + scalar_t* csr_values_{nullptr}; + int nnz_{0}; + std::vector size_{}; + + #if IS_CUSPARSE11_AVAILABLE() + cusparseSpMatDescr_t description_{0}; + #else + cusparseMatDescr_t description_{0}; + #endif + + csrMatrixRef() { + #if !IS_CUSPARSE11_AVAILABLE() + create_general_description_(description_); + #endif + } + + csrMatrixRef( + int* csr_indices, + int* csr_pointers, + scalar_t* csr_values, + int nnz, + const std::vector& size) + : csr_indices_{csr_indices}, + csr_pointers_{csr_pointers}, + csr_values_{csr_values}, + nnz_{nnz}, + size_{size} { + #if IS_CUSPARSE11_AVAILABLE() + cudaDataType cuda_data_type; + if ( std::is_same::value ) { + cuda_data_type = CUDA_R_32F; + } else if ( std::is_same::value) { + cuda_data_type = CUDA_R_64F; + } else { + TORCH_CHECK(false, "Tensor types must be either float32 or float64"); + } + TORCH_CUDASPARSE_CHECK(cusparseCreateCsr( + &description_, + this->size(0), + this->size(1), + this->nnz_, + this->csr_pointers_, + this->csr_indices_, + this->csr_values_, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, + cuda_data_type)); + #else + create_general_description_(description_); + #endif + } + + ~csrMatrixRef() { + #if IS_CUSPARSE11_AVAILABLE() + cusparseDestroySpMat(description_); + #else + cusparseDestroyMatDescr(description_); + #endif + } + + int size(int index) const { + return size_.at(index); + } +}; + +// csrOutput is used to represent the output for `CusparseMatrixMultiplyOp` +// Note that `csrOutput` is different from `csrMatrixRef` and the purpose +// of this was to have a materialized version of a CSR matrix. +// Moreover this implements a RAII guard for a cusparse descriptor +struct csrOutput { + Tensor csr_indices_{}; + Tensor csr_pointers_{}; + at::Tensor csr_values_{}; + int nnz_{0}; + std::vector size_; + + cusparseMatDescr_t description_{0}; + + csrOutput(const std::vector &size) : size_{size} { + create_general_description_(description_); + } + + ~csrOutput() { + cusparseDestroyMatDescr(description_); + } + + int size(int index) const { + return size_.at(index); + } +}; + +#if IS_CUSPARSE11_AVAILABLE() + +// RAII guard helps to support cuSparse 11 API for `A @ B` operation +// This generic template exists because with cuSparse the `scalar_t` type could be a double or float +template +struct CusparseMatrixMultiplyOp { + + cusparseSpGEMMDescr_t spgemmDesc; + + CusparseMatrixMultiplyOp() { + static_assert(std::is_same::value || std::is_same::value, + "cusparse csr sparse-sparse MM only supports data type of float and double."); + // SpGEMM Computation + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc)); + } + + ~CusparseMatrixMultiplyOp() { + // destroy matrix/vector descriptors + cusparseSpGEMM_destroyDescr(spgemmDesc); + } + + csrOutput operator ()( + const csrMatrixRef& A, + const csrMatrixRef& B, + Tensor& output_values, + Tensor& output_indices) { + const int A_num_rows = A.size(0); + const int A_num_cols = A.size(1); + const int A_num_nnz = A.nnz_; + + const int B_num_rows = B.size(0); + const int B_num_cols = B.size(1); + const int B_num_nnz = B.nnz_; + + int* dA_csrOffsets = A.csr_pointers_; + int* dA_columns = A.csr_indices_; + scalar_t* dA_values = A.csr_values_; + + int* dB_csrOffsets = B.csr_pointers_; + int* dB_columns = B.csr_indices_; + scalar_t* dB_values = B.csr_values_; + + cudaDataType computeType; + if ( std::is_same::value ) { + computeType = CUDA_R_32F; + } else if ( std::is_same::value) { + computeType = CUDA_R_64F; + } else { + TORCH_CHECK(false, "Tensor types must be either float32 or float64"); + } + csrOutput out({A.size(0), B.size(1)}); + + out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt)); + + int* dC_csrOffsets = out.csr_pointers_.data_ptr(); + int* dC_columns = nullptr; + scalar_t* dC_values = nullptr; + + scalar_t alpha = 1.0f; + scalar_t beta = 0.0f; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE; + + csrMatrixRef C( + nullptr, + nullptr, + nullptr, + /*nnz*/0, + {A_num_rows, B_num_cols} + ); + + //-------------------------------------------------------------------------- + // CUSPARSE APIs + cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle(); + void *dBuffer1 = NULL, *dBuffer2 = NULL; + size_t bufferSize1 = 0, bufferSize2 = 0; + + cusparseSpMatDescr_t matA = A.description_; + cusparseSpMatDescr_t matB = B.description_; + cusparseSpMatDescr_t matC = C.description_; + //-------------------------------------------------------------------------- + + // ask bufferSize1 bytes for external memory + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation( + handle, + opA, + opB, + &alpha, + matA, + matB, + &beta, + matC, + computeType, + CUSPARSE_SPGEMM_DEFAULT, + spgemmDesc, + &bufferSize1, + NULL)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + + at::DataPtr dataPtr1 = allocator.allocate(bufferSize1); + dBuffer1 = dataPtr1.get(); + // inspect the matrices A and B to understand the memory requiremnent for + // the next step + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation( + handle, + opA, + opB, + &alpha, + matA, + matB, + &beta, + matC, + computeType, + CUSPARSE_SPGEMM_DEFAULT, + spgemmDesc, + &bufferSize1, + dBuffer1)); + + // ask bufferSize2 bytes for external memory + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute( + handle, + opA, + opB, + &alpha, + matA, + matB, + &beta, + matC, + computeType, + CUSPARSE_SPGEMM_DEFAULT, + spgemmDesc, + &bufferSize2, + NULL)); + + at::DataPtr dataPtr2 = allocator.allocate(bufferSize2); + dBuffer2 = dataPtr2.get(); + + // compute the intermediate product of A * B + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute( + handle, + opA, + opB, + &alpha, + matA, + matB, + &beta, + matC, + computeType, + CUSPARSE_SPGEMM_DEFAULT, + spgemmDesc, + &bufferSize2, + dBuffer2)); + // get matrix C non-zero entries C_num_nnz1 + int64_t C_num_rows1, C_num_cols1, C_num_nnz1; + TORCH_CUDASPARSE_CHECK( + cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_num_nnz1)); + // allocate matrix C + // allocate C offsets + out.nnz_ = C_num_nnz1; + + out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt)); + out.csr_values_ = at::empty({out.nnz_}, output_values.options()); + dC_columns = out.csr_indices_.data_ptr(); + dC_values = out.csr_values_.data_ptr(); + + // update matC with the new pointers + TORCH_CUDASPARSE_CHECK( + cusparseCsrSetPointers(matC, dC_csrOffsets, dC_columns, dC_values)); + + // copy the final products to the matrix C + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_copy( + handle, + opA, + opB, + &alpha, + matA, + matB, + &beta, + matC, + computeType, + CUSPARSE_SPGEMM_DEFAULT, + spgemmDesc)); + return out; + } +}; + + +template struct CusparseMatrixMultiplyOp; + +template struct CusparseMatrixMultiplyOp; + +#else // if not IS_CUSPARSE11_AVAILABLE() + +using DcsrMatrixRef = csrMatrixRef; +using ScsrMatrixRef = csrMatrixRef; + +// RAII guard helps to support cuSparse 10 API for `A @ B` operation +// This generic template exists because with cuSparse the `scalar_t` type could be a double or float +template +struct CusparseMatrixMultiplyOp { + csrOutput operator()( + const csrMatrixRef& lhs, + const csrMatrixRef& rhs, + Tensor &output_values, + Tensor &output_indices) + { + TORCH_INTERNAL_ASSERT(false, "cusparse csr sparse-sparse MM only supports data type of float and double."); + } +}; + +// Specializacion for `A @ B` operation for double values with cuSparse +template<> struct CusparseMatrixMultiplyOp { + csrgemm2Info_t gemm2Info_; + + CusparseMatrixMultiplyOp() { + TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_)); + } + ~CusparseMatrixMultiplyOp() { + cusparseDestroyCsrgemm2Info(gemm2Info_); + } + + csrOutput operator ()( + const DcsrMatrixRef& lhs, + const DcsrMatrixRef& rhs, + Tensor &output_values, + Tensor &output_indices) { + double alpha = 1.0; + DcsrMatrixRef empty; + return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices); + } + + csrOutput Dgemm2( + const DcsrMatrixRef& A, + const DcsrMatrixRef& B, + const DcsrMatrixRef& C, + const double* alpha, + const double* beta, + Tensor &output_values, + Tensor &output_indices) { + void* buffer_{nullptr}; + cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle(); + TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST)); + + csrOutput out({A.size(0), B.size(1)}); + int innerSize = confirm_mult_size(A.size_, B.size_); + out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt)); + + // Compute needed buffer size + size_t new_bubber_sz; + TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt( + cusparseHandle_, + out.size(0), + out.size(1), + innerSize, + alpha, + A.description_, + A.nnz_, + A.csr_pointers_, + A.csr_indices_, + B.description_, + B.nnz_, + B.csr_pointers_, + B.csr_indices_, + beta, + C.description_, + C.nnz_, + C.csr_pointers_, + C.csr_indices_, + gemm2Info_, + &new_bubber_sz)); + + // (Re)allocate buffer if needed + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + at::DataPtr data_ptr = allocator.allocate(new_bubber_sz); + buffer_ = data_ptr.get(); + + // Find the resulting non-zero pattern. + TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz( + cusparseHandle_, + out.size(0), + out.size(1), + innerSize, + A.description_, + A.nnz_, + A.csr_pointers_, + A.csr_indices_, + B.description_, + B.nnz_, + B.csr_pointers_, + B.csr_indices_, + C.description_, + C.nnz_, + C.csr_pointers_, + C.csr_indices_, + out.description_, + out.csr_pointers_.data_ptr(), + &out.nnz_, + gemm2Info_, + buffer_)); + + out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt)); + out.csr_values_ = at::empty({out.nnz_}, output_values.options()); + + // Perform the gemm2 operation for doubles + // out = alpha ∗ A ∗ B + beta ∗ C + TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2( + cusparseHandle_, + out.size(0), + out.size(1), + innerSize, + alpha, + A.description_, + A.nnz_, + A.csr_values_, + A.csr_pointers_, + A.csr_indices_, + B.description_, + B.nnz_, + B.csr_values_, + B.csr_pointers_, + B.csr_indices_, + beta, + C.description_, + C.nnz_, + C.csr_values_, + C.csr_pointers_, + C.csr_indices_, + out.description_, + out.csr_values_.data_ptr(), + out.csr_pointers_.data_ptr(), + out.csr_indices_.data_ptr(), + gemm2Info_, + buffer_)); + return out; + } +}; + +// Specializacion for `A @ B` operation for float values with cuSparse +template<> struct CusparseMatrixMultiplyOp { + csrgemm2Info_t gemm2Info_; + + CusparseMatrixMultiplyOp() { + TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_)); + + } + ~CusparseMatrixMultiplyOp() { + cusparseDestroyCsrgemm2Info(gemm2Info_); + } + csrOutput operator()( + const ScsrMatrixRef& lhs, + const ScsrMatrixRef& rhs, + Tensor &output_values, + Tensor &output_indices) { + float alpha = 1.0; + ScsrMatrixRef empty; + return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices); + } + + csrOutput Sgemm2( + const ScsrMatrixRef& A, + const ScsrMatrixRef& B, + const ScsrMatrixRef& C, + const float* alpha, + const float* beta, + Tensor &output_values, + Tensor &output_indices) { + void* buffer_{nullptr}; + cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle(); + TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST)); + + csrOutput out({A.size(0), B.size(1)}); + + int innerSize = confirm_mult_size(A.size_, B.size_); + + out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt)); + + // Compute needed buffer size + size_t new_bubber_sz; + TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt( + cusparseHandle_, + out.size(0), + out.size(1), + innerSize, + alpha, + A.description_, + A.nnz_, + A.csr_pointers_, + A.csr_indices_, + B.description_, + B.nnz_, + B.csr_pointers_, + B.csr_indices_, + beta, + C.description_, + C.nnz_, + C.csr_pointers_, + C.csr_indices_, + gemm2Info_, + &new_bubber_sz)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + at::DataPtr data_ptr = allocator.allocate(new_bubber_sz); + buffer_ = data_ptr.get(); + + // Find the resulting non-zero pattern. + TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz( + cusparseHandle_, + out.size(0), + out.size(1), + innerSize, + A.description_, + A.nnz_, + A.csr_pointers_, + A.csr_indices_, + B.description_, + B.nnz_, + B.csr_pointers_, + B.csr_indices_, + C.description_, + C.nnz_, + C.csr_pointers_, + C.csr_indices_, + out.description_, + out.csr_pointers_.data_ptr(), + &out.nnz_, + gemm2Info_, + buffer_)); + + out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt)); + out.csr_values_ = at::empty({out.nnz_}, output_values.options()); + + // Perform the gemm2 operation for doubles + // out = alpha ∗ A ∗ B + beta ∗ C + TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2( + cusparseHandle_, + out.size(0), + out.size(1), + innerSize, + alpha, + A.description_, + A.nnz_, + A.csr_values_, + A.csr_pointers_, + A.csr_indices_, + B.description_, + B.nnz_, + B.csr_values_, + B.csr_pointers_, + B.csr_indices_, + beta, + C.description_, + C.nnz_, + C.csr_values_, + C.csr_pointers_, + C.csr_indices_, + out.description_, + out.csr_values_.data_ptr(), + out.csr_pointers_.data_ptr(), + out.csr_indices_.data_ptr(), + gemm2Info_, + buffer_)); + return out; + } +}; + + + +#endif // IS_CUSPARSE11_AVAILABLE() + +template +void sparse_sparse_matmul_cuda_kernel( + Tensor& result, + const Tensor& mat1, + const Tensor& mat2) { + + static_assert(std::is_same::value || std::is_same::value, + "sparse_sparse_matmul_cuda_kernel only supports float and double value types"); + + Tensor mat1_indices_ = mat1._indices().contiguous(); + Tensor mat1_values = mat1._values().contiguous(); + + Tensor mat1_row_indices = mat1_indices_.select(0, 0); + Tensor mat1_col_indices = mat1_indices_.select(0, 1); + + Tensor mat1_indptr = _to_csr_int(mat1_row_indices, mat1.size(0), mat1._nnz()); + + Tensor mat1_indices = at::empty( + {mat1_col_indices.size(0)}, mat1_col_indices.options().dtype(kInt)); + + mat1_indices.copy_(mat1_col_indices); + + Tensor mat2_indices_ = mat2._indices().contiguous(); + Tensor mat2_values = mat2._values().contiguous(); + Tensor mat2_row_indices = mat2_indices_.select(0, 0); + Tensor mat2_col_indices = mat2_indices_.select(0, 1); + + Tensor mat2_indptr = _to_csr_int(mat2_row_indices, mat2.size(0), mat2._nnz()); + Tensor mat2_indices = at::empty({mat2_col_indices.size(0)}, mat2_col_indices.options().dtype(kInt)); + mat2_indices.copy_(mat2_col_indices); + + auto m = mat1.size(0); + auto k1 = mat1.size(1); + + auto k2 = mat2.size(0); + auto n = mat2.size(1); + TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k1 <= INT_MAX), + "At the moment, cusparseDcsrgemm2 only supports m, n, k, nnz with the bound [val] <= ", INT_MAX, ".", + "If you need this, please file an issue on GitHub." + ); + auto output_indices = result._indices(); + auto output_values = result._values(); + + if ((k1 == 0 && k2 == 0) || (n == 0 && m == 0)) { + output_indices.zero_(); + output_values.zero_(); + return; + } + + csrMatrixRef csr_mat1( + mat1_indices.data_ptr(), + mat1_indptr.data_ptr(), + mat1_values.data_ptr(), + (int)mat1._nnz(), + {(int)mat1.size(0), (int)mat1.size(1)}); + + csrMatrixRef csr_mat2( + mat2_indices.data_ptr(), + mat2_indptr.data_ptr(), + mat2_values.data_ptr(), + (int)mat2._nnz(), + {(int)mat2.size(0), (int)mat2.size(1)}); + + // Sparse matrix multiplication + CusparseMatrixMultiplyOp op; + csrOutput csr_output = op(csr_mat1, csr_mat2, output_values, output_indices); + auto nnz = csr_output.nnz_; + + output_values.set_(csr_output.csr_values_); + output_indices.resize_({2, nnz}); + auto output_indices_accessor = output_indices.packed_accessor(); + + auto csr_output_pointers_accessor = + csr_output.csr_pointers_.packed_accessor(); + + auto csr_output_ind_accessor = + csr_output.csr_indices_.packed_accessor(); + + auto major_dim = result.size(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Filling the COO row indices + thrust::for_each( + policy, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_counting_iterator(int64_t(major_dim)), + [output_indices_accessor, + csr_output_pointers_accessor, + major_dim, + nnz] __device__(int64_t i) { + auto Ap = csr_output_pointers_accessor.data(); + int64_t* indices_row = output_indices_accessor[0].data(); + + for (int jj = Ap[i]; jj < Ap[i + 1]; jj++) { + indices_row[jj] = i; + } + }); + + // Filling the COO column indices + thrust::for_each( + policy, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_counting_iterator(int64_t(csr_output.nnz_)), + [output_indices_accessor, + csr_output_pointers_accessor, + csr_output_ind_accessor, + major_dim, + nnz] __device__(int64_t i) { + int64_t* indices_col = output_indices_accessor[1].data(); + indices_col[i] = csr_output_ind_accessor[i]; + }); +} + +} // end anonymous namespace + +Tensor sparse_matrix_mask_helper_cuda( + const SparseTensor& t, + const Tensor& mask_indices +) { + /* + This is a helper function which filter values from `t._values()` using the `mask_indices`. + This CUDA implementation uses `thrust::set_intersection_by_key` operation to find the intersection + of the `mask_indices` and the `t._indices()` to then filter the values. + + Inputs: + `t` - tensor input + `mask_indices` - mask indices tensor + */ + int64_t r_nnz = mask_indices.size(1); + auto t_v = t._values().contiguous(); + + Tensor r_values = at::zeros({r_nnz}, t_v.options()); + + auto t_i = t._indices().contiguous(); + auto t_indices_accessor = t_i.packed_accessor(); + auto t_nnz = t._nnz(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + Tensor t_flatten_indices = at::empty({t_nnz}, mask_indices.options()); + auto t_flatten_indices_accessor = t_flatten_indices.packed_accessor(); + auto t_n_cols = t.size(1); + + // Step 1: flatten the sparse indices `t._indices()` tensor into a 1D indices tensor `t_flatten_indices`. + thrust::for_each( + policy, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_counting_iterator(int64_t(t_nnz)), + [t_indices_accessor, t_flatten_indices_accessor, t_n_cols] __device__ (int64_t i) mutable { + auto index = t_indices_accessor[0][i] * t_n_cols + t_indices_accessor[1][i]; + t_flatten_indices_accessor[i] = index; + }); + + Tensor mask_flatten_indices = at::empty({r_nnz}, mask_indices.options()); + auto mask_flatten_indices_accessor = mask_flatten_indices.packed_accessor(); + auto mask_indices_accessor = mask_indices.packed_accessor(); + + // Step 2: flatten the sparse indices `mask_indices` tensor into a 1D indices tensor `mask_flatten_indices`. + thrust::for_each( + policy, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_counting_iterator(int64_t(r_nnz)), + [mask_flatten_indices_accessor, mask_indices_accessor, t_n_cols] __device__ (int64_t i) mutable { + auto index = mask_indices_accessor[0][i] * t_n_cols + mask_indices_accessor[1][i]; + mask_flatten_indices_accessor[i] = index; + }); + auto max_sz = std::max(r_nnz, t_nnz); + Tensor t_index_set = at::empty({max_sz}, mask_indices.options()); + + // Step 3: find the intersection between `t_flatten_indices` and `mask_flatten_indices` indices. + // Note: the original positions from `t_flatten_indices` are stored in `t_index_set` + auto result_end = thrust::set_intersection_by_key( + policy, + t_flatten_indices.data_ptr(), + t_flatten_indices.data_ptr() + t_nnz, + mask_flatten_indices.data_ptr(), + mask_flatten_indices.data_ptr() + r_nnz, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_discard_iterator(), + t_index_set.data_ptr()); + + // new_sz is the size of the intersection of the `mask_indices` and the `t._indices()` + auto new_sz = thrust::distance(t_index_set.data_ptr(), result_end.second); + + Tensor mask_index_set = at::empty({max_sz}, mask_indices.options()); + + // Step 4: Repeat the intersection operation between `mask_flatten_indices` and `t_flatten_indices` indices. + // But now store the positions from `mask_flatten_indices` in `mask_index_set` + thrust::set_intersection_by_key( + policy, + mask_flatten_indices.data_ptr(), + mask_flatten_indices.data_ptr() + r_nnz, + t_flatten_indices.data_ptr(), + t_flatten_indices.data_ptr() + t_nnz, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_discard_iterator(), + mask_index_set.data_ptr()); + + // Step 5: Filter `t._values()` values by using `mask_index_set` and `t_index_set` + AT_DISPATCH_FLOATING_TYPES(r_values.scalar_type(), "_sparse_matrix_mask", [&] { + auto r_values_accessor = r_values.packed_accessor(); + auto t_values = t_v.packed_accessor(); + auto mask_index_set_ptr = mask_index_set.packed_accessor(); + auto t_index_set_ptr = t_index_set.packed_accessor(); + thrust::for_each( + policy, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_counting_iterator(int64_t(new_sz)), + [r_values_accessor, t_values, t_index_set_ptr, mask_index_set_ptr, r_nnz] __device__ (int64_t i) mutable { + int64_t target = mask_index_set_ptr[i]; + int64_t origin = t_index_set_ptr[i]; + r_values_accessor[target] = t_values[origin]; + }); + }); + return r_values; +} + +Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) { + TORCH_INTERNAL_ASSERT(mat1_.is_sparse()); + TORCH_INTERNAL_ASSERT(mat2_.is_sparse()); + TORCH_CHECK(mat1_.dim() == 2); + TORCH_CHECK(mat2_.dim() == 2); + TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_mm: scalar values expected, mat1 got ", mat1_.dense_dim(), "D values"); + TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_mm: scalar values expected, mat2 got ", mat2_.dense_dim(), "D values"); + + TORCH_CHECK( + mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (", + mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")"); + + TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(), + "mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type()); + + auto output = at::native::empty_like(mat1_); + output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0); + + AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] { + sparse_sparse_matmul_cuda_kernel(output, mat1_.coalesce(), mat2_.coalesce()); + }); + return output; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/Vulkan.cpp b/aten/src/ATen/native/vulkan/Vulkan.cpp index d6fa6a32291b3..3646ae7e94962 100644 --- a/aten/src/ATen/native/vulkan/Vulkan.cpp +++ b/aten/src/ATen/native/vulkan/Vulkan.cpp @@ -7,6 +7,7 @@ #include #include +#include #ifdef USE_VULKAN_WRAPPER #include @@ -778,17 +779,17 @@ void ComputeUnit::createComputePipeline( { uint32_t offset = 0; size_t size = sizeof(WorkGroupSize::x); - spMapEntries[0].constantID = 1; + spMapEntries[0].constantID = 0; spMapEntries[0].offset = offset; spMapEntries[0].size = size; offset += size; size = sizeof(WorkGroupSize::y); - spMapEntries[1].constantID = 2; + spMapEntries[1].constantID = 1; spMapEntries[1].offset = offset; spMapEntries[1].size = size; offset += size; size = sizeof(WorkGroupSize::z); - spMapEntries[2].constantID = 3; + spMapEntries[2].constantID = 2; spMapEntries[2].offset = offset; spMapEntries[2].size = size; } @@ -1037,12 +1038,6 @@ ComputeUnit& ComputeUnitFactory::get( // VBuffer <-> VImage void copy_buffer_to_image(const VBuffer& buffer, VImage& image) { const auto device = context().device(); - struct ConstBlock { - int32_t w; - int32_t h; - }; - const ConstBlock constBlock{image.w(), image.h()}; - VBuffer constBuffer = makeUniformConstBuffer(&constBlock, sizeof(constBlock)); VkDescriptorSetLayout descrSetLayout{}; VkDescriptorSetLayoutBinding bindings[] = { @@ -1064,7 +1059,6 @@ void copy_buffer_to_image(const VBuffer& buffer, VImage& image) { image.bindStorageImage(descrSet, 0); buffer.bind(descrSet, 1); - constBuffer.bind(descrSet, 2); WorkGroupSize workGroupSize{8, 8, 1}; auto& computeUnit = context().computeUnitFactory().get( @@ -1096,12 +1090,6 @@ void copy_image_to_buffer( TORCH_INTERNAL_ASSERT( buffer.sizeBytes() >= image.capacityBytes(), "VulkanBuffer's capacity is less than VulkanImage capacity to copy from"); - struct ConstBlock { - int32_t w; - int32_t h; - }; - const ConstBlock constBlock{image.w(), image.h()}; - VBuffer constBuffer = makeUniformConstBuffer(&constBlock, sizeof(constBlock)); VkDescriptorSetLayout descrSetLayout{}; const VkDescriptorSetLayoutBinding bindings[] = { @@ -1124,7 +1112,6 @@ void copy_image_to_buffer( image.bindShaderRead(descrSet, 0); buffer.bind(descrSet, 1); - constBuffer.bind(descrSet, 2); const WorkGroupSize workGroupSize{8, 8, 1}; auto& computeUnit = context().computeUnitFactory().get( @@ -1182,11 +1169,7 @@ class VulkanTensor::Impl final { explicit Impl(std::vector sizes) : sizes_(std::move(sizes)), strides_(std::vector(sizes_.size())), - numel_(std::accumulate( - std::begin(sizes_), - std::end(sizes_), - 1, - std::multiplies())) { + numel_(prod_intlist(sizes_)) { TORCH_CHECK( initVulkanContextOnce(), "Vulkan Failed to create Vulkan Context"); } @@ -1289,8 +1272,7 @@ class VulkanTensor::Impl final { VkDeviceSize buffer_size_for_sizes(std::vector sizes) const { const auto d = sizes.size(); - const auto numel = std::accumulate( - std::begin(sizes), std::end(sizes), 1, std::multiplies()); + const auto numel = prod_intlist(sizes); VkDeviceSize bufferSize{sizeof(float) * numel}; // alignment to be able to copy between image and buffer if (d == 4) { diff --git a/aten/src/ATen/native/vulkan/VulkanAten.cpp b/aten/src/ATen/native/vulkan/VulkanAten.cpp index 72d5e15208ecf..b43b87e167909 100644 --- a/aten/src/ATen/native/vulkan/VulkanAten.cpp +++ b/aten/src/ATen/native/vulkan/VulkanAten.cpp @@ -55,17 +55,20 @@ VulkanTensor& vtensor_from_vulkan(Tensor& tensor) { Tensor empty( IntArrayRef size, - const TensorOptions& options, + optional dtype, + optional layout, + optional device, + optional pin_memory, const optional memory_format) { TORCH_CHECK( - !options.has_pinned_memory(), + !pin_memory.has_value(), "'pin_memory' argument is incompatible with Vulkan tensor"); TORCH_CHECK( - !options.has_memory_format() && !memory_format, + !memory_format.has_value(), "'memory_format' argument is incompatible with Vulkan tensor"); VulkanTensor vt{size.vec()}; return new_with_vtensor_vulkan( - std::move(vt), at::device(at::kVulkan).dtype(options.dtype())); + std::move(vt), at::device(at::kVulkan).dtype(dtype)); } Tensor empty_strided( @@ -76,7 +79,7 @@ Tensor empty_strided( optional device, optional pin_memory) { return vulkan::aten::empty( - size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), c10::nullopt); + size, dtype, layout, device, pin_memory, c10::nullopt); } Tensor upsample_nearest2d( @@ -161,7 +164,7 @@ Tensor avg_pool2d( pooling_output_shape(iW, kW, padW, dW, 1, ceil_mode); pool2d_shape_check( - self, kH, kW, dH, dW, padH, padW, 1, 1, iC, iH, iW, oH, oW); + self, kH, kW, dH, dW, padH, padW, 1, 1, iC, iH, iW, oH, oW, self.suggest_memory_format()); VulkanTensor y{{iN, iC, oH, oW}}; vulkan::detail::avg_pool2d( @@ -231,7 +234,8 @@ Tensor max_pool2d( iH, iW, oH, - oW); + oW, + self.suggest_memory_format()); VulkanTensor y{{iN, iC, oH, oW}}; vulkan::detail::max_pool2d( @@ -519,6 +523,7 @@ Tensor mean( const IntArrayRef dim, const bool keepdim, const optional dtype) { + TORCH_INTERNAL_ASSERT(!keepdim, "keepdim not implemented for Vulkan mean"); TORCH_INTERNAL_ASSERT(self.is_vulkan(), "mean expects Vulkan tensor input"); // Mean is implemented only for HW dimensions of 4-d tensor @@ -537,15 +542,17 @@ Tensor mean( return new_with_vtensor_vulkan(std::move(output), self.options()); } +#ifndef USE_VULKAN_API + TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl("slice.Tensor", TORCH_FN(at::native::vulkan::aten::slice)); - m.impl("reshape", TORCH_FN(at::native::vulkan::aten::reshape)); + m.impl("view", TORCH_FN(at::native::vulkan::aten::reshape)); m.impl("select.int", TORCH_FN(at::native::vulkan::aten::select)); m.impl("transpose.int", TORCH_FN(at::native::vulkan::aten::transpose)); - m.impl_UNBOXED("transpose_", at::native::vulkan::aten::transpose_); + m.impl("transpose_", at::native::vulkan::aten::transpose_); m.impl("view", TORCH_FN(at::native::vulkan::aten::view)); m.impl("unsqueeze", TORCH_FN(at::native::vulkan::aten::unsqueeze)); - m.impl_UNBOXED("empty.memory_format", at::native::vulkan::aten::empty); + m.impl("empty.memory_format", at::native::vulkan::aten::empty); m.impl("empty_strided", TORCH_FN(at::native::vulkan::aten::empty_strided)); m.impl("add.Tensor", TORCH_FN(at::native::vulkan::aten::add)); m.impl("clamp", TORCH_FN(at::native::vulkan::aten::clamp)); @@ -563,13 +570,15 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl("_cat", TORCH_FN(at::native::vulkan::aten::cat)); m.impl("mul.Scalar", TORCH_FN(at::native::vulkan::aten::mul_scalar)); m.impl("add.Scalar", TORCH_FN(at::native::vulkan::aten::add_scalar)); - m.impl_UNBOXED( + m.impl( "convolution_overrideable", at::native::vulkan::aten::convolution); - m.impl_UNBOXED("hardtanh_", at::native::vulkan::aten::hardtanh_); - m.impl_UNBOXED("relu_", at::native::vulkan::aten::relu_); - m.impl_UNBOXED("add_.Tensor", at::native::vulkan::aten::add_); + m.impl("hardtanh_", at::native::vulkan::aten::hardtanh_); + m.impl("relu_", at::native::vulkan::aten::relu_); + m.impl("add_.Tensor", at::native::vulkan::aten::add_); } +#endif /* USE_VULKAN_API */ + Tensor& copy_from_vulkan_(Tensor& self, const Tensor& src) { TORCH_INTERNAL_ASSERT( src.device().type() == DeviceType::Vulkan, diff --git a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h index 6afc28676f2b2..5b8c15855f08f 100644 --- a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h +++ b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h @@ -11,7 +11,7 @@ template struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { VulkanOpaqueTensorImpl( at::DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes, @@ -21,7 +21,8 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { data_type, device, opaque_handle, - sizes), + sizes, + false), strides_(strides.vec()) {} IntArrayRef strides() const override { diff --git a/aten/src/ATen/native/vulkan/VulkanOps.cpp b/aten/src/ATen/native/vulkan/VulkanOps.cpp index 302525582c9d5..8ad79a0c6f317 100644 --- a/aten/src/ATen/native/vulkan/VulkanOps.cpp +++ b/aten/src/ATen/native/vulkan/VulkanOps.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -28,21 +29,12 @@ void upsample_nearest2d( float scaleH, float scaleW) { auto device = context().device(); - auto physicalDevice = context().physicalDevice(); int64_t C = IN * IC; struct ConstBlock { - int32_t IW; - int32_t IH; - int32_t OW; - int32_t OH; float scaleX; float scaleY; }; - ConstBlock cb{safe_downcast(IW), - safe_downcast(IH), - safe_downcast(OW), - safe_downcast(OH), - scaleW, + ConstBlock cb{scaleW, scaleH}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); @@ -66,7 +58,7 @@ void upsample_nearest2d( WorkGroupSize workGroupSize{8, 8, 1}; auto& computeUnit = context().computeUnitFactory().get( - GLSL_SPV(upsampleNearest2d), descriptorSetLayout, workGroupSize); + GLSL_SPV(upsample_nearest2d), descriptorSetLayout, workGroupSize); computeUnit.createCommandBuffer(descriptorSet); input.image()->addImageMemoryBarrierToShaderRead(computeUnit.commandBuffer()); computeUnit.dispatchCommandBuffer(OW, OH, C, workGroupSize); @@ -112,17 +104,6 @@ void adaptive_avg_pool2d( const int64_t IC) { auto device = context().device(); int64_t C = IN * IC; - struct ConstBlock { - int32_t IW; - int32_t IH; - int32_t OW; - int32_t OH; - }; - ConstBlock cb{safe_downcast(IW), - safe_downcast(IH), - safe_downcast(OW), - safe_downcast(OH)}; - VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); VkDescriptorSetLayout descriptorSetLayout{}; VkDescriptorPool descriptorPool{}; @@ -140,7 +121,6 @@ void adaptive_avg_pool2d( output.image()->bindStorageImage(descriptorSet, 0); input.image()->bindShaderRead(descriptorSet, 1); - constBuffer.bind(descriptorSet, 2); WorkGroupSize workGroupSize{8, 8, 1}; auto& computeUnit = context().computeUnitFactory().get( @@ -239,20 +219,14 @@ void avg_pool2d( auto device = context().device(); const auto c = _n * _c; struct ConstBlock { - int32_t inputSize[4]; - int32_t outputSize[4]; int32_t kernelSize[2]; int32_t stride[2]; int32_t padding[2]; - int32_t dilate[2]; }; ConstBlock cb{ - {iW, iH, c, 0}, - {oW, oH, c, 0}, {kW, kH}, {dW, dH}, {padW, padH}, - {1, 1}, }; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); @@ -502,17 +476,10 @@ void add( auto W = os4[3]; auto device = context().device(); - auto physicalDevice = context().physicalDevice(); struct ConstBlock { - int32_t W; - int32_t H; - int32_t C; float alpha; }; - ConstBlock cb{safe_downcast(W), - safe_downcast(H), - safe_downcast(C), - alpha}; + ConstBlock cb{alpha}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); VkDescriptorSetLayout descriptorSetLayout{}; @@ -553,22 +520,16 @@ void add( void add(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; auto device = context().device(); struct ConstBlock { - int32_t inputSize[4]; float s; }; - ConstBlock cb{{safe_downcast(W), - safe_downcast(H), - safe_downcast(C_4), - 0}, - s}; + ConstBlock cb{s}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); VkDescriptorSetLayout descriptorSetLayout{}; @@ -606,22 +567,16 @@ void add(VulkanTensor& output, const VulkanTensor& input, const float s) { void mul(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; auto device = context().device(); struct ConstBlock { - int32_t inputSize[4]; float s; }; - ConstBlock cb{{safe_downcast(W), - safe_downcast(H), - safe_downcast(C_4), - 0}, - s}; + ConstBlock cb{s}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); VkDescriptorSetLayout descriptorSetLayout{}; @@ -1158,24 +1113,13 @@ void clamp( auto C = sizes[0] * sizes[1]; auto H = sizes[2]; auto W = sizes[3]; - auto C_4 = UP_DIV(C, 4); auto device = context().device(); - auto physicalDevice = context().physicalDevice(); struct ConstBlock { - int32_t W; - int32_t H; - int32_t C_4; - int32_t C; float min; float max; }; - ConstBlock cb{safe_downcast(W), - safe_downcast(H), - safe_downcast(C_4), - safe_downcast(C), - min, - max}; + ConstBlock cb{min, max}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); VkDescriptorSetLayout descriptorSetLayout{}; @@ -1222,14 +1166,10 @@ void addmm( const auto m2Sizes = m2.sizes(); TORCH_INTERNAL_ASSERT(m1Sizes.size() == 2); TORCH_INTERNAL_ASSERT(m2Sizes.size() == 2); - const auto m1H = m1Sizes[0]; const auto m1W = m1Sizes[1]; const auto m1C = 1; - const auto m1C_4 = UP_DIV(m1C, 4); const auto m2H = m2Sizes[0]; - const auto m2W = m2Sizes[1]; const auto m2C = 1; - const auto m2C_4 = UP_DIV(m2C, 4); const auto OH = m1Sizes[0]; const auto OW = m2Sizes[1]; @@ -1238,26 +1178,14 @@ void addmm( const auto C = m1C; const auto C_4 = UP_DIV(C, 4); - const auto K = m1W; auto device = context().device(); struct ConstBlock { - int32_t OW; - int32_t OH; - int32_t C_4; - int32_t C; - float beta; float alpha; - int32_t K; + float beta; }; - ConstBlock cb{safe_downcast(OW), - safe_downcast(OH), - safe_downcast(C_4), - safe_downcast(C), - beta, - alpha, - safe_downcast(K)}; + ConstBlock cb{alpha, beta}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); VkDescriptorSetLayout descriptorSetLayout{}; @@ -1269,15 +1197,14 @@ void addmm( VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, }; } else { descriptorTypes = { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, }; } @@ -1291,9 +1218,9 @@ void addmm( output.image()->bindStorageImage(descriptorSet, 0); m1.image()->bindShaderRead(descriptorSet, 1); m2.image()->bindShaderRead(descriptorSet, 2); - constBuffer.bind(descriptorSet, 3); if (hasT) { - (*t).image()->bindShaderRead(descriptorSet, 4); + (*t).image()->bindShaderRead(descriptorSet, 3); + constBuffer.bind(descriptorSet, 4); } WorkGroupSize workGroupSize{8, 8, 1}; @@ -1331,17 +1258,13 @@ void mean(VulkanTensor& output, const VulkanTensor& input) { int32_t C = safe_downcast(isizes[1]); int32_t H = safe_downcast(isizes[2]); int32_t W = safe_downcast(isizes[3]); - int32_t C_4 = UP_DIV(N * C, 4); auto device = context().device(); - auto physicalDevice = context().physicalDevice(); struct ConstBlock { int32_t W; int32_t H; - int32_t OW; - int32_t OH; }; - ConstBlock cb{W, H, C, N}; + ConstBlock cb{W, H}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); VkDescriptorSetLayout descriptorSetLayout{}; @@ -1364,12 +1287,12 @@ void mean(VulkanTensor& output, const VulkanTensor& input) { WorkGroupSize workGroupSize{1, 1, 1}; auto& computeUnit = context().computeUnitFactory().get( - GLSL_SPV(mean), descriptorSetLayout, workGroupSize); + GLSL_SPV(mean2d), descriptorSetLayout, workGroupSize); computeUnit.createCommandBuffer(descriptorSet); auto commandBuffer = computeUnit.commandBuffer(); output.image()->addImageMemoryBarrierToGeneral(commandBuffer); input.image()->addImageMemoryBarrierToShaderRead(commandBuffer); - computeUnit.dispatchCommandBuffer(1, 1, C_4, workGroupSize); + computeUnit.dispatchCommandBuffer(C, N, 1, workGroupSize); computeUnit.endCommandBuffer(); computeUnit.submitAndWaitCommandBuffer(); vkDestroyDescriptorPool(device, descriptorPool, nullptr); diff --git a/aten/src/ATen/native/vulkan/VulkanRegisterOpContextClass.cpp b/aten/src/ATen/native/vulkan/VulkanRegisterOpContextClass.cpp index 710727797c3e9..0a1c5fcea72df 100644 --- a/aten/src/ATen/native/vulkan/VulkanRegisterOpContextClass.cpp +++ b/aten/src/ATen/native/vulkan/VulkanRegisterOpContextClass.cpp @@ -8,6 +8,8 @@ namespace at { namespace native { namespace vulkan { +#ifndef USE_VULKAN_API + using detail::convolution2d::createConv2dClampPrePackOpContext; TORCH_LIBRARY(vulkan, m) { @@ -49,6 +51,9 @@ TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { m.impl("conv2d_clamp_run", detail::convolution2d::conv2d_clamp_run); } + +#endif /* USE_VULKAN_API */ + } // namespace vulkan } // namespace native } // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Adapter.h b/aten/src/ATen/native/vulkan/api/Adapter.h index 239edfb74518d..b4203530f6350 100644 --- a/aten/src/ATen/native/vulkan/api/Adapter.h +++ b/aten/src/ATen/native/vulkan/api/Adapter.h @@ -1,7 +1,10 @@ #pragma once +#ifdef USE_VULKAN_API + #include #include +#include namespace at { namespace native { @@ -28,9 +31,15 @@ struct Adapter final { // for now. return VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU == properties.deviceType; } + + inline Shader::WorkGroup local_work_group_size() const { + return { 4u, 4u, 4u, }; + } }; } // namespace api } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Allocator.h b/aten/src/ATen/native/vulkan/api/Allocator.h index f0f0c9baa59c4..15a69e1e261a2 100644 --- a/aten/src/ATen/native/vulkan/api/Allocator.h +++ b/aten/src/ATen/native/vulkan/api/Allocator.h @@ -1,13 +1,40 @@ #pragma once +// +// Do NOT include vk_mem_alloc.h directly. +// Always include this file (Allocator.h) instead. +// + +#ifdef USE_VULKAN_API + #include +#define VMA_VULKAN_VERSION 1000000 + +#ifdef USE_VULKAN_WRAPPER + #define VMA_STATIC_VULKAN_FUNCTIONS 0 +#else + #define VMA_DYNAMIC_VULKAN_FUNCTIONS 0 +#endif + +#define VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE (64ull * 1024 * 1024) +#define VMA_SMALL_HEAP_MAX_SIZE (256ull * 1024 * 1024) + #ifdef DEBUG + #define VMA_DEBUG_ALIGNMENT 4096 + #define VMA_DEBUG_ALWAYS_DEDICATED_MEMORY 0 + #define VMA_DEBUG_DETECT_CORRUPTION 1 + #define VMA_DEBUG_GLOBAL_MUTEX 1 + #define VMA_DEBUG_INITIALIZE_ALLOCATIONS 1 + #define VMA_DEBUG_MARGIN 64 + #define VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY 256 + #define VMA_RECORDING_ENABLED 1 + #define VMA_DEBUG_LOG(format, ...) \ do { \ printf(format, ##__VA_ARGS__); \ printf("\n"); \ - } while(false) + } while (false) #endif /* DEBUG */ #ifdef __clang__ @@ -16,11 +43,10 @@ #pragma clang diagnostic ignored "-Wunused-variable" #endif /* __clang__ */ -// Do NOT include vk_mem_alloc.h directly. -// Always include this file (Allocator.h) instead. - #include #ifdef __clang__ #pragma clang diagnostic pop #endif /* __clang__ */ + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Cache.h b/aten/src/ATen/native/vulkan/api/Cache.h index 36291a2227d4f..a93385088277d 100644 --- a/aten/src/ATen/native/vulkan/api/Cache.h +++ b/aten/src/ATen/native/vulkan/api/Cache.h @@ -1,5 +1,7 @@ #pragma once +#ifdef USE_VULKAN_API + #include namespace at { @@ -60,6 +62,10 @@ class Cache final { Factory factory_; }; +// +// Impl +// + template inline Cache::Cache(Factory factory) : factory_(std::move(factory)) { @@ -70,7 +76,7 @@ template inline auto Cache::retrieve( const Descriptor& descriptor) { auto iterator = cache_.find(descriptor); - if (cache_.cend() == iterator) { + if C10_UNLIKELY(cache_.cend() == iterator) { iterator = cache_.insert({descriptor, factory_(descriptor)}).first; } @@ -86,3 +92,5 @@ inline void Cache::purge() { } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp index a7793aea16dcb..247b51fa53950 100644 --- a/aten/src/ATen/native/vulkan/api/Command.cpp +++ b/aten/src/ATen/native/vulkan/api/Command.cpp @@ -1,29 +1,30 @@ #include +#include +#include namespace at { namespace native { namespace vulkan { namespace api { +namespace { -Command::Pool::Factory::Factory(const GPU& gpu) - : device_(gpu.device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device_, - "Invalid Vulkan device!"); -} +VkCommandPool create_command_pool( + const VkDevice device, + const uint32_t queue_family_index) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device, + "Invalid Vulkan device!"); -typename Command::Pool::Factory::Handle Command::Pool::Factory::operator()( - const Descriptor& descriptor) const { const VkCommandPoolCreateInfo command_pool_create_info{ VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO, nullptr, VK_COMMAND_POOL_CREATE_TRANSIENT_BIT, - descriptor.queue_family_index, + queue_family_index, }; VkCommandPool command_pool{}; VK_CHECK(vkCreateCommandPool( - device_, + device, &command_pool_create_info, nullptr, &command_pool)); @@ -32,15 +33,14 @@ typename Command::Pool::Factory::Handle Command::Pool::Factory::operator()( command_pool, "Invalid Vulkan command pool!"); - return Handle{ - command_pool, - Deleter(device_), - }; + return command_pool; } -void Command::Pool::purge( +void allocate_command_buffers( const VkDevice device, - const VkCommandPool command_pool) { + const VkCommandPool command_pool, + VkCommandBuffer* const command_buffers, + const uint32_t count) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( device, "Invalid Vulkan device!"); @@ -49,53 +49,58 @@ void Command::Pool::purge( command_pool, "Invalid Vulkan command pool!"); - VK_CHECK(vkResetCommandPool(device, command_pool, 0u)); -} - -namespace { - -VkCommandBuffer allocate_command_buffer( - const VkDevice device, - const VkCommandPool command_pool) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - command_pool, - "Invalid Vulkan command pool!"); + command_buffers && (count > 0u), + "Invalid usage!"); const VkCommandBufferAllocateInfo command_buffer_allocate_info{ VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, nullptr, command_pool, VK_COMMAND_BUFFER_LEVEL_PRIMARY, - 1u, + count, }; - VkCommandBuffer command_buffer{}; VK_CHECK(vkAllocateCommandBuffers( device, &command_buffer_allocate_info, - &command_buffer)); - - TORCH_CHECK( - command_buffer, - "Invalid Vulkan command buffer!"); - - return command_buffer; + command_buffers)); } } // namespace -Command::Buffer::Buffer(const VkDevice device, const VkCommandPool command_pool) - : command_buffer_(allocate_command_buffer(device, command_pool)) { +Command::Buffer::Buffer(const VkCommandBuffer command_buffer) + : command_buffer_(command_buffer) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, "Invalid Vulkan command buffer!"); } +Command::Buffer::Buffer(Buffer&& buffer) + : command_buffer_(std::move(buffer.command_buffer_)), + bound_(std::move(buffer.bound_)), + barriers_(std::move(buffer.barriers_)) { + buffer.invalidate(); +} + +Command::Buffer& Command::Buffer::operator=(Buffer&& buffer) { + if (&buffer != this) { + command_buffer_ = std::move(buffer.command_buffer_); + bound_ = std::move(buffer.bound_); + barriers_ = std::move(buffer.barriers_); + + buffer.invalidate(); + }; + + return *this; +} + void Command::Buffer::Buffer::begin() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer_, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + const VkCommandBufferBeginInfo command_buffer_begin_info{ VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, nullptr, @@ -106,43 +111,392 @@ void Command::Buffer::Buffer::begin() { VK_CHECK(vkBeginCommandBuffer( command_buffer_, &command_buffer_begin_info)); + + // Reset + bound_.reset(); + barriers_.reset(); } void Command::Buffer::Buffer::end() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer_, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + VK_CHECK(vkEndCommandBuffer(command_buffer_)); } -void Command::Buffer::bind(const VkPipeline pipeline) { +void Command::Buffer::barrier(const Pipeline::Barrier& barrier) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer_, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + + barriers_.stage.src |= barrier.stage.src; + barriers_.stage.dst |= barrier.stage.dst; + + barriers_.buffers.insert( + barriers_.buffers.end(), + barrier.buffers.begin(), + barrier.buffers.end()); + + barriers_.images.insert( + barriers_.images.end(), + barrier.images.begin(), + barrier.images.end()); +} + +void Command::Buffer::bind(const Pipeline::Object& pipeline) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer_, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( pipeline, "Invalid Vulkan pipeline!"); - vkCmdBindPipeline( - command_buffer_, - VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline); + if (pipeline.handle != bound_.pipeline.handle) { + vkCmdBindPipeline( + command_buffer_, + VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline.handle); + + bound_.pipeline = pipeline; + } } -void Command::Buffer::bind( - const VkPipelineLayout pipeline_layout, - const VkDescriptorSet descriptor_set) { +void Command::Buffer::bind(const Descriptor::Set& set) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - pipeline_layout, - "Invalid Vulkan pipeline layout!"); + command_buffer_, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + + const VkDescriptorSet descriptor_set = set.handle(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( descriptor_set, "Invalid Vulkan descriptor set!"); - vkCmdBindDescriptorSets( + if (descriptor_set != bound_.descriptor_set) { + vkCmdBindDescriptorSets( + command_buffer_, + VK_PIPELINE_BIND_POINT_COMPUTE, + bound_.pipeline.layout, + 0u, + 1u, + &descriptor_set, + 0u, + nullptr); + + bound_.descriptor_set = descriptor_set; + } +} + +void Command::Buffer::copy( + const Resource::Buffer::Object source, + const Resource::Buffer::Object destination) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, - VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline_layout, - 0u, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + source, + "Invalid Vulkan source buffer!"); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + destination, + "Invalid Vulkan destination buffer!"); + + barrier(); + + const VkBufferCopy buffer_copy{ + 0u, + 0u, + std::min(source.range, destination.range), + }; + + vkCmdCopyBuffer( + command_buffer_, + source.handle, + destination.handle, 1u, - &descriptor_set, + &buffer_copy); +} + +void Command::Buffer::dispatch( + const Shader::WorkGroup& global_work_group) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer_, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + + barrier(); + + vkCmdDispatch( + command_buffer_, + utils::div_up( + global_work_group.data[0u], + bound_.pipeline.local_work_group.data[0u]), + utils::div_up( + global_work_group.data[1u], + bound_.pipeline.local_work_group.data[1u]), + utils::div_up( + global_work_group.data[2u], + bound_.pipeline.local_work_group.data[2u])); +} + +void Command::Buffer::barrier() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer_, + "This command buffer is in an invalid state! " + "Potential reason: This command buffer is moved from."); + + if (barriers_.stage) { + c10::SmallVector buffer_memory_barriers; + + for (const Resource::Buffer::Barrier& barrier : barriers_.buffers) { + buffer_memory_barriers.push_back({ + VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, + nullptr, + barrier.memory.src, + barrier.memory.dst, + VK_QUEUE_FAMILY_IGNORED, + VK_QUEUE_FAMILY_IGNORED, + barrier.object.handle, + barrier.object.offset, + barrier.object.range, + }); + } + + c10::SmallVector image_memory_barriers; + + for (const Resource::Image::Barrier& barrier : barriers_.images) { + image_memory_barriers.push_back({ + VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, + nullptr, + barrier.memory.src, + barrier.memory.dst, + barrier.layout.src, + barrier.layout.dst, + VK_QUEUE_FAMILY_IGNORED, + VK_QUEUE_FAMILY_IGNORED, + barrier.object.handle, + { + VK_IMAGE_ASPECT_COLOR_BIT, + 0u, + VK_REMAINING_MIP_LEVELS, + 0u, + VK_REMAINING_ARRAY_LAYERS, + }, + }); + } + + vkCmdPipelineBarrier( + command_buffer_, + barriers_.stage.src, + barriers_.stage.dst, + 0u, + 0u, + nullptr, + buffer_memory_barriers.size(), + buffer_memory_barriers.data(), + image_memory_barriers.size(), + image_memory_barriers.data()); + } + + // Reset + barriers_.reset(); +} + +void Command::Buffer::invalidate() { + command_buffer_ = VK_NULL_HANDLE; +} + +inline void Command::Buffer::Bound::reset() { + pipeline = {}; + descriptor_set = VK_NULL_HANDLE; +} + +inline Command::Buffer::Barrier::Stage::operator bool() const { + return (0u != src) || (0u != dst); +} + +inline void Command::Buffer::Barrier::reset() { + stage = {}; + buffers.clear(); + images.clear(); +} + +Command::Pool::Pool(const GPU& gpu) + : device_(gpu.device), + command_pool_( + create_command_pool(gpu.device, gpu.adapter->compute_queue_family_index), + VK_DELETER(CommandPool)(device_)), + buffer_{} { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_, + "Invalid Vulkan device!"); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_pool_, + "Invalid Vulkan command pool!"); + + buffer_.pool.reserve(Configuration::kReserve); +} + +Command::Pool::Pool(Pool&& pool) + : device_(std::move(pool.device_)), + command_pool_(std::move(pool.command_pool_)), + buffer_(std::move(pool.buffer_)), + stream_(std::move(pool.stream_)) { + pool.invalidate(); +} + +Command::Pool& Command::Pool::operator=(Pool&& pool) { + if (&pool != this) { + device_ = std::move(pool.device_); + command_pool_ = std::move(pool.command_pool_); + buffer_ = std::move(pool.buffer_); + stream_ = std::move(pool.stream_); + + pool.invalidate(); + }; + + return *this; +} + +Command::Pool::~Pool() { + try { + if (device_ && command_pool_) { + purge(); + } + } + catch (const std::exception& e) { + LOG(WARNING) + << "Vulkan: Command pool destructor raised an exception! Error: " + << e.what(); + } + catch (...) { + LOG(WARNING) + << "Vulkan: Command pool destructor raised an unknown exception!"; + } +} + +Command::Buffer Command::Pool::allocate() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && command_pool_, + "This command pool is in an invalid state! " + "Potential reason: This command pool is moved from."); + + if (buffer_.pool.size() == buffer_.in_use) { + buffer_.pool.resize( + buffer_.pool.size() + + Configuration::kQuantum); + + allocate_command_buffers( + device_, + command_pool_.get(), + buffer_.pool.data() + buffer_.in_use, + Configuration::kQuantum); + } + + return Buffer(buffer_.pool[buffer_.in_use++]); +} + +Command::Buffer& Command::Pool::stream() { + if (!stream_.buffer) { + stream_.buffer = allocate(); + stream_.buffer.begin(); + stream_.counter = 0u; + } + + return stream_.buffer; +} + +void Command::Pool::purge() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && command_pool_, + "This command pool is in an invalid state! " + "Potential reason: This command pool is moved from."); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !stream_.buffer, + "Pending command buffer detected. Make sure all command buffers are " + "submitted to the queue for execution prior to reclaiming pool memory."); + + buffer_.in_use = 0u; + VK_CHECK(vkResetCommandPool(device_, command_pool_.get(), 0u)); +} + +void Command::Pool::submit( + const VkQueue queue, + const c10::ArrayRef buffers, + const Resource::Fence fence) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && command_pool_, + "This command pool is in an invalid state! " + "Potential reason: This command pool is moved from."); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + queue, + "Invalid Vulkan queue!"); + + c10::SmallVector command_buffers; + command_buffers.reserve(buffers.size()); + + for (const Buffer& buffer : buffers) { + VkCommandBuffer command_buffer = buffer.handle(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + command_buffer, + "Invalid Vulkan command buffer!"); + + // Are we submitting our one and only command stream, or a regular command + // buffer whose scope is manually maintained by the user? Automatically + // maintain state and submission rate if the former. + + if (stream_.buffer.handle() == command_buffer) { + // Hand the stream off to the driver if: + // - The user has implictly signaled interest in the results via a fence. + // - We are over the submission cutoff. We don't want to starve the GPU. + + if (fence || (stream_.counter++ > Configuration::kSubmit)) { + stream_.buffer.end(); + stream_.buffer.invalidate(); + } + // Skip - Accumulate more calls prior to submission. + else { + command_buffer = VK_NULL_HANDLE; + } + } + + if (command_buffer) { + command_buffers.push_back(command_buffer); + } + } + + if (!command_buffers.empty()) { + const VkSubmitInfo submit_info{ + VK_STRUCTURE_TYPE_SUBMIT_INFO, + nullptr, + 0u, + nullptr, + nullptr, + command_buffers.size(), + command_buffers.data(), 0u, - nullptr); + nullptr, + }; + + VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence.handle())); + } +} + +void Command::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + command_pool_.reset(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h index b0c171faa490b..8b60d8bcd8f54 100644 --- a/aten/src/ATen/native/vulkan/api/Command.h +++ b/aten/src/ATen/native/vulkan/api/Command.h @@ -1,8 +1,13 @@ #pragma once +#ifdef USE_VULKAN_API + #include -#include -#include +#include +#include +#include +#include +#include namespace at { namespace native { @@ -10,73 +15,109 @@ namespace vulkan { namespace api { struct Command final { + class Pool; + // - // Pool + // Buffer // - struct Pool final { - /* - Descriptor - */ + class Buffer final { + public: + explicit Buffer(VkCommandBuffer command_buffer = VK_NULL_HANDLE); + Buffer(const Buffer&) = delete; + Buffer& operator=(const Buffer&) = delete; + Buffer(Buffer&&); + Buffer& operator=(Buffer&&); + ~Buffer() = default; - struct Descriptor final { - uint32_t queue_family_index; - }; + operator bool() const; + VkCommandBuffer handle() const; - /* - Factory - */ + void begin(); + void end(); - class Factory final { - public: - explicit Factory(const GPU& gpu); + void barrier(const Pipeline::Barrier& barrier); + void bind(const Pipeline::Object& pipeline); + void bind(const Descriptor::Set& set); + void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination); + void dispatch(const Shader::WorkGroup& global_work_group); - typedef Pool::Descriptor Descriptor; - typedef VK_DELETER(CommandPool) Deleter; - typedef Handle Handle; + private: + friend class Pool; - struct Hasher { - size_t operator()(const Descriptor& descriptor) const; - }; + void barrier(); + void invalidate(); - Handle operator()(const Descriptor& descriptor) const; + private: + VkCommandBuffer command_buffer_; - private: - VkDevice device_; - }; + struct Bound final { + Pipeline::Object pipeline; + VkDescriptorSet descriptor_set; + + void reset(); + } bound_; - /* - Cache - */ + struct Barrier final { + struct Stage final { + VkPipelineStageFlags src; + VkPipelineStageFlags dst; - typedef api::Cache Cache; - Cache cache; + operator bool() const; + } stage; - explicit Pool(const GPU& gpu) - : cache(Factory(gpu)) { - } + c10::SmallVector buffers; + c10::SmallVector images; - static void purge(VkDevice device, VkCommandPool command_pool); - } pool; + void reset(); + } barriers_; + }; // - // Buffer + // Pool // - class Buffer final { + class Pool final { public: - Buffer(VkDevice device, VkCommandPool command_pool); + explicit Pool(const GPU& gpu); + Pool(const Pool&) = delete; + Pool& operator=(const Pool&) = delete; + Pool(Pool&&); + Pool& operator=(Pool&&); + ~Pool(); + + Buffer allocate(); + Buffer& stream(); + void purge(); + + void submit( + VkQueue queue, + c10::ArrayRef buffers, + Resource::Fence fence = {}); - void begin(); - void end(); - - void bind(VkPipeline pipeline); - void bind(VkPipelineLayout pipeline_layout, VkDescriptorSet descriptor_set); - void dispatch(); + private: + void invalidate(); private: - VkCommandBuffer command_buffer_; - }; + struct Configuration final { + static constexpr uint32_t kQuantum = 4u; + static constexpr uint32_t kReserve = 16u; + static constexpr uint32_t kSubmit = 10u; + }; + + VkDevice device_; + Handle command_pool_; + + struct { + std::vector pool; + size_t in_use; + } buffer_; + + struct { + Buffer buffer; + uint32_t counter; + } stream_; + } pool /* [thread_count] */; explicit Command(const GPU& gpu) : pool(gpu) { @@ -87,18 +128,17 @@ struct Command final { // Impl // -inline bool operator==( - const Command::Pool::Descriptor& _1, - const Command::Pool::Descriptor& _2) { - return _1.queue_family_index == _2.queue_family_index; +inline Command::Buffer::operator bool() const { + return VK_NULL_HANDLE != command_buffer_; } -inline size_t Command::Pool::Factory::Hasher::operator()( - const Descriptor& descriptor) const { - return c10::get_hash(descriptor.queue_family_index); +inline VkCommandBuffer Command::Buffer::handle() const { + return command_buffer_; } } // namespace api } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h index cbd53e8045eff..49f9ffa21a221 100644 --- a/aten/src/ATen/native/vulkan/api/Common.h +++ b/aten/src/ATen/native/vulkan/api/Common.h @@ -1,24 +1,41 @@ #pragma once +#ifdef USE_VULKAN_API + #include +#ifdef USE_VULKAN_SHADERC_RUNTIME +#include +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::Shader::Descriptor{ \ + name##_glsl, \ + } +#else +#include +#define VK_KERNEL(name) \ + ::at::native::vulkan::api::Shader::Descriptor{ \ + name##_spv, \ + name##_spv_len, \ + } +#endif /* USE_VULKAN_SHADERC_RUNTIME */ + #ifdef USE_VULKAN_WRAPPER #include #else #include -#endif +#endif /* USE_VULKAN_WRAPPER */ #define VK_CHECK(function) \ - { \ + do { \ const VkResult result = (function); \ TORCH_CHECK(VK_SUCCESS == result, "VkResult:", result); \ - } + } while (false) #define VK_CHECK_RELAXED(function) \ - { \ + do { \ const VkResult result = (function); \ TORCH_CHECK(VK_SUCCESS <= result, "VkResult:", result); \ - } + } while (false) #define VK_DELETER(Handle) \ at::native::vulkan::api::destroy_##Handle @@ -173,3 +190,5 @@ inline void Handle::reset(Type payload) { } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp index d0fa08dbde1d4..0a9a6e130f4f6 100644 --- a/aten/src/ATen/native/vulkan/api/Context.cpp +++ b/aten/src/ATen/native/vulkan/api/Context.cpp @@ -43,6 +43,40 @@ VkDevice create_device( &queue_priorities, }; + uint32_t device_extension_properties_count = 0; + VK_CHECK(vkEnumerateDeviceExtensionProperties( + physical_device, + nullptr, + &device_extension_properties_count, + nullptr)); + + std::vector device_extension_properties( + device_extension_properties_count); + + VK_CHECK(vkEnumerateDeviceExtensionProperties( + physical_device, + nullptr, + &device_extension_properties_count, + device_extension_properties.data())); + + constexpr const char* const requested_device_extensions[]{ + #ifdef VK_KHR_portability_subset + // https://vulkan.lunarg.com/doc/view/1.2.162.0/mac/1.2-extensions/vkspec.html#VUID-VkDeviceCreateInfo-pProperties-04451 + VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME, + #endif + }; + + std::vector enabled_device_extensions; + + for (const auto& requested_device_extension : requested_device_extensions) { + for (const auto& extension : device_extension_properties) { + if (strcmp(requested_device_extension, extension.extensionName) == 0) { + enabled_device_extensions.push_back(requested_device_extension); + break; + } + } + } + const VkDeviceCreateInfo device_create_info{ VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, nullptr, @@ -51,7 +85,8 @@ VkDevice create_device( &device_queue_create_info, 0u, nullptr, - 0u, + static_cast(enabled_device_extensions.size()), + enabled_device_extensions.data(), nullptr, }; @@ -78,25 +113,48 @@ VkQueue acquire_queue( } // namespace -void Context::Deleter::operator()(const VkDevice device) const { - // No VK_CHECK. Don't want an exception thrown in the destructor. - vkDeviceWaitIdle(device); - vkDestroyDevice(device, nullptr); -} - Context::Context(const Adapter& adapter) : adapter_(adapter), device_( create_device( adapter.handle, adapter.compute_queue_family_index), - Deleter{}), + &VK_DELETER(Device)), queue_(acquire_queue(device(), adapter.compute_queue_family_index)), command_(gpu()), shader_(gpu()), pipeline_(gpu()), descriptor_(gpu()), resource_(gpu()) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_, + "Invalid Vulkan device!"); +} + +Context::~Context() { + try { + flush(); + } + catch (const std::exception& e) { + LOG(WARNING) + << "Vulkan: Context destructor raised an exception! Error: " + << e.what(); + } + catch (...) { + LOG(WARNING) << "Vulkan: Context destructor raised an unknown exception!"; + } +} + +void Context::flush() { + VK_CHECK(vkQueueWaitIdle(queue())); + + resource().pool.purge(); + descriptor().pool.purge(); + command().pool.purge(); +} + +bool available() { + return context(); } Context* context() { @@ -106,6 +164,41 @@ Context* context() { return context; } +Descriptor::Set dispatch_prologue( + Command::Buffer& command_buffer, + const Shader::Layout::Signature& shader_layout_signature, + const Shader::Descriptor& shader_descriptor) { + Context* const context = api::context(); + const GPU gpu = context->gpu(); + Descriptor& descriptor = context->descriptor(); + Pipeline& pipeline = context->pipeline(); + Shader& shader = context->shader(); + + const Shader::Layout::Object shader_layout = + shader.layout.cache.retrieve({ + shader_layout_signature, + }); + + command_buffer.bind( + pipeline.cache.retrieve({ + pipeline.layout.cache.retrieve({ + shader_layout.handle, + }), + shader.cache.retrieve(shader_descriptor), + gpu.adapter->local_work_group_size(), + })); + + return descriptor.pool.allocate(shader_layout); +} + +void dispatch_epilogue( + Command::Buffer& command_buffer, + const Descriptor::Set& descriptor_set, + const Shader::WorkGroup& global_work_group) { + command_buffer.bind(descriptor_set); + command_buffer.dispatch(global_work_group); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h index 5d593bdd9bc17..41adfc5fb2723 100644 --- a/aten/src/ATen/native/vulkan/api/Context.h +++ b/aten/src/ATen/native/vulkan/api/Context.h @@ -1,5 +1,7 @@ #pragma once +#ifdef USE_VULKAN_API + #include #include #include @@ -29,58 +31,39 @@ class Context final { Context(Context&&) = default; Context& operator=(const Context&) = delete; Context& operator=(Context&&) = default; - ~Context() = default; - - inline GPU gpu() { - // A GPU is simply a (physical device, logical device, device queue) trio. - return { - &adapter_, - device(), - queue(), - }; - } + ~Context(); - inline Command& command() { - return command_; - } + GPU gpu(); + Command& command(); + Shader& shader(); + Pipeline& pipeline(); + Descriptor& descriptor(); + Resource& resource(); - inline Shader& shader() { - return shader_; - } + // GPU RPC - inline Pipeline& pipeline() { - return pipeline_; - } + template + void dispatch( + Command::Buffer& command_buffer, + const Shader::Layout::Signature& shader_layout_signature, + const Shader::Descriptor& shader_descriptor, + const Shader::WorkGroup& global_work_group, + Arguments&&... arguments); - inline Descriptor& descriptor() { - return descriptor_; - } + // This function is expensive and its use consequential for performance. Only + // use this function for debugging or as a short term hack on way to a more + // performant solution. - inline Resource& resource() { - return resource_; - } + void flush(); private: - inline VkDevice device() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_); - return device_.get(); - } - - inline VkQueue queue() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(queue_); - return queue_; - } - - private: - class Deleter final { - public: - void operator()(VkDevice device) const; - }; + VkDevice device(); + VkQueue queue(); private: // Construction and destruction order matters. Do not move members around. Adapter adapter_; - Handle device_; + Handle device_; VkQueue queue_; Command command_; Shader shader_; @@ -89,9 +72,109 @@ class Context final { Resource resource_; }; +bool available(); Context* context(); +// +// Impl +// + +inline GPU Context::gpu() { + // A GPU is simply a (physical device, logical device, device queue) trio. + return { + &adapter_, + device(), + queue(), + }; +} + +inline Command& Context::command() { + return command_; +} + +inline Shader& Context::shader() { + return shader_; +} + +inline Pipeline& Context::pipeline() { + return pipeline_; +} + +inline Descriptor& Context::descriptor() { + return descriptor_; +} + +inline Resource& Context::resource() { + return resource_; +} + +inline VkDevice Context::device() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_); + return device_.get(); +} + +inline VkQueue Context::queue() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(queue_); + return queue_; +} + +namespace detail { + +template< + size_t...Indices, + typename ...Arguments> +inline void bind( + Descriptor::Set& descriptor_set, + const std::index_sequence, + Arguments&&...arguments) { + C10_UNUSED const int _[]{ + 0, + (descriptor_set.bind(Indices, std::forward(arguments)), 0)..., + }; +} + +} // namespace detail + +template +inline void Context::dispatch( + Command::Buffer& command_buffer, + const Shader::Layout::Signature& shader_layout_signature, + const Shader::Descriptor& shader_descriptor, + const Shader::WorkGroup& global_work_group, + Arguments&&... arguments) { + // Forward declaration + Descriptor::Set dispatch_prologue( + Command::Buffer&, + const Shader::Layout::Signature&, + const Shader::Descriptor&); + + // Factor out template parameter independent code to minimize code bloat. + Descriptor::Set descriptor_set = dispatch_prologue( + command_buffer, + shader_layout_signature, + shader_descriptor); + + detail::bind( + descriptor_set, + std::index_sequence_for{}, + std::forward(arguments)...); + + // Forward declaration + void dispatch_epilogue( + Command::Buffer&, + const Descriptor::Set&, + const Shader::WorkGroup&); + + // Factor out template parameter independent code to minimize code bloat. + dispatch_epilogue( + command_buffer, + descriptor_set, + global_work_group); +} + } // namespace api } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp index ff0505ccebca2..5bdcb0b7fd023 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.cpp +++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp @@ -4,16 +4,18 @@ namespace at { namespace native { namespace vulkan { namespace api { +namespace { -const Descriptor::Pool::Descriptor Descriptor::Pool::kDefault{ - 1024u, - { - // Note: It is OK for the sum of descriptors per type, below, to exceed - // the max total figure above, but be concenious of memory consumption. - // Considering how the descriptor pool must be frequently purged anyway - // as a result of the impracticality of having enormous pools that - // persist through the execution of the program, there is diminishing - // return in increasing max counts. +VkDescriptorPool create_descriptor_pool(const VkDevice device) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device, + "Invalid Vulkan device!"); + + const struct { + uint32_t capacity; + c10::SmallVector sizes; + } descriptor { + 1024u, { /* Buffers @@ -21,11 +23,11 @@ const Descriptor::Pool::Descriptor Descriptor::Pool::kDefault{ { VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - 256u, + 1024u, }, { VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, - 256u, + 1024u, }, /* @@ -34,29 +36,19 @@ const Descriptor::Pool::Descriptor Descriptor::Pool::kDefault{ { VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - 256u, + 1024u, }, { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - 256u, + 1024u, }, }, - }, -}; - -Descriptor::Pool::Factory::Factory(const GPU& gpu) - : device_(gpu.device) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, - "Invalid Vulkan device!"); -} + }; -typename Descriptor::Pool::Factory::Handle Descriptor::Pool::Factory::operator()( - const Descriptor& descriptor) const { const VkDescriptorPoolCreateInfo descriptor_pool_create_info{ VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, nullptr, - 0u, /* Do not use VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT */ + 0u, descriptor.capacity, static_cast(descriptor.sizes.size()), descriptor.sizes.data(), @@ -64,7 +56,7 @@ typename Descriptor::Pool::Factory::Handle Descriptor::Pool::Factory::operator() VkDescriptorPool descriptor_pool{}; VK_CHECK(vkCreateDescriptorPool( - device_, + device, &descriptor_pool_create_info, nullptr, &descriptor_pool)); @@ -73,15 +65,15 @@ typename Descriptor::Pool::Factory::Handle Descriptor::Pool::Factory::operator() descriptor_pool, "Invalid Vulkan descriptor pool!"); - return Handle{ - descriptor_pool, - Deleter(device_), - }; + return descriptor_pool; } -void Descriptor::Pool::purge( +void allocate_descriptor_sets( const VkDevice device, - const VkDescriptorPool descriptor_pool) { + const VkDescriptorPool descriptor_pool, + const VkDescriptorSetLayout descriptor_set_layout, + VkDescriptorSet* const descriptor_sets, + const uint32_t count) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( device, "Invalid Vulkan device!"); @@ -90,48 +82,327 @@ void Descriptor::Pool::purge( descriptor_pool, "Invalid Vulkan descriptor pool!"); - VK_CHECK(vkResetDescriptorPool(device, descriptor_pool, 0u)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + descriptor_set_layout, + "Invalid Vulkan descriptor set layout!"); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + descriptor_sets && (count > 0u), + "Invalid usage!"); + + const std::vector descriptor_set_layouts{ + count, + descriptor_set_layout, + }; + + const VkDescriptorSetAllocateInfo descriptor_set_allocate_info{ + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, + nullptr, + descriptor_pool, + descriptor_set_layouts.size(), + descriptor_set_layouts.data(), + }; + + VK_CHECK(vkAllocateDescriptorSets( + device, + &descriptor_set_allocate_info, + descriptor_sets)); } -Descriptor::Factory::Factory( +} // namespace + +Descriptor::Set::Set( const VkDevice device, - const VkDescriptorPool descriptor_pool) + VkDescriptorSet descriptor_set, + const Shader::Layout::Signature& shader_layout_signature) : device_(device), - descriptor_pool_(descriptor_pool) { + descriptor_set_(descriptor_set), + shader_layout_signature_(shader_layout_signature), + bindings_{} { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - device, + device_, "Invalid Vulkan device!"); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - descriptor_pool, + descriptor_set_, + "Invalid Vulkan descriptor set!"); +} + +Descriptor::Set::Set(Set&& set) + : device_(std::move(set.device_)), + descriptor_set_(std::move(set.descriptor_set_)), + shader_layout_signature_(std::move(set.shader_layout_signature_)), + bindings_(std::move(set.bindings_)) { + set.invalidate(); +} + +Descriptor::Set& Descriptor::Set::operator=(Set&& set) { + if (&set != this) { + device_ = std::move(set.device_); + descriptor_set_ = std::move(set.descriptor_set_); + shader_layout_signature_ = std::move(set.shader_layout_signature_); + bindings_ = std::move(set.bindings_); + + set.invalidate(); + }; + + return *this; +} + +Descriptor::Set& Descriptor::Set::bind( + const uint32_t binding, + const Resource::Buffer::Object& buffer) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_set_, + "This descriptor set is in an invalid state! " + "Potential reason: This descriptor set is moved from."); + + update({ + binding, + shader_layout_signature_[binding], + { + .buffer = { + buffer.handle, + buffer.offset, + buffer.range, + }, + }, + }); + + return *this; +} + +Descriptor::Set& Descriptor::Set::bind( + const uint32_t binding, + const Resource::Image::Object& image) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_set_, + "This descriptor set is in an invalid state! " + "Potential reason: This descriptor set is moved from."); + + update({ + binding, + shader_layout_signature_[binding], + { + .image = { + image.sampler, + image.view, + [](const VkDescriptorType type, const VkImageLayout layout) { + return (VK_DESCRIPTOR_TYPE_STORAGE_IMAGE == type) ? + VK_IMAGE_LAYOUT_GENERAL : layout; + }(shader_layout_signature_[binding], image.layout), + }, + }, + }); + + return *this; +} + +VkDescriptorSet Descriptor::Set::handle() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_set_, + "This descriptor set is in an invalid state! " + "Potential reason: This descriptor set is moved from."); + + if (bindings_.dirty) { + const auto is_buffer = [](const VkDescriptorType type) { + switch (type) { + case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER: + case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER: + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC: + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC: + return true; + + default: + return false; + } + }; + + const auto is_image = [](const VkDescriptorType type) { + switch (type) { + case VK_DESCRIPTOR_TYPE_SAMPLER: + case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: + case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE: + case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: + case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT: + return true; + + default: + return false; + } + }; + + c10::SmallVector write_descriptor_sets; + + for (const Item& item : bindings_.items) { + VkWriteDescriptorSet write{ + VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET, + nullptr, + descriptor_set_, + item.binding, + 0u, + 1u, + item.type, + nullptr, + nullptr, + nullptr, + }; + + if (is_buffer(item.type)) { + write.pBufferInfo = &item.info.buffer; + } + else if (is_image(item.type)) { + write.pImageInfo = &item.info.image; + } + + write_descriptor_sets.emplace_back(write); + } + + vkUpdateDescriptorSets( + device_, + write_descriptor_sets.size(), + write_descriptor_sets.data(), + 0u, + nullptr); + + // Reset + bindings_.dirty = false; + } + + return descriptor_set_; +} + +void Descriptor::Set::invalidate() { + device_ = VK_NULL_HANDLE; + descriptor_set_ = VK_NULL_HANDLE; +} + +void Descriptor::Set::update(const Item& item) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_set_, + "This descriptor set is in an invalid state! " + "Potential reason: This descriptor set is moved from."); + + const auto items_itr = std::find_if( + bindings_.items.begin(), + bindings_.items.end(), + [binding = item.binding](const Item& other) { + return other.binding == binding; + }); + + if (bindings_.items.end() == items_itr) { + bindings_.items.emplace_back(item); + } + else { + *items_itr = item; + } + + bindings_.dirty = true; +} + +Descriptor::Pool::Pool(const GPU& gpu) + : device_(gpu.device), + descriptor_pool_( + create_descriptor_pool(gpu.device), + VK_DELETER(DescriptorPool)(device_)) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_, + "Invalid Vulkan device!"); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + descriptor_pool_, "Invalid Vulkan descriptor pool!"); } -VkDescriptorSet Descriptor::Factory::allocate( - const VkDescriptorSetLayout descriptor_set_layout) { - const VkDescriptorSetAllocateInfo descriptor_set_allocate_info{ - VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, - nullptr, - descriptor_pool_, - 1u, - &descriptor_set_layout, +Descriptor::Pool::Pool(Pool&& pool) + : device_(std::move(pool.device_)), + descriptor_pool_(std::move(pool.descriptor_pool_)), + set_(std::move(pool.set_)) { + pool.invalidate(); +} + +Descriptor::Pool& Descriptor::Pool::operator=(Pool&& pool) { + if (&pool != this) { + device_ = std::move(pool.device_); + descriptor_pool_ = std::move(pool.descriptor_pool_); + set_ = std::move(pool.set_); + + pool.invalidate(); }; - VkDescriptorSet descriptor_set{}; - VK_CHECK(vkAllocateDescriptorSets( + return *this; +} + +Descriptor::Pool::~Pool() { + try { + if (device_ && descriptor_pool_) { + purge(); + } + } + catch (const std::exception& e) { + LOG(WARNING) + << "Vulkan: Descriptor pool destructor raised an exception! Error: " + << e.what(); + } + catch (...) { + LOG(WARNING) + << "Vulkan: Descriptor pool destructor raised an unknown exception!"; + } +} + +Descriptor::Set Descriptor::Pool::allocate( + const Shader::Layout::Object& shader_layout) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_pool_, + "This descriptor pool is in an invalid state! " + "Potential reason: This descriptor pool is moved from."); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + shader_layout, + "Invalid Vulkan shader layout!"); + + auto iterator = set_.layouts.find(shader_layout.handle); + if (set_.layouts.cend() == iterator) { + iterator = set_.layouts.insert({shader_layout.handle, {}}).first; + iterator->second.pool.reserve(Configuration::kReserve); + } + + auto& layout = iterator->second; + + if (layout.pool.size() == layout.in_use) { + layout.pool.resize( + layout.pool.size() + + Configuration::kQuantum); + + allocate_descriptor_sets( + device_, + descriptor_pool_.get(), + shader_layout.handle, + layout.pool.data() + layout.in_use, + Configuration::kQuantum); + } + + return Set( device_, - &descriptor_set_allocate_info, - &descriptor_set)); + layout.pool[layout.in_use++], + shader_layout.signature); +} - TORCH_CHECK( - descriptor_set, - "Invalid Vulkan descriptor set!"); +void Descriptor::Pool::purge() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && descriptor_pool_, + "This descriptor pool is in an invalid state! " + "Potential reason: This descriptor pool is moved from."); - return descriptor_set; + VK_CHECK(vkResetDescriptorPool(device_, descriptor_pool_.get(), 0u)); + set_.layouts.clear(); } -void Descriptor::Factory::purge() { - Pool::purge(device_, descriptor_pool_); +void Descriptor::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + descriptor_pool_.reset(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h index bc6c147239900..6c50a350d7f3b 100644 --- a/aten/src/ATen/native/vulkan/api/Descriptor.h +++ b/aten/src/ATen/native/vulkan/api/Descriptor.h @@ -1,8 +1,10 @@ #pragma once +#ifdef USE_VULKAN_API + #include -#include -#include +#include +#include namespace at { namespace native { @@ -51,113 +53,99 @@ namespace api { struct Descriptor final { // - // Pool + // Set // - struct Pool final { - /* - Descriptor - */ - - struct Descriptor final { - uint32_t capacity; - c10::SmallVector sizes; - }; - - static const Descriptor kDefault; - - /* - Factory - */ - - class Factory final { - public: - explicit Factory(const GPU& gpu); + class Set final { + public: + Set( + VkDevice device, + VkDescriptorSet descriptor_set, + const Shader::Layout::Signature& shader_layout_signature); + Set(const Set&) = delete; + Set& operator=(const Set&) = delete; + Set(Set&&); + Set& operator=(Set&&); + ~Set() = default; - typedef Pool::Descriptor Descriptor; - typedef VK_DELETER(DescriptorPool) Deleter; - typedef Handle Handle; + Set& bind(uint32_t binding, const Resource::Buffer::Object& buffer); + Set& bind(uint32_t binding, const Resource::Image::Object& image); - struct Hasher { - size_t operator()(const Descriptor& descriptor) const; - }; + VkDescriptorSet handle() const; - Handle operator()(const Descriptor& descriptor) const; + private: + void invalidate(); - private: - VkDevice device_; + private: + struct Item final { + uint32_t binding; + VkDescriptorType type; + + union { + VkDescriptorBufferInfo buffer; + VkDescriptorImageInfo image; + } info; }; - /* - Cache - */ - - typedef api::Cache Cache; - Cache cache; + void update(const Item& item); - explicit Pool(const GPU& gpu) - : cache(Factory(gpu)) { - } + private: + VkDevice device_; + VkDescriptorSet descriptor_set_; + Shader::Layout::Signature shader_layout_signature_; - static void purge(VkDevice device, VkDescriptorPool descriptor_pool); - } pool; + struct { + c10::SmallVector items; + mutable bool dirty; + } bindings_; + }; - /* - Factory - */ + // + // Pool + // - class Factory final { + class Pool final { public: - Factory(VkDevice device, VkDescriptorPool descriptor_pool); - - VkDescriptorSet allocate(VkDescriptorSetLayout descriptor_set_layout); + explicit Pool(const GPU& gpu); + Pool(const Pool&) = delete; + Pool& operator=(const Pool&) = delete; + Pool(Pool&&); + Pool& operator=(Pool&&); + ~Pool(); + + Set allocate(const Shader::Layout::Object& shader_layout); void purge(); private: + void invalidate(); + + private: + struct Configuration final { + static constexpr uint32_t kQuantum = 16u; + static constexpr uint32_t kReserve = 64u; + }; + VkDevice device_; - VkDescriptorPool descriptor_pool_; - } factory; + Handle descriptor_pool_; - explicit Descriptor(const GPU& gpu) - : pool(gpu), - factory(gpu.device, pool.cache.retrieve(Pool::kDefault)) { - } -}; + struct { + struct Layout final { + std::vector pool; + size_t in_use; + }; -// -// Impl -// + ska::flat_hash_map layouts; + } set_; + } pool /* [thread_count] */; -inline bool operator==( - const Descriptor::Pool::Descriptor& _1, - const Descriptor::Pool::Descriptor& _2) { - return (_1.capacity == _2.capacity) && - (_1.sizes == _2.sizes); -} - -inline size_t Descriptor::Pool::Factory::Hasher::operator()( - const Descriptor& descriptor) const { - size_t hash = c10::get_hash(descriptor.capacity); - - for (const VkDescriptorPoolSize& descriptor_pool_size : descriptor.sizes) { - hash = c10::hash_combine( - hash, - c10::get_hash( - descriptor_pool_size.type, - descriptor_pool_size.descriptorCount)); + explicit Descriptor(const GPU& gpu) + : pool(gpu) { } - - return hash; -} +}; } // namespace api } // namespace vulkan } // namespace native } // namespace at -inline bool operator==( - const VkDescriptorPoolSize& _1, - const VkDescriptorPoolSize& _2) { - return (_1.type == _2.type) && - (_1.descriptorCount == _2.descriptorCount); -} +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp index bd9881c054436..89e85892ee0c3 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -100,39 +100,32 @@ typename Pipeline::Factory::Handle Pipeline::Factory::operator()( descriptor.shader_module, "Invalid Vulkan shader module!"); - constexpr uint32_t x_offset = 0u; - constexpr uint32_t x_size = sizeof(Shader::WorkGroup::x); - constexpr uint32_t y_offset = x_offset + x_size; - constexpr uint32_t y_size = sizeof(Shader::WorkGroup::y); - constexpr uint32_t z_offset = y_offset + y_size; - constexpr uint32_t z_size = sizeof(Shader::WorkGroup::z); - constexpr VkSpecializationMapEntry specialization_map_entires[3]{ // X { - 1u, - x_offset, - x_size, + 0u, + offsetof(Shader::WorkGroup, data[0u]), + sizeof(Shader::WorkGroup::data[0u]), }, // Y { - 2u, - y_offset, - y_size, + 1u, + offsetof(Shader::WorkGroup, data[1u]), + sizeof(Shader::WorkGroup::data[1u]), }, // Z { - 3u, - z_offset, - z_size, + 2u, + offsetof(Shader::WorkGroup, data[2u]), + sizeof(Shader::WorkGroup::data[2u]), }, }; const VkSpecializationInfo specialization_info{ 3u, specialization_map_entires, - sizeof(Shader::WorkGroup), - &descriptor.work_group, + sizeof(descriptor.local_work_group), + &descriptor.local_work_group, }; const VkComputePipelineCreateInfo compute_pipeline_create_info{ @@ -172,6 +165,14 @@ typename Pipeline::Factory::Handle Pipeline::Factory::operator()( }; } +Pipeline::Cache::Cache(Factory factory) + : cache_(std::move(factory)) { +} + +void Pipeline::Cache::purge() { + cache_.purge(); +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h index c327a140eded3..794193d8a1614 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.h +++ b/aten/src/ATen/native/vulkan/api/Pipeline.h @@ -1,7 +1,10 @@ #pragma once +#ifdef USE_VULKAN_API + #include #include +#include #include #include @@ -30,6 +33,22 @@ namespace api { // struct Pipeline final { + // + // Barrier + // + + struct Barrier final { + struct Stage final { + VkPipelineStageFlags src; + VkPipelineStageFlags dst; + } stage; + + c10::SmallVector buffers; + c10::SmallVector images; + + operator bool() const; + }; + // // Layout // @@ -77,6 +96,21 @@ struct Pipeline final { } } layout; + // + // Stage + // + + struct Stage final { + typedef uint8_t Flags; + + enum Type : Flags { + None = 0u << 0u, + Compute = 1u << 0u, + Host = 1u << 1u, + Transfer = 1u << 2u, + }; + }; + /* Descriptor */ @@ -84,7 +118,7 @@ struct Pipeline final { struct Descriptor final { VkPipelineLayout pipeline_layout; VkShaderModule shader_module; - Shader::WorkGroup work_group; + Shader::WorkGroup local_work_group; }; /* @@ -110,12 +144,37 @@ struct Pipeline final { api::Handle pipeline_cache_; }; + /* + Object + */ + + struct Object final { + VkPipeline handle; + VkPipelineLayout layout; + Shader::WorkGroup local_work_group; + + operator bool() const; + }; + /* Cache */ - typedef api::Cache Cache; - Cache cache; + class Cache final { + public: + explicit Cache(Factory factory); + Cache(const Cache&) = delete; + Cache& operator=(const Cache&) = delete; + Cache(Cache&&) = default; + Cache& operator=(Cache&&) = default; + ~Cache() = default; + + Object retrieve(const Descriptor& descriptor); + void purge(); + + private: + api::Cache cache_; + } cache; explicit Pipeline(const GPU& gpu) : layout(gpu), @@ -127,10 +186,21 @@ struct Pipeline final { // Impl // +inline Pipeline::Barrier::operator bool() const { + return (0u != stage.src) || + (0u != stage.dst) || + !buffers.empty() || + !images.empty(); +} + inline bool operator==( const Pipeline::Layout::Descriptor& _1, const Pipeline::Layout::Descriptor& _2) { - return (_1.descriptor_set_layout == _2.descriptor_set_layout); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Pipeline::Layout::Descriptor))); } inline size_t Pipeline::Layout::Factory::Hasher::operator()( @@ -141,9 +211,11 @@ inline size_t Pipeline::Layout::Factory::Hasher::operator()( inline bool operator==( const Pipeline::Descriptor& _1, const Pipeline::Descriptor& _2) { - return (_1.pipeline_layout == _2.pipeline_layout) && - (_1.shader_module == _2.shader_module) && - (_1.work_group == _2.work_group); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Pipeline::Descriptor))); } inline size_t Pipeline::Factory::Hasher::operator()( @@ -151,12 +223,28 @@ inline size_t Pipeline::Factory::Hasher::operator()( return c10::get_hash( descriptor.pipeline_layout, descriptor.shader_module, - descriptor.work_group.x, - descriptor.work_group.y, - descriptor.work_group.z); + descriptor.local_work_group.data[0u], + descriptor.local_work_group.data[1u], + descriptor.local_work_group.data[2u]); +} + +inline Pipeline::Object::operator bool() const { + return (VK_NULL_HANDLE != handle) && + (VK_NULL_HANDLE != layout); +} + +inline Pipeline::Object Pipeline::Cache::retrieve( + const Descriptor& descriptor) { + return { + cache_.retrieve(descriptor), + descriptor.pipeline_layout, + descriptor.local_work_group, + }; } } // namespace api } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index 6969883cb1832..adda610fb90c6 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -32,7 +32,7 @@ VmaAllocator create_allocator( nullptr, 1u, nullptr, - nullptr, // TODO (Ashkan): VULKAN_WRAPPER + nullptr, nullptr, instance, VK_API_VERSION_1_0, @@ -40,19 +40,20 @@ VmaAllocator create_allocator( VmaAllocator allocator{}; VK_CHECK(vmaCreateAllocator(&allocator_create_info, &allocator)); - TORCH_CHECK(allocator, "Invalid VMA allocator!"); + TORCH_CHECK(allocator, "Invalid VMA (Vulkan Memory Allocator) allocator!"); return allocator; } VmaAllocationCreateInfo create_allocation_create_info( - const VmaMemoryUsage usage) { + const Resource::Memory::Descriptor& descriptor) { return VmaAllocationCreateInfo{ - 0u, /* VMA_ALLOCATION_CREATE_MAPPED_BIT - MoltenVK Issue #175 */ - /* VMA_ALLOCATION_CREATE_STRATEGY_MIN_FRAGMENTATION_BIT */ - usage, - 0u, - 0u, + VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT | + /* VMA_ALLOCATION_CREATE_MAPPED_BIT - MoltenVK Issue #175 */ + 0, + descriptor.usage, + descriptor.required, + descriptor.preferred, 0u, VK_NULL_HANDLE, nullptr, @@ -63,53 +64,59 @@ void release_buffer(const Resource::Buffer& buffer) { // Safe to pass null as buffer or allocation. vmaDestroyBuffer( buffer.memory.allocator, - buffer.handle, + buffer.object.handle, buffer.memory.allocation); } void release_image(const Resource::Image& image) { - if (VK_NULL_HANDLE != image.view) { + // Sampler is an immutable object. Its lifetime is managed through the cache. + + if (VK_NULL_HANDLE != image.object.view) { VmaAllocatorInfo allocator_info{}; vmaGetAllocatorInfo(image.memory.allocator, &allocator_info); - vkDestroyImageView(allocator_info.device, image.view, nullptr); + vkDestroyImageView(allocator_info.device, image.object.view, nullptr); } // Safe to pass null as image or allocation. vmaDestroyImage( image.memory.allocator, - image.handle, + image.object.handle, image.memory.allocation); } } // namespace -void* map(const Resource::Memory& memory) { - // Call will be ignored by implementation if the memory type this allocation - // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior - // we want. - VK_CHECK(vmaInvalidateAllocation( - memory.allocator, memory.allocation, 0u, VK_WHOLE_SIZE)); - +void* map( + const Resource::Memory& memory, + const Resource::Memory::Access::Flags access) { void* data = nullptr; VK_CHECK(vmaMapMemory(memory.allocator, memory.allocation, &data)); + if (access & Resource::Memory::Access::Read) { + // Call will be ignored by implementation if the memory type this allocation + // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior + // we want. + VK_CHECK(vmaInvalidateAllocation( + memory.allocator, memory.allocation, 0u, VK_WHOLE_SIZE)); + } + return data; } Resource::Memory::Scope::Scope( const VmaAllocator allocator, const VmaAllocation allocation, - const Access access) + const Access::Flags access) : allocator_(allocator), allocation_(allocation), access_(access) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( allocator, - "Invalid VMA allocator!"); + "Invalid VMA (Vulkan Memory Allocator) allocator!"); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( allocation, - "Invalid VMA allocation!"); + "Invalid VMA (Vulkan Memory Allocator) allocation!"); } void Resource::Memory::Scope::operator()(const void* const data) const { @@ -117,30 +124,312 @@ void Resource::Memory::Scope::operator()(const void* const data) const { return; } - vmaUnmapMemory(allocator_, allocation_); - - if (Access::Write == access_) { + if (access_ & Access::Write) { // Call will be ignored by implementation if the memory type this allocation // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior // we want. VK_CHECK(vmaFlushAllocation(allocator_, allocation_, 0u, VK_WHOLE_SIZE)); } + + vmaUnmapMemory(allocator_, allocation_); +} + +Resource::Image::Sampler::Factory::Factory(const GPU& gpu) + : device_(gpu.device) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_, + "Invalid Vulkan device!"); +} + +typename Resource::Image::Sampler::Factory::Handle +Resource::Image::Sampler::Factory::operator()( + const Descriptor& descriptor) const { + const VkSamplerCreateInfo sampler_create_info{ + VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, + nullptr, + 0u, + descriptor.filter, + descriptor.filter, + descriptor.mipmap_mode, + descriptor.address_mode, + descriptor.address_mode, + descriptor.address_mode, + 0.0f, + VK_FALSE, + 1.0f, + VK_FALSE, + VK_COMPARE_OP_NEVER, + 0.0f, + VK_LOD_CLAMP_NONE, + descriptor.border, + VK_FALSE, + }; + + VkSampler sampler{}; + VK_CHECK(vkCreateSampler( + device_, + &sampler_create_info, + nullptr, + &sampler)); + + TORCH_CHECK( + sampler, + "Invalid Vulkan image sampler!"); + + return Handle{ + sampler, + Deleter(device_), + }; } -Resource::Pool::Pool(const GPU& gpu) +VkFence Resource::Fence::handle(const bool add_to_waitlist) const { + if (!pool) { + return VK_NULL_HANDLE; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + id < pool->fence_.pool.size(), + "Invalid Vulkan fence!"); + + const VkFence fence = pool->fence_.pool[id].get(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + fence, + "Invalid Vulkan fence!"); + + if (add_to_waitlist) { + pool->fence_.waitlist.push_back(fence); + } + + return fence; +} + +void Resource::Fence::wait(const uint64_t timeout_nanoseconds) { + const VkFence fence = handle(/* add_to_waitlist = */ false); + + const auto waitlist_itr = std::find( + pool->fence_.waitlist.cbegin(), + pool->fence_.waitlist.cend(), + fence); + + if (pool->fence_.waitlist.cend() != waitlist_itr) { + VK_CHECK(vkWaitForFences( + pool->device_, + 1u, + &fence, + VK_TRUE, + timeout_nanoseconds)); + + VK_CHECK(vkResetFences( + pool->device_, + 1u, + &fence)); + + pool->fence_.waitlist.erase(waitlist_itr); + } +} + +namespace { + +class Linear final : public Resource::Pool::Policy { + public: + Linear( + VkDeviceSize block_size, + uint32_t min_block_count, + uint32_t max_block_count); + + virtual void enact( + VmaAllocator allocator, + const VkMemoryRequirements& memory_requirements, + VmaAllocationCreateInfo& allocation_create_info) override; + + private: + struct Configuration final { + static constexpr uint32_t kReserve = 16u; + }; + + struct Entry final { + class Deleter final { + public: + explicit Deleter(VmaAllocator); + void operator()(VmaPool) const; + + private: + VmaAllocator allocator_; + }; + + uint32_t memory_type_index; + Handle handle; + }; + + std::vector pools_; + + struct { + VkDeviceSize size; + uint32_t min; + uint32_t max; + } block_; +}; + +Linear::Entry::Deleter::Deleter(const VmaAllocator allocator) + : allocator_(allocator) { +} + +void Linear::Entry::Deleter::operator()(const VmaPool pool) const { + vmaDestroyPool(allocator_, pool); +} + +Linear::Linear( + const VkDeviceSize block_size, + const uint32_t min_block_count, + const uint32_t max_block_count) + : block_ { + block_size, + min_block_count, + max_block_count, + } { + pools_.reserve(Configuration::kReserve); +} + +void Linear::enact( + const VmaAllocator allocator, + const VkMemoryRequirements& memory_requirements, + VmaAllocationCreateInfo& allocation_create_info) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + allocator, + "Invalid VMA (Vulkan Memory Allocator) allocator!"); + + uint32_t memory_type_index = 0u; + VK_CHECK(vmaFindMemoryTypeIndex( + allocator, + memory_requirements.memoryTypeBits, + &allocation_create_info, + &memory_type_index)); + + auto pool_itr = std::find_if( + pools_.begin(), + pools_.end(), + [memory_type_index](const Entry& entry) { + return entry.memory_type_index == memory_type_index; + }); + + if (pools_.end() == pool_itr) { + const VmaPoolCreateInfo pool_create_info{ + memory_type_index, + VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT, + block_.size, + block_.min, + block_.max, + 0u, + }; + + VmaPool pool{}; + VK_CHECK(vmaCreatePool( + allocator, + &pool_create_info, + &pool)); + + TORCH_CHECK( + pool, + "Invalid VMA (Vulkan Memory Allocator) memory pool!"); + + pools_.push_back({ + memory_type_index, + { + pool, + Entry::Deleter(allocator), + }, + }); + + pool_itr = std::prev(pools_.end()); + } + + allocation_create_info.pool = pool_itr->handle.get(); +} + +} // namespace + +std::unique_ptr Resource::Pool::Policy::linear( + const VkDeviceSize block_size, + const uint32_t min_block_count, + const uint32_t max_block_count) { + return std::make_unique( + block_size, + min_block_count, + max_block_count); +} + +Resource::Pool::Pool( + const GPU& gpu, + std::unique_ptr policy) : device_(gpu.device), allocator_( create_allocator( - gpu.adapter->runtime->instance(), - gpu.adapter->handle, - device_), - vmaDestroyAllocator) { - buffers_.reserve(Configuration::kReserve); - images_.reserve(Configuration::kReserve); + gpu.adapter->runtime->instance(), + gpu.adapter->handle, + device_), + vmaDestroyAllocator), + memory_{ + std::move(policy), + }, + image_{ + .sampler = Image::Sampler{gpu}, + }, + fence_{} { + buffer_.pool.reserve(Configuration::kReserve); + image_.pool.reserve(Configuration::kReserve); + fence_.pool.reserve(Configuration::kReserve); +} + +Resource::Pool::Pool(Pool&& pool) + : device_(std::move(pool.device_)), + allocator_(std::move(pool.allocator_)), + memory_(std::move(pool.memory_)), + buffer_(std::move(pool.buffer_)), + image_(std::move(pool.image_)), + fence_(std::move(pool.fence_)) { + pool.invalidate(); +} + +Resource::Pool& Resource::Pool::operator=(Pool&& pool) { + if (&pool != this) { + device_ = std::move(pool.device_); + allocator_ = std::move(pool.allocator_); + memory_ = std::move(pool.memory_); + buffer_ = std::move(pool.buffer_); + image_ = std::move(pool.image_); + fence_ = std::move(pool.fence_); + + pool.invalidate(); + }; + + return *this; +} + +Resource::Pool::~Pool() { + try { + if (device_ && allocator_) { + purge(); + } + } + catch (const std::exception& e) { + LOG(WARNING) + << "Vulkan: Resource pool destructor raised an exception! Error: " + << e.what(); + } + catch (...) { + LOG(WARNING) + << "Vulkan: Resource pool destructor raised an unknown exception!"; + } } -Resource::Buffer Resource::Pool::allocate( +Resource::Buffer Resource::Pool::buffer( const Buffer::Descriptor& descriptor) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && allocator_, + "This resource pool is in an invalid state! ", + "Potential reason: This resource pool is moved from."); + const VkBufferCreateInfo buffer_create_info{ VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, nullptr, @@ -152,40 +441,74 @@ Resource::Buffer Resource::Pool::allocate( nullptr, }; - const VmaAllocationCreateInfo allocation_create_info = + VkBuffer buffer{}; + VK_CHECK(vkCreateBuffer( + device_, + &buffer_create_info, + nullptr, + &buffer)); + + TORCH_CHECK( + buffer, + "Invalid Vulkan buffer!"); + + VkMemoryRequirements memory_requirements{}; + vkGetBufferMemoryRequirements( + device_, + buffer, + &memory_requirements); + + VmaAllocationCreateInfo allocation_create_info = create_allocation_create_info(descriptor.usage.memory); - VkBuffer buffer{}; - VmaAllocation allocation{}; - VmaAllocationInfo allocation_info{}; + if (memory_.policy) { + memory_.policy->enact( + allocator_.get(), + memory_requirements, + allocation_create_info); + } - VK_CHECK(vmaCreateBuffer( + VmaAllocation allocation{}; + VK_CHECK(vmaAllocateMemory( allocator_.get(), - &buffer_create_info, + &memory_requirements, &allocation_create_info, - &buffer, &allocation, - &allocation_info)); + nullptr)); - TORCH_CHECK(buffer, "Invalid Vulkan buffer!"); - TORCH_CHECK(allocation, "Invalid VMA allocation!"); + TORCH_CHECK( + allocation, + "Invalid VMA (Vulkan Memory Allocator) allocation!"); - buffers_.emplace_back( + VK_CHECK(vmaBindBufferMemory( + allocator_.get(), + allocation, + buffer)); + + buffer_.pool.emplace_back( Buffer{ - buffer, + Buffer::Object{ + buffer, + 0u, + descriptor.size, + }, Memory{ allocator_.get(), allocation, - allocation_info, }, }, &release_buffer); - return buffers_.back().get(); + return buffer_.pool.back().get(); } -Resource::Image Resource::Pool::allocate( +Resource::Image Resource::Pool::image( const Image::Descriptor& descriptor) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && allocator_, + "This resource pool is in an invalid state! ", + "Potential reason: This resource pool is moved from."); + const VkImageCreateInfo image_create_info{ VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, nullptr, @@ -204,23 +527,49 @@ Resource::Image Resource::Pool::allocate( VK_IMAGE_LAYOUT_UNDEFINED, }; - const VmaAllocationCreateInfo allocation_create_info = + VkImage image{}; + VK_CHECK(vkCreateImage( + device_, + &image_create_info, + nullptr, + &image)); + + TORCH_CHECK( + image, + "Invalid Vulkan image!"); + + VkMemoryRequirements memory_requirements{}; + vkGetImageMemoryRequirements( + device_, + image, + &memory_requirements); + + VmaAllocationCreateInfo allocation_create_info = create_allocation_create_info(descriptor.usage.memory); - VkImage image{}; - VmaAllocation allocation{}; - VmaAllocationInfo allocation_info{}; + if (memory_.policy) { + memory_.policy->enact( + allocator_.get(), + memory_requirements, + allocation_create_info); + } - VK_CHECK(vmaCreateImage( + VmaAllocation allocation{}; + VK_CHECK(vmaAllocateMemory( allocator_.get(), - &image_create_info, + &memory_requirements, &allocation_create_info, - &image, &allocation, - &allocation_info)); + nullptr)); - TORCH_CHECK(image, "Invalid Vulkan image!"); - TORCH_CHECK(allocation, "Invalid VMA allocation!"); + TORCH_CHECK( + allocation, + "Invalid VMA (Vulkan Memory Allocator) allocation!"); + + VK_CHECK(vmaBindImageMemory( + allocator_.get(), + allocation, + image)); const VkImageViewCreateInfo image_view_create_info{ VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO, @@ -238,9 +587,9 @@ Resource::Image Resource::Pool::allocate( { VK_IMAGE_ASPECT_COLOR_BIT, 0u, - 1u, + VK_REMAINING_MIP_LEVELS, 0u, - 1u, + VK_REMAINING_ARRAY_LAYERS, }, }; @@ -255,24 +604,87 @@ Resource::Image Resource::Pool::allocate( view, "Invalid Vulkan image view!"); - images_.emplace_back( + image_.pool.emplace_back( Image{ - image, - view, + Image::Object{ + image, + VK_IMAGE_LAYOUT_UNDEFINED, + view, + image_.sampler.cache.retrieve(descriptor.sampler), + }, Memory{ allocator_.get(), allocation, - allocation_info, }, }, &release_image); - return images_.back().get(); + return image_.pool.back().get(); +} + +Resource::Fence Resource::Pool::fence() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && allocator_, + "This resource pool is in an invalid state! ", + "Potential reason: This resource pool is moved from."); + + if (fence_.pool.size() == fence_.in_use) { + const VkFenceCreateInfo fence_create_info{ + VK_STRUCTURE_TYPE_FENCE_CREATE_INFO, + nullptr, + 0u, + }; + + VkFence fence{}; + VK_CHECK(vkCreateFence( + device_, + &fence_create_info, + nullptr, + &fence)); + + TORCH_CHECK( + fence, + "Invalid Vulkan fence!"); + + fence_.pool.emplace_back(fence, VK_DELETER(Fence)(device_)); + } + + return Fence{ + this, + fence_.in_use++, + }; } void Resource::Pool::purge() { - images_.clear(); - buffers_.clear(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && allocator_, + "This resource pool is in an invalid state! ", + "Potential reason: This resource pool is moved from."); + + if (!fence_.waitlist.empty()) { + VK_CHECK(vkWaitForFences( + device_, + fence_.waitlist.size(), + fence_.waitlist.data(), + VK_TRUE, + UINT64_MAX)); + + VK_CHECK(vkResetFences( + device_, + fence_.waitlist.size(), + fence_.waitlist.data())); + + fence_.waitlist.clear(); + } + + fence_.in_use = 0u; + image_.pool.clear(); + buffer_.pool.clear(); +} + +void Resource::Pool::invalidate() { + device_ = VK_NULL_HANDLE; + allocator_.reset(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h index 00145ebe071fb..19a7df3d04d2c 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.h +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -1,7 +1,11 @@ #pragma once +#ifdef USE_VULKAN_API + #include #include +#include +#include namespace at { namespace native { @@ -9,28 +13,70 @@ namespace vulkan { namespace api { struct Resource final { - /* - Memory - */ + class Pool; + + // + // Memory + // struct Memory final { - VmaAllocator allocator; - VmaAllocation allocation; - VmaAllocationInfo allocation_info; + /* + Descriptor + */ + + struct Descriptor final { + VmaMemoryUsage usage; + VkMemoryPropertyFlags /* optional */ required; + VkMemoryPropertyFlags /* optional */ preferred; + }; + + /* + Barrier + */ + + struct Barrier final { + VkAccessFlags src; + VkAccessFlags dst; + }; + + /* + Access + */ + + struct Access final { + typedef uint8_t Flags; + + enum Type : Flags { + None = 0u << 0u, + Read = 1u << 0u, + Write = 1u << 1u, + }; + + template + using Pointer = std::add_pointer_t< + std::conditional_t< + 0u != (access & Write), + Type, + std::add_const_t>>; + }; class Scope; template - using Data = Handle; + using Handle = Handle; template< typename Type, - typename Pointer = std::add_pointer_t>> - Data map() const &; + typename Pointer = Access::Pointer> + Handle map() const &; template< typename Type, - typename Pointer = std::add_pointer_t> - Data map() &; + Access::Flags kAccess, + typename Pointer = Access::Pointer> + Handle map() &; + + VmaAllocator allocator; + VmaAllocation allocation; private: // Intentionally disabed to ensure memory access is always properly @@ -40,15 +86,15 @@ struct Resource final { // for seemingly ineffective memory writes and hard to hunt down bugs. template - Data map() const && = delete; + Handle map() const && = delete; - template - Data map() && = delete; + template + Handle map() && = delete; }; - /* - Buffer - */ + // + // Buffer + // struct Buffer final { /* @@ -60,21 +106,92 @@ struct Resource final { struct { VkBufferUsageFlags buffer; - VmaMemoryUsage memory; + Memory::Descriptor memory; } usage; }; - VkBuffer handle; + /* + Object + */ + + struct Object final { + VkBuffer handle; + VkDeviceSize offset; + VkDeviceSize range; + + operator bool() const; + }; + + /* + Barrier + */ + + struct Barrier final { + Object object; + Memory::Barrier memory; + }; + + Object object; Memory memory; operator bool() const; }; - /* - Image - */ + // + // Image + // struct Image final { + // + // Sampler + // + + struct Sampler final { + /* + Descriptor + */ + + struct Descriptor final { + VkFilter filter; + VkSamplerMipmapMode mipmap_mode; + VkSamplerAddressMode address_mode; + VkBorderColor border; + }; + + /* + Factory + */ + + class Factory final { + public: + explicit Factory(const GPU& gpu); + + typedef Sampler::Descriptor Descriptor; + typedef VK_DELETER(Sampler) Deleter; + typedef Handle Handle; + + struct Hasher { + size_t operator()(const Descriptor& descriptor) const; + }; + + Handle operator()(const Descriptor& descriptor) const; + + private: + VkDevice device_; + }; + + /* + Cache + */ + + typedef api::Cache Cache; + Cache cache; + + explicit Sampler(const GPU& gpu) + : cache(Factory(gpu)) { + } + }; + /* Descriptor */ @@ -86,34 +203,108 @@ struct Resource final { struct { VkImageUsageFlags image; - VmaMemoryUsage memory; + Memory::Descriptor memory; } usage; struct { VkImageViewType type; VkFormat format; } view; + + Sampler::Descriptor sampler; }; - VkImage handle; - VkImageView view; + /* + Object + */ + + struct Object final { + VkImage handle; + VkImageLayout layout; + VkImageView view; + VkSampler sampler; + + operator bool() const; + }; + + /* + Barrier + */ + + struct Barrier final { + Object object; + Memory::Barrier memory; + + struct { + VkImageLayout src; + VkImageLayout dst; + } layout; + }; + + Object object; Memory memory; operator bool() const; }; - /* - Pool - */ + // + // Fence + // + + struct Fence final { + Pool* pool; + size_t id; + + operator bool() const; + VkFence handle(bool add_to_waitlist = true) const; + void wait(uint64_t timeout_nanoseconds = UINT64_MAX); + }; + + // + // Pool + // class Pool final { public: - explicit Pool(const GPU& gpu); + class Policy { + public: + virtual ~Policy() = default; + + static std::unique_ptr linear( + VkDeviceSize block_size = VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE, + uint32_t min_block_count = 1u, + uint32_t max_block_count = UINT32_MAX); + + virtual void enact( + VmaAllocator allocator, + const VkMemoryRequirements& memory_requirements, + VmaAllocationCreateInfo& allocation_create_info) = 0; + }; + + explicit Pool(const GPU& gpu, std::unique_ptr = {}); + Pool(const Pool&) = delete; + Pool& operator=(const Pool&) = delete; + Pool(Pool&&); + Pool& operator=(Pool&&); + ~Pool(); + + // Primary - Buffer allocate(const Buffer::Descriptor& descriptor); - Image allocate(const Image::Descriptor& descriptor); + Buffer buffer(const Buffer::Descriptor& descriptor); + Image image(const Image::Descriptor& descriptor); + Fence fence(); void purge(); + // Helper + + template + Buffer uniform(const Block& block); + + private: + friend struct Fence; + + void invalidate(); + private: struct Configuration final { static constexpr uint32_t kReserve = 256u; @@ -121,12 +312,29 @@ struct Resource final { VkDevice device_; Handle allocator_; - std::vector> buffers_; - std::vector> images_; + + struct { + std::unique_ptr policy; + } memory_; + + struct { + std::vector> pool; + } buffer_; + + struct { + std::vector> pool; + Image::Sampler sampler; + } image_; + + struct { + std::vector> pool; + mutable std::vector waitlist; + size_t in_use; + } fence_; } pool; explicit Resource(const GPU& gpu) - : pool(gpu) { + : pool(gpu, Pool::Policy::linear()) { } }; @@ -136,49 +344,114 @@ struct Resource final { class Resource::Memory::Scope final { public: - enum class Access { - Read, - Write, - }; + Scope( + VmaAllocator allocator, + VmaAllocation allocation, + Access::Flags access); - Scope(VmaAllocator allocator, VmaAllocation allocation, Access access); void operator()(const void* data) const; private: VmaAllocator allocator_; VmaAllocation allocation_; - Access access_; + Access::Flags access_; }; template -inline Resource::Memory::Data Resource::Memory::map() const & { - void* map(const Memory& memory); +inline Resource::Memory::Handle Resource::Memory::map() const & { + // Forward declaration + void* map(const Memory&, Access::Flags); - return Data{ - reinterpret_cast(map(*this)), - Scope(allocator, allocation, Scope::Access::Read), + return Handle{ + reinterpret_cast(map(*this, Access::Read)), + Scope(allocator, allocation, Access::Read), }; } -template -inline Resource::Memory::Data Resource::Memory::map() & { - void* map(const Memory& memory); +template +inline Resource::Memory::Handle Resource::Memory::map() & { + // Forward declaration + void* map(const Memory&, Access::Flags); - return Data{ - reinterpret_cast(map(*this)), - Scope(allocator, allocation, Scope::Access::Write), + static_assert( + (kAccess == Access::Read) || + (kAccess == Access::Write) || + (kAccess == (Access::Read | Access::Write)), + "Invalid memory access!"); + + return Handle{ + reinterpret_cast(map(*this, kAccess)), + Scope(allocator, allocation, kAccess), }; } +inline Resource::Buffer::Object::operator bool() const { + return VK_NULL_HANDLE != handle; +} + inline Resource::Buffer::operator bool() const { + return object; +} + +inline bool operator==( + const Resource::Image::Sampler::Descriptor& _1, + const Resource::Image::Sampler::Descriptor& _2) { + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Resource::Image::Sampler::Descriptor))); +} + +inline size_t Resource::Image::Sampler::Factory::Hasher::operator()( + const Descriptor& descriptor) const { + return c10::get_hash( + descriptor.filter, + descriptor.mipmap_mode, + descriptor.address_mode, + descriptor.border); +} + +inline Resource::Image::Object::operator bool() const { return VK_NULL_HANDLE != handle; } inline Resource::Image::operator bool() const { - return VK_NULL_HANDLE != handle; + return object; +} + +inline Resource::Fence::operator bool() const { + return pool; +} + +template +inline Resource::Buffer Resource::Pool::uniform(const Block& block) { + Buffer uniform = this->buffer({ + sizeof(Block), + { + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, + { + VMA_MEMORY_USAGE_CPU_TO_GPU, + 0u, + 0u, + }, + }, + }); + + { + Memory::Handle memory = uniform.memory.template map< + Block, + Memory::Access::Write>(); + + *memory.get() = block; + } + + return uniform; } } // namespace api } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Runtime.cpp b/aten/src/ATen/native/vulkan/api/Runtime.cpp index ce6e3b4231e4c..c3ad6ebddb45e 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.cpp +++ b/aten/src/ATen/native/vulkan/api/Runtime.cpp @@ -10,7 +10,7 @@ namespace api { namespace { struct Configuration final { -#ifndef DEBUG +#ifdef DEBUG static constexpr Runtime::Type kRuntime = Runtime::Type::Debug; #else static constexpr Runtime::Type kRuntime = Runtime::Type::Release; @@ -86,7 +86,9 @@ VkInstance create_instance(const Runtime::Type type) { nullptr, &instance_extension_count, instance_extension_properties.data())); constexpr const char* const requested_instance_extensions[]{ + #ifdef VK_EXT_debug_report VK_EXT_DEBUG_REPORT_EXTENSION_NAME, + #endif }; for (const auto& requested_instance_extension : requested_instance_extensions) { @@ -323,10 +325,6 @@ Runtime* initialize() { return runtime.get(); } -bool available() { - return initialize(); -} - Runtime* runtime() { Runtime* const runtime = initialize(); TORCH_CHECK( diff --git a/aten/src/ATen/native/vulkan/api/Runtime.h b/aten/src/ATen/native/vulkan/api/Runtime.h index 766aeb50cabce..55eae70f8723d 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.h +++ b/aten/src/ATen/native/vulkan/api/Runtime.h @@ -1,5 +1,7 @@ #pragma once +#ifdef USE_VULKAN_API + #include namespace at { @@ -26,15 +28,12 @@ class Runtime final { explicit Runtime(Type type); Runtime(const Runtime&) = delete; - Runtime(Runtime&&) = default; Runtime& operator=(const Runtime&) = delete; + Runtime(Runtime&&) = default; Runtime& operator=(Runtime&&) = default; ~Runtime() = default; - inline VkInstance instance() const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(instance_); - return instance_.get(); - } + VkInstance instance() const; typedef std::function Selector; Adapter select(const Selector& selector); @@ -42,8 +41,8 @@ class Runtime final { private: class Debug final { public: - explicit Debug(VkInstance instance); - void operator()(VkDebugReportCallbackEXT debug_report_callback) const; + explicit Debug(VkInstance); + void operator()(VkDebugReportCallbackEXT) const; private: VkInstance instance_; @@ -55,10 +54,20 @@ class Runtime final { Handle debug_report_callback_; }; -bool available(); Runtime* runtime(); +// +// Impl +// + +inline VkInstance Runtime::instance() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(instance_); + return instance_.get(); +} + } // namespace api } // namespace vulkan } // namespace native } // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp index 977f915a61d16..7995dd160c35f 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.cpp +++ b/aten/src/ATen/native/vulkan/api/Shader.cpp @@ -9,7 +9,6 @@ namespace native { namespace vulkan { namespace api { - Shader::Layout::Factory::Factory(const GPU& gpu) : device_(gpu.device) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( @@ -19,12 +18,25 @@ Shader::Layout::Factory::Factory(const GPU& gpu) Shader::Layout::Factory::Handle Shader::Layout::Factory::operator()( const Descriptor& descriptor) const { + c10::SmallVector bindings; + + uint32_t binding = 0u; + for (const VkDescriptorType type : descriptor.signature) { + bindings.push_back({ + binding++, + type, + 1u, + VK_SHADER_STAGE_COMPUTE_BIT, + nullptr, + }); + } + const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{ VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, nullptr, 0u, - static_cast(descriptor.bindings.size()), - descriptor.bindings.data(), + static_cast(bindings.size()), + bindings.data(), }; VkDescriptorSetLayout descriptor_set_layout{}; @@ -44,24 +56,12 @@ Shader::Layout::Factory::Handle Shader::Layout::Factory::operator()( }; } -Shader::Descriptor::Descriptor(const char* const glsl) - : type(Type::Source) { - TORCH_CHECK(glsl, "Invalid shader source code!"); - - shader.source = { - glsl, - 0u, - }; +Shader::Layout::Cache::Cache(Factory factory) + : cache_(std::move(factory)) { } -Shader::Descriptor::Descriptor(const uint32_t* const code, const uint32_t size) - : type(Type::Binary) { - TORCH_CHECK(code && (0u != size), "Invalid shader binary!"); - - shader.binary = { - code, - size, - }; +void Shader::Layout::Cache::purge() { + cache_.purge(); } #ifdef USE_VULKAN_SHADERC_RUNTIME @@ -71,6 +71,7 @@ struct Shader::Factory::Compiler final { shaderc::CompileOptions options; Compiler() { + options.SetNanClamp(/*enable =*/ true); options.SetSourceLanguage(shaderc_source_language_glsl); options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_0); options.SetWarningsAsErrors(); diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h index ff02b2ba90647..f005eb1c11e99 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.h +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -1,7 +1,10 @@ #pragma once +#ifdef USE_VULKAN_API + #include #include +#include #include namespace at { @@ -38,12 +41,18 @@ struct Shader final { // struct Layout final { + /* + Signature + */ + + typedef c10::SmallVector Signature; + /* Descriptor */ struct Descriptor final { - c10::SmallVector bindings; + Signature signature; }; /* @@ -68,12 +77,32 @@ struct Shader final { VkDevice device_; }; + struct Object final { + VkDescriptorSetLayout handle; + Signature signature; + + operator bool() const; + }; + /* Cache */ - typedef api::Cache Cache; - Cache cache; + class Cache final { + public: + explicit Cache(Factory factory); + Cache(const Cache&) = delete; + Cache& operator=(const Cache&) = delete; + Cache(Cache&&) = default; + Cache& operator=(Cache&&) = default; + ~Cache() = default; + + Object retrieve(const Descriptor& descriptor); + void purge(); + + private: + api::Cache cache_; + } cache; explicit Layout(const GPU& gpu) : cache(Factory(gpu)) { @@ -84,11 +113,7 @@ struct Shader final { // Work Group // - struct WorkGroup final { - uint32_t x; - uint32_t y; - uint32_t z; - }; + typedef utils::uvec3 WorkGroup; /* Descriptor @@ -165,45 +190,76 @@ struct Shader final { inline bool operator==( const Shader::Layout::Descriptor& _1, const Shader::Layout::Descriptor& _2) { - return _1.bindings == _2.bindings; + return _1.signature == _2.signature; } inline size_t Shader::Layout::Factory::Hasher::operator()( const Descriptor& descriptor) const { size_t hash = 0u; - for (const VkDescriptorSetLayoutBinding& binding : descriptor.bindings) { + for (const VkDescriptorType type : descriptor.signature) { hash = c10::hash_combine( hash, - c10::get_hash( - binding.binding, - binding.descriptorType, - binding.descriptorCount, - binding.stageFlags, - binding.pImmutableSamplers)); + c10::get_hash(type)); } return hash; } +inline Shader::Layout::Object::operator bool() const { + return VK_NULL_HANDLE != handle; +} + +inline Shader::Layout::Object Shader::Layout::Cache::retrieve( + const Descriptor& descriptor) { + return { + cache_.retrieve(descriptor), + descriptor.signature, + }; +} + inline bool operator==( const Shader::WorkGroup& _1, const Shader::WorkGroup& _2) { - return (_1.x == _2.x) && - (_1.y == _2.y) && - (_1.z == _2.z); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(Shader::WorkGroup))); +} + +inline Shader::Descriptor::Descriptor(const char* const glsl) + : type(Type::Source), + shader{ + .source = { + glsl, + 0u, + }, + } { + TORCH_CHECK(glsl, "Invalid shader source code!"); +} + +inline Shader::Descriptor::Descriptor( + const uint32_t* const code, + const uint32_t size) + : type(Type::Binary), + shader{ + .binary = { + code, + size, + }, + } { + TORCH_CHECK(code && (0u != size), "Invalid shader binary!"); } inline bool operator==( const Shader::Descriptor& _1, const Shader::Descriptor& _2) { static_assert( - sizeof(Shader::Descriptor::shader.source) == sizeof(Shader::Descriptor::shader.binary), - "This implementation requires sizeof(Source) to be equal to sizeof(Binary)."); + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); - return (_1.type == _2.type) && - (_1.shader.binary.spirv == _2.shader.binary.spirv) && - (_1.shader.binary.size == _2.shader.binary.size); + return (0 == memcmp(&_1, &_2, sizeof(Shader::Descriptor))); } inline size_t Shader::Factory::Hasher::operator()( @@ -226,9 +282,11 @@ inline size_t Shader::Factory::Hasher::operator()( inline bool operator==( const VkDescriptorSetLayoutBinding& _1, const VkDescriptorSetLayoutBinding& _2) { - return (_1.binding == _2.binding) && - (_1.descriptorType == _2.descriptorType) && - (_1.descriptorCount == _2.descriptorCount) && - (_1.stageFlags == _2.stageFlags) && - (_1.pImmutableSamplers == _2.pImmutableSamplers); + static_assert( + std::is_trivially_copyable::value, + "This implementation is no longer valid!"); + + return (0 == memcmp(&_1, &_2, sizeof(VkDescriptorSetLayoutBinding))); } + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/Utils.h b/aten/src/ATen/native/vulkan/api/Utils.h new file mode 100644 index 0000000000000..1d261849a5e7e --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Utils.h @@ -0,0 +1,112 @@ +#pragma once + +#ifdef USE_VULKAN_API + +namespace at { +namespace native { +namespace vulkan { +namespace api { +namespace utils { + +// +// Alignment +// + +template +inline constexpr Type align_down( + const Type number, + const Type multiple) { + return (number / multiple) * multiple; +} + +template +inline constexpr Type align_up( + const Type number, + const Type multiple) { + return align_down(number + multiple - 1, multiple); +} + +template +inline constexpr Type div_up( + const Type numerator, + const Type denominator) { + return (numerator + denominator - 1) / denominator; +} + +// +// Cast +// + +namespace detail { + +template +inline constexpr To safe_downcast(const From v) { + typedef std::common_type_t Type; + constexpr Type min{static_cast(std::numeric_limits::lowest())}; + constexpr Type max{static_cast(std::numeric_limits::max())}; + TORCH_CHECK(min <= v && v <= max, "Cast failed: out of range!"); + return static_cast(v); +} + +template +inline constexpr bool is_signed_to_unsigned() { + return std::is_signed::value && std::is_unsigned::value; +} + +} // namespace detail + +template < + typename To, + typename From, + std::enable_if_t(), bool> = true> +inline constexpr To safe_downcast(const From v) { + TORCH_CHECK(v >= From{}, "Cast failed: negative signed to unsigned!"); + return detail::safe_downcast(v); +} + +template < + typename To, + typename From, + std::enable_if_t(), bool> = true> +inline constexpr To safe_downcast(const From v) { + return detail::safe_downcast(v); +} + +// +// Vector +// + +namespace detail { + +template +struct vec final { + Type data[N]; +}; + +} // namespace detail + +template +using ivec = detail::vec; +using ivec2 = ivec<2u>; +using ivec3 = ivec<3u>; +using ivec4 = ivec<4u>; + +template +using uvec = detail::vec; +using uvec2 = uvec<2u>; +using uvec3 = uvec<3u>; +using uvec4 = uvec<4u>; + +template +using vec = detail::vec; +using vec2 = vec<2u>; +using vec3 = vec<3u>; +using vec4 = vec<4u>; + +} // namespace utils +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/api.h b/aten/src/ATen/native/vulkan/api/api.h index 658824e3bf2b4..c20d8b71e3c6f 100644 --- a/aten/src/ATen/native/vulkan/api/api.h +++ b/aten/src/ATen/native/vulkan/api/api.h @@ -1,5 +1,7 @@ #pragma once +#ifdef USE_VULKAN_API + #include #include @@ -10,3 +12,6 @@ #include #include #include +#include + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h index fdeadf9cdbfa7..c20cfd08ca7d1 100644 --- a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h +++ b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h @@ -25,7 +25,7 @@ /** \mainpage Vulkan Memory Allocator -Version 3.0.0-development (2020-06-24) +Version 3.0.0-development (2020-11-03) Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. \n License: MIT @@ -53,6 +53,7 @@ Documentation of all members: vk_mem_alloc.h - \subpage staying_within_budget - [Querying for budget](@ref staying_within_budget_querying_for_budget) - [Controlling memory usage](@ref staying_within_budget_controlling_memory_usage) + - \subpage resource_aliasing - \subpage custom_memory_pools - [Choosing memory type index](@ref custom_memory_pools_MemTypeIndex) - [Linear allocation algorithm](@ref linear_algorithm) @@ -126,7 +127,7 @@ To do it properly: \code #define VMA_IMPLEMENTATION -#include vk_mem_alloc.h +#include \endcode It may be a good idea to create dedicated CPP file just for this purpose. @@ -141,6 +142,15 @@ before including these headers (like `WIN32_LEAN_AND_MEAN` or `WINVER` for Windows, `VK_USE_PLATFORM_WIN32_KHR` for Vulkan), you must define them before every `#include` of this library. +You may need to configure the way you import Vulkan functions. + +- By default, VMA assumes you you link statically with Vulkan API. If this is not the case, + `#define VMA_STATIC_VULKAN_FUNCTIONS 0` before `#include` of the VMA implementation and use another way. +- You can `#define VMA_DYNAMIC_VULKAN_FUNCTIONS 1` and make sure `vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` globals are defined. + All the remaining Vulkan functions will be fetched automatically. +- Finally, you can provide your own pointers to all Vulkan functions needed by VMA using structure member + VmaAllocatorCreateInfo::pVulkanFunctions, if you fetched them in some custom way e.g. using some loader like [Volk](https://github.com/zeux/volk). + \section quick_start_initialization Initialization @@ -152,6 +162,7 @@ At program startup: \code VmaAllocatorCreateInfo allocatorInfo = {}; +allocatorInfo.vulkanApiVersion = VK_API_VERSION_1_2; allocatorInfo.physicalDevice = physicalDevice; allocatorInfo.device = device; allocatorInfo.instance = instance; @@ -160,6 +171,13 @@ VmaAllocator allocator; vmaCreateAllocator(&allocatorInfo, &allocator); \endcode +Only members `physicalDevice`, `device`, `instance` are required. +However, you should inform the library which Vulkan version do you use by setting +VmaAllocatorCreateInfo::vulkanApiVersion and which extensions did you enable +by setting VmaAllocatorCreateInfo::flags (like #VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT for VK_KHR_buffer_device_address). +Otherwise, VMA would use only features of Vulkan 1.0 core with no extensions. + + \section quick_start_resource_allocation Resource allocation When you want to create a buffer or image: @@ -301,6 +319,7 @@ VmaAllocation allocation; vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); \endcode + \section choosing_memory_type_custom_memory_pools Custom memory pools If you allocate from custom memory pool, all the ways of specifying memory @@ -513,7 +532,7 @@ VmaAllocation alloc; VmaAllocationInfo allocInfo; vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); -if(allocInfo.pUserData != nullptr) +if(allocInfo.pMappedData != nullptr) { // Allocation ended up in mappable memory. // It's persistently mapped. You can access it directly. @@ -599,6 +618,114 @@ set to more than 0 will try to allocate memory blocks without checking whether t fit within budget. +\page resource_aliasing Resource aliasing (overlap) + +New explicit graphics APIs (Vulkan and Direct3D 12), thanks to manual memory +management, give an opportunity to alias (overlap) multiple resources in the +same region of memory - a feature not available in the old APIs (Direct3D 11, OpenGL). +It can be useful to save video memory, but it must be used with caution. + +For example, if you know the flow of your whole render frame in advance, you +are going to use some intermediate textures or buffers only during a small range of render passes, +and you know these ranges don't overlap in time, you can bind these resources to +the same place in memory, even if they have completely different parameters (width, height, format etc.). + +![Resource aliasing (overlap)](../gfx/Aliasing.png) + +Such scenario is possible using VMA, but you need to create your images manually. +Then you need to calculate parameters of an allocation to be made using formula: + +- allocation size = max(size of each image) +- allocation alignment = max(alignment of each image) +- allocation memoryTypeBits = bitwise AND(memoryTypeBits of each image) + +Following example shows two different images bound to the same place in memory, +allocated to fit largest of them. + +\code +// A 512x512 texture to be sampled. +VkImageCreateInfo img1CreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; +img1CreateInfo.imageType = VK_IMAGE_TYPE_2D; +img1CreateInfo.extent.width = 512; +img1CreateInfo.extent.height = 512; +img1CreateInfo.extent.depth = 1; +img1CreateInfo.mipLevels = 10; +img1CreateInfo.arrayLayers = 1; +img1CreateInfo.format = VK_FORMAT_R8G8B8A8_SRGB; +img1CreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; +img1CreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; +img1CreateInfo.usage = VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT; +img1CreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; + +// A full screen texture to be used as color attachment. +VkImageCreateInfo img2CreateInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; +img2CreateInfo.imageType = VK_IMAGE_TYPE_2D; +img2CreateInfo.extent.width = 1920; +img2CreateInfo.extent.height = 1080; +img2CreateInfo.extent.depth = 1; +img2CreateInfo.mipLevels = 1; +img2CreateInfo.arrayLayers = 1; +img2CreateInfo.format = VK_FORMAT_R8G8B8A8_UNORM; +img2CreateInfo.tiling = VK_IMAGE_TILING_OPTIMAL; +img2CreateInfo.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; +img2CreateInfo.usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT; +img2CreateInfo.samples = VK_SAMPLE_COUNT_1_BIT; + +VkImage img1; +res = vkCreateImage(device, &img1CreateInfo, nullptr, &img1); +VkImage img2; +res = vkCreateImage(device, &img2CreateInfo, nullptr, &img2); + +VkMemoryRequirements img1MemReq; +vkGetImageMemoryRequirements(device, img1, &img1MemReq); +VkMemoryRequirements img2MemReq; +vkGetImageMemoryRequirements(device, img2, &img2MemReq); + +VkMemoryRequirements finalMemReq = {}; +finalMemReq.size = std::max(img1MemReq.size, img2MemReq.size); +finalMemReq.alignment = std::max(img1MemReq.alignment, img2MemReq.alignment); +finalMemReq.memoryTypeBits = img1MemReq.memoryTypeBits & img2MemReq.memoryTypeBits; +// Validate if(finalMemReq.memoryTypeBits != 0) + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; + +VmaAllocation alloc; +res = vmaAllocateMemory(allocator, &finalMemReq, &allocCreateInfo, &alloc, nullptr); + +res = vmaBindImageMemory(allocator, alloc, img1); +res = vmaBindImageMemory(allocator, alloc, img2); + +// You can use img1, img2 here, but not at the same time! + +vmaFreeMemory(allocator, alloc); +vkDestroyImage(allocator, img2, nullptr); +vkDestroyImage(allocator, img1, nullptr); +\endcode + +Remember that using resouces that alias in memory requires proper synchronization. +You need to issue a memory barrier to make sure commands that use `img1` and `img2` +don't overlap on GPU timeline. +You also need to treat a resource after aliasing as uninitialized - containing garbage data. +For example, if you use `img1` and then want to use `img2`, you need to issue +an image memory barrier for `img2` with `oldLayout` = `VK_IMAGE_LAYOUT_UNDEFINED`. + +Additional considerations: + +- Vulkan also allows to interpret contents of memory between aliasing resources consistently in some cases. +See chapter 11.8. "Memory Aliasing" of Vulkan specification or `VK_IMAGE_CREATE_ALIAS_BIT` flag. +- You can create more complex layout where different images and buffers are bound +at different offsets inside one large allocation. For example, one can imagine +a big texture used in some render passes, aliasing with a set of many small buffers +used between in some further passes. To bind a resource at non-zero offset of an allocation, +use vmaBindBufferMemory2() / vmaBindImageMemory2(). +- Before allocating memory for the resources you want to alias, check `memoryTypeBits` +returned in memory requirements of each resource to make sure the bits overlap. +Some GPUs may expose multiple memory types suitable e.g. only for buffers or +images with `COLOR_ATTACHMENT` usage, so the sets of memory types supported by your +resources may be disjoint. Aliasing them is not possible in that case. + + \page custom_memory_pools Custom memory pools A memory pool contains a number of `VkDeviceMemory` blocks. @@ -1286,7 +1413,7 @@ To do it, define macro `VMA_DEBUG_INITIALIZE_ALLOCATIONS` to 1. \code #define VMA_DEBUG_INITIALIZE_ALLOCATIONS 1 -#include vk_mem_alloc.h +#include \endcode It makes memory of all new allocations initialized to bit pattern `0xDCDCDCDC`. @@ -1313,7 +1440,7 @@ number of bytes as a margin before and after every allocation. \code #define VMA_DEBUG_MARGIN 16 -#include vk_mem_alloc.h +#include \endcode ![Allocations with margin](../gfx/Margins_2.png) @@ -1347,7 +1474,7 @@ of contents of the margins. \code #define VMA_DEBUG_MARGIN 16 #define VMA_DEBUG_DETECT_CORRUPTION 1 -#include vk_mem_alloc.h +#include \endcode When this feature is enabled, number of bytes specified as `VMA_DEBUG_MARGIN` @@ -1639,7 +1766,7 @@ VmaAllocatorCreateInfo::pDeviceMemoryCallbacks. When device memory of certain heap runs out of free space, new allocations may fail (returning error code) or they may succeed, silently pushing some existing memory blocks from GPU VRAM to system RAM (which degrades performance). This -behavior is implementation-dependant - it depends on GPU vendor and graphics +behavior is implementation-dependent - it depends on GPU vendor and graphics driver. On AMD cards it can be controlled while creating Vulkan device object by using @@ -1871,6 +1998,8 @@ Features deliberately excluded from the scope of this library: explicit memory type index and dedicated allocation anyway, so they don't interact with main features of this library. Such special purpose allocations should be made manually, using `vkCreateBuffer()` and `vkAllocateMemory()`. +- Sub-allocation of parts of one large buffer. Although recommended as a good practice, + it is the user's responsibility to implement such logic on top of VMA. - Recreation of buffers and images. Although the library has functions for buffer and image creation (vmaCreateBuffer(), vmaCreateImage()), you need to recreate these objects yourself after defragmentation. That's because the big @@ -1890,16 +2019,6 @@ Features deliberately excluded from the scope of this library: */ -#if VMA_RECORDING_ENABLED - #include - #if defined(_WIN32) - #include - #else - #include - #include - #endif -#endif - #ifdef __cplusplus extern "C" { #endif @@ -1912,7 +2031,7 @@ available through VmaAllocatorCreateInfo::pRecordSettings. #define VMA_RECORDING_ENABLED 0 #endif -#ifndef NOMINMAX +#if !defined(NOMINMAX) && defined(VMA_IMPLEMENTATION) #define NOMINMAX // For windows.h #endif @@ -1997,7 +2116,7 @@ available through VmaAllocatorCreateInfo::pRecordSettings. // Define these macros to decorate all public functions with additional code, // before and after returned type, appropriately. This may be useful for -// exporing the functions when compiling VMA as a separate library. Example: +// exporting the functions when compiling VMA as a separate library. Example: // #define VMA_CALL_PRE __declspec(dllexport) // #define VMA_CALL_POST __cdecl #ifndef VMA_CALL_PRE @@ -2188,10 +2307,10 @@ typedef enum VmaAllocatorCreateFlagBits { 1. (For Vulkan version < 1.2) Found as available and enabled device extension VK_KHR_buffer_device_address. This extension is promoted to core Vulkan 1.2. - 2. Found as available and enabled device feature `VkPhysicalDeviceBufferDeviceAddressFeatures*::bufferDeviceAddress`. + 2. Found as available and enabled device feature `VkPhysicalDeviceBufferDeviceAddressFeatures::bufferDeviceAddress`. - When this flag is set, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*` using VMA. - The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT*` to + When this flag is set, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT` using VMA. + The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT` to allocated memory blocks wherever it might be needed. For more information, see documentation chapter \ref enabling_buffer_device_address. @@ -2347,7 +2466,7 @@ typedef struct VmaAllocatorCreateInfo It must be a value in the format as created by macro `VK_MAKE_VERSION` or a constant like: `VK_API_VERSION_1_1`, `VK_API_VERSION_1_0`. The patch version number specified is ignored. Only the major and minor versions are considered. It must be less or equal (preferably equal) to value as passed to `vkCreateInstance` as `VkApplicationInfo::apiVersion`. - Only versions 1.0 and 1.1 are supported by the current implementation. + Only versions 1.0, 1.1, 1.2 are supported by the current implementation. Leaving it initialized to zero is equivalent to `VK_API_VERSION_1_0`. */ uint32_t vulkanApiVersion; @@ -2644,7 +2763,7 @@ typedef enum VmaAllocationCreateFlagBits { Pointer to mapped memory will be returned through VmaAllocationInfo::pMappedData. - Is it valid to use this flag for allocation made from memory type that is not + It is valid to use this flag for allocation made from memory type that is not `HOST_VISIBLE`. This flag is then ignored and memory is not mapped. This is useful if you need an allocation that is efficient to use on GPU (`DEVICE_LOCAL`) and still want to map it directly if possible on platforms that @@ -2749,7 +2868,7 @@ typedef struct VmaAllocationCreateInfo VkMemoryPropertyFlags requiredFlags; /** \brief Flags that preferably should be set in a memory type chosen for an allocation. - Set to 0 if no additional flags are prefered. \n + Set to 0 if no additional flags are preferred. \n If `pool` is not null, this member is ignored. */ VkMemoryPropertyFlags preferredFlags; /** \brief Bitmask containing one bit set for every memory type acceptable for this allocation. @@ -3085,7 +3204,12 @@ typedef struct VmaAllocationInfo { If the allocation is lost, it is equal to `VK_NULL_HANDLE`. */ VkDeviceMemory VMA_NULLABLE_NON_DISPATCHABLE deviceMemory; - /** \brief Offset into deviceMemory object to the beginning of this allocation, in bytes. (deviceMemory, offset) pair is unique to this allocation. + /** \brief Offset in `VkDeviceMemory` object to the beginning of this allocation, in bytes. `(deviceMemory, offset)` pair is unique to this allocation. + + You usually don't need to use this offset. If you create a buffer or an image together with the allocation using e.g. function + vmaCreateBuffer(), vmaCreateImage(), functions that operate on these resources refer to the beginning of the buffer or image, + not entire device memory block. Functions like vmaMapMemory(), vmaBindBufferMemory() also refer to the beginning of the allocation + and apply this offset automatically. It can change after call to vmaDefragment() if this allocation is passed to the function, or if allocation is lost. */ @@ -3219,7 +3343,7 @@ VMA_CALL_PRE VkResult VMA_CALL_POST vmaResizeAllocation( /** \brief Returns current information about specified allocation and atomically marks it as used in current frame. -Current paramters of given allocation are returned in `pAllocationInfo`. +Current paramteres of given allocation are returned in `pAllocationInfo`. This function also atomically "touches" allocation - marks it as used in current frame, just like vmaTouchAllocation(). @@ -3719,7 +3843,7 @@ VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory( This function is similar to vmaBindBufferMemory(), but it provides additional parameters. If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag -or with VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_1`. Otherwise the call fails. +or with VmaAllocatorCreateInfo::vulkanApiVersion `>= VK_API_VERSION_1_1`. Otherwise the call fails. */ VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory2( VmaAllocator VMA_NOT_NULL allocator, @@ -3753,7 +3877,7 @@ VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory( This function is similar to vmaBindImageMemory(), but it provides additional parameters. If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag -or with VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_1`. Otherwise the call fails. +or with VmaAllocatorCreateInfo::vulkanApiVersion `>= VK_API_VERSION_1_1`. Otherwise the call fails. */ VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory2( VmaAllocator VMA_NOT_NULL allocator, @@ -3780,13 +3904,17 @@ If the function succeeded, you must destroy both buffer and allocation when you no longer need them using either convenience function vmaDestroyBuffer() or separately, using `vkDestroyBuffer()` and vmaFreeMemory(). -If VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag was used, +If #VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag was used, VK_KHR_dedicated_allocation extension is used internally to query driver whether it requires or prefers the new buffer to have dedicated allocation. If yes, and if dedicated allocation is possible (VmaAllocationCreateInfo::pool is null -and VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT is not used), it creates dedicated +and #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT is not used), it creates dedicated allocation for this buffer, just like when using -VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +#VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. + +\note This function creates a new `VkBuffer`. Sub-allocation of parts of one large buffer, +although recommended as a good practice, is out of scope of this library and could be implemented +by the user as a higher-level logic on top of VMA. */ VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBuffer( VmaAllocator VMA_NOT_NULL allocator, @@ -3856,6 +3984,16 @@ VMA_CALL_PRE void VMA_CALL_POST vmaDestroyImage( #include #include +#if VMA_RECORDING_ENABLED + #include + #if defined(_WIN32) + #include + #else + #include + #include + #endif +#endif + /******************************************************************************* CONFIGURATION SECTION @@ -3881,6 +4019,10 @@ internally, like: */ #if !defined(VMA_DYNAMIC_VULKAN_FUNCTIONS) #define VMA_DYNAMIC_VULKAN_FUNCTIONS 1 + #if defined(VK_NO_PROTOTYPES) + extern PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr; + extern PFN_vkGetDeviceProcAddr vkGetDeviceProcAddr; + #endif #endif // Define this macro to 1 to make the library use STL containers instead of its own implementation. @@ -3943,7 +4085,7 @@ remove them if not needed. #if defined(__ANDROID_API__) && (__ANDROID_API__ < 16) #include -void *vma_aligned_alloc(size_t alignment, size_t size) +static void* vma_aligned_alloc(size_t alignment, size_t size) { // alignment must be >= sizeof(void*) if(alignment < sizeof(void*)) @@ -3960,7 +4102,7 @@ void *vma_aligned_alloc(size_t alignment, size_t size) #include #endif -void *vma_aligned_alloc(size_t alignment, size_t size) +static void* vma_aligned_alloc(size_t alignment, size_t size) { #if defined(__APPLE__) && (defined(MAC_OS_X_VERSION_10_16) || defined(__IPHONE_14_0)) #if MAC_OS_X_VERSION_MAX_ALLOWED >= MAC_OS_X_VERSION_10_16 || __IPHONE_OS_VERSION_MAX_ALLOWED >= __IPHONE_14_0 @@ -3986,17 +4128,29 @@ void *vma_aligned_alloc(size_t alignment, size_t size) return VMA_NULL; } #elif defined(_WIN32) -void *vma_aligned_alloc(size_t alignment, size_t size) +static void* vma_aligned_alloc(size_t alignment, size_t size) { return _aligned_malloc(size, alignment); } #else -void *vma_aligned_alloc(size_t alignment, size_t size) +static void* vma_aligned_alloc(size_t alignment, size_t size) { return aligned_alloc(alignment, size); } #endif +#if defined(_WIN32) +static void vma_aligned_free(void* ptr) +{ + _aligned_free(ptr); +} +#else +static void vma_aligned_free(void* ptr) +{ + free(ptr); +} +#endif + // If your compiler is not compatible with C++11 and definition of // aligned_alloc() function is missing, uncommeting following line may help: @@ -4029,12 +4183,13 @@ void *vma_aligned_alloc(size_t alignment, size_t size) #define VMA_SYSTEM_ALIGNED_MALLOC(size, alignment) vma_aligned_alloc((alignment), (size)) #endif -#ifndef VMA_SYSTEM_FREE - #if defined(_WIN32) - #define VMA_SYSTEM_FREE(ptr) _aligned_free(ptr) +#ifndef VMA_SYSTEM_ALIGNED_FREE + // VMA_SYSTEM_FREE is the old name, but might have been defined by the user + #if defined(VMA_SYSTEM_FREE) + #define VMA_SYSTEM_ALIGNED_FREE(ptr) VMA_SYSTEM_FREE(ptr) #else - #define VMA_SYSTEM_FREE(ptr) free(ptr) - #endif + #define VMA_SYSTEM_ALIGNED_FREE(ptr) vma_aligned_free(ptr) + #endif #endif #ifndef VMA_MIN @@ -4594,7 +4749,7 @@ static IterT VmaBinaryFindFirstNotLess(IterT beg, IterT end, const KeyT &key, co size_t down = 0, up = (end - beg); while(down < up) { - const size_t mid = (down + up) / 2; + const size_t mid = down + (up - down) / 2; // Overflow-safe midpoint calculation if(cmp(*(beg+mid), key)) { down = mid + 1; @@ -4685,7 +4840,7 @@ static void VmaFree(const VkAllocationCallbacks* pAllocationCallbacks, void* ptr } else { - VMA_SYSTEM_FREE(ptr); + VMA_SYSTEM_ALIGNED_FREE(ptr); } } @@ -13649,7 +13804,7 @@ uint32_t VmaBlockVector::ProcessDefragmentations( { VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); - const uint32_t moveCount = std::min(uint32_t(pCtx->defragmentationMoves.size()) - pCtx->defragmentationMovesProcessed, maxMoves); + const uint32_t moveCount = VMA_MIN(uint32_t(pCtx->defragmentationMoves.size()) - pCtx->defragmentationMovesProcessed, maxMoves); for(uint32_t i = 0; i < moveCount; ++ i) { diff --git a/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl b/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl index b5aa038098c6e..2c02e034603ef 100644 --- a/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl +++ b/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) readonly buffer kernel { vec4 data[]; @@ -13,7 +12,7 @@ layout(set = 0, binding = 2) uniform constBlock { } uConstBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID) * ivec3(4, 1, 1); diff --git a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl index ab07da5e4897b..b751d057d4ad8 100644 --- a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl @@ -1,42 +1,41 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - int IW; - int IH; - int OW; - int OH; -} -uConstBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + vec2 kernel; + vec2 stride; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - int ow = uConstBlock.OW; - int oh = uConstBlock.OH; - if (pos.x < ow && pos.y < oh) { - int iw = uConstBlock.IW; - int ih = uConstBlock.IH; - - int sx = int(floor(float(pos.x * iw) / ow)); - int sy = int(floor(float(pos.y * ih) / oh)); - int ex = int(ceil(float((pos.x + 1) * iw) / ow)); - int ey = int(ceil(float((pos.y + 1) * ih) / oh)); - - vec4 r = vec4(1.0) / float(ex - sx) / float(ey - sy); - vec4 acc = vec4(0); - - int xi, yi; - for (xi = sx; xi < ex; ++xi) { - for (yi = sy; yi < ey; ++yi) { - acc += texelFetch(uInput, ivec3(xi, yi, pos.z), 0); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const vec2 ipos = pos.xy * uBlock.stride; + + const ivec2 start = ivec2(ipos); + const ivec2 end = ivec2(ceil(ipos + uBlock.kernel)); + const ivec2 range = end - start; + + vec4 sum = vec4(0); + + for (int y = start.y; y < end.y; ++y) { + for (int x = start.x; x < end.x; ++x) { + sum += texelFetch(uInput, ivec3(x, y, pos.z), 0); } } - imageStore(uOutput, pos, r * acc); + imageStore( + uOutput, + pos, + sum / (range.x * range.y)); } } diff --git a/aten/src/ATen/native/vulkan/glsl/add.glsl b/aten/src/ATen/native/vulkan/glsl/add.glsl index 9b7e992e78c54..361927373a49e 100644 --- a/aten/src/ATen/native/vulkan/glsl/add.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add.glsl @@ -1,27 +1,27 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uInput1; -layout(set = 0, binding = 3) uniform constBlock { - int W; - int H; - int C; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uInput1; +layout(set = 0, binding = 3) uniform PRECISION restrict Block { + ivec3 size; float alpha; -} -uConstBlock; +} uBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - ivec3 WHC = ivec3(uConstBlock.W, uConstBlock.H, uConstBlock.C); - if (all(lessThan(pos, WHC))) { - vec4 v = texelFetch(uInput0, pos, 0) + - uConstBlock.alpha * texelFetch(uInput1, pos, 0); - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size))) { + imageStore( + uOutput, + pos, + texelFetch(uInput0, pos, 0) + uBlock.alpha * texelFetch(uInput1, pos, 0)); } } diff --git a/aten/src/ATen/native/vulkan/glsl/add_.glsl b/aten/src/ATen/native/vulkan/glsl/add_.glsl new file mode 100644 index 0000000000000..d6360a376c58b --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/add_.glsl @@ -0,0 +1,26 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec3 size; + float alpha; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) + uBlock.alpha * texelFetch(uInput0, pos, 0)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl index 559cdd7441c33..735086a8150a4 100644 --- a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl @@ -1,21 +1,26 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 sizes; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec3 size; float other; -} -uConstBlock; +} uBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.sizes.xyz))) { - vec4 v = texelFetch(uInput, pos, 0) + uConstBlock.other; - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size))) { + imageStore( + uOutput, + pos, + texelFetch(uInput, pos, 0) + uBlock.other); } } diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl new file mode 100644 index 0000000000000..a418a28bb5c39 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl @@ -0,0 +1,25 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION restrict Block { + ivec3 size; + float other; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) + uBlock.other); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/addmm.glsl b/aten/src/ATen/native/vulkan/glsl/addmm.glsl index 79987990e5954..a8f09252a167c 100644 --- a/aten/src/ATen/native/vulkan/glsl/addmm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/addmm.glsl @@ -1,33 +1,37 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; -layout(set = 0, binding = 3) uniform constBlock { - ivec4 outputSize; - float beta; - float alpha; - int K; -} -uConstBlock; -layout(set = 0, binding = 4) uniform PRECISION sampler3D uT; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; +layout(set = 0, binding = 3) uniform PRECISION sampler3D uT; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec4 size; + vec2 multiplier; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.outputSize.xyz))) { - int K = uConstBlock.K; - vec4 mmv = vec4(0); - int ki = 0; - for (; ki < K; ++ki) { - vec4 m1ki = texelFetch(uM1, ivec3(ki, pos.y, pos.z), 0); - vec4 m2ki = texelFetch(uM2, ivec3(pos.x, ki, pos.z), 0); - mmv += m1ki * m2ki; + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 sum = vec4(0); + + for (int k = 0; k < uBlock.size.w; ++k) { + sum = fma( + texelFetch(uM1, ivec3(k, pos.y, pos.z), 0), + texelFetch(uM2, ivec3(pos.x, k, pos.z), 0), + sum); } - vec4 tv = texelFetch(uT, pos, 0); - imageStore(uOutput, pos, uConstBlock.beta * tv + uConstBlock.alpha * mmv); + + imageStore( + uOutput, + pos, + uBlock.multiplier.x * sum + uBlock.multiplier.y * texelFetch(uT, pos, 0)); } } diff --git a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl index 552e75c11d596..fe1f2780ac3a9 100644 --- a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl @@ -1,42 +1,41 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 inputSize; - ivec4 outputSize; - ivec2 kernelSize; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec4 kernel; ivec2 stride; ivec2 padding; - ivec2 dilate; -} -uConstBlock; - -#define UP_DIV(x, y) (((x) + (y)-1) / (y)) +} uBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - ivec3 outputSize = uConstBlock.outputSize.xyz; - if (all(lessThan(pos, outputSize))) { - ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding; - ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate))); - ivec2 efxy = - min(uConstBlock.kernelSize, - UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate)); - - vec4 r = vec4(1.0) / float(efxy.x - sfxy.x) / float(efxy.x - sfxy.x); - vec4 acc = vec4(0); - - for (int kyi = sfxy.y; kyi < efxy.y; ++kyi) { - for (int kxi = sfxy.x; kxi < efxy.x; ++kxi) { - ivec2 ixy = s0 + ivec2(kxi, kyi); - acc += texelFetch(uInput, ivec3(ixy.x, ixy.y, pos.z), 0); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + const ivec2 start = max(ivec2(0), ipos); + const ivec2 end = min(ipos + uBlock.kernel.xy, uBlock.kernel.zw); + + vec4 sum = vec4(0); + + for (int y = start.y; y < end.y; ++y) { + for (int x = start.x; x < end.x; ++x) { + sum += texelFetch(uInput, ivec3(x, y, pos.z), 0); } } - imageStore(uOutput, pos, r * acc); + imageStore( + uOutput, + pos, + sum / uBlock.size.w); } } diff --git a/aten/src/ATen/native/vulkan/glsl/clamp.glsl b/aten/src/ATen/native/vulkan/glsl/clamp.glsl index 24104c2285a1d..52c2d2d96c268 100644 --- a/aten/src/ATen/native/vulkan/glsl/clamp.glsl +++ b/aten/src/ATen/native/vulkan/glsl/clamp.glsl @@ -1,23 +1,26 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { ivec4 size; - float minValue; - float maxValue; -} -uConstBlock; + vec2 clamp; +} uBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.size.xyz))) { - vec4 v = texelFetch(uInput, pos, 0); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { imageStore( - uOutput, pos, clamp(v, uConstBlock.minValue, uConstBlock.maxValue)); + uOutput, + pos, + clamp(texelFetch(uInput, pos, 0), uBlock.clamp.x, uBlock.clamp.y)); } } diff --git a/aten/src/ATen/native/vulkan/glsl/clamp_.glsl b/aten/src/ATen/native/vulkan/glsl/clamp_.glsl new file mode 100644 index 0000000000000..3f138bb93ec6b --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/clamp_.glsl @@ -0,0 +1,25 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION restrict Block { + ivec4 size; + vec2 clamp; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + imageStore( + uOutput, + pos, + clamp(imageLoad(uOutput, pos), uBlock.clamp.x, uBlock.clamp.y)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl new file mode 100644 index 0000000000000..4403d34b88ca7 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -0,0 +1,60 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; +layout(set = 0, binding = 3) buffer PRECISION restrict readonly Bias { + vec4 data[]; +} uBias; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec4 size; + ivec4 kernel; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + const ivec2 start = max(ivec2(0), ipos); + const ivec2 end = min(ipos + uBlock.kernel.xy, uBlock.kernel.zw); + ivec2 kstart = (start - ipos) / uBlock.dilate; + + kstart.x *= 4; + kstart.y += pos.z * uBlock.ikernel.y; + + vec4 sum = uBias.data[pos.z]; + + for (int z4 = 0; z4 < uBlock.size.w; ++z4, kstart.x += uBlock.ikernel.x) { + for (int y = start.y, ky = kstart.y; y < end.y; y += uBlock.dilate.y, ++ky) { + for (int x = start.x, kx = kstart.x; x < end.x; x += uBlock.dilate.x, kx += 4) { + const vec4 In = texelFetch(uInput, ivec3(x, y, z4), 0); + const ivec4 kxs = kx + ivec4(0, 1, 2, 3); + + sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kxs.x, ky, 0), 0), sum); + sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kxs.y, ky, 0), 0), sum); + sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kxs.z, ky, 0), 0), sum); + sum = fma(In.wwww, texelFetch(uKernel, ivec3(kxs.w, ky, 0), 0), sum); + } + } + } + + imageStore( + uOutput, + pos, + clamp(sum, uBlock.clamp.x, uBlock.clamp.y)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl new file mode 100644 index 0000000000000..2ac7d4c49838d --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl @@ -0,0 +1,51 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; +layout(set = 0, binding = 3) buffer PRECISION restrict readonly Bias { + vec4 data[]; +} uBias; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec4 size; + ivec4 kernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + const ivec2 start = max(ivec2(0), ipos); + const ivec2 end = min(ipos + uBlock.kernel.xy, uBlock.kernel.zw); + const ivec2 kstart = (start - ipos) / uBlock.dilate; + + vec4 sum = uBias.data[pos.z]; + + for (int y = start.y, ky = kstart.y; y < end.y; y += uBlock.dilate.y, ++ky) { + for (int x = start.x, kx = kstart.x + ky * uBlock.size.w; x < end.x; x += uBlock.dilate.x, ++kx) { + sum = fma( + texelFetch(uInput, ivec3(x, y, pos.z), 0), + texelFetch(uKernel, ivec3(kx, pos.z, 0), 0), + sum); + } + } + + imageStore( + uOutput, + pos, + clamp(sum, uBlock.clamp.x, uBlock.clamp.y)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl index 6c5ac25460574..5155c07669c1c 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; @@ -23,7 +22,7 @@ uConstBlock; #define UP_DIV(x, y) (((x) + (y)-1) / (y)) -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl index 1b4f3a6daa368..89411284fed43 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; @@ -23,7 +22,7 @@ uConstBlock; #define UP_DIV(x, y) (((x) + (y)-1) / (y)) -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { ivec3 gpos = ivec3(gl_GlobalInvocationID); diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl new file mode 100644 index 0000000000000..8baae9b5fcd5f --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl @@ -0,0 +1,65 @@ +#version 450 core +#define PRECISION $precision +layout(std430) buffer; +layout(set = 0, rgba32f, binding = 0) writeonly PRECISION uniform image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; +layout(set = 0, binding = 3) readonly buffer bias { + vec4 data[]; +} +uBias; +layout(set = 0, binding = 4) uniform constBlock { + ivec2 padding; + ivec2 kernelSize; + ivec2 stride; + ivec2 dilate; + ivec4 outputSize; + ivec4 inputSize; + float outputMin; + float outputMax; +} +uConstBlock; + +#define UP_DIV(x, y) (((x) + (y)-1) / (y)) + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uConstBlock.outputSize.xyz))) { + int kernelX = uConstBlock.kernelSize.x; + int kernelY = uConstBlock.kernelSize.y; + ivec3 inputSize = uConstBlock.inputSize.xyz; + ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding; + int fx, fy, fz; + ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate))); + ivec2 efxy = + min(uConstBlock.kernelSize, + UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate)); + vec4 color = uBias.data[pos.z]; + int kY = pos.z; + int strideX = uConstBlock.stride.x; + for (fy = sfxy.y; fy < efxy.y; ++fy) { + int sy = fy * uConstBlock.dilate.y + s0.y; + for (fx = 0; fx < kernelX; ++fx) { + int kZ = fx + fy * kernelX; + int sx = fx * uConstBlock.dilate.x + s0.x; + fz = 0; + for (; fz < inputSize.z; ++fz) { + int kX = 4 * fz; + vec4 k0 = texelFetch(uKernel, ivec3(kX + 0, kY, kZ), 0); + vec4 k1 = texelFetch(uKernel, ivec3(kX + 1, kY, kZ), 0); + vec4 k2 = texelFetch(uKernel, ivec3(kX + 2, kY, kZ), 0); + vec4 k3 = texelFetch(uKernel, ivec3(kX + 3, kY, kZ), 0); + + mat4 k = mat4(k0, k1, k2, k3); + + color += k * texelFetch(uInput, ivec3(sx, sy, fz), 0); + } + } + } + vec4 outputMin = vec4(uConstBlock.outputMin); + vec4 outputMax = vec4(uConstBlock.outputMax); + imageStore(uOutput, ivec3(pos.x, pos.y, pos.z), clamp(color, outputMin, outputMax)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl new file mode 100644 index 0000000000000..cd6f60e9641f6 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl @@ -0,0 +1,46 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; +layout(set = 0, binding = 3) buffer PRECISION restrict readonly Bias { + vec4 data[]; +} uBias; +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + ivec4 size; + ivec2 stride; + ivec2 padding; + vec2 clamp; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + vec4 sum = uBias.data[pos.z]; + + for (int z = 0, z4 = 0; z < uBlock.size.w; z += 4, ++z4) { + const vec4 In = texelFetch(uInput, ivec3(ipos, z4), 0); + const ivec4 kxs = z + ivec4(0, 1, 2, 3); + + sum = fma(In.xxxx, texelFetch(uKernel, ivec3(kxs.x, pos.z, 0), 0), sum); + sum = fma(In.yyyy, texelFetch(uKernel, ivec3(kxs.y, pos.z, 0), 0), sum); + sum = fma(In.zzzz, texelFetch(uKernel, ivec3(kxs.z, pos.z, 0), 0), sum); + sum = fma(In.wwww, texelFetch(uKernel, ivec3(kxs.w, pos.z, 0), 0), sum); + } + + imageStore( + uOutput, + pos, + clamp(sum, uBlock.clamp.x, uBlock.clamp.y)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl index 1f35f334c997f..01d653bf06de1 100644 --- a/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl @@ -1,31 +1,33 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, binding = 0) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 1) writeonly buffer destBuffer { + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0) uniform PRECISION sampler3D uImage; +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { float data[]; -} -uOutBuffer; -layout(set = 0, binding = 2) uniform sizeBlock { - int width; - int height; -} -uSizeBlock; +} uBuffer; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec4 offset; +} uBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - int W = uSizeBlock.width; - int H = uSizeBlock.height; - int WH = W * H; - if (pos.x < W && pos.y < H) { - vec4 color = texelFetch(uInput, pos, 0); - int z = pos.z * 4; - uOutBuffer.data[W * pos.y + pos.x + (z + 0) * WH] = color.r; - uOutBuffer.data[W * pos.y + pos.x + (z + 1) * WH] = color.g; - uOutBuffer.data[W * pos.y + pos.x + (z + 2) * WH] = color.b; - uOutBuffer.data[W * pos.y + pos.x + (z + 3) * WH] = color.a; + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const vec4 texel = texelFetch(uImage, pos, 0); + + const int base = pos.x + uBlock.size.x * pos.y + uBlock.size.w * pos.z; + const ivec4 index = base + uBlock.offset; + + uBuffer.data[index.x] = texel.r; + uBuffer.data[index.y] = texel.g; + uBuffer.data[index.z] = texel.b; + uBuffer.data[index.w] = texel.a; } } diff --git a/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl index 458e4b68f551b..88373605d010a 100644 --- a/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; layout(set = 0, binding = 2) uniform constBlock { @@ -17,7 +16,7 @@ uConstBlock; #define UP_DIV(x, y) (((x) + (y)-1) / (y)) #define FLT_MAX 3.402823466e+38 -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); diff --git a/aten/src/ATen/native/vulkan/glsl/mean.glsl b/aten/src/ATen/native/vulkan/glsl/mean.glsl index a602d7a2e9776..551fd747f103d 100644 --- a/aten/src/ATen/native/vulkan/glsl/mean.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mean.glsl @@ -1,37 +1,36 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - int W; - int H; - int OW; - int OH; -} -uConstBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec2 isize; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// This implementation is suboptimal and should be revisted. void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - int W = uConstBlock.W; - int H = uConstBlock.H; - int OW = uConstBlock.OW; - int OH = uConstBlock.OH; - vec4 r = vec4(1.0) / float(W) / float(H); - vec4 acc = vec4(0); - int xi, yi; - for (xi = 0; xi < W; ++xi) { - for (yi = 0; yi < H; ++yi) { - acc += texelFetch(uInput, ivec3(xi, yi, pos.z), 0); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 sum = vec4(0); + + for (int y = 0; y < uBlock.isize.y; ++y) { + for (int x = 0; x < uBlock.isize.x; ++x) { + sum += texelFetch(uInput, ivec3(x, y, pos.z), 0); + } } - } - vec4 outValue = r * acc; - for (int vi = 0; vi < 4; ++vi) { - int oy = (4 * pos.z + vi) / OW; - int ox = (4 * pos.z + vi) % OW; - imageStore(uOutput, ivec3(ox, oy, 0), vec4(outValue[vi], 0, 0, 0)); + + imageStore( + uOutput, + pos, + sum / uBlock.size.w); } } diff --git a/aten/src/ATen/native/vulkan/glsl/mean2d.glsl b/aten/src/ATen/native/vulkan/glsl/mean2d.glsl new file mode 100644 index 0000000000000..b8d0add329f25 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/mean2d.glsl @@ -0,0 +1,40 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec2 isize; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// This implementation is suboptimal and should be revisted. + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 sum = vec4(0); + + const int z = pos.x + uBlock.size.x * pos.y; + const int zi = z / 4; + const int zo = z % 4; + + for (int y = 0; y < uBlock.isize.y; ++y) { + for (int x = 0; x < uBlock.isize.x; ++x) { + sum += texelFetch(uInput, ivec3(x, y, zi), 0); + } + } + + imageStore( + uOutput, + pos, + vec4(sum[zo], 0, 0, 0) / uBlock.size.w); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/mm.glsl b/aten/src/ATen/native/vulkan/glsl/mm.glsl index 771617d64b8ac..157acfe9c074b 100644 --- a/aten/src/ATen/native/vulkan/glsl/mm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mm.glsl @@ -1,31 +1,32 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; -layout(set = 0, binding = 3) uniform constBlock { - ivec4 outputSize; - float beta; - float alpha; - int K; -} -uConstBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; +layout(set = 0, binding = 3) uniform PRECISION restrict Block { + ivec4 size; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.outputSize.xyz))) { - int K = uConstBlock.K; - vec4 mmv = vec4(0); - int ki = 0; - for (; ki < K; ++ki) { - vec4 m1ki = texelFetch(uM1, ivec3(ki, pos.y, pos.z), 0); - vec4 m2ki = texelFetch(uM2, ivec3(pos.x, ki, pos.z), 0); - mmv += m1ki * m2ki; + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 sum = vec4(0); + + for (int k = 0; k < uBlock.size.w; ++k) { + sum = fma( + texelFetch(uM1, ivec3(k, pos.y, pos.z), 0), + texelFetch(uM2, ivec3(pos.x, k, pos.z), 0), + sum); } - imageStore(uOutput, pos, uConstBlock.alpha * mmv); + + imageStore(uOutput, pos, sum); } } diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl index d34a99d2c6e84..c0ae48fe3883f 100644 --- a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl @@ -1,21 +1,26 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 sizes; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec3 size; float other; -} -uConstBlock; +} uBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.sizes.xyz))) { - vec4 v = uConstBlock.other * texelFetch(uInput, pos, 0); - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size))) { + imageStore( + uOutput, + pos, + texelFetch(uInput, pos, 0) * uBlock.other); } } diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl new file mode 100644 index 0000000000000..f959052879adc --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl @@ -0,0 +1,25 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION restrict Block { + ivec3 size; + float other; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) * uBlock.other); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl b/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl index 84e2dcfb6504f..adbafcbd0438a 100644 --- a/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl +++ b/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl @@ -1,32 +1,35 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uImage; -layout(set = 0, binding = 1) readonly buffer destBuffer { + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uImage; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { float data[]; -} -uInBuffer; -layout(set = 0, binding = 2) uniform sizeBlock { - int width; - int height; -} -uSizeBlock; +} uBuffer; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec4 offset; +} uBlock; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - int W = uSizeBlock.width; - int H = uSizeBlock.height; - if (pos.x < W && pos.y < H) { - vec4 color; - int z = pos.z * 4; - int WH = W * H; - color.r = uInBuffer.data[W * pos.y + pos.x + (z + 0) * WH]; - color.g = uInBuffer.data[W * pos.y + pos.x + (z + 1) * WH]; - color.b = uInBuffer.data[W * pos.y + pos.x + (z + 2) * WH]; - color.a = uInBuffer.data[W * pos.y + pos.x + (z + 3) * WH]; - imageStore(uImage, pos, color); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const int base = pos.x + uBlock.size.x * pos.y + uBlock.size.w * pos.z; + const ivec4 index = base + uBlock.offset; + + imageStore( + uImage, + pos, + vec4( + uBuffer.data[index.x], + uBuffer.data[index.y], + uBuffer.data[index.z], + uBuffer.data[index.w])); } } diff --git a/aten/src/ATen/native/vulkan/glsl/permute.glsl b/aten/src/ATen/native/vulkan/glsl/permute.glsl index bd0b6637efae0..3d1191ff6eea4 100644 --- a/aten/src/ATen/native/vulkan/glsl/permute.glsl +++ b/aten/src/ATen/native/vulkan/glsl/permute.glsl @@ -1,6 +1,5 @@ #version 450 core layout(std430) buffer; -layout(std430) uniform; layout(set = 0, binding = 0) writeonly buffer outputBuffer { float data[]; } @@ -17,7 +16,7 @@ layout(set = 0, binding = 2) uniform constBlock { } uConst; -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); diff --git a/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl deleted file mode 100644 index d7e4619a283a9..0000000000000 --- a/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl +++ /dev/null @@ -1,35 +0,0 @@ -#version 450 core -#define PRECISION $precision -layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - int IW; - int IH; - int OW; - int OH; - float scaleX; - float scaleY; -} -uConstBlock; - -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; - -void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - int ow = uConstBlock.OW; - int oh = uConstBlock.OH; - if (pos.x < ow && pos.y < oh) { - int iw = uConstBlock.IW; - int ih = uConstBlock.IH; - float srcX = float(pos.x) * uConstBlock.scaleX; - int x1 = int(floor(srcX)); - int x11 = clamp(x1, 0, iw - 1); - float srcY = float(pos.y) * uConstBlock.scaleY; - int y1 = int(floor(srcY)); - int y11 = clamp(y1, 0, ih - 1); - vec4 outValue = texelFetch(uInput, ivec3(x11, y11, pos.z), 0); - imageStore(uOutput, pos, outValue); - } -} diff --git a/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl new file mode 100644 index 0000000000000..b4db9b87dacb9 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl @@ -0,0 +1,32 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + ivec2 isize; + vec2 scale; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const ivec2 ipos = clamp( + ivec2(pos.xy * uBlock.scale), + ivec2(0), + uBlock.isize); + + imageStore( + uOutput, + pos, + texelFetch(uInput, ivec3(ipos, pos.z), 0)); + } +} diff --git a/aten/src/ATen/native/vulkan/ops/Add.cpp b/aten/src/ATen/native/vulkan/ops/Add.cpp new file mode 100644 index 0000000000000..95b7fd67c0951 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Add.cpp @@ -0,0 +1,267 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor add_scalar( + const Tensor& self_arg, + const Scalar other, + const Scalar alpha) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + v_self.sizes(), + v_self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { + uvec3 extents; + float other; + } block { + v_self.extents(), + other.to() * alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_scalar), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +Tensor& add_scalar_( + Tensor& self, + const Scalar other, + const Scalar alpha) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self.is_vulkan(), + "Vulkan: In-place add is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { + uvec3 extents; + float other; + } block { + v_self.extents(), + other.to() * alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_scalar_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return self; +} + +Tensor add_tensor( + const Tensor& self_arg, + const Tensor& other_arg, + const Scalar alpha) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + const vTensor& v_other = convert(other); + + vTensor v_output{ + context, + v_self.sizes(), + v_self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image() && v_other.has_image()) { + const struct Block final { + uvec3 extents; + float alpha; + } block { + v_output.extents(), + alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_other.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +Tensor& add_tensor_( + Tensor& self, + const Tensor& other_arg, + const Scalar alpha) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self.is_vulkan(), + "Vulkan: In-place add is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self); + + const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + const vTensor& v_other = convert(other); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image() && v_other.has_image() && !self.is_same(other)) { + const struct Block final { + uvec3 extents; + float alpha; + } block { + v_self.extents(), + alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Read | vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_other.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return self; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("add.Scalar", TORCH_FN(add_scalar)); + m.impl("add_.Scalar", TORCH_FN(add_scalar_)); + m.impl("add.Tensor", TORCH_FN(add_tensor)); + m.impl("add_.Tensor", TORCH_FN(add_tensor_)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Clamp.cpp b/aten/src/ATen/native/vulkan/ops/Clamp.cpp new file mode 100644 index 0000000000000..75e9a1bb0ffff --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Clamp.cpp @@ -0,0 +1,180 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor clamp( + const Tensor& self_arg, + const c10::optional min, + const c10::optional max) { + TORCH_CHECK( + min || max, + "At least one of 'min' or 'max' must not be None"); + + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + v_self.sizes(), + v_self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { + uvec3 extents; + uint32_t _; + vec2 clamp; + } block { + v_output.extents(), + 0u, + { + min ? min->to() : -std::numeric_limits::infinity(), + max ? max->to() : std::numeric_limits::infinity(), + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(clamp), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +Tensor& clamp_( + Tensor& self, + const c10::optional min, + const c10::optional max) { + api::Context* const context = api::context(); + + TORCH_CHECK( + min || max, + "At least one of 'min' or 'max' must not be None"); + + TORCH_CHECK( + self.is_vulkan(), + "Vulkan: In-place clamp is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { + uvec3 extents; + uint32_t _; + vec2 clamp; + } block { + v_self.extents(), + 0u, + { + min ? min->to() : -std::numeric_limits::infinity(), + max ? max->to() : std::numeric_limits::infinity(), + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(clamp_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return self; +} + +Tensor hardtanh( + const Tensor& self, + const Scalar min, + const Scalar max) { + return ops::clamp(self, min, max); +} + +Tensor& hardtanh_( + Tensor& self, + const Scalar min, + const Scalar max) { + return ops::clamp_(self, min, max); +} + +Tensor relu(const Tensor& self) { + return ops::clamp(self, 0, c10::nullopt); +} + +Tensor& relu_(Tensor& self) { + return ops::clamp_(self, 0, c10::nullopt); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("clamp", TORCH_FN(clamp)); + m.impl("clamp_", TORCH_FN(clamp_)); + m.impl("hardtanh", hardtanh); + m.impl("hardtanh_", hardtanh_); + m.impl("relu", relu); + m.impl("relu_", relu_); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h new file mode 100644 index 0000000000000..3c9b2e8b3b9f6 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -0,0 +1,43 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +struct Layout final { + // 4D Activation Maps + struct Activation4D final { + static constexpr size_t batch = 0u; + static constexpr size_t channels = 1u; + static constexpr size_t height = 2u; + static constexpr size_t width = 3u; + }; + + // Convolution Filters + struct Filter final { + static constexpr size_t output = 0u; + static constexpr size_t input = 1u; + static constexpr size_t height = 2u; + static constexpr size_t width = 3u; + }; + + // Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.) + struct Parameter final { + static constexpr size_t height = 0u; + static constexpr size_t width = 1u; + }; +}; + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp new file mode 100644 index 0000000000000..f1b9f778cca84 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -0,0 +1,1070 @@ +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +struct Experimentation final { + static constexpr bool kUseConv2dOldApi = false; +}; + +inline bool is_depthwise( + const IntArrayRef filter, + const int64_t groups) { + return (filter[Layout::Filter::output] == groups) && + // Only K == 1 supported. + (filter[Layout::Filter::input] == 1); +} + +inline bool is_pointwise(const IntArrayRef filter) { + return (1 == filter[Layout::Filter::height]) && + (1 == filter[Layout::Filter::width]); +} + +vTensor pack_weights_dw( + api::Context* const context, + api::Command::Buffer& command_buffer, + api::Resource::Pool& pool, + const Tensor& weight) { + /* Source */ + const IntArrayRef src_filter = weight.sizes(); + const float* const src_weight_ptr = weight.data_ptr(); + + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; + const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + + /* Destination */ + const int64_t dst_kw_sz = src_kernel_sz; + const int64_t dst_kh_sz = num_stacks; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + + vTensor v_weight{ + context, + &pool, + { + 4, + dst_kh_sz, + dst_kw_sz, + }, + weight.options(), + }; + + using Future = vTensor::Future; + Future v_weight_future = v_weight.host(command_buffer); + Future::Payload v_weight_payload = v_weight_future.wait(); + + float* const dst_weight_ptr = v_weight_payload.get(); + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { + /* Source */ + const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; + + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; + + float* const dst_weight_c_ptr = dst_weight_ptr + + dst_c * dst_kernel_sz + + dst_oh * dst_kw_sz; + + for (int64_t src_ih = 0; src_ih < src_filter[Layout::Filter::height]; ++src_ih) { + memcpy( + dst_weight_c_ptr + src_ih * src_kw_sz, + src_weight_oc_ptr + src_ih * src_kw_sz, + sizeof(float) * src_kw_sz); + } + } + + return v_weight; +} + +vTensor pack_weights_2d( + api::Context* const context, + api::Command::Buffer& command_buffer, + api::Resource::Pool& pool, + const Tensor& weight) { + /* Source */ + const IntArrayRef src_filter = weight.sizes(); + const float* const src_weight_ptr = weight.data_ptr(); + + const int64_t src_kw_sz = src_filter[Layout::Filter::width]; + const int64_t src_kh_sz = src_filter[Layout::Filter::height]; + const int64_t src_kernel_sz = src_kw_sz * src_kh_sz; + const int64_t src_block_sz = src_kernel_sz * src_filter[Layout::Filter::input]; + + const int64_t num_stacks = div_up(src_filter[Layout::Filter::output], INT64_C(4)); + const int64_t stack_depth = api::utils::align_up(src_filter[Layout::Filter::input], INT64_C(4)); + + /* Destination */ + const int64_t dst_kw_sz = src_kw_sz * stack_depth; + const int64_t dst_kh_sz = src_kh_sz * num_stacks; + const int64_t dst_kernel_sz = dst_kw_sz * dst_kh_sz; + + vTensor v_weight{ + context, + &pool, + { + 4, + dst_kh_sz, + dst_kw_sz, + }, + weight.options(), + }; + + using Future = vTensor::Future; + Future v_weight_future = v_weight.host(command_buffer); + Future::Payload v_weight_payload = v_weight_future.wait(); + + float* const dst_weight_ptr = v_weight_payload.get(); + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + for (int64_t src_oc = 0; src_oc < src_filter[Layout::Filter::output]; ++src_oc) { + /* Source */ + const float* const src_weight_oc_ptr = src_weight_ptr + src_oc * src_block_sz; + + /* Destination */ + const int64_t dst_oh = src_oc / 4; + const int64_t dst_c = src_oc % 4; + + float* const dst_weight_c_ptr = dst_weight_ptr + dst_c * dst_kernel_sz; + + for (int64_t src_ic = 0; src_ic < src_filter[Layout::Filter::input]; ++src_ic) { + const int64_t dst_ic4 = src_ic / 4; + + for (int64_t src_ih = 0; src_ih < src_kh_sz; ++src_ih) { + for (int64_t src_iw = 0; src_iw < src_kw_sz; ++src_iw) { + memcpy( + dst_weight_c_ptr + (dst_oh * src_kh_sz + src_ih) * dst_kw_sz + + dst_ic4 * src_kw_sz * 4 + src_iw * 4 + src_ic % 4, + src_weight_oc_ptr + src_ic * src_kernel_sz + src_ih * src_kw_sz + src_iw, + sizeof(float)); + } + } + } + } + + return v_weight; +} + +vTensor pack_weights_2d_old( + api::Context* const context, + api::Command::Buffer& command_buffer, + api::Resource::Pool& pool, + const Tensor& weight) { + const IntArrayRef src_filter = weight.sizes(); + const float* const src_weight_ptr = weight.data_ptr(); + + const uint32_t OC = src_filter[Layout::Filter::output]; + const uint32_t OC_4 = at::native::vulkan::api::utils::div_up(OC, 4u); + const uint32_t C = src_filter[Layout::Filter::input]; + const uint32_t C_4 = at::native::vulkan::api::utils::div_up(C, 4u); + const uint32_t KH = src_filter[Layout::Filter::height]; + const uint32_t KW = src_filter[Layout::Filter::width]; + + vTensor v_weight{ + context, + &pool, + { + 1, + 4 * KH * KW, + OC_4, + 4 * C_4 + }, + weight.options(), + }; + + using Future = vTensor::Future; + Future v_weight_future = v_weight.host(command_buffer); + Future::Payload v_weight_payload = v_weight_future.wait(); + + float* const dst_weight_ptr = v_weight_payload.get(); + memset(dst_weight_ptr, 0, v_weight.nbytes()); + + const float* const src = src_weight_ptr; + float* const dst = dst_weight_ptr; + + { + uint32_t ridx = 0; + const uint32_t oc_4SizeNumel = KW * KH * C_4 * 16; + for (uint32_t oc = 0; oc < OC; ++oc) { + int oc_4 = oc / 4; + int oc_4_i = oc % 4; + float* dst_oc = dst + oc_4 * oc_4SizeNumel; + for (uint32_t ic = 0; ic < C; ++ic) { + int ic_4 = ic / 4; + int ic_4_i = ic % 4; + float* dst_ic = dst_oc + ic_4 * KW * KH * 16; + for (uint32_t ky = 0; ky < KH; ++ky) { + float* dst_ky = dst_ic + ky * KW * 16; + for (uint32_t kx = 0; kx < KW; ++kx) { + float* dst_kx = dst_ky + kx * 16; + dst_kx[4 * ic_4_i + oc_4_i] = src[ridx++]; + } + } + } + } + + // shader KO4C4HW_to_image + struct Image3D { + float* data_; + uint32_t dim0_, dim1_, dim2_; + + Image3D(uint32_t dim0, uint32_t dim1, uint32_t dim2) { + dim0_ = dim0; + dim1_ = dim1; + dim2_ = dim2; + data_ = new float[dim0 * dim1 * dim2 * 4]; // TODO: memory leak + memset(data_, 0.f, dim0 * dim1 * dim2 * 4 * sizeof(float)); + } + + inline uint32_t idx(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) { + return i3 + i2 * 4 + i1 * 4 * dim2_ + i0 * 4 * dim2_ * dim1_; + } + + void set(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, float value) { + data_[idx(i0, i1, i2, i3)] = value; + } + + float get(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) { + return data_[idx(i0, i1, i2, i3)]; + } + } image{4 * C_4, OC_4, KH * KW}; + + for (uint32_t sx = 0; sx < C_4; ++sx) { + for (uint32_t sy = 0; sy < OC_4; ++sy) { + for (uint32_t sz = 0; sz < (KH * KW); ++sz) { + for (uint32_t vi = 0; vi < 4; ++vi) { + int bufferVIdx = 4 * sx * KH * KW + 4 * sy * C_4 * KH * KW + 4 * sz; + image.set(4 * sx + 0, sy, sz, vi, dst[4 * (bufferVIdx + 0) + vi]); + image.set(4 * sx + 1, sy, sz, vi, dst[4 * (bufferVIdx + 1) + vi]); + image.set(4 * sx + 2, sy, sz, vi, dst[4 * (bufferVIdx + 2) + vi]); + image.set(4 * sx + 3, sy, sz, vi, dst[4 * (bufferVIdx + 3) + vi]); + } + } + } + } + + // inverse function of nchw_to_image + const uint32_t W = 4 * C_4; + const uint32_t H = OC_4; + const uint32_t D = KH * KW; + for (uint32_t sx = 0; sx < W; ++sx) { + for (uint32_t sy = 0; sy < H; ++sy) { + for (uint32_t sz = 0; sz < D; ++sz) { + for (uint32_t szvi = 0; szvi < 4; ++szvi) { + dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image.get(sx, sy, sz, szvi); + } + } + } + } + } + + return v_weight; +} + +vTensor pack_weights( + api::Resource::Pool& pool, + const Tensor& weight_arg, + const int64_t groups) { + if (weight_arg.is_vulkan()) { + return convert(weight_arg); + } + + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + + const Tensor weight = weight_arg.contiguous(); + + if (is_depthwise(weight.sizes(), groups)) { + return pack_weights_dw( + context, + command_buffer, + pool, + weight); + } + + if (Experimentation::kUseConv2dOldApi) { + return pack_weights_2d_old( + context, + command_buffer, + pool, + weight); + } + + return pack_weights_2d( + context, + command_buffer, + pool, + weight); +} + +vTensor pack_biases( + api::Resource::Pool& pool, + const c10::optional& bias, + const Tensor& weight) { + if (bias && bias->is_vulkan()) { + return convert(*bias); + } + + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + + vTensor v_bias{ + context, + &pool, + { + // 1D + weight.size(Layout::Filter::output), + }, + weight.options(), + }; + + { + using Future = vTensor::Future; + Future v_bias_future = v_bias.host(command_buffer); + Future::Payload v_bias_payload = v_bias_future.wait(); + + if (bias) { + memcpy( + v_bias_payload.get(), + bias->contiguous().data_ptr(), + std::min(bias->nbytes(), v_bias.nbytes())); + } + else { + memset( + v_bias_payload.get(), + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); + } + } + + return v_bias; +} + +std::array pack_filter( + const Tensor& weight, + const IntArrayRef dilation) { + const IntArrayRef filter = weight.sizes(); + + const auto effective = [](const int64_t k, const int64_t d) { + return k + (k - 1) * (d - 1); + }; + + return { + align_up(filter[Layout::Filter::output], INT64_C(4)), + align_up(filter[Layout::Filter::input], INT64_C(4)), + effective( + filter[Layout::Filter::height], + dilation[Layout::Parameter::height]), + effective( + filter[Layout::Filter::width], + dilation[Layout::Parameter::width]), + }; +} + +std::array pack_params(const std::vector& vector) { + TORCH_INTERNAL_ASSERT(2u == vector.size(), "Invalid usage!"); + + return { + vector[0], + vector[1], + }; +} + +bool available( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool transposed, + const IntArrayRef /* output_padding */, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) { + return api::available() && + // Weight + (4 == weight.ndimension()) && + (weight.size(Layout::Filter::height) > 0) && + (weight.size(Layout::Filter::width) > 0) && + ((c10::DeviceType::CPU == weight.device().type()) || + (c10::DeviceType::Vulkan == weight.device().type())) && + (kFloat == weight.scalar_type()) && + // Bias + ((bias && bias->defined()) ? ((1 == bias->ndimension()) && + ((c10::DeviceType::CPU == bias->device().type()) || + (c10::DeviceType::Vulkan == bias->device().type())) && + (kFloat == bias->scalar_type()) && + (transposed ? false /* to be addded in the future */ + : (weight.size(Layout::Filter::output) == + bias->size(Layout::Filter::output)))) + : true) && + // Stride + (stride[Layout::Parameter::height] > 0) && + (stride[Layout::Parameter::width] > 0) && + // Padding + (padding[Layout::Parameter::height] >= 0) && + (padding[Layout::Parameter::width] >= 0) && + // Dilation + (dilation[Layout::Parameter::height] > 0) && + (dilation[Layout::Parameter::width] > 0) && + // Groups + (groups > 0) && + // Input + (weight.size(Layout::Filter::input) > 0) && + // Output + (weight.size(Layout::Filter::output) > 0) && + // Output - Groups + ((weight.size(Layout::Filter::output) % groups) == 0) && + // Output Min / Max + (!output_min || output_min->isFloatingPoint()) && + (!output_max || output_max->isFloatingPoint()) && + true; +} + +bool usable(const Tensor& input) { + // Input + return (4 == input.ndimension()) && + (c10::DeviceType::Vulkan == input.device().type()) && + (kFloat == input.scalar_type()) && + (input.size(Layout::Activation4D::batch) >= 0) && + (input.size(Layout::Activation4D::channels) > 0) && + (input.size(Layout::Activation4D::height) > 0) && + (input.size(Layout::Activation4D::width) > 0) && + !input.requires_grad() && + true; +} + +void conv2d_dw( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_input, + const vTensor& v_weight, + const vTensor& v_bias, + const IntArrayRef filter, + const IntArrayRef src_filter, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const float output_min, + const float output_max) { + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + uvec3 extents; + int32_t src_filter_width; + ivec4 kernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + } block { + v_output.extents(), + safe_downcast(src_filter[Layout::Filter::width]), + { + safe_downcast(filter[Layout::Filter::width]), + safe_downcast(filter[Layout::Filter::height]), + safe_downcast(v_input.sizes()[Layout::Activation4D::width]), + safe_downcast(v_input.sizes()[Layout::Activation4D::height]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(conv2d_dw), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_weight.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_bias.buffer( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } +} + +void conv2d_pw( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_input, + const vTensor& v_weight, + const vTensor& v_bias, + const IntArrayRef filter, + const IntArrayRef stride, + const IntArrayRef padding, + const float output_min, + const float output_max) { + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + uvec3 extents; + int32_t ic; + ivec2 stride; + ivec2 padding; + vec2 clamp; + } block { + v_output.extents(), + safe_downcast(filter[Layout::Filter::input]), + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(conv2d_pw), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_weight.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_bias.buffer( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } +} + +void conv2d( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_input, + const vTensor& v_weight, + const vTensor& v_bias, + const IntArrayRef filter, + const IntArrayRef src_filter, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const float output_min, + const float output_max) { + if C10_LIKELY(v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const struct Block final { + uvec3 extents; + int32_t ic4; + ivec4 kernel; + ivec2 ikernel; + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp; + ivec4 src_filter; + } block { + v_output.extents(), + safe_downcast(filter[Layout::Filter::input] / 4), + { + safe_downcast(filter[Layout::Filter::width]), + safe_downcast(filter[Layout::Filter::height]), + safe_downcast(v_input.sizes()[Layout::Activation4D::width]), + safe_downcast(v_input.sizes()[Layout::Activation4D::height]), + }, + { + safe_downcast(src_filter[Layout::Filter::width] * 4), + safe_downcast(src_filter[Layout::Filter::height]), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + { + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + }, + { + output_min, + output_max, + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(conv2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_weight.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_bias.buffer( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } +} + +void conv2d_old( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_input, + const vTensor& v_weight, + const vTensor& v_bias, + const IntArrayRef filter, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const float output_min, + const float output_max) { + using namespace api::utils; + + if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) { + const int32_t W = v_input.extents().data[0]; + const int32_t H = v_input.extents().data[1]; + const int32_t C_4 = v_input.extents().data[2]; + const int32_t C = 4 * C_4; + + const int32_t OW = v_output.extents().data[0]; + const int32_t OH = v_output.extents().data[1]; + const int32_t OC_4 = v_output.extents().data[2]; + const int32_t OC = 4 * OC_4; + + const struct Block final { + int32_t padding_x, padding_y; + int32_t kernel_x, kernel_y; + int32_t stride_x, stride_y; + int32_t dilate_x, dilate_y; + int32_t outputSize[4]; + int32_t inputSize[4]; + float outputMin; + float outputMax; + } block { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + safe_downcast(filter[Layout::Filter::width]), + safe_downcast(filter[Layout::Filter::height]), + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + safe_downcast(dilation[Layout::Parameter::width]), + safe_downcast(dilation[Layout::Parameter::height]), + { OW, OH, OC_4, OC }, + { W, H, C_4, C }, + output_min, + output_max, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(conv2d_nogroup_clamp), + //VK_KERNEL(conv2d_nogroup_clamp_1x), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_weight.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_bias.buffer( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } +} + +Tensor convolution( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool transposed, + const IntArrayRef output_padding, + const int64_t groups) { + return Conv2dOpContext::create( + api::context()->resource().pool, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups + ).run(input); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("convolution_overrideable", convolution); +} + +#endif /* USE_VULKAN_API */ + +} // namespace + +Conv2dOpContext::Conv2dOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool /* transposed */, + const IntArrayRef /* output_padding */, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) + : packed_{ + pack_weights(pool, weight, groups), + pack_biases(pool, bias, weight), + pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2)), + pack_params(expand_param_if_needed(stride, "stride", 2)), + pack_params(expand_param_if_needed(padding, "padding", 2)), + pack_params(expand_param_if_needed(dilation, "dilation", 2)), + groups, + output_min ? output_min->template to() : -std::numeric_limits::infinity(), + output_max ? output_max->template to() : +std::numeric_limits::infinity(), + }, + unpacked_{ + weight, + bias, + weight.sizes().vec(), + stride.vec(), + padding.vec(), + dilation.vec(), + groups, + output_min, + output_max, + } { +} + +Conv2dOpContext Conv2dOpContext::create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const IntArrayRef dilation_arg, + const bool transposed, + const IntArrayRef output_padding_arg, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) { + const auto stride = expand_param_if_needed(stride_arg, "stride", 2); + const auto padding = expand_param_if_needed(padding_arg, "padding", 2); + const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2); + const auto output_padding = output_padding_arg; // TODO: Deconvolutions + + TORCH_CHECK( + available( + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_min, + output_max), + "Vulkan::convolution not available! " + "Reason: The provided (weight, bias, stride, padding, dilation, groups, " + "transposed, output_padding, output_min, output_max) parameters are either " + "invalid individually or their combination is not supported by Vulkan impl."); + + // Pass in the originals + return Conv2dOpContext{ + pool, + weight, + bias, + stride_arg, + padding_arg, + dilation_arg, + transposed, + output_padding_arg, + groups, + output_min, + output_max, + }; +} + +Tensor Conv2dOpContext::run(const Tensor& input_arg) const { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + TORCH_CHECK( + usable(input), + "Vulkan Convolution not usable! " + "Reason: The provided input tensor is either invalid or unsupported by Vulkan impl."); + + vTensor v_output{ + context, + conv_output_size( + v_input.sizes(), + unpacked_.filter, + packed_.padding, + packed_.stride, + packed_.dilation), + input.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if (is_depthwise(unpacked_.filter, unpacked_.groups)) { + conv2d_dw( + context, + command_buffer, + v_output, + v_input, + packed_.v_weight, + packed_.v_bias, + packed_.filter, + unpacked_.filter, + packed_.stride, + packed_.padding, + packed_.dilation, + packed_.output_min, + packed_.output_max); + } + else { + if (Experimentation::kUseConv2dOldApi) { + conv2d_old( + context, + command_buffer, + v_output, + v_input, + packed_.v_weight, + packed_.v_bias, + packed_.filter, + packed_.stride, + packed_.padding, + packed_.dilation, + packed_.output_min, + packed_.output_max); + } else { + if (is_pointwise(unpacked_.filter)) { + conv2d_pw( + context, + command_buffer, + v_output, + v_input, + packed_.v_weight, + packed_.v_bias, + packed_.filter, + packed_.stride, + packed_.padding, + packed_.output_min, + packed_.output_max); + } + else { + conv2d( + context, + command_buffer, + v_output, + v_input, + packed_.v_weight, + packed_.v_bias, + packed_.filter, + unpacked_.filter, + packed_.stride, + packed_.padding, + packed_.dilation, + packed_.output_min, + packed_.output_max); + } + } + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +Conv2dOpContext::State Conv2dOpContext::unpack() const { + return Conv2dOpContext::State{ + unpacked_.weight, + unpacked_.bias, + unpacked_.stride, + unpacked_.padding, + unpacked_.dilation, + unpacked_.groups, + unpacked_.output_min, + unpacked_.output_max, + }; +} + +c10::intrusive_ptr conv2d_clamp_prepack( + Tensor&& weight, + c10::optional&& bias, + std::vector&& stride, + std::vector&& padding, + std::vector&& dilation, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) { + return c10::make_intrusive( + Conv2dOpContext::create( + persistent()->pool, + std::move(weight), + std::move(bias), + std::move(stride), + std::move(padding), + std::move(dilation), + /* transposed = */ false, + /* output_padding = */ {}, + groups, + output_min, + output_max)); +} + +Tensor conv2d_clamp_run( + const Tensor& input, + const c10::intrusive_ptr& context) { + return context->run(input); +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.h b/aten/src/ATen/native/vulkan/ops/Convolution.h new file mode 100644 index 0000000000000..7bd27bb1942bf --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Convolution.h @@ -0,0 +1,100 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +class Conv2dOpContext final : public torch::jit::CustomClassHolder { + public: + static Conv2dOpContext create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + c10::optional output_min = c10::nullopt, + c10::optional output_max = c10::nullopt); + + using State = std::tuple< + Tensor, + c10::optional, + std::vector, + std::vector, + std::vector, + int64_t, + c10::optional, + c10::optional>; + + Tensor run(const Tensor& input) const; + State unpack() const; + + private: + Conv2dOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + c10::optional output_min = c10::nullopt, + c10::optional output_max = c10::nullopt); + + private: + struct { + vTensor v_weight; + vTensor v_bias; + std::array filter; + std::array stride; + std::array padding; + std::array dilation; + int32_t groups; + float output_min; + float output_max; + } packed_; + + struct { + Tensor weight; + c10::optional bias; + std::vector filter; + std::vector stride; + std::vector padding; + std::vector dilation; + int64_t groups; + c10::optional output_min; + c10::optional output_max; + } unpacked_; +}; + +Tensor conv2d_clamp_run( + const Tensor& input, + const c10::intrusive_ptr& context); + +c10::intrusive_ptr conv2d_clamp_prepack( + Tensor&& weight, + c10::optional&& bias, + std::vector&& stride, + std::vector&& padding, + std::vector&& dilation, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp new file mode 100644 index 0000000000000..1cf6b1ad6aa93 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -0,0 +1,168 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor& copy_(Tensor& self, const Tensor& src) { + api::Context* const context = api::context(); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + // X -> Vulkan + if (at::kVulkan == self.device().type()) { + vTensor& v_self = convert(self); + + // Vulkan -> Vulkan + if (at::kVulkan == src.device().type()) { + command_buffer.copy( + // - Read-only access is implied on const tensors. Memory barriers + // are automatically inserted if a RAW hazard is detected. + // - Recording any potential pending sync operations into the same + // command buffer prevents an expensive queue submission. + convert(src).buffer( + command_buffer, + vTensor::Stage::Transfer), + // - Write-only access never triggers a sync as the contents will be + // overwritten regardless. Having said that, appropriate barriers + // are inserted automatically if WAR or WAW hazards are detected. + // - Recording pending sync operations into the same command buffer + // prevents an expensive queue submission. + v_self.buffer( + command_buffer, + vTensor::Stage::Transfer, + vTensor::Access::Write)); + + command_pool.submit(context->gpu().queue, command_buffer); + } + // CPU -> Vulkan + else { + const Tensor cpu_src = src.device().is_cpu() ? src : src.cpu(); + + // Requesting write-only host access to the tensor never triggers a sync + // as the contents will be overwritten regardless. Having said that, + // appropriate barriers are inserted automatically if WAR or WAW hazards + // are detected. Examples of such scenario for instance are if any of + // these async operations are on going in the background on 'self': + // - On discrete systems: + // * buffer-to-staging transfers + // * staging-to-buffer transfers + // - On UMA buffer is an alias for staging and accessible both on host + // and device. Consequently: + // * buffer-to-image NHWC -> NC4HW packing + // * image-to-buffer NC4HW -> NHWC unpacking + + using Future = vTensor::Future; + Future v_self_future = v_self.host(command_buffer); + + // Ideally we would have been able to put as much distance between + // requesting the data - a call to host() - and accessing the data + // - a call to wait() - but a local view of the computation graph + // in eager mode makes that optimization non-trivial. + + // This wait() will be a no-op if no hazards are detected, including the + // obvious, yet important, special case of 'self' being an empty tensor. + + Future::Payload v_self_payload = v_self_future.wait(); + + memcpy( + v_self_payload.get(), + cpu_src.contiguous().data_ptr(), + std::min(src.nbytes(), self.nbytes())); + } + } + // Vulkan -> X + else if (at::kVulkan == src.device().type()) { + const vTensor& v_src = convert(src); + + // Vulkan -> CPU + if (self.device().is_cpu()) { + // Similar notes as above applies, with the additional consideration of + // potential syncs on read accesses. Namely, + // - on discrete systems, if the (staging, buffer, image) trio, or + // - on UMA, if the (buffer, image) duo + // have gone out of sync as a result of one processor writing to one + // resource which is then either accessed as an another resource type on + // the same or another processor. Same considerations regarding hazard + // avoidance as above applies. + + using Future = vTensor::Future; + const Future v_src_future = v_src.host(command_buffer); + + // Ideally we would have been able to put as much distance between + // requesting the data - a call to host() - and accessing the data + // - a call to wait() - but a local view of the computation graph + // in eager mode makes that optimization non-trivial. + + // This wait() is a no-op if data is not out of sync. More often than + // not though, waits here are expected as the GPU catches up with + // compute submitted from CPU. + + const Future::Payload v_src_payload = v_src_future.wait(); + + memcpy( + self.data_ptr(), + v_src_payload.get(), + std::min(src.nbytes(), self.nbytes())); + } + else { + TORCH_CHECK(false, "Unsupported!"); + } + + // + // WARNING + // + + // This is not great. We almost never want to flush the GPU pipeline as + // that has far reaching consequences, especially if PyTorch is not the only + // process accessing the GPU. If we have done our job properly, above + // synchronization mechanisms should be enough to ensure correctness at a more + // modest cost, as there is no need to flush the entirety of jobs in flight + // if one is only interested on waiting on computation affecting one single + // tensor to finish. + // + // Having said that, we still do need to release all pool resources at one + // point per inference run or we will run out of memory otherwise. There is + // no perfect answer to this problem that checks all boxes, which leaves us + // with one of several design decisions: + // + // 1) Use graph mode to gain an understanding of the computation graph, + // itself allowing us to place pool purges intelligently. Best option + // for performance and memory consumption. Not without its downsides if + // flexibility is a top priority. + // 2) If on eager mode, and hence are seeing operations one at a time, expose + // this release of resources to the user as a Python / C++ function. This + // makes for suboptimal user experience but is efficient in terms of + // performance. + // 3) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, release all resources somewhere ... like here. This is + // not ideal since it requires a pipeline flush to make sure these objects + // are not already in use by a workload in flight. Cannot do much better + // within the constraints of this approach. Good for user experience, + // suboptimal for performance. + // 4) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, and performance does not matter, make CPU and GPU run in + // lockstep. Obviously this is just bad. Mentioned for the sake of + // completeness. + + context->flush(); + } + else { + TORCH_INTERNAL_ASSERT( + false, + "Invalid code path taken! Either the source or the destination tensor " + "was expected to be Vulkan a tensor! Incorrect dispatch?"); + } + } + // No queue submission here. All queue submissions must have been handled + // above either explicitly or as a result of calling tensor.host(). + + return self; +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Copy.h b/aten/src/ATen/native/vulkan/ops/Copy.h new file mode 100644 index 0000000000000..e69af06357c5a --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Copy.h @@ -0,0 +1,19 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor& copy_(Tensor& self, const Tensor& src); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Factory.cpp b/aten/src/ATen/native/vulkan/ops/Factory.cpp new file mode 100644 index 0000000000000..14deb30b98889 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Factory.cpp @@ -0,0 +1,58 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor empty_memory_format( + const IntArrayRef sizes, + const c10::optional dtype, + const c10::optional layout, + const c10::optional device, + const c10::optional pin_memory, + const optional memory_format) { + return convert(vTensor{ + api::context(), + sizes, + TensorOptions() + .dtype(dtype) + .layout(layout) + .device(device) + .pinned_memory(pin_memory) + .memory_format(memory_format), + }); +} + +Tensor empty_strided( + const IntArrayRef sizes, + const IntArrayRef /* strides */, + const optional dtype, + const optional layout, + const optional device, + const optional pin_memory) { + return empty_memory_format( + sizes, + dtype, + layout, + device, + pin_memory, + c10::MemoryFormat::Contiguous); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("empty.memory_format", at::native::vulkan::ops::empty_memory_format); + m.impl("empty_strided", TORCH_FN(at::native::vulkan::ops::empty_strided)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Mean.cpp b/aten/src/ATen/native/vulkan/ops/Mean.cpp new file mode 100644 index 0000000000000..6a413f55ded55 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mean.cpp @@ -0,0 +1,119 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor mean( + const at::Tensor& input_arg, + const IntArrayRef dim, + const bool keepdim, + const optional dtype) { + TORCH_CHECK( + input_arg.dim() == 4, + "Vulkan mean expects 4-dimensional input!"); + + static const std::unordered_set expected_dims_set({2, 3}); + std::unordered_set dims_set; + + for (const auto& d : dim) { + dims_set.insert(utils::normalize(d, 4)); + } + + TORCH_CHECK( + dims_set == expected_dims_set, + "Vulkan mean currently only supports image-wide reduction!"); + + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + const IntArrayRef v_input_sizes = v_input.sizes(); + + c10::SmallVector output_sizes{ + v_input_sizes[Layout::Activation4D::batch], + v_input_sizes[Layout::Activation4D::channels], + }; + + if (keepdim) { + output_sizes.push_back(1); + output_sizes.push_back(1); + } + + vTensor v_output{ + context, + output_sizes, + v_input.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_input.has_image()) { + const struct Block final { + uvec3 extents; + int32_t range; + ivec2 iextents; + } block { + v_output.extents(), + safe_downcast( + v_input_sizes[Layout::Activation4D::width] * + v_input_sizes[Layout::Activation4D::height]), + { + safe_downcast(v_input_sizes[Layout::Activation4D::width]), + safe_downcast(v_input_sizes[Layout::Activation4D::height]), + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + keepdim ? VK_KERNEL(mean) : VK_KERNEL(mean2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("mean.dim", TORCH_FN(mean)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp new file mode 100644 index 0000000000000..26cd4f86c5542 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -0,0 +1,385 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +vTensor pack_weights( + api::Resource::Pool& pool, + const Tensor& weight_arg) { + if (weight_arg.is_vulkan()) { + return convert(weight_arg); + } + + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + + const Tensor weight = weight_arg.contiguous(); + const IntArrayRef w_sizes = weight.sizes(); + const float* const src_weight_ptr = weight.data_ptr(); + + vTensor v_weight{ + context, + &pool, + w_sizes, + weight.options(), + }; + + { + using Future = vTensor::Future; + Future v_weight_future = v_weight.host(command_buffer); + Future::Payload v_weight_payload = v_weight_future.wait(); + + memcpy( + v_weight_payload.get(), + src_weight_ptr, + std::min(weight.nbytes(), v_weight.nbytes())); + } + + return v_weight; +} + +vTensor pack_biases( + api::Resource::Pool& pool, + const Tensor& weight_arg, + const c10::optional& bias_arg) { + if (bias_arg && bias_arg->is_vulkan()) { + return convert(*bias_arg); + } + + api::Context* const context = api::context(); + api::Command::Buffer& command_buffer = context->command().pool.stream(); + + vTensor v_bias{ + context, + &pool, + { + weight_arg.size(Layout::Parameter::width), + }, + weight_arg.options(), + }; + + { + using Future = vTensor::Future; + Future v_bias_future = v_bias.host(command_buffer); + Future::Payload v_bias_payload = v_bias_future.wait(); + + if (bias_arg) { + memcpy( + v_bias_payload.get(), + bias_arg->contiguous().data_ptr(), + std::min(bias_arg->nbytes(), v_bias.nbytes())); + } + else { + memset( + v_bias_payload.get(), + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); + } + } + + return v_bias; +} + +bool available( + const Tensor& weight, + const c10::optional& bias) { + return api::available() && + // Weight + (2 == weight.ndimension()) && + (weight.size(Layout::Parameter::height) > 0) && + (weight.size(Layout::Parameter::width) > 0) && + ((c10::DeviceType::CPU == weight.device().type()) || + (c10::DeviceType::Vulkan == weight.device().type())) && + (kFloat == weight.scalar_type()) && + !weight.requires_grad() && + // Bias + ((bias && bias->defined()) ? ((bias->ndimension() > 0) && + ((c10::DeviceType::CPU == bias->device().type()) || + (c10::DeviceType::Vulkan == bias->device().type())) && + (kFloat == bias->scalar_type()) && + ((bias->ndimension() > 1) ? + (bias->size(Layout::Parameter::width) == + weight.size(Layout::Parameter::width)) + : true) && + !bias->requires_grad()) + : true) && + true; +} + +bool usable( + const Tensor& input, + const Tensor& weight, + const c10::optional& /* bias */) { + return (2 == input.ndimension()) && + (c10::DeviceType::Vulkan == input.device().type()) && + (kFloat == input.scalar_type()) && + (input.size(Layout::Parameter::width) == + weight.size(Layout::Parameter::height)) && + !input.requires_grad() && + true; +} + +Tensor addmm( + const Tensor& bias, + const Tensor& input, + const Tensor& weight, + const Scalar beta, + const Scalar alpha) { + return LinearOpContext::create( + api::context()->resource().pool, + weight, + bias).run( + input, + alpha.to(), + beta.to()); +} + +Tensor mm( + const Tensor& mat1_arg, + const Tensor& mat2_arg) { + api::Context* const context = api::context(); + + const Tensor mat1 = mat1_arg.is_vulkan() ? mat1_arg : mat1_arg.vulkan(); + const vTensor& v_mat1 = convert(mat1); + + const Tensor mat2 = mat2_arg.is_vulkan() ? mat2_arg : mat2_arg.vulkan(); + const vTensor& v_mat2 = convert(mat2); + + const auto v_mat1_sizes = v_mat1.sizes(); + const auto v_mat2_sizes = v_mat2.sizes(); + + TORCH_CHECK( + v_mat1_sizes[Layout::Parameter::width] == + v_mat2_sizes[Layout::Parameter::height], + "Incompatible matrix dimensions!"); + + vTensor v_output{ + context, + { + v_mat1_sizes[Layout::Parameter::height], + v_mat2_sizes[Layout::Parameter::width], + }, + mat1.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_mat1.has_image() && v_mat2.has_image()) { + const struct Block final { + uvec3 extents; + int32_t K; + } block { + v_output.extents(), + safe_downcast(v_mat1_sizes[Layout::Parameter::width]), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mm), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_mat1.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_mat2.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("addmm", TORCH_FN(addmm)); + m.impl("mm", TORCH_FN(mm)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace + +LinearOpContext::LinearOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias) + : packed_{ + pack_weights(pool, weight), + pack_biases(pool, weight, bias), + }, + unpacked_{ + weight, + bias, + } { +} + +LinearOpContext LinearOpContext::create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias) { + TORCH_CHECK( + available(weight, bias), + "Vulkan Linear not available! " + "Reason: The provided (weight, bias) parameters are either invalid " + "individually or their combination is not supported by Vulkan Impl."); + + // Pass in the originals + return LinearOpContext{ + pool, + weight, + bias, + }; +} + +Tensor LinearOpContext::run( + const Tensor& input_arg, + const float alpha, + const float beta) const { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + TORCH_CHECK( + usable(input, unpacked_.weight, unpacked_.bias), + "Vulkan Linear not usable! " + "Reason: The provided input tensor is either invalid on its own, or its " + "combination with the provided weight and bias tensors are unsupported by " + "Vulkan impl."); + + vTensor v_output{ + context, + { + v_input.sizes()[Layout::Parameter::height], + packed_.v_weight.sizes()[Layout::Parameter::width], + }, + input.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY( + v_output.has_image() && + v_input.has_image() && + packed_.v_weight.has_image() && + packed_.v_bias.has_image()) { + const struct Block final { + uvec3 extents; + int32_t K; + vec2 multiplier; + } block { + v_output.extents(), + safe_downcast(v_input.sizes()[Layout::Parameter::width]), + { + alpha, + beta, + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(addmm), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + packed_.v_weight.image( + command_buffer, + vTensor::Stage::Compute), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + packed_.v_bias.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +LinearOpContext::State LinearOpContext::unpack() const { + return LinearOpContext::State{ + unpacked_.weight, + unpacked_.bias, + }; +} + +c10::intrusive_ptr linear_prepack( + Tensor&& weight, + c10::optional&& bias) { + return c10::make_intrusive( + LinearOpContext::create( + persistent()->pool, + std::move(weight), + std::move(bias))); +} + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& context) { + return context->run(input, 1.0, 1.0); +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Mm.h b/aten/src/ATen/native/vulkan/ops/Mm.h new file mode 100644 index 0000000000000..2c389c555a1a1 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mm.h @@ -0,0 +1,56 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +class LinearOpContext final : public torch::jit::CustomClassHolder { + public: + static LinearOpContext create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias); + + using State = std::tuple>; + + Tensor run(const Tensor& input, float beta, float alpha) const; + State unpack() const; + + private: + LinearOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias); + + private: + struct { + vTensor v_weight; + vTensor v_bias; + } packed_; + + struct { + Tensor weight; + c10::optional bias; + } unpacked_; +}; + +c10::intrusive_ptr linear_prepack( + Tensor&& weight, + c10::optional&& bias); + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& context); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Mul.cpp b/aten/src/ATen/native/vulkan/ops/Mul.cpp new file mode 100644 index 0000000000000..1e494287a5ae3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mul.cpp @@ -0,0 +1,134 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor mul_scalar( + const Tensor& self_arg, + const Scalar other) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + v_self.sizes(), + v_self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { + uvec3 extents; + float other; + } block { + v_output.extents(), + other.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mul_scalar), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +Tensor& mul_scalar_( + Tensor& self, + const Scalar other) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self.is_vulkan(), + "Vulkan: In-place mul_scalar is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { + uvec3 extents; + float other; + } block { + v_self.extents(), + other.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mul_scalar_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return self; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("mul.Scalar", TORCH_FN(mul_scalar)); + m.impl("mul_.Scalar", TORCH_FN(mul_scalar_)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Persistent.cpp b/aten/src/ATen/native/vulkan/ops/Persistent.cpp new file mode 100644 index 0000000000000..bea5e97e50211 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Persistent.cpp @@ -0,0 +1,33 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Persistent* persistent() { + static const std::unique_ptr persistent( + []() -> Persistent* { + try { + return new Persistent{ + api::Resource::Pool{ + api::context()->gpu(), + }, + }; + } + catch (...) { + return nullptr; + } + }()); + + TORCH_CHECK( + persistent, + "Vulkan: Failed to initialize the persistent resource pool!"); + + return persistent.get(); +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Persistent.h b/aten/src/ATen/native/vulkan/ops/Persistent.h new file mode 100644 index 0000000000000..a823dcb67d47a --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Persistent.h @@ -0,0 +1,35 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +// +// This class is meant for allocation of resources that will persist through the +// execution of the program, or until they are explicitly free'd by this code's +// clients, and its usage pattern is in direct contrast with the primary resource +// pool from which tensors draw from. Whereas the primary resource pool is +// purged in its entirety at the end of each inference run, the intended usage +// pattern for this class is such that it delegates object lifetime management +// to the users so resources can stick around for as long as required. This +// is ideal for prepacked weights, or scnearios where a precomputed or +// once-transformed data can be stored and reused in subsequent runs. +// + +struct Persistent final { + api::Resource::Pool pool; +}; + +Persistent* persistent(); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Pool.cpp b/aten/src/ATen/native/vulkan/ops/Pool.cpp new file mode 100644 index 0000000000000..1e6450c9ba97e --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Pool.cpp @@ -0,0 +1,251 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor adaptive_avg_pool2d( + const at::Tensor& self_arg, + const IntArrayRef output_size) { + TORCH_CHECK( + self_arg.dim() == 4, + "Vulkan adaptive_avg_pool2d expects 4-dimensional input!"); + + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + { + self.size(Layout::Activation4D::batch), + self.size(Layout::Activation4D::channels), + output_size[Layout::Activation4D::batch], + output_size[Layout::Activation4D::channels], + }, + v_self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image()) { + const uvec3 v_output_size = v_output.extents(); + const uvec3 v_self_size = v_self.extents(); + + const vec2 stride { + static_cast(v_self_size.data[0u]) / v_output_size.data[0u], + static_cast(v_self_size.data[1u]) / v_output_size.data[1u], + }; + + const struct Block final { + uvec3 extents; + uint32_t _; + vec2 kernel; + vec2 stride; + } block { + v_output.extents(), + 0u, + { + v_self_size.data[0u] - (v_output_size.data[0u] - 1u) * stride.data[0u], + v_self_size.data[1u] - (v_output_size.data[1u] - 1u) * stride.data[1u], + }, + stride, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(adaptive_avg_pool2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +Tensor avg_pool2d( + const Tensor& self_arg, + const IntArrayRef kernel_arg, + IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const bool ceil_mode, + const bool /* count_include_pad */, + const c10::optional /* divisor_override */) { + if (stride_arg.empty()) { + stride_arg = kernel_arg; + } + + TORCH_CHECK(!kernel_arg.empty(), "Kernel size cannot be empty!"); + TORCH_CHECK(!stride_arg.empty(), "Stride cannot be empty!"); + TORCH_CHECK(!padding_arg.empty(), "Padding cannot be empty!"); + + static const auto normalize = [](const IntArrayRef parameter) { + return std::array{ + parameter[0], + (2 == parameter.size()) ? parameter[1] : parameter[0], + }; + }; + + const auto input_size = self_arg.sizes(); + const auto kernel = normalize(kernel_arg); + const auto stride = normalize(stride_arg); + const auto padding = normalize(padding_arg); + const auto dilation = std::array{1, 1}; + + const int64_t output_height = pooling_output_shape( + input_size[Layout::Activation4D::height], + kernel[Layout::Parameter::height], + padding[Layout::Parameter::height], + stride[Layout::Parameter::height], + dilation[Layout::Parameter::height], + ceil_mode); + + const int64_t output_width = pooling_output_shape( + input_size[Layout::Activation4D::width], + kernel[Layout::Parameter::width], + padding[Layout::Parameter::width], + stride[Layout::Parameter::width], + dilation[Layout::Parameter::width], + ceil_mode); + + pool2d_shape_check( + self_arg, + kernel[Layout::Parameter::height], + kernel[Layout::Parameter::width], + stride[Layout::Parameter::height], + stride[Layout::Parameter::width], + padding[Layout::Parameter::height], + padding[Layout::Parameter::width], + dilation[Layout::Parameter::height], + dilation[Layout::Parameter::width], + input_size[Layout::Activation4D::channels], + input_size[Layout::Activation4D::height], + input_size[Layout::Activation4D::width], + output_height, + output_width, + self_arg.suggest_memory_format()); + + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + { + input_size[Layout::Activation4D::batch], + input_size[Layout::Activation4D::channels], + output_height, + output_width, + }, + v_self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { + uvec3 extents; + int32_t range; + ivec4 kernel; + ivec2 stride; + ivec2 padding; + } block { + v_output.extents(), + safe_downcast( + kernel[Layout::Parameter::width] * + kernel[Layout::Parameter::height]), + { + safe_downcast(kernel[Layout::Parameter::width]), + safe_downcast(kernel[Layout::Parameter::height]), + safe_downcast(self.size(Layout::Activation4D::width)), + safe_downcast(self.size(Layout::Activation4D::height)), + }, + { + safe_downcast(stride[Layout::Parameter::width]), + safe_downcast(stride[Layout::Parameter::height]), + }, + { + safe_downcast(padding[Layout::Parameter::width]), + safe_downcast(padding[Layout::Parameter::height]), + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(avg_pool2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("_adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d)); + m.impl("avg_pool2d", TORCH_FN(avg_pool2d)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Register.cpp b/aten/src/ATen/native/vulkan/ops/Register.cpp new file mode 100644 index 0000000000000..7b226654af01d --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Register.cpp @@ -0,0 +1,80 @@ +#ifdef USE_VULKAN_API + +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +TORCH_LIBRARY(vulkan, m) { + m.class_("Conv2dOpContext") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& context) { + return context->unpack(); + }, + // __setstate__ + [](Conv2dOpContext::State state) { + return conv2d_clamp_prepack( + std::move(std::get<0>(state)), + std::move(std::get<1>(state)), + std::move(std::get<2>(state)), + std::move(std::get<3>(state)), + std::move(std::get<4>(state)), + std::move(std::get<5>(state)), + std::move(std::get<6>(state)), + std::move(std::get<7>(state))); + }); + m.class_("LinearOpContext") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& context) { + return context->unpack(); + }, + // __setstate__ + [](LinearOpContext::State state) { + return linear_prepack( + std::move(std::get<0>(state)), std::move(std::get<1>(state))); + }); +} + +TORCH_LIBRARY(vulkan_prepack, m) { + m.def( + "conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, " + "int[2] padding, int[2] dilation, int groups, " + "Scalar? output_min=None, Scalar? output_max=None) " + "-> __torch__.torch.classes.vulkan.Conv2dOpContext"); + m.def( + "conv2d_clamp_run(Tensor X, " + "__torch__.torch.classes.vulkan.Conv2dOpContext W_prepack) -> Tensor Y"); + m.def( + "linear_prepack(Tensor W, Tensor? B) " + "-> __torch__.torch.classes.vulkan.LinearOpContext"); + m.def( + "linear_run(Tensor X, " + "__torch__.torch.classes.vulkan.LinearOpContext BW_prepack) -> Tensor Y"); +} + +TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { + m.impl("conv2d_clamp_prepack", TORCH_FN(conv2d_clamp_prepack)); + m.impl("linear_prepack", TORCH_FN(linear_prepack)); +} + +TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { + m.impl("conv2d_clamp_run", TORCH_FN(conv2d_clamp_run)); + m.impl("linear_run", TORCH_FN(linear_run)); +} + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp new file mode 100644 index 0000000000000..9d2a248f0707d --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -0,0 +1,57 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor view( + const Tensor& self_arg, + const IntArrayRef shape) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + shape, + self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + command_buffer.copy( + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.buffer( + command_buffer, + vTensor::Stage::Transfer), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.buffer( + command_buffer, + vTensor::Stage::Transfer, + vTensor::Access::Write)); + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("view", TORCH_FN(view)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp new file mode 100644 index 0000000000000..0bf7acbe7deea --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -0,0 +1,1313 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +VkFormat vk_format(const caffe2::TypeMeta dtype) { + switch (c10::typeMetaToScalarType(dtype)) { + case kFloat: + #ifdef USE_VULKAN_FP16_INFERENCE + return VK_FORMAT_R16G16B16A16_SFLOAT; + #else + return VK_FORMAT_R32G32B32A32_SFLOAT; + #endif /* USE_VULKAN_FP16_INFERENCE */ + + default: + TORCH_CHECK( + false, + "Vulkan tensor format not supported!"); + } + + return VK_FORMAT_UNDEFINED; +} + +VkExtent3D vk_extent(const uvec3& extent) { + return { + extent.data[0u], + extent.data[1u], + extent.data[2u], + }; +} + +vTensor::Access::Flags access( + const VkAccessFlags vk_access) { + vTensor::Access::Flags access = 0u; + + constexpr VkAccessFlags kRead = + VK_ACCESS_HOST_READ_BIT | + VK_ACCESS_MEMORY_READ_BIT | + VK_ACCESS_SHADER_READ_BIT | + VK_ACCESS_TRANSFER_READ_BIT | + VK_ACCESS_UNIFORM_READ_BIT; + + constexpr VkAccessFlags kWrite = + VK_ACCESS_HOST_WRITE_BIT | + VK_ACCESS_MEMORY_WRITE_BIT | + VK_ACCESS_SHADER_WRITE_BIT | + VK_ACCESS_TRANSFER_WRITE_BIT; + + if (vk_access & kRead) { + access |= vTensor::Access::Read; + } + + if (vk_access & kWrite) { + access |= vTensor::Access::Write; + } + + return access; +} + +VkAccessFlags vk_access( + const vTensor::Stage::Flags stage, + const vTensor::Access::Flags access) { + VkAccessFlags vk_access = 0u; + + if (access & vTensor::Access::Read) { + if (stage & vTensor::Stage::Compute) { + vk_access |= VK_ACCESS_SHADER_READ_BIT; + } + + if (stage & vTensor::Stage::Host) { + vk_access |= VK_ACCESS_HOST_READ_BIT; + } + + if (stage & vTensor::Stage::Transfer) { + vk_access |= VK_ACCESS_TRANSFER_READ_BIT; + } + } + + if (access & vTensor::Access::Write) { + if (stage & vTensor::Stage::Compute) { + vk_access |= VK_ACCESS_SHADER_WRITE_BIT; + } + + if (stage & vTensor::Stage::Host) { + vk_access |= VK_ACCESS_HOST_WRITE_BIT; + } + + if (stage & vTensor::Stage::Transfer) { + vk_access |= VK_ACCESS_TRANSFER_WRITE_BIT; + } + } + + return vk_access; +} + +VkImageLayout vk_layout( + const vTensor::Stage::Flags stage, + const vTensor::Access::Flags access) { + switch (stage) { + case vTensor::Stage::Compute: + switch (access) { + case vTensor::Access::Read: + return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL; + + default: + return VK_IMAGE_LAYOUT_GENERAL; + } break; + + case vTensor::Stage::Transfer: + switch (access) { + case vTensor::Access::Read: + return VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL; + + case vTensor::Access::Write: + return VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL; + + default: + TORCH_INTERNAL_ASSERT(false, "Invalid!"); + } break; + + default: + TORCH_INTERNAL_ASSERT(false, "Invalid!"); + } + + return VK_IMAGE_LAYOUT_UNDEFINED; +} + +VkPipelineStageFlags vk_stage( + const vTensor::Stage::Flags stage) { + VkPipelineStageFlags vk_stage = 0u; + + if (stage & vTensor::Stage::Compute) { + vk_stage |= VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; + } + + if (stage & vTensor::Stage::Host) { + vk_stage |= VK_PIPELINE_STAGE_HOST_BIT; + } + + if (stage & vTensor::Stage::Transfer) { + vk_stage |= VK_PIPELINE_STAGE_TRANSFER_BIT; + } + + return vk_stage; +} + +VkDeviceSize buffer_bytes( + const IntArrayRef sizes, + const caffe2::TypeMeta dtype) { + VkDeviceSize size = c10::elementSize(c10::typeMetaToScalarType(dtype)); + + // Forward declaration + bool requires_image(IntArrayRef); + + if (requires_image(sizes)) { + // Forward declaration + uvec3 image_extents(IntArrayRef); + + const uvec3 extents = image_extents(sizes); + size *= extents.data[0u] * extents.data[1u] * (4u * extents.data[2u]); + } + else { + size *= prod_intlist(sizes); + } + + return size; +} + +vTensor::Buffer allocate_buffer( + const api::Adapter* const adapter, + api::Resource::Pool* const pool, + const IntArrayRef sizes, + const TensorOptions& options) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + adapter, + "Invalid Vulkan adapter!"); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + pool, + "Invalid Vulkan resource pool!"); + + TORCH_CHECK(!sizes.empty(), "Invalid Vulkan tensor size!"); + verify(options); + + // Forward declaration + bool requires_staging(const api::Adapter*); + + const VkFlags usage = + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | + VK_BUFFER_USAGE_TRANSFER_SRC_BIT | + VK_BUFFER_USAGE_TRANSFER_DST_BIT; + + const auto memory = [adapter]() -> api::Resource::Memory::Descriptor { + if (requires_staging(adapter)) { + return { + VMA_MEMORY_USAGE_GPU_ONLY, + 0u, + 0u, + }; + } + + return { + VMA_MEMORY_USAGE_GPU_TO_CPU, + 0u, + VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, + }; + }(); + + return pool->buffer({ + buffer_bytes(sizes, options.dtype()), + // Usage + { + usage, + memory, + }, + }); +} + +bool requires_image(const IntArrayRef sizes) { + return (1u <= sizes.size()) && (sizes.size() <= 4u); +} + +uvec3 image_extents(const IntArrayRef sizes) { + int64_t width = 1; + int64_t height = 1; + int64_t depth = 1; + + switch (sizes.size()) { + case 1: + width = sizes[0]; + break; + + case 2: + width = sizes[1]; + height = sizes[0]; + break; + + case 3: + width = sizes[2]; + height = sizes[1]; + depth = sizes[0]; + break; + + case 4: + width = sizes[3]; + height = sizes[2]; + depth = sizes[0] * sizes[1]; + break; + + default: + TORCH_INTERNAL_ASSERT( + false, + "Only Tensors with 1 <= dim <= 4 can be represented as a Vulkan Image!"); + } + + return { + width, + height, + div_up(depth, INT64_C(4)), + }; +} + +vTensor::Image allocate_image( + api::Resource::Pool* const pool, + const VkExtent3D& extents, + const TensorOptions& options) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + pool, + "Invalid Vulkan resource pool!"); + + verify(options); + + return pool->image({ + VK_IMAGE_TYPE_3D, + vk_format(options.dtype()), + extents, + // Usage + { + VK_IMAGE_USAGE_SAMPLED_BIT | + VK_IMAGE_USAGE_STORAGE_BIT, + { + VMA_MEMORY_USAGE_GPU_ONLY, + 0u, + 0u, + }, + }, + // View + { + VK_IMAGE_VIEW_TYPE_3D, + vk_format(options.dtype()), + }, + // Sampler + { + VK_FILTER_NEAREST, + VK_SAMPLER_MIPMAP_MODE_NEAREST, + VK_SAMPLER_ADDRESS_MODE_REPEAT, + VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK, + }, + }); +} + +bool requires_staging(const api::Adapter* const adapter) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + adapter, + "Invalid Vulkan adapter!"); + + return !adapter->has_unified_memory(); +} + +vTensor::Buffer allocate_staging( + const api::Adapter* const adapter, + api::Resource::Pool* const pool, + const IntArrayRef sizes, + const TensorOptions& options) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + adapter, + "Invalid Vulkan adapter!"); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + pool, + "Invalid Vulkan resource pool!"); + + TORCH_CHECK(!sizes.empty(), "Invalid Vulkan tensor size!"); + verify(options); + + return pool->buffer({ + buffer_bytes(sizes, options.dtype()), + // Usage + { + VK_BUFFER_USAGE_TRANSFER_SRC_BIT | + VK_BUFFER_USAGE_TRANSFER_DST_BIT, + { + VMA_MEMORY_USAGE_CPU_COPY, + 0u, + 0u, + }, + }, + }); +} + +vTensor::Fence allocate_fence(api::Resource::Pool* const pool) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + pool, + "Invalid Vulkan resource pool!"); + + return pool->fence(); +} + +enum class Barrier { + None, + Exectution, + Memory, +}; + +Barrier categorize( + const VkAccessFlags vk_src_access, + const VkAccessFlags vk_dst_access) { + if (0u == vk_src_access) { + return Barrier::None; + } + + const vTensor::Access::Flags src_access = access(vk_src_access); + const vTensor::Access::Flags dst_access = access(vk_dst_access); + + if ((src_access & vTensor::Access::Read) == src_access) { + if ((dst_access & vTensor::Access::Read) == dst_access) { + // RAR (Read after Read) + return Barrier::None; + } + + // WAR (Write after Read) + return Barrier::Exectution; + } + + // RAW (Read after Write), or WAW (Write after Write) + return Barrier::Memory; +}; + +Barrier categorize( + const VkAccessFlags vk_src_access, + const VkAccessFlags vk_dst_access, + const VkImageLayout vk_src_layout, + const VkImageLayout vk_dst_layout) { + if (vk_src_layout != vk_dst_layout) { + return Barrier::Memory; + } + + return categorize(vk_src_access, vk_dst_access); +} + +} // namespace + +vTensor::vTensor( + api::Context* const context, + const IntArrayRef sizes, + const TensorOptions& options) + : vTensor( + context, + &context->resource().pool, + sizes, + options) { +} + +vTensor::vTensor( + api::Context* const context, + api::Resource::Pool* const pool, + const IntArrayRef sizes, + const TensorOptions& options) + : view_(new View{ + context, + pool, + sizes, + options, + }) { +} + +const vTensor* vTensor::host( + api::Command::Buffer& command_buffer) const { + view_->staging(command_buffer, Stage::Host, Access::Read); + return this; +} + +vTensor* vTensor::host( + api::Command::Buffer& command_buffer, + const Access::Flags access) { + view_->staging(command_buffer, Stage::Host, access); + return this; +} + +vTensor::Buffer::Object vTensor::buffer( + api::Command::Buffer& command_buffer, + const Stage::Flags stage) const & { + return view_->buffer( + command_buffer, + stage, + Access::Read).object; +} + +vTensor::Buffer::Object vTensor::buffer( + api::Command::Buffer& command_buffer, + const Stage::Flags stage, + const Access::Flags access) & { + return view_->buffer( + command_buffer, + stage, + access).object; +} + +vTensor::Image::Object vTensor::image( + api::Command::Buffer& command_buffer, + const Stage::Flags stage) const & { + return view_->image( + command_buffer, + stage, + Access::Read).object; +} + +vTensor::Image::Object vTensor::image( + api::Command::Buffer& command_buffer, + const Stage::Flags stage, + const Access::Flags access) & { + return view_->image( + command_buffer, + stage, + access).object; +} + +vTensor::View::View() + // Resources + : buffer_{}, + image_{}, + staging_{}, + fence_{}, + // Context + context_(nullptr), + pool_(nullptr), + // State + state_{}, + // Metadata + extents_{} { +} + +vTensor::View::View( + api::Context* const context, + api::Resource::Pool* const pool, + const IntArrayRef sizes, + const TensorOptions& options) + // Resources + : buffer_{}, + image_{}, + staging_{}, + fence_{}, + // Context + context_(context), + pool_(pool), + // State + state_(context->gpu().adapter, sizes), + // Metadata + extents_(image_extents(sizes)), + options_(options), + sizes_(sizes), + strides_(sizes.size()) { + ops::verify(options); +} + +class vTensor::View::CMD final { + public: + CMD(const View&, api::Command::Buffer&); + CMD(const CMD&) = delete; + CMD& operator=(const CMD&) = delete; + CMD(CMD&&) = delete; + CMD& operator=(CMD&&) = delete; + ~CMD() = default; + + typedef api::Resource::Buffer Buffer; + typedef api::Resource::Image Image; + typedef api::Resource::Fence Fence; + + void barrier(State::Transition transition); + + void copy_buffer_to_staging( + State& state, + const Buffer::Object& buffer, + Buffer::Object& staging); + + void copy_staging_to_buffer( + State& state, + const Buffer::Object& staging, + Buffer::Object& buffer); + + void copy_buffer_to_image( + State& state, + const Buffer::Object& buffer, + Image::Object& image); + + void copy_image_to_buffer( + State& state, + const Image::Object& image, + Buffer::Object& buffer); + + void submit(Fence fence); + + private: + const View& view_; + api::Command::Buffer& command_buffer_; +}; + +vTensor::View::CMD::CMD( + const View& view, + api::Command::Buffer& command_buffer) + : view_(view), + command_buffer_(command_buffer) { +} + +void vTensor::View::CMD::barrier(State::Transition transition) { + // Buffer and Staging are just an alias for the same memory region on UMA. + + if (view_.state_.is_uma()) { + transition.first.buffer.stage |= transition.first.staging.stage; + transition.first.buffer.access |= transition.first.staging.access; + transition.first.staging = {}; + + transition.second.buffer.stage |= transition.second.staging.stage; + transition.second.buffer.access |= transition.second.staging.access; + transition.second.staging = {}; + } + + // Filter out host dependencies out of source, per Vulkan spec host write ordering guarantees: + // https://www.khronos.org/registry/vulkan/specs/1.2/html/vkspec.html#synchronization-submission-host-writes + + const auto filter_stage =[](VkPipelineStageFlags& stage) { + stage &= ~VK_PIPELINE_STAGE_HOST_BIT; + }; + + filter_stage(transition.first.buffer.stage); + filter_stage(transition.first.staging.stage); + + const auto filter_access =[](VkAccessFlags& access) { + access &= ~(VK_ACCESS_HOST_READ_BIT | VK_ACCESS_HOST_WRITE_BIT); + }; + + filter_access(transition.first.buffer.access); + filter_access(transition.first.staging.access); + + api::Pipeline::Barrier barrier{}; + + if (transition.second.staging) { + const State::Bundle::Buffer from = transition.first.staging; + const State::Bundle::Buffer to = transition.second.staging; + + const Barrier category = categorize( + from.access, + to.access); + + if (Barrier::None != category) { + barrier.stage.src |= from.stage; + barrier.stage.dst |= to.stage; + + if (Barrier::Memory == category) { + barrier.buffers.push_back({ + view_.staging().object, + { + from.access, + to.access, + }, + }); + } + } + } + + if (transition.second.buffer) { + const State::Bundle::Buffer from = transition.first.buffer; + const State::Bundle::Buffer to = transition.second.buffer; + + const Barrier category = categorize( + from.access, + to.access); + + if (Barrier::None != category) { + barrier.stage.src |= from.stage; + barrier.stage.dst |= to.stage; + + if (Barrier::Memory == category) { + barrier.buffers.push_back({ + view_.buffer().object, + { + from.access, + to.access, + }, + }); + } + } + } + + if (transition.second.image) { + const State::Bundle::Image from = transition.first.image; + const State::Bundle::Image to = transition.second.image; + + const Barrier category = categorize( + from.access, + to.access, + from.layout, + to.layout); + + if (Barrier::None != category) { + barrier.stage.src |= from.stage; + barrier.stage.dst |= to.stage; + + if (Barrier::Memory == category) { + TORCH_INTERNAL_ASSERT( + from.layout == view_.image().object.layout, + "Invalid image layout!"); + + barrier.images.push_back({ + view_.image().object, + { + from.access, + to.access, + }, + { + from.layout, + to.layout, + }, + }); + + view_.image().object.layout = to.layout; + } + } + } + + // If we are left with anything meaningful, insert a barrier. + + if (barrier) { + if (0u == barrier.stage.src) { + barrier.stage.src = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; + } + + if (0u == barrier.stage.dst) { + barrier.stage.src = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT; + } + + command_buffer_.barrier(barrier); + } +} + +void vTensor::View::CMD::copy_buffer_to_staging( + State& state, + const Buffer::Object& buffer, + Buffer::Object& staging) { + if (state.is_clean(Component::Staging) || state.is_uma()) { + return; + } + + barrier( + state.transition({ + // Staging + { + vk_stage(Stage::Transfer), + vk_access(Stage::Transfer, Access::Write), + }, + // Buffer + { + vk_stage(Stage::Transfer), + vk_access(Stage::Transfer, Access::Read), + }, + // Image + {}, + })); + + command_buffer_.copy(buffer, staging); +} + +void vTensor::View::CMD::copy_staging_to_buffer( + State& state, + const Buffer::Object& staging, + Buffer::Object& buffer) { + if (state.is_clean(Component::Buffer) || state.is_uma()) { + return; + } + + barrier( + state.transition({ + // Staging + { + vk_stage(Stage::Transfer), + vk_access(Stage::Transfer, Access::Read), + }, + // Buffer + { + vk_stage(Stage::Transfer), + vk_access(Stage::Transfer, Access::Write), + }, + // Image + {}, + })); + + command_buffer_.copy(staging, buffer); +} + +void vTensor::View::CMD::copy_buffer_to_image( + State& state, + const Buffer::Object& buffer, + Image::Object& image) { + if (state.is_clean(Component::Image)) { + return; + } + + barrier( + state.transition({ + // Staging + {}, + // Buffer + { + vk_stage(Stage::Compute), + vk_access(Stage::Compute, Access::Read), + }, + // Image + { + vk_stage(Stage::Compute), + vk_access(Stage::Compute, Access::Write), + vk_layout(Stage::Compute, Access::Write), + }, + })); + + const uvec3 extents = view_.extents(); + const uint32_t plane = extents.data[0u] * extents.data[1u]; + + const struct Block final { + uvec3 extents; + uint32_t block; + uvec4 offset; + } block { + extents, + 4u * plane, + { + 0u * plane, + 1u * plane, + 2u * plane, + 3u * plane, + }, + }; + + view_.context_->dispatch( + command_buffer_, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(nchw_to_image), + extents, + image, + buffer, + view_.context_->resource().pool.uniform(block).object); +} + +void vTensor::View::CMD::copy_image_to_buffer( + State& state, + const Image::Object& image, + Buffer::Object& buffer) { + if (state.is_clean(Component::Buffer)) { + return; + } + + barrier( + state.transition({ + // Staging + {}, + // Buffer + { + vk_stage(Stage::Compute), + vk_access(Stage::Compute, Access::Write), + }, + // Image + { + vk_stage(Stage::Compute), + vk_access(Stage::Compute, Access::Read), + vk_layout(Stage::Compute, Access::Read), + }, + })); + + const uvec3 extents = view_.extents(); + const uint32_t plane = extents.data[0u] * extents.data[1u]; + + const struct Block final { + uvec3 extents; + uint32_t block; + uvec4 offset; + } block { + extents, + 4u * plane, + { + 0u * plane, + 1u * plane, + 2u * plane, + 3u * plane, + }, + }; + + view_.context_->dispatch( + command_buffer_, + { + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(image_to_nchw), + view_.extents(), + image, + buffer, + view_.context_->resource().pool.uniform(block).object); +} + +void vTensor::View::CMD::submit(const api::Resource::Fence fence) { + view_.context_->command().pool.submit( + view_.context_->gpu().queue, + command_buffer_, + fence); +} + +vTensor::Buffer& vTensor::View::buffer() const { + if (!buffer_) { + buffer_ = allocate_buffer( + context_->gpu().adapter, + pool_, + sizes(), + options()); + } + + return buffer_; +} + +vTensor::Buffer& vTensor::View::buffer( + api::Command::Buffer& command_buffer, + const Stage::Flags stage, + const Access::Flags access) const { + CMD cmd(*this, command_buffer); + return buffer(cmd, stage, access); +} + +vTensor::Buffer& vTensor::View::buffer( + CMD& cmd, + const Stage::Flags stage, + const Access::Flags access) const { + if ((access & Access::Read) && state_.is_dirty(Component::Buffer)) { + if (state_.is_clean(Component::Staging)) { + cmd.copy_staging_to_buffer( + state_, + staging(cmd, Stage::Transfer, Access::Read).object, + buffer().object); + } + else if (state_.is_clean(Component::Image)) { + cmd.copy_image_to_buffer( + state_, + image(cmd, Stage::Compute, Access::Read).object, + buffer().object); + } + else { + TORCH_INTERNAL_ASSERT( + false, + "Invalid state!"); + } + } + + cmd.barrier( + state_.transition({ + // Staging + {}, + // Buffer + { + vk_stage(stage), + vk_access(stage, access), + }, + // Image + {}, + })); + + if (access & Access::Write) { + state_.set_dirty(Component::All); + } + + state_.set_clean(Component::Buffer); + + return buffer(); +} + +vTensor::Image& vTensor::View::image() const { + if (!image_ && state_.is_available(Component::Image)) { + image_ = allocate_image( + pool_, + vk_extent(extents()), + options()); + } + + return image_; +} + +vTensor::Image& vTensor::View::image( + api::Command::Buffer& command_buffer, + const Stage::Flags stage, + const Access::Flags access) const { + CMD cmd(*this, command_buffer); + return image(cmd, stage, access); +} + +vTensor::Image& vTensor::View::image( + CMD& cmd, + const Stage::Flags stage, + const Access::Flags access) const { + if ((access & Access::Read) && state_.is_dirty(Component::Image)) { + cmd.copy_buffer_to_image( + state_, + buffer(cmd, stage, Access::Read).object, + image().object); + } + + cmd.barrier( + state_.transition({ + // Staging + {}, + // Buffer + {}, + // Image + { + vk_stage(stage), + vk_access(stage, access), + vk_layout(stage, access), + }, + })); + + if (access & Access::Write) { + state_.set_dirty(Component::All); + } + + state_.set_clean(Component::Image); + + return image(); +} + +vTensor::Buffer& vTensor::View::staging() const { + if (!state_.is_available(Component::Staging)) { + return buffer(); + } + + if (!staging_) { + staging_ = allocate_staging( + context_->gpu().adapter, + pool_, + sizes(), + options()); + } + + return staging_; +} + +vTensor::Buffer& vTensor::View::staging( + api::Command::Buffer& command_buffer, + const Stage::Flags stage, + const Access::Flags access) const { + CMD cmd(*this, command_buffer); + Buffer& staging = this->staging(cmd, stage, access); + cmd.submit(fence(access)); + + return staging; +} + +vTensor::Buffer& vTensor::View::staging( + CMD& cmd, + const Stage::Flags stage, + const Access::Flags access) const { + if ((access & Access::Read) && state_.is_dirty(Component::Staging)) { + cmd.copy_buffer_to_staging( + state_, + buffer(cmd, Stage::Transfer, Access::Read).object, + staging().object); + } + + cmd.barrier( + state_.transition({ + // Staging + { + vk_stage(stage), + vk_access(stage, access), + }, + // Buffer + {}, + // Image + {}, + })); + + if (access & Access::Write) { + state_.set_dirty(Component::All); + } + + state_.set_clean(Component::Staging); + + return staging(); +} + +vTensor::Fence& vTensor::View::fence(const Access::Flags access) const { + if (access & Access::Read) { + fence_ = allocate_fence(&context_->resource().pool); + } + + return fence_; +} + +vTensor::Memory& vTensor::View::wait() const { + if (fence_) { + fence_.wait(); + } + + return staging().memory; +} + +void vTensor::View::verify() const { + TORCH_INTERNAL_ASSERT(!image_ || state_.is_available(Component::Image)); + TORCH_INTERNAL_ASSERT(!staging_ || state_.is_discrete()); +} + +vTensor::View::State::State() + : available_{}, + dirty_{}, + bundle_{} { +} + +vTensor::View::State::State( + const api::Adapter* const adapter, + const IntArrayRef sizes) + : available_{}, + dirty_{}, + bundle_{} { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + adapter, + "Invalid Vulkan adapter!"); + + available_ |= Component::Buffer; + + if (requires_image(sizes)) { + available_ |= Component::Image; + } + + if (requires_staging(adapter)) { + available_ |= Component::Staging; + } +} + +vTensor::View::State::Transition +vTensor::View::State::transition(const Bundle bundle) { + const Bundle from = bundle_; + Bundle& to = bundle_; + + if (bundle.staging) { + to.staging = bundle.staging; + } + + if (bundle.buffer) { + to.buffer = bundle.buffer; + } + + if (bundle.image) { + to.image = bundle.image; + } + +#ifdef DEBUG + // Forward declaration + std::ostream& operator<<( + std::ostream&, + const View::State::Bundle&); + + std::cout << "From:" << std::endl << from << std::endl; + std::cout << "To:" << std::endl << to << std::endl; +#endif /* DEBUG */ + + return Transition{ + from, + to, + }; +} + +void verify(const TensorOptions& options) { + TORCH_CHECK( + !options.has_requires_grad() || !options.requires_grad(), + "'requires_grad' tensor option is not yet supported under Vulkan!"); + + TORCH_CHECK( + !options.has_pinned_memory() || !options.pinned_memory(), + "'pinned_memory' tensor option is not yet supported under Vulkan!"); + + TORCH_CHECK( + !options.has_layout() || (c10::kStrided == options.layout()), + "'layout' tensor option is not yet supported under Vulkan!"); + + TORCH_CHECK( + !options.has_memory_format() || + (c10::MemoryFormat::Contiguous == options.memory_format_opt()), + "'memory_format' tensor option is not yet supported under Vulkan!"); +} + +// +// Debug +// + +namespace { + +// Considering that VkAccessFlags is a weak typedef of a built-in data type, we +// need to introduce a new type to allow overload resolution distinguish between +// the two. + +struct Access final { + VkAccessFlags value; +}; + +std::ostream& operator<<( + std::ostream& stream, + const Access& access) { + stream << "Access: "; + + if (0u == access.value) { + return stream << " 0"; + } + + if (access.value & VK_ACCESS_HOST_READ_BIT) { + stream << " VK_ACCESS_HOST_READ_BIT"; + } + + if (access.value & VK_ACCESS_HOST_WRITE_BIT) { + stream << " VK_ACCESS_HOST_WRITE_BIT"; + } + + if (access.value & VK_ACCESS_MEMORY_READ_BIT) { + stream << " VK_ACCESS_MEMORY_READ_BIT"; + } + + if (access.value & VK_ACCESS_MEMORY_WRITE_BIT) { + stream << " VK_ACCESS_MEMORY_WRITE_BIT"; + } + + if (access.value & VK_ACCESS_SHADER_READ_BIT) { + stream << " VK_ACCESS_SHADER_READ_BIT"; + } + + if (access.value & VK_ACCESS_SHADER_WRITE_BIT) { + stream << " VK_ACCESS_SHADER_WRITE_BIT"; + } + + if (access.value & VK_ACCESS_TRANSFER_READ_BIT) { + stream << " VK_ACCESS_TRANSFER_READ_BIT"; + } + + if (access.value & VK_ACCESS_TRANSFER_WRITE_BIT) { + stream << " VK_ACCESS_TRANSFER_WRITE_BIT"; + } + + return stream; +} + +// Considering that VkImageLayout is a weak typedef of a built-in data type, +// we need to introduce a new type to allow overload resolution distinguish +// between the two. + +struct Image final { + struct Layout final { + VkImageLayout value; + }; +}; + +std::ostream& operator<<( + std::ostream& stream, + const Image::Layout& layout) { + stream << "Layout: "; + + switch (layout.value) { + case VK_IMAGE_LAYOUT_UNDEFINED: + stream << " VK_IMAGE_LAYOUT_UNDEFINED"; + break; + + case VK_IMAGE_LAYOUT_GENERAL: + stream << " VK_IMAGE_LAYOUT_GENERAL"; + break; + + case VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL: + stream << " VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL"; + break; + + default: + stream << " Unknown!"; + break; + }; + + return stream; +} + +// Considering that VkPipelineStageFlags is a weak typedef of a built-in data +// type, we need to introduce a new type to allow overload resolution distinguish +// between the two. + +struct Stage final { + VkPipelineStageFlags value; +}; + +std::ostream& operator<<( + std::ostream& stream, + const Stage& stage) { + stream << "Stage: "; + + if (0u == stage.value) { + return stream << " 0"; + } + + if (stage.value & VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT) { + stream << " VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT"; + } + + if (stage.value & VK_PIPELINE_STAGE_HOST_BIT) { + stream << " VK_PIPELINE_STAGE_HOST_BIT"; + } + + if (stage.value & VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT) { + stream << " VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT"; + } + + if (stage.value & VK_PIPELINE_STAGE_TRANSFER_BIT) { + stream << " VK_PIPELINE_STAGE_TRANSFER_BIT"; + } + + return stream; +} + +} // namespace + +std::ostream& operator<<( + std::ostream& stream, + const vTensor::View::State::Bundle& bundle) { + stream << "Staging\n " << + Stage{ + bundle.staging.stage, + } << "\n " << + Access{ + bundle.staging.access, + } << std::endl; + + stream << "Buffer\n " << + Stage{ + bundle.buffer.stage, + } << "\n " << + Access{ + bundle.buffer.access, + } << std::endl; + + stream << "Image\n " << + Stage{ + bundle.image.stage, + } << "\n " << + Access{ + bundle.image.access, + } << "\n " << + Image::Layout{ + bundle.image.layout, + } << std::endl; + + return stream; +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.h b/aten/src/ATen/native/vulkan/ops/Tensor.h new file mode 100644 index 0000000000000..f404988b420b3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Tensor.h @@ -0,0 +1,618 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +// +// This class represents a Vulkan tensor and provides an abstraction layer +// that allows both the CPU, and the GPU, to view a Vulkan (buffer, image) +// pair as one coherent, synchronized unit of storage on both UMA and discrete +// systems. Expanding on the previous sentence, this class tries to address +// two orthogonal implementation complexities that arise as a result of the +// aforementioned goal of memory coherence: +// +// 1) First, synchronization across processors; CPUs and GPUs are separate +// processors, and even though they share the same address space in a system +// with a unified memory architecture, their address spaces only partially +// overlap on systems with a discrete GPU. Consequently on discrete systems, +// while it is still technically possible to take advantage of this shared +// address space to maintain one single copy of the data, different access +// latencies from CPU and GPU to this shared location usually necessitates +// maintaining two copies each in processor-local memory, otherwise memory +// access latency will hurt from the processor to which this data is not +// close. This shared memory is more often than not located in system memory, +// making for slow GPU read and write access over the PCI-e bus on discrete. +// Maintaining two separate copies on the other hand, requires synchronization +// to guarantee coherence. This is not an issue on UMA and this implementation +// accounts for that optimization. +// +// 2) Second, synchronization across resources (i.e. buffers and images); GPU +// drivers pack images in proprietory formats for better locality of access +// and to enable lossless compression. These conversions are both expensive +// (in general) and manual (in Vulkan.) This requires a second order of +// synchronization to guarantee coherence between the contents of the buffer +// and image otherwise they will go out of sync. +// +// It is extremely important to keep in mind that the functionality this class +// provides is generally expensive. For optimal performance, the user of this +// class should: +// +// 1) Avoid frequent CPU <=> GPU transfers which will be triggered if data is +// write accessed on one processor and read / write accessed on the other. +// +// 2) Avoid frequent buffer <=> image conversions which will be trigerred if +// data is write accessed as a buffer (image) and read accessed as an +// image (buffer). +// +// 3) When and if a synchronization is unavoidable, place as much distance +// between the synchronization is triggered and the data is accessed since +// all synchronizations this class provides are async. +// +// For optimal performance, access the data as images, and keep the data on GPU, +// and above all understand the expensive data flow that this class abstracts +// away. +// +// vTensor tries to address a specific concern and intentionally does not expose +// GPU tensor memory directly. Please keep that behavior intact as the whole +// data model fundamentally depends on limiting what the user can achieve through +// the interface to guarantee performance and coherence. +// +// A vTensor is associated with an api::Context as preparation for multi-GPU +// support. +// + +class vTensor final { + public: + vTensor() = default; + vTensor( + api::Context* context, + IntArrayRef sizes, + const TensorOptions& options); + vTensor( + api::Context* context, + api::Resource::Pool* pool, + IntArrayRef sizes, + const TensorOptions& options); + + /* + Types + */ + + typedef api::Pipeline::Stage Stage; + typedef api::Resource::Memory::Access Access; + typedef api::Resource::Buffer Buffer; + typedef api::Resource::Fence Fence; + typedef api::Resource::Image Image; + typedef api::Resource::Memory Memory; + + /* + Future + */ + + template + class Future final { + template + using is_convertible = std::enable_if_t< + std::is_convertible< + Access::Pointer, + Access::Pointer>::value>; + + public: + explicit Future(const vTensor* tensor); + Future(const Future&) = delete; + Future& operator=(const Future&) = delete; + Future(Future&&); + Future& operator=(Future&&) &; + Future& operator=(Future&&) && = delete; + template> + Future(Future&&); + template> + Future& operator=(Future&&) &; + template + Future& operator=(Future&&) && = delete; + ~Future(); + + typedef Memory::Handle< + Access::Pointer< + Type, + kAccess>> Payload; + + // This is a blocking operation as the name suggests. A call to host() will + // trigger an async copy if pending writes are detected. Consequently, for + // optimal performance, put as much time and distance between the place + // where a vTensor::host() call occurs and the location where the returned + // future is explicitly waited on as a result of a call to this function. + + Payload wait() const &; + + private: + template + friend class Future; + + // Intentionally disabed to enforce a usage pattern wherein the Future's + // lifetime exceeds that of the Payload as we use the Future's destructor + // to eagerly (as opposed to lazily and upon first use) upload the + // modifications back onto the GPU in an effort to hide the upload latency. + + Payload wait() const && = delete; + + private: + const vTensor* tensor_; + }; + + /* + Host access - these functions will be expensive if they trigger a GPU -> CPU + sync due to pending writes. A call to host() will trigger an async copy in + such scenarios, which is then explicitly waited on as part of Future::wait(). + Consequently, for optimal performance, put as much time and distance between + the place where this function is called, and the location where the future is + waited on. + */ + + template + Future host(api::Command::Buffer&) const &; + + template + Future host(api::Command::Buffer&) &; + + /* + Device access - these functions will be expensive if they trigger a buffer + <-> image or CPU -> GPU sync due to pending writes. These functions are + non-blocking on the host as the copy operation is carried out by the GPU + asynchronously. Regardless, they result in extra work that could have been + avoided or at least minimized if all data access had occured through one + single processor (GPU in this case) and on one type of resource (image for + best performance.) Consequently, for optimal performance, avoid mixed reads + and writes across processor boundaries, and do your best to minimize layout + transitions as a result of working with images only (as opposed to mixed + buffer - image usage.) + This implementation intentionally restricts user access to the buffer and + image objects only, as opposed to their underlying memory, for the sake of + predictability of usage and efficiency. + */ + + Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const &; + Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) &; + + bool has_image() const; + Image::Object image(api::Command::Buffer&, Stage::Flags) const &; + Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) &; + + /* + Metadata + */ + + const api::utils::uvec3& extents() const; + const TensorOptions& options() const; + IntArrayRef sizes() const; + IntArrayRef strides() const; + size_t nbytes() const; + + private: + // Some overloads below are intentionally disabled to enforce a usage pattern + // that ensures the Tensor's lifetime exceeds that of the scope in which the + // underlying data is accessed. Allowing deleted overloads below to be + // invoked on a temporary would open the door to the possibility of accessing + // the underlying memory out of the expected scope. + + /* + Host + */ + + const vTensor* host(api::Command::Buffer&) const; + vTensor* host(api::Command::Buffer&, Access::Flags); + + template + Future host(api::Command::Buffer&) const && = delete; + + template + Future host(api::Command::Buffer&) && = delete; + + /* + Device + */ + + Buffer::Object buffer(api::Command::Buffer&, Stage::Flags) const && = delete; + Buffer::Object buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; + + Image::Object image(api::Command::Buffer&, Stage::Flags) const && = delete; + Image::Object image(api::Command::Buffer&, Stage::Flags, Access::Flags) && = delete; + + private: + class View final { + public: + View(); + View( + api::Context* context, + api::Resource::Pool* pool, + IntArrayRef sizes, + const TensorOptions& options); + View(const View&) = delete; + View& operator=(const View&) = delete; + View(View&&) = default; + View operator=(View&&) = delete; + ~View() = default; + + /* + Buffer + */ + + Buffer& buffer(api::Command::Buffer&, Stage::Flags, Access::Flags) const; + + /* + Image + */ + + bool has_image() const; + Image& image(api::Command::Buffer&, Stage::Flags, Access::Flags) const; + + /* + Host + */ + + Buffer& staging(api::Command::Buffer&, Stage::Flags, Access::Flags) const; + vTensor::Memory& wait() const; + + /* + Metadata + */ + + const api::utils::uvec3& extents() const; + const TensorOptions& options() const; + IntArrayRef sizes() const; + IntArrayRef strides() const; + + private: + class CMD; + + class State final { + public: + State(); + State(const api::Adapter*, IntArrayRef); + + struct Bundle final { + struct Buffer final { + VkPipelineStageFlags stage; + VkAccessFlags access; + + operator bool() const; + } staging, buffer; + + struct Image final { + VkPipelineStageFlags stage; + VkAccessFlags access; + VkImageLayout layout; + + operator bool() const; + } image; + }; + + struct Component final { + typedef uint8_t Flags; + + enum Type : Flags { + Buffer = 1u << 0u, + Image = 1u << 1u, + Staging = 1u << 2u, + All = Buffer | Image | Staging, + }; + }; + + // Availability + bool is_available(Component::Flags) const; + bool is_discrete() const; + bool is_uma() const; + + // Clean / Dirty + bool is_clean(Component::Flags) const; + bool is_dirty(Component::Flags) const; + void set_clean(Component::Flags); + void set_dirty(Component::Flags); + + // Transition + typedef std::pair Transition; + Transition transition(Bundle to); + + private: + Component::Flags available_; + Component::Flags dirty_; + Bundle bundle_; + }; + + typedef State::Component Component; + + private: + // Accessors / Lazy Allocation + Buffer& buffer() const; + Buffer& buffer(CMD&, Stage::Flags, Access::Flags) const; + Image& image() const; + Image& image(CMD&, Stage::Flags, Access::Flags) const; + Buffer& staging() const; + Buffer& staging(CMD&, Stage::Flags, Access::Flags) const; + Fence& fence(Access::Flags) const; + + // Validation + void verify() const; + + private: + // Resources + mutable Buffer buffer_; + mutable Image image_; + mutable Buffer staging_; + mutable Fence fence_; + + // Context + api::Context* context_; + api::Resource::Pool* pool_; + + // State + mutable State state_; + + // Metadata + api::utils::uvec3 extents_; + TensorOptions options_; + c10::SmallVector sizes_; + c10::SmallVector strides_; + + private: + // Debug + friend std::ostream& operator<<( + std::ostream&, + const View::State::Bundle&); + }; + + // Even at the cost of a heap allocation plus the resulting negative impact + // on cache locality due to the subsequent pointer chasing, it is still + // critcal to share the view across vTensor implementations to minimize + // programmer errors. Ideally this class should have been only made movable, + // and non-copyable - something we cannot do unfortunately due to the inner + // workings of at::TensorImpl requiring copy semantics in + // at::TensorImpl::release_resources() to function as expected. Now that this + // class is made copyable though, a new door to a whole new class of bugs is + // opened, in that there now is a chance of two [shallow] copies, have their + // State objects go out of sync as a result of an operation being performed on + // one shallow copy that is not reflected in the other. Technically, if the + // programmer is very careful, it is possible to avoid this trap and not pay + // the cost of indirection, but the resulting bugs of missing memory barriers + // will be so frustrating to hunt down for those unfamiliar with the internal + // mechanics of this class, that I decided to take the performance pentalty + // of this extra layer of indirection in favor of making this class easier + // to use. + + std::shared_ptr view_; + + private: + // Debug + friend std::ostream& operator<<( + std::ostream&, + const View::State::Bundle&); +}; + +const vTensor& convert(const Tensor& tensor); +vTensor& convert(Tensor& tensor); +Tensor convert(const vTensor& tensor); + +using vTensorImpl = VulkanOpaqueTensorImpl; +void verify(const TensorOptions& options); + +// +// Impl +// + +template +inline vTensor::Future::Future( + const vTensor* const tensor) + : tensor_(tensor) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + tensor_, + "Invalid Vulkan tensor!"); +} + +template +inline vTensor::Future::Future( + Future&& future) + : tensor_(std::move(future.tensor_)) { + future.tensor_ = nullptr; +} + +template +inline vTensor::Future& +vTensor::Future::operator=( + Future&& future) & { + tensor_ = std::move(future.tensor_); + future.tensor_ = nullptr; + return *this; +} + +template +template +inline vTensor::Future::Future( + Future&& future) + : tensor_(std::move(future.tensor_)) { + future.tensor_ = nullptr; +} + +template +template +inline vTensor::Future& +vTensor::Future::operator=( + Future&& future) & { + tensor_ = std::move(future.tensor_); + future.tensor_ = nullptr; + return *this; +} + +template +inline vTensor::Future::~Future() { +#if VULKAN_SYNC_TENSORS_EAGERLY + // Sync eagerly in an effort to hide latency. + // Upside: Kick off the async transfer early on to keep the GPU busy. + // Downside: An extra CPU command submission. + if (tensor_ && (Access::Write & kAccess)) { + if (tensor_->has_image()) { + tensor_->image(); + } + else { + tensor_->buffer(); + } + } +#endif +} + +template +inline typename vTensor::Future::Payload +vTensor::Future::wait() const & { + TORCH_CHECK( + tensor_, + "vTensor::Future is in an invalid state! " + "Potential reason: This future is moved from."); + + return tensor_->view_->wait().template map(); +} + +template +inline vTensor::Future +vTensor::host(api::Command::Buffer& command_buffer) const & { + return Future(host(command_buffer)); +} + +template +inline vTensor::Future +vTensor::host(api::Command::Buffer& command_buffer) & { + return Future(host(command_buffer, kAccess)); +} + +inline bool vTensor::has_image() const { + return view_->has_image(); +} + +inline const api::utils::uvec3& vTensor::extents() const { + return view_->extents(); +} + +inline const TensorOptions& vTensor::options() const { + return view_->options(); +} + +inline IntArrayRef vTensor::sizes() const { + return view_->sizes(); +} + +inline size_t vTensor::nbytes() const { + return c10::elementSize(c10::typeMetaToScalarType(options().dtype())) * + prod_intlist(sizes()); +} + +inline IntArrayRef vTensor::strides() const { + return view_->strides(); +} + +inline bool vTensor::View::has_image() const { + return state_.is_available(View::Component::Image); +} + +inline const api::utils::uvec3& vTensor::View::extents() const { + return extents_; +} + +inline const TensorOptions& vTensor::View::options() const { + return options_; +} + +inline IntArrayRef vTensor::View::sizes() const { + return sizes_; +} + +inline IntArrayRef vTensor::View::strides() const { + return strides_; +} + +inline vTensor::View::State::Bundle::Buffer::operator bool() const { + return (0u != stage) && + (0u != access); +} + +inline vTensor::View::State::Bundle::Image::operator bool() const { + return (0u != stage) && + (0u != access) && + (VK_IMAGE_LAYOUT_UNDEFINED != layout); +} + +inline bool vTensor::View::State::is_available( + const Component::Flags components) const { + return available_ & components; +} + +inline bool vTensor::View::State::is_discrete() const { + return is_available(Component::Staging); +} + +inline bool vTensor::View::State::is_uma() const { + return !is_discrete(); +} + +inline bool vTensor::View::State::is_clean( + const Component::Flags components) const { + return !is_dirty(components); +} + +inline bool vTensor::View::State::is_dirty( + const Component::Flags components) const { + return dirty_ & components; +} + +inline void vTensor::View::State::set_clean( + const Component::Flags components) { + dirty_ &= ~components; +} + +inline void vTensor::View::State::set_dirty( + const Component::Flags components) { + dirty_ |= components; +} + +inline const vTensor& convert(const Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + tensor.is_vulkan(), + "Vulkan tensor expected!"); + + const vTensorImpl* const impl = + static_cast(tensor.unsafeGetTensorImpl()); + + return impl->opaque_handle(); +} + +inline vTensor& convert(Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + tensor.is_vulkan(), + "Vulkan tensor expected!"); + + vTensorImpl* const impl = + static_cast(tensor.unsafeGetTensorImpl()); + + return impl->unsafe_opaque_handle(); +} + +inline Tensor convert(const vTensor& tensor) { + return at::detail::make_tensor( + DispatchKeySet(DispatchKey::Vulkan), + tensor.options().dtype(), + at::Device(at::kVulkan), + tensor, + tensor.sizes(), + tensor.strides()); +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp new file mode 100644 index 0000000000000..00cefc1bdf538 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -0,0 +1,112 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +using namespace api::utils; + +Tensor upsample_nearest2d( + const Tensor& input_arg, + const IntArrayRef output_sizes, + const c10::optional scales_h, + const c10::optional scales_w) { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + const auto v_input_sizes = v_input.sizes(); + + TORCH_CHECK( + (4 == v_input_sizes.size()) && (2 == output_sizes.size()), + "Invalid input!"); + + vTensor v_output{ + context, + { + v_input_sizes[Layout::Activation4D::batch], + v_input_sizes[Layout::Activation4D::channels], + output_sizes[Layout::Parameter::height], + output_sizes[Layout::Parameter::width], + }, + input.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_input.has_image()) { + const struct Block final { + uvec3 extents; + uint32_t _; + ivec2 iextents; + vec2 scale; + } block { + v_output.extents(), + 0u, + { + safe_downcast(input.size(Layout::Activation4D::width) - 1), + safe_downcast(input.size(Layout::Activation4D::height) - 1), + }, + { + compute_scales_value( + scales_w, + v_input_sizes[Layout::Activation4D::width], + output_sizes[Layout::Parameter::width]), + compute_scales_value( + scales_h, + v_input_sizes[Layout::Activation4D::height], + output_sizes[Layout::Parameter::height]), + }, + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(upsample_nearest2d), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_input.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("upsample_nearest2d", TORCH_FN(upsample_nearest2d)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Utils.h b/aten/src/ATen/native/vulkan/ops/Utils.h new file mode 100644 index 0000000000000..de218cfc472ab --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Utils.h @@ -0,0 +1,25 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace utils { + +inline int64_t normalize( + const int64_t dimension, + const int64_t n) { + return (dimension % n + n) % n; +} + +} // namespace utils +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/xnnpack/OpContext.cpp b/aten/src/ATen/native/xnnpack/OpContext.cpp index fe78dcda1f997..fe525cb2df4dc 100644 --- a/aten/src/ATen/native/xnnpack/OpContext.cpp +++ b/aten/src/ATen/native/xnnpack/OpContext.cpp @@ -27,9 +27,19 @@ XNNPackLinearOpContext::create_context( output_max ? output_max->to() : xnnpack::ContextLinear::kMax) ); + if (at::globalContext().releaseWeightsWhenPrepacking()) { + linear_op_context->free_orig_weight_and_bias(); + } + return linear_op_context; } +void XNNPackLinearOpContext::free_orig_weight_and_bias() { + orig_weight_and_bias_freed_ = true; + orig_weight_.reset(); + orig_bias_.reset(); +} + Tensor XNNPackLinearOpContext::run(const Tensor& input) { return xnnpack::internal::linear::run(op_context_, input); } @@ -70,6 +80,10 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight, output_max, std::move(op_context)); + if (at::globalContext().releaseWeightsWhenPrepacking()) { + conv2d_op_context->free_orig_weight_and_bias(); + } + return conv2d_op_context; } @@ -111,6 +125,10 @@ XNNPackTransposeConv2dOpContext::create_context(at::Tensor&& weight, output_max, std::move(op_context)); + if (at::globalContext().releaseWeightsWhenPrepacking()) { + conv2d_op_context->free_orig_weight_and_bias(); + } + return conv2d_op_context; } @@ -122,6 +140,18 @@ Tensor XNNPackTransposeConv2dOpContext::run(const Tensor& input) { return xnnpack::internal::convolution2d::run(op_context_, input); } +void XNNPackConv2dOpContext::free_orig_weight_and_bias() { + orig_weight_and_bias_freed_ = true; + orig_weight_.reset(); + orig_bias_.reset(); +} + +void XNNPackTransposeConv2dOpContext::free_orig_weight_and_bias() { + orig_weight_and_bias_freed_ = true; + orig_weight_.reset(); + orig_bias_.reset(); +} + } // namespace xnnpack } // namespace native } // namespace at diff --git a/aten/src/ATen/native/xnnpack/OpContext.h b/aten/src/ATen/native/xnnpack/OpContext.h index e696ad3aa81dc..e26c3383d6a60 100644 --- a/aten/src/ATen/native/xnnpack/OpContext.h +++ b/aten/src/ATen/native/xnnpack/OpContext.h @@ -43,13 +43,16 @@ class LinearOpContext : public torch::jit::CustomClassHolder { c10::optional orig_bias_; c10::optional output_min_; c10::optional output_max_; + bool orig_weight_and_bias_freed_; public: SerializationTypeLinearPrePack unpack() { + TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple(orig_weight_, orig_bias_, output_min_, output_max_); } virtual Tensor run(const Tensor& input) = 0; + virtual void free_orig_weight_and_bias() = 0; }; class XNNPackLinearOpContext final : public LinearOpContext { @@ -68,9 +71,11 @@ class XNNPackLinearOpContext final : public LinearOpContext { orig_bias_ = std::move(bias); output_min_ = min; output_max_ = max; + orig_weight_and_bias_freed_ = false; } - Tensor run(const Tensor& input); + Tensor run(const Tensor& input) override; + void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, @@ -89,9 +94,11 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder { int64_t groups_; c10::optional output_min_; c10::optional output_max_; + bool orig_weight_and_bias_freed_; public: SerializationTypeConv2dPrePack unpack() { + TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple( orig_weight_, orig_bias_, @@ -104,6 +111,7 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder { } virtual Tensor run(const Tensor& input) = 0; + virtual void free_orig_weight_and_bias() = 0; }; class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { @@ -117,9 +125,11 @@ class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { int64_t groups_; c10::optional output_min_; c10::optional output_max_; + bool orig_weight_and_bias_freed_; public: SerializationTypeTransposeConv2dPrePack unpack() { + TORCH_CHECK(!orig_weight_and_bias_freed_, "Original weight and bias have been freed"); return std::make_tuple( orig_weight_, orig_bias_, @@ -133,6 +143,7 @@ class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { } virtual Tensor run(const Tensor& input) = 0; + virtual void free_orig_weight_and_bias() = 0; }; class XNNPackConv2dOpContext final : public Conv2dOpContext { @@ -159,9 +170,11 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext { groups_ = groups; output_min_ = min; output_max_ = max; + orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; + void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, @@ -200,9 +213,11 @@ class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext { groups_ = groups; output_min_ = min; output_max_ = max; + orig_weight_and_bias_freed_ = false; } Tensor run(const Tensor& input) override; + void free_orig_weight_and_bias() override; static c10::intrusive_ptr create_context( Tensor&& weight, diff --git a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp index e8442a64d0ad8..da13fb9574d53 100644 --- a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp +++ b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp @@ -73,21 +73,21 @@ TORCH_LIBRARY(xnnpack, m) { } TORCH_LIBRARY(prepacked, m) { - m.def("linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext"); - m.def("linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y"); - m.def("conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext"); - m.def("conv2d_transpose_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.TransposeConv2dOpContext"); - m.def("conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y"); - m.def("conv2d_transpose_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.TransposeConv2dOpContext W_prepack) -> Tensor Y"); + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext")); + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext")); + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_transpose_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.TransposeConv2dOpContext")); + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_transpose_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.TransposeConv2dOpContext W_prepack) -> Tensor Y")); } TORCH_LIBRARY_IMPL(prepacked, CPU, m) { - m.impl("linear_clamp_prepack", TORCH_FN(createLinearClampPrePackOpContext)); - m.impl("linear_clamp_run", TORCH_FN(internal::linear::linear_clamp_run)); - m.impl("conv2d_clamp_prepack", TORCH_FN(createConv2dClampPrePackOpContext)); - m.impl("conv2d_transpose_clamp_prepack", TORCH_FN(createConv2dTransposeClampPrePackOpContext)); - m.impl("conv2d_clamp_run", TORCH_FN(internal::convolution2d::conv2d_clamp_run)); - m.impl("conv2d_transpose_clamp_run", TORCH_FN(internal::convolution2d::conv2d_transpose_clamp_run)); + m.impl(TORCH_SELECTIVE_NAME("prepacked::linear_clamp_prepack"), TORCH_FN(createLinearClampPrePackOpContext)); + m.impl(TORCH_SELECTIVE_NAME("prepacked::linear_clamp_run"), TORCH_FN(internal::linear::linear_clamp_run)); + m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_clamp_prepack"), TORCH_FN(createConv2dClampPrePackOpContext)); + m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_transpose_clamp_prepack"), TORCH_FN(createConv2dTransposeClampPrePackOpContext)); + m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_clamp_run"), TORCH_FN(internal::convolution2d::conv2d_clamp_run)); + m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_transpose_clamp_run"), TORCH_FN(internal::convolution2d::conv2d_transpose_clamp_run)); } } // namespace xnnpack diff --git a/aten/src/ATen/nnapi/CMakeLists.txt b/aten/src/ATen/nnapi/CMakeLists.txt new file mode 100644 index 0000000000000..01324049bde6f --- /dev/null +++ b/aten/src/ATen/nnapi/CMakeLists.txt @@ -0,0 +1,21 @@ +# Define this to build the NNAPI binding out of tree. +if(PYTORCH_NNAPI_STANDALONE) + cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + project(pytorch_nnapi) + + set(CMAKE_CXX_STANDARD 14) + find_package(Torch REQUIRED) + + set(NNAPI_SRCS + nnapi_bind.cpp + nnapi_wrapper.cpp + nnapi_model_loader.cpp + ) + + add_library(pytorch_nnapi SHARED ${NNAPI_SRCS}) + target_link_libraries(pytorch_nnapi torch) +else() + # Building within the PyTorch tree. + file(GLOB ATen_NNAPI_SRCS "*.cpp") + set(ATen_NNAPI_SRCS ${ATen_NNAPI_SRCS} PARENT_SCOPE) +endif() diff --git a/aten/src/ATen/nnapi/NeuralNetworks.h b/aten/src/ATen/nnapi/NeuralNetworks.h new file mode 100644 index 0000000000000..bfc3ea4ac49df --- /dev/null +++ b/aten/src/ATen/nnapi/NeuralNetworks.h @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + +Most of NeuralNetworks.h has been stripped for simplicity. +We don't need any of the function declarations since +we call them all through dlopen/dlsym. +Operation codes are pulled directly from serialized models. + +*/ + +#ifndef MINIMAL_NEURAL_NETWORKS_H +#define MINIMAL_NEURAL_NETWORKS_H + +#include + +typedef enum { + ANEURALNETWORKS_NO_ERROR = 0, + ANEURALNETWORKS_OUT_OF_MEMORY = 1, + ANEURALNETWORKS_INCOMPLETE = 2, + ANEURALNETWORKS_UNEXPECTED_NULL = 3, + ANEURALNETWORKS_BAD_DATA = 4, + ANEURALNETWORKS_OP_FAILED = 5, + ANEURALNETWORKS_BAD_STATE = 6, + ANEURALNETWORKS_UNMAPPABLE = 7, + ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE = 8, + ANEURALNETWORKS_UNAVAILABLE_DEVICE = 9, +} ResultCode; + +typedef enum { + ANEURALNETWORKS_FLOAT32 = 0, + ANEURALNETWORKS_INT32 = 1, + ANEURALNETWORKS_UINT32 = 2, + ANEURALNETWORKS_TENSOR_FLOAT32 = 3, + ANEURALNETWORKS_TENSOR_INT32 = 4, + ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5, + ANEURALNETWORKS_BOOL = 6, + ANEURALNETWORKS_TENSOR_QUANT16_SYMM = 7, + ANEURALNETWORKS_TENSOR_FLOAT16 = 8, + ANEURALNETWORKS_TENSOR_BOOL8 = 9, + ANEURALNETWORKS_FLOAT16 = 10, + ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL = 11, + ANEURALNETWORKS_TENSOR_QUANT16_ASYMM = 12, + ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13, +} OperandCode; + +typedef enum { + ANEURALNETWORKS_PREFER_LOW_POWER = 0, + ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1, + ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2, +} PreferenceCode; + +typedef struct ANeuralNetworksMemory ANeuralNetworksMemory; +typedef struct ANeuralNetworksModel ANeuralNetworksModel; +typedef struct ANeuralNetworksDevice ANeuralNetworksDevice; +typedef struct ANeuralNetworksCompilation ANeuralNetworksCompilation; +typedef struct ANeuralNetworksExecution ANeuralNetworksExecution; +typedef struct ANeuralNetworksEvent ANeuralNetworksEvent; + +typedef int32_t ANeuralNetworksOperationType; + +typedef struct ANeuralNetworksOperandType { + int32_t type; + uint32_t dimensionCount; + const uint32_t* dimensions; + float scale; + int32_t zeroPoint; +} ANeuralNetworksOperandType; + +#endif // MINIMAL_NEURAL_NETWORKS_H diff --git a/aten/src/ATen/nnapi/codegen.py b/aten/src/ATen/nnapi/codegen.py new file mode 100755 index 0000000000000..a24823da6f7cf --- /dev/null +++ b/aten/src/ATen/nnapi/codegen.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +""" +Code generator for NNAPI wrapper. We can't link directly against +libneuralnetworks.so because we want PyTorch to work on Android +devices that don't have it available. Instead, we generate a wrapper +that opens libneuralnetworks.so with dlopen and finds the functions +we need with dlsym. We also generate a "check" wrapper that checks +return values and throws C++ exceptions on errors. +""" +import sys +import re +import pathlib +import textwrap + + +PREFIX = """\ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is generated by nnapi/codegen.py +""" + + +NNAPI_FUNCTIONS = [ + ("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), # noqa: B950 + ("int", "ANeuralNetworks_getDevice", "uint32_t devIndex, ANeuralNetworksDevice** device"), # noqa: B950 + ("int", "ANeuralNetworksDevice_getName", "const ANeuralNetworksDevice* device, const char** name"), # noqa: B950 + ("int", "ANeuralNetworksDevice_getVersion", "const ANeuralNetworksDevice* device, const char** version"), # noqa: B950 + ("int", "ANeuralNetworksDevice_getFeatureLevel", "const ANeuralNetworksDevice* device, int64_t* featureLevel"), # noqa: B950 + ("int", "ANeuralNetworksModel_getSupportedOperationsForDevices", " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps"), # noqa: B950 + ("int", "ANeuralNetworksCompilation_createForDevices", "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation"), # noqa: B950 + ("int", "ANeuralNetworksExecution_compute", "ANeuralNetworksExecution* execution"), # noqa: B950 + ("int", "ANeuralNetworksMemory_createFromFd", "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory"), # noqa: B950 + ("void", "ANeuralNetworksMemory_free", "ANeuralNetworksMemory* memory"), # noqa: B950 + ("int", "ANeuralNetworksModel_create", "ANeuralNetworksModel** model"), # noqa: B950 + ("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), # noqa: B950 + ("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), # noqa: B950 + ("int", "ANeuralNetworksModel_addOperand", "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type"), # noqa: B950 + ("int", "ANeuralNetworksModel_setOperandValue", "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length"), # noqa: B950 + ("int", "ANeuralNetworksModel_setOperandValueFromMemory", "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 + ("int", "ANeuralNetworksModel_addOperation", "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950 + ("int", "ANeuralNetworksModel_identifyInputsAndOutputs", "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950 + ("int", "ANeuralNetworksModel_relaxComputationFloat32toFloat16", "ANeuralNetworksModel* model, bool allow"), # noqa: B950 + ("int", "ANeuralNetworksCompilation_create", "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation"), # noqa: B950 + ("void", "ANeuralNetworksCompilation_free", "ANeuralNetworksCompilation* compilation"), # noqa: B950 + ("int", "ANeuralNetworksCompilation_setPreference", "ANeuralNetworksCompilation* compilation, int32_t preference"), # noqa: B950 + ("int", "ANeuralNetworksCompilation_finish", "ANeuralNetworksCompilation* compilation"), # noqa: B950 + ("int", "ANeuralNetworksExecution_create", "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution"), # noqa: B950 + ("void", "ANeuralNetworksExecution_free", "ANeuralNetworksExecution* execution"), # noqa: B950 + ("int", "ANeuralNetworksExecution_setInput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length"), # noqa: B950 + ("int", "ANeuralNetworksExecution_setInputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 + ("int", "ANeuralNetworksExecution_setOutput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length"), # noqa: B950 + ("int", "ANeuralNetworksExecution_setOutputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 + ("int", "ANeuralNetworksExecution_startCompute", "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event"), # noqa: B950 + ("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), # noqa: B950 + ("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), # noqa: B950 + ("int", "ANeuralNetworksExecution_getOutputOperandRank", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank"), # noqa: B950 + ("int", "ANeuralNetworksExecution_getOutputOperandDimensions", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions"), # noqa: B950 +] + + +def main(argv): + struct_members = [] + load_functions = [] + define_checks = [] + + for ret, name, args in NNAPI_FUNCTIONS: + short_name = name.replace("ANeuralNetworks", "", 1) + + struct_members.append(f" {ret}(*{short_name})({args});") + + load_functions.append(f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");') + load_functions.append(f' check_nnapi_.{short_name} = check_{short_name};') + + call_args = "".join(re.findall(r"\w+(?:,|$)", args)) + if ret == "void": + define_checks.append(textwrap.dedent(f"""\ + {ret} check_{short_name}({args}) {{ + CAFFE_ENFORCE(nnapi_.{short_name}); + nnapi_.{short_name}({call_args}); + }}""")) + if ret == "int": + define_checks.append(textwrap.dedent(f"""\ + {ret} check_{short_name}({args}) {{ + CAFFE_ENFORCE(nnapi_.{short_name}); + int ret = nnapi_.{short_name}({call_args}); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; + }}""")) + + out_dir = pathlib.Path(__file__).parent + + (out_dir / "nnapi_wrapper.h").write_text( + PREFIX + + textwrap.dedent("""\ + #ifndef NNAPI_WRAPPER_H_ + #define NNAPI_WRAPPER_H_ + #include + #include + #include + struct nnapi_wrapper { + __STRUCT_MEMBERS__ + }; + #ifdef __cplusplus + void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi); + #endif + #endif + """) + .replace("__STRUCT_MEMBERS__", "\n".join(struct_members)) + ) + + (out_dir / "nnapi_wrapper.cpp").write_text( + PREFIX + + textwrap.dedent("""\ + #ifndef _WIN32 + #include + #endif + #include + #include + static int loaded = 0; + static struct nnapi_wrapper nnapi_; + static struct nnapi_wrapper check_nnapi_; + __DEFINE_CHECK_FUNCTIONS__ + void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi) { + #ifdef _WIN32 + TORCH_CHECK(false, "Running NNAPI models is not supported on Windows."); + #else + if (!loaded) { + // Clear error flag. + dlerror(); + void* handle = dlopen("libneuralnetworks.so", RTLD_LAZY | RTLD_LOCAL); + CAFFE_ENFORCE(handle, "Failed to load libneuralnetworks.so ", dlerror()); + __LOAD_FUNCTIONS__ + loaded = 1; + } + *nnapi = &nnapi_; + *check_nnapi = &check_nnapi_; + #endif + } + """) + .replace("__DEFINE_CHECK_FUNCTIONS__", "\n".join(define_checks)) + .replace("__LOAD_FUNCTIONS__", "\n".join(load_functions)) + ) + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/aten/src/ATen/nnapi/nnapi_bind.cpp b/aten/src/ATen/nnapi/nnapi_bind.cpp new file mode 100644 index 0000000000000..9e652290ab4aa --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_bind.cpp @@ -0,0 +1,204 @@ +#include + +#include +#include + +#include +#include + + +namespace torch { +namespace nnapi { +namespace { + +nnapi_wrapper* nnapi; +nnapi_wrapper* check_nnapi; + +void load_platform_library() { + static int run_once = [](){ + nnapi_wrapper_load(&nnapi, &check_nnapi); + CAFFE_ENFORCE(nnapi); + CAFFE_ENFORCE(nnapi->Model_free); + CAFFE_ENFORCE(nnapi->Compilation_free); + CAFFE_ENFORCE(nnapi->Execution_free); + return 0; + }(); + (void)run_once; +} + +#define MAKE_SMART_PTR(type) \ + struct type ## Freer { \ + void operator()(ANeuralNetworks ## type * obj) { \ + if (!nnapi) { /* obj must be null. */ return; } \ + nnapi-> type ## _free(obj); \ + } \ + }; \ + typedef std::unique_ptr type ## Ptr; + +MAKE_SMART_PTR(Model) +MAKE_SMART_PTR(Compilation) +MAKE_SMART_PTR(Execution) + +#undef MAKE_SMART_PTR + +struct NnapiCompilation : torch::jit::CustomClassHolder { + NnapiCompilation() { + // Could possibly call load_platform_library here, but error reporting + // can be complicated if the constructor is called during model loading. + // Instead, delay all work until the explicit init call. + } + + ~NnapiCompilation() { + } + + void init( + torch::Tensor serialized_model_tensor, + std::vector parameter_buffers) { + TORCH_CHECK(!model_, "Attempted to re-initialize NnapiCompilation."); + + load_platform_library(); + + std::vector buffers; + std::vector buffer_sizes; + for (auto& t : parameter_buffers) { + TORCH_CHECK(t.is_contiguous()); + buffers.push_back(t.data_ptr()); + buffer_sizes.push_back(t.nbytes()); + } + + TORCH_CHECK(serialized_model_tensor.is_contiguous()); + c10::ArrayRef ser_model = { + serialized_model_tensor.data_ptr(), + serialized_model_tensor.nbytes() + }; + TORCH_CHECK(ser_model.size() > 0); + + ANeuralNetworksModel* model; + check_nnapi->Model_create(&model); + CAFFE_ENFORCE(model); + model_.reset(model); + + int load_result = ::caffe2::nnapi::load_nnapi_model( + nnapi, + model_.get(), + ser_model.data(), + ser_model.size(), + buffers.size(), + buffers.data(), + buffer_sizes.data(), + 0, + nullptr, + nullptr, + &num_inputs_, + &num_outputs_, + nullptr); + CAFFE_ENFORCE(load_result == 0); + + check_nnapi->Model_finish(model_.get()); + + ANeuralNetworksCompilation* compilation; + check_nnapi->Compilation_create(model_.get(), &compilation); + // TODO: Make this configurable. + check_nnapi->Compilation_setPreference(compilation, ANEURALNETWORKS_PREFER_SUSTAINED_SPEED); + check_nnapi->Compilation_finish(compilation); + compilation_.reset(compilation); + } + + void run( + std::vector inputs, + std::vector outputs) { + ANeuralNetworksExecution* execution; + check_nnapi->Execution_create(compilation_.get(), &execution); + ExecutionPtr execution_unique_ptr(execution); + + TORCH_CHECK((int32_t)inputs.size() == num_inputs_); + TORCH_CHECK((int32_t)outputs.size() == num_outputs_); + + for (size_t i = 0; i < inputs.size(); i++) { + auto& t = inputs[i]; + // TODO: Check contiguous and dtype. + ANeuralNetworksOperandType op_type; + std::vector dim; + get_operand_type(t, &op_type, &dim); + check_nnapi->Execution_setInput( + execution, + i, + &op_type, + t.data_ptr(), + t.nbytes()); + } + + for (size_t i = 0; i < outputs.size(); i++) { + auto& t = outputs[i]; + // TODO: Check contiguous and dtype. + check_nnapi->Execution_setOutput( + execution, + i, + nullptr, + t.data_ptr(), + t.nbytes()); + } + + check_nnapi->Execution_compute(execution); + + // TODO: Maybe skip this for fixed-size outputs? + for (size_t i = 0; i < outputs.size(); i++) { + auto& t = outputs[i]; + uint32_t rank; + check_nnapi->Execution_getOutputOperandRank(execution, i, &rank); + std::vector dims(rank); + check_nnapi->Execution_getOutputOperandDimensions(execution, i, dims.data()); + std::vector long_dims(dims.begin(), dims.end()); + // TODO: Maybe check that only the batch dimension is changed? + t.resize_(long_dims); + } + } + + static void get_operand_type(const Tensor& t, ANeuralNetworksOperandType* operand, std::vector* dims) { + operand->dimensionCount = t.dim(); + TORCH_CHECK(operand->dimensionCount == t.dim()); // Check for overflow. + dims->resize(t.dim()); + operand->dimensions = dims->data(); + for (size_t i = 0; i < dims->size(); i++) { + (*dims)[i] = t.sizes()[i]; + TORCH_CHECK((*dims)[i] == t.sizes()[i]); // Check for overflow. + } + if (t.scalar_type() == c10::kFloat) { + operand->type = ANEURALNETWORKS_TENSOR_FLOAT32; + operand->scale = 0; + operand->zeroPoint = 0; + return; + } + if (t.scalar_type() == c10::kQUInt8) { + TORCH_CHECK(t.is_quantized()); + operand->type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + operand->scale = t.q_scale(); + operand->zeroPoint = t.q_zero_point(); + return; + } + // TODO: Support more dtypes. + CAFFE_THROW("Bad dtype"); + } + + ModelPtr model_; + CompilationPtr compilation_; + int32_t num_inputs_; + int32_t num_outputs_; +}; + +static auto register_NnapiCompilation = [](){ + try { + return torch::jit::class_("_nnapi", "Compilation") + .def(torch::jit::init<>()) + .def("init", &NnapiCompilation::init) + .def("run", &NnapiCompilation::run) + ; + } catch (std::exception& exn) { + LOG(ERROR) << "Failed to register class nnapi.Compilation: " << exn.what(); + throw; + } +}(); + +} // namespace +} // namespace nnapi +} // namespace torch diff --git a/aten/src/ATen/nnapi/nnapi_model_loader.cpp b/aten/src/ATen/nnapi/nnapi_model_loader.cpp new file mode 100644 index 0000000000000..27fe72d616529 --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_model_loader.cpp @@ -0,0 +1,264 @@ +#include + +#include +#include +#include + + +#ifndef NNAPI_LOADER_STANDALONE + +# include + +#else + +#define CAFFE_ENFORCE(cond, ...) do { if (!cond) { return -1; } } while (0) + +#endif + + +#define NNAPI_CHECK(res) CAFFE_ENFORCE(res == ANEURALNETWORKS_NO_ERROR) + + +namespace caffe2 { +namespace nnapi { + +namespace { + +/* +Serialized format for NNAPI models. It is basically just a list arguments +for calls to be made to NNAPI. +*/ + +typedef enum _SourceType { + SOURCE_IMMEDIATE = 0, + SOURCE_NUMBERED_BUFFER = 2, + SOURCE_NUMBERED_MEMORY = 3, +} SourceType; + +typedef struct _SerializedOperand { + int32_t type; + uint32_t dimension_count; + float scale; + int32_t zero_point; +} SerializedOperand; + +typedef struct _SerializedValue { + int32_t index; + int32_t source_type; + uint32_t source_length; +} SerializedValue; + +typedef struct _SerializedOperation { + int32_t operation_type; + uint32_t input_count; + uint32_t output_count; +} SerializedOperation; + +typedef struct _SerializedModel { + int32_t version; + int32_t operand_count; + int32_t value_count; + int32_t operation_count; + int32_t input_count; + int32_t output_count; + // SerializedOperand operands[operand_count]; + // SerializedValue values[value_count]; + // SerializedOperation operations[operation_count]; + // uint32_t operand_dimensions[sum(dimension_count)] + // uint32_t value_data[sum(source_length+pad)/4] + // uint32_t operation_args[sum(input_count + output_count)] + // uint32_t model_inputs[input_count] + // uint32_t model_outputs[output_count] +} SerializedModel; + + +/** + * Get the physically stored size of a value. All values are padded out + * to a multiple of 4 bytes to ensure the next value is 4-byte aligned. + */ +static uint32_t value_physical_size(uint32_t len) { + uint32_t phys = len; + if (len % 4 == 0) { + return len; + } + return len + 4 - (phys % 4); +} + +} // namespace + + +int load_nnapi_model( + struct nnapi_wrapper* nnapi, + ANeuralNetworksModel* model, + const void* serialized_model, + int64_t model_length, + size_t num_buffers, + const void** buffer_ptrs, + int32_t* buffer_sizes, + size_t num_memories, + ANeuralNetworksMemory** memories, + int32_t* memory_sizes, + int32_t* out_input_count, + int32_t* out_output_count, + size_t* out_bytes_consumed) { + int64_t required_size = 0; + const uint8_t* next_pointer = (const uint8_t*)serialized_model; + const uint8_t* end_of_buf = (const uint8_t*)serialized_model + model_length; + + required_size += sizeof(SerializedModel); + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedModel* ser_model = (SerializedModel*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + CAFFE_ENFORCE(ser_model->version == 1); + // Keep these small to avoid integer overflow. + CAFFE_ENFORCE(ser_model->operand_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->value_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->operation_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->input_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->output_count < (1 << 24)); + + required_size += sizeof(SerializedOperand) * ser_model->operand_count; + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedOperand* operands = (const SerializedOperand*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + required_size += sizeof(SerializedValue) * ser_model->value_count; + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedValue* values = (const SerializedValue*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + required_size += sizeof(SerializedOperation) * ser_model->operation_count; + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedOperation* operations = (const SerializedOperation*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + for (int i = 0; i < ser_model->operand_count; i++) { + required_size += 4 * operands[i].dimension_count; + } + + for (int i = 0; i < ser_model->value_count; i++) { + required_size += value_physical_size(values[i].source_length); + } + + for (int i = 0; i < ser_model->operation_count; i++) { + required_size += 4 * (operations[i].input_count + operations[i].output_count); + } + + required_size += 4 * (ser_model->input_count + ser_model->output_count); + + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + for (int i = 0; i < ser_model->operand_count; i++) { + ANeuralNetworksOperandType operand; + operand.type = operands[i].type; + operand.scale = operands[i].scale; + operand.zeroPoint = operands[i].zero_point; + operand.dimensionCount = operands[i].dimension_count; + operand.dimensions = operands[i].dimension_count ? (const uint32_t*)next_pointer : NULL; + + next_pointer += 4 * operands[i].dimension_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_addOperand(model, &operand); + NNAPI_CHECK(result); + } + + for (int i = 0; i < ser_model->value_count; i++) { + uint32_t len = values[i].source_length; + const uint8_t* stored_pointer = next_pointer; + const void* value_pointer = NULL; + size_t value_length; + + switch ((SourceType)values[i].source_type) { + case SOURCE_IMMEDIATE: + { + value_pointer = stored_pointer; + value_length = len; + } + break; + case SOURCE_NUMBERED_BUFFER: + { + CAFFE_ENFORCE(len == 12); + uint32_t buffer_number = *(uint32_t*)stored_pointer; + uint32_t buffer_offset = *(uint32_t*)(stored_pointer + 4); + uint32_t operand_length = *(uint32_t*)(stored_pointer + 8); + CAFFE_ENFORCE(buffer_number < num_buffers); + CAFFE_ENFORCE(buffer_offset + operand_length >= buffer_offset); // No integer overflow + CAFFE_ENFORCE(buffer_offset + operand_length <= (uint32_t)buffer_sizes[buffer_number]); // No buffer overflow + value_pointer = (uint8_t*)buffer_ptrs[buffer_number] + buffer_offset; + value_length = operand_length; + } + break; + case SOURCE_NUMBERED_MEMORY: + CAFFE_ENFORCE(false, "Memory inputs not implemented yet."); + break; + default: + CAFFE_ENFORCE(false, "Unknown source type: ", values[i].source_type); + } + + CAFFE_ENFORCE(value_pointer != NULL); + + next_pointer += value_physical_size(len); + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_setOperandValue( + model, + values[i].index, + value_pointer, + value_length); + NNAPI_CHECK(result); + } + + for (int i = 0; i < ser_model->operation_count; i++) { + const uint32_t* inputs = (const uint32_t*)next_pointer; + next_pointer += 4 * operations[i].input_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + const uint32_t* outputs = (const uint32_t*)next_pointer; + next_pointer += 4 * operations[i].output_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_addOperation( + model, + operations[i].operation_type, + operations[i].input_count, + inputs, + operations[i].output_count, + outputs); + NNAPI_CHECK(result); + } + + const uint32_t* model_inputs = (const uint32_t*)next_pointer; + next_pointer += 4 * ser_model->input_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + const uint32_t* model_outputs = (const uint32_t*)next_pointer; + next_pointer += 4 * ser_model->output_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_identifyInputsAndOutputs( + model, + ser_model->input_count, + model_inputs, + ser_model->output_count, + model_outputs); + NNAPI_CHECK(result); + + *out_input_count = ser_model->input_count; + *out_output_count = ser_model->output_count; + + // TODO: Maybe eliminate required_size and just rely on next_pointer for bounds checking. + CAFFE_ENFORCE(next_pointer <= end_of_buf); + CAFFE_ENFORCE(next_pointer == (const uint8_t*)serialized_model + required_size); + if (out_bytes_consumed != NULL) { + *out_bytes_consumed = next_pointer - (const uint8_t*)serialized_model; + } + + return 0; +} + +}} // namespace caffe2::nnapi diff --git a/aten/src/ATen/nnapi/nnapi_model_loader.h b/aten/src/ATen/nnapi/nnapi_model_loader.h new file mode 100644 index 0000000000000..6a07a76454b81 --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_model_loader.h @@ -0,0 +1,29 @@ +#ifndef NNAPI_MODEL_LOADER_H_ +#define NNAPI_MODEL_LOADER_H_ + +#include + +#include +#include + +namespace caffe2 { +namespace nnapi { + +int load_nnapi_model( + struct nnapi_wrapper* nnapi, + ANeuralNetworksModel* model, + const void* serialized_model, + int64_t model_length, + size_t num_buffers, + const void** buffer_ptrs, + int32_t* buffer_sizes, + size_t num_memories, + ANeuralNetworksMemory** memories, + int32_t* memory_sizes, + int32_t* out_input_count, + int32_t* out_output_count, + size_t* out_bytes_consumed); + +}} // namespace caffe2::nnapi + +#endif // NNAPI_MODEL_LOADER_H_ diff --git a/aten/src/ATen/nnapi/nnapi_wrapper.cpp b/aten/src/ATen/nnapi/nnapi_wrapper.cpp new file mode 100644 index 0000000000000..f3b18a142313e --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_wrapper.cpp @@ -0,0 +1,331 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is generated by nnapi/codegen.py +#ifndef _WIN32 +#include +#endif +#include +#include +static int loaded = 0; +static struct nnapi_wrapper nnapi_; +static struct nnapi_wrapper check_nnapi_; +int check__getDeviceCount(uint32_t* numDevices) { + CAFFE_ENFORCE(nnapi_._getDeviceCount); + int ret = nnapi_._getDeviceCount(numDevices); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check__getDevice(uint32_t devIndex, ANeuralNetworksDevice** device) { + CAFFE_ENFORCE(nnapi_._getDevice); + int ret = nnapi_._getDevice(devIndex,device); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Device_getName(const ANeuralNetworksDevice* device, const char** name) { + CAFFE_ENFORCE(nnapi_.Device_getName); + int ret = nnapi_.Device_getName(device,name); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Device_getVersion(const ANeuralNetworksDevice* device, const char** version) { + CAFFE_ENFORCE(nnapi_.Device_getVersion); + int ret = nnapi_.Device_getVersion(device,version); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Device_getFeatureLevel(const ANeuralNetworksDevice* device, int64_t* featureLevel) { + CAFFE_ENFORCE(nnapi_.Device_getFeatureLevel); + int ret = nnapi_.Device_getFeatureLevel(device,featureLevel); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_getSupportedOperationsForDevices( const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps) { + CAFFE_ENFORCE(nnapi_.Model_getSupportedOperationsForDevices); + int ret = nnapi_.Model_getSupportedOperationsForDevices(model,devices,numDevices,supportedOps); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Compilation_createForDevices(ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_createForDevices); + int ret = nnapi_.Compilation_createForDevices(model,devices,numDevices,compilation); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_compute(ANeuralNetworksExecution* execution) { + CAFFE_ENFORCE(nnapi_.Execution_compute); + int ret = nnapi_.Execution_compute(execution); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Memory_createFromFd(size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory) { + CAFFE_ENFORCE(nnapi_.Memory_createFromFd); + int ret = nnapi_.Memory_createFromFd(size,protect,fd,offset,memory); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Memory_free(ANeuralNetworksMemory* memory) { + CAFFE_ENFORCE(nnapi_.Memory_free); + nnapi_.Memory_free(memory); +} +int check_Model_create(ANeuralNetworksModel** model) { + CAFFE_ENFORCE(nnapi_.Model_create); + int ret = nnapi_.Model_create(model); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Model_free(ANeuralNetworksModel* model) { + CAFFE_ENFORCE(nnapi_.Model_free); + nnapi_.Model_free(model); +} +int check_Model_finish(ANeuralNetworksModel* model) { + CAFFE_ENFORCE(nnapi_.Model_finish); + int ret = nnapi_.Model_finish(model); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_addOperand(ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type) { + CAFFE_ENFORCE(nnapi_.Model_addOperand); + int ret = nnapi_.Model_addOperand(model,type); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_setOperandValue(ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length) { + CAFFE_ENFORCE(nnapi_.Model_setOperandValue); + int ret = nnapi_.Model_setOperandValue(model,index,buffer,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + CAFFE_ENFORCE(nnapi_.Model_setOperandValueFromMemory); + int ret = nnapi_.Model_setOperandValueFromMemory(model,index,memory,offset,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_addOperation(ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs) { + CAFFE_ENFORCE(nnapi_.Model_addOperation); + int ret = nnapi_.Model_addOperation(model,type,inputCount,inputs,outputCount,outputs); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs) { + CAFFE_ENFORCE(nnapi_.Model_identifyInputsAndOutputs); + int ret = nnapi_.Model_identifyInputsAndOutputs(model,inputCount,inputs,outputCount,outputs); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_relaxComputationFloat32toFloat16(ANeuralNetworksModel* model, bool allow) { + CAFFE_ENFORCE(nnapi_.Model_relaxComputationFloat32toFloat16); + int ret = nnapi_.Model_relaxComputationFloat32toFloat16(model,allow); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Compilation_create(ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_create); + int ret = nnapi_.Compilation_create(model,compilation); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Compilation_free(ANeuralNetworksCompilation* compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_free); + nnapi_.Compilation_free(compilation); +} +int check_Compilation_setPreference(ANeuralNetworksCompilation* compilation, int32_t preference) { + CAFFE_ENFORCE(nnapi_.Compilation_setPreference); + int ret = nnapi_.Compilation_setPreference(compilation,preference); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Compilation_finish(ANeuralNetworksCompilation* compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_finish); + int ret = nnapi_.Compilation_finish(compilation); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_create(ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution) { + CAFFE_ENFORCE(nnapi_.Execution_create); + int ret = nnapi_.Execution_create(compilation,execution); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Execution_free(ANeuralNetworksExecution* execution) { + CAFFE_ENFORCE(nnapi_.Execution_free); + nnapi_.Execution_free(execution); +} +int check_Execution_setInput(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setInput); + int ret = nnapi_.Execution_setInput(execution,index,type,buffer,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setInputFromMemory); + int ret = nnapi_.Execution_setInputFromMemory(execution,index,type,memory,offset,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_setOutput(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setOutput); + int ret = nnapi_.Execution_setOutput(execution,index,type,buffer,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setOutputFromMemory); + int ret = nnapi_.Execution_setOutputFromMemory(execution,index,type,memory,offset,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_startCompute(ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event) { + CAFFE_ENFORCE(nnapi_.Execution_startCompute); + int ret = nnapi_.Execution_startCompute(execution,event); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Event_wait(ANeuralNetworksEvent* event) { + CAFFE_ENFORCE(nnapi_.Event_wait); + int ret = nnapi_.Event_wait(event); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Event_free(ANeuralNetworksEvent* event) { + CAFFE_ENFORCE(nnapi_.Event_free); + nnapi_.Event_free(event); +} +int check_Execution_getOutputOperandRank(ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank) { + CAFFE_ENFORCE(nnapi_.Execution_getOutputOperandRank); + int ret = nnapi_.Execution_getOutputOperandRank(execution,index,rank); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_getOutputOperandDimensions(ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions) { + CAFFE_ENFORCE(nnapi_.Execution_getOutputOperandDimensions); + int ret = nnapi_.Execution_getOutputOperandDimensions(execution,index,dimensions); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi) { +#ifdef _WIN32 + TORCH_CHECK(false, "Running NNAPI models is not supported on Windows."); +#else + if (!loaded) { + // Clear error flag. + dlerror(); + void* handle = dlopen("libneuralnetworks.so", RTLD_LAZY | RTLD_LOCAL); + CAFFE_ENFORCE(handle, "Failed to load libneuralnetworks.so ", dlerror()); + *(void**)&nnapi_._getDeviceCount = dlsym(handle, "ANeuralNetworks_getDeviceCount"); + check_nnapi_._getDeviceCount = check__getDeviceCount; + *(void**)&nnapi_._getDevice = dlsym(handle, "ANeuralNetworks_getDevice"); + check_nnapi_._getDevice = check__getDevice; + *(void**)&nnapi_.Device_getName = dlsym(handle, "ANeuralNetworksDevice_getName"); + check_nnapi_.Device_getName = check_Device_getName; + *(void**)&nnapi_.Device_getVersion = dlsym(handle, "ANeuralNetworksDevice_getVersion"); + check_nnapi_.Device_getVersion = check_Device_getVersion; + *(void**)&nnapi_.Device_getFeatureLevel = dlsym(handle, "ANeuralNetworksDevice_getFeatureLevel"); + check_nnapi_.Device_getFeatureLevel = check_Device_getFeatureLevel; + *(void**)&nnapi_.Model_getSupportedOperationsForDevices = dlsym(handle, "ANeuralNetworksModel_getSupportedOperationsForDevices"); + check_nnapi_.Model_getSupportedOperationsForDevices = check_Model_getSupportedOperationsForDevices; + *(void**)&nnapi_.Compilation_createForDevices = dlsym(handle, "ANeuralNetworksCompilation_createForDevices"); + check_nnapi_.Compilation_createForDevices = check_Compilation_createForDevices; + *(void**)&nnapi_.Execution_compute = dlsym(handle, "ANeuralNetworksExecution_compute"); + check_nnapi_.Execution_compute = check_Execution_compute; + *(void**)&nnapi_.Memory_createFromFd = dlsym(handle, "ANeuralNetworksMemory_createFromFd"); + check_nnapi_.Memory_createFromFd = check_Memory_createFromFd; + *(void**)&nnapi_.Memory_free = dlsym(handle, "ANeuralNetworksMemory_free"); + check_nnapi_.Memory_free = check_Memory_free; + *(void**)&nnapi_.Model_create = dlsym(handle, "ANeuralNetworksModel_create"); + check_nnapi_.Model_create = check_Model_create; + *(void**)&nnapi_.Model_free = dlsym(handle, "ANeuralNetworksModel_free"); + check_nnapi_.Model_free = check_Model_free; + *(void**)&nnapi_.Model_finish = dlsym(handle, "ANeuralNetworksModel_finish"); + check_nnapi_.Model_finish = check_Model_finish; + *(void**)&nnapi_.Model_addOperand = dlsym(handle, "ANeuralNetworksModel_addOperand"); + check_nnapi_.Model_addOperand = check_Model_addOperand; + *(void**)&nnapi_.Model_setOperandValue = dlsym(handle, "ANeuralNetworksModel_setOperandValue"); + check_nnapi_.Model_setOperandValue = check_Model_setOperandValue; + *(void**)&nnapi_.Model_setOperandValueFromMemory = dlsym(handle, "ANeuralNetworksModel_setOperandValueFromMemory"); + check_nnapi_.Model_setOperandValueFromMemory = check_Model_setOperandValueFromMemory; + *(void**)&nnapi_.Model_addOperation = dlsym(handle, "ANeuralNetworksModel_addOperation"); + check_nnapi_.Model_addOperation = check_Model_addOperation; + *(void**)&nnapi_.Model_identifyInputsAndOutputs = dlsym(handle, "ANeuralNetworksModel_identifyInputsAndOutputs"); + check_nnapi_.Model_identifyInputsAndOutputs = check_Model_identifyInputsAndOutputs; + *(void**)&nnapi_.Model_relaxComputationFloat32toFloat16 = dlsym(handle, "ANeuralNetworksModel_relaxComputationFloat32toFloat16"); + check_nnapi_.Model_relaxComputationFloat32toFloat16 = check_Model_relaxComputationFloat32toFloat16; + *(void**)&nnapi_.Compilation_create = dlsym(handle, "ANeuralNetworksCompilation_create"); + check_nnapi_.Compilation_create = check_Compilation_create; + *(void**)&nnapi_.Compilation_free = dlsym(handle, "ANeuralNetworksCompilation_free"); + check_nnapi_.Compilation_free = check_Compilation_free; + *(void**)&nnapi_.Compilation_setPreference = dlsym(handle, "ANeuralNetworksCompilation_setPreference"); + check_nnapi_.Compilation_setPreference = check_Compilation_setPreference; + *(void**)&nnapi_.Compilation_finish = dlsym(handle, "ANeuralNetworksCompilation_finish"); + check_nnapi_.Compilation_finish = check_Compilation_finish; + *(void**)&nnapi_.Execution_create = dlsym(handle, "ANeuralNetworksExecution_create"); + check_nnapi_.Execution_create = check_Execution_create; + *(void**)&nnapi_.Execution_free = dlsym(handle, "ANeuralNetworksExecution_free"); + check_nnapi_.Execution_free = check_Execution_free; + *(void**)&nnapi_.Execution_setInput = dlsym(handle, "ANeuralNetworksExecution_setInput"); + check_nnapi_.Execution_setInput = check_Execution_setInput; + *(void**)&nnapi_.Execution_setInputFromMemory = dlsym(handle, "ANeuralNetworksExecution_setInputFromMemory"); + check_nnapi_.Execution_setInputFromMemory = check_Execution_setInputFromMemory; + *(void**)&nnapi_.Execution_setOutput = dlsym(handle, "ANeuralNetworksExecution_setOutput"); + check_nnapi_.Execution_setOutput = check_Execution_setOutput; + *(void**)&nnapi_.Execution_setOutputFromMemory = dlsym(handle, "ANeuralNetworksExecution_setOutputFromMemory"); + check_nnapi_.Execution_setOutputFromMemory = check_Execution_setOutputFromMemory; + *(void**)&nnapi_.Execution_startCompute = dlsym(handle, "ANeuralNetworksExecution_startCompute"); + check_nnapi_.Execution_startCompute = check_Execution_startCompute; + *(void**)&nnapi_.Event_wait = dlsym(handle, "ANeuralNetworksEvent_wait"); + check_nnapi_.Event_wait = check_Event_wait; + *(void**)&nnapi_.Event_free = dlsym(handle, "ANeuralNetworksEvent_free"); + check_nnapi_.Event_free = check_Event_free; + *(void**)&nnapi_.Execution_getOutputOperandRank = dlsym(handle, "ANeuralNetworksExecution_getOutputOperandRank"); + check_nnapi_.Execution_getOutputOperandRank = check_Execution_getOutputOperandRank; + *(void**)&nnapi_.Execution_getOutputOperandDimensions = dlsym(handle, "ANeuralNetworksExecution_getOutputOperandDimensions"); + check_nnapi_.Execution_getOutputOperandDimensions = check_Execution_getOutputOperandDimensions; + loaded = 1; + } + *nnapi = &nnapi_; + *check_nnapi = &check_nnapi_; +#endif +} diff --git a/aten/src/ATen/nnapi/nnapi_wrapper.h b/aten/src/ATen/nnapi/nnapi_wrapper.h new file mode 100644 index 0000000000000..7ab07bfe2b121 --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_wrapper.h @@ -0,0 +1,62 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is generated by nnapi/codegen.py +#ifndef NNAPI_WRAPPER_H_ +#define NNAPI_WRAPPER_H_ +#include +#include +#include +struct nnapi_wrapper { + int(*_getDeviceCount)(uint32_t* numDevices); + int(*_getDevice)(uint32_t devIndex, ANeuralNetworksDevice** device); + int(*Device_getName)(const ANeuralNetworksDevice* device, const char** name); + int(*Device_getVersion)(const ANeuralNetworksDevice* device, const char** version); + int(*Device_getFeatureLevel)(const ANeuralNetworksDevice* device, int64_t* featureLevel); + int(*Model_getSupportedOperationsForDevices)( const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps); + int(*Compilation_createForDevices)(ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation); + int(*Execution_compute)(ANeuralNetworksExecution* execution); + int(*Memory_createFromFd)(size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory); + void(*Memory_free)(ANeuralNetworksMemory* memory); + int(*Model_create)(ANeuralNetworksModel** model); + void(*Model_free)(ANeuralNetworksModel* model); + int(*Model_finish)(ANeuralNetworksModel* model); + int(*Model_addOperand)(ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type); + int(*Model_setOperandValue)(ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length); + int(*Model_setOperandValueFromMemory)(ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length); + int(*Model_addOperation)(ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs); + int(*Model_identifyInputsAndOutputs)(ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs); + int(*Model_relaxComputationFloat32toFloat16)(ANeuralNetworksModel* model, bool allow); + int(*Compilation_create)(ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation); + void(*Compilation_free)(ANeuralNetworksCompilation* compilation); + int(*Compilation_setPreference)(ANeuralNetworksCompilation* compilation, int32_t preference); + int(*Compilation_finish)(ANeuralNetworksCompilation* compilation); + int(*Execution_create)(ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution); + void(*Execution_free)(ANeuralNetworksExecution* execution); + int(*Execution_setInput)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length); + int(*Execution_setInputFromMemory)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length); + int(*Execution_setOutput)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length); + int(*Execution_setOutputFromMemory)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length); + int(*Execution_startCompute)(ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event); + int(*Event_wait)(ANeuralNetworksEvent* event); + void(*Event_free)(ANeuralNetworksEvent* event); + int(*Execution_getOutputOperandRank)(ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank); + int(*Execution_getOutputOperandDimensions)(ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions); +}; +#ifdef __cplusplus +void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi); +#endif +#endif diff --git a/aten/src/ATen/quantized/QTensorImpl.cpp b/aten/src/ATen/quantized/QTensorImpl.cpp index 40ecbf6d57209..1c79ac186c1ab 100644 --- a/aten/src/ATen/quantized/QTensorImpl.cpp +++ b/aten/src/ATen/quantized/QTensorImpl.cpp @@ -5,7 +5,7 @@ namespace at { QTensorImpl::QTensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, QuantizerPtr quantizer) : TensorImpl(std::move(storage), key_set, data_type), quantizer_(quantizer) {} diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h index c2728c7aab466..9c5db9f57f999 100644 --- a/aten/src/ATen/quantized/QTensorImpl.h +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -13,12 +13,12 @@ namespace at { * * We'll use QTensor in code or documentation to refer to a Tensor with QTensorImpl. */ -struct CAFFE2_API QTensorImpl : public c10::TensorImpl { +struct TORCH_API QTensorImpl : public c10::TensorImpl { public: QTensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, QuantizerPtr quantizer); // TODO: Expose in PyTorch Frontend @@ -51,6 +51,27 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { return impl; } + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive( + Storage(storage()), key_set(), data_type_, quantizer_); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/std::move(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + impl->refresh_contiguous(); + return impl; + } + /** * Shallow-copies data from another TensorImpl into this TensorImpl. * diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index 1f9225b527709..6647a8bea69ee 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -77,6 +77,18 @@ QTensorImpl* get_qtensorimpl(const Tensor& self) { return static_cast(self.unsafeGetTensorImpl()); } +int64_t get_sub_byte_tensor_size(int64_t size_bytes, at::ScalarType t) { + int64_t new_size_bytes; + switch(t) { + case at::ScalarType::QUInt4x2: + new_size_bytes = std::ceil(size_bytes * 0.5); + break; + default: + new_size_bytes = size_bytes; + } + return new_size_bytes; +} + inline Tensor new_qtensor( IntArrayRef sizes, const TensorOptions& options, @@ -99,7 +111,9 @@ inline Tensor new_qtensor( TORCH_CHECK( isQIntType(typeMetaToScalarType(dtype)), "ScalarType is not supported in new_qtensor."); - int64_t size_bytes = nelements * dtype.itemsize(); + auto scalar_type = typeMetaToScalarType(dtype); + int64_t size_bytes = get_sub_byte_tensor_size(nelements * dtype.itemsize(), scalar_type); + auto storage = c10::make_intrusive( StorageImpl::use_byte_size_t(), size_bytes, @@ -193,4 +207,8 @@ Tensor PerChannelAffineFloatQParamsQuantizer::dequantize(Tensor qtensor) { Quantizer::~Quantizer() {} +C10_EXPORT void set_quantizer_(const Tensor& self, ConstQuantizerPtr quantizer) { + get_qtensorimpl(self)->set_quantizer_(quantizer); +} + } // namespace at diff --git a/aten/src/ATen/quantized/Quantizer.h b/aten/src/ATen/quantized/Quantizer.h index e592d805fd203..1c740b774a008 100644 --- a/aten/src/ATen/quantized/Quantizer.h +++ b/aten/src/ATen/quantized/Quantizer.h @@ -24,7 +24,7 @@ namespace at { * the quantized value. For example, affine quantizer is * the most commonly used scheme in this category. */ -struct CAFFE2_API UniformQuantizer : public Quantizer { +struct TORCH_API UniformQuantizer : public Quantizer { explicit UniformQuantizer(ScalarType scalar_type) : Quantizer(scalar_type) {} }; @@ -33,7 +33,7 @@ struct CAFFE2_API UniformQuantizer : public Quantizer { * These quantization scheme may map float value non-uniformly to the quantized * value. K-means quantization is a representative example in this category. */ -struct CAFFE2_API NonUniformQuantizer : public Quantizer { +struct TORCH_API NonUniformQuantizer : public Quantizer { explicit NonUniformQuantizer(ScalarType scalar_type) : Quantizer(scalar_type) {} }; @@ -47,7 +47,7 @@ struct CAFFE2_API NonUniformQuantizer : public Quantizer { * For dequantize: * X = (Y - zero_point) * scale */ -struct CAFFE2_API AffineQuantizer : public UniformQuantizer { +struct TORCH_API AffineQuantizer : public UniformQuantizer { explicit AffineQuantizer(ScalarType scalar_type) : UniformQuantizer(scalar_type) {} }; @@ -58,7 +58,7 @@ struct CAFFE2_API AffineQuantizer : public UniformQuantizer { * PerTensorAffineQuantizer stores a scale and a zero_point, which is used for * all the values in the Tensor. */ -struct CAFFE2_API PerTensorAffineQuantizer : public AffineQuantizer { +struct TORCH_API PerTensorAffineQuantizer : public AffineQuantizer { explicit PerTensorAffineQuantizer(ScalarType scalar_type, double scale, int64_t zero_point) : AffineQuantizer(scalar_type), scale_(scale), @@ -107,7 +107,7 @@ struct CAFFE2_API PerTensorAffineQuantizer : public AffineQuantizer { * processors since it requires each multiplication result within a single * dot-product to have a different scale. */ -struct CAFFE2_API PerChannelAffineQuantizer : public AffineQuantizer { +struct TORCH_API PerChannelAffineQuantizer : public AffineQuantizer { explicit PerChannelAffineQuantizer( ScalarType scalar_type, Tensor scales, @@ -169,7 +169,7 @@ struct CAFFE2_API PerChannelAffineQuantizer : public AffineQuantizer { * be exactly represented in the quantized space. We can get additional precision by * using floating point values for zero point. */ -struct CAFFE2_API PerChannelAffineFloatQParamsQuantizer : public PerChannelAffineQuantizer { +struct TORCH_API PerChannelAffineFloatQParamsQuantizer : public PerChannelAffineQuantizer { explicit PerChannelAffineFloatQParamsQuantizer( ScalarType scalar_type, Tensor scales, @@ -205,24 +205,26 @@ struct CAFFE2_API PerChannelAffineFloatQParamsQuantizer : public PerChannelAffin // setters/getters for QTensorImpl fields; otherwise, you should use // the low level setters/getters that were implemented using this. // This may be called repeatedly, so make sure it's pretty cheap. -CAFFE2_API QTensorImpl* get_qtensorimpl(const Tensor& self); +TORCH_API QTensorImpl* get_qtensorimpl(const Tensor& self); // double and int64_t are because of the native function API, we only have these // argument types right now in native functions -CAFFE2_API QuantizerPtr +TORCH_API QuantizerPtr make_per_tensor_affine_quantizer( double scale, int64_t zero_point, ScalarType scalar_type); -CAFFE2_API QuantizerPtr make_per_channel_affine_quantizer( +TORCH_API QuantizerPtr make_per_channel_affine_quantizer( const Tensor& scales, const Tensor& zero_points, int64_t axis, ScalarType scalar_type); // Create a Quantized Tensor given arguments for normal Tensor and a quantizer -CAFFE2_API Tensor new_qtensor( +TORCH_API Tensor new_qtensor( IntArrayRef sizes, const TensorOptions& options, QuantizerPtr quantizer); +TORCH_API void set_quantizer_(const Tensor& self, ConstQuantizerPtr quantizer); + } // namespace at diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 8bf25c3cac2f2..a75b1a1295dbe 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include #include @@ -9,18 +11,16 @@ namespace { // Used to generate unique callback handles CallbackHandle next_unique_callback_handle() { - static std::atomic unique_cb_id {0}; - return CallbackHandle(++unique_cb_id); + static std::atomic unique_cb_id {1}; + return CallbackHandle(unique_cb_id++); } RecordFunctionHandle next_unique_record_function_handle() { - static std::atomic unique_rf_id {0}; - return RecordFunctionHandle(++unique_rf_id); + static std::atomic unique_rf_id {1}; + return RecordFunctionHandle(unique_rf_id++); } -// Thread local vector of callbacks, holds pairs (callbacks, unique_id); -// must be sorted in increasing handles order -thread_local RecordFunctionCallbacks sorted_tls_callbacks_; +thread_local RecordFunctionTLS rf_tls_; std::atomic defaultNodeId(-1); @@ -30,39 +30,58 @@ std::atomic defaultNodeId(-1); std::atomic next_thread_id_ {0}; thread_local uint64_t current_thread_id_ = 0; -thread_local bool tls_record_function_enabled_ = true; - // Low probability constant -const double kLowProb = 0.001; -thread_local int tries_left_ = 0; +static const double kLowProb = 0.001; +struct CoinflipTLS { + int tries_left_; + std::mt19937 genGeo_; + std::mt19937 genZeroOne_; + std::geometric_distribution distGeo_; + std::uniform_real_distribution distZeroOne_; + CoinflipTLS(); +}; + +CoinflipTLS::CoinflipTLS() + : tries_left_(0), genGeo_(std::random_device()()), genZeroOne_(std::random_device()()), distGeo_(kLowProb), distZeroOne_(0.0, 1.0) {} +thread_local CoinflipTLS coinflip_tls_; int sample_geometric() { - static thread_local auto gen = - std::make_unique(std::random_device()()); - std::geometric_distribution dist(kLowProb); - return dist(*gen); + return coinflip_tls_.distGeo_(coinflip_tls_.genGeo_); } double sample_zero_one() { - static thread_local auto gen = - std::make_unique(std::random_device()()); - std::uniform_real_distribution dist(0.0, 1.0); - return dist(*gen); + return coinflip_tls_.distZeroOne_(coinflip_tls_.genZeroOne_); } } // namespace +const RecordFunctionTLS& get_record_function_tls_() { + return rf_tls_; +} + +void set_record_function_tls_(const RecordFunctionTLS& tls) { + rf_tls_ = tls; +} + class CallbackManager { public: CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) { + if (cb.samplingProb() > kLowProb) { + // pre-sampling of RecordFunction with prob. kLowProb cannot be used + at::bumpRecordAllFunctions(); + } // note: monotonically increasing callbacks_unique_id keeps // sorted_tls_callbacks_ sorted auto handle = next_unique_callback_handle(); - sorted_tls_callbacks_.emplace_back(std::move(cb), handle); + rf_tls_.sorted_tls_callbacks_.emplace_back(std::move(cb), handle); return handle; } CallbackHandle addGlobalCallback(RecordFunctionCallback cb) { + if (cb.samplingProb() > kLowProb) { + // pre-sampling of RecordFunction with prob. kLowProb cannot be used + at::bumpRecordAllFunctions(); + } auto handle = next_unique_callback_handle(); sorted_global_callbacks_.emplace_back(std::move(cb), handle); return handle; @@ -79,13 +98,17 @@ class CallbackManager { return el.second == handle; }); if (it != cbs.end()) { + if (it->first.samplingProb() > kLowProb) { + // try to restore pre-sampling of RecordFunction + at::releaseRecordAllFunctions(); + } // keeps it sorted cbs.erase(it); return true; } return false; }; - auto found = find_and_remove(sorted_tls_callbacks_); + auto found = find_and_remove(rf_tls_.sorted_tls_callbacks_); if (!found) { found = find_and_remove(sorted_global_callbacks_); } @@ -99,7 +122,7 @@ class CallbackManager { } void clearThreadLocalCallbacks() { - sorted_tls_callbacks_.clear(); + rf_tls_.sorted_tls_callbacks_.clear(); } inline bool hasGlobalCallbacks() const { @@ -107,45 +130,107 @@ class CallbackManager { } inline bool hasThreadLocalCallbacks() const { - return !sorted_tls_callbacks_.empty(); + return !rf_tls_.sorted_tls_callbacks_.empty(); + } + + // We need this function to be inlined: init() is a hot path and + // callbackShouldRun is even hotter because it's called multiple + // times per init(). Profiling shows that the function prologue is + // taking up a significant fraction of the time. + static bool C10_ALWAYS_INLINE callbackShouldRun( + const RecordFunctionCallback& cb, RecordScope scope, bool pre_sampled) { + TORCH_INTERNAL_ASSERT( + !pre_sampled || (cb.sampling_prob_ <= kLowProb), + "Incorrect usage of a pre-sampled RecordFunction with a high-frequency " + " or non-sampled callback"); + + // first check whether this callback is interested in + // the given scope type + if (!cb.checkScope(scope)) { + return false; + } + // if we have registered should_run_ function, use it + if (cb.should_run_) { + return cb.should_run_(cb); + } + + // otherwise potentially do the sampling + double sampling_prob = cb.sampling_prob_; + if (pre_sampled) { + // adjust the sampling rate to account for kLowProb pre-sampling of + // the RecordFunction + sampling_prob /= kLowProb; + } + + if (sampling_prob < 1.0) { + // model the low probability events as events happening + // with probability kLowProb followed by another sampling with + // probability (sampling_prob / kLowProb), then replace the coin + // flip for kLowProb with a thread local number of tries tries_left_ + // sampled from the geometric distribution. + if (sampling_prob < kLowProb) { + if (coinflip_tls_.tries_left_ == 0) { + coinflip_tls_.tries_left_ = sample_geometric(); + return (sample_zero_one() < sampling_prob / kLowProb); + } else { + --coinflip_tls_.tries_left_; + return false; + } + } else { + return (sample_zero_one() < sampling_prob); + } + } + + return true; } // init is called by RecordFunction in constructor to // determine which thread local and global callbacks are going // to be executed and whether any of them need inputs - inline void init(RecordFunction& rec_fn) { - auto scope = rec_fn.scope(); - bool found_active_cb = false; + inline void init(RecordFunction& rec_fn, RecordScope scope, bool pre_sampled) { bool found_needs_inputs = false; bool found_needs_ids = false; - auto init_handles = [ - scope, &found_active_cb, &found_needs_inputs, &found_needs_ids]( - CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) { - handles.clear(); - - size_t num_callbacks = 0; - for (const auto& cb : cbs) { - if (cb.first.shouldRun(scope)) { - handles.push_back(cb.second); - ++num_callbacks; - found_active_cb = true; - if (cb.first.needsInputs()) { - found_needs_inputs = true; - } - if (cb.first.needsIds()) { - found_needs_ids = true; - } + + for (const auto& cb: rf_tls_.sorted_tls_callbacks_) { + if (callbackShouldRun(cb.first, scope, pre_sampled)) { + if (cb.first.needsInputs()) { + found_needs_inputs = true; + } + if (cb.first.needsIds()) { + found_needs_ids = true; } + if (!rec_fn.state_) { + rec_fn.state_ = std::make_unique(scope); + } + rec_fn.state_->sorted_active_tls_handles_.push_back(cb.second); } - // Pre-allocate observer context list with nullptr. - ctx_list.resize(num_callbacks); - }; + } + + for (const auto& cb: sorted_global_callbacks_) { + if (callbackShouldRun(cb.first, scope, pre_sampled)) { + if (cb.first.needsInputs()) { + found_needs_inputs = true; + } + if (cb.first.needsIds()) { + found_needs_ids = true; + } + if (!rec_fn.state_) { + rec_fn.state_ = std::make_unique(scope); + } + rec_fn.state_->sorted_active_global_handles_.push_back(cb.second); + } + } + + if (!rec_fn.state_) { + return; + } + + // Pre-allocate observer context list with nullptr. + rec_fn.state_->tls_ctx_.resize(rec_fn.state_->sorted_active_tls_handles_.size()); + rec_fn.state_->global_ctx_.resize(rec_fn.state_->sorted_active_global_handles_.size()); - init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_); - init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_); - rec_fn.active = found_active_cb; - rec_fn.needs_inputs = found_needs_inputs; - if (found_needs_ids && found_active_cb) { + rec_fn.state_->needs_inputs = found_needs_inputs; + if (found_needs_ids) { rec_fn.setHandle(next_unique_record_function_handle()); } } @@ -153,34 +238,37 @@ class CallbackManager { void runStartCallbacks(RecordFunction& rf) { mergeRunCallbacks( sorted_global_callbacks_, - rf.sorted_active_global_handles_, - rf.global_ctx_, + rf.state_->sorted_active_global_handles_, + rf.state_->global_ctx_, /* is_start */ true, rf); mergeRunCallbacks( - sorted_tls_callbacks_, - rf.sorted_active_tls_handles_, - rf.tls_ctx_, + rf_tls_.sorted_tls_callbacks_, + rf.state_->sorted_active_tls_handles_, + rf.state_->tls_ctx_, /* is_start */ true, rf); - rf.called_start_callbacks_ = true; + rf.state_->called_start_callbacks_ = true; } void runEndCallbacks(RecordFunction& rf) { mergeRunCallbacks( sorted_global_callbacks_, - rf.sorted_active_global_handles_, - rf.global_ctx_, + rf.state_->sorted_active_global_handles_, + rf.state_->global_ctx_, /* is_start */ false, rf); mergeRunCallbacks( - sorted_tls_callbacks_, - rf.sorted_active_tls_handles_, - rf.tls_ctx_, + rf_tls_.sorted_tls_callbacks_, + rf.state_->sorted_active_tls_handles_, + rf.state_->tls_ctx_, /* is_start */ false, rf); } + // Global callbacks; must be sorted in increasing handle order + RecordFunctionCallbacks sorted_global_callbacks_; + private: bool tryRunCallback( const RecordFunctionCallback& rfcb, @@ -189,10 +277,12 @@ class CallbackManager { bool is_start) { try { if (is_start) { - ctx = rfcb.start()(rf); + ctx = rfcb.start() ? rfcb.start()(rf) : nullptr; } else { - rfcb.end()(rf, ctx.get()); + if (rfcb.end()) { + rfcb.end()(rf, ctx.get()); + } } return true; } catch (const std::exception &e) { @@ -235,9 +325,6 @@ class CallbackManager { << "the code after profiler is finished"; } } - - // Global callbacks; must be sorted in increasing handle order - RecordFunctionCallbacks sorted_global_callbacks_; }; namespace { @@ -248,48 +335,16 @@ namespace { } } // namespace -bool RecordFunctionCallback::shouldRun(RecordScope scope) const { - // first check whether this callback is interested in - // the given scope type - if (!checkScope(scope)) { - return false; - } - // if we have registered should_run_ function, use it - if (should_run_) { - return should_run_(*this); - } - // otherwise potentially do the uniform sampling - if (sampling_prob_ != 1.0) { - // model the low probability events as events happening - // with prob. kLowProb followed by another sampling with - // prob. (sampling_prob_ / kLowProb), then replace the coin - // flip for kLowProb with a thread local number of tries tries_left_ - // sampled from the geometric distribution - if (sampling_prob_ < kLowProb) { - if (tries_left_ == 0) { - tries_left_ = sample_geometric(); - return (sample_zero_one() < sampling_prob_ / kLowProb); - } else { - --tries_left_; - return false; - } - } else { - return (sample_zero_one() < sampling_prob_); - } - } - return true; -} - RecordFunctionCallbacks _getTLSCallbacks() { - return sorted_tls_callbacks_; + return rf_tls_.sorted_tls_callbacks_; } void _setTLSCallbacks(const RecordFunctionCallbacks& callbacks) { // keep the original handles - sorted_tls_callbacks_ = callbacks; + rf_tls_.sorted_tls_callbacks_ = callbacks; std::sort( - sorted_tls_callbacks_.begin(), - sorted_tls_callbacks_.end(), + rf_tls_.sorted_tls_callbacks_.begin(), + rf_tls_.sorted_tls_callbacks_.end(), [](const std::pair& l, const std::pair& r) { return l.second < r.second; @@ -338,16 +393,20 @@ void clearCallbacks() { } bool isRecordFunctionEnabled() { - return tls_record_function_enabled_; + return rf_tls_.tls_record_function_enabled_; } void enableRecordFunction(bool enable) { - tls_record_function_enabled_ = enable; + rf_tls_.tls_record_function_enabled_ = enable; } -RecordFunction::RecordFunction(RecordScope scope) : scope_(scope) { - if (hasCallbacks() && isRecordFunctionEnabled()) { - manager().init(*this); +RecordFunction::RecordFunction(RecordScope scope, bool pre_sampled) { + auto* rf_tls_ptr = &rf_tls_; + if (rf_tls_ptr->tls_record_function_enabled_) { + auto& m = manager(); + if (!m.sorted_global_callbacks_.empty() || !rf_tls_ptr->sorted_tls_callbacks_.empty()) { + m.init(*this, scope, pre_sampled); + } } } @@ -361,23 +420,39 @@ uint64_t RecordFunction::currentThreadId() { } void RecordFunction::before(const char* name, int64_t sequence_nr) { - if (!active) { + if (!isActive()) { return; } - name_ = StringView(name); - sequence_nr_ = sequence_nr; - thread_id_ = currentThreadId(); + state_->name_ = StringView(name); + state_->sequence_nr_ = sequence_nr; + state_->thread_id_ = currentThreadId(); + state_->operator_name_.reset(); manager().runStartCallbacks(*this); } void RecordFunction::before(std::string name, int64_t sequence_nr) { - if (!active) { + if (!isActive()) { + return; + } + state_->name_ = StringView(std::move(name)); + state_->sequence_nr_ = sequence_nr; + state_->thread_id_ = currentThreadId(); + state_->operator_name_.reset(); + + manager().runStartCallbacks(*this); +} + +void RecordFunction::before( + c10::OperatorHandle const& op, + int64_t sequence_nr) { + if (!isActive()) { return; } - name_ = StringView(std::move(name)); - sequence_nr_ = sequence_nr; - thread_id_ = currentThreadId(); + state_->sequence_nr_ = sequence_nr; + state_->thread_id_ = currentThreadId(); + state_->operator_name_ = op.operator_name(); + state_->name_ = StringView(op.schema().name()); manager().runStartCallbacks(*this); } @@ -396,10 +471,55 @@ RecordFunction::~RecordFunction() { } void RecordFunction::end() { - if (active && called_start_callbacks_) { + if (isActive() && state_->called_start_callbacks_) { manager().runEndCallbacks(*this); + state_.reset(); + } +} + +// RecordFunction pre-sampling +namespace { +// Whether to try to create RecordFunction on each call (>0) or +// use pre-sampling (=0) +std::atomic global_record_all_functions_ {0}; +} + +void bumpRecordAllFunctions() { + global_record_all_functions_.fetch_add(1, std::memory_order_relaxed); +} + +void releaseRecordAllFunctions() { + TORCH_CHECK(global_record_all_functions_.fetch_sub(1, std::memory_order_relaxed) >= 0); +} + +bool checkRecordAllFunctions() { + return (global_record_all_functions_.load(std::memory_order_relaxed) > 0); +} + +bool shouldRunRecordFunction(bool* pre_sampled) { + auto* rf_tls_ptr = &rf_tls_; + if (rf_tls_ptr->sorted_tls_callbacks_.empty() && !manager().hasGlobalCallbacks()) { + *pre_sampled = false; + return false; + } + if (global_record_all_functions_.load(std::memory_order_relaxed) > 0) { + *pre_sampled = false; + return true; + } + if (!rf_tls_ptr->tls_record_function_enabled_) { + *pre_sampled = false; + return false; + } + + *pre_sampled = true; + auto* coinflip_tls_ptr = &coinflip_tls_; + if (coinflip_tls_ptr->tries_left_ == 0) { + coinflip_tls_ptr->tries_left_ = sample_geometric(); + return true; + } else { + --coinflip_tls_ptr->tries_left_; + return false; } - active = false; } } // namespace at diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 9b4d11ef1d5f3..96b22150082b4 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -1,20 +1,30 @@ #pragma once #include -#include +#include #include +#include +#include #include #include +namespace c10 { +class TORCH_API OperatorHandle; +} + namespace at { // Kind of record function scope; enum class C10_API_ENUM RecordScope : uint8_t { // c10/ATen ops, autograd nodes FUNCTION = 0, + // Functions/nodes called from the autograd + BACKWARD_FUNCTION, // TorchScript functions, methods TORCHSCRIPT_FUNCTION, + // Kernel Function dtype Tag + KERNEL_FUNCTION_DTYPE, // User defined scope (e.g. with record_function()) USER_SCOPE, NUM_SCOPES, // must be the last in the list @@ -82,15 +92,21 @@ typedef uint64_t RecordFunctionHandle; struct TORCH_API RecordFunction { // Default constructor is used with before function called afterwards: // scope - record scope that this function tracks + // pre_sampled - whether this RecordFunction was already pre-sampled with + // kLowProb probability RecordFunction( - RecordScope scope = RecordScope::FUNCTION); + RecordScope scope = RecordScope::FUNCTION, + bool pre_sampled = false); template void before( F fn, const std::vector* args, int64_t current_sequence_nr = -1) { - inputs_ = *args; + if (!isActive()) { + return; + } + state_->inputs_ = *args; before(fn, current_sequence_nr); } @@ -101,26 +117,45 @@ struct TORCH_API RecordFunction { RecordFunction& operator=(const RecordFunction&) = delete; inline const StringView& name() const { - return name_; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called name() on inactive RecordFunction"); + return state_->name_; } inline int64_t seqNr() const { - return sequence_nr_; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called seqNr() on inactive RecordFunction"); + return state_->sequence_nr_; } const std::vector& inputs() const { - return inputs_; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called inputs() on inactive RecordFunction"); + return state_->inputs_; } // Retrieves the thread_id that this RecordFunction ran start callbacks with. // Useful for writing thread safe end callbacks that may be potentially // executed in a different thread (async ops) - inline uint64_t getStartCallbacksThreadId() const { - return thread_id_; + inline uint64_t threadId() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called threadId() on inactive RecordFunction"); + return state_->thread_id_; + } + + // For backward functions - thread id of the corresponding forward function, + // or zero otherwise; + // used alongside with sequence number to correlate backward functions with + // the forward ones + inline uint64_t forwardThreadId() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called forwardThreadId() on inactive RecordFunction"); + return state_->fwd_thread_id_; + } + + inline void setForwardThreadId(uint64_t thread_id) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setForwardThreadId() on inactive RecordFunction"); + state_->fwd_thread_id_ = thread_id; } inline RecordScope scope() const { - return scope_; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called scope() on inactive RecordFunction"); + return state_->scope_; } // Returns logical thread_id for the current thread @@ -133,6 +168,7 @@ struct TORCH_API RecordFunction { // start callbacks void before(const char* name, int64_t sequence_nr = -1); void before(std::string name, int64_t sequence_nr = -1); + void before(c10::OperatorHandle const& op, int64_t sequence_nr = -1); // Sets node ID for distributed profiling static void setDefaultNodeId(int64_t defaultNodeId); @@ -144,7 +180,10 @@ struct TORCH_API RecordFunction { F fn, c10::ArrayRef args, int64_t current_sequence_nr = -1) { - inputs_ = args.vec(); + if (!isActive()) { + return; + } + state_->inputs_ = args.vec(); before(fn, current_sequence_nr); } @@ -153,61 +192,94 @@ struct TORCH_API RecordFunction { F fn, std::vector&& args, int64_t current_sequence_nr = -1) { - inputs_ = std::move(args); + if (!isActive()) { + return; + } + state_->inputs_ = std::move(args); before(fn, current_sequence_nr); } - // Calls end callbacks + // Calls end callbacks. After end(), accessors will no longer provide useful results. void end(); inline RecordFunctionHandle handle() const { - return handle_; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called handle() on inactive RecordFunction"); + return state_->handle_; + } + + inline c10::optional operator_name() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called operator_name() on inactive RecordFunction"); + return state_->operator_name_; } inline void setHandle(RecordFunctionHandle handle) { - handle_ = handle; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setHandle() on inactive RecordFunction"); + state_->handle_ = handle; + } + + // Whether this RecordFunction runs any callbacks. + bool isActive() const { + return state_ != nullptr; } - // Whether this RecordFunction runs any callbacks - bool active = false; - // Whether any of the picked callbacks require inputs - bool needs_inputs = false; + bool needsInputs() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called needsInputs() on inactive RecordFunction"); + return state_->needs_inputs; + } private: + // Allows the modification of some internal states for callbacks. friend class CallbackManager; - // Used internally to keep track of thread local and global callbacks - // that were picked to run; must be sorted; - CallbackHandles sorted_active_tls_handles_; - CallbackHandles sorted_active_global_handles_; + struct State { + explicit State(RecordScope scope) : scope_(scope) {} + + // Whether any of the picked callbacks require inputs + bool needs_inputs = false; + + // In cases when RecordFunction might be active but we chose not to + // use the observers (e.g. operator is not observed), this boolean + // flag is used to check whether the start callbacks were called + bool called_start_callbacks_ = false; + + // Whether the RecordFunction is pre-sampled + bool pre_sampled_ = false; + + // Used internally to keep track of thread local and global callbacks + // that were picked to run; must be sorted; + CallbackHandles sorted_active_tls_handles_; + CallbackHandles sorted_active_global_handles_; + + // Stores various ObserverContext objects with event metadata for thread local + // callbacks. + ObserverContextList tls_ctx_; + + // Stores various ObserverContext objects with event metadata for global + // callbacks. + ObserverContextList global_ctx_; - // Stores various ObserverContext objects with event metadata for thread local - // callbacks. - ObserverContextList tls_ctx_; + StringView name_; + int64_t sequence_nr_ = -1; + std::vector inputs_; - // Stores various ObserverContext objects with event metadata for global - // callbacks. - ObserverContextList global_ctx_; + c10::optional operator_name_; - // In cases when RecordFunction might be active but we chose not to - // use the observers (e.g. operator is not observed), this boolean - // flag is used to check whether the start callbacks were called - bool called_start_callbacks_ = false; + // Kind of scope this RecordFunction is observing + const RecordScope scope_; - StringView name_; - int64_t sequence_nr_ = -1; - std::vector inputs_; + // The logical thread_id that this RecordFunction was created with + uint64_t thread_id_ = 0; - // Kind of scope this RecordFunction is observing - const RecordScope scope_; + // For backward functions - thread id of the the forward function + uint64_t fwd_thread_id_ = 0; - // The logical thread_id that this RecordFunction was created with - uint64_t thread_id_ = 0; + // Unique id for this RecordFunction, used in callbacks to track start + // and end of ranges + RecordFunctionHandle handle_ {0}; + }; - // Unique id for this RecordFunction, used in callbacks to track start - // and end of ranges - RecordFunctionHandle handle_ {0}; + std::unique_ptr state_; }; // @@ -233,25 +305,16 @@ struct TORCH_API RecordFunction { */ class TORCH_API RecordFunctionCallback { public: - // This interface supports observers that require passing an ObserverContext - // between start and end callbacks. - explicit RecordFunctionCallback( - std::function(const RecordFunction&)> start, - std::function end = - [](const RecordFunction&, ObserverContext*) {}): - start_(std::move(start)), - end_(std::move(end)) { - scopes_.fill(true); - } + using StartCallback = std::unique_ptr(*)(const RecordFunction&); + using EndCallback = void (*)(const RecordFunction&, ObserverContext*); - // This interface is for observers that do not pass an ObserverContext object + // This interface supports observers that require passing an ObserverContext // between start and end callbacks. explicit RecordFunctionCallback( - std::function start, - std::function end = - [](const RecordFunction&) {}): - start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }}, - end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} { + StartCallback start, + EndCallback end = nullptr) : + start_(start), + end_(end) { scopes_.fill(true); } @@ -266,7 +329,7 @@ class TORCH_API RecordFunctionCallback { } RecordFunctionCallback& samplingProb(double sampling_prob) { - TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob_ <= 1.0, + TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob <= 1.0, "Invalid sampling probability"); sampling_prob_ = sampling_prob; return *this; @@ -286,8 +349,8 @@ class TORCH_API RecordFunctionCallback { } RecordFunctionCallback& setShouldRun( - std::function should_run) { - should_run_ = std::move(should_run); + bool(*should_run)(const RecordFunctionCallback&)) { + should_run_ = should_run; return *this; } @@ -307,33 +370,31 @@ class TORCH_API RecordFunctionCallback { return scopes_[(size_t)sc]; } - inline const std::function(const RecordFunction&)>& start() const { + inline StartCallback start() const { return start_; } - inline const std::function& end() const { + inline EndCallback end() const { return end_; } - // whether the callbacks should run in the given scope - bool shouldRun(RecordScope scope) const; - private: - std::function(const RecordFunction&)> start_; - std::function end_; - std::function should_run_; - bool needs_inputs_ = false; - bool needs_ids_ = false; + friend class CallbackManager; + StartCallback start_; + EndCallback end_; + bool(*should_run_)(const RecordFunctionCallback&) = nullptr; double sampling_prob_ = 1.0; std::array(RecordScope::NUM_SCOPES)> scopes_ = {}; + bool needs_inputs_ = false; + bool needs_ids_ = false; }; // Using macro to minimize inputs copies, // optional argument - function's seq_no #define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \ at::RecordFunction guard(scope); \ - if (guard.active) { \ - if (guard.needs_inputs) { \ + if (guard.isActive()) { \ + if (guard.needsInputs()) { \ guard.before(fn, inputs, ##__VA_ARGS__); \ } else { \ guard.before(fn, ##__VA_ARGS__); \ @@ -354,6 +415,11 @@ class TORCH_API RecordFunctionCallback { RECORD_FUNCTION_WITH_SCOPE( \ at::RecordScope::USER_SCOPE, fn, {}) +// RECORD_USER_SCOPE with inputs +#define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::USER_SCOPE, fn, inputs) + // Notes: // - two types of callbacks are provided: thread local and global // - thread local callbacks are added/removed only for the given thread @@ -471,4 +537,33 @@ class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard { TORCH_API RecordFunctionCallbacks _getTLSCallbacks(); TORCH_API void _setTLSCallbacks(const RecordFunctionCallbacks& callbacks); +struct TORCH_API RecordFunctionTLS { + // Thread local vector of callbacks, holds pairs (callbacks, unique_id); + // must be sorted in increasing handles order + RecordFunctionCallbacks sorted_tls_callbacks_; + + bool tls_record_function_enabled_ = true; + + // Stores the number of coin flips before the next successful coin flip + int tries_left_ = 0; +}; + +TORCH_API const RecordFunctionTLS& get_record_function_tls_(); + +TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); + +// Checks whether RecordFunction should be called, +// sets boolean pointed by the argument to whether pre-sampling was used +TORCH_API bool shouldRunRecordFunction(bool*); + +// The following functions are used to disable/enable pre-sampling of RecordFunction +// when high-frequency/non-sampled callbacks are added/removed. +// Note: every call to bumpRecordAllFunctions() is supposed to be matched with +// the corresponding releaseRecordAllFunctions() call. +// Note: disabling pre-sampling of RecordFunction incurs an extra overhead, since +// RecordFunction will be created for each operator call. +TORCH_API void bumpRecordAllFunctions(); +TORCH_API void releaseRecordAllFunctions(); +TORCH_API bool checkRecordAllFunctions(); + } // namespace at diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index 7c9aa96f6e70e..37c2919bb4586 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -3,17 +3,25 @@ #include #include -#include -#include -#include #include -#ifdef USE_VULKAN -#include -#endif namespace at { -using native::tensor; +Tensor var(const Tensor& self, int dim) { + return at::var(self, IntArrayRef{dim}); +} + +std::tuple var_mean(const Tensor& self, int dim) { + return at::var_mean(self, IntArrayRef{dim}); +} + +Tensor std(const Tensor& self, int dim) { + return at::std(self, IntArrayRef{dim}); +} + +std::tuple std_mean(const Tensor& self, int dim) { + return at::std_mean(self, IntArrayRef{dim}); +} ${function_definitions} diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index 81e46642ad581..767ca01a7302d 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -7,7 +7,6 @@ #include #include #include -#include // TODO: try to delete this #include #include #include @@ -19,28 +18,38 @@ namespace at { -using native::tensor; +// These functions are defined in ATen/Utils.cpp. +#define TENSOR(T, S) \ + TORCH_API Tensor tensor(ArrayRef values, const TensorOptions& options); \ + inline Tensor tensor( \ + std::initializer_list values, const TensorOptions& options) { \ + return at::tensor(ArrayRef(values), options); \ + } \ + inline Tensor tensor(T value, const TensorOptions& options) { \ + return at::tensor(ArrayRef(value), options); \ + } \ + inline Tensor tensor(ArrayRef values) { \ + return at::tensor(std::move(values), at::dtype(k##S)); \ + } \ + inline Tensor tensor(std::initializer_list values) { \ + return at::tensor(ArrayRef(values)); \ + } \ + inline Tensor tensor(T value) { \ + return at::tensor(ArrayRef(value)); \ + } +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) +AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR ${function_declarations} // Special C++ only overloads for std()-like functions (See gh-40287) // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef // So, for example std(0) would select the std(unbiased=False) overload -inline Tensor var(const Tensor& self, int dim) { - return at::native::var(self, IntArrayRef{dim}); -} - -inline std::tuple var_mean(const Tensor& self, int dim) { - return at::native::var_mean(self, IntArrayRef{dim}); -} - -inline Tensor std(const Tensor& self, int dim) { - return at::native::std(self, IntArrayRef{dim}); -} - -inline std::tuple std_mean(const Tensor& self, int dim) { - return at::native::std_mean(self, IntArrayRef{dim}); -} +TORCH_API Tensor var(const Tensor& self, int dim); +TORCH_API std::tuple var_mean(const Tensor& self, int dim); +TORCH_API Tensor std(const Tensor& self, int dim); +TORCH_API std::tuple std_mean(const Tensor& self, int dim); namespace { inline std::vector zero_sizes(const TensorOptions& options) { @@ -125,4 +134,12 @@ inline int64_t numel(const Tensor& tensor) { return tensor.numel(); } +inline int64_t size(const Tensor& tensor, int64_t dim) { + return tensor.size(dim); +} + +inline int64_t stride(const Tensor& tensor, int64_t dim) { + return tensor.stride(dim); +} + } diff --git a/aten/src/ATen/templates/MetaFunctions.h b/aten/src/ATen/templates/MetaFunctions.h new file mode 100644 index 0000000000000..7ad20b734330a --- /dev/null +++ b/aten/src/ATen/templates/MetaFunctions.h @@ -0,0 +1,15 @@ +#pragma once + +// ${generated_comment} + +#include +#include + +namespace at { + +namespace meta { + +${declarations} + +} // namespace meta +} // namespace at diff --git a/aten/src/ATen/templates/NativeFunctions.h b/aten/src/ATen/templates/NativeFunctions.h index 7d0dab41a6405..2e35fde1b95e9 100644 --- a/aten/src/ATen/templates/NativeFunctions.h +++ b/aten/src/ATen/templates/NativeFunctions.h @@ -3,9 +3,10 @@ // ${generated_comment} #include +#include +#include #include #include -#include #include #include @@ -25,29 +26,6 @@ struct Type; namespace at { namespace native { -// These functions are defined in native/TensorFactories.cpp. -#define TENSOR(T, S) \ - CAFFE2_API Tensor tensor(ArrayRef values, const TensorOptions& options); \ - inline Tensor tensor( \ - std::initializer_list values, const TensorOptions& options) { \ - return native::tensor(ArrayRef(values), options); \ - } \ - inline Tensor tensor(T value, const TensorOptions& options) { \ - return native::tensor(ArrayRef(value), options); \ - } \ - inline Tensor tensor(ArrayRef values) { \ - return native::tensor(std::move(values), at::dtype(k##S)); \ - } \ - inline Tensor tensor(std::initializer_list values) { \ - return native::tensor(ArrayRef(values)); \ - } \ - inline Tensor tensor(T value) { \ - return native::tensor(ArrayRef(value)); \ - } -AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) -AT_FORALL_COMPLEX_TYPES(TENSOR) -#undef TENSOR - ${native_function_declarations} } // namespace native diff --git a/aten/src/ATen/templates/BackendSelectRegister.cpp b/aten/src/ATen/templates/RegisterBackendSelect.cpp similarity index 89% rename from aten/src/ATen/templates/BackendSelectRegister.cpp rename to aten/src/ATen/templates/RegisterBackendSelect.cpp index db7276913201e..bcbf25f3117ff 100644 --- a/aten/src/ATen/templates/BackendSelectRegister.cpp +++ b/aten/src/ATen/templates/RegisterBackendSelect.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include namespace at { diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp similarity index 65% rename from aten/src/ATen/templates/TypeDerived.cpp rename to aten/src/ATen/templates/RegisterDispatchKey.cpp index d65c13ae8d97d..ed4359c6883ed 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -5,15 +5,13 @@ #define __STDC_FORMAT_MACROS #endif -#include - // ${generated_comment} -$storage_tensor_headers -#include +#include #include #include #include +#include #include #include #include @@ -22,6 +20,9 @@ #include #include #include +#include +#include +#include #include #include @@ -36,21 +37,17 @@ namespace at { -/* example -Tensor * ${Type}::add(Tensor & a, Tensor & b) { - std::cout << "add Tensor with backend ${Backend}\n"; - return &a; -} -*/ - -namespace ${Type} { - -${type_derived_method_definitions} +${dispatch_definitions} -} // namespace ${Type} +// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid +// ambiguity with conflicting identifiers that may have been defined in +// at namespace already. +namespace { -TORCH_LIBRARY_IMPL(aten, ${Backend}, m) { - ${function_registrations} +TORCH_LIBRARY_IMPL(aten, ${DispatchKey}, m) { + ${dispatch_registrations} } +} // anonymous namespace + } // namespace at diff --git a/aten/src/ATen/templates/TypeDefault.cpp b/aten/src/ATen/templates/RegisterSchema.cpp similarity index 91% rename from aten/src/ATen/templates/TypeDefault.cpp rename to aten/src/ATen/templates/RegisterSchema.cpp index 58c80381d3407..7098d587a3f93 100644 --- a/aten/src/ATen/templates/TypeDefault.cpp +++ b/aten/src/ATen/templates/RegisterSchema.cpp @@ -17,19 +17,14 @@ #include namespace at { -namespace TypeDefault { - -${type_method_definitions} - -} // namespace TypeDefault - TORCH_LIBRARY(aten, m) { - ${function_registrations}; + ${schema_registrations}; // String Ops // Implementations located in torch/csrc/jit/runtime/register_prim_ops.cpp m.def(TORCH_SELECTIVE_SCHEMA("aten::splitlines(str self, bool keepends=False) -> str[]")); - m.def(TORCH_SELECTIVE_SCHEMA("aten::slice.str(str string, int start, int end=9223372036854775807, int step=1) -> str")); + m.def(TORCH_SELECTIVE_SCHEMA( + "aten::slice.str(str string, int? start=0, int? end=9223372036854775807, int step=1) -> str")); m.def(TORCH_SELECTIVE_SCHEMA("aten::isupper(str self) -> bool")); m.def(TORCH_SELECTIVE_SCHEMA("aten::islower(str self) -> bool")); m.def(TORCH_SELECTIVE_SCHEMA("aten::capitalize(str self) -> str")); @@ -63,9 +58,4 @@ TORCH_LIBRARY(aten, m) { // Implementations located in torch/csrc/jit/runtime/register_distributed_ops.cpp m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)"); } - -TORCH_LIBRARY_IMPL(aten, Math, m) { - ${math_function_registrations}; -} - } // namespace at diff --git a/tools/autograd/templates/RegistrationDeclarations.h b/aten/src/ATen/templates/RegistrationDeclarations.h similarity index 100% rename from tools/autograd/templates/RegistrationDeclarations.h rename to aten/src/ATen/templates/RegistrationDeclarations.h diff --git a/aten/src/ATen/templates/SchemaRegister.cpp b/aten/src/ATen/templates/SchemaRegister.cpp deleted file mode 100644 index f48e732f47600..0000000000000 --- a/aten/src/ATen/templates/SchemaRegister.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// ${generated_comment} - -#include -#include - -using namespace at; - -TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(aten, m) { - ${schema_registrations} -} diff --git a/aten/src/ATen/templates/SparseTypeDerived.cpp b/aten/src/ATen/templates/SparseTypeDerived.cpp deleted file mode 100644 index b0a4fed24a630..0000000000000 --- a/aten/src/ATen/templates/SparseTypeDerived.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// required for old g++ to compile PRId64 macros, see -// https://github.com/pytorch/pytorch/issues/3571 -// for context -#define __STDC_FORMAT_MACROS - -#include - -// ${generated_comment} - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -$extra_cuda_headers - -namespace at { - -namespace ${Type} { - -${type_derived_method_definitions} - -} // namespace ${Type} - -TORCH_LIBRARY_IMPL(aten, ${Backend}, m) { - ${function_registrations}; -} - -} // namespace at diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 202b2124f2862..0dfef701c51b2 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -4,12 +4,15 @@ #include #include #include +#include #include #include +#include #include #include #include #include +#include #include #include #include @@ -25,6 +28,7 @@ class Tensor; } namespace c10{ struct TensorOptions; +template class List; } namespace at { struct Generator; @@ -49,6 +53,8 @@ namespace at { class Tensor; using TensorList = ArrayRef; +using Stream = c10::Stream; + namespace impl { inline bool variable_excluded_from_dispatch() { #ifdef C10_MOBILE @@ -78,7 +84,7 @@ inline bool variable_excluded_from_dispatch() { // // Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and // special care must be taken to handle this. -class CAFFE2_API Tensor { +class TORCH_API Tensor { public: Tensor(){}; // This constructor should not be used by end users and is an implementation @@ -111,6 +117,26 @@ class CAFFE2_API Tensor { return impl_->storage_offset(); } + Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + if (is_contiguous(memory_format)) { + return *this; + } else { + return __dispatch_contiguous(memory_format); + } + } + + int64_t size(int64_t dim) const { + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + dim = c10::maybe_wrap_dim(dim, this->dim(), false); + return sizes()[dim]; + } + + int64_t stride(int64_t dim) const { + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + dim = c10::maybe_wrap_dim(dim, this->dim(), false); + return strides()[dim]; + } + TensorImpl * unsafeGetTensorImpl() const { return impl_.get(); } @@ -182,10 +208,6 @@ class CAFFE2_API Tensor { Tensor& operator=(const Tensor&) &&; Tensor& operator=(Tensor&&) &&; - #ifdef _MSC_VER - #pragma warning( pop ) - #endif - bool is_same(const Tensor& other) const noexcept { return impl_ == other.impl_; } @@ -205,7 +227,7 @@ class CAFFE2_API Tensor { return impl_->strides(); } // See impl::get_opt_names in ATen/NamedTensor.h for docs. - optional opt_names() const { + c10::optional opt_names() const { return impl::get_opt_names(unsafeGetTensorImpl()); } // See impl::get_names in ATen/NamedTensor.h for docs. @@ -329,6 +351,9 @@ class CAFFE2_API Tensor { /// Returns if a `Tensor` is vulkan tensor. bool is_vulkan() const; + /// Returns if a `Tensor` is metal tensor. + bool is_metal() const; + /// Returns if a `Tensor` has quantized backend. bool is_quantized() const; @@ -375,7 +400,7 @@ class CAFFE2_API Tensor { template TensorAccessor accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); - TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); return TensorAccessor(data_ptr(),sizes().data(),strides().data()); } template @@ -389,7 +414,7 @@ class CAFFE2_API Tensor { template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> GenericPackedTensorAccessor generic_packed_accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); - TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); } template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> @@ -447,6 +472,7 @@ class CAFFE2_API Tensor { Tensor cuda() const; Tensor hip() const; Tensor vulkan() const; + Tensor metal() const; // ~~~~~ Autograd API ~~~~~ @@ -487,7 +513,7 @@ class CAFFE2_API Tensor { /// // f requires grad, has no operation creating it /// @endcode - /// \fn void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false) const; + /// \fn void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const; /// /// Computes the gradient of current tensor with respect to graph leaves. /// @@ -514,6 +540,23 @@ class CAFFE2_API Tensor { /// \param create_graph If ``true``, graph of the derivative will /// be constructed, allowing to compute higher order derivative /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. All the provided inputs + /// must be leaf Tensors. + void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const { + // NB: Adding this wrapper to _backward here because we'd like our + // 'backwards' api to accept the 'inputs' argument optionally. Since code gen + // currently does not support optional of TensorList our approach is to replace + // backward in native_functions.yaml with _backward and call it here instead. + if (inputs.has_value()) { + TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty") + this->_backward(inputs.value(), gradient, retain_graph, create_graph); + } else { + this->_backward({}, gradient, retain_graph, create_graph); + } + } /// \fn Tensor detach() const; /// @@ -553,6 +596,23 @@ class CAFFE2_API Tensor { return impl_->grad(); } + // The Forward AD API functions below are low level and are not to be used by end + // users who should use the API provided in torch/csrc/autograd.h + + /// This function returns the forward gradient for this Tensor at the given level. + const Tensor& fw_grad(uint64_t level) const { + return impl_->fw_grad(level, *this); + } + + /// This function can be used to set the value of the forward grad. + /// Note that the given new_grad might not be used directly if it has different + /// metadata (size/stride/storage offset) compared to this Tensor. In that case, + /// new_grad content will be copied into a new Tensor + void set_fw_grad(const Tensor& new_grad, uint64_t level, bool is_inplace_op) { + impl_->set_fw_grad(new_grad, *this, level, is_inplace_op); + } + + // STOP. Thinking of adding a method here, which only makes use // of other ATen methods? Define it in native_functions.yaml. @@ -697,6 +757,12 @@ class CAFFE2_API Tensor { c10::intrusive_ptr impl_; }; +// For "multiple ... operators specified" warnings, closing brace of class +// declaration must be included between pragma push & pop +#ifdef _MSC_VER +#pragma warning( pop ) +#endif + int64_t get_device(Tensor self); template diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp index 064f5911cb107..82a4336da0391 100644 --- a/aten/src/ATen/templates/TensorMethods.cpp +++ b/aten/src/ATen/templates/TensorMethods.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -14,6 +15,8 @@ namespace at { +using Stream = c10::Stream; + Tensor Tensor::cpu() const { return to(options().device(DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); } @@ -31,6 +34,10 @@ Tensor Tensor::vulkan() const { return to(options().device(DeviceType::Vulkan), /*non_blocking*/ false, /*copy*/ false); } +Tensor Tensor::metal() const { + return to(options().device(DeviceType::Metal), /*non_blocking*/ false, /*copy*/ false); +} + Tensor Tensor::toType(ScalarType t) const { return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false); } @@ -127,10 +134,20 @@ bool Tensor::is_vulkan() const { return impl_->is_vulkan(); } +bool Tensor::is_metal() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_metal(); +} + + bool is_vulkan(Tensor self) { return self.is_vulkan(); } +bool is_metal(Tensor self) { + return self.is_metal(); +} + bool Tensor::is_quantized() const { // NB: this is not a native function to avoid dispatching overhead. return impl_->is_quantized(); diff --git a/aten/src/ATen/templates/TypeDerived.h b/aten/src/ATen/templates/TypeDerived.h deleted file mode 100644 index 4b571f40383f1..0000000000000 --- a/aten/src/ATen/templates/TypeDerived.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -// ${generated_comment} - -#include -#include -#include -#include -#include -#include -#include -#include - -$extra_cuda_headers - -namespace c10 { -struct Storage; -} - -namespace at { - -class Tensor; -using TensorList = ArrayRef; - -class Context; -struct Generator; - -struct Quantizer; -// This is temporary typedef to enable Quantizer in aten native function API -// we'll remove them when we are actually exposing Quantizer class -// to frontend -using ConstQuantizerPtr = const c10::intrusive_ptr&; - -namespace ${Type} { - ${type_derived_method_declarations} -} - -} // namespace at diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 43d0fc8ccd923..4bf9bf46a9652 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -29,7 +29,9 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/math_kernel_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory_overlapping_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/mobile_memory_cleanup.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/variant_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce_ops_test.cpp @@ -75,15 +77,17 @@ list(APPEND ATen_HIP_TEST_SRCS # ${CMAKE_CURRENT_SOURCE_DIR}/hip/hip_stream_test.cpp list(APPEND ATen_VULKAN_TEST_SRCS - ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_api_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_api_test.cpp) list(APPEND ATen_MOBILE_TEST_SRCS - ${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test_all_types.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp) list(APPEND ATen_VEC256_TEST_SRCS - ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test_all_types.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test_all_types.cpp ) # Caffe2 specific tests diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 144c4671e50f4..1055eec9833a3 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -80,6 +80,23 @@ void TestAdd(DeprecatedTypeProperties& type) { } } +void TestZeros(DeprecatedTypeProperties& type) { + auto begin = std::chrono::high_resolution_clock::now(); + Tensor a = zeros({1024, 1024}, type); + for (int i = 1; i < 1000; ++i) { + a = zeros({128, 128}, type); + } + auto end = std::chrono::high_resolution_clock::now(); + std::cout << std::dec << " " + << std::chrono::duration_cast( + end - begin) + .count() + << " ms" << std::endl; + + std::srand(std::time(nullptr)); + ASSERT_EQ(norm(a).item(), 0.0); +} + void TestLoadsOfAdds(DeprecatedTypeProperties& type) { auto begin = std::chrono::high_resolution_clock::now(); Tensor d = ones({3, 4}, type); @@ -309,6 +326,7 @@ void test(DeprecatedTypeProperties& type) { TestSort(type); TestRandperm(type); TestAdd(type); + TestZeros(type); TestLoadsOfAdds(type); TestLoadOfAddsWithCopy(type); TestIsContiguous(type); @@ -370,14 +388,8 @@ TEST(BasicTest, FactoryMethodsTest) { ASSERT_FALSE(tensor0.is_pinned()); // Test setting requires_grad to true. - tensor0 = at::empty({4}, at::TensorOptions().requires_grad(true)); - ASSERT_EQ(tensor0.dtype(), at::kFloat); - ASSERT_EQ(tensor0.layout(), at::kStrided); - ASSERT_EQ(tensor0.device(), at::kCPU); - // This is a bug. Requires_grad was set to TRUE but this is being ignored. - // Issue https://github.com/pytorch/pytorch/issues/30405 - ASSERT_FALSE(tensor0.requires_grad()); - ASSERT_FALSE(tensor0.is_pinned()); + // This is a bug. Requires_grad was set to TRUE but this is not implemented. + EXPECT_ANY_THROW(at::empty({4}, at::TensorOptions().requires_grad(true))); // Test setting dtype at::Tensor tensor1 = at::empty({4}, at::TensorOptions().dtype(at::kHalf)); diff --git a/aten/src/ATen/test/cpu_caching_allocator_test.cpp b/aten/src/ATen/test/cpu_caching_allocator_test.cpp index 28a9b0476524e..cead52f5a7cc2 100644 --- a/aten/src/ATen/test/cpu_caching_allocator_test.cpp +++ b/aten/src/ATen/test/cpu_caching_allocator_test.cpp @@ -3,7 +3,7 @@ #include #include -#include +#include TEST(CPUCachingAllocatorTest, check_alloc_free) { c10::CPUCachingAllocator caching_allocator; diff --git a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp new file mode 100644 index 0000000000000..32396718a1c8d --- /dev/null +++ b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp @@ -0,0 +1,176 @@ +#include + +#include +#include +#include +#include + +at::Tensor run_with_control_flow( + at::Tensor input, + at::Tensor conv_weight, + at::Tensor linear_weight, + bool cond, + std::vector& pointers, + bool record = false, + bool validate = false) { + if (cond) { + input = input * 2; + } + void* input_ptr = input.data_ptr(); + auto conv_out = at::conv2d(input, conv_weight); + void* conv_out_ptr = input.data_ptr(); + auto conv_out_flat = conv_out.view({conv_out.size(0), -1}); + auto output = at::linear(conv_out_flat, linear_weight); + if (record) { + pointers.push_back(input_ptr); + pointers.push_back(conv_out_ptr); + } + if (validate) { + TORCH_CHECK(input_ptr == pointers[0]); + TORCH_CHECK(conv_out_ptr == pointers[1]); + } + return output; +} + +TEST(CPUAllocationPlanTest, with_control_flow) { + at::Tensor a = at::rand({23, 16, 16, 16}); + at::Tensor conv_weight = at::rand({16, 16, 3, 3}); + // output shape + // 23, 16, 14, 14 + // Flattened shape = 23, 3136 + at::Tensor linear_weight = at::rand({32, 3136}); + at::Tensor output, ref_output; + std::vector pointers; + + auto valid_allocation_plan = [&]() { + c10::AllocationPlan plan; + { + c10::WithProfileAllocationsGuard profile_guard(&plan); + ref_output = run_with_control_flow( + a, conv_weight, linear_weight, true, pointers); + } + }; + ASSERT_NO_THROW(valid_allocation_plan()); + + auto validate_allocation_plan = + [&](bool record_mode, bool validation_mode) -> bool { + c10::AllocationPlan plan; + { + c10::WithProfileAllocationsGuard profile_guard(&plan); + ref_output = + run_with_control_flow(a, conv_weight, linear_weight, record_mode, pointers); + } + bool success{true}; + for (uint64_t i = 0; i < 10; ++i) { + bool validation_success; + { + c10::WithValidateAllocationPlanGuard + validation_guard(&plan, &validation_success); + output = run_with_control_flow( + a, conv_weight, linear_weight, validation_mode, pointers); + } + success = success && validation_success; + } + return success; + }; + ASSERT_FALSE(validate_allocation_plan(false, true)); + ASSERT_FALSE(validate_allocation_plan(true, false)); + ASSERT_TRUE(validate_allocation_plan(true, true)); + ASSERT_TRUE(validate_allocation_plan(false, false)); +} + +TEST(CPUAllocationPlanTest, with_profiling_alloc) { + at::Tensor a = at::rand({23, 16, 16, 16}); + at::Tensor conv_weight = at::rand({16, 16, 3, 3}); + // output shape + // 23, 16, 14, 14 + // Flattened shape = 23, 3136 + at::Tensor linear_weight = at::rand({32, 3136}); + at::Tensor output, ref_output; + std::vector pointers; + + auto valid_allocation_plan = [&]() { + c10::AllocationPlan plan; + { + c10::WithProfileAllocationsGuard profile_guard(&plan); + ref_output = run_with_control_flow( + a, conv_weight, linear_weight, false, pointers); + } + }; + ASSERT_NO_THROW(valid_allocation_plan()); + + auto validate_allocation_plan = + [&](bool record_mode, + bool validation_mode, + bool validate_pointers) { + pointers.clear(); + c10::AllocationPlan plan; + { + c10::WithProfileAllocationsGuard profile_guard(&plan); + ref_output = run_with_control_flow( + a, + conv_weight, + linear_weight, + record_mode, + pointers, + false, + false); + } + c10::CPUProfilingAllocator profiling_allocator; + { + c10::WithProfilingAllocatorGuard + profiling_allocator_guard(&profiling_allocator, &plan); + output = run_with_control_flow( + a, + conv_weight, + linear_weight, + validation_mode, + pointers, + validate_pointers, + false); + } + for (uint64_t i = 0; i < 10; ++i) { + { + c10::WithProfilingAllocatorGuard + profiling_allocator_guard(&profiling_allocator, &plan); + output = run_with_control_flow( + a, + conv_weight, + linear_weight, + validation_mode, + pointers, + false, + validate_pointers); + } + } + }; + // When control flow conditions are same between profiling and evaluation + // profiling allocator should not throw. + ASSERT_NO_THROW(validate_allocation_plan(true, true, false)); + ASSERT_TRUE(ref_output.equal(output)); + ASSERT_NO_THROW(validate_allocation_plan(false, false, false)); + ASSERT_TRUE(ref_output.equal(output)); + // Furthermore profiling allocator should return the same pointers + // back for the intermediate tensors + ASSERT_NO_THROW(validate_allocation_plan(true, true, true)); + ASSERT_TRUE(ref_output.equal(output)); + ASSERT_NO_THROW(validate_allocation_plan(false, false, true)); + ASSERT_TRUE(ref_output.equal(output)); + + // When control flow conditions are different between profiling and evaluation + // profiling allocator should throw. + ASSERT_THROW(validate_allocation_plan(true, false, false), c10::Error); + ASSERT_THROW(validate_allocation_plan(false, true, false), c10::Error); +} + +int main(int argc, char* argv[]) { + // Setting the priority high to make sure no other allocator gets used instead of this. + c10::SetCPUAllocator(c10::GetDefaultMobileCPUAllocator(), /*priority*/ 100); + // Need to disable mkldnn for this test since it allocatred memory + // via raw_allocate inteface which requires context pointer and raw + // pointer to be the same. Tis is not true for mobile allocator. + at::globalContext().setUserEnabledMkldnn(false); + ::testing::InitGoogleTest(&argc, argv); + at::manual_seed(42); + return RUN_ALL_TESTS(); +} diff --git a/aten/src/ATen/test/cpu_rng_test.cpp b/aten/src/ATen/test/cpu_rng_test.cpp index f77ee41bbd845..805ed40557b68 100644 --- a/aten/src/ATen/test/cpu_rng_test.cpp +++ b/aten/src/ATen/test/cpu_rng_test.cpp @@ -28,6 +28,8 @@ struct TestCPUGenerator : public c10::GeneratorImpl { void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); } uint64_t current_seed() const override { throw std::runtime_error("not implemented"); } uint64_t seed() override { throw std::runtime_error("not implemented"); } + void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); } + c10::intrusive_ptr get_state() const override { throw std::runtime_error("not implemented"); } TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); } static DeviceType device_type() { return DeviceType::CPU; } @@ -57,15 +59,15 @@ Tensor& normal_(Tensor& self, double mean, double std, c10::optional return at::native::templates::normal_impl_(self, mean, std, gen); } -Tensor& normal_Tensor_float_out(Tensor& output, const Tensor& mean, double std, c10::optional gen) { +Tensor& normal_Tensor_float_out(const Tensor& mean, double std, c10::optional gen, Tensor& output) { return at::native::templates::normal_out_impl(output, mean, std, gen); } -Tensor& normal_float_Tensor_out(Tensor& output, double mean, const Tensor& std, c10::optional gen) { +Tensor& normal_float_Tensor_out(double mean, const Tensor& std, c10::optional gen, Tensor& output) { return at::native::templates::normal_out_impl(output, mean, std, gen); } -Tensor& normal_Tensor_Tensor_out(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional gen) { +Tensor& normal_Tensor_Tensor_out(const Tensor& mean, const Tensor& std, c10::optional gen, Tensor& output) { return at::native::templates::normal_out_impl(output, mean, std, gen); } @@ -121,36 +123,36 @@ Tensor& bernoulli_float(Tensor& self, double p, c10::optional gen) { return at::native::templates::bernoulli_impl_(self, p, gen); } -Tensor& bernoulli_out(Tensor& result, const Tensor& self, c10::optional gen) { +Tensor& bernoulli_out(const Tensor& self, c10::optional gen, Tensor& result) { return at::native::templates::bernoulli_out_impl(result, self, gen); } TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) { // Random - m.impl_UNBOXED("random_.from", random_from_to); - m.impl_UNBOXED("random_.to", random_to); - m.impl_UNBOXED("random_", random_); + m.impl("random_.from", random_from_to); + m.impl("random_.to", random_to); + m.impl("random_", random_); // Normal - m.impl_UNBOXED("normal_", normal_); - m.impl_UNBOXED("normal.Tensor_float_out", normal_Tensor_float_out); - m.impl_UNBOXED("normal.float_Tensor_out", normal_float_Tensor_out); - m.impl_UNBOXED("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out); - m.impl_UNBOXED("normal.Tensor_float", normal_Tensor_float); - m.impl_UNBOXED("normal.float_Tensor", normal_float_Tensor); - m.impl_UNBOXED("normal.Tensor_Tensor", normal_Tensor_Tensor); - m.impl_UNBOXED("uniform_", uniform_); + m.impl("normal_", normal_); + m.impl("normal.Tensor_float_out", normal_Tensor_float_out); + m.impl("normal.float_Tensor_out", normal_float_Tensor_out); + m.impl("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out); + m.impl("normal.Tensor_float", normal_Tensor_float); + m.impl("normal.float_Tensor", normal_float_Tensor); + m.impl("normal.Tensor_Tensor", normal_Tensor_Tensor); + m.impl("uniform_", uniform_); // Cauchy - m.impl_UNBOXED("cauchy_", cauchy_); + m.impl("cauchy_", cauchy_); // LogNormal - m.impl_UNBOXED("log_normal_", log_normal_); + m.impl("log_normal_", log_normal_); // Geometric - m.impl_UNBOXED("geometric_", geometric_); + m.impl("geometric_", geometric_); // Exponential - m.impl_UNBOXED("exponential_", exponential_); + m.impl("exponential_", exponential_); // Bernoulli - m.impl_UNBOXED("bernoulli.out", bernoulli_out); - m.impl_UNBOXED("bernoulli_.Tensor", bernoulli_Tensor); - m.impl_UNBOXED("bernoulli_.float", bernoulli_float); + m.impl("bernoulli.out", bernoulli_out); + m.impl("bernoulli_.Tensor", bernoulli_Tensor); + m.impl("bernoulli_.float", bernoulli_float); } class RNGTest : public ::testing::Test { diff --git a/aten/src/ATen/test/cuda_atomic_ops_test.cu b/aten/src/ATen/test/cuda_atomic_ops_test.cu index 69e48a3655f9d..920a72452916a 100644 --- a/aten/src/ATen/test/cuda_atomic_ops_test.cu +++ b/aten/src/ATen/test/cuda_atomic_ops_test.cu @@ -1,6 +1,7 @@ #include #include #include +#include const int blocksize = 256; const int factor = 4; @@ -10,7 +11,7 @@ template __global__ void addition_test_kernel(T * a, T * sum) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int idx = (tid) % arraysize; - + gpuAtomicAdd(&sum[idx], a[idx]); } @@ -18,7 +19,7 @@ template __global__ void mul_test_kernel(T * a, T * sum) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int idx = (tid) % arraysize; - + gpuAtomicMul(&sum[idx], a[idx]); } @@ -28,7 +29,7 @@ void test_atomic_add() { dim3 dimGrid(1, 1); T *a, *sum, *answer, *ad, *sumd; - + a = (T*)malloc(arraysize * sizeof(T)); sum = (T*)malloc(arraysize * sizeof(T)); answer = (T*)malloc(arraysize * sizeof(T)); @@ -41,7 +42,7 @@ void test_atomic_add() { cudaMalloc((void**)&ad, arraysize * sizeof(T)); cudaMalloc((void**)&sumd, arraysize * sizeof(T)); - + cudaMemcpy(ad, a, arraysize * sizeof(T), cudaMemcpyHostToDevice); cudaMemcpy(sumd, sum, arraysize * sizeof(T), cudaMemcpyHostToDevice); @@ -66,7 +67,7 @@ void test_atomic_mul() { dim3 dimGrid(1, 1); T *a, *sum, *answer, *ad, *sumd; - + a = (T*)malloc(arraysize * sizeof(T)); sum = (T*)malloc(arraysize * sizeof(T)); answer = (T*)malloc(arraysize * sizeof(T)); @@ -74,12 +75,12 @@ void test_atomic_mul() { for (int i = 0; i < arraysize; ++i) { a[i] = 2; sum[i] = 2; - answer[i] = pow(sum[i], factor); + answer[i] = pow(sum[i], static_cast(factor)); } cudaMalloc((void**)&ad, arraysize * sizeof(T)); cudaMalloc((void**)&sumd, arraysize * sizeof(T)); - + cudaMemcpy(ad, a, arraysize * sizeof(T), cudaMemcpyHostToDevice); cudaMemcpy(sumd, sum, arraysize * sizeof(T), cudaMemcpyHostToDevice); @@ -104,7 +105,7 @@ TEST(TestAtomicOps, TestAtomicAdd) { test_atomic_add(); test_atomic_add(); test_atomic_add(); - + test_atomic_add(); test_atomic_add(); test_atomic_add(); @@ -113,7 +114,7 @@ TEST(TestAtomicOps, TestAtomicAdd) { test_atomic_add >(); } -TEST(TestAtomicOps, TestAtomicMul) { +TEST(TestAtomicOps, DISABLED_ON_WINDOWS(TestAtomicMul)) { test_atomic_mul(); test_atomic_mul(); test_atomic_mul(); diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index a8220a319a753..f0ec67a49ac03 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -10,7 +10,8 @@ using namespace at; static int test_int; -Tensor empty_override(IntArrayRef size, const TensorOptions & options, c10::optional optional_memory_format) { +Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { test_int = 1; auto tensor_impl = c10::make_intrusive( Storage( @@ -29,9 +30,21 @@ Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) { return a; } +Tensor empty_strided_override( + IntArrayRef size, + IntArrayRef stride, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + + return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt); +} + TORCH_LIBRARY_IMPL(aten, MSNPU, m) { - m.impl_UNBOXED("aten::empty.memory_format", empty_override); - m.impl_UNBOXED("aten::add.Tensor", add_override); + m.impl("aten::empty.memory_format", empty_override); + m.impl("aten::empty_strided", empty_strided_override); + m.impl("aten::add.Tensor", add_override); } TEST(BackendExtensionTest, TestRegisterOp) { diff --git a/aten/src/ATen/test/ivalue_test.cpp b/aten/src/ATen/test/ivalue_test.cpp index 6474aa45d4dd2..a0e2648758ffa 100644 --- a/aten/src/ATen/test/ivalue_test.cpp +++ b/aten/src/ATen/test/ivalue_test.cpp @@ -51,6 +51,91 @@ TEST(IValueTest, Basic) { ASSERT_EQ(tv.use_count(), 2); } +static std::array makeSampleIValues() { + return { at::rand({3, 4}), "hello", 42, true, 1.5 }; +} + +static std::array makeMoreSampleIValues() { + return { at::rand({3, 4}), "goodbye", 23, false, 0.5 }; +} + +// IValue::operator== doesn't seem to work on Tensors. +#define EXPECT_IVALUE_EQ(a, b) \ + EXPECT_EQ((a).isTensor(), (b).isTensor()); \ + if ((a).isTensor()) { \ + EXPECT_TRUE(a.toTensor().equal(b.toTensor())); \ + } else { \ + EXPECT_EQ(a, b); \ + } + +TEST(IValueTest, Swap) { + // swap() has the following 3 cases: tensor, intrusive_ptr, or + // neither. Exercise all pairs of the three. + + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + for (const auto& input: sampleInputs) { + for (const auto& target: sampleTargets) { + IValue a(input); + IValue b(target); + EXPECT_IVALUE_EQ(a, input); + EXPECT_IVALUE_EQ(b, target); + a.swap(b); + EXPECT_IVALUE_EQ(a, target); + EXPECT_IVALUE_EQ(b, input); + } + } +} + +TEST(IValueTest, CopyConstruct) { + auto sampleInputs = makeSampleIValues(); + for (const IValue& v: sampleInputs) { + IValue copy(v); + EXPECT_IVALUE_EQ(copy, v); + } +} + +TEST(IValueTest, MoveConstruct) { + auto sampleInputs = makeSampleIValues(); + for (const IValue& v: sampleInputs) { + IValue source(v); + IValue target(std::move(source)); + EXPECT_IVALUE_EQ(target, v); + EXPECT_TRUE(source.isNone()); + } +} + +TEST(IValueTest, CopyAssign) { + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + + for (const IValue& input: sampleInputs) { + for (const IValue& target: sampleTargets) { + IValue copyTo(target); + IValue copyFrom(input); + copyTo = copyFrom; + EXPECT_IVALUE_EQ(copyTo, input); + EXPECT_IVALUE_EQ(copyFrom, input); + EXPECT_IVALUE_EQ(copyTo, copyFrom); + } + } +} + +TEST(IValueTest, MoveAssign) { + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + + for (const IValue& input: sampleInputs) { + for (const IValue& target: sampleTargets) { + IValue moveTo(target); + IValue moveFrom(input); + moveTo = std::move(moveFrom); + EXPECT_IVALUE_EQ(moveTo, input); + EXPECT_TRUE(moveFrom.isNone()); + } + } +} + TEST(IValueTest, Tuple) { std::tuple t = std::make_tuple(123, at::randn({1})); auto iv = IValue(t); @@ -259,6 +344,18 @@ TEST(IValueTest, ListNestedEquality) { EXPECT_NE(c2, c3); } +TEST(IValueTest, StreamEquality) { + at::Device device1 = at::Device(kCUDA, 0); + at::Device device2 = at::Device(kCUDA, 1); + c10::Stream stream1 = c10::Stream(c10::Stream::Default::DEFAULT, device1); + c10::Stream stream2 = c10::Stream(c10::Stream::Default::DEFAULT, device2); + IValue lhs(stream1); + IValue rhs_different(stream2); + IValue rhs_same(stream1); + EXPECT_FALSE(lhs.equals(rhs_different).toBool()); + EXPECT_TRUE(lhs.equals(rhs_same).toBool()); +} + TEST(IValueTest, EnumEquality) { auto cu = std::make_shared(); IValue int_ivalue_1(1); @@ -306,5 +403,137 @@ TEST(IValueTest, EnumEquality) { ); } +TEST(IValueTest, isPtrType) { + IValue tensor(at::rand({3, 4})); + IValue undefinedTensor((at::Tensor())); + IValue integer(42); + IValue str("hello"); + + EXPECT_TRUE(tensor.isPtrType()); + EXPECT_FALSE(undefinedTensor.isPtrType()); + EXPECT_FALSE(integer.isPtrType()); + EXPECT_TRUE(str.isPtrType()); +} + +TEST(IValueTest, isAliasOf) { + auto sampleIValues = makeSampleIValues(); + for (auto& iv: sampleIValues) { + for (auto& iv2: sampleIValues) { + if (&iv == &iv2 && iv.isPtrType()) { + EXPECT_TRUE(iv.isAliasOf(iv2)); + } else { + EXPECT_FALSE(iv.isAliasOf(iv2)); + } + } + } +} + +TEST(IValueTest, internalToPointer) { + IValue tensor(at::rand({3, 4})); + IValue str("hello"); + + EXPECT_EQ(tensor.internalToPointer(), tensor.unsafeToTensorImpl()); + EXPECT_NE(str.internalToPointer(), nullptr); + + IValue nullStr((c10::intrusive_ptr())); + ASSERT_TRUE(nullStr.isString()); + EXPECT_EQ(nullStr.internalToPointer(), nullptr); +} + +TEST(IValueTest, IdentityComparisonAndHashing) { + at::Tensor t1 = at::rand({3, 4}); + at::Tensor t2 = at::rand({3, 4}); + IValue tv1(t1), tv2(t2); + IValue tv1b(t1); + + EXPECT_EQ(tv1.hash(), tv1b.hash()); + EXPECT_NE(tv1.hash(), tv2.hash()); + + EXPECT_TRUE(tv1.is(tv1)); + EXPECT_TRUE(tv1.is(tv1b)); + EXPECT_TRUE(tv1b.is(tv1)); + EXPECT_TRUE(tv2.is(tv2)); + + EXPECT_FALSE(tv1.is(tv2)); + EXPECT_FALSE(tv2.is(tv1)); + + IValue none; + IValue undefinedTensor((at::Tensor())); + + EXPECT_TRUE(none.is(undefinedTensor)); + EXPECT_TRUE(undefinedTensor.is(none)); + + // Is this a bug? We should probably have a is b => a.hash() == b.hash() + EXPECT_NE(none.hash(), undefinedTensor.hash()); + + auto sampleIValues = makeSampleIValues(); + auto sampleIValues2 = makeSampleIValues(); + auto moreSampleIValues = makeMoreSampleIValues(); + + ASSERT_EQ(sampleIValues.size(), moreSampleIValues.size()); + for (int ii = 0; ii < sampleIValues.size(); ++ii) { + // Constant strings will have the same pointer value. + if (sampleIValues[ii].isPtrType() && !sampleIValues[ii].isString()) { + EXPECT_NE(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); + } else { + EXPECT_EQ(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); + } + EXPECT_NE(sampleIValues[ii].hash(), moreSampleIValues[ii].hash()); + } +} + +TEST(IValueTest, getSubValues) { + // Scalars have no subvalues. + IValue integer(42), float_(1.5); + + IValue::HashAliasedIValues subvalues; + + integer.getSubValues(subvalues); + EXPECT_TRUE(subvalues.empty()); + + subvalues.clear(); + + float_.getSubValues(subvalues); + EXPECT_TRUE(subvalues.empty()); + + subvalues.clear(); + + at::Tensor t1(at::rand({3, 4})), t2(at::rand({3, 4})); + IValue tv1(t1), tv2(t2); + IValue list(std::vector{t1, t2}); + IValue tuple(ivalue::Tuple::create({tv1, tv2})); + + std::unordered_map m; + m[1] = t1; + m[2] = t2; + + IValue dict(std::move(m)); + + auto objType = ClassType::create(nullopt, {}); + objType->addAttribute("t1", tv1.type()); + objType->addAttribute("t2", tv2.type()); + + auto o = ivalue::Object::create(StrongTypePtr(nullptr, objType), 2); + o->setSlot(0, tv1); + o->setSlot(1, tv2); + + IValue object(o); + tv1.getSubValues(subvalues); + EXPECT_EQ(subvalues.size(), 1); + EXPECT_EQ(subvalues.count(tv1), 1); + + subvalues.clear(); + + for (auto& container: {list, tuple, dict, object}) { + container.getSubValues(subvalues); + EXPECT_EQ(subvalues.size(), 3); + EXPECT_EQ(subvalues.count(container), 1); + EXPECT_EQ(subvalues.count(tv1), 1); + EXPECT_EQ(subvalues.count(tv2), 1); + + subvalues.clear(); + } +} + // TODO(gmagogsfm): Add type conversion test? } // namespace c10 diff --git a/aten/src/ATen/test/math_kernel_test.cpp b/aten/src/ATen/test/math_kernel_test.cpp index 9a4dfd640c3e3..6b5657d21a252 100644 --- a/aten/src/ATen/test/math_kernel_test.cpp +++ b/aten/src/ATen/test/math_kernel_test.cpp @@ -4,9 +4,22 @@ using namespace at; -#define ASSERT_ALLCLOSE_TOLERANCES(t1, t2, atol, rtol) \ - ASSERT_TRUE(t1.is_same_size(t2)); \ - ASSERT_TRUE(t1.allclose(t2, atol, rtol)); +bool allClose(const at::Tensor& t1, const at::Tensor& t2, double rtol=1e-5, double atol=1e-8) { + if (!t1.is_same_size(t2)) { + std::cerr << "Difference in tensor shapes: " + << t1.sizes() << " v.s. " << t2.sizes() << std::endl; + return false; + } + bool equal = t1.allclose(t2, rtol, atol); + if (!equal) { + std::cerr << "Difference in tensor value: \nFirst tensor:\n" + << t1 << "\nSecond tensor:\n" << t2 << std::endl; + } + return equal; +} + +#define ASSERT_ALLCLOSE_TOLERANCES(t1, t2, rtol, atol) \ + ASSERT_TRUE(allClose(t1, t2, rtol, atol)); // Ideally we want to test both forward and backward on math kernels but I // haven't found an easy way to do it. Currently we only test forward here @@ -37,4 +50,65 @@ TEST(MathKernelTest, NativeGroupNorm) { } } +TEST(MathKernelTest, NativeLayerNorm) { + const auto input = rand({20, 10, 10, 10}); + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + + double eps = 1e-05; + for (bool undef_weight: {true, false}) { + for (int normalized_size: {2, 3}) { + Tensor undef; + std::vector normalized_shape(normalized_size, 10); + const auto weight = rand(normalized_shape); + const auto bias = rand(normalized_shape); + + auto out = at::native_layer_norm( + input, normalized_shape, undef_weight ? undef : weight, undef_weight ? undef : bias, + eps); + auto math_out = at::native::math_native_layer_norm( + input, normalized_shape, undef_weight ? undef : weight, undef_weight ? undef : bias, + eps); + ASSERT_ALLCLOSE_TOLERANCES(std::get<0>(out), std::get<0>(math_out), 1e-3, 1e-5); + ASSERT_ALLCLOSE_TOLERANCES(std::get<1>(out), std::get<1>(math_out), 1e-3, 1e-5); + ASSERT_ALLCLOSE_TOLERANCES(std::get<2>(out), std::get<2>(math_out), 1e-3, 1e-5); + } + } +} + +TEST(MathKernelTest, Addr) { + const auto vec1 = arange(1., 4.); + const auto vec2 = arange(1., 3.); + const auto M = zeros({3, 2}); + for (float beta: {1., 1.2, 0.}) { + // nans and infs are not propagated to the output when beta == 0 + if (beta == 0) { + M[0][0] = std::numeric_limits::infinity(); + M[2][0] = std::numeric_limits::quiet_NaN(); + } + for (float alpha: {1., 2., 0.}) { + auto out = at::native::addr(M, vec1, vec2, beta, alpha); + auto math_out = at::native::math_addr(M, vec1, vec2, beta, alpha); + ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6); + } + } +} + +TEST(MathKernelTest, SiluBackward) { + const auto input = rand({20, 10}); + const auto grad_output = rand({20, 10}); + auto out = at::native::silu_backward(grad_output, input); + auto math_out = at::native::math_silu_backward(grad_output, input); + ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6); +} + +TEST(MathKernelTest, NarrowCopy) { + auto x = rand({5, 8, 7}); + for (int64_t dim = 0; dim < 3; ++dim) { + const int64_t start = 1, length = 4; + auto y_ref = x.narrow(dim, start, length); + auto y_test = at::native::narrow_copy_dense(x, dim, start, length); + ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0); + } +} diff --git a/aten/src/ATen/test/mobile_memory_cleanup.cpp b/aten/src/ATen/test/mobile_memory_cleanup.cpp new file mode 100644 index 0000000000000..8682fd0a4f151 --- /dev/null +++ b/aten/src/ATen/test/mobile_memory_cleanup.cpp @@ -0,0 +1,39 @@ +#include + +#include +#include + +using namespace torch::jit; + +#ifdef USE_XNNPACK + +TEST(MemoryCleanUp, NoErrorWithoutRelease) { + Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) + )"); + m.eval(); + auto m_optimized = optimizeForMobile(m); + std::stringstream ss; + EXPECT_NO_THROW(m_optimized.save(ss)); +} + +TEST(MemoryCleanUp, UnpackError) { + at::globalContext().setReleaseWeightsWhenPrepacking(true); + Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) + )"); + m.eval(); + auto m_optimized = optimizeForMobile(m); + std::stringstream ss; + EXPECT_ANY_THROW(m_optimized.save(ss)); +} + +#endif diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp index a2b64618ccfe5..efb81a340d226 100644 --- a/aten/src/ATen/test/quantized_test.cpp +++ b/aten/src/ATen/test/quantized_test.cpp @@ -101,9 +101,8 @@ TEST(TestQTensor, EmptyQuantized) { int zero_point = 10; int val = 100; int numel = 10; - Tensor q = at::_empty_affine_quantized({numel}, - at::device(at::kCPU).dtype(kQUInt8), - scale, zero_point); + Tensor q = at::_empty_affine_quantized( + {numel}, at::device(at::kCPU).dtype(kQUInt8), scale, zero_point); // Assigning to QTensor auto* q_data = q.data_ptr(); for (int i = 0; i < numel; ++i) { @@ -142,7 +141,66 @@ TEST(TestQTensor, EmptyPerchannelQuantized) { for (int i = 0; i < numel; ++i) { ASSERT_EQ( r_data[i], - (val - zero_points[i].item().to()) * - scales[i].item().to()); + (val - zero_points[i].item().to()) * scales[i].item().to()); + } +} + +TEST(TestQTensor, QuantizePerChannel4d) { + int C = 64, H = 10, W = 10; + auto scales = rand({C}).toType(kDouble); + auto zero_points = randint(10, {C}).toType(kLong); + int ch_axis = 1; + // create 4d tensor where each H x W image is a range(0, H*W) + Tensor tensor = at::empty({1, C, H, W}, at::device(at::kCPU).dtype(kFloat)); + auto* tensor_data = tensor.data_ptr(); + for (int c = 0, i = 0; c < C; ++c) { + for (int e = 0; e < H * W; ++e, ++i) { + tensor_data[i] = e; + } + } + // quantize and check values + Tensor q = at::native::quantize_per_channel_cpu( + tensor, scales, zero_points, ch_axis, kQUInt8); + auto* q_data = (uint8_t*)q.data_ptr(); + for (int c = 0, i = 0; c < C; ++c) { + float inv_scale = 1.0f / static_cast(scales[c].item()); + int64_t zero_point = zero_points[c].item(); + for (int e = 0; e < H * W; ++e, ++i) { + // downsize qval to 255 if val is greater than max uint8_t value + int qval = std::min(zero_point + std::nearbyint(e * inv_scale), 255); + ASSERT_EQ((int)q_data[i], qval); + } + } +} + +TEST(TestQTensor, QuantizePerChannel4dChannelsLast) { + int C = 64, H = 10, W = 10; + auto scales = rand({C}).toType(kDouble); + auto zero_points = randint(10, {C}).toType(kLong); + int ch_axis = 1; + // create 4d tensor where each H x W image is a range(0, H*W) + Tensor tensor = at::empty( + {1, C, H, W}, + at::device(at::kCPU).dtype(kFloat).memory_format( + at::MemoryFormat::ChannelsLast)); + auto* tensor_data = tensor.data_ptr(); + for (int e = 0, i = 0; e < H * W; ++e) { + for (int c = 0; c < C; ++c, ++i) { + tensor_data[i] = e; + } + } + + // quantize and check values + Tensor q = at::native::quantize_per_channel_cpu( + tensor, scales, zero_points, ch_axis, kQUInt8); + auto* q_data = (uint8_t*)q.data_ptr(); + for (int e = 0, i = 0; e < H * W; ++e) { + for (int c = 0; c < C; ++c, ++i) { + float inv_scale = 1.0f / static_cast(scales[c].item()); + int64_t zero_point = zero_points[c].item(); + // downsize qval to 255 if val is greater than max uint8_t value + int qval = std::min(zero_point + std::nearbyint(e * inv_scale), 255); + ASSERT_EQ((int)q_data[i], qval); + } } } diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index dc73460b3728f..3b7bfb47fe620 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -128,3 +128,33 @@ TEST(TestScalar, TestScalar) { ASSERT_EQ(float_one.item(), 1); ASSERT_EQ(float_one.item(), 1); } + +TEST(TestScalar, TestConj) { + Scalar int_scalar = 257; + Scalar float_scalar = 3.0; + Scalar complex_scalar = c10::complex(2.3, 3.5); + + ASSERT_EQ(int_scalar.conj().toInt(), 257); + ASSERT_EQ(float_scalar.conj().toDouble(), 3.0); + ASSERT_EQ(complex_scalar.conj().toComplexDouble(), c10::complex(2.3, -3.5)); +} + +TEST(TestScalar, TestEqual) { + ASSERT_FALSE(Scalar(1.0).equal(false)); + ASSERT_FALSE(Scalar(1.0).equal(true)); + ASSERT_FALSE(Scalar(true).equal(1.0)); + ASSERT_TRUE(Scalar(true).equal(true)); + + ASSERT_TRUE(Scalar(c10::complex{2.0, 5.0}).equal(c10::complex{2.0, 5.0})); + ASSERT_TRUE(Scalar(c10::complex{2.0, 0}).equal(2.0)); + ASSERT_TRUE(Scalar(c10::complex{2.0, 0}).equal(2)); + + ASSERT_TRUE(Scalar(2.0).equal(c10::complex{2.0, 0.0})); + ASSERT_FALSE(Scalar(2.0).equal(c10::complex{2.0, 4.0})); + ASSERT_FALSE(Scalar(2.0).equal(3.0)); + ASSERT_TRUE(Scalar(2.0).equal(2)); + + ASSERT_TRUE(Scalar(2).equal(c10::complex{2.0, 0})); + ASSERT_TRUE(Scalar(2).equal(2)); + ASSERT_TRUE(Scalar(2).equal(2.0)); +} diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 0650e9a3e6b4e..55df55f3b58cf 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp index 7a6dd50f91630..0102a8cf4f49c 100644 --- a/aten/src/ATen/test/undefined_tensor_test.cpp +++ b/aten/src/ATen/test/undefined_tensor_test.cpp @@ -19,7 +19,7 @@ TEST(TestUndefined, UndefinedTest) { ASSERT_EQ(std::string("UndefinedType"), und.toString()); ASSERT_ANY_THROW(und.strides()); - ASSERT_ANY_THROW(und.dim()); + ASSERT_EQ(und.dim(), 1); ASSERT_ANY_THROW([]() { return Tensor(); }() = Scalar(5)); ASSERT_ANY_THROW(und.add(und)); ASSERT_ANY_THROW(und.add(ft)); diff --git a/aten/src/ATen/test/vec256_test.cpp b/aten/src/ATen/test/vec256_test.cpp deleted file mode 100644 index 73680073c014b..0000000000000 --- a/aten/src/ATen/test/vec256_test.cpp +++ /dev/null @@ -1,671 +0,0 @@ -#include - -#include -#include - -#include - -using namespace at::vec256; - -bool check_equal(const at::Tensor& a, const at::Tensor& b) { - return (a.equal(b)); -} - -bool check_almost_equal( - const at::Tensor& a, const at::Tensor& b, const float tolerance) { - double max_val = a.abs().max().item(); - max_val = std::max(max_val, b.abs().max().item()); - if ((a - b).abs().max().item() > tolerance * max_val) { - std::cout << "Max difference:" - << (a - b).abs().max().item() << std::endl; - return false; - } - return true; -} - -template -void BlendTestHelperScalar( - const T* a_ptr, - const T* b_ptr, - T* res_ptr, - const int64_t num_els, - const int64_t count) { - for(auto i = 0; i < num_els; ++i) { - for (auto j = 0; j < Vec256::size(); ++j) { - auto index = i * Vec256::size() + j; - if (j < count) { - res_ptr[index] = b_ptr[index]; - } else { - res_ptr[index] = a_ptr[index]; - } - } - } -} - -namespace Impl { -float reciprocal(const float a) { - return (1/a); -} - -float rsqrt(const float a) { - return (1/std::sqrt(a)); -} - -float frac(const float a) { - return a - (static_cast(a)); -} -} - -template -void BlendTestHelperVector( - const T* a_ptr, - const T* b_ptr, - T* res_ptr, - const int64_t num_els, - const int64_t count) { - for(auto i = 0; i < num_els; ++i) { - auto a_elements = Vec256::loadu(a_ptr); - auto b_elements = Vec256::loadu(b_ptr); - a_ptr += Vec256::size(); - b_ptr += Vec256::size(); - auto res_elements = Vec256::set(a_elements, b_elements, count); - res_elements.store(res_ptr); - res_ptr += Vec256::size(); - } -} - -#define TranscedentalTester(opnamespace, name) \ -void TranscedentalHelper_##name(const float tolerance = 1e-6) { \ - at::Tensor a = at::rand({23, 23}); \ - a = a * -10; \ - a = a + 10; \ - at::Tensor ref_res = at::zeros({23, 23}); \ - at::Tensor vec_res = at::zeros({23, 23}); \ - float* a_ptr = a.data_ptr(); \ - float* ref_res_ptr = ref_res.data_ptr(); \ - float* vec_res_ptr = vec_res.data_ptr(); \ - size_t num_els = \ - (a.numel() / Vec256::size()) * Vec256::size(); \ - for(auto i = 0; i < num_els; ++i) { \ - ref_res_ptr[i] = opnamespace::name(a_ptr[i]); \ - } \ - for (size_t i = 0; i < num_els; i += Vec256::size()) { \ - auto a_elements = Vec256::loadu(a_ptr); \ - a_ptr += Vec256::size(); \ - auto res = a_elements.name(); \ - res.store(vec_res_ptr); \ - vec_res_ptr += Vec256::size(); \ - } \ - ASSERT_TRUE(check_almost_equal(ref_res, vec_res, tolerance)); \ -} - -#define TranscedentalTester2(name) \ -void TranscedentalHelper_##name(const float tolerance = 1e-6) { \ - at::Tensor a = at::rand({23, 23}); \ - at::Tensor b = at::rand({23, 23}); \ - a = a * -10; \ - a = a + 10; \ - at::Tensor ref_res = at::zeros({23, 23}); \ - at::Tensor vec_res = at::zeros({23, 23}); \ - float* a_ptr = a.data_ptr(); \ - float* b_ptr = a.data_ptr(); \ - float* ref_res_ptr = ref_res.data_ptr(); \ - float* vec_res_ptr = vec_res.data_ptr(); \ - size_t num_els = \ - (a.numel() / Vec256::size()) * Vec256::size(); \ - for(auto i = 0; i < num_els; ++i) { \ - ref_res_ptr[i] = std::name(a_ptr[i], b_ptr[i]); \ - } \ - for (size_t i = 0; i < num_els; i += Vec256::size()) { \ - auto a_elements = Vec256::loadu(a_ptr); \ - auto b_elements = Vec256::loadu(b_ptr); \ - a_ptr += Vec256::size(); \ - b_ptr += Vec256::size(); \ - auto res = a_elements.name(b_elements); \ - res.store(vec_res_ptr); \ - vec_res_ptr += Vec256::size(); \ - } \ - ASSERT_TRUE(check_almost_equal(ref_res, vec_res, tolerance)); \ -} - -// Not testing all the transcendentals. -// In fact fewer than these might suffice, since current implementation -// actually just calls STL version of these. -// So what is really being checked is the logic to map a function. -TranscedentalTester(std, abs) -TranscedentalTester(std, acos) -TranscedentalTester(std, asin) -TranscedentalTester(std, atan) -TranscedentalTester(std, erf) -TranscedentalTester(std, exp) -TranscedentalTester(std, log) -TranscedentalTester(std, tan) -TranscedentalTester(std, trunc) -TranscedentalTester(std, sqrt) - -TranscedentalTester2(atan2) -TranscedentalTester2(fmod) -TranscedentalTester2(pow) - -TranscedentalTester(Impl, reciprocal) -TranscedentalTester(Impl, rsqrt) -TranscedentalTester(Impl, frac) - -enum class OP_TYPE { - EQ = 0, - NE, - GT, - GE, - LT, - LE, - MIN, - MAX, - ADD, - SUB, - MUL, - DIV, - AND, - OR, - EXOR -}; - -void BasicOpTestHelper(const OP_TYPE& compare_type) { - at::Tensor a = at::rand({23, 23}); - at::Tensor b = at::rand({23, 23}); - at::Tensor ref_res = at::zeros({23, 23}); - at::Tensor vec_res = at::zeros({23, 23}); - - size_t num_els = - (a.numel() / Vec256::size()) * Vec256::size(); - // Vector components - float* a_ptr = a.data_ptr(); - float* b_ptr = b.data_ptr(); - float* ref_res_ptr = ref_res.data_ptr(); - for (size_t i = 0; i < num_els; ++i) { - switch (compare_type) { - case OP_TYPE::EQ: - if (a_ptr[i] == b_ptr[i]) { - ref_res_ptr[i] = 1.0f; - } else { - ref_res_ptr[i] = 0; - } - break; - case OP_TYPE::NE: - if (a_ptr[i] != b_ptr[i]) { - ref_res_ptr[i] = 1.0f; - } else { - ref_res_ptr[i] = 0; - } - break; - case OP_TYPE::GT: - if (a_ptr[i] > b_ptr[i]) { - ref_res_ptr[i] = 1.0f; - } else { - ref_res_ptr[i] = 0; - } - break; - case OP_TYPE::GE: - if (a_ptr[i] >= b_ptr[i]) { - ref_res_ptr[i] = 1.0f; - } else { - ref_res_ptr[i] = 0; - } - break; - case OP_TYPE::LT: - if (a_ptr[i] < b_ptr[i]) { - ref_res_ptr[i] = 1.0f; - } else { - ref_res_ptr[i] = 0; - } - break; - case OP_TYPE::LE: - if (a_ptr[i] <= b_ptr[i]) { - ref_res_ptr[i] = 1.0f; - } else { - ref_res_ptr[i] = 0; - } - break; - case OP_TYPE::MIN: - ref_res_ptr[i] = std::min(a_ptr[i], b_ptr[i]); - break; - case OP_TYPE::MAX: - ref_res_ptr[i] = std::max(a_ptr[i], b_ptr[i]); - break; - case OP_TYPE::ADD: - ref_res_ptr[i] = a_ptr[i] + b_ptr[i]; - break; - case OP_TYPE::SUB: - ref_res_ptr[i] = a_ptr[i] - b_ptr[i]; - break; - case OP_TYPE::MUL: - ref_res_ptr[i] = a_ptr[i] * b_ptr[i]; - break; - case OP_TYPE::DIV: - ref_res_ptr[i] = a_ptr[i] / b_ptr[i]; - break; - case OP_TYPE::OR: - { - uint32_t *a_val, *b_val; - a_val = reinterpret_cast(&a_ptr[i]); - b_val = reinterpret_cast(&b_ptr[i]); - uint32_t c_val = (*a_val) | (*b_val); - float* c_val_float; - c_val_float = reinterpret_cast(&c_val); - ref_res_ptr[i] = *c_val_float; - } - break; - case OP_TYPE::AND: - { - uint32_t *a_val, *b_val; - a_val = reinterpret_cast(&a_ptr[i]); - b_val = reinterpret_cast(&b_ptr[i]); - uint32_t c_val = (*a_val) & (*b_val); - float* c_val_float; - c_val_float = reinterpret_cast(&c_val); - ref_res_ptr[i] = *c_val_float; - } - break; - case OP_TYPE::EXOR: - { - uint32_t *a_val, *b_val; - a_val = reinterpret_cast(&a_ptr[i]); - b_val = reinterpret_cast(&b_ptr[i]); - uint32_t c_val = (*a_val) ^ (*b_val); - float* c_val_float; - c_val_float = reinterpret_cast(&c_val); - ref_res_ptr[i] = *c_val_float; - } - break; - } - } - - // Vectorized impl - float* vec_res_ptr = vec_res.data_ptr(); - for (size_t i = 0; i < num_els; i += Vec256::size()) { - auto a_elements = Vec256::loadu(a_ptr); - auto b_elements = Vec256::loadu(b_ptr); - a_ptr += Vec256::size(); - b_ptr += Vec256::size(); - Vec256 res_elements; - switch (compare_type) { - case OP_TYPE::EQ: - res_elements = a_elements.eq(b_elements); - break; - case OP_TYPE::NE: - res_elements = a_elements.ne(b_elements); - break; - case OP_TYPE::GT: - res_elements = a_elements.gt(b_elements); - break; - case OP_TYPE::GE: - res_elements = a_elements.ge(b_elements); - break; - case OP_TYPE::LT: - res_elements = a_elements.lt(b_elements); - break; - case OP_TYPE::LE: - res_elements = a_elements.le(b_elements); - break; - case OP_TYPE::MIN: - res_elements = at::vec256::minimum(a_elements, b_elements); - break; - case OP_TYPE::MAX: - res_elements = at::vec256::maximum(a_elements, b_elements); - break; - case OP_TYPE::ADD: - res_elements = a_elements + b_elements; - break; - case OP_TYPE::SUB: - res_elements = a_elements - b_elements; - break; - case OP_TYPE::MUL: - res_elements = a_elements * b_elements; - break; - case OP_TYPE::DIV: - res_elements = a_elements / b_elements; - break; - case OP_TYPE::OR: - res_elements = a_elements | b_elements; - break; - case OP_TYPE::AND: - res_elements = a_elements & b_elements; - break; - case OP_TYPE::EXOR: - res_elements = a_elements ^ b_elements; - break; - } - res_elements.store(vec_res_ptr); - vec_res_ptr += Vec256::size(); - } - ASSERT_TRUE(check_equal(ref_res, vec_res)); -} - -// Checks both loads and stores. -TEST(Vec256TestFloat, CopyTest) { - at::Tensor a = at::rand({23, 23}); - at::Tensor b = at::zeros({23, 23}); - // Copy goes through vec256 via tensoriterator - b.copy_(a); - ASSERT_TRUE(check_equal(a, b)); -} - -TEST(Vec256TestFloat, arangeTest) { - at::Tensor arange_output_ref = at::zeros({8}); - at::Tensor arange_output_vectorized = at::zeros({8}); - float base = 7.f; - float step = 5.f; - float* ref_output_ptr = arange_output_ref.data_ptr(); - for (int64_t i = 0; i < 8; ++i) { - ref_output_ptr[i] = base + i * step; - } - float* vec_output_ptr = arange_output_vectorized.data_ptr(); - auto arange_output = Vec256::arange(base, step); - arange_output.store(vec_output_ptr); - ASSERT_TRUE(check_equal(arange_output_ref, arange_output_vectorized)); -} - -// Checks blend and blendv. -TEST(Vec256TestFloat, Blend) { - at::Tensor a = at::rand({23, 23}); - at::Tensor b = at::rand({23, 23}); - at::Tensor ref_res = at::zeros({23, 23}); - at::Tensor vec_res = at::zeros({23, 23}); - - // Check templatized blend. - // Reference result: - const int64_t mask = 0xC5; - // Only check over multiple of Vec::size elements - size_t num_els = - (a.numel() / Vec256::size()) * Vec256::size(); - // Vector components - float* a_ptr = a.data_ptr(); - float* b_ptr = b.data_ptr(); - float* ref_res_ptr = ref_res.data_ptr(); - int64_t tmp_mask = mask; - for (size_t i = 0; i < num_els; ++i) { - if (i % Vec256::size() == 0) { - tmp_mask = mask; - } - if (tmp_mask & 0x1) { - ref_res_ptr[i] = b_ptr[i]; - } else { - ref_res_ptr[i] = a_ptr[i]; - } - tmp_mask = tmp_mask >> 1; - } - - // Vectorized impl - float* vec_res_ptr = vec_res.data_ptr(); - for (size_t i = 0; i < num_els; i += Vec256::size()) { - auto a_elements = Vec256::loadu(a_ptr); - auto b_elements = Vec256::loadu(b_ptr); - a_ptr += Vec256::size(); - b_ptr += Vec256::size(); - auto res_elements = Vec256::blend(a_elements, b_elements); - res_elements.store(vec_res_ptr); - vec_res_ptr += Vec256::size(); - } - ASSERT_TRUE(check_equal(ref_res, vec_res)); - - // Vector components - a_ptr = a.data_ptr(); - b_ptr = b.data_ptr(); - int32_t full_int_mask = 0xFFFFFFFF; - float* full_ptr = reinterpret_cast(&full_int_mask); - float full_float_mask = *full_ptr; - Vec256 float_mask(full_float_mask, 0.f, full_float_mask, 0.f, - 0.f, full_float_mask, 0.f, 0.f); - float float_mask_array[Vec256::size()]; - float_mask.store(float_mask_array); - ref_res_ptr = ref_res.data_ptr(); - for (size_t i = 0; i < num_els; ++i) { - if (float_mask_array[i % Vec256::size()] != 0) { - ref_res_ptr[i] = b_ptr[i]; - } else { - ref_res_ptr[i] = a_ptr[i]; - } - tmp_mask = tmp_mask >> 1; - } - - // Vectorized impl - vec_res_ptr = vec_res.data_ptr(); - for (size_t i = 0; i < num_els; i += Vec256::size()) { - auto a_elements = Vec256::loadu(a_ptr); - auto b_elements = Vec256::loadu(b_ptr); - a_ptr += Vec256::size(); - b_ptr += Vec256::size(); - auto res_elements = Vec256::blendv(a_elements, b_elements, float_mask); - res_elements.store(vec_res_ptr); - vec_res_ptr += Vec256::size(); - } - ASSERT_TRUE(check_equal(ref_res, vec_res)); -} - -// Checks Set -TEST(Vec256TestFloat, Set) { - at::Tensor a = at::rand({23, 23}); - at::Tensor b = at::rand({23, 23}); - at::Tensor ref_res = at::zeros({23, 23}); - at::Tensor vec_res = at::zeros({23, 23}); - - const float* a_ptr = a.data_ptr(); - const float* b_ptr = b.data_ptr(); - float* ref_res_ptr = ref_res.data_ptr(); - float* vec_res_ptr = vec_res.data_ptr(); - - // Only check over multiple of Vec::size elements - const size_t num_els = (a.numel() / Vec256::size()); - BlendTestHelperScalar(a_ptr, b_ptr, ref_res_ptr, num_els, 0); - BlendTestHelperVector(a_ptr, b_ptr, vec_res_ptr, num_els, 0); - ASSERT_TRUE(check_equal(ref_res, vec_res)); - BlendTestHelperScalar(a_ptr, b_ptr, ref_res_ptr, num_els, 1); - BlendTestHelperVector(a_ptr, b_ptr, vec_res_ptr, num_els, 1); - ASSERT_TRUE(check_equal(ref_res, vec_res)); - BlendTestHelperScalar(a_ptr, b_ptr, ref_res_ptr, num_els, 4); - BlendTestHelperVector(a_ptr, b_ptr, vec_res_ptr, num_els, 4); - ASSERT_TRUE(check_equal(ref_res, vec_res)); - BlendTestHelperScalar(a_ptr, b_ptr, ref_res_ptr, num_els, 6); - BlendTestHelperVector(a_ptr, b_ptr, vec_res_ptr, num_els, 6); - ASSERT_TRUE(check_equal(ref_res, vec_res)); - BlendTestHelperScalar(a_ptr, b_ptr, ref_res_ptr, num_els, 8); - BlendTestHelperVector(a_ptr, b_ptr, vec_res_ptr, num_els, 8); - ASSERT_TRUE(check_equal(ref_res, vec_res)); -} - -TEST(Vec256TestFloat, Abs) { - TranscedentalHelper_abs(); -} - -TEST(Vec256TestFloat, acos) { - TranscedentalHelper_acos(); -} - -TEST(Vec256TestFloat, asin) { - TranscedentalHelper_asin(); -} - -TEST(Vec256TestFloat, atan) { - TranscedentalHelper_atan(); -} - -TEST(Vec256TestFloat, erf) { - TranscedentalHelper_erf(); -} - -TEST(Vec256TestFloat, exp) { - TranscedentalHelper_exp(); -} - -TEST(Vec256TestFloat, tan) { - TranscedentalHelper_tan(); -} - -TEST(Vec256TestFloat, log) { - TranscedentalHelper_log(); -} - -TEST(Vec256TestFloat, trunc) { - TranscedentalHelper_trunc(); -} - -TEST(Vec256TestFloat, sqrt) { - TranscedentalHelper_sqrt(); -} - -TEST(Vec256TestFloat, atan2) { - TranscedentalHelper_atan2(); -} - -TEST(Vec256TestFloat, fmod) { - TranscedentalHelper_fmod(); -} - -TEST(Vec256TestFloat, pow) { - TranscedentalHelper_pow(); -} - -TEST(Vec256TestFloat, reciprocal) { - TranscedentalHelper_reciprocal(1e-3); -} - -TEST(Vec256TestFloat, rsqrt) { - // rsqrt tolerance is much worse. - // If we did not set seed even this is violated sometimes. - TranscedentalHelper_rsqrt(5e-3); -} - -TEST(Vec256TestFloat, frac) { - TranscedentalHelper_frac(); -} - -TEST(Vec256TestFloat, compare_eq) { - BasicOpTestHelper(OP_TYPE::EQ); -} - -TEST(Vec256TestFloat, compare_ne) { - BasicOpTestHelper(OP_TYPE::NE); -} - -TEST(Vec256TestFloat, compare_gt) { - BasicOpTestHelper(OP_TYPE::GT); -} - -TEST(Vec256TestFloat, compare_ge) { - BasicOpTestHelper(OP_TYPE::GE); -} - -TEST(Vec256TestFloat, compare_lt) { - BasicOpTestHelper(OP_TYPE::LT); -} - -TEST(Vec256TestFloat, compare_le) { - BasicOpTestHelper(OP_TYPE::LE); -} - -TEST(Vec256TestFloat, check_min) { - BasicOpTestHelper(OP_TYPE::MIN); -} - -TEST(Vec256TestFloat, check_max) { - BasicOpTestHelper(OP_TYPE::MAX); -} - -TEST(Vec256TestFloat, compare_add) { - BasicOpTestHelper(OP_TYPE::ADD); -} - -TEST(Vec256TestFloat, compare_sub) { - BasicOpTestHelper(OP_TYPE::SUB); -} - -TEST(Vec256TestFloat, check_mul) { - BasicOpTestHelper(OP_TYPE::MUL); -} - -TEST(Vec256TestFloat, check_div) { - BasicOpTestHelper(OP_TYPE::DIV); -} - -TEST(Vec256TestFloat, compare_or) { - BasicOpTestHelper(OP_TYPE::OR); -} - -TEST(Vec256TestFloat, check_and) { - BasicOpTestHelper(OP_TYPE::AND); -} - -TEST(Vec256TestFloat, check_xor) { - BasicOpTestHelper(OP_TYPE::EXOR); -} - -TEST(Vec256TestFloat, check_convert) { - at::Tensor a = at::rand({23, 23}); - a = a * -10; - a = a + 10; - at::Tensor ref_res = - at::empty({23, 23}, at::device(at::kCPU).dtype(at::kInt)); - at::Tensor vec_res = - at::empty({23, 23}, at::device(at::kCPU).dtype(at::kInt)); - float* a_float_ptr = a.data_ptr(); - int32_t* ref_res_int_ptr = ref_res.data_ptr(); - int32_t* vec_res_int_ptr = vec_res.data_ptr(); - for(auto i = 0; i < a.numel(); ++i) { - ref_res_int_ptr[i] = static_cast(a_float_ptr[i]); - } - at::vec256::convert(a_float_ptr, vec_res_int_ptr, a.numel()); - ASSERT_TRUE(check_almost_equal(ref_res, vec_res, 1e-6)); - - a = at::randint(-100, 100, {23, 23}); - a = a.to(at::kInt); - ref_res = at::empty({23, 23}); - vec_res = at::empty({23, 23}); - int32_t* a_int_ptr = a.data_ptr(); - float* ref_res_float_ptr = ref_res.data_ptr(); - float* vec_res_float_ptr = vec_res.data_ptr(); - for(auto i = 0; i < a.numel(); ++i) { - ref_res_float_ptr[i] = static_cast(a_int_ptr[i]); - } - at::vec256::convert(a_int_ptr, vec_res_float_ptr, a.numel()); - ASSERT_TRUE(check_almost_equal(ref_res, vec_res, 1e-6)); -} - -TEST(Vec256TestFloat, check_fmadd) { - at::Tensor a = at::rand({23, 23}); - a = a * -10; - a = a + 10; - at::Tensor b = at::rand({23, 23}); - b = b * -5; - b = b + 5; - at::Tensor c = at::rand({23, 23}); - c = c * 20; - at::Tensor ref_res = at::zeros({23, 23}); - at::Tensor vec_res = at::zeros({23, 23}); - float* a_ptr = a.data_ptr(); - float* b_ptr = a.data_ptr(); - float* c_ptr = a.data_ptr(); - float* ref_res_ptr = ref_res.data_ptr(); - float* vec_res_ptr = vec_res.data_ptr(); - size_t num_els = - (a.numel() / Vec256::size()) * Vec256::size(); - for(auto i = 0; i < num_els; ++i) { - ref_res_ptr[i] = a_ptr[i] * b_ptr[i] + c_ptr[i]; - } - for (size_t i = 0; i < num_els; i += Vec256::size()) { - auto a_elements = Vec256::loadu(a_ptr); - auto b_elements = Vec256::loadu(b_ptr); - auto c_elements = Vec256::loadu(c_ptr); - a_ptr += Vec256::size(); - b_ptr += Vec256::size(); - c_ptr += Vec256::size(); - auto res_elements = at::vec256::fmadd(a_elements, b_elements, c_elements); - res_elements.store(vec_res_ptr); - vec_res_ptr += Vec256::size(); - } - ASSERT_TRUE(check_almost_equal(ref_res, vec_res, 1e-6)); -} - -int main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - at::manual_seed(42); - return RUN_ALL_TESTS(); -} diff --git a/aten/src/ATen/test/vec256_test_all_types.cpp b/aten/src/ATen/test/vec256_test_all_types.cpp index 713de26baddc5..595b2eefd60ef 100644 --- a/aten/src/ATen/test/vec256_test_all_types.cpp +++ b/aten/src/ATen/test/vec256_test_all_types.cpp @@ -22,6 +22,8 @@ namespace { template class SqrtAndReciprocalReal : public ::testing::Test {}; template + class FractionAndRemainderReal : public ::testing::Test {}; + template class Trigonometric : public ::testing::Test {}; template class ErrorFunctions : public ::testing::Test {}; @@ -42,6 +44,8 @@ namespace { template class Pow : public ::testing::Test {}; template + class RangeFactories : public ::testing::Test {}; + template class BitwiseFloatsAdditional : public ::testing::Test {}; template class BitwiseFloatsAdditional2 : public ::testing::Test {}; @@ -57,7 +61,6 @@ namespace { using QuantTestedTypes = ::testing::Types; using RealFloatIntTestedTypes = ::testing::Types; using FloatIntTestedTypes = ::testing::Types; - using SingleFloat = ::testing::Types; using ComplexTypes = ::testing::Types; TYPED_TEST_CASE(Memory, ALLTestedTypes); TYPED_TEST_CASE(Arithmetics, FloatIntTestedTypes); @@ -69,6 +72,7 @@ namespace { TYPED_TEST_CASE(Rounding, RealFloatTestedTypes); TYPED_TEST_CASE(SqrtAndReciprocal, FloatTestedTypes); TYPED_TEST_CASE(SqrtAndReciprocalReal, RealFloatTestedTypes); + TYPED_TEST_CASE(FractionAndRemainderReal, RealFloatTestedTypes); TYPED_TEST_CASE(Trigonometric, RealFloatTestedTypes); TYPED_TEST_CASE(ErrorFunctions, RealFloatTestedTypes); TYPED_TEST_CASE(Exponents, RealFloatTestedTypes); @@ -80,6 +84,7 @@ namespace { TYPED_TEST_CASE(LogarithmReals, RealFloatTestedTypes); TYPED_TEST_CASE(Pow, RealFloatTestedTypes); TYPED_TEST_CASE(RealTests, RealFloatTestedTypes); + TYPED_TEST_CASE(RangeFactories, FloatIntTestedTypes); TYPED_TEST_CASE(BitwiseFloatsAdditional, RealFloatTestedTypes); TYPED_TEST_CASE(BitwiseFloatsAdditional2, FloatTestedTypes); TYPED_TEST_CASE(QuantizationTests, QuantTestedTypes); @@ -198,7 +203,6 @@ namespace { [](vec v) { return v.sqrt(); }, createDefaultUnaryTestCase(TestSeed(), false, true)); } - TYPED_TEST(SqrtAndReciprocalReal, RSqrt) { using vec = TypeParam; test_unary( @@ -217,6 +221,23 @@ namespace { createDefaultUnaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_zero)); } + TYPED_TEST(FractionAndRemainderReal, Frac) { + using vec = TypeParam; + test_unary( + NAME_INFO(frac), + RESOLVE_OVERLOAD(frac), + [](vec v) { return v.frac(); }, + createDefaultUnaryTestCase(TestSeed(), false, true)); + } + TYPED_TEST(FractionAndRemainderReal, Fmod) { + using vec = TypeParam; + test_binary( + NAME_INFO(fmod), + RESOLVE_OVERLOAD(std::fmod), + [](vec v0, vec v1) { return v0.fmod(v1); }, + createDefaultBinaryTestCase(TestSeed()), + RESOLVE_OVERLOAD(filter_fmod)); + } TYPED_TEST(Trigonometric, Sin) { using vec = TypeParam; using UVT = UvalueType; @@ -702,29 +723,74 @@ namespace { ASSERT_EQ(expected, actual) << "Failure Details:\n" << std::hex << "Expected:\n#\t" << expected << "\nActual:\n#\t" << actual; - } // + } + } + TYPED_TEST(BitwiseFloatsAdditional, Convert) { + using vec = TypeParam; + using VT = ValueType; + using IntVT = at::vec256::int_same_size_t; + + // verify float to int + CACHE_ALIGN VT input1[vec::size()]; + CACHE_ALIGN IntVT expected_vals1[vec::size()]; + CACHE_ALIGN IntVT actual_vals1[vec::size()]; + for (int64_t i = 0; i < vec::size(); i++) { + input1[i] = (VT)i * (VT)2.1 + (VT)0.5; + expected_vals1[i] = static_cast(input1[i]); + } + at::vec256::convert(input1, actual_vals1, vec::size()); + auto expected1 = VecType::loadu(expected_vals1); + auto actual1 = VecType::loadu(actual_vals1); + if (AssertVec256>(NAME_INFO(test_convert_to_int), expected1, actual1).check()) { + return; + } + + // verify int to float + CACHE_ALIGN IntVT input2[vec::size()]; + CACHE_ALIGN VT expected_vals2[vec::size()]; + CACHE_ALIGN VT actual_vals2[vec::size()]; + for (int64_t i = 0; i < vec::size(); i++) { + input2[i] = (IntVT)i * (IntVT)2 + (IntVT)1; + expected_vals2[i] = (VT)input2[i]; + } + at::vec256::convert(input2, actual_vals2, vec::size()); + auto expected2 = vec::loadu(expected_vals2); + auto actual2 = vec::loadu(actual_vals2); + AssertVec256(NAME_INFO(test_convert_to_float), expected2, actual2).check(); + } + TYPED_TEST(BitwiseFloatsAdditional, Fmadd) { + using vec = TypeParam; + using VT = ValueType; + + auto test_case = TestingCase::getBuilder() + .addDomain(CheckWithinDomains{ + {{(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}}, + true, getDefaultTolerance()}) + .setTestSeed(TestSeed()); + + test_ternary( + NAME_INFO(clamp), RESOLVE_OVERLOAD(local_fmadd), + [](const vec& v0, const vec& v1, const vec& v2) { + return at::vec256::fmadd(v0, v1, v2); + }, + test_case, + RESOLVE_OVERLOAD(filter_fmadd)); } template typename std::enable_if_t<(mask < 0 || mask> 255), void> - test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()]) + test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()]) { } template typename std::enable_if_t<(mask >= 0 && mask <= 255), void> - test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()]) - { - //generate expected_val + test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()]) { + // generate expected_val int64_t m = mask; for (int64_t i = 0; i < vec::size(); i++) { - if (m & 0x01) { - expected_val[i] = b[i]; - } - else { - expected_val[i] = a[i]; - } + expected_val[i] = (m & 0x01) ? b[i] : a[i]; m = m >> 1; } - //test with blend + // test with blend auto vec_a = vec::loadu(a); auto vec_b = vec::loadu(b); auto expected = vec::loadu(expected_val); @@ -733,6 +799,47 @@ namespace { if (AssertVec256(std::string(NAME_INFO(test_blend)) + mask_str, expected, actual).check()) return; test_blend(expected_val, a, b); } + template + std::enable_if_t<(!is_complex::value && idx == N), bool> + test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) { + // generate expected_val + for (int64_t i = 0; i < vec::size(); i++) { + int64_t hex_mask = 0; + std::memcpy(&hex_mask, &mask[i], sizeof(VT)); + expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i]; + } + // test with blendv + auto vec_a = vec::loadu(a); + auto vec_b = vec::loadu(b); + auto vec_m = vec::loadu(mask); + auto expected = vec::loadu(expected_val); + auto actual = vec::blendv(vec_a, vec_b, vec_m); + auto mask_str = std::string("\nblendv mask: "); + for (int64_t i = 0; i < vec::size(); i++) { + mask_str += std::to_string(mask[i]) + " "; + } + if (AssertVec256(std::string(NAME_INFO(test_blendv)) + mask_str, expected, actual).check()) { + return false; + } + return true; + } + template + std::enable_if_t<(!is_complex::value && idx != N), bool> + test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) { + // shuffle mask and do blendv test + VT m = mask[idx]; + if (!test_blendv(expected_val, a, b, mask)) return false; + if (m != (VT)0) { + mask[idx] = (VT)0; + } + else { + int64_t hex_mask = 0xFFFFFFFFFFFFFFFF; + std::memcpy(&mask[idx], &hex_mask, sizeof(VT)); + } + if (!test_blendv(expected_val, a, b, mask)) return false; + mask[idx] = m; + return true; + } template void blend_init(T(&a)[N], T(&b)[N]) { a[0] = (T)1.0; @@ -760,6 +867,16 @@ namespace { a[1] = a[0] + add; b[1] = b[0] + add; } + TYPED_TEST(BitwiseFloatsAdditional, Blendv) { + using vec = TypeParam; + using VT = ValueType; + CACHE_ALIGN VT a[vec::size()]; + CACHE_ALIGN VT b[vec::size()]; + CACHE_ALIGN VT mask[vec::size()] = {0}; + CACHE_ALIGN VT expected_val[vec::size()]; + blend_init(a, b); + test_blendv(expected_val, a, b, mask); + } TYPED_TEST(BitwiseFloatsAdditional2, Blend) { using vec = TypeParam; using VT = ValueType; @@ -770,6 +887,60 @@ namespace { constexpr int64_t power_sets = 1LL << (vec::size()); test_blend(expected_val, a, b); } + template + void test_set(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], int64_t count){ + if (count < 0) return; + //generate expected_val + for (int64_t i = 0; i < vec::size(); i++) { + expected_val[i] = (i < count) ? b[i] : a[i]; + } + // test with set + auto vec_a = vec::loadu(a); + auto vec_b = vec::loadu(b); + auto expected = vec::loadu(expected_val); + auto actual = vec::set(vec_a, vec_b, count); + + auto count_str = std::string("\ncount: ") + std::to_string(count); + if (AssertVec256(std::string(NAME_INFO(test_set)) + count_str, expected, actual).check()) { + return; + } + test_set(expected_val, a, b, (count == 0 ? -1 : count / 2)); + } + TYPED_TEST(BitwiseFloatsAdditional2, Set) { + using vec = TypeParam; + using VT = ValueType; + CACHE_ALIGN VT a[vec::size()]; + CACHE_ALIGN VT b[vec::size()]; + CACHE_ALIGN VT expected_val[vec::size()]; + blend_init(a, b); + test_set(expected_val, a, b, vec::size()); + } + template + std::enable_if_t::value, void> + arange_init(T& base, T& step) { + base = (T)5.0; + step = (T)2.0; + } + template + std::enable_if_t::value, void> + arange_init(T& base, T& step) { + base = T(5.0, 5.0); + step = T(2.0, 3.0); + } + TYPED_TEST(RangeFactories, Arange) { + using vec = TypeParam; + using VT = ValueType; + using UVT = UvalueType; + CACHE_ALIGN VT expected_val[vec::size()]; + VT base, step; + arange_init(base, step); + for (int64_t i = 0; i < vec::size(); i++) { + expected_val[i] = base + VT((UVT)i) * step; + } + auto expected = vec::loadu(expected_val); + auto actual = vec::arange(base, step); + AssertVec256(NAME_INFO(test_arange), expected, actual).check(); + } TEST(ComplexTests, TestComplexFloatImagRealConj) { float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28 }; float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0 }; @@ -976,14 +1147,15 @@ namespace { DomainRange{(VT)fake_zp, (VT)fake_qsix} }}) .setTestSeed(TestSeed()); - test_ternary( - NAME_INFO(relu6), - RESOLVE_OVERLOAD(relu6), - [](/*const*/ vec& v0, const vec& v1, const vec& v2) { - return v0.relu6(v1, v2); - }, - test_case); + test_ternary( + NAME_INFO(relu6), + RESOLVE_OVERLOAD(relu6), + [](/*const*/ vec& v0, const vec& v1, const vec& v2) { + return v0.relu6(v1, v2); + }, + test_case); } + #else #error GTEST does not have TYPED_TEST #endif diff --git a/aten/src/ATen/test/vec256_test_all_types.h b/aten/src/ATen/test/vec256_test_all_types.h index 3226af8422d13..9f97c9d1bc462 100644 --- a/aten/src/ATen/test/vec256_test_all_types.h +++ b/aten/src/ATen/test/vec256_test_all_types.h @@ -18,7 +18,7 @@ #define not_inline __attribute__((noinline)) #elif defined(_WIN32) #define CACHE_ALIGN __declspec(align(CACHE_LINE)) -#define not_inline __declspec(noinline) +#define not_inline __declspec(noinline) #else CACHE_ALIGN #define #define not_inline @@ -41,7 +41,7 @@ CACHE_ALIGN #define } #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) && (defined(__GNUC__) || defined(__GNUG__)) -#undef CHECK_DEQUANT_WITH_LOW_PRECISION +#undef CHECK_DEQUANT_WITH_LOW_PRECISION #define CHECK_WITH_FMA 1 #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2) #undef CHECK_DEQUANT_WITH_LOW_PRECISION @@ -352,6 +352,11 @@ T rsqrt(T x) { return 1 / std::sqrt(x); } +template +T frac(T x) { + return x - std::trunc(x); +} + template T maximum(const T& a, const T& b) { return (a > b) ? a : b; @@ -405,13 +410,36 @@ void filter_clamp(T& f, T& s, T& t) { } } +template +std::enable_if_t::value, void> filter_fmod(T& a, T& b) { + // This is to make sure fmod won't cause overflow when doing the div + if (std::abs(b) < (T)1) { + b = b < (T)0 ? (T)-1 : T(1); + } +} + +template +std::enable_if_t::value, void> filter_fmadd(T& a, T& b, T& c) { + // This is to setup a limit to make sure fmadd (a * b + c) won't overflow + T max = std::sqrt(std::numeric_limits::max()) / T(2.0); + T min = ((T)0 - max); + + if (a > max) a = max; + else if (a < min) a = min; + + if (b > max) b = max; + else if (b < min) b = min; + + if (c > max) c = max; + else if (c < min) c = min; +} + template void filter_zero(T& val) { val = is_zero(val) ? (T)1 : val; } template -std::enable_if_t>::value, void> filter_zero(Complex& val) -{ +std::enable_if_t>::value, void> filter_zero(Complex& val) { T rr = val.real(); T ii = val.imag(); rr = is_zero(rr) ? (T)1 : rr; @@ -445,7 +473,7 @@ std::enable_if_t < !is_complex::value, void> filter_add_overflow(T& a, T& b) T max = std::numeric_limits::max(); T min = std::numeric_limits::min(); // min <= (a +b) <= max; - // min - b <= a <= max - b + // min - b <= a <= max - b if (b < 0) { if (a < min - b) { a = min - b; @@ -464,7 +492,7 @@ std::enable_if_t < !is_complex::value, void> filter_sub_overflow(T& a, T& b) T max = std::numeric_limits::max(); T min = std::numeric_limits::min(); // min <= (a-b) <= max; - // min + b <= a <= max +b + // min + b <= a <= max +b if (b < 0) { if (a > max + b) { a = max + b; @@ -504,7 +532,7 @@ filter_mult_overflow(T& val1, T& val2) { // correct first; val1 = c; } - } // is_zero + } // is_zero } template @@ -929,8 +957,8 @@ void test_unary( AssertVec256 vecAssert(testNameInfo, seed, vec_expected, actual, input); if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return; - }// trial - //inrease Seed + }// trial + //inrease Seed changeSeedBy += 1; } for (auto& custom : testCase.getCustomChecks()) { @@ -1056,7 +1084,7 @@ void test_ternary( auto vec_expected = vec_type::loadu(expected); AssertVec256 vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1, input2); if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return; - }// trial + }// trial changeSeedBy += 1; } } @@ -1171,7 +1199,7 @@ std::enable_if_t>::value, Complex> local_division(Compl return x / y; #else //re = (ac + bd)/abs_2() - //im = (bc - ad)/abs_2() + //im = (bc - ad)/abs_2() T x_real = x.real(); T x_imag = x.imag(); T y_real = y.real(); @@ -1204,6 +1232,13 @@ std::enable_if_t>::value, Complex> local_division(Compl } +template +std::enable_if_t::value, T> local_fmadd(T a, T b, T c) { + PreventFma noFma; + T ab = a * b; + return noFma.add(ab, c); +} + template std::enable_if_t::value, T> local_sqrt(T x) { return std::sqrt(x); @@ -1211,22 +1246,7 @@ std::enable_if_t::value, T> local_sqrt(T x) { template std::enable_if_t>::value, Complex> local_sqrt(Complex x) { -#if defined(TEST_AGAINST_DEFAULT) return std::sqrt(x); -#else - PreventFma noFma; - // sqrt(2) / 2 * [sqrt(abs() + a) + sgn(b) * sqrt(abs() - a)i] - T real = x.real(); - T imag = x.imag(); - T abs = local_abs(x).real(); - T sqrt2_2 = std::sqrt(static_cast(2)) / static_cast(2); - T abs_r = noFma.add(abs, real); - T abs_i = noFma.sub(abs, real); - T res_r = sqrt2_2 * std::sqrt(abs_r); - T res_i = sqrt2_2 * std::sqrt(abs_i); - if (std::signbit(imag)) res_i = -res_i; - return Complex(res_r, res_i); -#endif } template @@ -1236,26 +1256,7 @@ std::enable_if_t::value, T> local_asin(T x) { template std::enable_if_t>::value, Complex> local_asin(Complex x) { -#if defined(TEST_AGAINST_DEFAULT) return std::asin(x); -#else - // asin(x) - // = -i*ln(iz + sqrt(1 -z^2)) - // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) - // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) - PreventFma noFma; - T a = x.real(); - T b = x.imag(); - T aa = a * a; - T bb = b * b; - T _ab = a * (-b); - T _2ab = noFma.add(_ab, _ab); - T aa_bb = static_cast(1) - noFma.sub(aa, bb); // 1 - (a*a-b*b) - Complex temp = Complex(-b, a) + local_sqrt(Complex(aa_bb, _2ab)); - auto ln = std::log(temp); - //-i*ln() => -i * ln => (ln.imag, -ln.real) - return Complex(ln.imag(), -ln.real()); -#endif } template @@ -1265,13 +1266,7 @@ std::enable_if_t::value, T> local_acos(T x) { template std::enable_if_t>::value, Complex> local_acos(Complex x) { -#if defined(TEST_AGAINST_DEFAULT) return std::acos(x); -#else - // pi/2 - asin(x) - auto half_pi = static_cast(M_PI) / static_cast(2); - return Complex(half_pi, 0) - local_asin(x); -#endif } template @@ -1373,9 +1368,9 @@ float dequantize_val(float scale, int64_t zero_point, T value) { float neg_p = -(zero_point * scale); float v = static_cast(value); float ret = fma(v, scale, neg_p); -#else +#else float ret = (static_cast(value) - zero_point) * scale; -#endif +#endif return ret; } @@ -1415,7 +1410,7 @@ TestingCase createDefaultUnaryTestCase(TestSeed seed = TestSeed(), bool bitwi using UVT = UvalueType; TestingCase testCase; if (!bitwise && std::is_floating_point::value) { - //for float types lets add manual ranges + //for float types lets add manual ranges UVT tolerance = getDefaultTolerance(); testCase = TestingCase::getBuilder() .set(bitwise, false) @@ -1443,7 +1438,7 @@ TestingCase createDefaultBinaryTestCase(TestSeed seed = TestSeed(), bool bitw using UVT = UvalueType; TestingCase testCase; if (!bitwise && std::is_floating_point::value) { - //for float types lets add manual ranges + //for float types lets add manual ranges UVT tolerance = getDefaultTolerance(); testCase = TestingCase::getBuilder() .set(bitwise, false) diff --git a/aten/src/ATen/test/vmap_test.cpp b/aten/src/ATen/test/vmap_test.cpp index 32aadc25b383e..87c59965f1dc2 100644 --- a/aten/src/ATen/test/vmap_test.cpp +++ b/aten/src/ATen/test/vmap_test.cpp @@ -15,8 +15,7 @@ TEST(VmapTest, TestBatchedTensor) { ASSERT_EQ(x.sizes(), expected_size); ASSERT_EQ(x.dim(), 2); ASSERT_EQ(x.numel(), 8); - ASSERT_THROW(x.strides(), c10::Error); - ASSERT_THROW(x.is_contiguous(), c10::Error); + ASSERT_EQ(x.is_contiguous(), false); ASSERT_THROW(x.storage(), c10::Error); ASSERT_THROW(x.storage_offset(), c10::Error); } @@ -297,7 +296,7 @@ TEST(VmapTest, TestVmapPhysicalViewNewLogicalFromPhysical) { VmapPhysicalView physical_view(ones({2, 3, 4}), /*levels = {2}*/4); Tensor physical = ones({2, 6, 7}); - auto result = physical_view.newLogicalFromPhysical(physical); + auto result = physical_view.getPhysicalToLogicalMap().apply(physical); auto* batched = maybeGetBatchedImpl(result); ASSERT_TRUE(batched != nullptr); ASSERT_TRUE(batched->value().is_same(physical)); @@ -308,7 +307,7 @@ TEST(VmapTest, TestVmapPhysicalViewNewLogicalFromPhysical) { VmapPhysicalView physical_view(ones({2, 3, 4, 5, 6}), /*levels = {1, 3, 4}*/2 | 8 | 16); Tensor physical = ones({2, 3, 4, 7}); - auto result = physical_view.newLogicalFromPhysical(physical); + auto result = physical_view.getPhysicalToLogicalMap().apply(physical); auto* batched = maybeGetBatchedImpl(result); ASSERT_TRUE(batched != nullptr); ASSERT_TRUE(batched->value().is_same(physical)); @@ -319,7 +318,7 @@ TEST(VmapTest, TestVmapPhysicalViewNewLogicalFromPhysical) { VmapPhysicalView physical_view(ones({2}), /*levels = {2}*/4); Tensor physical = ones({2}); - auto result = physical_view.newLogicalFromPhysical(physical); + auto result = physical_view.getPhysicalToLogicalMap().apply(physical); auto* batched = maybeGetBatchedImpl(result); ASSERT_TRUE(batched != nullptr); ASSERT_TRUE(batched->value().is_same(physical)); diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index ebf9ffce99d04..cbd65fd9b68fa 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -1,11 +1,934 @@ +#ifdef USE_VULKAN_API + #include +#include -#ifdef USE_VULKAN_API +// TODO: These functions should move to a common place. + +namespace { + +bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { + float maxValue = 0.0f; -#include + for (const auto& tensor : inputs) { + maxValue = fmax(tensor.abs().max().item(), maxValue); + } + +#ifdef USE_VULKAN_FP16_INFERENCE + constexpr float tolerance = 1e-2; +#else + constexpr float tolerance = 1e-5; +#endif + + return diff.abs().max().item() < (tolerance * maxValue); +} + +bool almostEqual(const at::Tensor& a, const at::Tensor& b) { + return checkRtol(a - b, {a, b}); +} + +bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { + return (a - b).abs().max().item() == 0.0f; +} + +} // namespace namespace { +TEST(VulkanAPITest, adaptive_avg_pool2d) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = at::rand({5, 7, 47, 31}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::adaptive_avg_pool2d(in_cpu, {3, 3}); + const auto out_vulkan = at::adaptive_avg_pool2d(in_cpu.vulkan(), {3, 3}); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, add) { + if (!at::is_vulkan_available()) { + return; + } + + const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const auto b_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto b_vulkan = b_cpu.vulkan(); + + const auto c_cpu = at::add(a_cpu, b_cpu, 2.1f); + const auto c_vulkan = at::add(a_vulkan, b_vulkan, 2.1f); + + const auto check = almostEqual(c_cpu, c_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << c_cpu << std::endl; + std::cout << "Got:\n" << c_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, add_) { + if (!at::is_vulkan_available()) { + return; + } + + auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const auto b_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat)); + const auto b_vulkan = b_cpu.vulkan(); + + a_cpu.add_(b_cpu, 2.1f); + a_vulkan.add_(b_vulkan, 2.1f); + + const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << a_cpu << std::endl; + std::cout << "Got:\n" << a_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, add_scalar) { + if (!at::is_vulkan_available()) { + return; + } + + const auto a_cpu = at::rand({13, 23, 59, 73}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + const auto c_cpu = at::add(a_cpu, b_scalar, 2.1f); + const auto c_vulkan = at::add(a_vulkan, b_scalar, 2.1f); + + const auto check = almostEqual(c_cpu, c_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << c_cpu << std::endl; + std::cout << "Got:\n" << c_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, add_scalar_) { + if (!at::is_vulkan_available()) { + return; + } + + auto a_cpu = at::rand({47, 2, 23, 97}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + a_cpu.add_(b_scalar, 2.1f); + a_vulkan.add_(b_scalar, 2.1f); + + const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << a_cpu << std::endl; + std::cout << "Got:\n" << a_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, addmm) { + if (!at::is_vulkan_available()) { + return; + } + + constexpr float alpha = 2.1f; + constexpr float beta = 103.24; + + const auto bias_cpu = at::rand({179, 163}, at::device(at::kCPU).dtype(at::kFloat)); + const auto m1_cpu = at::rand({179, 67}, at::device(at::kCPU).dtype(at::kFloat)); + const auto m2_cpu = at::rand({67, 163}, at::device(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::addmm(bias_cpu, m1_cpu, m2_cpu, beta, alpha); + + const auto bias_vulkan = bias_cpu.vulkan(); + const auto m1_vulkan = m1_cpu.vulkan(); + const auto m2_vulkan = m2_cpu.vulkan(); + const auto out_vulkan = at::addmm(bias_vulkan, m1_vulkan, m2_vulkan, beta, alpha); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, addmm_expand) { + if (!at::is_vulkan_available()) { + return; + } + + constexpr float alpha = 2.1f; + constexpr float beta = 103.24; + + const auto bias_cpu = at::rand({1000}, at::device(at::kCPU).dtype(at::kFloat)); + const auto m1_cpu = at::rand({1, 1280}, at::device(at::kCPU).dtype(at::kFloat)); + const auto m2_cpu = at::rand({1280, 1000}, at::device(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::addmm(bias_cpu, m1_cpu, m2_cpu, beta, alpha); + + const auto bias_vulkan = bias_cpu.vulkan(); + const auto m1_vulkan = m1_cpu.vulkan(); + const auto m2_vulkan = m2_cpu.vulkan(); + const auto out_vulkan = at::addmm(bias_vulkan, m1_vulkan, m2_vulkan, beta, alpha); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, avg_pool2d) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = at::rand({3, 19, 43, 79}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::avg_pool2d(in_cpu, {5, 3}, {1, 2}, {2, 0}, true); + const auto out_vulkan = at::avg_pool2d(in_cpu.vulkan(), {5, 3}, {1, 2}, {2, 0}, true); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, clamp) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const float min_value = 0.2f; + const float max_value = 0.8f; + + const auto out_cpu = at::clamp(in_cpu, min_value, max_value); + const auto out_vulkan = at::clamp(in_vulkan, min_value, max_value); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, clamp_) { + if (!at::is_vulkan_available()) { + return; + } + + const auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto vulkan = cpu.vulkan(); + + const float min_value = 0.2f; + const float max_value = 0.8f; + + cpu.clamp_(min_value, max_value); + vulkan.clamp_(min_value, max_value); + + const auto check = almostEqual(cpu, vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << cpu << std::endl; + std::cout << "Got:\n" << vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, conv2d) { + if (!at::is_vulkan_available()) { + return; + } + + constexpr int64_t groups = 1; + constexpr std::array stride{1, 2}; + constexpr std::array padding{3, 0}; + //TODO: Support conv2d with dilation != 1 + constexpr std::array dilation{1, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input {1, 37, 223, 227}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights {83, input.channels, 13, 2}; + + const auto input_cpu = at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn({weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const auto output_cpu = at::conv2d( + input_cpu, + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const auto output_vulkan = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const bool check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << output_cpu << std::endl; + std::cout << "Got:\n" << output_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, conv2d_dw) { + if (!at::is_vulkan_available()) { + return; + } + + constexpr int64_t groups = 7; + constexpr std::array stride{2, 3}; + constexpr std::array padding{0, 4}; + constexpr std::array dilation{3, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input {1, groups, 137, 199}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights {groups, 1, 17, 7}; + + const auto input_cpu = at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::rand({weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const auto output_cpu = at::conv2d( + input_cpu, + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const auto output_vulkan = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const bool check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << output_cpu << std::endl; + std::cout << "Got:\n" << output_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, conv2d_pw) { + if (!at::is_vulkan_available()) { + return; + } + + constexpr int64_t groups = 1; + constexpr std::array stride{1, 1}; + constexpr std::array padding{0, 0}; + constexpr std::array dilation{1, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input {1, 17, 127, 397}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights {29, input.channels, 1, 1}; + + const auto input_cpu = at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::randn({weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const auto output_cpu = at::conv2d( + input_cpu, + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const auto output_vulkan = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const bool check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << output_cpu << std::endl; + std::cout << "Got:\n" << output_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, copy) { + if (!at::is_vulkan_available()) { + return; + } + + const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat)); + const auto vulkan = cpu.vulkan(); + + const auto check = exactlyEqual(cpu, vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << cpu << std::endl; + std::cout << "Got:\n" << vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, empty) { + if (!at::is_vulkan_available()) { + return; + } + + ASSERT_NO_THROW(at::empty({1, 17, 41, 53}, at::device(at::kVulkan).dtype(at::kFloat))); +} + +TEST(VulkanAPITest, mean) { + const auto in_cpu = at::rand({17, 3, 79, 53}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::mean(in_cpu, {-1, -2}, true); + + const auto in_vulkan = in_cpu.vulkan(); + const auto out_vulkan = at::mean(in_vulkan, {-1, -2}, true); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, mean2d) { + const auto in_cpu = at::rand({11, 7, 173, 37}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::mean(in_cpu, {-1, -2}, false); + + const auto in_vulkan = in_cpu.vulkan(); + const auto out_vulkan = at::mean(in_vulkan, {-1, -2}, false); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, mm) { + if (!at::is_vulkan_available()) { + return; + } + + const auto m1_cpu = at::rand({241, 313}, at::device(at::kCPU).dtype(at::kFloat)); + const auto m2_cpu = at::rand({313, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = m1_cpu.mm(m2_cpu); + + const auto m1_vulkan = m1_cpu.vulkan(); + const auto m2_vulkan = m2_cpu.vulkan(); + const auto out_vulkan = m1_vulkan.mm(m2_vulkan); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, mul_scalar) { + if (!at::is_vulkan_available()) { + return; + } + + const auto a_cpu = at::rand({17, 213, 213, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + const auto c_cpu = at::mul(a_cpu, b_scalar); + const auto c_vulkan = at::mul(a_vulkan, b_scalar); + + const auto check = almostEqual(c_cpu, c_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << c_cpu << std::endl; + std::cout << "Got:\n" << c_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, mul_scalar_) { + if (!at::is_vulkan_available()) { + return; + } + + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + a_cpu.mul_(b_scalar); + a_vulkan.mul_(b_scalar); + + const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << a_cpu << std::endl; + std::cout << "Got:\n" << a_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, reshape) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = at::rand({47, 11, 83, 97}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const std::array shape{47 * 83, 11 * 97}; + + const auto out_cpu = at::reshape(in_cpu, shape); + const auto out_vulkan = at::reshape(in_vulkan, shape); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, reshape_) { + if (!at::is_vulkan_available()) { + return; + } + + const auto cpu = at::rand({59, 41, 19, 67}, at::device(at::kCPU).dtype(at::kFloat)); + const auto vulkan = cpu.vulkan(); + + const std::array shape{59, 41 * 67, 19}; + + cpu.reshape(shape); + vulkan.reshape(shape); + + const auto check = almostEqual(cpu, vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << cpu << std::endl; + std::cout << "Got:\n" << vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, upsample_nearest2d) { + if (!at::is_vulkan_available()) { + return; + } + + const auto in_cpu = at::rand({1, 2, 2, 3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + const auto out_cpu = at::upsample_nearest2d(in_cpu, {4, 6}); + + const auto in_vulkan = in_cpu.vulkan(); + const auto out_vulkan = at::upsample_nearest2d(in_vulkan, {4, 6}); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + std::cout << "Expected:\n" << out_cpu << std::endl; + std::cout << "Got:\n" << out_vulkan.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + +enum class OpType { + addmm, + conv2d, + hardtanh_, + mean, + }; + +class BaseOp { + public: + explicit BaseOp(const OpType type) : type_(type) {} + virtual ~BaseOp() = default; + + virtual at::Tensor run(at::Tensor&) const = 0; + virtual std::string toString() const = 0; + + private: + OpType type_; +}; + +class Addmm final : public BaseOp { + public: + Addmm( + const int64_t m1H, + const int64_t m1W, + const int64_t m2W, + const float beta, + const float alpha) + : BaseOp(OpType::addmm), + m2_(at::rand(c10::IntArrayRef({m1W, m2W}), at::device(at::kCPU).dtype(at::kFloat))), + v_m2(m2_.vulkan()), + b_(at::rand(c10::IntArrayRef({m1H, m2W}), at::device(at::kCPU).dtype(at::kFloat))), + v_b_(b_.vulkan()), + beta_(beta), + alpha_(alpha) { + } + + at::Tensor run(at::Tensor& t) const override { + if (t.is_vulkan()) { + return at::addmm(v_b_, t, v_m2, beta_, alpha_); + } + + return at::addmm(b_, t, m2_, beta_, alpha_); + } + + std::string toString() const override { + return "addmm"; + } + + private: + at::Tensor m2_; + at::Tensor v_m2; + at::Tensor b_; + at::Tensor v_b_; + float beta_; + float alpha_; +}; + +class Conv2d final : public BaseOp { + public: + Conv2d( + const c10::IntArrayRef wsizes, + const int64_t groups, + const int64_t stride, + const int64_t padding) + : BaseOp(OpType::conv2d), + groups_(groups), + stride_(stride), + padding_(padding), + w_(at::rand(wsizes, at::device(at::kCPU).dtype(at::kFloat))), + b_(at::rand(wsizes[0], at::device(at::kCPU).dtype(at::kFloat))){ + } + + at::Tensor run(at::Tensor& t) const override { + return at::conv2d(t, w_, b_, {stride_}, {padding_}, {1}, groups_); + } + + std::string toString() const override { + return "conv2d"; + } + + private: + int64_t groups_; + int64_t stride_; + int64_t padding_; + at::Tensor w_; + at::Tensor b_; +}; + +class Hardtanh_ final : public BaseOp { + public: + Hardtanh_() : BaseOp(OpType::hardtanh_) {} + + at::Tensor run(at::Tensor& input) const override { + return at::hardtanh_(input, 0, 6); + } + + std::string toString() const override { + return "hardtanh_"; + } +}; + +class Mean final : public BaseOp { + public: + Mean() : BaseOp(OpType::mean) {} + + at::Tensor run(at::Tensor& input) const override { + return at::mean(input, {2, 3}, false); + } + + std::string toString() const override { + return "mean"; + } +}; + +class OpsList { + public: + OpsList() {} + explicit OpsList(std::vector> ops) + : ops_(std::move(ops)) { + } + + auto run(const at::Tensor& input) { + at::Tensor output = input; + + for (const auto& op : ops_) { + output = op->run(output); + } + + return output; + } + + auto run(const at::Tensor& input, const at::Tensor& v_input) { + at::Tensor output = input; + at::Tensor v_output = v_input; + + for (const auto& op : ops_) { + output = op->run(output); + v_output = op->run(v_output); + } + + return std::make_pair(output, v_output); + } + + protected: + std::vector> ops_; +}; + +class MobileNetV2 final : public OpsList { + public: + MobileNetV2() { + ops_.emplace_back(new Conv2d({32, 3, 3, 3}, 1, 2, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({32, 1, 3, 3}, 32, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({16, 32, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({96, 16, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({96, 1, 3, 3}, 96, 2, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({24, 96, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({144, 24, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({144, 1, 3, 3}, 144, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({24, 144, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({144, 24, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({144, 1, 3, 3}, 144, 2, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({32, 144, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({192, 32, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({192, 1, 3, 3}, 192, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({32, 192, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({192, 32, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({192, 1, 3, 3}, 192, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({32, 192, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({192, 32, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({192, 1, 3, 3}, 192, 2, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({64, 192, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({384, 64, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({384, 1, 3, 3}, 384, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({64, 384, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({384, 64, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({384, 1, 3, 3}, 384, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({64, 384, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({384, 64, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({384, 1, 3, 3}, 384, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({64, 384, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({384, 64, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({384, 1, 3, 3}, 384, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({96, 384, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({576, 96, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({576, 1, 3, 3}, 576, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({96, 576, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({576, 96, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({576, 1, 3, 3}, 576, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({96, 576, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({576, 96, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({576, 1, 3, 3}, 576, 2, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({160, 576, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({960, 160, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({960, 1, 3, 3}, 960, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({160, 960, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({960, 160, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({960, 1, 3, 3}, 960, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({160, 960, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({960, 160, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({960, 1, 3, 3}, 960, 1, 1)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Conv2d({320, 960, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Conv2d({1280, 320, 1, 1}, 1, 1, 0)); + ops_.emplace_back(new Hardtanh_()); + ops_.emplace_back(new Mean()); + ops_.emplace_back(new Addmm(1, 1280, 1000, 0, 1)); + } +}; + +TEST(VulkanAPITest, mobilenetv2) { + if (!at::is_vulkan_available()) { + return; + } + + MobileNetV2 mn2; + + const auto input = at::rand({1, 3, 224, 224}, at::device(at::kCPU).dtype(at::kFloat)); + const auto output = mn2.run(input, input.vulkan()); + + const auto check = almostEqual(output.first, output.second.cpu()); + if (!check) { + std::cout << "Expected:\n" << output.first << std::endl; + std::cout << "Got:\n" << output.second.cpu() << std::endl; + } + + ASSERT_TRUE(check); +} + } // namespace #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/test/vulkan_test.cpp b/aten/src/ATen/test/vulkan_test.cpp index c8d1b72cc06bf..6b066b4337be9 100644 --- a/aten/src/ATen/test/vulkan_test.cpp +++ b/aten/src/ATen/test/vulkan_test.cpp @@ -1,3 +1,5 @@ +#ifndef USE_VULKAN_API + #include #include @@ -45,7 +47,12 @@ TEST(VulkanTest, upsampleNearest2D) { auto t_out = tv_out.to(at::TensorOptions{at::Device{at::kCPU}}.dtype(at::kFloat)); - ASSERT_TRUE(almostEqual(t_out, t_out_expected)); + bool check = almostEqual(t_out_expected, t_out); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); } TEST(VulkanTest, add) { @@ -208,7 +215,12 @@ TEST(VulkanTest, conv2dDWWeightsOnCPU) { auto tv_in = t_in.vulkan(); auto tv_out = at::conv2d(tv_in, t_w, t_b, stride, padding, dilation, groups); auto t_out = tv_out.cpu(); - ASSERT_TRUE(almostEqual(t_out, t_out_expected)); + bool check = almostEqual(t_out_expected, t_out); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); } TEST(VulkanTest, addmm) { @@ -227,7 +239,12 @@ TEST(VulkanTest, addmm) { auto tv_b = t_b.vulkan(); auto tv_out = at::addmm(tv_b, tv_m1, tv_m2, beta, alpha); auto t_out = tv_out.cpu(); - ASSERT_TRUE(almostEqual(t_out, t_out_expected)); + bool check = almostEqual(t_out_expected, t_out); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); } TEST(VulkanTest, mm) { @@ -242,7 +259,12 @@ TEST(VulkanTest, mm) { auto tv_m2 = t_m2.vulkan(); auto tv_out = tv_m1.mm(tv_m2); auto t_out = tv_out.cpu(); - ASSERT_TRUE(almostEqual(t_out, t_out_expected)); + bool check = almostEqual(t_out_expected, t_out); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); } TEST(VulkanTest, clamp) { @@ -301,7 +323,12 @@ TEST(VulkanTest, mean) { auto tv_in = t_in.vulkan(); auto tv_out = at::mean(tv_in, {2, 3}, false); auto t_out = tv_out.cpu(); - ASSERT_TRUE(almostEqual(t_out, t_out_expected)); + bool check = almostEqual(t_out_expected, t_out); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); } enum class OpType { conv2d, hardtanh_, mean, addmm }; @@ -874,7 +901,7 @@ TEST(VulkanTest, cat) { ASSERT_TRUE(check); } -TEST(VulkanTest, max_pool2d) { +TEST(VulkanTest, DISABLED_max_pool2d) { if (!at::is_vulkan_available()) return; @@ -900,10 +927,10 @@ TEST(VulkanTest, avg_pool2d) { auto t_in = at::rand({1, 3, 7, 7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto t_out_expected = at::avg_pool2d(t_in, {2, 2}, {1}, {0}, {1}); + auto t_out_expected = at::avg_pool2d(t_in, {2, 2}, {1}, {0}, true); auto tv_in = t_in.vulkan(); - auto tv_out = at::avg_pool2d(tv_in, {2, 2}, {1}, {0}, {1}); + auto tv_out = at::avg_pool2d(tv_in, {2, 2}, {1}, {0}, true); auto t_out = tv_out.cpu(); const auto check = almostEqual(t_out, t_out_expected); @@ -913,3 +940,5 @@ TEST(VulkanTest, avg_pool2d) { } ASSERT_TRUE(check); } + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/vulkan/Context.cpp b/aten/src/ATen/vulkan/Context.cpp index 8d2b5281d2aee..793c690a0c141 100644 --- a/aten/src/ATen/vulkan/Context.cpp +++ b/aten/src/ATen/vulkan/Context.cpp @@ -3,6 +3,10 @@ #include #include +#ifdef USE_VULKAN_API +#include +#endif /* USE_VULKAN_API */ + namespace at { namespace vulkan { @@ -23,8 +27,12 @@ at::Tensor& vulkan_copy_(at::Tensor& self, const at::Tensor& src) { namespace native { bool is_vulkan_available() { +#ifdef USE_VULKAN_API + return native::vulkan::api::available(); +#else auto p = at::vulkan::g_vulkan_impl_registry.load(); return p ? p->is_vulkan_available() : false; +#endif } } // namespace native diff --git a/aten/src/TH/CMakeLists.txt b/aten/src/TH/CMakeLists.txt index 6a491991a090b..5661a697da38e 100644 --- a/aten/src/TH/CMakeLists.txt +++ b/aten/src/TH/CMakeLists.txt @@ -65,6 +65,7 @@ install(FILES THGenerateComplexTypes.h THGenerateIntTypes.h THGenerateQUInt8Type.h + THGenerateQUInt4x2Type.h THGenerateQInt8Type.h THGenerateQInt32Type.h THGenerateQTypes.h @@ -78,7 +79,6 @@ install(FILES THHalf.h THTensor.hpp THStorageFunctions.hpp - THGenerator.hpp DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/TH") install(FILES diff --git a/aten/src/TH/THAllocator.cpp b/aten/src/TH/THAllocator.cpp index 6cdc62ab6da60..53b67a17032f9 100644 --- a/aten/src/TH/THAllocator.cpp +++ b/aten/src/TH/THAllocator.cpp @@ -6,10 +6,11 @@ #endif #include +#include /* stuff for mapped files */ #ifdef _WIN32 -#include +#include #endif #if defined(HAVE_MMAP) @@ -74,24 +75,26 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, #ifdef _WIN32 if (flags_ & TH_ALLOCATOR_MAPPED_SHAREDMEM) { // Shadowing - const char *filename; - const char *eventname; + const wchar_t *filename; + const wchar_t *eventname; + const std::wstring wFilename = c10::u8u16(filename_); + const std::wstring wEventname = c10::u8u16(eventname_); LARGE_INTEGER hfilesz; if (filename_[0] == '/') { - filename = filename_.c_str() + 1; - eventname = eventname_.c_str() + 1; + filename = wFilename.c_str() + 1; + eventname = wEventname.c_str() + 1; } else { - filename = filename_.c_str(); - eventname = eventname_.c_str(); + filename = wFilename.c_str(); + eventname = wEventname.c_str(); } hfilesz.QuadPart = size; if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) { - event_ = CreateEvent(nullptr, FALSE, FALSE, eventname); + event_ = CreateEventW(nullptr, FALSE, FALSE, eventname); } else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) { - event_ = OpenEvent(EVENT_ALL_ACCESS, FALSE, eventname); + event_ = OpenEventW(EVENT_ALL_ACCESS, FALSE, eventname); } else { AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE"); } @@ -101,9 +104,9 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, } if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) { - handle_ = CreateFileMapping(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename); + handle_ = CreateFileMappingW(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename); } else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) { - handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, filename); + handle_ = OpenFileMappingW(FILE_MAP_ALL_ACCESS, FALSE, filename); } else { AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE"); } @@ -136,15 +139,21 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, AT_ERROR("TH_ALLOCATOR_MAPPED_FROMFD not supported on Windows"); } + // Shadowing + const wchar_t *filename; + const std::wstring wFilename = c10::u8u16(filename_); + + filename = wFilename.c_str(); + /* open file */ /* FILE_FLAG_RANDOM_ACCESS ? */ if (flags_) { - hfile = CreateFileA(filename_.c_str(), GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0); + hfile = CreateFileW(filename, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0); if (hfile == INVALID_HANDLE_VALUE) { AT_ERROR("could not open file <", filename_, "> in read-write mode; error code: <", GetLastError(), ">"); } } else { - hfile = CreateFileA(filename_.c_str(), GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + hfile = CreateFileW(filename, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); if (hfile == INVALID_HANDLE_VALUE) { AT_ERROR("could not open file <", filename_, "> in read-only mode; error code: <", GetLastError(), ">"); } @@ -181,11 +190,11 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, /* get map handle */ if (flags_) { - if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { + if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">"); } } else { - if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { + if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) { AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">"); } } @@ -288,6 +297,7 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, if (base_ptr_ == MAP_FAILED) { base_ptr_ = nullptr; /* let's be sure it is NULL */ + AT_ERROR("unable to mmap ", size_, " bytes from file <", filename_, ">: ", strerror(errno), " (", errno, ")"); } if (flags_ & TH_ALLOCATOR_MAPPED_KEEPFD) { @@ -332,7 +342,7 @@ typedef struct{ HANDLE handle; HANDLE wait; } ReleaseContext; -static VOID CALLBACK WaitForReleaseHandle(PVOID lpParam, BOOLEAN TimerOrWaitFired) +static void CALLBACK WaitForReleaseHandle(PVOID lpParam, BOOLEAN TimerOrWaitFired) { if (lpParam) { ReleaseContext *ctx = (ReleaseContext *)lpParam; diff --git a/aten/src/TH/THAllocator.h b/aten/src/TH/THAllocator.h index d189bd1b91513..4a4e385281e75 100644 --- a/aten/src/TH/THAllocator.h +++ b/aten/src/TH/THAllocator.h @@ -21,7 +21,7 @@ TH_API c10::Allocator* getTHDefaultAllocator(void); // the non-file descriptor constructor enum WithFd { WITH_FD }; -class CAFFE2_API THMapAllocator { +class TORCH_API THMapAllocator { public: THMapAllocator(const char *filename, int flags, size_t size); THMapAllocator(WithFd, const char *filename, int fd, int flags, size_t size); @@ -71,11 +71,11 @@ class CAFFE2_API THMapAllocator { }; // Base-from-member idiom -struct CAFFE2_API THRefcountedMapAllocatorArgCheck { +struct TORCH_API THRefcountedMapAllocatorArgCheck { THRefcountedMapAllocatorArgCheck(int flags); }; -class CAFFE2_API THRefcountedMapAllocator +class TORCH_API THRefcountedMapAllocator : private THRefcountedMapAllocatorArgCheck, public THMapAllocator { public: diff --git a/aten/src/TH/THGenerateQTypes.h b/aten/src/TH/THGenerateQTypes.h index ee958b3a3210e..611b990f508f4 100644 --- a/aten/src/TH/THGenerateQTypes.h +++ b/aten/src/TH/THGenerateQTypes.h @@ -10,6 +10,7 @@ #include #include #include +#include #ifdef THQLocalGenerateManyTypes #undef THQLocalGenerateManyTypes diff --git a/aten/src/TH/THGenerateQUInt4x2Type.h b/aten/src/TH/THGenerateQUInt4x2Type.h new file mode 100644 index 0000000000000..4ecea45143593 --- /dev/null +++ b/aten/src/TH/THGenerateQUInt4x2Type.h @@ -0,0 +1,24 @@ +#ifndef TH_GENERIC_FILE +#error "You must define TH_GENERIC_FILE before including THGenerateQUInt4x2Type.h" +#endif + +#define quantized_t c10::quint4x2 +#define scalar_t uint8_t +#define Real QUInt4x2 +#define RealUnderlying Byte +#define THQUANTIZED +#define THQUINT8 +#define TH_REAL_IS_BYTE +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef scalar_t +#undef quantized_t +#undef Real +#undef RealUnderlying +#undef TH_REAL_IS_BYTE +#undef THQUINT8 +#undef THQUANTIZED + +#ifndef THGenerateManyTypes +#undef TH_GENERIC_FILE +#endif diff --git a/aten/src/TH/THGenerator.hpp b/aten/src/TH/THGenerator.hpp deleted file mode 100644 index 1a40611f8b5b8..0000000000000 --- a/aten/src/TH/THGenerator.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include - -/** - * THGeneratorState is a POD class needed for memcpys - * in torch.get_rng_state() and torch.set_rng_state(). - * It is a legacy class and even though it is replaced with - * at::CPUGeneratorImpl, we need this class and some of its fields - * to support backward compatibility on loading checkpoints. - */ -struct THGeneratorState { - /* The initial seed. */ - uint64_t the_initial_seed; - int left; /* = 1; */ - int seeded; /* = 0; */ - uint64_t next; - uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */ - - /********************************/ - - /* For normal distribution */ - double normal_x; - double normal_y; - double normal_rho; - int normal_is_valid; /* = 0; */ -}; - -/** - * THGeneratorStateNew is a POD class containing - * new data introduced in at::CPUGeneratorImpl and the legacy state. It is used - * as a helper for torch.get_rng_state() and torch.set_rng_state() - * functions. - */ -struct THGeneratorStateNew { - THGeneratorState legacy_pod; - float next_float_normal_sample; - bool is_next_float_normal_sample_valid; -}; diff --git a/aten/src/TH/THStorageFunctions.hpp b/aten/src/TH/THStorageFunctions.hpp index b78f8c7a3035a..8d5c28daa796b 100644 --- a/aten/src/TH/THStorageFunctions.hpp +++ b/aten/src/TH/THStorageFunctions.hpp @@ -8,6 +8,7 @@ #include #include +#include // Note [Weak references for intrusive refcounting] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/TH/generic/THLapack.cpp b/aten/src/TH/generic/THLapack.cpp index 6a776f4d0a174..5b4ef15e7c2c6 100644 --- a/aten/src/TH/generic/THLapack.cpp +++ b/aten/src/TH/generic/THLapack.cpp @@ -2,11 +2,8 @@ #define TH_GENERIC_FILE "TH/generic/THLapack.cpp" #else - TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); -TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); -TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info); TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info); TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info); @@ -32,21 +29,6 @@ void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, s #endif } -/* Compute for an N-by-N real nonsymmetric matrix A, the eigenvalues and, -optionally, the left and/or right eigenvectors */ -void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info) -{ -#ifdef USE_LAPACK -#if defined(TH_REAL_IS_DOUBLE) - dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); -#else - sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); -#endif -#else - THError("geev : Lapack library not found in compile time\n"); -#endif -} - /* Cholesky factorization based Matrix Inverse */ void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info) { diff --git a/aten/src/TH/generic/THLapack.h b/aten/src/TH/generic/THLapack.h index 287915c74d261..121eee871c67b 100644 --- a/aten/src/TH/generic/THLapack.h +++ b/aten/src/TH/generic/THLapack.h @@ -4,8 +4,6 @@ /* ||AX-B|| */ TH_API void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, scalar_t *work, int lwork, int *info); -/* Non-sym eigenvals */ -TH_API void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); /* Positive Definite matrices */ /* Matrix inverse based on Cholesky factorization */ diff --git a/aten/src/TH/generic/THStorage.cpp b/aten/src/TH/generic/THStorage.cpp index 2db795719557a..a085f31c740f3 100644 --- a/aten/src/TH/generic/THStorage.cpp +++ b/aten/src/TH/generic/THStorage.cpp @@ -115,10 +115,9 @@ void THStorage_(resizeBytes)(THStorage* storage, ptrdiff_t size_bytes) { void THStorage_(fill)(THStorage *storage, scalar_t value) { - ptrdiff_t i; auto type_meta = caffe2::TypeMeta::Make(); size_t numel = storage->nbytes() / type_meta.itemsize(); - for (i = 0; i < numel; i++) + for (size_t i = 0; i < numel; i++) THStorage_(data)(storage)[i] = value; } diff --git a/aten/src/TH/generic/THStorage.h b/aten/src/TH/generic/THStorage.h index cd419c695ba50..a41991c469c7a 100644 --- a/aten/src/TH/generic/THStorage.h +++ b/aten/src/TH/generic/THStorage.h @@ -38,6 +38,7 @@ #define THQUInt8Storage THStorage #define THQInt8Storage THStorage #define THQInt32Storage THStorage +#define THQUInt4x2Storage THStorage #define THComplexFloatStorage THStorage #define THComplexDoubleStorage THStorage diff --git a/aten/src/TH/generic/THStorageCopy.cpp b/aten/src/TH/generic/THStorageCopy.cpp index dc19deea7652c..2d6ec8a05eb66 100644 --- a/aten/src/TH/generic/THStorageCopy.cpp +++ b/aten/src/TH/generic/THStorageCopy.cpp @@ -8,7 +8,7 @@ void THStorage_(copy)(THStorage *storage, THStorage *src) scalar_t *scalar_src = THStorage_(data)(src); scalar_t *data = THStorage_(data)(storage); uint64_t numel = storage->nbytes() / sizeof(scalar_t); - for (ptrdiff_t i = 0; i < numel; ++i) { + for (uint64_t i = 0; i < numel; ++i) { data[i] = scalar_src[i]; } } @@ -19,11 +19,10 @@ void THStorage_(copy)(THStorage *storage, THStorage *src) #define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \ void THStorage_(copy##TYPENAMESRC)( \ THStorage * storage, TH##TYPENAMESRC##Storage * src) { \ - ptrdiff_t i; \ auto data = THStorage_(data)(storage); \ auto src_data = TH##TYPENAMESRC##Storage_data(src); \ uint64_t numel = storage->nbytes() / sizeof(scalar_t); \ - for (i = 0; i < numel; i++) \ + for (uint64_t i = 0; i < numel; i++) \ data[i] = static_cast(src_data[i]); \ } diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index 764220c24673a..9c1eb3cdfe224 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -5,6 +5,7 @@ #include #include #include +#include // Finds non-zero elements of a tensor and returns their subscripts void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) @@ -216,50 +217,6 @@ static inline int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t nu return linearIndex < 0 ? linearIndex + numel : linearIndex; } -void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index) -{ - THTensor_(resizeNd)(r_, index->dim(), THTensor_getSizePtr(index), NULL); - THTensor* dst = THTensor_(newContiguous)(r_); - - index = THLongTensor_newContiguous(index); - int64_t* index_data = THLongTensor_data(index); - ptrdiff_t srcElements = THTensor_(nElement)(src); - scalar_t* src_data = src->data(); - scalar_t* dst_data = dst->data(); - ptrdiff_t nIndices = THLongTensor_nElement(index); - int isContiguous = THTensor_(isContiguous)(src); - - // Exceptions must not be thrown across parallel sections, so we - // record the position of the invalid index and throw the exception after the - // loop. - std::atomic invalidIdxPos(-1); - - at::parallel_for(0, nIndices, TH_OMP_OVERHEAD_THRESHOLD, - [&](int64_t start, int64_t end) { - for (auto i = start; i < end; i++) { - int64_t idx = index_data[i]; - if (idx < srcElements && idx >= -srcElements) { - idx = THTensor_(wrapLinearIndex)(idx, srcElements); - if (isContiguous) { - dst_data[i] = src_data[idx]; - } else { - dst_data[i] = src_data[THTensor_(dataOffset)(src, idx)]; - } - } else { - int64_t tmp = -1; - invalidIdxPos.compare_exchange_strong(tmp, i); - } - } - }); - - if (invalidIdxPos >= 0) { - THTensor_(checkLinearIndex)(index_data[invalidIdxPos], srcElements); - } - - THLongTensor_free(index); - THTensor_(freeCopyTo)(dst, r_); -} - void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate) { THArgCheck(THLongTensor_nElement(index) == THTensor_(nElement)(src), 3, @@ -298,6 +255,13 @@ void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar numel = THLongTensor_nElement(index); THArgCheck(THTensor_nDimensionLegacyNoScalars(index) == 1, 3, "Index is supposed to be a vector"); THArgCheck(dim < THTensor_nDimensionLegacyNoScalars(tensor), 4,"Indexing dim %d is out of bounds of tensor", dim); + at::assert_no_overlap(tensor, index); + if (at::has_internal_overlap(tensor) == at::MemOverlap::YES) { + TORCH_WARN( + "Use of index_fill_ on expanded tensors is deprecated. " + "Please clone() the tensor before performing this operation. " + "This also applies to advanced indexing e.g. tensor[mask] = scalar"); + } index = THLongTensor_newContiguous(index); index_data = THLongTensor_data(index); diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 2494e21791e44..e6c2001691910 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -191,88 +191,6 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) if (free_b) c10::raw::intrusive_ptr::decref(b); } -void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, bool eigenvectors) -{ - char jobvr = eigenvectors ? 'V' : 'N'; - int n, lda, lwork, info, ldvr; - THTensor *work=nullptr, *wi, *wr, *a; - scalar_t wkopt; - scalar_t *rv_data; - int64_t i; - - THTensor *re__ = NULL; - THTensor *rv__ = NULL; - - THArgCheck(a_->dim() == 2, 1, "A should be 2 dimensional"); - THArgCheck(a_->size(0) == a_->size(1), 1,"A should be square"); - THArgCheck(THTensor_(isFinite)(a_), 1, "A should not contain infs or NaNs"); - - /* we want to definitely clone a_ for geev*/ - a = THTensor_(cloneColumnMajor)(NULL, a_); - - n = a->size(0); - lda = n; - - wi = THTensor_(newWithSize1d)(n); - wr = THTensor_(newWithSize1d)(n); - - rv_data = NULL; - ldvr = 1; - if (jobvr == 'V') - { - THTensor_(resize2d)(rv_,n,n); - /* guard against someone passing a correct size, but wrong stride */ - rv__ = THTensor_(newTransposedContiguous)(rv_); - rv_data = rv__->data(); - ldvr = n; - } - THTensor_(resize2d)(re_,n,2); - re__ = THTensor_(newContiguous)(re_); - - if (n > 0) { // lapack doesn't work with size 0 - /* get optimal workspace size */ - THLapack_(geev)('N', jobvr, n, a->data(), lda, wr->data(), wi->data(), - NULL, 1, rv_data, ldvr, &wkopt, -1, &info); - - lwork = (int)wkopt; - work = THTensor_(newWithSize1d)(lwork); - - THLapack_(geev)('N', jobvr, n, a->data(), lda, wr->data(), wi->data(), - NULL, 1, rv_data, ldvr, work->data(), lwork, &info); - - THLapackCheckWithCleanup(" Lapack Error in %s : %d off-diagonal elements of an didn't converge to zero", - THCleanup(c10::raw::intrusive_ptr::decref(re__); - c10::raw::intrusive_ptr::decref(rv__); - c10::raw::intrusive_ptr::decref(a); - c10::raw::intrusive_ptr::decref(wi); - c10::raw::intrusive_ptr::decref(wr); - c10::raw::intrusive_ptr::decref(work);), - "geev", info,""); - } - - { - scalar_t *re_data = re__->data(); - scalar_t *wi_data = wi->data(); - scalar_t *wr_data = wr->data(); - for (i=0; iis_empty(), 1, "'input' should not be empty"); + THArgCheck(!a->is_empty(), 2, "'input' should not be empty"); + THArgCheck(!tau->is_empty(), 3, "'tau' should not be empty"); THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, a); diff --git a/aten/src/TH/generic/THTensorLapack.h b/aten/src/TH/generic/THTensorLapack.h index 05dbbf9f12ec5..c19df681cd6f1 100644 --- a/aten/src/TH/generic/THTensorLapack.h +++ b/aten/src/TH/generic/THTensorLapack.h @@ -3,7 +3,6 @@ #else TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); -TH_API void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, bool eigenvectors); TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a, bool upper); TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a); TH_API void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau); diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 1d0daf1206de5..3f56494e5999a 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -31,7 +31,6 @@ TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim); TH_API void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim); -TH_API accreal THTensor_(trace)(THTensor *t); #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index 93708556dfb5b..2faeadd76e010 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -252,27 +252,6 @@ void THTensor_(preserveReduceDimSemantics)( #if !defined(TH_REAL_IS_BOOL) /* non bool only part */ -accreal THTensor_(trace)(THTensor *t) -{ - scalar_t *t_data = t->data(); - accreal sum = 0; - int64_t i = 0; - int64_t t_stride_0, t_stride_1, t_diag_size; - - THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix"); - - t_stride_0 = THTensor_(stride)(t, 0); - t_stride_1 = THTensor_(stride)(t, 1); - t_diag_size = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)); - while(i < t_diag_size) - { - sum += t_data[i*(t_stride_0+t_stride_1)]; - i++; - } - - return sum; -} - /* Implementation of the Quickselect algorithm, based on Nicolas Devillard's public domain implementation at http://ndevilla.free.fr/median/median/ Adapted similarly to the above Quicksort algorithm. */ diff --git a/aten/src/TH/generic/THTensorRandom.cpp b/aten/src/TH/generic/THTensorRandom.cpp index 399bcc38e1de7..c37b0b9bb7f05 100644 --- a/aten/src/TH/generic/THTensorRandom.cpp +++ b/aten/src/TH/generic/THTensorRandom.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) @@ -149,119 +148,4 @@ void THTensor_(multinomialAliasDraw)(THLongTensor *self, THTensor *q, THLongTens } } #endif - -#if defined(TH_REAL_IS_BYTE) -void THTensor_(getRNGState)(at::Generator _generator, THTensor *self) -{ - // See Note [Acquire lock when using random generators] - std::lock_guard lock(_generator.mutex()); - static const size_t size = sizeof(THGeneratorStateNew); - THTensor_(resize1d)(self, size); - THArgCheck(THTensor_(nElement)(self) == size, 1, "RNG state is wrong size"); - THArgCheck(THTensor_(isContiguous)(self), 1, "RNG state needs to be contiguous"); - static_assert(std::is_pod::value, "THGeneratorStateNew is not a PODType"); - - // cast byte tensor to POD type - THGeneratorStateNew* rng_state = (THGeneratorStateNew*)self->data(); - - // accumulate generator data to be copied into byte tensor - auto accum_state = std::make_unique(); - auto cast_generator = at::check_generator(_generator); - auto rng_data = cast_generator->engine().data(); - accum_state->legacy_pod.the_initial_seed = rng_data.seed_; - accum_state->legacy_pod.left = rng_data.left_; - accum_state->legacy_pod.seeded = rng_data.seeded_; - accum_state->legacy_pod.next = rng_data.next_; - std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state)); - accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy - accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy - accum_state->legacy_pod.normal_is_valid = false; - accum_state->legacy_pod.normal_y = 0.0; - accum_state->next_float_normal_sample = 0.0f; - accum_state->is_next_float_normal_sample_valid = false; - if(cast_generator->next_double_normal_sample()) { - accum_state->legacy_pod.normal_is_valid = true; - accum_state->legacy_pod.normal_y = *(cast_generator->next_double_normal_sample()); - } - if(cast_generator->next_float_normal_sample()) { - accum_state->is_next_float_normal_sample_valid = true; - accum_state->next_float_normal_sample = *(cast_generator->next_float_normal_sample()); - } - - memcpy(rng_state, accum_state.get(), size); -} - -void THTensor_(setRNGState)(at::Generator _generator, THTensor *self) -{ - // See Note [Acquire lock when using random generators] - std::lock_guard lock(_generator.mutex()); - auto cast_generator = at::check_generator(_generator); - THArgCheck(THTensor_(isContiguous)(self), 1, "RNG state needs to be contiguous"); - static_assert(std::is_pod::value, "THGeneratorState is not a PODType"); - static_assert(std::is_pod::value, "THGeneratorStateNew is not a PODType"); - - static const size_t size_legacy = sizeof(THGeneratorState); - static const size_t size_current = sizeof(THGeneratorStateNew); - static_assert(size_legacy != size_current, "Legacy THGeneratorState and THGeneratorStateNew can't be of the same size"); - - at::mt19937 engine; - auto float_normal_sample = c10::optional(); - auto double_normal_sample = c10::optional(); - - // Construct the state of at::CPUGeneratorImpl based on input byte tensor size. - THGeneratorState* legacy_pod; - if (THTensor_(nElement)(self) == size_legacy) { - legacy_pod = (THGeneratorState*)self->data(); - // Note that in legacy THGeneratorState, we didn't have float version - // of normal sample and hence we leave the c10::optional as is - - // Update next_double_normal_sample. - // Note that legacy THGeneratorState stores two uniform values (normal_x, normal_y) - // and a rho value (normal_rho). These three values were redundant and in the new - // DistributionsHelper.h, we store the actual extra normal sample, rather than three - // intermediate values. - if (legacy_pod->normal_is_valid) { - auto r = legacy_pod->normal_rho; - auto theta = 2.0 * M_PI * legacy_pod->normal_x; - // we return the sin version of the normal sample when in caching mode - double_normal_sample = c10::optional(r * ::sin(theta)); - } - } else if (THTensor_(nElement)(self) == size_current) { - auto rng_state = (THGeneratorStateNew*)self->data(); - legacy_pod = &rng_state->legacy_pod; - // update next_float_normal_sample - if (rng_state->is_next_float_normal_sample_valid) { - float_normal_sample = c10::optional(rng_state->next_float_normal_sample); - } - - // Update next_double_normal_sample. - // Note that in getRNGState, we now return the actual normal sample in normal_y - // and if it's valid in normal_is_valid. The redundant normal_x and normal_rho - // are squashed to 0.0. - if (legacy_pod->normal_is_valid) { - double_normal_sample = c10::optional(legacy_pod->normal_y); - } - } else { - AT_ERROR("Expected either a THGeneratorState of size ", size_legacy, - " or a THGeneratorStateNew of size ", size_current, - " but found the input RNG state size to be ", THTensor_(nElement)(self)); - } - - // construct engine_ - // Note that legacy THGeneratorState stored a state array of 64 bit uints, whereas in our - // redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are - // doing a std::copy. - at::mt19937_data_pod rng_data; - std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin()); - rng_data.seed_ = legacy_pod->the_initial_seed; - rng_data.left_ = legacy_pod->left; - rng_data.seeded_ = legacy_pod->seeded; - rng_data.next_ = static_cast(legacy_pod->next); - engine.set_data(rng_data); - THArgCheck(engine.is_valid(), 1, "Invalid mt19937 state"); - cast_generator->set_engine(engine); - cast_generator->set_next_float_normal_sample(float_normal_sample); - cast_generator->set_next_double_normal_sample(double_normal_sample); -} -#endif #endif diff --git a/aten/src/TH/generic/THTensorRandom.h b/aten/src/TH/generic/THTensorRandom.h index ffc52bc69390a..ddeb905680cd7 100644 --- a/aten/src/TH/generic/THTensorRandom.h +++ b/aten/src/TH/generic/THTensorRandom.h @@ -9,9 +9,4 @@ TH_API void THTensor_(multinomialAliasSetup)(THTensor *prob_dist, THLongTensor * TH_API void THTensor_(multinomialAliasDraw)(THLongTensor *self, THTensor *q, THLongTensor *J, int n_sample, c10::optional _generator); #endif -#if defined(TH_REAL_IS_BYTE) -TH_API void THTensor_(getRNGState)(at::Generator _generator, THTensor *self); -TH_API void THTensor_(setRNGState)(at::Generator _generator, THTensor *self); -#endif - #endif diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index bee2f5b84e502..8ceab78f5abed 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -41,14 +41,12 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THCReduceApplyUtils.cu - ${CMAKE_CURRENT_SOURCE_DIR}/THCBlas.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCSleep.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCStorage.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorCopy.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMath.cu - ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathBlas.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathMagma.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathPairwise.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathReduce.cu @@ -68,7 +66,6 @@ install(FILES THC.h ${CMAKE_CURRENT_BINARY_DIR}/THCGeneral.h THCGeneral.hpp - THCBlas.h THCSleep.h THCStorage.h THCStorageCopy.h @@ -141,8 +138,6 @@ install(FILES generic/THCTensorMasked.cu generic/THCTensorMath.h generic/THCTensorMath.cu - generic/THCTensorMathBlas.cu - generic/THCTensorMathBlas.h generic/THCTensorMathMagma.h generic/THCTensorMathMagma.cu generic/THCTensorMathPairwise.h diff --git a/aten/src/THC/THC.h b/aten/src/THC/THC.h index 79be433e1a84f..7e522a599b9e1 100644 --- a/aten/src/THC/THC.h +++ b/aten/src/THC/THC.h @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/aten/src/THC/THCApply.cuh b/aten/src/THC/THCApply.cuh index 368f1566e84c6..7e52e1a1130c0 100644 --- a/aten/src/THC/THCApply.cuh +++ b/aten/src/THC/THCApply.cuh @@ -6,6 +6,7 @@ #include #include #include +#include // // This file contains pointwise operation functions and kernels that @@ -242,14 +243,11 @@ bool THC_pointwiseApply1(THCState* state, // (or vice versa), the contiguous tensor can be collapsed to one // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, A) \ - kernelPointwiseApply1 \ - <<>>( \ - OffsetInfo \ - (aInfo), \ - (TYPE) totalElements, op); +#define HANDLE_CASE(TYPE, A) \ + kernelPointwiseApply1 \ + <<>>( \ + OffsetInfo(aInfo), (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define HANDLE_A_CASE(TYPE, A) { \ switch (A) { \ @@ -298,6 +296,7 @@ bool THC_pointwiseApply1(THCState* state, uint64_t, 1> <<>>( aOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 @@ -310,6 +309,7 @@ bool THC_pointwiseApply1(THCState* state, uint64_t, -1> <<>>( aOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE @@ -392,16 +392,13 @@ bool THC_pointwiseApply2(THCState* state, // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. #define HANDLE_CASE(TYPE, A, B) \ - kernelPointwiseApply2 \ + kernelPointwiseApply2 \ <<>>( \ - OffsetInfo \ - (aInfo), \ - OffsetInfo \ - (bInfo), \ - (TYPE) totalElements, op); + OffsetInfo(aInfo), \ + OffsetInfo(bInfo), \ + (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + #define HANDLE_B_CASE(TYPE, A, B) { \ switch (B) { \ @@ -474,6 +471,7 @@ bool THC_pointwiseApply2(THCState* state, uint64_t, 1, 1> <<>>( aOffset, bOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); @@ -488,6 +486,7 @@ bool THC_pointwiseApply2(THCState* state, uint64_t, -1, -1> <<>>( aOffset, bOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE @@ -598,7 +597,8 @@ bool THC_pointwiseApply3(THCState* state, (bInfo), \ OffsetInfo \ (cInfo), \ - (TYPE) totalElements, op); + (TYPE) totalElements, op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define HANDLE_C_CASE(TYPE, A, B, C) { \ switch (C) { \ @@ -697,6 +697,7 @@ bool THC_pointwiseApply3(THCState* state, uint64_t, 1, 1, 1> <<>>( aOffset, bOffset, cOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { #if CUDA_VERSION < 9000 grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); @@ -715,6 +716,7 @@ bool THC_pointwiseApply3(THCState* state, uint64_t, -1, -1, -1> <<>>( aOffset, bOffset, cOffset, (uint64_t) totalElements, op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } #undef HANDLE_CASE diff --git a/aten/src/THC/THCAsmUtils.cuh b/aten/src/THC/THCAsmUtils.cuh index 6375891bd7f24..be0bf6ffa1ba1 100644 --- a/aten/src/THC/THCAsmUtils.cuh +++ b/aten/src/THC/THCAsmUtils.cuh @@ -94,15 +94,16 @@ __device__ __forceinline__ int getLaneId() { #if defined(__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskLt() { - std::uint64_t m = (1ull << getLaneId()) - 1ull; + const std::uint64_t m = (1ull << getLaneId()) - 1ull; return m; +} #else __device__ __forceinline__ unsigned getLaneMaskLt() { unsigned mask; asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); return mask; -#endif } +#endif #if defined (__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskLe() { @@ -119,27 +120,28 @@ __device__ __forceinline__ unsigned getLaneMaskLe() { #if defined(__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskGt() { - std::uint64_t m = getLaneMaskLe(); + const std::uint64_t m = getLaneMaskLe(); return m ? ~m : m; +} #else __device__ __forceinline__ unsigned getLaneMaskGt() { unsigned mask; asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); return mask; -#endif } +#endif #if defined(__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskGe() { - std::uint64_t m = getLaneMaskLt(); + const std::uint64_t m = getLaneMaskLt(); return ~m; +} #else __device__ __forceinline__ unsigned getLaneMaskGe() { unsigned mask; asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); return mask; -#endif } - +#endif #endif // THC_ASM_UTILS_INC diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu deleted file mode 100644 index 73d411f05ef16..0000000000000 --- a/aten/src/THC/THCBlas.cu +++ /dev/null @@ -1,362 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include - -#ifdef __HIP_PLATFORM_HCC__ -#include -#endif - -/* Level 2 */ - -void adjustLdLevel2(int64_t m, int64_t n, int64_t *lda) -{ - // Note: leading dimensions generally are checked that they are > 0 and at least as big the result - // requires (even if the value won't be used). - // TODO: why does Level3 check trans but this doesn't? - if (n <= 1) - *lda = std::max(m, 1); -} - -void THCudaBlas_Sger(THCState *state, int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda) -{ - adjustLdLevel2(m, n, &lda); - - if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); - return; - } - THError("Cublas_Sger only supports m, n, lda, incx, incy" - "with the bound [val] <= %d", INT_MAX); -} - -void THCudaBlas_Dger(THCState *state, int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda) -{ - adjustLdLevel2(m, n, &lda); - - if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); - return; - } - THError("Cublas_Dger only supports m, n, lda, incx, incy" - "with the bound [val] <= %d", INT_MAX); -} - - -cublasOperation_t convertTransToCublasOperation(char trans) { - if (trans == 't') return CUBLAS_OP_T; - else if (trans == 'n') return CUBLAS_OP_N; - else if (trans == 'c') return CUBLAS_OP_C; - else { - THError("trans must be one of: t, n, c"); - return CUBLAS_OP_T; - } -} - -void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc) -{ - int transa_ = ((transa == 't') || (transa == 'T')); - int transb_ = ((transb == 't') || (transb == 'T')); - - // Note: leading dimensions generally are checked that they are > 0 and at least as big the result - // requires (even if the value won't be used). - if(n <= 1) - *ldc = std::max(m, 1); - - if(transa_) - { - if(m <= 1) - *lda = std::max(k, 1); - } - else - { - if(k <= 1) - *lda = std::max(m, 1); - } - - if(transb_) - { - if(k <= 1) - *ldb = std::max(n, 1); - } - else - { - if(n <= 1) - *ldb = std::max(k, 1); - } - -} - -/* Level 3 */ -void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc) -{ - at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); -} - -// In CUDA 8.0, definition of data types for sgemmex changed -#if CUDA_VERSION < 8000 -# define CUDA_R_16F CUBLAS_DATA_HALF -#endif - -void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::Half alpha, at::Half *a, int64_t lda, at::Half *b, int64_t ldb, at::Half beta, at::Half *c, int64_t ldc) -{ - at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); -} - -#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -void THCudaBlas_Bgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::BFloat16 alpha, at::BFloat16 *a, int64_t lda, at::BFloat16 *b, int64_t ldb, at::BFloat16 beta, at::BFloat16 *c, int64_t ldc) -{ - at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); -} -#endif - -void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, double alpha, double *a, int64_t lda, double *b, int64_t ldb, double beta, double *c, int64_t ldc) -{ - at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); -} - -void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB, - at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - float fAlpha = alpha; - float fBeta = beta; -#ifdef __HIP_PLATFORM_HCC__ - THCublasCheck(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, rocblas_datatype_f16_r, (int)lda, strideA, - b, rocblas_datatype_f16_r, (int)ldb, strideB, - (void*)&fBeta, c, rocblas_datatype_f16_r, (int)ldc, strideC, - c, rocblas_datatype_f16_r, (int)ldc, strideC, - (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0)); -#else -#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); -#endif // CUDA_VERSION < 11000 - THCublasCheck(cublasGemmStridedBatchedEx(handle, - opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, - b, CUDA_R_16F, (int)ldb, strideB, - (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, - (int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); -#endif // CUDA_VERSION < 11000 -#endif // __HIP_PLATFORM_HCC__ -} - -#ifdef __HIP_PLATFORM_HCC__ -void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::BFloat16 alpha, const at::BFloat16 *a, int64_t lda, int64_t strideA, const at::BFloat16 *b, int64_t ldb, int64_t strideB, - at::BFloat16 beta, at::BFloat16 *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - float fAlpha = alpha; - float fBeta = beta; - THCublasCheck(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, rocblas_datatype_bf16_r, (int)lda, strideA, - b, rocblas_datatype_bf16_r, (int)ldb, strideB, - (void*)&fBeta, c, rocblas_datatype_bf16_r, (int)ldc, strideC, - c, rocblas_datatype_bf16_r, (int)ldc, strideC, - (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0, NULL, NULL)); -} -#endif // __HIP_PLATFORM_HCC__ - -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::BFloat16 alpha, const at::BFloat16 *a, int64_t lda, int64_t strideA, const at::BFloat16 *b, int64_t ldb, int64_t strideB, - at::BFloat16 beta, at::BFloat16 *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major < 8) { - TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU"); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - float fAlpha = alpha; - float fBeta = beta; - THCublasCheck(cublasGemmStridedBatchedEx(handle, - opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, CUDA_R_16BF, (int)lda, strideA, - b, CUDA_R_16BF, (int)ldb, strideB, - (void*)&fBeta, c, CUDA_R_16BF, (int)ldc, strideC, - (int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} -#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - -void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb, - float beta, float *c[], int64_t ldc, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - -#ifdef __HIP_PLATFORM_HCC__ - - const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n; - const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k; - const int64_t stridec = ldc*n; - - THCudaBlas_SgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount); - -#else - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSgemmBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc, - (int)batchCount)); -#endif -} - -void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB, - float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSgemmStridedBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, - (int)batchCount)); -} - -void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb, - double beta, double *c[], int64_t ldc, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - -#ifdef __HIP_PLATFORM_HCC__ - - const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n; - const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k; - const int64_t stridec = ldc*n; - - THCudaBlas_DgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount); - -#else - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDgemmBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc, - (int)batchCount)); -#endif -} - -void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, - double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - // See Note [Writing Nondeterministic Operations] - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDgemmStridedBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, - (int)batchCount)); -} diff --git a/aten/src/THC/THCBlas.h b/aten/src/THC/THCBlas.h deleted file mode 100644 index a9b646a4374ff..0000000000000 --- a/aten/src/THC/THCBlas.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef THC_BLAS_INC -#define THC_BLAS_INC - -#include -#include -#include - -/* Level 2 */ -THC_API void THCudaBlas_Sger(THCState *state, int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda); -THC_API void THCudaBlas_Dger(THCState *state, int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda); - -/* Level 3 */ -THC_API void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc); -THC_API void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, double alpha, double *a, int64_t lda, double *b, int64_t ldb, double beta, double *c, int64_t ldc); - -THC_API void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, THHalf alpha, THHalf *a, int64_t lda, THHalf *b, int64_t ldb, THHalf beta, THHalf *c, int64_t ldc); -#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -THC_API void THCudaBlas_Bgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::BFloat16 alpha, at::BFloat16 *a, int64_t lda, at::BFloat16 *b, int64_t ldb, at::BFloat16 beta, at::BFloat16 *c, int64_t ldc); -#endif - -THC_API void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb, - float beta, float *c[], int64_t ldc, int64_t batchCount); -THC_API void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb, - double beta, double *c[], int64_t ldc, int64_t batchCount); -THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB, - float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount); -THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, - double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount); - -void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - THHalf alpha, const THHalf *a, int64_t lda, int64_t strideA, const THHalf *b, int64_t ldb, int64_t strideB, - THHalf beta, THHalf *c, int64_t ldc, int64_t strideC, int64_t batchCount); - -#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::BFloat16 alpha, const at::BFloat16 *a, int64_t lda, int64_t strideA, const at::BFloat16 *b, int64_t ldb, int64_t strideB, - at::BFloat16 beta, at::BFloat16 *c, int64_t ldc, int64_t strideC, int64_t batchCount); -#endif - -#endif diff --git a/aten/src/THC/THCCachingHostAllocator.h b/aten/src/THC/THCCachingHostAllocator.h index 38688842e10d9..30dc664cc118c 100644 --- a/aten/src/THC/THCCachingHostAllocator.h +++ b/aten/src/THC/THCCachingHostAllocator.h @@ -21,13 +21,13 @@ // Note that this allocator does not split larger allocations into smaller // blocks, unlike the caching device allocator. // -THC_API c10::Allocator* getTHCCachingHostAllocator(void); +TORCH_CUDA_API c10::Allocator* getTHCCachingHostAllocator(void); // Records an event in the specified stream. The allocation 'ptr' will not be // re-used until the event has occurred. -THC_API cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, at::cuda::CUDAStream stream); +TORCH_CUDA_API cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, at::cuda::CUDAStream stream); // Releases cached pinned memory allocations via cudaHostFree -THC_API void THCCachingHostAllocator_emptyCache(void); +TORCH_CUDA_API void THCCachingHostAllocator_emptyCache(void); #endif diff --git a/aten/src/THC/THCDeviceUtils.cuh b/aten/src/THC/THCDeviceUtils.cuh index 171488d912144..c82ca76ca5e20 100644 --- a/aten/src/THC/THCDeviceUtils.cuh +++ b/aten/src/THC/THCDeviceUtils.cuh @@ -7,6 +7,8 @@ #include #endif +#include + /* The largest consecutive integer representable in float32 (2^24) */ #define FLOAT32_MAX_CONSECUTIVE_INT 16777216.0f @@ -32,7 +34,7 @@ __host__ __device__ __forceinline__ T THCRoundUp(T a, T b) { */ template __device__ __forceinline__ T doLdg(const T* p) { -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 && !defined __HIP_PLATFORM_HCC__ return __ldg(p); #else return *p; diff --git a/aten/src/THC/THCGeneral.h.in b/aten/src/THC/THCGeneral.h.in index 7cb8c4c81755f..09095bd945974 100644 --- a/aten/src/THC/THCGeneral.h.in +++ b/aten/src/THC/THCGeneral.h.in @@ -14,11 +14,6 @@ #cmakedefine USE_MAGMA -// TH & THC are now part of the same library as ATen and Caffe2 -// NB: However, we are planning to split it out to a torch_cuda library -#define THC_API TORCH_CUDA_API -#define THC_CLASS TORCH_CUDA_API - #ifndef THAssert #define THAssert(exp) \ do { \ @@ -36,22 +31,22 @@ typedef struct _THCCudaResourcesPerDevice { size_t scratchSpacePerStream; } THCCudaResourcesPerDevice; -THC_API THCState* THCState_alloc(void); -THC_API void THCState_free(THCState* state); +TORCH_CUDA_API THCState* THCState_alloc(void); +TORCH_CUDA_API void THCState_free(THCState* state); -THC_API void THCudaInit(THCState* state); -THC_API void THCudaShutdown(THCState* state); +TORCH_CUDA_API void THCudaInit(THCState* state); +TORCH_CUDA_API void THCudaShutdown(THCState* state); /* If device `dev` can access allocations on device `devToAccess`, this will return */ /* 1; otherwise, 0. */ -THC_API int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess); +TORCH_CUDA_API int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess); -THC_API c10::Allocator* THCState_getCudaHostAllocator(THCState* state); +TORCH_CUDA_API c10::Allocator* THCState_getCudaHostAllocator(THCState* state); -THC_API void THCMagma_init(THCState *state); +TORCH_CUDA_API void THCMagma_init(THCState *state); /* For the current device and stream, returns the allocated scratch space */ -THC_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state); +TORCH_CUDA_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state); #define THCAssertSameGPU(expr) if (!expr) THError("arguments are located on different GPUs") #define THCudaCheck(err) __THCudaCheck(err, __FILE__, __LINE__) @@ -59,16 +54,16 @@ THC_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state); #define THCublasCheck(err) __THCublasCheck(err, __FILE__, __LINE__) #define THCusparseCheck(err) __THCusparseCheck(err, __FILE__, __LINE__) -THC_API void __THCudaCheck(cudaError_t err, const char *file, const int line); -THC_API void __THCudaCheckWarn(cudaError_t err, const char *file, const int line); -THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line); -THC_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line); +TORCH_CUDA_API void __THCudaCheck(cudaError_t err, const char *file, const int line); +TORCH_CUDA_API void __THCudaCheckWarn(cudaError_t err, const char *file, const int line); +TORCH_CUDA_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line); +TORCH_CUDA_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line); -THC_API void* THCudaMalloc(THCState *state, size_t size); -THC_API void THCudaFree(THCState *state, void* ptr); +TORCH_CUDA_API void* THCudaMalloc(THCState *state, size_t size); +TORCH_CUDA_API void THCudaFree(THCState *state, void* ptr); at::DataPtr THCudaHostAlloc(THCState *state, size_t size); -THC_API void THCudaHostRecord(THCState *state, void *ptr); +TORCH_CUDA_API void THCudaHostRecord(THCState *state, void *ptr); #endif diff --git a/aten/src/THC/THCReduceAll.cuh b/aten/src/THC/THCReduceAll.cuh index 9546f85f61c9a..af2e264e6528c 100644 --- a/aten/src/THC/THCReduceAll.cuh +++ b/aten/src/THC/THCReduceAll.cuh @@ -10,6 +10,7 @@ // #include +#include #include #ifdef __HIP_PLATFORM_HCC__ @@ -209,6 +210,7 @@ void callReduceAll(THCState* state, <<>>( in, (IndexType) totalElements, init, modifyOp, reduceOp, (AccT*) scratchSpace); + C10_CUDA_KERNEL_LAUNCH_CHECK(); int numPass1Blocks = grid.x; getPass2ReduceBlockGrid(state, totalElements, grid, block); @@ -218,6 +220,7 @@ void callReduceAll(THCState* state, <<>>( numPass1Blocks, init, reduceOp, (AccT*) scratchSpace, devOut); + C10_CUDA_KERNEL_LAUNCH_CHECK(); THCudaFree(state, scratchSpace); } else { @@ -227,6 +230,7 @@ void callReduceAll(THCState* state, kernelReduceAll <<>>( in, (IndexType) totalElements, init, modifyOp, reduceOp, devOut); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } diff --git a/aten/src/THC/THCReduceApplyUtils.cuh b/aten/src/THC/THCReduceApplyUtils.cuh index 4324e19597739..8048c5ac12ecb 100644 --- a/aten/src/THC/THCReduceApplyUtils.cuh +++ b/aten/src/THC/THCReduceApplyUtils.cuh @@ -147,6 +147,6 @@ __device__ T reduceBlockWithNThreadLocalReductions(T *smem, void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg); // Produces a grid with at least one point per tile -THC_API bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid); +TORCH_CUDA_API bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid); #endif // THC_REDUCE_APPLY_UTILS_INC diff --git a/aten/src/THC/THCSleep.h b/aten/src/THC/THCSleep.h index b3f20a39340e6..3d1c07c075811 100644 --- a/aten/src/THC/THCSleep.h +++ b/aten/src/THC/THCSleep.h @@ -5,6 +5,6 @@ #include // enqueues a kernel that spins for the specified number of cycles -THC_API void THC_sleep(THCState* state, int64_t cycles); +TORCH_CUDA_API void THC_sleep(THCState* state, int64_t cycles); #endif diff --git a/aten/src/THC/THCStorage.hpp b/aten/src/THC/THCStorage.hpp index 4e7e68d18e725..5c26eba2b6df0 100644 --- a/aten/src/THC/THCStorage.hpp +++ b/aten/src/THC/THCStorage.hpp @@ -13,17 +13,17 @@ #include #include -THC_API THCStorage* THCStorage_new(THCState* state); +TORCH_CUDA_API THCStorage* THCStorage_new(THCState* state); -THC_API void THCStorage_retain(THCState *state, THCStorage *storage); +TORCH_CUDA_API void THCStorage_retain(THCState *state, THCStorage *storage); -THC_API void THCStorage_resizeBytes( +TORCH_CUDA_API void THCStorage_resizeBytes( THCState* state, THCStorage* storage, ptrdiff_t size_bytes); -THC_API int THCStorage_getDevice(THCState* state, const THCStorage* storage); +TORCH_CUDA_API int THCStorage_getDevice(THCState* state, const THCStorage* storage); -THC_API THCStorage* THCStorage_newWithDataAndAllocator( +TORCH_CUDA_API THCStorage* THCStorage_newWithDataAndAllocator( THCState* state, at::DataPtr&& data, ptrdiff_t size, diff --git a/aten/src/THC/THCTensor.h b/aten/src/THC/THCTensor.h index acb0fa0b06074..63b8325624023 100644 --- a/aten/src/THC/THCTensor.h +++ b/aten/src/THC/THCTensor.h @@ -9,7 +9,7 @@ #define THC_DESC_BUFF_LEN 64 -typedef struct THC_CLASS THCDescBuff +typedef struct TORCH_CUDA_API THCDescBuff { char str[THC_DESC_BUFF_LEN]; } THCDescBuff; diff --git a/aten/src/THC/THCTensor.hpp b/aten/src/THC/THCTensor.hpp index dd31247497a14..f2dbbc88d276f 100644 --- a/aten/src/THC/THCTensor.hpp +++ b/aten/src/THC/THCTensor.hpp @@ -12,46 +12,46 @@ #include // See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll] -THC_API int THCTensor_nDimension(THCState *state, const THCTensor *self); -THC_API int THCTensor_nDimensionLegacyNoScalars(THCState *state, const THCTensor *self); -THC_API int THCTensor_nDimensionLegacyAll(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_nDimension(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_nDimensionLegacyNoScalars(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_nDimensionLegacyAll(THCState *state, const THCTensor *self); -THC_API int64_t THCTensor_size(THCState *state, const THCTensor *self, int dim); -THC_API int64_t THCTensor_sizeLegacyNoScalars(THCState *state, const THCTensor *self, int dim); -THC_API int64_t THCTensor_stride(THCState *state, const THCTensor *self, int dim); -THC_API int64_t THCTensor_strideLegacyNoScalars(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API int64_t THCTensor_size(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API int64_t THCTensor_sizeLegacyNoScalars(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API int64_t THCTensor_stride(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API int64_t THCTensor_strideLegacyNoScalars(THCState *state, const THCTensor *self, int dim); -THC_API THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta); +TORCH_CUDA_API THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta); -THC_API void THCTensor_resize(THCState *state, THCTensor *tensor, at::IntArrayRef size, at::IntArrayRef stride); -THC_API void THCTensor_resizeNd(THCState *state, THCTensor *tensor, int nDimension, const int64_t *size, const int64_t *stride); -THC_API void THCTensor_resizeAs(THCState *state, THCTensor *tensor, THCTensor *src); +TORCH_CUDA_API void THCTensor_resize(THCState *state, THCTensor *tensor, at::IntArrayRef size, at::IntArrayRef stride); +TORCH_CUDA_API void THCTensor_resizeNd(THCState *state, THCTensor *tensor, int nDimension, const int64_t *size, const int64_t *stride); +TORCH_CUDA_API void THCTensor_resizeAs(THCState *state, THCTensor *tensor, THCTensor *src); -THC_API void THCTensor_set(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_setStorage(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntArrayRef size_, at::IntArrayRef stride_); +TORCH_CUDA_API void THCTensor_set(THCState *state, THCTensor *self, THCTensor *src); +TORCH_CUDA_API void THCTensor_setStorage(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntArrayRef size_, at::IntArrayRef stride_); -THC_API void THCTensor_squeeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension_); -THC_API void THCTensor_unsqueeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension_); +TORCH_CUDA_API void THCTensor_squeeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension_); +TORCH_CUDA_API void THCTensor_unsqueeze1d(THCState *state, THCTensor *self, THCTensor *src, int dimension_); -THC_API bool THCTensor_allContiguous(THCState *state, THCTensor **inputs, int numInputs); -THC_API ptrdiff_t THCTensor_nElement(THCState *state, const THCTensor *self); +TORCH_CUDA_API bool THCTensor_allContiguous(THCState *state, THCTensor **inputs, int numInputs); +TORCH_CUDA_API ptrdiff_t THCTensor_nElement(THCState *state, const THCTensor *self); -THC_API void THCTensor_retain(THCState *state, THCTensor *self); -THC_API void THCTensor_free(THCState *state, THCTensor *self); +TORCH_CUDA_API void THCTensor_retain(THCState *state, THCTensor *self); +TORCH_CUDA_API void THCTensor_free(THCState *state, THCTensor *self); -THC_API int THCTensor_getDevice(THCState* state, const THCTensor* tensor); -THC_API bool THCTensor_allSameDevice(THCState* state, THCTensor ** inputs, int numInputs); +TORCH_CUDA_API int THCTensor_getDevice(THCState* state, const THCTensor* tensor); +TORCH_CUDA_API bool THCTensor_allSameDevice(THCState* state, THCTensor ** inputs, int numInputs); /* Can we use 32 bit math for indexing? */ -THC_API bool THCTensor_canUse32BitIndexMath(THCState* state, const THCTensor* t, ptrdiff_t max_elem=INT32_MAX); +TORCH_CUDA_API bool THCTensor_canUse32BitIndexMath(THCState* state, const THCTensor* t, ptrdiff_t max_elem=INT32_MAX); /* Are all tensors 32-bit indexable? */ -THC_API bool THCTensor_all32BitIndexable(THCState* state, THCTensor** inputs, int numInputs); -THC_API void THCTensor_preserveReduceDimSemantics(THCState *state, THCTensor *tensor, int in_dims, +TORCH_CUDA_API bool THCTensor_all32BitIndexable(THCState* state, THCTensor** inputs, int numInputs); +TORCH_CUDA_API void THCTensor_preserveReduceDimSemantics(THCState *state, THCTensor *tensor, int in_dims, int64_t dimension, int keepdim); /* Returns false if there is no possibility that the tensor */ /* has more than one index that references the same datapoint, */ /* true otherwise. */ -THC_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor* t); +TORCH_CUDA_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor* t); #include #include diff --git a/aten/src/THC/THCTensorIndex.cu b/aten/src/THC/THCTensorIndex.cu index 0287f31f658ec..3bb429ed30e36 100644 --- a/aten/src/THC/THCTensorIndex.cu +++ b/aten/src/THC/THCTensorIndex.cu @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #include @@ -218,20 +217,6 @@ struct WrapIndexOp { int64_t size; }; -template -struct TensorTakeOp { - TensorTakeOp(TensorInfo info, IndexType numel, int64_t*, int64_t*) - : info(info), numel(numel) {} - - __device__ __forceinline__ void operator()(T* out, int64_t* index) { - auto offset = indexToOffset(info, *index, numel); - *out = info.data[offset]; - } - - const TensorInfo info; - IndexType numel; -}; - template struct TensorPutOp { TensorPutOp(TensorInfo info, IndexType numel, int64_t*, int64_t*) diff --git a/aten/src/THC/THCTensorMath.h b/aten/src/THC/THCTensorMath.h index 68fbb240afb42..fd316f93ed552 100644 --- a/aten/src/THC/THCTensorMath.h +++ b/aten/src/THC/THCTensorMath.h @@ -13,12 +13,6 @@ #include #include -#include -#include - -#include -#include - #include #include diff --git a/aten/src/THC/THCTensorMathBlas.cu b/aten/src/THC/THCTensorMathBlas.cu deleted file mode 100644 index 383d1ed17b1d2..0000000000000 --- a/aten/src/THC/THCTensorMathBlas.cu +++ /dev/null @@ -1,13 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include diff --git a/aten/src/THC/THCTensorMathMagma.cu b/aten/src/THC/THCTensorMathMagma.cu index a2fd5fe8baf5d..36316a6bf2ebf 100644 --- a/aten/src/THC/THCTensorMathMagma.cu +++ b/aten/src/THC/THCTensorMathMagma.cu @@ -8,9 +8,7 @@ #include #ifdef USE_MAGMA -#include -#else -#include +#include #endif #ifndef DIVUP diff --git a/aten/src/THC/THCTensorMathMagma.cuh b/aten/src/THC/THCTensorMathMagma.cuh index 08124d3d4c916..1fb5821afce56 100644 --- a/aten/src/THC/THCTensorMathMagma.cuh +++ b/aten/src/THC/THCTensorMathMagma.cuh @@ -2,9 +2,7 @@ #define THC_TENSOR_MATH_MAGMA_CUH #ifdef USE_MAGMA -#include -#else -#include +#include #endif #ifdef USE_MAGMA diff --git a/aten/src/THC/THCTensorMathPairwise.cu b/aten/src/THC/THCTensorMathPairwise.cu index da57a1ad36f83..04fb34df4f704 100644 --- a/aten/src/THC/THCTensorMathPairwise.cu +++ b/aten/src/THC/THCTensorMathPairwise.cu @@ -21,36 +21,5 @@ struct TensorMulConstantOp { const T val; }; -template -struct TensorFmodOp { - TensorFmodOp(T v) : val((float)v) {} - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = (T) fmodf((float) *in, val); - } - - __device__ __forceinline__ void operator()(T* v) { - *v = (T) fmodf((float) *v, val); - } - - const float val; -}; - -template <> -struct TensorFmodOp { - TensorFmodOp(double v) : val(v) {} - __device__ __forceinline__ void operator()(double* out, double* in) { - *out = fmod(*in, val); - } - - __device__ __forceinline__ void operator()(double* v) { - *v = fmod(*v, val); - } - - const double val; -}; - -#include -#include - #include #include diff --git a/aten/src/THC/THCTensorMathPointwise.cuh b/aten/src/THC/THCTensorMathPointwise.cuh index 2b511983934f1..bb2e31de26698 100644 --- a/aten/src/THC/THCTensorMathPointwise.cuh +++ b/aten/src/THC/THCTensorMathPointwise.cuh @@ -36,50 +36,6 @@ struct TensorMulOp { } }; -template -struct TensorCFmodOp { - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = *out % *in; - } - - __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) { - *out = *in1 % *in2; - } -}; - -template <> -struct TensorCFmodOp { - __device__ __forceinline__ void operator()(float* out, float* in) { - *out = fmodf(*out, *in); - } - - __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) { - *out = fmodf(*in1, *in2); - } -}; - -template <> -struct TensorCFmodOp { - __device__ __forceinline__ void operator()(double* out, double* in) { - *out = fmod(*out, *in); - } - - __device__ __forceinline__ void operator()(double* out, double* in1, double* in2) { - *out = fmod(*in1, *in2); - } -}; - -template <> -struct TensorCFmodOp { - __device__ __forceinline__ void operator()(at::Half* out, at::Half* in) { - *out = fmodf(*out, *in); - } - - __device__ __forceinline__ void operator()(at::Half* out, at::Half* in1, at::Half* in2) { - *out = fmodf(*in1, *in2); - } -}; - template struct TensorCrossOp { TensorCrossOp(int64_t sx, int64_t sy, int64_t so) : sx(sx), sy(sy), so(so) {} diff --git a/aten/src/THC/THCTensorRandom.cu b/aten/src/THC/THCTensorRandom.cu index aefb427f4e67c..8655ea2fb829a 100644 --- a/aten/src/THC/THCTensorRandom.cu +++ b/aten/src/THC/THCTensorRandom.cu @@ -12,60 +12,6 @@ #define MAX_NUM_BLOCKS 200 #define BLOCK_SIZE 256 -// NB: ROCm compiler seems to have a bug where __host__ functions must be -// explicitly specified extern "C" otherwise ROCm compiler doesn't respect it. -// See https://github.com/RadeonOpenCompute/hcc/issues/839 -__host__ void THCRandom_getRNGState(at::Generator gen_, THByteTensor *rng_state) -{ - auto gen = at::check_generator(gen_); - std::lock_guard lock(gen->mutex_); - // The RNG state comprises the seed, and an offset used for Philox. - // The following line is just here for BC reason. sizeof curandStateMtgp32 is 4120. - // It used to be static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); - // MAX_NUM_BLOCKS was 200 and sizeof(curandStateMtgp32) is 4120. Hardcoding these numbers here - // because this is just host side code and we don't want to worry about linking with cuda - static const size_t states_size = 200 * sizeof(4120); - static const size_t seed_size = sizeof(uint64_t); - static const size_t offset_size = sizeof(int64_t); - static const size_t total_size = states_size + seed_size + offset_size; - THByteTensor_resize1d(rng_state, total_size); - THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); - THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); - // since curandStateMTGP is not used anymore, fill gen_states of THCGenerator with deterministic garbage value of -1 - // gen_states in THCGenerator struct was an array of curandStateMtgp32s. - memset(THByteTensor_data(rng_state), -1, states_size); - auto current_seed = gen->current_seed(); - auto offset = static_cast(gen->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic - memcpy(THByteTensor_data(rng_state) + states_size, ¤t_seed, seed_size); - memcpy(THByteTensor_data(rng_state) + states_size + seed_size, &offset, offset_size); -} - -__host__ void THCRandom_setRNGState(at::Generator gen_, THByteTensor *rng_state) -{ - auto gen = at::check_generator(gen_); - std::lock_guard lock(gen->mutex_); - static const size_t states_size = 200 * sizeof(4120); // this line is just here for BC reason - static const size_t seed_size = sizeof(uint64_t); - static const size_t offset_size = sizeof(int64_t); - static const size_t total_size = states_size + seed_size + offset_size; - bool no_philox_seed = false; - if (THByteTensor_nElement(rng_state) == total_size - offset_size) { - no_philox_seed = true; - } - else { - THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); - } - THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); - uint64_t input_seed; - memcpy(&input_seed, THByteTensor_data(rng_state) + states_size, seed_size); - gen->set_current_seed(input_seed); - int64_t philox_offset = 0; - if (!no_philox_seed) { - memcpy(&philox_offset, THByteTensor_data(rng_state) + states_size + seed_size, offset_size); - } - gen->set_philox_offset_per_thread(static_cast(philox_offset)); -} - #include #include diff --git a/aten/src/THC/THCTensorRandom.h b/aten/src/THC/THCTensorRandom.h index 0504f94a31a29..696e36f70bec7 100644 --- a/aten/src/THC/THCTensorRandom.h +++ b/aten/src/THC/THCTensorRandom.h @@ -9,9 +9,4 @@ #include #include -#include - -THC_API void THCRandom_getRNGState(at::Generator gen_, THByteTensor *rng_state); -THC_API void THCRandom_setRNGState(at::Generator gen_, THByteTensor *rng_state); - #endif diff --git a/aten/src/THC/THCTensorSort.cu b/aten/src/THC/THCTensorSort.cu index 8969209a1bdc2..189e73b909fb5 100644 --- a/aten/src/THC/THCTensorSort.cu +++ b/aten/src/THC/THCTensorSort.cu @@ -1,5 +1,6 @@ #include #include +#include void THCudaLongTensor_fillSliceWithIndex(THCState* state, THCudaLongTensor* t, @@ -28,8 +29,10 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state, #define FILL_INDEX(T, DIM) \ fillSliceWithIndex \ - <<>>( \ - info, numSlices, sliceSize, info.strides[collapseDim]) + <<>>( \ + info, numSlices, sliceSize, info.strides[collapseDim]); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() + if (THCTensor_canUse32BitIndexMath(state, t)) { TensorInfo info = @@ -59,6 +62,5 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state, } #undef FILL_INDEX - THCudaCheck(cudaGetLastError()); } } diff --git a/aten/src/THC/generic/THCStorage.h b/aten/src/THC/generic/THCStorage.h index 56c126058e45b..8f050779ddf2e 100644 --- a/aten/src/THC/generic/THCStorage.h +++ b/aten/src/THC/generic/THCStorage.h @@ -19,34 +19,34 @@ #define THCudaComplexFloatStorage THCStorage #define THCudaComplexDoubleStorage THCStorage -THC_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*); -THC_API int THCStorage_(elementSize)(THCState *state); +TORCH_CUDA_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*); +TORCH_CUDA_API int THCStorage_(elementSize)(THCState *state); /* slow access -- checks everything */ -THC_API void THCStorage_(set)(THCState *state, THCStorage*, ptrdiff_t, scalar_t); -THC_API scalar_t THCStorage_(get)(THCState *state, const THCStorage*, ptrdiff_t); +TORCH_CUDA_API void THCStorage_(set)(THCState *state, THCStorage*, ptrdiff_t, scalar_t); +TORCH_CUDA_API scalar_t THCStorage_(get)(THCState *state, const THCStorage*, ptrdiff_t); -THC_API THCStorage* THCStorage_(new)(THCState *state); -THC_API THCStorage* THCStorage_(newWithSize)(THCState *state, ptrdiff_t size); -THC_API THCStorage* THCStorage_(newWithSize1)(THCState *state, scalar_t); -THC_API THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *filename, ptrdiff_t size, int shared); +TORCH_CUDA_API THCStorage* THCStorage_(new)(THCState *state); +TORCH_CUDA_API THCStorage* THCStorage_(newWithSize)(THCState *state, ptrdiff_t size); +TORCH_CUDA_API THCStorage* THCStorage_(newWithSize1)(THCState *state, scalar_t); +TORCH_CUDA_API THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *filename, ptrdiff_t size, int shared); -THC_API THCStorage* THCStorage_(newWithAllocator)( +TORCH_CUDA_API THCStorage* THCStorage_(newWithAllocator)( THCState *state, ptrdiff_t size, at::Allocator* allocator); -THC_API THCStorage* THCStorage_(newWithDataAndAllocator)( +TORCH_CUDA_API THCStorage* THCStorage_(newWithDataAndAllocator)( THCState *state, at::DataPtr&& data, ptrdiff_t size, at::Allocator* allocator); -THC_API void THCStorage_(setFlag)(THCState *state, THCStorage *storage, const char flag); -THC_API void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char flag); -THC_API void THCStorage_(retain)(THCState *state, THCStorage *storage); +TORCH_CUDA_API void THCStorage_(setFlag)(THCState *state, THCStorage *storage, const char flag); +TORCH_CUDA_API void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char flag); +TORCH_CUDA_API void THCStorage_(retain)(THCState *state, THCStorage *storage); -THC_API void THCStorage_(free)(THCState *state, THCStorage *storage); -THC_API void THCStorage_( +TORCH_CUDA_API void THCStorage_(free)(THCState *state, THCStorage *storage); +TORCH_CUDA_API void THCStorage_( resizeBytes)(THCState* state, THCStorage* storage, ptrdiff_t size_bytes); -THC_API void THCStorage_(fill)(THCState *state, THCStorage *storage, scalar_t value); +TORCH_CUDA_API void THCStorage_(fill)(THCState *state, THCStorage *storage, scalar_t value); -THC_API int THCStorage_(getDevice)(THCState* state, const THCStorage* storage); +TORCH_CUDA_API int THCStorage_(getDevice)(THCState* state, const THCStorage* storage); #endif diff --git a/aten/src/THC/generic/THCStorageCopy.h b/aten/src/THC/generic/THCStorageCopy.h index fa360465c2bd2..d2f623057be50 100644 --- a/aten/src/THC/generic/THCStorageCopy.h +++ b/aten/src/THC/generic/THCStorageCopy.h @@ -4,57 +4,57 @@ /* Support for copy between different Storage types */ -THC_API void THCStorage_(copy)(THCState *state, THCStorage *storage, THCStorage *src); +TORCH_CUDA_API void THCStorage_(copy)(THCState *state, THCStorage *storage, THCStorage *src); #if !defined(THC_REAL_IS_COMPLEXFLOAT) && !defined(THC_REAL_IS_COMPLEXDOUBLE) - THC_API void THCStorage_(copyByte)(THCState *state, THCStorage *storage, struct THByteStorage *src); - THC_API void THCStorage_(copyChar)(THCState *state, THCStorage *storage, struct THCharStorage *src); - THC_API void THCStorage_(copyShort)(THCState *state, THCStorage *storage, struct THShortStorage *src); - THC_API void THCStorage_(copyInt)(THCState *state, THCStorage *storage, struct THIntStorage *src); - THC_API void THCStorage_(copyLong)(THCState *state, THCStorage *storage, struct THLongStorage *src); - THC_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct THFloatStorage *src); - THC_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src); - THC_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src); - THC_API void THCStorage_(copyBool)(THCState *state, THCStorage *storage, struct THBoolStorage *src); - THC_API void THCStorage_(copyBFloat16)(THCState *state, THCStorage *storage, struct THBFloat16Storage *src); + TORCH_CUDA_API void THCStorage_(copyByte)(THCState *state, THCStorage *storage, struct THByteStorage *src); + TORCH_CUDA_API void THCStorage_(copyChar)(THCState *state, THCStorage *storage, struct THCharStorage *src); + TORCH_CUDA_API void THCStorage_(copyShort)(THCState *state, THCStorage *storage, struct THShortStorage *src); + TORCH_CUDA_API void THCStorage_(copyInt)(THCState *state, THCStorage *storage, struct THIntStorage *src); + TORCH_CUDA_API void THCStorage_(copyLong)(THCState *state, THCStorage *storage, struct THLongStorage *src); + TORCH_CUDA_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct THFloatStorage *src); + TORCH_CUDA_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src); + TORCH_CUDA_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src); + TORCH_CUDA_API void THCStorage_(copyBool)(THCState *state, THCStorage *storage, struct THBoolStorage *src); + TORCH_CUDA_API void THCStorage_(copyBFloat16)(THCState *state, THCStorage *storage, struct THBFloat16Storage *src); #else - THC_API void THCStorage_(copyComplexFloat)(THCState *state, THCStorage *storage, struct THComplexFloatStorage *src); - THC_API void THCStorage_(copyComplexDouble)(THCState *state, THCStorage *storage, struct THComplexDoubleStorage *src); + TORCH_CUDA_API void THCStorage_(copyComplexFloat)(THCState *state, THCStorage *storage, struct THComplexFloatStorage *src); + TORCH_CUDA_API void THCStorage_(copyComplexDouble)(THCState *state, THCStorage *storage, struct THComplexDoubleStorage *src); #endif #if !defined(THC_REAL_IS_COMPLEXFLOAT) && !defined(THC_REAL_IS_COMPLEXDOUBLE) - THC_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src); - THC_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src); - THC_API void THCStorage_(copyCudaShort)(THCState *state, THCStorage *storage, struct THCudaShortStorage *src); - THC_API void THCStorage_(copyCudaInt)(THCState *state, THCStorage *storage, struct THCudaIntStorage *src); - THC_API void THCStorage_(copyCudaLong)(THCState *state, THCStorage *storage, struct THCudaLongStorage *src); - THC_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, struct THCudaStorage *src); - THC_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src); - THC_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src); - THC_API void THCStorage_(copyCudaBool)(THCState *state, THCStorage *storage, struct THCudaBoolStorage *src); - THC_API void THCStorage_(copyCudaBFloat16)(THCState *state, THCStorage *storage, struct THCudaBFloat16Storage *src); + TORCH_CUDA_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaShort)(THCState *state, THCStorage *storage, struct THCudaShortStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaInt)(THCState *state, THCStorage *storage, struct THCudaIntStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaLong)(THCState *state, THCStorage *storage, struct THCudaLongStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, struct THCudaStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaBool)(THCState *state, THCStorage *storage, struct THCudaBoolStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaBFloat16)(THCState *state, THCStorage *storage, struct THCudaBFloat16Storage *src); #else - THC_API void THCStorage_(copyCudaComplexFloat)(THCState *state, THCStorage *storage, struct THCudaComplexFloatStorage *src); - THC_API void THCStorage_(copyCudaComplexDouble)(THCState *state, THCStorage *storage, struct THCudaComplexDoubleStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaComplexFloat)(THCState *state, THCStorage *storage, struct THCudaComplexFloatStorage *src); + TORCH_CUDA_API void THCStorage_(copyCudaComplexDouble)(THCState *state, THCStorage *storage, struct THCudaComplexDoubleStorage *src); #endif #if !defined(THC_REAL_IS_COMPLEXFLOAT) && !defined(THC_REAL_IS_COMPLEXDOUBLE) - THC_API void TH_CONCAT_2(THByteStorage_copyCuda , Real)(THCState *state, THByteStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THCharStorage_copyCuda , Real)(THCState *state, THCharStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THShortStorage_copyCuda , Real)(THCState *state, THShortStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THIntStorage_copyCuda , Real)(THCState *state, THIntStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THLongStorage_copyCuda , Real)(THCState *state, THLongStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloatStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THBoolStorage_copyCuda, Real)(THCState *state, THBoolStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THBFloat16Storage_copyCuda, Real)(THCState *state, THBFloat16Storage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THByteStorage_copyCuda , Real)(THCState *state, THByteStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THCharStorage_copyCuda , Real)(THCState *state, THCharStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THShortStorage_copyCuda , Real)(THCState *state, THShortStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THIntStorage_copyCuda , Real)(THCState *state, THIntStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THLongStorage_copyCuda , Real)(THCState *state, THLongStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloatStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THBoolStorage_copyCuda, Real)(THCState *state, THBoolStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THBFloat16Storage_copyCuda, Real)(THCState *state, THBFloat16Storage *self, struct THCStorage *src); #else - THC_API void TH_CONCAT_2(THComplexFloatStorage_copyCuda , Real)(THCState *state, THComplexFloatStorage *self, struct THCStorage *src); - THC_API void TH_CONCAT_2(THComplexDoubleStorage_copyCuda, Real)(THCState *state, THComplexDoubleStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THComplexFloatStorage_copyCuda , Real)(THCState *state, THComplexFloatStorage *self, struct THCStorage *src); + TORCH_CUDA_API void TH_CONCAT_2(THComplexDoubleStorage_copyCuda, Real)(THCState *state, THComplexDoubleStorage *self, struct THCStorage *src); #endif -THC_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src); -THC_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src); -THC_API void THCStorage_(copyCPU)(THCState *state, THCStorage *self, THStorage *src); +TORCH_CUDA_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src); +TORCH_CUDA_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src); +TORCH_CUDA_API void THCStorage_(copyCPU)(THCState *state, THCStorage *self, THStorage *src); #endif diff --git a/aten/src/THC/generic/THCTensor.h b/aten/src/THC/generic/THCTensor.h index 525a0f8733913..5c1ea298c7772 100644 --- a/aten/src/THC/generic/THCTensor.h +++ b/aten/src/THC/generic/THCTensor.h @@ -20,87 +20,87 @@ #define THCudaComplexDoubleTensor THCTensor /**** access methods ****/ -THC_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self); -THC_API ptrdiff_t THCTensor_(storageOffset)(THCState *state, const THCTensor *self); +TORCH_CUDA_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self); +TORCH_CUDA_API ptrdiff_t THCTensor_(storageOffset)(THCState *state, const THCTensor *self); // See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll] -THC_API int THCTensor_(nDimension)(THCState *state, const THCTensor *self); -THC_API int THCTensor_(nDimensionLegacyNoScalars)(THCState *state, const THCTensor *self); -THC_API int THCTensor_(nDimensionLegacyAll)(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_(nDimension)(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_(nDimensionLegacyNoScalars)(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_(nDimensionLegacyAll)(THCState *state, const THCTensor *self); -THC_API int64_t THCTensor_(size)(THCState *state, const THCTensor *self, int dim); -THC_API int64_t THCTensor_(sizeLegacyNoScalars)(THCState *state, const THCTensor *self, int dim); -THC_API int64_t THCTensor_(stride)(THCState *state, const THCTensor *self, int dim); -THC_API int64_t THCTensor_(strideLegacyNoScalars)(THCState *state, const THCTensor *self, int dim); -THC_API scalar_t *THCTensor_(data)(THCState *state, const THCTensor *self); +TORCH_CUDA_API int64_t THCTensor_(size)(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API int64_t THCTensor_(sizeLegacyNoScalars)(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API int64_t THCTensor_(stride)(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API int64_t THCTensor_(strideLegacyNoScalars)(THCState *state, const THCTensor *self, int dim); +TORCH_CUDA_API scalar_t *THCTensor_(data)(THCState *state, const THCTensor *self); -THC_API void THCTensor_(setFlag)(THCState *state, THCTensor *self, const char flag); -THC_API void THCTensor_(clearFlag)(THCState *state, THCTensor *self, const char flag); +TORCH_CUDA_API void THCTensor_(setFlag)(THCState *state, THCTensor *self, const char flag); +TORCH_CUDA_API void THCTensor_(clearFlag)(THCState *state, THCTensor *self, const char flag); /**** creation methods ****/ -THC_API THCTensor *THCTensor_(new)(THCState *state); -THC_API THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor); -THC_API THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage_, ptrdiff_t storageOffset_, +TORCH_CUDA_API THCTensor *THCTensor_(new)(THCState *state); +TORCH_CUDA_API THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor); +TORCH_CUDA_API THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage_, ptrdiff_t storageOffset_, int64_t size0_, int64_t stride0_); /* stride might be NULL */ -THC_API THCTensor *THCTensor_(newWithSize1d)(THCState *state, int64_t size0_); +TORCH_CUDA_API THCTensor *THCTensor_(newWithSize1d)(THCState *state, int64_t size0_); -THC_API THCTensor *THCTensor_(newClone)(THCState *state, THCTensor *self); -THC_API THCTensor *THCTensor_(newContiguous)(THCState *state, THCTensor *tensor); -THC_API THCTensor *THCTensor_(newSelect)(THCState *state, THCTensor *tensor, int dimension_, int64_t sliceIndex_); -THC_API THCTensor *THCTensor_(newNarrow)(THCState *state, THCTensor *tensor, int dimension_, int64_t firstIndex_, int64_t size_); -THC_API THCTensor *THCTensor_(newTranspose)(THCState *state, THCTensor *tensor, int dimension1_, int dimension2_); -THC_API THCTensor *THCTensor_(newFoldBatchDim)(THCState *state, THCTensor *input); +TORCH_CUDA_API THCTensor *THCTensor_(newClone)(THCState *state, THCTensor *self); +TORCH_CUDA_API THCTensor *THCTensor_(newContiguous)(THCState *state, THCTensor *tensor); +TORCH_CUDA_API THCTensor *THCTensor_(newSelect)(THCState *state, THCTensor *tensor, int dimension_, int64_t sliceIndex_); +TORCH_CUDA_API THCTensor *THCTensor_(newNarrow)(THCState *state, THCTensor *tensor, int dimension_, int64_t firstIndex_, int64_t size_); +TORCH_CUDA_API THCTensor *THCTensor_(newTranspose)(THCState *state, THCTensor *tensor, int dimension1_, int dimension2_); +TORCH_CUDA_API THCTensor *THCTensor_(newFoldBatchDim)(THCState *state, THCTensor *input); // resize* methods simply resize the storage. So they may not retain the current data at current indices. // This is especially likely to happen when the tensor is not contiguous. In general, if you still need the // values, unless you are doing some size and stride tricks, do not use resize*. -THC_API void THCTensor_(resizeNd)(THCState *state, THCTensor *tensor, int nDimension, const int64_t *size, const int64_t *stride); -THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src); -THC_API void THCTensor_(resize0d)(THCState *state, THCTensor *tensor); -THC_API void THCTensor_(resize1d)(THCState *state, THCTensor *tensor, int64_t size0_); -THC_API void THCTensor_(resize2d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_); -THC_API void THCTensor_(resize3d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_); -THC_API void THCTensor_(resize4d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_, int64_t size3_); -THC_API void THCTensor_(resize5d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_, int64_t size3_, int64_t size4_); +TORCH_CUDA_API void THCTensor_(resizeNd)(THCState *state, THCTensor *tensor, int nDimension, const int64_t *size, const int64_t *stride); +TORCH_CUDA_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src); +TORCH_CUDA_API void THCTensor_(resize0d)(THCState *state, THCTensor *tensor); +TORCH_CUDA_API void THCTensor_(resize1d)(THCState *state, THCTensor *tensor, int64_t size0_); +TORCH_CUDA_API void THCTensor_(resize2d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_); +TORCH_CUDA_API void THCTensor_(resize3d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_); +TORCH_CUDA_API void THCTensor_(resize4d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_, int64_t size3_); +TORCH_CUDA_API void THCTensor_(resize5d)(THCState *state, THCTensor *tensor, int64_t size0_, int64_t size1_, int64_t size2_, int64_t size3_, int64_t size4_); -THC_API void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src); +TORCH_CUDA_API void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(narrow)(THCState *state, THCTensor *self, THCTensor *src, int dimension_, int64_t firstIndex_, int64_t size_); -THC_API void THCTensor_(select)(THCState *state, THCTensor *self, THCTensor *src, int dimension_, int64_t sliceIndex_); -THC_API void THCTensor_(transpose)(THCState *state, THCTensor *self, THCTensor *src, int dimension1_, int dimension2_); +TORCH_CUDA_API void THCTensor_(narrow)(THCState *state, THCTensor *self, THCTensor *src, int dimension_, int64_t firstIndex_, int64_t size_); +TORCH_CUDA_API void THCTensor_(select)(THCState *state, THCTensor *self, THCTensor *src, int dimension_, int64_t sliceIndex_); +TORCH_CUDA_API void THCTensor_(transpose)(THCState *state, THCTensor *self, THCTensor *src, int dimension1_, int dimension2_); -THC_API void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_); -THC_API void THCTensor_(unsqueeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_); +TORCH_CUDA_API void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_); +TORCH_CUDA_API void THCTensor_(unsqueeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_); -THC_API int THCTensor_(isContiguous)(THCState *state, const THCTensor *self); -THC_API int THCTensor_(isSameSizeAs)(THCState *state, const THCTensor *self, const THCTensor *src); -THC_API ptrdiff_t THCTensor_(nElement)(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_(isContiguous)(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_(isSameSizeAs)(THCState *state, const THCTensor *self, const THCTensor *src); +TORCH_CUDA_API ptrdiff_t THCTensor_(nElement)(THCState *state, const THCTensor *self); -THC_API void THCTensor_(retain)(THCState *state, THCTensor *self); -THC_API void THCTensor_(free)(THCState *state, THCTensor *self); -THC_API void THCTensor_(freeCopyTo)(THCState *state, THCTensor *self, THCTensor *dst); +TORCH_CUDA_API void THCTensor_(retain)(THCState *state, THCTensor *self); +TORCH_CUDA_API void THCTensor_(free)(THCState *state, THCTensor *self); +TORCH_CUDA_API void THCTensor_(freeCopyTo)(THCState *state, THCTensor *self, THCTensor *dst); /* Slow access methods [check everything] */ -THC_API void THCTensor_(set0d)(THCState *state, THCTensor *tensor, scalar_t value); -THC_API void THCTensor_(set1d)(THCState *state, THCTensor *tensor, int64_t x0, scalar_t value); -THC_API void THCTensor_(set2d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, scalar_t value); -THC_API void THCTensor_(set3d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, scalar_t value); -THC_API void THCTensor_(set4d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3, scalar_t value); - -THC_API scalar_t THCTensor_(get0d)(THCState *state, const THCTensor *tensor); -THC_API scalar_t THCTensor_(get1d)(THCState *state, const THCTensor *tensor, int64_t x0); -THC_API scalar_t THCTensor_(get2d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1); -THC_API scalar_t THCTensor_(get3d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2); -THC_API scalar_t THCTensor_(get4d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3); +TORCH_CUDA_API void THCTensor_(set0d)(THCState *state, THCTensor *tensor, scalar_t value); +TORCH_CUDA_API void THCTensor_(set1d)(THCState *state, THCTensor *tensor, int64_t x0, scalar_t value); +TORCH_CUDA_API void THCTensor_(set2d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, scalar_t value); +TORCH_CUDA_API void THCTensor_(set3d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, scalar_t value); +TORCH_CUDA_API void THCTensor_(set4d)(THCState *state, THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3, scalar_t value); + +TORCH_CUDA_API scalar_t THCTensor_(get0d)(THCState *state, const THCTensor *tensor); +TORCH_CUDA_API scalar_t THCTensor_(get1d)(THCState *state, const THCTensor *tensor, int64_t x0); +TORCH_CUDA_API scalar_t THCTensor_(get2d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1); +TORCH_CUDA_API scalar_t THCTensor_(get3d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2); +TORCH_CUDA_API scalar_t THCTensor_(get4d)(THCState *state, const THCTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3); /* CUDA-specific functions */ -THC_API int THCTensor_(getDevice)(THCState *state, const THCTensor *self); -THC_API int THCTensor_(checkGPU)(THCState *state, unsigned int nTensors, ...); +TORCH_CUDA_API int THCTensor_(getDevice)(THCState *state, const THCTensor *self); +TORCH_CUDA_API int THCTensor_(checkGPU)(THCState *state, unsigned int nTensors, ...); /* debug methods */ -THC_API THCDescBuff THCTensor_(sizeDesc)(THCState *state, const THCTensor *tensor); +TORCH_CUDA_API THCDescBuff THCTensor_(sizeDesc)(THCState *state, const THCTensor *tensor); #endif diff --git a/aten/src/THC/generic/THCTensor.hpp b/aten/src/THC/generic/THCTensor.hpp index 05d3eea7f79ac..a23d8df5324bf 100644 --- a/aten/src/THC/generic/THCTensor.hpp +++ b/aten/src/THC/generic/THCTensor.hpp @@ -8,9 +8,9 @@ // NOTE: functions exist here only to support dispatch via Declarations.cwrap. You probably don't want to put // new functions in here, they should probably be un-genericized. -THC_API void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, +TORCH_CUDA_API void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, at::IntArrayRef size_, at::IntArrayRef stride_); -THC_API void THCTensor_(resize)(THCState *state, THCTensor *self, at::IntArrayRef size, at::IntArrayRef stride); +TORCH_CUDA_API void THCTensor_(resize)(THCState *state, THCTensor *self, at::IntArrayRef size, at::IntArrayRef stride); #endif diff --git a/aten/src/THC/generic/THCTensorCopy.h b/aten/src/THC/generic/THCTensorCopy.h index 896b2635d7282..97642360f55c3 100644 --- a/aten/src/THC/generic/THCTensorCopy.h +++ b/aten/src/THC/generic/THCTensorCopy.h @@ -2,10 +2,10 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorCopy.h" #else -THC_API void THCTensor_(copy)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(copyIgnoringOverlaps)(THCState *state, THCTensor *self, THCTensor *src); +TORCH_CUDA_API void THCTensor_(copy)(THCState *state, THCTensor *self, THCTensor *src); +TORCH_CUDA_API void THCTensor_(copyIgnoringOverlaps)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, THTensor *src); -THC_API void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, THCTensor *src); +TORCH_CUDA_API void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, THTensor *src); +TORCH_CUDA_API void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, THCTensor *src); #endif diff --git a/aten/src/THC/generic/THCTensorIndex.cu b/aten/src/THC/generic/THCTensorIndex.cu index a6c621c8ef15d..3f506d3457141 100644 --- a/aten/src/THC/generic/THCTensorIndex.cu +++ b/aten/src/THC/generic/THCTensorIndex.cu @@ -3,6 +3,8 @@ #else #include +#include +#include // Check tensor dimensions for index operations, and return the slice size. // src can be nullptr in case of indexFill: in that case it is ignored. @@ -126,11 +128,12 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ - indexCopySmallIndex \ - <<>>( \ - dstInfo, srcInfo, indicesInfo, \ - dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize); +#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \ + indexCopySmallIndex \ + <<>>( \ + dstInfo, srcInfo, indicesInfo, \ + dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, TYPE, \ DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \ @@ -140,7 +143,8 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT dstInfo, srcInfo, indicesInfo, \ dstCopyDim, srcCopyDim, srcTotalSize, \ (IDX_IS_MAJOR) ? sliceSize : numIndices, \ - dstCopyDimSize); + dstCopyDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); @@ -220,21 +224,6 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT #undef LARGE_INDEX } -void THCTensor_(take)(THCState *state, THCTensor *dst, THCTensor *src, THCudaLongTensor *index) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src)); - THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); - - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, src) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, dst) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - THArgCheck(THCudaLongTensor_nDimensionLegacyNoScalars(state, index) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - THArgCheck(!(THCTensor_(numel)(state, src) == 0 && THCudaLongTensor_numel(state, index) != 0), 2, - "tried to take from an empty tensor"); - - THCTensor_(resizeNd)(state, dst, index->dim(), THTensor_getSizePtr(index), NULL); - dispatchTakePut(state, src, dst, index); -} - static void THCTensor_(sort_indices)(THCState *state, THCudaLongTensor *index, THCTensor *src) { THCThrustAllocator thrustAlloc(state); @@ -294,6 +283,13 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING); + at::assert_no_overlap(dst, indices); + if (at::has_internal_overlap(dst) == at::MemOverlap::YES) { + TORCH_WARN( + "Use of index_fill_ on expanded tensors is deprecated. " + "Please clone() the tensor before performing this operation. " + "This also applies to advanced indexing e.g. tensor[mask] = scalar"); + } // The `src` is partitioned into two parts: // -the size of each slice we are indexing, which is the @@ -314,11 +310,12 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \ +#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \ indexFillSmallIndex \ - <<>>( \ - dstInfo, indicesInfo, \ - dstFillDim, sliceSize, dstFillDimSize, val); + <<>>( \ + dstInfo, indicesInfo, \ + dstFillDim, sliceSize, dstFillDimSize, val); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); #define LARGE_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM, IDX_IS_MAJOR) \ indexFillLargeIndex \ @@ -326,7 +323,8 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT dstInfo, indicesInfo, \ dstFillDim, sliceSize * numIndices, \ (IDX_IS_MAJOR) ? sliceSize : numIndices, \ - dstFillDimSize, val); + dstFillDimSize, val); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8))); dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128)); diff --git a/aten/src/THC/generic/THCTensorIndex.h b/aten/src/THC/generic/THCTensorIndex.h index 2e40a0ff99873..95490fbf084e6 100644 --- a/aten/src/THC/generic/THCTensorIndex.h +++ b/aten/src/THC/generic/THCTensorIndex.h @@ -2,10 +2,10 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorIndex.h" #else -THC_API void THCTensor_(indexCopy)(THCState *state, THCTensor *res_, int dim, THCudaLongTensor *indices, THCTensor *src); -THC_API void THCTensor_(indexFill)(THCState *state, THCTensor *tensor, int dim, THCudaLongTensor *index, scalar_t val); -THC_API void THCTensor_(indexSelect)(THCState *state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index); -THC_API void THCTensor_(take)(THCState *state, THCTensor *res_, THCTensor *src, THCudaLongTensor *index); -THC_API void THCTensor_(put)(THCState *state, THCTensor *res_, THCudaLongTensor *indices, THCTensor *src, int accumulate); +TORCH_CUDA_API void THCTensor_(indexCopy)(THCState *state, THCTensor *res_, int dim, THCudaLongTensor *indices, THCTensor *src); +TORCH_CUDA_API void THCTensor_(indexFill)(THCState *state, THCTensor *tensor, int dim, THCudaLongTensor *index, scalar_t val); +TORCH_CUDA_API void THCTensor_(indexSelect)(THCState *state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index); +TORCH_CUDA_API void THCTensor_(take)(THCState *state, THCTensor *res_, THCTensor *src, THCudaLongTensor *index); +TORCH_CUDA_API void THCTensor_(put)(THCState *state, THCTensor *res_, THCudaLongTensor *indices, THCTensor *src, int accumulate); #endif diff --git a/aten/src/THC/generic/THCTensorMasked.cu b/aten/src/THC/generic/THCTensorMasked.cu index 344b0c35ef6e8..4e93ac260e420 100644 --- a/aten/src/THC/generic/THCTensorMasked.cu +++ b/aten/src/THC/generic/THCTensorMasked.cu @@ -60,7 +60,7 @@ void THCTensor_(maskedCopy)(THCState* state, "mask and tensor must have the same number of elements"); // Determine our output size - ptrdiff_t totalElements = THTensor_wrap(mask).sum().item(); + int64_t totalElements = THTensor_wrap(mask).sum().item(); // The number of `1` elements present in the mask must be <= the // number of elements available in `src` @@ -126,7 +126,7 @@ void THCTensor_(maskedCopyBool)(THCState* state, "mask and tensor must have the same number of elements"); // Determine our output size - ptrdiff_t totalElements = THTensor_wrap(mask).sum().item(); + int64_t totalElements = THTensor_wrap(mask).sum().item(); // The number of `1` elements present in the mask must be <= the // number of elements available in `src` diff --git a/aten/src/THC/generic/THCTensorMasked.h b/aten/src/THC/generic/THCTensorMasked.h index ece7c60246244..f834e3a31dea3 100644 --- a/aten/src/THC/generic/THCTensorMasked.h +++ b/aten/src/THC/generic/THCTensorMasked.h @@ -2,35 +2,35 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMasked.h" #else -THC_API void THCTensor_(maskedFill)(THCState *state, +TORCH_CUDA_API void THCTensor_(maskedFill)(THCState *state, THCTensor *tensor, THCudaByteTensor *mask, scalar_t value); -THC_API void THCTensor_(maskedFillBool)(THCState *state, +TORCH_CUDA_API void THCTensor_(maskedFillBool)(THCState *state, THCTensor *tensor, THCudaBoolTensor *mask, scalar_t value); // FIXME: remove now that we have THCudaByteTensor? -THC_API void THCTensor_(maskedFillByte)(THCState *state, +TORCH_CUDA_API void THCTensor_(maskedFillByte)(THCState *state, THCTensor *tensor, THByteTensor *mask, scalar_t value); -THC_API void THCTensor_(maskedCopy)(THCState *state, +TORCH_CUDA_API void THCTensor_(maskedCopy)(THCState *state, THCTensor *tensor, THCudaByteTensor *mask, THCTensor *src); -THC_API void THCTensor_(maskedCopyBool)(THCState *state, +TORCH_CUDA_API void THCTensor_(maskedCopyBool)(THCState *state, THCTensor *tensor, THCudaBoolTensor *mask, THCTensor *src); // FIXME: remove now that we have THCudaByteTensor? -THC_API void THCTensor_(maskedCopyByte)(THCState *state, +TORCH_CUDA_API void THCTensor_(maskedCopyByte)(THCState *state, THCTensor *tensor, THByteTensor *mask, THCTensor *src); diff --git a/aten/src/THC/generic/THCTensorMath.h b/aten/src/THC/generic/THCTensorMath.h index bdf5dc4763c81..47f6434ce5f27 100644 --- a/aten/src/THC/generic/THCTensorMath.h +++ b/aten/src/THC/generic/THCTensorMath.h @@ -2,9 +2,9 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMath.h" #else -THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value); -THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); -THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); +TORCH_CUDA_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value); +TORCH_CUDA_API void THCTensor_(zero)(THCState *state, THCTensor *self); +TORCH_CUDA_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); #endif diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu deleted file mode 100644 index a5d159a9cace7..0000000000000 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ /dev/null @@ -1,326 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathBlas.cu" -#else - -#include -#include - -#define ERROR_ONLY_FP_TYPES(func) \ - THError("%s for CUDA tensors only supports floating-point types. Try converting the tensors with .float()", func); - -__global__ void createBatchGemmBuffer3(const scalar_t** buffer1, const scalar_t ** buffer2, const scalar_t ** buffer3, scalar_t* data1, - scalar_t * data2, scalar_t * data3, int64_t stride1, int64_t stride2, int64_t stride3, int64_t num_batches) { - const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < num_batches) { - buffer1[idx] = data1 + idx * stride1; - buffer2[idx] = data2 + idx * stride2; - buffer3[idx] = data3 + idx * stride3; - } -} - -void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, - THCTensor *batch1, THCTensor *batch2, - scalar_t beta, scalar_t alpha) { -#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_BFLOAT16) - THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2)); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, t) == 3, 4, "expected 3D tensor"); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch1) == 3, 6, "expected 3D tensor"); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch2) == 3, 7, "expected 3D tensor"); - THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch1, 0), 6, - "equal number of batches expected"); - THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch2, 0), 7, - "equal number of batches expected"); - auto maybe_outnames = at::namedinference::compute_baddbmm_outnames(result, batch1, batch2, t); - { - at::NoNamesGuard guard; - THArgCheck(THCTensor_(size)(state, t, 1) == THCTensor_(size)(state, batch1, 1), 6, - "wrong matrix size"); - THArgCheck(THCTensor_(size)(state, t, 2) == THCTensor_(size)(state, batch2, 2), 7, - "wrong matrix size"); - THArgCheck(THCTensor_(size)(state, batch1, 2) == THCTensor_(size)(state, batch2, 1), 6, - "wrong matrix size"); - - if (t != result) { - THCTensor_(resizeAs)(state, result, t); - if (ScalarConvert::to(beta) != 0.0) { - THCTensor_(copy)(state, result, t); - } - } - - bool transpose_result; - char transpose_batch1, transpose_batch2; - int64_t lda, ldb, ldc; - THCTensor *result_, *batch1_, *batch2_; - if (result->stride(1) == 1 && - (result->size(2) == 1 || result->stride(2) >= std::max(1, result->size(1)))) - { - transpose_result = false; - result_ = result; - ldc = result_->stride(2); - } - else if (result->stride(2) == 1 && - (result->size(1) == 1 || result->stride(1) >= std::max(1, result->size(2)))) - { - transpose_result = true; - - THCTensor *swap = batch2; - batch2 = batch1; - batch1 = swap; - - result_ = result; - ldc = result_->stride(1); - } - else - { - transpose_result = false; - - THCTensor *transp_r_ = THCTensor_(newTranspose)(state, result, 1, 2); - result_ = THCTensor_(newClone)(state, transp_r_); - THCTensor_(free)(state, transp_r_); - THCTensor_(transpose)(state, result_, NULL, 1, 2); - - ldc = result_->stride(2); - } - - const int64_t m = result->size(transpose_result ? 2 : 1); - const int64_t n = result->size(transpose_result ? 1 : 2); - const int64_t k = batch1->size(transpose_result ? 1 : 2); - - if (batch1->stride(transpose_result ? 2 : 1) == 1 && - batch1->stride(transpose_result ? 1 : 2) >= std::max(1, m)) - { - transpose_batch1 = 'n'; - batch1_ = batch1; - lda = batch1_->stride(transpose_result ? 1 : 2); - } - else if (batch1->stride(transpose_result ? 1 : 2) == 1 && - batch1->stride(transpose_result ? 2 : 1) >= std::max(1, k)) - { - transpose_batch1 = 't'; - batch1_ = batch1; - lda = batch1_->stride(transpose_result ? 2 : 1); - } - else - { - transpose_batch1 = transpose_result ? 'n' : 't'; - // batch1_ is later freed if batch1_ != batch1 - if (THCTensor_(isContiguous)(state, batch1)) { - batch1_ = batch1; - } else { - batch1_ = THCTensor_(newContiguous)(state, batch1); - } - lda = batch1_->stride(1); - } - - if (batch2->stride(transpose_result ? 2 : 1) == 1 && - batch2->stride(transpose_result ? 1 : 2) >= std::max(1, k)) - { - transpose_batch2 = 'n'; - batch2_ = batch2; - ldb = batch2_->stride(transpose_result ? 1 : 2); - } - else if (batch2->stride(transpose_result ? 1 : 2) == 1 && - batch2->stride(transpose_result ? 2 : 1) >= std::max(1, n)) - { - transpose_batch2 = 't'; - batch2_ = batch2; - ldb = batch2_->stride(transpose_result ? 2 : 1); - } - else - { - transpose_batch2 = transpose_result ? 'n' : 't'; - // batch2_ is later freed if batch2_ != batch2 - if (THCTensor_(isContiguous)(state, batch2)) { - batch2_ = batch2; - } else { - batch2_ = THCTensor_(newContiguous)(state, batch2); - } - ldb = batch2_->stride(1); - } - int64_t num_batches = result_->size(0); - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - // Compute pointers to matrices in each batch. -#if CUDA_VERSION < 8000 && !defined __HIP_PLATFORM_HCC__ - size_t matrices_size = num_batches * sizeof(scalar_t*); - -// Copy pointers to device. - auto d_matrices1 = static_cast(THCudaMalloc(state, matrices_size)); - auto d_matrices2 = static_cast(THCudaMalloc(state, matrices_size)); - auto d_result_matrices = static_cast(THCudaMalloc(state, matrices_size)); - - const int64_t block = 512; - const int64_t grid = (num_batches + block - 1) / block; - - createBatchGemmBuffer3<<>>( - d_matrices1, d_matrices2, (const scalar_t**)d_result_matrices, THCTensor_(data)(state, batch1_), - THCTensor_(data)(state, batch2_), THCTensor_(data)(state, result_), - batch1_->stride(0), batch2_->stride(0), result_->stride(0), num_batches); - -#ifdef THC_REAL_IS_FLOAT - THCudaBlas_SgemmBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - d_matrices1, lda, - d_matrices2, ldb, - beta, - d_result_matrices, ldc, - num_batches); -#elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_DgemmBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - d_matrices1, lda, - d_matrices2, ldb, - beta, - d_result_matrices, ldc, - num_batches); -#endif //THC_REAL - - THCudaFree(state, d_matrices1); - THCudaFree(state, d_matrices2); - THCudaFree(state, d_result_matrices); - -#else -#ifdef THC_REAL_IS_FLOAT - THCudaBlas_SgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_DgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#endif //THC_REAL -#endif //CUDA_VERSION - -#elif defined(THC_REAL_IS_HALF) - -#if CUDA_VERSION < 9010 && !defined(__HIP_PLATFORM_HCC__) - // Currently no HgemmBatched in Cublas - for (int64_t i = 0; i < num_batches; ++i) { - THCudaBlas_Hgemm( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_) + i * batch1_->stride(0), lda, - THCTensor_(data)(state, batch2_) + i * batch2_->stride(0), ldb, - beta, - THCTensor_(data)(state, result_) + i * result_->stride(0), ldc); - } -#else -#ifndef __HIP_PLATFORM_HCC__ - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major >= 5){ -#endif - - THCudaBlas_HgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#ifndef __HIP_PLATFORM_HCC__ - } else { - for (int64_t i = 0; i < num_batches; ++i) { - THCudaBlas_Hgemm( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_) + i * batch1_->stride(0), lda, - THCTensor_(data)(state, batch2_) + i * batch2_->stride(0), ldb, - beta, - THCTensor_(data)(state, result_) + i * result_->stride(0), ldc); - } - } -#endif -#endif //CUDA_VERSION - -#elif defined(THC_REAL_IS_BFLOAT16) -#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - THCudaBlas_BgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#endif // __HIP_PLATFORM_HCC__ -#endif - - if (batch1_ != batch1) { - THCTensor_(free)(state, batch1_); - } - - if (batch2_ != batch2) { - THCTensor_(free)(state, batch2_); - } - - if (result_ != result) { - THCTensor_(freeCopyTo)(state, result_, result); - } - -#if defined(THC_REAL_IS_BFLOAT16) && !(defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000) - // To avoid "variable was set but never used" warning - [&transpose_batch1, &transpose_batch2, &lda, &ldb, &ldc]{}(); - TORCH_CHECK(false, "BgemmStridedBatched is not supported with at::BFloat16 type"); -#endif - } - at::namedinference::propagate_names_if_nonempty(result, maybe_outnames); - -#else - ERROR_ONLY_FP_TYPES("baddbmm"); -#endif -} - -#endif diff --git a/aten/src/THC/generic/THCTensorMathBlas.h b/aten/src/THC/generic/THCTensorMathBlas.h deleted file mode 100644 index e15baafaca641..0000000000000 --- a/aten/src/THC/generic/THCTensorMathBlas.h +++ /dev/null @@ -1,7 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathBlas.h" -#else - -THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, THCTensor *batch1, THCTensor *batch2, scalar_t beta, scalar_t alpha); - -#endif diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu index d0dfae33c4d09..216a964438875 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.cu +++ b/aten/src/THC/generic/THCTensorMathMagma.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.cu" #else +#include + #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) #ifdef USE_MAGMA @@ -115,85 +117,6 @@ void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor #endif } -void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, bool eigenvectors) -{ -#ifdef USE_MAGMA - char jobvrs = eigenvectors ? 'V' : 'N'; - THArgCheck(a_->dim() == 2, 3, "A should be 2 dimensional"); - THArgCheck(a_->size(0) == a_->size(1), 3, "A should be square"); - - magma_vec_t jobvr = jobvrs == 'N' ? MagmaNoVec : MagmaVec; - int64_t n = a_->size(0); - - scalar_t *a_data = th_magma_malloc_pinned(n * n); - THCTensor_(copyTensor2d)(state, a_data, a_); - - scalar_t *wr = th_magma_malloc_pinned(n); - scalar_t *wi = th_magma_malloc_pinned(n); - - scalar_t *vr_data = NULL; - int64_t ldvr = 1; - if (jobvr == MagmaVec) - { - vr_data = th_magma_malloc_pinned(n * n); - ldvr = n; - } - - scalar_t *work_data = nullptr; - - if (n > 0) { - int info; - scalar_t wkopt; - at::native::MagmaStreamSyncGuard guard; - -#if defined(THC_REAL_IS_FLOAT) - magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info); -#else - magma_dgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info); -#endif - - int lwork = (int) wkopt; - work_data = th_magma_malloc_pinned(lwork); - -#if defined(THC_REAL_IS_FLOAT) - magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info); -#else - magma_dgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info); -#endif - - if (info > 0) - THError("MAGMA geev : Failed to converge. %d off-diagonal elements of an didn't converge to zero", info); - else if (info < 0) - THError("MAGMA geev : Argument %d : illegal value", -info); - } - - { - THCTensor_(resize2d)(state, re_, 2, n); - THCTensor *re = THCTensor_(newContiguous)(state, re_); - if (n > 0) { - auto stream = c10::cuda::getCurrentCUDAStream(); - THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, THTensor_getStoragePtr(re)) + re->storage_offset(), wr, n*sizeof(scalar_t), cudaMemcpyHostToDevice, stream)); - THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, THTensor_getStoragePtr(re)) + re->storage_offset() + n, wi, n*sizeof(scalar_t), cudaMemcpyHostToDevice, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - } - THCTensor_(freeCopyTo)(state, re, re_); - THCTensor_(transpose)(state, re_, NULL, 0, 1); - } - - if (jobvr == MagmaVec) - THCTensor_(copyArray2d)(state, rv_, vr_data, n, n); - - magma_free_pinned(work_data); - magma_free_pinned(vr_data); - magma_free_pinned(wi); - magma_free_pinned(wr); - magma_free_pinned(a_data); - -#else - THError(NoMagma(geev)); -#endif -} - __global__ void THCTensor_(copyUpperSymmetric)(scalar_t *input, int n, int len) { for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) { @@ -250,8 +173,10 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper dim3 threads(128); if (uplo == 'U') { THCTensor_(copyUpperSymmetric)<<>>(input_data, n, len); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { THCTensor_(copyLowerSymmetric)<<>>(input_data, n, len); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } THCTensor_(freeCopyTo)(state, input, ra_); diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h index ae46a62c9ec61..48c2f54f26d3d 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.h +++ b/aten/src/THC/generic/THCTensorMathMagma.h @@ -5,10 +5,9 @@ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) // MAGMA (i.e. CUDA implementation of LAPACK functions) -THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); -THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, bool eigenvectors); -THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper); -THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_); +TORCH_CUDA_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); +TORCH_CUDA_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper); +TORCH_CUDA_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_); #endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) diff --git a/aten/src/THC/generic/THCTensorMathPairwise.h b/aten/src/THC/generic/THCTensorMathPairwise.h index 7bae0a54fa41f..6a3f96e440416 100644 --- a/aten/src/THC/generic/THCTensorMathPairwise.h +++ b/aten/src/THC/generic/THCTensorMathPairwise.h @@ -2,12 +2,11 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathPairwise.h" #else -THC_API int THCTensor_(equal)(THCState *state, THCTensor *self, THCTensor *src); +TORCH_CUDA_API int THCTensor_(equal)(THCState *state, THCTensor *self, THCTensor *src); #if !defined(THC_REAL_IS_BOOL) -THC_API void THCTensor_(mul)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value); -THC_API void THCTensor_(fmod)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value); +TORCH_CUDA_API void THCTensor_(mul)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value); #endif diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu index 54fe16bc85c0e..f7857cddc4978 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.cu +++ b/aten/src/THC/generic/THCTensorMathPointwise.cu @@ -11,47 +11,6 @@ static void propagate_names_if_named_tensor_enabled(THCTensor* result, THCTensor at::namedinference::propagate_names(result, src); } -#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_(NAME, CFUNC, REAL) \ - struct Tensor_##NAME##_##REAL##_Op { \ - __device__ __forceinline__ void operator()(scalar_t* out, scalar_t* in) const { \ - *out = CFUNC(*in); \ - } \ - \ - __device__ __forceinline__ void operator()(scalar_t* v) const { \ - *v = CFUNC(*v); \ - } \ - }; \ - \ - void THCTensor_(NAME)(THCState* state, THCTensor* self_, THCTensor* src) { \ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); \ - at::assert_no_internal_overlap(self_); \ - if (self_ == src) { \ - if (!THC_pointwiseApply1(state, self_, Tensor_##NAME##_##REAL##_Op())) { \ - THArgCheck(false, 2, CUTORCH_DIM_WARNING); \ - } \ - } else { \ - THCTensor_(resizeAs)(state, self_, src); \ - \ - if (!THC_pointwiseApply2(state, self_, src, Tensor_##NAME##_##REAL##_Op())) { \ - THArgCheck(false, 2, CUTORCH_DIM_WARNING); \ - } \ - } \ - \ - THCudaCheck(cudaGetLastError()); \ - propagate_names_if_named_tensor_enabled(self_, src); \ - } - -#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(NAME, CFUNC, REAL) \ - IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_(NAME, CFUNC, REAL) - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( sqrt, THCNumerics::sqrt, Real) - -#endif -#undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_ -#undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC - void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *x, THCTensor *y, int dimension) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self, x, y)); @@ -69,37 +28,5 @@ void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *x, THC THCTensor_(free)(state, ny); THCTensor_(free)(state, nself); } - -namespace { -c10::intrusive_ptr retainTensorImpl(THCTensor* self) { - c10::raw::intrusive_ptr::incref(self); - return c10::intrusive_ptr::reclaim(self); -} -} - -void THCTensor_(cmul)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2) -{ - auto out = at::Tensor(retainTensorImpl(self_)); - at::mul_out(out, at::Tensor(retainTensorImpl(src1)), at::Tensor(retainTensorImpl(src2))); -} - -void THCTensor_(cfmod)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self, src1, src2)); - THArgCheck(THCTensor_(nElement)(state, src1) == - THCTensor_(nElement)(state, src2), 2, "sizes do not match"); - - if (self == src1) { - if (!THC_pointwiseApply2(state, self, src2, TensorCFmodOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCTensor_(resizeAs)(state, self, src1); - if (!THC_pointwiseApply3(state, self, src1, src2, TensorCFmodOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } -} - #endif #endif diff --git a/aten/src/THC/generic/THCTensorMathPointwise.h b/aten/src/THC/generic/THCTensorMathPointwise.h index 658f344f8b149..8a9ea1ad78859 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.h +++ b/aten/src/THC/generic/THCTensorMathPointwise.h @@ -4,23 +4,7 @@ #if !defined(THC_REAL_IS_BOOL) -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - -THC_API void THCTensor_(atan)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(sqrt)(THCState *state, THCTensor *self, THCTensor *src); - -#endif - -THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, scalar_t min_value, scalar_t max_value); -THC_API void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2, int dimension); - -THC_API void THCTensor_(cadd)(THCState *state, THCTensor *self, THCTensor *src1, scalar_t value, THCTensor *src2); -THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1, scalar_t value, THCTensor *src2); -THC_API void THCTensor_(cmul)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(cdiv)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(clshift)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(crshift)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(cfmod)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); +TORCH_CUDA_API void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2, int dimension); #endif #endif diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu index 76f470ce7dfbf..13108491b0ec9 100644 --- a/aten/src/THC/generic/THCTensorMathReduce.cu +++ b/aten/src/THC/generic/THCTensorMathReduce.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathReduce.cu" #else +#include + #if !defined(THC_REAL_IS_BOOL) void THCTensor_(prod)(THCState* state, THCTensor *self, THCTensor *src, int dimension, int keepdim) { @@ -41,12 +43,9 @@ void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar dim3 threads(32); THCTensor_kernel_renorm - <<>> - (THCTensor_(data)(state, data), scalar_cast(value), size, scalar_cast(maxnorm)); - - cudaError_t errcode = cudaGetLastError(); - if(errcode != cudaSuccess) - THError(cudaGetErrorString(errcode)); + <<>>(THCTensor_(data)(state, data), + scalar_cast(value), size, scalar_cast(maxnorm)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } THCTensor_(free)(state, src_); diff --git a/aten/src/THC/generic/THCTensorMathReduce.h b/aten/src/THC/generic/THCTensorMathReduce.h index ebb62a64ffb1a..9326cae7c44b3 100644 --- a/aten/src/THC/generic/THCTensorMathReduce.h +++ b/aten/src/THC/generic/THCTensorMathReduce.h @@ -6,16 +6,16 @@ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, scalar_t max_norm); -THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, int keepdim); +TORCH_CUDA_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, scalar_t max_norm); +TORCH_CUDA_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, int keepdim); -THC_API accreal THCTensor_(std_all)(THCState *state, THCTensor *self, bool unbiased); -THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, scalar_t value); -THC_API accreal THCTensor_(var_all)(THCState *state, THCTensor *self, bool unbiased); +TORCH_CUDA_API accreal THCTensor_(std_all)(THCState *state, THCTensor *self, bool unbiased); +TORCH_CUDA_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, scalar_t value); +TORCH_CUDA_API accreal THCTensor_(var_all)(THCState *state, THCTensor *self, bool unbiased); #endif -THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, int dim, int keepdim); +TORCH_CUDA_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, int dim, int keepdim); #endif diff --git a/aten/src/THC/generic/THCTensorMode.cu b/aten/src/THC/generic/THCTensorMode.cu index 9fe955f3cf8d6..8c428c9a5d1bb 100644 --- a/aten/src/THC/generic/THCTensorMode.cu +++ b/aten/src/THC/generic/THCTensorMode.cu @@ -2,6 +2,7 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMode.cu" #else +#include #include void THCTensor_(calculateMode)(THCState *state, @@ -235,14 +236,14 @@ void THCTensor_(mode)(THCState *state, // Macro that calls kernel --> note that we set the block dimensions here, and // the amount of shared memory - #define HANDLE_MODE(SIZE) \ - { \ - dim3 blockSize(SIZE / 2); \ -\ - int memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \ - computeMode \ - <<>>( \ - THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ + #define HANDLE_MODE(SIZE) \ + { \ + const dim3 blockSize(SIZE / 2); \ + const auto memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \ + computeMode \ + <<>>( \ + THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } // Tradeoff between compilation time and the number of specializations. Ideally we would have diff --git a/aten/src/THC/generic/THCTensorMode.h b/aten/src/THC/generic/THCTensorMode.h index 796eb66e5379d..23d905eeb6f8d 100644 --- a/aten/src/THC/generic/THCTensorMode.h +++ b/aten/src/THC/generic/THCTensorMode.h @@ -4,7 +4,7 @@ /* Returns the mode, and index of the mode, for the set of values * along a given dimension in the input tensor. */ -THC_API void THCTensor_(mode)(THCState *state, +TORCH_CUDA_API void THCTensor_(mode)(THCState *state, THCTensor *values, THCudaLongTensor *indices, THCTensor *input, diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu index f3ca8bf93b1be..1ef540ba3302e 100644 --- a/aten/src/THC/generic/THCTensorRandom.cu +++ b/aten/src/THC/generic/THCTensorRandom.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) @@ -39,6 +40,8 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, larger_short), one, inputsize ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + at::Tensor smaller_short_wrapped = THTensor_wrap(smaller_short); at::Tensor smaller_wrapped = THTensor_wrap(smaller); at::Tensor larger_short_wrapped = THTensor_wrap(larger_short); @@ -57,6 +60,8 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, larger_short), inputsize - h_large_c, h_large_c ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + scalar_t q_max = at::max(THTensor_wrap(_q)).item(); condDiv<<< inputBlockDim, BLOCK_SIZE, 0, c10::cuda::getCurrentCUDAStream()>>>( @@ -64,6 +69,7 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCudaLongTensor_data(state, _J), inputsize, q_max ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); THCudaLongTensor_free(state, smaller); THCudaLongTensor_free(state, larger); @@ -104,6 +110,8 @@ void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, T THCTensor_(data)(state, uniform), THCTensor_(data)(state, bernoulli) ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + THCTensor_(free)(state, uniform); THCTensor_(free)(state, bernoulli); } diff --git a/aten/src/THC/generic/THCTensorRandom.h b/aten/src/THC/generic/THCTensorRandom.h index d20fe0574c0a9..4993342c9dd4f 100644 --- a/aten/src/THC/generic/THCTensorRandom.h +++ b/aten/src/THC/generic/THCTensorRandom.h @@ -6,8 +6,8 @@ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -THC_API void THCTensor_(multinomialAliasSetup)(struct THCState *state, THCTensor *probs, THCudaLongTensor *J, THCTensor *q); -THC_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample, c10::optional gen_); +TORCH_CUDA_API void THCTensor_(multinomialAliasSetup)(struct THCState *state, THCTensor *probs, THCudaLongTensor *J, THCTensor *q); +TORCH_CUDA_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample, c10::optional gen_); #endif #endif diff --git a/aten/src/THC/generic/THCTensorScatterGather.cu b/aten/src/THC/generic/THCTensorScatterGather.cu index 832539d370ce8..a1ab8d63f1636 100644 --- a/aten/src/THC/generic/THCTensorScatterGather.cu +++ b/aten/src/THC/generic/THCTensorScatterGather.cu @@ -2,10 +2,13 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorScatterGather.cu" #else +#include + #define RUN(TYPE, DIMS, REAL) \ - THCudaTensor_gatherKernel \ - <<>>( \ - tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); + THCudaTensor_gatherKernel \ + <<>>( \ + tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index) { @@ -61,19 +64,15 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, switch (indexInfo.dims) { case 1: RUN(unsigned int, 1, scalar_t); - THCudaCheck(cudaGetLastError()); break; case 2: RUN(unsigned int, 2, scalar_t); - THCudaCheck(cudaGetLastError()); break; case 3: RUN(unsigned int, 3, scalar_t); - THCudaCheck(cudaGetLastError()); break; default: RUN(unsigned int, -1, scalar_t); - THCudaCheck(cudaGetLastError()); break; } } else { @@ -84,7 +83,6 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, TensorInfo indexInfo = getTensorInfo(state, index); RUN(uint64_t, -1, scalar_t); - THCudaCheck(cudaGetLastError()); } } diff --git a/aten/src/THC/generic/THCTensorScatterGather.h b/aten/src/THC/generic/THCTensorScatterGather.h index d56854d0a14e8..9c7344af8cb75 100644 --- a/aten/src/THC/generic/THCTensorScatterGather.h +++ b/aten/src/THC/generic/THCTensorScatterGather.h @@ -2,6 +2,6 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorScatterGather.h" #else -THC_API void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index); +TORCH_CUDA_API void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index); #endif diff --git a/aten/src/THC/generic/THCTensorSort.cu b/aten/src/THC/generic/THCTensorSort.cu index b4da00a98b7fd..e378fe03358ea 100644 --- a/aten/src/THC/generic/THCTensorSort.cu +++ b/aten/src/THC/generic/THCTensorSort.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorSort.cu" #else +#include + // In alignment with default sort on a c++ map, this function // will permute key and value tensors identically, and // in such a way that the 'key' tensor is ordered numerically @@ -53,8 +55,9 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, dim3 block(blockSize); \ \ if (dir) { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace, TYPE, SIZE> \ + <<>>( \ keyInfo, \ keySlices, \ (TYPE) keySliceSize, \ @@ -62,16 +65,19 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ GTComp()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace, TYPE, SIZE> \ + <<>>( \ keyInfo, \ keySlices, \ (TYPE) keySliceSize, \ (TYPE) keyInfo.strides[collapseKeyDim], \ valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ - LTComp()); \ + LTComp()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } while (0) @@ -147,8 +153,6 @@ void THCTensor_(sortKeyValueInplace)(THCState* state, #undef HANDLE_CASE #undef HANDLE_SORT_CASE #undef HANDLE_A_CASE - - THCudaCheck(cudaGetLastError()); } void THCTensor_(sortViaThrust)(THCState* state, diff --git a/aten/src/THC/generic/THCTensorSort.h b/aten/src/THC/generic/THCTensorSort.h index d6569a3dd7047..0115327c0eafd 100644 --- a/aten/src/THC/generic/THCTensorSort.h +++ b/aten/src/THC/generic/THCTensorSort.h @@ -4,14 +4,14 @@ /* Performs an in-place sort of (keys, values). Only works for slice sizes <= 2048 at the moment (slice size == size of keys/values dim `dim`) */ -THC_API void THCTensor_(sortKeyValueInplace)(THCState* state, +TORCH_CUDA_API void THCTensor_(sortKeyValueInplace)(THCState* state, THCTensor* keys, THCudaLongTensor* values, int dim, bool dir); /* Performs an out-of-place sort of `input`, returning the per-slice indices in `indices` and the sorted values in `sorted` */ -THC_API void THCTensor_(sort)(THCState* state, +TORCH_CUDA_API void THCTensor_(sort)(THCState* state, THCTensor* sorted, THCudaLongTensor* indices, THCTensor* input, diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index a50f5e8f51ac7..8d7bf7701c040 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -3,15 +3,13 @@ #else #include +#include void THCTensor_(topk)(THCState* state, THCTensor *topK, THCudaLongTensor *indices, THCTensor *input_, int64_t k, int dim, int dir, int sorted) { - #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - TORCH_CHECK(false, "topk not suppported with BFloat16"); - #else THAssert(topK != NULL && indices != NULL && input_ != NULL); THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input_)); dim = at::maybe_wrap_dim(dim, input_); @@ -40,8 +38,8 @@ void THCTensor_(topk)(THCState* state, // is provided to the kernel for the arguments. #define RUN_K(INDEX_T, DIM, DIR) \ - gatherTopK \ - <<>>( \ + gatherTopK \ + <<>>( \ inputInfo, \ static_cast(sliceSize), \ static_cast(k), \ @@ -53,7 +51,8 @@ void THCTensor_(topk)(THCState* state, static_cast(topKSlices), \ static_cast(topKInfo.strides[collapseTopKDim]), \ indicesInfo, \ - static_cast(indicesInfo.strides[collapseIndicesDim])) + static_cast(indicesInfo.strides[collapseIndicesDim])); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() #define RUN_DIR(INDEX_T, DIM) \ if (dir) { \ @@ -74,10 +73,10 @@ void THCTensor_(topk)(THCState* state, } #define RUN_T(INDEX_T) \ - TensorInfo inputInfo = \ - getTensorInfo(state, input); \ - TensorInfo topKInfo = \ - getTensorInfo(state, topK); \ + TensorInfo inputInfo = \ + getTensorInfo(state, input); \ + TensorInfo topKInfo = \ + getTensorInfo(state, topK); \ TensorInfo indicesInfo = \ getTensorInfo(state, indices); \ \ @@ -186,7 +185,6 @@ void THCTensor_(topk)(THCState* state, THCudaLongTensor_free(state, input); THCudaCheck(cudaGetLastError()); - #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__ } #endif // THC_GENERIC_FILE diff --git a/aten/src/THC/generic/THCTensorTopK.h b/aten/src/THC/generic/THCTensorTopK.h index ffe6e959ea0bf..10eafed964a27 100644 --- a/aten/src/THC/generic/THCTensorTopK.h +++ b/aten/src/THC/generic/THCTensorTopK.h @@ -4,7 +4,7 @@ /* Returns the set of all kth smallest (or largest) elements, depending */ /* on `dir` */ -THC_API void THCTensor_(topk)(THCState* state, +TORCH_CUDA_API void THCTensor_(topk)(THCState* state, THCTensor* topK, THCudaLongTensor* indices, THCTensor* input, diff --git a/aten/src/THCUNN/RReLU.cu b/aten/src/THCUNN/RReLU.cu index 7a5c1811f2525..048f5f7294b22 100644 --- a/aten/src/THCUNN/RReLU.cu +++ b/aten/src/THCUNN/RReLU.cu @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -39,12 +40,17 @@ inline double __device__ curand_uniform_type(curandStatePhilox4_32_10_t } template -__global__ void rreluUpdateOutputTrain(int n, std::pair seeds, +__global__ void rreluUpdateOutputTrain(int n, at::PhiloxCudaState philox_args, T *input, T* noise, T *output, double a, double b) { + auto seeds = at::cuda::philox::unpack(philox_args); int idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; - curand_init(seeds.first, idx, seeds.second, &state); + curand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + CUDA_KERNEL_LOOP(i, n) { if (input[i] <= 0) diff --git a/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu b/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu index ab8d2cb1ad687..6e8d9bc91976d 100644 --- a/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu +++ b/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu @@ -3,21 +3,30 @@ #else static inline void THNN_(MultiLabelMarginCriterion_shapeCheck)( - THCState *state, - THCTensor *input, THCTensor *target) { - if (input->dim() <= 1) { + THCState *state, + THCTensor *input, THCTensor *target) { + int64_t ndims = input->dim(); + bool valid_inputs = (ndims == 2 && input->size(1) != 0) || (ndims == 1 && input->size(0) != 0) || ndims == 0; + TORCH_CHECK( + valid_inputs, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input->sizes()); + + if (ndims <= 1) { int dim = input->dim() == 0 ? 1 : input->size(0); int target_size = target->dim() == 0 ? 1 : target->size(0); - TORCH_CHECK(!target->is_empty() && (target->dim() <= 1) && (target_size == dim), - "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes()); - } else if (input->dim() == 2) { + + TORCH_CHECK(valid_inputs && target->dim() <= 1 && target->numel() == dim, + "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes()); + } else if (ndims == 2) { int nframe = input->size(0); int dim = input->size(1); - TORCH_CHECK(!target->is_empty() && (target->dim() == 2) - && (target->size(0) == nframe) && (target->size(1) == dim), - "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes()); + + TORCH_CHECK( + valid_inputs && target->dim() == 2 && target->size(0) == nframe && target->size(1) == dim, + "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes()); } else { - TORCH_CHECK(false, "non-empty vector or matrix expected, got size: ", input->sizes()); + TORCH_CHECK(false, "Expected input of ndims <= 2, but got ndims: ", ndims); } } @@ -31,6 +40,9 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( int64_t reduction) { THNN_(MultiLabelMarginCriterion_shapeCheck)(state, input, target); + if (input->numel() == 0) { + return; + } input = THCTensor_(newContiguous)(state, input); target = THCIndexTensor_(newContiguous)(state, target); istarget = THCTensor_(newContiguous)(state, istarget); @@ -100,7 +112,8 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( } } else { - TORCH_INTERNAL_ASSERT(false, "non-empty vector or matrix expected (shouldn't get here)"); + TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", + input->sizes()); } THCTensor_(free)(state, input); @@ -117,11 +130,17 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( THCTensor *istarget, int64_t reduction) { + THNN_(MultiLabelMarginCriterion_shapeCheck)(state, input, target); input = THCTensor_(newContiguous)(state, input); + THCTensor_(resizeAs)(state, gradInput, input); + if (input->numel() == 0) { + THCTensor_(free)(state, input); + return; + } + target = THCIndexTensor_(newContiguous)(state, target); istarget = THCTensor_(newContiguous)(state, istarget); gradOutput = THCTensor_(newContiguous)(state, gradOutput); - THCTensor_(resizeAs)(state, gradInput, input); if(gradInput->dim() <= 1) { @@ -149,10 +168,11 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( { int nframe = gradInput->size(0); int dim = gradInput->size(1); - THArgCheck(!target->is_empty() && (target->dim() == 2) && (target->size(0) == nframe) + THArgCheck((input->size(1) != 0) && (target->dim() == 2) && (target->size(0) == nframe) && (target->size(1) == dim), 3, "inconsistent target size"); - THArgCheck(!istarget->is_empty() && (istarget->dim() == 2) && (istarget->size(0) == nframe) + THArgCheck((istarget->dim() == 2) && (istarget->size(0) == nframe) && (istarget->size(1) == dim), 3, "inconsistent isTarget size"); + dim3 blocks(gradInput->size(0)); dim3 threads(MULTILABELMARGIN_THREADS); @@ -168,7 +188,8 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( reduction != at::Reduction::None); } else { - AT_ERROR("non-empty vector or matrix expected, got size: ", gradInput->sizes()); + TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", + gradInput->sizes()); } THCudaCheck(cudaGetLastError()); diff --git a/aten/src/THCUNN/generic/MultiMarginCriterion.cu b/aten/src/THCUNN/generic/MultiMarginCriterion.cu index f2df15054a4c8..129413f0b7b24 100644 --- a/aten/src/THCUNN/generic/MultiMarginCriterion.cu +++ b/aten/src/THCUNN/generic/MultiMarginCriterion.cu @@ -2,6 +2,30 @@ #define THC_GENERIC_FILE "THCUNN/generic/MultiMarginCriterion.cu" #else +static inline void THNN_(MultiMarginCriterion_shapeCheck)( + THCState *state, + THCTensor *input, THCTensor *target) { + int64_t nframe, dim; + int64_t ndims = input->dim(); + bool valid_inputs = (ndims == 2 && input->size(1) != 0) || (ndims == 1 && input->size(0) != 0) || ndims == 0; + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input->size(0); + } else { + nframe = input->size(0); + dim = input->size(1); + } + + TORCH_CHECK( + valid_inputs, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input->sizes()); + TORCH_CHECK( + valid_inputs && target->dim() <= 1 && target->numel() == nframe, + "inconsistent target size, got: ", + target->sizes()); +} + // TODO: improve error messages void THNN_(MultiMarginCriterion_updateOutput)( THCState *state, @@ -13,6 +37,10 @@ void THNN_(MultiMarginCriterion_updateOutput)( THCTensor *weights, accreal margin_) { + THNN_(MultiMarginCriterion_shapeCheck)(state, input, target); + if (input->numel() == 0) { + return; + } scalar_t margin = ScalarConvert::to(margin_); THCUNN_assertSameGPU(state, 2, input, target); input = THCTensor_(newContiguous)(state, input); @@ -59,7 +87,8 @@ void THNN_(MultiMarginCriterion_updateOutput)( else if (input->dim() == 2) { int nframe = input->size(0); - THArgCheck(!target->is_empty() && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3, + // allow zero-dim target for 2D input. + THArgCheck((input->size(1) != 0) && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3, "inconsistent target size"); dim3 blocks(input->size(0)); dim3 threads(MULTIMARGIN_THREADS); @@ -130,7 +159,8 @@ void THNN_(MultiMarginCriterion_updateOutput)( } else { - AT_ERROR("non-empty vector or matrix expected, got sizes: ", input->sizes()); + TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", + input->sizes()); } THCTensor_(free)(state, input); @@ -149,11 +179,17 @@ void THNN_(MultiMarginCriterion_updateGradInput)( THCTensor *weights, accreal margin_) { + THNN_(MultiMarginCriterion_shapeCheck)(state, input, target); + input = THCTensor_(newContiguous)(state, input); + THCTensor_(resizeAs)(state, gradInput, input); + if (input->numel() == 0) { + THCTensor_(free)(state, input); + return; + } scalar_t margin = ScalarConvert::to(margin_); THCUNN_assertSameGPU(state, 3, input, gradInput, target); - input = THCTensor_(newContiguous)(state, input); gradOutput = THCTensor_(newContiguous)(state, gradOutput); - THCTensor_(resizeAs)(state, gradInput, input); + if(weights) weights = THCTensor_(newContiguous)(state, weights); @@ -195,7 +231,7 @@ void THNN_(MultiMarginCriterion_updateGradInput)( else if (input->dim() == 2) { int nframe = gradInput->size(0); - THArgCheck(!target->is_empty() && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3, + THArgCheck((input->size(1) != 0) && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3, "inconsistent target size"); dim3 blocks(gradInput->size(0)); dim3 threads(MULTIMARGIN_THREADS); @@ -232,7 +268,8 @@ void THNN_(MultiMarginCriterion_updateGradInput)( } else { - AT_ERROR("non-empty vector or matrix expected, got ", input->sizes()); + TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", + input->sizes()); } THCTensor_(free)(state, input); diff --git a/aten/src/THCUNN/generic/RReLU.cu b/aten/src/THCUNN/generic/RReLU.cu index fd4e6ff0cf4c6..a320138614503 100644 --- a/aten/src/THCUNN/generic/RReLU.cu +++ b/aten/src/THCUNN/generic/RReLU.cu @@ -31,11 +31,11 @@ void THNN_(RReLU_updateOutput)( const uint32_t curand4_engine_calls = 4; dim3 grid = NUM_BLOCKS(n); uint64_t counter_offset = ((n - 1) / (BLOCK_SIZE * grid.x) + 1) * curand4_engine_calls; - std::pair rng_engine_inputs; + at::PhiloxCudaState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); } if (inplace) { @@ -67,53 +67,4 @@ void THNN_(RReLU_updateOutput)( } } } - -void THNN_(RReLU_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - THCTensor *noise, - double lower, - double upper, - bool train, - bool inplace) -{ - THCUNN_check_nElement(state, input, gradOutput); - THCUNN_assertSameGPU(state, 4, input, gradOutput, gradInput, noise); - - auto gradOutputTensor = THTensor_wrap(gradOutput).contiguous(); - gradOutput = gradOutputTensor.unsafeGetTensorImpl(); - - if (train && upper - lower > 1E-6) // e.g. if upper == lower, RReLU behaves like LeakyReLU - { - // multiply the gradient by the noise tensor - if (inplace) - { - THCTensor_(cmul)(state, gradOutput, gradOutput, noise); - THCTensor_(set)(state, gradInput, gradOutput); - } - else - { - THCTensor_(resizeAs)(state, gradInput, input); - THCTensor_(cmul)(state, gradInput, gradOutput, noise); - } - } - else - { - // use constant factor for negative input values - const scalar_t negSlope = ScalarConvert::to((lower + upper) / 2); - if (inplace) - { - THC_pointwiseApply2(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor(negSlope)); - THCTensor_(set)(state, gradInput, gradOutput); - } - else - { - THCTensor_(resizeAs)(state, gradInput, input); - THC_pointwiseApply3(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor(negSlope)); - } - } -} - #endif diff --git a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu index 535c43636af0c..599b098539139 100644 --- a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu +++ b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu @@ -114,9 +114,6 @@ void THNN_(SpatialConvolutionMM_updateOutput)( int kW, int kH, int dW, int dH, int padW, int padH) { - #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - TORCH_CHECK(false, "SpatialConvolutionMM_updateOutput not suppported with BFloat16"); - #else THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones); if (bias) { THCUNN_assertSameGPU(state, 2, weight, bias); @@ -166,10 +163,12 @@ void THNN_(SpatialConvolutionMM_updateOutput)( // Define a buffer of ones, for bias accumulation // Note: this buffer can be shared with other modules, it only ever gets increased, // and always contains ones. - if (ones->dim() != 2 || ones->size(0)*ones->size(1) < outputHeight*outputWidth) { - // Resize plane and fill with ones... - THCTensor_(resize2d)(state, ones, outputHeight, outputWidth); - THCTensor_(fill)(state, ones, ScalarConvert::to(1)); + if (bias) { + if (ones->dim() != 2 || ones->size(0)*ones->size(1) < outputHeight*outputWidth) { + // Resize plane and fill with ones... + THCTensor_(resize2d)(state, ones, outputHeight, outputWidth); + THCTensor_(fill)(state, ones, ScalarConvert::to(1)); + } } // Helpers @@ -191,16 +190,7 @@ void THNN_(SpatialConvolutionMM_updateOutput)( // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) if (bias) { - #ifdef THC_REAL_IS_FLOAT - THCudaBlas_Sgemm( - #elif defined(THC_REAL_IS_HALF) - THCudaBlas_Hgemm( - #elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_Dgemm( - #elif defined(THC_REAL_IS_BFLOAT16) - THCudaBlas_Bgemm( - #endif - state, + at::cuda::blas::gemm( 't', 'n', n_, m_, k_, ScalarConvert::to(1), @@ -235,16 +225,7 @@ void THNN_(SpatialConvolutionMM_updateOutput)( // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) auto gemm_in_ptr = (kW != 1 || kH != 1) ? THCTensor_(data)(state, columns) : THCTensor_(data)(state, input_n); - #ifdef THC_REAL_IS_FLOAT - THCudaBlas_Sgemm( - #elif defined(THC_REAL_IS_HALF) - THCudaBlas_Hgemm( - #elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_Dgemm( - #elif defined(THC_REAL_IS_BFLOAT16) - THCudaBlas_Bgemm( - #endif - state, + at::cuda::blas::gemm( 'n', 'n', n, m, k, ScalarConvert::to(1), @@ -267,7 +248,6 @@ void THNN_(SpatialConvolutionMM_updateOutput)( THCTensor_(free)(state, input); THCTensor_(free)(state, weight); - #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__ } void THNN_(SpatialConvolutionMM_updateGradInput)( @@ -281,10 +261,6 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( int kW, int kH, int dW, int dH, int padW, int padH) { - - #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - TORCH_CHECK(false, "SpatialConvolutionMM_updateGradInput not suppported with BFloat16"); - #else THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns, gradInput); weight = THNN_(newViewWeightMM2d)(state, weight); @@ -338,16 +314,7 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( int64_t k = nOutputPlane; // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) - #ifdef THC_REAL_IS_FLOAT - THCudaBlas_Sgemm( - #elif defined(THC_REAL_IS_HALF) - THCudaBlas_Hgemm( - #elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_Dgemm( - #elif defined(THC_REAL_IS_BFLOAT16) - THCudaBlas_Bgemm( - #endif - state, + at::cuda::blas::gemm( 'n', 't', n, m, k, ScalarConvert::to(1), @@ -380,7 +347,6 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( THCTensor_(free)(state, input); THCTensor_(free)(state, gradOutput); - #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__ } void THNN_(SpatialConvolutionMM_accGradParameters)( @@ -395,10 +361,6 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( int dW, int dH, int padW, int padH, accreal scale_) { - - #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - TORCH_CHECK(false, "SpatialConvolutionMM_updateGradParameters not suppported with BFloat16"); - #else scalar_t scale = ScalarConvert::to(scale_); THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, gradBias, columns, ones); if (gradWeight) { @@ -482,16 +444,7 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) auto gemm_in_ptr = (kW != 1 || kH != 1) ? THCTensor_(data)(state, columns) : THCTensor_(data)(state, input_n); - #ifdef THC_REAL_IS_FLOAT - THCudaBlas_Sgemm( - #elif defined(THC_REAL_IS_HALF) - THCudaBlas_Hgemm( - #elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_Dgemm( - #elif defined(THC_REAL_IS_BFLOAT16) - THCudaBlas_Bgemm( - #endif - state, + at::cuda::blas::gemm( 't', 'n', n, m, k, scale, @@ -510,7 +463,7 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( int64_t k_ = outputHeight * outputWidth; // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices) - #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) + //#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_BFLOAT16) at::cuda::blas::gemv( 't', k_, m_, @@ -520,23 +473,6 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( ScalarConvert::to(1), THCTensor_(data)(state, gradBias), 1 ); - #endif - #if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_BFLOAT16) - #ifdef THC_REAL_IS_HALF - THCudaBlas_Hgemm( - #elif defined(THC_REAL_IS_BFLOAT16) - THCudaBlas_Bgemm( - #endif - state, - 't', 'n', - m_, 1, k_, - scale, - THCTensor_(data)(state, gradOutput_n), k_, - THCTensor_(data)(state, ones), k_, - ScalarConvert::to(1), - THCTensor_(data)(state, gradBias), m_ - ); - #endif } } @@ -554,7 +490,6 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( THCTensor_(free)(state, input); THCTensor_(free)(state, gradOutput); - #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__ } #endif diff --git a/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu b/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu index 18d8da647d15f..53eff031a822e 100644 --- a/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu +++ b/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu @@ -13,9 +13,6 @@ void THNN_(SpatialDepthwiseConvolution_updateOutput)( int padW, int padH, int dilationW, int dilationH) { - #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - TORCH_CHECK(false, "SpatialDepthwiseConvolution_updateOutput not suppported with BFloat16"); - #else THCUNN_assertSameGPU(state, 3, input, output, weight); // Only handle 4D Input Tensors for now @@ -94,7 +91,6 @@ void THNN_(SpatialDepthwiseConvolution_updateOutput)( THCTensor_(free)(state, input); THCTensor_(free)(state, weight); if (bias) THCTensor_(free)(state, bias); - #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__ } void THNN_(SpatialDepthwiseConvolution_updateGradInput)( @@ -108,9 +104,6 @@ void THNN_(SpatialDepthwiseConvolution_updateGradInput)( int padW, int padH, int dilationW, int dilationH) { - #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - TORCH_CHECK(false, "SpatialDepthwiseConvolution_updateGradInput not suppported with BFloat16"); - #else THCUNN_assertSameGPU(state, 3, gradOutput, gradInput, weight); // Only handle 4D Input Tensors for now @@ -203,7 +196,6 @@ void THNN_(SpatialDepthwiseConvolution_updateGradInput)( THCTensor_(free)(state, weight); THCTensor_(free)(state, gradOutput); - #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__ } void THNN_(SpatialDepthwiseConvolution_accGradParameters)( @@ -216,9 +208,6 @@ void THNN_(SpatialDepthwiseConvolution_accGradParameters)( int padW, int padH, int dilationW, int dilationH) { - #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - TORCH_CHECK(false, "SpatialDepthwiseConvolution_accGradParameters not suppported with BFloat16"); - #else THCUNN_assertSameGPU(state, 3, input, gradOutput, gradWeight); // Only handle 4D Input Tensors for now @@ -271,7 +260,6 @@ void THNN_(SpatialDepthwiseConvolution_accGradParameters)( THCudaCheck(cudaGetLastError()); THCTensor_(free)(state, gradOutput); - #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__ } #endif diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index 20ebacebb9a87..9d67aa7d7f699 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -5,7 +5,7 @@ #include #include -THC_API void THNN_(ClassNLLCriterion_updateOutput)( +TORCH_CUDA_API void THNN_(ClassNLLCriterion_updateOutput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -15,7 +15,7 @@ THC_API void THNN_(ClassNLLCriterion_updateOutput)( THCTensor *total_weight, int64_t ignore_index); -THC_API void THNN_(ClassNLLCriterion_updateGradInput)( +TORCH_CUDA_API void THNN_(ClassNLLCriterion_updateGradInput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -26,33 +26,33 @@ THC_API void THNN_(ClassNLLCriterion_updateGradInput)( THCTensor *total_weight, int64_t ignore_index); -THC_API void THNN_(GatedLinear_updateOutput)( +TORCH_CUDA_API void THNN_(GatedLinear_updateOutput)( THCState *state, THCTensor *input, THCTensor *output, int dim); -THC_API void THNN_(GatedLinear_updateGradInput)( +TORCH_CUDA_API void THNN_(GatedLinear_updateGradInput)( THCState *state, THCTensor *input, THCTensor *gradOutput, THCTensor *gradInput, int dim); -THC_API void THNN_(LogSigmoid_updateOutput)( +TORCH_CUDA_API void THNN_(LogSigmoid_updateOutput)( THCState *state, THCTensor *input, THCTensor *output, THCTensor *buffer); -THC_API void THNN_(LogSigmoid_updateGradInput)( +TORCH_CUDA_API void THNN_(LogSigmoid_updateGradInput)( THCState *state, THCTensor *input, THCTensor *gradOutput, THCTensor *gradInput, THCTensor *buffer); -THC_API void THNN_(MultiLabelMarginCriterion_updateOutput)( +TORCH_CUDA_API void THNN_(MultiLabelMarginCriterion_updateOutput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -60,7 +60,7 @@ THC_API void THNN_(MultiLabelMarginCriterion_updateOutput)( THCTensor *is_target, int64_t reduction); -THC_API void THNN_(MultiLabelMarginCriterion_updateGradInput)( +TORCH_CUDA_API void THNN_(MultiLabelMarginCriterion_updateGradInput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -69,7 +69,7 @@ THC_API void THNN_(MultiLabelMarginCriterion_updateGradInput)( THCTensor *is_target, int64_t reduction); -THC_API void THNN_(MultiMarginCriterion_updateOutput)( +TORCH_CUDA_API void THNN_(MultiMarginCriterion_updateOutput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -79,7 +79,7 @@ THC_API void THNN_(MultiMarginCriterion_updateOutput)( THCTensor *weights, // [OPTIONAL] accreal margin); -THC_API void THNN_(MultiMarginCriterion_updateGradInput)( +TORCH_CUDA_API void THNN_(MultiMarginCriterion_updateGradInput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -90,7 +90,7 @@ THC_API void THNN_(MultiMarginCriterion_updateGradInput)( THCTensor *weights, // [OPTIONAL] accreal margin); -THC_API void THNN_(SpatialClassNLLCriterion_updateOutput)( +TORCH_CUDA_API void THNN_(SpatialClassNLLCriterion_updateOutput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -100,7 +100,7 @@ THC_API void THNN_(SpatialClassNLLCriterion_updateOutput)( THCTensor *total_weight, int64_t ignore_index); -THC_API void THNN_(SpatialClassNLLCriterion_updateGradInput)( +TORCH_CUDA_API void THNN_(SpatialClassNLLCriterion_updateGradInput)( THCState *state, THCTensor *input, THCIndexTensor *target, @@ -111,7 +111,7 @@ THC_API void THNN_(SpatialClassNLLCriterion_updateGradInput)( THCTensor *total_weight, int64_t ignore_index); -THC_API void THNN_(SpatialConvolutionMM_updateOutput)( +TORCH_CUDA_API void THNN_(SpatialConvolutionMM_updateOutput)( THCState *state, THCTensor *input, THCTensor *output, @@ -123,7 +123,7 @@ THC_API void THNN_(SpatialConvolutionMM_updateOutput)( int dW, int dH, int padW, int padH); -THC_API void THNN_(SpatialConvolutionMM_updateGradInput)( +TORCH_CUDA_API void THNN_(SpatialConvolutionMM_updateGradInput)( THCState *state, THCTensor *input, THCTensor *gradOutput, @@ -135,7 +135,7 @@ THC_API void THNN_(SpatialConvolutionMM_updateGradInput)( int dW, int dH, int padW, int padH); -THC_API void THNN_(SpatialConvolutionMM_accGradParameters)( +TORCH_CUDA_API void THNN_(SpatialConvolutionMM_accGradParameters)( THCState *state, THCTensor *input, THCTensor *gradOutput, @@ -148,7 +148,7 @@ THC_API void THNN_(SpatialConvolutionMM_accGradParameters)( int padW, int padH, accreal scale); -THC_API void THNN_(SpatialDepthwiseConvolution_updateOutput)( +TORCH_CUDA_API void THNN_(SpatialDepthwiseConvolution_updateOutput)( THCState *state, THCTensor *input, THCTensor *output, @@ -159,7 +159,7 @@ THC_API void THNN_(SpatialDepthwiseConvolution_updateOutput)( int padW, int padH, int dilationW, int dilationH); -THC_API void THNN_(SpatialDepthwiseConvolution_updateGradInput)( +TORCH_CUDA_API void THNN_(SpatialDepthwiseConvolution_updateGradInput)( THCState *state, THCTensor *input, THCTensor *gradOutput, @@ -170,7 +170,7 @@ THC_API void THNN_(SpatialDepthwiseConvolution_updateGradInput)( int padW, int padH, int dilationW, int dilationH); -THC_API void THNN_(SpatialDepthwiseConvolution_accGradParameters)( +TORCH_CUDA_API void THNN_(SpatialDepthwiseConvolution_accGradParameters)( THCState *state, THCTensor *input, THCTensor *gradOutput, @@ -180,7 +180,7 @@ THC_API void THNN_(SpatialDepthwiseConvolution_accGradParameters)( int padW, int padH, int dilationW, int dilationH); -THC_API void THNN_(RReLU_updateOutput)( +TORCH_CUDA_API void THNN_(RReLU_updateOutput)( THCState *state, THCTensor *input, THCTensor *output, @@ -191,7 +191,7 @@ THC_API void THNN_(RReLU_updateOutput)( bool inplace, c10::optional generator); -THC_API void THNN_(RReLU_updateGradInput)( +TORCH_CUDA_API void THNN_(RReLU_updateGradInput)( THCState *state, THCTensor *input, THCTensor *gradOutput, diff --git a/aten/tools/valgrind.sup b/aten/tools/valgrind.sup index 0b7ef9501b5be..ad5f66e0b0531 100644 --- a/aten/tools/valgrind.sup +++ b/aten/tools/valgrind.sup @@ -24,6 +24,15 @@ ... } +{ + ignore_cuda_ioctl_param_points_to_uninitialised_bytes + Memcheck:Param + ioctl(generic) + fun:ioctl + obj:*libcuda.so* + ... +} + { ignore_libomp_setaffinity_check Memcheck:Param diff --git a/benchmarks/cpp/tensorexpr/CMakeLists.txt b/benchmarks/cpp/tensorexpr/CMakeLists.txt new file mode 100644 index 0000000000000..85ab72ab3589b --- /dev/null +++ b/benchmarks/cpp/tensorexpr/CMakeLists.txt @@ -0,0 +1,9 @@ +add_executable( + tensorexpr_bench + bench_approx.cpp + bench_compile.cpp + bench_fuser_overhead.cpp + bench_gemm.cpp + main.cpp) + +target_link_libraries(tensorexpr_bench PRIVATE torch_library benchmark) diff --git a/benchmarks/cpp/tensorexpr/bench_approx.cpp b/benchmarks/cpp/tensorexpr/bench_approx.cpp new file mode 100644 index 0000000000000..220ea71497ff1 --- /dev/null +++ b/benchmarks/cpp/tensorexpr/bench_approx.cpp @@ -0,0 +1,145 @@ +#include +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +static void log_sleef(benchmark::State& state) { + KernelScope ks; + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + torch::jit::tensorexpr::Tensor* B = + Compute("B", {N}, [&](const VarHandle& i) { + return log(A.load(i)); + }); + LoopNest ln({B}); + ln.prepareForCodegen(); + ln.vectorizeInnerLoops(); + Stmt* s = ln.root_stmt(); + s = torch::jit::tensorexpr::IRSimplifier::simplify(s); + std::vector args; + args.emplace_back(B); + args.emplace_back(A); + args.emplace_back(N); + LLVMCodeGen cg(s, args); + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + auto B_ref = at::log(A_t); + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + assert(at::allclose(B_t, B_ref)); + for (auto _ : state) { + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + } + state.counters["log/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +static void log_fast(benchmark::State& state) { + KernelScope ks; + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + torch::jit::tensorexpr::Tensor* B = + Compute("B", {N}, [&](const VarHandle& i) { + return fast_log(A.load(i)); + }); + LoopNest ln({B}); + ln.prepareForCodegen(); + ln.vectorizeInnerLoops(); + Stmt* s = ln.root_stmt(); + s = torch::jit::tensorexpr::IRSimplifier::simplify(s); + std::vector args; + args.emplace_back(B); + args.emplace_back(A); + args.emplace_back(N); + LLVMCodeGen cg(s, args); + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + auto B_ref = at::log(A_t); + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + assert(at::allclose(B_t, B_ref)); + for (auto _ : state) { + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + } + state.counters["log/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +static void log_aten(benchmark::State& state) { + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + for (auto _ : state) { + at::native::log_out(B_t, A_t); + } + state.counters["log/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +static void logit_fast(benchmark::State& state) { + KernelScope ks; + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + torch::jit::tensorexpr::Tensor* B = + Compute("B", {N}, [&](const VarHandle& i) { + auto A_elem = A.load(i); + return fast_log(A_elem / (FloatImm::make(1.0f) - A_elem)); + }); + LoopNest ln({B}); + ln.prepareForCodegen(); + ln.vectorizeInnerLoops(); + Stmt* s = ln.root_stmt(); + s = torch::jit::tensorexpr::IRSimplifier::simplify(s); + std::vector args; + args.emplace_back(B); + args.emplace_back(A); + args.emplace_back(N); + LLVMCodeGen cg(s, args); + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + auto B_ref = at::logit(A_t); + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + assert(at::allclose(B_t, B_ref)); + for (auto _ : state) { + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + } + state.counters["logit/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +static void logit_aten(benchmark::State& state) { + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + for (auto _ : state) { + at::native::logit_out(B_t, A_t); + } + state.counters["logit/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +BENCHMARK(log_sleef) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(log_fast) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(log_aten) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(logit_fast) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(logit_aten) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); diff --git a/benchmarks/cpp/tensorexpr/bench_compile.cpp b/benchmarks/cpp/tensorexpr/bench_compile.cpp new file mode 100644 index 0000000000000..d2eae8f8ab3e6 --- /dev/null +++ b/benchmarks/cpp/tensorexpr/bench_compile.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include + +#ifdef TORCH_ENABLE_LLVM +namespace te = torch::jit::tensorexpr; + +static void BM_CompileSwish(benchmark::State& state) { + for (auto _ : state) { + constexpr int N = 512; + te::KernelScope ks; + te::VarHandle n("n", te::kInt); + te::Placeholder A(te::BufHandle("A", {N}, te::kFloat)); + te::Tensor* relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) { + return te::Max::make(A.load(i), 0.f, false); + }); + te::Tensor* min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) { + return te::Min::make(relu->call(i), 6.f, false); + }); + te::Tensor* plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) { + return min6->call(i) + 3.f; + }); + te::Tensor* times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) { + return A.load(i) * plus3->call(i); + }); + te::Tensor* sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) { + return times->call(i) * 1.f / 6.f; + }); + te::LoopNest nest({sixth}); + for (auto tensor : {relu, min6, plus3, times}) { + nest.computeInline(tensor->buf()); + } + nest.prepareForCodegen(); + te::Stmt* s = te::IRSimplifier::simplify(nest.root_stmt()); + te::LLVMCodeGen cg(s, {A, sixth}); + } +} + +static void BM_CompileSwishLLVMOnly(benchmark::State& state) { + constexpr int N = 512; + te::KernelScope ks; + te::VarHandle n("n", te::kInt); + te::Placeholder A(te::BufHandle("A", {N}, te::kFloat)); + te::Tensor* relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) { + return te::Max::make(A.load(i), 0.f, false); + }); + te::Tensor* min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) { + return te::Min::make(relu->call(i), 6.f, false); + }); + te::Tensor* plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) { + return min6->call(i) + 3.f; + }); + te::Tensor* times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) { + return A.load(i) * plus3->call(i); + }); + te::Tensor* sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) { + return times->call(i) * 1.f / 6.f; + }); + te::LoopNest nest({sixth}); + for (auto tensor : {relu, min6, plus3, times}) { + nest.computeInline(tensor->buf()); + } + nest.prepareForCodegen(); + te::Stmt* s = te::IRSimplifier::simplify(nest.root_stmt()); + for (auto _ : state) { + te::LLVMCodeGen cg(s, {A, sixth}); + } +} + +BENCHMARK(BM_CompileSwish); +BENCHMARK(BM_CompileSwishLLVMOnly); +#endif // TORCH_ENABLE_LLVM diff --git a/benchmarks/cpp/tensorexpr/bench_fuser_overhead.cpp b/benchmarks/cpp/tensorexpr/bench_fuser_overhead.cpp new file mode 100644 index 0000000000000..1ce66747f2f06 --- /dev/null +++ b/benchmarks/cpp/tensorexpr/bench_fuser_overhead.cpp @@ -0,0 +1,57 @@ +#include +#include +#include + +using namespace torch::jit; + +static const std::string two_adds = R"JIT( +def two_adds(self, x: Tensor, y: Tensor, z: Tensor) -> Tensor: + return x + y + z +)JIT"; + +static void FusedOverhead(benchmark::State& state) { + torch::NoGradGuard ng; + torch::AutoNonVariableTypeMode nv; + overrideCanFuseOnCPU(true); + + Module m("m"); + m.define(two_adds); + + auto x = torch::ones({1}); + auto y = torch::ones({1}); + auto z = torch::ones({1}); + + // Warmup. + for (int i = 0; i < 8; i++) { + m.run_method("two_adds", x, y, z); + } + + for (auto _ : state) { + m.run_method("two_adds", x, y, z); + } +} + +static void UnfusedOverhead(benchmark::State& state) { + torch::NoGradGuard ng; + torch::AutoNonVariableTypeMode nv; + overrideCanFuseOnCPU(false); + + Module m("m"); + m.define(two_adds); + + auto x = torch::ones({1}); + auto y = torch::ones({1}); + auto z = torch::ones({1}); + + // Warmup. + for (int i = 0; i < 8; i++) { + m.run_method("two_adds", x, y, z); + } + + for (auto _ : state) { + m.run_method("two_adds", x, y, z); + } +} + +BENCHMARK(FusedOverhead); +BENCHMARK(UnfusedOverhead); diff --git a/benchmarks/cpp/tensorexpr/bench_gemm.cpp b/benchmarks/cpp/tensorexpr/bench_gemm.cpp new file mode 100644 index 0000000000000..78855264a5b46 --- /dev/null +++ b/benchmarks/cpp/tensorexpr/bench_gemm.cpp @@ -0,0 +1,339 @@ +#include +#include +#include +#include +#include + +namespace te = torch::jit::tensorexpr; + +namespace { +class Gemm : public benchmark::Fixture { + public: + void SetUp(const benchmark::State& state) override { + M = state.range(0); + N = state.range(1); + K = state.range(2); + A = torch::randn({M, K}); + B = torch::randn({K, N}); + C = torch::mm(A, B); + } + + void TearDown(benchmark::State& state) override { + state.counters["GFLOPS"] = benchmark::Counter( + uint64_t(state.iterations()) * 2 * M * N * K, + benchmark::Counter::kIsRate); + } + + int M; + int N; + int K; + at::Tensor A; + at::Tensor B; + at::Tensor C; +}; +} + +BENCHMARK_DEFINE_F(Gemm, Torch)(benchmark::State& state) { + for (auto _ : state) { + torch::mm_out(C, A, B); + } +} + +BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { + te::KernelScope ks; + + te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); + te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); + te::Tensor* CT = te::Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + te::Sum(), + [&](const te::ExprHandle& m, + const te::ExprHandle& n, + const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, + {{K, "K"}}); + te::LoopNest loop({CT}); + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); + + for (auto _ : state) { + cg->call({A.data_ptr(), B.data_ptr(), C.data_ptr()}); + } +} + +BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) { + te::KernelScope ks; + + te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); + te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); + te::Tensor* CT = te::Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + te::Sum(), + [&](const te::ExprHandle& m, + const te::ExprHandle& n, + const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, + {{K, "K"}}); + te::LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* m = loops[0]; + te::For* mo; + te::For* mi; + loop.splitWithMask(m, 32, &mo, &mi); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* n = loops[2]; + te::For* no; + te::For* ni; + loop.splitWithMask(n, 32, &no, &ni); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[1]; + te::For* no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* ni = loops[3]; + te::For* k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[2]; + te::For* k = loops[3]; + loop.reorderAxis(mi, k); + } + + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); + + for (auto _ : state) { + cg->call({A.data_ptr(), B.data_ptr(), C.data_ptr()}); + } +} + +BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) { + te::KernelScope ks; + + te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); + te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); + te::Tensor* CT = te::Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + te::Sum(), + [&](const te::ExprHandle& m, + const te::ExprHandle& n, + const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, + {{K, "K"}}); + te::LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* m = loops[0]; + te::For* mo; + te::For* mi; + loop.splitWithMask(m, 4, &mo, &mi); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* n = loops[2]; + te::For* no; + te::For* ni; + loop.splitWithMask(n, 16, &no, &ni); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[1]; + te::For* no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* ni = loops[3]; + te::For* k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[2]; + te::For* k = loops[3]; + loop.reorderAxis(mi, k); + } + + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); + + for (auto _ : state) { + cg->call({A.data_ptr(), B.data_ptr(), C.data_ptr()}); + } +} + +BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) { + te::KernelScope ks; + + te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); + te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); + te::Tensor* CT = te::Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + te::Sum(), + [&](const te::ExprHandle& m, + const te::ExprHandle& n, + const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, + {{K, "K"}}); + te::LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* m = loops[0]; + te::For* mo; + te::For* mi; + loop.splitWithMask(m, 4, &mo, &mi); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* n = loops[2]; + te::For* no; + te::For* ni; + loop.splitWithMask(n, 16, &no, &ni); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[1]; + te::For* no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* ni = loops[3]; + te::For* k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[2]; + te::For* k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[3]; + te::For* ni = loops[4]; + te::Stmt* unrolled; + loop.vectorize(ni); + loop.unroll(mi, &unrolled); + } + + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); + + for (auto _ : state) { + cg->call({A.data_ptr(), B.data_ptr(), C.data_ptr()}); + } +} + +BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) { + te::KernelScope ks; + + te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); + te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); + te::Tensor* CT = te::Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + te::Sum(), + [&](const te::ExprHandle& m, + const te::ExprHandle& n, + const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, + {{K, "K"}}); + te::LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* m = loops[0]; + te::For* mo; + te::For* mi; + loop.splitWithMask(m, 4, &mo, &mi); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* n = loops[2]; + te::For* no; + te::For* ni; + loop.splitWithMask(n, 16, &no, &ni); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[1]; + te::For* no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* ni = loops[3]; + te::For* k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + te::For* mi = loops[2]; + te::For* k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + loop.cacheAccesses(CT->buf(), "C_regs", loops[2]); + } + + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); + + for (auto _ : state) { + cg->call({A.data_ptr(), B.data_ptr(), C.data_ptr()}); + } +} + +BENCHMARK_REGISTER_F(Gemm, Torch)->Args({128, 128, 128}); +BENCHMARK_REGISTER_F(Gemm, TensorExprNoopt)->Args({128, 128, 128}); +BENCHMARK_REGISTER_F(Gemm, TensorExprTile32x32)->Args({128, 128, 128}); +BENCHMARK_REGISTER_F(Gemm, TensorExprTile4x16)->Args({128, 128, 128}); +BENCHMARK_REGISTER_F(Gemm, TensorExprTile4x16VecUnroll)->Args({128, 128, 128}); +BENCHMARK_REGISTER_F(Gemm, TensorExprTile4x16Cache)->Args({128, 128, 128}); diff --git a/benchmarks/cpp/tensorexpr/main.cpp b/benchmarks/cpp/tensorexpr/main.cpp new file mode 100644 index 0000000000000..71fefa0472287 --- /dev/null +++ b/benchmarks/cpp/tensorexpr/main.cpp @@ -0,0 +1,3 @@ +#include + +BENCHMARK_MAIN(); diff --git a/benchmarks/distributed/ddp/benchmark.py b/benchmarks/distributed/ddp/benchmark.py index 4567749665f6f..202ad3c5f56cd 100644 --- a/benchmarks/distributed/ddp/benchmark.py +++ b/benchmarks/distributed/ddp/benchmark.py @@ -26,10 +26,6 @@ import torchvision -if not torch._six.PY3: - raise RuntimeError("DDP benchmark requires Python 3") - - def allgather_object(obj): buffer = io.BytesIO() torch.save(obj, buffer) diff --git a/benchmarks/distributed/ddp/diff.py b/benchmarks/distributed/ddp/diff.py index 59cd5b533c1b4..dc984626888a0 100644 --- a/benchmarks/distributed/ddp/diff.py +++ b/benchmarks/distributed/ddp/diff.py @@ -9,10 +9,6 @@ import numpy as np -if not torch._six.PY3: - raise RuntimeError("DDP benchmark requires Python 3") - - def load(path): with open(path, 'r') as f: return json.load(f) @@ -24,7 +20,7 @@ def main(): args = parser.parse_args() if len(args.file) != 2: - raise "Must specify 2 files to diff" + raise RuntimeError("Must specify 2 files to diff") ja = load(args.file[0]) jb = load(args.file[1]) diff --git a/benchmarks/distributed/pipeline/benchmark_dataset.py b/benchmarks/distributed/pipeline/benchmark_dataset.py new file mode 100644 index 0000000000000..e8c516e4bc441 --- /dev/null +++ b/benchmarks/distributed/pipeline/benchmark_dataset.py @@ -0,0 +1,56 @@ +import torch +from torch.utils.data import Dataset + + +def collate_sentences_lm(samples): + + if len(samples) == 0: + return {} + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = torch.stack([s["source"] for s in samples], 0) + tgt_tokens = torch.stack([s["target"] for s in samples], 0) + ntokens = len(samples) * len(samples[0]["target"]) + src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) + + batch = { + "id": id, + "nsentences": len(samples), + "ntokens": ntokens, + "input": src_tokens, + "target": tgt_tokens, + } + return batch + + +class BenchmarkLMDataset(Dataset): + """ + Dataset to benchmark a translation like seq2seq task. + Args: + vocab_size (int, optional): size of the vocabulary (default 10000). + max_source_positions (int, optional): max number of tokens in the + source sentence (default: 1024). + total_samples (int, optional): the total number of rows in the + dataset (default: 10000). + """ + + def __init__( + self, vocab_size=10000, max_source_positions=1024, total_samples=10000, + ): + self.vocab_size = vocab_size + self.max_source_positions = max_source_positions + self.total_samples = total_samples + self.sizes = [self.max_source_positions] * self.total_samples + + def __getitem__(self, index): + length = self.sizes[index] + source = torch.randint(1, self.vocab_size, (length,)) + target = source.clone() + return { + "id": index, + "source": source, + "target": target, + } + + def __len__(self): + return self.total_samples diff --git a/benchmarks/distributed/pipeline/pipe.py b/benchmarks/distributed/pipeline/pipe.py new file mode 100644 index 0000000000000..3433110f3194e --- /dev/null +++ b/benchmarks/distributed/pipeline/pipe.py @@ -0,0 +1,277 @@ +import argparse +import math +import os +import time + +from .benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm +import torch +from torch.distributed import rpc +import torch.nn as nn +from torch.utils.data import DataLoader + +from torch.distributed.pipeline.sync import Pipe +from torch.testing._internal.distributed.pipeline.utils import convert_to_balance +from torch.optim import Adam # type: ignore + +def sizeof_fmt(num, suffix='B'): + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti']: + if abs(num) < 1024.0: + return "%3.2f%sB" % (num, unit) + num /= 1024.0 + + +def init_random_seed(seed: int): + import numpy + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + numpy.random.seed(seed) + + +iteration_count = 0 + + +class EmbeddingLayer(nn.Embedding): + def __init__(self, ntoken, ninp, initrange): + super().__init__(ntoken, ninp) + self.ninp = ninp + self.weight.data.uniform_(-initrange, initrange) + + def forward(self, src): + return super().forward(src) * math.sqrt(self.ninp) + + +class PositionalEncodingLayer(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncodingLayer, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[: x.size(0), :] + return self.dropout(x) + + +class TransformerDecoderLayer(nn.TransformerEncoderLayer): + """Though this class inherits from torch.nn.TransformerEncoderLayer, + it functions as a decoder in this model""" + + def __init__(self, ninp, nhead, nhid, droupout): + super().__init__(ninp, nhead, nhid, droupout) + self.src_mask = None + + def _generate_square_subsequent_mask(self, sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) + return mask + + def forward(self, src): + global iteration_count + iteration_count += 1 + + if self.src_mask is None or self.src_mask.size(0) != len(src): + device = src.device + mask = self._generate_square_subsequent_mask(len(src)).to(device) + self.src_mask = mask + + return super().forward(src, self.src_mask) + + +class LinearLayer(nn.Linear): + def __init__(self, ninp, ntoken, initrange): + super().__init__(ninp, ntoken) + self.bias.data.zero_() + self.weight.data.uniform_(-initrange, initrange) + + +class TransformerLMSequential(nn.Sequential): + """A small language model based on the design of GPT-2 using nn.Sequential + for compatibility with Pipe""" + + def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): + layers = [ + EmbeddingLayer(ntokens, ninp, initrange), + PositionalEncodingLayer(ninp, dropout), + ] + for _ in range(ndecoder): + layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) + + layers.append(LinearLayer(ninp, ntokens, initrange)) + super(TransformerLMSequential, self).__init__(*layers) + + +def make_model(args, device, ntokens): + ninp = 2048 # embedding dimension + nhid = 2048 # the dimension of the feedforward network model in nn.TransformerEncoder + nhead = 32 # the number of heads in the multiheadattention models + dropout = 0 + initrange = 0.1 + ndecoder = args.num_decoder_layers + + model = TransformerLMSequential(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) + + criterion = nn.CrossEntropyLoss() + lr = 0.01 # learning rate + + def make_adam(model): + return Adam(model.parameters(), lr=lr) + + optimizer = make_adam + + return model, criterion, optimizer + + +def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): + model.train() + + vocab_size = 10000 + total_loss = 0.0 + start_time = time.time() + word_counter = 0 + + optimizer = optimizer(model) + + def get_first_device(model): + if model.devices: + return model.devices[0] + else: + return torch.cuda.current_device() + + def get_last_device(model): + if model.devices: + return model.devices[-1] + else: + return torch.cuda.current_device() + + + print('Number of parameters for model: {}'.format(sum(p.numel() for p in model.parameters()))) + for i, batch in enumerate(lm_dataloader): + bi = batch["input"] + if args.max_batch and i > args.max_batch: + break + optimizer.zero_grad() + try: + tmp = batch["input"].to(get_first_device(model)) + output = model(tmp).local_value() + except Exception as e: + raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e + + target = batch["target"].to(get_last_device(model)) + output = output.to(target.device) + + loss = criterion(output.view(-1, vocab_size), target.view(-1)) + loss.backward() + del target + del output + + torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) + optimizer.step() + + total_loss += loss.item() + log_interval = 1 + word_counter += batch["ntokens"] + if i % log_interval == 0 and i > 0: + cur_loss = total_loss / log_interval + elapsed = time.time() - start_time + print( + "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( + i, word_counter / elapsed, cur_loss, math.exp(cur_loss) + ) + ) + word_counter = 0 + total_loss = 0 + start_time = time.time() + + print('Peak memory usage for GPUs: ', end='') + for i in range(len(model.devices)): + print("cuda:{}: {}, ".format( + i, + sizeof_fmt(torch.cuda.memory_stats(i)["allocated_bytes.all.peak"])), end='') + print() + + +def generate_balance(num_devices, num_layers): + balance = [] + layers_assigned = 0 + for i in range(num_devices): + x = (num_layers - layers_assigned) / (num_devices - i) + if x.is_integer(): + balance.append(int(x)) + layers_assigned += x + else: + balance.append(math.ceil(x)) + layers_assigned += math.ceil(x) + return balance + + +def make_model_and_data(args, device): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + vocab_size = 10000 + model, criterion, optimizer = make_model(args, device, vocab_size) + lm_dataset = BenchmarkLMDataset() + lm_dataloader = DataLoader( + lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm + ) + return { + "model": model, + "criterion": criterion, + "optimizer": optimizer, + "data": lm_dataloader, + "vocab_size": vocab_size, + } + + +def bench_single_process(args): + os.environ.update({"MASTER_ADDR" : args.host}) + os.environ.update({"MASTER_PORT" : "10638"}) + + rpc.init_rpc( + "worker", + rank=0, + world_size=1, + ) + + num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 + num_devices = min(args.num_devices, num_devices) + assert num_devices > 0 + init_random_seed(0) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + blob = make_model_and_data(args, None) + model = blob["model"] + + balance = generate_balance(num_devices, len(model)) + model = convert_to_balance(model, balance) + p = Pipe( + model, chunks=args.chunks, checkpoint=args.checkpoint + ) + del model + del blob["model"] + + train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args) + +parser = argparse.ArgumentParser(description="benchmark") +parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") +parser.add_argument("--chunks", type=int, default=4, help="number of microbatches per batch") +parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") +parser.add_argument("--max-batch", type=int, default=10, help="Max number of batches") +parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model") +parser.add_argument( + "--checkpoint", default="except_last", choices=["always", "except_last", "never"], + help="Checkpointing strategy for pipe" +) +parser.add_argument( + "--num-devices", type=int, default=4, help="Number of GPU devices to use" +) + +if __name__ == "__main__": + args = parser.parse_args() + print(f"Running benchmark with args: {args}") + bench_single_process(args) diff --git a/benchmarks/distributed/rpc/rl/README.md b/benchmarks/distributed/rpc/rl/README.md new file mode 100644 index 0000000000000..1cd29a7a4b619 --- /dev/null +++ b/benchmarks/distributed/rpc/rl/README.md @@ -0,0 +1,42 @@ +# Distributed RPC Reinforcement Learning Benchmark + +This tool is used to measure `torch.distributed.rpc` throughput and latency for reinforcement learning. + +The benchmark spawns one *agent* process and a configurable number of *observer* processes. As this benchmark focuses on RPC throughput and latency, the agent uses a dummy policy and observers all use randomly generated states and rewards. In each iteration, observers pass their state to the agent through `torch.distributed.rpc` and wait for the agent to respond with an action. If `batch=False`, then the agent will process and respond to a single observer request at a time. Otherwise, the agent will accumulate requests from multiple observers and run them through the policy in one shot. There is also a separate *coordinator* process that manages the *agent* and *observers*. + +In addition to printing measurements, this benchmark produces a JSON file. Users may choose a single argument to provide multiple comma-separated entries for (ie: `world_size="10,50,100"`) in which case the JSON file produced can be passed to the plotting repo to visually see how results differ. In this case, each entry for the variable argument will be placed on the x axis. + +The benchmark results comprise of 4 key metrics: +1. _Agent Latency_ - How long does it take from the time the first action request in a batch is received from an observer to the time an action is selected by the agent for each request in that batch. If `batch=False` you can think of it as `batch_size=1`. +2. _Agent Throughput_ - The number of request processed per second for a given batch. Agent throughput is literally computed as `(batch_size / agent_latency)`. If not using batch, you can think of it as `batch_size=1`. +3. _Observer Latency_ - Time it takes from the moment an action is requested by a single observer to the time the response is received from the agent. Therefore if `batch=False`, observer latency is the agent latency plus the transit time it takes for the request to get to the agent from the observer plus the transit time it takes for the response to get to the observer from the agent. When `batch=True` there will be more variation due to some observer requests being queued in a batch for longer than others depending on what order those requests came into the batch in. +4. _Observer Throughput_ - Number of requests processed per second for a single observer. Observer Throughput is literally computed as `(1 / observer_latency)`. + +## Requirements + +This benchmark depends on PyTorch. + +## How to run + +For any environments you are interested in, pass the corresponding arguments to `python launcher.py`. + +```python launcher.py --world_size="10,20" --master_addr="127.0.0.1" --master_port="29501 --batch="True" --state_size="10-20-10" --nlayers="5" --out_features="10" --output_file_path="benchmark_report.json"``` + +Example Output: + +``` +-------------------------------------------------------------- +PyTorch distributed rpc benchmark reinforcement learning suite +-------------------------------------------------------------- +master_addr : 127.0.0.1 +master_port : 29501 +batch : True +state_size : 10-20-10 +nlayers : 5 +out_features : 10 +output_file_path : benchmark_report.json +x_axis_name : world_size +world_size | agent latency (seconds) agent throughput observer latency (seconds) observer throughput + p50 p75 p90 p95 p50 p75 p90 p95 p50 p75 p90 p95 p50 p75 p90 p95 +10 0.002 0.002 0.002 0.002 4432 4706 4948 5128 0.002 0.003 0.003 0.003 407 422 434 443 +20 0.004 0.005 0.005 0.005 4244 4620 4884 5014 0.005 0.005 0.006 0.006 191 207 215 220 diff --git a/benchmarks/distributed/rpc/rl/agent.py b/benchmarks/distributed/rpc/rl/agent.py new file mode 100644 index 0000000000000..4f55bdef84920 --- /dev/null +++ b/benchmarks/distributed/rpc/rl/agent.py @@ -0,0 +1,169 @@ +from functools import reduce +import time +import threading + +import torch +from torch.distributions import Categorical +import torch.distributed.rpc as rpc +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + + +OBSERVER_NAME = "observer{}" + + +class Policy(nn.Module): + def __init__(self, in_features, nlayers, out_features): + r""" + Inits policy class + Args: + in_features (int): Number of input features the model takes + nlayers (int): Number of layers in the model + out_features (int): Number of features the model outputs + """ + super(Policy, self).__init__() + + self.model = nn.Sequential( + nn.Flatten(1, -1), + nn.Linear(in_features, out_features), + * [nn.Linear(out_features, out_features) for _ in range(nlayers)] + ) + self.dim = 0 + + def forward(self, x): + action_scores = self.model(x) + return F.softmax(action_scores, dim=self.dim) + + +class AgentBase: + def __init__(self): + r""" + Inits agent class + """ + self.id = rpc.get_worker_info().id + self.running_reward = 0 + self.eps = 1e-7 + + self.rewards = {} + + self.future_actions = torch.futures.Future() + self.lock = threading.Lock() + + self.agent_latency_start = None + self.agent_latency_end = None + self.agent_latency = [] + self.agent_throughput = [] + + def reset_metrics(self): + r""" + Sets all benchmark metrics to their empty values + """ + self.agent_latency_start = None + self.agent_latency_end = None + self.agent_latency = [] + self.agent_throughput = [] + + def set_world(self, batch_size, state_size, nlayers, out_features, batch=True): + r""" + Further initializes agent to be aware of rpc environment + Args: + batch_size (int): size of batches of observer requests to process + state_size (list): List of ints dictating the dimensions of the state + nlayers (int): Number of layers in the model + out_features (int): Number of out features in the model + batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time + """ + self.batch = batch + self.policy = Policy(reduce((lambda x, y: x * y), state_size), nlayers, out_features) + self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) + + self.batch_size = batch_size + for rank in range(batch_size): + ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2)) + + self.rewards[ob_info.id] = [] + + self.saved_log_probs = [] if self.batch else { + k: [] for k in range(self.batch_size)} + + self.pending_states = self.batch_size + self.state_size = state_size + self.states = torch.zeros(self.batch_size, *state_size) + + @staticmethod + @rpc.functions.async_execution + def select_action_batch(agent_rref, observer_id, state): + r""" + Receives state from an observer to select action for. Queues the observers's request + for an action until queue size equals batch size named during Agent initiation, at which point + actions are selected for all pending observer requests and communicated back to observers + Args: + agent_rref (RRef): RRFef of this agent + observer_id (int): Observer id of observer calling this function + state (Tensor): Tensor representing current state held by observer + """ + self = agent_rref.local_value() + observer_id -= 2 + + self.states[observer_id].copy_(state) + future_action = self.future_actions.then( + lambda future_actions: future_actions.wait()[observer_id].item() + ) + + with self.lock: + if self.pending_states == self.batch_size: + self.agent_latency_start = time.time() + self.pending_states -= 1 + if self.pending_states == 0: + self.pending_states = self.batch_size + probs = self.policy(self.states) + m = Categorical(probs) + actions = m.sample() + self.saved_log_probs.append(m.log_prob(actions).t()) + future_actions = self.future_actions + self.future_actions = torch.futures.Future() + future_actions.set_result(actions) + + self.agent_latency_end = time.time() + + batch_latency = self.agent_latency_end - self.agent_latency_start + self.agent_latency.append(batch_latency) + self.agent_throughput.append(self.batch_size / batch_latency) + + return future_action + + @staticmethod + def select_action_non_batch(agent_rref, observer_id, state): + r""" + Select actions based on observer state and communicates back to observer + Args: + agent_rref (RRef): RRef of this agent + observer_id (int): Observer id of observer calling this function + state (Tensor): Tensor representing current state held by observer + """ + self = agent_rref.local_value() + observer_id -= 2 + agent_latency_start = time.time() + + state = state.float().unsqueeze(0) + probs = self.policy(state) + m = Categorical(probs) + action = m.sample() + self.saved_log_probs[observer_id].append(m.log_prob(action)) + + agent_latency_end = time.time() + non_batch_latency = agent_latency_end - agent_latency_start + self.agent_latency.append(non_batch_latency) + self.agent_throughput.append(1 / non_batch_latency) + + return action.item() + + def finish_episode(self, rets): + r""" + Finishes the episode + Args: + rets (list): List containing rewards generated by selct action calls during + episode run + """ + return self.agent_latency, self.agent_throughput diff --git a/benchmarks/distributed/rpc/rl/coordinator.py b/benchmarks/distributed/rpc/rl/coordinator.py new file mode 100644 index 0000000000000..1b53fe4ac00c6 --- /dev/null +++ b/benchmarks/distributed/rpc/rl/coordinator.py @@ -0,0 +1,139 @@ +import numpy as np +import time + +import torch +import torch.distributed.rpc as rpc + +from agent import AgentBase +from observer import ObserverBase + +COORDINATOR_NAME = "coordinator" +AGENT_NAME = "agent" +OBSERVER_NAME = "observer{}" + +EPISODE_STEPS = 100 + + +class CoordinatorBase: + def __init__(self, batch_size, batch, state_size, nlayers, out_features): + r""" + Coordinator object to run on worker. Only one coordinator exists. Responsible + for facilitating communication between agent and observers and recording benchmark + throughput and latency data. + Args: + batch_size (int): Number of observer requests to process in a batch + batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time + state_size (list): List of ints dictating the dimensions of the state + nlayers (int): Number of layers in the model + out_features (int): Number of out features in the model + """ + self.batch_size = batch_size + self.batch = batch + + self.agent_rref = None # Agent RRef + self.ob_rrefs = [] # Observer RRef + + agent_info = rpc.get_worker_info(AGENT_NAME) + self.agent_rref = rpc.remote(agent_info, AgentBase) + + for rank in range(batch_size): + ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2)) + ob_ref = rpc.remote(ob_info, ObserverBase) + self.ob_rrefs.append(ob_ref) + + ob_ref.rpc_sync().set_state(state_size, batch) + + self.agent_rref.rpc_sync().set_world( + batch_size, state_size, nlayers, out_features, self.batch) + + def run_coordinator(self, episodes, episode_steps, queue): + r""" + Runs n benchmark episodes. Each episode is started by coordinator telling each + observer to contact the agent. Each episode is concluded by coordinator telling agent + to finish the episode, and then the coordinator records benchmark data + Args: + episodes (int): Number of episodes to run + episode_steps (int): Number steps to be run in each episdoe by each observer + queue (SimpleQueue): SimpleQueue from torch.multiprocessing.get_context() for + saving benchmark run results to + """ + + agent_latency_final = [] + agent_throughput_final = [] + + observer_latency_final = [] + observer_throughput_final = [] + + for ep in range(episodes): + ep_start_time = time.time() + + print(f"Episode {ep} - ", end='') + + n_steps = episode_steps + agent_start_time = time.time() + + futs = [] + for ob_rref in self.ob_rrefs: + futs.append(ob_rref.rpc_async().run_ob_episode( + self.agent_rref, n_steps)) + + rets = torch.futures.wait_all(futs) + agent_latency, agent_throughput = self.agent_rref.rpc_sync().finish_episode(rets) + + self.agent_rref.rpc_sync().reset_metrics() + + agent_latency_final += agent_latency + agent_throughput_final += agent_throughput + + observer_latency_final += [ret[2] for ret in rets] + observer_throughput_final += [ret[3] for ret in rets] + + ep_end_time = time.time() + episode_time = ep_end_time - ep_start_time + print(round(episode_time, 3)) + + observer_latency_final = [t for s in observer_latency_final for t in s] + observer_throughput_final = [ + t for s in observer_throughput_final for t in s] + + benchmark_metrics = {'agent latency (seconds)': {}, + 'agent throughput': {}, + 'observer latency (seconds)': {}, + 'observer throughput': {}} + + + print("For batch size {0}".format(self.batch_size)) + print("\nAgent Latency - ", len(agent_latency_final)) + agent_latency_final = sorted(agent_latency_final) + for p in [50, 75, 90, 95]: + v = np.percentile(agent_latency_final, p) + print("p" + str(p) + ":", round(v, 3)) + p = f'p{p}' + benchmark_metrics['agent latency (seconds)'][p] = round(v, 3) + + print("\nAgent Throughput - ", len(agent_throughput_final)) + agent_throughput_final = sorted(agent_throughput_final) + for p in [50, 75, 90, 95]: + v = np.percentile(agent_throughput_final, p) + print("p" + str(p) + ":", int(v)) + p = f'p{p}' + benchmark_metrics['agent throughput'][p] = int(v) + + print("\nObserver Latency - ", len(observer_latency_final)) + observer_latency_final = sorted(observer_latency_final) + for p in [50, 75, 90, 95]: + v = np.percentile(observer_latency_final, p) + print("p" + str(p) + ":", round(v, 3)) + p = f'p{p}' + benchmark_metrics['observer latency (seconds)'][p] = round(v, 3) + + print("\nObserver Throughput - ", len(observer_throughput_final)) + observer_throughput_final = sorted(observer_throughput_final) + for p in [50, 75, 90, 95]: + v = np.percentile(observer_throughput_final, p) + print("p" + str(p) + ":", int(v)) + p = f'p{p}' + benchmark_metrics['observer throughput'][p] = int(v) + + if queue: + queue.put(benchmark_metrics) diff --git a/benchmarks/distributed/rpc/rl/launcher.py b/benchmarks/distributed/rpc/rl/launcher.py new file mode 100644 index 0000000000000..5a612aab0e9e1 --- /dev/null +++ b/benchmarks/distributed/rpc/rl/launcher.py @@ -0,0 +1,213 @@ +import argparse +import os +import time + +import json +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp + + +from coordinator import CoordinatorBase + +COORDINATOR_NAME = "coordinator" +AGENT_NAME = "agent" +OBSERVER_NAME = "observer{}" + +TOTAL_EPISODES = 10 +TOTAL_EPISODE_STEPS = 100 + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +parser = argparse.ArgumentParser(description='PyTorch RPC RL Benchmark') +parser.add_argument('--world_size', type=str, default='10') +parser.add_argument('--master_addr', type=str, default='127.0.0.1') +parser.add_argument('--master_port', type=str, default='29501') +parser.add_argument('--batch', type=str, default='True') + +parser.add_argument('--state_size', type=str, default='10-20-10') +parser.add_argument('--nlayers', type=str, default='5') +parser.add_argument('--out_features', type=str, default='10') +parser.add_argument('--output_file_path', type=str, default='benchmark_report.json') + +args = parser.parse_args() +args = vars(args) + +def run_worker(rank, world_size, master_addr, master_port, batch, state_size, nlayers, out_features, queue): + r""" + inits an rpc worker + Args: + rank (int): Rpc rank of worker machine + world_size (int): Number of workers in rpc network (number of observers + + 1 agent + 1 coordinator) + master_addr (str): Master address of cooridator + master_port (str): Master port of coordinator + batch (bool): Whether agent will use batching or process one observer + request a at a time + state_size (str): Numerical str representing state dimensions (ie: 5-15-10) + nlayers (int): Number of layers in model + out_features (int): Number of out features in model + queue (SimpleQueue): SimpleQueue from torch.multiprocessing.get_context() for + saving benchmark run results to + """ + state_size = list(map(int, state_size.split('-'))) + batch_size = world_size - 2 # No. of observers + + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + if rank == 0: + rpc.init_rpc(COORDINATOR_NAME, rank=rank, world_size=world_size) + + coordinator = CoordinatorBase( + batch_size, batch, state_size, nlayers, out_features) + coordinator.run_coordinator(TOTAL_EPISODES, TOTAL_EPISODE_STEPS, queue) + + elif rank == 1: + rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size) + else: + rpc.init_rpc(OBSERVER_NAME.format(rank), + rank=rank, world_size=world_size) + rpc.shutdown() + +def find_graph_variable(args): + r""" + Determines if user specified multiple entries for a single argument, in which case + benchmark is run for each of these entries. Comma separated values in a given argument indicate multiple entries. + Output is presented so that user can use plot repo to plot the results with each of the + variable argument's entries on the x-axis. Args is modified in accordance with this. + More than 1 argument with multiple entries is not permitted. + Args: + args (dict): Dictionary containing arguments passed by the user (and default arguments) + """ + var_types = {'world_size': int, + 'state_size': str, + 'nlayers': int, + 'out_features': int, + 'batch': str2bool} + for arg in var_types.keys(): + if ',' in args[arg]: + if args.get('x_axis_name'): + raise("Only 1 x axis graph variable allowed") + args[arg] = list(map(var_types[arg], args[arg].split(','))) # convert , separated str to list + args['x_axis_name'] = arg + else: + args[arg] = var_types[arg](args[arg]) # convert string to proper type + +def append_spaces(string, length): + r""" + Returns a modified string with spaces appended to the end. If length of string argument + is greater than or equal to length, a single space is appended, otherwise x spaces are appended + where x is the difference between the length of string and the length argument + Args: + string (str): String to be modified + length (int): Size of desired return string with spaces appended + Return: (str) + """ + string = str(string) + offset = length - len(string) + if offset <= 0: + offset = 1 + string += ' ' * offset + return string + +def print_benchmark_results(report): + r""" + Prints benchmark results + Args: + report (dict): JSON formatted dictionary containing relevant data on the run of this application + """ + print("--------------------------------------------------------------") + print("PyTorch distributed rpc benchmark reinforcement learning suite") + print("--------------------------------------------------------------") + for key, val in report.items(): + if key != "benchmark_results": + print(f'{key} : {val}') + + x_axis_name = report.get('x_axis_name') + col_width = 7 + heading = "" + if x_axis_name: + x_axis_output_label = f'{x_axis_name} |' + heading += append_spaces(x_axis_output_label, col_width) + metric_headers = ['agent latency (seconds)', 'agent throughput', + 'observer latency (seconds)', 'observer throughput'] + percentile_subheaders = ['p50', 'p75', 'p90', 'p95'] + subheading = "" + if x_axis_name: + subheading += append_spaces(' ' * (len(x_axis_output_label) - 1), col_width) + for header in metric_headers: + heading += append_spaces(header, col_width * len(percentile_subheaders)) + for percentile in percentile_subheaders: + subheading += append_spaces(percentile, col_width) + print(heading) + print(subheading) + + for benchmark_run in report['benchmark_results']: + run_results = "" + if x_axis_name: + run_results += append_spaces(benchmark_run[x_axis_name], max(col_width, len(x_axis_output_label))) + for metric_name in metric_headers: + percentile_results = benchmark_run[metric_name] + for percentile in percentile_subheaders: + run_results += append_spaces(percentile_results[percentile], col_width) + print(run_results) + +def main(): + r""" + Runs rpc benchmark once if no argument has multiple entries, and otherwise once for each of the multiple entries. + Multiple entries is indicated by comma separated values, and may only be done for a single argument. + Results are printed as well as saved to output file. In case of multiple entries for a single argument, + the plot repo can be used to benchmark results on the y axis with each entry on the x axis. + """ + find_graph_variable(args) + + # run once if no x axis variables + x_axis_variables = args[args['x_axis_name']] if args.get('x_axis_name') else [None] + ctx = mp.get_context('spawn') + queue = ctx.SimpleQueue() + benchmark_runs = [] + for i, x_axis_variable in enumerate(x_axis_variables): # run benchmark for every x axis variable + if len(x_axis_variables) > 1: + args[args['x_axis_name']] = x_axis_variable # set x axis variable for this benchmark iteration + processes = [] + start_time = time.time() + for rank in range(args['world_size']): + prc = ctx.Process( + target=run_worker, + args=( + rank, args['world_size'], args['master_addr'], args['master_port'], + args['batch'], args['state_size'], args['nlayers'], + args['out_features'], queue + ) + ) + prc.start() + processes.append(prc) + benchmark_run_results = queue.get() + for process in processes: + process.join() + print(f"Time taken benchmark run {i} -, {time.time() - start_time}") + if args.get('x_axis_name'): + # save x axis value was for this iteration in the results + benchmark_run_results[args['x_axis_name']] = x_axis_variable + benchmark_runs.append(benchmark_run_results) + + report = args + report['benchmark_results'] = benchmark_runs + if args.get('x_axis_name'): + # x_axis_name was variable so dont save a constant in the report for that variable + del report[args['x_axis_name']] + with open(args['output_file_path'], 'w') as f: + json.dump(report, f) + print_benchmark_results(report) + +if __name__ == '__main__': + main() diff --git a/benchmarks/distributed/rpc/rl/observer.py b/benchmarks/distributed/rpc/rl/observer.py new file mode 100644 index 0000000000000..159bfdeccb169 --- /dev/null +++ b/benchmarks/distributed/rpc/rl/observer.py @@ -0,0 +1,71 @@ +import random +import time + +import torch +import torch.distributed.rpc as rpc +from torch.distributed.rpc import rpc_sync + +from agent import AgentBase + + +class ObserverBase: + def __init__(self): + r""" + Inits observer class + """ + self.id = rpc.get_worker_info().id + + def set_state(self, state_size, batch): + r""" + Further initializes observer to be aware of rpc environment + Args: + state_size (list): List of integers denoting dimensions of state + batch (bool): Whether agent will be using batch select action + """ + self.state_size = state_size + self.select_action = AgentBase.select_action_batch if batch else AgentBase.select_action_non_batch + + def reset(self): + r""" + Resets state randomly + """ + state = torch.rand(self.state_size) + return state + + def step(self, action): + r""" + Generates random state and reward + Args: + action (int): Int received from agent representing action to take on state + """ + state = torch.rand(self.state_size) + reward = random.randint(0, 1) + + return state, reward + + def run_ob_episode(self, agent_rref, n_steps): + r""" + Runs single observer episode where for n_steps, an action is selected + from the agent based on curent state and state is updated + Args: + agent_rref (RRef): Remote Reference to the agent + n_steps (int): Number of times to select an action to transform state per episode + """ + state, ep_reward = self.reset(), None + rewards = torch.zeros(n_steps) + observer_latencies = [] + observer_throughput = [] + + for st in range(n_steps): + ob_latency_start = time.time() + action = rpc_sync(agent_rref.owner(), self.select_action, args=( + agent_rref, self.id, state)) + + ob_latency = time.time() - ob_latency_start + observer_latencies.append(ob_latency) + observer_throughput.append(1 / ob_latency) + + state, reward = self.step(action) + rewards[st] = reward + + return [rewards, ep_reward, observer_latencies, observer_throughput] diff --git a/benchmarks/fastrnns/cells.py b/benchmarks/fastrnns/cells.py index fc910300b0bed..6e797b9e2d1cb 100644 --- a/benchmarks/fastrnns/cells.py +++ b/benchmarks/fastrnns/cells.py @@ -1,4 +1,6 @@ import torch +from typing import Tuple +from torch import Tensor def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): @@ -22,8 +24,8 @@ def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): return hy, cy -def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] +def lstm_cell(input: Tensor, hidden: Tuple[Tensor, Tensor], w_ih: Tensor, + w_hh: Tensor, b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: hx, cx = hidden gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh @@ -40,8 +42,8 @@ def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): return hy, cy -def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh): - # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] +def flat_lstm_cell(input: Tensor, hx: Tensor, cx: Tensor, w_ih: Tensor, + w_hh: Tensor, b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) @@ -57,8 +59,8 @@ def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh): return hy, cy -def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] +def premul_lstm_cell(igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, + b_ih: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: hx, cx = hidden gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh @@ -75,8 +77,7 @@ def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh): return hy, cy -def premul_lstm_cell_no_bias(igates, hidden, w_hh, b_hh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor) -> Tuple[Tensor, Tensor] +def premul_lstm_cell_no_bias(igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor) -> Tuple[Tensor, Tensor]: hx, cx = hidden gates = igates + torch.mm(hx, w_hh.t()) + b_hh diff --git a/benchmarks/fastrnns/custom_lstms.py b/benchmarks/fastrnns/custom_lstms.py index d835b3e533f15..60abb1ac574cf 100644 --- a/benchmarks/fastrnns/custom_lstms.py +++ b/benchmarks/fastrnns/custom_lstms.py @@ -86,8 +86,7 @@ def script_lnlstm(input_size, hidden_size, num_layers, bias=True, LSTMState = namedtuple('LSTMState', ['hx', 'cx']) -def reverse(lst): - # type: (List[Tensor]) -> List[Tensor] +def reverse(lst: List[Tensor]) -> List[Tensor]: return lst[::-1] @@ -102,8 +101,7 @@ def __init__(self, input_size, hidden_size): self.bias_hh = Parameter(torch.randn(4 * hidden_size)) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = state gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + torch.mm(hx, self.weight_hh.t()) + self.bias_hh) @@ -165,8 +163,7 @@ def __init__(self, input_size, hidden_size, decompose_layernorm=False): self.layernorm_c = ln(hidden_size) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = state igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) @@ -190,8 +187,7 @@ def __init__(self, cell, *cell_args): self.cell = cell(*cell_args) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: inputs = input.unbind(0) outputs = torch.jit.annotate(List[Tensor], []) for i in range(len(inputs)): @@ -206,8 +202,7 @@ def __init__(self, cell, *cell_args): self.cell = cell(*cell_args) @jit.script_method - def forward(self, input, state): - # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: inputs = reverse(input.unbind(0)) outputs = jit.annotate(List[Tensor], []) for i in range(len(inputs)): @@ -227,8 +222,7 @@ def __init__(self, cell, *cell_args): ]) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: [forward LSTMState, backward LSTMState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) @@ -258,8 +252,7 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args): other_layer_args) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input @@ -286,8 +279,7 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args): other_layer_args) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]] + def forward(self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: # List[List[LSTMState]]: The outer list is for layers, # inner list is for directions. output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) @@ -322,8 +314,7 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args): self.dropout_layer = nn.Dropout(0.4) @jit.script_method - def forward(self, input, states): - # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input diff --git a/benchmarks/fastrnns/factory.py b/benchmarks/fastrnns/factory.py index 056bcd746aec8..91ac39b06a84d 100644 --- a/benchmarks/fastrnns/factory.py +++ b/benchmarks/fastrnns/factory.py @@ -1,6 +1,8 @@ import torch from collections import namedtuple +from typing import List, Tuple +from torch import Tensor from .cells import lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias, flat_lstm_cell @@ -234,8 +236,10 @@ def forward(sequences, hidden): def varlen_lstm_factory(cell, script): - def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh): - # type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]] # noqa + def dynamic_rnn(sequences: List[Tensor], hiddens: Tuple[Tensor, Tensor], wih: Tensor, + whh: Tensor, bih: Tensor, bhh: Tensor + ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]: + # noqa hx, cx = hiddens hxs = hx.unbind(1) cxs = cx.unbind(1) @@ -359,8 +363,8 @@ def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, def lstm_factory(cell, script): - def dynamic_rnn(input, hidden, wih, whh, bih, bhh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, + bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = input.unbind(0) @@ -379,8 +383,8 @@ def dynamic_rnn(input, hidden, wih, whh, bih, bhh): # premul: we're going to premultiply the inputs & weights def lstm_factory_premul(premul_cell, script): - def dynamic_rnn(input, hidden, wih, whh, bih, bhh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, + bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = torch.matmul(input, wih.t()).unbind(0) @@ -399,8 +403,8 @@ def dynamic_rnn(input, hidden, wih, whh, bih, bhh): # premul: we're going to premultiply the inputs & weights, and add bias def lstm_factory_premul_bias(premul_cell, script): - def dynamic_rnn(input, hidden, wih, whh, bih, bhh): - # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, + bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inpSize = input.size() @@ -442,8 +446,7 @@ def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh): def lstm_factory_multilayer(cell, script): - def dynamic_rnn(input, hidden, params): - # type: (Tensor, Tuple[Tensor, Tensor], List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: params_stride = 4 # NB: this assumes that biases are there hx, cx = hidden hy, cy = hidden # for scoping... diff --git a/benchmarks/fastrnns/fuser.py b/benchmarks/fastrnns/fuser.py index 620c19a13cf1a..e1daab594c508 100644 --- a/benchmarks/fastrnns/fuser.py +++ b/benchmarks/fastrnns/fuser.py @@ -1,12 +1,10 @@ import torch def set_fuser(fuser_name, executor_name): - assert fuser_name in ['te', 'old', 'none'] + assert fuser_name in ['te', 'old', 'none', 'default'] if fuser_name == 'te': torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) - torch._C._jit_set_bailout_depth(20) - torch._C._jit_set_num_profiled_runs(2) torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(True) @@ -21,16 +19,18 @@ def set_fuser(fuser_name, executor_name): torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_set_texpr_fuser_enabled(False) + elif fuser_name == 'default': + pass # --executor overrides settings of --fuser if executor_name == 'profiling': torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) - torch._C._jit_set_bailout_depth(20) - torch._C._jit_set_num_profiled_runs(2) elif executor_name == 'simple': torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(False) elif executor_name == 'legacy': torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) + elif executor_name == 'default': + pass diff --git a/benchmarks/functional_autograd_benchmark/ppl_models.py b/benchmarks/functional_autograd_benchmark/ppl_models.py index 906ebac5d41b6..94ba6698a91dc 100644 --- a/benchmarks/functional_autograd_benchmark/ppl_models.py +++ b/benchmarks/functional_autograd_benchmark/ppl_models.py @@ -24,8 +24,9 @@ def forward(beta_value: Tensor) -> Tensor: mu = X.mm(beta_value) # We need to compute the first and second gradient of this score with respect - # to beta_value. - score = dist.Bernoulli(logits=mu).log_prob(Y).sum() + beta_prior.log_prob(beta_value).sum() + # to beta_value. We disable Bernoulli validation because Y is a relaxed value. + score = (dist.Bernoulli(logits=mu, validate_args=False).log_prob(Y).sum() + + beta_prior.log_prob(beta_value).sum()) return score return forward, (beta_value.to(device),) @@ -40,7 +41,7 @@ def get_robust_regression(device: torch.device) -> GetterReturnType: Y = torch.rand(N, 1, device=device) # Predefined nu_alpha and nu_beta, nu_alpha.shape: (1, 1), nu_beta.shape: (1, 1) - nu_alpha = torch.randn(1, 1, device=device) + nu_alpha = torch.rand(1, 1, device=device) nu_beta = torch.rand(1, 1, device=device) nu = dist.Gamma(nu_alpha, nu_beta) diff --git a/benchmarks/functional_autograd_benchmark/torchvision_models.py b/benchmarks/functional_autograd_benchmark/torchvision_models.py index 25361af77661d..c1d9eaf6105a0 100644 --- a/benchmarks/functional_autograd_benchmark/torchvision_models.py +++ b/benchmarks/functional_autograd_benchmark/torchvision_models.py @@ -247,7 +247,7 @@ class IntermediateLayerGetter(nn.ModuleDict): Additionally, it is only able to query submodules that are directly assigned to the model. So if `model` is passed, `model.feature1` can be returned, but not `model.feature1.layer2`. - Arguments: + Args: model (nn.Module): model on which we will extract the features return_layers (Dict[name, new_name]): a dict containing the names of the modules for which the activations will be returned as @@ -324,7 +324,7 @@ def forward(self, x): class FCN(_SimpleSegmentationModel): """ Implements a Fully-Convolutional Network for semantic segmentation. - Arguments: + Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "out" for the last feature map used, and "aux" if an auxiliary classifier @@ -509,7 +509,7 @@ def box_area(boxes): """ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. - Arguments: + Args: boxes (Tensor[N, 4]): boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format Returns: diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 95e0e46bf79a2..2f170ab847ddf 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -60,13 +60,15 @@ add_short_configs = op_bench.cross_product_configs( ) class AddBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, K): - self.input_one = torch.rand(M, N, K) - self.input_two = torch.rand(M, N, K) + def init(self, M, N, K, device): + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } self.set_module_name("add") - def forward(self): - return torch.add(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.add(input_one, input_two) op_bench.generate_pt_test(add_short_configs, AddBenchmark) ``` @@ -174,14 +176,15 @@ add_short_configs = op_bench.config_list( ) class AddBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device): + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } + self.set_module_name("add") -    def init(self, M, N, K): -        self.input_one = torch.rand(M, N, K) -        self.input_two = torch.rand(M, N, K) -        self.set_module_name("add") - -    def forward(self): -        return torch.add(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.add(input_one, input_two) op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark) @@ -218,26 +221,28 @@ Let's look at it in detail: #### Part 2. Create Tensors and Add Computation After inputs are provided, we now look at adding the computation of an operator. Adding a new operator requires implementing a new `TorchBenchmarkBase` subclass. Every new class is required to implement 2 methods: -* `init` is used to create tensors based on the inputs we provided before. In this example, the parameters to `init` are `M, N, and K` which have been specified in the input configuration. -* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Besides the object itself, it doesn't take any additional parameters.  +* `init` is used to create tensors based on the inputs we provided before. In this example, the parameters to `init` are `M, N, and K` which have been specified in the input configuration. `init` also packed all the needed inputs together into a dictionary `self.inputs` which will be provided to `forward` as arguments for running the benchmark. +* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Apart from `self`, the order of the arguments must match the entries specified in `self.inputs`.   The example below shows the code for `torch.add`:   ``` # Given one set of M, N, K, the init method creates input tensors based on # that. The forward method does torch.add calculation on those input tensors. -class AddBenchmark(op_bench.TorchBenchmarkBase): -    def init(self, M, N, K): +class AddBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device): # this is the method where you need to create tensors # M, N, and K can be in different order, but they must match with # names in the configs. -        self.input_one = torch.rand(M, N, K) -        self.input_two = torch.rand(M, N, K) -        self.set_module_name("add") + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } + self.set_module_name("add") -    def forward(self): + def forward(self, input_one, input_two): # this is the method to have operator and do computation -        return torch.add(self.input_one, self.input_two) + return torch.add(input_one, input_two) ``` #### Part 3. Register Tests With the Benchmark Suite @@ -336,15 +341,16 @@ unary_ops_list = op_bench.op_list( ) class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, device, op_func): + self.inputs = { + "input": torch.rand(M, N, device=device) + } + self.op_func = op_func -    def init(self, M, N, op_func): -        self.input_one = torch.rand(M, N) -        self.op_func = op_func + def forward(self, input): + return self.op_func(input) -    def forward(self): -        return self.op_func(self.input_one) - -op_bench.generate_pt_tests_from_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) +op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) if __name__ == "__main__":     op_bench.benchmark_runner.main() @@ -371,27 +377,28 @@ unary_ops_list = op_bench.op_list( In this example, both operators share the same input so we only need to implement one TorchBenchmakrBase subclass.  Every new subclass is required to implement 3 methods: * `init` is used to create tensors and set the operator name and function. In this example, the parameters to `init` are `M`, `N`, and `op_func` which have been specified in the configurations. -* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Besides the object itself, it doesn't take any additional parameters.  +* `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Apart from `self`, the order of the arguments must match the entries specified in `self.inputs`. Here is the code for `abs` and `acos`: ``` class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): - -    def init(self, M, N, op_func): + def init(self, M, N, device, op_func): # The M and N match with the attr_names in the input configuration # The op_func matches with the attr_name in the ops configuration -        self.input_one = torch.rand(M, N) -        self.op_func = op_func + self.inputs = { + "input": torch.rand(M, N, device=device) + } + self.op_func = op_func -    def forward(self): -        return self.op_func(self.input_one) + def forward(self, input): + return self.op_func(input) ``` #### Part 3. Register a List of Operators -To register multiple operators, we introduced the `generate_pt_tests_from_list` function which takes three parameters. First, the list of operators. Second,the configs. Third, the benchmark class.   +To register multiple operators, we introduced the `generate_pt_tests_from_op_list` function which takes three parameters. First, the list of operators. Second,the configs. Third, the benchmark class.   Here is an example: ``` -op_bench.generate_pt_tests_from_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) +op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark) ``` diff --git a/benchmarks/operator_benchmark/benchmark_all_other_test.py b/benchmarks/operator_benchmark/benchmark_all_other_test.py index 4ea7ab47a4c26..adaf8a09ee960 100644 --- a/benchmarks/operator_benchmark/benchmark_all_other_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_other_test.py @@ -2,9 +2,10 @@ from pt import ( # noqa add_test, as_strided_test, batchnorm_test, binary_test, cat_test, # noqa channel_shuffle_test, chunk_test, conv_test, diag_test, embeddingbag_test, # noqa - fill_test, gather_test, linear_test, matmul_test, pool_test, # noqa + fill_test, gather_test, linear_test, matmul_test, nan_to_num_test, pool_test, # noqa softmax_test, hardsigmoid_test, hardswish_test, layernorm_test, # noqa - groupnorm_test, instancenorm_test # noqa + groupnorm_test, instancenorm_test, remainder_test, softmax_test, # noqa + split_test, sum_test, tensor_to_test # noqa ) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/benchmark_all_quantized_test.py b/benchmarks/operator_benchmark/benchmark_all_quantized_test.py index 076a2685f61ed..d0f5f9ff7896f 100644 --- a/benchmarks/operator_benchmark/benchmark_all_quantized_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_quantized_test.py @@ -18,6 +18,7 @@ quantization_test, qunary_test, qembedding_pack_test, + qembeddingbag_test, ) diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py index c374db61431ef..b0534bd9722dd 100644 --- a/benchmarks/operator_benchmark/benchmark_caffe2.py +++ b/benchmarks/operator_benchmark/benchmark_caffe2.py @@ -3,7 +3,7 @@ from caffe2.proto import caffe2_pb2 import benchmark_utils from collections import namedtuple -from benchmark_test_generator import _generate_test +from benchmark_test_generator import _register_test """Caffe2 performance microbenchmarks. @@ -50,10 +50,15 @@ def tensor(self, shapes, dtype='float32', device='cpu'): Return: C2 tensor of dtype """ + return self.feed_tensor(benchmark_utils.numpy_random(dtype, *shapes), device) + + def feed_tensor(self, tensor, device='cpu'): + """ Similar to tensor, but can supply any data compatible with FeedBlob + """ blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index) dev = self._device_option(device) with core.DeviceScope(dev): - workspace.FeedBlob(blob_name, benchmark_utils.numpy_random(dtype, *shapes)) + workspace.FeedBlob(blob_name, tensor) Caffe2BenchmarkBase.tensor_index += 1 return blob_name @@ -93,6 +98,10 @@ def test_name(self, name_type="long", **kargs): Caffe2BenchmarkBase.test_index += 1 return name + def extract_inputs_tuple(self): + # add a dummy function here to match the interface of TorchBenchmarkBase + pass + class Caffe2OperatorTestCase(object): """ This class includes all the information needed to benchmark an operator. @@ -107,7 +116,7 @@ def __init__(self, op_bench, test_config): self.test_config = test_config self.framework = "Caffe2" - def run_forward(self, num_runs, print_per_iter=False): + def run_forward(self, num_runs, print_per_iter=False, cuda_sync=False): """ Run the forward path of an operator in a loop """ with core.DeviceScope(self.op_bench.dev): @@ -115,7 +124,7 @@ def run_forward(self, num_runs, print_per_iter=False): if not workspace.RunOperatorMultiple(op, num_runs): raise ValueError("Unable to run operator test case: {}".format(self.test_name)) - def run_backward(self, num_runs): + def run_backward(self, num_runs, print_per_iter=False): """ Run the backward path of an operator in a loop """ with core.DeviceScope(self.op_bench.dev): @@ -185,12 +194,12 @@ def generate_c2_test_from_ops(ops_metadata, bench_op, tags): def generate_c2_test(configs, c2_bench_op): """ This function creates Caffe2 op test based on the given operator """ - return _generate_test(configs, c2_bench_op, create_caffe2_op_test_case, - run_backward=False) + return _register_test(configs, c2_bench_op, create_caffe2_op_test_case, + False) def generate_c2_gradient_test(configs, c2_bench_op): """ This function creates Caffe2 op test based on the given operator """ - return _generate_test(configs, c2_bench_op, create_caffe2_op_test_case, - run_backward=True) + return _register_test(configs, c2_bench_op, create_caffe2_op_test_case, + True) diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index ed8ecfcac375b..10d08b100d662 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -118,6 +118,13 @@ def _build_test(configs, bench_op, OperatorTestCase, run_backward, op_name_funct op._set_backward_test(run_backward) op.init(**init_dict) + op.extract_inputs_tuple() + + if not run_backward: + for _, attr in vars(op).items(): + if isinstance(attr, torch.nn.Module): + for param in attr.parameters(): + param.requires_grad = False input_name = None @@ -244,7 +251,7 @@ def _iteration_result_is_significant(self, iters, run_time_sec, curr_test_total_ def _launch_forward(self, test_case, iters, print_per_iter): """ Use Python's timeit module to measure execution time (unit: second). """ - cuda_sync = True if 'cuda' in test_case.test_config.test_name else False + cuda_sync = 'cuda' in test_case.test_config.test_name func = test_case.run_forward if self.use_jit: func = test_case.run_jit_forward diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index 4d927d73bfc04..2203a0af2ec38 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -1,7 +1,7 @@ import time import json import torch -import torch.utils.cpp_extension as cpp_extension # noqa +import cpp_extension # noqa """PyTorch performance microbenchmarks. @@ -10,7 +10,7 @@ microbenchmarks. """ -class TorchBenchmarkBase(object): +class TorchBenchmarkBase(torch.nn.Module): """ This is a base class used to create Pytorch operator benchmark. module_name is the name of the operator being benchmarked. test_name is the name (it's created by concatenating all the @@ -18,8 +18,8 @@ class TorchBenchmarkBase(object): """ def __init__(self): + super(TorchBenchmarkBase, self).__init__() self.user_given_name = None - self._jit_forward = None self._pass_count = 0 self._num_inputs_require_grads = 0 @@ -49,32 +49,26 @@ def auto_set(self): self._auto_set_counter += 1 return (self._pass_count == self._auto_set_counter) - def forward(self): - pass + def extract_inputs_tuple(self): + self.inputs_tuple = tuple(self.inputs.values()) - def _wrap_forward(self, foo): - """ The function passed to JIT trace must have at least one argument, - this function is to wrap the forward method to meet that requirement. - _consume op is used to avoid the dead-code-elimination optimization - in JIT. - """ - return torch.ops.operator_benchmark._consume(self.forward()) - - def _generate_jit_forward_graph(self): - """ generate a graph for the forward function via tracing - """ + @torch.jit.export + def get_inputs(self): + # Need to convert the inputs to tuple outside of JIT so that + # JIT can infer the size of the inputs. + return self.inputs_tuple - func = torch.jit.trace(self._wrap_forward, torch.rand(1)) - place_holder = torch.rand(1) # noqa + @torch.jit.export + def forward_impl(self): + # This is to supply the inputs to the forward function which + # will be called in both the eager and JIT mode of local runs + return self.forward(*self.get_inputs()) - @torch.jit.script - def _jit_forward_graph(iters, place_holder): - # type: (int, Tensor) - result = torch.jit.annotate(torch.Tensor, place_holder) - for _ in range(iters): - result = func(place_holder) - return result - return _jit_forward_graph + @torch.jit.export + def forward_consume(self, iters: int): + # _consume is used to avoid the dead-code-elimination optimization + for _ in range(iters): + torch.ops.operator_benchmark._consume(self.forward_impl()) def module_name(self): """ this is used to label the operator being benchmarked @@ -121,13 +115,20 @@ def __init__(self, op_bench, test_config): self.place_holder_tensor = torch.ones(1) self.framework = "PyTorch" self.time_series = [] + self._jit_forward_graph = None + + def _generate_jit_forward_graph(self): + """ generate a graph for the forward function via scripting + """ + scripted_op_bench = torch.jit.script(self.op_bench) + return scripted_op_bench.forward_consume def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False): """ Run the forward path of an op with JIT mode """ - if self.op_bench._jit_forward is None: - self.op_bench._jit_forward = self.op_bench._generate_jit_forward_graph() - self.op_bench._jit_forward(num_runs, self.place_holder_tensor) + if self._jit_forward_graph is None: + self._jit_forward_graph = self._generate_jit_forward_graph() + self._jit_forward_graph(num_runs) def _print_per_iter(self): # print last 50 values @@ -148,15 +149,15 @@ def run_forward(self, num_runs, print_per_iter, cuda_sync): if print_per_iter: for _ in range(num_runs): start_time = time.time() - self.output = self.op_bench.forward() - if cuda_sync: + self.output = self.op_bench.forward_impl() + if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) end_time = time.time() self.time_series.append((end_time - start_time) * 1e3) else: for _ in range(num_runs): - self.output = self.op_bench.forward() - if cuda_sync: + self.output = self.op_bench.forward_impl() + if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) def _output_mean(self): diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 1a3ec19d7ece2..b9347364428ea 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -10,14 +10,12 @@ This is the main function for running performance microbenchmark tests. It also registers existing benchmark tests via Python module imports. """ +parser = argparse.ArgumentParser( + description="Run microbenchmarks.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) - -def main(): - parser = argparse.ArgumentParser( - description="Run microbenchmarks.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - +def parse_args(): parser.add_argument( '--tag_filter', help='tag_filter can be used to run the shapes which matches the tag. (all is used to run all the shapes)', @@ -145,6 +143,10 @@ def main(): if args.mkl_num_threads: benchmark_utils.set_mkl_threads(args.mkl_num_threads) + return args + +def main(): + args = parse_args() benchmark_core.BenchmarkRunner(args).run() diff --git a/benchmarks/operator_benchmark/benchmark_test_generator.py b/benchmarks/operator_benchmark/benchmark_test_generator.py index 6dd8150dfccd9..ec60c33c205a1 100644 --- a/benchmarks/operator_benchmark/benchmark_test_generator.py +++ b/benchmarks/operator_benchmark/benchmark_test_generator.py @@ -37,3 +37,7 @@ def forward(self): """ for op in ops_list: _register_test(configs, pt_bench_op, create_pytorch_op_test_case, False, op) + +def generate_pt_gradient_tests_from_op_list(ops_list, configs, pt_bench_op): + for op in ops_list: + _register_test(configs, pt_bench_op, create_pytorch_op_test_case, True, op) diff --git a/benchmarks/operator_benchmark/c2/batch_box_cox_test.py b/benchmarks/operator_benchmark/c2/batch_box_cox_test.py new file mode 100644 index 0000000000000..958828a01d0c8 --- /dev/null +++ b/benchmarks/operator_benchmark/c2/batch_box_cox_test.py @@ -0,0 +1,46 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core + + +"""Microbenchmarks for BatchBoxCox operator.""" + +# Configs for C2 BatchBoxCox operator +batch_box_cox_long_configs = op_bench.cross_product_configs( + M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] +) + + +batch_box_cox_short_configs = op_bench.config_list( + attrs=[ + [16, 16, "float"], + [16, 16, "double"], + [64, 64, "float"], + [64, 64, "double"], + ], + attr_names=["M", "N", "dtype"], + tags=["short"], +) + + +class BatchBoxCoxBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, M, N, dtype): + self.data = self.tensor([M, N], dtype) + self.lambda1 = self.tensor([N], dtype) + self.lambda2 = self.tensor([N], dtype) + self.output = self.tensor([1, 1], dtype) + self.set_module_name("batch_box_cox") + + def forward(self): + op = core.CreateOperator("BatchBoxCox", [self.data, self.lambda1, self.lambda2], self.output) + return op + + +op_bench_c2.generate_c2_test( + batch_box_cox_long_configs + batch_box_cox_short_configs, BatchBoxCoxBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/batch_gather_test.py b/benchmarks/operator_benchmark/c2/batch_gather_test.py new file mode 100644 index 0000000000000..ff3d84b99b2bc --- /dev/null +++ b/benchmarks/operator_benchmark/c2/batch_gather_test.py @@ -0,0 +1,56 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core +import numpy + + +"""Microbenchmarks for element-wise BatchGather operator.""" + +# Configs for C2 BatherGather operator +batch_gather_configs_short = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [8, 8, 1], + [256, 512, 1], + [512, 512, 1], + [8, 8, 2], + [256, 512, 2], + [512, 512, 2], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"] +) + +batch_gather_configs_long = op_bench.cross_product_configs( + M=[128, 1024], + N=[128, 1024], + K=[1, 2], + device=['cpu', 'cuda'], + tags=["long"] +) + +class BatchGatherBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, M, N, K, device): + self.input_one = self.tensor([M, N, K], device=device) + max_val = N + numpy.random.seed((1 << 32) - 1) + index_dim = numpy.random.randint(0, N) + self.index = self.feed_tensor(numpy.random.randint(0, max_val, index_dim), device=device) + self.output = self.tensor([M, index_dim, K], device=device) + self.set_module_name("batch_gather") + + def forward(self): + op = core.CreateOperator("BatchGather", [self.input_one, self.index], self.output) + return op + + +op_bench_c2.generate_c2_test( + batch_gather_configs_long + batch_gather_configs_short, BatchGatherBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/clip_ranges_test.py b/benchmarks/operator_benchmark/c2/clip_ranges_test.py new file mode 100644 index 0000000000000..2bb32f0624457 --- /dev/null +++ b/benchmarks/operator_benchmark/c2/clip_ranges_test.py @@ -0,0 +1,51 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core, dyndep + +dyndep.InitOpsLibrary("@/caffe2/caffe2/fb/operators:clip_ranges_op") + +"""Microbenchmarks for ClipRanges operator.""" + +# Configs for C2 ClipRanges operator +clip_ranges_long_configs = op_bench.cross_product_configs( + LENGTH=range(1, 100), + M=[1], + N=[2], + MAX_LENGTH=range(1, 100), + dtype=["int32"], + tags=["long"] +) + + +clip_ranges_short_configs = op_bench.config_list( + attrs=[ + [6, 1, 2, 1, "int32"], + [7, 1, 2, 2, "int32"], + [8, 1, 2, 3, "int32"], + [9, 1, 2, 4, "int32"], + [10, 1, 2, 5, "int32"], + ], + attr_names=["LENGTH", "M", "N", "MAX_LENGTH", "dtype"], + tags=["short"], +) + + +class ClipRangesBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, LENGTH, M, N, MAX_LENGTH, dtype): + self.input = self.tensor([LENGTH, M, N], dtype) + self.max_length = MAX_LENGTH + self.set_module_name("clip_ranges") + + def forward(self): + op = core.CreateOperator("ClipRanges", self.input, self.input, max_length=self.max_length) + return op + + +op_bench_c2.generate_c2_test( + clip_ranges_long_configs + clip_ranges_short_configs, ClipRangesBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/concat_test.py b/benchmarks/operator_benchmark/c2/concat_test.py new file mode 100644 index 0000000000000..7e18c4a745a1f --- /dev/null +++ b/benchmarks/operator_benchmark/c2/concat_test.py @@ -0,0 +1,116 @@ +import operator_benchmark as op_bench +import benchmark_caffe2 as op_bench_c2 +import random +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core + + +"""Microbenchmarks for Concat operator. Supports both Caffe2/PyTorch.""" + +# Configs for C2 concat operator +cat_configs_short = op_bench.config_list( + attr_names=['sizes', 'N', 'axis'], + attrs=[ + [(1, 1, 1), 2, 0], # noqa + [(512, 512, 2), 2, 1], # noqa + [(128, 1024, 2), 2, 1], # noqa + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + 'dtype': ['float'], + }, + tags=['short'], +) + +cat_configs_long = op_bench.config_list( + attr_names=['sizes', 'N', 'axis'], + attrs=[ + [(2**10, 2**10, 2), 2, 0], # noqa + [(2**10+1, 2**10-1, 2), 2, 1], # noqa + [(2**10, 2**10, 2), 2, 2], # noqa + + [[ lambda: random.randint(2**6, 2**7), 2**7-17, 2**6+1], # noqa + 5, 0], + [[ 2**6+2**5, lambda: random.randint(2**6, 2**7), 2**6], # noqa + 5, 1], + [[ 2**7, 2**6, lambda: random.randint(2**6, 2**7)], # noqa + 5, 2], + + [[lambda: random.randint(2**5, 2**6), 2**5, 2**6], # noqa + 50, 0], + [[2**5, lambda: random.randint(2**5, 2**6), 2**6], # noqa + 50, 1], + [[2**5+1, 2**6+1, lambda: random.randint(2**5, 2**6)], # noqa + 50, 2], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + 'dtype': ['float'], + }, + tags=['long'], +) + +# There is a different codepath on CUDA for >4 dimensions +cat_configs_multidim = op_bench.config_list( + attr_names=['sizes', 'N', 'axis', 'dtype'], + attrs=[ + [(2**6, 2**5, 2**2, 2**4, 2**5), 2, 2], # noqa + [(2**4, 2**5, 2**2, 2**4, 2**5), 8, 2], # noqa + [(2**3+1, 2**5-1, 2**2+1, 2**4-1, 2**5+1), 17, 4], # noqa + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + 'dtype': ['float'], + }, + tags=['multidim'], +) + +cat_configs_manyinputs = op_bench.config_list( + attr_names=['sizes', 'N', 'axis', 'dtype'], + attrs=[ + [[lambda: random.randint(1, 10000)], 100, 0], + [[lambda: random.randint(1, 1000)], 1000, 0], + [[lambda: random.randint(1, 500)], 2000, 0], + [[lambda: random.randint(1, 300)], 3000, 0], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + 'dtype': ['float'], + }, + tags=['manyinputs'], +) + + +class ConcatBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, sizes, N, axis, dtype, device): + random.seed(42) + self.inputs = [] + self.args = {'axis': axis} + gen_sizes = [] + for i in range(N): + gen_sizes.append([old_size() if callable(old_size) else old_size + for old_size in sizes]) + + for s in gen_sizes: + self.inputs.append(self.tensor(s, dtype, device=device)) + + self.output = self.tensor(gen_sizes[0], dtype, device=device) + self.split_info = self.tensor(gen_sizes[0], "int") + self.set_module_name("concat") + + def forward(self): + op = core.CreateOperator( + "Concat", self.inputs, [self.output, self.split_info], **self.args + ) + return op + + +op_bench_c2.generate_c2_test(cat_configs_short + + cat_configs_long + + cat_configs_multidim + + cat_configs_manyinputs, + ConcatBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/quantile_op_test.py b/benchmarks/operator_benchmark/c2/quantile_op_test.py new file mode 100644 index 0000000000000..f22384f8e0fd2 --- /dev/null +++ b/benchmarks/operator_benchmark/c2/quantile_op_test.py @@ -0,0 +1,47 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core + + +"""Microbenchmarks for QuantileOp operator.""" + +# Configs for C2 QuantileOp operator +quantile_op_long_configs = op_bench.cross_product_configs( + M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] +) + + +quantile_op_short_configs = op_bench.config_list( + attrs=[ + [16, 16, "float"], + [16, 16, "double"], + [64, 64, "float"], + [64, 64, "double"], + ], + attr_names=["M", "N", "dtype"], + tags=["short"], +) + + +class QuantileOpBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, M, N, dtype): + self.data = [self.tensor([N], dtype) for _ in range(M)] + self.quantile = 0.3 + self.output = self.tensor([1], dtype) + self.set_module_name("quantile_op") + + def forward(self): + op = core.CreateOperator( + "Quantile", inputs=self.data, outputs=self.output, quantile=self.quantile + ) + return op + + +op_bench_c2.generate_c2_test( + quantile_op_long_configs + quantile_op_short_configs, QuantileOpBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/replace_nan_test.py b/benchmarks/operator_benchmark/c2/replace_nan_test.py new file mode 100644 index 0000000000000..f91c6f11c240f --- /dev/null +++ b/benchmarks/operator_benchmark/c2/replace_nan_test.py @@ -0,0 +1,43 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core + + +"""Microbenchmarks for element-wise ReplaceNaN operator.""" + +# Configs for C2 ReplaceNaN operator +replace_nan_long_configs = op_bench.cross_product_configs( + M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] +) + + +replace_nan_short_configs = op_bench.config_list( + attrs=[ + [16, 16, "float"], + [16, 16, "double"], + [64, 64, "float"], + [64, 64, "double"], + ], + attr_names=["M", "N", "dtype"], + tags=["short"], +) + + +class ReplaceNaNBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, M, N, dtype): + self.input = self.tensor([M, N], dtype) + self.set_module_name("replace_nan") + + def forward(self): + op = core.CreateOperator("ReplaceNaN", self.input, self.input, value=1.0) + return op + + +op_bench_c2.generate_c2_test( + replace_nan_long_configs + replace_nan_short_configs, ReplaceNaNBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/add_test.py b/benchmarks/operator_benchmark/pt/add_test.py index 911c44b5436a4..de0277fc3de70 100644 --- a/benchmarks/operator_benchmark/pt/add_test.py +++ b/benchmarks/operator_benchmark/pt/add_test.py @@ -29,12 +29,14 @@ class AddBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device): - self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) - self.input_two = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "input_two": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) + } self.set_module_name("add") - def forward(self): - return torch.add(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.add(input_one, input_two) # The generated test names based on add_short_configs will be in the following pattern: # add_M8_N16_K32_devicecpu @@ -53,13 +55,15 @@ def forward(self): class AddmmBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device): - self.input_one = torch.rand(M, K, device=device, requires_grad=self.auto_set()) - self.mat1 = torch.rand(M, N, device=device, requires_grad=self.auto_set()) - self.mat2 = torch.rand(N, K, device=device, requires_grad=self.auto_set()) + self.inputs = { + "input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()), + "mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()), + "mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()) + } self.set_module_name("addmm") - def forward(self): - return torch.addmm(self.input_one, self.mat1, self.mat2) + def forward(self, input_one, mat1, mat2): + return torch.addmm(input_one, mat1, mat2) op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark) op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark) @@ -70,13 +74,15 @@ def forward(self): class AddrBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, device, dtype): - self.input_one = torch.rand((M, N), device=device, requires_grad=self.auto_set(), dtype=dtype) - self.vec1 = torch.rand((M,), device=device, requires_grad=self.auto_set(), dtype=dtype) - self.vec2 = torch.rand((N,), device=device, requires_grad=self.auto_set(), dtype=dtype) + self.inputs = { + "input_one": torch.rand((M, N), device=device, requires_grad=self.auto_set(), dtype=dtype), + "vec1": torch.rand((M,), device=device, requires_grad=self.auto_set(), dtype=dtype), + "vec2": torch.rand((N,), device=device, requires_grad=self.auto_set(), dtype=dtype) + } self.set_module_name("addr") - def forward(self): - return torch.addr(self.input_one, self.vec1, self.vec2) + def forward(self, input_one, vec1, vec2): + return torch.addr(input_one, vec1, vec2) addr_configs = op_bench.cross_product_configs( M=[8, 256], @@ -95,13 +101,15 @@ def forward(self): class AddbmmBenchmark(op_bench.TorchBenchmarkBase): def init(self, B, M, N, K, device): - self.input_one = torch.rand((M, N), device=device, requires_grad=self.auto_set()) - self.batch1 = torch.rand((B, M, K), device=device, requires_grad=self.auto_set()) - self.batch2 = torch.rand((B, K, N,), device=device, requires_grad=self.auto_set()) + self.inputs = { + "input_one": torch.rand((M, N), device=device, requires_grad=self.auto_set()), + "batch1": torch.rand((B, M, K), device=device, requires_grad=self.auto_set()), + "batch2": torch.rand((B, K, N,), device=device, requires_grad=self.auto_set()) + } self.set_module_name("addbmm") - def forward(self): - return torch.addbmm(self.input_one, self.batch1, self.batch2) + def forward(self, input_one, batch1, batch2): + return torch.addbmm(input_one, batch1, batch2) addbmm_configs = op_bench.cross_product_configs( B=[2, 100], diff --git a/benchmarks/operator_benchmark/pt/as_strided_test.py b/benchmarks/operator_benchmark/pt/as_strided_test.py index a43702c15e221..77eff29811be6 100644 --- a/benchmarks/operator_benchmark/pt/as_strided_test.py +++ b/benchmarks/operator_benchmark/pt/as_strided_test.py @@ -1,5 +1,6 @@ import operator_benchmark as op_bench import torch +from typing import List """Microbenchmarks for as_strided operator""" @@ -32,15 +33,19 @@ class As_stridedBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, size, stride, storage_offset, device): - self.input_one = torch.rand(M, N, device=device) - self.size = size - self.stride = stride - self.storage_offset = storage_offset + self.inputs = { + "input_one": torch.rand(M, N, device=device), + "size": size, + "stride": stride, + "storage_offset": storage_offset + } self.set_module_name('as_strided') - def forward(self): + def forward( + self, input_one, size: List[int], stride: List[int], storage_offset: int + ): return torch.as_strided( - self.input_one, self.size, self.stride, self.storage_offset) + input_one, size, stride, storage_offset) op_bench.generate_pt_test(as_strided_configs_short + as_strided_configs_long, diff --git a/benchmarks/operator_benchmark/pt/batchnorm_test.py b/benchmarks/operator_benchmark/pt/batchnorm_test.py index 7257be36b9f11..816bdcc553421 100644 --- a/benchmarks/operator_benchmark/pt/batchnorm_test.py +++ b/benchmarks/operator_benchmark/pt/batchnorm_test.py @@ -28,15 +28,17 @@ class BatchNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device): - self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) - self.mean = torch.rand(N, device=device) - self.var = torch.rand(N, device=device) - self.weight = torch.rand(N, device=device) - self.bias = torch.rand(N, device=device) + self.inputs = { + "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), + "mean": torch.rand(N, device=device), + "var": torch.rand(N, device=device), + "weight": torch.rand(N, device=device), + "bias": torch.rand(N, device=device) + } self.set_module_name("batchnorm") - def forward(self): - return F.batch_norm(self.input_one, self.mean, self.var, self.weight, self.bias) + def forward(self, input_one, mean, var, weight, bias): + return F.batch_norm(input_one, mean, var, weight, bias) op_bench.generate_pt_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/binary_test.py b/benchmarks/operator_benchmark/pt/binary_test.py index bd177775764e1..9650392deae77 100644 --- a/benchmarks/operator_benchmark/pt/binary_test.py +++ b/benchmarks/operator_benchmark/pt/binary_test.py @@ -29,12 +29,14 @@ class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase): def init(self, in_one, in_two, dtype, device, op_func): - self.in_one = torch.randn(in_one, device=device).to(dtype=dtype) - self.in_two = torch.randn(in_two, device=device).to(dtype=dtype) + self.inputs = { + "in_one": torch.randn(in_one, device=device).to(dtype=dtype), + "in_two": torch.randn(in_two, device=device).to(dtype=dtype) + } self.op_func = op_func - def forward(self): - return self.op_func(self.in_one, self.in_two) + def forward(self, in_one, in_two): + return self.op_func(in_one, in_two) op_bench.generate_pt_tests_from_op_list(binary_ops_bcast_list, @@ -42,12 +44,15 @@ def forward(self): BinaryOpBcastBenchmark) +def copy(in1, in2): + return in1.copy_(in2) + # Benchmark ops performance without broadcast binary_ops_list = op_bench.op_list( attr_names=['op_name', 'op_func'], attrs=[ ['add', torch.add], - ['copy_', lambda in1, in2: in1.copy_(in2)], + ['copy_', copy], ], ) @@ -79,12 +84,14 @@ def forward(self): class BinaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device, dtype_one, dtype_two, op_func): - self.input_one = torch.randn(M, N, K, device=device).to(dtype=dtype_one) - self.input_two = torch.randn(M, N, K, device=device).to(dtype=dtype_two) + self.inputs = { + "input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one), + "input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two) + } self.op_func = op_func - def forward(self): - return self.op_func(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return self.op_func(input_one, input_two) op_bench.generate_pt_tests_from_op_list(binary_ops_list, diff --git a/benchmarks/operator_benchmark/pt/cat_test.py b/benchmarks/operator_benchmark/pt/cat_test.py index 97df910873660..c1022f296a2fb 100644 --- a/benchmarks/operator_benchmark/pt/cat_test.py +++ b/benchmarks/operator_benchmark/pt/cat_test.py @@ -1,6 +1,7 @@ import operator_benchmark as op_bench import torch import random +from typing import List """Microbenchmarks for Cat operator""" @@ -78,16 +79,19 @@ class CatBenchmark(op_bench.TorchBenchmarkBase): def init(self, sizes, N, dim, device): random.seed(42) - self.inputs = [] + inputs = [] for i in range(N): current_sizes = [old_size() if callable(old_size) else old_size for old_size in sizes] - self.inputs.append(torch.rand(current_sizes, device=device)) - self.dim = dim + inputs.append(torch.rand(current_sizes, device=device)) + self.inputs = { + "inputs": inputs, + "dim": dim + } self.set_module_name('cat') - def forward(self): - return torch.cat(self.inputs, dim=self.dim) + def forward(self, inputs: List[torch.Tensor], dim: int): + return torch.cat(inputs, dim=dim) op_bench.generate_pt_test(cat_configs_short + diff --git a/benchmarks/operator_benchmark/pt/channel_shuffle_test.py b/benchmarks/operator_benchmark/pt/channel_shuffle_test.py index 258bb6d69c042..87163f004b2db 100644 --- a/benchmarks/operator_benchmark/pt/channel_shuffle_test.py +++ b/benchmarks/operator_benchmark/pt/channel_shuffle_test.py @@ -36,16 +36,19 @@ class ChannelSHuffleBenchmark(op_bench.TorchBenchmarkBase): def init(self, batch_size, channels_per_group, height, width, groups, channel_last): - self.groups = groups channels = channels_per_group * groups data_shape = (batch_size, channels, height, width) - self.input_data = torch.rand(data_shape) + input_data = torch.rand(data_shape) if channel_last: - self.input_data = self.input_data.contiguous(memory_format=torch.channels_last) + input_data = input_data.contiguous(memory_format=torch.channels_last) + self.inputs = { + "input_data": input_data, + "groups": groups + } self.set_module_name('channel_shuffle') - def forward(self): - return torch.channel_shuffle(self.input_data, self.groups) + def forward(self, input_data, groups: int): + return torch.channel_shuffle(input_data, groups) op_bench.generate_pt_test(channel_shuffle_short_configs + channel_shuffle_long_configs, diff --git a/benchmarks/operator_benchmark/pt/chunk_test.py b/benchmarks/operator_benchmark/pt/chunk_test.py index 885301dfdcb0b..6c1148dbcdaab 100644 --- a/benchmarks/operator_benchmark/pt/chunk_test.py +++ b/benchmarks/operator_benchmark/pt/chunk_test.py @@ -30,12 +30,14 @@ class ChunkBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, chunks, device): - self.input_one = torch.rand(M, N, device=device) - self.chunks = chunks - self.set_module_name('chunk') - - def forward(self): - return torch.chunk(self.input_one, self.chunks) + self.inputs = { + "input_one": torch.rand(M, N, device=device), + "chunks": chunks + } + self.set_module_name("chunk") + + def forward(self, input_one, chunks: int): + return torch.chunk(input_one, chunks) op_bench.generate_pt_test(chunk_short_configs + chunks_long_configs, diff --git a/benchmarks/operator_benchmark/pt/clip_ranges_test.py b/benchmarks/operator_benchmark/pt/clip_ranges_test.py new file mode 100644 index 0000000000000..3b6b95d937867 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/clip_ranges_test.py @@ -0,0 +1,54 @@ +import operator_benchmark as op_bench +import torch + + +"""Microbenchmarks for ClipRanges operator.""" +torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") + +# Configs for C2 ClipRanges operator +clip_ranges_long_configs = op_bench.cross_product_configs( + LENGTH=range(1, 100), + M=[1], + N=[2], + MAX_LENGTH=range(1, 100), + device=['cpu', 'cuda'], + dtype=[torch.int32], + tags=["long"], +) + + +clip_ranges_short_configs = op_bench.config_list( + attrs=[ + [6, 1, 2, 1, torch.int32], + [7, 1, 2, 2, torch.int32], + [8, 1, 2, 3, torch.int32], + [9, 1, 2, 4, torch.int32], + [10, 1, 2, 5, torch.int32], + ], + attr_names=["LENGTH", "M", "N", "MAX_LENGTH", "dtype"], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"], +) + + +class ClipRangesBenchmark(op_bench.TorchBenchmarkBase): + def init(self, LENGTH, M, N, MAX_LENGTH, device, dtype): + self.inputs = { + "input": torch.rand(LENGTH, M, N, device=device).type(dtype), + "max_length": MAX_LENGTH + } + self.set_module_name("clip_ranges") + + def forward(self, input, max_length: int): + return torch.ops.fb.clip_ranges(input, max_length) + + +op_bench.generate_pt_test( + clip_ranges_long_configs + clip_ranges_short_configs, ClipRangesBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index 75c4f61290b3a..28b5118157168 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from . import configs +from pt import configs """ Microbenchmarks for Conv1d and ConvTranspose1d operators. @@ -11,22 +11,26 @@ class Conv1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, L, device): - self.input = torch.rand(N, IC, L, device=device) + self.inputs = { + "input": torch.rand(N, IC, L, device=device, requires_grad=self.auto_set()) + } self.conv1d = nn.Conv1d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('Conv1d') - def forward(self): - return self.conv1d(self.input) + def forward(self, input): + return self.conv1d(input) class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, L, device): - self.input = torch.rand(N, IC, L, device=device) + self.inputs = { + "input": torch.rand(N, IC, L, device=device) + } self.convtranspose1d = nn.ConvTranspose1d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('ConvTranspose1d') - def forward(self): - return self.convtranspose1d(self.input) + def forward(self, input): + return self.convtranspose1d(input) op_bench.generate_pt_test(configs.conv_1d_configs_short + configs.conv_1d_configs_long, @@ -42,24 +46,28 @@ def forward(self): class Conv2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): - self.input = torch.rand(N, IC, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, H, W, device=device) + } self.conv2d = nn.Conv2d( IC, OC, kernel, stride=stride, groups=G, padding=pad).to(device=device) self.set_module_name('Conv2d') - def forward(self): - return self.conv2d(self.input) + def forward(self, input): + return self.conv2d(input) class ConvTranspose2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): - self.input = torch.rand(N, IC, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, H, W, device=device) + } self.convtranspose2d = nn.ConvTranspose2d( IC, OC, kernel, stride=stride, groups=G, padding=pad).to(device=device) self.set_module_name('ConvTranspose2d') - def forward(self): - return self.convtranspose2d(self.input) + def forward(self, input): + return self.convtranspose2d(input) op_bench.generate_pt_test(configs.conv_2d_configs_short + configs.conv_2d_configs_long, @@ -74,22 +82,26 @@ def forward(self): class Conv3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, D, H, W, device): - self.input = torch.rand(N, IC, D, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, D, H, W, device=device) + } self.conv3d = nn.Conv3d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('Conv3d') - def forward(self): - return self.conv3d(self.input) + def forward(self, input): + return self.conv3d(input) class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, D, H, W, device): - self.input = torch.rand(N, IC, D, H, W, device=device) + self.inputs = { + "input": torch.rand(N, IC, D, H, W, device=device) + } self.convtranspose3d = nn.ConvTranspose3d(IC, OC, kernel, stride=stride).to(device=device) self.set_module_name('ConvTranspose3d') - def forward(self): - return self.convtranspose3d(self.input) + def forward(self, input): + return self.convtranspose3d(input) op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/diag_test.py b/benchmarks/operator_benchmark/pt/diag_test.py index 3fa1054428955..79ad86d29510e 100644 --- a/benchmarks/operator_benchmark/pt/diag_test.py +++ b/benchmarks/operator_benchmark/pt/diag_test.py @@ -22,13 +22,19 @@ class DiagBenchmark(op_bench.TorchBenchmarkBase): def init(self, dim, M, N, diagonal, out, device): - self.input = torch.rand(M, N, device=device) if dim == 2 else torch.rand(M, device=device) - self.diagonal = diagonal - self.out = torch.tensor((),) if out else None + self.inputs = { + "input": torch.rand(M, N, device=device) if dim == 2 else torch.rand(M, device=device), + "diagonal": diagonal, + "out": out, + "out_tensor": torch.tensor((),) + } self.set_module_name('diag') - def forward(self): - return torch.diag(self.input, diagonal=self.diagonal, out=self.out) + def forward(self, input, diagonal: int, out: bool, out_tensor): + if out: + return torch.diag(input, diagonal=diagonal, out=out_tensor) + else: + return torch.diag(input, diagonal=diagonal) op_bench.generate_pt_test(diag_configs_short, DiagBenchmark) diff --git a/benchmarks/operator_benchmark/pt/embeddingbag_test.py b/benchmarks/operator_benchmark/pt/embeddingbag_test.py index 0ad8ce46348b7..a8c100a797217 100644 --- a/benchmarks/operator_benchmark/pt/embeddingbag_test.py +++ b/benchmarks/operator_benchmark/pt/embeddingbag_test.py @@ -1,7 +1,7 @@ import operator_benchmark as op_bench import torch import numpy -from . import configs +from pt import configs """EmbeddingBag Operator Benchmark""" @@ -14,13 +14,16 @@ def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_las include_last_offset=include_last_offset, sparse=sparse).to(device=device) numpy.random.seed((1 << 32) - 1) - self.input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long() offsets = torch.LongTensor([offset], device=device) - self.offset = torch.cat((offsets, torch.tensor([self.input.size(0)], dtype=torch.long)), 0) + input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long() + self.inputs = { + "input": input, + "offset": torch.cat((offsets, torch.tensor([input.size(0)], dtype=torch.long)), 0) + } self.set_module_name('embeddingbag') - def forward(self): - return self.embedding(self.input, self.offset) + def forward(self, input, offset): + return self.embedding(input, offset) op_bench.generate_pt_test(configs.embeddingbag_short_configs, EmbeddingBagBenchmark) op_bench.generate_pt_gradient_test(configs.embeddingbag_short_configs, EmbeddingBagBenchmark) diff --git a/benchmarks/operator_benchmark/pt/fill_test.py b/benchmarks/operator_benchmark/pt/fill_test.py index 5a162db9f5f5b..97f59394a66a4 100644 --- a/benchmarks/operator_benchmark/pt/fill_test.py +++ b/benchmarks/operator_benchmark/pt/fill_test.py @@ -28,11 +28,13 @@ class Fill_Benchmark(op_bench.TorchBenchmarkBase): def init(self, N, device, dtype): - self.input_one = torch.zeros(N, device=device).type(dtype) + self.inputs = { + "input_one": torch.zeros(N, device=device).type(dtype) + } self.set_module_name("fill_") - def forward(self): - return self.input_one.fill_(10) + def forward(self, input_one): + return input_one.fill_(10) op_bench.generate_pt_test(fill_short_configs + fill_long_configs, diff --git a/benchmarks/operator_benchmark/pt/gather_test.py b/benchmarks/operator_benchmark/pt/gather_test.py index 509c1b937c3a7..6538cb3a8b902 100644 --- a/benchmarks/operator_benchmark/pt/gather_test.py +++ b/benchmarks/operator_benchmark/pt/gather_test.py @@ -30,15 +30,17 @@ class GatherBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, dim, device): - self.input_one = torch.rand(M, N, device=device) - self.dim = dim min_val = M if dim == 0 else N numpy.random.seed((1 << 32) - 1) - self.index = torch.tensor(numpy.random.randint(0, min_val, (M, N)), device=device) + self.inputs = { + "input_one": torch.rand(M, N, device=device), + "dim": dim, + "index": torch.tensor(numpy.random.randint(0, min_val, (M, N)), device=device) + } self.set_module_name("gather") - def forward(self): - return torch.gather(self.input_one, self.dim, self.index) + def forward(self, input_one, dim: int, index): + return torch.gather(input_one, dim, index) op_bench.generate_pt_test(gather_configs_short + gather_configs_long, diff --git a/benchmarks/operator_benchmark/pt/groupnorm_test.py b/benchmarks/operator_benchmark/pt/groupnorm_test.py index eb941b863dc73..f360ae26b2070 100644 --- a/benchmarks/operator_benchmark/pt/groupnorm_test.py +++ b/benchmarks/operator_benchmark/pt/groupnorm_test.py @@ -18,16 +18,18 @@ class GroupNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims, num_groups): - self.X = (torch.rand(*dims) - 0.5) * 256 - self.num_groups = num_groups num_channels = dims[1] - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - - def forward(self): + self.inputs = { + "input": (torch.rand(*dims) - 0.5) * 256, + "num_groups": num_groups, + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5 + } + + def forward(self, input, num_groups: int, weight, bias, eps: float): return F.group_norm( - self.X, self.num_groups, weight=self.weight, bias=self.bias, eps=self.eps) + input, num_groups, weight=weight, bias=bias, eps=eps) op_bench.generate_pt_test(groupnorm_configs_short, GroupNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/hardsigmoid_test.py b/benchmarks/operator_benchmark/pt/hardsigmoid_test.py index c3011d0a1fe41..f1161e485e721 100644 --- a/benchmarks/operator_benchmark/pt/hardsigmoid_test.py +++ b/benchmarks/operator_benchmark/pt/hardsigmoid_test.py @@ -45,11 +45,13 @@ class HardsigmoidBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, C, H, W, device, op_func): - self.input_one = torch.rand(N, C, H, W, device=device) + self.inputs = { + "input_one": torch.rand(N, C, H, W, device=device) + } self.op_func = op_func() - def forward(self): - return self.op_func(self.input_one) + def forward(self, input_one): + return self.op_func(input_one) op_bench.generate_pt_tests_from_op_list(hardsigmoid_ops_list, diff --git a/benchmarks/operator_benchmark/pt/hardswish_test.py b/benchmarks/operator_benchmark/pt/hardswish_test.py index 3879679bd33b0..0f1f94c0ddbac 100644 --- a/benchmarks/operator_benchmark/pt/hardswish_test.py +++ b/benchmarks/operator_benchmark/pt/hardswish_test.py @@ -45,11 +45,13 @@ class HardswishBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, C, H, W, device, op_func): - self.input_one = torch.rand(N, C, H, W, device=device) + self.inputs = { + "input_one": torch.rand(N, C, H, W, device=device) + } self.op_func = op_func() - def forward(self): - return self.op_func(self.input_one) + def forward(self, input_one): + return self.op_func(input_one) op_bench.generate_pt_tests_from_op_list(hardswish_ops_list, diff --git a/benchmarks/operator_benchmark/pt/index_select_test.py b/benchmarks/operator_benchmark/pt/index_select_test.py new file mode 100644 index 0000000000000..8418edb2840b1 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/index_select_test.py @@ -0,0 +1,57 @@ +import operator_benchmark as op_bench +import torch +import numpy + + +"""Microbenchmarks for index_select operator.""" + +# An example input from this configuration is M=4, N=4, dim=0. +index_select_configs_short = op_bench.config_list( + attr_names=["M", "N", "K", "dim"], + attrs=[ + [8, 8, 1, 1], + [256, 512, 1, 1], + [512, 512, 1, 1], + [8, 8, 2, 1], + [256, 512, 2, 1], + [512, 512, 2, 1], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"] +) + + +index_select_configs_long = op_bench.cross_product_configs( + M=[128, 1024], + N=[128, 1024], + K=[1, 2], + dim=[1], + device=['cpu', 'cuda'], + tags=["long"] +) + + +class IndexSelectBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, dim, device): + max_val = N + numpy.random.seed((1 << 32) - 1) + index_dim = numpy.random.randint(0, N) + self.inputs = { + "input_one": torch.rand(M, N, K, device=device), + "dim" : dim, + "index" : torch.tensor(numpy.random.randint(0, max_val, index_dim), device=device), + } + self.set_module_name("index_select") + + def forward(self, input_one, dim, index): + return torch.index_select(input_one, dim, index) + + +op_bench.generate_pt_test(index_select_configs_short + index_select_configs_long, + IndexSelectBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/instancenorm_test.py b/benchmarks/operator_benchmark/pt/instancenorm_test.py index 4eac02bc8bd8f..b152a9c753030 100644 --- a/benchmarks/operator_benchmark/pt/instancenorm_test.py +++ b/benchmarks/operator_benchmark/pt/instancenorm_test.py @@ -17,15 +17,17 @@ class InstanceNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims): - self.X = (torch.rand(*dims) - 0.5) * 256 num_channels = dims[1] - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - - def forward(self): + self.inputs = { + "input": (torch.rand(*dims) - 0.5) * 256, + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5 + } + + def forward(self, input, weight, bias, eps: float): return F.instance_norm( - self.X, weight=self.weight, bias=self.bias, eps=self.eps) + input, weight=weight, bias=bias, eps=eps) op_bench.generate_pt_test(instancenorm_configs_short, InstanceNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/layernorm_test.py b/benchmarks/operator_benchmark/pt/layernorm_test.py index f0aa81a8291c9..b18abf26eaf83 100644 --- a/benchmarks/operator_benchmark/pt/layernorm_test.py +++ b/benchmarks/operator_benchmark/pt/layernorm_test.py @@ -19,14 +19,17 @@ class LayerNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims): - self.X = (torch.rand(*dims) - 0.5) * 256 - self.weight = torch.rand(*self.X.size()[1:], dtype=torch.float) - self.bias = torch.rand(*self.X.size()[1:], dtype=torch.float) - self.eps = 1e-5 - - def forward(self): + input = (torch.rand(*dims) - 0.5) * 256 + self.inputs = { + "input": input, + "weight": torch.rand(*input.size()[1:], dtype=torch.float), + "bias": torch.rand(*input.size()[1:], dtype=torch.float), + "eps": 1e-5 + } + + def forward(self, input, weight, bias, eps: float): return F.layer_norm( - self.X, self.X.size()[1:], weight=self.weight, bias=self.bias, eps=self.eps) + input, input.size()[1:], weight=weight, bias=bias, eps=eps) op_bench.generate_pt_test(layernorm_configs_short, LayerNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/linear_test.py b/benchmarks/operator_benchmark/pt/linear_test.py index d5ce8afa8acf6..84263ed6f2d43 100644 --- a/benchmarks/operator_benchmark/pt/linear_test.py +++ b/benchmarks/operator_benchmark/pt/linear_test.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from . import configs +from pt import configs """Microbenchmarks for Linear operator.""" @@ -11,12 +11,14 @@ class LinearBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, IN, OUT, device): - self.input_one = torch.rand(N, IN, device=device) + self.inputs = { + "input_one": torch.rand(N, IN, device=device) + } self.linear = nn.Linear(IN, OUT).to(device=device) self.set_module_name("linear") - def forward(self): - return self.linear(self.input_one) + def forward(self, input_one): + return self.linear(input_one) op_bench.generate_pt_test(configs.linear_configs_short + configs.linear_configs_long, diff --git a/benchmarks/operator_benchmark/pt/matmul_test.py b/benchmarks/operator_benchmark/pt/matmul_test.py index 0c60524b911ad..e5d7d27589d4e 100644 --- a/benchmarks/operator_benchmark/pt/matmul_test.py +++ b/benchmarks/operator_benchmark/pt/matmul_test.py @@ -31,14 +31,18 @@ class MatMulBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, trans_a, trans_b, device): - self.input_one = torch.rand(M, N, device=device) if trans_a \ - else torch.rand(N, M, device=device).t() - self.input_two = torch.rand(N, K, device=device) if trans_b \ - else torch.rand(K, N, device=device).t() + self.inputs = { + "input_one": torch.rand(M, N, device=device) + if trans_a + else torch.rand(N, M, device=device).t(), + "input_two": torch.rand(N, K, device=device) + if trans_b + else torch.rand(K, N, device=device).t(), + } self.set_module_name("matmul") - def forward(self): - return torch.matmul(self.input_one, self.input_two) + def forward(self, input_one, input_two): + return torch.matmul(input_one, input_two) op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) diff --git a/benchmarks/operator_benchmark/pt/nan_to_num_test.py b/benchmarks/operator_benchmark/pt/nan_to_num_test.py new file mode 100644 index 0000000000000..0e8d82ad781fb --- /dev/null +++ b/benchmarks/operator_benchmark/pt/nan_to_num_test.py @@ -0,0 +1,63 @@ +import operator_benchmark as op_bench +import torch +import math + + +"""Microbenchmarks for torch.nan_to_num / nan_to_num_ operators""" + +# Configs for PT torch.nan_to_num / nan_to_num_ operators + +nan_to_num_ops_list = op_bench.op_list( + attr_names=['op_name', 'op_func'], + attrs=[ + ['nan_to_num', torch.nan_to_num], + ['nan_to_num_', torch.nan_to_num_], + ], +) + +nan_to_num_long_configs = op_bench.cross_product_configs( + M=[32, 64, 128], + N=range(32, 128, 32), + dtype=[torch.float, torch.double], + replace_inf=[True, False], + tags=["long"], +) + + +nan_to_num_short_configs = op_bench.cross_product_configs( + M=[16, 64], + N=[64, 64], + dtype=[torch.float, torch.double], + replace_inf=[True, False], + tags=["short"], +) + + +class ReplaceNaNBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, dtype, replace_inf, op_func): + input = torch.randn(M, N, dtype=dtype) + input[0][0] = float("nan") + self.inputs = { + "input": input, + "replace_inf": replace_inf + } + self.op_func = op_func + self.set_module_name("nan_to_num") + + def forward(self, input, replace_inf: bool): + # compare inplace + if replace_inf: + return self.op_func(input, nan=1.0) + else: + return self.op_func(input, nan=1.0, posinf=math.inf, neginf=-math.inf) + + +op_bench.generate_pt_tests_from_op_list( + nan_to_num_ops_list, + nan_to_num_long_configs + nan_to_num_short_configs, + ReplaceNaNBenchmark, +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/pool_test.py b/benchmarks/operator_benchmark/pt/pool_test.py index 88a75522566db..f465c41a09678 100644 --- a/benchmarks/operator_benchmark/pt/pool_test.py +++ b/benchmarks/operator_benchmark/pt/pool_test.py @@ -41,13 +41,13 @@ class Pool1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, kernel, stride, N, C, L, device, op_func): - self.input = torch.rand(N, C, L, device=device) - self.kernel = kernel - self.stride = stride - self.op_func = op_func(self.kernel, stride=self.stride) + self.inputs = { + "input": torch.rand(N, C, L, device=device) + } + self.op_func = op_func(kernel, stride=stride) - def forward(self): - return self.op_func(self.input) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(pool_1d_ops_list, @@ -98,14 +98,14 @@ def forward(self): class Pool2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, kernel, stride, N, C, H, W, device, op_func): - self.input = torch.rand(N, C, H, W, device=device) - self.kernel = kernel - self.stride = stride - self.op_func = op_func(self.kernel, stride=self.stride) + self.inputs = { + "input": torch.rand(N, C, H, W, device=device) + } + self.op_func = op_func(kernel, stride=stride) - def forward(self): - return self.op_func(self.input) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(pool_2d_ops_list, @@ -158,13 +158,13 @@ def forward(self): class Pool3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, kernel, stride, N, C, D, H, W, device, op_func): - self.input = torch.rand(N, C, D, H, W, device=device) - self.kernel = kernel - self.stride = stride - self.op_func = op_func(self.kernel, stride=self.stride) + self.inputs = { + "input": torch.rand(N, C, D, H, W, device=device) + } + self.op_func = op_func(kernel, stride=stride) - def forward(self): - return self.op_func(self.input) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(pool_3d_ops_list, diff --git a/benchmarks/operator_benchmark/pt/qactivation_test.py b/benchmarks/operator_benchmark/pt/qactivation_test.py index 7ef51958dc606..ea4d554d93569 100644 --- a/benchmarks/operator_benchmark/pt/qactivation_test.py +++ b/benchmarks/operator_benchmark/pt/qactivation_test.py @@ -42,12 +42,9 @@ qactivation_ops = op_bench.op_list( attrs=( - ('relu', nnq.functional.relu), + ('relu', torch.nn.ReLU()), ('relu6', torch.ops.quantized.relu6), ('functional.hardtanh', nnq.functional.hardtanh), - ('functional.hardswish', nnq.functional.hardswish), - ('functional.elu', nnq.functional.elu), - ('functional.celu', nnq.functional.celu), ('functional.hardsigmoid', nnq.functional.hardsigmoid), ('functional.leaky_relu', nnq.functional.leaky_relu), ('functional.sigmoid', torch.nn.functional.sigmoid), @@ -66,28 +63,49 @@ def _setup(self, dims, contig, dtype): self.zero_point = 0 # Quantize the tensor - self.q_input = torch.quantize_per_tensor(f_input, scale=self.scale, - zero_point=self.zero_point, - dtype=dtype) + q_input = torch.quantize_per_tensor(f_input, scale=self.scale, + zero_point=self.zero_point, + dtype=dtype) if not contig: # Make non-contiguous - new_shape = list(range(self.q_input.ndim))[::-1] - self.q_input = self.q_input.permute(new_shape) + new_shape = list(range(q_input.ndim))[::-1] + q_input = q_input.permute(new_shape) + + self.inputs = { + "q_input": q_input + } def init(self, dims, contig, inplace, dtype, op_func): self._setup(dims, contig, dtype) self.qop = op_func - def forward(self): - if self.qop in (nnq.functional.hardswish, nnq.functional.elu, - nnq.functional.celu): - return self.qop(self.q_input, scale=self.scale, zero_point=self.zero_point) - return self.qop(self.q_input) + +class QActivationBenchmark(QActivationBenchmarkBase): + def forward(self, q_input): + return self.qop(q_input) op_bench.generate_pt_tests_from_op_list(qactivation_ops, qactivation_short_configs + qactivation_long_configs, - QActivationBenchmarkBase) + QActivationBenchmark) + + +qactivation_scale_zero_point_ops = op_bench.op_list( + attrs=( + ('functional.hardswish', nnq.functional.hardswish), + ('functional.elu', nnq.functional.elu), + ('functional.celu', nnq.functional.celu), + ), + attr_names=('op_name', 'op_func'), +) + +class QActivationScaleZeroPointBenchmark(QActivationBenchmarkBase): + def forward(self, q_input): + return self.qop(q_input, scale=self.scale, zero_point=self.zero_point) + +op_bench.generate_pt_tests_from_op_list(qactivation_scale_zero_point_ops, + qactivation_short_configs + qactivation_long_configs, + QActivationScaleZeroPointBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/qarithmetic_test.py b/benchmarks/operator_benchmark/pt/qarithmetic_test.py index 87c75845900b3..01be129fe597a 100644 --- a/benchmarks/operator_benchmark/pt/qarithmetic_test.py +++ b/benchmarks/operator_benchmark/pt/qarithmetic_test.py @@ -1,5 +1,5 @@ import torch - +from torch._ops import ops import operator_benchmark as op_bench qarithmetic_binary_configs = op_bench.cross_product_configs( @@ -10,58 +10,77 @@ tags=('short',) ) + qarithmetic_binary_ops = op_bench.op_list( attrs=( - ('add', 'add'), - ('add_scalar', 'add_scalar'), - ('add_relu', 'add_relu'), - ('mul', 'mul'), - ('mul_scalar', 'mul_scalar'), + ('add', ops.quantized.add), + ('add_relu', ops.quantized.add_relu), + ('mul', ops.quantized.mul), ), attr_names=('op_name', 'op_func'), ) +qarithmetic_binary_scalar_ops = op_bench.op_list( + attrs=( + ('add_scalar', ops.quantized.add_scalar), + ('mul_scalar', ops.quantized.mul_scalar), + ), + attr_names=('op_name', 'op_func'), +) -r"""Base class to use QFunctional. - -Children will need to set `self.qop` to the qfunctional op under test. -I.e. `self.qop = 'add'` -""" class _QFunctionalBinaryArithmeticBenchmarkBase(op_bench.TorchBenchmarkBase): def setup(self, N, dtype, contig): self.qfunctional = torch.nn.quantized.QFunctional() # TODO: Consider more diverse shapes f_input = (torch.rand(N, N) - 0.5) * 256 - scale = 1.0 - zero_point = 0 - - self.q_input_a = torch.quantize_per_tensor(f_input, scale=scale, - zero_point=zero_point, + self.scale = 1.0 + self.zero_point = 0 + self.q_input_a = torch.quantize_per_tensor(f_input, scale=self.scale, + zero_point=self.zero_point, dtype=dtype) if not contig: permute_dims = list(range(f_input.ndim))[::-1] self.q_input_a = self.q_input_a.permute(permute_dims) - def forward(self): - return getattr(self.qfunctional, self.qop)(self.q_input_a, - self.q_input_b) - -class QFunctionalAddBenchmarkBase(_QFunctionalBinaryArithmeticBenchmarkBase): +class QFunctionalBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase): def init(self, N, dtype, contig, op_func): - super(QFunctionalAddBenchmarkBase, self).setup(N, dtype, contig) - self.qop = op_func - if self.qop.endswith('_scalar'): - self.q_input_b = 42 - else: - self.q_input_b = self.q_input_a + super(QFunctionalBenchmark, self).setup(N, dtype, contig) + self.inputs = { + "q_input_a": self.q_input_a, + "q_input_b": self.q_input_a, + "scale": self.scale, + "zero_point": self.zero_point + } + self.op_func = op_func + + def forward(self, q_input_a, q_input_b, scale: float, zero_point: int): + return self.op_func(q_input_a, q_input_b, scale=scale, zero_point=zero_point) op_bench.generate_pt_tests_from_op_list(qarithmetic_binary_ops, qarithmetic_binary_configs, - QFunctionalAddBenchmarkBase) + QFunctionalBenchmark) + + +class QFunctionalScalarBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase): + def init(self, N, dtype, contig, op_func): + super(QFunctionalScalarBenchmark, self).setup(N, dtype, contig) + self.inputs = { + "q_input": self.q_input_a, + "scalar_input": 42 + } + self.op_func = op_func + + def forward(self, q_input, scalar_input: int): + return self.op_func(q_input, scalar_input) + + +op_bench.generate_pt_tests_from_op_list(qarithmetic_binary_scalar_ops, + qarithmetic_binary_configs, + QFunctionalScalarBenchmark) if __name__ == '__main__': diff --git a/benchmarks/operator_benchmark/pt/qbatchnorm_test.py b/benchmarks/operator_benchmark/pt/qbatchnorm_test.py index f729f79dcce70..b7d591096a8d6 100644 --- a/benchmarks/operator_benchmark/pt/qbatchnorm_test.py +++ b/benchmarks/operator_benchmark/pt/qbatchnorm_test.py @@ -23,15 +23,17 @@ def init(self, M, N, K, device, dtype): self._init(M, N, K, device) x_scale = 0.1 x_zero_point = 0 - self.q_input_one = torch.quantize_per_tensor( - self.input_one, scale=x_scale, zero_point=x_zero_point, dtype=dtype) - self.mean = torch.rand(N) - self.var = torch.rand(N) - self.weight = torch.rand(N) - self.bias = torch.rand(N) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 + self.inputs = { + "q_input_one": torch.quantize_per_tensor( + self.input_one, scale=x_scale, zero_point=x_zero_point, dtype=dtype), + "mean": torch.rand(N), + "var": torch.rand(N), + "weight": torch.rand(N), + "bias": torch.rand(N), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } def _init(self, M, N, K, device): pass @@ -45,10 +47,20 @@ def _init(self, M, N, K, device): self.set_module_name("QBatchNorm1d") self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set()) - def forward(self): + def forward( + self, + q_input_one, + weight, + bias, + mean, + var, + eps: float, + Y_scale: float, + Y_zero_point: int + ): return torch.ops.quantized.batch_norm1d( - self.q_input_one, self.weight, self.bias, self.mean, self.var, self.eps, - self.Y_scale, self.Y_zero_point) + q_input_one, weight, bias, mean, var, eps, + Y_scale, Y_zero_point) class QBatchNorm2dBenchmark(QBatchNormBenchmark): @@ -58,10 +70,20 @@ def _init(self, M, N, K, device): # add a 1 as the last dimension self.input_one = torch.rand(M, N, K, 1, device=device, requires_grad=self.auto_set()) - def forward(self): + def forward( + self, + q_input_one, + weight, + bias, + mean, + var, + eps: float, + Y_scale: float, + Y_zero_point: int + ): return torch.ops.quantized.batch_norm2d( - self.q_input_one, self.weight, self.bias, self.mean, self.var, self.eps, - self.Y_scale, self.Y_zero_point) + q_input_one, weight, bias, mean, var, eps, + Y_scale, Y_zero_point) op_bench.generate_pt_test(batchnorm_configs_short, QBatchNorm1dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qcat_test.py b/benchmarks/operator_benchmark/pt/qcat_test.py index 77a66f53d2f03..32dd32e43adfe 100644 --- a/benchmarks/operator_benchmark/pt/qcat_test.py +++ b/benchmarks/operator_benchmark/pt/qcat_test.py @@ -2,6 +2,7 @@ import torch import torch.nn.quantized as nnq +from typing import List """Microbenchmarks for quantized Cat operator""" @@ -53,11 +54,14 @@ def init(self, M, N, K, L, dim, contig, dtype): elif contig == 'none': self.input = (q_input_non_contig, q_input_non_contig) - self.dim = dim + self.inputs = { + "input": self.input, + "dim": dim + } self.set_module_name('qcat') - def forward(self): - return self.qf.cat(self.input, dim=self.dim) + def forward(self, input: List[torch.Tensor], dim: int): + return self.qf.cat(input, dim=dim) op_bench.generate_pt_test(qcat_configs_short + qcat_configs_long, diff --git a/benchmarks/operator_benchmark/pt/qcomparators_test.py b/benchmarks/operator_benchmark/pt/qcomparators_test.py index d86ec20eb65d4..9c26f6dee23b8 100644 --- a/benchmarks/operator_benchmark/pt/qcomparators_test.py +++ b/benchmarks/operator_benchmark/pt/qcomparators_test.py @@ -34,23 +34,32 @@ def init(self, N, dtype, contig, other_scalar, out_variant, op_func): q_input_a = torch.quantize_per_tensor(f_input, scale=scale, zero_point=zero_point, dtype=dtype) - if other_scalar: - q_input_b = 42 - else: - q_input_b = q_input_a.clone() + q_input_b = q_input_a.clone() if not contig: permute_dims = list(range(f_input.ndim))[::-1] q_input_a = q_input_a.permute(permute_dims) self.qop = op_func - self.args = (q_input_a, q_input_b) - self.kwargs = {} + self.inputs = { + "q_input_a": q_input_a, + "q_input_b": q_input_b, + "out_variant": out_variant, + "other_scalar": other_scalar, + } + + def forward(self, q_input_a, q_input_b, out_variant: bool, other_scalar: bool): if out_variant: - self.kwargs['out'] = torch.tensor([], dtype=torch.bool) + if other_scalar: + return self.qop(q_input_a, 42, out=torch.tensor(True, dtype=torch.bool)) + else: + return self.qop(q_input_a, q_input_b, out=torch.tensor(True, dtype=torch.bool)) + else: + if other_scalar: + return self.qop(q_input_a, 42) + else: + return self.qop(q_input_a, q_input_b) - def forward(self): - return self.qop(*self.args, **self.kwargs) op_bench.generate_pt_tests_from_op_list(qcomparators_ops, diff --git a/benchmarks/operator_benchmark/pt/qconv_test.py b/benchmarks/operator_benchmark/pt/qconv_test.py index e014df5fdd6f7..14e8e143a7ca8 100644 --- a/benchmarks/operator_benchmark/pt/qconv_test.py +++ b/benchmarks/operator_benchmark/pt/qconv_test.py @@ -3,7 +3,7 @@ import torch import torch.nn.quantized as nnq -from . import configs +from pt import configs """ Microbenchmarks for qConv operators. @@ -24,16 +24,18 @@ def init(self, IC, OC, kernel, stride, N, L, device): W = torch.randn(OC, IC // G, kernel, dtype=torch.float32) self.qW = torch.quantize_per_tensor(W, scale=self.scale, zero_point=0, dtype=torch.qint8) - self.input = qX + self.inputs = { + "input": qX + } self.qconv1d = nnq.Conv1d(IC, OC, kernel, stride=stride, padding=pad, groups=G) self.qconv1d.set_weight_bias(self.qW, None) - self.qconv1d.scale = torch.tensor([self.scale], dtype=torch.double) - self.qconv1d.zero_point = torch.tensor([self.zero_point], dtype=torch.int) + self.qconv1d.scale = torch.tensor(self.scale, dtype=torch.double) + self.qconv1d.zero_point = torch.tensor(self.zero_point, dtype=torch.int) self.set_module_name("QConv1d") - def forward(self): - return self.qconv1d(self.input) + def forward(self, input): + return self.qconv1d(input) class QConv2dBenchmark(op_bench.TorchBenchmarkBase): @@ -51,16 +53,18 @@ def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): W = torch.randn(OC, IC // G, kernel, kernel, dtype=torch.float32) self.qW = torch.quantize_per_tensor(W, scale=self.scale, zero_point=0, dtype=torch.qint8) - self.input = qX + self.inputs = { + "input": qX + } self.qconv2d = nnq.Conv2d(IC, OC, kernel, stride=stride, padding=pad, groups=G) self.qconv2d.set_weight_bias(self.qW, None) - self.qconv2d.scale = torch.tensor([self.scale], dtype=torch.double) - self.qconv2d.zero_point = torch.tensor([self.zero_point], dtype=torch.int) + self.qconv2d.scale = torch.tensor(self.scale, dtype=torch.double) + self.qconv2d.zero_point = torch.tensor(self.zero_point, dtype=torch.int) self.set_module_name("QConv2d") - def forward(self): - return self.qconv2d(self.input) + def forward(self, input): + return self.qconv2d(input) op_bench.generate_pt_test(configs.remove_cuda(configs.conv_1d_configs_short + configs.conv_1d_configs_long), QConv1dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py new file mode 100644 index 0000000000000..4bd06b027969b --- /dev/null +++ b/benchmarks/operator_benchmark/pt/qembedding_bag_lookups_test.py @@ -0,0 +1,248 @@ + +import operator_benchmark as op_bench +import torch +import numpy as np +from typing import Optional + +from torch.testing._internal.common_quantization import ( + lengths_to_offsets +) + +torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") + + +embedding_bag_rowwise_offsets_short_configs = op_bench.cross_product_configs( + num_embeddings=(80,), + embedding_dim=(128, 256), + num_offsets=range(2, 10), + enable_per_sample_weights=(True, False), + include_last_offset=(True, False), + is_pruned_weights=(True, False,), + use_32bit_indices=(True, False), + use_32bit_offsets=(True, False), + tags=['short'], +) + + +embedding_bag_rowwise_offsets_long_configs = op_bench.cross_product_configs( + num_embeddings=(100, 120, 1000, 10_000, 20_000), + embedding_dim=(16, 64, 128, 256), + num_offsets=range(10, 20), + enable_per_sample_weights=(True, False), + include_last_offset=(True, False), + is_pruned_weights=(True, False,), + use_32bit_indices=(True, False), + use_32bit_offsets=(True, False), + tags=['long'] +) + + +full_configs = embedding_bag_rowwise_offsets_short_configs + embedding_bag_rowwise_offsets_long_configs + +four_bit_rowwise_ops = op_bench.op_list( + attrs=( + ('qembeddingbag_4bit_rowwise_offsets', torch.ops.quantized.embedding_bag_4bit_rowwise_offsets), + ), + attr_names=('op_name', 'op_func'), +) + +byte_rowwise_ops = op_bench.op_list( + attrs=( + ('qembeddingbag_byte_rowwise_offsets', torch.ops.quantized.embedding_bag_byte_rowwise_offsets), + ), + attr_names=('op_name', 'op_func'), +) + + +def get_pruned_weights_and_mapping(q_weights): + indicator = torch.from_numpy(np.random.uniform( + low=-1.0, high=1.0, size=[q_weights.shape[0]]).astype(np.float32)) + + q_pruned_weights, compressed_indices_mapping = torch.ops.fb.embedding_bag_rowwise_prune( + q_weights, indicator, 0.01, torch.int32) + + return q_pruned_weights, compressed_indices_mapping + + +class EmbedddingBag4BitRowwiseOffsetsTest(op_bench.TorchBenchmarkBase): + def init(self, + num_embeddings: int, + embedding_dim: int, + num_offsets: int, + enable_per_sample_weights: bool, + include_last_offset: bool, + is_pruned_weights: bool, + use_32bit_indices: bool, + use_32bit_offsets: bool, + op_func): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.num_offsets = num_offsets + self.enable_per_sample_weights = enable_per_sample_weights + self.include_last_offset = include_last_offset + self.max_segment_length = 20 + self.num_lengths = np.random.randint(1, num_offsets + 1) + self.lengths = np.random.randint(0, self.max_segment_length + 1, + size=self.num_lengths).astype(np.int32) + self.num_indices = np.sum(self.lengths) + self.is_pruned_weights = is_pruned_weights + self.use_32bit_indices = use_32bit_indices + self.use_32bit_offsets = use_32bit_offsets + + self.offsets = lengths_to_offsets(self.lengths) + self.indices = torch.from_numpy(np.random.randint( + low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64)) + + self.indices = self.indices.int() if self.use_32bit_indices else self.indices + self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets + + if self.include_last_offset: + self.offsets = torch.cat( + (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)), 0 + ) + + self.weights = torch.from_numpy((np.random.random_sample(( + self.num_embeddings, self.embedding_dim)) + 1).astype(np.float32)) + self.indices = torch.from_numpy(np.random.randint( + low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64)) + self.prepack_func = torch.ops.quantized.embedding_bag_4bit_prepack + + self.prepacked_weights = self.prepack_func(self.weights) + self.per_sample_weights = torch.from_numpy(np.random.uniform( + low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \ + self.enable_per_sample_weights else None + + self.compressed_indices = None + + if self.is_pruned_weights: + self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights) + + self.inputs = { + "prepacked_weights": self.prepacked_weights, + "indices": self.indices, + "offsets": self.offsets, + "mode": 0, + "per_sample_weights": self.per_sample_weights, + "include_last_offset": self.include_last_offset, + "is_pruned_weights": self.is_pruned_weights, + "compressed_indices": self.compressed_indices + } + + self.op_func = op_func + + def forward( + self, + prepacked_weights, + indices, + offsets, + mode: int, + per_sample_weights: Optional[torch.Tensor], + include_last_offset: bool, + is_pruned_weights: bool, + compressed_indices: Optional[torch.Tensor] + ): + + return self.op_func(prepacked_weights, indices, offsets, + mode=mode, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + pruned_weights=is_pruned_weights, + compressed_indices_mapping=compressed_indices) + + +class EmbedddingBagByteRowwiseOffsetsTest(op_bench.TorchBenchmarkBase): + def init(self, + num_embeddings: int, + embedding_dim: int, + num_offsets: int, + enable_per_sample_weights: bool, + include_last_offset: bool, + is_pruned_weights: bool, + use_32bit_indices: bool, + use_32bit_offsets: bool, + op_func): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.num_offsets = num_offsets + self.enable_per_sample_weights = enable_per_sample_weights + self.include_last_offset = include_last_offset + self.max_segment_length = 20 + self.num_lengths = np.random.randint(1, num_offsets + 1) + self.lengths = np.random.randint(0, self.max_segment_length + 1, + size=self.num_lengths).astype(np.int32) + self.is_pruned_weights = is_pruned_weights + self.use_32bit_indices = use_32bit_indices + self.use_32bit_offsets = use_32bit_offsets + + self.num_indices = np.sum(self.lengths) + self.offsets = lengths_to_offsets(self.lengths) + self.indices = torch.from_numpy(np.random.randint( + low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64)) + + self.indices = self.indices.int() if self.use_32bit_indices else self.indices + self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets + + if include_last_offset: + self.offsets = torch.cat( + (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)), 0 + ) + + self.weights = torch.from_numpy((np.random.random_sample(( + self.num_embeddings, self.embedding_dim)) + 1).astype(np.float32)) + self.indices = torch.from_numpy(np.random.randint( + low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64)) + + self.prepack_func = torch.ops.quantized.embedding_bag_byte_prepack + + self.prepacked_weights = self.prepack_func(self.weights) + self.per_sample_weights = torch.from_numpy(np.random.uniform( + low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \ + self.enable_per_sample_weights else None + + self.compressed_indices = None + + if self.is_pruned_weights: + self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights) + + self.inputs = { + "prepacked_weights": self.prepacked_weights, + "indices": self.indices, + "offsets": self.offsets, + "mode": 0, + "per_sample_weights": self.per_sample_weights, + "include_last_offset": self.include_last_offset, + "is_pruned_weights": self.is_pruned_weights, + "compressed_indices": self.compressed_indices + } + + self.op_func = op_func + + def forward( + self, + prepacked_weights, + indices, + offsets, + mode: int, + per_sample_weights: Optional[torch.Tensor], + include_last_offset: bool, + is_pruned_weights: bool, + compressed_indices: Optional[torch.Tensor] + ): + return self.op_func(prepacked_weights, indices, offsets, + mode=0, + per_sample_weights=per_sample_weights, + include_last_offset=self.include_last_offset, + pruned_weights=self.is_pruned_weights, + compressed_indices_mapping=self.compressed_indices) + + +op_bench.generate_pt_tests_from_op_list(four_bit_rowwise_ops, + full_configs, + EmbedddingBag4BitRowwiseOffsetsTest) +op_bench.generate_pt_tests_from_op_list(byte_rowwise_ops, + full_configs, + EmbedddingBagByteRowwiseOffsetsTest) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/qembedding_pack_test.py b/benchmarks/operator_benchmark/pt/qembedding_pack_test.py index e64d4fa1962b5..f9a3aaff051af 100644 --- a/benchmarks/operator_benchmark/pt/qembedding_pack_test.py +++ b/benchmarks/operator_benchmark/pt/qembedding_pack_test.py @@ -35,21 +35,25 @@ class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase): def init(self, num_embeddings, embedding_dim, op_func): - self.weight = torch.from_numpy((np.random.random_sample(( - num_embeddings, embedding_dim)) + 1).astype(np.float32)) + self.inputs = { + "weight": torch.from_numpy((np.random.random_sample(( + num_embeddings, embedding_dim)) + 1).astype(np.float32)) + } self.op_func = op_func - def forward(self): - return self.op_func(self.weight) + def forward(self, weight): + return self.op_func(weight) class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase): def init(self, num_embeddings, embedding_dim, op_func): weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float) - self.packed_weight = weight.to(torch.uint8) + self.inputs = { + "packed_weight": weight.to(torch.uint8) + } self.op_func = op_func - def forward(self): - return self.op_func(self.packed_weight) + def forward(self, packed_weight): + return self.op_func(packed_weight) op_bench.generate_pt_tests_from_op_list(conversion_ops, diff --git a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py index e80de5ebb6197..872f8c28fccd4 100644 --- a/benchmarks/operator_benchmark/pt/qembeddingbag_test.py +++ b/benchmarks/operator_benchmark/pt/qembeddingbag_test.py @@ -1,9 +1,9 @@ import operator_benchmark as op_bench import torch -import torch.nn.quantized.dynamic as nnqd +import torch.nn.quantized as nnq import numpy -from . import configs +from pt import configs """ Microbenchmarks for qEmbeddingBag operators. @@ -11,7 +11,7 @@ class QEmbeddingBagBenchmark(op_bench.TorchBenchmarkBase): def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_last_offset, device): - self.embedding = nnqd.EmbeddingBag( + self.embedding = nnq.EmbeddingBag( num_embeddings=embeddingbags, embedding_dim=dim, mode=mode, @@ -20,10 +20,14 @@ def init(self, embeddingbags, dim, mode, input_size, offset, sparse, include_las self.input = torch.tensor(numpy.random.randint(0, embeddingbags, input_size), device=device).long() offset = torch.LongTensor([offset], device=device) self.offset = torch.cat((offset, torch.tensor([self.input.size(0)], dtype=torch.long)), 0) + self.inputs = { + "input": self.input, + "offset": self.offset + } self.set_module_name('qEmbeddingBag') - def forward(self): - return self.embedding(self.input, self.offset) + def forward(self, input, offset): + return self.embedding(input, offset) op_bench.generate_pt_test(configs.embeddingbag_short_configs, QEmbeddingBagBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qgroupnorm_test.py b/benchmarks/operator_benchmark/pt/qgroupnorm_test.py index 6881bc4c518dc..942d6ab6560c2 100644 --- a/benchmarks/operator_benchmark/pt/qgroupnorm_test.py +++ b/benchmarks/operator_benchmark/pt/qgroupnorm_test.py @@ -20,23 +20,26 @@ class QGroupNormBenchmark(op_bench.TorchBenchmarkBase): def init(self, dims, num_groups, dtype): X = (torch.rand(*dims) - 0.5) * 256 - self.num_groups = num_groups num_channels = dims[1] scale = 1.0 zero_point = 0 - self.qX = torch.quantize_per_tensor( - X, scale=scale, zero_point=zero_point, dtype=dtype) - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 - - def forward(self): + + self.inputs = { + "qX": torch.quantize_per_tensor( + X, scale=scale, zero_point=zero_point, dtype=dtype), + "num_groups": num_groups, + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } + + def forward(self, qX, num_groups: int, weight, bias, eps: float, Y_scale: float, Y_zero_point: int): return torch.ops.quantized.group_norm( - self.qX, self.num_groups, weight=self.weight, bias=self.bias, - eps=self.eps, output_scale=self.Y_scale, - output_zero_point=self.Y_zero_point) + qX, num_groups, weight=weight, bias=bias, + eps=eps, output_scale=Y_scale, + output_zero_point=Y_zero_point) op_bench.generate_pt_test(groupnorm_configs_short, QGroupNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qinstancenorm_test.py b/benchmarks/operator_benchmark/pt/qinstancenorm_test.py index 5a770728bb5ee..df084700fac0d 100644 --- a/benchmarks/operator_benchmark/pt/qinstancenorm_test.py +++ b/benchmarks/operator_benchmark/pt/qinstancenorm_test.py @@ -22,19 +22,22 @@ def init(self, dims, dtype): num_channels = dims[1] scale = 1.0 zero_point = 0 - self.qX = torch.quantize_per_tensor( - X, scale=scale, zero_point=zero_point, dtype=dtype) - self.weight = torch.rand(num_channels, dtype=torch.float) - self.bias = torch.rand(num_channels, dtype=torch.float) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 - - def forward(self): + + self.inputs = { + "qX": torch.quantize_per_tensor( + X, scale=scale, zero_point=zero_point, dtype=dtype), + "weight": torch.rand(num_channels, dtype=torch.float), + "bias": torch.rand(num_channels, dtype=torch.float), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } + + def forward(self, qX, weight, bias, eps: float, Y_scale: float, Y_zero_point: int): return torch.ops.quantized.instance_norm( - self.qX, weight=self.weight, bias=self.bias, - eps=self.eps, output_scale=self.Y_scale, - output_zero_point=self.Y_zero_point) + qX, weight=weight, bias=bias, + eps=eps, output_scale=Y_scale, + output_zero_point=Y_zero_point) op_bench.generate_pt_test(instancenorm_configs_short, QInstanceNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qinterpolate_test.py b/benchmarks/operator_benchmark/pt/qinterpolate_test.py index a1861f4fe4b93..753154f13598c 100644 --- a/benchmarks/operator_benchmark/pt/qinterpolate_test.py +++ b/benchmarks/operator_benchmark/pt/qinterpolate_test.py @@ -44,15 +44,18 @@ def init(self, M, N, K, dtype, mode, scale, contig): dtype=dtype) if not contig: permute_dims = list(range(q_input.ndim))[::-1] - self.q_input_a = self.q_input_a.permute(permute_dims) + self.q_input = self.q_input.permute(permute_dims) - self.mode = mode - self.scale_factor = scale + self.inputs = { + "q_input": self.q_input, + "scale_factor": scale, + "mode": mode + } self.set_module_name('q_interpolate') - def forward(self): - return torch.nn.quantized.functional.interpolate( - self.q_input, scale_factor=self.scale_factor, mode=self.mode) + def forward(self, q_input, scale_factor: float, mode: str): + return torch.nn.functional.interpolate( + q_input, scale_factor=scale_factor, mode=mode) op_bench.generate_pt_test(qinterpolate_short_configs + qinterpolate_long_configs, diff --git a/benchmarks/operator_benchmark/pt/qlayernorm_test.py b/benchmarks/operator_benchmark/pt/qlayernorm_test.py index ee3224c315153..0a145ee015eab 100644 --- a/benchmarks/operator_benchmark/pt/qlayernorm_test.py +++ b/benchmarks/operator_benchmark/pt/qlayernorm_test.py @@ -25,17 +25,21 @@ def init(self, dims, dtype): zero_point = 0 self.qX = torch.quantize_per_tensor( X, scale=scale, zero_point=zero_point, dtype=dtype) - self.weight = torch.rand(*self.qX.size()[1:], dtype=torch.float) - self.bias = torch.rand(*self.qX.size()[1:], dtype=torch.float) - self.eps = 1e-5 - self.Y_scale = 0.1 - self.Y_zero_point = 0 - def forward(self): + self.inputs = { + "qX": self.qX, + "weight": torch.rand(*self.qX.size()[1:], dtype=torch.float), + "bias": torch.rand(*self.qX.size()[1:], dtype=torch.float), + "eps": 1e-5, + "Y_scale": 0.1, + "Y_zero_point": 0 + } + + def forward(self, qX, weight, bias, eps: float, Y_scale: float, Y_zero_point: int): return torch.ops.quantized.layer_norm( - self.qX, self.qX.size()[1:], weight=self.weight, bias=self.bias, - eps=self.eps, output_scale=self.Y_scale, - output_zero_point=self.Y_zero_point) + qX, qX.size()[1:], weight=weight, bias=bias, + eps=eps, output_scale=Y_scale, + output_zero_point=Y_zero_point) op_bench.generate_pt_test(layernorm_configs_short, QLayerNormBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qlinear_test.py b/benchmarks/operator_benchmark/pt/qlinear_test.py index 16a6f6521fd3e..6e4dd9d97eca5 100644 --- a/benchmarks/operator_benchmark/pt/qlinear_test.py +++ b/benchmarks/operator_benchmark/pt/qlinear_test.py @@ -5,7 +5,7 @@ import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd -from . import configs +from pt import configs """ Microbenchmarks for Quantized Linear operators. @@ -26,21 +26,25 @@ def init(self, N, IN, OUT, linear_under_test): self.qlinear.scale = scale self.qlinear.zero_point = zero_point - def forward(self): + def forward(self, input): # Assume that the `self.input` is set in the child - return self.qlinear(self.input) + return self.qlinear(input) class QLinearBenchmark(_QLinearBenchmarkBase): def init(self, N, IN, OUT, device): super(QLinearBenchmark, self).init(N, IN, OUT, nnq.Linear(IN, OUT)) - self.input = self.qX + self.inputs = { + "input": self.qX + } self.set_module_name("QLinear") class QDynamicLinearBenchmark(_QLinearBenchmarkBase): def init(self, N, IN, OUT, device): super(QDynamicLinearBenchmark, self).init(N, IN, OUT, nnqd.Linear(IN, OUT)) - self.input = self.X + self.inputs = { + "input": self.X + } self.set_module_name("QDynamicLinear") diff --git a/benchmarks/operator_benchmark/pt/qobserver_test.py b/benchmarks/operator_benchmark/pt/qobserver_test.py index 149acd2605656..6521773a73ff9 100644 --- a/benchmarks/operator_benchmark/pt/qobserver_test.py +++ b/benchmarks/operator_benchmark/pt/qobserver_test.py @@ -104,19 +104,22 @@ class QObserverBenchmark(op_bench.TorchBenchmarkBase): def init(self, C, M, N, dtype, qscheme, op_func, device): - self.f_input = torch.rand(C, M, N, device=device) + self.inputs = { + "f_input": torch.rand(C, M, N, device=device) + } self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device) - def forward(self): - self.op_func(self.f_input) - self.op_func.calculate_qparams() - return + def forward(self, f_input): + self.op_func(f_input) + return self.op_func.calculate_qparams() + class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase): def init(self, C, M, N, dtype, qscheme, op_func, device): self.f_input = torch.rand(C, M, N, device=device) self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device) self.q_observer(self.f_input) + self.inputs = {} def forward(self): return self.q_observer.calculate_qparams() diff --git a/benchmarks/operator_benchmark/pt/qpool_test.py b/benchmarks/operator_benchmark/pt/qpool_test.py index d53b4d05db987..b5c40fd4977a7 100644 --- a/benchmarks/operator_benchmark/pt/qpool_test.py +++ b/benchmarks/operator_benchmark/pt/qpool_test.py @@ -88,8 +88,12 @@ def setup(self, N, C, H, W, dtype, contig): self.q_input = self.q_input.permute(0, 2, 3, 1).contiguous() self.q_input = self.q_input.permute(0, 3, 1, 2) - def forward(self): - return self.pool_op(self.q_input) + self.inputs = { + "q_input": self.q_input + } + + def forward(self, q_input): + return self.pool_op(q_input) class QMaxPool2dBenchmark(_QPool2dBenchmarkBase): diff --git a/benchmarks/operator_benchmark/pt/qrnn_test.py b/benchmarks/operator_benchmark/pt/qrnn_test.py index 187a8f1a82e03..c6d696b817945 100644 --- a/benchmarks/operator_benchmark/pt/qrnn_test.py +++ b/benchmarks/operator_benchmark/pt/qrnn_test.py @@ -45,20 +45,25 @@ def init(self, I, H, NL, B, D, dtype): {nn.LSTM, nn.Linear}, dtype=dtype)[0] - self.x = torch.randn(sequence_len, # sequence length - batch_size, # batch size - I) # Number of features in X - self.h = torch.randn(NL * (D + 1), # layer_num * dir_num - batch_size, # batch size - H) # hidden size - self.c = torch.randn(NL * (D + 1), # layer_num * dir_num - batch_size, # batch size - H) # hidden size + x = torch.randn(sequence_len, # sequence length + batch_size, # batch size + I) # Number of features in X + h = torch.randn(NL * (D + 1), # layer_num * dir_num + batch_size, # batch size + H) # hidden size + c = torch.randn(NL * (D + 1), # layer_num * dir_num + batch_size, # batch size + H) # hidden size + self.inputs = { + "x": x, + "h": h, + "c": c + } self.set_module_name("QLSTM") - def forward(self): - return self.cell(self.x, (self.h, self.c)) + def forward(self, x, h, c): + return self.cell(x, (h, c))[0] op_bench.generate_pt_test(qrnn_configs, LSTMBenchmark) diff --git a/benchmarks/operator_benchmark/pt/qtensor_method_test.py b/benchmarks/operator_benchmark/pt/qtensor_method_test.py index 4834e3cb166bb..50dc59780ab86 100644 --- a/benchmarks/operator_benchmark/pt/qtensor_method_test.py +++ b/benchmarks/operator_benchmark/pt/qtensor_method_test.py @@ -22,16 +22,9 @@ tags=['long'] ) -qmethods_tensor_input_list = op_bench.op_list( - attr_names=['op_name', 'op_func'], - attrs=[ - ['q_copy', 'copy_'], - ], -) - class _QMethodBenchmarkBase(op_bench.TorchBenchmarkBase): - def init(self, M, N, dtype, contig, op_func): + def init(self, M, N, dtype, contig): f_input = torch.rand(M, N) scale = 1.0 zero_point = 0 @@ -41,23 +34,20 @@ def init(self, M, N, dtype, contig, op_func): if not contig: permute_dims = list(range(self.q_input.ndim))[::-1] self.q_input = self.q_input.permute(permute_dims) - self.op_func = op_func - -class QMethodTensorInputBenchmark(_QMethodBenchmarkBase): - def forward(self): - getattr(self.q_input, self.op_func)(self.q_input) + self.inputs = { + "q_input": self.q_input, + } -class QMethodNoInputBenchmark(_QMethodBenchmarkBase): - def forward(self): - getattr(self.q_input, self.op_func)() +class QMethodTensorInputCopyBenchmark(_QMethodBenchmarkBase): + def forward(self, q_input): + return q_input.copy_(q_input) -op_bench.generate_pt_tests_from_op_list( - qmethods_tensor_input_list, +op_bench.generate_pt_test( qmethods_configs_short + qmethods_configs_long, - QMethodTensorInputBenchmark + QMethodTensorInputCopyBenchmark ) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/pt/quantization_test.py b/benchmarks/operator_benchmark/pt/quantization_test.py index 4a83a2d1b75cd..af09a5fa2523f 100644 --- a/benchmarks/operator_benchmark/pt/quantization_test.py +++ b/benchmarks/operator_benchmark/pt/quantization_test.py @@ -56,8 +56,12 @@ def init(self, C, M, N, dtype, mode): self.op = nnq.DeQuantize() self.set_module_name('DequantizePerTensor') - def forward(self): - return self.op(self.input) + self.inputs = { + "input": self.input + } + + def forward(self, input): + return self.op(input) op_bench.generate_pt_test( @@ -98,12 +102,22 @@ def init(self, C, M, N, dtype, axis, mode): if mode == 'D': self.input = self.op(self.input, **self.kwargs) - # Dequantize doesn't take any arguments - self.op = lambda x, **kwargs: x.dequantize() + + def dequant(input, scales, zero_points, axis: int, dtype: int): + return input.dequantize() + self.op = dequant self.set_module_name('DequantizePerChannel') - def forward(self): - return self.op(self.input, **self.kwargs) + self.inputs = { + "input": self.input, + 'scales': torch.tensor([1.0] * channel_len), + 'zero_points': torch.tensor([0] * channel_len), + 'axis': axis, + 'dtype': dtype + } + + def forward(self, input, scales, zero_points, axis: int, dtype: int): + return self.op(input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype) op_bench.generate_pt_test( @@ -141,12 +155,14 @@ def forward(self): class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase): r"""Benchmarks fake quantization with default parameters.""" def init(self, N, C, H, W): - self.input = torch.rand(N, C, H, W) + self.inputs = { + "input": torch.rand(N, C, H, W) + } self.op = tq.FakeQuantize() self.set_module_name('FakeQuantize') - def forward(self): - return self.op(self.input) + def forward(self, input): + return self.op(input) op_bench.generate_pt_test( @@ -160,11 +176,37 @@ def forward(self): # scale and zero point. # original_kernel represents the original fake quantize c++ kernel. +def fakeQuantizePerTensorPyModule( + input, scale, zero_point, + quant_min: int, quant_max: int +): + return _LearnableFakeQuantizePerTensorOp.apply(input, scale, zero_point, quant_min, quant_max, 1.0) + +def fakeQuantizePerTensorLearnableKernel( + input, scale, zero_point, + quant_min: int, quant_max: int +): + return torch._fake_quantize_learnable_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) + +def fakeQuantizePerTensorOriginalKernel( + input, scale, zero_point, + quant_min: int, quant_max: int +): + return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max) + +fake_quantize_per_tensor_ops = op_bench.op_list( + attrs=( + ('py_module', fakeQuantizePerTensorPyModule), + ('learnable_kernel', fakeQuantizePerTensorLearnableKernel), + ('original_kernel', fakeQuantizePerTensorOriginalKernel) + ), + attr_names=('op_name', 'op_func'), +) + fake_quantize_operator_configs_short = op_bench.config_list( cross_product_configs={ 'nbits': (4, 8), 'device': ('cpu', 'cuda'), - 'op_type': ('py_module', 'learnable_kernel', 'original_kernel') }, **fake_quantize_configs_short_dict ) @@ -172,87 +214,114 @@ def forward(self): fake_quantize_operator_configs_long = op_bench.cross_product_configs( nbits=(4, 8), device=('cpu', 'cuda'), - op_type=('py_module', 'learnable_kernel', 'original_kernel'), **fake_quantize_configs_long_dict ) -class FakeQuantizePerTensorOpBenchmark(op_bench.TorchBenchmarkBase): +class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase): r"""Benchmarks 3 different fake quantize per tensor operators.""" - def init(self, N, C, H, W, nbits, device, op_type): + def init(self, N, C, H, W, nbits, device, op_func): self.quant_min = 0 self.quant_max = 2 ** nbits - 1 self.quant_range = 2 ** nbits - self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device) - self.scale = torch.tensor([1.]).to(device) - self.zero_point = torch.tensor([0.]).to(device) - self.input.requires_grad_() - self.scale.requires_grad_() - self.zero_point.requires_grad_() - self.args = [ - self.input, self.scale, self.zero_point, - self.quant_min, self.quant_max - ] - if op_type == 'py_module': - self.op = _LearnableFakeQuantizePerTensorOp.apply - self.args.append(1.) - elif op_type == 'learnable_kernel': - self.op = torch._fake_quantize_learnable_per_tensor_affine - else: - # Replace tensors with float and long types for original per tensor - # fake quantize kernel. - self.args[1], self.args[2] = 1., 0 - self.op = torch.fake_quantize_per_tensor_affine + self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device, requires_grad=self.auto_set()) + self.scale = torch.tensor([1.], requires_grad=self.auto_set()).to(device) + self.zero_point = torch.tensor([0.], requires_grad=self.auto_set()).to(device) + + self.inputs = { + "input": self.input, + "scale": self.scale, + "zero_point": self.zero_point, + "quant_min": self.quant_min, + "quant_max": self.quant_max, + } + self.op_func = op_func - def forward(self): - return self.op(*self.args) + def forward( + self, input, scale, zero_point, + quant_min: int, quant_max: int + ): + return self.op_func(input, scale, zero_point, quant_min, quant_max) -op_bench.generate_pt_test( +op_bench.generate_pt_tests_from_op_list( + fake_quantize_per_tensor_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, - FakeQuantizePerTensorOpBenchmark + FakeQuantizePerTensorBaseOpBenchmark ) - -op_bench.generate_pt_gradient_test( +op_bench.generate_pt_gradient_tests_from_op_list( + fake_quantize_per_tensor_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, - FakeQuantizePerTensorOpBenchmark + FakeQuantizePerTensorBaseOpBenchmark +) + +def fakeQuantizePerChannelPyModule( + input, scale, zero_point, axis: int, + quant_min: int, quant_max: int +): + return _LearnableFakeQuantizePerChannelOp.apply(input, scale, zero_point, axis, quant_min, quant_max, 1.0) + +def fakeQuantizePerChannelLearnableKernel( + input, scale, zero_point, axis: int, + quant_min: int, quant_max: int +): + return torch._fake_quantize_learnable_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) + +def fakeQuantizePerChannelOriginalKernel( + input, scale, zero_point, axis: int, + quant_min: int, quant_max: int +): + return torch.fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) + +fake_quantize_per_channel_ops = op_bench.op_list( + attrs=( + ('py_module', fakeQuantizePerChannelPyModule), + ('learnable_kernel', fakeQuantizePerChannelLearnableKernel), + ('original_kernel', fakeQuantizePerChannelOriginalKernel) + ), + attr_names=('op_name', 'op_func'), ) class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase): r"""Benchmarks 3 different fake quantize per channel operators.""" - def init(self, N, C, H, W, nbits, device, op_type): + def init(self, N, C, H, W, nbits, device, op_func): self.quant_min = 0 self.quant_max = 2 ** nbits - 1 self.quant_range = 2 ** nbits # Axis is chosen with respect to the number of channels: C. self.axis = 1 - self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device) - self.scale = torch.ones(C, device=device, dtype=torch.float32) - self.zero_point = torch.zeros(C, device=device, dtype=torch.float32) - self.input.requires_grad_() - self.scale.requires_grad_() - self.zero_point.requires_grad_() - self.args = [ - self.input, self.scale, self.zero_point, - self.axis, self.quant_min, self.quant_max - ] - if op_type == 'py_module': - self.op = _LearnableFakeQuantizePerChannelOp.apply - self.args.append(1.) - elif op_type == 'learnable_kernel': - self.op = torch._fake_quantize_learnable_per_channel_affine + self.input = torch.rand(N, C, H, W, dtype=torch.float, device=device, requires_grad=self.auto_set()) + + if op_func.__name__ == 'fakeQuantizePerChannelOriginalKernel': + self.scale = torch.ones(C, device=device, dtype=torch.float32, requires_grad=False) + self.zero_point = torch.zeros(C, device=device, dtype=torch.int64, requires_grad=False) else: - self.args[1] = torch.ones(C, device=device, dtype=torch.float32) - self.args[2] = torch.zeros(C, device=device, dtype=torch.int64) - self.op = torch.fake_quantize_per_channel_affine + self.scale = torch.ones(C, device=device, dtype=torch.float32, requires_grad=self.auto_set()) + self.zero_point = torch.zeros(C, device=device, dtype=torch.float32, requires_grad=self.auto_set()) + + self.inputs = { + "input": self.input, + "scale": self.scale, + "zero_point": self.zero_point, + "axis": self.axis, + "quant_min": self.quant_min, + "quant_max": self.quant_max, + } - def forward(self): - return self.op(*self.args) + self.op_func = op_func -op_bench.generate_pt_test( + def forward( + self, input, scale, zero_point, + axis: int, quant_min: int, quant_max: int + ): + return self.op_func(input, scale, zero_point, axis, quant_min, quant_max) + +op_bench.generate_pt_tests_from_op_list( + fake_quantize_per_channel_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, FakeQuantizePerChannelOpBenchmark ) -op_bench.generate_pt_gradient_test( +op_bench.generate_pt_gradient_tests_from_op_list( + fake_quantize_per_channel_ops, fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, FakeQuantizePerChannelOpBenchmark ) diff --git a/benchmarks/operator_benchmark/pt/qunary_test.py b/benchmarks/operator_benchmark/pt/qunary_test.py index 4b800857caffb..2b3cb34ab30c9 100644 --- a/benchmarks/operator_benchmark/pt/qunary_test.py +++ b/benchmarks/operator_benchmark/pt/qunary_test.py @@ -30,13 +30,15 @@ def init(self, M, N, dtype, op_func): f_input = torch.rand(M, N) scale = 1.0 zero_point = 0 - self.q_input = torch.quantize_per_tensor(f_input, scale=scale, + self.inputs = { + "q_input": torch.quantize_per_tensor(f_input, scale=scale, zero_point=zero_point, dtype=dtype) + } self.op_func = op_func - def forward(self): - return self.op_func(self.q_input) + def forward(self, q_input): + return self.op_func(q_input) # TODO: Uncomment the ops whenever they are implemented for quantized tensor. @@ -153,17 +155,19 @@ def forward(self): class QTopkOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, dtype, k): - self.k = k f_input = torch.rand(M, N) scale = 1.0 zero_point = 0 - self.q_input = torch.quantize_per_tensor(f_input, scale=scale, + self.inputs = { + "q_input": torch.quantize_per_tensor(f_input, scale=scale, zero_point=zero_point, - dtype=dtype) + dtype=dtype), + "k": k + } self.set_module_name('qtopk') - def forward(self): - return torch.topk(self.q_input, self.k) + def forward(self, q_input, k: int): + return torch.topk(q_input, k) op_bench.generate_pt_test(qunary_ops_topk_configs_short + qunary_ops_topk_configs_long, QTopkOpBenchmark) diff --git a/benchmarks/operator_benchmark/pt/remainder_test.py b/benchmarks/operator_benchmark/pt/remainder_test.py index ffb38f785b550..1aa7770d63e1b 100644 --- a/benchmarks/operator_benchmark/pt/remainder_test.py +++ b/benchmarks/operator_benchmark/pt/remainder_test.py @@ -47,10 +47,15 @@ def init(self, M, N, K, device, dtype, op_func): # +1 so we don't divide by zero self.divisor = (self.divisor * 40 + 1).to(dtype=dtype) + self.inputs = { + "dividend": self.dividend, + "divisor": self.divisor + } + self.op_func = op_func - def forward(self): - return self.op_func(self.dividend, self.divisor) + def forward(self, dividend, divisor): + return self.op_func(dividend, divisor) op_bench.generate_pt_tests_from_op_list(remainder_ops_list, diff --git a/benchmarks/operator_benchmark/pt/softmax_test.py b/benchmarks/operator_benchmark/pt/softmax_test.py index 65446c5c30ee3..237d9001e017b 100644 --- a/benchmarks/operator_benchmark/pt/softmax_test.py +++ b/benchmarks/operator_benchmark/pt/softmax_test.py @@ -47,11 +47,13 @@ class SoftmaxBenchmark(op_bench.TorchBenchmarkBase): def init(self, N, C, H, W, device, op_func): - self.input_one = torch.rand(N, C, H, W, device=device) + self.inputs = { + "input": torch.rand(N, C, H, W, device=device) + } self.op_func = op_func() - def forward(self): - return self.op_func(self.input_one) + def forward(self, input): + return self.op_func(input) op_bench.generate_pt_tests_from_op_list(softmax_ops_list, diff --git a/benchmarks/operator_benchmark/pt/split_test.py b/benchmarks/operator_benchmark/pt/split_test.py index f4da9437351eb..2972db5d2d1b9 100644 --- a/benchmarks/operator_benchmark/pt/split_test.py +++ b/benchmarks/operator_benchmark/pt/split_test.py @@ -30,12 +30,14 @@ class SplitBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, parts, device): - self.input_one = torch.rand(M, N, device=device) - self.split_size = int(M * N / parts) + self.inputs = { + "input": torch.rand(M, N, device=device), + "split_size": int(M * N / parts) + } self.set_module_name('split') - def forward(self): - return torch.split(self.input_one, self.split_size) + def forward(self, input, split_size: int): + return torch.split(input, split_size) op_bench.generate_pt_test(split_configs_short + split_configs_long, diff --git a/benchmarks/operator_benchmark/pt/sum_test.py b/benchmarks/operator_benchmark/pt/sum_test.py index 6b7fef83469e7..799267dfc7de8 100644 --- a/benchmarks/operator_benchmark/pt/sum_test.py +++ b/benchmarks/operator_benchmark/pt/sum_test.py @@ -33,11 +33,14 @@ def init(self, R, V, dim, contiguous, device): else: self.input_tensor = tensor - self.dim = dim + self.inputs = { + "input_tensor": self.input_tensor, + "dim": dim + } self.set_module_name("sum") - def forward(self): - return self.input_tensor.sum(dim=self.dim) + def forward(self, input_tensor, dim: int): + return input_tensor.sum(dim=dim) op_bench.generate_pt_test(sum_configs, SumBenchmark) diff --git a/benchmarks/operator_benchmark/pt/tensor_to_test.py b/benchmarks/operator_benchmark/pt/tensor_to_test.py new file mode 100644 index 0000000000000..0afaa3191d4ef --- /dev/null +++ b/benchmarks/operator_benchmark/pt/tensor_to_test.py @@ -0,0 +1,43 @@ +import operator_benchmark as op_bench +import torch + +tensor_conversion_short_configs = op_bench.cross_product_configs( + M=(8, 16, 32,), + N=(16, 64, 128,), + device=['cpu', 'cuda'], + tags=['short'], +) + +tensor_conversion_long_configs = op_bench.cross_product_configs( + M=(64, 128, 256, 512,), + N=(256, 512, 1024, 2048,), + device=['cpu', 'cuda'], + tags=['long'], +) + +class FloatToHalfTensorConversionBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, device): + self.inputs = { + "input": torch.rand(M, N, device=device, requires_grad=False, dtype=torch.float) + } + + def forward(self, input): + return input.to(torch.half) + +class HalfToFloatTensorConversionBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, device): + self.inputs = { + "input": torch.rand(M, N, device=device, requires_grad=False, dtype=torch.half) + } + + def forward(self, input): + return input.to(torch.float) + + +op_bench.generate_pt_test(tensor_conversion_short_configs, FloatToHalfTensorConversionBenchmark) +op_bench.generate_pt_test(tensor_conversion_long_configs, FloatToHalfTensorConversionBenchmark) +op_bench.generate_pt_test(tensor_conversion_short_configs, HalfToFloatTensorConversionBenchmark) +op_bench.generate_pt_test(tensor_conversion_long_configs, HalfToFloatTensorConversionBenchmark) + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/unary_test.py b/benchmarks/operator_benchmark/pt/unary_test.py index 1391283b1e103..7fd465d6525d6 100644 --- a/benchmarks/operator_benchmark/pt/unary_test.py +++ b/benchmarks/operator_benchmark/pt/unary_test.py @@ -27,12 +27,43 @@ class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, device, op_func): - self.input_one = torch.rand(M, N, device=device) + self.inputs = { + "input": torch.rand(M, N, device=device) + } self.op_func = op_func - def forward(self): - return self.op_func(self.input_one) + def forward(self, input): + return self.op_func(input) +def bernoulli_(input): + return input.bernoulli_() + +def cauchy_(input): + return input.cauchy_() + +def digamma_(input): + return input.digamma_() + +def exponential_(input): + return input.exponential_() + +def normal_(input): + return input.normal_() + +def random_(input): + return input.random_() + +def sign_(input): + return input.sign_() + +def uniform_(input): + return input.uniform_() + +def half_(input): + return input.half() + +def long_(input): + return input.long() unary_ops_list = op_bench.op_list( attr_names=['op_name', 'op_func'], @@ -105,18 +136,18 @@ def forward(self): ['tanh_', torch.tanh_], ['trunc', torch.trunc], ['trunc_', torch.trunc_], - ['unique', torch.unique], + ['unique', torch.functional._return_output], ['zero_', torch.zero_], - ['bernoulli_', lambda t: t.bernoulli_()], - ['cauchy_', lambda t: t.cauchy_()], - ['digamma_', lambda t: t.digamma_()], - ['exponential_', lambda t: t.exponential_()], - ['normal_', lambda t: t.normal_()], - ['random_', lambda t: t.random_()], - ['sign_', lambda t: t.sign_()], - ['uniform_', lambda t: t.uniform_()], - ['half', lambda t: t.half()], - ['long', lambda t: t.long()], + ['bernoulli_', bernoulli_], + ['cauchy_', cauchy_], + ['digamma_', digamma_], + ['exponential_', exponential_], + ['normal_', normal_], + ['random_', random_], + ['sign_', sign_], + ['uniform_', uniform_], + ['half', half_], + ['long', long_], ], ) diff --git a/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py b/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py index 07e887205300a..b0f2d94869902 100644 --- a/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py +++ b/benchmarks/operator_benchmark/pt_extension/cpp_extension_test.py @@ -1,8 +1,7 @@ -import torch -import cpp_extension # noqa - import unittest +import cpp_extension # noqa +import torch class TestConsumeOp(unittest.TestCase): @@ -23,3 +22,24 @@ def foo(x): value = r(x) self.assertEqual(value, torch.sum(x)) self.assertEqual(occurance, iters) + + def test_jit_consume_op_for_list_input(self): + iters = 6 + + def foo(x): + for i in range(iters): + result = torch.ops.operator_benchmark._consume(torch.chunk(x, 2)) + return result + + r = torch.jit.trace(foo, torch.rand(2, 2)) + + graph = str(r.graph) + occurance = graph.count("aten::chunk") + + x = torch.rand(2, 2) + value = r(x) + + self.assertTrue( + all([torch.allclose(t1, t2) for t1, t2 in zip(value, torch.chunk(x, 2))]) + ) + self.assertEqual(occurance, iters) diff --git a/benchmarks/operator_benchmark/pt_extension/extension.cpp b/benchmarks/operator_benchmark/pt_extension/extension.cpp index 0bbaed886ee64..2dbdfdd8b3e66 100644 --- a/benchmarks/operator_benchmark/pt_extension/extension.cpp +++ b/benchmarks/operator_benchmark/pt_extension/extension.cpp @@ -1,20 +1,28 @@ #include #include +using torch::List; using torch::Tensor; Tensor consume(Tensor a) { return a; } +List consume_list(List a) { + return a; +} + // When JIT tracing is used on function with constant for loop, // the for loop is optimized away because of dead code elimination. // That caused an issue for our op benchmark which needs to run an op // in a loop and report the execution time. This diff resolves that issue by // registering this consume op with correct alias information which is DEFAULT. -auto reg = torch::RegisterOperators() - .op("operator_benchmark::_consume", &consume); +TORCH_LIBRARY_FRAGMENT(operator_benchmark, m) { + m.def("_consume", &consume); + m.def("_consume.list", &consume_list); +} PYBIND11_MODULE(cpp_extension, m) { m.def("_consume", &consume, "consume"); + m.def("_consume_list", &consume_list, "consume_list"); } diff --git a/benchmarks/profiler_benchmark/profiler_bench.py b/benchmarks/profiler_benchmark/profiler_bench.py index 616d1078ee7da..75cd490fed2e6 100644 --- a/benchmarks/profiler_benchmark/profiler_bench.py +++ b/benchmarks/profiler_benchmark/profiler_bench.py @@ -1,33 +1,22 @@ -from functools import partial -import itertools -import statistics +import argparse +import sys import timeit import torch -TENSOR_SIZES = [1, 32, 128, 256, 512] -INTERNAL_ITER = 256 -PARALLEL_TASKS_NUM = 4 -N = 100 +from torch.utils.benchmark import Timer +PARALLEL_TASKS_NUM = 4 +INTERNAL_ITER = None def loop_workload(x): for i in range(INTERNAL_ITER): x = torch.mm(x, x) return x -traced_loop_workload = None -def run_profiler_benchmark_loop(input_x, use_cuda, profiling_enabled): - if profiling_enabled: - with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof: - traced_loop_workload(input_x) - else: - traced_loop_workload(input_x) - -def parallel_task(x): - for i in range(int(INTERNAL_ITER / PARALLEL_TASKS_NUM)): - x = torch.mm(x, x) - return x - def parallel_workload(x): + def parallel_task(x): + for i in range(int(INTERNAL_ITER / PARALLEL_TASKS_NUM)): + x = torch.mm(x, x) + return x futs = [] for i in range(PARALLEL_TASKS_NUM): futs.append(torch.jit._fork(parallel_task, x)) @@ -35,50 +24,77 @@ def parallel_workload(x): torch.jit._wait(futs[i]) return x -traced_parallel_workload = None -def run_profiler_benchmark_parallel(input_x, use_cuda, profiling_enabled): - if profiling_enabled: - with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof: - traced_parallel_workload(input_x) - else: - traced_parallel_workload(input_x) if __name__ == '__main__': - for workload_name in ["loop", "parallel"]: - print("Payload: {}; {} iterations, N = {}\n".format( - workload_name, INTERNAL_ITER, N)) - for params in itertools.product([False, True], TENSOR_SIZES, [False, True]): - use_cuda = params[0] - profiling_tensor_size = params[1] - profiling_enabled = params[2] - - if (use_cuda and not torch.cuda.is_available()): - continue - - print("Profiling {}, tensor size {}x{}, use cuda: {}".format( - "enabled" if profiling_enabled else "disabled", - profiling_tensor_size, profiling_tensor_size, use_cuda)) - - input_x = torch.rand(profiling_tensor_size, profiling_tensor_size) - if use_cuda: - input_x = input_x.cuda() - workload = None - if workload_name == "loop": - workload = partial( - run_profiler_benchmark_loop, input_x, use_cuda, profiling_enabled) - traced_loop_workload = torch.jit.trace(loop_workload, input_x) - elif workload_name == "parallel": - workload = partial( - run_profiler_benchmark_parallel, input_x, use_cuda, profiling_enabled) - traced_parallel_workload = torch.jit.trace( - parallel_workload, input_x) - - runtimes = timeit.repeat(workload, repeat=N, number=1) - avg_time = statistics.mean(runtimes) * 1000.0 - stddev_time = statistics.stdev(runtimes) * 1000.0 - print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format( - avg_time, stddev_time)) - if workload_name == "loop": - print("\ttime per iteration: {:.3f} ms".format( - avg_time / INTERNAL_ITER)) - print() + torch._C._set_graph_executor_optimize(False) + parser = argparse.ArgumentParser( + description='Profiler benchmark') + + parser.add_argument('--with_cuda', action='store_true') + parser.add_argument('--with_stack', action='store_true') + parser.add_argument('--use_script', action='store_true') + parser.add_argument('--use_kineto', action='store_true') + parser.add_argument('--profiling_tensor_size', default=1, type=int) + parser.add_argument('--workload', default='loop', type=str) + parser.add_argument('--internal_iter', default=256, type=int) + parser.add_argument('--timer_min_run_time', default=10, type=int) + parser.add_argument('--cuda_only', action='store_true') + + args = parser.parse_args() + + if args.with_cuda and not torch.cuda.is_available(): + print("No CUDA available") + sys.exit() + + print("Payload: {}, {} iterations; timer min. runtime = {}\n".format( + args.workload, args.internal_iter, args.timer_min_run_time)) + INTERNAL_ITER = args.internal_iter + + for profiling_enabled in [False, True]: + print("Profiling {}, tensor size {}x{}, use cuda: {}, use kineto: {}, with stacks: {}, use script: {}".format( + "enabled" if profiling_enabled else "disabled", + args.profiling_tensor_size, + args.profiling_tensor_size, + args.with_cuda, + args.use_kineto, + args.with_stack, + args.use_script)) + + input_x = torch.rand( + args.profiling_tensor_size, + args.profiling_tensor_size) + + if args.with_cuda: + input_x = input_x.cuda() + + workload = None + assert args.workload in ["loop", "parallel"] + if args.workload == "loop": + workload = loop_workload + else: + workload = parallel_workload + + if args.use_script: + traced_workload = torch.jit.trace(workload, (input_x,)) + workload = traced_workload + + if profiling_enabled: + def payload(): + x = None + with torch.autograd.profiler.profile( + use_cuda=args.with_cuda, + with_stack=args.with_stack, + use_kineto=args.use_kineto, + use_cpu=not args.cuda_only) as prof: + x = workload(input_x) + return x + else: + def payload(): + return workload(input_x) + + t = Timer( + "payload()", + globals={"payload": payload}, + timer=timeit.default_timer, + ).blocked_autorange(min_run_time=args.timer_min_run_time) + print(t) diff --git a/benchmarks/record_function_benchmark/record_function_bench.py b/benchmarks/record_function_benchmark/record_function_bench.py index ddd8243ebf0a3..830328247bb50 100644 --- a/benchmarks/record_function_benchmark/record_function_bench.py +++ b/benchmarks/record_function_benchmark/record_function_bench.py @@ -1,7 +1,7 @@ import argparse import sys import torch -import torch.utils._benchmark as benchmark_utils +import torch.utils.benchmark as benchmark_utils try: diff --git a/benchmarks/sparse/matmul_dlmc_bench.py b/benchmarks/sparse/matmul_dlmc_bench.py new file mode 100644 index 0000000000000..6112b6105e6ff --- /dev/null +++ b/benchmarks/sparse/matmul_dlmc_bench.py @@ -0,0 +1,198 @@ +# Sparse benchmarks + +# These benchmarks are for the sparse matrix functionality. +# They exist for comparing the performance of sparse matrix routines +# torch.sparse.mm(sparse, sparse)` with different backends (CPU/CUDA) +# and with other frameworks such as scipy. + +import sys +from scipy import sparse +import numpy as np +from pathlib import Path +import pandas as pd +import argparse +import torch +import torch.utils.benchmark as benchmark_utils + +def read_matrix_params(path): + sys.stdin = open(path) + nrows, ncols, nnz = map(lambda el: int(el), input().split(', ')) + return (nrows, ncols), nnz + + +def load_matrix(path): + sys.stdin = open(path) + nrows, ncols, nnz = map(lambda el: int(el), input().split(', ')) + index_pointers = map(lambda el: int(el), input().split()) + indices = map(lambda el: int(el), input().split()) + + index_pointers = list(index_pointers) + indices = list(indices) + data = np.random.rand(nnz) + coo = sparse.csr_matrix( + (data, np.array(indices), np.array(index_pointers)), + shape=(nrows, ncols)).tocoo() + return torch.sparse_coo_tensor([coo.row, coo.col], coo.data, coo.shape) + + +def scipy_coo_matmul(mat1, mat2): + result = mat1.dot(mat2).tocoo() + return torch.sparse_coo_tensor([result.row, result.col], result.data, + result.shape) + + +def to_coo_scipy(x): + indices_1 = x._indices().numpy() + values_1 = x._values().numpy() + return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])), + shape=x.shape) + + +def torch_backward(a_dense, b_dense): + a_dense.requires_grad = True + b_dense.requires_grad = True + r1 = a_dense.matmul(b_dense) + f1 = torch.sum(r1) + f1.backward() + + +def sparse_torch_backward(a, b): + a.requires_grad = True + b.requires_grad = True + + r2 = torch.sparse.mm(a, b) + f2 = torch.sparse.sum(r2) + f2.backward() + + +def load_dataset(dataset_path, hidden_size, sparsity, n_limit=20): + current_folder_path = f"{dataset_path}/{sparsity}" + path = Path(current_folder_path) + files = path.glob('**/*.smtx') + xs = [] + ys = [] + print(dataset_path, hidden_size, sparsity) + index = 0 + for elem in files: + if index == n_limit: + break + print('.', end='') + size, nnz = read_matrix_params(elem.as_posix()) + if size[1] == hidden_size: + xs.append(load_matrix(elem.as_posix())) + if size[0] == hidden_size: + ys.append(load_matrix(elem.as_posix())) + index += 1 + print() + return zip(xs, ys) + + +if __name__ == '__main__': + + path = Path() + parser = argparse.ArgumentParser(description='Sparse Matmul Bench') + + parser.add_argument('--path', type=str, help='dataset path') + parser.add_argument('--dataset', + type=str, + help='dataset name', + default='random_pruning') + parser.add_argument('--operation', + type=str, + help='matmul or backward', + default='matmul') + parser.add_argument('--output', + type=str, + help='dataframe output path', + default='/tmp/matmul_bench.pkl') + args = parser.parse_args() + print('path =', args.path) + print('dataset =', args.dataset) + print('operation =', args.operation) + print('output =', args.output) + + dataset_path = args.path + dataset_name = args.dataset + dataset_path = f"{dataset_path}/{dataset_name}" + df_output_path = args.output + tasks = [] + if args.operation == 'matmul': + tasks = [ + ("matmul", "cpu", "torch", "torch.mm(dense_x, dense_y)"), + ("matmul", "cpu", "torch.sparse", "torch.sparse.mm(tx, ty)"), + ("matmul", "cpu", "scipy", + "scipy_coo_matmul(scipy_varx, scipy_vary)"), + ("matmul", "cuda", "torch", + "torch.mm(dense_cuda_x, dense_cuda_y)"), + ("matmul", "cuda", "torch.sparse", + "torch.sparse.mm(tx_cuda, ty_cuda)"), + ] + else: + tasks = [ + ("backward", "cpu", "torch", "torch_backward(dense_x, dense_y)"), + ("backward", "cpu", "torch.sparse", + "sparse_torch_backward(tx, ty)"), + ("backward", "cuda", "torch", + "torch_backward(dense_cuda_x, dense_cuda_y)"), + ("backward", "cuda", "torch.sparse", + "sparse_torch_backward(tx_cuda, ty_cuda)"), + ] + serialized_results = [] + repeats = 2 + timers = [ + benchmark_utils.Timer( + stmt=stmt, + globals={ + "scipy_coo_matmul": scipy_coo_matmul, + "torch_backward": torch_backward, + "sparse_torch_backward": sparse_torch_backward, + "scipy_varx": to_coo_scipy(x), + "scipy_vary": to_coo_scipy(y), + "tx": x, + "ty": y, + "tx_cuda": x.cuda(), + "ty_cuda": y.cuda(), + "dense_cuda_x": x.to_dense().cuda(), + "dense_cuda_y": y.to_dense().cuda(), + "dense_x": x.to_dense(), + "dense_y": y.to_dense(), + }, + label=label, + sub_label=sub_label, + description=f"{sparsity}", + env=device, + # num_threads=num_threads, + ) for hidden_size in [512] + for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98] + for label, device, sub_label, stmt in tasks + for num_threads in [1, 4, 8, 16] + for x, y in load_dataset(dataset_path, hidden_size, sparsity) + ] + measurements = [] + + for i, timer in enumerate(timers * repeats): + m = timer.blocked_autorange(min_run_time=0.05) + serialized_results.append(pickle.dumps(m)) + m.metadata = { + "device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu' + } + measurements.append(m) + print(f"\r{i + 1} / {len(timers) * repeats}", end="") + sys.stdout.flush() + print() + + comparison = benchmark_utils.Compare( + [pickle.loads(i) for i in serialized_results]) + + print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") + comparison.print() + + print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") + comparison.trim_significant_figures() + comparison.colorize() + comparison.print() + + table = [(m.task_spec.sub_label, m.task_spec.description, + m.metadata["device"], m.mean) for m in measurements] + df = pd.DataFrame(table, columns=['method', 'sparsity', 'device', 'time']) + df.to_pickle(df_output_path) diff --git a/benchmarks/sparse/test.sh b/benchmarks/sparse/test.sh new file mode 100644 index 0000000000000..d7a3bc667b652 --- /dev/null +++ b/benchmarks/sparse/test.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +DATASET_ROOT_DIR=$HOME/datasets/ + +# wget https://storage.googleapis.com/sgk-sc2020/dlmc.tar.gz -P $DATASET_ROOT_DIR +# tar -xvf $DATASET_ROOT_DIR/dlmc.tar.gz + +echo "!! SPARSE SPMS TIME BENCHMARK!! " + +python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation matmul --output /tmp/matmul_bench.pkl +python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation backward --output /tmp/backward_bench.pkl + +python plot_results.py -i /tmp/matmul_bench.pkl +python plot_results.py -i /tmp/backward_bench.pkl diff --git a/benchmarks/static_runtime/CMakeLists.txt b/benchmarks/static_runtime/CMakeLists.txt index 6191150dc61b0..0a263c2a5a911 100644 --- a/benchmarks/static_runtime/CMakeLists.txt +++ b/benchmarks/static_runtime/CMakeLists.txt @@ -1,3 +1,7 @@ -list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc) list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc) +list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc) set(STATIC_RUNTIME_BENCHMARK_SRCS ${STATIC_RUNTIME_BENCHMARK_SRCS} PARENT_SCOPE) + +list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc) +list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_static_runtime.cc) +set(STATIC_RUNTIME_TEST_SRCS ${STATIC_RUNTIME_TEST_SRCS} PARENT_SCOPE) diff --git a/benchmarks/static_runtime/deep_wide_pt.cc b/benchmarks/static_runtime/deep_wide_pt.cc index 6ce19abd8c847..c708bf1959a3a 100644 --- a/benchmarks/static_runtime/deep_wide_pt.cc +++ b/benchmarks/static_runtime/deep_wide_pt.cc @@ -17,7 +17,7 @@ class DeepAndWide(Module): def forward(self: __torch__.DeepAndWide, ad_emb_packed: Tensor, user_emb: Tensor, - wide: Tensor) -> Tensor: + wide: Tensor) -> Tuple[Tensor]: _0 = self._fc_b _1 = self._fc_w _2 = self._sigma @@ -29,7 +29,7 @@ class DeepAndWide(Module): dp = torch.flatten(dp_unflatten, 1, -1) input = torch.cat([dp, wide_preproc], 1) fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1) - return torch.sigmoid(fc1) + return (torch.sigmoid(fc1),) )JIT"; const std::string trivial_model_1 = R"JIT( @@ -38,6 +38,25 @@ const std::string trivial_model_1 = R"JIT( return a + b * c + s )JIT"; +const std::string leaky_relu_model_const = R"JIT( + def forward(self, input): + x = torch.leaky_relu(input, 0.1) + x = torch.leaky_relu(x, 0.1) + x = torch.leaky_relu(x, 0.1) + x = torch.leaky_relu(x, 0.1) + return torch.leaky_relu(x, 0.1) +)JIT"; + +const std::string leaky_relu_model = R"JIT( + def forward(self, input, neg_slope): + x = torch.leaky_relu(input, neg_slope) + x = torch.leaky_relu(x, neg_slope) + x = torch.leaky_relu(x, neg_slope) + x = torch.leaky_relu(x, neg_slope) + return torch.leaky_relu(x, neg_slope) +)JIT"; + + void import_libs( std::shared_ptr cu, const std::string& class_name, @@ -81,3 +100,31 @@ torch::jit::Module getTrivialScriptModel() { module.define(trivial_model_1); return module; } + +torch::jit::Module getLeakyReLUScriptModel() { + torch::jit::Module module("leaky_relu"); + module.define(leaky_relu_model); + return module; +} + +torch::jit::Module getLeakyReLUConstScriptModel() { + torch::jit::Module module("leaky_relu_const"); + module.define(leaky_relu_model_const); + return module; +} + +const std::string long_model = R"JIT( + def forward(self, a, b, c): + d = torch.relu(a * b) + e = torch.relu(a * c) + f = torch.relu(e * d) + g = torch.relu(f * f) + h = torch.relu(g * c) + return h +)JIT"; + +torch::jit::Module getLongScriptModel() { + torch::jit::Module module("m"); + module.define(long_model); + return module; +} diff --git a/benchmarks/static_runtime/deep_wide_pt.h b/benchmarks/static_runtime/deep_wide_pt.h index f4f394c7ef630..c473eaf1bb952 100644 --- a/benchmarks/static_runtime/deep_wide_pt.h +++ b/benchmarks/static_runtime/deep_wide_pt.h @@ -1,6 +1,7 @@ #pragma once #include +#include struct DeepAndWide : torch::nn::Module { DeepAndWide(int num_features = 50) { @@ -31,6 +32,107 @@ struct DeepAndWide : torch::nn::Module { torch::Tensor mu_, sigma_, fc_w_, fc_b_; }; +// Implementation using native functions and pre-allocated tensors. +// It could be used as a "speed of light" for static runtime. +struct DeepAndWideFast : torch::nn::Module { + DeepAndWideFast(int num_features = 50) { + mu_ = register_parameter("mu_", torch::randn({1, num_features})); + sigma_ = register_parameter("sigma_", torch::randn({1, num_features})); + fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1})); + fc_b_ = register_parameter("fc_b_", torch::randn({1})); + allocated = false; + prealloc_tensors = {}; + } + + torch::Tensor forward( + torch::Tensor ad_emb_packed, + torch::Tensor user_emb, + torch::Tensor wide) { + torch::NoGradGuard no_grad; + if (!allocated) { + auto wide_offset = at::add(wide, mu_); + auto wide_normalized = at::native::mul(wide_offset, sigma_); + // Placeholder for ReplaceNaN + auto wide_preproc = at::native::clamp(wide_normalized, -10.0, 10.0); + + auto user_emb_t = at::native::transpose(user_emb, 1, 2); + auto dp_unflatten = at::native::bmm_cpu(ad_emb_packed, user_emb_t); + // auto dp = at::native::flatten(dp_unflatten, 1); + auto dp = dp_unflatten.view({dp_unflatten.size(0), 1}); + auto input = at::native::_cat_cpu({dp, wide_preproc}, 1); + + // fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_); + fc_w_t_ = torch::t(fc_w_); + auto fc1 = torch::addmm(fc_b_, input, fc_w_t_); + + auto pred = at::native::sigmoid(fc1); + + prealloc_tensors = {wide_offset, + wide_normalized, + wide_preproc, + user_emb_t, + dp_unflatten, + dp, + input, + fc1, + pred}; + allocated = true; + + return pred; + } else { + // Potential optimization: add and mul could be fused together (e.g. with + // Eigen). + at::add_out(prealloc_tensors[0], wide, mu_); + at::native::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_); + + at::native::clamp_out( + prealloc_tensors[2], prealloc_tensors[1], -10.0, 10.0); + + // Potential optimization: original tensor could be pre-transposed. + // prealloc_tensors[3] = at::native::transpose(user_emb, 1, 2); + if (prealloc_tensors[3].data_ptr() != user_emb.data_ptr()) { + auto sizes = user_emb.sizes(); + auto strides = user_emb.strides(); + prealloc_tensors[3].set_( + user_emb.storage(), + 0, + {sizes[0], sizes[2], sizes[1]}, + {strides[0], strides[2], strides[1]}); + } + + // Potential optimization: call MKLDNN directly. + at::native::bmm_out_cpu( + prealloc_tensors[4], ad_emb_packed, prealloc_tensors[3]); + + if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) { + // in unlikely case that the input tensor changed we need to + // reinitialize the view + prealloc_tensors[5] = + prealloc_tensors[4].view({prealloc_tensors[4].size(0), 1}); + } + + // Potential optimization: we can replace cat with carefully constructed + // tensor views on the output that are passed to the _out ops above. + at::native::_cat_out_cpu( + prealloc_tensors[6], {prealloc_tensors[5], prealloc_tensors[2]}, 1); + at::native::addmm_cpu_out( + prealloc_tensors[7], fc_b_, prealloc_tensors[6], fc_w_t_); + at::native::sigmoid_out(prealloc_tensors[8], prealloc_tensors[7]); + + return prealloc_tensors[8]; + } + } + torch::Tensor mu_, sigma_, fc_w_, fc_b_, fc_w_t_; + std::vector prealloc_tensors; + bool allocated = false; +}; + torch::jit::Module getDeepAndWideSciptModel(int num_features = 50); torch::jit::Module getTrivialScriptModel(); + +torch::jit::Module getLeakyReLUScriptModel(); + +torch::jit::Module getLeakyReLUConstScriptModel(); + +torch::jit::Module getLongScriptModel(); diff --git a/benchmarks/static_runtime/deep_wide_pt_bench.cc b/benchmarks/static_runtime/deep_wide_pt_bench.cc index ef960d28d7eb8..295df23dd7e6a 100644 --- a/benchmarks/static_runtime/deep_wide_pt_bench.cc +++ b/benchmarks/static_runtime/deep_wide_pt_bench.cc @@ -22,6 +22,21 @@ static void BM_deep_wide_base(benchmark::State& state) { } } +static void BM_deep_wide_fast(benchmark::State& state) { + std::shared_ptr net = + std::make_shared(num_features); + + const int batch_size = state.range(0); + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + // warmup + net->forward(ad_emb_packed, user_emb, wide); + for (auto _ : state) { + net->forward(ad_emb_packed, user_emb, wide); + } +} + static void BM_deep_wide_jit_graph_executor(benchmark::State& state) { auto mod = getDeepAndWideSciptModel(); @@ -60,7 +75,8 @@ static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) { static void BM_deep_wide_static(benchmark::State& state) { auto mod = getDeepAndWideSciptModel(); - torch::jit::StaticRuntime runtime(mod); + auto g = torch::jit::PrepareForStaticRuntime(mod); + torch::jit::StaticRuntime runtime(g); const int batch_size = state.range(0); auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); @@ -75,7 +91,83 @@ static void BM_deep_wide_static(benchmark::State& state) { } } +const std::shared_ptr& getStaticGraph() { + static const std::shared_ptr g = + torch::jit::PrepareForStaticRuntime(getDeepAndWideSciptModel()); + return g; +} + +static void BM_deep_wide_static_threaded(benchmark::State& state) { + auto g = getStaticGraph(); + torch::jit::StaticRuntime runtime(g); + + const int batch_size = 1; // state.range(0); + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + + std::vector inputs({ad_emb_packed, user_emb, wide}); + + for (auto _ : state) { + runtime.run(inputs); + } +} + +static void BM_leaky_relu_const(benchmark::State& state) { + auto mod = getLeakyReLUConstScriptModel(); + auto g = torch::jit::PrepareForStaticRuntime(mod); + torch::jit::StaticRuntime runtime(g); + + const int batch_size = state.range(0); + auto data = torch::randn({batch_size, num_features}); + std::vector inputs({data}); + + runtime.run(inputs); + for (auto _ : state) { + runtime.run(inputs); + } +} + +static void BM_leaky_relu(benchmark::State& state) { + auto mod = getLeakyReLUScriptModel(); + auto g = torch::jit::PrepareForStaticRuntime(mod); + torch::jit::StaticRuntime runtime(g); + + const int batch_size = state.range(0); + auto neg_slope = torch::randn(1); + auto data = torch::randn({batch_size, num_features}); + std::vector inputs({data, neg_slope[0]}); + + runtime.run(inputs); + for (auto _ : state) { + runtime.run(inputs); + } +} + +BENCHMARK(BM_leaky_relu)->RangeMultiplier(8)->Ranges({{1, 20}}); +BENCHMARK(BM_leaky_relu_const)->RangeMultiplier(8)->Ranges({{1, 20}}); + +static void BM_long_static_memory_optimization(benchmark::State& state) { + auto mod = getLongScriptModel(); + torch::jit::InferenceModuleOptions opts; + opts.optimize_memory = state.range(1); + auto g = torch::jit::PrepareForStaticRuntime(mod, opts); + torch::jit::StaticRuntime runtime(g); + + const auto N = state.range(0); + auto a = torch::randn({N, N}); + auto b = torch::randn({N, N}); + auto c = torch::randn({N, N}); + std::vector inputs({a, b, c}); + + runtime.run(inputs); + for (auto _ : state) { + runtime.run(inputs); + } +} + BENCHMARK(BM_deep_wide_base)->RangeMultiplier(8)->Ranges({{1, 20}}); +BENCHMARK(BM_deep_wide_fast)->RangeMultiplier(8)->Ranges({{1, 20}}); BENCHMARK(BM_deep_wide_jit_graph_executor) ->RangeMultiplier(8) @@ -86,5 +178,21 @@ BENCHMARK(BM_deep_wide_jit_profiling_executor) ->Ranges({{1, 20}}); BENCHMARK(BM_deep_wide_static)->RangeMultiplier(8)->Ranges({{1, 20}}); - -BENCHMARK_MAIN(); +BENCHMARK(BM_deep_wide_static_threaded)->Threads(8); + +BENCHMARK(BM_long_static_memory_optimization) + ->Args({2<<0, 0}) + ->Args({2<<2, 0}) + ->Args({2<<4, 0}) + ->Args({2<<8, 0}) + ->Args({2<<0, 1}) + ->Args({2<<2, 1}) + ->Args({2<<4, 1}) + ->Args({2<<8, 1}); + +int main(int argc, char** argv) +{ + c10::ParseCommandLineFlags(&argc, &argv); + ::benchmark::Initialize(&argc, argv); + ::benchmark::RunSpecifiedBenchmarks(); +} diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h new file mode 100644 index 0000000000000..a1cd49e3661f9 --- /dev/null +++ b/benchmarks/static_runtime/test_scripts.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +const auto list_construct_script = R"JIT( + def forward(self, a, b): + return [a, b] +)JIT"; + +const auto list_unpack_script = R"JIT( + def forward(self, a, b): + c = [a, b] + x, y = c + z = x + y + return z +)JIT"; + +const auto tuple_construct_script = R"JIT( + def forward(self, a, b): + return (a, b) +)JIT"; + +const auto add_script = R"JIT( + def forward(self, a, b): + return a + b +)JIT"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 3ad0956ced737..251e2654b0135 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1,6 +1,110 @@ #include +#include #include #include "deep_wide_pt.h" +#include "test_scripts.h" + +using namespace caffe2; +using namespace torch; +using namespace torch::jit; +using c10::IValue; + +namespace { +static at::Tensor getTensor(const at::IValue& ival) { + if (ival.isTensor()) { + return ival.toTensor(); + } else if (ival.isTensorList()) { + auto tensor_vec = ival.toTensorVector(); + TORCH_CHECK(tensor_vec.size() == 1); + return tensor_vec[0]; + } else if (ival.isTuple()) { + auto tuple = ival.toTuple(); + auto ivalue_vec = tuple->elements(); + TORCH_CHECK(ivalue_vec.size() == 1); + return ivalue_vec[0].toTensor(); + } else { + CAFFE_THROW("Unknown input IValue"); + } +} + +void compareTensorLists( + const std::vector& l, /* values */ + const std::vector& r /* expects */) { + EXPECT_TRUE(l.size() == r.size()); + for (int i = 0; i < l.size(); ++i) { + ASSERT_TRUE(l[i].isTensor()); + ASSERT_TRUE(r[i].isTensor()); + LOG(INFO) << "output " << i << ": \n" << l[i] << std::endl; + LOG(INFO) << "expect " << i << ": \n" << r[i] << std::endl; + EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor())); + } +} + +void compareTensorLists( + const std::vector& l, /* values */ + const std::vector& r /* expects */) { + EXPECT_TRUE(l.size() == r.size()); + for (int i = 0; i < l.size(); ++i) { + LOG(INFO) << "output " << i << ": \n" << l[i] << std::endl; + LOG(INFO) << "expect " << i << ": \n" << r[i] << std::endl; + EXPECT_TRUE(l[i].equal(r[i])); + } +} + +// Given a model/function in jit script, run the model/function +// with the jit interpreter and static runtime, and compare the results +void testStaticRuntime( + const std::string& jit_script, + const std::vector& args) { + script::Module module("module"); + module.define(jit_script); + + auto expect = module.forward(args); + + StaticRuntime runtime(module); + auto actual = runtime.run(args, {}); + + if (expect.isTuple()) { + compareTensorLists( + expect.toTuple()->elements(), actual.toTuple()->elements()); + } else if (expect.isList()) { + compareTensorLists( + expect.toTensorVector(), actual.toTensorVector()); + } else { + EXPECT_TRUE(expect.toTensor().equal(actual.toTensor())); + } +} +} // namespace + +TEST(StaticRuntime, IndividualOps_Binary) { + auto a = at::randn({2, 3}); + auto b = at::ones({2, 3}); + + std::vector args{a, b}; + + testStaticRuntime(add_script, args); + testStaticRuntime(list_construct_script, args); + testStaticRuntime(list_unpack_script, args); + testStaticRuntime(tuple_construct_script, args); +} + +TEST(StaticRuntime, LongModel) { + torch::jit::Module mod = getLongScriptModel(); + auto a = torch::randn({2, 2}); + auto b = torch::randn({2, 2}); + auto c = torch::randn({2, 2}); + + // run jit graph executor + std::vector input_ivalues({a, b, c}); + at::Tensor output_1 = mod.forward(input_ivalues).toTensor(); + + // run static runtime + std::vector input_tensors({a, b, c}); + auto g = torch::jit::PrepareForStaticRuntime(mod); + torch::jit::StaticRuntime runtime(g); + at::Tensor output_2 = runtime.run(input_tensors)[0]; + EXPECT_TRUE(output_1.equal(output_2)); +} TEST(StaticRuntime, TrivialModel) { torch::jit::Module mod = getTrivialScriptModel(); @@ -14,7 +118,166 @@ TEST(StaticRuntime, TrivialModel) { // run static runtime std::vector input_tensors({a, b, c}); - torch::jit::StaticRuntime runtime(mod); + auto g = torch::jit::PrepareForStaticRuntime(mod); + torch::jit::StaticRuntime runtime(g); + at::Tensor output_2 = runtime.run(input_tensors)[0]; + EXPECT_TRUE(output_1.equal(output_2)); +} + +TEST(StaticRuntime, LeakyReLU) { + torch::jit::Module mod = getLeakyReLUConstScriptModel(); + auto inputs = torch::randn({2, 2}); + + // run jit graph executor + std::vector input_ivalues({inputs}); + at::Tensor output_1 = mod.forward(input_ivalues).toTensor(); + + // run static runtime + std::vector input_tensors({inputs}); + auto g = torch::jit::PrepareForStaticRuntime(mod); + torch::jit::StaticRuntime runtime(g); at::Tensor output_2 = runtime.run(input_tensors)[0]; EXPECT_TRUE(output_1.equal(output_2)); } + +TEST(StaticRuntime, DeepWide) { + const int embedding_size = 32; + const int num_features = 50; + torch::jit::Module mod = getDeepAndWideSciptModel(); + auto g = torch::jit::PrepareForStaticRuntime(mod); + torch::jit::StaticRuntime runtime(g); + + for (int batch_size : {1, 8, 32}) { + for (int i = 0; i < 2; ++i) { + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + + // run jit graph executor + std::vector inputs({ad_emb_packed, user_emb, wide}); + auto output_1 = getTensor(mod.forward(inputs)); + + // run static runtime + std::vector input_tensors({ad_emb_packed, user_emb, wide}); + at::Tensor output_2 = runtime.run(input_tensors)[0]; + EXPECT_TRUE(output_1.equal(output_2)); + } + } +} + +TEST(StaticRuntime, KWargsAPI_1) { + const int embedding_size = 32; + const int num_features = 50; + auto module = getDeepAndWideSciptModel(); + torch::jit::StaticRuntime runtime(module); + + for (int batch_size : {1, 8, 32}) { + for (int i = 0; i < 2; ++i) { + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + + // run jit graph executor + std::vector inputs({ad_emb_packed, user_emb, wide}); + at::Tensor output_1 = getTensor(module.forward(inputs)); + + // run static runtime + at::Tensor output_2 = getTensor(runtime.run(inputs, {})); + EXPECT_TRUE(output_1.equal(output_2)); + } + } +} + +TEST(StaticRuntime, KWargsAPI_2) { + const int embedding_size = 32; + const int num_features = 50; + auto module = getDeepAndWideSciptModel(); + auto g = torch::jit::PrepareForStaticRuntime(module); + torch::jit::StaticRuntime runtime(module); + + for (int batch_size : {1, 8, 32}) { + for (int i = 0; i < 2; ++i) { + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + + // run jit graph executor + std::vector args({ad_emb_packed, user_emb, wide}); + at::Tensor output_1 = getTensor(module.forward(args)); + + std::unordered_map kwargs( + {{"ad_emb_packed", ad_emb_packed}, + {"user_emb", user_emb}, + {"wide", wide}}); + + // run static runtime + at::Tensor output_2 = getTensor(runtime.run({}, kwargs)); + EXPECT_TRUE(output_1.equal(output_2)); + } + } +} + +TEST(StaticRuntime, CleanUpMemory) { + const int embedding_size = 32; + const int num_features = 50; + torch::jit::Module mod = getDeepAndWideSciptModel(); + auto g = torch::jit::PrepareForStaticRuntime(mod); + + for (auto cleanup_memory : {true, false}) { + for (auto enable_out_variant : {true, false}) { + VLOG(1) << "cleanup_memory: " << cleanup_memory + << ", enable_out_variant: " << enable_out_variant; + torch::jit::StaticRuntimeOptions opts{cleanup_memory, enable_out_variant}; + torch::jit::StaticRuntime runtime(g, opts); + + for (int batch_size : {1, 8, 32}) { + for (int i = 0; i < 2; ++i) { + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + + // run jit graph executor + std::vector inputs({ad_emb_packed, user_emb, wide}); + auto output_1 = getTensor(mod.forward(inputs)); + + // run static runtime + std::vector input_tensors( + {ad_emb_packed, user_emb, wide}); + at::Tensor output_2 = runtime.run(input_tensors)[0]; + EXPECT_TRUE(output_1.equal(output_2)); + } + } + } + } +} + +TEST(StaticRuntime, FusionPass) { + const int embedding_size = 32; + const int num_features = 50; + for (int batch_size : {1, 8, 32}) { + for (int i = 0; i < 2; ++i) { + torch::jit::Module module = getDeepAndWideSciptModel(); + auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size}); + auto user_emb = torch::randn({batch_size, 1, embedding_size}); + auto wide = torch::randn({batch_size, num_features}); + + // run jit graph executor + std::vector inputs({ad_emb_packed, user_emb, wide}); + auto output_1 = getTensor(module.forward(inputs)); + + Method method = module.get_method("forward"); + auto graph = method.graph(); + fuseStaticSubgraphs(graph); + bool hit = false; + for (const auto& n : module.get_method("forward").graph()->nodes()) { + if (n->kind() == torch::jit::prim::StaticSubgraph) { + hit = true; + } + } + EXPECT_TRUE(hit); + auto output_2 = getTensor(module.forward(inputs)); + EXPECT_TRUE(output_1.equal(output_2)); + } + } +} + diff --git a/benchmarks/tensorexpr/__main__.py b/benchmarks/tensorexpr/__main__.py index 3653761395272..a1f0a5ee2feda 100644 --- a/benchmarks/tensorexpr/__main__.py +++ b/benchmarks/tensorexpr/__main__.py @@ -111,6 +111,11 @@ def main(): action='store_true', help="Print generated kernel(s).", ) + parser.add_argument( + "--no-dynamic-shape", + action='store_true', + help="Disable shape randomization in dynamic benchmarks.", + ) args = parser.parse_args() @@ -128,6 +133,7 @@ def main(): elif args.cuda_fuser == "nvf": import torch torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(True) torch._C._jit_set_profiling_mode(True) else : diff --git a/benchmarks/tensorexpr/benchmark.py b/benchmarks/tensorexpr/benchmark.py index 7d6127b10181a..6c9b91bc8ec5b 100644 --- a/benchmarks/tensorexpr/benchmark.py +++ b/benchmarks/tensorexpr/benchmark.py @@ -237,6 +237,70 @@ def cuda_pointwise_context(loop_levels, block_count, block_size): if block_size: torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size) +# Auxiliary class to facilitate dynamic input shape +class DynamicShape(object): + r''' + An Auxiliary class for dynamic shape benchmarks + + Pre-computes input with random shapes and also + modifies the compute method so in each call the + fuser sees a different input tensor shape + ''' + + # Number of random inputs in an instance + SAMPLE_SIZE = 100 + + def __init__(self, dynamic_range=1.2): + self._input_samples = [] + self._input_sample_index = 0 + self._dynamic_range = 1. / dynamic_range if dynamic_range > 1.0 else dynamic_range + self._enable_dynamic_shapes = True + + # Returns the input test case that current index points to + @property + def inputs(self): + return self._input_samples[self._input_sample_index] + + # An inputs assignment actually adds a test case in the class buffer + @inputs.setter + def inputs(self, val): + self._input_samples.append(val) + + # Runs normal compute while increment test case index + def compute(self): + super().compute() + self._input_sample_index = (self._input_sample_index + 1) % self.SAMPLE_SIZE + + # Defined by benchmark, the benchmark needs to specify the input + # tensor construction in this method, essentially the same way + # a benchmark creates the inputs list in the initializer + def instantiate_input(self): + raise NotImplementedError + + # Instantiate random shaped inputs and start the benchmark run + def run(self, args): + # force disable dynamic shape from command line + if args.no_dynamic_shape: + self._enable_dynamic_shapes = False + self.load_inputs() + super().run(args) + + # pre-compute inputs so the creations of random tensors + # do not add to the compute time + def load_inputs(self): + for i in range(self.SAMPLE_SIZE - 1): + self.instantiate_input() + + # returns a randomized shape + def rand_shape(self, shape): + if not self._enable_dynamic_shapes: + return shape + ratios = np.random.uniform(self._dynamic_range, 1.0, len(shape)) + dyn_shape = list( + np.multiply(shape, ratios).astype(int) + ) + return dyn_shape + benchmark_classes = [] diff --git a/benchmarks/tensorexpr/elementwise.py b/benchmarks/tensorexpr/elementwise.py index 6697b5bcc661c..af1352dfa949e 100644 --- a/benchmarks/tensorexpr/elementwise.py +++ b/benchmarks/tensorexpr/elementwise.py @@ -159,6 +159,7 @@ def register_element_ops(): # benchmark.register_benchmark_class(ElementMulBench) register_element_ops() + class SimpleElementBench(benchmark.Benchmark): def __init__(self, mode, device, dtype, N): super().__init__(mode, device, dtype) @@ -207,4 +208,23 @@ def memory_workload(self): def default_configs(): return [[1 << 25]] + benchmark.register_benchmark_class(SimpleElementBench) + + +class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench): + def __init__(self, mode, device, dtype, N): + benchmark.DynamicShape.__init__(self) + SimpleElementBench.__init__(self, mode, device, dtype, N) + + @classmethod + def module(cls): + return "simple_dynamic_element" + + def instantiate_input(self): + N, = self.rand_shape([self.N]) + data = self.rand([N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad) + self.inputs = [data] + + +benchmark.register_benchmark_class(DynamicSimpleElementBench) diff --git a/benchmarks/tensorexpr/reduction.py b/benchmarks/tensorexpr/reduction.py index 6b8edc10d5ea4..bc3e4e158a175 100644 --- a/benchmarks/tensorexpr/reduction.py +++ b/benchmarks/tensorexpr/reduction.py @@ -149,9 +149,51 @@ def __init__(self, mode, device, dtype, dim0, dim1): @staticmethod def module(): return "reduce2d_outer" - benchmark.register_benchmark_class(ReduceRowBench) benchmark.register_benchmark_class(ReduceMidBench) benchmark.register_benchmark_class(ReduceColBench) benchmark.register_benchmark_class(Reduce2DInnerBench) benchmark.register_benchmark_class(Reduce2DOuterBench) + + +class DynamicReduce2DBench(benchmark.DynamicShape, Reduce2DBench): + ''' + A benchmark class to validate 2 dimensional reduction performance. + Only a simple add is fused to induce the fuser and isolate reduction perf. + ''' + + def __init__(self, mode, device, dtype, red_dim, dim0, dim1): + benchmark.DynamicShape.__init__(self) + Reduce2DBench.__init__(self, mode, device, dtype, red_dim, dim0, dim1) + + def instantiate_input(self): + dim0, dim1 = self.rand_shape([self.dim0, self.dim1]) + + self.inputs = [self.randn( + [dim0, dim1], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad + )] + + @staticmethod + def module(): + return "dynamicreduce2d" + + +class DynamicReduce2DInnerBench(DynamicReduce2DBench): + def __init__(self, mode, device, dtype, dim0, dim1): + super().__init__(mode, device, dtype, 1, dim0, dim1) + + @staticmethod + def module(): + return "reduce2d_dynamic_inner" + + +class DynamicReduce2DOuterBench(DynamicReduce2DBench): + def __init__(self, mode, device, dtype, dim0, dim1): + super().__init__(mode, device, dtype, 0, dim0, dim1) + + @staticmethod + def module(): + return "reduce2d_dynamic_outer" + +benchmark.register_benchmark_class(DynamicReduce2DInnerBench) +benchmark.register_benchmark_class(DynamicReduce2DOuterBench) diff --git a/benchmarks/tensorexpr/rnn_eltwise.py b/benchmarks/tensorexpr/rnn_eltwise.py index 269e59d99685c..a56502c814627 100644 --- a/benchmarks/tensorexpr/rnn_eltwise.py +++ b/benchmarks/tensorexpr/rnn_eltwise.py @@ -65,3 +65,41 @@ def default_configs(): return [[64, 512]] benchmark.register_benchmark_class(RNNEltwise) + + +class DynamicLSTM(benchmark.DynamicShape, RNNEltwise): + def __init__(self, mode, device, dtype, b, hs): + benchmark.DynamicShape.__init__(self) + RNNEltwise.__init__(self, mode, device, dtype, b, hs) + + def instantiate_input(self): + b, hs = self.rand_shape([self.b, self.hs]) + + self.input = self.rand( + [b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad + ) + self.hx = self.rand( + [b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad + ) + self.cx = self.rand( + [b, hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad + ) + self.b_ih = self.rand( + [b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad + ) + self.b_hh = self.rand( + [b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad + ) + self.inputs = [ + self.input, + self.hx, + self.cx, + self.b_ih, + self.b_hh, + ] + + @staticmethod + def module(): + return "dynamic_lstm" + +benchmark.register_benchmark_class(DynamicLSTM) diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt index 075bc05b5ecfb..74df5089e4e31 100644 --- a/binaries/CMakeLists.txt +++ b/binaries/CMakeLists.txt @@ -4,6 +4,7 @@ if(INTERN_BUILD_MOBILE) caffe2_binary_target("speed_benchmark.cc") else() caffe2_binary_target("speed_benchmark_torch.cc") + caffe2_binary_target("compare_models_torch.cc") endif() return() endif() @@ -33,6 +34,7 @@ caffe2_binary_target("print_registered_core_operators.cc") caffe2_binary_target("run_plan.cc") caffe2_binary_target("speed_benchmark.cc") caffe2_binary_target("speed_benchmark_torch.cc") +caffe2_binary_target("compare_models_torch.cc") caffe2_binary_target("split_db.cc") caffe2_binary_target("db_throughput.cc") diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc index 9af1ee51a41e6..fd8c705923e4f 100644 --- a/binaries/benchmark_helper.cc +++ b/binaries/benchmark_helper.cc @@ -19,6 +19,9 @@ #include #include #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif #include #include #endif diff --git a/binaries/compare_models_torch.cc b/binaries/compare_models_torch.cc new file mode 100644 index 0000000000000..6275087fd4fab --- /dev/null +++ b/binaries/compare_models_torch.cc @@ -0,0 +1,262 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +C10_DEFINE_string( + refmodel, + "", + "The reference torch script model to compare against."); +C10_DEFINE_string( + model, + "", + "The torch script model to compare to the reference model."); +C10_DEFINE_string( + input_dims, + "", + "Alternate to input_files, if all inputs are simple " + "float TensorCPUs, specify the dimension using comma " + "separated numbers. If multiple input needed, use " + "semicolon to separate the dimension of different " + "tensors."); +C10_DEFINE_string(input_type, "", "Input type (uint8_t/float)"); +C10_DEFINE_string( + input_memory_format, + "contiguous_format", + "Input memory format (contiguous_format/channels_last)"); +C10_DEFINE_bool( + no_inputs, + false, + "Whether the model has any input. Will ignore other input arugments if true"); +C10_DEFINE_bool( + use_caching_allocator, + false, + "Whether to cache allocations between inference iterations"); +C10_DEFINE_bool( + print_output, + false, + "Whether to print output with all one input tensor."); +C10_DEFINE_int(iter, 10, "The number of iterations to run."); +C10_DEFINE_int(pytext_len, 0, "Length of input sequence."); +C10_DEFINE_string( + backend, + "cpu", + "what backend to use for model (vulkan, cpu, metal) (default=cpu)"); +C10_DEFINE_string( + refbackend, + "cpu", + "what backend to use for model (vulkan, cpu, metal) (default=cpu)"); +C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison"); + +bool checkRtol( + const at::Tensor& diff, + const std::vector& inputs, + float tolerance) { + float maxValue = 0.0f; + + for (const auto& tensor : inputs) { + maxValue = fmax(tensor.abs().max().item(), maxValue); + } + float maxDiff = diff.abs().max().item(); + + return maxDiff < (tolerance * maxValue); +} + +bool almostEqual(const at::Tensor& a, const at::Tensor& b, float tolerance) { + return checkRtol(a - b, {a, b}, tolerance); +} + +std::vector split( + char separator, + const std::string& string, + bool ignore_empty = true) { + std::vector pieces; + std::stringstream ss(string); + std::string item; + while (getline(ss, item, separator)) { + if (!ignore_empty || !item.empty()) { + pieces.push_back(std::move(item)); + } + } + return pieces; +} + +std::vector create_inputs( + std::vector& refinputs, + std::vector& inputs, + std::string& refbackend, + std::string& backend) { + if (FLAGS_no_inputs) { + return {}; + } + + CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified."); + CAFFE_ENFORCE_GE(FLAGS_input_type.size(), 0, "Input type must be specified."); + + std::vector input_dims_list = split(';', FLAGS_input_dims); + std::vector input_type_list = split(';', FLAGS_input_type); + std::vector input_memory_format_list = + split(';', FLAGS_input_memory_format); + + CAFFE_ENFORCE_GE( + input_dims_list.size(), 0, "Input dims not specified correctly."); + CAFFE_ENFORCE_GE( + input_type_list.size(), 0, "Input type not specified correctly."); + CAFFE_ENFORCE_GE( + input_memory_format_list.size(), + 0, + "Input format list not specified correctly."); + + CAFFE_ENFORCE_EQ( + input_dims_list.size(), + input_type_list.size(), + "Input dims and type should have the same number of items."); + CAFFE_ENFORCE_EQ( + input_dims_list.size(), + input_memory_format_list.size(), + "Input dims and format should have the same number of items."); + + for (size_t i = 0; i < input_dims_list.size(); ++i) { + auto input_dims_str = split(',', input_dims_list[i]); + std::vector input_dims; + input_dims.reserve(input_dims_str.size()); + for (const auto& s : input_dims_str) { + input_dims.push_back(c10::stoi(s)); + } + + at::ScalarType input_type; + if (input_type_list[i] == "float") { + input_type = at::ScalarType::Float; + } else if (input_type_list[i] == "uint8_t") { + input_type = at::ScalarType::Byte; + } else if (input_type_list[i] == "int64") { + input_type = at::ScalarType::Long; + } else { + CAFFE_THROW("Unsupported input type: ", input_type_list[i]); + } + + at::MemoryFormat input_memory_format; + if (input_memory_format_list[i] == "channels_last") { + if (input_dims.size() != 4u) { + CAFFE_THROW( + "channels_last memory format only available on 4D tensors!"); + } + input_memory_format = at::MemoryFormat::ChannelsLast; + } else if (input_memory_format_list[i] == "contiguous_format") { + input_memory_format = at::MemoryFormat::Contiguous; + } else { + CAFFE_THROW( + "Unsupported input memory format: ", input_memory_format_list[i]); + } + + const auto input_tensor = torch::rand( + input_dims, + at::TensorOptions(input_type).memory_format(input_memory_format)); + + if (refbackend == "vulkan") { + refinputs.emplace_back(input_tensor.vulkan()); + } else { + refinputs.emplace_back(input_tensor); + } + + if (backend == "vulkan") { + inputs.emplace_back(input_tensor.vulkan()); + } else { + inputs.emplace_back(input_tensor); + } + } + + if (FLAGS_pytext_len > 0) { + auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64); + if (refbackend == "vulkan") { + refinputs.emplace_back(stensor.vulkan()); + } else { + refinputs.emplace_back(stensor); + } + + if (backend == "vulkan") { + inputs.emplace_back(stensor.vulkan()); + } else { + inputs.emplace_back(stensor); + } + } + + return inputs; +} + +int main(int argc, char** argv) { + c10::SetUsageMessage( + "Run accuracy comparison to a reference model for a pytorch model.\n" + "Example usage:\n" + "./compare_models_torch" + " --refmodel=" + " --model=" + " --iter=20"); + if (!c10::ParseCommandLineFlags(&argc, &argv)) { + std::cerr << "Failed to parse command line flags!" << std::endl; + return 1; + } + + std::stringstream ss(FLAGS_tolerance); + float tolerance = 0; + ss >> tolerance; + + torch::autograd::AutoGradMode guard(false); + torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false); + auto module = torch::jit::load(FLAGS_model); + auto refmodule = torch::jit::load(FLAGS_refmodel); + + module.eval(); + refmodule.eval(); + + c10::CPUCachingAllocator caching_allocator; + c10::optional caching_allocator_guard; + if (FLAGS_use_caching_allocator) { + caching_allocator_guard.emplace(&caching_allocator); + } + std::cout << "Running modules." << std::endl; + + int passed = 0; + for (int i = 0; i < FLAGS_iter; ++i) { + std::vector refinputs; + std::vector inputs; + create_inputs(refinputs, inputs, FLAGS_refbackend, FLAGS_backend); + + const auto refoutput = refmodule.forward(refinputs).toTensor().cpu(); + const auto output = module.forward(inputs).toTensor().cpu(); + + bool check = almostEqual(refoutput, output, tolerance); + if (check) { + passed += 1; + } + } + std::cout << "Output was equal within tolerance " << passed << "/" + << FLAGS_iter + << " times. Pass rate: " << (float)passed / (float)FLAGS_iter * 100 + << std::setprecision(2) << "%" << std::endl; + + return 0; +} diff --git a/binaries/optimize_for_mobile.cc b/binaries/optimize_for_mobile.cc index 2e0f9a052c9f6..991bca7e55871 100644 --- a/binaries/optimize_for_mobile.cc +++ b/binaries/optimize_for_mobile.cc @@ -15,9 +15,10 @@ */ #include - +#include #include "torch/script.h" #include "torch/csrc/jit/api/module.h" +#include #include "torch/csrc/jit/passes/vulkan_rewrite.h" #include "torch/csrc/jit/passes/xnnpack_rewrite.h" #include "torch/csrc/jit/serialization/import.h" @@ -29,6 +30,7 @@ C10_DEFINE_string( "", "Name of the output model to be saved."); C10_DEFINE_string(backend, "", "The backend to be optimized"); +C10_DEFINE_string(preserved_methods, "", "Methods to be preserved") int main(int argc, char** argv) { c10::SetUsageMessage( @@ -36,7 +38,8 @@ int main(int argc, char** argv) { "./optimize_for_mobile" " --model=" " [--output=]" - " [--backend=]" + " [--backend=]" + " [--preserved_methods=]" ); if (!c10::ParseCommandLineFlags(&argc, &argv)) { @@ -48,12 +51,27 @@ int main(int argc, char** argv) { CAFFE_ENFORCE(FLAGS_model != "", c10::UsageMessage()); std::string output_model_name = - FLAGS_model.substr(0, FLAGS_model.find(".")) + "_optimized.bc"; + FLAGS_model.substr(0, FLAGS_model.find(".")) + "_optimized.ptl"; if (FLAGS_output != "") { output_model_name = FLAGS_output; } + std::vector preserved_methods; + if(FLAGS_preserved_methods != ""){ + std::stringstream ss(FLAGS_preserved_methods); + std::string m; + while(std::getline(ss, m, ';')){ + if(m != ""){ + preserved_methods.emplace_back(std::move(m)); + } + } + std::cout<<"The following methods will be preserved:"< empty_preserved_methods; - optimized_module = torch::jit::vulkanOptimizeForMobile(module, empty_preserved_methods); - } else { + optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods); + } else if (FLAGS_backend == "metal"){ + optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods); + }else{ CAFFE_ENFORCE(false, "Unknown backend: " + FLAGS_backend); } auto new_ops = torch::jit::export_opnames(optimized_module); diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index a7e3383b97f4f..c80f46d756524 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -7,53 +7,79 @@ #include #include -C10_DEFINE_int(iter, 100, "Number of iterations"); -C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations"); -C10_DEFINE_int(rec_fn_iter, 10e6, - "Number of iterations for the pure RecordFunction benchmark"); +C10_DEFINE_int(iter, 10000, "Number of iterations"); +C10_DEFINE_int(sampled_iter, 10e6, + "Number of iterations for the sampled observer benchmark"); namespace { -const int kInnerIter = 100; -const int kNumSampledCb = 2; const int kTensorSize = 16; const int kSmallTensorSize = 1; -const float kSampingProb = 0.1; - const float kLowSamplingProb = 0.0001; } -void setupBenchmarkCallbacks() { - // non-sampled callback - at::addGlobalCallback(at::RecordFunctionCallback( - [&](const at::RecordFunction& fn) {}, - [](const at::RecordFunction&) {}) - .needsInputs(true)); - - // sampled - for (auto idx = 0; idx < kNumSampledCb; ++idx) { - at::addGlobalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) {}, - [](const at::RecordFunction&) {}) - .needsInputs(true) - .samplingProb(kSampingProb) - ); +void addTestCallback( + double sampling_prob = 1.0, + at::RecordFunctionCallback::StartCallback fn = + [](const at::RecordFunction&) -> std::unique_ptr { return nullptr; }) { + auto cb = at::RecordFunctionCallback( + fn, + [](const at::RecordFunction&, at::ObserverContext*) {}) + .needsInputs(false); + if (sampling_prob < 1.0) { + cb.samplingProb(sampling_prob); } + at::addGlobalCallback(cb); } -float runBench(int tensor_size, int outer_iter) { +float runTensorGEMMBench(int tensor_size, int iter) { typedef std::chrono::high_resolution_clock clock; typedef std::chrono::microseconds us; std::chrono::time_point start_time = clock::now(); - for (auto idx = 0; idx < kInnerIter * outer_iter; ++idx) { - torch::mm( - torch::randn({tensor_size, tensor_size}), - torch::randn({tensor_size, tensor_size})); + auto inp = torch::randn({tensor_size, tensor_size}); + for (auto idx = 0; idx < iter; ++idx) { + torch::mm(inp, inp); } auto duration = static_cast( std::chrono::duration_cast(clock::now() - start_time).count()); return duration; } +float runPureRecordFunctionBench(int iter) { + typedef std::chrono::high_resolution_clock clock; + typedef std::chrono::microseconds us; + std::chrono::time_point start_time = clock::now(); + for (auto idx = 0; idx < iter; ++idx) { + bool pre_sampled = false; + if (at::shouldRunRecordFunction(&pre_sampled)) { + at::RecordFunction guard(at::RecordScope::USER_SCOPE, pre_sampled); + if (C10_UNLIKELY(guard.isActive())) { + guard.before("Test", -1); + } + } + } + auto duration = static_cast( + std::chrono::duration_cast(clock::now() - start_time).count()); + return duration; +} + +void runBenchmark() { + float duration = 0; + for (auto tensor_size : std::set({kSmallTensorSize, kTensorSize})) { + duration = runTensorGEMMBench(tensor_size, FLAGS_iter); + std::cout << "Tensor GEMM benchmark (" + << tensor_size + << "x" + << tensor_size + << ", " << FLAGS_iter << "): " << duration + << " us." << std::endl; + } + duration = runPureRecordFunctionBench(FLAGS_iter); + std::cout << "Pure RecordFunction benchmark (" + << FLAGS_iter << "): " + << duration + << " us." << std::endl; +} + int main(int argc, char** argv) { if (!c10::ParseCommandLineFlags(&argc, &argv)) { std::cout << "Failed to parse command line flags" << std::endl; @@ -61,45 +87,40 @@ int main(int argc, char** argv) { } at::enableRecordFunction(); - setupBenchmarkCallbacks(); + at::clearCallbacks(); - auto duration = runBench(kSmallTensorSize, FLAGS_warmup_iter); - std::cout << "Warmup time: " << duration << " us." << std::endl; + std::cout << "Warm up" << std::endl; + runBenchmark(); - for (auto tensor_size : std::set({kSmallTensorSize, kTensorSize})) { - duration = runBench(tensor_size, FLAGS_iter); - std::cout << "Time per iteration (" - << tensor_size - << "x" - << tensor_size - << "): " << (duration/FLAGS_iter) - << " us." << std::endl; - } + std::cout << "Running without observers" << std::endl; + runBenchmark(); + addTestCallback(); + std::cout << "Running with empty non-sampled observer" << std::endl; + runBenchmark(); at::clearCallbacks(); - int cb_count = 0; - at::addGlobalCallback(at::RecordFunctionCallback( - [&](const at::RecordFunction& fn) { + addTestCallback(kLowSamplingProb); + std::cout << "Running with empty sampled observer" << std::endl; + runBenchmark(); + at::clearCallbacks(); + + std::cout << "Checking number of sampled observer invocations" << std::endl; + static int cb_count = 0; + addTestCallback( + kLowSamplingProb, + [](const at::RecordFunction&) -> std::unique_ptr { ++cb_count; - }, - [](const at::RecordFunction&) {}) - .needsInputs(true) - .samplingProb(kLowSamplingProb) + return nullptr; + } ); - typedef std::chrono::high_resolution_clock clock; - typedef std::chrono::microseconds us; - std::chrono::time_point start_time = clock::now(); - for (auto n = 0; n < FLAGS_rec_fn_iter; ++n) { - RECORD_USER_SCOPE("test"); - } - duration = static_cast( - std::chrono::duration_cast(clock::now() - start_time).count()); - std::cout << "Pure RecordFunction runtime of " << FLAGS_rec_fn_iter - << " iterations " << duration + auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter); + + std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter + << " iterations: " << duration << " us, number of callback invocations: " << cb_count - << ", expected number: ~" << (int)(FLAGS_rec_fn_iter * kLowSamplingProb) + << ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb) << " invocations" << std::endl; at::clearCallbacks(); diff --git a/binaries/speed_benchmark_torch.cc b/binaries/speed_benchmark_torch.cc index db78467cfb433..88cc0b5dd9562 100644 --- a/binaries/speed_benchmark_torch.cc +++ b/binaries/speed_benchmark_torch.cc @@ -24,7 +24,7 @@ #include "torch/csrc/jit/serialization/import.h" #include "torch/script.h" -#include "c10/core/CPUCachingAllocator.h" +#include "c10/mobile/CPUCachingAllocator.h" #include using namespace std::chrono; @@ -70,6 +70,8 @@ C10_DEFINE_bool( C10_DEFINE_int(pytext_len, 0, "Length of input sequence."); C10_DEFINE_bool(vulkan, false, "Whether to use Vulkan backend (GPU)."); +namespace { + std::vector split(char separator, const std::string& string, bool ignore_empty = true) { std::vector pieces; @@ -143,14 +145,11 @@ std::vector create_inputs() { "Unsupported input memory format: ", input_memory_format_list[i]); } - const auto input_tensor = torch::ones( - input_dims, - at::TensorOptions(input_type).memory_format(input_memory_format)); - if (FLAGS_vulkan) { - inputs.push_back(input_tensor.vulkan()); - } else { - inputs.push_back(input_tensor); - } + inputs.push_back( + torch::ones( + input_dims, + at::TensorOptions(input_type). + memory_format(input_memory_format))); } if (FLAGS_pytext_len > 0) { @@ -161,6 +160,39 @@ std::vector create_inputs() { return inputs; } +class Runner { + public: + virtual ~Runner() = default; + virtual c10::IValue run( + torch::jit::Module& module, + const std::vector& inputs) { + return module.forward(inputs); + } +}; + +class vkRunner final : public Runner { + public: + virtual ~vkRunner() = default; + virtual c10::IValue run( + torch::jit::Module& module, + const std::vector& inputs) override { + // Upload the input tensor(s) to GPU memory. + inputs_.clear(); + inputs_.reserve(inputs.size()); + for (const auto& input : inputs) { + inputs_.emplace_back(input.toTensor().vulkan()); + } + + // Run, and download the output tensor to system memory. + return module.forward(inputs_).toTensor().cpu(); + } + + private: + std::vector inputs_; +}; + +} // namespace + int main(int argc, char** argv) { c10::SetUsageMessage( "Run speed benchmark for pytorch model.\n" @@ -199,9 +231,13 @@ int main(int argc, char** argv) { inputs = all_inputs.get(FLAGS_use_bundled_input).toTuple()->elements(); } + const std::unique_ptr runner = + FLAGS_vulkan ? std::make_unique() : + std::make_unique(); + module.eval(); if (FLAGS_print_output) { - std::cout << module.forward(inputs) << std::endl; + std::cout << runner->run(module, inputs) << std::endl; } c10::CPUCachingAllocator caching_allocator; @@ -217,7 +253,7 @@ int main(int argc, char** argv) { FLAGS_warmup, "."); for (int i = 0; i < FLAGS_warmup; ++i) { - module.forward(inputs); + runner->run(module, inputs); } std::cout << "Main runs." << std::endl; @@ -231,7 +267,7 @@ int main(int argc, char** argv) { auto micros = timer.MicroSeconds(); for (int i = 0; i < FLAGS_iter; ++i) { auto start = high_resolution_clock::now(); - module.forward(inputs); + runner->run(module, inputs); auto stop = high_resolution_clock::now(); auto duration = duration_cast(stop - start); times.push_back(duration.count()); diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 17fd7e6801229..b175e5bdd6ce2 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -17,12 +17,13 @@ set(C10_USE_GFLAGS ${USE_GFLAGS}) # used in cmake_macros.h.in set(C10_USE_GLOG ${USE_GLOG}) # used in cmake_macros.h.in set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in set(C10_USE_NUMA ${USE_NUMA}) +set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) configure_file( ${CMAKE_CURRENT_LIST_DIR}/macros/cmake_macros.h.in ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h) # Note: if you want to add ANY dependency to the c10 library, make sure you -# check with the core PyTorch developers as the dependendency will be +# check with the core PyTorch developers as the dependency will be # transitively passed on to all libraries dependent on PyTorch. file(GLOB C10_SRCS *.cpp @@ -32,6 +33,7 @@ file(GLOB C10_SRCS core/dispatch/*.cpp core/op_registration/*.cpp core/impl/*.cpp + mobile/*.cpp macros/*.cpp util/*.cpp ) diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp index b62530c926750..b8db967b6e0e6 100644 --- a/c10/core/Allocator.cpp +++ b/c10/core/Allocator.cpp @@ -35,14 +35,14 @@ at::Allocator* GetAllocator(const at::DeviceType& t) { } bool memoryProfilingEnabled() { - const auto& state = ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE); - auto* reporter_ptr = static_cast(state.get()); + auto* reporter_ptr = static_cast( + ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE)); return reporter_ptr && reporter_ptr->memoryProfilingEnabled(); } void reportMemoryUsageToProfiler(void* ptr, int64_t alloc_size, Device device) { - const auto& state = ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE); - auto* reporter_ptr = static_cast(state.get()); + auto* reporter_ptr = static_cast( + ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE)); if (reporter_ptr) { reporter_ptr->reportMemoryUsage(ptr, alloc_size, device); } diff --git a/c10/core/Backend.h b/c10/core/Backend.h index ab0c45ad1fc38..dfbe07efb2376 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -37,6 +37,7 @@ enum class Backend { MSNPU, XLA, Vulkan, + Metal, QuantizedCPU, QuantizedCUDA, Undefined, @@ -107,6 +108,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::XLA; } else if (t == DispatchKey::Vulkan) { return Backend::Vulkan; + } else if (t == DispatchKey::Metal) { + return Backend::Metal; } else if (t == DispatchKey::SparseCPU) { return Backend::SparseCPU; } else if (t == DispatchKey::SparseCUDA) { @@ -150,6 +153,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::MkldnnCPU; case Backend::Vulkan: return DispatchKey::Vulkan; + case Backend::Metal: + return DispatchKey::Metal; case Backend::QuantizedCPU: return DispatchKey::QuantizedCPU; case Backend::QuantizedCUDA: @@ -188,6 +193,8 @@ static inline DeviceType backendToDeviceType(Backend b) { return DeviceType::CUDA; case Backend::Vulkan: return DeviceType::Vulkan; + case Backend::Metal: + return DeviceType::Metal; case Backend::Undefined: AT_ERROR("Undefined backend is not a valid device type"); default: @@ -292,6 +299,8 @@ static inline const char* toString(Backend b) { return "MkldnnCPU"; case Backend::Vulkan: return "Vulkan"; + case Backend::Metal: + return "Metal"; case Backend::QuantizedCPU: return "QuantizedCPU"; case Backend::QuantizedCUDA: diff --git a/c10/core/CPUAllocator.cpp b/c10/core/CPUAllocator.cpp index e830aa4832d07..f733031628fec 100644 --- a/c10/core/CPUAllocator.cpp +++ b/c10/core/CPUAllocator.cpp @@ -1,6 +1,7 @@ #include -#include #include +#include +#include // TODO: rename flags to C10 C10_DEFINE_bool( @@ -44,8 +45,9 @@ void* alloc_cpu(size_t nbytes) { // We might have clowny upstream code that tries to alloc a negative number // of bytes. Let's catch it early. CAFFE_ENFORCE( - ((ptrdiff_t)nbytes) >= 0, - "alloc_cpu() seems to have been called with negative number: ", nbytes); + ((ptrdiff_t)nbytes) >= 0, + "alloc_cpu() seems to have been called with negative number: ", + nbytes); void* data; #ifdef __ANDROID__ @@ -77,7 +79,7 @@ void* alloc_cpu(size_t nbytes) { CHECK( !FLAGS_caffe2_cpu_allocator_do_zero_fill || !FLAGS_caffe2_cpu_allocator_do_junk_fill) - << "Cannot request both zero-fill and junk-fill at the same time"; + << "Cannot request both zero-fill and junk-fill at the same time"; if (FLAGS_caffe2_cpu_allocator_do_zero_fill) { memset(data, 0, nbytes); } else if (FLAGS_caffe2_cpu_allocator_do_junk_fill) { @@ -156,13 +158,20 @@ class DefaultMobileCPUAllocator final : public at::Allocator { // TODO: enable with better TLS support on mobile // profiledCPUMemoryReporter().Delete(pointer); auto allocator_ptr = GetThreadLocalCachingAllocator(); + auto profiling_allocator_ptr = GetThreadLocalProfilingAllocator(); if (allocator_ptr != nullptr) { allocator_ptr->free(pointer); + } else if (profiling_allocator_ptr != nullptr) { + profiling_allocator_ptr->free(pointer); } else { c10::free_cpu(pointer); // This adds extra cost to freeing memory to the default case when // caching allocator is not enabled. CPUCachingAllocator::record_free(pointer); + auto allocation_planner = GetThreadLocalAllocationPlanner(); + if (allocation_planner != nullptr) { + allocation_planner->record_free(pointer); + } } } @@ -179,10 +188,17 @@ class DefaultMobileCPUAllocator final : public at::Allocator { auto alloc_size = PreGuardBytes + nbytes + PostGuardBytes; void* data; auto allocator_ptr = GetThreadLocalCachingAllocator(); + auto profiling_allocator_ptr = GetThreadLocalProfilingAllocator(); if (allocator_ptr != nullptr) { data = allocator_ptr->allocate(alloc_size); + } else if (profiling_allocator_ptr != nullptr) { + data = profiling_allocator_ptr->allocate(alloc_size); } else { data = c10::alloc_cpu(alloc_size); + auto allocation_planner = GetThreadLocalAllocationPlanner(); + if (allocation_planner != nullptr) { + allocation_planner->record_allocation(alloc_size, data); + } } // profiledCPUMemoryReporter().New(data, alloc_size); return { @@ -282,12 +298,32 @@ void ProfiledCPUMemoryReporter::Delete(void* ptr) { return; } if (FLAGS_caffe2_report_cpu_memory_usage) { - LOG(INFO) << "C10 deleted " << nbytes << " bytes, total alloc " - << allocated << " bytes."; + LOG(INFO) << "C10 deleted " << nbytes << " bytes, total alloc " << allocated + << " bytes."; } if (profile_memory) { - reportMemoryUsageToProfiler(ptr, -nbytes, c10::Device(c10::DeviceType::CPU)); + reportMemoryUsageToProfiler( + ptr, -nbytes, c10::Device(c10::DeviceType::CPU)); + } +} + +C10_API at::Allocator* cpu_caching_alloc = nullptr; +C10_API uint8_t cpu_caching_alloc_priority = 0; + +void SetCPUCachingAllocator(Allocator* alloc, uint8_t priority) { + if (priority >= cpu_caching_alloc_priority) { + cpu_caching_alloc = alloc; + cpu_caching_alloc_priority = priority; + } +} + +Allocator* GetCPUCachingAllocator() { + if (cpu_caching_alloc == nullptr) { + VLOG(1) + << "There is not caching allocator registered for CPU, use the default allocator instead."; + return GetAllocator(DeviceType::CPU); } + return cpu_caching_alloc; } } // namespace c10 diff --git a/c10/core/CPUAllocator.h b/c10/core/CPUAllocator.h index 6cf745195f406..e6465007e48f6 100644 --- a/c10/core/CPUAllocator.h +++ b/c10/core/CPUAllocator.h @@ -65,4 +65,11 @@ C10_API at::Allocator* GetDefaultCPUAllocator(); // Get the Default Mobile CPU Allocator C10_API at::Allocator* GetDefaultMobileCPUAllocator(); +// The CPUCachingAllocator is experimental and might disappear in the future. +// The only place that uses it is in StaticRuntime. +// Set the CPU Caching Allocator +C10_API void SetCPUCachingAllocator(Allocator* alloc, uint8_t priority = 0); +// Get the CPU Caching Allocator +C10_API Allocator* GetCPUCachingAllocator(); + } // namespace c10 diff --git a/c10/core/DefaultDtype.cpp b/c10/core/DefaultDtype.cpp index daae181db9d74..583d4452bfbd3 100644 --- a/c10/core/DefaultDtype.cpp +++ b/c10/core/DefaultDtype.cpp @@ -3,21 +3,32 @@ namespace c10 { static auto default_dtype = caffe2::TypeMeta::Make(); +static auto default_dtype_as_scalartype = default_dtype.toScalarType(); static auto default_complex_dtype = caffe2::TypeMeta::Make>(); void set_default_dtype(caffe2::TypeMeta dtype) { - default_dtype = std::move(dtype); - if(dtype == caffe2::TypeMeta::Make()) { - default_complex_dtype = std::move(caffe2::TypeMeta::Make>()); - } else { - default_complex_dtype = std::move(caffe2::TypeMeta::Make>()); + default_dtype = dtype; + default_dtype_as_scalartype = default_dtype.toScalarType(); + switch (default_dtype_as_scalartype) { + case ScalarType::Half: + default_complex_dtype = ScalarType::ComplexHalf; + break; + case ScalarType::Double: + default_complex_dtype = ScalarType::ComplexDouble; + break; + default: + default_complex_dtype = ScalarType::ComplexFloat; + break; } } -const caffe2::TypeMeta& get_default_dtype() { +const caffe2::TypeMeta get_default_dtype() { return default_dtype; } -const caffe2::TypeMeta& get_default_complex_dtype() { +ScalarType get_default_dtype_as_scalartype() { + return default_dtype_as_scalartype; +} +const caffe2::TypeMeta get_default_complex_dtype() { return default_complex_dtype; } } // namespace c10 diff --git a/c10/core/DefaultDtype.h b/c10/core/DefaultDtype.h index 402a6069bfc34..d0a17474bda49 100644 --- a/c10/core/DefaultDtype.h +++ b/c10/core/DefaultDtype.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace caffe2 { class TypeMeta; @@ -8,6 +9,7 @@ class TypeMeta; namespace c10 { C10_API void set_default_dtype(caffe2::TypeMeta dtype); -C10_API const caffe2::TypeMeta& get_default_dtype(); -C10_API const caffe2::TypeMeta& get_default_complex_dtype(); +C10_API const caffe2::TypeMeta get_default_dtype(); +C10_API ScalarType get_default_dtype_as_scalartype(); +C10_API const caffe2::TypeMeta get_default_complex_dtype(); } // namespace c10 diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index 60c40b516f451..dbe38e17f39d9 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -30,7 +30,7 @@ namespace c10 { namespace { DeviceType parse_type(const std::string& device_string) { - static const std::array, 10> types = {{ + static const std::array, 11> types = {{ {"cpu", DeviceType::CPU}, {"cuda", DeviceType::CUDA}, {"mkldnn", DeviceType::MKLDNN}, @@ -41,6 +41,7 @@ DeviceType parse_type(const std::string& device_string) { {"fpga", DeviceType::FPGA}, {"msnpu", DeviceType::MSNPU}, {"xla", DeviceType::XLA}, + {"vulkan", DeviceType::Vulkan}, }}; auto device = std::find_if( types.begin(), @@ -52,7 +53,7 @@ DeviceType parse_type(const std::string& device_string) { return device->second; } AT_ERROR( - "Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string: ", device_string); + "Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ", device_string); } } // namespace diff --git a/c10/core/Device.h b/c10/core/Device.h index f1249e865f8be..04cd711c37b2c 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -15,7 +15,7 @@ namespace c10 { /// A DeviceIndex is not independently meaningful without knowing /// the DeviceType it is associated; try to use Device rather than /// DeviceIndex directly. -using DeviceIndex = int16_t; +using DeviceIndex = int8_t; /// Represents a a compute device on which a tensor is located. A device is /// uniquely identified by a type, which specifies the type of machine it is @@ -93,10 +93,14 @@ struct C10_API Device final { DeviceType type_; DeviceIndex index_ = -1; void validate() { - TORCH_CHECK(index_ == -1 || index_ >= 0, - "Device index must be -1 or non-negative, got ", index_); - TORCH_CHECK(!is_cpu() || index_ <= 0, - "CPU device index must be -1 or zero, got ", index_); + // Removing these checks in release builds noticeably improves + // performance in micro-benchmarks. + // This is safe to do, because backends that use the DeviceIndex + // have a later check when we actually try to switch to that device. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(index_ == -1 || index_ >= 0, + "Device index must be -1 or non-negative, got ", (int)index_); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got ", (int)index_); } }; @@ -112,8 +116,8 @@ struct hash { size_t operator()(c10::Device d) const noexcept { // Are you here because this static assert failed? Make sure you ensure // that the bitmasking code below is updated accordingly! - static_assert(sizeof(c10::DeviceType) == 2, "DeviceType is not 16-bit"); - static_assert(sizeof(c10::DeviceIndex) == 2, "DeviceIndex is not 16-bit"); + static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit"); + static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); // Note [Hazard when concatenating signed integers] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // We must first convert to a same-sized unsigned type, before promoting to @@ -124,8 +128,8 @@ struct hash { // Technically, by C/C++ integer promotion rules, we only need one of the // uint32_t casts to the result type, but we put in both for explicitness's sake. uint32_t bits = - static_cast(static_cast(d.type())) << 16 - | static_cast(static_cast(d.index())); + static_cast(static_cast(d.type())) << 16 + | static_cast(static_cast(d.index())); return std::hash{}(bits); } }; diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 9c8c53b3f0cac..a8bcead7f44e0 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -29,6 +29,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { return lower_case ? "xla" : "XLA"; case DeviceType::Vulkan: return lower_case ? "vulkan" : "VULKAN"; + case DeviceType::Metal: + return lower_case ? "metal" : "METAL"; default: AT_ERROR( "Unknown device: ", @@ -62,6 +64,7 @@ bool isValidDeviceType(DeviceType d) { case DeviceType::MSNPU: case DeviceType::XLA: case DeviceType::Vulkan: + case DeviceType::Metal: return true; default: return false; diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index 0289cf0a02306..86935436ae1c3 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -12,7 +12,7 @@ namespace c10 { -enum class DeviceType : int16_t { +enum class DeviceType : int8_t { CPU = 0, CUDA = 1, // CUDA. MKLDNN = 2, // Reserved for explicit MKLDNN @@ -24,12 +24,12 @@ enum class DeviceType : int16_t { MSNPU = 8, // MSNPU XLA = 9, // XLA / TPU Vulkan = 10, // Vulkan + Metal = 11, //Metal // NB: If you add more devices: // - Change the implementations of DeviceTypeName and isValidDeviceType // in DeviceType.cpp // - Change the number below - COMPILE_TIME_MAX_DEVICE_TYPES = 11, - ONLY_FOR_TEST = 20901, // This device type is only for test. + COMPILE_TIME_MAX_DEVICE_TYPES = 12, }; constexpr DeviceType kCPU = DeviceType::CPU; @@ -39,6 +39,7 @@ constexpr DeviceType kFPGA = DeviceType::FPGA; constexpr DeviceType kMSNPU = DeviceType::MSNPU; constexpr DeviceType kXLA = DeviceType::XLA; constexpr DeviceType kVulkan = DeviceType::Vulkan; +constexpr DeviceType kMetal = DeviceType::Metal; // define explicit int constant constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 70c36de555fcb..36c9ab0d61646 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -21,7 +21,8 @@ const char* toString(DispatchKey t) { return "XLA"; case DispatchKey::Vulkan: return "Vulkan"; - + case DispatchKey::Metal: + return "Metal"; case DispatchKey::MKLDNN: return "MKLDNN"; case DispatchKey::OpenGL: @@ -30,7 +31,6 @@ const char* toString(DispatchKey t) { return "OpenCL"; case DispatchKey::IDEEP: return "IDEEP"; - case DispatchKey::QuantizedCPU: return "QuantizedCPU"; case DispatchKey::QuantizedCUDA: @@ -53,6 +53,9 @@ const char* toString(DispatchKey t) { case DispatchKey::SparseHIP: return "SparseHIP"; + case DispatchKey::NestedTensor: + return "NestedTensor"; + case DispatchKey::PrivateUse1: return "PrivateUse1"; case DispatchKey::PrivateUse2: @@ -71,6 +74,8 @@ const char* toString(DispatchKey t) { return "AutogradCUDA"; case DispatchKey::AutogradXLA: return "AutogradXLA"; + case DispatchKey::AutogradNestedTensor: + return "AutogradNestedTensor"; case DispatchKey::AutogradPrivateUse1: return "AutogradPrivateUse1"; case DispatchKey::AutogradPrivateUse2: @@ -99,6 +104,9 @@ const char* toString(DispatchKey t) { case DispatchKey::Math: return "Math"; + case DispatchKey::DefaultBackend: + return "DefaultBackend"; + case DispatchKey::TESTING_ONLY_GenericWrapper: return "TESTING_ONLY_GenericWrapper"; @@ -114,6 +122,13 @@ std::ostream& operator<<(std::ostream& str, DispatchKey rhs) { return str << toString(rhs); } +// for a given backend key, return the associated autograd key. +// for non-backend keys, return AutogradOther as a default. +// Note: it's convenient and fast to return a default here rather than (say) +// returning an optional, or throwing. But it makes callers +// responsible for either a) enforcing the invariant that only backend keys +// be passed as arguments, or b) interpreting our return value carefully. +// DispatchKey getAutogradKeyFromBackend(DispatchKey t) { switch (t) { case DispatchKey::CPU: @@ -122,6 +137,8 @@ DispatchKey getAutogradKeyFromBackend(DispatchKey t) { return DispatchKey::AutogradCUDA; case DispatchKey::XLA: return DispatchKey::AutogradXLA; + case DispatchKey::NestedTensor: + return DispatchKey::AutogradNestedTensor; case DispatchKey::PrivateUse1: return DispatchKey::AutogradPrivateUse1; case DispatchKey::PrivateUse2: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index b32f991df3659..4b6ca26757bce 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -51,8 +51,8 @@ enum class DispatchKey : uint8_t { // Here are backends which you think of as traditionally specifying // how to implement operations on some device. - CPU, // registered at build/aten/src/ATen/CPUType.cpp - CUDA, // registered at build/aten/src/ATen/CUDAType.cpp + CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp + CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp HIP, // NB: I think this is not actually used, due to Note [Masquerading as // CUDA] FPGA, // Xilinx support lives out of tree at https://gitlab.com/pytorch-complex/vitis_kernels @@ -60,6 +60,7 @@ enum class DispatchKey : uint8_t { // test/cpp_extensions/msnpu_extension.cpp XLA, // lives out of tree at https://github.com/pytorch/xla Vulkan, + Metal, // These are Caffe2 device types which we grandfathered into // DispatchKey. @@ -72,8 +73,8 @@ enum class DispatchKey : uint8_t { // Here are backends which specify more specialized operators // based on the dtype of the tensor. - QuantizedCPU, // registered at build/aten/src/ATen/QuantizedCPUType.cpp - QuantizedCUDA, // registered at build/aten/src/ATen/QuantizedCUDAType.cpp + QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp + QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp ComplexCPU, // lives out of tree at // https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex ComplexCUDA, // and @@ -96,13 +97,14 @@ enum class DispatchKey : uint8_t { // based on the layout of the tensor. Note that the sparse backends // are one case where ordering matters: sparse multi-dispatches with // the corresponding dense tensors, and must be handled before them. - MkldnnCPU, // registered at build/aten/src/ATen/MkldnnCPUType.cpp + MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp // NB: not to be confused with MKLDNN, which is Caffe2 only - SparseCPU, // registered at build/aten/src/ATen/SparseCPUType.cpp - SparseCUDA, // registered at build/aten/src/ATen/SparseCUDAType.cpp + SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp + SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp SparseHIP, // TODO: I think this is not actually used, due to Note // [Masquerading as CUDA] + NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor // Here are reserved backends for user-defined backends, see Note [Private use // DispatchKey] // To see some example about how to use this, check out MSNPU @@ -216,6 +218,7 @@ enum class DispatchKey : uint8_t { AutogradCPU, AutogradCUDA, AutogradXLA, + AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor // Here are some reserved pre-autograd keys for user-defined backends, see // Note [Private use DispatchKey] AutogradPrivateUse1, @@ -273,11 +276,12 @@ enum class DispatchKey : uint8_t { // See Note [Alias Dispatch Key : Autograd] Autograd, - Math, + Math, // registered at build/aten/src/ATen/RegisterMath.cpp + DefaultBackend, // registered at build/aten/src/ATen/RegisterDefaultBackend.cpp // Define an alias key to represent end of alias dispatch keys. // If you add new alias keys after Autograd, please also update it here. - EndOfAliasKeys = Math, // + EndOfAliasKeys = DefaultBackend, // // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // The aliases exist for backwards compatibility reasons, they shouldn't diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index b331fd5a75d11..ef8355ef463c4 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -2,37 +2,22 @@ namespace c10 { -// backend dispatch keys that map to DispatchKey::AutogradOther -constexpr DispatchKeySet autogradother_backends = DispatchKeySet({ - DispatchKey::HIP, - DispatchKey::FPGA, - DispatchKey::MSNPU, - DispatchKey::Vulkan, - DispatchKey::MKLDNN, - DispatchKey::OpenGL, - DispatchKey::OpenCL, - DispatchKey::IDEEP, - DispatchKey::QuantizedCPU, - DispatchKey::QuantizedCUDA, - DispatchKey::ComplexCPU, - DispatchKey::ComplexCUDA, - DispatchKey::CustomRNGKeyId, - DispatchKey::MkldnnCPU, - DispatchKey::SparseCPU, - DispatchKey::SparseCUDA, - DispatchKey::SparseHIP, -}); - // backend_dispatch_keyset should include all runtime backend keys. +// Alias key DispatchKey::DefaultBackend maps to backend_dispatch_keyset constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKeySet({ DispatchKey::CPU, DispatchKey::CUDA, DispatchKey::XLA, + DispatchKey::NestedTensor, DispatchKey::PrivateUse1, DispatchKey::PrivateUse2, DispatchKey::PrivateUse3, }); +bool isBackendDispatchKey(DispatchKey t) { + return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t); +} + // math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset // Alias key DispatchKey::Math maps to math_dispatch_keyset. constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset; @@ -44,11 +29,15 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { return autograd_dispatch_keyset; case DispatchKey::Math: return math_dispatch_keyset; + case DispatchKey::DefaultBackend: + return backend_dispatch_keyset; default: return DispatchKeySet(t); } } +// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys. +// for a non-autograd key, return the empty keyset. DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { switch (t) { case DispatchKey::AutogradCPU: @@ -57,6 +46,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { return DispatchKeySet(DispatchKey::CUDA); case DispatchKey::AutogradXLA: return DispatchKeySet(DispatchKey::XLA); + case DispatchKey::AutogradNestedTensor: + return DispatchKeySet(DispatchKey::NestedTensor); case DispatchKey::AutogradPrivateUse1: return DispatchKeySet(DispatchKey::PrivateUse1); case DispatchKey::AutogradPrivateUse2: diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 58d6beec7f143..58d456b950ed6 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -61,8 +61,8 @@ class DispatchKeySet final { } } // Test if a DispatchKey is in the set - bool has(DispatchKey t) const { - TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); + bool inline has(DispatchKey t) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); return static_cast(repr_ & DispatchKeySet(t).repr_); } // Test if DispatchKeySet is a superset of ks. @@ -124,7 +124,7 @@ class DispatchKeySet final { public: // STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the // set. The iterator is only invalidated by the destruction of the underlying - // DispatchKeySet as the iterator stores a pointer to the raw represenation of + // DispatchKeySet as the iterator stores a pointer to the raw representation of // the DispatchKeySet. class iterator { public: @@ -188,16 +188,45 @@ C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); // autograd_dispatch_keyset should include all runtime autograd keys. // Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset. +// NB: keys in this set also get associated with Math constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, + DispatchKey::AutogradNestedTensor, DispatchKey::AutogradPrivateUse1, DispatchKey::AutogradPrivateUse2, DispatchKey::AutogradPrivateUse3, DispatchKey::AutogradOther, }); +// backend dispatch keys that map to DispatchKey::AutogradOther +// NB: keys in this set also get associated with Math +constexpr DispatchKeySet autogradother_backends = DispatchKeySet({ + DispatchKey::HIP, + DispatchKey::FPGA, + DispatchKey::MSNPU, + DispatchKey::Vulkan, + DispatchKey::Metal, + DispatchKey::MKLDNN, + DispatchKey::OpenGL, + DispatchKey::OpenCL, + DispatchKey::IDEEP, + DispatchKey::QuantizedCPU, + DispatchKey::QuantizedCUDA, + DispatchKey::ComplexCPU, + DispatchKey::ComplexCUDA, + DispatchKey::CustomRNGKeyId, + DispatchKey::MkldnnCPU, + DispatchKey::SparseCPU, + DispatchKey::SparseCUDA, + DispatchKey::SparseHIP, + DispatchKey::Meta, +}); + +// true if t is a backend dispatch key +C10_API bool isBackendDispatchKey(DispatchKey t); + // Resolve alias dispatch key to DispatchKeySet if applicable C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); @@ -206,7 +235,7 @@ C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); // This API exists because we have a use case for checking -// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefind) +// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) // in OperatorEntry.cpp but we disallow it in has() API. C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); diff --git a/c10/core/GeneratorImpl.h b/c10/core/GeneratorImpl.h index fff105a9858b0..84e620e93a72e 100644 --- a/c10/core/GeneratorImpl.h +++ b/c10/core/GeneratorImpl.h @@ -13,6 +13,7 @@ #include #include #include +#include /** * Note [Generator] @@ -42,7 +43,7 @@ * Please use the public mutex_ when using any methods from these classes, except for the * read-only methods. You can learn about the usage by looking into the unittests * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard. - * + * * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making * them non-thread safe and instead making the generator state splittable, to accommodate * forks into other threads). @@ -71,6 +72,8 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target { virtual void set_current_seed(uint64_t seed) = 0; virtual uint64_t current_seed() const = 0; virtual uint64_t seed() = 0; + virtual void set_state(const c10::TensorImpl& new_state) = 0; + virtual c10::intrusive_ptr get_state() const = 0; Device device() const; // See Note [Acquire lock when using random generators] @@ -96,7 +99,7 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target { namespace detail { -CAFFE2_API uint64_t getNonDeterministicRandom(bool is_cuda = false); +TORCH_API uint64_t getNonDeterministicRandom(bool is_cuda = false); } // namespace detail diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index e25814cd0717c..6528f6c8f1101 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -98,7 +98,7 @@ inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { // 1. Please do not combine these helper functions, each helper function handles // exactly one case of sizes + memory_format, by doing this, the strides indices // will be a constant array and we can access it using constant index number, -// the complier will fully unroll the loop on strides indices to gain a better +// the compiler will fully unroll the loop on strides indices to gain a better // performance. // 2. No error check in helper function, caller ensures the correctness of the input // 3. All helper functions have similar comments, only 1st helper function is commented here. @@ -205,7 +205,7 @@ inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArr // a. we identify corner cases where the implementation compromises on. // // By the time accumulated permutation is enabled to replace implicit -// memory_foramt through strides, we should be updating our tests and fix the +// memory_format through strides, we should be updating our tests and fix the // issues in our tests. // // We use Channels Last 2d as an example above. diff --git a/c10/core/Scalar.cpp b/c10/core/Scalar.cpp index 04bba06a91a58..203b544924ec3 100644 --- a/c10/core/Scalar.cpp +++ b/c10/core/Scalar.cpp @@ -3,7 +3,7 @@ namespace c10 { Scalar Scalar::operator-() const { - TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not suppported."); + TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not supported."); if (isFloatingPoint()) { return Scalar(-v.d); } else if (isComplex()) { @@ -13,4 +13,22 @@ Scalar Scalar::operator-() const { } } +Scalar Scalar::conj() const { + if (isComplex()) { + return Scalar(std::conj(v.z)); + } else { + return *this; + } +} + +Scalar Scalar::log() const { + if (isComplex()) { + return std::log(v.z); + } else if (isFloatingPoint()) { + return std::log(v.d); + } else { + return std::log(v.i); + } +} + } // namespace c10 diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 19f0d3b90e6fc..368228e8202e0 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -87,6 +87,45 @@ class C10_API Scalar { } Scalar operator-() const; + Scalar conj() const; + Scalar log() const; + + template::value, int>::type = 0> + bool equal(T num) const { + if (isComplex()) { + auto val = v.z; + return (val.real() == num) && (val.imag() == T()); + } else if (isFloatingPoint()) { + return v.d == num; + } else if (isIntegral(/*includeBool=*/false)) { + return v.i == num; + } else { + // boolean scalar does not equal to a non boolean value + return false; + } + } + + template::value, int>::type = 0> + bool equal(T num) const { + if (isComplex()) { + return v.z == num; + } else if (isFloatingPoint()) { + return (v.d == num.real()) && (num.imag() == T()); + } else if (isIntegral(/*includeBool=*/false)) { + return (v.i == num.real()) && (num.imag() == T()); + } else { + // boolean scalar does not equal to a non boolean value + return false; + } + } + + bool equal(bool num) const { + if (isBoolean()) { + return static_cast(v.i) == num; + } else { + return false; + } + } ScalarType type() const { if (isComplex()) { diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 41980540017c3..29fa2020f6841 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -3,9 +3,12 @@ #include #include #include +#include +#include +#include #include +#include #include -#include #include #include @@ -38,7 +41,8 @@ namespace c10 { _(c10::qint8, QInt8) /* 12 */ \ _(c10::quint8, QUInt8) /* 13 */ \ _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ // If you want to support ComplexHalf for real, add ComplexHalf @@ -67,6 +71,8 @@ enum class ScalarType : int8_t { NumOptions }; +constexpr uint16_t NumScalarTypes = static_cast(ScalarType::NumOptions); + namespace impl { // These are used to map ScalarTypes to C++ types. @@ -93,7 +99,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) #undef SPECIALIZE_ScalarTypeToCPPType -} +} // namespace impl template struct CppTypeToScalarType; @@ -109,6 +115,13 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) #undef SPECIALIZE_CppTypeToScalarType +#define AT_FORALL_INT_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) + #define AT_FORALL_SCALAR_TYPES(_) \ _(uint8_t, Byte) \ _(int8_t, Char) \ @@ -154,70 +167,13 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) #define AT_FORALL_QINT_TYPES(_) \ _(c10::qint8, QInt8) \ _(c10::quint8, QUInt8) \ - _(c10::qint32, QInt32) + _(c10::qint32, QInt32) \ + _(c10::quint4x2, QUInt4x2) #define AT_FORALL_COMPLEX_TYPES(_) \ _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) -static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { -#define DEFINE_CASE(ctype, name) \ - case ScalarType::name: \ - return caffe2::TypeMeta::Make(); - - switch (scalar_type) { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) - case ScalarType::Undefined: - return caffe2::TypeMeta(); - default: - AT_ERROR( - "Unrecognized Scalartype ", - scalar_type, - " (please report this error)"); - } -#undef DEFINE_CASE -} - -static inline c10::optional tryTypeMetaToScalarType( - caffe2::TypeMeta dtype) { -#define DEFINE_IF(ctype, name) \ - if (dtype == caffe2::TypeMeta::Make()) { \ - return {ScalarType::name}; \ - } - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF) -#undef DEFINE_IF - if (dtype == caffe2::TypeMeta()) { - return {ScalarType::Undefined}; - } - return c10::nullopt; -} - -static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { - if (auto scalar_type = tryTypeMetaToScalarType(dtype)) { - return *scalar_type; - } - AT_ERROR( - "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); -} - -inline optional optTypeMetaToScalarType(optional type_meta) { - if (!type_meta.has_value()) { - return c10::nullopt; - } - return typeMetaToScalarType(*type_meta); -} - -static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { - if (auto mt = tryTypeMetaToScalarType(m)) { - return (*mt) == t; - } - return false; -} - -static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { - return t == m; -} - #define DEFINE_CONSTANT(_, name) \ constexpr ScalarType k##name = ScalarType::name; @@ -279,7 +235,7 @@ static inline bool isComplexType(ScalarType t) { static inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types - return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32; + return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2; } static inline ScalarType toQIntType(ScalarType t) { @@ -303,6 +259,8 @@ static inline ScalarType toUnderlying(ScalarType t) { return ScalarType::Char; case ScalarType::QInt32: return ScalarType::Int; + case ScalarType::QUInt4x2: + return ScalarType::Byte; default: return t; } diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h new file mode 100644 index 0000000000000..b6e7f6cf1993d --- /dev/null +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +// these just expose TypeMeta/ScalarType bridge functions in c10 +// TODO move to typeid.h (or codemod away) when TypeMeta et al +// are moved from caffe2 to c10 (see note at top of typeid.h) + +namespace c10 { + +/** + * convert ScalarType enum values to TypeMeta handles + */ +static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { + return caffe2::TypeMeta::fromScalarType(scalar_type); +} + +/** + * convert TypeMeta handles to ScalarType enum values + */ +static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { + return dtype.toScalarType(); +} + +/** + * typeMetaToScalarType(), lifted to optional + */ +static inline optional optTypeMetaToScalarType(optional type_meta) { + if (!type_meta.has_value()) { + return c10::nullopt; + } + return type_meta->toScalarType(); +} + +/** + * convenience: equality across TypeMeta/ScalarType conversion + */ +static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { + return m.isScalarType(t); +} + +static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { + return t == m; +} + +static inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { + return !(t == m); +} + +static inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { + return !(t == m); +} + +} // namespace c10 diff --git a/c10/core/Stream.cpp b/c10/core/Stream.cpp index 9a5c838c73fed..1a56c9d685671 100644 --- a/c10/core/Stream.cpp +++ b/c10/core/Stream.cpp @@ -2,7 +2,7 @@ namespace c10 { -// Not very parseable, but I don't know a good compact syntax for streams. +// Not very parsable, but I don't know a good compact syntax for streams. // Feel free to change this into something more compact if needed. std::ostream& operator<<(std::ostream& stream, const Stream& s) { stream << "stream " << s.id() << " on device " << s.device(); diff --git a/c10/core/Stream.h b/c10/core/Stream.h index 6962be72bf722..62d5261534eea 100644 --- a/c10/core/Stream.h +++ b/c10/core/Stream.h @@ -111,24 +111,24 @@ class Stream final { uint64_t pack() const noexcept { // Are you here because this static assert failed? Make sure you ensure // that the bitmasking code below is updated accordingly! - static_assert(sizeof(DeviceType) == 2, "DeviceType is not 16-bit"); - static_assert(sizeof(DeviceIndex) == 2, "DeviceIndex is not 16-bit"); + static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit"); + static_assert(sizeof(DeviceIndex) == 1, "DeviceIndex is not 8-bit"); static_assert(sizeof(StreamId) == 4, "DeviceIndex is not 32-bit"); // Concat these together into a 64-bit integer // See Note [Hazard when concatenating signed integers] uint64_t bits = - static_cast(static_cast(device_type())) << 48 - | static_cast(static_cast(device_index())) << 32 + static_cast(static_cast(device_type())) << 48 + | static_cast(static_cast(device_index())) << 32 | static_cast(static_cast(id())); return bits; } static Stream unpack(uint64_t bits) { - auto stream_id = static_cast(bits) & 0xFFFFFFFFull; + const auto stream_id = static_cast(bits & 0xFFFFFFFFull); bits >>= 32; - auto device_index = static_cast(bits) & 0xFFFFull; + const auto device_index = static_cast(bits & 0xFFFFull); bits >>= 16; - auto device_type = static_cast(bits); + const auto device_type = static_cast(bits); TORCH_CHECK(isValidDeviceType(device_type)); // Unfortunately, we can't check if the StreamId is valid here; it // will be checked upon first use. diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 8702ed4fdebf4..0451f601abad7 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -44,26 +44,38 @@ const at::Tensor& TensorImpl::grad() const { return autograd_meta_->grad(); } +const at::Tensor& TensorImpl::fw_grad(uint64_t level, const at::Tensor& self) const { + // See TensorImpl::grad() above for explanation about the line below + if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor(); + return autograd_meta_->fw_grad(level, self); +} + +void TensorImpl::set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) { + if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make(); + autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op); +} + TensorImpl::TensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type) + const caffe2::TypeMeta data_type) : TensorImpl(std::move(storage), key_set, data_type, storage.device()) {} -TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) +TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : TensorImpl({}, key_set, data_type, std::move(device_opt)) {} -TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta& data_type, +TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : storage_(std::move(storage)), - sizes_{0}, storage_offset_(0), numel_(0), data_type_(data_type), device_opt_(device_opt) { + + init_bitfields(); + if (!key_set.empty()) { - AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() || - device_opt_.has_value()); + TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value()); // UndefinedTensorImpl is a singleton, so we skip logging it C10_LOG_API_USAGE_ONCE("tensor.create"); } @@ -78,15 +90,16 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: // we would also like to check that non-cpu devices have an index, but some Caffe2 operators create // Storages with default devices. - strides_.push_back(1); } +#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY IntArrayRef TensorImpl::sizes() const { - return sizes_; + return sizes_and_strides_.sizes_arrayref(); } +#endif IntArrayRef TensorImpl::strides() const { - return strides_; + return sizes_and_strides_.strides_arrayref(); } bool TensorImpl::compute_contiguous() const { @@ -95,9 +108,10 @@ bool TensorImpl::compute_contiguous() const { return is_contiguous; int64_t z = 1; for (int64_t d = dim() - 1; d >= 0; d--) { - if (sizes_[d] != 1) { - if (strides_[d] == z) { - z *= sizes_[d]; + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) == z) { + z *= size_d; } else { is_contiguous = false; break; @@ -110,16 +124,17 @@ bool TensorImpl::compute_contiguous() const { bool TensorImpl::compute_channels_last_contiguous_2d() const { // Please don't combine these code, constant array is used here to let // compiler fully unroll the loop to get better performance - switch (sizes_.size()) { + switch (sizes_and_strides_.size()) { case 4: { int64_t expected = 1; for (auto& d : {1, 3, 2, 0}) { - if (sizes_[d] != 1) { - if (strides_[d] != expected) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { return false; } - expected *= sizes_[d]; + expected *= size_d; } } return true; @@ -135,16 +150,17 @@ bool TensorImpl::compute_channels_last_contiguous_2d() const { bool TensorImpl::compute_channels_last_contiguous_3d() const { // Please don't combine these code, constant array is used here to let // compiler fully unroll the loop to get better performance - switch (sizes_.size()) { + switch (sizes_and_strides_.size()) { case 5: { int64_t expected = 1; for (auto& d : {1, 4, 3, 2, 0}) { - if (sizes_[d] != 1) { - if (strides_[d] != expected) { + const auto size_d = sizes_and_strides_.size_at_unchecked(d); + if (size_d != 1) { + if (sizes_and_strides_.stride_at_unchecked(d) != expected) { return false; } - expected *= sizes_[d]; + expected *= size_d; } } return true; @@ -158,16 +174,16 @@ bool TensorImpl::compute_channels_last_contiguous_3d() const { } bool TensorImpl::compute_strides_like_channels_last_2d() const { - return is_channels_last_strides_2d(sizes_, strides_); + return is_channels_last_strides_2d(TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_strides_like_channels_last_3d() const { - return is_channels_last_strides_3d(sizes_, strides_); + return is_channels_last_strides_3d(TensorImpl::sizes(), TensorImpl::strides()); } bool TensorImpl::compute_non_overlapping_and_dense() const { if (dim() == 1) { - return sizes_[0] < 2 || strides_[0] == 1; + return sizes_and_strides_.size_at_unchecked(0) < 2 || sizes_and_strides_.stride_at_unchecked(0) == 1; } SmallVector perm; perm.resize(dim()); @@ -176,22 +192,23 @@ bool TensorImpl::compute_non_overlapping_and_dense() const { } // Sort by strides, leaving 0 and 1 sized dims at the end of the array std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { - if (sizes_[a] < 2) { + if (sizes_and_strides_.size_at_unchecked(a) < 2) { return false; - } else if (sizes_[b] < 2) { + } else if (sizes_and_strides_.size_at_unchecked(b) < 2) { return true; } - return strides_[a] < strides_[b]; + return sizes_and_strides_.stride_at_unchecked(a) < sizes_and_strides_.stride_at_unchecked(b); }); auto require_stride = 1; for (int64_t i = 0; i < dim(); i ++) { - if (sizes_[perm[i]] < 2) { + const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]); + if (size_perm_i < 2) { return true; } - if (strides_[perm[i]] != require_stride) { + if (sizes_and_strides_.stride_at_unchecked(perm[i]) != require_stride) { return false; } - require_stride *= sizes_[perm[i]]; + require_stride *= size_perm_i; } return true; } @@ -203,18 +220,20 @@ void TensorImpl::release_resources() { } } +#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY int64_t TensorImpl::dim() const { - return sizes_.size(); + return sizes_and_strides_.size(); } +#endif int64_t TensorImpl::size(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); - return sizes_[d]; + return sizes_and_strides_.size_at_unchecked(d); } int64_t TensorImpl::stride(int64_t d) const { d = at::maybe_wrap_dim(d, dim(), false); - return strides_[d]; + return sizes_and_strides_.stride_at_unchecked(d); } bool TensorImpl::has_storage() const { @@ -287,14 +306,44 @@ c10::AutogradMetaInterface* TensorImpl::autograd_meta() const { return autograd_meta_.get(); } -void TensorImpl::copy_tensor_metadata( +c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + auto impl = c10::make_intrusive( + // No need to populate Storage; copy_tensor_metadata will do it for us. + key_set_, data_type_, device_opt_); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + impl->refresh_contiguous(); + return impl; +} + +c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + auto impl = c10::make_intrusive( + // No need to populate Storage; copy_tensor_metadata will do it for us. + key_set_, data_type_, device_opt_); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/std::move(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + impl->refresh_contiguous(); + return impl; +} + +void TensorImpl::copy_tensor_metadata_except_version_counter( const TensorImpl* src_impl, TensorImpl* dest_impl, - const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) { dest_impl->storage_ = src_impl->storage_; - dest_impl->sizes_ = src_impl->sizes_; - dest_impl->strides_ = src_impl->strides_; + dest_impl->sizes_and_strides_ = src_impl->sizes_and_strides_; dest_impl->storage_offset_ = src_impl->storage_offset_; dest_impl->data_type_ = src_impl->data_type_; dest_impl->device_opt_ = src_impl->device_opt_; @@ -307,13 +356,30 @@ void TensorImpl::copy_tensor_metadata( dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_; dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_; dest_impl->reserved_ = src_impl->reserved_; - dest_impl->set_version_counter(version_counter); dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); if (src_impl->named_tensor_meta_ != nullptr) { dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone(); } } +void TensorImpl::copy_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) { + copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change); + dest_impl->set_version_counter(version_counter); +} + +void TensorImpl::copy_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) { + copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change); + dest_impl->set_version_counter(std::move(version_counter)); +} + namespace impl { namespace { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 5b383303df920..944d3098be1fa 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -10,8 +11,10 @@ #include #include #include +#include #include + #include #include #include @@ -19,7 +22,7 @@ #include // A global boolean variable to control whether we free memory when a Tensor -// is shrinked to a smaller size. As a result, a Tensor is always going to +// is shrunk to a smaller size. As a result, a Tensor is always going to // keep the memory allocated for its maximum capacity reshaped to so far. // // This parameter is respected "upper-case" methods which call Resize() @@ -136,6 +139,8 @@ struct C10_API AutogradMetaInterface { virtual bool requires_grad() const = 0; virtual at::Tensor& mutable_grad() = 0; virtual const at::Tensor& grad() const = 0; + virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const = 0; + virtual void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) = 0; virtual ~AutogradMetaInterface(); }; @@ -242,6 +247,21 @@ struct C10_API VariableVersion { } }; +/** + * NOTE: Some TensorImpl methods are small and not overridden in the + * PyTorch codebase itself, but may theoretically need to be + * overridden by third-party TensorImpl subclasses. This macro allows + * users that need maximum performance and don't need these extension + * points to disable them with a build-time flag. (In particular, + * XLA's XLATensorImpl currently overrides these methods, so we can't + * enable this flag by default.) + */ +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY +#define TENSORIMPL_MAYBE_VIRTUAL +#else +#define TENSORIMPL_MAYBE_VIRTUAL virtual +#endif + /** * The low-level representation of a tensor, which contains a pointer * to a storage (which contains the actual data) and metadata (e.g., sizes and @@ -322,24 +342,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { TensorImpl( Storage&& storage, DispatchKeySet, - const caffe2::TypeMeta& data_type); + const caffe2::TypeMeta data_type); /** * Construct a 1-dim 0 size tensor that doesn't have a storage. */ - TensorImpl(DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional device_opt); + TensorImpl(DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional device_opt); // Legacy constructors so I don't have to go update call sites. // TODO: When Variable is added, delete these constructors TensorImpl( Storage&& storage, DispatchKey dispatch_key, - const caffe2::TypeMeta& data_type) + const caffe2::TypeMeta data_type) : TensorImpl( std::move(storage), DispatchKeySet(dispatch_key), data_type) {} - TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta& data_type, c10::optional device_opt) + TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta data_type, c10::optional device_opt) : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} private: @@ -347,7 +367,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // storage. Still, we pass it in separately because it's easier to write // the initializer list if we're not worried about storage being moved out // from under us. - TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional); + TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional); public: TensorImpl(const TensorImpl&) = delete; @@ -373,7 +393,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Return a reference to the sizes of this tensor. This reference remains * valid as long as the tensor is live and not resized. */ - virtual IntArrayRef sizes() const; + TENSORIMPL_MAYBE_VIRTUAL IntArrayRef sizes() const +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY + { + return sizes_and_strides_.sizes_arrayref(); + } +#else + ; +#endif /** * Return a reference to the strides of this tensor. This reference remains @@ -385,7 +412,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Return the number of dimensions of this tensor. Note that 0-dimension * represents a Tensor that is a Scalar, e.g., one that has a single element. */ - virtual int64_t dim() const; + TENSORIMPL_MAYBE_VIRTUAL int64_t dim() const +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY + { + return sizes_and_strides_.size(); + } +#else + ; +#endif /** * True if this tensor has storage. See storage() for details. @@ -410,7 +444,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * is no longer true; numel always accurately reports the product * of sizes of a tensor. */ - virtual int64_t numel() const { + TENSORIMPL_MAYBE_VIRTUAL int64_t numel() const { #ifdef DEBUG TORCH_INTERNAL_ASSERT(compute_numel() == numel_); #endif @@ -469,6 +503,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return key_set_.has(DispatchKey::Vulkan); } + bool is_metal() const { + return key_set_.has(DispatchKey::Metal); + } + // TODO: remove this once we don't automatically enabled Autograd dispatch keys // in TensorImpl constructor. // DON'T USE THIS API!! It's only created for testing purpose in @@ -558,9 +596,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Set whether or not a tensor requires gradient. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ void set_requires_grad(bool requires_grad); @@ -570,30 +605,57 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * we can automatically differentiate back to them. A tensor that * requires gradient and has no history is a "leaf" tensor, which we * accumulate gradients into. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ bool requires_grad() const; /** * Return a mutable reference to the gradient. This is conventionally * used as `t.grad() = x` to set a gradient to a completely new tensor. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ at::Tensor& mutable_grad(); /** * Return the accumulated gradient of a tensor. This gradient is written * into when performing backwards, when this tensor is a leaf tensor. - * - * It is only valid to call this method on a Variable. - * See Note [Tensor versus Variable in C++]. */ const at::Tensor& grad() const; + /** + * Return the accumulated gradient of a tensor. This gradient is computed + * using forward mode AD. + * + * This is an internal API that should never be used by end users. + * + * The API is as follows: + * - "level" allows to specify the level of forward AD nesting for which the + * gradient should be returned. Note that since levels are not fully + * supported yet, this argument should be 0. See documentation for + * torch::autograd::enter_dual_level for more details about forward AD nesting. + * - "self" should represent the Tensor whose forward grad is accessed. It is + * required when dealing with view. + */ + const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const; + + /** + * Sets the forward gradient for this Tensor. + * The given Tensor might not be used directly and its content will be copied. + * + * This is an internal API that should never be used by end users. + * + * The API is as follows: + * - "new_grad" is a Tensor containing the new value of the gradient that should + * be set + * - "self" should represent the Tensor whose forward grad is accessed. It is + * required when dealing with view. + * - "level" allows to specify the level of forward AD nesting for which the + * gradient should be set. Note that since levels are not fully supported + * yet, this argument should be 0. See documentation for torch::autograd::enter_dual_level + * for more details about forward AD nesting. + * - "is_inplace_op" is a boolean flag that tells if this gradient was generated + * by an inplace operation or an out of place one. This allows better error checking. + */ + void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op); + /** * Return a typed data pointer to the actual data which this tensor refers to. * This checks that the requested type (from the template parameter) matches @@ -661,7 +723,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Returns the TypeMeta of a tensor, which describes what data type * it is (e.g., int, float, ...) */ - const caffe2::TypeMeta& dtype() const { + const caffe2::TypeMeta dtype() const { return data_type_; } @@ -705,7 +767,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual void set_size(int64_t dim, int64_t new_size) { TORCH_CHECK(allow_tensor_metadata_change(), "set_size ", err_msg_tensor_metadata_change_not_allowed); - sizes_.at(dim) = new_size; + sizes_and_strides_.size_at(dim) = new_size; refresh_numel(); refresh_contiguous(); } @@ -718,7 +780,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual void set_stride(int64_t dim, int64_t new_stride) { TORCH_CHECK(allow_tensor_metadata_change(), "set_stride ", err_msg_tensor_metadata_change_not_allowed); - strides_[dim] = new_stride; + sizes_and_strides_.stride_at_unchecked(dim) = new_stride; refresh_contiguous(); } @@ -743,12 +805,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ void set_sizes_contiguous(IntArrayRef new_size) { TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous ", err_msg_tensor_metadata_change_not_allowed); - auto new_dim = new_size.size(); - sizes_.resize(new_dim); - for (size_t dim = 0; dim < new_dim; ++dim) { - sizes_[dim] = new_size[dim]; - } + sizes_and_strides_.set_sizes(new_size); refresh_numel(); empty_tensor_restride(MemoryFormat::Contiguous); @@ -770,27 +828,25 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ") must match dimensionality of strides (", new_stride.size(), ")"); - auto new_dim = new_size.size(); + const auto new_dim = new_size.size(); - sizes_.resize(new_dim); - for (size_t dim = 0; dim < new_dim; ++dim) { - sizes_[dim] = new_size[dim]; - } + sizes_and_strides_.set_sizes(new_size); - strides_.resize(new_dim); if (new_dim > 0) { for (size_t dim = new_dim - 1; ; dim--) { if (new_stride[dim] >= 0) { - strides_[dim] = new_stride[dim]; + sizes_and_strides_.stride_at_unchecked(dim) = new_stride[dim]; } else { // XXX: This behavior is surprising and may need to be removed to // support negative strides. Some pytorch functions rely on it: // for example, torch.cat (run TestTorch.test_cat_empty). if (dim == new_dim - 1) { - strides_[dim] = 1; + sizes_and_strides_.stride_at_unchecked(dim) = 1; } else { // Keep stride monotonically increasing to match NumPy. - strides_[dim] = std::max(sizes_[dim + 1], 1) * strides_[dim + 1]; + sizes_and_strides_.stride_at_unchecked(dim) = + std::max(sizes_and_strides_.size_at_unchecked(dim + 1), 1) * + sizes_and_strides_.stride_at_unchecked(dim + 1); } } if (dim == 0) break; @@ -931,18 +987,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual c10::intrusive_ptr shallow_copy_and_detach( const c10::VariableVersion& version_counter, - bool allow_tensor_metadata_change) const { - auto impl = c10::make_intrusive( - Storage(storage()), key_set_, data_type_); - copy_tensor_metadata( - /*src_impl=*/this, - /*dest_impl=*/impl.get(), - /*version_counter=*/version_counter, - /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); - impl->refresh_numel(); - impl->refresh_contiguous(); - return impl; - } + bool allow_tensor_metadata_change) const; + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + virtual c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; /** * Shallow-copies data from another TensorImpl into this TensorImpl. @@ -965,6 +1020,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { version_counter_ = version_counter; } + void set_version_counter( + c10::VariableVersion&& version_counter) noexcept { + version_counter_ = std::move(version_counter); + } + const c10::VariableVersion& version_counter() const noexcept { return version_counter_; } @@ -1014,12 +1074,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * This op is auto-asynchronous if the underlying device (CUDA) supports it. */ void Extend(int64_t num, float growthPct) { - TORCH_CHECK(sizes_.size() >= 1u); + TORCH_CHECK(sizes_and_strides_.size() >= 1u); TORCH_CHECK(num >= 0, "`num` must be non-negative for Extend"); TORCH_CHECK( is_contiguous_, "Right now Extend is only supported for contiguous Tensor."); - auto newDims = sizes_; + using SizesVector = SmallVector; + SizesVector newDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newDims[0] += num; if (!storage_.data()) { Resize(newDims); @@ -1031,16 +1092,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { static_cast(1), std::multiplies()); if (newNumel * data_type_.itemsize() <= storage_.nbytes()) { - sizes_ = newDims; + sizes_and_strides_.set_sizes(newDims); numel_ = newNumel; return; } - auto newCapacity = sizes_; - newCapacity[0] = std::max( - newDims[0], std::ceil(sizes_[0] * (growthPct + 100) / 100)); + SizesVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); + newCapacity[0] = std::max( + newDims[0], static_cast(std::ceil(sizes_and_strides_.size_at_unchecked(0) * (1 + growthPct / 100)))); auto oldData = std::move(storage_.data_ptr()); auto oldSize = numel_; - auto oldDims = sizes_; Resize(newCapacity); auto* newData = raw_mutable_data(data_type_); if (data_type_.copy()) { @@ -1067,7 +1127,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { true); // non-blocking } reserved_ = true; - sizes_ = newDims; + sizes_and_strides_.set_sizes(newDims); numel_ = newNumel; } @@ -1084,7 +1144,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { "Right now ReserveSpace is only supported for contiguous Tensor."); TORCH_CHECK( storage_.unique(), "Can't call ReserveSpace on shared storage."); - auto newCapacity = sizes_; + // TODO: eliminate newCapacity. + SmallVector newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); newCapacity[0] = outer_dim; auto newNumel = std::accumulate( newCapacity.begin(), @@ -1097,11 +1158,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // Old data is discarded storage_.data_ptr().clear(); auto oldSize = numel_; - auto oldDims = sizes_; + SmallVector oldDims(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end()); Resize(newCapacity); // Allocate new memory but don't copy over the data raw_mutable_data(data_type_); - sizes_ = oldDims; + sizes_and_strides_.set_sizes(oldDims); numel_ = oldSize; reserved_ = true; } @@ -1171,7 +1232,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { " The old caffe2 mixes Reshape and Resize but this behavior has " "been changed. If you find this error, most likely you will need " "to change corresponding code from Reshape to Resize."); - sizes_ = dims; + sizes_and_strides_.set_sizes(dims); empty_tensor_restride(MemoryFormat::Contiguous); } @@ -1231,10 +1292,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void ShareExternalPointer( DataPtr&& data_ptr, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, size_t size_bytes) { TORCH_CHECK( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "To share with a raw external pointer you need to pass in an " "initialized data_type(TypeMeta)."); if (!size_bytes) { @@ -1271,7 +1332,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * If the existing data does not match the desired type, it will be deleted * and a new storage will be created. */ - inline void* raw_mutable_data(const caffe2::TypeMeta& meta) { + inline void* raw_mutable_data(const caffe2::TypeMeta meta) { // For 0-size tensors it's fine to return any pointer (including nullptr) if (data_type_ == meta && storage_initialized()) { return static_cast(static_cast(storage_.data()) + storage_offset_ * meta.itemsize()); @@ -1335,7 +1396,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // error in attempt to invoke TypeMeta::ctor() static_assert( std::is_default_constructible::value, - "Tensor can't hold non-default-constructible types"); + "Tensor can't hold non-default-constructable types"); return static_cast(raw_mutable_data(caffe2::TypeMeta::Make())); } @@ -1365,7 +1426,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_storage_and_dtype( at::Storage storage, - const caffe2::TypeMeta& data_type) { + const caffe2::TypeMeta data_type) { set_storage_keep_dtype(storage); data_type_ = data_type; } @@ -1376,7 +1437,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * WARNING: This function doesn't rearrange data and assumes tensor is a memory * contiguous */ - virtual void empty_tensor_restride(MemoryFormat memory_format) { + void empty_tensor_restride(MemoryFormat memory_format) { #ifdef DEBUG TORCH_INTERNAL_ASSERT(compute_numel() == numel_, "If you are seeing this error, that means empty_tensor_restride was " @@ -1385,13 +1446,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { switch (memory_format) { case MemoryFormat::Contiguous: { // dim_ is a virtual call, don't repeat it - auto dim_ = dim(); - strides_.resize(dim_); + const auto dim_ = dim(); + sizes_and_strides_.resize(dim_); if (dim_ > 0) { - int last_idx = dim_ - 1; - strides_[last_idx] = 1; + const auto last_idx = dim_ - 1; + sizes_and_strides_.stride_at_unchecked(last_idx) = 1; for (auto i = last_idx - 1; i >= 0; --i) { - strides_[i] = strides_[i + 1] * std::max(sizes_[i + 1], 1); + sizes_and_strides_.stride_at_unchecked(i) = sizes_and_strides_.stride_at_unchecked(i + 1) * std::max(sizes_and_strides_.size_at_unchecked(i + 1), 1); } } break; @@ -1449,11 +1510,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { typename = typename std::enable_if::value>::type> bool SetDimsTemplate(ArrayRef src) { auto old_numel = numel_; - sizes_.resize(src.size()); + sizes_and_strides_.resize(src.size()); int64_t new_numel = 1; for (size_t i = 0; i < src.size(); ++i) { new_numel *= src[i]; - sizes_[i] = src[i]; + sizes_and_strides_.size_at_unchecked(i) = src[i]; } numel_ = new_numel; empty_tensor_restride(MemoryFormat::Contiguous); @@ -1577,6 +1638,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change); + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) + * from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change); + +private: + static void copy_tensor_metadata_except_version_counter( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + bool allow_tensor_metadata_change); + protected: // Error message to show when the user tries to change tensor metadata on // Tensor created from .data or .detach(). @@ -1632,12 +1711,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp PyObject* pyobj_ = nullptr; - // We could save a word or two by combining the SmallVector structs, - // since their size is redundant, and if we need to overflow the buffer space - // we could keep the two pointers together. However, that would require - // implementing another struct from scratch, so only do this if we're desperate. - SmallVector sizes_; - SmallVector strides_; + c10::impl::SizesAndStrides sizes_and_strides_; int64_t storage_offset_ = 0; // If sizes and strides are empty, the numel is 1!! However, most of the @@ -1671,36 +1745,47 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // INVARIANT: named_tensor_meta_ != nullptr <==> key_set_.has(DispatchKey::Named) DispatchKeySet key_set_; - // You get to have eight byte-size fields here, before you - // should pack this into a bitfield. + // Tensor is contiguous bool is_contiguous_ = true; + // default member initializers for bit-fields only available with -std=c++2a or -std=gnu++2a + inline void init_bitfields() { + is_channels_last_ = false; + is_channels_last_contiguous_ = false; + is_channels_last_3d_ = false; + is_channels_last_3d_contiguous_ = false; + is_non_overlapping_and_dense_ = true; + is_wrapped_number_ = false; + allow_tensor_metadata_change_ = true; + reserved_ = false; + } + // Tensor is stored in the channels last 2d memory format, when dimensions // order is (N)CHW and C-strides < W-strides < H-strides (< N-strides) // (If size of any dimension is equal to 1, this dimension strides value // is not taken into account). - bool is_channels_last_ = false; + bool is_channels_last_ : 1; // Channels last contiguous tensor is channel last tensor which occupies // contiguous memory block. - bool is_channels_last_contiguous_ = false; + bool is_channels_last_contiguous_ : 1; // Tensor is stored in the channels last 3d memory format, when dimensions // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (< N-strides) // (If size of any dimension is equal to 1, this dimension strides value // is not taken into account). - bool is_channels_last_3d_ = false; + bool is_channels_last_3d_ : 1; // Channels last 3d contiguous tensor is channel last 3d tensor which occupies // contiguous memory block. - bool is_channels_last_3d_contiguous_ = false; + bool is_channels_last_3d_contiguous_ : 1; // Dense tensor is the tensor that store values in a contiguous block of memory. // Non-overlapping tensor is the tensor in which elements occupy individual // non-repetitive memory. - bool is_non_overlapping_and_dense_ = false; + bool is_non_overlapping_and_dense_ : 1; - bool is_wrapped_number_ = false; + bool is_wrapped_number_ : 1; // NOTE [ Metadata Change for a Detached Tensor ] // @@ -1717,14 +1802,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // NOTE: For a full list of tensor metadata fields, please see // `copy_tensor_metadata()` in TensorImpl and its subclasses to find // which fields are copied by value. - bool allow_tensor_metadata_change_ = true; + bool allow_tensor_metadata_change_ : 1; // we decide to keep reserved_ and it will // live in Tensor after the split // The logic is that if Extend() or ReserveSpace() were ever called, // then subsequent Resize()s will not free up Storage. - bool reserved_ = false; - + bool reserved_ : 1; }; // Note [TensorImpl size constraints] @@ -1759,31 +1843,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // autograd metadata pointer // version counter pointer // PyObject pointer -// sizes SmallVector (begin) -// sizes SmallVector (end) -// sizes SmallVector (capacity) -// sizes SmallVector (pre-allocated 0) -// sizes SmallVector (pre-allocated 1) -// sizes SmallVector (pre-allocated 2) -// sizes SmallVector (pre-allocated 3) -// sizes SmallVector (pre-allocated 4) -// strides SmallVector (begin) -// strides SmallVector (end) -// strides SmallVector (capacity) -// strides SmallVector (pre-allocated 0) -// strides SmallVector (pre-allocated 1) -// strides SmallVector (pre-allocated 2) -// strides SmallVector (pre-allocated 3) -// strides SmallVector (pre-allocated 4) +// SizesAndStrides size/pointer +// SizesAndStrides sizes (pre-allocated 0) +// SizesAndStrides sizes (pre-allocated 1) +// SizesAndStrides sizes (pre-allocated 2) +// SizesAndStrides sizes (pre-allocated 3) +// SizesAndStrides sizes (pre-allocated 4) +// SizesAndStrides strides (pre-allocated 0) +// SizesAndStrides strides (pre-allocated 1) +// SizesAndStrides strides (pre-allocated 2) +// SizesAndStrides strides (pre-allocated 3) +// SizesAndStrides strides (pre-allocated 4) // storage offset // numel -// data type pointer +// data type // (optional) device // tensor type id // miscellaneous bitfield // static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * 31, + sizeof(TensorImpl) == sizeof(int64_t) * 24, "You changed the size of TensorImpl on 64-bit arch." "See Note [TensorImpl size constraints] on how to proceed."); } // namespace c10 diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index a42f4d4284f48..34e17c37f7740 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -17,6 +18,29 @@ #include namespace c10 { + +DispatchKey computeDispatchKey(c10::optional dtype, c10::optional layout, c10::optional device); + +inline ScalarType dtype_or_default(c10::optional dtype) { + return value_or_else(dtype, [] {return get_default_dtype_as_scalartype();}); +} + +inline caffe2::TypeMeta dtype_or_default(c10::optional dtype) { + return value_or_else(dtype, [] {return get_default_dtype();}); +} + +inline Layout layout_or_default(c10::optional layout) { + return layout.value_or(kStrided); +} + +inline Device device_or_default(c10::optional device) { + return value_or_else(device, [] {return Device(kCPU);}); +} + +inline bool pinned_memory_or_default(c10::optional pinned_memory) { + return pinned_memory.value_or(false); +} + /// A class to encapsulate construction axes of an Tensor. TensorOptions was /// designed to support the Python style API for specifying construction options /// on factory functions, e.g., @@ -97,6 +121,8 @@ namespace c10 { /// To get around this, we templatize the `Device` constructor. Since overload /// resolution is done before template resolution, our problem is solved. +DispatchKey computeDispatchKey(optional dtype, optional layout, optional device); + struct C10_API TensorOptions { TensorOptions() @@ -228,7 +254,7 @@ struct C10_API TensorOptions { /// Returns the device of the `TensorOptions`. Device device() const noexcept { - return has_device_ ? device_ : Device(kCPU); + return device_or_default(device_opt()); } /// Returns whether the device is specified. @@ -249,7 +275,7 @@ struct C10_API TensorOptions { /// Returns the dtype of the `TensorOptions`. caffe2::TypeMeta dtype() const noexcept { - return has_dtype_ ? dtype_ : get_default_dtype(); + return dtype_or_default(dtype_opt()); } /// Returns whether the dtype is specified. @@ -265,7 +291,7 @@ struct C10_API TensorOptions { /// Returns the layout of the `TensorOptions`. Layout layout() const noexcept { - return has_layout_ ? layout_ : kStrided; + return layout_or_default(layout_opt()); } /// Returns whether the layout is specified. @@ -298,7 +324,7 @@ struct C10_API TensorOptions { /// Returns the `pinned_memory` property of the `TensorOptions`. bool pinned_memory() const noexcept { - return has_pinned_memory_ ? pinned_memory_ : false; + return pinned_memory_or_default(pinned_memory_opt()); } /// Returns whether the `pinned_memory` is specified. @@ -353,15 +379,24 @@ struct C10_API TensorOptions { /// device guard. /// TensorOptions merge_in(TensorOptions options) const noexcept { - TensorOptions r = options; - if (!r.has_device()) r.set_device(device_opt()); - if (!r.has_dtype()) r.set_dtype(dtype_opt()); - if (!r.has_layout()) r.set_layout(layout_opt()); + TensorOptions merged = *this; + if (options.has_device()) merged.set_device(options.device_opt()); + if (options.has_dtype()) merged.set_dtype(options.dtype_opt()); + if (options.has_layout()) merged.set_layout(options.layout_opt()); // NB: requires grad is right biased; not a logical AND/OR! - if (!r.has_requires_grad()) r.set_requires_grad(requires_grad_opt()); - if (!r.has_pinned_memory()) r.set_pinned_memory(pinned_memory_opt()); - if (!r.has_memory_format()) r.set_memory_format(memory_format_opt()); - return r; + if (options.has_requires_grad()) merged.set_requires_grad(options.requires_grad_opt()); + if (options.has_pinned_memory()) merged.set_pinned_memory(options.pinned_memory_opt()); + if (options.has_memory_format()) merged.set_memory_format(options.memory_format_opt()); + return merged; + } + + // TODO remove after TensorOptions rationalization + TensorOptions merge_memory_format(c10::optional optional_memory_format) const noexcept { + TensorOptions merged = *this; + if (optional_memory_format.has_value()) { + merged.set_memory_format(*optional_memory_format); + } + return merged; } // Resolves the tensor type set specified by the current construction axes. @@ -369,66 +404,8 @@ struct C10_API TensorOptions { return DispatchKeySet(computeDispatchKey()); } - inline DispatchKey computeDispatchKey() const { - switch (layout()) { - case Layout::Strided: - switch (device().type()) { - case DeviceType::CPU: { - auto dtype_tmp = typeMetaToScalarType(dtype()); - if (isQIntType(dtype_tmp)) { - return DispatchKey::QuantizedCPU; - } - return DispatchKey::CPU; - } - case DeviceType::CUDA: { - auto dtype_tmp = typeMetaToScalarType(dtype()); - if (isQIntType(dtype_tmp)) { - return DispatchKey::QuantizedCUDA; - } - return DispatchKey::CUDA; - } - case DeviceType::MKLDNN: - return DispatchKey::MKLDNN; - case DeviceType::OPENGL: - return DispatchKey::OpenGL; - case DeviceType::OPENCL: - return DispatchKey::OpenCL; - case DeviceType::IDEEP: - return DispatchKey::IDEEP; - case DeviceType::HIP: - return DispatchKey::HIP; - case DeviceType::FPGA: - return DispatchKey::FPGA; - case DeviceType::MSNPU: - return DispatchKey::MSNPU; - case DeviceType::XLA: - return DispatchKey::XLA; - case DeviceType::Vulkan: - return DispatchKey::Vulkan; - default: - AT_ERROR("Unsupported device type for dense layout: ", device().type()); - } - case Layout::Sparse: - switch (device().type()) { - case DeviceType::CPU: - return DispatchKey::SparseCPU; - case DeviceType::CUDA: - return DispatchKey::SparseCUDA; - case DeviceType::HIP: - return DispatchKey::SparseHIP; - default: - AT_ERROR("Unsupported device type for sparse layout: ", device().type()); - } - case Layout::Mkldnn: - switch (device().type()) { - case DeviceType::CPU: - return DispatchKey::MkldnnCPU; - default: - AT_ERROR("Unsupported device type for mkldnn layout: ", device().type()); - } - default: - AT_ERROR("Unsupported layout: ", layout()); - } + DispatchKey computeDispatchKey() const { + return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt()); } private: @@ -527,8 +504,8 @@ struct C10_API TensorOptions { // NB: We didn't use c10::optional here, because then we can't pack // the has_***_ boolean fields. - caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 64-bit - Device device_ = at::kCPU; // 32-bit + Device device_ = at::kCPU; // 16-bit + caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 16-bit Layout layout_ = at::kStrided; // 8-bit MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit @@ -611,13 +588,70 @@ inline std::string toString(const TensorOptions options) { // This is intended to be a centralized location by which we can determine // what an appropriate DispatchKey for a tensor is. -// -// This takes a TensorOptions, rather than just a DeviceType and Layout, because -// we reserve the right to change dispatch based on *any* aspect of -// TensorOptions. WARNING: If you do this, you need to fix the calls -// to computeDispatchKey in caffe2/tensor.h -inline DispatchKey computeDispatchKey(TensorOptions options) { - return options.computeDispatchKey(); +inline DispatchKey computeDispatchKey(c10::optional dtype, c10::optional layout, c10::optional device) { + const auto layout_ = layout_or_default(layout); + const auto device_ = device_or_default(device); + switch (layout_) { + case Layout::Strided: { + const auto dtype_ = dtype_or_default(dtype); + switch (device_.type()) { + case DeviceType::CPU: { + if (isQIntType(dtype_)) { + return DispatchKey::QuantizedCPU; + } + return DispatchKey::CPU; + } + case DeviceType::CUDA: { + if (isQIntType(dtype_)) { + return DispatchKey::QuantizedCUDA; + } + return DispatchKey::CUDA; + } + case DeviceType::MKLDNN: + return DispatchKey::MKLDNN; + case DeviceType::OPENGL: + return DispatchKey::OpenGL; + case DeviceType::OPENCL: + return DispatchKey::OpenCL; + case DeviceType::IDEEP: + return DispatchKey::IDEEP; + case DeviceType::HIP: + return DispatchKey::HIP; + case DeviceType::FPGA: + return DispatchKey::FPGA; + case DeviceType::MSNPU: + return DispatchKey::MSNPU; + case DeviceType::XLA: + return DispatchKey::XLA; + case DeviceType::Vulkan: + return DispatchKey::Vulkan; + case DeviceType::Metal: + return DispatchKey::Metal; + default: + AT_ERROR("Unsupported device type for dense layout: ", device_.type()); + } + } + case Layout::Sparse: + switch (device_.type()) { + case DeviceType::CPU: + return DispatchKey::SparseCPU; + case DeviceType::CUDA: + return DispatchKey::SparseCUDA; + case DeviceType::HIP: + return DispatchKey::SparseHIP; + default: + AT_ERROR("Unsupported device type for sparse layout: ", device_.type()); + } + case Layout::Mkldnn: + switch (device_.type()) { + case DeviceType::CPU: + return DispatchKey::MkldnnCPU; + default: + AT_ERROR("Unsupported device type for mkldnn layout: ", device_.type()); + } + default: + AT_ERROR("Unsupported layout: ", layout_); + } } // We deliberately ignore handling AutogradCPU/CUDA/XLA... keys to @@ -655,6 +689,10 @@ inline DeviceType computeDeviceType(DispatchKey tid) { return DeviceType::CPU; } else if (tid == DispatchKey::Vulkan) { return DeviceType::Vulkan; + } else if (tid == DispatchKey::Metal) { + return DeviceType::Metal; + } else if (tid == DispatchKey::QuantizedCPU) { + return DeviceType::CPU; } else { AT_ASSERTM(false, "Unknown DispatchKey: ", tid); } diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index f79897c72e040..2b6365855c2bf 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -8,10 +8,6 @@ UndefinedTensorImpl::UndefinedTensorImpl() : TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) { } -IntArrayRef UndefinedTensorImpl::sizes() const { - AT_ERROR("sizes() called on undefined Tensor"); -} - int64_t UndefinedTensorImpl::size(int64_t d) const { AT_ERROR("size(dim) called on an undefined Tensor"); } @@ -20,10 +16,6 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const { AT_ERROR("stride(dim) called on an undefined Tensor"); } -int64_t UndefinedTensorImpl::dim() const { - AT_ERROR("dim() called on undefined Tensor"); -} - bool UndefinedTensorImpl::has_storage() const { AT_ERROR("has_storage() called on undefined Tensor"); } diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 9f1cb93c10eb0..ab34c8f52a1e9 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -17,19 +17,15 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { #endif return &_singleton; } - IntArrayRef sizes() const override; IntArrayRef strides() const override; int64_t size(int64_t d) const override; int64_t stride(int64_t d) const override; - int64_t dim() const override; bool has_storage() const override; const Storage& storage() const override; int64_t storage_offset() const override; private: UndefinedTensorImpl(); static UndefinedTensorImpl _singleton; -public: - friend struct UndefinedType; }; } // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 516aebba07472..258f8953f4de4 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -126,7 +126,7 @@ struct C10_API DeviceGuardImplInterface { /** * Increments the event's version and enqueues a job with this version * in the stream's work queue. When the stream process that job - * it nofifies all streams waiting on / blocked by that version of the + * it notifies all streams waiting on / blocked by that version of the * event to continue and marks that version as recorded. * */ virtual void record( @@ -209,7 +209,15 @@ class C10_API DeviceGuardImplRegistrar { static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE(g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl()); inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) { - auto p = device_guard_impl_registry[static_cast(type)].load(); + // Two adjacent int16_t fields DeviceType and DeviceIndex has field access + // miscompiled on NVCC. To workaround this issue, we apply a mask to the + // DeviceType. First check if the DeviceType is 16-bit. + // FB employees can see + // https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/ + // for more details + static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit"); + auto p = device_guard_impl_registry[static_cast(type) & 0xFF].load(); + // This seems to be the first place where you make use of a device // when you pass devices to factory functions. Give a nicer error // message in this case. diff --git a/c10/core/impl/LocalDispatchKeySet.cpp b/c10/core/impl/LocalDispatchKeySet.cpp index 358e6ef7e1f7f..78400e9556726 100644 --- a/c10/core/impl/LocalDispatchKeySet.cpp +++ b/c10/core/impl/LocalDispatchKeySet.cpp @@ -5,10 +5,6 @@ namespace c10 { namespace impl { -C10_DEFINE_bool(disable_variable_dispatch, false, "This flag forcibly disables the Variable code paths from executing, which currently breaks profiling in the process."); - -namespace { - /// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, /// thread_local is not supported. #ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY @@ -18,25 +14,15 @@ thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; #else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) -static PODLocalDispatchKeySet raw_local_dispatch_key_set; +PODLocalDispatchKeySet raw_local_dispatch_key_set; #endif -} // anonymous namespace - +#if defined(_MSC_VER) || defined(C10_ANDROID) LocalDispatchKeySet tls_local_dispatch_key_set() { - // Hack until variable performance is fixed - // - // ezyang: I'm pretty unhappy about this implementation, it looks wrong - // to me, as it seems to be performing a mutation on - // raw_local_dispatch_key_set. I can't conveniently test the correct - // version though... - if (FLAGS_disable_variable_dispatch) { - raw_local_dispatch_key_set.set_excluded( - raw_local_dispatch_key_set.excluded() | autograd_dispatch_keyset); - } return raw_local_dispatch_key_set; } +#endif // defined(_MSC_VER) || defined(C10_ANDROID) void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) { raw_local_dispatch_key_set = PODLocalDispatchKeySet { diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h index 5262b1d4d6c07..6c03755ea73d8 100644 --- a/c10/core/impl/LocalDispatchKeySet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include // TLS management for DispatchKeySet (the "local" DispatchKeySet(s)) @@ -23,8 +24,6 @@ namespace c10 { namespace impl { -C10_DECLARE_bool(disable_variable_dispatch); - // POD version of LocalDispatchKeySet. Declared here just so that // we can put it in the guards. struct C10_API PODLocalDispatchKeySet { @@ -54,7 +53,25 @@ struct C10_API LocalDispatchKeySet { DispatchKeySet excluded_; }; +// thread_local variables cannot be C10_API on Windows. +// Inlining this seems to break AutoNonVariableTypeGuard on Android. +#if defined(_MSC_VER) || defined(C10_ANDROID) C10_API LocalDispatchKeySet tls_local_dispatch_key_set(); +#else // defined(_MSC_VER) || defined(C10_ANDROID) +/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, +/// thread_local is not supported. +#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY + extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; +#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) + extern C10_API PODLocalDispatchKeySet raw_local_dispatch_key_set; +#endif + +inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() { + // Don't let people fiddle with the thread_local directly just + // because they include this header. + return raw_local_dispatch_key_set; +} +#endif // defined(_MSC_VER) || defined(C10_ANDROID) // Internal, use ThreadLocalStateGuard C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set); diff --git a/c10/core/impl/SizesAndStrides.cpp b/c10/core/impl/SizesAndStrides.cpp new file mode 100644 index 0000000000000..bf7ec3ff887d5 --- /dev/null +++ b/c10/core/impl/SizesAndStrides.cpp @@ -0,0 +1,66 @@ +#include + +namespace c10 { +namespace impl { + +void SizesAndStrides::resizeSlowPath(const size_t newSize, const size_t oldSize) { + if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline(), "resizeSlowPath called when fast path should have been hit!"); + int64_t* tempStorage = outOfLineStorage_; + memcpy( + &inlineStorage_[0], + &tempStorage[0], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + memcpy( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + &tempStorage[oldSize], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_ + // HAS BEEN OVERWRITTEN! + free(tempStorage); + } else { + if (isInline()) { + // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD + // OVERWRITE inlineStorage_! + int64_t* tempStorage = static_cast(malloc(storageBytes(newSize))); + TORCH_CHECK(tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!"); + const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); + const auto bytesToZero = (newSize > oldSize) ? (newSize - oldSize) * sizeof(tempStorage[0]) : 0; + memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[oldSize], 0, bytesToZero); + } + memcpy(&tempStorage[newSize], &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[newSize + oldSize], 0, bytesToZero); + } + outOfLineStorage_ = tempStorage; + } else { + const bool isGrowing = oldSize < newSize; + if (isGrowing) { + // Resize before shifting so that we have room. + resizeOutOfLineStorage(newSize); + } + // Shift the old strides to their new starting point. Note + // that this does not occur in the inline path above because + // the stride starting point is not moving. + memmove( + outOfLineStorage_ + newSize, + outOfLineStorage_ + oldSize, + std::min(oldSize, newSize) * sizeof(outOfLineStorage_[0])); + if (!isGrowing) { + // Resize after shifting so that we don't lose data. + resizeOutOfLineStorage(newSize); + } else { + // Zero the end of the sizes portion. + const auto bytesToZero = (newSize - oldSize) * sizeof(outOfLineStorage_[0]); + memset(&outOfLineStorage_[oldSize], 0, bytesToZero); + memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero); + } + } + } + size_ = newSize; +} + +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h new file mode 100644 index 0000000000000..4f7e19330acab --- /dev/null +++ b/c10/core/impl/SizesAndStrides.h @@ -0,0 +1,293 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 + +namespace c10 { +namespace impl { + +// Packed container for TensorImpl sizes and strides. +// This design improves on the previous approach of using a pair of +// c10::SmallVector by specializing for the operations we +// actually use and enforcing that the number of sizes is the same as +// the number of strides. The memory layout is as follows: +// +// 1 size_t for the size +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer to out-of-line array +class C10_API SizesAndStrides { + public: + // TODO: different iterator types for sizes & strides to prevent + // mixing the two accidentally. + using sizes_iterator = int64_t*; + using sizes_const_iterator = const int64_t*; + using strides_iterator = int64_t*; + using strides_const_iterator = const int64_t*; + + SizesAndStrides() : size_(1) { + size_at_unchecked(0) = 0; + stride_at_unchecked(0) = 1; + } + + ~SizesAndStrides() { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + } + + SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { + if (C10_LIKELY(rhs.isInline())) { + copyDataInline(rhs); + } else { + allocateOutOfLineStorage(size_); + copyDataOutline(rhs); + } + } + + SizesAndStrides& operator=(const SizesAndStrides& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + if (isInline()) { + allocateOutOfLineStorage(rhs.size_); + } else { + resizeOutOfLineStorage(rhs.size_); + } + copyDataOutline(rhs); + } + size_ = rhs.size_; + return *this; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { + if (C10_LIKELY(isInline())) { + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } else { + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + + rhs.size_ = 0; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + // They're outline. We're going to steal their vector. + if (!isInline()) { + free(outOfLineStorage_); + } + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + size_ = rhs.size_; + rhs.size_ = 0; + + return *this; + } + + size_t size() const noexcept { + return size_; + } + + const int64_t* sizes_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + int64_t* sizes_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + sizes_const_iterator sizes_begin() const noexcept { + return sizes_data(); + } + + sizes_iterator sizes_begin() noexcept { + return sizes_data(); + } + + sizes_const_iterator sizes_end() const noexcept { + return sizes_begin() + size(); + } + + sizes_iterator sizes_end() noexcept { + return sizes_begin() + size(); + } + + IntArrayRef sizes_arrayref() const noexcept { + return IntArrayRef{sizes_data(), size()}; + } + + void set_sizes(IntArrayRef newSizes) { + resize(newSizes.size()); + std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); + } + + const int64_t* strides_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + int64_t* strides_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_begin() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_iterator strides_begin() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_end() const noexcept { + return strides_begin() + size(); + } + + strides_iterator strides_end() noexcept { + return strides_begin() + size(); + } + + IntArrayRef strides_arrayref() const noexcept { + return IntArrayRef{strides_data(), size()}; + } + + // Size accessors. + int64_t size_at(size_t idx) const noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t& size_at(size_t idx) noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t size_at_unchecked(size_t idx) const noexcept { + return sizes_data()[idx]; + } + + int64_t& size_at_unchecked(size_t idx) noexcept { + return sizes_data()[idx]; + } + + // Size accessors. + int64_t stride_at(size_t idx) const noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t& stride_at(size_t idx) noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t stride_at_unchecked(size_t idx) const noexcept { + return strides_data()[idx]; + } + + int64_t& stride_at_unchecked(size_t idx) noexcept { + return strides_data()[idx]; + } + + void resize(size_t newSize) { + const auto oldSize = size(); + if (newSize == oldSize) { + return; + } + if (C10_LIKELY(newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (oldSize < newSize) { + const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]); + memset(&inlineStorage_[oldSize], 0, bytesToZero); + memset(&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 0, bytesToZero); + } + size_ = newSize; + } else { + resizeSlowPath(newSize, oldSize); + } + } + + void resizeSlowPath(size_t newSize, size_t oldSize); + + private: + bool isInline() const noexcept { + return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + } + + void copyDataInline(const SizesAndStrides& rhs) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } + + static size_t storageBytes(size_t size) noexcept { + return size * 2 * sizeof(int64_t); + } + + void allocateOutOfLineStorage(size_t size) { + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void resizeOutOfLineStorage(size_t newSize) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); + outOfLineStorage_ = static_cast(realloc(outOfLineStorage_, storageBytes(newSize))); + TORCH_CHECK(outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void copyDataOutline(const SizesAndStrides& rhs) noexcept { + memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + } + + size_t size_; + union { + int64_t *outOfLineStorage_; + int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + }; + +}; + +} // namespace impl +} // namespace c10 diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index c8fa53df6f02b..256fc54b08a18 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -13,7 +13,7 @@ configure_file( ${CMAKE_BINARY_DIR}/c10/cuda/impl/cuda_cmake_macros.h) # Note: if you want to add ANY dependency to the c10 library, make sure you -# check with the core PyTorch developers as the dependendency will be +# check with the core PyTorch developers as the dependency will be # transitively passed on to all libraries dependent on PyTorch. # Note: if you add a new source file/header, you will need to update diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 2285a332f7093..493296248e5ba 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -62,7 +62,7 @@ constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 M constexpr size_t kSmallBuffer = 2097152; // "small" allocations are packed in 2 MiB blocks constexpr size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer -constexpr size_t kRoundLarge = 2097152; // round up large allocs to 2 MiB +constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB typedef std::bitset(StatType::NUM_TYPES)> StatTypes; @@ -202,6 +202,13 @@ class DeviceCachingAllocator { // outstanding cuda events std::deque> cuda_events; + // record used memory. + size_t total_allocated_memory = 0; + + size_t allowed_memory_maximum = 0; + + bool set_fraction = false; + public: DeviceCachingAllocator() : @@ -241,10 +248,16 @@ class DeviceCachingAllocator { size_t device_free; size_t device_total; C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); + std::string allowed_info; + + if (set_fraction) { + allowed_info = format_size(allowed_memory_maximum) + " allowed; "; + } stats.num_ooms += 1; // "total capacity": total global memory on GPU + // "allowed": memory is allowed to use, which set by fraction. // "already allocated": memory allocated by the program using the // caching allocator // "free": free memory as reported by the CUDA API @@ -268,6 +281,7 @@ class DeviceCachingAllocator { format_size(stats.allocated_bytes[static_cast(StatType::AGGREGATE)].current), " already allocated; ", format_size(device_free), " free; ", + allowed_info, format_size(stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current), " reserved in total by PyTorch)"); } else { @@ -373,6 +387,15 @@ class DeviceCachingAllocator { block->stream_uses.insert(stream); } + /** set memory fraction to limit maximum allocated memory **/ + void setMemoryFraction(double fraction) { + size_t device_free; + size_t device_total; + C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); + allowed_memory_maximum = static_cast(fraction * device_total); + set_fraction = true; + } + /** returns cached blocks to the system allocator **/ void emptyCache() { std::lock_guard lock(mutex); @@ -630,14 +653,19 @@ class DeviceCachingAllocator { if (isRetry) { stats.num_alloc_retries += 1; } + if (set_fraction && total_allocated_memory + size > allowed_memory_maximum) { + p.err = cudaErrorMemoryAllocation; + } else { + p.err = cudaMalloc(&ptr, size); + } - p.err = cudaMalloc(&ptr, size); if (p.err != cudaSuccess) { if (!isRetry || p.err == cudaErrorMemoryAllocation) cudaGetLastError(); // clear CUDA error return false; } + total_allocated_memory += size; p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); update_stat_array(stats.segment, 1, p.stat_types); update_stat_array(stats.reserved_bytes, size, p.stat_types); @@ -665,6 +693,7 @@ class DeviceCachingAllocator { Block* block = *it; if (!block->prev && !block->next) { C10_CUDA_CHECK(cudaFree((void*)block->ptr)); + total_allocated_memory -= block->size; StatTypes stat_types; stat_types[static_cast(StatType::AGGREGATE)] = true; @@ -846,6 +875,25 @@ class THCCachingAllocator { device_allocator[block->device]->free(block); } + void setMemoryFraction(double fraction, int device) { + TORCH_INTERNAL_ASSERT( + 0 <= device && device < device_allocator.size(), + "Allocator not initialized for device ", + device, + ": did you call init?"); + TORCH_INTERNAL_ASSERT( + 0 <= fraction && fraction <= 1, + "invalid fraction:", + fraction, + ". Please set within (0, 1)."); + int activated_device; + cudaGetDevice (&activated_device); + if (activated_device != device) { + cudaSetDevice(device); + } + device_allocator[device]->setMemoryFraction(fraction); + } + void emptyCache() { int count = device_allocator.size(); for (int i = 0; i < count; i++) @@ -896,6 +944,19 @@ class THCCachingAllocator { THCCachingAllocator caching_allocator; +// Returns whether to force all allocations to bypass the caching allocator and +// go straight to cudaMalloc. This setting is useful when debugging GPU memory +// errors, since the caching allocator foils cuda-memcheck. +bool forceUncachedAllocator() { + static bool force_uncached = + getenv("PYTORCH_NO_CUDA_MEMORY_CACHING") != nullptr; + return force_uncached; +} + +static void uncached_delete(void* ptr) { + C10_CUDA_CHECK(cudaFree(ptr)); +} + // NB: I decided not to fold this into THCCachingAllocator, because the latter // has a lot more methods and it wasn't altogether clear that they should // actually be publicly exposed @@ -904,6 +965,10 @@ struct CudaCachingAllocator : public Allocator { int device; C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; + if (forceUncachedAllocator()) { + C10_CUDA_CHECK(cudaMalloc(&r, size)); + return {r, r, &uncached_delete, Device(DeviceType::CUDA, device)}; + } if (size != 0) { caching_allocator.malloc(&r, device, size, cuda::getCurrentCUDAStream(device)); } @@ -925,6 +990,10 @@ void init(int device_count) { caching_allocator.init(device_count); } +void setMemoryFraction(double fraction, int device) { + caching_allocator.setMemoryFraction(fraction, device); +} + void emptyCache(void) { caching_allocator.emptyCache(); } @@ -949,8 +1018,8 @@ std::mutex* getFreeMutex() } static inline void assertValidDevice(int device) { - int device_num = device_count(); - AT_ASSERTM(0 <= device && device < device_num, "Invalid device argument."); + int device_num = caching_allocator.device_allocator.size(); + TORCH_CHECK(0 <= device && device < device_num, "Invalid device argument."); } DeviceStats getDeviceStats(int device) { diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 7b9ce4b3211ff..8af8ec5073feb 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -111,6 +111,7 @@ C10_CUDA_API void raw_delete(void* ptr); C10_CUDA_API Allocator* get(); C10_CUDA_API void init(int device_count); +C10_CUDA_API void setMemoryFraction(double fraction, int device); C10_CUDA_API void emptyCache(); C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock); C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size); diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index bdce44f1c6d1c..5d1a473b55974 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -29,3 +29,8 @@ TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \ } \ } while (0) + +// This should be used directly after every kernel launch to ensure +// the launch happened correctly and provide an early, close-to-source +// diagnostic if it didn't. +#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 99bb721798d99..4c87cb9ec8c63 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -1,5 +1,7 @@ #include +#include + namespace c10 { namespace cuda { @@ -93,6 +95,8 @@ DeviceIndex device_count() noexcept { // initialize number of devices only once static int count = []() { try { + auto result = device_count_impl(); + TORCH_INTERNAL_ASSERT(result <= std::numeric_limits::max(), "Too many CUDA devices, DeviceIndex overflowed"); return device_count_impl(); } catch (const c10::Error& ex) { // We don't want to fail, but still log the warning diff --git a/c10/cuda/CUDAMathCompat.h b/c10/cuda/CUDAMathCompat.h index 7652ca0f639dc..1fb0c3ec29c2e 100644 --- a/c10/cuda/CUDAMathCompat.h +++ b/c10/cuda/CUDAMathCompat.h @@ -42,6 +42,13 @@ __MATH_FUNCTIONS_DECL__ double ceil(double x) { return ::ceil(x); } +__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) { + return ::copysignf(x, y); +} +__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) { + return ::copysign(x, y); +} + __MATH_FUNCTIONS_DECL__ float floor(float x) { return ::floorf(x); } diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 393826f75a030..d1e290c3f02cc 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -43,12 +43,9 @@ static constexpr int kStreamsPerPoolBits = 5; static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking; -// Note: stream priority is not supported by HIP // Note: lower numbers are higher priorities, zero is default priority -#ifndef __HIP_PLATFORM_HCC__ static int kHighPriority = -1; static int kLowPriority = 0; -#endif // __HIP_PLATFORM_HCC__ // Default streams static std::once_flag init_flag; @@ -63,7 +60,7 @@ static LeakyStreamInternals default_streams[C10_COMPILE_TIME_MAX_GPUS]; // in the pool to be returned when a stream is requested (round-robin fashion // , see the note in CUDAStream.h). // -// unique_ptr is used instead of vector because T might be non-moveable +// unique_ptr is used instead of vector because T might be non-movable // and non-copyable. static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS]; static std::atomic low_priority_counters[C10_COMPILE_TIME_MAX_GPUS]; @@ -229,17 +226,10 @@ static void initDeviceStreamState(DeviceIndex device_index) { lowpri_stream.device_index = device_index; hipri_stream.device_index = device_index; -#ifndef __HIP_PLATFORM_HCC__ C10_CUDA_CHECK(cudaStreamCreateWithPriority( &lowpri_stream.stream, kDefaultFlags, kLowPriority)); C10_CUDA_CHECK(cudaStreamCreateWithPriority( &hipri_stream.stream, kDefaultFlags, kHighPriority)); -#else - C10_CUDA_CHECK( - cudaStreamCreateWithFlags(&lowpri_stream.stream, kDefaultFlags)); - C10_CUDA_CHECK( - cudaStreamCreateWithFlags(&hipri_stream.stream, kDefaultFlags)); -#endif // __HIP_PLATFORM_HCC__ } } diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h index d9bc553aa2636..05eddf5ce122d 100644 --- a/c10/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -120,14 +120,10 @@ class C10_CUDA_API CUDAStream { } int priority() const { - #ifndef __HIP_PLATFORM_HCC__ DeviceGuard guard{stream_.device()}; int priority = 0; C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); return priority; - #else - AT_ERROR("cuStreamGetPriority with HIP is not supported"); - #endif } /// Explicit conversion to cudaStream_t. @@ -154,10 +150,9 @@ class C10_CUDA_API CUDAStream { } static std::tuple priority_range() { - #ifndef __HIP_PLATFORM_HCC__ // Note: this returns the range of priority **supported by PyTorch**, not // the range of priority **supported by CUDA**. The former is a subset of - // the latter. Curently PyTorch only supports 0 and -1, which are "low" and + // the latter. Currently PyTorch only supports 0 and -1, which are "low" and // "high" priority. int least_priority, greatest_priority; C10_CUDA_CHECK( @@ -165,9 +160,6 @@ class C10_CUDA_API CUDAStream { TORCH_INTERNAL_ASSERT(least_priority >= 0, "Unexpected CUDA stream priority range"); TORCH_INTERNAL_ASSERT(greatest_priority <= -1, "Unexpected CUDA stream priority range"); return std::make_tuple(0, -1); - #else - AT_ERROR("cuDeviceGetStreamPriorityRange with HIP is not supported"); - #endif } // Deleted for now; use CUDAEvent::block instead @@ -187,7 +179,7 @@ class C10_CUDA_API CUDAStream { * isHighPriority to true, or a stream for a specific device by setting device * (defaulting to the current CUDA stream.) */ -CAFFE2_API CUDAStream +TORCH_API CUDAStream getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); /** @@ -196,7 +188,7 @@ getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); * where most computation occurs when you aren't explicitly using * streams. */ -CAFFE2_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); +TORCH_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); /** * Get the current CUDA stream, for the passed CUDA device, or for the @@ -205,7 +197,7 @@ CAFFE2_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard' * or 'CUDAStreamGuard'. */ -CAFFE2_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); +TORCH_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); /** * Set the current stream on the device of the passed in stream to be @@ -217,7 +209,7 @@ CAFFE2_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); * (which will switch both your current device and current stream in the way you * expect, and reset it back to its original state afterwards). */ -CAFFE2_API void setCurrentCUDAStream(CUDAStream stream); +TORCH_API void setCurrentCUDAStream(CUDAStream stream); C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s); diff --git a/c10/macros/Export.h b/c10/macros/Export.h index 5888207c5f807..64d1037be0e4f 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -92,11 +92,10 @@ #endif // This one is being used by libtorch.so -// TODO: rename this to TORCH_API #ifdef CAFFE2_BUILD_MAIN_LIB -#define CAFFE2_API C10_EXPORT +#define TORCH_API C10_EXPORT #else -#define CAFFE2_API C10_IMPORT +#define TORCH_API C10_IMPORT #endif // NB: For now, HIP is overloaded to use the same macro, but ideally @@ -113,8 +112,8 @@ #define TORCH_HIP_API C10_IMPORT #endif -// Enums only need to be exported on windows -#ifdef _WIN32 +// Enums only need to be exported on windows for non-CUDA files +#if defined(_WIN32) && defined(__CUDACC__) #define C10_API_ENUM C10_API #else #define C10_API_ENUM diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index f8e703e087468..5499a7d8b81c4 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -77,8 +77,10 @@ * str and ending with a number that varies with the line. */ #ifdef __COUNTER__ +#define C10_UID __COUNTER__ #define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) #else +#define C10_UID __LINE__ #define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) #endif @@ -184,6 +186,14 @@ namespace at { namespace cuda { using namespace c10::hip; }} #define C10_NOINLINE #endif +#if __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#elif defined(_MSC_VER) +#define C10_ALWAYS_INLINE __forceinline +#else +#define C10_ALWAYS_INLINE inline +#endif + #include #include @@ -193,11 +203,14 @@ namespace at { namespace cuda { using namespace c10::hip; }} #define C10_DEVICE __device__ #define C10_HOST __host__ // constants from (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) -// The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5) -// but 2048 for previous architectures. You'll get warnings if you exceed these constants. +// The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5), +// 1536 for Geforce Ampere (8.6), +// and 2048 for all other architectures. You'll get warnings if you exceed these constants. // Hence, the following macros adjust the input values from the user to resolve potential warnings. #if __CUDA_ARCH__ == 750 constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; +#elif __CUDA_ARCH__ == 860 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; #else constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; #endif @@ -303,7 +316,7 @@ __host__ __device__ #define C10_MOBILE 1 #endif // ANDROID / IOS -// Portably determine if a type T is trivially copyable or not. +// Portable determination of whether type T is trivially copyable. // Warning: __has_trivial_copy for GCC may not always detect the non-POD // correctly. For example, T = std::unique_ptr may evaluate to true and be // treated as POD. This can cause unexpected behavior. diff --git a/c10/macros/cmake_macros.h.in b/c10/macros/cmake_macros.h.in index 5e42506f20dcb..2845fa1cd8d2e 100644 --- a/c10/macros/cmake_macros.h.in +++ b/c10/macros/cmake_macros.h.in @@ -8,6 +8,7 @@ #cmakedefine C10_USE_GLOG #cmakedefine C10_USE_GFLAGS #cmakedefine C10_USE_NUMA +#cmakedefine C10_USE_MSVC_STATIC_RUNTIME // Used by libtorch mobile build to enable features that are not enabled by // caffe2 mobile build. Should only use it when necessary as we are committed diff --git a/c10/core/CPUCachingAllocator.cpp b/c10/mobile/CPUCachingAllocator.cpp similarity index 96% rename from c10/core/CPUCachingAllocator.cpp rename to c10/mobile/CPUCachingAllocator.cpp index 232b8f2306e24..0114856ca89b6 100644 --- a/c10/core/CPUCachingAllocator.cpp +++ b/c10/mobile/CPUCachingAllocator.cpp @@ -1,4 +1,4 @@ -#include +#include namespace c10 { @@ -61,7 +61,7 @@ void CPUCachingAllocator::record_free(void* ptr) { // is being freed outside the scope of this allocator. // At the moment only way to capture this is to have the allocator, // that uses this CachingAllocator as the backing allocator, - // call this function explicity upon freeing memory while + // call this function explicitly upon freeing memory while // outside the scope of caching allocator. // If the memory is freed in some other way, then we will likely // have undefined behavior or page fault. But this can be @@ -95,8 +95,8 @@ CPUCachingAllocator* GetThreadLocalCachingAllocator() { WithCPUCachingAllocatorGuard::WithCPUCachingAllocatorGuard( CPUCachingAllocator* allocator) { - caching_allocator_ptr = allocator; prev_caching_allocator_ptr_ = GetThreadLocalCachingAllocator(); + caching_allocator_ptr = allocator; } WithCPUCachingAllocatorGuard::~WithCPUCachingAllocatorGuard() { diff --git a/c10/core/CPUCachingAllocator.h b/c10/mobile/CPUCachingAllocator.h similarity index 70% rename from c10/core/CPUCachingAllocator.h rename to c10/mobile/CPUCachingAllocator.h index ac5f3a95c8810..c80fee0682eb1 100644 --- a/c10/core/CPUCachingAllocator.h +++ b/c10/mobile/CPUCachingAllocator.h @@ -10,6 +10,38 @@ #include #include +/* + * CPUCachingAllocator: + * DISCLAIMER: + * This is subject to change (beta) and only supported on mobile builds. + * If code snippet such as in 'Usage pattern' is used outside of mobile + * build you will not observe the intended behavior. + * See below for more information. + * Why? + * It has been observed that some mobile platforms, such as pixel 3, return + * memory aggressively to the system. This results in page faults in some cases + * and ends up hurting performance. This caching allocator aims to address that. + * Furthermore it also allows users to specify their own allocator by implementing + * allocate/free virtual interfaces. + * What are the cons? + * There are some cons that were observed where use of caching allocator led to + * worse performance on some platforms. Reason being that the caching mechanism + * used by this allocator left us worse off compared to the corresponding platform's + * tuned memory allocator. In that case it seemed better to not use this allocator. + * Note there are some ideas to fix this in the works. + * + * Usage: + * Usage pattern: + * Instantiate and own the caching allocator. + * std::unique_ptr caching_allocator = + * std::make_unique(); + * Use caching allocator with a scoped guard at inference time. + * { + * WithCPUCachingAllocatorGuard(caching_allocator.get()); + * ... model.forward(...); + * } + */ + namespace c10 { class C10_API CPUCachingAllocator { @@ -22,13 +54,16 @@ class C10_API CPUCachingAllocator { * No speculative allocation for any future allocations. */ private: + inline void* allocate_and_cache(const size_t bytes); + void free_cached(); + protected: // Invariants. // 1. If memory is ever allocated via this allocator then // the pointer will exist in allocation_map_, unless the allocator // returned the memory to OS via free_cached. // 1.1. Therefore even when the said memory is "freed" via this // allocator (and thus cached), it will continue to stay - // in allocaiton_map_. Furthermore it will also exist in + // in allocation_map_. Furthermore it will also exist in // available_map_. Thus an allocated memory pointer can be in both // allocation_map_ and available_map_ simultaneously. // 2. Memory pointer maybe removed from allocation_map_, when it @@ -39,9 +74,6 @@ class C10_API CPUCachingAllocator { // As a result of above invariants, allocated memory ptr cannot be in // available_map_ unless it is in allocation_map_ as well. ska::flat_hash_map> available_map_; - inline void* allocate_and_cache(const size_t bytes); - void free_cached(); - protected: static ska::flat_hash_map allocation_map_; // Since allocation_map, which is a global instance, is mutated/read via // all public APIs we need a global mutex. @@ -64,16 +96,6 @@ CPUCachingAllocator* GetDefaultCPUCachingAllocator(); bool ThreadLocalCachingAllocatorEnabled(); CPUCachingAllocator* GetThreadLocalCachingAllocator(); -/* - * Usage pattern: - * std::unique_ptr caching_allocator = - * std::make_unique(); - * { - * WithCPUCachingAllocatorGuard(caching_allocator.get()); - * ... - * } - */ - class C10_API WithCPUCachingAllocatorGuard { public: WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator); diff --git a/c10/mobile/CPUProfilingAllocator.cpp b/c10/mobile/CPUProfilingAllocator.cpp new file mode 100644 index 0000000000000..0118d0a295871 --- /dev/null +++ b/c10/mobile/CPUProfilingAllocator.cpp @@ -0,0 +1,447 @@ +#include + +#include + +namespace c10 { + +namespace { +thread_local AllocationPlanner* allocation_planner{nullptr}; +thread_local CPUProfilingAllocator* profiling_allocator{nullptr}; + +struct MemBlock { + uint64_t start_offset, end_offset; + MemBlock(uint64_t s, uint64_t e) : start_offset(s), end_offset(e) {} + bool operator<(const MemBlock& other) const { + return start_offset < other.start_offset; + } +}; + +enum class EventType { + Allocate = 0, + Free, + Invalid +}; + +struct MemEvent { + uint64_t time; + uint64_t allocation_id; + uint64_t size; + EventType type{EventType::Invalid}; + MemEvent(uint64_t t, uint64_t id, uint64_t s, EventType e) : + time(t), allocation_id(id), size(s), type(e) {} +}; + +bool overlaps(const MemBlock& a, const MemBlock& b) { + // two blocks dont overlap if + // |---a--------|--------------b--------| + // strat_a end_a <= start_b end_b + return + !((a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset)); +} + +bool validate_allocation_plan( + const std::vector& alloc_events, + const std::vector& allocation_offsets) { + std::set allocations; + for (const auto& event : alloc_events) { + auto alloc_id = event.allocation_id; + // Skip allocations not managed by AllocationPlan + if (allocation_offsets[alloc_id] == std::numeric_limits::max()) { + continue; + } + auto start_offset = allocation_offsets[alloc_id]; + auto end_offset = allocation_offsets[alloc_id] + event.size; + MemBlock mem_block(start_offset, end_offset); + if (event.type == EventType::Allocate) { + auto it = allocations.lower_bound(mem_block); + if (it != allocations.end()) { + auto next_block = *it; + if (overlaps(next_block, mem_block)) { + return false; + } + } + if (it != allocations.begin()) { + auto prev_block = *(--it); + if (overlaps(prev_block, mem_block)) { + return false; + } + } + allocations.emplace(mem_block); + } else if (event.type == EventType::Free) { + auto it = allocations.find(mem_block); + TORCH_CHECK((*it).end_offset == end_offset, + "Enf offset of allocation being freed must match the one recorded."); + TORCH_CHECK( + it != allocations.end(), + "ProfilingAllocator: Allocate event " + "must have preceded deallocate event."); + allocations.erase(it); + } else { + TORCH_CHECK(false, "ProfilingAllocator: Invalid event type."); + } + } + return true; +} + +std::vector create_and_sort_mem_events( + const std::vector& allocation_sizes, + const std::vector& allocation_lifetimes) { + std::vector events; + for (uint64_t i = 0; i < allocation_sizes.size(); ++i) { + // If observed allocation are freed outside the scope of + // observation, then allocations are not managed by the + // AllocationPlan. + if (allocation_lifetimes[i] == std::numeric_limits::max()) { + continue; + } + events.emplace_back(i, i, allocation_sizes[i], EventType::Allocate); + events.emplace_back(allocation_lifetimes[i], i, allocation_sizes[i], EventType::Free); + } + std::sort( + events.begin(), + events.end(), + [](const MemEvent& a, + const MemEvent& b) -> bool {return a.time < b.time;}); + return events; +} + +std::vector formulate_greedy_allocation_plan( + const std::vector& allocation_sizes, + const std::vector& allocation_lifetimes) { + // Step 1. Construct all allocation/free events. + // Sort these events by timestamp. + // Step 2. Iterate through all events. + // 2.1 If allocate event: + // Find all candidate in free_size_to_offset map + // Greedily pick the first one. + // Remove the entry from free_size_to_offset map. + // new_offset = offset + request_size + // new_size = size - request_size + // Add new entry to both maps + // 2.2 If free event. + // Check if the returned offset merges with another chunk. + // If so merge until no more merging is possible. + // If returned offset does not merge, then + // just return it as a chunk. + + // lower_bound on this map will get all candidates of + // the right size for allocation. + std::map free_size_to_offset; + // This provides fast lookup when we want to insert freed block + // back, especially when we want to merge blocks. + ska::flat_hash_map::iterator> free_start_offset_to_size_iter; + ska::flat_hash_map::iterator> free_end_offset_to_size_iter; + // Upon free end_ptr = offset + size + // If end_ptr exists merge freed allocation + // Also find corresponding offset in size_to_offset + // Remove that entry and update with new size and offset + // If end_ptr does not exist then just insert offset,size + // in map and correspondingly size, offset in the other map. + // Merging should always be done recursively until no more chunks + // that can be found. + // After last free we should have only one entry left in these maps. + + std::vector allocation_offsets( + allocation_sizes.size(), std::numeric_limits::max()); + auto mem_events = create_and_sort_mem_events(allocation_sizes, allocation_lifetimes); + uint64_t max_offset{0}; + for (const auto& mem_event : mem_events) { + uint64_t alloc_offset; + uint64_t new_offset, new_size; + if (mem_event.type == EventType::Allocate) { + auto it = free_size_to_offset.lower_bound(mem_event.size); + if (it == free_size_to_offset.end()) { + // If there is no contiguous block of the size requested + // allocate a new one. + alloc_offset = max_offset; + max_offset += mem_event.size; + } else { + // If we have found a block of the size we want + // 1. change the block by allocating out of it. + // 1.1 Erase the entire block + // 1.2 Erase the reverse map entries + // 2. If block still has space left insert the remainder back in map. + // Including reverse map entries. + alloc_offset = it->second; + new_offset = alloc_offset + mem_event.size; + new_size = it->first - mem_event.size; + free_size_to_offset.erase(it); + free_start_offset_to_size_iter.erase(alloc_offset); + free_end_offset_to_size_iter.erase(alloc_offset + it->first); + if (new_size > 0) { + auto ref_it = free_size_to_offset.emplace(new_size, new_offset).first; + free_start_offset_to_size_iter.emplace(new_offset, ref_it); + free_end_offset_to_size_iter.emplace(new_offset + new_size, ref_it); + } + } + allocation_offsets[mem_event.allocation_id] = alloc_offset; + } else { + // 1. Check if freed block is adjacent to an existing free block + // at its end boundary. This is done by checking + // free_end_offset_to_size_iter. + // If we find such a block, remove it and adjust size of + // the block being freed. + // 2. Similarly check if freed block is adjacent to an existing + // free block at start boundary. This is done by checking + // free_start_offset_to_size_iter. + // If we find such a block, remove it and adjust size of + // the block being freed. + // 3. Insert the freed block in map. + auto freed_offset = allocation_offsets[mem_event.allocation_id]; + auto freed_size = mem_event.size; + auto end_offset = freed_offset + freed_size; + // Merge when another free block exist at the end of this block + auto end_it = free_start_offset_to_size_iter.find(end_offset); + if (end_it != free_start_offset_to_size_iter.end()) { + auto merge_block_iter = end_it->second; + auto merge_block_size = merge_block_iter->first; + freed_size += merge_block_size; + free_size_to_offset.erase(merge_block_iter); + free_start_offset_to_size_iter.erase(end_it); + // If the block is being merged then also remove it from + // free_end_offset_to_size_iter + free_end_offset_to_size_iter.erase(end_offset + merge_block_size); + } + // Merge when freed block exist at the end of another free block + auto start_it = free_end_offset_to_size_iter.find(freed_offset); + if (start_it != free_end_offset_to_size_iter.end()) { + auto merge_block_iter = start_it->second; + auto merge_block_size = merge_block_iter->first; + freed_size += merge_block_size; + freed_offset -= merge_block_size; + free_size_to_offset.erase(merge_block_iter); + free_end_offset_to_size_iter.erase(start_it); + // If the block is being merged then also remove it from + // free_start_offset_to_size_iter + free_start_offset_to_size_iter.erase(freed_offset); + } + auto freed_block_it = + free_size_to_offset.emplace(freed_size, freed_offset).first; + free_start_offset_to_size_iter.emplace(freed_offset, freed_block_it); + free_end_offset_to_size_iter.emplace( + freed_offset + freed_size, freed_block_it); + } + } + TORCH_CHECK(validate_allocation_plan(mem_events, allocation_offsets), + "ProfilingAllocator: Allocation plan invalid."); + return allocation_offsets; +} + +} // namespace + +void AllocationPlan::clear() { + allocation_sizes.clear(); + allocation_lifetimes.clear(); + allocation_offsets.clear(); +} + +void AllocationPlanner::record_allocation( + const uint64_t size, const void* ptr) { + if (validation_mode_) { + validation_success = validation_success && validate_allocation(size, ptr); + return; + } + allocation_plan_->allocation_sizes.push_back(size); + allocation_plan_->allocation_lifetimes.push_back( + std::numeric_limits::max()); + allocation_ptr_to_id_[ptr] = allocation_id_; + allocation_id_++; +} + +void AllocationPlanner::record_free(const void* ptr) { + if (validation_mode_) { + validation_success = validation_success && validate_free(ptr); + return; + } + auto it = allocation_ptr_to_id_.find(ptr); + if (it == allocation_ptr_to_id_.end()) { + // Free being recorded was allocated outside of WithProfileAllocationGuard + return; + } + auto id = it->second; + TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(), + "Allocation must have been recorded during record_allocation."); + allocation_plan_->allocation_lifetimes[id] = allocation_id_; +} + +bool AllocationPlanner::validate_allocation( + const uint64_t size, const void* ptr) { + if (allocation_id_ >= allocation_plan_->allocation_sizes.size() || + allocation_plan_->allocation_sizes[allocation_id_] != size) { + TORCH_WARN( + "Allocation request does not match plan:", + "Allocation id:", + allocation_id_, + ", Number of recorded allocations:", + allocation_plan_->allocation_sizes.size(), + ", Recorded size of the requested allocation:", + allocation_plan_->allocation_sizes[allocation_id_], + ", but got:", + size); + + return false; + } + allocation_ptr_to_id_[ptr] = allocation_id_; + allocation_id_++; + return true; +} + +bool AllocationPlanner::validate_free(const void* ptr) { + auto it = allocation_ptr_to_id_.find(ptr); + if (it == allocation_ptr_to_id_.end()) { + // Allocation that was made outside the validation scope is being freed here + return true; + } + auto id = (*it).second; + TORCH_CHECK(id < allocation_plan_->allocation_lifetimes.size(), + "Allocation must have been recorded during validate_allocation."); + auto lifetime_id = allocation_plan_->allocation_lifetimes[id]; + return (lifetime_id == allocation_id_); +} + +void AllocationPlanner::formulate_plan() { + allocation_plan_->allocation_offsets = + formulate_greedy_allocation_plan( + allocation_plan_->allocation_sizes, allocation_plan_->allocation_lifetimes); + allocation_plan_->total_size = 0; + for (auto i = 0; i < allocation_plan_->allocation_sizes.size(); ++i) { + if (allocation_plan_->allocation_lifetimes[i] == + std::numeric_limits::max()) { + continue; + } + auto limit = allocation_plan_->allocation_offsets[i] + allocation_plan_->allocation_sizes[i]; + allocation_plan_->total_size = std::max(allocation_plan_->total_size, limit); + } +} + +void AllocationPlanner::clear() { + allocation_plan_->clear(); + allocation_ptr_to_id_.clear(); +} + +void CPUProfilingAllocator::set_plan(const AllocationPlan* plan) { + TORCH_CHECK(plan != nullptr, "Allocation plan is nullptr."); + plan_ = plan; + allocation_id_ = 0; + allocation_ptr_to_id_.clear(); + if (current_size_ < plan->total_size) { + // Free existing memory and reallocate for larger size. + c10::free_cpu(blob_); + blob_ = c10::alloc_cpu(plan->total_size); + current_size_ = plan->total_size; + } +} + +void CPUProfilingAllocator::unset_plan() { + allocation_id_ = 0; + allocation_ptr_to_id_.clear(); + plan_ = nullptr; +} + +void* CPUProfilingAllocator::allocate(const size_t bytes) { + TORCH_CHECK(bytes == plan_->allocation_sizes[allocation_id_], + "Got allocation request that does not match with the plan."); + if (plan_->allocation_lifetimes[allocation_id_] == + std::numeric_limits::max()) { + // This allocation is not managed by ProfilingAllocator. + allocation_id_++; + return c10::alloc_cpu(bytes); + } + void* ptr = + reinterpret_cast(blob_) + + plan_->allocation_offsets[allocation_id_]; + allocation_ptr_to_id_[ptr] = allocation_id_; + allocation_id_++; + return ptr; +} + +void CPUProfilingAllocator::free(void* const ptr) { + auto it = allocation_ptr_to_id_.find(ptr); + if (it == allocation_ptr_to_id_.end()) { + // Either + // 1. Allocation that was made outside the validation scope is being freed here + // or + // 2. Allocation that is not managed by profiling allocator is being freed. + // Example of the second type + // Tensor out; + // for (....) { + // { + // CPUProfilingAllocator + // out = ...some op (This also frees previous memory held by out) + // } + // out is used.. + // } + c10::free_cpu(ptr); + return; + } + auto id = it->second; + TORCH_CHECK(id < plan_->allocation_lifetimes.size(), + "Freeing allocation that is not accordingly to the plan."); + auto lifetime_id = plan_->allocation_lifetimes[id]; + TORCH_CHECK( + lifetime_id == allocation_id_, + "Lifetime of allocations do not match: allocation_id ", + id, + ", expected:", + lifetime_id, + ", got:", + allocation_id_); +} + +CPUProfilingAllocator::~CPUProfilingAllocator() { + c10::free_cpu(blob_); +} + +WithProfileAllocationsGuard::WithProfileAllocationsGuard( + AllocationPlan* plan) { + // Nesting of allocation profiling does not seem meaningful. + TORCH_CHECK(allocation_planner == nullptr, + "Nesting profiling allocations is not supported."); + planner_ = std::make_unique(plan); + planner_->clear(); + allocation_planner = planner_.get(); +} + +WithProfileAllocationsGuard::~WithProfileAllocationsGuard() { + planner_->formulate_plan(); + allocation_planner = nullptr; +} + +WithValidateAllocationPlanGuard::WithValidateAllocationPlanGuard( + AllocationPlan* plan, bool* success) { + // Nesting of allocation profiling does not seem meaningful. + TORCH_CHECK(allocation_planner == nullptr, + "Nesting profiling allocations is not supported."); + planner_ = std::make_unique(plan, true); + success_ = success; + allocation_planner = planner_.get(); +} + +WithValidateAllocationPlanGuard::~WithValidateAllocationPlanGuard() { + *success_ = planner_->validation_success; + allocation_planner = nullptr; +} + +AllocationPlanner* GetThreadLocalAllocationPlanner() { + return allocation_planner; +} + +WithProfilingAllocatorGuard::WithProfilingAllocatorGuard( + CPUProfilingAllocator* allocator, const AllocationPlan* plan) { + // Nesting of profiling allocator is not supported. + TORCH_CHECK(profiling_allocator == nullptr, + "Nesting profiling allocators is not supported."); + profiling_allocator = allocator; + profiling_allocator->set_plan(plan); +} + +WithProfilingAllocatorGuard::~WithProfilingAllocatorGuard() { + profiling_allocator->unset_plan(); + profiling_allocator = nullptr; +} + +CPUProfilingAllocator* GetThreadLocalProfilingAllocator() { + return profiling_allocator; +} + +} // namespace c10 diff --git a/c10/mobile/CPUProfilingAllocator.h b/c10/mobile/CPUProfilingAllocator.h new file mode 100644 index 0000000000000..4a7e79fe28575 --- /dev/null +++ b/c10/mobile/CPUProfilingAllocator.h @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +/* + * Given a sequence of allocations in a thread, AllocationPlan records + * 1. size of each allocation + * 2. Lifetime of each allocation. + * 3. allocation offsets: Memory offset for each allocation in a single blob of memory + * 4. Total size of a blob of memory required to satisfy all the allocations. + */ +class C10_API AllocationPlan { + private: + // Records size of each allocation by their sequential allocation ids. + std::vector allocation_sizes; + // This maps one allocation id (X) to another allocation id (Y). + // Allocation X is alive until allocation Y. From allocation Y onwards + // allocation X is not referenced. + // Thus Y is the id of the first allocation after X is freed. + // NB: When an allocation is recorded, along with recording its size, + // we also set the lifetime to be numeric_limits::max() + // This is to track allocations that are made during the scope of + // profiling but were not freed until after the scope ended. + // Such allocations are not managed by profiling allocator. + std::vector allocation_lifetimes; + // Maps an allocation to some offset in a blob of memory. + std::vector allocation_offsets; + uint64_t total_size{0}; + void clear(); + friend class AllocationPlanner; + friend class CPUProfilingAllocator; +}; + +/* + * Map of memory ptr to allocation id. This is auxiliary information only + * used to establish lifetime of allocations. + */ +class C10_API AllocationPlanner { + private: + AllocationPlan* allocation_plan_{nullptr}; + // Maps allocated ptr to its allocation id. + // This is used when freeing the memory to lookup the allocation id + // in order to establish the lifetime of a particular allocation. + ska::flat_hash_map allocation_ptr_to_id_; + uint64_t allocation_id_{0}; + bool validation_mode_{false}; + + bool validate_allocation(const uint64_t size, const void* ptr); + bool validate_free(const void* ptr); + public: + bool validation_success{true}; + + AllocationPlanner() = delete; + AllocationPlanner(AllocationPlan* plan, bool validate = false) : + allocation_plan_(plan), validation_mode_(validate) {} + void record_allocation(const uint64_t size, const void* ptr); + void record_free(const void* ptr); + void formulate_plan(); + void clear(); +}; + +// NOT THREAD SAFE profiling allocator. +class C10_API CPUProfilingAllocator { + private: + const AllocationPlan* plan_{nullptr}; + uint64_t allocation_id_{0}; + uint64_t current_size_{0}; + void* blob_{nullptr}; + ska::flat_hash_map allocation_ptr_to_id_; + public: + ~CPUProfilingAllocator(); + void set_plan(const AllocationPlan* plan); + void unset_plan(); + void* allocate(const size_t bytes); + void free(void* const ptr); +}; + +/* + * Usage: Profile allocations made by one run of the model. + * AllocationPlan plan; + * { + * WithProfileAllocationGuard profile_guard(&plan); + * module.forward(...); + * } + * plan now contains allocation plan. + */ +class C10_API WithProfileAllocationsGuard { + public: + WithProfileAllocationsGuard(AllocationPlan* plan); + ~WithProfileAllocationsGuard(); + private: + std::unique_ptr planner_; +}; + +/* + * Usage: Validate allocation plan made with WithProfileAllocationGuard + * bool plan_validation_success, success = true; + * for (some number of representative inputs) + * { + * WithValidateAllocationPlanGuard(&plan, &plan_validation_success); + * module.forward(...); + * success = success && plan_validation_success; + * } + * success == true means allocations are according to plan + * else for some inputs allocation pattern changed. + */ +class C10_API WithValidateAllocationPlanGuard { + public: + WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success); + ~WithValidateAllocationPlanGuard(); + private: + std::unique_ptr planner_; + bool* success_; +}; + +AllocationPlanner* GetThreadLocalAllocationPlanner(); + +/* + * Usage: Allocate tensors accordingly to allocation plan + * First make allocation plan. + * See WithProfileAllocationsGuard usage. + * Second validate allocation plan. + * See WithValidateAllocationPlanGuard usage. + * CPUProfilingAllocator profiling_allocator; + * { + * WithProfilingAllocatorGuard allocator_guard(&profiling_allocator, &plan); + * module.forward(...); + * } + */ +class C10_API WithProfilingAllocatorGuard { + public: + WithProfilingAllocatorGuard( + CPUProfilingAllocator* allocator, const AllocationPlan* plan); + ~WithProfilingAllocatorGuard(); +}; + +CPUProfilingAllocator* GetThreadLocalProfilingAllocator(); + +} // namespace c10 diff --git a/c10/test/core/impl/SizesAndStrides_test.cpp b/c10/test/core/impl/SizesAndStrides_test.cpp new file mode 100644 index 0000000000000..94e90c42feffb --- /dev/null +++ b/c10/test/core/impl/SizesAndStrides_test.cpp @@ -0,0 +1,399 @@ +#include + +#include + +using namespace c10; +using namespace c10::impl; + +static void checkData(const SizesAndStrides& sz, IntArrayRef sizes, IntArrayRef strides) { + EXPECT_EQ(sizes.size(), strides.size()) << "bad test case: size() of sizes and strides don't match"; + EXPECT_EQ(sz.size(), sizes.size()); + + int idx = 0; + for (auto x: sizes) { + EXPECT_EQ(sz.size_at_unchecked(idx), x) << "index: " << idx; + EXPECT_EQ(sz.size_at(idx), x) << "index: " << idx; + EXPECT_EQ(sz.sizes_data()[idx], x) << "index: " << idx; + EXPECT_EQ(*(sz.sizes_begin() + idx), x) << "index: " << idx; + idx++; + } + EXPECT_EQ(sz.sizes_arrayref(), sizes); + + idx = 0; + for (auto x: strides) { + EXPECT_EQ(sz.stride_at_unchecked(idx), x) << "index: " << idx; + EXPECT_EQ(sz.stride_at(idx), x) << "index: " << idx; + EXPECT_EQ(sz.strides_data()[idx], x) << "index: " << idx; + EXPECT_EQ(*(sz.strides_begin() + idx), x) << "index: " << idx; + + idx++; + } + EXPECT_EQ(sz.strides_arrayref(), strides); +} + +TEST(SizesAndStridesTest, DefaultConstructor) { + SizesAndStrides sz; + checkData(sz, {0}, {1}); + // Can't test size_at() out of bounds because it just asserts for now. +} + +TEST(SizesAndStridesTest, SetSizes) { + SizesAndStrides sz; + sz.set_sizes({5, 6, 7, 8}); + checkData(sz, {5, 6, 7, 8}, {1, 0, 0, 0}); +} + +TEST(SizesAndStridesTest, Resize) { + SizesAndStrides sz; + + sz.resize(2); + + // Small to small growing. + checkData(sz, {0, 0}, {1, 0}); + + // Small to small growing, again. + sz.resize(5); + checkData(sz, {0, 0, 0, 0, 0}, {1, 0, 0, 0, 0}); + + for (int ii = 0; ii < sz.size(); ++ii) { + sz.size_at_unchecked(ii) = ii + 1; + sz.stride_at_unchecked(ii) = 2 * (ii + 1); + } + + checkData(sz, {1, 2, 3, 4, 5}, {2, 4, 6, 8, 10}); + + // Small to small, shrinking. + sz.resize(4); + checkData(sz, {1, 2, 3, 4}, {2, 4, 6, 8}); + + // Small to small with no size change. + sz.resize(4); + checkData(sz, {1, 2, 3, 4}, {2, 4, 6, 8}); + + // Small to small, growing back so that we can confirm that our "new" + // data really does get zeroed. + sz.resize(5); + checkData(sz, {1, 2, 3, 4, 0}, {2, 4, 6, 8, 0}); + + // Small to big. + sz.resize(6); + + checkData(sz, {1, 2, 3, 4, 0, 0}, {2, 4, 6, 8, 0, 0}); + + sz.size_at_unchecked(5) = 6; + sz.stride_at_unchecked(5) = 12; + + checkData(sz, {1, 2, 3, 4, 0, 6}, {2, 4, 6, 8, 0, 12}); + + // Big to big, growing. + sz.resize(7); + + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + // Big to big with no size change. + sz.resize(7); + + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + sz.size_at_unchecked(6) = 11; + sz.stride_at_unchecked(6) = 22; + + checkData(sz, {1, 2, 3, 4, 0, 6, 11}, {2, 4, 6, 8, 0, 12, 22}); + + // Big to big, shrinking. + sz.resize(6); + checkData(sz, {1, 2, 3, 4, 0, 6}, {2, 4, 6, 8, 0, 12}); + + // Grow back to make sure "new" elements get zeroed in big mode too. + sz.resize(7); + checkData(sz, {1, 2, 3, 4, 0, 6, 0}, {2, 4, 6, 8, 0, 12, 0}); + + // Finally, big to small. + + // Give it different data than it had when it was small to avoid + // getting it right by accident (i.e., because of leftover inline + // storage when going small to big). + for (int ii = 0; ii < sz.size(); ++ii) { + sz.size_at_unchecked(ii) = ii - 1; + sz.stride_at_unchecked(ii) = 2 * (ii - 1); + } + + checkData(sz, {-1, 0, 1, 2, 3, 4, 5}, {-2, 0, 2, 4, 6, 8, 10}); + + sz.resize(5); + checkData(sz, {-1, 0, 1, 2, 3}, {-2, 0, 2, 4, 6}); +} + +TEST(SizesAndStridesTest, SetAtIndex) { + SizesAndStrides sz; + + sz.resize(5); + sz.size_at(4) = 42; + sz.stride_at(4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + sz.size_at(5) = 43; + sz.stride_at(5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +TEST(SizesAndStridesTest, SetAtIterator) { + SizesAndStrides sz; + + sz.resize(5); + *(sz.sizes_begin() + 4) = 42; + *(sz.strides_begin() + 4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + *(sz.sizes_begin() + 5) = 43; + *(sz.strides_begin() + 5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +TEST(SizesAndStridesTest, SetViaData) { + SizesAndStrides sz; + + sz.resize(5); + *(sz.sizes_data() + 4) = 42; + *(sz.strides_data() + 4) = 23; + + checkData(sz, {0, 0, 0, 0, 42}, {1, 0, 0, 0, 23}); + + sz.resize(6); + *(sz.sizes_data() + 5) = 43; + *(sz.strides_data() + 5) = 24; + + checkData(sz, {0, 0, 0, 0, 42, 43}, {1, 0, 0, 0, 23, 24}); +} + +static SizesAndStrides makeSmall(int offset = 0) { + SizesAndStrides small; + small.resize(3); + for (int ii = 0; ii < small.size(); ++ii) { + small.size_at_unchecked(ii) = ii + 1 + offset; + small.stride_at_unchecked(ii) = 2 * (ii + 1 + offset); + } + + return small; +} + +static SizesAndStrides makeBig(int offset = 0) { + SizesAndStrides big; + big.resize(8); + for (int ii = 0; ii < big.size(); ++ii) { + big.size_at_unchecked(ii) = ii - 1 + offset; + big.stride_at_unchecked(ii) = 2 * (ii - 1 + offset); + } + + return big; +} + +static void checkSmall(const SizesAndStrides& sm, int offset = 0) { + std::vector sizes(3), strides(3); + for (int ii = 0; ii < 3; ++ii) { + sizes[ii] = ii + 1 + offset; + strides[ii] = 2 * (ii + 1 + offset); + } + checkData(sm, sizes, strides); +} + +static void checkBig(const SizesAndStrides& big, int offset = 0) { + std::vector sizes(8), strides(8); + for (int ii = 0; ii < 8; ++ii) { + sizes[ii] = ii - 1 + offset; + strides[ii] = 2 * (ii - 1 + offset); + } + checkData(big, sizes, strides); +} + +TEST(SizesAndStridesTest, MoveConstructor) { + SizesAndStrides empty; + + SizesAndStrides movedEmpty(std::move(empty)); + + EXPECT_EQ(empty.size(), 0); + EXPECT_EQ(movedEmpty.size(), 1); + checkData(movedEmpty, {0}, {1}); + + SizesAndStrides small = makeSmall(); + checkSmall(small); + + SizesAndStrides movedSmall(std::move(small)); + checkSmall(movedSmall); + EXPECT_EQ(small.size(), 0); + + SizesAndStrides big = makeBig(); + checkBig(big); + + SizesAndStrides movedBig(std::move(big)); + checkBig(movedBig); + EXPECT_EQ(big.size(), 0); +} + +TEST(SizesAndStridesTest, CopyConstructor) { + SizesAndStrides empty; + + SizesAndStrides copiedEmpty(empty); + + EXPECT_EQ(empty.size(), 1); + EXPECT_EQ(copiedEmpty.size(), 1); + checkData(empty, {0}, {1}); + checkData(copiedEmpty, {0}, {1}); + + SizesAndStrides small = makeSmall(); + checkSmall(small); + + SizesAndStrides copiedSmall(small); + checkSmall(copiedSmall); + checkSmall(small); + + SizesAndStrides big = makeBig(); + checkBig(big); + + SizesAndStrides copiedBig(big); + checkBig(big); + checkBig(copiedBig); +} + +TEST(SizesAndStridesTest, CopyAssignmentSmallToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides smallCopyFrom = makeSmall(1); + + checkSmall(smallTarget); + checkSmall(smallCopyFrom, 1); + + smallTarget = smallCopyFrom; + + checkSmall(smallTarget, 1); + checkSmall(smallCopyFrom, 1); +} + +TEST(SizesAndStridesTest, MoveAssignmentSmallToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides smallMoveFrom = makeSmall(1); + + checkSmall(smallTarget); + checkSmall(smallMoveFrom, 1); + + smallTarget = std::move(smallMoveFrom); + + checkSmall(smallTarget, 1); + EXPECT_EQ(smallMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentSmallToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides smallCopyFrom = makeSmall(); + + checkBig(bigTarget); + checkSmall(smallCopyFrom); + + bigTarget = smallCopyFrom; + + checkSmall(bigTarget); + checkSmall(smallCopyFrom); +} + +TEST(SizesAndStridesTest, MoveAssignmentSmallToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides smallMoveFrom = makeSmall(); + + checkBig(bigTarget); + checkSmall(smallMoveFrom); + + bigTarget = std::move(smallMoveFrom); + + checkSmall(bigTarget); + EXPECT_EQ(smallMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentBigToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides bigCopyFrom = makeBig(1); + + checkBig(bigTarget); + checkBig(bigCopyFrom, 1); + + bigTarget = bigCopyFrom; + + checkBig(bigTarget, 1); + checkBig(bigCopyFrom, 1); +} + +TEST(SizesAndStridesTest, MoveAssignmentBigToBig) { + SizesAndStrides bigTarget = makeBig(); + SizesAndStrides bigMoveFrom = makeBig(1); + + checkBig(bigTarget); + checkBig(bigMoveFrom, 1); + + bigTarget = std::move(bigMoveFrom); + + checkBig(bigTarget, 1); + EXPECT_EQ(bigMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentBigToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides bigCopyFrom = makeBig(); + + checkSmall(smallTarget); + checkBig(bigCopyFrom); + + smallTarget = bigCopyFrom; + + checkBig(smallTarget); + checkBig(bigCopyFrom); +} + +TEST(SizesAndStridesTest, MoveAssignmentBigToSmall) { + SizesAndStrides smallTarget = makeSmall(); + SizesAndStrides bigMoveFrom = makeBig(); + + checkSmall(smallTarget); + checkBig(bigMoveFrom); + + smallTarget = std::move(bigMoveFrom); + + checkBig(smallTarget); + EXPECT_EQ(bigMoveFrom.size(), 0); +} + +TEST(SizesAndStridesTest, CopyAssignmentSelf) { + SizesAndStrides small = makeSmall(); + SizesAndStrides big = makeBig(); + + checkSmall(small); + checkBig(big); + + small = small; + checkSmall(small); + + big = big; + checkBig(big); +} + +// Avoid failures due to -Wall -Wself-move. +static void selfMove(SizesAndStrides& x, SizesAndStrides& y) { + x = std::move(y); +} + +TEST(SizesAndStridesTest, MoveAssignmentSelf) { + SizesAndStrides small = makeSmall(); + SizesAndStrides big = makeBig(); + + checkSmall(small); + checkBig(big); + + selfMove(small, small); + checkSmall(small); + + selfMove(big, big); + checkBig(big); +} diff --git a/c10/test/util/Metaprogramming_test.cpp b/c10/test/util/Metaprogramming_test.cpp index 0f55814bf6f5c..63613980079d1 100644 --- a/c10/test/util/Metaprogramming_test.cpp +++ b/c10/test/util/Metaprogramming_test.cpp @@ -243,14 +243,36 @@ namespace test_tuple_take { TEST(MetaprogrammingTest, TupleTake_nonemptyPrefix) { auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_take, 2>(x); + auto y = tuple_take(x); auto z = std::make_tuple(0, "HEY"); EXPECT_EQ(y, z); } TEST(MetaprogrammingTest, TupleTake_fullPrefix) { auto x = std::make_tuple(0, "HEY", 2.0); - auto y = tuple_take, 3>(x); + auto y = tuple_take(x); + EXPECT_EQ(x, y); + } + + TEST(MetaprogrammingTest, TupleTake_negative) { + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_take(x); + auto z = std::make_tuple("HEY", 2.0); + EXPECT_EQ(y, z); + } +} + +namespace test_tuple_slice { + TEST(MetaprogrammingTest, TupleSlice_middle) { + auto x = std::make_tuple(0, "HEY", 2.0, false); + auto y = tuple_slice(x); + auto z = std::make_tuple("HEY", 2.0); + EXPECT_EQ(y, z); + } + + TEST(MetaprogrammingTest, TupleSlice_full) { + auto x = std::make_tuple(0, "HEY", 2.0); + auto y = tuple_slice(x); EXPECT_EQ(x, y); } } @@ -454,4 +476,22 @@ namespace test_tuple_concat { } } +namespace test_concat_iseq { + using std::index_sequence; + using std::integer_sequence; + static_assert(std::is_same, concat_iseq_t<>>::value, ""); + static_assert(std::is_same, concat_iseq_t>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4>, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<2>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4, 2>, index_sequence<>>>::value, ""); + static_assert(std::is_same, concat_iseq_t, index_sequence<4, 2>, index_sequence<9>>>::value, ""); + + static_assert(std::is_same, concat_iseq_t, integer_sequence>>::value, ""); +} + + } diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp index d08f512053aba..af00bab99c5bb 100644 --- a/c10/test/util/bfloat16_test.cpp +++ b/c10/test/util/bfloat16_test.cpp @@ -87,7 +87,7 @@ namespace { } TEST(BFloat16Math, Addition) { - // This test verifies that if only first 7 bits of float's mantisa are + // This test verifies that if only first 7 bits of float's mantissa are // changed after addition, we should have no loss in precision. // input bits @@ -108,8 +108,8 @@ namespace { EXPECT_EQ(res, expected); } - TEST(BFloat16Math, Substraction) { - // This test verifies that if only first 7 bits of float's mantisa are + TEST(BFloat16Math, Subtraction) { + // This test verifies that if only first 7 bits of float's mantissa are // changed after subtraction, we should have no loss in precision. // input bits diff --git a/c10/test/util/intrusive_ptr_test.cpp b/c10/test/util/intrusive_ptr_test.cpp index 34b96d71288c4..9df5b004a094c 100644 --- a/c10/test/util/intrusive_ptr_test.cpp +++ b/c10/test/util/intrusive_ptr_test.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -693,21 +694,21 @@ TEST(IntrusivePtrTest, Equality_Nullptr) { EXPECT_FALSE(var1 != var2); } -TEST(IntrusivePtrTest, Nonequality) { +TEST(IntrusivePtrTest, Inequality) { intrusive_ptr var1 = make_intrusive(); intrusive_ptr var2 = make_intrusive(); EXPECT_TRUE(var1 != var2); EXPECT_FALSE(var1 == var2); } -TEST(IntrusivePtrTest, Nonequality_NullptrLeft) { +TEST(IntrusivePtrTest, Inequality_NullptrLeft) { intrusive_ptr var1; intrusive_ptr var2 = make_intrusive(); EXPECT_TRUE(var1 != var2); EXPECT_FALSE(var1 == var2); } -TEST(IntrusivePtrTest, Nonequality_NullptrRight) { +TEST(IntrusivePtrTest, Inequality_NullptrRight) { intrusive_ptr var1 = make_intrusive(); intrusive_ptr var2; EXPECT_TRUE(var1 != var2); @@ -1652,6 +1653,21 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenLocking_thenReturnsCorrectObject) { EXPECT_EQ(var.ptr.get(), locked.get()); } +TEST(WeakIntrusivePtrTest, expiredPtr_whenLocking_thenReturnsNullType) { + IntrusiveAndWeak var = make_weak_intrusive(); + // reset the intrusive_ptr to test if weak pointer still valid + var.ptr.reset(); + EXPECT_TRUE(var.weak.expired()); + intrusive_ptr locked = var.weak.lock(); + EXPECT_FALSE(locked.defined()); +} + +TEST(WeakIntrusivePtrTest, weakNullPtr_locking) { + auto weak_ptr = make_invalid_weak(); + intrusive_ptr locked = weak_ptr.lock(); + EXPECT_FALSE(locked.defined()); +} + TEST( WeakIntrusivePtrTest, givenValidPtr_whenMoveAssigning_thenPointsToSameObject) { @@ -1671,6 +1687,15 @@ TEST( EXPECT_TRUE(obj1.weak.expired()); } +TEST( + WeakIntrusivePtrTest, + vector_insert_weak_intrusive) { + std::vector> priorWorks; + std::vector> wips; + wips.push_back(make_intrusive()); + priorWorks.insert(priorWorks.end(), wips.begin(), wips.end()); + EXPECT_EQ(priorWorks.size(), 1); +} TEST( WeakIntrusivePtrTest, givenInvalidPtr_whenMoveAssigning_thenNewInstanceIsValid) { @@ -2462,28 +2487,28 @@ TEST(WeakIntrusivePtrTest, Equality_Invalid) { EXPECT_FALSE(var1 != var2); } -TEST(WeakIntrusivePtrTest, Nonequality) { +TEST(WeakIntrusivePtrTest, Inequality) { IntrusiveAndWeak var1 = make_intrusive(); IntrusiveAndWeak var2 = make_intrusive(); EXPECT_TRUE(var1.weak != var2.weak); EXPECT_FALSE(var1.weak == var2.weak); } -TEST(WeakIntrusivePtrTest, Nonequality_InvalidLeft) { +TEST(WeakIntrusivePtrTest, Inequality_InvalidLeft) { weak_intrusive_ptr var1 = make_invalid_weak(); IntrusiveAndWeak var2 = make_intrusive(); EXPECT_TRUE(var1 != var2.weak); EXPECT_FALSE(var1 == var2.weak); } -TEST(WeakIntrusivePtrTest, Nonequality_InvalidRight) { +TEST(WeakIntrusivePtrTest, Inequality_InvalidRight) { IntrusiveAndWeak var1 = make_intrusive(); weak_intrusive_ptr var2 = make_invalid_weak(); EXPECT_TRUE(var1.weak != var2); EXPECT_FALSE(var1.weak == var2); } -TEST(WeakIntrusivePtrTest, Nonequality_WeakOnly) { +TEST(WeakIntrusivePtrTest, Inequality_WeakOnly) { weak_intrusive_ptr var1 = make_weak_only(); weak_intrusive_ptr var2 = make_weak_only(); EXPECT_TRUE(var1 != var2); diff --git a/c10/test/util/irange_test.cpp b/c10/test/util/irange_test.cpp new file mode 100644 index 0000000000000..d210a9c74bacc --- /dev/null +++ b/c10/test/util/irange_test.cpp @@ -0,0 +1,58 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include + +#include + +using namespace ::testing; + +TEST(irange_test, range_test) { + std::vector test_vec; + for(const auto i : c10::irange(4, 11)){ + test_vec.push_back(i); + } + const std::vector correct = {{4,5,6,7,8,9,10}}; + ASSERT_EQ(test_vec, correct); +} + +TEST(irange_test, end_test) { + std::vector test_vec; + for(const auto i : c10::irange(5)){ + test_vec.push_back(i); + } + const std::vector correct = {{0, 1, 2, 3, 4}}; + ASSERT_EQ(test_vec, correct); +} + +TEST(irange_test, neg_range_test) { + std::vector test_vec; + for(const auto i : c10::irange(-2, 3)){ + test_vec.push_back(i); + } + const std::vector correct = {{-2,-1,0,1,2}}; + ASSERT_EQ(test_vec, correct); +} + +TEST(irange, empty_reverse_range_two_inputs){ + std::vector test_vec; + for(const auto i : c10::irange(3, -3)){ + test_vec.push_back(i); + if(i>20){ //Cap the number of elements we add if something goes wrong + break; + } + } + const std::vector correct = {}; + ASSERT_EQ(test_vec, correct); +} + +TEST(irange, empty_reverse_range_one_input){ + std::vector test_vec; + for(const auto i : c10::irange(-3)){ + test_vec.push_back(i); + if(i>20){ //Cap the number of elements we add if something goes wrong + break; + } + } + const std::vector correct = {}; + ASSERT_EQ(test_vec, correct); +} diff --git a/c10/test/util/optional_test.cpp b/c10/test/util/optional_test.cpp new file mode 100644 index 0000000000000..c5d1a30ef1dc4 --- /dev/null +++ b/c10/test/util/optional_test.cpp @@ -0,0 +1,80 @@ +#include + +#include + +#include +#include +#include + +namespace { + +template +class OptionalTest : public ::testing::Test { + public: + using optional = c10::optional; +}; + +template +T getSampleValue(); + +template<> +bool getSampleValue() { + return true; +} + +template<> +uint64_t getSampleValue() { + return 42; +} + +template<> +std::string getSampleValue() { + return "hello"; +} + + +using OptionalTypes = ::testing::Types< + // 32-bit scalar optimization. + bool, + // Trivially destructible but not 32-bit scalar. + uint64_t, + // Non-trivial destructor. + std::string + >; + + +TYPED_TEST_CASE(OptionalTest, OptionalTypes); + +TYPED_TEST(OptionalTest, Empty) { + typename TestFixture::optional empty; + + EXPECT_FALSE((bool)empty); + EXPECT_FALSE(empty.has_value()); + + EXPECT_THROW(empty.value(), c10::bad_optional_access); +} + +TYPED_TEST(OptionalTest, Initialized) { + using optional = typename TestFixture::optional; + + const auto val = getSampleValue(); + optional opt((val)); + auto copy(opt), moveFrom1(opt), moveFrom2(opt); + optional move(std::move(moveFrom1)); + optional copyAssign; + copyAssign = opt; + optional moveAssign; + moveAssign = std::move(moveFrom2); + + std::array opts = {&opt, ©, ©Assign, &move, &moveAssign}; + for (auto* popt : opts) { + auto& opt = *popt; + EXPECT_TRUE((bool)opt); + EXPECT_TRUE(opt.has_value()); + + EXPECT_EQ(opt.value(), val); + EXPECT_EQ(*opt, val); + } +} + +} // namespace diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 513b83891cf13..ee40c572187cb 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -161,8 +161,7 @@ class ArrayRef final { return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); } - /// slice(n, m) - Chop off the first N elements of the array, and keep M - /// elements in the array. + /// slice(n, m) - Take M elements of the array starting at element N C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef slice(size_t N, size_t M) const { TORCH_CHECK( N + M <= size(), diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index da6ce38595527..878a7e307e559 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -7,15 +7,44 @@ namespace c10 { /// Constructors inline C10_HOST_DEVICE BFloat16::BFloat16(float value) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + x = __bfloat16_as_ushort(__float2bfloat16(value)); +#else // RNE by default x = detail::round_to_nearest_even(value); +#endif } /// Implicit conversions inline C10_HOST_DEVICE BFloat16::operator float() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + return __bfloat162float(*reinterpret_cast(&x)); +#else return detail::f32_from_bits(x); +#endif } +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); +#else + return *ptr; +#endif +} +#endif + /// Arithmetic inline C10_HOST_DEVICE BFloat16 operator+(const BFloat16& a, const BFloat16& b) { @@ -30,7 +59,7 @@ inline C10_HOST_DEVICE BFloat16 operator*(const BFloat16& a, const BFloat16& b) return static_cast(a) * static_cast(b); } -inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) { +inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) __ubsan_ignore_float_divide_by_zero__ { return static_cast(a) / static_cast(b); } diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index 375b1086e0739..0bd115d568f6c 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -7,6 +7,10 @@ #include #include +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +#include +#endif + namespace c10 { namespace detail { @@ -84,6 +88,11 @@ struct alignas(2) BFloat16 { constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits){}; inline C10_HOST_DEVICE BFloat16(float value); inline C10_HOST_DEVICE operator float() const; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); + explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; +#endif }; } // namespace c10 diff --git a/c10/util/Backtrace.cpp b/c10/util/Backtrace.cpp index 09e41b99109ea..17e4923fe4e54 100644 --- a/c10/util/Backtrace.cpp +++ b/c10/util/Backtrace.cpp @@ -9,12 +9,8 @@ #include #ifdef _MSC_VER -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif +#include #include -#include -#include #pragma comment(lib, "Dbghelp.lib") #endif diff --git a/c10/util/Bitset.h b/c10/util/Bitset.h index e849563e60fe7..964146be05e7b 100644 --- a/c10/util/Bitset.h +++ b/c10/util/Bitset.h @@ -64,7 +64,7 @@ struct bitset final { bitset cur = *this; size_t index = cur.find_first_set(); while (0 != index) { - // -1 because find_first_set() is not one-indiced. + // -1 because find_first_set() is not one-indexed. index -= 1; func(index); cur.unset(index); @@ -73,7 +73,7 @@ struct bitset final { } private: - // Return the index of the first set bit. The returned index is one-indiced + // Return the index of the first set bit. The returned index is one-indexed // (i.e. if the very first bit is set, this function returns '1'), and a return // of '0' means that there was no bit set. size_t find_first_set() const { diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 9329ab3b854cf..c7d9a8c1df1e6 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -24,6 +24,10 @@ #error You need C++14 to compile PyTorch #endif +#if defined(_WIN32) && (defined(min) || defined(max)) +# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows +#endif + /* * This header adds some polyfills with C++17 functionality */ @@ -258,6 +262,11 @@ struct _if_constexpr final { * Note: In Example 3, both branches return int, so func() returns int. This is not necessary. * If func() had a return type of "auto", then both branches could return different * types, say func() could return int and func() could return string. + * + * Note: if_constexpr is *eager* w.r.t. template expansion - meaning this + * polyfill does not behave like a true "if statement at compilation time". + * The `_` trick above only defers typechecking, which happens after templates + * have been expanded. (Of course this is all that's necessary for many use cases). */ template decltype(auto) if_constexpr(ThenCallback&& thenCallback, ElseCallback&& elseCallback) { diff --git a/c10/util/Exception.cpp b/c10/util/Exception.cpp index 3f2c34ffae6c1..1c63633263432 100644 --- a/c10/util/Exception.cpp +++ b/c10/util/Exception.cpp @@ -76,6 +76,14 @@ void Error::add_context(std::string new_msg) { refresh_what(); } +namespace detail { + +void torchCheckFail(const char *func, const char *file, uint32_t line, const std::string& msg) { + throw ::c10::Error({func, file, line}, msg); +} + +} // namespace detail + namespace Warning { namespace { diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 4b55c562130ca..ebd1e872251ee 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -181,6 +181,12 @@ class C10_API EnforceFiniteError : public Error { using Error::Error; }; +// Used in Onnxifi backend lowering. These turn into +// ExitException when they cross to Python. +class C10_API OnnxfiBackendSystemError : public Error { + using Error::Error; +}; + // A utility function to return an exception std::string by prepending its // exception type before its what() content C10_API std::string GetExceptionString(const std::exception& e); @@ -188,7 +194,7 @@ C10_API std::string GetExceptionString(const std::exception& e); namespace detail { // Return x if it is non-empty; otherwise return y. -inline std::string if_empty_then(std::string x, std::string y) { +inline std::string if_empty_then(const std::string& x, const std::string& y) { if (x.empty()) { return y; } else { @@ -318,29 +324,47 @@ inline std::string if_empty_then(std::string x, std::string y) { TORCH_CHECK_WITH_MSG(error_t, cond, "", __VA_ARGS__) #ifdef STRIP_ERROR_MESSAGES -#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ - if (C10_UNLIKELY_OR_CONST(!(cond))) { \ - C10_THROW_ERROR(Error, \ - #cond #type " CHECK FAILED at " \ - C10_STRINGIZE(__FILE__) \ - ); \ +#define TORCH_CHECK_MSG(cond, type, ...) \ + (#cond #type " CHECK FAILED at " \ + C10_STRINGIZE(__FILE__)) +#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + C10_THROW_ERROR(Error, \ + TORCH_CHECK_MSG(cond, type, __VA_ARGS__) \ + ); \ } #else +#define TORCH_CHECK_MSG(cond, type, ...) \ + ::c10::detail::if_empty_then( \ + ::c10::str(__VA_ARGS__), \ + "Expected " #cond " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)" \ + ) #define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ if (C10_UNLIKELY_OR_CONST(!(cond))) { \ C10_THROW_ERROR(error_t, \ - ::c10::detail::if_empty_then( \ - ::c10::str(__VA_ARGS__), \ - "Expected " #cond " to be true, but got false. " \ - "(Could this error message be improved? If so, " \ - "please report an enhancement request to PyTorch.)" \ - ) \ + TORCH_CHECK_MSG(cond, type, __VA_ARGS__) \ ); \ } #endif -#define TORCH_CHECK(cond, ...) TORCH_CHECK_WITH(Error, cond, __VA_ARGS__) -// An utility macro that does what `TORCH_CHECK` does if compiled in the host code, +namespace c10 { +namespace detail { + +[[noreturn]] C10_API void torchCheckFail(const char *func, const char *file, uint32_t line, const std::string& msg); + +} // namespace detail +} // namespace 10 + +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, __FILE__, static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } + +// An utility macro that does what `TORCH_CHECK` does if compiled in the host code, // otherwise does nothing. Supposed to be used in the code shared between host and // device code as an alternative for `TORCH_CHECK`. #if defined(__CUDACC__) || defined(__HIPCC__) @@ -381,18 +405,30 @@ inline std::string if_empty_then(std::string x, std::string y) { // Report a warning to the user. Accepts an arbitrary number of extra // arguments which are concatenated into the warning message using operator<< // +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_WARN(...) \ + ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, {}, false) +#else #define TORCH_WARN(...) \ ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, ::c10::str(__VA_ARGS__), false) +#endif // Report a warning to the user only once. Accepts an arbitrary number of extra // arguments which are concatenated into the warning message using operator<< // +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_WARN_ONCE(...) \ + C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [&] { \ + ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, {}, false); \ + return true; \ + }() +#else #define TORCH_WARN_ONCE(...) \ C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [&] { \ ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, ::c10::str(__VA_ARGS__), false); \ return true; \ }() - +#endif // ---------------------------------------------------------------------------- // Deprecated macros diff --git a/c10/util/Flags.h b/c10/util/Flags.h index 6bfe62507fcdb..b4352510c9973 100644 --- a/c10/util/Flags.h +++ b/c10/util/Flags.h @@ -4,7 +4,7 @@ /* Commandline flags support for C10. * * This is a portable commandline flags tool for c10, so we can optionally - * choose to use gflags or a lightweighted custom implementation if gflags is + * choose to use gflags or a lightweight custom implementation if gflags is * not possible on a certain platform. If you have gflags installed, set the * macro C10_USE_GFLAGS will seamlessly route everything to gflags. * diff --git a/c10/util/FunctionRef.h b/c10/util/FunctionRef.h index a3730476b734a..b3b9930cbbb50 100644 --- a/c10/util/FunctionRef.h +++ b/c10/util/FunctionRef.h @@ -18,6 +18,10 @@ #pragma once +#include +#include +#include + namespace c10 { /// An efficient, type-erasing, non-owning reference to a callable. This is diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h index 0a1421d3c9ad8..6fa3309d0a5ca 100644 --- a/c10/util/Half-inl.h +++ b/c10/util/Half-inl.h @@ -66,7 +66,7 @@ inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { return static_cast(a) * static_cast(b); } -inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) { +inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) __ubsan_ignore_float_divide_by_zero__ { return static_cast(a) / static_cast(b); } diff --git a/c10/util/Half.h b/c10/util/Half.h index 8f8dd3467367a..01562acea7045 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -328,7 +328,9 @@ namespace detail { const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); + return static_cast( + (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) + ); } } // namespace detail @@ -372,10 +374,11 @@ struct alignas(4) complex { Half imag() const { return imag_; } - inline complex(c10::complex value) - : real_(value.real()), imag_(value.imag()) {} - inline complex(c10::complex value) + explicit inline complex(c10::complex value) : real_(value.real()), imag_(value.imag()) {} + explicit inline complex(c10::complex value) + : real_(static_cast(value.real())), + imag_(static_cast(value.imag())) {} inline operator c10::complex() const { return {real_, imag_}; } diff --git a/c10/util/Logging.h b/c10/util/Logging.h index acab3cfecd23a..6fa7e93f26d81 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -284,7 +284,7 @@ BINARY_COMP_HELPER(LessEquals, <=) * Very lightweight logging for the first time API usage. It's beneficial for * tracking of individual functionality usage in larger applications. * - * In order to ensure light-weightness of logging, we utilize static variable + * In order to ensure light-weightedness of logging, we utilize static variable * trick - LogAPIUsage will be invoked only once and further invocations will * just do an atomic check. * diff --git a/c10/util/Metaprogramming.h b/c10/util/Metaprogramming.h index ee52520973777..a56b43afa852e 100644 --- a/c10/util/Metaprogramming.h +++ b/c10/util/Metaprogramming.h @@ -130,6 +130,30 @@ decltype(auto) filter_map(const Mapper& mapper, Args&&... args) { } +/** + * make_offset_index_sequence + * Like make_index_sequence, but starting from Start instead of 0. + * + * Example: + * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12> + */ +template +struct make_offset_index_sequence_impl + : make_offset_index_sequence_impl +{ + static_assert(static_cast(Start) >= 0, "make_offset_index_sequence: Start < 0"); + static_assert(static_cast(N) >= 0, "make_offset_index_sequence: N < 0"); +}; + +template +struct make_offset_index_sequence_impl { + typedef std::index_sequence type; +}; + +template +using make_offset_index_sequence = typename make_offset_index_sequence_impl::type; + + /** * Use tuple_elements to extract a position-indexed subset of elements * from the argument tuple into a result tuple. @@ -138,22 +162,58 @@ decltype(auto) filter_map(const Mapper& mapper, Args&&... args) { * std::tuple t = std::make_tuple(0, "HEY", 2.0); * std::tuple result = tuple_elements(t, std::index_sequence<0, 2>()); */ -template -constexpr auto tuple_elements(Tuple t, std::index_sequence) { - return std::tuple...>(std::get(t)...); +template +constexpr auto tuple_elements(Tuple t, std::index_sequence) { + return std::tuple...>(std::get(t)...); } /** - * Use tuple_take to extract the first n elements from the argument tuple - * into a result tuple. + * Use tuple_take to extract the first or last n elements from the argument + * tuple into a result tuple. * * Example: * std::tuple t = std::make_tuple(0, "HEY", 2.0); - * std::tuple result = tuple_take(t); + * std::tuple first_two = tuple_take(t); + * std::tuple last_two = tuple_take(t); + */ +template +struct TupleTake {}; + +template +struct TupleTake= 0, void>> { + static auto call(Tuple t) { + constexpr size_t size = std::tuple_size(); + static_assert(N <= size, "tuple_take: N > size"); + return tuple_elements(t, std::make_index_sequence{}); + } +}; + +template +struct TupleTake> { + static auto call(Tuple t) { + constexpr size_t size = std::tuple_size(); + static_assert(-N <= size, "tuple_take: -N > size"); + return tuple_elements(t, make_offset_index_sequence{}); + } +}; + +template +auto tuple_take(Tuple t) { + return TupleTake::call(t); +} + +/** + * Use tuple_slice to extract a contiguous subtuple from the argument. + * + * Example: + * std::tuple t = std::make_tuple(0, "HEY", 2.0, false); + * std::tuple middle_two = tuple_slice(t); */ -template -constexpr auto tuple_take(Tuple t) { - return tuple_elements(t, std::make_index_sequence{}); +template +constexpr auto tuple_slice(Tuple t) { + constexpr size_t size = std::tuple_size(); + static_assert(Start + N <= size, "tuple_slice: Start + N > size"); + return tuple_elements(t, make_offset_index_sequence{}); } @@ -249,4 +309,29 @@ template } +/** + * Concatenate multiple integer sequences + * Example: + * concat_iseq_t, std::index_sequence<4, 2>, std::index_sequence<5>> + * == std::index_sequence<2, 5, 3, 4, 2, 5> + */ +template struct concat_iseq { + static_assert(false_t::value, "In concat_iseq, the T arguments each must be std::integer_sequence<...> with the same IntType."); +}; +template<> +struct concat_iseq<> { + using type = std::index_sequence<>; +}; +template +struct concat_iseq> { + using type = std::integer_sequence; +}; +template +struct concat_iseq, std::integer_sequence, TailISeqs...> { + using type = typename concat_iseq, TailISeqs...>::type; +}; +template +using concat_iseq_t = typename concat_iseq::type; + + }} diff --git a/c10/util/Optional.cpp b/c10/util/Optional.cpp index 7389393e662fd..dd91391719b45 100644 --- a/c10/util/Optional.cpp +++ b/c10/util/Optional.cpp @@ -1 +1,10 @@ #include + +#include + +// CUDA 9.2 and below fail while trying to compile default move constructor +// see https://github.com/pytorch/csprng/issues/84 +#if (!defined(__CUDA_ARCH__) || !defined(CUDA_VERSION) || CUDA_VERSION > 9200) +static_assert(C10_IS_TRIVIALLY_COPYABLE(c10::optional), "c10::optional should be trivially copyable"); +static_assert(C10_IS_TRIVIALLY_COPYABLE(c10::optional), "c10::optional should be trivially copyable"); +#endif \ No newline at end of file diff --git a/c10/util/Optional.h b/c10/util/Optional.h index 915fd9047da4e..440adb6a1654d 100644 --- a/c10/util/Optional.h +++ b/c10/util/Optional.h @@ -10,21 +10,22 @@ // From https://github.com/akrzemi1/Optional // // C10 -// - Move to `c10` namespace. -// - Remove macro use in line 478 because the nvcc device compiler cannot handle +// - Move file to `c10` namespace. +// - Remove macro use in line 478 because the nvcc device compiler cannot handle it // it. -// - revise constructor logic so that it is consistent with c++ 17 standard documented -// here in (8): https://en.cppreference.com/w/cpp/utility/optional/optional, and -// could be able to support initialization of optionals from convertible type U, also -// remove two old constructors optional(const T&) and optional(T&&) as it could be -// handled by the template case with default template argument. -// - `constexpr struct in_place_t {} in_place{}` is moved to `c10/util/in_place.h`, +// - Revise constructor logic so that it is 1) consistent with c++ 17 standard documented +// here in (8): https://en.cppreference.com/w/cpp/utility/optional/optional, and 2) +// able to support initialization of optionals from convertible type U. +// - Remove the constructors for `optional(const T&)` and `optional(T&&)`, as they can be +// handled by the template case with the default template argument. +// - Move `constexpr struct in_place_t {} in_place{}` to `c10/util/in_place.h` // so that it can also be used in `c10/util/variant.h`. -// - Remove special cases for pre-c++14 compilers to make code simpler +// - Remove special cases for pre-c++14 compilers to make code simpler. #ifndef C10_UTIL_OPTIONAL_H_ #define C10_UTIL_OPTIONAL_H_ +#include #include #include @@ -35,6 +36,8 @@ #include #include +#include + #define TR2_OPTIONAL_REQUIRES(...) \ typename std::enable_if<__VA_ARGS__::value, bool>::type = false @@ -137,10 +140,9 @@ constexpr struct trivial_init_t { // 20.5.7, Disengaged state indicator struct nullopt_t { - struct init {}; - constexpr explicit nullopt_t(init) {} + constexpr explicit nullopt_t(int) {} }; -constexpr nullopt_t nullopt{nullopt_t::init()}; +constexpr nullopt_t nullopt {0}; // 20.5.8, class bad_optional_access class bad_optional_access : public std::logic_error { @@ -185,8 +187,22 @@ struct optional_base { constexpr optional_base() noexcept : init_(false), storage_(trivial_init){}; + explicit constexpr optional_base(const optional_base& v) : init_(v.init_), storage_(trivial_init) { + if (init_) { + ::new (dataptr()) T(v.storage_.value_); + } + } + explicit constexpr optional_base(const T& v) : init_(true), storage_(v) {} + explicit constexpr optional_base(optional_base&& v) noexcept( + std::is_nothrow_move_constructible::value) + : init_(v.init_), storage_(trivial_init) { + if (init_) { + ::new (dataptr()) T(std::move(v.storage_.value_)); + } + } + explicit constexpr optional_base(T&& v) : init_(true), storage_(constexpr_move(v)) {} @@ -204,10 +220,52 @@ struct optional_base { Args&&... args) : init_(true), storage_(il, std::forward(args)...) {} + optional_base& operator=(const optional_base& rhs) { + if (init_ && !rhs.init_) { + clear(); + } else if (!init_ && rhs.init_) { + init_ = true; + ::new (dataptr()) T(rhs.storage_.value_); + } else if (init_ && rhs.init_) { + storage_.value_ = rhs.storage_.value_; + } + return *this; + } + + optional_base& operator=(optional_base&& rhs) noexcept( + std::is_nothrow_move_assignable::value && + std::is_nothrow_move_constructible::value) { + if (init_ && !rhs.init_) { + clear(); + } else if (!init_ && rhs.init_) { + init_ = true; + ::new (dataptr()) T(std::move(rhs.storage_.value_)); + } else if (init_ && rhs.init_) { + storage_.value_ = std::move(rhs.storage_.value_); + } + return *this; + } + ~optional_base() { if (init_) storage_.value_.T::~T(); } + + private: + typename std::remove_const::type* dataptr() { + return std::addressof(storage_.value_); + } + + constexpr const T* dataptr() const { + return detail_::static_addressof(storage_.value_); + } + + void clear() noexcept { + if (init_) { + dataptr()->~T(); + } + init_ = false; + } }; template @@ -218,6 +276,20 @@ struct constexpr_optional_base { constexpr constexpr_optional_base() noexcept : init_(false), storage_(trivial_init){}; + explicit constexpr constexpr_optional_base(const constexpr_optional_base& v) : init_(v.init_), storage_(trivial_init) { + if (init_) { + ::new (dataptr()) T(v.storage_.value_); + } + } + + explicit constexpr constexpr_optional_base(constexpr_optional_base&& v) noexcept( + std::is_nothrow_move_constructible::value) + : init_(v.init_), storage_(trivial_init) { + if (init_) { + ::new (dataptr()) T(std::move(v.storage_.value_)); + } + } + explicit constexpr constexpr_optional_base(const T& v) : init_(true), storage_(v) {} @@ -239,23 +311,134 @@ struct constexpr_optional_base { : init_(true), storage_(il, std::forward(args)...) {} ~constexpr_optional_base() = default; + + constexpr_optional_base& operator=(const constexpr_optional_base& rhs) { + if (init_ && !rhs.init_) { + clear(); + } else if (!init_ && rhs.init_) { + init_ = true; + ::new (dataptr()) T(rhs.storage_.value_); + } else if (init_ && rhs.init_) { + storage_.value_ = rhs.storage_.value_; + } + return *this; + } + + constexpr_optional_base& operator=(constexpr_optional_base&& rhs) noexcept( + std::is_nothrow_move_assignable::value && + std::is_nothrow_move_constructible::value) { + if (init_ && !rhs.init_) { + clear(); + } else if (!init_ && rhs.init_) { + init_ = true; + ::new (dataptr()) T(std::move(rhs.storage_.value_)); + } else if (init_ && rhs.init_) { + storage_.value_ = std::move(rhs.storage_.value_); + } + return *this; + } + + private: + typename std::remove_const::type* dataptr() { + return std::addressof(storage_.value_); + } + + constexpr const T* dataptr() const { + return detail_::static_addressof(storage_.value_); + } + + void clear() noexcept { + init_ = false; + } +}; + +// HACK: Optimization for trivially copyable types. The mainline +// implementation fails to have trivial copy/move operations in these +// cases, and we care about them, so just implement that directly. +template +struct trivially_copyable_optimization_optional_base { + bool init_; + constexpr_storage_t storage_; + + constexpr trivially_copyable_optimization_optional_base() noexcept + : init_(false), storage_(trivial_init) {} + + explicit constexpr trivially_copyable_optimization_optional_base(const T& v) + : init_(true), storage_(v) {} + + explicit constexpr trivially_copyable_optimization_optional_base(T&& v) + : init_(true), storage_(constexpr_move(v)) {} + + template + explicit constexpr trivially_copyable_optimization_optional_base(in_place_t, Args&&... args) + : init_(true), storage_(constexpr_forward(args)...) {} + + template < + class U, + class... Args, + TR2_OPTIONAL_REQUIRES(std::is_constructible>)> + constexpr explicit trivially_copyable_optimization_optional_base( + in_place_t, + std::initializer_list il, + Args&&... args) + : init_(true), storage_(il, std::forward(args)...) {} + + ~trivially_copyable_optimization_optional_base() = default; }; +// CUDA 9.2 and below fail while trying to compile default move constructor +// see https://github.com/pytorch/csprng/issues/84 +#if (!defined(__CUDA_ARCH__) || !defined(CUDA_VERSION) || CUDA_VERSION > 9200) template using OptionalBase = typename std::conditional< - std::is_trivially_destructible::value, // if possible - constexpr_optional_base::type>, // use base with trivial destructor - optional_base::type>>::type; + std::is_trivially_destructible::value && + C10_IS_TRIVIALLY_COPYABLE(T) && + // Avoid using is_trivially_copy_{constructible,assignable} + // because old GCC versions don't support them. Also, + // is_trivially_copyable seems not to do what I expect, so check + // trivially_copyable_optimization_optional_base directly. + std::is_copy_constructible>::value && + std::is_copy_assignable>::value, + trivially_copyable_optimization_optional_base, + typename std::conditional< + std::is_trivially_destructible::value, // if possible + constexpr_optional_base::type>, // use base with trivial destructor + optional_base::type>>::type>::type; +#else +template +using OptionalBase = typename std::conditional< + std::is_trivially_destructible::value, // if possible + constexpr_optional_base::type>, // use base with trivial destructor + optional_base::type>>::type; +#endif template class optional : private OptionalBase { +// CUDA 9.2 and below fail while trying to compile default move constructor +// see https://github.com/pytorch/csprng/issues/84 +#if (!defined(__CUDA_ARCH__) || !defined(CUDA_VERSION) || CUDA_VERSION > 9200) template // re-declaration for nvcc on Windows. using OptionalBase = typename std::conditional< - std::is_trivially_destructible::value, // if possible - constexpr_optional_base::type>, // use base with trivial destructor - optional_base::type>>::type; + std::is_trivially_destructible::value && + C10_IS_TRIVIALLY_COPYABLE(U) && + std::is_copy_constructible>::value && + std::is_copy_assignable>::value, + trivially_copyable_optimization_optional_base, + typename std::conditional< + std::is_trivially_destructible::value, // if possible + constexpr_optional_base::type>, // use base with trivial destructor + optional_base::type>>::type>::type; +#else + template + using OptionalBase = typename std::conditional< + std::is_trivially_destructible::value, // if possible + constexpr_optional_base::type>, // use base with trivial destructor + optional_base::type>>::type; +#endif static_assert( !std::is_same::type, nullopt_t>::value, @@ -313,21 +496,21 @@ class optional : private OptionalBase { constexpr optional() noexcept : OptionalBase(){}; constexpr optional(nullopt_t) noexcept : OptionalBase(){}; - optional(const optional& rhs) : OptionalBase() { - if (rhs.initialized()) { - ::new (static_cast(dataptr())) T(*rhs); - OptionalBase::init_ = true; - } - } + optional(const optional& rhs) = default; +// CUDA 9.2 and below fail while trying to compile default move constructor +// see https://github.com/pytorch/csprng/issues/84 +#if (!defined(__CUDA_ARCH__) || !defined(CUDA_VERSION) || CUDA_VERSION > 9200) + optional(optional&& rhs) = default; +#else optional(optional&& rhs) noexcept( - std::is_nothrow_move_constructible::value) - : OptionalBase() { + std::is_nothrow_move_constructible::value) { if (rhs.initialized()) { ::new (static_cast(dataptr())) T(std::move(*rhs)); OptionalBase::init_ = true; } } +#endif // see https://github.com/akrzemi1/Optional/issues/16 // and https://en.cppreference.com/w/cpp/utility/optional/optional, @@ -381,27 +564,9 @@ class optional : private OptionalBase { return *this; } - optional& operator=(const optional& rhs) { - if (initialized() == true && rhs.initialized() == false) - clear(); - else if (initialized() == false && rhs.initialized() == true) - initialize(*rhs); - else if (initialized() == true && rhs.initialized() == true) - contained_val() = *rhs; - return *this; - } + optional& operator=(const optional& rhs) = default; - optional& operator=(optional&& rhs) noexcept( - std::is_nothrow_move_assignable::value&& - std::is_nothrow_move_constructible::value) { - if (initialized() == true && rhs.initialized() == false) - clear(); - else if (initialized() == false && rhs.initialized() == true) - initialize(std::move(*rhs)); - else if (initialized() == true && rhs.initialized() == true) - contained_val() = std::move(*rhs); - return *this; - } + optional& operator=(optional&& rhs) = default; template auto operator=(U&& v) -> typename std::enable_if< @@ -514,6 +679,22 @@ class optional : private OptionalBase { } }; +template +constexpr T value_or_else(const optional& v, F&& func) { + static_assert(std::is_convertible::return_type, T>::value, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() ? *v : detail_::convert(std::forward(func)()); +} + +template +constexpr T value_or_else(optional&& v, F&& func) { + static_assert(std::is_convertible::return_type, T>::value, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() + ? constexpr_move(std::move(v).contained_val()) + : detail_::convert(std::forward(func)()); +} + // XXX: please refrain from using optional, since it is being against with // the optional standard in c++ 17, see the debate and the details here: diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index 076a1d4010651..9b32d8edfe7f1 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -832,7 +832,7 @@ SmallVectorImpl& SmallVectorImpl::operator=( // If we have to grow to have enough elements, destroy the current elements. // This allows us to avoid copying them during the grow. - // FIXME: don't do this if they're efficiently moveable. + // FIXME: don't do this if they're efficiently movable. if (this->capacity() < RHSSize) { // Destroy current elements. this->destroy_range(this->begin(), this->end()); diff --git a/c10/util/StringUtil.cpp b/c10/util/StringUtil.cpp index 673a6f68fc053..244ca0582bd5d 100644 --- a/c10/util/StringUtil.cpp +++ b/c10/util/StringUtil.cpp @@ -18,6 +18,14 @@ std::string StripBasename(const std::string& full_path) { } } +std::string ExcludeFileExtension(const std::string& file_name) { + const char sep = '.'; + auto end_index = file_name.find_last_of(sep) == std::string::npos + ? -1 + : file_name.find_last_of(sep); + return file_name.substr(0, end_index); +} + } // namespace detail std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) { diff --git a/c10/util/StringUtil.h b/c10/util/StringUtil.h index d2744f1fbdc5a..12343ce5add87 100644 --- a/c10/util/StringUtil.h +++ b/c10/util/StringUtil.h @@ -17,6 +17,8 @@ namespace detail { // Obtains the base name from a full path. C10_API std::string StripBasename(const std::string& full_path); +C10_API std::string ExcludeFileExtension(const std::string& full_path); + template struct CanonicalizeStrTypes { using type = const T&; diff --git a/c10/util/ThreadLocalDebugInfo.cpp b/c10/util/ThreadLocalDebugInfo.cpp index 20d473667a8d2..473f4a273f359 100644 --- a/c10/util/ThreadLocalDebugInfo.cpp +++ b/c10/util/ThreadLocalDebugInfo.cpp @@ -7,14 +7,13 @@ thread_local std::shared_ptr debug_info = nullptr; } /* static */ -std::shared_ptr ThreadLocalDebugInfo::get( - DebugInfoKind kind) { - auto cur = debug_info; +DebugInfoBase* ThreadLocalDebugInfo::get(DebugInfoKind kind) { + ThreadLocalDebugInfo* cur = debug_info.get(); while (cur) { if (cur->kind_ == kind) { - return cur->info_; + return cur->info_.get(); } - cur = cur->parent_info_; + cur = cur->parent_info_.get(); } return nullptr; } diff --git a/c10/util/ThreadLocalDebugInfo.h b/c10/util/ThreadLocalDebugInfo.h index 9620cfb9fdea0..a1d167d0652d0 100644 --- a/c10/util/ThreadLocalDebugInfo.h +++ b/c10/util/ThreadLocalDebugInfo.h @@ -13,6 +13,7 @@ enum class C10_API_ENUM DebugInfoKind : uint8_t { PRODUCER_INFO = 0, MOBILE_RUNTIME_INFO, PROFILER_STATE, + INFERENCE_CONTEXT, // for inference usage TEST_INFO, // used only in tests TEST_INFO_2, // used only in tests @@ -32,7 +33,7 @@ class C10_API DebugInfoBase { // profiling, etc) class C10_API ThreadLocalDebugInfo { public: - static std::shared_ptr get(DebugInfoKind kind); + static DebugInfoBase* get(DebugInfoKind kind); // Get current ThreadLocalDebugInfo static std::shared_ptr current(); diff --git a/c10/util/TypeCast.h b/c10/util/TypeCast.h index a39e211b873a3..85513ecc5e2f9 100644 --- a/c10/util/TypeCast.h +++ b/c10/util/TypeCast.h @@ -44,7 +44,7 @@ struct static_cast_with_inter_type { // Note: Converting from negative float values to unsigned integer types is // undefined behavior in C++, and current CPU and GPU compilers exhibit // divergent behavior. Casting from negative float values to signed -// integer types and then to unsigned integer types is not undefiend, +// integer types and then to unsigned integer types is not undefined, // however, so this cast improves the consistency of type conversions // to uint8 across compilers. // Further note: Type conversions across compilers still have other undefined @@ -170,3 +170,5 @@ To checked_convert(From f, const char* name) { } } // namespace c10 + +// Trigger tests for D25440771. TODO: Remove this line any time you want. diff --git a/c10/util/Unicode.cpp b/c10/util/Unicode.cpp new file mode 100644 index 0000000000000..e6d41bc731867 --- /dev/null +++ b/c10/util/Unicode.cpp @@ -0,0 +1,49 @@ +#include + +namespace c10 { +#if defined(_WIN32) +std::wstring u8u16(const std::string& str) { + if (str.empty()) { + return std::wstring(); + } + int size_needed = MultiByteToWideChar( + CP_UTF8, 0, str.c_str(), static_cast(str.size()), NULL, 0); + TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); + std::wstring wstr(size_needed, 0); + MultiByteToWideChar( + CP_UTF8, + 0, + str.c_str(), + static_cast(str.size()), + &wstr[0], + size_needed); + return wstr; +} +std::string u16u8(const std::wstring& wstr) { + if (wstr.empty()) { + return std::string(); + } + int size_needed = WideCharToMultiByte( + CP_UTF8, + 0, + wstr.c_str(), + static_cast(wstr.size()), + NULL, + 0, + NULL, + NULL); + TORCH_CHECK(size_needed > 0, "Error converting the content to UTF8"); + std::string str(size_needed, 0); + WideCharToMultiByte( + CP_UTF8, + 0, + wstr.c_str(), + static_cast(wstr.size()), + &str[0], + size_needed, + NULL, + NULL); + return str; +} +#endif +} // namespace c10 \ No newline at end of file diff --git a/c10/util/Unicode.h b/c10/util/Unicode.h new file mode 100644 index 0000000000000..ccdea5a62c1a0 --- /dev/null +++ b/c10/util/Unicode.h @@ -0,0 +1,14 @@ +#pragma once + +#if defined(_WIN32) +#include +#include +#include +#endif + +namespace c10 { +#if defined(_WIN32) +C10_API std::wstring u8u16(const std::string& str); +C10_API std::string u16u8(const std::wstring& wstr); +#endif +} diff --git a/c10/util/UniqueVoidPtr.h b/c10/util/UniqueVoidPtr.h index cf5187153b133..c4e3158ae10e9 100644 --- a/c10/util/UniqueVoidPtr.h +++ b/c10/util/UniqueVoidPtr.h @@ -10,7 +10,7 @@ using DeleterFnPtr = void (*)(void*); namespace detail { // Does not delete anything -CAFFE2_API void deleteNothing(void*); +TORCH_API void deleteNothing(void*); // A detail::UniqueVoidPtr is an owning smart pointer like unique_ptr, but // with three major differences: diff --git a/c10/util/complex.h b/c10/util/complex.h index 9c63a2b296fb3..d4d5525170af1 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -61,7 +61,7 @@ namespace c10 { // Since we only support float and double, on will use `complex& operator=(T x)` // - Copy assignment operator and converting assignment operator // - There is no specialization of converting assignment operators, which type is -// convertible is soly depend on whether the scalar type is convertable +// convertible is solely dependent on whether the scalar type is convertible // // In addition to the standard assignment, we also provide assignment operators with std and thrust // @@ -262,7 +262,7 @@ struct alignas(sizeof(T) * 2) complex { return real() || imag(); } - constexpr T real() const { + C10_HOST_DEVICE constexpr T real() const { return real_; } constexpr void real(T value) { diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 237c3a7dc0ad8..790d97ee39941 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -5,6 +5,11 @@ #include #include +namespace pybind11 { +template +class class_; +} + namespace c10 { class intrusive_ptr_target; namespace raw { @@ -14,13 +19,17 @@ namespace raw { namespace intrusive_ptr { inline void incref(intrusive_ptr_target * self); } + + // constructor tag used by intrusive_ptr constructors + struct DontIncreaseRefcount {}; } /** * intrusive_ptr is an alternative to shared_ptr that has better * performance because it does the refcounting intrusively * (i.e. in a member of the object itself). * Your class T needs to inherit from intrusive_ptr_target to allow it to be - * used in an intrusive_ptr. + * used in an intrusive_ptr. Your class's constructor should not allow + *`this` to escape to other threads or create an intrusive_ptr from `this`. */ // Note [Stack allocated intrusive_ptr_target safety] @@ -97,7 +106,8 @@ class C10_API intrusive_ptr_target { refcount_.load() == 0, "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it"); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - weakcount_.load() == 0, + // See ~intrusive_ptr for optimization that will frequently result in 1 at destruction time. + weakcount_.load() == 1 || weakcount_.load() == 0, "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it"); #if defined(_MSC_VER) && !defined(__clang__) # pragma warning(pop) @@ -149,6 +159,29 @@ TTarget* assign_ptr_(TTarget* rhs) { return rhs; } } + +// Increment needs to be acquire-release to make use_count() and +// unique() reliable. +inline size_t atomic_refcount_increment(std::atomic& refcount) { + return refcount.fetch_add(1, std::memory_order_acq_rel) + 1; +} + +// weak_use_count() is only used for testing, so we don't need it to +// be reliable. Relaxed should be fine. +inline size_t atomic_weakcount_increment(std::atomic& weakcount) { + return weakcount.fetch_add(1, std::memory_order_relaxed) + 1; +} + +// Both decrements need to be acquire-release for correctness. See +// e.g. std::shared_ptr implementation. +inline size_t atomic_refcount_decrement(std::atomic& refcount) { + return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1; +} + +inline size_t atomic_weakcount_decrement(std::atomic& weakcount) { + return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1; +} + } // namespace detail template @@ -173,7 +206,7 @@ class intrusive_ptr final { "NullType must have a constexpr singleton() method"); #endif static_assert( - std::is_same::value, + std::is_base_of::type>::value, "NullType::singleton() must return a element_type* pointer"); TTarget* target_; @@ -182,9 +215,19 @@ class intrusive_ptr final { friend class intrusive_ptr; friend class weak_intrusive_ptr; + // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom + // smart holder in pybind11 could access the private constructor of + // intrusive_ptr(T*) which took the ownership of the object. This is required + // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses + // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For + // details, see + // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers + template + friend class pybind11::class_; + void retain_() { if (target_ != NullType::singleton()) { - size_t new_refcount = ++target_->refcount_; + size_t new_refcount = detail::atomic_refcount_increment(target_->refcount_); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( new_refcount != 1, "intrusive_ptr: Cannot increase refcount after it reached zero."); @@ -192,7 +235,7 @@ class intrusive_ptr final { } void reset_() noexcept { - if (target_ != NullType::singleton() && --target_->refcount_ == 0) { + if (target_ != NullType::singleton() && detail::atomic_refcount_decrement(target_->refcount_) == 0) { // justification for const_cast: release_resources is basically a destructor // and a destructor always mutates the object, even for const objects. const_cast*>(target_)->release_resources(); @@ -200,23 +243,45 @@ class intrusive_ptr final { // See comment above about weakcount. As long as refcount>0, // weakcount is one larger than the actual number of weak references. // So we need to decrement it here. - if (--target_->weakcount_ == 0) { + if (target_->weakcount_.load(std::memory_order_acquire) == 1 || + detail::atomic_weakcount_decrement(target_->weakcount_) == 0) { delete target_; } } target_ = NullType::singleton(); } + // raw pointer constructors are not public because we shouldn't make + // intrusive_ptr out of raw pointers except from inside the make_intrusive(), + // reclaim() and weak_intrusive_ptr::lock() implementations. + // This constructor will not increase the ref counter for you. - // This is not public because we shouldn't make intrusive_ptr out of raw - // pointers except from inside the make_intrusive() and - // weak_intrusive_ptr::lock() implementations - explicit intrusive_ptr(TTarget* target) noexcept : target_(target) {} + // We use the tagged dispatch mechanism to explicitly mark this constructor + // to not increase the refcount + explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept + : target_(target) {} + + // This constructor will increase the ref counter for you. + // This constructor will be used by the make_intrusive(), and also pybind11, + // which wrap the intrusive_ptr holder around the raw pointer and incref + // correspondingly (pybind11 requires raw pointer constructor to incref by + // default). + explicit intrusive_ptr(TTarget* target) + : intrusive_ptr(target, raw::DontIncreaseRefcount{}) { + if (target_ != NullType::singleton()) { + // We can't use retain_(), because we also have to increase weakcount + // and because we allow raising these values from 0, which retain_() + // has an assertion against. + detail::atomic_refcount_increment(target_->refcount_); + detail::atomic_weakcount_increment(target_->weakcount_); + } + } public: using element_type = TTarget; - intrusive_ptr() noexcept : intrusive_ptr(NullType::singleton()) {} + intrusive_ptr() noexcept + : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { rhs.target_ = NullType::singleton(); @@ -313,14 +378,14 @@ class intrusive_ptr final { if (target_ == NullType::singleton()) { return 0; } - return target_->refcount_.load(); + return target_->refcount_.load(std::memory_order_acquire); } size_t weak_use_count() const noexcept { if (target_ == NullType::singleton()) { return 0; } - return target_->weakcount_.load(); + return target_->weakcount_.load(std::memory_order_acquire); } bool unique() const noexcept { @@ -347,17 +412,30 @@ class intrusive_ptr final { * passed in *must* have been created using intrusive_ptr::release(). */ static intrusive_ptr reclaim(TTarget* owning_ptr) { - return intrusive_ptr(owning_ptr); + return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{}); } + /** + * Allocate a heap object with args and wrap it inside a intrusive_ptr and + * incref. This is a helper function to let make_intrusive() access private + * intrusive_ptr constructors. + */ template static intrusive_ptr make(Args&&... args) { - auto result = intrusive_ptr(new TTarget(std::forward(args)...)); - // We can't use retain_(), because we also have to increase weakcount - // and because we allow raising these values from 0, which retain_() - // has an assertion against. - ++result.target_->refcount_; - ++result.target_->weakcount_; + auto result = intrusive_ptr(new TTarget(std::forward(args)...), raw::DontIncreaseRefcount{}); + + // We just created result.target_, so we know no other thread has + // access to it, so we know we needn't care about memory ordering. + // (On x86_64, a store with memory_order_relaxed generates a plain old + // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is + // much more expensive: https://godbolt.org/z/eKPzj8.) + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + result.target_->refcount_ == 0 && result.target_->weakcount_ == 0, + "intrusive_ptr: Newly-created target had non-zero refcounts. Does its " + "constructor do something strange like incref or create an intrusive_ptr" + "from `this`?"); + result.target_->refcount_.store(1, std::memory_order_relaxed); + result.target_->weakcount_.store(1, std::memory_order_relaxed); return result; } @@ -431,7 +509,7 @@ class weak_intrusive_ptr final { "NullType must have a constexpr singleton() method"); #endif static_assert( - std::is_same::value, + std::is_base_of::type>::value, "NullType::singleton() must return a element_type* pointer"); TTarget* target_; @@ -441,7 +519,7 @@ class weak_intrusive_ptr final { void retain_() { if (target_ != NullType::singleton()) { - size_t new_weakcount = ++target_->weakcount_; + size_t new_weakcount = detail::atomic_weakcount_increment(target_->weakcount_); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( new_weakcount != 1, "weak_intrusive_ptr: Cannot increase weakcount after it reached zero."); @@ -449,7 +527,7 @@ class weak_intrusive_ptr final { } void reset_() noexcept { - if (target_ != NullType::singleton() && --target_->weakcount_ == 0) { + if (target_ != NullType::singleton() && detail::atomic_weakcount_decrement(target_->weakcount_) == 0) { delete target_; } target_ = NullType::singleton(); @@ -518,6 +596,12 @@ class weak_intrusive_ptr final { return operator=(rhs); } + weak_intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept { + weak_intrusive_ptr tmp(rhs); + swap(tmp); + return *this; + } + template weak_intrusive_ptr& operator=( const weak_intrusive_ptr& rhs) & { @@ -568,14 +652,14 @@ class weak_intrusive_ptr final { if (target_ == NullType::singleton()) { return 0; } - return target_->refcount_.load(); // refcount, not weakcount! + return target_->refcount_.load(std::memory_order_acquire); // refcount, not weakcount! } size_t weak_use_count() const noexcept { if (target_ == NullType::singleton()) { return 0; } - return target_->weakcount_.load(); + return target_->weakcount_.load(std::memory_order_acquire); } bool expired() const noexcept { @@ -583,15 +667,20 @@ class weak_intrusive_ptr final { } intrusive_ptr lock() const noexcept { - auto refcount = target_->refcount_.load(); - do { - if (refcount == 0) { - // Object already destructed, no strong references left anymore. - // Return nullptr. - return intrusive_ptr(NullType::singleton()); - } - } while (!target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); - return intrusive_ptr(target_); + if (expired()) { + return intrusive_ptr(); + } else { + auto refcount = target_->refcount_.load(std::memory_order_seq_cst); + do { + if (refcount == 0) { + // Object already destructed, no strong references left anymore. + // Return nullptr. + return intrusive_ptr(); + } + } while (!target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); + return intrusive_ptr( + target_, raw::DontIncreaseRefcount{}); + } } /** @@ -611,7 +700,7 @@ class weak_intrusive_ptr final { /** * Takes an owning (but must be weakly referenced) pointer to TTarget* and * creates a weak_intrusive_ptr that takes over ownership. - * Thas means the weakcount is not increased. + * This means that the weakcount is not increased. * This is the counter-part to weak_intrusive_ptr::release() and the pointer * passed in *must* have been created using weak_intrusive_ptr::release(). */ @@ -692,7 +781,7 @@ namespace intrusive_ptr { // NullType::singleton to this function inline void incref(intrusive_ptr_target* self) { if (self) { - ++self->refcount_; + detail::atomic_refcount_increment(self->refcount_); } } @@ -726,7 +815,7 @@ namespace intrusive_ptr { namespace weak_intrusive_ptr { inline void incref(weak_intrusive_ptr_target* self) { - ++self->weakcount_; + detail::atomic_weakcount_increment(self->weakcount_); } inline void decref(weak_intrusive_ptr_target* self) { diff --git a/c10/util/irange.h b/c10/util/irange.h new file mode 100644 index 0000000000000..e9895076624ab --- /dev/null +++ b/c10/util/irange.h @@ -0,0 +1,77 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include + +#include +#include + +namespace c10 { + +namespace detail { + +template {}, int> = 0> +struct integer_iterator : std::iterator { + explicit integer_iterator(I value) : value(value) {} + + I operator*() const { return value; } + + I const* operator->() const { return &value; } + + integer_iterator& operator++() { + ++value; + return *this; + } + + integer_iterator operator++(int) { + const auto copy = *this; + ++*this; + return copy; + } + + bool operator==(const integer_iterator& other) const { + return value == other.value; + } + + bool operator!=(const integer_iterator& other) const { + return value != other.value; + } + + protected: + I value; +}; + +} // namespace detail + +template {}, int> = 0> +struct integer_range { + public: + integer_range(I begin, I end) : begin_(begin), end_(end) {} + detail::integer_iterator begin() const { return begin_; } + detail::integer_iterator end() const { return end_; } + + private: + detail::integer_iterator begin_; + detail::integer_iterator end_; +}; + +/// Creates an integer range for the half-open interval [begin, end) +/// If end<=begin, then the range is empty +template ::value, bool> = true> +integer_range irange(Integer begin, Integer end) { + //If end<=begin then the range is empty; we can achieve this effect by + //choosing the larger of {begin, end} as the loop terminator + return {begin, std::max(begin, end)}; +} + +/// Creates an integer range for the half-open interval [0, end) +/// If end<=begin, then the range is empty +template ::value, bool> = true> +integer_range irange(Integer end) { + //If end<=begin then the range is empty; we can achieve this effect by + //choosing the larger of {0, end} as the loop terminator + return {Integer(), std::max(Integer(), end)}; +} + +} // namespace torch diff --git a/c10/util/llvmMathExtras.h b/c10/util/llvmMathExtras.h index 8def126c29aa9..2c4fbf8a501bf 100644 --- a/c10/util/llvmMathExtras.h +++ b/c10/util/llvmMathExtras.h @@ -14,9 +14,10 @@ #define LLVM_SUPPORT_MATHEXTRAS_H #include - #include #include #include + #include + #include #include #include #include @@ -547,26 +548,26 @@ /// (32 bit edition.) /// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2 inline unsigned Log2_32(uint32_t Value) { - return 31 - countLeadingZeros(Value); + return static_cast(31 - countLeadingZeros(Value)); } /// Return the floor log base 2 of the specified value, -1 if the value is zero. /// (64 bit edition.) inline unsigned Log2_64(uint64_t Value) { - return 63 - countLeadingZeros(Value); + return static_cast(63 - countLeadingZeros(Value)); } /// Return the ceil log base 2 of the specified value, 32 if the value is zero. /// (32 bit edition). /// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3 inline unsigned Log2_32_Ceil(uint32_t Value) { - return 32 - countLeadingZeros(Value - 1); + return static_cast(32 - countLeadingZeros(Value - 1)); } /// Return the ceil log base 2 of the specified value, 64 if the value is zero. /// (64 bit edition.) inline unsigned Log2_64_Ceil(uint64_t Value) { - return 64 - countLeadingZeros(Value - 1); + return static_cast(64 - countLeadingZeros(Value - 1)); } /// Return the greatest common divisor of the values using Euclid's algorithm. @@ -589,6 +590,7 @@ /// This function takes a 32-bit integer and returns the bit equivalent float. inline float BitsToFloat(uint32_t Bits) { + //TODO: Use bit_cast once C++20 becomes available. float F; static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes"); memcpy(&F, &Bits, sizeof(Bits)); diff --git a/c10/util/math_compat.h b/c10/util/math_compat.h index 7d1a7b6438503..786e09c4c895b 100644 --- a/c10/util/math_compat.h +++ b/c10/util/math_compat.h @@ -59,6 +59,20 @@ namespace std { throw std::runtime_error("std::hypot is not implemented on older Android"); } + // TODO: this function needs to be implemented and tested. Currently just throw an error. + inline float igamma(float x, float y) { + throw std::runtime_error("igamma is not implemented on older Android"); + } + inline double igamma(double x, double y) { + throw std::runtime_error("igamma is not implemented on older Android"); + } + inline float igammac(float x, float y) { + throw std::runtime_error("igammac is not implemented on older Android"); + } + inline double igammac(double x, double y) { + throw std::runtime_error("igammac is not implemented on older Android"); + } + // TODO: this function needs to be implemented and tested. Currently just throw an error. inline float nextafter(float x, float y) { throw std::runtime_error("std::nextafter is not implemented on older Android"); @@ -66,7 +80,7 @@ namespace std { inline double nextafter(double x, double y) { throw std::runtime_error("std::nextafter is not implemented on older Android"); } - + // TODO: this function needs to be implemented and tested. Currently just throw an error. inline float exp2(float x) { throw std::runtime_error("std::exp2 is not implemented on older Android"); diff --git a/c10/util/quint4x2.h b/c10/util/quint4x2.h new file mode 100644 index 0000000000000..c2502b561409e --- /dev/null +++ b/c10/util/quint4x2.h @@ -0,0 +1,18 @@ +#pragma once +#include + +#include + +namespace c10 { + +/** + * quint4x2 is for un-signed 4 bit quantized Tensors that are packed to byte boundary. + */ +struct alignas(1) quint4x2 { + using underlying = uint8_t; + uint8_t val_; + quint4x2() = default; + C10_HOST_DEVICE explicit quint4x2(uint8_t val) : val_(val) {} +}; + +} // namespace c10 diff --git a/c10/util/quint8.h b/c10/util/quint8.h index 0c0476ae6af5a..3aeb09b6397c0 100644 --- a/c10/util/quint8.h +++ b/c10/util/quint8.h @@ -6,7 +6,7 @@ namespace c10 { /** - * qint8 is for signed 8 bit quantized Tensors + * quint8 is for unsigned 8 bit quantized Tensors */ struct alignas(1) quint8 { using underlying = uint8_t; diff --git a/c10/util/typeid.cpp b/c10/util/typeid.cpp index e97eaa8439799..79c093cbeb31c 100644 --- a/c10/util/typeid.cpp +++ b/c10/util/typeid.cpp @@ -14,42 +14,41 @@ namespace detail { C10_EXPORT void _ThrowRuntimeTypeLogicError(const string& msg) { // In earlier versions it used to be std::abort() but it's a bit hard-core // for a library - AT_ERROR(msg); + TORCH_CHECK(false, msg); } +} // namespace detail +[[noreturn]] void TypeMeta::error_unsupported_typemeta(caffe2::TypeMeta dtype) { + TORCH_CHECK(false, "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); +} -} // namespace detail +// see TypeMeta::addTypeMetaData +std::atomic TypeMeta::nextTypeIndex(NumScalarTypes); -template <> -EXPORT_IF_NOT_GCC const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance< - detail::_Uninitialized>() noexcept { - static constexpr detail::TypeMetaData singleton = detail::TypeMetaData( - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - TypeIdentifier::uninitialized(), - "nullptr (uninitialized)"); - return &singleton; +// fixed length array of TypeMetaData instances +detail::TypeMetaData* TypeMeta::typeMetaDatas() { + static detail::TypeMetaData instances[MaxTypeIndex + 1] = { +#define SCALAR_TYPE_META(T, name) \ + /* ScalarType::name */ \ + detail::TypeMetaData( \ + sizeof(T), \ + detail::_PickNew(), \ + detail::_PickPlacementNew(), \ + detail::_PickCopy(), \ + detail::_PickPlacementDelete(), \ + detail::_PickDelete(), \ + TypeIdentifier::Get(), \ + c10::util::get_fully_qualified_type_name()), +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_META) +#undef SCALAR_TYPE_META + // The remainder of the array is padded with TypeMetaData blanks. + // The first of these is the entry for ScalarType::Undefined. + // The rest are consumed by CAFFE_KNOWN_TYPE entries. + }; + return instances; } -CAFFE_KNOWN_TYPE(uint8_t) -CAFFE_KNOWN_TYPE(int8_t) -CAFFE_KNOWN_TYPE(int16_t) -CAFFE_KNOWN_TYPE(int) -CAFFE_KNOWN_TYPE(int64_t) -CAFFE_KNOWN_TYPE(at::Half) -CAFFE_KNOWN_TYPE(float) -CAFFE_KNOWN_TYPE(double) -CAFFE_KNOWN_TYPE(c10::complex) -CAFFE_KNOWN_TYPE(c10::complex) -CAFFE_KNOWN_TYPE(c10::complex) -// 11 = undefined type id -// 12 = Tensor (defined in tensor.cc) CAFFE_KNOWN_TYPE(std::string) -CAFFE_KNOWN_TYPE(bool) CAFFE_KNOWN_TYPE(uint16_t) CAFFE_KNOWN_TYPE(char) CAFFE_KNOWN_TYPE(std::unique_ptr) @@ -61,7 +60,7 @@ CAFFE_KNOWN_TYPE(bool*) CAFFE_KNOWN_TYPE(char*) CAFFE_KNOWN_TYPE(int*) -// For some of the compilers, long is definied separately from int32_t and +// For some of the compilers, long is defined separately from int32_t and // int64_t. As a result we will need to actually define them separately. // It is recommended that one does NOT use long - use int32_t and int64_t // explicitly. Explicit long type annotation may go away in the future. @@ -79,14 +78,11 @@ using _guard_long_unique = std::conditional_t< _guard_long_unique_dummy, T>; } // namespace detail + CAFFE_KNOWN_TYPE(detail::_guard_long_unique); CAFFE_KNOWN_TYPE(detail::_guard_long_unique>) CAFFE_KNOWN_TYPE(float*) CAFFE_KNOWN_TYPE(at::Half*) -CAFFE_KNOWN_TYPE(c10::qint8) -CAFFE_KNOWN_TYPE(c10::quint8) -CAFFE_KNOWN_TYPE(c10::qint32) -CAFFE_KNOWN_TYPE(at::BFloat16) } // namespace caffe2 diff --git a/c10/util/typeid.h b/c10/util/typeid.h index 62a0bdfc66440..635e9aa5fc331 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -21,17 +21,14 @@ #include #include #include -#include #include #include #include #include -#include -#include -#include -#include #include +#include + /* * TypeIdentifier is a small type containing an id. * Types must be registered using CAFFE_KNOWN_TYPE() for them to have a type id. @@ -66,7 +63,7 @@ namespace caffe2 { */ class C10_API TypeIdentifier final : public at::IdWrapper { - public: +public: friend std::ostream& operator<<(std::ostream& stream, TypeIdentifier typeId); friend constexpr bool operator<(TypeIdentifier lhs, TypeIdentifier rhs); @@ -86,9 +83,8 @@ class C10_API TypeIdentifier final return TypeIdentifier(c10::util::type_index{0}); } - private: +private: constexpr explicit TypeIdentifier(c10::util::type_index id) : IdWrapper(id) {} - friend class TypeMeta; // TODO Is this friend an issue? }; // Allow usage in std::map / std::set @@ -125,7 +121,16 @@ struct TypeMetaData final { using PlacementDelete = void(void*, size_t); using Delete = void(void*); - TypeMetaData() = delete; + constexpr TypeMetaData() noexcept + : itemsize_(0), + new_(nullptr), + placementNew_(nullptr), + copy_(nullptr), + placementDelete_(nullptr), + delete_(nullptr), + id_(TypeIdentifier::uninitialized()), + name_("nullptr (uninitialized)") {} + constexpr TypeMetaData( size_t itemsize, New* newFn, @@ -135,14 +140,14 @@ struct TypeMetaData final { Delete* deleteFn, TypeIdentifier id, c10::string_view name) noexcept - : itemsize_(itemsize), - new_(newFn), - placementNew_(placementNew), - copy_(copy), - placementDelete_(placementDelete), - delete_(deleteFn), - id_(id), - name_(name) {} + : itemsize_(itemsize), + new_(newFn), + placementNew_(placementNew), + copy_(copy), + placementDelete_(placementDelete), + delete_(deleteFn), + id_(id), + name_(name) {} size_t itemsize_; New* new_; @@ -293,25 +298,24 @@ inline constexpr TypeMetaData::Delete* _PickDelete() noexcept { return &_Delete; } -template -inline C10_TYPENAME_CONSTEXPR TypeMetaData _makeTypeMetaDataInstance() { - C10_HOST_CONSTEXPR_VAR auto typeId = TypeIdentifier::Get(); - C10_TYPENAME_CONSTEXPR auto typeName = c10::util::get_fully_qualified_type_name(); - - return {sizeof(T), - _PickNew(), - _PickPlacementNew(), - _PickCopy(), - _PickPlacementDelete(), - _PickDelete(), - typeId, - typeName}; -} - class _Uninitialized final {}; } // namespace detail +// +// note: this is outside TypeMeta bc gcc seems to have trouble +// with scalarTypeItemSizes as a constexpr static member used by +// a public inline instance method +// + +// item sizes for TypeMeta::itemsize() fast path +static constexpr uint8_t scalarTypeItemSizes[NumScalarTypes] = { +#define SCALAR_TYPE_SIZE(T, name) sizeof(T), + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_SIZE) +#undef SCALAR_TYPE_SIZE + 0, // Undefined +}; + /** * TypeMeta is a thin class that allows us to store the type of a container such * as a blob, or the data type of a tensor, with a unique run-time id. It also @@ -337,17 +341,22 @@ class C10_API TypeMeta final { TypeMeta(const TypeMeta& src) noexcept = default; /** - * Assignment operator. + * Assignment operators. */ TypeMeta& operator=(const TypeMeta& src) noexcept = default; TypeMeta(TypeMeta&& rhs) noexcept = default; - private: + inline TypeMeta& operator=(ScalarType scalar_type) noexcept { + index_ = static_cast(scalar_type); + return *this; + } + +private: // TypeMeta can only be created by Make, making sure that we do not // create incorrectly mixed up TypeMeta objects. - explicit TypeMeta(const detail::TypeMetaData* data) noexcept - : data_(data) { + explicit TypeMeta(const uint16_t index) noexcept + : index_(index) { } public: @@ -355,48 +364,66 @@ class C10_API TypeMeta final { * Returns the type id. */ TypeIdentifier id() const noexcept { - return data_->id_; + return data().id_; + } + /** + * true if we represent some ScalarType type + */ + inline bool isScalarType() const noexcept { + return index_ < NumScalarTypes; + } + /** + * true if we represent ScalarType scalar_type + */ + inline bool isScalarType(ScalarType scalar_type) const noexcept { + return index_ == static_cast(scalar_type); } /** * Returns the size of the item. */ - size_t itemsize() const noexcept { - return data_->itemsize_; + inline size_t itemsize() const noexcept { + if (C10_LIKELY(isScalarType())) { + return scalarTypeItemSizes[index_]; + } + return data().itemsize_; } + /** + * Returns the new function pointer for individual items. + */ New* newFn() const noexcept { - return data_->new_; + return data().new_; } /** * Returns the placement new function pointer for individual items. */ PlacementNew* placementNew() const noexcept { - return data_->placementNew_; + return data().placementNew_; } /** * Returns the typed copy function pointer for individual iterms. */ Copy* copy() const noexcept { - return data_->copy_; + return data().copy_; } /** * Returns the destructor function pointer for individual items. */ PlacementDelete* placementDelete() const noexcept { - return data_->placementDelete_; + return data().placementDelete_; } Delete* deleteFn() const noexcept { - return data_->delete_; + return data().delete_; } /** * Returns a printable name for the type. */ c10::string_view name() const noexcept { - return data_->name_; + return data().name_; } friend bool operator==( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept; + const TypeMeta lhs, + const TypeMeta rhs) noexcept; template bool Match() const noexcept { @@ -411,7 +438,7 @@ class C10_API TypeMeta final { } template - static C10_TYPENAME_CONSTEXPR c10::string_view TypeName() noexcept { + static c10::string_view TypeName() noexcept { return c10::util::get_fully_qualified_type_name(); } @@ -436,35 +463,105 @@ class C10_API TypeMeta final { #pragma GCC diagnostic ignored "-Wunknown-warning-option" #pragma GCC diagnostic ignored "-Wundefined-var-template" #endif - return TypeMeta(_typeMetaDataInstance()); + return TypeMeta(_typeMetaData()); #ifndef _MSC_VER #pragma GCC diagnostic pop #endif } - private: - const detail::TypeMetaData* data_; + /** + * convert ScalarType enum values to TypeMeta handles + */ + static inline caffe2::TypeMeta fromScalarType(ScalarType scalar_type) { + const size_t index = static_cast(scalar_type); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index < NumScalarTypes, + "Unrecognized Scalartype ", scalar_type, " (please report this error)"); + return TypeMeta(index); + } + + /** + * convert TypeMeta handles to ScalarType enum values + */ + inline ScalarType toScalarType() { + if (C10_LIKELY(isScalarType())) { + return static_cast(index_); + } + error_unsupported_typemeta(*this); + } + +private: + [[noreturn]] static void error_unsupported_typemeta(caffe2::TypeMeta dtype); + + // hard limit number of registered types + // note: constexpr provokes Windows compilation error "member may not be initialized" + // static constexpr size_t MaxTypeIndex = UINT8_MAX; + #define MaxTypeIndex UINT8_MAX + + static std::atomic nextTypeIndex; + + static detail::TypeMetaData* typeMetaDatas(); template - C10_API static const detail::TypeMetaData* _typeMetaDataInstance() noexcept; + static uint16_t addTypeMetaData() { + const uint16_t index = nextTypeIndex++; + TORCH_CHECK(index <= MaxTypeIndex, + "Maximum number of CAFFE_KNOWN_TYPE declarations has been exceeded. ", + "Please report this issue."); + typeMetaDatas()[index] = detail::TypeMetaData{ + sizeof(T), + detail::_PickNew(), + detail::_PickPlacementNew(), + detail::_PickCopy(), + detail::_PickPlacementDelete(), + detail::_PickDelete(), + TypeIdentifier::Get(), + c10::util::get_fully_qualified_type_name()}; + return index; + } + + // specializations return indexes into typeMetaDataInstances() + template + C10_API static uint16_t _typeMetaData() noexcept; + + // + // TypeMeta just wraps this index + // + + uint16_t index_; + + inline const detail::TypeMetaData& data() const { + return typeMetaDatas()[index_]; + } }; +// specializations of TypeMeta::_typeMetaData for ScalarType types + +#define DEFINE_SCALAR_METADATA_INSTANCE(T, name) \ + template <> \ + constexpr uint16_t TypeMeta::_typeMetaData() noexcept { \ + return static_cast(ScalarType::name); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_METADATA_INSTANCE) +#undef DEFINE_SCALAR_METADATA_INSTANCE + template <> -C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance< - detail::_Uninitialized>() noexcept; +C10_EXPORT constexpr uint16_t TypeMeta::_typeMetaData() noexcept { + return static_cast(ScalarType::Undefined); +} inline TypeMeta::TypeMeta() noexcept - : data_(_typeMetaDataInstance()) { + : index_(_typeMetaData()) { } inline bool operator==( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept { - return (lhs.data_ == rhs.data_); + const TypeMeta lhs, + const TypeMeta rhs) noexcept { + return (lhs.index_ == rhs.index_); } inline bool operator!=( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept { + const TypeMeta lhs, + const TypeMeta rhs) noexcept { return !operator==(lhs, rhs); } @@ -499,13 +596,11 @@ inline std::ostream& operator<<( #define EXPORT_IF_NOT_GCC #endif -#define CAFFE_KNOWN_TYPE(T) \ - template <> \ - EXPORT_IF_NOT_GCC const detail::TypeMetaData* \ - TypeMeta::_typeMetaDataInstance() noexcept { \ - static C10_TYPENAME_CONSTEXPR detail::TypeMetaData singleton = \ - detail::_makeTypeMetaDataInstance(); \ - return &singleton; \ +#define CAFFE_KNOWN_TYPE(T) \ + template <> \ + EXPORT_IF_NOT_GCC uint16_t TypeMeta::_typeMetaData() noexcept { \ + static const uint16_t index = addTypeMetaData(); \ + return index; \ } } // namespace caffe2 diff --git a/c10/util/win32-headers.h b/c10/util/win32-headers.h new file mode 100644 index 0000000000000..ef79c9f372b1a --- /dev/null +++ b/c10/util/win32-headers.h @@ -0,0 +1,57 @@ +#pragma once + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#ifndef NOKERNEL +#define NOKERNEL +#endif +#ifndef NOUSER +#define NOUSER +#endif +#ifndef NOSERVICE +#define NOSERVICE +#endif +#ifndef NOSOUND +#define NOSOUND +#endif +#ifndef NOMCX +#define NOMCX +#endif +#ifndef NOGDI +#define NOGDI +#endif +#ifndef NOMSG +#define NOMSG +#endif +#ifndef NOMB +#define NOMB +#endif +#ifndef NOCLIPBOARD +#define NOCLIPBOARD +#endif + +#include +#include + +#undef VOID +#undef DELETE +#undef IN +#undef THIS +#undef CONST +#undef NAN +#undef UNKNOWN +#undef NONE +#undef ANY +#undef IGNORE +#undef STRICT +#undef GetObject +#undef CreateSemaphore +#undef Yield +#undef RotateRight32 +#undef RotateLeft32 +#undef RotateRight64 +#undef RotateLeft64 diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 6ea848bd32e56..440735ed592d8 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -6,7 +6,7 @@ if(USE_VULKAN) include(../cmake/VulkanCodegen.cmake) endif() -# ---[ MSVC OpenMP modification +# ---[ MSVC OpenMP modification if(MSVC) include(../cmake/public/utils.cmake) endif() @@ -111,7 +111,7 @@ endif() add_subdirectory(core) add_subdirectory(serialize) add_subdirectory(utils) -if(BUILD_CAFFE2) +if(BUILD_CAFFE2 OR (NOT USE_FBGEMM)) add_subdirectory(perfkernels) endif() @@ -291,26 +291,41 @@ endif() if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) if(USE_DISTRIBUTED) - add_library(process_group_agent "${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp" "${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.h") - target_link_libraries(process_group_agent PRIVATE torch c10d fmt::fmt-header-only) - add_dependencies(process_group_agent torch c10d) # Define this target even if we're building without TensorPipe, to make life # easier to other targets that depend on this. However, in that case, by not # setting the USE_TENSORPIPE compile definition, this target will just end # up being empty. Downstream targets should also add a #ifdef guard. - add_library(tensorpipe_agent - "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp" - "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.h" - "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_utils.cpp" - "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_utils.h" - ) - target_link_libraries(tensorpipe_agent PRIVATE torch c10d tensorpipe fmt::fmt-header-only) - add_dependencies(tensorpipe_agent torch c10d) - if(USE_TENSORPIPE) - target_compile_definitions(tensorpipe_agent PUBLIC USE_TENSORPIPE) - target_link_libraries(tensorpipe_agent PRIVATE tensorpipe) - add_dependencies(tensorpipe_agent tensorpipe) + if(NOT WIN32) + add_library(process_group_agent "${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp" "${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.h") + target_link_libraries(process_group_agent PRIVATE torch c10d fmt::fmt-header-only) + add_dependencies(process_group_agent torch c10d) + + add_library(tensorpipe_agent + "${TORCH_SRC_DIR}/csrc/distributed/rpc/macros.h" + "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp" + "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.h" + "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_utils.cpp" + "${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_utils.h" + ) + target_link_libraries(tensorpipe_agent PRIVATE torch c10d tensorpipe fmt::fmt-header-only) + add_dependencies(tensorpipe_agent torch c10d) + if(USE_TENSORPIPE) + if(USE_CUDA) + target_compile_definitions(tensorpipe_agent PUBLIC USE_CUDA) + endif() + + if(USE_ROCM) + target_compile_definitions(tensorpipe_agent PRIVATE + USE_ROCM + __HIP_PLATFORM_HCC__ + ) + endif() + + target_compile_definitions(tensorpipe_agent PUBLIC USE_TENSORPIPE) + target_link_libraries(tensorpipe_agent PRIVATE tensorpipe) + add_dependencies(tensorpipe_agent tensorpipe) + endif() endif() endif() @@ -326,9 +341,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) set(GENERATED_CXX_TORCH "${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp" - "${TORCH_SRC_DIR}/csrc/jit/generated/generated_unboxing_wrappers_0.cpp" - "${TORCH_SRC_DIR}/csrc/jit/generated/generated_unboxing_wrappers_1.cpp" - "${TORCH_SRC_DIR}/csrc/jit/generated/generated_unboxing_wrappers_2.cpp" ) if(NOT INTERN_DISABLE_AUTOGRAD) @@ -388,11 +400,13 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) COMMAND "${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" + --native-functions-path "aten/src/ATen/native/native_functions.yaml" --nn-path "aten/src" $<$:--disable-autograd> $<$:--selected-op-list-path="${SELECTED_OP_LIST}"> --force_schema_registration DEPENDS + "${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml" "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" "${TOOLS_PATH}/autograd/templates/VariableType.h" "${TOOLS_PATH}/autograd/templates/VariableType.cpp" @@ -416,10 +430,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) "${TOOLS_PATH}/autograd/gen_variable_factories.py" "${TOOLS_PATH}/autograd/gen_variable_type.py" "${TOOLS_PATH}/autograd/load_derivatives.py" - "${TOOLS_PATH}/autograd/nested_dict.py" - "${TOOLS_PATH}/autograd/utils.py" - "${TOOLS_PATH}/jit/gen_unboxing_wrappers.py" - "${TOOLS_PATH}/jit/templates/generated_unboxing_wrappers.cpp" WORKING_DIRECTORY "${TORCH_ROOT}") @@ -463,6 +473,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) # This one needs to be unconditionally added as Functions.cpp is also unconditionally added list(APPEND TORCH_SRCS ${TORCH_SRC_DIR}/csrc/autograd/FunctionsManual.cpp + ${TORCH_SRC_DIR}/csrc/utils/out_types.cpp ) if(NOT INTERN_DISABLE_AUTOGRAD) @@ -493,63 +504,17 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) PROPERTIES COMPILE_FLAGS "-DC10_DISABLE_LEGACY_IMPORT" ) endif() - if(USE_DISTRIBUTED) + if(USE_DISTRIBUTED AND NOT WIN32) append_filelist("libtorch_distributed_sources" TORCH_SRCS) endif() endif() + if(USE_CUDA OR USE_ROCM) + append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS) + endif() + if(USE_CUDA) - list(APPEND Caffe2_GPU_SRCS - ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp - ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp - ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp - ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/arith.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/compute_at.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/codegen.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/dispatch.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/expr_evaluator.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_kernel_arg.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_launch_params.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_utils.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/fusion.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/graph_fuser.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/index_compute.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/instrumentation.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_base_nodes.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_cloner.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_graphviz.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_nodes.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir_builder.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_index.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_insert_syncs.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_thread_predicate.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_unroll.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_validation.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/predicate_compute.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/register_interface.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/scheduler.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/shape_inference.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_view.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_iter.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_replay.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_rfactor.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/type.cpp - ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp - ) + list(APPEND Caffe2_GPU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) if(MSVC) # Delay load nvcuda.dll so we can import torch compiled with cuda on a CPU-only machine @@ -567,13 +532,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) endif() if(USE_ROCM) - list(APPEND Caffe2_HIP_SRCS - ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp - ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp - ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp - ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp - ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp - ) + list(APPEND Caffe2_HIP_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) if(USE_NCCL) list(APPEND Caffe2_HIP_SRCS ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) @@ -765,6 +724,49 @@ elseif(USE_CUDA) endif() endif() +if(USE_CUDA OR USE_ROCM) + if(USE_CUDA) + set(TORCHLIB_FLAVOR torch_cuda) + elseif(USE_ROCM) + set(TORCHLIB_FLAVOR torch_hip) + endif() + + # The list of NVFUSER runtime files + list(APPEND NVFUSER_RUNTIME_FILES + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/block_reduction.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/broadcast.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/fp16_support.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/grid_reduction.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/helpers.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/random_numbers.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensor.cu + ) + + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/include/nvfuser_resources") + + # "stringify" NVFUSER runtime sources + # (generate C++ header files embedding the original input as a string literal) + set(NVFUSER_STRINGIFY_TOOL "${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tools/stringify_file.py") + foreach(src ${NVFUSER_RUNTIME_FILES}) + get_filename_component(filename ${src} NAME_WE) + set(dst "${CMAKE_BINARY_DIR}/include/nvfuser_resources/${filename}.h") + add_custom_command( + COMMENT "Stringify NVFUSER runtime source file" + OUTPUT ${dst} + DEPENDS ${src} + COMMAND ${PYTHON_EXECUTABLE} ${NVFUSER_STRINGIFY_TOOL} -i ${src} -o ${dst} + ) + add_custom_target(nvfuser_rt_${filename} DEPENDS ${dst}) + add_dependencies(${TORCHLIB_FLAVOR} nvfuser_rt_${filename}) + + # also generate the resource headers during the configuration step + # (so tools like clang-tidy can run w/o requiring a real build) + execute_process(COMMAND + ${PYTHON_EXECUTABLE} ${NVFUSER_STRINGIFY_TOOL} -i ${src} -o ${dst}) + endforeach() + + target_include_directories(${TORCHLIB_FLAVOR} PRIVATE "${CMAKE_BINARY_DIR}/include") +endif() if(NOT MSVC AND USE_XNNPACK) TARGET_LINK_LIBRARIES(torch_cpu PRIVATE fxdiv) @@ -825,6 +827,11 @@ endif() target_include_directories(torch_cpu PRIVATE ${TORCH_ROOT}/third_party/miniz-2.0.8) + if(USE_KINETO) + target_include_directories(torch_cpu PRIVATE + ${TORCH_ROOT}/third_party/kineto/libkineto/include + ${TORCH_ROOT}/third_party/kineto/libkineto/src) + endif() install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch @@ -838,10 +845,10 @@ endif() DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch) - if(BUILD_TEST AND NOT USE_ROCM) + if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/tensorexpr ${CMAKE_BINARY_DIR}/test_tensorexpr) - if(USE_DISTRIBUTED) + if(USE_DISTRIBUTED AND NOT WIN32) add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) endif() endif() @@ -893,9 +900,7 @@ endif() DESTINATION share/cmake/Torch) if(USE_DISTRIBUTED) - if(NOT MSVC) - add_subdirectory(${TORCH_SRC_DIR}/lib/c10d lib_c10d) - endif() + add_subdirectory(${TORCH_SRC_DIR}/lib/c10d lib_c10d) endif() @@ -970,6 +975,14 @@ if(USE_DISTRIBUTED) target_compile_definitions(torch_cpu PRIVATE USE_DISTRIBUTED ) + # Pass USE_RPC in order to reduce use of + # #if defined(USE_DISTRIBUTED) && !defined(_WIN32) + # need to be removed when RPC is supported + if(NOT WIN32) + target_compile_definitions(torch_cpu PRIVATE + USE_RPC + ) + endif() # Pass USE_TENSORPIPE to torch_cpu as some parts of rpc/utils.cpp # can only be compiled with USE_TENSORPIPE is set. if(USE_TENSORPIPE) @@ -1017,7 +1030,11 @@ if($ENV{TH_BINARY_BUILD}) # # These linker commands do not work on OS X, do not attempt this there. # (It shouldn't matter anyway, though, because OS X has dropped CUDA support) - set_target_properties(torch_cpu PROPERTIES LINK_FLAGS "-Wl,--undefined=mkl_lapack_slaed0 -Wl,--undefined=mkl_lapack_dlaed0 -Wl,--undefined=mkl_lapack_dormql -Wl,--undefined=mkl_lapack_sormql") + foreach(_symb slaed0 daled0 dormql sormql zheevd cheevd) + STRING(APPEND _undefined_link_flags " -Wl,--undefined=mkl_lapack_${_symb}") + endforeach(_symb) + set_target_properties(torch_cpu PROPERTIES LINK_FLAGS ${_undefined_link_flags}) + endif() endif() @@ -1180,6 +1197,18 @@ if(USE_CUDA) endif() +# ---[ Metal(OSX) modification +if(APPLE AND USE_PYTORCH_METAL) + if(NOT INTERN_BUILD_MOBILE) + include(../cmake/Metal.cmake) + # We need to link the system frameworks explicitly + find_library(metal NAMES Metal) + find_library(mps NAMES MetalPerformanceShaders) + find_library(foundation NAMES Foundation) + find_library(accelerate NAMES Accelerate) + target_link_libraries(torch_cpu PUBLIC ${metal} ${mps} ${foundation} ${accelerate}) + endif() +endif() # Note [Global dependencies] # Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized, @@ -1247,7 +1276,13 @@ endif() if(BUILD_STATIC_RUNTIME_BENCHMARK) add_subdirectory(${TORCH_ROOT}/benchmarks/static_runtime ${PROJECT_BINARY_DIR}/bin) add_executable(static_runtime_bench "${STATIC_RUNTIME_BENCHMARK_SRCS}") + add_executable(static_runtime_test "${STATIC_RUNTIME_TEST_SRCS}") target_link_libraries(static_runtime_bench torch_library benchmark) + target_link_libraries(static_runtime_test torch_library gtest_main) +endif() + +if(BUILD_TENSOREXPR_BENCHMARK) + add_subdirectory(${TORCH_ROOT}/benchmarks/cpp/tensorexpr ${CMAKE_BINARY_DIR}/tensorexpr_bench) endif() if(BUILD_MOBILE_BENCHMARK) @@ -1280,8 +1315,8 @@ if(BUILD_TEST) foreach(test_src ${ATen_VEC256_TEST_SRCS}) foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) get_filename_component(test_name ${test_src} NAME_WE) - list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY) - list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS) + list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY) + list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS) separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}") add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}") target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main) @@ -1291,7 +1326,7 @@ if(BUILD_TEST) target_compile_definitions(${test_name}_${CPU_CAPABILITY} PRIVATE CPU_CAPABILITY=${CPU_CAPABILITY} CPU_CAPABILITY_${CPU_CAPABILITY}) target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE ${FLAGS}) if(NOT MSVC) - target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE -Wno-ignored-qualifiers) + target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE -Wno-ignored-qualifiers) endif(NOT MSVC) add_test(NAME ${test_name}_${CPU_CAPABILITY} COMMAND $) endforeach() diff --git a/caffe2/contrib/aten/README.md b/caffe2/contrib/aten/README.md index 377a1f780271c..593079ef13933 100644 --- a/caffe2/contrib/aten/README.md +++ b/caffe2/contrib/aten/README.md @@ -1,6 +1,6 @@ # An ATen operator for Caffe2 -[ATen](https://github.com/zdevito/aten) is a simple tensor library thats exposes the Tensor operations in Torch +ATen is a simple tensor library thats exposes the Tensor operations in Torch and PyTorch directly in C++14. This library provides a generated wrapper around the ATen API that makes these functions available in Caffe2 as an operator. It also makes it accessible using the ToffeeIR. @@ -8,8 +8,8 @@ ToffeeIR. ### Example Usage in Caffe2 -First identify a function in ATen you want to call in [Functions.h](https://github.com/zdevito/ATen/blob/master/doc/Functions.h), -[Tensor.h](https://github.com/zdevito/ATen/blob/master/doc/Tensor.h), or [Type.h](https://github.com/zdevito/ATen/blob/master/doc/Type.h). +First identify a function in ATen you want to call in Functions.h, +Tensor.h, or Type.h. We will call the `pow` operator: diff --git a/caffe2/contrib/aten/aten_op.cc b/caffe2/contrib/aten/aten_op.cc index 9e7479141ad41..dba68d21c2dd1 100644 --- a/caffe2/contrib/aten/aten_op.cc +++ b/caffe2/contrib/aten/aten_op.cc @@ -6,13 +6,17 @@ namespace caffe2 { namespace internal { at::Tensor index_with_uint8_handling( const at::Tensor& self, - at::TensorList indices) { + const torch::List>& indices) { // Support BC only for the simplest case of mask indexing - if (indices.size() == 1 && indices[0].scalar_type() == at::kByte) { - TORCH_WARN( - "Indexing with uint8 mask tensor in ATenOp is now deprecated," - " please use a bool mask instead."); - return at::index(self, {indices[0].to(at::kBool)}); + if (indices.size() == 1) { + c10::optional first = indices[0]; + if (first.has_value() + && first->scalar_type() == at::kByte) { + TORCH_WARN( + "Indexing with uint8 mask tensor in ATenOp is now deprecated," + " please use a bool mask instead."); + return at::index(self, {first->to(at::kBool)}); + } } return at::index(self, indices); } diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index 97c575ea58db1..cd1ce7651b482 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -19,9 +19,9 @@ namespace caffe2 { using at::Half; // for AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ...) namespace internal { -CAFFE2_API at::Tensor index_with_uint8_handling( +TORCH_API at::Tensor index_with_uint8_handling( const at::Tensor& self, - at::TensorList indices); + const torch::List>& indices); } template @@ -86,6 +86,16 @@ class ATenOp : public Operator { std::vector peekSlice(size_t i, size_t len, size_t N) { std::vector results; + results.reserve(len); + for (size_t ii = i; ii < i + len; ++ii) { + results.push_back(peek(ii, N)); + } + return results; + } + + torch::List> peekSliceOptionals(size_t i, size_t len, size_t N) { + torch::List> results; + results.reserve(len); for (size_t ii = i; ii < i + len; ++ii) { results.push_back(peek(ii, N)); } diff --git a/caffe2/contrib/aten/docs/pytorch_to_caffe2.md b/caffe2/contrib/aten/docs/pytorch_to_caffe2.md index 85c275bb51781..c3f615ee37b94 100644 --- a/caffe2/contrib/aten/docs/pytorch_to_caffe2.md +++ b/caffe2/contrib/aten/docs/pytorch_to_caffe2.md @@ -6,7 +6,7 @@ operators that haven't been standardized yet, or custom `torch.autograd.Function are specific to a network. To bridge this gap, we provide an experimental operator in ONNX that allows you to directly access PyTorch's tensor functions using the ATen library. -[ATen](https://github.com/zdevito/aten) is the underlying C++ library that PyTorch uses to do tensor operations. Caffe2 has an [ATen operator](https://github.com/caffe2/caffe2/tree/master/caffe2/contrib/aten) +[ATen](https://github.com/pytorch/pytorch/tree/master/aten) is the underlying C++ library that PyTorch uses to do tensor operations. Caffe2 has an [ATen operator](https://github.com/pytorch/pytorch/tree/master/caffe2/contrib/aten) that can run these tensor functions in a Caffe2 network after importing them through ONNX. This guide explains how to configure Caffe2 and modify your PyTorch program to use @@ -61,8 +61,8 @@ We can add a `symbolic` method to it like so: The function `graph.at` adds a new ATen op the computation graph. You can call any ATen function using this facility. To do so, -first identify a function in ATen you want to call in [Functions.h](https://github.com/zdevito/ATen/blob/master/doc/Functions.h), -[Tensor.h](https://github.com/zdevito/ATen/blob/master/doc/Tensor.h), or [Type.h](https://github.com/zdevito/ATen/blob/master/doc/Type.h). +first identify a function in ATen you want to call in Functions.h, +Tensor.h, or Type.h. As an example, we might want to call the `pow` operator: @@ -86,9 +86,9 @@ To call methods of ATen's `Type` objects, you provide an additional string attri that determines the type. For instance, `ones` creates a new constant tensor of all ones: ``` class Type { - ... - virtual Tensor ones(IntArrayRef size) const; - ... + ... + virtual Tensor ones(IntArrayRef size) const; + ... }; ``` diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py index 703bf3ec167f0..64d3de547bb7d 100755 --- a/caffe2/contrib/aten/gen_op.py +++ b/caffe2/contrib/aten/gen_op.py @@ -20,7 +20,7 @@ import argparse import os from copy import deepcopy -from typing import Dict, List +from typing import Dict, List, Set parser = argparse.ArgumentParser() parser.add_argument("--template_dir", default=".", help="where template.h is") @@ -68,7 +68,7 @@ def value_has_tensors(v): def value_is_tensor_type(v): - return value_has_tensors(v) and v['dynamic_type'] != 'TensorList' + return value_has_tensors(v) and v['dynamic_type'] not in ['TensorList', 'const c10::List> &'] # for each aten type, how do we handle a return value of that type? @@ -208,7 +208,7 @@ def self_as_first_argument(arguments): def get_num_inputs(o): args = 0 for a in o['arguments']: - if a['type'] == 'TensorList': + if a['type'] in ['TensorList', 'const c10::List> &']: return '*' elif value_has_tensors(a): args += 1 @@ -236,12 +236,12 @@ def emit_assignments(o, env): decls = yaml.load(read(os.path.join(args.yaml_dir, 'Declarations.yaml')), Loader=Loader) factory_methods = find_factory_methods(decls) filtered = [expanded for o in decls for expanded in expand(o) if supports(expanded, factory_methods)] - top_env = { + top_env: Dict[str, List] = { 'mappings': [], 'implementations': [], 'cases': [], - } # type: Dict[str, List] - seen = set() + } + seen: Set[str] = set() key = 0 for o in filtered: # [DESCRIPTORS] @@ -277,23 +277,28 @@ def emit_assignments(o, env): # e.g. "Float" is at::kFloat assert('Type' in o['method_of']) - static_tensor_inputs = sum(arg['type'] != 'TensorList' and value_is_tensor_type(arg) for arg in o['arguments']) - has_tensorlist = any(arg['type'] == 'TensorList' for arg in o['arguments']) + static_tensor_inputs = sum(arg['type'] not in ['TensorList', 'const c10::List> &'] and value_is_tensor_type(arg) for arg in o['arguments']) + has_tensorlist = any(arg['type'] in ['TensorList', 'const c10::List> &'] for arg in o['arguments']) if has_tensorlist: - tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] == 'TensorList'][0] + tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['TensorList', 'const c10::List> &']][0] real_inputs = 0 for i, arg in enumerate(o['arguments']): env['arguments'].append(arg['name']) - # Emulate logic in gen_unboxing_wrappers.py. Pretend the flat argument - # list is a stack where the end is the top. + # Pretend the flat argument list is a stack where the end is the top. view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs if arg['type'] == 'TensorList': # NOTE: do not advance real_inputs here. After this we will - # switch to indexing the "stack" from the end as if we only had + # switch to indexing the "stack" from the end env['statements'].append( 'auto {} = peekSlice({}, InputSize() - {}, InputSize());' .format(arg['name'], real_inputs, static_tensor_inputs)) + elif arg['type'] == 'const c10::List> &': + # NOTE: do not advance real_inputs here. After this we will + # switch to indexing the "stack" from the end + env['statements'].append( + 'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());' + .format(arg['name'], real_inputs, static_tensor_inputs)) elif value_is_tensor_type(arg): # load tensor inputs from Caffe2 env['statements'].append( diff --git a/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc b/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc index 7debb5f7cf7e4..c616688681786 100644 --- a/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc +++ b/caffe2/contrib/fakelowp/elementwise_fp16_fake_op.cc @@ -22,7 +22,21 @@ int getSizeFromDims(const std::vector& dims) { template struct FP16PairWiseCPUFunctor : public OP { + template bool Forward( + const std::vector& A_dims, + const std::vector& B_dims, + const TIn* A, + const TIn* B, + TOut* C, + CPUContext* context) const { + OP::Forward(A_dims, B_dims, A, B, C, context); + + return true; + } + + template<> + bool Forward( const std::vector& A_dims, const std::vector& B_dims, const float* A, @@ -54,7 +68,7 @@ OPERATOR_SCHEMA(SumFakeFp16).NumInputs(1, INT_MAX).NumOutputs(1, INT_MAX); REGISTER_CPU_OPERATOR( AddFakeFp16, BinaryElementwiseOp< - TensorTypes, + TensorTypes, CPUContext, FP16PairWiseCPUFunctor>>); OPERATOR_SCHEMA(AddFakeFp16).NumInputs(2).NumOutputs(1); diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h index 4aef84663adc9..ddeea5d5f56cf 100644 --- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h +++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h @@ -189,7 +189,7 @@ class LayerNormFakeFp16Op final : public Operator { int Nout = X.numel(); std::vector inv_scalev(Nout, inv_scale); - std::vector offsetv(Nout, Y_offset - 128.0); + std::vector offsetv(Nout, Y_offset); uint8_t* Y_uint8_data = Y_int8->t.template mutable_data(); fake_fp16::fma_fp16(Nout, Y_fp16.data(), inv_scalev.data(), offsetv.data()); @@ -200,7 +200,6 @@ class LayerNormFakeFp16Op final : public Operator { for (int i = 0; i < Nout; i++) { float halfRes = offsetv[i]; halfRes = round(halfRes); - halfRes = halfRes + 128.0; if (std::isinf(halfRes)) { if (halfRes > 0) { halfRes = qmax; diff --git a/caffe2/contrib/fakelowp/test/test_batchmatmul_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_batchmatmul_nnpi_fp16.py index 94a76fed85f5d..d6e5c5db6d2a4 100644 --- a/caffe2/contrib/fakelowp/test/test_batchmatmul_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_batchmatmul_nnpi_fp16.py @@ -1,8 +1,3 @@ - - - - - import numpy as np import unittest import caffe2.python.fakelowp.init_shared_libs # noqa @@ -11,6 +6,7 @@ from caffe2.python import core, workspace from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net from caffe2.python.fakelowp.test_utils import print_test_debug_info +import datetime from hypothesis import given, settings import hypothesis.strategies as st import caffe2.python.serialized_test.serialized_test_util as serial @@ -29,7 +25,7 @@ class TestBatchMatMul(serial.SerializedTestCase): trans_b=st.booleans(), run_ints=st.booleans() ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_batch_matmul(self, M, K, N, C, rand_seed, trans_a, trans_b, run_ints): np.random.seed(rand_seed) workspace.ResetWorkspace() diff --git a/caffe2/contrib/fakelowp/test/test_batchnorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_batchnorm_nnpi_fp16.py index 7b1b5f0701715..56ac6733f13d7 100644 --- a/caffe2/contrib/fakelowp/test/test_batchnorm_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_batchnorm_nnpi_fp16.py @@ -1,8 +1,3 @@ - - - - - import numpy as np import unittest @@ -15,6 +10,7 @@ from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net from caffe2.python.fakelowp.test_utils import print_test_debug_info import caffe2.python.serialized_test.serialized_test_util as serial +import datetime core.GlobalInit(["caffe2", "--glow_global_fp16=1", "--glow_global_fused_scale_offset_fp16=1", @@ -46,7 +42,7 @@ class BatchnormTest(serial.SerializedTestCase): size=st.integers(2, 30), input_channels=st.integers(2, 40), batch_size=st.integers(2, 20)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_bn(self, seed, size, input_channels, batch_size): workspace.ResetWorkspace() np.random.seed(seed) diff --git a/caffe2/contrib/fakelowp/test/test_chunking.py b/caffe2/contrib/fakelowp/test/test_chunking.py new file mode 100644 index 0000000000000..306b5c3b3f02b --- /dev/null +++ b/caffe2/contrib/fakelowp/test/test_chunking.py @@ -0,0 +1,142 @@ +# Must happen before importing caffe2.python.* +import caffe2.python.fakelowp.init_shared_libs # noqa +import datetime +import numpy as np +from hypothesis import given, settings, example +from hypothesis import strategies as st +from caffe2.python import core, workspace +from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net +from caffe2.python.fakelowp.test_utils import print_test_debug_info +import caffe2.python.serialized_test.serialized_test_util as serial + +# Test that parallel chunks behave the same way as the serial one + +workspace.GlobalInit( + [ + "caffe2", + "--glow_global_fp16=1", + "--glow_global_fused_scale_offset_fp16=1", + "--glow_global_force_sls_fp16_accum=1", + "--glow_nnpi_num_parallel_chunks=2", + "--glow_use_dag_optimizer=false", + "--glow_dump_graph=true", + ] +) + +class Fusions(serial.SerializedTestCase): + def _get_scale_zp(self, tensor): + tensor_max = np.max(tensor) + tensor_min = min(0, np.min(tensor)) + scale = np.float32(np.float16((tensor_max - tensor_min) / 255.0)) + if scale < 1e-6: + scale = 1e-6 + zero_point = 0 - tensor_min / scale + zero_point = int(round(np.clip(zero_point, 0, 255.0))) + return (scale, zero_point) + + @given( + scale=st.floats(1e-4, 1e2), + zp=st.integers(-128, 128), + rand_seed=st.integers(0, 65534), + m=st.integers(32, 64), + k=st.integers(1000, 6000), + n=st.integers(200, 600), + ) + # @example(m=64, k=5423, n=553, scale=1e-3, zp=120, rand_seed=1) + @settings(deadline=datetime.timedelta(seconds=1000), max_examples=1) + def test_ParallelFC(self, m, k, n, scale, zp, rand_seed): + np.random.seed(rand_seed) + workspace.ResetWorkspace() + + # Y = W_T * X + b + X_fp32 = np.random.uniform(-1, 1, size=(m, k)).astype(np.float16) \ + .astype(np.float32) + + W_fp32 = np.random.uniform(-1, 1, size=(n, k)).astype(np.float32) + b_fp32 = np.zeros((n,), dtype=np.float32) + + X_scale, X_zero_point = self._get_scale_zp(X_fp32) + + workspace.FeedBlob("X", X_fp32) + workspace.FeedBlob("W", W_fp32) + workspace.FeedBlob("b", b_fp32) + + workspace.RunOperatorOnce( + core.CreateOperator( + "Int8FCPackWeight", + ["W"], + ["W_int8"], + engine="DNNLOWP", + save_unpacked_weights=True, + in_scale=X_scale, + ) + ) + + ref_net = core.Net("net") + ref_net.Int8QuantizeNNPI( + ["X"], + ["X_int8"], + Y_scale=X_scale, + Y_zero_point=X_zero_point + ) + ref_net.Int8FCFakeAcc32NNPI( + ["X_int8", "W_int8", "b"], + ["Y_int8"], + Y_scale=X_scale, + Y_zero_point=X_zero_point, + ) + ref_net.Int8Relu( + ["Y_int8"], + ["Y_relu"], + Y_zero_point=X_zero_point, + Y_scale=X_scale, + ) + ref_net.Int8DequantizeNNPI( + ["Y_relu"], + ["Y"] + ) + ref_net.Proto().external_output.append("Y") + + # run ref_net + workspace.RunNetOnce(ref_net) + Y_fbgemm = workspace.FetchBlob("Y") + + # run onnxifi net + ref_net.Proto().op[0].type = "Int8Quantize" + ref_net.Proto().op[1].type = "Int8FC" + ref_net.Proto().op[2].type = "Int8Relu" + ref_net.Proto().op[3].type = "Int8Dequantize" + net_onnxified = onnxifi_caffe2_net( + ref_net.Proto(), + {}, + debug=True, + adjust_batch=False, + use_onnx=False, + weight_names=["W_int8", "b"], + ) + num_onnxified_ops = sum( + 1 if o.type == "Onnxifi" else 0 for o in net_onnxified.op + ) + print(net_onnxified) + np.testing.assert_equal(num_onnxified_ops, 1) + workspace.CreateNet(net_onnxified) + workspace.RunNet(net_onnxified.name) + Y_glow = workspace.FetchBlob("Y") + + if not np.allclose(Y_glow, Y_fbgemm): + diff_Y = np.abs(Y_glow - Y_fbgemm) + print_test_debug_info( + "int8_fc", + { + "seed": rand_seed, + "n": n, + "X": X_fp32, + "W": W_fp32, + "b": b_fp32, + "Y_fbgemm": Y_fbgemm, + "Y_glow": Y_glow, + "diff": diff_Y, + "maxdiff": diff_Y.max(axis=1), + }, + ) + assert 0 diff --git a/caffe2/contrib/fakelowp/test/test_deq_swish_quant_nnpi.py b/caffe2/contrib/fakelowp/test/test_deq_swish_quant_nnpi.py index b7a9fc810cfcf..7ee160e196027 100644 --- a/caffe2/contrib/fakelowp/test/test_deq_swish_quant_nnpi.py +++ b/caffe2/contrib/fakelowp/test/test_deq_swish_quant_nnpi.py @@ -1,11 +1,11 @@ - - import numpy as np import caffe2.python.fakelowp.init_shared_libs # noqa from caffe2.python import core, workspace from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net from caffe2.python.fakelowp.test_utils import print_test_debug_info import caffe2.python.serialized_test.serialized_test_util as serial +import datetime +from hypothesis import settings core.GlobalInit(["caffe2", "--caffe2_log_level=-3", "--glow_global_fp16=1"]) @@ -24,6 +24,7 @@ def _sigmoid(self, x): def _swish(self, x): return np.float32(x) * self._sigmoid(x) + @settings(deadline=datetime.timedelta(seconds=10)) def test_swish_int8(self): np.random.seed(0) workspace.ResetWorkspace() diff --git a/caffe2/contrib/fakelowp/test/test_fc_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_fc_nnpi_fp16.py index 7a68af63a84b1..d9c2bd37daebc 100644 --- a/caffe2/contrib/fakelowp/test/test_fc_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_fc_nnpi_fp16.py @@ -1,8 +1,3 @@ - - - - - import numpy as np import unittest @@ -14,6 +9,7 @@ from caffe2.python import workspace from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net from caffe2.python.fakelowp.test_utils import print_test_debug_info +import datetime import caffe2.python.serialized_test.serialized_test_util as serial core.GlobalInit(["caffe2", "--caffe2_log_level=-3", "--glow_global_fp16=1"]) @@ -23,7 +19,7 @@ class FCTest(serial.SerializedTestCase): @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_clip(self, seed): np.random.seed(seed) m, n, k = 8, 8, 8 @@ -48,7 +44,7 @@ def test_clip(self, seed): ) workspace.GlobalInit( ['caffe2', '--caffe2_log_level=0', '--glow_global_fp16=1', - '--glow_clip_fp16']) + '--glow_clip_fp16', '--glow_global_fp16_constants=1']) workspace.SwitchWorkspace("glow_test_ws", True) workspace.ResetWorkspace() W0 = np.full((n, k), 65536.0, dtype) @@ -82,7 +78,7 @@ def test_clip(self, seed): n=st.integers(4, 50), seed=st.integers(0, 65534) ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_fc_exercise(self, m, k, n, seed): """ Test that the matmul engine is working, this doesn't test precision @@ -147,7 +143,7 @@ def test_fc_exercise(self, m, k, n, seed): assert(0) @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_fc_numeric_cases(self, seed): """ Test numerics, use examples found from the unit test. Use Fp16FCAcc16NNPI as a reference. @@ -272,7 +268,7 @@ def test_fc_numeric_cases(self, seed): seed=st.integers(0, 65534), use_packed=st.integers(0, 2) ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_fc_num0(self, seed, m, k, n, use_packed): """ Test numerics, fix a dimension and determine the ranges of error. Use Fp16FCAcc16 as a reference. diff --git a/caffe2/contrib/fakelowp/test/test_fusions.py b/caffe2/contrib/fakelowp/test/test_fusions.py index 45757badba43f..3e22d7c5937be 100644 --- a/caffe2/contrib/fakelowp/test/test_fusions.py +++ b/caffe2/contrib/fakelowp/test/test_fusions.py @@ -1,7 +1,6 @@ - - # Must happen before importing caffe2.python.* import caffe2.python.fakelowp.init_shared_libs # noqa +import datetime import numpy as np from hypothesis import given, settings from hypothesis import strategies as st @@ -27,8 +26,8 @@ class Fusions(serial.SerializedTestCase): size=st.integers(1, 100000), rand_seed=st.integers(0, 65534), ) - @settings(deadline=None) - def Skip_test_tanhquantize(self, scale, zp, size, rand_seed): + @settings(deadline=datetime.timedelta(seconds=10)) + def test_tanhquantize(self, scale, zp, size, rand_seed): np.random.seed(rand_seed) workspace.ResetWorkspace() diff --git a/caffe2/contrib/fakelowp/test/test_int8_ops_nnpi.py b/caffe2/contrib/fakelowp/test/test_int8_ops_nnpi.py index 5a91a00706ff5..1507f41a48611 100644 --- a/caffe2/contrib/fakelowp/test/test_int8_ops_nnpi.py +++ b/caffe2/contrib/fakelowp/test/test_int8_ops_nnpi.py @@ -1,5 +1,3 @@ - - import caffe2.python.fakelowp.init_shared_libs # noqa import numpy as np from caffe2.python import core, workspace @@ -7,8 +5,14 @@ from hypothesis import given, strategies as st, settings from caffe2.python.fakelowp.test_utils import print_test_debug_info import caffe2.python.serialized_test.serialized_test_util as serial +import datetime -core.GlobalInit(["caffe2", "--caffe2_log_level=-3", "--glow_global_fp16=1"]) +core.GlobalInit(["caffe2", + "--caffe2_log_level=-3", + "--glow_global_fp16=1", + "--glow_clip_quant_range_to_fp16=1", + "--glow_global_fp16_constants=1" + ]) class Int8OpsTest(serial.SerializedTestCase): @@ -27,7 +31,7 @@ def _get_scale_zp(self, tensor): rand_seed=st.integers(0, 65534), non_zero_offset=st.booleans() ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=50)) def test_int8_quantize(self, n, rand_seed, non_zero_offset): print("n={}, rand_seed={}".format(n, rand_seed)) np.random.seed(rand_seed) @@ -128,7 +132,7 @@ def test_int8_quantize(self, n, rand_seed, non_zero_offset): rand_seed=st.integers(0, 65534), quantize_bias=st.sampled_from([False]), ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=50)) def test_int8_fc( self, n, m, k, rand_seed, quantize_bias, f ): @@ -229,7 +233,7 @@ def test_int8_fc( n=st.integers(1, 4), rand_seed=st.integers(0, 65534) ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_int8_small_input(self, n, rand_seed): print("n={}, rand_seed={}".format(n, rand_seed)) np.random.seed(rand_seed) diff --git a/caffe2/contrib/fakelowp/test/test_int8_quant.py b/caffe2/contrib/fakelowp/test/test_int8_quant.py index 02095286e1ee1..2770dc7bef046 100644 --- a/caffe2/contrib/fakelowp/test/test_int8_quant.py +++ b/caffe2/contrib/fakelowp/test/test_int8_quant.py @@ -1,12 +1,12 @@ - - # Must happen before importing caffe2.python.* import caffe2.python.fakelowp.init_shared_libs # noqa +import datetime import numpy as np from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net import caffe2.python.serialized_test.serialized_test_util as serial +from hypothesis import settings workspace.GlobalInit( [ @@ -18,6 +18,7 @@ ) class QuantTest(serial.SerializedTestCase): + @settings(deadline=datetime.timedelta(seconds=10)) def test_dequantize(self): pred_net = caffe2_pb2.NetDef() pred_net.name = "pred" @@ -60,6 +61,7 @@ def test_dequantize(self): Y_glow = workspace.FetchBlob("Y") np.testing.assert_equal(Y_ref, Y_glow) + @settings(deadline=datetime.timedelta(seconds=20)) def test_quantize(self): pred_net = caffe2_pb2.NetDef() pred_net.name = "pred" diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py index 9ff0986116b6d..f992c6f9e1fcc 100644 --- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py @@ -1,8 +1,3 @@ - - - - - import numpy as np import caffe2.python.fakelowp.init_shared_libs # noqa from caffe2.proto import caffe2_pb2 @@ -13,6 +8,7 @@ from hypothesis import given, settings from hypothesis import strategies as st import caffe2.python.serialized_test.serialized_test_util as serial +import datetime core.GlobalInit(["caffe2", "--glow_global_fp16=1", @@ -30,8 +26,8 @@ class LayerNorm(serial.SerializedTestCase): size=st.integers(min_value=2, max_value=128), epsilon=st.floats(min_value=1e-4, max_value=1e-3), elementwise_affine=st.booleans()) - @settings(deadline=None) - def Skip_test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine): + @settings(deadline=datetime.timedelta(seconds=10)) + def test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine): np.random.seed(seed) # Reset the workspace workspace.ResetWorkspace() @@ -144,9 +140,9 @@ def _layernorm_transform(self, X): size=st.integers(min_value=2, max_value=128), epsilon=st.floats(min_value=1e-4, max_value=1e-3), elementwise_affine=st.booleans()) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) # re-enable when T74553975 gets fixed - def Skip_test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine): + def test_fused_ln_quantize(self, seed, batch_size, size, epsilon, elementwise_affine): np.random.seed(seed) # Reset the workspace diff --git a/caffe2/contrib/fakelowp/test/test_op_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_op_nnpi_fp16.py index e8512b4dcd74f..8a5a2aaeaae70 100644 --- a/caffe2/contrib/fakelowp/test/test_op_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_op_nnpi_fp16.py @@ -1,11 +1,7 @@ - - - - - import numpy as np import caffe2.python.fakelowp.init_shared_libs # noqa +import datetime from hypothesis import given, settings from hypothesis import strategies as st from caffe2.proto import caffe2_pb2 @@ -103,22 +99,22 @@ def _test_binary_op_graph(self, name, seed): assert(0) @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_add_graph(self, seed): self._test_binary_op_graph("Add", seed) @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_sub_graph(self, seed): self._test_binary_op_graph("Sub", seed) @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_mul_graph(self, seed): self._test_binary_op_graph("Mul", seed) @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_div_graph(self, seed): self._test_binary_op_graph("Div", seed) @@ -199,7 +195,7 @@ def _test_op_w_ulp_error(self, seed, opname, regions, atol=0, err_threshold=2): # Once hypothesis.testing version is updated, we can re-enable # testing with different hypothesis examples. @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_sigmoid(self, seed): np.random.seed(seed) opname = "Sigmoid" @@ -213,7 +209,7 @@ def test_sigmoid(self, seed): # Once hypothesis.testing version is updated, we can re-enable # testing with different hypothesis examples. @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_tanh(self, seed): np.random.seed(seed) opname = "Tanh" @@ -230,7 +226,7 @@ def test_tanh(self, seed): # testing with different hypothesis examples. # TODO: move atol to 1e-8 once we get a non-lowered swish implementation @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_swish(self, seed): np.random.seed(seed) opname = "Swish" @@ -243,7 +239,7 @@ def test_swish(self, seed): # Once hypothesis.testing version is updated, we can re-enable # testing with different hypothesis examples. @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_logit(self, seed): np.random.seed(seed) workspace.ResetWorkspace() @@ -309,7 +305,7 @@ def test_logit(self, seed): class ReluTest(serial.SerializedTestCase): @given(seed=st.integers(0, 65534)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def relu_test(self, inputs, gc, dc, seed): np.random.seed(seed) inputs = np.random.rand(1).astype(np.float32) diff --git a/caffe2/contrib/fakelowp/test/test_sls_4bit_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_sls_4bit_nnpi_fp16.py index a8d6640fa58e3..489bfbc37f4f5 100644 --- a/caffe2/contrib/fakelowp/test/test_sls_4bit_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_sls_4bit_nnpi_fp16.py @@ -1,8 +1,3 @@ - - - - - import numpy as np import unittest @@ -16,6 +11,7 @@ from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net from caffe2.python.fakelowp.test_utils import print_test_debug_info import caffe2.python.serialized_test.serialized_test_util as serial +import datetime workspace.GlobalInit(["caffe2", "--glow_global_fp16=1", "--glow_global_fused_scale_offset_fp16=1", @@ -24,7 +20,7 @@ class SparseLengthsSum4BitFakeNNPIFp16Test(serial.SerializedTestCase): @given(seed=st.integers(0, 65535)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_slws_fused_4bit_rowwise_all_same(self, seed): np.random.seed(seed) workspace.ResetWorkspace() @@ -118,7 +114,7 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed): batch_size=st.integers(1, 32), max_weight=st.integers(0, 1), ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_slws_fused_4bit_rowwise(self, seed, num_rows, embedding_dim, batch_size, max_weight): workspace.ResetWorkspace() np.random.seed(seed) diff --git a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py index f8fd03cbfb730..041dcce97dbf8 100644 --- a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py @@ -1,9 +1,8 @@ - - import unittest # Must happen before importing caffe2.python.* import caffe2.python.fakelowp.init_shared_libs # noqa +import datetime import numpy as np from hypothesis import given, settings from hypothesis import strategies as st @@ -99,7 +98,7 @@ def Skip_test_SLS_NonQuantized_fp16(self): assert 0 @given(seed=st.integers(0, 65535)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_slws_fused_8bit_rowwise_all_same(self, seed): # Comment out for predictable debugging np.random.seed(seed) @@ -207,7 +206,7 @@ def test_slws_fused_8bit_rowwise_all_same(self, seed): batch_size=st.integers(1, 5), max_weight=st.integers(0, 100), ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_slws_fused_8bit_rowwise(self, seed, num_rows, embedding_dim, batch_size, max_weight): np.random.seed(seed) workspace.ResetWorkspace() @@ -315,7 +314,7 @@ def test_slws_fused_8bit_rowwise(self, seed, num_rows, embedding_dim, batch_size # Simple test to aid debugging order of operations # Minimize the case to an SLS that adds two rows @given(seed=st.integers(0, 65535)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_small_sls(self, seed): np.random.seed(seed) workspace.ResetWorkspace() @@ -419,6 +418,147 @@ def test_small_sls(self, seed): ) assert 0 + @given(seed=st.integers(0, 65535)) + @settings(deadline=datetime.timedelta(seconds=10)) + def test_sls_layernorm(self, seed): + np.random.seed(seed) + workspace.ResetWorkspace() + + n = 2 + DIM = 3 + data = 4 * (np.random.random_sample((n, DIM)) + 1).astype(np.float32) + + lengths = np.array([n], dtype=np.int32) + indices = np.array(range(n), dtype=np.int64) + weights = np.random.uniform(low=0.01, high=0.5, size=[n]).astype(np.float32) + + pred_net = caffe2_pb2.NetDef() + pred_net.name = "pred" + pred_net.external_input.extend( + ["quantized_data", "weights", "indices", "lengths"] + ) + pred_net.external_output.append("Y_norm") + pred_net.external_output.append("Y_mean") + pred_net.external_output.append("Y_std") + + pred_net.op.add().CopyFrom( + core.CreateOperator( + "SparseLengthsWeightedSumFused8BitRowwise", + ["quantized_data", "weights", "indices", "lengths"], + ["Y"], + ) + ) + + pred_net.op.add().CopyFrom( + core.CreateOperator( + "LayerNorm", + ["Y"], + ["Y_norm", "Y_mean", "Y_std"], + epsilon=1e-4, + ) + ) + + ref_net = caffe2_pb2.NetDef() + ref_net.name = "ref" + ref_net.external_input.extend( + ["quantized_data", "weights", "indices", "lengths"] + ) + ref_net.external_output.append("Y_norm") + ref_net.external_output.append("Y_mean") + ref_net.external_output.append("Y_std") + + ref_net.op.add().CopyFrom( + core.CreateOperator( + "SparseLengthsWeightedSumFused8BitRowwiseFakeFP16NNPI", + ["quantized_data", "weights", "indices", "lengths"], + ["Y"], + ) + ) + + ref_net.op.add().CopyFrom( + core.CreateOperator( + "LayerNormFakeFP16NNPI", + ["Y"], + ["Y_norm", "Y_mean", "Y_std"], + epsilon=1e-4, + axis=1, + elementwise_affine=False + ) + ) + + workspace.FeedBlob("data", data) + workspace.RunOperatorOnce( + core.CreateOperator( + "FloatToFused8BitRowwiseQuantized", ["data"], ["quantized_data"] + ) + ) + + quantized_data = workspace.FetchBlob("quantized_data") + + onnxified_net = onnxifi_caffe2_net( + pred_net, + {}, + max_batch_size=1, + max_seq_size=n, + debug=True, + adjust_batch=True, + use_onnx=False, + ) + print("before", pred_net) + print("after", onnxified_net) + workspace.FeedBlob("indices", indices) + workspace.FeedBlob("lengths", lengths) + workspace.FeedBlob("weights", weights) + + workspace.CreateNet(onnxified_net) + workspace.CreateNet(ref_net) + + workspace.RunNet(onnxified_net.name) + Y_glow = workspace.FetchBlob("Y_norm") + Y_mean_glow = workspace.FetchBlob("Y_mean") + Y_std_glow = workspace.FetchBlob("Y_std") + + workspace.RunNet(ref_net.name) + Y = workspace.FetchBlob("Y") + print("pre normalization", Y) + Y_ref = workspace.FetchBlob("Y_norm") + Y_mean_ref = workspace.FetchBlob("Y_mean") + Y_std_ref = workspace.FetchBlob("Y_std") + + # print(Y_ref, Y_glow) + # print(Y_ref.shape, Y_glow.shape) + + diff = np.abs(Y_ref - Y_glow) + max_err = np.max(diff, axis=1) + num_offenders = (max_err > 0).sum() + if num_offenders > 0: + np.set_printoptions(precision=12) + print( + "ref", + Y_ref.astype(np.float16).astype(np.float32), + "glow", + Y_glow.astype(np.float16).astype(np.float32), + ) + print_test_debug_info( + "slws_fused_8bit_rowwise_inv_scale", + { + "seed": seed, + "indices": indices, + "data": data, + "quantized_data": quantized_data, + "lengths": lengths, + "weights": weights, + "Y_norm_glow": Y_glow, + "Y_norm_ref": Y_ref, + "Y_mean_glow": Y_mean_glow, + "Y_std_glow": Y_std_glow, + "Y_mean_ref": Y_mean_ref, + "Y_std_ref": Y_std_ref, + "diff": diff, + "rowwise_diff": np.max(diff, axis=1), + }, + ) + assert 0 if __name__ == '__main__': diff --git a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp32.py b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp32.py index 207403f1bd0da..971bf8412f4ce 100644 --- a/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp32.py +++ b/caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp32.py @@ -1,9 +1,8 @@ - - import unittest # Must happen before importing caffe2.python.* import caffe2.python.fakelowp.init_shared_libs # noqa +import datetime import numpy as np from hypothesis import given, settings from hypothesis import strategies as st @@ -32,7 +31,7 @@ class SparseLengthsSum8BitFakeNNPIFp32Test(serial.SerializedTestCase): batch_size=st.integers(1, 5), max_weight=st.integers(0, 100), ) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_slws_fused_8bit_rowwise_acc32_nnpi( self, seed, num_rows, embedding_dim, batch_size, max_weight ): @@ -148,7 +147,7 @@ def test_slws_fused_8bit_rowwise_acc32_nnpi( @given(seed=st.integers(0, 65535)) - @settings(deadline=None) + @settings(deadline=datetime.timedelta(seconds=10)) def test_small_sls_acc32(self, seed): workspace.GlobalInit( [ diff --git a/caffe2/contrib/fakelowp/unary_fp16_fake_op.cc b/caffe2/contrib/fakelowp/unary_fp16_fake_op.cc index f3fa7ee4e9be6..5a67507c42ca3 100644 --- a/caffe2/contrib/fakelowp/unary_fp16_fake_op.cc +++ b/caffe2/contrib/fakelowp/unary_fp16_fake_op.cc @@ -6,9 +6,7 @@ C10_DECLARE_bool(caffe2_fbgemm_fake_fp16_clamp); -namespace caffe2 { - -namespace { +namespace fake_fp16{ auto sig_lut = std::vector{ 0.0000e+00f, 0.0000e+00f, 0.0000e+00f, 0.0000e+00f, 0.0000e+00f, 0.0000e+00f, 0.0000e+00f, 5.9605e-08f, 5.9605e-08f, 5.9605e-08f, @@ -101,14 +99,6 @@ at::Half CalcSigmoidByLUT(at::Half x) { return at::Half(res1 + res2); } -OpSchema::Cost CostInferenceForRelu( - const OperatorDef& def, - const vector& in) { - struct OpSchema::Cost cost = PointwiseCostInference<0>(def, in); - cost.params_bytes = 0; - return cost; -} - const int TANH_LINEAR_MAX_VALUE = 10048; const int TANH_ASYMPTOTE_MIN_VALUE = 17538; @@ -346,31 +336,6 @@ at::Half CalcTanhByPolynomial(at::Half input) { return tanhResult; } -struct SigmoidEmulatorFunctor { - bool operator()( - const int N, - const float* X, - float* Y, - CPUContext* /* unused */) const { - for (int i = 0; i < N; i++) { - Y[i] = CalcSigmoidByLUT((at::Half)X[i]); - } - return true; - } -}; - -struct TanhEmulatorFunctor { - bool operator()( - const int N, - const float* X, - float* Y, - CPUContext* /* unused */) const { - for (int i = 0; i < N; i++) { - Y[i] = CalcTanhByLUT((at::Half)X[i]); - } - return true; - } -}; static const float swishLutKnot[] = { -0.000000025618f, -0.000000027492f, -0.000000029503f, -0.000000031660f, @@ -501,7 +466,7 @@ at::Half CalcSwishByLUT(at::Half x) { float f_x = x; float f_one_over_delta = one_over_delta; float f_a_one_over_delta = -a_one_over_delta; - fake_fp16::fma_fp16(1, &f_x, &f_one_over_delta, &f_a_one_over_delta); + fma_fp16(1, &f_x, &f_one_over_delta, &f_a_one_over_delta); at::Half bin_calc = f_a_one_over_delta; uint32_t bin = bin_calc < 0 ? 0 : (uint32_t)floor(bin_calc); @@ -516,7 +481,7 @@ at::Half CalcSwishByLUT(at::Half x) { float f_delta = delta; float f_bin = at::Half(bin); float f_a = a; - fake_fp16::fma_fp16(1, &f_delta, &f_bin, &f_a); + fma_fp16(1, &f_delta, &f_bin, &f_a); at::Half bin_x = at::Half(f_a); at::Half p = at::Half(x - bin_x) * one_over_delta; @@ -525,7 +490,7 @@ at::Half CalcSwishByLUT(at::Half x) { float f_p = -p; float lutVal = at::Half(swishLutKnot[bin]); - fake_fp16::fma_fp16(1, &f_p, &lutVal, &lutVal); + fma_fp16(1, &f_p, &lutVal, &lutVal); at::Half res2 = lutVal; return at::Half(res1 + res2); @@ -555,7 +520,45 @@ at::Half CalcLogit(at::Half input, float eps) { } } -} // namespace +} // namespace fake_fp16 + +namespace caffe2 { +using namespace fake_fp16; + + +struct SigmoidEmulatorFunctor { + bool operator()( + const int N, + const float* X, + float* Y, + CPUContext* /* unused */) const { + for (int i = 0; i < N; i++) { + Y[i] = CalcSigmoidByLUT((at::Half)X[i]); + } + return true; + } +}; + +struct TanhEmulatorFunctor { + bool operator()( + const int N, + const float* X, + float* Y, + CPUContext* /* unused */) const { + for (int i = 0; i < N; i++) { + Y[i] = CalcTanhByLUT((at::Half)X[i]); + } + return true; + } +}; + +OpSchema::Cost CostInferenceForRelu( + const OperatorDef& def, + const vector& in) { + struct OpSchema::Cost cost = PointwiseCostInference<0>(def, in); + cost.params_bytes = 0; + return cost; +} REGISTER_CPU_OPERATOR( ReluFakeFp16, diff --git a/caffe2/contrib/fakelowp/unary_fp16_fake_op.h b/caffe2/contrib/fakelowp/unary_fp16_fake_op.h index cf62625445d9d..8f4c4cbc1851e 100644 --- a/caffe2/contrib/fakelowp/unary_fp16_fake_op.h +++ b/caffe2/contrib/fakelowp/unary_fp16_fake_op.h @@ -63,3 +63,10 @@ struct TanhFakeIdealFp16Functor { }; } // namespace caffe2 + +namespace fake_fp16 { + +at::Half CalcSigmoidByLUT(at::Half x); +at::Half CalcTanhByLUT(at::Half input); + +} // namespace fake_fp16 diff --git a/caffe2/contrib/gloo/common.h b/caffe2/contrib/gloo/common.h index 0c56bf932c770..f258775685bfe 100644 --- a/caffe2/contrib/gloo/common.h +++ b/caffe2/contrib/gloo/common.h @@ -11,7 +11,7 @@ namespace caffe2 { namespace gloo { -CAFFE2_API void signalFailure(Blob* status_blob, std::exception& exception); +TORCH_API void signalFailure(Blob* status_blob, std::exception& exception); struct createDeviceAttr { // "tcp" or "ibverbs" @@ -22,7 +22,7 @@ struct createDeviceAttr { std::string interface; }; -CAFFE2_API std::shared_ptr<::gloo::transport::Device> createDevice( +TORCH_API std::shared_ptr<::gloo::transport::Device> createDevice( const createDeviceAttr attr); // Captures the parameters passed to Gloo. diff --git a/caffe2/contrib/gloo/gloo_test.py b/caffe2/contrib/gloo/gloo_test.py index fbca9b8fe64c3..5ae066f5e3ca2 100644 --- a/caffe2/contrib/gloo/gloo_test.py +++ b/caffe2/contrib/gloo/gloo_test.py @@ -27,7 +27,6 @@ op_engine = 'GLOO' - class TemporaryDirectory: def __enter__(self): self.tmpdir = tempfile.mkdtemp() diff --git a/caffe2/contrib/gloo/store_handler.h b/caffe2/contrib/gloo/store_handler.h index 00b651c2d66af..a68f01eac25d1 100644 --- a/caffe2/contrib/gloo/store_handler.h +++ b/caffe2/contrib/gloo/store_handler.h @@ -8,7 +8,7 @@ namespace caffe2 { namespace gloo { -class CAFFE2_API StoreHandlerWrapper : public ::gloo::rendezvous::Store { +class TORCH_API StoreHandlerWrapper : public ::gloo::rendezvous::Store { public: explicit StoreHandlerWrapper(StoreHandler& handler) : handler_(handler) {} diff --git a/caffe2/contrib/nccl/cuda_nccl_gpu.cc b/caffe2/contrib/nccl/cuda_nccl_gpu.cc index 31cd55d08578c..ef2b9ab37ea09 100644 --- a/caffe2/contrib/nccl/cuda_nccl_gpu.cc +++ b/caffe2/contrib/nccl/cuda_nccl_gpu.cc @@ -28,13 +28,8 @@ class NCCLContext { // get stream priorities int lo_pri, hi_pri; CUDA_ENFORCE(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri)); -#ifndef __HIP_PLATFORM_HCC__ CUDA_ENFORCE(cudaStreamCreateWithPriority( &streams_[i], cudaStreamNonBlocking, hi_pri)); -#else - CUDA_ENFORCE(cudaStreamCreateWithFlags( - &streams_[i], cudaStreamNonBlocking)); -#endif // __HIP_PLATFORM_HCC__ CUDA_ENFORCE(cudaEventCreateWithFlags( &events_[i], cudaEventDefault | cudaEventDisableTiming)); } diff --git a/caffe2/contrib/opencl/context.h b/caffe2/contrib/opencl/context.h index ce788a39a7cdc..15bfda2203f06 100644 --- a/caffe2/contrib/opencl/context.h +++ b/caffe2/contrib/opencl/context.h @@ -59,7 +59,7 @@ class OpenCLContext final { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { CAFFE_ENFORCE(!meta.copy(), "OpenCLContext requires fundamental types."); CopyBytes(n * meta.itemsize(), src, dst); } diff --git a/caffe2/contrib/tensorboard/tensorboard_exporter.py b/caffe2/contrib/tensorboard/tensorboard_exporter.py index ef12ce563cde3..a9a1651a9b99d 100644 --- a/caffe2/contrib/tensorboard/tensorboard_exporter.py +++ b/caffe2/contrib/tensorboard/tensorboard_exporter.py @@ -7,7 +7,6 @@ import copy import logging import os -import six from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace @@ -93,7 +92,7 @@ def _get_blob_names(ops): def _remap_keys(m, f): - m2 = {f(key): value for key, value in six.iteritems(m)} + m2 = {f(key): value for key, value in m.items()} m.clear() m.update(m2) diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.h b/caffe2/contrib/tensorrt/tensorrt_tranformer.h index ec7786e6ee03d..4d4e92dbf4bc0 100644 --- a/caffe2/contrib/tensorrt/tensorrt_tranformer.h +++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.h @@ -14,12 +14,12 @@ namespace caffe2 { -CAFFE2_API void BuildInitializationList( +TORCH_API void BuildInitializationList( Workspace* ws, ::ONNX_NAMESPACE::GraphProto* g, std::unordered_set* initialization_list); -class CAFFE2_API TensorRTTransformer { +class TORCH_API TensorRTTransformer { public: TensorRTTransformer( size_t max_batch_size, diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index dd422c5b44ccc..fcf08eebfa8aa 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -26,17 +26,6 @@ C10_DEFINE_bool( false, "Serialize BOOL, UINT8, INT8, UINT16, INT16, INT64, FLOAT16 tensors using byte_data field instead of int32"); -#ifdef _MSC_VER -// It's MSVC, so we just have to guess ... and allow an override -#ifdef FOLLY_ENDIAN_BE -constexpr auto kIsLittleEndian = false; -#else -constexpr auto kIsLittleEndian = true; -#endif -#else -constexpr auto kIsLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; -#endif - namespace caffe2 { /** * @brief StringSerializer is the serializer for String. @@ -420,7 +409,7 @@ void DeserializeBlob(const BlobProto& blob_proto, Blob* result) { // === Local helper functions === // Get dimensions from Tensor proto -static std::vector DimsFromTensorProto(const TensorProto& proto) { +std::vector DimsFromTensorProto(const TensorProto& proto) { std::vector dims; dims.reserve(proto.dims().size()); for (const int64_t d : proto.dims()) { @@ -430,7 +419,7 @@ static std::vector DimsFromTensorProto(const TensorProto& proto) { } // Get number of elements from Tensor proto -static int64_t NumelFromTensorProto(const TensorProto& tensor_proto) { +int64_t NumelFromTensorProto(const TensorProto& tensor_proto) { int64_t numel = 1; for (const int64_t d : tensor_proto.dims()) { numel *= d; @@ -439,7 +428,7 @@ static int64_t NumelFromTensorProto(const TensorProto& tensor_proto) { } // Get data type from Tensor proto -static TypeMeta GetDataType(const TensorProto& tensor_proto) { +TypeMeta GetDataType(const TensorProto& tensor_proto) { TypeMeta dtype; if (tensor_proto.data_type() != TensorProto_DataType_UNDEFINED) { dtype = DataTypeToTypeMeta(tensor_proto.data_type()); @@ -459,7 +448,7 @@ static at::TensorOptions TensorOptionsFromProto( .device(OptionToDevice(tensor_proto.device_detail())); } -static std::unique_ptr ContextFromProto( +std::unique_ptr ContextFromProto( const TensorProto& tensor_proto) { auto device = OptionToDevice(tensor_proto.device_detail()); return CreateContext(device); diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h index 5309314af0c77..43c05498d1622 100644 --- a/caffe2/core/blob_serialization.h +++ b/caffe2/core/blob_serialization.h @@ -17,6 +17,17 @@ C10_DECLARE_int(caffe2_tensor_chunk_size); C10_DECLARE_int(caffe2_max_tensor_serializer_threads); C10_DECLARE_bool(caffe2_serialize_fp16_as_bytes); +#ifdef _MSC_VER +// It's MSVC, so we just have to guess ... and allow an override +#ifdef FOLLY_ENDIAN_BE +constexpr auto kIsLittleEndian = false; +#else +constexpr auto kIsLittleEndian = true; +#endif +#else +constexpr auto kIsLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; +#endif + namespace caffe2 { constexpr auto kTensorBlobType = "Tensor"; @@ -29,7 +40,7 @@ constexpr auto kChunkIdSeparator = "#%"; * approaches for specific classes. Acceptor should take care of writing data * to the actual storage. */ -CAFFE2_API void SerializeBlob( +TORCH_API void SerializeBlob( const Blob& blob, const string& name, BlobSerializerBase::SerializationAcceptor acceptor, @@ -45,15 +56,15 @@ CAFFE2_API void SerializeBlob( * * NOTE: this function doesn't do chunking and might break with big tensors. */ -CAFFE2_API string SerializeBlob(const Blob& blob, const string& name); +TORCH_API string SerializeBlob(const Blob& blob, const string& name); /** * Deserializes from a string containing either BlobProto or TensorProto. If * the deserialization fails, the content in the blob should no longer be * trusted. */ -CAFFE2_API void DeserializeBlob(const string& content, Blob* result); -CAFFE2_API void DeserializeBlob(const BlobProto& proto, Blob* result); +TORCH_API void DeserializeBlob(const string& content, Blob* result); +TORCH_API void DeserializeBlob(const BlobProto& proto, Blob* result); /* * Get an empty Tensor from the TensorProto given the meta data in proto (data @@ -75,7 +86,7 @@ CAFFE2_API void DeserializeBlob(const BlobProto& proto, Blob* result); * these function calls. e.g. mutable_data will allocate memory on the first * call and it will return a pointer to the allocated memory on later calls. */ -CAFFE2_API Tensor EmptyTensorFromProto(const TensorProto& proto); +TORCH_API Tensor EmptyTensorFromProto(const TensorProto& proto); /** * @brief TensorSerializer is the serializer for Tensors. @@ -83,7 +94,7 @@ CAFFE2_API Tensor EmptyTensorFromProto(const TensorProto& proto); * TensorSerializer takes in a blob that contains a Tensor, and serializes it * into a TensorProto protocol buffer. */ -class CAFFE2_API TensorSerializer : public BlobSerializerBase { +class TORCH_API TensorSerializer : public BlobSerializerBase { public: TensorSerializer() {} ~TensorSerializer() override {} @@ -125,7 +136,7 @@ class CAFFE2_API TensorSerializer : public BlobSerializerBase { * tensor, change the TensorProto's corresponding fields before calling * Deserialize. */ -class CAFFE2_API TensorDeserializer : public BlobDeserializerBase { +class TORCH_API TensorDeserializer : public BlobDeserializerBase { public: void Deserialize(const BlobProto& proto, Blob* blob) override; @@ -229,7 +240,7 @@ inline void CopyFromProtoWithCast( // Converts MessageLite to string while also checking that SerializeAsString // succeeds. Pass description of class/function of the call if you'd // like it appended to the error message. -CAFFE2_API std::string SerializeAsString_EnforceCheck( +TORCH_API std::string SerializeAsString_EnforceCheck( const google::protobuf::MessageLite&, const char* error_location = nullptr); @@ -239,6 +250,14 @@ inline std::string SerializeBlobProtoAsString_EnforceCheck( return SerializeAsString_EnforceCheck(blob, blob.name().c_str()); } +int64_t NumelFromTensorProto(const TensorProto& tensor_proto); + +std::vector DimsFromTensorProto(const TensorProto& proto); + +TypeMeta GetDataType(const TensorProto& tensor_proto); + +std::unique_ptr ContextFromProto(const TensorProto& tensor_proto); + } // namespace caffe2 #endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_ diff --git a/caffe2/core/blob_serializer_base.h b/caffe2/core/blob_serializer_base.h index ad282f31fe27d..969fb92240cae 100644 --- a/caffe2/core/blob_serializer_base.h +++ b/caffe2/core/blob_serializer_base.h @@ -78,7 +78,7 @@ inline unique_ptr CreateSerializer(TypeIdentifier id) { * @brief BlobDeserializerBase is an abstract class that deserializes a blob * from a BlobProto or a TensorProto. */ -class CAFFE2_API BlobDeserializerBase { +class TORCH_API BlobDeserializerBase { public: virtual ~BlobDeserializerBase() {} diff --git a/caffe2/core/blob_stats.h b/caffe2/core/blob_stats.h index e05b45183fb50..547897ed5f6ad 100644 --- a/caffe2/core/blob_stats.h +++ b/caffe2/core/blob_stats.h @@ -41,6 +41,6 @@ namespace BlobStat { * Return size in bytes of the blob, if available for a blob of given type. * If not available, return 0. */ -CAFFE2_API size_t sizeBytes(const Blob& blob); +TORCH_API size_t sizeBytes(const Blob& blob); } } diff --git a/caffe2/core/common.h b/caffe2/core/common.h index 076d83b1236d6..1b71eab0aa951 100644 --- a/caffe2/core/common.h +++ b/caffe2/core/common.h @@ -124,18 +124,18 @@ class SkipIndices<> { // linked. This function should not be used in static initialization functions // as the underlying boolean variable is going to be switched on when one // loads libtorch_gpu.so. -CAFFE2_API bool HasCudaRuntime(); -CAFFE2_API bool HasHipRuntime(); +TORCH_API bool HasCudaRuntime(); +TORCH_API bool HasHipRuntime(); namespace internal { // Sets the Cuda Runtime flag that is used by HasCudaRuntime(). You should // never use this function - it is only used by the Caffe2 gpu code to notify // Caffe2 core that cuda runtime has been loaded. -CAFFE2_API void SetCudaRuntimeFlag(); -CAFFE2_API void SetHipRuntimeFlag(); +TORCH_API void SetCudaRuntimeFlag(); +TORCH_API void SetHipRuntimeFlag(); } // namespace internal // Returns which setting Caffe2 was configured and built with (exported from // CMake) -CAFFE2_API const std::map& GetBuildOptions(); +TORCH_API const std::map& GetBuildOptions(); } // namespace caffe2 diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h index b38c38cdc781c..8d27766c629c1 100644 --- a/caffe2/core/common_gpu.h +++ b/caffe2/core/common_gpu.h @@ -132,6 +132,8 @@ CAFFE2_CUDA_API int GetGPUIDForPointer(const void* ptr); /** * Gets the device property for the given device. This function is thread safe. + * The initial run on this function is ~1ms/device; however, the results are + * cached so subsequent runs should be much faster. */ CAFFE2_CUDA_API const cudaDeviceProp& GetDeviceProperty(const int device); diff --git a/caffe2/core/context.h b/caffe2/core/context.h index 422a12acb118a..d5fe10820152c 100644 --- a/caffe2/core/context.h +++ b/caffe2/core/context.h @@ -15,6 +15,13 @@ #include +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) +#include +#include +#else +#include "caffe2/core/distributions_stubs.h" +#endif + C10_DECLARE_bool(caffe2_report_cpu_memory_usage); namespace caffe2 { @@ -23,7 +30,7 @@ namespace caffe2 { * A function to generate a random number seed that is unique in a best-effort * basis, using an ever-incrementing seed and the current time. */ -CAFFE2_API uint32_t RandomNumberSeed(); +TORCH_API uint32_t RandomNumberSeed(); /** * The CPU Context, representing the bare minimum of what a Context class in @@ -37,9 +44,14 @@ CAFFE2_API uint32_t RandomNumberSeed(); * computation it has. * */ -class CAFFE2_API CPUContext final : public BaseContext { +class TORCH_API CPUContext final : public BaseContext { public: +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) + typedef at::CPUGeneratorImpl rand_gen_type; +#else typedef std::mt19937 rand_gen_type; +#endif + CPUContext() {} explicit CPUContext(const DeviceOption& option) : random_seed_(option.has_random_seed() ? option.random_seed() : 1701), @@ -66,11 +78,11 @@ class CAFFE2_API CPUContext final : public BaseContext { inline void FinishDeviceComputation() override {} - inline rand_gen_type& RandGenerator() { + inline rand_gen_type* RandGenerator() { if (!random_generator_.get()) { random_generator_.reset(new rand_gen_type(RandSeed())); } - return *random_generator_.get(); + return random_generator_.get(); } inline uint32_t RandSeed() { @@ -119,7 +131,7 @@ class CAFFE2_API CPUContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { if (meta.copy()) { meta.copy()(src, dst, n); } else { diff --git a/caffe2/core/context_base.h b/caffe2/core/context_base.h index bad6872819ded..dfc1504e2092d 100644 --- a/caffe2/core/context_base.h +++ b/caffe2/core/context_base.h @@ -33,7 +33,7 @@ class BaseContext; * functions in the BaseContext class. * TODO: add docs after this is finalized. */ -class CAFFE2_API BaseContext { +class TORCH_API BaseContext { public: virtual ~BaseContext() noexcept {} @@ -104,7 +104,7 @@ class CAFFE2_API BaseContext { } void CopyItemsSameDevice( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { @@ -117,7 +117,7 @@ class CAFFE2_API BaseContext { } void CopyItemsFromCPU( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { @@ -130,7 +130,7 @@ class CAFFE2_API BaseContext { } void CopyItemsToCPU( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h index c0930b1a0e615..7406132f87887 100644 --- a/caffe2/core/context_gpu.h +++ b/caffe2/core/context_gpu.h @@ -279,7 +279,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { CAFFE_ENFORCE(!meta.copy(), "CUDAContext requires fundamental types."); CopyBytes(n * meta.itemsize(), src, dst); } diff --git a/caffe2/core/db.h b/caffe2/core/db.h index 2d04b3c224cae..97657793a70aa 100644 --- a/caffe2/core/db.h +++ b/caffe2/core/db.h @@ -19,7 +19,7 @@ enum Mode { READ, WRITE, NEW }; /** * An abstract class for the cursor of the database while reading. */ -class CAFFE2_API Cursor { +class TORCH_API Cursor { public: Cursor() {} virtual ~Cursor() {} @@ -60,7 +60,7 @@ class CAFFE2_API Cursor { /** * An abstract class for the current database transaction while writing. */ -class CAFFE2_API Transaction { +class TORCH_API Transaction { public: Transaction() {} virtual ~Transaction() {} @@ -79,7 +79,7 @@ class CAFFE2_API Transaction { /** * An abstract class for accessing a database of key-value pairs. */ -class CAFFE2_API DB { +class TORCH_API DB { public: DB(const string& /*source*/, Mode mode) : mode_(mode) {} virtual ~DB() {} @@ -143,7 +143,7 @@ inline bool DBExists(const string& db_type, const string& full_db_name) { /** * A reader wrapper for DB that also allows us to serialize it. */ -class CAFFE2_API DBReader { +class TORCH_API DBReader { public: friend class DBReaderSerializer; DBReader() {} @@ -296,7 +296,7 @@ class CAFFE2_API DBReader { C10_DISABLE_COPY_AND_ASSIGN(DBReader); }; -class CAFFE2_API DBReaderSerializer : public BlobSerializerBase { +class TORCH_API DBReaderSerializer : public BlobSerializerBase { public: /** * Serializes a DBReader. Note that this blob has to contain DBReader, @@ -309,7 +309,7 @@ class CAFFE2_API DBReaderSerializer : public BlobSerializerBase { BlobSerializerBase::SerializationAcceptor acceptor) override; }; -class CAFFE2_API DBReaderDeserializer : public BlobDeserializerBase { +class TORCH_API DBReaderDeserializer : public BlobDeserializerBase { public: void Deserialize(const BlobProto& proto, Blob* blob) override; }; diff --git a/caffe2/core/distributions_stubs.h b/caffe2/core/distributions_stubs.h new file mode 100644 index 0000000000000..cb5b43be03790 --- /dev/null +++ b/caffe2/core/distributions_stubs.h @@ -0,0 +1,75 @@ +#ifndef CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_ +#define CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_ + +#include + +/** + * This file provides distributions compatible with + * ATen/core/DistributionsHelper.h but backed with the std RNG implementation + * instead of the ATen one. + * + * Caffe2 mobile builds currently do not depend on all of ATen so this is + * required to allow using the faster ATen RNG for normal builds but keep the + * build size small on mobile. RNG performance typically doesn't matter on + * mobile builds since the models are small and rarely using random + * initialization. + */ + +namespace at { +namespace { + +template +struct distribution_adapter { + template + C10_HOST_DEVICE inline distribution_adapter(Args... args) + : distribution_(std::forward(args)...) {} + + template + C10_HOST_DEVICE inline R operator()(RNG generator) { + return distribution_(*generator); + } + + private: + T distribution_; +}; + +template +struct uniform_int_from_to_distribution + : distribution_adapter> { + C10_HOST_DEVICE inline uniform_int_from_to_distribution( + uint64_t range, + int64_t base) + : distribution_adapter>( + base, + // std is inclusive, at is exclusive + base + range - 1) {} +}; + +template +using uniform_real_distribution = + distribution_adapter>; + +template +using normal_distribution = + distribution_adapter>; + +template +using bernoulli_distribution = + distribution_adapter; + +template +using exponential_distribution = + distribution_adapter>; + +template +using cauchy_distribution = + distribution_adapter>; + +template +using lognormal_distribution = + distribution_adapter>; + +} // namespace +} // namespace at + +#endif // CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_ diff --git a/caffe2/core/event.cc b/caffe2/core/event.cc index b643385709358..919ff11b0aa1a 100644 --- a/caffe2/core/event.cc +++ b/caffe2/core/event.cc @@ -2,19 +2,19 @@ namespace caffe2 { -CAFFE2_API EventCreateFunction Event::event_creator_[MaxDeviceTypes]; -CAFFE2_API EventRecordFunction Event::event_recorder_[MaxDeviceTypes]; -CAFFE2_API EventWaitFunction +TORCH_API EventCreateFunction Event::event_creator_[MaxDeviceTypes]; +TORCH_API EventRecordFunction Event::event_recorder_[MaxDeviceTypes]; +TORCH_API EventWaitFunction Event::event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; -CAFFE2_API EventFinishFunction Event::event_finisher_[MaxDeviceTypes]; +TORCH_API EventFinishFunction Event::event_finisher_[MaxDeviceTypes]; -CAFFE2_API EventQueryFunction Event::event_querier_[MaxDeviceTypes]; -CAFFE2_API EventErrorMessageFunction +TORCH_API EventQueryFunction Event::event_querier_[MaxDeviceTypes]; +TORCH_API EventErrorMessageFunction Event::event_err_msg_getter_[MaxDeviceTypes]; -CAFFE2_API EventSetFinishedFunction +TORCH_API EventSetFinishedFunction Event::event_finished_setter_[MaxDeviceTypes]; -CAFFE2_API EventResetFunction Event::event_resetter_[MaxDeviceTypes]; -CAFFE2_API EventSetCallbackFunction +TORCH_API EventResetFunction Event::event_resetter_[MaxDeviceTypes]; +TORCH_API EventSetCallbackFunction Event::event_callback_setter_[MaxDeviceTypes]; namespace { diff --git a/caffe2/core/event.h b/caffe2/core/event.h index 77e3b19175046..0bbb701ecb9e1 100644 --- a/caffe2/core/event.h +++ b/caffe2/core/event.h @@ -55,7 +55,7 @@ typedef void (*EventResetFunction)(Event*); typedef std::function EventCallbackFunction; typedef void (*EventSetCallbackFunction)(Event*, EventCallbackFunction); -class CAFFE2_API Event { +class TORCH_API Event { public: explicit Event(const DeviceOption& option) : event_(), type_(option.device_type()), option_(option) { diff --git a/caffe2/core/export_caffe2_op_to_c10.h b/caffe2/core/export_caffe2_op_to_c10.h index e73b354140c4d..814ee05c7e7ac 100644 --- a/caffe2/core/export_caffe2_op_to_c10.h +++ b/caffe2/core/export_caffe2_op_to_c10.h @@ -8,6 +8,8 @@ #include #include #include +#include +#include #include namespace caffe2 { @@ -178,7 +180,7 @@ inline FunctionSchema make_function_schema_for_c10(const char* schema_str) { #define C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(OperatorName) \ namespace caffe2 { \ namespace _c10_ops { \ - CAFFE2_API const FunctionSchema& schema_##OperatorName(); \ + TORCH_API const FunctionSchema& schema_##OperatorName(); \ } \ } @@ -191,48 +193,52 @@ inline FunctionSchema make_function_schema_for_c10(const char* schema_str) { ::caffe2::detail::make_function_schema_for_c10(OperatorSchema); \ return schema; \ } \ + TORCH_LIBRARY_FRAGMENT(_caffe2, m) { \ + m.def(::caffe2::detail::make_function_schema_for_c10(OperatorSchema)); \ + } \ } \ } #define C10_EXPORT_CAFFE2_OP_TO_C10_CPU_KERNEL_ONLY( \ OperatorName, OperatorClass) \ /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ - static auto registry_##OperatorName##_##__COUNTER__ = \ - ::c10::RegisterOperators().op( \ - ::caffe2::_c10_ops::schema_##OperatorName(), \ - ::c10::RegisterOperators::options() \ - .kernel<&::caffe2::detail::call_caffe2_op_from_c10< \ - ::caffe2::_c10_ops::schema_##OperatorName, \ - OperatorClass>>(::c10::DispatchKey::CPU)); - -#define C10_EXPORT_CAFFE2_OP_TO_C10_CPU( \ - OperatorName, OperatorSchema, OperatorClass) \ - C10_EXPORT_CAFFE2_OP_TO_C10_SCHEMA_ONLY(OperatorName, OperatorSchema) \ + TORCH_LIBRARY_IMPL(_caffe2, CPU, m) { \ + m.impl("_caffe2::" #OperatorName, \ + torch::CppFunction::makeFromBoxedFunction< \ + ::caffe2::detail::call_caffe2_op_from_c10< \ + ::caffe2::_c10_ops::schema_##OperatorName, \ + OperatorClass>>()); \ + } + +#define C10_EXPORT_CAFFE2_OP_TO_C10_CPU( \ + OperatorName, OperatorSchema, OperatorClass) \ + C10_EXPORT_CAFFE2_OP_TO_C10_SCHEMA_ONLY(OperatorName, OperatorSchema) \ C10_EXPORT_CAFFE2_OP_TO_C10_CPU_KERNEL_ONLY(OperatorName, OperatorClass) #define C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(OperatorName, OperatorClass) \ /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ - static auto registry_##OperatorName##_##__COUNTER__ = \ - ::c10::RegisterOperators().op( \ - ::caffe2::_c10_ops::schema_##OperatorName(), \ - ::c10::RegisterOperators::options() \ - .kernel<&::caffe2::detail::call_caffe2_op_from_c10< \ - ::caffe2::_c10_ops::schema_##OperatorName, \ - OperatorClass>>(::c10::DispatchKey::CUDA)); + TORCH_LIBRARY_IMPL(_caffe2, CUDA, m) { \ + m.impl("_caffe2::" #OperatorName, \ + torch::CppFunction::makeFromBoxedFunction< \ + ::caffe2::detail::call_caffe2_op_from_c10< \ + ::caffe2::_c10_ops::schema_##OperatorName, \ + OperatorClass>>()); \ + } + // You should never manually call the C10_EXPORT_CAFFE2_OP_TO_C10_HIP macro . // The C10_EXPORT_CAFFE2_OP_TO_C10_CUDA macro from above will be automatically // rewritten to C10_EXPORT_CAFFE2_OP_TO_C10_HIP by hipify . #define C10_EXPORT_CAFFE2_OP_TO_C10_HIP(OperatorName, OperatorClass) \ /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ - static auto registry_##OperatorName##_##__COUNTER__ = \ - ::c10::RegisterOperators().op( \ - ::caffe2::_c10_ops::schema_##OperatorName(), \ - ::c10::RegisterOperators() \ - .options() \ - .kernel<&::caffe2::detail::call_caffe2_op_from_c10< \ - ::caffe2::_c10_ops::schema_##OperatorName, \ - OperatorClass>>(::c10::DispatchKey::HIP)); + TORCH_LIBRARY_IMPL(_caffe2, HIP, m) { \ + m.impl("_caffe2::" #OperatorName, \ + torch::CppFunction::makeFromBoxedFunction< \ + ::caffe2::detail::call_caffe2_op_from_c10< \ + ::caffe2::_c10_ops::schema_##OperatorName, \ + OperatorClass>>()); \ + } + #else // Don't use c10 dispatcher on mobile because of binary size diff --git a/caffe2/core/graph.h b/caffe2/core/graph.h index 6162b089ace16..dfee4b7deade6 100644 --- a/caffe2/core/graph.h +++ b/caffe2/core/graph.h @@ -16,7 +16,7 @@ namespace transform { /** * Graph representation of an operator. */ -struct CAFFE2_API Node { +struct TORCH_API Node { public: // Empty constructor for resize Node() {} @@ -45,7 +45,7 @@ struct CAFFE2_API Node { /** * Graph representation of a Netdef. */ -struct CAFFE2_API Graph { +struct TORCH_API Graph { public: /** * Given a subgraph, gets all of the parents of the subgraph, as well as @@ -155,7 +155,7 @@ struct CAFFE2_API Graph { // Adds an operator def to a netdef. // Returns the ptr, if you want to add anything extra (such as device_option) -CAFFE2_API OperatorDef* AddOp( +TORCH_API OperatorDef* AddOp( NetDef* netdef_ptr, string op_type, std::vector inputs, @@ -168,12 +168,12 @@ CAFFE2_API OperatorDef* AddOp( * For example, if we wanted to match an operator to Conv or FC, we can give: * "Conv|FC" as the type() of that op. */ -CAFFE2_API bool MatchStrings(string p, string s); +TORCH_API bool MatchStrings(string p, string s); /** * This ensures that each named arg that exists in the pattern exists in g_op, * is equal in value. */ -CAFFE2_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op); +TORCH_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op); } // namespace caffe2 diff --git a/caffe2/core/init.h b/caffe2/core/init.h index 634b6012ef053..8d0fbd3f1557e 100644 --- a/caffe2/core/init.h +++ b/caffe2/core/init.h @@ -8,7 +8,7 @@ namespace caffe2 { namespace internal { -class CAFFE2_API Caffe2InitializeRegistry { +class TORCH_API Caffe2InitializeRegistry { public: typedef bool (*InitFunction)(int*, char***); // Registry() is defined in .cpp file to make registration work across @@ -96,12 +96,12 @@ class CAFFE2_API Caffe2InitializeRegistry { }; } // namespace internal -CAFFE2_API bool unsafeRunCaffe2InitFunction( +TORCH_API bool unsafeRunCaffe2InitFunction( const char* name, int* pargc = nullptr, char*** pargv = nullptr); -class CAFFE2_API InitRegisterer { +class TORCH_API InitRegisterer { public: InitRegisterer( internal::Caffe2InitializeRegistry::InitFunction function, @@ -128,9 +128,9 @@ class CAFFE2_API InitRegisterer { /** * @brief Determine whether GlobalInit has already been run */ -CAFFE2_API bool GlobalInitAlreadyRun(); +TORCH_API bool GlobalInitAlreadyRun(); -class CAFFE2_API GlobalInitIsCalledGuard { +class TORCH_API GlobalInitIsCalledGuard { public: GlobalInitIsCalledGuard() { if (!GlobalInitAlreadyRun()) { @@ -165,7 +165,7 @@ class CAFFE2_API GlobalInitIsCalledGuard { * * GlobalInit is also thread-safe and can be called concurrently. */ -CAFFE2_API bool GlobalInit(int* pargc, char*** argv); +TORCH_API bool GlobalInit(int* pargc, char*** argv); /** * @brief Initialize the global environment without command line arguments @@ -174,6 +174,6 @@ CAFFE2_API bool GlobalInit(int* pargc, char*** argv); * On mobile devices, use this global init, since we cannot pass the * command line options to caffe2, no arguments are passed. */ -CAFFE2_API bool GlobalInit(); +TORCH_API bool GlobalInit(); } // namespace caffe2 #endif // CAFFE2_CORE_INIT_H_ diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 60525f0f4be68..dd9f9902be1ff 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -54,10 +54,16 @@ static_assert( // Useful build settings that are recorded in the compiled binary #define CAFFE2_BUILD_STRINGS { \ + {"TORCH_VERSION", "${TORCH_VERSION}"}, \ + {"CXX_COMPILER", "${CMAKE_CXX_COMPILER}"}, \ {"CXX_FLAGS", "${CMAKE_CXX_FLAGS}"}, \ {"BUILD_TYPE", "${CMAKE_BUILD_TYPE}"}, \ - {"BLAS", "${BLAS}"}, \ + {"BLAS_INFO", "${BLAS_INFO}"}, \ + {"LAPACK_INFO", "${LAPACK_INFO}"}, \ {"USE_CUDA", "${USE_CUDA}"}, \ + {"CUDA_VERSION", "${CUDA_VERSION}"}, \ + {"USE_CUDNN", "${USE_CUDNN}"}, \ + {"CUDNN_VERSION", "${CUDNN_VERSION}"}, \ {"USE_NCCL", "${USE_NCCL}"}, \ {"USE_MPI", "${USE_MPI}"}, \ {"USE_GFLAGS", "${USE_GFLAGS}"}, \ diff --git a/caffe2/core/memonger.cc b/caffe2/core/memonger.cc index 994c97c6c6925..9aca98f9eb1a3 100644 --- a/caffe2/core/memonger.cc +++ b/caffe2/core/memonger.cc @@ -6,6 +6,19 @@ #include "caffe2/utils/proto_utils.h" namespace caffe2 { + +void run_schema_check(const NetDef& net) { + for (auto& op : net.op()) { + auto* schema = OpSchemaRegistry::Schema(op.type()); + if (schema) { + CAFFE_ENFORCE( + schema->Verify(op), + "Operator def did not pass schema checking: ", + ProtoDebugString(op)); + } + } +} + namespace memonger { NetDef optimize_inference_net( @@ -16,6 +29,10 @@ NetDef optimize_inference_net( return net; } + // Memonger modifies the graph. Do an early schema check here to make sure + // the operators are valid + run_schema_check(net); + std::vector ops; for (auto& op : net.op()) { if (op.type() == "RecurrentNetwork") { @@ -138,13 +155,20 @@ class ComputeBlobRecyclingForDag { const string& namescope, const std::unordered_set& dont_share_blob_names, const std::unordered_map>& blob_shapes) { + + // Memonger modifies the graph. Do an early schema check here to make sure + // the operators are valid + run_schema_check(net); // Construct the set of input blobs. std::unordered_set heads_blobs_set(heads.begin(), heads.end()); // Construct the set of output blobs we want to optimize. + // Blobs not eligible for sharing are filtered out for (const int op_index : op_indices) { for (const auto& output : net.op(op_index).output()) { - optim_op_outputs_.insert(output); + if (has_key(shareable_blob_names, output) && !has_key(dont_share_blob_names, output)) { + optim_op_outputs_.insert(output); + } } } @@ -182,10 +206,13 @@ class ComputeBlobRecyclingForDag { } // The main recursive call. Here we do start DFS in the operator graph - // from the input blobs. + // from the input blobs. Note that the input ordering does not indicate + // operator graph ordering. To avoid traversing children operators first, + // traversal begins from root ops and then recursively children ops are + // visited. for (const auto& input_blob : heads) { for (const int op_index : blob_to_ops_[input_blob]) { - if (!op_visited_[op_index]) { + if (!op_visited_[op_index] && !op_inputs_[op_index]) { vector> free_blobs; std::unordered_set tokens{tokens_counter_++}; process_op( @@ -254,6 +281,12 @@ class ComputeBlobRecyclingForDag { apply_recurrent_blob_assignments(optimized_net.mutable_op(i)); } + // Special handling for AsyncIf ops, where internal nets can + // refer to memongered blobs + if (optimized_net.op(i).type() == "AsyncIf") { + apply_asyncif_blob_assignments(optimized_net.mutable_op(i)); + } + for (int j = 0; j < optimized_net.op(i).input_size(); ++j) { const string& input_name = get_blob_or_mapped_blob(optimized_net.op(i).input(j)); @@ -313,6 +346,39 @@ class ComputeBlobRecyclingForDag { } } + void apply_asyncif_blob_assignments(OperatorDef* op) { + for (int i = 0; i < op->arg_size(); i++) { + Argument* arg = op->mutable_arg(i); + const string& name = arg->name(); + if (name == "then_net" || name == "else_net") { + NetDef* step_net_ref = arg->mutable_n(); + NetDef optimized_net = apply_assignments(*step_net_ref); + + // update external inputs and outputs mappings as well + // for this internal net + std::vector optim_external_inputs; + for (auto& blob_name : optimized_net.external_input()) { + optim_external_inputs.push_back(get_blob_or_mapped_blob(blob_name)); + } + optimized_net.mutable_external_input()->Clear(); + for (const auto& blob_name : optim_external_inputs) { + optimized_net.add_external_input(blob_name); + } + + std::vector optim_external_outputs; + for (auto& blob_name : optimized_net.external_output()) { + optim_external_outputs.push_back(get_blob_or_mapped_blob(blob_name)); + } + optimized_net.mutable_external_output()->Clear(); + for (const auto& blob_name : optim_external_outputs) { + optimized_net.add_external_output(blob_name); + } + + step_net_ref->CopyFrom(optimized_net); + } + } + } + template inline bool has_key(const std::unordered_map& in_map, const K& key) { return in_map.find(key) != in_map.end(); diff --git a/caffe2/core/memonger.h b/caffe2/core/memonger.h index 83270fa26cb8b..b015a23f02854 100644 --- a/caffe2/core/memonger.h +++ b/caffe2/core/memonger.h @@ -8,13 +8,17 @@ #include "caffe2/proto/caffe2_pb.h" namespace caffe2 { + +// op schema check +TORCH_API void run_schema_check(const NetDef& net); + namespace memonger { -CAFFE2_API NetDef optimize_inference_net( +TORCH_API NetDef optimize_inference_net( const NetDef& net, const std::set& static_blobs); -CAFFE2_API NetDef compute_blob_recycling_for_dag( +TORCH_API NetDef compute_blob_recycling_for_dag( const NetDef& net, const std::vector& heads, const std::vector& op_indices, diff --git a/caffe2/core/module.h b/caffe2/core/module.h index 88f8730d675a7..bb5dceb22cad5 100644 --- a/caffe2/core/module.h +++ b/caffe2/core/module.h @@ -23,7 +23,7 @@ namespace caffe2 { * different modules. Currently, we only store the name and a simple * description of what this module does. */ -class CAFFE2_API ModuleSchema { +class TORCH_API ModuleSchema { public: ModuleSchema(const char* name, const char* description); }; @@ -41,12 +41,12 @@ class CAFFE2_API ModuleSchema { * the reason we do not include ".so" is for cross-platform compatibility * on platforms like mac os. */ -CAFFE2_API const CaffeMap& CurrentModules(); +TORCH_API const CaffeMap& CurrentModules(); /** * @brief Checks whether a module is already present in the current binary. */ -CAFFE2_API bool HasModule(const string& name); +TORCH_API bool HasModule(const string& name); /** * @brief Load a module. @@ -56,7 +56,7 @@ CAFFE2_API bool HasModule(const string& name); * full path option to only experimental modules. * filename: (optional) a filename that serves as a hint to load the module. */ -CAFFE2_API void LoadModule(const string& name, const string& filename=""); +TORCH_API void LoadModule(const string& name, const string& filename=""); #define CAFFE2_MODULE(name, description) \ diff --git a/caffe2/core/net.h b/caffe2/core/net.h index 49333b1afe853..0726d8e8c6c90 100644 --- a/caffe2/core/net.h +++ b/caffe2/core/net.h @@ -34,7 +34,7 @@ class Workspace; // Net is a thin struct that owns all the operators together with the operator // contexts. -class CAFFE2_API NetBase : public Observable { +class TORCH_API NetBase : public Observable { public: NetBase(const std::shared_ptr& net_def, Workspace* ws); virtual ~NetBase() noexcept {} @@ -135,7 +135,7 @@ class CAFFE2_API NetBase : public Observable { C10_DISABLE_COPY_AND_ASSIGN(NetBase); }; -class CAFFE2_API ExecutorHelper { +class TORCH_API ExecutorHelper { public: ExecutorHelper() {} virtual TaskThreadPoolBase* GetPool(const DeviceOption& option) const; @@ -161,14 +161,14 @@ C10_DECLARE_REGISTRY( * created net object to the workspace's net map, while this function returns * a standalone net object. */ -CAFFE2_API unique_ptr CreateNet(const NetDef& net_def, Workspace* ws); -CAFFE2_API unique_ptr CreateNet( +TORCH_API unique_ptr CreateNet(const NetDef& net_def, Workspace* ws); +TORCH_API unique_ptr CreateNet( const std::shared_ptr& net_def, Workspace* ws); -CAFFE2_API void AddGlobalNetObserverCreator(NetObserverCreator creator); +TORCH_API void AddGlobalNetObserverCreator(NetObserverCreator creator); -CAFFE2_API void ClearGlobalNetObservers(); +TORCH_API void ClearGlobalNetObservers(); } // namespace caffe2 diff --git a/caffe2/core/net_async_base.h b/caffe2/core/net_async_base.h index 20e3a69826cdc..b80ef9872c8b6 100644 --- a/caffe2/core/net_async_base.h +++ b/caffe2/core/net_async_base.h @@ -57,13 +57,13 @@ struct ExecutionOptions { bool run_root_tasks_inline_ = false; }; -struct CAFFE2_API AsyncNetCancelled : public std::exception { +struct TORCH_API AsyncNetCancelled : public std::exception { const char* what() const noexcept override { return "Cancelled"; } }; -class CAFFE2_API AsyncNetBase : public NetBase { +class TORCH_API AsyncNetBase : public NetBase { public: AsyncNetBase(const std::shared_ptr& net_def, Workspace* ws); ~AsyncNetBase() override; diff --git a/caffe2/core/net_async_scheduling.h b/caffe2/core/net_async_scheduling.h index 3751669933dd9..7a557ceb8a1ce 100644 --- a/caffe2/core/net_async_scheduling.h +++ b/caffe2/core/net_async_scheduling.h @@ -5,7 +5,7 @@ namespace caffe2 { -class CAFFE2_API AsyncSchedulingNet : public AsyncNetBase { +class TORCH_API AsyncSchedulingNet : public AsyncNetBase { public: AsyncSchedulingNet( const std::shared_ptr& net_def, diff --git a/caffe2/core/net_async_tracing.h b/caffe2/core/net_async_tracing.h index 43665b1d80ab8..33e91c7a007fc 100644 --- a/caffe2/core/net_async_tracing.h +++ b/caffe2/core/net_async_tracing.h @@ -29,7 +29,7 @@ C10_DECLARE_int(caffe2_net_async_tracing_nth); namespace caffe2 { namespace tracing { -struct CAFFE2_API TracerEvent { +struct TORCH_API TracerEvent { int op_id_ = -1; int task_id_ = -1; int stream_id_ = -1; @@ -70,7 +70,7 @@ struct TracingConfig { int64_t trace_for_n_ms = 1000; // 1sec }; -class CAFFE2_API Tracer { +class TORCH_API Tracer { public: Tracer( const NetBase* net, @@ -111,7 +111,7 @@ class CAFFE2_API Tracer { friend class TracerGuard; }; -class CAFFE2_API TracerGuard { +class TORCH_API TracerGuard { public: TracerGuard() {} @@ -142,16 +142,16 @@ class CAFFE2_API TracerGuard { // Extract the shard id from name of the form "...shard:123..." // Return -1 if there is no shard found -CAFFE2_API int extractShardId(const std::string& name); +TORCH_API int extractShardId(const std::string& name); // Check if the net name is white-listed for tracing (specified via a command // line flag) -CAFFE2_API bool isTraceableNetName(const std::string& net_name); +TORCH_API bool isTraceableNetName(const std::string& net_name); -CAFFE2_API std::shared_ptr create( +TORCH_API std::shared_ptr create( const NetBase* net, const std::string& net_name); -CAFFE2_API bool startIter(const std::shared_ptr& tracer); +TORCH_API bool startIter(const std::shared_ptr& tracer); } // namespace tracing diff --git a/caffe2/core/net_parallel.h b/caffe2/core/net_parallel.h index 756637f1bc537..60030c3a1d4e3 100644 --- a/caffe2/core/net_parallel.h +++ b/caffe2/core/net_parallel.h @@ -10,7 +10,7 @@ namespace caffe2 { class ParallelNetExecutorHelper; -class CAFFE2_API ParallelNet : public NetBase { +class TORCH_API ParallelNet : public NetBase { public: ParallelNet(const std::shared_ptr& net_def, Workspace* ws); diff --git a/caffe2/core/net_simple.h b/caffe2/core/net_simple.h index 5b8bc29be4dfa..c6b25eab4c57c 100644 --- a/caffe2/core/net_simple.h +++ b/caffe2/core/net_simple.h @@ -16,7 +16,7 @@ namespace caffe2 { // This is the very basic structure you need to run a network - all it // does is simply to run everything in sequence. If you want more fancy control // such as a DAG-like execution, check out other better net implementations. -class CAFFE2_API SimpleNet : public NetBase { +class TORCH_API SimpleNet : public NetBase { public: SimpleNet(const std::shared_ptr& net_def, Workspace* ws); bool SupportsAsync() override { diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h b/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h index 5033417d8d78d..eb6f1d7c4d95d 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h @@ -40,11 +40,11 @@ class Node; // \brief Edge within a Graph. template -class Edge : public StorageType { +class Edge : public ::StorageType { public: using NodeRef = typename Graph::NodeRef; Edge(NodeRef tail, NodeRef head, U... args) - : StorageType(std::forward(args)...), + : ::StorageType(std::forward(args)...), tail_(tail), head_(head) { DEBUG_PRINT("Creating instance of Edge: %p\n", this); @@ -74,17 +74,17 @@ class Edge : public StorageType { // \brief Node within a Graph. template -class Node : public StorageType, public Notifier> { +class Node : public ::StorageType, public Notifier> { public: using NodeRef = typename Graph::NodeRef; using EdgeRef = typename Graph::EdgeRef; /// \brief Create a node with data. - explicit Node(T&& data) : StorageType(std::move(data)) { + explicit Node(T&& data) : ::StorageType(std::move(data)) { DEBUG_PRINT("Creating instance of Node: %p\n", this); } /// \brief Create an empty node. - explicit Node() : StorageType() {} + explicit Node() : ::StorageType() {} Node(Node&&) = default; Node(const Node&) = delete; Node& operator=(const Node&) = delete; diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Representations/Compiler.h b/caffe2/core/nomnigraph/include/nomnigraph/Representations/Compiler.h index d8e9c1090b3ed..d5a019fc15e1a 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Representations/Compiler.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Representations/Compiler.h @@ -8,7 +8,7 @@ namespace nom { namespace repr { -class CAFFE2_API Value { +class TORCH_API Value { public: enum class ValueKind { Value, Instruction, Data }; Value(ValueKind K) : kind_(K) {} @@ -22,7 +22,7 @@ class CAFFE2_API Value { const ValueKind kind_; }; -class CAFFE2_API Data : public Value { +class TORCH_API Data : public Value { public: Data() : Value(ValueKind::Data) {} static bool classof(const Value* V) { @@ -41,7 +41,7 @@ class CAFFE2_API Data : public Value { size_t version_ = 0; }; -class CAFFE2_API Instruction : public Value { +class TORCH_API Instruction : public Value { public: /// \brief All the different types of execution. enum class Opcode { @@ -66,7 +66,7 @@ class CAFFE2_API Instruction : public Value { Opcode op_; }; -class CAFFE2_API Terminator : public Instruction { +class TORCH_API Terminator : public Instruction { public: Terminator(Instruction::Opcode op) : Instruction(op) {} @@ -80,17 +80,17 @@ class CAFFE2_API Terminator : public Instruction { } }; -class CAFFE2_API Branch : public Terminator { +class TORCH_API Branch : public Terminator { public: Branch() : Terminator(Instruction::Opcode::Branch) {} }; -class CAFFE2_API Return : public Terminator { +class TORCH_API Return : public Terminator { public: Return() : Terminator(Instruction::Opcode::Return) {} }; -class CAFFE2_API Phi : public Instruction { +class TORCH_API Phi : public Instruction { public: Phi() : Instruction(Instruction::Opcode::Phi) {} }; diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h index 812fea7be7c3b..e3eb90afc4f5f 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h @@ -41,7 +41,7 @@ class NeuralNetData; /// a saved void* pointer for external use. Derived classes /// add richer semantics to the annotation and it is encouraged /// to use them. -class CAFFE2_API Annotation { +class TORCH_API Annotation { public: enum class AnnotationKind { Generic, Caffe2 }; @@ -57,7 +57,7 @@ class CAFFE2_API Annotation { const AnnotationKind kind_; }; -class CAFFE2_API NeuralNetOperator : public Instruction { +class TORCH_API NeuralNetOperator : public Instruction { public: /// Discriminator for LLVM-style RTTI (isa<>) enum class NNKind { @@ -132,7 +132,7 @@ class CAFFE2_API NeuralNetOperator : public Instruction { std::unique_ptr extraAnnotation_; }; -class CAFFE2_API NeuralNetData : public Data { +class TORCH_API NeuralNetData : public Data { public: /// Discriminator for LLVM-style RTTI (isa<>) enum class NNDataKind { Generic, Tensor }; @@ -155,7 +155,7 @@ class CAFFE2_API NeuralNetData : public Data { NNDataKind kind_; }; -class CAFFE2_API Tensor : public NeuralNetData { +class TORCH_API Tensor : public NeuralNetData { public: enum class DataType { Generic, Float, Half, Int8 }; enum class Layout { Generic, NCHW, NHWC }; @@ -208,21 +208,21 @@ class CAFFE2_API Tensor : public NeuralNetData { #include "nomnigraph/Generated/OpClasses.h" -class CAFFE2_API While : public NeuralNetOperator { +class TORCH_API While : public NeuralNetOperator { public: While() : NeuralNetOperator(NNKind::While, Opcode::Branch) {} NOMNIGRAPH_DEFINE_NN_RTTI(While); ~While() {} }; -class CAFFE2_API NNPhi : public NeuralNetOperator { +class TORCH_API NNPhi : public NeuralNetOperator { public: NNPhi() : NeuralNetOperator(NNKind::NNPhi, Opcode::Phi) {} NOMNIGRAPH_DEFINE_NN_RTTI(NNPhi); ~NNPhi() {} }; -class CAFFE2_API GenericOperator : public NeuralNetOperator { +class TORCH_API GenericOperator : public NeuralNetOperator { public: GenericOperator() : NeuralNetOperator(NNKind::GenericOperator) {} GenericOperator(std::string name) @@ -244,7 +244,7 @@ using NNGraph = nom::Graph>; using NNSubgraph = nom::Subgraph>; using NNCFGraph = nom::repr::ControlFlowGraph; -struct CAFFE2_API NNModule { +struct TORCH_API NNModule { NNGraph dataFlow; NNCFGraph controlFlow; std::unordered_set inputs; @@ -464,41 +464,41 @@ NNGraph::NodeRef convertNode(NNGraph& g, NNGraph::NodeRef node) { } /// NeuralNetData specific helpers. -CAFFE2_API bool hasProducer(NNGraph::NodeRef n); -CAFFE2_API NNGraph::NodeRef getProducer(NNGraph::NodeRef n); -CAFFE2_API bool hasConsumer(NNGraph::NodeRef n); -CAFFE2_API std::vector getConsumers(NNGraph::NodeRef n); +TORCH_API bool hasProducer(NNGraph::NodeRef n); +TORCH_API NNGraph::NodeRef getProducer(NNGraph::NodeRef n); +TORCH_API bool hasConsumer(NNGraph::NodeRef n); +TORCH_API std::vector getConsumers(NNGraph::NodeRef n); -CAFFE2_API bool hasInputs(NNGraph::NodeRef n); -CAFFE2_API std::vector getInputs(NNGraph::NodeRef n); -CAFFE2_API std::vector getOutputs(NNGraph::NodeRef n); +TORCH_API bool hasInputs(NNGraph::NodeRef n); +TORCH_API std::vector getInputs(NNGraph::NodeRef n); +TORCH_API std::vector getOutputs(NNGraph::NodeRef n); -CAFFE2_API std::set getInputs(const NNSubgraph& sg); -CAFFE2_API std::set getOutputs(const NNSubgraph& sg); +TORCH_API std::set getInputs(const NNSubgraph& sg); +TORCH_API std::set getOutputs(const NNSubgraph& sg); // Get the name of the node regardless of underlying type. -CAFFE2_API std::string getName(NNGraph::NodeRef n); +TORCH_API std::string getName(NNGraph::NodeRef n); // Replace the producer of the first argument with the second argument -CAFFE2_API void replaceProducer( +TORCH_API void replaceProducer( NNGraph::NodeRef tensorNode, NNGraph::NodeRef newProducer); // Set all consumers of first argument to consume the second argument -CAFFE2_API void replaceAllUsesWith( +TORCH_API void replaceAllUsesWith( NNGraph::NodeRef oldTensorNode, NNGraph::NodeRef newTensorNode); // Set the second argument to consume the inputs of the first argument -CAFFE2_API void replaceAsConsumer( +TORCH_API void replaceAsConsumer( NNGraph::NodeRef oldConsumer, NNGraph::NodeRef newConsumer); // Create an output tensor node -CAFFE2_API NNGraph::NodeRef +TORCH_API NNGraph::NodeRef createOutput(NNModule* nn, NNGraph::NodeRef producer, std::string name); // Hack for windows compiler. template -CAFFE2_API NNGraph::NodeRef createOperator(NNModule* nn, Args... args); +TORCH_API NNGraph::NodeRef createOperator(NNModule* nn, Args... args); // Create an operator template @@ -506,7 +506,7 @@ NNGraph::NodeRef createOperator(NNModule* nn, Args... args) { return nn->dataFlow.createNode(util::make_unique(args...)); } -CAFFE2_API void coalesceInsertedDataDependencies(repr::NNModule* m); +TORCH_API void coalesceInsertedDataDependencies(repr::NNModule* m); template struct C10_EXPORT NodeHelper {}; @@ -517,12 +517,12 @@ using NNMatchPredicate = nom::matcher::MatchPredicate; // Commonly used node predicate. // The node has a single output and the output has a single consumer. -CAFFE2_API bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef); +TORCH_API bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef); // The node has a unique consumer (there may be multiple edges from output // to the single consumer). -CAFFE2_API bool hasUniqueConsumer(NNGraph::NodeRef nodeRef); +TORCH_API bool hasUniqueConsumer(NNGraph::NodeRef nodeRef); -CAFFE2_API NNMatchPredicate matchExternalTensorNode(); +TORCH_API NNMatchPredicate matchExternalTensorNode(); } // namespace nn diff --git a/caffe2/core/nomnigraph/tests/test_util.h b/caffe2/core/nomnigraph/tests/test_util.h index f19f75f75b628..f60e73f9005d6 100644 --- a/caffe2/core/nomnigraph/tests/test_util.h +++ b/caffe2/core/nomnigraph/tests/test_util.h @@ -102,9 +102,9 @@ class TestRandom { * return labelMap; * }); */ -CAFFE2_API nom::Graph createGraph(); +TORCH_API nom::Graph createGraph(); -CAFFE2_API nom::Graph createGraphWithCycle(); +TORCH_API nom::Graph createGraphWithCycle(); std::map BBPrinter(typename nom::repr::NNCFGraph::NodeRef node); @@ -112,9 +112,9 @@ std::map cfgEdgePrinter(typename nom::repr::NNCFGraph: std::map NNPrinter(typename nom::repr::NNGraph::NodeRef node); -CAFFE2_API nom::Graph::NodeRef createTestNode( +TORCH_API nom::Graph::NodeRef createTestNode( nom::Graph& g); -CAFFE2_API std::map TestNodePrinter( +TORCH_API std::map TestNodePrinter( nom::Graph::NodeRef node); #endif // NOM_TESTS_TEST_UTIL_H diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index ea9ae7892a23e..8b2a6b571098d 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -48,10 +48,10 @@ struct FunctionSchema; namespace caffe2 { -class CAFFE2_API OperatorBase; +class TORCH_API OperatorBase; typedef ObserverBase OperatorObserver; -class CAFFE2_API OperatorBase : public Observable { +class TORCH_API OperatorBase : public Observable { public: explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws); @@ -1246,7 +1246,7 @@ struct DispatchHelper, ExtraArgs...> { template \ struct DispatchHelper, ExtraArgs...> { \ template \ - static bool call(Op* op, const TypeMeta& meta) { \ + static bool call(Op* op, const TypeMeta meta) { \ static_assert( \ !std::is_same::value, \ "GenericTensorImplementation must be the last in TensorTypes list"); \ @@ -1269,7 +1269,7 @@ struct DispatchHelper, ExtraArgs...> { template \ struct DispatchHelper, ExtraArgs...> { \ template \ - static bool call(Op* /* unused */, const TypeMeta& meta) { \ + static bool call(Op* /* unused */, const TypeMeta meta) { \ CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ } \ template \ @@ -1287,7 +1287,7 @@ struct DispatchHelper, ExtraArgs...> { TensorTypes, \ ExtraArgs...> { \ template \ - static bool call(Op* op, const TypeMeta&) { \ + static bool call(Op* op, const TypeMeta) { \ return op->template DoRunWithOtherType(); \ } \ template \ @@ -1325,9 +1325,9 @@ typedef c10::Registry< std::unique_ptr, const OperatorDef&, Workspace*>* (*RegistryFunction)(); -CAFFE2_API std::map* gDeviceTypeRegistry(); +TORCH_API std::map* gDeviceTypeRegistry(); -struct CAFFE2_API DeviceTypeRegisterer { +struct TORCH_API DeviceTypeRegisterer { explicit DeviceTypeRegisterer(DeviceType type, RegistryFunction func) { if (gDeviceTypeRegistry()->count(type)) { std::cerr << "Device type " << DeviceTypeName(type) @@ -1446,7 +1446,7 @@ C10_DECLARE_REGISTRY( // You should not need to use this class. struct StaticLinkingProtector { StaticLinkingProtector() { - const int registered_ops = CPUOperatorRegistry()->Keys().size(); + const auto registered_ops = CPUOperatorRegistry()->Keys().size(); // Note: this is a check failure instead of an exception, because if // the linking is wrong, Caffe2 won't be able to run properly anyway, // so it's better to fail loud. @@ -1467,7 +1467,7 @@ struct StaticLinkingProtector { // specific engines that only implement a subset of the features required by // the original operator schema. // TODO(jiayq): make more feature-complete exception message. -class CAFFE2_API UnsupportedOperatorFeature : public std::exception { +class TORCH_API UnsupportedOperatorFeature : public std::exception { public: UnsupportedOperatorFeature(const string& msg) : msg_(msg) {} const char* what() const noexcept override { @@ -1488,12 +1488,12 @@ class CAFFE2_API UnsupportedOperatorFeature : public std::exception { // Creates an operator with the given operator definition. // Throws on error and never returns nullptr -CAFFE2_API unique_ptr CreateOperator( +TORCH_API unique_ptr CreateOperator( const OperatorDef& operator_def, Workspace* ws, int net_position = OperatorBase::kNoNetPositionSet); -CAFFE2_API const std::string OpRegistryKey( +TORCH_API const std::string OpRegistryKey( const std::string& op_type, const std::string& engine = ""); @@ -1505,50 +1505,50 @@ using PerOpEnginePrefType = CaffeMap>; // {device_type -> EnginePrefType} using GlobalEnginePrefType = CaffeMap; -CAFFE2_API void SetPerOpEnginePref( +TORCH_API void SetPerOpEnginePref( const PerOpEnginePrefType& per_op_engine_pref); -CAFFE2_API void SetGlobalEnginePref( +TORCH_API void SetGlobalEnginePref( const GlobalEnginePrefType& global_engine_pref); -CAFFE2_API void SetEnginePref( +TORCH_API void SetEnginePref( const PerOpEnginePrefType& per_op_engine_pref, const GlobalEnginePrefType& global_engine_pref); -CAFFE2_API void SetOpEnginePref( +TORCH_API void SetOpEnginePref( const std::string& op_type, const CaffeMap& op_pref); -CAFFE2_API void LoadInt8TensorInfoOfBlob( +TORCH_API void LoadInt8TensorInfoOfBlob( std::vector* scale, std::vector* offset, uint32_t* axis, const Blob* b); -CAFFE2_API TensorShape GetTensorShapeOfBlob(const Blob* b); +TORCH_API TensorShape GetTensorShapeOfBlob(const Blob* b); -CAFFE2_API TensorShapes InferBlobShapesAndTypes( +TORCH_API TensorShapes InferBlobShapesAndTypes( CaffeMap& blob_desc, const vector& nets); -CAFFE2_API TensorShapes InferBlobShapesAndTypesFromWorkspace( +TORCH_API TensorShapes InferBlobShapesAndTypesFromWorkspace( Workspace* ws, const vector& nets); -CAFFE2_API TensorShapes InferBlobShapesAndTypesFromMap( +TORCH_API TensorShapes InferBlobShapesAndTypesFromMap( const CaffeMap>& blob_dimensions, const vector& nets); -CAFFE2_API TensorShapes InferBlobShapesAndTypesFromMap( +TORCH_API TensorShapes InferBlobShapesAndTypesFromMap( const CaffeMap>& blob_dimensions, const CaffeMap& blob_types, const vector& nets); -CAFFE2_API std::map> +TORCH_API std::map> ValidateTensorDevices(OperatorBase& op, const OperatorDef& op_def); // Get a set of registered operator names -CAFFE2_API std::set GetRegisteredOperators(); +TORCH_API std::set GetRegisteredOperators(); // Operator logging capabilities -CAFFE2_API void SetOperatorLogger( +TORCH_API void SetOperatorLogger( std::function tracer); std::function GetOperatorLogger(); diff --git a/caffe2/core/operator_gradient.h b/caffe2/core/operator_gradient.h index b444c285b2dc6..5c8d97a38fd25 100644 --- a/caffe2/core/operator_gradient.h +++ b/caffe2/core/operator_gradient.h @@ -14,7 +14,7 @@ namespace caffe2 { * a sparse blob, its gradient name should be written into indice_ for * the sparse indices and value_ for the values. */ -struct CAFFE2_API GradientWrapper { +struct TORCH_API GradientWrapper { string dense_; string indices_; string values_; @@ -33,7 +33,7 @@ struct CAFFE2_API GradientWrapper { /** * A struct that holds the gradient operators and related gradient maps. */ -struct CAFFE2_API GradientOpsMeta { +struct TORCH_API GradientOpsMeta { vector ops_; vector g_input_; @@ -44,7 +44,7 @@ struct CAFFE2_API GradientOpsMeta { : ops_(ops), g_input_(v) {} }; -class CAFFE2_API GradientMakerBase { +class TORCH_API GradientMakerBase { public: GradientMakerBase( const OperatorDef& def, @@ -256,7 +256,7 @@ class CAFFE2_API GradientMakerBase { * that the gradient computation should not flow through it at all, and throws * an error if it is called. */ -class CAFFE2_API NoGradient : public GradientMakerBase { +class TORCH_API NoGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { return vector(); @@ -328,7 +328,7 @@ C10_DECLARE_REGISTRY( /** * @brief Gets the GradientOpsMeta for the given operator def. */ -CAFFE2_API GradientOpsMeta GetGradientForOp( +TORCH_API GradientOpsMeta GetGradientForOp( const OperatorDef& def, const vector& g_output); diff --git a/caffe2/core/operator_schema.cc b/caffe2/core/operator_schema.cc index 3009ba4206c0b..9ff8dfd0eaa61 100644 --- a/caffe2/core/operator_schema.cc +++ b/caffe2/core/operator_schema.cc @@ -307,8 +307,8 @@ OpSchema::Arg(const char* name, const char* description, bool required) { } #define DEFINE_STANDARG_ARG(name, str) \ - CAFFE2_API const char* OpSchema::Arg_##name = #str; \ - CAFFE2_API OpSchema& OpSchema::Arg##name(const char* description) { \ + TORCH_API const char* OpSchema::Arg_##name = #str; \ + TORCH_API OpSchema& OpSchema::Arg##name(const char* description) { \ return Arg(#str, description, true); \ } diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h index deca56a5a88d2..00834fa338b30 100644 --- a/caffe2/core/operator_schema.h +++ b/caffe2/core/operator_schema.h @@ -37,7 +37,7 @@ constexpr int kCannotComputeNumOutputs = -1; * OPERATOR_SCHEMA(name) * .NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}}); */ -class CAFFE2_API OpSchema { +class TORCH_API OpSchema { public: OpSchema() : OpSchema("unknown", "unknown", 0) {} OpSchema(const string& type, const string& file, const int line); @@ -339,7 +339,7 @@ class CAFFE2_API OpSchema { return inplace_enforced_(x, y); } - CAFFE2_API friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema); + TORCH_API friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema); const std::vector& args() const { return args_; @@ -457,7 +457,7 @@ class CAFFE2_API OpSchema { /** * @brief A registry to hold all the operator schemas. */ -class CAFFE2_API OpSchemaRegistry { +class TORCH_API OpSchemaRegistry { public: static OpSchema& NewSchema(const string& key, const string& file, const int line) { diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc index 3f70e96fffc89..06e27ef5be540 100644 --- a/caffe2/core/plan_executor.cc +++ b/caffe2/core/plan_executor.cc @@ -17,10 +17,18 @@ C10_DEFINE_bool( "If used we will handle exceptions in executor threads. " "This avoids SIGABRT but may cause process to deadlock"); +C10_DEFINE_int( + caffe2_plan_executor_exception_timeout, + 60, + "Number of seconds to wait for concurrent threads to stop on exception" + "before terminating."); + namespace caffe2 { namespace { +// ExceptionWrapper holds an exception. If exception pointers are being used, +// it'll hold the original exception pointer otherwise just the message. class ExceptionWrapper { public: ExceptionWrapper() : hasException_(false) {} @@ -39,6 +47,10 @@ class ExceptionWrapper { #endif } + const std::string& what() const { + return exceptionMsg_; + } + operator bool() { return hasException_; } @@ -51,6 +63,34 @@ class ExceptionWrapper { std::string exceptionMsg_; }; +// ExceptionWrapperTerminate terminates the program with the specified +// exception. This preserves the exception ptr and ExceptionTracer will +// correctly grab it on exit. +class ExceptionWrapperTerminate { + public: + explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew) + : ew_(std::move(ew)) {} + + ~ExceptionWrapperTerminate() { + ew_.rethrowException(); + } + + private: + ExceptionWrapper ew_; +}; + +// ScopeExitGuard runs the provided function when it's destructed. +class ScopeExitGuard { + public: + explicit ScopeExitGuard(std::function&& f) : f_(std::move(f)) {} + ~ScopeExitGuard() { + f_(); + } + + private: + std::function f_; +}; + struct NetDefInfo { const NetDef* netDef; // in order to keep the "override existing nets" on the top-level workflow, @@ -61,44 +101,6 @@ struct NetDefInfo { using NetDefMap = std::unordered_map; -struct Reporter { - struct ReporterInstance { - std::mutex report_mutex; - std::condition_variable report_cv; - std::thread report_thread; - ReporterInstance(int intervalMillis, bool* done, std::function f) { - auto interval = std::chrono::milliseconds(intervalMillis); - auto reportWorker = [=]() { - std::unique_lock lk(report_mutex); - do { - report_cv.wait_for(lk, interval, [&]() { return *done; }); - f(); - } while (!*done); - }; - report_thread = std::thread(reportWorker); - } - }; - - void start(int64_t intervalMillis, std::function f) { - instances_.emplace_back(new ReporterInstance(intervalMillis, &done, f)); - } - - ~Reporter() { - done = true; - for (auto& instance : instances_) { - if (!instance->report_thread.joinable()) { - continue; - } - instance->report_cv.notify_all(); - instance->report_thread.join(); - } - } - - private: - std::vector> instances_; - bool done{false}; -}; - // Returns a function that returns `true` if we should continue // iterating, given the current iteration count. std::function getContinuationTest( @@ -131,8 +133,8 @@ std::function getContinuationTest( // if the blob doesn't exist or is not initialized, return false inline bool getShouldStop(const Blob* b) { if (!b || - b->meta().id() == - TypeIdentifier::uninitialized()) { // not exist or uninitialized + b->meta() == + ScalarType::Undefined) { // not exist or uninitialized return false; } @@ -355,6 +357,23 @@ struct CompiledExecutionStep { }; } + void Fail(const std::exception& ex) { + { + std::lock_guard guard(exception_mutex_); + if (!first_exception_) { + LOG(ERROR) << "Substep exception:\n" << c10::GetExceptionString(ex); + first_exception_ = ExceptionWrapper(ex); + } + gotFailure = true; + } + Cancel(); + } + + ExceptionWrapper FirstException() { + std::lock_guard guard(exception_mutex_); + return first_exception_; + } + // Cancel attempts to cancel the running nets in a best effort way. If the net // or op type does IO and doesn't implement cancellation it may not be // possible to cancel leading to execution getting stuck on error. @@ -387,6 +406,9 @@ struct CompiledExecutionStep { private: std::unique_ptr localWorkspace_; + + std::mutex exception_mutex_; // protects first_exception_ + ExceptionWrapper first_exception_; }; void ExecutionStepWrapper::Cancel() { @@ -404,6 +426,65 @@ std::unique_ptr ExecutionStepWrapper::doCompile() { ws_id_injector_)); } +struct Reporter { + struct ReporterInstance { + std::mutex report_mutex; + std::condition_variable report_cv; + std::thread report_thread; + ExceptionWrapper exception; + + ReporterInstance( + int intervalMillis, + std::atomic* done, + std::function f, + ExecutionStepWrapper::CompiledGuard* compiledStep) { + auto interval = std::chrono::milliseconds(intervalMillis); + auto reportWorker = [=]() { + std::unique_lock lk(report_mutex); + do { + report_cv.wait_for(lk, interval, [&]() { return done->load(); }); + try { + f(); + } catch (const std::exception& ex) { + LOG(ERROR) << "Reporter instance exception:\n" + << c10::GetExceptionString(ex); + if (!FLAGS_caffe2_handle_executor_threads_exceptions) { + throw; + } + (*compiledStep)->Fail(ex); + done->store(true); + } + } while (!done->load()); + }; + report_thread = std::thread(reportWorker); + } + }; + + explicit Reporter(ExecutionStepWrapper::CompiledGuard* compiledStep) + : compiledStep_(compiledStep) {} + + void start(int64_t intervalMillis, std::function f) { + instances_.emplace_back( + new ReporterInstance(intervalMillis, &done_, f, compiledStep_)); + } + + ~Reporter() { + done_ = true; + for (auto& instance : instances_) { + if (!instance->report_thread.joinable()) { + continue; + } + instance->report_cv.notify_all(); + instance->report_thread.join(); + } + } + + private: + std::vector> instances_; + std::atomic done_{false}; + ExecutionStepWrapper::CompiledGuard* compiledStep_; +}; + #define CHECK_SHOULD_STOP(step, shouldStop) \ if (getShouldStop(shouldStop)) { \ VLOG(1) << "Execution step " << step.name() << " stopped by " \ @@ -419,7 +500,7 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { std::unique_ptr reporter; if (step.has_report_net() || compiledStep->reportSubsteps.size() > 0) { - reporter = std::make_unique(); + reporter = std::make_unique(&compiledStep); auto* reportNet = compiledStep->reportNet; if (reportNet) { VLOG(1) << "Starting reporter net"; @@ -460,9 +541,16 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { << " with " << step.substep().size() << " concurrent substeps"; std::atomic next_substep{0}; - std::mutex exception_mutex; - ExceptionWrapper first_exception; + std::condition_variable cv; + std::mutex exception_mutex; // protects done + int done{0}; auto worker = [&]() { + ScopeExitGuard on_exit([&] { + std::lock_guard guard(exception_mutex); + done += 1; + cv.notify_all(); + }); + auto num_substeps = compiledStep->recurringSubsteps.size(); int substep_id = next_substep++ % num_substeps; if (compiledStep->gotFailure) { @@ -474,14 +562,7 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { compiledStep->gotFailure = true; } } catch (const std::exception& ex) { - std::lock_guard guard(exception_mutex); - if (!first_exception) { - first_exception = ExceptionWrapper(ex); - LOG(ERROR) << "Parallel worker exception:\n" - << c10::GetExceptionString(ex); - } - compiledStep->gotFailure = true; - compiledStep->Cancel(); + compiledStep->Fail(ex); if (!FLAGS_caffe2_handle_executor_threads_exceptions) { // In complex plans other threads might get stuck if another // one fails. So we let exception to go out of thread which @@ -492,6 +573,8 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { } }; + std::unique_lock guard(exception_mutex); + std::vector threads; auto numThreads = compiledStep->recurringSubsteps.size(); if (step.has_num_concurrent_instances()) { @@ -500,6 +583,24 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { for (size_t i = 0; i < numThreads; ++i) { threads.emplace_back(worker); } + + auto workersDone = [&] { return done == numThreads; }; + + // If we get an exception, try to wait for all threads to stop + // gracefully. + cv.wait( + guard, [&] { return workersDone() || compiledStep->gotFailure; }); + cv.wait_for( + guard, + std::chrono::seconds(FLAGS_caffe2_plan_executor_exception_timeout), + [&] { return workersDone(); }); + auto first_exception = compiledStep->FirstException(); + if (!workersDone() && first_exception) { + LOG(ERROR) << "failed to stop concurrent workers after exception: " + << first_exception.what(); + ExceptionWrapperTerminate(std::move(first_exception)); + } + for (auto& thread : threads) { thread.join(); } @@ -527,7 +628,11 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) { } } } - return true; + + if (auto first_exception = compiledStep->FirstException()) { + first_exception.rethrowException(); + } + return !compiledStep->gotFailure; } #undef CHECK_SHOULD_STOP diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc index 86f145d72a096..c26361bee788c 100644 --- a/caffe2/core/plan_executor_test.cc +++ b/caffe2/core/plan_executor_test.cc @@ -18,6 +18,51 @@ static std::atomic cancelCount{0}; static std::atomic stuckRun{false}; } // namespace +class StuckBlockingOp final : public Operator { + public: + StuckBlockingOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + + bool RunOnDevice() override { + // StuckBlockingOp runs and notifies ErrorOp. + stuckRun = true; + + while (!cancelled_) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + return true; + } + + void Cancel() override { + LOG(INFO) << "cancelled StuckBlockingOp."; + cancelCount += 1; + cancelled_ = true; + } + + private: + std::atomic cancelled_{false}; +}; + +REGISTER_CPU_OPERATOR(StuckBlocking, StuckBlockingOp); +OPERATOR_SCHEMA(StuckBlocking).NumInputs(0).NumOutputs(0); + +class NoopOp final : public Operator { + public: + NoopOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + + bool RunOnDevice() override { + // notify Error op we've ran. + stuckRun = true; + return true; + } +}; + +REGISTER_CPU_OPERATOR(Noop, NoopOp); +OPERATOR_SCHEMA(Noop).NumInputs(0).NumOutputs(0); + + class StuckAsyncOp final : public Operator { public: StuckAsyncOp(const OperatorDef& operator_def, Workspace* ws) @@ -55,7 +100,7 @@ class ErrorOp final : public Operator { : Operator(operator_def, ws) {} bool RunOnDevice() override { - // Wait for StuckAsyncOp to run first. + // Wait for StuckAsyncOp or StuckBlockingOp to run first. while (!stuckRun) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } @@ -67,6 +112,29 @@ class ErrorOp final : public Operator { REGISTER_CPU_OPERATOR(Error, ErrorOp); OPERATOR_SCHEMA(Error).NumInputs(0).NumOutputs(0); +static std::atomic blockingErrorRuns{0}; +class BlockingErrorOp final : public Operator { + public: + BlockingErrorOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + + bool RunOnDevice() override { + // First n op executions should block and then start throwing errors. + if (blockingErrorRuns.fetch_sub(1) >= 1) { + LOG(INFO) << "blocking"; + while (true) { + std::this_thread::sleep_for(std::chrono::hours(10)); + } + } else { + LOG(INFO) << "throwing"; + throw TestError(); + } + } +}; + +REGISTER_CPU_OPERATOR(BlockingError, BlockingErrorOp); +OPERATOR_SCHEMA(BlockingError).NumInputs(0).NumOutputs(0); + PlanDef parallelErrorPlan() { PlanDef plan_def; @@ -100,11 +168,81 @@ PlanDef parallelErrorPlan() { return plan_def; } +PlanDef parallelErrorPlanWithCancellableStuckNet() { + // Set a plan with two nets: one stuck net with blocking operator that never + // returns; one error net with error op that throws. + PlanDef plan_def; + + auto* stuck_blocking_net = plan_def.add_network(); + stuck_blocking_net->set_name("stuck_blocking_net"); + { + auto* op = stuck_blocking_net->add_op(); + op->set_type("StuckBlocking"); + } + + auto* error_net = plan_def.add_network(); + error_net->set_name("error_net"); + { + auto* op = error_net->add_op(); + op->set_type("Error"); + } + + auto* execution_step = plan_def.add_execution_step(); + execution_step->set_concurrent_substeps(true); + { + auto* substep = execution_step->add_substep(); + substep->add_network(stuck_blocking_net->name()); + } + { + auto* substep = execution_step->add_substep(); + substep->add_network(error_net->name()); + } + + return plan_def; +} + +PlanDef reporterErrorPlanWithCancellableStuckNet() { + // Set a plan with a concurrent net and a reporter net: one stuck net with + // blocking operator that never returns; one reporter net with error op + // that throws. + PlanDef plan_def; + + auto* stuck_blocking_net = plan_def.add_network(); + stuck_blocking_net->set_name("stuck_blocking_net"); + { + auto* op = stuck_blocking_net->add_op(); + op->set_type("StuckBlocking"); + } + + auto* error_net = plan_def.add_network(); + error_net->set_name("error_net"); + { + auto* op = error_net->add_op(); + op->set_type("Error"); + } + + auto* execution_step = plan_def.add_execution_step(); + execution_step->set_concurrent_substeps(true); + { + auto* substep = execution_step->add_substep(); + substep->add_network(stuck_blocking_net->name()); + } + { + auto* substep = execution_step->add_substep(); + substep->set_run_every_ms(1); + substep->add_network(error_net->name()); + } + + return plan_def; +} + struct HandleExecutorThreadExceptionsGuard { - HandleExecutorThreadExceptionsGuard() { + HandleExecutorThreadExceptionsGuard(int timeout = 60) { globalInit({ "caffe2", "--caffe2_handle_executor_threads_exceptions=1", + "--caffe2_plan_executor_exception_timeout=" + + caffe2::to_string(timeout), }); } @@ -133,12 +271,136 @@ struct HandleExecutorThreadExceptionsGuard { TEST(PlanExecutorTest, ErrorAsyncPlan) { HandleExecutorThreadExceptionsGuard guard; + cancelCount = 0; PlanDef plan_def = parallelErrorPlan(); Workspace ws; ASSERT_THROW(ws.RunPlan(plan_def), TestError); ASSERT_EQ(cancelCount, 1); } +// death tests not supported on mobile +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) +TEST(PlanExecutorTest, BlockingErrorPlan) { + // TSAN doesn't play nicely with death tests +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) + return; +#endif +#endif + + ASSERT_DEATH( + [] { + HandleExecutorThreadExceptionsGuard guard(/*timeout=*/1); + + PlanDef plan_def; + + std::string plan_def_template = R"DOC( + network { + name: "net" + op { + type: "BlockingError" + } + } + execution_step { + num_concurrent_instances: 2 + substep { + network: "net" + } + } + )DOC"; + + CAFFE_ENFORCE( + TextFormat::ParseFromString(plan_def_template, &plan_def)); + Workspace ws; + blockingErrorRuns = 1; + ws.RunPlan(plan_def); + FAIL() << "shouldn't have reached this point"; + }(), + "failed to stop concurrent workers after exception: test error"); +} +#endif + +TEST(PlanExecutorTest, ErrorPlanWithCancellableStuckNet) { + HandleExecutorThreadExceptionsGuard guard; + + cancelCount = 0; + PlanDef plan_def = parallelErrorPlanWithCancellableStuckNet(); + Workspace ws; + + ASSERT_THROW(ws.RunPlan(plan_def), TestError); + ASSERT_EQ(cancelCount, 1); +} + +TEST(PlanExecutorTest, ReporterErrorPlanWithCancellableStuckNet) { + HandleExecutorThreadExceptionsGuard guard; + + cancelCount = 0; + PlanDef plan_def = reporterErrorPlanWithCancellableStuckNet(); + Workspace ws; + + ASSERT_THROW(ws.RunPlan(plan_def), TestError); + ASSERT_EQ(cancelCount, 1); +} + +PlanDef shouldStopWithCancelPlan() { + // Set a plan with a looping net with should_stop_blob set and a concurrent + // net that throws an error. The error should cause should_stop to return + // false and end the concurrent net. + PlanDef plan_def; + + auto* should_stop_net = plan_def.add_network(); + { + auto* op = should_stop_net->add_op(); + op->set_type("Noop"); + } + should_stop_net->set_name("should_stop_net"); + should_stop_net->set_type("async_scheduling"); + + auto* error_net = plan_def.add_network(); + error_net->set_name("error_net"); + { + auto* op = error_net->add_op(); + op->set_type("Error"); + } + + auto* execution_step = plan_def.add_execution_step(); + execution_step->set_concurrent_substeps(true); + { + auto* substep = execution_step->add_substep(); + execution_step->set_concurrent_substeps(true); + substep->set_name("concurrent_should_stop"); + substep->set_should_stop_blob("should_stop_blob"); + auto* substep2 = substep->add_substep(); + substep2->set_name("should_stop_net"); + substep2->add_network(should_stop_net->name()); + substep2->set_num_iter(10); + } + { + auto* substep = execution_step->add_substep(); + substep->set_name("error_step"); + substep->add_network(error_net->name()); + } + + return plan_def; +} + +TEST(PlanExecutorTest, ShouldStopWithCancel) { + HandleExecutorThreadExceptionsGuard guard; + + stuckRun = false; + PlanDef plan_def = shouldStopWithCancelPlan(); + Workspace ws; + + Blob* blob = ws.CreateBlob("should_stop_blob"); + Tensor* tensor = BlobGetMutableTensor(blob, CPU); + const vector& shape{1}; + tensor->Resize(shape); + tensor->mutable_data()[0] = false; + + ASSERT_THROW(ws.RunPlan(plan_def), TestError); + ASSERT_TRUE(stuckRun); +} + } // namespace caffe2 #endif diff --git a/caffe2/core/stats.h b/caffe2/core/stats.h index f037ca6e17560..a2ba948cc8cf4 100644 --- a/caffe2/core/stats.h +++ b/caffe2/core/stats.h @@ -11,7 +11,7 @@ namespace caffe2 { -class CAFFE2_API StatValue { +class TORCH_API StatValue { std::atomic v_{0}; public: @@ -28,7 +28,7 @@ class CAFFE2_API StatValue { } }; -struct CAFFE2_API ExportedStatValue { +struct TORCH_API ExportedStatValue { std::string key; int64_t value; std::chrono::time_point ts; @@ -40,7 +40,7 @@ struct CAFFE2_API ExportedStatValue { using ExportedStatList = std::vector; using ExportedStatMap = std::unordered_map; -CAFFE2_API ExportedStatMap toMap(const ExportedStatList& stats); +TORCH_API ExportedStatMap toMap(const ExportedStatList& stats); /** * @brief Holds a map of atomic counters keyed by name. @@ -114,7 +114,7 @@ CAFFE2_API ExportedStatMap toMap(const ExportedStatList& stats); * structure by calling StatRegistry::update(). * */ -class CAFFE2_API StatRegistry { +class TORCH_API StatRegistry { std::mutex mutex_; std::unordered_map> stats_; @@ -153,7 +153,7 @@ class CAFFE2_API StatRegistry { ~StatRegistry(); }; -struct CAFFE2_API Stat { +struct TORCH_API Stat { std::string groupName; std::string name; Stat(const std::string& gn, const std::string& n) : groupName(gn), name(n) {} @@ -164,7 +164,7 @@ struct CAFFE2_API Stat { } }; -class CAFFE2_API ExportedStat : public Stat { +class TORCH_API ExportedStat : public Stat { StatValue* value_; public: @@ -181,7 +181,7 @@ class CAFFE2_API ExportedStat : public Stat { } }; -class CAFFE2_API AvgExportedStat : public ExportedStat { +class TORCH_API AvgExportedStat : public ExportedStat { private: ExportedStat count_; @@ -200,7 +200,7 @@ class CAFFE2_API AvgExportedStat : public ExportedStat { } }; -class CAFFE2_API StdDevExportedStat : public ExportedStat { +class TORCH_API StdDevExportedStat : public ExportedStat { // Uses an offset (first_) to remove issue of cancellation // Variance is then (sumsqoffset_ - (sumoffset_^2) / count_) / (count_ - 1) private: @@ -234,7 +234,7 @@ class CAFFE2_API StdDevExportedStat : public ExportedStat { } }; -class CAFFE2_API DetailedExportedStat : public ExportedStat { +class TORCH_API DetailedExportedStat : public ExportedStat { private: std::vector details_; @@ -258,7 +258,7 @@ class CAFFE2_API DetailedExportedStat : public ExportedStat { } }; -class CAFFE2_API StaticStat : public Stat { +class TORCH_API StaticStat : public Stat { private: StatValue* value_; diff --git a/caffe2/core/tensor.cc b/caffe2/core/tensor.cc index 1123d4970fad0..b56cdca0da37b 100644 --- a/caffe2/core/tensor.cc +++ b/caffe2/core/tensor.cc @@ -164,10 +164,12 @@ void ReinitializeTensor( if (tensor->dtype() == options.dtype()) { tensor->raw_mutable_data(); } else { - C10_LOG_FIRST_N(WARNING, 1) - << "Changing the data type of Tensor is discouraged." - << " Attempt to change data type from: " << tensor->dtype() - << " to: " << options.dtype(); + // This C10 logging API is not thread-safe, and should not be called here + // This can lead to a memory corruption in glog. + // C10_LOG_FIRST_N(WARNING, 1) + // << "Changing the data type of Tensor is discouraged." + // << " Attempt to change data type from: " << tensor->dtype() + // << " to: " << options.dtype(); // create a new Tensor when the data_type doesn't match *tensor = caffe2::empty(dims, options); } diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 18a7be64d670f..77b8d2b5cb56d 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -26,7 +26,7 @@ using at::UndefinedTensorImpl; * * NB: See TensorImpl for documentation on these methods. */ -class CAFFE2_API Tensor final { +class TORCH_API Tensor final { private: enum Unsafe { IDoWantAliasing }; Tensor(const Tensor& other, Unsafe _) : impl_(other.getIntrusivePtr()) {} @@ -70,7 +70,7 @@ class CAFFE2_API Tensor final { explicit Tensor(at::Device device) : impl_(c10::make_intrusive( Storage::create_legacy(device), - c10::computeDispatchKey(at::device(device).layout(at::kStrided)), + c10::computeDispatchKey(c10::nullopt, at::kStrided, device), TypeMeta())) {} /** @@ -299,14 +299,14 @@ class CAFFE2_API Tensor final { void ShareExternalPointer( void* src, - const TypeMeta& data_type, + const TypeMeta data_type, size_t nbytes = 0, MemoryDeleter d = nullptr) const { CAFFE_ENFORCE_WITH_CALLER( impl_->is_contiguous(), "Right now ShareExternalPointer is only supported for contiguous Tensor."); CAFFE_ENFORCE_WITH_CALLER( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "To share with a raw external pointer you need to pass in an " "initialized data_type(TypeMeta)."); impl_.get()->ShareExternalPointer( @@ -315,7 +315,7 @@ class CAFFE2_API Tensor final { void ShareExternalPointer( at::DataPtr&& data_ptr, - const TypeMeta& data_type, + const TypeMeta data_type, size_t nbytes) { impl_.get()->ShareExternalPointer(std::move(data_ptr), data_type, nbytes); } @@ -342,7 +342,7 @@ class CAFFE2_API Tensor final { return impl_.get()->data(); } - inline void* raw_mutable_data(const TypeMeta& meta) const { + inline void* raw_mutable_data(const TypeMeta meta) const { return impl_.get()->raw_mutable_data(meta); } @@ -358,7 +358,7 @@ class CAFFE2_API Tensor final { inline void* raw_mutable_data() const { const auto& data_type = impl_->dtype(); CAFFE_ENFORCE_WITH_CALLER( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "Calling raw_mutable_data() without meta, but the current meta is " "of unknown type."); return raw_mutable_data(data_type); @@ -469,7 +469,7 @@ class CAFFE2_API Tensor final { /** * Returns the TypeMeta object associated with the current data type. */ - inline const TypeMeta& dtype() const { + inline const TypeMeta dtype() const { return impl_->dtype(); } @@ -477,7 +477,7 @@ class CAFFE2_API Tensor final { * (To be deprecated) Returns the TypeMeta object associated with the current * data type. */ - inline const TypeMeta& meta() const { + inline const TypeMeta meta() const { return impl_->dtype(); } @@ -530,10 +530,10 @@ class CAFFE2_API Tensor final { * this will not do anything if the * Tensor already has correct size and data type */ -CAFFE2_API void +TORCH_API void ReinitializeTensor(Tensor* t, at::IntArrayRef dims, at::TensorOptions options); -CAFFE2_API void ReinitializeAndCopyFrom( +TORCH_API void ReinitializeAndCopyFrom( Tensor* t, at::TensorOptions options, const Tensor& src, @@ -564,7 +564,7 @@ void TensorVectorResize( DeviceType type); // Tensor factory function -CAFFE2_API Tensor empty(at::IntArrayRef dims, at::TensorOptions options); +TORCH_API Tensor empty(at::IntArrayRef dims, at::TensorOptions options); /** * @brief Creates a CPU tensor, and fills its contents with the given values. @@ -585,7 +585,7 @@ Tensor TensorCPUFromValues(at::IntArrayRef dims, at::ArrayRef values) { vector GetTensorInfo(const void* c, size_t* capacity, DeviceOption* device); -class CAFFE2_API TensorPrinter { +class TORCH_API TensorPrinter { public: explicit TensorPrinter( const std::string& tensor_name = "", diff --git a/caffe2/core/test_utils.h b/caffe2/core/test_utils.h index 47226c9232d88..89f21c133255a 100644 --- a/caffe2/core/test_utils.h +++ b/caffe2/core/test_utils.h @@ -18,13 +18,13 @@ namespace caffe2 { namespace testing { // Asserts that the values of two tensors are the same. -CAFFE2_API void assertTensorEquals( +TORCH_API void assertTensorEquals( const TensorCPU& tensor1, const TensorCPU& tensor2, float eps = 1e-6); // Asserts that two float values are close within epsilon. -CAFFE2_API void assertNear(float value1, float value2, float epsilon); +TORCH_API void assertNear(float value1, float value2, float epsilon); // Asserts that the numeric values of a tensor is equal to a data vector. template @@ -55,23 +55,23 @@ void assertTensor( } // Asserts a list of tensors presented in two workspaces are equal. -CAFFE2_API void assertTensorListEquals( +TORCH_API void assertTensorListEquals( const std::vector& tensorNames, const Workspace& workspace1, const Workspace& workspace2); // Read a tensor from the workspace. -CAFFE2_API const caffe2::Tensor& getTensor( +TORCH_API const caffe2::Tensor& getTensor( const caffe2::Workspace& workspace, const std::string& name); // Create a new tensor in the workspace. -CAFFE2_API caffe2::Tensor* createTensor( +TORCH_API caffe2::Tensor* createTensor( const std::string& name, caffe2::Workspace* workspace); // Create a new operator in the net. -CAFFE2_API caffe2::OperatorDef* createOperator( +TORCH_API caffe2::OperatorDef* createOperator( const std::string& type, const std::vector& inputs, const std::vector& outputs, @@ -154,7 +154,7 @@ caffe2::Tensor* createTensorAndConstantFill( } // Concise util class to mutate a net in a chaining fashion. -class CAFFE2_API NetMutator { +class TORCH_API NetMutator { public: explicit NetMutator(caffe2::NetDef* net) : net_(net) {} @@ -184,7 +184,7 @@ class CAFFE2_API NetMutator { }; // Concise util class to mutate a workspace in a chaining fashion. -class CAFFE2_API WorkspaceMutator { +class TORCH_API WorkspaceMutator { public: explicit WorkspaceMutator(caffe2::Workspace* workspace) : workspace_(workspace) {} diff --git a/caffe2/core/transform.h b/caffe2/core/transform.h index 723e14789d627..7f8971c89406f 100644 --- a/caffe2/core/transform.h +++ b/caffe2/core/transform.h @@ -31,7 +31,7 @@ namespace caffe2 { * own transform, write your implementations for PatternRule, ValidatorRule, and * ReplaceRule. */ -class CAFFE2_API Transform { +class TORCH_API Transform { public: Transform() {} @@ -148,7 +148,7 @@ class CAFFE2_API Transform { }; // Creates a Transform based on a key, which should be defined in registry. -CAFFE2_API unique_ptr CreateTransform(string key); +TORCH_API unique_ptr CreateTransform(string key); C10_DECLARE_REGISTRY(TransformRegistry, Transform); #define REGISTER_TRANSFORM(name, ...) \ @@ -156,14 +156,14 @@ C10_DECLARE_REGISTRY(TransformRegistry, Transform); // Create a Transform object from registry, // and immediately apply it to a Netdef. -CAFFE2_API NetDef ApplyTransform(const string& key, const NetDef& netdef); +TORCH_API NetDef ApplyTransform(const string& key, const NetDef& netdef); // Create a Transform object from registry, apply it to a NetDef. // Will only return the transformed net if it is faster than the old net. // This will run the init net first, will run the two nets warmup_runs times. // Then, we will take the average time of main_runs runs, and only keep the // transformed net if it is faster by a factor of improvement_threshold. -CAFFE2_API NetDef ApplyTransformIfFaster( +TORCH_API NetDef ApplyTransformIfFaster( const string& key, const NetDef& netdef, const NetDef& init_netdef, diff --git a/caffe2/core/types.cc b/caffe2/core/types.cc index d1007fe76e863..c738fc50a2888 100644 --- a/caffe2/core/types.cc +++ b/caffe2/core/types.cc @@ -8,7 +8,7 @@ namespace caffe2 { -TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { +TensorProto::DataType TypeMetaToDataType(const TypeMeta meta) { static_assert( sizeof(int) == 4, "int in this compiler does not equal to 4 bytes."); static std::map data_type_map{ @@ -36,7 +36,7 @@ TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { it == data_type_map.end() ? TensorProto_DataType_UNDEFINED : it->second); } -const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt) { +const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt) { static std::map type_meta_map{ {TensorProto_DataType_FLOAT, TypeMeta::Make()}, {TensorProto_DataType_INT32, TypeMeta::Make()}, diff --git a/caffe2/core/types.h b/caffe2/core/types.h index c0e8d7bbfb3d8..7a74abe4fac94 100644 --- a/caffe2/core/types.h +++ b/caffe2/core/types.h @@ -47,10 +47,10 @@ inline int32_t GetDimFromOrderString(const std::string& str) { inline constexpr char NameScopeSeparator() { return '/'; } // From TypeMeta to caffe2::DataType protobuffer enum. -CAFFE2_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta); +TORCH_API TensorProto::DataType TypeMetaToDataType(const TypeMeta meta); // From caffe2::DataType protobuffer enum to TypeMeta -CAFFE2_API const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt); +TORCH_API const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt); } // namespace caffe2 diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h index 793b5f611d089..25805d0418730 100644 --- a/caffe2/core/workspace.h +++ b/caffe2/core/workspace.h @@ -24,7 +24,7 @@ namespace caffe2 { class NetBase; -struct CAFFE2_API StopOnSignal { +struct TORCH_API StopOnSignal { StopOnSignal() : handler_(std::make_shared( SignalHandler::Action::STOP, @@ -44,7 +44,7 @@ struct CAFFE2_API StopOnSignal { * runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of * all these objects and deals with the scaffolding logistics. */ -class CAFFE2_API Workspace { +class TORCH_API Workspace { public: typedef std::function ShouldContinue; typedef CaffeMap > BlobMap; diff --git a/caffe2/distributed/file_store_handler.cc b/caffe2/distributed/file_store_handler.cc index 5a749c304d2bb..5a34e53b69472 100644 --- a/caffe2/distributed/file_store_handler.cc +++ b/caffe2/distributed/file_store_handler.cc @@ -122,6 +122,16 @@ int64_t FileStoreHandler::add( return 0; } +int64_t FileStoreHandler::getNumKeys() { + CHECK(false) << "getNumKeys not implemented for FileStoreHandler"; + return 0; +} + +bool FileStoreHandler::deleteKey(const std::string& /* unused */) { + CHECK(false) << "deleteKey not implemented for FileStoreHandler"; + return false; +} + bool FileStoreHandler::check(const std::vector& names) { std::vector paths; for (const auto& name : names) { diff --git a/caffe2/distributed/file_store_handler.h b/caffe2/distributed/file_store_handler.h index b58b156e51b03..43d86fed8c57f 100644 --- a/caffe2/distributed/file_store_handler.h +++ b/caffe2/distributed/file_store_handler.h @@ -4,7 +4,7 @@ namespace caffe2 { -class CAFFE2_API FileStoreHandler : public StoreHandler { +class TORCH_API FileStoreHandler : public StoreHandler { public: explicit FileStoreHandler(const std::string& path, const std::string& prefix); virtual ~FileStoreHandler(); @@ -17,6 +17,10 @@ class CAFFE2_API FileStoreHandler : public StoreHandler { virtual int64_t add(const std::string& name, int64_t value) override; + virtual bool deleteKey(const std::string& key) override; + + virtual int64_t getNumKeys() override; + virtual bool check(const std::vector& names) override; virtual void wait( diff --git a/caffe2/distributed/redis_store_handler.cc b/caffe2/distributed/redis_store_handler.cc index 7caaa6c79de7a..e424c0e719fd8 100644 --- a/caffe2/distributed/redis_store_handler.cc +++ b/caffe2/distributed/redis_store_handler.cc @@ -76,6 +76,16 @@ int64_t RedisStoreHandler::add(const std::string& name, int64_t value) { return reply->integer; } +int64_t RedisStoreHandler::getNumKeys() { + CHECK(false) << "getNumKeys not implemented for RedisStoreHandler"; + return 0; +} + +bool RedisStoreHandler::deleteKey(const std::string& /* unused */) { + CHECK(false) << "deleteKey not implemented for RedisStoreHandler"; + return false; +} + bool RedisStoreHandler::check(const std::vector& names) { std::vector args; args.push_back("EXISTS"); diff --git a/caffe2/distributed/redis_store_handler.h b/caffe2/distributed/redis_store_handler.h index 0caa888a6629f..1ff75918cd8c8 100644 --- a/caffe2/distributed/redis_store_handler.h +++ b/caffe2/distributed/redis_store_handler.h @@ -10,7 +10,7 @@ extern "C" { namespace caffe2 { -class CAFFE2_API RedisStoreHandler : public StoreHandler { +class TORCH_API RedisStoreHandler : public StoreHandler { public: explicit RedisStoreHandler(std::string& host, int port, std::string& prefix); virtual ~RedisStoreHandler(); @@ -23,6 +23,10 @@ class CAFFE2_API RedisStoreHandler : public StoreHandler { virtual int64_t add(const std::string& name, int64_t value) override; + virtual int64_t getNumKeys() override; + + virtual bool deleteKey(const std::string& key) override; + virtual bool check(const std::vector& names) override; virtual void wait( diff --git a/caffe2/distributed/store_handler.h b/caffe2/distributed/store_handler.h index e11ea57aea3de..d4d9b80b49293 100644 --- a/caffe2/distributed/store_handler.h +++ b/caffe2/distributed/store_handler.h @@ -10,7 +10,7 @@ namespace caffe2 { -class CAFFE2_API StoreHandler { +class TORCH_API StoreHandler { public: static constexpr std::chrono::milliseconds kDefaultTimeout = std::chrono::seconds(30); @@ -41,6 +41,16 @@ class CAFFE2_API StoreHandler { */ virtual int64_t add(const std::string& name, int64_t value) = 0; + /* + * Returns the number of keys in this store. + */ + virtual int64_t getNumKeys() = 0; + + /* + * Removes the specified key from the store. + */ + virtual bool deleteKey(const std::string& key) = 0; + /* * Check if a keys exist in the store. */ @@ -57,7 +67,7 @@ class CAFFE2_API StoreHandler { /* * The backing store is no longer available. It may have been deleted. */ -struct CAFFE2_API StoreHandlerNotAvailableException +struct TORCH_API StoreHandlerNotAvailableException : public std::runtime_error { explicit StoreHandlerNotAvailableException(const std::string& msg) : std::runtime_error(msg) {} @@ -70,7 +80,7 @@ struct CAFFE2_API StoreHandlerNotAvailableException /* * Timeout accessing the store. */ -struct CAFFE2_API StoreHandlerTimeoutException : public std::runtime_error { +struct TORCH_API StoreHandlerTimeoutException : public std::runtime_error { explicit StoreHandlerTimeoutException(const std::string& msg) : std::runtime_error(msg) {} }; diff --git a/caffe2/ideep/ideep_utils.h b/caffe2/ideep/ideep_utils.h index 947d1b337ab3d..b1b3aae3a8ee8 100644 --- a/caffe2/ideep/ideep_utils.h +++ b/caffe2/ideep/ideep_utils.h @@ -1,7 +1,7 @@ #pragma once #include // For caffe2 macros. - +#include // All caffe2 ideep related headers #include #include diff --git a/caffe2/ideep/utils/ideep_context.h b/caffe2/ideep/utils/ideep_context.h index 823b4bec16bd7..d0f1207a08f69 100644 --- a/caffe2/ideep/utils/ideep_context.h +++ b/caffe2/ideep/utils/ideep_context.h @@ -91,7 +91,7 @@ class IDEEPContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { if (meta.copy()) { meta.copy()(src, dst, n); } else { diff --git a/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h b/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h index cc0ad1f72d010..04495fa0cd729 100644 --- a/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h +++ b/caffe2/mobile/contrib/libvulkan-stub/include/vulkan/vulkan.h @@ -6,7 +6,7 @@ extern "C" { #endif /* -** Copyright (c) 2015-2016 The Khronos Group Inc. +** Copyright (c) 2015-2017 The Khronos Group Inc. ** ** Licensed under the Apache License, Version 2.0 (the "License"); ** you may not use this file except in compliance with the License. @@ -28,22 +28,22 @@ extern "C" { #define VK_VERSION_1_0 1 -#include "vk_platform.h" +#include "./vk_platform.h" #define VK_MAKE_VERSION(major, minor, patch) \ (((major) << 22) | ((minor) << 12) | (patch)) // DEPRECATED: This define has been removed. Specific version defines (e.g. VK_API_VERSION_1_0), or the VK_MAKE_VERSION macro, should be used instead. -//#define VK_API_VERSION VK_MAKE_VERSION(1, 0, 0) +//#define VK_API_VERSION VK_MAKE_VERSION(1, 0, 0) // Patch version should always be set to 0 // Vulkan 1.0 version number -#define VK_API_VERSION_1_0 VK_MAKE_VERSION(1, 0, 0) +#define VK_API_VERSION_1_0 VK_MAKE_VERSION(1, 0, 0)// Patch version should always be set to 0 #define VK_VERSION_MAJOR(version) ((uint32_t)(version) >> 22) #define VK_VERSION_MINOR(version) (((uint32_t)(version) >> 12) & 0x3ff) #define VK_VERSION_PATCH(version) ((uint32_t)(version) & 0xfff) // Version of this file -#define VK_HEADER_VERSION 29 +#define VK_HEADER_VERSION 59 #define VK_NULL_HANDLE 0 @@ -145,6 +145,8 @@ typedef enum VkResult { VK_ERROR_INCOMPATIBLE_DISPLAY_KHR = -1000003001, VK_ERROR_VALIDATION_FAILED_EXT = -1000011001, VK_ERROR_INVALID_SHADER_NV = -1000012000, + VK_ERROR_OUT_OF_POOL_MEMORY_KHR = -1000069000, + VK_ERROR_INVALID_EXTERNAL_HANDLE_KHR = -1000072003, VK_RESULT_BEGIN_RANGE = VK_ERROR_FRAGMENTED_POOL, VK_RESULT_END_RANGE = VK_INCOMPLETE, VK_RESULT_RANGE_SIZE = (VK_INCOMPLETE - VK_ERROR_FRAGMENTED_POOL + 1), @@ -220,12 +222,117 @@ typedef enum VkStructureType { VK_STRUCTURE_TYPE_DEDICATED_ALLOCATION_IMAGE_CREATE_INFO_NV = 1000026000, VK_STRUCTURE_TYPE_DEDICATED_ALLOCATION_BUFFER_CREATE_INFO_NV = 1000026001, VK_STRUCTURE_TYPE_DEDICATED_ALLOCATION_MEMORY_ALLOCATE_INFO_NV = 1000026002, + VK_STRUCTURE_TYPE_TEXTURE_LOD_GATHER_FORMAT_PROPERTIES_AMD = 1000041000, + VK_STRUCTURE_TYPE_RENDER_PASS_MULTIVIEW_CREATE_INFO_KHX = 1000053000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MULTIVIEW_FEATURES_KHX = 1000053001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MULTIVIEW_PROPERTIES_KHX = 1000053002, VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMAGE_CREATE_INFO_NV = 1000056000, VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_NV = 1000056001, VK_STRUCTURE_TYPE_IMPORT_MEMORY_WIN32_HANDLE_INFO_NV = 1000057000, VK_STRUCTURE_TYPE_EXPORT_MEMORY_WIN32_HANDLE_INFO_NV = 1000057001, VK_STRUCTURE_TYPE_WIN32_KEYED_MUTEX_ACQUIRE_RELEASE_INFO_NV = 1000058000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2_KHR = 1000059000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2_KHR = 1000059001, + VK_STRUCTURE_TYPE_FORMAT_PROPERTIES_2_KHR = 1000059002, + VK_STRUCTURE_TYPE_IMAGE_FORMAT_PROPERTIES_2_KHR = 1000059003, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_IMAGE_FORMAT_INFO_2_KHR = 1000059004, + VK_STRUCTURE_TYPE_QUEUE_FAMILY_PROPERTIES_2_KHR = 1000059005, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2_KHR = 1000059006, + VK_STRUCTURE_TYPE_SPARSE_IMAGE_FORMAT_PROPERTIES_2_KHR = 1000059007, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SPARSE_IMAGE_FORMAT_INFO_2_KHR = 1000059008, + VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHX = 1000060000, + VK_STRUCTURE_TYPE_BIND_BUFFER_MEMORY_INFO_KHX = 1000060001, + VK_STRUCTURE_TYPE_BIND_IMAGE_MEMORY_INFO_KHX = 1000060002, + VK_STRUCTURE_TYPE_DEVICE_GROUP_RENDER_PASS_BEGIN_INFO_KHX = 1000060003, + VK_STRUCTURE_TYPE_DEVICE_GROUP_COMMAND_BUFFER_BEGIN_INFO_KHX = 1000060004, + VK_STRUCTURE_TYPE_DEVICE_GROUP_SUBMIT_INFO_KHX = 1000060005, + VK_STRUCTURE_TYPE_DEVICE_GROUP_BIND_SPARSE_INFO_KHX = 1000060006, + VK_STRUCTURE_TYPE_DEVICE_GROUP_PRESENT_CAPABILITIES_KHX = 1000060007, + VK_STRUCTURE_TYPE_IMAGE_SWAPCHAIN_CREATE_INFO_KHX = 1000060008, + VK_STRUCTURE_TYPE_BIND_IMAGE_MEMORY_SWAPCHAIN_INFO_KHX = 1000060009, + VK_STRUCTURE_TYPE_ACQUIRE_NEXT_IMAGE_INFO_KHX = 1000060010, + VK_STRUCTURE_TYPE_DEVICE_GROUP_PRESENT_INFO_KHX = 1000060011, + VK_STRUCTURE_TYPE_DEVICE_GROUP_SWAPCHAIN_CREATE_INFO_KHX = 1000060012, VK_STRUCTURE_TYPE_VALIDATION_FLAGS_EXT = 1000061000, + VK_STRUCTURE_TYPE_VI_SURFACE_CREATE_INFO_NN = 1000062000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_GROUP_PROPERTIES_KHX = 1000070000, + VK_STRUCTURE_TYPE_DEVICE_GROUP_DEVICE_CREATE_INFO_KHX = 1000070001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_IMAGE_FORMAT_INFO_KHR = 1000071000, + VK_STRUCTURE_TYPE_EXTERNAL_IMAGE_FORMAT_PROPERTIES_KHR = 1000071001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_BUFFER_INFO_KHR = 1000071002, + VK_STRUCTURE_TYPE_EXTERNAL_BUFFER_PROPERTIES_KHR = 1000071003, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES_KHR = 1000071004, + VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_BUFFER_CREATE_INFO_KHR = 1000072000, + VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMAGE_CREATE_INFO_KHR = 1000072001, + VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR = 1000072002, + VK_STRUCTURE_TYPE_IMPORT_MEMORY_WIN32_HANDLE_INFO_KHR = 1000073000, + VK_STRUCTURE_TYPE_EXPORT_MEMORY_WIN32_HANDLE_INFO_KHR = 1000073001, + VK_STRUCTURE_TYPE_MEMORY_WIN32_HANDLE_PROPERTIES_KHR = 1000073002, + VK_STRUCTURE_TYPE_MEMORY_GET_WIN32_HANDLE_INFO_KHR = 1000073003, + VK_STRUCTURE_TYPE_IMPORT_MEMORY_FD_INFO_KHR = 1000074000, + VK_STRUCTURE_TYPE_MEMORY_FD_PROPERTIES_KHR = 1000074001, + VK_STRUCTURE_TYPE_MEMORY_GET_FD_INFO_KHR = 1000074002, + VK_STRUCTURE_TYPE_WIN32_KEYED_MUTEX_ACQUIRE_RELEASE_INFO_KHR = 1000075000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_SEMAPHORE_INFO_KHR = 1000076000, + VK_STRUCTURE_TYPE_EXTERNAL_SEMAPHORE_PROPERTIES_KHR = 1000076001, + VK_STRUCTURE_TYPE_EXPORT_SEMAPHORE_CREATE_INFO_KHR = 1000077000, + VK_STRUCTURE_TYPE_IMPORT_SEMAPHORE_WIN32_HANDLE_INFO_KHR = 1000078000, + VK_STRUCTURE_TYPE_EXPORT_SEMAPHORE_WIN32_HANDLE_INFO_KHR = 1000078001, + VK_STRUCTURE_TYPE_D3D12_FENCE_SUBMIT_INFO_KHR = 1000078002, + VK_STRUCTURE_TYPE_SEMAPHORE_GET_WIN32_HANDLE_INFO_KHR = 1000078003, + VK_STRUCTURE_TYPE_IMPORT_SEMAPHORE_FD_INFO_KHR = 1000079000, + VK_STRUCTURE_TYPE_SEMAPHORE_GET_FD_INFO_KHR = 1000079001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PUSH_DESCRIPTOR_PROPERTIES_KHR = 1000080000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES_KHR = 1000083000, + VK_STRUCTURE_TYPE_PRESENT_REGIONS_KHR = 1000084000, + VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR = 1000085000, + VK_STRUCTURE_TYPE_OBJECT_TABLE_CREATE_INFO_NVX = 1000086000, + VK_STRUCTURE_TYPE_INDIRECT_COMMANDS_LAYOUT_CREATE_INFO_NVX = 1000086001, + VK_STRUCTURE_TYPE_CMD_PROCESS_COMMANDS_INFO_NVX = 1000086002, + VK_STRUCTURE_TYPE_CMD_RESERVE_SPACE_FOR_COMMANDS_INFO_NVX = 1000086003, + VK_STRUCTURE_TYPE_DEVICE_GENERATED_COMMANDS_LIMITS_NVX = 1000086004, + VK_STRUCTURE_TYPE_DEVICE_GENERATED_COMMANDS_FEATURES_NVX = 1000086005, + VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_W_SCALING_STATE_CREATE_INFO_NV = 1000087000, + VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES_2_EXT = 1000090000, + VK_STRUCTURE_TYPE_DISPLAY_POWER_INFO_EXT = 1000091000, + VK_STRUCTURE_TYPE_DEVICE_EVENT_INFO_EXT = 1000091001, + VK_STRUCTURE_TYPE_DISPLAY_EVENT_INFO_EXT = 1000091002, + VK_STRUCTURE_TYPE_SWAPCHAIN_COUNTER_CREATE_INFO_EXT = 1000091003, + VK_STRUCTURE_TYPE_PRESENT_TIMES_INFO_GOOGLE = 1000092000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MULTIVIEW_PER_VIEW_ATTRIBUTES_PROPERTIES_NVX = 1000097000, + VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_SWIZZLE_STATE_CREATE_INFO_NV = 1000098000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DISCARD_RECTANGLE_PROPERTIES_EXT = 1000099000, + VK_STRUCTURE_TYPE_PIPELINE_DISCARD_RECTANGLE_STATE_CREATE_INFO_EXT = 1000099001, + VK_STRUCTURE_TYPE_HDR_METADATA_EXT = 1000105000, + VK_STRUCTURE_TYPE_SHARED_PRESENT_SURFACE_CAPABILITIES_KHR = 1000111000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_FENCE_INFO_KHR = 1000112000, + VK_STRUCTURE_TYPE_EXTERNAL_FENCE_PROPERTIES_KHR = 1000112001, + VK_STRUCTURE_TYPE_EXPORT_FENCE_CREATE_INFO_KHR = 1000113000, + VK_STRUCTURE_TYPE_IMPORT_FENCE_WIN32_HANDLE_INFO_KHR = 1000114000, + VK_STRUCTURE_TYPE_EXPORT_FENCE_WIN32_HANDLE_INFO_KHR = 1000114001, + VK_STRUCTURE_TYPE_FENCE_GET_WIN32_HANDLE_INFO_KHR = 1000114002, + VK_STRUCTURE_TYPE_IMPORT_FENCE_FD_INFO_KHR = 1000115000, + VK_STRUCTURE_TYPE_FENCE_GET_FD_INFO_KHR = 1000115001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SURFACE_INFO_2_KHR = 1000119000, + VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES_2_KHR = 1000119001, + VK_STRUCTURE_TYPE_SURFACE_FORMAT_2_KHR = 1000119002, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTER_FEATURES_KHR = 1000120000, + VK_STRUCTURE_TYPE_IOS_SURFACE_CREATE_INFO_MVK = 1000122000, + VK_STRUCTURE_TYPE_MACOS_SURFACE_CREATE_INFO_MVK = 1000123000, + VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR = 1000127000, + VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR = 1000127001, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SAMPLER_FILTER_MINMAX_PROPERTIES_EXT = 1000130000, + VK_STRUCTURE_TYPE_SAMPLER_REDUCTION_MODE_CREATE_INFO_EXT = 1000130001, + VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR = 1000146000, + VK_STRUCTURE_TYPE_IMAGE_MEMORY_REQUIREMENTS_INFO_2_KHR = 1000146001, + VK_STRUCTURE_TYPE_IMAGE_SPARSE_MEMORY_REQUIREMENTS_INFO_2_KHR = 1000146002, + VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR = 1000146003, + VK_STRUCTURE_TYPE_SPARSE_IMAGE_MEMORY_REQUIREMENTS_2_KHR = 1000146004, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BLEND_OPERATION_ADVANCED_FEATURES_EXT = 1000148000, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BLEND_OPERATION_ADVANCED_PROPERTIES_EXT = 1000148001, + VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_ADVANCED_STATE_CREATE_INFO_EXT = 1000148002, + VK_STRUCTURE_TYPE_PIPELINE_COVERAGE_TO_COLOR_STATE_CREATE_INFO_NV = 1000149000, + VK_STRUCTURE_TYPE_PIPELINE_COVERAGE_MODULATION_STATE_CREATE_INFO_NV = 1000152000, VK_STRUCTURE_TYPE_BEGIN_RANGE = VK_STRUCTURE_TYPE_APPLICATION_INFO, VK_STRUCTURE_TYPE_END_RANGE = VK_STRUCTURE_TYPE_LOADER_DEVICE_CREATE_INFO, VK_STRUCTURE_TYPE_RANGE_SIZE = (VK_STRUCTURE_TYPE_LOADER_DEVICE_CREATE_INFO - VK_STRUCTURE_TYPE_APPLICATION_INFO + 1), @@ -513,6 +620,7 @@ typedef enum VkImageLayout { VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL = 7, VK_IMAGE_LAYOUT_PREINITIALIZED = 8, VK_IMAGE_LAYOUT_PRESENT_SRC_KHR = 1000001002, + VK_IMAGE_LAYOUT_SHARED_PRESENT_KHR = 1000111000, VK_IMAGE_LAYOUT_BEGIN_RANGE = VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_END_RANGE = VK_IMAGE_LAYOUT_PREINITIALIZED, VK_IMAGE_LAYOUT_RANGE_SIZE = (VK_IMAGE_LAYOUT_PREINITIALIZED - VK_IMAGE_LAYOUT_UNDEFINED + 1), @@ -578,6 +686,7 @@ typedef enum VkPolygonMode { VK_POLYGON_MODE_FILL = 0, VK_POLYGON_MODE_LINE = 1, VK_POLYGON_MODE_POINT = 2, + VK_POLYGON_MODE_FILL_RECTANGLE_NV = 1000153000, VK_POLYGON_MODE_BEGIN_RANGE = VK_POLYGON_MODE_FILL, VK_POLYGON_MODE_END_RANGE = VK_POLYGON_MODE_POINT, VK_POLYGON_MODE_RANGE_SIZE = (VK_POLYGON_MODE_POINT - VK_POLYGON_MODE_FILL + 1), @@ -678,6 +787,52 @@ typedef enum VkBlendOp { VK_BLEND_OP_REVERSE_SUBTRACT = 2, VK_BLEND_OP_MIN = 3, VK_BLEND_OP_MAX = 4, + VK_BLEND_OP_ZERO_EXT = 1000148000, + VK_BLEND_OP_SRC_EXT = 1000148001, + VK_BLEND_OP_DST_EXT = 1000148002, + VK_BLEND_OP_SRC_OVER_EXT = 1000148003, + VK_BLEND_OP_DST_OVER_EXT = 1000148004, + VK_BLEND_OP_SRC_IN_EXT = 1000148005, + VK_BLEND_OP_DST_IN_EXT = 1000148006, + VK_BLEND_OP_SRC_OUT_EXT = 1000148007, + VK_BLEND_OP_DST_OUT_EXT = 1000148008, + VK_BLEND_OP_SRC_ATOP_EXT = 1000148009, + VK_BLEND_OP_DST_ATOP_EXT = 1000148010, + VK_BLEND_OP_XOR_EXT = 1000148011, + VK_BLEND_OP_MULTIPLY_EXT = 1000148012, + VK_BLEND_OP_SCREEN_EXT = 1000148013, + VK_BLEND_OP_OVERLAY_EXT = 1000148014, + VK_BLEND_OP_DARKEN_EXT = 1000148015, + VK_BLEND_OP_LIGHTEN_EXT = 1000148016, + VK_BLEND_OP_COLORDODGE_EXT = 1000148017, + VK_BLEND_OP_COLORBURN_EXT = 1000148018, + VK_BLEND_OP_HARDLIGHT_EXT = 1000148019, + VK_BLEND_OP_SOFTLIGHT_EXT = 1000148020, + VK_BLEND_OP_DIFFERENCE_EXT = 1000148021, + VK_BLEND_OP_EXCLUSION_EXT = 1000148022, + VK_BLEND_OP_INVERT_EXT = 1000148023, + VK_BLEND_OP_INVERT_RGB_EXT = 1000148024, + VK_BLEND_OP_LINEARDODGE_EXT = 1000148025, + VK_BLEND_OP_LINEARBURN_EXT = 1000148026, + VK_BLEND_OP_VIVIDLIGHT_EXT = 1000148027, + VK_BLEND_OP_LINEARLIGHT_EXT = 1000148028, + VK_BLEND_OP_PINLIGHT_EXT = 1000148029, + VK_BLEND_OP_HARDMIX_EXT = 1000148030, + VK_BLEND_OP_HSL_HUE_EXT = 1000148031, + VK_BLEND_OP_HSL_SATURATION_EXT = 1000148032, + VK_BLEND_OP_HSL_COLOR_EXT = 1000148033, + VK_BLEND_OP_HSL_LUMINOSITY_EXT = 1000148034, + VK_BLEND_OP_PLUS_EXT = 1000148035, + VK_BLEND_OP_PLUS_CLAMPED_EXT = 1000148036, + VK_BLEND_OP_PLUS_CLAMPED_ALPHA_EXT = 1000148037, + VK_BLEND_OP_PLUS_DARKER_EXT = 1000148038, + VK_BLEND_OP_MINUS_EXT = 1000148039, + VK_BLEND_OP_MINUS_CLAMPED_EXT = 1000148040, + VK_BLEND_OP_CONTRAST_EXT = 1000148041, + VK_BLEND_OP_INVERT_OVG_EXT = 1000148042, + VK_BLEND_OP_RED_EXT = 1000148043, + VK_BLEND_OP_GREEN_EXT = 1000148044, + VK_BLEND_OP_BLUE_EXT = 1000148045, VK_BLEND_OP_BEGIN_RANGE = VK_BLEND_OP_ADD, VK_BLEND_OP_END_RANGE = VK_BLEND_OP_MAX, VK_BLEND_OP_RANGE_SIZE = (VK_BLEND_OP_MAX - VK_BLEND_OP_ADD + 1), @@ -694,6 +849,8 @@ typedef enum VkDynamicState { VK_DYNAMIC_STATE_STENCIL_COMPARE_MASK = 6, VK_DYNAMIC_STATE_STENCIL_WRITE_MASK = 7, VK_DYNAMIC_STATE_STENCIL_REFERENCE = 8, + VK_DYNAMIC_STATE_VIEWPORT_W_SCALING_NV = 1000087000, + VK_DYNAMIC_STATE_DISCARD_RECTANGLE_EXT = 1000099000, VK_DYNAMIC_STATE_BEGIN_RANGE = VK_DYNAMIC_STATE_VIEWPORT, VK_DYNAMIC_STATE_END_RANGE = VK_DYNAMIC_STATE_STENCIL_REFERENCE, VK_DYNAMIC_STATE_RANGE_SIZE = (VK_DYNAMIC_STATE_STENCIL_REFERENCE - VK_DYNAMIC_STATE_VIEWPORT + 1), @@ -817,6 +974,47 @@ typedef enum VkSubpassContents { VK_SUBPASS_CONTENTS_MAX_ENUM = 0x7FFFFFFF } VkSubpassContents; +typedef enum VkObjectType { + VK_OBJECT_TYPE_UNKNOWN = 0, + VK_OBJECT_TYPE_INSTANCE = 1, + VK_OBJECT_TYPE_PHYSICAL_DEVICE = 2, + VK_OBJECT_TYPE_DEVICE = 3, + VK_OBJECT_TYPE_QUEUE = 4, + VK_OBJECT_TYPE_SEMAPHORE = 5, + VK_OBJECT_TYPE_COMMAND_BUFFER = 6, + VK_OBJECT_TYPE_FENCE = 7, + VK_OBJECT_TYPE_DEVICE_MEMORY = 8, + VK_OBJECT_TYPE_BUFFER = 9, + VK_OBJECT_TYPE_IMAGE = 10, + VK_OBJECT_TYPE_EVENT = 11, + VK_OBJECT_TYPE_QUERY_POOL = 12, + VK_OBJECT_TYPE_BUFFER_VIEW = 13, + VK_OBJECT_TYPE_IMAGE_VIEW = 14, + VK_OBJECT_TYPE_SHADER_MODULE = 15, + VK_OBJECT_TYPE_PIPELINE_CACHE = 16, + VK_OBJECT_TYPE_PIPELINE_LAYOUT = 17, + VK_OBJECT_TYPE_RENDER_PASS = 18, + VK_OBJECT_TYPE_PIPELINE = 19, + VK_OBJECT_TYPE_DESCRIPTOR_SET_LAYOUT = 20, + VK_OBJECT_TYPE_SAMPLER = 21, + VK_OBJECT_TYPE_DESCRIPTOR_POOL = 22, + VK_OBJECT_TYPE_DESCRIPTOR_SET = 23, + VK_OBJECT_TYPE_FRAMEBUFFER = 24, + VK_OBJECT_TYPE_COMMAND_POOL = 25, + VK_OBJECT_TYPE_SURFACE_KHR = 1000000000, + VK_OBJECT_TYPE_SWAPCHAIN_KHR = 1000001000, + VK_OBJECT_TYPE_DISPLAY_KHR = 1000002000, + VK_OBJECT_TYPE_DISPLAY_MODE_KHR = 1000002001, + VK_OBJECT_TYPE_DEBUG_REPORT_CALLBACK_EXT = 1000011000, + VK_OBJECT_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_KHR = 1000085000, + VK_OBJECT_TYPE_OBJECT_TABLE_NVX = 1000086000, + VK_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX = 1000086001, + VK_OBJECT_TYPE_BEGIN_RANGE = VK_OBJECT_TYPE_UNKNOWN, + VK_OBJECT_TYPE_END_RANGE = VK_OBJECT_TYPE_COMMAND_POOL, + VK_OBJECT_TYPE_RANGE_SIZE = (VK_OBJECT_TYPE_COMMAND_POOL - VK_OBJECT_TYPE_UNKNOWN + 1), + VK_OBJECT_TYPE_MAX_ENUM = 0x7FFFFFFF +} VkObjectType; + typedef VkFlags VkInstanceCreateFlags; typedef enum VkFormatFeatureFlagBits { @@ -834,6 +1032,9 @@ typedef enum VkFormatFeatureFlagBits { VK_FORMAT_FEATURE_BLIT_DST_BIT = 0x00000800, VK_FORMAT_FEATURE_SAMPLED_IMAGE_FILTER_LINEAR_BIT = 0x00001000, VK_FORMAT_FEATURE_SAMPLED_IMAGE_FILTER_CUBIC_BIT_IMG = 0x00002000, + VK_FORMAT_FEATURE_TRANSFER_SRC_BIT_KHR = 0x00004000, + VK_FORMAT_FEATURE_TRANSFER_DST_BIT_KHR = 0x00008000, + VK_FORMAT_FEATURE_SAMPLED_IMAGE_FILTER_MINMAX_BIT_EXT = 0x00010000, VK_FORMAT_FEATURE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkFormatFeatureFlagBits; typedef VkFlags VkFormatFeatureFlags; @@ -857,6 +1058,8 @@ typedef enum VkImageCreateFlagBits { VK_IMAGE_CREATE_SPARSE_ALIASED_BIT = 0x00000004, VK_IMAGE_CREATE_MUTABLE_FORMAT_BIT = 0x00000008, VK_IMAGE_CREATE_CUBE_COMPATIBLE_BIT = 0x00000010, + VK_IMAGE_CREATE_BIND_SFR_BIT_KHX = 0x00000040, + VK_IMAGE_CREATE_2D_ARRAY_COMPATIBLE_BIT_KHR = 0x00000020, VK_IMAGE_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkImageCreateFlagBits; typedef VkFlags VkImageCreateFlags; @@ -894,6 +1097,7 @@ typedef VkFlags VkMemoryPropertyFlags; typedef enum VkMemoryHeapFlagBits { VK_MEMORY_HEAP_DEVICE_LOCAL_BIT = 0x00000001, + VK_MEMORY_HEAP_MULTI_INSTANCE_BIT_KHX = 0x00000002, VK_MEMORY_HEAP_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkMemoryHeapFlagBits; typedef VkFlags VkMemoryHeapFlags; @@ -918,6 +1122,7 @@ typedef enum VkPipelineStageFlagBits { VK_PIPELINE_STAGE_HOST_BIT = 0x00004000, VK_PIPELINE_STAGE_ALL_GRAPHICS_BIT = 0x00008000, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT = 0x00010000, + VK_PIPELINE_STAGE_COMMAND_PROCESS_BIT_NVX = 0x00020000, VK_PIPELINE_STAGE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkPipelineStageFlagBits; typedef VkFlags VkPipelineStageFlags; @@ -1010,6 +1215,8 @@ typedef enum VkPipelineCreateFlagBits { VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT = 0x00000001, VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT = 0x00000002, VK_PIPELINE_CREATE_DERIVATIVE_BIT = 0x00000004, + VK_PIPELINE_CREATE_VIEW_INDEX_FROM_DEVICE_INDEX_BIT_KHX = 0x00000008, + VK_PIPELINE_CREATE_DISPATCH_BASE_KHX = 0x00000010, VK_PIPELINE_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkPipelineCreateFlagBits; typedef VkFlags VkPipelineCreateFlags; @@ -1056,6 +1263,11 @@ typedef VkFlags VkPipelineDynamicStateCreateFlags; typedef VkFlags VkPipelineLayoutCreateFlags; typedef VkFlags VkShaderStageFlags; typedef VkFlags VkSamplerCreateFlags; + +typedef enum VkDescriptorSetLayoutCreateFlagBits { + VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR = 0x00000001, + VK_DESCRIPTOR_SET_LAYOUT_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VkDescriptorSetLayoutCreateFlagBits; typedef VkFlags VkDescriptorSetLayoutCreateFlags; typedef enum VkDescriptorPoolCreateFlagBits { @@ -1072,6 +1284,12 @@ typedef enum VkAttachmentDescriptionFlagBits { VK_ATTACHMENT_DESCRIPTION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkAttachmentDescriptionFlagBits; typedef VkFlags VkAttachmentDescriptionFlags; + +typedef enum VkSubpassDescriptionFlagBits { + VK_SUBPASS_DESCRIPTION_PER_VIEW_ATTRIBUTES_BIT_NVX = 0x00000001, + VK_SUBPASS_DESCRIPTION_PER_VIEW_POSITION_X_ONLY_BIT_NVX = 0x00000002, + VK_SUBPASS_DESCRIPTION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VkSubpassDescriptionFlagBits; typedef VkFlags VkSubpassDescriptionFlags; typedef enum VkAccessFlagBits { @@ -1092,12 +1310,17 @@ typedef enum VkAccessFlagBits { VK_ACCESS_HOST_WRITE_BIT = 0x00004000, VK_ACCESS_MEMORY_READ_BIT = 0x00008000, VK_ACCESS_MEMORY_WRITE_BIT = 0x00010000, + VK_ACCESS_COMMAND_PROCESS_READ_BIT_NVX = 0x00020000, + VK_ACCESS_COMMAND_PROCESS_WRITE_BIT_NVX = 0x00040000, + VK_ACCESS_COLOR_ATTACHMENT_READ_NONCOHERENT_BIT_EXT = 0x00080000, VK_ACCESS_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkAccessFlagBits; typedef VkFlags VkAccessFlags; typedef enum VkDependencyFlagBits { VK_DEPENDENCY_BY_REGION_BIT = 0x00000001, + VK_DEPENDENCY_VIEW_LOCAL_BIT_KHX = 0x00000002, + VK_DEPENDENCY_DEVICE_GROUP_BIT_KHX = 0x00000004, VK_DEPENDENCY_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF } VkDependencyFlagBits; typedef VkFlags VkDependencyFlags; @@ -1143,6 +1366,27 @@ typedef enum VkStencilFaceFlagBits { } VkStencilFaceFlagBits; typedef VkFlags VkStencilFaceFlags; +typedef struct VkApplicationInfo { + VkStructureType sType; + const void* pNext; + const char* pApplicationName; + uint32_t applicationVersion; + const char* pEngineName; + uint32_t engineVersion; + uint32_t apiVersion; +} VkApplicationInfo; + +typedef struct VkInstanceCreateInfo { + VkStructureType sType; + const void* pNext; + VkInstanceCreateFlags flags; + const VkApplicationInfo* pApplicationInfo; + uint32_t enabledLayerCount; + const char* const* ppEnabledLayerNames; + uint32_t enabledExtensionCount; + const char* const* ppEnabledExtensionNames; +} VkInstanceCreateInfo; + typedef void* (VKAPI_PTR *PFN_vkAllocationFunction)( void* pUserData, size_t size, @@ -1172,29 +1416,6 @@ typedef void (VKAPI_PTR *PFN_vkInternalFreeNotification)( VkInternalAllocationType allocationType, VkSystemAllocationScope allocationScope); -typedef void (VKAPI_PTR *PFN_vkVoidFunction)(void); - -typedef struct VkApplicationInfo { - VkStructureType sType; - const void* pNext; - const char* pApplicationName; - uint32_t applicationVersion; - const char* pEngineName; - uint32_t engineVersion; - uint32_t apiVersion; -} VkApplicationInfo; - -typedef struct VkInstanceCreateInfo { - VkStructureType sType; - const void* pNext; - VkInstanceCreateFlags flags; - const VkApplicationInfo* pApplicationInfo; - uint32_t enabledLayerCount; - const char* const* ppEnabledLayerNames; - uint32_t enabledExtensionCount; - const char* const* ppEnabledExtensionNames; -} VkInstanceCreateInfo; - typedef struct VkAllocationCallbacks { void* pUserData; PFN_vkAllocationFunction pfnAllocation; @@ -1435,6 +1656,7 @@ typedef struct VkPhysicalDeviceMemoryProperties { VkMemoryHeap memoryHeaps[VK_MAX_MEMORY_HEAPS]; } VkPhysicalDeviceMemoryProperties; +typedef void (VKAPI_PTR *PFN_vkVoidFunction)(void); typedef struct VkDeviceQueueCreateInfo { VkStructureType sType; const void* pNext; @@ -2360,7 +2582,7 @@ typedef void (VKAPI_PTR *PFN_vkCmdDraw)(VkCommandBuffer commandBuffer, uint32_t typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexed)(VkCommandBuffer commandBuffer, uint32_t indexCount, uint32_t instanceCount, uint32_t firstIndex, int32_t vertexOffset, uint32_t firstInstance); typedef void (VKAPI_PTR *PFN_vkCmdDrawIndirect)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, uint32_t drawCount, uint32_t stride); typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexedIndirect)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, uint32_t drawCount, uint32_t stride); -typedef void (VKAPI_PTR *PFN_vkCmdDispatch)(VkCommandBuffer commandBuffer, uint32_t x, uint32_t y, uint32_t z); +typedef void (VKAPI_PTR *PFN_vkCmdDispatch)(VkCommandBuffer commandBuffer, uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ); typedef void (VKAPI_PTR *PFN_vkCmdDispatchIndirect)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset); typedef void (VKAPI_PTR *PFN_vkCmdCopyBuffer)(VkCommandBuffer commandBuffer, VkBuffer srcBuffer, VkBuffer dstBuffer, uint32_t regionCount, const VkBufferCopy* pRegions); typedef void (VKAPI_PTR *PFN_vkCmdCopyImage)(VkCommandBuffer commandBuffer, VkImage srcImage, VkImageLayout srcImageLayout, VkImage dstImage, VkImageLayout dstImageLayout, uint32_t regionCount, const VkImageCopy* pRegions); @@ -2996,9 +3218,9 @@ VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndexedIndirect( VKAPI_ATTR void VKAPI_CALL vkCmdDispatch( VkCommandBuffer commandBuffer, - uint32_t x, - uint32_t y, - uint32_t z); + uint32_t groupCountX, + uint32_t groupCountY, + uint32_t groupCountZ); VKAPI_ATTR void VKAPI_CALL vkCmdDispatchIndirect( VkCommandBuffer commandBuffer, @@ -3197,6 +3419,20 @@ VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkSurfaceKHR) typedef enum VkColorSpaceKHR { VK_COLOR_SPACE_SRGB_NONLINEAR_KHR = 0, + VK_COLOR_SPACE_DISPLAY_P3_NONLINEAR_EXT = 1000104001, + VK_COLOR_SPACE_EXTENDED_SRGB_LINEAR_EXT = 1000104002, + VK_COLOR_SPACE_DCI_P3_LINEAR_EXT = 1000104003, + VK_COLOR_SPACE_DCI_P3_NONLINEAR_EXT = 1000104004, + VK_COLOR_SPACE_BT709_LINEAR_EXT = 1000104005, + VK_COLOR_SPACE_BT709_NONLINEAR_EXT = 1000104006, + VK_COLOR_SPACE_BT2020_LINEAR_EXT = 1000104007, + VK_COLOR_SPACE_HDR10_ST2084_EXT = 1000104008, + VK_COLOR_SPACE_DOLBYVISION_EXT = 1000104009, + VK_COLOR_SPACE_HDR10_HLG_EXT = 1000104010, + VK_COLOR_SPACE_ADOBERGB_LINEAR_EXT = 1000104011, + VK_COLOR_SPACE_ADOBERGB_NONLINEAR_EXT = 1000104012, + VK_COLOR_SPACE_PASS_THROUGH_EXT = 1000104013, + VK_COLOR_SPACE_EXTENDED_SRGB_NONLINEAR_EXT = 1000104014, VK_COLOR_SPACE_BEGIN_RANGE_KHR = VK_COLOR_SPACE_SRGB_NONLINEAR_KHR, VK_COLOR_SPACE_END_RANGE_KHR = VK_COLOR_SPACE_SRGB_NONLINEAR_KHR, VK_COLOR_SPACE_RANGE_SIZE_KHR = (VK_COLOR_SPACE_SRGB_NONLINEAR_KHR - VK_COLOR_SPACE_SRGB_NONLINEAR_KHR + 1), @@ -3208,6 +3444,8 @@ typedef enum VkPresentModeKHR { VK_PRESENT_MODE_MAILBOX_KHR = 1, VK_PRESENT_MODE_FIFO_KHR = 2, VK_PRESENT_MODE_FIFO_RELAXED_KHR = 3, + VK_PRESENT_MODE_SHARED_DEMAND_REFRESH_KHR = 1000111000, + VK_PRESENT_MODE_SHARED_CONTINUOUS_REFRESH_KHR = 1000111001, VK_PRESENT_MODE_BEGIN_RANGE_KHR = VK_PRESENT_MODE_IMMEDIATE_KHR, VK_PRESENT_MODE_END_RANGE_KHR = VK_PRESENT_MODE_FIFO_RELAXED_KHR, VK_PRESENT_MODE_RANGE_SIZE_KHR = (VK_PRESENT_MODE_FIFO_RELAXED_KHR - VK_PRESENT_MODE_IMMEDIATE_KHR + 1), @@ -3299,6 +3537,11 @@ VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkSwapchainKHR) #define VK_KHR_SWAPCHAIN_SPEC_VERSION 68 #define VK_KHR_SWAPCHAIN_EXTENSION_NAME "VK_KHR_swapchain" + +typedef enum VkSwapchainCreateFlagBitsKHR { + VK_SWAPCHAIN_CREATE_BIND_SFR_BIT_KHX = 0x00000001, + VK_SWAPCHAIN_CREATE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkSwapchainCreateFlagBitsKHR; typedef VkFlags VkSwapchainCreateFlagsKHR; typedef struct VkSwapchainCreateInfoKHR { @@ -3599,7 +3842,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL vkGetPhysicalDeviceXcbPresentationSupportKHR( #define VK_KHR_wayland_surface 1 #include -#define VK_KHR_WAYLAND_SURFACE_SPEC_VERSION 5 +#define VK_KHR_WAYLAND_SURFACE_SPEC_VERSION 6 #define VK_KHR_WAYLAND_SURFACE_EXTENSION_NAME "VK_KHR_wayland_surface" typedef VkFlags VkWaylandSurfaceCreateFlagsKHR; @@ -3697,7 +3940,7 @@ VKAPI_ATTR VkResult VKAPI_CALL vkCreateAndroidSurfaceKHR( #define VK_KHR_win32_surface 1 #include -#define VK_KHR_WIN32_SURFACE_SPEC_VERSION 5 +#define VK_KHR_WIN32_SURFACE_SPEC_VERSION 6 #define VK_KHR_WIN32_SURFACE_EXTENSION_NAME "VK_KHR_win32_surface" typedef VkFlags VkWin32SurfaceCreateFlagsKHR; @@ -3732,426 +3975,2480 @@ VKAPI_ATTR VkBool32 VKAPI_CALL vkGetPhysicalDeviceWin32PresentationSupportKHR( #define VK_KHR_SAMPLER_MIRROR_CLAMP_TO_EDGE_EXTENSION_NAME "VK_KHR_sampler_mirror_clamp_to_edge" -#define VK_EXT_debug_report 1 -VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkDebugReportCallbackEXT) - -#define VK_EXT_DEBUG_REPORT_SPEC_VERSION 3 -#define VK_EXT_DEBUG_REPORT_EXTENSION_NAME "VK_EXT_debug_report" -#define VK_STRUCTURE_TYPE_DEBUG_REPORT_CREATE_INFO_EXT VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT +#define VK_KHR_get_physical_device_properties2 1 +#define VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_SPEC_VERSION 1 +#define VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME "VK_KHR_get_physical_device_properties2" +typedef struct VkPhysicalDeviceFeatures2KHR { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceFeatures features; +} VkPhysicalDeviceFeatures2KHR; -typedef enum VkDebugReportObjectTypeEXT { - VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT = 0, - VK_DEBUG_REPORT_OBJECT_TYPE_INSTANCE_EXT = 1, - VK_DEBUG_REPORT_OBJECT_TYPE_PHYSICAL_DEVICE_EXT = 2, - VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_EXT = 3, - VK_DEBUG_REPORT_OBJECT_TYPE_QUEUE_EXT = 4, - VK_DEBUG_REPORT_OBJECT_TYPE_SEMAPHORE_EXT = 5, - VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT = 6, - VK_DEBUG_REPORT_OBJECT_TYPE_FENCE_EXT = 7, - VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_MEMORY_EXT = 8, - VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_EXT = 9, - VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_EXT = 10, - VK_DEBUG_REPORT_OBJECT_TYPE_EVENT_EXT = 11, - VK_DEBUG_REPORT_OBJECT_TYPE_QUERY_POOL_EXT = 12, - VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_VIEW_EXT = 13, - VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_VIEW_EXT = 14, - VK_DEBUG_REPORT_OBJECT_TYPE_SHADER_MODULE_EXT = 15, - VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_CACHE_EXT = 16, - VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_LAYOUT_EXT = 17, - VK_DEBUG_REPORT_OBJECT_TYPE_RENDER_PASS_EXT = 18, - VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_EXT = 19, - VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_LAYOUT_EXT = 20, - VK_DEBUG_REPORT_OBJECT_TYPE_SAMPLER_EXT = 21, - VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_POOL_EXT = 22, - VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_EXT = 23, - VK_DEBUG_REPORT_OBJECT_TYPE_FRAMEBUFFER_EXT = 24, - VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_POOL_EXT = 25, - VK_DEBUG_REPORT_OBJECT_TYPE_SURFACE_KHR_EXT = 26, - VK_DEBUG_REPORT_OBJECT_TYPE_SWAPCHAIN_KHR_EXT = 27, - VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT = 28, - VK_DEBUG_REPORT_OBJECT_TYPE_BEGIN_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT, - VK_DEBUG_REPORT_OBJECT_TYPE_END_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT, - VK_DEBUG_REPORT_OBJECT_TYPE_RANGE_SIZE_EXT = (VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT - VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT + 1), - VK_DEBUG_REPORT_OBJECT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF -} VkDebugReportObjectTypeEXT; +typedef struct VkPhysicalDeviceProperties2KHR { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceProperties properties; +} VkPhysicalDeviceProperties2KHR; -typedef enum VkDebugReportErrorEXT { - VK_DEBUG_REPORT_ERROR_NONE_EXT = 0, - VK_DEBUG_REPORT_ERROR_CALLBACK_REF_EXT = 1, - VK_DEBUG_REPORT_ERROR_BEGIN_RANGE_EXT = VK_DEBUG_REPORT_ERROR_NONE_EXT, - VK_DEBUG_REPORT_ERROR_END_RANGE_EXT = VK_DEBUG_REPORT_ERROR_CALLBACK_REF_EXT, - VK_DEBUG_REPORT_ERROR_RANGE_SIZE_EXT = (VK_DEBUG_REPORT_ERROR_CALLBACK_REF_EXT - VK_DEBUG_REPORT_ERROR_NONE_EXT + 1), - VK_DEBUG_REPORT_ERROR_MAX_ENUM_EXT = 0x7FFFFFFF -} VkDebugReportErrorEXT; +typedef struct VkFormatProperties2KHR { + VkStructureType sType; + void* pNext; + VkFormatProperties formatProperties; +} VkFormatProperties2KHR; +typedef struct VkImageFormatProperties2KHR { + VkStructureType sType; + void* pNext; + VkImageFormatProperties imageFormatProperties; +} VkImageFormatProperties2KHR; -typedef enum VkDebugReportFlagBitsEXT { - VK_DEBUG_REPORT_INFORMATION_BIT_EXT = 0x00000001, - VK_DEBUG_REPORT_WARNING_BIT_EXT = 0x00000002, - VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT = 0x00000004, - VK_DEBUG_REPORT_ERROR_BIT_EXT = 0x00000008, - VK_DEBUG_REPORT_DEBUG_BIT_EXT = 0x00000010, - VK_DEBUG_REPORT_FLAG_BITS_MAX_ENUM_EXT = 0x7FFFFFFF -} VkDebugReportFlagBitsEXT; -typedef VkFlags VkDebugReportFlagsEXT; +typedef struct VkPhysicalDeviceImageFormatInfo2KHR { + VkStructureType sType; + const void* pNext; + VkFormat format; + VkImageType type; + VkImageTiling tiling; + VkImageUsageFlags usage; + VkImageCreateFlags flags; +} VkPhysicalDeviceImageFormatInfo2KHR; + +typedef struct VkQueueFamilyProperties2KHR { + VkStructureType sType; + void* pNext; + VkQueueFamilyProperties queueFamilyProperties; +} VkQueueFamilyProperties2KHR; -typedef VkBool32 (VKAPI_PTR *PFN_vkDebugReportCallbackEXT)( - VkDebugReportFlagsEXT flags, - VkDebugReportObjectTypeEXT objectType, - uint64_t object, - size_t location, - int32_t messageCode, - const char* pLayerPrefix, - const char* pMessage, - void* pUserData); +typedef struct VkPhysicalDeviceMemoryProperties2KHR { + VkStructureType sType; + void* pNext; + VkPhysicalDeviceMemoryProperties memoryProperties; +} VkPhysicalDeviceMemoryProperties2KHR; +typedef struct VkSparseImageFormatProperties2KHR { + VkStructureType sType; + void* pNext; + VkSparseImageFormatProperties properties; +} VkSparseImageFormatProperties2KHR; -typedef struct VkDebugReportCallbackCreateInfoEXT { - VkStructureType sType; - const void* pNext; - VkDebugReportFlagsEXT flags; - PFN_vkDebugReportCallbackEXT pfnCallback; - void* pUserData; -} VkDebugReportCallbackCreateInfoEXT; +typedef struct VkPhysicalDeviceSparseImageFormatInfo2KHR { + VkStructureType sType; + const void* pNext; + VkFormat format; + VkImageType type; + VkSampleCountFlagBits samples; + VkImageUsageFlags usage; + VkImageTiling tiling; +} VkPhysicalDeviceSparseImageFormatInfo2KHR; -typedef VkResult (VKAPI_PTR *PFN_vkCreateDebugReportCallbackEXT)(VkInstance instance, const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugReportCallbackEXT* pCallback); -typedef void (VKAPI_PTR *PFN_vkDestroyDebugReportCallbackEXT)(VkInstance instance, VkDebugReportCallbackEXT callback, const VkAllocationCallbacks* pAllocator); -typedef void (VKAPI_PTR *PFN_vkDebugReportMessageEXT)(VkInstance instance, VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objectType, uint64_t object, size_t location, int32_t messageCode, const char* pLayerPrefix, const char* pMessage); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceFeatures2KHR)(VkPhysicalDevice physicalDevice, VkPhysicalDeviceFeatures2KHR* pFeatures); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceProperties2KHR)(VkPhysicalDevice physicalDevice, VkPhysicalDeviceProperties2KHR* pProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceFormatProperties2KHR)(VkPhysicalDevice physicalDevice, VkFormat format, VkFormatProperties2KHR* pFormatProperties); +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceImageFormatProperties2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceImageFormatInfo2KHR* pImageFormatInfo, VkImageFormatProperties2KHR* pImageFormatProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceQueueFamilyProperties2KHR)(VkPhysicalDevice physicalDevice, uint32_t* pQueueFamilyPropertyCount, VkQueueFamilyProperties2KHR* pQueueFamilyProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceMemoryProperties2KHR)(VkPhysicalDevice physicalDevice, VkPhysicalDeviceMemoryProperties2KHR* pMemoryProperties); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceSparseImageFormatProperties2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceSparseImageFormatInfo2KHR* pFormatInfo, uint32_t* pPropertyCount, VkSparseImageFormatProperties2KHR* pProperties); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkCreateDebugReportCallbackEXT( - VkInstance instance, - const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, - const VkAllocationCallbacks* pAllocator, - VkDebugReportCallbackEXT* pCallback); +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceFeatures2KHR( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceFeatures2KHR* pFeatures); -VKAPI_ATTR void VKAPI_CALL vkDestroyDebugReportCallbackEXT( - VkInstance instance, - VkDebugReportCallbackEXT callback, - const VkAllocationCallbacks* pAllocator); +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceProperties2KHR( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceProperties2KHR* pProperties); -VKAPI_ATTR void VKAPI_CALL vkDebugReportMessageEXT( - VkInstance instance, - VkDebugReportFlagsEXT flags, - VkDebugReportObjectTypeEXT objectType, - uint64_t object, - size_t location, - int32_t messageCode, - const char* pLayerPrefix, - const char* pMessage); +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceFormatProperties2KHR( + VkPhysicalDevice physicalDevice, + VkFormat format, + VkFormatProperties2KHR* pFormatProperties); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceImageFormatProperties2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceImageFormatInfo2KHR* pImageFormatInfo, + VkImageFormatProperties2KHR* pImageFormatProperties); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceQueueFamilyProperties2KHR( + VkPhysicalDevice physicalDevice, + uint32_t* pQueueFamilyPropertyCount, + VkQueueFamilyProperties2KHR* pQueueFamilyProperties); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceMemoryProperties2KHR( + VkPhysicalDevice physicalDevice, + VkPhysicalDeviceMemoryProperties2KHR* pMemoryProperties); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceSparseImageFormatProperties2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceSparseImageFormatInfo2KHR* pFormatInfo, + uint32_t* pPropertyCount, + VkSparseImageFormatProperties2KHR* pProperties); #endif -#define VK_NV_glsl_shader 1 -#define VK_NV_GLSL_SHADER_SPEC_VERSION 1 -#define VK_NV_GLSL_SHADER_EXTENSION_NAME "VK_NV_glsl_shader" +#define VK_KHR_shader_draw_parameters 1 +#define VK_KHR_SHADER_DRAW_PARAMETERS_SPEC_VERSION 1 +#define VK_KHR_SHADER_DRAW_PARAMETERS_EXTENSION_NAME "VK_KHR_shader_draw_parameters" -#define VK_IMG_filter_cubic 1 -#define VK_IMG_FILTER_CUBIC_SPEC_VERSION 1 -#define VK_IMG_FILTER_CUBIC_EXTENSION_NAME "VK_IMG_filter_cubic" +#define VK_KHR_maintenance1 1 +#define VK_KHR_MAINTENANCE1_SPEC_VERSION 1 +#define VK_KHR_MAINTENANCE1_EXTENSION_NAME "VK_KHR_maintenance1" +typedef VkFlags VkCommandPoolTrimFlagsKHR; -#define VK_AMD_rasterization_order 1 -#define VK_AMD_RASTERIZATION_ORDER_SPEC_VERSION 1 -#define VK_AMD_RASTERIZATION_ORDER_EXTENSION_NAME "VK_AMD_rasterization_order" +typedef void (VKAPI_PTR *PFN_vkTrimCommandPoolKHR)(VkDevice device, VkCommandPool commandPool, VkCommandPoolTrimFlagsKHR flags); +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkTrimCommandPoolKHR( + VkDevice device, + VkCommandPool commandPool, + VkCommandPoolTrimFlagsKHR flags); +#endif -typedef enum VkRasterizationOrderAMD { - VK_RASTERIZATION_ORDER_STRICT_AMD = 0, - VK_RASTERIZATION_ORDER_RELAXED_AMD = 1, - VK_RASTERIZATION_ORDER_BEGIN_RANGE_AMD = VK_RASTERIZATION_ORDER_STRICT_AMD, - VK_RASTERIZATION_ORDER_END_RANGE_AMD = VK_RASTERIZATION_ORDER_RELAXED_AMD, - VK_RASTERIZATION_ORDER_RANGE_SIZE_AMD = (VK_RASTERIZATION_ORDER_RELAXED_AMD - VK_RASTERIZATION_ORDER_STRICT_AMD + 1), - VK_RASTERIZATION_ORDER_MAX_ENUM_AMD = 0x7FFFFFFF -} VkRasterizationOrderAMD; +#define VK_KHR_external_memory_capabilities 1 +#define VK_LUID_SIZE_KHR 8 +#define VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME "VK_KHR_external_memory_capabilities" + + +typedef enum VkExternalMemoryHandleTypeFlagBitsKHR { + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR = 0x00000001, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR = 0x00000002, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_KHR = 0x00000004, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_TEXTURE_BIT_KHR = 0x00000008, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_TEXTURE_KMT_BIT_KHR = 0x00000010, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP_BIT_KHR = 0x00000020, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE_BIT_KHR = 0x00000040, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalMemoryHandleTypeFlagBitsKHR; +typedef VkFlags VkExternalMemoryHandleTypeFlagsKHR; + +typedef enum VkExternalMemoryFeatureFlagBitsKHR { + VK_EXTERNAL_MEMORY_FEATURE_DEDICATED_ONLY_BIT_KHR = 0x00000001, + VK_EXTERNAL_MEMORY_FEATURE_EXPORTABLE_BIT_KHR = 0x00000002, + VK_EXTERNAL_MEMORY_FEATURE_IMPORTABLE_BIT_KHR = 0x00000004, + VK_EXTERNAL_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalMemoryFeatureFlagBitsKHR; +typedef VkFlags VkExternalMemoryFeatureFlagsKHR; + +typedef struct VkExternalMemoryPropertiesKHR { + VkExternalMemoryFeatureFlagsKHR externalMemoryFeatures; + VkExternalMemoryHandleTypeFlagsKHR exportFromImportedHandleTypes; + VkExternalMemoryHandleTypeFlagsKHR compatibleHandleTypes; +} VkExternalMemoryPropertiesKHR; + +typedef struct VkPhysicalDeviceExternalImageFormatInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalImageFormatInfoKHR; -typedef struct VkPipelineRasterizationStateRasterizationOrderAMD { - VkStructureType sType; - const void* pNext; - VkRasterizationOrderAMD rasterizationOrder; -} VkPipelineRasterizationStateRasterizationOrderAMD; +typedef struct VkExternalImageFormatPropertiesKHR { + VkStructureType sType; + void* pNext; + VkExternalMemoryPropertiesKHR externalMemoryProperties; +} VkExternalImageFormatPropertiesKHR; + +typedef struct VkPhysicalDeviceExternalBufferInfoKHR { + VkStructureType sType; + const void* pNext; + VkBufferCreateFlags flags; + VkBufferUsageFlags usage; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalBufferInfoKHR; +typedef struct VkExternalBufferPropertiesKHR { + VkStructureType sType; + void* pNext; + VkExternalMemoryPropertiesKHR externalMemoryProperties; +} VkExternalBufferPropertiesKHR; +typedef struct VkPhysicalDeviceIDPropertiesKHR { + VkStructureType sType; + void* pNext; + uint8_t deviceUUID[VK_UUID_SIZE]; + uint8_t driverUUID[VK_UUID_SIZE]; + uint8_t deviceLUID[VK_LUID_SIZE_KHR]; + uint32_t deviceNodeMask; + VkBool32 deviceLUIDValid; +} VkPhysicalDeviceIDPropertiesKHR; -#define VK_AMD_shader_trinary_minmax 1 -#define VK_AMD_SHADER_TRINARY_MINMAX_SPEC_VERSION 1 -#define VK_AMD_SHADER_TRINARY_MINMAX_EXTENSION_NAME "VK_AMD_shader_trinary_minmax" +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalBufferPropertiesKHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceExternalBufferInfoKHR* pExternalBufferInfo, VkExternalBufferPropertiesKHR* pExternalBufferProperties); -#define VK_AMD_shader_explicit_vertex_parameter 1 -#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_SPEC_VERSION 1 -#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_EXTENSION_NAME "VK_AMD_shader_explicit_vertex_parameter" +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceExternalBufferPropertiesKHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceExternalBufferInfoKHR* pExternalBufferInfo, + VkExternalBufferPropertiesKHR* pExternalBufferProperties); +#endif +#define VK_KHR_external_memory 1 +#define VK_KHR_EXTERNAL_MEMORY_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME "VK_KHR_external_memory" +#define VK_QUEUE_FAMILY_EXTERNAL_KHR (~0U-1) -#define VK_EXT_debug_marker 1 -#define VK_EXT_DEBUG_MARKER_SPEC_VERSION 3 -#define VK_EXT_DEBUG_MARKER_EXTENSION_NAME "VK_EXT_debug_marker" +typedef struct VkExternalMemoryImageCreateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsKHR handleTypes; +} VkExternalMemoryImageCreateInfoKHR; -typedef struct VkDebugMarkerObjectNameInfoEXT { - VkStructureType sType; - const void* pNext; - VkDebugReportObjectTypeEXT objectType; - uint64_t object; - const char* pObjectName; -} VkDebugMarkerObjectNameInfoEXT; +typedef struct VkExternalMemoryBufferCreateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsKHR handleTypes; +} VkExternalMemoryBufferCreateInfoKHR; -typedef struct VkDebugMarkerObjectTagInfoEXT { +typedef struct VkExportMemoryAllocateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsKHR handleTypes; +} VkExportMemoryAllocateInfoKHR; + + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_external_memory_win32 1 +#define VK_KHR_EXTERNAL_MEMORY_WIN32_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME "VK_KHR_external_memory_win32" + +typedef struct VkImportMemoryWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; + HANDLE handle; + LPCWSTR name; +} VkImportMemoryWin32HandleInfoKHR; + +typedef struct VkExportMemoryWin32HandleInfoKHR { VkStructureType sType; const void* pNext; - VkDebugReportObjectTypeEXT objectType; - uint64_t object; - uint64_t tagName; - size_t tagSize; - const void* pTag; -} VkDebugMarkerObjectTagInfoEXT; + const SECURITY_ATTRIBUTES* pAttributes; + DWORD dwAccess; + LPCWSTR name; +} VkExportMemoryWin32HandleInfoKHR; -typedef struct VkDebugMarkerMarkerInfoEXT { +typedef struct VkMemoryWin32HandlePropertiesKHR { VkStructureType sType; - const void* pNext; - const char* pMarkerName; - float color[4]; -} VkDebugMarkerMarkerInfoEXT; + void* pNext; + uint32_t memoryTypeBits; +} VkMemoryWin32HandlePropertiesKHR; +typedef struct VkMemoryGetWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkDeviceMemory memory; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkMemoryGetWin32HandleInfoKHR; -typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectTagEXT)(VkDevice device, VkDebugMarkerObjectTagInfoEXT* pTagInfo); -typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectNameEXT)(VkDevice device, VkDebugMarkerObjectNameInfoEXT* pNameInfo); -typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerBeginEXT)(VkCommandBuffer commandBuffer, VkDebugMarkerMarkerInfoEXT* pMarkerInfo); -typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerEndEXT)(VkCommandBuffer commandBuffer); -typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerInsertEXT)(VkCommandBuffer commandBuffer, VkDebugMarkerMarkerInfoEXT* pMarkerInfo); + +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandleKHR)(VkDevice device, const VkMemoryGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandlePropertiesKHR)(VkDevice device, VkExternalMemoryHandleTypeFlagBitsKHR handleType, HANDLE handle, VkMemoryWin32HandlePropertiesKHR* pMemoryWin32HandleProperties); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectTagEXT( +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandleKHR( VkDevice device, - VkDebugMarkerObjectTagInfoEXT* pTagInfo); + const VkMemoryGetWin32HandleInfoKHR* pGetWin32HandleInfo, + HANDLE* pHandle); -VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectNameEXT( +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandlePropertiesKHR( VkDevice device, - VkDebugMarkerObjectNameInfoEXT* pNameInfo); + VkExternalMemoryHandleTypeFlagBitsKHR handleType, + HANDLE handle, + VkMemoryWin32HandlePropertiesKHR* pMemoryWin32HandleProperties); +#endif +#endif /* VK_USE_PLATFORM_WIN32_KHR */ -VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerBeginEXT( - VkCommandBuffer commandBuffer, - VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +#define VK_KHR_external_memory_fd 1 +#define VK_KHR_EXTERNAL_MEMORY_FD_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME "VK_KHR_external_memory_fd" -VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerEndEXT( - VkCommandBuffer commandBuffer); +typedef struct VkImportMemoryFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; + int fd; +} VkImportMemoryFdInfoKHR; -VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerInsertEXT( - VkCommandBuffer commandBuffer, - VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +typedef struct VkMemoryFdPropertiesKHR { + VkStructureType sType; + void* pNext; + uint32_t memoryTypeBits; +} VkMemoryFdPropertiesKHR; + +typedef struct VkMemoryGetFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkDeviceMemory memory; + VkExternalMemoryHandleTypeFlagBitsKHR handleType; +} VkMemoryGetFdInfoKHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryFdKHR)(VkDevice device, const VkMemoryGetFdInfoKHR* pGetFdInfo, int* pFd); +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryFdPropertiesKHR)(VkDevice device, VkExternalMemoryHandleTypeFlagBitsKHR handleType, int fd, VkMemoryFdPropertiesKHR* pMemoryFdProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryFdKHR( + VkDevice device, + const VkMemoryGetFdInfoKHR* pGetFdInfo, + int* pFd); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryFdPropertiesKHR( + VkDevice device, + VkExternalMemoryHandleTypeFlagBitsKHR handleType, + int fd, + VkMemoryFdPropertiesKHR* pMemoryFdProperties); #endif -#define VK_AMD_gcn_shader 1 -#define VK_AMD_GCN_SHADER_SPEC_VERSION 1 -#define VK_AMD_GCN_SHADER_EXTENSION_NAME "VK_AMD_gcn_shader" +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_win32_keyed_mutex 1 +#define VK_KHR_WIN32_KEYED_MUTEX_SPEC_VERSION 1 +#define VK_KHR_WIN32_KEYED_MUTEX_EXTENSION_NAME "VK_KHR_win32_keyed_mutex" +typedef struct VkWin32KeyedMutexAcquireReleaseInfoKHR { + VkStructureType sType; + const void* pNext; + uint32_t acquireCount; + const VkDeviceMemory* pAcquireSyncs; + const uint64_t* pAcquireKeys; + const uint32_t* pAcquireTimeouts; + uint32_t releaseCount; + const VkDeviceMemory* pReleaseSyncs; + const uint64_t* pReleaseKeys; +} VkWin32KeyedMutexAcquireReleaseInfoKHR; -#define VK_NV_dedicated_allocation 1 -#define VK_NV_DEDICATED_ALLOCATION_SPEC_VERSION 1 -#define VK_NV_DEDICATED_ALLOCATION_EXTENSION_NAME "VK_NV_dedicated_allocation" -typedef struct VkDedicatedAllocationImageCreateInfoNV { - VkStructureType sType; - const void* pNext; - VkBool32 dedicatedAllocation; -} VkDedicatedAllocationImageCreateInfoNV; +#endif /* VK_USE_PLATFORM_WIN32_KHR */ -typedef struct VkDedicatedAllocationBufferCreateInfoNV { - VkStructureType sType; - const void* pNext; - VkBool32 dedicatedAllocation; -} VkDedicatedAllocationBufferCreateInfoNV; +#define VK_KHR_external_semaphore_capabilities 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_CAPABILITIES_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_CAPABILITIES_EXTENSION_NAME "VK_KHR_external_semaphore_capabilities" + + +typedef enum VkExternalSemaphoreHandleTypeFlagBitsKHR { + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD_BIT_KHR = 0x00000001, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR = 0x00000002, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_KHR = 0x00000004, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE_BIT_KHR = 0x00000008, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT_KHR = 0x00000010, + VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalSemaphoreHandleTypeFlagBitsKHR; +typedef VkFlags VkExternalSemaphoreHandleTypeFlagsKHR; + +typedef enum VkExternalSemaphoreFeatureFlagBitsKHR { + VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT_KHR = 0x00000001, + VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT_KHR = 0x00000002, + VK_EXTERNAL_SEMAPHORE_FEATURE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalSemaphoreFeatureFlagBitsKHR; +typedef VkFlags VkExternalSemaphoreFeatureFlagsKHR; + +typedef struct VkPhysicalDeviceExternalSemaphoreInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalSemaphoreInfoKHR; -typedef struct VkDedicatedAllocationMemoryAllocateInfoNV { +typedef struct VkExternalSemaphorePropertiesKHR { + VkStructureType sType; + void* pNext; + VkExternalSemaphoreHandleTypeFlagsKHR exportFromImportedHandleTypes; + VkExternalSemaphoreHandleTypeFlagsKHR compatibleHandleTypes; + VkExternalSemaphoreFeatureFlagsKHR externalSemaphoreFeatures; +} VkExternalSemaphorePropertiesKHR; + + +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalSemaphorePropertiesKHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceExternalSemaphoreInfoKHR* pExternalSemaphoreInfo, VkExternalSemaphorePropertiesKHR* pExternalSemaphoreProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceExternalSemaphorePropertiesKHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceExternalSemaphoreInfoKHR* pExternalSemaphoreInfo, + VkExternalSemaphorePropertiesKHR* pExternalSemaphoreProperties); +#endif + +#define VK_KHR_external_semaphore 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_EXTENSION_NAME "VK_KHR_external_semaphore" + + +typedef enum VkSemaphoreImportFlagBitsKHR { + VK_SEMAPHORE_IMPORT_TEMPORARY_BIT_KHR = 0x00000001, + VK_SEMAPHORE_IMPORT_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkSemaphoreImportFlagBitsKHR; +typedef VkFlags VkSemaphoreImportFlagsKHR; + +typedef struct VkExportSemaphoreCreateInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalSemaphoreHandleTypeFlagsKHR handleTypes; +} VkExportSemaphoreCreateInfoKHR; + + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_external_semaphore_win32 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_WIN32_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_WIN32_EXTENSION_NAME "VK_KHR_external_semaphore_win32" + +typedef struct VkImportSemaphoreWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkSemaphoreImportFlagsKHR flags; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; + HANDLE handle; + LPCWSTR name; +} VkImportSemaphoreWin32HandleInfoKHR; + +typedef struct VkExportSemaphoreWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + const SECURITY_ATTRIBUTES* pAttributes; + DWORD dwAccess; + LPCWSTR name; +} VkExportSemaphoreWin32HandleInfoKHR; + +typedef struct VkD3D12FenceSubmitInfoKHR { VkStructureType sType; const void* pNext; - VkImage image; - VkBuffer buffer; -} VkDedicatedAllocationMemoryAllocateInfoNV; + uint32_t waitSemaphoreValuesCount; + const uint64_t* pWaitSemaphoreValues; + uint32_t signalSemaphoreValuesCount; + const uint64_t* pSignalSemaphoreValues; +} VkD3D12FenceSubmitInfoKHR; + +typedef struct VkSemaphoreGetWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; +} VkSemaphoreGetWin32HandleInfoKHR; +typedef VkResult (VKAPI_PTR *PFN_vkImportSemaphoreWin32HandleKHR)(VkDevice device, const VkImportSemaphoreWin32HandleInfoKHR* pImportSemaphoreWin32HandleInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetSemaphoreWin32HandleKHR)(VkDevice device, const VkSemaphoreGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); -#define VK_AMD_draw_indirect_count 1 -#define VK_AMD_DRAW_INDIRECT_COUNT_SPEC_VERSION 1 -#define VK_AMD_DRAW_INDIRECT_COUNT_EXTENSION_NAME "VK_AMD_draw_indirect_count" +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkImportSemaphoreWin32HandleKHR( + VkDevice device, + const VkImportSemaphoreWin32HandleInfoKHR* pImportSemaphoreWin32HandleInfo); -typedef void (VKAPI_PTR *PFN_vkCmdDrawIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); -typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexedIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); +VKAPI_ATTR VkResult VKAPI_CALL vkGetSemaphoreWin32HandleKHR( + VkDevice device, + const VkSemaphoreGetWin32HandleInfoKHR* pGetWin32HandleInfo, + HANDLE* pHandle); +#endif +#endif /* VK_USE_PLATFORM_WIN32_KHR */ + +#define VK_KHR_external_semaphore_fd 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_FD_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_SEMAPHORE_FD_EXTENSION_NAME "VK_KHR_external_semaphore_fd" + +typedef struct VkImportSemaphoreFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkSemaphoreImportFlagsKHR flags; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; + int fd; +} VkImportSemaphoreFdInfoKHR; + +typedef struct VkSemaphoreGetFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkSemaphore semaphore; + VkExternalSemaphoreHandleTypeFlagBitsKHR handleType; +} VkSemaphoreGetFdInfoKHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkImportSemaphoreFdKHR)(VkDevice device, const VkImportSemaphoreFdInfoKHR* pImportSemaphoreFdInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetSemaphoreFdKHR)(VkDevice device, const VkSemaphoreGetFdInfoKHR* pGetFdInfo, int* pFd); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndirectCountAMD( - VkCommandBuffer commandBuffer, - VkBuffer buffer, - VkDeviceSize offset, - VkBuffer countBuffer, - VkDeviceSize countBufferOffset, - uint32_t maxDrawCount, - uint32_t stride); +VKAPI_ATTR VkResult VKAPI_CALL vkImportSemaphoreFdKHR( + VkDevice device, + const VkImportSemaphoreFdInfoKHR* pImportSemaphoreFdInfo); -VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndexedIndirectCountAMD( +VKAPI_ATTR VkResult VKAPI_CALL vkGetSemaphoreFdKHR( + VkDevice device, + const VkSemaphoreGetFdInfoKHR* pGetFdInfo, + int* pFd); +#endif + +#define VK_KHR_push_descriptor 1 +#define VK_KHR_PUSH_DESCRIPTOR_SPEC_VERSION 1 +#define VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME "VK_KHR_push_descriptor" + +typedef struct VkPhysicalDevicePushDescriptorPropertiesKHR { + VkStructureType sType; + void* pNext; + uint32_t maxPushDescriptors; +} VkPhysicalDevicePushDescriptorPropertiesKHR; + + +typedef void (VKAPI_PTR *PFN_vkCmdPushDescriptorSetKHR)(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipelineBindPoint, VkPipelineLayout layout, uint32_t set, uint32_t descriptorWriteCount, const VkWriteDescriptorSet* pDescriptorWrites); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdPushDescriptorSetKHR( VkCommandBuffer commandBuffer, - VkBuffer buffer, - VkDeviceSize offset, - VkBuffer countBuffer, - VkDeviceSize countBufferOffset, - uint32_t maxDrawCount, - uint32_t stride); + VkPipelineBindPoint pipelineBindPoint, + VkPipelineLayout layout, + uint32_t set, + uint32_t descriptorWriteCount, + const VkWriteDescriptorSet* pDescriptorWrites); #endif -#define VK_AMD_negative_viewport_height 1 -#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_SPEC_VERSION 0 -#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_EXTENSION_NAME "VK_AMD_negative_viewport_height" +#define VK_KHR_16bit_storage 1 +#define VK_KHR_16BIT_STORAGE_SPEC_VERSION 1 +#define VK_KHR_16BIT_STORAGE_EXTENSION_NAME "VK_KHR_16bit_storage" +typedef struct VkPhysicalDevice16BitStorageFeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 storageBuffer16BitAccess; + VkBool32 uniformAndStorageBuffer16BitAccess; + VkBool32 storagePushConstant16; + VkBool32 storageInputOutput16; +} VkPhysicalDevice16BitStorageFeaturesKHR; -#define VK_AMD_gpu_shader_half_float 1 -#define VK_AMD_GPU_SHADER_HALF_FLOAT_SPEC_VERSION 1 -#define VK_AMD_GPU_SHADER_HALF_FLOAT_EXTENSION_NAME "VK_AMD_gpu_shader_half_float" -#define VK_AMD_shader_ballot 1 -#define VK_AMD_SHADER_BALLOT_SPEC_VERSION 0 -#define VK_AMD_SHADER_BALLOT_EXTENSION_NAME "VK_AMD_shader_ballot" +#define VK_KHR_incremental_present 1 +#define VK_KHR_INCREMENTAL_PRESENT_SPEC_VERSION 1 +#define VK_KHR_INCREMENTAL_PRESENT_EXTENSION_NAME "VK_KHR_incremental_present" +typedef struct VkRectLayerKHR { + VkOffset2D offset; + VkExtent2D extent; + uint32_t layer; +} VkRectLayerKHR; -#define VK_IMG_format_pvrtc 1 -#define VK_IMG_FORMAT_PVRTC_SPEC_VERSION 1 -#define VK_IMG_FORMAT_PVRTC_EXTENSION_NAME "VK_IMG_format_pvrtc" +typedef struct VkPresentRegionKHR { + uint32_t rectangleCount; + const VkRectLayerKHR* pRectangles; +} VkPresentRegionKHR; +typedef struct VkPresentRegionsKHR { + VkStructureType sType; + const void* pNext; + uint32_t swapchainCount; + const VkPresentRegionKHR* pRegions; +} VkPresentRegionsKHR; -#define VK_NV_external_memory_capabilities 1 -#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_SPEC_VERSION 1 -#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME "VK_NV_external_memory_capabilities" -typedef enum VkExternalMemoryHandleTypeFlagBitsNV { - VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT_NV = 0x00000001, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_NV = 0x00000002, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_BIT_NV = 0x00000004, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_KMT_BIT_NV = 0x00000008, - VK_EXTERNAL_MEMORY_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF -} VkExternalMemoryHandleTypeFlagBitsNV; -typedef VkFlags VkExternalMemoryHandleTypeFlagsNV; +#define VK_KHR_descriptor_update_template 1 +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkDescriptorUpdateTemplateKHR) -typedef enum VkExternalMemoryFeatureFlagBitsNV { - VK_EXTERNAL_MEMORY_FEATURE_DEDICATED_ONLY_BIT_NV = 0x00000001, - VK_EXTERNAL_MEMORY_FEATURE_EXPORTABLE_BIT_NV = 0x00000002, - VK_EXTERNAL_MEMORY_FEATURE_IMPORTABLE_BIT_NV = 0x00000004, - VK_EXTERNAL_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF -} VkExternalMemoryFeatureFlagBitsNV; -typedef VkFlags VkExternalMemoryFeatureFlagsNV; +#define VK_KHR_DESCRIPTOR_UPDATE_TEMPLATE_SPEC_VERSION 1 +#define VK_KHR_DESCRIPTOR_UPDATE_TEMPLATE_EXTENSION_NAME "VK_KHR_descriptor_update_template" -typedef struct VkExternalImageFormatPropertiesNV { - VkImageFormatProperties imageFormatProperties; - VkExternalMemoryFeatureFlagsNV externalMemoryFeatures; - VkExternalMemoryHandleTypeFlagsNV exportFromImportedHandleTypes; - VkExternalMemoryHandleTypeFlagsNV compatibleHandleTypes; -} VkExternalImageFormatPropertiesNV; +typedef enum VkDescriptorUpdateTemplateTypeKHR { + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_DESCRIPTOR_SET_KHR = 0, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR = 1, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_BEGIN_RANGE_KHR = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_DESCRIPTOR_SET_KHR, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_END_RANGE_KHR = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR, + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_RANGE_SIZE_KHR = (VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR - VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_DESCRIPTOR_SET_KHR + 1), + VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_MAX_ENUM_KHR = 0x7FFFFFFF +} VkDescriptorUpdateTemplateTypeKHR; -typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalImageFormatPropertiesNV)(VkPhysicalDevice physicalDevice, VkFormat format, VkImageType type, VkImageTiling tiling, VkImageUsageFlags usage, VkImageCreateFlags flags, VkExternalMemoryHandleTypeFlagsNV externalHandleType, VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); +typedef VkFlags VkDescriptorUpdateTemplateCreateFlagsKHR; + +typedef struct VkDescriptorUpdateTemplateEntryKHR { + uint32_t dstBinding; + uint32_t dstArrayElement; + uint32_t descriptorCount; + VkDescriptorType descriptorType; + size_t offset; + size_t stride; +} VkDescriptorUpdateTemplateEntryKHR; + +typedef struct VkDescriptorUpdateTemplateCreateInfoKHR { + VkStructureType sType; + void* pNext; + VkDescriptorUpdateTemplateCreateFlagsKHR flags; + uint32_t descriptorUpdateEntryCount; + const VkDescriptorUpdateTemplateEntryKHR* pDescriptorUpdateEntries; + VkDescriptorUpdateTemplateTypeKHR templateType; + VkDescriptorSetLayout descriptorSetLayout; + VkPipelineBindPoint pipelineBindPoint; + VkPipelineLayout pipelineLayout; + uint32_t set; +} VkDescriptorUpdateTemplateCreateInfoKHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateDescriptorUpdateTemplateKHR)(VkDevice device, const VkDescriptorUpdateTemplateCreateInfoKHR* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDescriptorUpdateTemplateKHR* pDescriptorUpdateTemplate); +typedef void (VKAPI_PTR *PFN_vkDestroyDescriptorUpdateTemplateKHR)(VkDevice device, VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, const VkAllocationCallbacks* pAllocator); +typedef void (VKAPI_PTR *PFN_vkUpdateDescriptorSetWithTemplateKHR)(VkDevice device, VkDescriptorSet descriptorSet, VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, const void* pData); +typedef void (VKAPI_PTR *PFN_vkCmdPushDescriptorSetWithTemplateKHR)(VkCommandBuffer commandBuffer, VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, VkPipelineLayout layout, uint32_t set, const void* pData); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceExternalImageFormatPropertiesNV( - VkPhysicalDevice physicalDevice, - VkFormat format, - VkImageType type, - VkImageTiling tiling, - VkImageUsageFlags usage, - VkImageCreateFlags flags, - VkExternalMemoryHandleTypeFlagsNV externalHandleType, - VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); +VKAPI_ATTR VkResult VKAPI_CALL vkCreateDescriptorUpdateTemplateKHR( + VkDevice device, + const VkDescriptorUpdateTemplateCreateInfoKHR* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkDescriptorUpdateTemplateKHR* pDescriptorUpdateTemplate); + +VKAPI_ATTR void VKAPI_CALL vkDestroyDescriptorUpdateTemplateKHR( + VkDevice device, + VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR void VKAPI_CALL vkUpdateDescriptorSetWithTemplateKHR( + VkDevice device, + VkDescriptorSet descriptorSet, + VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, + const void* pData); + +VKAPI_ATTR void VKAPI_CALL vkCmdPushDescriptorSetWithTemplateKHR( + VkCommandBuffer commandBuffer, + VkDescriptorUpdateTemplateKHR descriptorUpdateTemplate, + VkPipelineLayout layout, + uint32_t set, + const void* pData); #endif -#define VK_NV_external_memory 1 -#define VK_NV_EXTERNAL_MEMORY_SPEC_VERSION 1 -#define VK_NV_EXTERNAL_MEMORY_EXTENSION_NAME "VK_NV_external_memory" +#define VK_KHR_shared_presentable_image 1 +#define VK_KHR_SHARED_PRESENTABLE_IMAGE_SPEC_VERSION 1 +#define VK_KHR_SHARED_PRESENTABLE_IMAGE_EXTENSION_NAME "VK_KHR_shared_presentable_image" -typedef struct VkExternalMemoryImageCreateInfoNV { - VkStructureType sType; - const void* pNext; - VkExternalMemoryHandleTypeFlagsNV handleTypes; -} VkExternalMemoryImageCreateInfoNV; +typedef struct VkSharedPresentSurfaceCapabilitiesKHR { + VkStructureType sType; + void* pNext; + VkImageUsageFlags sharedPresentSupportedUsageFlags; +} VkSharedPresentSurfaceCapabilitiesKHR; -typedef struct VkExportMemoryAllocateInfoNV { + +typedef VkResult (VKAPI_PTR *PFN_vkGetSwapchainStatusKHR)(VkDevice device, VkSwapchainKHR swapchain); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetSwapchainStatusKHR( + VkDevice device, + VkSwapchainKHR swapchain); +#endif + +#define VK_KHR_external_fence_capabilities 1 +#define VK_KHR_EXTERNAL_FENCE_CAPABILITIES_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_CAPABILITIES_EXTENSION_NAME "VK_KHR_external_fence_capabilities" + + +typedef enum VkExternalFenceHandleTypeFlagBitsKHR { + VK_EXTERNAL_FENCE_HANDLE_TYPE_OPAQUE_FD_BIT_KHR = 0x00000001, + VK_EXTERNAL_FENCE_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR = 0x00000002, + VK_EXTERNAL_FENCE_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_KHR = 0x00000004, + VK_EXTERNAL_FENCE_HANDLE_TYPE_SYNC_FD_BIT_KHR = 0x00000008, + VK_EXTERNAL_FENCE_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalFenceHandleTypeFlagBitsKHR; +typedef VkFlags VkExternalFenceHandleTypeFlagsKHR; + +typedef enum VkExternalFenceFeatureFlagBitsKHR { + VK_EXTERNAL_FENCE_FEATURE_EXPORTABLE_BIT_KHR = 0x00000001, + VK_EXTERNAL_FENCE_FEATURE_IMPORTABLE_BIT_KHR = 0x00000002, + VK_EXTERNAL_FENCE_FEATURE_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkExternalFenceFeatureFlagBitsKHR; +typedef VkFlags VkExternalFenceFeatureFlagsKHR; + +typedef struct VkPhysicalDeviceExternalFenceInfoKHR { + VkStructureType sType; + const void* pNext; + VkExternalFenceHandleTypeFlagBitsKHR handleType; +} VkPhysicalDeviceExternalFenceInfoKHR; + +typedef struct VkExternalFencePropertiesKHR { VkStructureType sType; - const void* pNext; - VkExternalMemoryHandleTypeFlagsNV handleTypes; -} VkExportMemoryAllocateInfoNV; + void* pNext; + VkExternalFenceHandleTypeFlagsKHR exportFromImportedHandleTypes; + VkExternalFenceHandleTypeFlagsKHR compatibleHandleTypes; + VkExternalFenceFeatureFlagsKHR externalFenceFeatures; +} VkExternalFencePropertiesKHR; +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalFencePropertiesKHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceExternalFenceInfoKHR* pExternalFenceInfo, VkExternalFencePropertiesKHR* pExternalFenceProperties); -#ifdef VK_USE_PLATFORM_WIN32_KHR -#define VK_NV_external_memory_win32 1 -#define VK_NV_EXTERNAL_MEMORY_WIN32_SPEC_VERSION 1 -#define VK_NV_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME "VK_NV_external_memory_win32" +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceExternalFencePropertiesKHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceExternalFenceInfoKHR* pExternalFenceInfo, + VkExternalFencePropertiesKHR* pExternalFenceProperties); +#endif -typedef struct VkImportMemoryWin32HandleInfoNV { +#define VK_KHR_external_fence 1 +#define VK_KHR_EXTERNAL_FENCE_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_EXTENSION_NAME "VK_KHR_external_fence" + + +typedef enum VkFenceImportFlagBitsKHR { + VK_FENCE_IMPORT_TEMPORARY_BIT_KHR = 0x00000001, + VK_FENCE_IMPORT_FLAG_BITS_MAX_ENUM_KHR = 0x7FFFFFFF +} VkFenceImportFlagBitsKHR; +typedef VkFlags VkFenceImportFlagsKHR; + +typedef struct VkExportFenceCreateInfoKHR { VkStructureType sType; const void* pNext; - VkExternalMemoryHandleTypeFlagsNV handleType; - HANDLE handle; -} VkImportMemoryWin32HandleInfoNV; + VkExternalFenceHandleTypeFlagsKHR handleTypes; +} VkExportFenceCreateInfoKHR; -typedef struct VkExportMemoryWin32HandleInfoNV { + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_KHR_external_fence_win32 1 +#define VK_KHR_EXTERNAL_FENCE_WIN32_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_WIN32_EXTENSION_NAME "VK_KHR_external_fence_win32" + +typedef struct VkImportFenceWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkFenceImportFlagsKHR flags; + VkExternalFenceHandleTypeFlagBitsKHR handleType; + HANDLE handle; + LPCWSTR name; +} VkImportFenceWin32HandleInfoKHR; + +typedef struct VkExportFenceWin32HandleInfoKHR { VkStructureType sType; const void* pNext; const SECURITY_ATTRIBUTES* pAttributes; DWORD dwAccess; -} VkExportMemoryWin32HandleInfoNV; + LPCWSTR name; +} VkExportFenceWin32HandleInfoKHR; +typedef struct VkFenceGetWin32HandleInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkExternalFenceHandleTypeFlagBitsKHR handleType; +} VkFenceGetWin32HandleInfoKHR; -typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandleNV)(VkDevice device, VkDeviceMemory memory, VkExternalMemoryHandleTypeFlagsNV handleType, HANDLE* pHandle); + +typedef VkResult (VKAPI_PTR *PFN_vkImportFenceWin32HandleKHR)(VkDevice device, const VkImportFenceWin32HandleInfoKHR* pImportFenceWin32HandleInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetFenceWin32HandleKHR)(VkDevice device, const VkFenceGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); #ifndef VK_NO_PROTOTYPES -VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandleNV( +VKAPI_ATTR VkResult VKAPI_CALL vkImportFenceWin32HandleKHR( VkDevice device, - VkDeviceMemory memory, - VkExternalMemoryHandleTypeFlagsNV handleType, + const VkImportFenceWin32HandleInfoKHR* pImportFenceWin32HandleInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetFenceWin32HandleKHR( + VkDevice device, + const VkFenceGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); #endif #endif /* VK_USE_PLATFORM_WIN32_KHR */ -#ifdef VK_USE_PLATFORM_WIN32_KHR -#define VK_NV_win32_keyed_mutex 1 -#define VK_NV_WIN32_KEYED_MUTEX_SPEC_VERSION 1 -#define VK_NV_WIN32_KEYED_MUTEX_EXTENSION_NAME "VK_NV_win32_keyed_mutex" +#define VK_KHR_external_fence_fd 1 +#define VK_KHR_EXTERNAL_FENCE_FD_SPEC_VERSION 1 +#define VK_KHR_EXTERNAL_FENCE_FD_EXTENSION_NAME "VK_KHR_external_fence_fd" -typedef struct VkWin32KeyedMutexAcquireReleaseInfoNV { - VkStructureType sType; - const void* pNext; - uint32_t acquireCount; - const VkDeviceMemory* pAcquireSyncs; - const uint64_t* pAcquireKeys; - const uint32_t* pAcquireTimeoutMilliseconds; - uint32_t releaseCount; - const VkDeviceMemory* pReleaseSyncs; - const uint64_t* pReleaseKeys; -} VkWin32KeyedMutexAcquireReleaseInfoNV; +typedef struct VkImportFenceFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkFenceImportFlagsKHR flags; + VkExternalFenceHandleTypeFlagBitsKHR handleType; + int fd; +} VkImportFenceFdInfoKHR; +typedef struct VkFenceGetFdInfoKHR { + VkStructureType sType; + const void* pNext; + VkFence fence; + VkExternalFenceHandleTypeFlagBitsKHR handleType; +} VkFenceGetFdInfoKHR; -#endif /* VK_USE_PLATFORM_WIN32_KHR */ -#define VK_EXT_validation_flags 1 -#define VK_EXT_VALIDATION_FLAGS_SPEC_VERSION 1 -#define VK_EXT_VALIDATION_FLAGS_EXTENSION_NAME "VK_EXT_validation_flags" +typedef VkResult (VKAPI_PTR *PFN_vkImportFenceFdKHR)(VkDevice device, const VkImportFenceFdInfoKHR* pImportFenceFdInfo); +typedef VkResult (VKAPI_PTR *PFN_vkGetFenceFdKHR)(VkDevice device, const VkFenceGetFdInfoKHR* pGetFdInfo, int* pFd); +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkImportFenceFdKHR( + VkDevice device, + const VkImportFenceFdInfoKHR* pImportFenceFdInfo); -typedef enum VkValidationCheckEXT { - VK_VALIDATION_CHECK_ALL_EXT = 0, - VK_VALIDATION_CHECK_BEGIN_RANGE_EXT = VK_VALIDATION_CHECK_ALL_EXT, - VK_VALIDATION_CHECK_END_RANGE_EXT = VK_VALIDATION_CHECK_ALL_EXT, - VK_VALIDATION_CHECK_RANGE_SIZE_EXT = (VK_VALIDATION_CHECK_ALL_EXT - VK_VALIDATION_CHECK_ALL_EXT + 1), - VK_VALIDATION_CHECK_MAX_ENUM_EXT = 0x7FFFFFFF -} VkValidationCheckEXT; +VKAPI_ATTR VkResult VKAPI_CALL vkGetFenceFdKHR( + VkDevice device, + const VkFenceGetFdInfoKHR* pGetFdInfo, + int* pFd); +#endif + +#define VK_KHR_get_surface_capabilities2 1 +#define VK_KHR_GET_SURFACE_CAPABILITIES_2_SPEC_VERSION 1 +#define VK_KHR_GET_SURFACE_CAPABILITIES_2_EXTENSION_NAME "VK_KHR_get_surface_capabilities2" + +typedef struct VkPhysicalDeviceSurfaceInfo2KHR { + VkStructureType sType; + const void* pNext; + VkSurfaceKHR surface; +} VkPhysicalDeviceSurfaceInfo2KHR; + +typedef struct VkSurfaceCapabilities2KHR { + VkStructureType sType; + void* pNext; + VkSurfaceCapabilitiesKHR surfaceCapabilities; +} VkSurfaceCapabilities2KHR; + +typedef struct VkSurfaceFormat2KHR { + VkStructureType sType; + void* pNext; + VkSurfaceFormatKHR surfaceFormat; +} VkSurfaceFormat2KHR; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceSurfaceCapabilities2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, VkSurfaceCapabilities2KHR* pSurfaceCapabilities); +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceSurfaceFormats2KHR)(VkPhysicalDevice physicalDevice, const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, uint32_t* pSurfaceFormatCount, VkSurfaceFormat2KHR* pSurfaceFormats); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceSurfaceCapabilities2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, + VkSurfaceCapabilities2KHR* pSurfaceCapabilities); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceSurfaceFormats2KHR( + VkPhysicalDevice physicalDevice, + const VkPhysicalDeviceSurfaceInfo2KHR* pSurfaceInfo, + uint32_t* pSurfaceFormatCount, + VkSurfaceFormat2KHR* pSurfaceFormats); +#endif + +#define VK_KHR_variable_pointers 1 +#define VK_KHR_VARIABLE_POINTERS_SPEC_VERSION 1 +#define VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME "VK_KHR_variable_pointers" + +typedef struct VkPhysicalDeviceVariablePointerFeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 variablePointersStorageBuffer; + VkBool32 variablePointers; +} VkPhysicalDeviceVariablePointerFeaturesKHR; + + + +#define VK_KHR_dedicated_allocation 1 +#define VK_KHR_DEDICATED_ALLOCATION_SPEC_VERSION 3 +#define VK_KHR_DEDICATED_ALLOCATION_EXTENSION_NAME "VK_KHR_dedicated_allocation" + +typedef struct VkMemoryDedicatedRequirementsKHR { + VkStructureType sType; + void* pNext; + VkBool32 prefersDedicatedAllocation; + VkBool32 requiresDedicatedAllocation; +} VkMemoryDedicatedRequirementsKHR; + +typedef struct VkMemoryDedicatedAllocateInfoKHR { + VkStructureType sType; + const void* pNext; + VkImage image; + VkBuffer buffer; +} VkMemoryDedicatedAllocateInfoKHR; + + + +#define VK_KHR_storage_buffer_storage_class 1 +#define VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_SPEC_VERSION 1 +#define VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME "VK_KHR_storage_buffer_storage_class" + + +#define VK_KHR_relaxed_block_layout 1 +#define VK_KHR_RELAXED_BLOCK_LAYOUT_SPEC_VERSION 1 +#define VK_KHR_RELAXED_BLOCK_LAYOUT_EXTENSION_NAME "VK_KHR_relaxed_block_layout" + + +#define VK_KHR_get_memory_requirements2 1 +#define VK_KHR_GET_MEMORY_REQUIREMENTS_2_SPEC_VERSION 1 +#define VK_KHR_GET_MEMORY_REQUIREMENTS_2_EXTENSION_NAME "VK_KHR_get_memory_requirements2" + +typedef struct VkBufferMemoryRequirementsInfo2KHR { + VkStructureType sType; + const void* pNext; + VkBuffer buffer; +} VkBufferMemoryRequirementsInfo2KHR; + +typedef struct VkImageMemoryRequirementsInfo2KHR { + VkStructureType sType; + const void* pNext; + VkImage image; +} VkImageMemoryRequirementsInfo2KHR; + +typedef struct VkImageSparseMemoryRequirementsInfo2KHR { + VkStructureType sType; + const void* pNext; + VkImage image; +} VkImageSparseMemoryRequirementsInfo2KHR; + +typedef struct VkMemoryRequirements2KHR { + VkStructureType sType; + void* pNext; + VkMemoryRequirements memoryRequirements; +} VkMemoryRequirements2KHR; + +typedef struct VkSparseImageMemoryRequirements2KHR { + VkStructureType sType; + void* pNext; + VkSparseImageMemoryRequirements memoryRequirements; +} VkSparseImageMemoryRequirements2KHR; + + +typedef void (VKAPI_PTR *PFN_vkGetImageMemoryRequirements2KHR)(VkDevice device, const VkImageMemoryRequirementsInfo2KHR* pInfo, VkMemoryRequirements2KHR* pMemoryRequirements); +typedef void (VKAPI_PTR *PFN_vkGetBufferMemoryRequirements2KHR)(VkDevice device, const VkBufferMemoryRequirementsInfo2KHR* pInfo, VkMemoryRequirements2KHR* pMemoryRequirements); +typedef void (VKAPI_PTR *PFN_vkGetImageSparseMemoryRequirements2KHR)(VkDevice device, const VkImageSparseMemoryRequirementsInfo2KHR* pInfo, uint32_t* pSparseMemoryRequirementCount, VkSparseImageMemoryRequirements2KHR* pSparseMemoryRequirements); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetImageMemoryRequirements2KHR( + VkDevice device, + const VkImageMemoryRequirementsInfo2KHR* pInfo, + VkMemoryRequirements2KHR* pMemoryRequirements); + +VKAPI_ATTR void VKAPI_CALL vkGetBufferMemoryRequirements2KHR( + VkDevice device, + const VkBufferMemoryRequirementsInfo2KHR* pInfo, + VkMemoryRequirements2KHR* pMemoryRequirements); + +VKAPI_ATTR void VKAPI_CALL vkGetImageSparseMemoryRequirements2KHR( + VkDevice device, + const VkImageSparseMemoryRequirementsInfo2KHR* pInfo, + uint32_t* pSparseMemoryRequirementCount, + VkSparseImageMemoryRequirements2KHR* pSparseMemoryRequirements); +#endif + +#define VK_EXT_debug_report 1 +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkDebugReportCallbackEXT) + +#define VK_EXT_DEBUG_REPORT_SPEC_VERSION 8 +#define VK_EXT_DEBUG_REPORT_EXTENSION_NAME "VK_EXT_debug_report" +#define VK_STRUCTURE_TYPE_DEBUG_REPORT_CREATE_INFO_EXT VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT +#define VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_EXT VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_CALLBACK_EXT_EXT + + +typedef enum VkDebugReportObjectTypeEXT { + VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT = 0, + VK_DEBUG_REPORT_OBJECT_TYPE_INSTANCE_EXT = 1, + VK_DEBUG_REPORT_OBJECT_TYPE_PHYSICAL_DEVICE_EXT = 2, + VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_EXT = 3, + VK_DEBUG_REPORT_OBJECT_TYPE_QUEUE_EXT = 4, + VK_DEBUG_REPORT_OBJECT_TYPE_SEMAPHORE_EXT = 5, + VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_BUFFER_EXT = 6, + VK_DEBUG_REPORT_OBJECT_TYPE_FENCE_EXT = 7, + VK_DEBUG_REPORT_OBJECT_TYPE_DEVICE_MEMORY_EXT = 8, + VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_EXT = 9, + VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_EXT = 10, + VK_DEBUG_REPORT_OBJECT_TYPE_EVENT_EXT = 11, + VK_DEBUG_REPORT_OBJECT_TYPE_QUERY_POOL_EXT = 12, + VK_DEBUG_REPORT_OBJECT_TYPE_BUFFER_VIEW_EXT = 13, + VK_DEBUG_REPORT_OBJECT_TYPE_IMAGE_VIEW_EXT = 14, + VK_DEBUG_REPORT_OBJECT_TYPE_SHADER_MODULE_EXT = 15, + VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_CACHE_EXT = 16, + VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_LAYOUT_EXT = 17, + VK_DEBUG_REPORT_OBJECT_TYPE_RENDER_PASS_EXT = 18, + VK_DEBUG_REPORT_OBJECT_TYPE_PIPELINE_EXT = 19, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_LAYOUT_EXT = 20, + VK_DEBUG_REPORT_OBJECT_TYPE_SAMPLER_EXT = 21, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_POOL_EXT = 22, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_SET_EXT = 23, + VK_DEBUG_REPORT_OBJECT_TYPE_FRAMEBUFFER_EXT = 24, + VK_DEBUG_REPORT_OBJECT_TYPE_COMMAND_POOL_EXT = 25, + VK_DEBUG_REPORT_OBJECT_TYPE_SURFACE_KHR_EXT = 26, + VK_DEBUG_REPORT_OBJECT_TYPE_SWAPCHAIN_KHR_EXT = 27, + VK_DEBUG_REPORT_OBJECT_TYPE_DEBUG_REPORT_CALLBACK_EXT_EXT = 28, + VK_DEBUG_REPORT_OBJECT_TYPE_DISPLAY_KHR_EXT = 29, + VK_DEBUG_REPORT_OBJECT_TYPE_DISPLAY_MODE_KHR_EXT = 30, + VK_DEBUG_REPORT_OBJECT_TYPE_OBJECT_TABLE_NVX_EXT = 31, + VK_DEBUG_REPORT_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX_EXT = 32, + VK_DEBUG_REPORT_OBJECT_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_KHR_EXT = 1000085000, + VK_DEBUG_REPORT_OBJECT_TYPE_BEGIN_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT, + VK_DEBUG_REPORT_OBJECT_TYPE_END_RANGE_EXT = VK_DEBUG_REPORT_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX_EXT, + VK_DEBUG_REPORT_OBJECT_TYPE_RANGE_SIZE_EXT = (VK_DEBUG_REPORT_OBJECT_TYPE_INDIRECT_COMMANDS_LAYOUT_NVX_EXT - VK_DEBUG_REPORT_OBJECT_TYPE_UNKNOWN_EXT + 1), + VK_DEBUG_REPORT_OBJECT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDebugReportObjectTypeEXT; + + +typedef enum VkDebugReportFlagBitsEXT { + VK_DEBUG_REPORT_INFORMATION_BIT_EXT = 0x00000001, + VK_DEBUG_REPORT_WARNING_BIT_EXT = 0x00000002, + VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT = 0x00000004, + VK_DEBUG_REPORT_ERROR_BIT_EXT = 0x00000008, + VK_DEBUG_REPORT_DEBUG_BIT_EXT = 0x00000010, + VK_DEBUG_REPORT_FLAG_BITS_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDebugReportFlagBitsEXT; +typedef VkFlags VkDebugReportFlagsEXT; + +typedef VkBool32 (VKAPI_PTR *PFN_vkDebugReportCallbackEXT)( + VkDebugReportFlagsEXT flags, + VkDebugReportObjectTypeEXT objectType, + uint64_t object, + size_t location, + int32_t messageCode, + const char* pLayerPrefix, + const char* pMessage, + void* pUserData); + +typedef struct VkDebugReportCallbackCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkDebugReportFlagsEXT flags; + PFN_vkDebugReportCallbackEXT pfnCallback; + void* pUserData; +} VkDebugReportCallbackCreateInfoEXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateDebugReportCallbackEXT)(VkInstance instance, const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugReportCallbackEXT* pCallback); +typedef void (VKAPI_PTR *PFN_vkDestroyDebugReportCallbackEXT)(VkInstance instance, VkDebugReportCallbackEXT callback, const VkAllocationCallbacks* pAllocator); +typedef void (VKAPI_PTR *PFN_vkDebugReportMessageEXT)(VkInstance instance, VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objectType, uint64_t object, size_t location, int32_t messageCode, const char* pLayerPrefix, const char* pMessage); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateDebugReportCallbackEXT( + VkInstance instance, + const VkDebugReportCallbackCreateInfoEXT* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkDebugReportCallbackEXT* pCallback); + +VKAPI_ATTR void VKAPI_CALL vkDestroyDebugReportCallbackEXT( + VkInstance instance, + VkDebugReportCallbackEXT callback, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR void VKAPI_CALL vkDebugReportMessageEXT( + VkInstance instance, + VkDebugReportFlagsEXT flags, + VkDebugReportObjectTypeEXT objectType, + uint64_t object, + size_t location, + int32_t messageCode, + const char* pLayerPrefix, + const char* pMessage); +#endif + +#define VK_NV_glsl_shader 1 +#define VK_NV_GLSL_SHADER_SPEC_VERSION 1 +#define VK_NV_GLSL_SHADER_EXTENSION_NAME "VK_NV_glsl_shader" + + +#define VK_EXT_depth_range_unrestricted 1 +#define VK_EXT_DEPTH_RANGE_UNRESTRICTED_SPEC_VERSION 1 +#define VK_EXT_DEPTH_RANGE_UNRESTRICTED_EXTENSION_NAME "VK_EXT_depth_range_unrestricted" + + +#define VK_IMG_filter_cubic 1 +#define VK_IMG_FILTER_CUBIC_SPEC_VERSION 1 +#define VK_IMG_FILTER_CUBIC_EXTENSION_NAME "VK_IMG_filter_cubic" + + +#define VK_AMD_rasterization_order 1 +#define VK_AMD_RASTERIZATION_ORDER_SPEC_VERSION 1 +#define VK_AMD_RASTERIZATION_ORDER_EXTENSION_NAME "VK_AMD_rasterization_order" + + +typedef enum VkRasterizationOrderAMD { + VK_RASTERIZATION_ORDER_STRICT_AMD = 0, + VK_RASTERIZATION_ORDER_RELAXED_AMD = 1, + VK_RASTERIZATION_ORDER_BEGIN_RANGE_AMD = VK_RASTERIZATION_ORDER_STRICT_AMD, + VK_RASTERIZATION_ORDER_END_RANGE_AMD = VK_RASTERIZATION_ORDER_RELAXED_AMD, + VK_RASTERIZATION_ORDER_RANGE_SIZE_AMD = (VK_RASTERIZATION_ORDER_RELAXED_AMD - VK_RASTERIZATION_ORDER_STRICT_AMD + 1), + VK_RASTERIZATION_ORDER_MAX_ENUM_AMD = 0x7FFFFFFF +} VkRasterizationOrderAMD; + +typedef struct VkPipelineRasterizationStateRasterizationOrderAMD { + VkStructureType sType; + const void* pNext; + VkRasterizationOrderAMD rasterizationOrder; +} VkPipelineRasterizationStateRasterizationOrderAMD; + + + +#define VK_AMD_shader_trinary_minmax 1 +#define VK_AMD_SHADER_TRINARY_MINMAX_SPEC_VERSION 1 +#define VK_AMD_SHADER_TRINARY_MINMAX_EXTENSION_NAME "VK_AMD_shader_trinary_minmax" + + +#define VK_AMD_shader_explicit_vertex_parameter 1 +#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_SPEC_VERSION 1 +#define VK_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER_EXTENSION_NAME "VK_AMD_shader_explicit_vertex_parameter" + + +#define VK_EXT_debug_marker 1 +#define VK_EXT_DEBUG_MARKER_SPEC_VERSION 4 +#define VK_EXT_DEBUG_MARKER_EXTENSION_NAME "VK_EXT_debug_marker" + +typedef struct VkDebugMarkerObjectNameInfoEXT { + VkStructureType sType; + const void* pNext; + VkDebugReportObjectTypeEXT objectType; + uint64_t object; + const char* pObjectName; +} VkDebugMarkerObjectNameInfoEXT; + +typedef struct VkDebugMarkerObjectTagInfoEXT { + VkStructureType sType; + const void* pNext; + VkDebugReportObjectTypeEXT objectType; + uint64_t object; + uint64_t tagName; + size_t tagSize; + const void* pTag; +} VkDebugMarkerObjectTagInfoEXT; + +typedef struct VkDebugMarkerMarkerInfoEXT { + VkStructureType sType; + const void* pNext; + const char* pMarkerName; + float color[4]; +} VkDebugMarkerMarkerInfoEXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectTagEXT)(VkDevice device, const VkDebugMarkerObjectTagInfoEXT* pTagInfo); +typedef VkResult (VKAPI_PTR *PFN_vkDebugMarkerSetObjectNameEXT)(VkDevice device, const VkDebugMarkerObjectNameInfoEXT* pNameInfo); +typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerBeginEXT)(VkCommandBuffer commandBuffer, const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerEndEXT)(VkCommandBuffer commandBuffer); +typedef void (VKAPI_PTR *PFN_vkCmdDebugMarkerInsertEXT)(VkCommandBuffer commandBuffer, const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectTagEXT( + VkDevice device, + const VkDebugMarkerObjectTagInfoEXT* pTagInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkDebugMarkerSetObjectNameEXT( + VkDevice device, + const VkDebugMarkerObjectNameInfoEXT* pNameInfo); + +VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerBeginEXT( + VkCommandBuffer commandBuffer, + const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); + +VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerEndEXT( + VkCommandBuffer commandBuffer); + +VKAPI_ATTR void VKAPI_CALL vkCmdDebugMarkerInsertEXT( + VkCommandBuffer commandBuffer, + const VkDebugMarkerMarkerInfoEXT* pMarkerInfo); +#endif + +#define VK_AMD_gcn_shader 1 +#define VK_AMD_GCN_SHADER_SPEC_VERSION 1 +#define VK_AMD_GCN_SHADER_EXTENSION_NAME "VK_AMD_gcn_shader" + + +#define VK_NV_dedicated_allocation 1 +#define VK_NV_DEDICATED_ALLOCATION_SPEC_VERSION 1 +#define VK_NV_DEDICATED_ALLOCATION_EXTENSION_NAME "VK_NV_dedicated_allocation" + +typedef struct VkDedicatedAllocationImageCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkBool32 dedicatedAllocation; +} VkDedicatedAllocationImageCreateInfoNV; + +typedef struct VkDedicatedAllocationBufferCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkBool32 dedicatedAllocation; +} VkDedicatedAllocationBufferCreateInfoNV; + +typedef struct VkDedicatedAllocationMemoryAllocateInfoNV { + VkStructureType sType; + const void* pNext; + VkImage image; + VkBuffer buffer; +} VkDedicatedAllocationMemoryAllocateInfoNV; + + + +#define VK_AMD_draw_indirect_count 1 +#define VK_AMD_DRAW_INDIRECT_COUNT_SPEC_VERSION 1 +#define VK_AMD_DRAW_INDIRECT_COUNT_EXTENSION_NAME "VK_AMD_draw_indirect_count" + +typedef void (VKAPI_PTR *PFN_vkCmdDrawIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); +typedef void (VKAPI_PTR *PFN_vkCmdDrawIndexedIndirectCountAMD)(VkCommandBuffer commandBuffer, VkBuffer buffer, VkDeviceSize offset, VkBuffer countBuffer, VkDeviceSize countBufferOffset, uint32_t maxDrawCount, uint32_t stride); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndirectCountAMD( + VkCommandBuffer commandBuffer, + VkBuffer buffer, + VkDeviceSize offset, + VkBuffer countBuffer, + VkDeviceSize countBufferOffset, + uint32_t maxDrawCount, + uint32_t stride); + +VKAPI_ATTR void VKAPI_CALL vkCmdDrawIndexedIndirectCountAMD( + VkCommandBuffer commandBuffer, + VkBuffer buffer, + VkDeviceSize offset, + VkBuffer countBuffer, + VkDeviceSize countBufferOffset, + uint32_t maxDrawCount, + uint32_t stride); +#endif + +#define VK_AMD_negative_viewport_height 1 +#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_SPEC_VERSION 1 +#define VK_AMD_NEGATIVE_VIEWPORT_HEIGHT_EXTENSION_NAME "VK_AMD_negative_viewport_height" + + +#define VK_AMD_gpu_shader_half_float 1 +#define VK_AMD_GPU_SHADER_HALF_FLOAT_SPEC_VERSION 1 +#define VK_AMD_GPU_SHADER_HALF_FLOAT_EXTENSION_NAME "VK_AMD_gpu_shader_half_float" + + +#define VK_AMD_shader_ballot 1 +#define VK_AMD_SHADER_BALLOT_SPEC_VERSION 1 +#define VK_AMD_SHADER_BALLOT_EXTENSION_NAME "VK_AMD_shader_ballot" + + +#define VK_AMD_texture_gather_bias_lod 1 +#define VK_AMD_TEXTURE_GATHER_BIAS_LOD_SPEC_VERSION 1 +#define VK_AMD_TEXTURE_GATHER_BIAS_LOD_EXTENSION_NAME "VK_AMD_texture_gather_bias_lod" + +typedef struct VkTextureLODGatherFormatPropertiesAMD { + VkStructureType sType; + void* pNext; + VkBool32 supportsTextureGatherLODBiasAMD; +} VkTextureLODGatherFormatPropertiesAMD; + + + +#define VK_KHX_multiview 1 +#define VK_KHX_MULTIVIEW_SPEC_VERSION 1 +#define VK_KHX_MULTIVIEW_EXTENSION_NAME "VK_KHX_multiview" + +typedef struct VkRenderPassMultiviewCreateInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t subpassCount; + const uint32_t* pViewMasks; + uint32_t dependencyCount; + const int32_t* pViewOffsets; + uint32_t correlationMaskCount; + const uint32_t* pCorrelationMasks; +} VkRenderPassMultiviewCreateInfoKHX; + +typedef struct VkPhysicalDeviceMultiviewFeaturesKHX { + VkStructureType sType; + void* pNext; + VkBool32 multiview; + VkBool32 multiviewGeometryShader; + VkBool32 multiviewTessellationShader; +} VkPhysicalDeviceMultiviewFeaturesKHX; + +typedef struct VkPhysicalDeviceMultiviewPropertiesKHX { + VkStructureType sType; + void* pNext; + uint32_t maxMultiviewViewCount; + uint32_t maxMultiviewInstanceIndex; +} VkPhysicalDeviceMultiviewPropertiesKHX; + + + +#define VK_IMG_format_pvrtc 1 +#define VK_IMG_FORMAT_PVRTC_SPEC_VERSION 1 +#define VK_IMG_FORMAT_PVRTC_EXTENSION_NAME "VK_IMG_format_pvrtc" + + +#define VK_NV_external_memory_capabilities 1 +#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_SPEC_VERSION 1 +#define VK_NV_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME "VK_NV_external_memory_capabilities" + + +typedef enum VkExternalMemoryHandleTypeFlagBitsNV { + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT_NV = 0x00000001, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT_BIT_NV = 0x00000002, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_BIT_NV = 0x00000004, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_IMAGE_KMT_BIT_NV = 0x00000008, + VK_EXTERNAL_MEMORY_HANDLE_TYPE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF +} VkExternalMemoryHandleTypeFlagBitsNV; +typedef VkFlags VkExternalMemoryHandleTypeFlagsNV; + +typedef enum VkExternalMemoryFeatureFlagBitsNV { + VK_EXTERNAL_MEMORY_FEATURE_DEDICATED_ONLY_BIT_NV = 0x00000001, + VK_EXTERNAL_MEMORY_FEATURE_EXPORTABLE_BIT_NV = 0x00000002, + VK_EXTERNAL_MEMORY_FEATURE_IMPORTABLE_BIT_NV = 0x00000004, + VK_EXTERNAL_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_NV = 0x7FFFFFFF +} VkExternalMemoryFeatureFlagBitsNV; +typedef VkFlags VkExternalMemoryFeatureFlagsNV; + +typedef struct VkExternalImageFormatPropertiesNV { + VkImageFormatProperties imageFormatProperties; + VkExternalMemoryFeatureFlagsNV externalMemoryFeatures; + VkExternalMemoryHandleTypeFlagsNV exportFromImportedHandleTypes; + VkExternalMemoryHandleTypeFlagsNV compatibleHandleTypes; +} VkExternalImageFormatPropertiesNV; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceExternalImageFormatPropertiesNV)(VkPhysicalDevice physicalDevice, VkFormat format, VkImageType type, VkImageTiling tiling, VkImageUsageFlags usage, VkImageCreateFlags flags, VkExternalMemoryHandleTypeFlagsNV externalHandleType, VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceExternalImageFormatPropertiesNV( + VkPhysicalDevice physicalDevice, + VkFormat format, + VkImageType type, + VkImageTiling tiling, + VkImageUsageFlags usage, + VkImageCreateFlags flags, + VkExternalMemoryHandleTypeFlagsNV externalHandleType, + VkExternalImageFormatPropertiesNV* pExternalImageFormatProperties); +#endif + +#define VK_NV_external_memory 1 +#define VK_NV_EXTERNAL_MEMORY_SPEC_VERSION 1 +#define VK_NV_EXTERNAL_MEMORY_EXTENSION_NAME "VK_NV_external_memory" + +typedef struct VkExternalMemoryImageCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsNV handleTypes; +} VkExternalMemoryImageCreateInfoNV; + +typedef struct VkExportMemoryAllocateInfoNV { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsNV handleTypes; +} VkExportMemoryAllocateInfoNV; + + + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_NV_external_memory_win32 1 +#define VK_NV_EXTERNAL_MEMORY_WIN32_SPEC_VERSION 1 +#define VK_NV_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME "VK_NV_external_memory_win32" + +typedef struct VkImportMemoryWin32HandleInfoNV { + VkStructureType sType; + const void* pNext; + VkExternalMemoryHandleTypeFlagsNV handleType; + HANDLE handle; +} VkImportMemoryWin32HandleInfoNV; + +typedef struct VkExportMemoryWin32HandleInfoNV { + VkStructureType sType; + const void* pNext; + const SECURITY_ATTRIBUTES* pAttributes; + DWORD dwAccess; +} VkExportMemoryWin32HandleInfoNV; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetMemoryWin32HandleNV)(VkDevice device, VkDeviceMemory memory, VkExternalMemoryHandleTypeFlagsNV handleType, HANDLE* pHandle); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetMemoryWin32HandleNV( + VkDevice device, + VkDeviceMemory memory, + VkExternalMemoryHandleTypeFlagsNV handleType, + HANDLE* pHandle); +#endif +#endif /* VK_USE_PLATFORM_WIN32_KHR */ + +#ifdef VK_USE_PLATFORM_WIN32_KHR +#define VK_NV_win32_keyed_mutex 1 +#define VK_NV_WIN32_KEYED_MUTEX_SPEC_VERSION 1 +#define VK_NV_WIN32_KEYED_MUTEX_EXTENSION_NAME "VK_NV_win32_keyed_mutex" + +typedef struct VkWin32KeyedMutexAcquireReleaseInfoNV { + VkStructureType sType; + const void* pNext; + uint32_t acquireCount; + const VkDeviceMemory* pAcquireSyncs; + const uint64_t* pAcquireKeys; + const uint32_t* pAcquireTimeoutMilliseconds; + uint32_t releaseCount; + const VkDeviceMemory* pReleaseSyncs; + const uint64_t* pReleaseKeys; +} VkWin32KeyedMutexAcquireReleaseInfoNV; + + +#endif /* VK_USE_PLATFORM_WIN32_KHR */ + +#define VK_KHX_device_group 1 +#define VK_MAX_DEVICE_GROUP_SIZE_KHX 32 +#define VK_KHX_DEVICE_GROUP_SPEC_VERSION 1 +#define VK_KHX_DEVICE_GROUP_EXTENSION_NAME "VK_KHX_device_group" + + +typedef enum VkPeerMemoryFeatureFlagBitsKHX { + VK_PEER_MEMORY_FEATURE_COPY_SRC_BIT_KHX = 0x00000001, + VK_PEER_MEMORY_FEATURE_COPY_DST_BIT_KHX = 0x00000002, + VK_PEER_MEMORY_FEATURE_GENERIC_SRC_BIT_KHX = 0x00000004, + VK_PEER_MEMORY_FEATURE_GENERIC_DST_BIT_KHX = 0x00000008, + VK_PEER_MEMORY_FEATURE_FLAG_BITS_MAX_ENUM_KHX = 0x7FFFFFFF +} VkPeerMemoryFeatureFlagBitsKHX; +typedef VkFlags VkPeerMemoryFeatureFlagsKHX; + +typedef enum VkMemoryAllocateFlagBitsKHX { + VK_MEMORY_ALLOCATE_DEVICE_MASK_BIT_KHX = 0x00000001, + VK_MEMORY_ALLOCATE_FLAG_BITS_MAX_ENUM_KHX = 0x7FFFFFFF +} VkMemoryAllocateFlagBitsKHX; +typedef VkFlags VkMemoryAllocateFlagsKHX; + +typedef enum VkDeviceGroupPresentModeFlagBitsKHX { + VK_DEVICE_GROUP_PRESENT_MODE_LOCAL_BIT_KHX = 0x00000001, + VK_DEVICE_GROUP_PRESENT_MODE_REMOTE_BIT_KHX = 0x00000002, + VK_DEVICE_GROUP_PRESENT_MODE_SUM_BIT_KHX = 0x00000004, + VK_DEVICE_GROUP_PRESENT_MODE_LOCAL_MULTI_DEVICE_BIT_KHX = 0x00000008, + VK_DEVICE_GROUP_PRESENT_MODE_FLAG_BITS_MAX_ENUM_KHX = 0x7FFFFFFF +} VkDeviceGroupPresentModeFlagBitsKHX; +typedef VkFlags VkDeviceGroupPresentModeFlagsKHX; + +typedef struct VkMemoryAllocateFlagsInfoKHX { + VkStructureType sType; + const void* pNext; + VkMemoryAllocateFlagsKHX flags; + uint32_t deviceMask; +} VkMemoryAllocateFlagsInfoKHX; + +typedef struct VkBindBufferMemoryInfoKHX { + VkStructureType sType; + const void* pNext; + VkBuffer buffer; + VkDeviceMemory memory; + VkDeviceSize memoryOffset; + uint32_t deviceIndexCount; + const uint32_t* pDeviceIndices; +} VkBindBufferMemoryInfoKHX; + +typedef struct VkBindImageMemoryInfoKHX { + VkStructureType sType; + const void* pNext; + VkImage image; + VkDeviceMemory memory; + VkDeviceSize memoryOffset; + uint32_t deviceIndexCount; + const uint32_t* pDeviceIndices; + uint32_t SFRRectCount; + const VkRect2D* pSFRRects; +} VkBindImageMemoryInfoKHX; + +typedef struct VkDeviceGroupRenderPassBeginInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t deviceMask; + uint32_t deviceRenderAreaCount; + const VkRect2D* pDeviceRenderAreas; +} VkDeviceGroupRenderPassBeginInfoKHX; + +typedef struct VkDeviceGroupCommandBufferBeginInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t deviceMask; +} VkDeviceGroupCommandBufferBeginInfoKHX; + +typedef struct VkDeviceGroupSubmitInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t waitSemaphoreCount; + const uint32_t* pWaitSemaphoreDeviceIndices; + uint32_t commandBufferCount; + const uint32_t* pCommandBufferDeviceMasks; + uint32_t signalSemaphoreCount; + const uint32_t* pSignalSemaphoreDeviceIndices; +} VkDeviceGroupSubmitInfoKHX; + +typedef struct VkDeviceGroupBindSparseInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t resourceDeviceIndex; + uint32_t memoryDeviceIndex; +} VkDeviceGroupBindSparseInfoKHX; + +typedef struct VkDeviceGroupPresentCapabilitiesKHX { + VkStructureType sType; + const void* pNext; + uint32_t presentMask[VK_MAX_DEVICE_GROUP_SIZE_KHX]; + VkDeviceGroupPresentModeFlagsKHX modes; +} VkDeviceGroupPresentCapabilitiesKHX; + +typedef struct VkImageSwapchainCreateInfoKHX { + VkStructureType sType; + const void* pNext; + VkSwapchainKHR swapchain; +} VkImageSwapchainCreateInfoKHX; + +typedef struct VkBindImageMemorySwapchainInfoKHX { + VkStructureType sType; + const void* pNext; + VkSwapchainKHR swapchain; + uint32_t imageIndex; +} VkBindImageMemorySwapchainInfoKHX; + +typedef struct VkAcquireNextImageInfoKHX { + VkStructureType sType; + const void* pNext; + VkSwapchainKHR swapchain; + uint64_t timeout; + VkSemaphore semaphore; + VkFence fence; + uint32_t deviceMask; +} VkAcquireNextImageInfoKHX; + +typedef struct VkDeviceGroupPresentInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t swapchainCount; + const uint32_t* pDeviceMasks; + VkDeviceGroupPresentModeFlagBitsKHX mode; +} VkDeviceGroupPresentInfoKHX; + +typedef struct VkDeviceGroupSwapchainCreateInfoKHX { + VkStructureType sType; + const void* pNext; + VkDeviceGroupPresentModeFlagsKHX modes; +} VkDeviceGroupSwapchainCreateInfoKHX; + + +typedef void (VKAPI_PTR *PFN_vkGetDeviceGroupPeerMemoryFeaturesKHX)(VkDevice device, uint32_t heapIndex, uint32_t localDeviceIndex, uint32_t remoteDeviceIndex, VkPeerMemoryFeatureFlagsKHX* pPeerMemoryFeatures); +typedef VkResult (VKAPI_PTR *PFN_vkBindBufferMemory2KHX)(VkDevice device, uint32_t bindInfoCount, const VkBindBufferMemoryInfoKHX* pBindInfos); +typedef VkResult (VKAPI_PTR *PFN_vkBindImageMemory2KHX)(VkDevice device, uint32_t bindInfoCount, const VkBindImageMemoryInfoKHX* pBindInfos); +typedef void (VKAPI_PTR *PFN_vkCmdSetDeviceMaskKHX)(VkCommandBuffer commandBuffer, uint32_t deviceMask); +typedef VkResult (VKAPI_PTR *PFN_vkGetDeviceGroupPresentCapabilitiesKHX)(VkDevice device, VkDeviceGroupPresentCapabilitiesKHX* pDeviceGroupPresentCapabilities); +typedef VkResult (VKAPI_PTR *PFN_vkGetDeviceGroupSurfacePresentModesKHX)(VkDevice device, VkSurfaceKHR surface, VkDeviceGroupPresentModeFlagsKHX* pModes); +typedef VkResult (VKAPI_PTR *PFN_vkAcquireNextImage2KHX)(VkDevice device, const VkAcquireNextImageInfoKHX* pAcquireInfo, uint32_t* pImageIndex); +typedef void (VKAPI_PTR *PFN_vkCmdDispatchBaseKHX)(VkCommandBuffer commandBuffer, uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ, uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ); +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDevicePresentRectanglesKHX)(VkPhysicalDevice physicalDevice, VkSurfaceKHR surface, uint32_t* pRectCount, VkRect2D* pRects); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkGetDeviceGroupPeerMemoryFeaturesKHX( + VkDevice device, + uint32_t heapIndex, + uint32_t localDeviceIndex, + uint32_t remoteDeviceIndex, + VkPeerMemoryFeatureFlagsKHX* pPeerMemoryFeatures); + +VKAPI_ATTR VkResult VKAPI_CALL vkBindBufferMemory2KHX( + VkDevice device, + uint32_t bindInfoCount, + const VkBindBufferMemoryInfoKHX* pBindInfos); + +VKAPI_ATTR VkResult VKAPI_CALL vkBindImageMemory2KHX( + VkDevice device, + uint32_t bindInfoCount, + const VkBindImageMemoryInfoKHX* pBindInfos); + +VKAPI_ATTR void VKAPI_CALL vkCmdSetDeviceMaskKHX( + VkCommandBuffer commandBuffer, + uint32_t deviceMask); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetDeviceGroupPresentCapabilitiesKHX( + VkDevice device, + VkDeviceGroupPresentCapabilitiesKHX* pDeviceGroupPresentCapabilities); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetDeviceGroupSurfacePresentModesKHX( + VkDevice device, + VkSurfaceKHR surface, + VkDeviceGroupPresentModeFlagsKHX* pModes); + +VKAPI_ATTR VkResult VKAPI_CALL vkAcquireNextImage2KHX( + VkDevice device, + const VkAcquireNextImageInfoKHX* pAcquireInfo, + uint32_t* pImageIndex); + +VKAPI_ATTR void VKAPI_CALL vkCmdDispatchBaseKHX( + VkCommandBuffer commandBuffer, + uint32_t baseGroupX, + uint32_t baseGroupY, + uint32_t baseGroupZ, + uint32_t groupCountX, + uint32_t groupCountY, + uint32_t groupCountZ); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDevicePresentRectanglesKHX( + VkPhysicalDevice physicalDevice, + VkSurfaceKHR surface, + uint32_t* pRectCount, + VkRect2D* pRects); +#endif + +#define VK_EXT_validation_flags 1 +#define VK_EXT_VALIDATION_FLAGS_SPEC_VERSION 1 +#define VK_EXT_VALIDATION_FLAGS_EXTENSION_NAME "VK_EXT_validation_flags" + + +typedef enum VkValidationCheckEXT { + VK_VALIDATION_CHECK_ALL_EXT = 0, + VK_VALIDATION_CHECK_SHADERS_EXT = 1, + VK_VALIDATION_CHECK_BEGIN_RANGE_EXT = VK_VALIDATION_CHECK_ALL_EXT, + VK_VALIDATION_CHECK_END_RANGE_EXT = VK_VALIDATION_CHECK_SHADERS_EXT, + VK_VALIDATION_CHECK_RANGE_SIZE_EXT = (VK_VALIDATION_CHECK_SHADERS_EXT - VK_VALIDATION_CHECK_ALL_EXT + 1), + VK_VALIDATION_CHECK_MAX_ENUM_EXT = 0x7FFFFFFF +} VkValidationCheckEXT; typedef struct VkValidationFlagsEXT { VkStructureType sType; const void* pNext; - uint32_t disabledValidationCheckCount; - VkValidationCheckEXT* pDisabledValidationChecks; -} VkValidationFlagsEXT; + uint32_t disabledValidationCheckCount; + VkValidationCheckEXT* pDisabledValidationChecks; +} VkValidationFlagsEXT; + + + +#ifdef VK_USE_PLATFORM_VI_NN +#define VK_NN_vi_surface 1 +#define VK_NN_VI_SURFACE_SPEC_VERSION 1 +#define VK_NN_VI_SURFACE_EXTENSION_NAME "VK_NN_vi_surface" + +typedef VkFlags VkViSurfaceCreateFlagsNN; + +typedef struct VkViSurfaceCreateInfoNN { + VkStructureType sType; + const void* pNext; + VkViSurfaceCreateFlagsNN flags; + void* window; +} VkViSurfaceCreateInfoNN; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateViSurfaceNN)(VkInstance instance, const VkViSurfaceCreateInfoNN* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkSurfaceKHR* pSurface); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateViSurfaceNN( + VkInstance instance, + const VkViSurfaceCreateInfoNN* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkSurfaceKHR* pSurface); +#endif +#endif /* VK_USE_PLATFORM_VI_NN */ + +#define VK_EXT_shader_subgroup_ballot 1 +#define VK_EXT_SHADER_SUBGROUP_BALLOT_SPEC_VERSION 1 +#define VK_EXT_SHADER_SUBGROUP_BALLOT_EXTENSION_NAME "VK_EXT_shader_subgroup_ballot" + + +#define VK_EXT_shader_subgroup_vote 1 +#define VK_EXT_SHADER_SUBGROUP_VOTE_SPEC_VERSION 1 +#define VK_EXT_SHADER_SUBGROUP_VOTE_EXTENSION_NAME "VK_EXT_shader_subgroup_vote" + + +#define VK_KHX_device_group_creation 1 +#define VK_KHX_DEVICE_GROUP_CREATION_SPEC_VERSION 1 +#define VK_KHX_DEVICE_GROUP_CREATION_EXTENSION_NAME "VK_KHX_device_group_creation" + +typedef struct VkPhysicalDeviceGroupPropertiesKHX { + VkStructureType sType; + void* pNext; + uint32_t physicalDeviceCount; + VkPhysicalDevice physicalDevices[VK_MAX_DEVICE_GROUP_SIZE_KHX]; + VkBool32 subsetAllocation; +} VkPhysicalDeviceGroupPropertiesKHX; + +typedef struct VkDeviceGroupDeviceCreateInfoKHX { + VkStructureType sType; + const void* pNext; + uint32_t physicalDeviceCount; + const VkPhysicalDevice* pPhysicalDevices; +} VkDeviceGroupDeviceCreateInfoKHX; + + +typedef VkResult (VKAPI_PTR *PFN_vkEnumeratePhysicalDeviceGroupsKHX)(VkInstance instance, uint32_t* pPhysicalDeviceGroupCount, VkPhysicalDeviceGroupPropertiesKHX* pPhysicalDeviceGroupProperties); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkEnumeratePhysicalDeviceGroupsKHX( + VkInstance instance, + uint32_t* pPhysicalDeviceGroupCount, + VkPhysicalDeviceGroupPropertiesKHX* pPhysicalDeviceGroupProperties); +#endif + +#define VK_NVX_device_generated_commands 1 +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkObjectTableNVX) +VK_DEFINE_NON_DISPATCHABLE_HANDLE(VkIndirectCommandsLayoutNVX) + +#define VK_NVX_DEVICE_GENERATED_COMMANDS_SPEC_VERSION 3 +#define VK_NVX_DEVICE_GENERATED_COMMANDS_EXTENSION_NAME "VK_NVX_device_generated_commands" + + +typedef enum VkIndirectCommandsTokenTypeNVX { + VK_INDIRECT_COMMANDS_TOKEN_TYPE_PIPELINE_NVX = 0, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DESCRIPTOR_SET_NVX = 1, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_INDEX_BUFFER_NVX = 2, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_VERTEX_BUFFER_NVX = 3, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_PUSH_CONSTANT_NVX = 4, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DRAW_INDEXED_NVX = 5, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DRAW_NVX = 6, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_DISPATCH_NVX = 7, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_BEGIN_RANGE_NVX = VK_INDIRECT_COMMANDS_TOKEN_TYPE_PIPELINE_NVX, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_END_RANGE_NVX = VK_INDIRECT_COMMANDS_TOKEN_TYPE_DISPATCH_NVX, + VK_INDIRECT_COMMANDS_TOKEN_TYPE_RANGE_SIZE_NVX = (VK_INDIRECT_COMMANDS_TOKEN_TYPE_DISPATCH_NVX - VK_INDIRECT_COMMANDS_TOKEN_TYPE_PIPELINE_NVX + 1), + VK_INDIRECT_COMMANDS_TOKEN_TYPE_MAX_ENUM_NVX = 0x7FFFFFFF +} VkIndirectCommandsTokenTypeNVX; + +typedef enum VkObjectEntryTypeNVX { + VK_OBJECT_ENTRY_TYPE_DESCRIPTOR_SET_NVX = 0, + VK_OBJECT_ENTRY_TYPE_PIPELINE_NVX = 1, + VK_OBJECT_ENTRY_TYPE_INDEX_BUFFER_NVX = 2, + VK_OBJECT_ENTRY_TYPE_VERTEX_BUFFER_NVX = 3, + VK_OBJECT_ENTRY_TYPE_PUSH_CONSTANT_NVX = 4, + VK_OBJECT_ENTRY_TYPE_BEGIN_RANGE_NVX = VK_OBJECT_ENTRY_TYPE_DESCRIPTOR_SET_NVX, + VK_OBJECT_ENTRY_TYPE_END_RANGE_NVX = VK_OBJECT_ENTRY_TYPE_PUSH_CONSTANT_NVX, + VK_OBJECT_ENTRY_TYPE_RANGE_SIZE_NVX = (VK_OBJECT_ENTRY_TYPE_PUSH_CONSTANT_NVX - VK_OBJECT_ENTRY_TYPE_DESCRIPTOR_SET_NVX + 1), + VK_OBJECT_ENTRY_TYPE_MAX_ENUM_NVX = 0x7FFFFFFF +} VkObjectEntryTypeNVX; + + +typedef enum VkIndirectCommandsLayoutUsageFlagBitsNVX { + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_UNORDERED_SEQUENCES_BIT_NVX = 0x00000001, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_SPARSE_SEQUENCES_BIT_NVX = 0x00000002, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_EMPTY_EXECUTIONS_BIT_NVX = 0x00000004, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_INDEXED_SEQUENCES_BIT_NVX = 0x00000008, + VK_INDIRECT_COMMANDS_LAYOUT_USAGE_FLAG_BITS_MAX_ENUM_NVX = 0x7FFFFFFF +} VkIndirectCommandsLayoutUsageFlagBitsNVX; +typedef VkFlags VkIndirectCommandsLayoutUsageFlagsNVX; + +typedef enum VkObjectEntryUsageFlagBitsNVX { + VK_OBJECT_ENTRY_USAGE_GRAPHICS_BIT_NVX = 0x00000001, + VK_OBJECT_ENTRY_USAGE_COMPUTE_BIT_NVX = 0x00000002, + VK_OBJECT_ENTRY_USAGE_FLAG_BITS_MAX_ENUM_NVX = 0x7FFFFFFF +} VkObjectEntryUsageFlagBitsNVX; +typedef VkFlags VkObjectEntryUsageFlagsNVX; + +typedef struct VkDeviceGeneratedCommandsFeaturesNVX { + VkStructureType sType; + const void* pNext; + VkBool32 computeBindingPointSupport; +} VkDeviceGeneratedCommandsFeaturesNVX; + +typedef struct VkDeviceGeneratedCommandsLimitsNVX { + VkStructureType sType; + const void* pNext; + uint32_t maxIndirectCommandsLayoutTokenCount; + uint32_t maxObjectEntryCounts; + uint32_t minSequenceCountBufferOffsetAlignment; + uint32_t minSequenceIndexBufferOffsetAlignment; + uint32_t minCommandsTokenBufferOffsetAlignment; +} VkDeviceGeneratedCommandsLimitsNVX; + +typedef struct VkIndirectCommandsTokenNVX { + VkIndirectCommandsTokenTypeNVX tokenType; + VkBuffer buffer; + VkDeviceSize offset; +} VkIndirectCommandsTokenNVX; + +typedef struct VkIndirectCommandsLayoutTokenNVX { + VkIndirectCommandsTokenTypeNVX tokenType; + uint32_t bindingUnit; + uint32_t dynamicCount; + uint32_t divisor; +} VkIndirectCommandsLayoutTokenNVX; + +typedef struct VkIndirectCommandsLayoutCreateInfoNVX { + VkStructureType sType; + const void* pNext; + VkPipelineBindPoint pipelineBindPoint; + VkIndirectCommandsLayoutUsageFlagsNVX flags; + uint32_t tokenCount; + const VkIndirectCommandsLayoutTokenNVX* pTokens; +} VkIndirectCommandsLayoutCreateInfoNVX; + +typedef struct VkCmdProcessCommandsInfoNVX { + VkStructureType sType; + const void* pNext; + VkObjectTableNVX objectTable; + VkIndirectCommandsLayoutNVX indirectCommandsLayout; + uint32_t indirectCommandsTokenCount; + const VkIndirectCommandsTokenNVX* pIndirectCommandsTokens; + uint32_t maxSequencesCount; + VkCommandBuffer targetCommandBuffer; + VkBuffer sequencesCountBuffer; + VkDeviceSize sequencesCountOffset; + VkBuffer sequencesIndexBuffer; + VkDeviceSize sequencesIndexOffset; +} VkCmdProcessCommandsInfoNVX; + +typedef struct VkCmdReserveSpaceForCommandsInfoNVX { + VkStructureType sType; + const void* pNext; + VkObjectTableNVX objectTable; + VkIndirectCommandsLayoutNVX indirectCommandsLayout; + uint32_t maxSequencesCount; +} VkCmdReserveSpaceForCommandsInfoNVX; + +typedef struct VkObjectTableCreateInfoNVX { + VkStructureType sType; + const void* pNext; + uint32_t objectCount; + const VkObjectEntryTypeNVX* pObjectEntryTypes; + const uint32_t* pObjectEntryCounts; + const VkObjectEntryUsageFlagsNVX* pObjectEntryUsageFlags; + uint32_t maxUniformBuffersPerDescriptor; + uint32_t maxStorageBuffersPerDescriptor; + uint32_t maxStorageImagesPerDescriptor; + uint32_t maxSampledImagesPerDescriptor; + uint32_t maxPipelineLayouts; +} VkObjectTableCreateInfoNVX; + +typedef struct VkObjectTableEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; +} VkObjectTableEntryNVX; + +typedef struct VkObjectTablePipelineEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkPipeline pipeline; +} VkObjectTablePipelineEntryNVX; + +typedef struct VkObjectTableDescriptorSetEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkPipelineLayout pipelineLayout; + VkDescriptorSet descriptorSet; +} VkObjectTableDescriptorSetEntryNVX; + +typedef struct VkObjectTableVertexBufferEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkBuffer buffer; +} VkObjectTableVertexBufferEntryNVX; + +typedef struct VkObjectTableIndexBufferEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkBuffer buffer; + VkIndexType indexType; +} VkObjectTableIndexBufferEntryNVX; + +typedef struct VkObjectTablePushConstantEntryNVX { + VkObjectEntryTypeNVX type; + VkObjectEntryUsageFlagsNVX flags; + VkPipelineLayout pipelineLayout; + VkShaderStageFlags stageFlags; +} VkObjectTablePushConstantEntryNVX; + + +typedef void (VKAPI_PTR *PFN_vkCmdProcessCommandsNVX)(VkCommandBuffer commandBuffer, const VkCmdProcessCommandsInfoNVX* pProcessCommandsInfo); +typedef void (VKAPI_PTR *PFN_vkCmdReserveSpaceForCommandsNVX)(VkCommandBuffer commandBuffer, const VkCmdReserveSpaceForCommandsInfoNVX* pReserveSpaceInfo); +typedef VkResult (VKAPI_PTR *PFN_vkCreateIndirectCommandsLayoutNVX)(VkDevice device, const VkIndirectCommandsLayoutCreateInfoNVX* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkIndirectCommandsLayoutNVX* pIndirectCommandsLayout); +typedef void (VKAPI_PTR *PFN_vkDestroyIndirectCommandsLayoutNVX)(VkDevice device, VkIndirectCommandsLayoutNVX indirectCommandsLayout, const VkAllocationCallbacks* pAllocator); +typedef VkResult (VKAPI_PTR *PFN_vkCreateObjectTableNVX)(VkDevice device, const VkObjectTableCreateInfoNVX* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkObjectTableNVX* pObjectTable); +typedef void (VKAPI_PTR *PFN_vkDestroyObjectTableNVX)(VkDevice device, VkObjectTableNVX objectTable, const VkAllocationCallbacks* pAllocator); +typedef VkResult (VKAPI_PTR *PFN_vkRegisterObjectsNVX)(VkDevice device, VkObjectTableNVX objectTable, uint32_t objectCount, const VkObjectTableEntryNVX* const* ppObjectTableEntries, const uint32_t* pObjectIndices); +typedef VkResult (VKAPI_PTR *PFN_vkUnregisterObjectsNVX)(VkDevice device, VkObjectTableNVX objectTable, uint32_t objectCount, const VkObjectEntryTypeNVX* pObjectEntryTypes, const uint32_t* pObjectIndices); +typedef void (VKAPI_PTR *PFN_vkGetPhysicalDeviceGeneratedCommandsPropertiesNVX)(VkPhysicalDevice physicalDevice, VkDeviceGeneratedCommandsFeaturesNVX* pFeatures, VkDeviceGeneratedCommandsLimitsNVX* pLimits); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdProcessCommandsNVX( + VkCommandBuffer commandBuffer, + const VkCmdProcessCommandsInfoNVX* pProcessCommandsInfo); + +VKAPI_ATTR void VKAPI_CALL vkCmdReserveSpaceForCommandsNVX( + VkCommandBuffer commandBuffer, + const VkCmdReserveSpaceForCommandsInfoNVX* pReserveSpaceInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkCreateIndirectCommandsLayoutNVX( + VkDevice device, + const VkIndirectCommandsLayoutCreateInfoNVX* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkIndirectCommandsLayoutNVX* pIndirectCommandsLayout); + +VKAPI_ATTR void VKAPI_CALL vkDestroyIndirectCommandsLayoutNVX( + VkDevice device, + VkIndirectCommandsLayoutNVX indirectCommandsLayout, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR VkResult VKAPI_CALL vkCreateObjectTableNVX( + VkDevice device, + const VkObjectTableCreateInfoNVX* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkObjectTableNVX* pObjectTable); + +VKAPI_ATTR void VKAPI_CALL vkDestroyObjectTableNVX( + VkDevice device, + VkObjectTableNVX objectTable, + const VkAllocationCallbacks* pAllocator); + +VKAPI_ATTR VkResult VKAPI_CALL vkRegisterObjectsNVX( + VkDevice device, + VkObjectTableNVX objectTable, + uint32_t objectCount, + const VkObjectTableEntryNVX* const* ppObjectTableEntries, + const uint32_t* pObjectIndices); + +VKAPI_ATTR VkResult VKAPI_CALL vkUnregisterObjectsNVX( + VkDevice device, + VkObjectTableNVX objectTable, + uint32_t objectCount, + const VkObjectEntryTypeNVX* pObjectEntryTypes, + const uint32_t* pObjectIndices); + +VKAPI_ATTR void VKAPI_CALL vkGetPhysicalDeviceGeneratedCommandsPropertiesNVX( + VkPhysicalDevice physicalDevice, + VkDeviceGeneratedCommandsFeaturesNVX* pFeatures, + VkDeviceGeneratedCommandsLimitsNVX* pLimits); +#endif + +#define VK_NV_clip_space_w_scaling 1 +#define VK_NV_CLIP_SPACE_W_SCALING_SPEC_VERSION 1 +#define VK_NV_CLIP_SPACE_W_SCALING_EXTENSION_NAME "VK_NV_clip_space_w_scaling" + +typedef struct VkViewportWScalingNV { + float xcoeff; + float ycoeff; +} VkViewportWScalingNV; + +typedef struct VkPipelineViewportWScalingStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkBool32 viewportWScalingEnable; + uint32_t viewportCount; + const VkViewportWScalingNV* pViewportWScalings; +} VkPipelineViewportWScalingStateCreateInfoNV; + + +typedef void (VKAPI_PTR *PFN_vkCmdSetViewportWScalingNV)(VkCommandBuffer commandBuffer, uint32_t firstViewport, uint32_t viewportCount, const VkViewportWScalingNV* pViewportWScalings); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdSetViewportWScalingNV( + VkCommandBuffer commandBuffer, + uint32_t firstViewport, + uint32_t viewportCount, + const VkViewportWScalingNV* pViewportWScalings); +#endif + +#define VK_EXT_direct_mode_display 1 +#define VK_EXT_DIRECT_MODE_DISPLAY_SPEC_VERSION 1 +#define VK_EXT_DIRECT_MODE_DISPLAY_EXTENSION_NAME "VK_EXT_direct_mode_display" + +typedef VkResult (VKAPI_PTR *PFN_vkReleaseDisplayEXT)(VkPhysicalDevice physicalDevice, VkDisplayKHR display); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkReleaseDisplayEXT( + VkPhysicalDevice physicalDevice, + VkDisplayKHR display); +#endif + +#ifdef VK_USE_PLATFORM_XLIB_XRANDR_EXT +#define VK_EXT_acquire_xlib_display 1 +#include + +#define VK_EXT_ACQUIRE_XLIB_DISPLAY_SPEC_VERSION 1 +#define VK_EXT_ACQUIRE_XLIB_DISPLAY_EXTENSION_NAME "VK_EXT_acquire_xlib_display" + +typedef VkResult (VKAPI_PTR *PFN_vkAcquireXlibDisplayEXT)(VkPhysicalDevice physicalDevice, Display* dpy, VkDisplayKHR display); +typedef VkResult (VKAPI_PTR *PFN_vkGetRandROutputDisplayEXT)(VkPhysicalDevice physicalDevice, Display* dpy, RROutput rrOutput, VkDisplayKHR* pDisplay); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkAcquireXlibDisplayEXT( + VkPhysicalDevice physicalDevice, + Display* dpy, + VkDisplayKHR display); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetRandROutputDisplayEXT( + VkPhysicalDevice physicalDevice, + Display* dpy, + RROutput rrOutput, + VkDisplayKHR* pDisplay); +#endif +#endif /* VK_USE_PLATFORM_XLIB_XRANDR_EXT */ + +#define VK_EXT_display_surface_counter 1 +#define VK_EXT_DISPLAY_SURFACE_COUNTER_SPEC_VERSION 1 +#define VK_EXT_DISPLAY_SURFACE_COUNTER_EXTENSION_NAME "VK_EXT_display_surface_counter" +#define VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES2_EXT VK_STRUCTURE_TYPE_SURFACE_CAPABILITIES_2_EXT + + +typedef enum VkSurfaceCounterFlagBitsEXT { + VK_SURFACE_COUNTER_VBLANK_EXT = 0x00000001, + VK_SURFACE_COUNTER_FLAG_BITS_MAX_ENUM_EXT = 0x7FFFFFFF +} VkSurfaceCounterFlagBitsEXT; +typedef VkFlags VkSurfaceCounterFlagsEXT; + +typedef struct VkSurfaceCapabilities2EXT { + VkStructureType sType; + void* pNext; + uint32_t minImageCount; + uint32_t maxImageCount; + VkExtent2D currentExtent; + VkExtent2D minImageExtent; + VkExtent2D maxImageExtent; + uint32_t maxImageArrayLayers; + VkSurfaceTransformFlagsKHR supportedTransforms; + VkSurfaceTransformFlagBitsKHR currentTransform; + VkCompositeAlphaFlagsKHR supportedCompositeAlpha; + VkImageUsageFlags supportedUsageFlags; + VkSurfaceCounterFlagsEXT supportedSurfaceCounters; +} VkSurfaceCapabilities2EXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetPhysicalDeviceSurfaceCapabilities2EXT)(VkPhysicalDevice physicalDevice, VkSurfaceKHR surface, VkSurfaceCapabilities2EXT* pSurfaceCapabilities); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetPhysicalDeviceSurfaceCapabilities2EXT( + VkPhysicalDevice physicalDevice, + VkSurfaceKHR surface, + VkSurfaceCapabilities2EXT* pSurfaceCapabilities); +#endif + +#define VK_EXT_display_control 1 +#define VK_EXT_DISPLAY_CONTROL_SPEC_VERSION 1 +#define VK_EXT_DISPLAY_CONTROL_EXTENSION_NAME "VK_EXT_display_control" + + +typedef enum VkDisplayPowerStateEXT { + VK_DISPLAY_POWER_STATE_OFF_EXT = 0, + VK_DISPLAY_POWER_STATE_SUSPEND_EXT = 1, + VK_DISPLAY_POWER_STATE_ON_EXT = 2, + VK_DISPLAY_POWER_STATE_BEGIN_RANGE_EXT = VK_DISPLAY_POWER_STATE_OFF_EXT, + VK_DISPLAY_POWER_STATE_END_RANGE_EXT = VK_DISPLAY_POWER_STATE_ON_EXT, + VK_DISPLAY_POWER_STATE_RANGE_SIZE_EXT = (VK_DISPLAY_POWER_STATE_ON_EXT - VK_DISPLAY_POWER_STATE_OFF_EXT + 1), + VK_DISPLAY_POWER_STATE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDisplayPowerStateEXT; + +typedef enum VkDeviceEventTypeEXT { + VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT = 0, + VK_DEVICE_EVENT_TYPE_BEGIN_RANGE_EXT = VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT, + VK_DEVICE_EVENT_TYPE_END_RANGE_EXT = VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT, + VK_DEVICE_EVENT_TYPE_RANGE_SIZE_EXT = (VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT - VK_DEVICE_EVENT_TYPE_DISPLAY_HOTPLUG_EXT + 1), + VK_DEVICE_EVENT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDeviceEventTypeEXT; + +typedef enum VkDisplayEventTypeEXT { + VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT = 0, + VK_DISPLAY_EVENT_TYPE_BEGIN_RANGE_EXT = VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT, + VK_DISPLAY_EVENT_TYPE_END_RANGE_EXT = VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT, + VK_DISPLAY_EVENT_TYPE_RANGE_SIZE_EXT = (VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT - VK_DISPLAY_EVENT_TYPE_FIRST_PIXEL_OUT_EXT + 1), + VK_DISPLAY_EVENT_TYPE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDisplayEventTypeEXT; + +typedef struct VkDisplayPowerInfoEXT { + VkStructureType sType; + const void* pNext; + VkDisplayPowerStateEXT powerState; +} VkDisplayPowerInfoEXT; + +typedef struct VkDeviceEventInfoEXT { + VkStructureType sType; + const void* pNext; + VkDeviceEventTypeEXT deviceEvent; +} VkDeviceEventInfoEXT; + +typedef struct VkDisplayEventInfoEXT { + VkStructureType sType; + const void* pNext; + VkDisplayEventTypeEXT displayEvent; +} VkDisplayEventInfoEXT; + +typedef struct VkSwapchainCounterCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkSurfaceCounterFlagsEXT surfaceCounters; +} VkSwapchainCounterCreateInfoEXT; + + +typedef VkResult (VKAPI_PTR *PFN_vkDisplayPowerControlEXT)(VkDevice device, VkDisplayKHR display, const VkDisplayPowerInfoEXT* pDisplayPowerInfo); +typedef VkResult (VKAPI_PTR *PFN_vkRegisterDeviceEventEXT)(VkDevice device, const VkDeviceEventInfoEXT* pDeviceEventInfo, const VkAllocationCallbacks* pAllocator, VkFence* pFence); +typedef VkResult (VKAPI_PTR *PFN_vkRegisterDisplayEventEXT)(VkDevice device, VkDisplayKHR display, const VkDisplayEventInfoEXT* pDisplayEventInfo, const VkAllocationCallbacks* pAllocator, VkFence* pFence); +typedef VkResult (VKAPI_PTR *PFN_vkGetSwapchainCounterEXT)(VkDevice device, VkSwapchainKHR swapchain, VkSurfaceCounterFlagBitsEXT counter, uint64_t* pCounterValue); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkDisplayPowerControlEXT( + VkDevice device, + VkDisplayKHR display, + const VkDisplayPowerInfoEXT* pDisplayPowerInfo); + +VKAPI_ATTR VkResult VKAPI_CALL vkRegisterDeviceEventEXT( + VkDevice device, + const VkDeviceEventInfoEXT* pDeviceEventInfo, + const VkAllocationCallbacks* pAllocator, + VkFence* pFence); + +VKAPI_ATTR VkResult VKAPI_CALL vkRegisterDisplayEventEXT( + VkDevice device, + VkDisplayKHR display, + const VkDisplayEventInfoEXT* pDisplayEventInfo, + const VkAllocationCallbacks* pAllocator, + VkFence* pFence); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetSwapchainCounterEXT( + VkDevice device, + VkSwapchainKHR swapchain, + VkSurfaceCounterFlagBitsEXT counter, + uint64_t* pCounterValue); +#endif + +#define VK_GOOGLE_display_timing 1 +#define VK_GOOGLE_DISPLAY_TIMING_SPEC_VERSION 1 +#define VK_GOOGLE_DISPLAY_TIMING_EXTENSION_NAME "VK_GOOGLE_display_timing" + +typedef struct VkRefreshCycleDurationGOOGLE { + uint64_t refreshDuration; +} VkRefreshCycleDurationGOOGLE; + +typedef struct VkPastPresentationTimingGOOGLE { + uint32_t presentID; + uint64_t desiredPresentTime; + uint64_t actualPresentTime; + uint64_t earliestPresentTime; + uint64_t presentMargin; +} VkPastPresentationTimingGOOGLE; + +typedef struct VkPresentTimeGOOGLE { + uint32_t presentID; + uint64_t desiredPresentTime; +} VkPresentTimeGOOGLE; + +typedef struct VkPresentTimesInfoGOOGLE { + VkStructureType sType; + const void* pNext; + uint32_t swapchainCount; + const VkPresentTimeGOOGLE* pTimes; +} VkPresentTimesInfoGOOGLE; + + +typedef VkResult (VKAPI_PTR *PFN_vkGetRefreshCycleDurationGOOGLE)(VkDevice device, VkSwapchainKHR swapchain, VkRefreshCycleDurationGOOGLE* pDisplayTimingProperties); +typedef VkResult (VKAPI_PTR *PFN_vkGetPastPresentationTimingGOOGLE)(VkDevice device, VkSwapchainKHR swapchain, uint32_t* pPresentationTimingCount, VkPastPresentationTimingGOOGLE* pPresentationTimings); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkGetRefreshCycleDurationGOOGLE( + VkDevice device, + VkSwapchainKHR swapchain, + VkRefreshCycleDurationGOOGLE* pDisplayTimingProperties); + +VKAPI_ATTR VkResult VKAPI_CALL vkGetPastPresentationTimingGOOGLE( + VkDevice device, + VkSwapchainKHR swapchain, + uint32_t* pPresentationTimingCount, + VkPastPresentationTimingGOOGLE* pPresentationTimings); +#endif + +#define VK_NV_sample_mask_override_coverage 1 +#define VK_NV_SAMPLE_MASK_OVERRIDE_COVERAGE_SPEC_VERSION 1 +#define VK_NV_SAMPLE_MASK_OVERRIDE_COVERAGE_EXTENSION_NAME "VK_NV_sample_mask_override_coverage" + + +#define VK_NV_geometry_shader_passthrough 1 +#define VK_NV_GEOMETRY_SHADER_PASSTHROUGH_SPEC_VERSION 1 +#define VK_NV_GEOMETRY_SHADER_PASSTHROUGH_EXTENSION_NAME "VK_NV_geometry_shader_passthrough" + + +#define VK_NV_viewport_array2 1 +#define VK_NV_VIEWPORT_ARRAY2_SPEC_VERSION 1 +#define VK_NV_VIEWPORT_ARRAY2_EXTENSION_NAME "VK_NV_viewport_array2" + + +#define VK_NVX_multiview_per_view_attributes 1 +#define VK_NVX_MULTIVIEW_PER_VIEW_ATTRIBUTES_SPEC_VERSION 1 +#define VK_NVX_MULTIVIEW_PER_VIEW_ATTRIBUTES_EXTENSION_NAME "VK_NVX_multiview_per_view_attributes" + +typedef struct VkPhysicalDeviceMultiviewPerViewAttributesPropertiesNVX { + VkStructureType sType; + void* pNext; + VkBool32 perViewPositionAllComponents; +} VkPhysicalDeviceMultiviewPerViewAttributesPropertiesNVX; + + + +#define VK_NV_viewport_swizzle 1 +#define VK_NV_VIEWPORT_SWIZZLE_SPEC_VERSION 1 +#define VK_NV_VIEWPORT_SWIZZLE_EXTENSION_NAME "VK_NV_viewport_swizzle" + + +typedef enum VkViewportCoordinateSwizzleNV { + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_X_NV = 0, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_X_NV = 1, + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_Y_NV = 2, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_Y_NV = 3, + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_Z_NV = 4, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_Z_NV = 5, + VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_W_NV = 6, + VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_W_NV = 7, + VK_VIEWPORT_COORDINATE_SWIZZLE_BEGIN_RANGE_NV = VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_X_NV, + VK_VIEWPORT_COORDINATE_SWIZZLE_END_RANGE_NV = VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_W_NV, + VK_VIEWPORT_COORDINATE_SWIZZLE_RANGE_SIZE_NV = (VK_VIEWPORT_COORDINATE_SWIZZLE_NEGATIVE_W_NV - VK_VIEWPORT_COORDINATE_SWIZZLE_POSITIVE_X_NV + 1), + VK_VIEWPORT_COORDINATE_SWIZZLE_MAX_ENUM_NV = 0x7FFFFFFF +} VkViewportCoordinateSwizzleNV; + +typedef VkFlags VkPipelineViewportSwizzleStateCreateFlagsNV; + +typedef struct VkViewportSwizzleNV { + VkViewportCoordinateSwizzleNV x; + VkViewportCoordinateSwizzleNV y; + VkViewportCoordinateSwizzleNV z; + VkViewportCoordinateSwizzleNV w; +} VkViewportSwizzleNV; + +typedef struct VkPipelineViewportSwizzleStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkPipelineViewportSwizzleStateCreateFlagsNV flags; + uint32_t viewportCount; + const VkViewportSwizzleNV* pViewportSwizzles; +} VkPipelineViewportSwizzleStateCreateInfoNV; + + + +#define VK_EXT_discard_rectangles 1 +#define VK_EXT_DISCARD_RECTANGLES_SPEC_VERSION 1 +#define VK_EXT_DISCARD_RECTANGLES_EXTENSION_NAME "VK_EXT_discard_rectangles" + + +typedef enum VkDiscardRectangleModeEXT { + VK_DISCARD_RECTANGLE_MODE_INCLUSIVE_EXT = 0, + VK_DISCARD_RECTANGLE_MODE_EXCLUSIVE_EXT = 1, + VK_DISCARD_RECTANGLE_MODE_BEGIN_RANGE_EXT = VK_DISCARD_RECTANGLE_MODE_INCLUSIVE_EXT, + VK_DISCARD_RECTANGLE_MODE_END_RANGE_EXT = VK_DISCARD_RECTANGLE_MODE_EXCLUSIVE_EXT, + VK_DISCARD_RECTANGLE_MODE_RANGE_SIZE_EXT = (VK_DISCARD_RECTANGLE_MODE_EXCLUSIVE_EXT - VK_DISCARD_RECTANGLE_MODE_INCLUSIVE_EXT + 1), + VK_DISCARD_RECTANGLE_MODE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkDiscardRectangleModeEXT; + +typedef VkFlags VkPipelineDiscardRectangleStateCreateFlagsEXT; + +typedef struct VkPhysicalDeviceDiscardRectanglePropertiesEXT { + VkStructureType sType; + void* pNext; + uint32_t maxDiscardRectangles; +} VkPhysicalDeviceDiscardRectanglePropertiesEXT; + +typedef struct VkPipelineDiscardRectangleStateCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkPipelineDiscardRectangleStateCreateFlagsEXT flags; + VkDiscardRectangleModeEXT discardRectangleMode; + uint32_t discardRectangleCount; + const VkRect2D* pDiscardRectangles; +} VkPipelineDiscardRectangleStateCreateInfoEXT; + + +typedef void (VKAPI_PTR *PFN_vkCmdSetDiscardRectangleEXT)(VkCommandBuffer commandBuffer, uint32_t firstDiscardRectangle, uint32_t discardRectangleCount, const VkRect2D* pDiscardRectangles); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkCmdSetDiscardRectangleEXT( + VkCommandBuffer commandBuffer, + uint32_t firstDiscardRectangle, + uint32_t discardRectangleCount, + const VkRect2D* pDiscardRectangles); +#endif + +#define VK_EXT_swapchain_colorspace 1 +#define VK_EXT_SWAPCHAIN_COLOR_SPACE_SPEC_VERSION 3 +#define VK_EXT_SWAPCHAIN_COLOR_SPACE_EXTENSION_NAME "VK_EXT_swapchain_colorspace" + + +#define VK_EXT_hdr_metadata 1 +#define VK_EXT_HDR_METADATA_SPEC_VERSION 1 +#define VK_EXT_HDR_METADATA_EXTENSION_NAME "VK_EXT_hdr_metadata" + +typedef struct VkXYColorEXT { + float x; + float y; +} VkXYColorEXT; + +typedef struct VkHdrMetadataEXT { + VkStructureType sType; + const void* pNext; + VkXYColorEXT displayPrimaryRed; + VkXYColorEXT displayPrimaryGreen; + VkXYColorEXT displayPrimaryBlue; + VkXYColorEXT whitePoint; + float maxLuminance; + float minLuminance; + float maxContentLightLevel; + float maxFrameAverageLightLevel; +} VkHdrMetadataEXT; + + +typedef void (VKAPI_PTR *PFN_vkSetHdrMetadataEXT)(VkDevice device, uint32_t swapchainCount, const VkSwapchainKHR* pSwapchains, const VkHdrMetadataEXT* pMetadata); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR void VKAPI_CALL vkSetHdrMetadataEXT( + VkDevice device, + uint32_t swapchainCount, + const VkSwapchainKHR* pSwapchains, + const VkHdrMetadataEXT* pMetadata); +#endif + +#ifdef VK_USE_PLATFORM_IOS_MVK +#define VK_MVK_ios_surface 1 +#define VK_MVK_IOS_SURFACE_SPEC_VERSION 2 +#define VK_MVK_IOS_SURFACE_EXTENSION_NAME "VK_MVK_ios_surface" + +typedef VkFlags VkIOSSurfaceCreateFlagsMVK; + +typedef struct VkIOSSurfaceCreateInfoMVK { + VkStructureType sType; + const void* pNext; + VkIOSSurfaceCreateFlagsMVK flags; + const void* pView; +} VkIOSSurfaceCreateInfoMVK; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateIOSSurfaceMVK)(VkInstance instance, const VkIOSSurfaceCreateInfoMVK* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkSurfaceKHR* pSurface); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateIOSSurfaceMVK( + VkInstance instance, + const VkIOSSurfaceCreateInfoMVK* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkSurfaceKHR* pSurface); +#endif +#endif /* VK_USE_PLATFORM_IOS_MVK */ + +#ifdef VK_USE_PLATFORM_MACOS_MVK +#define VK_MVK_macos_surface 1 +#define VK_MVK_MACOS_SURFACE_SPEC_VERSION 2 +#define VK_MVK_MACOS_SURFACE_EXTENSION_NAME "VK_MVK_macos_surface" + +typedef VkFlags VkMacOSSurfaceCreateFlagsMVK; + +typedef struct VkMacOSSurfaceCreateInfoMVK { + VkStructureType sType; + const void* pNext; + VkMacOSSurfaceCreateFlagsMVK flags; + const void* pView; +} VkMacOSSurfaceCreateInfoMVK; + + +typedef VkResult (VKAPI_PTR *PFN_vkCreateMacOSSurfaceMVK)(VkInstance instance, const VkMacOSSurfaceCreateInfoMVK* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkSurfaceKHR* pSurface); + +#ifndef VK_NO_PROTOTYPES +VKAPI_ATTR VkResult VKAPI_CALL vkCreateMacOSSurfaceMVK( + VkInstance instance, + const VkMacOSSurfaceCreateInfoMVK* pCreateInfo, + const VkAllocationCallbacks* pAllocator, + VkSurfaceKHR* pSurface); +#endif +#endif /* VK_USE_PLATFORM_MACOS_MVK */ + +#define VK_EXT_sampler_filter_minmax 1 +#define VK_EXT_SAMPLER_FILTER_MINMAX_SPEC_VERSION 1 +#define VK_EXT_SAMPLER_FILTER_MINMAX_EXTENSION_NAME "VK_EXT_sampler_filter_minmax" + + +typedef enum VkSamplerReductionModeEXT { + VK_SAMPLER_REDUCTION_MODE_WEIGHTED_AVERAGE_EXT = 0, + VK_SAMPLER_REDUCTION_MODE_MIN_EXT = 1, + VK_SAMPLER_REDUCTION_MODE_MAX_EXT = 2, + VK_SAMPLER_REDUCTION_MODE_BEGIN_RANGE_EXT = VK_SAMPLER_REDUCTION_MODE_WEIGHTED_AVERAGE_EXT, + VK_SAMPLER_REDUCTION_MODE_END_RANGE_EXT = VK_SAMPLER_REDUCTION_MODE_MAX_EXT, + VK_SAMPLER_REDUCTION_MODE_RANGE_SIZE_EXT = (VK_SAMPLER_REDUCTION_MODE_MAX_EXT - VK_SAMPLER_REDUCTION_MODE_WEIGHTED_AVERAGE_EXT + 1), + VK_SAMPLER_REDUCTION_MODE_MAX_ENUM_EXT = 0x7FFFFFFF +} VkSamplerReductionModeEXT; + +typedef struct VkSamplerReductionModeCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkSamplerReductionModeEXT reductionMode; +} VkSamplerReductionModeCreateInfoEXT; + +typedef struct VkPhysicalDeviceSamplerFilterMinmaxPropertiesEXT { + VkStructureType sType; + void* pNext; + VkBool32 filterMinmaxSingleComponentFormats; + VkBool32 filterMinmaxImageComponentMapping; +} VkPhysicalDeviceSamplerFilterMinmaxPropertiesEXT; + + + +#define VK_AMD_gpu_shader_int16 1 +#define VK_AMD_GPU_SHADER_INT16_SPEC_VERSION 1 +#define VK_AMD_GPU_SHADER_INT16_EXTENSION_NAME "VK_AMD_gpu_shader_int16" + + +#define VK_AMD_mixed_attachment_samples 1 +#define VK_AMD_MIXED_ATTACHMENT_SAMPLES_SPEC_VERSION 1 +#define VK_AMD_MIXED_ATTACHMENT_SAMPLES_EXTENSION_NAME "VK_AMD_mixed_attachment_samples" + + +#define VK_EXT_shader_stencil_export 1 +#define VK_EXT_SHADER_STENCIL_EXPORT_SPEC_VERSION 1 +#define VK_EXT_SHADER_STENCIL_EXPORT_EXTENSION_NAME "VK_EXT_shader_stencil_export" + + +#define VK_EXT_blend_operation_advanced 1 +#define VK_EXT_BLEND_OPERATION_ADVANCED_SPEC_VERSION 2 +#define VK_EXT_BLEND_OPERATION_ADVANCED_EXTENSION_NAME "VK_EXT_blend_operation_advanced" + + +typedef enum VkBlendOverlapEXT { + VK_BLEND_OVERLAP_UNCORRELATED_EXT = 0, + VK_BLEND_OVERLAP_DISJOINT_EXT = 1, + VK_BLEND_OVERLAP_CONJOINT_EXT = 2, + VK_BLEND_OVERLAP_BEGIN_RANGE_EXT = VK_BLEND_OVERLAP_UNCORRELATED_EXT, + VK_BLEND_OVERLAP_END_RANGE_EXT = VK_BLEND_OVERLAP_CONJOINT_EXT, + VK_BLEND_OVERLAP_RANGE_SIZE_EXT = (VK_BLEND_OVERLAP_CONJOINT_EXT - VK_BLEND_OVERLAP_UNCORRELATED_EXT + 1), + VK_BLEND_OVERLAP_MAX_ENUM_EXT = 0x7FFFFFFF +} VkBlendOverlapEXT; + +typedef struct VkPhysicalDeviceBlendOperationAdvancedFeaturesEXT { + VkStructureType sType; + void* pNext; + VkBool32 advancedBlendCoherentOperations; +} VkPhysicalDeviceBlendOperationAdvancedFeaturesEXT; + +typedef struct VkPhysicalDeviceBlendOperationAdvancedPropertiesEXT { + VkStructureType sType; + void* pNext; + uint32_t advancedBlendMaxColorAttachments; + VkBool32 advancedBlendIndependentBlend; + VkBool32 advancedBlendNonPremultipliedSrcColor; + VkBool32 advancedBlendNonPremultipliedDstColor; + VkBool32 advancedBlendCorrelatedOverlap; + VkBool32 advancedBlendAllOperations; +} VkPhysicalDeviceBlendOperationAdvancedPropertiesEXT; + +typedef struct VkPipelineColorBlendAdvancedStateCreateInfoEXT { + VkStructureType sType; + const void* pNext; + VkBool32 srcPremultiplied; + VkBool32 dstPremultiplied; + VkBlendOverlapEXT blendOverlap; +} VkPipelineColorBlendAdvancedStateCreateInfoEXT; + + + +#define VK_NV_fragment_coverage_to_color 1 +#define VK_NV_FRAGMENT_COVERAGE_TO_COLOR_SPEC_VERSION 1 +#define VK_NV_FRAGMENT_COVERAGE_TO_COLOR_EXTENSION_NAME "VK_NV_fragment_coverage_to_color" + +typedef VkFlags VkPipelineCoverageToColorStateCreateFlagsNV; + +typedef struct VkPipelineCoverageToColorStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkPipelineCoverageToColorStateCreateFlagsNV flags; + VkBool32 coverageToColorEnable; + uint32_t coverageToColorLocation; +} VkPipelineCoverageToColorStateCreateInfoNV; + + + +#define VK_NV_framebuffer_mixed_samples 1 +#define VK_NV_FRAMEBUFFER_MIXED_SAMPLES_SPEC_VERSION 1 +#define VK_NV_FRAMEBUFFER_MIXED_SAMPLES_EXTENSION_NAME "VK_NV_framebuffer_mixed_samples" + + +typedef enum VkCoverageModulationModeNV { + VK_COVERAGE_MODULATION_MODE_NONE_NV = 0, + VK_COVERAGE_MODULATION_MODE_RGB_NV = 1, + VK_COVERAGE_MODULATION_MODE_ALPHA_NV = 2, + VK_COVERAGE_MODULATION_MODE_RGBA_NV = 3, + VK_COVERAGE_MODULATION_MODE_BEGIN_RANGE_NV = VK_COVERAGE_MODULATION_MODE_NONE_NV, + VK_COVERAGE_MODULATION_MODE_END_RANGE_NV = VK_COVERAGE_MODULATION_MODE_RGBA_NV, + VK_COVERAGE_MODULATION_MODE_RANGE_SIZE_NV = (VK_COVERAGE_MODULATION_MODE_RGBA_NV - VK_COVERAGE_MODULATION_MODE_NONE_NV + 1), + VK_COVERAGE_MODULATION_MODE_MAX_ENUM_NV = 0x7FFFFFFF +} VkCoverageModulationModeNV; + +typedef VkFlags VkPipelineCoverageModulationStateCreateFlagsNV; + +typedef struct VkPipelineCoverageModulationStateCreateInfoNV { + VkStructureType sType; + const void* pNext; + VkPipelineCoverageModulationStateCreateFlagsNV flags; + VkCoverageModulationModeNV coverageModulationMode; + VkBool32 coverageModulationTableEnable; + uint32_t coverageModulationTableCount; + const float* pCoverageModulationTable; +} VkPipelineCoverageModulationStateCreateInfoNV; + + + +#define VK_NV_fill_rectangle 1 +#define VK_NV_FILL_RECTANGLE_SPEC_VERSION 1 +#define VK_NV_FILL_RECTANGLE_EXTENSION_NAME "VK_NV_fill_rectangle" + + +#define VK_EXT_post_depth_coverage 1 +#define VK_EXT_POST_DEPTH_COVERAGE_SPEC_VERSION 1 +#define VK_EXT_POST_DEPTH_COVERAGE_EXTENSION_NAME "VK_EXT_post_depth_coverage" + +#define VK_EXT_shader_viewport_index_layer 1 +#define VK_EXT_SHADER_VIEWPORT_INDEX_LAYER_SPEC_VERSION 1 +#define VK_EXT_SHADER_VIEWPORT_INDEX_LAYER_EXTENSION_NAME "VK_EXT_shader_viewport_index_layer" #ifdef __cplusplus diff --git a/caffe2/mpi/mpi_common.h b/caffe2/mpi/mpi_common.h index fab89edbdd72e..ab04afb869756 100644 --- a/caffe2/mpi/mpi_common.h +++ b/caffe2/mpi/mpi_common.h @@ -34,7 +34,7 @@ MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE) #undef MPI_DATATYPE_WRAPPER // For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard. -CAFFE2_API std::mutex& MPIMutex(); +TORCH_API std::mutex& MPIMutex(); #define MPI_CHECK(condition) \ do { \ @@ -54,23 +54,23 @@ CAFFE2_API std::mutex& MPIMutex(); * @brief Gets the global MPI communicator used by Caffe2. In default, this * is MPI_COMM_WORLD unless you call SetGlobalMPIComm(). */ -CAFFE2_API MPI_Comm GlobalMPIComm(); +TORCH_API MPI_Comm GlobalMPIComm(); /** * @brief Sets the global MPI communicator. Caffe2 takes over the ownership * of the passed in communicator. */ -CAFFE2_API void SetGlobalMPIComm(MPI_Comm new_comm); +TORCH_API void SetGlobalMPIComm(MPI_Comm new_comm); /** * @brief A helper function to return the size of the given communicator. */ -CAFFE2_API int MPICommSize(MPI_Comm comm); +TORCH_API int MPICommSize(MPI_Comm comm); /** * @brief A helper function to return the rank of the given communicator. */ -CAFFE2_API int MPICommRank(MPI_Comm comm); +TORCH_API int MPICommRank(MPI_Comm comm); /** * @brief A simple wrapper over an MPI common world. diff --git a/caffe2/observers/profile_observer.h b/caffe2/observers/profile_observer.h index 89cd83fb33e78..8f397101c15bc 100644 --- a/caffe2/observers/profile_observer.h +++ b/caffe2/observers/profile_observer.h @@ -46,7 +46,7 @@ class ProfileCounter { float run_time_ = 0.0f; }; -class CAFFE2_API ProfileOperatorObserver final +class TORCH_API ProfileOperatorObserver final : public ProfileCounter, public ObserverBase { public: @@ -94,7 +94,7 @@ class CAFFE2_API ProfileOperatorObserver final void Stop() override; }; -class CAFFE2_API ProfileObserver final : public OperatorAttachingNetObserver< +class TORCH_API ProfileObserver final : public OperatorAttachingNetObserver< ProfileOperatorObserver, ProfileObserver> { public: diff --git a/caffe2/observers/runcnt_observer.h b/caffe2/observers/runcnt_observer.h index 76a0e40e12d7d..93bf4e4eefe55 100644 --- a/caffe2/observers/runcnt_observer.h +++ b/caffe2/observers/runcnt_observer.h @@ -9,7 +9,7 @@ namespace caffe2 { class RunCountNetObserver; -class CAFFE2_API RunCountOperatorObserver final +class TORCH_API RunCountOperatorObserver final : public ObserverBase { public: explicit RunCountOperatorObserver(OperatorBase* op) = delete; @@ -27,7 +27,7 @@ class CAFFE2_API RunCountOperatorObserver final RunCountNetObserver* netObserver_; }; -class CAFFE2_API RunCountNetObserver final +class TORCH_API RunCountNetObserver final : public OperatorAttachingNetObserver< RunCountOperatorObserver, RunCountNetObserver> { diff --git a/caffe2/observers/time_observer.h b/caffe2/observers/time_observer.h index fa54e1f9cd566..84de8efd26c59 100644 --- a/caffe2/observers/time_observer.h +++ b/caffe2/observers/time_observer.h @@ -14,7 +14,7 @@ namespace caffe2 { class TimeObserver; -class CAFFE2_API TimeCounter { +class TORCH_API TimeCounter { public: explicit TimeCounter() {} inline float average_time() const { @@ -28,7 +28,7 @@ class CAFFE2_API TimeCounter { int iterations_ = 0; }; -class CAFFE2_API TimeOperatorObserver final +class TORCH_API TimeOperatorObserver final : public TimeCounter, public ObserverBase { public: @@ -46,7 +46,7 @@ class CAFFE2_API TimeOperatorObserver final void Stop() override; }; -class CAFFE2_API TimeObserver final +class TORCH_API TimeObserver final : public TimeCounter, public OperatorAttachingNetObserver { public: diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h index 05d56787f4bb5..5fec9dcb28db5 100644 --- a/caffe2/onnx/backend.h +++ b/caffe2/onnx/backend.h @@ -25,7 +25,7 @@ using ::ONNX_NAMESPACE::ValueInfoProto; using ValueInfoMap = std::unordered_map; -class CAFFE2_API ConversionContext { +class TORCH_API ConversionContext { public: ConversionContext(const ValueInfoMap& value_infos, int opset_version) : value_infos_(value_infos), opset_version_(opset_version) {} @@ -44,7 +44,7 @@ class CAFFE2_API ConversionContext { // \brief This struct holds the converted ops after the onnx->c2 conversion. // Notice that for RNN ops, it may create ops in init_net. Hence we have the // `init_ops` field. -struct CAFFE2_API Caffe2Ops { +struct TORCH_API Caffe2Ops { ::google::protobuf::RepeatedPtrField init_ops; ::google::protobuf::RepeatedPtrField ops; ::google::protobuf::RepeatedPtrField interface_blobs; @@ -52,7 +52,7 @@ struct CAFFE2_API Caffe2Ops { // A convenient class to query attributes of a NodeProto. Note that the // NodeProto can not be modified during the query of OnnxAttributes object -class CAFFE2_API OnnxAttributes { +class TORCH_API OnnxAttributes { public: OnnxAttributes(const NodeProto& node); @@ -120,7 +120,7 @@ template <> const TensorProto* OnnxAttributes::get(const std::string& key) const; // convenient class for onnx node -struct CAFFE2_API OnnxNode { +struct TORCH_API OnnxNode { OnnxNode(const NodeProto& node_in) : node(node_in), attributes(node_in) {} const NodeProto& node; @@ -128,7 +128,7 @@ struct CAFFE2_API OnnxNode { OnnxAttributes attributes; }; -class CAFFE2_API Caffe2Backend { +class TORCH_API Caffe2Backend { public: // Since we still have this Python-C++ hybrid flow, we will need to take the // DummyName generator from Python as a pointer. In this case, Python env owns diff --git a/caffe2/onnx/backend_rep.h b/caffe2/onnx/backend_rep.h index eb91ea63b909d..861899532f1d7 100644 --- a/caffe2/onnx/backend_rep.h +++ b/caffe2/onnx/backend_rep.h @@ -9,7 +9,7 @@ namespace caffe2 { namespace onnx { -class CAFFE2_API Caffe2BackendRep { +class TORCH_API Caffe2BackendRep { public: void Run( const caffe2::Predictor::TensorList& inputs, diff --git a/caffe2/onnx/helper.h b/caffe2/onnx/helper.h index c310aa46935d9..5f706b297389e 100644 --- a/caffe2/onnx/helper.h +++ b/caffe2/onnx/helper.h @@ -14,7 +14,7 @@ using ::ONNX_NAMESPACE::AttributeProto; using ::ONNX_NAMESPACE::NodeProto; // \brief This class generates unique dummy names -class CAFFE2_API DummyName { +class TORCH_API DummyName { public: std::string NewDummyName(); @@ -98,7 +98,7 @@ ::ONNX_NAMESPACE::TensorProto MakeTensor( return ret; } -CAFFE2_API NodeProto MakeNode( +TORCH_API NodeProto MakeNode( const std::string& type, const std::vector& inputs, const std::vector& outputs, diff --git a/caffe2/onnx/offline_tensor.h b/caffe2/onnx/offline_tensor.h index 094df7d9b7cf3..9c6b85d17ce56 100644 --- a/caffe2/onnx/offline_tensor.h +++ b/caffe2/onnx/offline_tensor.h @@ -7,7 +7,7 @@ namespace caffe2 { #ifndef C10_MOBILE -struct CAFFE2_API OfflineTensor { +struct TORCH_API OfflineTensor { // A shell tensor to record shape and dtype Tensor shape_tensor{CPU}; diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc index d7c0d2ffc304c..72e25033466d2 100644 --- a/caffe2/onnx/onnx_exporter.cc +++ b/caffe2/onnx/onnx_exporter.cc @@ -1,5 +1,6 @@ #include "caffe2/onnx/onnx_exporter.h" #include "caffe2/core/logging.h" +#include "caffe2/core/memonger.h" #include "caffe2/core/tensor_impl.h" #include "caffe2/onnx/helper.h" #include "caffe2/proto/caffe2_legacy.pb.h" @@ -141,10 +142,10 @@ void ssaRewriteForIfOp( std::vector if_external_output; std::unordered_set if_inputs, if_outputs; - for (const auto& input: op->input()) { + for (const auto& input : op->input()) { if_inputs.insert(input); } - for (const auto& output: op->output()) { + for (const auto& output : op->output()) { if_outputs.insert(output); } @@ -256,7 +257,6 @@ void revertRenamedExternalOutputForIfOp( } } } - } // namespace ::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType( @@ -305,7 +305,8 @@ void rewriteSubnet( std::unordered_map SsaRewrite( caffe2::NetDef* init_net, - caffe2::NetDef* pred_net) { + caffe2::NetDef* pred_net, + bool PreserveInPlaceOps) { std::unordered_map input_mapping; std::unordered_map blob_versions; @@ -325,6 +326,9 @@ std::unordered_map SsaRewrite( std::set is_initialized_tensor; if (pred_net) { + // Ssa rewriting modifies the net, check if the net passes schema check + run_schema_check(*pred_net); + std::unordered_set external_outputs; for (const auto& input : pred_net->external_input()) { // Create identical mapping for now. This shall be removed eventually. @@ -353,7 +357,25 @@ std::unordered_map SsaRewrite( } } - for (auto& output : *op.mutable_output()) { + for (int out_idx = 0; out_idx < op.output_size(); out_idx++) { + auto& output = *op.mutable_output(out_idx); + + // restore in-place settings + bool is_inplace = false; + if (PreserveInPlaceOps) { + for (int in_idx = 0; in_idx < op.input_size(); in_idx++) { + auto* schema = OpSchemaRegistry::Schema(op.type()); + if (schema && schema->inplace_enforced(in_idx, out_idx)) { + output = op.input(in_idx); + is_inplace = true; + break; + } + } + } + if (is_inplace) { + continue; + } + auto it = blob_versions.find(output); if (it != blob_versions.end()) { if (op.type() != "If" && op.type() != "AsyncIf") { @@ -408,6 +430,8 @@ std::unordered_map SsaRewrite( } } } + // run schema check again + run_schema_check(*pred_net); return input_mapping; } @@ -844,25 +868,28 @@ ConvertedResult OnnxExporter::CreateConvPoolNodes( const auto& input_size = shapes.at(node.input(0)); const auto& output_size = shapes.at(node.output(0)); CAFFE_ENFORCE_EQ(output_size.dims().size(), 4); - if (!global && // global pool does not care about legacy pad - legacy_pad_attr.i() != static_cast(caffe2::LegacyPadding::NOTSET)) { + if (!global && // global pool does not care about legacy pad + legacy_pad_attr.i() != + static_cast(caffe2::LegacyPadding::NOTSET)) { if (legacy_pad_attr.i() == static_cast(caffe2::LegacyPadding::VALID)) { CAFFE_ENFORCE(!attrs.count("pads")); attrs.emplace("auto_pad", MakeAttribute("auto_pad", "VALID")); - } else if (legacy_pad_attr.i() == + } else if ( + legacy_pad_attr.i() == static_cast(caffe2::LegacyPadding::SAME)) { CAFFE_ENFORCE(!attrs.count("pads")); // default behavior in Caffe2 is SAME_UPPER // https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h#L39 attrs.emplace("auto_pad", MakeAttribute("auto_pad", "SAME_UPPER")); - } else if (legacy_pad_attr.i() == + } else if ( + legacy_pad_attr.i() == static_cast(caffe2::LegacyPadding::CAFFE_LEGACY_POOLING)) { - // The problem here is that, Pool op in Caffe may add an additional pixel, - // if the last part is smaller than stride. So we use the explicit padding - // to replace legacy_pad. pad[end] = output_size[start + 2] * - // stride[start] - pad[start] - 1 + kernel[start] - input[start + 2] end = - // start + len(pad) / 2 + // The problem here is that, Pool op in Caffe may add an additional + // pixel, if the last part is smaller than stride. So we use the + // explicit padding to replace legacy_pad. pad[end] = output_size[start + // + 2] * stride[start] - pad[start] - 1 + kernel[start] - input[start + + // 2]; end = start + len(pad) / 2 LOG(WARNING) << "Converting legacy padding to explicit padding."; auto* pads_attr = attrs.at("pads").mutable_ints(); auto& strides_attr = attrs.at("strides").ints(); @@ -874,7 +901,8 @@ ConvertedResult OnnxExporter::CreateConvPoolNodes( pads_attr->Set(i + 2, tmp_pad); } } else { - LOG(ERROR) << "Don't know how to handle the legacy_pad:" << legacy_pad_attr.i(); + LOG(ERROR) << "Don't know how to handle the legacy_pad:" + << legacy_pad_attr.i(); CAFFE_THROW("Failed to handle legacy padding in pool operator!"); } } @@ -1002,16 +1030,16 @@ ConvertedResult OnnxExporter::CreateMergeDimNodes( } const auto reshaped = dummy_->NewDummyName(); - nodes.emplace_back(MakeNode("Reshape", - { x, const_tensors.back().name() }, - { reshaped })); + nodes.emplace_back( + MakeNode("Reshape", {x, const_tensors.back().name()}, {reshaped})); - nodes.emplace_back(MakeNode("Squeeze", - { reshaped }, - { y }, - std::vector{ - MakeAttribute("axes", std::vector{ 0 }), - })); + nodes.emplace_back(MakeNode( + "Squeeze", + {reshaped}, + {y}, + std::vector{ + MakeAttribute("axes", std::vector{0}), + })); return result; } @@ -1067,67 +1095,68 @@ ConvertedResult OnnxExporter::CreateChannelShuffleNodes( ConvertedResult OnnxExporter::CreateReduceMeanNodes( const caffe2::OperatorDef& def, const std::unordered_map& shapes) { - CAFFE_ENFORCE_GE(def.input_size(), 1); - CAFFE_ENFORCE_LE(def.input_size(), 2); - CAFFE_ENFORCE_EQ(def.input_size(), 1, "Input \"lengths\" is not supported."); - CAFFE_ENFORCE_GE(def.output_size(), 1); - const auto& x = def.input(0); - const auto& y = def.output(0); - const auto& dims = shapes.at(x).dims(); - - ConvertedResult result; - auto& nodes = result.first; - std::unordered_map args; - for (const auto& a : def.arg()) { - args.emplace(a.name(), &a); - } - - std::vector axes; - int64_t keepdims = 1; + CAFFE_ENFORCE_GE(def.input_size(), 1); + CAFFE_ENFORCE_LE(def.input_size(), 2); + CAFFE_ENFORCE_EQ(def.input_size(), 1, "Input \"lengths\" is not supported."); + CAFFE_ENFORCE_GE(def.output_size(), 1); + const auto& x = def.input(0); + const auto& y = def.output(0); + const auto& dims = shapes.at(x).dims(); - if (def.type() == "ReduceMean") { - // axes - auto it = args.find("axes"); - if (it == args.end()) { - axes.resize(dims.size()); - std::iota(axes.begin(), axes.end(), 0); - } else { - axes.assign(it->second->ints().begin(), it->second->ints().end()); - } + ConvertedResult result; + auto& nodes = result.first; + std::unordered_map args; + for (const auto& a : def.arg()) { + args.emplace(a.name(), &a); + } - // keepdims - it = args.find("keepdims"); - if (it != args.end()) { - keepdims = it->second->i(); - } + std::vector axes; + int64_t keepdims = 1; + + if (def.type() == "ReduceMean") { + // axes + auto it = args.find("axes"); + if (it == args.end()) { + axes.resize(dims.size()); + std::iota(axes.begin(), axes.end(), 0); } else { - // num_reduce_dim - auto it = args.find("num_reduce_dim"); - const int64_t num_reduce_dim = it == args.end() ? 1 : it->second->i(); - CAFFE_ENFORCE_LE(num_reduce_dim, dims.size()); - axes.resize(num_reduce_dim); - - int64_t start_dim = 0; - if (def.type() == "ReduceFrontMean") { - start_dim = 0; - } else if (def.type() == "ReduceBackMean") { - start_dim = dims.size() - axes.size(); - } - std::iota(axes.begin(), axes.end(), start_dim); + axes.assign(it->second->ints().begin(), it->second->ints().end()); + } - keepdims = 0; + // keepdims + it = args.find("keepdims"); + if (it != args.end()) { + keepdims = it->second->i(); + } + } else { + // num_reduce_dim + auto it = args.find("num_reduce_dim"); + const int64_t num_reduce_dim = it == args.end() ? 1 : it->second->i(); + CAFFE_ENFORCE_LE(num_reduce_dim, dims.size()); + axes.resize(num_reduce_dim); + + int64_t start_dim = 0; + if (def.type() == "ReduceFrontMean") { + start_dim = 0; + } else if (def.type() == "ReduceBackMean") { + start_dim = dims.size() - axes.size(); } + std::iota(axes.begin(), axes.end(), start_dim); - nodes.emplace_back(MakeNode("ReduceMean", - { x }, - { y }, - { - MakeAttribute("axes", axes), - MakeAttribute("keepdims", keepdims), - }, - def.name())); + keepdims = 0; + } + + nodes.emplace_back(MakeNode( + "ReduceMean", + {x}, + {y}, + { + MakeAttribute("axes", axes), + MakeAttribute("keepdims", keepdims), + }, + def.name())); - return result; + return result; } ConvertedResult OnnxExporter::CreateUpsampleNodes( @@ -1300,11 +1329,10 @@ ConvertedResult OnnxExporter::CreateGemmNodes( const auto inner = DimProd(x_shape, axis, x_shape.dims().size()); gemm_x_input = dummy_->NewDummyName(); - const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, - std::vector{ -1, inner })); - nodes.emplace_back(MakeNode("Reshape", - { x, const_tensors.back().name() }, - { gemm_x_input })); + const_tensors.emplace_back( + CreateOnnxShapeTensor(dummy_, std::vector{-1, inner})); + nodes.emplace_back( + MakeNode("Reshape", {x, const_tensors.back().name()}, {gemm_x_input})); } it = args.find("axis_w"); @@ -1317,20 +1345,20 @@ ConvertedResult OnnxExporter::CreateGemmNodes( auto outer = DimProd(w_shape, 0, axis_w); auto inner = DimProd(w_shape, axis_w, w_shape.dims().size()); auto reshaped_w = dummy_->NewDummyName(); - const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, - std::vector{ outer, inner })); - nodes.emplace_back(MakeNode("Reshape", - { w, const_tensors.back().name() }, - { reshaped_w })); + const_tensors.emplace_back( + CreateOnnxShapeTensor(dummy_, std::vector{outer, inner})); + nodes.emplace_back( + MakeNode("Reshape", {w, const_tensors.back().name()}, {reshaped_w})); w = reshaped_w; } auto gemm_y_output = axis > 1 ? dummy_->NewDummyName() : y; - nodes.emplace_back(MakeNode("Gemm", - { gemm_x_input, w, b }, - { gemm_y_output }, - { MakeAttribute("transB", 1L) }, - def.name())); + nodes.emplace_back(MakeNode( + "Gemm", + {gemm_x_input, w, b}, + {gemm_y_output}, + {MakeAttribute("transB", 1L)}, + def.name())); // capture the outer shape if needed. if (axis > 1) { @@ -1338,26 +1366,26 @@ ConvertedResult OnnxExporter::CreateGemmNodes( nodes.emplace_back(MakeNode("Shape", {x}, {x_shape})); const auto x_shape_outer = dummy_->NewDummyName(); - nodes.emplace_back(MakeNode("Slice", - { x_shape }, - { x_shape_outer }, - std::vector{ - MakeAttribute("starts", std::vector{ 0 }), - MakeAttribute("ends", std::vector{ axis }), - })); + nodes.emplace_back(MakeNode( + "Slice", + {x_shape}, + {x_shape_outer}, + std::vector{ + MakeAttribute("starts", std::vector{0}), + MakeAttribute("ends", std::vector{axis}), + })); const auto y_shape = dummy_->NewDummyName(); - const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, { -1 })); - nodes.emplace_back(MakeNode("Concat", - { x_shape_outer, const_tensors.back().name() }, - { y_shape }, - std::vector{ - MakeAttribute("axis", static_cast(0)), - })); - - nodes.emplace_back(MakeNode("Reshape", - { gemm_y_output, y_shape }, - { y })); + const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, {-1})); + nodes.emplace_back(MakeNode( + "Concat", + {x_shape_outer, const_tensors.back().name()}, + {y_shape}, + std::vector{ + MakeAttribute("axis", static_cast(0)), + })); + + nodes.emplace_back(MakeNode("Reshape", {gemm_y_output, y_shape}, {y})); } return result; @@ -1374,7 +1402,7 @@ void OnnxExporter::InitOpToTensorProto( const Argument* values = nullptr; const Argument* shape = nullptr; - for (const auto& arg: op.arg()) { + for (const auto& arg : op.arg()) { if (arg.name() == "values") { values = &arg; } else if (arg.name() == "shape") { @@ -1386,7 +1414,7 @@ void OnnxExporter::InitOpToTensorProto( CAFFE_ENFORCE(shape); // Set dims - for (const auto i: shape->ints()) { + for (const auto i : shape->ints()) { tensor->add_dims(i); } diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h index 726e65440eecf..c0040e5e3d161 100644 --- a/caffe2/onnx/onnx_exporter.h +++ b/caffe2/onnx/onnx_exporter.h @@ -31,14 +31,15 @@ void rewriteSubnet( // Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external // output names for predict net. -CAFFE2_API std::unordered_map SsaRewrite( +TORCH_API std::unordered_map SsaRewrite( caffe2::NetDef* init_net, - caffe2::NetDef* pred_net); + caffe2::NetDef* pred_net, + bool PreserveInPlaceOps = true); ::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType( caffe2::TensorProto::DataType t); -class CAFFE2_API OnnxExporter { +class TORCH_API OnnxExporter { using SpecialOpConverter = ConvertedResult (OnnxExporter::*)( const caffe2::OperatorDef&, const std::unordered_map&); diff --git a/caffe2/onnx/ssa_test.cc b/caffe2/onnx/ssa_test.cc index 5e6553ee8e005..e4902306ab5a1 100644 --- a/caffe2/onnx/ssa_test.cc +++ b/caffe2/onnx/ssa_test.cc @@ -38,7 +38,7 @@ TEST(SsaTest, ConvReluInplace) { EXPECT_EQ("Y", net.external_output(0)); } -TEST(SsaTest, FC_FC_FC_InPlace_Output) { +TEST(SsaTest, FC_Relu_FC_InPlace_Output) { caffe2::NetDef net; auto* op = net.add_op(); op->set_type("FC"); @@ -47,10 +47,8 @@ TEST(SsaTest, FC_FC_FC_InPlace_Output) { op->add_input("b0"); op->add_output("Y"); op = net.add_op(); - op->set_type("FC"); + op->set_type("Relu"); op->add_input("Y"); - op->add_input("W1"); - op->add_input("b1"); op->add_output("Y"); op = net.add_op(); op->set_type("FC"); diff --git a/caffe2/operators/async_net_barrier_op.cc b/caffe2/operators/async_net_barrier_op.cc new file mode 100644 index 0000000000000..25d10e673eac1 --- /dev/null +++ b/caffe2/operators/async_net_barrier_op.cc @@ -0,0 +1,50 @@ +#include "caffe2/operators/async_net_barrier_op.h" + +namespace caffe2 { + +namespace { +std::pair, std::vector> +asyncBarrierOpDevInfer(const OperatorDef& def) { + auto op_device = + def.has_device_option() ? def.device_option() : DeviceOption(); + ArgumentHelper helper(def); + auto cross_device = helper.GetSingleArgument("cross_device", 0); + std::vector opt; + for (int i = 0; i < def.input().size(); ++i) { + if (cross_device == 1) { + DeviceOption dev; + dev.set_device_type(op_device.device_type()); + dev.set_device_id(i); + opt.push_back(dev); + } else { + opt.push_back(op_device); + } + } + return std::make_pair(opt, opt); +} +} + +OPERATOR_SCHEMA(AsyncNetBarrier) + .NumInputs(1, INT_MAX) + .NumOutputs(1, INT_MAX) + .IdenticalTypeAndShape() + .InputsCanCrossDevices() + .AllowOneToOneInplace() + .DeviceInferenceFunction(asyncBarrierOpDevInfer) + .SetDoc(R"DOC( +This is a pretty much no-op operator, since it's only purposes is make sure that +async_scheduling will schedule certian operations earlier than others. + +Exaple where this operator can work well - mixture of data-parallel and model- +parallel training, where one wants to force that all copies are started before +data-parallel part starts. +)DOC") + .Arg( + "cross_device", + "Specifies either inputs should be across different devices in dev inference options"); + +SHOULD_NOT_DO_GRADIENT(AsyncNetBarrier); +REGISTER_CPU_OPERATOR(AsyncNetBarrier, AsyncNetBarrierOp); + + +} // namespace caffe2 diff --git a/caffe2/operators/async_net_barrier_op.cu b/caffe2/operators/async_net_barrier_op.cu new file mode 100644 index 0000000000000..b516c4c14177f --- /dev/null +++ b/caffe2/operators/async_net_barrier_op.cu @@ -0,0 +1,8 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/async_net_barrier_op.h" + +namespace caffe2 { + +REGISTER_CUDA_OPERATOR(AsyncNetBarrier, AsyncNetBarrierOp); + +} // namespace caffe2 diff --git a/caffe2/operators/async_net_barrier_op.h b/caffe2/operators/async_net_barrier_op.h new file mode 100644 index 0000000000000..9b44db317a7ab --- /dev/null +++ b/caffe2/operators/async_net_barrier_op.h @@ -0,0 +1,30 @@ +#ifndef CAFFE2_OPERATORS_ASYNC_BARRIER_OP_H_ +#define CAFFE2_OPERATORS_ASYNC_BARRIER_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class AsyncNetBarrierOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(AsyncNetBarrierOp) + + bool RunOnDevice() override { + // This is a pretty much no-op operator, since it's only purposes is make + // sure that async_scheduling will schedule certian operations earlier than + // others. + // + // Exaple where this operator can work well - mixture of data-parallel and + // model parallel training, where one wants to force that all copies are + // started before data-parallel part starts. + return true; + } +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_ASYNC_BARRIER_OP_H_ diff --git a/caffe2/operators/bbox_transform_op.cc b/caffe2/operators/bbox_transform_op.cc index 2a6e2d2c7bf47..8c17621c1490d 100644 --- a/caffe2/operators/bbox_transform_op.cc +++ b/caffe2/operators/bbox_transform_op.cc @@ -205,24 +205,4 @@ C10_EXPORT_CAFFE2_OP_TO_C10_CPU( "Tensor output_1" ")", BBoxTransformOpFloatCPU); - - C10_EXPORT_CAFFE2_OP_TO_C10_CPU( - BBoxTransform2, - "__caffe2::BBoxTransform(" - "Tensor rois, " - "Tensor deltas, " - "Tensor im_info, " - "float[] weights, " - "bool apply_scale, " - "bool rotated, " - "bool angle_bound_on, " - "int angle_bound_lo, " - "int angle_bound_hi, " - "float clip_angle_thresh, " - "bool legacy_plus_one" - ") -> (" - "Tensor output_0, " - "Tensor output_1" - ")", - BBoxTransformOpFloatCPU); // clang-format on diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc index f014616615f32..76d7c87dbb5b6 100644 --- a/caffe2/operators/box_with_nms_limit_op.cc +++ b/caffe2/operators/box_with_nms_limit_op.cc @@ -5,8 +5,9 @@ namespace caffe2 { template <> -bool BoxWithNMSLimitOp::RunOnDevice() { - const auto& tscores = Input(0); +template +bool BoxWithNMSLimitOp::DoRunWithType() { +const auto& tscores = Input(0); const auto& tboxes = Input(1); const int box_dim = rotated_ ? 5 : 4; @@ -35,18 +36,19 @@ bool BoxWithNMSLimitOp::RunOnDevice() { int num_boxes_classes = get_box_cls_index(num_classes - 1) + 1; CAFFE_ENFORCE_EQ(num_boxes_classes * box_dim, tboxes.size(1)); + // Default value for batch_size and batch_splits int batch_size = 1; - vector batch_splits_default(1, tscores.size(0)); - const float* batch_splits_data = batch_splits_default.data(); + vector batch_splits_default(1, tscores.size(0)); + const T* batch_splits_data = batch_splits_default.data(); if (InputSize() > 2) { // tscores and tboxes have items from multiple images in a batch. Get the // corresponding batch splits from input. const auto& tbatch_splits = Input(2); CAFFE_ENFORCE_EQ(tbatch_splits.dim(), 1); batch_size = tbatch_splits.size(0); - batch_splits_data = tbatch_splits.data(); + batch_splits_data = tbatch_splits.data(); } - Eigen::Map batch_splits(batch_splits_data, batch_size); + Eigen::Map> batch_splits(batch_splits_data, batch_size); CAFFE_ENFORCE_EQ(batch_splits.sum(), N); auto* out_scores = Output(0, {0}, at::dtype()); @@ -65,7 +67,7 @@ bool BoxWithNMSLimitOp::RunOnDevice() { vector total_keep_per_batch(batch_size); int offset = 0; for (int b = 0; b < batch_splits.size(); ++b) { - int num_boxes = batch_splits(b); + int num_boxes = batch_splits[b]; Eigen::Map scores( tscores.data() + offset * tscores.size(1), num_boxes, @@ -336,31 +338,4 @@ C10_EXPORT_CAFFE2_OP_TO_C10_CPU( ")", caffe2::BoxWithNMSLimitOp); -C10_EXPORT_CAFFE2_OP_TO_C10_CPU( - BoxWithNMSLimit2, - "__caffe2::BoxWithNMSLimit(" - "Tensor scores, " - "Tensor boxes, " - "Tensor batch_splits, " - "float score_thresh, " - "float nms, " - "int detections_per_im, " - "bool soft_nms_enabled, " - "str soft_nms_method, " - "float soft_nms_sigma, " - "float soft_nms_min_score_thres, " - "bool rotated, " - "bool cls_agnostic_bbox_reg, " - "bool input_boxes_include_bg_cls, " - "bool output_classes_include_bg_cls, " - "bool legacy_plus_one " - ") -> (" - "Tensor scores, " - "Tensor boxes, " - "Tensor classes, " - "Tensor batch_splits, " - "Tensor keeps, " - "Tensor keeps_size" - ")", - caffe2::BoxWithNMSLimitOp); // clang-format on diff --git a/caffe2/operators/box_with_nms_limit_op.h b/caffe2/operators/box_with_nms_limit_op.h index 0d61f68f3340c..0527e24424b50 100644 --- a/caffe2/operators/box_with_nms_limit_op.h +++ b/caffe2/operators/box_with_nms_limit_op.h @@ -60,7 +60,16 @@ class BoxWithNMSLimitOp final : public Operator { ~BoxWithNMSLimitOp() {} - bool RunOnDevice() override; + bool RunOnDevice() override { + if (InputSize() > 2) { + return DispatchHelper>::call(this, Input(2)); + } else { + return DoRunWithType(); + } + } + + template + bool DoRunWithType(); protected: // TEST.SCORE_THRESH diff --git a/caffe2/operators/bucketize_op.cu b/caffe2/operators/bucketize_op.cu index 1d48013e771d2..5d3049f239fb5 100644 --- a/caffe2/operators/bucketize_op.cu +++ b/caffe2/operators/bucketize_op.cu @@ -15,7 +15,7 @@ __global__ void BucketizeOpKernel( CUDA_1D_KERNEL_LOOP(i, N) { int32_t low = -1, high = M; while (high - low > 1) { - int32_t median = (high + low) / 2; + const int32_t median = low + (high - low) / 2; if (bounds[median] < X[i]) { low = median; } else { diff --git a/caffe2/operators/channel_shuffle_op.cu b/caffe2/operators/channel_shuffle_op.cu index 2c5a4e5e7f9fb..34f07afb6dd95 100644 --- a/caffe2/operators/channel_shuffle_op.cu +++ b/caffe2/operators/channel_shuffle_op.cu @@ -33,7 +33,7 @@ __global__ void ChannelShuffleNCHWKernel( template __global__ void -ChannelShuffleNHWCKernel(const int G, const int K, const float* X, float* Y) { +ChannelShuffleNHWCKernel(const int G, const int K, const T* X, T* Y) { __shared__ T sdata[kSharedSize]; const int C = G * K; const int offset = blockIdx.x * C; diff --git a/caffe2/operators/counter_ops.h b/caffe2/operators/counter_ops.h index aea013621fa08..bc90f9c933942 100644 --- a/caffe2/operators/counter_ops.h +++ b/caffe2/operators/counter_ops.h @@ -9,7 +9,7 @@ namespace caffe2 { template -class CAFFE2_API Counter { +class TORCH_API Counter { public: explicit Counter(T count) : count_(count) {} bool countDown() { diff --git a/caffe2/operators/create_scope_op.h b/caffe2/operators/create_scope_op.h index b5d75a8c434ca..474b1c105499c 100644 --- a/caffe2/operators/create_scope_op.h +++ b/caffe2/operators/create_scope_op.h @@ -20,7 +20,7 @@ namespace detail { * Keeps track of forward and backward gradient workspaces in stack, * reuses previously created workspaces, non-thread safe */ -class CAFFE2_API WorkspaceStack { +class TORCH_API WorkspaceStack { public: explicit WorkspaceStack() : parent_ws_(nullptr), top_(-1) {} diff --git a/caffe2/operators/cross_entropy_op.h b/caffe2/operators/cross_entropy_op.h index 932ed0d33372d..ec587b70e01cd 100644 --- a/caffe2/operators/cross_entropy_op.h +++ b/caffe2/operators/cross_entropy_op.h @@ -125,7 +125,7 @@ class WeightedSigmoidCrossEntropyWithLogitsGradientOp final }; template -class CAFFE2_API CrossEntropyOp final : public Operator { +class TORCH_API CrossEntropyOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(CrossEntropyOp); USE_OPERATOR_CONTEXT_FUNCTIONS; @@ -140,7 +140,7 @@ class CAFFE2_API CrossEntropyOp final : public Operator { }; template -class CAFFE2_API CrossEntropyGradientOp final : public Operator { +class TORCH_API CrossEntropyGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(CrossEntropyGradientOp); USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index ec2e93bc96e88..c311ad23e4ed9 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -407,7 +407,7 @@ class UnPackRecordsOp : public Operator { // Precomputer the output sizes to avoid resizing std::vector> outputDims(numTensors); - std::vector metas(numTensors); + std::vector metas(numTensors); CAFFE_ENFORCE( numRows > 0 || InputSize() > 1, @@ -428,7 +428,7 @@ class UnPackRecordsOp : public Operator { // Checks to ensure that dimensions/sizes match CAFFE_ENFORCE_EQ(outputDims[j].size(), input.dim()); - CAFFE_ENFORCE(*metas[j] == input.dtype()); + CAFFE_ENFORCE(metas[j] == input.dtype()); // We look from first dimension, because we concat on the first. for (int k = 1; k < input.dim(); ++k) { CAFFE_ENFORCE_EQ(input.sizes()[k], outputDims[j][k]); @@ -442,7 +442,7 @@ class UnPackRecordsOp : public Operator { std::vector destinations(numTensors); for (int i = 0; i < numTensors; ++i) { Output(i)->Resize(outputDims[i]); - destinations[i] = Output(i)->raw_mutable_data(*metas[i]); + destinations[i] = Output(i)->raw_mutable_data(metas[i]); } for (int i = 0; i < numRows; ++i) { @@ -450,7 +450,7 @@ class UnPackRecordsOp : public Operator { const auto& input = tensors[i][j]; context_.CopyItemsSameDevice( - *metas[j], + metas[j], input.numel(), input.raw_data() /* src */, destinations[j] /* dst */ @@ -468,7 +468,7 @@ class UnPackRecordsOp : public Operator { void getShapeAndMetaFromInput( const Shared2DTensorVectorPtr& inputs, std::vector>& outputDims, - std::vector& metas) { + std::vector& metas) { const auto& inputZero = inputs->at(0); const auto numTensors = inputZero.size(); @@ -479,13 +479,13 @@ class UnPackRecordsOp : public Operator { for (int i = 0; i < numTensors; ++i) { outputDims[i] = inputZero[i].sizes().vec(); outputDims[i][0] = 0; - metas[i] = &inputZero[i].dtype(); + metas[i] = inputZero[i].dtype(); } } void getShapeAndMetaFromPrototypeBlobs( std::vector>& outputDims, - std::vector& metas) { + std::vector& metas) { const auto numTensors = fields_.size(); CAFFE_ENFORCE_EQ(numTensors, InputSize() - 1); CAFFE_ENFORCE_EQ(numTensors, OutputSize()); @@ -493,7 +493,7 @@ class UnPackRecordsOp : public Operator { const auto& input = Input(i + 1); outputDims[i] = input.sizes().vec(); outputDims[i][0] = 0; - metas[i] = &input.dtype(); + metas[i] = input.dtype(); } } @@ -987,10 +987,9 @@ class CollectTensorOp final : public Operator { // append pos = numVisited_; } else { - auto& gen = context_.RandGenerator(); // uniform between [0, numVisited_] - std::uniform_int_distribution uniformDist(0, numVisited_); - pos = uniformDist(gen); + at::uniform_int_from_to_distribution uniformDist(numVisited_+1, 0); + pos = uniformDist(context_.RandGenerator()); if (pos >= numToCollect_) { // discard pos = -1; diff --git a/caffe2/operators/dataset_ops.h b/caffe2/operators/dataset_ops.h index 70a294e14136f..fc890014dbb24 100644 --- a/caffe2/operators/dataset_ops.h +++ b/caffe2/operators/dataset_ops.h @@ -146,7 +146,7 @@ class TreeWalker { return size; } - inline const TypeMeta& meta() const { + inline const TypeMeta meta() const { return walker_.input(fieldId_).dtype(); } diff --git a/caffe2/operators/dropout_op.cc b/caffe2/operators/dropout_op.cc index 0ddc3f2e4d703..102dd30abc281 100644 --- a/caffe2/operators/dropout_op.cc +++ b/caffe2/operators/dropout_op.cc @@ -17,15 +17,15 @@ bool DropoutOp::RunOnDevice() { float scale = 1. / (1. - ratio_); // mask=true means keep, and mask=false means not keep, so we will // generate probability depending on 1-ratio. - std::bernoulli_distribution dist(1. - ratio_); + at::bernoulli_distribution dist(1. - ratio_); const float* Xdata = X.data(); float* Ydata = Y->template mutable_data(); auto mask = Output(1, X.sizes(), at::dtype()); bool* mask_data = mask->template mutable_data(); - auto& gen = context_.RandGenerator(); + auto* gen = context_.RandGenerator(); for (int i = 0; i < X.numel(); ++i) { - mask_data[i] = dist(gen); + mask_data[i] = dist(gen) > 0.5; Ydata[i] = Xdata[i] * scale * mask_data[i]; } return true; diff --git a/caffe2/operators/elementwise_ops_utils.cc b/caffe2/operators/elementwise_ops_utils.cc index 5bb6c768ea3e0..0f76a1b35aa4d 100644 --- a/caffe2/operators/elementwise_ops_utils.cc +++ b/caffe2/operators/elementwise_ops_utils.cc @@ -53,7 +53,10 @@ std::vector ComputeBinaryBroadcastForwardDims( for (; i >= 0 && j >= 0; --k) { const int A_dim = A_dims[i]; const int B_dim = B_dims[j]; - CAFFE_ENFORCE(A_dim == B_dim || A_dim == 1 || B_dim == 1); + CAFFE_ENFORCE( + A_dim == B_dim || A_dim == 1 || B_dim == 1, + "A_dim: ", A_dim , ",B_dim: ", B_dim + ); if (A_dim == 0 || B_dim == 0) { C_dims[k] = 0; } else { diff --git a/caffe2/operators/elementwise_ops_utils.h b/caffe2/operators/elementwise_ops_utils.h index 93ef4006e9d24..104e7a818ba3f 100644 --- a/caffe2/operators/elementwise_ops_utils.h +++ b/caffe2/operators/elementwise_ops_utils.h @@ -10,20 +10,20 @@ namespace caffe2 { namespace elementwise_ops_utils { -CAFFE2_API std::tuple +TORCH_API std::tuple ComputeLegacyBroadcastSizes(const Tensor& A, const Tensor& B, int axis); -CAFFE2_API std::vector ComputeBinaryBroadcastForwardDims( +TORCH_API std::vector ComputeBinaryBroadcastForwardDims( const std::vector& A_dims, const std::vector& B_dims); -CAFFE2_API void ComputeBinaryBroadcastBackwardAxes( +TORCH_API void ComputeBinaryBroadcastBackwardAxes( const std::vector& A_dims, const std::vector& B_dims, std::vector* A_axes, std::vector* B_axes); -CAFFE2_API void ComputeBinaryBroadcastBackwardDims( +TORCH_API void ComputeBinaryBroadcastBackwardDims( const std::vector& A_dims, const std::vector& B_dims, std::vector* A_back_dims, diff --git a/caffe2/operators/filler_op.h b/caffe2/operators/filler_op.h index a34078fd1c2a3..7e01a01792bf5 100644 --- a/caffe2/operators/filler_op.h +++ b/caffe2/operators/filler_op.h @@ -92,6 +92,7 @@ class FillerOp : public Operator { } shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end()); output->Resize(shape); + shape_ = shape; } else { output->Resize(shape_); } diff --git a/caffe2/operators/gather_ranges_to_dense_op.h b/caffe2/operators/gather_ranges_to_dense_op.h index 217a61b25129d..a58d9484e5ca7 100644 --- a/caffe2/operators/gather_ranges_to_dense_op.h +++ b/caffe2/operators/gather_ranges_to_dense_op.h @@ -88,11 +88,11 @@ class GatherRangesToDenseOp final : public Operator { CAFFE_ENFORCE_EQ( ranges.size(1), lengths_.size(), - "Nummber of ranges should match number of lengths"); + "Number of ranges should match number of lengths"); CAFFE_ENFORCE_EQ( ranges.size(1), OutputSize(), - "Nummber of ranges should match number of outputs"); + "Number of ranges should match number of outputs"); CAFFE_ENFORCE_EQ( ranges.size(2), 2, "Ranges last dimension should be of size 2"); diff --git a/caffe2/operators/generate_proposals_op.cc b/caffe2/operators/generate_proposals_op.cc index f97854d5372b6..79f85602ecde9 100644 --- a/caffe2/operators/generate_proposals_op.cc +++ b/caffe2/operators/generate_proposals_op.cc @@ -415,25 +415,6 @@ SHOULD_NOT_DO_GRADIENT(GenerateProposalsCPP); } // namespace caffe2 // clang-format off -C10_EXPORT_CAFFE2_OP_TO_C10_CPU( - GenerateProposals2, - "__caffe2::GenerateProposals(" - "Tensor scores, " - "Tensor bbox_deltas, " - "Tensor im_info, " - "Tensor anchors, " - "float spatial_scale, " - "int pre_nms_topN, " - "int post_nms_topN, " - "float nms_thresh, " - "float min_size, " - "bool angle_bound_on, " - "int angle_bound_lo, " - "int angle_bound_hi, " - "float clip_angle_thresh, " - "bool legacy_plus_one" - ") -> (Tensor output_0, Tensor output_1)", - caffe2::GenerateProposalsOp); C10_EXPORT_CAFFE2_OP_TO_C10_CPU( GenerateProposals, "_caffe2::GenerateProposals(" diff --git a/caffe2/operators/generate_proposals_op.h b/caffe2/operators/generate_proposals_op.h index 0b239a3160564..b783b3db437b5 100644 --- a/caffe2/operators/generate_proposals_op.h +++ b/caffe2/operators/generate_proposals_op.h @@ -49,7 +49,7 @@ class ConstTensorView { // anchors: predefined anchors, size(A, 4) // Return: all_anchors_vec: (H * W, A * 4) // Need to reshape to (H * W * A, 4) to match the format in python -CAFFE2_API ERMatXf ComputeAllAnchors( +TORCH_API ERMatXf ComputeAllAnchors( const TensorCPU& anchors, int height, int width, @@ -59,7 +59,7 @@ CAFFE2_API ERMatXf ComputeAllAnchors( // spatial location, only computes anchors for the already sorted and filtered // positions after NMS is applied to avoid unnecessary computation. // `order` is a raveled array of sorted indices in (A, H, W) format. -CAFFE2_API ERArrXXf ComputeSortedAnchors( +TORCH_API ERArrXXf ComputeSortedAnchors( const Eigen::Map& anchors, int height, int width, diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.h b/caffe2/operators/generate_proposals_op_util_nms_gpu.h index 10d081f1f38e1..697a1ddc542a9 100644 --- a/caffe2/operators/generate_proposals_op_util_nms_gpu.h +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.h @@ -23,7 +23,7 @@ namespace utils { // by NMS // Those tensors will be resized to the necessary size // context : current CUDA context -CAFFE2_API void nms_gpu_upright( +TORCH_API void nms_gpu_upright( const float* d_desc_sorted_boxes, const int N, const float thresh, @@ -42,7 +42,7 @@ struct RotatedBox { // d_desc_sorted_boxes : pixel coordinates of proposed bounding boxes // size: (N,5), format: [x_ct; y_ctr; width; height; angle] // the boxes are sorted by scores in descending order -CAFFE2_API void nms_gpu_rotated( +TORCH_API void nms_gpu_rotated( const float* d_desc_sorted_boxes, const int N, const float thresh, @@ -52,7 +52,7 @@ CAFFE2_API void nms_gpu_rotated( TensorCPU& host_delete_mask, CUDAContext* context); -CAFFE2_API void nms_gpu( +TORCH_API void nms_gpu( const float* d_desc_sorted_boxes, const int N, const float thresh, diff --git a/caffe2/operators/heatmap_max_keypoint_op.cc b/caffe2/operators/heatmap_max_keypoint_op.cc index 420bb89d941a8..c4c31e432a083 100644 --- a/caffe2/operators/heatmap_max_keypoint_op.cc +++ b/caffe2/operators/heatmap_max_keypoint_op.cc @@ -172,12 +172,4 @@ C10_EXPORT_CAFFE2_OP_TO_C10_CPU( ") -> Tensor keypoints", HeatmapMaxKeypointOpFloatCPU); -C10_EXPORT_CAFFE2_OP_TO_C10_CPU( - HeatmapMaxKeypoint2, - "__caffe2::HeatmapMaxKeypoint(" - "Tensor heatmaps, " - "Tensor bboxes_in, " - "bool should_output_softmax = True" - ") -> Tensor keypoints", - HeatmapMaxKeypointOpFloatCPU); // clang-format on diff --git a/caffe2/operators/index_ops.h b/caffe2/operators/index_ops.h index 890753caf2fe9..2f5705cb4c263 100644 --- a/caffe2/operators/index_ops.h +++ b/caffe2/operators/index_ops.h @@ -18,7 +18,7 @@ using int64_tValue = int64_t; struct IndexBase { public: - IndexBase(int64_tValue maxElements, const TypeMeta& type) + IndexBase(int64_tValue maxElements, const TypeMeta type) : maxElements_{maxElements}, meta_(type), frozen_{false} {} void Freeze() { @@ -35,7 +35,7 @@ struct IndexBase { virtual ~IndexBase() {} - const TypeMeta& Type() const { + const TypeMeta Type() const { return meta_; } diff --git a/caffe2/operators/last_n_window_collector.cc b/caffe2/operators/last_n_window_collector.cc index 1b141b6500cd2..8b14c834b8d33 100644 --- a/caffe2/operators/last_n_window_collector.cc +++ b/caffe2/operators/last_n_window_collector.cc @@ -142,6 +142,30 @@ OPERATOR_SCHEMA(LastNWindowCollector) .NumInputs({3, 4, 5}) .NumOutputs(2, 3) .EnforceInplace({{0, 0}, {1, 1}, {4, 2}}) + .TensorInferenceFunction([](const OperatorDef& def, + const vector& in) { + auto output_size = def.output_size(); + vector out(output_size); + const ArgumentHelper helper(def); + const auto num_to_collect = + helper.GetSingleArgument("num_to_collect", -1); + + const auto data_dims = GetDimsVector(in[2]); + vector last_n_shape(data_dims.size()); + last_n_shape[0] = num_to_collect; + std::copy(data_dims.begin() + 1, data_dims.end(), last_n_shape.begin() + 1); + out[0] = CreateTensorShape(last_n_shape, in[2].data_type()); + + out[1] = in[1]; + + if (output_size > 2) { + vector num_visited_shape(1); + num_visited_shape[0] = 1; + out[2] = CreateTensorShape(num_visited_shape, TensorProto::INT64); + } + + return out; + }) .SetDoc(R"DOC( Collect the last N rows from input data. The purpose is to keep track of data accross batches, so for example suppose the LastNWindowCollector is called diff --git a/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.cc b/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.cc index 67b1cddee6135..facf01a8b84df 100644 --- a/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.cc +++ b/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.cc @@ -1,5 +1,6 @@ #include "caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.h" #include "c10/util/Registry.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" namespace caffe2 { @@ -611,3 +612,12 @@ fp16 scale and bias), and where rows are pruned. NO_GRADIENT(SparseLengthsMean2BitRowwiseSparse); } // namespace caffe2 + +C10_EXPORT_CAFFE2_OP_TO_C10_CPU( + SparseLengthsSum8BitRowwiseSparse, + "_caffe2::SparseLengthsSum8BitRowwiseSparse(" + "Tensor data, " + "Tensor indices, " + "Tensor lengths, " + "Tensor compressed_indices_mapping) -> Tensor output", + caffe2::SparseLengthsNBitRowwiseSparseOp<8>); diff --git a/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.h b/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.h index dfdb40a51efac..58be36fb9bc56 100644 --- a/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.h +++ b/caffe2/operators/lengths_reducer_fused_nbit_rowwise_ops.h @@ -331,8 +331,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator { !(with_weights && is_mean), "Cannot have with_weights and is_mean a the same time"); - SparseLengthsNBitRowwiseSparseOp(const OperatorDef& def, Workspace* ws) - : Operator(def, ws) {} + template + explicit SparseLengthsNBitRowwiseSparseOp(Args&&... args) + : Operator(std::forward(args)...) {} ~SparseLengthsNBitRowwiseSparseOp() override {} bool RunOnDevice() override { diff --git a/caffe2/operators/load_save_op_util.h b/caffe2/operators/load_save_op_util.h index b99bf73dc49e0..f0978d9fc2d0e 100644 --- a/caffe2/operators/load_save_op_util.h +++ b/caffe2/operators/load_save_op_util.h @@ -26,32 +26,32 @@ struct BlobState { is_tensor(is_tensor) {} }; -CAFFE2_API std::string buildBlobNameFromDbKey( +TORCH_API std::string buildBlobNameFromDbKey( const std::string& dbKey, const std::string& strip_prefix = "", const std::string& add_prefix = ""); // We are tracking sizes of already read tensor parts while reading data // chunks. This way we can make sure that all chunks were loaded in the end. -CAFFE2_API void ProcessBlob( +TORCH_API void ProcessBlob( Blob* blob, const BlobProto& proto, std::unordered_map* blob_states_ptr, const std::string& key, int* loaded_blobs); -CAFFE2_API void prepareBlob( +TORCH_API void prepareBlob( Blob* blob, std::unordered_map* blob_states_ptr, const std::string& key); -CAFFE2_API void updateBlobStates( +TORCH_API void updateBlobStates( const BlobProto& proto, std::unordered_map* blob_states_ptr, const std::string& key, int* loaded_blobs); -CAFFE2_API void validateBlobStates( +TORCH_API void validateBlobStates( const std::unordered_map& blob_states); } // namespace load_save_op_util diff --git a/caffe2/operators/locally_connected_op_util.h b/caffe2/operators/locally_connected_op_util.h index e9eb90035f55c..d1fd77fa055c2 100644 --- a/caffe2/operators/locally_connected_op_util.h +++ b/caffe2/operators/locally_connected_op_util.h @@ -35,7 +35,7 @@ struct CUDAConvNetShapeParams { int Y_W; }; -CAFFE2_API void SetColumnBufferShape( +TORCH_API void SetColumnBufferShape( int N, int kernel_dim, int output_image_size, @@ -46,7 +46,7 @@ CAFFE2_API void SetColumnBufferShape( std::vector* column_transposed_dims, std::vector* column_axes); -CAFFE2_API void SetYBufferShape( +TORCH_API void SetYBufferShape( int N, int M, int output_image_size, diff --git a/caffe2/operators/mean_op.h b/caffe2/operators/mean_op.h index f16914f4a8949..beb0b0440505d 100644 --- a/caffe2/operators/mean_op.h +++ b/caffe2/operators/mean_op.h @@ -65,9 +65,11 @@ class MeanOp final : public Operator { bool RunOnDevice() override { if (Input(0).template IsType()) { return DoRunWithType(); + } else if (Input(0).template IsType()) { + return DoRunWithType(); } else { CAFFE_THROW( - "Mean operator only supports 32-bit float, but", + "Mean operator only supports 32-bit float or 64-bit double, but", " input was of type ", Input(0).dtype().name()); } @@ -111,9 +113,11 @@ class MeanGradientOp : public Operator { bool RunOnDevice() override { if (Input(0).template IsType()) { return DoRunWithType(); + } else if (Input(0).template IsType()) { + return DoRunWithType(); } else { CAFFE_THROW( - "Mean operator only supports 32-bit float, but", + "Mean operator only supports 32-bit float or 64-bit double, but", " input was of type ", Input(0).dtype().name()); } diff --git a/caffe2/operators/mish_op.cc b/caffe2/operators/mish_op.cc index 3a0eb72141a66..12f1433fc0602 100644 --- a/caffe2/operators/mish_op.cc +++ b/caffe2/operators/mish_op.cc @@ -12,39 +12,44 @@ namespace caffe2 { template <> template bool MishFunctor:: -operator()(const int N, const T* X, T* Y, CPUContext* /* context */) const { +operator()(const int N, const T* X, T* Y, CPUContext* context) const { ConstEigenVectorArrayMap X_arr(X, N); - EigenVectorArrayMap(Y, N) = X_arr * (T(1) + X_arr.exp()).log().tanh(); + EigenVectorArrayMap Y_arr(Y, N); + math::Exp(N, X, Y, context); + math::Log1p(N, Y, Y, context); + Y_arr = X_arr * Y_arr.tanh(); return true; } template <> template bool MishGradientOp::DoRunWithType() { - auto& Xin = Input(X); - auto& Yin = Input(Y); - auto& DYin = Input(DY); - - CAFFE_ENFORCE_EQ(Xin.numel(), Yin.numel()); - CAFFE_ENFORCE_EQ(DYin.numel(), Yin.numel()); - auto* DXout = Output(DX, Yin.sizes(), at::dtype()); - - const float* Xdata = Xin.template data(); - const float* Ydata = Yin.template data(); - const float* dYdata = DYin.template data(); - float* dXdata = DXout->template mutable_data(); - - EigenVectorArrayMap dXvec(dXdata, DXout->numel()); - ConstEigenVectorArrayMap Xvec(Xdata, Xin.numel()); - ConstEigenVectorArrayMap Yvec(Ydata, Yin.numel()); - ConstEigenVectorArrayMap dYvec(dYdata, DYin.numel()); - - // w = e^(3x) + 4*e^2x + e^x * (6 + 4x) + 4(1 + x) - // q = (e^x + 1)^2 + 1 - // dX = dY * e^x * w / q^2 - dXvec = dYvec * - (T(4) * (Xvec+T(1)) * (-T(3)*Xvec).exp() + T(4)*(-Xvec).exp() + T(1) + (T(4)*Xvec+T(6))*(-T(2)*Xvec).exp()) / - (T(1) + T(4)*(-Xvec).exp() + T(8)*(-T(2)*Xvec).exp() + T(8)*(-T(3)*Xvec).exp() + T(4)*(-T(4)*Xvec).exp()); + const auto& X = Input(INPUT); + const auto& Y = Input(OUTPUT); + const auto& dY = Input(OUTPUT_GRAD); + + CAFFE_ENFORCE_EQ(X.numel(), Y.numel()); + CAFFE_ENFORCE_EQ(dY.numel(), Y.numel()); + auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype()); + + const T* X_data = X.template data(); + const T* Y_data = Y.template data(); + const T* dY_data = dY.template data(); + T* dX_data = dX->template mutable_data(); + + const int64_t N = X.numel(); + ConstEigenVectorArrayMap X_arr(X_data, N); + ConstEigenVectorArrayMap Y_arr(Y_data, N); + ConstEigenVectorArrayMap dY_arr(dY_data, N); + EigenVectorArrayMap dX_arr(dX_data, N); + + math::Exp(N, X_data, dX_data, &context_); + math::Log1p(N, dX_data, dX_data, &context_); + math::Tanh(N, dX_data, dX_data, &context_); + dX_arr = dY_arr * + (dX_arr + + X_arr * (T(1) - dX_arr.square()) * T(0.5) * + ((X_arr * T(0.5)).tanh() + T(1))); return true; } @@ -52,7 +57,7 @@ bool MishGradientOp::DoRunWithType() { REGISTER_CPU_OPERATOR( Mish, UnaryElementwiseOp< - TensorTypes, + TensorTypes, CPUContext, MishFunctor>); REGISTER_CPU_OPERATOR(MishGradient, MishGradientOp); @@ -70,11 +75,7 @@ tensor elementwise. .Input(0, "X", "1D input tensor") .Output(0, "Y", "1D output tensor"); // Input: X, Y, dY, output: dX -OPERATOR_SCHEMA(MishGradient) - .NumInputs(3) - .NumOutputs(1) - .AllowInplace({{2, 0}}) - .SetDoc(R"DOC( +OPERATOR_SCHEMA(MishGradient).NumInputs(3).NumOutputs(1).SetDoc(R"DOC( MishGradient takes X, Y and dY and uses this to update dX according to the chain rule and derivatives of the Mish function. )DOC"); diff --git a/caffe2/operators/mish_op.h b/caffe2/operators/mish_op.h index 819f1ff065378..caf5c45a28c8d 100644 --- a/caffe2/operators/mish_op.h +++ b/caffe2/operators/mish_op.h @@ -22,12 +22,12 @@ class MishGradientOp final : public Operator { bool DoRunWithType(); bool RunOnDevice() override { - return DispatchHelper>::call(this, Input(X)); + return DispatchHelper>::call(this, Input(INPUT)); } - protected: - INPUT_TAGS(X, Y, DY); - OUTPUT_TAGS(DX); + private: + INPUT_TAGS(INPUT, OUTPUT, OUTPUT_GRAD); + OUTPUT_TAGS(INPUT_GRAD); }; } // namespace caffe2 diff --git a/caffe2/operators/mod_op.cc b/caffe2/operators/mod_op.cc index 48c1eea5a415b..8faaac51572f5 100644 --- a/caffe2/operators/mod_op.cc +++ b/caffe2/operators/mod_op.cc @@ -25,8 +25,6 @@ bool ModOp::DoRunWithType() { return true; } -namespace { - REGISTER_CPU_OPERATOR(Mod, ModOp); OPERATOR_SCHEMA(Mod) .NumInputs(1) @@ -95,5 +93,4 @@ X after running op: .Output(0, "Y", "*(type: Tensor``)* Output tensor of data with modulo operation applied."); SHOULD_NOT_DO_GRADIENT(ModOp); -} // namespace } // namespace caffe2 diff --git a/caffe2/operators/mod_op.cu b/caffe2/operators/mod_op.cu new file mode 100644 index 0000000000000..90043a29389f4 --- /dev/null +++ b/caffe2/operators/mod_op.cu @@ -0,0 +1,63 @@ +#include "caffe2/operators/mod_op.h" + +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { + +namespace { + +template +__global__ void ModOpSimpleKernel(const int N, const int64_t divisor_, + const T* data_ptr, T* output_ptr) { + CUDA_1D_KERNEL_LOOP(i, N) { + output_ptr[i] = data_ptr[i] % divisor_; + } +} + + +template +__global__ void ModOpKernel(const int N, const int64_t divisor_, + const T* data_ptr, T* output_ptr) { + CUDA_1D_KERNEL_LOOP(i, N) { + output_ptr[i] = data_ptr[i] % divisor_; + if (output_ptr[i] && ((output_ptr[i] > 0) != (divisor_ > 0))) { + output_ptr[i] += divisor_; + } + } +} + +} // namespace + +template <> +template +bool ModOp::DoRunWithType() { + auto& data = Input(DATA); + auto N = data.numel(); + const auto* data_ptr = data.template data(); + + auto* output = Output(0, data.sizes(), at::dtype()); + auto* output_ptr = output->template mutable_data(); + + if (sign_follow_divisor_) { + ModOpKernel<<< + CAFFE_GET_BLOCKS(N), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + N, divisor_, data_ptr, output_ptr); + } else { + ModOpSimpleKernel<<< + CAFFE_GET_BLOCKS(N), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + N, divisor_, data_ptr, output_ptr); + } + + return true; + +} + +REGISTER_CUDA_OPERATOR(Mod, ModOp); + +} // namespace caffe2 diff --git a/caffe2/operators/numpy_tile_op.h b/caffe2/operators/numpy_tile_op.h index 8a39b40df0f82..ac9886ec503ad 100644 --- a/caffe2/operators/numpy_tile_op.h +++ b/caffe2/operators/numpy_tile_op.h @@ -92,7 +92,7 @@ class NumpyTileOp : public Operator { private: void DoTile( - const TypeMeta& meta, + const TypeMeta meta, int item_size, int outer_dim, int inner_dim, diff --git a/caffe2/operators/pad_op.h b/caffe2/operators/pad_op.h index fc138998ecc46..8ba352cf242fd 100644 --- a/caffe2/operators/pad_op.h +++ b/caffe2/operators/pad_op.h @@ -16,7 +16,7 @@ enum class PadMode { EDGE = 2, // pads with the edge values, with string "edge" }; -CAFFE2_API PadMode StringToPadMode(const string&); +TORCH_API PadMode StringToPadMode(const string&); template class PadImageOp final : public ConvPoolOpBase { diff --git a/caffe2/operators/reservoir_sampling.cc b/caffe2/operators/reservoir_sampling.cc index 32378623dbfe2..5ca129e45428d 100644 --- a/caffe2/operators/reservoir_sampling.cc +++ b/caffe2/operators/reservoir_sampling.cc @@ -150,10 +150,9 @@ class ReservoirSamplingOp final : public Operator { // append pos = *num_visited; } else { - auto& gen = context_.RandGenerator(); // uniform between [0, num_visited] - std::uniform_int_distribution uniformDist(0, *num_visited); - pos = uniformDist(gen); + at::uniform_int_from_to_distribution uniformDist(*num_visited+1, 0); + pos = uniformDist(context_.RandGenerator()); if (pos >= numToCollect_) { // discard pos = -1; diff --git a/caffe2/operators/rnn/recurrent_network_executor.h b/caffe2/operators/rnn/recurrent_network_executor.h index 95197ee245332..eecccf7774926 100644 --- a/caffe2/operators/rnn/recurrent_network_executor.h +++ b/caffe2/operators/rnn/recurrent_network_executor.h @@ -476,7 +476,7 @@ std::unique_ptr createRNNExecutor( std::string timestep_blob, ArgumentHelper rnn_args); -class CAFFE2_API ThreadedRecurrentNetworkExecutor : public RecurrentNetworkExecutorBase { +class TORCH_API ThreadedRecurrentNetworkExecutor : public RecurrentNetworkExecutorBase { public: ThreadedRecurrentNetworkExecutor( const NetDef& step_net_def, diff --git a/caffe2/operators/rnn/recurrent_network_op.h b/caffe2/operators/rnn/recurrent_network_op.h index 8484b6813fd28..86b6e4531820d 100644 --- a/caffe2/operators/rnn/recurrent_network_op.h +++ b/caffe2/operators/rnn/recurrent_network_op.h @@ -46,7 +46,7 @@ struct Link { int32_t window{1}; }; -struct CAFFE2_API ScratchWorkspaces { +struct TORCH_API ScratchWorkspaces { std::vector> stepWorkspaces; std::shared_ptr sharedBlobsWs = nullptr; }; @@ -59,7 +59,7 @@ inline void UpdateTimestepBlob(Workspace* ws, std::string blob_name, int t) { t; } -CAFFE2_API std::map GetRecurrentMapping( +TORCH_API std::map GetRecurrentMapping( const std::vector& links, bool backward); @@ -158,15 +158,15 @@ void initializeRecurrentInput( } } -CAFFE2_API void PrependOps(std::vector ops, NetDef* netdef); +TORCH_API void PrependOps(std::vector ops, NetDef* netdef); -CAFFE2_API void AddApplyLinkOps( +TORCH_API void AddApplyLinkOps( const vector& links, std::string timestep, const DeviceOption& device_option, NetDef* netdef); -CAFFE2_API void extractLinks( +TORCH_API void extractLinks( OperatorBase* op, const std::string& internalArg, const std::string& externalArg, @@ -174,7 +174,7 @@ CAFFE2_API void extractLinks( const std::string& windowArg, std::vector* links); -CAFFE2_API NetDef +TORCH_API NetDef extractNetDef(const OperatorDef& op, const std::string& argName); } // namespace detail diff --git a/caffe2/operators/roi_align_gradient_op.cc b/caffe2/operators/roi_align_gradient_op.cc index 7f3b1155e1b38..6a9b2bab0ec39 100644 --- a/caffe2/operators/roi_align_gradient_op.cc +++ b/caffe2/operators/roi_align_gradient_op.cc @@ -191,7 +191,7 @@ void ROIAlignBackwardFeature( } // namespace template <> -bool RoIAlignGradientOp::RunOnDevice() { +C10_EXPORT bool RoIAlignGradientOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op diff --git a/caffe2/operators/roi_align_gradient_op.cu b/caffe2/operators/roi_align_gradient_op.cu index babf06d759eb8..09f56e3269e78 100644 --- a/caffe2/operators/roi_align_gradient_op.cu +++ b/caffe2/operators/roi_align_gradient_op.cu @@ -190,7 +190,7 @@ __global__ void RoIAlignBackwardFeature( } // namespace template <> -bool RoIAlignGradientOp::RunOnDevice() { +C10_EXPORT bool RoIAlignGradientOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op diff --git a/caffe2/operators/roi_align_op.cc b/caffe2/operators/roi_align_op.cc index 997eb1404b2e0..02120137465f1 100644 --- a/caffe2/operators/roi_align_op.cc +++ b/caffe2/operators/roi_align_op.cc @@ -84,7 +84,7 @@ std::vector> MakeBilinearInterpolationParams( } // namespace template <> -bool RoIAlignOp::RunOnDeviceWithOrderNCHW( +C10_EXPORT bool RoIAlignOp::RunOnDeviceWithOrderNCHW( int64_t N, int64_t C, int64_t H, @@ -170,7 +170,7 @@ bool RoIAlignOp::RunOnDeviceWithOrderNCHW( } template <> -bool RoIAlignOp::RunOnDeviceWithOrderNHWC( +C10_EXPORT bool RoIAlignOp::RunOnDeviceWithOrderNHWC( int64_t N, int64_t C, int64_t H, @@ -313,16 +313,3 @@ C10_EXPORT_CAFFE2_OP_TO_C10_CPU( ") -> Tensor", caffe2::RoIAlignCPUOp); -C10_EXPORT_CAFFE2_OP_TO_C10_CPU( - RoIAlign2, - "__caffe2::RoIAlign(" - " Tensor features," - " Tensor rois," - " str order," - " float spatial_scale," - " int pooled_h," - " int pooled_w," - " int sampling_ratio," - " bool aligned" - ") -> Tensor", - caffe2::RoIAlignCPUOp); diff --git a/caffe2/operators/roi_align_op.cu b/caffe2/operators/roi_align_op.cu index 62d7842e2ae3d..4d0edd3a408c1 100644 --- a/caffe2/operators/roi_align_op.cu +++ b/caffe2/operators/roi_align_op.cu @@ -149,7 +149,7 @@ __global__ void RoIAlignForward( } // namespace template <> -bool RoIAlignOp::RunOnDevice() { +C10_EXPORT bool RoIAlignOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs // RoI pooled data diff --git a/caffe2/operators/roi_align_rotated_gradient_op.cu b/caffe2/operators/roi_align_rotated_gradient_op.cu index 1ca0b73c72fa3..cc16a828858ff 100644 --- a/caffe2/operators/roi_align_rotated_gradient_op.cu +++ b/caffe2/operators/roi_align_rotated_gradient_op.cu @@ -198,7 +198,7 @@ __global__ void RoIAlignRotatedBackward( } // namespace template <> -bool RoIAlignRotatedGradientOp::RunOnDevice() { +C10_EXPORT bool RoIAlignRotatedGradientOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op diff --git a/caffe2/operators/roi_align_rotated_op.cc b/caffe2/operators/roi_align_rotated_op.cc index c94d0f11bd1fe..5491cc3e597fb 100644 --- a/caffe2/operators/roi_align_rotated_op.cc +++ b/caffe2/operators/roi_align_rotated_op.cc @@ -291,7 +291,7 @@ void ROIAlignRotatedForward( } // namespace template <> -bool RoIAlignRotatedOp::RunOnDevice() { +C10_EXPORT bool RoIAlignRotatedOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs @@ -424,17 +424,4 @@ C10_EXPORT_CAFFE2_OP_TO_C10_CPU( ") -> Tensor", RoIAlignRotatedOpFloatCPU); -C10_EXPORT_CAFFE2_OP_TO_C10_CPU( - RoIAlignRotated2, - "__caffe2::RoIAlignRotated(" - "Tensor features, " - "Tensor rois, " - "str order, " - "float spatial_scale, " - "int pooled_h, " - "int pooled_w, " - "int sampling_ratio, " - "bool aligned" - ") -> Tensor", - RoIAlignRotatedOpFloatCPU); // clang-format on diff --git a/caffe2/operators/roi_align_rotated_op.cu b/caffe2/operators/roi_align_rotated_op.cu index 96e4797c597c5..67c1d38f51b4d 100644 --- a/caffe2/operators/roi_align_rotated_op.cu +++ b/caffe2/operators/roi_align_rotated_op.cu @@ -158,7 +158,7 @@ __global__ void RoIAlignRotatedForward( } // namespace template <> -bool RoIAlignRotatedOp::RunOnDevice() { +C10_EXPORT bool RoIAlignRotatedOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs diff --git a/caffe2/operators/roi_pool_op.cc b/caffe2/operators/roi_pool_op.cc index 95a6cbfa386c1..d0018b03f4a6a 100644 --- a/caffe2/operators/roi_pool_op.cc +++ b/caffe2/operators/roi_pool_op.cc @@ -8,7 +8,7 @@ using std::max; using std::min; template <> -bool RoIPoolOp::RunOnDevice() { +C10_EXPORT bool RoIPoolOp::RunOnDevice() { const auto& X = Input(0); // Input data to pool const auto& R = Input(1); // RoIs auto* Y = Output(0); // RoI pooled data diff --git a/caffe2/operators/roi_pool_op.cu b/caffe2/operators/roi_pool_op.cu index af479f8a5881b..7c1ef13166230 100644 --- a/caffe2/operators/roi_pool_op.cu +++ b/caffe2/operators/roi_pool_op.cu @@ -167,7 +167,7 @@ bool RoIPoolOp::RunOnDevice() { } template <> -bool RoIPoolGradientOp::RunOnDevice() { +C10_EXPORT bool RoIPoolGradientOp::RunOnDevice() { auto& X = Input(0); // Input data to pool auto& R = Input(1); // RoIs auto& A = Input(2); // argmaxes diff --git a/caffe2/operators/segment_reduction_op.h b/caffe2/operators/segment_reduction_op.h index 55650f036a19c..97618a8e51022 100644 --- a/caffe2/operators/segment_reduction_op.h +++ b/caffe2/operators/segment_reduction_op.h @@ -2006,13 +2006,13 @@ i.e. `len(LENGTHS)`. Other dimensions are inherited from the input tensor. "OUTPUT", "Aggregated output tensor. Has the first dimension of K " "(the number of segments)."); - schema.TensorInferenceFunction( + schema.TensorInferenceFunction(OpSchema::NeedsAllInputShapes( [](const OperatorDef&, const std::vector& input_types) { std::vector out(1); out[0] = input_types[0]; out[0].set_dims(0, input_types[Reducer::kInputCount + 1].dims(0)); return out; - }); + })); ReducerDef::PopulateSchema(schema); schema.CostInferenceFunction( diff --git a/caffe2/operators/self_binning_histogram_op.cc b/caffe2/operators/self_binning_histogram_op.cc index 8cecf0267ea3c..111abd18094cd 100644 --- a/caffe2/operators/self_binning_histogram_op.cc +++ b/caffe2/operators/self_binning_histogram_op.cc @@ -35,7 +35,11 @@ OPERATOR_SCHEMA(SelfBinningHistogram) "logspace_start", "A float that's used as the starting point for logarithmic spacing. " "Since logarithmic spacing cannot contain <=0 values this value will " - "be used to represent all such values."); + "be used to represent all such values.") + .Arg( + "abs", + "Apply abs() on every input value." + ); SHOULD_NOT_DO_GRADIENT(SelfBinningHistogram); } // namespace caffe2 diff --git a/caffe2/operators/self_binning_histogram_op.h b/caffe2/operators/self_binning_histogram_op.h index d29d02b2deb97..6fb6c8f14a085 100644 --- a/caffe2/operators/self_binning_histogram_op.h +++ b/caffe2/operators/self_binning_histogram_op.h @@ -19,7 +19,8 @@ class SelfBinningHistogramOp final : public Operator { bin_spacing_(this->template GetSingleArgument( "bin_spacing", "linear")), - logspace_start_(this->template GetSingleArgument("logspace_start", 1e-24)) + logspace_start_(this->template GetSingleArgument("logspace_start", 1e-24)), + abs_(this->template GetSingleArgument("abs", false)) { CAFFE_ENFORCE_GE( num_bins_, 1, "Number of bins must be greater than or equal to 1."); @@ -64,13 +65,14 @@ class SelfBinningHistogramOp final : public Operator { total_count += N; const auto* x_data = x.template data(); for (int64_t data_idx = 0; data_idx < N; data_idx++) { + const T val = this->abs_ ? abs(x_data[data_idx]) : x_data[data_idx]; if (!first_seen) { - max = x_data[data_idx]; - min = x_data[data_idx]; + max = val; + min = val; first_seen = true; } else { - max = std::max(x_data[data_idx], max); - min = std::min(x_data[data_idx], min); + max = std::max(val, max); + min = std::min(val, min); } } } @@ -130,10 +132,11 @@ class SelfBinningHistogramOp final : public Operator { const int64_t N = x.numel(); const auto* x_data = x.template data(); for (int64_t data_idx = 0; data_idx < N; data_idx++) { + const T val = this->abs_ ? abs(x_data[data_idx]) : x_data[data_idx]; const auto bisection_it = std::upper_bound( histogram_values_data, histogram_values_data + num_edges_, - x_data[data_idx]); + val); const int bisection_idx = bisection_it - histogram_values_data; if (bisection_idx > 0 && bisection_idx < num_edges_) { histogram_counts_data[bisection_idx - 1]++; @@ -156,6 +159,7 @@ class SelfBinningHistogramOp final : public Operator { int num_edges_; std::string bin_spacing_; float logspace_start_; + bool abs_; // automatically apply abs() on the input values void CheckInputs() { const auto& input_zero = Input(0); diff --git a/caffe2/operators/slice_op.cc b/caffe2/operators/slice_op.cc index 7acf854ba9da3..f9fd393032619 100644 --- a/caffe2/operators/slice_op.cc +++ b/caffe2/operators/slice_op.cc @@ -17,7 +17,7 @@ Produces a slice of the input tensor. - Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments. -- If a negative value is passed for any of the start or end indices, it represents the number of elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element). +- If a negative value is passed for any of the start or end indices, it represents |value| - 1 elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element). Github Links: - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc @@ -67,11 +67,11 @@ print("Y:", workspace.FetchBlob("Y")) .Input( 1, "starts", - "(*Tensor``*): 1D tensor of start-indices for each dimension of data") + "(*Tensor``*): 1D tensor of start-indices for each dimension of data (dimensions following the sliced one might be omitted)") .Input( 2, "ends", - "(*Tensor``*): 1D tensor of end-indices for each dimension of data") + "(*Tensor``*): 1D tensor of end-indices for each dimension of data (dimensions following the sliced one might be omitted)") .Arg("starts", "(*Tuple(int)*): list of starting indices") .Arg("ends", "(*Tuple(int)*): list of ending indices") .TensorInferenceFunction([](const OperatorDef& def, @@ -90,9 +90,10 @@ print("Y:", workspace.FetchBlob("Y")) for (int i = 0; i < data.dims_size(); ++i) { if (i >= starts.size()) { + dst_sizes[i] = data.dims(i); continue; } - if (data.dims_size() > 0) { + if (data.dims(i) > 0) { auto start = starts[i]; auto end = ends[i]; if (start < 0) { diff --git a/caffe2/operators/slice_op.cu b/caffe2/operators/slice_op.cu index 7a843fee3a527..184385310c9c1 100644 --- a/caffe2/operators/slice_op.cu +++ b/caffe2/operators/slice_op.cu @@ -74,22 +74,23 @@ bool SliceImplGpu( if (i >= starts.numel()) { starts_idx[i] = 0; ends_idx[i] = data.size(i); + dst_sizes[i] = data.size(i); continue; } if (data.size(i) > 0) { auto start = starts_data[i]; auto end = ends_data[i]; if (start < 0) { - start = data.sizes()[i] + 1 + start; + start = data.size(i) + 1 + start; } if (end < 0) { - end = data.sizes()[i] + 1 + end; + end = data.size(i) + 1 + end; } - if (start > data.sizes()[i]) { - start = data.sizes()[i]; + if (start > data.size(i)) { + start = data.size(i); } - if (end > data.sizes()[i]) { - end = data.sizes()[i]; + if (end > data.size(i)) { + end = data.size(i); } CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(end, 0); @@ -115,7 +116,7 @@ bool SliceImplGpu( // for now only supports slicing in 1 dimension int dim = -1; for (int i = 0; i < data.dim(); ++i) { - if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { + if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) { CAFFE_ENFORCE_EQ( dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; @@ -154,7 +155,7 @@ bool SliceImplGpu( size_t src_nbytes = data.nbytes(); size_t dst_nbytes = output->nbytes(); - size_t src_block_size = unit * data.sizes()[dim]; + size_t src_block_size = unit * data.size(dim); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_offset = unit * starts_idx[dim]; @@ -187,7 +188,7 @@ bool SliceImplGpu( size_t dst_nbytes = gdata->nbytes(); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); - size_t dst_block_size = unit * data.sizes()[dim]; + size_t dst_block_size = unit * data.size(dim); size_t dst_offset = unit * starts_idx[dim]; if (num_blocks == 0 || dst_block_size == 0) { diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h index 8d1990e54c38f..9706472315b61 100644 --- a/caffe2/operators/slice_op.h +++ b/caffe2/operators/slice_op.h @@ -33,23 +33,24 @@ bool SliceImpl( for (int i = 0; i < data.dim(); ++i) { if (i >= starts.numel()) { starts_idx[i] = 0; - ends_idx[i] = data.sizes()[i]; + ends_idx[i] = data.size(i); + dst_sizes[i] = data.size(i); continue; } - if (data.sizes()[i] > 0) { + if (data.size(i) > 0) { auto start = starts_data[i]; auto end = ends_data[i]; if (start < 0) { - start = data.sizes()[i] + 1 + start; + start = data.size(i) + 1 + start; } if (end < 0) { - end = data.sizes()[i] + 1 + end; + end = data.size(i) + 1 + end; } - if (start > data.sizes()[i]) { - start = data.sizes()[i]; + if (start > data.size(i)) { + start = data.size(i); } - if (end > data.sizes()[i]) { - end = data.sizes()[i]; + if (end > data.size(i)) { + end = data.size(i); } CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(end, 0); @@ -78,7 +79,7 @@ bool SliceImpl( // for now only supports slicing in 1 dimension int dim = -1; for (int i = 0; i < data.dim(); ++i) { - if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { + if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) { CAFFE_ENFORCE_EQ( dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; @@ -117,7 +118,7 @@ bool SliceImpl( size_t src_nbytes = data.nbytes(); size_t dst_nbytes = output->nbytes(); - size_t src_block_size = unit * data.sizes()[dim]; + size_t src_block_size = unit * data.size(dim); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_offset = unit * starts_idx[dim]; @@ -155,7 +156,7 @@ bool SliceImpl( size_t dst_nbytes = gdata->nbytes(); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); - size_t dst_block_size = unit * data.sizes()[dim]; + size_t dst_block_size = unit * data.size(dim); size_t dst_offset = unit * starts_idx[dim]; if (num_blocks == 0 || dst_block_size == 0) { diff --git a/caffe2/operators/sparse_dropout_with_replacement_op.cc b/caffe2/operators/sparse_dropout_with_replacement_op.cc index 109860a788d7d..a9bd46ef27f1e 100644 --- a/caffe2/operators/sparse_dropout_with_replacement_op.cc +++ b/caffe2/operators/sparse_dropout_with_replacement_op.cc @@ -26,12 +26,12 @@ bool SparseDropoutWithReplacementOp::RunOnDevice() { X.numel(), "Inconsistent input data. Number of elements should match total length."); - std::bernoulli_distribution dist(1. - ratio_); - auto& gen = context_.RandGenerator(); + at::bernoulli_distribution dist(1. - ratio_); + auto* gen = context_.RandGenerator(); int32_t total_output_length = 0; vector selected(Lengths.numel(), true); for (int i = 0; i < Lengths.numel(); ++i) { - if (dist(gen)) { + if (dist(gen) > 0.5) { output_lengths_data[i] = input_lengths_data[i]; } else { // Replace with a single dropout value. Even if input length is 0. diff --git a/caffe2/operators/sparse_lp_regularizer_op.h b/caffe2/operators/sparse_lp_regularizer_op.h index 95a33e05f3a43..b2e19655a95b1 100644 --- a/caffe2/operators/sparse_lp_regularizer_op.h +++ b/caffe2/operators/sparse_lp_regularizer_op.h @@ -6,7 +6,7 @@ namespace caffe2 { template -class CAFFE2_API SparseLpRegularizerOp final : public Operator { +class TORCH_API SparseLpRegularizerOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template diff --git a/caffe2/operators/sparse_normalize_op.cc b/caffe2/operators/sparse_normalize_op.cc index 516955c108e1c..85534c0238b3f 100644 --- a/caffe2/operators/sparse_normalize_op.cc +++ b/caffe2/operators/sparse_normalize_op.cc @@ -2,6 +2,12 @@ #include "caffe2/core/tensor.h" #include "caffe2/utils/eigen_utils.h" +#include "caffe2/utils/cpuid.h" + +#ifdef USE_FBGEMM +#include "fbgemm/FbgemmConvert.h" +#endif + namespace caffe2 { template <> @@ -46,6 +52,62 @@ bool SparseNormalizeOp::DoRunWithType() { return true; } +template <> +bool SparseNormalizeOp::RunOnDevice() { + return DispatchHelper>::call( + this, Input(INDICES)); +} + +inline void Float16ToFloat_ref(const at::Half* in, float* out, size_t N) { + for (size_t i = 0; i < N; ++i) { + out[i] = in[i]; + } +} + +template <> +template +bool SparseNormalizeOp::DoRunWithType() { + const auto* indices = Input(INDICES).template data(); + const auto* paramIn = Input(PARAM).template data(); + auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data(); + const float kEps = 1e-12f; + + // n: number of sparse embeddings to be normalized + auto n = Input(INDICES).numel(); + if (n == 0) { + return true; + } + // embedding length, e.g. 32, 64, 128 + auto block_size = Input(PARAM).size_from_dim(1); + vector row_vec_fp32(block_size); + auto out_data = row_vec_fp32.data(); + for (int i = 0; i < n; ++i) { + auto idx = indices[i]; + auto offsetIdx = idx * block_size; +#ifdef USE_FBGEMM + if (GetCpuId().avx2()) { + fbgemm::Float16ToFloat_avx2( + reinterpret_cast(paramIn + offsetIdx), + out_data, + block_size); + } else { + Float16ToFloat_ref(paramIn + offsetIdx, out_data, block_size); + } +#else + Float16ToFloat_ref(paramIn + offsetIdx, out_data, block_size); +#endif + ConstEigenVectorMap xVec_fp32(row_vec_fp32.data(), block_size); + float norm = xVec_fp32.template lpNorm<2>(); + if (use_max_norm_ && norm <= norm_) { + continue; + } + auto Y = paramOut + offsetIdx; + EigenVectorArrayMap(Y, block_size) *= + static_cast(norm_ / (norm + kEps)); + } + return true; +} + REGISTER_CPU_OPERATOR(SparseNormalize, SparseNormalizeOp); OPERATOR_SCHEMA(SparseNormalize) .NumInputs(2, 3) @@ -74,4 +136,33 @@ Given a sparse matrix, apply max_norm or constant_norm sparse regularization. )DOC"); SHOULD_NOT_DO_GRADIENT(SparseNormalize); + +REGISTER_CPU_OPERATOR(Float16SparseNormalize, SparseNormalizeOp); +OPERATOR_SCHEMA(Float16SparseNormalize) + .NumInputs(2, 3) + .NumOutputs(1) + .Input(0, "param", "Parameters to be normalized") + .Input(1, "indices", "Sparse indices") + .Input( + 2, + "grad", + "Gradient computed (optional - not used, this argument is for backwards compatibility)") + .Output(0, "output_param", "Normalized parameters") + .EnforceOneToOneInplace() + .Arg( + "use_max_norm", + "A bool variable to control whether to use max norm \ + or constant norm. When use_max_norm = false, constant norm is used so that \ + all the embedding vectors are scaled to have a L2 norm equals to A \ + (see blow argument norm=A). If use_max_norm = true, \ + max norm is used so that embedding is scaled so that its l2 norm is no larger \ + than A. If an embedding's norm is less than A originally, \ + the embedding is left unchanged.\ + The default is True.") + .Arg("norm", "L2 norm of the embedding. The default is 1.0.") + .SetDoc(R"DOC( +Given a sparse matrix, apply max_norm or constant_norm sparse regularization. +)DOC"); + +SHOULD_NOT_DO_GRADIENT(Float16SparseNormalize); } // namespace caffe2 diff --git a/caffe2/operators/sparse_normalize_op.h b/caffe2/operators/sparse_normalize_op.h index de2fba437c2fe..44434b2ba8b6f 100644 --- a/caffe2/operators/sparse_normalize_op.h +++ b/caffe2/operators/sparse_normalize_op.h @@ -6,7 +6,7 @@ namespace caffe2 { template -class CAFFE2_API SparseNormalizeOp final : public Operator { +class TORCH_API SparseNormalizeOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template diff --git a/caffe2/operators/sparse_to_dense_mask_op.cc b/caffe2/operators/sparse_to_dense_mask_op.cc index d968112c9ecc2..b842d09e068d0 100644 --- a/caffe2/operators/sparse_to_dense_mask_op.cc +++ b/caffe2/operators/sparse_to_dense_mask_op.cc @@ -45,21 +45,21 @@ Convert sparse representations to dense with given indices. Transforms a sparse representation of map represented as `indices` vector and `values` tensor into a compacted tensor where the first dimension -corresponds to each id provided in mask argument. Missing values are filled with -the value of `default_value`. After running this op: +corresponds to each id provided in the mask argument. Missing values are filled +with the value of `default_value`. After running this op: output[j, :] = values[i] // where mask[j] == indices[i] output[j, ...] = default_value // when mask[j] doesn't appear in indices -If `lengths` is provided and not empty, and extra "batch" dimension is prepended +If `lengths` is provided and not empty, an extra "batch" dimension is prepended to the output. -`values` and `default_value` can have additional matching dimensions, operation -is performed on the entire subtensor in thise case. +`values` and `default_value` can have additional matching dimensions +(the operation is performed on the entire subtensor in this case). -For example, if `lengths` is supplied and `values` is 1-D vector of floats and -`default_value` is a float scalar, the output is going to be a float matrix -of size `len(lengths) X len(mask)` +For example, if `lengths` is supplied and `values` is a 1-D vector of floats +and `default_value` is a float scalar, the output is going to be a float +matrix of size `len(lengths) X len(mask)`. )DOC") .Arg( "mask", @@ -67,6 +67,10 @@ of size `len(lengths) X len(mask)` .Arg( "return_presence_mask", "bool whether to return presence mask, false by default") + .Arg( + "max_skipped_indices", + "int argument representing the maximum number of invalid row ids that " + "can be skipped before returning an error. 50 by default") .Input(0, "indices", "1-D int32/int64 tensor of concatenated ids of data") .Input(1, "values", "Data tensor, first dimension has to match `indices`") .Input( @@ -117,3 +121,18 @@ class GetSparseToDenseMaskGradient : public GradientMakerBase { REGISTER_GRADIENT(SparseToDenseMask, GetSparseToDenseMaskGradient); } // namespace } // namespace caffe2 + +// clang-format off +C10_EXPORT_CAFFE2_OP_TO_C10_CPU( + SparseToDenseMask, + "_caffe2::SparseToDenseMask(" + "Tensor indices, " + "Tensor values, " + "Tensor default_value, " + "Tensor? lengths, " + "int[] mask, " + "bool? return_presence_mask = False, " + "int? max_skipped_indices = 50" + ") -> (Tensor output, Tensor presence_mask)", + caffe2::SparseToDenseMaskOp); +// clang-format on diff --git a/caffe2/operators/sparse_to_dense_mask_op.h b/caffe2/operators/sparse_to_dense_mask_op.h index 8ed589c6d734c..26213c0cff33c 100644 --- a/caffe2/operators/sparse_to_dense_mask_op.h +++ b/caffe2/operators/sparse_to_dense_mask_op.h @@ -5,10 +5,13 @@ #include #include #include "caffe2/core/context.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/utils/math.h" +C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(SparseToDenseMask); + namespace caffe2 { template diff --git a/caffe2/operators/string_ops.cc b/caffe2/operators/string_ops.cc index 76fedeb488f83..7339d772f473a 100644 --- a/caffe2/operators/string_ops.cc +++ b/caffe2/operators/string_ops.cc @@ -71,6 +71,17 @@ struct EndsWith { std::string suffix_; }; +struct StrEquals { + explicit StrEquals(OperatorBase& op) + : text_(op.GetSingleArgument("text", "")) {} + bool operator()(const std::string& str) { + return str == text_; + } + + private: + std::string text_; +}; + struct Prefix { explicit Prefix(OperatorBase& op) : length_(op.GetSingleArgument("length", 3)) {} @@ -108,6 +119,9 @@ REGISTER_CPU_OPERATOR( REGISTER_CPU_OPERATOR( StringEndsWith, StringElementwiseOp>); +REGISTER_CPU_OPERATOR( + StringEquals, + StringElementwiseOp>); REGISTER_CPU_OPERATOR(StringJoin, StringJoinOp); OPERATOR_SCHEMA(StringPrefix) @@ -164,6 +178,17 @@ Returns tensor of boolean of the same dimension of input. .Input(0, "strings", "Tensor of std::string.") .Output(0, "bools", "Tensor of bools of same shape as input."); +OPERATOR_SCHEMA(StringEquals) + .NumInputs(1) + .NumOutputs(1) + .SetDoc(R"DOC( +Performs equality check on each string in the input tensor. +Returns tensor of booleans of the same dimension as input. +)DOC") + .Arg("text", "The text to check input strings equality against.") + .Input(0, "strings", "Tensor of std::string.") + .Output(0, "bools", "Tensor of bools of same shape as input."); + OPERATOR_SCHEMA(StringJoin) .NumInputs(1) .NumOutputs(1) @@ -187,6 +212,7 @@ SHOULD_NOT_DO_GRADIENT(StringPrefix); SHOULD_NOT_DO_GRADIENT(StringSuffix); SHOULD_NOT_DO_GRADIENT(StringStartsWith); SHOULD_NOT_DO_GRADIENT(StringEndsWith); +SHOULD_NOT_DO_GRADIENT(StringEquals); SHOULD_NOT_DO_GRADIENT(StringJoin); } } // namespace caffe2 diff --git a/caffe2/operators/text_file_reader_utils.h b/caffe2/operators/text_file_reader_utils.h index 558c73342882a..01b4743a91c14 100644 --- a/caffe2/operators/text_file_reader_utils.h +++ b/caffe2/operators/text_file_reader_utils.h @@ -9,13 +9,13 @@ namespace caffe2 { -struct CAFFE2_API Token { +struct TORCH_API Token { int startDelimId; const char* start; const char* end; }; -class CAFFE2_API TokenizedString { +class TORCH_API TokenizedString { // holder for strings that have been modified std::vector> modifiedStrings_; std::vector tokens_; @@ -31,7 +31,7 @@ class CAFFE2_API TokenizedString { friend class Tokenizer; }; -class CAFFE2_API Tokenizer { +class TORCH_API Tokenizer { private: int startDelimId_; // state of the tokenizer @@ -48,18 +48,18 @@ class CAFFE2_API Tokenizer { void next(char* start, char* end, TokenizedString& tokenized); }; -struct CAFFE2_API CharRange { +struct TORCH_API CharRange { char* start; char* end; }; -struct CAFFE2_API StringProvider { +struct TORCH_API StringProvider { virtual void operator()(CharRange&) = 0; virtual void reset() = 0; virtual ~StringProvider() {} }; -class CAFFE2_API BufferedTokenizer { +class TORCH_API BufferedTokenizer { public: BufferedTokenizer(const Tokenizer& t, StringProvider* p, int numPasses = 1) : provider_(p), tokenizer_(t), tokenIndex_(0), numPasses_(numPasses) {} @@ -104,7 +104,7 @@ class CAFFE2_API BufferedTokenizer { int pass_{0}; }; -class CAFFE2_API FileReader : public StringProvider { +class TORCH_API FileReader : public StringProvider { public: explicit FileReader(const std::string& path, size_t bufferSize = 65536); ~FileReader(); diff --git a/caffe2/operators/tile_op.cc b/caffe2/operators/tile_op.cc index 40684c50575b1..b0d797fce7ff7 100644 --- a/caffe2/operators/tile_op.cc +++ b/caffe2/operators/tile_op.cc @@ -71,7 +71,7 @@ bool TileOp::DoRunWithType() { // size from axis up const int inner_size = X.size_from_dim(axis); - const TypeMeta& meta = X.dtype(); + const TypeMeta meta = X.dtype(); const int item_size = X.itemsize(); const char* X_ptr = reinterpret_cast(X.raw_data()); char* Y_ptr = reinterpret_cast(Y->raw_mutable_data(meta)); diff --git a/caffe2/operators/unsafe_coalesce.cc b/caffe2/operators/unsafe_coalesce.cc new file mode 100644 index 0000000000000..263c96b68572e --- /dev/null +++ b/caffe2/operators/unsafe_coalesce.cc @@ -0,0 +1,27 @@ +#include "caffe2/operators/unsafe_coalesce.h" + +namespace caffe2 { + +OPERATOR_SCHEMA(UnsafeCoalesce) + .NumInputsOutputs([](int inputs, int outputs) { + return inputs + 1 == outputs; + }) + .AllowInplace([](int input, int output) { return input == output; }) + .SetDoc(R"DOC( +Coalesce the N inputs into N outputs and a single coalesced output blob. +This allows operations that operate over multiple small kernels (e.g. +biases in a deep CNN) to be coalesced into a single larger operation, +amortizing the kernel launch overhead, synchronization costs for +distributed computation, etc. +The operator: +- computes the total size of the coalesced blob by summing the input sizes +- allocates the coalesced output blob as the total size +- copies the input vectors into the coalesced blob, at the correct offset. +- aliases each Output(i) to- point into the coalesced blob, at the corresponding offset for Input(i). +This is 'unsafe' as the output vectors are aliased, so use with +caution. +)DOC"); + +REGISTER_CPU_OPERATOR(UnsafeCoalesce, UnsafeCoalesceOp); + +} // namespace caffe2 diff --git a/caffe2/operators/unsafe_coalesce.cu b/caffe2/operators/unsafe_coalesce.cu new file mode 100644 index 0000000000000..234ce69bde3a4 --- /dev/null +++ b/caffe2/operators/unsafe_coalesce.cu @@ -0,0 +1,8 @@ +#include "caffe2/operators/unsafe_coalesce.h" +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { + +REGISTER_CUDA_OPERATOR(UnsafeCoalesce, UnsafeCoalesceOp); + +} diff --git a/caffe2/operators/unsafe_coalesce.h b/caffe2/operators/unsafe_coalesce.h new file mode 100644 index 0000000000000..bb0f58a655589 --- /dev/null +++ b/caffe2/operators/unsafe_coalesce.h @@ -0,0 +1,69 @@ +#ifndef CAFFE2_OPERATORS_UNSAFE_COALESCE_OP_H_ +#define CAFFE2_OPERATORS_UNSAFE_COALESCE_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" +#include "caffe2/core/operator.h" + + +namespace caffe2 { + +template +class UnsafeCoalesceOp final : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + using Operator::Operator; + + bool RunOnDevice() override { + size_t coalesced_size = 0; + for (int i = 0; i < InputSize(); ++i) { + // For now only float type is supported + CAFFE_ENFORCE( + Input(i).dtype().template Match(), + "Must only coalesce float type, error at input: ", + i); + } + + for (int i = 0; i < InputSize(); ++i) { + coalesced_size += Input(i).numel(); + } + auto* coalesced = Output(OutputSize() - 1, coalesced_size, at::dtype()); + auto coalesced_data = coalesced->template mutable_data(); + + size_t coalesced_offset = 0; + for (auto i = 0; i < InputSize(); ++i) { + const auto num_elems = Input(i).numel(); + auto input_sizes = Input(i).sizes().vec(); + // Don't do anything if both tensors are already pointing on the same data + auto input_data = Input(i).template data(); + if (input_data != coalesced_data + coalesced_offset) { + // Make sure that we don't run operation on the same tensor + CAFFE_ENFORCE_NE( + input_data - Input(i).unsafeGetTensorImpl()->storage_offset(), + coalesced_data - + Output(OutputSize() - 1) + ->unsafeGetTensorImpl() + ->storage_offset(), + "Tensors used in UnsafeCoalesce operator cannot share storage, unless it's inplace operation"); + context_.CopyItemsSameDevice( + Input(i).dtype(), + num_elems, + input_data, + coalesced_data + coalesced_offset); + + // Note: this could cause Input(i) to free it's data if + // Output(i) and Input(i) alias each other. This is safe on a + // GPU (as the copy will happen-before the free), but it's + // worth mentioning. + OperatorBase::SetOutputTensor(i, coalesced->Alias()); + Output(i)->unsafeGetTensorImpl()->set_storage_offset(coalesced_offset); + Output(i)->Resize(input_sizes); + } + coalesced_offset += num_elems; + } + return true; + } +}; +} // namespace caffe2 + +#endif /* CAFFE2_OPERATORS_UNSAFE_COALESCE_OP_H_ */ diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index 3d3f133b1de7a..9abcf5ab0b86f 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -59,6 +59,7 @@ REGISTER_CPU_OPERATOR(GatherRanges, GatherRangesOp); REGISTER_CPU_OPERATOR(LengthsGather, LengthsGatherOp); REGISTER_CPU_OPERATOR(LengthsToSegmentIds, LengthsToSegmentIdsOp); REGISTER_CPU_OPERATOR(LengthsToRanges, LengthsToRangesOp); +REGISTER_CPU_OPERATOR(LengthsToOffsets, LengthsToOffsetsOp); REGISTER_CPU_OPERATOR(SegmentIdsToLengths, SegmentIdsToLengthsOp); REGISTER_CPU_OPERATOR(SegmentIdsToRanges, SegmentIdsToRangesOp); REGISTER_CPU_OPERATOR(LengthsToWeights, LengthsToWeightsOp); @@ -522,20 +523,20 @@ Another output LENGTHS represents each example length within OUTPUT "LENGTHS", "1-D tensor of size N with lengths over gathered data" " for each row in a batch. sum(LENGTHS) == OUTPUT.size()") - .TensorInferenceFunction([](const OperatorDef& /* unused */, - const vector& in) { - std::vector out(2); - - int total = 1; - for (auto d : in[0].dims()) { - total *= d; - } - out[0].add_dims(total); - out[0].set_data_type(in[0].data_type()); - out[1].add_dims(in[1].dims(0)); - out[1].set_data_type(in[1].data_type()); - return out; - }); + .TensorInferenceFunction(OpSchema::NeedsAllInputShapes( + [](const OperatorDef& /* unused */, const vector& in) { + std::vector out(2); + + int total = 1; + for (auto d : in[0].dims()) { + total *= d; + } + out[0].add_dims(total); + out[0].set_data_type(in[0].data_type()); + out[1].add_dims(in[1].dims(0)); + out[1].set_data_type(in[1].data_type()); + return out; + })); OPERATOR_SCHEMA(LengthsGather) .NumInputs(3) @@ -636,6 +637,30 @@ For example, `[1, 3, 0, 2]` transforms into `[[0, 1], [1, 3], [4, 0], [4, 2]]`. "ranges", "2D tensor of shape len(lengths) X 2 and the same type as `lengths`"); +OPERATOR_SCHEMA(LengthsToOffsets) + .NumInputs(1) + .NumOutputs(1) + .SetDoc(R"DOC( +Given a vector of segment lengths, returns a vector of offsets from these lengths, +which will have the same size as the input vector. Output is going to have +the same type as input. For long tensors explicit casting from int32 to int64 +might be necessary prior to this op. + +For example, `[1, 3, 0, 2]` transforms into `[0, 1, 4, 4]`. +)DOC") + .Input(0, "lengths", "1D tensor of int32 or int64 segment lengths.") + .Output(0, "offsets", "1D tensor of the same shape and type as `lengths`") + .TensorInferenceFunction([](const OperatorDef& def, + const vector& in) { + const ArgumentHelper args(def); + bool include_last_offset = + args.GetSingleArgument("include_last_offset", false); + vector out_shape(in[0].dims().begin(), in[0].dims().end()); + out_shape[0] += include_last_offset ? 1 : 0; + return vector{ + CreateTensorShape(out_shape, in[0].data_type())}; + }); + OPERATOR_SCHEMA(SegmentIdsToLengths) .NumInputs(1, 2) .NumOutputs(1) diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index a82b5666fb7b2..bdc9c0bfbfd9d 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -918,6 +918,45 @@ class LengthsToRangesOp : public Operator { } }; +template +class LengthsToOffsetsOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + + template + explicit LengthsToOffsetsOp(Args&&... args) + : Operator(std::forward(args)...), + include_last_offset_(this->template GetSingleArgument( + "include_last_offset", + false)) {} + + bool RunOnDevice() override { + auto& input = Input(0); + auto* output = Output(0); + auto* input_data = input.template data(); + + CAFFE_ENFORCE(input.sizes().size() == 1, "Input must be a vector."); + auto size = input.numel(); + + output->Resize(size + (include_last_offset_ ? 1 : 0)); + auto* output_data = output->template mutable_data(); + + int32_t offset = 0; + for (int i = 0; i < size; ++i) { + auto len = input_data[i]; + output_data[i] = offset; + offset += len; + } + if (include_last_offset_) { + output_data[size] = offset; + } + return true; + } + + private: + bool include_last_offset_; +}; + template class SegmentIdsToLengthsOp : public Operator { public: diff --git a/caffe2/opt/annotations.h b/caffe2/opt/annotations.h index 9bc1f1e313764..89ff7c38a438e 100644 --- a/caffe2/opt/annotations.h +++ b/caffe2/opt/annotations.h @@ -7,7 +7,7 @@ namespace caffe2 { -class CAFFE2_API Caffe2Annotation : public nom::repr::Annotation { +class TORCH_API Caffe2Annotation : public nom::repr::Annotation { public: Caffe2Annotation() : Annotation(AnnotationKind::Caffe2) {} Caffe2Annotation(std::string device) diff --git a/caffe2/opt/backend_cutting.cc b/caffe2/opt/backend_cutting.cc index e1f7808d48b94..45f46ab483308 100644 --- a/caffe2/opt/backend_cutting.cc +++ b/caffe2/opt/backend_cutting.cc @@ -352,9 +352,13 @@ void DumpGraph(NNGraph* g, const std::string& fname) { }; std::ofstream out(fname.c_str()); - out << nom::converters::convertToDotString(g, nnprinter); - out.close(); + if (out) { + out << nom::converters::convertToDotString(g, nnprinter); + } else { + LOG(ERROR) << "Cannot create nomnigraph dump file: " << fname; + } } + caffe2::NetDef OptimizeForBackend( caffe2::NetDef& net, std::function supports, diff --git a/caffe2/opt/backend_cutting.h b/caffe2/opt/backend_cutting.h index c4c0a68575a21..5b4df14db2e94 100644 --- a/caffe2/opt/backend_cutting.h +++ b/caffe2/opt/backend_cutting.h @@ -8,8 +8,8 @@ namespace caffe2 { namespace opt { -CAFFE2_API void DumpGraph(nom::repr::NNGraph* g, const std::string& fname); -CAFFE2_API caffe2::NetDef OptimizeForBackend( +TORCH_API void DumpGraph(nom::repr::NNGraph* g, const std::string& fname); +TORCH_API caffe2::NetDef OptimizeForBackend( caffe2::NetDef& net, std::function supports, std::function transform_func, diff --git a/caffe2/opt/backend_transformer_base.cc b/caffe2/opt/backend_transformer_base.cc index 7bb27fca92abf..9090e0b5277b5 100644 --- a/caffe2/opt/backend_transformer_base.cc +++ b/caffe2/opt/backend_transformer_base.cc @@ -177,6 +177,6 @@ void BackendTransformerBase::dumpNet( const std::string& fname) const { NetDef shape_net(pred_net); addShapeToNet(shape_net, shape_hints); - WriteProtoToTextFile(shape_net, fname); + WriteProtoToTextFile(shape_net, fname, false); } } // namespace caffe2 diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index 95302ca5ccc40..f9c9b6acf034b 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -270,6 +270,30 @@ TEST(BoundShapeInference, LengthsRangeFill) { TensorProto_DataType_INT32); } + +TEST(BoundShapeInference, ConstantFill) { + NetDef net; + net.add_op()->CopyFrom( + CreateOperatorDef("ConstantFill", "", {"X"}, {"Y"}, {})); + ShapeInfoMap shape_map; + BoundShapeSpec spec(20, 1000); + BoundShapeInferencer eng(spec); + shape_map.emplace( + "X", + makeTensorInfo( + {TensorBoundShape_DimType_BATCH, + TensorBoundShape_DimType_CONSTANT}, + {20, 1024})); + eng.InferBoundShapeAndType(net, shape_map, nullptr); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, + "Y", + {TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT}, + {20, 1024}, + TensorProto_DataType_FLOAT); +} + // https://github.com/pytorch/pytorch/issues/40861 TEST(BoundShapeInference, DISABLED_ON_WINDOWS(Reshape)) { NetDef net; diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index d8fe956a0ddd3..8ef5de06b02e7 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -234,7 +234,8 @@ TensorShape& BoundShapeInferencer::CheckAndSetTensorBoundShape( bool is_quantized, bool allow_existing_shape, float scale, - int offset) { + int offset, + bool in_place_op) { auto rt = shape_info_.emplace(name, ShapeInfo()); ShapeInfo& shape_info = rt.first->second; TensorShape& shape = shape_info.shape; @@ -246,8 +247,8 @@ TensorShape& BoundShapeInferencer::CheckAndSetTensorBoundShape( shape_info.q_info.offset.push_back(offset); shape_info.q_info.axis = 1; } - // If the shape information exists in shape_info_ already - if (!rt.second) { + // If the shape information exists in shape_info_ already and we want to compare old/new shapes + if (!rt.second && !in_place_op) { // Check dim size consistency CAFFE_ENFORCE_EQ( shape.dims_size(), @@ -290,13 +291,19 @@ TensorShape& BoundShapeInferencer::CheckAndSetTensorBoundShape( return shape; } // If shape information does not exist in shape_info_, + // or shape info is not final, // set shape info according to inputs. - shape_info.setDimType(t); - shape.mutable_dims()->Clear(); - for (const auto d : bound_dims) { - shape.add_dims(d); + if (!shape_info.getShapeIsFinal()) { + shape_info.setDimType(t); + shape.mutable_dims()->Clear(); + for (const auto d : bound_dims) { + shape.add_dims(d); + } + shape.set_data_type(type); + if (in_place_op) { + shape_info.setShapeIsFinal(true); + } } - shape.set_data_type(type); return shape; } @@ -315,6 +322,12 @@ void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) { if (it != shape_info_.end()) { it->second.setDimType(std::vector( it->second.shape.dims_size(), TensorBoundShape_DimType_CONSTANT)); + if (op.type() == "ConstantFill" && op.input_size() >= 1) { + auto it_input = shape_info_.find(op.input(0)); + if (it_input != shape_info_.end()) { + it->second.setDimType(it_input->second.getDimType()); + } + } } } @@ -851,7 +864,12 @@ void BoundShapeInferencer::InferTile(const OperatorDef& op) { false); } -void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { +void BoundShapeInferencer::InferCommonOp( + const OperatorDef& op, + const OpSchema* schema, + bool bypass_input_check, + bool in_place_op +) { // First, we need to check that all the input shape/types are already // presented try { @@ -859,25 +877,30 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { types_with_independent_output_shape = {"Int8GenQuantParams", "Int8QuantSchemeBlobFill", "ComputeEqualizationScale"}; + const static std::unordered_set + pruning_ops = {"RowwisePruneI64", "RowwisePruneI32"}; std::vector input_shapes; for (const auto& input : op.input()) { const auto it = shape_info_.find(input); if (it == shape_info_.end() && - !types_with_independent_output_shape.count(op.type())) { + !types_with_independent_output_shape.count(op.type()) && !bypass_input_check) { LOG(WARNING) << "Cannot find shape info for " << input << ". Skipping " << op.type(); return; } - if (types_with_independent_output_shape.count(op.type())) { + if (types_with_independent_output_shape.count(op.type()) || (bypass_input_check && it == shape_info_.end())) { TensorShape input_shape; input_shapes.emplace_back(std::move(input_shape)); - } else { input_shapes.emplace_back(it->second.shape); } } - const OpSchema* schema = OpSchemaRegistry::Schema(op.type()); + // Schema can be pre-defined. + // If not predefined, get the schema for the op. + if (schema == nullptr) { + schema = OpSchemaRegistry::Schema(op.type()); + } CAFFE_ENFORCE(schema); std::vector output_shapes; output_shapes = schema->InferTensor(op, input_shapes); @@ -923,7 +946,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { for (int i = 0; i < output_shapes.size(); i++) { const auto& shape = output_shapes[i]; - if (infered_data_type == TensorProto::UNDEFINED) { + if (infered_data_type == TensorProto::UNDEFINED || pruning_ops.find(op.type()) != pruning_ops.end()) { infered_data_type = shape.data_type(); } if (shape.unknown_shape()) { @@ -937,7 +960,8 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { is_quantized, false, scale, - offset); + offset, + in_place_op); } } catch (const caffe2::EnforceNotMet& e) { LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type() diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h index cf034789f0c0c..54290c5a86278 100644 --- a/caffe2/opt/bound_shape_inferencer.h +++ b/caffe2/opt/bound_shape_inferencer.h @@ -15,7 +15,7 @@ namespace caffe2 { // max_seq_size is the upper bound of length of every item in a batch. // Upper bound of length of a batch of items should be max_batch_size * // max_seq_size. -struct CAFFE2_API BoundShapeSpec { +struct TORCH_API BoundShapeSpec { explicit BoundShapeSpec(int64_t b, int64_t q) : max_batch_size(b), max_seq_size(q), @@ -86,7 +86,7 @@ class BoundShapeInferencerBase { bool extract_feature_len_; }; -class CAFFE2_API BoundShapeInferencer : public BoundShapeInferencerBase { +class TORCH_API BoundShapeInferencer : public BoundShapeInferencerBase { public: explicit BoundShapeInferencer(const BoundShapeSpec& spec) : BoundShapeInferencerBase(spec) {} @@ -107,7 +107,8 @@ class CAFFE2_API BoundShapeInferencer : public BoundShapeInferencerBase { bool is_quantized, bool allow_existing_shape = false, float scale = 1, - int offset = 0); + int offset = 0, + bool in_place_op = false); TensorShape& SetTensorBoundShapeIfNotExist( const std::string& name, @@ -136,7 +137,7 @@ class CAFFE2_API BoundShapeInferencer : public BoundShapeInferencerBase { // Standard shape/type inference using op schema registered shape inference // function - void InferCommonOp(const OperatorDef& op); + void InferCommonOp(const OperatorDef& op, const OpSchema* schema = nullptr, bool bypass_input_check = false, bool in_place_op = false); // Initialize private parameters, such as shape_info, extract_feature_len_ // This is called at the beginning of InferBoundShapeAndType() @@ -148,7 +149,7 @@ class CAFFE2_API BoundShapeInferencer : public BoundShapeInferencerBase { int64_t current_max_batch_size_{0}; }; -CAFFE2_API std::shared_ptr getBoundShapeInferencer( +TORCH_API std::shared_ptr getBoundShapeInferencer( const BoundShapeSpec& spec); C10_DECLARE_SHARED_REGISTRY( diff --git a/caffe2/opt/converter.h b/caffe2/opt/converter.h index 5cd69f189d7a9..734189930a38e 100644 --- a/caffe2/opt/converter.h +++ b/caffe2/opt/converter.h @@ -13,38 +13,38 @@ namespace caffe2 { -CAFFE2_API void injectDataEdgeIndicators(caffe2::NetDef* net); -CAFFE2_API void removeDataEdgeIndicators(caffe2::NetDef* net); +TORCH_API void injectDataEdgeIndicators(caffe2::NetDef* net); +TORCH_API void removeDataEdgeIndicators(caffe2::NetDef* net); // Default conversion to a NNModule // Optionally strict -- which checks for various input and output conditions. // Optionally this function will update a vector that maps operators in the // netdef positionally to NodeRefs in the resultant NNModule. -CAFFE2_API nom::repr::NNModule convertToNNModule( +TORCH_API nom::repr::NNModule convertToNNModule( const caffe2::NetDef& net, bool strict = false, std::vector* = nullptr); -CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&); +TORCH_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&); // Pass in an oldNet to copy all the attributes of that network. // Be warned that transformations that modify the graph's inputs or outputs // are not reflected in changes to external_input or external_output. -CAFFE2_API caffe2::NetDef convertToCaffe2Proto( +TORCH_API caffe2::NetDef convertToCaffe2Proto( nom::repr::NNModule&, const caffe2::NetDef& oldNet); // Use these functions instead of the registry directly. -CAFFE2_API std::unique_ptr +TORCH_API std::unique_ptr convertToNeuralNetOperator(const caffe2::OperatorDef& op); -CAFFE2_API caffe2::OperatorDef convertToOperatorDef( +TORCH_API caffe2::OperatorDef convertToOperatorDef( const nom::repr::NNGraph::NodeRef& instrNode); // If the annotation doesn't exist, attempt to add it -CAFFE2_API Caffe2Annotation* getOrAddCaffe2Annotation( +TORCH_API Caffe2Annotation* getOrAddCaffe2Annotation( nom::repr::NNGraph::NodeRef& instrNode); -class CAFFE2_API Converter { +class TORCH_API Converter { public: explicit Converter() = default; virtual std::unique_ptr diff --git a/caffe2/opt/custom/in_batch_broadcast.cc b/caffe2/opt/custom/in_batch_broadcast.cc index 99668a6c82b5a..8406e2d29a58e 100644 --- a/caffe2/opt/custom/in_batch_broadcast.cc +++ b/caffe2/opt/custom/in_batch_broadcast.cc @@ -112,8 +112,8 @@ void inBatchBroadcast( setShape(blob, new_blob); const auto rit = reversed.find(blob); if (rit != reversed.end()) { - const auto& orignal_input = rit->second; - setShape(orignal_input, ""); + const auto& original_input = rit->second; + setShape(original_input, ""); } } diff --git a/caffe2/opt/device.h b/caffe2/opt/device.h index daa634de0563f..b2425ccd6c5c5 100644 --- a/caffe2/opt/device.h +++ b/caffe2/opt/device.h @@ -4,7 +4,7 @@ namespace caffe2 { namespace opt { -CAFFE2_API void insertCopies( +TORCH_API void insertCopies( nom::repr::NNModule* nn, std::function supported, std::function copyToFn, diff --git a/caffe2/opt/distributed.h b/caffe2/opt/distributed.h index 27e57f5a2a176..8089612f846c1 100644 --- a/caffe2/opt/distributed.h +++ b/caffe2/opt/distributed.h @@ -16,7 +16,7 @@ namespace caffe2 { /// /// Throws an exception if the passed in blobMap contains /// blobs that are not present in the NNModule. -CAFFE2_API nom::repr::NNModule convertToNNModule( +TORCH_API nom::repr::NNModule convertToNNModule( caffe2::NetDef&, std::map); @@ -24,10 +24,10 @@ CAFFE2_API nom::repr::NNModule convertToNNModule( /// if you already have an NNModule. /// You probably don't want to use these /// if you can use convertToNNModule instead. -CAFFE2_API void addBlobDeviceOptions( +TORCH_API void addBlobDeviceOptions( std::map blobMap, nom::repr::NNModule* nn); -CAFFE2_API void injectDataEdgeIndicators(nom::repr::NNModule* nn); -CAFFE2_API void removeDataEdgeIndicators(nom::repr::NNModule* nn); +TORCH_API void injectDataEdgeIndicators(nom::repr::NNModule* nn); +TORCH_API void removeDataEdgeIndicators(nom::repr::NNModule* nn); } // namespace caffe2 diff --git a/caffe2/opt/fakefp16_transform.cc b/caffe2/opt/fakefp16_transform.cc index 424056bd2c809..d4cc76b00c87c 100644 --- a/caffe2/opt/fakefp16_transform.cc +++ b/caffe2/opt/fakefp16_transform.cc @@ -299,8 +299,8 @@ void fakeFp16Transform(NetDef* net) { FLAGS_fake_fp16_conversion_use_fp16_acc, FLAGS_fake_fp16_conversion_use_nnpi); - auto blacklist_pos = glow::ParseNetPositionList(FLAGS_onnxifi_blacklist); - auto blacklist_type = glow::ParseBlackListOps(FLAGS_onnxifi_blacklist_ops); + auto blocklist_pos = glow::ParseNetPositionList(FLAGS_onnxifi_blacklist); + auto blocklist_type = glow::ParseBlockListOps(FLAGS_onnxifi_blacklist_ops); // A hack to only do fakefp16 transformation for operators which will be // lowered to ONNXIFI. @@ -320,7 +320,7 @@ void fakeFp16Transform(NetDef* net) { auto* op = net->mutable_op(i); auto net_pos = ArgumentHelper::GetSingleArgument(*op, "net_pos", -1); - if (blacklist_pos.count(net_pos) || blacklist_type.count(op->type())) { + if (blocklist_pos.count(net_pos) || blocklist_type.count(op->type())) { continue; } auto it = kFakeFp16OpConversionMap.find(op->type()); diff --git a/caffe2/opt/fakefp16_transform.h b/caffe2/opt/fakefp16_transform.h index 969738191bb7b..22729a0585be7 100644 --- a/caffe2/opt/fakefp16_transform.h +++ b/caffe2/opt/fakefp16_transform.h @@ -12,14 +12,14 @@ namespace caffe2 { namespace opt { // Mapping from fp32 ops to fakefp16 ops -CAFFE2_API std::unordered_map getFakeFp16OpMapping( +TORCH_API std::unordered_map getFakeFp16OpMapping( bool use_fp16_acc = false, bool use_nnpi = false); -CAFFE2_API void fakeFp16FuseOps(NetDef* net); +TORCH_API void fakeFp16FuseOps(NetDef* net); // Transform normal fp32 operators to fakefp16 operators. -CAFFE2_API void fakeFp16Transform(NetDef* net); +TORCH_API void fakeFp16Transform(NetDef* net); } // namespace opt } // namespace caffe2 diff --git a/caffe2/opt/fusion.h b/caffe2/opt/fusion.h index 0973ade54b383..7dde163553065 100644 --- a/caffe2/opt/fusion.h +++ b/caffe2/opt/fusion.h @@ -25,7 +25,7 @@ namespace opt { using namespace nom; -CAFFE2_API void fuseConvBN(repr::NNModule* nn, caffe2::Workspace* ws); +TORCH_API void fuseConvBN(repr::NNModule* nn, caffe2::Workspace* ws); // Generic activation fusion helper. // diff --git a/caffe2/opt/glow_net_transform.cc b/caffe2/opt/glow_net_transform.cc index f021d263106d6..29fc6e553d72a 100644 --- a/caffe2/opt/glow_net_transform.cc +++ b/caffe2/opt/glow_net_transform.cc @@ -13,11 +13,6 @@ C10_DEFINE_bool( true, "Attach AdjustBatch ops at input/outputs of the Onnxifi ops"); -C10_DEFINE_bool( - onnxifi_loop_test_mode, - false, - "For test purpose only. Build a dummy net just to test the functionality"); - C10_DEFINE_bool( enforce_fp32_inputs_into_fp16, false, @@ -93,7 +88,7 @@ std::unordered_set ParseNetPositionList(const std::string& str) { return net_position_list; } -std::unordered_set ParseBlackListOps(const std::string& str) { +std::unordered_set ParseBlockListOps(const std::string& str) { std::unordered_set ops; if (str.empty()) { return ops; @@ -112,13 +107,14 @@ void onnxifi( const std::vector& input_names, const std::vector& output_names, const std::vector& weight_names, - const std::unordered_set& blacklist, - const ShapeInfoMap& shape_hints, + const std::unordered_set& blocklist, + const ShapeInfoMap& shape_hints_max_bs, bool use_onnx, size_t max_batch_size, size_t max_seq_size, bool load_model_by_blob, - bool predictor_net_ssa_rewritten) { + bool predictor_net_ssa_rewritten, + const std::unordered_map &shape_hints_per_bs) { // Split SparseLengthsSumSparse so that we can lower the SparseLengthsSum part splitSparseLengthsSumSparse(net, *ws); @@ -146,11 +142,11 @@ void onnxifi( opts.load_model_by_blob = load_model_by_blob; opts.enforce_fp32_inputs_into_fp16 = FLAGS_enforce_fp32_inputs_into_fp16; opts.merge_fp32_inputs_into_fp16 = FLAGS_merge_fp32_inputs_into_fp16; - opts.loop_test = FLAGS_onnxifi_loop_test_mode; opts.predictor_net_ssa_rewritten = predictor_net_ssa_rewritten; opts.timeout = FLAGS_onnxifi_timeout_ms; + opts.shape_hints_per_bs = shape_hints_per_bs; - ShapeInfoMap more_shape_hints = shape_hints; + ShapeInfoMap more_shape_hints = shape_hints_max_bs; if (!FLAGS_onnxifi_shape_hints.empty()) { parseShapeInfoMapFromString(FLAGS_onnxifi_shape_hints, more_shape_hints); } @@ -158,19 +154,19 @@ void onnxifi( // Before applying backlist, make sure the ops in the net all have an net_pos; caffe2::BackendTransformerBase::annotateOpIndex(net); - // Parse the blacklist - auto more_blacklist = ParseNetPositionList(FLAGS_onnxifi_blacklist); - for (const auto& b : blacklist) { - more_blacklist.emplace(b); + // Parse the blocklist + auto more_blocklist = ParseNetPositionList(FLAGS_onnxifi_blacklist); + for (const auto& b : blocklist) { + more_blocklist.emplace(b); } // ONNX mode will change the op order so it doesn't apply here if (!opts.use_onnx) { - auto blacklisted_ops = ParseBlackListOps(FLAGS_onnxifi_blacklist_ops); + auto blocklisted_ops = ParseBlockListOps(FLAGS_onnxifi_blacklist_ops); for (const auto& op : net->op()) { - if (blacklisted_ops.count(op.type())) { + if (blocklisted_ops.count(op.type())) { ArgumentHelper helper(op); - more_blacklist.emplace(helper.GetSingleArgument(op, kNetPos, -1)); + more_blocklist.emplace(helper.GetSingleArgument(op, kNetPos, -1)); } } } @@ -183,7 +179,7 @@ void onnxifi( // 1. for specified op, we find its input and outputs. // 2. for each input and output, we create a new copy op and attach it as an // input to the copy. - // 3. we blacklist these new copy operators from onnxification. This forces + // 3. we blocklist these new copy operators from onnxification. This forces // these intermediate tensors to also become outputs of the onnxifi op. // 4. we put the right arguments on the copy ops so TensorObserver can print // out the values. @@ -217,14 +213,11 @@ void onnxifi( AddArgument(kNetPos, pos, ©_op); AddArgument("observe_input_tensors", 1, ©_op); net->add_op()->CopyFrom(copy_op); - more_blacklist.emplace(pos); + more_blocklist.emplace(pos); } OnnxifiTransformer ts(opts); - ts.transform(ws, net, weight_names, more_shape_hints, more_blacklist); - if (FLAGS_onnxifi_debug_mode) { - WriteProtoToTextFile(*net, "debug_transformed_net.pb_txt"); - } + ts.transform(ws, net, weight_names, more_shape_hints, more_blocklist); // Cleanup the input from the workspace for (const auto& i : input_names) { diff --git a/caffe2/opt/glow_net_transform.h b/caffe2/opt/glow_net_transform.h index 7e2eedec90aa1..774cb7dd7c1ca 100644 --- a/caffe2/opt/glow_net_transform.h +++ b/caffe2/opt/glow_net_transform.h @@ -16,7 +16,7 @@ namespace caffe2 { namespace glow { /// Onnxifi transformation on the net and workspace. We also /// needed the input data/shape to populate the shape. In addition, we take a \p -/// blacklist to control and mask what ops we want to consider in onnxifi +/// blocklist to control and mask what ops we want to consider in onnxifi /// process. We can also set whether to use ONNX proto or C2 proto through /// ONNXIFI interface. void onnxifi( @@ -25,16 +25,17 @@ void onnxifi( const std::vector& input_names, const std::vector& output_names, const std::vector& weight_names, - const std::unordered_set& blacklist, - const ShapeInfoMap& shape_hints, + const std::unordered_set& blocklist, + const ShapeInfoMap& shape_hints_max_bs, bool use_onnx, size_t max_batch_size = 0, size_t max_seq_size = 0, bool load_model_by_blob = false, - bool predictor_net_ssa_rewritten = false); + bool predictor_net_ssa_rewritten = false, + const std::unordered_map &shape_hints_per_bs = {}); std::unordered_set ParseNetPositionList(const std::string& str); -std::unordered_set ParseBlackListOps(const std::string& str); +std::unordered_set ParseBlockListOps(const std::string& str); } // namespace glow } // namespace caffe2 diff --git a/caffe2/opt/mobile.cc b/caffe2/opt/mobile.cc index adbbbd19a1e36..c54b70405f074 100644 --- a/caffe2/opt/mobile.cc +++ b/caffe2/opt/mobile.cc @@ -99,7 +99,7 @@ void fuseNNPACKConvRelu(repr::NNModule* nn) { return false; } caffe2::string algo = "AUTO"; - for (const auto arg : op.arg()) { + for (const auto &arg : op.arg()) { if (arg.name() == "algo") { algo = arg.s(); } diff --git a/caffe2/opt/mobile.h b/caffe2/opt/mobile.h index 78e98763a32ea..d31a3f8212c25 100644 --- a/caffe2/opt/mobile.h +++ b/caffe2/opt/mobile.h @@ -7,8 +7,8 @@ namespace caffe2 { namespace opt { -CAFFE2_API void addNNPACK(nom::repr::NNModule* nn, bool low_memory = false); -CAFFE2_API void fuseNNPACKConvRelu(nom::repr::NNModule* nn); +TORCH_API void addNNPACK(nom::repr::NNModule* nn, bool low_memory = false); +TORCH_API void fuseNNPACKConvRelu(nom::repr::NNModule* nn); } // namespace opt } // namespace caffe2 diff --git a/caffe2/opt/onnx_convert.h b/caffe2/opt/onnx_convert.h index 707d41321f77e..89bf209c37d2d 100644 --- a/caffe2/opt/onnx_convert.h +++ b/caffe2/opt/onnx_convert.h @@ -1,6 +1,6 @@ #include "caffe2/core/common.h" -class CAFFE2_API OnnxAnnotation : public nom::repr::Annotation { +class TORCH_API OnnxAnnotation : public nom::repr::Annotation { public: OnnxAnnotation() : Annotation(AnnotationKind::Onnx) {} OnnxAnnotation(std::string device) @@ -30,8 +30,8 @@ class CAFFE2_API OnnxAnnotation : public nom::repr::Annotation { caffe2::OperatorDef* OpDef = nullptr; }; -CAFFE2_API nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, std::unordered_map* blobMapOut = nullptr); +TORCH_API nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, std::unordered_map* blobMapOut = nullptr); -CAFFE2_API caffe2::NetDef convertToOnnxProto(nom::repr::NNModule&); +TORCH_API caffe2::NetDef convertToOnnxProto(nom::repr::NNModule&); -CAFFE2_API std::unique_ptr convertToOperatorDef(caffe2::OperatorDef op); +TORCH_API std::unique_ptr convertToOperatorDef(caffe2::OperatorDef op); diff --git a/caffe2/opt/onnxifi_op.cc b/caffe2/opt/onnxifi_op.cc index 158f9b7a7ed8d..624e91f3780f2 100644 --- a/caffe2/opt/onnxifi_op.cc +++ b/caffe2/opt/onnxifi_op.cc @@ -300,6 +300,46 @@ details::OutputReshapeInfo OnnxifiOp::initOutputReshapeInfo() return output_reshape_info; } +template <> +template +void OnnxifiOp::fillOutputReshapeInfo( + const DimContainer& real_shape, + c10::ArrayRef max_shape, + details::OutputReshapeInfo &output_reshape_info, + int currentIndex) { + CAFFE_ENFORCE_EQ(real_shape.size(), max_shape.size()); + const auto dim_size = real_shape.size(); + auto& begin = output_reshape_info.begins[currentIndex]; + begin.Resize(dim_size); + int32_t* begin_ptr = begin.template mutable_data(); + auto& end = output_reshape_info.ends[currentIndex]; + end.Resize(dim_size); + int32_t* end_ptr = end.template mutable_data(); + int32_t mismatch = 0; + for (int j = 0; j < dim_size; ++j) { + CAFFE_ENFORCE_GE( + max_shape[j], + real_shape[j], + "It is weird that max shape of ", + output_names_[currentIndex], + " is smaller than real shape at dim ", + j, + " (", + max_shape[j], + " vs ", + real_shape[j], + ")"); + begin_ptr[j] = 0; + if (max_shape[j] >= real_shape[j]) { + end_ptr[j] = real_shape[j]; + mismatch += j; + } else { + end_ptr[j] = -1; + } + } + output_reshape_info.fast_path[currentIndex] = !mismatch; +} + template <> int OnnxifiOp::extractOutputBatchSizes() { if (use_onnx_ || !adjust_output_batch_) { @@ -337,77 +377,55 @@ int OnnxifiOp::extractOutputBatchSizes() { return current_batch_size; } - auto it = - output_reshape_info_.emplace(current_batch_size, initOutputReshapeInfo()); - auto& output_reshape_info = it.first->second; - BoundShapeSpec spec(dims[0], max_seq_size_); - auto bound_shape_inferencer = - BoundShapeInferencerRegistry()->Create("C10", spec); - for (int i = 0; i < InputSize(); ++i) { - at::IntArrayRef dim0; - bool quantized = false; - if (this->template InputIsType(i)) { - const auto& input_tensor_int8 = - this->template Input(i); - const auto& t0 = input_tensor_int8.t; - dim0 = t0.sizes(); - quantized = true; - } else { - const auto& t0 = Input(i); - dim0 = t0.sizes(); - } - TensorShape shape; - for (const auto d : dim0) { - shape.add_dims(d); - } - std::vector dim_type( - shape.dims_size(), TensorBoundShape_DimType_CONSTANT); - if (dim_type.size()) { - dim_type[0] = TensorBoundShape_DimType_BATCH; + auto& output_reshape_info = output_reshape_info_.emplace(current_batch_size, initOutputReshapeInfo()).first->second; + + if (use_passed_output_shapes_) { + auto shape_info_it = output_shapes_per_bs_.find(current_batch_size); + CAFFE_ENFORCE(shape_info_it != output_shapes_per_bs_.end(), "Unable to find outputs shapes for bs=", current_batch_size); + CAFFE_ENFORCE_EQ(shape_info_it->second.size(), OutputSize()); + + for (int i = 0; i < OutputSize(); ++i) { + fillOutputReshapeInfo(shape_info_it->second[i], output_shapes_max_bs_[i], output_reshape_info, i); } - input_shape_info_[input_names_[i]] = - ShapeInfo(dim_type, std::move(shape), quantized); - } - bound_shape_inferencer->InferBoundShapeAndType( - netdef_, input_shape_info_, nullptr, false); - const auto& shape_info = bound_shape_inferencer->shape_info(); - for (int i = 0; i < OutputSize(); ++i) { - const auto it = shape_info.find(output_names_[i]); - CAFFE_ENFORCE(it != shape_info.end()); - const auto& real_shape = it->second.shape; - const auto& max_shape = output_shapes_[i]; - CAFFE_ENFORCE_EQ(real_shape.dims_size(), max_shape.size()); - const auto dim_size = real_shape.dims_size(); - auto& begin = output_reshape_info.begins[i]; - begin.Resize(dim_size); - int32_t* begin_ptr = begin.template mutable_data(); - auto& end = output_reshape_info.ends[i]; - end.Resize(dim_size); - int32_t* end_ptr = end.template mutable_data(); - int32_t mismatch = 0; - for (int j = 0; j < dim_size; ++j) { - CAFFE_ENFORCE_GE( - max_shape[j], - real_shape.dims(j), - "It is weird that max shape of ", - output_names_[i], - " is smaller than real shape at dim ", - j, - " (", - max_shape[j], - " vs ", - real_shape.dims(j), - ")"); - begin_ptr[j] = 0; - if (max_shape[j] >= real_shape.dims(j)) { - end_ptr[j] = real_shape.dims(j); - mismatch += j; + } else { + BoundShapeSpec spec(dims[0], max_seq_size_); + auto bound_shape_inferencer = + BoundShapeInferencerRegistry()->Create("C10", spec); + for (int i = 0; i < InputSize(); ++i) { + at::IntArrayRef dim0; + bool quantized = false; + if (this->template InputIsType(i)) { + const auto& input_tensor_int8 = + this->template Input(i); + const auto& t0 = input_tensor_int8.t; + dim0 = t0.sizes(); + quantized = true; } else { - end_ptr[j] = -1; + const auto& t0 = Input(i); + dim0 = t0.sizes(); + } + TensorShape shape; + for (const auto d : dim0) { + shape.add_dims(d); + } + std::vector dim_type( + shape.dims_size(), TensorBoundShape_DimType_CONSTANT); + if (dim_type.size()) { + dim_type[0] = TensorBoundShape_DimType_BATCH; } + input_shape_info_[input_names_[i]] = + ShapeInfo(dim_type, std::move(shape), quantized); + } + bound_shape_inferencer->InferBoundShapeAndType( + netdef_, input_shape_info_, nullptr, false); + const auto& shape_info = bound_shape_inferencer->shape_info(); + for (int i = 0; i < OutputSize(); ++i) { + const auto find_res = shape_info.find(output_names_[i]); + CAFFE_ENFORCE(find_res != shape_info.end()); + fillOutputReshapeInfo(find_res->second.shape.dims(), output_shapes_max_bs_[i], output_reshape_info, i); } - output_reshape_info.fast_path[i] = !mismatch; } + return current_batch_size; } @@ -439,8 +457,10 @@ void OnnxifiOp::adjustOutputBatchSizes(int current_batch_size) { } template <> -void OnnxifiOp::setOutputShapeAndType(int output_idx) { - tensor_dims_int64_.clear(); +void OnnxifiOp::setOutputShapeAndType( + int output_idx, + c10::SmallVector& tensor_dims_int64) { + tensor_dims_int64.clear(); std::vector tensor_dims; uint64_t type = ONNXIFI_DATATYPE_FLOAT32; const auto it = output_shape_hints_.find(output_idx); @@ -458,7 +478,7 @@ void OnnxifiOp::setOutputShapeAndType(int output_idx) { tensor_descriptor.dimensions = tensor_dims.size(); CAFFE_ENFORCE( tensor_descriptor.dimensions != 0, tensor_descriptor.name, " has 0 dim"); - auto& output_shape = output_shapes_[output_idx]; + auto& output_shape = output_shapes_max_bs_[output_idx]; output_shape.clear(); output_shape.insert( output_shape.begin(), tensor_dims.cbegin(), tensor_dims.cend()); @@ -466,14 +486,14 @@ void OnnxifiOp::setOutputShapeAndType(int output_idx) { std::copy( tensor_dims.cbegin(), tensor_dims.cend(), - std::back_inserter(tensor_dims_int64_)); + std::back_inserter(tensor_dims_int64)); // Setup the output C2 tensor if (!info.quantized) { // Normal Tensor auto* output_tensor = Output( output_idx, - tensor_dims_int64_, + tensor_dims_int64, at::dtype(OnnxifiTypeToDataType(type)).device(CPU)); setOutputTensorDescriptorTypeAndBuffer( type, output_tensor, &tensor_descriptor); @@ -481,7 +501,7 @@ void OnnxifiOp::setOutputShapeAndType(int output_idx) { // single quantizer, output Int8Tensor auto* output_tensor = this->template Output(output_idx); - output_tensor->t.Resize(tensor_dims_int64_); + output_tensor->t.Resize(tensor_dims_int64); setOutputTensorDescriptorTypeAndBuffer( type, &output_tensor->t, &tensor_descriptor); tensor_descriptor.quantizationParams = 1; @@ -524,8 +544,9 @@ bool OnnxifiOp::RunOnDevice() { } CAFFE_ENFORCE_EQ(output_desc_.size(), OutputSize()); + c10::SmallVector tensor_dims_int64; for (unsigned i = 0U; i < OutputSize(); ++i) { - setOutputShapeAndType(i); + setOutputShapeAndType(i, tensor_dims_int64); } bool ext_supported = false; onnxMemoryFenceV1 input_fence; diff --git a/caffe2/opt/onnxifi_op.h b/caffe2/opt/onnxifi_op.h index f19403a14e58b..caffae6328277 100644 --- a/caffe2/opt/onnxifi_op.h +++ b/caffe2/opt/onnxifi_op.h @@ -4,6 +4,7 @@ #include "onnx/onnx_pb.h" +#include "c10/util/Exception.h" #include "c10/util/SmallVector.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" @@ -18,7 +19,7 @@ namespace caffe2 { namespace details { /// Provides slicing info for the outputs. All the vector members should be of -/// the same size as number of outpus of the Onnxifi op. +/// the same size as number of outputs of the Onnxifi op. struct OutputReshapeInfo { std::vector begins; std::vector ends; @@ -54,6 +55,7 @@ class OnnxifiOp final : public Operator { timeout_(this->template GetSingleArgument("timeout", 0)), nominal_batch_idx_( this->template GetSingleArgument("nominal_batch_idx", 0)), + use_passed_output_shapes_(this->template GetSingleArgument("use_passed_output_shapes", 0)), adjust_quantized_offset_(this->template GetSingleArgument( "adjust_quantized_offset", 128)) { @@ -65,7 +67,7 @@ class OnnxifiOp final : public Operator { CAFFE_ENFORCE(!onnx_model_str.empty(), "onnx_model cannot be empty"); if (use_glow_aot_) { auto netdef_str = - this->template GetSingleArgument("netdef_str", ""); + this->template GetSingleArgument("netdef_str", ""); CAFFE_ENFORCE(ParseProtoFromLargeString(netdef_str, &netdef_)); } else if (!use_onnx_) { CAFFE_ENFORCE(ParseProtoFromLargeString(onnx_model_str, &netdef_)); @@ -85,7 +87,7 @@ class OnnxifiOp final : public Operator { all_offsets_.reserve(ws->Blobs().size()); all_scales_.reserve(ws->Blobs().size()); input_shapes_.resize(input_names_.size()); - output_shapes_.resize(output_names_.size()); + output_shapes_max_bs_.resize(output_names_.size()); quantized_outputs_.resize(output_names_.size(), false); int output_idx = 0; ArgumentHelper helper(operator_def); @@ -126,6 +128,35 @@ class OnnxifiOp final : public Operator { adjust_quantized_offset_ = 0; } + LOG(INFO) << "use_onnx_=" << use_onnx_ + << ", use_glow_aot_=" << use_glow_aot_ + << ", use_passed_output_shapes_=" << use_passed_output_shapes_; + + if (use_passed_output_shapes_) { + // Populate output_shapes_per_bs_ + for (int bs = 1; bs < max_batch_size_; ++bs) { + auto output_shapes_tp = helper.GetRepeatedArgument("output_shapes_bs_" + caffe2::to_string(bs)); + auto output_qshapes_tp = helper.GetRepeatedArgument("output_qshapes_bs_" + caffe2::to_string(bs)); + CAFFE_ENFORCE_EQ(output_names_.size(), output_shapes_tp.size() + output_qshapes_tp.size()); + + std::unordered_map name_to_shape; + for (const auto& output_shape_tp : output_shapes_tp) { + name_to_shape.emplace(output_shape_tp.name(), details::TensorInfo{output_shape_tp}); + } + for (const auto& output_qshape_tp : output_qshapes_tp) { + name_to_shape.emplace(output_qshape_tp.name(), details::TensorInfo{output_qshape_tp}); + } + + for (output_idx = 0; output_idx < output_names_.size(); ++output_idx) { + auto it = name_to_shape.find(output_names_[output_idx]); + CAFFE_ENFORCE(it != name_to_shape.end()); + output_shapes_per_bs_[bs].push_back({}); + auto &output_shapes = output_shapes_per_bs_[bs].back(); + std::copy(it->second.dims.cbegin(), it->second.dims.cend(), std::back_inserter(output_shapes)); + } + } + } + // Get output resizing hints adjust_output_batch_ = this->template GetSingleArgument("adjust_output_batch", 0); @@ -165,7 +196,13 @@ class OnnxifiOp final : public Operator { } #endif private: - void setOutputShapeAndType(int output_idx); + // Second argument is a cache vector to avoid repeated reallocation. + // The existence of this is not ideal, which is purely due to the fact that + // we use int64_t for c2::tensor dim but uint64_t for onnxDesciptor dim. + // Maybe we should just use int64_t. + void setOutputShapeAndType( + int output_idx, + c10::SmallVector& tensor_dims_int64); void buildPropertyList( const OperatorDef& /* unused */, @@ -187,7 +224,7 @@ class OnnxifiOp final : public Operator { this->template GetRepeatedArgument("initializers"); // Build the Onnxifi engine auto backend_index = - this->template GetSingleArgument("backend_id", use_onnx_ ? 1 : 0); + this->template GetSingleArgument("backend_id", use_onnx_ ? 1 : 0); // If using Glow AOT, override the backend_id to 1, since it uses a custom // ONNX format, and that's the id we use for the ONNX backend. if (use_glow_aot_) { @@ -266,18 +303,24 @@ class OnnxifiOp final : public Operator { static const uint64_t auxPropertiesListAOT[] = { ONNXIFI_OPTIMIZATION_AOT, ONNXIFI_GRAPH_PROPERTY_NONE}; - CAFFE_ENFORCE_EQ( - lib_->onnxInitGraph( - backend, - use_glow_aot_ ? auxPropertiesListAOT : nullptr, - onnx_model_str.size(), - (const void*)(onnx_model_str.c_str()), - weight_descs.size(), - weight_descs.data(), - &graph, - static_cast(max_seq_size_), - defered_blob_reader), - ONNXIFI_STATUS_SUCCESS); + auto ret = lib_->onnxInitGraph( + backend, + use_glow_aot_ ? auxPropertiesListAOT : nullptr, + onnx_model_str.size(), + (const void*)(onnx_model_str.c_str()), + weight_descs.size(), + weight_descs.data(), + &graph, + static_cast(max_seq_size_), + defered_blob_reader); + if (ret != ONNXIFI_STATUS_SUCCESS) { + if (ret == ONNXIFI_STATUS_FATAL_ERROR) { + C10_THROW_ERROR( + OnnxfiBackendSystemError, "Fatal error during onnxInitGraph"); + } else { + CAFFE_THROW("onnxInitGraph failed"); + } + } return std::make_shared( backend_id, backend, graph, lib_, std::move(weight_shape_info)); @@ -326,6 +369,14 @@ class OnnxifiOp final : public Operator { #endif } + /// Helper method for extractOutputBatchSizes(), used to deduplicate code of populating output reshape infos + template + void fillOutputReshapeInfo( + const DimContainer& real_shape, + c10::ArrayRef max_shape, + details::OutputReshapeInfo &output_reshape_info, + int index); + /// Extract output batch size. If the output batch size is going to be at /// max_batch_size_, return true indicating that no output shape adjustment is /// needed. Otherwise, return false. @@ -411,7 +462,7 @@ class OnnxifiOp final : public Operator { int nominal_batch_idx_{0}; // We bind the op input/output by position while ONNXIFI binds input/output by - // names. In addition, op input/output names can be writtten by, for example, + // names. In addition, op input/output names can be written by, for example, // memonger. We cache the original input/output name of ONNX object here and // bind them by position. std::vector input_names_; @@ -421,16 +472,14 @@ class OnnxifiOp final : public Operator { NetDef netdef_; std::vector> input_shapes_; - std::vector> output_shapes_; + std::vector> output_shapes_max_bs_; + + // Mapping of batch sizes to output shapes + std::unordered_map>> output_shapes_per_bs_; // Indicate if i-th output is a quantized tensor std::vector quantized_outputs_; - // A cache vector to avoid repeated reallocation. The existence of this is not - // ideal, which is purely due to the factor that we use int64_t for c2::tensor - // dim but uint64_t for onnxDesciptor dim. Maybe we should just use int64_t - c10::SmallVector tensor_dims_int64_; - // This is for multi group quantization info std::vector> all_scales_; std::vector> all_offsets_; @@ -442,6 +491,9 @@ class OnnxifiOp final : public Operator { // max_batch_size std::unordered_map input_shape_info_; + // Whether we should use passed output shape hints or do shape inference + const bool use_passed_output_shapes_{false}; + // Whether we need to resize outputs or not bool adjust_output_batch_{false}; diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 9166153cf6931..2dd8c8d2d8b4e 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -403,235 +403,6 @@ void mergeFp32InputsAndConvertToFp16( } } -NetDef buildLoopTestNet( - const NetDef& net, - const std::unordered_set& initialization_list, - std::unordered_map* shape_hints, - size_t batch_size) { - NetDef net_dummy; - - // Add non-weigh inputs only - for (const auto& i : net.external_input()) { - if (!initialization_list.count(i)) { - net_dummy.add_external_input(i); - } - } - for (const auto& o : net.external_output()) { - net_dummy.add_external_output(o); - } - - // Now categorize the inputs into the following groups. We don't support - // handling of 3d inputs yet, but it can be done easily by converting n-d - // inputs into 2-d with Reshape or ReduceSum - std::unordered_set batched_2d_inputs; - std::unordered_set other_2d_inputs; - std::unordered_set all_1d_inputs; - auto addCast = [&net_dummy]( - const std::string& i, - std::string& in, - caffe2::TensorProto::DataType dtype) mutable { - int multiplier = 1; - if (dtype != caffe2::TensorProto::FLOAT) { - in += "_fp32"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Clip", - "", - {i}, - {in}, - {MakeArgument("min", 0.0), MakeArgument("max", 1.0)})); - if (dtype == caffe2::TensorProto::INT8 || - dtype == caffe2::TensorProto::UINT8) { - multiplier = sizeof(float) / sizeof(int8_t); - } else if ( - dtype == caffe2::TensorProto::INT16 || - dtype == caffe2::TensorProto::UINT16 || - dtype == caffe2::TensorProto::FLOAT16) { - multiplier = sizeof(float) / sizeof(int16_t); - } else if (dtype == caffe2::TensorProto::INT64) { - // Special case, it should really be 0.5 - multiplier = 0; - } - } - return multiplier; - }; - auto adjustDim = [](int d, int m, TensorShape& shape) { - if (m > 1) { - CAFFE_ENFORCE_EQ(shape.dims(d) % m, 0); - shape.set_dims(d, shape.dims(d) / m); - } else if (m == 0) { - shape.set_dims(d, shape.dims(d) * 2); - } - shape.set_data_type(caffe2::TensorProto::FLOAT); - }; - size_t dim2 = 0; - for (const auto& i : net_dummy.external_input()) { - auto it = shape_hints->find(i); - CAFFE_ENFORCE( - it != shape_hints->end(), "Cannot find shape info for input ", i); - auto& shape = it->second.shape; - std::string in = i; - // Trick here: since backend like glow doesn't support non-float - // arithmatics, we need to be creative and bitcast non-float data type into - // float while maintaining the same bit lengths. We do this by changing the - // shape dim. So that we will always load the same amount of bits onto the - // backend. To avoid numeric complication, we add a Clip. - if (shape.dims_size() == 2) { - auto m = addCast(i, in, shape.data_type()); - adjustDim(1, m, shape); - if (shape.dims(0) == batch_size) { - batched_2d_inputs.emplace(in); - dim2 += shape.dims(1); - } else { - other_2d_inputs.emplace(in); - } - } else if (shape.dims_size() == 1) { - auto m = addCast(i, in, shape.data_type()); - adjustDim(0, m, shape); - all_1d_inputs.emplace(in); - } else { - const std::string fin = i + "_flatten"; - net_dummy.add_op()->CopyFrom( - CreateOperatorDef("Flatten", "", {i}, {fin}, {})); - in = fin; - auto m = addCast(fin, in, shape.data_type()); - auto last = shape.dims_size() - 1; - adjustDim(last, m, shape); - size_t ndim = 1; - for (unsigned k = 1; k < shape.dims_size(); ++k) { - ndim *= shape.dims(k); - } - if (shape.dims(0) == batch_size) { - batched_2d_inputs.emplace(in); - dim2 += ndim; - } else { - other_2d_inputs.emplace(in); - } - } - } - - // Add adjusted shape hints - auto* shape_arg = net_dummy.add_arg(); - auto* qshape_arg = net_dummy.add_arg(); - shape_arg->set_name("input_shape_info"); - qshape_arg->set_name("input_qshape_info"); - for (const auto& i : net_dummy.external_input()) { - auto info = shape_hints->at(i); - if (!info.is_quantized) { - shape_arg->mutable_tensors()->Add()->CopyFrom( - wrapShapeInfoIntoTensorProto(i, info)); - } else { - qshape_arg->mutable_qtensors()->Add()->CopyFrom( - wrapShapeInfoIntoQTensorProto(i, info)); - } - } - - // Collect all the input together into a 2d tensor of {batch_size, X} - std::vector concat2d_batched( - batched_2d_inputs.begin(), batched_2d_inputs.end()); - const std::string concat_out = "batch_2d_concat"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Concat", - "", - concat2d_batched, - {concat_out, "batch_2d_concat_split_info"}, - {MakeArgument("axis", 1)})); - std::vector scalars; - for (const auto& i : other_2d_inputs) { - std::string o = i + "_reduced"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "ReduceSum", - "", - {i}, - {o}, - {MakeArgument>("axes", {0, 1}), - MakeArgument("keepdims", 0)})); - scalars.emplace_back(std::move(o)); - } - for (const auto& i : all_1d_inputs) { - std::string o = i + "_reduced"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "ReduceSum", - "", - {i}, - {o}, - {MakeArgument>("axes", {0}), - MakeArgument("keepdims", 0)})); - scalars.emplace_back(std::move(o)); - } - const std::string summed = "summed"; - net_dummy.add_op()->CopyFrom( - CreateOperatorDef("Sum", "", scalars, {summed}, {})); - const std::string out = "result_out"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Add", - "", - {concat_out, summed}, - {out}, - {MakeArgument("broadcast", 1)})); - - for (const auto& o : net_dummy.external_output()) { - const auto it = shape_hints->find(o); - CAFFE_ENFORCE( - it != shape_hints->end(), "Cannot find shape info for output ", o); - const auto& shape = it->second.shape; - // TODO: all doable but I'm lazy - if (shape.data_type() != caffe2::TensorProto::FLOAT) { - CAFFE_THROW("We need a Cast op to match the output data type"); - } - if (shape.dims_size() == 2) { - if (shape.dims(0) == batch_size) { - if (shape.dims(1) > dim2) { - CAFFE_THROW( - "We need Tile op to match the output dim ", - shape.dims(1), - " vs ", - dim2); - } else if (shape.dims(1) == dim2) { - net_dummy.add_op()->CopyFrom( - CreateOperatorDef("Copy", "", {out}, {o}, {})); - } else { - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Slice", - "", - {out}, - {o}, - {MakeArgument>("starts", {0, 0}), - MakeArgument>( - "ends", {-1, static_cast(shape.dims(1))})})); - } - } - } else if (shape.dims_size() == 1) { - if (shape.dims(0) == batch_size) { - const std::string oi = o + "_pre"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Slice", - "", - {out}, - {oi}, - {MakeArgument>("starts", {0, 0}), - MakeArgument>("ends", {-1, 1})})); - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Reshape", - "", - {oi}, - {o}, - {MakeArgument>( - "shape", {static_cast(batch_size)})})); - } else { - CAFFE_THROW( - "We need Slice and Tile op to match the output dim ", - shape.dims(0), - " vs ", - batch_size); - } - } else { - CAFFE_THROW("Only support 1D/2D outputs for now"); - } - } - - return net_dummy; -} - } // namespace void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws) { @@ -735,12 +506,38 @@ OnnxifiTransformer::~OnnxifiTransformer() { } } +bool OnnxifiTransformer::canPassOutputShapeHintsPerBs( + const OperatorDef& op, + const std::unordered_map& shape_hints_per_bs) const { + if (shape_hints_per_bs.empty()) { + return false; + } + + for (int bs = 1; bs < opts_.bound_shape_spec.max_batch_size; ++bs) { + auto shape_hints_search = shape_hints_per_bs.find(bs); + if (shape_hints_search == shape_hints_per_bs.end()) { + return false; + } + const auto& shape_hints = shape_hints_search->second; + + for (int output_idx = 0; output_idx < op.output_size(); ++output_idx) { + auto shape_hint_search = shape_hints.find(op.output(output_idx)); + if (shape_hint_search == shape_hints.end()) { + return false; + } + } + } + + return true; +} + OperatorDef OnnxifiTransformer::buildOnnxifiOp( const std::string& onnx_model_str, const std::unordered_set& initialization_list, const std::vector& external_inputs, const std::vector& external_outputs, - const std::unordered_map& shape_hints) { + const ShapeInfoMap& shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs) { OperatorDef op; op.set_type("Onnxifi"); auto* onnx_model_arg = op.add_arg(); @@ -778,9 +575,9 @@ OperatorDef OnnxifiTransformer::buildOnnxifiOp( int nominal_batch_idx{0}; for (const auto& input : external_inputs) { if (!initialization_list.count(input)) { - const auto it = shape_hints.find(input); + const auto it = shape_hints_max_bs.find(input); CAFFE_ENFORCE( - it != shape_hints.end(), "Input shape for ", input, " not found"); + it != shape_hints_max_bs.end(), "Input shape for ", input, " not found"); const auto& info = it->second; if (info.getDimType(0) == TensorBoundShape_DimType_BATCH && getBlob1stDimSize(info) == max_batch_size) { @@ -791,15 +588,15 @@ OperatorDef OnnxifiTransformer::buildOnnxifiOp( } } - // Add output size hints + // Add output size hints for max batch size auto* output_shape_info_arg = op.add_arg(); output_shape_info_arg->set_name("output_shape_info"); auto* output_qshape_info_arg = op.add_arg(); output_qshape_info_arg->set_name("output_qshape_info"); for (int i = 0; i < op.output_size(); ++i) { const auto& o = op.output(i); - const auto it = shape_hints.find(o); - if (it != shape_hints.end()) { + const auto it = shape_hints_max_bs.find(o); + if (it != shape_hints_max_bs.end()) { if (!it->second.is_quantized) { output_shape_info_arg->mutable_tensors()->Add()->CopyFrom( wrapShapeInfoIntoTensorProto(o, it->second)); @@ -811,6 +608,33 @@ OperatorDef OnnxifiTransformer::buildOnnxifiOp( } } + // Add output size hints per batch size + if (canPassOutputShapeHintsPerBs(op, shape_hints_per_bs)) { + VLOG(2) << "Passing in output shape hints for batch sizes in [1, " << opts_.bound_shape_spec.max_batch_size << ")"; + AddArgument("use_passed_output_shapes", 1, &op); + + for (int bs = 1; bs < opts_.bound_shape_spec.max_batch_size; ++bs) { + auto* output_shape_arg = op.add_arg(); + output_shape_arg->set_name("output_shapes_bs_" + caffe2::to_string(bs)); + auto* output_qshape_arg = op.add_arg(); + output_qshape_arg->set_name("output_qshapes_bs_" + caffe2::to_string(bs)); + + const auto& shape_hints = shape_hints_per_bs.find(bs)->second; + + for (int output_idx = 0; output_idx < op.output_size(); ++output_idx) { + const auto& output_name = op.output(output_idx); + const auto& shape_hint = shape_hints.find(output_name)->second; + if (!shape_hint.is_quantized) { + output_shape_arg->mutable_tensors()->Add()->CopyFrom(wrapShapeInfoIntoTensorProto(output_name, shape_hint)); + } else { + output_shape_arg->mutable_qtensors()->Add()->CopyFrom(wrapShapeInfoIntoQTensorProto(output_name, shape_hint)); + } + } + } + } else { + AddArgument("use_passed_output_shapes", 0, &op); + } + // Tell Onnxifi op that the model is in onnx or c2 proto format AddArgument("use_onnx", opts_.use_onnx ? 1 : 0, &op); @@ -838,11 +662,14 @@ OperatorDef OnnxifiTransformer::buildOnnxifiOp( NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( const caffe2::NetDef& net, const std::unordered_set& weights_in_ws, - const ShapeInfoMap& shape_hints) { + const ShapeInfoMap& shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs) { int onnxifi_op_id = onnxifi_op_id_; if (opts_.debug) { WriteProtoToTextFile( - net, "debug_original_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt"); + net, + "debug_original_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt", + false); } if (opts_.min_ops > net.op_size()) { return net; @@ -874,8 +701,8 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( op.type() == "SparseLengthsWeightedSumFused4BitRowwise") ? 1 : 0; - const auto& indices_hint = shape_hints.at(op.input(1 + weighted)); - const auto& lengths_hint = shape_hints.at(op.input(2 + weighted)); + const auto& indices_hint = shape_hints_max_bs.at(op.input(1 + weighted)); + const auto& lengths_hint = shape_hints_max_bs.at(op.input(2 + weighted)); const auto& indices_shape = indices_hint.shape; const auto& lengths_shape = lengths_hint.shape; if ((indices_hint.getDimType(0) == @@ -916,29 +743,17 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( onnxifi_net.clear_external_input(); for (const auto& i : total_inputs_vec) { onnxifi_net.add_external_input(i); - auto info = shape_hints.at(i); + auto info = shape_hints_max_bs.at(i); if (!info.is_quantized) { shape_arg->mutable_tensors()->Add()->CopyFrom( - wrapShapeInfoIntoTensorProto(i, shape_hints.at(i))); + wrapShapeInfoIntoTensorProto(i, shape_hints_max_bs.at(i))); } else { qshape_arg->mutable_qtensors()->Add()->CopyFrom( - wrapShapeInfoIntoQTensorProto(i, shape_hints.at(i))); + wrapShapeInfoIntoQTensorProto(i, shape_hints_max_bs.at(i))); } } - // Rewrite the net into a dummy in loop test mode - ShapeInfoMap new_shape_hints; - if (opts_.loop_test) { - new_shape_hints = shape_hints; - onnxifi_net = buildLoopTestNet( - onnxifi_net, - initialization_list, - &new_shape_hints, - opts_.bound_shape_spec.max_batch_size); - initialization_list.clear(); - } - - // Add parition info + // Add partition info for (const auto& p : partition_infos_) { onnxifi_net.add_partition_info()->CopyFrom(p); } @@ -963,17 +778,20 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( initialization_list, onnxifi_net_inputs, onnxifi_net_outputs, - opts_.loop_test ? new_shape_hints : shape_hints); + shape_hints_max_bs, + shape_hints_per_bs); NetDef net_opt = composeResultNet(onnxifi_op); // Debugging stuff if (opts_.debug) { WriteProtoToTextFile( onnxifi_net, - "debug_onnxifi_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt"); + "debug_onnxifi_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt", + false); WriteProtoToTextFile( net_opt, - "debug_optimized_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt"); + "debug_optimized_net_" + c10::to_string(onnxifi_op_id) + ".pb_txt", + false); } return net_opt; } @@ -983,7 +801,8 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx( const std::unordered_set& weights_in_ws, Workspace* ws, onnx::OnnxExporter* exporter, - ShapeInfoMap* shape_hints) { + ShapeInfoMap* shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs) { if (opts_.min_ops > net.op_size()) { return net; } @@ -1007,7 +826,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx( TensorShape shape; shape.mutable_dims()->CopyFrom(t.dims()); auto ret = shape_hints_onnx_.emplace(t.name(), std::move(shape)); - shape_hints->emplace( + shape_hints_max_bs->emplace( std::piecewise_construct, std::forward_as_tuple(ret.first->first), std::forward_as_tuple( @@ -1082,13 +901,14 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx( initialization_list, onnxifi_net_inputs, onnxifi_net_outputs, - *shape_hints); + *shape_hints_max_bs, + shape_hints_per_bs); NetDef net_opt = composeResultNet(onnxifi_op); // Debugging stuff if (opts_.debug) { - WriteProtoToTextFile(onnx_model, "debug_onnxifi_net.onnx_txt"); - WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt"); + WriteProtoToTextFile(onnx_model, "debug_onnxifi_net.onnx_txt", false); + WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt", false); } return net_opt; } @@ -1375,11 +1195,11 @@ void OnnxifiTransformer::applyFilteringRules( blocklistCpuPartition(net, blocklisted_ops); } -void OnnxifiTransformer::getBackendId() { +std::vector OnnxifiTransformer::getBackendId() { idx_ = 0; if (opts_.use_onnx) { - return; + return backend_ids_; } // Try to find a backend that support Caffe2 proto. Note that this is quite // opportunistic as we don't officially support Caffe2 proto. @@ -1394,26 +1214,28 @@ void OnnxifiTransformer::getBackendId() { break; } } + return backend_ids_; } NetDef OnnxifiTransformer::TransformViaC2( NetDef* pred_net, const std::unordered_set& weights, const std::unordered_set& blocklisted_ops, - const ShapeInfoMap& shape_hints) { + const ShapeInfoMap& shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs) { onnxBackendID backend_id = backend_ids_[idx_]; auto c2_supports = [this, - &shape_hints, + &shape_hints_max_bs, &blocklisted_ops, backend_id, &weights](const caffe2::OperatorDef& op) { - return supportOpC2(op, shape_hints, weights, blocklisted_ops, backend_id); + return supportOpC2(op, shape_hints_max_bs, weights, blocklisted_ops, backend_id); }; auto c2_converter = - [this, &weights, &shape_hints](const caffe2::NetDef& net) { - return SubnetToOnnxifiOpViaC2(net, weights, shape_hints); + [this, &weights, &shape_hints_max_bs, &shape_hints_per_bs](const caffe2::NetDef& net) { + return SubnetToOnnxifiOpViaC2(net, weights, shape_hints_max_bs, shape_hints_per_bs); }; return opt::OptimizeForBackend( @@ -1425,7 +1247,8 @@ NetDef OnnxifiTransformer::TransformViaOnnx( NetDef* pred_net, const std::unordered_set& weights, const std::unordered_set& blocklisted_ops, - ShapeInfoMap* shape_hints) { + ShapeInfoMap* shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs) { onnxBackendID backend_id = backend_ids_[idx_]; // function to tell whether the ONNXIFI backend supports a given C2 op or not @@ -1439,9 +1262,9 @@ NetDef OnnxifiTransformer::TransformViaOnnx( // the same exporter throughout the process to avoid duplicated dummy name // generation onnx::OnnxExporter exporter2(nullptr); - auto onnx_converter = [this, ws, &weights, shape_hints, &exporter2]( + auto onnx_converter = [this, ws, &weights, shape_hints_max_bs, &exporter2, &shape_hints_per_bs]( const caffe2::NetDef& net) mutable { - return SubnetToOnnxifiOpViaOnnx(net, weights, ws, &exporter2, shape_hints); + return SubnetToOnnxifiOpViaOnnx(net, weights, ws, &exporter2, shape_hints_max_bs, shape_hints_per_bs); }; return opt::OptimizeForBackend( @@ -1467,7 +1290,7 @@ void OnnxifiTransformer::transform( CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr"); if (opts_.debug) { - WriteProtoToTextFile(*pred_net, "debug_pre_ssa_net.pb_txt"); + WriteProtoToTextFile(*pred_net, "debug_pre_ssa_net.pb_txt", false); } // Get model id and reset Onnxifi op id to 0 @@ -1501,17 +1324,17 @@ void OnnxifiTransformer::transform( // blob for output is created. This causes problem if inferShape uses original // ws since it does not expect the output blob to be present. Workspace mapped_ws(ws, input_mapping_); - ShapeInfoMap shape_hints = inferShapes( + ShapeInfoMap shape_hints_max_bs = inferShapes( &mapped_ws, pred_net, shape_hints_mapped, opts_.bound_shape_spec); if (opts_.use_onnx) { - shape_hints_onnx_ = stripShapeInfoMap(shape_hints); + shape_hints_onnx_ = stripShapeInfoMap(shape_hints_max_bs); } if (opts_.enforce_fp32_inputs_into_fp16) { - enforceFp32InputsToFp16(weights, pred_net, &shape_hints); + enforceFp32InputsToFp16(weights, pred_net, &shape_hints_max_bs); } if (opts_.merge_fp32_inputs_into_fp16) { mergeFp32InputsAndConvertToFp16( - opts_.bound_shape_spec.max_batch_size, weights, pred_net, &shape_hints); + opts_.bound_shape_spec.max_batch_size, weights, pred_net, &shape_hints_max_bs); } if (opts_.debug) { @@ -1522,7 +1345,7 @@ void OnnxifiTransformer::transform( for (const auto& w : weights) { w_arg->add_strings(w); } - dumpNet(ssa_net, shape_hints, "debug_ssa_net.pb_txt"); + dumpNet(ssa_net, shape_hints_max_bs, "debug_ssa_net.pb_txt"); } extractPartitionInfo(*pred_net); @@ -1532,13 +1355,13 @@ void OnnxifiTransformer::transform( // Apply some filtering rules std::unordered_set new_blocklisted_ops( blocklisted_ops.begin(), blocklisted_ops.end()); - applyFilteringRules(*pred_net, shape_hints, weights, &new_blocklisted_ops); + applyFilteringRules(*pred_net, shape_hints_max_bs, weights, &new_blocklisted_ops); // Transform the net NetDef net_opt = opts_.use_onnx ? TransformViaOnnx( - ws, pred_net, weights, new_blocklisted_ops, &shape_hints) - : TransformViaC2(pred_net, weights, new_blocklisted_ops, shape_hints); + ws, pred_net, weights, new_blocklisted_ops, &shape_hints_max_bs, opts_.shape_hints_per_bs) + : TransformViaC2(pred_net, weights, new_blocklisted_ops, shape_hints_max_bs, opts_.shape_hints_per_bs); // Need to figure out a proper place to handle device option net_opt.mutable_device_option()->CopyFrom(pred_net->device_option()); @@ -1546,9 +1369,9 @@ void OnnxifiTransformer::transform( pred_net->Swap(&net_opt); - addShapeToNet(*pred_net, shape_hints); + addShapeToNet(*pred_net, shape_hints_max_bs); if (opts_.debug) { - WriteProtoToTextFile(*pred_net, "debug_full_opt_net.pb_txt"); + WriteProtoToTextFile(*pred_net, "debug_full_opt_net.pb_txt", false); } } diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index 9af168a4d20e8..d1af1731013d8 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -18,7 +18,7 @@ class OnnxExporter; // Split SparseLengthsSumSparse into SparseLengthsSumSparseLookup + // SparseLengthsSum -CAFFE2_API void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws); +TORCH_API void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws); struct OnnxifiTransformerOptions final : public BackendTransformOptions { explicit OnnxifiTransformerOptions() : BackendTransformOptions() {} @@ -39,17 +39,17 @@ struct OnnxifiTransformerOptions final : public BackendTransformOptions { // fp16 or not bool merge_fp32_inputs_into_fp16{false}; - // Enter loop test mode - bool loop_test{false}; - // Whether the net has been ssaRewritten bool predictor_net_ssa_rewritten{false}; // Inference timeout int timeout{0}; + + // Mapping of batch sizes to shape infos + std::unordered_map shape_hints_per_bs; }; -class CAFFE2_API OnnxifiTransformer final : public BackendTransformerBase { +class TORCH_API OnnxifiTransformer final : public BackendTransformerBase { public: explicit OnnxifiTransformer(const OnnxifiTransformerOptions& opts); ~OnnxifiTransformer() override; @@ -61,6 +61,17 @@ class CAFFE2_API OnnxifiTransformer final : public BackendTransformerBase { const ShapeInfoMap& shape_hints, const std::unordered_set& blocklisted_ops) override; + // Query whether an operator is supported by passing C2 protobuf + bool supportOpC2( + const caffe2::OperatorDef& op, + const ShapeInfoMap& shape_hints, + const std::unordered_set& weights, + const std::unordered_set& blocklisted_ops, + onnxBackendID backend_id) const; + + // Determine backend id + std::vector getBackendId(); + private: // Since we create new tensors during the conversion process, we actually need // into inject them into the original workspace @@ -72,13 +83,21 @@ class CAFFE2_API OnnxifiTransformer final : public BackendTransformerBase { const std::unordered_set& weights_in_ws, Workspace* ws, onnx::OnnxExporter* exporter, - ShapeInfoMap* shape_hints); + ShapeInfoMap* shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs); // Convert a cutoff subgraph net to an Onnxifi op caffe2::NetDef SubnetToOnnxifiOpViaC2( const caffe2::NetDef& net, const std::unordered_set& weights_in_ws, - const ShapeInfoMap& shape_hints); + const ShapeInfoMap& shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs); + + // Check that output shape hints are present to ensure we can pass them to + // OnnxifiOp + bool canPassOutputShapeHintsPerBs( + const OperatorDef& op, + const std::unordered_map& shape_hints_per_bs) const; // We already have all the ops and external inputs and outputs! OperatorDef buildOnnxifiOp( @@ -86,14 +105,16 @@ class CAFFE2_API OnnxifiTransformer final : public BackendTransformerBase { const std::unordered_set& initialization_list, const std::vector& external_inputs, const std::vector& external_outputs, - const std::unordered_map& shape_hints); + const ShapeInfoMap& shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs); // Transform by passing C2 proto to backend NetDef TransformViaC2( NetDef* pred_net, const std::unordered_set& weights, const std::unordered_set& blocklisted_ops, - const ShapeInfoMap& shape_hints); + const ShapeInfoMap& shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs); // Transform by passing ONNX proto to backend NetDef TransformViaOnnx( @@ -101,15 +122,8 @@ class CAFFE2_API OnnxifiTransformer final : public BackendTransformerBase { NetDef* pred_net, const std::unordered_set& weights, const std::unordered_set& blocklisted_ops, - ShapeInfoMap* shape_hints); - - // Query whether an operator is supported by passing C2 protobuf - bool supportOpC2( - const caffe2::OperatorDef& op, - const ShapeInfoMap& shape_hints, - const std::unordered_set& weights, - const std::unordered_set& blocklisted_ops, - onnxBackendID backend_id) const; + ShapeInfoMap* shape_hints_max_bs, + const std::unordered_map &shape_hints_per_bs); // Query whether an operator is supported by passing ONNX protobuf bool supportOpOnnx( @@ -141,9 +155,6 @@ class CAFFE2_API OnnxifiTransformer final : public BackendTransformerBase { const std::unordered_set& weights, std::unordered_set* blocklisted_ops) const; - // Determine backend id - void getBackendId(); - // Extract partition info from the original net void extractPartitionInfo(const NetDef& net); diff --git a/caffe2/opt/optimize_ideep.h b/caffe2/opt/optimize_ideep.h index 85b86bfbc2718..280ef1bab6f93 100644 --- a/caffe2/opt/optimize_ideep.h +++ b/caffe2/opt/optimize_ideep.h @@ -8,7 +8,7 @@ namespace caffe2 { namespace opt { -CAFFE2_API void OptimizeForMkldnn( +TORCH_API void OptimizeForMkldnn( nom::repr::NNModule* nn, caffe2::Workspace* ws, bool training_mode = false); diff --git a/caffe2/opt/optimizer.h b/caffe2/opt/optimizer.h index 326f371b5725e..72245b413e999 100644 --- a/caffe2/opt/optimizer.h +++ b/caffe2/opt/optimizer.h @@ -8,8 +8,8 @@ namespace caffe2 { namespace opt { -CAFFE2_API NetDef optimize(NetDef net, Workspace* ws, int level = 1); -CAFFE2_API NetDef optimize(NetDef net, int level = 1); +TORCH_API NetDef optimize(NetDef net, Workspace* ws, int level = 1); +TORCH_API NetDef optimize(NetDef net, int level = 1); } // namespace opt } // namespace caffe2 diff --git a/caffe2/opt/passes.h b/caffe2/opt/passes.h index fc15dcad13fe7..b2ef81c2d4247 100644 --- a/caffe2/opt/passes.h +++ b/caffe2/opt/passes.h @@ -21,7 +21,7 @@ namespace caffe2 { * use a different registry and inherit from WorkspaceOptimizationPass. */ -class CAFFE2_API OptimizationPass { +class TORCH_API OptimizationPass { public: OptimizationPass(NNModule* nn) : nn_(nn) {} virtual void run() = 0; @@ -31,7 +31,7 @@ class CAFFE2_API OptimizationPass { NNModule* nn_; }; -class CAFFE2_API WorkspaceOptimizationPass : public OptimizationPass { +class TORCH_API WorkspaceOptimizationPass : public OptimizationPass { public: WorkspaceOptimizationPass(NNModule* nn, Workspace* ws) : OptimizationPass(nn), ws_(ws) {} virtual ~WorkspaceOptimizationPass() {} diff --git a/caffe2/opt/shape_info.cc b/caffe2/opt/shape_info.cc index dfcdeb0356bdc..7e3ac1b15dc9b 100644 --- a/caffe2/opt/shape_info.cc +++ b/caffe2/opt/shape_info.cc @@ -2,6 +2,7 @@ #include "caffe2/core/operator.h" #include "caffe2/core/tensor_int8.h" #include "caffe2/utils/string_utils.h" +#include namespace caffe2 { diff --git a/caffe2/opt/shape_info.h b/caffe2/opt/shape_info.h index 3376922f166fb..b843963db73b1 100644 --- a/caffe2/opt/shape_info.h +++ b/caffe2/opt/shape_info.h @@ -4,7 +4,7 @@ namespace caffe2 { -struct CAFFE2_API QShapeInfo { +struct TORCH_API QShapeInfo { QShapeInfo(float o = 0, float s = 1, uint32_t a = 1) { offset.clear(); scale.clear(); @@ -18,7 +18,7 @@ struct CAFFE2_API QShapeInfo { vector scale; }; -struct CAFFE2_API ShapeInfo { +struct TORCH_API ShapeInfo { ShapeInfo(bool q = false) : is_quantized(q) {} ShapeInfo( std::vector&& t, @@ -95,6 +95,14 @@ struct CAFFE2_API ShapeInfo { } } + bool getShapeIsFinal() { + return shape_is_final; + } + + void setShapeIsFinal(bool flag) { + shape_is_final = flag; + } + TensorShape shape; // quantization related information @@ -106,6 +114,9 @@ struct CAFFE2_API ShapeInfo { // dim_type.size == shape.dims.size std::vector dim_type; bool dim_type_is_set = false; + // a flag to indicate whether the shape is final and cannot be changed + // eg: input/output of in-place ops + bool shape_is_final = false; }; using ShapeInfoMap = std::unordered_map; @@ -122,23 +133,23 @@ bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs); // since they are already inserted as CONSTANT, it will take effect here. // For SEQ typed tensors, there are only a few of them and they will be // handled by BoundShapeInferencer. -CAFFE2_API ShapeInfo constructShapeInfoWithDefaultDimType( +TORCH_API ShapeInfo constructShapeInfoWithDefaultDimType( TensorShape shape, TensorBoundShape_DimType defaultFirstDimType = TensorBoundShape_DimType_BATCH); -CAFFE2_API void parseShapeInfoMapFromString(const std::string&, ShapeInfoMap&); +TORCH_API void parseShapeInfoMapFromString(const std::string&, ShapeInfoMap&); // Extract shape info from tensorBoundShapes to a ShapeInfoMap. // Change shape according to new max_batch_size and max_feature_len // at the same time if necessary. -CAFFE2_API ShapeInfoMap extractShapeInfoFromTensorBoundShapes( +TORCH_API ShapeInfoMap extractShapeInfoFromTensorBoundShapes( TensorBoundShapes tensor_bound_shapes, int64_t new_max_batch_size = -1, int64_t new_max_feature_len = -1); // In-place modify TensorBoundShape to change shape size based on type -CAFFE2_API void changeTensorBoundShapes( +TORCH_API void changeTensorBoundShapes( TensorBoundShape& tensor_shape_and_type, const int64_t old_batch_size, const int64_t old_seq_size, @@ -146,7 +157,7 @@ CAFFE2_API void changeTensorBoundShapes( const int64_t new_seq_size); // In-place modify TensorShape's shape at a specific dimension -CAFFE2_API void modifyTensorShapeDimSize( +TORCH_API void modifyTensorShapeDimSize( TensorShape* tensor_shape, int dim_index, const int64_t old_size, diff --git a/caffe2/opt/tvm_transformer.h b/caffe2/opt/tvm_transformer.h index 8ff29baee8436..6a4a34507f9a3 100644 --- a/caffe2/opt/tvm_transformer.h +++ b/caffe2/opt/tvm_transformer.h @@ -13,7 +13,7 @@ struct TvmTransformOptions final : public BackendTransformOptions { bool profiling_based_jit{false}; }; -class CAFFE2_API TvmTransformer final : public BackendTransformerBase { +class TORCH_API TvmTransformer final : public BackendTransformerBase { public: explicit TvmTransformer(const TvmTransformOptions& opts) : BackendTransformerBase(), opts_(opts) {} @@ -68,7 +68,7 @@ class CAFFE2_API TvmTransformer final : public BackendTransformerBase { }; // Helper function to clean up a net and run tvm transform. -CAFFE2_API void tvmTransform( +TORCH_API void tvmTransform( NetDef* net, Workspace* ws, const std::vector& input_names, @@ -84,7 +84,7 @@ CAFFE2_API void tvmTransform( bool tvm_profiling_based_jit, bool debug); -CAFFE2_API void cleanUpPredictNet( +TORCH_API void cleanUpPredictNet( NetDef* net, const std::vector& input_names, const std::vector& output_names, diff --git a/caffe2/perfkernels/adagrad.h b/caffe2/perfkernels/adagrad.h index 765f8154c8791..12cd41056ec3e 100644 --- a/caffe2/perfkernels/adagrad.h +++ b/caffe2/perfkernels/adagrad.h @@ -99,8 +99,9 @@ inline void adagrad_update_prefetch_inlined( __m256 gi = _mm256_loadu_ps(g + i); __m256 hi = _mm256_loadu_ps(h + i); __m256 wi = _mm256_loadu_ps(w + i); -#ifdef __AVX2__ +#ifdef __FMA__ gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi); + #else gi = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(weight_decay), wi), gi); #endif diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc index bc4b7730dced3..b32efc9eae418 100644 --- a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_idx_avx2.cc @@ -17,7 +17,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -401,7 +401,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_false__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -425,7 +425,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_true__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -883,7 +883,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -1387,7 +1387,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_false__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -1410,7 +1410,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_true__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -1987,7 +1987,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -2514,7 +2514,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { @@ -2538,7 +2538,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, bool normalize_by_lengths, float* out) { diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 1216c6b77cdf7..d9b2f0627bc4a 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -23,7 +23,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for sum reducer const float* scale_bias, // optional scale & bias params for uint8 input bool normalize_by_lengths, @@ -85,7 +85,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -118,7 +118,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -163,7 +163,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ diff --git a/caffe2/perfkernels/embedding_lookup_idx.h b/caffe2/perfkernels/embedding_lookup_idx.h index 9092b2759357d..67573fb21fa30 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.h +++ b/caffe2/perfkernels/embedding_lookup_idx.h @@ -48,7 +48,7 @@ void EmbeddingLookupIdx( const std::int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for non-weighted sum const float* scale_bias, // optional scale & bias params for uint8 input bool normalize_by_lengths, diff --git a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc index eb61353866c70..329598b84d4d6 100644 --- a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc @@ -17,7 +17,7 @@ static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -402,7 +402,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -427,7 +427,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma( const int64_t data_size, const float* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -891,7 +891,7 @@ static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -1396,7 +1396,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -1421,7 +1421,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma( const int64_t data_size, const at::Half* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -2005,7 +2005,7 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -2523,7 +2523,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, @@ -2548,7 +2548,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma( const int64_t data_size, const uint8_t* input, const int* indices, - const int64_t* offsets, + const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc index 28972e4f49a12..99a41d742d1f7 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc @@ -22,7 +22,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for sum reducer bool normalize_by_lengths, OutType* out) { @@ -88,7 +88,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const uint8_t* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ bool normalize_by_lengths, \ OutType* out) { \ @@ -118,7 +118,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const uint8_t* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ bool normalize_by_lengths, \ OutType* out) { \ @@ -160,7 +160,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const uint8_t* input, \ const IndexType* indices, \ - const int64_t* offsets, \ + const IndexType* offsets, \ const float* weights, \ bool normalize_by_lengths, \ OutType* out) { \ diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h index 9970c8671d0af..f7422bd7b7522 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h @@ -50,7 +50,7 @@ void Fused8BitRowwiseEmbeddingLookupIdx( const std::int64_t data_size, const InType* input, const IndexType* indices, - const int64_t* offsets, + const IndexType* offsets, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, OutType* out); diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 75b0c8b583bea..402f3bb92a415 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -450,7 +450,7 @@ def compute(InType, use_weights, isa): args.append(" const " + InType + "* input,") args.append(" const " + IndexType + "* indices,") if opts.use_offsets: - args.append(" const int64_t* offsets,") + args.append(" const " + IndexType + "* offsets,") else: args.append(" const int* lengths,") args.append(" const float* weights,") diff --git a/caffe2/predictor/emulator/data_filler.h b/caffe2/predictor/emulator/data_filler.h index b893a18e56a78..e3021f624e450 100644 --- a/caffe2/predictor/emulator/data_filler.h +++ b/caffe2/predictor/emulator/data_filler.h @@ -144,7 +144,7 @@ class TestDataRandomFiller : public DataRandomFiller { }; // Convenient helpers to fill data to workspace. -CAFFE2_API void fillRandomNetworkInputs( +TORCH_API void fillRandomNetworkInputs( const NetDef& net, const std::vector>>& inputDims, const std::vector>& inputTypes, diff --git a/caffe2/predictor/predictor.h b/caffe2/predictor/predictor.h index fd16eb5934fe2..f49de2062cd9e 100644 --- a/caffe2/predictor/predictor.h +++ b/caffe2/predictor/predictor.h @@ -7,7 +7,7 @@ namespace caffe2 { -class CAFFE2_API Predictor { +class TORCH_API Predictor { public: using TensorList = std::vector; using TensorMap = std::unordered_map; diff --git a/caffe2/predictor/predictor_config.h b/caffe2/predictor/predictor_config.h index 243729b044e9a..ad3711e3873ce 100644 --- a/caffe2/predictor/predictor_config.h +++ b/caffe2/predictor/predictor_config.h @@ -17,7 +17,7 @@ using PredictorParameters = std::map>; /** * Stores parameters nessasary for creating a PredictorInterface object. */ -struct CAFFE2_API PredictorConfig { +struct TORCH_API PredictorConfig { // A map of parameter name to Tensor object. Predictor is supposed to // guarantee constness of all these Tensor objects. std::shared_ptr parameters; @@ -41,14 +41,14 @@ struct CAFFE2_API PredictorConfig { std::shared_ptr ws; }; -CAFFE2_API Workspace makeWorkspace(std::shared_ptr parameters); +TORCH_API Workspace makeWorkspace(std::shared_ptr parameters); -CAFFE2_API PredictorConfig makePredictorConfig( +TORCH_API PredictorConfig makePredictorConfig( const MetaNetDef& net, Workspace* parent = nullptr, bool run_init = true); -CAFFE2_API PredictorConfig makePredictorConfig( +TORCH_API PredictorConfig makePredictorConfig( const NetDef& init_net, const NetDef& run_net, Workspace* parent = nullptr, diff --git a/caffe2/predictor/predictor_test.cc b/caffe2/predictor/predictor_test.cc index 9e3bf8ceab93a..5448614a832d5 100644 --- a/caffe2/predictor/predictor_test.cc +++ b/caffe2/predictor/predictor_test.cc @@ -187,7 +187,7 @@ TEST_F(PredictorTest, SimpleBatchSized) { EXPECT_EQ(output.front().sizes().size(), 2); EXPECT_EQ(output.front().size(0), 1); EXPECT_EQ(output.front().size(1), 10); - EXPECT_NEAR(output.front().data()[4], 0.1209, 1E-4); + EXPECT_NEAR(output.front().data()[4], 4.9556, 1E-4); } TEST_F(PredictorTest, SimpleBatchSizedMapInput) { @@ -202,7 +202,7 @@ TEST_F(PredictorTest, SimpleBatchSizedMapInput) { EXPECT_EQ(output.front().sizes().size(), 2); EXPECT_EQ(output.front().size(0), 1); EXPECT_EQ(output.front().size(1), 10); - EXPECT_NEAR(output.front().data()[4], 0.1209, 1E-4); + EXPECT_NEAR(output.front().data()[4], 4.9556, 1E-4); } } // namespace caffe2 diff --git a/caffe2/predictor/predictor_utils.cc b/caffe2/predictor/predictor_utils.cc index e38d51d5f3d26..44b28688a7e8f 100644 --- a/caffe2/predictor/predictor_utils.cc +++ b/caffe2/predictor/predictor_utils.cc @@ -9,7 +9,7 @@ namespace caffe2 { namespace predictor_utils { -CAFFE2_API const NetDef& getNet( +TORCH_API const NetDef& getNet( const MetaNetDef& def, const std::string& name) { for (const auto& n : def.nets()) { diff --git a/caffe2/predictor/predictor_utils.h b/caffe2/predictor/predictor_utils.h index 8c9cb4a5792d4..e7405e68f9995 100644 --- a/caffe2/predictor/predictor_utils.h +++ b/caffe2/predictor/predictor_utils.h @@ -8,18 +8,18 @@ namespace caffe2 { namespace predictor_utils { -CAFFE2_API const NetDef& getNet(const MetaNetDef& def, const std::string& name); +TORCH_API const NetDef& getNet(const MetaNetDef& def, const std::string& name); const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs( const MetaNetDef& def, const std::string& name); -CAFFE2_API std::unique_ptr extractMetaNetDef( +TORCH_API std::unique_ptr extractMetaNetDef( db::Cursor* cursor, const std::string& key); // Extract the MetaNetDef from `db`, and run the global init net on the // `master` workspace. -CAFFE2_API std::unique_ptr runGlobalInitialization( +TORCH_API std::unique_ptr runGlobalInitialization( std::unique_ptr db, Workspace* master); diff --git a/caffe2/predictor/transforms.cc b/caffe2/predictor/transforms.cc index 03653d8ea2a44..72a6098b7e959 100644 --- a/caffe2/predictor/transforms.cc +++ b/caffe2/predictor/transforms.cc @@ -90,7 +90,7 @@ void RenameOutputs( void RenameInputsInChildren( const string& from, const string& to, - std::shared_ptr net, + caffe2::NetDef* net, std::unordered_map>& children) { VLOG(2) << "RenameInputsInChildren (from=" << from << ", to=" << to << ")"; if (children.count(from) == 0) { @@ -106,7 +106,7 @@ void RenameInputsInChildren( void RenameOutputInParents( const std::string& from, const std::string& to, - std::shared_ptr net, + caffe2::NetDef* net, std::unordered_map>& parents) { VLOG(2) << "RenameOutputInParents (from=" << from << ", to=" << to << ")"; if (parents.count(from) == 0) { @@ -225,7 +225,13 @@ bool FoundOpCandidate( // extra complexity is handled in FoundOpCandidate. void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) { int num_removed = 0; - std::shared_ptr net = graph.predict_net_def; + NetDef* net = graph.predict_net_def.get(); + for (auto& op : net->op()) { + if (op.type() == "RecurrentNetwork") { + LOG(INFO) << "RemoveOpsByType does not support RecurrentNetwork yet"; + return; + } + } std::unordered_set inputs( graph.input_names.begin(), graph.input_names.end()); @@ -239,7 +245,7 @@ void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) { for (const auto& o : graph.output_names) { net->add_external_output(o); } - onnx::SsaRewrite(nullptr, net.get()); + onnx::SsaRewrite(nullptr, net); // clear external_outputs net->mutable_external_output()->Clear(); graph.predictor_net_ssa_rewritten = true; diff --git a/caffe2/proto/CMakeLists.txt b/caffe2/proto/CMakeLists.txt index 9dc4b4a86cb78..ba6b696dde4ba 100644 --- a/caffe2/proto/CMakeLists.txt +++ b/caffe2/proto/CMakeLists.txt @@ -10,14 +10,14 @@ add_library(Caffe2_PROTO OBJECT ${Caffe2_PROTO_HEADERS} ${Caffe2_PROTO_SRCS}) if(MSVC) if(BUILD_SHARED_LIBS) - set(Caffe2_API_DEFINE "-DCAFFE2_API=__declspec(dllexport)") + set(TORCH_API_DEFINE "-DTORCH_API=__declspec(dllexport)") else() - set(Caffe2_API_DEFINE "-DCAFFE2_API=") + set(TORCH_API_DEFINE "-DTORCH_API=") endif() else() - set(Caffe2_API_DEFINE "-DCAFFE2_API=") + set(TORCH_API_DEFINE "-DTORCH_API=") endif() target_compile_definitions( - Caffe2_PROTO PRIVATE ${Caffe2_API_DEFINE}) + Caffe2_PROTO PRIVATE ${TORCH_API_DEFINE}) install(FILES ${Caffe2_PROTO_HEADERS} DESTINATION include/caffe2/proto) diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 16aa33900efd9..76acbf201f716 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -186,6 +186,9 @@ message TensorBoundShape { } repeated DimType dim_type = 2; // dim_type.size() == shape.dims.size() optional string name = 3; + // a flag to indicate whether the shape is final and cannot be changed + // eg: input/output of in-place ops + optional bool shape_is_final = 4; } message TensorBoundShapes { @@ -236,7 +239,6 @@ enum DeviceTypeProto { PROTO_XLA = 9; // XLA / TPU // Change the following number if you add more devices in the code. PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10; - PROTO_ONLY_FOR_TEST = 20901; // This device type is only for test. } // Device-specific options. We do not distinguish DeviceOption protos for diff --git a/caffe2/proto/caffe2_pb.h b/caffe2/proto/caffe2_pb.h index 6bd886b559c18..fc8acab2d62ab 100644 --- a/caffe2/proto/caffe2_pb.h +++ b/caffe2/proto/caffe2_pb.h @@ -15,9 +15,8 @@ constexpr DeviceType IDEEP = DeviceType::IDEEP; constexpr DeviceType HIP = DeviceType::HIP; constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES = DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES; -constexpr DeviceType ONLY_FOR_TEST = DeviceType::ONLY_FOR_TEST; -inline CAFFE2_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) { +inline TORCH_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) { switch (p) { case caffe2::PROTO_CPU: return DeviceType::CPU; @@ -35,8 +34,6 @@ inline CAFFE2_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) { return DeviceType::HIP; case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES: return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES; - case caffe2::PROTO_ONLY_FOR_TEST: - return DeviceType::ONLY_FOR_TEST; default: AT_ERROR( "Unknown device:", @@ -47,11 +44,11 @@ inline CAFFE2_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) { } } -inline CAFFE2_API DeviceType ProtoToType(int p) { +inline TORCH_API DeviceType ProtoToType(int p) { return ProtoToType(static_cast(p)); } -inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) { +inline TORCH_API DeviceTypeProto TypeToProto(const DeviceType& t) { switch (t) { case DeviceType::CPU: return caffe2::PROTO_CPU; @@ -69,8 +66,6 @@ inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) { return caffe2::PROTO_HIP; case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES: return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES; - case DeviceType::ONLY_FOR_TEST: - return caffe2::PROTO_ONLY_FOR_TEST; default: AT_ERROR( "Unknown device:", @@ -81,7 +76,7 @@ inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) { } } -inline CAFFE2_API caffe2::DeviceOption DeviceToOption( +inline TORCH_API caffe2::DeviceOption DeviceToOption( const at::Device& device) { caffe2::DeviceOption option; auto type = device.type(); @@ -102,7 +97,6 @@ inline CAFFE2_API caffe2::DeviceOption DeviceToOption( case DeviceType::MKLDNN: case DeviceType::IDEEP: case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES: - case DeviceType::ONLY_FOR_TEST: break; default: AT_ERROR( @@ -115,7 +109,7 @@ inline CAFFE2_API caffe2::DeviceOption DeviceToOption( return option; } -inline CAFFE2_API at::Device OptionToDevice(const caffe2::DeviceOption option) { +inline TORCH_API at::Device OptionToDevice(const caffe2::DeviceOption option) { auto type = option.device_type(); int32_t id = -1; switch (type) { diff --git a/caffe2/proto/caffe2_pb2.pyi b/caffe2/proto/caffe2_pb2.pyi new file mode 100644 index 0000000000000..060f60fc6c884 --- /dev/null +++ b/caffe2/proto/caffe2_pb2.pyi @@ -0,0 +1,18 @@ + +# Defined in caffe2/proto/caffe2_pb2.h +class DeviceType: + ... + +CPU: DeviceType = ... +CUDA: DeviceType = ... +OPENGL: DeviceType = ... +OPENCL: DeviceType = ... +MKLDNN: DeviceType = ... +IDEEP: DeviceType = ... +HIP: DeviceType = ... + +class NetDef: + ... + +class OperatorDef: + ... \ No newline at end of file diff --git a/caffe2/python/__init__.py b/caffe2/python/__init__.py index 8582eff9ce19b..1262f97588ad6 100644 --- a/caffe2/python/__init__.py +++ b/caffe2/python/__init__.py @@ -12,7 +12,6 @@ caffe2_pb2.IDEEP = caffe2_pb2.PROTO_IDEEP caffe2_pb2.HIP = caffe2_pb2.PROTO_HIP caffe2_pb2.COMPILE_TIME_MAX_DEVICE_TYPES = caffe2_pb2.PROTO_COMPILE_TIME_MAX_DEVICE_TYPES -caffe2_pb2.ONLY_FOR_TEST = caffe2_pb2.PROTO_ONLY_FOR_TEST if platform.system() == 'Windows': is_conda = os.path.exists(os.path.join(sys.prefix, 'conda-meta')) diff --git a/caffe2/python/_import_c_extension.py b/caffe2/python/_import_c_extension.py index d6754adc20fd5..32b9ec34d1f83 100644 --- a/caffe2/python/_import_c_extension.py +++ b/caffe2/python/_import_c_extension.py @@ -5,16 +5,6 @@ import sys from caffe2.python import extension_loader -# NOTE: we have to import python protobuf here **before** we load cpp extension. -# Otherwise it breaks under certain build conditions if cpp implementation of -# protobuf is used. Presumably there's some registry in protobuf library and -# python side has to initialize the dictionary first, before static -# initialization in python extension does so. Otherwise, duplicated protobuf -# descriptors will be created and it can lead to obscure errors like -# "Parameter to MergeFrom() must be instance of same class: -# expected caffe2.NetDef got caffe2.NetDef." -import caffe2.proto - # We will first try to load the gpu-enabled caffe2. If it fails, we will then # attempt to load the cpu version. The cpu backend is the minimum required, so # if that still fails, we will exit loud. diff --git a/caffe2/python/benchmarks/concat_benchmark.py b/caffe2/python/benchmarks/concat_benchmark.py new file mode 100644 index 0000000000000..d32def6841c3e --- /dev/null +++ b/caffe2/python/benchmarks/concat_benchmark.py @@ -0,0 +1,31 @@ +import argparse + +import numpy as np +from caffe2.python import core, workspace + + +def benchmark_concat(num_inputs, input_dim, axis, add_axis, iterations): + input_names = [f"input{i}" for i in range(num_inputs)] + for n in input_names: + workspace.FeedBlob(n, np.random.randn(*input_dim).astype(np.float32)) + + net = core.Net("benchmark_net") + net.Concat(input_names, ["output", "split_info"], axis=axis, add_axis=add_axis) + workspace.CreateNet(net) + + runtimes = workspace.BenchmarkNet(net.Name(), 1, iterations, True) + print(f"{num_inputs * np.prod(input_dim) * 4 / runtimes[1] / 1e6} GB/s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="minimal benchmark for concat.") + parser.add_argument("--num_inputs", type=int, default=2) + parser.add_argument("--input_dim", nargs="+", type=int, required=True) + parser.add_argument("--axis", type=int, default=-1) + parser.add_argument("--add_axis", type=int, default=0) + parser.add_argument("--iterations", type=int, default=64) + args, extra_args = parser.parse_known_args() + core.GlobalInit(["python"] + extra_args) + benchmark_concat( + args.num_inputs, args.input_dim, args.axis, args.add_axis, args.iterations + ) diff --git a/caffe2/python/benchmarks/sparse_lengths_sum_nbit_benchmark.py b/caffe2/python/benchmarks/sparse_lengths_sum_nbit_benchmark.py index 1b683be0d51ef..b4cb8f2da0b4e 100644 --- a/caffe2/python/benchmarks/sparse_lengths_sum_nbit_benchmark.py +++ b/caffe2/python/benchmarks/sparse_lengths_sum_nbit_benchmark.py @@ -5,7 +5,7 @@ import hypothesis.strategies as st import numpy as np -from caffe2.python import core, dyndep, workspace +from caffe2.python import core, workspace def benchmark_sparse_lengths_sum( diff --git a/caffe2/python/benchmarks/sparse_normalize_benchmark.py b/caffe2/python/benchmarks/sparse_normalize_benchmark.py new file mode 100644 index 0000000000000..91bb3a3448662 --- /dev/null +++ b/caffe2/python/benchmarks/sparse_normalize_benchmark.py @@ -0,0 +1,121 @@ +import argparse +import datetime + +# import hypothesis.strategies as st +import numpy as np +from caffe2.python import core, workspace + + +def benchmark_sparse_normalize( + categorical_limit, + embedding_size, + average_len, + batch_size, + iterations, + flush_cache, + fp16, +): + print("Preparing lookup table. " + str(datetime.datetime.now())) + + # We will use a constant, but non-trivial value so we save initialization + # time. + data = np.ones([categorical_limit, embedding_size], dtype=np.float32) + data *= 17.01 + + init_net = core.Net("init_net") + if fp16: + op = core.CreateOperator("FloatToHalf", "X", "X_fp16") + init_net.Proto().op.extend([op]) + l3_cache_size = 30 * 2 ** 20 // 4 + + # In order to produce truly random lengths and indices, we will embed a + # Python operator in the net to generate them. + def f(_, outputs): + lengths = np.random.randint( + int(average_len * 0.75), int(average_len * 1.25), batch_size + ).astype(np.int32) + indices = np.random.randint(0, categorical_limit, np.sum(lengths)).astype( + np.int64 + ) + outputs[0].feed(indices) + + workspace.FeedBlob("X", data) + workspace.FeedBlob("huge_blob", np.random.randn(l3_cache_size).astype(np.float32)) + + print("Data has shape {} {}".format(data.shape, datetime.datetime.now())) + + init_net.Python(f)([], ["indices"]) + workspace.RunNetOnce(init_net) + + net = core.Net("mynet") + op = core.CreateOperator( + "Float16SparseNormalize" if fp16 else "SparseNormalize", + ["X_fp16", "indices"] if fp16 else ["X", "indices"], + "X_fp16" if fp16 else "X", + ) + net.Proto().external_input.append("X") + net.Proto().external_input.append("X_fp16") + net.Proto().external_input.append("indices") + net.Proto().op.extend([op]) + if flush_cache: + net.Scale("huge_blob", "huge_blob_2x", value=2.0) + + workspace.CreateNet(net) + + # Set random seed, so that repeated runs will keep the same sequence of + # random indices. + np.random.seed(1701) + + print("Preparation finished. " + str(datetime.datetime.now())) + + runtimes = workspace.BenchmarkNet(net.Name(), 1, iterations, True) + + print("{} ms".format(runtimes[2 if flush_cache else 1])) + print("indice_size: " + str(workspace.FetchBlob("indices").size)) + print( + "{} GB/sec".format( + (2 if fp16 else 4) + * embedding_size + * workspace.FetchBlob("indices").size + / runtimes[2 if flush_cache else 1] + / 1e6 + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="minimal benchmark for sparse lengths sum." + ) + parser.add_argument( + "-e", "--embedding-size", type=int, default=600000, help="Lookup table size." + ) + parser.add_argument( + "--embedding-dim", type=int, default=128, help="Embedding dimension." + ) + parser.add_argument( + "--average-len", + type=int, + default=27, + help="Sparse feature average lengths, default is 27", + ) + parser.add_argument("--batch_size", type=int, default=100, help="The batch size.") + parser.add_argument( + "-i", "--iteration", type=int, default=100, help="The number of iterations." + ) + parser.add_argument( + "--flush-cache", action="store_true", help="If true, flush cache" + ) + parser.add_argument("--fp16", action="store_true", help="If true, use fp16") + args, extra_args = parser.parse_known_args() + core.GlobalInit(["python"] + extra_args) + + benchmark_sparse_normalize( + args.embedding_size, + args.embedding_dim, + args.average_len, + args.batch_size, + args.iteration, + args.flush_cache, + args.fp16, + ) diff --git a/caffe2/python/brew.py b/caffe2/python/brew.py index 0e050ec32c442..69a4561aae100 100644 --- a/caffe2/python/brew.py +++ b/caffe2/python/brew.py @@ -24,6 +24,7 @@ from caffe2.python.helpers.nonlinearity import * from caffe2.python.helpers.normalization import * from caffe2.python.helpers.pooling import * +from caffe2.python.helpers.quantization import * from caffe2.python.helpers.tools import * from caffe2.python.helpers.train import * @@ -52,6 +53,9 @@ class HelperWrapper(object): 'concat': concat, 'depth_concat': depth_concat, 'sum': sum, + 'reduce_sum': reduce_sum, + 'sub': sub, + 'arg_min': arg_min, 'transpose': transpose, 'iter': iter, 'accuracy': accuracy, @@ -65,10 +69,13 @@ class HelperWrapper(object): 'add_weight_decay': add_weight_decay, 'elementwise_linear': elementwise_linear, 'layer_norm': layer_norm, + 'mat_mul' : mat_mul, 'batch_mat_mul' : batch_mat_mul, 'cond' : cond, 'loop' : loop, 'db_input' : db_input, + 'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float, + 'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse, } def __init__(self, wrapped): diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py index 9d7797fc3adac..872a66c7bd1f7 100644 --- a/caffe2/python/checkpoint.py +++ b/caffe2/python/checkpoint.py @@ -22,8 +22,7 @@ -@context.define_context() -class Job(object): +class Job(context.Managed): """ A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the `exit_group` which will be run by a JobRunner. @@ -97,11 +96,13 @@ def compile(self, session_class): self.exit_group = session_class.compile(self.exit_group) def __enter__(self): + super(Job, self).__enter__() self.epoch_group.__enter__() return self def __exit__(self, *args): self.epoch_group.__exit__() + super(Job, self).__exit__(*args) def add_stop_condition(self, output): if isinstance(output, core.BlobReference): diff --git a/caffe2/python/compatibility.py b/caffe2/python/compatibility.py deleted file mode 100644 index 9d615a3083337..0000000000000 --- a/caffe2/python/compatibility.py +++ /dev/null @@ -1,8 +0,0 @@ -from six import PY2, PY3 - -if PY2: - import collections - container_abcs = collections -elif PY3: - import collections.abc - container_abcs = collections.abc diff --git a/caffe2/python/context.py b/caffe2/python/context.py index 28815bb7f36b0..ce9b312855e67 100644 --- a/caffe2/python/context.py +++ b/caffe2/python/context.py @@ -1,19 +1,15 @@ ## @package context # Module caffe2.python.context - - - - +import inspect import threading -import six +import functools class _ContextInfo(object): - def __init__(self, cls, allow_default, arg_name): + def __init__(self, cls, allow_default): self.cls = cls self.allow_default = allow_default - self.arg_name = arg_name self._local_stack = threading.local() @property @@ -43,14 +39,10 @@ class _ContextRegistry(object): def __init__(self): self._ctxs = {} - def register(self, ctx_info): - assert isinstance(ctx_info, _ContextInfo) - assert (ctx_info.cls not in self._ctxs), ( - 'Context %s already registered' % ctx_info.cls) - self._ctxs[ctx_info.cls] = ctx_info - def get(self, cls): - assert cls in self._ctxs, 'Context %s not registered.' % cls + if cls not in self._ctxs: + assert issubclass(cls, Managed), "must be a context managed class, got {}".format(cls) + self._ctxs[cls] = _ContextInfo(cls, allow_default=issubclass(cls, DefaultManaged)) return self._ctxs[cls] @@ -62,62 +54,53 @@ def _context_registry(): return _CONTEXT_REGISTRY -def __enter__(self): - if self._prev_enter is not None: - self._prev_enter() - _context_registry().get(self._ctx_class).enter(self) - return self +def _get_managed_classes(obj): + return [ + cls for cls in inspect.getmro(obj.__class__) + if issubclass(cls, Managed) and cls != Managed and cls != DefaultManaged + ] -def __exit__(self, *args): - _context_registry().get(self._ctx_class).exit(self) - if self._prev_exit is not None: - self._prev_exit(*args) +class Managed(object): + """ + Managed makes the inheritted class a context managed class. -def __call__(self, func): - @six.wraps(func) - def wrapper(*args, **kwargs): - with self: - return func(*args, **kwargs) - return wrapper - - -@classmethod -def _current(cls, value=None, required=True): - return _get_active_context(cls, value, required) - - -class define_context(object): - def __init__(self, arg_name=None, allow_default=False): - self.arg_name = arg_name - self.allow_default = allow_default + class Foo(Managed): ... - def __call__(self, cls): - assert not hasattr(cls, '_ctx_class'), ( - '%s parent class (%s) already defines context.' % ( - cls, cls._ctx_class)) - cls._ctx_class = cls + with Foo() as f: + assert f == Foo.current() + """ - _context_registry().register( - _ContextInfo(cls, self.allow_default, self.arg_name) - ) + @classmethod + def current(cls, value=None, required=True): + ctx_info = _context_registry().get(cls) + if value is not None: + assert isinstance(value, cls), ( + 'Wrong context type. Expected: %s, got %s.' % (cls, type(value))) + return value + return ctx_info.get_active(required=required) - cls._prev_enter = cls.__enter__ if hasattr(cls, '__enter__') else None - cls._prev_exit = cls.__exit__ if hasattr(cls, '__exit__') else None + def __enter__(self): + for cls in _get_managed_classes(self): + _context_registry().get(cls).enter(self) + return self - cls.__enter__ = __enter__ - cls.__exit__ = __exit__ - cls.__call__ = __call__ - cls.current = _current + def __exit__(self, *args): + for cls in _get_managed_classes(self): + _context_registry().get(cls).exit(self) - return cls + def __call__(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + return wrapper -def _get_active_context(cls, val=None, required=True): - ctx_info = _context_registry().get(cls) - if val is not None: - assert isinstance(val, cls), ( - 'Wrong context type. Expected: %s, got %s.' % (cls, type(val))) - return val - return ctx_info.get_active(required=required) +class DefaultManaged(Managed): + """ + DefaultManaged is similar to Managed but if there is no parent when + current() is called it makes a new one. + """ + pass diff --git a/caffe2/python/context.pyi b/caffe2/python/context.pyi new file mode 100644 index 0000000000000..f1ff173baa71e --- /dev/null +++ b/caffe2/python/context.pyi @@ -0,0 +1,13 @@ +from typing import Optional, TypeVar, Type + +T = TypeVar('T') + +class Managed: + @classmethod + def current(cls: Type[T], value: Optional[T] = None, required: bool = True) -> T: ... + + def __call__(self, func: T) -> T: ... + + def __enter__(self: T) -> T: ... + +class DefaultManaged(Managed): ... diff --git a/caffe2/python/context_test.py b/caffe2/python/context_test.py index 6c259d326a194..0ca36e49ac807 100644 --- a/caffe2/python/context_test.py +++ b/caffe2/python/context_test.py @@ -7,8 +7,13 @@ from threading import Thread -@context.define_context() -class MyContext(object): +class MyContext(context.Managed): + pass + +class DefaultMyContext(context.DefaultManaged): + pass + +class ChildMyContext(MyContext): pass @@ -37,3 +42,26 @@ def testMultiThreaded(self): @MyContext() def testDecorator(self): self.assertIsNotNone(MyContext.current()) + + def testNonDefaultCurrent(self): + with self.assertRaises(AssertionError): + MyContext.current() + + ctx = MyContext() + self.assertEqual(MyContext.current(value=ctx), ctx) + + self.assertIsNone(MyContext.current(required=False)) + + def testDefaultCurrent(self): + self.assertIsInstance(DefaultMyContext.current(), DefaultMyContext) + + def testNestedContexts(self): + with MyContext() as ctx1: + with DefaultMyContext() as ctx2: + self.assertEqual(DefaultMyContext.current(), ctx2) + self.assertEqual(MyContext.current(), ctx1) + + def testChildClasses(self): + with ChildMyContext() as ctx: + self.assertEqual(ChildMyContext.current(), ctx) + self.assertEqual(MyContext.current(), ctx) diff --git a/caffe2/python/convert.py b/caffe2/python/convert.py index 18033661a69e1..b4b37811de107 100644 --- a/caffe2/python/convert.py +++ b/caffe2/python/convert.py @@ -5,6 +5,3 @@ -from caffe2.proto import caffe2_pb2, torch_pb2 - -import caffe2.python._import_c_extension as C diff --git a/caffe2/python/convert_test.py b/caffe2/python/convert_test.py index a1dc52aad2d9f..d9d82bf5e6c4a 100644 --- a/caffe2/python/convert_test.py +++ b/caffe2/python/convert_test.py @@ -3,10 +3,8 @@ -from caffe2.python import convert, workspace -from caffe2.proto import caffe2_pb2, torch_pb2 +from caffe2.python import workspace import unittest -import numpy as np class TestOperator(unittest.TestCase): def setUp(self): diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 6d7c503e2c818..b0add7d39f625 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -1307,7 +1307,7 @@ def get_remapped_str(blob_str): def control_op_remap(op, prefix, blob_remap): net_arg_names = [] - if op.type == "If": + if op.type == "If" or op.type == "AsyncIf": net_arg_names = ['then_net', 'else_net'] else: net_arg_names = ['loop_net', 'cond_net'] @@ -1327,6 +1327,7 @@ def control_op_remap(op, prefix, blob_remap): 'RecurrentNetworkGradient': recurrent_network_op_remap, 'If': control_op_remap, 'While': control_op_remap, + 'AsyncIf': control_op_remap, } @@ -1475,13 +1476,12 @@ def __init__(self, name_or_proto): self._external_input_map.update(list(self._net.external_input)) # Set the next name index properly. - existing_names = set( - sum( - [list(op.input) for op in self._net.op], [] - ) + sum( - existing_outputs, [] - ) - ) + existing_names = set() + for op in self._net.op: + existing_names.update(list(op.input)) + for output in existing_outputs: + existing_names.update(output) + for outs in existing_outputs: self._op_outputs.update(outs) @@ -1506,6 +1506,10 @@ def __init__(self, name_or_proto): # make sure that this net name hasn't been used before self._net.name = Net._get_next_net_name(name) + # a map between prefix and ID for fast generation of blob names + self._next_blob_name_ids = {} + + def AppendNet(self, net, device_option=None): assert isinstance(net, Net) for i in net.Proto().external_input: @@ -1524,7 +1528,8 @@ def AppendNet(self, net, device_option=None): ops = net.Proto().op if device_option is not None: ops = [copy.deepcopy(op) for op in ops] - map(lambda x: x.device_option.CopyFrom(device_option), ops) + for op in ops: + op.device_option.CopyFrom(device_option) for op in ops: if op.type == "RecurrentNetwork": for arg in op.arg: @@ -1936,12 +1941,14 @@ def NextName(self, prefix=None, output_id=None): output_name = output_name_base if output_id is not None: output_name += ':' + str(output_id) - index = 2 + key = output_name + index = self._next_blob_name_ids.get(key, 2) while self.BlobIsDefined(str(ScopedBlobReference(output_name))): output_name = output_name_base + '_' + str(index) if output_id is not None: output_name += ':' + str(output_id) index += 1 + self._next_blob_name_ids[key] = index else: output_name = self._net.name + '_blob_' + str(self._next_name_index) self._next_name_index += 1 @@ -2024,10 +2031,14 @@ def AddArgument(self, arg_name, arg_value): def AddExternalInput(self, *inputs): assert len(inputs) > 0 refs = [] + input_name_set = set() for input in inputs: input_name = str(input) - assert str(input) not in self._external_input_map, ( - 'Net already contains an input named %s' % input_name) + assert ( + input_name not in self._external_input_map + and input_name not in input_name_set + ), ("Net already contains an input named %s" % input_name) + input_name_set.add(input_name) for input in inputs: input_name = str(input) self._net.external_input.extend([input_name]) @@ -2087,7 +2098,7 @@ def set_input_record(self, input_record): self._input_record = input_record for blob in self._input_record.field_blobs(): - if blob not in self.external_inputs: + if not self.is_external_input(blob): self.AddExternalInput(blob) return self._input_record @@ -2332,6 +2343,9 @@ def make_builder(t): ) def is_external_input(self, blob): + if self._recreate_lookup_tables: + self._RecreateLookupTables() + name = str(blob) return name in self._external_input_map diff --git a/caffe2/python/core_gradients_test.py b/caffe2/python/core_gradients_test.py index 3674b7aa45852..293eccca0dd4b 100644 --- a/caffe2/python/core_gradients_test.py +++ b/caffe2/python/core_gradients_test.py @@ -3,7 +3,6 @@ -from future.utils import bytes_to_native_str from hypothesis import given, settings import hypothesis.strategies as st import unittest diff --git a/caffe2/python/core_test.py b/caffe2/python/core_test.py index b0f5b11f0d1c0..523f14c6135f5 100644 --- a/caffe2/python/core_test.py +++ b/caffe2/python/core_test.py @@ -232,8 +232,44 @@ def test_mask_clone_update_external_list(self): "external output not matched", ) + def test_control_op_remap(self): + # Subnets under If/AsyncIf operators should get name remapping when cloned + n = core.Net("original") + then_net = core.Net("a") + then_net.FC(["inputA"], "fc_a") + else_net = core.Net("b") + else_net.FC(["inputB"], "fc_b") + n.If( + inputs=[], + outputs=[], + then_net=then_net.Proto(), + else_net=else_net.Proto(), + ) + copied = n.Clone("copied", blob_remap={"inputA": "inputX"}) + if_op = copied._net.op[0] + self.assertEqual(if_op.arg[0].n.op[0].input, ["inputX"]) + self.assertEqual(if_op.arg[1].n.op[0].input, ["inputB"]) + class TestExternalInputs(test_util.TestCase): + def testAddExternalInputShouldRaiseIfDuplicate(self): + net = core.Net("test") + net.AddExternalInput( + schema.Struct(("x", schema.Scalar(np.float))), + ) + with self.assertRaises(AssertionError): + net.AddExternalInput( + schema.Struct(("x", schema.Scalar(np.float))), + ) + + def testAddExternalInputShouldRaiseIfDuplicateInSameCall(self): + net = core.Net("test") + with self.assertRaises(AssertionError): + net.AddExternalInput( + schema.Struct(("x", schema.Scalar(np.float))), + schema.Struct(("x", schema.Scalar(np.float))), + ) + def testSetInputRecordWithBlobs(self): net = core.Net("test") record = schema.NewRecord(net, schema.Struct( diff --git a/caffe2/python/data_parallel_model.py b/caffe2/python/data_parallel_model.py index 95abb7159d42d..8537e1ee3cf19 100644 --- a/caffe2/python/data_parallel_model.py +++ b/caffe2/python/data_parallel_model.py @@ -1056,7 +1056,7 @@ def sumN(*dev_indices): """Create a Sum op for 2 or more blobs on different devices. Saves the result on the first device. - Arguments: + Args: dev_indices -- a list of device indices, which can be translated into CUDA identifiers with model._devices """ diff --git a/caffe2/python/dataio_test.py b/caffe2/python/dataio_test.py index 0c45fb50aed9c..ac1c72284fbfd 100644 --- a/caffe2/python/dataio_test.py +++ b/caffe2/python/dataio_test.py @@ -6,7 +6,6 @@ from caffe2.python.dataio import ( CompositeReader, CompositeReaderBuilder, - Reader, ReaderBuilder, ReaderWithDelay, ReaderWithLimit, @@ -29,7 +28,6 @@ import shutil import unittest import tempfile -import time def make_source_dataset(ws, size=100, offset=0, name=None): diff --git a/caffe2/python/experiment_util.py b/caffe2/python/experiment_util.py index 822a0a2950bac..6084312df84f7 100644 --- a/caffe2/python/experiment_util.py +++ b/caffe2/python/experiment_util.py @@ -10,7 +10,6 @@ import logging import socket import abc -import six from collections import OrderedDict from future.utils import viewkeys, viewvalues @@ -26,7 +25,7 @@ class ExternalLogger(object): - six.add_metaclass(abc.ABCMeta) + __metaclass__ = abc.ABCMeta @abc.abstractmethod def set_runtime_args(self, runtime_args): diff --git a/caffe2/python/fused_8bit_rowwise_conversion_ops_test.py b/caffe2/python/fused_8bit_rowwise_conversion_ops_test.py index a7e5d714b63c0..5058e73fe311f 100644 --- a/caffe2/python/fused_8bit_rowwise_conversion_ops_test.py +++ b/caffe2/python/fused_8bit_rowwise_conversion_ops_test.py @@ -30,7 +30,7 @@ def floats_to_bytes(floats): if isinstance(as_bytes[0], int): byte_matrix[i] = list(as_bytes) else: - byte_matrix[i] = list(map(ord, as_bytes)) + byte_matrix[i] = [ord(i) for i in as_bytes] return byte_matrix diff --git a/caffe2/python/gradient_checker.py b/caffe2/python/gradient_checker.py index afb8d50714929..5f116bd6107c4 100644 --- a/caffe2/python/gradient_checker.py +++ b/caffe2/python/gradient_checker.py @@ -5,6 +5,7 @@ +import os import numpy as np from caffe2.python import core, workspace, net_drawer @@ -292,8 +293,20 @@ def CheckSimple( if ensure_outputs_are_inferred: self._assertInferTensorChecks(op, grad_ops) + full_grad_check = os.getenv('CAFFE2_FULL_GRAD_CHECK') == '1' + dims_to_check = inputs[input_to_check].size for current_dim in range(dims_to_check): + # Grad check is very expensive (as it involves running the op from + # scratch for each of the input tensor elements). Thus, let's + # run it by default only on a small subset of dimensions. Here we + # apply very scientific approach: the first and the last 3 elements + # of each tensor. Pass CAFFE2_FULL_GRAD_CHECK=1 env var to enable + # the full check + if not full_grad_check and current_dim >= 3 and \ + current_dim + 3 < dims_to_check: + grad_estimate.flat[current_dim] = grad.flat[current_dim] + continue # Positive gradient inputs[input_to_check].flat[current_dim] += self._stepsize pos_loss, _ = self.GetLossAndGrad( diff --git a/caffe2/python/helpers/algebra.py b/caffe2/python/helpers/algebra.py index 948c55ac88ceb..2b626677b029c 100644 --- a/caffe2/python/helpers/algebra.py +++ b/caffe2/python/helpers/algebra.py @@ -18,9 +18,31 @@ def sum(model, blob_in, blob_out, **kwargs): return model.net.Sum(blob_in, blob_out, **kwargs) +def reduce_sum(model, blob_in, blob_out, **kwargs): + """ReduceSum""" + return model.net.ReduceSum(blob_in, blob_out, **kwargs) + + +def sub(model, blob_in, blob_out, **kwargs): + """Subtract""" + return model.net.Sub(blob_in, blob_out, **kwargs) + + +def mat_mul(model, blob_in, blob_out, **kwargs): + """Matrix multiplication""" + return model.net.MatMul(blob_in, blob_out, **kwargs) + + +def arg_min(model, blob_in, blob_out, **kwargs): + """ArgMin""" + return model.net.ArgMin(blob_in, blob_out, **kwargs) + def batch_mat_mul(model, blob_in, blob_out, enable_tensor_core=False, **kwargs): if enable_tensor_core: kwargs['engine'] = 'TENSORCORE' return model.net.BatchMatMul(blob_in, blob_out, **kwargs) + +def sparse_lengths_sum_4bit_rowwise_sparse(model, blob_in, blob_out, **kwargs): + return model.net.SparseLengthsSum4BitRowwiseSparse(blob_in, blob_out, **kwargs) diff --git a/caffe2/python/helpers/quantization.py b/caffe2/python/helpers/quantization.py new file mode 100644 index 0000000000000..4e7a6da32436d --- /dev/null +++ b/caffe2/python/helpers/quantization.py @@ -0,0 +1,9 @@ +# @package quantization +# Module caffe2.python.helpers.quantization + + +def fused_8bit_rowwise_quantized_to_float( + model, blob_in, blob_out +): + """Fused8BitRowwiseQuantizedToFloat""" + return model.net.Fused8BitRowwiseQuantizedToFloat(blob_in, blob_out) diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index 045677f8422a9..33e065779d6cf 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -1,7 +1,3 @@ - - - - import numpy as np import copy import time @@ -18,6 +14,21 @@ dyndep.InitOpsLibrary('@/caffe2/caffe2/fb/optimizers:sgd_simd_ops') +if workspace.has_gpu_support: + # NOTE: During GPU stress tests, the number of workers exceeds the number + # of GPUs which results in flakiness from GPU contention. As a + # result, deadlines are not enforced on CUDA runs. + _hypothesis_settings = settings + + def settings(**kwargs): + if 'deadline' in kwargs: + kwargs['deadline'] = None + kwargs.setdefault('max_examples', 50) + + def wrapped(f): + return _hypothesis_settings(**kwargs)(f) + return wrapped + def sigmoid(x): return 1.0 / (1.0 + np.exp(-x)) @@ -331,6 +342,7 @@ def prod(xs): @unittest.skipIf(not workspace.has_gpu_support, "Skipping test due to no gpu present.") + @settings(deadline=None) @given(hidden_size=st.integers(min_value=1, max_value=3), num_layers=st.integers(min_value=1, max_value=3), bidirectional=st.booleans(), @@ -994,6 +1006,38 @@ def op_ref(x): inputs=[np.array(lengths, dtype=np.int32)], reference=op_ref) + @given( + lengths=st.lists( + st.integers(min_value=0, max_value=10), min_size=0, max_size=10 + ), + include_last_offset=st.booleans(), + **hu.gcs_cpu_only + ) + @settings(deadline=None) + def test_lengths_to_offsets(self, lengths, include_last_offset, gc, dc): + op = core.CreateOperator( + "LengthsToOffsets", + ["lengths"], + ["ranges"], + include_last_offset=include_last_offset, + ) + + def op_ref(x): + if not x.size: + arr = [x.reshape(0)] + else: + arr = [np.concatenate(([0], np.cumsum(x)[:-1]))] + if include_last_offset: + arr[0] = np.concatenate((arr[0], np.array([np.sum(x)]))) + return tuple(arr) + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=[np.array(lengths, dtype=np.int32)], + reference=op_ref, + ) + @given(prediction=hu.arrays(dims=[10, 3], elements=hu.floats(allow_nan=False, allow_infinity=False, @@ -2683,7 +2727,7 @@ def histogram(X): Y[X >= upper_bound] = num_buckets + 1 Y[(X >= lower_bound) & (X < upper_bound)] = \ ((X[(X >= lower_bound) & (X < upper_bound)] - lower_bound) / - segment + 1).astype(np.int32) + segment + 1).astype(np.int32) for i in range(Y.shape[0]): for j in range(Y.shape[1]): diff --git a/caffe2/python/hypothesis_test_util.py b/caffe2/python/hypothesis_test_util.py index 2000e269969ea..0fc489d772731 100644 --- a/caffe2/python/hypothesis_test_util.py +++ b/caffe2/python/hypothesis_test_util.py @@ -50,16 +50,11 @@ import logging import numpy as np import os -import six import struct def is_sandcastle(): - if os.getenv('SANDCASTLE') == '1': - return True - elif os.getenv('TW_JOB_USER') == 'sandcastle': - return True - return False + return os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle' def is_travis(): @@ -113,7 +108,7 @@ def floats(*args, **kwargs): max_examples=50, min_satisfying_examples=1, verbosity=hypothesis.Verbosity.verbose, - deadline=1000)) + deadline=10000)) hypothesis.settings.register_profile( "dev", settings( @@ -121,7 +116,8 @@ def floats(*args, **kwargs): database=None, max_examples=10, min_satisfying_examples=1, - verbosity=hypothesis.Verbosity.verbose)) + verbosity=hypothesis.Verbosity.verbose, + deadline=10000)) hypothesis.settings.register_profile( "debug", settings( @@ -129,7 +125,8 @@ def floats(*args, **kwargs): database=None, max_examples=1000, min_satisfying_examples=1, - verbosity=hypothesis.Verbosity.verbose)) + verbosity=hypothesis.Verbosity.verbose, + deadline=50000)) hypothesis.settings.load_profile( 'sandcastle' if is_sandcastle() else os.getenv('CAFFE2_HYPOTHESIS_PROFILE', @@ -750,5 +747,5 @@ def assertRunOpRaises( if regexp is None: self.assertRaises(exception, workspace.RunOperatorOnce, op) else: - six.assertRaisesRegex( - self, exception, regexp, workspace.RunOperatorOnce, op) + self.assertRaisesRegex( + exception, regexp, workspace.RunOperatorOnce, op) diff --git a/caffe2/python/ideep/conv_op_test.py b/caffe2/python/ideep/conv_op_test.py index ae4473ea4864e..7c5a0026c113a 100644 --- a/caffe2/python/ideep/conv_op_test.py +++ b/caffe2/python/ideep/conv_op_test.py @@ -4,7 +4,6 @@ import unittest -import sys import hypothesis.strategies as st from hypothesis import given, settings import numpy as np diff --git a/caffe2/python/ideep/convfusion_op_test.py b/caffe2/python/ideep/convfusion_op_test.py index 18ce574b623b3..a0a782ab8a03e 100644 --- a/caffe2/python/ideep/convfusion_op_test.py +++ b/caffe2/python/ideep/convfusion_op_test.py @@ -5,8 +5,7 @@ import unittest import hypothesis.strategies as st -from hypothesis import given, settings -import copy +from hypothesis import given import numpy as np import math from caffe2.proto import caffe2_pb2 diff --git a/caffe2/python/ideep/dropout_op_test.py b/caffe2/python/ideep/dropout_op_test.py index 33b0a52a74216..5b07333758dd9 100644 --- a/caffe2/python/ideep/dropout_op_test.py +++ b/caffe2/python/ideep/dropout_op_test.py @@ -7,8 +7,6 @@ from hypothesis import given import hypothesis.strategies as st import numpy as np - -from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.ideep_test_util as mu diff --git a/caffe2/python/ideep/order_switch_op_test.py b/caffe2/python/ideep/order_switch_op_test.py index a259e01bab102..39ede0d214fe9 100644 --- a/caffe2/python/ideep/order_switch_op_test.py +++ b/caffe2/python/ideep/order_switch_op_test.py @@ -10,7 +10,6 @@ import caffe2.python.ideep_test_util as mu from hypothesis import given, settings -from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace diff --git a/caffe2/python/ideep/shape_op_test.py b/caffe2/python/ideep/shape_op_test.py index 47114832f85d3..1beb24bc88038 100644 --- a/caffe2/python/ideep/shape_op_test.py +++ b/caffe2/python/ideep/shape_op_test.py @@ -7,7 +7,6 @@ import hypothesis.strategies as st from hypothesis import given, settings import numpy as np -from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.ideep_test_util as mu diff --git a/caffe2/python/ideep/spatial_bn_op_test.py b/caffe2/python/ideep/spatial_bn_op_test.py index 618a0e7fbfc3a..97efafa720570 100644 --- a/caffe2/python/ideep/spatial_bn_op_test.py +++ b/caffe2/python/ideep/spatial_bn_op_test.py @@ -7,9 +7,8 @@ import hypothesis.strategies as st import numpy as np import unittest -from caffe2.python import brew, core, workspace +from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu -from caffe2.python.model_helper import ModelHelper import caffe2.python.ideep_test_util as mu diff --git a/caffe2/python/ideep/test_ideep_net.py b/caffe2/python/ideep/test_ideep_net.py index aa1c5bc260fa1..42feeed001220 100644 --- a/caffe2/python/ideep/test_ideep_net.py +++ b/caffe2/python/ideep/test_ideep_net.py @@ -9,7 +9,6 @@ import numpy as np import argparse import time -import os.path def GetArgumentParser(): diff --git a/caffe2/python/ideep/transform_ideep_net.py b/caffe2/python/ideep/transform_ideep_net.py index 962d4051718b0..2d0f35a7406f2 100644 --- a/caffe2/python/ideep/transform_ideep_net.py +++ b/caffe2/python/ideep/transform_ideep_net.py @@ -6,7 +6,6 @@ import argparse import copy import json -import os.path import numpy as np diff --git a/caffe2/python/ideep/transpose_op_test.py b/caffe2/python/ideep/transpose_op_test.py index 8b324ed964aeb..f8b784822a077 100644 --- a/caffe2/python/ideep/transpose_op_test.py +++ b/caffe2/python/ideep/transpose_op_test.py @@ -7,7 +7,6 @@ import hypothesis.strategies as st from hypothesis import given, settings import numpy as np -from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.ideep_test_util as mu diff --git a/caffe2/python/ideep_test_util.py b/caffe2/python/ideep_test_util.py index 7129ed14ba743..0cc643317c934 100644 --- a/caffe2/python/ideep_test_util.py +++ b/caffe2/python/ideep_test_util.py @@ -14,7 +14,6 @@ import hypothesis.strategies as st from caffe2.proto import caffe2_pb2 -from caffe2.python import workspace from caffe2.python import hypothesis_test_util as hu cpu_do = hu.cpu_do diff --git a/caffe2/python/layer_model_helper.py b/caffe2/python/layer_model_helper.py index 7c3dda3b320c0..6a5a3c82dd30c 100644 --- a/caffe2/python/layer_model_helper.py +++ b/caffe2/python/layer_model_helper.py @@ -17,12 +17,10 @@ from caffe2.python.optimizer import get_param_device, Optimizer from caffe2.python.regularizer import Regularizer, RegularizationBy from caffe2.python.layers import layers -from caffe2.proto import caffe2_pb2 from future.utils import viewitems, viewvalues import logging import numpy as np -import six import copy logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ def filter_metrics_schema(self, white_set): def add_ad_hoc_plot_blob(self, blob, dtype=None): assert isinstance( - blob, (six.string_types, core.BlobReference) + blob, (str, core.BlobReference) ), "expect type str or BlobReference, but got {}".format(type(blob)) dtype = dtype or (np.float, (1, )) self.add_metric_field(str(blob), schema.Scalar(dtype, blob)) @@ -173,7 +171,7 @@ def initializer(blob_name): def add_global_constant( self, name, array=None, dtype=None, initializer=None ): - assert isinstance(name, six.string_types), ( + assert isinstance(name, str), ( 'name should be a string as we are using it as map key') # This is global namescope for constants. They will be created in all # init_nets and there should be very few of them. @@ -310,7 +308,7 @@ def create_param(self, param_name, shape, initializer, optimizer=None, ps_param=None, regularizer=None): if isinstance(param_name, core.BlobReference): param_name = str(param_name) - elif isinstance(param_name, six.string_types): + elif isinstance(param_name, str): # Parameter name will be equal to current Namescope that got # resolved with the respect of parameter sharing of the scopes. param_name = parameter_sharing_context.get_parameter_name( @@ -750,6 +748,6 @@ def breakdown_map(self, breakdown_map): # TODO(xlwang): provide more rich feature information in breakdown_map; # and change the assertion accordingly assert isinstance(breakdown_map, dict) - assert all(isinstance(k, six.string_types) for k in breakdown_map) + assert all(isinstance(k, str) for k in breakdown_map) assert sorted(breakdown_map.values()) == list(range(len(breakdown_map))) self._breakdown_map = breakdown_map diff --git a/caffe2/python/layer_parameter_sharing_test.py b/caffe2/python/layer_parameter_sharing_test.py index 518412b9e90c9..8e1831a2ff350 100644 --- a/caffe2/python/layer_parameter_sharing_test.py +++ b/caffe2/python/layer_parameter_sharing_test.py @@ -9,7 +9,6 @@ ) from caffe2.python.optimizer import AdagradOptimizer, AdamOptimizer from caffe2.python.layer_test_util import LayersTestCase -import six class ParameterSharingTest(LayersTestCase): @@ -116,7 +115,7 @@ def test_layer_shared_parameter_name_different_shapes(self): self.assertEquals(self.model.layers[-1].w, 'global_scope/fc/w') - with six.assertRaisesRegex(self, ValueError, 'Got inconsistent shapes .*'): + with self.assertRaisesRegex(ValueError, 'Got inconsistent shapes .*'): self.model.FC( self.model.input_feature_schema.float_features, output_dims + 1 diff --git a/caffe2/python/layers/functional.py b/caffe2/python/layers/functional.py index c6d156fd68cec..bc47c474ac8f6 100644 --- a/caffe2/python/layers/functional.py +++ b/caffe2/python/layers/functional.py @@ -11,7 +11,6 @@ ) import caffe2.proto.caffe2_pb2 as caffe2_pb2 import numpy as np -import six import logging logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ def __init__(self, model, input_record, output_names_or_num, function, self._kwargs = kwargs return_struct = ( isinstance(output_names_or_num, list) or - (isinstance(output_names_or_num, six.integer_types) and + (isinstance(output_names_or_num, int) and output_names_or_num != 1) ) diff --git a/caffe2/python/layers/last_n_window_collector.py b/caffe2/python/layers/last_n_window_collector.py index a16b731a2f78f..5e6874b4cca03 100644 --- a/caffe2/python/layers/last_n_window_collector.py +++ b/caffe2/python/layers/last_n_window_collector.py @@ -1,10 +1,6 @@ ## @package last_n_window_collector # Module caffe2.python.layers.last_n_window_collector - - - - from caffe2.python import core, schema from caffe2.python.layers.layers import ModelLayer diff --git a/caffe2/python/layers/merge_id_lists.py b/caffe2/python/layers/merge_id_lists.py index 68c27b5875678..b076cd8c5e750 100644 --- a/caffe2/python/layers/merge_id_lists.py +++ b/caffe2/python/layers/merge_id_lists.py @@ -16,7 +16,7 @@ class MergeIdLists(ModelLayer): """Merge multiple ID_LISTs into a single ID_LIST - Arguments: + Args: model: A layer model instance input_record: Tuple (Struct) of ID_LIST features to be merged diff --git a/caffe2/python/layers/sampling_trainable_mixin.py b/caffe2/python/layers/sampling_trainable_mixin.py index 403cc5a4a51cf..79c928d212528 100644 --- a/caffe2/python/layers/sampling_trainable_mixin.py +++ b/caffe2/python/layers/sampling_trainable_mixin.py @@ -6,10 +6,9 @@ import abc -import six -class SamplingTrainableMixin(six.with_metaclass(abc.ABCMeta, object)): +class SamplingTrainableMixin(metaclass=abc.ABCMeta): def __init__(self, *args, **kwargs): super(SamplingTrainableMixin, self).__init__(*args, **kwargs) diff --git a/caffe2/python/layers/tags.py b/caffe2/python/layers/tags.py index 5161ee2e1a967..613fdbe8f45d9 100644 --- a/caffe2/python/layers/tags.py +++ b/caffe2/python/layers/tags.py @@ -5,13 +5,12 @@ -import six +import functools from caffe2.python import context -@context.define_context(allow_default=True) -class TagContext(object): +class TagContext(context.DefaultManaged): """ Scope driven way to provide tags to the layers. """ @@ -61,7 +60,7 @@ class Tags(object): COMPONENT = 'component:' PIPELINE = 'pipeline:' """ - Indicate it's a dense layer or dense param init, + Indicate it's a dense layer or dense param init, but we use hogwild across multiple trainers """ HOGWILD_DENSE = "hogwild_dense" @@ -105,7 +104,7 @@ def __exit__(self, type, value, traceback): TagContext.current().remove_tags(self.tags) def __call__(self, func): - @six.wraps(func) + @functools.wraps(func) def wrapper(*args, **kwargs): with self: return func(*args, **kwargs) diff --git a/caffe2/python/lazy_dyndep_test.py b/caffe2/python/lazy_dyndep_test.py index 1441facd3a6f7..d8d08b3ffdf99 100644 --- a/caffe2/python/lazy_dyndep_test.py +++ b/caffe2/python/lazy_dyndep_test.py @@ -60,7 +60,7 @@ class TestLazyDynDepAllCompare(hu.HypothesisTestCase): @given( d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8) ) - @settings(deadline=10000) + @settings(deadline=None) def test_allcompare(self, d, n, num_procs): dims = [] for _ in range(d): diff --git a/caffe2/python/memonger_test.py b/caffe2/python/memonger_test.py index 8584e8d5e4cce..a54cf95302e51 100644 --- a/caffe2/python/memonger_test.py +++ b/caffe2/python/memonger_test.py @@ -1,8 +1,3 @@ - - - - - import numpy as np from caffe2.python import workspace, memonger, core, model_helper, brew @@ -456,6 +451,89 @@ def test_forward_optim_tree_harder(self, input_dim, output_dim, batch_size): np.testing.assert_almost_equal(loss1, optimized_loss1) np.testing.assert_almost_equal(loss2, optimized_loss2) + # This test reproduces scenario where dag traversal for finding + # shared blobs was not always starting from ops with in degree of 0 + @settings(deadline=10000) + def test_forward_optim_tree_dag_traversal(self): + input_dim = 4 + output_dim = 4 + batch_size = 4 + + m = model_helper.ModelHelper() + m.Proto().type = "dag" + m.Proto().num_workers = 4 + + with core.NameScope("name_x"): + fc1 = brew.fc(m, "data", "fc1", dim_in=input_dim, dim_out=output_dim) + fc2 = brew.fc(m, fc1, "fc2", dim_in=output_dim, dim_out=output_dim) + + fc3 = brew.fc(m, fc2, "fc3", dim_in=output_dim, dim_out=output_dim) + fc4 = brew.fc(m, fc3, "fc4", dim_in=output_dim, dim_out=output_dim) + fc5 = brew.fc(m, fc4, "fc5", dim_in=output_dim, dim_out=output_dim) + + # Branch + fc3b = brew.fc(m, fc2, "fc3b", dim_in=output_dim, dim_out=output_dim) + fc4b = brew.fc(m, fc3b, "fc4b", dim_in=output_dim, dim_out=output_dim) + fc5b = brew.fc(m, fc4b, "fc5b", dim_in=output_dim, dim_out=output_dim) + + fc5sum = brew.sum(m, [fc5, fc5b], "fc5sum") + + fc5.Relu([], fc5sum) \ + .Softmax([], "pred1") \ + .LabelCrossEntropy(["label"], ["xent1"]) \ + .AveragedLoss([], "loss1") + fc6 = brew.fc(m, fc5, "fc6", dim_in=output_dim, dim_out=output_dim) + fc6.Relu([], fc6) \ + .Softmax([], "pred2") \ + .LabelCrossEntropy(["label"], ["xent2"]) \ + .AveragedLoss([], "loss2") + + blobs_before = count_blobs(m.net.Proto()) + # adding name_x/fc5_w as heads (which belongs to non-root op) + # to make sure that dag traversal always starts from root ops + optim_proto = memonger.optimize_inference_for_dag( + m.net, ["name_x/fc5_w", "name_x/data"], "name_x" + ) + blobs_after = count_blobs(optim_proto) + self.assertLess(blobs_after, blobs_before) + + # This is specifically to verify the op schema check being done in memonger + def test_forward_optim_tree_enforce_inplace_op_invalid(self): + m = model_helper.ModelHelper() + m.Proto().type = "dag" + m.Proto().num_workers = 4 + + net = m.net + net.IndexFreeze("A", "B") # enforce inplace op + net.Sum(["B", "B"], "C") + net.Relu("C", "D") + net.Sum(["D", "D"], "E") + + with self.assertRaises(RuntimeError): + memonger.optimize_inference_for_dag(net, ["A"], "") + + # Here inplace op is specifically a root op to repro the scenario where dag + # memonger could treat all the output blobs as shareable blobs and fails + # assertion of input blob with the same name not allowed to share + def test_forward_optim_tree_enforce_inplace_op_valid_and_as_head(self): + m = model_helper.ModelHelper() + m.Proto().type = "dag" + m.Proto().num_workers = 4 + + net = m.net + net.IndexFreeze("A", "A") # enforce inplace op + net.Sum(["A", "A"], "B") + net.Relu("B", "C") + net.Relu("C", "D") + net.Sum(["D", "D"], "E") + + blobs_before = count_blobs(m.net.Proto()) + optim_proto = memonger.optimize_inference_for_dag( + net, ["A"], "" + ) + blobs_after = count_blobs(optim_proto) + self.assertLess(blobs_after, blobs_before) + def test_rnn(self): from caffe2.python import rnn_cell T = 5 diff --git a/caffe2/python/mkl/mkl_LRN_op_test.py b/caffe2/python/mkl/mkl_LRN_op_test.py index 2b084bea591b7..fddb20e6bb14d 100644 --- a/caffe2/python/mkl/mkl_LRN_op_test.py +++ b/caffe2/python/mkl/mkl_LRN_op_test.py @@ -5,7 +5,7 @@ import unittest import hypothesis.strategies as st -from hypothesis import given, settings +from hypothesis import given import numpy as np from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/mkl/mkl_LRN_speed_test.py b/caffe2/python/mkl/mkl_LRN_speed_test.py index ae42902d91021..c192137dc28c9 100644 --- a/caffe2/python/mkl/mkl_LRN_speed_test.py +++ b/caffe2/python/mkl/mkl_LRN_speed_test.py @@ -6,7 +6,7 @@ import numpy as np from caffe2.proto import caffe2_pb2 -from caffe2.python import cnn, core, workspace, test_util +from caffe2.python import core, workspace, test_util @unittest.skipIf(not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn.") diff --git a/caffe2/python/mkl/mkl_conv_op_test.py b/caffe2/python/mkl/mkl_conv_op_test.py index f1fe7b0623182..74c4f2c6cde99 100644 --- a/caffe2/python/mkl/mkl_conv_op_test.py +++ b/caffe2/python/mkl/mkl_conv_op_test.py @@ -5,7 +5,7 @@ import unittest import hypothesis.strategies as st -from hypothesis import given, settings +from hypothesis import given import numpy as np from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/mkl/mkl_fc_op_test.py b/caffe2/python/mkl/mkl_fc_op_test.py index 01786d55c3376..180d93f265703 100644 --- a/caffe2/python/mkl/mkl_fc_op_test.py +++ b/caffe2/python/mkl/mkl_fc_op_test.py @@ -5,7 +5,7 @@ import unittest import hypothesis.strategies as st -from hypothesis import given, settings +from hypothesis import given import numpy as np from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/mkl/mkl_fc_speed_test.py b/caffe2/python/mkl/mkl_fc_speed_test.py index 85f5605e96766..243e49c2f8f8b 100644 --- a/caffe2/python/mkl/mkl_fc_speed_test.py +++ b/caffe2/python/mkl/mkl_fc_speed_test.py @@ -6,7 +6,7 @@ import numpy as np from caffe2.proto import caffe2_pb2 -from caffe2.python import cnn, core, workspace, test_util +from caffe2.python import core, workspace, test_util @unittest.skipIf(not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn.") diff --git a/caffe2/python/mkl/mkl_fill_op_test.py b/caffe2/python/mkl/mkl_fill_op_test.py index 26a9b7131b0b0..f233275786f7f 100644 --- a/caffe2/python/mkl/mkl_fill_op_test.py +++ b/caffe2/python/mkl/mkl_fill_op_test.py @@ -5,8 +5,7 @@ import unittest import hypothesis.strategies as st -from hypothesis import given, settings -import numpy as np +from hypothesis import given from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.mkl_test_util as mu diff --git a/caffe2/python/mkl/mkl_pool_speed_test.py b/caffe2/python/mkl/mkl_pool_speed_test.py index b25e0f915cc7e..aa43aed97a09c 100644 --- a/caffe2/python/mkl/mkl_pool_speed_test.py +++ b/caffe2/python/mkl/mkl_pool_speed_test.py @@ -6,7 +6,7 @@ import numpy as np from caffe2.proto import caffe2_pb2 -from caffe2.python import cnn, core, workspace, test_util +from caffe2.python import core, workspace, test_util @unittest.skipIf(not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn.") diff --git a/caffe2/python/mkl/mkl_sbn_op_test.py b/caffe2/python/mkl/mkl_sbn_op_test.py index 2ac9080ce670e..86856b130d637 100644 --- a/caffe2/python/mkl/mkl_sbn_op_test.py +++ b/caffe2/python/mkl/mkl_sbn_op_test.py @@ -5,7 +5,7 @@ import unittest import hypothesis.strategies as st -from hypothesis import given, settings +from hypothesis import given import numpy as np from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/mkl/mkl_sbn_speed_test.py b/caffe2/python/mkl/mkl_sbn_speed_test.py index 3b3b71d1c997c..05885ceca5756 100644 --- a/caffe2/python/mkl/mkl_sbn_speed_test.py +++ b/caffe2/python/mkl/mkl_sbn_speed_test.py @@ -6,7 +6,7 @@ import numpy as np from caffe2.proto import caffe2_pb2 -from caffe2.python import cnn, core, workspace, test_util +from caffe2.python import core, workspace, test_util @unittest.skipIf(not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn.") diff --git a/caffe2/python/mkl/mkl_speed_test.py b/caffe2/python/mkl/mkl_speed_test.py index 9a7310a484d14..ab2e4428519a0 100644 --- a/caffe2/python/mkl/mkl_speed_test.py +++ b/caffe2/python/mkl/mkl_speed_test.py @@ -6,7 +6,7 @@ import numpy as np from caffe2.proto import caffe2_pb2 -from caffe2.python import cnn, core, workspace, test_util +from caffe2.python import core, workspace, test_util @unittest.skipIf(not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn.") diff --git a/caffe2/python/mkl/rewrite_graph.py b/caffe2/python/mkl/rewrite_graph.py index 3a88a3deecccc..b52501584064d 100644 --- a/caffe2/python/mkl/rewrite_graph.py +++ b/caffe2/python/mkl/rewrite_graph.py @@ -6,7 +6,6 @@ import copy from caffe2.proto import caffe2_pb2 from caffe2.python import core -import caffe2.python._import_c_extension as C def rewrite_init_net_simple(net): diff --git a/caffe2/python/model_helper.py b/caffe2/python/model_helper.py index a5a4865c0ec1d..5eb81d898b33f 100644 --- a/caffe2/python/model_helper.py +++ b/caffe2/python/model_helper.py @@ -21,7 +21,6 @@ from itertools import chain import logging -import six # _known_working_ops are operators that do not need special care. @@ -199,7 +198,7 @@ def create_param(self, param_name, shape, initializer, tags=None): # ParameterSharing will be applied. if isinstance(param_name, core.BlobReference): param_name = str(param_name) - elif isinstance(param_name, six.string_types): + elif isinstance(param_name, str): # Parameter name will be equal to current Namescope that got # resolved with the respect of parameter sharing of the scopes. param_name = parameter_sharing_context.get_parameter_name( diff --git a/caffe2/python/modeling/initializers.py b/caffe2/python/modeling/initializers.py index b3e4b1a44dd70..ba4236d046544 100644 --- a/caffe2/python/modeling/initializers.py +++ b/caffe2/python/modeling/initializers.py @@ -6,8 +6,6 @@ from caffe2.python.core import DataType, BlobReference, ScopedBlobReference from caffe2.python.modeling.parameter_info import ParameterInfo -import six - class Initializer(object): ''' @@ -47,7 +45,7 @@ class ExternalInitializer(object): def create_param(self, param_name, init_net, shape): if isinstance(param_name, BlobReference): param = BlobReference(str(param_name), init_net) - elif isinstance(param_name, six.string_types): + elif isinstance(param_name, str): param = ScopedBlobReference(param_name, init_net) else: raise TypeError("Unsupported type for param_name") diff --git a/caffe2/python/modeling/net_modifier.py b/caffe2/python/modeling/net_modifier.py index e824c828e4bdb..c0545fad08f51 100644 --- a/caffe2/python/modeling/net_modifier.py +++ b/caffe2/python/modeling/net_modifier.py @@ -4,10 +4,9 @@ import abc -import six -class NetModifier(six.with_metaclass(abc.ABCMeta, object)): +class NetModifier(metaclass=abc.ABCMeta): """ An abstraction class for supporting modifying a generated net. Inherited classes should implement the modify_net method where diff --git a/caffe2/python/net_builder.py b/caffe2/python/net_builder.py index 70dcdec11a58a..fd525ed4766a1 100644 --- a/caffe2/python/net_builder.py +++ b/caffe2/python/net_builder.py @@ -10,11 +10,10 @@ from caffe2.python.control_ops_util import add_if_op, add_while_op -@context.define_context() -class NetBuilder(object): +class NetBuilder(context.Managed): """ Scope-driven mechanism for building nets, loops and conditional blocks. - Arguments: + Args: name: NetBuilder's name initial_scope: list of blobs that are available for reading/writing Example: @@ -138,6 +137,8 @@ def get(self): return self._children def __exit__(self, etype, *args): + super(NetBuilder, self).__exit__(etype, *args) + if self._use_control_ops and len(self._children) > 0: _children = self._children self._reset_children() diff --git a/caffe2/python/nomnigraph_test.py b/caffe2/python/nomnigraph_test.py index 3d9adc6964860..bd9d10fcbae13 100644 --- a/caffe2/python/nomnigraph_test.py +++ b/caffe2/python/nomnigraph_test.py @@ -3,7 +3,7 @@ -from caffe2.python import core, workspace, test_util +from caffe2.python import core, test_util from caffe2.proto import caffe2_pb2 import caffe2.python.nomnigraph as ng diff --git a/caffe2/python/normalizer_context.py b/caffe2/python/normalizer_context.py index a85b993b4502b..9559024bbcd39 100644 --- a/caffe2/python/normalizer_context.py +++ b/caffe2/python/normalizer_context.py @@ -10,8 +10,7 @@ ModifierContext, UseModifierBase) -@context.define_context(allow_default=True) -class NormalizerContext(ModifierContext): +class NormalizerContext(ModifierContext, context.DefaultManaged): """ provide context to allow param_info to have different normalizers """ diff --git a/caffe2/python/onnx/backend.py b/caffe2/python/onnx/backend.py index d0f768e42eebd..193a6f217f933 100644 --- a/caffe2/python/onnx/backend.py +++ b/caffe2/python/onnx/backend.py @@ -5,14 +5,7 @@ To run this, you will need to have Caffe2 installed as well. """ - - - - - -import os import collections -from subprocess import Popen, PIPE import sys import zipfile import itertools @@ -23,16 +16,13 @@ # importing onnx first, which will cause it to go out and pick up the # system protobuf. import onnx.backend - -import caffe2 from caffe2.python import core, workspace, rnn_cell, gru_cell -from caffe2.python.compatibility import container_abcs from caffe2.python.model_helper import ModelHelper from caffe2.proto import caffe2_pb2 import caffe2.python.utils import numpy as np import onnx -from onnx import checker, GraphProto, TensorProto, AttributeProto, ModelProto +from onnx import TensorProto import onnx.numpy_helper import onnx.defs import onnx.optimizer @@ -42,7 +32,6 @@ from caffe2.python.onnx.workspace import Workspace from caffe2.python.onnx.backend_rep import Caffe2Rep -from caffe2.python.onnx.backend_cpp_rep import Caffe2CppRep import caffe2.python._import_c_extension as C @@ -705,7 +694,13 @@ def prepare(cls, model, device='CPU', raw_values_dict=None, **kwargs): else: opset_version = 1 - model = onnx.shape_inference.infer_shapes(model) + # Prior to onnx version update to onnx-1.8.0, errors caused by failures in + # in the onnx shape inference call were being supressed. Hence a try-catch block + # is added around the infer_shapes call to avoid these failures and preserve status + try: + model = onnx.shape_inference.infer_shapes(model) + except RuntimeError: + warnings.warn("ShapeInferenceWarning: Inferred shape and existing shape differ in rank") ws = Workspace() device_option = get_device_option(Device(device)) @@ -775,7 +770,7 @@ def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version ops = translator(init_model, pred_model, OnnxNode(node_def), opset_version) if isinstance(ops, Caffe2Ops): return ops - if not isinstance(ops, container_abcs.Iterable): + if not isinstance(ops, collections.abc.Iterable): ops = [ops] return Caffe2Ops(ops, [], []) @@ -873,7 +868,13 @@ def _graph_to_net(cls, onnx_graph, opset_version): def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_initializers): device_option = get_device_option(Device(device)) - onnx_model = onnx.utils.polish_model(onnx_model) + # Prior to onnx version update to onnx-1.8.0, errors caused by failures in + # in the onnx shape inference call were being supressed. Hence a try-catch block + # is added around the infer_shapes call to avoid these failures and preserve status + try: + onnx_model = onnx.utils.polish_model(onnx_model) + except RuntimeError: + warnings.warn("ShapeInferenceWarning: Inferred shape and existing shape differ in rank") init_model = cls.optimize_onnx(onnx_model, init=True) pred_model = cls.optimize_onnx(onnx_model, predict=True) diff --git a/caffe2/python/onnx/bin/conversion.py b/caffe2/python/onnx/bin/conversion.py index 126eef8a84704..7e469e514a738 100644 --- a/caffe2/python/onnx/bin/conversion.py +++ b/caffe2/python/onnx/bin/conversion.py @@ -9,8 +9,7 @@ from caffe2.proto import caffe2_pb2 import click -import numpy as np -from onnx import checker, ModelProto +from onnx import ModelProto from caffe2.python.onnx.backend import Caffe2Backend as c2 import caffe2.python.onnx.frontend as c2_onnx diff --git a/caffe2/python/onnx/frontend.py b/caffe2/python/onnx/frontend.py index ee3c30949ff74..b5121602aff5b 100644 --- a/caffe2/python/onnx/frontend.py +++ b/caffe2/python/onnx/frontend.py @@ -10,22 +10,18 @@ - +import collections import itertools import logging import re from caffe2.python import core as caffe2_core -from caffe2.python.compatibility import container_abcs -from caffe2.proto import caffe2_legacy_pb2 -from enum import Enum -from onnx import (defs, checker, helper, numpy_helper, mapping, - ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto, OperatorSetIdProto) -from onnx.helper import make_tensor, make_tensor_value_info, make_attribute, make_model +from onnx import (checker, helper, numpy_helper, mapping, + GraphProto, NodeProto, TensorProto, OperatorSetIdProto) +from onnx.helper import make_tensor_value_info, make_model import numpy as np from caffe2.python.onnx.helper import c2_native_run_net -from caffe2.python.onnx.error import Unsupported import caffe2.python._import_c_extension as C @@ -156,7 +152,7 @@ def caffe2_op_to_onnx_node(cls, op_def, shapes): const_tensors = [] if isinstance(nodes, tuple): nodes, const_tensors = nodes - if not isinstance(nodes, container_abcs.Iterable): + if not isinstance(nodes, collections.abc.Iterable): nodes = [nodes] return nodes, const_tensors diff --git a/caffe2/python/onnx/helper.py b/caffe2/python/onnx/helper.py index 7f8f1a6d346ab..6e73a5d5c95d6 100644 --- a/caffe2/python/onnx/helper.py +++ b/caffe2/python/onnx/helper.py @@ -9,9 +9,6 @@ from onnx.backend.base import namedtupledict from caffe2.python.onnx.workspace import Workspace -import caffe2.python._import_c_extension as C - -import io import logging import time diff --git a/caffe2/python/onnx/onnxifi.py b/caffe2/python/onnx/onnxifi.py index a04e7e4554b93..3e67c4948b1f1 100644 --- a/caffe2/python/onnx/onnxifi.py +++ b/caffe2/python/onnx/onnxifi.py @@ -11,9 +11,7 @@ from caffe2.proto import caffe2_pb2 -from caffe2.python import core, workspace import caffe2.python._import_c_extension as C -import numpy as np def onnxifi_caffe2_net( diff --git a/caffe2/python/onnx/test_onnxifi.py b/caffe2/python/onnx/test_onnxifi.py index 7eafccaec9e44..4316149d5bf67 100644 --- a/caffe2/python/onnx/test_onnxifi.py +++ b/caffe2/python/onnx/test_onnxifi.py @@ -3,16 +3,14 @@ -import json import numpy as np -import os import time import unittest import onnx import onnx.defs from onnx.backend.base import namedtupledict -from onnx.helper import make_node, make_graph, make_tensor, make_tensor_value_info, make_model +from onnx.helper import make_node, make_graph, make_tensor_value_info, make_model from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace from caffe2.python.models.download import ModelDownloader diff --git a/caffe2/python/onnx/tests/c2_ref_test.py b/caffe2/python/onnx/tests/c2_ref_test.py index d253b06658a3c..aab5a04a169cc 100644 --- a/caffe2/python/onnx/tests/c2_ref_test.py +++ b/caffe2/python/onnx/tests/c2_ref_test.py @@ -6,9 +6,7 @@ -import json import os -import six import unittest from caffe2.python import core @@ -18,7 +16,7 @@ from onnx.helper import make_node, make_graph, make_tensor, make_tensor_value_info, make_model from caffe2.python.onnx.helper import c2_native_run_net, c2_native_run_op -from onnx import defs, mapping +from onnx import mapping import caffe2.python.onnx.frontend as c2_onnx import caffe2.python.onnx.backend as c2 @@ -44,9 +42,8 @@ def test_check_arguments(self): b2.convert_node(node_def.SerializeToString()) bad_node_def = make_node("Add", inputs=["X", "Y"], outputs=["Z"], foo=42, bar=56) - with six.assertRaisesRegex(self, - RuntimeError, - "Don't know how to map unexpected argument (foo|bar)"): + with self.assertRaisesRegex(RuntimeError, + "Don't know how to map unexpected argument (foo|bar)"): b2.convert_node(bad_node_def.SerializeToString()) def test_dynamicslice_3inputs_graph(self): diff --git a/caffe2/python/onnx/tests/conversion_test.py b/caffe2/python/onnx/tests/conversion_test.py index 86cdddcd16924..1bb457491b85e 100644 --- a/caffe2/python/onnx/tests/conversion_test.py +++ b/caffe2/python/onnx/tests/conversion_test.py @@ -6,7 +6,6 @@ import json -import six import tempfile import textwrap import traceback @@ -82,9 +81,9 @@ def test_caffe2_to_onnx_value_info(self): caffe2_net.flush() args = [caffe2_net.name, '--output', output.name] - six.assertRaisesRegex(self, Exception, - 'value info', - self._run_command, caffe2_to_onnx, args) + self.assertRaisesRegex(Exception, + 'value info', + self._run_command, caffe2_to_onnx, args) args.extend([ '--value-info', @@ -110,12 +109,12 @@ def test_onnx_to_caffe2(self): [node_def], "test", [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3)), - helper.make_tensor_value_info("W", TensorProto.FLOAT, (3, 2))], - [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 2))], + helper.make_tensor_value_info("W", TensorProto.FLOAT, (1, 3))], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 3))], initializer=[helper.make_tensor("W", TensorProto.FLOAT, - [3, 2], - np.zeros((3, 2)).flatten().astype(float))]) + [1, 3], + np.zeros((1, 3)).flatten().astype(float))]) model_def = helper.make_model(graph_def, producer_name='onnx-to-caffe2-test') onnx_model.write(model_def.SerializeToString()) onnx_model.flush() @@ -156,7 +155,7 @@ def test_onnx_to_caffe2_zipfile(self): initializer=[helper.make_tensor("W", TensorProto.FLOAT, [3, 2], - b'__EXTERNAL', + W.tobytes(), raw=True)]) model_def = helper.make_model(graph_def, producer_name='onnx-to-caffe2-test') onnx_model.writestr('__MODEL_PROTO', model_def.SerializeToString()) @@ -222,11 +221,11 @@ def _make_fake_loop_op(self, body_nodes, input_types, output_types): # lcd is a dummy loop-carried dependency that only exists because # right now the schema checker is broken and assumes a variadic # input needs at least one value. - graph_inputs = [helper.make_tensor_value_info("i", TensorProto.INT32, ()), - helper.make_tensor_value_info("cond", TensorProto.BOOL, ())] + graph_inputs = [helper.make_tensor_value_info("i", TensorProto.INT64, (1,)), + helper.make_tensor_value_info("cond", TensorProto.BOOL, (1,))] for type, shape, name in input_types: graph_inputs.append(helper.make_tensor_value_info("_" + name, type, shape)) - graph_outputs = [helper.make_tensor_value_info("cond", TensorProto.BOOL, ())] + graph_outputs = [helper.make_tensor_value_info("cond", TensorProto.BOOL, (1,))] for type, shape, name in output_types: graph_outputs.append(helper.make_tensor_value_info("_" + name, type, shape)) body_graph = helper.make_graph(body_nodes, "body_graph", graph_inputs, diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index e4de0a19c07a5..e8b718a5a2be3 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -13,7 +13,7 @@ import caffe2.python.onnx.backend as c2 -from caffe2.python import core, workspace +from caffe2.python import core core.SetEnginePref({}, {}) # This is a pytest magic variable to load extra plugins @@ -129,6 +129,21 @@ '|test_pow_types.*' ')') +# Temporarily skip some ONNX backend tests due to updates in opset 13. +backend_test.exclude('(test_if_.*' # added support for sequence type inputs + '|test_if_seq_.*' # added support for sequence type inputs + '|test_logsoftmax_.*' # axis attr default value changed from 1 to -1 + '|test_loop11_.*' # seg fault issue + '|test_loop13_seq_.*' # no support for sequence inputs for scan input + '|test_reduce_sum_.*' # axes is now an input (not attr), added noop_with_empty_axes + '|test_softmax_.*' # axis attr default value changed from 1 to -1 + '|test_split_variable_parts_.*' # axes is now an input (not attr) + '|test_squeeze_.*' # axes is now an input (not attr) + '|test_unsqueeze_.*' # axes is now an input (not attr) + '|test_MaxPool1d_stride_padding_dilation_.*' + '|test_MaxPool2d_stride_padding_dilation_.*' + ')') + # Skip vgg to speed up CI if 'JENKINS_URL' in os.environ: backend_test.exclude(r'(test_vgg19|test_vgg)') diff --git a/caffe2/python/onnx/tests/ssa_test.py b/caffe2/python/onnx/tests/ssa_test.py index d34d4a0e52876..96f954037178b 100644 --- a/caffe2/python/onnx/tests/ssa_test.py +++ b/caffe2/python/onnx/tests/ssa_test.py @@ -7,11 +7,10 @@ import copy -import onnx import numpy as np from caffe2.proto import caffe2_pb2 from caffe2.python import core -from onnx import helper, TensorProto +from onnx import TensorProto import caffe2.python.onnx.frontend as c2_onnx from caffe2.python.onnx.helper import c2_native_run_net diff --git a/caffe2/python/onnx/tests/test_utils.py b/caffe2/python/onnx/tests/test_utils.py index d224daf05ba3e..bebfc1012957d 100644 --- a/caffe2/python/onnx/tests/test_utils.py +++ b/caffe2/python/onnx/tests/test_utils.py @@ -6,7 +6,6 @@ -import os import unittest import numpy as np diff --git a/caffe2/python/operator_fp_exceptions_test.py b/caffe2/python/operator_fp_exceptions_test.py index 3a1ebcd4ec67f..f039ef09f637f 100644 --- a/caffe2/python/operator_fp_exceptions_test.py +++ b/caffe2/python/operator_fp_exceptions_test.py @@ -3,7 +3,6 @@ from caffe2.python import core, workspace -from caffe2.proto import caffe2_pb2 from caffe2.python.test_util import TestCase import numpy as np diff --git a/caffe2/python/operator_test/activation_ops_test.py b/caffe2/python/operator_test/activation_ops_test.py index 132bee879f6d2..7e5c5f423606d 100644 --- a/caffe2/python/operator_test/activation_ops_test.py +++ b/caffe2/python/operator_test/activation_ops_test.py @@ -263,5 +263,32 @@ def gelu_ref(X): ensure_outputs_are_inferred=True) + @given(n=st.integers(0, 6), m=st.integers(4, 6), + seed=st.integers(0, 1000), **hu.gcs_cpu_only) + def test_mish(self, n, m, gc, dc, seed): + np.random.seed(seed) + X = np.random.rand(n, m).astype(np.float32) + + def mish_ref(X): + return (X * np.tanh(np.log1p(np.exp(X))),) + + op = core.CreateOperator( + "Mish", + ["X"], + ["Y"] + ) + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=[X], + reference=mish_ref, + ensure_outputs_are_inferred=True, + ) + + self.assertGradientChecks( + gc, op, [X], 0, [0], ensure_outputs_are_inferred=True) + + if __name__ == "__main__": unittest.main() diff --git a/caffe2/python/operator_test/async_net_barrier_test.py b/caffe2/python/operator_test/async_net_barrier_test.py new file mode 100644 index 0000000000000..e2c0ea0ccc1a4 --- /dev/null +++ b/caffe2/python/operator_test/async_net_barrier_test.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np +from caffe2.python import core +from hypothesis import given + + +class TestAsyncNetBarrierOp(hu.HypothesisTestCase): + @given( + n=st.integers(1, 5), + shape=st.lists(st.integers(0, 5), min_size=1, max_size=3), + **hu.gcs + ) + def test_async_net_barrier_op(self, n, shape, dc, gc): + test_inputs = [(100 * np.random.random(shape)).astype(np.float32) for _ in range(n)] + test_input_blobs = ["x_{}".format(i) for i in range(n)] + + barrier_op = core.CreateOperator( + "AsyncNetBarrier", + test_input_blobs, + test_input_blobs, + device_option=gc, + ) + + def reference_func(*args): + self.assertEquals(len(args), n) + return args + + self.assertReferenceChecks(gc, barrier_op, test_inputs, reference_func) diff --git a/caffe2/python/operator_test/blobs_queue_db_test.py b/caffe2/python/operator_test/blobs_queue_db_test.py index 6cf8170b34f89..88197d16d70b5 100644 --- a/caffe2/python/operator_test/blobs_queue_db_test.py +++ b/caffe2/python/operator_test/blobs_queue_db_test.py @@ -3,7 +3,6 @@ -import unittest import numpy as np import caffe2.proto.caffe2_pb2 as caffe2_pb2 diff --git a/caffe2/python/operator_test/boolean_mask_test.py b/caffe2/python/operator_test/boolean_mask_test.py index 05b8212242e4b..38fe43899990a 100644 --- a/caffe2/python/operator_test/boolean_mask_test.py +++ b/caffe2/python/operator_test/boolean_mask_test.py @@ -2,7 +2,6 @@ -from caffe2.proto import caffe2_pb2 from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial diff --git a/caffe2/python/operator_test/bucketize_op_test.py b/caffe2/python/operator_test/bucketize_op_test.py index bf9af112a5b01..2eb2acf87902b 100644 --- a/caffe2/python/operator_test/bucketize_op_test.py +++ b/caffe2/python/operator_test/bucketize_op_test.py @@ -2,10 +2,9 @@ -from caffe2.python import core, dyndep +from caffe2.python import core from hypothesis import given import caffe2.python.hypothesis_test_util as hu -import hypothesis.strategies as st import numpy as np diff --git a/caffe2/python/operator_test/concat_split_op_test.py b/caffe2/python/operator_test/concat_split_op_test.py index 1927b4eac78fc..ac83681f08bf6 100644 --- a/caffe2/python/operator_test/concat_split_op_test.py +++ b/caffe2/python/operator_test/concat_split_op_test.py @@ -3,8 +3,7 @@ -from caffe2.proto import caffe2_pb2 -from caffe2.python import core, workspace +from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial from hypothesis import given, settings diff --git a/caffe2/python/operator_test/conv_test.py b/caffe2/python/operator_test/conv_test.py index ae54cd37a91da..e600aa2c9ee96 100644 --- a/caffe2/python/operator_test/conv_test.py +++ b/caffe2/python/operator_test/conv_test.py @@ -2,7 +2,6 @@ import collections import functools -import os import unittest import caffe2.python._import_c_extension as C diff --git a/caffe2/python/operator_test/cosine_embedding_criterion_op_test.py b/caffe2/python/operator_test/cosine_embedding_criterion_op_test.py index 04bfbbe6f4f60..d979407321a40 100644 --- a/caffe2/python/operator_test/cosine_embedding_criterion_op_test.py +++ b/caffe2/python/operator_test/cosine_embedding_criterion_op_test.py @@ -3,7 +3,6 @@ -from hypothesis import given import hypothesis.strategies as st import numpy as np diff --git a/caffe2/python/operator_test/crf_test.py b/caffe2/python/operator_test/crf_test.py index b75e7b7b1a104..4d7b90c431a6a 100644 --- a/caffe2/python/operator_test/crf_test.py +++ b/caffe2/python/operator_test/crf_test.py @@ -9,7 +9,6 @@ import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st from hypothesis import given, settings -import unittest class TestCRFOp(hu.HypothesisTestCase): diff --git a/caffe2/python/operator_test/cross_entropy_ops_test.py b/caffe2/python/operator_test/cross_entropy_ops_test.py index d1852e7dd9e83..c88f93503a15e 100644 --- a/caffe2/python/operator_test/cross_entropy_ops_test.py +++ b/caffe2/python/operator_test/cross_entropy_ops_test.py @@ -9,7 +9,6 @@ import numpy as np import unittest -import os def sigmoid(x): return 1.0 / (1.0 + np.exp(-x)) diff --git a/caffe2/python/operator_test/ctc_beam_search_decoder_op_test.py b/caffe2/python/operator_test/ctc_beam_search_decoder_op_test.py index 1dda7166e65a4..29440c00a4b3c 100644 --- a/caffe2/python/operator_test/ctc_beam_search_decoder_op_test.py +++ b/caffe2/python/operator_test/ctc_beam_search_decoder_op_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from caffe2.python.test_util import caffe2_flaky from collections import defaultdict, Counter from hypothesis import given, settings import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/operator_test/cudnn_recurrent_test.py b/caffe2/python/operator_test/cudnn_recurrent_test.py index db1b826cfe414..ef4433a41a185 100644 --- a/caffe2/python/operator_test/cudnn_recurrent_test.py +++ b/caffe2/python/operator_test/cudnn_recurrent_test.py @@ -4,7 +4,6 @@ from caffe2.python import model_helper, workspace, core, rnn_cell -from caffe2.proto import caffe2_pb2 from future.utils import viewitems import numpy as np diff --git a/caffe2/python/operator_test/dataset_ops_test.py b/caffe2/python/operator_test/dataset_ops_test.py index 96d93dc5effb8..a7e01570a22ab 100644 --- a/caffe2/python/operator_test/dataset_ops_test.py +++ b/caffe2/python/operator_test/dataset_ops_test.py @@ -1,32 +1,32 @@ +import functools +import operator +import string - - - +import hypothesis.strategies as st import numpy as np -from caffe2.python import core, workspace, dataset +import numpy.testing as npt +from caffe2.python import core, dataset, workspace from caffe2.python.dataset import Const from caffe2.python.schema import ( - List, Field, Struct, Scalar, Map, from_blob_list, FetchRecord, NewRecord, - FeedRecord + FeedRecord, + FetchRecord, + Field, + List, + Map, + NewRecord, + Scalar, + Struct, + from_blob_list, ) from caffe2.python.test_util import TestCase - -import numpy.testing as npt - -import string - from hypothesis import given -import hypothesis.strategies as st def _assert_arrays_equal(actual, ref, err_msg): - if ref.dtype.kind in ('S', 'O', 'U'): + if ref.dtype.kind in ("S", "O", "U"): np.testing.assert_array_equal(actual, ref, err_msg=err_msg) else: - np.testing.assert_allclose( - actual, ref, atol=1e-4, - rtol=1e-4, err_msg=err_msg - ) + np.testing.assert_allclose(actual, ref, atol=1e-4, rtol=1e-4, err_msg=err_msg) def _assert_records_equal(actual, ref): @@ -34,11 +34,12 @@ def _assert_records_equal(actual, ref): assert isinstance(ref, Field) b1 = actual.field_blobs() b2 = ref.field_blobs() - assert (len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % ( - len(b1), len(b2) + assert len(b1) == len(b2), "Records have different lengths: %d vs. %d" % ( + len(b1), + len(b2), ) for name, d1, d2 in zip(ref.field_names(), b1, b2): - _assert_arrays_equal(d1, d2, err_msg='Mismatch in field %s.' % name) + _assert_arrays_equal(d1, d2, err_msg="Mismatch in field %s." % name) @st.composite @@ -47,7 +48,7 @@ def _sparse_features_map(draw, num_records, **kwargs): st.lists( st.integers(min_value=1, max_value=10), min_size=num_records, - max_size=num_records + max_size=num_records, ) ) @@ -58,7 +59,7 @@ def _sparse_features_map(draw, num_records, **kwargs): st.integers(min_value=1, max_value=100), min_size=sparse_maps_total_length, max_size=sparse_maps_total_length, - unique=True + unique=True, ) ) @@ -66,7 +67,7 @@ def _sparse_features_map(draw, num_records, **kwargs): st.lists( st.integers(min_value=1, max_value=10), min_size=sparse_maps_total_length, - max_size=sparse_maps_total_length + max_size=sparse_maps_total_length, ) ) @@ -77,7 +78,7 @@ def _sparse_features_map(draw, num_records, **kwargs): st.lists( st.integers(min_value=1, max_value=9223372036854775807), min_size=total_sparse_values_lengths, - max_size=total_sparse_values_lengths + max_size=total_sparse_values_lengths, ) ) @@ -95,7 +96,7 @@ def _dense_features_map(draw, num_records, **kwargs): st.lists( st.integers(min_value=1, max_value=10), min_size=num_records, - max_size=num_records + max_size=num_records, ) ) @@ -106,14 +107,12 @@ def _dense_features_map(draw, num_records, **kwargs): st.integers(min_value=1, max_value=100), min_size=total_length, max_size=total_length, - unique=True + unique=True, ) ) float_values = draw( - st.lists(st.floats(), - min_size=total_length, - max_size=total_length) + st.lists(st.floats(), min_size=total_length, max_size=total_length) ) return [float_lengths, float_keys, float_values] @@ -123,22 +122,20 @@ def _dense_features_map(draw, num_records, **kwargs): def _dataset(draw, min_elements=3, max_elements=10, **kwargs): schema = Struct( # Dense Features Map - ('floats', Map( - Scalar(np.int32), Scalar(np.float32) - )), + ("floats", Map(Scalar(np.int32), Scalar(np.float32))), # Sparse Features Map - ('int_lists', Map( - Scalar(np.int32), - List(Scalar(np.int64)), - )), + ( + "int_lists", + Map( + Scalar(np.int32), + List(Scalar(np.int64)), + ), + ), # Complex Type - ('text', Scalar(str)), + ("text", Scalar(str)), ) - num_records = draw( - st.integers(min_value=min_elements, - max_value=max_elements) - ) + num_records = draw(st.integers(min_value=min_elements, max_value=max_elements)) raw_dense_features_map_contents = draw(_dense_features_map(num_records)) @@ -149,13 +146,17 @@ def _dataset(draw, min_elements=3, max_elements=10, **kwargs): st.lists( st.text(alphabet=string.ascii_lowercase), min_size=num_records, - max_size=num_records + max_size=num_records, ) ) ] # Concatenate all raw contents to a single one - contents_raw = raw_dense_features_map_contents + raw_sparse_features_map_contents + raw_text_contents + contents_raw = ( + raw_dense_features_map_contents + + raw_sparse_features_map_contents + + raw_text_contents + ) contents = from_blob_list(schema, contents_raw) @@ -172,31 +173,28 @@ def test_pack_unpack(self, input): dataset_fields = schema.field_names() - for pack_to_single_shared_ptr in (True, False): - net = core.Net('pack_unpack_net') + net = core.Net("pack_unpack_net") batch = NewRecord(net, contents) FeedRecord(batch, contents) packed = net.PackRecords( - batch.field_blobs(), 1, + batch.field_blobs(), + 1, fields=dataset_fields, - pack_to_single_shared_ptr=pack_to_single_shared_ptr + pack_to_single_shared_ptr=pack_to_single_shared_ptr, ) unpacked = packed.UnPackRecords( - [], len(dataset_fields), - fields=dataset_fields + [], len(dataset_fields), fields=dataset_fields ) workspace.RunNetOnce(net) - for initial_tensor, unpacked_tensor in zip( - batch.field_blobs(), unpacked - ): + for initial_tensor, unpacked_tensor in zip(batch.field_blobs(), unpacked): npt.assert_array_equal( workspace.FetchBlob(initial_tensor), - workspace.FetchBlob(unpacked_tensor) + workspace.FetchBlob(unpacked_tensor), ) def test_dataset_ops(self): @@ -207,35 +205,38 @@ def test_dataset_ops(self): """ schema = Struct( # fixed size vector, which will be stored as a matrix when batched - ('dense', Scalar((np.float32, 3))), + ("dense", Scalar((np.float32, 3))), # could represent a feature map from feature ID to float value - ('floats', Map( - Scalar(np.int32), Scalar(np.float32) - )), + ("floats", Map(Scalar(np.int32), Scalar(np.float32))), # could represent a multi-valued categorical feature map - ('int_lists', Map( - Scalar(np.int32), - List(Scalar(np.int64)), - )), + ( + "int_lists", + Map( + Scalar(np.int32), + List(Scalar(np.int64)), + ), + ), # could represent a multi-valued, weighted categorical feature map ( - 'id_score_pairs', Map( + "id_score_pairs", + Map( Scalar(np.int32), Map( Scalar(np.int64), Scalar(np.float32), - keys_name='ids', - values_name='scores' + keys_name="ids", + values_name="scores", ), - ) + ), ), # additional scalar information ( - 'metadata', Struct( - ('user_id', Scalar(np.int64)), - ('user_embed', Scalar((np.float32, 2))), - ('query', Scalar(str)), - ) + "metadata", + Struct( + ("user_id", Scalar(np.int64)), + ("user_embed", Scalar((np.float32, 2))), + ("query", Scalar(str)), + ), ), ) """ @@ -244,26 +245,24 @@ def test_dataset_ops(self): written as a tensor. """ expected_fields = [ - ('dense', (np.float32, 3)), - ('floats:lengths', np.int32), - ('floats:values:keys', np.int32), - ('floats:values:values', np.float32), - ('int_lists:lengths', np.int32), - ('int_lists:values:keys', np.int32), - ('int_lists:values:values:lengths', np.int32), - ('int_lists:values:values:values', np.int64), - ('id_score_pairs:lengths', np.int32), - ('id_score_pairs:values:keys', np.int32), - ('id_score_pairs:values:values:lengths', np.int32), - ('id_score_pairs:values:values:values:ids', np.int64), - ('id_score_pairs:values:values:values:scores', np.float32), - ('metadata:user_id', np.int64), - ('metadata:user_embed', (np.float32, 2)), - ('metadata:query', str), + ("dense", (np.float32, 3)), + ("floats:lengths", np.int32), + ("floats:values:keys", np.int32), + ("floats:values:values", np.float32), + ("int_lists:lengths", np.int32), + ("int_lists:values:keys", np.int32), + ("int_lists:values:values:lengths", np.int32), + ("int_lists:values:values:values", np.int64), + ("id_score_pairs:lengths", np.int32), + ("id_score_pairs:values:keys", np.int32), + ("id_score_pairs:values:values:lengths", np.int32), + ("id_score_pairs:values:values:values:ids", np.int64), + ("id_score_pairs:values:values:values:scores", np.float32), + ("metadata:user_id", np.int64), + ("metadata:user_embed", (np.float32, 2)), + ("metadata:query", str), ] - zipped = zip( - expected_fields, schema.field_names(), schema.field_types() - ) + zipped = zip(expected_fields, schema.field_names(), schema.field_types()) for (ref_name, ref_type), name, dtype in zipped: self.assertEquals(ref_name, name) self.assertEquals(np.dtype(ref_type), dtype) @@ -295,7 +294,7 @@ def test_dataset_ops(self): # metadata [123, 234, 456], # user_id [[0.2, 0.8], [0.5, 0.5], [0.7, 0.3]], # user_embed - ['dog posts', 'friends who like to', 'posts about ca'], # query + ["dog posts", "friends who like to", "posts about ca"], # query ] # convert the above content to ndarrays, checking against the schema contents = from_blob_list(schema, contents_raw) @@ -305,8 +304,8 @@ def test_dataset_ops(self): Then, a Writer is used to append these entries to the dataset. """ ds = dataset.Dataset(schema) - net = core.Net('init') - with core.NameScope('init'): + net = core.Net("init") + with core.NameScope("init"): ds.init_empty(net) content_blobs = NewRecord(net, contents) @@ -337,7 +336,7 @@ def test_dataset_ops(self): [11.1], # id score pairs [123], [[0.2, 0.8]], - ['dog posts'], # metadata + ["dog posts"], # metadata ), ( [[2.1, 2.2, 2.3]], # dense @@ -355,7 +354,7 @@ def test_dataset_ops(self): [21.1, 22.1, 22.2], [234], [[0.5, 0.5]], - ['friends who like to'], # metadata + ["friends who like to"], # metadata ), ( [[3.1, 3.2, 3.3]], # dense @@ -373,11 +372,11 @@ def test_dataset_ops(self): [31.1, 31.2, 32.1, 32.2, 32.3], # id score list [456], [[0.7, 0.3]], - ['posts about ca'], # metadata + ["posts about ca"], # metadata ), # after the end of the dataset, we will keep getting empty vectors - ([], ) * 16, - ([], ) * 16, + ([],) * 16, + ([],) * 16, ] entries = [from_blob_list(schema, e) for e in entries_raw] """ @@ -385,8 +384,8 @@ def test_dataset_ops(self): We will run `read` net multiple times and assert that we are reading the entries the way we stated above. """ - read_init_net = core.Net('read_init') - read_next_net = core.Net('read_next') + read_init_net = core.Net("read_init") + read_next_net = core.Net("read_next") reader = ds.reader(read_init_net) should_continue, batch = reader.read_record(read_next_net) @@ -407,11 +406,11 @@ def test_dataset_ops(self): Where we will process the dataset a little and store it in a second dataset. We can reuse the same Reader since it supports reset. """ - reset_net = core.Net('reset_net') + reset_net = core.Net("reset_net") reader.reset(reset_net) read_step, batch = reader.execution_step() """ We will add the line number * 1000 to the feature ids. """ - process_net = core.Net('process') + process_net = core.Net("process") line_no = Const(process_net, 0, dtype=np.int32) const_one = Const(process_net, 1000, dtype=np.int32) process_net.Add([line_no, const_one], [line_no]) @@ -419,19 +418,19 @@ def test_dataset_ops(self): process_net.Print(field, []) process_net.Add([field, line_no], field, broadcast=1, axis=0) """ Lets create a second dataset and append to it. """ - ds2 = dataset.Dataset(schema, name='dataset2') + ds2 = dataset.Dataset(schema, name="dataset2") ds2.init_empty(reset_net) writer = ds2.writer(reset_net) writer.write_record(process_net, batch) # commit is not necessary for DatasetWriter but will add it for # generality of the example - commit_net = core.Net('commit') + commit_net = core.Net("commit") writer.commit(commit_net) """ Time to create and run a plan which will do the processing """ - plan = core.Plan('process') - plan.AddStep(core.execution_step('reset', reset_net)) + plan = core.Plan("process") + plan.AddStep(core.execution_step("reset", reset_net)) plan.AddStep(read_step.AddNet(process_net)) - plan.AddStep(core.execution_step('commit', commit_net)) + plan.AddStep(core.execution_step("commit", commit_net)) workspace.RunPlan(plan) """ Now we should have dataset2 populated. @@ -446,18 +445,18 @@ def test_dataset_ops(self): You can create a new schema from pieces of another schema and reuse the same data. """ - subschema = Struct(('top_level', schema.int_lists.values)) + subschema = Struct(("top_level", schema.int_lists.values)) int_list_contents = contents.int_lists.values.field_names() self.assertEquals(len(subschema.field_names()), len(int_list_contents)) """ 7. Random Access a dataset """ - read_init_net = core.Net('read_init') - read_next_net = core.Net('read_next') + read_init_net = core.Net("read_init") + read_next_net = core.Net("read_next") idx = np.array([2, 1, 0]) - indices_blob = Const(read_init_net, idx, name='indices') + indices_blob = Const(read_init_net, idx, name="indices") reader = ds.random_reader(read_init_net, indices_blob) reader.computeoffset(read_init_net) @@ -480,11 +479,11 @@ def test_dataset_ops(self): 8. Random Access a dataset with loop_over = true """ - read_init_net = core.Net('read_init') - read_next_net = core.Net('read_next') + read_init_net = core.Net("read_init") + read_next_net = core.Net("read_next") idx = np.array([2, 1, 0]) - indices_blob = Const(read_init_net, idx, name='indices') + indices_blob = Const(read_init_net, idx, name="indices") reader = ds.random_reader(read_init_net, indices_blob, loop_over=True) reader.computeoffset(read_init_net) @@ -506,11 +505,11 @@ def test_dataset_ops(self): before shuffling the chunks. """ - read_init_net = core.Net('read_init') - read_next_net = core.Net('read_next') + read_init_net = core.Net("read_init") + read_next_net = core.Net("read_next") reader = ds.random_reader(read_init_net) - reader.sort_and_shuffle(read_init_net, 'int_lists:lengths', 1, 2) + reader.sort_and_shuffle(read_init_net, "int_lists:lengths", 1, 2) reader.computeoffset(read_init_net) should_continue, batch = reader.read_record(read_next_net) @@ -531,7 +530,7 @@ def test_dataset_ops(self): """ Trim a dataset """ - trim_net = core.Net('trim_ds') + trim_net = core.Net("trim_ds") ds.trim(trim_net, multiple_of=2) workspace.RunNetOnce(trim_net) trimmed = FetchRecord(ds.content()) @@ -540,67 +539,108 @@ def test_dataset_ops(self): self.assertEquals(EXPECTED_SIZES, actual_sizes) def test_last_n_window_ops(self): - collect_net = core.Net('collect_net') + collect_net = core.Net("collect_net") collect_net.GivenTensorFill( [], - 'input', + "input", shape=[3, 2], values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ) - input_array =\ - np.array(list(range(1, 7)), dtype=np.float32).reshape(3, 2) + input_array = np.array(list(range(1, 7)), dtype=np.float32).reshape(3, 2) - workspace.CreateBlob('output') - workspace.FeedBlob('next', np.array(0, dtype=np.int32)) + workspace.CreateBlob("output") + workspace.FeedBlob("next", np.array(0, dtype=np.int32)) collect_net.LastNWindowCollector( - ['output', 'next', 'input'], - ['output', 'next'], + ["output", "next", "input"], + ["output", "next"], num_to_collect=7, ) - plan = core.Plan('collect_data') - plan.AddStep( - core.execution_step('collect_data', [collect_net], - num_iter=1) - ) + plan = core.Plan("collect_data") + plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=1)) workspace.RunPlan(plan) - reference_result = workspace.FetchBlob('output') + reference_result = workspace.FetchBlob("output") npt.assert_array_equal(input_array, reference_result) - plan = core.Plan('collect_data') - plan.AddStep( - core.execution_step('collect_data', [collect_net], - num_iter=2) - ) + plan = core.Plan("collect_data") + plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=2)) workspace.RunPlan(plan) - reference_result = workspace.FetchBlob('output') - npt.assert_array_equal(input_array[[1, 2, 2, 0, 1, 2, 0]], - reference_result) + reference_result = workspace.FetchBlob("output") + npt.assert_array_equal(input_array[[1, 2, 2, 0, 1, 2, 0]], reference_result) - plan = core.Plan('collect_data') - plan.AddStep( - core.execution_step('collect_data', [collect_net], - num_iter=3) - ) + plan = core.Plan("collect_data") + plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=3)) workspace.RunPlan(plan) - reference_result = workspace.FetchBlob('output') - npt.assert_array_equal(input_array[[2, 0, 1, 2, 2, 0, 1]], - reference_result) + reference_result = workspace.FetchBlob("output") + npt.assert_array_equal(input_array[[2, 0, 1, 2, 2, 0, 1]], reference_result) + + def test_last_n_window_ops_shape_inference(self): + collect_net = core.Net("collect_net") + collect_net.GivenTensorFill( + [], + "input", + shape=[3, 2], + values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + ) + + workspace.CreateBlob("output") + workspace.FeedBlob("next", np.array(0, dtype=np.int32)) + collect_net.LastNWindowCollector( + ["output", "next", "input"], + ["output", "next"], + num_to_collect=7, + ) + (shapes, types) = workspace.InferShapesAndTypes([collect_net]) + workspace.RunNetOnce(collect_net) + + self.assertTrue( + np.array_equal( + shapes["output"], np.array([7, workspace.blobs["output"].shape[1]]) + ) + ) + + def test_last_n_window_ops_shape_inference_4d_input(self): + input_shape = [3, 2, 4, 5] + collect_net = core.Net("collect_net") + collect_net.GivenTensorFill( + [], + "input", + shape=input_shape, + values=[ + float(val) for val in range(functools.reduce(operator.mul, input_shape)) + ], + ) + + workspace.CreateBlob("output") + workspace.FeedBlob("next", np.array(0, dtype=np.int32)) + collect_net.LastNWindowCollector( + ["output", "next", "input"], + ["output", "next"], + num_to_collect=7, + ) + (shapes, types) = workspace.InferShapesAndTypes([collect_net]) + workspace.RunNetOnce(collect_net) + + self.assertTrue( + np.array_equal( + shapes["output"], np.array([7, *list(workspace.blobs["output"].shape[1:])]) + ) + ) def test_collect_tensor_ops(self): - init_net = core.Net('init_net') - blobs = ['blob_1', 'blob_2', 'blob_3'] + init_net = core.Net("init_net") + blobs = ["blob_1", "blob_2", "blob_3"] bvec_map = {} - ONE = init_net.ConstantFill([], 'ONE', shape=[1, 2], value=1) + ONE = init_net.ConstantFill([], "ONE", shape=[1, 2], value=1) for b in blobs: init_net.ConstantFill([], [b], shape=[1, 2], value=0) - bvec_map[b] = b + '_vec' + bvec_map[b] = b + "_vec" init_net.CreateTensorVector([], [bvec_map[b]]) - reader_net = core.Net('reader_net') + reader_net = core.Net("reader_net") for b in blobs: reader_net.Add([b, ONE], [b]) - collect_net = core.Net('collect_net') + collect_net = core.Net("collect_net") num_to_collect = 1000 max_example_to_cover = 100000 bvec = [bvec_map[b] for b in blobs] @@ -610,25 +650,24 @@ def test_collect_tensor_ops(self): num_to_collect=num_to_collect, ) - print('Collect Net Proto: {}'.format(collect_net.Proto())) + print("Collect Net Proto: {}".format(collect_net.Proto())) - plan = core.Plan('collect_data') - plan.AddStep(core.execution_step('collect_init', init_net)) + plan = core.Plan("collect_data") + plan.AddStep(core.execution_step("collect_init", init_net)) plan.AddStep( core.execution_step( - 'collect_data', [reader_net, collect_net], - num_iter=max_example_to_cover + "collect_data", [reader_net, collect_net], num_iter=max_example_to_cover ) ) workspace.RunPlan(plan) # concat the collected tensors - concat_net = core.Net('concat_net') + concat_net = core.Net("concat_net") bconcated_map = {} bsize_map = {} for b in blobs: - bconcated_map[b] = b + '_concated' - bsize_map[b] = b + '_size' + bconcated_map[b] = b + "_concated" + bsize_map[b] = b + "_size" concat_net.ConcatTensorVector([bvec_map[b]], [bconcated_map[b]]) concat_net.TensorVectorSize([bvec_map[b]], [bsize_map[b]]) @@ -637,19 +676,16 @@ def test_collect_tensor_ops(self): # check data reference_result = workspace.FetchBlob(bconcated_map[blobs[0]]) self.assertEqual( - reference_result.shape, - (min(num_to_collect, max_example_to_cover), 2) + reference_result.shape, (min(num_to_collect, max_example_to_cover), 2) ) size = workspace.FetchBlob(bsize_map[blobs[0]]) self.assertEqual(tuple(), size.shape) self.assertEqual(min(num_to_collect, max_example_to_cover), size.item()) hist, _ = np.histogram( - reference_result[:, 0], - bins=10, - range=(1, max_example_to_cover) + reference_result[:, 0], bins=10, range=(1, max_example_to_cover) ) - print('Sample histogram: {}'.format(hist)) + print("Sample histogram: {}".format(hist)) self.assertTrue(all(hist > 0.6 * (num_to_collect / 10))) for i in range(1, len(blobs)): @@ -659,4 +695,5 @@ def test_collect_tensor_ops(self): if __name__ == "__main__": import unittest + unittest.main() diff --git a/caffe2/python/operator_test/deform_conv_test.py b/caffe2/python/operator_test/deform_conv_test.py index f6ad0e38e73cc..67289de5e924a 100644 --- a/caffe2/python/operator_test/deform_conv_test.py +++ b/caffe2/python/operator_test/deform_conv_test.py @@ -1,6 +1,5 @@ -import os import unittest import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/operator_test/depthwise_3x3_conv_test.py b/caffe2/python/operator_test/depthwise_3x3_conv_test.py index 2d6d6429f8335..cdfffce288dde 100644 --- a/caffe2/python/operator_test/depthwise_3x3_conv_test.py +++ b/caffe2/python/operator_test/depthwise_3x3_conv_test.py @@ -5,7 +5,7 @@ import numpy as np import caffe2.python.hypothesis_test_util as hu -from caffe2.python import core, dyndep, utils, workspace +from caffe2.python import core, utils from hypothesis import given, settings import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/distance_op_test.py b/caffe2/python/operator_test/distance_op_test.py index e948fdae9673a..5b46548e072b5 100644 --- a/caffe2/python/operator_test/distance_op_test.py +++ b/caffe2/python/operator_test/distance_op_test.py @@ -6,7 +6,6 @@ from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial -from hypothesis import given import hypothesis.strategies as st import numpy as np diff --git a/caffe2/python/operator_test/elementwise_linear_op_test.py b/caffe2/python/operator_test/elementwise_linear_op_test.py index ac0dc3dd0975b..2bd85625a3d9d 100644 --- a/caffe2/python/operator_test/elementwise_linear_op_test.py +++ b/caffe2/python/operator_test/elementwise_linear_op_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/elementwise_ops_test.py b/caffe2/python/operator_test/elementwise_ops_test.py index ed7a09eb08571..31f70086de7be 100644 --- a/caffe2/python/operator_test/elementwise_ops_test.py +++ b/caffe2/python/operator_test/elementwise_ops_test.py @@ -10,7 +10,6 @@ import numpy as np import unittest -import os class TestElementwiseOps(hu.HypothesisTestCase): @@ -296,62 +295,6 @@ def test_cbrt_grad(self, X, in_place, gc, dc): ensure_outputs_are_inferred=True, ) - @given(n=st.integers(0, 6), m=st.integers(4, 6), - seed=st.integers(0, 1000), **hu.gcs_cpu_only) - def test_mish(self, n, m, gc, dc, seed): - np.random.seed(seed) - X = np.random.rand(n, m).astype(np.float32) - - def mish(X): - return [X * np.tanh(np.log(1 + np.exp(X)))] - - op = core.CreateOperator( - "Mish", - ["X"], - ["Z"] - ) - - self.assertReferenceChecks( - device_option=gc, - op=op, - inputs=[X], - reference=mish, - ensure_outputs_are_inferred=True, - ) - - self.assertGradientChecks( - gc, op, [X], 0, [0], stepsize=1e-4, threshold=1e-2, - ensure_outputs_are_inferred=True) - - @given(n=st.integers(0, 6), m=st.integers(4, 6), - seed=st.integers(0, 1000), **hu.gcs_cpu_only) - def test_mish_gradient_inplace(self, n, m, gc, dc, seed): - np.random.seed(seed) - - def mish(X): - return [X * np.tanh(np.log(1 + np.exp(X)))] - - def mish_gradient(X, Y, dY): - w = np.exp(3 * X) + 4 * np.exp(2 * X) + (6 + 4 * X) * np.exp(X) + 4 * (1 + X) - sigma2 = np.square(np.square(np.exp(X) + 1) + 1) - return [dY * np.exp(X) * w / sigma2] - # return [dY * (Y + np.divide(1. - Y, 1. + np.exp(-X)))] - - X = np.random.rand(n, m).astype(np.float32) - Y = mish(X)[0] - dY = np.random.rand(n, m).astype(np.float32) - op = core.CreateOperator( - "MishGradient", - ["X", "Y", "grad"], - "grad" - ) - - self.assertReferenceChecks( - device_option=gc, - op=op, - inputs=[X, Y, dY], - reference=mish_gradient, - ) @given(n=st.integers(0, 6), m=st.integers(4, 6), seed=st.integers(0, 1000), **hu.gcs) diff --git a/caffe2/python/operator_test/enforce_finite_op_test.py b/caffe2/python/operator_test/enforce_finite_op_test.py index b843bfdc95b9b..8150977945a25 100644 --- a/caffe2/python/operator_test/enforce_finite_op_test.py +++ b/caffe2/python/operator_test/enforce_finite_op_test.py @@ -8,7 +8,6 @@ from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu -import hypothesis.strategies as st class TestEnforceFinite(hu.HypothesisTestCase): diff --git a/caffe2/python/operator_test/expand_op_test.py b/caffe2/python/operator_test/expand_op_test.py index 0d198b1aff144..aba2c1106da37 100644 --- a/caffe2/python/operator_test/expand_op_test.py +++ b/caffe2/python/operator_test/expand_op_test.py @@ -3,7 +3,7 @@ -from caffe2.python import core, workspace +from caffe2.python import core from hypothesis import given, settings import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/operator_test/feature_maps_ops_test.py b/caffe2/python/operator_test/feature_maps_ops_test.py index 19fa329c93891..5a20b63166be5 100644 --- a/caffe2/python/operator_test/feature_maps_ops_test.py +++ b/caffe2/python/operator_test/feature_maps_ops_test.py @@ -2,7 +2,7 @@ -from caffe2.python import core, workspace, dyndep +from caffe2.python import core, workspace from caffe2.python.test_util import TestCase import numpy as np diff --git a/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py b/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py index 12d0b0265afbc..b7cb5f68351f5 100644 --- a/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py +++ b/caffe2/python/operator_test/fused_nbit_rowwise_conversion_ops_test.py @@ -46,7 +46,7 @@ def int8_to_bytes(int8s): if isinstance(as_bytes[0], int): byte_matrix[i] = list(as_bytes) else: - byte_matrix[i] = list(map(ord, as_bytes)) + byte_matrix[i] = [ord(i) for i in as_bytes] return byte_matrix diff --git a/caffe2/python/operator_test/glu_op_test.py b/caffe2/python/operator_test/glu_op_test.py index f38df09ec9fb1..7b7a33dcd90a4 100644 --- a/caffe2/python/operator_test/glu_op_test.py +++ b/caffe2/python/operator_test/glu_op_test.py @@ -6,7 +6,7 @@ from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial -from hypothesis import assume, given, settings, HealthCheck +from hypothesis import given, settings import hypothesis.strategies as st import numpy as np diff --git a/caffe2/python/operator_test/group_conv_test.py b/caffe2/python/operator_test/group_conv_test.py index 62aba236d5bae..8e864bb421526 100644 --- a/caffe2/python/operator_test/group_conv_test.py +++ b/caffe2/python/operator_test/group_conv_test.py @@ -12,7 +12,6 @@ import caffe2.python.hypothesis_test_util as hu import unittest -import os class TestGroupConvolution(hu.HypothesisTestCase): diff --git a/caffe2/python/operator_test/gru_test.py b/caffe2/python/operator_test/gru_test.py index 99444f39ac266..1a7db2634989d 100644 --- a/caffe2/python/operator_test/gru_test.py +++ b/caffe2/python/operator_test/gru_test.py @@ -16,7 +16,6 @@ import hypothesis.strategies as st import numpy as np import unittest -import os def gru_unit(*args, **kwargs): diff --git a/caffe2/python/operator_test/hyperbolic_ops_test.py b/caffe2/python/operator_test/hyperbolic_ops_test.py index 90a8197e7ccfc..c0a1e8f49f5ad 100644 --- a/caffe2/python/operator_test/hyperbolic_ops_test.py +++ b/caffe2/python/operator_test/hyperbolic_ops_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/im2col_col2im_test.py b/caffe2/python/operator_test/im2col_col2im_test.py index 760228382bc60..42cb1deaf8aeb 100644 --- a/caffe2/python/operator_test/im2col_col2im_test.py +++ b/caffe2/python/operator_test/im2col_col2im_test.py @@ -10,9 +10,6 @@ import hypothesis.strategies as st import numpy as np -import unittest -import os - class TestReduceFrontSum(hu.HypothesisTestCase): @given(batch_size=st.integers(1, 3), diff --git a/caffe2/python/operator_test/image_input_op_test.py b/caffe2/python/operator_test/image_input_op_test.py index 0de1f0ad048bb..6bed69af9ae08 100644 --- a/caffe2/python/operator_test/image_input_op_test.py +++ b/caffe2/python/operator_test/image_input_op_test.py @@ -13,7 +13,7 @@ from PIL import Image import numpy as np import shutil -import six +import io import sys import tempfile @@ -134,7 +134,7 @@ def create_test(output_dir, width, height, default_bound, minsize, crop, means, img_array = np.random.random_integers( 0, 255, [height, width, 3]).astype(np.uint8) img_obj = Image.fromarray(img_array) - img_str = six.BytesIO() + img_str = io.BytesIO() img_obj.save(img_str, 'PNG') # Create a random bounding box for every other image diff --git a/caffe2/python/operator_test/instance_norm_test.py b/caffe2/python/operator_test/instance_norm_test.py index fb4f3c935ba8e..efce9d7001fe7 100644 --- a/caffe2/python/operator_test/instance_norm_test.py +++ b/caffe2/python/operator_test/instance_norm_test.py @@ -11,7 +11,6 @@ import caffe2.python.serialized_test.serialized_test_util as serial import unittest -import os class TestInstanceNorm(serial.SerializedTestCase): diff --git a/caffe2/python/operator_test/jsd_ops_test.py b/caffe2/python/operator_test/jsd_ops_test.py index 6ed2db2e88c28..f205d8e650b2e 100644 --- a/caffe2/python/operator_test/jsd_ops_test.py +++ b/caffe2/python/operator_test/jsd_ops_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/layer_norm_op_test.py b/caffe2/python/operator_test/layer_norm_op_test.py index 62e94afe9e7d7..d402cce4c4f98 100644 --- a/caffe2/python/operator_test/layer_norm_op_test.py +++ b/caffe2/python/operator_test/layer_norm_op_test.py @@ -13,7 +13,6 @@ import hypothesis.strategies as st import numpy as np -import os import torch import unittest diff --git a/caffe2/python/operator_test/learning_rate_op_test.py b/caffe2/python/operator_test/learning_rate_op_test.py index bdce6a4c78f78..8d17c0c7ef08e 100644 --- a/caffe2/python/operator_test/learning_rate_op_test.py +++ b/caffe2/python/operator_test/learning_rate_op_test.py @@ -50,7 +50,7 @@ def ref(iter): def test_hill_learning_rate_op(self, gc, dc): iter = np.random.randint(low=1, high=1e5, size=1) - num_iter = int(np.random.randint(low=1e2, high=1e3, size=1)) + num_iter = int(np.random.randint(low=1e2, high=1e8, size=1)) start_multiplier = 1e-4 gamma = 1.0 power = 0.5 diff --git a/caffe2/python/operator_test/lengths_pad_op_test.py b/caffe2/python/operator_test/lengths_pad_op_test.py index 626ec0542b7da..cda2f7da323ef 100644 --- a/caffe2/python/operator_test/lengths_pad_op_test.py +++ b/caffe2/python/operator_test/lengths_pad_op_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/lengths_reducer_fused_nbit_rowwise_ops_test.py b/caffe2/python/operator_test/lengths_reducer_fused_nbit_rowwise_ops_test.py index fc4e89e2545bd..49b0ba7ec22c4 100644 --- a/caffe2/python/operator_test/lengths_reducer_fused_nbit_rowwise_ops_test.py +++ b/caffe2/python/operator_test/lengths_reducer_fused_nbit_rowwise_ops_test.py @@ -3,7 +3,7 @@ import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st import numpy as np -from caffe2.python import core, dyndep, workspace +from caffe2.python import core, workspace from hypothesis import given diff --git a/caffe2/python/operator_test/lengths_tile_op_test.py b/caffe2/python/operator_test/lengths_tile_op_test.py index e0a5f96095882..441fcc7478357 100644 --- a/caffe2/python/operator_test/lengths_tile_op_test.py +++ b/caffe2/python/operator_test/lengths_tile_op_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/lengths_top_k_ops_test.py b/caffe2/python/operator_test/lengths_top_k_ops_test.py index b8b082a021253..d2d55c531ec0a 100644 --- a/caffe2/python/operator_test/lengths_top_k_ops_test.py +++ b/caffe2/python/operator_test/lengths_top_k_ops_test.py @@ -19,7 +19,7 @@ def test_lengths_top_k_op(self, N, K, gc, dc): lens = np.random.randint(low=1, high=2 * K + 1, size=N).astype(np.int32) X = [] for i in lens: - X.extend(map(lambda x: x / 100.0, range(0, 6 * i, 6))) + X.extend(x / 100.0 for x in range(0, 6 * i, 6)) X = np.array(X, dtype=np.float32) op = core.CreateOperator("LengthsTopK", ["X", "Y"], ["values", "indices"], k=K) diff --git a/caffe2/python/operator_test/loss_ops_test.py b/caffe2/python/operator_test/loss_ops_test.py index 24cb65ac96f8a..f6a07ead3cf99 100644 --- a/caffe2/python/operator_test/loss_ops_test.py +++ b/caffe2/python/operator_test/loss_ops_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py index b8cef19b24dfe..8b4001a574ac6 100644 --- a/caffe2/python/operator_test/matmul_op_test.py +++ b/caffe2/python/operator_test/matmul_op_test.py @@ -9,8 +9,6 @@ from hypothesis import assume, given, settings import hypothesis.strategies as st - -from caffe2.proto import caffe2_pb2 from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial diff --git a/caffe2/python/operator_test/mean_op_test.py b/caffe2/python/operator_test/mean_op_test.py index 5830089f8e9bb..ee2c6fc8fbf7a 100644 --- a/caffe2/python/operator_test/mean_op_test.py +++ b/caffe2/python/operator_test/mean_op_test.py @@ -6,8 +6,6 @@ from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial - -from hypothesis import given import hypothesis.strategies as st import numpy as np import unittest diff --git a/caffe2/python/operator_test/mod_op_test.py b/caffe2/python/operator_test/mod_op_test.py index 914bffd2067c6..03ff766c11e43 100644 --- a/caffe2/python/operator_test/mod_op_test.py +++ b/caffe2/python/operator_test/mod_op_test.py @@ -1,12 +1,7 @@ - - - - - import numpy from caffe2.python import core -from hypothesis import given +from hypothesis import given, settings import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st @@ -16,7 +11,8 @@ @st.composite def _data(draw): return draw( - hu.tensor(dtype=np.int64, + hu.tensor( + dtype=np.int64, elements=st.integers( min_value=np.iinfo(np.int64).min, max_value=np.iinfo(np.int64).max ) @@ -25,6 +21,7 @@ def _data(draw): class TestMod(hu.HypothesisTestCase): + @settings(deadline=None) @given( data=_data(), divisor=st.integers( @@ -32,7 +29,7 @@ class TestMod(hu.HypothesisTestCase): ), inplace=st.booleans(), sign_follow_divisor=st.booleans(), - **hu.gcs_cpu_only + **hu.gcs ) def test_mod( self, data, divisor, inplace, sign_follow_divisor, gc, dc diff --git a/caffe2/python/operator_test/moments_op_test.py b/caffe2/python/operator_test/moments_op_test.py index 3b270df254ce9..bee44e360e3f1 100644 --- a/caffe2/python/operator_test/moments_op_test.py +++ b/caffe2/python/operator_test/moments_op_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial diff --git a/caffe2/python/operator_test/numpy_tile_op_test.py b/caffe2/python/operator_test/numpy_tile_op_test.py index a202581f808c9..c32aa99470db4 100644 --- a/caffe2/python/operator_test/numpy_tile_op_test.py +++ b/caffe2/python/operator_test/numpy_tile_op_test.py @@ -9,7 +9,7 @@ import hypothesis.strategies as st import unittest -from caffe2.python import core, workspace +from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial diff --git a/caffe2/python/operator_test/onnx_while_test.py b/caffe2/python/operator_test/onnx_while_test.py index 4cff53b87d6ef..5ad9c277239d1 100644 --- a/caffe2/python/operator_test/onnx_while_test.py +++ b/caffe2/python/operator_test/onnx_while_test.py @@ -3,7 +3,7 @@ from caffe2.proto import caffe2_pb2 -from caffe2.python import core, workspace +from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial from hypothesis import given, settings diff --git a/caffe2/python/operator_test/pack_rnn_sequence_op_test.py b/caffe2/python/operator_test/pack_rnn_sequence_op_test.py index 9a76e6b847a5a..eceb1e5ba6a94 100644 --- a/caffe2/python/operator_test/pack_rnn_sequence_op_test.py +++ b/caffe2/python/operator_test/pack_rnn_sequence_op_test.py @@ -4,7 +4,6 @@ from caffe2.python import core -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/pad_test.py b/caffe2/python/operator_test/pad_test.py index 6d4e6bbdcd08f..788c4035dd5f5 100644 --- a/caffe2/python/operator_test/pad_test.py +++ b/caffe2/python/operator_test/pad_test.py @@ -5,8 +5,6 @@ from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial - -from hypothesis import given import hypothesis.strategies as st import numpy as np import unittest diff --git a/caffe2/python/operator_test/percentile_op_test.py b/caffe2/python/operator_test/percentile_op_test.py index d81b0a963185a..40c4192e21e9f 100644 --- a/caffe2/python/operator_test/percentile_op_test.py +++ b/caffe2/python/operator_test/percentile_op_test.py @@ -3,7 +3,7 @@ -from caffe2.python import core, workspace, dyndep +from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import numpy as np diff --git a/caffe2/python/operator_test/rand_quantization_op_test.py b/caffe2/python/operator_test/rand_quantization_op_test.py index e244f77149e1f..a702ab41577f9 100644 --- a/caffe2/python/operator_test/rand_quantization_op_test.py +++ b/caffe2/python/operator_test/rand_quantization_op_test.py @@ -6,7 +6,6 @@ import numpy as np import struct import unittest -import os from hypothesis import given, example import hypothesis.strategies as st diff --git a/caffe2/python/operator_test/recurrent_network_test.py b/caffe2/python/operator_test/recurrent_network_test.py index 13650e6cad4e9..33ada4d6881c8 100644 --- a/caffe2/python/operator_test/recurrent_network_test.py +++ b/caffe2/python/operator_test/recurrent_network_test.py @@ -11,9 +11,6 @@ import hypothesis.strategies as st import numpy as np -import os -import unittest - class RecurrentNetworkTest(serial.SerializedTestCase): @given(T=st.integers(1, 4), n=st.integers(1, 5), diff --git a/caffe2/python/operator_test/reduce_ops_test.py b/caffe2/python/operator_test/reduce_ops_test.py index 727631befe89d..7b79b3b81aed1 100644 --- a/caffe2/python/operator_test/reduce_ops_test.py +++ b/caffe2/python/operator_test/reduce_ops_test.py @@ -11,7 +11,6 @@ import hypothesis.strategies as st import numpy as np import itertools as it -import unittest class TestReduceOps(serial.SerializedTestCase): diff --git a/caffe2/python/operator_test/reduction_ops_test.py b/caffe2/python/operator_test/reduction_ops_test.py index 7d4287df66097..6a99f2b27d429 100644 --- a/caffe2/python/operator_test/reduction_ops_test.py +++ b/caffe2/python/operator_test/reduction_ops_test.py @@ -3,7 +3,6 @@ -from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace from hypothesis import assume, given, settings import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/operator_test/reshape_ops_test.py b/caffe2/python/operator_test/reshape_ops_test.py index a42f00bbf82f4..dc90b6815f017 100644 --- a/caffe2/python/operator_test/reshape_ops_test.py +++ b/caffe2/python/operator_test/reshape_ops_test.py @@ -3,7 +3,6 @@ import numpy as np -import six from numpy.testing import assert_array_equal from caffe2.python import core, workspace diff --git a/caffe2/python/operator_test/roi_align_rotated_op_test.py b/caffe2/python/operator_test/roi_align_rotated_op_test.py index c74157a039b03..ea835acead617 100644 --- a/caffe2/python/operator_test/roi_align_rotated_op_test.py +++ b/caffe2/python/operator_test/roi_align_rotated_op_test.py @@ -3,7 +3,6 @@ -from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace from hypothesis import given import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/operator_test/self_binning_histogram_test.py b/caffe2/python/operator_test/self_binning_histogram_test.py index 14a37872ee5a5..afcf5ea57e3e7 100644 --- a/caffe2/python/operator_test/self_binning_histogram_test.py +++ b/caffe2/python/operator_test/self_binning_histogram_test.py @@ -8,9 +8,10 @@ class TestSelfBinningHistogramBase(object): - def __init__(self, bin_spacing, dtype): + def __init__(self, bin_spacing, dtype, abs=False): self.bin_spacing = bin_spacing self.dtype = dtype + self.abs = abs def _check_histogram(self, arrays, num_bins, expected_values=None, expected_counts=None): # Check that sizes match and counts add up. @@ -20,28 +21,39 @@ def _check_histogram(self, arrays, num_bins, expected_values=None, expected_coun self.assertTrue(np.size(counts) == num_bins) self.assertTrue(np.sum(counts) == sum([np.size(array) for array in arrays])) - + # Check counts if expected_counts is None: # Check that counts are correct for the returned values if expected_counts is not given. expected_counts = np.zeros(num_bins, dtype='i') for array in arrays: - for i in array: + for input_val in array: + input_val = abs(input_val) if self.abs else input_val found = False for pos in range(np.size(values)): - if values[pos] > i: + if values[pos] > input_val: found = True break - self.assertTrue(found, "input array must fit inside values array") + self.assertTrue(found, f"input value must fit inside values array: " + f"input={input_val}, last_value={values[-1]}") if self.bin_spacing == "linear": - self.assertTrue(pos > 0, "first value should be the smallest") + self.assertTrue(pos > 0, + f"input should not be smaller than the first bin value: " + f"input={input_val}, 1st bin value={values[pos]}") if pos == 0: self.assertEqual(self.bin_spacing, "logarithmic") expected_counts[pos] += 1 else: expected_counts[pos - 1] += 1 self.assertTrue(np.array_equal(expected_counts, counts), f"expected:{expected_counts}\ncounts:{counts}") + # Check values if expected_values is not None: - self.assertTrue(np.array_equal(expected_values, values), f"expected:{expected_values}\ncounts:{values}") + self.assertTrue(np.allclose(expected_values, values, rtol=1e-02, atol=1e-05), + f"expected:{expected_values}\nvalues:{values}") + # Ideally, the output values are sorted in a non-decreasing order. + for idx in range(len(values) - 1): + self.assertTrue(values[idx] <= values[idx + 1]) + if self.abs: + self.assertTrue(values[0] >= 0) def _run_single_op_net(self, arrays, num_bins, logspacing_start=None): @@ -57,6 +69,7 @@ def _run_single_op_net(self, arrays, num_bins, logspacing_start=None): num_bins=num_bins, bin_spacing=self.bin_spacing, logspacing_start=logspacing_start, + abs=self.abs ) else: net.SelfBinningHistogram( @@ -64,6 +77,7 @@ def _run_single_op_net(self, arrays, num_bins, logspacing_start=None): ["histogram_values", "histogram_counts"], num_bins=num_bins, bin_spacing=self.bin_spacing, + abs=self.abs ) workspace.RunNetOnce(net) @@ -82,10 +96,25 @@ def test_histogram_device_consistency(self, rows, cols, gc, dc): def test_histogram_bin_to_fewer(self): X = np.array([-2.0, -2.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 9.0], dtype=self.dtype) + if self.bin_spacing == 'linear': + if not self.abs: + expected_values = [-2., 0.2, 2.4, 4.6, 6.8, 9.] + expected_counts = [5, 2, 2, 1, 1, 0] + else: + expected_values = [0., 1.8, 3.6, 5.4, 7.2, 9.] + expected_counts = [4, 4, 1, 1, 1, 0] + else: + expected_values = [1.e-24, 9.8e-20, 9.6e-15, 9.4e-10, 9.2e-05, 9.] + if not self.abs: + expected_counts = [5, 0, 0, 0, 6, 0] + else: + expected_counts = [3, 0, 0, 0, 8, 0] self._run_single_op_net([X], 5) self._check_histogram( [X], 6, + expected_values=expected_values, + expected_counts=expected_counts ) def test_histogram_bin_to_more(self): @@ -99,10 +128,20 @@ def test_histogram_bin_to_more(self): def test_histogram_bin_to_two(self): """This test roughly tests [min,max+EPSILON] and [N,0]""" X = np.array([-2.0, -2.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 9.0], dtype=self.dtype) + if self.bin_spacing == 'linear': + if not self.abs: + expected_values = [-2., 9.] + else: + expected_values = [0., 9.] + else: + expected_values = [1.e-24, 9.] + expected_counts = [11, 0] self._run_single_op_net([X], 1) self._check_histogram( [X], 2, + expected_values=expected_values, + expected_counts=expected_counts ) def test_histogram_min_max_equal(self): @@ -129,7 +168,7 @@ def test_histogram_min_max_equal(self): def test_histogram_min_max_equal_nonzero(self): X = np.array([1., 1., 1., 1., 1.], dtype=self.dtype) logspacing_start = 1e-24 - self._run_single_op_net([X], 3, 1e-24) + self._run_single_op_net([X], 3, logspacing_start) self._check_histogram( [X], 4, @@ -143,33 +182,58 @@ def test_histogram_empty_input_tensor(self): self._check_histogram( [X], 2, + expected_values=[0., 0.], + expected_counts=[0, 0] ) self._run_single_op_net([X], 10) self._check_histogram( [X], 11, + expected_values=[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + expected_counts=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) def test_histogram_multi_input(self): X1 = np.array([-2.0, -2.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 9.0], dtype=self.dtype) X2 = np.array([-5.0, -3.0, 7, 7, 0.0, 1.0, 2.0, -3.0, 4.0, 6.0, 9.0], dtype=self.dtype) + if self.bin_spacing == 'linear': + if not self.abs: + expected_values = [-5., -2.2, 0.6, 3.4, 6.2, 9.] + expected_counts = [3, 6, 5, 4, 4, 0] + else: + expected_values = [0., 1.8, 3.6, 5.4, 7.2, 9.] + expected_counts = [6, 7, 3, 4, 2, 0] + else: + expected_values = [1.e-24, 9.8e-20, 9.6e-15, 9.4e-10, 9.2e-05, 9.] + if not self.abs: + expected_counts = [9, 0, 0, 0, 13, 0] + else: + expected_counts = [4, 0, 0, 0, 18, 0] self._run_single_op_net([X1, X2], 5) self._check_histogram( [X1, X2], 6, + expected_values=expected_values, + expected_counts=expected_counts ) def test_histogram_very_small_range_for_stride_underflow(self): """Tests a large number of bins for a very small range of values. - This test uses float type. 1-e38 is very small, and with 1M bins, it + This test uses float type. 1-e302 is very small, and with 1M bins, it causes numeric underflow. This test is to show that this is handled. + + Note: this test was flaky due to how compiler and OS handls floats. + Previously, 1-e38 does not induce overflow and cuases test error for some + combinations of compiler and OS. Now 1-e302 should be small enough. """ - X = np.array([0, 1e-38], dtype='f') - self._run_single_op_net([X], 1000000) + X = np.array([0, 1e-302], dtype='f') + large_bin_number = 1000000 + self._run_single_op_net([X], large_bin_number) self._check_histogram( [X], - 1000001, + large_bin_number + 1, + expected_counts=[2] + [0] * large_bin_number # [2, 0, 0, ..., 0] ) @@ -200,6 +264,35 @@ def __init__(self, *args, **kwargs): TestSelfBinningHistogramBase.__init__(self, bin_spacing="logarithmic", dtype='f') hu.HypothesisTestCase.__init__(self, *args, **kwargs) +class TestSelfBinningHistogramLinearWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='d', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLogarithmicWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="logarithmic", dtype='d', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLinearFloatWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='f', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLogarithmicFloatWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="logarithmic", dtype='f', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLinearWithNoneAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='d', abs=None) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLinearFloatWithNoneAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='f', abs=None) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) if __name__ == "__main__": global_options = ["caffe2"] diff --git a/caffe2/python/operator_test/sequence_ops_test.py b/caffe2/python/operator_test/sequence_ops_test.py index 4609473f91f08..65c0669abfb00 100644 --- a/caffe2/python/operator_test/sequence_ops_test.py +++ b/caffe2/python/operator_test/sequence_ops_test.py @@ -11,7 +11,6 @@ import hypothesis.strategies as st import numpy as np import unittest -import os def _gen_test_add_padding(with_pad_data=True, diff --git a/caffe2/python/operator_test/sparse_normalize_test.py b/caffe2/python/operator_test/sparse_normalize_test.py index ecc4ae0c8d22d..e9777dad36dc4 100644 --- a/caffe2/python/operator_test/sparse_normalize_test.py +++ b/caffe2/python/operator_test/sparse_normalize_test.py @@ -1,19 +1,14 @@ +from __future__ import absolute_import, division, print_function, unicode_literals - - - - +import caffe2.python.hypothesis_test_util as hu import hypothesis -from hypothesis import given, settings, HealthCheck import hypothesis.strategies as st import numpy as np - from caffe2.python import core -import caffe2.python.hypothesis_test_util as hu +from hypothesis import HealthCheck, given, settings class TestSparseNormalize(hu.HypothesisTestCase): - @staticmethod def ref_normalize(param_in, use_max_norm, norm): param_norm = np.linalg.norm(param_in) + 1e-12 @@ -24,31 +19,44 @@ def ref_normalize(param_in, use_max_norm, norm): # Suppress filter_too_much health check. # Likely caused by `assume` call falling through too often. @settings(suppress_health_check=[HealthCheck.filter_too_much]) - @given(inputs=hu.tensors(n=2, min_dim=2, max_dim=2), - use_max_norm=st.booleans(), - norm=st.floats(min_value=1.0, max_value=4.0), - data_strategy=st.data(), - **hu.gcs_cpu_only) - def test_sparse_normalize(self, inputs, use_max_norm, norm, - data_strategy, gc, dc): + @given( + inputs=hu.tensors(n=2, min_dim=2, max_dim=2), + use_max_norm=st.booleans(), + norm=st.floats(min_value=1.0, max_value=4.0), + data_strategy=st.data(), + use_fp16=st.booleans(), + **hu.gcs_cpu_only + ) + def test_sparse_normalize( + self, inputs, use_max_norm, norm, data_strategy, use_fp16, gc, dc + ): param, grad = inputs param += 0.02 * np.sign(param) param[param == 0.0] += 0.02 + if use_fp16: + param = param.astype(np.float16) + grad = grad.astype(np.float16) + # Create an indexing array containing values that are lists of indices, # which index into param indices = data_strategy.draw( - hu.tensor(dtype=np.int64, min_dim=1, max_dim=1, - elements=st.sampled_from(np.arange(param.shape[0]))), + hu.tensor( + dtype=np.int64, + min_dim=1, + max_dim=1, + elements=st.sampled_from(np.arange(param.shape[0])), + ) ) - hypothesis.note('indices.shape: %s' % str(indices.shape)) + hypothesis.note("indices.shape: %s" % str(indices.shape)) # For now, the indices must be unique - hypothesis.assume(np.array_equal(np.unique(indices.flatten()), - np.sort(indices.flatten()))) + hypothesis.assume( + np.array_equal(np.unique(indices.flatten()), np.sort(indices.flatten())) + ) op1 = core.CreateOperator( - "SparseNormalize", + "Float16SparseNormalize" if use_fp16 else "SparseNormalize", ["param", "indices"], ["param"], use_max_norm=use_max_norm, @@ -59,7 +67,7 @@ def test_sparse_normalize(self, inputs, use_max_norm, norm, grad = grad[indices] op2 = core.CreateOperator( - "SparseNormalize", + "Float16SparseNormalize" if use_fp16 else "SparseNormalize", ["param", "indices", "grad"], ["param"], use_max_norm=use_max_norm, @@ -69,20 +77,22 @@ def test_sparse_normalize(self, inputs, use_max_norm, norm, def ref_sparse_normalize(param, indices, grad=None): param_out = np.copy(param) for _, index in enumerate(indices): - param_out[index] = self.ref_normalize( - param[index], - use_max_norm, - norm, - ) + param_out[index] = self.ref_normalize(param[index], use_max_norm, norm) return (param_out,) # self.assertDeviceChecks(dc, op, [param, indices], [0]) self.assertReferenceChecks( - gc, op1, [param, indices], - ref_sparse_normalize + gc, + op1, + [param, indices], + ref_sparse_normalize, + threshold=1e-2 if use_fp16 else 1e-4, ) self.assertReferenceChecks( - gc, op2, [param, indices, grad], - ref_sparse_normalize + gc, + op2, + [param, indices, grad], + ref_sparse_normalize, + threshold=1e-2 if use_fp16 else 1e-4, ) diff --git a/caffe2/python/operator_test/spatial_bn_op_test.py b/caffe2/python/operator_test/spatial_bn_op_test.py index 35f7bd2a5e294..21a530346329c 100644 --- a/caffe2/python/operator_test/spatial_bn_op_test.py +++ b/caffe2/python/operator_test/spatial_bn_op_test.py @@ -3,7 +3,6 @@ -from caffe2.proto import caffe2_pb2 from caffe2.python import brew, core, utils, workspace import caffe2.python.hip_test_util as hiputl import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/operator_test/square_root_divide_op_test.py b/caffe2/python/operator_test/square_root_divide_op_test.py index 5bd6cb1d08f86..51f328c95f5f2 100644 --- a/caffe2/python/operator_test/square_root_divide_op_test.py +++ b/caffe2/python/operator_test/square_root_divide_op_test.py @@ -5,7 +5,6 @@ from caffe2.python import core from functools import partial -from hypothesis import given from hypothesis import strategies as st import caffe2.python.hypothesis_test_util as hu diff --git a/caffe2/python/operator_test/string_ops_test.py b/caffe2/python/operator_test/string_ops_test.py index eedb57be1d6c4..a0c56a6866664 100644 --- a/caffe2/python/operator_test/string_ops_test.py +++ b/caffe2/python/operator_test/string_ops_test.py @@ -119,6 +119,33 @@ def string_ends_with_ref(strings): [strings], string_ends_with_ref) + @given(strings=st.text(alphabet=['a', 'b'])) + @settings(deadline=1000) + def test_string_equals(self, strings): + text = "" + if strings: + text = strings[0] + + strings = np.array( + [str(a) for a in strings], dtype=np.object + ) + + def string_equals_ref(strings): + return ( + np.array([a == text for a in strings], dtype=bool), + ) + + op = core.CreateOperator( + 'StringEquals', + ['strings'], + ['bools'], + text=text) + self.assertReferenceChecks( + hu.cpu_do, + op, + [strings], + string_equals_ref) + if __name__ == "__main__": import unittest unittest.main() diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index 9bec647642403..d6185125e3ae7 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -1,12 +1,12 @@ +import struct +import unittest + import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st import numpy as np -import struct import torch -import unittest - from caffe2.python import core, workspace from hypothesis import given, settings from scipy.stats import norm @@ -77,7 +77,7 @@ def create_bbox_transform_inputs(roi_counts, num_classes, rotated): def bytes_to_floats(byte_matrix): floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float32) for i, byte_values in enumerate(byte_matrix): - floats[i], = struct.unpack('f', bytearray(byte_values)) + (floats[i],) = struct.unpack("f", bytearray(byte_values)) return floats @@ -85,12 +85,12 @@ def floats_to_bytes(floats): byte_matrix = np.empty([np.shape(floats)[0], 4], dtype=np.uint8) for i, value in enumerate(floats): assert isinstance(value, np.float32), (value, floats) - as_bytes = struct.pack('f', value) + as_bytes = struct.pack("f", value) # In Python3 bytes will be a list of int, in Python2 a list of string if isinstance(as_bytes[0], int): byte_matrix[i] = list(as_bytes) else: - byte_matrix[i] = list(map(ord, as_bytes)) + byte_matrix[i] = [ord(i) for i in as_bytes] return byte_matrix @@ -180,6 +180,7 @@ def bbox_transform_ref(): rotated=st.booleans(), angle_bound_on=st.booleans(), clip_angle_thresh=st.sampled_from([-1.0, 1.0]), + batch_splits_dtype=st.sampled_from([torch.float32, torch.int32]), **hu.gcs_cpu_only ) def test_box_with_nms_limits( @@ -189,6 +190,7 @@ def test_box_with_nms_limits( rotated, angle_bound_on, clip_angle_thresh, + batch_splits_dtype, gc, dc, ): @@ -250,7 +252,7 @@ def box_with_nms_limit_ref(): outputs = torch.ops._caffe2.BoxWithNMSLimit( torch.tensor(class_prob), torch.tensor(pred_bbox), - torch.tensor(batch_splits), + torch.tensor(batch_splits, dtype=batch_splits_dtype), score_thresh=float(score_thresh), nms=float(nms_thresh), detections_per_im=int(topk_per_image), @@ -268,6 +270,69 @@ def box_with_nms_limit_ref(): for o, o_ref in zip(outputs, output_refs): torch.testing.assert_allclose(o, o_ref) + @given( + dim_1=st.integers(min_value=10, max_value=10), + dim_2=st.integers(min_value=3, max_value=3), + dim_3=st.integers(min_value=2, max_value=2), + ) + def test_sparse_to_dense_mask(self, dim_1, dim_2, dim_3): + indices = np.array([i + 1 for i in range(dim_1)]).astype(np.int32) + values = np.random.rand(dim_1, dim_2, dim_3).astype(np.float32) + default_value = np.zeros((dim_2, dim_3)).astype(np.float32) + mask = [2, 4, 9] + + def sparse_to_dense_mask_ref(return_presence_mask=False): + ref_op = core.CreateOperator( + "SparseToDenseMask", + ["indices", "values", "default_value"], + ["output", "presence_mask"], + mask=mask, + return_presence_mask=return_presence_mask, + ) + workspace.FeedBlob("indices", indices) + workspace.FeedBlob("values", values) + workspace.FeedBlob("default_value", default_value) + workspace.RunOperatorOnce(ref_op) + + if return_presence_mask: + return ( + workspace.FetchBlob("output"), + workspace.FetchBlob("presence_mask"), + ) + + return workspace.FetchBlob("output") + + # Testing return_presence_mask = False + output = sparse_to_dense_mask_ref() + output = torch.tensor(output) + + a, _ = torch.ops._caffe2.SparseToDenseMask( + torch.tensor(indices), + torch.tensor(values), + torch.tensor(default_value), + None, + mask=mask, + ) + + torch.testing.assert_allclose(output, a) + + # Testing return_presence_mask = True + output, presence_mask = sparse_to_dense_mask_ref(return_presence_mask=True) + output = torch.tensor(output) + presence_mask = torch.tensor(presence_mask) + + a, b = torch.ops._caffe2.SparseToDenseMask( + torch.tensor(indices), + torch.tensor(values), + torch.tensor(default_value), + None, + mask=mask, + return_presence_mask=True, + ) + + torch.testing.assert_allclose(output, a) + torch.testing.assert_allclose(presence_mask, b) + @given( A=st.integers(min_value=4, max_value=4), H=st.integers(min_value=10, max_value=10), @@ -380,7 +445,7 @@ def inference_lstm_ref(): return ( workspace.FetchBlob("output"), workspace.FetchBlob("hidden"), - workspace.FetchBlob("cell") + workspace.FetchBlob("cell"), ) output, hidden, cell = inference_lstm_ref() @@ -526,7 +591,7 @@ def rand_rotated_roi(): np.random.rand() * H, np.random.rand() * W, np.random.rand() * H, - np.random.rand() * 360 - 180 + np.random.rand() * 360 - 180, ] ).astype(np.float32) @@ -613,18 +678,19 @@ def test_collect_and_distribute_fpn_rpn_proposals_op(self, roi_counts): for x, y in zip(fpn_outputs, all_outputs[1:]): torch.testing.assert_allclose(x, y) - @given(X=hu.tensor(), - fast_gelu=st.booleans()) + @given(X=hu.tensor(), fast_gelu=st.booleans()) def _test_gelu_op(self, X, fast_gelu, device): def _gelu_ref(_X): - return (_X * norm.cdf(_X).astype(np.float32), ) - expected_output, = _gelu_ref(X) + return (_X * norm.cdf(_X).astype(np.float32),) + + (expected_output,) = _gelu_ref(X) actual_output = torch.ops._caffe2.Gelu(torch.tensor(X), fast_gelu) rtol = 1e-3 if fast_gelu else 1e-4 atol = 1e-5 torch.testing.assert_allclose( - expected_output, actual_output.cpu(), rtol=rtol, atol=atol) + expected_output, actual_output.cpu(), rtol=rtol, atol=atol + ) def test_gelu_op(self): self._test_gelu_op(device="cpu") @@ -633,13 +699,11 @@ def test_gelu_op(self): def test_gelu_op_cuda(self): self._test_gelu_op(device="cuda") - - @given(inputs=hu.lengths_tensor( - dtype=np.float32, - min_value=1, - max_value=5, - allow_empty=True, - )) + @given( + inputs=hu.lengths_tensor( + dtype=np.float32, min_value=1, max_value=5, allow_empty=True + ) + ) def _test_lengths_op(self, inputs, ref_op_name, torch_op, device): data, lengths = inputs @@ -652,7 +716,8 @@ def _lengths_ref(X, Y): expected_output = _lengths_ref(data, lengths) actual_output = torch_op( - torch.tensor(data), torch.tensor(lengths, dtype=torch.int32)) + torch.tensor(data), torch.tensor(lengths, dtype=torch.int32) + ) torch.testing.assert_allclose(expected_output, actual_output.cpu()) @@ -691,8 +756,12 @@ def _test_resize_nearest_op(self, device): def _resize_nearest_ref(X): ref_op = core.CreateOperator( - "ResizeNearest", ["X"], ["Y"], - width_scale=2.0, height_scale=1.5, order="NCHW", + "ResizeNearest", + ["X"], + ["Y"], + width_scale=2.0, + height_scale=1.5, + order="NCHW", ) workspace.FeedBlob("X", X) workspace.RunOperatorOnce(ref_op) @@ -701,7 +770,9 @@ def _resize_nearest_ref(X): expected_output = _resize_nearest_ref(data) actual_output = torch.ops._caffe2.ResizeNearest( torch.tensor(data).to(device), - order="NCHW", width_scale=2.0, height_scale=1.5, + order="NCHW", + width_scale=2.0, + height_scale=1.5, ) torch.testing.assert_allclose(expected_output, actual_output.cpu()) @@ -716,9 +787,7 @@ def test_resize_nearest_op_cuda(self): @given(input_data=hu.tensor(min_dim=2, max_dim=2)) def test_Fused8BitRowwiseQuantizedToFloat(self, input_data): QuantizeOp = core.CreateOperator( - "FloatToFused8BitRowwiseQuantized", - ["input_data"], - ["quantized_data"], + "FloatToFused8BitRowwiseQuantized", ["input_data"], ["quantized_data"] ) workspace.FeedBlob("input_data", input_data) @@ -741,16 +810,15 @@ def test_piecewise_linear_op(self, binary_input): num_dims = 3 data = np.random.rand(1024, num_dims).astype(np.float32) slopes = np.zeros(4 * num_dims).astype(np.float32) - bounds = np.sort(np.random.rand(5, num_dims).astype(np.float32), axis=0).flatten('F') + bounds = np.sort( + np.random.rand(5, num_dims).astype(np.float32), axis=0 + ).flatten("F") intercepts = np.random.rand(4 * num_dims).astype(np.float32) def _piecewise_linear_ref(X): ref_op = core.CreateOperator( "PiecewiseLinearTransform", - ["data", - "bounds", - "slopes", - "intercepts"], + ["data", "bounds", "slopes", "intercepts"], ["calibrated"], binary=binary_input, ) @@ -763,7 +831,12 @@ def _piecewise_linear_ref(X): expected_output = _piecewise_linear_ref(data) actual_output = torch.ops._caffe2.PiecewiseLinearTransform( - torch.tensor(data), bounds.tolist(), slopes.tolist(), intercepts.tolist(), binary_input) + torch.tensor(data), + bounds.tolist(), + slopes.tolist(), + intercepts.tolist(), + binary_input, + ) torch.testing.assert_allclose(torch.tensor(expected_output), actual_output) @@ -790,9 +863,7 @@ def test_index_hash_op(self): data = np.random.randint(low=0, high=1000, size=(4, 4, 4)) def _index_hash_ref(X): - ref_op = core.CreateOperator( - "IndexHash", ["X"], ["Y"], seed=0, modulo=100 - ) + ref_op = core.CreateOperator("IndexHash", ["X"], ["Y"], seed=0, modulo=100) workspace.FeedBlob("X", X) workspace.RunOperatorOnce(ref_op) return workspace.FetchBlob("Y") @@ -817,33 +888,32 @@ def _bucketize_ref(X): return workspace.FetchBlob("Y") expected_output = _bucketize_ref(data) - actual_output = torch.ops._caffe2.Bucketize( - torch.tensor(data), boundaries - ) + actual_output = torch.ops._caffe2.Bucketize(torch.tensor(data), boundaries) torch.testing.assert_allclose(expected_output, actual_output.cpu()) - @given(X=hu.tensor(), - eps=st.floats(min_value=1e-4, max_value=1e-2), - ) + @given(X=hu.tensor(), eps=st.floats(min_value=1e-4, max_value=1e-2)) def test_logit(self, X, eps): def ref(X, eps): - ref_op = core.CreateOperator('Logit', ["X"], ["Y"], eps=eps) + ref_op = core.CreateOperator("Logit", ["X"], ["Y"], eps=eps) workspace.FeedBlob("X", X) workspace.RunOperatorOnce(ref_op) return workspace.FetchBlob("Y") + expected_output = ref(X, eps) - actual_output = torch.ops._caffe2.Logit( - torch.tensor(X), eps - ) + actual_output = torch.ops._caffe2.Logit(torch.tensor(X), eps) torch.testing.assert_allclose(expected_output, actual_output.cpu()) def test_percentile(self): - original_values = np.array([[3., 5., 3], [5., 1., 6.]]).astype(np.float32) - value_to_pct = np.array([[3, 0.2], [5, 0.5], [1, 0.3], [3, 0.6]]).astype(np.float32) + original_values = np.array([[3.0, 5.0, 3], [5.0, 1.0, 6.0]]).astype(np.float32) + value_to_pct = np.array([[3, 0.2], [5, 0.5], [1, 0.3], [3, 0.6]]).astype( + np.float32 + ) lengths = np.array([2, 1, 1]).astype(np.int32) def _percentile_ref(original_values, value_to_pct, lengths): - ref_op = core.CreateOperator('Percentile', ["original_values", "value_to_pct", "lengths"], ["Y"]) + ref_op = core.CreateOperator( + "Percentile", ["original_values", "value_to_pct", "lengths"], ["Y"] + ) workspace.FeedBlob("original_values", original_values) workspace.FeedBlob("value_to_pct", value_to_pct) workspace.FeedBlob("lengths", lengths) @@ -852,7 +922,9 @@ def _percentile_ref(original_values, value_to_pct, lengths): expected_output = _percentile_ref(original_values, value_to_pct, lengths) actual_output = torch.ops._caffe2.Percentile( - torch.tensor(original_values), torch.Tensor(value_to_pct), torch.Tensor(lengths).int() + torch.tensor(original_values), + torch.Tensor(value_to_pct), + torch.Tensor(lengths).int(), ) torch.testing.assert_allclose(expected_output, actual_output.cpu()) @@ -862,7 +934,9 @@ def test_batch_bucket_one_hot_op(self): boundaries = np.array([0.1, 2.5, 1, 3.1, 4.5]).astype(np.float32) def _batch_bucket_one_hot_ref(data, lengths, boundaries): - ref_op = core.CreateOperator('BatchBucketOneHot', ["data", "lengths", "boundaries"], ["Y"]) + ref_op = core.CreateOperator( + "BatchBucketOneHot", ["data", "lengths", "boundaries"], ["Y"] + ) workspace.FeedBlob("data", data) workspace.FeedBlob("lengths", lengths) workspace.FeedBlob("boundaries", boundaries) @@ -921,26 +995,43 @@ def test_gather_ranges_to_dense_op(self): def test_merge_id_lists(self, lengths_0, lengths_1): def _merge_id_lists(lengths, values): ref_op = core.CreateOperator( - 'MergeIdLists', + "MergeIdLists", ["lengths_0", "values_0", "lengths_1", "values_1"], - ["merged_lengths", "merged_values"] + ["merged_lengths", "merged_values"], ) workspace.FeedBlob("lengths_0", lengths[0]) workspace.FeedBlob("values_0", values[0]) workspace.FeedBlob("lengths_1", lengths[1]) workspace.FeedBlob("values_1", values[1]) workspace.RunOperatorOnce(ref_op) - return workspace.FetchBlob("merged_lengths"), workspace.FetchBlob("merged_values") + return ( + workspace.FetchBlob("merged_lengths"), + workspace.FetchBlob("merged_values"), + ) - lengths = [np.array([lengths_0]).astype(np.int32), np.array([lengths_1]).astype(np.int32)] + lengths = [ + np.array([lengths_0]).astype(np.int32), + np.array([lengths_1]).astype(np.int32), + ] values = [ - np.random.choice(np.arange(0, 10), size=lengths_0, replace=False).astype(np.int32), - np.random.choice(np.arange(10, 20), size=lengths_1, replace=False).astype(np.int32) + np.random.choice(np.arange(0, 10), size=lengths_0, replace=False).astype( + np.int32 + ), + np.random.choice(np.arange(10, 20), size=lengths_1, replace=False).astype( + np.int32 + ), ] - expected_merged_lengths, expected_merged_values = _merge_id_lists(lengths, values) + expected_merged_lengths, expected_merged_values = _merge_id_lists( + lengths, values + ) output_merged_lengths, output_merged_values = torch.ops._caffe2.MergeIdLists( - [torch.tensor(lengths[0]), torch.tensor(values[0]), torch.tensor(lengths[1]), torch.tensor(values[1])] + [ + torch.tensor(lengths[0]), + torch.tensor(values[0]), + torch.tensor(lengths[1]), + torch.tensor(values[1]), + ] ) torch.testing.assert_allclose(expected_merged_lengths, output_merged_lengths) torch.testing.assert_allclose(expected_merged_values, output_merged_values) @@ -1003,18 +1094,11 @@ def test_learning_rate(self): def test_pack_segments(self): s = torch.rand(3, 3, 3) lengths = torch.tensor([2, 1]) - packed_tensor, _ = torch.ops._caffe2.PackSegments( - lengths, - s, - ) + packed_tensor, _ = torch.ops._caffe2.PackSegments(lengths, s) self.assertEqual(packed_tensor.numpy().shape, (2, 2, 3, 3)) - unpacked_tensor = torch.ops._caffe2.UnpackSegments( - lengths, - packed_tensor, - ) + unpacked_tensor = torch.ops._caffe2.UnpackSegments(lengths, packed_tensor) torch.testing.assert_allclose(s, unpacked_tensor) - -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/caffe2/python/operator_test/unsafe_coalesce_test.py b/caffe2/python/operator_test/unsafe_coalesce_test.py new file mode 100644 index 0000000000000..36f10cf1b4263 --- /dev/null +++ b/caffe2/python/operator_test/unsafe_coalesce_test.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np +import numpy.testing as npt +from caffe2.python import core, workspace +from hypothesis import given + + +class TestUnsafeCoalesceOp(hu.HypothesisTestCase): + @given( + n=st.integers(1, 5), + shape=st.lists(st.integers(0, 5), min_size=1, max_size=3), + **hu.gcs + ) + def test_unsafe_coalesce_op(self, n, shape, dc, gc): + workspace.ResetWorkspace() + test_inputs = [(100 * np.random.random(shape)).astype(np.float32) for _ in range(n)] + test_input_blobs = ["x_{}".format(i) for i in range(n)] + + coalesce_op = core.CreateOperator( + "UnsafeCoalesce", + test_input_blobs, + test_input_blobs + ["shared_memory_blob"], + device_option=gc, + ) + + def reference_func(*args): + self.assertEquals(len(args), n) + return list(args) + [np.concatenate([x.flatten() for x in args])] + + self.assertReferenceChecks(gc, coalesce_op, test_inputs, reference_func) + + @given( + n=st.integers(1, 5), + shape=st.lists(st.integers(1, 5), min_size=1, max_size=3), + seed=st.integers(0, 65535), + **hu.gcs + ) + def test_unsafe_coalesce_op_blob_sharing(self, n, shape, seed, dc, gc): + workspace.ResetWorkspace() + # Can make debugging of the test more predictable + np.random.seed(seed) + test_inputs = [(np.random.random(shape)).astype(np.float32) for _ in range(n)] + test_input_blobs = ["x_{}".format(i) for i in range(n)] + + coalesce_op = core.CreateOperator( + "UnsafeCoalesce", + test_input_blobs, + test_input_blobs + ["shared_memory_blob"], + device_option=gc, + ) + for name, value in zip(test_input_blobs, test_inputs): + workspace.FeedBlob(name, value, device_option=gc) + + workspace.RunOperatorOnce(coalesce_op) + blob_value = workspace.blobs["shared_memory_blob"] + npt.assert_almost_equal( + blob_value, + np.concatenate([x.flatten() for x in test_inputs]), + decimal=4 + ) + # np.random generates values in range [0, 1), so -2 is outside of range + blob_value.fill(-2.0) + self.assertTrue((blob_value != workspace.blobs["shared_memory_blob"]).all()) + workspace.FeedBlob("shared_memory_blob", blob_value, device_option=gc) + + # All blobs preserved shape, but got overwritted to -2 + for name, value in zip(test_input_blobs, test_inputs): + self.assertEqual(value.shape, workspace.blobs[name].shape) + self.assertTrue((value != workspace.blobs[name]).all()) + self.assertTrue((workspace.blobs[name] == -2).all()) + + # It should be OK to reuse operator as long as it's blob shapes are not changing + workspace.RunOperatorOnce(coalesce_op) diff --git a/caffe2/python/operator_test/utility_ops_test.py b/caffe2/python/operator_test/utility_ops_test.py index 241d1e4c1b565..aeefbf596afeb 100644 --- a/caffe2/python/operator_test/utility_ops_test.py +++ b/caffe2/python/operator_test/utility_ops_test.py @@ -11,7 +11,6 @@ import hypothesis.strategies as st import numpy as np import random -import six class TestUtilityOps(serial.SerializedTestCase): @@ -474,7 +473,7 @@ def test_range(self, gc, dc): names[len(inputs) - 1], ["Y"] ) - with six.assertRaisesRegex(self, RuntimeError, 'Step size cannot be 0'): + with self.assertRaisesRegex(RuntimeError, 'Step size cannot be 0'): self.assertReferenceChecks( device_option=gc, op=op, diff --git a/caffe2/python/optimizer.py b/caffe2/python/optimizer.py index 9a2f9f5414209..bbea8d42ed4d4 100644 --- a/caffe2/python/optimizer.py +++ b/caffe2/python/optimizer.py @@ -571,6 +571,7 @@ def __init__( output_effective_lr_and_update=False, pruning_options=None, swa_options=None, + ema_options=None, weight_scale=None, counter_halflife=-1, **kwargs @@ -596,6 +597,7 @@ def __init__( self._process_pruning_options(pruning_options) self._process_swa_options(swa_options) + self._process_ema_options(ema_options) def _process_swa_options(self, swa_options): self.swa_enabled = True if swa_options else False @@ -606,6 +608,14 @@ def _process_swa_options(self, swa_options): self.swa_feedback_step = swa_options.get("swa_feedback_step", None) self.swa_feedback_end_it = swa_options.get("swa_feedback_end_it", None) + def _process_ema_options(self, ema_options): + self.ema_enabled = True if ema_options else False + if self.ema_enabled: + self.ema_start = ema_options.get("ema_start", None) + self.ema_end = ema_options.get("ema_end", None) + self.ema_step = ema_options.get("ema_step", None) + self.ema_alpha = ema_options.get("ema_alpha", None) + def _process_pruning_options(self, pruning_options): self.use_mask = False @@ -1045,6 +1055,22 @@ def _run(self, net, param_init_net, param_info): feedback_step=self.swa_feedback_step, feedback_end=self.swa_feedback_end_it, ) + + if self.ema_enabled: + param_ema = str(param) + "_ema" + if not param_init_net.BlobIsDefined(param_ema): + param_init_net.ConstantFill([param], param_ema, value=0.0) + self._aux_params.local.append(param_ema) + + net.EMA( + [param, param_ema, lr_iteration], + [param, param_ema], + ema_start=self.ema_start, + ema_end=self.ema_end, + ema_step=self.ema_step, + ema_alpha=self.ema_alpha, + ) + if self.weight_scale: net.WeightScale( [param, lr_iteration], diff --git a/caffe2/python/optimizer_context.py b/caffe2/python/optimizer_context.py index d1593f4403837..b214d136f61a2 100644 --- a/caffe2/python/optimizer_context.py +++ b/caffe2/python/optimizer_context.py @@ -13,8 +13,7 @@ DEFAULT_OPTIM = 'DEFAULT' -@context.define_context(allow_default=True) -class OptimizerContext(ModifierContext): +class OptimizerContext(ModifierContext, context.DefaultManaged): """ provide context to allow param_info to have different optimizers """ diff --git a/caffe2/python/optimizer_test_util.py b/caffe2/python/optimizer_test_util.py index 02276b08c1766..beb8a37818327 100644 --- a/caffe2/python/optimizer_test_util.py +++ b/caffe2/python/optimizer_test_util.py @@ -8,7 +8,6 @@ import unittest import numpy as np from caffe2.python import brew, core, workspace, cnn, optimizer -from caffe2.proto import caffe2_pb2 from caffe2.python.modeling.initializers import ( Initializer, PseudoFP16Initializer) diff --git a/caffe2/python/parallel_workers.py b/caffe2/python/parallel_workers.py index 4ee446610bdbf..067f4794a89f9 100644 --- a/caffe2/python/parallel_workers.py +++ b/caffe2/python/parallel_workers.py @@ -38,7 +38,6 @@ import atexit import time import collections -import six import traceback from abc import ABCMeta, abstractmethod @@ -110,7 +109,7 @@ def put_metric(self, key, value, count=True): class State(): - six.add_metaclass(ABCMeta) + __metaclass__ = ABCMeta @abstractmethod def start(self): diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 2923b98c565ff..65a246e4a39c3 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -118,7 +118,7 @@ static_assert( sizeof(int) == sizeof(int32_t), "We make an assumption that int is always int32 for numpy " "type mapping."); -int CaffeToNumpyType(const TypeMeta& meta) { +int CaffeToNumpyType(const TypeMeta meta) { #ifdef USE_NUMPY static std::map numpy_type_map{ {TypeMeta::Id(), NPY_BOOL}, @@ -143,7 +143,7 @@ int CaffeToNumpyType(const TypeMeta& meta) { #endif // USE_NUMPY } -const TypeMeta& NumpyTypeToCaffe(int numpy_type) { +const TypeMeta NumpyTypeToCaffe(int numpy_type) { #ifdef USE_NUMPY static std::map caffe_type_map{ {NPY_BOOL, TypeMeta::Make()}, diff --git a/caffe2/python/pybind_state.h b/caffe2/python/pybind_state.h index b8f9dbaf37194..6513f216a9bee 100644 --- a/caffe2/python/pybind_state.h +++ b/caffe2/python/pybind_state.h @@ -103,8 +103,8 @@ static_assert( "We make an assumption that int is always int32 for numpy " "type mapping."); -int CaffeToNumpyType(const TypeMeta& dtype); -const TypeMeta& NumpyTypeToCaffe(int numpy_type); +int CaffeToNumpyType(const TypeMeta dtype); +const TypeMeta NumpyTypeToCaffe(int numpy_type); class TensorFetcher : public BlobFetcherBase { public: @@ -114,7 +114,7 @@ class TensorFetcher : public BlobFetcherBase { // Checks whether the data with type `dtype` needs to be copied in the context // of `tensor` - bool NeedsCopy(const Tensor* tensor, const TypeMeta& dtype) const { + bool NeedsCopy(const Tensor* tensor, const TypeMeta dtype) const { #ifdef USE_NUMPY return tensor->GetDeviceType() != CPU || CaffeToNumpyType(dtype) == NPY_OBJECT; @@ -200,9 +200,9 @@ class TensorFeeder : public BlobFeederBase { auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta& dtype = NumpyTypeToCaffe(npy_type); + const TypeMeta dtype = NumpyTypeToCaffe(npy_type); CAFFE_ENFORCE( - dtype.id() != TypeIdentifier::uninitialized(), + dtype != ScalarType::Undefined, "This numpy data type is not supported: ", PyArray_TYPE(array), "."); @@ -232,7 +232,6 @@ class TensorFeeder : public BlobFeederBase { for (int i = 0; i < tensor.numel(); ++i) { char* str; Py_ssize_t strSize; -#if PY_MAJOR_VERSION > 2 if (PyBytes_Check(input[i])) { CAFFE_ENFORCE( PyBytes_AsStringAndSize(input[i], &str, &strSize) != -1, @@ -246,11 +245,6 @@ class TensorFeeder : public BlobFeederBase { } else { CAFFE_THROW("Unsupported python object type passed into ndarray."); } -#else - CAFFE_ENFORCE( - PyBytes_AsStringAndSize(input[i], &str, &strSize) != -1, - "Unsupported python object type passed into ndarray."); -#endif // PY_MAJOR_VERSION > 2 outPtr[i] = std::string(str, strSize); } break; @@ -342,18 +336,12 @@ class PythonOpBase : public Operator { try { builder_call = loads(py::bytes(pickled)).cast(); } catch (const py::error_already_set& e) { -#if PY_MAJOR_VERSION >= 3 LOG(INFO) << "Cannot unpickle python operator: " << e.what(); LOG(INFO) << "Try latin1 encoding for python3 run"; // to use the `_a` literal for arguments using namespace pybind11::literals; builder_call = loads(py::bytes(pickled), "encoding"_a = "latin1") .template cast(); -#else - // for py2, simply re-throw the exception, as there is no encoding - // argument for pickle.loads - throw; -#endif } CAFFE_ENFORCE(builder_call); CAFFE_ENFORCE_EQ(py::len(builder_call), 3); diff --git a/caffe2/python/pybind_state_dlpack.cc b/caffe2/python/pybind_state_dlpack.cc index 7b1ec2b8e1418..a7204481224f1 100644 --- a/caffe2/python/pybind_state_dlpack.cc +++ b/caffe2/python/pybind_state_dlpack.cc @@ -14,7 +14,7 @@ const DLDeviceType* CaffeToDLDeviceType(int device_type) { return it == dl_device_type_map.end() ? nullptr : &it->second; } -const DLDataType* CaffeToDLType(const TypeMeta& meta) { +const DLDataType* CaffeToDLType(const TypeMeta meta) { static std::map dl_type_map{ {TypeMeta::Id(), DLDataType{0, 8, 1}}, {TypeMeta::Id(), DLDataType{0, 16, 1}}, @@ -30,7 +30,7 @@ const DLDataType* CaffeToDLType(const TypeMeta& meta) { return it == dl_type_map.end() ? nullptr : &it->second; } -const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type) { +const TypeMeta DLTypeToCaffe(const DLDataType& dl_type) { try { if (dl_type.lanes != 1) { throw std::invalid_argument("invalid type"); diff --git a/caffe2/python/pybind_state_dlpack.h b/caffe2/python/pybind_state_dlpack.h index 54f3157e7634d..bcdbc50a61d44 100644 --- a/caffe2/python/pybind_state_dlpack.h +++ b/caffe2/python/pybind_state_dlpack.h @@ -16,9 +16,9 @@ namespace py = pybind11; const DLDeviceType* CaffeToDLDeviceType(int device_type); -const DLDataType* CaffeToDLType(const TypeMeta& meta); +const DLDataType* CaffeToDLType(const TypeMeta meta); -const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type); +const TypeMeta DLTypeToCaffe(const DLDataType& dl_type); // TODO: remove context template @@ -40,7 +40,7 @@ class DLPackWrapper { if (tensor->numel() <= 0) { tensor->Resize(0); } - if (tensor->dtype().id() == TypeIdentifier::uninitialized()) { + if (tensor->dtype() == ScalarType::Undefined) { // treat uninitialized tensor as float tensor tensor->template mutable_data(); } diff --git a/caffe2/python/pybind_state_ideep.cc b/caffe2/python/pybind_state_ideep.cc index 8d09b0aaa326e..bbeaf524f055d 100644 --- a/caffe2/python/pybind_state_ideep.cc +++ b/caffe2/python/pybind_state_ideep.cc @@ -97,7 +97,7 @@ class IDeepFetcher : public BlobFetcherBase { }; class IDeepFeeder : public BlobFeederBase { - itensor::data_type type_transform(const TypeMeta &meta) { + itensor::data_type type_transform(const TypeMeta meta) { if (meta == TypeMeta::Make()) return itensor::data_type::f32; else if (meta == TypeMeta::Make()) @@ -119,10 +119,10 @@ class IDeepFeeder : public BlobFeederBase { PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array); auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta &meta = NumpyTypeToCaffe(npy_type); + const TypeMeta meta = NumpyTypeToCaffe(npy_type); CAFFE_ENFORCE_NE( - meta.id(), - TypeIdentifier::uninitialized(), + meta, + ScalarType::Undefined, "This numpy data type is not supported: ", PyArray_TYPE(array), "."); @@ -172,7 +172,7 @@ class IDeepFeeder : public BlobFeederBase { auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta &meta = NumpyTypeToCaffe(npy_type); + const TypeMeta meta = NumpyTypeToCaffe(npy_type); // TODO: if necessary, use dispatcher. if ((in_place && blob->IsType()) diff --git a/caffe2/python/python_op_test.py b/caffe2/python/python_op_test.py index 893671b96f45e..4b39adc3f36a5 100644 --- a/caffe2/python/python_op_test.py +++ b/caffe2/python/python_op_test.py @@ -8,7 +8,6 @@ from hypothesis import given, settings import hypothesis.strategies as st import numpy as np -import six class CustomError(Exception): @@ -55,12 +54,12 @@ def f(inputs, _): def test_exception(self): op = CreatePythonOperator(MainOpFunctionThatThrowsCustomError, [], []) - with six.assertRaisesRegex(self, CustomError, "This is an intentional exception."): + with self.assertRaisesRegex(CustomError, "This is an intentional exception."): workspace.RunOperatorOnce(op) def test_exception_builder(self): op = CreatePythonOperator(MainOpFunctionThatThrowsCustomErrorInBuilder, [], []) - with six.assertRaisesRegex(self, CustomError, "This is an intentional exception in builder."): + with self.assertRaisesRegex(CustomError, "This is an intentional exception in builder."): workspace.RunOperatorOnce(op) @given(x=hu.tensor()) diff --git a/caffe2/python/regularizer_context.py b/caffe2/python/regularizer_context.py index 5d79e138b6b79..27dc378189619 100644 --- a/caffe2/python/regularizer_context.py +++ b/caffe2/python/regularizer_context.py @@ -10,8 +10,7 @@ ModifierContext, UseModifierBase) -@context.define_context(allow_default=True) -class RegularizerContext(ModifierContext): +class RegularizerContext(ModifierContext, context.DefaultManaged): """ provide context to allow param_info to have different regularizers """ diff --git a/caffe2/python/rnn/lstm_comparison.py b/caffe2/python/rnn/lstm_comparison.py index dee96413dbe54..34fddbc1a66e8 100644 --- a/caffe2/python/rnn/lstm_comparison.py +++ b/caffe2/python/rnn/lstm_comparison.py @@ -2,7 +2,6 @@ -from caffe2.proto import caffe2_pb2 from caffe2.python import workspace, core, lstm_benchmark, utils from copy import copy diff --git a/caffe2/python/rnn_cell.py b/caffe2/python/rnn_cell.py index e16bfaaf491e1..f6da5e126119d 100644 --- a/caffe2/python/rnn_cell.py +++ b/caffe2/python/rnn_cell.py @@ -7,11 +7,9 @@ import functools import inspect -import itertools import logging import numpy as np import random -import six from future.utils import viewkeys from caffe2.proto import caffe2_pb2 @@ -32,7 +30,7 @@ def _RectifyName(blob_reference_or_name): if blob_reference_or_name is None: return None - if isinstance(blob_reference_or_name, six.string_types): + if isinstance(blob_reference_or_name, str): return core.ScopedBlobReference(blob_reference_or_name) if not isinstance(blob_reference_or_name, core.BlobReference): raise Exception("Unknown blob reference type") @@ -42,7 +40,7 @@ def _RectifyName(blob_reference_or_name): def _RectifyNames(blob_references_or_names): if blob_references_or_names is None: return None - return list(map(_RectifyName, blob_references_or_names)) + return [_RectifyName(i) for i in blob_references_or_names] class RNNCell(object): @@ -236,7 +234,7 @@ def get_state_names(self): ''' Returns recurrent state names with self.name scoping applied ''' - return list(map(self.scope, self.get_state_names_override())) + return [self.scope(name) for name in self.get_state_names_override()] def get_state_names_override(self): ''' diff --git a/caffe2/python/schema.py b/caffe2/python/schema.py index fb7cadf42847c..4f72a8cc1ffb8 100644 --- a/caffe2/python/schema.py +++ b/caffe2/python/schema.py @@ -30,7 +30,6 @@ from six import StringIO logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) FIELD_SEPARATOR = ':' @@ -98,6 +97,8 @@ class Field(object): """Represents an abstract field type in a dataset. """ + __slots__ = ("_parent", "_field_offsets") + def __init__(self, children): """Derived classes must call this after their initialization.""" self._parent = (None, 0) @@ -204,6 +205,8 @@ class List(Field): the parent domain. """ + __slots__ = ("lengths", "_items") + def __init__(self, values, lengths_blob=None): if isinstance(lengths_blob, Field): assert isinstance(lengths_blob, Scalar) @@ -213,7 +216,7 @@ def __init__(self, values, lengths_blob=None): self._items = _normalize_field(values) self.lengths._set_parent(self, 0) self._items._set_parent(self, 1) - Field.__init__(self, [self.lengths, self._items]) + super(List, self).__init__([self.lengths, self._items]) def field_names(self): value_fields = self._items.field_names() @@ -281,13 +284,16 @@ class ListWithEvicted(List): This class is similar with List, but containing extra field evicted_values for LRU Hashing. """ + + __slots__ = ("_evicted_values",) + def __init__(self, values, lengths_blob=None, evicted_values=None): if isinstance(evicted_values, Field): assert isinstance(evicted_values, Scalar) self._evicted_values = _normalize_field(evicted_values) else: self._evicted_values = Scalar(np.int64, evicted_values) - List.__init__(self, values, lengths_blob=lengths_blob) + super(ListWithEvicted, self).__init__(values, lengths_blob=lengths_blob) def field_names(self): value_fields = self._items.field_names() @@ -323,8 +329,8 @@ def _pprint_impl(self, indent, str_buffer): self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer) str_buffer.write(' ' * (indent + 1) + "_items=\n") self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer) - str_buffer.write(' ' * (indent + 1) + "_evicted_Values=\n") - self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer) + str_buffer.write(' ' * (indent + 1) + "_evicted_values=\n") + self._evicted_values._pprint_impl(indent=indent + 2, str_buffer=str_buffer) str_buffer.write(' ' * indent + ")\n") @@ -362,6 +368,8 @@ class Struct(Field): """Represents a named list of fields sharing the same domain. """ + __slots__ = ("fields", "_frozen") + def __init__(self, *fields): """ fields is a list of tuples in format of (name, field). The name is a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example @@ -408,7 +416,7 @@ def __init__(self, *fields): self.fields[name] = self.fields[name] + field for id, (_, field) in enumerate(viewitems(self.fields)): field._set_parent(self, id) - Field.__init__(self, viewvalues(self.fields)) + super(Struct, self).__init__(viewvalues(self.fields)) self._frozen = True def _struct_from_nested_name(self, nested_name, field): @@ -534,7 +542,7 @@ def __getattr__(self, item): if item.startswith('__'): raise AttributeError(item) try: - return self.__dict__['fields'][item] + return super(Struct, self).__getattribute__("fields")[item] except KeyError: raise AttributeError(item) @@ -710,10 +718,12 @@ class Scalar(Field): a conversion to numpy.ndarray is attempted. """ + __slots__ = ("_metadata", "dtype", "_original_dtype", "_blob") + def __init__(self, dtype=None, blob=None, metadata=None): self._metadata = None self.set(dtype, blob, metadata, unsafe=True) - Field.__init__(self, []) + super(Scalar, self).__init__([]) def field_names(self): return [''] @@ -970,6 +980,8 @@ def from_dtype(dtype, _outer_shape=()): class _SchemaNode(object): """This is a private class used to represent a Schema Node""" + __slots__ = ("name", "children", "type_str", "field") + def __init__(self, name, type_str=''): self.name = name self.children = [] diff --git a/caffe2/python/schema_test.py b/caffe2/python/schema_test.py index dca19a127ef2f..bb9536e4430b3 100644 --- a/caffe2/python/schema_test.py +++ b/caffe2/python/schema_test.py @@ -10,6 +10,31 @@ import pickle import random +class TestField(unittest.TestCase): + def testInitShouldSetEmptyParent(self): + f = schema.Field([]) + self.assertTupleEqual(f._parent, (None, 0)) + + def testInitShouldSetFieldOffsets(self): + f = schema.Field([ + schema.Scalar(dtype=np.int32), + schema.Struct( + ('field1', schema.Scalar(dtype=np.int32)), + ('field2', schema.List(schema.Scalar(dtype=str))), + ), + schema.Scalar(dtype=np.int32), + schema.Struct( + ('field3', schema.Scalar(dtype=np.int32)), + ('field4', schema.List(schema.Scalar(dtype=str))) + ), + schema.Scalar(dtype=np.int32), + ]) + self.assertListEqual(f._field_offsets, [0, 1, 4, 5, 8, 9]) + + def testInitShouldSetFieldOffsetsIfNoChildren(self): + f = schema.Field([]) + self.assertListEqual(f._field_offsets, [0]) + class TestDB(unittest.TestCase): def testPicklable(self): diff --git a/caffe2/python/scope_test.py b/caffe2/python/scope_test.py index 9bd69eb329026..bf3c8e9a0d06d 100644 --- a/caffe2/python/scope_test.py +++ b/caffe2/python/scope_test.py @@ -4,7 +4,6 @@ from caffe2.python import scope, core, workspace -from caffe2.proto import caffe2_pb2 import unittest import threading diff --git a/caffe2/python/session.py b/caffe2/python/session.py index de3b09931a302..fb2b57c4f5eec 100644 --- a/caffe2/python/session.py +++ b/caffe2/python/session.py @@ -192,7 +192,7 @@ def _compile_task_group(cls, task_group, setup_net_list=None): task = task_group.to_task() plan = core.Plan('task_group_plan') plan.AddStep(task.get_step()) - return (plan, task.output_list(), task.workspace_type) + return (plan, task.output_list(), task.workspace_type()) def _run_compiled(self, compiled): plan, output_list, workspace_type = compiled diff --git a/caffe2/python/task.py b/caffe2/python/task.py index f1b25ee260922..332dec0d16c41 100644 --- a/caffe2/python/task.py +++ b/caffe2/python/task.py @@ -1,10 +1,6 @@ ## @package task # Module caffe2.python.task - - - - from caffe2.python import core, context from caffe2.python.schema import Field, from_blob_list from collections import defaultdict @@ -23,8 +19,7 @@ def _merge_node_kwargs(a, b): return c -@context.define_context(allow_default=True) -class Cluster(object): +class Cluster(context.DefaultManaged): """ Context that keeps track of all the node names used. Users shouldn't have to use them directly, since a Cluster is automatically @@ -57,8 +52,7 @@ def __repr__(self): self.nodes(), self.node_kwargs()) -@context.define_context(allow_default=True) -class Node(object): +class Node(context.DefaultManaged): """ A Node context is used to indicate that all Tasks instantiated within will run on the given node name. (Only the name of the node actually counts.) @@ -162,8 +156,7 @@ def add_setup_steps(step, init_nets, exit_nets, name): return core.execution_step(name, steps) -@context.define_context(allow_default=False) -class TaskGroup(object): +class TaskGroup(context.Managed): """ Context that gathers tasks which will run concurrently, potentially on multiple nodes. All tasks in the same node will share the same workspace @@ -354,7 +347,9 @@ def workspace_type(self): def __repr__(self): return "TaskGroup(tasks={}, workspace_type={}, remote_nets={})".format( - self.tasks(), self.workspace_type(), self.remote_nets()) + self._tasks + self._tasks_to_add, + self.workspace_type(), + self.remote_nets()) class TaskOutput(object): @@ -442,8 +437,7 @@ def __repr__(self): return "TaskOutputList(outputs={})".format(self.outputs) -@context.define_context() -class Task(object): +class Task(context.Managed): """ A Task is composed of an execution step and zero or more outputs. Tasks are executed in the context of a TaskGroup, which, in turn, can @@ -542,6 +536,8 @@ def __init__( self._num_instances = num_instances def __enter__(self): + super(Task, self).__enter__() + # temporarily remove from _tasks_to_add to ensure correct order if self.group is not None: self.group._tasks_to_add.remove(self) @@ -553,6 +549,8 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): + super(Task, self).__exit__(type, value, traceback) + self._net_builder.__exit__(type, value, traceback) if type is None: self.set_step(self._net_builder) diff --git a/caffe2/python/task_test.py b/caffe2/python/task_test.py index c44e93a3704ca..31adb41a0ac99 100644 --- a/caffe2/python/task_test.py +++ b/caffe2/python/task_test.py @@ -1,8 +1,3 @@ - - - - - import unittest from caffe2.python import task @@ -22,3 +17,8 @@ def testRepr(self): ] for obj, want in cases: self.assertEqual(obj.__repr__(), want) + + def testEffectlessRepr(self): + task_group = task.TaskGroup() + _repr = task_group.__repr__() + self.assertFalse(task_group._already_used) diff --git a/caffe2/python/test/executor_test_util.py b/caffe2/python/test/executor_test_util.py index ba10247eaa2ef..abf63626a7fad 100644 --- a/caffe2/python/test/executor_test_util.py +++ b/caffe2/python/test/executor_test_util.py @@ -14,7 +14,6 @@ import time import numpy as np -from hypothesis import settings CI_MAX_EXAMPLES = 2 diff --git a/caffe2/python/test/inference_lstm_op_test.py b/caffe2/python/test/inference_lstm_op_test.py index 20caab9ba78b2..768827bd8876c 100644 --- a/caffe2/python/test/inference_lstm_op_test.py +++ b/caffe2/python/test/inference_lstm_op_test.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 -import inspect import hypothesis.strategies as st import numpy as np import torch -from caffe2.python import core, workspace +from caffe2.python import core from caffe2.python.test_util import TestCase from hypothesis import given, settings from torch import nn diff --git a/caffe2/python/test/python_protobuf_test.py b/caffe2/python/test/python_protobuf_test.py index 7790e0f6d8f5c..a407f33fe2537 100644 --- a/caffe2/python/test/python_protobuf_test.py +++ b/caffe2/python/test/python_protobuf_test.py @@ -5,9 +5,6 @@ # make sure we use cpp implementation of protobuf import os os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" - -# import cpp extension first -from caffe2.python import core # then import protobuf from caffe2.proto import caffe2_pb2, metanet_pb2 diff --git a/caffe2/python/trt/test_pt_onnx_trt.py b/caffe2/python/trt/test_pt_onnx_trt.py index 96f1ad76f6b73..5e6abb5c4d0b4 100644 --- a/caffe2/python/trt/test_pt_onnx_trt.py +++ b/caffe2/python/trt/test_pt_onnx_trt.py @@ -15,17 +15,13 @@ import os import unittest -from typing import List, Any from PIL import Image import numpy as np import torch -from torch.onnx import OperatorExportTypes import torchvision.models as models import pycuda.driver as cuda -# This import causes pycuda to automatically manage CUDA context creation and cleanup. -import pycuda.autoinit import tensorrt as trt TRT_LOGGER = trt.Logger(trt.Logger.WARNING) diff --git a/caffe2/python/trt/test_trt.py b/caffe2/python/trt/test_trt.py index 39d37ca9fa0a9..2782cca7c13ff 100644 --- a/caffe2/python/trt/test_trt.py +++ b/caffe2/python/trt/test_trt.py @@ -7,7 +7,7 @@ from caffe2.python import core, workspace import onnx import onnx.defs -from onnx.helper import make_node, make_graph, make_tensor, make_tensor_value_info, make_model +from onnx.helper import make_node, make_graph, make_tensor_value_info, make_model from onnx.backend.base import namedtupledict from caffe2.python.models.download import ModelDownloader import caffe2.python.onnx.backend as c2 @@ -16,7 +16,6 @@ from caffe2.python.onnx.tests.test_utils import TestCase import numpy as np import os.path -import json import time import unittest import tarfile diff --git a/caffe2/python/trt/transform.py b/caffe2/python/trt/transform.py index 0936941aac039..1b201007daab8 100644 --- a/caffe2/python/trt/transform.py +++ b/caffe2/python/trt/transform.py @@ -12,9 +12,7 @@ from caffe2.proto import caffe2_pb2 -from caffe2.python.onnx.helper import c2_native_run_net, c2_native_run_op -from caffe2.python import core, workspace -import caffe2.python.onnx.frontend as c2_front +from caffe2.python import workspace import caffe2.python._import_c_extension as C import numpy as np diff --git a/caffe2/python/utils.py b/caffe2/python/utils.py index 947dd9bf296d7..289d107303fab 100644 --- a/caffe2/python/utils.py +++ b/caffe2/python/utils.py @@ -6,12 +6,12 @@ from caffe2.proto import caffe2_pb2 -from caffe2.python.compatibility import container_abcs from future.utils import viewitems from google.protobuf.message import DecodeError, Message from google.protobuf import text_format import sys +import collections import copy import functools import numpy as np @@ -126,7 +126,7 @@ def MakeArgument(key, value): """Makes an argument based on the value type.""" argument = caffe2_pb2.Argument() argument.name = key - iterable = isinstance(value, container_abcs.Iterable) + iterable = isinstance(value, collections.abc.Iterable) # Fast tracking common use case where a float32 array of tensor parameters # needs to be serialized. The entire array is guaranteed to have the same diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index 99983e84f0976..0aa46ee2d4b35 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -335,7 +335,7 @@ def StringifyNetName(name): def GetNetName(net): if isinstance(net, basestring): return net - if type(net).__name__ == "Net": + if type(net).__name__ == "Net" or type(net).__name__ == "NetWithShapeInference": return net.Name() if isinstance(net, caffe2_pb2.NetDef): return net.name diff --git a/caffe2/quantization/server/compute_equalization_scale.cc b/caffe2/quantization/server/compute_equalization_scale.cc index 6e2f73ebd8402..2ed6b4c954e61 100644 --- a/caffe2/quantization/server/compute_equalization_scale.cc +++ b/caffe2/quantization/server/compute_equalization_scale.cc @@ -17,11 +17,9 @@ bool ComputeEqualizationScaleOp::RunOnDevice() { const int64_t N = W.size_to_dim(1); const int64_t K = W.size_from_dim(1); auto* S = Output(0, K, at::dtype()); - auto* S_INV = Output(1, K, at::dtype()); const float* X_data = X.template data(); const float* W_data = W.template data(); float* S_data = S->template mutable_data(); - float* S_INV_data = S_INV->template mutable_data(); float WcolMax, XcolMax; for (int64_t j = 0; j < K; j++) { @@ -38,10 +36,8 @@ bool ComputeEqualizationScaleOp::RunOnDevice() { } if (WcolMax == 0 || XcolMax == 0) { S_data[j] = 1; - S_INV_data[j] = 1; } else { S_data[j] = std::sqrt(WcolMax / XcolMax); - S_INV_data[j] = 1 / S_data[j]; } } return true; @@ -50,10 +46,10 @@ bool ComputeEqualizationScaleOp::RunOnDevice() { REGISTER_CPU_OPERATOR(ComputeEqualizationScale, ComputeEqualizationScaleOp); OPERATOR_SCHEMA(ComputeEqualizationScale) .NumInputs(2) - .NumOutputs(2) + .NumOutputs(1) .SetDoc(R"DOC( Given a weight matrix W and input matrix X, the output S is the equalization parameter -vector computed from W and X, and S_INV = 1 / S +vector computed from W and X S is computed by: S[j] = max(abs(W[][j])) == 0 || max(abs(X[][j])) == 0 ? 1 : @@ -62,19 +58,12 @@ S[j] = max(abs(W[][j])) == 0 || max(abs(X[][j])) == 0 ? 1 : )DOC") .TensorInferenceFunction([](const OperatorDef& /* def */, const vector& in) { - vector out(2); - - if (in[0].unknown_shape() || in[1].unknown_shape()) { - out[0].set_unknown_shape(true); - out[1].set_unknown_shape(true); - return out; - } + vector out(1); const int64_t K = size_from_dim_(1, GetDimsVector(in[1])); vector s_shape(2); s_shape[0] = 1; s_shape[1] = K; out[0] = CreateTensorShape(s_shape, TensorProto_DataType_FLOAT); - out[1] = CreateTensorShape(s_shape, TensorProto_DataType_FLOAT); return out; }) .Input( @@ -86,10 +75,6 @@ S[j] = max(abs(W[][j])) == 0 || max(abs(X[][j])) == 0 ? 1 : 0, "S", "Scale computed that will be multiplied to the columns of input.") - .Output( - 1, - "S_INV", - "Scale inverse that will be multiplied to the columns of weight.") .SetDoc( R"DOC(Operator to compute equalization scale given the input data and weight)DOC"); diff --git a/caffe2/quantization/server/compute_equalization_scale_test.py b/caffe2/quantization/server/compute_equalization_scale_test.py index 74d34c5502d3f..bad21d7cafd7e 100644 --- a/caffe2/quantization/server/compute_equalization_scale_test.py +++ b/caffe2/quantization/server/compute_equalization_scale_test.py @@ -38,19 +38,17 @@ def test_compute_equalization_scale(self, m, n, k, rnd_seed, gc, dc): def ref_compute_equalization_scale(X, W): S = np.ones([X.shape[1]]) - S_INV = np.ones([X.shape[1]]) for j in range(W.shape[1]): WcolMax = np.absolute(W[:, j]).max() XcolMax = np.absolute(X[:, j]).max() if WcolMax and XcolMax: S[j] = np.sqrt(WcolMax / XcolMax) - S_INV[j] = 1 / S[j] - return S, S_INV + return S net = core.Net("test") ComputeEqualizationScaleOp = core.CreateOperator( - "ComputeEqualizationScale", ["X", "W"], ["S", "S_INV"] + "ComputeEqualizationScale", ["X", "W"], ["S"] ) net.Proto().op.extend([ComputeEqualizationScaleOp]) @@ -59,16 +57,14 @@ def ref_compute_equalization_scale(X, W): self.ws.run(net) S = self.ws.blobs["S"].fetch() - S_INV = self.ws.blobs["S_INV"].fetch() - S_ref, S_INV_ref = ref_compute_equalization_scale(X, W) + S_ref = ref_compute_equalization_scale(X, W) np.testing.assert_allclose(S, S_ref, atol=1e-3, rtol=1e-3) - np.testing.assert_allclose(S_INV, S_INV_ref, atol=1e-3, rtol=1e-3) def test_compute_equalization_scale_shape_inference(self): X = np.array([[1, 2], [2, 4], [6, 7]]).astype(np.float32) W = np.array([[2, 3], [5, 4], [8, 2]]).astype(np.float32) ComputeEqualizationScaleOp = core.CreateOperator( - "ComputeEqualizationScale", ["X", "W"], ["S", "S_INV"] + "ComputeEqualizationScale", ["X", "W"], ["S"] ) workspace.FeedBlob("X", X) workspace.FeedBlob("W", W) @@ -81,9 +77,7 @@ def test_compute_equalization_scale_shape_inference(self): blob_types={"X": core.DataType.FLOAT, "W": core.DataType.FLOAT}, ) assert ( - "S" in shapes and "S" in types and "S_INV" in shapes and "S_INV" in types + "S" in shapes and "S" in types ), "Failed to infer the shape or type of output" self.assertEqual(shapes["S"], [1, 2]) - self.assertEqual(shapes["S_INV"], [1, 2]) self.assertEqual(types["S"], core.DataType.FLOAT) - self.assertEqual(types["S_INV"], core.DataType.FLOAT) diff --git a/caffe2/quantization/server/elementwise_sum_relu_op.cc b/caffe2/quantization/server/elementwise_sum_relu_op.cc index df4b726c73060..dbb14c0c5ce84 100644 --- a/caffe2/quantization/server/elementwise_sum_relu_op.cc +++ b/caffe2/quantization/server/elementwise_sum_relu_op.cc @@ -42,11 +42,13 @@ class SumReluOp : public SumOp { bool RunOnDevice() override { if (Input(0).template IsType()) { return DoRunWithType(); + } else if (Input(0).template IsType()) { + return DoRunWithType(); } else if (Input(0).template IsType()) { return DoRunWithType(); } else { CAFFE_THROW( - "Sum operator only supports 32-bit float and ints, but", + "Sum operator only supports 32-bit float, 64-bit double and ints, but", " input was of type ", Input(0).dtype().name()); } diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc index e05b93bad1696..25e396cf8919b 100644 --- a/caffe2/quantization/server/fbgemm_pack_op.cc +++ b/caffe2/quantization/server/fbgemm_pack_op.cc @@ -211,7 +211,7 @@ void QuantizeConvBias( if (use_fp16) { bdata_local.resize(bias.numel()); fbgemm::RoundToFloat16( - bdata, bdata_local.data(), bias.numel(), 1 /* FLAGS_caffe2_fbgemm_fake_fp16_clamp */); + bdata, bdata_local.data(), bias.numel(), false /* FLAGS_caffe2_fbgemm_fake_fp16_clamp */); bdata = bdata_local.data(); } b_quantized.resize(bias.numel()); diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc index c7e6804c1dcfe..4a5a6e6b7ad0f 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc +++ b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc @@ -190,6 +190,9 @@ bool FullyConnectedDNNLowPOp::RunOnDevice() { if (!dequantize_output_) { Y_int32_.resize(Y->size()); + if (Y_int32_.size() < Y_int32_.capacity() / 2) { + Y_int32_.shrink_to_fit(); + } DoNothing<> doNothingObj{}; if (quantize_channelwise_ || filter_qparams_[0].zero_point) { @@ -443,6 +446,9 @@ bool FullyConnectedDNNLowPOp::RunOnDevice() { #endif Y_int32_.resize(Y->size()); + if (Y_int32_.size() < Y_int32_.capacity() / 2) { + Y_int32_.shrink_to_fit(); + } for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int32_t sum = 0; diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py b/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py index f1939e198b84e..3a8b0c14931e5 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py +++ b/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py @@ -6,7 +6,7 @@ import hypothesis.strategies as st import numpy as np from caffe2.python import core, dyndep, workspace -from caffe2.quantization.server import dnnlowp_pybind11, utils as dnnlowp_utils +from caffe2.quantization.server import utils as dnnlowp_utils from caffe2.quantization.server.dnnlowp_test_utils import ( avoid_vpmaddubsw_overflow_fc, check_quantized_results_close, diff --git a/caffe2/quantization/server/int8_gen_quant_params_min_max.cc b/caffe2/quantization/server/int8_gen_quant_params_min_max.cc new file mode 100644 index 0000000000000..76a2bb747242c --- /dev/null +++ b/caffe2/quantization/server/int8_gen_quant_params_min_max.cc @@ -0,0 +1,37 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include "caffe2/quantization/server/int8_gen_quant_params_min_max.h" +#include +#include "caffe2/quantization/server/int8_gen_quant_params.h" + +namespace caffe2 { +using namespace std; +using namespace dnnlowp; + +REGISTER_CPU_OPERATOR( + Int8GenQuantParamsMinMax, + Int8GenQuantParamsMinMaxOp); +OPERATOR_SCHEMA(Int8GenQuantParamsMinMax) + .NumInputs(2, 3) + .NumOutputs(1) + .TensorInferenceFunction([](const OperatorDef& /* def */, + const vector& /* in */) { + vector out(1); + out[0].set_data_type(TensorProto_DataType_FLOAT); + out[0].add_dims(1); + return out; + }) + .Input(0, "min", "The lower bound of the tensor to be quantized.") + .Input(1, "max", "The upper bound of the tensor to be quantized.") + .Input( + 2, + "quant_scheme", + "(Optional) Int8QuantSchemeBlob that specifies the quantization kind and preserve_sparsity options when generating the quant params. We only use preserve_sparsity in this op which is default to be false.") + .Output( + 0, + "quant_param", + "Int8QuantParamsBlob that contains the scale and zero_point info in TensorQuantizationParams type.") + .SetDoc( + R"DOC(Operator wrapper for generating int8 tensor quantization parameters given lower and upper bound of the input tensor)DOC"); + +} // namespace caffe2 diff --git a/caffe2/quantization/server/int8_gen_quant_params_min_max.h b/caffe2/quantization/server/int8_gen_quant_params_min_max.h new file mode 100644 index 0000000000000..ada6a46a8dece --- /dev/null +++ b/caffe2/quantization/server/int8_gen_quant_params_min_max.h @@ -0,0 +1,50 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once +#include "caffe2/quantization/server/caffe2_dnnlowp_utils.h" +#include "caffe2/quantization/server/dnnlowp.h" +#include "caffe2/quantization/server/int8_gen_quant_params.h" +#include + + +namespace caffe2 { +using namespace std; +using dnnlowp::TensorQuantizationParams; + +template +class Int8GenQuantParamsMinMaxOp final : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + Int8GenQuantParamsMinMaxOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + bool RunOnDevice() override { + // Generate Int8 quant params based on the input data (last N samples of the + // activations) and the quant scheme + const float min = + OperatorBase::Input(0, CPU).template data()[0]; + const float max = + OperatorBase::Input(1, CPU).template data()[0]; + bool preserve_sparsity = false; + if (InputSize() == 3){ + const auto* quant_scheme = + this->template Input>(2).get(); + preserve_sparsity = quant_scheme->preserve_sparsity_; + } + dnnlowp::QuantizationFactory* qfactory = + dnnlowp::QuantizationFactory::GetDefaultInstance(); + TensorQuantizationParams qparam = qfactory->ChooseQuantizationParams( + min, + max, + 8, + preserve_sparsity); + auto* output_qparam = + this->template Output>(0); + output_qparam->reset( + new Int8QuantParamsBlob(qparam.scale, qparam.zero_point)); + LOG_EVERY_N(INFO, 1) << "scale and bias are " << qparam.scale << "," << qparam.zero_point; + return true; + } + +}; // class Int8GenQuantParamsOp + +} // namespace caffe2 diff --git a/caffe2/quantization/server/int8_gen_quant_params_min_max_test.py b/caffe2/quantization/server/int8_gen_quant_params_min_max_test.py new file mode 100644 index 0000000000000..dd27074db5c41 --- /dev/null +++ b/caffe2/quantization/server/int8_gen_quant_params_min_max_test.py @@ -0,0 +1,83 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + + + +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np +from caffe2.python import core, workspace +from caffe2.quantization.server import dnnlowp_pybind11 +from hypothesis import given, settings + + +class TestInt8GenQuantParamsMinMaxOperator(hu.HypothesisTestCase): + @settings(max_examples=20, deadline=None) + @given( + n=st.integers(10, 10), + m=st.integers(10, 10), + preserve_sparsity=st.booleans(), + rnd_seed=st.integers(1, 5), + **hu.gcs_cpu_only + ) + def test_int8_gen_quant_params_min_max_op( + self, n, m, preserve_sparsity, rnd_seed, gc, dc + ): + X_min = 0 if preserve_sparsity else -77 + X_max = X_min + 255 + np.random.seed(rnd_seed) + X = np.round(np.random.rand(n, m) * (X_max - X_min) + X_min).astype( + np.float32 + ) + # Calculate X_qparam + hist, bin_edges = np.histogram(X.flatten(), bins=2048) + X_qparam = dnnlowp_pybind11.ChooseStaticQuantizationParams( + np.min(X), np.max(X), hist, preserve_sparsity, 8, "MIN_MAX_QUANTIZATION" + ) + + # Build a net to generate X's qparam using the Int8GenQuantParamsMinMax op + workspace.FeedBlob("X", X, device_option=gc) + workspace.FeedBlob("X_min", np.array([np.min(X)]), device_option=gc) + workspace.FeedBlob("X_max", np.array([np.max(X)]), device_option=gc) + dnnlowp_pybind11.CreateInt8QuantSchemeBlob( + "quant_scheme", "MIN_MAX_QUANTIZATION", preserve_sparsity + ) + assert workspace.HasBlob( + "quant_scheme" + ), "Failed to create the quant_scheme blob in current workspace" + + gen_quant_params_net = core.Net("gen_quant_params_min_max") + gen_quant_params_op = core.CreateOperator( + "Int8GenQuantParamsMinMax", + ["X_min", "X_max", "quant_scheme"], + ["quant_param"], + device_option=gc, + ) + gen_quant_params_net.Proto().op.extend([gen_quant_params_op]) + assert workspace.RunNetOnce( + gen_quant_params_net + ), "Failed to run the gen_quant_params net" + scale, zero_point = dnnlowp_pybind11.ObserveInt8QuantParamsBlob("quant_param") + + shapes, types = workspace.InferShapesAndTypes( + [gen_quant_params_net], + blob_dimensions={"X": [n, m], "X_min": [1], "X_max": [1], "quant_scheme": [1]}, + blob_types={"X": core.DataType.FLOAT, "X_min": core.DataType.FLOAT, "X_max": core.DataType.FLOAT, "quant_scheme": core.DataType.STRING} + ) + self.assertEqual(shapes["quant_param"], [1]) + self.assertEqual(types["quant_param"], core.DataType.FLOAT) + + np.testing.assert_equal(scale, X_qparam.scale) + np.testing.assert_equal(zero_point, X_qparam.zero_point) diff --git a/caffe2/quantization/server/norm_minimization.cc b/caffe2/quantization/server/norm_minimization.cc index 94e655e56da2e..a8d0d3da0dbe5 100644 --- a/caffe2/quantization/server/norm_minimization.cc +++ b/caffe2/quantization/server/norm_minimization.cc @@ -14,6 +14,10 @@ namespace dnnlowp { #undef NDEBUG +// Use fp16_min as the small scale cutoff because we don't want to use scales in fp16 subnormal range. +// This is to be consistent with Glow and FakeLowP implementation for NNPI. +constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + static float GetNorm(float begin, float end, float density, NormMinimization::Kind kind) { float norm = 0; @@ -57,7 +61,8 @@ TensorQuantizationParams NormMinimization::NonlinearQuantizationParamsSearch( vector bins_f(dnnlowp::adjust_hist_to_include_zero(hist, &min, &max)); int nbins = bins_f.size(); float bin_width = (max - min) / nbins; - if (bin_width == 0) { + float scale = (max - min) / float((1 << precision) - 1); + if (bin_width == 0 || scale < SMALL_SCALE_THRESHOLD) { QuantizationFactory* qfactory = QuantizationFactory::GetDefaultInstance(); return qfactory->ChooseQuantizationParams( min, max, precision, preserve_sparsity); @@ -190,6 +195,12 @@ TensorQuantizationParams NormMinimization::ChooseQuantizationParams( int nbins = bins_f.size(); float bin_width = (max - min) / nbins; + float scale = (max - min) / float((1 << precision) - 1); + if (bin_width == 0 || scale < SMALL_SCALE_THRESHOLD) { + QuantizationFactory* qfactory = QuantizationFactory::GetDefaultInstance(); + return qfactory->ChooseQuantizationParams( + min, max, precision, preserve_sparsity); + } int dst_nbins = 1 << precision; int zero_bin = round(-min / bin_width); diff --git a/caffe2/quantization/server/quantize_dnnlowp_op.cc b/caffe2/quantization/server/quantize_dnnlowp_op.cc index 64cc6b83fb6e0..da047d0bfdf2c 100644 --- a/caffe2/quantization/server/quantize_dnnlowp_op.cc +++ b/caffe2/quantization/server/quantize_dnnlowp_op.cc @@ -1,10 +1,6 @@ #include "quantize_dnnlowp_op.h" #include "dnnlowp_op.h" -#ifdef _OPENMP -#include -#endif - #include "caffe2/core/tensor_int8.h" #include "caffe2/quantization/server/int8_gen_quant_params.h" #include "caffe2_dnnlowp_utils.h" @@ -67,16 +63,7 @@ bool QuantizeDNNLowPOp::RunOnDevice() { const float* in_data = Input(0).template data(); T* out_data = output->t.template mutable_data(); -#ifdef _OPENMP -#pragma omp parallel -#endif - { - int i_begin, i_end; - tie(i_begin, i_end) = Get1DPartition( - Input(0).numel(), dnnlowp_get_num_threads(), dnnlowp_get_thread_num()); - fbgemm::Quantize( - in_data + i_begin, out_data + i_begin, i_end - i_begin, in_qparams); - } + fbgemm::Quantize(in_data, out_data, Input(0).numel(), in_qparams); PropagateOutputTensorQuantizationParams(this, 0, in_qparams); diff --git a/caffe2/queue/blobs_queue.h b/caffe2/queue/blobs_queue.h index 5ad5c93513170..a60cc1570c44d 100644 --- a/caffe2/queue/blobs_queue.h +++ b/caffe2/queue/blobs_queue.h @@ -20,7 +20,7 @@ namespace caffe2 { // Containing blobs are owned by the workspace. // On read, we swap out the underlying data for the blob passed in for blobs -class CAFFE2_API BlobsQueue : public std::enable_shared_from_this { +class TORCH_API BlobsQueue : public std::enable_shared_from_this { public: BlobsQueue( Workspace* ws, diff --git a/caffe2/requirements.txt b/caffe2/requirements.txt index 7c0367da1d855..aa8d2be43aa5d 100644 --- a/caffe2/requirements.txt +++ b/caffe2/requirements.txt @@ -2,4 +2,3 @@ numpy enum34 pyyaml requests -typing diff --git a/caffe2/serialize/crc_alt.h b/caffe2/serialize/crc_alt.h index be51083fec0e7..e7c986ff89fb8 100644 --- a/caffe2/serialize/crc_alt.h +++ b/caffe2/serialize/crc_alt.h @@ -680,12 +680,12 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) // put operator for one zero bit in odd odd[0] = Polynomial; // CRC-32 polynomial - for (int i = 1; i < CrcBits; i++) + for (uint32_t i = 1; i < CrcBits; i++) odd[i] = 1 << (i - 1); // put operator for two zero bits in even // same as gf2_matrix_square(even, odd); - for (int i = 0; i < CrcBits; i++) + for (uint32_t i = 0; i < CrcBits; i++) { uint32_t vec = odd[i]; even[i] = 0; @@ -695,7 +695,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) } // put operator for four zero bits in odd // same as gf2_matrix_square(odd, even); - for (int i = 0; i < CrcBits; i++) + for (uint32_t i = 0; i < CrcBits; i++) { uint32_t vec = even[i]; odd[i] = 0; @@ -711,7 +711,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) for (; lengthB > 0; lengthB >>= 1) { // same as gf2_matrix_square(a, b); - for (int i = 0; i < CrcBits; i++) + for (uint32_t i = 0; i < CrcBits; i++) { uint32_t vec = b[i]; a[i] = 0; diff --git a/caffe2/serialize/file_adapter.h b/caffe2/serialize/file_adapter.h index 416208ec05445..ee68b794967d2 100644 --- a/caffe2/serialize/file_adapter.h +++ b/caffe2/serialize/file_adapter.h @@ -10,7 +10,7 @@ namespace caffe2 { namespace serialize { -class CAFFE2_API FileAdapter final : public ReadAdapterInterface { +class TORCH_API FileAdapter final : public ReadAdapterInterface { public: C10_DISABLE_COPY_AND_ASSIGN(FileAdapter); explicit FileAdapter(const std::string& file_name); diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 63f5a34aa23b1..3d9701274ba3d 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -65,7 +65,7 @@ PyTorchStreamReader::PyTorchStreamReader(std::istream* in) } PyTorchStreamReader::PyTorchStreamReader( - std::unique_ptr in) + std::shared_ptr in) : ar_(std::make_unique()), in_(std::move(in)) { init(); } @@ -306,6 +306,7 @@ void PyTorchStreamWriter::setup(const string& file_name) { file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary); valid("opening archive ", file_name.c_str()); + TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened."); writer_func_ = [this](const void* buf, size_t nbytes) -> size_t { file_stream_.write(static_cast(buf), nbytes); return !file_stream_ ? 0 : nbytes; diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 2e841d0ad8248..87c3151bbb760 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -12,6 +12,7 @@ #include "caffe2/serialize/istream_adapter.h" #include "caffe2/serialize/read_adapter_interface.h" +#include "caffe2/serialize/versions.h" extern "C" { typedef struct mz_zip_archive mz_zip_archive; @@ -90,73 +91,11 @@ typedef struct mz_zip_archive mz_zip_archive; namespace caffe2 { namespace serialize { -constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; -constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L; - -// Versions (i.e. why was the version number bumped?) - -// Note [Dynamic Versions and torch.jit.save vs. torch.save] -// -// Our versioning scheme has a "produced file format version" which -// describes how an archive is to be read. The version written in an archive -// is at least this current produced file format version, but may be greater -// if it includes certain symbols. We refer to these conditional versions -// as "dynamic," since they are identified at runtime. -// -// Dynamic versioning is useful when an operator's semantics are updated. -// When using torch.jit.save we want those semantics to be preserved. If -// we bumped the produced file format version on every change, however, -// then older versions of PyTorch couldn't read even simple archives, like -// a single tensor, from newer versions of PyTorch. Instead, we -// assign dynamic versions to these changes that override the -// produced file format version as needed. That is, when the semantics -// of torch.div changed it was assigned dynamic version 4, and when -// torch.jit.saving modules that use torch.div those archives also have -// (at least) version 4. This prevents earlier versions of PyTorch -// from accidentally performing the wrong kind of division. Modules -// that don't use torch.div or other operators with dynamic versions -// can write the produced file format version, and these programs will -// run as expected on earlier versions of PyTorch. -// -// While torch.jit.save attempts to preserve operator semantics, -// torch.save does not. torch.save is analogous to pickling Python, so -// a function that uses torch.div will have different behavior if torch.saved -// and torch.loaded across PyTorch versions. From a technical perspective, -// torch.save ignores dynamic versioning. - -// 1. Initial version -// 2. Removed op_version_set version numbers -// 3. Added type tags to pickle serialization of container types -// 4. (Dynamic) Stopped integer division using torch.div -// (a versioned symbol preserves the historic behavior of versions 1--3) -// 5. (Dynamic) Stops torch.full inferring a floating point dtype -// when given bool or integer fill values. -constexpr uint64_t kProducedFileFormatVersion = 0x3L; - -// the version we write when the archive contains bytecode. -// It must be higher or eq to kProducedFileFormatVersion. -// Because torchscript changes is likely introduce bytecode change. -// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion -// should be increased too. The relationship is: -// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion -// >= kProducedFileFormatVersion -constexpr uint64_t kProducedBytecodeVersion = 0x4L; - -static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, - "kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion."); - -// Introduce kMinSupportedBytecodeVersion for limited backward compatibility -// support of bytecode. If -// kMinSupportedBytecodeVersion <= model_version <= kProducedBytecodeVersion (in loader), -// we should support this model_version. For example, we provide a wrapper to -// handle an updated operator. -constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L; - -class CAFFE2_API PyTorchStreamReader final { +class TORCH_API PyTorchStreamReader final { public: explicit PyTorchStreamReader(const std::string& file_name); explicit PyTorchStreamReader(std::istream* in); - explicit PyTorchStreamReader(std::unique_ptr in); + explicit PyTorchStreamReader(std::shared_ptr in); // return dataptr, size std::tuple getRecord(const std::string& name); @@ -180,11 +119,11 @@ class CAFFE2_API PyTorchStreamReader final { std::unique_ptr ar_; std::string archive_name_; std::string archive_name_plus_slash_; - std::unique_ptr in_; + std::shared_ptr in_; int64_t version_; }; -class CAFFE2_API PyTorchStreamWriter final { +class TORCH_API PyTorchStreamWriter final { public: explicit PyTorchStreamWriter(std::string archive_name); explicit PyTorchStreamWriter( diff --git a/caffe2/serialize/istream_adapter.h b/caffe2/serialize/istream_adapter.h index b7a0444e5f632..8960d5535c885 100644 --- a/caffe2/serialize/istream_adapter.h +++ b/caffe2/serialize/istream_adapter.h @@ -9,7 +9,7 @@ namespace caffe2 { namespace serialize { // this is a reader implemented by std::istream -class CAFFE2_API IStreamAdapter final : public ReadAdapterInterface { +class TORCH_API IStreamAdapter final : public ReadAdapterInterface { public: C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter); explicit IStreamAdapter(std::istream* istream); diff --git a/caffe2/serialize/read_adapter_interface.h b/caffe2/serialize/read_adapter_interface.h index 556c0051cfae5..0a6b5b74a762e 100644 --- a/caffe2/serialize/read_adapter_interface.h +++ b/caffe2/serialize/read_adapter_interface.h @@ -11,7 +11,7 @@ namespace serialize { // this is the interface for the (file/stream/memory) reader in // PyTorchStreamReader. with this interface, we can extend the support // besides standard istream -class CAFFE2_API ReadAdapterInterface { +class TORCH_API ReadAdapterInterface { public: virtual size_t size() const = 0; virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h new file mode 100644 index 0000000000000..4da4b2c50305b --- /dev/null +++ b/caffe2/serialize/versions.h @@ -0,0 +1,68 @@ +#pragma once + +namespace caffe2 { +namespace serialize { + +constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; +constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L; + +// Versions (i.e. why was the version number bumped?) + +// Note [Dynamic Versions and torch.jit.save vs. torch.save] +// +// Our versioning scheme has a "produced file format version" which +// describes how an archive is to be read. The version written in an archive +// is at least this current produced file format version, but may be greater +// if it includes certain symbols. We refer to these conditional versions +// as "dynamic," since they are identified at runtime. +// +// Dynamic versioning is useful when an operator's semantics are updated. +// When using torch.jit.save we want those semantics to be preserved. If +// we bumped the produced file format version on every change, however, +// then older versions of PyTorch couldn't read even simple archives, like +// a single tensor, from newer versions of PyTorch. Instead, we +// assign dynamic versions to these changes that override the +// produced file format version as needed. That is, when the semantics +// of torch.div changed it was assigned dynamic version 4, and when +// torch.jit.saving modules that use torch.div those archives also have +// (at least) version 4. This prevents earlier versions of PyTorch +// from accidentally performing the wrong kind of division. Modules +// that don't use torch.div or other operators with dynamic versions +// can write the produced file format version, and these programs will +// run as expected on earlier versions of PyTorch. +// +// While torch.jit.save attempts to preserve operator semantics, +// torch.save does not. torch.save is analogous to pickling Python, so +// a function that uses torch.div will have different behavior if torch.saved +// and torch.loaded across PyTorch versions. From a technical perspective, +// torch.save ignores dynamic versioning. + +// 1. Initial version +// 2. Removed op_version_set version numbers +// 3. Added type tags to pickle serialization of container types +// 4. (Dynamic) Stopped integer division using torch.div +// (a versioned symbol preserves the historic behavior of versions 1--3) +// 5. (Dynamic) Stops torch.full inferring a floating point dtype +// when given bool or integer fill values. +constexpr uint64_t kProducedFileFormatVersion = 0x3L; + +// the version we write when the archive contains bytecode. +// It must be higher or eq to kProducedFileFormatVersion. +// Because torchscript changes is likely introduce bytecode change. +// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion +// should be increased too. The relationship is: +// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion +// >= kProducedFileFormatVersion +constexpr uint64_t kProducedBytecodeVersion = 0x4L; + +static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, + "kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion."); + +// Introduce kMinSupportedBytecodeVersion for limited backward compatibility +// support of bytecode. If +// kMinSupportedBytecodeVersion <= model_version <= kProducedBytecodeVersion (in loader), +// we should support this model_version. For example, we provide a wrapper to +// handle an updated operator. +constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L; +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/sgd/adadelta_op_gpu.cu b/caffe2/sgd/adadelta_op_gpu.cu index efccaf7ba63e6..92416b503bcd8 100644 --- a/caffe2/sgd/adadelta_op_gpu.cu +++ b/caffe2/sgd/adadelta_op_gpu.cu @@ -47,6 +47,7 @@ void AdadeltaUpdate( CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, w, g, h, d, epsilon, decay, lr, nw, nh, nd); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } // namespace @@ -154,6 +155,7 @@ class CUDASparseAdadeltaOp final : public Operator { paramOut, momentOut, momentDeltaOut); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/caffe2/sgd/adagrad_fused_op_gpu.cu b/caffe2/sgd/adagrad_fused_op_gpu.cu index 814a24c74183f..539b9919e8e7d 100644 --- a/caffe2/sgd/adagrad_fused_op_gpu.cu +++ b/caffe2/sgd/adagrad_fused_op_gpu.cu @@ -308,69 +308,132 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel( const float LR = lr[0]; // num_indices blocks, each block process one index - int sorted_linear_indice_id = blockIdx.x; // the index of sorted_linear_ind + int sorted_linear_indice_id; + if (ExactBlock) { + sorted_linear_indice_id = + blockIdx.x * blockDim.y + threadIdx.y; // the index of sorted_linear_ind + } else { + sorted_linear_indice_id = blockIdx.x; // the index of sorted_linear_ind + } if (sorted_linear_indice_id >= num_indices) { // don't have warp divergence when embedding dim is multiple of 32 return; } + // the index row in the embedding table + SIndex index = sorted_linear_ind_data[sorted_linear_indice_id]; + // check if this thread block is responsible for this whole linear index bool linear_index_start = (sorted_linear_indice_id == 0 || - sorted_linear_ind_data[sorted_linear_indice_id - 1] != - sorted_linear_ind_data[sorted_linear_indice_id]); + sorted_linear_ind_data[sorted_linear_indice_id - 1] != index); if (!linear_index_start) { // don't have warp divergence when embedding dim is multiple of 32 return; } - // the index row in the embedding table - SIndex index = sorted_linear_ind_data[sorted_linear_indice_id]; - // find the num of duplicated indices. - int num_dup = 1; - while (sorted_linear_indice_id + num_dup < num_indices && - sorted_linear_ind_data[sorted_linear_indice_id + num_dup] == index) { - num_dup += 1; - } + if (ExactBlock) { + // find the num of duplicated indices. + int num_dup = 1; + while (true) { + int segment_continue = 0; + if (sorted_linear_indice_id + num_dup + threadIdx.x < num_indices) { + segment_continue = + sorted_linear_ind_data[sorted_linear_indice_id + num_dup + threadIdx.x] == + index; + } +#ifndef __HIP_PLATFORM_HCC__ + int32_t num_dup_incr = __popc(__ballot_sync(0xFFFFFFFF, segment_continue)); +#else + int32_t num_dup_incr = __popc(__ballot(segment_continue)); +#endif + num_dup += num_dup_incr; + if (num_dup_incr != kWarpSize) { + break; + } + } - // TODO: Tuning NumThreads for sum_squares - typedef cub::BlockReduce BlockReduce; - __shared__ BlockReduce::TempStorage temp_storage; - int valid = min(block_size, blockDim.x); + float sum_squares = 0.0; + extern __shared__ float x_ij[]; - float sum_squares = 0.0; - __shared__ float row_sum_squares_avg; - extern __shared__ float x_ij[]; + // we need to avoid index collision for the threads in the same block. + // Different threadIdx.y works on different `index`. + int sm_offset = threadIdx.y * block_size; - for (int i = threadIdx.x; i < block_size; i += blockDim.x) { - // i: index in the embedding dimension - float t_x_ij = 0.0; + for (int i = threadIdx.x; i < block_size; i += blockDim.x) { + // i: index in the embedding dimension + float t_x_ij = 0.0; - for (int dup_id = 0; dup_id < num_dup; dup_id++) { - int group = sorted_seg_id_data[sorted_linear_indice_id + dup_id]; - t_x_ij += grad[group * block_size + i]; + for (int dup_id = 0; dup_id < num_dup; dup_id++) { + int group = sorted_seg_id_data[sorted_linear_indice_id + dup_id]; + t_x_ij += grad[group * block_size + i]; + } + t_x_ij += weight_decay * + rand_factor.convertTypeFromParamToTarget(param[index * block_size + i]); + sum_squares += t_x_ij * t_x_ij; + + x_ij[sm_offset + i] = t_x_ij; } - t_x_ij += weight_decay * - rand_factor.convertTypeFromParamToTarget(param[index * block_size + i]);; - sum_squares += t_x_ij * t_x_ij; - x_ij[i] = t_x_ij; - } - float reduce_result = BlockReduce(temp_storage).Sum(sum_squares, valid); - if (threadIdx.x == 0) { - row_sum_squares_avg = reduce_result / static_cast(block_size); - float mom_new = param_mom[index] + static_cast(row_sum_squares_avg); + // We have a strong assumption that blockDim.x = 32, which is equal to the warp size. + float row_sum_squares_avg = warpReduceAllSum(sum_squares) / static_cast(block_size); + float mom_new = param_mom[index] + row_sum_squares_avg; param_mom[index] = mom_new; - } - __syncthreads(); - // update param - float step = LR / (sqrtf(param_mom[index]) + epsilon); - for (int i = threadIdx.x; i < block_size; i += blockDim.x) { - const size_t paramIdx = index * block_size + i; // index for param - param[paramIdx] = - rand_factor.convertTypeFromTargetToParam(param[paramIdx] + x_ij[i] * step); + // update param + float step = LR / (sqrtf(mom_new) + epsilon); + for (int i = threadIdx.x; i < block_size; i += blockDim.x) { + const size_t paramIdx = index * block_size + i; // index for param + param[paramIdx] = rand_factor.convertTypeFromTargetToParam( + rand_factor.convertTypeFromParamToTarget(param[paramIdx]) + x_ij[sm_offset + i] * step); + } + } else { + // find the num of duplicated indices. + int num_dup = 1; + while (sorted_linear_indice_id + num_dup < num_indices && + sorted_linear_ind_data[sorted_linear_indice_id + num_dup] == index) { + num_dup += 1; + } + + // TODO: Tuning NumThreads for sum_squares + typedef cub::BlockReduce BlockReduce; + __shared__ BlockReduce::TempStorage temp_storage; + int valid = min(block_size, blockDim.x); + + float sum_squares = 0.0; + __shared__ float row_sum_squares_avg; + extern __shared__ float x_ij[]; + + for (int i = threadIdx.x; i < block_size; i += blockDim.x) { + // i: index in the embedding dimension + float t_x_ij = 0.0; + + for (int dup_id = 0; dup_id < num_dup; dup_id++) { + int group = sorted_seg_id_data[sorted_linear_indice_id + dup_id]; + t_x_ij += grad[group * block_size + i]; + } + t_x_ij += weight_decay * + rand_factor.convertTypeFromParamToTarget(param[index * block_size + i]); + sum_squares += t_x_ij * t_x_ij; + x_ij[i] = t_x_ij; + } + float reduce_result = BlockReduce(temp_storage).Sum(sum_squares, valid); + + if (threadIdx.x == 0) { + row_sum_squares_avg = reduce_result / static_cast(block_size); + float mom_new = param_mom[index] + row_sum_squares_avg; + param_mom[index] = mom_new; + } + __syncthreads(); + + // update param + float step = LR / (sqrtf(param_mom[index]) + epsilon); + for (int i = threadIdx.x; i < block_size; i += blockDim.x) { + const size_t paramIdx = index * block_size + i; // index for param + param[paramIdx] = rand_factor.convertTypeFromTargetToParam( + rand_factor.convertTypeFromParamToTarget(param[paramIdx]) + x_ij[i] * step); + } } } @@ -570,8 +633,12 @@ class CUDASparseAdagradFusedWithSparseLengthsSumGradientOp final is_mean ? grad_buffer_.template mutable_data() : NULL; if (is_mean) { gradient_mean_kernel - <<>>( + <<>>( grad, lengths, grad_buffer_data, block_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } if (block_size <= maxThreads) { @@ -594,6 +661,7 @@ class CUDASparseAdagradFusedWithSparseLengthsSumGradientOp final is_mean ? grad_buffer_data : grad, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // calling cuda kernel with ExactBlock = false sparse_adagrad_fused_length_sum_gradient_kernel< @@ -612,6 +680,7 @@ class CUDASparseAdagradFusedWithSparseLengthsSumGradientOp final is_mean ? grad_buffer_data : grad, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return true; } @@ -753,6 +822,7 @@ class CUDASparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (block_size > 64) { sparse_adagrad_fused_length_weighted_sum_gradient_kernel< IndexType, @@ -772,6 +842,7 @@ class CUDASparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (block_size > 32) { sparse_adagrad_fused_length_weighted_sum_gradient_kernel< IndexType, @@ -791,6 +862,7 @@ class CUDASparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { sparse_adagrad_fused_length_weighted_sum_gradient_kernel< IndexType, @@ -810,6 +882,7 @@ class CUDASparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return true; } @@ -934,8 +1007,12 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final is_mean ? grad_buffer_.template mutable_data() : NULL; if (is_mean) { gradient_mean_kernel - <<>>( + <<>>( grad, lengths, grad_buffer_data, block_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } // 0: nearest rounding @@ -968,6 +1045,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final lr, seed, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { rowwise_sparse_adagrad_fused_length_sum_gradient_kernel< IndexType, @@ -987,6 +1065,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final lr, seed, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } else { if (round_option_) { @@ -1012,6 +1091,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final lr, seed, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { rowwise_sparse_adagrad_fused_length_sum_gradient_kernel< IndexType, @@ -1035,6 +1115,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final lr, seed, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } return true; @@ -1172,6 +1253,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final 0, context_.cuda_stream()>>>( grad, lengths, grad_buffer_data, block_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } sorted_linear_ind_buffer_.ResizeLike(indicesInput); @@ -1179,13 +1261,11 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final sorted_seg_id_buffer_.ResizeLike(indicesInput); linear_index_weight_offsets_dedup_kernel - <<>>( + <<>>( indices, prefix_sum_length_data, seg_id_buffer_.template mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); sort_pairs_wrapper( num_indices, @@ -1206,60 +1286,141 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsSumGradientExactOp final seed.y = maxThreads * block_size; } - CAFFE_ENFORCE_LE(block_size, 10240, - "Block size is too big and will exceed the max size of the shared memory"); - if (round_option_ == STOCHASTIC) { - rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< - IndexType, - TParam, - T, - false, - STOCHASTIC> - <<>>( - prefix_sum_length_data, - N, - block_size, - num_lengths, - num_indices, - epsilon_, - paramOut, - momentOut, - indices, - is_mean ? grad_buffer_data : grad, - sorted_linear_ind_buffer_.template data(), - sorted_seg_id_buffer_.template data(), - lr, - seed, - weight_decay_); + if (block_size <= maxThreads / 2 && block_size % 32 == 0) { + // Fast path when the embedding dimension is a multiple of 32, using + // WarpReduce. + constexpr int kWarpNum = 8; + const dim3 threads(kWarpSize, kWarpNum); + const dim3 blocks((num_indices + kWarpNum - 1) / kWarpNum); + CAFFE_ENFORCE_LE( + kWarpNum * kWarpSize, + maxThreads, + "the total number of threads in a block should be smaller than or equal to maxThreads"); + + const int sm_size = block_size * kWarpNum * sizeof(float); + // Maximum shared memory allocated per thread block is 48 KB on Maxwell/Pascal + CAFFE_ENFORCE_LE( + sm_size, + 1024 * 48, + "Block size is too big and will exceed the max size of the shared memory"); + + if (round_option_ == STOCHASTIC) { + rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< + IndexType, + TParam, + T, + true, + STOCHASTIC> + <<>>( + prefix_sum_length_data, + N, + block_size, + num_lengths, + num_indices, + epsilon_, + paramOut, + momentOut, + indices, + is_mean ? grad_buffer_data : grad, + sorted_linear_ind_buffer_.template data(), + sorted_seg_id_buffer_.template data(), + lr, + seed, + weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< + IndexType, + TParam, + T, + true, + NEAREST> + <<>>( + prefix_sum_length_data, + N, + block_size, + num_lengths, + num_indices, + epsilon_, + paramOut, + momentOut, + indices, + is_mean ? grad_buffer_data : grad, + sorted_linear_ind_buffer_.template data(), + sorted_seg_id_buffer_.template data(), + lr, + seed, + weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } else { - rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< - IndexType, - TParam, - T, - false, - NEAREST> - <<>>( - prefix_sum_length_data, - N, - block_size, - num_lengths, - num_indices, - epsilon_, - paramOut, - momentOut, - indices, - is_mean ? grad_buffer_data : grad, - sorted_linear_ind_buffer_.template data(), - sorted_seg_id_buffer_.template data(), - lr, - seed, - weight_decay_); + const int sm_size = block_size * sizeof(float); + // Maximum shared memory allocated per thread block is 48 KB on Maxwell/Pascal + CAFFE_ENFORCE_LE( + sm_size, + 1024 * 48, + "Block size is too big and will exceed the max size of the shared memory"); + if (round_option_ == STOCHASTIC) { + rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< + IndexType, + TParam, + T, + false, + STOCHASTIC> + <<>>( + prefix_sum_length_data, + N, + block_size, + num_lengths, + num_indices, + epsilon_, + paramOut, + momentOut, + indices, + is_mean ? grad_buffer_data : grad, + sorted_linear_ind_buffer_.template data(), + sorted_seg_id_buffer_.template data(), + lr, + seed, + weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel< + IndexType, + TParam, + T, + false, + NEAREST> + <<>>( + prefix_sum_length_data, + N, + block_size, + num_lengths, + num_indices, + epsilon_, + paramOut, + momentOut, + indices, + is_mean ? grad_buffer_data : grad, + sorted_linear_ind_buffer_.template data(), + sorted_seg_id_buffer_.template data(), + lr, + seed, + weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } } return true; @@ -1408,6 +1569,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (block_size > 64) { rowwise_sparse_adagrad_fused_length_weighted_sum_gradient_kernel< IndexType, @@ -1427,6 +1589,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (block_size > 32) { rowwise_sparse_adagrad_fused_length_weighted_sum_gradient_kernel< IndexType, @@ -1446,6 +1609,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { rowwise_sparse_adagrad_fused_length_weighted_sum_gradient_kernel< IndexType, @@ -1465,6 +1629,7 @@ class CUDARowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final out_weight_grads, lr, weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return true; diff --git a/caffe2/sgd/adagrad_fused_op_gpu.cuh b/caffe2/sgd/adagrad_fused_op_gpu.cuh index 9a5f53bead122..e695dac37e4d5 100644 --- a/caffe2/sgd/adagrad_fused_op_gpu.cuh +++ b/caffe2/sgd/adagrad_fused_op_gpu.cuh @@ -26,6 +26,27 @@ namespace caffe2 { +constexpr int kWarpSize = 32; + +template +inline __device__ T shfl_xor(const T val, int laneMask, int width = kWarpSize) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_xor_sync(0xffffffff, val, laneMask, width); +#else + return __shfl_xor(val, laneMask, width); +#endif +} + +/// Sums a register value across all warp threads +template +inline __device__ T warpReduceAllSum(T val) { +#pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask); + } + return val; +} + enum roundOption : int { NEAREST = 0, STOCHASTIC = 1 }; template diff --git a/caffe2/sgd/adagrad_op.h b/caffe2/sgd/adagrad_op.h index 0f336a56dbc44..f4272a03bbd77 100644 --- a/caffe2/sgd/adagrad_op.h +++ b/caffe2/sgd/adagrad_op.h @@ -232,11 +232,17 @@ class SparseAdagradOp final : public Operator { last_block_size_ = block_size; if (std::is_same::value) { kernel_i32_ = fbgemm::GenerateSparseAdaGrad( - block_size, /*rowwise=*/false, /*prefetch=*/16, weight_decay_); + block_size, + /*rowwise=*/false, + /*prefetch=*/16, + weight_decay_ != 0.0f); } else { CAFFE_ENFORCE((std::is_same::value)); kernel_i64_ = fbgemm::GenerateSparseAdaGrad( - block_size, /*rowwise=*/false, /*prefetch=*/16, weight_decay_); + block_size, + /*rowwise=*/false, + /*prefetch=*/16, + weight_decay_ != 0.0f); } } @@ -250,7 +256,8 @@ class SparseAdagradOp final : public Operator { momentOut, reinterpret_cast(indices), epsilon_, - lr[0]); + lr[0], + weight_decay_); } else { num_rows_processed = kernel_i64_( n, @@ -260,7 +267,8 @@ class SparseAdagradOp final : public Operator { momentOut, reinterpret_cast(indices), epsilon_, - lr[0]); + lr[0], + weight_decay_); } if (num_rows_processed < n) { CAFFE_ENFORCE_GE( @@ -340,7 +348,7 @@ class SparseAdagradOp final : public Operator { protected: float epsilon_; - float weight_decay_; + const float weight_decay_; #if defined(USE_FBGEMM) && !defined(__NVCC__) fbgemm::SparseAdaGradSignature::Type kernel_i32_; fbgemm::SparseAdaGradSignature::Type kernel_i64_; @@ -421,11 +429,17 @@ class RowWiseSparseAdagradOp final : public Operator { last_block_size_ = block_size; if (std::is_same::value) { kernel_i32_ = fbgemm::GenerateSparseAdaGrad( - block_size, /*rowwise=*/true, /*prefetch=*/16, weight_decay_); + block_size, + /*rowwise=*/true, + /*prefetch=*/16, + weight_decay_ != 0.0f); } else { CAFFE_ENFORCE((std::is_same::value)); kernel_i64_ = fbgemm::GenerateSparseAdaGrad( - block_size, /*rowwise=*/true, /*prefetch=*/16, weight_decay_); + block_size, + /*rowwise=*/true, + /*prefetch=*/16, + weight_decay_ != 0.0f); } } @@ -439,7 +453,8 @@ class RowWiseSparseAdagradOp final : public Operator { moment, reinterpret_cast(indices), epsilon_, - lr[0]); + lr[0], + weight_decay_); } else { num_rows_processed = kernel_i64_( n, @@ -449,7 +464,8 @@ class RowWiseSparseAdagradOp final : public Operator { moment, reinterpret_cast(indices), epsilon_, - lr[0]); + lr[0], + weight_decay_); } if (num_rows_processed < n) { @@ -527,7 +543,7 @@ class RowWiseSparseAdagradOp final : public Operator { protected: float epsilon_; - float weight_decay_; + const float weight_decay_; #if defined(USE_FBGEMM) && !defined(__NVCC__) fbgemm::SparseAdaGradSignature::Type kernel_i32_; fbgemm::SparseAdaGradSignature::Type kernel_i64_; diff --git a/caffe2/sgd/adagrad_op_gpu.cu b/caffe2/sgd/adagrad_op_gpu.cu index b49cf2880a905..8abb3376ca875 100644 --- a/caffe2/sgd/adagrad_op_gpu.cu +++ b/caffe2/sgd/adagrad_op_gpu.cu @@ -44,6 +44,7 @@ void adagrad_update( 0, context->cuda_stream()>>>( N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -188,6 +189,7 @@ class CUDASparseAdagradOp final : public Operator { Input(GRAD).template data(), Input(LR).template data(), weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -231,6 +233,7 @@ bool RowWiseSparseAdagradOp::DoRunWithType() { Input(GRAD).template data(), Input(LR).template data(), weight_decay_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/caffe2/sgd/adam_op_gpu.cu b/caffe2/sgd/adam_op_gpu.cu index 30b1b5ce8c674..42ab975faacb3 100644 --- a/caffe2/sgd/adam_op_gpu.cu +++ b/caffe2/sgd/adam_op_gpu.cu @@ -47,6 +47,7 @@ void adam_update( 0, context->cuda_stream()>>>( N, g, m, v, ng, nm, nv, beta1, beta2, eps_hat, correction, lr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } __global__ void AdamCompute( @@ -94,6 +95,7 @@ void adam_compute( 0, context->cuda_stream()>>>( N, w, g, m, v, nw, nm, nv, beta1, beta2, eps_hat, correction, lr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } __global__ void AdamComputeOutputGrad( @@ -143,6 +145,7 @@ void adam_compute_output_grad( 0, context->cuda_stream()>>>( N, w, g, m, v, nw, nm, nv, ng, beta1, beta2, eps_hat, correction, lr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -333,6 +336,7 @@ bool SparseAdamOp::DoRunWithType() { correction, Input(LR).template data(), iter); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); SparseAdamOutputGradKernel @@ -354,6 +358,7 @@ bool SparseAdamOp::DoRunWithType() { correction, Input(LR).template data(), iter); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return true; @@ -398,6 +403,7 @@ bool RowWiseSparseAdamOp::DoRunWithType() { Input(GRAD).template data(), correction, Input(LR).template data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); RowWiseSparseAdamOutputGradKernel @@ -418,6 +424,7 @@ bool RowWiseSparseAdamOp::DoRunWithType() { Input(GRAD).template data(), correction, Input(LR).template data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return true; diff --git a/caffe2/sgd/fp16_momentum_sgd_op.cu b/caffe2/sgd/fp16_momentum_sgd_op.cu index 8ec1c85fd5480..985f4f2864d15 100644 --- a/caffe2/sgd/fp16_momentum_sgd_op.cu +++ b/caffe2/sgd/fp16_momentum_sgd_op.cu @@ -215,6 +215,7 @@ void fp16_momentum_sgd_update( nesterov, weight_decay, reinterpret_cast(param)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // not setting N to N/2 } else { FP16MomentumSGDFP32Kernel<<< @@ -232,6 +233,7 @@ void fp16_momentum_sgd_update( nesterov, weight_decay, reinterpret_cast(param)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // not setting N to N/2 } diff --git a/caffe2/sgd/fp32_momentum_sgd_op.cu b/caffe2/sgd/fp32_momentum_sgd_op.cu index c7947dac440a2..1ae9015bc6d94 100644 --- a/caffe2/sgd/fp32_momentum_sgd_op.cu +++ b/caffe2/sgd/fp32_momentum_sgd_op.cu @@ -108,6 +108,7 @@ void fp32_momentum_sgd_update( nesterov, weight_decay, reinterpret_cast(param)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // not setting N to N/2 // TODO_ check float performance vs float2 } diff --git a/caffe2/sgd/lars_op_gpu.cu b/caffe2/sgd/lars_op_gpu.cu index 9c4322167dbe6..2a1d6c79e8332 100644 --- a/caffe2/sgd/lars_op_gpu.cu +++ b/caffe2/sgd/lars_op_gpu.cu @@ -31,6 +31,7 @@ void LarsOp::ComputeLearningRate( float* lr_rescaled) { ComputeLearningRateKernel<<<1, 1, 0, context_.cuda_stream()>>>( wd, trust, lr_max, offset, lr_min, X_norm, dX_norm, lr_rescaled); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } REGISTER_CUDA_OPERATOR(Lars, LarsOp); diff --git a/caffe2/sgd/learning_rate_op.h b/caffe2/sgd/learning_rate_op.h index 3ba6bef39e639..fb0998a65d714 100644 --- a/caffe2/sgd/learning_rate_op.h +++ b/caffe2/sgd/learning_rate_op.h @@ -62,7 +62,7 @@ class LearningRateOp final : public Operator { active_period, inactive_period, active_first); } else if (policy == "hill") { int64_t num_iter = - this->template GetSingleArgument(arg_prefix + "num_iter", 0); + this->template GetSingleArgument(arg_prefix + "num_iter", 0); DCHECK_GT(num_iter, 0); T start_multiplier = this->template GetSingleArgument( arg_prefix + "start_multiplier", 0.); @@ -198,9 +198,9 @@ class LearningRateOp final : public Operator { } else if (policy == "constantThenLinearWarmup") { T start_warmup_multiplier = this->template GetSingleArgument( arg_prefix + "start_warmup_multiplier", 0.1); - int64_t constant_warmup_num_iter = this->template GetSingleArgument( + int64_t constant_warmup_num_iter = this->template GetSingleArgument( arg_prefix + "constant_warmup_num_iter", 10000000); - int64_t linear_warmup_num_iter = this->template GetSingleArgument( + int64_t linear_warmup_num_iter = this->template GetSingleArgument( arg_prefix + "linear_warmup_num_iter", 10000000); return new ConstantThenLinearWarmupLearningRate( start_warmup_multiplier, @@ -209,9 +209,9 @@ class LearningRateOp final : public Operator { } else if (policy == "compositeCyclical") { T start_warmup_multiplier = this->template GetSingleArgument( arg_prefix + "start_warmup_multiplier", 0.1); - int64_t constant_warmup_num_iter = this->template GetSingleArgument( + int64_t constant_warmup_num_iter = this->template GetSingleArgument( arg_prefix + "constant_warmup_num_iter", 10000000); - int64_t linear_warmup_num_iter = this->template GetSingleArgument( + int64_t linear_warmup_num_iter = this->template GetSingleArgument( arg_prefix + "linear_warmup_num_iter", 10000000); T cyclical_max_lr = this->template GetSingleArgument( arg_prefix + "cyclical_max_lr", 0.05); @@ -245,9 +245,9 @@ class LearningRateOp final : public Operator { } else if (policy == "compositeCosine") { T start_warmup_multiplier = this->template GetSingleArgument( arg_prefix + "start_warmup_multiplier", 0.1); - int64_t constant_warmup_num_iter = this->template GetSingleArgument( + int64_t constant_warmup_num_iter = this->template GetSingleArgument( arg_prefix + "constant_warmup_num_iter", 10000000); - int64_t linear_warmup_num_iter = this->template GetSingleArgument( + int64_t linear_warmup_num_iter = this->template GetSingleArgument( arg_prefix + "linear_warmup_num_iter", 10000000); T cosine_max_lr = this->template GetSingleArgument( arg_prefix + "cosine_max_lr", 0.5); diff --git a/caffe2/sgd/momentum_sgd_op_gpu.cu b/caffe2/sgd/momentum_sgd_op_gpu.cu index ebf0abae54b9b..e8eb00654e65a 100644 --- a/caffe2/sgd/momentum_sgd_op_gpu.cu +++ b/caffe2/sgd/momentum_sgd_op_gpu.cu @@ -82,12 +82,14 @@ void momentum_sgd_update( CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { MomentumSGDKernel <<cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -150,6 +152,7 @@ bool SparseMomentumSGDUpdateOp::DoRunWithType() { Input(GRAD).template data(), Output(OUTPUT_GRAD)->template mutable_data(), Input(LR).template data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/caffe2/sgd/rmsprop_op_gpu.cu b/caffe2/sgd/rmsprop_op_gpu.cu index fd293e240308b..d6f236739084b 100644 --- a/caffe2/sgd/rmsprop_op_gpu.cu +++ b/caffe2/sgd/rmsprop_op_gpu.cu @@ -43,6 +43,7 @@ void rmsprop_update( CUDAContext* context) { RmsPropUpdate<<cuda_stream()>>>( N, g, ms, mom, ng, nms, nmom, decay, momentum, epsilon, lr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/caffe2/sgd/rowwise_adagrad_fused.h b/caffe2/sgd/rowwise_adagrad_fused.h index 5430b97f6e9a9..56be3c15453e9 100644 --- a/caffe2/sgd/rowwise_adagrad_fused.h +++ b/caffe2/sgd/rowwise_adagrad_fused.h @@ -46,7 +46,7 @@ inline float compute_square_average_with_weight_decay_inlined_( for (; i + kSize <= len; i += kSize) { __m256 ai = _mm256_loadu_ps(a + i); __m256 wi = _mm256_loadu_ps(w + i); -#ifdef __AVX2__ +#ifdef __FMA__ ai = _mm256_fmadd_ps(weight_decay_v, wi, ai); #else ai = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), ai); @@ -84,7 +84,7 @@ inline float compute_square_average_with_weight_decay_inlined_( __m256 ai = _mm256_loadu_ps(a + i); __m128i whi = _mm_loadu_si128(reinterpret_cast(w + i)); __m256 wi = _mm256_cvtph_ps(whi); -#ifdef __AVX2__ +#ifdef __FMA__ ai = _mm256_fmadd_ps(weight_decay_v, wi, ai); #else ai = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), ai); @@ -952,7 +952,7 @@ struct rowwise_adagrad_update_inlined { __m256 gi = _mm256_loadu_ps(g + i); __m256 wi = _mm256_loadu_ps(w + i); if (weight_decay != 0.0f) { -#ifdef __AVX2__ +#ifdef __FMA__ gi = _mm256_fmadd_ps(weight_decay_v, wi, gi); #else gi = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), gi); diff --git a/caffe2/sgd/yellowfin_op_gpu.cu b/caffe2/sgd/yellowfin_op_gpu.cu index 89f9c547fa68a..cb62ae4335577 100644 --- a/caffe2/sgd/yellowfin_op_gpu.cu +++ b/caffe2/sgd/yellowfin_op_gpu.cu @@ -32,6 +32,7 @@ void YellowFinOp::GetLrMu() { // Finding root of cubic formula for YF's Single Step GetLrMuKernel<<<1, 1, 0, context_.cuda_stream()>>>( g_norm2_max_deb_, g_norm2_min_deb_, distance_deb_, variance_, mu_, lr_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); MovingAverage(1, mu_, mu_avg_, mu_avg_out_, mu_deb_); MovingAverage(1, lr_, lr_avg_, lr_avg_out_, lr_deb_); } @@ -78,6 +79,7 @@ void YellowFinOp::MomentumSgdUpdate() { param_out_, moment_out_, nesterov_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } REGISTER_CUDA_OPERATOR(YellowFin, YellowFinOp); diff --git a/caffe2/transforms/common_subexpression_elimination.h b/caffe2/transforms/common_subexpression_elimination.h index fdec50a11e8ec..6e54f8185d556 100644 --- a/caffe2/transforms/common_subexpression_elimination.h +++ b/caffe2/transforms/common_subexpression_elimination.h @@ -25,7 +25,7 @@ namespace caffe2 { * * TODO(benz): Fix the error to not match nodes that write to external output. */ -class CAFFE2_API CommonSubexpressionEliminationTransform : public Transform { +class TORCH_API CommonSubexpressionEliminationTransform : public Transform { public: CommonSubexpressionEliminationTransform() { SetPatternMatchType(SORTED_WRT_EXECUTION_ORDER); diff --git a/caffe2/transforms/conv_to_nnpack_transform.h b/caffe2/transforms/conv_to_nnpack_transform.h index 8563732f225e2..0e19989aee644 100644 --- a/caffe2/transforms/conv_to_nnpack_transform.h +++ b/caffe2/transforms/conv_to_nnpack_transform.h @@ -7,7 +7,7 @@ namespace caffe2 { -class CAFFE2_API ConvToNNPackTransform : public SingleOpTransform { +class TORCH_API ConvToNNPackTransform : public SingleOpTransform { protected: // Specify what the op needs to be to match the pattern. bool MatchOperator(const OperatorDef& op) override { diff --git a/caffe2/transforms/pattern_net_transform.h b/caffe2/transforms/pattern_net_transform.h index 397258fbd4fc2..95638f4a839c5 100644 --- a/caffe2/transforms/pattern_net_transform.h +++ b/caffe2/transforms/pattern_net_transform.h @@ -15,7 +15,7 @@ namespace caffe2 { * and this Transform will find subgraphs which fit the pattern net, * and replace it with the replace net. */ -class CAFFE2_API PatternNetTransform : public Transform { +class TORCH_API PatternNetTransform : public Transform { public: PatternNetTransform(const NetDef& pattern_net, const NetDef& replace_net) : p_(transform::Graph(pattern_net)), r_(transform::Graph(replace_net)) { diff --git a/caffe2/transforms/single_op_transform.h b/caffe2/transforms/single_op_transform.h index 45f93cbbd8b99..096c06423db0f 100644 --- a/caffe2/transforms/single_op_transform.h +++ b/caffe2/transforms/single_op_transform.h @@ -15,7 +15,7 @@ namespace caffe2 { * Transforms which derive from SingleOpTransform need to override: * ReplaceOperator and MatchOperator. */ -class CAFFE2_API SingleOpTransform : public Transform { +class TORCH_API SingleOpTransform : public Transform { protected: bool PatternRule( const transform::Graph& g, diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt index 798985953b89f..62190501cdac0 100644 --- a/caffe2/utils/CMakeLists.txt +++ b/caffe2/utils/CMakeLists.txt @@ -1,9 +1,13 @@ if((NOT BUILD_CAFFE2) OR (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE)) list(APPEND Caffe2_CPU_SRCS utils/string_utils.cc - utils/threadpool/pthreadpool-cpp.cc utils/threadpool/ThreadPool.cc ) + + if(USE_PTHREADPOOL AND NOT USE_INTERNAL_PTHREADPOOL_IMPL) + list(APPEND Caffe2_CPU_SRCS utils/threadpool/pthreadpool-cpp.cc) + endif() + if(NOT BUILD_CAFFE2) list(APPEND Caffe2_CPU_SRCS utils/proto_wrap.cc diff --git a/caffe2/utils/GpuDefs.cuh b/caffe2/utils/GpuDefs.cuh index 46d8058c84b54..be591cc95b92e 100644 --- a/caffe2/utils/GpuDefs.cuh +++ b/caffe2/utils/GpuDefs.cuh @@ -7,16 +7,9 @@ namespace caffe2 { // Static definition of GPU warp size for unrolling and code generation -#ifdef __CUDA_ARCH__ -#if __CUDA_ARCH__ <= 800 -constexpr int kWarpSize = 32; -#else -#error Unknown __CUDA_ARCH__; please define parameters for compute capability -#endif // __CUDA_ARCH__ types -#elif defined(__HIP_PLATFORM_HCC__) +#if defined(__HIP_PLATFORM_HCC__) constexpr int kWarpSize = warpSize; // = 64 (Defined in hip_runtime.h) #else -// dummy value for host compiler constexpr int kWarpSize = 32; #endif // __CUDA_ARCH__ diff --git a/caffe2/utils/bench_utils.h b/caffe2/utils/bench_utils.h index b879ccc1eb510..59997edad58d4 100644 --- a/caffe2/utils/bench_utils.h +++ b/caffe2/utils/bench_utils.h @@ -23,7 +23,7 @@ namespace caffe2 { -CAFFE2_API uint32_t wipe_cache(); +TORCH_API uint32_t wipe_cache(); } // namespace caffe2 diff --git a/caffe2/utils/cpuid.cc b/caffe2/utils/cpuid.cc index b2e6b89a5cbd8..7ef47dd757c32 100644 --- a/caffe2/utils/cpuid.cc +++ b/caffe2/utils/cpuid.cc @@ -7,10 +7,10 @@ const CpuId& GetCpuId() { return cpuid_singleton; } -CAFFE2_API uint32_t CpuId::f1c_ = 0; -CAFFE2_API uint32_t CpuId::f1d_ = 0; -CAFFE2_API uint32_t CpuId::f7b_ = 0; -CAFFE2_API uint32_t CpuId::f7c_ = 0; +TORCH_API uint32_t CpuId::f1c_ = 0; +TORCH_API uint32_t CpuId::f1d_ = 0; +TORCH_API uint32_t CpuId::f7b_ = 0; +TORCH_API uint32_t CpuId::f7c_ = 0; CpuId::CpuId() { #ifdef _MSC_VER diff --git a/caffe2/utils/cpuid.h b/caffe2/utils/cpuid.h index 7cc09009fd20f..598e1bdaea8e7 100644 --- a/caffe2/utils/cpuid.h +++ b/caffe2/utils/cpuid.h @@ -12,7 +12,7 @@ namespace caffe2 { class CpuId; -CAFFE2_API const CpuId& GetCpuId(); +TORCH_API const CpuId& GetCpuId(); /////////////////////////////////////////////////////////////////////////////// // Implementation of CpuId that is borrowed from folly. @@ -137,10 +137,10 @@ class CpuId { #undef X private: - CAFFE2_API static uint32_t f1c_; - CAFFE2_API static uint32_t f1d_; - CAFFE2_API static uint32_t f7b_; - CAFFE2_API static uint32_t f7c_; + TORCH_API static uint32_t f1c_; + TORCH_API static uint32_t f1d_; + TORCH_API static uint32_t f7b_; + TORCH_API static uint32_t f7c_; }; } // namespace caffe2 diff --git a/caffe2/utils/eigen_utils.h b/caffe2/utils/eigen_utils.h index 83e7cb2317bbc..77b6eb64f095b 100644 --- a/caffe2/utils/eigen_utils.h +++ b/caffe2/utils/eigen_utils.h @@ -73,17 +73,31 @@ using EArrXf = Eigen::ArrayXf; using EArrXd = Eigen::ArrayXd; using EArrXi = Eigen::ArrayXi; using EArrXb = EArrXt; +using EArrXI32 = EArrXt; +using EArrXU16 = EArrXt; +using EArrXU8 = EArrXt; +using EArr3U8 = Eigen::Array; // 2-d array, column major template using EArrXXt = Eigen::Array; using EArrXXf = Eigen::ArrayXXf; +using EArrXXI32 = EArrXXt; +using EArrXXU16 = EArrXXt; +using EArrXXU8 = EArrXXt; +using EArrXXi = EArrXXt; // 2-d array, row major template using ERArrXXt = Eigen::Array; using ERArrXXf = ERArrXXt; +using ERArrXXI32t = ERArrXXt; +using ERArrXXU16t = ERArrXXt; +using ERArrXXU8t = ERArrXXt; +using ERArrXXi = ERArrXXt; +using ERArrXXi64t = ERArrXXt; +using ERArrXXi32t = ERArrXXt; // 1-d vector template @@ -100,6 +114,8 @@ template using EMatXt = Eigen::Matrix; using EMatXd = Eigen::MatrixXd; using EMatXf = Eigen::MatrixXf; +using EMatXU8 = EMatXt; +using EMatXU16 = EMatXt; // 2-d matrix, row major template @@ -107,6 +123,7 @@ using ERMatXt = Eigen::Matrix; using ERMatXd = ERMatXt; using ERMatXf = ERMatXt; +using ERMatXU8 = ERMatXt; namespace utils { diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index 4ad285d50a273..07911a3c3d58d 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -28,7 +28,7 @@ class Tensor; // An empty class as a placeholder for a math function that has no specific // engine specified. -class CAFFE2_API DefaultEngine {}; +class TORCH_API DefaultEngine {}; namespace math { @@ -118,7 +118,7 @@ C10_DECLARE_BINARY_OP(BitwiseXor) // Broadcasts X with X_dims to Y with Y_dims. template -CAFFE2_API void Broadcast( +TORCH_API void Broadcast( const int X_ndim, const int* X_dims, const int Y_ndim, @@ -130,7 +130,7 @@ CAFFE2_API void Broadcast( // Computes inv_std from variance. template -CAFFE2_API void InvStd( +TORCH_API void InvStd( const int N, const T epsilon, const T* var, @@ -140,7 +140,7 @@ CAFFE2_API void InvStd( // Adds batch sub-tensors elementwise to output. Stripe is the stripe length // and N is the number of elements to add (size of Y). template -CAFFE2_API void AddStripedBatch( +TORCH_API void AddStripedBatch( const int N, const T* first, T* y, @@ -151,24 +151,24 @@ CAFFE2_API void AddStripedBatch( // Compute the row-wise max of a N*D matrix X, and write it to a N // dimensional vector y. template -CAFFE2_API void +TORCH_API void RowwiseMax(const int N, const int D, const T* x, T* y, Context* context); // Compute the column-wise max of a N*D matrix X, and write it to a D // dimensional vector y. template -CAFFE2_API void +TORCH_API void ColwiseMax(const int N, const int D, const T* x, T* y, Context* context); // Elemwise maximum of vector x and scalar alpha. y[i] = max(x[i], alpha) template -CAFFE2_API void +TORCH_API void Maximum(const int N, const float alpha, const T* x, T* y, Context* context); // Decaf gemm provides a simpler interface to the gemm functions, with the // limitation that the data has to be contiguous in memory. template -CAFFE2_API void Gemm( +TORCH_API void Gemm( const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, const int M, @@ -185,7 +185,7 @@ CAFFE2_API void Gemm( // We also provide a gemm that has explicit lda, ldb and ldc specified. // In most cases you probably want to use the function above, though. template -CAFFE2_API void GemmEx( +TORCH_API void GemmEx( const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, const int M, @@ -203,7 +203,7 @@ CAFFE2_API void GemmEx( // GemmBatched provides a simple abstraction into library routines template -CAFFE2_API void GemmBatched( +TORCH_API void GemmBatched( const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, const int batch_size, @@ -219,7 +219,7 @@ CAFFE2_API void GemmBatched( TensorProto::DataType math_type = TensorProto_DataType_FLOAT); template -CAFFE2_API void GemmStridedBatched( +TORCH_API void GemmStridedBatched( const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, const int batch_size, @@ -242,7 +242,7 @@ CAFFE2_API void GemmStridedBatched( // CblasNoTrans: x is an N dim vector and y is an M dim vector. // CblasTrans: x is an M dim vector and y is an N dim vector. template -CAFFE2_API void Gemv( +TORCH_API void Gemv( const CBLAS_TRANSPOSE trans_A, const int M, const int N, @@ -255,13 +255,13 @@ CAFFE2_API void Gemv( TensorProto::DataType math_type = TensorProto_DataType_FLOAT); template -CAFFE2_API void +TORCH_API void RandUniform(const size_t n, const T a, const T b, T* r, Context* context); // Generate n values that sum up to a fixed sum // and subject to a restriction a <= x <= b for each x generated template -CAFFE2_API void RandFixedSum( +TORCH_API void RandFixedSum( const size_t n, const T a, const T b, @@ -270,7 +270,7 @@ CAFFE2_API void RandFixedSum( Context* context); template -CAFFE2_API void RandUniformUnique( +TORCH_API void RandUniformUnique( const size_t n, const T a, const T b, @@ -282,21 +282,21 @@ CAFFE2_API void RandUniformUnique( // Generate n values from synthetic data distribution, // define by unique accesses and stack distances template -CAFFE2_API void +TORCH_API void RandSyntheticData(const size_t n, const T a, const T b, T* r, Context* context); template -CAFFE2_API void +TORCH_API void RandGaussian(const size_t n, const T mean, const T std, T* r, Context* context); // Dot matrix of vector a and b, and writes the result to a single value y. template -CAFFE2_API void +TORCH_API void Dot(const int N, const T* a, const T* b, T* y, Context* context); // Sum of vector x, and writes the result to a single value y. template -CAFFE2_API void Sum( +TORCH_API void Sum( const int N, const T* x, T* y, @@ -305,7 +305,7 @@ CAFFE2_API void Sum( // Sum of squares of vector x, and writes the result to a single value y. template -CAFFE2_API void SumSqr( +TORCH_API void SumSqr( const int N, const T* x, T* y, @@ -315,7 +315,7 @@ CAFFE2_API void SumSqr( // Select does index selection of the rows a N*D matrix x, and gives the N // dimensional vector y that contains the selected data. template -CAFFE2_API void Select( +TORCH_API void Select( const int N, const int D, const T* x, @@ -329,7 +329,7 @@ CAFFE2_API void Select( // For NCHW order, groups doesn't make any difference because we're doing Im2Col // for each N and C is the slowest moving dimension among CHW. template -CAFFE2_API void Im2Col( +TORCH_API void Im2Col( const int channels, const int height, const int width, @@ -350,7 +350,7 @@ CAFFE2_API void Im2Col( // groups must be 1 for GPU template -CAFFE2_API void Im2ColNd( +TORCH_API void Im2ColNd( const int N, const int img_size, const int col_size, @@ -371,7 +371,7 @@ CAFFE2_API void Im2ColNd( // For NCHW order, groups doesn't make any difference because we're doing Im2Col // for each N and C is the slowest moving dimension among CHW. template -CAFFE2_API void Col2Im( +TORCH_API void Col2Im( const int channels, const int height, const int width, @@ -396,7 +396,7 @@ CAFFE2_API void Col2Im( // For NCHW order, groups doesn't make any difference because we're doing Im2Col // for each N and C is the slowest moving dimension among CHW. template -CAFFE2_API void Col2ImNd( +TORCH_API void Col2ImNd( const int N, const int img_size, const int col_size, @@ -414,7 +414,7 @@ CAFFE2_API void Col2ImNd( // Applies a per-channel bias value to each channel of the input // image. image_size is H * W template -CAFFE2_API void BiasCHW( +TORCH_API void BiasCHW( const T* bias, const T* bias_multiplier, const int bias_channels, @@ -423,7 +423,7 @@ CAFFE2_API void BiasCHW( Context* context); template -CAFFE2_API void CopyMatrix( +TORCH_API void CopyMatrix( const size_t item_size, const int M, const int N, @@ -435,7 +435,7 @@ CAFFE2_API void CopyMatrix( TypeMeta::Copy copy = nullptr); template -CAFFE2_API void CopyMatrix( +TORCH_API void CopyMatrix( const int M, const int N, const T* A, @@ -445,7 +445,7 @@ CAFFE2_API void CopyMatrix( Context* context); template -CAFFE2_API void CopyMatrix( +TORCH_API void CopyMatrix( const int M, const int N, const T* A, @@ -457,7 +457,7 @@ CAFFE2_API void CopyMatrix( Context* context); template -CAFFE2_API void CopyVector(const int N, const T* A, T* B, Context* context); +TORCH_API void CopyVector(const int N, const T* A, T* B, Context* context); } // namespace math } // namespace caffe2 diff --git a/caffe2/utils/math/broadcast.cu b/caffe2/utils/math/broadcast.cu index 97f7cb500fc12..7d7a2535743c6 100644 --- a/caffe2/utils/math/broadcast.cu +++ b/caffe2/utils/math/broadcast.cu @@ -83,6 +83,7 @@ __global__ void AffineChannelNHWCCUDAKernel( AffineChannelNCHWCUDAKernel \ <<cuda_stream()>>>( \ C, M, HxW, X, scale, bias, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ template <> \ CAFFE2_CUDA_EXPORT void AffineChannel( \ @@ -100,6 +101,7 @@ __global__ void AffineChannelNHWCCUDAKernel( CAFFE_CUDA_NUM_THREADS, \ 0, \ context->cuda_stream()>>>(C, X, scale, bias, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) #undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL diff --git a/caffe2/utils/math/broadcast.h b/caffe2/utils/math/broadcast.h index 67e37d1bd9156..16b98c749ce32 100644 --- a/caffe2/utils/math/broadcast.h +++ b/caffe2/utils/math/broadcast.h @@ -8,7 +8,7 @@ namespace caffe2 { namespace math { template -CAFFE2_API void AffineChannel( +TORCH_API void AffineChannel( const int N, const int C, const int HxW, diff --git a/caffe2/utils/math/elementwise.cc b/caffe2/utils/math/elementwise.cc index 15a6dceca0e94..aaba32662099e 100644 --- a/caffe2/utils/math/elementwise.cc +++ b/caffe2/utils/math/elementwise.cc @@ -46,6 +46,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION( VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, vsLn) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, vdLn) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log1p, vsLog1p) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log1p, vdLog1p) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sin, vsSin) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sin, vdSin) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Asin, vsAsin) @@ -155,6 +157,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, exp) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, exp) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, log) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, log) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log1p, log1p) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log1p, log1p) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sin, sin) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sin, sin) DELEGATE_SIMPLE_UNARY_FUNCTION(float, Asin, asin) diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu index 11e61f41e06cb..778147a7b9eb9 100644 --- a/caffe2/utils/math/elementwise.cu +++ b/caffe2/utils/math/elementwise.cu @@ -340,6 +340,7 @@ CAFFE2_SPECIALIZED_CUDA_SET(at::BFloat16) } DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log1p, log1pf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin, sinf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin, asinf) DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos, cosf) @@ -416,6 +417,7 @@ DELEGATE_CUDA_POWX(float, powf) SinCosCUDAKernel \ <<cuda_stream()>>>( \ N, X, S, C); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } CAFFE2_SPECIALIZED_CUDA_SINCOS(float) @@ -442,6 +444,7 @@ CAFFE2_SPECIALIZED_CUDA_SINCOS(double) ScaleCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } \ template <> \ @@ -463,6 +466,7 @@ CAFFE2_SPECIALIZED_CUDA_SINCOS(double) ScaleCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } DELEGATE_CUDA_SCALE(float, cublasSscal) @@ -500,6 +504,7 @@ DELEGATE_CUDA_SCALE(double, cublasDscal) ScaleCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } \ template <> \ @@ -529,6 +534,7 @@ DELEGATE_CUDA_SCALE(double, cublasDscal) ScaleCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } DELEGATE_CUDA_SCALE_EX(float, double, CUDA_R_32F, CUDA_R_64F, CUDA_R_64F) @@ -550,6 +556,7 @@ DELEGATE_CUDA_SCALE_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) ScaleCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } \ template <> \ @@ -564,6 +571,7 @@ DELEGATE_CUDA_SCALE_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) ScaleCUDAKernel \ <<cuda_stream()>>>( \ N, *alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } CAFFE2_SPECIALIZED_CUDA_SCALE(std::int32_t, std::int32_t) @@ -851,6 +859,7 @@ DELEGATE_CUDA_AXPY_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) AxpyCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ template <> \ CAFFE2_CUDA_EXPORT void Axpy( \ @@ -863,6 +872,7 @@ DELEGATE_CUDA_AXPY_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) AxpyCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } CAFFE2_SPECIALIZED_CUDA_AXPY(float, double) CAFFE2_SPECIALIZED_CUDA_AXPY(float, at::Half) @@ -883,6 +893,7 @@ CAFFE2_SPECIALIZED_CUDA_AXPY(float, at::Half) AxpbyCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, beta, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ template <> \ CAFFE2_CUDA_EXPORT void Axpby( \ @@ -896,6 +907,7 @@ CAFFE2_SPECIALIZED_CUDA_AXPY(float, at::Half) AxpbyCUDAKernel \ <<cuda_stream()>>>( \ N, alpha, X, beta, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } CAFFE2_SPECIALIZED_CUDA_AXPBY(float, float) CAFFE2_SPECIALIZED_CUDA_AXPBY(float, double) diff --git a/caffe2/utils/math/elementwise.h b/caffe2/utils/math/elementwise.h index a6c77bb70c006..794709359f3cf 100644 --- a/caffe2/utils/math/elementwise.h +++ b/caffe2/utils/math/elementwise.h @@ -8,67 +8,69 @@ namespace caffe2 { namespace math { template -CAFFE2_API void Exp(int N, const T* X, T* Y, Context* context); +TORCH_API void Exp(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Log(int N, const T* X, T* Y, Context* context); +TORCH_API void Log(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Sin(int N, const T* X, T* Y, Context* context); +TORCH_API void Log1p(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Asin(int N, const T* X, T* Y, Context* context); +TORCH_API void Sin(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Cos(int N, const T* X, T* Y, Context* context); +TORCH_API void Asin(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Acos(int N, const T* X, T* Y, Context* context); +TORCH_API void Cos(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Tan(int N, const T* X, T* Y, Context* context); +TORCH_API void Acos(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Atan(int N, const T* X, T* Y, Context* context); +TORCH_API void Tan(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Sinh(int N, const T* X, T* Y, Context* context); +TORCH_API void Atan(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Cosh(int N, const T* X, T* Y, Context* context); +TORCH_API void Sinh(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void SinCos(int N, const T* X, T* S, T* C, Context* context); +TORCH_API void Cosh(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Tanh(int N, const T* X, T* Y, Context* context); +TORCH_API void SinCos(int N, const T* X, T* S, T* C, Context* context); template -CAFFE2_API void Abs(int N, const T* X, T* Y, Context* context); +TORCH_API void Tanh(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Sqr(int N, const T* X, T* Y, Context* context); +TORCH_API void Abs(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Sqrt(int N, const T* X, T* Y, Context* context); +TORCH_API void Sqr(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Rsqrt(int N, const T* X, T* Y, Context* context); +TORCH_API void Sqrt(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Cube(int N, const T* X, T* Y, Context* context); +TORCH_API void Rsqrt(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Cbrt(int N, const T* X, T* Y, Context* context); +TORCH_API void Cube(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Neg(int N, const T* X, T* Y, Context* context); +TORCH_API void Cbrt(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Sign(int N, const T* X, T* Y, Context* context); +TORCH_API void Neg(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Not(int N, const T* X, T* Y, Context* context); +TORCH_API void Sign(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Powx(int N, const T* A, const T b, T* Y, Context* context); +TORCH_API void Not(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Inv(int N, const T* X, T* Y, Context* context); +TORCH_API void Powx(int N, const T* A, const T b, T* Y, Context* context); template -CAFFE2_API void Erf(int N, const T* X, T* Y, Context* context); +TORCH_API void Inv(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void CdfNorm(int N, const T* X, T* Y, Context* context); +TORCH_API void Erf(int N, const T* X, T* Y, Context* context); +template +TORCH_API void CdfNorm(int N, const T* X, T* Y, Context* context); template -CAFFE2_API void Set(std::int64_t N, T alpha, T* X, Context* context); +TORCH_API void Set(std::int64_t N, T alpha, T* X, Context* context); template -CAFFE2_API void +TORCH_API void Scale(std::int64_t N, TAlpha alpha, const TData* X, TData* Y, Context* context); // Different from the Scale function above, if alpha is passed in as a pointer, // we will assume that it lives on the Context device, for example on GPU. template -CAFFE2_API void Scale( +TORCH_API void Scale( std::int64_t N, const TAlpha* alpha, const TData* X, @@ -76,58 +78,58 @@ CAFFE2_API void Scale( Context* context); template -CAFFE2_API void Add(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Add(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void Sub(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Sub(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void Mul(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Mul(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void Div(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Div(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void Min(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Min(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void Max(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Max(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void And(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void And(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void Or(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Or(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void Xor(int N, const T* A, const T* B, T* C, Context* context); +TORCH_API void Xor(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void +TORCH_API void BitwiseAnd(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void +TORCH_API void BitwiseOr(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void +TORCH_API void BitwiseXor(int N, const T* A, const T* B, T* C, Context* context); template -CAFFE2_API void EQ(int N, const T* A, const T* B, bool* C, Context* context); +TORCH_API void EQ(int N, const T* A, const T* B, bool* C, Context* context); template -CAFFE2_API void NE(int N, const T* A, const T* B, bool* C, Context* context); +TORCH_API void NE(int N, const T* A, const T* B, bool* C, Context* context); template -CAFFE2_API void LT(int N, const T* A, const T* B, bool* C, Context* context); +TORCH_API void LT(int N, const T* A, const T* B, bool* C, Context* context); template -CAFFE2_API void LE(int N, const T* A, const T* B, bool* C, Context* context); +TORCH_API void LE(int N, const T* A, const T* B, bool* C, Context* context); template -CAFFE2_API void GT(int N, const T* A, const T* B, bool* C, Context* context); +TORCH_API void GT(int N, const T* A, const T* B, bool* C, Context* context); template -CAFFE2_API void GE(int N, const T* A, const T* B, bool* C, Context* context); +TORCH_API void GE(int N, const T* A, const T* B, bool* C, Context* context); template -CAFFE2_API void +TORCH_API void Axpy(std::int64_t N, TAlpha alpha, const TData* X, TData* Y, Context* context); // Different from the Axpy function above, if alpha is passed in // as a pointer, we will assume that it lives on the Context device, // for example on GPU. template -CAFFE2_API void Axpy( +TORCH_API void Axpy( std::int64_t N, const TAlpha* alpha, const TData* X, @@ -135,7 +137,7 @@ CAFFE2_API void Axpy( Context* context); template -CAFFE2_API void Axpby( +TORCH_API void Axpby( std::int64_t N, TAlpha alpha, const TData* X, @@ -144,7 +146,7 @@ CAFFE2_API void Axpby( Context* context); template -CAFFE2_API void Axpby( +TORCH_API void Axpby( std::int64_t N, const TAlpha* alpha, const TData* X, diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu index 9e811b19fe8db..d1969262c6771 100644 --- a/caffe2/utils/math/reduce.cu +++ b/caffe2/utils/math/reduce.cu @@ -156,6 +156,7 @@ void ReduceTensorCUDAImpl( ReduceTensorCUDAKernel <<cuda_stream()>>>( inner_size, X_strides, Y_dims, reducer, init, alpha, X, Y); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -188,12 +189,14 @@ void ReduceTensorCUDA( RowwiseReduceCUDAKernel <<cuda_stream()>>>( cols, reducer, init, alpha, X, Y); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return; } if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { ColwiseReduceCUDAKernel <<cuda_stream()>>>( rows, cols, reducer, init, alpha, X, Y); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return; } int M; @@ -393,6 +396,7 @@ void MomentsCUDAImpl( MomentsCUDAKernel <<cuda_stream()>>>( inner_size, X_strides, Y_dims, X, mean, var); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } template @@ -430,12 +434,14 @@ void MomentsCUDA( RowwiseMomentsCUDAKernel <<cuda_stream()>>>( cols, X, mean, var); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return; } if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { ColwiseMomentsCUDAKernel <<cuda_stream()>>>( rows, cols, X, mean, var); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return; } int M; diff --git a/caffe2/utils/math/reduce.cuh b/caffe2/utils/math/reduce.cuh index 937cd5075278b..0c43ad45a379a 100644 --- a/caffe2/utils/math/reduce.cuh +++ b/caffe2/utils/math/reduce.cuh @@ -21,12 +21,16 @@ using BlockReduce2D = cub:: if (size >= 128) { \ Func \ <<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else if (size >= 64) { \ Func<<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else if (size >= 32) { \ Func<<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else { \ Func<<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } while (false) @@ -36,15 +40,19 @@ using BlockReduce2D = cub:: if (size >= 128) { \ Func \ <<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else if (size >= 64) { \ Func \ <<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else if (size >= 32) { \ Func \ <<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else { \ Func \ <<>>(__VA_ARGS__); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } while (false) diff --git a/caffe2/utils/math/reduce.h b/caffe2/utils/math/reduce.h index 7f8b83578280c..52d056d105bee 100644 --- a/caffe2/utils/math/reduce.h +++ b/caffe2/utils/math/reduce.h @@ -11,11 +11,11 @@ class Tensor; namespace math { template -CAFFE2_API void +TORCH_API void ReduceMin(const int N, const T* X, T* y, Tensor* scratch_ptr, Context* context); template -CAFFE2_API void +TORCH_API void ReduceMax(const int N, const T* X, T* y, Tensor* scratch_ptr, Context* context); // In all of the reduce functions, X_dims and Y_dims should have ndim elements. @@ -25,7 +25,7 @@ ReduceMax(const int N, const T* X, T* y, Tensor* scratch_ptr, Context* context); // Y = alpha * ReduceMin(X) template -CAFFE2_API void ReduceMin( +TORCH_API void ReduceMin( const int ndim, const int* X_dims, const int* Y_dims, @@ -36,7 +36,7 @@ CAFFE2_API void ReduceMin( // Y = alpha * ReduceMax(X) template -CAFFE2_API void ReduceMax( +TORCH_API void ReduceMax( const int ndim, const int* X_dims, const int* Y_dims, @@ -47,7 +47,7 @@ CAFFE2_API void ReduceMax( // Y = alpha * ReduceSum(X) template -CAFFE2_API void ReduceSum( +TORCH_API void ReduceSum( const int ndim, const int* X_dims, const int* Y_dims, @@ -58,7 +58,7 @@ CAFFE2_API void ReduceSum( // Y = alpha * ReduceMean(X) template -CAFFE2_API void ReduceMean( +TORCH_API void ReduceMean( const int ndim, const int* X_dims, const int* Y_dims, @@ -69,7 +69,7 @@ CAFFE2_API void ReduceMean( // Y = alpha * ReduceL1(X) template -CAFFE2_API void ReduceL1( +TORCH_API void ReduceL1( const int ndim, const int* X_dims, const int* Y_dims, @@ -80,7 +80,7 @@ CAFFE2_API void ReduceL1( // Y = alpha * ReduceL2(X) template -CAFFE2_API void ReduceL2( +TORCH_API void ReduceL2( const int ndim, const int* X_dims, const int* Y_dims, @@ -91,7 +91,7 @@ CAFFE2_API void ReduceL2( // Computes mean and variance over axes. template -CAFFE2_API void Moments( +TORCH_API void Moments( const int ndims, const int* X_dims, const int* Y_dims, diff --git a/caffe2/utils/math/transpose.cu b/caffe2/utils/math/transpose.cu index a02a8d207a0ec..4474d38311ad2 100644 --- a/caffe2/utils/math/transpose.cu +++ b/caffe2/utils/math/transpose.cu @@ -68,6 +68,7 @@ void BatchTranspose2DCUDAImpl( BatchTranspose2DCUDAKernel <<cuda_stream()>>>( N, H, W, dh, dw, X, Y); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } #define DELEGATE_TRANSPOSE_2D_CUDA_IMPL(TIndex, TData, CuBLASFunc) \ @@ -106,6 +107,7 @@ void BatchTranspose2DCUDAImpl( dim3(kTileDim, kBlockRows), \ 0, \ context->cuda_stream()>>>(N, H, W, dh, dw, X, Y); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int32_t, float, cublasSgeam) @@ -157,6 +159,7 @@ void TransposeCUDAImpl( TransposeCUDAKernel <<cuda_stream()>>>( size, X_strides, Y_dims, X, Y); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } // namespace diff --git a/caffe2/utils/math/transpose.h b/caffe2/utils/math/transpose.h index a01caa2aaf10a..ca3d7fd859b67 100644 --- a/caffe2/utils/math/transpose.h +++ b/caffe2/utils/math/transpose.h @@ -9,7 +9,7 @@ namespace math { // Transpose tensor X with dims by axes and write the result to tensor Y. template -CAFFE2_API void Transpose( +TORCH_API void Transpose( int ndim, const TIndex* dims, const int* axes, @@ -18,11 +18,11 @@ CAFFE2_API void Transpose( Context* context); template -CAFFE2_API void +TORCH_API void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, Context* context); template -CAFFE2_API void +TORCH_API void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, Context* context); } // namespace math diff --git a/caffe2/utils/math/utils.h b/caffe2/utils/math/utils.h index 473db41a979fe..88b9f7c2efab3 100644 --- a/caffe2/utils/math/utils.h +++ b/caffe2/utils/math/utils.h @@ -61,35 +61,35 @@ MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) { // Increase the index digits by one based on dims. template -CAFFE2_API void +TORCH_API void IncreaseIndexInDims(int ndim, const TIndex* dims, TIndex* index); // Get index value from dims and index digits. template -CAFFE2_API TIndex +TORCH_API TIndex GetIndexFromDims(const int n, const TIndex* dims, const TIndex* index); // Checks if the input permutation is an identity permutation; -CAFFE2_API bool IsIdentityPermutation(const int n, const int* perm); +TORCH_API bool IsIdentityPermutation(const int n, const int* perm); -CAFFE2_API bool +TORCH_API bool CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims); -CAFFE2_API bool IsRowwiseReduce( +TORCH_API bool IsRowwiseReduce( const int ndim, const int* X_dims, const int* Y_dims, int* rows, int* cols); -CAFFE2_API bool IsColwiseReduce( +TORCH_API bool IsColwiseReduce( const int ndim, const int* X_dims, const int* Y_dims, int* rows, int* cols); -CAFFE2_API bool IsBothEndsReduce( +TORCH_API bool IsBothEndsReduce( const int ndim, const int* X_dims, const int* Y_dims, @@ -99,7 +99,7 @@ CAFFE2_API bool IsBothEndsReduce( // Computest the broadcast binary operation dims. template -CAFFE2_API void ComputeBroadcastBinaryOpDims( +TORCH_API void ComputeBroadcastBinaryOpDims( const int A_ndim, const TIndex* A_dims, const int B_ndim, @@ -108,7 +108,7 @@ CAFFE2_API void ComputeBroadcastBinaryOpDims( TIndex* B_broadcast_dims, TIndex* C_broadcast_dims); -CAFFE2_API bool IsRowwiseBroadcastBinaryOp( +TORCH_API bool IsRowwiseBroadcastBinaryOp( const int ndim, const int* A_dims, const int* B_dims, @@ -116,7 +116,7 @@ CAFFE2_API bool IsRowwiseBroadcastBinaryOp( int* cols, bool* broadcast_1st); -CAFFE2_API bool IsColwiseBroadcastBinaryOp( +TORCH_API bool IsColwiseBroadcastBinaryOp( const int ndim, const int* A_dims, const int* B_dims, @@ -124,7 +124,7 @@ CAFFE2_API bool IsColwiseBroadcastBinaryOp( int* cols, bool* broadcast_1st); -CAFFE2_API bool IsBothEndsBroadcastBinaryOp( +TORCH_API bool IsBothEndsBroadcastBinaryOp( const int ndim, const int* A_dims, const int* B_dims, @@ -133,19 +133,19 @@ CAFFE2_API bool IsBothEndsBroadcastBinaryOp( int* nxt, bool* broadcast_1st); -CAFFE2_API bool IsBatchTranspose2D(const int ndim, const int* axes); +TORCH_API bool IsBatchTranspose2D(const int ndim, const int* axes); -CAFFE2_API void ComputeTransposeAxesForReduceOp( +TORCH_API void ComputeTransposeAxesForReduceOp( const int num_dims, const int num_reduce_axes, const int* reduce_axes, int* transpose_axes); -CAFFE2_API void +TORCH_API void ComputeTransposeAxesForReduceOp(const int ndim, const int* dims, int* axes); template -CAFFE2_API void ComputeTransposedStrides( +TORCH_API void ComputeTransposedStrides( int ndim, const TIndex* dims, const int* axes, diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index 9cf59c20ad190..339d9ef8ff460 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -110,7 +110,7 @@ C10_EXPORT void Gemm( return; default: LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; - return; // The line above calls `abort()`. Should never reach here. + return; // The line above calls `abort()`. Should never reach here. } } case CblasTrans: { @@ -127,7 +127,7 @@ C10_EXPORT void Gemm( return; default: LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; - return; // The line above calls `abort()`. Should never reach here. + return; // The line above calls `abort()`. Should never reach here. } } default: @@ -177,7 +177,7 @@ C10_EXPORT void GemmEx( return; default: LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; - return; // The line above calls `abort()`. Should never reach here. + return; // The line above calls `abort()`. Should never reach here. } } case CblasTrans: { @@ -201,7 +201,7 @@ C10_EXPORT void GemmEx( return; default: LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; - return; // The line above calls `abort()`. Should never reach here. + return; // The line above calls `abort()`. Should never reach here. } } default: @@ -1065,11 +1065,23 @@ DEFINE_BROADCAST_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor) #undef DELEGATE_BROADCAST_BINARY_FUNCTION +namespace { +// incrementIfNotMax increments the number if the value is not max for that +// datatype. This ensures that the value never overflows. +template +inline T incrementIfNotMax(T a) { + if (a == std::numeric_limits::max()) { + return a; + } + return a + 1; +} +} // namespace + #define CAFFE2_RAND_UNIFORM_REAL(T) \ template <> \ C10_EXPORT void RandUniform( \ const size_t n, const T a, const T b, T* r, CPUContext* context) { \ - std::uniform_real_distribution distribution(a, b); \ + at::uniform_real_distribution distribution(a, b); \ for (size_t i = 0; i < n; ++i) { \ r[i] = distribution(context->RandGenerator()); \ } \ @@ -1078,14 +1090,15 @@ CAFFE2_RAND_UNIFORM_REAL(float); CAFFE2_RAND_UNIFORM_REAL(double); #undef CAFFE2_RAND_UNIFORM_REAL -#define CAFFE2_RAND_UNIFORM_CHAR(T) \ - template <> \ - C10_EXPORT void RandUniform( \ - const size_t n, const T a, const T b, T* r, CPUContext* context) { \ - std::uniform_int_distribution distribution((short)a, (short)b); \ - for (size_t i = 0; i < n; ++i) { \ - r[i] = static_cast(distribution(context->RandGenerator())); \ - } \ +#define CAFFE2_RAND_UNIFORM_CHAR(T) \ + template <> \ + C10_EXPORT void RandUniform( \ + const size_t n, const T a, const T b, T* r, CPUContext* context) { \ + at::uniform_int_from_to_distribution distribution( \ + incrementIfNotMax(b - a), a); \ + for (size_t i = 0; i < n; ++i) { \ + r[i] = static_cast(distribution(context->RandGenerator())); \ + } \ } CAFFE2_RAND_UNIFORM_CHAR(int8_t); CAFFE2_RAND_UNIFORM_CHAR(uint8_t); @@ -1095,7 +1108,10 @@ CAFFE2_RAND_UNIFORM_CHAR(uint8_t); template <> \ C10_EXPORT void RandUniform( \ const size_t n, const T a, const T b, T* r, CPUContext* context) { \ - std::uniform_int_distribution distribution(a, b); \ + at::uniform_int_from_to_distribution distribution( \ + incrementIfNotMax( \ + static_cast(b) - static_cast(a)), \ + a); \ for (size_t i = 0; i < n; ++i) { \ r[i] = distribution(context->RandGenerator()); \ } \ @@ -1135,7 +1151,7 @@ CAFFE2_RAND_UNIFORM_INT(uint64_t); auto remaining_numbers = n - 1 - i; \ double mean = (sum - current_sum) / (remaining_numbers + 1); \ double stdev = std::min(mean - a, b - mean); \ - std::normal_distribution distribution{mean, stdev / 4.0}; \ + at::normal_distribution distribution{mean, stdev / 4.0}; \ T value, remaining_sum_test; \ do { \ value = distribution(context->RandGenerator()); \ @@ -1350,7 +1366,8 @@ CAFFE2_RAND_SYNTHETIC_DATA(uint64_t); CAFFE_ENFORCE_EQ( \ m, avoid_set.size(), "AC10_EXPORT void should be unique"); \ } \ - std::uniform_int_distribution distribution(a, b); \ + at::uniform_int_from_to_distribution distribution( \ + incrementIfNotMax(b - a), a); \ T v = 0; \ for (size_t i = 0; i < n; ++i) { \ do { \ @@ -1372,7 +1389,7 @@ C10_EXPORT void RandGaussian( const float std, float* r, CPUContext* context) { - std::normal_distribution distribution(mean, std); + at::normal_distribution distribution(mean, std); for (size_t i = 0; i < n; ++i) { r[i] = distribution(context->RandGenerator()); } diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc index e8e42e1bbf3e3..70484643cded0 100644 --- a/caffe2/utils/proto_utils.cc +++ b/caffe2/utils/proto_utils.cc @@ -53,7 +53,6 @@ C10_EXPORT bool IsCPUDeviceType(int device_type) { PROTO_CPU, PROTO_MKLDNN, PROTO_IDEEP, - PROTO_ONLY_FOR_TEST, }; return cpu_types.count(device_type); } @@ -217,10 +216,17 @@ C10_EXPORT bool ReadProtoFromTextFile(const char* filename, Message* proto) { C10_EXPORT void WriteProtoToTextFile( const Message& proto, - const char* filename) { + const char* filename, + bool throwIfError) { int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); FileOutputStream* output = new FileOutputStream(fd); - CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output)); + if(!google::protobuf::TextFormat::Print(proto, output)) { + if (throwIfError) { + CAFFE_THROW("Cannot write proto to text file: ", filename); + } else { + LOG(ERROR) << "Cannot write proto to text file: " << filename; + } + } delete output; close(fd); } diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h index b5e1b52f742d2..47f6c2534f006 100644 --- a/caffe2/utils/proto_utils.h +++ b/caffe2/utils/proto_utils.h @@ -23,27 +23,27 @@ using ::google::protobuf::MessageLite; // Note that we can't use DeviceType_Name, because that is only available in // protobuf-full, and some platforms (like mobile) may want to use // protobuf-lite instead. -CAFFE2_API std::string DeviceTypeName(const int32_t& d); +TORCH_API std::string DeviceTypeName(const int32_t& d); -CAFFE2_API int DeviceId(const DeviceOption& option); +TORCH_API int DeviceId(const DeviceOption& option); // Returns if the two DeviceOptions are pointing to the same device. -CAFFE2_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs); +TORCH_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs); -CAFFE2_API bool IsCPUDeviceType(int device_type); -CAFFE2_API bool IsGPUDeviceType(int device_type); +TORCH_API bool IsCPUDeviceType(int device_type); +TORCH_API bool IsGPUDeviceType(int device_type); // Common interfaces that reads file contents into a string. -CAFFE2_API bool ReadStringFromFile(const char* filename, string* str); -CAFFE2_API bool WriteStringToFile(const string& str, const char* filename); +TORCH_API bool ReadStringFromFile(const char* filename, string* str); +TORCH_API bool WriteStringToFile(const string& str, const char* filename); // Common interfaces that are supported by both lite and full protobuf. -CAFFE2_API bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto); +TORCH_API bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto); inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) { return ReadProtoFromBinaryFile(filename.c_str(), proto); } -CAFFE2_API void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename); +TORCH_API void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename); inline void WriteProtoToBinaryFile(const MessageLite& proto, const string& filename) { return WriteProtoToBinaryFile(proto, filename.c_str()); @@ -60,9 +60,9 @@ inline bool ParseFromString(const string& spec, MessageLite* proto) { } // namespace TextFormat -CAFFE2_API string ProtoDebugString(const MessageLite& proto); +TORCH_API string ProtoDebugString(const MessageLite& proto); -CAFFE2_API bool ParseProtoFromLargeString(const string& str, MessageLite* proto); +TORCH_API bool ParseProtoFromLargeString(const string& str, MessageLite* proto); // Text format MessageLite wrappers: these functions do nothing but just // allowing things to compile. It will produce a runtime error if you are using @@ -80,13 +80,15 @@ inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { inline void WriteProtoToTextFile( const MessageLite& /*proto*/, - const char* /*filename*/) { + const char* /*filename*/, + bool throwIfError = true) { LOG(FATAL) << "If you are running lite version, you should not be " << "calling any text-format protobuffers."; } inline void WriteProtoToTextFile(const MessageLite& proto, - const string& filename) { - return WriteProtoToTextFile(proto, filename.c_str()); + const string& filename, + bool throwIfError = true) { + return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); } inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) { @@ -103,21 +105,21 @@ inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) { using ::google::protobuf::Message; namespace TextFormat { -CAFFE2_API bool ParseFromString(const string& spec, Message* proto); +TORCH_API bool ParseFromString(const string& spec, Message* proto); } // namespace TextFormat -CAFFE2_API string ProtoDebugString(const Message& proto); +TORCH_API string ProtoDebugString(const Message& proto); -CAFFE2_API bool ParseProtoFromLargeString(const string& str, Message* proto); +TORCH_API bool ParseProtoFromLargeString(const string& str, Message* proto); -CAFFE2_API bool ReadProtoFromTextFile(const char* filename, Message* proto); +TORCH_API bool ReadProtoFromTextFile(const char* filename, Message* proto); inline bool ReadProtoFromTextFile(const string filename, Message* proto) { return ReadProtoFromTextFile(filename.c_str(), proto); } -CAFFE2_API void WriteProtoToTextFile(const Message& proto, const char* filename); -inline void WriteProtoToTextFile(const Message& proto, const string& filename) { - return WriteProtoToTextFile(proto, filename.c_str()); +TORCH_API void WriteProtoToTextFile(const Message& proto, const char* filename, bool throwIfError = true); +inline void WriteProtoToTextFile(const Message& proto, const string& filename, bool throwIfError = true) { + return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); } // Read Proto from a file, letting the code figure out if it is text or binary. @@ -187,8 +189,8 @@ inline OperatorDef CreateOperatorDef( engine); } -CAFFE2_API bool HasOutput(const OperatorDef& op, const std::string& output); -CAFFE2_API bool HasInput(const OperatorDef& op, const std::string& input); +TORCH_API bool HasOutput(const OperatorDef& op, const std::string& output); +TORCH_API bool HasInput(const OperatorDef& op, const std::string& input); /** * @brief A helper class to index into arguments. @@ -297,36 +299,36 @@ class C10_EXPORT ArgumentHelper { // Helper methods to get an argument from OperatorDef or NetDef given argument // name. Throws if argument does not exist. -CAFFE2_API const Argument& GetArgument(const OperatorDef& def, const string& name); -CAFFE2_API const Argument& GetArgument(const NetDef& def, const string& name); +TORCH_API const Argument& GetArgument(const OperatorDef& def, const string& name); +TORCH_API const Argument& GetArgument(const NetDef& def, const string& name); // Helper methods to get an argument from OperatorDef or NetDef given argument // name. Returns nullptr if argument does not exist. -CAFFE2_API const Argument* GetArgumentPtr(const OperatorDef& def, const string& name); -CAFFE2_API const Argument* GetArgumentPtr(const NetDef& def, const string& name); +TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, const string& name); +TORCH_API const Argument* GetArgumentPtr(const NetDef& def, const string& name); // Helper methods to query a boolean argument flag from OperatorDef or NetDef // given argument name. If argument does not exist, return default value. // Throws if argument exists but the type is not boolean. -CAFFE2_API bool GetFlagArgument( +TORCH_API bool GetFlagArgument( const OperatorDef& def, const string& name, bool default_value = false); -CAFFE2_API bool GetFlagArgument( +TORCH_API bool GetFlagArgument( const NetDef& def, const string& name, bool default_value = false); -CAFFE2_API Argument* GetMutableArgument( +TORCH_API Argument* GetMutableArgument( const string& name, const bool create_if_missing, OperatorDef* def); -CAFFE2_API Argument* GetMutableArgument( +TORCH_API Argument* GetMutableArgument( const string& name, const bool create_if_missing, NetDef* def); template -CAFFE2_API Argument MakeArgument(const string& name, const T& value); +TORCH_API Argument MakeArgument(const string& name, const T& value); template inline void AddArgument(const string& name, const T& value, Def* def) { @@ -345,7 +347,7 @@ bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) { // - Going through list of ops in order, all op inputs must be outputs // from other ops, or registered as external inputs. // - All external outputs must be outputs of some operators. -CAFFE2_API void cleanupExternalInputsAndOutputs(NetDef* net); +TORCH_API void cleanupExternalInputsAndOutputs(NetDef* net); } // namespace caffe2 diff --git a/caffe2/utils/proto_wrap.cc b/caffe2/utils/proto_wrap.cc index eb06524cae841..6899a5d818539 100644 --- a/caffe2/utils/proto_wrap.cc +++ b/caffe2/utils/proto_wrap.cc @@ -9,7 +9,7 @@ namespace ONNX_NAMESPACE { // ONNX wrapper functions for protobuf's GetEmptyStringAlreadyInited() function // used to avoid duplicated global variable in the case when protobuf // is built with hidden visibility. -CAFFE2_API const ::std::string& GetEmptyStringAlreadyInited() { +TORCH_API const ::std::string& GetEmptyStringAlreadyInited() { return ::google::protobuf::internal::GetEmptyStringAlreadyInited(); } @@ -20,7 +20,7 @@ namespace caffe2 { // Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() function // used to avoid duplicated global variable in the case when protobuf // is built with hidden visibility. -CAFFE2_API const ::std::string& GetEmptyStringAlreadyInited() { +TORCH_API const ::std::string& GetEmptyStringAlreadyInited() { return ::google::protobuf::internal::GetEmptyStringAlreadyInited(); } @@ -35,7 +35,7 @@ namespace torch { // Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() function // used to avoid duplicated global variable in the case when protobuf // is built with hidden visibility. -CAFFE2_API const ::std::string& GetEmptyStringAlreadyInited() { +TORCH_API const ::std::string& GetEmptyStringAlreadyInited() { return ::google::protobuf::internal::GetEmptyStringAlreadyInited(); } diff --git a/caffe2/utils/proto_wrap.h b/caffe2/utils/proto_wrap.h index 92cb2b4227a3a..bcbce663c0d2a 100644 --- a/caffe2/utils/proto_wrap.h +++ b/caffe2/utils/proto_wrap.h @@ -7,7 +7,7 @@ namespace caffe2 { // A wrapper function to shut down protobuf library (this is needed in ASAN // testing and valgrind cases to avoid protobuf appearing to "leak" memory). -CAFFE2_API void ShutdownProtobufLibrary(); +TORCH_API void ShutdownProtobufLibrary(); } // namespace caffe2 diff --git a/caffe2/utils/signal_handler.h b/caffe2/utils/signal_handler.h index c773bdd4393ad..9e0bc2ad2f157 100644 --- a/caffe2/utils/signal_handler.h +++ b/caffe2/utils/signal_handler.h @@ -11,7 +11,7 @@ namespace caffe2 { -class CAFFE2_API SignalHandler { +class TORCH_API SignalHandler { public: enum class Action { NONE, @@ -38,8 +38,8 @@ class CAFFE2_API SignalHandler { // This works by setting up certain fatal signal handlers. Previous fatal // signal handlers will still be called when the signal is raised. Defaults // to being off. -CAFFE2_API void setPrintStackTracesOnFatalSignal(bool print); -CAFFE2_API bool printStackTracesOnFatalSignal(); +TORCH_API void setPrintStackTracesOnFatalSignal(bool print); +TORCH_API bool printStackTracesOnFatalSignal(); #endif // defined(CAFFE2_SUPPORTS_SIGNAL_HANDLER) } // namespace caffe2 diff --git a/caffe2/utils/smart_tensor_printer.h b/caffe2/utils/smart_tensor_printer.h index 48e1e47cf8efc..e6d96ef37ae01 100644 --- a/caffe2/utils/smart_tensor_printer.h +++ b/caffe2/utils/smart_tensor_printer.h @@ -8,7 +8,7 @@ namespace caffe2 { // explicit specify the type of the tensor while calling the Print() method. // It also supports a convenience function with a default constructed printer as // a static method. -class CAFFE2_API SmartTensorPrinter { +class TORCH_API SmartTensorPrinter { public: // The proliferation of constructors is to give the feature parity with // TensorPrinter diff --git a/caffe2/utils/string_utils.h b/caffe2/utils/string_utils.h index bd13b723eda35..e959a467da025 100644 --- a/caffe2/utils/string_utils.h +++ b/caffe2/utils/string_utils.h @@ -9,17 +9,17 @@ namespace caffe2 { -CAFFE2_API std::vector +TORCH_API std::vector split(char separator, const std::string& string, bool ignore_empty = false); -CAFFE2_API std::string trim(const std::string& str); +TORCH_API std::string trim(const std::string& str); -CAFFE2_API size_t editDistance( +TORCH_API size_t editDistance( const std::string& s1, const std::string& s2, size_t max_distance = 0); -CAFFE2_API inline bool StartsWith( +TORCH_API inline bool StartsWith( const std::string& str, const std::string& prefix) { return str.length() >= prefix.length() && @@ -27,7 +27,7 @@ CAFFE2_API inline bool StartsWith( prefix.end(); } -CAFFE2_API inline bool EndsWith( +TORCH_API inline bool EndsWith( const std::string& full, const std::string& ending) { if (full.length() >= ending.length()) { @@ -39,7 +39,7 @@ CAFFE2_API inline bool EndsWith( } } -CAFFE2_API int32_t editDistanceHelper( +TORCH_API int32_t editDistanceHelper( const char* s1, size_t s1_len, const char* s2, diff --git a/caffe2/utils/threadpool/ThreadPool.cc b/caffe2/utils/threadpool/ThreadPool.cc index 00ecfb3ed64ce..6010a86ab1239 100644 --- a/caffe2/utils/threadpool/ThreadPool.cc +++ b/caffe2/utils/threadpool/ThreadPool.cc @@ -15,6 +15,8 @@ C10_DEFINE_int(caffe2_threadpool_android_cap, true, ""); // Whether or not threadpool caps apply to iOS C10_DEFINE_int(caffe2_threadpool_ios_cap, true, ""); +C10_DEFINE_int(pthreadpool_size, 0, "Override the default thread pool size."); + namespace caffe2 { size_t getDefaultNumThreads() { @@ -69,6 +71,11 @@ size_t getDefaultNumThreads() { break; } } + + if (FLAGS_pthreadpool_size) { + // Always give precedence to explicit setting. + numThreads = FLAGS_pthreadpool_size; + } return numThreads; } diff --git a/caffe2/utils/threadpool/ThreadPool.h b/caffe2/utils/threadpool/ThreadPool.h index 5165764fe9380..951b8f7f6befd 100644 --- a/caffe2/utils/threadpool/ThreadPool.h +++ b/caffe2/utils/threadpool/ThreadPool.h @@ -29,8 +29,8 @@ constexpr size_t kCacheLineSize = 64; // misaligned intrinsics, no SSE instructions shall be involved in // the ThreadPool implementation. // Note: alignas is disabled because some compilers do not deal with -// CAFFE2_API and alignas annotations at the same time. -class CAFFE2_API /*alignas(kCacheLineSize)*/ ThreadPool { +// TORCH_API and alignas annotations at the same time. +class TORCH_API /*alignas(kCacheLineSize)*/ ThreadPool { public: static std::unique_ptr defaultThreadPool(); ThreadPool(int numThreads); diff --git a/caffe2/utils/threadpool/pthreadpool.h b/caffe2/utils/threadpool/pthreadpool.h index 27935febe45e7..0c6cc3661e05f 100644 --- a/caffe2/utils/threadpool/pthreadpool.h +++ b/caffe2/utils/threadpool/pthreadpool.h @@ -8,6 +8,25 @@ #include // for size_t #include // for uint32_t +#ifdef USE_PTHREADPOOL +// This is a hack. +// Mainly introduced here because +// 1. NNPACK can be compiled to use internal legacy threadpool implementation because much of C2 depends on that. +// 2. Then if we want to use NNPACK in PyTorch, which uses new pthreadpool, then we will supply new pthreadpool pointer +// to NNPACK. This will not work if NNPACK is compiled with internal legacy threadpool. Thus this guard +// along with changes in pthreadpool_impl.cc allows us to override that behavior. +// It enables us to use NNPACK from pytorch using `caffe2::pthreadpool_()` +namespace caffe2 { +class WithCastToNewThreadPool { + public: + explicit WithCastToNewThreadPool(bool use_new_threadpool); + ~WithCastToNewThreadPool(); + private: + bool use_new_threadpool_; +}; +} +#endif + typedef struct pthreadpool* legacy_pthreadpool_t; typedef void (*legacy_pthreadpool_function_1d_t)(void*, size_t); diff --git a/caffe2/utils/threadpool/pthreadpool_impl.cc b/caffe2/utils/threadpool/pthreadpool_impl.cc index 66326eef7a7b0..8165ae3571ca7 100644 --- a/caffe2/utils/threadpool/pthreadpool_impl.cc +++ b/caffe2/utils/threadpool/pthreadpool_impl.cc @@ -1,6 +1,21 @@ #include "caffe2/utils/threadpool/pthreadpool.h" +#include "caffe2/utils/threadpool/pthreadpool-cpp.h" #include "caffe2/utils/threadpool/ThreadPool.h" +#ifdef USE_PTHREADPOOL +namespace caffe2 { +namespace { +static thread_local bool using_new_threadpool{false}; +} +WithCastToNewThreadPool::WithCastToNewThreadPool(bool use_new_threadpool) { + use_new_threadpool_ = using_new_threadpool; + using_new_threadpool = use_new_threadpool; +} +WithCastToNewThreadPool::~WithCastToNewThreadPool() { + using_new_threadpool = use_new_threadpool_; +} +} +#endif // // External API @@ -19,12 +34,25 @@ void legacy_pthreadpool_compute_1d( } return; } +#ifdef USE_PTHREADPOOL + if (caffe2::using_new_threadpool) { + pthreadpool_parallelize_1d(threadpool, function, argument, range, 0u); + } else { + reinterpret_cast(threadpool) + ->run( + [function, argument](int threadId, size_t workId) { + function(argument, workId); + }, + range); + } +#else reinterpret_cast(threadpool) ->run( [function, argument](int threadId, size_t workId) { function(argument, workId); }, range); +#endif } void legacy_pthreadpool_parallelize_1d( diff --git a/caffe2/video/video_decoder.h b/caffe2/video/video_decoder.h index 5286d52dc7dba..a091142389d63 100644 --- a/caffe2/video/video_decoder.h +++ b/caffe2/video/video_decoder.h @@ -477,11 +477,11 @@ class VideoDecoder { Callback& callback); }; -CAFFE2_API void FreeDecodedData( +TORCH_API void FreeDecodedData( std::vector>& sampledFrames, std::vector>& sampledAudio); -CAFFE2_API bool DecodeMultipleClipsFromVideo( +TORCH_API bool DecodeMultipleClipsFromVideo( const char* video_buffer, const std::string& video_filename, const int encoded_size, diff --git a/caffe2/video/video_io.h b/caffe2/video/video_io.h index a25e87e61a601..beefd7b0782d9 100644 --- a/caffe2/video/video_io.h +++ b/caffe2/video/video_io.h @@ -12,7 +12,7 @@ namespace caffe2 { -CAFFE2_API void ClipTransformRGB( +TORCH_API void ClipTransformRGB( const unsigned char* buffer_rgb, const int crop_size, const int length_rgb, @@ -27,7 +27,7 @@ CAFFE2_API void ClipTransformRGB( const std::vector& inv_std_rgb, float* transformed_clip); -CAFFE2_API void ClipTransformOpticalFlow( +TORCH_API void ClipTransformOpticalFlow( const unsigned char* buffer_rgb, const int crop_size, const int length_of, diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 63e2d9f4d9344..a9d2e4f50e457 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -110,6 +110,12 @@ if(INTERN_BUILD_ATEN_OPS) endif(MSVC) endif(CXX_AVX2_FOUND) + if(CXX_VSX_FOUND) + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_VSX_CPU_DEFINITION") + LIST(APPEND CPU_CAPABILITY_NAMES "VSX") + LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_VSX_FLAGS}") + endif(CXX_VSX_FOUND) + list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES) math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1") @@ -144,7 +150,7 @@ if(INTERN_BUILD_ATEN_OPS) endforeach() list(APPEND ATen_CPU_SRCS ${cpu_kernel_cpp}) - file(GLOB all_python "${CMAKE_CURRENT_LIST_DIR}/../tools/codegen/*.py") + file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../tools/codegen/*.py") set(GEN_ROCM_FLAG) if(USE_ROCM) @@ -167,7 +173,7 @@ if(INTERN_BUILD_ATEN_OPS) endif() execute_process( COMMAND - "${PYTHON_EXECUTABLE}" ${CMAKE_CURRENT_LIST_DIR}/../tools/code_analyzer/gen_op_registration_whitelist.py + "${PYTHON_EXECUTABLE}" ${CMAKE_CURRENT_LIST_DIR}/../tools/code_analyzer/gen_op_registration_allowlist.py --op-dependency "${OP_DEPENDENCY}" --root-ops "${SELECTED_OP_LIST}" OUTPUT_VARIABLE OP_REGISTRATION_WHITELIST @@ -178,9 +184,6 @@ if(INTERN_BUILD_ATEN_OPS) --force_schema_registration --op_registration_whitelist ${OP_REGISTRATION_WHITELIST}) endif() - if(USE_VULKAN) - set(GEN_VULKAN_FLAGS --vulkan) - endif() set(GEN_COMMAND "${PYTHON_EXECUTABLE}" -m tools.codegen.gen diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 028098f61d360..968456c404903 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -46,6 +46,8 @@ endif() # 3. If MSVC_Z7_OVERRIDE is ON, then /Zi and /ZI will be replaced with /Z7 # for Debug and RelWithDebInfo builds if(MSVC) + # skip unwanted includes from windows.h + add_definitions(-DWIN32_LEAN_AND_MEAN) foreach(flag_var CMAKE_C_FLAGS CMAKE_C_FLAGS_RELEASE CMAKE_C_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL) @@ -107,6 +109,8 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) endif() # ---[ BLAS + +# setting default preferred BLAS options if not already present. if(NOT INTERN_BUILD_MOBILE) set(BLAS "MKL" CACHE STRING "Selected BLAS library") else() @@ -146,7 +150,6 @@ elseif(BLAS STREQUAL "MKL") set(CAFFE2_USE_MKL ON) else() message(WARNING "MKL could not be found. Defaulting to Eigen") - set(BLAS "Eigen" CACHE STRING "Selected BLAS library") set(CAFFE2_USE_EIGEN_FOR_BLAS ON) set(CAFFE2_USE_MKL OFF) endif() @@ -173,9 +176,6 @@ if(NOT INTERN_BUILD_MOBILE) find_package(BLAS) if(NOT BLAS_FOUND) set(USE_BLAS 0) - set(BLAS "" CACHE STRING "Selected BLAS library") - else() - set(BLAS BLAS_INFO CACHE STRING "Selected BLAS library") endif() endif() @@ -190,6 +190,26 @@ if(NOT INTERN_BUILD_MOBILE) endif() set(AT_MKL_ENABLED 1) endif() +elseif(INTERN_USE_EIGEN_BLAS) + # Eigen BLAS for Mobile + set(USE_BLAS 1) + include(${CMAKE_CURRENT_LIST_DIR}/External/EigenBLAS.cmake) + list(APPEND Caffe2_DEPENDENCY_LIBS eigen_blas) +endif() + +# ---[ FFTW +set(AT_FFTW_ENABLED 0) +set(USE_FFTW OFF) +if(USE_FFTW OR NOT MKL_FOUND) + find_library(LIBFFTW3 fftw3) + if(LIBFFTW3) + find_path(FFTW3_INCLUDE_DIR NAMES fftw3.h ONLY_CMAKE_FIND_ROOT_PATH) + if(FFTW3_INCLUDE_DIR) + SET(AT_FFTW_ENABLED 1) + SET(USE_FFTW ON) + include_directories(${FFTW3_INCLUDE_DIR}) + endif() + endif() endif() # ---[ Dependencies @@ -222,7 +242,7 @@ if(USE_NNPACK OR USE_QNNPACK OR USE_PYTORCH_QNNPACK OR USE_XNNPACK) set(DISABLE_NNPACK_AND_FAMILY ON) endif() else() - if(NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$")) + if(NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin|Windows)$")) message(WARNING "Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in {Q/X}NNPACK. " "Supported platforms are Android, iOS, Linux, and macOS. " @@ -244,6 +264,13 @@ if(USE_NNPACK OR USE_QNNPACK OR USE_PYTORCH_QNNPACK OR USE_XNNPACK) caffe2_update_option(USE_PYTORCH_QNNPACK OFF) caffe2_update_option(USE_XNNPACK OFF) else() + # Disable unsupported NNPack combinations with MSVC + if(MSVC) + caffe2_update_option(USE_NNPACK OFF) + caffe2_update_option(USE_QNNPACK OFF) + caffe2_update_option(USE_PYTORCH_QNNPACK OFF) + endif() + set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") if(NOT DEFINED CPUINFO_SOURCE_DIR) @@ -266,18 +293,18 @@ else() set(DISABLE_NNPACK_AND_FAMILY ON) endif() +if(USE_QNNPACK AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") + message(WARNING + "QNNPACK does not compile for Apple Silicon. " + "Turn this warning off by explicit USE_QNNPACK=OFF.") + caffe2_update_option(USE_QNNPACK OFF) +endif() + set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs CACHE PATH "Confu-style dependencies source directory") set(CONFU_DEPENDENCIES_BINARY_DIR ${PROJECT_BINARY_DIR}/confu-deps CACHE PATH "Confu-style dependencies binary directory") -# ---[ Eigen BLAS for Mobile -if(INTERN_BUILD_MOBILE AND INTERN_USE_EIGEN_BLAS) - set(USE_BLAS 1) - include(${CMAKE_CURRENT_LIST_DIR}/External/EigenBLAS.cmake) - list(APPEND Caffe2_DEPENDENCY_LIBS eigen_blas) -endif() - # ---[ pthreadpool # Only add a dependency on pthreadpool if we are on a mobile build # or are building any of the libraries in the {Q/X}NNPACK family. @@ -503,6 +530,13 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) "${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK") set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON) + # Workaround for https://github.com/pytorch/pytorch/issues/47292 + if(CMAKE_BUILD_TYPE STREQUAL "Debug" AND CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.5.0)) + # Compiling qu8-requantization/precise-psimd.c without any optimization flags on gcc-7.4 or older i + # Fails with internal compiler error + # Workaround by forcing -O1 for XNNPACK (i.e. build it with RelWithDebInfo) + set_property(TARGET XNNPACK APPEND_STRING PROPERTY COMPILE_FLAGS "-O1") + endif() endif() include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR}) @@ -700,7 +734,6 @@ else() caffe2_update_option(USE_FAKELOWP OFF) endif() - # ---[ LMDB if(USE_LMDB) find_package(LMDB) @@ -1179,8 +1212,9 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -std=c++14) if(CMAKE_BUILD_TYPE MATCHES Debug) - list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -g2) list(APPEND HIP_CXX_FLAGS -O0) + list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling) endif(CMAKE_BUILD_TYPE MATCHES Debug) set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) @@ -1192,7 +1226,7 @@ if(USE_ROCM) endforeach() set(Caffe2_HIP_INCLUDE - ${thrust_INCLUDE_DIRS} ${hipcub_INCLUDE_DIRS} ${rocprim_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${roctracer_INCLUDE_DIRS} ${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} $ ${Caffe2_HIP_INCLUDE}) + $ ${Caffe2_HIP_INCLUDE}) # This is needed for library added by hip_add_library (same for hip_add_executable) hip_include_directories(${Caffe2_HIP_INCLUDE}) @@ -1253,10 +1287,7 @@ if(USE_CUDA) endif() if(USE_GLOO) - if(MSVC) - message(WARNING "Gloo can not be used on Windows.") - caffe2_update_option(USE_GLOO OFF) - elseif(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) message(WARNING "Gloo can only be used on 64-bit systems.") caffe2_update_option(USE_GLOO OFF) else() @@ -1311,6 +1342,10 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") else() + if(USE_CUDA) + set(TP_USE_CUDA ON CACHE BOOL "" FORCE) + set(TP_ENABLE_CUDA_IPC ON CACHE BOOL "" FORCE) + endif() set(TP_BUILD_LIBUV ON CACHE BOOL "" FORCE) set(TP_ENABLE_SHM OFF CACHE BOOL "" FORCE) set(TP_ENABLE_CMA OFF CACHE BOOL "" FORCE) @@ -1480,9 +1515,7 @@ if(NOT INTERN_BUILD_MOBILE) if(MSVC) # we want to respect the standard, and we are bored of those **** . add_definitions(-D_CRT_SECURE_NO_DEPRECATE=1) - # skip unwanted includes from windows.h - add_definitions(-DWIN32_LEAN_AND_MEAN) - list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "/wd4819" "-Xcompiler" "/wd4503" "-Xcompiler" "/wd4190" "-Xcompiler" "/wd4244" "-Xcompiler" "/wd4251" "-Xcompiler" "/wd4275" "-Xcompiler" "/wd4522") + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/wd4819,/wd4503,/wd4190,/wd4244,/wd4251,/wd4275,/wd4522") endif() if(NOT MSVC) @@ -1507,7 +1540,8 @@ if(NOT INTERN_BUILD_MOBILE) if(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor") - list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__") + list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__") add_compile_options(-DCUDA_HAS_FP16=1) else() message(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor") @@ -1604,6 +1638,7 @@ if(NOT INTERN_BUILD_MOBILE) add_compile_options(-DUSE_GCC_GET_CPUID) endif() + find_package(VSX) # checks VSX find_package(AVX) # checks AVX and AVX2 # we don't set -mavx and -mavx2 flags globally, but only for specific files @@ -1626,6 +1661,9 @@ if(NOT INTERN_BUILD_MOBILE) find_package(LAPACK) if(LAPACK_FOUND) set(USE_LAPACK 1) + list(APPEND Caffe2_PRIVATE_DEPENDENCY_LIBS ${LAPACK_LIBRARIES}) + else() + set(USE_LAPACK 0) endif() if(NOT USE_CUDA) @@ -1733,7 +1771,8 @@ endif() # # End ATen checks # - +set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) +set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt) # Disable compiler feature checks for `fmt`. @@ -1746,3 +1785,46 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt) set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "") list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only) +set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) + +# ---[ Kineto +if(USE_KINETO) + if(USE_KINETO AND NOT TARGET kineto) + set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party" CACHE STRING "") + set(KINETO_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/kineto/libkineto" CACHE STRING "") + set(KINETO_BUILD_TESTS OFF CACHE BOOL "") + set(KINETO_LIBRARY_TYPE "static" CACHE STRING "") + set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "") + + message(STATUS "Configuring Kineto dependency:") + message(STATUS " KINETO_SOURCE_DIR = ${KINETO_SOURCE_DIR}") + message(STATUS " KINETO_BUILD_TESTS = ${KINETO_BUILD_TESTS}") + message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}") + message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}") + + if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/include) + set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include") + elseif(EXISTS ${CUDA_SOURCE_DIR}/include/cupti.h) + set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/include") + endif() + + if((NOT DEFINED CUDA_cupti_LIBRARY) OR (${CUDA_cupti_LIBRARY} STREQUAL "CUDA_cupti_LIBRARY-NOTFOUND")) + if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a") + elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti_static.a) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti_static.a") + elseif(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so") + elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti.so) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti.so") + endif() + endif() + message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}") + message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}") + + add_subdirectory("${KINETO_SOURCE_DIR}") + message(STATUS "Configured Kineto as a dependency.") + endif() + + list(APPEND Caffe2_DEPENDENCY_LIBS kineto) +endif() diff --git a/cmake/External/nnpack.cmake b/cmake/External/nnpack.cmake index 84244dc864c36..b1dcd728e6905 100644 --- a/cmake/External/nnpack.cmake +++ b/cmake/External/nnpack.cmake @@ -27,7 +27,7 @@ endif() # (2) Anything but x86, x86-64, ARM, ARM64 - unsupported ############################################################################## if(CMAKE_SYSTEM_PROCESSOR) - if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64|armv5te|armv7-a|armv7l|aarch64)$") + if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64|armv5te|armv7-a|armv7l|arm64|aarch64)$") message(WARNING "NNPACK is not supported on ${CMAKE_SYSTEM_PROCESSOR} processors. " "The only supported architectures are x86, x86-64, ARM, and ARM64. " "Turn this warning off by USE_NNPACK=OFF.") diff --git a/cmake/Metal.cmake b/cmake/Metal.cmake new file mode 100644 index 0000000000000..e3124609c1790 --- /dev/null +++ b/cmake/Metal.cmake @@ -0,0 +1,45 @@ +if(NOT APPLE) + return() +endif() + +if(NOT USE_PYTORCH_METAL) + return() +endif() + +if(IOS OR INTERN_BUILD_MOBILE) + return() +endif() + +set(OSX_PLATFORM "MacOSX.platform") +exec_program(/usr/bin/xcode-select ARGS -print-path OUTPUT_VARIABLE CMAKE_XCODE_DEVELOPER_DIR) +set(XCODE_POST_43_ROOT "${CMAKE_XCODE_DEVELOPER_DIR}/Platforms/${OSX_PLATFORM}/Developer") +set(XCODE_PRE_43_ROOT "/Developer/Platforms/${OSX_PLATFORM}/Developer") +if(NOT DEFINED CMAKE_OSX_DEVELOPER_ROOT) + if(EXISTS ${XCODE_POST_43_ROOT}) + set(CMAKE_OSX_DEVELOPER_ROOT ${XCODE_POST_43_ROOT}) + elseif(EXISTS ${XCODE_PRE_43_ROOT}) + set(CMAKE_OSX_DEVELOPER_ROOT ${XCODE_PRE_43_ROOT}) + endif(EXISTS ${XCODE_POST_43_ROOT}) +endif(NOT DEFINED CMAKE_OSX_DEVELOPER_ROOT) +set(CMAKE_OSX_DEVELOPER_ROOT ${CMAKE_OSX_DEVELOPER_ROOT} CACHE PATH "Location of OSX SDKs root directory") + +if(NOT DEFINED CMAKE_OSX_SDK_ROOT) + file(GLOB _CMAKE_OSX_SDKS "${CMAKE_OSX_DEVELOPER_ROOT}/SDKs/*") + if(_CMAKE_OSX_SDKS) + list(SORT _CMAKE_OSX_SDKS) + list(REVERSE _CMAKE_OSX_SDKS) + list(GET _CMAKE_OSX_SDKS 0 CMAKE_OSX_SDK_ROOT) + message(STATUS "_CMAKE_OSX_SDKS: ${_CMAKE_OSX_SDKS}") + else(_CMAKE_OSX_SDKS) + message(FATAL_ERROR "No OSX SDK's found in default search path ${CMAKE_OSX_DEVELOPER_ROOT}.") + endif(_CMAKE_OSX_SDKS) + message(STATUS "Toolchain using default OSX SDK: ${CMAKE_OSX_SDK_ROOT}") +endif(NOT DEFINED CMAKE_OSX_SDK_ROOT) +set(CMAKE_OSX_SDK_ROOT ${CMAKE_OSX_SDK_ROOT} CACHE PATH "Location of the selected OSX SDK") +set(CMAKE_FRAMEWORK_PATH + ${CMAKE_OSX_SDK_ROOT}/System/Library/Frameworks + ${CMAKE_OSX_SDK_ROOT}/System/Library/PrivateFrameworks + ${CMAKE_OSX_SDK_ROOT}/Developer/Library/Frameworks +) +message(STATUS "CMAKE_FRAMEWORK_PATH: ${CMAKE_FRAMEWORK_PATH}") +set(CMAKE_FIND_FRAMEWORK FIRST) diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 4e2b873e7afe8..1d53184951d77 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -43,6 +43,9 @@ if(NOT INTERN_BUILD_MOBILE) # important because with ASAN you might need to help the compiled library find # some dynamic libraries. cmake_push_check_state(RESET) + if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64)$") + list(APPEND CMAKE_REQUIRED_FLAGS "-arch ${CMAKE_HOST_SYSTEM_PROCESSOR}") + endif() CHECK_C_SOURCE_RUNS(" int main() { return 0; } " COMPILER_WORKS) @@ -306,8 +309,9 @@ endif() # Also, we will turn off deprecated-declarations # due to protobuf. -if(IOS) +if(IOS AND (${IOS_ARCH} MATCHES "armv7*")) add_definitions("-mfpu=neon-fp16") + add_definitions("-arch" ${IOS_ARCH}) add_definitions("-Wno-deprecated-declarations") endif() diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index 2dcb2a24f208b..2e55087160357 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -41,17 +41,22 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDIF (OMAP4_TRUE) ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin") + IF(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND NOT CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64") + set(NEON_FOUND true CACHE BOOL "NEON available on ARM64") + ENDIF() EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE CPUINFO) - #neon instruction can be found on the majority part of modern ARM processor - STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) - STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) - IF (NEON_TRUE) - set(NEON_FOUND true CACHE BOOL "NEON available on host") - ELSE (NEON_TRUE) - set(NEON_FOUND false CACHE BOOL "NEON available on host") - ENDIF (NEON_TRUE) + IF(NOT CPUINFO STREQUAL "") + #neon instruction can be found on the majority part of modern ARM processor + STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) + IF (NEON_TRUE) + set(NEON_FOUND true CACHE BOOL "NEON available on host") + ELSE (NEON_TRUE) + set(NEON_FOUND false CACHE BOOL "NEON available on host") + ENDIF (NEON_TRUE) + ENDIF() ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows") # TODO diff --git a/cmake/Modules/FindBLAS.cmake b/cmake/Modules/FindBLAS.cmake index e93e98a6095d4..e8f5d7c950920 100644 --- a/cmake/Modules/FindBLAS.cmake +++ b/cmake/Modules/FindBLAS.cmake @@ -83,7 +83,7 @@ MACRO(Check_Fortran_Libraries LIBRARIES _prefix _name _flags _list) check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) else (CMAKE_Fortran_COMPILER_WORKS) check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) - endif (CMAKE_Fortran_COMPILER_WORKS) + endif(CMAKE_Fortran_COMPILER_WORKS) set(CMAKE_REQUIRED_LIBRARIES) mark_as_advanced(${_prefix}${_combined_name}_WORKS) set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) @@ -117,33 +117,46 @@ if((NOT BLAS_LIBRARIES) if (BLAS_LIBRARIES) set(BLAS_INFO "accelerate") set(BLAS_IS_ACCELERATE 1) - endif (BLAS_LIBRARIES) + endif(BLAS_LIBRARIES) endif() if((NOT BLAS_LIBRARIES) AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "veclib"))) - check_fortran_libraries( - BLAS_LIBRARIES - BLAS - sgemm - "" - "vecLib") - if (BLAS_LIBRARIES) - set(BLAS_INFO "veclib") - endif (BLAS_LIBRARIES) + FIND_PACKAGE(vecLib) + if(vecLib_FOUND) + SET(BLAS_INFO "veclib") + else() + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "vecLib") + if (BLAS_LIBRARIES) + set(BLAS_INFO "veclib") + endif(BLAS_LIBRARIES) + endif() endif() if((NOT BLAS_LIBRARIES) AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "open"))) - check_fortran_libraries( - BLAS_LIBRARIES - BLAS - sgemm - "" - "openblas") - if(BLAS_LIBRARIES) - set(BLAS_INFO "open") - endif(BLAS_LIBRARIES) + FIND_PACKAGE(OpenBLAS) + if(OpenBLAS_FOUND) + SET(BLAS_INFO "open") + SET(BLAS_LIBRARIES ${OpenBLAS_LIB}) + SET(BLAS_INCLUDE_DIR ${OpenBLAS_INCLUDE_DIR}) + SET(BLAS_VERSION ${OpenBLAS_VERSION}) + else() + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "openblas") + if(BLAS_LIBRARIES) + set(BLAS_INFO "open") + endif(BLAS_LIBRARIES) + endif() endif() if((NOT BLAS_LIBRARIES) @@ -153,7 +166,7 @@ if((NOT BLAS_LIBRARIES) BLAS sgemm "" - "openblas;pthread") + "openblas;pthread;m") if(BLAS_LIBRARIES) set(BLAS_INFO "open") endif(BLAS_LIBRARIES) @@ -182,7 +195,7 @@ if((NOT BLAS_LIBRARIES) "goto2;gfortran") if (BLAS_LIBRARIES) set(BLAS_INFO "goto") - endif (BLAS_LIBRARIES) + endif(BLAS_LIBRARIES) endif() if((NOT BLAS_LIBRARIES) @@ -195,7 +208,7 @@ if((NOT BLAS_LIBRARIES) "goto2;gfortran;pthread") if (BLAS_LIBRARIES) set(BLAS_INFO "goto") - endif (BLAS_LIBRARIES) + endif(BLAS_LIBRARIES) endif() if((NOT BLAS_LIBRARIES) @@ -208,7 +221,7 @@ if((NOT BLAS_LIBRARIES) "acml;gfortran") if (BLAS_LIBRARIES) set(BLAS_INFO "acml") - endif (BLAS_LIBRARIES) + endif(BLAS_LIBRARIES) endif() if((NOT BLAS_LIBRARIES) @@ -222,21 +235,26 @@ if((NOT BLAS_LIBRARIES) "blis") if (BLAS_LIBRARIES) set(BLAS_INFO "FLAME") - endif (BLAS_LIBRARIES) + endif(BLAS_LIBRARIES) endif() # BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) if((NOT BLAS_LIBRARIES) AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "atlas"))) - check_fortran_libraries( - BLAS_LIBRARIES - BLAS - sgemm - "" - "ptf77blas;atlas;gfortran") - if (BLAS_LIBRARIES) - set(BLAS_INFO "atlas") - endif (BLAS_LIBRARIES) + FIND_PACKAGE(Atlas) + if(Atlas_FOUND) + SET(BLAS_INFO "atlas") + else() + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "ptf77blas;atlas;gfortran") + if (BLAS_LIBRARIES) + set(BLAS_INFO "atlas") + endif(BLAS_LIBRARIES) + endif() endif() # Generic BLAS library? @@ -250,7 +268,7 @@ if((NOT BLAS_LIBRARIES) "blas") if (BLAS_LIBRARIES) set(BLAS_INFO "generic") - endif (BLAS_LIBRARIES) + endif(BLAS_LIBRARIES) endif() # Determine if blas was compiled with the f2c conventions @@ -287,7 +305,7 @@ int main() { SET(BLAS_F2C TRUE) ELSE (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) SET(BLAS_F2C FALSE) - ENDIF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + ENDIF(BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) CHECK_C_SOURCE_RUNS(" #include #include @@ -303,7 +321,7 @@ int main() { SET(BLAS_USE_CBLAS_DOT TRUE) ELSE (BLAS_USE_CBLAS_DOT) SET(BLAS_USE_CBLAS_DOT FALSE) - ENDIF (BLAS_USE_CBLAS_DOT) + ENDIF(BLAS_USE_CBLAS_DOT) SET(CMAKE_REQUIRED_LIBRARIES) ENDIF(BLAS_LIBRARIES) @@ -317,10 +335,10 @@ endif(BLAS_LIBRARIES) IF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED) message(FATAL_ERROR "Cannot find a library with BLAS API. Please specify library location.") -ENDIF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED) +ENDIF(NOT BLAS_FOUND AND BLAS_FIND_REQUIRED) IF(NOT BLAS_FIND_QUIETLY) IF(BLAS_FOUND) - MESSAGE(STATUS "Found a library with BLAS API (${BLAS_INFO}).") + MESSAGE(STATUS "Found a library with BLAS API (${BLAS_INFO}). Full path: (${BLAS_LIBRARIES})") ELSE(BLAS_FOUND) MESSAGE(STATUS "Cannot find a library with BLAS API. Not using BLAS.") ENDIF(BLAS_FOUND) diff --git a/cmake/Modules/FindLAPACK.cmake b/cmake/Modules/FindLAPACK.cmake index c057f207132f1..b0e607d905872 100644 --- a/cmake/Modules/FindLAPACK.cmake +++ b/cmake/Modules/FindLAPACK.cmake @@ -123,6 +123,30 @@ if(BLAS_FOUND) IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "open")) SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) check_function_exists("cheev_" OPEN_LAPACK_WORKS) + if(OPEN_LAPACK_WORKS) + check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS) + if(NOT LAPACK_CGESDD_WORKS) + find_library(GFORTRAN_LIBRARY + NAMES libgfortran.a gfortran + PATHS /usr/lib/gcc/aarch64-linux-gnu/9/ + /usr/lib/gcc/x86_64-redhat-linux/9/ + /usr/lib/gcc/aarch64-linux-gnu/8/ + /usr/lib/gcc/x86_64-redhat-linux/8/ + /usr/lib/gcc/aarch64-linux-gnu/7/ + /usr/lib/gcc/x86_64-redhat-linux/7/ + ) + list(APPEND CMAKE_REQUIRED_LIBRARIES "${GFORTRAN_LIBRARY}") + unset(LAPACK_CGESDD_WORKS CACHE) + check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS) + if(LAPACK_CGESDD_WORKS) + list(APPEND LAPACK_LIBRARIES "${GFORTRAN_LIBRARY}") + else() + message(WARNING "OpenBlas has been compiled with Lapack support, but cgesdd can not be used") + set(OPEN_LAPACK_WORKS NO) + endif() + endif() + endif() + set(CMAKE_REQUIRED_LIBRARIES) if(OPEN_LAPACK_WORKS) SET(LAPACK_INFO "open") diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index c66fb19b65457..c9f787a1721ad 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -67,12 +67,14 @@ SET(MKLDNN_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE) SET(MKLDNN_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE) IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set. - SET(ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE) + SET(MKLDNN_ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE) ELSE() IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") - SET(ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE) + IF(CPU_INTEL) + SET(MKLDNN_ARCH_OPT_FLAGS "-msse4" CACHE STRING "" FORCE) + ENDIF() ELSE() - SET(ARCH_OPT_FLAGS "" CACHE STRING "" FORCE) + SET(MKLDNN_ARCH_OPT_FLAGS "" CACHE STRING "" FORCE) ENDIF() ENDIF() diff --git a/cmake/Modules/FindVSX.cmake b/cmake/Modules/FindVSX.cmake new file mode 100644 index 0000000000000..74691f9240fb2 --- /dev/null +++ b/cmake/Modules/FindVSX.cmake @@ -0,0 +1,35 @@ + +IF(CMAKE_SYSTEM_NAME MATCHES "Linux") + message("-- ") + EXEC_PROGRAM(LD_SHOW_AUXV=1 ARGS "/bin/true" OUTPUT_VARIABLE bintrue) + if(bintrue MATCHES "AT_PLATFORM:[ \\t\\n\\r]*([a-zA-Z0-9_]+)[ \\t\\n\\r]*") + if(CMAKE_MATCH_COUNT GREATER 0) + string(TOLOWER ${CMAKE_MATCH_1} platform) + if(${platform} MATCHES "^power") + message("-- POWER Platform: ${platform}") + SET(POWER_COMP TRUE CACHE BOOL "power ") + SET(CXX_VSX_FLAGS "${CXX_VSX_FLAGS} -mcpu=${platform} -mtune=${platform}" ) + endif() + endif() + endif() + SET(VSX_CODE " #include + int main() { + float __attribute__((aligned(16))) vptr_y[8] = { 1.0f,2.f,3.f,4.f,4.f,3.f,2.f,1.f }; + __vector float v_result = vec_add(vec_vsx_ld(0, vptr_y), vec_vsx_ld(16, vptr_y)); + return 0; + }") + #check_cxx_compiler_flag(-mvsx vsx_flag) + SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + SET(CMAKE_REQUIRED_FLAGS "-mvsx") + CHECK_C_SOURCE_COMPILES("${VSX_CODE}" C_VSX_FOUND) + CHECK_CXX_SOURCE_COMPILES("${VSX_CODE}" CXX_VSX_FOUND) + SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + if(CXX_VSX_FOUND) + message("-- VSX flag was set.") + SET(CXX_VSX_FLAGS "${CXX_VSX_FLAGS} -mvsx" ) + elseif(POWER_COMP) + message(WARNING "-- VSX flag was not set.") + endif() + message("-- ") +endif() + diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake index aa210cae64733..ec8a732e30702 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake @@ -1427,7 +1427,7 @@ macro(CUDA_WRAP_SRCS cuda_target format generated_files) set(CUDA_HOST_SHARED_FLAGS) endif() - macro(_filter_blacklisted_host_flags CUDA_FLAGS) + macro(_filter_blocklisted_host_flags CUDA_FLAGS) string(REGEX REPLACE "[ \t]+" ";" ${CUDA_FLAGS} "${${CUDA_FLAGS}}") foreach(_blacklisted ${CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST}) list(REMOVE_ITEM ${CUDA_FLAGS} "${_blacklisted}") @@ -1439,7 +1439,7 @@ macro(CUDA_WRAP_SRCS cuda_target format generated_files) # always need to set the SHARED_FLAGS, though. if(CUDA_PROPAGATE_HOST_FLAGS) set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS}") - _filter_blacklisted_host_flags(_cuda_C_FLAGS) + _filter_blocklisted_host_flags(_cuda_C_FLAGS) set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${_cuda_C_FLAGS} ${CUDA_HOST_SHARED_FLAGS})") else() set(_cuda_host_flags "set(CMAKE_HOST_FLAGS ${CUDA_HOST_SHARED_FLAGS})") @@ -1465,7 +1465,7 @@ macro(CUDA_WRAP_SRCS cuda_target format generated_files) endif() endif() set(_cuda_C_FLAGS "${CMAKE_${CUDA_C_OR_CXX}_FLAGS_${config_upper}}") - _filter_blacklisted_host_flags(_cuda_C_FLAGS) + _filter_blocklisted_host_flags(_cuda_C_FLAGS) if(_cuda_fix_g3) string(REPLACE "-g3" "-g" _cuda_C_FLAGS "${_cuda_C_FLAGS}") endif() diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake index c17dfa7514179..7f22d476d2fbe 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -84,9 +84,20 @@ endif() if(CUDA_VERSION VERSION_GREATER "10.5") list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere") - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0" "8.0+PTX") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0") list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0") + if(CUDA_VERSION VERSION_LESS "11.1") + set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX") + endif() +endif() + +if(NOT CUDA_VERSION VERSION_LESS "11.1") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6" "8.6+PTX") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6") + set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6") + if(CUDA_VERSION VERSION_LESS "12.0") set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0") endif() diff --git a/cmake/ProtoBuf.cmake b/cmake/ProtoBuf.cmake index 32cab7557f3b7..d8a2c279aee47 100644 --- a/cmake/ProtoBuf.cmake +++ b/cmake/ProtoBuf.cmake @@ -39,7 +39,7 @@ macro(custom_protobuf_find) set(CMAKE_POSITION_INDEPENDENT_CODE ON) if(MSVC) - foreach(flag_var + foreach(flag_var CMAKE_C_FLAGS CMAKE_C_FLAGS_RELEASE CMAKE_C_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL) if(${flag_var} MATCHES "/Z[iI7]") @@ -172,8 +172,8 @@ function(caffe2_protobuf_generate_cpp_py srcs_var hdrs_var python_var) list(APPEND ${hdrs_var} "${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h") list(APPEND ${python_var} "${CMAKE_CURRENT_BINARY_DIR}/${fil_we}_pb2.py") - # Add CAFFE2_API prefix to protobuf classes and methods in all cases - set(DLLEXPORT_STR "dllexport_decl=CAFFE2_API:") + # Add TORCH_API prefix to protobuf classes and methods in all cases + set(DLLEXPORT_STR "dllexport_decl=TORCH_API:") # Note: the following depends on PROTOBUF_PROTOC_EXECUTABLE. This # is done to make sure protoc is built before attempting to @@ -196,7 +196,7 @@ function(caffe2_protobuf_generate_cpp_py srcs_var hdrs_var python_var) # If we remove all reference to these pb.h files from external # libraries and binaries this rewrite can be removed. - COMMAND ${CMAKE_COMMAND} -DFILENAME=${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h -DNAMESPACES=caffe\;caffe2\;onnx\;torch -P ${PROJECT_SOURCE_DIR}/cmake/ProtoBufPatch.cmake + COMMAND ${CMAKE_COMMAND} -DFILENAME=${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h -DNAMESPACES=caffe\;caffe2\;onnx\;torch -DLOCAL_PROTOBUF=${CAFFE2_LINK_LOCAL_PROTOBUF} -P ${PROJECT_SOURCE_DIR}/cmake/ProtoBufPatch.cmake DEPENDS ${CAFFE2_PROTOC_EXECUTABLE} ${abs_fil} COMMENT "Running C++/Python protocol buffer compiler on ${fil}" VERBATIM ) @@ -209,6 +209,7 @@ function(caffe2_protobuf_generate_cpp_py srcs_var hdrs_var python_var) COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}" COMMAND ${CAFFE2_PROTOC_EXECUTABLE} -I${PROJECT_SOURCE_DIR} --cpp_out=${DLLEXPORT_STR}${PROJECT_BINARY_DIR} ${abs_fil} COMMAND ${CAFFE2_PROTOC_EXECUTABLE} -I${PROJECT_SOURCE_DIR} --python_out "${PROJECT_BINARY_DIR}" ${abs_fil} + COMMAND ${CMAKE_COMMAND} -DFILENAME=${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h -DNAMESPACES=caffe\;caffe2\;onnx\;torch -DLOCAL_PROTOBUF=${CAFFE2_LINK_LOCAL_PROTOBUF} -P ${PROJECT_SOURCE_DIR}/cmake/ProtoBufPatch.cmake DEPENDS ${CAFFE2_PROTOC_EXECUTABLE} ${abs_fil} COMMENT "Running C++/Python protocol buffer compiler on ${fil}" VERBATIM ) endif() diff --git a/cmake/ProtoBufPatch.cmake b/cmake/ProtoBufPatch.cmake index 2124b61897998..704dcd7da1545 100644 --- a/cmake/ProtoBufPatch.cmake +++ b/cmake/ProtoBufPatch.cmake @@ -1,41 +1,83 @@ # CMake file to replace the string contents in ONNX, Caffe, and Caffe2 proto. # Usage example: -# cmake -DFILENAME=caffe2.pb.h -P ProtoBufPatch.cmake +# cmake -DFILENAME=caffe2.pb.h -DLOCAL_PROTOBUF=ON -P ProtoBufPatch.cmake file(READ ${FILENAME} content) -# protobuf-3.6.0 pattern -string( - REPLACE - "::google::protobuf::internal::GetEmptyStringAlreadyInited" - "GetEmptyStringAlreadyInited" - content - "${content}") +if(LOCAL_PROTOBUF) + # protobuf-3.6.0 pattern + string( + REPLACE + "::google::protobuf::internal::GetEmptyStringAlreadyInited" + "GetEmptyStringAlreadyInited" + content + "${content}") -# protobuf-3.8.0+ pattern -string( - REPLACE - "::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited" - "GetEmptyStringAlreadyInited" - content - "${content}") + # protobuf-3.8.0+ pattern + string( + REPLACE + "::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited" + "GetEmptyStringAlreadyInited" + content + "${content}") -string( - REPLACE - "PROTOBUF_CONSTEXPR" - "" - content - "${content}") + string( + REPLACE + "PROTOBUF_CONSTEXPR" + "" + content + "${content}") -# https://github.com/protocolbuffers/protobuf/commit/0400cca3236de1ca303af38bf81eab332d042b7c -# changes PROTOBUF_CONSTEXPR to constexpr, which breaks windows -# build. -string( - REGEX REPLACE - "static constexpr ([^ ]+) ([^ ]+) =" - "static \\1 const \\2 =" - content - "${content}") + # https://github.com/protocolbuffers/protobuf/commit/0400cca3236de1ca303af38bf81eab332d042b7c + # changes PROTOBUF_CONSTEXPR to constexpr, which breaks windows + # build. + string( + REGEX REPLACE + "static constexpr ([^ ]+) ([^ ]+) =" + "static \\1 const \\2 =" + content + "${content}") + + foreach(ns ${NAMESPACES}) + # Insert "const ::std::string& GetEmptyStringAlreadyInited();" within + # the namespace and make sure we only do it once in the file. Unfortunately + # using string(REPLACE ...) doesn't work because it will replace at all + # locations and there might be multiple declarations of the namespace + # depending on how the proto is structured. + set(search "namespace ${ns} {") + string(LENGTH "${search}" search_len) + string(FIND "${content}" "${search}" pos) + if(${pos} GREATER -1) + math(EXPR pos "${pos}+${search_len}") + string(SUBSTRING "${content}" 0 ${pos} content_pre) + string(SUBSTRING "${content}" ${pos} -1 content_post) + string( + CONCAT + content + "${content_pre}" + " const ::std::string& GetEmptyStringAlreadyInited(); " + "${content_post}") + endif() + endforeach() + + # The moving constructor is defined in the header file, which will cause + # a link error that claims that the vftable is not found. Luckily, we + # could move the definition into the source file to solve the problem. + list(LENGTH NAMESPACES ns_count) + if("${FILENAME}" MATCHES ".pb.h" AND ns_count EQUAL 1) + string(REPLACE ".pb.h" ".pb.cc" SOURCE_FILENAME ${FILENAME}) + file(READ ${SOURCE_FILENAME} content_cc_origin) + + string(REGEX MATCHALL "([a-zA-Z_]+)\\([a-zA-Z_]+&& from\\) noexcept[^}]*}" content_cc "${content}") + string(REGEX REPLACE "};" "}\n" content_cc "${content_cc}") + string(REGEX REPLACE "([a-zA-Z_]+)\\([a-zA-Z_]+&& from\\) noexcept" " \\1::\\1(\\1&& from) noexcept" content_cc "${content_cc}") + set(content_cc "${content_cc_origin}\nnamespace ${NAMESPACES} {\n#if LANG_CXX11\n${content_cc}\n#endif\n}") + + string(REGEX REPLACE "([a-zA-Z_]+)\\([a-zA-Z_]+&& from\\) noexcept([^}]*)}" "\\1(\\1&& from) noexcept;" content "${content}") + + file(WRITE ${SOURCE_FILENAME} "${content_cc}") + endif() +endif() # constexpr int TensorBoundShape_DimType_DimType_ARRAYSIZE = TensorBoundShape_DimType_DimType_MAX + 1; # throws @@ -53,44 +95,4 @@ string( content "${content}") -foreach(ns ${NAMESPACES}) - # Insert "const ::std::string& GetEmptyStringAlreadyInited();" within - # the namespace and make sure we only do it once in the file. Unfortunately - # using string(REPLACE ...) doesn't work because it will replace at all - # locations and there might be multiple declarations of the namespace - # depending on how the proto is structured. - set(search "namespace ${ns} {") - string(LENGTH "${search}" search_len) - string(FIND "${content}" "${search}" pos) - if(${pos} GREATER -1) - math(EXPR pos "${pos}+${search_len}") - string(SUBSTRING "${content}" 0 ${pos} content_pre) - string(SUBSTRING "${content}" ${pos} -1 content_post) - string( - CONCAT - content - "${content_pre}" - " const ::std::string& GetEmptyStringAlreadyInited(); " - "${content_post}") - endif() -endforeach() - -# The moving constructor is defined in the header file, which will cause -# a link error that claims that the vftable is not found. Luckily, we -# could move the definition into the source file to solve the problem. -list(LENGTH NAMESPACES ns_count) -if("${FILENAME}" MATCHES ".pb.h" AND ns_count EQUAL 1) - string(REPLACE ".pb.h" ".pb.cc" SOURCE_FILENAME ${FILENAME}) - file(READ ${SOURCE_FILENAME} content_cc_origin) - - string(REGEX MATCHALL "([a-zA-Z_]+)\\([a-zA-Z_]+&& from\\) noexcept[^}]*}" content_cc "${content}") - string(REGEX REPLACE "};" "}\n" content_cc "${content_cc}") - string(REGEX REPLACE "([a-zA-Z_]+)\\([a-zA-Z_]+&& from\\) noexcept" " \\1::\\1(\\1&& from) noexcept" content_cc "${content_cc}") - set(content_cc "${content_cc_origin}\nnamespace ${NAMESPACES} {\n#if LANG_CXX11\n${content_cc}\n#endif\n}") - - string(REGEX REPLACE "([a-zA-Z_]+)\\([a-zA-Z_]+&& from\\) noexcept([^}]*)}" "\\1(\\1&& from) noexcept;" content "${content}") - - file(WRITE ${SOURCE_FILENAME} "${content_cc}") -endif() - file(WRITE ${FILENAME} "${content}") diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 3d4da7f061762..b91ef59c9bf14 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -9,7 +9,6 @@ function(caffe2_print_configuration_summary) message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}") message(STATUS " C++ compiler id : ${CMAKE_CXX_COMPILER_ID}") message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}") - message(STATUS " BLAS : ${BLAS}") message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) @@ -24,6 +23,7 @@ function(caffe2_print_configuration_summary) message(STATUS " BUILD_CAFFE2_OPS : ${BUILD_CAFFE2_OPS}") message(STATUS " BUILD_CAFFE2_MOBILE : ${BUILD_CAFFE2_MOBILE}") message(STATUS " BUILD_STATIC_RUNTIME_BENCHMARK: ${BUILD_STATIC_RUNTIME_BENCHMARK}") + message(STATUS " BUILD_TENSOREXPR_BENCHMARK: ${BUILD_TENSOREXPR_BENCHMARK}") message(STATUS " BUILD_BINARY : ${BUILD_BINARY}") message(STATUS " BUILD_CUSTOM_PROTOBUF : ${BUILD_CUSTOM_PROTOBUF}") if(${CAFFE2_LINK_LOCAL_PROTOBUF}) @@ -44,12 +44,23 @@ function(caffe2_print_configuration_summary) message(STATUS " Python site-packages: ${PYTHON_SITE_PACKAGES}") endif() message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}") + message(STATUS " CAFFE2_USE_MSVC_STATIC_RUNTIME : ${CAFFE2_USE_MSVC_STATIC_RUNTIME}") message(STATUS " BUILD_TEST : ${BUILD_TEST}") message(STATUS " BUILD_JNI : ${BUILD_JNI}") message(STATUS " BUILD_MOBILE_AUTOGRAD : ${BUILD_MOBILE_AUTOGRAD}") - + if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + message(STATUS " CROSS_COMPILING_MACOSX : ${CROSS_COMPILING_MACOSX}") + endif() message(STATUS " INTERN_BUILD_MOBILE : ${INTERN_BUILD_MOBILE}") + message(STATUS " USE_BLAS : ${USE_BLAS}") + if(${USE_BLAS}) + message(STATUS " BLAS : ${BLAS_INFO}") + endif() + message(STATUS " USE_LAPACK : ${USE_LAPACK}") + if(${USE_LAPACK}) + message(STATUS " LAPACK : ${LAPACK_INFO}") + endif() message(STATUS " USE_ASAN : ${USE_ASAN}") message(STATUS " USE_CPP_CODE_COVERAGE : ${USE_CPP_CODE_COVERAGE}") message(STATUS " USE_CUDA : ${USE_CUDA}") @@ -92,6 +103,7 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") message(STATUS " USE_FBGEMM : ${USE_FBGEMM}") message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}") + message(STATUS " USE_KINETO : ${USE_KINETO}") message(STATUS " USE_FFMPEG : ${USE_FFMPEG}") message(STATUS " USE_GFLAGS : ${USE_GFLAGS}") message(STATUS " USE_GLOG : ${USE_GLOG}") @@ -106,6 +118,8 @@ function(caffe2_print_configuration_summary) message(STATUS " LMDB version : ${LMDB_VERSION}") endif() message(STATUS " USE_METAL : ${USE_METAL}") + message(STATUS " USE_PYTORCH_METAL : ${USE_PYTORCH_METAL}") + message(STATUS " USE_FFTW : ${USE_FFTW}") message(STATUS " USE_MKL : ${CAFFE2_USE_MKL}") message(STATUS " USE_MKLDNN : ${USE_MKLDNN}") if(${CAFFE2_USE_MKLDNN}) @@ -126,6 +140,12 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_OPENMP : ${USE_OPENMP}") message(STATUS " USE_TBB : ${USE_TBB}") message(STATUS " USE_VULKAN : ${USE_VULKAN}") + if(${USE_VULKAN}) + message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}") + message(STATUS " USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}") + message(STATUS " USE_VULKAN_SHADERC_RUNTIME : ${USE_VULKAN_SHADERC_RUNTIME}") + message(STATUS " USE_VULKAN_WRAPPER : ${USE_VULKAN_WRAPPER}") + endif() message(STATUS " USE_PROF : ${USE_PROF}") message(STATUS " USE_QNNPACK : ${USE_QNNPACK}") message(STATUS " USE_PYTORCH_QNNPACK : ${USE_PYTORCH_QNNPACK}") diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index 2c3b75941aa2f..8ed7314c55dba 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -13,6 +13,34 @@ # and the following imported targets: # # torch +macro(append_torchlib_if_found) + foreach (_arg ${ARGN}) + find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib") + if(${_arg}_LIBRARY) + list(APPEND TORCH_LIBRARIES ${${_arg}_LIBRARY}) + else() + message(WARNING "static library ${${_arg}_LIBRARY} not found.") + endif() + endforeach() +endmacro() + +macro(append_wholearchive_lib_if_found) + foreach (_arg ${ARGN}) + find_library(${_arg}_LIBRARY ${_arg} PATHS "${TORCH_INSTALL_PREFIX}/lib") + if(${_arg}_LIBRARY) + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + list(APPEND TORCH_LIBRARIES "-Wl,-force_load,${${_arg}_LIBRARY}") + elseif(MSVC) + list(APPEND TORCH_LIBRARIES "-WHOLEARCHIVE:${${_arg}_LIBRARY}") + else() + # gcc + list(APPEND TORCH_LIBRARIES "-Wl,--whole-archive ${${_arg}_LIBRARY} -Wl,--no-whole-archive") + endif() + else() + message(WARNING "static library ${${_arg}_LIBRARY} not found.") + endif() + endforeach() +endmacro() include(FindPackageHandleStandardArgs) @@ -39,58 +67,69 @@ endif() if(@BUILD_SHARED_LIBS@) find_package(Caffe2 REQUIRED PATHS ${CMAKE_CURRENT_LIST_DIR}/../Caffe2) set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS}) + append_torchlib_if_found(c10) else() add_library(torch STATIC IMPORTED) # set imported_location at the bottom - set(TORCH_LIBRARIES torch) -endif() - -find_library(C10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib") -list(APPEND TORCH_LIBRARIES ${C10_LIBRARY}) - -# We need manually add dependent libraries when they are not linked into the -# shared library. -# TODO: this list might be incomplete. -if(NOT @BUILD_SHARED_LIBS@) - find_library(TORCH_CPU_LIBRARY torch_cpu PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${TORCH_CPU_LIBRARY}) - + #library need whole archive + append_wholearchive_lib_if_found(torch torch_cpu) + if(@USE_CUDA@) + append_wholearchive_lib_if_found(torch_cuda c10_cuda) + endif() + + # We need manually add dependent libraries when they are not linked into the + # shared library. + # TODO: this list might be incomplete. + append_torchlib_if_found(c10) + if(@BUILD_CAFFE2@) + append_torchlib_if_found(Caffe2_perfkernels_avx512 Caffe2_perfkernels_avx2 Caffe2_perfkernels_avx) + endif() + if(@USE_NNPACK@) - find_library(NNPACK_LIBRARY nnpack PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${NNPACK_LIBRARY}) + append_torchlib_if_found(nnpack) endif() if(@USE_PYTORCH_QNNPACK@) - find_library(PYTORCH_QNNPACK_LIBRARY pytorch_qnnpack PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${PYTORCH_QNNPACK_LIBRARY}) + append_torchlib_if_found(pytorch_qnnpack) + endif() + + if(@USE_QNNPACK@) + append_torchlib_if_found(qnnpack) endif() if(@USE_XNNPACK@) - find_library(XNNPACK_LIBRARY XNNPACK PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${XNNPACK_LIBRARY}) + append_torchlib_if_found(XNNPACK) endif() + append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc) + append_torchlib_if_found(onnx onnx_proto) + + append_torchlib_if_found(foxi_loader fmt) + append_torchlib_if_found(clog cpuinfo) + if(NOT @USE_INTERNAL_PTHREADPOOL_IMPL@) - find_library(PTHREADPOOL_LIBRARY pthreadpool PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${PTHREADPOOL_LIBRARY}) + append_torchlib_if_found(pthreadpool) endif() - if(@INTERN_USE_EIGEN_BLAS@) - find_library(EIGEN_BLAS_LIBRARY eigen_blas PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${EIGEN_BLAS_LIBRARY}) + append_torchlib_if_found(eigen_blas) + + if(@USE_FBGEMM@) + append_torchlib_if_found(fbgemm) endif() - find_library(CPUINFO_LIBRARY cpuinfo PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${CPUINFO_LIBRARY}) + if(@USE_MKLDNN@) + append_torchlib_if_found(dnnl mkldnn) + endif() - find_library(CLOG_LIBRARY clog PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_LIBRARIES ${CLOG_LIBRARY}) + append_torchlib_if_found(sleef asmjit) endif() if(@USE_CUDA@) if(MSVC) - set(NVTOOLEXT_HOME "C:/Program Files/NVIDIA Corporation/NvToolsExt") - if($ENV{NVTOOLEXT_HOME}) - set(NVTOOLEXT_HOME $ENV{NVTOOLEXT_HOME}) + if(NOT NVTOOLEXT_HOME) + set(NVTOOLEXT_HOME "C:/Program Files/NVIDIA Corporation/NvToolsExt") + endif() + if(DEFINED ENV{NVTOOLSEXT_PATH}) + set(NVTOOLEXT_HOME $ENV{NVTOOLSEXT_PATH}) endif() set(TORCH_CUDA_LIBRARIES ${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib @@ -112,8 +151,10 @@ if(@USE_CUDA@) ${LIBNVTOOLSEXT} ${CUDA_LIBRARIES}) endif() - find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib") - list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY}) + if(@BUILD_SHARED_LIBS@) + find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib") + list(APPEND TORCH_CUDA_LIBRARIES ${C10_CUDA_LIBRARY}) + endif() list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES}) endif() diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake deleted file mode 100644 index 3838b1fe95496..0000000000000 --- a/cmake/Utils.cmake +++ /dev/null @@ -1,240 +0,0 @@ -################################################################################################ -# Exclude and prepend functionalities -function(exclude OUTPUT INPUT) -set(EXCLUDES ${ARGN}) -foreach(EXCLUDE ${EXCLUDES}) - list(REMOVE_ITEM INPUT "${EXCLUDE}") -endforeach() -set(${OUTPUT} ${INPUT} PARENT_SCOPE) -endfunction(exclude) - -function(prepend OUTPUT PREPEND) -set(OUT "") -foreach(ITEM ${ARGN}) - list(APPEND OUT "${PREPEND}${ITEM}") -endforeach() -set(${OUTPUT} ${OUT} PARENT_SCOPE) -endfunction(prepend) - - -################################################################################################ -# Clears variables from list -# Usage: -# caffe_clear_vars() -macro(caffe_clear_vars) - foreach(_var ${ARGN}) - unset(${_var}) - endforeach() -endmacro() - -################################################################################################ -# Prints list element per line -# Usage: -# caffe_print_list() -function(caffe_print_list) - foreach(e ${ARGN}) - message(STATUS ${e}) - endforeach() -endfunction() - -################################################################################################ -# Reads set of version defines from the header file -# Usage: -# caffe_parse_header( ..) -macro(caffe_parse_header FILENAME FILE_VAR) - set(vars_regex "") - set(__parnet_scope OFF) - set(__add_cache OFF) - foreach(name ${ARGN}) - if("${name}" STREQUAL "PARENT_SCOPE") - set(__parnet_scope ON) - elseif("${name}" STREQUAL "CACHE") - set(__add_cache ON) - elseif(vars_regex) - set(vars_regex "${vars_regex}|${name}") - else() - set(vars_regex "${name}") - endif() - endforeach() - if(EXISTS "${FILENAME}") - file(STRINGS "${FILENAME}" ${FILE_VAR} REGEX "#define[ \t]+(${vars_regex})[ \t]+[0-9]+" ) - else() - unset(${FILE_VAR}) - endif() - foreach(name ${ARGN}) - if(NOT "${name}" STREQUAL "PARENT_SCOPE" AND NOT "${name}" STREQUAL "CACHE") - if(${FILE_VAR}) - if(${FILE_VAR} MATCHES ".+[ \t]${name}[ \t]+([0-9]+).*") - string(REGEX REPLACE ".+[ \t]${name}[ \t]+([0-9]+).*" "\\1" ${name} "${${FILE_VAR}}") - else() - set(${name} "") - endif() - if(__add_cache) - set(${name} ${${name}} CACHE INTERNAL "${name} parsed from ${FILENAME}" FORCE) - elseif(__parnet_scope) - set(${name} "${${name}}" PARENT_SCOPE) - endif() - else() - unset(${name} CACHE) - endif() - endif() - endforeach() -endmacro() - -################################################################################################ -# Reads single version define from the header file and parses it -# Usage: -# caffe_parse_header_single_define( ) -function(caffe_parse_header_single_define LIBNAME HDR_PATH VARNAME) - set(${LIBNAME}_H "") - if(EXISTS "${HDR_PATH}") - file(STRINGS "${HDR_PATH}" ${LIBNAME}_H REGEX "^#define[ \t]+${VARNAME}[ \t]+\"[^\"]*\".*$" LIMIT_COUNT 1) - endif() - - if(${LIBNAME}_H) - string(REGEX REPLACE "^.*[ \t]${VARNAME}[ \t]+\"([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MAJOR "${${LIBNAME}_H}") - string(REGEX REPLACE "^.*[ \t]${VARNAME}[ \t]+\"[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MINOR "${${LIBNAME}_H}") - string(REGEX REPLACE "^.*[ \t]${VARNAME}[ \t]+\"[0-9]+\\.[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_PATCH "${${LIBNAME}_H}") - set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE) - set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE) - set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE) - set(${LIBNAME}_VERSION_STRING "${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}" PARENT_SCOPE) - - # append a TWEAK version if it exists: - set(${LIBNAME}_VERSION_TWEAK "") - if("${${LIBNAME}_H}" MATCHES "^.*[ \t]${VARNAME}[ \t]+\"[0-9]+\\.[0-9]+\\.[0-9]+\\.([0-9]+).*$") - set(${LIBNAME}_VERSION_TWEAK "${CMAKE_MATCH_1}" ${ARGN} PARENT_SCOPE) - endif() - if(${LIBNAME}_VERSION_TWEAK) - set(${LIBNAME}_VERSION_STRING "${${LIBNAME}_VERSION_STRING}.${${LIBNAME}_VERSION_TWEAK}" ${ARGN} PARENT_SCOPE) - else() - set(${LIBNAME}_VERSION_STRING "${${LIBNAME}_VERSION_STRING}" ${ARGN} PARENT_SCOPE) - endif() - endif() -endfunction() - -################################################################################################ -# Parses a version string that might have values beyond major, minor, and patch -# and set version variables for the library. -# Usage: -# caffe2_parse_version_str( ) -function(caffe2_parse_version_str LIBNAME VERSIONSTR) - string(REGEX REPLACE "^([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MAJOR "${VERSIONSTR}") - string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MINOR "${VERSIONSTR}") - string(REGEX REPLACE "[0-9]+\\.[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_PATCH "${VERSIONSTR}") - set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE) - set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE) - set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE) - set(${LIBNAME}_VERSION "${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}" PARENT_SCOPE) -endfunction() - -### -# Removes common indentation from a block of text to produce code suitable for -# setting to `python -c`, or using with pycmd. This allows multiline code to be -# nested nicely in the surrounding code structure. -# -# This function respsects PYTHON_EXECUTABLE if it defined, otherwise it uses -# `python` and hopes for the best. An error will be thrown if it is not found. -# -# Args: -# outvar : variable that will hold the stdout of the python command -# text : text to remove indentation from -# -function(dedent outvar text) - # Use PYTHON_EXECUTABLE if it is defined, otherwise default to python - if("${PYTHON_EXECUTABLE}" STREQUAL "") - set(_python_exe "python") - else() - set(_python_exe "${PYTHON_EXECUTABLE}") - endif() - set(_fixup_cmd "import sys; from textwrap import dedent; print(dedent(sys.stdin.read()))") - file(WRITE "${CMAKE_BINARY_DIR}/indented.txt" "${text}") - execute_process( - COMMAND "${_python_exe}" -c "${_fixup_cmd}" - INPUT_FILE "${CMAKE_BINARY_DIR}/indented.txt" - RESULT_VARIABLE _dedent_exitcode - OUTPUT_VARIABLE _dedent_text) - if(NOT _dedent_exitcode EQUAL 0) - message(ERROR " Failed to remove indentation from: \n\"\"\"\n${text}\n\"\"\" - Python dedent failed with error code: ${_dedent_exitcode}") - message(FATAL_ERROR " Python dedent failed with error code: ${_dedent_exitcode}") - endif() - # Remove supurflous newlines (artifacts of print) - string(STRIP "${_dedent_text}" _dedent_text) - set(${outvar} "${_dedent_text}" PARENT_SCOPE) -endfunction() - - -function(pycmd_no_exit outvar exitcode cmd) - # Use PYTHON_EXECUTABLE if it is defined, otherwise default to python - if("${PYTHON_EXECUTABLE}" STREQUAL "") - set(_python_exe "python") - else() - set(_python_exe "${PYTHON_EXECUTABLE}") - endif() - # run the actual command - execute_process( - COMMAND "${_python_exe}" -c "${cmd}" - RESULT_VARIABLE _exitcode - OUTPUT_VARIABLE _output) - # Remove supurflous newlines (artifacts of print) - string(STRIP "${_output}" _output) - set(${outvar} "${_output}" PARENT_SCOPE) - set(${exitcode} "${_exitcode}" PARENT_SCOPE) -endfunction() - - -### -# Helper function to run `python -c ""` and capture the results of stdout -# -# Runs a python command and populates an outvar with the result of stdout. -# Common indentation in the text of `cmd` is removed before the command is -# executed, so the caller does not need to worry about indentation issues. -# -# This function respsects PYTHON_EXECUTABLE if it defined, otherwise it uses -# `python` and hopes for the best. An error will be thrown if it is not found. -# -# Args: -# outvar : variable that will hold the stdout of the python command -# cmd : text representing a (possibly multiline) block of python code -# -function(pycmd outvar cmd) - dedent(_dedent_cmd "${cmd}") - pycmd_no_exit(_output _exitcode "${_dedent_cmd}") - - if(NOT _exitcode EQUAL 0) - message(ERROR " Failed when running python code: \"\"\"\n${_dedent_cmd}\n\"\"\"") - message(FATAL_ERROR " Python command failed with error code: ${_exitcode}") - endif() - # Remove supurflous newlines (artifacts of print) - string(STRIP "${_output}" _output) - set(${outvar} "${_output}" PARENT_SCOPE) -endfunction() - -### -# Helper function to print out everything that cmake knows about a target -# -# Copied from https://stackoverflow.com/questions/32183975/how-to-print-all-the-properties-of-a-target-in-cmake -# This isn't called anywhere, but it's very useful when debugging cmake -# NOTE: This doesn't work for INTERFACE_LIBRARY or INTERFACE_LINK_LIBRARY targets - -function(print_target_properties tgt) - if(NOT TARGET ${tgt}) - message("There is no target named '${tgt}'") - return() - endif() - - # Get a list of all cmake properties TODO cache this lazily somehow - execute_process(COMMAND cmake --help-property-list OUTPUT_VARIABLE CMAKE_PROPERTY_LIST) - string(REGEX REPLACE ";" "\\\\;" CMAKE_PROPERTY_LIST "${CMAKE_PROPERTY_LIST}") - string(REGEX REPLACE "\n" ";" CMAKE_PROPERTY_LIST "${CMAKE_PROPERTY_LIST}") - - foreach(prop ${CMAKE_PROPERTY_LIST}) - string(REPLACE "" "${CMAKE_BUILD_TYPE}" prop ${prop}) - get_property(propval TARGET ${tgt} PROPERTY ${prop} SET) - if(propval) - get_target_property(propval ${tgt} ${prop}) - message("${tgt} ${prop} = ${propval}") - endif() - endforeach(prop) -endfunction(print_target_properties) diff --git a/cmake/iOS.cmake b/cmake/iOS.cmake index cb31605f95430..a43874b6ffea0 100644 --- a/cmake/iOS.cmake +++ b/cmake/iOS.cmake @@ -165,11 +165,11 @@ elseif(IOS_PLATFORM STREQUAL "WATCHOS") set(DEFAULT_IOS_ARCH "armv7k;arm64_32") endif() -set(IOS_ARCH ${DEFAULT_IOS_ARCH} CACHE string "Build architecture for iOS") -set(CMAKE_OSX_ARCHITECTURES ${IOS_ARCH} CACHE string "Build architecture for iOS") +set(IOS_ARCH ${DEFAULT_IOS_ARCH} CACHE STRING "Build architecture for iOS") +set(CMAKE_OSX_ARCHITECTURES ${IOS_ARCH} CACHE STRING "Build architecture for iOS") # Set the find root to the iOS developer roots and to user defined paths -set(CMAKE_FIND_ROOT_PATH ${CMAKE_IOS_DEVELOPER_ROOT} ${CMAKE_IOS_SDK_ROOT} ${CMAKE_PREFIX_PATH} CACHE string "iOS find search path root") +set(CMAKE_FIND_ROOT_PATH ${CMAKE_IOS_DEVELOPER_ROOT} ${CMAKE_IOS_SDK_ROOT} ${CMAKE_PREFIX_PATH} CACHE STRING "iOS find search path root") # default to searching for frameworks first set(CMAKE_FIND_FRAMEWORK FIRST) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 386280cbb4ffd..f5e77dc7f52a0 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -154,6 +154,7 @@ if(HIP_FOUND) ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.### set(hip_DIR ${HIP_PATH}/lib/cmake/hip) + set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand) @@ -168,6 +169,7 @@ if(HIP_FOUND) set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust) find_package_and_print_version(hip REQUIRED) + find_package_and_print_version(hsa-runtime64 REQUIRED) find_package_and_print_version(amd_comgr REQUIRED) find_package_and_print_version(rocrand REQUIRED) find_package_and_print_version(hiprand REQUIRED) @@ -203,9 +205,4 @@ if(HIP_FOUND) # roctx is part of roctracer find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib) set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include) - - # Necessary includes for building PyTorch since we include HIP headers that depend on hcc/hsa headers. - set(hcc_INCLUDE_DIRS ${HCC_PATH}/include) - set(hsa_INCLUDE_DIRS ${HSA_PATH}/include) - endif() diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index 8b60915f7e007..a418724f6256a 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -474,11 +474,13 @@ foreach(diag cc_clobber_ignored integer_sign_change useless_using_declaration unsigned_compare_with_zero declared_but_not_referenced bad_friend_decl) - list(APPEND CUDA_NVCC_FLAGS -Xcudafe --diag_suppress=${diag}) + list(APPEND SUPPRESS_WARNING_FLAGS --diag_suppress=${diag}) endforeach() +string(REPLACE ";" "," SUPPRESS_WARNING_FLAGS "${SUPPRESS_WARNING_FLAGS}") +list(APPEND CUDA_NVCC_FLAGS -Xcudafe ${SUPPRESS_WARNING_FLAGS}) # Set C++14 support -set(CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST "-Werror") +set(CUDA_PROPAGATE_HOST_FLAGS_BLOCKLIST "-Werror") if(MSVC) list(APPEND CUDA_NVCC_FLAGS "--Werror" "cross-execution-space-call") list(APPEND CUDA_NVCC_FLAGS "--no-host-device-move-forward") @@ -490,7 +492,7 @@ endif() # OpenMP flags for NVCC with Clang-cl if("${CMAKE_CXX_SIMULATE_ID}" STREQUAL "MSVC" AND "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") - list(APPEND CUDA_PROPAGATE_HOST_FLAGS_BLACKLIST "-Xclang" "-fopenmp") + list(APPEND CUDA_PROPAGATE_HOST_FLAGS_BLOCKLIST "-Xclang" "-fopenmp") if(MSVC_TOOLSET_VERSION LESS 142) list(APPEND CUDA_NVCC_FLAGS "-Xcompiler" "-openmp") else() diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 0425eaee46f5f..f7455f078040c 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -1,3 +1,185 @@ +################################################################################################ +# Exclude and prepend functionalities +function(exclude OUTPUT INPUT) +set(EXCLUDES ${ARGN}) +foreach(EXCLUDE ${EXCLUDES}) + list(REMOVE_ITEM INPUT "${EXCLUDE}") +endforeach() +set(${OUTPUT} ${INPUT} PARENT_SCOPE) +endfunction(exclude) + +function(prepend OUTPUT PREPEND) +set(OUT "") +foreach(ITEM ${ARGN}) + list(APPEND OUT "${PREPEND}${ITEM}") +endforeach() +set(${OUTPUT} ${OUT} PARENT_SCOPE) +endfunction(prepend) + + +################################################################################################ +# Clears variables from list +# Usage: +# caffe_clear_vars() +macro(caffe_clear_vars) + foreach(_var ${ARGN}) + unset(${_var}) + endforeach() +endmacro() + +################################################################################################ +# Prints list element per line +# Usage: +# caffe_print_list() +function(caffe_print_list) + foreach(e ${ARGN}) + message(STATUS ${e}) + endforeach() +endfunction() + +################################################################################################ +# Reads set of version defines from the header file +# Usage: +# caffe_parse_header( ..) +macro(caffe_parse_header FILENAME FILE_VAR) + set(vars_regex "") + set(__parnet_scope OFF) + set(__add_cache OFF) + foreach(name ${ARGN}) + if("${name}" STREQUAL "PARENT_SCOPE") + set(__parnet_scope ON) + elseif("${name}" STREQUAL "CACHE") + set(__add_cache ON) + elseif(vars_regex) + set(vars_regex "${vars_regex}|${name}") + else() + set(vars_regex "${name}") + endif() + endforeach() + if(EXISTS "${FILENAME}") + file(STRINGS "${FILENAME}" ${FILE_VAR} REGEX "#define[ \t]+(${vars_regex})[ \t]+[0-9]+" ) + else() + unset(${FILE_VAR}) + endif() + foreach(name ${ARGN}) + if(NOT "${name}" STREQUAL "PARENT_SCOPE" AND NOT "${name}" STREQUAL "CACHE") + if(${FILE_VAR}) + if(${FILE_VAR} MATCHES ".+[ \t]${name}[ \t]+([0-9]+).*") + string(REGEX REPLACE ".+[ \t]${name}[ \t]+([0-9]+).*" "\\1" ${name} "${${FILE_VAR}}") + else() + set(${name} "") + endif() + if(__add_cache) + set(${name} ${${name}} CACHE INTERNAL "${name} parsed from ${FILENAME}" FORCE) + elseif(__parnet_scope) + set(${name} "${${name}}" PARENT_SCOPE) + endif() + else() + unset(${name} CACHE) + endif() + endif() + endforeach() +endmacro() + +################################################################################################ +# Parses a version string that might have values beyond major, minor, and patch +# and set version variables for the library. +# Usage: +# caffe2_parse_version_str( ) +function(caffe2_parse_version_str LIBNAME VERSIONSTR) + string(REGEX REPLACE "^([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MAJOR "${VERSIONSTR}") + string(REGEX REPLACE "^[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MINOR "${VERSIONSTR}") + string(REGEX REPLACE "[0-9]+\\.[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_PATCH "${VERSIONSTR}") + set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION "${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}" PARENT_SCOPE) +endfunction() + +### +# Removes common indentation from a block of text to produce code suitable for +# setting to `python -c`, or using with pycmd. This allows multiline code to be +# nested nicely in the surrounding code structure. +# +# This function respsects PYTHON_EXECUTABLE if it defined, otherwise it uses +# `python` and hopes for the best. An error will be thrown if it is not found. +# +# Args: +# outvar : variable that will hold the stdout of the python command +# text : text to remove indentation from +# +function(dedent outvar text) + # Use PYTHON_EXECUTABLE if it is defined, otherwise default to python + if("${PYTHON_EXECUTABLE}" STREQUAL "") + set(_python_exe "python") + else() + set(_python_exe "${PYTHON_EXECUTABLE}") + endif() + set(_fixup_cmd "import sys; from textwrap import dedent; print(dedent(sys.stdin.read()))") + file(WRITE "${CMAKE_BINARY_DIR}/indented.txt" "${text}") + execute_process( + COMMAND "${_python_exe}" -c "${_fixup_cmd}" + INPUT_FILE "${CMAKE_BINARY_DIR}/indented.txt" + RESULT_VARIABLE _dedent_exitcode + OUTPUT_VARIABLE _dedent_text) + if(NOT _dedent_exitcode EQUAL 0) + message(ERROR " Failed to remove indentation from: \n\"\"\"\n${text}\n\"\"\" + Python dedent failed with error code: ${_dedent_exitcode}") + message(FATAL_ERROR " Python dedent failed with error code: ${_dedent_exitcode}") + endif() + # Remove supurflous newlines (artifacts of print) + string(STRIP "${_dedent_text}" _dedent_text) + set(${outvar} "${_dedent_text}" PARENT_SCOPE) +endfunction() + + +function(pycmd_no_exit outvar exitcode cmd) + # Use PYTHON_EXECUTABLE if it is defined, otherwise default to python + if("${PYTHON_EXECUTABLE}" STREQUAL "") + set(_python_exe "python") + else() + set(_python_exe "${PYTHON_EXECUTABLE}") + endif() + # run the actual command + execute_process( + COMMAND "${_python_exe}" -c "${cmd}" + RESULT_VARIABLE _exitcode + OUTPUT_VARIABLE _output) + # Remove supurflous newlines (artifacts of print) + string(STRIP "${_output}" _output) + set(${outvar} "${_output}" PARENT_SCOPE) + set(${exitcode} "${_exitcode}" PARENT_SCOPE) +endfunction() + + +### +# Helper function to run `python -c ""` and capture the results of stdout +# +# Runs a python command and populates an outvar with the result of stdout. +# Common indentation in the text of `cmd` is removed before the command is +# executed, so the caller does not need to worry about indentation issues. +# +# This function respsects PYTHON_EXECUTABLE if it defined, otherwise it uses +# `python` and hopes for the best. An error will be thrown if it is not found. +# +# Args: +# outvar : variable that will hold the stdout of the python command +# cmd : text representing a (possibly multiline) block of python code +# +function(pycmd outvar cmd) + dedent(_dedent_cmd "${cmd}") + pycmd_no_exit(_output _exitcode "${_dedent_cmd}") + + if(NOT _exitcode EQUAL 0) + message(ERROR " Failed when running python code: \"\"\"\n${_dedent_cmd}\n\"\"\"") + message(FATAL_ERROR " Python command failed with error code: ${_exitcode}") + endif() + # Remove supurflous newlines (artifacts of print) + string(STRIP "${_output}" _output) + set(${outvar} "${_output}" PARENT_SCOPE) +endfunction() + + ############################################################################## # Macro to update cached options. macro(caffe2_update_option variable value) @@ -156,21 +338,6 @@ function(caffe2_hip_binary_target target_name_or_src) target_include_directories(${__target} PRIVATE ${Caffe2_HIP_INCLUDE}) endfunction() -############################################################################## -# Multiplex between loading executables for CUDA versus HIP (AMD Software Stack). -# Usage: -# torch_cuda_based_add_executable(cuda_target) -# -macro(torch_cuda_based_add_executable cuda_target) - if(USE_ROCM) - hip_add_executable(${cuda_target} ${ARGN}) - elseif(USE_CUDA) - cuda_add_executable(${cuda_target} ${ARGN}) - else() - - endif() -endmacro() - ############################################################################## # Multiplex between adding libraries for CUDA versus HIP (AMD Software Stack). @@ -222,6 +389,11 @@ endmacro() function(torch_compile_options libname) set_property(TARGET ${libname} PROPERTY CXX_STANDARD 14) + # ---[ Check if warnings should be errors. + if(WERROR) + target_compile_options(${libname} PRIVATE -Werror) + endif() + if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) # until they can be unified, keep these lists synced with setup.py if(MSVC) @@ -274,7 +446,7 @@ function(torch_compile_options libname) if(MSVC) elseif(WERROR) - target_compile_options(${libname} PRIVATE -Werror -Wno-strict-overflow) + target_compile_options(${libname} PRIVATE -Wno-strict-overflow) endif() endif() @@ -292,11 +464,6 @@ function(torch_compile_options libname) # Use -O2 for release builds (-O3 doesn't improve perf, and -Os results in perf regression) target_compile_options(${libname} PRIVATE "$<$,$>:-O2>") - # ---[ Check if warnings should be errors. - # TODO: Dedupe with WERROR check above - if(WERROR) - target_compile_options(${libname} PRIVATE -Werror) - endif() endfunction() @@ -317,3 +484,4 @@ function(torch_set_target_props libname) set_target_properties(${libname} PROPERTIES STATIC_LIBRARY_FLAGS_DEBUG "/NODEFAULTLIB:${VCOMP_LIB}d") endif() endfunction() + diff --git a/codecov.yml b/codecov.yml index 79a3cd8057b19..2d12cb5c84b41 100644 --- a/codecov.yml +++ b/codecov.yml @@ -3,5 +3,25 @@ coverage: project: default: threshold: 1% +codecov: + notify: + # Code coverage is collected by 5 configs: codecov_test[12], onnx[12] and windows_test1 + after_n_builds: 5 +comment: + layout: "diff" + behavior: once + require_changes: true + require_base: yes + require_head: yes + after_n_builds: 5 + branches: + - "master" +# Disable inline comments that this code is not covered +github_checks: + annotations: false fixes: - "/opt/conda/lib/python3.8/site-packages/::project/" + - "C:/Users/circleci/project/build/win_tmp/build/::project/" +ignore: + - "caffe2" + - "third_party" diff --git a/docker.Makefile b/docker.Makefile index 18acced1de8d9..6b843fa9c1b30 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -1,52 +1,66 @@ -DOCKER_REGISTRY = docker.io -DOCKER_ORG = $(shell docker info 2>/dev/null | sed '/Username:/!d;s/.* //') -DOCKER_IMAGE = pytorch -DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) +DOCKER_REGISTRY = docker.io +DOCKER_ORG = $(shell docker info 2>/dev/null | sed '/Username:/!d;s/.* //') +DOCKER_IMAGE = pytorch +DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) ifeq ("$(DOCKER_ORG)","") $(warning WARNING: No docker user found using results from whoami) -DOCKER_ORG = $(shell whoami) +DOCKER_ORG = $(shell whoami) endif -BASE_RUNTIME = ubuntu:18.04 -BASE_DEVEL = nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04 +CUDA_VERSION = 11.0 +CUDNN_VERSION = 8 +BASE_RUNTIME = ubuntu:18.04 +BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-devel-ubuntu18.04 # The conda channel to use to install pytorch / torchvision -INSTALL_CHANNEL = pytorch +INSTALL_CHANNEL = pytorch -PYTHON_VERSION = 3.7 +PYTHON_VERSION = 3.7 +PYTORCH_VERSION = $(shell git describe --tags) # Can be either official / dev -BUILD_TYPE = dev -BUILD_PROGRESS = auto -BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) --build-arg PYTHON_VERSION=$(PYTHON_VERSION) --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) -DOCKER_BUILD = DOCKER_BUILDKIT=1 docker build --progress=$(BUILD_PROGRESS) --target $(BUILD_TYPE) -t $(DOCKER_FULL_NAME):$(DOCKER_TAG) $(BUILD_ARGS) . -DOCKER_PUSH = docker push $(DOCKER_FULL_NAME):$(DOCKER_TAG) +BUILD_TYPE = dev +BUILD_PROGRESS = auto +BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ + --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ + --build-arg CUDA_VERSION=$(CUDA_VERSION) \ + --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) \ + --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) +EXTRA_DOCKER_BUILD_FLAGS ?= +DOCKER_BUILD = DOCKER_BUILDKIT=1 \ + docker build \ + --progress=$(BUILD_PROGRESS) \ + $(EXTRA_DOCKER_BUILD_FLAGS) \ + --target $(BUILD_TYPE) \ + -t $(DOCKER_FULL_NAME):$(DOCKER_TAG) \ + $(BUILD_ARGS) . +DOCKER_PUSH = docker push $(DOCKER_FULL_NAME):$(DOCKER_TAG) .PHONY: all all: devel-image .PHONY: devel-image devel-image: BASE_IMAGE := $(BASE_DEVEL) -devel-image: DOCKER_TAG := $(shell git describe --tags)-devel +devel-image: DOCKER_TAG := $(PYTORCH_VERSION)-devel devel-image: $(DOCKER_BUILD) .PHONY: devel-image devel-push: BASE_IMAGE := $(BASE_DEVEL) -devel-push: DOCKER_TAG := $(shell git describe --tags)-devel +devel-push: DOCKER_TAG := $(PYTORCH_VERSION)-devel devel-push: $(DOCKER_PUSH) .PHONY: runtime-image runtime-image: BASE_IMAGE := $(BASE_RUNTIME) -runtime-image: DOCKER_TAG := $(shell git describe --tags)-runtime +runtime-image: DOCKER_TAG := $(PYTORCH_VERSION)-runtime runtime-image: $(DOCKER_BUILD) docker tag $(DOCKER_FULL_NAME):$(DOCKER_TAG) $(DOCKER_FULL_NAME):latest .PHONY: runtime-image runtime-push: BASE_IMAGE := $(BASE_RUNTIME) -runtime-push: DOCKER_TAG := $(shell git describe --tags)-runtime +runtime-push: DOCKER_TAG := $(PYTORCH_VERSION)-runtime runtime-push: $(DOCKER_PUSH) diff --git a/docker/caffe2/jenkins/centos-rocm/Dockerfile b/docker/caffe2/jenkins/centos-rocm/Dockerfile index 19fdaa176b921..5c2dacf304fcf 100644 --- a/docker/caffe2/jenkins/centos-rocm/Dockerfile +++ b/docker/caffe2/jenkins/centos-rocm/Dockerfile @@ -33,7 +33,6 @@ ENV PATH /opt/rocm/bin:$PATH ENV PATH /opt/rocm/hcc/bin:$PATH ENV PATH /opt/rocm/hip/bin:$PATH ENV PATH /opt/rocm/opencl/bin:$PATH -ENV HIP_PLATFORM hcc ENV LC_ALL en_US.utf8 ENV LANG en_US.utf8 diff --git a/docker/caffe2/jenkins/common/install_python.sh b/docker/caffe2/jenkins/common/install_python.sh index 48a47b2711071..19633d451ab3d 100755 --- a/docker/caffe2/jenkins/common/install_python.sh +++ b/docker/caffe2/jenkins/common/install_python.sh @@ -135,11 +135,6 @@ if [ -z "${INSTALL_SETUPTOOLS}" ]; then pip install -U pip setuptools!=38.5.2 fi -# tornado 5.0 requires Python 2.7.9+ or 3.4+ -if [[ $($PYTHON -c 'import sys; print(int(sys.version_info <= (2, 7, 9) or sys.version_info <= (3, 4)))' == 1) ]]; then - pip install 'tornado<5' -fi - # Need networkx 2.0 because bellmand_ford was moved in 2.1 . Scikit-image by # defaults installs the most recent networkx version, so we install this lower # version explicitly before scikit-image pulls it in as a dependency diff --git a/docker/caffe2/jenkins/ubuntu-rocm/Dockerfile b/docker/caffe2/jenkins/ubuntu-rocm/Dockerfile index dbec35e06c81b..f6624d077cfce 100644 --- a/docker/caffe2/jenkins/ubuntu-rocm/Dockerfile +++ b/docker/caffe2/jenkins/ubuntu-rocm/Dockerfile @@ -60,7 +60,6 @@ ENV PATH /opt/rocm/bin:$PATH ENV PATH /opt/rocm/hcc/bin:$PATH ENV PATH /opt/rocm/hip/bin:$PATH ENV PATH /opt/rocm/opencl/bin:$PATH -ENV HIP_PLATFORM hcc ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000000000..471f0aa9f8885 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,2 @@ +Please see the [Writing documentation section of CONTRIBUTING.md](../CONTRIBUTING.md#writing-documentation) +for details on both writing and building the docs. diff --git a/docs/caffe2/process.py b/docs/caffe2/process.py index 9fa37e5fbb5aa..3b94b9d38502a 100644 --- a/docs/caffe2/process.py +++ b/docs/caffe2/process.py @@ -1,20 +1,21 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 ## @package process # Module doxygen.process # Script to insert preamble for doxygen and regen API docs -import glob, os, shutil +import os +import shutil # Module caffe2...caffe2.python.control_test -def insert(originalfile,first_line,description): - with open(originalfile,'r') as f: +def insert(originalfile, first_line, description): + with open(originalfile, 'r') as f: f1 = f.readline() - if(f1.find(first_line)<0): + if(f1.find(first_line) < 0): docs = first_line + description + f1 - with open('newfile.txt','w') as f2: + with open('newfile.txt', 'w') as f2: f2.write(docs) f2.write(f.read()) - os.rename('newfile.txt',originalfile) + os.rename('newfile.txt', originalfile) else: print('already inserted') @@ -29,15 +30,15 @@ def insert(originalfile,first_line,description): for file in files: if (file.endswith(".py") and not file.endswith("_test.py") and not file.endswith("__.py")): filepath = os.path.join(root, file) - print("filepath: " + filepath) + print(("filepath: " + filepath)) directory = os.path.dirname(filepath)[2:] - directory = directory.replace("/",".") - print "directory: " + directory + directory = directory.replace("/", ".") + print("directory: " + directory) name = os.path.splitext(file)[0] first_line = "## @package " + name description = "\n# Module " + directory + "." + name + "\n" - print first_line,description - insert(filepath,first_line,description) + print(first_line, description) + insert(filepath, first_line, description) if os.path.exists("doxygen/doxygen-python"): print("Looks like you ran this before, so we need to cleanup those old files...") diff --git a/docs/cpp/requirements.txt b/docs/cpp/requirements.txt index 452aa3eadad04..731a0475be798 100644 --- a/docs/cpp/requirements.txt +++ b/docs/cpp/requirements.txt @@ -1,5 +1,5 @@ sphinx==3.1.2 -breathe==4.19.2 +breathe==4.25.0 exhale==0.2.3 -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme bs4 diff --git a/docs/cpp/source/check-doxygen.sh b/docs/cpp/source/check-doxygen.sh index b258a41214187..6df0e68e280a3 100755 --- a/docs/cpp/source/check-doxygen.sh +++ b/docs/cpp/source/check-doxygen.sh @@ -20,6 +20,7 @@ python -m tools.codegen.gen python tools/setup_helpers/generate_code.py \ --declarations-path build/aten/src/ATen/Declarations.yaml \ + --native-functions-path aten/src/ATen/native/native_functions.yaml \ --nn-path aten/src popd diff --git a/docs/cpp/source/index.rst b/docs/cpp/source/index.rst index 2bfbe63f47c6e..39c63ddd5d7bf 100644 --- a/docs/cpp/source/index.rst +++ b/docs/cpp/source/index.rst @@ -1,20 +1,20 @@ PyTorch C++ API =============== -These pages provide documentation for the public portions of the PyTorch C++ +These pages provide the documentation for the public portions of the PyTorch C++ API. This API can roughly be divided into five parts: -- **ATen**: The foundational tensor and mathematical operation library on which all else is built; -- **Autograd**: Augments ATen with automatic differentiation; -- **C++ Frontend**: High level constructs for training and evaluation of machine learning models; -- **TorchScript**: An interface to the TorchScript JIT compiler and interpreter; +- **ATen**: The foundational tensor and mathematical operation library on which all else is built. +- **Autograd**: Augments ATen with automatic differentiation. +- **C++ Frontend**: High level constructs for training and evaluation of machine learning models. +- **TorchScript**: An interface to the TorchScript JIT compiler and interpreter. - **C++ Extensions**: A means of extending the Python API with custom C++ and CUDA routines. -Together, these building blocks form a research and +Combining, these building blocks form a research and production ready C++ library for tensor computation and dynamic neural networks with strong emphasis on GPU acceleration as well as fast CPU performance. It is currently in use at Facebook in research and -production; we look forward to welcoming more users of the PyTorch C++ API. +production; we are looking forward to welcome more users of the PyTorch C++ API. .. warning:: @@ -76,7 +76,7 @@ C++ Frontend ------------ The PyTorch C++ frontend provides a high level, pure C++ modeling interface for -neural network and general machine learning research and production use cases, +neural network and general ML(Machine Learning) research and production use cases, largely following the Python API in design and provided functionality. The C++ frontend includes the following: @@ -119,7 +119,7 @@ expanded on a continuous and active basis. TorchScript ----------- -TorchScript a representation of a PyTorch model that can be understood, +TorchScript is a representation of a PyTorch model that can be understood, compiled and serialized by the TorchScript compiler. Fundamentally, TorchScript is a programming language in its own right. It is a subset of Python using the PyTorch API. The C++ interface to TorchScript encompasses three primary pieces of @@ -150,7 +150,7 @@ CUDA to accelerate research in vanilla PyTorch setups. The C++ extension API does not add any new functionality to the PyTorch C++ API. Instead, it provides integration with Python setuptools as well as JIT compilation mechanisms that allow access to ATen, the autograd and other C++ APIs from -Python. To learn more about the C++ extension API, see +Python. To learn more about the C++ extension API, go through `this tutorial `_. Contents @@ -183,4 +183,4 @@ Acknowledgements This documentation website for the PyTorch C++ universe has been enabled by the `Exhale `_ project and generous investment of time and effort by its maintainer, `svenevs `_. -We thank Stephen for his work and his help with the PyTorch C++ documentation. +We thank Stephen for his work and his efforts providing help with the PyTorch C++ documentation. diff --git a/docs/cpp/source/notes/tensor_cuda_stream.rst b/docs/cpp/source/notes/tensor_cuda_stream.rst new file mode 100644 index 0000000000000..9d9e14704ff96 --- /dev/null +++ b/docs/cpp/source/notes/tensor_cuda_stream.rst @@ -0,0 +1,276 @@ +Tensor CUDA Stream API +====================== + +A `CUDA Stream`_ is a linear sequence of execution that belongs to a specific CUDA device. +The PyTorch C++ API supports CUDA streams with the CUDAStream class and useful helper functions to make streaming operations easy. +You can find them in `CUDAStream.h`_. This note provides more details on how to use Pytorch C++ CUDA Stream APIs. + +.. _CUDA Stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams +.. _CUDAStream.h: https://pytorch.org/cppdocs/api/file_c10_cuda_CUDAStream.h.html#file-c10-cuda-cudastream-h +.. _CUDAStreamGuard.h: https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_c_u_d_a_stream_guard.html + +Acquiring CUDA stream +********************* + +Pytorch's C++ API provides the following ways to acquire CUDA stream: + +1. Acquire a new stream from the CUDA stream pool, streams are preallocated from the pool and returned in a round-robin fashion. + +.. code-block:: cpp + + CUDAStream getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); + +.. tip:: + + You can request a stream from the high priority pool by setting isHighPriority to true, or a stream for a specific device + by setting device index (defaulting to the current CUDA stream's device index). + +2. Acquire the default CUDA stream for the passed CUDA device, or for the current device if no device index is passed. + +.. code-block:: cpp + + CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); + +.. tip:: + + The default stream is where most computation occurs when you aren't explicitly using streams. + +3. Acquire the current CUDA stream, for the CUDA device with index ``device_index``, or for the current device if no device index is passed. + +.. code-block:: cpp + + CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); + +.. tip:: + + The current CUDA stream will usually be the default CUDA stream for the device, but it may be different if someone + called ``setCurrentCUDAStream`` or used ``StreamGuard`` or ``CUDAStreamGuard``. + + + +Set CUDA stream +*************** + +Pytorch's C++ API provides the following ways to set CUDA stream: + +1. Set the current stream on the device of the passed in stream to be the passed in stream. + +.. code-block:: cpp + + void setCurrentCUDAStream(CUDAStream stream); + +.. attention:: + + This function may have nosthing to do with the current device. It only changes the current stream on the stream's device. + We recommend using ``CUDAStreamGuard``, instead, since it switches to the stream's device and makes it the current stream on that device. + ``CUDAStreamGuard`` will also restore the current device and stream when it's destroyed + +2. Use ``CUDAStreamGuard`` to switch to a CUDA stream within a scope, it is defined in `CUDAStreamGuard.h`_ + +.. tip:: + + Use ``CUDAMultiStreamGuard`` if you need to set streams on multiple CUDA devices. + +CUDA Stream Usage Examples +************************** + +1. Acquiring and setting CUDA stream on the same device + +.. code-block:: cpp + + // This example shows how to acquire and set CUDA stream on the same device. + // `at::cuda::setCurrentCUDAStream` is used to set current CUDA stream + + // create a tensor on device 0 + torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(torch::kCUDA)); + // get a new CUDA stream from CUDA stream pool on device 0 + at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(); + // set current CUDA stream from default stream to `myStream` on device 0 + at::cuda::setCurrentCUDAStream(myStream); + // sum() on tensor0 uses `myStream` as current CUDA stream + tensor0.sum(); + + // get the default CUDA stream on device 0 + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); + // set current CUDA stream back to default CUDA stream on devide 0 + at::cuda::setCurrentCUDAStream(defaultStream); + // sum() on tensor0 uses `defaultStream` as current CUDA stream + tensor0.sum(); + +.. code-block:: cpp + + // This example is the same as previous example, but explicitly specify device + // index and use CUDA stream guard to set current CUDA stream + + // create a tensor on device 0 + torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(torch::kCUDA)); + // get a new stream from CUDA stream pool on device 0 + at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(false, 0); + // set the current CUDA stream to `myStream` within the scope using CUDA stream guard + { + at::cuda::CUDAStreamGuard guard(myStream); + // current CUDA stream is `myStream` from here till the end of bracket. + // sum() on tensor0 uses `myStream` as current CUDA stream + tensor0.sum(); + } + // current CUDA stream is reset to default CUDA stream after CUDA stream guard is destroyed + // sum() on tensor0 uses default CUDA stream on device 0 as current CUDA stream + tensor0.sum(); + +.. attention:: + + Above code is running on the same CUDA device. `setCurrentCUDAStream` will always set current CUDA stream on current device, + but note that `setCurrentCUDASteram` actually set current stream on the device of passed in CUDA stream. + + +2. Acquiring and setting CUDA streams on multiple devices. + +.. code-block:: cpp + + // This example shows how to acquire and set CUDA stream on two devices. + + // acquire new CUDA streams from CUDA stream pool on device 0 and device 1 + at::cuda::CUDAStream myStream0 = at::cuda::getStreamFromPool(false, 0); + at::cuda::CUDAStream myStream1 = at::cuda::getStreamFromPool(false, 1); + + // set current CUDA stream to `myStream0` on device 0 + at::cuda::setCurrentCUDAStream(myStream0); + // set current CUDA stream to `myStream1` on device 1 + at::cuda::setCurrentCUDAStream(myStream1); + + // create a tensor on device 0, no need to specify device index since + // current device index is 0 + torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(at::kCUDA)); + // sum() on tensor0 use `myStream0` as current CUDA stream on device 0 + tensor0.sum(); + + // change the current device index to 1 by using CUDA device guard within a braket scope + { + at::cuda::CUDAGuard device_guard{1}; + // create a tensor on device 1 + torch::Tensor tensor1 = torch::ones({2, 2}, torch::device(at::kCUDA)); + // sum() on tensor 1 uses `myStream1` as current CUDA stream on device 1 + tensor1.sum(); + } + + // current device is reset to device 0 after device_guard is destroyed + + // acquire a new CUDA stream on device 1 + at::cuda::CUDAStream myStream1_1 = at::cuda::getStreamFromPool(false, 1); + // create a new tensor on device 1 + torch::Tensor tensor1 = torch::ones({2, 2}, torch::device({torch::kCUDA, 1})); + + // change the current device index to 1 and current CUDA stream on device 1 + // to `myStream1_1` using CUDA stream guard within a scope + { + at::cuda::CUDAStreamGuard stream_guard(myStream1_1); + // sum() on tensor1 use `myStream1_1` as current CUDA stream on device 1 + tensor1.sum(); + } + + // current device is reset to device 0 and current CUDA stream on device 1 is + // reset to `myStream1` + + // sum() on tensor1 uses `myStream1` as current CUDA stream on device 1 + tensor1.sum(); + + +3. Working with CUDA multistream guard + +.. code-block:: cpp + + // This example shows how to use CUDA multistream guard to set + // two streams on two devices at the same time. + + // create two tensor, one on device 0, one on device 1 + torch::Tensor tensor0 = torch::ones({2, 2}, torch::device({torch::kCUDA, 0})); + torch::Tensor tensor1 = torch::ones({2, 2}, torch::device({torch::kCUDA, 1})); + + // acquire new CUDA streams from CUDA stream pool on device 0 and device 1 + at::cuda::CUDAStream myStream0 = at::cuda::getStreamFromPool(false, 0); + at::cuda::CUDAStream myStream1 = at::cuda::getStreamFromPool(false, 1); + + // set current CUDA stream on device 0 to `myStream0` and + // set current CUDA stream on device 1 to `myStream1` CUDA using multistream guard + { + at::cuda::CUDAMultiStreamGuard multi_guard({myStream0, myStream1}); + + // sum() on tensor0 uses `myStream0` as current CUDA stream on device 0 + tensor0.sum(); + // sum() on tensor1 uses `myStream1` as current CUDA stream on device 1 + tensor1.sum(); + } + + // current CUDA stream on device 0 is reset to default CUDA stream on device 0 + // current CUDA stream on device 1 is reset to default CUDA stream on device 1 + + // sum() on tensor0 uses default CUDA stream as current CUDA stream on device 0 + tensor0.sum(); + // sum() on tensor1 uses defualt CUDA stream as current CUDA stream on device 1 + tensor1.sum(); + +.. attention:: + ``CUDAMultiStreamGuard`` does not change current device index, it only changes the stream on + each passed in stream's device. Other than scope controlling, this guard is equivalent to + calling ``setCurrentCUDAStream`` on each passed in stream. + +4. A skeleton example for handling CUDA streams on multiple devices + +.. code-block:: cpp + + // This is a skeleton example that shows how to handle CUDA streams on multiple devices + // Suppose you want to do work on the non-default stream on two devices simultaneously, and we + // already have streams on both devices in two vectors. The following code shows three ways + // of acquiring and setting the streams. + + // Usage 0: acquire CUDA stream and set current CUDA stream with `setCurrentCUDAStream` + // Create a CUDA stream vector `streams0` on device 0 + std::vector streams0 = + {at::cuda::getDefaultCUDAStream(), at::cuda::getStreamFromPool()}; + // set current stream as `streams0[0]` on device 0 + at::cuda::setCurrentCUDAStream(streams0[0]); + + // create a CUDA stream vector `streams1` on device using CUDA device guard + std::vector streams1; + { + // device index is set to 1 within this scope + at::cuda::CUDAGuard device_guard(1); + streams1.push_back(at::cuda::getDefaultCUDAStream()); + streams1.push_back(at::cuda::getStreamFromPool()); + } + // device index is reset to 0 after device_guard is destroyed + + // set current stream as `streams1[0]` on device 1 + at::cuda::setCurrentCUDAStream(streams1[0]); + + + // Usage 1: use CUDA device guard to change the current device index only + { + at::cuda::CUDAGuard device_guard(1); + + // current device index is changed to 1 within scope + // current CUDA stream is still `streams1[0]` on device 1, no change + } + // current device index is reset to 0 after `device_guard` is destroyed + + + // Usage 2: use CUDA stream guard to change both current device index and current CUDA stream. + { + at::cuda::CUDAStreamGuard stream_guard(streams1[1]); + + // current device index and current CUDA stream are set to 1 and `streams1[1]` within scope + } + // current device index and current CUDA stream are reset to 0 and `streams0[0]` after + // stream_guard is destroyed + + + // Usage 3: use CUDA multi-stream guard to change multiple streams on multiple devices + { + // This is the same as calling `torch::cuda::setCurrentCUDAStream` on both streams + at::cuda::CUDAMultiStreamGuard multi_guard({streams0[1], streams1[1]}); + + // current device index is not change, still 0 + // current CUDA stream on device 0 and device 1 are set to `streams0[1]` and `streams1[1]` + } + // current CUDA stream on device 0 and device 1 are reset to `streams0[0]` and `streams1[0]` + // after `multi_guard` is destroyed. diff --git a/docs/libtorch.rst b/docs/libtorch.rst index 1b7f9a95356b2..f5e1abda42182 100644 --- a/docs/libtorch.rst +++ b/docs/libtorch.rst @@ -5,8 +5,8 @@ The core of pytorch does not depend on Python. A CMake-based build system compiles the C++ source code into a shared object, libtorch.so. -Building libtorch ------------------ +Building libtorch using Python +------------------------------ You can use a python script/module located in tools package to build libtorch :: @@ -34,3 +34,16 @@ To produce libtorch.a rather than libtorch.so, set the environment variable `BUI To use ninja rather than make, set `CMAKE_GENERATOR="-GNinja" CMAKE_INSTALL="ninja install"`. Note that we are working on eliminating tools/build_pytorch_libs.sh in favor of a unified cmake build. + +Building libtorch using CMake +-------------------------------------- + +You can build C++ libtorch.so directly with cmake. For example, to build a Release version from the master branch and install it in the directory specified by CMAKE_INSTALL_PREFIX below, you can use +:: + git clone -b master --recurse-submodule https://github.com/pytorch/pytorch.git + mkdir pytorch-build + cd pytorch-build + cmake -DBUILD_SHARED_LIBS:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=Release -DPYTHON_EXECUTABLE:PATH=`which python3` -DCMAKE_INSTALL_PREFIX:PATH=../pytorch-install ../pytorch + cmake --build . --target install + +To use release branch v1.6.0, for example, replace ``master`` with ``v1.6.0``. You will get errors if you do not have needed dependencies such as Python3's PyYAML package. diff --git a/docs/source/__config__.rst b/docs/source/__config__.rst index e4a6ac8904939..adbfebe560a72 100644 --- a/docs/source/__config__.rst +++ b/docs/source/__config__.rst @@ -2,6 +2,7 @@ torch.__config__ =================================== .. automodule:: torch.__config__ +.. currentmodule:: torch.__config__ .. autofunction:: show .. autofunction:: parallel_info diff --git a/docs/source/benchmark_utils.rst b/docs/source/benchmark_utils.rst new file mode 100644 index 0000000000000..8e46d017cf1cd --- /dev/null +++ b/docs/source/benchmark_utils.rst @@ -0,0 +1,12 @@ +.. role:: hidden + :class: hidden-section + +Benchmark Utils - torch.utils.benchmark +================================================== + +.. automodule:: torch.utils.benchmark +.. currentmodule:: torch.utils.benchmark + +.. autoclass:: Timer + :members: + :show-inheritance: diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index c152ca6165715..f346fbe994e6c 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -25,7 +25,6 @@ torch.* torch.nn ~~~~~~~~ -- Thomas Viehmann (`t-vi `__) - Adam Paszke (`apaszke `__) - Greg Chanan (`gchanan `__) - Soumith Chintala (`soumith `__) diff --git a/docs/source/conf.py b/docs/source/conf.py index fe1e2260be727..610f6efa08409 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -161,7 +161,7 @@ # TODO: verify this works as expected release = 'master' -# Customized html_title here. +# Customized html_title here. # Default is " ".join(project, release, "documentation") if not set if RELEASE: # remove hash (start with 'a') from version number if any @@ -192,6 +192,9 @@ # Disable docstring inheritance autodoc_inherit_docstrings = False +# Disable displaying type annotations, these can be very verbose +autodoc_typehints = 'none' + # -- katex javascript in header # @@ -253,9 +256,9 @@ def setup(app): add_css(css_file) # From PyTorch 1.5, we now use autogenerated files to document classes and -# functions. This breaks older references since +# functions. This breaks older references since # https://docs.pytorch.org/torch.html#torch.flip -# moved to +# moved to # https://docs.pytorch.org/torch/generated/torchflip.html # which breaks older links from blog posts, stack overflow answers and more. # To mitigate that, we add an id="torch.flip" in an appropriated place @@ -278,7 +281,7 @@ def visit_reference(self, node): # to autogenerated content anchor = ref_anchor[1] txt = node.parent.astext() - if txt == anchor or txt == anchor.split('.')[-1]: + if txt == anchor or txt == anchor.split('.')[-1]: self.body.append('

'.format(ref_anchor[1])) return old_call(self, node) Klass.visit_reference = visit_reference diff --git a/docs/source/data.rst b/docs/source/data.rst index 9ba88f02c31ff..e7b00d23f521f 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -331,6 +331,8 @@ are compatible with Windows while using multi-process data loading: ``__main__`` check. This ensures that they are available in worker processes. (this is needed since functions are pickled as references only, not ``bytecode``.) +.. _data-loading-randomness: + Randomness in multi-process data loading """""""""""""""""""""""""""""""""""""""""" @@ -403,6 +405,7 @@ Example:: .. autoclass:: TensorDataset .. autoclass:: ConcatDataset .. autoclass:: ChainDataset +.. autoclass:: BufferedShuffleDataset .. autoclass:: Subset .. autofunction:: torch.utils.data.get_worker_info .. autofunction:: torch.utils.data.random_split diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index a248d3e4ca831..b35a34fc02655 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -52,12 +52,22 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it. Backends that come with PyTorch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -PyTorch distributed currently only supports Linux. By default, the Gloo and NCCL backends -are built and included in PyTorch distributed (NCCL only when building with CUDA). -MPI is an -optional backend that can only be included if you build PyTorch from source. (e.g. -building PyTorch on a host that has MPI installed.) +PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). +By default for Linux, the Gloo and NCCL backends are built and included in PyTorch +distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be +included if you build PyTorch from source. (e.g.building PyTorch on a host that has MPI +installed.) +.. warning :: + As of PyTorch v1.7, Windows support for the distributed package only covers collective + communications with Gloo backend, `FileStore`, and `DistributedDataParallel`. Therefore, + the `init_method` argument in :func:`init_process_group` must point to a file. This works + for both local and shared file systems: + + - Local file system, ``init_method="file:///d:/tmp/some_file"`` + - Shared file system, ``init_method="file://////{machine_name}/{share_folder_name}/some_file"`` + + Similarly, if you directly pass in a `store` argument, it must be a ``FileStore`` instance. Which backend to use? ^^^^^^^^^^^^^^^^^^^^^ @@ -260,6 +270,31 @@ The machine with rank 0 will be used to set up all connections. This is the default method, meaning that ``init_method`` does not have to be specified (or can be ``env://``). +Distributed Key-Value Store +--------------------------- + +The distributed package comes with a distributed key-value store, which can be +used to share information between processes in the group as well as to +initialize the distributed pacakge in +:func:`torch.distributed.init_process_group` (by explicitly creating the store +as an alternative to specifying ``init_method``.) There are 3 choices for +Key-Value Stores: :class:`~torch.distributed.TCPStore`, +:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`. + +.. autoclass:: Store +.. autoclass:: TCPStore +.. autoclass:: HashStore +.. autoclass:: FileStore +.. autoclass:: PrefixStore + +.. autofunction:: torch.distributed.Store.set +.. autofunction:: torch.distributed.Store.get +.. autofunction:: torch.distributed.Store.add +.. autofunction:: torch.distributed.Store.wait +.. autofunction:: torch.distributed.Store.num_keys +.. autofunction:: torch.distributed.Store.delete_key +.. autofunction:: torch.distributed.Store.set_timeout + Groups ------ @@ -295,27 +330,59 @@ as they should never be created manually, but they are guaranteed to support two Synchronous and asynchronous collective operations -------------------------------------------------- -Every collective operation function supports the following two kinds of operations: - -synchronous operation - the default mode, when ``async_op`` is set to False. -when the function returns, it is guaranteed that -the collective operation is performed (not necessarily completed if it's a CUDA op since all -CUDA ops are asynchronous), and any further function calls depending on the data of the -collective operation can be called. In the synchronous mode, the collective function does not -return anything - -asynchronous operation - when ``async_op`` is set to True. The collective operation function +Every collective operation function supports the following two kinds of operations, +depending on the setting of the ``async_op`` flag passed into the collective: + +**Synchronous operation** - the default mode, when ``async_op`` is set to ``False``. +When the function returns, it is guaranteed that +the collective operation is performed. In the case of CUDA operations, it is not guaranteed +that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any +further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, +function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of +synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream +synchronization, see `CUDA Semantics `__. +See the below script to see examples of differences in these semantics for CPU and CUDA operations. + +**Asynchronous operation** - when ``async_op`` is set to True. The collective operation function returns a distributed request object. In general, you don't need to create it manually and it is guaranteed to support two methods: -* ``is_completed()`` - returns True if the operation has finished -* ``wait()`` - will block the process until the operation is finished. +* ``is_completed()`` - in the case of CPU collectives, returns ``True`` if completed. In the case of CUDA operations, + returns ``True`` if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the + default stream without further synchronization. +* ``wait()`` - in the case of CPU collectives, will block the process until the operation is completed. In the case + of CUDA collectives, will block until the operation has been successfully enqueued onto a CUDA stream and the + output can be utilized on the default stream without further synchronization. + +**Example** + +The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. +It shows the explicit need to synchronize when using collective outputs on different CUDA streams: + +:: + + # Code runs on each rank. + dist.init_process_group("nccl", rank=rank, world_size=2) + output = torch.tensor([rank]).cuda(rank) + s = torch.cuda.Stream() + handle = dist.all_reduce(output, async_op=True) + # Wait ensures the operation is enqueued, but not necessarily complete. + handle.wait() + # Using result on non-default stream. + with torch.cuda.stream(s): + s.wait_stream(torch.cuda.default_stream()) + output.add_(100) + if rank == 0: + # if the explicit call to wait_stream was omitted, the output below will be + # non-deterministically 1 or 101, depending on whether the allreduce overwrote + # the value after the add completed. + print(output) Collective functions -------------------- -.. autofunction:: broadcast +.. autofunction:: broadcast .. autofunction:: broadcast_object_list @@ -333,6 +400,8 @@ Collective functions .. autofunction:: scatter +.. autofunction:: scatter_object_list + .. autofunction:: reduce_scatter .. autofunction:: all_to_all diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index aebc390383681..fe09626e60d84 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -167,6 +167,15 @@ Probability distributions - torch.distributions :undoc-members: :show-inheritance: +:hidden:`Kumaraswamy` +~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torch.distributions.kumaraswamy +.. autoclass:: Kumaraswamy + :members: + :undoc-members: + :show-inheritance: + :hidden:`Laplace` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/fft.rst b/docs/source/fft.rst index ab50bd271d322..dfce2503f70d1 100644 --- a/docs/source/fft.rst +++ b/docs/source/fft.rst @@ -1,31 +1,38 @@ .. role:: hidden :class: hidden-section -.. _torch-fft-module: - torch.fft ========= Discrete Fourier transforms and related functions. -To use these functions the torch.fft module must be imported since its name -conflicts with the :func:`torch.fft` function. - .. automodule:: torch.fft :noindex: .. currentmodule:: torch.fft -Functions ---------- +Fast Fourier Transforms +----------------------- .. autofunction:: fft .. autofunction:: ifft +.. autofunction:: fft2 +.. autofunction:: ifft2 .. autofunction:: fftn .. autofunction:: ifftn .. autofunction:: rfft .. autofunction:: irfft +.. autofunction:: rfft2 +.. autofunction:: irfft2 .. autofunction:: rfftn .. autofunction:: irfftn .. autofunction:: hfft .. autofunction:: ihfft + +Helper Functions +---------------- + +.. autofunction:: fftfreq +.. autofunction:: rfftfreq +.. autofunction:: fftshift +.. autofunction:: ifftshift diff --git a/docs/source/fx.rst b/docs/source/fx.rst new file mode 100644 index 0000000000000..0a8e00f6d44c2 --- /dev/null +++ b/docs/source/fx.rst @@ -0,0 +1,48 @@ +.. currentmodule:: torch.fx + +torch.fx +============= + +Overview +-------- +.. automodule:: torch.fx + +Limitations of Symbolic Tracing +------------------------------- + +TODO + +Writing Transformations +----------------------- + +TODO + +Debugging Transformations +------------------------- + +TODO + +API Reference +------------- + +.. autofunction:: torch.fx.symbolic_trace + +.. autofunction:: torch.fx.wrap + +.. autoclass:: torch.fx.GraphModule + :members: + + .. automethod:: __init__ + +.. autoclass:: torch.fx.Graph + :members: + + .. automethod:: __init__ + +.. autoclass:: torch.fx.Node + :members: + +.. autoclass:: torch.fx.Tracer + :members: + +.. autoclass:: torch.fx.Proxy diff --git a/docs/source/index.rst b/docs/source/index.rst index 54a82a07b1a92..a334bffab01ef 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,6 +10,25 @@ PyTorch documentation PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. +Features described in this documentation are classified by release status: + + *Stable:* These features will be maintained long-term and there should generally + be no major performance limitations or gaps in documentation. + We also expect to maintain backwards compatibility (although + breaking changes can happen and notice will be given one release ahead + of time). + + *Beta:* Features are tagged as Beta because the API may change based on + user feedback, because the performance needs to improve, or because + coverage across operators is not yet complete. For Beta features, we are + committing to seeing the feature through to the Stable classification. + We are not, however, committing to backwards compatibility. + + *Prototype:* These features are typically not available as part of + binary distributions like PyPI or Conda, except sometimes behind run-time + flags, and are at an early stage for feedback and testing. + + .. toctree:: :glob: :maxdepth: 1 @@ -42,9 +61,12 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. torch.distributions torch.fft futures + fx torch.hub torch.jit torch.linalg + torch.overrides + profiler nn.init onnx optim @@ -54,6 +76,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. torch.random sparse storage + torch.utils.benchmark torch.utils.bottleneck torch.utils.checkpoint torch.utils.cpp_extension @@ -71,7 +94,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. :maxdepth: 1 :caption: Libraries - torchaudio + torchaudio torchtext torchvision TorchElastic diff --git a/docs/source/jit.rst b/docs/source/jit.rst index 45ba6fa18d80e..ccd37738277f0 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -45,15 +45,18 @@ Creating TorchScript Code script trace + script_if_tracing trace_module fork wait ScriptModule ScriptFunction + freeze save load ignore unused + isinstance Mixing Tracing and Scripting ---------------------------- @@ -544,10 +547,10 @@ best practices? cpu_model = gpu_model.cpu() sample_input_cpu = sample_input_gpu.cpu() - traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu) + traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) torch.jit.save(traced_cpu, "cpu.pth") - traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu) + traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) torch.jit.save(traced_gpu, "gpu.pth") # ... later, when using the model: diff --git a/docs/source/jit_language_reference.rst b/docs/source/jit_language_reference.rst index 4cca46fdc0057..d52c00e147f39 100644 --- a/docs/source/jit_language_reference.rst +++ b/docs/source/jit_language_reference.rst @@ -63,7 +63,7 @@ net models. In particular, TorchScript supports: :header: "Type", "Description" "``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend" - "``Tuple[T0, T1, ...]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)" + "``Tuple[T0, T1, ..., TN]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)" "``bool``", "A boolean value" "``int``", "A scalar integer" "``float``", "A scalar floating point number" @@ -72,6 +72,7 @@ net models. In particular, TorchScript supports: "``Optional[T]``", "A value which is either None or type ``T``" "``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types." "``T``", "A `TorchScript Class`_" + "``E``", "A `TorchScript Enum`_" "``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple ` tuple type" Unlike Python, each variable in TorchScript function must have a single static type. @@ -130,6 +131,7 @@ These types and features from the :mod:`typing` module are unavailble in TorchSc ":any:`typing.Any`", ":any:`typing.Any` is currently in development but not yet released" ":any:`typing.NoReturn`", "Not implemented" ":any:`typing.Union`", "Unlikely to be implemented (however :any:`typing.Optional` is supported)" + ":any:`typing.Sequence`", "Not implemented" ":any:`typing.Callable`", "Not implemented" ":any:`typing.Literal`", "Not implemented" ":any:`typing.ClassVar`", "Not implemented" @@ -271,6 +273,7 @@ Example (refining types on parameters and locals): module = torch.jit.script(M(2)) module = torch.jit.script(M(None)) + .. _TorchScript Class: .. _TorchScript Classes: .. _torchscript-classes: @@ -346,6 +349,37 @@ like any other TorchScript type: print(sum_pair(p)) +.. _TorchScript Enum: +.. _TorchScript Enums: +.. _torchscript-enums: + +TorchScript Enums +^^^^^^^^^^^^^^^^^^^ + +Python enums can be used in TorchScript without any extra annotation or code: + +:: + + from enum import Enum + + + class Color(Enum): + RED = 1 + GREEN = 2 + + @torch.jit.script + def enum_fn(x: Color, y: Color) -> bool: + if x == Color.RED: + return True + + return x == y + +After an enum is defined, it can be used in both TorchScript and Python interchangeably +like any other TorchScript type. The type of the values of an enum must be ``int``, +``float``, or ``str``. All values must be of the same type; heterogenous types for enum +values are not supported. + + Named Tuples ^^^^^^^^^^^^ Types produced by :func:`collections.namedtuple ` can be used in TorchScript. diff --git a/docs/source/jit_unsupported.rst b/docs/source/jit_unsupported.rst index 8bf3e78d672a5..7368abad1e300 100644 --- a/docs/source/jit_unsupported.rst +++ b/docs/source/jit_unsupported.rst @@ -87,6 +87,5 @@ we suggest using :meth:`torch.jit.trace`. * :class:`torch.nn.RNN` * :class:`torch.nn.AdaptiveLogSoftmaxWithLoss` * :class:`torch.autograd.Function` - * :class:`torch.autograd.no_grad` * :class:`torch.autograd.enable_grad` * :class:`torch.Generator` diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 834b6a60ac93f..f592eac72aea5 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -12,5 +12,18 @@ Common linear algebra operations. Functions --------- +.. autofunction:: cholesky +.. autofunction:: cond .. autofunction:: det +.. autofunction:: slogdet +.. autofunction:: eigh +.. autofunction:: eigvalsh +.. autofunction:: matrix_rank .. autofunction:: norm +.. autofunction:: pinv +.. autofunction:: svd +.. autofunction:: solve +.. autofunction:: tensorinv +.. autofunction:: tensorsolve +.. autofunction:: inv +.. autofunction:: qr diff --git a/docs/source/mobile_optimizer.rst b/docs/source/mobile_optimizer.rst index 3067a2db43799..bb11abf82dbac 100644 --- a/docs/source/mobile_optimizer.rst +++ b/docs/source/mobile_optimizer.rst @@ -5,14 +5,14 @@ torch.utils.mobile_optimizer This API is in beta and may change in the near future. Torch mobile supports ``torch.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode. -The method takes the following parameters: a torch.jit.ScriptModule object, a blacklisting optimization set and a preserved method list +The method takes the following parameters: a torch.jit.ScriptModule object, a blocklisting optimization set and a preserved method list -By default, if optimization blacklist is None or empty, ``optimize_for_mobile`` will run the following optimizations: - - **Conv2D + BatchNorm fusion** (blacklisting option `MobileOptimizerType::CONV_BN_FUSION`): This optimization pass folds ``Conv2d-BatchNorm2d`` into ``Conv2d`` in ``forward`` method of this module and all its submodules. The weight and bias of the ``Conv2d`` are correspondingly updated. - - **Insert and Fold prepacked ops** (blacklisting option `MobileOptimizerType::INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops. +By default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the following optimizations: + - **Conv2D + BatchNorm fusion** (blocklisting option `MobileOptimizerType::CONV_BN_FUSION`): This optimization pass folds ``Conv2d-BatchNorm2d`` into ``Conv2d`` in ``forward`` method of this module and all its submodules. The weight and bias of the ``Conv2d`` are correspondingly updated. + - **Insert and Fold prepacked ops** (blocklisting option `MobileOptimizerType::INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops. - **ReLU/Hardtanh fusion**: XNNPACK ops support fusion of clamping. That is clamping of output activation is done as part of the kernel, including for 2D convolution and linear op kernels. Thus clamping effectively comes for free. Thus any op that can be expressed as clamping op, such as ``ReLU`` or ``hardtanh``, can be fused with previous ``Conv2D`` or ``linear`` op in XNNPACK. This pass rewrites graph by finding ``ReLU/hardtanh`` ops that follow XNNPACK ``Conv2D/linear`` ops, written by the previous pass, and fuses them together. - - **Dropout removal** (blacklisting option `MobileOptimizerType::REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false. - - **Conv packed params hoisting** (blacklisting option `MobileOptimizerType::HOIST_CONV_PACKED_PARAMS`): This optimization pass moves convolution packed params to the root module, so that the convolution structs can be deleted. This decreases model size without impacting numerics. + - **Dropout removal** (blocklisting option `MobileOptimizerType::REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false. + - **Conv packed params hoisting** (blocklisting option `MobileOptimizerType::HOIST_CONV_PACKED_PARAMS`): This optimization pass moves convolution packed params to the root module, so that the convolution structs can be deleted. This decreases model size without impacting numerics. ``optimize_for_mobile`` will also invoke freeze_module pass which only preserves ``forward`` method. If you have other method to that needed to be preserved, add them into the preserved method list and pass into the method. diff --git a/docs/source/name_inference.rst b/docs/source/name_inference.rst index ccbb8c0c54d3d..2606c82280b8a 100644 --- a/docs/source/name_inference.rst +++ b/docs/source/name_inference.rst @@ -151,6 +151,7 @@ If you don't see an operation listed here, but it would help your use case, plea ":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc` ":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc` ":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.nanmedian`, :func:`torch.nanmedian`",:ref:`removes_dimensions-doc` ":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc` ":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc` ":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc` diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index 416121cec8d61..17b0e0a80b360 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -496,6 +496,11 @@ Vision functions .. autofunction:: pixel_shuffle +:hidden:`pixel_unshuffle` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pixel_unshuffle + :hidden:`pad` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 8d195c04037c3..74f7994447a18 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -21,6 +21,7 @@ These are the basic building block for graphs :template: classtemplate.rst ~parameter.Parameter + ~parameter.UninitializedParameter Containers ---------------------------------- @@ -37,6 +38,17 @@ Containers ParameterList ParameterDict +Global Hooks For Module + +.. currentmodule:: torch.nn.modules.module +.. autosummary:: + :toctree: generated + :nosignatures: + + register_module_forward_pre_hook + register_module_forward_hook + register_module_backward_hook + .. currentmodule:: torch Convolution Layers @@ -53,6 +65,12 @@ Convolution Layers nn.ConvTranspose1d nn.ConvTranspose2d nn.ConvTranspose3d + nn.LazyConv1d + nn.LazyConv2d + nn.LazyConv3d + nn.LazyConvTranspose1d + nn.LazyConvTranspose2d + nn.LazyConvTranspose3d nn.Unfold nn.Fold @@ -207,6 +225,7 @@ Linear Layers nn.Identity nn.Linear nn.Bilinear + nn.LazyLinear Dropout Layers -------------- @@ -280,10 +299,21 @@ Vision Layers :template: classtemplate.rst nn.PixelShuffle + nn.PixelUnshuffle nn.Upsample nn.UpsamplingNearest2d nn.UpsamplingBilinear2d +Shuffle Layers +---------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + nn.ChannelShuffle + DataParallel Layers (multi-GPU, distributed) -------------------------------------------- @@ -363,3 +393,14 @@ Quantized Functions Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than floating point precision. PyTorch supports both per tensor and per channel asymmetric linear quantization. To learn more how to use quantized functions in PyTorch, please refer to the :ref:`quantization-doc` documentation. + +Lazy Modules Initialization +--------------------------- + +.. currentmodule:: torch +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + nn.modules.lazy.LazyModuleMixin diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index d455f76b8c45b..625ffa1ba2382 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -214,80 +214,278 @@ proper thread locking code to ensure the hooks are thread safe. .. _complex_autograd-doc: Autograd for Complex Numbers -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +---------------------------- + +The short version: + +- When you use PyTorch to differentiate any function :math:`f(z)` with complex domain and/or codomain, + the gradients are computed under the assumption that the function is a part of a larger real-valued + loss function :math:`g(input)=L`. The gradient computed is :math:`\frac{\partial L}{\partial z^*}` + (note the conjugation of z), the negative of which is precisely the direction of steepest descent + used in Gradient Descent algorithm.. Thus, all the existing optimizers work out of + the box with complex parameters. +- This convention matches TensorFlow's convention for complex + differentiation, but is different from JAX (which computes + :math:`\frac{\partial L}{\partial z}`). +- If you have a real-to-real function which internally uses complex + operations, the convention here doesn't matter: you will always get + the same result that you would have gotten if it had been implemented + with only real operations. + +If you are curious about the mathematical details, or want to know how +to define complex derivatives in PyTorch, read on. + +What are complex derivatives? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -**What notion of complex derivative does PyTorch use?** -******************************************************* +The mathematical definition of complex-differentiability takes the +limit definition of a derivative and generalizes it to operate on +complex numbers. For a function :math:`f: ℂ → ℂ`, we can write: -PyTorch follows `JAX's `_ -convention for autograd for Complex Numbers. + .. math:: + f'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h} + +In order for this limit to exist, not only must :math:`u` and :math:`v` must be +real differentiable (as above), but :math:`f` must also satisfy the Cauchy-Riemann `equations +`_. In +other words: the limit computed for real and imaginary steps (:math:`h`) +must be equal. This is a more restrictive condition. + +The complex differentiable functions are commonly known as holomorphic +functions. They are well behaved, have all the nice properties that +you've seen from real differentiable functions, but are practically of no +use in the optimization world. For optimization problems, only real valued objective +functions are used in the research community since complex numbers are not part of any +ordered field and so having complex valued loss does not make much sense. + +It also turns out that no interesting real-valued objective fulfill the +Cauchy-Riemann equations. So the theory with homomorphic function cannot be +used for optimization and most people therefore use the Wirtinger calculus. + +Wirtinger Calculus comes in picture ... +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +So, we have this great theory of complex differentiability and +holomorphic functions, and we can’t use any of it at all, because many +of the commonly used functions are not holomorphic. What’s a poor +mathematician to do? Well, Wirtinger observed that even if :math:`f(z)` +isn’t holomorphic, one could rewrite it as a two variable function +:math:`f(z, z*)` which is always holomorphic. This is because real and +imaginary of the components of :math:`z` can be expressed in terms of +:math:`z` and :math:`z^*` as: -Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v -which compute the real and imaginary parts of the function: + .. math:: + \begin{aligned} + Re(z) &= \frac {z + z^*}{2} \\ + Im(z) &= \frac {z - z^*}{2j} + \end{aligned} + +Wirtinger calculus suggests to study :math:`f(z, z^*)` instead, which is +guaranteed to be holomorphic if :math:`f` was real differentiable (another +way to think of it is as a change of coordinate system, from :math:`f(x, y)` +to :math:`f(z, z^*)`.) This function has partial derivatives +:math:`\frac{\partial }{\partial z}` and :math:`\frac{\partial}{\partial z^{*}}`. +We can use the chain rule to establish a +relationship between these partial derivatives and the partial +derivatives w.r.t., the real and imaginary components of :math:`z`. - .. code:: + .. math:: + \begin{aligned} + \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ + &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ + \\ + \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ + &= 1j * (\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}) + \end{aligned} + +From the above equations, we get: + + .. math:: + \begin{aligned} + \frac{\partial }{\partial z} &= 1/2 * (\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}) \\ + \frac{\partial }{\partial z^*} &= 1/2 * (\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}) + \end{aligned} - def F(z): - x, y = real(z), imag(z) - return u(x, y) + v(x, y) * 1j +which is the classic definition of Wirtinger calculus that you would find on `Wikipedia `_. -where :math:`1j` is a unit imaginary number. +There are a lot of beautiful consequences of this change. -We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent -vector :math:`c+dj \in C` as: +- For one, the Cauchy-Riemann equations translate into simply saying that :math:`\frac{\partial f}{\partial z^*} = 0` (that is to say, the function :math:`f` can be written + entirely in terms of :math:`z`, without making reference to :math:`z^*`). +- Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we should + take while making variable update is given by :math:`\frac{\partial Loss}{\partial z^*}` (not :math:`\frac{\partial Loss}{\partial z}`). - .. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} +For more reading, check out: https://arxiv.org/pdf/0906.4835.pdf -where +How is Wirtinger Calculus useful in optimization? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Researchers in audio and other fields, more commonly, use gradient +descent to optimize real valued loss functions with complex variables. +Typically, these people treat the real and imaginary values as separate +channels that can be updated. For a step size :math:`s/2` and loss +:math:`L`, we can write the following equations in :math:`ℝ^2`: .. math:: - J = \begin{bmatrix} - \frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\ - \frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\ + \begin{aligned} + x_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} \\ + y_{n+1} &= y_n - (s/2) * \frac{\partial L}{\partial y} + \end{aligned} + +How do these equations translate into complex space :math:`ℂ`? -This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication -with :math:`[1, 1j]^T` is used to identify the result as a complex number. + .. math:: + \begin{aligned} + z_{n+1} &= x_n - (s/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (s/2) * \frac{\partial L}{\partial y}) + &= z_n - s * 1/2 * (\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}) + &= z_n - s * \frac{\partial L}{\partial z^*} + \end{aligned} + +Something very interesting has happened: Wirtinger calculus tells us +that we can simplify the complex variable update formula above to only +refer to the conjugate Wirtinger derivative +:math:`\frac{\partial L}{\partial z^*}`, giving us exactly the step we take in optimization. + +Because the conjugate Wirtinger derivative gives us exactly the correct step for a real valued loss function, PyTorch gives you this derivative +when you differentiate a function with a real valued loss. + +How does PyTorch compute the conjugate Wirtinger derivative? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Typically, our derivative formulas take in `grad_output` as an input, +representing the incoming Vector-Jacobian product that we’ve already +computed, aka, :math:`\frac{\partial L}{\partial s^*}`, where :math:`L` +is the loss of the entire computation (producing a real loss) and +:math:`s` is the output of our function. The goal here is to compute +:math:`\frac{\partial L}{\partial z^*}`, where :math:`z` is the input of +the function. It turns out that in the case of real loss, we can +get away with *only* calculating :math:`\frac{\partial L}{\partial z^*}`, +even though the chain rule implies that we also need to +have access to :math:`\frac{\partial L}{\partial z^*}`. If you want +to skip this derivation, look at the last equation in this section +and then skip to the next section. + +Let’s continue working with :math:`f: ℂ → ℂ` defined as +:math:`f(z) = f(x+yj) = u(x, y) + v(x, y)j`. As discussed above, +autograd’s gradient convention is centered around optimization for real +valued loss functions, so let’s assume :math:`f` is a part of larger +real valued loss function :math:`g`. Using chain rule, we can write: -We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as: + .. math:: + \frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*} + :label: [1] - .. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} +Now using Wirtinger derivative definition, we can write: -In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward -mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at -the `JAX docs `_ -to get explanation for the negative signs in the formula. + .. math:: + \begin{aligned} + \frac{\partial L}{\partial s} = 1/2 * (\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j) \\ + \frac{\partial L}{\partial s^*} = 1/2 * (\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j) + \end{aligned} -**What happens if I call backward() on a complex scalar?** -******************************************************************************* +It should be noted here that since :math:`u` and :math:`v` are real +functions, and :math:`L` is real by our assumption that :math:`f` is a +part of a real valued function, we have: -The gradient for a complex function is computed assuming the input function is a holomorphic function. -This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom -(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number. -However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the -Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate -matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can -obtain that gradient using backward which is just a call to `vjp` with covector `1.0`. + .. math:: + (\frac{\partial L}{\partial s})^* = \frac{\partial L}{\partial s^*} + :label: [2] -The net effect of this assumption is that the partial derivatives of the imaginary part of the function -(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar -(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards). +i.e., :math:`\frac{\partial L}{\partial s}` equals to :math:`grad\_output^*`. -For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly. +Solving the above equations for :math:`\frac{\partial L}{\partial u}` and :math:`\frac{\partial L}{\partial v}`, we get: -**How are the JVP and VJP defined for cross-domain functions?** -*************************************************************** + .. math:: + \begin{aligned} + \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ + \frac{\partial L}{\partial v} = -1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) + \end{aligned} + :label: [3] + +Substituting :eq:`[3]` in :eq:`[1]`, we get: + + .. math:: + \begin{aligned} + \frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) * \frac{\partial u}{\partial z^*} - 1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) * \frac{\partial v}{\partial z^*} \\ + &= \frac{\partial L}{\partial s} * (\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j) + \frac{\partial L}{\partial s^*} * (\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j) \\ + &= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\ + &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ + \end{aligned} + +Using :eq:`[2]`, we get: + + .. math:: + \begin{aligned} + \frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s^*})^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * (\frac{\partial s}{\partial z})^* \\ + &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * {(\frac{\partial s}{\partial z})}^* } \\ + \end{aligned} + :label: [4] + +This last equation is the important one for writing your own gradients, +as it decomposes our derivative formula into a simpler one that is easy +to compute by hand. + +How can I write my own derivative formula for a complex function? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The above boxed equation gives us the general formula for all +derivatives on complex functions. However, we still need to +compute :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}`. +There are two ways you could do this: + + - The first way is to just use the definition of Wirtinger derivatives directly and calculate :math:`\frac{\partial s}{\partial z}` and :math:`\frac{\partial s}{\partial z^*}` by + using :math:`\frac{\partial s}{\partial x}` and :math:`\frac{\partial s}{\partial y}` + (which you can compute in the normal way). + - The second way is to use the change of variables trick and rewrite :math:`f(z)` as a two variable function :math:`f(z, z^*)`, and compute + the conjugate Wirtinger derivatives by treating :math:`z` and :math:`z^*` as independent variables. This is often easier; for example, if the function in question is holomorphic, only :math:`z` will be used (and :math:`\frac{\partial s}{\partial z^*}` will be zero). + +Let's consider the function :math:`f(z = x + yj) = c * z = c * (x+yj)` as an example, where :math:`c \in ℝ`. + +Using the first way to compute the Wirtinger derivatives, we have. + +.. math:: + \begin{aligned} + \frac{\partial s}{\partial z} &= 1/2 * (\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j) \\ + &= 1/2 * (c - (c * 1j) * 1j) \\ + &= c \\ + \\ + \\ + \frac{\partial s}{\partial z^*} &= 1/2 * (\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j) \\ + &= 1/2 * (c + (c * 1j) * 1j) \\ + &= 0 \\ + \end{aligned} + +Using :eq:`[4]`, and `grad\_output = 1.0` (which is the default grad output value used when :func:`backward` is called on a scalar output in PyTorch), we get: + + .. math:: + \frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c + +Using the second way to compute Wirtinger derivatives, we directly get: + + .. math:: + \begin{aligned} + \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ + &= c \\ + \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ + &= 0 + \end{aligned} -Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity), -we use the formula given below for cross-domain functions. +And using :eq:`[4]` again, we get :math:`\frac{\partial L}{\partial z^*} = c`. As you can see, the second way involves lesser calculations, and comes +in more handy for faster calculations. -The :math:`JVP` and :math:`VJP` for a :math:`f1: ℂ → ℝ^2` are defined as: +What about cross-domain functions? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - .. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix} +Some functions map from complex inputs to real outputs, or vice versa. +These functions form a special case of :eq:`[4]`, which we can derive using the +chain rule: - .. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix} + - For :math:`f: ℂ → ℝ`, we get: -The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as: + .. math:: + \frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}} - .. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\ + - For :math:`f: ℝ → ℂ`, we get: - .. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J + .. math:: + \frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}}) diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 230426be8695f..34ee143a77d57 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -65,7 +65,7 @@ available on new NVIDIA GPUs since Ampere, internally to compute matmul (matrix and batched matrix multiplies) and convolutions. TF32 tensor cores are designed to achieve better performance on matmul and convolutions on -`torch.float32` tensors by truncating input data to have 10 bits of mantissa, and accumulating +`torch.float32` tensors by rounding input data to have 10 bits of mantissa, and accumulating results with FP32 precision, maintaining FP32 dynamic range. matmuls and convolutions are controlled separately, and their corresponding flags can be accessed at: @@ -104,7 +104,7 @@ To get an idea of the precision and speed, see the example code below: ab_fp32 = a @ b # takes 0.11s on GA100 error = (ab_fp32 - ab_full).abs().max() # 0.0031 relative_error = error / mean # 0.000039 - + From the above example, we can see that with TF32 enabled, the speed is ~7x faster, relative error compared to double precision is approximately 2 orders of magnitude larger. If the full FP32 precision is needed, users can disable TF32 by: @@ -114,6 +114,13 @@ is needed, users can disable TF32 by: torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False +To toggle the TF32 flags off in C++, you can do + +.. code:: C++ + + at::globalContext().setAllowTF32CuBLAS(false); + at::globalContext().setAllowTF32CuDNN(false); + For more information about TF32, see: - `TensorFloat-32`_ @@ -189,6 +196,41 @@ necessary synchronization when data is moved around, as explained above. However, when using non-default streams, it is the user's responsibility to ensure proper synchronization. +.. _bwd-cuda-stream-semantics: + +Stream semantics of backward passes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Internally, each backward CUDA op runs on the same stream that was used for its corresponding forward op. + +When manually supplying CUDA tensor(s) as a backward pass's initial gradient(s) (e.g., +:func:`autograd.backward(..., grad_tensors=initial_grads)`, +:func:`autograd.grad(..., grad_outputs=initial_grads)`, or +:meth:`tensor.backward(..., gradient=initial_grad)`), +the acts of + +1. populating the initial gradient(s) and +2. invoking the backward pass + +have the same stream-semantics relationship as any pair of ops:: + + # Safe, populating initial_grad and invoking backward are in the same stream context + with torch.cuda.stream(strm): + loss.backward(gradient=torch.ones_like(loss)) + + # Unsafe, populating initial_grad and invoking backward are in different stream contexts, + # without synchronization + initial_grad = torch.ones_like(loss) + with torch.cuda.stream(strm): + loss.backward(gradient=initial_grad) + + # Safe, with synchronization + initial_grad = torch.ones_like(loss) + strm.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(strm): + initial_grad.record_stream(strm) + loss.backward(gradient=initial_grad) + .. _CUDA stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams .. _cuda-memory-management: @@ -214,13 +256,17 @@ complete snapshot of the memory allocator state via :meth:`~torch.cuda.memory_snapshot`, which can help you understand the underlying allocation patterns produced by your code. +Use of a caching allocator can interfere with memory checking tools such as +``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set +``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching. + .. _cufft-plan-cache: cuFFT plan cache ---------------- For each CUDA device, an LRU cache of cuFFT plans is used to speed up repeatedly -running FFT methods (e.g., :func:`torch.fft`) on CUDA tensors of same geometry +running FFT methods (e.g., :func:`torch.fft.fft`) on CUDA tensors of same geometry with same configuration. Because some cuFFT plans may allocate GPU memory, these caches have a maximum capacity. @@ -399,7 +445,7 @@ The difference between :class:`~torch.nn.parallel.DistributedDataParallel` and uses multiprocessing where a process is created for each GPU, while :class:`~torch.nn.DataParallel` uses multithreading. By using multiprocessing, each GPU has its dedicated process, this avoids the performance overhead caused -by GIL of Python interpreter. +by GIL of Python interpreter. -If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use +If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use `torch.distributed.launch` utility to launch your program, see :ref:`distributed-launch`. diff --git a/docs/source/notes/ddp.rst b/docs/source/notes/ddp.rst index 66cb5d67eea64..b7b2676bb2065 100644 --- a/docs/source/notes/ddp.rst +++ b/docs/source/notes/ddp.rst @@ -143,7 +143,7 @@ the structure of the code. ProcessGroup ------------ -- `ProcessGroup.hpp `__: +- `ProcessGroup.hpp `__: contains the abstract API of all process group implementations. The ``c10d`` library provides 3 implementations out of the box, namely, `ProcessGroupGloo`, `ProcessGroupNCCL`, and `ProcessGroupMPI`. @@ -152,13 +152,13 @@ ProcessGroup and ``ProcessGroup::allreduce()`` to sum gradients. -- `Store.hpp `__: +- `Store.hpp `__: assists the rendezvous service for process group instances to find each other. DistributedDataParallel ----------------------- -- `distributed.py `__: +- `distributed.py `__: is the Python entry point for DDP. It implements the initialization steps and the ``forward`` function for the ``nn.parallel.DistributedDataParallel`` module which call into C++ libraries. Its ``_sync_param`` function performs @@ -167,12 +167,12 @@ DistributedDataParallel all other processes. The inter-process parameter synchronization happens in ``Reducer.cpp``. -- `comm.h `__: +- `comm.h `__: implements the coalesced broadcast helper function which is invoked to broadcast model states during initialization and synchronize model buffers before the forward pass. -- `reducer.h `__: +- `reducer.h `__: provides the core implementation for gradient synchronization in the backward pass. It has three entry point functions: diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index d26aca1fee1af..50100720b33b3 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -247,6 +247,8 @@ This is how a ``Linear`` module can be implemented:: self.input_features, self.output_features, self.bias is not None ) +.. _extending-torch: + Extending :mod:`torch` ---------------------- @@ -605,13 +607,13 @@ provides a developer-facing API for ensuring full support for changes without warning in the future. First, to get a listing of all overridable functions, use -``torch._overrides.get_overridable_functions``. This returns a dictionary whose +``torch.overrides._get_overridable_functions``. This returns a dictionary whose keys are namespaces in the ``PyTorch`` Python API and whose values are a list of functions in that namespace that can be overriden. For example, let's print the names of the first 5 functions in ``torch.nn.functional`` that can be overriden:: - >>> from torch._overrides import get_overridable_functions + >>> from torch.overrides import get_overridable_functions >>> func_dict = get_overridable_functions() >>> nn_funcs = func_dict[torch.nn.functional] >>> print([f.__name__ for f in nn_funcs[:5]) @@ -622,20 +624,20 @@ This listing of functions makes it possible to iterate over all overridable functions, however in practice this is not enough to write tests for all of these functions without laboriously and manually copying the signature of each function for each test. To ease this process, the -``torch._overrides.get_testing_overrides`` function returns a dictionary mapping +``torch.overrides._get_testing_overrides`` function returns a dictionary mapping overridable functions in the ``PyTorch`` API to dummy lambda functions that have the same signature as the original function but unconditionally return -1. These functions are most useful to use with ``inspect`` to analyze the function signature of the original ``PyTorch`` function:: >>> import inspect - >>> from torch._overrides import get_testing_overrides + >>> from torch.overrides import get_testing_overrides >>> override_dict = get_testing_overrides() >>> dummy_add = override_dict[torch.add] >>> inspect.signature(dummy_add) -Finally, ``torch._overrides.get_ignored_functions`` returns a tuple of functions +Finally, ``torch.overrides.get_ignored_functions`` returns a tuple of functions that explicitly cannot be overrided by ``__torch_function__``. This list can be useful to confirm that a function that isn't present in the dictionary returned by ``get_overridable_functions`` cannot be overriden. diff --git a/docs/source/notes/randomness.rst b/docs/source/notes/randomness.rst index 6f629074a4f42..84ffb43f8c0c3 100644 --- a/docs/source/notes/randomness.rst +++ b/docs/source/notes/randomness.rst @@ -1,3 +1,5 @@ +.. _reproducibility: + Reproducibility =============== @@ -30,6 +32,14 @@ CPU and CUDA):: import torch torch.manual_seed(0) +Python +------ + +For custom operators, you might need to set python seed as well:: + + inport random + random.seed(0) + Random number generators in other libraries ------------------------------------------- If you or any of the libraries you are using rely on NumPy, you can seed the global @@ -121,3 +131,22 @@ CUDA RNN and LSTM In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior. See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds. +DataLoader +.......... + +DataLoader will reseed workers following :ref:`data-loading-randomness` algorithm. +Use :meth:`worker_init_fn` to preserve reproducibility:: + + def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + numpy.random.seed(worker_seed) + random.seed(worker_seed) + + DataLoader( + train_dataset, + batch_size=batch_size, + num_workers=num_workers, + worker_init_fn=seed_worker + ) + + diff --git a/docs/source/notes/windows.rst b/docs/source/notes/windows.rst index 443d3849582f9..cc195e7a93a95 100644 --- a/docs/source/notes/windows.rst +++ b/docs/source/notes/windows.rst @@ -20,14 +20,15 @@ MKL and MAGMA. Here are the steps to build with them. REM Download MAGMA files REM version available: + REM 2.5.4 (CUDA 10.1 10.2 11.0 11.1) x (Debug Release) REM 2.5.3 (CUDA 10.1 10.2 11.0) x (Debug Release) REM 2.5.2 (CUDA 9.2 10.0 10.1 10.2) x (Debug Release) REM 2.5.1 (CUDA 9.2 10.0 10.1 10.2) x (Debug Release) REM 2.5.0 (CUDA 9.0 9.2 10.0 10.1) x (Debug Release) REM 2.4.0 (CUDA 8.0 9.2) x (Release) - set CUDA_PREFIX=cuda92 + set CUDA_PREFIX=cuda101 set CONFIG=release - curl -k https://s3.amazonaws.com/ossci-windows/magma_2.5.1_%CUDA_PREFIX%_%CONFIG%.7z -o magma.7z + curl -k https://s3.amazonaws.com/ossci-windows/magma_2.5.4_%CUDA_PREFIX%_%CONFIG%.7z -o magma.7z 7z x -aoa magma.7z -omagma REM Setting essential environment variables diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 3c07486b0e899..a1f88a64c83d0 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -249,6 +249,233 @@ E.g.: :: out = model(*inputs) torch.onnx.export(model, inputs, 'loop_and_list.onnx', opset_version=11, example_outputs=out) +Write PyTorch model in Torch way +-------------------------------- + +PyTorch models can be written using numpy manipulations, but this is not proper when we convert to the ONNX model. +For the trace-based exporter, tracing treats the numpy values as the constant node, +therefore it calculates the wrong result if we change the input. +So the PyTorch model need implement using torch operators. +For example, do not use numpy operators on numpy tensors: :: + + np.concatenate((x, y, z), axis=1) + +do not convert to numpy types: :: + + y = x.astype(np.int) + +Always use torch tensors and torch operators: torch.concat, etc. +In addition, Dropout layer need defined in init function so that inferencing can handle it properly, i.e., :: + + class MyModule(nn.Module): + def __init__(self): + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x = self.dropout(x) + +Using dictionaries to handle Named Arguments as model inputs +------------------------------------------------------------ + +There are two ways to handle models which consist of named parameters or keyword arguments as inputs: + +* The first method is to pass all the inputs in the same order as required by the model and pass None + values for the keyword arguments that do not require a value to be passed + +* The second and more intuitive method is to represent the keyword arguments as key-value pairs where + the key represents the name of the argument in the model signature and the value represents the value + of the argument to be passed + +For example, in the model: :: + + class Model(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + m = Model() + x = torch.randn(2, 3) + z = torch.randn(2, 3) + +There are two ways of exporting the model: + +* Not using a dictionary for the keyword arguments and passing all the inputs in the same order + as required by the model :: + + torch.onnx.export(model, (x, None, z), ‘test.onnx’) + +* Using a dictionary to represent the keyword arguments. This dictionary is always passed in + addition to the non-keyword arguments and is always the last argument in the args tuple. :: + + torch.onnx.export(model, (x, {'y': None, 'z': z}), ‘test.onnx’) + +For cases in which there are no keyword arguments, models can be exported with either an +empty or no dictionary. For example, :: + + torch.onnx.export(model, (x, {}), ‘test.onnx’) + or + torch.onnx.export(model, (x, ), ‘test.onnx’) + +An exception to this rule are cases in which the last input is also of a dictionary type. +In these cases it is mandatory to have an empty dictionary as the last argument in the +args tuple. For example, :: + + class Model(torch.nn.Module): + def forward(self, k, x): + ... + return x + m = Model() + k = torch.randn(2, 3)   + x = {torch.tensor(1.): torch.randn(2, 3)} + +Without the presence of the empty dictionary, the export call assumes that the +‘x’ input is intended to represent the optional dictionary consisting of named arguments. +In order to prevent this from being an issue a constraint is placed to provide an empty +dictionary as the last input in the tuple args in such cases. +The new call would look like this. :: + + torch.onnx.export(model, (k, x, {}), ‘test.onnx’) + + +Indexing +-------- + +Tensor indexing in PyTorch is very flexible and complicated. +There are two categories of indexing. Both are largely supported in exporting today. +If you are experiencing issues exporting indexing that belongs to the supported patterns below, +please double check that you are exporting with the latest opset (opset_version=12). + +Getter +~~~~~~ + +This type of indexing occurs on the RHS. Export is supported for ONNX opset version >= 9. E.g.: :: + + data = torch.randn(3, 4) + index = torch.tensor([1, 2]) + + # RHS indexing is supported in ONNX opset >= 11. + class RHSIndexing(torch.nn.Module): + def forward(self, data, index): + return data[index] + + out = RHSIndexing()(data, index) + + torch.onnx.export(RHSIndexing(), (data, index), 'indexing.onnx', opset_version=9) + + # onnxruntime + import onnxruntime + sess = onnxruntime.InferenceSession('indexing.onnx') + out_ort = sess.run(None, { + sess.get_inputs()[0].name: data.numpy(), + sess.get_inputs()[1].name: index.numpy(), + }) + + assert torch.all(torch.eq(out, torch.tensor(out_ort))) + +Below is the list of supported patterns for RHS indexing. :: + + # Scalar indices + data[0, 1] + + # Slice indices + data[:3] + + # Tensor indices + data[torch.tensor([[1, 2], [2, 3]])] + data[torch.tensor([2, 3]), torch.tensor([1, 2])] + data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] + data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] + + # Ellipsis + # Not supported in scripting + # i.e. torch.jit.script(model) will fail if model contains this pattern. + # Export is supported under tracing + # i.e. torch.onnx.export(model) + data[...] + + # The combination of above + data[2, ..., torch.tensor([2, 1, 3]), 2:4, torch.tensor([[1], [2]])] + + # Boolean mask (supported for ONNX opset version >= 11) + data[data != 1] + +And below is the list of unsupported patterns for RHS indexing. :: + + # Tensor indices that includes negative values. + data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])] + +Setter +~~~~~~ + +In code, this type of indexing occurs on the LHS. +Export is supported for ONNX opset version >= 11. E.g.: :: + + data = torch.zeros(3, 4) + new_data = torch.arange(4).to(torch.float32) + + # LHS indexing is supported in ONNX opset >= 11. + class LHSIndexing(torch.nn.Module): + def forward(self, data, new_data): + data[1] = new_data + return data + + out = LHSIndexing()(data, new_data) + + data = torch.zeros(3, 4) + new_data = torch.arange(4).to(torch.float32) + torch.onnx.export(LHSIndexing(), (data, new_data), 'inplace_assign.onnx', opset_version=11) + + # onnxruntime + import onnxruntime + sess = onnxruntime.InferenceSession('inplace_assign.onnx') + out_ort = sess.run(None, { + sess.get_inputs()[0].name: torch.zeros(3, 4).numpy(), + sess.get_inputs()[1].name: new_data.numpy(), + }) + + assert torch.all(torch.eq(out, torch.tensor(out_ort))) + +Below is the list of supported patterns for LHS indexing. :: + + # Scalar indices + data[0, 1] = new_data + + # Slice indices + data[:3] = new_data + + # Tensor indices + # If more than one tensor are used as indices, only consecutive 1-d tensor indices are supported. + data[torch.tensor([[1, 2], [2, 3]])] = new_data + data[torch.tensor([2, 3]), torch.tensor([1, 2])] = new_data + + # Ellipsis + # Not supported to export in script modules + # i.e. torch.onnx.export(torch.jit.script(model)) will fail if model contains this pattern. + # Export is supported under tracing + # i.e. torch.onnx.export(model) + data[...] = new_data + + # The combination of above + data[2, ..., torch.tensor([2, 1, 3]), 2:4] += update + + # Boolean mask + data[data != 1] = new_data + +And below is the list of unsupported patterns for LHS indexing. :: + + # Multiple tensor indices if any has rank >= 2 + data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data + + # Multiple tensor indices that are not consecutive + data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data + + # Tensor indices that includes negative values. + data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data + +If you are experiencing issues exporting indexing that belongs to the above supported patterns, please double check that +you are exporting with the latest opset (opset_version=12). TorchVision support ------------------- @@ -348,6 +575,7 @@ The following operators are supported: * glu * group_norm * gt +* hardswish * hardtanh * im2col * index_copy @@ -527,7 +755,7 @@ but intuitively the interface they provide looks like this:: ONNX outputs whose values correspond to the original PyTorch return values of the autograd Function (or None if an output is not supported by ONNX). - Arguments: + Args: g (Graph): graph to write the ONNX representation into inputs (Value...): list of values representing the variables which contain the inputs for this function @@ -554,7 +782,7 @@ but intuitively the interface they provide looks like this:: The set of operators and the inputs/attributes they take is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md - Arguments: + Args: opname (string): The ONNX operator name, e.g., `Abs` or `Add`. args (Value...): The inputs to the operator; usually provided as arguments to the `symbolic` definition. @@ -649,16 +877,16 @@ This mode is used to export all operators as regular ONNX operators. This is the Example torch ir graph: - graph(%0 : Float(2:12, 3:4, 4:1)): - %3 : Float(2:12, 3:4, 4:1) = aten:exp(%0) - %4 : Float(2:12, 3:4, 4:1) = aten:div(%0, %3) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:exp(%0) + %4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:div(%0, %3) return (%4) Is exported as: - graph(%0 : Float(2:12, 3:4, 4:1)): - %1 : Float(2:12, 3:4, 4:1) = onnx:Exp(%0) - %2 : Float(2:12, 3:4, 4:1) = onnx:Div(%0, %1) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %1 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Exp(%0) + %2 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Div(%0, %1) return (%2) @@ -668,16 +896,16 @@ This mode is used to export all operators as ATen ops, and avoid conversion to O Example torch ir graph: - graph(%0 : Float(2:12, 3:4, 4:1)): - %3 : Float(2:12, 3:4, 4:1) = aten::exp(%0) - %4 : Float(2:12, 3:4, 4:1) = aten::div(%0, %3) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::exp(%0) + %4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%0, %3) return (%4) Is exported as: - graph(%0 : Float(2:12, 3:4, 4:1)): - %1 : Float(2:12, 3:4, 4:1) = aten::ATen[operator="exp"](%0) - %2 : Float(2:12, 3:4, 4:1) = aten::ATen[operator="div"](%0, %1) + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])): + %1 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="exp"](%0) + %2 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="div"](%0, %1) return (%2) ONNX_ATEN_FALLBACK @@ -707,7 +935,7 @@ To export a raw ir. :: Example torch ir graph: - graph(%x.1 : Float(1:1)): + graph(%x.1 : Float(1, strides=[1])): %1 : Tensor = aten::exp(%x.1) %2 : Tensor = aten::div(%x.1, %1) %y.1 : Tensor[] = prim::ListConstruct(%2) @@ -715,7 +943,7 @@ To export a raw ir. :: is exported as: - graph(%x.1 : Float(1:1)): + graph(%x.1 : Float(1, strides=[1])): %1 : Tensor = aten::exp(%x.1) %2 : Tensor = aten::div(%x.1, %1) %y.1 : Tensor[] = prim::ListConstruct(%2) @@ -729,18 +957,18 @@ enables users to register and implement the operator as part of their runtime ba Example torch ir graph: - graph(%0 : Float(2:12, 3:4, 4:1), - %1 : Float(2:12, 3:4, 4:1)): - %6 : Float(2:12, 3:4, 4:1) = foo_namespace::bar(%0, %1) # custom op - %7 : Float(2:12, 3:4, 4:1) = aten::div(%6, %0) # registered op + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]), + %1 : Float(2, 3, 4, strides=[12, 4, 1])): + %6 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op + %7 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%6, %0) # registered op return (%7)) is exported as: - graph(%0 : Float(2:12, 3:4, 4:1), - %1 : Float(2:12, 3:4, 4:1)): - %2 : Float(2:12, 3:4, 4:1) = foo_namespace::bar(%0, %1) # custom op - %3 : Float(2:12, 3:4, 4:1) = onnx::Div(%2, %0) # registered op + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]), + %1 : Float(2, 3, 4, strides=[12, 4, 1])): + %2 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op + %3 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx::Div(%2, %0) # registered op return (%3 @@ -807,32 +1035,7 @@ Q: Does ONNX support implicit scalar datatype casting? Q: Is tensor in-place indexed assignment like `data[index] = new_data` supported? - Yes, this is supported now for ONNX opset version >= 11. E.g.: :: - - data = torch.zeros(3, 4) - new_data = torch.arange(4).to(torch.float32) - - # Assigning to left hand side indexing is supported in ONNX opset >= 11. - class InPlaceIndexedAssignment(torch.nn.Module): - def forward(self, data, new_data): - data[1] = new_data - return data - - out = InPlaceIndexedAssignment()(data, new_data) - - data = torch.zeros(3, 4) - new_data = torch.arange(4).to(torch.float32) - torch.onnx.export(InPlaceIndexedAssignment(), (data, new_data), 'inplace_assign.onnx', opset_version=11) - - # onnxruntime - import onnxruntime - sess = onnxruntime.InferenceSession('inplace_assign.onnx') - out_ort = sess.run(None, { - sess.get_inputs()[0].name: torch.zeros(3, 4).numpy(), - sess.get_inputs()[1].name: new_data.numpy(), - }) - - assert torch.all(torch.eq(out, torch.tensor(out_ort))) + Yes, this is supported for ONNX opset version >= 11. Please checkout `Indexing`_. Q: Is tensor list exportable to ONNX? diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 2ecb75fd83519..936206a5e2ded 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -269,7 +269,7 @@ and start to collect SWA averages of the parameters at epoch 160: >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() ->>> if i > swa_start: +>>> if epoch > swa_start: >>> swa_model.update_parameters(model) >>> swa_scheduler.step() >>> else: diff --git a/docs/source/profiler.rst b/docs/source/profiler.rst new file mode 100644 index 0000000000000..6541f08c4feb1 --- /dev/null +++ b/docs/source/profiler.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torch.profiler + +torch.profiler +============== + +Overview +-------- +.. automodule:: torch.profiler + + +API Reference +------------- + +.. autoclass:: torch.profiler.profile + :members: + +.. autofunction:: torch.profiler.schedule diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index a1b13e1ada1f5..f782d51b50277 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -105,7 +105,7 @@ Fused modules are provided for common patterns in CNNs. Combining several operations together (like convolution and relu) allows for better quantization accuracy - + * `torch.nn.intrinsic` — float versions of the modules, can be swapped with quantized version 1 to 1: @@ -172,7 +172,6 @@ Layers for the quantization-aware training * :func:`~torch.quantization.fuse_modules` * Functions for graph mode quantization: - * :func:`~torch.quantization.quantize_jit` - Function for graph mode post training static quantization * :func:`~torch.quantization.quantize_dynamic_jit` - Function for graph mode post training dynamic quantization @@ -255,7 +254,6 @@ Quantized version of standard NN layers. * :class:`~torch.nn.quantized.Conv3d` — 3D convolution * :class:`~torch.nn.quantized.Linear` — Linear (fully-connected) layer * :class:`~torch.nn.MaxPool2d` — 2D max pooling -* :class:`~torch.nn.quantized.ReLU` — Rectified linear unit * :class:`~torch.nn.quantized.ReLU6` — Rectified linear unit with cut-off at quantized representation of 6 * :class:`~torch.nn.quantized.ELU` — ELU @@ -294,7 +292,6 @@ quantization output parameters) * :func:`~torch.nn.quantized.functional.interpolate` — Down-/up- sampler * :func:`~torch.nn.quantized.functional.linear` — Linear (fully-connected) op * :func:`~torch.nn.quantized.functional.max_pool2d` — 2D max pooling -* :func:`~torch.nn.quantized.functional.relu` — Rectified linear unit * :func:`~torch.nn.quantized.functional.elu` — ELU * :func:`~torch.nn.quantized.functional.hardsigmoid` — Hardsigmoid * :func:`~torch.nn.quantized.functional.hardswish` — Hardswish @@ -324,5 +321,3 @@ Quantized dtypes and quantization schemes * :attr:`torch.quint8` — 8-bit unsigned integer * :attr:`torch.qint8` — 8-bit signed integer * :attr:`torch.qint32` — 32-bit signed integer - - diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index b597fa9f51f3c..1cac90ffab869 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -45,8 +45,8 @@ The corresponding implementation is chosen automatically based on the PyTorch bu .. note:: - PyTorch 1.3 doesn't provide quantized operator implementations on CUDA yet - - this is direction of future work. Move the model to CPU in order to test the + At the moment PyTorch doesn't provide quantized operator implementations on CUDA - + this is the direction for future work. Move the model to CPU in order to test the quantized functionality. Quantization-aware training (through :class:`~torch.quantization.FakeQuantize`) @@ -77,6 +77,356 @@ The corresponding implementation is chosen automatically based on the PyTorch bu ``torch.backends.quantized.engine = 'qnnpack'`` +Quantization API Summary +--------------------------------------- + +PyTorch provides two different modes of quantization: Eager Mode Quantization and FX Graph Mode Quantization. + +Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals. + +FX Graph Mode Quantization is a new automated quantization framework in PyTorch, and currently it's a prototype feature. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process. Although people might need to refactor the model a bit to make the model compatible with FX Graph Mode Quantization (symbolically traceable with torch.fx). + +Eager Mode Quantization +^^^^^^^^^^^^^^^^^^^^^^^ + +There are three types of quantization supported in Eager Mode Quantization: + +1. dynamic quantization (weights quantized with activations read/stored in + floating point and quantized for compute.) +2. static quantization (weights quantized, activations quantized, calibration + required post training) +3. quantization aware training (weights quantized, activations quantized, + quantization numerics modeled during training) + +Please see our `Introduction to Quantization on Pytorch +`_ blog post +for a more comprehensive overview of the tradeoffs between these quantization +types. + +Dynamic Quantization +~~~~~~~~~~~~~~~~~~~~ + +This is the simplest to apply form of quantization where the weights are +quantized ahead of time but the activations are dynamically quantized +during inference. This is used for situations where the model execution time +is dominated by loading weights from memory rather than computing the matrix +multiplications. This is true for for LSTM and Transformer type models with +small batch size. + +Diagram:: + + # original model + # all tensors and computations are in floating point + previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 + / + linear_weight_fp32 + + # dynamically quantized model + # linear and conv weights are in int8 + previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32 + / + linear_weight_int8 + +API example:: + + import torch + + # define a floating point model + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.fc = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.fc(x) + return x + + # create a model instance + model_fp32 = M() + # create a quantized model instance + model_int8 = torch.quantization.quantize_dynamic( + model_fp32, # the original model + {torch.nn.Linear}, # a set of layers to dynamically quantize + dtype=torch.qint8) # the target dtype for quantized weights + + # run the model + input_fp32 = torch.randn(4, 4, 4, 4) + res = model_int8(input_fp32) + +To learn more about dynamic quantization please see our `dynamic quantization tutorial +`_. + +Static Quantization +~~~~~~~~~~~~~~~~~~~ + +Static quantization quantizes the weights and activations of the model. It +fuses activations into preceding layers where possible. It requires +calibration with a representative dataset to determine optimal quantization +parameters for activations. Post Training Quantization is typically used when +both memory bandwidth and compute savings are important with CNNs being a +typical use case. Static quantization is also known as Post Training +Quantization or PTQ. + +Diagram:: + + # original model + # all tensors and computations are in floating point + previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 + / + linear_weight_fp32 + + # statically quantized model + # weights and activations are in int8 + previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8 + / + linear_weight_int8 + +API Example:: + + import torch + + # define a floating point model where some layers could be statically quantized + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + # QuantStub converts tensors from floating point to quantized + self.quant = torch.quantization.QuantStub() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.relu = torch.nn.ReLU() + # DeQuantStub converts tensors from quantized to floating point + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + # manually specify where tensors will be converted from floating + # point to quantized in the quantized model + x = self.quant(x) + x = self.conv(x) + x = self.relu(x) + # manually specify where tensors will be converted from quantized + # to floating point in the quantized model + x = self.dequant(x) + return x + + # create a model instance + model_fp32 = M() + + # model must be set to eval mode for static quantization logic to work + model_fp32.eval() + + # attach a global qconfig, which contains information about what kind + # of observers to attach. Use 'fbgemm' for server inference and + # 'qnnpack' for mobile inference. Other quantization configurations such + # as selecting symmetric or assymetric quantization and MinMax or L2Norm + # calibration techniques can be specified here. + model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm') + + # Fuse the activations to preceding layers, where applicable. + # This needs to be done manually depending on the model architecture. + # Common fusions include `conv + relu` and `conv + batchnorm + relu` + model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']]) + + # Prepare the model for static quantization. This inserts observers in + # the model that will observe activation tensors during calibration. + model_fp32_prepared = torch.quantization.prepare(model_fp32_fused) + + # calibrate the prepared model to determine quantization parameters for activations + # in a real world setting, the calibration would be done with a representative dataset + input_fp32 = torch.randn(4, 1, 4, 4) + model_fp32_prepared(input_fp32) + + # Convert the observed model to a quantized model. This does several things: + # quantizes the weights, computes and stores the scale and bias value to be + # used with each activation tensor, and replaces key operators with quantized + # implementations. + model_int8 = torch.quantization.convert(model_fp32_prepared) + + # run the model, relevant calculations will happen in int8 + res = model_int8(input_fp32) + +To learn more about static quantization, please see the `static quantization tutorial +`_. + +Quantization Aware Training +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Quantization Aware Training models the effects of quantization during training +allowing for higher accuracy compared to other quantization methods. During +training, all calculations are done in floating point, with fake_quant modules +modeling the effects of quantization by clamping and rounding to simulate the +effects of INT8. After model conversion, weights and +activations are quantized, and activations are fused into the preceding layer +where possible. It is commonly used with CNNs and yields a higher accuracy +compared to static quantization. Quantization Aware Training is also known as +QAT. + +Diagram:: + + # original model + # all tensors and computations are in floating point + previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 + / + linear_weight_fp32 + + # model with fake_quants for modeling quantization numerics during training + previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32 + / + linear_weight_fp32 -- fq + + # quantized model + # weights and activations are in int8 + previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8 + / + linear_weight_int8 + +API Example:: + + import torch + + # define a floating point model where some layers could benefit from QAT + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + # QuantStub converts tensors from floating point to quantized + self.quant = torch.quantization.QuantStub() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.bn = torch.nn.BatchNorm2d(1) + self.relu = torch.nn.ReLU() + # DeQuantStub converts tensors from quantized to floating point + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.dequant(x) + return x + + # create a model instance + model_fp32 = M() + + # model must be set to train mode for QAT logic to work + model_fp32.train() + + # attach a global qconfig, which contains information about what kind + # of observers to attach. Use 'fbgemm' for server inference and + # 'qnnpack' for mobile inference. Other quantization configurations such + # as selecting symmetric or assymetric quantization and MinMax or L2Norm + # calibration techniques can be specified here. + model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') + + # fuse the activations to preceding layers, where applicable + # this needs to be done manually depending on the model architecture + model_fp32_fused = torch.quantization.fuse_modules(model_fp32, + [['conv', 'bn', 'relu']]) + + # Prepare the model for QAT. This inserts observers and fake_quants in + # the model that will observe weight and activation tensors during calibration. + model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused) + + # run the training loop (not shown) + training_loop(model_fp32_prepared) + + # Convert the observed model to a quantized model. This does several things: + # quantizes the weights, computes and stores the scale and bias value to be + # used with each activation tensor, fuses modules where appropriate, + # and replaces key operators with quantized implementations. + model_fp32_prepared.eval() + model_int8 = torch.quantization.convert(model_fp32_prepared) + + # run the model, relevant calculations will happen in int8 + res = model_int8(input_fp32) + +To learn more about quantization aware training, please see the `QAT +tutorial +`_. + +(Prototype) FX Graph Mode Quantization +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Quantization types supported by FX Graph Mode can be classified in two ways: + +1. +- Post Training Quantization (apply quantization after training, quantization parameters are calculated based on sample calibration data) +- Quantization Aware Training (simulate quantization during training so that the quantization parameters can be learned together with the model using training data) + +2. +- Weight Only Quantization (only weight is statically quantized) +- Dynamic Quantization (weight is statically quantized, activation is dynamically quantized) +- Static Quantization (both weight and activations are statically quantized) + +These two ways of classification are independent, so theoretically we can have 6 different types of quantization. + +The supported quantization types in FX Graph Mode Quantization are: +- Post Training Quantization + + - Weight Only Quantization + - Dynamic Quantization + - Static Quantization + +- Quantization Aware Training + + - Static Quantization + + +There are multiple quantization types in post training quantization (weight only, dynamic and static) and the configuration is done through `qconfig_dict` (an argument of the `prepare_fx` function). + +API Example:: + + import torch.quantization.quantize_fx as quantize_fx + import copy + + model_fp = UserModel(...) + + # + # post training dynamic/weight_only quantization + # + + # we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model + model_to_quantize = copy.deepcopy(model_fp) + model_to_quantize.eval() + qconfig_dict = {"": torch.quantization.default_dynamic_qconfig} + # prepare + model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) + # no calibration needed when we only have dynamici/weight_only quantization + # quantize + model_quantized = quantize_fx.convert_fx(model_prepared) + + # + # post training static quantization + # + + model_to_quantize = copy.deepcopy(model_fp) + qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')} + model_to_quantize.eval() + # prepare + model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) + # calibrate (not shown) + # quantize + model_quantized = quantize_fx.convert_fx(model_prepared) + + # + # quantization aware training for static quantization + # + + model_to_quantize = copy.deepcopy(model_fp) + qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('qnnpack')} + model_to_quantize.train() + # prepare + model_prepared = quantize_fx.prepare_qat_fx(model_to_qunatize, qconfig_dict) + # training loop (not shown) + # quantize + model_quantized = quantize_fx.convert_fx(model_prepared) + + # + # fusion + # + model_to_quantize = copy.deepcopy(model_fp) + model_fused = quantize_fx.fuse_fx(model_to_quantize) + +Please see the following tutorials for more information about FX Graph Mode Quantization: +- FX Graph Mode Post Training Static Quantization (TODO: link) +- FX Graph Mode Post Training Dynamic Quantization (TODO: link) + Quantized Tensors --------------------------------------- @@ -121,79 +471,8 @@ cover typical CNN and RNN models torch.nn.quantized torch.nn.quantized.dynamic -Quantization Workflows ----------------------- - -PyTorch provides three approaches to quantize models. - -.. _quantization tutorials: - https://pytorch.org/tutorials/#quantization-experimental - -1. Post Training Dynamic Quantization: This is the simplest to apply form of - quantization where the weights are quantized ahead of time but the - activations are dynamically quantized during inference. This is used - for situations where the model execution time is dominated by loading - weights from memory rather than computing the matrix multiplications. - This is true for for LSTM and Transformer type models with small - batch size. Applying dynamic quantization to a whole model can be - done with a single call to :func:`torch.quantization.quantize_dynamic()`. - See the `quantization tutorials`_ -2. Post Training Static Quantization: This is the most commonly used form of - quantization where the weights are quantized ahead of time and the - scale factor and bias for the activation tensors is pre-computed - based on observing the behavior of the model during a calibration - process. Post Training Quantization is typically when both memory bandwidth - and compute savings are important with CNNs being a typical use case. - The general process for doing post training quantization is: - - - - 1. Prepare the model: - - a. Specify where the activations are quantized and dequantized explicitly - by adding QuantStub and DeQuantStub modules. - b. Ensure that modules are not reused. - c. Convert any operations that require requantization into modules - - 2. Fuse operations like conv + relu or conv+batchnorm + relu together to - improve both model accuracy and performance. - - 3. Specify the configuration of the quantization methods \'97 such as - selecting symmetric or asymmetric quantization and MinMax or - L2Norm calibration techniques. - 4. Use the :func:`torch.quantization.prepare` to insert modules - that will observe activation tensors during calibration - 5. Calibrate the model by running inference against a calibration - dataset - 6. Finally, convert the model itself with the - torch.quantization.convert() method. This does several things: it - quantizes the weights, computes and stores the scale and bias - value to be used each activation tensor, and replaces key - operators quantized implementations. - - See the `quantization tutorials`_ - - -3. Quantization Aware Training: In the rare cases where post training - quantization does not provide adequate accuracy training can be done - with simulated quantization using the - :class:`torch.quantization.FakeQuantize`. Computations will take place in - FP32 but with values clamped and rounded to simulate the effects of INT8 - quantization. The sequence of steps is very similar. - - - 1. Steps (1) and (2) are identical. - - 3. Specify the configuration of the fake quantization methods \'97 such as - selecting symmetric or asymmetric quantization and MinMax or Moving Average - or L2Norm calibration techniques. - 4. Use the :func:`torch.quantization.prepare_qat` to insert modules - that will simulate quantization during training. - 5. Train or fine tune the model. - 6. Identical to step (6) for post training quantization - - See the `quantization tutorials`_ - +Quantization Customizations +--------------------------- While default implementations of observers to select the scale factor and bias based on observed tensor data are provided, developers can provide their own @@ -218,9 +497,15 @@ prior to quantization. This is because currently quantization works on a module by module basis. Specifically, for all quantization techniques, the user needs to: 1. Convert any operations that require output requantization (and thus have - additional parameters) from functionals to module form. + additional parameters) from functionals to module form (for example, + using ``torch.nn.ReLU`` instead of ``torch.nn.functional.relu``). 2. Specify which parts of the model need to be quantized either by assigning - ```.qconfig`` attributes on submodules or by specifying ``qconfig_dict`` + ``.qconfig`` attributes on submodules or by specifying ``qconfig_dict``. + For example, setting ``model.conv1.qconfig = None`` means that the + ``model.conv`` layer will not be quantized, and setting + ``model.linear1.qconfig = custom_qconfig`` means that the quantization + settings for ``model.linear1`` will be using ``custom_qconfig`` instead + of the global qconfig. For static quantization techniques which quantize activations, the user needs to do the following in addition: @@ -238,6 +523,78 @@ to do the following in addition: to be fused. We currently support the following fusions: [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu] +Best Practices +-------------- + +1. Set the ``reduce_range`` argument on observers to `True` if you are using the + ``fbgemm`` backend. This argument prevents overflow on some int8 instructions + by reducing the range of quantized data type by 1 bit. + +Common Errors +--------------------------------------- + +Passing a non-quantized Tensor into a quantized kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you see an error similar to:: + + RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend... + +This means that you are trying to pass a non-quantized Tensor to a quantized +kernel. A common workaround is to use ``torch.quantization.QuantStub`` to +quantize the tensor. This needs to be done manually in Eager mode quantization. +An e2e example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + # during the convert step, this will be replaced with a + # `quantize_per_tensor` call + x = self.quant(x) + x = self.conv(x) + return x + +Passing a quantized Tensor into a non-quantized kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you see an error similar to:: + + RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend. + +This means that you are trying to pass a quantized Tensor to a non-quantized +kernel. A common workaround is to use ``torch.quantization.DeQuantStub`` to +dequantize the tensor. This needs to be done manually in Eager mode quantization. +An e2e example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + # this module will not be quantized (see `qconfig = None` logic below) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + # during the convert step, this will be replaced with a + # `quantize_per_tensor` call + x = self.quant(x) + x = self.conv1(x) + # during the convert step, this will be replaced with a + # `dequantize` call + x = self.dequant(x) + x = self.conv2(x) + return x + + m = M() + m.qconfig = some_qconfig + # turn off quantization for conv2 + m.conv2.qconfig = None + Modules that provide quantization functions and classes ------------------------------------------------------- diff --git a/docs/source/rpc.rst b/docs/source/rpc.rst index 1e4788c99634b..1d786710d15ca 100644 --- a/docs/source/rpc.rst +++ b/docs/source/rpc.rst @@ -113,8 +113,6 @@ and move it to the desired devices on the callee if necessary. The RPC package also provides decorators which allow applications to specify how a given function should be treated on the callee side. -.. warning:: - The ``rpc.functions`` package is a prototype feature and subject to change. .. autofunction:: torch.distributed.rpc.functions.async_execution @@ -142,9 +140,6 @@ to configure the backend's behavior. TensorPipe Backend """""""""""""""""" -.. warning:: - The TensorPipe backend is a **beta feature**. - The TensorPipe agent, which is the default, leverages `the TensorPipe library `_, which provides a natively point-to-point communication primitive specifically suited for machine learning @@ -192,6 +187,10 @@ Example:: Process Group Backend """"""""""""""""""""" +.. warning :: + The Process Group Backend will be deprecated soon, we recommend using the + TensorPipe Backend instead. + The Process Group agent instantiates a process group from the :mod:`~torch.distributed` module and utilizes its point-to-point communication capabilities to send RPC messages. Internally, the process @@ -301,3 +300,5 @@ to use `the profiler `__ - `Implementing a Parameter Server using Distributed RPC Framework `__ - `Combining Distributed DataParallel with Distributed RPC Framework `__ - `Profiling RPC-based Workloads `__ +- `Implementing batch RPC processing `__ +- `Distributed Pipeline Parallel `__ diff --git a/docs/source/rpc/distributed_autograd.rst b/docs/source/rpc/distributed_autograd.rst index 4a4f7855a7339..61af22b9486f5 100644 --- a/docs/source/rpc/distributed_autograd.rst +++ b/docs/source/rpc/distributed_autograd.rst @@ -270,7 +270,7 @@ As an example the complete code with distributed autograd would be as follows: # Retrieve the gradients from the context. dist_autograd.get_gradients(context_id) -The distributed autograd graph with dependencies would be as follows: +The distributed autograd graph with dependencies would be as follows (t5.sum() excluded for simplicity): .. image:: ../_static/img/distributed_autograd/distributed_dependencies_computed.png diff --git a/docs/source/rpc/rref.rst b/docs/source/rpc/rref.rst index 822f10a32f8cd..3d51971110389 100644 --- a/docs/source/rpc/rref.rst +++ b/docs/source/rpc/rref.rst @@ -13,7 +13,7 @@ Background ^^^^^^^^^^ RRef stands for Remote REFerence. It is a reference of an object which is -located on the local or a remote worker, and transparently handles reference +located on the local or remote worker, and transparently handles reference counting under the hood. Conceptually, it can be considered as a distributed shared pointer. Applications can create an RRef by calling :meth:`~torch.distributed.rpc.remote`. Each RRef is owned by the callee worker @@ -42,9 +42,9 @@ Assumptions RRef protocol is designed with the following assumptions. - **Transient Network Failures**: The RRef design handles transient - network failures by retrying messages. Node crashes or permanent network - partition is beyond the scope. When those incidents occur, the application - may take down all workers, revert to the previous checkpoint, and resume + network failures by retrying messages. It cannot handle node crashes or + permanent network partitions. When those incidents occur, the application + should take down all workers, revert to the previous checkpoint, and resume training. - **Non-idempotent UDFs**: We assume the user functions (UDF) provided to :meth:`~torch.distributed.rpc.rpc_sync`, @@ -87,12 +87,12 @@ The only requirement is that any ``UserRRef`` must notify the owner upon destruction. Hence, we need the first guarantee: -**G1. The owner will be notified when any ``UserRRef`` is deleted.** +**G1. The owner will be notified when any UserRRef is deleted.** As messages might come delayed or out-of-order, we need one more guarantee to make sure the delete message is not processed too soon. If A sends a message to -B that involves an RRef, we call the RRef on A the parent RRef and the RRef on B -the child RRef. +B that involves an RRef, we call the RRef on A (the parent RRef) and the RRef on B +(the child RRef). **G2. Parent RRef will NOT be deleted until the child RRef is confirmed by the owner.** @@ -125,19 +125,19 @@ possible that the child ``UserRRef`` may be deleted before the owner knows its parent ``UserRRef``. Consider the following example, where the ``OwnerRRef`` forks to A, then A forks -to Y, and Y forks to Z.: +to Y, and Y forks to Z: .. code:: OwnerRRef -> A -> Y -> Z If all of Z's messages, including the delete message, are processed by the -owner before all messages from Y, the owner will learn Z's deletion before -knowing Y. Nevertheless, this does not cause any problem. Because, at least -one of Y's ancestors will be alive (in this case, A) and it will +owner before Y's messages. the owner will learn of Z's deletion befores +knowing Y exists. Nevertheless, this does not cause any problem. Because, at least +one of Y's ancestors will be alive (A) and it will prevent the owner from deleting the ``OwnerRRef``. More specifically, if the -owner does not know Y, A cannot be deleted due to **G2**, and the owner knows A -as the owner is A's parent. +owner does not know Y, A cannot be deleted due to **G2**, and the owner knows A +since it is A's parent. Things get a little trickier if the RRef is created on a user: diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 13084fba0861d..2a2c0b1daabd7 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -1,146 +1,522 @@ -.. currentmodule:: torch.sparse +.. currentmodule:: torch .. _sparse-docs: torch.sparse ============ +Introduction +++++++++++++ + +PyTorch provides :class:`torch.Tensor` to represent a +multi-dimensional array containing elements of a single data type. By +default, array elements are stored contiguously in memory leading to +efficient implementations of various array processing algorithms that +relay on the fast access to array elements. However, there exists an +important class of multi-dimensional arrays, so-called sparse arrays, +where the contiguous memory storage of array elements turns out to be +suboptimal. Sparse arrays have a property of having a vast portion of +elements being equal to zero which means that a lot of memory as well +as processor resources can be spared if only the non-zero elements are +stored or/and processed. Various sparse storage formats (`such as COO, +CSR/CSC, LIL, etc.`__) have been developed that are optimized for a +particular structure of non-zero elements in sparse arrays as well as +for specific operations on the arrays. + +__ https://en.wikipedia.org/wiki/Sparse_matrix + +.. note:: + + When talking about storing only non-zero elements of a sparse + array, the usage of adjective "non-zero" is not strict: one is + allowed to store also zeros in the sparse array data + structure. Hence, in the following, we use "specified elements" for + those array elements that are actually stored. In addition, the + unspecified elements are typically assumed to have zero value, but + not only, hence we use the term "fill value" to denote such + elements. + +.. note:: + + Using a sparse storage format for storing sparse arrays can be + advantageous only when the size and sparsity levels of arrays are + high. Otherwise, for small-sized or low-sparsity arrays using the + contiguous memory storage format is likely the most efficient + approach. + .. warning:: - This API is in beta and may change in the near future. + The PyTorch API of sparse tensors is in beta and may change in the near future. + +.. _sparse-coo-docs: -Torch supports sparse tensors in COO(rdinate) format, which can -efficiently store and process tensors for which the majority of elements -are zeros. +Sparse COO tensors +++++++++++++++++++ -A sparse tensor is represented as a pair of dense tensors: a tensor -of values and a 2D tensor of indices. A sparse tensor can be constructed -by providing these two tensors, as well as the size of the sparse tensor -(which cannot be inferred from these tensors!) Suppose we want to define -a sparse tensor with the entry 3 at location (0, 2), entry 4 at -location (1, 0), and entry 5 at location (1, 2). We would then write: +Currently, PyTorch implements the so-called Coordinate format, or COO +format, as the default sparse storage format for storing sparse +tensors. In COO format, the specified elements are stored as tuples +of element indices and the corresponding values. In particular, - >>> i = torch.LongTensor([[0, 1, 1], - [2, 0, 2]]) - >>> v = torch.FloatTensor([3, 4, 5]) - >>> torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense() - 0 0 3 - 4 0 5 - [torch.FloatTensor of size 2x3] + - the indices of specified elements are collected in ``indices`` + tensor of size ``(ndim, nse)`` and with element type + ``torch.int64``, -Note that the input to LongTensor is NOT a list of index tuples. If you want + - the corresponding values are collected in ``values`` tensor of + size ``(nse,)`` and with an arbitrary integer or floating point + number element type, + +where ``ndim`` is the dimensionality of the tensor and ``nse`` is the +number of specified elements. + +.. note:: + + The memory consumption of a sparse COO tensor is at least ``(ndim * + 8 + ) * nse`` bytes (plus a constant + overhead from storing other tensor data). + + The memory consumption of a strided tensor is at least + ``product() * ``. + + For example, the memory consumption of a 10 000 x 10 000 tensor + with 100 000 non-zero 32-bit floating point numbers is at least + ``(2 * 8 + 4) * 100 000 = 2 000 000`` bytes when using COO tensor + layout and ``10 000 * 10 000 * 4 = 400 000 000`` bytes when using + the default strided tensor layout. Notice the 200 fold memory + saving from using the COO storage format. + +Construction +------------ + +A sparse COO tensor can be constructed by providing the two tensors of +indices and values, as well as the size of the sparse tensor (when it +cannot be inferred from the indices and values tensors) to a function +:func:`torch.sparse_coo_tensor`. + +Suppose we want to define a sparse tensor with the entry 3 at location +(0, 2), entry 4 at location (1, 0), and entry 5 at location (1, 2). +Unspecified elements are assumed to have the same value, fill value, +which is zero by default. We would then write: + + >>> i = [[0, 1, 1], + [2, 0, 2]] + >>> v = [3, 4, 5] + >>> s = torch.sparse_coo_tensor(i, v, (2, 3)) + >>> s + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3, 4, 5]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + >>> s.to_dense() + tensor([[0, 0, 3], + [4, 0, 5]]) + +Note that the input ``i`` is NOT a list of index tuples. If you want to write your indices this way, you should transpose before passing them to the sparse constructor: - >>> i = torch.LongTensor([[0, 2], [1, 0], [1, 2]]) - >>> v = torch.FloatTensor([3, 4, 5 ]) - >>> torch.sparse.FloatTensor(i.t(), v, torch.Size([2,3])).to_dense() - 0 0 3 - 4 0 5 - [torch.FloatTensor of size 2x3] - -You can also construct hybrid sparse tensors, where only the first n -dimensions are sparse, and the rest of the dimensions are dense. - - >>> i = torch.LongTensor([[2, 4]]) - >>> v = torch.FloatTensor([[1, 3], [5, 7]]) - >>> torch.sparse.FloatTensor(i, v).to_dense() - 0 0 - 0 0 - 1 3 - 0 0 - 5 7 - [torch.FloatTensor of size 5x2] - -An empty sparse tensor can be constructed by specifying its size: - - >>> torch.sparse.FloatTensor(2, 3) - SparseFloatTensor of size 2x3 with indices: - [torch.LongTensor with no dimension] - and values: - [torch.FloatTensor with no dimension] - -SparseTensor has the following invariants: - 1. sparse_dim + dense_dim = len(SparseTensor.shape) - 2. SparseTensor._indices().shape = (sparse_dim, nnz) - 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) - -Since SparseTensor._indices() is always a 2D tensor, the smallest sparse_dim = 1. -Therefore, representation of a SparseTensor of sparse_dim = 0 is simply a dense tensor. + >>> i = [[0, 2], [1, 0], [1, 2]] + >>> v = [3, 4, 5 ] + >>> s = torch.sparse_coo_tensor(list(zip(*i)), v, (2, 3)) + >>> # Or another equivalent formulation to get s + >>> s = torch.sparse_coo_tensor(torch.tensor(i).t(), v, (2, 3)) + >>> torch.sparse_coo_tensor(i.t(), v, torch.Size([2,3])).to_dense() + tensor([[0, 0, 3], + [4, 0, 5]]) + +An empty sparse COO tensor can be constructed by specifying its size +only: + + >>> torch.sparse_coo_tensor(size=(2, 3)) + tensor(indices=tensor([], size=(2, 0)), + values=tensor([], size=(0,)), + size=(2, 3), nnz=0, layout=torch.sparse_coo) + +.. _sparse-hybrid-coo-docs: + +Hybrid sparse COO tensors +------------------------- + +Pytorch implements an extension of sparse tensors with scalar values +to sparse tensors with (contiguous) tensor values. Such tensors are +called hybrid tensors. + +PyTorch hybrid COO tensor extends the sparse COO tensor by allowing +the ``values`` tensor to be a multi-dimensional tensor so that we +have: + + - the indices of specified elements are collected in ``indices`` + tensor of size ``(sparse_dims, nse)`` and with element type + ``torch.int64``, + + - the corresponding (tensor) values are collected in ``values`` + tensor of size ``(nse, dense_dims)`` and with an arbitrary integer + or floating point number element type. + +.. note:: + + We use (M + K)-dimensional tensor to denote a N-dimensional hybrid + sparse tensor, where M and K are the numbers of sparse and dense + dimensions, respectively, such that M + K == N holds. + +Suppose we want to create a (2 + 1)-dimensional tensor with the entry +[3, 4] at location (0, 2), entry [5, 6] at location (1, 0), and entry +[7, 8] at location (1, 2). We would write + + >>> i = [[0, 1, 1], + [2, 0, 2]] + >>> v = [[3, 4], [5, 6], [7, 8]] + >>> s = torch.sparse_coo_tensor(i, v, (2, 3, 2)) + >>> s + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([[3, 4], + [5, 6], + [7, 8]]), + size=(2, 3, 2), nnz=3, layout=torch.sparse_coo) + + >>> s.to_dense() + tensor([[[0, 0], + [0, 0], + [3, 4]], + [[5, 6], + [0, 0], + [7, 8]]]) + +In general, if ``s`` is a sparse COO tensor and ``M = +s.sparse_dim()``, ``K = s.dense_dim()``, then we have the following +invariants: + + - ``M + K == len(s.shape) == s.ndim`` - dimensionality of a tensor + is the sum of the number of sparse and dense dimensions, + - ``s.indices().shape == (M, nse)`` - sparse indices are stored + explicitly, + - ``s.values().shape == (nse,) + s.shape[M : M + K]`` - the values + of a hybrid tensor are K-dimensional tensors, + - ``s.values().layout == torch.strided`` - values are stored as + strided tensors. + +.. note:: + + Dense dimensions always follow sparse dimensions, that is, mixing + of dense and sparse dimensions is not supported. + +.. _sparse-uncoalesced-coo-docs: + +Uncoalesced sparse COO tensors +------------------------------ + +PyTorch sparse COO tensor format permits *uncoalesced* sparse tensors, +where there may be duplicate coordinates in the indices; in this case, +the interpretation is that the value at that index is the sum of all +duplicate value entries. For example, one can specify multiple values, +``3`` and ``4``, for the same index ``1``, that leads to an 1-D +uncoalesced tensor: + + >>> i = [[1, 1]] + >>> v = [3, 4] + >>> s=torch.sparse_coo_tensor(i, v, (3,)) + >>> s + tensor(indices=tensor([[1, 1]]), + values=tensor( [3, 4]), + size=(3,), nnz=2, layout=torch.sparse_coo) + +while the coalescing process will accumulate the multi-valued elements +into a single value using summation: + + >>> s.coalesce() + tensor(indices=tensor([[1]]), + values=tensor([7]), + size=(3,), nnz=1, layout=torch.sparse_coo) + +In general, the output of :meth:`torch.Tensor.coalesce` method is a +sparse tensor with the following properties: + +- the indices of specified tensor elements are unique, +- the indices are sorted in lexicographical order, +- :meth:`torch.Tensor.is_coalesced()` returns ``True``. + +.. note:: + + For the most part, you shouldn't have to care whether or not a + sparse tensor is coalesced or not, as most operations will work + identically given a coalesced or uncoalesced sparse tensor. + + However, some operations can be implemented more efficiently on + uncoalesced tensors, and some on coalesced tensors. + + For instance, addition of sparse COO tensors is implemented by + simply concatenating the indices and values tensors: + + >>> a = torch.sparse_coo_tensor([[1, 1]], [5, 6], (2,)) + >>> b = torch.sparse_coo_tensor([[0, 0]], [7, 8], (2,)) + >>> a + b + tensor(indices=tensor([[0, 0, 1, 1]]), + values=tensor([7, 8, 5, 6]), + size=(2,), nnz=4, layout=torch.sparse_coo) + + If you repeatedly perform an operation that can produce duplicate + entries (e.g., :func:`torch.Tensor.add`), you should occasionally + coalesce your sparse tensors to prevent them from growing too large. + + On the other hand, the lexicographical ordering of indices can be + advantageous for implementing algorithms that involve many element + selection operations, such as slicing or matrix products. + + +Working with sparse COO tensors +------------------------------- + +Let's consider the following example: + + >>> i = [[0, 1, 1], + [2, 0, 2]] + >>> v = [[3, 4], [5, 6], [7, 8]] + >>> s = torch.sparse_coo_tensor(i, v, (2, 3, 2)) + +As mentioned above, a sparse COO tensor is a :class:`torch.Tensor` +instance and to distinguish it from the `Tensor` instances that use +some other layout, on can use :attr:`torch.Tensor.is_sparse` or +:attr:`torch.Tensor.layout` properties: + + >>> isinstance(s, torch.Tensor) + True + >>> s.is_sparse + True + >>> s.layout == torch.sparse_coo + True + +The number of sparse and dense dimensions can be acquired using +methods :meth:`torch.Tensor.sparse_dim` and +:meth:`torch.Tensor.dense_dim`, respectively. For instance: + + >>> s.sparse_dim(), s.dense_dim() + (2, 1) + + +If ``s`` is a sparse COO tensor then its COO format data can be +acquired using methods :meth:`torch.Tensor.indices()` and +:meth:`torch.Tensor.values()`. + +.. note:: + + Currently, one can acquire the COO format data only when the tensor + instance is coalesced: + + >>> s.indices() + RuntimeError: Cannot get indices on an uncoalesced tensor, please call .coalesce() first + + For acquiring the COO format data of an uncoalesced tensor, use + :func:`torch.Tensor._values()` and :func:`torch.Tensor._indices()`: + + >>> s._indices() + tensor([[0, 1, 1], + [2, 0, 2]]) + + .. See https://github.com/pytorch/pytorch/pull/45695 for a new API. + +Constructing a new sparse COO tensor results a tensor that is not +coalesced: + + >>> s.is_coalesced() + False + +but one can construct a coalesced copy of a sparse COO tensor using +the :meth:`torch.Tensor.coalesce` method: + + >>> s2 = s.coalesce() + >>> s2.indices() + tensor([[0, 1, 1], + [2, 0, 2]]) + +When working with uncoalesced sparse COO tensors, one must take into +an account the additive nature of uncoalesced data: the values of the +same indices are the terms of a sum that evaluation gives the value of +the corresponding tensor element. For example, the scalar +multiplication on an uncoalesced sparse tensor could be implemented by +multiplying all the uncoalesced values with the scalar because ``c * +(a + b) == c * a + c * b`` holds. However, any nonlinear operation, +say, a square root, cannot be implemented by applying the operation to +uncoalesced data because ``sqrt(a + b) == sqrt(a) + sqrt(b)`` does not +hold in general. + +Slicing (with positive step) of a sparse COO tensor is supported only +for dense dimensions. Indexing is supported for both sparse and dense +dimensions: + + >>> s[1] + tensor(indices=tensor([[0, 2]]), + values=tensor([[5, 6], + [7, 8]]), + size=(3, 2), nnz=2, layout=torch.sparse_coo) + >>> s[1, 0, 1] + tensor(6) + >>> s[1, 0, 1:] + tensor([6]) + + +In PyTorch, the fill value of a sparse tensor cannot be specified +explicitly and is assumed to be zero in general. However, there exists +operations that may interpret the fill value differently. For +instance, :func:`torch.sparse.softmax` computes the softmax with the +assumption that the fill value is negative infinity. + +.. See https://github.com/Quansight-Labs/rfcs/tree/pearu/rfc-fill-value/RFC-0004-sparse-fill-value for a new API + +Supported Linear Algebra operations ++++++++++++++++++++++++++++++++++++ + +The following table summarizes supported Linear Algebra operations on +sparse matrices where the operands layouts may vary. Here +``T[layout]`` denotes a tensor with a given layout. Similarly, +``M[layout]`` denotes a matrix (2-D PyTorch tensor), and ``V[layout]`` +denotes a vector (1-D PyTorch tensor). In addition, ``f`` denotes a +scalar (float or 0-D PyTorch tensor), ``*`` is element-wise +multiplication, and ``@`` is matrix multiplication. + +.. csv-table:: + :header: "PyTorch operation", "Sparse grad?", "Layout signature" + :widths: 20, 5, 60 + :delim: ; + + :func:`torch.mv`;no; ``M[sparse_coo] @ V[strided] -> V[strided]`` + :func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]`` + :func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]`` + :func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]`` + :func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]`` + :func:`torch.hspmm`; no; ``M[sparse_coo] @ M[strided] -> M[hybrid sparse_coo]`` + :func:`torch.bmm`; no; ``T[sparse_coo] @ T[strided] -> T[strided]`` + :func:`torch.addmm`; no; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]`` + :func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]`` + :func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]`` + :func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]`` + :func:`torch.pca_lowrank`; yes; ``PCA(M[sparse_coo]) -> M[strided], M[strided], M[strided]`` + :func:`torch.svd_lowrank`; yes; ``SVD(M[sparse_coo]) -> M[strided], M[strided], M[strided]`` + +where "Sparse grad?" column indicates if the PyTorch operation supports +backward with respect to sparse matrix argument. All PyTorch operations, +except :func:`torch.smm`, support backward with respect to strided +matrix arguments. .. note:: - Our sparse tensor format permits *uncoalesced* sparse tensors, where - there may be duplicate coordinates in the indices; in this case, - the interpretation is that the value at that index is the sum of all - duplicate value entries. Uncoalesced tensors permit us to implement - certain operators more efficiently. - - For the most part, you shouldn't have to care whether or not a - sparse tensor is coalesced or not, as most operations will work - identically given a coalesced or uncoalesced sparse tensor. - However, there are two cases in which you may need to care. - - First, if you repeatedly perform an operation that can produce - duplicate entries (e.g., :func:`torch.sparse.FloatTensor.add`), you - should occasionally coalesce your sparse tensors to prevent - them from growing too large. - - Second, some operators will produce different values depending on - whether or not they are coalesced or not (e.g., - :func:`torch.sparse.FloatTensor._values` and - :func:`torch.sparse.FloatTensor._indices`, as well as - :func:`torch.Tensor.sparse_mask`). These operators are - prefixed by an underscore to indicate that they reveal internal - implementation details and should be used with care, since code - that works with coalesced sparse tensors may not work with - uncoalesced sparse tensors; generally speaking, it is safest - to explicitly coalesce before working with these operators. - - For example, suppose that we wanted to implement an operator - by operating directly on :func:`torch.sparse.FloatTensor._values`. - Multiplication by a scalar can be implemented in the obvious way, - as multiplication distributes over addition; however, square root - cannot be implemented directly, since ``sqrt(a + b) != sqrt(a) + - sqrt(b)`` (which is what would be computed if you were given an - uncoalesced tensor.) - -.. class:: FloatTensor() - - .. method:: add - .. method:: add_ - .. method:: clone - .. method:: dim - .. method:: div - .. method:: div_ - .. method:: get_device - .. method:: hspmm - .. method:: mm - .. method:: mul - .. method:: mul_ - .. method:: narrow_copy - .. method:: resizeAs_ - .. method:: size - .. method:: spadd - .. method:: spmm - .. method:: sspaddmm - .. method:: sspmm - .. method:: sub - .. method:: sub_ - .. method:: t_ - .. method:: to_dense - .. method:: transpose - .. method:: transpose_ - .. method:: zero_ - .. method:: coalesce - .. method:: is_coalesced - .. method:: _indices - .. method:: _values - .. method:: _nnz - -Functions ----------------------------------- + Currently, PyTorch does not support matrix multiplication with the + layout signature ``M[strided] @ M[sparse_coo]``. However, + applications can still compute this using the matrix relation ``D @ + S == (S.t() @ D.t()).t()``. +.. class:: Tensor() + :noindex: + + The following methods are specific to :ref:`sparse tensors `: + + .. autoattribute:: is_sparse + .. automethod:: dense_dim + .. automethod:: sparse_dim + .. automethod:: sparse_mask + .. automethod:: sparse_resize_ + .. automethod:: sparse_resize_and_clear_ + .. automethod:: to_dense + .. automethod:: to_sparse + .. The following methods are specific to :ref:`sparse COO tensors `: + .. automethod:: coalesce + .. automethod:: is_coalesced + .. automethod:: indices + .. automethod:: values + +The following :class:`torch.Tensor` methods support :ref:`sparse COO +tensors `: + +:meth:`~torch.Tensor.add` +:meth:`~torch.Tensor.add_` +:meth:`~torch.Tensor.addmm` +:meth:`~torch.Tensor.addmm_` +:meth:`~torch.Tensor.any` +:meth:`~torch.Tensor.asin` +:meth:`~torch.Tensor.asin_` +:meth:`~torch.Tensor.arcsin` +:meth:`~torch.Tensor.arcsin_` +:meth:`~torch.Tensor.bmm` +:meth:`~torch.Tensor.clone` +:meth:`~torch.Tensor.deg2rad` +:meth:`~torch.Tensor.deg2rad_` +:meth:`~torch.Tensor.detach` +:meth:`~torch.Tensor.detach_` +:meth:`~torch.Tensor.dim` +:meth:`~torch.Tensor.div` +:meth:`~torch.Tensor.div_` +:meth:`~torch.Tensor.floor_divide` +:meth:`~torch.Tensor.floor_divide_` +:meth:`~torch.Tensor.get_device` +:meth:`~torch.Tensor.index_select` +:meth:`~torch.Tensor.isnan` +:meth:`~torch.Tensor.log1p` +:meth:`~torch.Tensor.log1p_` +:meth:`~torch.Tensor.mm` +:meth:`~torch.Tensor.mul` +:meth:`~torch.Tensor.mul_` +:meth:`~torch.Tensor.mv` +:meth:`~torch.Tensor.narrow_copy` +:meth:`~torch.Tensor.neg` +:meth:`~torch.Tensor.neg_` +:meth:`~torch.Tensor.negative` +:meth:`~torch.Tensor.negative_` +:meth:`~torch.Tensor.numel` +:meth:`~torch.Tensor.rad2deg` +:meth:`~torch.Tensor.rad2deg_` +:meth:`~torch.Tensor.resize_as_` +:meth:`~torch.Tensor.size` +:meth:`~torch.Tensor.pow` +:meth:`~torch.Tensor.square` +:meth:`~torch.Tensor.smm` +:meth:`~torch.Tensor.sspaddmm` +:meth:`~torch.Tensor.sub` +:meth:`~torch.Tensor.sub_` +:meth:`~torch.Tensor.t` +:meth:`~torch.Tensor.t_` +:meth:`~torch.Tensor.transpose` +:meth:`~torch.Tensor.transpose_` +:meth:`~torch.Tensor.zero_` + + +Sparse tensor functions ++++++++++++++++++++++++ + +.. autofunction:: torch.sparse_coo_tensor +.. autofunction:: torch.sparse.sum .. autofunction:: torch.sparse.addmm .. autofunction:: torch.sparse.mm -.. autofunction:: torch.sparse.sum +.. autofunction:: torch.sspaddmm +.. autofunction:: torch.hspmm +.. autofunction:: torch.smm +.. autofunction:: torch.sparse.softmax +.. autofunction:: torch.sparse.log_softmax + +Other functions ++++++++++++++++ + +The following :mod:`torch` functions support :ref:`sparse COO tensors `: + +:func:`~torch.cat` +:func:`~torch.dstack` +:func:`~torch.empty` +:func:`~torch.empty_like` +:func:`~torch.hstack` +:func:`~torch.index_select` +:func:`~torch.is_complex` +:func:`~torch.is_floating_point` +:func:`~torch.is_nonzero` +:func:`~torch.is_same_size` +:func:`~torch.is_signed` +:func:`~torch.is_tensor` +:func:`~torch.lobpcg` +:func:`~torch.mm` +:func:`~torch.native_norm` +:func:`~torch.pca_lowrank` +:func:`~torch.select` +:func:`~torch.stack` +:func:`~torch.svd_lowrank` +:func:`~torch.unsqueeze` +:func:`~torch.vstack` +:func:`~torch.zeros` +:func:`~torch.zeros_like` diff --git a/docs/source/tensor_view.rst b/docs/source/tensor_view.rst index ce3f4ff5edaf4..059a76e2b28db 100644 --- a/docs/source/tensor_view.rst +++ b/docs/source/tensor_view.rst @@ -73,6 +73,8 @@ For reference, here’s a full list of view ops in PyTorch: - :meth:`~torch.Tensor.unbind` - :meth:`~torch.Tensor.split` - :meth:`~torch.Tensor.split_with_sizes` +- :meth:`~torch.Tensor.swapaxes` +- :meth:`~torch.Tensor.swapdims` - :meth:`~torch.Tensor.chunk` - :meth:`~torch.Tensor.indices` (sparse tensor only) - :meth:`~torch.Tensor.values` (sparse tensor only) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 7cd1a88f82b38..1baf34dd955ea 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -189,6 +189,8 @@ view of a storage and defines numeric operations on it. .. automethod:: addcmul_ .. automethod:: addmm .. automethod:: addmm_ + .. automethod:: sspaddmm + :noindex: .. automethod:: addmv .. automethod:: addmv_ .. automethod:: addr @@ -212,6 +214,8 @@ view of a storage and defines numeric operations on it. .. automethod:: arctan_ .. automethod:: atan2 .. automethod:: atan2_ + .. automethod:: all + .. automethod:: any .. automethod:: backward :noindex: .. automethod:: baddbmm @@ -231,6 +235,7 @@ view of a storage and defines numeric operations on it. .. automethod:: bmm .. automethod:: bool .. automethod:: byte + .. automethod:: broadcast_to .. automethod:: cauchy_ .. automethod:: ceil .. automethod:: ceil_ @@ -247,6 +252,8 @@ view of a storage and defines numeric operations on it. .. automethod:: contiguous .. automethod:: copy_ .. automethod:: conj + .. automethod:: copysign + .. automethod:: copysign_ .. automethod:: cos .. automethod:: cos_ .. automethod:: cosh @@ -263,12 +270,15 @@ view of a storage and defines numeric operations on it. .. automethod:: cummax .. automethod:: cummin .. automethod:: cumprod + .. automethod:: cumprod_ .. automethod:: cumsum + .. automethod:: cumsum_ .. automethod:: data_ptr .. automethod:: deg2rad .. automethod:: dequantize .. automethod:: det .. automethod:: dense_dim + :noindex: .. automethod:: detach :noindex: .. automethod:: detach_ @@ -308,13 +318,14 @@ view of a storage and defines numeric operations on it. .. automethod:: exponential_ .. automethod:: fix .. automethod:: fix_ - .. automethod:: fft .. automethod:: fill_ .. automethod:: flatten .. automethod:: flip .. automethod:: fliplr .. automethod:: flipud .. automethod:: float + .. automethod:: float_power + .. automethod:: float_power_ .. automethod:: floor .. automethod:: floor_ .. automethod:: floor_divide @@ -346,7 +357,10 @@ view of a storage and defines numeric operations on it. .. automethod:: hypot_ .. automethod:: i0 .. automethod:: i0_ - .. automethod:: ifft + .. automethod:: igamma + .. automethod:: igamma_ + .. automethod:: igammac + .. automethod:: igammac_ .. automethod:: index_add_ .. automethod:: index_add .. automethod:: index_copy_ @@ -357,10 +371,11 @@ view of a storage and defines numeric operations on it. .. automethod:: index_put .. automethod:: index_select .. automethod:: indices + :noindex: + .. automethod:: inner .. automethod:: int .. automethod:: int_repr .. automethod:: inverse - .. automethod:: irfft .. automethod:: isclose .. automethod:: isfinite .. automethod:: isinf @@ -377,12 +392,15 @@ view of a storage and defines numeric operations on it. .. automethod:: is_shared .. automethod:: is_signed .. autoattribute:: is_sparse + :noindex: .. automethod:: istft .. automethod:: isreal .. automethod:: item .. automethod:: kthvalue .. automethod:: lcm .. automethod:: lcm_ + .. automethod:: ldexp + .. automethod:: ldexp_ .. automethod:: le .. automethod:: le_ .. automethod:: less_equal @@ -436,11 +454,16 @@ view of a storage and defines numeric operations on it. .. automethod:: maximum .. automethod:: mean .. automethod:: median + .. automethod:: nanmedian .. automethod:: min .. automethod:: minimum .. automethod:: mm + .. automethod:: smm + :noindex: .. automethod:: mode .. automethod:: movedim + .. automethod:: moveaxis + .. automethod:: msort .. automethod:: mul .. automethod:: mul_ .. automethod:: multiply @@ -453,6 +476,8 @@ view of a storage and defines numeric operations on it. .. automethod:: narrow .. automethod:: narrow_copy .. automethod:: ndimension + .. automethod:: nan_to_num + .. automethod:: nan_to_num_ .. automethod:: ne .. automethod:: ne_ .. automethod:: not_equal @@ -492,6 +517,7 @@ view of a storage and defines numeric operations on it. .. automethod:: q_per_channel_axis .. automethod:: rad2deg .. automethod:: random_ + .. automethod:: ravel .. automethod:: reciprocal .. automethod:: reciprocal_ .. automethod:: record_stream @@ -512,7 +538,6 @@ view of a storage and defines numeric operations on it. .. automethod:: resize_as_ .. automethod:: retain_grad :noindex: - .. automethod:: rfft .. automethod:: roll .. automethod:: rot90 .. automethod:: round @@ -536,6 +561,8 @@ view of a storage and defines numeric operations on it. .. automethod:: sgn_ .. automethod:: sin .. automethod:: sin_ + .. automethod:: sinc + .. automethod:: sinc_ .. automethod:: sinh .. automethod:: sinh_ .. automethod:: asinh @@ -548,7 +575,9 @@ view of a storage and defines numeric operations on it. .. automethod:: sort .. automethod:: split .. automethod:: sparse_mask + :noindex: .. automethod:: sparse_dim + :noindex: .. automethod:: sqrt .. automethod:: sqrt_ .. automethod:: square @@ -568,9 +597,13 @@ view of a storage and defines numeric operations on it. .. automethod:: sum .. automethod:: sum_to_size .. automethod:: svd + .. automethod:: swapaxes + .. automethod:: swapdims .. automethod:: symeig .. automethod:: t .. automethod:: t_ + .. automethod:: tensor_split + .. automethod:: tile .. automethod:: to .. automethod:: to_mkldnn .. automethod:: take @@ -585,6 +618,7 @@ view of a storage and defines numeric operations on it. .. automethod:: tolist .. automethod:: topk .. automethod:: to_sparse + :noindex: .. automethod:: trace .. automethod:: transpose .. automethod:: transpose_ @@ -607,16 +641,12 @@ view of a storage and defines numeric operations on it. .. automethod:: unsqueeze .. automethod:: unsqueeze_ .. automethod:: values + :noindex: .. automethod:: var .. automethod:: vdot .. automethod:: view .. automethod:: view_as .. automethod:: where + .. automethod:: xlogy + .. automethod:: xlogy_ .. automethod:: zero_ - -.. class:: BoolTensor() - - The following methods are unique to :class:`torch.BoolTensor`. - - .. automethod:: all - .. automethod:: any diff --git a/docs/source/torch.nn.quantized.rst b/docs/source/torch.nn.quantized.rst index a9aaa51a33bf4..aeb3b55cd5fda 100644 --- a/docs/source/torch.nn.quantized.rst +++ b/docs/source/torch.nn.quantized.rst @@ -1,14 +1,12 @@ torch.nn.quantized ------------------ -This module implements the quantized versions of the nn layers such as -~`torch.nn.Conv2d` and `torch.nn.ReLU`. +This module implements the quantized versions of the nn modules and functionals. Functional interface ~~~~~~~~~~~~~~~~~~~~ .. automodule:: torch.nn.quantized.functional -.. autofunction:: relu .. autofunction:: linear .. autofunction:: conv1d .. autofunction:: conv2d @@ -25,11 +23,6 @@ Functional interface .. automodule:: torch.nn.quantized -ReLU -~~~~~~~~~~~~~~~ -.. autoclass:: ReLU - :members: - ReLU6 ~~~~~~~~~~~~~~~ .. autoclass:: ReLU6 @@ -119,5 +112,3 @@ InstanceNorm3d ~~~~~~~~~~~~~~~ .. autoclass:: InstanceNorm3d :members: - - diff --git a/docs/source/torch.overrides.rst b/docs/source/torch.overrides.rst new file mode 100644 index 0000000000000..0630b60c4b177 --- /dev/null +++ b/docs/source/torch.overrides.rst @@ -0,0 +1,27 @@ +.. currentmodule:: torch.overrides + +torch.overrides +--------------- + +This module exposes various helper functions for the ``__torch_function__`` +protocol. See :ref:`extending-torch` for more detail on the +``__torch_function__`` protocol. + +Functions +~~~~~~~~~ + +.. autofunction:: get_ignored_functions + +.. autofunction:: get_overridable_functions + +.. autofunction:: get_testing_overrides + +.. autofunction:: handle_torch_function + +.. autofunction:: has_torch_function + +.. autofunction:: is_tensor_like + +.. autofunction:: is_tensor_method_or_property + +.. autofunction:: wrap_torch_function diff --git a/docs/source/torch.rst b/docs/source/torch.rst index beab6c449df11..922e1434bae1e 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -1,7 +1,7 @@ torch ===== The torch package contains data structures for multi-dimensional -tensors and mathematical operations over these are defined. +tensors and defines mathematical operations over these tensors. Additionally, it provides many utilities for efficient serializing of Tensors and arbitrary types, and other useful utilities. @@ -85,20 +85,29 @@ Indexing, Slicing, Joining, Mutating Ops cat chunk + column_stack dstack gather hstack index_select masked_select movedim + moveaxis narrow nonzero reshape + row_stack + scatter + scatter_add split squeeze stack + swapaxes + swapdims t take + tensor_split + tile transpose unbind unsqueeze @@ -276,6 +285,7 @@ Pointwise Ops clamp clip conj + copysign cos cosh deg2rad @@ -289,11 +299,13 @@ Pointwise Ops exp2 expm1 fix + float_power floor floor_divide fmod frac imag + ldexp lerp lgamma log @@ -309,9 +321,12 @@ Pointwise Ops logit hypot i0 + igamma + igammac mul multiply mvlgamma + nan_to_num neg negative nextafter @@ -327,6 +342,7 @@ Pointwise Ops sign signbit sin + sinc sinh sqrt square @@ -336,6 +352,7 @@ Pointwise Ops tanh true_divide trunc + xlogy Reduction Ops ~~~~~~~~~~~~~~~~~~~~~~ @@ -347,12 +364,15 @@ Reduction Ops argmin amax amin + all + any max min dist logsumexp mean median + nanmedian mode norm nansum @@ -400,6 +420,7 @@ Comparison Ops not_equal sort topk + msort Spectral Ops @@ -408,10 +429,6 @@ Spectral Ops :toctree: generated :nosignatures: - fft - ifft - rfft - irfft stft istft bartlett_window @@ -434,6 +451,8 @@ Other Operations bincount block_diag broadcast_tensors + broadcast_to + broadcast_shapes bucketize cartesian_prod cdist @@ -453,12 +472,14 @@ Other Operations flip fliplr flipud + kron rot90 gcd histc meshgrid lcm logcumsumexp + ravel renorm repeat_interleave roll @@ -494,6 +515,7 @@ BLAS and LAPACK Operations eig geqrf ger + inner inverse det logdet @@ -536,3 +558,4 @@ Utilities set_deterministic is_deterministic vmap + _assert diff --git a/ios/LibTorch.podspec b/ios/LibTorch.podspec index 17e9fb26afa18..b90cf6aff5d6e 100644 --- a/ios/LibTorch.podspec +++ b/ios/LibTorch.podspec @@ -1,6 +1,6 @@ Pod::Spec.new do |s| s.name = 'LibTorch' - s.version = '1.6.0' + s.version = '1.7.1' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/pytorch' diff --git a/mode/aibench_caffe2_android b/mode/aibench_caffe2_android deleted file mode 100644 index fbf755b784386..0000000000000 --- a/mode/aibench_caffe2_android +++ /dev/null @@ -1,3 +0,0 @@ ---config -caffe2.strip_glog=0 -@fbsource//fbandroid/mode/gnustl diff --git a/mode/aibench_pytorch_android b/mode/aibench_pytorch_android deleted file mode 100644 index 2572d24d3032e..0000000000000 --- a/mode/aibench_pytorch_android +++ /dev/null @@ -1,5 +0,0 @@ ---config -user.ndk_cxxflags='-g1' ---config -pt.disable_per_op_profiling=0 -@fbsource//fbandroid/mode/ndk_libcxx diff --git a/modules/detectron/group_spatial_softmax_op.cu b/modules/detectron/group_spatial_softmax_op.cu index 92e89ae5acc2c..a37a3fba55a73 100644 --- a/modules/detectron/group_spatial_softmax_op.cu +++ b/modules/detectron/group_spatial_softmax_op.cu @@ -112,6 +112,7 @@ bool GroupSpatialSoftmaxOp::RunOnDevice() { GroupSpatialSoftmaxKernel<<>>( N, A, W, H, Xdata, Pdata, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -158,11 +159,13 @@ bool GroupSpatialSoftmaxGradientOp::RunOnDevice() { SumProbsKernel<<>>( N, A, W, H, Ydata, dYdata, sum_probs_data, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Step 2: dX[i] = dX[i] - s SubSumKernel<<>>( N, A, W, H, sum_probs_.data(), dXdata, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Step 3: dX[i] = Y[i] * dX[i] math::Mul(Y.size(), dXdata, Ydata, dXdata, &context_); diff --git a/modules/detectron/ps_roi_pool_op.cu b/modules/detectron/ps_roi_pool_op.cu index 1ba418be5c990..68e4ec377d622 100644 --- a/modules/detectron/ps_roi_pool_op.cu +++ b/modules/detectron/ps_roi_pool_op.cu @@ -253,6 +253,7 @@ bool PSRoIPoolOp::RunOnDevice() { output_size, X.data(), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, R.data(), output_dim_, group_size_, Y->mutable_data(), A->mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -276,6 +277,7 @@ bool PSRoIPoolGradientOp::RunOnDevice() { dY.size(), dY.data(), A.data(), R.dim32(0), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, output_dim_, dX->mutable_data(), R.data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/modules/detectron/roi_pool_f_op.cu b/modules/detectron/roi_pool_f_op.cu index 62948f7eacbe4..b261911b95a16 100644 --- a/modules/detectron/roi_pool_f_op.cu +++ b/modules/detectron/roi_pool_f_op.cu @@ -149,6 +149,7 @@ bool RoIPoolFOp::RunOnDevice() { output_size, X.data(), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, R.data(), Y->mutable_data(), A->mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -173,6 +174,7 @@ bool RoIPoolFGradientOp::RunOnDevice() { dY.size(), dY.data(), A.data(), R.dim32(0), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_, dX->mutable_data(), R.data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } return true; } diff --git a/modules/detectron/select_smooth_l1_loss_op.cu b/modules/detectron/select_smooth_l1_loss_op.cu index 9065bfc7afbea..ce68fcff634d6 100644 --- a/modules/detectron/select_smooth_l1_loss_op.cu +++ b/modules/detectron/select_smooth_l1_loss_op.cu @@ -129,6 +129,7 @@ bool SelectSmoothL1LossOp::RunOnDevice() { M, Y_hat.data(), Y.data(), L.data(), buff_.mutable_data(), S.data(), beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Sum of all losses // al := sum_i l_i @@ -175,6 +176,7 @@ bool SelectSmoothL1LossGradientOp::RunOnDevice() { D, H, W, M, Y_hat.data(), Y.data(), L.data(), d_Y_hat->mutable_data(), d_avg_loss.data(), scale_, S.data(), beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/modules/detectron/sigmoid_cross_entropy_loss_op.cu b/modules/detectron/sigmoid_cross_entropy_loss_op.cu index d69a7b41dc33b..bb86560fcb01f 100644 --- a/modules/detectron/sigmoid_cross_entropy_loss_op.cu +++ b/modules/detectron/sigmoid_cross_entropy_loss_op.cu @@ -93,6 +93,8 @@ bool SigmoidCrossEntropyLossOp::RunOnDevice() { T.data(), losses_.mutable_data(), counts_.mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + float* avg_loss_data = avg_loss->mutable_data(); math::Sum( losses_.size(), losses_.data(), avg_loss_data, &context_); @@ -106,6 +108,7 @@ bool SigmoidCrossEntropyLossOp::RunOnDevice() { CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(normalizer_.size(), normalizer_data, 1e-5); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Div( 1, avg_loss_data, normalizer_data, avg_loss_data, &context_); } @@ -135,6 +138,7 @@ bool SigmoidCrossEntropyLossGradientOp::RunOnDevice() { T.data(), dX->mutable_data(), counts_.mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); if (normalize_) { float* normalizer_data = normalizer_.mutable_data(); math::Sum( @@ -145,6 +149,7 @@ bool SigmoidCrossEntropyLossGradientOp::RunOnDevice() { CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(normalizer_.size(), normalizer_data, 1e-5); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Div( 1, d_avg_loss.data(), diff --git a/modules/detectron/sigmoid_focal_loss_op.cu b/modules/detectron/sigmoid_focal_loss_op.cu index 5b130c8dfc1fb..e6f2dea21b5df 100644 --- a/modules/detectron/sigmoid_focal_loss_op.cu +++ b/modules/detectron/sigmoid_focal_loss_op.cu @@ -134,6 +134,7 @@ bool SigmoidFocalLossOp::RunOnDevice() { N, D, H, W, X.data(), T.data(), wp.data(), gamma_, alpha_, num_classes_, losses_.mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Sum( losses_.size(), losses_.data(), avg_loss_data, &context_); @@ -165,6 +166,7 @@ bool SigmoidFocalLossGradientOp::RunOnDevice() { N, D, H, W, X.data(), T.data(), dX->mutable_data(), wp.data(), gamma_, alpha_, num_classes_, d_avg_loss.data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Scale( dX->size(), scale_, diff --git a/modules/detectron/smooth_l1_loss_op.cu b/modules/detectron/smooth_l1_loss_op.cu index 1a3e8b78b53f1..ea835a4bc2b97 100644 --- a/modules/detectron/smooth_l1_loss_op.cu +++ b/modules/detectron/smooth_l1_loss_op.cu @@ -102,6 +102,7 @@ bool SmoothL1LossOp::RunOnDevice() { context_.cuda_stream()>>>( buff_.size(), buff_.data(), buff_.mutable_data(), beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Element-wise weighted smooth l1 loss (can be used to specify a per-element // loss weight) @@ -164,6 +165,8 @@ bool SmoothL1LossGradientOp::RunOnDevice() { context_.cuda_stream()>>>( buff_.size(), buff_.data(), d_Y_hat->mutable_data(), d_avg_loss.data(), scale_ / N, beta_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Element-wise scale by alpha_in and alpha_out math::Mul( d_Y_hat->size(), d_Y_hat->data(), alpha_in.data(), diff --git a/modules/detectron/softmax_focal_loss_op.cu b/modules/detectron/softmax_focal_loss_op.cu index 93635269f176c..b7f8d2423ebc0 100644 --- a/modules/detectron/softmax_focal_loss_op.cu +++ b/modules/detectron/softmax_focal_loss_op.cu @@ -176,6 +176,7 @@ bool SoftmaxFocalLossOp::RunOnDevice() { <<>>( N, A, H, W, Xdata, P->mutable_data(), num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // Compute loss for each x,y location const int* Tdata = T.data(); @@ -184,6 +185,7 @@ bool SoftmaxFocalLossOp::RunOnDevice() { 0, context_.cuda_stream()>>>( N, A, H, W, P->data(), Tdata, losses_.mutable_data(), Wdata, gamma_, alpha_, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // sum the losses float* avg_loss_data = avg_loss->mutable_data(); @@ -227,6 +229,8 @@ bool SoftmaxFocalLossGradientOp::RunOnDevice() { 0, context_.cuda_stream()>>>( N, A, H, W, Pdata, Tdata, buff_.mutable_data(), Wdata, gamma_, alpha_, num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Compute the gradient with the weights const float* Bdata = buff_.data(); SoftmaxFocalLossGradientKernel @@ -234,6 +238,7 @@ bool SoftmaxFocalLossGradientOp::RunOnDevice() { 0, context_.cuda_stream()>>>( N, D, H, W, Pdata, Tdata, Bdata, d_avg_loss.data(), dX->mutable_data(), num_classes_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); math::Scale( dX->size(), scale_, diff --git a/modules/detectron/spatial_narrow_as_op.cu b/modules/detectron/spatial_narrow_as_op.cu index 97ddc492eb074..ff8b5632e80a8 100644 --- a/modules/detectron/spatial_narrow_as_op.cu +++ b/modules/detectron/spatial_narrow_as_op.cu @@ -115,6 +115,7 @@ bool SpatialNarrowAsOp::DoRunWithType() { out_width, A.template data(), C->template mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } @@ -152,6 +153,7 @@ bool SpatialNarrowAsGradientOp::DoRunWithType() { out_width, dC.template data(), dA->template mutable_data()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/modules/detectron/upsample_nearest_op.cu b/modules/detectron/upsample_nearest_op.cu index 38af4254f9221..0ea32e348c0b3 100644 --- a/modules/detectron/upsample_nearest_op.cu +++ b/modules/detectron/upsample_nearest_op.cu @@ -164,6 +164,8 @@ bool UpsampleNearestOp::RunOnDevice() { upscale<<>>( input_data, output_data, no_elements, scale_, d1, d2, d3); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return true; } @@ -209,6 +211,7 @@ bool UpsampleNearestGradientOp::RunOnDevice() { math::Set(no_elements, 0.f, gradInput_data, &context_); downscale<<>>( gradInput_data, gradOutput_data, no_elements, scale_, d1, d2, d3); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } diff --git a/modules/observers/perf_observer.cc b/modules/observers/perf_observer.cc index d64bdf1446023..c15eda5dda000 100644 --- a/modules/observers/perf_observer.cc +++ b/modules/observers/perf_observer.cc @@ -18,6 +18,9 @@ defined(TARGET_IPHONE_SIMULATOR) #endif #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif #include #endif diff --git a/mypy-strict.ini b/mypy-strict.ini index 95a8d599606e4..00545679e8f19 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -2,11 +2,12 @@ # rules. The intention is for this config file to be used to ENFORCE that # people are using mypy on codegen files. # -# For now, only code_template.py is covered this way +# For now, only code_template.py and benchmark utils Timer are covered this way [mypy] python_version = 3.6 +cache_dir = .mypy_cache/strict strict_optional = True show_column_numbers = True warn_no_return = True @@ -29,4 +30,26 @@ warn_return_any = True implicit_reexport = False strict_equality = True -files = tools/codegen/gen.py +files = tools/codegen/gen.py, + tools/autograd/*.py, + tools/pyi/*.py, + torch/utils/benchmark/utils/common.py, + torch/utils/benchmark/utils/timer.py, + torch/utils/benchmark/utils/valgrind_wrapper/*.py, + torch/utils/_pytree.py + +# Specifically enable imports of benchmark utils. As more of `torch` becomes +# strict compliant, those modules can be enabled as well. +[mypy-torch.utils.benchmark.utils.*] +follow_imports = normal + +# Don't follow imports as much of `torch` is not strict compliant. +[mypy-torch] +follow_imports = skip + +[mypy-torch.*] +follow_imports = skip + +# Missing stub. +[mypy-numpy] +ignore_missing_imports = True diff --git a/mypy.ini b/mypy.ini index a7c82cb693590..8b45ee3154ed9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,10 +1,12 @@ # This is the PyTorch MyPy config file (note: don't change this line! - # test_run_mypy in test/test_type_hints.py uses this string) [mypy] +cache_dir = .mypy_cache/normal warn_unused_configs = True warn_redundant_casts = True show_error_codes = True check_untyped_defs = True +follow_imports = silent # # Note: test/ still has syntax errors so can't be added @@ -17,11 +19,18 @@ check_untyped_defs = True files = torch, caffe2, + test/type_hint_tests, + test/test_bundled_images.py, + test/test_bundled_inputs.py, test/test_complex.py, + test/test_dataset.py, + test/test_expecttest.py, test/test_futures.py, + test/test_numpy_interop.py, test/test_torch.py, test/test_type_hints.py, - test/test_type_info.py + test/test_type_info.py, + test/test_utils.py # Minimum version supported - variable annotations were introduced @@ -50,160 +59,33 @@ check_untyped_defs = False [mypy-torch._torch_docs] ignore_errors = True -[mypy-torch.distributed.*] -ignore_errors = True - -[mypy-torch.testing._internal.codegen.*] +[mypy-torch.backends._nnapi.*] ignore_errors = True [mypy-torch.testing._internal.hypothesis_utils.*] ignore_errors = True -[mypy-torch.testing._internal.common_nn.*] -ignore_errors = True - [mypy-torch.testing._internal.common_quantization.*] ignore_errors = True -[mypy-torch.testing._internal.common_utils.*] -ignore_errors = True - [mypy-torch.testing._internal.generated.*] ignore_errors = True [mypy-torch.testing._internal.distributed.*] ignore_errors = True -[mypy-torch.quantization.observer] -ignore_errors = True - -[mypy-torch.quantization.stubs] -ignore_errors = True - -[mypy-torch.quantization.fake_quantize] -ignore_errors = True - -[mypy-torch.quantization.quantize_jit] -ignore_errors = True - -[mypy-torch.quantization._numeric_suite] -ignore_errors = True - - -[mypy-torch.quantization.quantize_fx] -ignore_errors = True - -[mypy-torch.quantization.fx.*] -ignore_errors = True - -[mypy-torch.quasirandom] -ignore_errors = True - -[mypy-torch.distributions.*] -ignore_errors = True - -[mypy-torch._tensor_str] -ignore_errors = True - -[mypy-torch.nn.modules.batchnorm] -ignore_errors = True - [mypy-torch.nn.modules.container] ignore_errors = True -[mypy-torch.nn.modules.conv] -ignore_errors = True - -[mypy-torch.nn.modules.fold] -ignore_errors = True - -[mypy-torch.nn.modules.instancenorm] -ignore_errors = True - -[mypy-torch.nn.modules.linear] -ignore_errors = True - -[mypy-torch.nn.modules.loss] -ignore_errors = True - -[mypy-torch.nn.modules.module] -ignore_errors = True - -[mypy-torch.nn.modules.normalization] -ignore_errors = True - -[mypy-torch.nn.modules.padding] -ignore_errors = True - [mypy-torch.nn.modules.pooling] ignore_errors = True -[mypy-torch.nn.modules.rnn] -ignore_errors = True - -[mypy-torch.nn.modules.sparse] -ignore_errors = True - [mypy-torch.nn.parallel._functions] ignore_errors = True -[mypy-torch.nn.parallel.comm] -ignore_errors = True - -[mypy-torch.nn.quantized.functional] -ignore_errors = True - -[mypy-torch.nn.quantized.modules] -ignore_errors = True - -[mypy-torch.nn.quantized.modules.activation] -ignore_errors = True - -[mypy-torch.nn.quantized.modules.normalization] -ignore_errors = True - -[mypy-torch.nn.quantized.modules.utils] -ignore_errors = True - -[mypy-torch.nn.qat.modules.activations] -ignore_errors = True - -[mypy-torch.nn.qat.modules.conv] -ignore_errors = True - -[mypy-torch.nn.quantized.dynamic.modules.linear] -ignore_errors = True - -[mypy-torch.nn.quantized.modules.conv] -ignore_errors = True - -[mypy-torch.nn.quantized.modules.functional_modules] -ignore_errors = True - -[mypy-torch.cuda] -ignore_errors = True - -[mypy-torch.cuda.amp.*] -ignore_errors = True - -[mypy-torch.cuda.comm] -ignore_errors = True - -[mypy-torch.cuda.nccl] -ignore_errors = True - -[mypy-torch._lobpcg] -ignore_errors = True - [mypy-torch._appdirs] ignore_errors = True -[mypy-torch.storage] -ignore_errors = True - -[mypy-torch._utils] -ignore_errors = True - [mypy-torch._overrides] ignore_errors = True @@ -213,102 +95,36 @@ ignore_errors = True [mypy-torch.contrib._tensorboard_vis] ignore_errors = True -[mypy-torch.utils.data._utils.worker] -ignore_errors = True - -[mypy-torch.utils.data.distributed] -ignore_errors = True - [mypy-torch.nn.utils.prune] ignore_errors = True -[mypy-torch.nn.cpp] -ignore_errors = True - [mypy-torch.utils.show_pickle] ignore_errors = True [mypy-torch.utils.hipify.hipify_python] ignore_errors = True -[mypy-torch.autograd._functions.tensor] -ignore_errors = True - -[mypy-torch.autograd.function] -ignore_errors = True - -[mypy-torch.autograd.functional] -ignore_errors = True - -[mypy-torch.autograd.profiler] -ignore_errors = True - -[mypy-torch.autograd.gradcheck] -ignore_errors = True - -[mypy-torch.autograd.anomaly_mode] -ignore_errors = True - -[mypy-torch.autograd.variable] +[mypy-torch.utils.benchmark.examples.*] ignore_errors = True [mypy-torch.nn.quantized.modules.batchnorm] ignore_errors = True -[mypy-torch.nn.intrinsic.quantized.modules.conv_relu] -ignore_errors = True - -[mypy-torch.nn.intrinsic.quantized.modules.bn_relu] -ignore_errors = True - -[mypy-torch.nn.intrinsic.quantized.modules.linear_relu] -ignore_errors = True - [mypy-torch.nn.intrinsic.qat.modules.conv_fused] ignore_errors = True -[mypy-torch.onnx.operators] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset8] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset9] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset11] -ignore_errors = True - -[mypy-torch.onnx.symbolic_caffe2] -ignore_errors = True - -[mypy-torch.onnx.symbolic_helper] -ignore_errors = True - -[mypy-torch.onnx.symbolic_registry] -ignore_errors = True - -[mypy-torch.onnx.utils] -ignore_errors = True - -[mypy-torch.multiprocessing] -ignore_errors = True - -[mypy-torch.multiprocessing.reductions] -ignore_errors = True - -[mypy-torch.multiprocessing.queue] -ignore_errors = True - [mypy-torch.multiprocessing.pool] ignore_errors = True -[mypy-torch.multiprocessing.spawn] -ignore_errors = True - [mypy-torch.overrides] ignore_errors = True +# +# Adding type annotations to caffe2 is probably not worth the effort +# only work on this if you have a specific reason for it, otherwise +# leave these ignores as they are. +# + [mypy-caffe2.python.*] ignore_errors = True diff --git a/requirements-flake8.txt b/requirements-flake8.txt new file mode 100644 index 0000000000000..1e2ba252556f3 --- /dev/null +++ b/requirements-flake8.txt @@ -0,0 +1,8 @@ +flake8==3.8.2 +flake8-bugbear==20.1.4 +flake8-comprehensions==3.3.0 +flake8-executable==2.0.4 +flake8-pyi==20.5.0 +mccabe +pycodestyle==2.6.0 +pyflakes==2.2.0 diff --git a/requirements.txt b/requirements.txt index 07127f738ff92..759baf3984c3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ requests setuptools six typing_extensions -dataclasses +dataclasses; python_version<"3.7" diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh index 75ee1ed2001a1..b4d77878282e3 100755 --- a/scripts/build_ios.sh +++ b/scripts/build_ios.sh @@ -92,6 +92,12 @@ CMAKE_ARGS+=("-DUSE_LEVELDB=OFF") CMAKE_ARGS+=("-DUSE_MPI=OFF") CMAKE_ARGS+=("-DUSE_NUMPY=OFF") CMAKE_ARGS+=("-DUSE_NNPACK=OFF") +CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") + +# Metal +if [ "${USE_PYTORCH_METAL:-}" == "1" ]; then + CMAKE_ARGS+=("-DUSE_PYTORCH_METAL=ON") +fi # pthreads CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread") diff --git a/scripts/model_zoo/update-models-from-caffe2.py b/scripts/model_zoo/update-models-from-caffe2.py index fb582a047bc62..d3e46e449d8ab 100644 --- a/scripts/model_zoo/update-models-from-caffe2.py +++ b/scripts/model_zoo/update-models-from-caffe2.py @@ -6,15 +6,12 @@ import caffe2.python.workspace as c2_workspace import glob import json -import math import numpy as np import onnx import caffe2.python.onnx.frontend import caffe2.python.onnx.backend import os import shutil -import subprocess -import sys import tarfile import tempfile @@ -25,7 +22,6 @@ from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory from caffe2.proto import caffe2_pb2 from onnx import numpy_helper -from filechunkio import FileChunkIO """A script converting Caffe2 models to ONNX, and updating ONNX model zoos. diff --git a/scripts/onnx/test.sh b/scripts/onnx/test.sh index 77f9c8b9f16ea..6918619588598 100755 --- a/scripts/onnx/test.sh +++ b/scripts/onnx/test.sh @@ -23,12 +23,13 @@ do done set -- "${UNKNOWN[@]}" # leave UNKNOWN -pip install pytest scipy hypothesis - if [[ $PARALLEL == 1 ]]; then pip install pytest-xdist fi +pip install pytest scipy hypothesis # these may not be necessary +pip install pytest-cov # installing since `coverage run -m pytest ..` doesn't work + # realpath might not be available on MacOS script_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}") top_dir=$(dirname $(dirname $(dirname "$script_path"))) @@ -38,6 +39,10 @@ test_paths=( args=() args+=("-v") +args+=("--cov") +args+=("--cov-report") +args+=("xml:test/coverage.xml") +args+=("--cov-append") if [[ $PARALLEL == 1 ]]; then args+=("-n") args+=("3") @@ -51,6 +56,7 @@ pytest "${args[@]}" \ --ignore "$top_dir/test/onnx/test_custom_ops.py" \ --ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \ --ignore "$top_dir/test/onnx/test_utility_funs.py" \ + --ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \ "${test_paths[@]}" # onnxruntime only support py3 @@ -66,8 +72,13 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test1* ]]; then fi if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then # Update the loop for new opsets - for i in $(seq 10 12); do + for i in $(seq 10 13); do pytest "${args[@]}" \ "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i" done + pytest "${args[@]}" \ + "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference" fi + +# Our CI expects both coverage.xml and .coverage to be within test/ +mv .coverage test/.coverage diff --git a/scripts/release_notes/categorize.py b/scripts/release_notes/categorize.py new file mode 100644 index 0000000000000..985d11f2e2bdd --- /dev/null +++ b/scripts/release_notes/categorize.py @@ -0,0 +1,133 @@ +import argparse +import os +import textwrap +from common import categories, topics, CommitDataCache +from commitlist import CommitList + +class Categorizer: + def __init__(self, path, category='Uncategorized'): + self.cache = CommitDataCache() + self.commits = CommitList.from_existing(path) + + # Special categories: 'Uncategorized' + # All other categories must be real + self.category = category + + def categorize(self): + commits = self.commits.filter(self.category) + i = 0 + while i < len(commits): + cur_commit = commits[i] + next_commit = commits[i + 1] if i + 1 < len(commits) else None + jump_to = self.handle_commit(cur_commit, i + 1, len(commits), commits) + + # Increment counter + if jump_to is not None: + i = jump_to + elif next_commit is None: + i = len(commits) + else: + i = commits.index(next_commit) + + def features(self, commit): + return self.cache.get(commit.commit_hash) + + def potential_reverts_of(self, commit, commits): + if 'Updating submodules' in commit.title: + return [] + index = commits.index(commit) + # -8 to remove the (#35011) + cleaned_title = commit.title[:-10] + # NB: the index + 2 is sketch + return {(index + 2 + delta): cand for delta, cand in enumerate(commits[index + 1:]) + if cleaned_title in cand.title and + commit.commit_hash != cand.commit_hash} + + def handle_commit(self, commit, i, total, commits): + potential_reverts = self.potential_reverts_of(commit, commits) + if potential_reverts: + potential_reverts = f'!!!POTENTIAL REVERTS!!!: {potential_reverts}' + else: + potential_reverts = "" + + features = self.features(commit) + + breaking_alarm = "" + if 'topic: bc-breaking' in features.labels: + breaking_alarm += "!!!!!! BC BREAKING !!!!!!" + + if 'module: deprecation' in features.labels: + breaking_alarm += "!!!!!! DEPRECATION !!!!!!" + + os.system('clear') + view = textwrap.dedent(f'''\ +[{i}/{total}] +================================================================================ +{features.title} + +{features.body} + +Files changed: {features.files_changed} + +Labels: {features.labels} + +{potential_reverts} {breaking_alarm} + +Current category: {commit.category} + +Select from: {', '.join(categories)} + + ''') + print(view) + cat_choice = None + while cat_choice is None: + value = input('category> ').strip() + if len(value) == 0: + cat_choice = commit.category + continue + choices = [cat for cat in categories + if cat.startswith(value)] + if len(choices) != 1: + print(f'Possible matches: {choices}, try again') + continue + cat_choice = choices[0] + print(f'\nSelected: {cat_choice}') + print(f'\nCurrent topic: {commit.topic}') + print(f'''Select from: {', '.join(topics)}''') + topic_choice = None + while topic_choice is None: + value = input('topic> ').strip() + if len(value) == 0: + topic_choice = commit.topic + continue + choices = [cat for cat in topics + if cat.startswith(value)] + if len(choices) != 1: + print(f'Possible matches: {choices}, try again') + continue + topic_choice = choices[0] + print(f'\nSelected: {topic_choice}') + self.update_commit(commit, cat_choice, topic_choice) + return None + + def update_commit(self, commit, category, topic): + assert category in categories + assert topic in topics + commit.category = category + commit.topic = topic + self.commits.write_to_disk() + +def main(): + parser = argparse.ArgumentParser(description='Tool to help categorize commits') + parser.add_argument('--category', type=str, default='Uncategorized', + help='Which category to filter by. "Uncategorized", None, or a category name') + parser.add_argument('--file', help='The location of the commits CSV', + default='results/commitlist.csv') + + args = parser.parse_args() + categorizer = Categorizer(args.file, args.category) + categorizer.categorize() + + +if __name__ == '__main__': + main() diff --git a/scripts/release_notes/commitlist.py b/scripts/release_notes/commitlist.py new file mode 100644 index 0000000000000..552641f546743 --- /dev/null +++ b/scripts/release_notes/commitlist.py @@ -0,0 +1,181 @@ +import argparse +from common import run, topics +from collections import defaultdict +import os +import csv +import pprint +from common import CommitDataCache +import re + + +""" +Example Usages + +Create a new commitlist for consumption by categorize.py. +Said commitlist contains commits between v1.5.0 and f5bc91f851. + + python commitlist.py --create_new tags/v1.5.0 f5bc91f851 + +Update the existing commitlist to commit bfcb687b9c. + + python commitlist.py --update_to bfcb687b9c + +""" + +class Commit: + def __init__(self, commit_hash, category, topic, title): + self.commit_hash = commit_hash + self.category = category + self.topic = topic + self.title = title + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.commit_hash == other.commit_hash and \ + self.category == other.category and \ + self.topic == other.topic and \ + self.title == other.title + + def __repr__(self): + return f'Commit({self.commit_hash}, {self.category}, {self.topic}, {self.title})' + +class CommitList: + # NB: Private ctor. Use `from_existing` or `create_new`. + def __init__(self, path, commits): + self.path = path + self.commits = commits + + @staticmethod + def from_existing(path): + commits = CommitList.read_from_disk(path) + return CommitList(path, commits) + + @staticmethod + def create_new(path, base_version, new_version): + if os.path.exists(path): + raise ValueError('Attempted to create a new commitlist but one exists already!') + commits = CommitList.get_commits_between(base_version, new_version) + return CommitList(path, commits) + + @staticmethod + def read_from_disk(path): + with open(path) as csvfile: + reader = csv.reader(csvfile) + rows = list(row for row in reader) + assert all(len(row) >= 4 for row in rows) + return [Commit(*row[:4]) for row in rows] + + def write_to_disk(self): + path = self.path + rows = self.commits + with open(path, 'w') as csvfile: + writer = csv.writer(csvfile) + for commit in rows: + writer.writerow([commit.commit_hash, commit.category, commit.topic, commit.title]) + + @staticmethod + def get_commits_between(base_version, new_version): + cmd = f'git merge-base {base_version} {new_version}' + rc, merge_base, _ = run(cmd) + assert rc == 0 + + # Returns a list of something like + # b33e38ec47 Allow a higher-precision step type for Vec256::arange (#34555) + cmd = f'git log --reverse --oneline {merge_base}..{new_version}' + rc, commits, _ = run(cmd) + assert rc == 0 + + log_lines = commits.split('\n') + hashes, titles = zip(*[log_line.split(' ', 1) for log_line in log_lines]) + return [Commit(commit_hash, 'Uncategorized', 'Untopiced', title) for commit_hash, title in zip(hashes, titles)] + + def filter(self, *, category=None, topic=None): + commits = self.commits + if category is not None: + commits = [commit for commit in commits if commit.category == category] + if topic is not None: + commits = [commit for commit in commits if commit.topic == topic] + return commits + + def update_to(self, new_version): + last_hash = self.commits[-1].commit_hash + new_commits = CommitList.get_commits_between(last_hash, new_version) + self.commits += new_commits + + def stat(self): + counts = defaultdict(lambda: defaultdict(int)) + for commit in self.commits: + counts[commit.category][commit.topic] += 1 + return counts + + +def create_new(path, base_version, new_version): + commits = CommitList.create_new(path, base_version, new_version) + commits.write_to_disk() + +def update_existing(path, new_version): + commits = CommitList.from_existing(path) + commits.update_to(new_version) + commits.write_to_disk() + +def to_markdown(commit_list, category): + def cleanup_title(commit): + match = re.match(r'(.*) \(#\d+\)', commit.title) + if match is None: + return commit.title + return match.group(1) + + cdc = CommitDataCache() + lines = [f'\n## {category}\n'] + for topic in topics: + lines.append(f'### {topic}\n') + commits = commit_list.filter(category=category, topic=topic) + for commit in commits: + result = cleanup_title(commit) + maybe_pr_number = cdc.get(commit.commit_hash).pr_number + if maybe_pr_number is None: + result = f'- {result} ({commit.commit_hash})\n' + else: + result = f'- {result} ([#{maybe_pr_number}](https://github.com/pytorch/pytorch/pull/{maybe_pr_number}))\n' + lines.append(result) + return lines + +def main(): + parser = argparse.ArgumentParser(description='Tool to create a commit list') + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--create_new', nargs=2) + group.add_argument('--update_to') + group.add_argument('--stat', action='store_true') + group.add_argument('--export_markdown', action='store_true') + + parser.add_argument('--path', default='results/commitlist.csv') + args = parser.parse_args() + + if args.create_new: + create_new(args.path, args.create_new[0], args.create_new[1]) + return + if args.update_to: + update_existing(args.path, args.update_to) + return + if args.stat: + commits = CommitList.from_existing(args.path) + stats = commits.stat() + pprint.pprint(stats) + return + if args.export_markdown: + commits = CommitList.from_existing(args.path) + categories = list(commits.stat().keys()) + lines = [] + for category in categories: + lines += to_markdown(commits, category) + filename = f'results/result.md' + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, 'w') as f: + f.writelines(lines) + return + assert False + +if __name__ == '__main__': + main() diff --git a/scripts/release_notes/common.py b/scripts/release_notes/common.py new file mode 100644 index 0000000000000..4312d71b544cc --- /dev/null +++ b/scripts/release_notes/common.py @@ -0,0 +1,196 @@ +from collections import namedtuple +from os.path import expanduser +import locale +import subprocess +import re +import requests +import os +import json + +categories = [ + 'Uncategorized', + 'distributed', + 'mobile', + 'jit', + 'visualization', + 'onnx', + 'caffe2', + 'quantization', + 'amd', + 'benchmark', + 'profiler', + 'dispatcher', + 'releng', + 'fx', + 'code_coverage', + 'vulkan', + 'skip', + 'cpp_frontend', + 'python_frontend', + 'complex_frontend', + 'vmap_frontend', + 'autograd_frontend', + 'build_frontend', + 'memory_format_frontend', + 'foreach_frontend', +] + +topics = [ + 'bc_breaking', + 'deprecations', + 'new_features', + 'improvements', + 'bug_fixes', + 'performance', + 'docs', + 'devs', + 'Untopiced', +] + + +Features = namedtuple('Features', [ + 'title', + 'body', + 'pr_number', + 'files_changed', + 'labels', +]) + + +def dict_to_features(dct): + return Features( + title=dct['title'], + body=dct['body'], + pr_number=dct['pr_number'], + files_changed=dct['files_changed'], + labels=dct['labels']) + + +def features_to_dict(features): + return dict(features._asdict()) + + +def run(command): + """Returns (return-code, stdout, stderr)""" + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=True) + output, err = p.communicate() + rc = p.returncode + enc = locale.getpreferredencoding() + output = output.decode(enc) + err = err.decode(enc) + return rc, output.strip(), err.strip() + + +def commit_body(commit_hash): + cmd = f'git log -n 1 --pretty=format:%b {commit_hash}' + ret, out, err = run(cmd) + return out if ret == 0 else None + + +def commit_title(commit_hash): + cmd = f'git log -n 1 --pretty=format:%s {commit_hash}' + ret, out, err = run(cmd) + return out if ret == 0 else None + + +def commit_files_changed(commit_hash): + cmd = f'git diff-tree --no-commit-id --name-only -r {commit_hash}' + ret, out, err = run(cmd) + return out.split('\n') if ret == 0 else None + + +def parse_pr_number(body, commit_hash, title): + regex = r'Pull Request resolved: https://github.com/pytorch/pytorch/pull/([0-9]+)' + matches = re.findall(regex, body) + if len(matches) == 0: + if 'revert' not in title.lower() and 'updating submodules' not in title.lower(): + print(f'[{commit_hash}: {title}] Could not parse PR number, ignoring PR') + return None + if len(matches) > 1: + print(f'[{commit_hash}: {title}] Got two PR numbers, using the first one') + return matches[0] + return matches[0] + + +def get_ghstack_token(): + pattern = 'github_oauth = (.*)' + with open(expanduser('~/.ghstackrc'), 'r+') as f: + config = f.read() + matches = re.findall(pattern, config) + if len(matches) == 0: + raise RuntimeError("Can't find a github oauth token") + return matches[0] + +token = get_ghstack_token() +headers = {"Authorization": f"token {token}"} + +def run_query(query): + request = requests.post('https://api.github.com/graphql', json={'query': query}, headers=headers) + if request.status_code == 200: + return request.json() + else: + raise Exception("Query failed to run by returning code of {}. {}".format(request.status_code, query)) + + +def gh_labels(pr_number): + query = f""" + {{ + repository(owner: "pytorch", name: "pytorch") {{ + pullRequest(number: {pr_number}) {{ + labels(first: 10) {{ + edges {{ + node {{ + name + }} + }} + }} + }} + }} + }} + """ + query = run_query(query) + edges = query['data']['repository']['pullRequest']['labels']['edges'] + return [edge['node']['name'] for edge in edges] + + +def get_features(commit_hash, return_dict=False): + title, body, files_changed = ( + commit_title(commit_hash), + commit_body(commit_hash), + commit_files_changed(commit_hash)) + pr_number = parse_pr_number(body, commit_hash, title) + labels = [] + if pr_number is not None: + labels = gh_labels(pr_number) + result = Features(title, body, pr_number, files_changed, labels) + if return_dict: + return features_to_dict(result) + return result + +class CommitDataCache: + def __init__(self, path='results/data.json'): + self.path = path + self.data = {} + if os.path.exists(path): + self.data = self.read_from_disk() + + def get(self, commit): + if commit not in self.data.keys(): + # Fetch and cache the data + self.data[commit] = get_features(commit) + self.write_to_disk() + return self.data[commit] + + def read_from_disk(self): + with open(self.path, 'r') as f: + data = json.load(f) + data = {commit: dict_to_features(dct) + for commit, dct in data.items()} + return data + + def write_to_disk(self): + data = {commit: features._asdict() for commit, features in self.data.items()} + with open(self.path, 'w') as f: + json.dump(data, f) + diff --git a/scripts/release_notes/requirements.txt b/scripts/release_notes/requirements.txt new file mode 100644 index 0000000000000..945b116ad3c50 --- /dev/null +++ b/scripts/release_notes/requirements.txt @@ -0,0 +1 @@ +PyGithub diff --git a/scripts/release_notes/test_release_notes.py b/scripts/release_notes/test_release_notes.py new file mode 100644 index 0000000000000..898db48c29295 --- /dev/null +++ b/scripts/release_notes/test_release_notes.py @@ -0,0 +1,45 @@ +import unittest +import tempfile +from commitlist import CommitList + +class TestCommitList(unittest.TestCase): + def test_create_new(self): + with tempfile.TemporaryDirectory() as tempdir: + commit_list_path = f'{tempdir}/commitlist.csv' + commit_list = CommitList.create_new(commit_list_path, 'v1.5.0', '7543e7e558') + self.assertEqual(len(commit_list.commits), 2143) + self.assertEqual(commit_list.commits[0].commit_hash, '7335f079ab') + self.assertTrue(commit_list.commits[0].title.startswith('[pt][quant] qmul and qadd')) + self.assertEqual(commit_list.commits[-1].commit_hash, '7543e7e558') + self.assertTrue(commit_list.commits[-1].title.startswith('Migrate minall, max, maxall')) + + def test_read_write(self): + with tempfile.TemporaryDirectory() as tempdir: + commit_list_path = f'{tempdir}/commitlist.csv' + initial = CommitList.create_new(commit_list_path, 'v1.5.0', '7543e7e558') + initial.write_to_disk() + + expected = CommitList.from_existing(commit_list_path) + expected.commits[-2].category = 'foobar' + expected.write_to_disk() + + commit_list = CommitList.from_existing(commit_list_path) + for commit, expected in zip(commit_list.commits, expected.commits): + self.assertEqual(commit, expected) + + def test_update_to(self): + with tempfile.TemporaryDirectory() as tempdir: + commit_list_path = f'{tempdir}/commitlist.csv' + initial = CommitList.create_new(commit_list_path, 'v1.5.0', '7543e7e558') + initial.commits[-2].category = 'foobar' + self.assertEqual(len(initial.commits), 2143) + initial.write_to_disk() + + commit_list = CommitList.from_existing(commit_list_path) + commit_list.update_to('5702a28b26') + self.assertEqual(len(commit_list.commits), 2143 + 4) + self.assertEqual(commit_list.commits[-5], initial.commits[-1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb index 801ad34a64fd4..810c23352fdd8 100644 --- a/scripts/xcode_build.rb +++ b/scripts/xcode_build.rb @@ -62,10 +62,13 @@ project.save sdk = nil +arch = nil if options[:platform] == 'SIMULATOR' sdk = 'iphonesimulator' + arch = 'x86_64' elsif options[:platform] == 'OS' sdk = 'iphoneos' + arch = 'arm64' else raise "unsupported platform #{options[:platform]}" end @@ -76,4 +79,5 @@ end # run xcodebuild -exec "xcodebuild clean build -project #{xcodeproj_path} -target #{target.name} -sdk #{sdk} -configuration Release PROVISIONING_PROFILE_SPECIFIER=#{profile}" +exec "xcodebuild clean build -project #{xcodeproj_path} -target #{target.name} -sdk #{sdk} -configuration Release PROVISIONING_PROFILE_SPECIFIER=#{profile} -arch #{arch}" + diff --git a/setup.py b/setup.py index 059188875e777..50983a89ad557 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,9 @@ # USE_FBGEMM=0 # disables the FBGEMM build # +# USE_KINETO=1 +# enables experimental usage of libkineto +# # USE_NUMPY=0 # disables the NumPy build # @@ -61,6 +64,9 @@ # BUILD_CAFFE2_OPS=0 # disable Caffe2 operators build # +# BUILD_CAFFE2=0 +# disable Caffe2 build +# # USE_IBVERBS # toggle features related to distributed support # @@ -162,7 +168,8 @@ # When turned on, the following cmake variables will be toggled as well: # USE_SYSTEM_CPUINFO=ON USE_SYSTEM_SLEEF=ON BUILD_CUSTOM_PROTOBUF=OFF - +# This future is needed to print Python2 EOL message +from __future__ import print_function import sys if sys.version_info < (3,): print("Python 2 has reached end-of-life and is no longer supported by PyTorch.") @@ -172,17 +179,14 @@ sys.exit(-1) import platform -python_min_version = (3, 6, 1) -python_min_version_str = '.'.join((str(num) for num in python_min_version)) -python_max_version = (3, 9, 0) -python_max_version_str = '.'.join((str(num) for num in python_max_version)) -if sys.version_info < python_min_version or sys.version_info >= python_max_version: - print("You are using Python {}. Python >={},<{} is required.".format(platform.python_version(), - python_min_version_str, - python_max_version_str)) +python_min_version = (3, 6, 2) +python_min_version_str = '.'.join(map(str, python_min_version)) +if sys.version_info < python_min_version: + print("You are using Python {}. Python >={} is required.".format(platform.python_version(), + python_min_version_str)) sys.exit(-1) -from setuptools import setup, Extension, distutils, find_packages +from setuptools import setup, Extension, find_packages from collections import defaultdict from distutils import core from distutils.core import Distribution @@ -193,6 +197,7 @@ import distutils.sysconfig import filecmp import shutil +import subprocess import os import json import glob @@ -306,7 +311,6 @@ def check_file(f): 'benchmark', 'CMakeLists.txt')) check_pydep('yaml', 'pyyaml') - check_pydep('typing', 'typing') build_caffe2(version=version, cmake_python_library=cmake_python_library, @@ -323,8 +327,16 @@ def check_file(f): # Use copies instead of symbolic files. # Windows has very poor support for them. - sym_files = ['tools/shared/_utils_internal.py'] - orig_files = ['torch/_utils_internal.py'] + sym_files = [ + 'tools/shared/_utils_internal.py', + 'torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h', + 'torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h', + ] + orig_files = [ + 'torch/_utils_internal.py', + 'third_party/valgrind-headers/callgrind.h', + 'third_party/valgrind-headers/valgrind.h', + ] for sym_file, orig_file in zip(sym_files, orig_files): same = False if os.path.exists(sym_file): @@ -340,7 +352,10 @@ def check_file(f): ################################################################################ # the list of runtime dependencies required by this built package -install_requires = ['future', 'typing_extensions', 'dataclasses'] +install_requires = [ + 'typing_extensions', + 'dataclasses; python_version < "3.7"' +] missing_pydep = ''' Missing build dependency: Unable to `import {importname}`. @@ -356,6 +371,42 @@ def check_pydep(importname, module): class build_ext(setuptools.command.build_ext.build_ext): + + # Copy libiomp5.dylib inside the wheel package on OS X + def _embed_libiomp(self): + if not IS_DARWIN: + return + lib_dir = os.path.join(self.build_lib, 'torch', 'lib') + libtorch_cpu_path = os.path.join(lib_dir, 'libtorch_cpu.dylib') + if not os.path.exists(libtorch_cpu_path): + return + # Parse libtorch_cpu load commands + otool_cmds = subprocess.check_output(['otool', '-l', libtorch_cpu_path]).decode('utf-8').split('\n') + rpaths, libs = [], [] + for idx, line in enumerate(otool_cmds): + if line.strip() == 'cmd LC_LOAD_DYLIB': + lib_name = otool_cmds[idx + 2].strip() + assert lib_name.startswith('name ') + libs.append(lib_name.split(' ', 1)[1].rsplit('(', 1)[0][:-1]) + + if line.strip() == 'cmd LC_RPATH': + rpath = otool_cmds[idx + 2].strip() + assert rpath.startswith('path ') + rpaths.append(rpath.split(' ', 1)[1].rsplit('(', 1)[0][:-1]) + + omp_lib_name = 'libiomp5.dylib' + if os.path.join('@rpath', omp_lib_name) not in libs: + return + + # Copy libiomp5 from rpath locations + for rpath in rpaths: + source_lib = os.path.join(rpath, omp_lib_name) + if not os.path.exists(source_lib): + continue + target_lib = os.path.join(self.build_lib, 'torch', 'lib', omp_lib_name) + self.copy_file(source_lib, target_lib) + break + def run(self): # Report build options. This is run after the build completes so # `CMakeCache.txt` exists and we can get an # accurate report on what is used and what is not. @@ -396,7 +447,7 @@ def run(self): else: report('-- Building without distributed package') - # Do not use clang to compile exensions if `-fstack-clash-protection` is defined + # Do not use clang to compile extensions if `-fstack-clash-protection` is defined # in system CFLAGS system_c_flags = distutils.sysconfig.get_config_var('CFLAGS') if IS_LINUX and '-fstack-clash-protection' in system_c_flags and 'clang' in os.environ.get('CC', ''): @@ -405,6 +456,8 @@ def run(self): # It's an old-style class in Python 2.7... setuptools.command.build_ext.build_ext.run(self) + self._embed_libiomp() + # Copy the essential export library to compile C++ extensions. if IS_WINDOWS: build_temp = self.build_temp @@ -617,13 +670,13 @@ def configure_extension_build(): extra_link_args += ['-g'] - def make_relative_rpath(path): + def make_relative_rpath_args(path): if IS_DARWIN: - return '-Wl,-rpath,@loader_path/' + path + return ['-Wl,-rpath,@loader_path/' + path] elif IS_WINDOWS: - return '' + return [] else: - return '-Wl,-rpath,$ORIGIN/' + path + return ['-Wl,-rpath,$ORIGIN/' + path] ################################################################################ # Declare extensions and package @@ -638,7 +691,7 @@ def make_relative_rpath(path): extra_compile_args=main_compile_args + extra_compile_args, include_dirs=[], library_dirs=library_dirs, - extra_link_args=extra_link_args + main_link_args + [make_relative_rpath('lib')]) + extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib')) extensions.append(C) if not IS_WINDOWS: @@ -727,6 +780,7 @@ def print_box(msg): with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f: long_description = f.read() + version_range_max = max(sys.version_info[1], 8) + 1 setup( name=package_name, version=version, @@ -838,6 +892,7 @@ def print_box(msg): 'include/torch/csrc/jit/serialization/*.h', 'include/torch/csrc/jit/python/*.h', 'include/torch/csrc/jit/testing/*.h', + 'include/torch/csrc/jit/tensorexpr/*.h', 'include/torch/csrc/onnx/*.h', 'include/torch/csrc/utils/*.h', 'include/pybind11/*.h', @@ -861,6 +916,9 @@ def print_box(msg): 'share/cmake/Gloo/*.cmake', 'share/cmake/Tensorpipe/*.cmake', 'share/cmake/Torch/*.cmake', + 'utils/benchmark/utils/*.cpp', + 'utils/benchmark/utils/valgrind_wrapper/*.cpp', + 'utils/benchmark/utils/valgrind_wrapper/*.h', ], 'caffe2': [ 'python/serialized_test/data/operator_test/*.zip', @@ -870,7 +928,7 @@ def print_box(msg): download_url='https://github.com/pytorch/pytorch/tags', author='PyTorch Team', author_email='packages@pytorch.org', - python_requires='>={},<{}'.format(python_min_version_str, python_max_version_str), + python_requires='>={}'.format(python_min_version_str), # PyPI package information. classifiers=[ 'Development Status :: 5 - Production/Stable', @@ -886,7 +944,7 @@ def print_box(msg): 'Topic :: Software Development :: Libraries :: Python Modules', 'Programming Language :: C++', 'Programming Language :: Python :: 3', - ] + ['Programming Language :: Python :: 3.{}' for i in range(python_min_version[1], python_max_version[1])], + ] + ['Programming Language :: Python :: 3.{}'.format(i) for i in range(python_min_version[1], version_range_max)], license='BSD-3', keywords='pytorch machine learning', ) diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 92c4f1060a648..d863b06b19b5b 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -27,84 +27,38 @@ # NB: function name DOES NOT include overload name! allow_list = [ ("c10_experimental", datetime.date(2222, 1, 1)), + # Internal + ("static", datetime.date(9999, 1, 1)), # Internal, profiler-specific ops ("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)), ("profiler::_record_function_enter", datetime.date(9999, 1, 1)), - ("tensorexpr::Group", datetime.date(2020, 9, 9)), - ("aten::append*", datetime.date(2020, 4, 15)), - ("aten::_min", datetime.date(2020, 9, 9)), - ("aten::_max", datetime.date(2020, 9, 9)), - ("aten::amax", datetime.date(2020, 10, 9)), - ("aten::amin", datetime.date(2020, 10, 9)), - ("aten::min_values", datetime.date(2020, 10, 9)), - ("aten::max_values", datetime.date(2020, 10, 9)), - ("aten::split_with_sizes", datetime.date(2020, 7, 29)), - ("aten::eq", datetime.date(2020, 7, 30)), - ("aten::log", datetime.date(2020, 7, 30)), - ("aten::__and__", datetime.date(2020, 7, 30)), - ("aten::__or__", datetime.date(2020, 7, 30)), - ("aten::__xor__", datetime.date(2020, 7, 30)), - ("aten::add", datetime.date(2020, 7, 30)), - ("aten::__upsample_bilinear", datetime.date(2020, 7, 30)), - ("aten::hash", datetime.date(2020, 7, 30)), - ("aten::divmod", datetime.date(2020, 7, 30)), - ("aten::sorted", datetime.date(2020, 8, 30)), - ("aten::__contains__", datetime.date(2020, 7, 30)), - ("aten::ne", datetime.date(2020, 7, 30)), - ("aten::index", datetime.date(2020, 7, 30)), - ("aten::isnan", datetime.date(2020, 7, 30)), - ("aten::pow", datetime.date(2020, 7, 30)), - ("aten::atan2", datetime.date(2020, 7, 30)), - ("aten::copy_", datetime.date(2020, 7, 30)), - ("aten::sort", datetime.date(2020, 7, 30)), - ("aten::_convolution", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution_transpose", datetime.date(2020, 10, 15)), - ("aten::_convolution_double_backward", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution_backward_input", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution_backward", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution_backward_weight", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution_transpose_backward", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution_transpose_backward_input", datetime.date(2020, 10, 15)), - ("aten::cudnn_convolution_transpose_backward_weight", datetime.date(2020, 10, 15)), - ("aten::_cudnn_init_dropout_state", datetime.date(2020, 7, 30)), - ("aten::sparse_coo_tensor", datetime.date(2020, 7, 30)), - ("aten::_sparse_coo_tensor_with_dims", datetime.date(2020, 7, 30)), - ("aten::_sparse_coo_tensor_with_dims_and_tensors", datetime.date(2020, 7, 30)), - ("aten::__lshift__", datetime.date(2020, 7, 30)), - ("aten::__rshift__", datetime.date(2020, 7, 30)), - ("aten::__round_to_zero_floordiv", datetime.date(2020, 7, 30)), - ("aten::gcd", datetime.date(2020, 7, 30)), - ("aten::unflatten", datetime.date(2020, 8, 14)), - ("aten::linalg_outer", datetime.date(2020, 8, 30)), - # WARNING: overload name here doesn't do anything - ("aten::linalg_outer.out", datetime.date(2020, 8, 30)), - ("aten::linalg_norm", datetime.date(2020, 9, 30)), - ("aten::linalg_norm.ord_str", datetime.date(2020, 9, 30)), - ("aten::linalg_norm.out", datetime.date(2020, 9, 30)), - ("aten::linalg_norm.ord_str_out", datetime.date(2020, 9, 30)), - ("aten::_compute_linear_combination", datetime.date(2020, 9, 1)), - ("aten::linspace", datetime.date(2020, 9, 30)), - ("aten::linspace.out", datetime.date(2020, 9, 30)), - ("aten::logspace", datetime.date(2020, 9, 30)), - ("aten::logspace.out", datetime.date(2020, 9, 30)), - ("__getstate__", datetime.date(2020, 9, 11), "Conv[23]dPackedParams"), - ("_caffe2::LearningRate", datetime.date(2020, 10, 1)), - ("aten::_var", datetime.date(2020, 10, 1)), - ("aten::_std", datetime.date(2020, 10, 1)), - ("aten::_foreach_add_", datetime.date(2020, 10, 1)), - ("aten::stft", datetime.date(2020, 10, 1)), - ("aten::istft", datetime.date(2020, 10, 1)), - ("prim::MakeTestTensor", datetime.date(2020, 10, 1)), - ("preprocess", datetime.date(2020, 10, 1)), - ("compile", datetime.date(2020, 10, 1)), - ("execute", datetime.date(2020, 10, 1)), - ("aten::_addr", datetime.date(2020, 10, 31)), - ("aten::_addr_", datetime.date(2020, 10, 31)), - ("aten::_addr.out", datetime.date(2020, 10, 31)), + ("aten::_qr_helper", datetime.date(2021, 1, 31)), + ("aten::fft", datetime.date(2021, 1, 31)), + ("aten::ifft", datetime.date(2021, 1, 31)), + ("aten::irfft", datetime.date(2021, 1, 31)), + ("aten::rfft", datetime.date(2021, 1, 31)), + ("aten::_svd_helper", datetime.date(2021, 1, 31)), + ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)), + ("aten::_cudnn_rnn", datetime.date(2020, 12, 31)), + ("aten::_cudnn_rnn_backward", datetime.date(2020, 12, 31)), + ("aten::quantile", datetime.date(2021, 1, 31)), + ("aten::nanquantile", datetime.date(2021, 1, 31)), + ("aten::_fft_with_size", datetime.date(2021, 1, 31)), + ("aten::thnn_conv_depthwise2d_backward", datetime.date(2021, 1, 31)), + ("aten::slow_conv3d_backward", datetime.date(2021, 1, 31)), + ("aten::thnn_conv2d_backward", datetime.date(2021, 1, 31)), + ("aten::slow_conv_transpose3d_backward", datetime.date(2021, 1, 31)), + ("aten::slow_conv_transpose2d_backward", datetime.date(2021, 1, 31)), + ("aten::set_", datetime.date(2021, 1, 31)), + ("aten::native_layer_norm", datetime.date(2021, 1, 31)), + ("aten::native_layer_norm_backward", datetime.date(2021, 1, 31)), + ("aten::sort", datetime.date(2021, 1, 31)), + ("aten::sort_out", datetime.date(2021, 1, 31)), + ("aten::elu_backward", datetime.date(2021, 1, 31)), + ("aten::_multinomial_alias_setup", datetime.date(2021, 1, 31)), + ("aten::_multinomial_alias_draw", datetime.date(2021, 1, 31)), ] - def allow_listed(schema, allow_list): for item in allow_list: if item[1] < datetime.date.today(): @@ -124,6 +78,7 @@ def allow_listed(schema, allow_list): dont_parse_list = [ ("_TorchScriptTesting.*", datetime.date(2099, 9, 17)), ("test_backend", datetime.date(2099, 9, 17)), + ("dist_c10d", datetime.date(2021, 1, 30)), ] diff --git a/test/benchmark_utils/callgrind_artifacts.json b/test/benchmark_utils/callgrind_artifacts.json new file mode 100644 index 0000000000000..e59f0524a139f --- /dev/null +++ b/test/benchmark_utils/callgrind_artifacts.json @@ -0,0 +1,1187 @@ +{ + "baseline_inclusive": [ + "6746 /home/rdonnelly/mc/conda-bld/compilers_linux-64_1534865402226/work/.build/src/glibc-2.12.2/csu/../sysdeps/x86_64/elf/start.S:0x00000000001c3ce2 [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Modules/main.c:Py_Main [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Programs/python.c:main [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_EvalCode [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_EvalCodeEx [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_EvalFrameEx [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:_PyEval_EvalFrameDefault [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:PyRun_AnyFileExFlags [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:PyRun_FileExFlags [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:PyRun_SimpleFileExFlags [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:run_mod [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6746 ???:(below main) [/usr/lib64/libc-2.28.so]", + "6746 ???:0x0000000000001050 [/usr/lib64/ld-2.28.so]", + "2407 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:call_function [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "1206 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyObject_FastCallKeywords [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "1196 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyObject_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "1180 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "1019 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallKeywords [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "1013 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "881 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:type_call [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "867 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode_nodummy [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "862 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "789 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:range_new [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "686 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "632 /tmp/build/80754af9/python_1599604603603/work/Objects/moduleobject.c:module_getattro [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "590 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "584 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "561 /tmp/build/80754af9/python_1599604603603/work/Modules/timemodule.c:time_sleep [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "261 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::cpp_function::dispatcher(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "209 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "207 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_RichCompareBool [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "196 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyObject_GetIter [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "195 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:rangeiter_next [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "192 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyStack_AsTuple [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "180 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_FromLong [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "177 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:range_iter [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "167 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_RichCompare [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "167 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", + "157 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:PyTuple_New [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "129 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyNumber_Subtract [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "113 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "112 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_SaveThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "111 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:_PyType_Lookup [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "100 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_Py_CheckFunctionResult [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "98 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", + "98 /tmp/build/80754af9/python_1599604603603/work/Python/pytime.c:_PyTime_FromSecondsObject [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "94 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyNumber_FloorDivide [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "93 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_New [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "90 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so]", + "87 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_AsLong [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "81 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:_PyObject_GC_NewVar [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "80 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:binary_op1 [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "75 ???:pthread_mutex_unlock [/usr/lib64/libpthread-2.28.so]", + "72 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "67 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:long_sub [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]" + ], + "baseline_exclusive": [ + "1394 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:_PyEval_EvalFrameDefault [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "867 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode_nodummy [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "710 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "338 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "182 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:call_function [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "180 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_FromLong [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "177 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "134 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:range_new [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "113 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "111 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:_PyType_Lookup [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "110 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "104 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_RichCompare [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "95 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:rangeiter_next [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "90 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so]", + "85 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::cpp_function::dispatcher(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "74 /tmp/build/80754af9/python_1599604603603/work/Python/pytime.c:_PyTime_FromSecondsObject [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "70 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyObject_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "66 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:_PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "66 ???:__pthread_mutex_unlock_usercnt [/usr/lib64/libpthread-2.28.so]", + "64 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_Py_CheckFunctionResult [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "63 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:long_richcompare [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "62 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyNumber_Subtract [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "61 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:PyTuple_New [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "54 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "54 /tmp/build/80754af9/python_1599604603603/work/Python/errors.c:PyErr_Restore [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "51 /tmp/build/80754af9/python_1599604603603/work/Modules/timemodule.c:time_sleep [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "49 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "47 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:long_sub [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "47 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:range_iter [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "47 /tmp/build/80754af9/python_1599604603603/work/Python/getargs.c:PyArg_UnpackTuple [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "45 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_AsLongAndOverflow [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "44 /home/test_user/miniconda3/envs/throwaway/include/pybind11/detail/internals.h:pybind11::detail::get_internals() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "44 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:_PyObject_FastCallDict", + "44 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:type_call [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "44 ???:pthread_cond_signal@@GLIBC_2.3.2 [/usr/lib64/libpthread-2.28.so]", + "42 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_AsLong [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "42 /tmp/build/80754af9/python_1599604603603/work/Objects/moduleobject.c:module_getattro [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "40 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_RichCompareBool [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "39 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:long_div [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "35 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyStack_AsTuple [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "35 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "35 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", + "35 /tmp/build/80754af9/python_1599604603603/work/Python/errors.c:PyErr_Occurred [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "34 /tmp/build/80754af9/python_1599604603603/work/Python/pytime.c:_PyTime_AsTimeval [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "31 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyNumber_Add [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "31 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:binary_op1 [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "26 /tmp/build/80754af9/python_1599604603603/work/Python/pystate.c:_PyThreadState_UncheckedGet [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "25 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:long_add [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "22 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:_PyObject_GC_NewVar [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "22 /tmp/build/80754af9/python_1599604603603/work/Python/pytime.c:_PyTime_GetMonotonicClock [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "21 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", + "21 /usr/include/c++/8/bits/stl_vector.h:pybind11::cpp_function::dispatcher(_object*, _object*, _object*)", + "20 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_New [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "20 ???:clock_gettime [/usr/lib64/libc-2.28.so]", + "19 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:_PyObject_GC_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "19 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyObject_GetIter [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "18 /tmp/build/80754af9/python_1599604603603/work/Objects/capsule.c:PyCapsule_GetPointer [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "17 /home/test_user/miniconda3/envs/throwaway/include/pybind11/cast.h:pybind11::cpp_function::dispatcher(_object*, _object*, _object*)", + "17 /tmp/build/80754af9/python_1599604603603/work/Objects/floatobject.c:PyFloat_AsDouble [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "17 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:range_dealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "15 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:PyObject_GC_UnTrack [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "15 ???:__memset_avx2_unaligned_erms [/usr/lib64/libc-2.28.so]", + "14 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyNumber_FloorDivide [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "13 /tmp/build/80754af9/python_1599604603603/work/Objects/frameobject.c:PyFrame_BlockSetup [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "13 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:object_init [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "13 /usr/include/c++/8/bits/stl_bvector.h:pybind11::cpp_function::dispatcher(_object*, _object*, _object*)", + "12 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pytypes.h:pybind11::cpp_function::dispatcher(_object*, _object*, _object*)", + "11 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyNumber_Index [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "11 /tmp/build/80754af9/python_1599604603603/work/Objects/frameobject.c:PyFrame_BlockPop [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "11 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_SaveThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "11 ???:select [/usr/lib64/libc-2.28.so]", + "11 build/../torch/csrc/Module.cpp:void pybind11::cpp_function::initialize, c10::TensorOptions const&)", + "5130822 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5114822 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4964822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4943822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4682822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "4660822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4597822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4586822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4372822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4352822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4091822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "4069822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "4006822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "3995822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3905822 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3831822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "3742822 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3718822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", + "3715822 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3702822 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2526822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2438822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2422822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2209822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "2198822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2183822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2178822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1934822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1917822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1704822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "1693822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "1678822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "1673822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1669822 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1658822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1433822 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1112000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const", + "1098500 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "1062157 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "1039000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::fill_(c10::Scalar) const", + "1016000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "939977 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:subtype_dealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "813000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "786000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", + "785000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", + "783000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", + "767977 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/python_variable.cpp:THPVariable_dealloc(THPVariable*)", + "764977 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_dealloc(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "758000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "686977 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_clear(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "676822 /data/users/test_user/repos/pytorch/build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const", + "643822 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "643000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/Fill.cpp:at::native::fill_(at::Tensor&, c10::Scalar)", + "643000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "642000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFunction(at::RecordScope)", + "642000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "596977 build/../c10/util/intrusive_ptr.h:THPVariable_clear(THPVariable*)", + "508822 build/../c10/core/CPUAllocator.cpp:c10::alloc_cpu(unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "488000 build/../torch/csrc/utils/python_arg_parser.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "486000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFunction(at::RecordScope) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "461822 ???:posix_memalign [/usr/lib64/libc-2.28.so]", + "434822 ???:_mid_memalign [/usr/lib64/libc-2.28.so]", + "429000 /data/users/test_user/repos/pytorch/build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", + "421000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "408000 ???:__tls_get_addr [/usr/lib64/ld-2.28.so]", + "389822 ???:_int_memalign [/usr/lib64/libc-2.28.so]", + "388193 ???:_int_free [/usr/lib64/libc-2.28.so]", + "386000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::is_complex() const", + "366000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction()", + "361000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "352977 ???:free [/usr/lib64/libc-2.28.so]", + "350977 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.cpp:c10::TensorImpl::release_resources()", + "315000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::is_complex() const", + "302000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "281000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorBody.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "273000 /data/users/test_user/repos/pytorch/build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::tls_local_dispatch_key_set()", + "273000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::tls_local_dispatch_key_set() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "261000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_ops.cc:operator delete(void*, unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "255000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_op.cc:operator delete(void*) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "236977 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::release_resources()", + "231000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl()", + "228977 /data/users/test_user/repos/pytorch/build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::ReportAndDelete(void*)", + "228977 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::ReportAndDelete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "223663 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "220000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::~gil_scoped_release() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "209209 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "205609 /tmp/build/80754af9/python_1599604603603/work/Objects/moduleobject.c:module_getattro [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "197500 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor)", + "196000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "192000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "192000 build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "191567 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "190500 build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "189561 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "187000 build/../c10/core/TensorImpl.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "182816 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "181000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&)", + "179500 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor)", + "178000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "173500 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "171000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "170175 ???:_int_malloc [/usr/lib64/libc-2.28.so]", + "169000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat)", + "168000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/new_op.cc:operator new(unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "167167 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", + "166000 build/../c10/core/ScalarType.h:c10::typeMetaToScalarType(caffe2::TypeMeta) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "158000 /data/users/test_user/repos/pytorch/build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled()", + "155000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorOptions.h:c10::TensorOptions::computeDispatchKey() const", + "154000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::TensorOptions) const", + "154000 build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::TensorOptions) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "153000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "152000 build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "151000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::gil_scoped_release(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "148500 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_NewWithVar(_typeobject*, at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "147000 build/../c10/util/typeid.h:c10::typeMetaToScalarType(caffe2::TypeMeta)", + "138000 /data/users/test_user/repos/pytorch/build/../c10/core/impl/VirtualGuardImpl.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device)", + "138000 build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "137000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef)", + "135000 ???:malloc [/usr/lib64/libc-2.28.so]", + "123000 build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "121500 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:PyType_GenericAlloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "120000 /data/users/test_user/repos/pytorch/build/../c10/util/ThreadLocalDebugInfo.cpp:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind)", + "120000 build/../aten/src/ATen/record_function.cpp:__tls_init [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "116000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::~TensorImpl()", + "112112 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_SaveThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "107000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::New(void*, unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "105977 build/../c10/core/CPUAllocator.cpp:c10::free_cpu(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "104000 /data/users/test_user/repos/pytorch/build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::scalartype(int) [clone .isra.180]", + "104000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::Delete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "104000 build/../c10/core/impl/VirtualGuardImpl.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "104000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::scalartype(int) [clone .isra.180] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "101000 build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "100000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", + "98098 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", + "95000 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "95000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "94000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", + "92821 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode_nodummy [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "91000 /data/users/test_user/repos/pytorch/build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type()", + "91000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so]", + "90000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter()", + "90000 /data/users/test_user/repos/pytorch/build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "90000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::end() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "90000 build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "88000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, bool (at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&)" + ], + "ones_no_data_exclusive": [ + "408000 ???:__tls_get_addr [/usr/lib64/ld-2.28.so]", + "388193 ???:_int_free [/usr/lib64/libc-2.28.so]", + "274000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "264000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFunction(at::RecordScope) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "192000 build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "169855 ???:_int_malloc [/usr/lib64/libc-2.28.so]", + "154000 build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::TensorOptions) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "148561 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:_PyEval_EvalFrameDefault [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "135000 ???:malloc [/usr/lib64/libc-2.28.so]", + "116000 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:subtype_dealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "105000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::tls_local_dispatch_key_set() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "102000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "102000 build/../c10/core/ScalarType.h:c10::typeMetaToScalarType(caffe2::TypeMeta) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "92821 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode_nodummy [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "92431 ???:_int_memalign [/usr/lib64/libc-2.28.so]", + "92000 ???:free [/usr/lib64/libc-2.28.so]", + "92000 build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so]", + "90000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::end() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "84338 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "84000 build/../c10/util/SmallVector.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "78000 build/../c10/core/TensorOptions.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "78000 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "74710 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "72000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::~RecordFunction()", + "67000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "66066 ???:__pthread_mutex_unlock_usercnt [/usr/lib64/libpthread-2.28.so]", + "64110 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "64000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "64000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionParameter::check(_object*, std::vector >&, int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "61182 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:call_function [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "60061 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:PyTuple_New [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "59177 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "59000 build/../c10/util/Optional.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "59000 build/../torch/csrc/utils/python_arg_parser.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "57000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "55000 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:tupledealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "54000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "52000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "50000 build/../c10/util/ThreadLocalDebugInfo.cpp:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "50000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "49049 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "49000 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:PyType_GenericAlloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "48000 build/../aten/src/ATen/record_function.cpp:__tls_init [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "48000 build/../c10/util/SmallVector.h:at::RecordFunction::~RecordFunction()", + "45015 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:PyObject_GC_UnTrack [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "45000 ???:_mid_memalign [/usr/lib64/libc-2.28.so]", + "44044 ???:pthread_cond_signal@@GLIBC_2.3.2 [/usr/lib64/libpthread-2.28.so]", + "44000 build/../c10/core/CPUAllocator.cpp:c10::alloc_cpu(unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "42000 build/../c10/util/typeid.h:c10::typeMetaToScalarType(caffe2::TypeMeta)", + "41000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "41000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "40000 build/../c10/core/TensorOptions.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "39000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "37111 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:_PyType_Lookup [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "36613 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "36000 /usr/include/c++/8/bits/stl_construct.h:at::RecordFunction::~RecordFunction()", + "36000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_clear(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "36000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::intlist(int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", + "35000 build/../c10/core/TensorOptions.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "34000 /tmp/build/80754af9/python_1599604603603/work/Objects/weakrefobject.c:PyObject_ClearWeakRefs [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "34000 build/../c10/core/impl/InlineDeviceGuard.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device)", + "33066 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:_PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "33000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/new_op.cc:operator new(unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "33000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "31000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "31000 build/../c10/util/SmallVector.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat)", + "30000 build/../aten/src/ATen/core/dispatch/Dispatcher.cpp:c10::Dispatcher::singleton() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../aten/src/ATen/record_function.cpp:at::(anonymous namespace)::manager() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../c10/core/impl/VirtualGuardImpl.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "28000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::New(void*, unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "28000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::compute_contiguous() const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "27015 ???:__memset_avx2_unaligned_erms [/usr/lib64/libc-2.28.so]", + "27000 ???:posix_memalign [/usr/lib64/libc-2.28.so]", + "27000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "26000 build/../c10/core/TensorImpl.h:c10::TensorImpl::data() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "26000 build/../c10/core/TensorOptions.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "26000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "25000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgs::intlist(int)", + "25000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::Delete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::~TensorImpl()", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../torch/csrc/Exceptions.cpp:torch::PyWarningHandler::~PyWarningHandler() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "25000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::device(int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "25000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "24000 build/../c10/core/DispatchKey.cpp:c10::getAutogradKeyFromBackend(c10::DispatchKey) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "24000 build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "24000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "23000 build/../aten/src/ATen/core/LegacyTypeDispatch.h:at::AutoNonVariableTypeMode::AutoNonVariableTypeMode(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "23000 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "23000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "22044 /home/test_user/miniconda3/envs/throwaway/include/pybind11/detail/internals.h:pybind11::detail::get_internals() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "22000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind)", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "22000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "21021 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", + "20035 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyStack_AsTuple [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "20000 build/../c10/util/Optional.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "20000 build/../torch/csrc/autograd/generated/variable_factories.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "20000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "20000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::get_autograd_meta(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19019 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:_PyObject_GC_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/native/TypeProperties.cpp:at::native::is_complex(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../c10/util/SmallVector.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "18054 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "18000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::~RecordFunction()", + "18000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::fill_(c10::Scalar) const", + "18000 build/../aten/src/ATen/native/TensorFactories.h:at::native::check_size_nonnegative(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "18000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_NewWithVar(_typeobject*, at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "17000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::gil_scoped_release(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "17000 /usr/include/c++/8/new:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "17000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_dispatch_key() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "17000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "16064 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_Py_CheckFunctionResult [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "16000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "16000 build/../c10/util/Exception.cpp:c10::Warning::set_warning_handler(c10::WarningHandler*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "16000 build/../c10/util/intrusive_ptr.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "16000 build/../c10/util/intrusive_ptr.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "16000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "15000 /usr/include/c++/8/new:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "15000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "15000 build/../c10/core/ScalarType.h:at::native::is_complex(at::Tensor const&)", + "15000 build/../c10/core/TensorOptions.h:c10::TensorOptions::computeDispatchKey() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "15000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "15000 build/../c10/util/intrusive_ptr.h:THPVariable_clear(THPVariable*)", + "15000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_dealloc(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "15000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::try_get_grad_accumulator(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "15000 build/../torch/csrc/utils/object_ptr.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", + "14042 /tmp/build/80754af9/python_1599604603603/work/Objects/moduleobject.c:module_getattro [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "14000 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::ReportAndDelete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "14000 build/../c10/core/ScalarType.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "14000 build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "14000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::TensorOptions const&)", + "14000 build/../c10/util/Optional.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", + "14000 build/../c10/util/typeid.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "14000 build/../c10/util/typeid.h:at::native::is_complex(at::Tensor const&)", + "14000 build/aten/src/ATen/core/TensorBody.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "14000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::is_complex() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "13000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "13000 build/../c10/core/TensorOptions.h:at::ones(c10::ArrayRef, c10::TensorOptions const&)", + "13000 build/../c10/util/SmallVector.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef)", + "13000 build/../torch/csrc/Exceptions.cpp:torch::PyWarningHandler::PyWarningHandler() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "13000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::scalartype(int) [clone .isra.180] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "12000 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:PyObject_GC_Del [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "12000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "12000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", + "12000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "12000 build/../c10/core/TensorImpl.h:c10::TensorImpl::compute_contiguous() const", + "12000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr >::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11011 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_SaveThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "11000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::~gil_scoped_release() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "11000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "11000 build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11000 build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "11000 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "10000 build/../c10/core/CPUAllocator.cpp:c10::profiledCPUMemoryReporter() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "10000 build/../c10/util/Exception.cpp:c10::Warning::get_warning_handler() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "10000 build/../c10/util/intrusive_ptr.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", + "10000 build/../c10/util/intrusive_ptr.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "10000 build/../c10/util/llvmMathExtras.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "9009 ???:pthread_mutex_unlock [/usr/lib64/libpthread-2.28.so]", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 build/../c10/core/Device.h:at::native::fill_out(at::Tensor&, c10::Scalar)", + "9000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::release_resources() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "9000 build/../c10/core/TensorOptions.h:c10::TensorOptions::TensorOptions() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "9000 build/../c10/util/Optional.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "9000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 /usr/include/c++/8/bits/atomic_base.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 /usr/include/c++/8/bits/stl_vector.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "8000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::release_resources()", + "8000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::is_complex() const", + "8000 build/../aten/src/ATen/detail/CPUGuardImpl.h:at::detail::CPUGuardImpl::getDevice() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "8000 build/../c10/core/CPUAllocator.cpp:c10::GetCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "8000 build/../c10/core/DeviceGuard.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "8000 build/../c10/core/DispatchKeySet.h:c10::DispatchKeySet::has(c10::DispatchKey) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "8000 build/../c10/core/StorageImpl.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/core/impl/DeviceGuardImplInterface.h:c10::impl::getDeviceGuardImpl(c10::DeviceType) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "8000 build/../c10/core/impl/VirtualGuardImpl.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "8000 build/../c10/util/Optional.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/util/Optional.h:c10::TensorOptions::computeDispatchKey() const", + "8000 build/../c10/util/SmallVector.h:c10::TensorImpl::~TensorImpl()", + "8000 build/../c10/util/llvmMathExtras.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/util/llvmMathExtras.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 build/aten/src/ATen/core/TensorBody.h:at::native::fill_out(at::Tensor&, c10::Scalar)", + "7035 /tmp/build/80754af9/python_1599604603603/work/Python/errors.c:PyErr_Occurred [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "7000 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GetDictPtr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "7000 build/../c10/core/Scalar.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "7000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "7000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", + "7000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "7000 build/../c10/core/impl/VirtualGuardImpl.h:c10::optional_base >::~optional_base()", + "7000 build/../c10/util/intrusive_ptr.h:torch::autograd::utils::wrap(at::Tensor)", + "7000 build/../c10/util/llvmMathExtras.h:at::Tensor::fill_(c10::Scalar) const", + "7000 build/../c10/util/llvmMathExtras.h:at::Tensor::is_complex() const", + "7000 build/aten/src/ATen/core/TensorBody.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "6026 /tmp/build/80754af9/python_1599604603603/work/Python/pystate.c:_PyThreadState_UncheckedGet [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6006 /tmp/build/80754af9/python_1599604603603/work/Python/pystate.c:PyThreadState_Swap [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_op.cc:operator delete(void*) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "6000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_ops.cc:operator delete(void*, unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "6000 /usr/include/c++/8/bits/atomic_base.h:THPVariable_clear(THPVariable*)", + "6000 /usr/include/c++/8/bits/move.h:torch::PythonArgs::intlist(int)", + "6000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::memoryProfilingEnabled()", + "6000 /usr/include/c++/8/bits/stl_iterator.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", + "6000 /usr/include/c++/8/bits/unique_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::~TensorImpl()", + "6000 /usr/include/c++/8/tuple:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../aten/src/ATen/record_function.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "6000 build/../c10/core/Allocator.cpp:c10::GetAllocator(c10::DeviceType const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "6000 build/../c10/core/Device.h:at::detail::CPUGuardImpl::getDevice() const", + "6000 build/../c10/core/DispatchKeySet.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "6000 build/../c10/core/DispatchKeySet.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "6000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "6000 build/../c10/core/TensorImpl.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "6000 build/../c10/core/TensorImpl.h:at::Tensor::device() const", + "6000 build/../c10/core/TensorOptions.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&)", + "6000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "6000 build/../c10/util/SmallVector.h:c10::TensorImpl::compute_contiguous() const", + "6000 build/../c10/util/TypeCast.h:float c10::checked_convert(double, char const*)", + "6000 build/../c10/util/intrusive_ptr.h:THPVariable_Wrap(at::Tensor)", + "6000 build/../c10/util/intrusive_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "6000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", + "5000 /tmp/build/80754af9/python_1599604603603/_build_env/x86_64-conda_cos6-linux-gnu/sysroot/usr/include/bits/string3.h:PyType_GenericAlloc", + "5000 /usr/include/c++/8/bits/atomic_base.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../aten/src/ATen/DeviceGuard.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::fill_(c10::Scalar) const", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "5000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::(anonymous namespace)::infer_full_options(c10::Scalar, c10::TensorOptions const&) [clone .isra.262] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5000 build/../c10/core/Device.h:torch::PythonArgs::device(int)", + "5000 build/../c10/core/DispatchKeySet.h:at::Tensor::fill_(c10::Scalar) const", + "5000 build/../c10/core/DispatchKeySet.h:at::Tensor::is_complex() const", + "5000 build/../c10/core/TensorImpl.h:at::Tensor::is_quantized() const", + "5000 build/../c10/core/TensorOptions.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "5000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "5000 build/../c10/util/Optional.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::release_resources()", + "5000 build/../torch/csrc/utils/cuda_lazy_init.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "4004 ???:__errno_location [/usr/lib64/libpthread-2.28.so]", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "4000 /usr/include/c++/8/bits/atomic_base.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", + "4000 /usr/include/c++/8/bits/atomic_base.h:c10::impl::getDeviceGuardImpl(c10::DeviceType)", + "4000 /usr/include/c++/8/bits/atomic_base.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "4000 /usr/include/c++/8/bits/move.h:c10::TensorImpl::release_resources()", + "4000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::set_autograd_meta(std::unique_ptr >) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "4000 /usr/include/c++/8/cmath:float c10::checked_convert(double, char const*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::is_complex() const", + "4000 build/../c10/core/Allocator.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/core/Device.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/core/DeviceGuard.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "4000 build/../c10/core/DispatchKeySet.h:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet)", + "4000 build/../c10/core/TensorImpl.h:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", + "4000 build/../c10/util/Optional.h:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/Optional.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", + "4000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "4000 build/../c10/util/SmallVector.h:c10::TensorImpl::sizes() const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "4000 build/../c10/util/UniqueVoidPtr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/intrusive_ptr.h:THPVariable_NewWithVar(_typeobject*, at::Tensor)", + "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::pyobj(at::Tensor const&)", + "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", + "4000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::device() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallKeywords [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "3000 /usr/include/c++/8/array:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", + "3000 /usr/include/c++/8/bits/atomic_base.h:c10::intrusive_ptr::reset_()", + "3000 /usr/include/c++/8/bits/shared_ptr_base.h:THPVariable_clear(THPVariable*)", + "3000 /usr/include/c++/8/bits/stl_numeric.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::PyWarningHandler()", + "3000 /usr/include/c++/8/bits/unique_ptr.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "3000 /usr/include/c++/8/tuple:c10::DefaultCPUAllocator::allocate(unsigned long) const", + "3000 build/../aten/src/ATen/core/LegacyTypeDispatch.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "3000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, bool (at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3000 build/../c10/core/Backend.h:torch::PythonArgs::device(int)", + "3000 build/../c10/core/Backend.h:torch::tensors::get_default_dispatch_key()", + "3000 build/../c10/core/Device.h:c10::DefaultCPUAllocator::allocate(unsigned long) const", + "3000 build/../c10/core/Device.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3000 build/../c10/core/Scalar.h:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", + "3000 build/../c10/core/TensorImpl.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "3000 build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "3000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&)", + "3000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "3000 build/../c10/util/Optional.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "3000 build/../c10/util/intrusive_ptr.h:THPVariable_dealloc(THPVariable*)", + "3000 build/../c10/util/typeid.h:c10::TensorImpl::data() const", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "3000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::check_deprecated(torch::FunctionSignature const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "3000 build/aten/src/ATen/core/TensorBody.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 build/aten/src/ATen/core/TensorBody.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "2006 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "2000 /usr/include/c++/8/array:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "2000 /usr/include/c++/8/bits/atomic_base.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "2000 /usr/include/c++/8/bits/move.h:c10::TensorImpl::set_autograd_meta(std::unique_ptr >)", + "2000 /usr/include/c++/8/bits/shared_ptr_base.h:torch::autograd::impl::try_get_grad_accumulator(at::Tensor const&)", + "2000 /usr/include/c++/8/bits/stl_algobase.h:torch::PythonArgs::intlist(int)", + "2000 /usr/include/c++/8/bits/stl_vector.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", + "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::~PyWarningHandler()", + "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", + "2000 /usr/include/c++/8/ext/new_allocator.h:torch::PythonArgs::intlist(int)", + "2000 /usr/include/c++/8/new:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "2000 /usr/include/c++/8/new:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", + "2000 /usr/include/c++/8/tuple:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "2000 build/../aten/src/ATen/Context.cpp:at::getCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", + "2000 build/../aten/src/ATen/core/dispatch/OperatorEntry.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "2000 build/../aten/src/ATen/detail/CPUGuardImpl.h:at::detail::CPUGuardImpl::uncheckedSetDevice(c10::Device) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2000 build/../aten/src/TH/THAllocator.cpp:getTHDefaultAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2000 build/../c10/core/Allocator.h:c10::DefaultCPUAllocator::allocate(unsigned long) const" + ], + "ones_with_data_inclusive": [ + "9395109 /home/rdonnelly/mc/conda-bld/compilers_linux-64_1534865402226/work/.build/src/glibc-2.12.2/csu/../sysdeps/x86_64/elf/start.S:0x00000000001c3ce2 [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Modules/main.c:Py_Main [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Programs/python.c:main [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_EvalCode [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_EvalCodeEx [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_EvalFrameEx [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:_PyEval_EvalFrameDefault [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:PyRun_AnyFileExFlags [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:PyRun_FileExFlags [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:PyRun_SimpleFileExFlags [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:run_mod [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "9395109 ???:(below main) [/usr/lib64/libc-2.28.so]", + "9395109 ???:0x0000000000001050 [/usr/lib64/ld-2.28.so]", + "7767596 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:call_function [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "7705208 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallKeywords [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "7702202 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "7463189 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "5458967 build/../torch/csrc/autograd/generated/variable_factories.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "5288967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&)", + "5177967 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5161967 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5011967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4990967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4729967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "4707967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4644967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4633967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4419967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "4399967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4138967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "4116967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "4053967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "4042967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3952967 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3878967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "3789967 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3765967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", + "3762967 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3749967 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2573967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2485967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2469967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2256967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "2245967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2230967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "2225967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1981967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1964967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1751967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "1740967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "1725967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "1720967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1716967 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1705967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "1475967 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1307993 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "1112000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const", + "1067166 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "1039000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::fill_(c10::Scalar) const", + "1016000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "944986 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:subtype_dealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "813000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "786000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", + "785000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", + "783000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", + "772986 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/python_variable.cpp:THPVariable_dealloc(THPVariable*)", + "769986 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_dealloc(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "758000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "686996 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_clear(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "670967 /data/users/test_user/repos/pytorch/build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const", + "643000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/Fill.cpp:at::native::fill_(at::Tensor&, c10::Scalar)", + "643000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "642000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFunction(at::RecordScope)", + "642000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "637967 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "596996 build/../c10/util/intrusive_ptr.h:THPVariable_clear(THPVariable*)", + "502967 build/../c10/core/CPUAllocator.cpp:c10::alloc_cpu(unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "488000 build/../torch/csrc/utils/python_arg_parser.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "486000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFunction(at::RecordScope) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "455967 ???:posix_memalign [/usr/lib64/libc-2.28.so]", + "447282 ???:_int_free [/usr/lib64/libc-2.28.so]", + "441225 ???:free [/usr/lib64/libc-2.28.so]", + "432000 ???:__tls_get_addr [/usr/lib64/ld-2.28.so]", + "429000 /data/users/test_user/repos/pytorch/build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", + "428967 ???:_mid_memalign [/usr/lib64/libc-2.28.so]", + "421000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "386000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::is_complex() const", + "383967 ???:_int_memalign [/usr/lib64/libc-2.28.so]", + "366000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction()", + "361000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "350996 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.cpp:c10::TensorImpl::release_resources()", + "345229 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_op.cc:operator delete(void*) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "315000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::is_complex() const", + "308686 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "302000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "290632 /tmp/build/80754af9/python_1599604603603/work/Objects/moduleobject.c:module_getattro [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "288883 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "281000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorBody.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "277000 /data/users/test_user/repos/pytorch/build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::intlist(int)", + "276590 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "274584 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "273000 /data/users/test_user/repos/pytorch/build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::tls_local_dispatch_key_set()", + "273000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::tls_local_dispatch_key_set() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "261000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_ops.cc:operator delete(void*, unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "236996 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::release_resources()", + "231000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl()", + "228996 /data/users/test_user/repos/pytorch/build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::ReportAndDelete(void*)", + "228996 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::ReportAndDelete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "222000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/new_op.cc:operator new(unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "220000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::~gil_scoped_release() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "209209 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "200993 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor)", + "200000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat)", + "196000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "193993 build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "192000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "192000 build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "187000 build/../c10/core/TensorImpl.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "182993 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor)", + "181000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&)", + "178000 ???:malloc [/usr/lib64/libc-2.28.so]", + "178000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "176993 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "171000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "170404 ???:_int_malloc [/usr/lib64/libc-2.28.so]", + "170000 build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "167167 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", + "166000 build/../c10/core/ScalarType.h:c10::typeMetaToScalarType(caffe2::TypeMeta) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "159000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef)", + "158000 /data/users/test_user/repos/pytorch/build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled()", + "155000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorOptions.h:c10::TensorOptions::computeDispatchKey() const", + "154000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::TensorOptions) const", + "154000 build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::TensorOptions) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "153000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "152000 build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "151993 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_NewWithVar(_typeobject*, at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "151000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::gil_scoped_release(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "147000 build/../c10/util/typeid.h:c10::typeMetaToScalarType(caffe2::TypeMeta)", + "146000 build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "138000 /data/users/test_user/repos/pytorch/build/../c10/core/impl/VirtualGuardImpl.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device)", + "134049 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "124993 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:PyType_GenericAlloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "120000 /data/users/test_user/repos/pytorch/build/../c10/util/ThreadLocalDebugInfo.cpp:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind)", + "120000 build/../aten/src/ATen/record_function.cpp:__tls_init [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "116000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::~TensorImpl()", + "113888 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode_nodummy [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "112112 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_SaveThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "107000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::New(void*, unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "105996 build/../c10/core/CPUAllocator.cpp:c10::free_cpu(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "104000 /data/users/test_user/repos/pytorch/build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::scalartype(int) [clone .isra.180]", + "104000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::Delete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "104000 build/../c10/core/impl/VirtualGuardImpl.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "104000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::scalartype(int) [clone .isra.180] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "101000 build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "100000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", + "98098 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", + "95000 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "95000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "94000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", + "93229 /usr/include/c++/8/ext/new_allocator.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)" + ], + "ones_with_data_exclusive": [ + "447282 ???:_int_free [/usr/lib64/libc-2.28.so]", + "432000 ???:__tls_get_addr [/usr/lib64/ld-2.28.so]", + "274000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "264000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFunction(at::RecordScope) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "192000 build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "178000 ???:malloc [/usr/lib64/libc-2.28.so]", + "170048 ???:_int_malloc [/usr/lib64/libc-2.28.so]", + "154000 build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::TensorOptions) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "140561 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:_PyEval_EvalFrameDefault [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "134049 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "116000 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:subtype_dealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "115000 ???:free [/usr/lib64/libc-2.28.so]", + "113888 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode_nodummy [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "106000 build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "105000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::tls_local_dispatch_key_set() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "102000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "102000 build/../c10/core/ScalarType.h:c10::typeMetaToScalarType(caffe2::TypeMeta) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "92506 ???:_int_memalign [/usr/lib64/libc-2.28.so]", + "90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so]", + "90000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::end() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "84338 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "84000 build/../c10/util/SmallVector.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "78000 build/../c10/core/TensorOptions.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "78000 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "74710 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "72000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::~RecordFunction()", + "67000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "67000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "66066 ???:__pthread_mutex_unlock_usercnt [/usr/lib64/libpthread-2.28.so]", + "64110 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "64000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionParameter::check(_object*, std::vector >&, int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "61182 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:call_function [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "59177 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "59000 build/../c10/util/Optional.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "59000 build/../torch/csrc/utils/python_arg_parser.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "57000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "56000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::intlist(int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "55000 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:tupledealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "54000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "52000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "50000 build/../c10/util/ThreadLocalDebugInfo.cpp:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "50000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "49000 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:PyType_GenericAlloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "48000 build/../aten/src/ATen/record_function.cpp:__tls_init [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "48000 build/../c10/util/SmallVector.h:at::RecordFunction::~RecordFunction()", + "45015 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:PyObject_GC_UnTrack [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "45000 ???:_mid_memalign [/usr/lib64/libc-2.28.so]", + "44061 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:PyTuple_New [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "44044 ???:pthread_cond_signal@@GLIBC_2.3.2 [/usr/lib64/libpthread-2.28.so]", + "44000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/new_op.cc:operator new(unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "44000 build/../c10/core/CPUAllocator.cpp:c10::alloc_cpu(unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::compute_contiguous() const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "42000 build/../c10/util/typeid.h:c10::typeMetaToScalarType(caffe2::TypeMeta)", + "41000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "41000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "41000 build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "40106 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "40000 build/../c10/core/TensorOptions.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "39000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "38056 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:_PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "37111 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:_PyType_Lookup [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "36000 /usr/include/c++/8/bits/stl_construct.h:at::RecordFunction::~RecordFunction()", + "36000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_clear(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", + "35000 build/../c10/core/TensorOptions.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "34000 /tmp/build/80754af9/python_1599604603603/work/Objects/weakrefobject.c:PyObject_ClearWeakRefs [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "34000 build/../c10/core/impl/InlineDeviceGuard.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device)", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "33000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "31000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../aten/src/ATen/core/dispatch/Dispatcher.cpp:c10::Dispatcher::singleton() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../aten/src/ATen/record_function.cpp:at::(anonymous namespace)::manager() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../c10/core/impl/VirtualGuardImpl.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "30000 build/../c10/util/SmallVector.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat)", + "28000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::New(void*, unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "27015 ???:__memset_avx2_unaligned_erms [/usr/lib64/libc-2.28.so]", + "27000 ???:posix_memalign [/usr/lib64/libc-2.28.so]", + "27000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "26000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgs::intlist(int)", + "26000 build/../c10/core/TensorImpl.h:c10::TensorImpl::data() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "26000 build/../c10/core/TensorOptions.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "26000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "25000 build/../aten/src/ATen/native/TensorFactories.h:at::native::check_size_nonnegative(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "25000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::Delete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::~TensorImpl()", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "25000 build/../torch/csrc/Exceptions.cpp:torch::PyWarningHandler::~PyWarningHandler() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "25000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::device(int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "25000 build/../torch/csrc/utils/python_numbers.h:torch::PythonArgs::intlist(int)", + "25000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "24000 build/../c10/core/DispatchKey.cpp:c10::getAutogradKeyFromBackend(c10::DispatchKey) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "24000 build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "24000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "23000 build/../aten/src/ATen/core/LegacyTypeDispatch.h:at::AutoNonVariableTypeMode::AutoNonVariableTypeMode(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "23000 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "23000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "22044 /home/test_user/miniconda3/envs/throwaway/include/pybind11/detail/internals.h:pybind11::detail::get_internals() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "22000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind)", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "22000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "21021 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", + "20035 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyStack_AsTuple [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "20000 build/../c10/util/Optional.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "20000 build/../torch/csrc/autograd/generated/variable_factories.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "20000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "20000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::get_autograd_meta(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19019 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:_PyObject_GC_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/native/TypeProperties.cpp:at::native::is_complex(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../c10/util/SmallVector.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "18054 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "18000 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_AsLongLongAndOverflow [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "18000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::~RecordFunction()", + "18000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::fill_(c10::Scalar) const", + "18000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_NewWithVar(_typeobject*, at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "17010 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:PyType_IsSubtype [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "17000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::gil_scoped_release(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "17000 /usr/include/c++/8/new:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "17000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_dispatch_key() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "17000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "16064 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_Py_CheckFunctionResult [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "16000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "16000 build/../c10/util/Exception.cpp:c10::Warning::set_warning_handler(c10::WarningHandler*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "16000 build/../c10/util/intrusive_ptr.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "16000 build/../c10/util/intrusive_ptr.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "16000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "15000 /usr/include/c++/8/new:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "15000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "15000 build/../c10/core/ScalarType.h:at::native::is_complex(at::Tensor const&)", + "15000 build/../c10/core/TensorOptions.h:c10::TensorOptions::computeDispatchKey() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "15000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "15000 build/../c10/util/intrusive_ptr.h:THPVariable_clear(THPVariable*)", + "15000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_dealloc(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "15000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::try_get_grad_accumulator(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "15000 build/../torch/csrc/utils/object_ptr.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", + "14042 /tmp/build/80754af9/python_1599604603603/work/Objects/moduleobject.c:module_getattro [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "14000 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::ReportAndDelete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "14000 build/../c10/core/ScalarType.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "14000 build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "14000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::TensorOptions const&)", + "14000 build/../c10/util/Optional.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", + "14000 build/../c10/util/typeid.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "14000 build/../c10/util/typeid.h:at::native::is_complex(at::Tensor const&)", + "14000 build/aten/src/ATen/core/TensorBody.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "14000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::is_complex() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "13000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "13000 build/../c10/core/TensorOptions.h:at::ones(c10::ArrayRef, c10::TensorOptions const&)", + "13000 build/../torch/csrc/Exceptions.cpp:torch::PyWarningHandler::PyWarningHandler() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "13000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::scalartype(int) [clone .isra.180] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "13000 build/../torch/csrc/utils/tensor_numpy.cpp:torch::utils::is_numpy_int(_object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "12000 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:PyObject_GC_Del [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "12000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "12000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", + "12000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "12000 build/../c10/core/TensorImpl.h:c10::TensorImpl::compute_contiguous() const", + "12000 build/../c10/util/SmallVector.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef)", + "12000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr >::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11011 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_SaveThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "11000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::~gil_scoped_release() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "11000 /usr/include/c++/8/bits/stl_algobase.h:torch::PythonArgs::intlist(int)", + "11000 /usr/include/c++/8/ext/new_allocator.h:torch::PythonArgs::intlist(int)", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "11000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "11000 build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11000 build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "11000 build/../torch/csrc/jit/frontend/tracer.cpp:torch::jit::tracer::getTracingState() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11000 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "10000 build/../c10/core/CPUAllocator.cpp:c10::profiledCPUMemoryReporter() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "10000 build/../c10/util/Exception.cpp:c10::Warning::get_warning_handler() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "10000 build/../c10/util/intrusive_ptr.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", + "10000 build/../c10/util/intrusive_ptr.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "10000 build/../c10/util/llvmMathExtras.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "9009 ???:pthread_mutex_unlock [/usr/lib64/libpthread-2.28.so]", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 build/../c10/core/Device.h:at::native::fill_out(at::Tensor&, c10::Scalar)", + "9000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::release_resources() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "9000 build/../c10/core/TensorOptions.h:c10::TensorOptions::TensorOptions() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "9000 build/../c10/util/Optional.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "9000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_op.cc:operator delete(void*) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "8000 /usr/include/c++/8/bits/atomic_base.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::release_resources()", + "8000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::is_complex() const", + "8000 build/../aten/src/ATen/detail/CPUGuardImpl.h:at::detail::CPUGuardImpl::getDevice() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "8000 build/../c10/core/CPUAllocator.cpp:c10::GetCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "8000 build/../c10/core/DeviceGuard.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "8000 build/../c10/core/DispatchKeySet.h:c10::DispatchKeySet::has(c10::DispatchKey) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "8000 build/../c10/core/StorageImpl.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/core/impl/DeviceGuardImplInterface.h:c10::impl::getDeviceGuardImpl(c10::DeviceType) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "8000 build/../c10/core/impl/VirtualGuardImpl.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "8000 build/../c10/util/Optional.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/util/Optional.h:c10::TensorOptions::computeDispatchKey() const", + "8000 build/../c10/util/SmallVector.h:c10::TensorImpl::compute_contiguous() const", + "8000 build/../c10/util/SmallVector.h:c10::TensorImpl::~TensorImpl()", + "8000 build/../c10/util/llvmMathExtras.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/util/llvmMathExtras.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 build/aten/src/ATen/core/TensorBody.h:at::native::fill_out(at::Tensor&, c10::Scalar)", + "7035 /tmp/build/80754af9/python_1599604603603/work/Python/errors.c:PyErr_Occurred [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "7000 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GetDictPtr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "7000 /usr/include/c++/8/bits/stl_numeric.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "7000 /usr/include/c++/8/bits/stl_vector.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "7000 build/../c10/core/Scalar.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "7000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "7000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", + "7000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "7000 build/../c10/core/impl/VirtualGuardImpl.h:c10::optional_base >::~optional_base()", + "7000 build/../c10/util/intrusive_ptr.h:torch::autograd::utils::wrap(at::Tensor)", + "7000 build/../c10/util/llvmMathExtras.h:at::Tensor::fill_(c10::Scalar) const", + "7000 build/../c10/util/llvmMathExtras.h:at::Tensor::is_complex() const", + "7000 build/aten/src/ATen/core/TensorBody.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", + "6026 /tmp/build/80754af9/python_1599604603603/work/Python/pystate.c:_PyThreadState_UncheckedGet [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6006 /tmp/build/80754af9/python_1599604603603/work/Python/pystate.c:PyThreadState_Swap [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "6000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_ops.cc:operator delete(void*, unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", + "6000 /usr/include/c++/8/bits/atomic_base.h:THPVariable_clear(THPVariable*)", + "6000 /usr/include/c++/8/bits/move.h:torch::PythonArgs::intlist(int)", + "6000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::memoryProfilingEnabled()", + "6000 /usr/include/c++/8/bits/stl_iterator.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", + "6000 /usr/include/c++/8/bits/unique_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::~TensorImpl()", + "6000 /usr/include/c++/8/tuple:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../aten/src/ATen/record_function.h:at::RecordFunction::RecordFunction(at::RecordScope)", + "6000 build/../c10/core/Allocator.cpp:c10::GetAllocator(c10::DeviceType const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "6000 build/../c10/core/Device.h:at::detail::CPUGuardImpl::getDevice() const", + "6000 build/../c10/core/DispatchKeySet.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "6000 build/../c10/core/DispatchKeySet.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "6000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "6000 build/../c10/core/TensorImpl.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "6000 build/../c10/core/TensorImpl.h:at::Tensor::device() const", + "6000 build/../c10/core/TensorOptions.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&)", + "6000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "6000 build/../c10/util/TypeCast.h:float c10::checked_convert(double, char const*)", + "6000 build/../c10/util/intrusive_ptr.h:THPVariable_Wrap(at::Tensor)", + "6000 build/../c10/util/intrusive_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "6000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", + "5000 /tmp/build/80754af9/python_1599604603603/_build_env/x86_64-conda_cos6-linux-gnu/sysroot/usr/include/bits/string3.h:PyType_GenericAlloc", + "5000 /usr/include/c++/8/bits/atomic_base.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../aten/src/ATen/DeviceGuard.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::fill_(c10::Scalar) const", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "5000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::(anonymous namespace)::infer_full_options(c10::Scalar, c10::TensorOptions const&) [clone .isra.262] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5000 build/../c10/core/Device.h:torch::PythonArgs::device(int)", + "5000 build/../c10/core/DispatchKeySet.h:at::Tensor::fill_(c10::Scalar) const", + "5000 build/../c10/core/DispatchKeySet.h:at::Tensor::is_complex() const", + "5000 build/../c10/core/TensorImpl.h:at::Tensor::is_quantized() const", + "5000 build/../c10/core/TensorOptions.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "5000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "5000 build/../c10/util/Optional.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::release_resources()", + "5000 build/../torch/csrc/utils/cuda_lazy_init.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "4004 ???:__errno_location [/usr/lib64/libpthread-2.28.so]", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "4000 /usr/include/c++/8/bits/atomic_base.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", + "4000 /usr/include/c++/8/bits/atomic_base.h:c10::impl::getDeviceGuardImpl(c10::DeviceType)", + "4000 /usr/include/c++/8/bits/atomic_base.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "4000 /usr/include/c++/8/bits/move.h:c10::TensorImpl::release_resources()", + "4000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::set_autograd_meta(std::unique_ptr >) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "4000 /usr/include/c++/8/cmath:float c10::checked_convert(double, char const*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::is_complex() const", + "4000 build/../c10/core/Allocator.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/core/Device.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/core/DeviceGuard.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", + "4000 build/../c10/core/DispatchKeySet.h:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet)", + "4000 build/../c10/core/TensorImpl.h:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", + "4000 build/../c10/util/Optional.h:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/Optional.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", + "4000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "4000 build/../c10/util/SmallVector.h:c10::TensorImpl::sizes() const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "4000 build/../c10/util/UniqueVoidPtr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/intrusive_ptr.h:THPVariable_NewWithVar(_typeobject*, at::Tensor)", + "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::pyobj(at::Tensor const&)", + "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", + "4000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::device() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallKeywords [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "3000 /usr/include/c++/8/array:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", + "3000 /usr/include/c++/8/bits/atomic_base.h:c10::intrusive_ptr::reset_()", + "3000 /usr/include/c++/8/bits/shared_ptr_base.h:THPVariable_clear(THPVariable*)", + "3000 /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::PyWarningHandler()", + "3000 /usr/include/c++/8/bits/unique_ptr.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "3000 /usr/include/c++/8/ext/new_allocator.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "3000 /usr/include/c++/8/tuple:c10::DefaultCPUAllocator::allocate(unsigned long) const", + "3000 build/../aten/src/ATen/core/LegacyTypeDispatch.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "3000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, bool (at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "3000 build/../c10/core/Backend.h:torch::PythonArgs::device(int)", + "3000 build/../c10/core/Backend.h:torch::tensors::get_default_dispatch_key()", + "3000 build/../c10/core/Device.h:c10::DefaultCPUAllocator::allocate(unsigned long) const", + "3000 build/../c10/core/Device.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3000 build/../c10/core/Scalar.h:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", + "3000 build/../c10/core/TensorImpl.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "3000 build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "3000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&)", + "3000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "3000 build/../c10/util/Optional.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", + "3000 build/../c10/util/intrusive_ptr.h:THPVariable_dealloc(THPVariable*)", + "3000 build/../c10/util/typeid.h:c10::TensorImpl::data() const", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "3000 build/../torch/csrc/utils/object_ptr.h:torch::PythonArgs::intlist(int)", + "3000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::check_deprecated(torch::FunctionSignature const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "3000 build/aten/src/ATen/core/TensorBody.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 build/aten/src/ATen/core/TensorBody.h:torch::autograd::make_variable(at::Tensor, bool, bool)", + "2006 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", + "2000 /usr/include/c++/8/array:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "2000 /usr/include/c++/8/bits/atomic_base.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", + "2000 /usr/include/c++/8/bits/move.h:c10::TensorImpl::set_autograd_meta(std::unique_ptr >)", + "2000 /usr/include/c++/8/bits/shared_ptr_base.h:torch::autograd::impl::try_get_grad_accumulator(at::Tensor const&)", + "2000 /usr/include/c++/8/bits/stl_vector.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", + "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::~PyWarningHandler()", + "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", + "2000 /usr/include/c++/8/new:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "2000 /usr/include/c++/8/new:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", + "2000 /usr/include/c++/8/tuple:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "2000 build/../aten/src/ATen/Context.cpp:at::getCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", + "2000 build/../aten/src/ATen/core/dispatch/OperatorEntry.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", + "2000 build/../aten/src/ATen/detail/CPUGuardImpl.h:at::detail::CPUGuardImpl::uncheckedSetDevice(c10::Device) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2000 build/../aten/src/TH/THAllocator.cpp:getTHDefaultAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2000 build/../c10/core/Allocator.h:c10::DefaultCPUAllocator::allocate(unsigned long) const" + ] +} \ No newline at end of file diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py new file mode 100644 index 0000000000000..e15a0fec338b7 --- /dev/null +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -0,0 +1,955 @@ +import json +import os +import re +import textwrap +import timeit +from typing import Any, List, Tuple +import unittest + +import torch +import torch.utils.benchmark as benchmark_utils +from torch.testing._internal.common_utils import TestCase, run_tests, IS_SANDCASTLE, IS_WINDOWS, slowTest +from torch.testing._internal import expecttest +import numpy as np + + +CALLGRIND_ARTIFACTS: str = os.path.join( + os.path.split(os.path.abspath(__file__))[0], + "callgrind_artifacts.json" +) + + +def generate_callgrind_artifacts() -> None: + """Regenerate `callgrind_artifacts.json` + + Unlike the expect tests, regenerating callgrind counts will produce a + large diff since build directories and conda/pip directories are included + in the instruction string. It is also not 100% deterministic (due to jitter + from Python) and takes over a minute to run. As a result, running this + function is manual. + """ + print("Regenerating callgrind artifact.") + + stats_no_data = benchmark_utils.Timer( + "y = torch.ones(())" + ).collect_callgrind(number=1000) + + stats_with_data = benchmark_utils.Timer( + "y = torch.ones((1,))" + ).collect_callgrind(number=1000) + + user = os.getenv("USER") + + def to_entry(fn_counts): + return [f"{c} {fn.replace(f'/{user}/', '/test_user/')}" for c, fn in fn_counts] + + artifacts = { + "baseline_inclusive": to_entry(stats_no_data.baseline_inclusive_stats), + "baseline_exclusive": to_entry(stats_no_data.baseline_exclusive_stats), + "ones_no_data_inclusive": to_entry(stats_no_data.stmt_inclusive_stats), + "ones_no_data_exclusive": to_entry(stats_no_data.stmt_exclusive_stats), + "ones_with_data_inclusive": to_entry(stats_with_data.stmt_inclusive_stats), + "ones_with_data_exclusive": to_entry(stats_with_data.stmt_exclusive_stats), + } + + with open(CALLGRIND_ARTIFACTS, "wt") as f: + json.dump(artifacts, f, indent=4) + + +def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]: + """Hermetic artifact to unit test Callgrind wrapper. + + In addition to collecting counts, this wrapper provides some facilities for + manipulating and displaying the collected counts. The results of several + measurements are stored in callgrind_artifacts.json. + + While FunctionCounts and CallgrindStats are pickleable, the artifacts for + testing are stored in raw string form for easier inspection and to avoid + baking any implementation details into the artifact itself. + """ + with open(CALLGRIND_ARTIFACTS, "rt") as f: + artifacts = json.load(f) + + pattern = re.compile(r"^\s*([0-9]+)\s(.+)$") + + def to_function_counts( + count_strings: List[str], + inclusive: bool + ) -> benchmark_utils.FunctionCounts: + data: List[benchmark_utils.FunctionCount] = [] + for cs in count_strings: + # Storing entries as f"{c} {fn}" rather than [c, fn] adds some work + # reviving the artifact, but it makes the json much easier to read. + match = pattern.search(cs) + assert match is not None + c, fn = match.groups() + data.append(benchmark_utils.FunctionCount(count=int(c), function=fn)) + + return benchmark_utils.FunctionCounts( + tuple(sorted(data, reverse=True)), + inclusive=inclusive) + + baseline_inclusive = to_function_counts(artifacts["baseline_inclusive"], True) + baseline_exclusive = to_function_counts(artifacts["baseline_exclusive"], False) + + stats_no_data = benchmark_utils.CallgrindStats( + benchmark_utils.TaskSpec("y = torch.ones(())", "pass"), + number_per_run=1000, + built_with_debug_symbols=True, + baseline_inclusive_stats=baseline_inclusive, + baseline_exclusive_stats=baseline_exclusive, + stmt_inclusive_stats=to_function_counts(artifacts["ones_no_data_inclusive"], True), + stmt_exclusive_stats=to_function_counts(artifacts["ones_no_data_exclusive"], False), + ) + + stats_with_data = benchmark_utils.CallgrindStats( + benchmark_utils.TaskSpec("y = torch.ones((1,))", "pass"), + number_per_run=1000, + built_with_debug_symbols=True, + baseline_inclusive_stats=baseline_inclusive, + baseline_exclusive_stats=baseline_exclusive, + stmt_inclusive_stats=to_function_counts(artifacts["ones_with_data_inclusive"], True), + stmt_exclusive_stats=to_function_counts(artifacts["ones_with_data_exclusive"], False), + ) + + return stats_no_data, stats_with_data + + +class MyModule(torch.nn.Module): + def forward(self, x): + return x + 1 + + +class TestBenchmarkUtils(TestCase): + def regularizeAndAssertExpectedInline( + self, x: Any, + expect: str, + indent: int = 12 + ) -> None: + x_str: str = re.sub( + "object at 0x[0-9a-fA-F]+>", + "object at 0xXXXXXXXXXXXX>", + x if isinstance(x, str) else repr(x) + ) + if "\n" in x_str: + # Indent makes the reference align at the call site. + x_str = textwrap.indent(x_str, " " * indent) + + self.assertExpectedInline(x_str, expect, skip=1) + + def test_timer(self): + timer = benchmark_utils.Timer( + stmt="torch.ones(())", + ) + sample = timer.timeit(5).median + self.assertIsInstance(sample, float) + + median = timer.blocked_autorange(min_run_time=0.01).median + self.assertIsInstance(median, float) + + # We set a very high threshold to avoid flakiness in CI. + # The internal algorithm is tested in `test_adaptive_timer` + median = timer.adaptive_autorange(threshold=0.5).median + + # Test that multi-line statements work properly. + median = benchmark_utils.Timer( + stmt=""" + with torch.no_grad(): + y = x + 1""", + setup=""" + x = torch.ones((1,), requires_grad=True) + for _ in range(5): + x = x + 1.0""", + ).timeit(5).median + self.assertIsInstance(sample, float) + + @slowTest + @unittest.skipIf(IS_SANDCASTLE, "C++ timing is OSS only.") + def test_cpp_timer(self): + timer = benchmark_utils.Timer( + "torch::Tensor y = x + 1;", + setup="torch::Tensor x = torch::empty({1});", + timer=timeit.default_timer, + language=benchmark_utils.Language.CPP, + ) + t = timer.timeit(10) + self.assertIsInstance(t.median, float) + + class _MockTimer: + _seed = 0 + + _timer_noise_level = 0.05 + _timer_cost = 100e-9 # 100 ns + + _function_noise_level = 0.05 + _function_costs = ( + ("pass", 8e-9), + ("cheap_fn()", 4e-6), + ("expensive_fn()", 20e-6), + ("with torch.no_grad():\n y = x + 1", 10e-6), + ) + + def __init__(self, stmt, setup, timer, globals): + self._random_state = np.random.RandomState(seed=self._seed) + self._mean_cost = {k: v for k, v in self._function_costs}[stmt] + + def sample(self, mean, noise_level): + return max(self._random_state.normal(mean, mean * noise_level), 5e-9) + + def timeit(self, number): + return sum([ + # First timer invocation + self.sample(self._timer_cost, self._timer_noise_level), + + # Stmt body + self.sample(self._mean_cost * number, self._function_noise_level), + + # Second timer invocation + self.sample(self._timer_cost, self._timer_noise_level), + ]) + + def test_adaptive_timer(self): + class MockTimer(benchmark_utils.Timer): + _timer_cls = self._MockTimer + + class _MockCudaTimer(self._MockTimer): + # torch.cuda.synchronize is much more expensive than + # just timeit.default_timer + _timer_cost = 10e-6 + + _function_costs = ( + self._MockTimer._function_costs[0], + self._MockTimer._function_costs[1], + + # GPU should be faster once there is enough work. + ("expensive_fn()", 5e-6), + ) + + class MockCudaTimer(benchmark_utils.Timer): + _timer_cls = _MockCudaTimer + + m = MockTimer("pass").blocked_autorange(min_run_time=10) + self.regularizeAndAssertExpectedInline( + m, + """\ + + pass + Median: 7.98 ns + IQR: 0.52 ns (7.74 to 8.26) + 125 measurements, 10000000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer("pass").adaptive_autorange(), + """\ + + pass + Median: 7.86 ns + IQR: 0.71 ns (7.63 to 8.34) + 6 measurements, 1000000 runs per measurement, 1 thread""" + ) + + # Check against strings so we can reuse expect infra. + self.regularizeAndAssertExpectedInline(m.mean, """8.001365835795602e-09""") + self.regularizeAndAssertExpectedInline(m.median, """7.983151323215967e-09""") + self.regularizeAndAssertExpectedInline(len(m.times), """125""") + self.regularizeAndAssertExpectedInline(m.number_per_run, """10000000""") + + self.regularizeAndAssertExpectedInline( + MockTimer("cheap_fn()").blocked_autorange(min_run_time=10), + """\ + + cheap_fn() + Median: 3.98 us + IQR: 0.27 us (3.85 to 4.12) + 252 measurements, 10000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer("cheap_fn()").adaptive_autorange(), + """\ + + cheap_fn() + Median: 4.16 us + IQR: 0.22 us (4.04 to 4.26) + 4 measurements, 1000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer("expensive_fn()").blocked_autorange(min_run_time=10), + """\ + + expensive_fn() + Median: 19.97 us + IQR: 1.35 us (19.31 to 20.65) + 501 measurements, 1000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer("expensive_fn()").adaptive_autorange(), + """\ + + expensive_fn() + Median: 20.79 us + IQR: 1.09 us (20.20 to 21.29) + 4 measurements, 1000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockCudaTimer("pass").blocked_autorange(min_run_time=10), + """\ + + pass + Median: 7.92 ns + IQR: 0.43 ns (7.75 to 8.17) + 13 measurements, 100000000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockCudaTimer("pass").adaptive_autorange(), + """\ + + pass + Median: 7.75 ns + IQR: 0.57 ns (7.56 to 8.13) + 4 measurements, 10000000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockCudaTimer("cheap_fn()").blocked_autorange(min_run_time=10), + """\ + + cheap_fn() + Median: 4.04 us + IQR: 0.30 us (3.90 to 4.19) + 25 measurements, 100000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockCudaTimer("cheap_fn()").adaptive_autorange(), + """\ + + cheap_fn() + Median: 4.09 us + IQR: 0.38 us (3.90 to 4.28) + 4 measurements, 100000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockCudaTimer("expensive_fn()").blocked_autorange(min_run_time=10), + """\ + + expensive_fn() + Median: 4.98 us + IQR: 0.31 us (4.83 to 5.13) + 20 measurements, 100000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockCudaTimer("expensive_fn()").adaptive_autorange(), + """\ + + expensive_fn() + Median: 5.01 us + IQR: 0.28 us (4.87 to 5.15) + 4 measurements, 10000 runs per measurement, 1 thread""" + ) + + # Make sure __repr__ is reasonable for + # multi-line / label / sub_label / description, but we don't need to + # check numerics. + multi_line_stmt = """ + with torch.no_grad(): + y = x + 1 + """ + + self.regularizeAndAssertExpectedInline( + MockTimer(multi_line_stmt).blocked_autorange(), + """\ + + stmt: + with torch.no_grad(): + y = x + 1 + + Median: 10.06 us + IQR: 0.54 us (9.73 to 10.27) + 20 measurements, 1000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer(multi_line_stmt, sub_label="scalar_add").blocked_autorange(), + """\ + + stmt: (scalar_add) + with torch.no_grad(): + y = x + 1 + + Median: 10.06 us + IQR: 0.54 us (9.73 to 10.27) + 20 measurements, 1000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer( + multi_line_stmt, + label="x + 1 (no grad)", + sub_label="scalar_add", + ).blocked_autorange(), + """\ + + x + 1 (no grad): scalar_add + Median: 10.06 us + IQR: 0.54 us (9.73 to 10.27) + 20 measurements, 1000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer( + multi_line_stmt, + setup="setup_fn()", + sub_label="scalar_add", + ).blocked_autorange(), + """\ + + stmt: (scalar_add) + with torch.no_grad(): + y = x + 1 + + setup: setup_fn() + Median: 10.06 us + IQR: 0.54 us (9.73 to 10.27) + 20 measurements, 1000 runs per measurement, 1 thread""" + ) + + self.regularizeAndAssertExpectedInline( + MockTimer( + multi_line_stmt, + setup=""" + x = torch.ones((1,), requires_grad=True) + for _ in range(5): + x = x + 1.0""", + sub_label="scalar_add", + description="Multi-threaded scalar math!", + num_threads=16, + ).blocked_autorange(), + """\ + + stmt: (scalar_add) + with torch.no_grad(): + y = x + 1 + + Multi-threaded scalar math! + setup: + x = torch.ones((1,), requires_grad=True) + for _ in range(5): + x = x + 1.0 + + Median: 10.06 us + IQR: 0.54 us (9.73 to 10.27) + 20 measurements, 1000 runs per measurement, 16 threads""" + ) + + @slowTest + @unittest.skipIf(IS_WINDOWS, "Valgrind is not supported on Windows.") + @unittest.skipIf(IS_SANDCASTLE, "Valgrind is OSS only.") + def test_collect_callgrind(self): + with self.assertRaisesRegex( + ValueError, + r"`collect_callgrind` requires that globals be wrapped " + r"in `CopyIfCallgrind` so that serialization is explicit." + ): + benchmark_utils.Timer( + "pass", + globals={"x": 1} + ).collect_callgrind(collect_baseline=False) + + with self.assertRaisesRegex( + # Subprocess raises AttributeError (from pickle), + # _ValgrindWrapper re-raises as generic OSError. + OSError, "AttributeError: Can't get attribute 'MyModule'" + ): + benchmark_utils.Timer( + "model(1)", + globals={"model": benchmark_utils.CopyIfCallgrind(MyModule())} + ).collect_callgrind(collect_baseline=False) + + + @torch.jit.script + def add_one(x): + return x + 1 + + timer = benchmark_utils.Timer( + "y = add_one(x) + k", + setup="x = torch.ones((1,))", + globals={ + "add_one": benchmark_utils.CopyIfCallgrind(add_one), + "k": benchmark_utils.CopyIfCallgrind(5), + "model": benchmark_utils.CopyIfCallgrind( + MyModule(), + setup=f"""\ + import sys + sys.path.append({repr(os.path.split(os.path.abspath(__file__))[0])}) + from test_benchmark_utils import MyModule + """ + ) + } + ) + + # Don't collect baseline to speed up unit test by ~30 seconds. + stats = timer.collect_callgrind(number=1000, collect_baseline=False) + counts = stats.counts(denoise=False) + + self.assertIsInstance(counts, int) + self.assertGreater(counts, 0) + + from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import wrapper_singleton + self.assertIsNone( + wrapper_singleton()._bindings_module, + "JIT'd bindings are only for back testing." + ) + + @slowTest + @unittest.skipIf(IS_WINDOWS, "Valgrind is not supported on Windows.") + @unittest.skipIf(IS_SANDCASTLE, "Valgrind is OSS only.") + def test_collect_cpp_callgrind(self): + timer = benchmark_utils.Timer( + "x += 1;", + setup="torch::Tensor x = torch::ones({1});", + timer=timeit.default_timer, + language="c++", + ) + stats = [ + timer.collect_callgrind() + for _ in range(3) + ] + counts = [s.counts() for s in stats] + + self.assertGreater( + min(counts), 0, "No stats were collected") + self.assertEqual( + min(counts), max(counts), "C++ Callgrind should be deterministic") + + for s in stats: + self.assertEqual( + s.counts(denoise=True), s.counts(denoise=False), + "De-noising should not apply to C++.") + + def test_manipulate_callgrind_stats(self): + stats_no_data, stats_with_data = load_callgrind_artifacts() + + # Mock `torch.set_printoptions(linewidth=160)` + wide_linewidth = benchmark_utils.FunctionCounts( + stats_no_data.stats(inclusive=False)._data, False, 160) + + for l in repr(wide_linewidth).splitlines(keepends=False): + self.assertLessEqual(len(l), 160) + + self.assertEqual( + # `delta` is just a convenience method. + stats_with_data.delta(stats_no_data)._data, + (stats_with_data.stats() - stats_no_data.stats())._data + ) + + deltas = stats_with_data.as_standardized().delta(stats_no_data.as_standardized()) + + def custom_transforms(fn: str): + fn = re.sub(re.escape("/usr/include/c++/8/bits/"), "", fn) + fn = re.sub(r"build/../", "", fn) + fn = re.sub(".+" + re.escape("libsupc++"), "libsupc++", fn) + return fn + + self.regularizeAndAssertExpectedInline( + stats_no_data, + """\ + + y = torch.ones(()) + All Noisy symbols removed + Instructions: 8869966 8728096 + Baseline: 6682 5766 + 1000 runs per measurement, 1 thread""", + ) + + self.regularizeAndAssertExpectedInline( + stats_no_data.counts(), + """8869966""", + ) + + self.regularizeAndAssertExpectedInline( + stats_no_data.counts(denoise=True), + """8728096""", + ) + + self.regularizeAndAssertExpectedInline( + stats_no_data.stats(), + """\ + + 408000 ???:__tls_get_addr [/usr/lib64/ld-2.28.so] + 388193 ???:_int_free [/usr/lib64/libc-2.28.so] + 274000 build/../torch/csrc/utils/python ... rch/torch/lib/libtorch_python.so] + 264000 build/../aten/src/ATen/record_fu ... ytorch/torch/lib/libtorch_cpu.so] + 192000 build/../c10/core/Device.h:c10:: ... epos/pytorch/torch/lib/libc10.so] + 169855 ???:_int_malloc [/usr/lib64/libc-2.28.so] + 154000 build/../c10/core/TensorOptions. ... ytorch/torch/lib/libtorch_cpu.so] + 147167 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + 135000 ???:malloc [/usr/lib64/libc-2.28.so] + ... + -62 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + -63 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + -70 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + -74 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + -85 /home/test_user/miniconda3/envs/ ... rch/torch/lib/libtorch_python.so] + -95 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + -104 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + -134 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + -180 /tmp/build/80754af9/python_15996 ... da3/envs/throwaway/bin/python3.6] + + Total: 8863284""", + ) + + self.regularizeAndAssertExpectedInline( + stats_no_data.stats(inclusive=True), + """\ + + 8952420 ???:0x0000000000001050 [/usr/lib64/ld-2.28.so] + 8952420 ???:(below main) [/usr/lib64/libc-2.28.so] + 8952420 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + 8952420 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + 8952420 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + 8952420 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + 8952420 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + 8952420 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + 8952420 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + ... + -195 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + -196 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + -207 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + -261 /home/test_user/miniconda3/envs/ ... ch/torch/lib/libtorch_python.so] + -561 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + -789 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + -881 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + -1196 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] + -1206 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6]""", + ) + + self.regularizeAndAssertExpectedInline( + wide_linewidth, + """\ + + 408000 ???:__tls_get_addr [/usr/lib64/ld-2.28.so] + 388193 ???:_int_free [/usr/lib64/libc-2.28.so] + 274000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature ... bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so] + 264000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFun ... ordScope) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so] + 192000 build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so] + 169855 ???:_int_malloc [/usr/lib64/libc-2.28.so] + 154000 build/../c10/core/TensorOptions.h:c10::TensorOptions::merge_in(c10::Tens ... ns) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so] + 147167 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:_PyEval_EvalFrameDefault [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + 135000 ???:malloc [/usr/lib64/libc-2.28.so] + ... + -62 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:PyNumber_Subtract [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + -63 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:long_richcompare [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + -70 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyObject_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + -74 /tmp/build/80754af9/python_1599604603603/work/Python/pytime.c:_PyTime_FromSecondsObject [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + -85 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:py ... ject*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so] + -95 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:rangeiter_next [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + -104 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_RichCompare [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + -134 /tmp/build/80754af9/python_1599604603603/work/Objects/rangeobject.c:range_new [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + -180 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_FromLong [/home/test_user/miniconda3/envs/throwaway/bin/python3.6] + + Total: 8863284""" # noqa + ) + + self.regularizeAndAssertExpectedInline( + stats_no_data.as_standardized().stats(), + """\ + + 408000 ???:__tls_get_addr + 388193 ???:_int_free + 274000 build/../torch/csrc/utils/python ... ject*, _object*, _object**, bool) + 264000 build/../aten/src/ATen/record_fu ... ::RecordFunction(at::RecordScope) + 192000 build/../c10/core/Device.h:c10::Device::validate() + 169855 ???:_int_malloc + 154000 build/../c10/core/TensorOptions. ... erge_in(c10::TensorOptions) const + 147167 Python/ceval.c:_PyEval_EvalFrameDefault + 135000 ???:malloc + ... + -62 Objects/abstract.c:PyNumber_Subtract + -63 Objects/longobject.c:long_richcompare + -70 Objects/abstract.c:_PyObject_FastCallDict + -74 Python/pytime.c:_PyTime_FromSecondsObject + -85 /home/test_user/miniconda3/envs/ ... her(_object*, _object*, _object*) + -95 Objects/rangeobject.c:rangeiter_next + -104 Objects/object.c:PyObject_RichCompare + -134 Objects/rangeobject.c:range_new + -180 Objects/longobject.c:PyLong_FromLong + + Total: 8863284""", + ) + + self.regularizeAndAssertExpectedInline( + deltas, + """\ + + 85000 Objects/dictobject.c:lookdict_unicode + 59089 ???:_int_free + 43000 ???:malloc + 25000 build/../torch/csrc/utils/python ... :torch::PythonArgs::intlist(int) + 24000 ???:__tls_get_addr + 23000 ???:free + 21067 Objects/dictobject.c:lookdict_unicode_nodummy + 20000 build/../torch/csrc/utils/python ... :torch::PythonArgs::intlist(int) + 18000 Objects/longobject.c:PyLong_AsLongLongAndOverflow + ... + 2000 /home/nwani/m3/conda-bld/compile ... del_op.cc:operator delete(void*) + 1000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgs::intlist(int) + 193 ???:_int_malloc + 75 ???:_int_memalign + -1000 build/../c10/util/SmallVector.h: ... _contiguous(c10::ArrayRef) + -1000 build/../c10/util/SmallVector.h: ... nsor_restride(c10::MemoryFormat) + -1000 /usr/include/c++/8/bits/stl_vect ... es(_object*, _object*, _object*) + -8000 Python/ceval.c:_PyEval_EvalFrameDefault + -16000 Objects/tupleobject.c:PyTuple_New + + Total: 432917""", + ) + + self.regularizeAndAssertExpectedInline(len(deltas), """35""") + + self.regularizeAndAssertExpectedInline( + deltas.transform(custom_transforms), + """\ + + 85000 Objects/dictobject.c:lookdict_unicode + 59089 ???:_int_free + 43000 ???:malloc + 25000 torch/csrc/utils/python_numbers.h:torch::PythonArgs::intlist(int) + 24000 ???:__tls_get_addr + 23000 ???:free + 21067 Objects/dictobject.c:lookdict_unicode_nodummy + 20000 torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::intlist(int) + 18000 Objects/longobject.c:PyLong_AsLongLongAndOverflow + ... + 2000 c10/util/SmallVector.h:c10::TensorImpl::compute_contiguous() const + 1000 stl_vector.h:torch::PythonArgs::intlist(int) + 193 ???:_int_malloc + 75 ???:_int_memalign + -1000 stl_vector.h:torch::autograd::TH ... es(_object*, _object*, _object*) + -1000 c10/util/SmallVector.h:c10::Tens ... _contiguous(c10::ArrayRef) + -1000 c10/util/SmallVector.h:c10::Tens ... nsor_restride(c10::MemoryFormat) + -8000 Python/ceval.c:_PyEval_EvalFrameDefault + -16000 Objects/tupleobject.c:PyTuple_New + + Total: 432917""", + ) + + self.regularizeAndAssertExpectedInline( + deltas.filter(lambda fn: fn.startswith("???")), + """\ + + 59089 ???:_int_free + 43000 ???:malloc + 24000 ???:__tls_get_addr + 23000 ???:free + 193 ???:_int_malloc + 75 ???:_int_memalign + + Total: 149357""", + ) + + self.regularizeAndAssertExpectedInline( + deltas[:5], + """\ + + 85000 Objects/dictobject.c:lookdict_unicode + 59089 ???:_int_free + 43000 ???:malloc + 25000 build/../torch/csrc/utils/python_ ... h:torch::PythonArgs::intlist(int) + 24000 ???:__tls_get_addr + + Total: 236089""", + ) + + def test_compare(self): + # Simulate several approaches. + costs = ( + # overhead_optimized_fn() + (1e-6, 1e-9), + + # compute_optimized_fn() + (3e-6, 5e-10), + + # special_case_fn() [square inputs only] + (1e-6, 4e-10), + ) + + sizes = ( + (16, 16), + (16, 128), + (128, 128), + (4096, 1024), + (2048, 2048), + ) + + # overhead_optimized_fn() + class _MockTimer_0(self._MockTimer): + _function_costs = tuple( + (f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j) + for i, j in sizes + ) + + class MockTimer_0(benchmark_utils.Timer): + _timer_cls = _MockTimer_0 + + # compute_optimized_fn() + class _MockTimer_1(self._MockTimer): + _function_costs = tuple( + (f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j) + for i, j in sizes + ) + + class MockTimer_1(benchmark_utils.Timer): + _timer_cls = _MockTimer_1 + + # special_case_fn() + class _MockTimer_2(self._MockTimer): + _function_costs = tuple( + (f"fn({i}, {j})", costs[2][0] + costs[2][1] * i * j) + for i, j in sizes if i == j + ) + + class MockTimer_2(benchmark_utils.Timer): + _timer_cls = _MockTimer_2 + + results = [] + for i, j in sizes: + results.append( + MockTimer_0( + f"fn({i}, {j})", + label="fn", + description=f"({i}, {j})", + sub_label="overhead_optimized", + ).blocked_autorange(min_run_time=10) + ) + + results.append( + MockTimer_1( + f"fn({i}, {j})", + label="fn", + description=f"({i}, {j})", + sub_label="compute_optimized", + ).blocked_autorange(min_run_time=10) + ) + + if i == j: + results.append( + MockTimer_2( + f"fn({i}, {j})", + label="fn", + description=f"({i}, {j})", + sub_label="special_case (square)", + ).blocked_autorange(min_run_time=10) + ) + + def rstrip_lines(s: str) -> str: + # VSCode will rstrip the `expected` string literal whether you like + # it or not. So we have to rstrip the compare table as well. + return "\n".join([i.rstrip() for i in s.splitlines(keepends=False)]) + + compare = benchmark_utils.Compare(results) + self.regularizeAndAssertExpectedInline( + rstrip_lines(str(compare).strip()), + """\ + [------------------------------------------------- fn ------------------------------------------------] + | (16, 16) | (16, 128) | (128, 128) | (4096, 1024) | (2048, 2048) + 1 threads: -------------------------------------------------------------------------------------------- + overhead_optimized | 1.3 | 3.0 | 17.4 | 4174.4 | 4174.4 + compute_optimized | 3.1 | 4.0 | 11.2 | 2099.3 | 2099.3 + special_case (square) | 1.1 | | 7.5 | | 1674.7 + + Times are in microseconds (us).""" + ) + + compare.trim_significant_figures() + self.regularizeAndAssertExpectedInline( + rstrip_lines(str(compare).strip()), + """\ + [------------------------------------------------- fn ------------------------------------------------] + | (16, 16) | (16, 128) | (128, 128) | (4096, 1024) | (2048, 2048) + 1 threads: -------------------------------------------------------------------------------------------- + overhead_optimized | 1 | 3.0 | 17 | 4200 | 4200 + compute_optimized | 3 | 4.0 | 11 | 2100 | 2100 + special_case (square) | 1 | | 8 | | 1700 + + Times are in microseconds (us).""" + ) + + compare.colorize() + columnwise_colored_actual = rstrip_lines(str(compare).strip()) + columnwise_colored_expected = textwrap.dedent( + """\ + [------------------------------------------------- fn ------------------------------------------------] + | (16, 16) | (16, 128) | (128, 128) | (4096, 1024) | (2048, 2048) + 1 threads: -------------------------------------------------------------------------------------------- + overhead_optimized | 1 | \x1b[92m\x1b[1m 3.0 \x1b[0m\x1b[0m | \x1b[2m\x1b[91m 17 \x1b[0m\x1b[0m | 4200 | \x1b[2m\x1b[91m 4200 \x1b[0m\x1b[0m + compute_optimized | \x1b[2m\x1b[91m 3 \x1b[0m\x1b[0m | 4.0 | 11 | \x1b[92m\x1b[1m 2100 \x1b[0m\x1b[0m | 2100 + special_case (square) | \x1b[92m\x1b[1m 1 \x1b[0m\x1b[0m | | \x1b[92m\x1b[1m 8 \x1b[0m\x1b[0m | | \x1b[92m\x1b[1m 1700 \x1b[0m\x1b[0m + + Times are in microseconds (us).""" # noqa + ) + + compare.colorize(rowwise=True) + rowwise_colored_actual = rstrip_lines(str(compare).strip()) + rowwise_colored_expected = textwrap.dedent( + """\ + [------------------------------------------------- fn ------------------------------------------------] + | (16, 16) | (16, 128) | (128, 128) | (4096, 1024) | (2048, 2048) + 1 threads: -------------------------------------------------------------------------------------------- + overhead_optimized | \x1b[92m\x1b[1m 1 \x1b[0m\x1b[0m | \x1b[2m\x1b[91m 3.0 \x1b[0m\x1b[0m | \x1b[31m\x1b[1m 17 \x1b[0m\x1b[0m | \x1b[31m\x1b[1m 4200 \x1b[0m\x1b[0m | \x1b[31m\x1b[1m 4200 \x1b[0m\x1b[0m + compute_optimized | \x1b[92m\x1b[1m 3 \x1b[0m\x1b[0m | 4.0 | \x1b[2m\x1b[91m 11 \x1b[0m\x1b[0m | \x1b[31m\x1b[1m 2100 \x1b[0m\x1b[0m | \x1b[31m\x1b[1m 2100 \x1b[0m\x1b[0m + special_case (square) | \x1b[92m\x1b[1m 1 \x1b[0m\x1b[0m | | \x1b[31m\x1b[1m 8 \x1b[0m\x1b[0m | | \x1b[31m\x1b[1m 1700 \x1b[0m\x1b[0m + + Times are in microseconds (us).""" # noqa + ) + + def print_new_expected(s: str) -> None: + print(f'{"":>12}"""\\', end="") + for l in s.splitlines(keepends=False): + print("\n" + textwrap.indent(repr(l)[1:-1], " " * 12), end="") + print('"""\n') + + if expecttest.ACCEPT: + # expecttest does not currently support non-printable characters, + # so these two entries have to be updated manually. + if columnwise_colored_actual != columnwise_colored_expected: + print("New columnwise coloring:\n") + print_new_expected(columnwise_colored_actual) + + if rowwise_colored_actual != rowwise_colored_expected: + print("New rowwise coloring:\n") + print_new_expected(rowwise_colored_actual) + + self.assertEqual(columnwise_colored_actual, columnwise_colored_expected) + self.assertEqual(rowwise_colored_actual, rowwise_colored_expected) + + @unittest.skipIf(IS_WINDOWS and os.getenv("VC_YEAR") == "2019", "Random seed only accepts int32") + def test_fuzzer(self): + fuzzer = benchmark_utils.Fuzzer( + parameters=[ + benchmark_utils.FuzzedParameter( + "n", minval=1, maxval=16, distribution="loguniform")], + tensors=[benchmark_utils.FuzzedTensor("x", size=("n",))], + seed=0, + ) + + expected_results = [ + (0.7821, 0.0536, 0.9888, 0.1949, 0.5242, 0.1987, 0.5094), + (0.7166, 0.5961, 0.8303, 0.005), + ] + + for i, (tensors, _, _) in enumerate(fuzzer.take(2)): + x = tensors["x"] + self.assertEqual( + x, torch.Tensor(expected_results[i]), rtol=1e-3, atol=1e-3) + + +if __name__ == '__main__': + run_tests() diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index e9bcae7f47a5b..3e8467be97344 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -14,6 +14,7 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/memory.cpp ${TORCH_API_TEST_DIR}/misc.cpp ${TORCH_API_TEST_DIR}/module.cpp + ${TORCH_API_TEST_DIR}/moduledict.cpp ${TORCH_API_TEST_DIR}/modulelist.cpp ${TORCH_API_TEST_DIR}/modules.cpp ${TORCH_API_TEST_DIR}/parameterdict.cpp diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 8a8aa75541ac0..3f79c771c2be3 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -7,6 +7,7 @@ #include using namespace torch::autograd; +using namespace torch::test; #define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b))) #define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b))) @@ -154,6 +155,39 @@ TEST(AutogradAPITests, RetainGrad) { ASSERT_VARIABLE_EQ(input * 18, input.grad()); } +TEST(AutogradAPITests, AnomalyMode) { + // Needs to have backtrace as warning and then throw an error + torch::autograd::DetectAnomalyGuard detect_anomaly; + { + WarningCapture warnings; + auto x = torch::tensor({5.0}, torch::requires_grad()); + auto y = x * x; + auto z = y * y; + y += 1; + ASSERT_THROWS_WITH(z.backward(), "inplace"); + ASSERT_TRUE( + warnings.str().find("Traceback of forward") != std::string::npos); + } + { + WarningCapture warnings; + // Double backward + auto x = torch::tensor({0.0}, torch::requires_grad()); + auto y = x.pow(1.5); + auto gr = + grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); + ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan"); + auto msgs = warnings.messages(); + ASSERT_EQ(msgs.size(), 2); + ASSERT_TRUE( + msgs[0].find("Traceback of forward call that caused the error") != + std::string::npos); + ASSERT_TRUE( + msgs[1].find( + "Traceback of forward call that induced the previous calculation") != + std::string::npos); + } +} + TEST(CustomAutogradTest, CustomFunction) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) { @@ -211,7 +245,7 @@ TEST(CustomAutogradTest, FunctionReturnsUndefined) { }; auto x = torch::ones(1, torch::requires_grad()); - + MyFunction::apply(x).backward(); ASSERT_FALSE(x.grad().defined()); @@ -647,6 +681,8 @@ TEST(CustomAutogradTest, ReentrantPriority) { ASSERT_EQ(order.size(), 10); ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9); ASSERT_EQ(order.back(), 0); + // Clear static variable in case test get executed in a loop + order.clear(); } TEST(CustomAutogradTest, Hooks) { @@ -728,6 +764,36 @@ TEST(CustomAutogradTest, HookNone) { ASSERT_TRUE(was_called); } +TEST(CustomAutogradTest, BackwardWithInputs) { + Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable z = x * x + x * y + y * y; + Variable x_grad_expected = 2 * x + y; + Variable y_grad_expected = x + 2 * y; + + z.backward(torch::ones({5, 5}), false, false, {x}); + + ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); + ASSERT_FALSE(y.grad().defined()); +} + +TEST(CustomAutogradTest, BackwardWithEmptyInputs) { + Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable z = x * x + x * y + y * y; + Variable x_grad_expected = 2 * x + y; + Variable y_grad_expected = x + 2 * y; + ASSERT_THROWS_WITH(z.backward(torch::ones({5, 5}), false, false, std::vector{}), "cannot be empty"); +} + +TEST(CustomAutogradTest, BackwardWithNonLeafInputs) { + Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable z = x * x; + Variable w = z + x * y + y * y; + ASSERT_THROWS_WITH(w.backward(torch::ones({5, 5}), false, false, {z}), "is not a leaf Tensor"); +} + // TODO add these tests if needed // test_once_differentiable // test_sparse_backward diff --git a/test/cpp/api/fft.cpp b/test/cpp/api/fft.cpp index f8f9d5f1d906b..e78e358862e63 100644 --- a/test/cpp/api/fft.cpp +++ b/test/cpp/api/fft.cpp @@ -4,16 +4,6 @@ #include -// Tests that the fft function can be called as usual -TEST(FFTTest, unclobbered_fft) { - auto t = torch::randn({64, 2}, torch::dtype(torch::kDouble)); - torch::fft(t, 1); -} - -// Clobbers torch::fft the function with torch::fft the namespace -#include - - // Naive DFT of a 1 dimensional tensor torch::Tensor naive_dft(torch::Tensor x, bool forward=true) { TORCH_INTERNAL_ASSERT(x.dim() == 1); diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 4efdb122efc86..d4f353f5607f4 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -246,6 +246,18 @@ TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) { ASSERT_TRUE(input.sizes() == input.grad().sizes()); } +TEST_F(FunctionalTest, SmoothL1LossBeta) { + auto input = torch::tensor({0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true)); + auto target = torch::tensor({0., 1., 5.}, torch::kFloat); + auto output = + F::smooth_l1_loss(input, target, /*reduction=*/torch::kMean, /*beta=*/0.5); + auto expected = torch::tensor(1.67, torch::kFloat); + auto s = output.sum(); + s.backward(); + ASSERT_TRUE(output.allclose(expected)); + ASSERT_TRUE(input.sizes() == input.grad().sizes()); +} + TEST_F(FunctionalTest, SmoothL1LossNoReduction) { auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); @@ -670,6 +682,56 @@ TEST_F(FunctionalTest, TripletMarginLoss) { ASSERT_TRUE(output.allclose(expected, 1e-04)); } +TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) { + // Check that if we use torch::pairwise_distance with the default + // TripletMarginLoss options as our distance function, the outputs + // are equal (i.e., equal under defaults). + + std::vector + reductions = {torch::kSum, torch::kMean, torch::kNone}; + std::vector margins = {0.5, 1.0, 1.5}; + std::vector swaps = {true, false}; + + for (auto& reduction : reductions) { + for (auto& margin : margins) { + for (const auto& swap : swaps) { + auto anchor = + torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto positive = + torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto negative = + torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + + auto basicOptions = F::TripletMarginLossFuncOptions() + .reduction(reduction) + .margin(margin) + .swap(swap); + auto distanceOptions = + F::TripletMarginWithDistanceLossFuncOptions() + .reduction(reduction) + .margin(margin) + .swap(swap); + TripletMarginLoss basicLoss(basicOptions); + TripletMarginWithDistanceLoss distanceLoss(distanceOptions); + + auto basicOutput = + F::triplet_margin_loss(anchor, positive, negative, basicOptions); + auto distanceOutput = F::triplet_margin_with_distance_loss( + anchor, positive, negative, distanceOptions); + + ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6)); + + // handle for torch::kNone reduction + auto sum = distanceOutput.sum(); + sum.backward(); + ASSERT_EQ(anchor.sizes(), anchor.grad().sizes()); + ASSERT_EQ(positive.sizes(), positive.grad().sizes()); + ASSERT_EQ(negative.sizes(), negative.grad().sizes()); + } + } + } +} + TEST_F(FunctionalTest, NLLLoss) { auto input = torch::tensor({{-0.1315, -3.1315, -2.5315}, {-3.7038, -0.1038, -2.6038}, @@ -1425,6 +1487,23 @@ TEST_F(FunctionalTest, PixelShuffle) { ASSERT_TRUE(y.allclose(y_exp)); } +TEST_F(FunctionalTest, PixelUnshuffle) { + auto x = torch::tensor( + {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, + torch::kFloat); + auto y_exp = torch::tensor( + {{{{-17, 19}, {-1, 2}}, + {{7, 14}, {-3, 1}}, + {{0, -2}, {-12, 14}}, + {{-15, 0}, {-3, 9}}}}, + torch::kFloat); + auto y = F::pixel_unshuffle(x, 2); + + ASSERT_EQ(y.ndimension(), 4); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2})); + ASSERT_TRUE(y.allclose(y_exp)); +} + TEST_F(FunctionalTest, Softplus) { const auto size = 3; for (const auto beta : {0.5, 1.0, 2.0}) { diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp index 160075d0d268d..a8d6320e9533d 100644 --- a/test/cpp/api/misc.cpp +++ b/test/cpp/api/misc.cpp @@ -82,3 +82,11 @@ TEST_F(AutogradTest, CanPassCustomGradientInputs) { z.sum().backward(torch::ones({}) * 2); ASSERT_TRUE(x.grad().allclose(y * 2)); } + +TEST(UtilsTest, AmbiguousOperatorDefaults) { + auto tmp = at::empty({}, at::kCPU); + at::_test_ambiguous_defaults(tmp); + at::_test_ambiguous_defaults(tmp, 1); + at::_test_ambiguous_defaults(tmp, 1, 1); + at::_test_ambiguous_defaults(tmp, 2, "2"); +} diff --git a/test/cpp/api/moduledict.cpp b/test/cpp/api/moduledict.cpp new file mode 100644 index 0000000000000..3451421cfeb9c --- /dev/null +++ b/test/cpp/api/moduledict.cpp @@ -0,0 +1,313 @@ +#include +#include +#include +#include +#include + +#include + +using namespace torch::nn; +using namespace torch::test; + +struct ModuleDictTest : torch::test::SeedingFixture {}; + +TEST_F(ModuleDictTest, ConstructsFromList) { + struct M : Module { + explicit M(int value_) : value(value_) {} + int value; + }; + + std::vector>> list = { + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)}, + {"module_3", std::make_shared(3)} + }; + ModuleDict dict(list); + ASSERT_EQ(dict->size(), 3); +} + +TEST_F(ModuleDictTest, ConstructsFromordereddict) { + struct M : Module { + explicit M(int value_) : value(value_) {} + int value; + }; + + torch::OrderedDict> ordereddict = { + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)}, + {"module_3", std::make_shared(3)}, + }; + ModuleDict dict(ordereddict); + ASSERT_EQ(dict->size(), 3); +} + +TEST_F(ModuleDictTest, UpdatePopClearContains) { + struct M : Module { + explicit M(int value_) : value(value_) {} + int value; + }; + + ModuleDict dict; + ASSERT_TRUE(dict->empty()); + // Update by List + std::vector>> list1 = { + {"module_1", std::make_shared(1)} + }; + dict->update(list1); + ASSERT_EQ(dict->size(), 1); + ASSERT_TRUE(dict->contains("module_1")); + // Update by OrderedDict + torch::OrderedDict> ordereddict = { + {"module_2", std::make_shared(2)} + }; + dict->update(ordereddict); + ASSERT_EQ(dict->size(), 2); + ASSERT_TRUE(dict->contains("module_2")); + // Update by another ModuleDict + std::vector>>list2 = { + {"module_3", std::make_shared(3)} + }; + ModuleDict updatedict(list2); + dict->update(*updatedict); + ASSERT_EQ(dict->size(), 3); + ASSERT_TRUE(dict->contains("module_3")); + // Pop + dict->pop("module_1"); + ASSERT_EQ(dict->size(), 2); + // Pop unexist + ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined"); + // Clear + dict->clear(); + ASSERT_EQ(dict->size(), 0); +} + +TEST_F(ModuleDictTest, UpdateExist) { + struct M : Module { + explicit M(int value_) : value(value_) {} + int value; + }; + std::vector>> list1 = { + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)} + }; + ModuleDict dict(list1); + ASSERT_EQ(dict->at("module_2").value, 2); + // Update by list + std::vector>> list2 = { + {"module_2", std::make_shared(0)}, + {"module_3", std::make_shared(3)} + }; + dict->update(list2); + ASSERT_EQ(dict->size(), 3); + ASSERT_EQ(dict->at("module_2").value, 0); + // Update by ordereddict + torch::OrderedDict> ordereddict = { + {"module_3", std::make_shared(0)}, + {"module_4", std::make_shared(4)} + }; + dict->update(ordereddict); + ASSERT_EQ(dict->size(), 4); + ASSERT_EQ(dict->at("module_3").value, 0); + // Update by ModuleDict + std::vector>> list3 = { + {"module_4", std::make_shared(0)}, + {"module_1", std::make_shared(0)} + }; + ModuleDict dict2(list3); + dict->update(*dict2); + ASSERT_EQ(dict->size(), 4); + ASSERT_EQ(dict->at("module_1").value, 0); + ASSERT_EQ(dict->at("module_4").value, 0); +} + +TEST_F(ModuleDictTest, Keys) { + struct M : Module { + explicit M(int value_) : value(value_) {} + int value; + }; + + torch::OrderedDict> ordereddict = { + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"dropout", Dropout(0.5).ptr()}, + }; + ModuleDict dict(ordereddict); + const auto& keys = dict->keys(); + std::vector expected{"linear", "conv", "dropout"}; + ASSERT_EQ(keys, expected); + ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined"); + + ASSERT_TRUE(dict["linear"]->as()); + ASSERT_TRUE(dict["conv"]->as()); + ASSERT_TRUE(dict["dropout"]->as()); +} + +TEST_F(ModuleDictTest, Values) { + struct M : Module { + explicit M(int value_) : value(value_) {} + int value; + }; + + torch::OrderedDict> ordereddict = { + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)}, + }; + ModuleDict dict(ordereddict); + const auto& values = dict->values(); + const auto& expected = ordereddict.values(); + ASSERT_EQ(values, expected); + ASSERT_TRUE(std::equal( + dict->begin(), + dict->end(), + ordereddict.begin(), + [](const auto& lhs, + const auto& rhs) { + return lhs.value().get() == rhs.value().get(); + })); +} + +TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) { + torch::OrderedDict> ordereddict = { + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"dropout", Dropout(0.5).ptr()}, + {"batch", BatchNorm2d(5).ptr()}, + {"embedding", Embedding(4, 10).ptr()}, + {"lstm", LSTM(4, 5).ptr()} + }; + ModuleDict dict(ordereddict); +} + +TEST_F(ModuleDictTest, HasReferenceSemantics) { + torch::OrderedDict> ordereddict = { + {"linear1", Linear(2, 3).ptr()}, + {"linear2", Linear(3, 4).ptr()}, + {"linear3", Linear(4, 5).ptr()}, + }; + ModuleDict first(ordereddict); + ModuleDict second(ordereddict); + + ASSERT_EQ(first->size(), second->size()); + ASSERT_TRUE(std::equal( + first->begin(), + first->end(), + second->begin(), + [](const auto& lhs, + const auto& rhs) { + return lhs.value().get() == rhs.value().get(); + })); +} + +void iscloneable_helper(torch::Device device) { + torch::OrderedDict> ordereddict = { + {"linear", Linear(2, 3).ptr()}, + {"relu", Functional(torch::relu).ptr()}, + {"batch", BatchNorm1d(3).ptr()}, + }; + ModuleDict dict(ordereddict); + dict->to(device); + ModuleDict clone = std::dynamic_pointer_cast(dict->clone(device)); + ASSERT_EQ(dict->size(), clone->size()); + + for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end(); ++it, ++it_c) { + // The key should be same + ASSERT_EQ(it->key(), it_c->key()); + // The modules should be the same kind (type). + ASSERT_EQ(it->value()->name(), it_c->value()->name()); + // But not pointer-equal (distinct objects). + ASSERT_NE(it->value(), it_c->value()); + } + + // Verify that the clone is deep, i.e. parameters of modules are cloned too. + torch::NoGradGuard no_grad; + + auto params1 = dict->named_parameters(); + auto params2 = clone->named_parameters(); + ASSERT_EQ(params1.size(), params2.size()); + for (auto& param : params1) { + ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()])); + ASSERT_EQ(param->device(), params2[param.key()].device()); + ASSERT_TRUE(param->allclose(params2[param.key()])); + param->add_(2); + } + for (auto& param : params1) { + ASSERT_FALSE(param->allclose(params2[param.key()])); + } +} + +TEST_F(ModuleDictTest, IsCloneable) { + iscloneable_helper(torch::kCPU); +} + +TEST_F(ModuleDictTest, IsCloneable_CUDA) { + iscloneable_helper({torch::kCUDA, 0}); +} + +TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) { + torch::OrderedDict> ordereddict1 = { + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"test", Dropout(0.5).ptr()}, + }; + ModuleDict dict(ordereddict1); + + auto modules = dict->children(); + ASSERT_TRUE(modules[0]->as()); + ASSERT_TRUE(modules[1]->as()); + ASSERT_TRUE(modules[2]->as()); + + // Update Existing + torch::OrderedDict> ordereddict2 = { + {"lstm", LSTM(4, 5).ptr()}, + {"test", BatchNorm2d(5).ptr()} + }; + dict->update(ordereddict2); + + modules = dict->children(); + ASSERT_TRUE(modules[0]->as()); + ASSERT_TRUE(modules[1]->as()); + // Keep Order + ASSERT_TRUE(modules[2]->as()); + ASSERT_TRUE(modules[3]->as()); +} + +TEST_F(ModuleDictTest, CloneToDevice_CUDA) { + torch::OrderedDict> ordereddict = { + {"linear", Linear(2, 3).ptr()}, + {"relu", Functional(torch::relu).ptr()}, + {"batch", BatchNorm1d(3).ptr()}, + }; + ModuleDict dict(ordereddict); + torch::Device device(torch::kCUDA, 0); + ModuleDict clone = + std::dynamic_pointer_cast(dict->clone(device)); + for (const auto& p : clone->parameters()) { + ASSERT_EQ(p.device(), device); + } + for (const auto& b : clone->buffers()) { + ASSERT_EQ(b.device(), device); + } +} + +TEST_F(ModuleDictTest, PrettyPrintModuleDict) { + torch::OrderedDict> ordereddict = { + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"dropout", Dropout(0.5).ptr()}, + {"batch", BatchNorm2d(5).ptr()}, + {"embedding", Embedding(4, 10).ptr()}, + {"lstm", LSTM(4, 5).ptr()} + }; + ModuleDict dict(ordereddict); + + ASSERT_EQ( + c10::str(dict), + "torch::nn::ModuleDict(\n" + " (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n" + " (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n" + " (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n" + " (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n" + " (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" + " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n" + ")"); +} diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 4777cf0b54bca..f24f8b42a19bd 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -2085,6 +2085,115 @@ TEST_F(ModulesTest, TripletMarginLoss) { ASSERT_EQ(anchor.sizes(), anchor.grad().sizes()); } +TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) { + // Check that if we use torch::pairwise_distance with the default + // TripletMarginLoss options as our distance function, the outputs + // are equal (i.e., equal under defaults). + + std::vector + reductions = {torch::kSum, torch::kMean, torch::kNone}; + std::vector margins = {0.5, 1.0, 1.5}; + std::vector swaps = {true, false}; + + for (auto& reduction : reductions) { + for (auto& margin : margins) { + for (const auto swap : swaps) { + auto anchor = + torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto positive = + torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto negative = + torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + + auto basicOptions = TripletMarginLossOptions() + .reduction(reduction) + .margin(margin) + .swap(swap); + auto distanceOptions = + TripletMarginWithDistanceLossOptions() + .reduction(reduction) + .margin(margin) + .swap(swap); + TripletMarginLoss basicLoss(basicOptions); + TripletMarginWithDistanceLoss distanceLoss(distanceOptions); + + auto basicOutput = basicLoss->forward(anchor, positive, negative); + auto distanceOutput = distanceLoss->forward(anchor, positive, negative); + auto basicOperatorOutput = basicLoss(anchor, positive, negative); + auto distanceOperatorOutput = distanceLoss(anchor, positive, negative); + + ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6)); + ASSERT_TRUE(distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6)); + ASSERT_TRUE(distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6)); + + // handle for torch::kNone reduction + auto sum = distanceOutput.sum(); + sum.backward(); + ASSERT_EQ(anchor.sizes(), anchor.grad().sizes()); + ASSERT_EQ(positive.sizes(), positive.grad().sizes()); + ASSERT_EQ(negative.sizes(), negative.grad().sizes()); + } + } + } +} + +TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) { + // Check for parity between F::triplet_margin_with_distance_loss and + // TripletMarginWithDistanceLoss. + auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) { + return torch::pairwise_distance(x, y); + }; + auto cosine_distance = [&](const torch::Tensor& x, + const torch::Tensor& y) { + return 1.0 - torch::cosine_similarity(x, y); + }; + std::vector + distance_functions = {pairwise_distance, cosine_distance}; + + std::vector + reductions = {torch::kSum, torch::kMean, torch::kNone}; + std::vector margins = {0.5, 1.0, 1.5}; + std::vector swaps = {true, false}; + + for (auto& function : distance_functions) { + for (auto& reduction : reductions) { + for (auto& margin : margins) { + for (const auto swap : swaps) { + auto moduleOptions = + TripletMarginWithDistanceLossOptions() + .distance_function(function) + .reduction(reduction) + .margin(margin) + .swap(swap); + auto functionOptions = + torch::nn::functional::TripletMarginWithDistanceLossFuncOptions() + .distance_function(function) + .reduction(reduction) + .margin(margin) + .swap(swap); + + auto anchor = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto positive = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto negative = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + + TripletMarginWithDistanceLoss distanceLoss(moduleOptions); + + auto moduleOutput = distanceLoss->forward(anchor, positive, negative); + auto moduleOperatorOutput = distanceLoss(anchor, positive, negative); + auto functionOutput = torch::nn::functional::triplet_margin_with_distance_loss( + anchor, positive, negative, functionOptions); + + ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6)); + ASSERT_TRUE(moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6)); + } + } + } + } +} + TEST_F(ModulesTest, NLLLoss) { NLLLoss loss; auto input = torch::tensor({{-0.1315, -3.1315, -2.5315}, @@ -2652,6 +2761,24 @@ TEST_F(ModulesTest, PixelShuffle) { ASSERT_TRUE(y.allclose(y_exp)); } +TEST_F(ModulesTest, PixelUnshuffle) { + PixelUnshuffle module(/*downscale_factor=*/2); + auto x = torch::tensor( + {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, + torch::kFloat); + auto y_exp = torch::tensor( + {{{{-17, 19}, {-1, 2}}, + {{7, 14}, {-3, 1}}, + {{0, -2}, {-12, 14}}, + {{-15, 0}, {-3, 9}}}}, + torch::kFloat); + auto y = module(x); + + ASSERT_EQ(y.ndimension(), 4); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2})); + ASSERT_TRUE(y.allclose(y_exp)); +} + TEST_F(ModulesTest, Softplus) { const auto size = 3; for (const auto beta : {0.5, 1.0, 2.0}) { @@ -3529,9 +3656,9 @@ TEST_F(ModulesTest, PrettyPrintIdentity) { } TEST_F(ModulesTest, PrettyPrintFlatten) { - ASSERT_EQ(c10::str(Flatten()), + ASSERT_EQ(c10::str(Flatten()), "torch::nn::Flatten(start_dim=1, end_dim=-1)"); - ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))), + ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))), "torch::nn::Flatten(start_dim=2, end_dim=4)"); } @@ -4394,6 +4521,20 @@ TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) { "torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)"); } +TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) { + auto distanceOptions = TripletMarginWithDistanceLossOptions() + .distance_function([&](const torch::Tensor& x, + const torch::Tensor& y) { + return torch::pairwise_distance(x, y, 2.0, 1e-6); + }) + .margin(1.5) + .swap(true) + .reduction(torch::kMean); + ASSERT_EQ( + c10::str(TripletMarginWithDistanceLoss(distanceOptions)), + "torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)"); +} + TEST_F(ModulesTest, PrettyPrintNLLLoss) { ASSERT_EQ( c10::str(NLLLoss()), "torch::nn::NLLLoss()"); @@ -4641,6 +4782,12 @@ TEST_F(ModulesTest, PrettyPrintPixelShuffle) { "torch::nn::PixelShuffle(upscale_factor=5)"); } +TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) { + ASSERT_EQ( + c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))), + "torch::nn::PixelUnshuffle(downscale_factor=5)"); +} + TEST_F(ModulesTest, PrettyPrintSoftplus) { ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)"); diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp index c51c38660c2dd..e5be819b48044 100644 --- a/test/cpp/api/rnn.cpp +++ b/test/cpp/api/rnn.cpp @@ -14,8 +14,13 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) { auto nhid = 32; auto model = std::make_shared(); auto l1 = model->add(Linear(1, nhid), "l1"); - auto rnn = model->add(model_maker(nhid), "rnn"); - auto lo = model->add(Linear(nhid, 1), "lo"); + auto rnn_model = model_maker(nhid); + auto rnn = model->add(rnn_model, "rnn"); + auto nout = nhid; + if (rnn_model.get()->options_base.proj_size() > 0) { + nout = rnn_model.get()->options_base.proj_size(); + } + auto lo = model->add(Linear(nout, 1), "lo"); torch::optim::Adam optimizer(model->parameters(), 1e-2); auto forward_op = [&](torch::Tensor x) { @@ -44,7 +49,6 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) { torch::rand({nlen, bs, 1}, backend).round().to(torch::kFloat32); auto labels = inputs.sum(0).detach(); inputs.set_requires_grad(true); - auto outputs = forward_op(inputs); torch::Tensor loss = torch::mse_loss(outputs, labels); @@ -90,6 +94,35 @@ void check_lstm_sizes(std::tuple(), 0); } +void check_lstm_sizes_proj(std::tuple> lstm_output) { + // Expect the LSTM to have 32 outputs and 3 layers, with an input of batch + // 10 and 16 time steps (10 x 16 x n) + + torch::Tensor output = std::get<0>(lstm_output); + std::tuple state = std::get<1>(lstm_output); + torch::Tensor hx = std::get<0>(state); + torch::Tensor cx = std::get<1>(state); + + ASSERT_EQ(output.ndimension(), 3); + ASSERT_EQ(output.size(0), 10); + ASSERT_EQ(output.size(1), 16); + ASSERT_EQ(output.size(2), 32); + + ASSERT_EQ(hx.ndimension(), 3); + ASSERT_EQ(hx.size(0), 3); // layers + ASSERT_EQ(hx.size(1), 16); // Batchsize + ASSERT_EQ(hx.size(2), 32); // 32 hidden dims + + ASSERT_EQ(cx.ndimension(), 3); + ASSERT_EQ(cx.size(0), 3); // layers + ASSERT_EQ(cx.size(1), 16); // Batchsize + ASSERT_EQ(cx.size(2), 64); // 64 cell dims + + // Something is in the hiddens + ASSERT_GT(hx.norm().item(), 0); + ASSERT_GT(cx.norm().item(), 0); +} + struct RNNTest : torch::test::SeedingFixture {}; TEST_F(RNNTest, CheckOutputSizes) { @@ -118,6 +151,33 @@ TEST_F(RNNTest, CheckOutputSizes) { ASSERT_GT(diff.abs().sum().item(), 1e-3); } +TEST_F(RNNTest, CheckOutputSizesProj) { + LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32)); + // Input size is: sequence length, batch size, input size + auto x = torch::randn({10, 16, 128}, torch::requires_grad()); + auto output = model->forward(x); + auto y = x.mean(); + + y.backward(); + check_lstm_sizes_proj(output); + + auto next = model->forward(x, std::get<1>(output)); + + check_lstm_sizes_proj(next); + + auto output_hx = std::get<0>(std::get<1>(output)); + auto output_cx = std::get<1>(std::get<1>(output)); + + auto next_hx = std::get<0>(std::get<1>(next)); + auto next_cx = std::get<1>(std::get<1>(next)); + + torch::Tensor diff = next_hx - output_hx; + // Hiddens changed + ASSERT_GT(diff.abs().sum().item(), 1e-3); + diff = next_cx - output_cx; + ASSERT_GT(diff.abs().sum().item(), 1e-3); +} + TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) { torch::manual_seed(0); // Make sure the outputs match pytorch outputs @@ -192,6 +252,11 @@ TEST_F(RNNTest, EndToEndLSTM) { [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); })); } +TEST_F(RNNTest, EndToEndLSTMProj) { + ASSERT_TRUE(test_RNN_xor( + [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); })); +} + TEST_F(RNNTest, EndToEndGRU) { ASSERT_TRUE( test_RNN_xor([](int s) { return GRU(GRUOptions(s, s).num_layers(2)); })); @@ -235,11 +300,45 @@ TEST_F(RNNTest, Sizes_CUDA) { ASSERT_GT(diff.abs().sum().item(), 1e-3); } +TEST_F(RNNTest, SizesProj_CUDA) { + torch::manual_seed(0); + LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32)); + model->to(torch::kCUDA); + auto x = + torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA)); + auto output = model->forward(x); + auto y = x.mean(); + + y.backward(); + check_lstm_sizes_proj(output); + + auto next = model->forward(x, std::get<1>(output)); + + check_lstm_sizes_proj(next); + + auto output_hx = std::get<0>(std::get<1>(output)); + auto output_cx = std::get<1>(std::get<1>(output)); + + auto next_hx = std::get<0>(std::get<1>(next)); + auto next_cx = std::get<1>(std::get<1>(next)); + + torch::Tensor diff = next_hx - output_hx; + // Hiddens changed + ASSERT_GT(diff.abs().sum().item(), 1e-3); + diff = next_cx - output_cx; + ASSERT_GT(diff.abs().sum().item(), 1e-3); +} + TEST_F(RNNTest, EndToEndLSTM_CUDA) { ASSERT_TRUE(test_RNN_xor( [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true)); } +TEST_F(RNNTest, EndToEndLSTMProj_CUDA) { + ASSERT_TRUE(test_RNN_xor( + [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); }, true)); +} + TEST_F(RNNTest, EndToEndGRU_CUDA) { ASSERT_TRUE(test_RNN_xor( [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true)); @@ -258,6 +357,9 @@ TEST_F(RNNTest, PrettyPrintRNNs) { ASSERT_EQ( c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))), "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)"); + ASSERT_EQ( + c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32))), + "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false, proj_size=32)"); ASSERT_EQ( c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))), "torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)"); @@ -503,6 +605,55 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) { } } +TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) { + // Create two LSTMs with the same options + auto opt = LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true).proj_size(2); + LSTM lstm_cpu {opt}; + LSTM lstm_cuda {opt}; + + // Copy weights and biases from CPU LSTM to CUDA LSTM + { + at::NoGradGuard guard; + for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) { + lstm_cuda->named_parameters()[param.key()].copy_(lstm_cpu->named_parameters()[param.key()]); + } + } + + lstm_cpu->flatten_parameters(); + lstm_cuda->flatten_parameters(); + + // Move LSTM to CUDA + lstm_cuda->to(torch::kCUDA); + + auto options = torch::TensorOptions() + .dtype(torch::kFloat32).requires_grad(false); + auto input_cpu = torch::tensor({1, 2, 3, 4, 5, 6}, options) + .reshape({3, 1, 2}); + auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options) + .reshape({3, 1, 2}).to(torch::kCUDA); + + // Call forward on both LSTMs + auto output_cpu = lstm_cpu->forward(input_cpu); + auto output_cuda = lstm_cuda->forward(input_cuda); + + output_cpu = lstm_output_to_device(output_cpu, torch::kCPU); + + // Assert that the output and state are equal on CPU and CUDA + ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim()); + for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) { + ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); + } + for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) { + for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) { + for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) { + ASSERT_NEAR( + std::get<0>(output_cpu)[i][j][k].item(), + std::get<0>(output_cuda)[i][j][k].item(), 1e-5); + } + } + } +} + TEST_F(RNNTest, UsePackedSequenceAsInput) { { torch::manual_seed(0); diff --git a/test/cpp/api/tensor_indexing.cpp b/test/cpp/api/tensor_indexing.cpp index efb153fbf4817..03600c5c882e6 100644 --- a/test/cpp/api/tensor_indexing.cpp +++ b/test/cpp/api/tensor_indexing.cpp @@ -83,27 +83,27 @@ TEST(TensorIndexingTest, TestNoIndices) { ASSERT_THROWS_WITH(tensor.index_put_(indices, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); } -TEST(TensorIndexingTest, TestAdvancedIndexingWithArrayRefOfTensor) { +TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) { { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index(at::ArrayRef({index})); + torch::Tensor result = at::index(tensor, {index}); torch::Tensor result_with_init_list = tensor.index({index}); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef({index}), torch::ones({20})); + torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20})); torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({20})); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef({index}), torch::ones({1, 20})); + torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({1, 20})); torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({1, 20})); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } } @@ -173,7 +173,7 @@ TEST(TensorIndexingTest, TestBoolIndices) { TEST(TensorIndexingTest, TestBoolIndicesAccumulate) { auto mask = torch::zeros({10}, torch::kBool); auto y = torch::ones({10, 10}); - y.index_put_({mask}, y.index({mask}), /*accumulate=*/true); + y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true); assert_tensor_equal(y, torch::ones({10, 10})); } diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp index 5de56139702a9..d6b347a5d754f 100644 --- a/test/cpp/api/tensor_options.cpp +++ b/test/cpp/api/tensor_options.cpp @@ -111,8 +111,8 @@ TEST(DeviceTest, ParsesCorrectlyFromString) { device = Device("hip"); ASSERT_EQ(device, Device(DeviceType::HIP)); - device = Device("hip:321"); - ASSERT_EQ(device, Device(DeviceType::HIP, 321)); + device = Device("hip:123"); + ASSERT_EQ(device, Device(DeviceType::HIP, 123)); std::vector badnesses = { "", "cud:1", "cuda:", "cpu::1", ":1", "3", "tpu:4", "??"}; diff --git a/test/cpp/dist_autograd/CMakeLists.txt b/test/cpp/dist_autograd/CMakeLists.txt index 5d23602881f07..9969c63e16d57 100644 --- a/test/cpp/dist_autograd/CMakeLists.txt +++ b/test/cpp/dist_autograd/CMakeLists.txt @@ -1,4 +1,4 @@ -if(USE_DISTRIBUTED) +if(USE_DISTRIBUTED AND NOT WIN32) set(DIST_AUTOGRAD_TEST_DIR "${TORCH_ROOT}/test/cpp/dist_autograd") set(DIST_AUTOGRAD_TEST_SOURCES ${TORCH_ROOT}/test/cpp/common/main.cpp diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index b8f6ef1952263..2e22cd646813f 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -19,12 +19,9 @@ endif() # Build the cpp gtest binary containing the cpp-only tests. set(JIT_TEST_SRCS - ${JIT_TEST_ROOT}/gtest.cpp ${JIT_TEST_ROOT}/test_alias_analysis.cpp ${JIT_TEST_ROOT}/test_argument_spec.cpp ${JIT_TEST_ROOT}/test_autodiff.cpp - ${JIT_TEST_ROOT}/test_base.cpp - ${JIT_TEST_ROOT}/test_base.h ${JIT_TEST_ROOT}/test_class_import.cpp ${JIT_TEST_ROOT}/test_class_parser.cpp ${JIT_TEST_ROOT}/test_class_type.cpp @@ -100,8 +97,6 @@ elseif(USE_ROCM) ${PYTORCH_HIP_HCC_LIBRARIES} ${TORCH_CUDA_LIBRARIES}) - target_link_libraries(test_jit PRIVATE caffe2_gpu) - target_compile_definitions(test_jit PRIVATE USE_ROCM) endif() diff --git a/test/cpp/jit/README.md b/test/cpp/jit/README.md index a3e92403201f3..ef5ea2d910be0 100644 --- a/test/cpp/jit/README.md +++ b/test/cpp/jit/README.md @@ -1,69 +1,44 @@ # JIT C++ Tests -## How to add a new test +## Adding a new test First, create a new test file. Test files should have be placed in this directory, with a name that starts with `test_`, like `test_foo.cpp`. -Here is an example test file you can copy-paste. +In general a single test suite + +Add your test file to the `JIT_TEST_SRCS` list in `test/cpp/jit/CMakeLists.txt`. + +A test file may look like: ```cpp -#include +#include -// Tests go in torch::jit -namespace torch { -namespace jit { +using namespace ::torch::jit -// 1. Test cases are void() functions. -// 2. They start with the prefix `test` -void testCaseOne() { - // ... +TEST(FooTest, BarBaz) { + // ... } -void testCaseTwo() { - // ... -} -} +// Append '_CUDA' to the test case name will automatically filter it out if CUDA +// is not compiled. +TEST(FooTest, NeedsAGpu_CUDA) { + // ... } -``` -Then, register your test in `tests.h`: -```cpp -// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests -#define TH_FORALL_TESTS(_) \ - _(ADFormulas) \ - _(Attributes) \ - ... - _(CaseOne) // note that the `test` prefix is omitted. - _(CaseTwo) -``` - -We glob all the test files together in `CMakeLists.txt` so that you don't -have to edit it every time you add a test. Unfortunately, this means that in -order to get the build to pick up your new test file, you need to re-run -cmake: -``` -python setup.py build --cmake +// Similarly, if only one GPU is detected, tests with `_MultiCUDA` at the end +// will not be run. +TEST(FooTest, NeedsMultipleGpus_MultiCUDA) { + // ... +} ``` -## Why do we have two different test runners? -We have two different ways of running our cpp tests: -1. With `gtest`, from a standalone binary. -2. With Python, from `TestJit.test_cpp` and `TestJit.test_cpp_cuda` (in - `test/test_jit.py`) - -We want both because we need to test things from a pure-C++ environment and -with all our various Python patch-points enabled. - -## How do I run the tests? +## Building and running the tests The following commands assume you are in PyTorch root. -1. With `gtest`: - ```bash - # (re)build the test binary - ninja build/bin/test_jit - # run - build/bin/test_jit --gtest_filter='glob_style_filter*' - ``` -2. With Python: - ``` - python test/test_jit.py TestJit.test_cpp TestJit.test_cpp_cuda - ``` +```bash +# ... Build PyTorch from source, e.g. +python setup.py develop +# (re)build just the binary +ninja -C build bin/test_jit +# run tests +build/bin/test_jit --gtest_filter='glob_style_filter*' +``` diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp deleted file mode 100644 index e0e512be43526..0000000000000 --- a/test/cpp/jit/gtest.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include - -#include - -namespace torch { -namespace jit { - -#define JIT_GTEST(name) \ - TEST(JitTest, name) { \ - test##name(); \ - } -TH_FORALL_TESTS(JIT_GTEST) -#undef JIT_TEST - -#define JIT_GTEST_CUDA(name) \ - TEST(JitTest, name##_CUDA) { \ - test##name(); \ - } -TH_FORALL_TESTS_CUDA(JIT_GTEST_CUDA) -#undef JIT_TEST_CUDA - -} // namespace jit -} // namespace torch diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index e700ee5406163..58078c9716428 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -906,14 +906,15 @@ graph(): } TEST(WildcardsTest, Basic) { - RegisterOperators reg({Operator( - "prim::returns_wildcard(Tensor a) -> Tensor(*)", - [](Stack* stack) {}, - aliasAnalysisFromSchema()), - Operator( - "prim::writes(Tensor(z!) a) -> Tensor(a)", - [](Stack* stack) {}, - aliasAnalysisFromSchema())}); + RegisterOperators reg( + {Operator( + "prim::returns_wildcard(Tensor a) -> Tensor(*)", + [](Stack* stack) {}, + aliasAnalysisFromSchema()), + Operator( + "prim::writes(Tensor(z!) a) -> Tensor(a)", + [](Stack* stack) {}, + aliasAnalysisFromSchema())}); const auto returns_wildcard = Symbol::fromQualString("prim::returns_wildcard"); const auto writes = Symbol::fromQualString("prim::writes"); diff --git a/test/cpp/jit/test_argument_spec.cpp b/test/cpp/jit/test_argument_spec.cpp index bf40761fc468b..b653abe1cbbf8 100644 --- a/test/cpp/jit/test_argument_spec.cpp +++ b/test/cpp/jit/test_argument_spec.cpp @@ -50,21 +50,23 @@ TEST(ArgumentSpecTest, CompleteArgumentSpec_CUDA) { auto const GF = at::CUDA(at::kFloat); auto const GD = at::CUDA(at::kDouble); - auto list = createStack({var(CF, {1}, true), - var(CD, {1, 2}, false), - var(GF, {}, true), - var(GD, {4, 5, 6}, false), - undef()}); + auto list = createStack( + {var(CF, {1}, true), + var(CD, {1, 2}, false), + var(GF, {}, true), + var(GD, {4, 5, 6}, false), + undef()}); // make sure we have some non-standard strides list[1].toTensor().transpose_(0, 1); // same list but different backing values - auto list2 = createStack({var(CF, {1}, true), - var(CD, {1, 2}, false), - var(GF, {}, true), - var(GD, {4, 5, 6}, false), - undef()}); + auto list2 = createStack( + {var(CF, {1}, true), + var(CD, {1, 2}, false), + var(GF, {}, true), + var(GD, {4, 5, 6}, false), + undef()}); list2[1].toTensor().transpose_(0, 1); CompleteArgumentSpec a(true, list); @@ -142,21 +144,23 @@ TEST(ArgumentSpecTest, Basic_CUDA) { ArgumentSpecCreator arg_spec_creator(*graph); - auto list = createStack({var(CF, {1}, true), - var(CD, {1, 2}, false), - var(GF, {}, true), - var(GD, {4, 5, 6}, false), - undef()}); + auto list = createStack( + {var(CF, {1}, true), + var(CD, {1, 2}, false), + var(GF, {}, true), + var(GD, {4, 5, 6}, false), + undef()}); // make sure we have some non-standard strides list[1].toTensor().transpose_(0, 1); // same list but different backing values - auto list2 = createStack({var(CF, {1}, true), - var(CD, {1, 2}, false), - var(GF, {}, true), - var(GD, {4, 5, 6}, false), - undef()}); + auto list2 = createStack( + {var(CF, {1}, true), + var(CD, {1, 2}, false), + var(GF, {}, true), + var(GD, {4, 5, 6}, false), + undef()}); list2[1].toTensor().transpose_(0, 1); ArgumentSpec a = arg_spec_creator.create(true, list); diff --git a/test/cpp/jit/test_autodiff.cpp b/test/cpp/jit/test_autodiff.cpp index 3993c63b1708e..38ddfee5fdd28 100644 --- a/test/cpp/jit/test_autodiff.cpp +++ b/test/cpp/jit/test_autodiff.cpp @@ -81,6 +81,7 @@ variable_list grad( grad_outputs, true, false, + false, fmap(inputs, get_edge)); } diff --git a/test/cpp/jit/test_base.cpp b/test/cpp/jit/test_base.cpp deleted file mode 100644 index 338577fbd8336..0000000000000 --- a/test/cpp/jit/test_base.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include -#include - -#include "torch/csrc/jit/runtime/custom_operator.h" - -namespace torch { -namespace jit { -inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { - return c10::AliasAnalysisKind::FROM_SCHEMA; -} - -namespace { -RegisterOperators reg({ - // This operator is intended to be used in JIT analysis and transformation - // pass unit tests in which Values with type Tensor are often required. It - // should not be used in situations in which the graph is actually executed - // because it always produces empty Tensors. - Operator( - "prim::MakeTestTensor() -> Tensor", - [](Stack* stack) { push(stack, at::Tensor()); }, - aliasAnalysisFromSchema()), -}); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/jit/test_base.h b/test/cpp/jit/test_base.h deleted file mode 100644 index 25f9e9f36cde1..0000000000000 --- a/test/cpp/jit/test_base.h +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once - -// This file defines assertion macros that work in both gtest and non-gtest -// builds, and has some common includes. -#include "torch/csrc/jit/ir/ir.h" -#include "torch/csrc/jit/runtime/operator.h" - -#if defined(USE_GTEST) -#include -#include -#else -#include "c10/util/Exception.h" -// Temporary: we are going to remove these polyfills entirely. -// But for now avoid redefining them if they are already defined in gtest. -// (ASSERT_EQ is a proxy for whether gtest is already present) -#ifndef ASSERT_EQ -#define ASSERT_EQ(x, y) TORCH_INTERNAL_ASSERT((x) == (y)) -#define ASSERT_NE(x, y) TORCH_INTERNAL_ASSERT((x) != (y)) -#define ASSERT_TRUE TORCH_INTERNAL_ASSERT -#define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) -#define ASSERT_THROWS_WITH(statement, substring) \ - try { \ - (void)statement; \ - ASSERT_TRUE(false); \ - } catch (const std::exception& e) { \ - ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ - } -#define ASSERT_ANY_THROW(statement) \ - { \ - bool threw = false; \ - try { \ - (void)statement; \ - } catch (const std::exception& e) { \ - threw = true; \ - } \ - ASSERT_TRUE(threw); \ - } -#endif // ndef(ASSERT_EQ) - -#endif // defined(USE_GTEST) - -static inline bool isSandcastle() { - return ( - (std::getenv("SANDCASTLE")) || - (std::getenv("TW_JOB_USER") && - std::string(std::getenv("TW_JOB_USER")) == "sandcastle")); -} diff --git a/test/cpp/jit/test_class_parser.cpp b/test/cpp/jit/test_class_parser.cpp index a5b19f63fd3f3..2f7f06d3802b9 100644 --- a/test/cpp/jit/test_class_parser.cpp +++ b/test/cpp/jit/test_class_parser.cpp @@ -1,6 +1,5 @@ #include -#include #include #include diff --git a/test/cpp/jit/test_class_type.cpp b/test/cpp/jit/test_class_type.cpp index c00aafcc526b1..21229594d56d0 100644 --- a/test/cpp/jit/test_class_type.cpp +++ b/test/cpp/jit/test_class_type.cpp @@ -1,11 +1,12 @@ -#include +#include + #include #include namespace torch { namespace jit { -void testClassTypeAddRemoveAttr() { +TEST(ClassTypeTest, AddRemoveAttr) { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); cls->addAttribute("attr1", TensorType::get(), true); @@ -32,12 +33,12 @@ void testClassTypeAddRemoveAttr() { cls->addAttribute("attr1", IntType::get()); } -void testClassTypeAddRemoveConstant() { +TEST(ClassTypeTest, AddRemoveConstant) { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu); cls->addConstant("const1", IValue(1)); cls->addConstant("const2", IValue(2)); - cls->addConstant("const3", IValue(2)); + cls->addConstant("const3", IValue(3)); ASSERT_EQ(cls->numConstants(), 3); ASSERT_TRUE(cls->hasConstant("const1")); ASSERT_TRUE(cls->hasConstant("const2")); @@ -46,7 +47,7 @@ void testClassTypeAddRemoveConstant() { ASSERT_EQ(cls->getConstant("const1").toInt(), 1); ASSERT_EQ(cls->getConstant("const2").toInt(), 2); - ASSERT_EQ(cls->getConstant("const2").toInt(), 3); + ASSERT_EQ(cls->getConstant("const3").toInt(), 3); cls->unsafeRemoveConstant("const2"); ASSERT_TRUE(cls->hasConstant("const1")); diff --git a/test/cpp/jit/test_constant_pooling.cpp b/test/cpp/jit/test_constant_pooling.cpp index c8cb58e1886a9..6f81e5db907bd 100644 --- a/test/cpp/jit/test_constant_pooling.cpp +++ b/test/cpp/jit/test_constant_pooling.cpp @@ -79,14 +79,36 @@ graph(): ConstantPooling(graph); testing::FileCheck() .check_count( - "Float(2:1, requires_grad=0, device=cpu) = prim::Constant", + "Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant", 1, /*exactly*/ true) ->check_count( - "Long(2:1, requires_grad=0, device=cpu) = prim::Constant", + "Long(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant", 1, /*exactly*/ true) ->run(*graph); } + +TEST(ConstantPoolingTest, DictConstantPooling) { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(): + %0 : int = prim::Constant[value=1]() # test/elias.py:6:9 + %1 : int = prim::Constant[value=2]() # test/elias.py:6:12 + %a.1 : Dict(int, int) = prim::DictConstruct(%0, %1) + %b.1 : Dict(int, int) = prim::DictConstruct(%1, %1) + return (%a.1, %b.1) + )IR", + &*graph); + ConstantPropagation(graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count( + "Dict(int, int) = prim::Constant", + 2, + /*exactly*/ true) + ->run(*graph); +} } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index a96a3b4a5635f..aad95d61a5e6a 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -44,5 +45,89 @@ TEST(CustomClassTest, TorchbindIValueAPI) { test_with_obj(new_stack_ivalue, "boo"); } +class TorchBindTestClass : public torch::jit::CustomClassHolder { + public: + std::string get() { + return "Hello, I am your test custom class"; + } +}; + +constexpr char class_doc_string[] = R"( + I am docstring for TorchBindTestClass + Args: + What is an argument? Oh never mind, I don't take any. + + Return: + How would I know? I am just a holder of some meaningless test methods. + )"; +constexpr char method_doc_string[] = + "I am docstring for TorchBindTestClass get_with_docstring method"; + +namespace { +static auto reg = + torch::class_( + "_TorchBindTest", + "_TorchBindTestClass", + class_doc_string) + .def("get", &TorchBindTestClass::get) + .def("get_with_docstring", &TorchBindTestClass::get, method_doc_string); + +} // namespace + +// Tests DocString is properly propagated when defining CustomClasses. +TEST(CustomClassTest, TestDocString) { + auto class_type = getCustomClass( + "__torch__.torch.classes._TorchBindTest._TorchBindTestClass"); + AT_ASSERT(class_type); + AT_ASSERT(class_type->doc_string() == class_doc_string); + + AT_ASSERT(class_type->getMethod("get").doc_string().empty()); + AT_ASSERT( + class_type->getMethod("get_with_docstring").doc_string() == + method_doc_string); +} + +TEST(CustomClassTest, Serialization) { + script::Module m("m"); + + // test make_custom_class API + auto custom_class_obj = make_custom_class>( + std::vector{"foo", "bar"}); + m.register_attribute( + "s", + custom_class_obj.type(), + custom_class_obj, + /*is_parameter=*/false); + m.define(R"( + def forward(self): + return self.s.return_a_tuple() + )"); + + auto test_with_obj = [](script::Module& mod) { + auto res = mod.run_method("forward"); + auto tup = res.toTuple(); + AT_ASSERT(tup->elements().size() == 2); + auto i = tup->elements()[1].toInt(); + AT_ASSERT(i == 123); + }; + + auto frozen_m = torch::jit::freeze_module(m.clone()); + + test_with_obj(m); + test_with_obj(frozen_m); + + std::ostringstream oss; + m.save(oss); + std::istringstream iss(oss.str()); + caffe2::serialize::IStreamAdapter adapter{&iss}; + auto loaded_module = torch::jit::load(iss, torch::kCPU); + + std::ostringstream oss_frozen; + frozen_m.save(oss_frozen); + std::istringstream iss_frozen(oss_frozen.str()); + caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen}; + auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index f563120bbc6c3..fc2d83d764094 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder { } }; +struct LambdaInit : torch::CustomClassHolder { + int x, y; + LambdaInit(int x_, int y_) : x(x_), y(y_) {} + int64_t diff() { + return this->x - this->y; + } +}; + struct NoInit : torch::CustomClassHolder { int64_t x; }; @@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { .def("add", &Foo::add) .def("combine", &Foo::combine); + m.class_("_LambdaInit") + .def(torch::init([](int64_t x, int64_t y, bool swap) { + if (swap) { + return c10::make_intrusive(y, x); + } else { + return c10::make_intrusive(x, y); + } + })) + .def("diff", &LambdaInit::diff); + m.class_("_NoInit").def( "get_x", [](const c10::intrusive_ptr& self) { return self->x; }); diff --git a/test/cpp/jit/test_fuser.cpp b/test/cpp/jit/test_fuser.cpp index ef595215b8820..bff3ef4a32cd7 100644 --- a/test/cpp/jit/test_fuser.cpp +++ b/test/cpp/jit/test_fuser.cpp @@ -1,43 +1,40 @@ #include -#include -#include "ATen/core/interned_strings.h" -#include "torch/csrc/autograd/generated/variable_factories.h" -#include "torch/csrc/autograd/variable.h" -#include "torch/csrc/jit/codegen/fuser/interface.h" -#include "torch/csrc/jit/frontend/tracer.h" -#include "torch/csrc/jit/ir/alias_analysis.h" -#include "torch/csrc/jit/ir/attributes.h" -#include "torch/csrc/jit/ir/irparser.h" -#include "torch/csrc/jit/passes/common_subexpression_elimination.h" -#include "torch/csrc/jit/passes/constant_propagation.h" -#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" -#include "torch/csrc/jit/passes/dead_code_elimination.h" -#include "torch/csrc/jit/passes/graph_fuser.h" -#include "torch/csrc/jit/passes/lower_grad_of.h" -#include "torch/csrc/jit/passes/lower_tuples.h" -#include "torch/csrc/jit/passes/requires_grad_analysis.h" -#include "torch/csrc/jit/passes/shape_analysis.h" -#include "torch/csrc/jit/passes/utils/subgraph_utils.h" -#include "torch/csrc/jit/runtime/argument_spec.h" -#include "torch/csrc/jit/runtime/autodiff.h" -#include "torch/csrc/jit/runtime/custom_operator.h" -#include "torch/csrc/jit/runtime/interpreter.h" -#include "torch/csrc/jit/runtime/symbolic_script.h" -#include "torch/csrc/jit/serialization/import.h" - -#include "torch/csrc/autograd/engine.h" -#include "torch/csrc/autograd/variable.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include "ATen/core/ivalue.h" -#include "torch/csrc/jit/api/module.h" -#include "torch/csrc/jit/frontend/ir_emitter.h" -#include "torch/csrc/jit/runtime/graph_executor.h" -#include "onnx/onnx_pb.h" - -#include +#include #include @@ -57,6 +54,9 @@ namespace torch { namespace jit { TEST(FuserTest, TestSimple_CUDA) { +#if defined(FBCODE_CAFFE2) + return; +#endif const auto graph_string = R"IR( graph(%0 : Tensor, %1 : Tensor): @@ -77,6 +77,9 @@ TEST(FuserTest, TestSimple_CUDA) { } TEST(FuserTest, TestOne_CUDA) { +#if defined(FBCODE_CAFFE2) + return; +#endif auto testOne = [&](int ti, int tj) { const auto graph_string = R"IR( graph(%0 : Tensor, @@ -134,6 +137,9 @@ TEST(FuserTest, TestOne_CUDA) { } TEST(FuserTest, FusedConcat_CUDA) { +#if defined(FBCODE_CAFFE2) + return; +#endif const auto graph_string0 = R"IR( graph(%0 : Tensor, %1 : Tensor): @@ -177,6 +183,9 @@ TEST(FuserTest, FusedConcat_CUDA) { } TEST(FuserTest, FusionAliasing) { +#if defined(FBCODE_CAFFE2) + return; +#endif const auto graph_string = R"IR( graph(%0 : Tensor, %1 : Tensor): @@ -202,6 +211,10 @@ TEST(FuserTest, FusionAliasing) { } TEST(FuserTest, KernelCaching) { +#if defined(FBCODE_CAFFE2) + return; +#endif + // Constructs two functionally equivalent graphs const auto graph0_string = R"IR( graph(%0 : Float(2, 3, 4), diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index d18becfa66412..5086cb5b5148e 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1,6 +1,5 @@ #if defined(USE_CUDA) - -#include +#include #include #include @@ -8,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +32,7 @@ namespace torch { namespace jit { -using namespace torch::jit::fuser; +using namespace torch::jit::fuser::cuda; namespace { @@ -59,8 +59,13 @@ TensorView* makeConcreteTensor( // We can uncomment the below statement to test all tests with contiguous // tensors. return makeContigTensor(nDims, dtype); std::vector dom; - for (size_t i = 0; i < sizes.size(); i++) - dom.push_back(new IterDomain(new Int(0), new Int(sizes[i]))); + for (size_t i = 0; i < sizes.size(); i++) { + if (sizes[i] >= 0) { + dom.push_back(new IterDomain(new Int(0), new Int(sizes[i]))); + } else { + dom.push_back(new IterDomain(new Int(0), new Int())); + } + } return new TensorView(new TensorDomain(dom), dtype); } @@ -93,7 +98,7 @@ void checkIntValue( // (These tests exercise IrGraphGenerator through a non-trivial IR, // to make sure that it runs w/o crashing. The actual output is not // validated) -void testGPU_IrGraphGenerator() { +TEST(NVFuserTest, IrGraphGenerator_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -145,7 +150,7 @@ void testGPU_IrGraphGenerator() { .empty()); } -void testGPU_FusionDispatch() { +TEST(NVFuserTest, FusionDispatch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -160,7 +165,7 @@ void testGPU_FusionDispatch() { } // Evaluate basic scalar operations with constant values -void testGPU_FusionExprEvalConstants() { +TEST(NVFuserTest, FusionExprEvalConstants_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -177,7 +182,7 @@ void testGPU_FusionExprEvalConstants() { } // Evaluate basic scalar operations with bound values -void testGPU_FusionExprEvalBindings() { +TEST(NVFuserTest, FusionExprEvalBindings_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -222,7 +227,7 @@ void testGPU_FusionExprEvalBindings() { } // Evaluate expressions in a simple IR -void testGPU_FusionExprEvalBasic() { +TEST(NVFuserTest, FusionExprEvalBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -278,7 +283,7 @@ void testGPU_FusionExprEvalBasic() { } // Evaluate expressions in a more complex IR -void testGPU_FusionExprEvalComplex() { +TEST(NVFuserTest, FusionExprEvalComplex_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -330,7 +335,7 @@ void testGPU_FusionExprEvalComplex() { } // Evaluate expressions post lowering -void testGPU_FusionExprEvalPostLower() { +TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -387,7 +392,7 @@ void testGPU_FusionExprEvalPostLower() { checkIntValue(evaluator, tid_x, 128); } -void testGPU_FusionClear() { +TEST(NVFuserTest, FusionClear_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -457,7 +462,7 @@ void testGPU_FusionClear() { at::Tensor input1 = at::randn({16, 8, 8}, options); at::Tensor input2 = at::randn_like(input1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); @@ -467,7 +472,7 @@ void testGPU_FusionClear() { TORCH_CHECK(output_ref.equal(outputs[0])); } -void testGPU_FusionCopy() { +TEST(NVFuserTest, FusionCopy_CUDA) { Fusion original_fusion; // Create the test IR @@ -541,7 +546,7 @@ void testGPU_FusionCopy() { ASSERT_EQ(original_kernel, clone_kernel); } -void testGPU_FusionMove() { +TEST(NVFuserTest, FusionMove_CUDA) { Fusion fusion; // Create the test IR @@ -611,7 +616,7 @@ void testGPU_FusionMove() { ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str()); } -void testGPU_FusionSimpleArith() { +TEST(NVFuserTest, FusionSimpleArith_CUDA) { std::stringstream ss1, ss2; Fusion fusion; @@ -640,7 +645,7 @@ void testGPU_FusionSimpleArith() { "Error where explicit add nodes don't match implicit add nodes."); } -void testGPU_FusionSimpleTypePromote() { +TEST(NVFuserTest, FusionSimpleTypePromote_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -663,7 +668,7 @@ class ZeroMutator : public OptOutMutator { } }; -void testGPU_FusionMutator() { +TEST(NVFuserTest, FusionMutator_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -681,7 +686,7 @@ void testGPU_FusionMutator() { TORCH_CHECK(flhs->value().value() == 0.f); } -void testGPU_FusionRegister() { +TEST(NVFuserTest, FusionRegister_CUDA) { Fusion fusion; FusionGuard fg(&fusion); Float* v1 = new Float{1.f}; @@ -712,7 +717,7 @@ struct DummyExpr : public Expr { DummyExpr& operator=(DummyExpr&& other) = delete; }; -void testGPU_FusionTopoSort() { +TEST(NVFuserTest, FusionTopoSort_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -779,7 +784,7 @@ void testGPU_FusionTopoSort() { TORCH_CHECK(fusion.origin(v6)->name() == 3); } -void testGPU_FusionTensor() { +TEST(NVFuserTest, FusionTensor_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); Fusion fusion; @@ -843,7 +848,7 @@ void testGPU_FusionTensor() { } } -void testGPU_FusionFilterVals() { +TEST(NVFuserTest, FusionFilterVals_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -881,7 +886,7 @@ void testGPU_FusionFilterVals() { "Not expecting any results"); } -void testGPU_FusionTVSplit() { +TEST(NVFuserTest, FusionTVSplit_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -907,7 +912,7 @@ void testGPU_FusionTVSplit() { static_cast(inner->extent())->value().value() == 2); } -void testGPU_FusionTVMerge() { +TEST(NVFuserTest, FusionTVMerge_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -925,7 +930,7 @@ void testGPU_FusionTVMerge() { tv->getRootDomain()[2]->extent()); } -void testGPU_FusionTVReorder() { +TEST(NVFuserTest, FusionTVReorder_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -972,7 +977,7 @@ void testGPU_FusionTVReorder() { TORCH_CHECK(ref[1]->sameAs(tv->axis(1))); } -void testGPU_FusionEquality() { +TEST(NVFuserTest, FusionEquality_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1013,7 +1018,7 @@ void testGPU_FusionEquality() { TORCH_CHECK(!neg1->sameAs(neg2)); } -void testGPU_FusionDependency() { +TEST(NVFuserTest, FusionDependency_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1083,15 +1088,15 @@ void testGPU_FusionDependency() { TORCH_CHECK(dep_chain.empty()); } -void testGPU_FusionParser() { +TEST(NVFuserTest, FusionParser_CUDA) { auto g = std::make_shared(); const auto graph0_string = R"IR( - graph(%0 : Float(2:1), - %1 : Float(2:1)): - %c0 : Float(2:1) = aten::mul(%0, %1) - %d0 : Float(2:1) = aten::mul(%c0, %0) + graph(%0 : Float(2, strides=[1]), + %1 : Float(2, strides=[1])): + %c0 : Float(2, strides=[1]) = aten::mul(%0, %1) + %d0 : Float(2, strides=[1]) = aten::mul(%c0, %0) return (%d0))IR"; - torch::jit::parseIR(graph0_string, g.get()); + parseIR(graph0_string, g.get()); // strides are not yet supported in the irparser. for (auto val : g->block()->inputs()) { @@ -1105,12 +1110,12 @@ void testGPU_FusionParser() { } } - auto fusion = fuser::cuda::parseJitIR(g); + auto fusion = parseJitIR(g); FusionGuard fg(fusion.get()); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16}, options); at::Tensor input2 = at::randn({16}, options); - fuser::cuda::scheduleFusion(fusion.get(), {input1, input2}); + scheduleFusion(fusion.get(), {input1, input2}); // CONSIDER: // 1. this can be moved to a dedicated "golden" file @@ -1156,14 +1161,14 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te TORCH_CHECK(false); } - cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(fusion.get()); auto outputs = fe.runFusion({input1, input2}); at::Tensor output_ref = input1 * input2 * input1; TORCH_CHECK(output_ref.equal(outputs[0])); } -void testGPU_FusionForLoop() { +TEST(NVFuserTest, FusionForLoop_CUDA) { // TODO(kir): re-enable this test // due to the current "GpuLower guard" approach, we can only create // kernel IR during GpuLower::lower() @@ -1204,7 +1209,7 @@ void testGPU_FusionForLoop() { #endif } -void testGPU_FusionCodeGen() { +TEST(NVFuserTest, FusionCodeGen_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1231,7 +1236,7 @@ void testGPU_FusionCodeGen() { at::Tensor output = at::empty({16, 8, 8}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({}, {output}); @@ -1241,7 +1246,7 @@ void testGPU_FusionCodeGen() { TORCH_CHECK(output_ref.equal(output)); } -void testGPU_FusionCodeGen2() { +TEST(NVFuserTest, FusionCodeGen2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1273,7 +1278,7 @@ void testGPU_FusionCodeGen2() { at::Tensor input1 = at::randn({16, 8, 8}, options); at::Tensor input2 = at::randn_like(input1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); @@ -1283,7 +1288,7 @@ void testGPU_FusionCodeGen2() { TORCH_CHECK(output_ref.equal(outputs[0])); } -void testGPU_FusionSimplePWise() { +TEST(NVFuserTest, FusionSimplePWise_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // dimensionality of the problem @@ -1330,7 +1335,7 @@ void testGPU_FusionSimplePWise() { at::Tensor input2 = at::rand_like(input1); at::Tensor output = at::empty_like(input1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input1, input2}, {output}); @@ -1340,7 +1345,7 @@ void testGPU_FusionSimplePWise() { TORCH_CHECK(output_ref.equal(output)); } -void testGPU_FusionExecKernel() { +TEST(NVFuserTest, FusionExecKernel_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1381,7 +1386,7 @@ void testGPU_FusionExecKernel() { at::Tensor input1 = at::ones({1, 128}, options); at::Tensor input2 = at::ones_like(input1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); @@ -1394,7 +1399,7 @@ int ceilDiv_(int a, int b) { return (a + b - 1) / b; } -void testGPU_FusionAdvancedComputeAt() { +TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 @@ -1403,74 +1408,74 @@ void testGPU_FusionAdvancedComputeAt() { // tv5 = tv3 + tv2 // tv6 = tv5 + tv4 // tv7 = tv1 + tv4 - { - Fusion fusion; - FusionGuard fg(&fusion); + Fusion fusion; + FusionGuard fg(&fusion); - TensorView* tv0 = makeDummyTensor(2); - fusion.addInput(tv0); + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Float(0.5)); - TensorView* tv2 = mul(tv1, new Float(-1.0)); - TensorView* tv3 = add(tv1, new Float(3.0)); - TensorView* tv4 = mul(tv1, new Float(2.0)); - TensorView* tv5 = add(tv3, tv2); + TensorView* tv1 = mul(tv0, new Float(0.5)); + TensorView* tv2 = mul(tv1, new Float(-1.0)); + TensorView* tv3 = add(tv1, new Float(3.0)); + TensorView* tv4 = mul(tv1, new Float(2.0)); + TensorView* tv5 = add(tv3, tv2); - TensorView* tv6 = add(tv5, tv4); - TensorView* tv7 = add(tv1, tv4); + TensorView* tv6 = add(tv5, tv4); + TensorView* tv7 = add(tv1, tv4); - fusion.addOutput(tv6); - fusion.addOutput(tv7); + fusion.addOutput(tv6); + fusion.addOutput(tv7); - // Lets setup to actually run - tv7->merge(0); - tv7->split(0, 128); - tv7->split(0, 4); + // Lets setup to actually run + tv7->merge(0); + tv7->split(0, 128); + tv7->split(0, 4); - tv7->axis(0)->parallelize(ParallelType::BIDx); + tv7->axis(0)->parallelize(ParallelType::BIDx); - tv0->computeAt(tv7, 1); - - TORCH_CHECK(tv1->hasComputeAt() && tv1->nDims() == 3); - TORCH_CHECK(tv2->getComputeAtView() == tv5 && tv2->nDims() == 3); - TORCH_CHECK(tv3->getComputeAtView() == tv5 && tv3->nDims() == 3); - TORCH_CHECK(tv4->hasComputeAt() && tv4->nDims() == 3); - TORCH_CHECK(tv5->getComputeAtView() == tv6 && tv5->nDims() == 3); - TORCH_CHECK(tv6->getComputeAtView() == tv7 && tv6->nDims() == 3); - TORCH_CHECK(!tv7->hasComputeAt()); - - for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } + tv0->computeAt(tv7, 1); + + TORCH_CHECK(tv1->hasComputeAt() && tv1->nDims() == 3); + TORCH_CHECK(tv2->getComputeAtView() == tv5 && tv2->nDims() == 3); + TORCH_CHECK(tv3->getComputeAtView() == tv5 && tv3->nDims() == 3); + TORCH_CHECK(tv4->hasComputeAt() && tv4->nDims() == 3); + TORCH_CHECK(tv5->getComputeAtView() == tv6 && tv5->nDims() == 3); + TORCH_CHECK(tv6->getComputeAtView() == tv7 && tv6->nDims() == 3); + TORCH_CHECK(!tv7->hasComputeAt()); + + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); } + } - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127}, options); + at::Tensor t0 = at::randn({129, 127}, options); - auto t1 = t0.mul({0.5}); - auto t2 = t1.mul({-1.0}); - auto t3 = t1.add({3.0}); - auto t4 = t1.mul({2.0}); - auto t5 = t3.add(t2); - auto t6 = t5.add(t4); - auto t7 = t1.add(t4); + auto t1 = t0.mul({0.5}); + auto t2 = t1.mul({-1.0}); + auto t3 = t1.add({3.0}); + auto t4 = t1.mul({2.0}); + auto t5 = t3.add(t2); + auto t6 = t5.add(t4); + auto t7 = t1.add(t4); - at::Tensor kernel_tv6 = at::empty_like(t0, options); - at::Tensor kernel_tv7 = at::empty_like(t0, options); + at::Tensor kernel_tv6 = at::empty_like(t0, options); + at::Tensor kernel_tv7 = at::empty_like(t0, options); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - fe.runFusion({t0}, {kernel_tv6, kernel_tv7}); + FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0}, {kernel_tv6, kernel_tv7}); - TORCH_CHECK(at::allclose(kernel_tv6, t6)); - TORCH_CHECK(at::allclose(kernel_tv7, t7)); - } + TORCH_CHECK(at::allclose(kernel_tv6, t6)); + TORCH_CHECK(at::allclose(kernel_tv7, t7)); +} +TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 @@ -1478,222 +1483,255 @@ void testGPU_FusionAdvancedComputeAt() { // tv4 = tv2 + tv1 // tv5 = tv4 + tv3 // tv6 = tv5 + tv3 - { - Fusion fusion; - FusionGuard fg(&fusion); + Fusion fusion; + FusionGuard fg(&fusion); - TensorView* tv0 = makeDummyTensor(2); - fusion.addInput(tv0); + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Float(-1.0)); - TensorView* tv2 = add(tv0, new Float(3.0)); - TensorView* tv3 = mul(tv0, new Float(2.0)); - TensorView* tv4 = add(tv2, tv1); + TensorView* tv1 = mul(tv0, new Float(-1.0)); + TensorView* tv2 = add(tv0, new Float(3.0)); + TensorView* tv3 = mul(tv0, new Float(2.0)); + TensorView* tv4 = add(tv2, tv1); - TensorView* tv5 = add(tv4, tv3); - TensorView* tv6 = add(tv5, tv3); + TensorView* tv5 = add(tv4, tv3); + TensorView* tv6 = add(tv5, tv3); - fusion.addOutput(tv5); - fusion.addOutput(tv6); + fusion.addOutput(tv5); + fusion.addOutput(tv6); - // Lets setup to actually run - tv6->merge(0); - tv6->split(0, 128); - tv6->split(0, 4); + // Lets setup to actually run + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); - tv6->axis(0)->parallelize(ParallelType::BIDx); + tv6->axis(0)->parallelize(ParallelType::BIDx); - tv0->computeAt(tv6, 1); + tv0->computeAt(tv6, 1); - for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); } + } - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127}, options); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127}, options); - auto t1 = t0.mul({-1.0}); - auto t2 = t0.add({3.0}); - auto t3 = t0.mul({2.0}); - auto t4 = t2.add(t1); - auto t5 = t4.add(t3); - auto t6 = t5.add(t3); + auto t1 = t0.mul({-1.0}); + auto t2 = t0.add({3.0}); + auto t3 = t0.mul({2.0}); + auto t4 = t2.add(t1); + auto t5 = t4.add(t3); + auto t6 = t5.add(t3); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0}); + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0}); - TORCH_CHECK(at::allclose(outputs[0], t5)); - TORCH_CHECK(at::allclose(outputs[1], t6)); - } + TORCH_CHECK(at::allclose(outputs[0], t5)); + TORCH_CHECK(at::allclose(outputs[1], t6)); +} +TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 - { - Fusion fusion; - FusionGuard fg(&fusion); + Fusion fusion; + FusionGuard fg(&fusion); - TensorView* tv0 = makeDummyTensor(4); - fusion.addInput(tv0); + TensorView* tv0 = makeDummyTensor(4); + fusion.addInput(tv0); - TensorView* tv1 = makeDummyTensor(4); - fusion.addInput(tv1); + TensorView* tv1 = makeDummyTensor(4); + fusion.addInput(tv1); - TensorView* tv2 = mul(tv1, new Float(.979361)); - TensorView* tv3 = mul(tv2, tv0); + TensorView* tv2 = mul(tv1, new Float(.979361)); + TensorView* tv3 = mul(tv2, tv0); - fusion.addOutput(tv3); + fusion.addOutput(tv3); - // Lets setup to actually run - while (tv3->nDims() > 1) - tv3->merge(0); - tv3->split(0, 128); - tv3->split(0, 4); + // Lets setup to actually run + while (tv3->nDims() > 1) + tv3->merge(0); + tv3->split(0, 128); + tv3->split(0, 4); - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); - tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); - for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); } + } - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); - auto t2 = t1.mul({0.979361}); - auto t3 = t2.mul(t0); + auto t2 = t1.mul({0.979361}); + auto t3 = t2.mul(t0); - at::Tensor kernel_tv3 = at::empty_like(t0, options); + at::Tensor kernel_tv3 = at::empty_like(t0, options); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - fe.runFusion({t0, t1}, {kernel_tv3}); + FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0, t1}, {kernel_tv3}); - TORCH_CHECK(at::allclose(kernel_tv3, t3)); - } + TORCH_CHECK(at::allclose(kernel_tv3, t3)); +} +TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 // T6 = T5 - T0 - { - Fusion fusion; - FusionGuard fg(&fusion); + Fusion fusion; + FusionGuard fg(&fusion); - TensorView* tv0 = makeDummyTensor(4); - fusion.addInput(tv0); + TensorView* tv0 = makeDummyTensor(4); + fusion.addInput(tv0); - TensorView* tv1 = makeDummyTensor(4); - fusion.addInput(tv1); + TensorView* tv1 = makeDummyTensor(4); + fusion.addInput(tv1); - TensorView* tv2 = makeDummyTensor(4); - fusion.addInput(tv2); + TensorView* tv2 = makeDummyTensor(4); + fusion.addInput(tv2); - TensorView* tv3 = makeDummyTensor(4); - fusion.addInput(tv3); + TensorView* tv3 = makeDummyTensor(4); + fusion.addInput(tv3); - TensorView* tv4 = sub(tv2, tv3); - TensorView* tv5 = add(tv1, tv4); - TensorView* tv6 = sub(tv5, tv0); + TensorView* tv4 = sub(tv2, tv3); + TensorView* tv5 = add(tv1, tv4); + TensorView* tv6 = sub(tv5, tv0); - fusion.addOutput(tv6); + fusion.addOutput(tv6); - // Lets setup to actually run - while (tv6->nDims() > 1) - tv6->merge(0); - tv6->split(0, 128); - tv6->split(0, 4); + // Lets setup to actually run + while (tv6->nDims() > 1) + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); - tv0->computeAt(tv6, 1); - tv1->computeAt(tv6, 1); - tv2->computeAt(tv6, 1); - tv3->computeAt(tv6, 1); + tv0->computeAt(tv6, 1); + tv1->computeAt(tv6, 1); + tv2->computeAt(tv6, 1); + tv3->computeAt(tv6, 1); - tv6->axis(0)->parallelize(ParallelType::BIDx); + tv6->axis(0)->parallelize(ParallelType::BIDx); - for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); } + } - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - at::Tensor t2 = at::rand_like(t0, options); - at::Tensor t3 = at::rand_like(t0, options); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + at::Tensor t2 = at::rand_like(t0, options); + at::Tensor t3 = at::rand_like(t0, options); - auto t4 = t2.sub(t3); - auto t5 = t1.add(t4); - auto t6 = t5.sub(t0); + auto t4 = t2.sub(t3); + auto t5 = t1.add(t4); + auto t6 = t5.sub(t0); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0, t1, t2, t3}); + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1, t2, t3}); - TORCH_CHECK(at::allclose(outputs[0], t6)); - } + TORCH_CHECK(at::allclose(outputs[0], t6)); +} +TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 - { - Fusion fusion; - FusionGuard fg(&fusion); + Fusion fusion; + FusionGuard fg(&fusion); - // Set up your input tensor views - TensorView* tv0 = makeDummyTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = makeDummyTensor(2); - fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Float(2.0)); - TensorView* tv3 = mul(tv1, tv2); - fusion.addOutput(tv3); + // Set up your input tensor views + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeDummyTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = add(tv0, new Float(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); - tv3->merge(0); - tv3->split(-1, 8); - tv3->split(-1, 4); + tv3->merge(0); + tv3->split(-1, 8); + tv3->split(-1, 4); - tv2->computeAt(tv3, 1); - tv2->split(-1, 4); // Kernel will break without this split - tv3->axis(0)->parallelize(ParallelType::BIDx); + tv2->computeAt(tv3, 1); + tv3->axis(0)->parallelize(ParallelType::BIDx); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); - auto t2 = t0.add(2.0); - auto t3 = t1.mul(t2); + auto t2 = t0.add(2.0); + auto t3 = t1.mul(t2); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0, t1}); + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); - TORCH_CHECK(at::allclose(outputs[0], t3)); - } + TORCH_CHECK(at::allclose(outputs[0], t3)); +} + +TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeDummyTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = add(tv0, new Float(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv2->merge(0); + tv2->split(-1, 8); + tv2->split(-1, 4); + tv3->merge(0); + tv3->split(-1, 8); + + tv2->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t0.add(2.0); + auto t3 = t1.mul(t2); + + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); + + TORCH_CHECK(at::allclose(outputs[0], t3)); } -void testGPU_FusionComputeAtMultiConsumers() { +TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -2 @@ -1745,7 +1783,7 @@ void testGPU_FusionComputeAtMultiConsumers() { at::Tensor kernel_tv2 = at::empty_like(t0, options); at::Tensor kernel_tv3 = at::empty_like(t0, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {kernel_tv2, kernel_tv3}); @@ -1754,7 +1792,7 @@ void testGPU_FusionComputeAtMultiConsumers() { } // Similar to ComputeAtMultiConsumers, but with a common consumer. -void testGPU_FusionComputeAtCommonConsumer1() { +TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -2 @@ -1816,7 +1854,7 @@ void testGPU_FusionComputeAtCommonConsumer1() { at::Tensor kernel_tv4 = at::empty_like(t0, options); at::Tensor kernel_tv5 = at::empty_like(t0, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {kernel_tv3, kernel_tv4, kernel_tv5}); @@ -1825,7 +1863,7 @@ void testGPU_FusionComputeAtCommonConsumer1() { TORCH_CHECK(at::allclose(kernel_tv5, t5)); } -void testGPU_FusionComputeAtCommonConsumer2() { +TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -1 @@ -1903,7 +1941,7 @@ void testGPU_FusionComputeAtCommonConsumer2() { at::Tensor kernel_tv5 = at::empty_like(t0, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {kernel_tv5}); @@ -1912,7 +1950,7 @@ void testGPU_FusionComputeAtCommonConsumer2() { // Similar to the above common consumer test but adds an additional // tensor that has no common consumer with the other tensors. -void testGPU_FusionComputeAtCommonConsumer3() { +TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -1 @@ -2000,7 +2038,7 @@ void testGPU_FusionComputeAtCommonConsumer3() { at::Tensor kernel_tv5 = at::empty_like(t0, options); at::Tensor kernel_tv6 = at::empty_like(t0, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {kernel_tv5, kernel_tv6}); @@ -2010,7 +2048,7 @@ void testGPU_FusionComputeAtCommonConsumer3() { // Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor // that does not have data dependency with the consumer. -void testGPU_FusionComputeAtNoCommonConsumer() { +TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv1 * -2 @@ -2073,7 +2111,7 @@ void testGPU_FusionComputeAtNoCommonConsumer() { at::Tensor kernel_tv5 = at::empty_like(t0, options); at::Tensor kernel_tv6 = at::empty_like(t0, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {kernel_tv3, kernel_tv4, kernel_tv5, kernel_tv6}); @@ -2102,7 +2140,7 @@ void checkConcretized( } // namespace -void testGPU_FusionBCastConcretizeBasic() { +TEST(NVFuserTest, FusionBCastConcretizeBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2132,7 +2170,7 @@ void testGPU_FusionBCastConcretizeBasic() { checkConcretized(tv2_0, 0, tv1, 1, false); } -void testGPU_FusionBCastConcretizeRfactor() { +TEST(NVFuserTest, FusionBCastConcretizeRfactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2181,7 +2219,7 @@ void checkIdProvedEquivalent( } // namespace -void testGPU_FusionProveIdEqBasic() { +TEST(NVFuserTest, FusionProveIdEqBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2206,7 +2244,7 @@ void testGPU_FusionProveIdEqBasic() { checkIdProvedEquivalent(tv0, 0, tv1, 1, false); } -void testGPU_FusionProveIdEqRfactor() { +TEST(NVFuserTest, FusionProveIdEqRfactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2236,7 +2274,7 @@ void testGPU_FusionProveIdEqRfactor() { checkIdProvedEquivalent(tv3, 0, tv0, 0, true); } -void testGPU_FusionScalarInputs() { +TEST(NVFuserTest, FusionScalarInputs_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2309,7 +2347,7 @@ void testGPU_FusionScalarInputs() { at::Scalar test(fl0); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion( {t0, @@ -2323,7 +2361,7 @@ void testGPU_FusionScalarInputs() { TORCH_CHECK(at::allclose(kernel_tv4, t4)); } -void testGPU_FusionLoopUnroll() { +TEST(NVFuserTest, FusionLoopUnroll_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2368,7 +2406,7 @@ void testGPU_FusionLoopUnroll() { at::Tensor input0 = at::rand({129, 13, 3}, options); at::Tensor input1 = at::rand({129, 13, 3}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input0, input1}); @@ -2495,7 +2533,7 @@ void test_op( if (fusion.isStochastic()) at::manual_seed(0); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion(aten_inputs_ivalues, output_vect); cudaDeviceSynchronize(); @@ -2564,7 +2602,7 @@ void test_op( std::make_index_sequence{}); } -void testGPU_FusionUnaryOps() { +TEST(NVFuserTest, FusionUnaryOps_CUDA) { using OpTuple = std::tuple; @@ -2638,17 +2676,18 @@ void testGPU_FusionUnaryOps() { std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float))); } -void testGPU_FusionBinaryOps() { +TEST(NVFuserTest, FusionBinaryOps_CUDA) { using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&); using OpTuple = std::tuple; // see [Note: explicit tuple type for uniform initialization list] - std::vector logic_ops{OpTuple{at::eq, BinaryOpType::Eq, "eq"}, - OpTuple{at::ge, BinaryOpType::GE, "ge"}, - OpTuple{at::gt, BinaryOpType::GT, "gt"}, - OpTuple{at::le, BinaryOpType::LE, "le"}, - OpTuple{at::lt, BinaryOpType::LT, "lt"}, - OpTuple{at::ne, BinaryOpType::NE, "ne"}}; + std::vector logic_ops{ + OpTuple{at::eq, BinaryOpType::Eq, "eq"}, + OpTuple{at::ge, BinaryOpType::GE, "ge"}, + OpTuple{at::gt, BinaryOpType::GT, "gt"}, + OpTuple{at::le, BinaryOpType::LE, "le"}, + OpTuple{at::lt, BinaryOpType::LT, "lt"}, + OpTuple{at::ne, BinaryOpType::NE, "ne"}}; std::for_each(logic_ops.begin(), logic_ops.end(), [](OpTuple& op) { test_op( @@ -2738,7 +2777,7 @@ void testGPU_FusionBinaryOps() { std::make_pair(ValType::Scalar, DataType::Float))); } -void testGPU_FusionTernaryOps() { +TEST(NVFuserTest, FusionTernaryOps_CUDA) { test_op( /*blocks*/ 640, /*threads*/ 64, @@ -2787,7 +2826,7 @@ void testGPU_FusionTernaryOps() { std::make_pair(ValType::TensorView, DataType::Float))); } -void testGPU_FusionCompoundOps() { +TEST(NVFuserTest, FusionCompoundOps_CUDA) { test_op( /*blocks*/ 640, /*threads*/ 64, @@ -2826,7 +2865,7 @@ void testGPU_FusionCompoundOps() { std::make_pair(ValType::Scalar, DataType::Float))); } -void testGPU_FusionCastOps() { +TEST(NVFuserTest, FusionCastOps_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2850,7 +2889,7 @@ void testGPU_FusionCastOps() { std::array inputs = {input1}; const at::ArrayRef input_ivalues(inputs); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion(input_ivalues); @@ -2868,7 +2907,7 @@ void testGPU_FusionCastOps() { // We want split/merge/reorder all tested both on and off rfactor domains, also // want compute at into the rfactor domain, and into its consumer -void testGPU_FusionRFactorReplay() { +TEST(NVFuserTest, FusionRFactorReplay_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2961,7 +3000,7 @@ void testGPU_FusionRFactorReplay() { // Start off simple, block on the outer dim // block stride + thread all reduce + unrolling on inner dim -void testGPU_FusionReduction() { +TEST(NVFuserTest, FusionReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3011,7 +3050,7 @@ void testGPU_FusionReduction() { at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); @@ -3019,7 +3058,7 @@ void testGPU_FusionReduction() { TORCH_CHECK(aten_output.allclose(cg_output)); } -void testGPU_FusionReduction2() { +TEST(NVFuserTest, FusionReduction2_CUDA) { { Fusion fusion; FusionGuard fg(&fusion); @@ -3081,13 +3120,10 @@ void testGPU_FusionReduction2() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(outputs[0])); } @@ -3134,19 +3170,16 @@ void testGPU_FusionReduction2() { at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } } -void testGPU_FusionReduction3() { +TEST(NVFuserTest, FusionReduction3_CUDA) { { Fusion fusion; FusionGuard fg(&fusion); @@ -3205,19 +3238,16 @@ void testGPU_FusionReduction3() { at::Tensor t4 = at::rand({numel_x}, options); auto t5 = t3.mul(t4); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1, t4}); - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - TORCH_CHECK( t5.allclose(outputs[0]), "Error of: ", t5.sub(outputs[0]).abs().max()); } } -void testGPU_FusionReduction4() { +TEST(NVFuserTest, FusionReduction4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3258,7 +3288,7 @@ void testGPU_FusionReduction4() { at::Tensor cg_output = at::empty({bidy, tidx}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); @@ -3269,7 +3299,7 @@ void testGPU_FusionReduction4() { aten_output.sub(cg_output).abs().max()); } -void testGPU_FusionReduction5() { +TEST(NVFuserTest, FusionReduction5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3322,7 +3352,7 @@ void testGPU_FusionReduction5() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -3330,7 +3360,7 @@ void testGPU_FusionReduction5() { TORCH_CHECK(aten_output.allclose(outputs[0])); } -void testGPU_FusionReductionTFT() { +TEST(NVFuserTest, FusionReductionTFT_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3376,18 +3406,15 @@ void testGPU_FusionReductionTFT() { at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); - c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } -void testGPU_FusionBranches() { +TEST(NVFuserTest, FusionBranches_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3414,7 +3441,7 @@ void testGPU_FusionBranches() { at::Tensor t1 = at::randn({x, y}, options); at::Tensor t2 = at::randn({x, y}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; tv6->merge(0); tv6->split(0, 128); tv6->split(0, 4); @@ -3444,7 +3471,7 @@ void testGPU_FusionBranches() { TORCH_CHECK(t6.allclose(outputs[0])); } -void testGPU_FusionSimpleBCast() { +TEST(NVFuserTest, FusionSimpleBCast_CUDA) { { Fusion fusion; FusionGuard fg(&fusion); @@ -3491,7 +3518,7 @@ void testGPU_FusionSimpleBCast() { at::Tensor t6 = t4.expand({x, y, z}); at::Tensor t7 = t5.add(t6); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t2, t3}); @@ -3547,7 +3574,7 @@ void testGPU_FusionSimpleBCast() { at::Tensor cg_output = at::empty({x, y, z}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1, t4}, {cg_output}); @@ -3597,7 +3624,7 @@ void testGPU_FusionSimpleBCast() { at::Tensor cg_output = at::empty({x, y, z}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t2}, {cg_output}); @@ -3647,7 +3674,7 @@ void testGPU_FusionSimpleBCast() { at::Tensor cg_output = at::empty({x, y, z}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1}, {cg_output}); @@ -3696,7 +3723,7 @@ void testGPU_FusionSimpleBCast() { at::Tensor cg_output = at::empty({m, k, n}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1}, {cg_output}); @@ -3708,7 +3735,7 @@ void testGPU_FusionSimpleBCast() { } } -void testGPU_FusionComplexBCast() { +TEST(NVFuserTest, FusionComplexBCast_CUDA) { { Fusion fusion; FusionGuard fg(&fusion); @@ -3755,7 +3782,7 @@ void testGPU_FusionComplexBCast() { auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3; auto t7 = t4.unsqueeze(0).expand({x, y, z}) + t6; - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t3, t6}); @@ -3803,7 +3830,7 @@ void testGPU_FusionComplexBCast() { at::Tensor t4 = at::randn({x, y}, options); auto t5 = t3.add(t4); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t4}); @@ -3811,176 +3838,170 @@ void testGPU_FusionComplexBCast() { } } -void testGPU_FusionAdvancedIndexing() { - // Merging left to right is still broken in some instances. Indexing can't - // complete because we assume we can simply traverse consumer->producer in the - // index/extent map, but this case breaks this assumption. - { - Fusion fusion; - FusionGuard fg(&fusion); +TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); - int w = 3, x = 4, y = 7, z = 8; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto tv0 = makeDummyTensor(3); - auto tv1 = makeDummyTensor(4); - fusion.addInput(tv0); - fusion.addInput(tv1); + auto tv0 = makeDummyTensor(3); + auto tv1 = makeDummyTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); - auto tv2 = add(tv0, new Float(1.0)); - auto tv3 = broadcast(tv2, {true, false, false, false}); - auto tv4 = add(tv3, tv1); + auto tv2 = add(tv0, new Float(1.0)); + auto tv3 = broadcast(tv2, {true, false, false, false}); + auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); + fusion.addOutput(tv4); - tv4->merge(0); - tv4->merge(0); - tv4->merge(0); + tv4->merge(0); + tv4->merge(0); + tv4->merge(0); - tv4->split(0, 128); - tv4->split(0, 4); + tv4->split(0, 128); + tv4->split(0, 4); - tv2->computeAt(tv4, 1); + tv2->computeAt(tv4, 1); - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::Unroll); - tv4->axis(2)->parallelize(ParallelType::TIDx); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unroll); + tv4->axis(2)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(2)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(2)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(2)->parallelize(ParallelType::TIDx); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; - at::Tensor t0 = at::randn({x, y, z}, options); - at::Tensor t1 = at::randn({w, x, y, z}, options); + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0, t1}); + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); - auto t3 = t0.add(1.0); - auto t4 = t3.add(t1); + auto t3 = t0.add(1.0); + auto t4 = t3.add(t1); - TORCH_CHECK(t4.allclose(outputs[0])); - } + TORCH_CHECK(t4.allclose(outputs[0])); +} - // Merging right to left actually does work. - { - Fusion fusion; - FusionGuard fg(&fusion); +TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); - int w = 3, x = 4, y = 7, z = 8; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto tv0 = makeDummyTensor(3); - auto tv1 = makeDummyTensor(4); - fusion.addInput(tv0); - fusion.addInput(tv1); + auto tv0 = makeDummyTensor(3); + auto tv1 = makeDummyTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); - auto tv2 = add(tv0, new Float(1.0)); - auto tv3 = broadcast(tv2, {true, false, false, false}); - auto tv4 = add(tv3, tv1); + auto tv2 = add(tv0, new Float(1.0)); + auto tv3 = broadcast(tv2, {true, false, false, false}); + auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); + fusion.addOutput(tv4); - tv4->merge(-2); - tv4->merge(-2); - tv4->merge(-2); + tv4->merge(-2); + tv4->merge(-2); + tv4->merge(-2); - tv4->split(0, 128); - tv4->split(0, 4); + tv4->split(0, 128); + tv4->split(0, 4); - tv2->computeAt(tv4, 1); + tv2->computeAt(tv4, 1); - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::Unroll); - tv4->axis(2)->parallelize(ParallelType::TIDx); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unroll); + tv4->axis(2)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(2)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(2)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(2)->parallelize(ParallelType::TIDx); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; - at::Tensor t0 = at::randn({x, y, z}, options); - at::Tensor t1 = at::randn({w, x, y, z}, options); + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0, t1}); + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); - auto t3 = t0.add(1.0); - auto t4 = t3.add(t1); + auto t3 = t0.add(1.0); + auto t4 = t3.add(t1); - TORCH_CHECK(t4.allclose(outputs[0])); - } - // Same issue as the first one in this section - { - Fusion fusion; - FusionGuard fg(&fusion); + TORCH_CHECK(t4.allclose(outputs[0])); +} - int w = 3, x = 4, y = 7, z = 8; +TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); - auto tv0 = makeDummyTensor(3); - auto tv1 = makeDummyTensor(4); - fusion.addInput(tv0); - fusion.addInput(tv1); + int w = 3, x = 4, y = 7, z = 8; - auto tv2 = add(tv0, new Float(1.0)); - auto tv3 = add(tv2, tv1); - fusion.addOutput(tv3); + auto tv0 = makeDummyTensor(3); + auto tv1 = makeDummyTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({x, y, z}, options); - at::Tensor t1 = at::randn({w, x, y, z}, options); + auto tv2 = add(tv0, new Float(1.0)); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); - fuser::cuda::scheduleFusion(&fusion, {t0, t1}); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0, t1}); + scheduleFusion(&fusion, {t0, t1}); - auto t2 = t0.add(1.0); - auto t3 = t2.add(t1); + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); - TORCH_CHECK(t3.allclose(outputs[0])); - } + auto t2 = t0.add(1.0); + auto t3 = t2.add(t1); - { - Fusion fusion; - FusionGuard fg(&fusion); + TORCH_CHECK(t3.allclose(outputs[0])); +} - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({10, 20}); - fusion.addInput(tv0); - TensorView* tv1 = makeConcreteTensor({10, 10, 20}); - fusion.addInput(tv1); +TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); - TensorView* tv2 = add(tv0, new Float(1)); - TensorView* tv3 = broadcast(tv2, {true, false, false}); - TensorView* tv4 = add(tv3, tv1); - fusion.addOutput(tv4); + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({10, 20}); + fusion.addInput(tv0); + TensorView* tv1 = makeConcreteTensor({10, 10, 20}); + fusion.addInput(tv1); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10, 20}, options); - at::Tensor t1 = at::randn({10, 10, 20}, options); + TensorView* tv2 = add(tv0, new Float(1)); + TensorView* tv3 = broadcast(tv2, {true, false, false}); + TensorView* tv4 = add(tv3, tv1); + fusion.addOutput(tv4); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0, t1}); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10, 20}, options); + at::Tensor t1 = at::randn({10, 10, 20}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); - auto t2 = t0.add(1.0); - auto t3 = t2.add(t1); + auto t2 = t0.add(1.0); + auto t3 = t2.add(t1); - TORCH_CHECK(t3.allclose(outputs[0])); - } + TORCH_CHECK(t3.allclose(outputs[0])); } // Test a simple Gemm but also play around with fusion executor features -void testGPU_FusionSimpleGemm() { +TEST(NVFuserTest, FusionSimpleGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4046,15 +4067,13 @@ void testGPU_FusionSimpleGemm() { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); // Lets specify a few bounds in launch params to make sure it works - fe.runFusion( - {t0, t1}, torch::jit::fuser::cuda::LaunchParams(1, -1, -1, 32, 4, 4)); + fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); // Make sure bad launch params throws - ASSERT_ANY_THROW(fe.runFusion( - {t0, t1}, torch::jit::fuser::cuda::LaunchParams(1, 2, 3, 4, 5, 6))); + ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6))); // Don't specify any launch params auto outputs = fe.runFusion({t0, t1}); @@ -4067,7 +4086,7 @@ void testGPU_FusionSimpleGemm() { } // Softmax with a 1D tensor. Parallelized only with a single thread block. -void testGPU_FusionSoftmax1D() { +TEST(NVFuserTest, FusionSoftmax1D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4112,7 +4131,7 @@ void testGPU_FusionSoftmax1D() { at::Tensor cg_output = at::empty({dimx}, options); at::Tensor t3_output = at::empty_like(cg_output, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {cg_output}); @@ -4124,7 +4143,7 @@ void testGPU_FusionSoftmax1D() { } // Softmax with a 1D tensor with input normalization. -void testGPU_FusionSoftmax1DNormalized() { +TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4166,13 +4185,14 @@ void testGPU_FusionSoftmax1DNormalized() { sub_tv3->computeAt(sum_exp_rf_tv9, -1); sub_tv3_copy->computeAt(output_tv7, -1); - TensorView* tensors_to_parallelize[] = {max_val_tv1, - bcast_max_tv2, - sum_exp_tv5, - bcast_sum_tv6, - output_tv7, - max_val_rf_tv8, - sum_exp_rf_tv9}; + TensorView* tensors_to_parallelize[] = { + max_val_tv1, + bcast_max_tv2, + sum_exp_tv5, + bcast_sum_tv6, + output_tv7, + max_val_rf_tv8, + sum_exp_rf_tv9}; for (auto tv : tensors_to_parallelize) { tv->axis(-1)->parallelize(ParallelType::TIDx); @@ -4182,7 +4202,7 @@ void testGPU_FusionSoftmax1DNormalized() { at::Tensor t0 = at::randn({dimx}, options); at::Tensor t3_output = at::empty({dimx}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0}); @@ -4195,7 +4215,7 @@ void testGPU_FusionSoftmax1DNormalized() { // Softmax with a 3D tensor, where the inner-most 3rd dimension is // normalized. Pallelized with multiple thread blocks. -void testGPU_FusionSoftmax3D() { +TEST(NVFuserTest, FusionSoftmax3D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4243,7 +4263,7 @@ void testGPU_FusionSoftmax3D() { at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); at::Tensor t3_output = at::empty_like(cg_output, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {cg_output}); @@ -4255,7 +4275,7 @@ void testGPU_FusionSoftmax3D() { } // Softmax with a 3D tensor with input normalization. -void testGPU_FusionSoftmax3DNormalized() { +TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4300,13 +4320,14 @@ void testGPU_FusionSoftmax3DNormalized() { sub_tv3->computeAt(sum_exp_rf_tv9, -1); sub_tv3_copy->computeAt(output_tv7, -1); - TensorView* tensors_to_parallelize[] = {max_val_tv1, - bcast_max_tv2, - sum_exp_tv5, - bcast_sum_tv6, - output_tv7, - max_val_rf_tv8, - sum_exp_rf_tv9}; + TensorView* tensors_to_parallelize[] = { + max_val_tv1, + bcast_max_tv2, + sum_exp_tv5, + bcast_sum_tv6, + output_tv7, + max_val_rf_tv8, + sum_exp_rf_tv9}; for (auto tv : tensors_to_parallelize) { tv->axis(0)->parallelize(ParallelType::BIDx); @@ -4318,7 +4339,7 @@ void testGPU_FusionSoftmax3DNormalized() { at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0}); @@ -4329,7 +4350,7 @@ void testGPU_FusionSoftmax3DNormalized() { t2.sub(outputs[0]).abs().max()); } -void testGPU_FusionSoftmaxComputeAt() { +TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4355,7 +4376,7 @@ void testGPU_FusionSoftmaxComputeAt() { } // Similar to FusionReduction but uses grid reduction -void testGPU_FusionGridReduction1() { +TEST(NVFuserTest, FusionGridReduction1_CUDA) { const int gdimx = 32; const int bdimx = 128; @@ -4404,7 +4425,7 @@ void testGPU_FusionGridReduction1() { at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); @@ -4413,7 +4434,7 @@ void testGPU_FusionGridReduction1() { } // Same test as the above but uses BIDy and TIDx for reduction -void testGPU_FusionGridReduction2() { +TEST(NVFuserTest, FusionGridReduction2_CUDA) { const int gdimy = 32; const int bdimx = 128; @@ -4459,7 +4480,7 @@ void testGPU_FusionGridReduction2() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -4468,7 +4489,7 @@ void testGPU_FusionGridReduction2() { } // Same test but uses BIDy and BIDz for reduction. No TID used. -void testGPU_FusionGridReduction3dim1() { +TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { const int gdimz = 32; const int gdimy = 128; @@ -4515,7 +4536,7 @@ void testGPU_FusionGridReduction3dim1() { at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); @@ -4524,7 +4545,7 @@ void testGPU_FusionGridReduction3dim1() { } // Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0 -void testGPU_FusionGridReduction3dim0() { +TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { const int rdim = 0; const int gdimy = 128; const int gdimz = 32; @@ -4568,7 +4589,7 @@ void testGPU_FusionGridReduction3dim0() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -4577,7 +4598,7 @@ void testGPU_FusionGridReduction3dim0() { } // This is similar to the FusionReduction, but swaps BIDx and TIDx -void testGPU_FusionGridReduction4() { +TEST(NVFuserTest, FusionGridReduction4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4630,7 +4651,7 @@ void testGPU_FusionGridReduction4() { at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); @@ -4640,7 +4661,7 @@ void testGPU_FusionGridReduction4() { // Grid reduction with 2D thread blocks but only TIDx and BIDx are // mapped to a reduction dim -void testGPU_FusionGridReduction5() { +TEST(NVFuserTest, FusionGridReduction5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4683,7 +4704,7 @@ void testGPU_FusionGridReduction5() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -4692,7 +4713,7 @@ void testGPU_FusionGridReduction5() { } // Similar to FusionGridReduction1 but with 3D tensors -void testGPU_FusionGridReduction6() { +TEST(NVFuserTest, FusionGridReduction6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4745,7 +4766,7 @@ void testGPU_FusionGridReduction6() { at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); @@ -4753,7 +4774,7 @@ void testGPU_FusionGridReduction6() { TORCH_CHECK(aten_output.allclose(cg_output)); } -void testGPU_FusionNonRedAxisBind() { +TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { int bid_x = 3; int tid_x = 2; int red_dim = 0; @@ -4776,7 +4797,7 @@ void testGPU_FusionNonRedAxisBind() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({16, bid_x * tid_x}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -4788,7 +4809,7 @@ void testGPU_FusionNonRedAxisBind() { aten_output.sub(outputs[0]).abs().max()); } -void testGPU_FusionSplitBCast() { +TEST(NVFuserTest, FusionSplitBCast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4831,12 +4852,12 @@ void testGPU_FusionSplitBCast() { at::Tensor t1 = at::randn({32, 32, 128}, options); at::Tensor cg_output = at::empty({32, 32, 128}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1}, {cg_output}); } -void testGPU_FusionBCastInnerDim() { +TEST(NVFuserTest, FusionBCastInnerDim_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4850,7 +4871,7 @@ void testGPU_FusionBCastInnerDim() { TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast()); } -void testGPU_FusionBCastReduce() { +TEST(NVFuserTest, FusionBCastReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4866,7 +4887,7 @@ void testGPU_FusionBCastReduce() { // Multiple consumer reduction with computeAt // https://github.com/csarofeen/pytorch/issues/110 -void testGPU_FusionReductionMultiConsumer() { +TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); @@ -4883,77 +4904,76 @@ void testGPU_FusionReductionMultiConsumer() { tv1->getThisComputeAtAxis() == 2 && tv1->getRelativeComputeAtAxis() == 2); } -void testGPU_FusionComputeAtExprOrder() { - { - for (int i = 0; i < 2; ++i) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeDummyTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, new Float(1)); - auto tv2 = add(tv0, new Float(1)); - TensorView* tv3 = add(tv1, tv2); - if (i == 0) { - tv1->computeAt(tv3, -1); - fusion.addOutput(tv2); - } else { - tv2->computeAt(tv3, -1); - fusion.addOutput(tv1); - } - fusion.addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::rand({100}, options); - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({input}); - - auto aten_output = (input + 1) * 2; - TORCH_CHECK( - aten_output.allclose(outputs[1]), - "Error of: ", - aten_output.sub(outputs[1]).abs().max()); - } - } - { +TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { + for (int i = 0; i < 2; ++i) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views - TensorView* tv0 = makeDummyTensor(2); + TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); auto tv1 = add(tv0, new Float(1)); auto tv2 = add(tv0, new Float(1)); TensorView* tv3 = add(tv1, tv2); + if (i == 0) { + tv1->computeAt(tv3, -1); + fusion.addOutput(tv2); + } else { + tv2->computeAt(tv3, -1); + fusion.addOutput(tv1); + } fusion.addOutput(tv3); - tv3->split(-1, 32); - - tv1->computeAt(tv3, -1); - tv2->computeAt(tv3, -2); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::rand({100, 100}, options); - at::Tensor output = at::empty_like(input, options); + at::Tensor input = at::rand({100}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); - fe.runFusion({input}, {output}); + auto outputs = fe.runFusion({input}); auto aten_output = (input + 1) * 2; TORCH_CHECK( - aten_output.allclose(output), + aten_output.allclose(outputs[1]), "Error of: ", - aten_output.sub(output).abs().max()); + aten_output.sub(outputs[1]).abs().max()); } } -void testGPU_FusionZeroDimComputeAt() { +TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, new Float(1)); + auto tv2 = add(tv0, new Float(1)); + TensorView* tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + tv3->split(-1, 32); + + tv1->computeAt(tv3, -1); + tv2->computeAt(tv3, -2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::rand({100, 100}, options); + at::Tensor output = at::empty_like(input, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {output}); + + auto aten_output = (input + 1) * 2; + TORCH_CHECK( + aten_output.allclose(output), + "Error of: ", + aten_output.sub(output).abs().max()); +} + +TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4969,7 +4989,7 @@ void testGPU_FusionZeroDimComputeAt() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -4980,7 +5000,7 @@ void testGPU_FusionZeroDimComputeAt() { aten_output.sub(outputs[0]).abs().max()); } -void testGPU_FusionZeroDimBroadcast() { +TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5004,7 +5024,7 @@ void testGPU_FusionZeroDimBroadcast() { at::Tensor input2 = at::rand({10, 10}, options); at::Tensor output = at::empty({}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input1, input2}, {output}); @@ -5016,7 +5036,7 @@ void testGPU_FusionZeroDimBroadcast() { aten_output.sub(output).abs().max()); } -void testGPU_FusionZeroDimReduction() { +TEST(NVFuserTest, FusionZeroDimReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5042,7 +5062,7 @@ void testGPU_FusionZeroDimReduction() { at::Tensor input = at::rand({1000}, options); at::Tensor output = at::empty({}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {output}); @@ -5053,7 +5073,7 @@ void testGPU_FusionZeroDimReduction() { aten_output.sub(output).abs().max()); } -void testGPU_FusionBCastAfterReduce() { +TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 128; @@ -5093,7 +5113,7 @@ void testGPU_FusionBCastAfterReduce() { at::Tensor t0 = at::randn({x, y}, options); at::Tensor t4 = at::randn({x, y}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t4}); @@ -5104,7 +5124,7 @@ void testGPU_FusionBCastAfterReduce() { TORCH_CHECK(t5.allclose(outputs[0], 1e-5, 1e-5)); } -void testGPU_FusionReductionScheduler() { +TEST(NVFuserTest, FusionReductionScheduler_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -5125,11 +5145,11 @@ void testGPU_FusionReductionScheduler() { at::Tensor input = at::randn({bid_x, tid_x}, options); // Apply reduction heuristic - auto reduction_params = cuda::getReductionHeuristics(&fusion, {input}, tv1); + auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - cuda::scheduleReduction(&fusion, reduction_params.value(), tv1, {}); + scheduleReduction(&fusion, reduction_params.value(), tv1, {}); - cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); // no broadcasting needed, omitting the last optional argument; auto outputs = fe.runFusion({input}, reduction_params.value().lparams); @@ -5142,7 +5162,7 @@ void testGPU_FusionReductionScheduler() { } // Simple reduction parallelized on a symbolic size. -void testGPU_FusionSymbolicReduction() { +TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5181,18 +5201,16 @@ void testGPU_FusionSymbolicReduction() { // How many threads to use for the block reduction int runtime_threadIdx_dim = 128; - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion( - {input}, - torch::jit::fuser::cuda::LaunchParams( - -1, -1, -1, runtime_threadIdx_dim, -1, -1)); + {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1)); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(outputs[0])); } -void testGPU_FusionReductionSchedulerMultiDimNonFastest() { +TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { const std::vector red_dims = {0, 2}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions @@ -5216,11 +5234,11 @@ void testGPU_FusionReductionSchedulerMultiDimNonFastest() { at::Tensor cg_output = at::empty(tensor_dims_out, options); // Apply reduction heuristic - auto reduction_params = cuda::getReductionHeuristics(&fusion, {input}, tv1); + auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - cuda::scheduleReduction(&fusion, reduction_params.value(), tv1, {}); + scheduleReduction(&fusion, reduction_params.value(), tv1, {}); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}, reduction_params.value().lparams); @@ -5232,7 +5250,7 @@ void testGPU_FusionReductionSchedulerMultiDimNonFastest() { aten_output.sub(outputs[0]).abs().max()); } -void testGPU_FusionReductionSchedulerMultiDimFastest() { +TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { const std::vector red_dims = {1, 3}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions @@ -5254,11 +5272,11 @@ void testGPU_FusionReductionSchedulerMultiDimFastest() { at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::randn(tensor_dims_in, options); - auto reduction_params = cuda::getReductionHeuristics(&fusion, {input}, tv1); + auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - cuda::scheduleReduction(&fusion, reduction_params.value(), tv1, {}); + scheduleReduction(&fusion, reduction_params.value(), tv1, {}); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}, reduction_params.value().lparams); @@ -5270,12 +5288,16 @@ void testGPU_FusionReductionSchedulerMultiDimFastest() { aten_output.sub(outputs[0]).abs().max()); } -void testGPU_FusionReductionSchedulerDimShmoo() { +TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { std::vector fp16_usage = {true, false}; std::vector red_axis = {1, 0}; std::vector output_dims = {320, 640}; std::vector red_dims; + // Making sure we get deterministic results + // (see https://github.com/csarofeen/pytorch/issues/399) + at::manual_seed(0); + // Tried to cut down the number iterations with just // doing every other power of 2. for (int i = 1; i <= 1024 * 1024; i <<= 2) { @@ -5293,7 +5315,7 @@ void testGPU_FusionReductionSchedulerDimShmoo() { makeDummyTensor(2, (fp16 ? DataType::Half : DataType::Float)); fusion.addInput(tv0); - torch::jit::fuser::Val* tv0_cast = nullptr; + Val* tv0_cast = nullptr; if (fp16) { tv0_cast = castOp(DataType::Float, tv0); } @@ -5323,13 +5345,12 @@ void testGPU_FusionReductionSchedulerDimShmoo() { outputs_of_red.push_back(tv1_cast); } - auto reduction_params = - cuda::getReductionHeuristics(&fusion, {input}, tv1); + auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1); TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!"); - cuda::scheduleReduction( + scheduleReduction( &fusion, reduction_params.value(), tv1, outputs_of_red); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = @@ -5346,7 +5367,7 @@ void testGPU_FusionReductionSchedulerDimShmoo() { } } -void testGPU_FusionCacheBefore() { +TEST(NVFuserTest, FusionCacheBefore_CUDA) { // TVM Cache Write Fusion fusion; FusionGuard fg(&fusion); @@ -5376,7 +5397,7 @@ void testGPU_FusionCacheBefore() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({M, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -5387,7 +5408,7 @@ void testGPU_FusionCacheBefore() { aten_output.sub(outputs[0]).abs().sum()); } -void testGPU_FusionCacheAfter() { +TEST(NVFuserTest, FusionCacheAfter_CUDA) { // TVM Cache Read Fusion fusion; FusionGuard fg(&fusion); @@ -5417,7 +5438,7 @@ void testGPU_FusionCacheAfter() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({M, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -5428,7 +5449,7 @@ void testGPU_FusionCacheAfter() { aten_output.sub(outputs[0]).abs().sum()); } -void testGPU_FusionCacheIndirect() { +TEST(NVFuserTest, FusionCacheIndirect_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5466,7 +5487,7 @@ void testGPU_FusionCacheIndirect() { at::Tensor in2 = at::rand({M, N}, options); at::Tensor in3 = at::rand({M, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({in0, in1, in2, in3}); @@ -5477,7 +5498,7 @@ void testGPU_FusionCacheIndirect() { aten_output.sub(outputs[0]).abs().sum()); } -void testGPU_FusionCacheBcast() { +TEST(NVFuserTest, FusionCacheBcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5525,7 +5546,7 @@ void testGPU_FusionCacheBcast() { at::Tensor t0 = at::randn({M}, options); at::Tensor t1 = at::randn({N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); @@ -5536,7 +5557,7 @@ void testGPU_FusionCacheBcast() { aten_output.sub(outputs[0]).abs().max()); } -void testGPU_FusionCacheComplex() { +TEST(NVFuserTest, FusionCacheComplex_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5579,7 +5600,7 @@ void testGPU_FusionCacheComplex() { at::Tensor input1 = at::rand({N, N}, options); at::Tensor input2 = at::rand({N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); @@ -5591,7 +5612,7 @@ void testGPU_FusionCacheComplex() { aten_output.sub(outputs[0]).abs().sum()); } -void testGPU_FusionCacheMultiConsumer() { +TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5621,7 +5642,7 @@ void testGPU_FusionCacheMultiConsumer() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -5636,7 +5657,7 @@ void testGPU_FusionCacheMultiConsumer() { aten_output.sub(outputs[1]).abs().sum()); } -void testGPU_FusionSmem() { +TEST(NVFuserTest, FusionSmem_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5679,7 +5700,7 @@ void testGPU_FusionSmem() { at::Tensor t0 = at::randn({M, N}, options); at::Tensor t1 = at::randn({M, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); @@ -5691,7 +5712,7 @@ void testGPU_FusionSmem() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0); } -void testGPU_FusionSmemReduce() { +TEST(NVFuserTest, FusionSmemReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5729,7 +5750,7 @@ void testGPU_FusionSmemReduce() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({M, K, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0}); @@ -5742,7 +5763,7 @@ void testGPU_FusionSmemReduce() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1); } -void testGPU_FusionSmemBlockGemm() { +TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5793,7 +5814,7 @@ void testGPU_FusionSmemBlockGemm() { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); @@ -5805,7 +5826,7 @@ void testGPU_FusionSmemBlockGemm() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0); } -void testGPU_FusionSmemBlockGemmCache() { +TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5879,7 +5900,7 @@ void testGPU_FusionSmemBlockGemmCache() { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); @@ -5891,7 +5912,477 @@ void testGPU_FusionSmemBlockGemmCache() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0); } -void testGPU_FusionSmemDynamicReductionSymbolic() { +TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = makeDummyTensor(2); + fusion.addInput(x); + TensorView* max_val = + reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), x); // (M) + TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) + TensorView* x_max_sub = sub(x, bcast_max); // (M, N) + TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N) + TensorView* sum_exp = sum(exp, {-1}); // (M, R) + TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B) + TensorView* softmax = div(exp, bcast_sum); // (M, N) + fusion.addOutput(softmax); + + // Read Input into Shared Memory + // Load Input + Pwise into shared memory + auto cache_x = x->cache_after(); + cache_x->setMemoryType(MemoryType::Shared); + exp->setMemoryType(MemoryType::Shared); + + std::vector all_tensors( + {x, + cache_x, + max_val, + bcast_max, + x_max_sub, + exp, + sum_exp, + bcast_sum, + softmax}); + + auto tidx = new Int(); + fusion.addInput(tidx); + + for (auto tensor : all_tensors) { + tensor->split(-1, tidx); + } + + auto sum_exp_rf = sum_exp->rFactor({1}); + all_tensors.push_back(sum_exp_rf); + + // computeAt + x->computeAt(x_max_sub, 1); + exp->computeAt(softmax, 1); + x_max_sub->computeAt(exp, 2); + + softmax->axis(0)->parallelize(ParallelType::BIDx); + for (auto tensor : all_tensors) { + tensor->axis(-1)->parallelize(ParallelType::TIDx); + } + + const size_t dimx = 1024; + const size_t dimy = 4096; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dimx, dimy}, options); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, 128}); + + auto t1 = at::_softmax(t0, -1, false); + TORCH_CHECK( + t1.allclose(outputs[0], 1e-5, 1e-5), + "Error of: ", + t1.sub(outputs[0]).abs().max()); +} + +TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int pixels_per_thread = 64; + const int TIDX = 128; + const int static_size = pixels_per_thread * TIDX; + + TensorView* sx = makeConcreteTensor({-1, static_size}); + TensorView* dx = makeDummyTensor(2); + fusion.addInput(sx); + fusion.addInput(dx); + + TensorView* max_sx = + reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), sx); // (M) + TensorView* max_dx = + reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), dx); // (M) + + // Reduction => merge local and shared memory TensorViews + TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx); + TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) + + TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N) + TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N) + + TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N) + TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N) + + TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R) + TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R) + + // Reduction => merge local and shared memory TensorViews + TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp); + TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B) + + TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N) + TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N) + fusion.addOutput(sx_softmax); + fusion.addOutput(dx_softmax); + + auto sx_cache = sx->cache_after(); + auto dx_cache = dx->cache_after(); + dx_cache->setMemoryType(MemoryType::Shared); + dx_exp->setMemoryType(MemoryType::Shared); + + // Reduction and Broadcast Tensors common to both memory TVs + std::vector common_tensors( + {max_val, sum_exp, bcast_max, bcast_sum}); + + // Static Local Memory TVs + std::vector static_tensors( + {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax}); + + // Dynamic Local Memory TVs + std::vector dynamic_tensors( + {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax}); + + std::vector all_tensors; + all_tensors.insert( + all_tensors.end(), common_tensors.begin(), common_tensors.end()); + all_tensors.insert( + all_tensors.end(), static_tensors.begin(), static_tensors.end()); + all_tensors.insert( + all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end()); + + // M => M + // M, N => M, N/128, 128 + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->split(-1, TIDX); + } + } + + auto sx_sum_exp_rf = sx_sum_exp->rFactor({1}); + auto dx_sum_exp_rf = dx_sum_exp->rFactor({1}); + all_tensors.push_back(sx_sum_exp_rf); + all_tensors.push_back(dx_sum_exp_rf); + + // computeAt + sx->computeAt(sx_max_sub, 1); + dx->computeAt(dx_max_sub, 1); + + sx_exp->computeAt(sx_softmax, 1); + dx_exp->computeAt(dx_softmax, 1); + + sx_max_sub->computeAt(sx_exp, 2); + dx_max_sub->computeAt(dx_exp, 2); + + sx_softmax->axis(0)->parallelize(ParallelType::BIDx); + dx_softmax->axis(0)->parallelize(ParallelType::BIDx); + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + const size_t dimx = 1024; + const size_t dimy = 16384; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in = at::randn({dimx, dimy}, options); + at::Tensor static_in = in.narrow(1, 0, static_size); + at::Tensor dynamic_in = in.narrow(1, static_size, dimy - static_size); + + at::Tensor out = at::zeros({dimx, dimy}, options); + at::Tensor static_out = out.narrow(1, 0, static_size); + at::Tensor dynamic_out = out.narrow(1, static_size, dimy - static_size); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = + fe.runFusion({static_in, dynamic_in}, {static_out, dynamic_out}); + + auto t1 = at::_softmax(in, -1, false); + TORCH_CHECK( + t1.allclose(out, 1e-5, 1e-5), "Error of: ", t1.sub(out).abs().max()); +} + +TEST(NVFuserTest, FusionPersistentBatchNormLocalShared_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int pixels_per_thread = 64; + const int TIDX = 128; + const int static_size = pixels_per_thread * TIDX; + + TensorView* sx = makeConcreteTensor({-1, static_size}); + TensorView* dx = makeDummyTensor(2); + fusion.addInput(sx); + fusion.addInput(dx); + + Float* gamma = new Float(); + Float* beta = new Float(); + Float* eps = new Float(); + Int* N = new Int(); + fusion.addInput(gamma); + fusion.addInput(beta); + fusion.addInput(eps); + fusion.addInput(N); + + // Reduction + auto sx_sum = sum(sx, {-1}); // (M, R) + auto dx_sum = sum(dx, {-1}); // (M, R) + // Reduction => merge local and shared memory TensorViews + auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum); + + // Broadcast + auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B) + // Pwise + auto x_mean = div(x_sum_bcast, N); // (M, B) + + auto sx_mean_sub = sub(sx, x_mean); // (M, N) + auto dx_mean_sub = sub(dx, x_mean); // (M, N) + + auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N) + auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N) + + // Reduction + auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R) + auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R) + // Reduction => merge local and shared memory TensorViews + auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum); + + // Broadcast + auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B) + // Pwise + auto var = div(var_sum_bcast, N); // (M, B) + auto var_eps = add(var, eps); // (M, B) + auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B) + + auto sx_norm = mul(sx_mean_sub, rvar); + auto dx_norm = mul(dx_mean_sub, rvar); + + auto sx_norm_gamma = mul(sx_norm, gamma); + auto dx_norm_gamma = mul(dx_norm, gamma); + + auto sx_norm_gamma_beta = add(sx_norm_gamma, beta); + auto dx_norm_gamma_beta = add(dx_norm_gamma, beta); + fusion.addOutput(sx_norm_gamma_beta); + fusion.addOutput(dx_norm_gamma_beta); + + // Read Input into Shared Memory + // Read Input minus Input_Mean into Shared Memory + auto sx_cache = sx->cache_after(); + auto dx_cache = dx->cache_after(); + dx_cache->setMemoryType(MemoryType::Shared); + dx_mean_sub->setMemoryType(MemoryType::Shared); + + std::vector common_tensors( + {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar}); + + std::vector static_tensors( + {sx, + sx_cache, + sx_sum, + sx_mean_sub, + sx_mean_sub_pow, + sx_var_sum, + sx_norm, + sx_norm_gamma, + sx_norm_gamma_beta}); + + std::vector dynamic_tensors( + {dx, + dx_cache, + dx_sum, + dx_mean_sub, + dx_mean_sub_pow, + dx_var_sum, + dx_norm, + dx_norm_gamma, + dx_norm_gamma_beta}); + + std::vector all_tensors; + all_tensors.insert( + all_tensors.end(), common_tensors.begin(), common_tensors.end()); + all_tensors.insert( + all_tensors.end(), static_tensors.begin(), static_tensors.end()); + all_tensors.insert( + all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end()); + + // M => M + // M, N => M, N/128, 128 + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->split(-1, TIDX); + } + } + + // Local Sum => Block Broadcast + TensorView* sx_sum_rf = sx_sum->rFactor({1}); + TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1}); + TensorView* dx_sum_rf = dx_sum->rFactor({1}); + TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1}); + all_tensors.push_back(sx_sum_rf); + all_tensors.push_back(sx_var_sum_rf); + all_tensors.push_back(dx_sum_rf); + all_tensors.push_back(dx_var_sum_rf); + + // ComputeAt + sx->computeAt(sx_mean_sub_pow, 1); + dx->computeAt(dx_mean_sub_pow, 1); + + var_sum->computeAt(rvar, 1); + + sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2); + dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2); + + sx_norm->computeAt(sx_norm_gamma_beta, 2); + dx_norm->computeAt(dx_norm_gamma_beta, 2); + + sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx); + dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx); + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + const int dimx = 1024; + const int dimy = 16384; + const float kGamma = 1.0f; + const float kBeta = 0.0f; + const float kEps = 1e-5; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor in = at::randn({dimx, dimy}, options); + at::Tensor static_in = in.narrow(1, 0, static_size); + at::Tensor dynamic_in = in.narrow(1, static_size, dimy - static_size); + + at::Tensor out = at::zeros({dimx, dimy}, options); + at::Tensor static_out = out.narrow(1, 0, static_size); + at::Tensor dynamic_out = out.narrow(1, static_size, dimy - static_size); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion( + {static_in, dynamic_in, kGamma, kBeta, kEps, dimy}, + {static_out, dynamic_out}); + + auto at_mu = at::mean(in, -1).unsqueeze(1); + auto at_var = at::var(in, -1).unsqueeze(1); + auto at_rvar = at::rsqrt(at::add(at_var, kEps)); + auto at_norm = at::mul(at::sub(in, at_mu), at_rvar); + auto at_norm_gamma_beta = at::add(at::mul(at_norm, kGamma), kBeta); + TORCH_CHECK( + at_norm_gamma_beta.allclose(out, 1e-3, 1e-3), + "Error of: ", + at_norm_gamma_beta.sub(out).abs().max()); +} + +TEST(NVFuserTest, FusionSmemDynamicPersistentBatchNorm_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + auto x = makeDummyTensor(2); + Float* gamma = new Float(); + Float* beta = new Float(); + Float* eps = new Float(); + Int* N = new Int(); + fusion.addInput(x); + fusion.addInput(gamma); + fusion.addInput(beta); + fusion.addInput(eps); + fusion.addInput(N); + + // Reduction + auto x_sum = sum(x, {-1}); // (M, R) + // Broadcast + auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B) + // Pwise + auto x_mean = div(x_sum_bcast, N); // (M, B) + auto x_mean_sub = sub(x, x_mean); // (M, N) + auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N) + // Reduction + auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R) + // Broadcast + auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B) + // Pwise + auto var = div(var_sum_bcast, N); // (M, B) + auto var_eps = add(var, eps); // (M, B) + auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B) + auto norm = mul(x_mean_sub, rvar); + auto norm_gamma = mul(norm, gamma); + auto norm_gamma_beta = add(norm_gamma, beta); + fusion.addOutput(norm_gamma_beta); + + // Read Input into Shared Memory + // Read Input minus Input_Mean into Shared Memory + auto cache_x = x->cache_after(); + cache_x->setMemoryType(MemoryType::Shared); + x_mean_sub->setMemoryType(MemoryType::Shared); + + std::vector all_tensors( + {x_sum, + x_mean, + cache_x, + x_sum_bcast, + x_mean_sub, + x_mean_sub_pow, + var_sum, + var_sum_bcast, + var, + var_eps, + rvar, + norm, + norm_gamma, + norm_gamma_beta}); + + auto tidx = new Int(); + fusion.addInput(tidx); + + for (auto tensor : all_tensors) { + tensor->split(-1, tidx); + } + norm_gamma->split(1, 1); + norm_gamma_beta->split(1, 1); + + // Local Sum => Block Broadcast + TensorView* x_sum_rf = x_sum->rFactor({1}); + TensorView* var_sum_rf = var_sum->rFactor({1}); + all_tensors.push_back(x_sum_rf); + all_tensors.push_back(var_sum_rf); + + // ComputeAt + x->computeAt(x_mean_sub_pow, 1); + var_sum->computeAt(rvar, 1); + x_mean_sub_pow->computeAt(var_sum_rf, 2); + norm->computeAt(norm_gamma_beta, 2); + + for (auto tv : all_tensors) { + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + const int dimx = 128; + const int dimy = 2048; + const float kGamma = 1.0f; + const float kBeta = 0.0f; + const float kEps = 1e-5; + const int TIDX = 128; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dimx, dimy}, options); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, kGamma, kBeta, kEps, dimy, TIDX}); + + auto at_mu = at::mean(t0, -1).unsqueeze(1); + auto at_var = at::var(t0, -1).unsqueeze(1); + auto at_rvar = at::rsqrt(at::add(at_var, kEps)); + auto at_norm = at::mul(at::sub(t0, at_mu), at_rvar); + auto at_norm_gamma_beta = at::add(at::mul(at_norm, kGamma), kBeta); + TORCH_CHECK( + at_norm_gamma_beta.allclose(outputs[0], 1e-3, 1e-3), + "Error of: ", + at_norm_gamma_beta.sub(outputs[0]).abs().max()); +} + +TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5925,12 +6416,10 @@ void testGPU_FusionSmemDynamicReductionSymbolic() { // How many threads to use for the block reduction constexpr int runtime_threadIdx_dim = 128; - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion( - {input}, - torch::jit::fuser::cuda::LaunchParams( - -1, -1, -1, runtime_threadIdx_dim, -1, -1)); + {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1)); auto aten_output = input.sum({1}); TORCH_CHECK( @@ -5940,7 +6429,7 @@ void testGPU_FusionSmemDynamicReductionSymbolic() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0); } -void testGPU_FusionSmemDynamicReductionSymbolicArg() { +TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5984,12 +6473,11 @@ void testGPU_FusionSmemDynamicReductionSymbolicArg() { // How many threads to use for the block reduction constexpr int runtime_threadIdx_dim = 128; - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion( {t0, runtime_threadIdx_dim}, - torch::jit::fuser::cuda::LaunchParams( - -1, -1, -1, runtime_threadIdx_dim, -1, -1)); + LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1)); at::Tensor aten_output = sum(t0, {1}); TORCH_CHECK( @@ -6000,7 +6488,7 @@ void testGPU_FusionSmemDynamicReductionSymbolicArg() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1); } -void testGPU_FusionSmemDynamicPwiseMulSymbolicArgWAR() { +TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6044,11 +6532,10 @@ void testGPU_FusionSmemDynamicPwiseMulSymbolicArgWAR() { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); - auto outputs = fe.runFusion( - {t0, t1, BSX}, - torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, BSX, -1, -1)); + auto outputs = + fe.runFusion({t0, t1, BSX}, LaunchParams(-1, -1, -1, BSX, -1, -1)); at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)); TORCH_CHECK( @@ -6059,7 +6546,7 @@ void testGPU_FusionSmemDynamicPwiseMulSymbolicArgWAR() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(22) == 1); } -void testGPU_FusionSmemDynamicTiledGemm() { +TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6164,7 +6651,7 @@ void testGPU_FusionSmemDynamicTiledGemm() { at::Tensor A = at::randn({M, K}, options); at::Tensor B = at::randn({K, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; // Generate CUDA and compile with nvRTC fe.compileFusion(&fusion); @@ -6185,7 +6672,7 @@ void testGPU_FusionSmemDynamicTiledGemm() { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(41) == 1); } -void testGPU_FusionGlobalIntermediate() { +TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6219,12 +6706,10 @@ void testGPU_FusionGlobalIntermediate() { // How many threads to use for the block reduction constexpr int runtime_threadIdx_dim = 128; - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion( - {input}, - torch::jit::fuser::cuda::LaunchParams( - -1, -1, -1, runtime_threadIdx_dim, -1, -1)); + {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1)); auto aten_output = input.sum({1}); TORCH_CHECK( @@ -6233,7 +6718,7 @@ void testGPU_FusionGlobalIntermediate() { aten_output.sub(outputs[0]).abs().max()); } -void testGPU_FusionGlobalIntermediateDefaultSchedule() { +TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6262,7 +6747,7 @@ void testGPU_FusionGlobalIntermediateDefaultSchedule() { at::Tensor in2 = at::rand({M, N}, options); at::Tensor in3 = at::rand({M, N}, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({in0, in1, in2, in3}); @@ -6273,7 +6758,7 @@ void testGPU_FusionGlobalIntermediateDefaultSchedule() { aten_output.sub(outputs[0]).abs().sum()); } -void testGPU_FusionConstCheck() { +TEST(NVFuserTest, FusionConstCheck_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6290,7 +6775,7 @@ void testGPU_FusionConstCheck() { TORCH_CHECK(one_x4->isConstScalar()); } -void testGPU_FusionUnrollWithAlloc() { +TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { const std::vector tensor_dims_in = {128, 128}; Fusion fusion; FusionGuard fg(&fusion); @@ -6325,7 +6810,7 @@ void testGPU_FusionUnrollWithAlloc() { tv1->computeAt(tv2_rf, -1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -6338,7 +6823,7 @@ void testGPU_FusionUnrollWithAlloc() { } // Test isZeroInt -void testGPU_FusionIsZeroInt() { +TEST(NVFuserTest, FusionIsZeroInt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6351,7 +6836,7 @@ void testGPU_FusionIsZeroInt() { } // Test isOneInt -void testGPU_FusionIsOneInt() { +TEST(NVFuserTest, FusionIsOneInt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6366,7 +6851,7 @@ void testGPU_FusionIsOneInt() { // This is to verify no cycle of computeAt is created. A more complex // variation of this pattern appears in one of the Python tests // (test_random_topo). -void testGPU_FusionComputeAtNonterminatingOutput() { +TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6401,7 +6886,7 @@ void testGPU_FusionComputeAtNonterminatingOutput() { at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand(100, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); @@ -6430,7 +6915,7 @@ void testGPU_FusionComputeAtNonterminatingOutput() { return; } -void testGPU_FusionTraversalOrder1() { +TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6449,7 +6934,7 @@ void testGPU_FusionTraversalOrder1() { tv1->computeAt(tv3, -1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -6478,7 +6963,7 @@ void testGPU_FusionTraversalOrder1() { t4.sub(cg_output_tv4).abs().max()); } -void testGPU_FusionTraversalOrder2() { +TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6501,7 +6986,7 @@ void testGPU_FusionTraversalOrder2() { tv1->computeAt(tv5, -1); tv3->computeAt(tv5, -1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -6531,7 +7016,7 @@ void testGPU_FusionTraversalOrder2() { t5.sub(cg_output_tv5).abs().max()); } -void testGPU_FusionTraversalOrder3() { +TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { for (int i = 0; i < 2; ++i) { Fusion fusion; FusionGuard fg(&fusion); @@ -6568,7 +7053,7 @@ void testGPU_FusionTraversalOrder3() { compute_at_outer->computeAt(tv5, -2); compute_at_inner->computeAt(tv5, -1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -6599,7 +7084,7 @@ void testGPU_FusionTraversalOrder3() { } } -void testGPU_FusionTraversalOrder4() { +TEST(NVFuserTest, FusionTraversalOrder4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6624,7 +7109,7 @@ void testGPU_FusionTraversalOrder4() { tv1->computeAt(tv2, -1); tv5->computeAt(tv6, -1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -6663,7 +7148,7 @@ void testGPU_FusionTraversalOrder4() { t7.sub(cg_output_tv7).abs().max()); } -void testGPU_FusionTraversalOrder5() { +TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6682,7 +7167,7 @@ void testGPU_FusionTraversalOrder5() { tv2->computeAt(tv5, -1); tv4->computeAt(tv5, -1); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -6713,7 +7198,7 @@ void testGPU_FusionTraversalOrder5() { t5.sub(cg_output_tv5).abs().max()); } -void testGPU_FusionTraversalOrder6() { +TEST(NVFuserTest, FusionTraversalOrder6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6735,7 +7220,7 @@ void testGPU_FusionTraversalOrder6() { tv1->computeAt(tv3, -1); tv2->computeAt(tv3, -2); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -6755,7 +7240,7 @@ void testGPU_FusionTraversalOrder6() { t4.sub(cg_output_tv4).abs().max()); } -void testGPU_FusionTraversalOrder7() { +TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6783,7 +7268,7 @@ void testGPU_FusionTraversalOrder7() { tv2->computeAt(tv5, -4); tv4->computeAt(tv5, -3); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -6804,7 +7289,7 @@ void testGPU_FusionTraversalOrder7() { } // Test predication of grid reduction -void testGPU_FusionThreadPredicate() { +TEST(NVFuserTest, FusionThreadPredicate_CUDA) { const int gdimx = 4; const int bdimx = 128; @@ -6850,7 +7335,7 @@ void testGPU_FusionThreadPredicate() { at::Tensor cg_output_tv2 = at::empty({numel_x}, options); at::Tensor cg_output_tv3 = at::empty_like(input, options); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output_tv3, cg_output_tv2}); @@ -6860,7 +7345,7 @@ void testGPU_FusionThreadPredicate() { TORCH_CHECK(aten_output_tv3.allclose(cg_output_tv3)); } -void testGPU_FusionLSTMCell() { +TEST(NVFuserTest, FusionLSTMCell_CUDA) { const int hidden_features = 512; const int batch_size = 64; @@ -6930,9 +7415,9 @@ void testGPU_FusionLSTMCell() { auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate)); auto at_hy = at_outgate.mul(at_cy.tanh()); - fuser::cuda::scheduleFusion(&fusion, c10::ArrayRef(inputs)); + scheduleFusion(&fusion, c10::ArrayRef(inputs)); - torch::jit::fuser::cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion(c10::ArrayRef(inputs)); @@ -6940,7 +7425,7 @@ void testGPU_FusionLSTMCell() { TORCH_CHECK(at_hy.allclose(outputs[1], 1e-4, 1e-7)); } -void testGPU_FusionComputeAtMultiBCast() { +TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6958,7 +7443,7 @@ void testGPU_FusionComputeAtMultiBCast() { ASSERT_ANY_THROW(tv1->computeAt(tv3, -1)); } -void testGPU_FusionReductionHalf() { +TEST(NVFuserTest, FusionReductionHalf_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6988,14 +7473,14 @@ void testGPU_FusionReductionHalf() { tv_entries.begin(), tv_entries.end()); auto reduction_params = - cuda::getReductionHeuristics(&fusion, {input}, reduction_tv); + getReductionHeuristics(&fusion, {input}, reduction_tv); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - cuda::scheduleReduction( + scheduleReduction( &fusion, reduction_params.value(), reduction_tv, tvOutputsOfReduction); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - cuda::FusionExecutor fe; + FusionExecutor fe; fe.compileFusion(&fusion); // no broadcasting needed, omitting the last optional argument; auto outputs = fe.runFusion({input}, reduction_params.value().lparams); @@ -7011,14 +7496,14 @@ void testGPU_FusionReductionHalf() { aten_output.sub(outputs[0]).abs().max()); } -void testGPU_FusionInputsIdLookup() { +TEST(NVFuserTest, FusionInputsIdLookup_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({16, 8, 8}, options); at::Tensor t1 = at::randn({8, 8}, options); at::Tensor t2 = at::randn({6, 4}, options); // create a cache with max size 2; - auto inputs_id_lookup = torch::jit::fuser::cuda::InputsIdLookup(2); + auto inputs_id_lookup = InputsIdLookup(2); // testing basic function, same encoding for identical inputs auto id_0 = inputs_id_lookup.lookupId({t0, t1, 5.0}); @@ -7049,6 +7534,94 @@ void testGPU_FusionInputsIdLookup() { TORCH_CHECK(id_1_relook.eviction == false); } +TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { + std::vector sizes_vec({16, 8, 8}); + std::vector strides_vec({64, 8, 1}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // pass with identical shape + auto t0 = at::randn({16, 8, 8}, options); + TORCH_CHECK(complyWith(t0, tensor_type)); + + // pass with dynamic shape + auto t1 = at::randn({16, 16, 8}, options); + TORCH_CHECK(complyWith(t1, tensor_type)); + + // rank failure + auto t5 = at::randn({16, 8, 8, 8}, options); + TORCH_CHECK(!complyWith(t5, tensor_type)); + + // broadcasting semantic change failure + auto t2 = at::randn({16, 1, 8}, options); + TORCH_CHECK(!complyWith(t2, tensor_type)); + + // contiguity failure via slicing + auto t3 = t0.slice(1, 0, 8, 2); + TORCH_CHECK(!complyWith(t3, tensor_type)); + + // contiguity failure via slicing + auto t4 = t0.slice(2, 0, 8, 2); + TORCH_CHECK(!complyWith(t4, tensor_type)); +} + +TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { + std::vector sizes_vec({16, 1, 8}); + std::vector strides_vec({8, 8, 1}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // broadcasting semantic change + auto t0 = at::randn({16, 8, 8}, options); + TORCH_CHECK(!complyWith(t0, tensor_type)); + + // dtype failure + auto t1 = at::randn({16, 1, 8}, options.dtype(at::kHalf)); + TORCH_CHECK(!complyWith(t1, tensor_type)); + + // dtype failure + auto t2 = at::randn({16, 1, 8}, options); + TORCH_CHECK(complyWith(t2, tensor_type)); + + // device inconsistency shouldn't fail + auto t3 = at::randn({16, 1, 8}, options.device(at::kCPU, 0)); + TORCH_CHECK(complyWith(t3, tensor_type)); +} + +TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { + std::vector sizes_vec({16, 8, 8}); + std::vector strides_vec({64, 1, 8}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // failing permutation + auto t0 = at::randn({16, 8, 8}, options); + TORCH_CHECK(!complyWith(t0, tensor_type)); + + // passing with dynamic shape + auto t1 = t0.permute({0, 2, 1}); + TORCH_CHECK(complyWith(t1, tensor_type)); +} + +TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { + std::vector sizes_vec({16, 8, 8}); + std::vector strides_vec({128, 16, 1}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // contiguity check passes although it differs + auto t0 = at::randn({16, 16, 8}, options); + TORCH_CHECK(complyWith(t0, tensor_type)); + + // passing with dynamic shape + auto t1 = t0.slice(1, 0, 16, 2); + TORCH_CHECK(complyWith(t1, tensor_type)); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_graph_executor.cpp b/test/cpp/jit/test_graph_executor.cpp index 992cde217a900..a0566ce807f4c 100644 --- a/test/cpp/jit/test_graph_executor.cpp +++ b/test/cpp/jit/test_graph_executor.cpp @@ -1,11 +1,15 @@ -#include "test/cpp/jit/test_base.h" +#include + #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/runtime/graph_executor.h" +#include "torch/jit.h" +#include "torch/script.h" +#include "torch/torch.h" namespace torch { namespace jit { -void testGraphExecutor() { +TEST(GraphExecutorTest, Basic_CUDA) { constexpr int batch_size = 4; constexpr int input_size = 256; @@ -28,5 +32,40 @@ void testGraphExecutor() { ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1)); } +TEST(GraphExecutorTest, runAsync_executor) { + /* + TODO: there are some problem with C++ parsing script program involving + fork. Use the test module below for now. + issue about this: github.com/pytorch/pytorch/issues/46368 + The test module file is generated by following: + class DemoModule(torch.nn.Module): + def forward(self): + r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + return r1.wait() + r2.wait() + demo = DemoModule() + torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth') + */ + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("test_interpreter_async.pt"); + auto module = load(testModelFile); + auto graph = module.get_method("forward").graph(); + GraphExecutor graphExecutor(graph, ""); + auto asyncCounter = 0; + std::mutex mtx; + // a dummy executor which actually use at::launch, but add up a counter + auto launcher = [&](std::function f) { + mtx.lock(); + ++asyncCounter; + mtx.unlock(); + at::launch(move(f)); + }; + std::vector stack; + stack.push_back(module._ivalue()); + graphExecutor.runAsync(stack, launcher)->wait(); + ASSERT_TRUE(asyncCounter > 0); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_inliner.cpp b/test/cpp/jit/test_inliner.cpp index 2153a03893192..702f5bd97573b 100644 --- a/test/cpp/jit/test_inliner.cpp +++ b/test/cpp/jit/test_inliner.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -36,18 +36,16 @@ struct InlinerGuard { bool oldState_; }; -void testInliner() { - { - // disable automatic inlining so we can test it manually - InlinerGuard guard(/*shouldInline=*/false); +TEST(InlinerTest, Basic) { + // disable automatic inlining so we can test it manually + InlinerGuard guard(/*shouldInline=*/false); - CompilationUnit cu(testSource); - auto& fn = cu.get_function("foo3"); + CompilationUnit cu(testSource); + auto& fn = cu.get_function("foo3"); - auto g = fn.graph(); - Inline(*g); - FileCheck().check_count("prim::Print", 3)->run(*g); - } + auto g = fn.graph(); + Inline(*g); + FileCheck().check_count("prim::Print", 3)->run(*g); } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_interface.cpp b/test/cpp/jit/test_interface.cpp index b256e2328ceb8..04a532459426a 100644 --- a/test/cpp/jit/test_interface.cpp +++ b/test/cpp/jit/test_interface.cpp @@ -1,5 +1,5 @@ +#include -#include #include #include @@ -44,7 +44,7 @@ static void import_libs( si.loadType(QualifiedName(class_name)); } -void testModuleInterfaceSerialization() { +TEST(InterfaceTest, ModuleInterfaceSerialization) { auto cu = std::make_shared(); Module parentMod("parentMod", cu); Module subMod("subMod", cu); diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp index 5977b0c0494a9..edbc11965709e 100644 --- a/test/cpp/jit/test_interpreter.cpp +++ b/test/cpp/jit/test_interpreter.cpp @@ -1,107 +1,126 @@ -#include "test/cpp/jit/test_base.h" +#include + +#include #include "test/cpp/jit/test_utils.h" +#include "torch/jit.h" +#include "torch/script.h" +#include "torch/torch.h" -#include namespace torch { namespace jit { -void testTypeCheck() { - { +class TypeCheckTest : public ::testing::Test { + protected: + TypeCheckTest() : interp(makeInterp()) {} + + InterpreterState interp; + + private: + static InterpreterState makeInterp() { auto graph = std::make_shared(); std::unordered_map vmap; parseIR( R"IR( graph(%a.1 : Tensor, %b.1 : Tensor): - %t0 : Float(2:2, 2:1, device=cpu, requires_grad=1), %t1 : Float(3:3, 3:1), %type_matched : bool = prim::TypeCheck(%a.1, %b.1) + %t0 : Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), Float(3, 3, strides=[3, 1])]](%a.1, %b.1) return (%t0, %t1, %type_matched) )IR", &*graph, vmap); Code function(graph, ""); - InterpreterState interp(function); - { - // TypeCheck yields to true! Shape, grad and device matches. - auto a = at::zeros({2, 2}, at::kFloat); - auto b = at::ones({3, 3}, at::kFloat); - a.set_requires_grad(true); - a = a.to(at::kCPU); - std::vector stack({a, b}); - interp.run(stack); - ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a)); - ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b)); - ASSERT_TRUE(stack[2].toBool()); - } - { - auto a = at::zeros({2, 2}, at::kFloat); - auto b = at::ones({2, 2}, at::kFloat); // Size mismatch - a.set_requires_grad(true); - a = a.to(at::kCPU); - std::vector stack({a, b}); - interp.run(stack); - ASSERT_FALSE(stack[2].toBool()); - } - { - auto a = at::zeros({2, 2}, at::kFloat); - auto b = at::ones({3, 3}, at::kFloat); - a = a.to(at::kCPU); - a.set_requires_grad(false); // Gradient mismatch - std::vector stack({a, b}); - interp.run(stack); - ASSERT_FALSE(stack[2].toBool()); - } - { - auto a = at::zeros({2, 2}, at::kFloat); - auto b = at::ones({3, 3}, at::kFloat); - a = a.to(at::kCPU); - a.set_requires_grad(true); - a = a.to(at::kInt); // Scalar type mismatch - std::vector stack({a, b}); - interp.run(stack); - ASSERT_FALSE(stack[2].toBool()); - } - { - auto a = at::zeros({2, 2}, at::kFloat); - auto b = at::ones({3, 3}, at::kFloat); - a.set_requires_grad(true); - a = a.to(at::kCUDA); // Device mismatch - std::vector stack({a, b}); - interp.run(stack); - ASSERT_FALSE(stack[2].toBool()); - } + return InterpreterState(function); } +}; - try { // Test empty Typecheck raises an internal assertion - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR( - R"IR( -graph(%a.1 : Tensor, - %b.1 : Tensor): - %type_matched : bool = prim::TypeCheck() - return (%type_matched) - )IR", - &*graph, - vmap); - } catch (const std::exception& e) { - } - try { // Test for assertion if num_inputs + 1 != num_outputs - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR( - R"IR( -graph(%a.1 : Tensor, - %b.1 : Tensor): - %type_matched : bool = prim::TypeCheck(%a.1) - return (%type_matched) - )IR", - &*graph, - vmap); - } catch (const std::exception& e) { - } +TEST_F(TypeCheckTest, MatchingType) { + // TypeCheck yields to true! Shape, grad and device matches. + auto a = at::zeros({2, 2}, at::kFloat); + auto b = at::ones({3, 3}, at::kFloat); + a.set_requires_grad(true); + a = a.to(at::kCPU); + std::vector stack({a, b}); + interp.run(stack); + ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a)); + ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b)); + ASSERT_TRUE(stack[2].toBool()); +} + +TEST_F(TypeCheckTest, SizeMismatch) { + auto a = at::zeros({2, 2}, at::kFloat); + auto b = at::ones({2, 2}, at::kFloat); // Size mismatch + a.set_requires_grad(true); + a = a.to(at::kCPU); + std::vector stack({a, b}); + interp.run(stack); + ASSERT_FALSE(stack[2].toBool()); +} + +TEST_F(TypeCheckTest, GradientMismatch) { + auto a = at::zeros({2, 2}, at::kFloat); + auto b = at::ones({3, 3}, at::kFloat); + a = a.to(at::kCPU); + a.set_requires_grad(false); // Gradient mismatch + std::vector stack({a, b}); + interp.run(stack); + ASSERT_FALSE(stack[2].toBool()); } -void testInterp() { + +TEST_F(TypeCheckTest, ScalarTypeMismatch) { + auto a = at::zeros({2, 2}, at::kFloat); + auto b = at::ones({3, 3}, at::kFloat); + a = a.to(at::kCPU); + a.set_requires_grad(true); + a = a.to(at::kInt); // Scalar type mismatch + std::vector stack({a, b}); + interp.run(stack); + ASSERT_FALSE(stack[2].toBool()); +} + +TEST_F(TypeCheckTest, DeviceMismatch_CUDA) { + auto a = at::zeros({2, 2}, at::kFloat); + auto b = at::ones({3, 3}, at::kFloat); + a.set_requires_grad(true); + a = a.to(at::kCUDA); // Device mismatch + std::vector stack({a, b}); + interp.run(stack); + ASSERT_FALSE(stack[2].toBool()); +} + +// TODO: These tests weren't doing anything. +// TEST(TypeCheckErrorTest, EmptyCheckRaises) { +// // Test empty Typecheck raises an internal assertion +// auto graph = std::make_shared(); +// std::unordered_map vmap; +// EXPECT_ANY_THROW(parseIR( +// R"IR( +// graph(%a.1 : Tensor, +// %b.1 : Tensor): +// %type_matched : bool = prim::TypeCheck() +// return (%type_matched) +// )IR", +// &*graph, +// vmap)); +// } + +// TODO: These tests weren't doing anything. +// TEST(TypeCheckErrorTest, WrongInputOutputCountRaises) { +// // Test for assertion if num_inputs + 1 != num_outputs +// auto graph = std::make_shared(); +// std::unordered_map vmap; +// EXPECT_ANY_THROW(parseIR( +// R"IR( +// graph(%a.1 : Tensor, +// %b.1 : Tensor): +// %type_matched : bool = prim::TypeCheck(%a.1) +// return (%type_matched) +// )IR", +// &*graph, +// vmap)); +// } + +TEST(InterpreterTest, Basic_CUDA) { constexpr int batch_size = 4; constexpr int input_size = 256; constexpr int seq_len = 32; @@ -123,5 +142,41 @@ void testInterp() { ASSERT_TRUE(exactlyEqual(outputs[0], hx)); ASSERT_TRUE(exactlyEqual(outputs[1], cx)); } + +TEST(InterpreterTest, runAsyncBasicTest) { + /* + TODO: there are some problem with C++ parsing script program involving + fork. Use the test module below for now. + issue about this: github.com/pytorch/pytorch/issues/46368 + The test module file is generated by following: + class DemoModule(torch.nn.Module): + def forward(self): + r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + return r1.wait() + r2.wait() + demo = DemoModule() + torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pth') + */ + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("test_interpreter_async.pt"); + auto model = load(testModelFile); + auto graph = model.get_method("forward").graph(); + Code function(graph, ""); + auto asyncCounter = 0; + std::mutex mtx; + // a dummy executor which actually use at::launch, but add up a counter + auto launcher = [&](std::function f) { + mtx.lock(); + ++asyncCounter; + mtx.unlock(); + at::launch(f); + }; + std::vector stack; + stack.push_back(model._ivalue()); + InterpreterState interp(function, launcher); + interp.runAsync(stack)->wait(); + ASSERT_TRUE(asyncCounter > 0); +} } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_interpreter_async.pt b/test/cpp/jit/test_interpreter_async.pt new file mode 100644 index 0000000000000..fdeb3e611054f Binary files /dev/null and b/test/cpp/jit/test_interpreter_async.pt differ diff --git a/test/cpp/jit/test_ir.cpp b/test/cpp/jit/test_ir.cpp index a05ff70061bf9..2423bbf0c7736 100644 --- a/test/cpp/jit/test_ir.cpp +++ b/test/cpp/jit/test_ir.cpp @@ -1,11 +1,12 @@ -#include "test/cpp/jit/test_base.h" +#include + #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/ir/irparser.h" namespace torch { namespace jit { -void testAttributes() { +TEST(IRTest, Attributes) { Graph g; auto one = attr::alpha; auto two = attr::device; @@ -33,7 +34,7 @@ void testAttributes() { ASSERT_EQ(attr2.f(one), 5); } -void testBlocks() { +TEST(IRTest, Blocks) { auto g = std::make_shared(); const auto graph_string = R"IR( graph(%a : Tensor, @@ -92,7 +93,7 @@ void testBlocks() { ->run(*g2); } -void testCommonAncestor() { +TEST(IRTest, CommonAncestor) { std::string input_str = R"( graph(%x : Tensor, %a.1 : bool, diff --git a/test/cpp/jit/test_irparser.cpp b/test/cpp/jit/test_irparser.cpp index a71b64a7b85bd..6db8ba26639d9 100644 --- a/test/cpp/jit/test_irparser.cpp +++ b/test/cpp/jit/test_irparser.cpp @@ -1,7 +1,8 @@ +#include + #include #include #include -#include "test/cpp/jit/test_base.h" #include #include @@ -38,52 +39,52 @@ static void checkRoundtrip(const std::string& s) { AT_ASSERT(original == parsed); } -void testIRParser() { - { - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR( - R"IR( +TEST(IRParserTest, Basic) { + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR( + R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : Tensor = foo::add(%0, %1) %res, %3 = foo::mul(%0, %2) %x, %y = foo::combine(%res, %2, %3) return (%x, %y, %res))IR", - &*graph, - vmap); + &*graph, + vmap); - AT_ASSERT(graph->inputs().size() == 2); - AT_ASSERT(graph->outputs().size() == 3); - Value* x = graph->outputs()[0]; - Value* y = graph->outputs()[1]; - Value* res = graph->outputs()[2]; - Value* t0 = graph->inputs()[0]; - Value* t1 = graph->inputs()[1]; - AT_ASSERT(vmap["x"] == x); - AT_ASSERT(vmap["y"] == y); - AT_ASSERT(vmap["res"] == res); - AT_ASSERT(vmap["0"] == t0); - AT_ASSERT(vmap["1"] == t1); - AT_ASSERT(x->node() == y->node()); - Node* comb = x->node(); - Value* t2 = comb->inputs()[1]; - Value* t3 = comb->inputs()[2]; - AT_ASSERT(vmap["2"] == t2); - AT_ASSERT(vmap["3"] == t3); - AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine")); - AT_ASSERT(comb->outputs() == std::vector({x, y})); - AT_ASSERT(comb->inputs() == std::vector({res, t2, t3})); - Node* mul = res->node(); - AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul")); - AT_ASSERT(mul->inputs() == std::vector({t0, t2})); - AT_ASSERT(mul->outputs() == std::vector({res, t3})); - Node* add = t2->node(); - AT_ASSERT(add->kind().toQualString() == std::string("foo::add")); - AT_ASSERT(add->inputs() == std::vector({t0, t1})); - AT_ASSERT(add->outputs() == std::vector({t2})); - } - { - checkRoundtrip(R"IR( + AT_ASSERT(graph->inputs().size() == 2); + AT_ASSERT(graph->outputs().size() == 3); + Value* x = graph->outputs()[0]; + Value* y = graph->outputs()[1]; + Value* res = graph->outputs()[2]; + Value* t0 = graph->inputs()[0]; + Value* t1 = graph->inputs()[1]; + AT_ASSERT(vmap["x"] == x); + AT_ASSERT(vmap["y"] == y); + AT_ASSERT(vmap["res"] == res); + AT_ASSERT(vmap["0"] == t0); + AT_ASSERT(vmap["1"] == t1); + AT_ASSERT(x->node() == y->node()); + Node* comb = x->node(); + Value* t2 = comb->inputs()[1]; + Value* t3 = comb->inputs()[2]; + AT_ASSERT(vmap["2"] == t2); + AT_ASSERT(vmap["3"] == t3); + AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine")); + AT_ASSERT(comb->outputs() == std::vector({x, y})); + AT_ASSERT(comb->inputs() == std::vector({res, t2, t3})); + Node* mul = res->node(); + AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul")); + AT_ASSERT(mul->inputs() == std::vector({t0, t2})); + AT_ASSERT(mul->outputs() == std::vector({res, t3})); + Node* add = t2->node(); + AT_ASSERT(add->kind().toQualString() == std::string("foo::add")); + AT_ASSERT(add->inputs() == std::vector({t0, t1})); + AT_ASSERT(add->outputs() == std::vector({t2})); +} + +TEST(IRParserTest, NestedBlock) { + checkRoundtrip(R"IR( graph(): %0 : Tensor = a::a() block0(): @@ -95,9 +96,10 @@ graph(): %3 : Tensor = d::d() return (%3) )IR"); - } - { - checkRoundtrip(R"IR( +} + +TEST(IRParserTest, If) { + checkRoundtrip(R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): @@ -114,9 +116,10 @@ graph(%0 : Tensor, %11 : Tensor = aten::add(%5, %3, %10) return (%11) )IR"); - } - { - checkRoundtrip(R"IR( +} + +TEST(IRParserTest, If2) { + checkRoundtrip(R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): @@ -133,40 +136,43 @@ graph(%0 : Tensor, %11 : Tensor = aten::add(%5, %3, %10) return (%11) )IR"); - } - { - auto graph = std::make_shared(); - parseIR( - R"IR( +} + +TEST(IRParserTest, InferredTypeIsTensor) { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%a): return (%a))IR", - &*graph); - AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get())); - } - { - // Check that parser correctly handles values reusing the same name. - auto graph = std::make_shared(); - parseIR( - R"IR( + &*graph); + AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get())); +} + +TEST(IRParserTest, ValueReuse) { + // Check that parser correctly handles values reusing the same name. + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%x): %x = a::a(%x) %x = b::b(%x) return (%x))IR", - &*graph); - Value* x0 = graph->inputs()[0]; - Value* x2 = graph->outputs()[0]; - Node* b = x2->node(); - Value* x1 = b->inputs()[0]; - Node* a = x1->node(); - AT_ASSERT(a->inputs() == std::vector({x0})); - AT_ASSERT(a->outputs() == std::vector({x1})); - AT_ASSERT(b->inputs() == std::vector({x1})); - AT_ASSERT(b->outputs() == std::vector({x2})); - } - { - // Check that parser handles attributes and types. - checkRoundtrip( - R"IR( + &*graph); + Value* x0 = graph->inputs()[0]; + Value* x2 = graph->outputs()[0]; + Node* b = x2->node(); + Value* x1 = b->inputs()[0]; + Node* a = x1->node(); + AT_ASSERT(a->inputs() == std::vector({x0})); + AT_ASSERT(a->outputs() == std::vector({x1})); + AT_ASSERT(b->inputs() == std::vector({x1})); + AT_ASSERT(b->outputs() == std::vector({x2})); +} + +TEST(IRParserTest, Attributes) { + // Check that parser handles attributes and types. + checkRoundtrip( + R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): @@ -176,160 +182,152 @@ graph(%0 : Tensor, %8 : string = z::z() return (%7) )IR"); - } +} - { - checkRoundtrip( - R"IR( +TEST(IRParserTest, OptionalTypes) { + checkRoundtrip( + R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): %3 : int? = prim::Constant() return (%3) )IR"); - } +} - { - checkRoundtrip( - R"IR( +TEST(IRParserTest, StarTensor) { + checkRoundtrip( + R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): %3 : Float(*, *, *) = prim::Constant() return (%3) )IR"); - } +} - { - checkRoundtrip( - R"IR( +TEST(IRParserTest, UnshapedTensor) { + checkRoundtrip( + R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): %3 : Long() = prim::Constant() return (%3) )IR"); - } +} - { - checkRoundtrip( - R"IR( +TEST(IRParserTest, ShapedTensor) { + checkRoundtrip( + R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): %3 : Double(4, 4, 5) = prim::Constant() return (%3) )IR"); - } +} - { - checkRoundtrip( - R"IR( +TEST(IRParserTest, NestedContrainer) { + checkRoundtrip( + R"IR( graph(): %0 : float[] = prim::Constant[value=[1., 2., 3.]]() %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]() %2 : (float[], str[]) = prim::TupleConstruct(%0, %1) return (%2) )IR"); - } +} - { - bool error_thrown = false; - try { - checkRoundtrip( - R"IR( +TEST(IRParserTest, MalformedShapeAnnotation) { + EXPECT_ANY_THROW(checkRoundtrip( + R"IR( graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): %3 : Double(4!, 4, 5) = prim::Constant() return (%3) -)IR"); - } catch (const std::exception& error) { - error_thrown = true; - } - AT_ASSERT(error_thrown); - } +)IR")); +} - { - auto graph = std::make_shared(); - const std::string& text = - R"IR( +TEST(IRParserTest, FileCheck) { + auto graph = std::make_shared(); + const std::string& text = + R"IR( graph(%a): # CHECK: return return (%a))IR"; - parseIR(text, &*graph); - AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get())); - torch::jit::testing::FileCheck().run(text, *graph); - } + parseIR(text, &*graph); + AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(TensorType::get())); + torch::jit::testing::FileCheck().run(text, *graph); +} - { - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR( - R"IR( +TEST(IRParserTest, Strides) { + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR( + R"IR( graph(%a : Float(4, 5), - %b : Float(4:5, 5:1), + %b : Float(4, 5, strides=[5, 1]), %c : Double(*, *)): return (%a) )IR", - &*graph, - vmap); - Value* a = graph->inputs()[0]; - Value* b = graph->inputs()[1]; - Value* c = graph->inputs()[2]; + &*graph, + vmap); + Value* a = graph->inputs()[0]; + Value* b = graph->inputs()[1]; + Value* c = graph->inputs()[2]; - auto a_type = a->type()->cast(); - auto a_sizes = *a_type->sizes().concrete_sizes(); - auto a_strides = a_type->strides().concrete_sizes(); - AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5); - AT_ASSERT(a_strides == c10::nullopt); + auto a_type = a->type()->cast(); + auto a_sizes = *a_type->sizes().concrete_sizes(); + auto a_strides = a_type->strides().concrete_sizes(); + AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5); + AT_ASSERT(a_strides == c10::nullopt); - auto b_type = b->type()->cast(); - auto b_sizes = *b_type->sizes().concrete_sizes(); - auto b_strides = *(b_type->strides().sizes()); - AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5); - AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1); + auto b_type = b->type()->cast(); + auto b_sizes = *b_type->sizes().concrete_sizes(); + auto b_strides = *(b_type->strides().sizes()); + AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5); + AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1); - auto c_type = c->type()->cast(); - AT_ASSERT(*c_type->sizes().size() == 2); - AT_ASSERT(c_type->sizes().concrete_sizes() == c10::nullopt); - AT_ASSERT(c_type->strides().concrete_sizes() == c10::nullopt); - } - { - auto graph = std::make_shared(); - std::unordered_map vmap; - bool error_thrown = false; - try { - parseIR( - R"IR( -graph(%a : Float(4:5, 5)): + auto c_type = c->type()->cast(); + AT_ASSERT(*c_type->sizes().size() == 2); + AT_ASSERT(c_type->sizes().concrete_sizes() == c10::nullopt); + AT_ASSERT(c_type->strides().concrete_sizes() == c10::nullopt); +} + +TEST(IRParserTest, MalformedStrides) { + auto graph = std::make_shared(); + std::unordered_map vmap; + bool error_thrown = false; + EXPECT_ANY_THROW(parseIR( + R"IR( +graph(%a : Float(4, strides=[5], 5)): return (%a) )IR", - &*graph, - vmap); - } catch (const std::exception& error) { - error_thrown = true; - } - AT_ASSERT(error_thrown); - } - { - checkRoundtrip( - R"IR( + &*graph, + vmap)); +} + +TEST(IRParserTest, TensorShapes) { + checkRoundtrip( + R"IR( graph(%a : Float(4, 5), - %b : Float(4:5, 5:1), + %b : Float(4, 5, strides=[5, 1]), %c : Double(*, *)): return (%a) )IR"); - } - { - checkRoundtrip( - R"IR( +} + +TEST(IRParserTest, DeviceAndRequiresGradTensors) { + checkRoundtrip( + R"IR( graph(%a : Float(*, *, device=cpu), %b : Float(*, *, requires_grad=1), %c : Long(5, 10, requires_grad=1, device=cpu), %d : Float(5, requires_grad=0, device=cuda:2), - %e : Long(4:6, 3:2, 2:1, requires_grad=0, device=cuda:1), + %e : Long(4, 3, 1, strides=[6, 2, 1], requires_grad=0, device=cuda:1), %f : Float(), %g : Float(device=cpu), %h : Float(requires_grad=1), @@ -337,41 +335,45 @@ graph(%a : Float(*, *, device=cpu), %j : Double(*, *, requires_grad=0)): return (%a) )IR"); - } - { - auto graph = std::make_shared(); - parseIR( - R"IR( +} + +TEST(IRParserTest, ListConstant) { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(): %d : int[] = prim::Constant[value=[1,2,3]]() return (%d) )IR", - &*graph); - Node* n = graph->outputs()[0]->node(); - AT_ASSERT(n->kind() == prim::Constant); - AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival); - const auto& genericList = n->ival(attr::value).toList(); - std::vector int_vals; - for (const IValue& ival : genericList) { - int_vals.push_back(ival.toInt()); - } - AT_ASSERT(int_vals.size() == 3); - AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3); + &*graph); + Node* n = graph->outputs()[0]->node(); + AT_ASSERT(n->kind() == prim::Constant); + AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival); + const auto& genericList = n->ival(attr::value).toList(); + std::vector int_vals; + for (const IValue& ival : genericList) { + int_vals.push_back(ival.toInt()); } - { - checkRoundtrip( - R"IR( + AT_ASSERT(int_vals.size() == 3); + AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3); +} + +TEST(IRParserTest, PartialStarTensor) { + checkRoundtrip( + R"IR( graph(%x : Float(10, *, 10)): return (%x) )IR"); - checkRoundtrip( - R"IR( +} + +TEST(IRParserTest, ComplexTensorAttributes) { + checkRoundtrip( + R"IR( graph(%x : Double(*, 200, *, requires_grad=1, device=cuda:1), %b : Float(5, *, requires_grad=1), %c : Long(*, 10, device=cpu)): return (%x) )IR"); - } } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_jit_type.cpp b/test/cpp/jit/test_jit_type.cpp index 16c69ccd05fd5..8fd14d2525eb8 100644 --- a/test/cpp/jit/test_jit_type.cpp +++ b/test/cpp/jit/test_jit_type.cpp @@ -1,4 +1,5 @@ -#include +#include + #include #include #include "torch/csrc/jit/ir/ir.h" @@ -7,7 +8,7 @@ namespace torch { namespace jit { -void testUnifyTypes() { +TEST(JitTypeTest, UnifyTypes) { auto bool_tensor = TensorType::get()->withScalarType(at::kBool); auto opt_bool_tensor = OptionalType::create(bool_tensor); auto unified_opt_bool = unifyTypes(bool_tensor, opt_bool_tensor); @@ -17,7 +18,7 @@ void testUnifyTypes() { TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(opt_bool_tensor)); auto unified = unifyTypes(opt_bool_tensor, tensor); TORCH_INTERNAL_ASSERT(unified); - auto elem = (*unified)->expect()->getElementType(); + auto elem = (*unified)->expectRef().getElementType(); TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(TensorType::get())); auto opt_tuple_none_int = OptionalType::create( diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 814654dfc6973..e31b0f519f1bc 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -1,20 +1,30 @@ +#include + #include -#include #include #include #include #include +#include #include #include #include #include +#define ASSERT_THROWS_WITH(statement, substring) \ + try { \ + (void)statement; \ + ASSERT_TRUE(false); \ + } catch (const std::exception& e) { \ + ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ + } + // Tests go in torch::jit namespace torch { namespace jit { -void testLiteInterpreterUpsampleNearest2d() { +TEST(LiteInterpreterTest, UpsampleNearest2d) { Module m("m"); m.define(R"( def forward(self, input: Tensor, scale:float): @@ -37,7 +47,26 @@ void testLiteInterpreterUpsampleNearest2d() { ASSERT_TRUE(resd.equal(refd)); } -void testLiteInterpreterAdd() { +TEST(LiteInterpreterTest, CheckAttrAccess) { + Module m("m"); + m.register_attribute("mobile_optimized", BoolType::get(), true); + + std::stringstream ss; + m._save_for_mobile(ss); + mobile::Module bc = _load_for_mobile(ss); + bool mobile_optimized = bc.attr("mobile_optimized", false).toBool(); + + AT_ASSERT(mobile_optimized); + m.setattr("mobile_optimized", false); + ss = std::stringstream(); + m._save_for_mobile(ss); + bc = _load_for_mobile(ss); + mobile_optimized = bc.attr("mobile_optimized", false).toBool(); + + AT_ASSERT(!mobile_optimized); +} + +TEST(LiteInterpreterTest, Add) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); // TODO: support default param val, which was pushed in @@ -71,7 +100,7 @@ void testLiteInterpreterAdd() { AT_ASSERT(resd == refd); } -void testLiteInterpreterConv() { +TEST(LiteInterpreterTest, Conv) { auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); if (s && strcmp(s, "1") == 0) return; @@ -103,7 +132,7 @@ void testLiteInterpreterConv() { outputref[0][0][0][0].item() == output[0][0][0][0].item()); } -void testLiteInterpreterInline() { +TEST(LiteInterpreterTest, Inline) { Module m("m"); m.define(R"JIT( def foo1(self, x): @@ -123,7 +152,7 @@ void testLiteInterpreterInline() { AT_ASSERT(output.toTensor().item() == 7.0); } -void testLiteInterpreterTuple() { +TEST(LiteInterpreterTest, Tuple) { Module m("m"); m.define(R"JIT( def foo(self, x): @@ -141,7 +170,7 @@ void testLiteInterpreterTuple() { AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2); } -void testLiteInterpreterDict() { +TEST(LiteInterpreterTest, Dict) { Module m("m"); m.define(R"JIT( def foo(self, x): @@ -159,7 +188,7 @@ void testLiteInterpreterDict() { AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2); } -void testLiteInterpreterPrimOverload() { +TEST(LiteInterpreterTest, PrimOverload) { /* // temporarily disabled script::Module m("m"); @@ -178,7 +207,7 @@ void testLiteInterpreterPrimOverload() { */ } -void testLiteInterpreterPrim() { +TEST(LiteInterpreterTest, Prim) { Module m("m"); m.define(R"JIT( def forward(self, x): @@ -204,7 +233,7 @@ void testLiteInterpreterPrim() { AT_ASSERT(resi == refi); } -void testLiteInterpreterPrimScalar() { +TEST(LiteInterpreterTest, PrimScalar) { Module m("m"); m.define(R"JIT( def forward(self, x): @@ -230,7 +259,7 @@ void testLiteInterpreterPrimScalar() { AT_ASSERT(resi == refi); } -void testLiteInterpreterLoadOrigJit() { +TEST(LiteInterpreterTest, LoadOrigJit) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -243,7 +272,7 @@ void testLiteInterpreterLoadOrigJit() { ASSERT_THROWS_WITH(_load_for_mobile(ss), "file not found"); } -void testLiteInterpreterWrongMethodName() { +TEST(LiteInterpreterTest, WrongMethodName) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -260,7 +289,7 @@ void testLiteInterpreterWrongMethodName() { ASSERT_THROWS_WITH(bc.get_method("forward")(inputs), "is not defined"); } -void testLiteInterpreterSetState() { +TEST(LiteInterpreterTest, SetState) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -308,7 +337,7 @@ class TorchBindLiteInterpreterTestStruct } }; -void testLiteInterpreterBuiltinFunction() { +TEST(LiteInterpreterTest, BuiltinFunction) { script::Module m("m"); auto custom_class_obj = make_custom_class(); @@ -328,7 +357,7 @@ void testLiteInterpreterBuiltinFunction() { AT_ASSERT(str == expected); } -void testLiteInterpreterModuleInfoBasic() { +TEST(LiteInterpreterTest, ModuleInfoBasic) { Module m("M"); m.define(R"JIT( def forward(self, x): @@ -357,7 +386,7 @@ void testLiteInterpreterModuleInfoBasic() { AT_ASSERT(module_debug_info_set == expected_result); } -void testLiteInterpreterNotSavingModuleInfo() { +TEST(LiteInterpreterTest, NotSaveModuleInfo) { Module m("M"); m.define(R"JIT( def forward(self, x): @@ -380,7 +409,7 @@ void testLiteInterpreterNotSavingModuleInfo() { } } -void testLiteInterpreterOneSubmoduleModuleInfo() { +TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) { Module a("A"); a.define(R"JIT( def forward(self, x): @@ -416,7 +445,7 @@ void testLiteInterpreterOneSubmoduleModuleInfo() { AT_ASSERT(module_debug_info_set == expected_result); } -void testLiteInterpreterTwoSubmodulesModuleInfo() { +TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) { Module a("A"); a.define(R"JIT( def forward(self, x): @@ -458,7 +487,7 @@ void testLiteInterpreterTwoSubmodulesModuleInfo() { AT_ASSERT(module_debug_info_set == expected_result); } -void testLiteInterpreterSequentialModuleInfo() { +TEST(LiteInterpreterTest, SequentialModuleInfo) { Module a("A"); a.define(R"JIT( def forward(self, x): @@ -495,12 +524,35 @@ void testLiteInterpreterSequentialModuleInfo() { } } + // class A(nn.Module): + // def __init__(self): + // super(A, self).__init__() + + // def forward(self, x): + // return x + 1 + + // class B(nn.Module): + // def __init__(self): + // super(B, self).__init__() + + // def forward(self, x): + // return x + 2 + + // class C(nn.Module): + // def __init__(self): + // super(C, self).__init__() + // self.A0 = A() + // self.B0 = B() + + // def forward(self, x): + // return self.A0.forward(self.B0.forward(x)) + std::unordered_set expected_result( {"top(C).A0(A).forward", "top(C).B0(B).forward"}); AT_ASSERT(module_debug_info_set == expected_result); } -void testLiteInterpreterHierarchyModuleInfo() { +TEST(LiteInterpreterTest, HierarchyModuleInfo) { Module a("A"); a.define(R"JIT( def forward(self, x): @@ -540,13 +592,15 @@ void testLiteInterpreterHierarchyModuleInfo() { // There are 3 module information strings here. // "top(C).forward": for the add operator in top. // "top(C).B0(B).forward": for the add operator in B0. - // "top(C).B0(B).A0(A).forward": for the add operator in A0. + // "top(C).B0(B).forward.A0(A).forward": for the add operator in A0. std::unordered_set expected_result( - {"top(C).forward", "top(C).B0(B).forward", "top(C).B0(B).A0(A).forward"}); + {"top(C).forward", + "top(C).B0(B).forward", + "top(C).B0(B).forward.A0(A).forward"}); AT_ASSERT(module_debug_info_set == expected_result); } -void testLiteInterpreterDuplicatedClassTypeModuleInfo() { +TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) { Module a("A"); a.define(R"JIT( def forward(self, x): @@ -578,15 +632,33 @@ void testLiteInterpreterDuplicatedClassTypeModuleInfo() { } } - // The current approach is not able to distinguish between A0 and A1, - // which have the same class type. Hence, it only records module - // information for A1. + // class A(nn.Module): + // def __init__(self): + // super(A, self).__init__() + + // def forward(self, x): + // return x + 5 + + // class B(nn.Module): + // def __init__(self): + // super(B, self).__init__() + // self.A0 = A() + // self.A1 = A() + + // def forward(self, x): + // return self.A0.forward(x) + self.A1.forward(x) + + // There are 3 module information strings here. + // "top(B).forward": for the add operator in top. + // "top(B).A0(A).forward": for the add operator in A0. + // "top(B).A1(A).forward": for the add operator in A1. + std::unordered_set expected_result( - {"top(B).forward", "top(B).A1(A).forward"}); + {"top(B).forward", "top(B).A0(A).forward", "top(B).A1(A).forward"}); AT_ASSERT(module_debug_info_set == expected_result); } -void testLiteInterpreterEval() { +TEST(LiteInterpreterTest, Eval) { std::vector inputs; Module m("m"); @@ -619,7 +691,7 @@ void testLiteInterpreterEval() { outputref[0][0][0][0].item() == output[0][0][0][0].item()); } -void testLiteInterpreterFindWrongMethodName() { +TEST(LiteInterpreterTest, FindWrongMethodName) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -633,7 +705,7 @@ void testLiteInterpreterFindWrongMethodName() { ASSERT_TRUE(bc.find_method("forward") == c10::nullopt); } -void testLiteInterpreterFindAndRunMethod() { +TEST(LiteInterpreterTest, FindAndRunMethod) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -663,7 +735,7 @@ void testLiteInterpreterFindAndRunMethod() { AT_ASSERT(resd == refd); } -void testLiteInterpreterRunMethodVariadic() { +TEST(LiteInterpreterTest, RunMethodVariadic) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -686,6 +758,60 @@ void testLiteInterpreterRunMethodVariadic() { AT_ASSERT(resd == refd); } +TEST(LiteInterpreterTest, ExtraFiles) { + const auto script = R"JIT( + def forward(self): + x = torch.rand(5, 5) + x = x.mm(x) + return x + )JIT"; + + auto module = + std::make_shared("Module", std::make_shared()); + module->define(script); + std::ostringstream oss; + std::unordered_map extra_files; + extra_files["metadata.json"] = "abc"; + module->_save_for_mobile(oss, extra_files); + + std::istringstream iss(oss.str()); + caffe2::serialize::IStreamAdapter adapter{&iss}; + std::unordered_map loaded_extra_files; + loaded_extra_files["metadata.json"] = ""; + auto loaded_module = + torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files); + ASSERT_EQ(loaded_extra_files["metadata.json"], "abc"); +} + +TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) { + torch::jit::Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + x1 = torch.zeros(2, 2) + x2 = torch.empty_like(torch.empty(2, 2)) + x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) + return (x1, x2, x3) + )"); + m.eval(); + + std::stringstream ss; + m._save_for_mobile(ss); + + torch::jit::mobile::Module ptl_model = torch::jit::_load_for_mobile(ss); + std::set operator_names = + torch::jit::mobile::_export_operator_list(ptl_model); + std::set expected_operator_names = { + "aten::_convolution", + "aten::empty.memory_format", + "aten::empty_like", + "aten::zeros", + }; + EXPECT_EQ(operator_names, expected_operator_names) + << "Expected the root operator lists to be the same"; +} + namespace { static auto reg = torch::class_( diff --git a/test/cpp/jit/test_lite_trainer.cpp b/test/cpp/jit/test_lite_trainer.cpp index b70c4db62c702..9a988ecb2db18 100644 --- a/test/cpp/jit/test_lite_trainer.cpp +++ b/test/cpp/jit/test_lite_trainer.cpp @@ -1,5 +1,6 @@ +#include + #include -#include #include #include #include @@ -16,7 +17,7 @@ namespace torch { namespace jit { -void testLiteInterpreterParams() { +TEST(LiteTrainerTest, Params) { Module m("m"); m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false); m.define(R"( @@ -74,7 +75,7 @@ void testLiteInterpreterParams() { AT_ASSERT(parameters[0].item() == bc_parameters[0].item()); } -void testMobileNamedParameters() { +TEST(MobileTest, NamedParameters) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -99,7 +100,7 @@ void testMobileNamedParameters() { } } -void testMobileSaveLoadData() { +TEST(MobileTest, SaveLoadData) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -127,7 +128,7 @@ void testMobileSaveLoadData() { } } -void testMobileSaveLoadParameters() { +TEST(MobileTest, SaveLoadParameters) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); m.define(R"( @@ -157,7 +158,7 @@ void testMobileSaveLoadParameters() { } } -void testMobileSaveLoadParametersEmpty() { +TEST(MobileTest, SaveLoadParametersEmpty) { Module m("m"); m.define(R"( def add_it(self, x): @@ -180,7 +181,7 @@ void testMobileSaveLoadParametersEmpty() { AT_ASSERT(mobile_params.size() == 0); } -void testLiteSGD() { +TEST(LiteTrainerTest, SGD) { Module m("m"); m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false); m.define(R"( @@ -253,7 +254,7 @@ struct DummyDataset : torch::data::datasets::Dataset { }; } // namespace -void testLiteSequentialSampler() { +TEST(LiteTrainerTest, SequentialSampler) { // test that sampler can be used with dataloader const int kBatchSize = 10; auto data_loader = diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 92baba1168daa..3256c897ace91 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -1,60 +1,57 @@ +#include + #include #include #include #include -#include "test/cpp/jit/test_base.h" -#include "test/cpp/jit/test_utils.h" - +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include #include -#include "torch/csrc/autograd/generated/variable_factories.h" -#include "torch/csrc/autograd/variable.h" -#include "torch/csrc/jit/codegen/fuser/interface.h" -#include "torch/csrc/jit/frontend/code_template.h" -#include "torch/csrc/jit/frontend/tracer.h" -#include "torch/csrc/jit/ir/alias_analysis.h" -#include "torch/csrc/jit/ir/attributes.h" -#include "torch/csrc/jit/ir/irparser.h" -#include "torch/csrc/jit/ir/scope.h" -#include "torch/csrc/jit/jit_log.h" -#include "torch/csrc/jit/passes/bailout_graph.h" -#include "torch/csrc/jit/passes/common_subexpression_elimination.h" -#include "torch/csrc/jit/passes/constant_propagation.h" -#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" -#include "torch/csrc/jit/passes/dead_code_elimination.h" -#include "torch/csrc/jit/passes/graph_fuser.h" -#include "torch/csrc/jit/passes/guard_elimination.h" -#include "torch/csrc/jit/passes/inline_autodiff_subgraphs.h" -#include "torch/csrc/jit/passes/insert_guards.h" -#include "torch/csrc/jit/passes/liveness.h" -#include "torch/csrc/jit/passes/loop_unrolling.h" -#include "torch/csrc/jit/passes/lower_grad_of.h" -#include "torch/csrc/jit/passes/lower_tuples.h" -#include "torch/csrc/jit/passes/pass_manager.h" -#include "torch/csrc/jit/passes/requires_grad_analysis.h" -#include "torch/csrc/jit/passes/shape_analysis.h" -#include "torch/csrc/jit/passes/utils/subgraph_utils.h" -#include "torch/csrc/jit/runtime/argument_spec.h" -#include "torch/csrc/jit/runtime/autodiff.h" -#include "torch/csrc/jit/runtime/custom_operator.h" -#include "torch/csrc/jit/runtime/interpreter.h" -#include "torch/csrc/jit/runtime/symbolic_script.h" -#include "torch/csrc/jit/serialization/import.h" - -#include "torch/csrc/autograd/engine.h" -#include "torch/csrc/autograd/variable.h" - +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include #include +#include #include -#include "torch/csrc/jit/api/module.h" -#include "torch/csrc/jit/frontend/ir_emitter.h" -#include "torch/csrc/jit/runtime/profiling_record.h" -#include "torch/jit.h" - -#include "onnx/onnx_pb.h" +#include #include #include @@ -64,6 +61,7 @@ #include #include #include +#include #include #include #include @@ -92,7 +90,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector& list) { return out; } -void testInternedStrings() { +TEST(InternedStringsTest, Basic) { ASSERT_EQ(prim::Param, Symbol::prim("Param")); ASSERT_EQ(prim::Return, Symbol::prim("Return")); ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return")); @@ -108,7 +106,7 @@ void testInternedStrings() { ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2")); } -void testFromQualString() { +TEST(FromQualStringTest, Basic) { ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param")); ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm")); ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM")); @@ -138,7 +136,7 @@ void testFromQualString() { } } -void testTHNNConv() { +TEST(THNNConvTest, Basic) { std::vector input_size = {4, 3, 15, 17}; // B x C x H x W std::vector kernel_size = {3, 5}; std::vector stride = {1, 2}; @@ -233,7 +231,7 @@ void testTHNNConv() { assertAllClose(tensor_grads_out, expected_tensor_grads_out); } -void testATenNativeBatchNorm() { +TEST(ATenNativeBatchNormTest, Basic) { // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor // running_mean, Tensor running_var, bool training, float momentum, float eps) // -> (Tensor, Tensor, Tensor) @@ -365,7 +363,11 @@ void testATenNativeBatchNorm() { assertAllClose(tensor_grads_out, expected_tensor_grads_out); } -void testCustomFusion() { +TEST(CustomFusionTest, Basic) { +#if defined(FBCODE_CAFFE2) + return; +#endif + auto graph_string = R"IR( graph(%0 : Float(2, 3, 4), %1 : Float(2, 3, 4)): @@ -399,7 +401,11 @@ void testCustomFusion() { AT_ASSERT(hits == 2); } -void testCustomFusionNestedBlocks() { +TEST(CustomFusionTest, NestedBlocks) { +#if defined(FBCODE_CAFFE2) + return; +#endif + auto graph_string = R"IR( graph(%0 : Float(2, 3, 4), %1 : Float(2, 3, 4), @@ -461,7 +467,8 @@ static const auto cf_examples = R"JIT( i += 1 return a )JIT"; -void testControlFlow() { + +TEST(ControlFlowTest, Basic) { auto cu = compile(cf_examples); auto run = [&](const std::string& name, std::vector stack) { @@ -484,170 +491,176 @@ void testControlFlow() { ASSERT_EQ(256, run_binary("while_test", 2, 0)); } -void testProto() { +TEST(ProtoTest, Basic) { ::ONNX_NAMESPACE::ModelProto proto; proto.set_producer_name("foo"); } // test a few features that are not directly used in schemas yet -void testSchemaParser() { +TEST(SchemaParserTest, NestedArrays) { // nested arrays auto s = parseSchema("at::what(int[][4] foo) -> ()"); ASSERT_TRUE(s.arguments().at(0).N() == 4); ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments() .at(0) .type() - ->expect() - ->getElementType() - ->expect() - ->getElementType())); + ->expectRef() + .getElementType() + ->expectRef() + .getElementType())); auto s2 = parseSchema("at::what(int[][] foo) -> ()"); ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments() .at(0) .type() - ->expect() - ->getElementType() - ->expect() - ->getElementType())); + ->expectRef() + .getElementType() + ->expectRef() + .getElementType())); +} +TEST(SchemaParserTest, NamedReturns) { // named returns parseSchema("at::what(Tensor! i_will_be_written_to) -> ()"); auto s3 = parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)"); ASSERT_TRUE(s3.returns().at(0).name() == "the_return"); ASSERT_TRUE(s3.returns().at(1).name() == "the_return2"); +} +TEST(SchemaParserTest, Futures) { // futures auto s4 = parseSchema("at::what(Future(int) foo) -> ()"); ASSERT_TRUE(IntType::get()->isSubtypeOf( - s4.arguments().at(0).type()->expect()->getElementType())); + s4.arguments().at(0).type()->expectRef().getElementType())); +} +TEST(SchemaParserTest, AnnotatedAliasSets) { // test tensor with annotated alias sets parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))"); +} - { - const auto s = parseSchema( - "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)" - " -> (Tensor(b|c)[](a!))"); - - // The list itself is annotated with `a` - const auto& aliasInfo = *s.arguments().at(0).alias_info(); - ASSERT_TRUE( - aliasInfo.beforeSets() == - std::unordered_set{Symbol::fromQualString("alias::a")}); - ASSERT_TRUE(aliasInfo.isWrite()); - - // Check the contained types - ASSERT_TRUE(!aliasInfo.containedTypes().empty()); - const auto& containedAliasInfo = aliasInfo.containedTypes()[0]; - const auto expected = std::unordered_set{ - Symbol::fromQualString("alias::b"), - Symbol::fromQualString("alias::c"), - }; - ASSERT_TRUE(containedAliasInfo.beforeSets() == expected); - ASSERT_TRUE(containedAliasInfo.afterSets() == expected); - ASSERT_FALSE(containedAliasInfo.isWrite()); - } - { - const auto s = parseSchema( - "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)" - " -> (Tensor(b|c)[](a!))"); - - // The list itself is annotated with `a` - const auto& aliasInfo = *s.arguments().at(0).alias_info(); - ASSERT_EQ( - aliasInfo.beforeSets(), - std::unordered_set{Symbol::fromQualString("alias::a")}); - ASSERT_EQ( - aliasInfo.afterSets(), - std::unordered_set{Symbol::fromQualString("alias::a")}); - ASSERT_TRUE(aliasInfo.isWrite()); - ASSERT_EQ(aliasInfo.containedTypes().size(), 1); - - // Check the contained types - ASSERT_TRUE(!aliasInfo.containedTypes().empty()); - const auto& containedAliasInfo = aliasInfo.containedTypes()[0]; - const auto expectedBefore = std::unordered_set{ - Symbol::fromQualString("alias::b"), - }; - const auto expectedAfter = std::unordered_set{ - Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")}; - ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore); - ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter); - ASSERT_FALSE(containedAliasInfo.isWrite()); - } +TEST(SchemaParserTest, BeforeAfterSets) { + const auto s = parseSchema( + "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)" + " -> (Tensor(b|c)[](a!))"); + + // The list itself is annotated with `a` + const auto& aliasInfo = *s.arguments().at(0).alias_info(); + ASSERT_TRUE( + aliasInfo.beforeSets() == + std::unordered_set{Symbol::fromQualString("alias::a")}); + ASSERT_TRUE(aliasInfo.isWrite()); + + // Check the contained types + ASSERT_TRUE(!aliasInfo.containedTypes().empty()); + const auto& containedAliasInfo = aliasInfo.containedTypes()[0]; + const auto expected = std::unordered_set{ + Symbol::fromQualString("alias::b"), + Symbol::fromQualString("alias::c"), + }; + ASSERT_TRUE(containedAliasInfo.beforeSets() == expected); + ASSERT_TRUE(containedAliasInfo.afterSets() == expected); + ASSERT_FALSE(containedAliasInfo.isWrite()); } -void testTopologicalIndex() { - { - Graph graph; - auto node1 = graph.create(prim::AutogradZero); - auto node2 = graph.create(prim::AutogradZero); - auto node3 = graph.create(prim::AutogradZero); - auto node4 = graph.create(prim::AutogradZero); - - graph.appendNode(node4); - graph.prependNode(node1); - node2->insertAfter(node1); - node3->insertBefore(node4); - - // nodes should be in numerical order - ASSERT_TRUE(node1->isBefore(node2)); - ASSERT_TRUE(node1->isBefore(node3)); - ASSERT_TRUE(node1->isBefore(node4)); - ASSERT_TRUE(node2->isAfter(node1)); - ASSERT_TRUE(node2->isBefore(node3)); - ASSERT_TRUE(node2->isBefore(node4)); - ASSERT_FALSE(node3->isBefore(node1)); - ASSERT_FALSE(node3->isBefore(node2)); - ASSERT_FALSE(node3->isAfter(node4)); - - // Built up a block structure - // node3 - // /\ ... - // A B block1 - // \ ... - // C block2 - auto block1 = node3->addBlock(); - auto A = graph.create(prim::AutogradZero); - block1->appendNode(A); - auto B = graph.create(prim::AutogradZero); - block1->appendNode(B); - auto block2 = B->addBlock(); - auto C = graph.create(prim::AutogradZero); - block2->appendNode(C); - - // Check isAfter on different block levels - ASSERT_TRUE(node1->isBefore(A)); - ASSERT_TRUE(A->isBefore(B)); - ASSERT_TRUE(A->isBefore(C)); - - // make sure things don't blow up on deletions - node2->destroy(); - auto node2p = graph.create(prim::AutogradZero); - node2p->insertAfter(node1); - ASSERT_TRUE(node1->isBefore(node2p)); - ASSERT_TRUE(node2p->isBefore(node3)); +TEST(SchemaParserTest, BeforeAfterSets2) { + const auto s = parseSchema( + "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)" + " -> (Tensor(b|c)[](a!))"); + + // The list itself is annotated with `a` + const auto& aliasInfo = *s.arguments().at(0).alias_info(); + ASSERT_EQ( + aliasInfo.beforeSets(), + std::unordered_set{Symbol::fromQualString("alias::a")}); + ASSERT_EQ( + aliasInfo.afterSets(), + std::unordered_set{Symbol::fromQualString("alias::a")}); + ASSERT_TRUE(aliasInfo.isWrite()); + ASSERT_EQ(aliasInfo.containedTypes().size(), 1); + + // Check the contained types + ASSERT_TRUE(!aliasInfo.containedTypes().empty()); + const auto& containedAliasInfo = aliasInfo.containedTypes()[0]; + const auto expectedBefore = std::unordered_set{ + Symbol::fromQualString("alias::b"), + }; + const auto expectedAfter = std::unordered_set{ + Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")}; + ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore); + ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter); + ASSERT_FALSE(containedAliasInfo.isWrite()); +} + +TEST(TopologicalIndexTest, Basic) { + Graph graph; + auto node1 = graph.create(prim::AutogradZero); + auto node2 = graph.create(prim::AutogradZero); + auto node3 = graph.create(prim::AutogradZero); + auto node4 = graph.create(prim::AutogradZero); + + graph.appendNode(node4); + graph.prependNode(node1); + node2->insertAfter(node1); + node3->insertBefore(node4); + + // nodes should be in numerical order + ASSERT_TRUE(node1->isBefore(node2)); + ASSERT_TRUE(node1->isBefore(node3)); + ASSERT_TRUE(node1->isBefore(node4)); + ASSERT_TRUE(node2->isAfter(node1)); + ASSERT_TRUE(node2->isBefore(node3)); + ASSERT_TRUE(node2->isBefore(node4)); + ASSERT_FALSE(node3->isBefore(node1)); + ASSERT_FALSE(node3->isBefore(node2)); + ASSERT_FALSE(node3->isAfter(node4)); + + // Built up a block structure + // node3 + // /\ ... + // A B block1 + // \ ... + // C block2 + auto block1 = node3->addBlock(); + auto A = graph.create(prim::AutogradZero); + block1->appendNode(A); + auto B = graph.create(prim::AutogradZero); + block1->appendNode(B); + auto block2 = B->addBlock(); + auto C = graph.create(prim::AutogradZero); + block2->appendNode(C); + + // Check isAfter on different block levels + ASSERT_TRUE(node1->isBefore(A)); + ASSERT_TRUE(A->isBefore(B)); + ASSERT_TRUE(A->isBefore(C)); + + // make sure things don't blow up on deletions + node2->destroy(); + auto node2p = graph.create(prim::AutogradZero); + node2p->insertAfter(node1); + ASSERT_TRUE(node1->isBefore(node2p)); + ASSERT_TRUE(node2p->isBefore(node3)); +} + +TEST(TopologicalIndexTest, Reindex) { + // Induce reindexing to test that path + Graph graph; + std::map nodes; + + auto anchor = graph.create(prim::AutogradZero); + graph.appendNode(anchor); + // Inserting to the same place a lot will trigger reindexing + for (auto i = 0; i < 100; ++i) { + auto n = graph.create(prim::AutogradZero); + n->insertAfter(anchor); + nodes[i] = n; } - { - // Induce reindexing to test that path - Graph graph; - std::map nodes; - - auto anchor = graph.create(prim::AutogradZero); - graph.appendNode(anchor); - // Inserting to the same place a lot will trigger reindexing - for (auto i = 0; i < 100; ++i) { - auto n = graph.create(prim::AutogradZero); - n->insertAfter(anchor); - nodes[i] = n; - } - // Nodes should be in reverse order - for (auto i = 0; i < 100; ++i) { - for (auto j = i + 1; j < 100; ++j) { - ASSERT_TRUE(nodes[i]->isAfter(nodes[j])); - } + // Nodes should be in reverse order + for (auto i = 0; i < 100; ++i) { + for (auto j = i + 1; j < 100; ++j) { + ASSERT_TRUE(nodes[i]->isAfter(nodes[j])); } } } @@ -708,12 +721,40 @@ void checkTracedInputs(const TracedTestInputs& inputs) { TORCH_CHECK(found_mul); } +static bool bad_scope = false; +template +std::unique_ptr checkScopeCallback( + const at::RecordFunction& fn) { + if (fn.scope() == scope) { + ++(*cnt); + } else { + bad_scope = true; + } + return nullptr; +} + +template +void pushScopedCallback() { + at::addGlobalCallback( + at::RecordFunctionCallback(checkScopeCallback) + .scopes({scope})); +} + +// These cannot be function-local because that would prohibit them +// from being used as template arguments prior to C++17. +static size_t fun_cnt; +static size_t ts_fun_cnt; +static size_t user_scope_cnt; + void checkScopeCallbacks() { - bool found_function_scope = false; - bool found_method_scope = false; - bool found_user_scope = false; + static bool found_function_scope; + static bool found_method_scope; + static bool found_user_scope; + found_function_scope = false; + found_method_scope = false; + found_user_scope = false; at::addGlobalCallback(at::RecordFunctionCallback( - [&](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr { if (fn.scope() == at::RecordScope::FUNCTION && std::string(fn.name().str()) == "test_function") { found_function_scope = true; @@ -726,31 +767,16 @@ void checkScopeCallbacks() { std::string(fn.name().str()) == "test_user_scope") { found_user_scope = true; } - }, - [](const at::RecordFunction&) {})); - - bool bad_scope = false; - auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) { - at::addGlobalCallback( - at::RecordFunctionCallback( - [&bad_scope, &cnt, scope](const at::RecordFunction& fn) { - if (fn.scope() == scope) { - ++cnt; - } else { - bad_scope = true; - } - return true; - }, - [](const at::RecordFunction&) {}) - .scopes({scope})); - }; + return nullptr; + })); - size_t fun_cnt = 0; - pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt); - size_t ts_fun_cnt = 0; - pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt); - size_t user_scope_cnt = 0; - pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt); + bad_scope = false; + fun_cnt = 0; + pushScopedCallback(); + ts_fun_cnt = 0; + pushScopedCallback(); + user_scope_cnt = 0; + pushScopedCallback(); TORCH_CHECK(at::hasCallbacks()); @@ -770,33 +796,41 @@ void checkScopeCallbacks() { TORCH_CHECK(found_user_scope); } -void testRecordFunction() { +static bool should_run = false; + +static bool shouldRunCallback(const RecordFunctionCallback&) { + return should_run; +} + +static TracedTestInputs traced_inputs; +static std::unordered_set ts_names; + +std::unique_ptr tracedInputsCallback( + const RecordFunction& fn) { + if (fn.scope() == RecordScope::FUNCTION) { + auto inputs = fn.inputs(); + std::vector> sizes; + for (const auto& input : inputs) { + if (input.isTensor()) { + sizes.push_back(input.toTensor().sizes().vec()); + } else if (input.isScalar()) { + sizes.push_back(std::vector()); + } + } + traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes)); + } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { + ts_names.insert(fn.name().str()); + } + return nullptr; +} + +TEST(RecordFunctionTest, TracedTestInputs) { // disabling the inlining of method calls GraphOptimizerEnabledGuard opt_guard(false); // [(fn, [[sizes], [sizes], ...]), ...] - TracedTestInputs traced_inputs; - std::unordered_set ts_names; addGlobalCallback( - RecordFunctionCallback( - [&](const RecordFunction& fn) { - if (fn.scope() == RecordScope::FUNCTION) { - auto inputs = fn.inputs(); - std::vector> sizes; - for (const auto& input : inputs) { - if (input.isTensor()) { - sizes.push_back(input.toTensor().sizes().vec()); - } else if (input.isScalar()) { - sizes.push_back(std::vector()); - } - } - traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes)); - } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { - ts_names.insert(fn.name().str()); - } - }, - [](const RecordFunction&) {}) - .needsInputs(true)); + RecordFunctionCallback(tracedInputsCallback).needsInputs(true)); TracedTestInputs eager_inputs, jit_inputs; { @@ -817,37 +851,42 @@ void testRecordFunction() { traced_inputs.clear(); } - TORCH_CHECK(ts_names.size() == 2); TORCH_CHECK(ts_names.find("forward") != ts_names.end()); TORCH_CHECK(ts_names.find("foo") != ts_names.end()); checkTracedInputs(eager_inputs); checkTracedInputs(jit_inputs); at::clearCallbacks(); +} + +static int sampled_cb_ctr = 0; +std::unique_ptr sampledCallback(const RecordFunction& fn) { + if (std::string(fn.name().str()) == "test") { + ++sampled_cb_ctr; + } + return nullptr; +} + +static int non_sampled_cb_ctr = 0; +std::unique_ptr nonSampledCallback(const RecordFunction& fn) { + if (std::string(fn.name().str()) == "test") { + ++non_sampled_cb_ctr; + } + return nullptr; +} + +TEST(RecordFunctionTest, SampledCallbacks) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); // test sampled callbacks - int sampled_cb_ctr = 0; - auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) { - return addGlobalCallback(RecordFunctionCallback( - [&sampled_cb_ctr](const RecordFunction& fn) { - if (std::string(fn.name().str()) == "test") { - ++sampled_cb_ctr; - } - return true; - }, - [](const RecordFunction&) {}) - .samplingProb(sampling_prob)); + sampled_cb_ctr = 0; + auto setup_sampled_callback = [](double sampling_prob) { + return addGlobalCallback( + RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob)); }; - int non_sampled_cb_ctr = 0; - addGlobalCallback(RecordFunctionCallback( - [&non_sampled_cb_ctr](const RecordFunction& fn) { - if (std::string(fn.name().str()) == "test") { - ++non_sampled_cb_ctr; - } - return true; - }, - [](const RecordFunction&) {})); + addGlobalCallback(RecordFunctionCallback(nonSampledCallback)); auto handle = setup_sampled_callback(0.5); @@ -882,17 +921,22 @@ void testRecordFunction() { // test the scope of the callbacks checkScopeCallbacks(); clearCallbacks(); +} + +TEST(RecordFunctionTest, RecordFunctionGuard) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + static std::vector fn_names; + static std::mutex guard_mtx; // check record function guard - std::vector fn_names; - std::mutex mtx; addGlobalCallback(RecordFunctionCallback( - [&fn_names, &mtx](const RecordFunction& fn) { - std::lock_guard lock(mtx); + [](const RecordFunction& fn) -> std::unique_ptr { + std::lock_guard lock(guard_mtx); fn_names.push_back(fn.name().str()); - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); { RecordFunctionGuard g1(false); { @@ -911,18 +955,26 @@ void testRecordFunction() { TORCH_CHECK(fn_names.size() == 1); TORCH_CHECK(fn_names[0] == "B"); clearCallbacks(); +} - // test add/remove - std::vector ids; - auto add_remove_test_add_cb = [&ids](size_t id) { - return addGlobalCallback(RecordFunctionCallback( - [&ids, id](const RecordFunction& fn) { ids.push_back(id); }, - [](const RecordFunction&) {})); - }; +static std::vector ids; + +template +auto add_remove_test_add_cb() { + return addGlobalCallback(RecordFunctionCallback( + [](const RecordFunction& fn) -> std::unique_ptr { + ids.push_back(id); + return nullptr; + })); +} + +TEST(RecordFunctionTest, Callbacks) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); - auto h1 = add_remove_test_add_cb(1); - auto h2 = add_remove_test_add_cb(2); - auto h3 = add_remove_test_add_cb(3); + auto h1 = add_remove_test_add_cb<1>(); + auto h2 = add_remove_test_add_cb<2>(); + auto h3 = add_remove_test_add_cb<3>(); { RECORD_USER_SCOPE("test"); } @@ -953,9 +1005,7 @@ void testRecordFunction() { // thread local / global callbacks ids.clear(); - addGlobalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(1); }, - [](const RecordFunction&) {})); + add_remove_test_add_cb<1>(); { RECORD_USER_SCOPE("test"); } @@ -963,10 +1013,12 @@ void testRecordFunction() { TORCH_CHECK(ids[0] == 1); ids.clear(); - auto th = std::thread([&ids]() { + auto th = std::thread([]() { addThreadLocalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(2); }, - [](const RecordFunction&) {})); + [](const RecordFunction& fn) -> std::unique_ptr { + ids.push_back(2); + return nullptr; + })); { RECORD_USER_SCOPE("test_thread"); } }); @@ -991,22 +1043,20 @@ void testRecordFunction() { }; ids.clear(); { // START: global test - const int test_val = 123; - const std::string test_str = "test str"; addGlobalCallback(RecordFunctionCallback( - [test_val, test_str, &ids](const RecordFunction& /* unused */) { + [](const RecordFunction& + /* unused */) -> std::unique_ptr { auto ctx = std::make_unique(); - ctx->a = test_val; - ctx->b = test_str; + ctx->a = 123; + ctx->b = "test_str"; ids.push_back(1); return ctx; }, - [test_val, test_str]( - const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { + [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { auto ctx = dynamic_cast(ctx_ptr); TORCH_CHECK(ctx_ptr != nullptr); - TORCH_CHECK(ctx->a == test_val); - TORCH_CHECK(ctx->b == test_str); + TORCH_CHECK(ctx->a == 123); + TORCH_CHECK(ctx->b == "test_str"); })); { RECORD_USER_SCOPE("test"); } @@ -1016,23 +1066,23 @@ void testRecordFunction() { ids.clear(); } // END: global test { // START: thread local test - auto ctx_th = std::thread([&ids]() { + auto ctx_th = std::thread([]() { const int test_val = 234; const std::string test_str = "test thread str"; addThreadLocalCallback(RecordFunctionCallback( - [test_val, test_str, &ids](const RecordFunction& /* unused */) { + [](const RecordFunction& + /* unused */) -> std::unique_ptr { auto ctx = std::make_unique(); - ctx->a = test_val; - ctx->b = test_str; + ctx->a = 234; + ctx->b = "test_thread_str"; ids.push_back(2); return ctx; }, - [test_val, test_str]( - const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { + [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { auto ctx = dynamic_cast(ctx_ptr); TORCH_CHECK(ctx_ptr != nullptr); - TORCH_CHECK(ctx->a == test_val); - TORCH_CHECK(ctx->b == test_str); + TORCH_CHECK(ctx->a == 234); + TORCH_CHECK(ctx->b == "test_thread_str"); })); // Will call both global and thread local callbacks. @@ -1046,18 +1096,21 @@ void testRecordFunction() { } // END: thread local test clearCallbacks(); +} - // test should_run +TEST(RecordFunctionTest, ShouldRun) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); - bool ran = false; - bool should_run = false; + should_run = false; + static bool ran = false; addGlobalCallback( RecordFunctionCallback( - [&ran](const RecordFunction& fn) { ran = true; }, - [](const RecordFunction&) {}) - .setShouldRun([&should_run](const RecordFunctionCallback&) { - return should_run; - })); + [](const RecordFunction& fn) -> std::unique_ptr { + ran = true; + return nullptr; + }) + .setShouldRun(shouldRunCallback)); { RECORD_USER_SCOPE("test"); } @@ -1070,47 +1123,85 @@ void testRecordFunction() { TORCH_CHECK(ran); clearCallbacks(); +} + +TEST(RecordFunctionTest, Basic) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + static std::string recorded_op; + static bool has_ids = false; // test propagation of TLS callbacks std::thread t([]() { RecordFunctionGuard enable_rec_fn; - std::string recorded_op; auto handle = addThreadLocalCallback(RecordFunctionCallback( - [&recorded_op](const RecordFunction& fn) { + [](const RecordFunction& fn) -> std::unique_ptr { recorded_op = fn.name().str(); - }, - [](const RecordFunction&) {})); + return nullptr; + })); ThreadLocalState state; std::thread t_child([state]() { ThreadLocalStateGuard g_tls(state); RECORD_USER_SCOPE("test_in_thread"); }); t_child.join(); - TORCH_CHECK(recorded_op == "test_in_thread"); + EXPECT_EQ(recorded_op, "test_in_thread"); removeCallback(handle); }); t.join(); clearCallbacks(); // test set ids - bool has_ids = false; addGlobalCallback( RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; }, - [](const RecordFunction&) {}) + [](const RecordFunction& fn) -> std::unique_ptr { + has_ids = fn.handle() > 0; + return nullptr; + }) .needsIds(true)); { RECORD_USER_SCOPE("test"); } TORCH_CHECK(has_ids); clearCallbacks(); has_ids = false; addGlobalCallback(RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { has_ids = fn.handle() > 0; }, - [](const RecordFunction&) {})); + [](const RecordFunction& fn) -> std::unique_ptr { + has_ids = fn.handle() > 0; + return nullptr; + })); { RECORD_USER_SCOPE("test"); } TORCH_CHECK(!has_ids); clearCallbacks(); } +TEST(RecordFunctionTest, OperatorNameOverload) { + static std::set operator_names; + at::addGlobalCallback(at::RecordFunctionCallback( + [](const at::RecordFunction& fn) + -> std::unique_ptr { + c10::optional op_name = + fn.operator_name(); + if (op_name.has_value()) { + operator_names.insert(c10::toString(*op_name)); + } else { + operator_names.insert("No Operator Name"); + } + return nullptr; + }) + .scopes({at::RecordScope::FUNCTION})); + auto t = torch::randn({1, 2, 3}, at::kCPU); + t.set_requires_grad(false); + auto t2 = t.pow(2); + + at::clearCallbacks(); + EXPECT_TRUE(operator_names.count("No Operator Name") == 0) + << "Expected that all traced operators had an associated OperatorName object"; + EXPECT_TRUE(operator_names.count("aten::randn") == 1) + << "Expected aten::randn to have been called and recorded, but it was not"; + EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1) + << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not"; +} + class TestThreadLocalDebugInfo : public c10::DebugInfoBase { public: int getModelId() const { @@ -1128,15 +1219,16 @@ class TestThreadLocalDebugInfo : public c10::DebugInfoBase { }; void checkDebugInfo(c10::DebugInfoKind kind, int model_id) { - auto debug_info = c10::ThreadLocalDebugInfo::get(kind); + auto* debug_info = c10::ThreadLocalDebugInfo::get(kind); TORCH_CHECK(debug_info != nullptr); - auto* test_debug_info = - dynamic_cast(debug_info.get()); + auto* test_debug_info = dynamic_cast(debug_info); TORCH_CHECK(test_debug_info != nullptr); TORCH_CHECK(test_debug_info->getModelId() == model_id); } -void testThreadLocalDebugInfo() { +TEST(ThreadLocalDebugInfoTest, Basic) { + static std::atomic done{false}; + TORCH_CHECK( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); auto debug_info = std::make_shared(); @@ -1149,10 +1241,9 @@ void testThreadLocalDebugInfo() { // check that thread local debug info is propagated through fork calls TORCH_CHECK( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); - std::atomic done{false}; { c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); - at::launch([&done]() { + at::launch([]() { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; }); @@ -1165,12 +1256,11 @@ void testThreadLocalDebugInfo() { c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); done = false; auto handle = addGlobalCallback(RecordFunctionCallback( - [&done](const RecordFunction&) { + [](const RecordFunction&) -> std::unique_ptr { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; - return true; - }, - [](const RecordFunction&) {})); + return nullptr; + })); { c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); auto t = torch::randn({1, 2, 3}, at::kCPU); @@ -1196,7 +1286,7 @@ void testThreadLocalDebugInfo() { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); done = false; - at::launch([&done]() { + at::launch([]() { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); done = true; @@ -1209,7 +1299,7 @@ void testThreadLocalDebugInfo() { } } -void testFallbackGraphs() { +TEST(FallbackGraphsTest, Basic) { static const auto nestGraphIntoFallbackGraph = [](const std::shared_ptr& graph) { ProfilingRecord::removeProfileCounter(graph->block()); @@ -1285,35 +1375,36 @@ void testFallbackGraphs() { } } -void testAutogradProfiler() { - constexpr int batch_size = 4; - constexpr int input_size = 256; - constexpr int seq_len = 32; - - int hidden_size = 2 * input_size; - auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU); - auto hx = torch::randn({batch_size, hidden_size}, at::kCPU); - auto cx = torch::randn({batch_size, hidden_size}, at::kCPU); - auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU)); - auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU)); - - std::stringstream ss; - { - RecordProfile guard(ss); - for (size_t i = 0; i < 100; ++i) { - std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); - } - } - - std::string result = ss.str(); - size_t count = 0; - for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos; - count++, pos++) { - } - TORCH_CHECK(count == 200); -} - -void testNoneSchemaMatch() { +// TODO this test wasn't running and is broken. +// TEST(AutogradProfilerTest, Basic) { +// constexpr int batch_size = 4; +// constexpr int input_size = 256; +// constexpr int seq_len = 32; + +// int hidden_size = 2 * input_size; +// auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU); +// auto hx = torch::randn({batch_size, hidden_size}, at::kCPU); +// auto cx = torch::randn({batch_size, hidden_size}, at::kCPU); +// auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU)); +// auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU)); + +// std::stringstream ss; +// { +// RecordProfile guard(ss); +// for (size_t i = 0; i < 100; ++i) { +// std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); +// } +// } + +// std::string result = ss.str(); +// size_t count = 0; +// for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos; +// count++, pos++) { +// } +// ASSERT_EQ((count, 200); +// } + +TEST(NoneSchemaMatchTest, Basic) { RegisterOperators reg({ Operator( "prim::test_none() -> int?", @@ -1348,40 +1439,6 @@ void testNoneSchemaMatch() { AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1); } -void testModuleDefine() { - Module m("m"); - m.register_parameter("foo", torch::ones({}), false); - m.define(R"( - def add_it(self, x, b : int = 4): - return self.foo + x + b - )"); - auto result = m.run_method("add_it", torch::ones({})); - AT_ASSERT(result.toTensor().item() == 6); -} - -void testModuleConversion() { - Module m("test"); - { - // test cuda to cpu for params and buffers - m.register_parameter("foo", torch::ones({}, at::kCUDA), false); - m.register_buffer("bar", torch::ones({}, at::kCUDA)); - - m.to(at::kCUDA); - m.to(at::kCPU); - AT_ASSERT(m.attr("foo").toTensor().device().is_cpu()); - AT_ASSERT(m.attr("bar").toTensor().device().is_cpu()); - } - { - // test cpu to cuda for params and buffers - m.register_parameter("foo", torch::ones({}), false); - m.register_buffer("bar", torch::ones({})); - - m.to(at::kCUDA); - AT_ASSERT(m.attr("foo").toTensor().device().is_cuda()); - AT_ASSERT(m.attr("bar").toTensor().device().is_cuda()); - } -} - static int testPassValue = 0; void fakePass(std::shared_ptr& g) { testPassValue++; @@ -1390,7 +1447,7 @@ void fakePass(std::shared_ptr& g) { RegisterPass p(fakePass); -void testPassManagement() { +TEST(PassManagementTest, Basic) { std::shared_ptr graph = std::make_shared(); parseIR( R"IR( @@ -1447,14 +1504,17 @@ size_t countNodes( return count; } -void testLoopPeeler() { - // peel all loops - auto true_pred = [](Node* n) { return true; }; - auto is_loop = [](Node* n) { return n->kind() == prim::Loop; }; +bool true_pred(Node* n) { + return true; +}; +bool is_loop(Node* n) { + return n->kind() == prim::Loop; +}; + +TEST(LoopPeelerTest, NoInductionVariableUse) { // do not use an induction variable explicitly - { - static const auto str_func_def = R"JIT( + static const auto str_func_def = R"JIT( def test_peel_n_times(): sum = 0 for i in range(10): @@ -1462,41 +1522,41 @@ void testLoopPeeler() { return sum )JIT"; - auto cu = compile(str_func_def); - auto& f = cu->get_function("test_peel_n_times"); - auto stack = createStack({}); - // peeling loop once - { - LoopsPeeler peeler(true_pred, 1); - auto copy = f.graph()->copy(); - peeler.run(copy); - int num_loops = - std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); - ASSERT_EQ(num_loops, 2); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 20); - } + auto cu = compile(str_func_def); + auto& f = cu->get_function("test_peel_n_times"); + auto stack = createStack({}); + // peeling loop once + { + LoopsPeeler peeler(true_pred, 1); + auto copy = f.graph()->copy(); + peeler.run(copy); + int num_loops = + std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); + ASSERT_EQ(num_loops, 2); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 20); + } - // test peeling more than one iteration - { - LoopsPeeler peeler(true_pred, 3); - auto copy = f.graph()->copy(); - peeler.run(copy); - int num_loops = - std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); - ASSERT_EQ(num_loops, 2); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 20); - } + // test peeling more than one iteration + { + LoopsPeeler peeler(true_pred, 3); + auto copy = f.graph()->copy(); + peeler.run(copy); + int num_loops = + std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); + ASSERT_EQ(num_loops, 2); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 20); } +} +TEST(LoopPeelerTest, YesInductionVariableUse) { // uses the induction variable - { - static const auto str_func_def = R"JIT( + static const auto str_func_def = R"JIT( def test_peel_n_times(): sum = 0 for i in range(10): @@ -1504,41 +1564,41 @@ void testLoopPeeler() { return sum )JIT"; - auto cu = compile(str_func_def); - auto& f = cu->get_function("test_peel_n_times"); - auto stack = createStack({}); - // peeling loop once - { - LoopsPeeler peeler(true_pred, 1); - auto copy = f.graph()->copy(); - peeler.run(copy); - int num_loops = - std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); - ASSERT_EQ(num_loops, 2); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 45); - } + auto cu = compile(str_func_def); + auto& f = cu->get_function("test_peel_n_times"); + auto stack = createStack({}); + // peeling loop once + { + LoopsPeeler peeler(true_pred, 1); + auto copy = f.graph()->copy(); + peeler.run(copy); + int num_loops = + std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); + ASSERT_EQ(num_loops, 2); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 45); + } - // test peeling more than one iteration - { - LoopsPeeler peeler(true_pred, 3); - auto copy = f.graph()->copy(); - peeler.run(copy); - int num_loops = - std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); - ASSERT_EQ(num_loops, 2); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 45); - } + // test peeling more than one iteration + { + LoopsPeeler peeler(true_pred, 3); + auto copy = f.graph()->copy(); + peeler.run(copy); + int num_loops = + std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); + ASSERT_EQ(num_loops, 2); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 45); } +} +TEST(LoopPeelerTest, LoopWithTerminationCondition) { // tests with explicit termination conditions - { - static const auto str_func_def = R"JIT( + static const auto str_func_def = R"JIT( def test_with_cond_times(): sum = 0 i = 0 @@ -1548,44 +1608,44 @@ void testLoopPeeler() { return sum )JIT"; - // the peel changes the termination condition to false - // so the original loop doesn't run - auto cu = compile(str_func_def); - auto& f = cu->get_function("test_with_cond_times"); - auto stack = createStack({}); - // peeling 5 iterations should update the termination - // condition to false - { - LoopsPeeler peeler(true_pred, 5); - auto copy = f.graph()->copy(); - peeler.run(copy); - int num_loops = - std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); - ASSERT_EQ(num_loops, 2); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 3); - } - - // the termination condition remains true - { - LoopsPeeler peeler(true_pred, 1); - auto copy = f.graph()->copy(); - peeler.run(copy); - int num_loops = - std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); - ASSERT_EQ(num_loops, 2); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 3); - } + // the peel changes the termination condition to false + // so the original loop doesn't run + auto cu = compile(str_func_def); + auto& f = cu->get_function("test_with_cond_times"); + auto stack = createStack({}); + // peeling 5 iterations should update the termination + // condition to false + { + LoopsPeeler peeler(true_pred, 5); + auto copy = f.graph()->copy(); + peeler.run(copy); + int num_loops = + std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); + ASSERT_EQ(num_loops, 2); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 3); } - // tests simple nested loops + // the termination condition remains true { - static const auto str_func_def = R"JIT( + LoopsPeeler peeler(true_pred, 1); + auto copy = f.graph()->copy(); + peeler.run(copy); + int num_loops = + std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop); + ASSERT_EQ(num_loops, 2); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 3); + } +} + +// tests simple nested loops +TEST(LoopPeelerTest, SimpleNestedLoops) { + static const auto str_func_def = R"JIT( def test_nested_loops(): sum = 0 i = 0 @@ -1595,35 +1655,35 @@ void testLoopPeeler() { return sum )JIT"; - auto cu = compile(str_func_def); - auto& f = cu->get_function("test_nested_loops"); - auto stack = createStack({}); + auto cu = compile(str_func_def); + auto& f = cu->get_function("test_nested_loops"); + auto stack = createStack({}); - { - LoopsPeeler peeler(true_pred, 1); - auto copy = f.graph()->copy(); - peeler.run(copy); - ASSERT_EQ(countNodes(copy, is_loop), 5); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 900); - } - - { - LoopsPeeler peeler(true_pred, 5); - auto copy = f.graph()->copy(); - peeler.run(copy); - ASSERT_EQ(countNodes(copy, is_loop), 5); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 900); - } + { + LoopsPeeler peeler(true_pred, 1); + auto copy = f.graph()->copy(); + peeler.run(copy); + ASSERT_EQ(countNodes(copy, is_loop), 5); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 900); } { - static const auto str_func_def = R"JIT( + LoopsPeeler peeler(true_pred, 5); + auto copy = f.graph()->copy(); + peeler.run(copy); + ASSERT_EQ(countNodes(copy, is_loop), 5); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 900); + } +} + +TEST(LoopPeelerTest, SimpleNestedLoops2) { + static const auto str_func_def = R"JIT( def test_nested_loops(): sum = 0 i = 0 @@ -1635,34 +1695,33 @@ void testLoopPeeler() { return sum )JIT"; - auto cu = compile(str_func_def); - auto& f = cu->get_function("test_nested_loops"); - auto stack = createStack({}); - { - LoopsPeeler peeler(true_pred, 1); - auto copy = f.graph()->copy(); - peeler.run(copy); - ASSERT_EQ(countNodes(copy, is_loop), 5); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 3); - } + auto cu = compile(str_func_def); + auto& f = cu->get_function("test_nested_loops"); + auto stack = createStack({}); + { + LoopsPeeler peeler(true_pred, 1); + auto copy = f.graph()->copy(); + peeler.run(copy); + ASSERT_EQ(countNodes(copy, is_loop), 5); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 3); + } - { - LoopsPeeler peeler(true_pred, 5); - auto copy = f.graph()->copy(); - peeler.run(copy); - ASSERT_EQ(countNodes(copy, is_loop), 5); - Code code(copy, ""); - InterpreterState interpreter{code}; - interpreter.run(stack); - ASSERT_EQ(stack.back().toInt(), 3); - } + { + LoopsPeeler peeler(true_pred, 5); + auto copy = f.graph()->copy(); + peeler.run(copy); + ASSERT_EQ(countNodes(copy, is_loop), 5); + Code code(copy, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + ASSERT_EQ(stack.back().toInt(), 3); } } -void testInsertAndEliminateRedundantGuards() { +TEST(InsertAndEliminateRedundantGuardsTest, Basic) { static const auto basic_example = R"JIT( def basic(x, y): a = x + y @@ -1692,7 +1751,7 @@ void testInsertAndEliminateRedundantGuards() { }); ASSERT_NE(guard, nodes.end()); ASSERT_EQ( - guard->input()->type()->expect()->sizes().size(), + guard->input()->type()->expectRef().sizes().size(), c10::nullopt); checkShape(*guard, {2, 3}, false); auto is_guard = [](Node* n) { return n->kind() == prim::Guard; }; @@ -1705,7 +1764,7 @@ void testInsertAndEliminateRedundantGuards() { ASSERT_EQ(num_guards, 2); } -void testInsertBailOuts() { +TEST(InsertBailOutsTest, Basic) { static const auto basic_example = R"JIT( def basic_loop(x, y): @@ -1754,7 +1813,7 @@ void testInsertBailOuts() { } } -void testProfiler() { +TEST(ProfilerTest, Basic) { constexpr int batch_size = 4; constexpr int input_size = 256; @@ -1780,8 +1839,8 @@ void testProfiler() { is.run(stack); // profiled types are stored as attributes and show up in the dump, e.g. - // Tensor = prim::profile[profiled_type=Double(4:256, 256:1, requires_grad=0, - // device=cpu) + // Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1], + // requires_grad=0, device=cpu) testing::FileCheck() .check("Tensor = prim::profile[profiled_type") ->check_same("256") @@ -1804,7 +1863,7 @@ void testProfiler() { checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise); } -void testCallStack() { +TEST(CallStackTest, Basic) { const auto text = R"( def ham(x): return x/7 @@ -1835,7 +1894,7 @@ def foo(x): ASSERT_TRUE(n->callstack()); auto callstack_vector = (*n->callstack())->vec(); ASSERT_EQ(callstack_vector.size(), 1); - ASSERT_EQ(callstack_vector[0].first, &cu->get_function("bar")); + ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar")); break; } case 7: { @@ -1845,8 +1904,8 @@ def foo(x): ASSERT_TRUE(n->callstack()); auto callstack_vector = (*n->callstack())->vec(); ASSERT_EQ(callstack_vector.size(), 2); - ASSERT_EQ(callstack_vector[0].first, &cu->get_function("baz")); - ASSERT_EQ(callstack_vector[1].first, &cu->get_function("ham")); + ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz")); + ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham")); break; } case 11: { @@ -1875,12 +1934,12 @@ def foo(x): ASSERT_TRUE(n->callstack()); auto callstack_vector = (*n->callstack())->vec(); ASSERT_EQ(callstack_vector.size(), 1); - ASSERT_EQ(callstack_vector[0].first, &cu->get_function("ham")); + ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham")); } } } -void testCallStackCaching() { +TEST(CallStackTest, Caching) { const auto text = R"( def a(x): @@ -1923,7 +1982,7 @@ def c(x): ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2")); } -void testAutogradSymbols() { +TEST(AutogradSymbolsTest, Basic) { Symbol sym = Symbol::fromQualString("aten::test_symbol"); Graph graph; auto node = graph.create(sym); @@ -1942,7 +2001,7 @@ void testAutogradSymbols() { TORCH_CHECK(!canRunWithAutograd(node)); } -void testDefaultArgTypeHinting() { +TEST(DefaultArgTypeHintingTest, Basic) { const auto text_non_hinted = R"( def a(x, y=1): @@ -1968,184 +2027,182 @@ def a(x, y:int=1): auto cu = compile(text_hinted); } -void testFutures() { - // Basic set case. - { - auto f1 = c10::make_intrusive(IntType::get()); - ASSERT_FALSE(f1->completed()); - ASSERT_FALSE(f1->hasValue()); - int32_t sat1 = 0; - int32_t sat2 = 0; - f1->addCallback([&]() { ++sat1; }); - f1->markCompleted(43); - ASSERT_TRUE(f1->completed()); - ASSERT_TRUE(f1->hasValue()); - ASSERT_FALSE(f1->hasError()); - ASSERT_EQ(sat1, 1); - ASSERT_EQ(f1->constValue().toInt(), 43); - ASSERT_EQ(f1->value().toInt(), 43); - f1->addCallback([&]() { ++sat2; }); - ASSERT_EQ(sat1, 1); - ASSERT_EQ(sat2, 1); - } +// Basic set case. +TEST(FuturesTest, Basic) { + auto f1 = c10::make_intrusive(IntType::get()); + ASSERT_FALSE(f1->completed()); + ASSERT_FALSE(f1->hasValue()); + int32_t sat1 = 0; + int32_t sat2 = 0; + f1->addCallback([&]() { ++sat1; }); + f1->markCompleted(43); + ASSERT_TRUE(f1->completed()); + ASSERT_TRUE(f1->hasValue()); + ASSERT_FALSE(f1->hasError()); + ASSERT_EQ(sat1, 1); + ASSERT_EQ(f1->constValue().toInt(), 43); + ASSERT_EQ(f1->value().toInt(), 43); + f1->addCallback([&]() { ++sat2; }); + ASSERT_EQ(sat1, 1); + ASSERT_EQ(sat2, 1); +} - // Basic error cases. - { - auto f1 = c10::make_intrusive(IntType::get()); - int sat1 = 0; - int sat2 = 0; - f1->addCallback([&]() { ++sat1; }); - f1->setError( - std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed"))); - ASSERT_EQ(sat1, 1); - ASSERT_TRUE(f1->completed()); - ASSERT_TRUE(f1->hasError()); - ASSERT_FALSE(f1->hasValue()); - try { - (void)f1->value(); - ASSERT_TRUE(false); // Supposed to throw. - } catch (const std::exception& e) { - ASSERT_TRUE(strcmp(e.what(), "Failed") == 0); - } - f1->addCallback([&]() { ++sat2; }); - ASSERT_EQ(sat1, 1); - ASSERT_EQ(sat2, 1); - f1->setErrorIfNeeded( - std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup"))); - ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0); - ASSERT_EQ(sat1, 1); - ASSERT_EQ(sat2, 1); +// Basic error cases. +TEST(FuturesTest, Error) { + auto f1 = c10::make_intrusive(IntType::get()); + int sat1 = 0; + int sat2 = 0; + f1->addCallback([&]() { ++sat1; }); + f1->setError( + std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed"))); + ASSERT_EQ(sat1, 1); + ASSERT_TRUE(f1->completed()); + ASSERT_TRUE(f1->hasError()); + ASSERT_FALSE(f1->hasValue()); + try { + (void)f1->value(); + ASSERT_TRUE(false); // Supposed to throw. + } catch (const std::exception& e) { + ASSERT_TRUE(strcmp(e.what(), "Failed") == 0); } + f1->addCallback([&]() { ++sat2; }); + ASSERT_EQ(sat1, 1); + ASSERT_EQ(sat2, 1); + f1->setErrorIfNeeded( + std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup"))); + ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0); + ASSERT_EQ(sat1, 1); + ASSERT_EQ(sat2, 1); +} - // then - { - auto f1 = c10::make_intrusive(IntType::get()); - auto f2 = f1->then( - [f1]() -> IValue { return f1->constValue().toInt() + 1; }, - IntType::get()); - auto f3 = f2->then( - [f2]() -> IValue { return f2->constValue().toInt() * 3; }, - IntType::get()); - bool done = false; - f3->addCallback([f3, &done]() { - ASSERT_EQ(f3->constValue().toInt(), (42 + 1) * 3); - done = true; - }); - ASSERT_FALSE(done); - f1->markCompleted(42); - ASSERT_TRUE(done); - } +// then +TEST(FuturesTest, Then) { + auto f1 = c10::make_intrusive(IntType::get()); + auto f2 = f1->then( + [f1]() -> IValue { return f1->constValue().toInt() + 1; }, + IntType::get()); + auto f3 = f2->then( + [f2]() -> IValue { return f2->constValue().toInt() * 3; }, + IntType::get()); + bool done = false; + f3->addCallback([f3, &done]() { + ASSERT_EQ(f3->constValue().toInt(), (42 + 1) * 3); + done = true; + }); + ASSERT_FALSE(done); + f1->markCompleted(42); + ASSERT_TRUE(done); +} - // collectAll() - { - auto s1 = c10::make_intrusive(IntType::get()); - auto s2 = c10::make_intrusive(IntType::get()); - auto s3 = c10::make_intrusive(IntType::get()); - - // Empty case - c10::List> futures( - FutureType::create(IntType::get())); - auto c1 = collectAll(futures); - ASSERT_TRUE(c1->completed()); - ASSERT_EQ(c1->value().toList().size(), 0); - ASSERT_TRUE( - *(c1->value().toList().elementType()) == - *FutureType::create(IntType::get())); - - // 1-element, initially not completed. - futures.push_back(s1); - auto c2 = collectAll(futures); - ASSERT_FALSE(c2->completed()); - s1->markCompleted(5); - ASSERT_TRUE(c2->completed()); - ASSERT_EQ(c2->value().toList().size(), 1); - ASSERT_TRUE( - *(c2->value().toList().elementType()) == - *FutureType::create(IntType::get())); - ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5); - - // 1-element, already completed - auto c3 = collectAll(futures); - ASSERT_TRUE(c3->completed()); - ASSERT_EQ(c3->value().toList().size(), 1); - ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5); - - // 3 elements. - futures.push_back(s2); - futures.push_back(s3); - auto c4 = collectAll(futures); - ASSERT_FALSE(c4->completed()); - s3->markCompleted(7); - ASSERT_FALSE(c4->completed()); - s2->markCompleted(6); - ASSERT_TRUE(c4->completed()); - ASSERT_EQ(c4->value().toList().size(), 3); - ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5); - ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6); - ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7); - ASSERT_TRUE( - *(c4->value().toList().elementType()) == - *FutureType::create(IntType::get())); - - // Handle exception in the list. - auto s4 = c10::make_intrusive(IntType::get()); - futures.push_back(s4); - auto c5 = collectAll(futures); - ASSERT_FALSE(c5->completed()); - s4->setError( - std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed"))); - ASSERT_TRUE(c5->completed()); - ASSERT_EQ(c5->value().toList().size(), 4); - try { - (void)c5->value().toList().get(3).toFuture()->value(); - ASSERT_TRUE(false); // supposed to throw - } catch (const std::exception& e) { - ASSERT_EQ(std::string(e.what()), "Failed"); - } +// collectAll() +TEST(FuturesTest, CollectAll) { + auto s1 = c10::make_intrusive(IntType::get()); + auto s2 = c10::make_intrusive(IntType::get()); + auto s3 = c10::make_intrusive(IntType::get()); + + // Empty case + c10::List> futures( + FutureType::create(IntType::get())); + auto c1 = collectAll(futures); + ASSERT_TRUE(c1->completed()); + ASSERT_EQ(c1->value().toList().size(), 0); + ASSERT_TRUE( + *(c1->value().toList().elementType()) == + *FutureType::create(IntType::get())); + + // 1-element, initially not completed. + futures.push_back(s1); + auto c2 = collectAll(futures); + ASSERT_FALSE(c2->completed()); + s1->markCompleted(5); + ASSERT_TRUE(c2->completed()); + ASSERT_EQ(c2->value().toList().size(), 1); + ASSERT_TRUE( + *(c2->value().toList().elementType()) == + *FutureType::create(IntType::get())); + ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5); + + // 1-element, already completed + auto c3 = collectAll(futures); + ASSERT_TRUE(c3->completed()); + ASSERT_EQ(c3->value().toList().size(), 1); + ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5); + + // 3 elements. + futures.push_back(s2); + futures.push_back(s3); + auto c4 = collectAll(futures); + ASSERT_FALSE(c4->completed()); + s3->markCompleted(7); + ASSERT_FALSE(c4->completed()); + s2->markCompleted(6); + ASSERT_TRUE(c4->completed()); + ASSERT_EQ(c4->value().toList().size(), 3); + ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5); + ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6); + ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7); + ASSERT_TRUE( + *(c4->value().toList().elementType()) == + *FutureType::create(IntType::get())); + + // Handle exception in the list. + auto s4 = c10::make_intrusive(IntType::get()); + futures.push_back(s4); + auto c5 = collectAll(futures); + ASSERT_FALSE(c5->completed()); + s4->setError( + std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed"))); + ASSERT_TRUE(c5->completed()); + ASSERT_EQ(c5->value().toList().size(), 4); + try { + (void)c5->value().toList().get(3).toFuture()->value(); + ASSERT_TRUE(false); // supposed to throw + } catch (const std::exception& e) { + ASSERT_EQ(std::string(e.what()), "Failed"); } +} - // collectAny() - { - auto s1 = c10::make_intrusive(IntType::get()); +// collectAny() +TEST(FuturesTest, CollectAny) { + auto s1 = c10::make_intrusive(IntType::get()); - // Empty case - c10::List> futures( - FutureType::create(IntType::get())); - auto c1 = collectAny(futures); - ASSERT_TRUE(c1->completed()); - - // 1 element, not yet satisfied - futures.push_back(s1); - auto c2 = collectAny(futures); - ASSERT_FALSE(c2->completed()); - s1->markCompleted(5); - ASSERT_TRUE(c2->completed()); - ASSERT_TRUE(c2->value().isInt()); - ASSERT_EQ(c2->value().toInt(), 5); - - // 1 element already satisfied. - auto c3 = collectAny(futures); - ASSERT_TRUE(c3->completed()); - ASSERT_TRUE(c3->value().isInt()); - ASSERT_EQ(c3->value().toInt(), 5); - - // 2 elements - futures.clear(); - auto s2 = c10::make_intrusive(IntType::get()); - auto s3 = c10::make_intrusive(IntType::get()); - futures.push_back(s2); - futures.push_back(s3); - auto c4 = collectAny(futures); - ASSERT_FALSE(c4->completed()); - s3->markCompleted(7); - ASSERT_TRUE(c4->completed()); - ASSERT_EQ(c4->value().toInt(), 7); - s2->markCompleted(1); - ASSERT_EQ(c4->value().toInt(), 7); - } + // Empty case + c10::List> futures( + FutureType::create(IntType::get())); + auto c1 = collectAny(futures); + ASSERT_TRUE(c1->completed()); + + // 1 element, not yet satisfied + futures.push_back(s1); + auto c2 = collectAny(futures); + ASSERT_FALSE(c2->completed()); + s1->markCompleted(5); + ASSERT_TRUE(c2->completed()); + ASSERT_TRUE(c2->value().isInt()); + ASSERT_EQ(c2->value().toInt(), 5); + + // 1 element already satisfied. + auto c3 = collectAny(futures); + ASSERT_TRUE(c3->completed()); + ASSERT_TRUE(c3->value().isInt()); + ASSERT_EQ(c3->value().toInt(), 5); + + // 2 elements + futures.clear(); + auto s2 = c10::make_intrusive(IntType::get()); + auto s3 = c10::make_intrusive(IntType::get()); + futures.push_back(s2); + futures.push_back(s3); + auto c4 = collectAny(futures); + ASSERT_FALSE(c4->completed()); + s3->markCompleted(7); + ASSERT_TRUE(c4->completed()); + ASSERT_EQ(c4->value().toInt(), 7); + s2->markCompleted(1); + ASSERT_EQ(c4->value().toInt(), 7); } -void testTLSFutureCallbacks() { +TEST(TLSFutureCallbacksTest, Basic) { // cb that verifies the profiler is enabled auto profilerEnabledCb = []() { ASSERT_TRUE(torch::autograd::profiler::profilerEnabled()); @@ -2153,7 +2210,7 @@ void testTLSFutureCallbacks() { // test running callbacks with propagation of TLS state. { // Enable the profiler in this thread - torch::autograd::profiler::enableProfiler( + torch::autograd::profiler::enableProfilerLegacy( torch::autograd::profiler::ProfilerConfig( torch::autograd::profiler::ProfilerState::CPU, false, false)); auto s1 = c10::make_intrusive(IntType::get()); @@ -2162,12 +2219,12 @@ void testTLSFutureCallbacks() { // Since we join here, we can ensure that all callbacks corresponding to // markCompleted() have finished. t.join(); - torch::autograd::profiler::disableProfiler(); + torch::autograd::profiler::disableProfilerLegacy(); } // then() with TLS State { // Enable the profiler in this thread - torch::autograd::profiler::enableProfiler( + torch::autograd::profiler::enableProfilerLegacy( torch::autograd::profiler::ProfilerConfig( torch::autograd::profiler::ProfilerState::CPU, false, false)); auto s1 = c10::make_intrusive(IntType::get()); @@ -2180,31 +2237,32 @@ void testTLSFutureCallbacks() { std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); }); t.join(); s2->wait(); - torch::autograd::profiler::disableProfiler(); + torch::autograd::profiler::disableProfilerLegacy(); } } -void testProfilerDisableInCallback() { +TEST(ProfilerDisableInCallbackTest, Basic) { // cb that verifies the profiler is enabled auto profilerEnabledCb = []() { ASSERT_TRUE(torch::autograd::profiler::profilerEnabled()); }; - torch::autograd::profiler::enableProfiler( + torch::autograd::profiler::enableProfilerLegacy( torch::autograd::profiler::ProfilerConfig( torch::autograd::profiler::ProfilerState::CPU, false, false)); auto s1 = c10::make_intrusive(IntType::get()); - s1->addCallback(wrapPropagateTLSState([&profilerEnabledCb] { + auto verifyProfilerCb = wrapPropagateTLSState([&profilerEnabledCb] { // Ensure the profiler is still enabled in this thread. profilerEnabledCb(); auto t1 = torch::ones({2, 2}); auto t2 = torch::ones({2, 2}); torch::add(t1, t2); // Don't cleanup TLSState, and just consolidate. + auto opts = torch::autograd::profiler::ProfilerDisableOptions(false, true); auto thread_event_lists = - torch::autograd::profiler::disableProfiler(false, true); + torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); // Ensure that the events from this thread are still profiled and we obtain // the expected in events in our consolidated list when calling - // disableProfiler(). + // disableProfilerLegacy(). bool found_ones = false; bool found_add = false; for (const auto& li : thread_event_lists) { @@ -2215,17 +2273,35 @@ void testProfilerDisableInCallback() { found_ones = true; } } + if (found_add && found_ones) { + break; + } } ASSERT_TRUE(found_ones); ASSERT_TRUE(found_add); - })); + }); + + s1->addCallback(verifyProfilerCb); // Disable the profiler, but do not consolidate results in the main thread. - torch::autograd::profiler::disableProfiler(true, false); + auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false); + torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); }); t.join(); + + // Similar to above test, but verifies correctness in the case where + // continuation runs on the main thread. + torch::autograd::profiler::enableProfilerLegacy( + torch::autograd::profiler::ProfilerConfig( + torch::autograd::profiler::ProfilerState::CPU, false, false)); + s1 = c10::make_intrusive(IntType::get()); + s1->addCallback(verifyProfilerCb); + // Runs callback inline + s1->markCompleted(at::IValue(1)); + opts = torch::autograd::profiler::ProfilerDisableOptions(true, false); + torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); } -void testIValueKWargs() { +TEST(IValueKWargsTest, Basic) { const auto text = R"( def foo(a : int, b : int, c : int = 4): return a + 2*b + 3*c @@ -2235,5 +2311,70 @@ void testIValueKWargs() { ASSERT_EQ(result.toInt(), 19); } +TEST(ComputeFlopsTest, Basic) { + uint64_t flops = 0; + + // Test unknown operator + std::unordered_map extra_args; + flops = computeFlops(std::string("aten::unknown"), extra_args); + ASSERT_EQ(flops, 0); + + // Test aten::conv2d + extra_args.clear(); + std::vector input_sizes = {4, 5, 6, 7}; + std::vector weight_sizes = {3, 5, 2, 1}; + extra_args["input_size"] = at::IValue(at::IntArrayRef(input_sizes)); + extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_sizes)); + extra_args["stride"] = 1; + extra_args["dilation"] = 0; + extra_args["groups"] = 1; + flops = computeFlops(std::string("aten::conv2d"), extra_args); + ASSERT_EQ(flops, 10080); + + // Test aten::conv2d fail + extra_args.clear(); + input_sizes = {4, 5, 6, 7}; + weight_sizes = {4, 5, 6}; + extra_args["input_size"] = at::IValue(at::IntArrayRef(input_sizes)); + extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_sizes)); + flops = computeFlops(std::string("aten::conv2d"), extra_args); + ASSERT_EQ(flops, 0); + + // Test aten::conv2d fail 2 + extra_args.clear(); + input_sizes = {4, 5, 6, 7}; + extra_args["input_size"] = at::IValue(at::IntArrayRef(input_sizes)); + flops = computeFlops(std::string("aten::conv2d"), extra_args); + ASSERT_EQ(flops, 0); + + // Test aten::mm + extra_args.clear(); + std::vector mat1_sizes = {3, 4, 5, 6}; + std::vector mat2_sizes = {6, 5, 4, 3}; + extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes)); + extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes)); + flops = computeFlops(std::string("aten::mm"), extra_args); + ASSERT_EQ(flops, 21600); + + // Test mm out of range + extra_args.clear(); + flops = computeFlops(std::string("aten::mm"), extra_args); + ASSERT_EQ(flops, 0); + + // Test aten::add.Tensor + extra_args.clear(); + std::vector mat_sizes = {3, 4, 5, 6}; + extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes)); + flops = computeFlops(std::string("aten::add.Tensor"), extra_args); + ASSERT_EQ(flops, 360); + + // Test aten::mul.Tensor + extra_args.clear(); + mat_sizes = {3, 4, 5, 6}; + extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes)); + flops = computeFlops(std::string("aten::mul.Tensor"), extra_args); + ASSERT_EQ(flops, 360); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_mobile_type_parser.cpp b/test/cpp/jit/test_mobile_type_parser.cpp index 989d16794bd2f..7e24e5dc65bc1 100644 --- a/test/cpp/jit/test_mobile_type_parser.cpp +++ b/test/cpp/jit/test_mobile_type_parser.cpp @@ -1,5 +1,6 @@ -#include "test/cpp/jit/test_base.h" -//#include +#include + +#include namespace c10 { // std::string serializeType(const Type &t); @@ -8,50 +9,74 @@ TypePtr parseType(const std::string& pythonStr); namespace torch { namespace jit { -void testMobileTypeParser() { +TEST(MobileTypeParserTest, Empty) { std::string empty_ps(""); ASSERT_ANY_THROW(c10::parseType(empty_ps)); +} +TEST(MobileTypeParserTest, RoundTripAnnotationStr) { std::string int_ps("int"); auto int_tp = c10::parseType(int_ps); std::string int_tps = int_tp->annotation_str(); ASSERT_EQ(int_ps, int_tps); +} +TEST(MobileTypeParserTest, NestedContainersAnnotationStr) { std::string tuple_ps( "Tuple[str, Optional[float], Dict[str, List[Tensor]], int]"); auto tuple_tp = c10::parseType(tuple_ps); std::string tuple_tps = tuple_tp->annotation_str(); ASSERT_EQ(tuple_ps, tuple_tps); +} +TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) { + std::string tuple_ps( + "Tuple[str, Optional[float], Dict[str, List[Tensor]], int]"); std::string tuple_space_ps( "Tuple[ str, Optional[float], Dict[str, List[Tensor ]] , int]"); auto tuple_space_tp = c10::parseType(tuple_space_ps); // tuple_space_tps should not have weird white spaces std::string tuple_space_tps = tuple_space_tp->annotation_str(); ASSERT_EQ(tuple_ps, tuple_space_tps); +} +TEST(MobileTypeParserTest, TypoRaises) { std::string typo_token("List[tensor]"); ASSERT_ANY_THROW(c10::parseType(typo_token)); +} +TEST(MobileTypeParserTest, MismatchBracketRaises) { std::string mismatch1("List[Tensor"); ASSERT_ANY_THROW(c10::parseType(mismatch1)); +} +TEST(MobileTypeParserTest, MismatchBracketRaises2) { std::string mismatch2("List[[Tensor]"); ASSERT_ANY_THROW(c10::parseType(mismatch2)); +} +TEST(MobileTypeParserTest, DictWithoutValueRaises) { std::string mismatch3("Dict[Tensor]"); ASSERT_ANY_THROW(c10::parseType(mismatch3)); +} +TEST(MobileTypeParserTest, ListArgCountMismatchRaises) { // arg count mismatch std::string mismatch4("List[int, str]"); ASSERT_ANY_THROW(c10::parseType(mismatch4)); +} +TEST(MobileTypeParserTest, DictArgCountMismatchRaises) { std::string trailing_commm("Dict[str,]"); ASSERT_ANY_THROW(c10::parseType(trailing_commm)); +} +TEST(MobileTypeParserTest, ValidTypeWithExtraStuffRaises) { std::string extra_stuff("int int"); ASSERT_ANY_THROW(c10::parseType(extra_stuff)); +} +TEST(MobileTypeParserTest, NonIdentifierRaises) { std::string non_id("(int)"); ASSERT_ANY_THROW(c10::parseType(non_id)); } diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp index 386addd9fbeca..c77d89af5afaa 100644 --- a/test/cpp/jit/test_module_api.cpp +++ b/test/cpp/jit/test_module_api.cpp @@ -1,4 +1,5 @@ -#include +#include + #include #include @@ -42,7 +43,45 @@ static void import_libs( si.loadType(QualifiedName(class_name)); } -void testModuleClone() { +TEST(ModuleAPITest, MethodRunAsync) { + // Module m("m"); + // m.define(R"( + // def forward(self): + // r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + // r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) + // return r1.wait() + r2.wait() + // )"); + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + // borrow model file from TEST(GraphExecutorTest, runAsync_executor) + testModelFile.append("test_interpreter_async.pt"); + auto m = load(testModelFile); + + auto counter = 0; + std::mutex mtx; + + auto launcher = [&](std::function f) { + mtx.lock(); + ++counter; + mtx.unlock(); + at::launch(move(f)); + }; + + auto method = m.get_method("forward"); + + std::vector stack; + auto kwargs = std::unordered_map(); + auto future = method.run_async(stack, kwargs, launcher); + + future->wait(); + + // expect 2 forks and 2 wait callbacks being excuted on provided taskLauncher + // but ivalue::Future would be marked completed and release wait before + // finishing all callbacks + ASSERT_GE(counter, 2); +} + +TEST(ModuleAPITest, Clone) { auto cu = std::make_shared(); // creating child module auto child = ClassType::create("child", cu, true); @@ -71,7 +110,7 @@ void testModuleClone() { ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3); } -void testModuleCloneWithModuleInterface() { +TEST(ModuleAPITest, CloneWithModuleInterface) { auto cu = std::make_shared(); // define a initial module with two submods share same interface @@ -115,7 +154,7 @@ void testModuleCloneWithModuleInterface() { ASSERT_NE(clonedMod.type(), parentMod.type()); } -void testModuleCopy() { +TEST(ModuleAPITest, Copy) { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); auto attr_name = "attr"; @@ -144,7 +183,7 @@ void testModuleCopy() { ASSERT_EQ(m3.attr(attr_name).toInt(), 3); } -void testModuleDeepcopy() { +TEST(ModuleAPITest, DeepCopy) { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); auto str_attr = "str_attr"; @@ -203,7 +242,7 @@ void testModuleDeepcopy() { ASSERT_TRUE(t1.equal(t3)); } -void testModuleDeepcopyString() { +TEST(ModuleAPITest, DeepCopyString) { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); auto attr1 = "attr1"; @@ -219,7 +258,7 @@ void testModuleDeepcopyString() { ASSERT_EQ(copied.attr(attr1).toString()->string(), original_str); } -void testModuleDeepcopyAliasing() { +TEST(ModuleAPITest, DeepCopyPreservesAliasing) { // check deepcopy preserves aliasing auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); @@ -256,7 +295,7 @@ void testModuleDeepcopyAliasing() { ASSERT_TRUE(copied_attr3.isAliasOf(copied_attr4)); } -void testModuleConstant() { +TEST(ModuleAPITest, Constants) { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); auto attr_name = "attr"; @@ -272,7 +311,7 @@ void testModuleConstant() { ASSERT_EQ(m.attr(const_name).toInt(), 3); } -void testModuleParameter() { +TEST(ModuleAPITest, Parameters) { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); Module m(cu, cls); @@ -291,5 +330,39 @@ void testModuleParameter() { ASSERT_TRUE(m.hasattr("none_param2")); } +TEST(ModuleAPITest, Define) { + Module m("m"); + m.register_parameter("foo", torch::ones({}), false); + m.define(R"( + def add_it(self, x, b : int = 4): + return self.foo + x + b + )"); + auto result = m.run_method("add_it", torch::ones({})); + AT_ASSERT(result.toTensor().item() == 6); +} + +TEST(ModuleAPITest, To_CUDA) { + Module m("test"); + { + // test cuda to cpu for params and buffers + m.register_parameter("foo", torch::ones({}, at::kCUDA), false); + m.register_buffer("bar", torch::ones({}, at::kCUDA)); + + m.to(at::kCUDA); + m.to(at::kCPU); + AT_ASSERT(m.attr("foo").toTensor().device().is_cpu()); + AT_ASSERT(m.attr("bar").toTensor().device().is_cpu()); + } + { + // test cpu to cuda for params and buffers + m.register_parameter("foo", torch::ones({}), false); + m.register_buffer("bar", torch::ones({})); + + m.to(at::kCUDA); + AT_ASSERT(m.attr("foo").toTensor().device().is_cuda()); + AT_ASSERT(m.attr("bar").toTensor().device().is_cuda()); + } +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_peephole_optimize.cpp b/test/cpp/jit/test_peephole_optimize.cpp index 5382d556613d5..9985faa6e9bd3 100644 --- a/test/cpp/jit/test_peephole_optimize.cpp +++ b/test/cpp/jit/test_peephole_optimize.cpp @@ -1,4 +1,5 @@ -#include +#include + #include #include @@ -8,47 +9,48 @@ namespace torch { namespace jit { -void testPeepholeOptimize() { - // test is / is not none optimization - { - auto graph = std::make_shared(); - parseIR( - R"IR( +TEST(PeepholeOptimizeTest, IsAndIsNot) +// test is / is not none optimization +{ + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%0 : int): %1 : None = prim::Constant() %2 : bool = aten::__is__(%0, %1) %3 : bool = aten::__isnot__(%0, %1) return (%2, %3) )IR", - graph.get()); - PeepholeOptimize(graph); - testing::FileCheck() - .check_not("aten::__is__") - ->check_not("aten::__isnot__") - ->run(*graph); - } - { - auto graph = std::make_shared(); - parseIR( - R"IR( + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check_not("aten::__is__") + ->check_not("aten::__isnot__") + ->run(*graph); +} + +TEST(PeepholeOptimizeTest, IsAndIsNot2) { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%0: int?): %1 : None = prim::Constant() %2 : bool = aten::__is__(%0, %1) %3 : bool = aten::__isnot__(%0, %1) return (%2, %3) )IR", - graph.get()); - PeepholeOptimize(graph); - testing::FileCheck() - .check("aten::__is__") - ->check("aten::__isnot__") - ->run(*graph); - } + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check("aten::__is__") + ->check("aten::__isnot__") + ->run(*graph); +} - { - auto graph = std::make_shared(); - parseIR( - R"IR( +TEST(PeepholeOptimizeTest, IsAndIsNot3) { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%0: int?): %1 : Tensor = prim::AutogradZero() %2 : None = prim::Constant() @@ -56,48 +58,49 @@ graph(%0: int?): %5 : bool = aten::__isnot__(%1, %2) return (%4, %5) )IR", - graph.get()); - PeepholeOptimize(graph); - testing::FileCheck() - .check("aten::__is__") - ->check_not("aten::__isnot__") - ->run(*graph); - } + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check("aten::__is__") + ->check_not("aten::__isnot__") + ->run(*graph); +} - // test unwrap optional - { - auto graph = std::make_shared(); - parseIR( - R"IR( +TEST(PeepholeOptimizeTest, UnwrapOptional) +// test unwrap optional +{ + auto graph = std::make_shared(); + parseIR( + R"IR( graph(): %1 : Float(*, *, *) = prim::Constant() %2 : bool = aten::_unwrap_optional(%1) %3 : bool = prim::unchecked_unwrap_optional(%1) return (%2, %3) )IR", - graph.get()); - PeepholeOptimize(graph); - testing::FileCheck().check_not("unwrap")->run(*graph); - } - { - auto graph = std::make_shared(); - parseIR( - R"IR( + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck().check_not("unwrap")->run(*graph); +} + +TEST(PeepholeOptimizeTest, UnwrapOptional2) { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%1 : Float(*, *, *)?): %2 : bool = aten::_unwrap_optional(%1) %3 : bool = prim::unchecked_unwrap_optional(%1) return (%2, %3) )IR", - graph.get()); - PeepholeOptimize(graph); - testing::FileCheck().check_count("unwrap", 2)->run(*graph); - } + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck().check_count("unwrap", 2)->run(*graph); +} - // tests addmm fusion - { - auto graph = std::make_shared(); - parseIR( - R"IR( +TEST(PeepholeOptimizeTest, AddMMFusion) { + auto graph = std::make_shared(); + parseIR( + R"IR( graph( %0 : Float(2, 3, 4), %1 : Float(2, 3, 4), @@ -108,10 +111,9 @@ graph(%1 : Float(*, *, *)?): %6 : Tensor = aten::add(%5, %2, %3) return (%6) )IR", - graph.get()); - FuseAddMM(graph); - testing::FileCheck().check("addmm")->run(*graph); - } + graph.get()); + FuseAddMM(graph); + testing::FileCheck().check("addmm")->run(*graph); } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_qualified_name.cpp b/test/cpp/jit/test_qualified_name.cpp index 0f387bb542ed9..80028ada85650 100644 --- a/test/cpp/jit/test_qualified_name.cpp +++ b/test/cpp/jit/test_qualified_name.cpp @@ -1,68 +1,70 @@ +#include #include #include -#include "test/cpp/jit/test_base.h" using c10::QualifiedName; namespace torch { namespace jit { -void testQualifiedName() { - { - // Test prefix construction - auto foo = QualifiedName("foo"); - auto bar = QualifiedName(foo, "bar"); - auto baz = QualifiedName(bar, "baz"); - ASSERT_EQ(baz.qualifiedName(), "foo.bar.baz"); - ASSERT_EQ(baz.prefix(), "foo.bar"); - ASSERT_EQ(baz.name(), "baz"); - auto nullstate = QualifiedName(); - ASSERT_EQ(nullstate.qualifiedName(), ""); - ASSERT_EQ(nullstate.prefix(), ""); - ASSERT_EQ(nullstate.name(), ""); - } - { - // Test dotted construction - auto foo = QualifiedName("foo.bar.baz"); - ASSERT_EQ(foo.qualifiedName(), "foo.bar.baz"); - ASSERT_EQ(foo.prefix(), "foo.bar"); - ASSERT_EQ(foo.name(), "baz"); +TEST(QualifiedNameTest, PrefixConstruction) { + // Test prefix construction + auto foo = QualifiedName("foo"); + auto bar = QualifiedName(foo, "bar"); + auto baz = QualifiedName(bar, "baz"); + ASSERT_EQ(baz.qualifiedName(), "foo.bar.baz"); + ASSERT_EQ(baz.prefix(), "foo.bar"); + ASSERT_EQ(baz.name(), "baz"); + auto nullstate = QualifiedName(); + ASSERT_EQ(nullstate.qualifiedName(), ""); + ASSERT_EQ(nullstate.prefix(), ""); + ASSERT_EQ(nullstate.name(), ""); +} + +TEST(QualifiedNameTest, DottedConstruction) { + // Test dotted construction + auto foo = QualifiedName("foo.bar.baz"); + ASSERT_EQ(foo.qualifiedName(), "foo.bar.baz"); + ASSERT_EQ(foo.prefix(), "foo.bar"); + ASSERT_EQ(foo.name(), "baz"); + + auto bar = QualifiedName("bar"); + ASSERT_EQ(bar.qualifiedName(), "bar"); + ASSERT_EQ(bar.prefix(), ""); + ASSERT_EQ(bar.name(), "bar"); +} + +TEST(QualifiedNameTest, BadInputRaises) { + // throw some bad inputs at it + ASSERT_ANY_THROW(QualifiedName("foo..bar")); + ASSERT_ANY_THROW(QualifiedName(".foo.bar")); + ASSERT_ANY_THROW(QualifiedName("foo.bar.")); + ASSERT_ANY_THROW(QualifiedName("")); +} + +TEST(QualifiedNameTest, Equality) { + // test equality api + auto foo1 = QualifiedName("foo.bar.baz"); + auto foo2 = QualifiedName("foo.bar.baz"); + auto foo3 = QualifiedName("bar.bar.baz"); + ASSERT_EQ(foo1, foo2); + ASSERT_NE(foo1, foo3); + auto bar1 = QualifiedName("sup"); + auto bar2 = QualifiedName("sup"); + ASSERT_EQ(foo1, foo2); +} - auto bar = QualifiedName("bar"); - ASSERT_EQ(bar.qualifiedName(), "bar"); - ASSERT_EQ(bar.prefix(), ""); - ASSERT_EQ(bar.name(), "bar"); - } - { - // throw some bad inputs at it - ASSERT_ANY_THROW(QualifiedName("foo..bar")); - ASSERT_ANY_THROW(QualifiedName(".foo.bar")); - ASSERT_ANY_THROW(QualifiedName("foo.bar.")); - ASSERT_ANY_THROW(QualifiedName("")); - } - { - // test equality api - auto foo1 = QualifiedName("foo.bar.baz"); - auto foo2 = QualifiedName("foo.bar.baz"); - auto foo3 = QualifiedName("bar.bar.baz"); - ASSERT_EQ(foo1, foo2); - ASSERT_NE(foo1, foo3); - auto bar1 = QualifiedName("sup"); - auto bar2 = QualifiedName("sup"); - ASSERT_EQ(foo1, foo2); - } - { - // test prefix api - auto foo1 = QualifiedName("foo.bar.baz"); - auto foo2 = QualifiedName("foo.bar"); - auto foo3 = QualifiedName("bar.bar.baz"); - auto foo4 = QualifiedName("foo.bar"); - ASSERT_TRUE(foo2.isPrefixOf(foo1)); - ASSERT_TRUE(foo2.isPrefixOf(foo4)); - ASSERT_TRUE(foo4.isPrefixOf(foo2)); - ASSERT_FALSE(foo1.isPrefixOf(foo2)); - ASSERT_FALSE(foo2.isPrefixOf(foo3)); - } +TEST(QualifiedNameTest, IsPrefixOf) { + // test prefix api + auto foo1 = QualifiedName("foo.bar.baz"); + auto foo2 = QualifiedName("foo.bar"); + auto foo3 = QualifiedName("bar.bar.baz"); + auto foo4 = QualifiedName("foo.bar"); + ASSERT_TRUE(foo2.isPrefixOf(foo1)); + ASSERT_TRUE(foo2.isPrefixOf(foo4)); + ASSERT_TRUE(foo4.isPrefixOf(foo2)); + ASSERT_FALSE(foo1.isPrefixOf(foo2)); + ASSERT_FALSE(foo2.isPrefixOf(foo3)); } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_save_load.cpp b/test/cpp/jit/test_save_load.cpp index 05940845d1720..e102a6ff767cd 100644 --- a/test/cpp/jit/test_save_load.cpp +++ b/test/cpp/jit/test_save_load.cpp @@ -1,4 +1,5 @@ -#include +#include + #include #include @@ -12,10 +13,10 @@ namespace torch { namespace jit { -// Tests that an extra file written explicitly has precedence over -// extra files written by a hook -// TODO: test for the warning, too -void testExtraFilesHookPreference() { +TEST(SerializationTest, ExtraFilesHookPreference) { + // Tests that an extra file written explicitly has precedence over + // extra files written by a hook + // TODO: test for the warning, too const auto script = R"JIT( def forward(self): x = torch.rand(5, 5) @@ -43,52 +44,50 @@ void testExtraFilesHookPreference() { ASSERT_EQ(loaded_extra_files["metadata.json"], "abc"); } -void testSaveExtraFilesHook() { +TEST(SerializationTest, ExtraFileHooksNoSecret) { // no secrets + std::stringstream ss; { - std::stringstream ss; - { - Module m("__torch__.m"); - ExtraFilesMap extra; - extra["metadata.json"] = "abc"; - m.save(ss, extra); - } - ss.seekg(0); - { - ExtraFilesMap extra; - extra["metadata.json"] = ""; - extra["secret.json"] = ""; - jit::load(ss, c10::nullopt, extra); - ASSERT_EQ(extra["metadata.json"], "abc"); - ASSERT_EQ(extra["secret.json"], ""); - } + Module m("__torch__.m"); + ExtraFilesMap extra; + extra["metadata.json"] = "abc"; + m.save(ss, extra); } - // some secret + ss.seekg(0); { - std::stringstream ss; - { - SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { - return {{"secret.json", "topsecret"}}; - }); - Module m("__torch__.m"); - ExtraFilesMap extra; - extra["metadata.json"] = "abc"; - m.save(ss, extra); - SetExportModuleExtraFilesHook(nullptr); - } - ss.seekg(0); - { - ExtraFilesMap extra; - extra["metadata.json"] = ""; - extra["secret.json"] = ""; - jit::load(ss, c10::nullopt, extra); - ASSERT_EQ(extra["metadata.json"], "abc"); - ASSERT_EQ(extra["secret.json"], "topsecret"); - } + ExtraFilesMap extra; + extra["metadata.json"] = ""; + extra["secret.json"] = ""; + jit::load(ss, c10::nullopt, extra); + ASSERT_EQ(extra["metadata.json"], "abc"); + ASSERT_EQ(extra["secret.json"], ""); } } -void testTypeTags() { +TEST(SerializationTest, ExtraFileHooksWithSecret) { + std::stringstream ss; + { + SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { + return {{"secret.json", "topsecret"}}; + }); + Module m("__torch__.m"); + ExtraFilesMap extra; + extra["metadata.json"] = "abc"; + m.save(ss, extra); + SetExportModuleExtraFilesHook(nullptr); + } + ss.seekg(0); + { + ExtraFilesMap extra; + extra["metadata.json"] = ""; + extra["secret.json"] = ""; + jit::load(ss, c10::nullopt, extra); + ASSERT_EQ(extra["metadata.json"], "abc"); + ASSERT_EQ(extra["secret.json"], "topsecret"); + } +} + +TEST(SerializationTest, TypeTags) { auto list = c10::List>(); list.push_back(c10::List({1, 2, 3})); list.push_back(c10::List({4, 5, 6})); @@ -121,5 +120,33 @@ void testTypeTags() { } } +TEST(SerializationTest, TestJitStream_CUDA) { + torch::jit::Module model; + std::vector inputs; + // Deserialize the ScriptModule from a file using torch::jit::load(). + // Load the scripted model. This should have been generated by tests_setup.py + // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py + model = torch::jit::load("saved_stream_model.pt"); + + auto output = model.forward(inputs); + auto list_of_elements = output.toTuple()->elements(); + auto is_stream_s = list_of_elements[0].toBool(); + + // a,b: These are the two input tensors + // c: This is output tensor generated by the operation torch.cat(a,b) + auto a = list_of_elements[1].toTensor(); + auto b = list_of_elements[2].toTensor(); + auto c = list_of_elements[3].toTensor(); + // op: this is used to verify if the cat operation produced the same results + // as that on the GPU with torch.cat + auto op = at::cat({a, b}, 0); + + // Check if the stream is set + ASSERT_TRUE(is_stream_s); + // Check if the sizes of the outputs (op and c) is same on the GPU and CPU + ASSERT_EQ(op.sizes(), c.sizes()); + // Check if both the output tensors are equal + ASSERT_TRUE(op.equal(c)); +} } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_schema_matching.cpp b/test/cpp/jit/test_schema_matching.cpp index bea7d14dcaf2d..aeeb173b26783 100644 --- a/test/cpp/jit/test_schema_matching.cpp +++ b/test/cpp/jit/test_schema_matching.cpp @@ -1,8 +1,9 @@ +#include + #include +#include #include #include -#include "test/cpp/jit/test_base.h" -#include "torch/csrc/jit/runtime/custom_operator.h" #include #include @@ -10,80 +11,79 @@ namespace torch { namespace jit { -void testSchemaMatching() { - { - RegisterOperators reg({ - Operator( - "aten::test_vartype(t[] a, t b) -> (t)", - [](Stack* stack) { - c10::List list; - double a; - pop(stack, list, a); - push(stack, a); - }, - c10::AliasAnalysisKind::FROM_SCHEMA), - }); - Module m("m"); - m.define(R"( +TEST(SchemaMatchingTest, VarType) { + RegisterOperators reg({ + Operator( + "aten::test_vartype(t[] a, t b) -> (t)", + [](Stack* stack) { + c10::List list; + double a; + pop(stack, list, a); + push(stack, a); + }, + c10::AliasAnalysisKind::FROM_SCHEMA), + }); + Module m("m"); + m.define(R"( def test(self): a = (1.0, 2.0) return torch.test_vartype(a, 2.0) )"); - auto result = m.run_method("test"); - TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0); + auto result = m.run_method("test"); + TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0); - const std::string error_example = R"JIT( + const std::string error_example = R"JIT( def test_2(self): a = (1.0, 2.0) non_float = (1, 1) return torch.test_vartype(a, non_float) )JIT"; - std::string err = ""; - try { - m.define(error_example); - } catch (const std::exception& e) { - err = e.what(); - } - TORCH_INTERNAL_ASSERT( - err.find("previously matched to type") != std::string::npos); + std::string err = ""; + try { + m.define(error_example); + } catch (const std::exception& e) { + err = e.what(); } - { - RegisterOperators reg({ - Operator( - "aten::test_vartype2(t a, t[] b) -> (t[])", - [](Stack* stack) { - double a; - c10::List list; - pop(stack, a, list); - push(stack, a); - }, - AliasAnalysisKind::FROM_SCHEMA), - }); - Module m("m"); - m.define(R"JIT( + TORCH_INTERNAL_ASSERT( + err.find("previously matched to type") != std::string::npos); +} + +TEST(SchemaMatchingTest, VarType2) { + RegisterOperators reg({ + Operator( + "aten::test_vartype2(t a, t[] b) -> (t[])", + [](Stack* stack) { + double a; + c10::List list; + pop(stack, a, list); + push(stack, a); + }, + AliasAnalysisKind::FROM_SCHEMA), + }); + Module m("m"); + m.define(R"JIT( def test(self): a = (1.0, 2.0) return torch.test_vartype2(3.0, a) )JIT"); - auto result = m.run_method("test"); - TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0); + auto result = m.run_method("test"); + TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0); - static const auto error_exam2 = R"JIT( + static const auto error_exam2 = R"JIT( def test_2(self): a = (1, 2) return torch.test_vartype2(3.0, a) )JIT"; - std::string err = ""; - try { - m.define(error_exam2); - } catch (const std::exception& e) { - err = e.what(); - } - TORCH_INTERNAL_ASSERT( - err.find("previously matched to type") != std::string::npos); + std::string err = ""; + try { + m.define(error_exam2); + } catch (const std::exception& e) { + err = e.what(); } + TORCH_INTERNAL_ASSERT( + err.find("previously matched to type") != std::string::npos); } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_subgraph_matcher.cpp b/test/cpp/jit/test_subgraph_matcher.cpp index 2e398db44e95f..01e8293fcaeda 100644 --- a/test/cpp/jit/test_subgraph_matcher.cpp +++ b/test/cpp/jit/test_subgraph_matcher.cpp @@ -1,11 +1,12 @@ -#include "test/cpp/jit/test_base.h" +#include + #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/ir/subgraph_matcher.h" namespace torch { namespace jit { -void testTrivial1() { +TEST(SubgraphMatcherTest, Trivial1) { Graph graph, pattern; parseIR( R"IR( @@ -22,7 +23,7 @@ graph(%0): AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } -void testTrivial2() { +TEST(SubgraphMatcherTest, Trivial2) { Graph graph; auto* g_in = graph.addInput(); auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1)); @@ -45,7 +46,7 @@ void testTrivial2() { } } -void testTrivial3() { +TEST(SubgraphMatcherTest, Trivial3) { Graph graph, pattern; parseIR( R"IR( @@ -64,7 +65,7 @@ graph(%a, %b): AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } -void testTrivial4() { +TEST(SubgraphMatcherTest, Trivial4) { Graph graph; auto* g_in0 = graph.addInput(); auto* g_in1 = graph.addInput(); @@ -92,7 +93,7 @@ void testTrivial4() { } } -void testLinear1() { +TEST(SubgraphMatcherTest, Linear1) { Graph graph, pattern; parseIR( R"IR( @@ -114,7 +115,7 @@ graph(%0): AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } -void testLinear2() { +TEST(SubgraphMatcherTest, Linear2) { Graph graph; auto* g_in = graph.addInput(); @@ -164,7 +165,7 @@ void testLinear2() { * | * eee */ -void testDiamond1() { +TEST(SubgraphMatcherTest, Diamond1) { Graph graph, pattern1, pattern2; parseIR( R"IR( @@ -215,7 +216,7 @@ graph(%0): * | * o1 */ -void testDiamond2() { +TEST(SubgraphMatcherTest, Diamond2) { Graph graph; auto* g_in = graph.addInput(); @@ -253,7 +254,7 @@ void testDiamond2() { } } -void testXPattern() { +TEST(SubgraphMatcherTest, XPattern) { Graph graph, pattern; parseIR( R"IR( @@ -280,7 +281,7 @@ graph(%0, %1): AT_ASSERT(!findPatternMatches(pattern, graph).empty()); } -void testMultipleMatches() { +TEST(SubgraphMatcherTest, MultipleMatches) { Graph graph, pattern; parseIR( R"IR( @@ -301,7 +302,7 @@ graph(%t0): AT_ASSERT(matches.size() == 4); } -void testOverlappingMatches() { +TEST(SubgraphMatcherTest, OverlappingMatches) { Graph graph, pattern; parseIR( R"IR( @@ -323,7 +324,7 @@ graph(%t0): AT_ASSERT(matches.size() == 3); } -void testMatchInBasicBlocks1() { +TEST(SubgraphMatcherTest, MatchInBasicBlocks1) { Graph graph; parseIR( R"IR( @@ -360,7 +361,7 @@ graph(%x, %y): AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); } -void testMatchInBasicBlocks2() { +TEST(SubgraphMatcherTest, MatchInBasicBlocks2) { Graph graph; parseIR( R"IR( @@ -395,7 +396,7 @@ graph(%x, %y): AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); } -void testMatchesAttributes() { +TEST(SubgraphMatcherTest, MatchesAttributes) { Graph graph; parseIR( R"IR( @@ -479,13 +480,14 @@ graph(%a, %b): } } -void testBadPattern() { +TEST(SubgraphMatcherTest, BadPattern) { Graph graph, pattern1, pattern2; parseIR( R"IR( -graph(%0): - %a = a::aaa(%0) - return (%a))IR", +graph(%x): + %y = my::op1(%x) + %z = my::op2(%x) + return (%y, %z))IR", &graph); parseIR( @@ -497,6 +499,7 @@ graph(%x): -> (%z) return (%y))IR", &pattern1); + // No support for patterns with subblocks ASSERT_ANY_THROW(findPatternMatches(pattern1, graph)); parseIR( @@ -506,25 +509,58 @@ graph(%x): %z = my::op2(%x) return (%y, %z))IR", &pattern2); + // Not supported multi-output pattern, because not the whole pattern is + // covered by a traversal up from the first output (`%z = ...` is not + // visited). See the note "Multi-output Patterns" in subgraph_matcher.h. ASSERT_ANY_THROW(findPatternMatches(pattern2, graph)); } -void testSubgraphMatching() { - testTrivial1(); - testTrivial2(); - testTrivial3(); - testTrivial4(); - testLinear1(); - testLinear2(); - testDiamond1(); - testDiamond2(); - testXPattern(); - testMultipleMatches(); - testOverlappingMatches(); - testMatchInBasicBlocks1(); - testMatchInBasicBlocks2(); - testMatchesAttributes(); - testBadPattern(); +TEST(SubgraphMatcherTest, MultiOutput) { + { + Graph graph, pattern; + parseIR( + R"IR( +graph(%0): + %a = a::aaa(%0) + %b = b::bbb(%a) + %c = c::ccc(%a, %b) + %x = a::aaa(%c) + %y = b::bbb(%x) + %z = d::ddd(%x, %y) + return (%y))IR", + &graph); + parseIR( + R"IR( +graph(%0): + %a = a::aaa(%0) + %b = b::bbb(%a) + return (%b, %a))IR", + &pattern); + AT_ASSERT(findPatternMatches(pattern, graph).size() == 2); + } + { + Graph graph, pattern; + parseIR( + R"IR( +graph(%0, %1): + %a1, %a2 = a::aaa(%0, %1) + %b = b::bbb(%a1) + %c = c::ccc(%b) + + %x1, %x2 = a::aaa(%c, %a2) + %y = b::bbb(%x1) + %z = d::ddd(%y) + return (%z))IR", + &graph); + parseIR( + R"IR( +graph(%0, %1): + %a1, %a2 = a::aaa(%0, %1) + %b = b::bbb(%a1) + return (%b, %a2))IR", + &pattern); + AT_ASSERT(findPatternMatches(pattern, graph).size() == 2); + } } } // namespace jit diff --git a/test/cpp/jit/test_subgraph_rewriter.cpp b/test/cpp/jit/test_subgraph_rewriter.cpp index 9799dfdb97b2d..048223ba8600c 100644 --- a/test/cpp/jit/test_subgraph_rewriter.cpp +++ b/test/cpp/jit/test_subgraph_rewriter.cpp @@ -1,4 +1,5 @@ -#include +#include + #include #include #include @@ -8,7 +9,7 @@ namespace torch { namespace jit { using namespace testing; -void testFilterMatch() { +TEST(SubgraphRewriterTest, FilterMatch) { auto graph = std::make_shared(); parseIR( @@ -80,7 +81,7 @@ graph(%a, %b): } } -void testFilterNoMatch() { +TEST(SubgraphRewriterTest, FilterNoMatch) { auto graph = std::make_shared(); parseIR( R"IR( @@ -121,10 +122,127 @@ graph(%a, %b): FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph); } -void testSubgraphRewriter() { - testFilterMatch(); - testFilterNoMatch(); -} +TEST(SubgraphRewriterTest, MultiOutput) { + { + auto graph = std::make_shared(); + + // Basic multi-output pattern rewriting + parseIR( + R"IR( +graph(%0, %1): + %a1, %a2 = a::aaa(%0, %1) + %b = b::bbb(%a1) + %c = c::ccc(%b) + + %x1, %x2 = a::aaa(%c, %a2) + %y = b::bbb(%x1) + %z = d::ddd(%y) + return (%z))IR", + graph.get()); + + std::string pattern = R"IR( +graph(%0, %1): + %a1, %a2 = a::aaa(%0, %1) + %b = b::bbb(%a1) + return (%b, %a2))IR"; + + std::string replacement = R"IR( +graph(%a, %b): + %x, %y = ab::ababab(%a, %b) + return (%x, %y))IR"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(pattern, replacement); + + auto g = graph->copy(); + rewriter.runOnGraph(g); + FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g); + } + { + auto graph = std::make_shared(); + + // Mimic a real model case + parseIR( + R"IR( + graph(%k, %m, %x1, %x2, %x3, %x4, %y1, %y2, %y3, %y4): + %a1 = aa::aaa(%x1, %k) + %b1_1, %b1_2 = bb::bbb(%y1, %a1) + %a2 = aa::aaa(%x2, %k) + %b2_1, %b2_2 = bb::bbb(%y2, %a2) + %a3 = aa::aaa(%x3, %k) + %b3_1, %b3_2 = bb::bbb(%y3, %a3) + %a4 = aa::aaa(%x4, %k) + %b4_1, %b4_2 = bb::bbb(%y4, %a4) + %c = cc::ccc(%b4_1) + %d1 = dd::ddd(%b1_2, %m) + %e1 = ee::eee(%b1_1, %d1) + %d2 = dd::ddd(%b2_2, %m) + %e2 = ee::eee(%b2_1, %d2) + %d3 = dd::ddd(%b3_2, %m) + %e3 = ee::eee(%b3_1, %d3) + %d4 = dd::ddd(%b4_2, %m) + %e4 = ee::eee(%b4_1, %d4) + return (%d1, %d2, %d3, %d4, %e1, %e2, %e3, %e4) + )IR", + graph.get()); + + std::string pattern = R"IR( + graph(%a, %b, %c, %d): + %y0 = aa::aaa(%b, %c) + %y1, %y2 = bb::bbb(%a, %y0) + %y3 = dd::ddd(%y2, %d) + return (%y3, %y1))IR"; + + std::string replacement = R"IR( + graph(%a, %b, %c, %d): + %x, %y = ab::ababab(%a, %b, %c, %d) + return (%x, %y))IR"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(pattern, replacement); + + auto g = graph->copy(); + rewriter.runOnGraph(g); + FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g); + } + { + auto graph = std::make_shared(); + + // A case where no rewriting should occur due to data dependencies + parseIR( + R"IR( + graph(%x, %y): + %a = aa::aaa(%x) + %b = bb::bbb(%a) + %e = ee::eee(%b) + %c = cc::ccc(%y) + %d = dd::ddd(%b, %c) + %f = ff::fff(%b, %d) + return (%f) + )IR", + graph.get()); + + std::string pattern = R"IR( + graph(%a, %c): + %b = bb::bbb(%a) + %d = dd::ddd(%b, %c) + return (%d, %b))IR"; + + std::string replacement = R"IR( + graph(%a, %c): + %d, %b = db::fused(%a, %c) + return (%d, %b))IR"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(pattern, replacement); + auto g = graph->copy(); + rewriter.runOnGraph(g); + // We should not perform the replacement on the given graph due to data + // dependency constraints: the output %b is used in %e, which precedes one + // def of the input %c. + FileCheck().check_not("db::fused")->run(*g); + } +} } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_subgraph_utils.cpp b/test/cpp/jit/test_subgraph_utils.cpp index e1f86cc349799..cadb143cf5de9 100644 --- a/test/cpp/jit/test_subgraph_utils.cpp +++ b/test/cpp/jit/test_subgraph_utils.cpp @@ -1,4 +1,5 @@ -#include "test/cpp/jit/test_base.h" +#include + #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" @@ -7,7 +8,7 @@ namespace torch { namespace jit { -void testSubgraphUtils() { +TEST(SubgraphUtilsTest, Basic) { auto graph = build_lstm(); EliminateCommonSubexpression(graph); @@ -37,7 +38,7 @@ void testSubgraphUtils() { ASSERT_EQ(originalNodes.size(), newNodes.size()); } -void testSubgraphUtilsVmap() { +TEST(SubgraphUtilsTest, Vmap) { auto graph = std::make_shared(); std::unordered_map parse_map; @@ -92,5 +93,34 @@ graph(%a : Tensor, %b : Tensor, %c : Tensor): ASSERT_TRUE(vmap2.at(new_tanh_out)->node()->kind() == aten::tanh); } +TEST(SubgraphUtilsTest, GraphName) { + auto graph = std::make_shared(); + + std::unordered_map parse_map; + parseIR( + R"IR( +graph(%a : Tensor, %b : Tensor, %c : Tensor): + %x : Tensor = aten::tanh(%a) + %y : Tensor = aten::mul(%a, %b) + %p : Tensor = aten::div(%c, %b) + %q1 : Tensor = aten::mul(%p, %a) + %q2 : Tensor = aten::tanh(%q1) + %q3 : Tensor = aten::tanh(%q2) + %q4 : Tensor = aten::tanh(%q3) + %q5 : Tensor = aten::tanh(%q4) + return (%x, %y, %q5))IR", + &*graph, + parse_map); + std::string ref_full_name = "graph_tanh_mul_div_mul_tanh_tanh_tanh_tanh"; + std::string full_name = + SubgraphUtils::generateNameForGraph(graph, 80, "graph"); + ASSERT_EQ(full_name, ref_full_name); + + std::string truncated_name = + SubgraphUtils::generateNameForGraph(graph, 10, "graph"); + + ASSERT_LE(truncated_name.size(), ref_full_name.size()); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_utils.cpp b/test/cpp/jit/test_utils.cpp index d87e8201615dd..6f626756db744 100644 --- a/test/cpp/jit/test_utils.cpp +++ b/test/cpp/jit/test_utils.cpp @@ -1,6 +1,9 @@ +#include + #include #include #include +#include namespace torch { namespace jit { @@ -137,5 +140,22 @@ std::pair lstm( return {hy, cy}; } +inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; +} + +namespace { +RegisterOperators reg({ + // This operator is intended to be used in JIT analysis and transformation + // pass unit tests in which Values with type Tensor are often required. It + // should not be used in situations in which the graph is actually executed + // because it always produces empty Tensors. + Operator( + "prim::MakeTestTensor() -> Tensor", + [](Stack* stack) { push(stack, at::Tensor()); }, + aliasAnalysisFromSchema()), +}); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_utils.h b/test/cpp/jit/test_utils.h index 6e6b82fff4424..109f7253deea2 100644 --- a/test/cpp/jit/test_utils.h +++ b/test/cpp/jit/test_utils.h @@ -1,7 +1,6 @@ #pragma once #include -#include "test/cpp/jit/test_base.h" #include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/runtime/autodiff.h" #include "torch/csrc/jit/runtime/interpreter.h" diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h deleted file mode 100644 index a058326c20506..0000000000000 --- a/test/cpp/jit/tests.h +++ /dev/null @@ -1,242 +0,0 @@ -#pragma once - -/** - * See README.md for instructions on how to add a new test. - */ -#include -#include - -namespace torch { -namespace jit { -#define TH_FORALL_TESTS(_) \ - _(Attributes) \ - _(Blocks) \ - _(CallStack) \ - _(CallStackCaching) \ - _(ControlFlow) \ - _(IValueKWargs) \ - _(CustomFusion) \ - _(SchemaMatching) \ - _(FromQualString) \ - _(InternedStrings) \ - _(PassManagement) \ - _(Proto) \ - _(SchemaParser) \ - _(TopologicalIndex) \ - _(SubgraphUtils) \ - _(SubgraphUtilsVmap) \ - _(IRParser) \ - _(THNNConv) \ - _(ATenNativeBatchNorm) \ - _(NoneSchemaMatch) \ - _(UnifyTypes) \ - _(Profiler) \ - _(FallbackGraphs) \ - _(InsertAndEliminateRedundantGuards) \ - _(LoopPeeler) \ - _(InsertBailOuts) \ - _(PeepholeOptimize) \ - _(RecordFunction) \ - _(ThreadLocalDebugInfo) \ - _(SubgraphMatching) \ - _(SubgraphRewriter) \ - _(ModuleClone) \ - _(ModuleConstant) \ - _(ModuleParameter) \ - _(ModuleCopy) \ - _(ModuleDeepcopy) \ - _(ModuleDeepcopyString) \ - _(ModuleDeepcopyAliasing) \ - _(ModuleDefine) \ - _(QualifiedName) \ - _(ExtraFilesHookPreference) \ - _(SaveExtraFilesHook) \ - _(TypeTags) \ - _(CustomFusionNestedBlocks) \ - _(ModuleInterfaceSerialization) \ - _(ModuleCloneWithModuleInterface) \ - _(ClassTypeAddRemoveAttr) \ - _(Inliner) \ - _(LiteInterpreterAdd) \ - _(LiteInterpreterConv) \ - _(LiteInterpreterInline) \ - _(LiteInterpreterTuple) \ - _(LiteInterpreterUpsampleNearest2d) \ - _(CommonAncestor) \ - _(AutogradSymbols) \ - _(DefaultArgTypeHinting) \ - _(Futures) \ - _(TLSFutureCallbacks) \ - _(ProfilerDisableInCallback) \ - _(MobileTypeParser) \ - _(LiteInterpreterBuiltinFunction) \ - _(LiteInterpreterPrim) \ - _(LiteInterpreterPrimScalar) \ - _(LiteInterpreterLoadOrigJit) \ - _(LiteInterpreterWrongMethodName) \ - _(LiteInterpreterParams) \ - _(LiteInterpreterSetState) \ - _(LiteInterpreterModuleInfoBasic) \ - _(LiteInterpreterNotSavingModuleInfo) \ - _(LiteInterpreterOneSubmoduleModuleInfo) \ - _(LiteInterpreterTwoSubmodulesModuleInfo) \ - _(LiteInterpreterSequentialModuleInfo) \ - _(LiteInterpreterHierarchyModuleInfo) \ - _(LiteInterpreterDuplicatedClassTypeModuleInfo) \ - _(LiteInterpreterEval) \ - _(LiteInterpreterDict) \ - _(LiteInterpreterFindAndRunMethod) \ - _(LiteInterpreterFindWrongMethodName) \ - _(MobileNamedParameters) \ - _(MobileSaveLoadData) \ - _(MobileSaveLoadParameters) \ - _(MobileSaveLoadParametersEmpty) \ - _(LiteSGD) \ - _(LiteSequentialSampler) - -#if defined(USE_CUDA) -#define TH_FORALL_TESTS_CUDA(_) \ - _(GraphExecutor) \ - _(ModuleConversion) \ - _(Interp) \ - _(TypeCheck) \ - _(GPU_IrGraphGenerator) \ - _(GPU_FusionDispatch) \ - _(GPU_FusionClear) \ - _(GPU_FusionCopy) \ - _(GPU_FusionMove) \ - _(GPU_FusionSimpleArith) \ - _(GPU_FusionExprEvalConstants) \ - _(GPU_FusionExprEvalBindings) \ - _(GPU_FusionExprEvalBasic) \ - _(GPU_FusionExprEvalComplex) \ - _(GPU_FusionExprEvalPostLower) \ - _(GPU_FusionSimpleTypePromote) \ - _(GPU_FusionMutator) \ - _(GPU_FusionRegister) \ - _(GPU_FusionTopoSort) \ - _(GPU_FusionTensor) \ - _(GPU_FusionFilterVals) \ - _(GPU_FusionTVSplit) \ - _(GPU_FusionTVMerge) \ - _(GPU_FusionTVReorder) \ - _(GPU_FusionEquality) \ - _(GPU_FusionParser) \ - _(GPU_FusionDependency) \ - _(GPU_FusionCodeGen) \ - _(GPU_FusionCodeGen2) \ - _(GPU_FusionSimplePWise) \ - _(GPU_FusionExecKernel) \ - _(GPU_FusionForLoop) \ - _(GPU_FusionLoopUnroll) \ - _(GPU_FusionUnaryOps) \ - _(GPU_FusionBinaryOps) \ - _(GPU_FusionTernaryOps) \ - _(GPU_FusionCompoundOps) \ - _(GPU_FusionCastOps) \ - _(GPU_FusionAdvancedComputeAt) \ - _(GPU_FusionComputeAtMultiConsumers) \ - _(GPU_FusionComputeAtCommonConsumer1) \ - _(GPU_FusionComputeAtCommonConsumer2) \ - _(GPU_FusionComputeAtCommonConsumer3) \ - _(GPU_FusionComputeAtNoCommonConsumer) \ - _(GPU_FusionScalarInputs) \ - _(GPU_FusionBCastConcretizeBasic) \ - _(GPU_FusionBCastConcretizeRfactor) \ - _(GPU_FusionProveIdEqBasic) \ - _(GPU_FusionProveIdEqRfactor) \ - _(GPU_FusionRFactorReplay) \ - _(GPU_FusionReduction) \ - _(GPU_FusionReduction2) \ - _(GPU_FusionReduction3) \ - _(GPU_FusionReduction4) \ - _(GPU_FusionReduction5) \ - _(GPU_FusionReductionTFT) \ - _(GPU_FusionSimpleBCast) \ - _(GPU_FusionComplexBCast) \ - _(GPU_FusionAdvancedIndexing) \ - _(GPU_FusionSimpleGemm) \ - _(GPU_FusionSoftmax1D) \ - _(GPU_FusionSoftmax1DNormalized) \ - _(GPU_FusionSoftmax3D) \ - _(GPU_FusionSoftmax3DNormalized) \ - _(GPU_FusionSoftmaxComputeAt) \ - _(GPU_FusionGridReduction1) \ - _(GPU_FusionGridReduction2) \ - _(GPU_FusionGridReduction3dim1) \ - _(GPU_FusionGridReduction3dim0) \ - _(GPU_FusionGridReduction4) \ - _(GPU_FusionGridReduction5) \ - _(GPU_FusionGridReduction6) \ - _(GPU_FusionNonRedAxisBind) \ - _(GPU_FusionBCastInnerDim) \ - _(GPU_FusionBCastReduce) \ - _(GPU_FusionSplitBCast) \ - _(GPU_FusionComputeAtExprOrder) \ - _(GPU_FusionZeroDimComputeAt) \ - _(GPU_FusionZeroDimBroadcast) \ - _(GPU_FusionZeroDimReduction) \ - _(GPU_FusionReductionMultiConsumer) \ - _(GPU_FusionBCastAfterReduce) \ - _(GPU_FusionReductionScheduler) \ - _(GPU_FusionReductionSchedulerMultiDimNonFastest) \ - _(GPU_FusionReductionSchedulerMultiDimFastest) \ - _(GPU_FusionReductionSchedulerDimShmoo) \ - _(GPU_FusionCacheBefore) \ - _(GPU_FusionCacheAfter) \ - _(GPU_FusionCacheIndirect) \ - _(GPU_FusionCacheBcast) \ - _(GPU_FusionCacheComplex) \ - _(GPU_FusionCacheMultiConsumer) \ - _(GPU_FusionSmem) \ - _(GPU_FusionSmemReduce) \ - _(GPU_FusionSmemBlockGemm) \ - _(GPU_FusionSmemBlockGemmCache) \ - _(GPU_FusionSmemDynamicReductionSymbolic) \ - _(GPU_FusionSmemDynamicReductionSymbolicArg) \ - _(GPU_FusionSmemDynamicPwiseMulSymbolicArgWAR) \ - _(GPU_FusionSmemDynamicTiledGemm) \ - _(GPU_FusionGlobalIntermediate) \ - _(GPU_FusionGlobalIntermediateDefaultSchedule) \ - _(GPU_FusionConstCheck) \ - _(GPU_FusionSymbolicReduction) \ - _(GPU_FusionUnrollWithAlloc) \ - _(GPU_FusionIsZeroInt) \ - _(GPU_FusionIsOneInt) \ - _(GPU_FusionComputeAtNonterminatingOutput) \ - _(GPU_FusionTraversalOrder1) \ - _(GPU_FusionTraversalOrder2) \ - _(GPU_FusionTraversalOrder3) \ - _(GPU_FusionTraversalOrder4) \ - _(GPU_FusionTraversalOrder5) \ - _(GPU_FusionTraversalOrder6) \ - _(GPU_FusionTraversalOrder7) \ - _(GPU_FusionBranches) \ - _(GPU_FusionThreadPredicate) \ - _(GPU_FusionLSTMCell) \ - _(GPU_FusionComputeAtMultiBCast) \ - _(GPU_FusionReductionHalf) \ - _(GPU_FusionInputsIdLookup) -#else -#define TH_FORALL_TESTS_CUDA(_) \ - _(GraphExecutor) \ - _(ModuleConversion) \ - _(Interp) \ - _(TypeCheck) -#endif - -#define DECLARE_JIT_TEST(name) void test##name(); -TH_FORALL_TESTS(DECLARE_JIT_TEST) -TH_FORALL_TESTS_CUDA(DECLARE_JIT_TEST) -#undef DECLARE_JIT_TEST - -// This test is special since it requires prior setup in python. -// So it is not part of the general test list (which is shared between the gtest -// and python test runners), but is instead invoked manually by the -// torch_python_test.cpp -void testEvalModeForLoadedModule(); -void testSerializationInterop(); -void testTorchSaveError(); - -} // namespace jit -} // namespace torch diff --git a/test/cpp/jit/tests_setup.py b/test/cpp/jit/tests_setup.py index 68871d1c21d21..928a06d9b5a0b 100644 --- a/test/cpp/jit/tests_setup.py +++ b/test/cpp/jit/tests_setup.py @@ -63,11 +63,38 @@ def setup(self): torch.save(value, self.path, _use_new_zipfile_serialization=False) +class TorchSaveJitStream_CUDA(FileSetup): + path = 'saved_stream_model.pt' + + def setup(self): + if not torch.cuda.is_available(): + return + + class Model(torch.nn.Module): + def forward(self): + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + a = torch.rand(3, 4, device="cuda") + b = torch.rand(3, 4, device="cuda") + + with torch.jit.cuda.stream(s): + is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id() + c = torch.cat((a, b), 0).to("cuda") + s.synchronize() + return is_stream_s, a, b, c + + model = Model() + + # Script the model and save + script_model = torch.jit.script(model) + torch.jit.save(script_model, self.path) + tests = [ EvalModeForLoadedModule(), SerializationInterop(), TorchSaveError(), + TorchSaveJitStream_CUDA() ] def setup(): diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 9d3ab71c0cfcf..cea5079b1a4eb 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -28,7 +28,7 @@ class TestE2EBase : public ::testing::Test { autogradContainer = getDistAutogradContainer(); // Setup server store. - store = std::make_shared( + store = c10::make_intrusive( serverAddress, 0, numWorkers, true, std::chrono::seconds(10)); buildRpcAgent(); @@ -64,19 +64,21 @@ class TestE2EBase : public ::testing::Test { ScriptRemoteCall scriptRemoteCall( op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId()); - auto fm = autograd::sendMessageWithAutograd( + auto jitFuture = autograd::sendMessageWithAutograd( *rpcAgent, rpcAgent->getWorkerInfo("worker"), std::move(scriptRemoteCall).toMessage(), false); - ownerRRef->registerOwnerCreationFuture(fm); + ownerRRef->registerOwnerCreationFuture(jitFuture); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. - fm->addCallback( - [ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) { - callback::finishCreatingOwnerRRef(fm, ownerRRefId); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( + [wp, ownerRRefId = ownerRRef->rrefId()]() { + auto jitFuture = wp.lock(); + callback::finishCreatingOwnerRRef(*jitFuture, ownerRRefId); }); return ownerRRef; } @@ -89,12 +91,14 @@ class TestE2EBase : public ::testing::Test { // Send the RPC and return result. auto response = autograd::sendMessageWithAutograd( - *rpcAgent, - rpcAgent->getWorkerInfo("worker"), - std::move(scriptCall).toMessage()) - ->wait(); + *rpcAgent, + rpcAgent->getWorkerInfo("worker"), + std::move(scriptCall).toMessage()); + response->waitAndThrow(); + MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP; - auto wrappedResponse = deserializeResponse(response, messageType); + auto wrappedResponse = deserializeResponse( + std::move(*response->value().toCustomClass()), messageType); return static_cast(*wrappedResponse).value().toTensor(); } @@ -147,7 +151,7 @@ class TestE2EBase : public ::testing::Test { std::shared_ptr rpcAgent; static const size_t numIters; static const size_t numWorkers; - std::shared_ptr store; + c10::intrusive_ptr store; static const char* serverAddress; }; diff --git a/test/cpp/rpc/test_e2e_process_group.cpp b/test/cpp/rpc/test_e2e_process_group.cpp index d509a4606fa1b..01bed87687a4c 100644 --- a/test/cpp/rpc/test_e2e_process_group.cpp +++ b/test/cpp/rpc/test_e2e_process_group.cpp @@ -19,10 +19,11 @@ class TestE2EProcessGroup : public TestE2EBase { options.devices.push_back( ::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress)); std::chrono::milliseconds rpcTimeout(30000); + options.timeout = rpcTimeout; // Initialize server rpc agent. - auto pg = - std::make_shared(store, 0, numWorkers, options); + auto pg = c10::make_intrusive( + store, 0, numWorkers, options); rpcAgent = std::make_shared( "worker", diff --git a/test/cpp/rpc/test_e2e_tensorpipe.cpp b/test/cpp/rpc/test_e2e_tensorpipe.cpp index 8fecf6dffb75e..b7b5e1d91c1bb 100644 --- a/test/cpp/rpc/test_e2e_tensorpipe.cpp +++ b/test/cpp/rpc/test_e2e_tensorpipe.cpp @@ -11,7 +11,6 @@ namespace torch { namespace distributed { namespace rpc { - #ifdef USE_TENSORPIPE class TestE2ETensorPipe : public TestE2EBase { @@ -23,8 +22,8 @@ class TestE2ETensorPipe : public TestE2EBase { float rpcTimeout = 30; // Initialize server rpc agent. - auto pg = - std::make_shared(store, 0, numWorkers, options); + auto pg = c10::make_intrusive( + store, 0, numWorkers, options); TensorPipeRpcBackendOptions opts( /*numWorkerThreads=*/std::max(16U, std::thread::hardware_concurrency()), @@ -49,6 +48,15 @@ class TestE2ETensorPipe : public TestE2EBase { // challenging and we don't have a good solution yet. TEST_F(TestE2ETensorPipe, TestTrainingLoop) { runTrainingLoop(); + // Ensure the tensorpipe internal state is cleared up. + auto tensorpipeAgent = std::static_pointer_cast(rpcAgent); + // Wait a while for async RPCs to propagate through (ex: dist autograd + // cleanup) + while (tensorpipeAgent->numPendingResponses() != 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + ASSERT_EQ(0, tensorpipeAgent->numPendingResponses()); + ASSERT_EQ(0, tensorpipeAgent->timeoutMapSize()); } #endif diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt index a2922045adffd..f1ed3743d7424 100644 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -1,16 +1,43 @@ set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr) -file(GLOB TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_*.cpp) -set(TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_SRCS} PARENT_SCOPE) +set(TENSOREXPR_TEST_SRCS + ${TENSOREXPR_TEST_ROOT}/test_aten.cpp + ${TENSOREXPR_TEST_ROOT}/test_boundsinference.cpp + ${TENSOREXPR_TEST_ROOT}/test_conv.cpp + ${TENSOREXPR_TEST_ROOT}/test_expr.cpp + ${TENSOREXPR_TEST_ROOT}/test_ir_printer.cpp + ${TENSOREXPR_TEST_ROOT}/test_kernel.cpp + ${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp + ${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp + ${TENSOREXPR_TEST_ROOT}/test_reductions.cpp + ${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp + ${TENSOREXPR_TEST_ROOT}/test_simplify.cpp + ${TENSOREXPR_TEST_ROOT}/test_te_fuser_pass.cpp + ${TENSOREXPR_TEST_ROOT}/test_train.cpp + ${TENSOREXPR_TEST_ROOT}/test_train_impl.cpp + ${TENSOREXPR_TEST_ROOT}/test_type.cpp +) + +if(USE_CUDA) + list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp) +endif() + +if(USE_LLVM AND LLVM_FOUND) + list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp) +endif() add_executable(test_tensorexpr ${TORCH_ROOT}/test/cpp/common/main.cpp - ${TENSOREXPR_TEST_ROOT}/gtest.cpp ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp ${TENSOREXPR_TEST_SRCS}) target_link_libraries(test_tensorexpr PRIVATE torch gtest) target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) +target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) + +add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) +target_link_libraries(tutorial_tensorexpr PRIVATE torch) +target_include_directories(tutorial_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) if(USE_CUDA) target_link_libraries(test_tensorexpr PRIVATE @@ -18,23 +45,34 @@ if(USE_CUDA) ${CUDA_NVRTC_LIB} ${CUDA_CUDA_LIB} ${TORCH_CUDA_LIBRARIES}) - target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA) + + target_link_libraries(tutorial_tensorexpr PRIVATE + ${CUDA_LIBRARIES} + ${CUDA_NVRTC_LIB} + ${CUDA_CUDA_LIB} + ${TORCH_CUDA_LIBRARIES}) + target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA) elseif(USE_ROCM) target_link_libraries(test_tensorexpr PRIVATE ${ROCM_HIPRTC_LIB} ${PYTORCH_HIP_HCC_LIBRARIES} ${TORCH_CUDA_LIBRARIES}) - - target_link_libraries(test_tensorexpr PRIVATE caffe2_gpu) - target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) + + target_link_libraries(tutorial_tensorexpr PRIVATE + ${ROCM_HIPRTC_LIB} + ${PYTORCH_HIP_HCC_LIBRARIES} + ${TORCH_CUDA_LIBRARIES}) + target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) endif() if(INSTALL_TEST) install(TARGETS test_tensorexpr DESTINATION bin) + install(TARGETS tutorial_tensorexpr DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) install(FILES $ DESTINATION bin OPTIONAL) + install(FILES $ DESTINATION bin OPTIONAL) endif() endif() diff --git a/test/cpp/tensorexpr/gtest.cpp b/test/cpp/tensorexpr/gtest.cpp deleted file mode 100644 index 14a5e81f95da9..0000000000000 --- a/test/cpp/tensorexpr/gtest.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include - -#include - -namespace torch { -namespace jit { - -#define TENSOREXPR_GTEST(name) \ - TEST(TensorExprTest, name) { \ - test##name(); \ - } -TH_FORALL_TENSOREXPR_TESTS(TENSOREXPR_GTEST) -#undef TENSOREXPR_GTEST - -#ifdef TORCH_ENABLE_LLVM -#define TENSOREXPR_GTEST_LLVM(name) \ - TEST(TensorExprTest, name##_LLVM) { \ - test##name(); \ - } -TH_FORALL_TENSOREXPR_TESTS_LLVM(TENSOREXPR_GTEST_LLVM) -#undef TENSOREXPR_GTEST_LLVM -#endif - -#ifdef USE_CUDA -#define TENSOREXPR_GTEST_CUDA(name) \ - TEST(TensorExprTest, name##_CUDA) { \ - test##name(); \ - } -TH_FORALL_TENSOREXPR_TESTS_CUDA(TENSOREXPR_GTEST_CUDA) -#undef TENSOREXPR_GTEST_CUDA -#endif - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp index 3ccc484c84201..1183b2e85b227 100644 --- a/test/cpp/tensorexpr/test_aten.cpp +++ b/test/cpp/tensorexpr/test_aten.cpp @@ -2,6 +2,8 @@ #include #include +#include + #include #include "test/cpp/tensorexpr/padded_buffer.h" #include "test/cpp/tensorexpr/test_base.h" @@ -12,16 +14,16 @@ namespace jit { using namespace torch::jit::tensorexpr; -void testATen_cast_Float() { +TEST(ATen, _cast_Float) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); + ExprHandle load_a = a_buf.load(index); ExprHandle to_float = Cast::make(kFloat, load_a); - Stmt* store_b = Store::make(b_buf, {index}, to_float, 1); + Stmt* store_b = b_buf.store({index}, to_float); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -31,7 +33,7 @@ void testATen_cast_Float() { a_v(i) = i; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -40,16 +42,16 @@ void testATen_cast_Float() { } } -void testATennegInt() { +TEST(ATen, negInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); + ExprHandle load_a = a_buf.load(index); ExprHandle to_float = Sub::make(0, load_a); - Stmt* store_b = Store::make(b_buf, {index}, to_float, 1); + Stmt* store_b = b_buf.store({index}, to_float); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -59,7 +61,7 @@ void testATennegInt() { a_v(i) = i; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -68,16 +70,16 @@ void testATennegInt() { } } -void testATennegFloat() { +TEST(ATen, negFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); + ExprHandle load_a = a_buf.load(index); ExprHandle to_float = Sub::make(0, load_a); - Stmt* store_b = Store::make(b_buf, {index}, to_float, 1); + Stmt* store_b = b_buf.store({index}, to_float); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -87,7 +89,7 @@ void testATennegFloat() { a_v(i) = i; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -96,19 +98,19 @@ void testATennegFloat() { } } -void testATenaddInt() { +TEST(ATen, addInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); - Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); + Placeholder d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - ExprHandle load_c = Load::make(c_buf, {index}, 1); - Stmt* store_d = Store::make(d_buf, {index}, load_a + load_b * load_c, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + Stmt* store_d = d_buf.store({index}, load_a + load_b * load_c); Stmt* stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); @@ -122,7 +124,7 @@ void testATenaddInt() { c_v(i) = 3 * i + 2; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); ir_eval(a_v, b_v, c_v, d_v); for (int i = 0; i < kTotalSize; ++i) { @@ -133,19 +135,19 @@ void testATenaddInt() { } } -void testATenaddFloat() { +TEST(ATen, addFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); - Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - ExprHandle load_c = Load::make(c_buf, {index}, 1); - Stmt* store_d = Store::make(d_buf, {index}, load_a + load_b * load_c, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + Stmt* store_d = d_buf.store({index}, load_a + load_b * load_c); Stmt* stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); @@ -159,7 +161,7 @@ void testATenaddFloat() { c_v(i) = 3 * i + 2; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); ir_eval(a_v, b_v, c_v, d_v); for (int i = 0; i < kTotalSize; ++i) { @@ -170,19 +172,19 @@ void testATenaddFloat() { } } -void testATensubInt() { +TEST(ATen, subInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); - Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); + Placeholder d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - ExprHandle load_c = Load::make(c_buf, {index}, 1); - Stmt* store_d = Store::make(d_buf, {index}, load_a - load_b * load_c, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + Stmt* store_d = d_buf.store({index}, load_a - load_b * load_c); Stmt* stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); @@ -196,7 +198,7 @@ void testATensubInt() { c_v(i) = 3 * i + 2; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); ir_eval(a_v, b_v, c_v, d_v); for (int i = 0; i < kTotalSize; ++i) { @@ -207,19 +209,19 @@ void testATensubInt() { } } -void testATensubFloat() { +TEST(ATen, subFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); - Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - ExprHandle load_c = Load::make(c_buf, {index}, 1); - Stmt* store_d = Store::make(d_buf, {index}, load_a - load_b * load_c, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + Stmt* store_d = d_buf.store({index}, load_a - load_b * load_c); Stmt* stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); @@ -233,7 +235,7 @@ void testATensubFloat() { c_v(i) = 3 * i + 2; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); ir_eval(a_v, b_v, c_v, d_v); for (int i = 0; i < kTotalSize; ++i) { @@ -244,20 +246,19 @@ void testATensubFloat() { } } -void testATenlerp() { +TEST(ATen, lerp) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); - Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - ExprHandle load_c = Load::make(c_buf, {index}, 1); - Stmt* store_d = - Store::make(d_buf, {index}, load_a + load_c * (load_b - load_a), 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + Stmt* store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); Stmt* stmt = For::make(index, 0, kTotalSize, store_d); PaddedBuffer a_v(kTotalSize); @@ -271,7 +272,7 @@ void testATenlerp() { c_v(i) = 3 * i + 2; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); ir_eval(a_v, b_v, c_v, d_v); for (int i = 0; i < kTotalSize; ++i) { @@ -282,22 +283,21 @@ void testATenlerp() { } } -void testATenaddcmulInt() { +TEST(ATen, addcmulInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); - Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kInt)); - Buffer e_buf(BufHandle("E", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); + Placeholder d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kInt)); + Placeholder e_buf(BufHandle("E", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - ExprHandle load_c = Load::make(c_buf, {index}, 1); - ExprHandle load_d = Load::make(d_buf, {index}, 1); - Stmt* store_e = - Store::make(e_buf, {index}, load_a + load_b * load_c * load_d, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + ExprHandle load_d = d_buf.load(index); + Stmt* store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); Stmt* stmt = For::make(index, 0, kTotalSize, store_e); PaddedBuffer a_v(kTotalSize); @@ -313,7 +313,7 @@ void testATenaddcmulInt() { d_v(i) = 5 * i + 3; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf, e_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); ir_eval(a_v, b_v, c_v, d_v, e_v); for (int i = 0; i < kTotalSize; ++i) { @@ -325,22 +325,21 @@ void testATenaddcmulInt() { } } -void testATenaddcmulFloat() { +TEST(ATen, addcmulFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); - Buffer d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); - Buffer e_buf(BufHandle("E", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder d_buf(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder e_buf(BufHandle("E", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - ExprHandle load_c = Load::make(c_buf, {index}, 1); - ExprHandle load_d = Load::make(d_buf, {index}, 1); - Stmt* store_e = - Store::make(e_buf, {index}, load_a + load_b * load_c * load_d, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + ExprHandle load_d = d_buf.load(index); + Stmt* store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); Stmt* stmt = For::make(index, 0, kTotalSize, store_e); PaddedBuffer a_v(kTotalSize); @@ -356,7 +355,7 @@ void testATenaddcmulFloat() { d_v(i) = 5 * i + 3; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf, e_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); ir_eval(a_v, b_v, c_v, d_v, e_v); for (int i = 0; i < kTotalSize; ++i) { @@ -368,17 +367,17 @@ void testATenaddcmulFloat() { } } -void testATenmulInt() { +TEST(ATen, mulInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = Store::make(c_buf, {index}, load_a * load_b, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, load_a * load_b); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -390,7 +389,7 @@ void testATenmulInt() { b_v(i) = 2 * i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -400,17 +399,17 @@ void testATenmulInt() { } } -void testATenmulFloat() { +TEST(ATen, mulFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = Store::make(c_buf, {index}, load_a * load_b, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, load_a * load_b); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -422,7 +421,7 @@ void testATenmulFloat() { b_v(i) = 2 * i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -432,17 +431,17 @@ void testATenmulFloat() { } } -void testATendivInt() { +TEST(ATen, divInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = Store::make(c_buf, {index}, load_a / load_b, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, load_a / load_b); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -454,7 +453,7 @@ void testATendivInt() { b_v(i) = i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -464,17 +463,17 @@ void testATendivInt() { } } -void testATendivFloat() { +TEST(ATen, divFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = Store::make(c_buf, {index}, load_a / load_b, 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, load_a / load_b); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -486,7 +485,7 @@ void testATendivFloat() { b_v(i) = i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -496,18 +495,17 @@ void testATendivFloat() { } } -void testATenmaxInt() { +TEST(ATen, maxInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = - Store::make(c_buf, {index}, Max::make(load_a, load_b, true), 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -519,7 +517,7 @@ void testATenmaxInt() { b_v(i) = 2 * i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -529,18 +527,17 @@ void testATenmaxInt() { } } -void testATenmaxFloat() { +TEST(ATen, maxFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = - Store::make(c_buf, {index}, Max::make(load_a, load_b, true), 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -552,7 +549,7 @@ void testATenmaxFloat() { b_v(i) = 2 * i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -562,18 +559,17 @@ void testATenmaxFloat() { } } -void testATenminInt() { +TEST(ATen, minInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = - Store::make(c_buf, {index}, Min::make(load_a, load_b, true), 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -585,7 +581,7 @@ void testATenminInt() { b_v(i) = 2 * i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -595,18 +591,17 @@ void testATenminInt() { } } -void testATenminFloat() { +TEST(ATen, minFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - ExprHandle load_b = Load::make(b_buf, {index}, 1); - Stmt* store_c = - Store::make(c_buf, {index}, Min::make(load_a, load_b, true), 1); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + Stmt* store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); Stmt* stmt = For::make(index, 0, kTotalSize, store_c); PaddedBuffer a_v(kTotalSize); @@ -618,7 +613,7 @@ void testATenminFloat() { b_v(i) = 2 * i + 1; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); for (int i = 0; i < kTotalSize; ++i) { @@ -631,12 +626,12 @@ void testATenminFloat() { void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, FloatImm::make(1.0f) / load_a, 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -646,7 +641,7 @@ void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { a_v(i) = i; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -655,15 +650,15 @@ void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { } } -void testATenreluInt() { +TEST(ATen, reluInt) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, Max::make(load_a, 0, false), 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, Max::make(load_a, 0, false)); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -673,7 +668,7 @@ void testATenreluInt() { a_v(i) = i - 64; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -682,19 +677,17 @@ void testATenreluInt() { } } -void testATenreluFloat() { +TEST(ATen, reluFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make( - b_buf, - {index}, - Max::make(load_a, 0, false), // relu does not propagate nans - 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store( + {index}, Max::make(load_a, 0, false) // relu does not propagate nans + ); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -704,7 +697,7 @@ void testATenreluFloat() { a_v(i) = i - 64; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -713,15 +706,15 @@ void testATenreluFloat() { } } -void testATenlogFloat() { +TEST(ATen, logFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, log(load_a), 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, log(load_a)); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -731,7 +724,7 @@ void testATenlogFloat() { a_v(i) = i + 10; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -740,15 +733,47 @@ void testATenlogFloat() { } } -void testATenlog10Float() { +TEST(ATen, fastLogFloat) { + KernelScope kernel_scope; + const int kTotalSize = 128 * 128; + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, fast_log(load_a)); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = at::randn({1}).item().to(); + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + auto test = b_v(i); + auto ref = std::log(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_FLOAT_EQ(test, ref); + } + } +} + +TEST(ATen, log10Float) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, log10(load_a), 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, log10(load_a)); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -758,7 +783,7 @@ void testATenlog10Float() { a_v(i) = i + 10; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -767,15 +792,15 @@ void testATenlog10Float() { } } -void testATenlog2Float() { +TEST(ATen, log2Float) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, log2(load_a), 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, log2(load_a)); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -785,7 +810,7 @@ void testATenlog2Float() { a_v(i) = i + 10; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -794,15 +819,15 @@ void testATenlog2Float() { } } -void testATenexpFloat() { +TEST(ATen, expFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, exp(load_a), 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, exp(load_a)); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -812,7 +837,7 @@ void testATenexpFloat() { a_v(i) = i / 10.0f; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -821,15 +846,15 @@ void testATenexpFloat() { } } -void testATenerfFloat() { +TEST(ATen, erfFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, erf(load_a), 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, erf(load_a)); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -839,7 +864,7 @@ void testATenerfFloat() { a_v(i) = i / 10.0f; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -848,15 +873,15 @@ void testATenerfFloat() { } } -void testATencosFloat() { +TEST(ATen, cosFloat) { KernelScope kernel_scope; const int kTotalSize = 128; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make(a_buf, {index}, 1); - Stmt* store_b = Store::make(b_buf, {index}, cos(load_a), 1); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, cos(load_a)); Stmt* stmt = For::make(index, 0, kTotalSize, store_b); PaddedBuffer a_v(kTotalSize); @@ -866,7 +891,7 @@ void testATencosFloat() { a_v(i) = i / 10.0f; } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); ir_eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { @@ -875,156 +900,131 @@ void testATencosFloat() { } } -void testATeneqInt() { +TEST(ATen, eqInt) { KernelScope kernel_scope; constexpr int N = 128; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 1); std::vector b_buffer(N, 1); std::vector c_buffer(N, 0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto memcpy_expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kEQ), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kEQ))); - SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); ir_eval(a_buffer, b_buffer, c_buffer); assertAllEqual(c_buffer, 1); } -void testATengeInt() { +TEST(ATen, geInt) { KernelScope kernel_scope; constexpr int N = 128; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 5); std::vector b_buffer(N, 5); std::vector c_buffer(N, 0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto memcpy_expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kGE), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kGE))); - SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); ir_eval(a_buffer, b_buffer, c_buffer); assertAllEqual(c_buffer, 1); } -void testATengtInt() { +TEST(ATen, gtInt) { KernelScope kernel_scope; constexpr int N = 128; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 6); std::vector b_buffer(N, 3); std::vector c_buffer(N, 0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto memcpy_expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kGT), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kGT))); - SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); ir_eval(a_buffer, b_buffer, c_buffer); assertAllEqual(c_buffer, 1); } -void testATenleInt() { +TEST(ATen, leInt) { KernelScope kernel_scope; constexpr int N = 128; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 5); std::vector b_buffer(N, 5); std::vector c_buffer(N, 0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto memcpy_expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kLE), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kLE))); - SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); ir_eval(a_buffer, b_buffer, c_buffer); assertAllEqual(c_buffer, 1); } -void testATenltInt() { +TEST(ATen, ltInt) { KernelScope kernel_scope; constexpr int N = 128; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 5); std::vector b_buffer(N, 5); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto memcpy_expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kLT), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kLT))); - SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); ir_eval(a_buffer, b_buffer, c_buffer); assertAllEqual(c_buffer, 0); diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp index 98d3d4127da8b..7d1c0820ab18d 100644 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -1,15 +1,14 @@ -#include #include #include #include #include +#include + #include #include #include -#include #include -#include #include #include #include @@ -41,7 +40,7 @@ static void verifyConstBounds( } } -void testBoundsInference_1() { +TEST(BoundsInference, _1) { // Verify that bounds inference works for the following example: // for i in 0..100: // b[i] = a[i] @@ -49,9 +48,9 @@ void testBoundsInference_1() { // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} KernelScope kernel_scope; ExprHandle n(100); - Buffer a(BufHandle("a", {n}, kFloat)); + Placeholder a(BufHandle("a", {n}, kFloat)); Tensor* b = - Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); }); + Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); auto bounds_info = inferBounds(l.root_stmt()); @@ -66,7 +65,7 @@ void testBoundsInference_1() { verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}}); } -void testBoundsInference_2() { +TEST(BoundsInference, _2) { // Verify that bounds inference works for the following example: // for i in 0..n: // b[i] = a[i] @@ -74,9 +73,9 @@ void testBoundsInference_2() { // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} KernelScope kernel_scope; VarHandle n("n", kInt); - Buffer a(BufHandle("a", {n}, kFloat)); + Placeholder a(BufHandle("a", {n}, kFloat)); Tensor* b = - Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); }); + Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); auto bounds_info = inferBounds(l.root_stmt()); @@ -91,7 +90,7 @@ void testBoundsInference_2() { verifyConstBounds(bounds_info.at(b->buf())[0], {{0, -1}}); } -void testBoundsInference_3() { +TEST(BoundsInference, _3) { // Verify that bounds inference works for the following example: // for i in 0..100: // b[i] = a[i] * a[i+10] @@ -99,9 +98,10 @@ void testBoundsInference_3() { // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} KernelScope kernel_scope; ExprHandle n(100); - Buffer a(BufHandle("a", {n + 10}, kFloat)); - Tensor* b = Compute( - "b", {{n, "i"}}, [&](const VarHandle& i) { return a(i) * a(i + 10); }); + Placeholder a(BufHandle("a", {n + 10}, kFloat)); + Tensor* b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) { + return a.load(i) * a.load(i + 10); + }); LoopNest l({b}); auto bounds_info = inferBounds(l.root_stmt()); @@ -116,7 +116,7 @@ void testBoundsInference_3() { verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}}); } -void testBoundsInference_4() { +TEST(BoundsInference, _4) { // Verify that bounds inference works for the following example: // // for y in 0..200: @@ -128,14 +128,14 @@ void testBoundsInference_4() { KernelScope kernel_scope; ExprHandle W(320); ExprHandle H(200); - Buffer a(BufHandle("a", {H, W}, kFloat)); + Placeholder a(BufHandle("a", {H, W}, kFloat)); Tensor* b = Compute( "b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) { return x * y; }); Tensor* c = Compute( "c", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) { - return a(y, x) * b->call(y, x); + return a.load(y, x) * b->call(y, x); }); LoopNest l({c}); std::vector loops = l.getLoopStmtsFor(c); @@ -193,7 +193,7 @@ void testBoundsInference_4() { } } -void testBoundsInference_5() { +TEST(BoundsInference, _5) { // Verify that bounds inference works for the following example: // for i in 0..100: // b[i] = a[i] @@ -207,9 +207,9 @@ void testBoundsInference_5() { // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; KernelScope kernel_scope; ExprHandle n(100); - Buffer a(BufHandle("a", {n}, kFloat)); + Placeholder a(BufHandle("a", {n}, kFloat)); Tensor* b = - Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); }); + Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); For* outer; @@ -246,7 +246,7 @@ void testBoundsInference_5() { } } -void testBoundsInference_6() { +TEST(BoundsInference, _6) { // Verify that bounds inference works for the following example: // // for y in 0..200: @@ -260,14 +260,14 @@ void testBoundsInference_6() { ExprHandle H(200); ExprHandle CW(32); ExprHandle CH(20); - Buffer a(BufHandle("a", {H, W}, kFloat)); + Placeholder a(BufHandle("a", {H, W}, kFloat)); Tensor* b = Compute( "b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) { return x * y; }); Tensor* c = Compute( "c", {{CH, "y"}, {CW, "x"}}, [&](const VarHandle& y, const VarHandle& x) { - return a(y + 100, x + 100) * b->call(y * 2, x * 5); + return a.load(y + 100, x + 100) * b->call(y * 2, x * 5); }); LoopNest l({c}); std::vector loops = l.getLoopStmtsFor(c); @@ -325,14 +325,14 @@ void testBoundsInference_6() { } } -void testBoundsInferenceNonOverlapping() { +TEST(BoundsInference, Adjacent) { KernelScope kernel_scope; - ExprHandle H(3); - Buffer a(BufHandle("a", {10}, kFloat)); + ExprHandle H(6); + Placeholder a(BufHandle("a", {20}, kFloat)); Tensor* b = - Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a(x); }); + Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x); }); Tensor* c = Compute( - "c", {{H, "x"}}, [&](const VarHandle& x) { return a(x + H + 1); }); + "c", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x + H); }); LoopNest l({b, c}); std::vector loops = NodeFinder::find(l.root_stmt()); @@ -341,487 +341,388 @@ void testBoundsInferenceNonOverlapping() { auto bounds_info = inferBounds(loops[0]); ASSERT_EQ(bounds_info.size(), 2); - // reads from a[0:2], writes to b[0:2] + // reads from a[0:5], writes to b[0:5] ASSERT_EQ(bounds_info.at(a.data()).size(), 1); ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.data())[0], {{0, 2}}); + verifyConstBounds(bounds_info.at(a.data())[0], {{0, 5}}); ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 2}}); + verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}}); } { // Infer bounds on the inner loop scope auto bounds_info = inferBounds(loops[1]); ASSERT_EQ(bounds_info.size(), 2); - // reads from a[0+4:2+4], writes to c[0:2] + // reads from a[0+6:5+6], writes to c[0:5] ASSERT_EQ(bounds_info.at(a.data()).size(), 1); ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.data())[0], {{4, 6}}); + verifyConstBounds(bounds_info.at(a.data())[0], {{6, 11}}); ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 2}}); + verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}}); } { // Infer bounds on the high level program. auto bounds_info = inferBounds(l.root_stmt()); ASSERT_EQ(bounds_info.size(), 3); - // Should be union of above 2 bounds. - ASSERT_EQ(bounds_info.at(a.data()).size(), 2); - ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.data())[0], {{0, 2}}); + // Should be union of above 2 bounds, but this time the bounds of A can be + // merged. + ASSERT_EQ(bounds_info.at(a.data()).size(), 1); ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.data())[1], {{4, 6}}); + verifyConstBounds(bounds_info.at(a.data())[0], {{0, 11}}); ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 2}}); + verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}}); ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 2}}); + verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}}); } } -void testBoundsInferenceAdjacent() { +TEST(BoundsInference, MultipleTopLoopLoad) { KernelScope kernel_scope; - ExprHandle H(6); - Buffer a(BufHandle("a", {20}, kFloat)); + Placeholder a(BufHandle("a", {100}, kFloat)); Tensor* b = - Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a(x); }); - Tensor* c = - Compute("c", {{H, "x"}}, [&](const VarHandle& x) { return a(x + H); }); - LoopNest l({b, c}); - std::vector loops = NodeFinder::find(l.root_stmt()); + Compute("b", {{64, "x"}}, [&](const VarHandle& x) { return a.load(x); }); + Tensor* c = Compute( + "c", {{32, "x"}}, [&](const VarHandle& x) { return a.load(x + 10); }); + Tensor* d = Compute( + "d", {{96, "x"}}, [&](const VarHandle& x) { return a.load(x + 2); }); + LoopNest l({b, c, d}); - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 2); + auto bounds_info = inferBounds(l.root_stmt()); - // reads from a[0:5], writes to b[0:5] - ASSERT_EQ(bounds_info.at(a.data()).size(), 1); - ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.data())[0], {{0, 5}}); + ASSERT_EQ(bounds_info.size(), 4); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}}); - } + // a only read. { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 2); - - // reads from a[0+6:5+6], writes to c[0:5] - ASSERT_EQ(bounds_info.at(a.data()).size(), 1); - ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.data())[0], {{6, 11}}); + auto bounds = bounds_info[a.data()]; + ASSERT_EQ(bounds.size(), 1); + // One dimension. + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); + // Bounds: + // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b). + // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 = + // 96 + 2 - 1 (d). + verifyConstBounds(bound, {{0, 97}}); + } - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}}); + // b, c, d only written. + { + auto bounds = bounds_info[b->buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for b. + verifyConstBounds(bound, {{0, 63}}); } { - // Infer bounds on the high level program. - auto bounds_info = inferBounds(l.root_stmt()); - ASSERT_EQ(bounds_info.size(), 3); + auto bounds = bounds_info[c->buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for c. + verifyConstBounds(bound, {{0, 31}}); + } + { + auto bounds = bounds_info[d->buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for d. + verifyConstBounds(bound, {{0, 95}}); + } +} - // Should be union of above 2 bounds, but this time the bounds of A can be - // merged. - ASSERT_EQ(bounds_info.at(a.data()).size(), 1); - ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.data())[0], {{0, 11}}); +TEST(BoundsInference, MultipleTopLoopStore) { + KernelScope kernel_scope; + BufHandle a("a", {100}, kFloat); + BufHandle b("b", {100}, kFloat); + BufHandle c("c", {100}, kFloat); + BufHandle d("d", {100}, kFloat); + VarHandle x("x", kInt); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}}); + // Same as above but the offsets are on the Store now. + // Can't do this through ComputeAPI without transforms we don't have yet. + Stmt* stmt = Block::make( + {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}, 1), 1)), + For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}, 1), 1)), + For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x}, 1), 1))}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}}); + auto bounds_info = inferBounds(stmt); + + ASSERT_EQ(bounds_info.size(), 4); + + // a only read. + { + auto bounds = bounds_info[a.node()]; + ASSERT_EQ(bounds.size(), 1); + // One dimension. + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); + // Bounds: there are no offsets, so this is just the max loop bounds. + verifyConstBounds(bound, {{0, 95}}); + } + + // b, c, d only written. + { + auto bounds = bounds_info[b.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the b loop. + // b loop has no offset, so just the loop extents. + verifyConstBounds(bound, {{0, 63}}); + } + { + auto bounds = bounds_info[c.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the c loop. + // Offset is 10, extent is 32-1. + verifyConstBounds(bound, {{10, 41}}); + } + { + auto bounds = bounds_info[d.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the d loop. + // Offset is 2, extent is 96-1. + verifyConstBounds(bound, {{2, 97}}); } } -void testMergeInferredBounds() { +TEST(BoundsInference, CacheReads) { KernelScope kernel_scope; - Buffer a(BufHandle("a", {10}, kFloat)); - - // There are seven cases to consider in mergeTensorAccesses(A, B) - // * A is lower than B and does not overlap. - // * A is higher than B and does not overlap. - // * A overlaps B on both ends. - // * B overlaps A on both ends. - // * A overlaps B on the lower end. (equiv to B overlaps A on upper end). - // * A overlaps B on the upper end. (likewise covers reverse) - // * A and B are the same range. - - BoundsInfo info; - // Test no overlap, both ways. - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(3)}}); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}}); - info[a.data()].push_back({kLoad, {new IntImm(9)}, {new IntImm(9)}}); - BoundsInfo res = mergeTensorAccesses(info); - ASSERT_EQ(res.size(), 1); - ASSERT_EQ(res[a.data()].size(), 3); - - ASSERT_EQ(res.at(a.data())[0].kind, kLoad); - ASSERT_EQ(res.at(a.data())[1].kind, kLoad); - ASSERT_EQ(res.at(a.data())[2].kind, kLoad); - verifyConstBounds(res.at(a.data())[0], {{1, 3}}); - verifyConstBounds(res.at(a.data())[1], {{5, 7}}); - verifyConstBounds(res.at(a.data())[2], {{9, 9}}); - - // Test full overlap, A over B. - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}}); - info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{1, 7}}); - - // B over A. - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}}); - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{1, 7}}); - - // Test partial overlap on the low end, A over B. - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}}); - info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{3, 7}}); - - // Test partial overlap on the high end. - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(2)}, {new IntImm(5)}}); - info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{2, 6}}); - - // Test equality is deduped. - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}}); - info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{4, 6}}); + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 3); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + auto bounds_info_before = inferBounds(l.root_stmt()); + + Stmt* j_loop = l.getLoopStmtsFor(B)[1]; + l.cacheAccesses(A->buf(), "A_local", j_loop); + + auto bounds_info_after = inferBounds(l.root_stmt()); + + // CacheAccesses should not change existing bounds, but add a new one for the + // cache. + for (auto& pair : bounds_info_after) { + auto beforeIt = bounds_info_before.find(pair.first); + if (beforeIt != bounds_info_before.end()) { + // Same number of TensorAccessBoundInfos. + ASSERT_EQ(pair.second.size(), beforeIt->second.size()); + + for (size_t i = 0; i < pair.second.size(); ++i) { + TensorAccessBoundsInfo& after = pair.second[i]; + TensorAccessBoundsInfo& before = beforeIt->second[i]; + // Same number of dimensions. + ASSERT_EQ(before.start.size(), after.start.size()); + + // Bounds are equal. + for (size_t j = 0; j < before.start.size(); ++j) { + ASSERT_TRUE(exprEquals(before.start[j], after.start[j])); + ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j])); + } + } + } else { + // This should be the cache. + ASSERT_EQ(pair.first->name_hint(), "A_local"); + // Should have both a load and a store. + ASSERT_EQ(pair.second.size(), 2); + TensorAccessBoundsInfo& first = pair.second[0]; + TensorAccessBoundsInfo& second = pair.second[1]; + + ASSERT_NE(first.kind, second.kind); + // 2 dimensions. + ASSERT_EQ(first.start.size(), second.start.size()); + ASSERT_EQ(first.start.size(), 2); + + // bounds for load and store are equal. + for (size_t j = 0; j < first.start.size(); ++j) { + ASSERT_TRUE(exprEquals(first.start[j], second.start[j])); + ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j])); + } + } + } } -void testMergeInferredLoadStoreDiff() { +TEST(BoundsInference, Flattened) { KernelScope kernel_scope; - Buffer a(BufHandle("a", {10}, kFloat)); - - // Loads and Stores do not merge: - BoundsInfo info; - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}}); - info[a.data()].push_back({kStore, {new IntImm(3)}, {new IntImm(9)}}); - - BoundsInfo res = mergeTensorAccesses(info); - ASSERT_EQ(res.size(), 1); - ASSERT_EQ(res[a.data()].size(), 2); - ASSERT_EQ(res.at(a.data())[0].kind, kLoad); - ASSERT_EQ(res.at(a.data())[1].kind, kStore); - verifyConstBounds(res.at(a.data())[0], {{1, 7}}); - verifyConstBounds(res.at(a.data())[1], {{3, 9}}); - - // Do merge around the other kind of access: - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(3)}}); - info[a.data()].push_back({kStore, {new IntImm(3)}, {new IntImm(4)}}); - info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(5)}}); - info[a.data()].push_back({kStore, {new IntImm(4)}, {new IntImm(8)}}); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}}); - res = mergeTensorAccesses(info); - - ASSERT_EQ(res[a.data()].size(), 2); - verifyConstBounds(res.at(a.data())[0], {{1, 7}}); - verifyConstBounds(res.at(a.data())[1], {{3, 8}}); + Tensor* b = Compute( + "b", + {{3, "z"}, {4, "y"}, {5, "x"}}, + [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { + return x * y + z; + }); + + LoopNest l({b}); + // Flatten indices. + l.prepareForCodegen(); + auto bounds_info = inferBounds(l.root_stmt()); + + // There's only one buffer. + ASSERT_EQ(bounds_info.size(), 1); + auto& TABI = bounds_info[b->buf()][0]; + ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); + // Flattened bounds should have a single dimension. + ASSERT_EQ(TABI.start.size(), 1); + ASSERT_EQ(TABI.stop.size(), 1); + + // Bounds should be 0 -> (3*4*5)-1 + ASSERT_TRUE(exprEquals(TABI.start[0], new IntImm(0))); + ASSERT_TRUE(exprEquals(TABI.stop[0], new IntImm(3 * 4 * 5 - 1))); } -void testMergeInferred2DBounds() { +void testGetPotentialHazards() { KernelScope kernel_scope; - Buffer a(BufHandle("a", {10, 10}, kFloat)); - - // Non overlapping in both dimensions: - BoundsInfo info; - info[a.data()].push_back( - {kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}}); - info[a.data()].push_back( - {kLoad, {new IntImm(5), new IntImm(5)}, {new IntImm(9), new IntImm(9)}}); - - BoundsInfo res = mergeTensorAccesses(info); - ASSERT_EQ(res.size(), 1); - ASSERT_EQ(res[a.data()].size(), 2); - ASSERT_EQ(res.at(a.data())[0].kind, kLoad); - ASSERT_EQ(res.at(a.data())[1].kind, kLoad); - verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}}); - verifyConstBounds(res.at(a.data())[1], {{5, 9}, {5, 9}}); - - // Overlapping in a single dimension should mean we cannot merge. - // First dimension: - info.clear(); - info[a.data()].push_back( - {kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}}); - info[a.data()].push_back( - {kLoad, {new IntImm(2), new IntImm(5)}, {new IntImm(9), new IntImm(9)}}); - - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}}); - verifyConstBounds(res.at(a.data())[1], {{2, 9}, {5, 9}}); - - // Second dimension: - info.clear(); - info[a.data()].push_back( - {kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}}); - info[a.data()].push_back( - {kLoad, {new IntImm(5), new IntImm(2)}, {new IntImm(9), new IntImm(9)}}); - - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}}); - verifyConstBounds(res.at(a.data())[1], {{5, 9}, {2, 9}}); - - // Overlapping in both dimensions: - // {1-6, 1-3) | {4-9, 2,7} => {1,9, 1,7} - // TODO: this will overestimate and we should fix it. - info.clear(); - info[a.data()].push_back( - {kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(6), new IntImm(3)}}); - info[a.data()].push_back( - {kLoad, {new IntImm(4), new IntImm(2)}, {new IntImm(9), new IntImm(7)}}); - - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{1, 9}, {1, 7}}); + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* + * A[0] = B[0]; + * B[0] = 3; WAR on B + * A[0] = B[0]; WAW on A, RAW on B + * C[0] = 5; + */ + + Store* store1 = Store::make(a, {0}, Load::make(b, {0}, 1), 1); + Store* store2 = Store::make(b, {0}, 3, 1); + Store* store3 = Store::make(a, {0}, Load::make(b, {0}, 1), 1); + Store* store4 = Store::make(c, {0}, 5, 1); + Stmt* stmt = Block::make({store1, store2, store3, store4}); + + MemDependencyChecker analyzer; + stmt->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterRead, + getPotentialHazards(analyzer, store1, store2)); + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, store2, store3)); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, + getPotentialHazards(analyzer, store1, store3)); + + // Fourth store has no dependencies + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store1, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store2, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store3, store4)); + } } -void testMergeAdjacentBounds() { +void testGetPotentialHazardsLoopNoHazard() { KernelScope kernel_scope; - Buffer a(BufHandle("a", {10}, kFloat)); - - // Adjacent but not overlapping bounds can be merged. - // e.g. {1-4} | {5-9} => {1-9} - BoundsInfo info; - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}}); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(9)}}); - BoundsInfo res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{1, 9}}); - - // And on the other side: - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(9)}}); - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - verifyConstBounds(res.at(a.data())[0], {{1, 9}}); - - // One space gap is enough to prevent merging: - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}}); - info[a.data()].push_back({kLoad, {new IntImm(6)}, {new IntImm(9)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - verifyConstBounds(res.at(a.data())[0], {{1, 4}}); - verifyConstBounds(res.at(a.data())[1], {{6, 9}}); -} -std::pair boundAsStringPair( - TensorAccessBoundsInfo& info, - size_t idx = 0) { - std::ostringstream start, stop; - start << *info.start[idx]; - stop << *info.stop[idx]; - return {start.str(), stop.str()}; + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return (i + 1) * (j + 1); + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + For* loopRootA = l.getLoopStmtsFor(A)[0]; + For* loopRootB = l.getLoopStmtsFor(B)[0]; + + // No dependencies between loops. + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, loopRootA, loopRootB)); } -void testMergeSymbolicBounds() { +void testGetPotentialHazardsLoopCall() { KernelScope kernel_scope; - Buffer a(BufHandle("a", {10}, kFloat)); - VarHandle W("W", kInt); - VarHandle X("X", kInt); - VarHandle Y("Y", kInt); - VarHandle Z("Z", kInt); - - // Can do nothing with fully symbolic bounds: - BoundsInfo info; - info[a.data()].push_back({kLoad, {W.node()}, {Z.node()}}); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - BoundsInfo res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - - // Can merge if the difference between bounds is constant and enclosing. - // {X-Y} | {X-5 - Y+10} => {X-5 - Y+10} - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - info[a.data()].push_back({kLoad, - {new Sub(X.node(), new IntImm(5))}, - {new Add(Y.node(), new IntImm(10))}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - - // Cannot merge otherwise. - // {X-Y} | {X+5 - Y+10} => could be 2 groups if Y < X+5. - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - info[a.data()].push_back({kLoad, - {new Add(X.node(), new IntImm(5))}, - {new Add(Y.node(), new IntImm(10))}}); - - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - - // Can't merge if there's a gap of at least one element: - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(4)}}); - info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - - // Can't even though the high of the first bound is above the low of the - // second, X can == 6 and Y can == 4 so this can't merge in all cases. - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}}); - info[a.data()].push_back({kLoad, {new IntImm(4)}, {Y.node()}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - - // If either side is equal, they must be overlapping. - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {Z.node()}}); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - auto pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "X"); - ASSERT_EQ(pair.second, "Max(Y, Z, 1)"); - - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - info[a.data()].push_back({kLoad, {Z.node()}, {Y.node()}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "Min(X, Z, 1)"); - ASSERT_EQ(pair.second, "Y"); - - // If either side is only one apart, they must be adjacent. - info.clear(); - info[a.data()].push_back( - {kLoad, {new Add(X.node(), new IntImm(1))}, {Z.node()}}); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "X"); - ASSERT_EQ(pair.second, "Max(Y, Z, 1)"); - - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - info[a.data()].push_back( - {kLoad, {Z.node()}, {new Sub(Y.node(), new IntImm(1))}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "Min(X, Z, 1)"); - ASSERT_EQ(pair.second, "Y"); - - // If either side is 2 apart, they may not be overlapping. - // in this case if Y == X+1 they don't overlap. - info.clear(); - info[a.data()].push_back( - {kLoad, {new Add(X.node(), new IntImm(2))}, {Z.node()}}); - info[a.data()].push_back( - {kLoad, {X.node()}, {new Sub(Y.node(), new IntImm(1))}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); - - // In this case they may not overlap if X == Y. - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}}); - info[a.data()].push_back( - {kLoad, {Z.node()}, {new Sub(Y.node(), new IntImm(2))}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 2); + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{64, "i"}, {64, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i, j) + 5; + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + For* loopRootA = l.getLoopStmtsFor(A)[0]; + For* loopRootB = l.getLoopStmtsFor(B)[0]; + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, loopRootA, loopRootB)); } -void testMergeSymbolicAdjacent() { +void testGetPotentialHazardsLoopSplit() { KernelScope kernel_scope; - Buffer a(BufHandle("a", {10}, kFloat)); - VarHandle X("X", kInt); - VarHandle Y("Y", kInt); - - BoundsInfo info; - // Can merge if a range is adjacent: - // {X-5} | {6-Y} => {X-Y} - info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(5)}}); - info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}}); - BoundsInfo res = mergeTensorAccesses(info); - - ASSERT_EQ(res[a.data()].size(), 1); - auto pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "X"); - ASSERT_EQ(pair.second, "Y"); - - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}}); - info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(5)}}); - res = mergeTensorAccesses(info); - - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "X"); - ASSERT_EQ(pair.second, "Y"); - - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}}); - info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}}); - res = mergeTensorAccesses(info); - - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "X"); - ASSERT_EQ(pair.second, "Y"); - - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}}); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}}); - res = mergeTensorAccesses(info); - - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "X"); - ASSERT_EQ(pair.second, "Y"); - - // If either the lower or upper bound is adjacent the range then they must - // overlap, even if we don't know the extent. - info.clear(); - info[a.data()].push_back({kLoad, {new IntImm(6)}, {X.node()}}); - info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "5"); - ASSERT_EQ(pair.second, "Max(X, Y, 1)"); - - info.clear(); - info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}}); - info[a.data()].push_back({kLoad, {Y.node()}, {new IntImm(5)}}); - res = mergeTensorAccesses(info); - ASSERT_EQ(res[a.data()].size(), 1); - pair = boundAsStringPair(res[a.data()][0]); - ASSERT_EQ(pair.first, "Min(X, Y, 1)"); - ASSERT_EQ(pair.second, "6"); + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + + LoopNest l({A}); + For *outer, *inner, *tail; + + // Splitting with tail by something offset creates a tail which also writes to + // A. + l.splitWithTail(l.getLoopStmtsFor(A)[0], 5, &outer, &inner, &tail); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); } } // namespace jit diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp new file mode 100644 index 0000000000000..f49cac72e6930 --- /dev/null +++ b/test/cpp/tensorexpr/test_conv.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace te = torch::jit::tensorexpr; +namespace F = torch::nn::functional; + +TEST(Conv, Conv2D) { + te::KernelScope kernel_scope; + + // Input dimensions. + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 11; + constexpr int W = 11; + + // Filter dimensions. + constexpr int K = 8; + constexpr int R = 3; + constexpr int S = 3; + + // Output dims. + constexpr int OH = H - R + 1; + constexpr int OW = W - S + 1; + + // Compute reference result. + at::Tensor input = torch::randn({N, C, H, W}); + at::Tensor filter = torch::randn({K, C, R, S}); + at::Tensor ref = F::conv2d(input, filter); + + // Double check the output size is as expected. + ASSERT_EQ(ref.size(0), N); + ASSERT_EQ(ref.size(1), K); + ASSERT_EQ(ref.size(2), OH); + ASSERT_EQ(ref.size(3), OW); + + te::Placeholder inputB(te::BufHandle("input", {N, C, H, W}, te::kFloat)); + te::Placeholder filterB(te::BufHandle("filter", {K, C, R, S}, te::kFloat)); + + te::Tensor* conv = te::Reduce( + "conv", + {{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}}, + te::Sum(), + // FIXME: We have to use a `std::vector` parameter here and then unpack + // it, because we don't have an overload allowing for an arbitrary number + // of ExprHandle/VarHandle parameters. + [&](const std::vector& v) { + auto const& n = v[0]; + auto const& k = v[1]; + auto const& oh = v[2]; + auto const& ow = v[3]; + auto const& c = v[4]; + auto const& r = v[5]; + auto const& s = v[6]; + // FIXME: We have to use `call` and construct a `std::vector` here + // because the `operator()` overload is only specialized for a small + // number of arguments. + return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); + }, + // FIXME: If you forget one of the reduction dims, you get a segfault. + // Could that be caught by a verifier? + {{C, "c"}, {R, "r"}, {S, "s"}}); + + // FIXME: It'd be nice to have a single header that pulls in things like + // LoopNest, IRSimplifier, etc. + te::LoopNest loop({conv}); + loop.prepareForCodegen(); + te::Stmt* s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + + at::Tensor result = at::empty_like(ref); + te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); + cg.call( + {input.data_ptr(), + filter.data_ptr(), + result.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index 2ad70e158ebfc..44323416b18f6 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -1,14 +1,15 @@ #ifdef USE_CUDA +#include #include #include -#include "test/cpp/tensorexpr/test_base.h" -#include +#include + +#include "test/cpp/tensorexpr/test_base.h" #include #include "test/cpp/tensorexpr/padded_buffer.h" -#include "torch/csrc/jit/tensorexpr/buffer.h" #include "torch/csrc/jit/tensorexpr/cuda_codegen.h" #include "torch/csrc/jit/tensorexpr/ir_simplifier.h" #include "torch/csrc/jit/tensorexpr/loopnest.h" @@ -25,14 +26,14 @@ using namespace torch::jit::tensorexpr; using namespace torch::jit::tensorexpr; template -void testCudaTestVectorAdd01_impl() { +static void testCudaTestVectorAdd01_impl() { KernelScope kernel_scope; const int num_iter = 3; const int block_count = 16; const int block_size = 128; Dtype dtype = ToDtype(); - Buffer a_buf("a", dtype, {num_iter, block_count, block_size}); - Buffer b_buf("b", dtype, {num_iter, block_count, block_size}); + Placeholder a_buf("a", dtype, {num_iter, block_count, block_size}); + Placeholder b_buf("b", dtype, {num_iter, block_count, block_size}); Tensor* c = Compute( "c", { @@ -41,7 +42,7 @@ void testCudaTestVectorAdd01_impl() { {block_size, "t_id"}, }, [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return a_buf(n, b_id, t_id) + b_buf(n, b_id, t_id); + return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); }); LoopNest l({c}); std::vector loops = l.getLoopStmtsFor(c); @@ -91,13 +92,13 @@ float sigmoid(float x) { return 1.0f / (1.0f + expf(-0.0f - x)); } -void testCudaSigmoid() { +TEST(Cuda, Sigmoid_CUDA) { KernelScope kernel_scope; const int num_iter = 3; const int block_count = 16; const int block_size = 128; Dtype dtype = ToDtype(); - Buffer a_buf("a", dtype, {num_iter, block_count, block_size}); + Placeholder a_buf("a", dtype, {num_iter, block_count, block_size}); Tensor* c = Compute( "c", { @@ -106,7 +107,7 @@ void testCudaSigmoid() { {block_size, "t_id"}, }, [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return sigmoid(sigmoid(a_buf(n, b_id, t_id))); + return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); }); LoopNest l({c}); std::vector loops = l.getLoopStmtsFor(c); @@ -146,7 +147,7 @@ void testCudaSigmoid() { cudaFree(c_dev); } -void testCudaTestVectorAdd01() { +TEST(Cuda, TestVectorAdd01_CUDA) { // floating types. testCudaTestVectorAdd01_impl(); testCudaTestVectorAdd01_impl(); @@ -162,14 +163,14 @@ void testCudaTestVectorAdd01() { static void testCudaTestVectorAdd02_impl(int N, int block_size) { KernelScope kernel_scope; - Buffer a_buf("a", kFloat, {N}); - Buffer b_buf("b", kFloat, {N}); + Placeholder a_buf("a", kFloat, {N}); + Placeholder b_buf("b", kFloat, {N}); Tensor* c = Compute( "c", { {N, "N"}, }, - [&](const VarHandle& n) { return a_buf(n) + b_buf(n); }); + [&](const VarHandle& n) { return a_buf.load(n) + b_buf.load(n); }); LoopNest l({c}); For* n_outer; For* n_inner; @@ -216,17 +217,17 @@ static void testCudaTestVectorAdd02_impl(int N, int block_size) { cudaFree(c_dev); } -void testCudaTestVectorAdd02() { +TEST(Cuda, TestVectorAdd02_CUDA) { testCudaTestVectorAdd02_impl(1024, 128); testCudaTestVectorAdd02_impl(1030, 128); } -void testCudaHalfCast() { +TEST(Cuda, HalfCast_CUDA) { KernelScope ks; auto half = ToDtype(); - Buffer a("a", half, {4}); + Placeholder a("a", half, {4}); Tensor* b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) { - return Cast::make(kFloat, a(i)); + return Cast::make(kFloat, a.load(i)); }); LoopNest l({b}); @@ -260,16 +261,16 @@ void testCudaHalfCast() { cudaFree(bDev); } -void testCudaDynamicShape2D() { +TEST(Cuda, DynamicShape2D_CUDA) { KernelScope kernel_scope; auto testWithSize = [](int32_t M, int32_t N) { VarHandle m("m", kInt); VarHandle n("n", kInt); - Buffer a(BufHandle("a", {m, n}, kFloat)); - Buffer b(BufHandle("b", {m, n}, kFloat)); + Placeholder a(BufHandle("a", {m, n}, kFloat)); + Placeholder b(BufHandle("b", {m, n}, kFloat)); Tensor* c = Compute( "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { - return a(i, j) + b(i, j); + return a.load(i, j) + b.load(i, j); }); LoopNest l({c}); l.prepareForCodegen(); @@ -323,7 +324,7 @@ void testCudaDynamicShape2D() { testWithSize(27, 13); } -void testCudaTestRand01() { +TEST(Cuda, TestRand01_CUDA) { KernelScope kernel_scope; const int num_iter = 3; const int block_count = 16; @@ -382,13 +383,13 @@ void testCudaTestRand01() { cudaFree(c_dev); } -void testCudaDynamicShapeSplit() { +TEST(Cuda, DynamicShapeSplit_CUDA) { KernelScope ks; constexpr int N = 4096; VarHandle n("n", kInt); - Buffer a(BufHandle("a", {n}, kFloat)); - Tensor* b = - Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; }); + Placeholder a(BufHandle("a", {n}, kFloat)); + Tensor* b = Compute( + "b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); LoopNest l({b}); For* outer; For* inner; @@ -433,11 +434,11 @@ void testCudaDynamicShapeSplit() { cudaFree(bDev); } -void testCudaOneBlockOneThreadGlobalReduce1() { +TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { const static int N = 1024; KernelScope kernel_scope; - Buffer data_buf("data", kFloat, {N}); - Buffer output_buf("output", kFloat, {1}); + Placeholder data_buf("data", kFloat, {N}); + Placeholder output_buf("output", kFloat, {1}); // The test adds the following code for trivial reduction: // for (int bidx = 0; bidx < 1; bidx++) { // blockIdx.x @@ -449,12 +450,12 @@ void testCudaOneBlockOneThreadGlobalReduce1() { // } // } - Store* init_store = Store::make(output_buf, {0}, 0.f, 1); + Store* init_store = output_buf.store({0}, 0.f); VarHandle i1("i1", kInt); - ExprHandle load_data = Load::make(data_buf, {i1}, 1); - ExprHandle load_output = Load::make(output_buf, {0}, 1); + ExprHandle load_data = Load::make(BufHandle(data_buf.data()), {i1}, 1); + ExprHandle load_output = Load::make(BufHandle(output_buf.data()), {0}, 1); ExprHandle add_value = load_output + load_data; - Store* store_output = Store::make(output_buf, {0}, add_value, 1); + Store* store_output = output_buf.store({0}, add_value); For* for_output = For::make(i1, 0, N, store_output); Stmt* reduce_block = Block::make({init_store, for_output}); VarHandle thread_idx("tidx", kInt); @@ -500,7 +501,7 @@ void testCudaOneBlockOneThreadGlobalReduce1() { cudaFree(output_dev); } -void testCudaOneBlockMultiThreadGlobalReduce1() { +TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { const static int N = 1024; KernelScope kernel_scope; @@ -515,10 +516,10 @@ void testCudaOneBlockMultiThreadGlobalReduce1() { // b[0] = b[0] + a[t] // implied atomic // clang-format on - Buffer a_buf("a", kFloat, {N}); - Buffer b_buf("b", kFloat, {1}); + Placeholder a_buf("a", kFloat, {N}); + Placeholder b_buf("b", kFloat, {1}); - Store* init_store = Store::make(b_buf, {0}, 0.f, 1); + Store* init_store = b_buf.store({0}, 0.f); VarHandle t("t", kInt); VarHandle b("b", kInt); @@ -534,10 +535,10 @@ void testCudaOneBlockMultiThreadGlobalReduce1() { // for t in 0..1024: // thread-idx // b[0] = b[0] + a[t] // implied atomic - ExprHandle load_a = Load::make(a_buf, {t}, 1); - ExprHandle load_b = Load::make(b_buf, {0}, 1); + ExprHandle load_a = Load::make(BufHandle(a_buf.data()), {t}, 1); + ExprHandle load_b = Load::make(BufHandle(b_buf.data()), {0}, 1); ExprHandle add_value = load_b + load_a; - Store* store_b = Store::make(b_buf, {0}, add_value, 1); + Store* store_b = b_buf.store({0}, add_value); For* for_b = For::make(t, 0, N, store_b, thread_idx_options); Stmt* reduce_block = Block::make({for_init, for_b}); @@ -578,7 +579,7 @@ void testCudaOneBlockMultiThreadGlobalReduce1() { cudaFree(b_dev); } -void testCudaNoThreadIdxWrite_1() { +TEST(Cuda, NoThreadIdxWrite_1_CUDA) { KernelScope kernel_scope; // This test does the following reduction: @@ -597,8 +598,8 @@ void testCudaNoThreadIdxWrite_1() { // covered by its own thread-idx const static int N = 1024; - Buffer a_buf("a", kFloat, {2}); - Buffer b_buf("b", kFloat, {N}); + Placeholder a_buf("a", kFloat, {2}); + Placeholder b_buf("b", kFloat, {N}); VarHandle k("k", kInt); VarHandle l("l", kInt); @@ -608,15 +609,15 @@ void testCudaNoThreadIdxWrite_1() { // a[0] = 0 // for n in 0..2: // a[0] = a[0] + n - Store* store_a0_0 = Store::make(a_buf, {0}, 0.f, 1); - ExprHandle load_a0 = Load::make(a_buf, {0}, 1); + Store* store_a0_0 = a_buf.store({0}, 0.f); + ExprHandle load_a0 = Load::make(BufHandle(a_buf.data()), {0}, 1); ExprHandle v1 = load_a0 + n; - Store* store_a0_v1 = Store::make(a_buf, {0}, v1, 1); + Store* store_a0_v1 = a_buf.store({0}, v1); For* loop_a_0 = For::make(n, 0, 2, store_a0_v1); // for m in 0..1024: // thread-idx // b[m] = m - Store* store_bm_m = Store::make(b_buf, {m}, m + 0.f, 1); + Store* store_bm_m = b_buf.store({m}, m + 0.f); LoopOptions thread_idx_options; thread_idx_options.set_gpu_thread_index(0); For* loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); @@ -624,10 +625,10 @@ void testCudaNoThreadIdxWrite_1() { // a[1] = 1 // for l in 0..2: // a[1] = a[1] + l - Store* store_a1_1 = Store::make(a_buf, {1}, 1.f, 1); - ExprHandle load_a1 = Load::make(a_buf, {1}, 1); + Store* store_a1_1 = a_buf.store({1}, 1.f); + ExprHandle load_a1 = a_buf.load(1); ExprHandle v2 = load_a1 + l; - Store* store_a1_v2 = Store::make(a_buf, {1}, v2, 1); + Store* store_a1_v2 = a_buf.store({1}, v2); For* loop_a_1 = For::make(l, 0, 2, store_a1_v2); Stmt* reduce_block = @@ -675,7 +676,7 @@ void testCudaNoThreadIdxWrite_1() { cudaFree(b_dev); } -void testCudaSharedMemReduce_1() { +TEST(Cuda, SharedMemReduce_1_CUDA) { // FIXME: this test is flaky in CI. KernelScope kernel_scope; // This test does the following: @@ -699,8 +700,8 @@ void testCudaSharedMemReduce_1() { LoopOptions block_idx_opt; block_idx_opt.set_gpu_block_index(0); - Buffer a("a", kFloat, {1, M, N}); - Buffer b("b", kFloat, {1}); + Placeholder a("a", kFloat, {1, M, N}); + Placeholder b("b", kFloat, {1}); VarHandle k("k", kInt); VarHandle m("m", kInt); VarHandle n("n", kInt); @@ -729,7 +730,8 @@ void testCudaSharedMemReduce_1() { // for n in 0..64: // thread_idx // c(n) = c(n) + a(k, m, n) ExprHandle load_cn = Load::make(kFloat, c, {n}, 1); - ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}, 1); + ExprHandle a_kmn = + Load::make(BufHandle(a.data()), {k * (M * N) + m * N + n}, 1); ExprHandle v_add = load_cn + a_kmn; Store* store_cn_v = Store::make(c, {n}, v_add); For* loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); @@ -741,12 +743,12 @@ void testCudaSharedMemReduce_1() { // b(k) = 0 // for n in 0..64: // thread_idx // b(k) = b(k) + c(n) - Store* store_bk_0 = Store::make(b, {k}, 0.f, 1); + Store* store_bk_0 = b.store({k}, 0.f); block.push_back(store_bk_0); - ExprHandle load_bk = Load::make(b, {k}, 1); + ExprHandle load_bk = b.load(k); ExprHandle load_cn = Load::make(kFloat, c, {n}, 1); ExprHandle v_add = load_bk + load_cn; - Store* store_bk = Store::make(b, {k}, v_add, 1); + Store* store_bk = b.store({k}, v_add); For* loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); block.push_back(loop_n3); } @@ -769,14 +771,14 @@ void testCudaSharedMemReduce_1() { // Check the c write is not masked, but the d write is. const std::string& verification_pattern = R"IR( -# CHECK: c_ = 0 +# CHECK: c_1 = 0 # CHECK: for (int m = 0; m < 128 -# CHECK: c_ = c_ + +# CHECK: c_1 = c_1 + # CHECK: __syncthreads(); # CHECK: if (threadIdx.x<1 # CHECK: b[blockIdx.x] = # CHECK: __syncthreads(); -# CHECK: atomicAdd(&b[blockIdx.x], c_) +# CHECK: atomicAdd(&b[blockIdx.x], c_1) )IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -814,7 +816,7 @@ void testCudaSharedMemReduce_1() { cudaFree(b_dev); } -void testCudaLocalMemReduce_1() { +TEST(Cuda, LocalMemReduce_1_CUDA) { KernelScope kernel_scope; // This test does the following: // for k in 0..1: // block-idx @@ -835,8 +837,8 @@ void testCudaLocalMemReduce_1() { LoopOptions block_idx_opt; block_idx_opt.set_gpu_block_index(0); - Buffer a("a", kFloat, {1, M, N}); - Buffer b("b", kFloat, {1}); + Placeholder a("a", kFloat, {1, M, N}); + Placeholder b("b", kFloat, {1}); VarHandle k("k", kInt); VarHandle m("m", kInt); VarHandle n("n", kInt); @@ -848,7 +850,7 @@ void testCudaLocalMemReduce_1() { std::vector block_k; { // b(k) = 0 - Store* store_bk_0 = Store::make(b, {k}, 0.f, 1); + Store* store_bk_0 = b.store({k}, 0.f); block_k.push_back(store_bk_0); } std::vector block_n; @@ -866,7 +868,7 @@ void testCudaLocalMemReduce_1() { // for m in 0..128: // c(0) = c(0) + a(k, m, n) ExprHandle load_c0 = Load::make(kFloat, c, {0}, 1); - ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}, 1); + ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); ExprHandle v_add = load_c0 + a_kmn; Store* store_c0_v = Store::make(c, {0}, v_add); For* loop_m = For::make(m, 0, M, store_c0_v); @@ -874,10 +876,10 @@ void testCudaLocalMemReduce_1() { } { // b(k) = b(k) + c(0) - ExprHandle load_bk = Load::make(b, {k}, 1); + ExprHandle load_bk = b.load(k); ExprHandle load_c0 = Load::make(kFloat, c, {0}, 1); ExprHandle v_add = load_bk + load_c0; - Store* store_bk = Store::make(b, {k}, v_add, 1); + Store* store_bk = b.store({k}, v_add); block_n.push_back(store_bk); } { @@ -927,12 +929,12 @@ void testCudaLocalMemReduce_1() { cudaFree(b_dev); } -void testCudaHalfSupport() { +TEST(Cuda, HalfSupport_CUDA) { KernelScope ks; auto half = ToDtype(); - Buffer a("a", half, {4}); + Placeholder a("a", half, {4}); Tensor* b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) { - return Cast::make(half, ExprHandle(2.0f) * a(i)); + return Cast::make(half, ExprHandle(2.0f) * a.load(i)); }); Tensor* c = Compute("c", {{4, "n"}}, [&](const VarHandle& i) { @@ -985,11 +987,117 @@ void testCudaHalfSupport() { cudaFree(dDev); } -void testCudaPrioritizeDependents() { +TEST(Cuda, HalfPropagation_CUDA) { KernelScope kernel_scope; - Buffer a("a", kFloat, {10}); - Buffer b("b", kFloat, {12}); - Buffer c("c", kFloat, {12}); + auto half = ToDtype(); + Placeholder a("a", half, {4}); + Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { + return Max::make(a.load(i), ExprHandle(new HalfImm(0)), true); + }); + + LoopNest l({relu}); + l.prepareForCodegen(); + Stmt* s = l.root_stmt(); + CudaCodeGen cg(s, {a, relu}); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the types used by the Max are Float. + const std::string& verification_pattern = + R"IR( +# CHECK: for ( +# CHECK: float v = float(a[n]); +# CHECK: relu[n] = half(Max(v, 0.f +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector aData(4, 2.0f); + std::vector reluData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* reluDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto reluSize = reluData.size() * sizeof(reluData[0]); + + cudaMalloc(&aDev, aSize); + cudaMalloc(&reluDev, reluSize); + cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice); + cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cg.call({aDev, reluDev}); + cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + assertAllEqual(aData, reluData); + + cudaFree(aDev); + cudaFree(reluDev); +} + +TEST(Cuda, UnusedHalfArgument_CUDA) { + KernelScope kernel_scope; + Placeholder a("a", kFloat, {4}); + auto half = ToDtype(); + Placeholder b("b", half, {4}); + Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { + return Max::make(a.load(i), ExprHandle(new FloatImm(0)), true); + }); + + LoopNest l({relu}); + l.prepareForCodegen(); + Stmt* s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, relu}); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the types used by the Max are Float. + const std::string& verification_pattern = + R"IR( +# CHECK: for ( +# CHECK: float v = a[n]; +# CHECK: relu[n] = Max(v, 0.f +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // Sanity Cbeck; + std::vector aData(4, 2.0f); + std::vector bData(4, 2.0f); + std::vector reluData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* bDev = nullptr; + at::Half* reluDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto bSize = bData.size() * sizeof(bData[0]); + auto reluSize = reluData.size() * sizeof(reluData[0]); + + cudaMalloc(&aDev, aSize); + cudaMalloc(&bDev, bSize); + cudaMalloc(&reluDev, reluSize); + cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice); + cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice); + cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cg.call({aDev, bDev, reluDev}); + cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + assertAllEqual(aData, reluData); + + cudaFree(aDev); + cudaFree(bDev); + cudaFree(reluDev); +} + +TEST(Cuda, PrioritizeDependents_CUDA) { + KernelScope kernel_scope; + Placeholder a("a", kFloat, {10}); + Placeholder b("b", kFloat, {12}); + Placeholder c("c", kFloat, {12}); LoopOptions block_idx_opt; block_idx_opt.set_gpu_block_index(0); @@ -1002,13 +1110,13 @@ void testCudaPrioritizeDependents() { * c[i] = (i < 10 ? a[i] + b[i] : b[i]); * } */ - ExprHandle load_a = Load::make(a, {i}, 1); - ExprHandle load_b = Load::make(b, {i}, 1); + ExprHandle load_a = Load::make(BufHandle(a.data()), {i}, 1); + ExprHandle load_b = Load::make(BufHandle(b.data()), {i}, 1); ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); - For* loop = For::make( - i, 0, 12, Block::make({Store::make(c, {i}, ite, 1)}), block_idx_opt); + For* loop = + For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); CudaCodeGen cuda_cg(loop, a, b, c); @@ -1059,16 +1167,17 @@ void testCudaPrioritizeDependents() { /// Tests the case where there are two loops which have different extents bound /// to the same block dimension. We must mask the smaller extent loop body. -void testCudaMaskBlockDim() { +TEST(Cuda, MaskBlockDim_CUDA) { KernelScope kernel_scope; int A_SIZE = 100; int B_SIZE = 50; - Buffer a_buf("a", kFloat, {A_SIZE}); - Buffer b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute( - "c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf(i) + 10; }); + Placeholder a_buf("a", kFloat, {A_SIZE}); + Placeholder b_buf("b", kFloat, {B_SIZE}); + Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + return a_buf.load(i) + 10; + }); Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { - return a_buf(i) + b_buf(i); + return a_buf.load(i) + b_buf.load(i); }); LoopNest l({c, d}); @@ -1151,16 +1260,17 @@ void testCudaMaskBlockDim() { /// Tests the case with two loops, which have different extents that are bound /// to the same thread dimension. This is the same as the above - the smaller /// rank write should be masked. But this time we also need to syncthreads. -void testCudaMaskThreadDim() { +TEST(Cuda, MaskThreadDim_CUDA) { KernelScope kernel_scope; int A_SIZE = 50; int B_SIZE = 100; - Buffer a_buf("a", kFloat, {A_SIZE}); - Buffer b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute( - "c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf(i) + 10; }); + Placeholder a_buf("a", kFloat, {A_SIZE}); + Placeholder b_buf("b", kFloat, {B_SIZE}); + Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + return a_buf.load(i) + 10; + }); Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { - return a_buf(i / 2) + b_buf(i); + return a_buf.load(i / 2) + b_buf.load(i); }); LoopNest l({c, d}); @@ -1245,16 +1355,17 @@ void testCudaMaskThreadDim() { /// in distinct dimensions. // Note: this is an extremely dumb pattern which we should never see, but is a // useful edge case to make sure we've got things covered. -void testCudaMaskMultiBlockDim() { +TEST(Cuda, MaskMultiBlockDim_CUDA) { KernelScope kernel_scope; int A_SIZE = 100; int B_SIZE = 50; - Buffer a_buf("a", kFloat, {A_SIZE}); - Buffer b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute( - "c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf(i) + 10; }); + Placeholder a_buf("a", kFloat, {A_SIZE}); + Placeholder b_buf("b", kFloat, {B_SIZE}); + Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + return a_buf.load(i) + 10; + }); Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { - return a_buf(i) + b_buf(i); + return a_buf.load(i) + b_buf.load(i); }); LoopNest l({c, d}); @@ -1338,16 +1449,17 @@ void testCudaMaskMultiBlockDim() { /// distinct. // Note: this is an extremely dumb pattern which we should never see, but is a // useful edge case to make sure we've got things covered. -void testCudaMaskBlockAndThreadDim() { +TEST(Cuda, MaskBlockAndThreadDim_CUDA) { KernelScope kernel_scope; int A_SIZE = 100; int B_SIZE = 50; - Buffer a_buf("a", kFloat, {A_SIZE}); - Buffer b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute( - "c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf(i) + 10; }); + Placeholder a_buf("a", kFloat, {A_SIZE}); + Placeholder b_buf("b", kFloat, {B_SIZE}); + Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + return a_buf.load(i) + 10; + }); Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { - return a_buf(i) + b_buf(i); + return a_buf.load(i) + b_buf.load(i); }); LoopNest l({c, d}); @@ -1429,24 +1541,24 @@ void testCudaMaskBlockAndThreadDim() { /// Tests the case where the loopnest has two loops of depth two: each with the /// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In /// this case all writes with a rank smaller than the max should be masked. -void testCudaMaskMultiDim() { +TEST(Cuda, MaskMultiDim_CUDA) { KernelScope kernel_scope; int OUTER_SIZE = 10; int A_SIZE = 100; int B_SIZE = 50; - Buffer a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); - Buffer b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); + Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); + Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); Tensor* c = Compute( "C", {{OUTER_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf(i, j); + return ExprHandle(2) * a_buf.load(i, j); }); Tensor* d = Compute( "D", {{OUTER_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->call(i, j * 2) + b_buf(i, j); + return c->call(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); @@ -1559,24 +1671,24 @@ void testCudaMaskMultiDim() { // Tests the case where loop extents are symbolic and not known at compile time. // In this case both stores must be masked against the extent of the other loop, // incase it is larger. -void testCudaMaskMultiDimSymbolic() { +TEST(Cuda, MaskMultiDimSymbolic_CUDA) { KernelScope kernel_scope; VarHandle OUTER_SIZE("OUTER_SIZE", kInt); VarHandle A_SIZE("A_SIZE", kInt); VarHandle B_SIZE("B_SIZE", kInt); - Buffer a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); - Buffer b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); + Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); + Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); Tensor* c = Compute( "C", {{OUTER_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf(i, j); + return ExprHandle(2) * a_buf.load(i, j); }); Tensor* d = Compute( "D", {{OUTER_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->call(i, j * 2) + b_buf(i, j); + return c->call(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); @@ -1695,15 +1807,15 @@ void testCudaMaskMultiDimSymbolic() { // bound to the block dimension. Internally the inner loops have different // extents but are bound to the same thread dimension. The smaller loop should // be masked. -void testCudaMaskCompoundInnerLoop() { +TEST(Cuda, MaskCompoundInnerLoop_CUDA) { KernelScope kernel_scope; int OUTER_SIZE = 10; int A_SIZE = 100; int B_SIZE = 50; - Buffer a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); - Buffer b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); - Buffer c_buf("c", kFloat, {OUTER_SIZE, A_SIZE}); - Buffer d_buf("d", kFloat, {OUTER_SIZE, B_SIZE}); + Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); + Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); + Placeholder c_buf("c", kFloat, {OUTER_SIZE, A_SIZE}); + Placeholder d_buf("d", kFloat, {OUTER_SIZE, B_SIZE}); // Can't build this using Compute and transforms yet. LoopOptions blockBound; @@ -1723,13 +1835,13 @@ void testCudaMaskCompoundInnerLoop() { j, 0, A_SIZE, - Store::make(c_buf, {i, j}, ExprHandle(2) * a_buf(i, j), 1), + c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), threadBound), For::make( k, 0, B_SIZE, - Store::make(d_buf, {i, k}, c_buf(i, k * 2) + b_buf(i, k), 1), + d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), threadBound)}), blockBound); @@ -1832,17 +1944,17 @@ void testCudaMaskCompoundInnerLoop() { // Tests the case with two loops fused into a common parent, which is not bound // to any block or thread dimension - however it's two inner loops are bound to -// the first thread dimenions. This should work just like the MaskThreadDim test -// where the bigger loop is unmasked but the smaller is masked. -void testCudaMaskInnerLoopOneBlock() { +// the first thread dimensions. This should work just like the MaskThreadDim +// test where the bigger loop is unmasked but the smaller is masked. +TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { KernelScope kernel_scope; int OUTER_SIZE = 10; int A_SIZE = 100; int B_SIZE = 50; - Buffer a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); - Buffer b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); - Buffer c_buf("c", kFloat, {OUTER_SIZE, A_SIZE}); - Buffer d_buf("d", kFloat, {OUTER_SIZE, B_SIZE}); + Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); + Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); + Placeholder c_buf("c", kFloat, {OUTER_SIZE, A_SIZE}); + Placeholder d_buf("d", kFloat, {OUTER_SIZE, B_SIZE}); // Can't build this using Compute and transforms yet. LoopOptions blockBound; @@ -1862,13 +1974,13 @@ void testCudaMaskInnerLoopOneBlock() { j, 0, A_SIZE, - Store::make(c_buf, {i, j}, ExprHandle(2) * a_buf(i, j), 1), + c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), threadBound), For::make( k, 0, B_SIZE, - Store::make(d_buf, {i, k}, c_buf(i, k * 2) + b_buf(i, k), 1), + d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), threadBound)})); stmt = FlattenIndexes(stmt); @@ -1973,24 +2085,24 @@ void testCudaMaskInnerLoopOneBlock() { // size, but with internal loops bound to different thread rank (ie x and y). In // this case both bodies must be masked against the other dimension being > 0. // Note: this is a bit degenerate no one would actually write this for perf. -void testCudaMaskMultiDimMultiAxis() { +TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { KernelScope kernel_scope; int OUTER_SIZE = 10; int A_SIZE = 30; int B_SIZE = 15; - Buffer a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); - Buffer b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); + Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); + Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); Tensor* c = Compute( "C", {{OUTER_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf(i, j); + return ExprHandle(2) * a_buf.load(i, j); }); Tensor* d = Compute( "D", {{OUTER_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->call(i, j * 2) + b_buf(i, j); + return c->call(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); @@ -2103,25 +2215,25 @@ void testCudaMaskMultiDimMultiAxis() { // Tests the case with two loop nests, each bound to both Block and Thread but // the second loop is smaller in both cases - the second store must be masked // for both the block and thread dimension. -void testCudaMaskMultiDimMultiLevel() { +TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { KernelScope kernel_scope; int OUTER_A_SIZE = 10; int OUTER_B_SIZE = 5; int A_SIZE = 30; int B_SIZE = 15; - Buffer a_buf("a", kFloat, {OUTER_A_SIZE, A_SIZE}); - Buffer b_buf("b", kFloat, {OUTER_B_SIZE, B_SIZE}); + Placeholder a_buf("a", kFloat, {OUTER_A_SIZE, A_SIZE}); + Placeholder b_buf("b", kFloat, {OUTER_B_SIZE, B_SIZE}); Tensor* c = Compute( "C", {{OUTER_A_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf(i, j); + return ExprHandle(2) * a_buf.load(i, j); }); Tensor* d = Compute( "D", {{OUTER_B_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->call(i, j * 2) + b_buf(i, j); + return c->call(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp index e94e70aa6b388..30100f1eb65e3 100644 --- a/test/cpp/tensorexpr/test_expr.cpp +++ b/test/cpp/tensorexpr/test_expr.cpp @@ -1,14 +1,14 @@ -#include "test/cpp/tensorexpr/test_base.h" - -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/buffer.h" -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/function.h" -#include "torch/csrc/jit/tensorexpr/ir.h" -#include "torch/csrc/jit/tensorexpr/ir_printer.h" -#include "torch/csrc/jit/tensorexpr/loopnest.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" +#include + +#include + +#include +#include +#include +#include +#include +#include +#include #include #include @@ -22,7 +22,7 @@ using namespace torch::jit::tensorexpr; using SimpleIRExprEval = ExprEval; -void testExprBasicValueTest() { +TEST(Expr, BasicValueTest) { KernelScope kernel_scope; ExprHandle a = IntImm::make(2), b = IntImm::make(3); ExprHandle c = Add::make(a, b); @@ -30,7 +30,7 @@ void testExprBasicValueTest() { ASSERT_EQ(eval.value(), 5); } -void testExprBasicValueTest02() { +TEST(Expr, BasicValueTest02) { KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -41,7 +41,7 @@ void testExprBasicValueTest02() { ASSERT_EQ(eval.value(), -4.0f); } -void testExprLetTest01() { +TEST(Expr, LetTest01) { KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); @@ -50,7 +50,7 @@ void testExprLetTest01() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprLetTest02() { +TEST(Expr, LetTest02) { KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kFloat); @@ -62,18 +62,18 @@ void testExprLetTest02() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4 * 6)); } -void testExprLetStmtTest01() { +TEST(Expr, LetStmtTest01) { KernelScope kernel_scope; - Buffer a_buf("a", kFloat, {1}); - Buffer b_buf("b", kFloat, {1}); + Placeholder a_buf("a", kFloat, {1}); + Placeholder b_buf("b", kFloat, {1}); - ExprHandle load_a = Load::make(a_buf, {0}, 1); + ExprHandle load_a = a_buf.load(0); VarHandle var = VarHandle("v", kFloat); Stmt* let_store = Let::make(var, load_a); - Stmt* store_b = Store::make(b_buf, {0}, var, 1); + Stmt* store_b = b_buf.store({0}, var); Block* block = Block::make({let_store, store_b}); - SimpleIREvaluator eval(block, a_buf, b_buf); + SimpleIREvaluator eval(block, {a_buf, b_buf}); PaddedBuffer a_v(1); PaddedBuffer b_v(1); @@ -86,7 +86,7 @@ void testExprLetStmtTest01() { ExpectAllNear(b_v, b_ref, 1e-5); } -void testExprIntTest() { +TEST(Expr, IntTest) { KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); @@ -95,7 +95,7 @@ void testExprIntTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprFloatTest() { +TEST(Expr, FloatTest) { KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); @@ -104,7 +104,7 @@ void testExprFloatTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprByteTest() { +TEST(Expr, ByteTest) { KernelScope kernel_scope; VarHandle x("x", kByte); ExprHandle body = ExprHandle((uint8_t)2) + @@ -114,7 +114,7 @@ void testExprByteTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprCharTest() { +TEST(Expr, CharTest) { KernelScope kernel_scope; VarHandle x("x", kChar); ExprHandle body = ExprHandle((int8_t)2) + @@ -124,7 +124,7 @@ void testExprCharTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprShortTest() { +TEST(Expr, ShortTest) { KernelScope kernel_scope; VarHandle x("x", kShort); ExprHandle body = ExprHandle((int16_t)2) + @@ -134,7 +134,7 @@ void testExprShortTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprLongTest() { +TEST(Expr, LongTest) { KernelScope kernel_scope; VarHandle x("x", kLong); ExprHandle body = ExprHandle((int64_t)2) + @@ -144,7 +144,7 @@ void testExprLongTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprHalfTest() { +TEST(Expr, HalfTest) { KernelScope kernel_scope; VarHandle x("x", kHalf); ExprHandle body = ExprHandle((at::Half)2) + @@ -154,7 +154,7 @@ void testExprHalfTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprDoubleTest() { +TEST(Expr, DoubleTest) { KernelScope kernel_scope; VarHandle x("x", kDouble); ExprHandle body = ExprHandle((double)2) + @@ -164,33 +164,15 @@ void testExprDoubleTest() { ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); } -void testExprDisallowBoolArithmetic() { - KernelScope kernel_scope; - VarHandle x("x", kBool); - VarHandle y("y", kBool); - std::string error{"arithmetic binary operations on Bool not supported"}; - ASSERT_THROWS_WITH((x + y), error); - ASSERT_THROWS_WITH((x - y), error); - ASSERT_THROWS_WITH((x * y), error); - ASSERT_THROWS_WITH((x / y), error); - ASSERT_THROWS_WITH((x & y), error); - ASSERT_THROWS_WITH((x | y), error); - ASSERT_THROWS_WITH((x ^ y), error); - ASSERT_THROWS_WITH((x << y), error); - ASSERT_THROWS_WITH((x >> y), error); - ASSERT_THROWS_WITH(Max::make(x, y, /*propagate_nans=*/true), error); - ASSERT_THROWS_WITH(Min::make(x, y, /*propagate_nans=*/true), error); -} - -void testExprVectorAdd01() { +TEST(Expr, VectorAdd01) { KernelScope kernel_scope; const int kVectorSize = 8; const int kVectorCount = 128; const int kTotalSize = kVectorSize * kVectorCount; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c_buf(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); /* Build the following: @@ -201,17 +183,14 @@ void testExprVectorAdd01() { } */ VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = Load::make( - a_buf, + ExprHandle load_a = a_buf.loadWithMask( {Ramp::make(index * kVectorSize, 1, kVectorSize)}, Broadcast::make(1, kVectorSize)); - ExprHandle load_b = Load::make( - b_buf, + ExprHandle load_b = b_buf.loadWithMask( {Ramp::make(index * kVectorSize, 1, kVectorSize)}, Broadcast::make(1, kVectorSize)); ExprHandle value = load_a + load_b; - Stmt* store_c = Store::make( - c_buf, + Stmt* store_c = c_buf.storeWithMask( {Ramp::make(index * kVectorSize, 1, kVectorSize)}, value, Broadcast::make(1, kVectorSize)); @@ -230,38 +209,33 @@ void testExprVectorAdd01() { b_v(i) = i * i * 4; c_ref(i) = a_v(i) + b_v(i); } - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); ir_eval(a_v, b_v, c_v); ExpectAllNear(c_v, c_ref, 1e-5); } -void testExprCompareSelectEQ() { +TEST(Expr, CompareSelectEQ) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 1); std::vector b_buffer(N, 1); std::vector c_buffer(N, 0); std::vector c_ref(N, 0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto memcpy_expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kEQ), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kEQ))); - SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); ir_eval(a_buffer, b_buffer, c_buffer); ASSERT_EQ(a_buffer.size(), N); @@ -273,7 +247,7 @@ void testExprCompareSelectEQ() { assertAllEqual(c_buffer, 1); } -void testExprCompareSelectDtypes() { +TEST(Expr, CompareSelectDtypes) { // LHS and RHS expressions should have the same dtype, but this dtype could // differ from the dtype of the return values (but dtypes of true and false // return values should be the same). @@ -282,15 +256,14 @@ void testExprCompareSelectDtypes() { // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kFloat)); std::vector a_buffer(N, 1); std::vector b_buffer(N, 1); std::vector c_buffer(N, 0.0f); std::vector c_ref(N, 3.14f); - auto mask = IntImm::make(1); VarHandle i("i", kInt); // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f // A and B are int, C is float. @@ -298,18 +271,16 @@ void testExprCompareSelectDtypes() { i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), + a.load(i), + b.load(i), FloatImm::make(3.14f), FloatImm::make(2.78f), - CompareSelectOperation::kEQ), - mask)); + CompareSelectOperation::kEQ))); - SimpleIREvaluator ir_eval(select_expr, a, b, c); + SimpleIREvaluator ir_eval(select_expr, {a, b, c}); ir_eval(a_buffer, b_buffer, c_buffer); ASSERT_EQ(a_buffer.size(), N); @@ -321,21 +292,19 @@ void testExprCompareSelectDtypes() { ExpectAllNear(c_buffer, c_ref, 1e-7); } -void testExprIntrinsicsDtypes() { +TEST(Expr, IntrinsicsDtypes) { KernelScope kernel_scope; constexpr int N = 256; - Buffer a(BufHandle("A", {N}, kDouble)); - Buffer b(BufHandle("B", {N}, kDouble)); + Placeholder a(BufHandle("A", {N}, kDouble)); + Placeholder b(BufHandle("B", {N}, kDouble)); std::vector a_buffer(N, -10.0); std::vector b_buffer(N, 0.0); std::vector b_ref(N, 10.0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto fabs_expr = For::make( - i, 0, N, Store::make(b, {i}, fabs(Load::make(a, {i}, mask)), mask)); + auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i)))); - SimpleIREvaluator ir_eval(fabs_expr, a, b); + SimpleIREvaluator ir_eval(abs_expr, {a, b}); ir_eval(a_buffer, b_buffer); ASSERT_EQ(a_buffer.size(), N); @@ -345,7 +314,7 @@ void testExprIntrinsicsDtypes() { ExpectAllNear(b_buffer, b_ref, 1e-7); } -void testExprSubstitute01() { +TEST(Expr, Substitute01) { KernelScope kernel_scope; const Var* x = new Var("x", kFloat); const Var* y = new Var("y", kFloat); @@ -366,7 +335,7 @@ void testExprSubstitute01() { ASSERT_EQ(e2_str, e2_ref_str); } -void testExprMath01() { +TEST(Expr, Math01) { KernelScope kernel_scope; ExprHandle v = sin(ExprHandle(1.0f)); @@ -380,7 +349,7 @@ void testExprMath01() { ASSERT_NEAR(res, v_ref, 1e-6); } -void testExprUnaryMath01() { +TEST(Expr, UnaryMath01) { KernelScope kernel_scope; struct TestConfig { std::function func; @@ -408,7 +377,7 @@ void testExprUnaryMath01() { [](float v) { return std::tanh(v); }}, {[](const ExprHandle& v) { return exp(v); }, [](float v) { return std::exp(v); }}, - {[](const ExprHandle& v) { return fabs(v); }, + {[](const ExprHandle& v) { return tensorexpr::abs(v); }, [](float v) { return std::fabs(v); }}, {[](const ExprHandle& v) { return log(v); }, [](float v) { return std::log(v); }}, @@ -439,9 +408,15 @@ void testExprUnaryMath01() { SimpleIRExprEval eval(v); ASSERT_NEAR(eval.value(), v_ref, 1e-6); } + + for (float input_v : {std::nan("1"), 0., .5}) { + ExprHandle v = FloatImm::make(input_v); + SimpleIRExprEval eval(Intrinsics::make(kIsNan, v)); + ASSERT_NEAR(eval.value(), std::isnan(input_v), 0); + } } -void testExprBinaryMath01() { +TEST(Expr, BinaryMath01) { KernelScope kernel_scope; struct TestConfig { std::function func; @@ -465,7 +440,7 @@ void testExprBinaryMath01() { } } -void testExprBitwiseOps() { +TEST(Expr, BitwiseOps) { KernelScope kernel_scope; ExprHandle a(59); ExprHandle b(11); @@ -477,19 +452,19 @@ void testExprBitwiseOps() { ASSERT_EQ(eval.value(), 11); } -void testExprDynamicShapeAdd() { +TEST(Expr, DynamicShapeAdd) { KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); - Buffer a(BufHandle("a", {n}, kFloat)); - Buffer b(BufHandle("b", {n}, kFloat)); - Buffer c(BufHandle("c", {n}, kFloat)); + Placeholder a(BufHandle("a", {n}, kFloat)); + Placeholder b(BufHandle("b", {n}, kFloat)); + Placeholder c(BufHandle("c", {n}, kFloat)); VarHandle i("i", kInt); - Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1)); + Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); std::vector aData(size, 1.0f); std::vector bData(size, 2.0f); std::vector cData(size, 0.0f); - SimpleIREvaluator(s, a, b, c, n)(aData, bData, cData, size); + SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size); ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); }; testWithSize(1); @@ -501,16 +476,14 @@ void testCond01() { KernelScope kernel_scope; const int N = 16; PaddedBuffer a_v(N); - Buffer a_buf("a", kFloat, {N}); + Placeholder a_buf("a", kFloat, {N}); VarHandle index = VarHandle("index", kInt); - Stmt* assign_x2 = - Store::make(BufHandle(a_buf.data()), {index}, cast(index) * 2, 1); - Stmt* assign_x3 = - Store::make(BufHandle(a_buf.data()), {index}, cast(index) * 3, 1); + Stmt* assign_x2 = a_buf.store({index}, cast(index) * 2); + Stmt* assign_x3 = a_buf.store({index}, cast(index) * 3); ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); Stmt* assign = Cond::make(even_cond, assign_x2, assign_x3); Stmt* for_stmt = For::make(index, 0, N, assign); - SimpleIREvaluator(for_stmt, a_buf)(a_v); + SimpleIREvaluator(for_stmt, {a_buf})(a_v); PaddedBuffer a_ref(N); for (int i = 0; i < N; i++) { @@ -564,31 +537,31 @@ void testStmtClone() { KernelScope kernel_scope; const int N = 16; - Buffer a_buf("a", kInt, {N}); + Placeholder a_buf("a", kInt, {N}); VarHandle index = VarHandle("index", kInt); - Stmt* body = Store::make(BufHandle(a_buf.data()), {index}, 5, 1); + Stmt* body = a_buf.store({index}, 5); Stmt* loop = For::make(index, 0, N, body); Stmt* cloned_loop = Stmt::clone(loop); std::vector orig_loop_results(N); std::vector cloned_loop_results(N); - SimpleIREvaluator(loop, a_buf)(orig_loop_results); - SimpleIREvaluator(cloned_loop, a_buf)(cloned_loop_results); + SimpleIREvaluator(loop, {a_buf})(orig_loop_results); + SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results); assertAllEqual(orig_loop_results, 5); assertAllEqual(cloned_loop_results, 5); // Let's add another assign to the body in the cloned loop and verify that the // original statement hasn't changed while the cloned one has. - Stmt* body_addition = Store::make(BufHandle(a_buf.data()), {index}, 33, 1); + Stmt* body_addition = a_buf.store({index}, 33); Block* cloned_body = static_cast(static_cast(cloned_loop)->body()); cloned_body->append_stmt(body_addition); std::vector orig_loop_results_after_mutation(N); std::vector cloned_loop_results_after_mutation(N); - SimpleIREvaluator(loop, a_buf)(orig_loop_results_after_mutation); - SimpleIREvaluator(cloned_loop, a_buf)(cloned_loop_results_after_mutation); + SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation); + SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation); assertAllEqual(orig_loop_results_after_mutation, 5); assertAllEqual(cloned_loop_results_after_mutation, 33); diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp index 4f8e9c6125372..2600df6cdbcff 100644 --- a/test/cpp/tensorexpr/test_ir_printer.cpp +++ b/test/cpp/tensorexpr/test_ir_printer.cpp @@ -1,3 +1,5 @@ +#include + #include #include "test/cpp/tensorexpr/test_base.h" @@ -14,7 +16,7 @@ namespace jit { using namespace torch::jit::tensorexpr; -void testIRPrinterBasicValueTest() { +TEST(IRPrinter, BasicValueTest) { KernelScope kernel_scope; ExprHandle a = IntImm::make(2), b = IntImm::make(3); ExprHandle c = Add::make(a, b); @@ -24,7 +26,7 @@ void testIRPrinterBasicValueTest() { ASSERT_EQ(ss.str(), "2 + 3"); } -void testIRPrinterBasicValueTest02() { +TEST(IRPrinter, BasicValueTest02) { KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -37,7 +39,7 @@ void testIRPrinterBasicValueTest02() { ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)"); } -void testIRPrinterCastTest() { +TEST(IRPrinter, CastTest) { KernelScope kernel_scope; VarHandle x("x", kHalf); VarHandle y("y", kFloat); @@ -49,7 +51,7 @@ void testIRPrinterCastTest() { ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)"); } -void testIRPrinterFunctionName() { +TEST(IRPrinter, FunctionName) { KernelScope kernel_scope; int M = 4; int N = 20; diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index ab916d370e828..902c2a7011973 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -1,8 +1,9 @@ +#include + #include #include #include #include -#include #include #include #include @@ -18,14 +19,73 @@ namespace jit { using namespace torch::indexing; using namespace torch::jit::tensorexpr; -void testKernel_1() { +TEST(Kernel, InliningIntermediates) { + // here, each mul has only one use, so it should be completely inlined + { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %one : int = prim::Constant[value=1]() + %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) + return (%5))IR"; + KernelScope kernel_scope; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); + } + { + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=${device}), + %1 : Float(5, 3, strides=[3, 1], device=${device})): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %one : int = prim::Constant[value=1]() + %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one) + %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one) + %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0) + return (%4, %5))IR"; + for (bool use_cuda : {false, true}) { + if (!torch::cuda::is_available() && use_cuda) { + continue; + } + + KernelScope kernel_scope; + TemplateEnv env; + env.s("device", use_cuda ? "cuda:0" : "cpu"); + const auto graph_string = format(graph_template, env); + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + auto device = use_cuda ? kCUDA : kCPU; + TensorExprKernel k(graph); + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + // aten_mul only has one use, inlined completely + torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); + + // aten_sub should be removed in cuda, exist in cpu + // 5 uses: allocate, initialize, free and two reads + size_t num_out1_uses = use_cuda ? 0 : 5; + torch::jit::testing::FileCheck() + .check_count("aten_sub", num_out1_uses, /*exactly*/ true) + ->run(oss.str()); + } + } +} + +TEST(Kernel, _1) { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu), - %1 : Float(5:3,3:1, device=cpu)): - %2 : Float(5:3,3:1) = aten::mul(%0, %1) - %3 : Float(5:3,3:1) = aten::mul(%0, %2) + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); @@ -57,14 +117,14 @@ void testKernel_1() { } } -void testKernel_2() { +TEST(Kernel, _2) { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu), - %1 : Float(5:1,3:5, device=cpu)): - %2 : Float(5:3,3:1) = aten::mul(%0, %1) - %3 : Float(5:3,3:1) = aten::mul(%0, %2) + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); @@ -97,14 +157,14 @@ void testKernel_2() { } } -void testKernel_3() { +TEST(Kernel, _3) { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu), - %1 : Float(5:12,3:2, device=cpu)): - %2 : Float(5:3,3:1) = aten::mul(%0, %1) - %3 : Float(5:3,3:1) = aten::mul(%0, %2) + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); @@ -137,15 +197,17 @@ void testKernel_3() { } } -void testKernel_4() { +TEST(Kernel, DISABLED_Shape_Inference) { + // disabled: doesn't do stride propagation, and isn't being used currently + // Test TensorExpr shape inference capabilities: it should only require shapes // for the inputs { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(5:3, 3:1, device=cpu), - %1 : Float(5:12, 3:2, device=cpu)): + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): %2 : Tensor = aten::mul(%0, %1) %3 : Tensor = aten::mul(%0, %2) return (%3))IR"; @@ -183,8 +245,8 @@ void testKernel_4() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(8:8, 8:1, device=cpu), - %1 : Float(8:8, 8:1, device=cpu)): + graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), + %1 : Float(8, 8, strides=[8, 1], device=cpu)): %2 : Tensor = aten::mul(%0, %1) %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) %r : Tensor = aten::mul(%3, %4) @@ -224,9 +286,9 @@ void testKernel_4() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%a : Float(4:2, 2:1, device=cpu), - %b : Float(4:6, 3:2, 2:1, device=cpu), - %c : Float(3:4, 2:2, 2:1, device=cpu)): + graph(%a : Float(4, 2, strides=[2, 1], device=cpu), + %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), + %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): %one : int = prim::Constant[value=1]() %minus_one : int = prim::Constant[value=-1]() %three : int = prim::Constant[value=3]() @@ -287,9 +349,9 @@ void testKernel_4() { KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%a : Float(5:6, 3:2, 2:1, device=cpu), - %b : Float(5:14, 7:2, 2:1, device=cpu), - %c : Float(5:18, 9:2, 2:1, device=cpu)): + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): %dim : int = prim::Constant[value=1]() %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] @@ -336,6 +398,109 @@ void testKernel_4() { CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } + { + // Test that we throw an error when input list for aten::cat is empty + KernelScope kernel_scope; + + const auto graph_string = R"IR( + graph(): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct() + %r : Tensor = aten::cat(%inputs, %dim) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + auto compile = [&]() { + TensorExprKernel k(graph); + k.getCodeGenStmt(); + }; + ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat"); + } + { + // Test that we throw an error when 'dim' passed to aten::cat is invalid + KernelScope kernel_scope; + + const auto ir_dim_99 = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=99]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) + return (%r))IR"; + const auto ir_dim_minus_6 = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=-6]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) + return (%r))IR"; + + auto compile = [](const std::string& graph_string) { + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + k.getCodeGenStmt(); + }; + ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index"); + ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index"); + } +} + +TEST(Kernel, CatInputTypesPromotion) { + { + // Test that we properly promote input types for aten::cat + KernelScope kernel_scope; + + const auto graph_string = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble)); + auto ref = at::cat({a, b, c}, 1); + + TensorExprKernel k(graph); + std::vector inputs = {a, b, c}; + Stmt* s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_cat)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + + // Check sizes + CHECK_EQ(o.sizes().size(), ref.sizes().size()); + CHECK_EQ(o.dtype(), ref.dtype()); + size_t num_el = 1; + for (size_t idx = 0; idx < ref.sizes().size(); idx++) { + CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (size_t i = 0; i < num_el; i++) { + CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); + } + } } namespace { @@ -361,10 +526,15 @@ at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { } // namespace -void testKernelSumAllAxes() { +TEST(Kernel, DISABLED_SumAllAxes) { + // [zero-dim tensors] + // NNC does not yet handle zero-dim tensors. aten::sum with no axis + // input returns a zero-dim tensors, so these tests must be disabled + // until we add support for zero-dim tensors. + // Test lowering of sum on all axes. const auto graph_template = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu)): + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): %1 : ${dtype} %2 : Tensor = aten::sum(%0, %1) return (%2))IR"; @@ -408,14 +578,27 @@ void testKernelSumAllAxes() { } } -void testKernelSumOneAxis() { +std::string li_to_str(at::ArrayRef li) { + std::stringstream out; + bool first = true; + for (auto elem : li) { + if (!first) { + out << ", "; + } + out << elem; + first = false; + } + return out.str(); +} + +TEST(Kernel, SumOneAxis) { // Test lowering of sum on one axis. const auto graph_template = R"IR( - graph(%0 : Float(5:3,3:1, device=cpu)): + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): %1 : int[] = prim::Constant[value=[${dim}]]() %2 : bool = prim::Constant[value=${keepdim}]() %3 : ${dtype} - %4 : Tensor = aten::sum(%0, %1, %2, %3) + %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) return (%4))IR"; auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); @@ -427,17 +610,23 @@ void testKernelSumOneAxis() { env.d("dim", dim); env.d("keepdim", keepdim); env.s("dtype", dtypeConstant(scalar_type)); - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); c10::optional dtype; if (scalar_type != ScalarType::None) { dtype = static_cast(scalar_type); } auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); + if (scalar_type == ScalarType::None) { + env.s("out_dtype", "Float"); + } else { + env.s("out_dtype", "Double"); + } + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto o = at::empty({}, TensorOptions(kCPU)); TensorExprKernel k(graph); std::vector inputs = {a}; Stmt* s = k.getCodeGenStmt(); @@ -448,9 +637,10 @@ void testKernelSumOneAxis() { // Check the IR we produced const std::string& verification_pattern = R"IR( -# CHECK: int v = 0 -# CHECK: int v_1 = 0 -# CHECK: input1)IR"; +# CHECK: for (int v = 0; v < +# CHECK-NEXT: sum +# CHECK-NEXT: for (int v_1 = 0; v_1 < +# CHECK-NEXT: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); std::vector stack = fmap(inputs); @@ -464,16 +654,16 @@ void testKernelSumOneAxis() { } } -void testKernelSumMultipleAxes() { +TEST(Kernel, SumMultipleAxes) { // Test lowering of sum on multiple axes. const auto graph_template = R"IR( - graph(%0 : Float(2:18,3:6,2:3,3:1, device=cpu)): + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): %1 : int = prim::Constant[value=${dim1}]() %2 : int = prim::Constant[value=${dim2}]() %3 : int[] = prim::ListConstruct(%1, %2) %4 : bool = prim::Constant[value=${keepdim}]() %5 : ${dtype} - %6 : Tensor = aten::sum(%0, %3, %4, %5) + %6 : Float(${size}, strides=[${strides}]) = aten::sum(%0, %3, %4, %5) return (%6))IR"; auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); @@ -488,13 +678,17 @@ void testKernelSumMultipleAxes() { env.d("dim2", dim2); env.d("keepdim", keepdim); env.s("dtype", dtypeConstant(ScalarType::None)); + auto o = at::empty({}, TensorOptions(kCPU)); + auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); + + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); auto graph = std::make_shared(); parseIR(graph_string, &*graph); - auto o = at::empty({}, TensorOptions(kCPU)); - auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); TensorExprKernel k(graph); std::vector inputs = {a}; Stmt* s = k.getCodeGenStmt(); @@ -509,7 +703,7 @@ void testKernelSumMultipleAxes() { # CHECK: int v_1 = 0 # CHECK: int v_2 = 0 # CHECK: int v_3 = 0 -# CHECK: input1)IR"; +# CHECK: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); std::vector stack = fmap(inputs); @@ -523,5 +717,323 @@ void testKernelSumMultipleAxes() { } } +// This test and the following ones testing Softmax only tests with dim set +// to one of the valid input dimensions. It does not test with dim=None +// because that is supposed to be deprecated. +TEST(Kernel, Softmax2D) { + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 5 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (auto log_softmax : {false, true}) { + for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + auto other_dim = (softmax_dim + 1) % a.dim(); + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + KernelScope kernel_scope; + TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + Stmt* s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + TemplateEnv ver_env; + ver_env.d("other_dim", other_dim); + ver_env.d("other_dim_size", a.sizes()[other_dim]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + + // verication sting temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } +} + +TEST(Kernel, Softmax3D) { + const auto graph_template = R"IR( + graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 3 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 + # CHECK-NEXT: aten_softmax)IR"; + + for (auto log_softmax : {false, true}) { + for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (int i = 0; i < a.dim(); ++i) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + + KernelScope kernel_scope; + TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + Stmt* s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + + // verication sting temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } +} + +TEST(Kernel, Softmax4D) { + const auto graph_template = R"IR( + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 2 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 + # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (auto log_softmax : {false, true}) { + for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (int i = 0; i < a.dim(); ++i) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + + KernelScope kernel_scope; + TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + Stmt* s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("dim3", other_dims[2]); + ver_env.d("dim3_size", a.sizes()[other_dims[2]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + + // verication sting temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } +} + +TEST(Kernel, DISABLED_InlineProducerIntoReduction) { + // see : [zero-dim tensors] + KernelScope kernel_scope; + + // Inline producer (mul) into reduction (sum). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=7]() + %4 : Float(5, 3, strides=[3, 1]) = aten::sum(%2, %3) + return (%4))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + Stmt* s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have only one loop in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int v = 0; v < 5; + # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK-NEXT: sum + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kDouble); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST(Kernel, DISABLED_InlineReductionIntoConsumer) { + // see : [zero-dim tensors] + + KernelScope kernel_scope; + + // Inline producer (mul %2) into reduction (sum %4) but DO NOT + // inline the reduction into consumer (mul %4). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=6]() + %4 : Float(5, 3, strides=[3, 1]) = aten::sum(%2, %3) + %5 : Float(5, 3, strides=[3, 1]) = aten::mul(%2, %4) + return (%5))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + Stmt* s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have two loops in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int v = 0; v < 5; + # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK-NEXT: sum + # CHECK: for (int v_2 = 0; v_2 < 5; + # CHECK-NEXT: for (int v_3 = 0; v_3 < 3; + # CHECK-NEXT: aten_mul + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kFloat) * (a * b); + ASSERT_TRUE(at::allclose(o, ref)); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index ee8540eb58c45..7afb839dc7e09 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -1,11 +1,11 @@ #ifdef TORCH_ENABLE_LLVM +#include + #include "test/cpp/tensorexpr/test_base.h" #include "test/cpp/tensorexpr/padded_buffer.h" #include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/buffer.h" #include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/function.h" #include "torch/csrc/jit/tensorexpr/ir.h" #include "torch/csrc/jit/tensorexpr/ir_printer.h" #include "torch/csrc/jit/tensorexpr/ir_simplifier.h" @@ -35,7 +35,7 @@ using LLVMExprEval = ExprEval; _(at::Half, Half, 0.128f) #define IMM_TEST(Type, Name, Val) \ - void testLLVM##Name##ImmTest() { \ + TEST(LLVM, Name##ImmTest) { \ KernelScope kernel_scope; \ auto a = Name##Imm::make(Val); \ LLVMExprEval cg(a); \ @@ -49,7 +49,7 @@ TEST_LLVM_SCALAR_TYPES(IMM_TEST) #undef IMM_TEST #define ADD_TEST(Type, Name, Val) \ - void testLLVM##Name##AddTest() { \ + TEST(LLVM, Name##AddTest) { \ KernelScope kernel_scope; \ auto a = Name##Imm::make(Val); \ auto b = Name##Imm::make(Val * 2); \ @@ -65,7 +65,7 @@ TEST_LLVM_SCALAR_TYPES(ADD_TEST) #undef ADD_TEST #define SUB_TEST(Type, Name, Val) \ - void testLLVM##Name##SubTest() { \ + TEST(LLVM, Name##SubTest) { \ KernelScope kernel_scope; \ auto a = Name##Imm::make(Val * 2); \ auto b = Name##Imm::make(Val); \ @@ -81,7 +81,7 @@ TEST_LLVM_SCALAR_TYPES(SUB_TEST) #undef SUB_TEST #define MUL_TEST(Type, Name, Val) \ - void testLLVM##Name##MulTest() { \ + TEST(LLVM, Name##MulTest) { \ KernelScope kernel_scope; \ auto a = Name##Imm::make(Val); \ auto b = Name##Imm::make((Type)4); \ @@ -97,7 +97,7 @@ TEST_LLVM_SCALAR_TYPES(MUL_TEST) #undef MUL_TEST #define DIV_TEST(Type, Name, Val) \ - void testLLVM##Name##DivTest() { \ + TEST(LLVM, Name##DivTest) { \ KernelScope kernel_scope; \ auto a = Name##Imm::make((Type)6); \ auto b = Name##Imm::make((Type)3); \ @@ -112,7 +112,7 @@ TEST_LLVM_SCALAR_TYPES(MUL_TEST) TEST_LLVM_SCALAR_TYPES(DIV_TEST) #undef DIV_TEST -void testLLVMIntToFloatCastTest() { +TEST(LLVM, IntToFloatCastTest) { KernelScope kernel_scope; auto a = IntImm::make(2); auto b = Cast::make(kFloat, a); @@ -120,7 +120,7 @@ void testLLVMIntToFloatCastTest() { ASSERT_EQ(cg.value(), 2.0); } -void testLLVMFloatToIntCastTest() { +TEST(LLVM, FloatToIntCastTest) { KernelScope kernel_scope; auto a = FloatImm::make(2.0); auto b = Cast::make(kInt, a); @@ -128,7 +128,7 @@ void testLLVMFloatToIntCastTest() { ASSERT_EQ(cg.value(), 2); } -void testLLVMIntToLongCastTest() { +TEST(LLVM, IntToLongCastTest) { KernelScope kernel_scope; auto a = IntImm::make(12345); auto b = Cast::make(kLong, a); @@ -136,7 +136,7 @@ void testLLVMIntToLongCastTest() { ASSERT_EQ(cg.value(), 12345); } -void testLLVMByteToCharCastTest() { +TEST(LLVM, ByteToCharCastTest) { KernelScope kernel_scope; auto a = ByteImm::make(250); auto b = Cast::make(kChar, a); @@ -144,7 +144,7 @@ void testLLVMByteToCharCastTest() { ASSERT_EQ(cg.value(), (int8_t)250); } -void testLLVMHalfToLongCastTest() { +TEST(LLVM, HalfToLongCastTest) { KernelScope kernel_scope; auto a = HalfImm::make(2.0); auto b = Cast::make(kLong, a); @@ -152,7 +152,7 @@ void testLLVMHalfToLongCastTest() { ASSERT_EQ(cg.value(), 2); } -void testLLVMByteToDoubleCastTest() { +TEST(LLVM, ByteToDoubleCastTest) { KernelScope kernel_scope; auto a = ByteImm::make(2); auto b = Cast::make(kDouble, a); @@ -160,20 +160,105 @@ void testLLVMByteToDoubleCastTest() { ASSERT_EQ(cg.value(), 2); } -void testLLVMLetTest01() { +TEST(LLVM, BitCast) { + constexpr int16_t ref16 = 1337; + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + at::Half reff16 = 1337.0f; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + + // this is broken + /*{ + KernelScope kernel_scope; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(k); + auto b = BitCast::make(kShort, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + KernelScope kernel_scope; + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + KernelScope kernel_scope; + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + KernelScope kernel_scope; + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + KernelScope kernel_scope; + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } +} + +TEST(LLVM, fastLogFloat) { + KernelScope kernel_scope; + const int kTotalSize = 128 * 128; + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + Stmt* store_b = b_buf.store({index}, fast_log(load_a)); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = at::randn({1}).item().to(); + } + + LLVMCodeGen ir_eval(stmt, {a_buf, b_buf}); + ir_eval.call({a_v, b_v}); + + for (int i = 0; i < kTotalSize; ++i) { + auto test = b_v(i); + auto ref = std::log(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_FLOAT_EQ(test, ref); + } + } +} + +TEST(LLVM, LetTest01) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kFloat)); + Placeholder a(BufHandle("A", {1}, kFloat)); std::vector v = {1, 0}; std::vector args({v.data()}); VarHandle x("x", kFloat); auto block = Block::make({ Let::make(x, 3.f), - Store::make( - a, - {IntImm::make(0)}, - ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)), - IntImm::make(1)), + a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))), }); LLVMCodeGen cg(block, {a}); @@ -181,57 +266,52 @@ void testLLVMLetTest01() { ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f); } -void testLLVMLetTest02() { +TEST(LLVM, LetTest02) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kFloat)); + Placeholder a(BufHandle("A", {1}, kFloat)); std::vector v = {1, 0}; std::vector args({v.data()}); VarHandle x("x", kFloat); VarHandle y("y", kFloat); - auto block = Block::make({ - Let::make(x, 3.f), - Let::make(y, 6.f), - Store::make( - a, - {IntImm::make(0)}, - ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)), - IntImm::make(1)), - }); + auto block = Block::make( + {Let::make(x, 3.f), + Let::make(y, 6.f), + a.store( + {IntImm::make(0)}, + ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))}); LLVMCodeGen cg(block, {a}); ASSERT_EQ(cg.value(args), 0); ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f); } -void testLLVMLetTestMultitype() { +TEST(LLVM, LetTestMultitype) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kDouble)); + Placeholder a(BufHandle("A", {1}, kDouble)); std::vector v = {1, 0}; std::vector args({v.data()}); VarHandle x("x", kByte); VarHandle y("y", kHalf); - auto block = Block::make({ - Let::make(x, 3), - Let::make(y, 6.f), - Store::make( - a, - {IntImm::make(0)}, - Cast::make( - kDouble, - ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f))), - IntImm::make(1)), - }); + auto block = Block::make( + {Let::make(x, 3), + Let::make(y, 6.f), + a.store( + {0}, + Cast::make( + kDouble, + ExprHandle(2.f) + + (x * ExprHandle(3.f) + y * ExprHandle(4.f))))}); LLVMCodeGen cg(block, {a}); ASSERT_EQ(cg.value(args), 0); ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f); } -void testLLVMBufferTest() { +TEST(LLVM, BufferTest) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {32}, kFloat)); + Placeholder a(BufHandle("A", {32}, kFloat)); std::vector v(5); std::vector args({v.data()}); auto rv = IntImm::make(0); @@ -239,16 +319,16 @@ void testLLVMBufferTest() { ASSERT_EQ(cg.value(args), 0); } -void testLLVMBlockTest() { +TEST(LLVM, BlockTest) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {32}, kInt)); + Placeholder a(BufHandle("A", {32}, kInt)); std::vector v = {1, 2}; std::vector args({v.data()}); auto block = Block::make({ - Store::make(a, {IntImm::make(0)}, IntImm::make(3), IntImm::make(1)), - Store::make(a, {IntImm::make(1)}, IntImm::make(4), IntImm::make(1)), - Store::make(a, {IntImm::make(0)}, IntImm::make(4), IntImm::make(1)), + a.store({0}, 3), + a.store({1}, 4), + a.store({0}, 4), }); LLVMCodeGen cg(block, {a}); @@ -257,18 +337,14 @@ void testLLVMBlockTest() { ASSERT_EQ(v[1], 4); } -void testLLVMLoadStoreTest() { +TEST(LLVM, LoadStoreTest) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); + Placeholder a(BufHandle("A", {1}, kInt)); + Placeholder b(BufHandle("B", {1}, kInt)); std::vector a_buffer = {42}; std::vector b_buffer = {-11}; - auto store = Store::make( - b, - {IntImm::make(0)}, - Load::make(a, {IntImm::make(0)}, IntImm::make(1)), - IntImm::make(1)); + auto store = b.store({0}, a.load(0)); LLVMCodeGen cg(store, {a, b}); std::vector args({a_buffer.data(), b_buffer.data()}); ASSERT_EQ(cg.value(args), 0); @@ -276,23 +352,16 @@ void testLLVMLoadStoreTest() { ASSERT_EQ(b_buffer[0], 42); } -void testLLVMIfThenElseTest() { +TEST(LLVM, IfThenElseTest) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); - Buffer c(BufHandle("C", {1}, kInt)); + Placeholder a(BufHandle("A", {1}, kInt)); + Placeholder b(BufHandle("B", {1}, kInt)); + Placeholder c(BufHandle("C", {1}, kInt)); std::vector a_buffer = {42}; std::vector b_buffer = {-11}; std::vector c_buffer = {1}; - auto store = Store::make( - b, - {IntImm::make(0)}, - IfThenElse::make( - Load::make(c, {IntImm::make(0)}, IntImm::make(1)), // cond - Load::make(a, {IntImm::make(0)}, IntImm::make(1)), // then - IntImm::make(0)), // else - IntImm::make(1)); + auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0)); LLVMCodeGen cg(store, {a, b, c}); std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); ASSERT_EQ(cg.value(args), 0); @@ -300,17 +369,117 @@ void testLLVMIfThenElseTest() { ASSERT_EQ(b_buffer[0], 42); } -void testLLVMVecLoadStoreTest() { +// if (x < 10) x = x + 1 +TEST(LLVM, CondNoFalseBlockTest) { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value); + } + } +} + +// if (x < 10) { +// x = x + 1; +// } else { +// x = x - 1; +// } +TEST(LLVM, CondTest) { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto block = Block::make({ + cond, + x.store({0}, x.load(0) * 2), + }); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(block, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); + } else { + ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); + } + } +} + +// if (x < 10) { +// if (x > 5) { +// x = x + 1; +// } else { +// x = x - 1; +// } +// } else { +// if (x <= 15) { +// x = x + 2; +// } else { +// x = x - 2; +// } +// } +TEST(LLVM, CondNestedTest) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); + + Placeholder x(BufHandle("X", {1}, kInt)); + auto true_cmp = + CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); + auto true_cond = Cond::make( + true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto false_cmp = + CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); + auto false_cond = Cond::make( + false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, true_cond, false_cond); + + for (int32_t x_value : {0, 8, 15, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + if (x_value > 5) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value - 1); + } + } else { + if (x_value <= 15) { + ASSERT_EQ(x_buffer[0], x_value + 2); + } else { + ASSERT_EQ(x_buffer[0], x_value - 2); + } + } + } +} + +TEST(LLVM, VecLoadStoreTest) { + KernelScope kernel_scope; + Placeholder a(BufHandle("A", {1}, kInt)); + Placeholder b(BufHandle("B", {1}, kInt)); std::vector a_buffer = {1, 1, 1, 1}; std::vector b_buffer = {2, 2, 2, 2}; - auto store = Store::make( - b, + auto store = b.storeWithMask( {Ramp::make(0, 1, 4)}, - Load::make(a, {Ramp::make(0, 1, 4)}, Broadcast::make(IntImm::make(1), 4)), + a.loadWithMask( + {Ramp::make(0, 1, 4)}, Broadcast::make(IntImm::make(1), 4)), Broadcast::make(IntImm::make(1), 4)); LLVMCodeGen cg(store, {a, b}); std::vector args({a_buffer.data(), b_buffer.data()}); @@ -326,18 +495,16 @@ void testLLVMVecLoadStoreTest() { } #define FLOAT_INTRINSICS_TEST(Name, Lanes) \ - void testLLVMVecFloat_##Name##Lane##Lanes##Test() { \ + TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \ KernelScope kernel_scope; \ - Buffer a(BufHandle("A", {1}, kFloat)); \ - Buffer b(BufHandle("B", {1}, kFloat)); \ + Placeholder a(BufHandle("A", {1}, kFloat)); \ + Placeholder b(BufHandle("B", {1}, kFloat)); \ float val = 0.5f; \ std::vector a_buffer(Lanes, val); \ std::vector b_buffer(Lanes, val); \ - auto store = Store::make( \ - b, \ + auto store = b.storeWithMask( \ {Ramp::make(0, 1, Lanes)}, \ - Name(Load::make( \ - a, \ + Name(a.loadWithMask( \ {Ramp::make(0, 1, Lanes)}, \ Broadcast::make(IntImm::make(1), Lanes))), \ Broadcast::make(IntImm::make(1), Lanes)); \ @@ -371,18 +538,16 @@ FLOAT_INTRINSICS_TEST(lgamma, 8) #undef FLOAT_INTRINSICS_TEST #define DOUBLE_INTRINSICS_TEST(Name, Lanes) \ - void testLLVMVecDouble_##Name##Lane##Lanes##Test() { \ + TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \ KernelScope kernel_scope; \ - Buffer a(BufHandle("A", {1}, kDouble)); \ - Buffer b(BufHandle("B", {1}, kDouble)); \ + Placeholder a(BufHandle("A", {1}, kDouble)); \ + Placeholder b(BufHandle("B", {1}, kDouble)); \ float val = 0.5f; \ std::vector a_buffer(Lanes, val); \ std::vector b_buffer(Lanes, val); \ - auto store = Store::make( \ - b, \ + auto store = b.storeWithMask( \ {Ramp::make(0, 1, Lanes)}, \ - Name(Load::make( \ - a, \ + Name(a.loadWithMask( \ {Ramp::make(0, 1, Lanes)}, \ Broadcast::make(IntImm::make(1), Lanes))), \ Broadcast::make(IntImm::make(1), Lanes)); \ @@ -415,18 +580,17 @@ DOUBLE_INTRINSICS_TEST(expm1, 4) DOUBLE_INTRINSICS_TEST(lgamma, 4) #undef DOUBLE_INTRINSICS_TEST -void testLLVMVectorizerLoadStoreTest() { +TEST(LLVM, VectorizerLoadStoreTest) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); + Placeholder a(BufHandle("A", {1}, kInt)); - Tensor* c = Compute("c", {{4, "i"}}, [&](const VarHandle& i) { - return Load::make(a, {i}, 1); - }); + Tensor* c = + Compute("c", {{4, "i"}}, [&](const VarHandle& i) { return a.load(i); }); - Buffer c_buf(BufHandle(c->func_var())); + Placeholder c_buf(BufHandle(c->buf())); LoopNest l({c}); Stmt* s = l.root_stmt(); - l.vectorize(dynamic_cast(s)->front()); + l.vectorize(dynamic_cast(dynamic_cast(s)->front())); ASSERT_TRUE(dynamic_cast(dynamic_cast(s)->front()) == nullptr); @@ -439,18 +603,42 @@ void testLLVMVectorizerLoadStoreTest() { assertAllEqual(c_vec, 21); } -void testLLVMMemcpyTest() { +TEST(LLVM, VectorizeBitCast) { + KernelScope kernel_scope; + Placeholder a(BufHandle("A", {128}, kInt)); + + Tensor* c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) { + return bitcast(a.load(i)); + }); + + Placeholder c_buf(BufHandle(c->buf())); + LoopNest l({c}); + Stmt* s = l.root_stmt(); + l.vectorize(dynamic_cast(dynamic_cast(s)->front())); + ASSERT_TRUE(dynamic_cast(dynamic_cast(s)->front()) == nullptr); + + LLVMCodeGen cg(s, {a, c_buf}); + + std::vector a_vec(128); + std::vector c_vec(128); + for (auto i = 0; i < 128; ++i) { + a_vec[i] = raw_bitcast(1337.f); + } + std::vector args({a_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 1337.f); +} + +TEST(LLVM, MemcpyTest) { KernelScope kernel_scope; constexpr int N = 32; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); std::vector a_buffer(N, 42); std::vector b_buffer(N, 0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, Store::make(b, {i}, Load::make(a, {i}, mask), mask)); + auto expr = For::make(i, 0, N, b.store({i}, a.load(i))); LLVMCodeGen cg(expr, {a, b}); @@ -463,15 +651,14 @@ void testLLVMMemcpyTest() { assertAllEqual(b_buffer, 42); } -void testLLVMBzeroTest() { +TEST(LLVM, BzeroTest) { KernelScope kernel_scope; constexpr int N = 32; - Buffer b(BufHandle("B", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); std::vector b_buffer(N, 11); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, Store::make(b, {i}, IntImm::make(0), mask)); + auto expr = For::make(i, 0, N, b.store({i}, 0)); LLVMCodeGen cg(expr, {b}); @@ -482,27 +669,18 @@ void testLLVMBzeroTest() { assertAllEqual(b_buffer, 0); } -void testLLVMElemwiseAdd() { +TEST(LLVM, ElemwiseAdd) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 41); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Add::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask)), - mask)); + auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); LLVMCodeGen cg(expr, {a, b, c}); @@ -517,24 +695,18 @@ void testLLVMElemwiseAdd() { assertAllEqual(c_buffer, 42); } -void testLLVMElemwiseAddFloat() { +TEST(LLVM, ElemwiseAddFloat) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); - Buffer c(BufHandle("C", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); + Placeholder c(BufHandle("C", {N}, kFloat)); std::vector a_buffer(N, 41); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, {i}, Load::make(a, {i}, mask) + Load::make(b, {i}, mask), mask)); + auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i))); LLVMCodeGen cg(expr, {a, b, c}); @@ -549,11 +721,11 @@ void testLLVMElemwiseAddFloat() { assertAllEqual(c_buffer, 42.0f); } -void testLLVMElemwiseLog10Float() { +TEST(LLVM, ElemwiseLog10Float) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); std::vector a_buffer(N, 10.0f); std::vector b_buffer(N, 2.0f); @@ -563,10 +735,9 @@ void testLLVMElemwiseLog10Float() { i, 0, N / 4, - Store::make( - b, + b.storeWithMask( {Ramp::make(i * 4, 1, 4)}, - log10(Load::make(a, {Ramp::make(i * 4, 1, 4)}, mask)), + log10(a.loadWithMask({Ramp::make(i * 4, 1, 4)}, mask)), mask)); LLVMCodeGen cg(expr, {a, b}); @@ -580,11 +751,11 @@ void testLLVMElemwiseLog10Float() { assertAllEqual(b_buffer, 1.0f); } -void testLLVMElemwiseLog1pFloat() { +TEST(LLVM, ElemwiseLog1pFloat) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); std::vector a_buffer(N, expf(3.0f) - 1); std::vector b_buffer(N, 42.0f); @@ -594,10 +765,9 @@ void testLLVMElemwiseLog1pFloat() { i, 0, N / 4, - Store::make( - b, + b.storeWithMask( {Ramp::make(i * 4, 1, 4)}, - log1p(Load::make(a, {Ramp::make(i * 4, 1, 4)}, mask)), + log1p(a.loadWithMask({Ramp::make(i * 4, 1, 4)}, mask)), mask)); LLVMCodeGen cg(expr, {a, b}); @@ -611,27 +781,19 @@ void testLLVMElemwiseLog1pFloat() { ExpectAllNear(b_buffer, 3.0f, 1e-5f); } -void testLLVMElemwiseMaxInt() { +TEST(LLVM, ElemwiseMaxInt) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 41); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false), - mask)); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); LLVMCodeGen cg(expr, {a, b, c}); @@ -646,27 +808,19 @@ void testLLVMElemwiseMaxInt() { assertAllEqual(c_buffer, 41); } -void testLLVMElemwiseMinInt() { +TEST(LLVM, ElemwiseMinInt) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 41); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false), - mask)); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); LLVMCodeGen cg(expr, {a, b, c}); @@ -681,27 +835,19 @@ void testLLVMElemwiseMinInt() { assertAllEqual(c_buffer, 1); } -void testLLVMElemwiseMaxFloat() { +TEST(LLVM, ElemwiseMaxFloat) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); - Buffer c(BufHandle("C", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); + Placeholder c(BufHandle("C", {N}, kFloat)); std::vector a_buffer(N, 41); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false), - mask)); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); LLVMCodeGen cg(expr, {a, b, c}); @@ -716,27 +862,19 @@ void testLLVMElemwiseMaxFloat() { assertAllEqual(c_buffer, 41.0f); } -void testLLVMElemwiseMaxNaNFloat() { +TEST(LLVM, ElemwiseMaxNaNFloat) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); - Buffer c(BufHandle("C", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); + Placeholder c(BufHandle("C", {N}, kFloat)); std::vector a_buffer(N, NAN); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Max::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false), - mask)); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); LLVMCodeGen cg(expr, {a, b, c}); @@ -752,27 +890,19 @@ void testLLVMElemwiseMaxNaNFloat() { } } -void testLLVMElemwiseMinFloat() { +TEST(LLVM, ElemwiseMinFloat) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); - Buffer c(BufHandle("C", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); + Placeholder c(BufHandle("C", {N}, kFloat)); std::vector a_buffer(N, 41); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false), - mask)); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); LLVMCodeGen cg(expr, {a, b, c}); @@ -787,27 +917,19 @@ void testLLVMElemwiseMinFloat() { assertAllEqual(c_buffer, 1.0f); } -void testLLVMElemwiseMinNaNFloat() { +TEST(LLVM, ElemwiseMinNaNFloat) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); - Buffer c(BufHandle("C", {N}, kFloat)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); + Placeholder c(BufHandle("C", {N}, kFloat)); std::vector a_buffer(N, NAN); std::vector b_buffer(N, 1); std::vector c_buffer(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Min::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask), false), - mask)); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); LLVMCodeGen cg(expr, {a, b, c}); @@ -823,27 +945,18 @@ void testLLVMElemwiseMinNaNFloat() { } } -void testLLVMElemwiseMod() { +TEST(LLVM, ElemwiseMod) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 41); std::vector b_buffer(N, 23); std::vector c_buffer(N, 18); - auto mask = IntImm::make(1); VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - Mod::make(Load::make(a, {i}, mask), Load::make(b, {i}, mask)), - mask)); + auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i)))); LLVMCodeGen cg(expr, {a, b, c}); @@ -858,12 +971,12 @@ void testLLVMElemwiseMod() { assertAllEqual(c_buffer, 18); } -void testLLVMCompareSelectIntEQ() { +TEST(LLVM, CompareSelectIntEQ) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 1); std::vector b_buffer(N, 1); std::vector c_buffer(N, 0); @@ -874,20 +987,15 @@ void testLLVMCompareSelectIntEQ() { c_ref[i] = 0; } - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kEQ), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kEQ))); LLVMCodeGen cg(expr, {a, b, c}); @@ -904,30 +1012,25 @@ void testLLVMCompareSelectIntEQ() { } } -void testLLVMCompareSelectFloatEQ() { +TEST(LLVM, CompareSelectFloatEQ) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kFloat)); - Buffer b(BufHandle("B", {N}, kFloat)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kFloat)); + Placeholder b(BufHandle("B", {N}, kFloat)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 1.0f); std::vector b_buffer(N, 1.0f); std::vector c_buffer(N, 0); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kEQ), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kEQ))); LLVMCodeGen cg(expr, {a, b, c}); @@ -943,12 +1046,12 @@ void testLLVMCompareSelectFloatEQ() { assertAllEqual(c_buffer, 1); } -void testLLVMCompareSelectByteGT() { +TEST(LLVM, CompareSelectByteGT) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kByte)); - Buffer b(BufHandle("B", {N}, kByte)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kByte)); + Placeholder b(BufHandle("B", {N}, kByte)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 0); std::vector b_buffer(N, 0); std::vector c_buffer(N, 0); @@ -959,20 +1062,15 @@ void testLLVMCompareSelectByteGT() { c_ref[i] = 1; } - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kGT), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kGT))); LLVMCodeGen cg(expr, {a, b, c}); @@ -989,31 +1087,26 @@ void testLLVMCompareSelectByteGT() { } } -void testLLVMCompareSelectByteGE() { +TEST(LLVM, CompareSelectByteGE) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kByte)); - Buffer b(BufHandle("B", {N}, kByte)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kByte)); + Placeholder b(BufHandle("B", {N}, kByte)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 0); std::vector b_buffer(N, 0); std::vector c_buffer(N, 0); std::vector c_ref(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kGE), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kGE))); LLVMCodeGen cg(expr, {a, b, c}); @@ -1030,12 +1123,12 @@ void testLLVMCompareSelectByteGE() { } } -void testLLVMCompareSelectByteLT() { +TEST(LLVM, CompareSelectByteLT) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kByte)); - Buffer b(BufHandle("B", {N}, kByte)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kByte)); + Placeholder b(BufHandle("B", {N}, kByte)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 0); std::vector b_buffer(N, 128); std::vector c_buffer(N, 0); @@ -1046,20 +1139,15 @@ void testLLVMCompareSelectByteLT() { c_ref[i] = 0; } - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kLT), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kLT))); LLVMCodeGen cg(expr, {a, b, c}); @@ -1076,31 +1164,26 @@ void testLLVMCompareSelectByteLT() { } } -void testLLVMCompareSelectByteLE() { +TEST(LLVM, CompareSelectByteLE) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kByte)); - Buffer b(BufHandle("B", {N}, kByte)); - Buffer c(BufHandle("C", {N}, kInt)); + Placeholder a(BufHandle("A", {N}, kByte)); + Placeholder b(BufHandle("B", {N}, kByte)); + Placeholder c(BufHandle("C", {N}, kInt)); std::vector a_buffer(N, 0); std::vector b_buffer(N, 128); std::vector c_buffer(N, 0); std::vector c_ref(N, 1); - auto mask = IntImm::make(1); VarHandle i("i", kInt); auto expr = For::make( i, 0, N, - Store::make( - c, + c.store( {i}, CompareSelect::make( - Load::make(a, {i}, mask), - Load::make(b, {i}, mask), - CompareSelectOperation::kLE), - mask)); + a.load(i), b.load(i), CompareSelectOperation::kLE))); LLVMCodeGen cg(expr, {a, b, c}); @@ -1117,19 +1200,18 @@ void testLLVMCompareSelectByteLE() { } } -void testLLVMStoreFloat() { +TEST(LLVM, StoreFloat) { KernelScope kernel_scope; - Buffer result(BufHandle("result", {1}, kFloat)); + Placeholder result(BufHandle("result", {1}, kFloat)); std::vector result_buffer = {0.0f}; - auto expr = Store::make( - result, {IntImm::make(0)}, FloatImm::make(3.14f), IntImm::make(1)); + auto expr = result.store({0}, FloatImm::make(3.14f)); LLVMCodeGen cg(expr, {result}); std::vector args({result_buffer.data()}); ASSERT_EQ(cg.value(args), 0); ASSERT_EQ(result_buffer[0], 3.14f); } -void testLLVMSimpleMath01() { +TEST(LLVM, SimpleMath01) { KernelScope kernel_scope; const int N = 1024; Tensor* tensor = Compute("f", {{N, "i"}}, [](const VarHandle& i) { @@ -1137,7 +1219,7 @@ void testLLVMSimpleMath01() { }); LoopNest l({tensor}); Stmt* stmt = l.root_stmt(); - Buffer f_buf(BufHandle(tensor->func_var())); + Placeholder f_buf(BufHandle(tensor->buf())); LLVMCodeGen cg(stmt, {f_buf}); PaddedBuffer f_v(N, "f_v"); @@ -1151,16 +1233,16 @@ void testLLVMSimpleMath01() { ExpectAllNear(f_v, f_ref, 1e-5); } -void testLLVMComputeMul() { +TEST(LLVM, ComputeMul) { KernelScope kernel_scope; const int N = 1024; - Buffer a(BufHandle("a", {N}, kFloat)); - Buffer b(BufHandle("b", {N}, kFloat)); + Placeholder a(BufHandle("a", {N}, kFloat)); + Placeholder b(BufHandle("b", {N}, kFloat)); Tensor* c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) { - return Load::make(a, {i}, 1) * Load::make(b, {i}, 1); + return a.load(i) * b.load(i); }); - Buffer c_buf(BufHandle(c->func_var())); + Placeholder c_buf(BufHandle(c->buf())); LoopNest l({c}); Stmt* s = l.root_stmt(); @@ -1174,19 +1256,18 @@ void testLLVMComputeMul() { assertAllEqual(c_vec, 42.0f); } -void testLLVMBroadcastAdd() { +TEST(LLVM, BroadcastAdd) { KernelScope kernel_scope; const int M = 32; const int N = 1024; - Buffer a(BufHandle("a", {M, N}, kFloat)); - Buffer b(BufHandle("b", {N}, kFloat)); + Placeholder a(BufHandle("a", {M, N}, kFloat)); + Placeholder b(BufHandle("b", {N}, kFloat)); Tensor* c = Compute( "c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - ExprHandle mask(1); - return Load::make(a, {i, j}, mask) + Load::make(b, {j}, mask); + return a.load(i, j) + b.load(j); }); - Buffer c_buf(BufHandle(c->func_var())); + Placeholder c_buf(BufHandle(c->buf())); LoopNest l({c}); l.prepareForCodegen(); Stmt* s = l.root_stmt(); @@ -1208,7 +1289,7 @@ void testLLVMBroadcastAdd() { } } -void testLLVMBitwiseOps() { +TEST(LLVM, BitwiseOps) { KernelScope kernel_scope; auto a = IntImm::make(59); auto b = IntImm::make(11); @@ -1221,15 +1302,15 @@ void testLLVMBitwiseOps() { ASSERT_EQ(cg.value(), 11); } -void testLLVMDynamicShapeAdd() { +TEST(LLVM, DynamicShapeAdd) { KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); - Buffer a(BufHandle("a", {n}, kFloat)); - Buffer b(BufHandle("b", {n}, kFloat)); - Buffer c(BufHandle("c", {n}, kFloat)); + Placeholder a(BufHandle("a", {n}, kFloat)); + Placeholder b(BufHandle("b", {n}, kFloat)); + Placeholder c(BufHandle("c", {n}, kFloat)); VarHandle i("i", kInt); - Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1)); + Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); std::vector aData(size, 1.0f); std::vector bData(size, 2.0f); std::vector cData(size, 0.0f); @@ -1243,15 +1324,15 @@ void testLLVMDynamicShapeAdd() { testWithSize(37); } -void testLLVMBindDynamicShapeAdd() { +TEST(LLVM, BindDynamicShapeAdd) { KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); - Buffer a(BufHandle("a", {n}, kFloat)); - Buffer b(BufHandle("b", {n}, kFloat)); - Buffer c(BufHandle("c", {n}, kFloat)); + Placeholder a(BufHandle("a", {n}, kFloat)); + Placeholder b(BufHandle("b", {n}, kFloat)); + Placeholder c(BufHandle("c", {n}, kFloat)); VarHandle i("i", kInt); - Stmt* s = For::make(i, 0, n, Store::make(c, {i}, a(i) + b(i), 1)); + Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); std::vector aData(size, 1.0f); std::vector bData(size, 2.0f); std::vector cData(size, 0.0f); @@ -1264,14 +1345,15 @@ void testLLVMBindDynamicShapeAdd() { testWithSize(37); } -void testLLVMTensorDynamicShapeAdd() { +TEST(LLVM, TensorDynamicShapeAdd) { KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); - Buffer a(BufHandle("a", {n}, kFloat)); - Buffer b(BufHandle("b", {n}, kFloat)); - Tensor* c = Compute( - "c", {{n, "n"}}, [&](const VarHandle& i) { return a(i) + b(i); }); + Placeholder a(BufHandle("a", {n}, kFloat)); + Placeholder b(BufHandle("b", {n}, kFloat)); + Tensor* c = Compute("c", {{n, "n"}}, [&](const VarHandle& i) { + return a.load(i) + b.load(i); + }); LoopNest l({c}); Stmt* s = l.root_stmt(); LLVMCodeGen cg(s, {a, b, c, n}); @@ -1286,16 +1368,16 @@ void testLLVMTensorDynamicShapeAdd() { testWithSize(37); } -void testLLVMDynamicShape2D() { +TEST(LLVM, DynamicShape2D) { KernelScope kernel_scope; auto testWithSize = [](int32_t M, int32_t N) { VarHandle m("m", kInt); VarHandle n("n", kInt); - Buffer a(BufHandle("a", {m, n}, kFloat)); - Buffer b(BufHandle("b", {m, n}, kFloat)); + Placeholder a(BufHandle("a", {m, n}, kFloat)); + Placeholder b(BufHandle("b", {m, n}, kFloat)); Tensor* c = Compute( "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { - return a(i, j) + b(i, j); + return a.load(i, j) + b.load(i, j); }); LoopNest l({c}); l.prepareForCodegen(); @@ -1312,7 +1394,7 @@ void testLLVMDynamicShape2D() { testWithSize(37, 11); } -void testLLVMEmptyStmt() { +TEST(LLVM, EmptyStmt) { KernelScope kernel_scope; Stmt* s = new Block({}); @@ -1321,9 +1403,9 @@ void testLLVMEmptyStmt() { // Just don't crash. } -void testLLVMEliminatedStmt() { +TEST(LLVM, EliminatedStmt) { KernelScope kernel_scope; - Buffer a(BufHandle("a", {1}, kFloat)); + Placeholder a(BufHandle("a", {1}, kFloat)); Tensor* c = Compute("c", {{0, "m"}}, [&](const VarHandle& m) { return m; }); @@ -1337,14 +1419,14 @@ void testLLVMEliminatedStmt() { cg.call({aData, cData}); } -void testLLVMSimpleReduction() { +TEST(LLVM, SimpleReduction) { KernelScope kernel_scope; int M = 128; int N = 64; const int kTotalSize = M * N; - Buffer a("a", kFloat, {1, M, N}); + Placeholder a("a", kFloat, {1, M, N}); // TODO: why doesn't implicit vector work? std::vector axis = {DimArg(1)}; @@ -1376,14 +1458,14 @@ void testLLVMSimpleReduction() { ExpectAllNear(b_v, b_ref, 1e-5); } -void testLLVMRFactorReduction() { +TEST(LLVM, RFactorReduction) { KernelScope kernel_scope; int M = 128; int N = 64; const int kTotalSize = M * N; - Buffer a("a", kFloat, {1, M, N}); + Placeholder a("a", kFloat, {1, M, N}); // TODO: why doesn't implicit vector work? std::vector axis = {DimArg(1)}; @@ -1425,44 +1507,40 @@ void testLLVMRFactorReduction() { ExpectAllNear(b_v, b_ref, 1e-5); } -// TODO: disabled since this doesn't work. -void DISABLED_testLLVMRFactorVectorizedReduction() { +TEST(LLVM, RFactorVectorizedReduction) { KernelScope kernel_scope; int M = 128; int N = 64; const int kTotalSize = M * N; - Buffer a("a", kFloat, {1, M, N}); + Placeholder a("a", kFloat, {1, M, N}); - // TODO: why doesn't implicit vector work? - std::vector axis = {DimArg(1)}; - std::vector reduce_axis = {DimArg(M), DimArg(N)}; - Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis); + Tensor* b = Reduce("sum", {{1, "K"}}, Sum(), a, {{M, "M"}, {N, "N"}}); LoopNest loopnest({b}); std::vector loops = loopnest.getLoopStmtsFor(b); For* loop_k = loops.at(0); For* loop_m = loops.at(1); For* loop_n = loops.at(2); - loopnest.reorderAxis(loop_n, loop_m); - loops = loopnest.getLoopStmtsFor(b); + loopnest.rfactor(b->body(), loop_n->var()); + + loops = NodeFinder::find(loopnest.root_stmt()); loop_k = loops.at(0); - loop_n = loops.at(1); + // loop 1 is the initializer of tmp_buf loop_m = loops.at(2); - // Case-III reductions - loopnest.rfactor(b->body(), loop_n->var()); - loopnest.prepareForCodegen(); - Stmt* s = loopnest.root_stmt(); - s = IRSimplifier::simplify(s); + loop_n = loops.at(3); + loopnest.reorderAxis(loop_n, loop_m); - Block* root_block = dynamic_cast(s); - auto I = root_block->begin(); - ++I; + // Case-III reductions + loops = NodeFinder::find(loopnest.root_stmt()); + // Vectorize initializer of tmp_buf + loopnest.vectorize(loops[1]); + // Vectorize producer of tmp_buf + loopnest.vectorize(loops[2]); - For* outer_loop = dynamic_cast(*I); - loopnest.vectorize(outer_loop); + loopnest.prepareForCodegen(); - s = IRSimplifier::simplify(s); + Stmt* s = IRSimplifier::simplify(loopnest.root_stmt()); LLVMCodeGen cg(s, {a, b}); PaddedBuffer a_v(1, M, N, "a_v"); @@ -1483,6 +1561,94 @@ void DISABLED_testLLVMRFactorVectorizedReduction() { ExpectAllNear(b_v, b_ref, 1e-5); } +TEST(LLVM, VectorizedGEMM) { + KernelScope ks; + + int M = 32; + int N = 32; + int K = 48; + + Placeholder AP(BufHandle("A", {M, K}, kFloat)); + Placeholder BP(BufHandle("B", {K, N}, kFloat)); + Tensor* CT = Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {{K, "K"}}); + LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* m = loops[0]; + For* mo; + For* mi; + loop.splitWithMask(m, 16, &mo, &mi); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* n = loops[2]; + For* no; + For* ni; + loop.splitWithMask(n, 16, &no, &ni); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* mi = loops[1]; + For* no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* ni = loops[3]; + For* k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* mi = loops[2]; + For* k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto loops = NodeFinder::find(loop.root_stmt()); + loop.vectorize(loops[3]); + loop.vectorize(loops.back()); + } + + loop.prepareForCodegen(); + + Stmt* s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + LLVMCodeGen cg(s, {AP, BP, CT}); + + PaddedBuffer a_v(M, K, "a_v"); + PaddedBuffer b_v(K, N, "b_v"); + PaddedBuffer c_v(M, N, "c_v"); + PaddedBuffer c_ref(M, N, "c_ref"); + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + c_ref(m, n) = 0.f; + for (int k = 0; k < K; k++) { + c_ref(m, n) += a_v(m, k) * b_v(k, n); + } + } + } + + cg.call({a_v, b_v, c_v}); + + ExpectAllNear(c_v, c_ref, 1e-5); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 201a7e57820b3..89ad7eb2aecbb 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -7,9 +9,7 @@ #include #include #include -#include #include -#include #include #include #include @@ -22,7 +22,7 @@ namespace jit { using namespace torch::jit::tensorexpr; -void testExprSimple01() { +TEST(LoopNest, ExprSimple01) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", {{16, "X"}, {5, "y"}}, [](const VarHandle& x, const VarHandle& y) { @@ -41,7 +41,7 @@ void testExprSimple01() { l.splitWithTail(x_outer, 2, &x_2, &x_1, &x_tail_2); } -void testExprLower01() { +TEST(LoopNest, ExprLower01) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", {{16, "x"}, {5, "y"}}, [](const VarHandle& x, const VarHandle& y) { @@ -55,7 +55,7 @@ void testExprLower01() { ASSERT_LT(oss.str().size(), 200); } -void testExprSimple02() { +TEST(LoopNest, ExprSimple02) { KernelScope kernel_scope; auto func = [](const ExprHandle& x, const ExprHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; @@ -110,7 +110,7 @@ void testExprSimple02() { PaddedBuffer f_ref(26, 5, "f_res"); stmt = FlattenIndexes(stmt); - SimpleIREvaluator ir_eval(stmt, tensor); + SimpleIREvaluator ir_eval(stmt, {tensor}); ir_eval(f_v); for (int x = 0; x < 26; x++) { @@ -151,7 +151,7 @@ void assertForRanges( } } -void testExprSliceHeadWithLoopOptions() { +TEST(LoopNest, ExprSliceHeadWithLoopOptions) { KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); @@ -173,7 +173,7 @@ void testExprSliceHeadWithLoopOptions() { ASSERT_TRUE(head->loop_options().isDefault()); } -void testExprSliceTailWithLoopOptions() { +TEST(LoopNest, ExprSliceTailWithLoopOptions) { KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); @@ -200,7 +200,7 @@ void testExprSliceTailWithLoopOptions() { ASSERT_TRUE(tail_tail->loop_options().isDefault()); } -void testExprSliceHeadWhenFactorEqualsSize() { +TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { // When factor equals the For loop's original size, keep using the original // For loop. KernelScope kernel_scope; @@ -221,7 +221,7 @@ void testExprSliceHeadWhenFactorEqualsSize() { assertForRanges(body, {{0, 10}}); } -void testExprSliceHeadWhenFactorLargerThanSize() { +TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); @@ -240,7 +240,7 @@ void testExprSliceHeadWhenFactorLargerThanSize() { assertForRanges(body, {{0, 10}}); } -void testExprSliceHead() { +TEST(LoopNest, ExprSliceHead) { KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); @@ -261,7 +261,7 @@ void testExprSliceHead() { assertForRanges(body, {{0, 4}, {4, 10}}); } -void testExprSliceHeadWithNonZeroStart() { +TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); @@ -286,7 +286,7 @@ void testExprSliceHeadWithNonZeroStart() { assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); } -void testExprSliceTailWhenFactorEqualsSize() { +TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { // When factor equals the For loop's original size, keep using the original // For loop. KernelScope kernel_scope; @@ -307,7 +307,7 @@ void testExprSliceTailWhenFactorEqualsSize() { assertForRanges(body, {{0, 10}}); } -void testExprSliceTailWhenFactorLargerThanSize() { +TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { // When factor equals the For loop's original size, keep using the original // For loop. KernelScope kernel_scope; @@ -328,7 +328,7 @@ void testExprSliceTailWhenFactorLargerThanSize() { assertForRanges(body, {{0, 10}}); } -void testExprSliceTail() { +TEST(LoopNest, ExprSliceTail) { KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); @@ -349,7 +349,7 @@ void testExprSliceTail() { assertForRanges(body, {{0, 6}, {6, 10}}); } -void testExprSplitAndSlice() { +TEST(LoopNest, ExprSplitAndSlice) { // 0: splitWithTail // 1: sliceTail on inner loop // 2: sliceHead on outer loop @@ -408,7 +408,7 @@ void testExprSplitAndSlice() { assertForRanges(loop->body(), {{0, 19}, {19, 21}}); } -void testExprSliceAndNormalize() { +TEST(LoopNest, ExprSliceAndNormalize) { // 0: sliceHead // 1: normalize tail KernelScope kernel_scope; @@ -439,7 +439,7 @@ T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { return eval.value(value); } -void testExprSliceWithVariableDimension() { +TEST(LoopNest, ExprSliceWithVariableDimension) { auto testWithDimension = [](int dimension, const std::vector>& expected_for_ranges) { @@ -478,7 +478,7 @@ void testExprSliceWithVariableDimension() { testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); } -void testExprSplitWithTail() { +TEST(LoopNest, ExprSplitWithTail) { KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); @@ -513,7 +513,7 @@ void testExprSplitWithTail() { assertForRange(loop, 0, 12); } -void testExprSplitWithTailNone() { +TEST(LoopNest, ExprSplitWithTailNone) { KernelScope kernel_scope; auto func = [](const ExprHandle& x, const ExprHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; @@ -560,7 +560,7 @@ void testExprSplitWithTailNone() { PaddedBuffer f_v(24, 5, "f_v"); PaddedBuffer f_ref(24, 5, "f_res"); - SimpleIREvaluator ir_eval(stmt, tensor); + SimpleIREvaluator ir_eval(stmt, {tensor}); ir_eval(f_v); for (int x = 0; x < 24; x++) { @@ -573,15 +573,15 @@ void testExprSplitWithTailNone() { } } -void testExprSplitWithMask01() { +TEST(LoopNest, ExprSplitWithMask01) { KernelScope kernel_scope; const int M = 26; const int N = 5; - Buffer a_buf("a", kFloat, {M, N}); - Buffer b_buf("b", kFloat, {M, N}); + Placeholder a_buf("a", kFloat, {M, N}); + Placeholder b_buf("b", kFloat, {M, N}); Tensor* tensor = Compute( "f", {{M, "m"}, {N, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf(m, n) + b_buf(m, n) + 1.0f; + return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; }); For* n_outer; For* n_inner; @@ -604,20 +604,20 @@ void testExprSplitWithMask01() { } } - SimpleIREvaluator(stmt, a_buf, b_buf, tensor)(a_v, b_v, c_v); + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); ExpectAllNear(c_v, c_ref, 1e-5); } // Tests the case where we split a loop cleanly multiple times, we should not // insert any masks. -void testExprSplitWithMaskRepeatedNoMask() { +TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { KernelScope kernel_scope; const int M = 64; - Buffer a_buf("a", kFloat, {M}); - Buffer b_buf("b", kFloat, {M}); + Placeholder a_buf("a", kFloat, {M}); + Placeholder b_buf("b", kFloat, {M}); Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { - return a_buf(m) + b_buf(m) + 1.0f; + return a_buf.load(m) + b_buf.load(m) + 1.0f; }); LoopNest l({tensor}); @@ -643,13 +643,13 @@ void testExprSplitWithMaskRepeatedNoMask() { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testSplitWithTailWithLoopOptions() { +TEST(LoopNest, SplitWithTailWithLoopOptions) { KernelScope kernel_scope; const int M = 21; - Buffer a_buf("a", kFloat, {M}); - Buffer b_buf("b", kFloat, {M}); + Placeholder a_buf("a", kFloat, {M}); + Placeholder b_buf("b", kFloat, {M}); Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { - return a_buf(m) + b_buf(m) + 1.0f; + return a_buf.load(m) + b_buf.load(m) + 1.0f; }); For *outer, *inner, *tail; @@ -673,13 +673,13 @@ void testSplitWithTailWithLoopOptions() { ASSERT_TRUE(tail->loop_options().isDefault()); } -void testSplitWithMaskWithLoopOptions() { +TEST(LoopNest, SplitWithMaskWithLoopOptions) { KernelScope kernel_scope; const int M = 21; - Buffer a_buf("a", kFloat, {M}); - Buffer b_buf("b", kFloat, {M}); + Placeholder a_buf("a", kFloat, {M}); + Placeholder b_buf("b", kFloat, {M}); Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { - return a_buf(m) + b_buf(m) + 1.0f; + return a_buf.load(m) + b_buf.load(m) + 1.0f; }); For *outer, *inner; @@ -696,18 +696,18 @@ void testSplitWithMaskWithLoopOptions() { ASSERT_TRUE(inner->loop_options().isDefault()); } -void testScheduleBroadcastAddBuffer() { +TEST(LoopNest, ScheduleBroadcastAddBuffer) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N}); - Buffer b_buf("b", kFloat, {N, K}); + Placeholder a_buf("a", kFloat, {M, N}); + Placeholder b_buf("b", kFloat, {N, K}); Tensor* c = Compute( "broadcast_add", {{M, "m"}, {N, "n"}, {K, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n) + b_buf(n, k); + return a_buf.load(m, n) + b_buf.load(n, k); }); LoopNest l({c}); Stmt* stmt = l.root_stmt(); @@ -729,7 +729,7 @@ void testScheduleBroadcastAddBuffer() { b_v.Backup(); PaddedBuffer c_v(M, N, K, "c_buf"); - SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); ir_eval(a_v, b_v, c_v); a_v.CheckBackup(); @@ -745,18 +745,18 @@ void testScheduleBroadcastAddBuffer() { ExpectAllNear(c_v, c_ref, 1e-5); } -void testScheduleFunctionCall01() { +TEST(LoopNest, ScheduleFunctionCall01) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N}); - Buffer b_buf("b", kFloat, {N, K}); + Placeholder a_buf("a", kFloat, {M, N}); + Placeholder b_buf("b", kFloat, {N, K}); Tensor* c = Compute( "broadcast_add", {{M, "m"}, {N, "n"}, {K, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n) + b_buf(n, k); + return a_buf.load(m, n) + b_buf.load(n, k); }); Tensor* d = Compute( "d", @@ -796,33 +796,33 @@ void testScheduleFunctionCall01() { } } - SimpleIREvaluator eval(stmt, a_buf, b_buf, d); + SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); eval(a_v, b_v, d_v); ExpectAllNear(d_v, d_ref, 1e-5); } -void testScheduleInlineSimple() { +TEST(LoopNest, ScheduleInlineSimple) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N}); - Buffer b_buf("b", kFloat, {N, K}); - Buffer c_buf("c", kFloat, {M, N}); - Buffer d_buf("d", kFloat, {M, K}); + Placeholder a_buf("a", kFloat, {M, N}); + Placeholder b_buf("b", kFloat, {N, K}); + Placeholder c_buf("c", kFloat, {M, N}); + Placeholder d_buf("d", kFloat, {M, K}); Tensor* x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n) * b_buf(n, k); + return a_buf.load(m, n) * b_buf.load(n, k); }); Tensor* y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k); + return c_buf.load(m, n) * d_buf.load(m, k) + x->call(m, n, k); }); LoopNest l1({y}); @@ -835,8 +835,8 @@ void testScheduleInlineSimple() { Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); - SimpleIREvaluator eval1(stmt1, a_buf, b_buf, c_buf, d_buf, y); - SimpleIREvaluator eval2(stmt2, a_buf, b_buf, c_buf, d_buf, y); + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); PaddedBuffer a_v(M, N); PaddedBuffer b_v(N, K); @@ -888,22 +888,22 @@ void InlineFunc01Helper(const std::vector& inline_order) { const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N}); - Buffer b_buf("b", kFloat, {N, K}); - Buffer c_buf("c", kFloat, {M, N}); - Buffer d_buf("d", kFloat, {M, K}); + Placeholder a_buf("a", kFloat, {M, N}); + Placeholder b_buf("b", kFloat, {N, K}); + Placeholder c_buf("c", kFloat, {M, N}); + Placeholder d_buf("d", kFloat, {M, K}); Tensor* x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n) * b_buf(n, k); + return a_buf.load(m, n) * b_buf.load(n, k); }); Tensor* y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k); + return c_buf.load(m, n) * d_buf.load(m, k) + x->call(m, n, k); }); Tensor* z = Compute( "z", @@ -966,7 +966,7 @@ void InlineFunc01Helper(const std::vector& inline_order) { } } - SimpleIREvaluator eval(stmt, a_buf, b_buf, c_buf, d_buf, z); + SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); eval(a_v, b_v, c_v, d_v, z_v); ExpectAllNear(z_v, z_ref, 1e-5); } @@ -976,8 +976,9 @@ void InlineFunc01Helper(const std::vector& inline_order) { "z", {{M, "m3"}, {N, "n3"}, {K, "k3"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n) * b_buf(n, k) + - (c_buf(m, n) * d_buf(m, k) + a_buf(m, n) * b_buf(n, k)); + return a_buf.load(m, n) * b_buf.load(n, k) + + (c_buf.load(m, n) * d_buf.load(m, k) + + a_buf.load(m, n) * b_buf.load(n, k)); }); LoopNest l2({z2}); l2.prepareForCodegen(); @@ -992,7 +993,7 @@ void InlineFunc01Helper(const std::vector& inline_order) { } } -void testScheduleInlineFunc01() { +TEST(LoopNest, ScheduleInlineFunc01) { InlineFunc01Helper({"x", "y"}); InlineFunc01Helper({"y", "x"}); InlineFunc01Helper({"x"}); @@ -1001,7 +1002,7 @@ void testScheduleInlineFunc01() { } // Make sure we cache random vars if we should. -void testScheduleInlineRandom() { +TEST(LoopNest, ScheduleInlineRandom) { KernelScope kernel_scope; const int M = 4; const int N = 5; @@ -1041,7 +1042,7 @@ void testScheduleInlineRandom() { } // Make sure we don't cache random vars that are not being inlined. -void testScheduleInlineRandomUnrelated() { +TEST(LoopNest, ScheduleInlineRandomUnrelated) { KernelScope kernel_scope; const int M = 4; const int N = 5; @@ -1082,7 +1083,7 @@ void testScheduleInlineRandomUnrelated() { // Make sure we generate the right number of random values == the dimensionality // of the production tensor. -void testScheduleInlineRandomLowerDimensions() { +TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { KernelScope kernel_scope; const int M = 4; const int N = 5; @@ -1119,19 +1120,19 @@ void testScheduleInlineRandomLowerDimensions() { } // Make sure we don't screw up intrinsics thinking they're rand. -void testScheduleInlineIntrinsics() { +TEST(LoopNest, ScheduleInlineIntrinsics) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N}); - Buffer b_buf("b", kFloat, {N, K}); + Placeholder a_buf("a", kFloat, {M, N}); + Placeholder b_buf("b", kFloat, {N, K}); Tensor* x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n) * b_buf(n, k); + return a_buf.load(m, n) * b_buf.load(n, k); }); Tensor* y = Compute( "y", @@ -1164,8 +1165,8 @@ void testScheduleInlineIntrinsics() { Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); - SimpleIREvaluator eval1(stmt1, a_buf, b_buf, y); - SimpleIREvaluator eval2(stmt2, a_buf, b_buf, y); + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); PaddedBuffer y_1(M, N, K); PaddedBuffer y_2(M, N, K); @@ -1180,7 +1181,7 @@ void testScheduleInlineIntrinsics() { } // Make sure we can handle rand and non-rand intrinsics. -void testScheduleInlineRandWithIntrinsics() { +TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { KernelScope kernel_scope; const int M = 4; const int N = 5; @@ -1219,7 +1220,7 @@ void testScheduleInlineRandWithIntrinsics() { } // Split a Compute then inline it into another compute. -void testScheduleSplitAThenInline() { +TEST(LoopNest, ScheduleSplitAThenInline) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1238,7 +1239,7 @@ void testScheduleSplitAThenInline() { } // Split a Compute then inline another Compute into it. -void testScheduleSplitBThenInline() { +TEST(LoopNest, ScheduleSplitBThenInline) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1258,7 +1259,7 @@ void testScheduleSplitBThenInline() { Stmt* s = IRSimplifier::simplify(l.root_stmt()); std::vector output(6, 0); - SimpleIREvaluator eval(s, b); + SimpleIREvaluator eval(s, {b}); eval(output); for (int i = 0; i < 6; ++i) { @@ -1267,7 +1268,7 @@ void testScheduleSplitBThenInline() { } // Split a Compute twice then inline it. -void testScheduleSplitTwiceThenInline() { +TEST(LoopNest, ScheduleSplitTwiceThenInline) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1287,7 +1288,7 @@ void testScheduleSplitTwiceThenInline() { } // Inline a Compute, then split. -void testScheduleInlineThenSplit() { +TEST(LoopNest, ScheduleInlineThenSplit) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1307,7 +1308,7 @@ void testScheduleInlineThenSplit() { l.prepareForCodegen(); Stmt* s = IRSimplifier::simplify(l.root_stmt()); std::vector output(6, 0); - SimpleIREvaluator eval(s, b); + SimpleIREvaluator eval(s, {b}); eval(output); for (int i = 0; i < 6; ++i) { @@ -1316,7 +1317,7 @@ void testScheduleInlineThenSplit() { } // Split a Compute, inline it, then split the result. -void testScheduleSplitInlineThenSplit() { +TEST(LoopNest, ScheduleSplitInlineThenSplit) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1338,7 +1339,7 @@ void testScheduleSplitInlineThenSplit() { l.prepareForCodegen(); Stmt* s = IRSimplifier::simplify(l.root_stmt()); std::vector output(16, 0); - SimpleIREvaluator eval(s, b); + SimpleIREvaluator eval(s, {b}); eval(output); for (int i = 0; i < 16; ++i) { @@ -1347,7 +1348,7 @@ void testScheduleSplitInlineThenSplit() { } // Oversplit a loop that is simplified out after inlining. -void testScheduleSplitInlineSimplify() { +TEST(LoopNest, ScheduleSplitInlineSimplify) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return ExprHandle(4) * i - ExprHandle(2) * i; @@ -1367,7 +1368,7 @@ void testScheduleSplitInlineSimplify() { } // Inline a Compute with two consumers. -void testScheduleInlineThreeMixedOnce() { +TEST(LoopNest, ScheduleInlineThreeMixedOnce) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1386,7 +1387,7 @@ void testScheduleInlineThreeMixedOnce() { Stmt* s = IRSimplifier::simplify(l.root_stmt()); std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, c); + SimpleIREvaluator eval(s, {c}); eval(output); for (int k = 0; k < 4; ++k) { @@ -1397,7 +1398,7 @@ void testScheduleInlineThreeMixedOnce() { } // Inline Compute A into B, then inline B into C. -void testScheduleInlineThreeMixedTwice() { +TEST(LoopNest, ScheduleInlineThreeMixedTwice) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1417,7 +1418,7 @@ void testScheduleInlineThreeMixedTwice() { Stmt* s = IRSimplifier::simplify(l.root_stmt()); std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, c); + SimpleIREvaluator eval(s, {c}); eval(output); for (int k = 0; k < 4; ++k) { @@ -1428,7 +1429,7 @@ void testScheduleInlineThreeMixedTwice() { } // Inline a Compute that is both a producer and consumer. -void testScheduleInlineThreeMixedInner() { +TEST(LoopNest, ScheduleInlineThreeMixedInner) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1447,7 +1448,7 @@ void testScheduleInlineThreeMixedInner() { Stmt* s = IRSimplifier::simplify(l.root_stmt()); std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, c); + SimpleIREvaluator eval(s, {c}); eval(output); for (int k = 0; k < 4; ++k) { @@ -1458,7 +1459,7 @@ void testScheduleInlineThreeMixedInner() { } // Split 3 Computes, then inline the first two into the last. -void testScheduleInlineThreeMixedSplit() { +TEST(LoopNest, ScheduleInlineThreeMixedSplit) { KernelScope kernel_scope; Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); @@ -1483,17 +1484,60 @@ void testScheduleInlineThreeMixedSplit() { ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); } -void testScheduleFuserStyle() { +// Check that inlining works for output tensors too +TEST(LoopNest, ScheduleInlineOutputTensors) { + KernelScope kernel_scope; + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor* x = Compute( + "x", + {{M, "m1"}, {N, "n1"}, {K, "k1"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return m * n * k; + }); + Tensor* y = Compute( + "y", + {{M, "m2"}, {N, "n2"}, {K, "k2"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x->call(m, n, k) + m; + }); + + LoopNest l1({x, y}); + l1.computeInline(x->buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); + std::ostringstream oss; + oss << *stmt1; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m1 = 0; m1 < 4; m1++) +# CHECK: for (int n1 = 0; n1 < 5; n1++) +# CHECK: for (int k1 = 0; k1 < 6; k1++) +# CHECK: x[m1, n1, k1] = (n1 * m1) * k1; +# CHECK: for (int m2 = 0; m2 < 4; m2++) +# CHECK: for (int n2 = 0; n2 < 5; n2++) +# CHECK: for (int k2 = 0; k2 < 6; k2++) +# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, ScheduleFuserStyle) { KernelScope kernel_scope; const int kVectorSize = 8; const int kVectorCount = 128; const int kTotalSize = kVectorSize * kVectorCount; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Tensor* b = Compute( "f", {{kTotalSize, "i"}}, [&](const std::vector& axes) { - return a_buf(axes[0]) + 11.0f; + return a_buf.load(axes[0]) + 11.0f; }); Tensor* c = Compute( @@ -1508,7 +1552,7 @@ void testScheduleFuserStyle() { std::vector a_data(kTotalSize, 7.0f); std::vector b_data(kTotalSize, 0.0f); std::vector c_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, a_buf, b, c)(a_data, b_data, c_data); + SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); for (int i = 0; i < kTotalSize; i++) { ASSERT_EQ(b_data[i], 18.0f); @@ -1516,25 +1560,25 @@ void testScheduleFuserStyle() { } } -void testScheduleFuserThreeArg() { +TEST(LoopNest, ScheduleFuserThreeArg) { KernelScope kernel_scope; const int kVectorSize = 8; const int kVectorCount = 128; const int kTotalSize = kVectorSize * kVectorCount; - Buffer a(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Buffer b(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); - Buffer c(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); - Buffer d(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder a(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder b(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder c(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); + Placeholder d(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); Tensor* e = Compute("e", {{kTotalSize, "i"}}, [&](const VarHandle& i) { - return a(i) + b(i); + return a.load(i) + b.load(i); }); Tensor* f = Compute("f", {{kTotalSize, "i"}}, [&](const VarHandle& i) { - return (*e)(i) + c(i); + return e->call(i) + c.load(i); }); Tensor* g = Compute("g", {{kTotalSize, "i"}}, [&](const VarHandle& i) { - return (*f)(i) + d(i); + return f->call(i) + d.load(i); }); LoopNest l({g}); @@ -1548,23 +1592,23 @@ void testScheduleFuserThreeArg() { std::vector c_data(kTotalSize, 3.0f); std::vector d_data(kTotalSize, 4.0f); std::vector g_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, a, b, c, d, g)(a_data, b_data, c_data, d_data, g_data); + SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); for (int i = 0; i < kTotalSize; i++) { ASSERT_EQ(g_data[i], 10.0f); } } -void testScheduleDynamicShape2D() { +TEST(LoopNest, ScheduleDynamicShape2D) { KernelScope kernel_scope; auto testWithSize = [](int32_t M, int32_t N) { VarHandle m("m", kInt); VarHandle n("n", kInt); - Buffer a(BufHandle("a", {m, n}, kFloat)); - Buffer b(BufHandle("b", {m, n}, kFloat)); + Placeholder a(BufHandle("a", {m, n}, kFloat)); + Placeholder b(BufHandle("b", {m, n}, kFloat)); Tensor* c = Compute( "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { - return a(i, j) + b(i, j); + return a.load(i, j) + b.load(i, j); }); LoopNest l({c}); Stmt* s = l.root_stmt(); @@ -1580,7 +1624,7 @@ void testScheduleDynamicShape2D() { testWithSize(37, 11); } -void testLoopNestComputeAt_1() { +TEST(LoopNest, LoopNestComputeAt_1) { // Verify that compute_at works on the following example: // // for (int i_a = 0; i_a < N; i_a++) { @@ -1633,7 +1677,7 @@ void testLoopNestComputeAt_1() { assertAllEqual(b_data, b_ref); } -void testLoopNestComputeAt_2() { +TEST(LoopNest, LoopNestComputeAt_2) { // Verify that compute_at works on the following example: // // for (int py = 0; py < H+1; py++) { @@ -1686,7 +1730,7 @@ void testLoopNestComputeAt_2() { const std::string& verification_pattern = R"IR( # CHECK: for (int cy = 0; cy < H; cy++) -# CHECK: Allocate(temp, int, {2, W + 1}) +# CHECK: Allocate(temp, int, {2 * (W + 1)}) # CHECK: for # CHECK: for # CHECK: for (int cx = 0; cx < W; cx++) @@ -1718,7 +1762,7 @@ void testLoopNestComputeAt_2() { R"IR( # CHECK: for (int cy = 0; cy < H; cy++) # CHECK: for (int cx = 0; cx < W; cx++) -# CHECK: Allocate(temp, int, {2, 2}) +# CHECK: Allocate(temp, int, {4}) # CHECK: for # CHECK: for # CHECK-NOT: prod[ @@ -1735,7 +1779,7 @@ void testLoopNestComputeAt_2() { } } -void testLoopNestComputeAt_3() { +TEST(LoopNest, LoopNestComputeAt_3) { // Verify that compute_at works on the following example: // // A(x,y) = x*y @@ -1803,7 +1847,7 @@ void testLoopNestComputeAt_3() { # CHECK: for (int cx = 0; cx < W; cx++) # CHECK: C[ # CHECK: for (int dy = 0; dy < H; dy++) -# CHECK: Allocate(temp, int, {1, W}) +# CHECK: Allocate(temp, int, {W}) # CHECK: for (int dx = 0; dx < W; dx++) # CHECK-NOT: A[)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1840,7 +1884,7 @@ void testLoopNestComputeAt_3() { # CHECK: C[ # CHECK: for (int dy = 0; dy < H; dy++) # CHECK: for (int dx = 0; dx < W; dx++) -# CHECK: Allocate(temp, int, {1, 1}) +# CHECK: Allocate(temp, int, {1}) # CHECK-NOT: A[)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1853,7 +1897,7 @@ void testLoopNestComputeAt_3() { } } -void testLoopNestComputeAt_4() { +TEST(LoopNest, LoopNestComputeAt_4) { // TODO: Verify that computeAt works with reduction axis } @@ -1873,7 +1917,7 @@ class LoopOrderHelper : public IRVisitor { } }; -void testLoopNestReorderAxis1() { +TEST(LoopNest, LoopNestReorderAxis1) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", {{2, "x"}, {3, "y"}}, [](const VarHandle& x, const VarHandle& y) { @@ -1922,7 +1966,7 @@ void testLoopNestReorderAxis1() { ASSERT_EQ(oss1.str(), oss2.str()); } -void testLoopNestReorderPartialAxes() { +TEST(LoopNest, LoopNestReorderPartialAxes) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", @@ -1970,7 +2014,7 @@ void testLoopNestReorderPartialAxes() { } } -void testLoopNestReorderInternalAxis() { +TEST(LoopNest, LoopNestReorderInternalAxis) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", @@ -2007,7 +2051,7 @@ void testLoopNestReorderInternalAxis() { } } -void testLoopNestReorderEnclosingAxis() { +TEST(LoopNest, LoopNestReorderEnclosingAxis) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", @@ -2043,7 +2087,7 @@ void testLoopNestReorderEnclosingAxis() { } } -void testLoopNestReorderSameAxis() { +TEST(LoopNest, LoopNestReorderSameAxis) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", {{2, "x"}, {3, "y"}}, [](const VarHandle& x, const VarHandle& y) { @@ -2062,7 +2106,7 @@ void testLoopNestReorderSameAxis() { ASSERT_EQ(oss.str(), oss2.str()); } -void testLoopNestReorderExtraStatements() { +TEST(LoopNest, LoopNestReorderExtraStatements) { /* We're going for a structure like this: * for x in ... * Stmt 1 @@ -2084,16 +2128,19 @@ void testLoopNestReorderExtraStatements() { }); LoopNest l({tensor}); - Buffer extra(BufHandle("res", {6, 3}, kFloat)); + Placeholder extra(BufHandle("res", {6, 3}, kFloat)); auto loops = l.getLoopStmtsFor(tensor); VarHandle i = VarHandle(loops[0]->var()); - Stmt* store_1 = Store::make(extra, {i, 0}, ExprHandle(1.f), 1); - Stmt* store_2 = Store::make(extra, {i, 1}, ExprHandle(2.f), 1); + Stmt* store_1 = + Store::make(BufHandle(extra.data()), {i, 0}, ExprHandle(1.f), 1); + Stmt* store_2 = + Store::make(BufHandle(extra.data()), {i, 1}, ExprHandle(2.f), 1); // stmt 3 is the Function body. - Stmt* store_3 = Store::make(extra, {i, 2}, ExprHandle(4.f), 1); + Stmt* store_3 = + Store::make(BufHandle(extra.data()), {i, 2}, ExprHandle(4.f), 1); loops[0]->body()->prepend_stmt(store_1); loops[1]->body()->prepend_stmt(store_2); @@ -2224,16 +2271,16 @@ void LoopNestReorderTestHelper( [](const std::vector&) { return -1; }); LoopNest l({c}); - Buffer extra(BufHandle("extra", {5}, kInt)); + Placeholder extra(BufHandle("extra", {5}, kInt)); auto loops = l.getLoopStmtsFor(c); int j = 0; for (auto* l : loops) { // Add an increment at each layer of the loop which counts the number of // times the loop executes. - Load* load = new Load(extra, {new IntImm(j)}, new IntImm(1)); + Load* load = new Load(extra.data(), {new IntImm(j)}, new IntImm(1)); Add* add = new Add(load, new IntImm(1)); - Stmt* store = Store::make(extra, {j}, ExprHandle(add), 1); + Stmt* store = new Store(extra.data(), {new IntImm(j)}, add, new IntImm(1)); if (prepend) { l->body()->prepend_stmt(store); } @@ -2297,7 +2344,7 @@ void LoopNestReorderTestHelper( } } -void testLoopNestReorderLongStringOfPreOrphans() { +TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { for (int i = 0; i < 5; ++i) { for (int j = 0; j < 5; ++j) { // skip noops, since we check the loop isn't the same after reordering. @@ -2308,7 +2355,7 @@ void testLoopNestReorderLongStringOfPreOrphans() { } } -void testLoopNestReorderLongStringOfPostOrphans() { +TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { for (int i = 0; i < 5; ++i) { for (int j = 0; j < 5; ++j) { // skip noops, since we check the loop isn't the same after reordering. @@ -2319,7 +2366,7 @@ void testLoopNestReorderLongStringOfPostOrphans() { } } -void testLoopNestReorderLongStringFull() { +TEST(LoopNest, LoopNestReorderLongStringFull) { for (int i = 0; i < 5; ++i) { for (int j = 0; j < 5; ++j) { // skip noops, since we check the loop isn't the same after reordering. @@ -2330,27 +2377,27 @@ void testLoopNestReorderLongStringFull() { } } -void testLoopNestReorderInternalLoopNest() { +TEST(LoopNest, LoopNestReorderInternalLoopNest) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N}); - Buffer b_buf("b", kFloat, {N, K}); - Buffer c_buf("c", kFloat, {M, N}); - Buffer d_buf("d", kFloat, {M, K}); + Placeholder a_buf("a", kFloat, {M, N}); + Placeholder b_buf("b", kFloat, {N, K}); + Placeholder c_buf("c", kFloat, {M, N}); + Placeholder d_buf("d", kFloat, {M, K}); Tensor* x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n) * b_buf(n, k); + return a_buf.load(m, n) * b_buf.load(n, k); }); Tensor* y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k); + return c_buf.load(m, n) * d_buf.load(m, k) + x->call(m, n, k); }); Tensor* z = Compute( "z", @@ -2430,13 +2477,13 @@ void testLoopNestReorderInternalLoopNest() { } } - SimpleIREvaluator eval(stmt, a_buf, b_buf, c_buf, d_buf, z); + SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); eval(a_v, b_v, c_v, d_v, z_v); ExpectAllNear(z_v, z_ref, 1e-5); } } -void testOuterLoopVectorization() { +TEST(LoopNest, OuterLoopVectorization) { KernelScope kernel_scope; Tensor* tensor = Compute( "f", {{8, "X"}, {8, "y"}}, [](const VarHandle& x, const VarHandle& y) { @@ -2481,7 +2528,7 @@ std::string constantUpperBoundLoopIR(int upper_bound_val) { } // namespace -void testUnroll() { +TEST(LoopNest, Unroll) { const std::string actual = constantUpperBoundLoopIR(3); const std::string& verification_pattern = R"IR( @@ -2492,7 +2539,7 @@ void testUnroll() { torch::jit::testing::FileCheck().run(verification_pattern, actual); } -void testUnrollOuter() { +TEST(LoopNest, UnrollOuter) { KernelScope kernel_scope; ExprHandle outer_bound(3); ExprHandle inner_bound(4); @@ -2521,7 +2568,7 @@ void testUnrollOuter() { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testUnrollInner() { +TEST(LoopNest, UnrollInner) { KernelScope kernel_scope; ExprHandle outer_bound(3); ExprHandle inner_bound(4); @@ -2548,7 +2595,7 @@ void testUnrollInner() { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testUnrollMultipleStatements() { +TEST(LoopNest, UnrollMultipleStatements) { KernelScope kernel_scope; const int kTotalSize = 3; BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); @@ -2559,8 +2606,9 @@ void testUnrollMultipleStatements() { x, 0, kTotalSize, - Block::make({Store::make(a_buf, {x}, x * 2), - Store::make(b_buf, {x}, Load::make(a_buf, {x}, 1))})); + Block::make( + {Store::make(a_buf, {x}, x * 2), + Store::make(b_buf, {x}, Load::make(a_buf, {x}, 1))})); Block::make({f}); Stmt* unrolled = nullptr; LoopNest::unroll(f, &unrolled); @@ -2578,7 +2626,7 @@ void testUnrollMultipleStatements() { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testUnrollEmpty() { +TEST(LoopNest, UnrollEmpty) { const std::string actual = constantUpperBoundLoopIR(0); const std::string& verification_pattern = R"IR( # CHECK-NOT: A[ @@ -2587,7 +2635,7 @@ void testUnrollEmpty() { torch::jit::testing::FileCheck().run(verification_pattern, actual); } -void testNoUnroll() { +TEST(LoopNest, NoUnroll) { KernelScope kernel_scope; VarHandle upper_bound("N", kInt); Tensor* A = Compute( @@ -2599,7 +2647,7 @@ void testNoUnroll() { LoopNest::unroll(loops[0], &unrolled), "non-constant loop"); } -void testUnrollWithLet() { +TEST(LoopNest, UnrollWithLet) { KernelScope kernel_scope; const int kTotalSize = 3; BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); @@ -2611,9 +2659,10 @@ void testUnrollWithLet() { x, 0, kTotalSize, - Block::make({Let::make(e, 7), - Store::make(a_buf, {x}, e), - Store::make(b_buf, {x}, e + 1)})); + Block::make( + {Let::make(e, 7), + Store::make(a_buf, {x}, e), + Store::make(b_buf, {x}, e + 1)})); Block::make({f}); Stmt* unrolled = nullptr; LoopNest::unroll(f, &unrolled); @@ -2633,7 +2682,7 @@ void testUnrollWithLet() { std::vector a_v(kTotalSize, 0); std::vector b_v(kTotalSize, 0); - SimpleIREvaluator eval(unrolled, a_buf, b_buf); + SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); eval(a_v, b_v); for (int i = 0; i < kTotalSize; ++i) { ASSERT_EQ(a_v[i], 7); @@ -2641,7 +2690,7 @@ void testUnrollWithLet() { } } -void testNormalizeStartPositive() { +TEST(LoopNest, NormalizeStartPositive) { KernelScope kernel_scope; // Input IR: @@ -2653,9 +2702,9 @@ void testNormalizeStartPositive() { BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); VarHandle x("x", kInt); - auto for_body = - Block::make({Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x}, 1), 1), - Store::make(b_buf, {x}, x * 2)}); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x}, 1), 1), + Store::make(b_buf, {x}, x * 2)}); auto for_stmt = For::make(x, 50, 100, for_body); Block::make({for_stmt}); @@ -2674,7 +2723,7 @@ void testNormalizeStartPositive() { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } -void testNormalizeStartNegative() { +TEST(LoopNest, NormalizeStartNegative) { KernelScope kernel_scope; // Input IR: @@ -2707,7 +2756,7 @@ void testNormalizeStartNegative() { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } -void testNormalizeStartZero() { +TEST(LoopNest, NormalizeStartZero) { KernelScope kernel_scope; // Input IR: @@ -2721,9 +2770,9 @@ void testNormalizeStartZero() { BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); VarHandle x("x", kInt); - auto for_body = - Block::make({Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x}, 1), 1), - Store::make(b_buf, {x}, x * 2)}); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x}, 1), 1), + Store::make(b_buf, {x}, x * 2)}); auto for_stmt = For::make(x, 0, 100, for_body); Block::make({for_stmt}); @@ -2742,7 +2791,7 @@ void testNormalizeStartZero() { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } -void testNormalizeStartVariable() { +TEST(LoopNest, NormalizeStartVariable) { KernelScope kernel_scope; // Input IR: @@ -2756,9 +2805,9 @@ void testNormalizeStartVariable() { BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - auto for_body = - Block::make({Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x}, 1), 1), - Store::make(b_buf, {x}, x * 2)}); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x}, 1), 1), + Store::make(b_buf, {x}, x * 2)}); auto for_stmt = For::make(x, y, 100, for_body); Block::make({for_stmt}); @@ -2777,7 +2826,7 @@ void testNormalizeStartVariable() { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } -void testNormalizeOnNestedOuterLoop() { +TEST(LoopNest, NormalizeOnNestedOuterLoop) { KernelScope kernel_scope; // Input IR: @@ -2815,7 +2864,7 @@ void testNormalizeOnNestedOuterLoop() { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } -void testNormalizeOnNestedInnerLoop() { +TEST(LoopNest, NormalizeOnNestedInnerLoop) { KernelScope kernel_scope; // Input IR: @@ -2853,14 +2902,14 @@ void testNormalizeOnNestedInnerLoop() { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } -void testNormalizeAndSplitWithTail() { +TEST(LoopNest, NormalizeAndSplitWithTail) { KernelScope kernel_scope; // Create a dummy tensor to construct LoopNest. ExprHandle n(100); - Buffer a(BufHandle("a", {n}, kFloat)); + Placeholder a(BufHandle("a", {n}, kFloat)); Tensor* b = - Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); }); + Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); // Input IR: @@ -2902,13 +2951,304 @@ void testNormalizeAndSplitWithTail() { torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); } -void testDetectInlineRankMismatch() { +TEST(LoopNest, FlattenSimpleLoopNest2D) { + KernelScope kernel_scope; + + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j, 1)}); + auto inner_for = For::make(j, 0, 5, for_body); + auto outer_for = For::make(i, 0, 10, inner_for); + Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + For* flattened = nullptr; + bool success = LoopNest::flatten(loops, &flattened); + ASSERT_TRUE(success); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { + # CHECK: A[i_flat / 5, i_flat % 5] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenSimpleLoopNest3D) { + KernelScope kernel_scope; + + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // for (int k = 0; k < 7; k++) { + // A[i,j,k] = i + j * k; + // } + // } + // } + BufHandle a_buf("A", {10, 5, 7}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k, 1)}); + auto for1 = For::make(k, 0, 7, for_body); + auto for2 = For::make(j, 0, 5, for1); + auto for3 = For::make(i, 0, 10, for2); + Block::make({for3}); + + std::vector loops = {for3, for2, for1}; + For* flattened = nullptr; + bool success = LoopNest::flatten(loops, &flattened); + ASSERT_TRUE(success); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { + # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5, 7); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5, 7); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenLoopNestAfterNormalize) { + KernelScope kernel_scope; + + // Input IR: + // for (int i = 2; i < 10; i++) { + // for (int j = 3; j < 15; j++) { + // A[i - 2,j - 3] = i * j; + // } + // } + BufHandle a_buf("A", {8, 12}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j, 1)}); + auto inner_for = For::make(j, 3, 15, for_body); + auto outer_for = For::make(i, 2, 10, inner_for); + Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + For* flattened = nullptr; + bool success = LoopNest::flatten(loops, &flattened); + ASSERT_TRUE(success); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { + # CHECK: A[i_flat / 12, i_flat % 12] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(8, 12); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(8, 12); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenImperfectLoopNest) { + KernelScope kernel_scope; + + // Input IR: + // for (int i = 0; i < 10; i++) { + // A[i, i] = 0; + // for (int j = 0; j < 15; j++) { + // A[i,j] = i * j; + // } + // } + // Do not flatten. + + BufHandle a_buf("A", {10, 15}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j, 1)}); + auto inner_for = For::make(j, 0, 15, for_body); + auto outer_for = For::make( + i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0, 1), inner_for})); + Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + For* flattened = nullptr; + bool success = LoopNest::flatten(loops, &flattened); + ASSERT_FALSE(success); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i = 0; i < 10; i++) { + # CHECK-NEXT: A[i, i] = + # CHECK-NEXT: for (int j = 0; j < 15; j++) { + # CHECK-NEXT: A[i, j] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, FlattenReductionLoopNest) { + KernelScope kernel_scope; + + // Input IR: + // for (int i = 0; i < 10; i++) { + // S[i] = 0; + // for (int j = 0; j < 15; j++) { + // S[i] = S[i] + A[i,j]; + // } + // } + // Do not flatten. + + BufHandle a_buf("A", {10, 15}, kInt); + BufHandle s_buf("S", {10}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make( + s_buf, + {i}, + Load::make(s_buf, {i}, 1) + Load::make(a_buf, {i, j}, 1), + 1)}); + auto inner_for = For::make(j, 0, 15, for_body); + auto outer_for = For::make( + i, 0, 10, Block::make({Store::make(s_buf, {i}, 0, 1), inner_for})); + Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + For* flattened = nullptr; + bool success = LoopNest::flatten(loops, &flattened); + ASSERT_FALSE(success); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i = 0; i < 10; i++) { + # CHECK-NEXT: S[i] = + # CHECK-NEXT: for (int j = 0; j < 15; j++) { + # CHECK-NEXT: S[i] = (S[i]) + (A[i, j]) + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, FlattenReductionLoopNestFromTensor) { + KernelScope kernel_scope; + const int M = 3; + const int N = 7; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + Placeholder b(BufHandle("b", {m, n}, kFloat)); + Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}}); + LoopNest loop({c}); + auto loops = loop.getLoopStmtsFor(c); + For* flattened; + bool success = LoopNest::flatten(loops, &flattened); + ASSERT_FALSE(success); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int m = 0; m < 3; m++) { + # CHECK-NEXT: sum[m] = + # CHECK-NEXT: for (int n = 0; n < 7; n++) { + # CHECK-NEXT: sum[m] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, FlattenIncorrectLoopsAsInput) { + KernelScope kernel_scope; + + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + // for (int x = 0; x < 10; x++) { + // for (int y = 0; y < 5; y++) { + // A[x,y] = A[x,y] + x + y; + // } + // } + // Flatten({For_i, For_y}) => should not succeed + + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j, 1)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}, 1) + x + y, 1)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + Block::make({outer_for1, outer_for2}); + + std::vector loops = {outer_for1, inner_for2}; + For* flattened = nullptr; + bool success = LoopNest::flatten(loops, &flattened); + ASSERT_FALSE(success); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i = 0; i < 10; i++) { + # CHECK-NEXT: for (int j = 0; j < 5; j++) { + # CHECK-NEXT: A[i, j] = i * j + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, DetectInlineRankMismatch) { KernelScope kernel_scope; const int kTotalSize = 8; - Buffer a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Tensor* a = Compute( - "a", {{kTotalSize, "i"}}, [&](const VarHandle& i) { return a_buf(i); }); + Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); + Tensor* a = Compute("a", {{kTotalSize, "i"}}, [&](const VarHandle& i) { + return a_buf.load(i); + }); Tensor* reshape = Compute( "reshape", {{kTotalSize / 2, "i"}, {2, "j"}}, @@ -2916,7 +3256,484 @@ void testDetectInlineRankMismatch() { LoopNest l({reshape}); ASSERT_THROWS_WITH( l.computeInline(l.getLoopBodyFor(a)), - "Buffer indexed access is inconsistent with its rank"); + "Placeholder indexed access is inconsistent with its rank"); +} + +TEST(LoopNest, CacheReadsSimple) { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 3); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* j_loop = l.getLoopStmtsFor(B)[1]; + l.cacheAccesses(A->buf(), "A_local", j_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + + // just this once: verify the whole thing. + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A, int, {4096}); +#CHECK: for (int i +#CHECK: for (int j +#CHECK: A[ +#CHECK: } +#CHECK: } +#CHECK: for (int i_1 +#CHECK: Allocate(A_local, int, {10}); +#CHECK: for (int j_1 +#CHECK: A_local[j_1] = A[ +#CHECK: } +#CHECK: for (int j_2 +#CHECK: B[10 * i_1 + j_2] = A_local[j_2]; +#CHECK: } +#CHECK: Free(A_local); +#CHECK: } +#CHECK: for (int i_2 +#CHECK: for (int j_3 +#CHECK: C[ +#CHECK: } +#CHECK: } +#CHECK: Free(A); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 3); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsOuter) { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 40) + A->call(i + 31, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* i_loop = l.getLoopStmtsFor(B)[0]; + l.cacheAccesses(A->buf(), "A_local", i_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {231}); +#CHECK: A_local[j_1 + 11 * i_1] = +#CHECK: B[10 * i_2 + j_2] = (A_local[(j_2 + 11 * i_2) + 12]) + (A_local[j_2 + 11 * i_2]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsInternal) { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 40) + A->call(i + 31, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* j_loop = l.getLoopStmtsFor(B)[1]; + l.cacheAccesses(A->buf(), "A_local", j_loop); + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {22}); +#CHECK: A_local[j_1 + 11 * i_2] = +#CHECK: B[10 * i_1 + j_2] = (A_local[j_2]) + (A_local[j_2 + 12]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsInner) { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + // note im changing the offset of the first arg of the first call to A. + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 34, j + 40) + A->call(i + 30, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* body = l.getLoopBodyFor(B); + l.cacheAccesses(A->buf(), "A_local", body); + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {10}); +#CHECK: A_local[2 * i_2 + j_2] = +#CHECK: B[10 * i_1 + j_1] = (A_local[8]) + (A_local[1]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheWritesSimple) { + KernelScope kernel_scope; + + Tensor* A = Compute( + "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor* B = Compute( + "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 30, j + 40) + A->call(i + 31, j + 41); + }); + Tensor* C = Compute( + "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i + 10, j + 20) + A->call(i + 30, j + 40); + }); + + LoopNest l({B, C}); + Stmt* a_loop = l.getLoopStmtsFor(A)[1]; + l.cacheAccesses(A->buf(), "A_local", a_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(A_local, int, {64}); +#CHECK: for (int j = 0; j < 64 +#CHECK: A_local[j] = i * j; +#CHECK: for (int j_1 = 0; j_1 < 64 +#CHECK: A[64 * i + j_1] = A_local[ +#CHECK: Free(A_local); +#CHECK-NOT: A_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + SimpleIREvaluator cg(l.root_stmt(), {B, C}); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, DeadStoreElimination) { + KernelScope kernel_scope; + VarHandle y("y", kInt); + VarHandle x("x_tail", kInt); + BufHandle f("f", {26, 5}, kFloat); + BufHandle g("g", {26, 5}, kFloat); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + For* stmt1 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(f, {x_2, y}, (x_2 + y), 1), + Store::make(g, {x_2, y}, (x_2 * y), 1), + }))); + Stmt* stmt = Block::make({stmt1}); + + // Will eliminate if not used by an output. + LoopNest loop(stmt, {f.node()}, {}, {}); + loop.eliminateDeadStores(); + + std::ostringstream oss; + oss << *loop.root_stmt(); + + const std::string& expected_ir = + R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK-NOT: g[x_tail + 5 * 4, y] + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // But won't eliminate if used by different outputs. + LoopNest loop2(stmt, {f.node(), g.node()}, {}, {}); + loop2.eliminateDeadStores(); + + oss.clear(); + oss << *loop2.root_stmt(); + + const std::string& expected_ir2 = + R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK: g[x_tail + 5 * 4, y] + )IR"; + torch::jit::testing::FileCheck().run(expected_ir2, oss.str()); +} + +TEST(LoopNest, DeadStoreEliminationWithIntermediates) { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + BufHandle f("f", {26 * 5}, kFloat); + BufHandle g("g", {26 * 5}, kFloat); + BufHandle h("h", {26, 5}, kFloat); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + For* stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); + For* stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); + For* stmt3 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(h, {x, y}, Load::make(f, {x * y}, 1), 1), + }))); + Stmt* stmt = Block::make({stmt1, stmt2, stmt3}); + + // Will eliminate the write to g, but not f since it used by the producer of + // h. + LoopNest loop(stmt, {h.node()}, {}, {}); + loop.eliminateDeadStores(); + + std::ostringstream oss; + oss << *loop.root_stmt(); + + const std::string& expected_ir = + R"IR( + #CHECK: f[x] = x; + #CHECK-NOT: g[z] = + #CHECK: h[x, y] = f[x * y]; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Sanity check won't eliminate if g is an output. + LoopNest loop2(stmt, {h.node(), g.node()}, {}, {}); + loop2.eliminateDeadStores(); + + oss.clear(); + oss << *loop2.root_stmt(); + + const std::string& expected_ir2 = + R"IR( + #CHECK: f[x] = x; + #CHECK: g[z] = z + 1; + #CHECK: h[x, y] = f[x * y]; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir2, oss.str()); +} + +TEST(LoopNest, CompoundTensorSimple) { + KernelScope kernel_scope; + + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j, 1)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}, 1) + x + y, 1)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + Block* body = Block::make({outer_for1, outer_for2}); + + Tensor* A = new CompoundTensor(a_buf.node(), {i.node(), j.node()}, body); + + LoopNest l({A}); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + + Stmt* s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg(s, {A}); + + std::vector a_ref(50, 0); + + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 5; ++j) { + a_ref[i * 5 + j] = (i * j) + i + j; + } + } + cg.call({a_data}); + + assertAllEqual(a_data, a_ref); +} + +TEST(LoopNest, CompoundTensorUsed) { + KernelScope kernel_scope; + + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j, 1)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}, 1) + x + y, 1)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + Block* body = Block::make({outer_for1, outer_for2}); + + Tensor* A = new CompoundTensor(a_buf.node(), {i.node(), j.node()}, body); + Tensor* B = Compute( + "B", {{10, "i"}, {3, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + return A->call(i, j + 1) + A->call(i, j + 2); + }); + + LoopNest l({B}); + ASSERT_FALSE(l.computeInline(A->buf())); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + std::vector b_data(50, 0); + + Stmt* s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg(s, {B}); + + std::vector b_ref(50, 0); + + auto AT = [](int i, int j) { return i * j + i + j; }; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 3; ++j) { + b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); + } + } + cg.call({b_data}); + + assertAllEqual(b_data, b_ref); } } // namespace jit diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp new file mode 100644 index 0000000000000..7105b9d0ff204 --- /dev/null +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -0,0 +1,3186 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +// Test helper function used to determine if two regions of a buffer have an +// overlap. No Overlap & partial overlap is obvious. Contains means A is +// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal +// ranges are ContainedOrEqual. +TEST(MemDependency, BoundOverlap) { + KernelScope kernel_scope; + + using namespace analysis; + + auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + + // Sanity check 3 overlap cases. + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); + ASSERT_EQ(PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5))); + ASSERT_EQ(NoOverlap, boundOverlap(CB(0, 0), CB(1, 1))); + + // Partial overlap works in either order. + ASSERT_EQ(PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14))); + ASSERT_EQ(PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10))); + + // Total Overlap works when one bound encloses the other, and returns which. + ASSERT_EQ(Contains, boundOverlap(CB(2, 15), CB(7, 9))); + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16))); + + // Total overlap works when the bounds are an identical range, returns + // ContainedOrEqual. + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15))); + + // Total overlap when only one end of the bound matches. + ASSERT_EQ(Contains, boundOverlap(CB(2, 15), CB(2, 10))); + ASSERT_EQ(Contains, boundOverlap(CB(2, 15), CB(3, 15))); + ASSERT_EQ(Contains, boundOverlap(CB(0, 10), CB(0, 9))); + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15))); + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15))); + + // No overlap when a < b. + ASSERT_EQ(NoOverlap, boundOverlap(CB(0, 2), CB(5, 10))); + ASSERT_EQ(NoOverlap, boundOverlap(CB(2, 2), CB(3, 3))); + ASSERT_EQ(NoOverlap, boundOverlap(CB(100, 120), CB(130, 130))); + + // No overlap when a > b. + ASSERT_EQ(NoOverlap, boundOverlap(CB(5, 10), CB(0, 2))); + ASSERT_EQ(NoOverlap, boundOverlap(CB(3, 3), CB(2, 2))); + ASSERT_EQ(NoOverlap, boundOverlap(CB(130, 130), CB(100, 120))); + + // No overlap when adjacent. + ASSERT_EQ(NoOverlap, boundOverlap(CB(0, 100), CB(101, 120))); + ASSERT_EQ(NoOverlap, boundOverlap(CB(2, 3), CB(0, 1))); + + // Partial overlap when middle bounds match. + ASSERT_EQ(PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120))); + ASSERT_EQ(PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4))); + ASSERT_EQ(PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100))); + ASSERT_EQ(PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2))); + + // Total overlap when one bound is single length over one end of the other. + ASSERT_EQ(Contains, boundOverlap(CB(2, 15), CB(15, 15))); + ASSERT_EQ(Contains, boundOverlap(CB(2, 15), CB(2, 2))); + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15))); + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15))); +} + +TEST(MemDependency, BoundOverlapSymbolic) { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + VarHandle w("w", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + // Sanity check cases where the start and end is symbolic but the diff is + // constant. + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x))); + ASSERT_EQ(PartialOverlap, boundOverlap(CB(x, x + 3), CB(x + 2, x + 5))); + ASSERT_EQ(NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1))); + + // We can't infer the sign of y, so cannot tell whether adding y is larger or + // smaller than y/2. + ASSERT_EQ(PartialOverlap, boundOverlap(CB(x, x + y), CB(x, x + y / 2))); + + // No information about this bound, have to take the most conservative option: + // there may be an overlap. + ASSERT_EQ(PartialOverlap, boundOverlap(CB(x, y), CB(z, w))); + + // Math on opaque terms works. + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(x + w, y - z), CB(x + w, y - z))); + // Even requiring simplification. + ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(x - w - w, y), CB(x - w * 2, y))); +} + +// Tests the helper function for overlap of multi dimensional indices bounds. +// This uses boundOverlap on each dimension and return the "lowest" kind of +// overlap. +TEST(MemDependency, BoundOverlapMultiDim) { + KernelScope kernel_scope; + + using namespace analysis; + + auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + + // Sanity check one dimensional cases. + ASSERT_EQ(ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); + ASSERT_EQ(NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)})); + ASSERT_EQ(PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)})); + + // Total overlap in 3 dims. + ASSERT_EQ( + ContainedOrEqual, + overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)})); + ASSERT_EQ( + ContainedOrEqual, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)})); + + // Total overlap in 2 dims, no overlap in another. + ASSERT_EQ( + NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); + + // Total overlap in 2 dims, partial overlap in another. + ASSERT_EQ( + PartialOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); + // This case is most important, so verify the overlap in any dim. (dim 2) + ASSERT_EQ( + PartialOverlap, + overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)})); + // Dim 1. + ASSERT_EQ( + PartialOverlap, + overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)})); + // Total overlap in 1 dim, partial in 2. + ASSERT_EQ( + PartialOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)})); + // Total overlap, partial overlap, no overlap. + ASSERT_EQ( + NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)})); + + // Total overlap (B) in 2 dims, total overlap (A) in another. + ASSERT_EQ( + Contains, + overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)})); + + // Total overlap (A) in 2 dims, total overlap (B) in another. + ASSERT_EQ( + Contains, + overlaps( + {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)})); + + // Total (B), No Overlap, Total (A). + ASSERT_EQ( + NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)})); +} + +// Test the helper we use to subtract bounds: returns the regions(s) of A which +// remain after removing the region of B. +TEST(MemDependency, BoundSubtract) { + KernelScope kernel_scope; + + using namespace analysis; + + auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + // One element subtract. + ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0); + ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0); + + // No Overlap. + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)})); + + // one side overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)})); + ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)})); + ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)})); + + // both sides overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {})); + + // internal overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)})); +} + +TEST(MemDependency, BoundSubtractSymbolic) { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + VarHandle w("w", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + // One element subtract. + ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {})); + + // Subtract constant range low. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)})); + // Subtract constant range high. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)})); + // Subtract constant range total overlap. + ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {})); + // Subtract constant range internal. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)), + {CB(x, x + 2), CB(x + 8, x + 10)})); + + // Size is inferable but not constant, only works with a single var. + ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)})); + + // Size is not inferable. + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)})); +} + +// Tests the helper function that does subtraction, but for multi dimensional +// indices bounds. +TEST(MemDependency, BoundSubtractMultiDim) { + KernelScope kernel_scope; + + using namespace analysis; + + auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto EQ = [](std::vector x, std::vector y) { + if (x.size() != y.size()) { + return false; + } + for (auto i = 0; i < x.size(); ++i) { + if (!indexBoundsEquals(x[i], y[i])) { + return false; + } + } + return true; + }; + + // sanity check one dimension. + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {})); + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {})); + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}})); + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}})); + + // Multi dim total overlap. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); + + // Mutli dim one way partial in dim 1. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), + {{CB(4, 9), CB(0, 2)}})); + + // Mutli dim one way partial in dim 2. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), + {{CB(0, 9), CB(11, 20)}})); + + // Partial overlap in 2 dims. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}), + {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}})); + + // Partial overlap in 3 dims. + ASSERT_TRUE( + EQ(subtractIndicesBounds( + {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}), + {{CB(0, 1), CB(0, 5), CB(0, 5)}, + {CB(2, 5), CB(0, 1), CB(0, 5)}, + {CB(2, 5), CB(2, 5), CB(0, 1)}})); +} + +// Tests the multi dimensional subtraction code for bounds that cannot be fully +// materialized. +TEST(MemDependency, BoundSubtractMultiDimSymbolic) { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](std::vector x, std::vector y) { + if (x.size() != y.size()) { + return false; + } + for (auto i = 0; i < x.size(); ++i) { + if (!indexBoundsEquals(x[i], y[i])) { + return false; + } + } + return true; + }; + + // Cannot determine overlaps. + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}})); + + // Various total Overlaps. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {})); + + // one-way overlap in first dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}), + {{CB(x - 4, x), CB(0, y)}})); + // second dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}), + {{CB(0, x), CB(0, 4)}})); + + // Internal overlap in first dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}), + {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}})); + // second dim. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}), + {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}})); + + // Overlap in both dimensions. + ASSERT_TRUE( + EQ(subtractIndicesBounds( + {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}), + { + {CB(0, 4), CB(0, y)}, + {CB(x - 4, x), CB(0, y)}, + {CB(0, x), CB(0, 9)}, + {CB(0, x), CB(y - 9, y)}, + })); +} + +// Simple check that the analyzer does anything at all... +TEST(MemDependency, MemDependencyCheckerSimple) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * B[0] = A[0] + 1; + */ + + Store* aStore = Store::make(a, {0}, 3, 1); + Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}, 1), 1), 1); + + Stmt* stmt = Block::make({aStore, bStore}); + + stmt->accept(&analyzer); + + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); + // sanity check, but anything that depends directly must depend indirectly. + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore)); +} + +// Check that there is a difference between direct and indirect dependence. +TEST(MemDependency, MemDependencyCheckerMultiStmt) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * B[0] = A[0]; + * C[0] = B[0] + 1; + */ + + Store* aStore = Store::make(a, {0}, 3, 1); + Store* bStore = Store::make(b, {0}, Load::make(a, {0}, 1), 1); + Store* cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}, 1), 1), 1); + + Stmt* stmt = Block::make({aStore, bStore, cStore}); + + stmt->accept(&analyzer); + + // C depends on A indirectly. + ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore)); + ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore)); + + // C depends on B directly, which depends on A directly. + ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore)); + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + + // Dependency goes top to bottom only. + ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore)); +} + +// Verify that we do filter writes that are totally overlapped by later writes. +TEST(MemDependency, MemDependencyCheckerOverlap) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * A[0] = 6; + * B[0] = A[0] + 1; + */ + + Store* aStore = Store::make(a, {0}, 3, 1); + Store* a2Store = Store::make(a, {0}, 6, 1); + Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}, 1), 1), 1); + + Stmt* stmt = Block::make({aStore, a2Store, bStore}); + + stmt->accept(&analyzer); + + // B store depends on second A store but not first since it is completely + // overlapped. + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store)); + ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore)); + + // No dependency between either A store. + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store)); + ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore)); +} + +// Verify that bounds match loop iterations, and that dependencies progress +// across loop scopes. +TEST(MemDependency, MemDependencyCheckerLoop) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * for (int x = 0; x < 10; ++x) { + * A[x] = x; + * } + * B[0] = A[0] + 1; + */ + + Store* aStore = Store::make(a, {x}, x, 1); + Stmt* loop = For::make(x, 0, 10, aStore); + Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}, 1), 1), 1); + + Stmt* stmt = Block::make({loop, bStore}); + + stmt->accept(&analyzer); + + // Same A->B dependency. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop but does not depend on any loop iteration. + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop)); + + auto aStoreAccess = analyzer.accessFor(aStore); + ASSERT_NE(aStoreAccess, nullptr); + + // It should have bounds covering the range of x: 0 <= x < 10. + ASSERT_TRUE(indexBoundsEquals( + aStoreAccess->bounds(), {Bound(new IntImm(0), new IntImm(9))})); +} + +// Reductions should promote dependencies as well. +TEST(MemDependency, MemDependencyCheckerLoopReduce) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * A[0] = 0; + * for (int x = 0; x < 10; ++x) { + * A[0] = A[x] + 1; + * } + * B[0] = A[0]; + */ + + Store* aInit = Store::make(a, {0}, 0, 1); + ExprHandle reduce = + ExprHandle(Sum()(a.node(), ExprHandle(1), {x.node()}, {x.node()})); + Store* aReduce = Store::make(a, {0}, reduce, 1); + Stmt* loop = For::make(x, 0, 10, aReduce); + Store* bStore = Store::make(b, {0}, Load::make(a, {0}, 1), 1); + + Stmt* stmt = Block::make({aInit, loop, bStore}); + + stmt->accept(&analyzer); + + // B -> A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); + + // B depends indirectly on the intializer of A, since the reduction depends + // on it. + ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); + + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop and depends on other iterations. + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); + + // The loop contents depend on the initializer too. + ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); + + // Find loads within the reduction: + auto reduceLoads = NodeFinder::find(reduce.node()); + // Pull out the access for the load inside the loop. + for (auto* load : reduceLoads) { + auto loopLoad = analyzer.accessFor(load); + // It should have 10 element long bounds. + ASSERT_TRUE(indexBoundsEquals( + loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))})); + } +} + +// Lowering a reduction doesn't affect dependency analysis. +TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * A[0] = 0; + * for (int x = 0; x < 10; ++x) { + * A[0] = A[x] + 1; + * } + * B[0] = A[0]; + */ + + Store* aInit = Store::make(a, {0}, 0, 1); + ExprHandle aLoad = Load::make(a, {x}, 1); + Store* aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); + Stmt* loop = For::make(x, 0, 10, aReduce); + Store* bStore = Store::make(b, {0}, Load::make(a, {0}, 1), 1); + + Stmt* stmt = Block::make({aInit, loop, bStore}); + + stmt->accept(&analyzer); + + // B -> A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); + + // B depends indirectly on the intializer of A, since the reduction depends + // on it. + ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); + + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop and depends on other iterations. + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); + + // The loop contents depend on the initializer too. + ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); + + // Pull out the access for the store inside the loop. + auto loopLoad = analyzer.accessFor(aLoad.node()); + // It should have 10 element long bounds. + ASSERT_TRUE(indexBoundsEquals( + loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))})); +} + +// Can determine dependencies of outputs, through to inputs. +TEST(MemDependency, MemDependencyCheckerInputsOutputs) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + // initialize analyzer with inputs and outputs. + analysis::MemDependencyChecker analyzer({a}, {b}); + + // Here's a Relu. + /* + * for (int x = 0; x < 10; ++x) { + * B[x] = Max(A[x], 0); + * } + */ + + ExprHandle aLoad = Load::make(a, {x}, 1); + Store* bStore = Store::make(b, {x}, Max::make(aLoad, 0, true), 1); + Stmt* loop = For::make(x, 0, 10, bStore); + + Stmt* stmt = Block::make({loop}); + + stmt->accept(&analyzer); + + // Output depends indirectly on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + // aLoad depends directly on the input A. + ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node())); + // bStore therefore depends directly on the input A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node())); + // The output depends directly on the store. + ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); + + // Check AccessInfo based overloads. + auto input = analyzer.input(a.node()); + auto output = analyzer.output(b.node()); + + // Output depends indirectly on input. + ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); + // Not directly. + ASSERT_FALSE(analyzer.dependsDirectly(output, input)); + // Not in reverse order. + ASSERT_FALSE(analyzer.dependsIndirectly(input, output)); + + // output -> bStore -> bLoad -> input. + auto storeAccess = analyzer.accessFor(bStore); + auto loadAccess = analyzer.accessFor(aLoad.node()); + + ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess)); + ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input)); +} + +// Can tell if an output does not depend on an input. +TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + // initialize analyzer with inputs and outputs. + analysis::MemDependencyChecker analyzer({a}, {b}); + + // Here's a dumb Relu. + /* + * for (int x = 0; x < 10; ++x) { + * B[x] = Max(x, 0); + * } + */ + + Store* bStore = Store::make(b, {x}, Max::make(x, 0, true), 1); + Stmt* loop = For::make(x, 0, 10, bStore); + + Stmt* stmt = Block::make({loop}); + + stmt->accept(&analyzer); + + // Output does not depend indirectly on input. + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node())); + + // The output still depends directly on the store. + ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); + + // Check AccessInfo based overloads. + auto input = analyzer.input(a.node()); + auto output = analyzer.output(b.node()); + + // Output does not depend indirectly on input. + ASSERT_FALSE(analyzer.dependsIndirectly(output, input)); +} + +// Verify different loop extents produce accesses with different bounds, and +// that later accesses find dependencies that overlap their entire bound range. +TEST(MemDependency, MemDependencyCheckerLoopBounds) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + using namespace analysis; + + MemDependencyChecker analyzer({a}, {c}); + + // This enables using the execution order of the loops to determine if some + // loops are self dependent or not. + analyzer.allowLoopExecutionOrderAnalysis(); + + /* + * for (int x = 1; x < 10; ++x) { + * B[x] = A[x]; + * } + * for (int x = 1; x < 9; ++x) { + * B[x] = B[x] * 2; + * } + * for (int x = 3; x < 4; ++x) { + * C[x] = A[x]; + * } + * for (int x = 0; x < 10; ++x) { + * C[x] = B[x]; + * } + */ + + std::vector stmts( + {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}, 1), 1)), + For::make( + x, + 1, + 9, + Store::make(b, {x}, Mul::make(Load::make(b, {x}, 1), 2), 1)), + For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}, 1), 1)), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}, 1), 1))}); + + Stmt* stmt = Block::make(stmts); + + stmt->accept(&analyzer); + + auto input = analyzer.input(a.node()); + auto output = analyzer.output(c.node()); + + // sanity check Output -> Input. + ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); + + // Check the For loop dependencies: + + // Last write to C depends on both writes to B since they contain the last + // write to at least one element. + ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1])); + ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0])); + + // The last write to C does not depend on the other write to C. + ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); + + auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + /* 0. Input: A[(0, 9)] - dependents: 1 5 + * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2 + * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7 + * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4 + * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7 + * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6 + * 6. Store: C[(3, 3)] - depends on: 5 + * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8 + * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9 + * 9. Output: C[(0, 9)] - depends on: 8 + */ + + // Now let's look at the bounds of each access. + // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this + // much. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 10); + const Var* aVar = a.node()->base_handle(); + const Var* bVar = b.node()->base_handle(); + const Var* cVar = c.node()->base_handle(); + + // The first access is the input A. + ASSERT_EQ(history[0]->type(), AccessType::Input); + ASSERT_EQ(history[0]->var(), aVar); + // It has the bounds of the producing Input. + ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); + // sanity check the input we retrieved earlier matches. + ASSERT_EQ(history[0], input); + + // The second access is the load of A in the first loop. + ASSERT_EQ(history[1]->type(), AccessType::Load); + ASSERT_EQ(history[1]->var(), aVar); + // It has the bounds of the loop, i.e. start == 1. + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)})); + // It reads from A, so it should have a dependency on the last write to this + // range - with is the input. + ASSERT_EQ(history[1]->dependencies().size(), 1); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + // The third access is the store into B in the first loop. + ASSERT_EQ(history[2]->type(), AccessType::Store); + ASSERT_EQ(history[2]->var(), bVar); + // It also has the bounds of the loop, i.e. start == 1. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); + // The previous load is in its RHS, so it depends on it. + ASSERT_EQ(history[2]->dependencies().size(), 1); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The third access is the load from B in the second loop. + ASSERT_EQ(history[3]->type(), AccessType::Load); + ASSERT_EQ(history[3]->var(), bVar); + // It has the bounds of the second loop, i.e. >= 1 < 9. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)})); + // It reads from B in a smaller range, so should depend on the previous + // store. + ASSERT_EQ(history[3]->dependencies().size(), 1); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The fourth: the store to B in the second loop. + ASSERT_EQ(history[4]->type(), AccessType::Store); + ASSERT_EQ(history[4]->var(), bVar); + // It also has the bounds of the second loop. + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)})); + // The previous load is in its RHS, so it depends on it as before. + ASSERT_EQ(history[4]->dependencies().size(), 1); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // The fifth access is the load is from the 3rd loop, and skips previous B + // accesses. + ASSERT_EQ(history[5]->type(), AccessType::Load); + ASSERT_EQ(history[5]->var(), aVar); + // It has the bounds of the third loop: >= 3 < 4. + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)})); + // It depends on the last thing to write to A, which is the A input. + ASSERT_EQ(history[5]->dependencies().size(), 1); + ASSERT_TRUE(history[5]->hasDependency(history[0])); + + // Sixth: the store into the output C. + ASSERT_EQ(history[6]->type(), AccessType::Store); + ASSERT_EQ(history[6]->var(), cVar); + // It also has the bounds of the third loop. + ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)})); + // The previous load is in its RHS, so it depends on it as always. + ASSERT_EQ(history[6]->dependencies().size(), 1); + ASSERT_TRUE(history[6]->hasDependency(history[5])); + + // The seventh access is the load of B in the fourth loop. + ASSERT_EQ(history[7]->type(), AccessType::Load); + ASSERT_EQ(history[7]->var(), bVar); + // It has the bounds of the final loop, >= 0 < 10 + ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); + // The bounds of this read are larger than the bounds of the previous write, + // so it depends on both previous Stores to B. + ASSERT_EQ(history[7]->dependencies().size(), 2); + ASSERT_TRUE(history[7]->hasDependency(history[2])); + ASSERT_TRUE(history[7]->hasDependency(history[4])); + + // Eight: the final store into the output C. + ASSERT_EQ(history[8]->type(), AccessType::Store); + ASSERT_EQ(history[8]->var(), cVar); + // It also has the bounds of the final loop. + ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); + // The previous load is in its RHS, so it depends on it as always. + ASSERT_EQ(history[8]->dependencies().size(), 1); + ASSERT_TRUE(history[8]->hasDependency(history[7])); + + // The last access represents the output Buf. + ASSERT_EQ(history[9]->type(), AccessType::Output); + ASSERT_EQ(history[9]->var(), cVar); + // It has the bounds of the output Buf. + ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)})); + // sanity check the input we retrieved earlier matches. + ASSERT_EQ(history[9], output); + // It depends on the last write to C only. + ASSERT_EQ(history[9]->dependencies().size(), 1); + ASSERT_TRUE(history[9]->hasDependency(history[8])); +} + +// Verify that we can still infer bounds when the loop var is offset. +TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer({a}, {b}); + + // This enables using the execution order of the loops to determine if some + // loops are self dependent or not. + analyzer.allowLoopExecutionOrderAnalysis(); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * for (int x = 0; x < 9; x++) { + * A[x] = A[x + 1]; + * } + * for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + * for (int x = 0; x < 10; x++) { + * A[x] = A[9 - x]; + * } + * for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + */ + + Stmt* stmt = Block::make( + {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1)), + For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}, 1), 1)), + For::make( + x, + 0, + 9, + Store::make( + a, + {ExprHandle(9) - x}, + Load::make(a, {ExprHandle(8) - x}, 1), + 1)), + For::make( + x, + 0, + 10, + Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}, 1), 1)), + For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}, 1), 1))}); + + stmt->accept(&analyzer); + + // Sanity check output depends on Input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + /* 0. Input: A[(0, 9)] - dependents: 1 + * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2 + * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3 + * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4 + * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7 + * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6 + * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7 + * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8 + * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9 + * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10 + * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11 + * 11. Output: B[(0, 9)] - depends on: 10 + */ + + // Now let's look at the bounds of each access. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 12); + const Var* aVar = a.node()->base_handle(); + const Var* bVar = b.node()->base_handle(); + + // The first access is the input A. + ASSERT_EQ(history[0]->type(), AccessType::Input); + ASSERT_EQ(history[0]->var(), aVar); + // It has the bounds of the producing Input. + ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); + + // The second access is the load A[x-1]. + ASSERT_EQ(history[1]->type(), AccessType::Load); + ASSERT_EQ(history[1]->var(), aVar); + // It has the bounds of the loop modified by the offset of each index, in + // this case -1. + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); + // It depends on the input, but also the store in the same loop, since + // different interations of the loop depend on each other. + ASSERT_EQ(history[1]->dependencies().size(), 2); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + ASSERT_TRUE(history[1]->hasDependency(history[2])); + + // The third access is the Store to A[x] in the first loop. + ASSERT_EQ(history[2]->type(), AccessType::Store); + ASSERT_EQ(history[2]->var(), aVar); + // It has no offset on x, so should have the same bounds as the loop. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); + + // The fourth access is the load A[x+1] in the second loop. + ASSERT_EQ(history[3]->type(), AccessType::Load); + ASSERT_EQ(history[3]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9) modified by the offset of each + // index, in this case 1. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)})); + // This load totally overlaps the previous write to A, so it depends only on + // it and not the input. + ASSERT_EQ(history[3]->dependencies().size(), 1); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The fifth access is the store to A[x] in the second loop. + ASSERT_EQ(history[4]->type(), AccessType::Store); + ASSERT_EQ(history[4]->var(), aVar); + // It has no offset on x, so should have the same bounds as the loop. + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)})); + + // The sixth access is the load to A[8 - x] in the third loop. + ASSERT_EQ(history[5]->type(), AccessType::Load); + ASSERT_EQ(history[5]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9) modified by the offset of each + // index, in this case 8 - x. + // This access has a negative stride, which will be normalized. + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)})); + // This load totally overlaps the most recent write to A, so it depends only + // on it and not the input or the first write to A. + ASSERT_EQ(history[5]->dependencies().size(), 1); + ASSERT_TRUE(history[5]->hasDependency(history[4])); + + // The seventh access is the store to A[9 - x] in the third loop. + ASSERT_EQ(history[6]->type(), AccessType::Store); + ASSERT_EQ(history[6]->var(), aVar); + // This store has a negative stride on it's indices, but is notmalized + // internally. + ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)})); + + // The eighth access is the load A[9-x] in the second loop. + ASSERT_EQ(history[7]->type(), AccessType::Load); + ASSERT_EQ(history[7]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x, + // which esstentially traverses the loop backwards. + ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); + // This Load has three write dependencies: + ASSERT_EQ(history[7]->dependencies().size(), 3); + // * The previous store (#6) for elements 1-9 + ASSERT_TRUE(history[7]->hasDependency(history[6])); + // * An earlier store (#4) covering element 0 + ASSERT_TRUE(history[7]->hasDependency(history[4])); + // * A future store inside this loop, since this loop modifies the buffer + // in a non distinct way (due to the load and store having different access + // strides). + ASSERT_TRUE(history[7]->hasDependency(history[8])); + + // The ninth access is the store to A[x] in the fourth loop. + ASSERT_EQ(history[8]->type(), AccessType::Store); + ASSERT_EQ(history[8]->var(), aVar); + // This store has a negative stride on it's indices, but is notmalized + // internally. + ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); + + // The tenth and 11th acceses are the copy from A[x] to B[x]. + ASSERT_EQ(history[9]->type(), AccessType::Load); + ASSERT_EQ(history[9]->var(), aVar); + ASSERT_EQ(history[10]->type(), AccessType::Store); + ASSERT_EQ(history[10]->var(), bVar); + + // The last access represents the output Buf. + ASSERT_EQ(history[11]->type(), AccessType::Output); + ASSERT_EQ(history[11]->var(), bVar); + // It has the bounds of the output Buf. + ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)})); + // It depends on the last write to B only. + ASSERT_EQ(history[11]->dependencies().size(), 1); + ASSERT_TRUE(history[11]->hasDependency(history[10])); + + // ok that's enough of that. +} + +// Check many different cases of loop self dependency - when a load within a +// loop is dependent on a Store later in the same loop but in different +// iteration. This is affected by whether or not we can trust the execution +// order of the loop. +TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + using namespace analysis; + + // This check assumes that the Stmt has a single Store with a single Load on + // the RHS. + auto isSelfDependent = + [](const std::vector>& history) -> bool { + return history.front()->hasDependency(history.back()); + }; + + { + /* for (int y = 0; y < 10; y++) { + * A[y] = (A[y]) + 1; + * } */ + + // Not self dependent since all loop iterations use a different y. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + y, + 0, + 10, + Block::make( + {Store::make(a, {y}, Add::make(Load::make(a, {y}, 1), 1), 1)})); + + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int y = 0; y < 10; y++) { + * A[y + 1] = (A[y + 1]) + 1; + * } + */ + + // Not self dependent due to different y (with offset). + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + y, + 0, + 10, + Block::make({Store::make( + a, {y + 1}, Add::make(Load::make(a, {y + 1}, 1), 1), 1)})); + + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + // Is self dependent since all loops use a common constant element of A. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)})); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[0] = (B[0]) + x; + * } + */ + + // Is not self dependent beacause there is no store to the buffer that is + // read. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(b, {0}, 1), x), 1)})); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[y] = (A[y]) + x; + * } + */ + + // Is self dependent since all loops use a common symbolic element of A. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {y}, Add::make(Load::make(a, {y}, 1), x), 1)})); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 1]; + * } + */ + + // In this case it depends if we are considering execution order. + + MemDependencyChecker analyzer; + + Stmt* stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}, 1), 1)); + stmt->accept(&analyzer); + + // With analysis of order disabled, this is self dependent since the read + // from X+1 and the write to X+1 could be in reverse order. + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 1]; + * } + */ + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + Stmt* stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}, 1), 1)); + stmt->accept(&analyzer); + + // If order analysis is enabled, this is not dependent since the read for + // each element occurs before the write to that element. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + MemDependencyChecker analyzer; + + Stmt* stmt = + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + Stmt* stmt = + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1)); + stmt->accept(&analyzer); + + // In this case, even with order analysis the Load is dependent on the + // Store, since the write to X occurs before the read from X. + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + */ + + // Still works if the execution order is reversed, so long as the read + // comes before the write. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + Stmt* stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}, 1), 1)); + stmt->accept(&analyzer); + + // However here was can determine the A store is earlier in the order than + // the load. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[8 - x] = A[9 - x]; + * } + */ + + // But not if it doesn't. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + Stmt* stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + */ + + // And not if we're not relying on execution order. + + MemDependencyChecker analyzer; + + Stmt* stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 3; x < 10; x++) { + * A[x - 2] = A[x - 1]; + * } + */ + + // Forward order but negative indices. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + Stmt* stmt = For::make( + x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}, 1), 1)); + stmt->accept(&analyzer); + + // However here was can determine the A store is earlier in the order than + // the load. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2]; + * } + */ + + // With an access stride. + + MemDependencyChecker analyzer; + // Execution order doesn't matter since the read and the write are totally + // distinct. + + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 1]; + * } + */ + + // Here we can use the common stride of the accesses to determine they are + // distinct. + // Note, this is the only place (loop self depedency) we use this stride + // to avoid unnecessary depedence. + + MemDependencyChecker analyzer; + // Execution order doesn't matter since the read and the write are totally + // distinct. + + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 - 1]; + * } + */ + + // same if the read is behind the write so long as they are distinct. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 2]; + * } + */ + + // But not if the offset is in the stride. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 - 2]; + * } + */ + + // Works with negative offsets too. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 7]; + * } + */ + + // Detects accesses are distinct when offset is large but not a multiple + // of stride. + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 4]; + * } + */ + + // Works with offsets which are multiples of the stride. + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 6] = A[x * 6 + 5]; + * } + */ + + // detects accesses are distinct with large strides when the offset is + // within. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6]; + * } + */ + + // detects accesses are overlapping when stride is different but a + // multiple. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 4] = A[x * 2]; + * } + */ + + // still works when the read axis is the smaller stride. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6 + 1]; + * } + */ + + // detects accesses are distinct when stride is different but a multiple + // and there is an offset. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6 + 4]; + * } + */ + + // The smaller stride determines whether there is overlap. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2 + 3] = A[x * 6]; + * } + */ + + // The smaller stride determines whether there is overlap, not the larger. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 3 + 1]; + * } + */ + + // If they have strides with no common muliple > 1, they overlap. + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 10]; + * } + */ + + // If the offset is greater than the size of the loop, they can't overlap. + + MemDependencyChecker analyzer; + Stmt* stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[9 - x]; + * } + */ + + // If they have different execution orders they may overlap. + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, + 0, + 10, + Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[19 - x * 2]; + * } + */ + + // Or they may not, depending on their start offset and strides. + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, + 0, + 10, + Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x / 2] = A[x / 2]; + * } + */ + + // If the stride is not monotonic, they overlap. + + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x / 2] = A[x / 2] + 1; + * } + */ + + // If the stride is not monotonic, they overlap - even with an offset. + MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x % 2] = A[x % 2]; + * } + */ + + // Mod too... + + analysis::MemDependencyChecker analyzer; + Stmt* stmt = For::make( + x, + 0, + 10, + Store::make( + a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = y; x < z; x++) { + * A[x] = A[x + 1]; + * } + */ + + // Still works with symbolic loop extents. + + { + MemDependencyChecker analyzer; + Stmt* stmt = + For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + Stmt* stmt = + For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}, 1), 1)); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + } +} + +// Verify that a strided access still works. +// TODO: actually this only works because of the size of the ranges, revist this +// test after strided overlap is implemented. +TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { + KernelScope kernel_scope; + BufHandle a("A", {20}, kInt); + BufHandle b("B", {20}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + MemDependencyChecker analyzer({a.node()}, {b.node()}); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}, 1), 1)), + For::make( + x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}, 1), 1)) + + }); + stmt->accept(&analyzer); + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 2 dependencies... the store in each loop. + auto outputAccess = analyzer.output(b.node()); + ASSERT_EQ(outputAccess->dependencies().size(), 2); +} + +/* TODO(nickg) - this test will fail due to the lack of stride math in Bound +TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { + KernelScope kernel_scope; + BufHandle a("A", {20}, kInt); + BufHandle b("B", {20}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}, 1), 1)), + For::make( + x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}, 1), 1)), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}, 1), 1)) + + }); + stmt->accept(&analyzer); + + std::cout << *stmt << "\n"; + for (auto& wi : analyzer.getHistory()) { + wi->print(); + } + } +}*/ + +// analysis on Stmts using Cond. +TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * C[0] = (B[0]) + 1; + * } else { + * C[0] = (B[1]) + 1; + * } + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}, 1), 1)), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Store::make(c, {0}, Add::make(Load::make(b, {0}, 1), 1), 1), + Store::make(c, {0}, Add::make(Load::make(b, {1}, 1), 1), 1))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 3); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * C[x] = B[x]; + * } + * } else { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}, 1), 1)), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}, 1), 1)), + For::make( + x, + 0, + 10, + Store::make( + c, {x}, Add::make(Load::make(b, {x}, 1), 1), 1)))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 3); + + // TODO(nickg): actually since the true and false branch cover the total + // range of the first store this should have 2 dependencies, but we don't + // do that yet. + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Only has true branch. + + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}, 1), 1)), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + For::make( + x, + 0, + 10, + Store::make(c, {x}, Add::make(Load::make(b, {x}, 1), 1), 1)), + nullptr)}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * } else { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Only has false branch. + + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}, 1), 1)), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + nullptr, + For::make( + x, + 0, + 10, + Store::make( + c, {x}, Add::make(Load::make(b, {x}, 1), 1), 1)))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (C[0]<5 ? 1 : 0) { + * C[0] = 5; + * } + */ + + // Cond's Condition depends on a previous access. + + MemDependencyChecker analyzer({a}, {c}); + Store* initStore = Store::make(c, {x}, Load::make(a, {x}, 1), 1); + ExprHandle conditionalLoad = Load::make(c, {0}, 1); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, initStore), + Cond::make( + CompareSelect::make( + conditionalLoad, 5, CompareSelectOperation::kLT), + Store::make(c, {0}, 5, 1), + nullptr)}); + + stmt->accept(&analyzer); + + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + + ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore)); + ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node())); + } +} + +// Stmts using IfThenElse. +TEST(MemDependency, MemDependencyCheckerIfThenElse) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1; + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + Store* ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Add::make(Load::make(b, {0}, 1), 1), + Add::make(Load::make(b, {1}, 1), 1)), + 1); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}, 1), 1)), + ifStore}); + + stmt->accept(&analyzer); + + // Output C should have 2 dependencies, each of the two stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // Now we need to check the Store containing the IfThenElse. + auto ifStoreAccess = analyzer.accessFor(ifStore); + + // It should have 2 dependencies. + ASSERT_EQ(ifStoreAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * C[0] = (y < 5 ? (B[0]) + 1 : 42; + */ + + // If the load appears in only one side of an IfThenElse the output may be + // dependent on it. + + MemDependencyChecker analyzer({a, b}, {c}); + Store* ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Add::make(Load::make(b, {0}, 1), 1), + 42), + 1); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}, 1), 1)), + ifStore}); + + stmt->accept(&analyzer); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = (x < 5 ? B[x] : A[x]; + * } + */ + + // In this case C is dependent on both A and B. + + // TODO: in cases like this it would be possible to split the range of B + // into two bounds, one dependent on A and one depenent on B. We'd need to + // examine conditions relative to previously encountered loop variables. I'm + // uncertain if this would be helpful. + + MemDependencyChecker analyzer({a, b}, {c}); + Store* ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Load::make(b, {x}, 1), + Load::make(a, {x}, 1)), + 1); + Stmt* stmt = Block::make({For::make(x, 0, 10, ifStore)}); + + stmt->accept(&analyzer); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } +} + +// Cutting a loop with single elem writes +TEST(MemDependency, MemDependencyCheckerCutLoop) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + * B[5] = 100; + */ + + // Cutting a loop with single element writes. + + MemDependencyChecker analyzer({a}, {b}); + Stmt* stmt = Block::make( + {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}, 1), 1)), + Store::make(b, {5}, 100, 1)}); + + stmt->accept(&analyzer); + + // Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 2 depdenencies. + auto outputAccess = analyzer.output(b.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + } + + { + /* for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + * for (int x = 4; x < 7; x++) { + * B[x] = B[x] + 3; + * } + * B[5] = 100; + * B[6] = 101; + * B[7] = 102; + */ + + // Cutting a loop with a smaller loop but then totally overlap that second + // loop with one element writes. + + MemDependencyChecker analyzer({a}, {b}); + For* firstLoop = + For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}, 1), 1)); + Store* secondStore = + Store::make(b, {x}, Add::make(Load::make(b, {x}, 1), 1), 3); + For* secondLoop = For::make(x, 4, 7, secondStore); + + Stmt* stmt = Block::make( + {firstLoop, + secondLoop, + Store::make(b, {4}, 100, 1), + Store::make(b, {5}, 101, 1), + Store::make(b, {6}, 102, 1)}); + + stmt->accept(&analyzer); + + // Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 4 depdenencies. + auto outputAccess = analyzer.output(b.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 4); + + // Second loop depends on first loop. + ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop)); + + // Output does not depend on second loop or store. + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop)); + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore)); + } +} + +// Dynamic shapes (load in indices). +TEST(MemDependency, MemDependencyCheckerDynamicShapes) { + KernelScope kernel_scope; + BufHandle a("A", {100}, kInt); + BufHandle b("B", {100}, kInt); + BufHandle c("C", {100}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + { + /* for (int x = 0; x < B[0]; x++) { + * C[x] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make({For::make( + x, + 0, + Load::make(b, {0}, 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 + * 1. Input: A[(0, 99)] - dependents: 3 + * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4 + * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4 + * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Output dependent on A input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + // Also dependent on B input to determine the size of the region written. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The accesses in the loop depend on the load in the stop condition. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // Make a load from B to compare against. + ExprHandle loadFromB = Load::make(b, {0}, 1); + + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)})); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)})); + } + + { + /* for (int x = B[0]; x < B[1]; x++) { + * C[x] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make({For::make( + x, + Load::make(b, {0}, 1), + Load::make(b, {1}, 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 3 + * 1. Input: A[(0, 99)] - dependents: 4 + * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5 + * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5 + * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5 + * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6 + * 6. Output: C[(0, 99)] - depends on: 5 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 7); + + // The accesses in the loop depend on the load in the start condition. + ASSERT_TRUE(history[5]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[2])); + + // also the stop condition. + ASSERT_TRUE(history[5]->hasDependency(history[3])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // Make loads from B to compare against. + ExprHandle loadFromB0 = Load::make(b, {0}, 1); + ExprHandle loadFromB1 = Load::make(b, {1}, 1); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[B[x]]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Store::make(c, {x}, Load::make(a, {Load::make(b, {x}, 1)}, 1), 1))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 + * 1. Input: A[(0, 99)] - dependents: 3 + * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4 + * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4 + * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads, the load of A depends on the load of B. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The loads in the indices depend on the relevant input buffer. + ASSERT_TRUE(history[3]->hasDependency(history[1])); + ASSERT_TRUE(history[2]->hasDependency(history[0])); + + // The load from B has the loop bounds. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + + // The load from A has bounds B[0] to B[9]. + ExprHandle loadFromB0 = Load::make(b, {0}, 1); + ExprHandle loadFromB9 = Load::make(b, {9}, 1); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[B[x]] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Store::make(c, {Load::make(b, {x}, 1)}, Load::make(a, {x}, 1), 1))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 3 + * 1. Input: A[(0, 99)] - dependents: 2 + * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4 + * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4 + * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads, neither load is dependent. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + ASSERT_FALSE(history[3]->hasDependency(history[2])); + ASSERT_FALSE(history[2]->hasDependency(history[3])); + + // The loads each depend on their relevant input. (but accesses are in a + // different order than the last case). + ASSERT_TRUE(history[3]->hasDependency(history[0])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The load from B has the loop bounds. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)})); + + // And so does the load from A. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[B[A[x]]] = x; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Store::make(c, {Load::make(b, {Load::make(a, {x}, 1)}, 1)}, x, 1))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 3 + * 1. Input: A[(0, 99)] - dependents: 2 + * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4 + * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4 + * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // The outer load depends on the inner. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The loads each depend on their relevant input. (but accesses are in a + // different order than the last case). + ASSERT_TRUE(history[3]->hasDependency(history[0])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The load from A has the loop bounds. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + // The load from B as bounds A[0] to A[9]. + ExprHandle loadFromA0 = Load::make(a, {0}, 1); + ExprHandle loadFromA9 = Load::make(a, {9}, 1); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)})); + + // The store has bounds of B[A[0]] to B[A[9]]. + ExprHandle loadFromBA0 = Load::make(b, {loadFromA0}, 1); + ExprHandle loadFromBA9 = Load::make(b, {loadFromA9}, 1); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)})); + } +} + +// Verify multi dimensional bounds work. +TEST(MemDependency, MemDependencyCheckerMultiDim) { + KernelScope kernel_scope; + int M = 10, N = 9, K = 12; + BufHandle a("A", {M, N, K}, kInt); + BufHandle b("B", {M, N, K}, kInt); + BufHandle c("C", {M, K}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 9; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, y, z] = A[x, y, z]; + * } + * } + * } + */ + // Full range. + + MemDependencyChecker analyzer({a}, {b}); + Stmt* stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + N, + For::make( + z, + 0, + K, + Store::make(b, {x, y, z}, Load::make(a, {x, y, z}, 1), 1))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 5; x++) { + * for (int y = 0; y < 5; y++) { + * for (int z = 0; z < 5; z++) { + * B[x, y, z] = A[x, y, z]; + * } + * } + * } + */ + // Partial range. + + MemDependencyChecker analyzer({a}, {b}); + Stmt* stmt = Block::make({For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + For::make( + z, + 0, + 5, + Store::make(b, {x, y, z}, Load::make(a, {x, y, z}, 1), 1))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 12; y++) { + * B[x, 0, y] = A[x, 0, y]; + * } + * } + */ + + // Partial loops. + + MemDependencyChecker analyzer({a}, {b}); + Stmt* stmt = Block::make({For::make( + x, + 0, + N, + For::make( + y, + 0, + K, + Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}, 1), 1)))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 100; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]); + * } + * } + * } + */ + + // Loops that don't correspond to an index, bufs with different + // dimensionality. + + MemDependencyChecker analyzer({a, c}, {b}); + Stmt* stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + 100, + For::make( + z, + 0, + K, + Store::make( + b, + {x, 0, z}, + Add::make( + Load::make(a, {x, 0, z}, 1), Load::make(c, {x, z}, 1)), + 1))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on both inputs. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node())); + + // 6 accesses: 2 inputs, 2 loads, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // Simple chain from input to output over the A buf. + // history[0] is the C input, history[3] is the load from C. + ASSERT_TRUE(history[5]->hasDependency(history[4])); + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + // The store also depends on the load from the C input. + ASSERT_TRUE(history[4]->hasDependency(history[3])); + ASSERT_TRUE(history[3]->hasDependency(history[0])); + + // A Buf accesses. + ASSERT_TRUE( + EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); + + // C buf access. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 9; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]); + * } + * } + * } + */ + // Multi-dim reductions. + + MemDependencyChecker analyzer({a}, {b}); + Stmt* stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + N, + For::make( + z, + 0, + K, + Store::make( + b, + {x, 0, 0}, + Add::make( + Load::make(b, {x, y, z}, 1), + Load::make(a, {x, y, z}, 1)), + 1))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, 2 loads, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 5); + + // Simple chain from input to output. + ASSERT_TRUE(history[4]->hasDependency(history[3])); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[3]->hasDependency(history[1])); + ASSERT_TRUE(history[2]->hasDependency(history[0])); + + // The load from B depends on the store to B. + ASSERT_TRUE(history[1]->hasDependency(history[3])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)})); + } +} + +// Various tests using the external Compute/Reduce API. +TEST(MemDependency, MemDependencyCheckerComputeAPI) { + KernelScope kernel_scope; + + using namespace analysis; + + /* for (int m = 0; m < 4; m++) { + * for (int n = 0; n < 5; n++) { + * for (int k = 0; k < 6; k++) { + * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]); + * } + * } + * } + * for (int m_1 = 0; m_1 < 4; m_1++) { + * for (int n_1 = 0; n_1 < 5; n_1++) { + * for (int k_1 = 0; k_1 < 6; k_1++) { + * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1); + * } + * } + * } + */ + + // Can determine if 2 loops created by Compute are dependent. + Placeholder a_buf("a", kFloat, {4, 5}); + Placeholder b_buf("b", kFloat, {5, 6}); + Tensor* c = Compute( + "broadcast_add", + {{4, "m"}, {5, "n"}, {6, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor* d = Compute( + "d", + {{4, "m"}, {5, "n"}, {6, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c->call(m, n, k) + 1; + }); + + LoopNest l({d}); + + MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()}); + + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), a_buf.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b_buf.data())); + + // Second loop depends on first loop. + auto* c_loop = l.getLoopStmtsFor(c)[0]; + auto* d_loop = l.getLoopStmtsFor(d)[0]; + ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); +} + +TEST(MemDependency, MemDependencyCheckerComputeInline) { + KernelScope kernel_scope; + + using namespace analysis; + + /* for (int m = 0; m < 4; m++) { + * for (int n = 0; n < 5; n++) { + * for (int k = 0; k < 6; k++) { + * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1); + * } + * } + * } + */ + + // Check inlining affects the number of accesses returned. + + Placeholder a_buf("a", kFloat, {4, 5}); + Placeholder b_buf("b", kFloat, {5, 6}); + Tensor* c = Compute( + "broadcast_add", + {{4, "m"}, {5, "n"}, {6, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor* d = Compute( + "d", + {{4, "m"}, {5, "n"}, {6, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c->call(m, n, k) + 1; + }); + + LoopNest l({d}); + l.computeInline(c->buf()); + + MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()}); + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), a_buf.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b_buf.data())); + + // broadcast_add tensor should not appear in trace at all. + for (auto& wi : analyzer.getHistory()) { + ASSERT_NE(wi->var(), c->buf()->base_handle()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeSplit) { + KernelScope kernel_scope; + + using namespace analysis; + // Split an axis, so the number of loops != the number of dimensions. + + Placeholder a_buf("a", kFloat, {4, 5}); + Placeholder b_buf("b", kFloat, {5, 6}); + Tensor* c = Compute( + "broadcast_add", + {{4, "m"}, {5, "n"}, {6, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + + LoopNest l({c}); + + MemDependencyChecker analyzer_before( + {a_buf.data(), b_buf.data()}, {c->buf()}); + l.root_stmt()->accept(&analyzer_before); + + For *o, *i, *t; + l.splitWithTail(l.getLoopStmtsFor(c)[0], 2, &o, &i, &t); + + MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()}); + Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + stmt->accept(&analyzer_after); + + // Splitting should not change accesses at all. + auto history_before = analyzer_before.getHistory(); + auto history_after = analyzer_after.getHistory(); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + ASSERT_EQ( + history_before[i]->bounds().size(), history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + ASSERT_EQ( + history_before[i]->dependencies().size(), + history_after[i]->dependencies().size()); + ASSERT_EQ( + history_before[i]->dependents().size(), + history_after[i]->dependents().size()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeReorder) { + KernelScope kernel_scope; + + using namespace analysis; + // Reorder an axis, so the loop order doesn't match the indexing order. + + Placeholder a_buf("a", kFloat, {4, 5}); + Placeholder b_buf("b", kFloat, {5, 6}); + Tensor* c = Compute( + "broadcast_add", + {{4, "m"}, {5, "n"}, {6, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + + LoopNest l({c}); + + MemDependencyChecker analyzer_before( + {a_buf.data(), b_buf.data()}, {c->buf()}); + l.root_stmt()->accept(&analyzer_before); + + auto loops = l.getLoopStmtsFor(c); + l.reorderAxis(loops[0], loops[1]); + + MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()}); + Stmt* stmt = IRSimplifier::simplify(l.root_stmt()); + stmt->accept(&analyzer_after); + + // Reordering should not change accesses at all. + auto history_before = analyzer_before.getHistory(); + auto history_after = analyzer_after.getHistory(); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + ASSERT_EQ( + history_before[i]->bounds().size(), history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + ASSERT_EQ( + history_before[i]->dependencies().size(), + history_after[i]->dependencies().size()); + ASSERT_EQ( + history_before[i]->dependents().size(), + history_after[i]->dependents().size()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeReduce) { + KernelScope kernel_scope; + + using namespace analysis; + /* for (int l2 = 0; l2 < 2; l2++) { + * for (int n1 = 0; n1 < 3; n1++) { + * for (int m1 = 0; m1 < 6; m1++) { + * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]); + * } + * } + * } + * for (int l1 = 0; l1 < 2; l1++) { + * sum[l1] = float(0); + * for (int n1_1 = 0; n1_1 < 3; n1_1++) { + * for (int m1_1 = 0; m1_1 < 6; m1_1++) { + * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)), + * out_args={l1}, reduce_args={n1, m1}); + * } + * } + * } + */ + + // Can determine dependencies of a Reduction. + + Placeholder a(BufHandle("a", {2, 3, 6}, kFloat)); + Placeholder b(BufHandle("b", {2, 3, 6}, kFloat)); + + Tensor* c = Compute( + "scale", + {{2, "l2"}, {3, "n1"}, {6, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {6, "m1"}}); + LoopNest l({d}); + + MemDependencyChecker analyzer({a.data(), b.data()}, {d->buf()}); + + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), a.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b.data())); + + // Second loop depends on first loop. + auto* c_loop = l.getLoopStmtsFor(c)[0]; + auto* d_loop = l.getLoopStmtsFor(d)[0]; + ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); + + // Reduction depends on both inputs. + auto reduces = NodeFinder::find(l.root_stmt()); + ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.data())); +} + +TEST(MemDependency, MemDependencyCheckerComputeGEMM) { + KernelScope kernel_scope; + int M = 1024; + int N = 1024; + int K = 2048; + using namespace analysis; + + Placeholder AP(BufHandle("A", {M, K}, kFloat)); + Placeholder BP(BufHandle("B", {K, N}, kFloat)); + Tensor* CT = Reduce( + "gemm", + {{M, "M"}, {N, "N"}}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {{K, "K"}}); + LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* m = loops[0]; + For* mo; + For* mi; + loop.splitWithMask(m, 4, &mo, &mi); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* n = loops[2]; + For* no; + For* ni; + loop.splitWithMask(n, 16, &no, &ni); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* mi = loops[1]; + For* no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* ni = loops[3]; + For* k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + For* mi = loops[2]; + For* k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + loop.cacheAccesses(CT->buf(), "C_regs", loops[2]); + } + + MemDependencyChecker analyzer_unlowered( + loop.getInputBufs(), loop.getOutputBufs()); + + MemDependencyChecker analyzer_lowered( + loop.getInputBufs(), loop.getOutputBufs()); + + // Test both unlowered and lowered form. + { + Stmt* stmt = IRSimplifier::simplify(loop.root_stmt()); + stmt->accept(&analyzer_unlowered); + + // Outputs depend on inputs. + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT->buf(), AP.data())); + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT->buf(), BP.data())); + + // The last write to gemm should cover the total bound of the output. + std::shared_ptr outputAccess = + analyzer_unlowered.output(CT->buf()); + // A single dependency. + ASSERT_EQ(outputAccess->dependencies().size(), 1); + + // dependencies is a set with 1 element, so can just deref begin(). + std::shared_ptr gemmStore = + outputAccess->dependencies().begin()->second; + // Check its a store. + ASSERT_EQ(gemmStore->type(), AccessType::Store); + + ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds())); + + // Likewise the first read from each input cover the entire range of the + // input. + auto aInput = analyzer_unlowered.input(AP.data()); + auto bInput = analyzer_unlowered.input(BP.data()); + + // A single dependent each. + ASSERT_EQ(aInput->dependents().size(), 1); + ASSERT_EQ(bInput->dependents().size(), 1); + + // They're both loads. + std::shared_ptr aLoad = aInput->dependents().begin()->second; + std::shared_ptr bLoad = bInput->dependents().begin()->second; + ASSERT_EQ(aLoad->type(), AccessType::Load); + ASSERT_EQ(bLoad->type(), AccessType::Load); + + ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds())); + ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds())); + } + + loop.prepareForCodegen(); + + // now check lowered dependency graph. + { + Stmt* stmt = IRSimplifier::simplify(loop.root_stmt()); + stmt->accept(&analyzer_lowered); + + // Lowering will change the dimensionality of all bounds due to index + // flattening and will insert Allocates and Frees. + + auto history_before = analyzer_unlowered.getHistory(); + auto history_after = analyzer_lowered.getHistory(); + + ASSERT_EQ(history_before.size() + 2, history_after.size()); + + // Filter out the alloc/free; + auto isAllocFree = [](const auto& info) { + return info->type() == AccessType::Alloc || + info->type() == AccessType::Free; + }; + history_after.erase( + std::remove_if(history_after.begin(), history_after.end(), isAllocFree), + history_after.end()); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + + if (history_before[i]->dependencies().size() != + history_after[i]->dependencies().size()) { + // Must depend on an Alloc. + ASSERT_TRUE(std::any_of( + history_after[i]->dependencies().begin(), + history_after[i]->dependencies().end(), + [](const auto& pair) { + return pair.second->type() == AccessType::Alloc; + })); + + ASSERT_EQ( + history_before[i]->dependencies().size() + 1, + history_after[i]->dependencies().size()); + } + + if (history_before[i]->dependents().size() != + history_after[i]->dependents().size()) { + // Must depend on an Free. + ASSERT_TRUE(std::any_of( + history_after[i]->dependents().begin(), + history_after[i]->dependents().end(), + [](const auto& pair) { + return pair.second->type() == AccessType::Free; + })); + + ASSERT_EQ( + history_before[i]->dependents().size() + 1, + history_after[i]->dependents().size()); + } + + // Inputs and outputs are not flattened, only accesses. + if (history_before[i]->type() == AccessType::Input || + history_before[i]->type() == AccessType::Output) { + ASSERT_EQ( + history_before[i]->bounds().size(), + history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + } else { + ASSERT_EQ(history_after[i]->bounds().size(), 1); + const Expr* flat_bounds = new IntImm(1); + + for (auto& b : history_before[i]->bounds()) { + flat_bounds = new Mul(flat_bounds, new Add(b.end, new IntImm(1))); + + ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); + } + + flat_bounds = IRSimplifier::simplify(flat_bounds); + const Expr* after_bounds = IRSimplifier::simplify( + new Add(history_after[i]->bounds()[0].end, new IntImm(1))); + ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); + } + } + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 2af8e33d39813..f69217df9bdea 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1,20 +1,22 @@ +#include + #include #include #include #include #include -#include "test/cpp/tensorexpr/test_base.h" - -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "torch/csrc/jit/tensorexpr/analysis.h" -#include "torch/csrc/jit/tensorexpr/buffer.h" -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/function.h" -#include "torch/csrc/jit/tensorexpr/ir.h" -#include "torch/csrc/jit/tensorexpr/ir_printer.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/loopnest.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace torch { namespace jit { @@ -22,10 +24,10 @@ namespace jit { using namespace torch::jit::tensorexpr; // Sum an array to a single value. -void testReduceSum1D() { +TEST(Reductions, ReduceSum1D) { KernelScope kernel_scope; - Buffer b(BufHandle("b", {10}, kFloat)); + Placeholder b(BufHandle("b", {10}, kFloat)); std::vector in(10); for (int j = 0; j < 10; ++j) { in[j] = j; @@ -45,7 +47,7 @@ void testReduceSum1D() { ASSERT_EQ(out[0], 45); } // Sum a 2D tensor to a 1D tensor with dynamic shapes. -void testReduceSum2D() { +TEST(Reductions, ReduceSum2D) { KernelScope kernel_scope; const int M = 3; @@ -54,7 +56,7 @@ void testReduceSum2D() { VarHandle m("m", kInt); VarHandle n("n", kInt); - Buffer b(BufHandle("b", {m, n}, kFloat)); + Placeholder b(BufHandle("b", {m, n}, kFloat)); std::vector in(M * N); for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { @@ -86,13 +88,13 @@ void testReduceSum2D() { // Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to // check our work. -void testReduceSum3D() { +TEST(Reductions, ReduceSum3D) { KernelScope kernel_scope; const int M = 10; VarHandle m("m", kInt); - Buffer b(BufHandle("b", {2, 3, m}, kFloat)); + Placeholder b(BufHandle("b", {2, 3, m}, kFloat)); Tensor* c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(), b, {{m, "m"}}); LoopNest loop({c}); @@ -140,7 +142,7 @@ void testReduceSum3D() { } // This is the same as just reducing the original result across that axis. - Buffer c_buf(BufHandle(c->func_var())); + Placeholder c_buf(BufHandle(c->buf())); Tensor* e = Reduce("sum3", {{2, "l"}}, Sum(), c_buf, {{3, "m"}}); LoopNest loop3({e}); loop3.prepareForCodegen(); @@ -156,12 +158,12 @@ void testReduceSum3D() { } // Sum a large (10 D) Tensor 5 dimensions in. -void testReduceSum10D() { +TEST(Reductions, ReduceSum10D) { KernelScope kernel_scope; - Buffer in_(BufHandle("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat)); + Placeholder in_(BufHandle("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat)); const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; - Buffer out_(BufHandle("out_", {2, 3, 2, 3, 2}, kFloat)); + Placeholder out_(BufHandle("out_", {2, 3, 2, 3, 2}, kFloat)); const int OutputSize = 2 * 3 * 2 * 3 * 2; std::vector in(InputSize, 1.f); @@ -189,13 +191,13 @@ void testReduceSum10D() { } // Reduce via Mul rather than Add using a custom Reducer. -void testReduceProduct() { +TEST(Reductions, ReduceProduct) { KernelScope kernel_scope; const int M = 4; const int N = 4; - Buffer b(BufHandle("b", {M, N}, kFloat)); + Placeholder b(BufHandle("b", {M, N}, kFloat)); std::vector in(M * N); for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { @@ -229,10 +231,10 @@ void testReduceProduct() { } // Maximum reductions. -void testReduceMax() { +TEST(Reductions, ReduceMax) { KernelScope kernel_scope; - Buffer in_(BufHandle("b", {10}, kFloat)); + Placeholder in_(BufHandle("b", {10}, kFloat)); std::vector in(10); std::vector out(1, -1.f); @@ -252,7 +254,7 @@ void testReduceMax() { ASSERT_EQ(out[0], 9); - Buffer in2_(BufHandle("b", {2, 5}, kFloat)); + Placeholder in2_(BufHandle("b", {2, 5}, kFloat)); std::vector out2(2, -1.f); Tensor* m2d = Reduce("max", {{2, "n"}}, Maximum(kFloat), in2_, {{5, "m"}}); @@ -270,11 +272,11 @@ void testReduceMax() { } // Minimum reduction, with custom initialization. -void testReduceMinCustomInitializer() { +TEST(Reductions, ReduceMinCustomInitializer) { KernelScope kernel_scope; VarHandle minInit("minInit", kFloat); - Buffer in_(BufHandle("b", {10}, kFloat)); + Placeholder in_(BufHandle("b", {10}, kFloat)); std::vector in(10); std::vector out(1, -1.f); @@ -286,7 +288,7 @@ void testReduceMinCustomInitializer() { "min", {}, Minimum(ExprHandle(minInit)), - [&](ParameterList& v) { return in_.call(v); }, + [&](ParameterList& v) { return in_.load(v); }, {{10, "m"}}); LoopNest loop({min}); @@ -308,11 +310,11 @@ void testReduceMinCustomInitializer() { // Example implementation of Any/All. // TODO: this is very awkward without logical And/Or operators. -void testReduceAnyAll() { +TEST(Reductions, ReduceAnyAll) { KernelScope kernel_scope; VarHandle searchValue("searchValue", kInt); - Buffer b(BufHandle("b", {4, 10}, kInt)); + Placeholder b(BufHandle("b", {4, 10}, kInt)); Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) { return CompareSelect::make(a, 1, 1, b, kEQ); @@ -323,7 +325,7 @@ void testReduceAnyAll() { {{4, "i"}}, anyEqSV, [&](const auto& i, const auto& j) { - return CompareSelect::make(b(i, j), searchValue, kEQ); + return CompareSelect::make(b.load(i, j), searchValue, kEQ); }, {{10, "j"}}); @@ -366,7 +368,7 @@ void testReduceAnyAll() { {{4, "i"}}, allGTSV, [&](const auto& i, const auto& j) { - return CompareSelect::make(b(i, j), searchValue, kGT); + return CompareSelect::make(b.load(i, j), searchValue, kGT); }, {{10, "j"}}); @@ -394,11 +396,11 @@ void testReduceAnyAll() { ASSERT_EQ(out[3], 1); } -void testReduceMatmul2D() { +TEST(Reductions, ReduceMatmul2D) { KernelScope kernel_scope; - Buffer tA(BufHandle("tA", {3, 2}, kFloat)); - Buffer tB(BufHandle("tB", {2, 3}, kFloat)); + Placeholder tA(BufHandle("tA", {3, 2}, kFloat)); + Placeholder tB(BufHandle("tB", {2, 3}, kFloat)); std::vector tA_(6); std::vector tB_(6); @@ -416,7 +418,7 @@ void testReduceMatmul2D() { {{3, "m"}, {3, "n"}}, Sum(), [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return tA(m, k) * tB(k, n); + return tA.load(m, k) * tB.load(k, n); }, {{2, "k"}}); @@ -436,10 +438,10 @@ void testReduceMatmul2D() { } } -void testReduceRfactorLike() { +TEST(Reductions, ReduceRfactorLike) { KernelScope kernel_scope; - Buffer in(BufHandle("in", {10, 10}, kFloat)); + Placeholder in(BufHandle("in", {10, 10}, kFloat)); std::vector in_(100); for (int i = 0; i < 100; ++i) { in_[i] = i; @@ -448,7 +450,7 @@ void testReduceRfactorLike() { std::vector out(1, -1.f); Tensor* l1 = Reduce("l1", {{10, "i"}}, Sum(), in, {{10, "j"}}); - Buffer in_rf(BufHandle(l1->func_var())); + Placeholder in_rf(BufHandle(l1->buf())); Tensor* l2 = Reduce("l2", {}, Sum(), in_rf, {{10, "i"}}); @@ -463,21 +465,21 @@ void testReduceRfactorLike() { ASSERT_EQ(out[0], 99 * 50); } -void testReduceAsProducer() { +TEST(Reductions, ReduceAsProducer) { KernelScope kernel_scope; const int M = 10; VarHandle m("m", kInt); - Buffer a(BufHandle("a", {2, 3}, kFloat)); - Buffer b(BufHandle("b", {2, 3, m}, kFloat)); + Placeholder a(BufHandle("a", {2, 3}, kFloat)); + Placeholder b(BufHandle("b", {2, 3, m}, kFloat)); Tensor* c = Reduce("sum", {{2, "l1"}, {3, "n1"}}, Sum(), b, {{m, "m1"}}); Tensor* d = Compute( "scale", {{2, "l2"}, {3, "n1"}}, [&](const VarHandle& l, const VarHandle& n) { - return c->call(l, n) * a(l, n); + return c->call(l, n) * a.load(l, n); }); LoopNest loop({d}); loop.prepareForCodegen(); @@ -507,20 +509,20 @@ void testReduceAsProducer() { } } -void testReduceAsConsumer() { +TEST(Reductions, ReduceAsConsumer) { KernelScope kernel_scope; const int M = 10; VarHandle m("m", kInt); - Buffer a(BufHandle("a", {2, 3, m}, kFloat)); - Buffer b(BufHandle("b", {2, 3, m}, kFloat)); + Placeholder a(BufHandle("a", {2, 3, m}, kFloat)); + Placeholder b(BufHandle("b", {2, 3, m}, kFloat)); Tensor* c = Compute( "scale", {{2, "l2"}, {3, "n1"}, {m, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b(l, n, m) * a(l, n, m); + return b.load(l, n, m) * a.load(l, n, m); }); Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}}); LoopNest loop({d}); @@ -556,10 +558,10 @@ void testReduceAsConsumer() { } } -void testSplitReduceAxis() { +TEST(Reductions, SplitReduceAxis) { KernelScope kernel_scope; - Buffer in(BufHandle("in", {16, 8}, kFloat)); + Placeholder in(BufHandle("in", {16, 8}, kFloat)); std::vector in_(16 * 8); for (int i = 0; i < 16; ++i) { @@ -590,10 +592,10 @@ void testSplitReduceAxis() { } } -void testSplitNonReduceAxis() { +TEST(Reductions, SplitNonReduceAxis) { KernelScope kernel_scope; - Buffer in(BufHandle("in", {16, 8}, kFloat)); + Placeholder in(BufHandle("in", {16, 8}, kFloat)); std::vector in_(16 * 8); for (int i = 0; i < 16; ++i) { @@ -628,7 +630,7 @@ void testSplitNonReduceAxis() { } } -void testReorderedReductionInitializer() { +TEST(Reductions, ReorderedReductionInitializer) { KernelScope kernel_scope; /* From the quip: for k in 0..1: // blockIdx @@ -637,7 +639,7 @@ void testReorderedReductionInitializer() { SumOp(c(k, n), 0, a(k, m, n), {m}) */ - Buffer in(BufHandle("in", {1, 12, 6}, kFloat)); + Placeholder in(BufHandle("in", {1, 12, 6}, kFloat)); std::vector in_(12 * 6, 1.f); Tensor* tensor_ = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}}); @@ -677,7 +679,7 @@ void testReorderedReductionInitializer() { } } -void testReduceRfactor() { +TEST(Reductions, ReduceRfactor) { KernelScope kernel_scope; const int M = 10; @@ -685,7 +687,7 @@ void testReduceRfactor() { VarHandle m("m", kInt); VarHandle n("n", kInt); - Buffer b(BufHandle("b", {m, n}, kFloat)); + Placeholder b(BufHandle("b", {m, n}, kFloat)); std::vector in(M * N); for (int j = 0; j < M * N; ++j) { in[j] = j; @@ -710,7 +712,7 @@ void testReduceRfactor() { ASSERT_EQ(out[0], 4950); } -void testReduce3DRfactorInternal() { +TEST(Reductions, Reduce3DRfactorInternal) { KernelScope kernel_scope; const int M = 10; @@ -720,7 +722,7 @@ void testReduce3DRfactorInternal() { VarHandle n("n", kInt); VarHandle k("k", kInt); - Buffer b(BufHandle("b", {m, n, k}, kFloat)); + Placeholder b(BufHandle("b", {m, n, k}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -745,7 +747,7 @@ void testReduce3DRfactorInternal() { ASSERT_EQ(out[0], 499500); } -void testReduce3DRfactorInner() { +TEST(Reductions, Reduce3DRfactorInner) { KernelScope kernel_scope; const int M = 10; @@ -755,7 +757,7 @@ void testReduce3DRfactorInner() { VarHandle n("n", kInt); VarHandle k("k", kInt); - Buffer b(BufHandle("b", {m, n, k}, kFloat)); + Placeholder b(BufHandle("b", {m, n, k}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -780,7 +782,7 @@ void testReduce3DRfactorInner() { ASSERT_EQ(out[0], 499500); } -void testReduce3DRfactorOuter() { +TEST(Reductions, Reduce3DRfactorOuter) { KernelScope kernel_scope; const int M = 10; @@ -790,7 +792,7 @@ void testReduce3DRfactorOuter() { VarHandle n("n", kInt); VarHandle k("k", kInt); - Buffer b(BufHandle("b", {m, n, k}, kFloat)); + Placeholder b(BufHandle("b", {m, n, k}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -814,7 +816,7 @@ void testReduce3DRfactorOuter() { ASSERT_EQ(out[0], 499500); } -void testReduce3DRfactorWithOuter() { +TEST(Reductions, Reduce3DRfactorWithOuter) { KernelScope kernel_scope; const int L = 5; @@ -826,7 +828,7 @@ void testReduce3DRfactorWithOuter() { VarHandle n("n", kInt); VarHandle k("k", kInt); - Buffer b(BufHandle("b", {l, m, n, k}, kFloat)); + Placeholder b(BufHandle("b", {l, m, n, k}, kFloat)); std::vector in(L * M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -851,7 +853,7 @@ void testReduce3DRfactorWithOuter() { ASSERT_EQ(out[0], 7750); } -void testReduce3DRfactorRepeated() { +TEST(Reductions, Reduce3DRfactorRepeated) { KernelScope kernel_scope; const int M = 5; @@ -861,7 +863,7 @@ void testReduce3DRfactorRepeated() { VarHandle n("n", kInt); VarHandle k("k", kInt); - Buffer b(BufHandle("b", {m, n, k}, kFloat)); + Placeholder b(BufHandle("b", {m, n, k}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -899,7 +901,7 @@ void testReduce3DRfactorRepeated() { } } -void testReduceRfactorInsertionPoint() { +TEST(Reductions, ReduceRfactorInsertionPoint) { KernelScope kernel_scope; const int M = 10; @@ -907,7 +909,7 @@ void testReduceRfactorInsertionPoint() { VarHandle m("m", kInt); VarHandle n("n", kInt); - Buffer b(BufHandle("b", {m, n}, kFloat)); + Placeholder b(BufHandle("b", {m, n}, kFloat)); std::vector in(M * N); for (int j = 0; j < M * N; ++j) { in[j] = j; @@ -932,7 +934,7 @@ void testReduceRfactorInsertionPoint() { ASSERT_EQ(out[0], 4950); } -void testReduce3DRfactorInsertionPoint() { +TEST(Reductions, Reduce3DRfactorInsertionPoint) { KernelScope kernel_scope; const int M = 10; @@ -942,7 +944,7 @@ void testReduce3DRfactorInsertionPoint() { VarHandle n("n", kInt); VarHandle k("k", kInt); - Buffer b(BufHandle("b", {m, n, k}, kFloat)); + Placeholder b(BufHandle("b", {m, n, k}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -966,10 +968,10 @@ void testReduce3DRfactorInsertionPoint() { ASSERT_EQ(out[0], 4950); } -void testReduceRepeatedInternalRfactor() { +TEST(Reductions, ReduceRepeatedInternalRfactor) { KernelScope kernel_scope; - Buffer in_(BufHandle("in_", {2, 3, 4, 5, 6}, kFloat)); + Placeholder in_(BufHandle("in_", {2, 3, 4, 5, 6}, kFloat)); const int InputSize = 2 * 3 * 4 * 5 * 6; std::vector in(InputSize, 1.f); @@ -1013,14 +1015,14 @@ void testReduceRepeatedInternalRfactor() { } // Split a reduction axis with a tail loop. -void testReduceSplitTail() { +TEST(Reductions, ReduceSplitTail) { KernelScope kernel_scope; const int M = 10; const int N = 10; const int K = 10; - Buffer b(BufHandle("b", {M, N, K}, kFloat)); + Placeholder b(BufHandle("b", {M, N, K}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -1047,13 +1049,13 @@ void testReduceSplitTail() { } // Split a reduction axis cleanly so there is no tail loop. -void testReduceSplitNoTail() { +TEST(Reductions, ReduceSplitNoTail) { KernelScope kernel_scope; const int M = 10; const int N = 10; const int K = 10; - Buffer b(BufHandle("b", {M, N, K}, kFloat)); + Placeholder b(BufHandle("b", {M, N, K}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -1081,14 +1083,14 @@ void testReduceSplitNoTail() { // Split a reduction axis with only a tail loop (the split loop will be size 0 // and eliminated out). -void testReduceOverSplitTail() { +TEST(Reductions, ReduceOverSplitTail) { KernelScope kernel_scope; const int M = 10; const int N = 10; const int K = 10; - Buffer b(BufHandle("b", {M, N, K}, kFloat)); + Placeholder b(BufHandle("b", {M, N, K}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -1115,14 +1117,14 @@ void testReduceOverSplitTail() { } // Split a reduction axis with a mask. -void testReduceSplitMask() { +TEST(Reductions, ReduceSplitMask) { KernelScope kernel_scope; const int M = 10; const int N = 10; const int K = 10; - Buffer b(BufHandle("b", {M, N, K}, kFloat)); + Placeholder b(BufHandle("b", {M, N, K}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -1149,13 +1151,13 @@ void testReduceSplitMask() { } // Split a reduction axis cleanly not requiring a mask. -void testReduceSplitNoMask() { +TEST(Reductions, ReduceSplitNoMask) { KernelScope kernel_scope; const int M = 10; const int N = 10; const int K = 10; - Buffer b(BufHandle("b", {M, N, K}, kFloat)); + Placeholder b(BufHandle("b", {M, N, K}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -1182,14 +1184,14 @@ void testReduceSplitNoMask() { } // Split a reduction axis with all logic in the mask. -void testReduceOverSplitMask() { +TEST(Reductions, ReduceOverSplitMask) { KernelScope kernel_scope; const int M = 10; const int N = 10; const int K = 10; - Buffer b(BufHandle("b", {M, N, K}, kFloat)); + Placeholder b(BufHandle("b", {M, N, K}, kFloat)); std::vector in(M * N * K); for (int j = 0; j < M * N * K; ++j) { in[j] = j; @@ -1217,7 +1219,7 @@ void testReduceOverSplitMask() { // Test an rfactor when there are two ReduceOps in the graph due to a // splitWithTail. -void testReduceSplitRfactor() { +TEST(Reductions, ReduceSplitRfactor) { KernelScope kernel_scope; const int M = 2; @@ -1225,7 +1227,7 @@ void testReduceSplitRfactor() { const int K = 10; const int SPLIT_FACTOR = 4; - Buffer b(BufHandle("b", {M, N, K}, kFloat)); + Placeholder b(BufHandle("b", {M, N, K}, kFloat)); std::vector in(M * N * K); for (int m = 0; m < M; ++m) { for (int j = 0; j < N * K; ++j) { @@ -1257,14 +1259,14 @@ void testReduceSplitRfactor() { // Test an rfactor which ends up being eliminated since the total loop size is // smaller than the split factor. -void testReduceOverSplitRfactor() { +TEST(Reductions, ReduceOverSplitRfactor) { KernelScope kernel_scope; const int N = 10; const int K = 10; const int SPLIT_FACTOR = 16; - Buffer b(BufHandle("b", {N, K}, kFloat)); + Placeholder b(BufHandle("b", {N, K}, kFloat)); std::vector in(N * K); for (int j = 0; j < N * K; ++j) { in[j] = j; @@ -1308,18 +1310,18 @@ void testReduceOverSplitRfactor() { // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testReduceInlineReduction() { +TEST(Reductions, ReduceInlineReduction) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M}); - Buffer b_buf("b", kFloat, {M, N, K}); + Placeholder a_buf("a", kFloat, {M}); + Placeholder b_buf("b", kFloat, {M, N, K}); Tensor* x = Reduce("x", {{M, "m1"}}, Sum(), b_buf, {{N, "n1"}, {K, "k1"}}); Tensor* y = Compute("y", {{M, "m2"}}, [&](const VarHandle& m) { - return a_buf(m) + x->call(m); + return a_buf.load(m) + x->call(m); }); PaddedBuffer a_v(M); @@ -1337,24 +1339,24 @@ void testReduceInlineReduction() { } LoopNest l1({y}); - ASSERT_THROWS_WITH( - l1.computeInline(x->buf()), "cannot inline a reduction computation"); + // Cannot inline a reduction computation + ASSERT_FALSE(l1.computeInline(x->buf())); } -void testReduceInlineConsumer() { +TEST(Reductions, ReduceInlineConsumer) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N, K}); - Buffer b_buf("b", kFloat, {M, N, K}); + Placeholder a_buf("a", kFloat, {M, N, K}); + Placeholder b_buf("b", kFloat, {M, N, K}); Tensor* x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n, k) + b_buf(m, n, k); + return a_buf.load(m, n, k) + b_buf.load(m, n, k); }); Tensor* y = Reduce("y", {{M, "m2"}}, Sum(), x, {{N, "n2"}, {K, "k2"}}); @@ -1380,8 +1382,8 @@ void testReduceInlineConsumer() { Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); - SimpleIREvaluator eval1(stmt1, a_buf, b_buf, y); - SimpleIREvaluator eval2(stmt2, a_buf, b_buf, y); + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); PaddedBuffer y_1(M); PaddedBuffer y_2(M); @@ -1395,20 +1397,20 @@ void testReduceInlineConsumer() { ASSERT_GT(oss1.str().size(), oss2.str().size()); } -void testReduceInlineReducerInternal() { +TEST(Reductions, ReduceInlineReducerInternal) { KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Buffer a_buf("a", kFloat, {M, N, K}); - Buffer b_buf("b", kFloat, {M, N, K}); + Placeholder a_buf("a", kFloat, {M, N, K}); + Placeholder b_buf("b", kFloat, {M, N, K}); Tensor* x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf(m, n, k) + b_buf(m, n, k); + return a_buf.load(m, n, k) + b_buf.load(m, n, k); }); Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { @@ -1438,8 +1440,8 @@ void testReduceInlineReducerInternal() { Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt()); Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt()); - SimpleIREvaluator eval1(stmt1, a_buf, b_buf, y); - SimpleIREvaluator eval2(stmt2, a_buf, b_buf, y); + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); PaddedBuffer y_1(M); PaddedBuffer y_2(M); @@ -1453,5 +1455,519 @@ void testReduceInlineReducerInternal() { ASSERT_GT(oss1.str().size(), oss2.str().size()); } +TEST(Reductions, ReductionCacheAccessesOuter) { + KernelScope kernel_scope; + + int L = 4; + int N = 3; + int M = 2; + + Placeholder a(BufHandle("a", {L, N, M}, kFloat)); + Placeholder b(BufHandle("b", {L, N, M}, kFloat)); + + Tensor* c = Compute( + "scale", + {{L, "l2"}, {N, "n1"}, {M, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + + Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + Stmt* d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(d->buf(), "d_local", d_loop); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local, float, {1}); +#CHECK: sum[l1] = 0 +#CHECK: d_local[0] = 0 +#CHECK: for (int n1 +#CHECK: for (int m1 +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: } +#CHECK: sum[l1] = (sum[l1]) + (d_local[0]) +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionCacheAccessesInner) { + KernelScope kernel_scope; + + int L = 4; + int N = 3; + int M = 2; + + Placeholder a(BufHandle("a", {L, N, M}, kFloat)); + Placeholder b(BufHandle("b", {L, N, M}, kFloat)); + + Tensor* c = Compute( + "scale", + {{L, "l2"}, {N, "n1"}, {M, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + + Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + Stmt* d_loop = l.getLoopStmtsFor(d)[2]; + l.cacheAccesses(d->buf(), "d_local", d_loop); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1] = 0 +#CHECK: for (int n1 +#CHECK: Allocate(d_local, float, {1}); +#CHECK: d_local[0] = 0 +#CHECK: for (int m1 +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: sum[l1] = (sum[l1]) + (d_local[0]) +#CHECK: Free(d_local); +#CHECK: } +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionCacheBodyAccess) { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + Stmt* d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(c->buf(), "scale_local", d_loop); + + l.prepareForCodegen(); + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(scale_local, float, {384}); +#CHECK: for (int j = 0; j < 32; j++) { +#CHECK: for (int k = 0; k < 12; k++) { +#CHECK: scale_local[k + 12 * j] = scale[(k + 384 * l1) + 12 * j]; +#CHECK: sum[l1] = (sum[l1]) + (scale_local[12 * n1_1 + m1_1]); +#CHECK: Free(scale_local); +#CHECK: scale_1[l] = (b[l]) * (sum[l]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionCacheConsumerAccess) { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + For* outer; + For* inner; + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + + Stmt* e_loop = l.getLoopStmtsFor(e)[1]; + l.cacheAccesses(d->buf(), "sum_local", e_loop); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1] = (sum[l1]) + (scale[ +#CHECK: Allocate(sum_local, float, {4}); +#CHECK: for (int i = 0; i < 4 +#CHECK: sum_local[i] = sum[i + 4 * l_outer]; +#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionSplitCacheConsumerAccess) { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + For* outer; + For* inner; + + // Split outer reduction axis. + l.splitWithMask(l.getLoopStmtsFor(d)[0], 4, &outer, &inner); + + // Split reduction consumer. + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + + l.cacheAccesses(d->buf(), "sum_local", inner); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + // reduction changes but cache does not. + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]); +#CHECK: Allocate(sum_local, float, {4}); +#CHECK: for (int i = 0; i < 4 +#CHECK: sum_local[i] = sum[i + 4 * l_outer]; +#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionReorderCacheConsumerAccess) { + KernelScope kernel_scope; + + Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); + Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); + + Tensor* c = Compute( + "scale", + {{24, "l2"}, {32, "n1"}, {12, "m1"}}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + + Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d->call(l); + }); + + LoopNest l({e}); + + For* outer; + For* inner; + + // reorder outer reduction axes. + auto loops = l.getLoopStmtsFor(d); + l.reorderAxis(loops[0], loops[1]); + + // Split reduction consumer. + l.splitWithMask(l.getLoopStmtsFor(e)[0], 4, &outer, &inner); + + l.cacheAccesses(d->buf(), "sum_local", inner); + l.prepareForCodegen(); + + Stmt* result = IRSimplifier::simplify(l.root_stmt()); + + // neither reduction body not cache changes. + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( +#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]); +#CHECK: Allocate(sum_local, float, {4}); +#CHECK: for (int i = 0; i < 4 +#CHECK: sum_local[i] = sum[i + 4 * l_outer]; +#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionRfactorCacheTempOuter) { + KernelScope kernel_scope; + + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + Placeholder b(BufHandle("B", {m, n, k}, kFloat)); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); + LoopNest loop({c}); + auto reduces = NodeFinder::find(loop.root_stmt()); + loop.rfactor(reduces[0], reduces[0]->reduce_args()[1]); + + reduces = NodeFinder::find(loop.root_stmt()); + std::vector loops = NodeFinder::find(loop.root_stmt()); + loop.cacheAccesses(reduces[0]->accumulator(), "tmp2", loops[2]); + loop.prepareForCodegen(); + Stmt* s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(tmp_buf, float, {n}); +#CHECK: for (int a = 0; a < m +#CHECK: Allocate(tmp2, float, {n}); +#CHECK: for (int i = 0; i < n +#CHECK: tmp2[i] = 0 +#CHECK: } +#CHECK: for (int b = 0; b < n +#CHECK: for (int c +#CHECK: tmp2[b] = (tmp2[b]) + (B[ +#CHECK: } +#CHECK: } +#CHECK: for (int i = 0; i < n +#CHECK: tmp_buf[i] = (tmp_buf[i]) + (tmp2[i]); +#CHECK: } +#CHECK: Free(tmp2); +#CHECK-NOT: tmp2 + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReductionRfactorCacheTempInner) { + KernelScope kernel_scope; + + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + Placeholder b(BufHandle("B", {m, n, k}, kFloat)); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); + LoopNest loop({c}); + auto reduces = NodeFinder::find(loop.root_stmt()); + loop.rfactor(reduces[0], reduces[0]->reduce_args()[1]); + + reduces = NodeFinder::find(loop.root_stmt()); + std::vector loops = NodeFinder::find(loop.root_stmt()); + loop.cacheAccesses(reduces[0]->accumulator(), "tmp2", loops[3]); + loop.prepareForCodegen(); + Stmt* s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(tmp_buf, float, {n}); +#CHECK: for (int a = 0; a < m +#CHECK: for (int b = 0; b < n +#CHECK: Allocate(tmp2, float, {1}); +#CHECK: tmp2[0] = 0 +#CHECK: for (int c +#CHECK: tmp2[0] = (tmp2[0]) + (B[ +#CHECK: } +#CHECK: tmp_buf[b] = (tmp_buf[b]) + (tmp2[0]); +#CHECK: Free(tmp2); +#CHECK-NOT: tmp2 + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReductionVectorize) { + KernelScope kernel_scope; + + std::vector in_(8 * 8); + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + in_[i * 8 + j] = i; + } + } + std::vector out_before(8, -1.f); + std::vector out_after(8, -1.f); + + Placeholder in(BufHandle("in", {8, 8}, kFloat)); + + Tensor* tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}}); + LoopNest l_before({tensor}); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); + cg_before.call({in_, out_before}); + + LoopNest l({tensor}); + l.vectorize(l.getLoopStmtsFor(tensor)[0]); + + Stmt* s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8); +#CHECK: for (int n = 0; n < 8; n++) { +#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(n, 8, 8)]), out_args={Ramp(0, 1, 8)}, reduce_args={n}); +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Vectorizing should not change result. + l.prepareForCodegen(); + s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(s, {in, tensor}); + cg_after.call({in_, out_after}); + for (int i = 0; i < 8; ++i) { + ASSERT_EQ(out_before[i], out_after[i]); + } +} + +TEST(Reductions, ReductionVectorizeInner) { + KernelScope kernel_scope; + + Placeholder in(BufHandle("in", {8, 8}, kFloat)); + + Tensor* tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}}); + LoopNest l({tensor}); + + ASSERT_THROWS_WITH( + l.vectorize(l.getLoopStmtsFor(tensor)[1]), "reduction axis"); +} + +TEST(Reductions, ReductionVectorizeRfactor) { + KernelScope kernel_scope; + + std::vector in_(8 * 8); + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + in_[i * 8 + j] = i; + } + } + std::vector out_before(1, -1.f); + std::vector out_after(1, -1.f); + + Placeholder in(BufHandle("in", {8, 8}, kFloat)); + + Tensor* tensor = Reduce("sum", {}, Sum(), in, {{8, "m"}, {8, "n"}}); + + LoopNest l_before({tensor}); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); + cg_before.call({in_, out_before}); + + LoopNest l({tensor}); + ASSERT_THROWS_WITH( + l.vectorize(l.getLoopStmtsFor(tensor)[1]), "reduction axis"); + + // But if we rfactor this so it's not a reduce axis we can vectorize that + // loop. + std::vector loops = l.getLoopStmtsFor(tensor); + auto v = loops.at(1)->var(); + l.rfactor(tensor->body(), v); + + loops = NodeFinder::find(l.root_stmt()); + l.vectorize(loops[2]); + + Stmt* s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: sum = 0.f; +#CHECK: for (int n = 0; n < 8; n++) { +#CHECK: tmp_buf[n] = 0.f; +#CHECK: } +#CHECK: for (int m = 0; m < 8; m++) { +#CHECK: tmp_buf[Ramp(0, 1, 8)] = ReduceOp((tmp_buf[Ramp(0, 1, 8)]) + (in[Ramp(8 * m, 1, 8)]), out_args={Ramp(0, 1, 8)}, reduce_args={m}); +#CHECK: } +#CHECK: for (int n = 0; n < 8; n++) { +#CHECK: sum = ReduceOp((sum) + (tmp_buf[n]), out_args={}, reduce_args={n}); +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Vectorizing should not change result. + l.prepareForCodegen(); + s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(s, {in, tensor}); + cg_after.call({in_, out_after}); + + ASSERT_EQ(out_before[0], out_after[0]); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp index e7a28f1fb277b..bfddb285bb345 100644 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -1,3 +1,4 @@ +#include #include "test/cpp/tensorexpr/test_base.h" #include "test/cpp/tensorexpr/test_utils.h" @@ -11,18 +12,18 @@ namespace jit { using namespace torch::jit::tensorexpr; // Can replace a simple scalar access with a local variable. -void testRegisterizerSimple() { +TEST(Registerizer, RegisterizerSimple) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = - Block::make({Store::make(a, {0}, 0, 1), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); /* * A[0] = 0; @@ -31,14 +32,14 @@ void testRegisterizerSimple() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_ = x + A_; + * A_1 = x + A_1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -46,28 +47,28 @@ void testRegisterizerSimple() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } // Won't do replacement of a loop access. -void testRegisterizerLoop() { +TEST(Registerizer, RegisterizerLoop) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {10}, kInt)); + BufHandle a("A", {10}, kInt); VarHandle x("x", kInt); - Stmt* stmt = - Block::make({Store::make(a, {0}, 0, 1), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))}); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))}); /* * A[0] = 0; @@ -77,7 +78,7 @@ void testRegisterizerLoop() { */ // No change. - registerize(stmt); + stmt = registerize(stmt); /* * A[0] = 0; @@ -96,40 +97,40 @@ void testRegisterizerLoop() { # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A_ # CHECK: A[x] = -# CHECK-NOT: A[0] = A_;)IR"; +# CHECK-NOT: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } // Won't replace even if the load is a fixed scalar, since the store could // invalidate it. -void testRegisterizerLoopFixedLoad() { +TEST(Registerizer, RegisterizerLoopFixedLoad) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); - Stmt* stmt = - Block::make({Store::make(a, {0}, 0, 1), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); /* * A[0] = 0; * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; + * A[x] = (A[0]) + x; * } */ // No change. - registerize(stmt); + stmt = registerize(stmt); /* * A[0] = 0; * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; + * A[x] = (A[0]) + x; * } */ @@ -143,15 +144,262 @@ void testRegisterizerLoopFixedLoad() { # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A_ # CHECK: A[x] = -# CHECK-NOT: A[0] = A_;)IR"; +# CHECK-NOT: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize accesses that occur entirely within inner scopes, even if +// they depend on the loop var. +TEST(Registerizer, RegisterizerLoopInternal) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * A[x] = (A[x]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * A_1 = A_1 + x; + * A_1 = A_1 + x; + * A[x] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = A_1 + x; +# CHECK: A_1 = A_1 + x; +# CHECK: A[x] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access can be overlapped by another read in the same Expr. In this case +// B[z] and B[y] overlap and prevent registerization of both accesses. +TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Store::make( + a, + {x}, + Add::make(Load::make(b, {y}, 1), Load::make(b, {z}, 1)), + 1))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (B[y]) + (B[z]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeated) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1)})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {1}, 1), x), 1)})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[1]; + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_1; + * A_2 = x + A_1; + * } + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_1; + * A_2 = x + A_1; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[1]; +# CHECK: int A_2 = A[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = x + A_1; +# CHECK: A_2 = x + A_1; +# CHECK: } +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = x + A_1; +# CHECK: A_2 = x + A_1; +# CHECK: } +# CHECK-NOT: A[1] +# CHECK: A[0] = A_2; +# CHECK-NOT: A[1] +# CHECK: })IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } +TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1)})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1), + Store::make(a, {0}, Add::make(Load::make(a, {x}, 1), x), 1)})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1), + Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1)})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1), + Store::make(a, {0}, Add::make(x, Load::make(a, {y}, 1)), 1)})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + // Will registerize multiple accesses of different items of the same buffer. -void testRegisterizerMultiVar() { +TEST(Registerizer, RegisterizerMultiVar) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {2}, kInt)); + BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); Stmt* stmt = Block::make({ Store::make(a, {0}, 0, 1), @@ -174,17 +422,17 @@ void testRegisterizerMultiVar() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; - * int A__1 = 0; + * int A_1 = 0; + * int A_2 = 0; * for (int x = 0; x < 10; x++) { - * A__1 = x + A__1; - * A_ = A_ - x; + * A_2 = x + A_2; + * A_1 = A_1 - x; * } - * A[1] = A__1; - * A[0] = A_; + * A[1] = A_2; + * A[0] = A_1; */ std::ostringstream oss; @@ -192,23 +440,23 @@ void testRegisterizerMultiVar() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; -# CHECK: int A__1 = 0; +# CHECK: int A_1 = 0; +# CHECK: int A_2 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A__1 = -# CHECK: A[1] = A__1 -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A_2 = +# CHECK: A[1] = A_2 +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } // Will registerize the valid accesses while skipping invalid replacements. -void testRegisterizerVariableLoad() { +TEST(Registerizer, RegisterizerVariableLoad) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {10}, kInt)); + BufHandle a("A", {1}, kInt); + BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); VarHandle x2("x", kInt); Stmt* stmt = Block::make( @@ -234,17 +482,17 @@ void testRegisterizerVariableLoad() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { * B[x] = x; * } * for (int x_1 = 0; x_1 < 10; x_1++) { - * A_ = A_ + (B[x_1]); + * A_1 = A_1 + (B[x_1]); * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -252,32 +500,32 @@ void testRegisterizerVariableLoad() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK: B[x] = x # CHECK: for (int x_1 = 0; x_1 < 10; x_1++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } // Can registerize variable accesses so long as the variable does not change. -void testRegisterizerSymbolicIndices() { +TEST(Registerizer, RegisterizerSymbolicIndices) { KernelScope kernel_scope; VarHandle i("i", kInt); VarHandle N("N", kInt); - Buffer a(BufHandle("A", {N}, kInt)); + BufHandle a("A", {N}, kInt); VarHandle x("x", kInt); - Stmt* stmt = - Block::make({Store::make(a, {i}, 0, 1), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {i}, Add::make(Load::make(a, {i}, 1), x), 1)}))}); + Stmt* stmt = Block::make( + {Store::make(a, {i}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {i}, Add::make(Load::make(a, {i}, 1), x), 1)}))}); /* * A[i] = 0; @@ -286,14 +534,14 @@ void testRegisterizerSymbolicIndices() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_ = x + A_; + * A_1 = x + A_1; * } - * A[i] = A_; + * A[i] = A_1; */ std::ostringstream oss; @@ -301,50 +549,19 @@ void testRegisterizerSymbolicIndices() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[i] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[i] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -// Will not registerize if a variable usage of the sclar may overlap the target -// scalar. -// TODO: we can support this by writing back to the buffer before the variable -// access, but we'd need temporal analysis of dependencies which we don't have -// yet. Will have to fix soon though. -void testRegisterizerEarlyStop() { - KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); - VarHandle x("x", kInt); - Stmt* stmt = Block::make( - {Store::make(a, {0}, 0, 1), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)})), - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1))}); - - std::ostringstream before; - before << *stmt; - - // No change. - registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - // Can registerize accesses dependent on multiple loop vars. -void testRegisterizerMultiLoop() { +TEST(Registerizer, RegisterizerMultiLoop) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); Stmt* stmt = Block::make( @@ -372,16 +589,16 @@ void testRegisterizerMultiLoop() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { * for (int y = 0; y < 10; y++) { - * A_ = x * y + y * A_l + * A_1 = x * y + y * A_1; * } * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -389,20 +606,20 @@ void testRegisterizerMultiLoop() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK: for (int y = 0; y < 10; y++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } // Can registerize correctly if scalars already exist in the program. -void testRegisterizerRepeated() { +TEST(Registerizer, RegisterizerRepeated) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {2}, kInt)); + BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); Stmt* stmt = Block::make({ Store::make(a, {0}, 0, 1), @@ -418,23 +635,24 @@ void testRegisterizerRepeated() { // Registerize manually to make sure we only replace a single target. { - RegisterizerAnalysis analysis; + registerizer::RegisterizerAnalysis analysis; stmt->accept(&analysis); auto candidates = analysis.getCandidates(); ASSERT_EQ(candidates.size(), 2); - RegisterizerReplacer replacer(candidates.front()); + candidates.pop_back(); + registerizer::RegisterizerReplacer replacer(candidates); stmt = stmt->accept_mutator(&replacer); } // Re-analyze and replace the second target. { - RegisterizerAnalysis analysis; + registerizer::RegisterizerAnalysis analysis; stmt->accept(&analysis); auto candidates = analysis.getCandidates(); ASSERT_EQ(candidates.size(), 1); - RegisterizerReplacer replacer(candidates.front()); + registerizer::RegisterizerReplacer replacer(candidates); stmt = stmt->accept_mutator(&replacer); } @@ -443,22 +661,22 @@ void testRegisterizerRepeated() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; -# CHECK: int A__1 = 0; +# CHECK: int A_1 = 0; +# CHECK: int A_1_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A__1 = -# CHECK: A[1] = A__1 -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A_1_1 = +# CHECK: A[1] = A_1_1; +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -// Can registerize rthe load of A. -void testRegisterizerNoLoads() { +// Can registerize the load of A. +TEST(Registerizer, RegisterizerNoLoads) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); Stmt* stmt = Block::make( {Store::make(a, {0}, 0, 1), @@ -472,14 +690,14 @@ void testRegisterizerNoLoads() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_ = x + 1; + * A_1 = x + 1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -487,29 +705,29 @@ void testRegisterizerNoLoads() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } // Can registerize the load of A but not the store of B. -void testRegisterizerNoRepeatedStores() { +TEST(Registerizer, RegisterizerNoRepeatedStores) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {10}, kInt)); + BufHandle a("A", {1}, kInt); + BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); - Stmt* stmt = - Block::make({Store::make(a, {0}, 0, 1), - For::make( - x, - 0, - 10, - Block::make({Store::make( - b, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + b, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); /* * A[0] = 0; @@ -518,17 +736,17 @@ void testRegisterizerNoRepeatedStores() { * } */ - registerize(stmt); + stmt = registerize(stmt); // TODO: its unnecessary to reorder the initializer of A[0], but it's not // actually worse so lets not worry for now. /* - * int A_ = 0; + * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * B[x] = x + A_; + * B[x] = x + A_1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -536,19 +754,19 @@ void testRegisterizerNoRepeatedStores() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = 0; +# CHECK: int A_1 = 0; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A_ # CHECK: B[x] = -# CHECK: A[0] = A_;)IR"; +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } // Won't registerize if there are multiple accesses which may overlap. -void testRegisterizerMultiVarOverlap() { +TEST(Registerizer, RegisterizerMultiVarOverlap) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {2}, kInt)); + BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); Stmt* stmt = Block::make({ Store::make(a, {0}, 0, 1), @@ -567,7 +785,7 @@ void testRegisterizerMultiVarOverlap() { before << *stmt; // No change. - registerize(stmt); + stmt = registerize(stmt); std::ostringstream after; after << *stmt; @@ -575,15 +793,15 @@ void testRegisterizerMultiVarOverlap() { ASSERT_EQ(before.str(), after.str()); } -void testRegisterizerAllocs() { +TEST(Registerizer, RegisterizerAllocs) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {2}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); - Buffer c(BufHandle("C", {1}, kInt)); + BufHandle a("A", {2}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {1}, kInt); VarHandle x("x", kInt); - VarHandle b_(b.data()->base_handle()); + VarHandle b_(b.node()->base_handle()); Stmt* stmt = Block::make( {Allocate::make(b_, kInt, {Load::make(c, {0}, 1)}), @@ -609,19 +827,19 @@ void testRegisterizerAllocs() { * Free(B); */ - registerize(stmt); + stmt = registerize(stmt); /* - * int C_ = C[0]; + * int C_1 = C[0]; * Allocate(B, int, {C_}); - * int A_ = C_; - * int B_ = 0; + * int A_1 = C_1; + * int B_1 = 0; * for (int x = 0; x < 10; x++) { - * B_ = B_ + x; - * A_ = C_; + * B_1 = B_1 + x; + * A_1 = C_1; * } - * B[0] = B_; - * A[0] = A_; + * B[0] = B_1; + * A[0] = A_1; * Free(B); */ @@ -630,23 +848,23 @@ void testRegisterizerAllocs() { const std::string& verification_pattern = R"IR( -# CHECK: int C_ = C[0]; +# CHECK: int C_1 = C[0]; # CHECK: Allocate(B -# CHECK: int A_ = C_; -# CHECK: int B_ = 0; +# CHECK: int A_1 = C_1; +# CHECK: int B_1 = 0; # CHECK: for (int x = 0; x < 10; x++) -# CHECK: B_ = -# CHECK: A_ = C_ -# CHECK: B[0] = B_; -# CHECK: A[0] = A_; +# CHECK: B_1 = +# CHECK: A_1 = C_ +# CHECK: B[0] = B_1; +# CHECK: A[0] = A_1; # CHECK: Free(B)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testRegisterizerNoInitializer() { +TEST(Registerizer, RegisterizerNoInitializer) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); Stmt* stmt = Block::make({For::make( x, @@ -661,14 +879,14 @@ void testRegisterizerNoInitializer() { * } */ - registerize(stmt); + stmt = registerize(stmt); /* - * int A_ = A[0]; + * int A_1 = A[0]; * for (int x = 0; x < 10; x++) { - * A_ = x + A_; + * A_1 = x + A_1; * } - * A[0] = A_; + * A[0] = A_1; */ std::ostringstream oss; @@ -676,80 +894,110 @@ void testRegisterizerNoInitializer() { const std::string& verification_pattern = R"IR( -# CHECK: int A_ = A[0]; +# CHECK: int A_1 = A[0]; # CHECK: for (int x = 0; x < 10; x++) # CHECK-NOT: A[ -# CHECK: A_ = -# CHECK: A[0] = A_;)IR"; +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testRegisterizerLoadThenStore() { +TEST(Registerizer, RegisterizerNoInitializerLoopVar) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); + BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); Stmt* stmt = Block::make({For::make( x, 0, 10, - Block::make({Store::make(b, {0}, Add::make(Load::make(a, {0}, 1), x), 1), - Store::make(a, {0}, Load::make(b, {0}, 1), 1)}))}); + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))}); /* * for (int x = 0; x < 10; x++) { - * B[0] = (A[0]) + x; - * A[0] = B[0]; + * A[x] = (A[x]) + x; * } */ - registerize(stmt); - - /* - * int A_ = A[0]; - * int B_ = B[0]; - * for (int x = 0; x < 10; x++) { - * B_ = x + A_; - * A_ = B_; - * } - * B[0] = B_; - * A[0] = A_; - */ + std::ostringstream before; + before << *stmt; - std::ostringstream oss; - oss << *stmt; + // No change. + stmt = registerize(stmt); - const std::string& verification_pattern = - R"IR( -# CHECK: int A_ = A[0]; -# CHECK: int B_ = B[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: B[ -# CHECK: B_ = -# CHECK-NOT: A[ -# CHECK: A_ = B_ -# CHECK: B[0] = B_ -# CHECK: A[0] = A_;)IR"; + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoadThenStore) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {0}, Add::make(Load::make(a, {0}, 1), x), 1), + Store::make(a, {0}, Load::make(b, {0}, 1), 1)}))}); + + /* + * for (int x = 0; x < 10; x++) { + * B[0] = (A[0]) + x; + * A[0] = B[0]; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * int B_1 = B[0]; + * for (int x = 0; x < 10; x++) { + * B_1 = x + A_1; + * A_1 = B_1; + * } + * B[0] = B_1; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: int B_1 = B[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: B[ +# CHECK: B_1 = +# CHECK-NOT: A[ +# CHECK: A_1 = B_ +# CHECK: B[0] = B_ +# CHECK: A[0] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -void testRegisterizerParallelized() { +TEST(Registerizer, RegisterizerParallelized) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); LoopOptions loopOpts; loopOpts.set_gpu_block_index(0); - Stmt* stmt = - Block::make({Store::make(a, {0}, 0, 1), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}), - loopOpts)}); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}), + loopOpts)}); /* * A[0] = 0; @@ -763,49 +1011,345 @@ void testRegisterizerParallelized() { "Registerization must occur after parallelism flattening"); } -void testRegisterizerConditions() { +// Should be able to registerize this since the scalar would exist before the +// branch. +TEST(Registerizer, RegisterizerConditionAfter) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {5}, kInt)); + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); VarHandle x("x", kInt); - Stmt* stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({ - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}, 1), x), - Add::make(Load::make(a, {x - 5}, 1), x)), - 1), - Store::make( - a, - {x - 5}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}, 1), x), - Add::make(Load::make(a, {x - 5}, 1), x)), - 1)), - }))}); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this since the scalar exists in the same form +// after the branch and there is no overlap. +TEST(Registerizer, RegisterizerConditionBefore) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = B[x]; + * C[x] = A[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_ 1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = B[x]; + * C[x] = A_1; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this as the combination of the two above rules. +TEST(Registerizer, RegisterizerConditionInside) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Store::make(b, {x}, Load::make(a, {x}, 1), 1), + Store::make(a, {x}, Load::make(c, {x}, 1), 1)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * B[x] = A_1; + * A_1 = C[x]; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: B[x] = A_1; +# CHECK: A_1 = C[x]; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An example where an access is cut by an overlapping access inside a +// condition, and both sides are large enough to be registerized but cannot be +// because there is no safe place to put the initializer or finalizer. +TEST(Registerizer, RegisterizerConditionInsideOverlap1) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Store::make(a, {0}, 3, 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x}, 1), 1), + Store::make(a, {x}, Load::make(c, {x}, 1), 1)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + // The A[0] store overlaps, A[x] cutting the region that can be registerized + // into two groups. + // Each group has 2 loads and 2 stores however, so we could registerize it, + // but the first group would need to be finalized inside the condition block, + // the second would need to be initialized inside the condition block. There's + // no safe place to put these that's visible to the other uses in the group + // and so neither registerization is possible. std::ostringstream before; before << *stmt; - /* for (int x = 0; x < 10; x++) { - * if (x<5 ? 1 : 0) { - * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } else { - * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Same as the above, but the access group before the condition (and after the +// condition) are large enough to be registerized without needing the access +// from the loop. Registerization occurs but does not include any accesses in +// the condition, and the first group must be finalized before the Cond, the +// second initialized after it. +TEST(Registerizer, RegisterizerConditionInsideOverlap2) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(a, {x}, Load::make(b, {x + 1}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Store::make(a, {0}, 3, 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x}, 1), 1), + Store::make(b, {x + 1}, Load::make(a, {x}, 1), 1), + Store::make(a, {x}, Load::make(c, {x}, 1), 1)}); + + /* + * A[x] = B[x]; + * A[x] = B[x + 1]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * B[x + 1] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; // A_1 initializer + * A_1 = B[x + 1]; // + * C[x] = A_1; // + * A[x] = A_1; // A_1 finalizer + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * int A_2 = A[x]; // A_2 initialier + * B[x] = A_2; // + * B[x + 1] = A_2; // + * A_2 = C[x]; // + * A[x] = A_2; // A_2 finalizer + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: A_1 = B[x + 1]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1; +# CHECK: if ( +# CHECK-NOT: A_1 = A_1 + 1; +# CHECK: A[x] = (A[x] +# CHECK: A[0] = +# CHECK: A[x] = (A[x] +# CHECK: } +# CHECK: int A_2 = A[x]; +# CHECK: B[x] = A_2; +# CHECK: B[x + 1] = A_2; +# CHECK: A_2 = C[x]; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// When accesses are within conditional blocks they are not visible to the wider +// program, because we don't know if the branch would be taken and if it isn't +// the accesses in it don't need to be valid (think size checks on the index). +// In this case the accesses cannot be registerized. +TEST(Registerizer, RegisterizerConditionHidden) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; * } */ + std::ostringstream before; + before << *stmt; + // No change. - registerize(stmt); + stmt = registerize(stmt); std::ostringstream after; after << *stmt; @@ -813,5 +1357,2461 @@ void testRegisterizerConditions() { ASSERT_EQ(before.str(), after.str()); } +// But... if the same access is found in a non conditional scope, that means +// that that access is valid in the higher scope (or at least if its not it's +// the user's fault). It "unhides" the conditional accesses, allowing +// registerization to occur. +TEST(Registerizer, RegisterizerConditionUnhidden) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = (A[x]) + 1; <-- this is doing the unhiding. + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = A_1 + 1; + * if (x>5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (x<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x>5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of a Cond. +TEST(Registerizer, RegisterizerCondCondition) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x}, 1), 1), + Store::make(c, {x}, Load::make(a, {x}, 1), 1), + Cond::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Store::make(c, {x}, Add::make(Load::make(c, {x}, 1), 1), 1), + nullptr)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if ((A[x])<5 ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * int C_1 = A_1; + * if (A_1<5 ? 1 : 0) { + * C_1 = C_1 + 1; + * } + * C[x] = C_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: int C_1 = A_1; +# CHECK: if (A_1<5 +# CHECK: C_1 = C_1 + 1; +# CHECK: C[x] = C_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +TEST(Registerizer, RegisterizerCondConditionUnhidden) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 1), 1), + Store::make(a, {x}, Add::make(Load::make(a, {x}, 1), 10), 1))}); + + /* + * if ((A[x])<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } else { + * A[x] = (A[x]) + 10; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (A_1<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } else { + * A_1 = A_1 + 10; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (A_1<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } else { +# CHECK: A_1 = A_1 + 10; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Conditional hiding also works for IfThenElse exprs. +TEST(Registerizer, RegisterizerIfThenElseHidden) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make( + {Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1)}); + + /* + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Conditional unhiding also works for IfThenElse exprs. +TEST(Registerizer, RegisterizerIfThenElseUnhidden) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make({ + Store::make(a, {x}, 0, 1), + Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x + 1}, 1), 2)), + 1), + }); + + /* + * A[x] = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested IfThenElse exprs can't promote to higher level scopes. +TEST(Registerizer, RegisterizerIfThenElseNested) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + BufHandle d("D", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + IfThenElse::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Load::make(d, {x}, 1), + Load::make(b, {x}, 1)), + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kEQ), + Load::make(c, {x}, 1), + Load::make(d, {x}, 1))), + 1)}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, + * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), + * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Cannot registerize an access completely contained within an IfThenElse +// branch, since it is not a Stmt and cannot hold variable definitions. We need +// to check that we don't promote the initializer/finalizer to the enclosing +// Block. +TEST(Registerizer, RegisterizerIfThenElseInternal) { + KernelScope kernel_scope; + // Making these floats so they don't get simplified to a single access. + BufHandle a("A", {5}, kFloat); + BufHandle b("B", {5}, kFloat); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Add::make(Load::make(b, {x}, 1), Load::make(b, {x}, 1)), + Load::make(b, {x}, 1)), + 1)}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // If this was a Cond instead of an IfThenElse then we could registerize the + // two accesses to B[x] in the True branch. + + // Actually lets verify that. + + stmt = Block::make({Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Store::make( + a, {x}, Add::make(Load::make(b, {x}, 1), Load::make(b, {x}, 1)), 1), + Store::make(a, {x}, Load::make(b, {x}, 1), 1))}); + + /* + * if (x<3 ? 1 : 0) { + * A[x] = (B[x]) + (B[x]); + * } else { + * A[x] = B[x]; + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<3 ? 1 : 0) { + * float B_1 = B[x]; + * A[x] = B_1 + B_1; + * } else { + * A[x] = B[x]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK-NOT: float +# CHECK: if (x<3 +# CHECK: float B_1 = +# CHECK: A[x] = B_1 + B_1 +# CHECK: } else { +# CHECK: A[x] = B[x] +# CHECK: } +# CHECK-NOT: A[x] +# CHECK-NOT: B[x])IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of an IfThenElse; +TEST(Registerizer, RegisterizerIfThenElseCondition) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make( + {Store::make(a, {x}, Load::make(a, {x}, 1), 1), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Load::make(b, {0}, 1), + Load::make(c, {0}, 1)), + 1)}); + + /* + * A[x] = A[x]; <---- just here so there are enough accesses to combine. + * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * A_1 = A_1; + * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Store::make( + b, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), 1), + Add::make(Load::make(a, {x}, 1), 10)), + 1)}); + + /* + * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot promote accesses internal to IfThenElse branches even if the enclosing +// scope if conditional. +TEST(Registerizer, RegisterizerConditionBranchOnly) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({ + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), x), + Add::make(Load::make(a, {x - 5}, 1), x)), + 1), + Store::make( + a, + {x - 5}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}, 1), x), + Add::make(Load::make(a, {x - 5}, 1), x)), + 1)), + }))}); + + std::ostringstream before; + before << *stmt; + + /* for (int x = 0; x < 10; x++) { + * if (x<5 ? 1 : 0) { + * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } else { + * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } + * } + */ + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// We can registerize an IfThenElse that appears in the condition branch of a +// Cond. This is a weird but valid thing to do. +TEST(Registerizer, RegisterizerCondIfThenElse) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make( + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}, 1), 5, CompareSelectOperation::kLT), + Load::make(a, {x}, 1), + Load::make(b, {x}, 1)), + x, + CompareSelectOperation::kEQ), + Store::make(c, {x}, Add::make(Load::make(c, {x}, 1), 1), 1), + nullptr)}); + + /* + * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + // access to A can be registerized, but not B or C + + /* + * int A_1 = A[x]; + * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] +# CHECK: C[x] = (C[x]) + 1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a conditional access in the RHS of a store unhidden by it's +// LHS, and hoist it out of a loop. +TEST(Registerizer, RegisterizerIfThenElseLoop) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}, 1), + Load::make(b, {y}, 1)), + 1)); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: for ( +# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot registerize if the RHS overlaps the access creating visibility. +TEST(Registerizer, RegisterizerIfThenElseLoopCut) { + KernelScope kernel_scope; + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + Stmt* stmt = Block::make({For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}, 1), + Load::make(a, {y}, 1)), + 1))}); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Simple case where an access is cut by an overlapping access later in the +// program, we can registerize up until the overlap. +TEST(Registerizer, RegisterizerPartialAfter) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)})), + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK-NOT: A)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize an access which overlaps a previous access, the +// initializer must be inserted after the previous access. +TEST(Registerizer, RegisterizerPartialBefore) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1)), + Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))}); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The combination of the previous two tests, an access is cut by an overlapping +// access in both directions. +TEST(Registerizer, RegisterizerPartialInside) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x1("x1", kInt); + VarHandle x2("x2", kInt); + VarHandle x3("x3", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 2, 1), + For::make( + x1, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x1), 1)), + For::make( + x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}, 1), 1)), + For::make( + x3, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x3), 1))}); + + /* + * A[0] = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A[0] = (A[0]) + x1; + * } + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * for (int x3 = 0; x3 < 10; x3++) { + * A[0] = (A[0]) + x3; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A_1 = A_1 + x1; + * } + * A[0] = A_1; + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * int A_2 = A[0]; + * for (int x3 = 0; x3 < 10; x3++) { + * A_2 = A_2 + x3; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x2] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x3; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An element could be registerized program wide but is cut by a conditional +// access, we should break this into two scalars and write back to the buffer +// before the condition. +TEST(Registerizer, RegisterizerPartialCondition) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 2, 1), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1), + nullptr), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1))}); + + /* + * A[0] = 2; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + x; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: A[x] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Tests case where an access is cut by an internal conditional access which +// itself is registerized. +TEST(Registerizer, RegisterizerPartialConditionInternalCut) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 1, 1), + Store::make(a, {0}, 3, 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1, 1), Store::make(a, {x}, 3, 1)}), + nullptr), + Store::make(a, {0}, 4, 1), + Store::make(a, {0}, 6, 1)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[0] = 4; + * A[0] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * int A_2 = 1; + * A_2 = 3; + * A[x] = A_2; + * } + * int A_3 = 4; + * A_3 = 6; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: int A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: A[x] = A_2; +# CHECK: } +# CHECK: int A_3 = 4; +# CHECK: A_3 = 6; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// First statment in condition closes outer access, but can be registerized with +// later statements. +TEST(Registerizer, RegisterizerPartialConditionInternalStart) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, 1, 1), + Store::make(a, {0}, 3, 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1, 1), Store::make(a, {x}, 3, 1)}), + nullptr), + Store::make(a, {x}, 4, 1), + Store::make(a, {x}, 6, 1)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[x] = 4; + * A[x] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * int A_2 = A[x]; <--- must read from the input here. + * if (x<5 ? 1 : 0) { + * A_2 = 1; + * A_2 = 3; + * } + * A_2 = 4; + * A_2 = 6; + * A[x] = A_2; + */ + + // TODO: I suppose we could refactor with a conditional initializier? + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: int A_2 = A[x]; +# CHECK: if ( +# CHECK: A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: } +# CHECK: A_2 = 4; +# CHECK: A_2 = 6; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access cuts two open overlaps and creates four scalar variables. +TEST(Registerizer, RegisterizerPartialOverlapsTwo) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {1}, Load::make(a, {0}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1), + For::make(x, 1, 10, Store::make(a, {x}, x, 1)), + Store::make(a, {1}, Load::make(a, {0}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1), + Store::make(a, {0}, Load::make(a, {1}, 1), 1)}); + + /* + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * int A_2 = A_1; + * A_1 = A_2; + * A_1 = A_2; + * A[1] = A_2; + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * int A_3 = A[0]; + * int A_4 = A_3; + * A_3 = A_4; + * A_3 = A_4; + * A[1] = A_4; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: int A_2 = A_1; +# CHECK: A_1 = A_2; +# CHECK: A_1 = A_2; +# CHECK: A[1] = A_2; +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = x; +# CHECK: } +# CHECK: int A_3 = A[0]; +# CHECK: int A_4 = A_3; +# CHECK: A_3 = A_4; +# CHECK: A_3 = A_4; +# CHECK: A[1] = A_4; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested blocks will automatically be flattened and do not provent +// registerization of enclosed accesses. +TEST(Registerizer, RegisterizerNestedBlocks) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 2), 1)}), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 3), 1), + Block::make({Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), 4), 1)})})}); + + /* + * A[0] = (A[0]) + 1; + * { + * A[0] = (A[0]) + 2; + * } + * { + * A[0] = (A[0]) + 3; + * { + * A[0] = (A[0]) + 4; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * A_1 = A_1 + 2; + * A_1 = A_1 + 3; + * A_1 = A_1 + 4; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_1 = A_1 + 2; +# CHECK: A_1 = A_1 + 3; +# CHECK: A_1 = A_1 + 4; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The access can be registerized internally to a condition, but must ensure +// that both initializer and finalizer are within the same condition. +TEST(Registerizer, RegisterizerNestedConditions) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If an access exists outside the scope of the condition then we can lift +// nested conditional usages into the same scalar. +TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {1}, 1, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x<5 +# CHECK: A[1] = 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + stmt = registerize(stmt); +} + +TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + stmt = registerize(stmt); +} + +// If an access is cut by another access internal to a condition block, it still +// cuts the access. +TEST(Registerizer, RegisterizerNestedConditionsCut) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {x}, 1, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + nullptr)}))}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * for (int x = 0; x < 10; x++) { + * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Three loops and four element regions, three of which should be registerized +// at different levels of the IR. +TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {4}, 0, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kGT), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kGT), + Block::make({ + Cond::make( + CompareSelect::make(x, 4, CompareSelectOperation::kGT), + Block::make({ + Store::make( + a, {1}, Add::make(Load::make(a, {1}, 1), 1), 1), + Store::make( + a, {2}, Add::make(Load::make(a, {2}, 1), 1), 1), + Store::make( + a, {3}, Add::make(Load::make(a, {3}, 1), 1), 1), + Store::make( + a, {4}, Add::make(Load::make(a, {4}, 1), 1), 1), + Store::make( + a, {1}, Add::make(Load::make(a, {1}, 1), 1), 1), + }), + nullptr), + Store::make(a, {2}, Add::make(Load::make(a, {2}, 1), 1), 1), + }), + nullptr), + nullptr)}); + + /* + * A[4] = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * if (x>4 ? 1 : 0) { + * A[1] = (A[1]) + 1; + * A[2] = (A[2]) + 1; + * A[3] = (A[3]) + 1; + * A[4] = (A[4]) + 1; + * A[1] = (A[1]) + 1; + * } + * A[2] = (A[2]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * int A_3 = A[2]; + * if (x>4 ? 1 : 0) { + * int A_2 = A[1]; + * A_2 = A_2 + 1; + * A_3 = A_3 + 1; + * A[3] = (A[3]) + 1; + * A_1 = A_1 + 1; + * A_2 = A_2 + 1; + * A[1] = A_2; + * } + * A_3 = A_3 + 1; + * A[2] = A_3; + * } + * } + * A[4] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: if (x>2 ? 1 : 0) { +# CHECK: if (x>3 ? 1 : 0) { +# CHECK: int A_3 = A[2]; +# CHECK: if (x>4 ? 1 : 0) { +# CHECK: int A_2 = A[1]; +# CHECK: A_2 = A_2 + 1; +# CHECK: A_3 = A_3 + 1; +# CHECK: A[3] = (A[3]) + 1; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_2 = A_2 + 1; +# CHECK: A[1] = A_2; +# CHECK: } +# CHECK: A_3 = A_3 + 1; +# CHECK: A[2] = A_3; +# CHECK: } +# CHECK: } +# CHECK: A[4] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can replace a simple scalar access with a local variable even when that +// variable is an outer loop var. +TEST(Registerizer, RegisterizerNestedLoopSimple) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({For::make( + y, + 0, + 10, + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {y}, Add::make(Load::make(a, {y}, 1), x), 1)})))}); + + /* + * for (int y = 0; y < 10; y++) { + * for (int x = 0; x < 10; x++) { + * A[y] = (A[y]) + x; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * for (int y = 0; y < 10; y++) { + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[y] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int y +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = x + A_1; +# CHECK: } +# CHECK: A[y] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the positive case of the hiddenAccess split, where an internal +// conditional access can be hoisted up through a loop to match an existing +// access in a higher scope and the two can be registerized. +TEST(Registerizer, RegisterizerHiddenAccessYes) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0, 1), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make( + a, + {0}, + Add::make(Load::make(a, {0}, 1), 1), + 1)), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the negative case of the hiddenAccess split, where the hoisted access is +// never unhidden at a higher scope and registerization occurs at the lower +// scope. +TEST(Registerizer, RegisterizerHiddenAccessNo) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0, 1), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * int A_1 = A[0]; + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: int A_1 = A[0]; +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: } +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// In this case the conditional access must be hoisted by two loops, there are +// two accesses here one is unhidden and the other isnt. A[0] can be +// registerized but B[0] cannot. +TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0, 1), + For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + Block::make({Cond::make( + CompareSelect::make(y, 3, CompareSelectOperation::kEQ), + Block::make( + {Store::make( + a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1), + Store::make( + b, + {0}, + Add::make(Load::make(b, {0}, 1), 1), + 1)}), + nullptr)})))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A_1 = A_1 + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: for (int y +# CHECK: if (y==3 +# CHECK: A_1 = A_1 + 1; +# CHECK: B[0] = (B[0]) + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, but the immeidate parent is +// not a condition. +TEST(Registerizer, RegisterizerTwoConditionalLoops) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, cut in the middle. +TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { + KernelScope kernel_scope; + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr), + For::make(x, 0, 10, Store::make(a, {x}, 1, 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), 1), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: for (int x +# CHECK: A[x] = 1; +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// references a Let var in a local scope which cannot be hoisted out of the +// loop. +TEST(Registerizer, RegisterizerLoopLetVar) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Let::make(y, 30), + Store::make(a, {y}, Add::make(x, Load::make(a, {y}, 1)), 1)}))}); + + /* + * for (int x = 0; x < 10; x++) { + * int y = 30; + * A[y] = x + (A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// references a Let var in an outer scope that does not prevent hoisting the +// initializer. +TEST(Registerizer, RegisterizerLoopLetVarOuter) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make( + {Let::make(y, 30), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {y}, Add::make(x, Load::make(a, {y}, 1)), 1)}))}); + + /* + * int y = 30; + * for (int x = 0; x < 10; x++) { + * A[y] = x + (A[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int y = 30; + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[y] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int y = 30; +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = x + A_1; +# CHECK: A[y] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Okay so the registerizer generally goes after index flattening, but just in +// case. Test multi index registerization. +TEST(Registerizer, RegisterizerMultiDim) { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 1, 2] = (A[0, 1, 2]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0, 1, 2] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0, 1, 2] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Wont registerize if only some dims match, but will still registerize distinct +// elements. +TEST(Registerizer, RegisterizerMultiDimPartial) { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 2, 2] = (A[0, 1, 4]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[0, 1, 4]; + * int A_2 = A[0, 2, 2]; + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_1; + * } + * A[0, 2, 2] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[0, 1, 4]; +# CHECK: int A_2 = A[0, 2, 2]; +# CHECK: for ( +# CHECK: A_2 = x + A_1; +# CHECK: A[0, 2, 2] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If they could overlap across all dimensions we cannot registerize. +TEST(Registerizer, RegisterizerMultiDimOverlap) { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 2]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// But, if one dimension is known to be distinct they do not overlap. +TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { + KernelScope kernel_scope; + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + Stmt* stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0, 1), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}, 1), x), 1)}))}); + + /* + * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[y, 2, 4]; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = A_1 + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[y, 2, 4]; +# CHECK: for ( +# CHECK: A[0, x, 2] = A_1 + x; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with different input dimensionality. +TEST(Registerizer, RegisterizerMultiDim3DReduction1) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10, 10}, kInt); + BufHandle c("C", {10, 10, 10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + Stmt* stmt = For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x, y, z}, + Add::make( + Load::make(c, {x, y, z}, 1), + Mul::make( + Load::make(b, {x, y}, 1), Load::make(a, {x}, 1))), + 1)))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize the A and B access since they can be hoisted before + // hitting a dependent loop var. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[x, y]; + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[x, y]; +# CHECK: for (int z +# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with the same smaller dimensionality using different loop +// vars. +TEST(Registerizer, RegisterizerMultiDim3DReduction2) { + KernelScope kernel_scope; + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + Stmt* stmt = For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x}, + Add::make( + Load::make(c, {x}, 1), + Mul::make(Load::make(b, {y}, 1), Load::make(a, {x}, 1))), + 1)))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x] = (C[x]) + (B[y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize all accesses, the A and C access can be hoisted to the + // outer loop since they depend only on it's loop var while the B can only be + // raised to the loop of y. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * int C_1 = C[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[y]; + * for (int z = 0; z < 10; z++) { + * C_1 = B_1 * A_1 + C_1; + * } + * } + * C[x] = C_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: int C_1 = C[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[y]; +# CHECK: for (int z +# CHECK: C_1 = B_1 * A_1 + C_1; +# CHECK: } +# CHECK: } +# CHECK: C[x] = C_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index f0185884fc58d..c34812e463825 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -1,9 +1,10 @@ -#include "test/cpp/tensorexpr/test_base.h" +#include +#include -#include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/hash_provider.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/loopnest.h" +#include +#include +#include +#include #include @@ -71,7 +72,7 @@ using SimpleIRExprEval = ExprEval; ASSERT_EQ(node_->op_type(), kRand); \ } -void testConstantFoldSimple() { +TEST(Simplify, ConstantFoldSimple) { KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -85,7 +86,7 @@ void testConstantFoldSimple() { ASSERT_EQ(eval.value(), 5.f); } -void testConstantFoldTwoLayer() { +TEST(Simplify, ConstantFoldTwoLayer) { KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -101,7 +102,7 @@ void testConstantFoldTwoLayer() { ASSERT_EQ(eval.value(), -4.f); } -void testConstantFoldShifts() { +TEST(Simplify, ConstantFoldShifts) { KernelScope kernel_scope; ExprHandle a(7); ExprHandle b(2); @@ -116,7 +117,7 @@ void testConstantFoldShifts() { ASSERT_EQ(eval.value(), 7 << (4 - 3)); } -void testConstantFoldBitwise() { +TEST(Simplify, ConstantFoldBitwise) { KernelScope kernel_scope; ExprHandle a(59); ExprHandle b(22); @@ -131,7 +132,7 @@ void testConstantFoldBitwise() { ASSERT_EQ(eval.value(), (59 ^ 22) & 101); } -void testConstantFoldMultiOp() { +TEST(Simplify, ConstantFoldMultiOp) { KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -150,7 +151,7 @@ void testConstantFoldMultiOp() { ASSERT_EQ(eval.value(), ref.value()); } -void testConstantFoldMinMax() { +TEST(Simplify, ConstantFoldMinMax) { KernelScope kernel_scope; ExprHandle a(12.0f); ExprHandle b(15.0f); @@ -169,7 +170,7 @@ void testConstantFoldMinMax() { ASSERT_EQ(eval.value(), 15.f); } -void testConstantFoldIntrinsics() { +TEST(Simplify, ConstantFoldIntrinsics) { KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -179,7 +180,7 @@ void testConstantFoldIntrinsics() { ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); - ExprHandle fn = Intrinsics::make(kFabs, rndHandle); + ExprHandle fn = Intrinsics::make(kAbs, rndHandle); ExprHandle newF = IRSimplifier::simplify(fn); ASSERT_NE(newF.AsNode(), nullptr); @@ -191,7 +192,7 @@ void testConstantFoldIntrinsics() { ASSERT_EQ(eval.value(), ref.value()); } -void testConstantFoldCastToBool() { +TEST(Simplify, ConstantFoldCastToBool) { KernelScope kernel_scope; ExprHandle f = Cast::make(kBool, IntImm::make(0)); ExprHandle newF = IRSimplifier::simplify(f); @@ -199,7 +200,7 @@ void testConstantFoldCastToBool() { ASSERT_EQ(eval.value(), false); } -void testConstantFoldWithVar() { +TEST(Simplify, ConstantFoldWithVar) { KernelScope kernel_scope; { VarHandle x("x", kInt); @@ -230,7 +231,7 @@ void testConstantFoldWithVar() { } } -void testConditionalSelectFoldSimple() { +TEST(Simplify, ConditionalSelectFoldSimple) { KernelScope kernel_scope; ExprHandle a(3.0f); ExprHandle b(4.0f); @@ -277,7 +278,7 @@ void testConditionalSelectFoldSimple() { } } -void testConditionalSelectFoldTwoLayer() { +TEST(Simplify, ConditionalSelectFoldTwoLayer) { KernelScope kernel_scope; ExprHandle a(3.0f); ExprHandle b(2.0f); @@ -325,7 +326,7 @@ void testConditionalSelectFoldTwoLayer() { } } -void testConditionalSelectFoldWithVar() { +TEST(Simplify, ConditionalSelectFoldWithVar) { KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle f = x < 4.f; @@ -346,7 +347,7 @@ void testConditionalSelectFoldWithVar() { } } -void testUnFoldableExpr() { +TEST(Simplify, UnFoldableExpr) { KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kFloat); @@ -364,7 +365,7 @@ void testUnFoldableExpr() { ASSERT_EQ(eval.value(), 9 + 10); } -void testHashSimple() { +TEST(Simplify, HashSimple) { KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle a(2.0f); @@ -385,7 +386,7 @@ void testHashSimple() { ASSERT_NE(hash_a, hash_f); } -void testHashEquivalence() { +TEST(Simplify, HashEquivalence) { KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kFloat); @@ -422,7 +423,7 @@ void testHashEquivalence() { ASSERT_NE(hasher.hash(f5.node()), (size_t)0); } -void testHashEquivalenceRand() { +TEST(Simplify, HashEquivalenceRand) { KernelScope kernel_scope; ExprHandle f = Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); @@ -442,7 +443,7 @@ void testHashEquivalenceRand() { ASSERT_NE(hash_l, hash_r); } -void testHashEquivalenceAfterFolding() { +TEST(Simplify, HashEquivalenceAfterFolding) { KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle a(2.0f); @@ -468,7 +469,7 @@ void testHashEquivalenceAfterFolding() { ASSERT_EQ(hash_l_n, hash_r_n); } -void testHashDifferenceTypes() { +TEST(Simplify, HashDifferenceTypes) { KernelScope kernel_scope; HashProvider hasher; @@ -501,12 +502,12 @@ void testHashDifferenceTypes() { ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); } -void testHashLargeExpression() { +TEST(Simplify, HashLargeExpression) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto memcpy_stmt = For::make( @@ -522,8 +523,8 @@ void testHashLargeExpression() { CompareSelectOperation::kEQ), mask)); - Buffer d(BufHandle("D", {1}, kInt)); - Buffer e(BufHandle("E", {1}, kInt)); + BufHandle d("D", {1}, kInt); + BufHandle e("E", {1}, kInt); auto store_ramp_stmt = Store::make( e, {Ramp::make(0, 1, 4)}, @@ -552,12 +553,12 @@ void testHashLargeExpression() { ASSERT_NE(hash_t, hash_f); } -void testHashForLoopOptions() { +TEST(Simplify, HashForLoopOptions) { KernelScope kernel_scope; constexpr int N = 1024; - Buffer a(BufHandle("A", {N}, kInt)); - Buffer b(BufHandle("B", {N}, kInt)); - Buffer c(BufHandle("C", {N}, kInt)); + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto for_stmt = For::make( @@ -596,7 +597,7 @@ void testHashForLoopOptions() { } /// (2 + x) + 4 => x + 6 -void testSimplifyAdd() { +TEST(Simplify, SimplifyAdd) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -618,7 +619,7 @@ void testSimplifyAdd() { } /// (2 - x) - 4 => -2 - x -void testSimplifySub() { +TEST(Simplify, SimplifySub) { KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); @@ -634,21 +635,21 @@ void testSimplifySub() { ASSERT_EQ(rhs->name_hint(), "x"); } -/// 2 * (1 - x) - 4 => -2 * (x + 3) -void testSimplifyMultiLayer() { +/// 2 * (1 - x) - 4 => 2 * (-3 - x) +TEST(Simplify, SimplifyMultiLayer) { KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 3); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_IMM_WITH_VAL(Int, sub->lhs(), -3); + IS_VAR_WITH_NAME(sub->rhs(), "x"); } /// 2 * (3 * x) - (x * 4) => 2 * x -void testSimplifyMultiTerm() { +TEST(Simplify, SimplifyMultiTerm) { KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = @@ -666,7 +667,7 @@ void testSimplifyMultiTerm() { } /// 2 * (3 * (long)x) - (x * 4) => 2 * x -void testSimplifyCasts() { +TEST(Simplify, SimplifyCasts) { KernelScope kernel_scope; VarHandle x("x", kLong); ExprHandle body = @@ -684,7 +685,7 @@ void testSimplifyCasts() { } /// (x + 0) * 1 => x -void testSimplifyEliminatesNoOps() { +TEST(Simplify, SimplifyEliminatesNoOps) { KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = (x + ExprHandle(0)) * 1; @@ -696,7 +697,7 @@ void testSimplifyEliminatesNoOps() { } /// Cannot simplify this. -void testSimplifyMultiVar() { +TEST(Simplify, SimplifyMultiVar) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -719,7 +720,7 @@ void testSimplifyMultiVar() { } // x + 2 + y => x + y + 2 -void testSimplifyReorderings() { +TEST(Simplify, DISABLED_SimplifyReorderings) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -736,7 +737,7 @@ void testSimplifyReorderings() { } /// y + x * 0 => y -void testSimplifyEliminatesVar() { +TEST(Simplify, SimplifyEliminatesVar) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -746,7 +747,7 @@ void testSimplifyEliminatesVar() { IS_VAR_WITH_NAME(simplified.node(), "y"); } -void testSimplifyAdds() { +TEST(Simplify, SimplifyAdds) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -776,16 +777,16 @@ void testSimplifyAdds() { } { - // (x - y) + (x - y) => -2 * (y - x) + // (x - y) + (x - y) => 2 * (x - y) ExprHandle body = (x - y) + (x - y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "y"); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { @@ -817,7 +818,7 @@ void testSimplifyAdds() { } } -void testSimplifyMuls() { +TEST(Simplify, SimplifyMuls) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -937,7 +938,7 @@ void testSimplifyMuls() { } // Sub an expr from itself will result in zero. -void testSimplifySubs() { +TEST(Simplify, SimplifySubs) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -964,15 +965,15 @@ void testSimplifySubs() { } { - // (x + y) - 2 * (x + y) => -1 * (x + y) + // (x + y) - 2 * (x + y) => -1 * x - y ExprHandle body = (x + y) - ExprHandle(2) * (x + y); ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); } { @@ -1060,9 +1061,50 @@ void testSimplifySubs() { ExprHandle simplified = IRSimplifier::simplify(body); IS_IMM_WITH_VAL(Int, simplified.node(), 2); } + + { + // Sub where result is negative. + ExprHandle body = x - (x + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), -1); + } + + { + // Sub where result is positive due to negative scalar on RHS. + ExprHandle body = x - (x - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 1); + } + + { + // Term - Polynomial sub where RHS must be negated. + ExprHandle body = (x * 2) - (x * 2 + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), -1); + } + + { + // Term - Polynomial sub where the result is a Term. + ExprHandle body = (y * x * 2) - (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Term - Polynomial sub where the result is a Polynomial. + ExprHandle body = (x * 2) - (x + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_IMM_WITH_VAL(Int, sub->rhs(), 1); + } } -void testSimplifyDiv() { +TEST(Simplify, SimplifyDiv) { KernelScope kernel_scope; VarHandle x("x", kInt); @@ -1079,17 +1121,135 @@ void testSimplifyDiv() { IS_VAR_WITH_NAME(simplified.node(), "x"); } +} + +TEST(Simplify, SimplifyMod) { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); { - ExprHandle body = x / x; + // Constant folding works. + ExprHandle body = ExprHandle(10) % 8; ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 2); + } - IS_IMM_WITH_VAL(Int, simplified.node(), 1); + { + // x % x => 0 + ExprHandle body = x % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // 0 % x => 0 + ExprHandle body = ExprHandle(0) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // x % 1 => 0 + ExprHandle body = x % 1; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Doesn't change unknown mods. + // x % y => x % y + ExprHandle body = x % y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_VAR_WITH_NAME(mod->rhs(), "y"); + } + + { + // don't touch if RHS is unknown. + // 4 % x => 4 % x + ExprHandle body = ExprHandle(4) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_IMM_WITH_VAL(Int, mod->lhs(), 4); + IS_VAR_WITH_NAME(mod->rhs(), "x"); + } + + { + // don't touch if LHS is unknown. + // x % 4 => x % 4 + ExprHandle body = x % 4; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 4); + } + + { + // if LHS is a multiple of RHS, mod is zero. + // 2 * x % x => 0 + ExprHandle body = (x * 2) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true even if the multiple is not constant. + // x * y % x => 0 + ExprHandle body = (x * y) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true with multiple unknown values in LHS. + // x * y * z % x => 0 + ExprHandle body = (x * y * z) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true if the denom is compound. + // x * y * z % y * z => 0 + ExprHandle body = (x * y * z) % (y * z); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Sanity check true with scalars that are multiples. + // 12 * x % 4 => 0 + ExprHandle body = (x * 12) % 4; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Sanity check not true if the smaller scalar is on LHS. + // 4 * x % 12 => 4 * x % 12 + ExprHandle body = (x * 4) % 12; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 4); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 12); + } + + { + // Both scalar and symbolic in multiple. + // (6 * x * y) % (3 * x * y) => 0 + ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); } } // Test that mixing ops together simplifies as expected. -void testSimplifyMultiOp() { +TEST(Simplify, SimplifyMultiOp) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1158,7 +1318,7 @@ void testSimplifyMultiOp() { } // Test that chaining many ops together works as expected. -void testSimplifyManyOps() { +TEST(Simplify, SimplifyManyOps) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1206,7 +1366,7 @@ void testSimplifyManyOps() { } } -void testSimplifyFactorization() { +TEST(Simplify, SimplifyFactorization) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1274,12 +1434,12 @@ void testSimplifyFactorization() { { // Factorization with common divider but different signs. - // (-2 * x) + (4 * y) => -2 * (x - 2 * y) - ExprHandle body = (ExprHandle(-2) * x + ExprHandle(4) * y); + // (2 * x) + (-4 * y) => 2 * (x - 2 * y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); IS_VAR_WITH_NAME(sub->lhs(), "x"); @@ -1287,10 +1447,54 @@ void testSimplifyFactorization() { IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); IS_VAR_WITH_NAME(mul2->rhs(), "y"); } + + { + // Factorization with all negative numbers. + // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) + ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); + IS_VAR_WITH_NAME(mul2->rhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); + IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); + IS_VAR_WITH_NAME(mul3->rhs(), "y"); + } + + { + // The following test ensures that there in no infinite recursion during + // factorization when negative numbers are involved. + VarHandle a("a", kInt); + VarHandle b("b", kInt); + VarHandle c("c", kInt); + VarHandle d("d", kInt); + VarHandle e("e", kInt); + VarHandle f("f", kInt); + VarHandle g("g", kInt); + VarHandle h("h", kInt); + + ExprHandle body = ExprHandle(0) + (ExprHandle(1024) * a) + + (ExprHandle(-1) * b) + (ExprHandle(-1) * c) + (ExprHandle(1) * d) + + (ExprHandle(1) * e) + (ExprHandle(32) * f) + (ExprHandle(-1024) * g) + + (ExprHandle(-32) * h); + ExprHandle simplified = IRSimplifier::simplify(body); + + // We only check for the top level nodes here, since the main purpose + // here is ensure that this simplification completes. + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 1024); + IS_VAR_WITH_NAME(mul->rhs(), "g"); + } } // (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) -void testSimplifyFactorizeUneven() { +TEST(Simplify, SimplifyFactorizeUneven) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1317,7 +1521,7 @@ void testSimplifyFactorizeUneven() { // (x * y) + (2 * x) * (x + y) => 3 * (x * y) + 2 * (x * x) // This is kind of a placeholder test for variable factorization. -void testSimplifyDeeperTerms() { +TEST(Simplify, SimplifyDeeperTerms) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1341,7 +1545,7 @@ void testSimplifyDeeperTerms() { // Tests the difference between two less trivial expressions. // (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 -void testSimplifyDeeperDifference() { +TEST(Simplify, SimplifyDeeperDifference) { KernelScope kernel_scope; VarHandle n("n", kInt); VarHandle n_1("n_1", kInt); @@ -1355,7 +1559,7 @@ void testSimplifyDeeperDifference() { // Test constant folding into the difference between expressions. // 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 -void testSimplifyFoldComplexDifference() { +TEST(Simplify, SimplifyFoldComplexDifference) { KernelScope kernel_scope; VarHandle n("n", kInt); VarHandle n_1("n_1", kInt); @@ -1370,7 +1574,7 @@ void testSimplifyFoldComplexDifference() { IS_IMM_WITH_VAL(Int, simplified.node(), 3); } -void testSimplifyIfComponents() { +TEST(Simplify, SimplifyIfComponents) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1392,7 +1596,7 @@ void testSimplifyIfComponents() { IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); } -void testSimplifyOpaqueTerms() { +TEST(Simplify, SimplifyOpaqueTerms) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1418,7 +1622,7 @@ void testSimplifyOpaqueTerms() { } } -void testSimplifySymbolicMinMax() { +TEST(Simplify, SimplifySymbolicMinMax) { KernelScope kernel_scope; { @@ -1454,7 +1658,7 @@ void testSimplifySymbolicMinMax() { } } -void testSimplifyNestedMax() { +TEST(Simplify, SimplifyNestedMax) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1757,7 +1961,7 @@ void testSimplifyNestedMax() { } } -void testSimplifyNestedMin() { +TEST(Simplify, SimplifyNestedMin) { KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -2060,11 +2264,11 @@ void testSimplifyNestedMin() { } } -void testSimplifyWontReorderFloat() { +TEST(Simplify, SimplifyWontReorderFloat) { KernelScope kernel_scope; { - // 3 * (3 * x) - 3 * (3 * y) => -9 * (y - x) + // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) // This is an expression we can simplify. VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -2074,10 +2278,10 @@ void testSimplifyWontReorderFloat() { ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -9); + IS_IMM_WITH_VAL(Int, mul->lhs(), 9); IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "y"); - IS_VAR_WITH_NAME(sub->rhs(), "x"); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); } { @@ -2172,7 +2376,7 @@ void testSimplifyWontReorderFloat() { } } -void testSimplifyRoundModPattern() { +TEST(Simplify, SimplifyRoundModPattern) { KernelScope kernel_scope; { @@ -2380,7 +2584,7 @@ void testSimplifyRoundModPattern() { } } -void testSimplifyRoundModPatternFactorization() { +TEST(Simplify, SimplifyRoundModPatternFactorization) { KernelScope kernel_scope; { @@ -2439,7 +2643,7 @@ void testSimplifyRoundModPatternFactorization() { } } -void testSimplifyRoundModPatternMultivar() { +TEST(Simplify, SimplifyRoundModPatternMultivar) { KernelScope kernel_scope; { @@ -2496,7 +2700,7 @@ void testSimplifyRoundModPatternMultivar() { } } -void testSimplifyDivisionScalarFactorization() { +TEST(Simplify, SimplifyDivisionScalarFactorization) { KernelScope kernel_scope; { @@ -2568,7 +2772,7 @@ void testSimplifyDivisionScalarFactorization() { } } -void testSimplifyConstantBranches() { +TEST(Simplify, SimplifyConstantBranches) { KernelScope kernel_scope; { @@ -2626,14 +2830,14 @@ void testSimplifyConstantBranches() { } } -void testSimplifyConstantCond() { +TEST(Simplify, SimplifyConstantCond) { KernelScope kernel_scope; { // If the condition is constant true then take the true_value. // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); ExprHandle condition(1); Stmt* true_val = Store::make(a, {0}, 1, 1); Stmt* false_val = Store::make(b, {0}, 1, 1); @@ -2648,8 +2852,8 @@ void testSimplifyConstantCond() { { // If the condition is constant false then take the false_value. // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); ExprHandle condition(0); Stmt* true_val = Store::make(a, {0}, 1, 1); Stmt* false_val = Store::make(b, {0}, 1, 1); @@ -2665,8 +2869,8 @@ void testSimplifyConstantCond() { // condition is simplified before checking. // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 VarHandle x("x", kInt); - Buffer a(BufHandle("A", {1}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); ExprHandle condition(x - x); Stmt* true_val = Store::make(a, {0}, 1, 1); Stmt* false_val = Store::make(b, {0}, 1, 1); @@ -2682,7 +2886,7 @@ void testSimplifyConstantCond() { // If both branches are the same then don't do the condition. // x ? A[0] = x : A[0] = x => A[0] = x VarHandle x("x", kInt); - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); ExprHandle condition(x - x); Stmt* true_val = Store::make(a, {0}, x, 1); Stmt* false_val = Store::make(a, {0}, x, 1); @@ -2698,7 +2902,7 @@ void testSimplifyConstantCond() { // If both branches simplify to the same thing it still works. // x ? (x + x) : (2 * x) => x VarHandle x("x", kInt); - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); ExprHandle condition(x - x); Stmt* true_val = Store::make(a, {0}, ExprHandle(2) * x, 1); Stmt* false_val = Store::make(a, {0}, x + x, 1); @@ -2714,7 +2918,7 @@ void testSimplifyConstantCond() { // But not if they dont // x ? x : (2 * x) => x ? x : (2 * x) VarHandle x("x", kInt); - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); ExprHandle condition(x); Stmt* true_val = Store::make(a, {0}, x, 1); Stmt* false_val = Store::make(a, {0}, ExprHandle(2) * x, 1); @@ -2738,7 +2942,7 @@ void testSimplifyConstantCond() { } } -void testSimplifyEliminateEmptyCond() { +TEST(Simplify, SimplifyEliminateEmptyCond) { KernelScope kernel_scope; // If the branches are empty in different ways, eliminate. { @@ -2766,13 +2970,196 @@ void testSimplifyEliminateEmptyCond() { } } -void testSimplifyEliminateZeroLengthFor() { +TEST(Simplify, SimplifyConstantComparisons) { + KernelScope kernel_scope; + + auto ComparisonTest = + [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { + ExprHandle body = CompareSelect::make(a, b, op); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), result); + }; + + // Equals. + ComparisonTest(2, 2, kEQ, 1); + ComparisonTest(1, 2, kEQ, 0); + ComparisonTest(2, 1, kEQ, 0); + + // Greater than. + ComparisonTest(2, 2, kGT, 0); + ComparisonTest(1, 2, kGT, 0); + ComparisonTest(2, 1, kGT, 1); + + // Greater or Equal. + ComparisonTest(2, 2, kGE, 1); + ComparisonTest(1, 2, kGE, 0); + ComparisonTest(2, 1, kGE, 1); + + // Less Than. + ComparisonTest(2, 2, kLT, 0); + ComparisonTest(1, 2, kLT, 1); + ComparisonTest(2, 1, kLT, 0); + + // Less or Equal. + ComparisonTest(2, 2, kLE, 1); + ComparisonTest(1, 2, kLE, 1); + ComparisonTest(2, 1, kLE, 0); + + // Not equal. + ComparisonTest(2, 2, kNE, 0); + ComparisonTest(1, 2, kNE, 1); + ComparisonTest(2, 1, kNE, 1); + + // With specified results: + ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 42); +} + +TEST(Simplify, SimplifySymbolicComparisons) { + KernelScope kernel_scope; + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; + auto TookFalseBranch = [](ExprHandle a) { + IS_IMM_WITH_VAL(Int, a.node(), 0); + }; + + // EQ + + // x == x => 1 + ExprHandle body = CompareSelect::make(x, x, kEQ); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x == x+1 => 0 + body = CompareSelect::make(x, x + 1, kEQ); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x == x * 2 cannot simplify since we don't know x is nonzero. + body = CompareSelect::make(x, x * 2, kEQ); + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // x == x * 1 => 1 + body = CompareSelect::make(x, x * 1, kEQ); + TookTrueBranch(IRSimplifier::simplify(body)); + + { + // x == y => x == y + body = CompareSelect::make(x, y, kEQ); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kEQ); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_VAR_WITH_NAME(cmp->rhs(), "y"); + } + + { + // x == 5 => x == 5 + body = CompareSelect::make(x, 5, kEQ); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kEQ); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); + } + + // GT + + // x+1 > x => 1 + body = CompareSelect::make(x + 1, x, kGT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x > x + 1 => 0 + body = CompareSelect::make(x, x + 1, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x > x - 1 => 1 + body = CompareSelect::make(x, x - 1, kGT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x - 1 > x => 0 + body = CompareSelect::make(x - 1, x, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x > x => 0 + body = CompareSelect::make(x, x, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x * 2 > x => x * 2 > x + // since we don't know the sign of x. + body = CompareSelect::make(x * 2, x, kGT); + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // GE + + // x+1 >= x => 1 + body = CompareSelect::make(x + 1, x, kGE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x >= x + 1 => 0 + body = CompareSelect::make(x, x + 1, kGE); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x >= x => 1 + body = CompareSelect::make(x, x, kGE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x * 2 >= x => x * 2 >= x + // since we don't know the sign of x. + body = CompareSelect::make(x * 2, x, kGE); + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // LT + + // x+1 < x => 0 + body = CompareSelect::make(x + 1, x, kLT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x < x + 1 => 1 + body = CompareSelect::make(x, x + 1, kLT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x < x => 0 + body = CompareSelect::make(x, x, kLT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // LE + + // x+1 <= x => 0 + body = CompareSelect::make(x + 1, x, kLE); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x <= x + 1 => 1 + body = CompareSelect::make(x, x + 1, kLE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x <= x => 1 + body = CompareSelect::make(x, x, kLE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // NE + + // x+1 != x => 1 + body = CompareSelect::make(x + 1, x, kNE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x != x + 1 => 1 + body = CompareSelect::make(x, x + 1, kNE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x != x => 0 + body = CompareSelect::make(x, x, kNE); + TookFalseBranch(IRSimplifier::simplify(body)); +} + +TEST(Simplify, SimplifyEliminateZeroLengthFor) { KernelScope kernel_scope; { // Will eliminate zero loop For. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = @@ -2784,8 +3171,8 @@ void testSimplifyEliminateZeroLengthFor() { { // still works if start is not zero. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = @@ -2798,8 +3185,8 @@ void testSimplifyEliminateZeroLengthFor() { { // works if both terms are variable. VarHandle x("x", kInt); - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = @@ -2812,8 +3199,8 @@ void testSimplifyEliminateZeroLengthFor() { { // works if one term simplifies down. VarHandle x("x", kInt); - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = For::make( @@ -2825,8 +3212,8 @@ void testSimplifyEliminateZeroLengthFor() { { // Sanity check does nothing if the condition is not met. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = @@ -2836,13 +3223,13 @@ void testSimplifyEliminateZeroLengthFor() { } } -void testSimplifyOneLoopFor() { +TEST(Simplify, SimplifyOneLoopFor) { KernelScope kernel_scope; { // Will remove the loop if the body is run once. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = @@ -2856,8 +3243,8 @@ void testSimplifyOneLoopFor() { { // still works if start is not zero. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = @@ -2872,8 +3259,8 @@ void testSimplifyOneLoopFor() { { // works if both terms are variable. VarHandle x("x", kInt); - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = For::make( @@ -2888,8 +3275,8 @@ void testSimplifyOneLoopFor() { { // works if one term simplifies down. VarHandle x("x", kInt); - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = For::make( @@ -2903,8 +3290,8 @@ void testSimplifyOneLoopFor() { { // Sanity check does nothing if the condition is not met. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); auto body = @@ -2914,13 +3301,13 @@ void testSimplifyOneLoopFor() { } } -void testSimplifyForWontLoseLoopOptions() { +TEST(Simplify, SimplifyForWontLoseLoopOptions) { KernelScope kernel_scope; { // Sanity check does nothing if the condition is not met. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); LoopOptions options; @@ -2934,13 +3321,13 @@ void testSimplifyForWontLoseLoopOptions() { } } -void testSimplifyMultilevelFor() { +TEST(Simplify, SimplifyMultilevelFor) { KernelScope kernel_scope; { // Multiple layers of For will be simplified out. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); VarHandle j("j", kInt); @@ -2956,8 +3343,8 @@ void testSimplifyMultilevelFor() { { // Will maintain an outer loop if the inner loop is eliminated. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); VarHandle j("j", kInt); @@ -2979,8 +3366,8 @@ void testSimplifyMultilevelFor() { { // Will maintain inner loop if outer loops is eliminated. - Buffer a(BufHandle("A", {4}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); VarHandle j("j", kInt); @@ -2999,11 +3386,11 @@ void testSimplifyMultilevelFor() { } } -void testSimplifyForCleansUp() { +TEST(Simplify, SimplifyForCleansUp) { KernelScope kernel_scope; { - Buffer a("a", kFloat, {1, 12, 1}); + Placeholder a("a", kFloat, {1, 12, 1}); VarHandle x("x", kInt); Tensor* b = Compute( "x", @@ -3028,7 +3415,7 @@ void testSimplifyForCleansUp() { } } -void testSimplifyEliminateEmptyFor() { +TEST(Simplify, SimplifyEliminateEmptyFor) { KernelScope kernel_scope; { @@ -3045,13 +3432,13 @@ void testSimplifyEliminateEmptyFor() { } } -void testSimplifyFlattenBlock() { +TEST(Simplify, SimplifyFlattenBlock) { KernelScope kernel_scope; { // Flatten multiple blocks down to one. // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); Store* store1 = Store::make(a, {0}, 1, 1); Store* store2 = Store::make(a, {0}, 0, 1); @@ -3074,7 +3461,7 @@ void testSimplifyFlattenBlock() { { // Flatten multiple sub blocks containing statements. // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); Store* store1 = Store::make(a, {0}, 1, 1); Store* store2 = Store::make(a, {0}, 0, 1); @@ -3097,7 +3484,7 @@ void testSimplifyFlattenBlock() { { // Flatten sub blocks with different depths. // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } - Buffer a(BufHandle("A", {1}, kInt)); + BufHandle a("A", {1}, kInt); Store* store1 = Store::make(a, {0}, 1, 1); Store* store2 = Store::make(a, {0}, 0, 1); @@ -3130,7 +3517,7 @@ void testSimplifyFlattenBlock() { } } -void testSimplifyEliminateZeroLengthAlloc() { +TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { KernelScope kernel_scope; { @@ -3204,7 +3591,7 @@ void testSimplifyEliminateZeroLengthAlloc() { } } -void testDontSimplifyRand() { +TEST(Simplify, DontSimplifyRand) { KernelScope kernel_scope; { @@ -3238,11 +3625,11 @@ void testDontSimplifyRand() { } } -void testSimplifyReorderForCond() { +TEST(Simplify, SimplifyReorderForCond) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {4}, kInt)); - Buffer b(BufHandle("B", {1}, kInt)); - Buffer c(BufHandle("C", {4}, kInt)); + BufHandle a("A", {4}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); VarHandle j("j", kInt); @@ -3438,10 +3825,10 @@ void testSimplifyReorderForCond() { } } -void testSimplifyFuseConditions() { +TEST(Simplify, SimplifyFuseConditions) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {2}, kInt)); - Buffer b(BufHandle("B", {2}, kInt)); + BufHandle a("A", {2}, kInt); + BufHandle b("B", {2}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); VarHandle j("j", kInt); @@ -3552,24 +3939,25 @@ void testSimplifyFuseConditions() { // Can't fuse, CompareSelect results are different. // Actually we totally could if we normalized CompareSelect results, but // TODO for later. - auto body = Block::make({Cond::make( - CompareSelect::make( - i, - 10, - new IntImm(1), - new IntImm(0), - CompareSelectOperation::kLT), - Store::make(a, {0}, i, mask), - nullptr), - Cond::make( - CompareSelect::make( - j, - 10, - new IntImm(2), - new IntImm(0), - CompareSelectOperation::kLT), - Store::make(a, {1}, i, mask), - nullptr)}); + auto body = Block::make( + {Cond::make( + CompareSelect::make( + i, + 10, + new IntImm(1), + new IntImm(0), + CompareSelectOperation::kLT), + Store::make(a, {0}, i, mask), + nullptr), + Cond::make( + CompareSelect::make( + j, + 10, + new IntImm(2), + new IntImm(0), + CompareSelectOperation::kLT), + Store::make(a, {1}, i, mask), + nullptr)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); @@ -3768,20 +4156,21 @@ void testSimplifyFuseConditions() { { // Can fuse if the conditions simplify to the same thing. - auto body = Block::make({Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(87) % ExprHandle(11), - CompareSelectOperation::kLT), - Store::make(a, {0}, i, mask), - nullptr), - Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(300) / ExprHandle(30), - CompareSelectOperation::kLT), - Store::make(a, {1}, i, mask), - nullptr)}); + auto body = Block::make( + {Cond::make( + CompareSelect::make( + i * 2, + ExprHandle(87) % ExprHandle(11), + CompareSelectOperation::kLT), + Store::make(a, {0}, i, mask), + nullptr), + Cond::make( + CompareSelect::make( + i * 2, + ExprHandle(300) / ExprHandle(30), + CompareSelectOperation::kLT), + Store::make(a, {1}, i, mask), + nullptr)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 1); @@ -3794,9 +4183,9 @@ void testSimplifyFuseConditions() { { // Can fuse non-CompareSelects. // if (i) { X } if (i) { Y } => if (i) { X; Y } - auto body = - Block::make({Cond::make(i, Store::make(a, {0}, i, mask), nullptr), - Cond::make(i, Store::make(a, {1}, i, mask), nullptr)}); + auto body = Block::make( + {Cond::make(i, Store::make(a, {0}, i, mask), nullptr), + Cond::make(i, Store::make(a, {1}, i, mask), nullptr)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); @@ -3809,9 +4198,9 @@ void testSimplifyFuseConditions() { { // Sanity check wont fuse different non-CompareSelects. - auto body = - Block::make({Cond::make(i, Store::make(a, {0}, i, mask), nullptr), - Cond::make(j, Store::make(a, {1}, i, mask), nullptr)}); + auto body = Block::make( + {Cond::make(i, Store::make(a, {0}, i, mask), nullptr), + Cond::make(j, Store::make(a, {1}, i, mask), nullptr)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); @@ -3823,9 +4212,9 @@ void testSimplifyFuseConditions() { { // Sanity check constant condition elimination still occurs when merging is // possible. - auto body = - Block::make({Cond::make(1, Store::make(a, {0}, i, mask), nullptr), - Cond::make(1, Store::make(a, {1}, i, mask), nullptr)}); + auto body = Block::make( + {Cond::make(1, Store::make(a, {0}, i, mask), nullptr), + Cond::make(1, Store::make(a, {1}, i, mask), nullptr)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 2); @@ -3856,18 +4245,19 @@ void testSimplifyFuseConditions() { } } -void testSimplifySyncThreads() { +TEST(Simplify, SimplifySyncThreads) { KernelScope kernel_scope; - Buffer a(BufHandle("A", {4}, kInt)); + BufHandle a("A", {4}, kInt); auto mask = IntImm::make(1); VarHandle i("i", kInt); { // Merge two inner SyncThreads. - auto body = Block::make({Store::make(a, {0}, 1, 1), - new SyncThreads(), - new SyncThreads(), - Store::make(a, {1}, 0, 1)}); + auto body = Block::make( + {Store::make(a, {0}, 1, 1), + new SyncThreads(), + new SyncThreads(), + Store::make(a, {1}, 0, 1)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); ASSERT_EQ(block->nstmts(), 3); @@ -3891,13 +4281,14 @@ void testSimplifySyncThreads() { { // Merge many inner SyncThreads. - auto body = Block::make({Store::make(a, {0}, 1, 1), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - Store::make(a, {1}, 0, 1)}); + auto body = Block::make( + {Store::make(a, {0}, 1, 1), + new SyncThreads(), + new SyncThreads(), + new SyncThreads(), + new SyncThreads(), + new SyncThreads(), + Store::make(a, {1}, 0, 1)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); @@ -3910,13 +4301,14 @@ void testSimplifySyncThreads() { { // Merge multiple outer SyncThreads. - auto body = Block::make({new SyncThreads(), - new SyncThreads(), - Store::make(a, {1}, 0, 1), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - new SyncThreads()}); + auto body = Block::make( + {new SyncThreads(), + new SyncThreads(), + Store::make(a, {1}, 0, 1), + new SyncThreads(), + new SyncThreads(), + new SyncThreads(), + new SyncThreads()}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); @@ -3927,15 +4319,16 @@ void testSimplifySyncThreads() { { // Merge multiple sections; - auto body = Block::make({Store::make(a, {0}, 1, 1), - new SyncThreads(), - new SyncThreads(), - Store::make(a, {1}, 0, 1), - Store::make(a, {2}, 0, 1), - new SyncThreads(), - new SyncThreads(), - new SyncThreads(), - Store::make(a, {3}, 0, 1)}); + auto body = Block::make( + {Store::make(a, {0}, 1, 1), + new SyncThreads(), + new SyncThreads(), + Store::make(a, {1}, 0, 1), + Store::make(a, {2}, 0, 1), + new SyncThreads(), + new SyncThreads(), + new SyncThreads(), + Store::make(a, {3}, 0, 1)}); Stmt* simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Block, simplified, block); @@ -3950,7 +4343,7 @@ void testSimplifySyncThreads() { } } -void testSimplifyRampSubBroadcast() { +TEST(Simplify, SimplifyRampSubBroadcast) { KernelScope kernel_scope; int num_lanes = 4; ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); @@ -3964,7 +4357,7 @@ void testSimplifyRampSubBroadcast() { ASSERT_EQ(newRamp->lanes(), num_lanes); } -void testSimplifyBroadcastTermExpander() { +TEST(Simplify, SimplifyBroadcastTermExpander) { KernelScope kernel_scope; int num_lanes = 8; ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); @@ -3974,7 +4367,7 @@ void testSimplifyBroadcastTermExpander() { // relevant path in TermExpander::mutate. The two bc1 terms are brought // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); - Buffer buf(BufHandle("buf", {num_lanes}, kInt)); + BufHandle buf("buf", {num_lanes}, kInt); // The result isn't fully simplified currently and thus would be brittle to // match. Observe its value instead. auto store = Store::make( @@ -3982,7 +4375,7 @@ void testSimplifyBroadcastTermExpander() { {Ramp::make(0, 1, num_lanes)}, simplified, Broadcast::make(ExprHandle(1), num_lanes)); - SimpleIREvaluator eval(store, buf); + SimpleIREvaluator eval(store, {buf}); std::vector output(num_lanes); eval(output); for (int i = 0; i < num_lanes; ++i) { diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp index 826cf72093468..adaf14593cf70 100644 --- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp +++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -24,18 +26,17 @@ struct WithCPUFuser { bool cpuFuserEnabled; }; -void testFuserPass_1() { +TEST(TEFuserPass, FuserPass_1) { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(128:1, device=cpu), - %1 : Float(128:1, device=cpu)): + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): %12 : int = prim::Constant[value=1]() - %2.1 : Float(128:1, device=cpu) = aten::mul(%0, %1) - %2 : Float(128:1, device=cpu) = aten::mul(%2.1, %1) - %3 : Float(128:1, device=cpu) = aten::add_(%2, %1, %12) - %4 : Float(128:1, device=cpu) = aten::mul(%2, %1) - %5 : Float(128:1, device=cpu) = aten::add(%2, %4, %12) + %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) + %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) + %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) + %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) return (%5))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -51,17 +52,16 @@ void testFuserPass_1() { ->run(*g); } -void testFuserPass_2() { +TEST(TEFuserPass, FuserPass_2) { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%0 : Float(128:1, device=cpu), - %1 : Float(128:1, device=cpu)): + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): %12 : int = prim::Constant[value=1]() - %a : Float(128:1, device=cpu) = aten::mul(%0, %1) - %b : Float(128:1, device=cpu) = aten::add(%0, %1, %12) - %c : Float(128:1, device=cpu) = aten::add_(%b, %1, %12) - %d : Float(128:1, device=cpu) = aten::mul(%c, %a) + %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) + %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) + %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) return (%d))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -76,13 +76,12 @@ void testFuserPass_2() { ->run(*g); } -void testFuserPass_3() { +TEST(TEFuserPass, FuserPass_3) { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(128:1, device=cpu), - %y : Float(128:1, device=cpu)): - %r : Float(128:1, device=cpu) = aten::mul(%x, %y) + graph(%x : Float(128, strides=[1], device=cpu), + %y : Float(128, strides=[1], device=cpu)): + %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) return (%r))IR"; { auto g = std::make_shared(); @@ -106,8 +105,7 @@ void testFuserPass_3() { } } -void testFuserPass_0DimInput() { - KernelScope kernel_scope; +TEST(TEFuserPass, FuserPass_0DimInput) { const auto graph_string = R"IR( graph(%x : Float(device=cuda), %y : Float(device=cuda)): @@ -125,13 +123,12 @@ void testFuserPass_0DimInput() { testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } -void testFuserPass_UnfusibleDevice() { +TEST(TEFuserPass, FuserPass_UnfusibleDevice) { WithCPUFuser cf(false); - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(10:1, device=cpu)): - %a : Float(10:1, device=cpu) = aten::mul(%x, %y) + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(10, strides=[1], device=cpu)): + %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) return (%a))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -143,9 +140,8 @@ void testFuserPass_UnfusibleDevice() { testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } -void testFuserPass_UnknownShapes() { +TEST(TEFuserPass, FuserPass_UnknownShapes) { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( graph(%x : Tensor, %y : Tensor): @@ -162,17 +158,16 @@ void testFuserPass_UnknownShapes() { testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); } -void testFuserPass_Multidevice() { +TEST(TEFuserPass, FuserPass_Multidevice) { { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cpu), - %z : Float(30:1, device=cpu)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Tensor = aten::cat(%xyz_list, %dim) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) return (%cat))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -185,14 +180,13 @@ void testFuserPass_Multidevice() { } { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cuda:0), - %z : Float(30:1, device=cpu)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0), + %z : Float(30, strides=[1], device=cpu)): %dim : int = prim::Constant[value=0]() %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Tensor = aten::cat(%xyz_list, %dim) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) return (%cat))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -206,15 +200,14 @@ void testFuserPass_Multidevice() { } { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cpu), - %z : Float(10:1, device=cuda:0)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): %dim : int = prim::Constant[value=0]() %xy_list : Tensor[] = prim::ListConstruct(%x, %y) - %xy_cat : Tensor = aten::cat(%xy_list, %dim) - %r : Tensor = aten::mul(%xy_cat, %z) + %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) + %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) return (%r))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -228,15 +221,14 @@ void testFuserPass_Multidevice() { } { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cpu), - %z : Float(10:1, device=cuda:0)): + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): %z2 : Tensor = aten::mul(%z, %z) %dim : int = prim::Constant[value=0]() %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) - %cat : Tensor = aten::cat(%xy_list, %dim) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) return (%cat))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -250,11 +242,10 @@ void testFuserPass_Multidevice() { } { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cpu), - %y : Float(20:1, device=cuda:0)): - %r : Tensor = aten::mul(%x, %y) + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0)): + %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) return (%r))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -267,14 +258,13 @@ void testFuserPass_Multidevice() { } { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%x : Float(10:1, device=cuda:0), - %y : Float(20:1, device=cuda:1), - %z : Float(20:1, device=cpu)): - %x2 : Tensor = aten::mul(%x, %x) - %y2 : Tensor = aten::mul(%y, %y) - %z2 : Tensor = aten::mul(%z, %z) + graph(%x : Float(10, strides=[1], device=cuda:0), + %y : Float(20, strides=[1], device=cuda:1), + %z : Float(20, strides=[1], device=cpu)): + %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) + %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) + %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) return (%x2, %y2, %z2))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -288,14 +278,13 @@ void testFuserPass_Multidevice() { } } -void testFuserPass_MergeGroups() { +TEST(TEFuserPass, FuserPass_MergeGroups) { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( - graph(%a : Float(128:1, device=cpu), - %b : Float(128:1, device=cpu)): - %x : Float(128:1, device=cpu) = aten::mul(%a, %a) - %y : Float(128:1, device=cpu) = aten::mul(%b, %b) + graph(%a : Float(128, strides=[1], device=cpu), + %b : Float(128, strides=[1], device=cpu)): + %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) + %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) return (%x, %y))IR"; auto g = std::make_shared(); torch::jit::parseIR(graph_string, g.get()); @@ -311,9 +300,8 @@ void testFuserPass_MergeGroups() { ->run(*g); } -void testFuserPass_UnknownShapesIgnored() { +TEST(TEFuserPass, FuserPass_UnknownShapesIgnored) { WithCPUFuser cf; - KernelScope kernel_scope; const auto graph_string = R"IR( graph(%x : Float(device=cpu), %y : Float(device=cpu)): @@ -330,5 +318,55 @@ void testFuserPass_UnknownShapesIgnored() { testing::FileCheck().check("prim::TensorExprGroup")->run(*g); } +TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Bool(8, strides=[1], device=cpu), + %y : Bool(8, strides=[1], device=cpu)): + %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) + %b : Tensor = aten::__or__(%a, %y) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_Where) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(8, strides=[1], device=cpu), + %y : Float(8, strides=[1], device=cpu), + %z : Float(8, strides=[1], device=cpu)): + %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) + %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_WhereList) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(8, strides=[1], device=cpu), + %y : Float(8, strides=[1], device=cpu), + %z : Float(8, strides=[1], device=cpu)): + %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) + %b : Tensor[] = aten::where(%cond) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_train.cpp b/test/cpp/tensorexpr/test_train.cpp index aa2426050324e..5bfe68d58bda6 100644 --- a/test/cpp/tensorexpr/test_train.cpp +++ b/test/cpp/tensorexpr/test_train.cpp @@ -1,10 +1,10 @@ -#include "test/cpp/tensorexpr/test_train.h" +#include + #include "test/cpp/tensorexpr/padded_buffer.h" #include "test/cpp/tensorexpr/test_base.h" +#include "test/cpp/tensorexpr/test_train.h" #include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/buffer.h" #include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/function.h" #include "torch/csrc/jit/tensorexpr/ir.h" #include "torch/csrc/jit/tensorexpr/ir_printer.h" #include "torch/csrc/jit/tensorexpr/loopnest.h" @@ -48,7 +48,7 @@ struct T { } }; -void testTrainBasic() { +TEST(Train, TrainBasic) { { VGraph graph; auto A = graph.create_tensor({"K"}); @@ -56,7 +56,7 @@ void testTrainBasic() { auto C = call("mul", {A, B})[0]; Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -85,7 +85,7 @@ void testTrainBasic() { auto dA = grad(D, A, ones); Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -117,7 +117,7 @@ void testTrainBasic() { auto C = A + B; Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -146,7 +146,7 @@ void testTrainBasic() { auto dA = D.grad(A, ones); Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -181,7 +181,7 @@ void testTrainBasic() { auto dC = (C * C).grad(B, ones); Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -209,7 +209,7 @@ void testTrainBasic() { auto X = T(g, {"K"}); auto Y = X.sum(); Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -229,7 +229,7 @@ void testTrainBasic() { auto Y = X.sum(); auto Z = Y.broadcast_like(X); Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -266,7 +266,7 @@ void testTrainBasic() { auto new_W = W - W_grad; Stmt* s; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -304,14 +304,15 @@ void testTrainBasic() { for (auto i = 0; i < 100; ++i) { std::generate(X_.begin(), X_.end(), gen); - cg.call({X_.data(), - W_ref_.data(), - W_.data(), - one_.data(), - K_.data(), - LR_.data(), - W_.data(), - N}); + cg.call( + {X_.data(), + W_ref_.data(), + W_.data(), + one_.data(), + K_.data(), + LR_.data(), + W_.data(), + N}); } // Less than 1% difference after running regression for (auto i = 0; i < W_.size(); ++i) { diff --git a/test/cpp/tensorexpr/test_train.h b/test/cpp/tensorexpr/test_train.h index 39674933aa9c5..16ff667860d0b 100644 --- a/test/cpp/tensorexpr/test_train.h +++ b/test/cpp/tensorexpr/test_train.h @@ -37,7 +37,7 @@ VTensor* grad(VTensor* y, VTensor* x, VTensor* j); std::string dot(const VGraph& g); std::tuple< torch::jit::tensorexpr::Stmt*, - std::map, + std::map, std::map, std::map> to_tensorexpr(const VGraph& graph, std::vector outputs = {}); diff --git a/test/cpp/tensorexpr/test_train_impl.cpp b/test/cpp/tensorexpr/test_train_impl.cpp index 1636b583cef9f..2beffa2e09899 100644 --- a/test/cpp/tensorexpr/test_train_impl.cpp +++ b/test/cpp/tensorexpr/test_train_impl.cpp @@ -1,8 +1,6 @@ #include "test/cpp/tensorexpr/test_train.h" #include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/buffer.h" #include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/function.h" #include "torch/csrc/jit/tensorexpr/ir.h" #include "torch/csrc/jit/tensorexpr/ir_printer.h" #include "torch/csrc/jit/tensorexpr/loopnest.h" @@ -171,7 +169,7 @@ VTensor* grad(VTensor* y, VTensor* x, VTensor* j) { TORCH_CHECK(g, ss.str()); } auto g_outs = g(op->inputs, grad_inputs); - for (auto i = 0; i < g_outs.size(); ++i) { + for (auto i = 0U; i < g_outs.size(); ++i) { auto input = op->inputs[i]; if (need_grad.find(input) != need_grad.end()) { if (grad_map.find(input) != grad_map.end()) { @@ -200,7 +198,7 @@ VOp::VOp( VGraph* graph_) : inputs(inputs_), graph(graph_) { method = &VMethod::get(name); - for (auto i = 0; i < num_outputs; ++i) { + for (auto i = 0U; i < num_outputs; ++i) { outputs.emplace_back(graph->create_tensor({})); outputs.back()->op = this; } @@ -305,8 +303,9 @@ REGISTER_METHOD( }, [](const std::vector& inputs, const std::vector& ginputs) -> std::vector { - return {call("mul", {ginputs[0], inputs[1]})[0], - call("mul", {ginputs[0], inputs[0]})[0]}; + return { + call("mul", {ginputs[0], inputs[1]})[0], + call("mul", {ginputs[0], inputs[0]})[0]}; }, [](const std::vector& inputs) -> std::vector> { @@ -331,8 +330,9 @@ REGISTER_METHOD( const std::vector& ginputs) -> std::vector { auto b_2 = call("mul", {inputs[1], inputs[1]})[0]; auto a_div_b_2 = call("div", {inputs[0], b_2})[0]; - return {call("div", {ginputs[0], inputs[1]})[0], - call("mul", {ginputs[0], call("neg", {a_div_b_2})[0]})[0]}; + return { + call("div", {ginputs[0], inputs[1]})[0], + call("mul", {ginputs[0], call("neg", {a_div_b_2})[0]})[0]}; }, [](const std::vector& inputs) -> std::vector> { @@ -408,7 +408,7 @@ std::string dot(const VGraph& g) { std::tuple< Stmt*, - std::map, + std::map, std::map, std::map> to_tensorexpr(const VGraph& graph, std::vector outputs) { @@ -458,7 +458,7 @@ to_tensorexpr(const VGraph& graph, std::vector outputs) { return order; }; - std::map inputs; + std::map inputs; std::map bindings; std::map vbindings; @@ -481,10 +481,10 @@ to_tensorexpr(const VGraph& graph, std::vector outputs) { if (vars.size() == 0) { vars.emplace_back(IntImm::make(1)); } - Buffer inpB(BufHandle(get_name(id), exprs, kFloat)); + Placeholder inpB(BufHandle(get_name(id), exprs, kFloat)); auto inpT = Compute("input" + get_name(id), vars, [&](const VarHandle& i) { - return Load::make(inpB, {i}, 1); + return Load::make(BufHandle(inpB.data()), {i}, 1); }); inputs.emplace(&t, inpB); bindings.emplace(&t, inpT); @@ -499,7 +499,7 @@ to_tensorexpr(const VGraph& graph, std::vector outputs) { } auto outs = vop->method->lower(inps, vop->inputs, vbindings); TORCH_CHECK(outs.size() == vop->outputs.size()); - for (auto i = 0; i < outs.size(); ++i) { + for (auto i = 0U; i < outs.size(); ++i) { bindings[vop->outputs[i]] = outs[i]; } } diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp index 1dc0d1ccee35d..71ad0f5149ace 100644 --- a/test/cpp/tensorexpr/test_type.cpp +++ b/test/cpp/tensorexpr/test_type.cpp @@ -1,4 +1,6 @@ -#include "test/cpp/tensorexpr/test_base.h" +#include + +#include "torch/csrc/jit/tensorexpr/eval.h" #include "torch/csrc/jit/tensorexpr/ir.h" #include "torch/csrc/jit/tensorexpr/tensor.h" @@ -6,7 +8,7 @@ namespace torch { namespace jit { using namespace torch::jit::tensorexpr; -void testTypeTest01() { +TEST(Type, Test01) { KernelScope kernel_scope; { Dtype dt1 = kInt; @@ -41,7 +43,116 @@ void testTypeTest01() { } } -void testTypePropagation() { +TEST(Type, BitCasting) { + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kInt); + } + { + KernelScope kernel_scope; + VarHandle x("x", kInt); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kFloat); + } + { + KernelScope kernel_scope; + VarHandle x("x", kShort); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kHalf); + } + { + KernelScope kernel_scope; + VarHandle x("x", kHalf); + ExprHandle y = bitcast(x); + ASSERT_EQ(y.dtype(), kShort); + } + + constexpr int16_t ref16 = 1337; + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + at::Half reff16 = 1337.0f; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + using SimpleIRExprEval = ExprEval; + // this is broken + /*{ + KernelScope kernel_scope; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(*k); + auto b = BitCast::make(kShort, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + KernelScope kernel_scope; + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + KernelScope kernel_scope; + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + KernelScope kernel_scope; + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + KernelScope kernel_scope; + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } + + // This segfaults :( + /*{ + KernelScope kernel_scope; + VarHandle x("x", kDouble); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kLong); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kShort); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + KernelScope kernel_scope; + VarHandle x("x", kInt); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + }*/ +} + +TEST(Type, Propagation) { // Same types: { KernelScope kernel_scope; diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h deleted file mode 100644 index 34eeaa0de19a1..0000000000000 --- a/test/cpp/tensorexpr/tests.h +++ /dev/null @@ -1,471 +0,0 @@ -#pragma once - -/** - * See README.md for instructions on how to add a new test. - */ -#include -#include - -namespace torch { -namespace jit { - -#define TH_FORALL_TENSOREXPR_TESTS(_) \ - _(ExprBasicValueTest) \ - _(ExprBasicValueTest02) \ - _(ExprLetTest01) \ - _(ExprLetStmtTest01) \ - _(ExprLetTest02) \ - _(ExprIntTest) \ - _(ExprFloatTest) \ - _(ExprByteTest) \ - _(ExprCharTest) \ - _(ExprShortTest) \ - _(ExprLongTest) \ - _(ExprHalfTest) \ - _(ExprDoubleTest) \ - _(ExprDisallowBoolArithmetic) \ - _(ExprVectorAdd01) \ - _(ExprCompareSelectEQ) \ - _(ExprCompareSelectDtypes) \ - _(ExprIntrinsicsDtypes) \ - _(ExprSubstitute01) \ - _(ExprMath01) \ - _(ExprUnaryMath01) \ - _(ExprBinaryMath01) \ - _(ExprDynamicShapeAdd) \ - _(ExprBitwiseOps) \ - _(IRPrinterBasicValueTest) \ - _(IRPrinterBasicValueTest02) \ - _(IRPrinterCastTest) \ - _(IRPrinterFunctionName) \ - _(ExprSimple01) \ - _(ExprLower01) \ - _(ExprSimple02) \ - _(ExprSliceHead) \ - _(ExprSliceHeadWhenFactorEqualsSize) \ - _(ExprSliceHeadWhenFactorLargerThanSize) \ - _(ExprSliceHeadWithLoopOptions) \ - _(ExprSliceHeadWithNonZeroStart) \ - _(ExprSliceTail) \ - _(ExprSliceTailWhenFactorEqualsSize) \ - _(ExprSliceTailWhenFactorLargerThanSize) \ - _(ExprSliceTailWithLoopOptions) \ - _(ExprSliceAndNormalize) \ - _(ExprSliceWithVariableDimension) \ - _(ExprSplitAndSlice) \ - _(ExprSplitWithTail) \ - _(ExprSplitWithTailNone) \ - _(ExprSplitWithMask01) \ - _(ExprSplitWithMaskRepeatedNoMask) \ - _(SplitWithTailWithLoopOptions) \ - _(SplitWithMaskWithLoopOptions) \ - _(ScheduleBroadcastAddBuffer) \ - _(ScheduleFunctionCall01) \ - _(ScheduleInlineSimple) \ - _(ScheduleInlineFunc01) \ - _(ScheduleInlineRandom) \ - _(ScheduleInlineRandomUnrelated) \ - _(ScheduleInlineRandomLowerDimensions) \ - _(ScheduleInlineIntrinsics) \ - _(ScheduleInlineRandWithIntrinsics) \ - _(ScheduleSplitAThenInline) \ - _(ScheduleSplitBThenInline) \ - _(ScheduleSplitTwiceThenInline) \ - _(ScheduleInlineThenSplit) \ - _(ScheduleSplitInlineThenSplit) \ - _(ScheduleSplitInlineSimplify) \ - _(ScheduleInlineThreeMixedOnce) \ - _(ScheduleInlineThreeMixedTwice) \ - _(ScheduleInlineThreeMixedInner) \ - _(ScheduleInlineThreeMixedSplit) \ - _(ScheduleFuserStyle) \ - _(ScheduleFuserThreeArg) \ - _(ScheduleDynamicShape2D) \ - _(ReduceSum1D) \ - _(ReduceSum2D) \ - _(ReduceSum3D) \ - _(ReduceSum10D) \ - _(ReduceProduct) \ - _(ReduceMax) \ - _(ReduceMinCustomInitializer) \ - _(ReduceAnyAll) \ - _(ReduceMatmul2D) \ - _(ReduceRfactorLike) \ - _(ReduceAsProducer) \ - _(ReduceAsConsumer) \ - _(SplitReduceAxis) \ - _(SplitNonReduceAxis) \ - _(ReorderedReductionInitializer) \ - _(ReduceRfactor) \ - _(Reduce3DRfactorInternal) \ - _(Reduce3DRfactorInner) \ - _(Reduce3DRfactorOuter) \ - _(Reduce3DRfactorWithOuter) \ - _(Reduce3DRfactorRepeated) \ - _(ReduceRfactorInsertionPoint) \ - _(Reduce3DRfactorInsertionPoint) \ - _(ReduceRepeatedInternalRfactor) \ - _(ReduceSplitTail) \ - _(ReduceSplitNoTail) \ - _(ReduceOverSplitTail) \ - _(ReduceSplitMask) \ - _(ReduceSplitNoMask) \ - _(ReduceOverSplitMask) \ - _(ReduceSplitRfactor) \ - _(ReduceOverSplitRfactor) \ - _(ReduceInlineReduction) \ - _(ReduceInlineConsumer) \ - _(ReduceInlineReducerInternal) \ - _(TypeTest01) \ - _(TypePropagation) \ - _(Cond01) \ - _(IfThenElse01) \ - _(IfThenElse02) \ - _(IfThenElse03) \ - _(ATen_cast_Float) \ - _(ATennegInt) \ - _(ATennegFloat) \ - _(ATenaddInt) \ - _(ATenaddFloat) \ - _(ATensubInt) \ - _(ATensubFloat) \ - _(ATenlerp) \ - _(ATenaddcmulInt) \ - _(ATenaddcmulFloat) \ - _(ATenmulInt) \ - _(ATenmulFloat) \ - _(ATendivInt) \ - _(ATendivFloat) \ - _(ATenmaxInt) \ - _(ATenmaxFloat) \ - _(ATenminInt) \ - _(ATenminFloat) \ - _(ATenreciprocal) \ - _(ATenreluInt) \ - _(ATenreluFloat) \ - _(ATenlogFloat) \ - _(ATenlog10Float) \ - _(ATenlog2Float) \ - _(ATenexpFloat) \ - _(ATenerfFloat) \ - _(ATencosFloat) \ - _(ATeneqInt) \ - _(ATengeInt) \ - _(ATengtInt) \ - _(ATenleInt) \ - _(ATenltInt) \ - _(ConstantFoldSimple) \ - _(ConstantFoldTwoLayer) \ - _(ConstantFoldShifts) \ - _(ConstantFoldBitwise) \ - _(ConstantFoldMultiOp) \ - _(ConstantFoldMinMax) \ - _(ConstantFoldIntrinsics) \ - _(ConstantFoldCastToBool) \ - _(ConstantFoldWithVar) \ - _(ConditionalSelectFoldSimple) \ - _(ConditionalSelectFoldTwoLayer) \ - _(ConditionalSelectFoldWithVar) \ - _(UnFoldableExpr) \ - _(HashSimple) \ - _(HashEquivalence) \ - _(HashEquivalenceRand) \ - _(HashEquivalenceAfterFolding) \ - _(HashDifferenceTypes) \ - _(HashLargeExpression) \ - _(HashForLoopOptions) \ - _(SimplifyAdd) \ - _(SimplifySub) \ - _(SimplifyMultiLayer) \ - _(SimplifyMultiTerm) \ - _(SimplifyCasts) \ - _(SimplifyEliminatesNoOps) \ - _(SimplifyMultiVar) \ - _(SimplifyEliminatesVar) \ - _(SimplifyAdds) \ - _(SimplifyMuls) \ - _(SimplifySubs) \ - _(SimplifyDiv) \ - _(SimplifyMultiOp) \ - _(SimplifyManyOps) \ - _(SimplifyFactorization) \ - _(SimplifyFactorizeUneven) \ - _(SimplifyDeeperTerms) \ - _(SimplifyDeeperDifference) \ - _(SimplifyFoldComplexDifference) \ - _(SimplifyIfComponents) \ - _(SimplifyOpaqueTerms) \ - _(SimplifySymbolicMinMax) \ - _(SimplifyNestedMax) \ - _(SimplifyNestedMin) \ - _(SimplifyWontReorderFloat) \ - _(SimplifyRoundModPattern) \ - _(SimplifyRoundModPatternFactorization) \ - _(SimplifyRoundModPatternMultivar) \ - _(SimplifyDivisionScalarFactorization) \ - _(SimplifyConstantBranches) \ - _(SimplifyConstantCond) \ - _(SimplifyEliminateEmptyCond) \ - _(SimplifyEliminateZeroLengthFor) \ - _(SimplifyOneLoopFor) \ - _(SimplifyForWontLoseLoopOptions) \ - _(SimplifyMultilevelFor) \ - _(SimplifyForCleansUp) \ - _(SimplifyEliminateEmptyFor) \ - _(SimplifyFlattenBlock) \ - _(SimplifyEliminateZeroLengthAlloc) \ - _(DontSimplifyRand) \ - _(SimplifyReorderForCond) \ - _(SimplifyFuseConditions) \ - _(SimplifySyncThreads) \ - _(SimplifyRampSubBroadcast) \ - _(SimplifyBroadcastTermExpander) \ - _(RegisterizerSimple) \ - _(RegisterizerLoop) \ - _(RegisterizerLoopFixedLoad) \ - _(RegisterizerMultiVar) \ - _(RegisterizerVariableLoad) \ - _(RegisterizerSymbolicIndices) \ - _(RegisterizerEarlyStop) \ - _(RegisterizerMultiLoop) \ - _(RegisterizerRepeated) \ - _(RegisterizerNoLoads) \ - _(RegisterizerNoRepeatedStores) \ - _(RegisterizerMultiVarOverlap) \ - _(RegisterizerAllocs) \ - _(RegisterizerNoInitializer) \ - _(RegisterizerLoadThenStore) \ - _(RegisterizerParallelized) \ - _(RegisterizerConditions) \ - _(StmtClone) \ - _(BoundsInference_1) \ - _(BoundsInference_2) \ - _(BoundsInference_3) \ - _(BoundsInference_4) \ - _(BoundsInference_5) \ - _(BoundsInference_6) \ - _(BoundsInferenceNonOverlapping) \ - _(BoundsInferenceAdjacent) \ - _(MergeInferredBounds) \ - _(MergeInferredLoadStoreDiff) \ - _(MergeInferred2DBounds) \ - _(MergeAdjacentBounds) \ - _(MergeSymbolicBounds) \ - _(MergeSymbolicAdjacent) \ - _(LoopNestComputeAt_1) \ - _(LoopNestComputeAt_2) \ - _(LoopNestComputeAt_3) \ - _(LoopNestComputeAt_4) \ - _(LoopNestReorderAxis1) \ - _(LoopNestReorderPartialAxes) \ - _(LoopNestReorderInternalAxis) \ - _(LoopNestReorderEnclosingAxis) \ - _(LoopNestReorderSameAxis) \ - _(LoopNestReorderExtraStatements) \ - _(LoopNestReorderLongStringOfPreOrphans) \ - _(LoopNestReorderLongStringOfPostOrphans) \ - _(LoopNestReorderLongStringFull) \ - _(LoopNestReorderInternalLoopNest) \ - _(OuterLoopVectorization) \ - _(Unroll) \ - _(UnrollOuter) \ - _(UnrollInner) \ - _(UnrollMultipleStatements) \ - _(UnrollEmpty) \ - _(NoUnroll) \ - _(UnrollWithLet) \ - _(NormalizeStartPositive) \ - _(NormalizeStartNegative) \ - _(NormalizeStartZero) \ - _(NormalizeStartVariable) \ - _(NormalizeOnNestedOuterLoop) \ - _(NormalizeOnNestedInnerLoop) \ - _(NormalizeAndSplitWithTail) \ - _(DetectInlineRankMismatch) \ - _(Kernel_1) \ - _(Kernel_2) \ - _(Kernel_3) \ - _(Kernel_4) \ - _(KernelSumAllAxes) \ - _(KernelSumOneAxis) \ - _(KernelSumMultipleAxes) \ - _(FuserPass_1) \ - _(FuserPass_2) \ - _(FuserPass_3) \ - _(FuserPass_0DimInput) \ - _(FuserPass_UnfusibleDevice) \ - _(FuserPass_UnknownShapes) \ - _(FuserPass_UnknownShapesIgnored) \ - _(FuserPass_Multidevice) \ - _(FuserPass_MergeGroups) \ - _(TrainBasic) - -#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \ - _(LLVMByteImmTest) \ - _(LLVMCharImmTest) \ - _(LLVMShortImmTest) \ - _(LLVMIntImmTest) \ - _(LLVMLongImmTest) \ - _(LLVMFloatImmTest) \ - _(LLVMDoubleImmTest) \ - _(LLVMHalfImmTest) \ - _(LLVMByteAddTest) \ - _(LLVMCharAddTest) \ - _(LLVMShortAddTest) \ - _(LLVMIntAddTest) \ - _(LLVMLongAddTest) \ - _(LLVMFloatAddTest) \ - _(LLVMDoubleAddTest) \ - _(LLVMHalfAddTest) \ - _(LLVMByteSubTest) \ - _(LLVMCharSubTest) \ - _(LLVMShortSubTest) \ - _(LLVMIntSubTest) \ - _(LLVMLongSubTest) \ - _(LLVMFloatSubTest) \ - _(LLVMDoubleSubTest) \ - _(LLVMHalfSubTest) \ - _(LLVMByteMulTest) \ - _(LLVMCharMulTest) \ - _(LLVMShortMulTest) \ - _(LLVMIntMulTest) \ - _(LLVMLongMulTest) \ - _(LLVMFloatMulTest) \ - _(LLVMDoubleMulTest) \ - _(LLVMHalfMulTest) \ - _(LLVMByteDivTest) \ - _(LLVMCharDivTest) \ - _(LLVMShortDivTest) \ - _(LLVMIntDivTest) \ - _(LLVMLongDivTest) \ - _(LLVMFloatDivTest) \ - _(LLVMDoubleDivTest) \ - _(LLVMHalfDivTest) \ - _(LLVMIntToFloatCastTest) \ - _(LLVMFloatToIntCastTest) \ - _(LLVMIntToLongCastTest) \ - _(LLVMByteToCharCastTest) \ - _(LLVMHalfToLongCastTest) \ - _(LLVMByteToDoubleCastTest) \ - _(LLVMLetTest01) \ - _(LLVMLetTest02) \ - _(LLVMLetTestMultitype) \ - _(LLVMBufferTest) \ - _(LLVMBlockTest) \ - _(LLVMLoadStoreTest) \ - _(LLVMVecLoadStoreTest) \ - _(LLVMVecFloat_acosLane4Test) \ - _(LLVMVecFloat_asinLane4Test) \ - _(LLVMVecFloat_atanLane4Test) \ - _(LLVMVecFloat_coshLane4Test) \ - _(LLVMVecFloat_sinhLane4Test) \ - _(LLVMVecFloat_tanhLane4Test) \ - _(LLVMVecFloat_erfLane4Test) \ - _(LLVMVecFloat_erfcLane4Test) \ - _(LLVMVecFloat_expm1Lane4Test) \ - _(LLVMVecFloat_lgammaLane4Test) \ - _(LLVMVecFloat_acosLane8Test) \ - _(LLVMVecFloat_asinLane8Test) \ - _(LLVMVecFloat_atanLane8Test) \ - _(LLVMVecFloat_coshLane8Test) \ - _(LLVMVecFloat_sinhLane8Test) \ - _(LLVMVecFloat_tanhLane8Test) \ - _(LLVMVecFloat_erfLane8Test) \ - _(LLVMVecFloat_erfcLane8Test) \ - _(LLVMVecFloat_expm1Lane8Test) \ - _(LLVMVecFloat_lgammaLane8Test) \ - _(LLVMVecDouble_acosLane2Test) \ - _(LLVMVecDouble_asinLane2Test) \ - _(LLVMVecDouble_atanLane2Test) \ - _(LLVMVecDouble_coshLane2Test) \ - _(LLVMVecDouble_sinhLane2Test) \ - _(LLVMVecDouble_tanhLane2Test) \ - _(LLVMVecDouble_erfLane2Test) \ - _(LLVMVecDouble_erfcLane2Test) \ - _(LLVMVecDouble_expm1Lane2Test) \ - _(LLVMVecDouble_lgammaLane2Test) \ - _(LLVMVecDouble_acosLane4Test) \ - _(LLVMVecDouble_asinLane4Test) \ - _(LLVMVecDouble_atanLane4Test) \ - _(LLVMVecDouble_coshLane4Test) \ - _(LLVMVecDouble_sinhLane4Test) \ - _(LLVMVecDouble_tanhLane4Test) \ - _(LLVMVecDouble_erfLane4Test) \ - _(LLVMVecDouble_erfcLane4Test) \ - _(LLVMVecDouble_expm1Lane4Test) \ - _(LLVMVecDouble_lgammaLane4Test) \ - _(LLVMMemcpyTest) \ - _(LLVMBzeroTest) \ - _(LLVMElemwiseAdd) \ - _(LLVMElemwiseAddFloat) \ - _(LLVMElemwiseLog10Float) \ - _(LLVMElemwiseLog1pFloat) \ - _(LLVMElemwiseMaxInt) \ - _(LLVMElemwiseMinInt) \ - _(LLVMElemwiseMaxFloat) \ - _(LLVMElemwiseMaxNaNFloat) \ - _(LLVMElemwiseMinFloat) \ - _(LLVMElemwiseMinNaNFloat) \ - _(LLVMElemwiseMod) \ - _(LLVMCompareSelectIntEQ) \ - _(LLVMCompareSelectFloatEQ) \ - _(LLVMCompareSelectByteGT) \ - _(LLVMCompareSelectByteGE) \ - _(LLVMCompareSelectByteLT) \ - _(LLVMCompareSelectByteLE) \ - _(LLVMStoreFloat) \ - _(LLVMSimpleMath01) \ - _(LLVMComputeMul) \ - _(LLVMBroadcastAdd) \ - _(LLVMBitwiseOps) \ - _(LLVMDynamicShapeAdd) \ - _(LLVMBindDynamicShapeAdd) \ - _(LLVMTensorDynamicShapeAdd) \ - _(LLVMDynamicShape2D) \ - _(LLVMEmptyStmt) \ - _(LLVMEliminatedStmt) \ - _(LLVMIfThenElseTest) \ - _(LLVMVectorizerLoadStoreTest) \ - _(LLVMSimpleReduction) \ - _(LLVMRFactorReduction) - -// _(LLVMRFactorVectorizedReduction) - -#define TH_FORALL_TENSOREXPR_TESTS_CUDA(_) \ - _(CudaTestVectorAdd01) \ - _(CudaTestVectorAdd02) \ - _(CudaDynamicShape2D) \ - _(CudaDynamicShapeSplit) \ - _(CudaOneBlockOneThreadGlobalReduce1) \ - _(CudaOneBlockMultiThreadGlobalReduce1) \ - _(CudaNoThreadIdxWrite_1) \ - _(CudaLocalMemReduce_1) \ - _(CudaSharedMemReduce_1) \ - _(CudaTestRand01) \ - _(CudaSigmoid) \ - _(CudaHalfCast) \ - _(CudaHalfSupport) \ - _(CudaPrioritizeDependents) \ - _(CudaMaskBlockDim) \ - _(CudaMaskThreadDim) \ - _(CudaMaskMultiBlockDim) \ - _(CudaMaskBlockAndThreadDim) \ - _(CudaMaskMultiDim) \ - _(CudaMaskMultiDimSymbolic) \ - _(CudaMaskCompoundInnerLoop) \ - _(CudaMaskInnerLoopOneBlock) \ - _(CudaMaskMultiDimMultiAxis) \ - _(CudaMaskMultiDimMultiLevel) - -#define DECLARE_TENSOREXPR_TEST(name) void test##name(); -TH_FORALL_TENSOREXPR_TESTS(DECLARE_TENSOREXPR_TEST) -#ifdef TORCH_ENABLE_LLVM -TH_FORALL_TENSOREXPR_TESTS_LLVM(DECLARE_TENSOREXPR_TEST) -#endif -#ifdef USE_CUDA -TH_FORALL_TENSOREXPR_TESTS_CUDA(DECLARE_TENSOREXPR_TEST) -#endif -#undef DECLARE_TENSOREXPR_TEST - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp new file mode 100644 index 0000000000000..31e05549186e5 --- /dev/null +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -0,0 +1,394 @@ +// *** Tensor Expressions *** +// +// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to +// work with them, and outlines how they are used in the overall TorchScript +// compilation pipeline. This doc is permanently a "work in progress" since NNC +// is under active development and things change fast. +// +// This Tutorial's code is compiled in the standard pytorch build, and the +// executable can be found in `build/bin/tutorial_tensorexpr`. +// +// *** What is NNC *** +// +// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT +// and it performs on-the-fly code generation for kernels, which are often a +// combination of multiple aten (torch) operators. +// +// When the JIT interpreter executes a torchscript model, it automatically +// extracts subgraphs from the torchscript IR graph for which specialized code +// can be JIT generated. This usually improves performance as the 'combined' +// kernel created from the subgraph could avoid unnecessary memory traffic that +// is unavoidable when the subgraph is interpreted as-is, operator by operator. +// This optimization is often referred to as 'fusion'. Relatedly, the process of +// finding and extracting subgraphs suitable for NNC code generation is done by +// a JIT pass called 'fuser'. +// +// *** What is TE *** +// +// TE stands for Tensor Expressions. TE is a commonly used approach for +// compiling kernels performing tensor (~matrix) computation. The idea behind it +// is that operators are represented as a mathematical formula describing what +// computation they do (as TEs) and then the TE engine can perform mathematical +// simplification and other optimizations using those formulas and eventually +// generate executable code that would produce the same results as the original +// sequence of operators, but more efficiently. +// +// NNC's design and implementation of TE was heavily inspired by Halide and TVM +// projects. +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +int main(int argc, char* argv[]) { + // Memory management for tensor expressions is currently done with memory + // arenas. That is, whenever an object is created it registers itself in an + // arena and the object is kept alive as long as the arena is alive. When the + // arena gets destructed, it deletes all objects registered in it. + // + // The easiest way to set up a memory arena is to use `KernelScope` class - it + // is a resource guard that creates a new arena on construction and restores + // the previously set arena on destruction. + // + // We will create a kernel scope here, and thus we'll set up a mem arena for + // the entire tutorial. + KernelScope kernel_scope; + + std::cout << "*** Structure of tensor expressions ***" << std::endl; + { + // A tensor expression is a tree of expressions. Each expression has a type, + // and that type defines what sub-expressions it the current expression has. + // For instance, an expression of type 'Mul' would have a type 'kMul' and + // two subexpressions: LHS and RHS. Each of these two sub-expressions could + // also be a 'Mul' or some other expression. + // + // Let's construct a simple TE: + Expr* lhs = new IntImm(5); + Expr* rhs = new Var("x", kInt); + Expr* mul = new Mul(lhs, rhs); + std::cout << "Tensor expression: " << *mul << std::endl; + // Prints: Tensor expression: 5 * x + + // Here we created an expression representing a 5*x computation, where x is + // an int variable. + + // Another, probably a more convenient, way to construct tensor expressions + // is to use so called expression handles (as opposed to raw expressions + // like we did in the previous example). Expression handles overload common + // operations and allow us to express the same semantics in a more natural + // way: + ExprHandle l = 1; + ExprHandle r = Var::make("x", kInt); + ExprHandle m = l * r; + std::cout << "Tensor expression: " << *m.node() << std::endl; + // Prints: Tensor expression: 1 * x + + // In a similar fashion we could construct arbitrarily complex expressions + // using mathematical and logical operations, casts between various data + // types, and a bunch of intrinsics. + ExprHandle a = Var::make("a", kInt); + ExprHandle b = Var::make("b", kFloat); + ExprHandle c = Var::make("c", kFloat); + ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f); + std::cout << "Tensor expression: " << *x.node() << std::endl; + // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f) + + // An ultimate purpose of tensor expressions is to optimize tensor + // computations, and in order to represent accesses to tensors data, there + // is a special kind of expression - a load. + // To construct a load we need two pieces: the base and the indices. The + // base of a load is a Buf expression, which could be thought of as a + // placeholder similar to Var, but with dimensions info. + // + // Let's construct a simple load: + BufHandle A("A", {ExprHandle(64), ExprHandle(32)}, kInt); + ExprHandle i = Var::make("i", kInt), j = Var::make("j", kInt); + ExprHandle load = Load::make(A.dtype(), A, {i, j}, /* mask= */ 1); + std::cout << "Tensor expression: " << *load.node() << std::endl; + // Prints: Tensor expression: A[i, j] + } + + std::cout << "*** Tensors, Functions, and Placeholders ***" << std::endl; + { + // A tensor computation is represented by objects of Tensor class and + // consists of the following pieces: + // - domain, which is specified by a Buf expression + // - an expression (or several expressions if we want to perform several + // independent computations over the same domain) for its elements, as a + // function of indices + // + // TODO: Update this section once Tensor/Function cleanup is done + std::vector dims = { + new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate + // and represents an integer constant + + // Next we need to create arguments. The arguments are Vars, and they play + // role of placeholders. The computation that the tensor would describe + // would use these arguments. + const Var* i = new Var("i", kInt); + const Var* j = new Var("j", kInt); + std::vector args = {i, j}; + + // Now we can define the body of the tensor computation using these + // arguments. + Expr* body = new Mul(i, j); + + // Finally, we pass all these pieces together to Tensor constructor: + Tensor* X = new Tensor("X", dims, args, body); + std::cout << "Tensor computation: " << *X << std::endl; + // Prints: Tensor computation: Tensor X(i[64], j[32]) = i * j + + // Similarly to how we provide a more convenient way of using handles for + // constructing Exprs, Tensors also have a more convenient API for + // construction. It is based on Compute API, which takes a name, + // dimensions, and a lambda specifying the computation body: + Tensor* Z = Compute( + "Z", + {{64, "i"}, {32, "j"}}, + [](const VarHandle& i, const VarHandle& j) { return i / j; }); + std::cout << "Tensor computation: " << *Z << std::endl; + // Prints: Tensor computation: Tensor Z(i[64], j[32]) = i / j + + // Tensors might access other tensors and external placeholders in their + // expressions. It can be done like so: + Placeholder P("P", kFloat, {64, 32}); + Tensor* R = Compute( + "R", + {{64, "i"}, {32, "j"}}, + [&](const VarHandle& i, const VarHandle& j) { + return Z->call(i, j) * P.load(i, j); + }); + std::cout << "Tensor computation: " << *R << std::endl; + // Prints: Tensor computation: Tensor R(i[64], j[32]) = Z(i, j) * P[i, j] + + // Placeholders could be thought of as external tensors, i.e. tensors for + // which we don't have the element expression. In other words, for `Tensor` + // we know an expression specifying how its elements can be computed (a + // mathematical formula). For external tensors, or placeholders, we don't + // have such an expression. They need to be considered as coming to us as + // inputs from outside - we can only load data from them. + // + // Also note that we use 'call' to construct an access to an element of a + // Tensor and we use 'load' for accessing elements of an external tensor + // through its Placeholder. This is an implementation detail and could be + // changed in future. + + // TODO: Show how reductions are represented and constructed + } + + std::cout << "*** Loopnests and Statements ***" << std::endl; + { + // Creating a tensor expression is the first step to generate an executable + // code for it. A next step is to represent it as a loop nest and apply + // various loop transformations in order to get an optimal implementation. + // In Halide's or TVM's terms the first step was to define the algorithm of + // computation (what to compute?) and now we are getting to the schedule of + // the computation (how to compute?). + // + // Let's create a simple tensor expression and construct a loop nest for it. + Placeholder A("A", kFloat, {64, 32}); + Placeholder B("B", kFloat, {64, 32}); + Tensor* X = Compute( + "X", + {{64, "i"}, {32, "j"}}, + [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j) + B.load(i, j); + }); + Tensor* Y = Compute( + "Y", + {{64, "i"}, {32, "j"}}, + [&](const VarHandle& i, const VarHandle& j) { + return sigmoid(X->call(i, j)); + }); + std::cout << "Tensor computation X: " << *X + << "Tensor computation Y: " << *Y << std::endl; + // Prints: + // Tensor computation X: Tensor X(i[64], j[32]) = (A[i, j]) + (B[i, j]) + // Tensor computation Y: Tensor Y(i[64], j[32]) = sigmoid(X(i, j)) + + // Creating a loop nest is as quite simple, we just need to specify what are + // the output tensors in our computation and LoopNest object will + // automatically pull all tensor dependencies: + LoopNest loopnest({Y}); + + // An IR used in LoopNest is based on tensor statements, represented by + // `Stmt` class. Statements are used to specify the loop nest structure, and + // to take a sneak peek at them, let's print out what we got right after + // creating our LoopNest object: + std::cout << *loopnest.root_stmt() << std::endl; + // Prints: + // { + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // X[i, j] = (A[i, j]) + (B[i, j]); + // } + // } + // for (int i_1 = 0; i_1 < 64; i_1++) { + // for (int j_1 = 0; j_1 < 32; j_1++) { + // Y[i_1, j_1] = sigmoid(X(i_1, j_1)); + // } + // } + // } + + // To introduce statements let's first look at their three main types (in + // fact, there are more than 3 types, but the other types would be easy to + // understand once the overall structure is clear): + // 1) Block + // 2) For + // 3) Store + // + // A `Block` statement is simply a list of other statements. + // A `For` is a statement representing one axis of computation. It contains + // an index variable (Var), boundaries of the axis (start and end - both are + // `Expr`s), and a `Block` statement body. + // A `Store` represents an assignment to a tensor element. It contains a Buf + // representing the target tensor, a list of expressions for indices of the + // element, and the value to be stored, which is an arbitrary expression. + + // Once we've constructed the loop nest, we can apply various tranformations + // to it. To begin with, let's inline computation of X into computation of Y + // and see what happens to our statements. + loopnest.computeInline(loopnest.getLoopBodyFor(X)); + std::cout << *loopnest.root_stmt() << std::endl; + // Prints: + // { + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // Y[i, j] = sigmoid((A[i, j]) + (B[i, j])); + // } + // } + // } + // + // As you can see, the first two loops have disappeared and the expression + // for X[i,j] has been inserted into the Y[i,j] computation. + + // Loop transformations can be composed, so we can do something else with + // our loop nest now. Let's split the inner loop with a factor of 9, for + // instance. + std::vector loops = loopnest.getLoopStmtsFor(Y); + For* j_outer; + For* j_inner; + For* j_tail; + int split_factor = 9; + loopnest.splitWithTail( + loops[1], // loops[0] is the outer loop, loops[1] is inner + split_factor, + &j_outer, // These are handles that we would be using for + &j_inner, // further transformations + &j_tail); + std::cout << *loopnest.root_stmt() << std::endl; + // Prints: + // { + // for (int i = 0; i < 64; i++) { + // for (int j_outer = 0; j_outer < (32 - 0) / 9; j_outer++) { + // for (int j_inner = 0; j_inner < 9; j_inner++) { + // Y[i, j_outer * 9 + j_inner] = sigmoid((A[i, j_outer * 9 + ... + // } + // } + // for (int j_tail = 0; j_tail < (32 - 0) % 9; j_tail++) { + // Y[i, j_tail + ((32 - 0) / 9) * 9] = sigmoid((A[i, j_tail + ... + // } + // } + // } + + // TODO: List all available transformations + // TODO: Show how statements can be constructed manually + } + + std::cout << "*** Codegen ***" << std::endl; + { + // An ultimate goal of tensor expressions is to be provide a mechanism to + // execute a given computation in the fastest possible way. So far we've + // looked at how we could describe what computation we're interested in, but + // we haven't looked at how to actually execute it. So far all we've been + // dealing with was just symbols with no actual data associated, in this + // section we would look at how we can bridge that gap. + + // Let's start by constructing a simple computation for us to work with: + Placeholder A("A", kInt, {64, 32}); + Placeholder B("B", kInt, {64, 32}); + Tensor* X = Compute( + "X", + {{64, "i"}, {32, "j"}}, + [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j) + B.load(i, j); + }); + + // And let's lower it to a loop nest, as we did in the previous section: + LoopNest loopnest({X}); + std::cout << *loopnest.root_stmt() << std::endl; + // Prints: + // { + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // X[i, j] = (A[i, j]) + (B[i, j]); + // } + // } + + // Now imagine that we have two actual tensors 64x32 that we want sum + // together, how do we pass those tensors to the computation and how do we + // carry it out? + // + // Codegen object is aimed at providing exactly that functionality. Codegen + // is an abstract class and concrete codegens are derived from it. + // Currently, we have three codegens: + // 1) Simple Evaluator, + // 2) LLVM Codegen for CPU, + // 3) CUDA Codegen. + // In this example we will be using Simple Evaluator, since it's available + // everywhere. + + // To create a codegen, we need to provide the statement - it specifies the + // computation we want to perform - and a list of placeholders and tensors + // used in the computation. The latter part is crucial since that's the only + // way the codegen could use to correlate symbols in the statement to actual + // data arrays that we will be passing when we will actually be performing + // the computation. + // + // Let's create a Simple IR Evaluator codegen for our computation: + SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X}); + + // We are using the simplest codegen and in it almost no work is done at the + // construction step. Real codegens such as CUDA and LLVM perform + // compilation during that stage so that when we're about to run the + // computation everything is ready. + + // Let's now create some inputs and run our computation with them: + std::vector data_A(64 * 32, 3); // This will be the input A + std::vector data_B(64 * 32, 5); // This will be the input B + std::vector data_X(64 * 32, 0); // This will be used for the result + + // Now let's invoke our codegen to perform the computation on our data. We + // need to provide as many arguments as how many placeholders and tensors we + // passed at the codegen construction time. A position in these lists would + // define how real data arrays from the latter call (these arguments are + // referred to as 'CallArg's in our codebase) correspond to symbols + // (placeholders and tensors) used in the tensor expressions we constructed + // (these are referred to as 'BufferArg'). + // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A + // contains data for the placeholder A, data_B - for the placeholder B, and + // data_X would be used for contents of tensor X. + ir_eval(data_A, data_B, data_X); + + // Let's print one of the elements from each array to verify that the + // computation did happen: + std::cout << "A[10] = " << data_A[10] << std::endl + << "B[10] = " << data_B[10] << std::endl + << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl; + // Prints: + // A[10] = 3 + // B[10] = 5 + // X[10] = A[10] + B[10] = 8 + } + + // TODO: Show how TorchScript IR is translated to TE + return 0; +} diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md index b7ec61a5a9589..55d3f33f32b2d 100644 --- a/test/cpp_api_parity/parity-tracker.md +++ b/test/cpp_api_parity/parity-tracker.md @@ -88,11 +88,11 @@ torch::nn::GRU|Yes|No torch::nn::RNNCell|Yes|No torch::nn::LSTMCell|Yes|No torch::nn::GRUCell|Yes|No -torch::nn::Transformer|No|No +torch::nn::Transformer|Yes|No torch::nn::TransformerEncoder|No|No torch::nn::TransformerDecoder|No|No -torch::nn::TransformerEncoderLayer|No|No -torch::nn::TransformerDecoderLayer|No|No +torch::nn::TransformerEncoderLayer|Yes|No +torch::nn::TransformerDecoderLayer|Yes|No torch::nn::Identity|Yes|No torch::nn::Linear|Yes|No torch::nn::Bilinear|Yes|No @@ -125,6 +125,7 @@ torch::nn::CosineEmbeddingLoss|Yes|No torch::nn::MultiMarginLoss|Yes|No torch::nn::TripletMarginLoss|Yes|No torch::nn::PixelShuffle|Yes|No +torch::nn::PixelUnshuffle|Yes|No torch::nn::Upsample|Yes|No torch::nn::DataParallel|No|No torch::nn::parallel::DistributedDataParallel|No|No diff --git a/test/cpp_extensions/cpp_c10d_extension.cpp b/test/cpp_extensions/cpp_c10d_extension.cpp index 188484cf9248f..d01d07f208d73 100644 --- a/test/cpp_extensions/cpp_c10d_extension.cpp +++ b/test/cpp_extensions/cpp_c10d_extension.cpp @@ -23,96 +23,96 @@ ProcessGroupTest::ProcessGroupTest(int rank, int size) ProcessGroupTest::~ProcessGroupTest() {} -std::shared_ptr ProcessGroupTest::broadcast( +c10::intrusive_ptr ProcessGroupTest::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce( +c10::intrusive_ptr ProcessGroupTest::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - return std::make_shared(); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupTest::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced"); } -std::shared_ptr ProcessGroupTest::reduce( +c10::intrusive_ptr ProcessGroupTest::reduce( std::vector& tensors, const ReduceOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce"); } -std::shared_ptr ProcessGroupTest::allgather( +c10::intrusive_ptr ProcessGroupTest::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather"); } -std::shared_ptr ProcessGroupTest::allgather_base( +c10::intrusive_ptr ProcessGroupTest::allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support allgather_base"); } -std::shared_ptr ProcessGroupTest::barrier( +c10::intrusive_ptr ProcessGroupTest::barrier( const BarrierOptions& opts) { - throw std::runtime_error("ProcessGroupTest does not support barrier"); + return c10::make_intrusive(); } -std::shared_ptr ProcessGroupTest::gather( +c10::intrusive_ptr ProcessGroupTest::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support gather"); } -std::shared_ptr ProcessGroupTest::scatter( +c10::intrusive_ptr ProcessGroupTest::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support scatter"); } -std::shared_ptr ProcessGroupTest::reduce_scatter( +c10::intrusive_ptr ProcessGroupTest::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupTest does not support reduce_scatter"); } -std::shared_ptr ProcessGroupTest::send( +c10::intrusive_ptr ProcessGroupTest::send( std::vector& tensors, int dstRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support send"); } -std::shared_ptr ProcessGroupTest::recv( +c10::intrusive_ptr ProcessGroupTest::recv( std::vector& tensors, int srcRank, int tag) { throw std::runtime_error("ProcessGroupTest does not support recv"); } -std::shared_ptr ProcessGroupTest::recvAnysource( +c10::intrusive_ptr ProcessGroupTest::recvAnysource( std::vector& tensor, int tag) { throw std::runtime_error("ProcessGroupTest does not support recvAnysource"); } -std::shared_ptr ProcessGroupTest::createProcessGroupTest( - const std::shared_ptr<::c10d::Store>& store, +c10::intrusive_ptr ProcessGroupTest::createProcessGroupTest( + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout) { - return std::make_shared(rank, size); + return c10::make_intrusive(rank, size); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/test/cpp_extensions/cpp_c10d_extension.hpp b/test/cpp_extensions/cpp_c10d_extension.hpp index d8dffcd20327d..6b5070e306e94 100644 --- a/test/cpp_extensions/cpp_c10d_extension.hpp +++ b/test/cpp_extensions/cpp_c10d_extension.hpp @@ -41,67 +41,67 @@ class ProcessGroupTest : public ProcessGroup { explicit ProcessGroupTest(int rank = -1, int size = -1); virtual ~ProcessGroupTest(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, - int tag); + int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, - int tag); + int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, - int tag); + int tag) override; // Create a new ProcessGroupTest instance - static std::shared_ptr createProcessGroupTest( - const std::shared_ptr<::c10d::Store>& store, + static c10::intrusive_ptr createProcessGroupTest( + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout); diff --git a/test/cpp_extensions/cuda_extension.cu b/test/cpp_extensions/cuda_extension.cu index 29511af8a0ed3..0c23d89df889d 100644 --- a/test/cpp_extensions/cuda_extension.cu +++ b/test/cpp_extensions/cuda_extension.cu @@ -6,6 +6,7 @@ #include #include +#include #include @@ -26,4 +27,5 @@ void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) { const int threads = 1024; const int blocks = (size + threads - 1) / threads; sigmoid_add_kernel<<>>(x, y, output, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/test/cpp_extensions/cuda_extension_kernel.cu b/test/cpp_extensions/cuda_extension_kernel.cu index 6602199898631..4a942b0a20aff 100644 --- a/test/cpp_extensions/cuda_extension_kernel.cu +++ b/test/cpp_extensions/cuda_extension_kernel.cu @@ -1,5 +1,6 @@ #include #include +#include #include @@ -20,4 +21,5 @@ void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) { const int threads = 1024; const int blocks = (size + threads - 1) / threads; sigmoid_add_kernel<<>>(x, y, output, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/test/cpp_extensions/cuda_extension_kernel2.cu b/test/cpp_extensions/cuda_extension_kernel2.cu index 817bdf64ac8ed..ddb240e5d067a 100644 --- a/test/cpp_extensions/cuda_extension_kernel2.cu +++ b/test/cpp_extensions/cuda_extension_kernel2.cu @@ -1,5 +1,6 @@ #include #include +#include #include @@ -20,4 +21,5 @@ void tanh_add_cuda(const float* x, const float* y, float* output, int size) { const int threads = 1024; const int blocks = (size + threads - 1) / threads; tanh_add_kernel<<>>(x, y, output, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/test/cpp_extensions/msnpu_extension.cpp b/test/cpp_extensions/msnpu_extension.cpp index 62f046e0037cd..ea67910f96da9 100644 --- a/test/cpp_extensions/msnpu_extension.cpp +++ b/test/cpp_extensions/msnpu_extension.cpp @@ -20,9 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { return Tensor(std::move(tensor_impl)); } -Tensor empty_override(IntArrayRef size, const TensorOptions& options, c10::optional optional_memory_format) { +Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, + c10::optional pin_memory, c10::optional optional_memory_format) { test_int = 0; - return get_tensor(options.dtype(), size); + return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size); } Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) { @@ -52,10 +53,10 @@ std::tuple fake_convolution_backward( } TORCH_LIBRARY_IMPL(aten, MSNPU, m) { - m.impl_UNBOXED("empty.memory_format", empty_override); - m.impl_UNBOXED("add.Tensor", add_override); - m.impl_UNBOXED("convolution_overrideable", fake_convolution); - m.impl_UNBOXED("convolution_backward_overrideable", fake_convolution_backward); + m.impl("empty.memory_format", empty_override); + m.impl("add.Tensor", add_override); + m.impl("convolution_overrideable", fake_convolution); + m.impl("convolution_backward_overrideable", fake_convolution_backward); } // TODO: Extend this to exercise multi-device setting. In that case, diff --git a/test/cpp_extensions/rng_extension.cpp b/test/cpp_extensions/rng_extension.cpp index bf16a840dfc95..f3ab91fb3cab7 100644 --- a/test/cpp_extensions/rng_extension.cpp +++ b/test/cpp_extensions/rng_extension.cpp @@ -22,6 +22,8 @@ struct TestCPUGenerator : public c10::GeneratorImpl { void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); } uint64_t current_seed() const override { throw std::runtime_error("not implemented"); } uint64_t seed() override { throw std::runtime_error("not implemented"); } + void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); } + c10::intrusive_ptr get_state() const override { throw std::runtime_error("not implemented"); } TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); } static DeviceType device_type() { return DeviceType::CPU; } @@ -54,9 +56,9 @@ size_t getInstanceCount() { } TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) { - m.impl_UNBOXED("aten::random_.from", random_from_to); - m.impl_UNBOXED("aten::random_.to", random_to); - m.impl_UNBOXED("aten::random_", random_); + m.impl("aten::random_.from", random_from_to); + m.impl("aten::random_.to", random_to); + m.impl("aten::random_", random_); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index fb18dd2dd67a1..b3bee8f6ff495 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -4,6 +4,7 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME +from torch.testing._internal.common_utils import IS_WINDOWS if sys.platform == 'win32': vc_version = os.getenv('VCToolsVersion', '') @@ -28,7 +29,7 @@ extra_compile_args=CXX_FLAGS), ] -if torch.cuda.is_available() and CUDA_HOME is not None: +if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None): extension = CUDAExtension( 'torch_test_cpp_extension.cuda', [ 'cuda_extension.cpp', @@ -38,22 +39,16 @@ extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': ['-O2']}) ext_modules.append(extension) -elif torch.cuda.is_available() and ROCM_HOME is not None: - from torch.utils.hipify import hipify_python - this_dir = os.path.dirname(os.path.abspath(__file__)) - hipify_python.hipify( - project_directory=this_dir, - output_directory=this_dir, - includes="./*", - show_detailed=True, - is_pytorch_extension=True,) - extension = CUDAExtension( - 'torch_test_cpp_extension.cuda', [ - 'cuda_extension.cpp', - 'hip/hip_extension_kernel.hip', - 'hip/hip_extension_kernel2.hip', - ]) - ext_modules.append(extension) + +if not IS_WINDOWS: # MSVC has bug compiling this example + if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None): + extension = CUDAExtension( + 'torch_test_cpp_extension.torch_library', [ + 'torch_library.cu' + ], + extra_compile_args={'cxx': CXX_FLAGS, + 'nvcc': ['-O2']}) + ext_modules.append(extension) setup( name='torch_test_cpp_extension', diff --git a/test/cpp_extensions/torch_library.cu b/test/cpp_extensions/torch_library.cu new file mode 100644 index 0000000000000..1bfbb8ab95caa --- /dev/null +++ b/test/cpp_extensions/torch_library.cu @@ -0,0 +1,9 @@ +#include + +bool logical_and(bool a, bool b) { return a && b; } + +TORCH_LIBRARY(torch_library, m) { + m.def("logical_and", &logical_and); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/test/custom_backend/backend.py b/test/custom_backend/backend.py index 17e399d320a7f..8b48ed0a4108e 100644 --- a/test/custom_backend/backend.py +++ b/test/custom_backend/backend.py @@ -33,7 +33,7 @@ def to_custom_backend(module): Returns: The module, lowered so that it can run on TestBackend. """ - lowered_module = torch._C._jit_to_backend("custom_backend", module._c, {"forward": {"": ""}}) + lowered_module = torch._C._jit_to_backend("custom_backend", module, {"forward": {"": ""}}) return lowered_module diff --git a/test/custom_operator/op.cpp b/test/custom_operator/op.cpp index 6e9a8eb7ab895..dd8ca4344bc1e 100644 --- a/test/custom_operator/op.cpp +++ b/test/custom_operator/op.cpp @@ -61,17 +61,9 @@ torch::Tensor custom_op_with_autograd( return CustomOpAutogradFunction::apply(var1, mul, var2, var3); } -static auto registry = - torch::RegisterOperators() - // We parse the schema for the user. - .op("custom::op", &custom_op) - .op("custom::op2", &custom_op2) - - // User provided schema. Among other things, allows defaulting values, - // because we cannot infer default values from the signature. It also - // gives arguments meaningful names. - .op("custom::op_with_defaults(Tensor tensor, float scalar = 1, int repeat = 1) -> Tensor[]", - &custom_op) - - .op("custom::op_with_autograd(Tensor var1, int mul, Tensor var2, Tensor? var3=None) -> Tensor", - &custom_op_with_autograd); +TORCH_LIBRARY_FRAGMENT(custom, m) { + m.def("op", custom_op); + m.def("op2", custom_op2); + m.def("op_with_defaults(Tensor tensor, float scalar = 1, int repeat = 1) -> Tensor[]", custom_op); + m.def("op_with_autograd(Tensor var1, int mul, Tensor var2, Tensor? var3=None) -> Tensor", custom_op_with_autograd); +} diff --git a/test/distributed/pipeline/sync/LICENSE b/test/distributed/pipeline/sync/LICENSE new file mode 100644 index 0000000000000..e52be240fdc98 --- /dev/null +++ b/test/distributed/pipeline/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright 2019-2020 Kakao Brain + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/test/distributed/pipeline/sync/__init__.py b/test/distributed/pipeline/sync/__init__.py new file mode 100644 index 0000000000000..94cd5bcb415e0 --- /dev/null +++ b/test/distributed/pipeline/sync/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. +# See also: https://docs.pytest.org/en/latest/goodpractices.html diff --git a/test/distributed/pipeline/sync/conftest.py b/test/distributed/pipeline/sync/conftest.py new file mode 100644 index 0000000000000..561c41d11350c --- /dev/null +++ b/test/distributed/pipeline/sync/conftest.py @@ -0,0 +1,53 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import tempfile +import torch +from torch.distributed import rpc + + +@pytest.fixture(autouse=True) +def manual_seed_zero(): + torch.manual_seed(0) + + +@pytest.fixture(scope="session") +def cuda_sleep(): + # Warm-up CUDA. + torch.empty(1, device="cuda") + + # From test/test_cuda.py in PyTorch. + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + + def cuda_sleep(seconds): + torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) + + return cuda_sleep + + +def pytest_report_header(): + return f"torch: {torch.__version__}" + +@pytest.fixture +def setup_rpc(scope="session"): + file = tempfile.NamedTemporaryFile() + rpc.init_rpc( + name="worker0", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(file.name), + ) + ) + yield + rpc.shutdown() diff --git a/test/distributed/pipeline/sync/skip/__init__.py b/test/distributed/pipeline/sync/skip/__init__.py new file mode 100644 index 0000000000000..ab03724cafbf5 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/distributed/pipeline/sync/skip/test_api.py b/test/distributed/pipeline/sync/skip/test_api.py new file mode 100644 index 0000000000000..e4a053131069e --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_api.py @@ -0,0 +1,45 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import copy + +from torch import nn + +from torch.distributed.pipeline.sync.skip import Namespace, skippable, stash + + +def test_namespace_difference(): + ns1 = Namespace() + ns2 = Namespace() + assert ns1 != ns2 + + +def test_namespace_copy(): + ns = Namespace() + assert copy.copy(ns) == ns + assert copy.copy(ns) is not ns + + +def test_skippable_repr(): + @skippable(stash=["hello"]) + class Hello(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, x): + yield stash("hello", x) + return self.conv(x) # noqa + + m = Hello() + assert ( + repr(m) + == """ +@skippable(Hello( + (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) +)) +""".strip() + ) diff --git a/test/distributed/pipeline/sync/skip/test_gpipe.py b/test/distributed/pipeline/sync/skip/test_gpipe.py new file mode 100644 index 0000000000000..885564ca1840a --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_gpipe.py @@ -0,0 +1,108 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange +from torch.testing._internal.distributed.pipeline.utils import convert_to_balance + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_1to3(balance, checkpoint, setup_rpc): + if torch.cuda.device_count() < len(balance): + pytest.skip("at least %d cuda devices required" % len(balance)) + + @skippable(stash=["1to3"]) + class Layer1(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + yield stash("1to3", input) + output = self.conv(input) + return output # noqa + + class Layer2(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + output = self.conv(input) + return output + + @skippable(pop=["1to3"]) + class Layer3(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + skip_1to3 = yield pop("1to3") + output = self.conv(input) + skip_1to3 + return output + + model = nn.Sequential(Layer1(), Layer2(), Layer3()) + model = convert_to_balance(model, balance) + model = Pipe(model, chunks=3, checkpoint=checkpoint) + + in_device = model.devices[0] + out_device = model.devices[-1] + + input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) + output = model(input) + loss = output.local_value().mean() + loss.backward() + + assert torch.allclose(output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) + assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device)) + + +def test_none_skip(setup_rpc): + @skippable(stash=["none"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("none", None) + return input # noqa + + @skippable(pop=["none"]) + class Pop(nn.Module): + def forward(self, input): + none = yield pop("none") + assert none is None + return input + + model = nn.Sequential(Stash(), Pop()) + model = Pipe(model, chunks=5) + + input = torch.rand(10, requires_grad=True) + output = model(input) + + def assert_grad_fn_is_not_portal(grad_fn, visited=None): + if visited is None: + visited = set() + if grad_fn in visited or grad_fn is None: + return + + assert not isinstance(grad_fn, PortalBlue._backward_cls) + assert not isinstance(grad_fn, PortalCopy._backward_cls) + assert not isinstance(grad_fn, PortalOrange._backward_cls) + + visited.add(grad_fn) + for next_grad_fn, _ in grad_fn.next_functions: + assert_grad_fn_is_not_portal(next_grad_fn, visited) + + assert_grad_fn_is_not_portal(output.local_value().grad_fn) + + output.local_value().sum().backward() + assert input.grad.mean().item() == 1 diff --git a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py new file mode 100644 index 0000000000000..8275f25e22225 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py @@ -0,0 +1,111 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from torch import nn + +from torch.distributed.pipeline.sync.skip import Namespace, pop, skippable, stash +from torch.distributed.pipeline.sync.skip.layout import inspect_skip_layout + + +class Pass(nn.Module): + def forward(self, input): + return input + + +@skippable(stash=["foo"]) +class StashFoo(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input # noqa + + +@skippable(pop=["foo"]) +class PopFoo(nn.Module): + def forward(self, input): + foo = yield stash("foo") + return input + foo + + +@skippable(stash=["bar"]) +class StashBar(nn.Module): + def forward(self, input): + yield stash("bar", input) + return input # noqa + + +@skippable(pop=["bar"]) +class PopBar(nn.Module): + def forward(self, input): + bar = yield pop("bar") + return input + bar + + +def test_no_skippables(): + p1 = nn.Sequential(Pass()) + p2 = nn.Sequential(Pass()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], []] + + +def test_inner_partition(): + p1 = nn.Sequential(StashFoo(), PopFoo()) + p2 = nn.Sequential(Pass()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], []] + + +def test_adjoining_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(PopFoo()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], [(0, None, "foo")]] + + +def test_far_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(Pass()) + p3 = nn.Sequential(PopFoo()) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + assert policy == [[], [], [(0, None, "foo")]] + + +def test_pop_2_from_different_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(StashBar()) + p3 = nn.Sequential(PopBar(), PopFoo()) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. + assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]] + + +def test_namespace(): + ns1 = Namespace() + ns2 = Namespace() + + p1 = nn.Sequential(StashFoo().isolate(ns1)) + p2 = nn.Sequential(StashFoo().isolate(ns2)) + p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. + assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]] diff --git a/test/distributed/pipeline/sync/skip/test_leak.py b/test/distributed/pipeline/sync/skip/test_leak.py new file mode 100644 index 0000000000000..c8e348fb6fc25 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_leak.py @@ -0,0 +1,126 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe, is_checkpointing, is_recomputing +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.tracker import current_skip_tracker + + +@skippable(stash=["skip"]) +class Stash(nn.Module): + def forward(self, input): + yield stash("skip", input) + return input # noqa + + +@skippable(pop=["skip"]) +class Pop(nn.Module): + def forward(self, input): + skip = yield pop("skip") + return input + skip + + +@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) +@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) +def test_delete_portal_tensor(train, checkpoint, setup_rpc): + # Without checkpointing: + # +- Stash --+ +--- Pop ----+ - - - layers + # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function + # +----------+ +------------+ + # + # With checkpointing: + # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ + # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | + # +----------+ +------------+ +------------+ +----------+ + + def portal_tensor_life_is(tensor_life, skip_tracker=None): + if skip_tracker is None: + skip_tracker = current_skip_tracker() + + # Get the current portal. + portal = list(skip_tracker.portals.values())[0] + + if tensor_life == 0: + return portal.tensor_life == 0 and portal.tensor is None + else: + return portal.tensor_life == tensor_life and portal.tensor is not None + + # Check the portal tensor after 'Stash'. + stash_ = Stash() + + @stash_.register_forward_hook + def check_portal_tensor_after_stash(*_): + if is_checkpointing(): + assert portal_tensor_life_is(2) + elif is_recomputing(): + assert portal_tensor_life_is(0) + else: + assert portal_tensor_life_is(1) + + pop_ = Pop() + + @pop_.register_forward_hook + def check_portal_tensor_after_pop(*_): + if is_checkpointing(): + assert portal_tensor_life_is(1) + elif is_recomputing(): + assert portal_tensor_life_is(0) + else: + assert portal_tensor_life_is(0) + + class NoPortalTensorAtBackward(nn.Module): + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.skip_tracker = current_skip_tracker() + return input.detach() + + @staticmethod + def backward(ctx, grad): + assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) + return grad + + def forward(self, input): + return self.F.apply(input) + + model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) + model = Pipe(model, chunks=2, checkpoint=checkpoint) + + input = torch.rand(10, requires_grad=True) + + if train: + model.train() + output = model(input).local_value() + output.norm().backward() + else: + model.eval() + with torch.no_grad(): + model(input) + + +@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) +def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): + def deny(*args, **kwargs): + raise AssertionError("tried to create Portal without Pipe") + + monkeypatch.setattr("torch.distributed.pipeline.sync.skip.portal.Portal.__init__", deny) + + model = nn.Sequential(Stash(), Pop()) + + input = torch.rand(10, requires_grad=True) + + if train: + model.train() + output = model(input) + output.norm().backward() + else: + model.eval() + with torch.no_grad(): + model(input) diff --git a/test/distributed/pipeline/sync/skip/test_portal.py b/test/distributed/pipeline/sync/skip/test_portal.py new file mode 100644 index 0000000000000..c637ac86f5813 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_portal.py @@ -0,0 +1,155 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch + +from torch.distributed.pipeline.sync.dependency import fork, join +from torch.distributed.pipeline.sync.skip.portal import Portal +from torch.distributed.pipeline.sync.stream import default_stream + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_copy_returns_on_next_device(): + portal = Portal(torch.rand(1), tensor_life=1) + + prev_stream = default_stream(torch.device("cpu")) + next_stream = default_stream(torch.device("cuda")) + + phony = torch.zeros(0, requires_grad=True) + assert phony.device.type == "cpu" + + phony = portal.copy(prev_stream, next_stream, phony) + assert phony.device.type == "cuda" + + +def test_blue_orange(): + tensor1 = torch.rand(1, requires_grad=True) + tensor2 = torch.rand(1, requires_grad=True) + + # Same with: output = tensor1*2 + tensor2 + # + # +----------------------+ + # | | + # tensor2 -- PortalBlue -+ +- PortalOrange -+ + # | | | + # tensor1 ------------ Join -- Fork --- Mul --- Add -- output + # + main = tensor1 + portal = Portal(tensor2, tensor_life=2) + phony = portal.blue() + main = join(main, phony) + main, phony = fork(main) + sub = portal.orange(phony) + output = main * 2 + sub + + output.backward() + + assert torch.allclose(tensor1.grad, torch.tensor([2.0])) + assert torch.allclose(tensor2.grad, torch.tensor([1.0])) + + +def test_blue_orange_not_requires_grad(): + tensor1 = torch.rand(1, requires_grad=True) + tensor2 = torch.rand(1) + + # Same with: output = tensor1*2 + tensor2 + # + # +----------------------+ + # | | + # tensor2 -- PortalBlue -+ +- PortalOrange -+ + # | | | + # tensor1 ------------ Join -- Fork --- Mul --- Add -- output + # + main = tensor1 + portal = Portal(tensor2, tensor_life=2) + phony = portal.blue() + main = join(main, phony) + main, phony = fork(main) + sub = portal.orange(phony) + output = main * 2 + sub + + output.backward() + + assert torch.allclose(tensor1.grad, torch.tensor([2.0])) + assert tensor2.grad is None + + +def test_use_grad(): + tensor = torch.rand(1, requires_grad=True) + portal = Portal(tensor, tensor_life=1) + + portal.put_grad(tensor) + assert portal.use_grad() is tensor + + # Gradient in a portal is ephemeral. + with pytest.raises(RuntimeError): + portal.use_grad() + + +class TestTensorLife: + @pytest.fixture + def new_portal(self): + portal = None + + def new_portal(tensor_life): + nonlocal portal + tensor = torch.rand(1, requires_grad=True) + portal = Portal(tensor, tensor_life) + return portal, tensor + + yield new_portal + + # A test using this fixture must exhaust the tensor in the portal. + with pytest.raises(RuntimeError): + portal.check_tensor_life() + assert portal.tensor is None + + def test_tensor_life_0(self, new_portal): + portal, tensor = new_portal(0) + assert portal.tensor is None + + def test_tensor_life_1(self, new_portal): + portal, tensor = new_portal(1) + assert portal.tensor is tensor + + portal.blue() + + def test_tensor_life_2(self, new_portal): + portal, tensor = new_portal(2) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + def test_tensor_life_3(self, new_portal): + portal, tensor = new_portal(3) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + def test_tensor_life_4(self, new_portal): + portal, tensor = new_portal(4) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + portal.blue() + + def test_tensor_life_3_plus_1(self, new_portal): + portal, tensor = new_portal(3) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + another_tensor = torch.rand(1, requires_grad=True) + portal.put_tensor(another_tensor, tensor_life=1) + portal.blue() diff --git a/test/distributed/pipeline/sync/skip/test_stash_pop.py b/test/distributed/pipeline/sync/skip/test_stash_pop.py new file mode 100644 index 0000000000000..6961f645b128e --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_stash_pop.py @@ -0,0 +1,136 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker + + +@pytest.fixture(autouse=True) +def skip_tracker(): + skip_tracker = SkipTracker() + with use_skip_tracker(skip_tracker): + yield skip_tracker + + +def test_stash(skip_tracker): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + l1 = Stash() + + assert len(skip_tracker.tensors) == 0 + + with use_skip_tracker(skip_tracker): + l1(torch.tensor(42)) + + assert len(skip_tracker.tensors) == 1 + + +def test_pop(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo # noqa + + l1 = Stash() + l2 = Pop() + + output = l2(l1(torch.tensor(42))) + + assert output.item() == 42 + + +def test_declare_but_not_use(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + return input * 2 + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + return input * 3 + + l1 = Stash() + l2 = Pop() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + with pytest.raises(RuntimeError): + l2(torch.tensor(42)) + + +def test_stash_not_declared(): + @skippable() + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + l1 = Stash() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + +def test_pop_not_declared(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + @skippable() + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo # noqa + + l1 = Stash() + l2 = Pop() + + latent = l1(torch.tensor(42)) + + with pytest.raises(RuntimeError): + l2(latent) + + +def test_pop_not_stashed(): + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + yield pop("foo") + + l1 = Pop() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + +def test_stash_none(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", None) + return input * 2 # noqa + + l1 = Stash() + l1(torch.tensor(42)) diff --git a/test/distributed/pipeline/sync/skip/test_tracker.py b/test/distributed/pipeline/sync/skip/test_tracker.py new file mode 100644 index 0000000000000..e8036c77340c1 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_tracker.py @@ -0,0 +1,127 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from queue import Queue +import threading + +import pytest +import torch +from torch import nn + +from torch.distributed.pipeline.sync.checkpoint import enable_checkpointing, enable_recomputing +from torch.distributed.pipeline.sync.microbatch import Batch +from torch.distributed.pipeline.sync.skip import pop, skippable, stash +from torch.distributed.pipeline.sync.skip.layout import SkipLayout +from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, SkipTrackerThroughPotals, current_skip_tracker + + +def test_default_skip_tracker(): + q = Queue() + + def f(): + q.put(current_skip_tracker()) + + t = threading.Thread(target=f) + t.start() + t.join() + + skip_tracker = q.get() + + assert type(skip_tracker) is SkipTracker + assert type(skip_tracker) is not SkipTrackerThroughPotals + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_default_skip_tracker_by_data_parallel(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo + + model = nn.Sequential(Stash(), Pop()) + model = nn.DataParallel(model, device_ids=[0, 0], output_device=0) + + input = torch.rand(10, device=0) + output = model(input) + + assert torch.allclose(output, input) + + +def test_reuse_portal(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + a = torch.tensor([2.0]) + b = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "test", a) + portal = skip_tracker.portals[(None, "test")] + + skip_tracker.save(batch, None, "test", b) + assert portal is skip_tracker.portals[(None, "test")] + + +def test_no_copy_no_portal(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + a = torch.tensor([2.0]) + b = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "copy", a) + skip_tracker.save(batch, None, "not_copy", b) + + assert (None, "copy") in skip_tracker.portals + assert (None, "copy") not in skip_tracker.tensors + assert (None, "not_copy") in skip_tracker.tensors + assert (None, "not_copy") not in skip_tracker.portals + + +def test_tensor_life_without_checkpointing(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + tensor = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 1 + + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 0 + + +def test_tensor_life_with_checkpointing(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + tensor = torch.tensor([2.0]) + + with enable_checkpointing(): + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 2 + + with enable_checkpointing(): + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 1 + + with enable_recomputing(): + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 0 + + with enable_recomputing(): + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 0 diff --git a/test/distributed/pipeline/sync/skip/test_verify_skippables.py b/test/distributed/pipeline/sync/skip/test_verify_skippables.py new file mode 100644 index 0000000000000..6f9dd510493d6 --- /dev/null +++ b/test/distributed/pipeline/sync/skip/test_verify_skippables.py @@ -0,0 +1,152 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +from torch import nn + +from torch.distributed.pipeline.sync.skip import Namespace, skippable, verify_skippables + + +def test_matching(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + verify_skippables(nn.Sequential(Layer1(), Layer2())) + + +def test_stash_not_pop(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "no module declared 'foo' as poppable but stashed" in str(e.value) + + +def test_pop_unknown(): + @skippable(pop=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) + + +def test_stash_again(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer3(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + assert "'1' redeclared 'foo' as stashable" in str(e.value) + + +def test_pop_again(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer3(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + assert "'2' redeclared 'foo' as poppable" in str(e.value) + + +def test_stash_pop_together_different_names(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"], stash=["bar"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["bar"]) + class Layer3(nn.Module): + pass + + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + + +def test_stash_pop_together_same_name(): + @skippable(stash=["foo"], pop=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) + + +def test_double_stash_pop(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer3(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer4(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) + assert "'2' redeclared 'foo' as stashable" in str(e.value) + assert "'3' redeclared 'foo' as poppable" in str(e.value) + + +def test_double_stash_pop_but_isolated(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer3(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer4(nn.Module): + pass + + ns1 = Namespace() + ns2 = Namespace() + + verify_skippables( + nn.Sequential(Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2),) + ) diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py new file mode 100644 index 0000000000000..5aa9ec6d454e9 --- /dev/null +++ b/test/distributed/pipeline/sync/test_balance.py @@ -0,0 +1,225 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import time + +import pytest +import torch +from torch import nn + +from torch.distributed.pipeline.sync._balance import balance_by_size, balance_by_time, blockpartition +from torch.distributed.pipeline.sync._balance.profile import layerwise_sandbox + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + +devices = ["cpu"] +if torch.cuda.is_available(): + devices.append("cuda") + + +def test_blockpartition(): + assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]] + + +def test_blockpartition_zeros(): + assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] + + +def test_blockpartition_non_positive_partitions(): + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=0) + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=-1) + + +def test_blockpartition_short_sequence(): + with pytest.raises(ValueError): + blockpartition.solve([], partitions=1) + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=2) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.skip(reason="Flaky due to time.sleep()") +def test_balance_by_time(device): + class Delay(nn.Module): + def __init__(self, seconds): + super().__init__() + self.seconds = seconds + + def forward(self, x): + time.sleep(self.seconds) + return x + + model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) + sample = torch.rand(1) + balance = balance_by_time(2, model, sample, device=device) + assert balance == [4, 2] + + +def test_balance_by_time_loop_resets_input(): + # nn.Flatten was introduced at PyTorch 1.2.0. + class Flatten(nn.Module): + def forward(self, x): + return x.flatten(1) + + model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) + sample = torch.rand(10, 3, 8, 8) + balance = balance_by_time(2, model, sample, device="cpu") + assert balance == [1, 2] + + +@skip_if_no_cuda +def test_balance_by_size_latent(): + class Expand(nn.Module): + def __init__(self, times): + super().__init__() + self.times = times + + def forward(self, x): + for i in range(self.times): + x = x + torch.rand_like(x, requires_grad=True) + return x + + sample = torch.rand(10, 100, 100) + + model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) + balance = balance_by_size(2, model, sample) + assert balance == [4, 2] + + model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) + balance = balance_by_size(2, model, sample) + assert balance == [2, 4] + + +@skip_if_no_cuda +def test_balance_by_size_param(): + model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) + sample = torch.rand(7, 1) + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [4, 2] + + model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) + sample = torch.rand(1, 7) + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [2, 4] + + +@skip_if_no_cuda +def test_balance_by_size_param_scale(): + class Tradeoff(nn.Module): + def __init__(self, param_size, latent_size): + super().__init__() + self.fc = nn.Linear(param_size, param_size) + self.latent_size = latent_size + + def forward(self, x): + for i in range(self.latent_size): + x = x + torch.rand_like(x, requires_grad=True) + return x + + model = nn.Sequential( + Tradeoff(param_size=1, latent_size=6), + Tradeoff(param_size=2, latent_size=5), + Tradeoff(param_size=3, latent_size=4), + Tradeoff(param_size=4, latent_size=3), + Tradeoff(param_size=5, latent_size=2), + Tradeoff(param_size=6, latent_size=1), + ) + + sample = torch.rand(1, requires_grad=True) + + balance = balance_by_size(2, model, sample, param_scale=0) + assert balance == [2, 4] + + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [4, 2] + + +@pytest.mark.parametrize("device", devices) +def test_layerwise_sandbox(device): + model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) + model.eval() + + for layer in layerwise_sandbox(model, torch.device(device)): + assert layer.training + assert all(p.device.type == device for p in layer.parameters()) + + assert all(not l.training for l in model) + assert all(p.device.type == "cpu" for p in model.parameters()) + + +@pytest.mark.parametrize("device", devices) +def test_sandbox_during_profiling(device): + model = nn.Sequential(nn.BatchNorm2d(3)) + + before = {k: v.clone() for k, v in model.state_dict().items()} + + sample = torch.rand(1, 3, 10, 10) + balance_by_time(1, model, sample, device=device) + + after = model.state_dict() + + assert before.keys() == after.keys() + for key, value in before.items(): + assert torch.allclose(after[key], value), key + + +def test_not_training(): + class AssertTraining(nn.Module): + def forward(self, x): + assert self.training + return x + + model = nn.Sequential(AssertTraining()) + + model.eval() + assert not model.training + + sample = torch.rand(1) + balance_by_time(1, model, sample, device="cpu") + + assert not model.training + + +def test_balance_by_time_tuple(): + class Twin(nn.Module): + def forward(self, x): + return x, x.detach() + + class Add(nn.Module): + def forward(self, a_b): + a, b = a_b + return a + b + + model = nn.Sequential(Twin(), Add()) + sample = torch.rand(1, requires_grad=True) + balance_by_time(1, model, sample, device="cpu") + + +@skip_if_no_cuda +def test_balance_by_size_tuple(): + class Twin(nn.Module): + def forward(self, x): + return x, x.detach() + + class Add(nn.Module): + def forward(self, a_b): + a, b = a_b + return a + b + + model = nn.Sequential(Twin(), Add()) + sample = torch.rand(1, requires_grad=True) + balance_by_size(1, model, sample) + + +def test_already_has_grad(): + model = nn.Sequential(nn.Conv2d(3, 3, 1)) + sample = torch.rand(1, 3, 32, 32) + model(sample).norm().backward() + + with pytest.raises(ValueError, match="some parameter already has gradient"): + balance_by_time(1, model, sample, device="cpu") diff --git a/test/distributed/pipeline/sync/test_bugs.py b/test/distributed/pipeline/sync/test_bugs.py new file mode 100644 index 0000000000000..580e58bf58bcb --- /dev/null +++ b/test/distributed/pipeline/sync/test_bugs.py @@ -0,0 +1,139 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn +import torch.nn.functional as F + +from torch.distributed.pipeline.sync import Pipe + + +def test_python_autograd_function(setup_rpc): + # A Python autograd function might fail with this error: + # + # RuntimeError: Returning Variables sharing storage with other Variables + # that require grad is not supported in Python functions. Please submit a + # feature request if you hit this error. + # + # It doesn't look like an essential restriction. But it happens on the + # current PyTorch version. To avoid it, we should detach the tensor before + # returning by identity autograd functions, such as Wait, Fork, and Join. + # + class Identity(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad): + return grad + + class M(nn.Module): + def forward(self, input): + return Identity.apply(input) + + model = nn.Sequential(M(), M()) + model = Pipe(model, checkpoint="always") + + x = torch.rand(42) + y = model(x) + assert torch.allclose(x, y.local_value()) + + +def test_exception_no_hang(setup_rpc): + # In v0.0.2, once a failed partition receives a normal message + # (non-closing) for the next micro-batch, a hang occured. The reason was + # that a failed partition didn't call in_queue.task_done() on a normal + # message. So the former partition was blocked at out_queue.join() for the + # next of next micro-batch. + class ExpectedException(Exception): + pass + + class Pass(nn.Module): + def forward(self, x): + return x + + class Raise(nn.Module): + def forward(self, x): + raise ExpectedException() + + model = nn.Sequential(Pass(), Pass(), Raise()) + model = Pipe(model, chunks=3) + + with pytest.raises(ExpectedException): + model(torch.rand(3)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") +def test_tuple_wait(cuda_sleep, setup_rpc): + # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. + # Under this behavior, if checkpointing was disabled, there's a possibility + # that gradient accumulations on other tensors are not synchronized + # properly to the copy stream. + class Sleep(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.detach() + + @staticmethod + def backward(ctx, grad): + with torch.cuda.device(grad.device): + cuda_sleep(0.05) + return grad + + class Layer1(nn.Module): + def __init__(self): + super().__init__() + self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) + + def forward(self, pair): + a, b = pair + a = a * self.ones + return a * 1, b * 2, b * 3 + + class Layer2(nn.Module): + def __init__(self): + super().__init__() + self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) + + def forward(self, triple): + a, b, c = triple + a = a * self.ones + b = Sleep.apply(b) + return a + b + c + + model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) + model = Pipe(model, chunks=32, checkpoint="never") + + a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) + b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) + + y = model((a, b)) + y.local_value().norm().backward() + + torch.cuda.synchronize(0) + torch.cuda.synchronize(1) + + assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) + + +def test_parallel_randoms(setup_rpc): + class Dropouts(nn.Module): + def forward(self, x): + for _ in range(100): + x = F.dropout(x, p=0.001) + return x + + model = nn.Sequential(Dropouts(), Dropouts()) + + x = torch.rand(10, 10, requires_grad=True) + model = Pipe(model, chunks=10, checkpoint="always") + y = model(x) + y = y.local_value() + y.norm().backward() + + assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() diff --git a/test/distributed/pipeline/sync/test_checkpoint.py b/test/distributed/pipeline/sync/test_checkpoint.py new file mode 100644 index 0000000000000..18553cba2f666 --- /dev/null +++ b/test/distributed/pipeline/sync/test_checkpoint.py @@ -0,0 +1,158 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from functools import partial + +import pytest +import torch +from torch import nn +import torch.cuda + +from torch.distributed.pipeline.sync.checkpoint import Checkpointing, checkpoint, is_checkpointing, is_recomputing +from torch.distributed.pipeline.sync.dependency import fork, join +from torch.distributed.pipeline.sync.microbatch import Batch + +devices = ["cpu"] +if torch.cuda.is_available(): + devices.append("cuda") + + +@pytest.mark.parametrize("device", devices) +def test_serial_checkpoints(device): + # Copied from https://github.com/pytorch/pytorch/pull/18568. + timeline = [] + + class Log(torch.autograd.Function): + @staticmethod + def forward(ctx, name, x): + ctx.name = name + timeline.append(f"{name}:forward") + return x.detach() + + @staticmethod + def backward(ctx, grad_output): + name = ctx.name + timeline.append(f"{name}:backward") + return None, grad_output + + a = torch.rand(1, device=device, requires_grad=True) + b = torch.rand(1, device=device, requires_grad=True) + + # Increase the next function sequence number. + _ = a + 1 + 2 + 3 + 4 + 5 + + a = checkpoint(partial(Log.apply, "a"), a) + + a, phony = fork(a) + b = join(b, phony) + + b = checkpoint(partial(Log.apply, "b"), b) + + c = torch.cat((a, b)) + + out = c.sum() + + # +--> {a} --Checkpoint(Log)--> {a} + # {out} --Sum--> {c} --Cat ^-----------------------------+ + # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} + out.backward() + + assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"] + # |----------------------| |-----------------------| |-----------------------| + # forward pass Checkpoint(Log[b]) Checkpoint(Log[a]) + + +def test_not_requires_grad(): + x = Batch(torch.rand(1, requires_grad=False)) + assert not x[0].requires_grad + + def f(x): + return x * 2 + + chk = Checkpointing(f, x) + x = chk.checkpoint() + assert x[0].requires_grad + + chk.recompute(x) + assert x[0].requires_grad + + x.tensor.backward() + + +def test_not_requires_grad_with_parameter(): + x = torch.rand(1, requires_grad=False) + a = torch.rand(1, requires_grad=True) + + def f(x): + return x * a + + y = checkpoint(f, x) + y.backward() + + assert a.grad is not None + + +@pytest.mark.parametrize("device", devices) +def test_random_in_checkpoint(device): + dropout = nn.Dropout(p=0.5) + + torch.manual_seed(0) + x = torch.randn(3, 3, device=device, requires_grad=True) + y = dropout(x) + y.norm().backward() + + torch.manual_seed(0) + chk_x = torch.randn(3, 3, device=device, requires_grad=True) + chk_y = checkpoint(dropout, chk_x) + chk_y.norm().backward() + + assert torch.allclose(x.grad, chk_x.grad) + + +def test_detect_checkpointing_recomputing(): + logs = [] + + class Detect(nn.Module): + def forward(self, input): + logs.append((is_checkpointing(), is_recomputing())) + return input + + model = Detect() + input = torch.rand(1, requires_grad=True) + + output = checkpoint(model, input) + output.backward() + + assert logs == [(True, False), (False, True)] + + +def test_detect_checkpointing_recomputing_without_checkpoint(): + logs = [] + + class Detect(nn.Module): + def forward(self, input): + logs.append((is_checkpointing(), is_recomputing())) + return input + + model = Detect() + input = torch.rand(1, requires_grad=True) + + output = model(input) + output.backward() + + assert logs == [(False, False)] + + +def test_non_grad_output(): + class ForkNonGrad(nn.Module): + def forward(self, input): + return (input * 2, torch.rand(1)) + + model = ForkNonGrad() + input = torch.rand(1, requires_grad=True) + + output = checkpoint(model, input) + output[0].backward() diff --git a/test/distributed/pipeline/sync/test_copy.py b/test/distributed/pipeline/sync/test_copy.py new file mode 100644 index 0000000000000..9ee792fd27d45 --- /dev/null +++ b/test/distributed/pipeline/sync/test_copy.py @@ -0,0 +1,68 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch + +from torch.distributed.pipeline.sync.copy import Copy, Wait +from torch.distributed.pipeline.sync.stream import CPUStream, current_stream, get_device, is_cuda, new_stream, use_stream + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + + +def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): + device = get_device(prev_stream) + + with use_stream(prev_stream): + if is_cuda(prev_stream): + cuda_sleep(0.5) + x = torch.ones(100, device=device, requires_grad=True) + + (y,) = Copy.apply(prev_stream, next_stream, x) + (y,) = Wait.apply(prev_stream, next_stream, x) + + with use_stream(next_stream): + assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) + y.norm().backward() + with use_stream(prev_stream): + assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) + + +def test_copy_wait_cpu_cpu(): + prev_stream = CPUStream + next_stream = CPUStream + _test_copy_wait(prev_stream, next_stream) + + +@skip_if_no_cuda +def test_copy_wait_cpu_cuda(cuda_sleep): + prev_stream = CPUStream + next_stream = current_stream(torch.device("cuda")) + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +@skip_if_no_cuda +def test_copy_wait_cuda_cpu(cuda_sleep): + prev_stream = current_stream(torch.device("cuda")) + next_stream = CPUStream + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +@skip_if_no_cuda +def test_copy_wait_cuda_cuda(cuda_sleep): + prev_stream = current_stream(torch.device("cuda")) + next_stream = new_stream(torch.device("cuda")) + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +def test_wait_multiple_tensors(): + a = torch.rand(1, requires_grad=True) + b = torch.rand(1, requires_grad=True) + + a, b = Wait.apply(CPUStream, CPUStream, a, b) + + assert a.grad_fn is b.grad_fn + assert a.grad_fn.__class__ is Wait._backward_cls diff --git a/test/distributed/pipeline/sync/test_deferred_batch_norm.py b/test/distributed/pipeline/sync/test_deferred_batch_norm.py new file mode 100644 index 0000000000000..cf3f86654804e --- /dev/null +++ b/test/distributed/pipeline/sync/test_deferred_batch_norm.py @@ -0,0 +1,192 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from copy import deepcopy +from itertools import chain + +import pytest +import torch +from torch import nn, optim + +from torch.distributed.pipeline.sync.batchnorm import DeferredBatchNorm + +CHUNKS = 4 + + +def tilt_dist(input): + # Tilt variance by channel. + rgb = input.transpose(0, 1) + rgb[0] *= 1 + rgb[1] *= 10 + rgb[2] *= 100 + + # Tilt mean by single batch. + for i, single in enumerate(input): + single += 2 ** i + + return input + + +def chunked_forward(model, input, chunks=CHUNKS): + output_chunks = [] + + for chunk in input.chunk(chunks): + output_chunks.append(model(chunk)) + + return torch.cat(output_chunks) + + +@pytest.mark.parametrize("chunks", [1, 4]) +@pytest.mark.parametrize("input_requires_grad", [True, False]) +def test_transparency(chunks, input_requires_grad): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) + + input1 = torch.rand(16, 3, 224, 224) + input1 = tilt_dist(input1) + input2 = input1.clone() + input1.requires_grad = input_requires_grad + input2.requires_grad = input_requires_grad + + output1 = chunked_forward(bn, input1, chunks=chunks) + output2 = chunked_forward(dbn, input2, chunks=chunks) + + assert torch.allclose(output1, output2, atol=1e-4) + + output1.mean().backward() + output2.mean().backward() + + assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) + + if input_requires_grad: + assert input1.grad is not None + assert input2.grad is not None + assert torch.allclose(input1.grad, input2.grad, atol=1e-4) + + +@pytest.mark.parametrize("momentum", [0.1, None]) +def test_running_stats(momentum): + bn = nn.BatchNorm2d(3, momentum=momentum) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + bn(input) + chunked_forward(dbn, input) + + assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) + assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) + + +def test_convert_deferred_batch_norm(): + bn = nn.BatchNorm2d(3, track_running_stats=False) + bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) + assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False + + dbn = DeferredBatchNorm(3, chunks=CHUNKS) + dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) + assert dbn is dbn_again + + dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) + assert dbn is not dbn_again # because of different chunks + + +def test_eval(): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + bn(input) + chunked_forward(dbn, input) + + bn.eval() + dbn.eval() + + assert torch.allclose(bn(input), dbn(input), atol=1e-4) + + +def test_optimize(): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) + + for i in range(5): + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + # train + y = bn(input) + a = y.sum() + a.backward() + + y = chunked_forward(dbn, input) + b = y.sum() + b.backward() + + opt.step() + + # eval + bn.eval() + dbn.eval() + + with torch.no_grad(): + assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10 ** i)) + + +def test_conv_bn(): + bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) + + # 1st step + a = bn(input) + b = chunked_forward(dbn, input) + + # Outputs are different. (per-mini-batch vs. per-micro-batch) + assert not torch.allclose(a, b) + + a.sum().backward() + b.sum().backward() + opt.step() + opt.zero_grad() + + # Conv layers are also trained differently because of their different outputs. + assert not torch.allclose(bn[0].weight, dbn[0].weight) + + # But BNs track identical running stats. + assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) + assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) + + # 2nd step + a = bn(input) + b = chunked_forward(dbn, input) + a.sum().backward() + b.sum().backward() + + # BNs can't track identical running stats due to the different conv layers. + assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) + assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) + + +def test_input_requiring_grad(): + dbn = DeferredBatchNorm(3, chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + input.requires_grad = True + + chunked_forward(dbn, input) + + assert not dbn.sum.requires_grad + assert dbn.sum.grad_fn is None diff --git a/test/distributed/pipeline/sync/test_dependency.py b/test/distributed/pipeline/sync/test_dependency.py new file mode 100644 index 0000000000000..e1dcfd35defa4 --- /dev/null +++ b/test/distributed/pipeline/sync/test_dependency.py @@ -0,0 +1,144 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import weakref + +import pytest +import torch + +from torch.distributed.pipeline.sync.dependency import Fork, Join, fork, join + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_fork_join(): + logs = [] + + class Log(torch.autograd.Function): + @staticmethod + def forward(ctx, number, tensor): + ctx.number = number + return tensor.detach() + + @staticmethod + def backward(ctx, grad): + logs.append(ctx.number) + return None, grad + + a = torch.rand(1, device="cpu", requires_grad=True) + b = torch.rand(1, device="cuda", requires_grad=True) + + a = Log.apply(1, a) + + a, phony = fork(a) + b = join(a, phony) + + b = Log.apply(2, b) + b = b.to("cpu") + + (a + b).backward() + + assert logs == [2, 1] + + +def test_fork_join_enable_grad(): + x = torch.rand(1, requires_grad=True) + + with torch.enable_grad(): + x2, p = fork(x) + + assert p.requires_grad + assert x2 is not x + x = x2 + + assert x.requires_grad + assert p.requires_grad + assert x.grad_fn.__class__ is Fork._backward_cls + assert p.grad_fn.__class__ is Fork._backward_cls + + with torch.enable_grad(): + x2 = join(x, p) + + assert x2 is not x + x = x2 + + assert x.requires_grad + assert x.grad_fn.__class__ is Join._backward_cls + + +def test_fork_join_no_grad(monkeypatch): + def do_not_apply(*args): + raise AssertionError("Function.apply called") + + monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) + + x = torch.rand(1, requires_grad=True) + + with torch.no_grad(): + x2, p = fork(x) + + assert not p.requires_grad + assert x2 is x + x = x2 + + with torch.no_grad(): + x2 = join(x, p) + + assert x2 is x + x = x2 + + +def test_fork_leak(): + leak = None + + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad): + nonlocal leak + leak = weakref.ref(ctx) + return grad + + x = torch.rand(1, requires_grad=True) + x = F.apply(x) + x, phony = fork(x) + x = join(x, phony) + + x.backward() + del x, phony + + assert leak() is None + + +def test_join_when_fork_not_requires_grad(): + x = torch.rand(2, 1) + a, b = x.chunk(2) + + assert not a.requires_grad + a, p = fork(a) + assert not a.requires_grad + assert not p.requires_grad + + assert not b.requires_grad + b = join(b, p) + assert not b.requires_grad + + +def test_join_when_fork_requires_grad(): + x = torch.rand(2, 1) + a, b = x.chunk(2) + + a.requires_grad_() + assert a.requires_grad + a, p = fork(a) + assert a.requires_grad + assert p.requires_grad + + assert not b.requires_grad + b = join(b, p) + assert b.requires_grad diff --git a/test/distributed/pipeline/sync/test_inplace.py b/test/distributed/pipeline/sync/test_inplace.py new file mode 100644 index 0000000000000..4720454892bc9 --- /dev/null +++ b/test/distributed/pipeline/sync/test_inplace.py @@ -0,0 +1,71 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe + + +def test_inplace_on_requires_grad(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) + model = Pipe(model, checkpoint="always") + + x = torch.rand(1) + y = model(x).local_value() + + message = r"a leaf Variable that requires grad .* used in an in-place operation." + with pytest.raises(RuntimeError, match=message): + y.backward() + + +@pytest.mark.xfail(strict=True) +def test_inplace_on_not_requires_grad(setup_rpc): + # In-place operation on a tensor not requiring grad doesn't cause a + # RuntimeError. Currently, we cannot detect this case. + model = nn.Sequential(nn.ReLU(inplace=True)) + model = Pipe(model, [1], devices=["cpu"], checkpoint="always") + + x = torch.rand(1) + y = model(x).local_value() + del model + + message = r"a leaf Variable that requires grad .* used in an in-place operation." + with pytest.raises(RuntimeError, match=message): + y.backward() + + +@pytest.mark.xfail(strict=True) +def test_inplace_incorrect_grad(setup_rpc): + class M(nn.Module): + def forward(self, foo_bar): + # 'foo' requires grad but 'bar' does not. In-place operation on + # 'bar' won't cause a RuntimeError. + foo, bar = foo_bar + + # add_(1) is not idempotent, in contrast to relu_(). If it is + # executed multiple times, it will accumulates each difference onto + # 'bar'. + bar.add_(1) + + # 'bar' is still captured by checkpointing. 'foo' will get + # incorrect grad. + return foo * bar + + model = nn.Sequential(M()) + model = Pipe(model, [1], devices=["cpu"], checkpoint="always") + + foo = torch.tensor([1.0], requires_grad=True) + bar = torch.tensor([1.0]) + + output = model((foo, bar)).local_value() + del model + output.backward() + + # The gradient of 'foo' should be 2, but it is 3 actually because + # bar.add_(1) was executed twice due to checkpointing. + assert foo.grad.item() == 2.0 diff --git a/test/distributed/pipeline/sync/test_microbatch.py b/test/distributed/pipeline/sync/test_microbatch.py new file mode 100644 index 0000000000000..914e9e8e8ae2a --- /dev/null +++ b/test/distributed/pipeline/sync/test_microbatch.py @@ -0,0 +1,138 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +import torch.cuda + +from torch.distributed.pipeline.sync.microbatch import Batch, check, gather, scatter + + +def test_batch_atomic(): + x = torch.tensor(42) + b = Batch(x) + + assert b.atomic + + assert b.tensor is x + with pytest.raises(AttributeError): + b.tensors + + assert list(b) == [x] + assert len(b) == 1 + assert b[0] is x + + +def test_batch_non_atomic(): + x, y = torch.tensor(42), torch.tensor(21) + b = Batch((x, y)) + + assert not b.atomic + + with pytest.raises(AttributeError): + b.tensor + assert b.tensors == (x, y) + + assert list(b) == [x, y] + assert len(b) == 2 + assert b[0] is x + assert b[1] is y + + +def test_batch_call(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + def f(x): + return x + + assert a.call(f).atomic + assert not b.call(f).atomic + + +def test_batch_setitem_by_index(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + a[0] = torch.tensor(0) + b[0] = torch.tensor(0) + + assert a.atomic + assert a[0].item() == 0 + + assert not b.atomic + assert len(b) == 2 + assert b[0].item() == 0 + assert b[1].item() == 21 + + +def test_batch_setitem_by_slice(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + a[:] = (torch.tensor(0),) + b[:] = (torch.tensor(0),) + + assert a.atomic + assert a[0].item() == 0 + + assert not b.atomic + assert len(b) == 1 + assert b[0].item() == 0 + + +def test_check(): + check(torch.tensor(42)) + check((torch.tensor(4), torch.tensor(2))) + + with pytest.raises(TypeError): + check(42) + + with pytest.raises(TypeError): + check("str") + + with pytest.raises(TypeError): + check((torch.tensor(4), 2)) + + +def test_gather_tensors(): + a = torch.zeros(1, 1) + b = torch.zeros(1, 1) + + ab = gather([Batch(a), Batch(b)]) + + assert ab.size() == (2, 1) + + +def test_gather_tuples(): + a = (torch.zeros(1, 1), torch.zeros(2, 2)) + b = (torch.zeros(1, 1), torch.zeros(2, 2)) + + ab = gather([Batch(a), Batch(b)]) + + assert isinstance(ab, tuple) + assert ab[0].size() == (2, 1) + assert ab[1].size() == (4, 2) + + +def test_scatter_tensor(): + ab = torch.zeros(2, 1) + + a, b = scatter(ab, chunks=2) + + assert a.tensor.size() == (1, 1) + assert b.tensor.size() == (1, 1) + + +def test_scatter_tuple(): + ab = (torch.zeros(2, 1), torch.zeros(4, 2)) + + a, b = scatter(ab, chunks=2) + + assert a.tensors[0].size() == (1, 1) + assert b.tensors[0].size() == (1, 1) + assert a.tensors[1].size() == (2, 2) + assert b.tensors[1].size() == (2, 2) diff --git a/test/distributed/pipeline/sync/test_phony.py b/test/distributed/pipeline/sync/test_phony.py new file mode 100644 index 0000000000000..5d06d465589de --- /dev/null +++ b/test/distributed/pipeline/sync/test_phony.py @@ -0,0 +1,50 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from torch.distributed.pipeline.sync.phony import get_phony + + +def test_phony_size(): + p = get_phony(torch.device("cpu"), requires_grad=False) + assert p.size() == (0,) + + +def test_phony_requires_grad(): + p1 = get_phony(torch.device("cpu"), requires_grad=True) + p2 = get_phony(torch.device("cpu"), requires_grad=False) + assert p1.requires_grad + assert not p2.requires_grad + + +def test_cached_phony(): + p1 = get_phony(torch.device("cpu"), requires_grad=True) + p2 = get_phony(torch.device("cpu"), requires_grad=True) + assert p1 is p2 + + p3 = get_phony(torch.device("cpu"), requires_grad=False) + p4 = get_phony(torch.device("cpu"), requires_grad=False) + assert p3 is p4 + + assert p1 is not p3 + + +def test_phony_in_autograd_function(): + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() + + x = torch.rand(1, requires_grad=True) + + p1 = Phonify.apply(x) + p2 = get_phony(torch.device("cpu"), requires_grad=True) + + assert p1 is not p2 + assert p1.grad_fn is not None + assert p2.grad_fn is None diff --git a/test/distributed/pipeline/sync/test_pipe.py b/test/distributed/pipeline/sync/test_pipe.py new file mode 100644 index 0000000000000..c01822a477850 --- /dev/null +++ b/test/distributed/pipeline/sync/test_pipe.py @@ -0,0 +1,629 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from collections import OrderedDict +from copy import deepcopy +import time + +import pytest +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + + +def test_parameters(): + model = nn.Sequential(nn.Linear(1, 1)) + pipe = Pipe(model, chunks=1) + assert list(pipe.parameters()) != [] + + +def test_public_attrs(): + class MyString: + def __init__(self, value): + self.value = value + + def __str__(self): + return self.value + + model = nn.Sequential(nn.Linear(1, 1)) + pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) + + assert pipe.devices == [torch.device("cpu")] + assert pipe.chunks == 42 + assert isinstance(pipe.chunks, int) + assert pipe.checkpoint == "always" + assert isinstance(pipe.checkpoint, str) + + +def test_sequential_like(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model) + + assert len(model) == 2 + assert list(model) == [a, b] + + assert model[0] is a + assert model[1] is b + with pytest.raises(IndexError): + _ = model[2] + + assert model[-1] is b + assert model[-2] is a + +def test_chunks_less_than_1(): + model = nn.Sequential(nn.Linear(1, 1)) + + with pytest.raises(ValueError): + Pipe(model, chunks=0) + + with pytest.raises(ValueError): + Pipe(model, chunks=-1) + +def test_batch_size_indivisible(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=4) + + with pytest.warns(None) as record: + model(torch.rand(7, 1)) + + # Indivisible batch size is legal. + assert not record + + +def test_batch_size_small(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=4) + + with pytest.warns(None) as record: + model(torch.rand(2, 1)) + + # Batch size smaller than chunks is legal. + assert not record + + +def test_checkpoint_mode(setup_rpc): + def count_grad_fn(grad_fn, name, visited=None): + if visited is None: + visited = set() + if grad_fn in visited: + return 0 + visited.add(grad_fn) + + if grad_fn is None: + return 0 + if grad_fn.__class__.__name__ == name: + return 1 + + counter = 0 + for next_grad_fn, _ in grad_fn.next_functions: + counter += count_grad_fn(next_grad_fn, name, visited=visited) + return counter + + model = nn.Sequential(nn.Linear(1, 1)) + input = torch.rand(2, 1) + + always = Pipe(model, chunks=2, checkpoint="always") + except_last = Pipe(model, chunks=2, checkpoint="except_last") + never = Pipe(model, chunks=2, checkpoint="never") + + always_output = always(input) + except_last_output = except_last(input) + never_output = never(input) + + assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 + assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1 + assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 + + +def test_checkpoint_mode_invalid(): + model = nn.Sequential(nn.Linear(1, 1)) + + with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): + Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") + + +def test_checkpoint_mode_when_chunks_1(): + model = nn.Sequential(nn.Linear(1, 1)) + + # All checkpoint modes are fine. + Pipe(model, chunks=1, checkpoint="except_last") + Pipe(model, chunks=1, checkpoint="always") + Pipe(model, chunks=1, checkpoint="never") + + +def test_checkpoint_eval(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=2) + input = torch.rand(2, 1) + + def find_grad_fn(grad_fn, name): + if grad_fn is None: + return False + if grad_fn.__class__.__name__ == name: + return True + for next_grad_fn, _ in grad_fn.next_functions: + if find_grad_fn(next_grad_fn, name): + return True + return False + + model.train() + train_output = model(input) + assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") + + model.eval() + eval_output = model(input) + assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") + + +def test_checkpoint_non_float_input(setup_rpc): + class ForkNonFloat(nn.Module): + def forward(self, input): + return (input * 2, torch.tensor([False])) + + class JoinNonFloat(nn.Module): + def forward(self, input): + return input[0] * 2 + + model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) + model = Pipe(model, chunks=1, checkpoint="always") + + input = torch.rand(1, requires_grad=True) + output = model(input) + output.backward() + + +def test_no_grad(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, chunks=2) + input = torch.rand(2, 1) + + latent = None + + def hook(module, input, output): + _ = module + _ = input + + nonlocal latent + latent = output + + partition = model.partitions[0] + partition.register_forward_hook(hook) + + with torch.no_grad(): + model(input) + + assert latent.grad_fn is None + + +def test_exception(setup_rpc): + class ExpectedException(Exception): + pass + + class Raise(nn.Module): + def forward(self, *_): + raise ExpectedException() + + model = nn.Sequential(Raise()) + model = Pipe(model, chunks=1) + + with pytest.raises(ExpectedException): + model(torch.rand(1)) + + +def test_exception_early_stop_asap(setup_rpc): + """Even the first partitions have finished to process, the partition before + the failed partition should be killed as soon as possible. + """ + + class ExpectedException(Exception): + pass + + class Pass(nn.Module): + def forward(self, x): + return x + + counter = 0 + + class Counter(nn.Module): + def forward(self, x): + time.sleep(0.1) + + nonlocal counter + counter += 1 + + return x + + class Raise(nn.Module): + def forward(self, x): + raise ExpectedException() + + model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) + model = Pipe(model, chunks=3) + + with pytest.raises(ExpectedException): + model(torch.rand(3)) + + # If the early stop doesn't work, it would be 3 instead. + assert counter == 2 + + +def test_nested_input(setup_rpc): + class NestedInput(nn.Module): + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(1, 1) + self.fc_b = nn.Linear(1, 1) + + def forward(self, inp): + return inp + + model = nn.Sequential(NestedInput()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + b = torch.rand(10, 1, requires_grad=True) + + # TypeError: expected Tensor, but got tuple + with pytest.raises(TypeError): + model((a, (a, b))).local_value() + + # TypeError: expected Tensor, but got list + with pytest.raises(TypeError): + model((a, [a, b])).local_value() + + +def test_input_pair(setup_rpc): + class Two(nn.Module): + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(1, 1) + self.fc_b = nn.Linear(1, 1) + + def forward(self, a_and_b): + a, b = a_and_b + return (self.fc_a(a), self.fc_b(b)) + + model = nn.Sequential(Two()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + b = torch.rand(10, 1, requires_grad=True) + + a_out, b_out = model((a, b)).local_value() + loss = (a_out + b_out).mean() + loss.backward() + + assert a.grad is not None + assert b.grad is not None + + # Test with list. + a.grad = None + b.grad = None + a_out, b_out = model([a, b]).local_value() + loss = (a_out + b_out).mean() + loss.backward() + + assert a.grad is not None + assert b.grad is not None + + + +def test_input_singleton(setup_rpc): + class One(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(1, 1) + + def forward(self, only_a): + (a,) = only_a + return (self.fc(a),) + + model = nn.Sequential(One()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + + (a_out,) = model((a,)).local_value() + loss = a_out.mean() + loss.backward() + + assert all(p.grad is not None for p in model.parameters()) + assert a.grad is not None + + # Test with list + a.grad = None + for p in model.parameters(): + p.grad = None + + (a_out,) = model([a]).local_value() + loss = a_out.mean() + loss.backward() + + assert all(p.grad is not None for p in model.parameters()) + assert a.grad is not None + + +def test_input_varargs(setup_rpc): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model) + + a = torch.rand(1) + b = torch.rand(1) + + # TypeError: forward() takes 2 positional arguments but 3 were given + with pytest.raises(TypeError): + model(a, b) + + +def test_non_tensor(setup_rpc): + class NonTensor(nn.Module): + def forward(self, _): + return "hello" + + model = nn.Sequential(NonTensor()) + model = Pipe(model) + x = torch.rand(1) + + # TypeError: expected Tensor as element 0 in argument 0, but got str + with pytest.raises(TypeError): + model(x) + + # TypeError: expected Tensor to scatter, but got str + with pytest.raises(TypeError): + model("hello") + + +def test_non_tensor_sequence(setup_rpc): + class NonTensorTuple(nn.Module): + def forward(self, x): + return (x, "hello") + + model = nn.Sequential(NonTensorTuple()) + model = Pipe(model) + x = torch.rand(1) + + # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 + with pytest.raises(TypeError): + model(x) + + # TypeError: expected Tensor to scatter, but got str + with pytest.raises(TypeError): + model((x, "hello")) + + # TypeError: expected Tensor to scatter, but got str + with pytest.raises(TypeError): + model([x, "hello"]) + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_deferred_batch_norm(checkpoint, setup_rpc): + bn = nn.BatchNorm2d(3) + pipe_bn = deepcopy(bn) + pipe = Pipe( + nn.Sequential(pipe_bn), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True + ) + + x = torch.rand(4, 3, 10, 10) + pipe(x).local_value().mean().backward() + bn(x).mean().backward() + + assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) + assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) + + +@pytest.mark.parametrize("checkpoint", ["never", "always"]) +def test_deferred_batch_norm_params(checkpoint, setup_rpc): + bn = nn.BatchNorm2d(3) + pipe_bn = deepcopy(bn) + pipe = Pipe( + nn.Sequential(pipe_bn), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True + ) + + x = torch.rand(4, 3, 10, 10) + pipe(x).local_value().mean().backward() + bn(x).mean().backward() + + assert pipe[0].weight.grad is not None + assert pipe[0].bias.grad is not None + + assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) + assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) + + +def test_devices(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + c = nn.Linear(1, 1) + + # There are extra two devices. + model = nn.Sequential(a, b, c) + model = Pipe(model) + + cpu = torch.device("cpu") + # Extra devices must be discarded. + assert model.devices == [cpu, cpu, cpu] + + +def test_partitions(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model) + + assert isinstance(model.partitions, nn.ModuleList) + assert isinstance(model.partitions[0], nn.Sequential) + assert isinstance(model.partitions[1], nn.Sequential) + + assert "partitions.0.0.weight" in model.state_dict() + + +def test_deny_moving(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model) + + # Moving is denied. + with pytest.raises(TypeError): + model.cuda() + + with pytest.raises(TypeError): + model.cpu() + + with pytest.raises(TypeError): + model.to(torch.device("cuda")) + + with pytest.raises(TypeError): + model.to(0) + + with pytest.raises(TypeError): + model.to("cuda") + + with pytest.raises(TypeError): + model.to(device=0) + + with pytest.raises(TypeError): + model.to(torch.rand(1)) + + with pytest.raises(TypeError): + model.to(tensor=torch.rand(1)) + + # Casting is allowed. + model.half() + model.to(torch.double) + model.to(dtype=torch.float) + + +def test_empty_module(setup_rpc): + # Empty sequential module is not illegal. + model = nn.Sequential() + model = Pipe(model) + + assert model(torch.tensor(42)).local_value() == torch.tensor(42) + assert model((torch.tensor(42),)).local_value() == (torch.tensor(42),) + + # But only tensor or tensors is legal in Pipe. + with pytest.raises(TypeError): + model(42) + + +def test_named_children(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) + model = Pipe(model) + + names = set(n for n, _ in model.named_modules()) + assert "partitions.0.a" in names + assert "partitions.1.b" in names + + # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires + # several methods in its namespace. + with pytest.raises(AttributeError): + model.a + + +def test_verify_module_non_sequential(): + with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"): + Pipe(nn.Module()) + + +def test_verify_module_duplicate_children(): + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(conv, conv) + + with pytest.raises(ValueError, match="module with duplicate children is not supported"): + Pipe(model) + + +@skip_if_no_cuda +def test_verify_module_params_on_same_device(): + class Surrogate(nn.Module): + def __init__(self, param1, param2): + super().__init__() + self.param1 = param1 + self.param2 = param2 + + conv1 = nn.Conv2d(3, 3, 1) + conv2 = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv1, conv2.cuda())) + + with pytest.raises( + ValueError, + match=r'should have all parameters on a single device, please use .to\(\)' + ' to place the module on a single device'): + Pipe(model) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") +def test_verify_nested_modules(setup_rpc): + model = nn.Sequential( + nn.Sequential( + nn.Linear(32, 16).cuda(0), + nn.Linear(16, 8).cuda(0) + ), + nn.Sequential( + nn.Linear(8, 4).cuda(1), + nn.Linear(4, 2).cuda(1) + ), + ) + + pipe = Pipe(model) + out = pipe(torch.rand(10, 32).cuda(0)) + assert out.local_value().device == torch.device("cuda:1") + assert out.local_value().size() == torch.Size([10, 2]) + +def test_verify_module_duplicate_parameters_on_same_device(): + class Surrogate(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv), Surrogate(conv)) + + Pipe(model) + + +def test_forward_lockstep(setup_rpc): + timeline = [] + + class DelayedLog(nn.Module): + def __init__(self, j, seconds): + super().__init__() + self.i = 0 + self.j = j + self.seconds = seconds + + def forward(self, x): + time.sleep(self.seconds) + + timeline.append((self.i, self.j)) + self.i += 1 + + return x + + model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) + model = Pipe(model, chunks=3) + model(torch.rand(3, 1)) + + # Expected timeline: (Logs are recorded at !) + # + # Partition #0: 0! 1! 2! + # Partition #1: 000! 111! 222! + # + assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] diff --git a/test/distributed/pipeline/sync/test_pipeline.py b/test/distributed/pipeline/sync/test_pipeline.py new file mode 100644 index 0000000000000..ef583e2df8708 --- /dev/null +++ b/test/distributed/pipeline/sync/test_pipeline.py @@ -0,0 +1,29 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed.pipeline.sync.pipeline import _clock_cycles + + +def test_clock_cycles(): + assert list(_clock_cycles(1, 1)) == [[(0, 0)]] + assert list(_clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] + assert list(_clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] + + assert list(_clock_cycles(3, 3)) == [ # noqa + [(0, 0)], + [(1, 0), (0, 1)], + [(2, 0), (1, 1), (0, 2)], + [(2, 1), (1, 2)], + [(2, 2)], + ] + + assert list(_clock_cycles(4, 2)) == [ # noqa + [(0, 0)], + [(1, 0), (0, 1)], + [(2, 0), (1, 1)], + [(3, 0), (2, 1)], + [(3, 1)], + ] diff --git a/test/distributed/pipeline/sync/test_stream.py b/test/distributed/pipeline/sync/test_stream.py new file mode 100644 index 0000000000000..fcb4409632108 --- /dev/null +++ b/test/distributed/pipeline/sync/test_stream.py @@ -0,0 +1,188 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch + +from torch.distributed.pipeline.sync.stream import ( + CPUStream, + current_stream, + default_stream, + get_device, + is_cuda, + new_stream, + record_stream, + use_device, + use_stream, + wait_stream, +) + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + + +class TestNewStream: + def test_new_stream_cpu(self): + stream = new_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_new_stream_cuda(self): + stream = new_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream != torch.cuda.default_stream() + + +class TestCurrentStream: + def test_current_stream_cpu(self): + stream = current_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_current_stream_cuda(self): + stream = current_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream == torch.cuda.current_stream() + + +class TestDefaultStream: + def test_default_stream_cpu(self): + stream = default_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_default_stream_cuda(self): + stream = default_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream == torch.cuda.default_stream() + + +class TestUseDevice: + def test_use_device_cpu(self): + with use_device(torch.device("cpu")): + pass + + @skip_if_no_cuda + def test_use_device_cuda(self): + with use_device(torch.device("cuda")): + pass + + +class TestUseStream: + def test_use_stream_cpu(self): + with use_stream(CPUStream): + pass + + @skip_if_no_cuda + def test_use_stream_cuda(self): + stream = new_stream(torch.device("cuda")) + with use_stream(stream): + assert current_stream(torch.device("cuda")) == stream + + +class TestGetDevice: + def test_get_device_cpu(self): + assert get_device(CPUStream).type == "cpu" + + @skip_if_no_cuda + def test_get_device_cuda(self): + stream = current_stream(torch.device("cuda")) + assert get_device(stream).type == "cuda" + + +class TestWaitStream: + def _test_wait_stream(self, source, target, cuda_sleep=None): + with use_stream(target): + if is_cuda(target): + cuda_sleep(0.5) + x = torch.ones(100, 100, device=get_device(target)) + + wait_stream(source, target) + + with use_stream(source): + assert x.sum().item() == 10000 + + def test_wait_stream_cpu_cpu(self): + source = CPUStream + target = CPUStream + self._test_wait_stream(source, target) + + @skip_if_no_cuda + def test_wait_stream_cpu_cuda(self, cuda_sleep): + source = CPUStream + target = new_stream(torch.device("cuda")) + self._test_wait_stream(source, target, cuda_sleep) + + @skip_if_no_cuda + def test_wait_stream_cuda_cpu(self, cuda_sleep): + source = new_stream(torch.device("cuda")) + target = CPUStream + self._test_wait_stream(source, target, cuda_sleep) + + @skip_if_no_cuda + def test_wait_stream_cuda_cuda(self, cuda_sleep): + source = current_stream(torch.device("cuda")) + target = new_stream(torch.device("cuda")) + self._test_wait_stream(source, target, cuda_sleep) + + +class TestRecordStream: + def test_record_stream_cpu(self): + # It should silently ignore CPU tensors. + x = torch.rand(1, device=torch.device("cpu")) + record_stream(x, CPUStream) + + @skip_if_no_cuda + def test_record_stream_cuda(self, cuda_sleep): + # This test detects unexpected block reallocation. For reliable test, + # the stream to allocate tensors is isolated. The allocator will not + # reuse free blocks which were allocated from another stream. + stream_alloc = new_stream(torch.device("cuda")) + with torch.cuda.stream(stream_alloc): + x = torch.rand(1, device=torch.device("cuda")) + + stream = new_stream(torch.device("cuda")) + record_stream(x, stream) + with use_stream(stream): + cuda_sleep(0.5) + + # 'x' is deleted at Python's perspective. But the block of 'x' is still + # required for 'stream'. 'y' shouldn't be allocated to the block. + data_ptr = x.data_ptr() + del x + stream_alloc.synchronize() + with torch.cuda.stream(stream_alloc): + y = torch.rand(1, device=torch.device("cuda")) + assert y.data_ptr() != data_ptr + + # Pause Python until 'stream' finishes tasks queued. Now the block of + # 'x' is free to be reallocated. + wait_stream(CPUStream, stream) + with torch.cuda.stream(stream_alloc): + z = torch.rand(1, device=torch.device("cuda")) + assert z.data_ptr() == data_ptr + + @skip_if_no_cuda + def test_record_stream_shifted_view(self, cuda_sleep): + # Issue: https://github.com/pytorch/pytorch/issues/27366 + stream_alloc = new_stream(torch.device("cuda")) + with torch.cuda.stream(stream_alloc): + x = torch.rand(2, device=torch.device("cuda")) + + y = x[1:] + assert y.data_ptr() > x.data_ptr() + + stream = new_stream(torch.device("cuda")) + with use_stream(stream): + cuda_sleep(0.5) + record_stream(y, stream) + + data_ptr = x.data_ptr() + del x, y + + stream_alloc.synchronize() + with torch.cuda.stream(stream_alloc): + z = torch.rand(2, device=torch.device("cuda")) + assert z.data_ptr() != data_ptr diff --git a/test/distributed/pipeline/sync/test_transparency.py b/test/distributed/pipeline/sync/test_transparency.py new file mode 100644 index 0000000000000..a9d44e50d32bf --- /dev/null +++ b/test/distributed/pipeline/sync/test_transparency.py @@ -0,0 +1,43 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torch import nn + +from torch.distributed.pipeline.sync import Pipe + + +def test_simple_linears(setup_rpc): + def sum_grad(parameters): + return sum([p.grad.sum() for p in parameters if p.grad is not None]) + + def zero_grad(parameters): + for p in parameters: + p.grad = None + + inputs = torch.rand(8, 1) + model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),) + + # Without Pipe + outputs = model(inputs) + loss = outputs.mean() + loss.backward() + + grad_without_pipe = sum_grad(model.parameters()) + + zero_grad(model.parameters()) + + # With Pipe + model = Pipe(model, chunks=4) + + outputs = model(inputs).local_value() + loss = outputs.mean() + loss.backward() + + grad_with_pipe = sum_grad(model.parameters()) + + # Both grads should be identical. + assert torch.allclose(grad_with_pipe, grad_without_pipe) diff --git a/test/distributed/pipeline/sync/test_worker.py b/test/distributed/pipeline/sync/test_worker.py new file mode 100644 index 0000000000000..5d3791a946d8c --- /dev/null +++ b/test/distributed/pipeline/sync/test_worker.py @@ -0,0 +1,166 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import threading +import time + +import pytest +import torch + +from torch.distributed.pipeline.sync.microbatch import Batch +from torch.distributed.pipeline.sync.stream import CPUStream +from torch.distributed.pipeline.sync.worker import Task, spawn_workers +from torch.testing._internal.common_utils import TEST_WITH_TSAN + + +class fake_device: + """A test double for :class:`torch.device`. Every fake device is different + with each other. + """ + + type = "fake" + index = None + + +@pytest.mark.skipif(TEST_WITH_TSAN, reason="False positive in TSAN") +def test_join_running_workers(): + count = 0 + + def counter(): + nonlocal count + time.sleep(0.1) + count += 1 + return Batch(()) + + with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): + + def call_in_worker(i, f): + task = Task(CPUStream, compute=f, finalize=None) + in_queues[i].put(task) + + for i in range(10): + call_in_worker(i, counter) + + # There's no nondeterminism because 'spawn_workers' joins all running + # workers. + assert count == 10 + + +@pytest.mark.skipif(TEST_WITH_TSAN, reason="False positive in TSAN") +def test_join_running_workers_with_exception(): + class ExpectedException(Exception): + pass + + count = 0 + + def counter(): + nonlocal count + time.sleep(0.1) + count += 1 + return Batch(()) + + with pytest.raises(ExpectedException): + with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): + + def call_in_worker(i, f): + task = Task(CPUStream, compute=f, finalize=None) + in_queues[i].put(task) + + for i in range(10): + call_in_worker(i, counter) + + raise ExpectedException + + # There's no nondeterminism because only 1 task can be placed in input + # queues. + assert count == 10 + + +def test_compute_multithreading(): + """Task.compute should be executed on multiple threads.""" + thread_ids = set() + + def log_thread_id(): + thread_id = threading.current_thread().ident + thread_ids.add(thread_id) + return Batch(()) + + with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): + for i in range(2): + t = Task(CPUStream, compute=log_thread_id, finalize=None) + in_queues[i].put(t) + for i in range(2): + out_queues[i].get() + + assert len(thread_ids) == 2 + + +def test_compute_success(): + """Task.compute returns (True, (task, batch)) on success.""" + + def _42(): + return Batch(torch.tensor(42)) + + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + t = Task(CPUStream, compute=_42, finalize=None) + in_queues[0].put(t) + ok, (task, batch) = out_queues[0].get() + + assert ok + assert task is t + assert isinstance(batch, Batch) + assert batch[0].item() == 42 + + +def test_compute_exception(): + """Task.compute returns (False, exc_info) on failure.""" + + def zero_div(): + 0 / 0 + + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + t = Task(CPUStream, compute=zero_div, finalize=None) + in_queues[0].put(t) + ok, exc_info = out_queues[0].get() + + assert not ok + assert isinstance(exc_info, tuple) + assert issubclass(exc_info[0], ZeroDivisionError) + + +@pytest.mark.parametrize("grad_mode", [True, False]) +def test_grad_mode(grad_mode): + def detect_grad_enabled(): + x = torch.rand(1, requires_grad=torch.is_grad_enabled()) + return Batch(x) + + with torch.set_grad_enabled(grad_mode): + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) + in_queues[0].put(task) + + ok, (_, batch) = out_queues[0].get() + + assert ok + assert batch[0].requires_grad == grad_mode + + +def test_worker_per_device(): + cpu = torch.device("cpu") + cpu0 = torch.device("cpu", index=0) + fake1 = fake_device() + fake2 = fake_device() + + with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): + assert len(in_queues) == len(out_queues) == 5 + + # 0: cpu, 1: cpu, 2: cpu0 + assert in_queues[0] is in_queues[1] is in_queues[2] + assert out_queues[0] is out_queues[1] is out_queues[2] + + # 3: fake1, 4: fake2 + assert in_queues[3] is not in_queues[4] + assert out_queues[3] is not out_queues[4] diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py old mode 100644 new mode 100755 index a81bc53f175a9..5492d6a9c3b2f --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -1,6 +1,6 @@ - import copy import math +import operator import os import random import signal @@ -9,44 +9,59 @@ import threading import time import unittest -from datetime import timedelta -from sys import platform from contextlib import contextmanager - -from itertools import groupby, product +from datetime import timedelta from functools import reduce -import operator +from itertools import groupby, product +from sys import platform import torch -from torch._six import string_classes -import torch.testing._internal.common_utils as common -from torch import nn -import torch.nn.functional as F import torch.distributed as c10d import torch.distributed as dist +import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default +import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD +import torch.nn.functional as F +import torch.testing._internal.common_utils as common +from torch import nn +from torch._six import string_classes from torch.nn.parallel import DistributedDataParallel - -from torch.testing._internal.common_distributed import MultiProcessTestCase, \ - requires_gloo, requires_nccl, requires_nccl_version, \ - skip_if_not_multigpu, skip_if_lt_x_gpu, get_timeout, skip_if_rocm, \ - simple_sparse_reduce_tests - -from torch.testing._internal.common_utils import TestCase, load_tests, run_tests, \ - retry_on_connect_failures, ADDRESS_IN_USE, CONNECT_TIMEOUT, TEST_WITH_TSAN +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_gloo, + requires_nccl, + requires_nccl_version, + skip_if_not_multigpu, + skip_if_lt_x_gpu, + get_timeout, + skip_if_rocm, + simple_sparse_reduce_tests, + skip_if_win32, + create_device, +) +from torch.testing._internal.common_utils import ( + TestCase, + load_tests, + run_tests, + retry_on_connect_failures, + ADDRESS_IN_USE, + CONNECT_TIMEOUT, + TEST_WITH_TSAN, + slowTest, +) # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests if not c10d.is_available(): - print('c10d not available, skipping tests', file=sys.stderr) + print("c10d not available, skipping tests", file=sys.stderr) sys.exit(0) -if platform == 'darwin': - LOOPBACK = 'lo0' +if platform == "darwin": + LOOPBACK = "lo0" else: - LOOPBACK = 'lo' + LOOPBACK = "lo" def gpus_for_rank(world_size): @@ -59,7 +74,9 @@ def gpus_for_rank(world_size): gpus_per_process = torch.cuda.device_count() // world_size gpus_for_rank = [] for rank in range(world_size): - gpus_for_rank.append(visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]) + gpus_for_rank.append( + visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process] + ) return gpus_for_rank @@ -92,7 +109,7 @@ def simple_reduce_tests(rank, world_size): # that the output changes accordingly. for i in range(4): vin = rank | (1 << i) - vout = (1 << i) + vout = 1 << i tests.append( ( c10d.ReduceOp.BAND, @@ -139,27 +156,29 @@ def simple_coalesced_reduce_tests(rank, world_size): [torch.tensor([rank + 1]), torch.tensor([(rank + 1) ** 2])], [ torch.tensor([float(world_size * (world_size + 1) / 2)]), - torch.tensor([float(world_size * (world_size + 1) * (2 * world_size + 1) / 6)]) - ] + torch.tensor( + [float(world_size * (world_size + 1) * (2 * world_size + 1) / 6)] + ), + ], ), ( c10d.ReduceOp.PRODUCT, [torch.tensor([rank + 1.0]), torch.tensor([rank + 2.0])], [ torch.tensor([float(math.factorial(world_size))]), - torch.tensor([float(math.factorial(world_size + 1))]) - ] + torch.tensor([float(math.factorial(world_size + 1))]), + ], ), ( c10d.ReduceOp.MIN, [torch.tensor([rank + x]) for x in [0.0, 1.0]], - [torch.tensor([0.0]), torch.tensor([1.0])] + [torch.tensor([0.0]), torch.tensor([1.0])], ), ( c10d.ReduceOp.MAX, [torch.tensor([rank + x]) for x in [1.0, 2.0]], - [torch.tensor([world_size]), torch.tensor([world_size + 1.0])] - ) + [torch.tensor([world_size]), torch.tensor([world_size + 1.0])], + ), ] @@ -205,6 +224,7 @@ def _test_set_get(self, fs): fs.add("key3", 4) fs.add("key3", 5) fs.add("key3", 6) + self.assertEqual(fs.num_keys(), self.num_keys_total) self.assertEqual(b"6", fs.get("key")) self.assertEqual(b"value0", fs.get("key0")) self.assertEqual(b"value1", fs.get("key1")) @@ -214,6 +234,14 @@ def _test_set_get(self, fs): def test_set_get(self): self._test_set_get(self._create_store()) + # This is the number of keys used in test_set_get. Adding this as a class + # property instead of hardcoding in the test since some Store + # implementations will have differing number of keys. In the base case, + # there will be 5 keys: key, key0, key1, key2, key3. + @property + def num_keys_total(self): + return 5 + class FileStoreTest(TestCase, StoreTestBase): def setUp(self): @@ -257,13 +285,17 @@ def create_tcp_store(addr): class TCPStoreTest(TestCase, StoreTestBase): def _create_store(self): - store = create_tcp_store('localhost') + store = create_tcp_store("localhost") store.set_timeout(timedelta(seconds=300)) return store def test_address_already_in_use(self): - with self.assertRaisesRegex(RuntimeError, "^Address already in use$"): - addr = 'localhost' + if sys.platform == "win32": + err_msg_reg = "Only one usage of each socket address*" + else: + err_msg_reg = "^Address already in use$" + with self.assertRaisesRegex(RuntimeError, err_msg_reg): + addr = "localhost" port = common.find_free_port() # Use noqa to silence flake8. @@ -272,17 +304,58 @@ def test_address_already_in_use(self): store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 + # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by + # the user and one additional key used for coordinate all the workers. + @property + def num_keys_total(self): + return 6 + + def _test_numkeys_delkeys(self, fs): + # We start off with one init key in the store to coordinate workers + self.assertEqual(fs.num_keys(), 1) + fs.add("key", 1) + fs.add("key", 2) + fs.add("key", 3) + fs.set("key0", "value0") + fs.add("key3", 1) + fs.set("key1", "value1") + self.assertEqual(fs.num_keys(), 5) + fs.delete_key("key") + self.assertEqual(fs.num_keys(), 4) + fs.set_timeout(timedelta(seconds=2)) + with self.assertRaises(RuntimeError): + fs.get("key") + fs.delete_key("key0") + fs.delete_key("key3") + self.assertEqual(fs.num_keys(), 2) + fs.set("key4", "value2") + self.assertEqual(fs.num_keys(), 3) + self.assertEqual(b"value1", fs.get("key1")) + self.assertEqual(b"value2", fs.get("key4")) + + # https://github.com/pytorch/pytorch/issues/46064 <- takes 5+ min to finish + @slowTest + def test_numkeys_delkeys(self): + self._test_numkeys_delkeys(self._create_store()) + class PrefixTCPStoreTest(TestCase, StoreTestBase): def setUp(self): super(PrefixTCPStoreTest, self).setUp() - self.tcpstore = create_tcp_store('localhost') + self.tcpstore = create_tcp_store("localhost") self.prefix = "test_prefix" self.tcpstore.set_timeout(timedelta(seconds=300)) def _create_store(self): return c10d.PrefixStore(self.prefix, self.tcpstore) + # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys + # added by the user and one additional key used for coordinate all the + # workers. + @property + def num_keys_total(self): + return 6 + class MyPythonStore(c10d.Store): def __init__(self): @@ -326,16 +399,16 @@ def test_set_get(self): class RendezvousTest(TestCase): def test_unknown_handler(self): with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): - c10d.rendezvous('invalid://') + c10d.rendezvous("invalid://") class RendezvousEnvTest(TestCase): @retry_on_connect_failures + @requires_nccl() def test_common_errors(self): - # TODO remove this hack - if not hasattr(c10d, "ProcessGroupNCCL"): - raise unittest.SkipTest("C10D is not built with NCCL process group," - " skipping test") + if torch.cuda.device_count() == 0: + raise unittest.SkipTest("No GPUs available, skipping test") + vars = { "WORLD_SIZE": "1", "RANK": "0", @@ -366,71 +439,71 @@ def withouts(d, keys): d.pop(key) return d - with Env(without(vars, 'WORLD_SIZE')): - with self.assertRaisesRegex(ValueError, 'WORLD_SIZE expected'): - gen = c10d.rendezvous('env://') + with Env(without(vars, "WORLD_SIZE")): + with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): + gen = c10d.rendezvous("env://") next(gen) - c10d.init_process_group(backend='nccl', world_size=1) + c10d.init_process_group(backend="nccl", world_size=1) self.assertEqual(c10d.get_rank(), 0) self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() - with Env(without(vars, 'RANK')): - with self.assertRaisesRegex(ValueError, 'RANK expected'): - gen = c10d.rendezvous('env://') + with Env(without(vars, "RANK")): + with self.assertRaisesRegex(ValueError, "RANK expected"): + gen = c10d.rendezvous("env://") next(gen) - c10d.init_process_group(backend='nccl', rank=0) + c10d.init_process_group(backend="nccl", rank=0) self.assertEqual(c10d.get_rank(), 0) self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() - with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])): - c10d.init_process_group(backend='nccl', rank=0, world_size=1) + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + c10d.init_process_group(backend="nccl", rank=0, world_size=1) self.assertEqual(c10d.get_rank(), 0) self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() with Env(vars): - c10d.init_process_group(backend='nccl') + c10d.init_process_group(backend="nccl") self.assertEqual(c10d.get_rank(), 0) self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() - with Env(without(vars, 'MASTER_ADDR')): - with self.assertRaisesRegex(ValueError, 'MASTER_ADDR expected'): - gen = c10d.rendezvous('env://') + with Env(without(vars, "MASTER_ADDR")): + with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): + gen = c10d.rendezvous("env://") next(gen) - with Env(without(vars, 'MASTER_PORT')): - with self.assertRaisesRegex(ValueError, 'MASTER_PORT expected'): - gen = c10d.rendezvous('env://') + with Env(without(vars, "MASTER_PORT")): + with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): + gen = c10d.rendezvous("env://") next(gen) - with Env(without(vars, 'WORLD_SIZE')): - gen = c10d.rendezvous('env://?world_size={}'.format(1)) + with Env(without(vars, "WORLD_SIZE")): + gen = c10d.rendezvous("env://?world_size={}".format(1)) _, _, size = next(gen) self.assertEqual(size, 1) - with Env(without(vars, 'RANK')): - gen = c10d.rendezvous('env://?rank={}'.format(0)) + with Env(without(vars, "RANK")): + gen = c10d.rendezvous("env://?rank={}".format(0)) _, rank, _ = next(gen) self.assertEqual(rank, 0) - with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])): - gen = c10d.rendezvous('env://?rank={}&world_size={}'.format(0, 1)) + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + gen = c10d.rendezvous("env://?rank={}&world_size={}".format(0, 1)) _, rank, size = next(gen) self.assertEqual(rank, 0) self.assertEqual(size, 1) @retry_on_connect_failures def test_nominal(self): - os.environ['WORLD_SIZE'] = '1' - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(common.find_free_port()) + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(common.find_free_port()) # Single rank - os.environ['RANK'] = '0' - gen0 = c10d.rendezvous('env://') + os.environ["RANK"] = "0" + gen0 = c10d.rendezvous("env://") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) self.assertEqual(1, size0) @@ -443,19 +516,19 @@ def test_nominal(self): class RendezvousFileTest(TestCase): def test_common_errors(self): - with self.assertRaisesRegex(ValueError, 'path missing'): - gen = c10d.rendezvous('file://?rank=0&world_size=1') + with self.assertRaisesRegex(ValueError, "path missing"): + gen = c10d.rendezvous("file://?rank=0&world_size=1") next(gen) - with self.assertRaisesRegex(ValueError, 'rank parameter missing'): - gen = c10d.rendezvous('file:///tmp/foo?world_size=1') + with self.assertRaisesRegex(ValueError, "rank parameter missing"): + gen = c10d.rendezvous("file:///tmp/foo?world_size=1") next(gen) - with self.assertRaisesRegex(ValueError, 'size parameter missing'): - gen = c10d.rendezvous('file:///tmp/foo?rank=0') + with self.assertRaisesRegex(ValueError, "size parameter missing"): + gen = c10d.rendezvous("file:///tmp/foo?rank=0") next(gen) def test_nominal(self): with tempfile.NamedTemporaryFile(delete=False) as file: - url = 'file://%s?world_size=%d' % (file.name, 2) + url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' gen0 = c10d.rendezvous(url + "&rank=0") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) @@ -474,23 +547,23 @@ def test_nominal(self): self.assertEqual(b"value1", store0.get("key1")) +@skip_if_win32() class RendezvousTCPTest(TestCase): - def create_tcp_url(self): addr = "localhost" port = common.find_free_port() - url = 'tcp://%s:%d?world_size=%d' % (addr, port, 1) + url = "tcp://%s:%d?world_size=%d" % (addr, port, 1) return url def test_common_errors(self): - with self.assertRaisesRegex(ValueError, 'port number missing'): - gen = c10d.rendezvous('tcp://127.0.0.1?rank=0&world_size=1') + with self.assertRaisesRegex(ValueError, "port number missing"): + gen = c10d.rendezvous("tcp://127.0.0.1?rank=0&world_size=1") next(gen) - with self.assertRaisesRegex(ValueError, 'rank parameter missing'): - gen = c10d.rendezvous('tcp://127.0.0.1:23456?world_size=1') + with self.assertRaisesRegex(ValueError, "rank parameter missing"): + gen = c10d.rendezvous("tcp://127.0.0.1:23456?world_size=1") next(gen) - with self.assertRaisesRegex(ValueError, 'size parameter missing'): - gen = c10d.rendezvous('tcp://127.0.0.1:23456?rank=0') + with self.assertRaisesRegex(ValueError, "size parameter missing"): + gen = c10d.rendezvous("tcp://127.0.0.1:23456?rank=0") next(gen) @retry_on_connect_failures @@ -528,8 +601,12 @@ class TimeoutTest(TestCase): def _test_store_timeout(self, backend, init_method, c2p): try: c10d.distributed_c10d.init_process_group( - backend=backend, init_method=init_method, world_size=1, rank=0, - timeout=timedelta(seconds=1)) + backend=backend, + init_method=init_method, + world_size=1, + rank=0, + timeout=timedelta(seconds=1), + ) default_store = c10d.distributed_c10d._get_default_store() tik = time.time() with self.assertRaisesRegex(RuntimeError, "Timeout"): @@ -544,16 +621,20 @@ def _test_store_timeout(self, backend, init_method, c2p): def _init_methods(self): f = tempfile.NamedTemporaryFile(delete=False) - yield "file://%s" % f.name - f.close() - yield "tcp://127.0.0.1:%d" % common.find_free_port() + if sys.platform == "win32": + yield "file:///%s" % f.name.replace("\\", "/") + f.close() + else: + yield "file://%s" % f.name + f.close() + yield "tcp://127.0.0.1:%d" % common.find_free_port() def _test_default_store_timeout(self, backend): for init_method in self._init_methods(): c2p = [] t = threading.Thread( - target=self._test_store_timeout, - args=(backend, init_method, c2p)) + target=self._test_store_timeout, args=(backend, init_method, c2p) + ) t.daemon = True t.start() t.join(5) @@ -571,24 +652,34 @@ def _test_default_store_timeout(self, backend): @requires_nccl() @retry_on_connect_failures def test_default_store_timeout_nccl(self): - self._test_default_store_timeout('nccl') + if torch.cuda.device_count() == 0: + raise unittest.SkipTest("No GPUs available, skipping test") + self._test_default_store_timeout("nccl") @requires_gloo() @retry_on_connect_failures def test_default_store_timeout_gloo(self): - self._test_default_store_timeout('gloo') + self._test_default_store_timeout("gloo") @requires_gloo() -@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) class ProcessGroupGlooTest(MultiProcessTestCase): def setUp(self): super(ProcessGroupGlooTest, self).setUp() - self._fork_processes() + + # For Windows platform, Python does not support fork, change it to spawn here. + if sys.platform == "win32": + self._spawn_processes() + else: + self._fork_processes() def opts(self, threads=2): opts = c10d.ProcessGroupGloo.Options() - opts.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] + opts.devices = [create_device(interface=LOOPBACK)] opts.timeout = 5.0 opts.threads = threads return opts @@ -598,8 +689,8 @@ def test_multi_device_constructor(self): opts = c10d.ProcessGroupGloo.Options() opts.timeout = 5.0 opts.devices = [ - c10d.ProcessGroupGloo.create_device(interface=LOOPBACK), - c10d.ProcessGroupGloo.create_device(interface=LOOPBACK), + create_device(interface=LOOPBACK), + create_device(interface=LOOPBACK), ] pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts) @@ -713,7 +804,9 @@ def test_broadcast_basics_cuda(self): def _test_broadcast_stress(self, inputs): store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(threads=8) + ) work_handles = [ pg.broadcast(inputs[i], root=(i % self.world_size)) for i in range(len(inputs)) @@ -721,9 +814,7 @@ def _test_broadcast_stress(self, inputs): for i, work_handle in enumerate(work_handles): work_handle.wait() self.assertEqual( - torch.tensor([ - (i * self.world_size) + (i % self.world_size) - ]), + torch.tensor([(i * self.world_size) + (i % self.world_size)]), inputs[i], msg=("Mismatch in iteration %d" % i), ) @@ -733,9 +824,10 @@ def test_broadcast_stress(self): self._test_broadcast_stress(inputs) @skip_if_not_multigpu - @skip_if_rocm def test_broadcast_stress_cuda(self): - inputs = [torch.tensor([i * self.world_size + self.rank]).cuda() for i in range(1000)] + inputs = [ + torch.tensor([i * self.world_size + self.rank]).cuda() for i in range(1000) + ] self._test_broadcast_stress(inputs) def test_allreduce_checks(self): @@ -789,7 +881,9 @@ def _test_allreduce_basics(self, fn): x = fn(torch.tensor([self.rank + 1.0])) work = pg.allreduce(x) work.wait() - self.assertEqual(torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), x) + self.assertEqual( + torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), x + ) def test_allreduce_basics(self): self._test_allreduce_basics(lambda t: t.clone()) @@ -800,16 +894,20 @@ def test_allreduce_basics_cuda(self): def _test_allreduce_stress(self, inputs): store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(threads=8) + ) work_handles = [pg.allreduce(inputs[i]) for i in range(len(inputs))] for i, work_handle in enumerate(work_handles): work_handle.wait() # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( - torch.tensor([ - (i * self.world_size) + - (self.world_size * (self.world_size - 1) / 2) - ]), + torch.tensor( + [ + (i * self.world_size) + + (self.world_size * (self.world_size - 1) / 2) + ] + ), inputs[i], msg=("Mismatch in iteration %d" % i), ) @@ -878,15 +976,25 @@ def test_allreduce_coalesced_basics(self): def _test_allreduce_coalesced_stress(self, inputs): store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(threads=8) + ) work_handles = [pg.allreduce_coalesced(input) for input in inputs] for i, work_handle in enumerate(work_handles): work_handle.wait() # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( - 2 * [torch.tensor([(i * self.world_size) + (self.world_size * (self.world_size - 1) / 2)])], + 2 + * [ + torch.tensor( + [ + (i * self.world_size) + + (self.world_size * (self.world_size - 1) / 2) + ] + ) + ], inputs[i], - msg="Mismatch in interation {}".format(i) + msg="Mismatch in interation {}".format(i), ) def test_allreduce_coalesced_stress(self): @@ -926,9 +1034,8 @@ def _test_sparse_allreduce_basics(self, fn): for num_inputs_per_rank in [1, 2]: tests = simple_sparse_reduce_tests( - self.rank, - self.world_size, - num_inputs=num_inputs_per_rank) + self.rank, self.world_size, num_inputs=num_inputs_per_rank + ) for (inputs, outputs) in tests: tensors = [fn(input) for input in inputs] work = pg.allreduce(tensors) @@ -940,7 +1047,6 @@ def test_sparse_allreduce_basics(self): self._test_sparse_allreduce_basics(lambda t: t) @skip_if_not_multigpu - @skip_if_rocm def test_sparse_allreduce_basics_cuda(self): self._test_sparse_allreduce_basics(lambda t: t.clone().cuda()) @@ -962,12 +1068,16 @@ def test_scatter_checks(self): opts.rootRank = self.world_size pg.scatter([t1], [], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element output tensor list"): + with self.assertRaisesRegex( + ValueError, "requires a single-element output tensor list" + ): opts = c10d.ScatterOptions() opts.rootRank = 0 pg.scatter([], [], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element output tensor list"): + with self.assertRaisesRegex( + ValueError, "requires a single-element output tensor list" + ): opts = c10d.ScatterOptions() opts.rootRank = 0 pg.scatter([t1, t1], [], opts) @@ -985,13 +1095,17 @@ def test_scatter_checks(self): desired_list_size = self.world_size incorrect_list_size = self.world_size - 1 err_str = "Incorrect input list size {}. Input list size should be {}" - with self.assertRaisesRegex(ValueError, err_str.format(incorrect_list_size, desired_list_size)): + with self.assertRaisesRegex( + ValueError, err_str.format(incorrect_list_size, desired_list_size) + ): opts = c10d.ScatterOptions() opts.rootRank = self.rank pg.scatter([t1], [[t1] * incorrect_list_size], opts) incorrect_list_size = self.world_size + 1 - with self.assertRaisesRegex(ValueError, err_str.format(incorrect_list_size, desired_list_size)): + with self.assertRaisesRegex( + ValueError, err_str.format(incorrect_list_size, desired_list_size) + ): opts = c10d.ScatterOptions() opts.rootRank = self.rank pg.scatter([t1], [[t1] * incorrect_list_size], opts) @@ -1043,7 +1157,9 @@ def test_scatter_basics_cuda(self): def _test_scatter_stress(self, inputs, fn): store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(threads=8) + ) outputs = [ [fn(torch.tensor([-1])) for _ in range(self.world_size)] for _ in range(len(inputs)) @@ -1054,7 +1170,9 @@ def _test_scatter_stress(self, inputs, fn): opts = c10d.ScatterOptions() opts.rootRank = root if root == self.rank: - work = pg.scatter([outputs[i][root]], [[fn(e) for e in inputs[i]]], opts) + work = pg.scatter( + [outputs[i][root]], [[fn(e) for e in inputs[i]]], opts + ) else: work = pg.scatter([outputs[i][root]], [], opts) work_handles.append(work) @@ -1104,22 +1222,30 @@ def test_gather_checks(self): opts.rootRank = self.world_size pg.gather([], [t1], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element input tensor list"): + with self.assertRaisesRegex( + ValueError, "requires a single-element input tensor list" + ): opts = c10d.GatherOptions() opts.rootRank = 0 pg.gather([], [], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element input tensor list"): + with self.assertRaisesRegex( + ValueError, "requires a single-element input tensor list" + ): opts = c10d.GatherOptions() opts.rootRank = 0 pg.gather([], [t1, t1], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element output list"): + with self.assertRaisesRegex( + ValueError, "requires a single-element output list" + ): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([], [t1], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element output list"): + with self.assertRaisesRegex( + ValueError, "requires a single-element output list" + ): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([[t1] * self.world_size, [t1] * self.world_size], [t1], opts) @@ -1127,13 +1253,17 @@ def test_gather_checks(self): desired_list_size = self.world_size incorrect_list_size = self.world_size - 1 err_str = "Incorrect output list size {}. Output list size should be {}" - with self.assertRaisesRegex(ValueError, err_str.format(incorrect_list_size, desired_list_size)): + with self.assertRaisesRegex( + ValueError, err_str.format(incorrect_list_size, desired_list_size) + ): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([[t1] * incorrect_list_size], [t1], opts) incorrect_list_size = self.world_size + 1 - with self.assertRaisesRegex(ValueError, err_str.format(incorrect_list_size, desired_list_size)): + with self.assertRaisesRegex( + ValueError, err_str.format(incorrect_list_size, desired_list_size) + ): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([[t1] * incorrect_list_size], [t1], opts) @@ -1187,17 +1317,17 @@ def test_gather_basics_cuda(self): def _test_gather_stress(self, inputs, fn): store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(threads=8) + ) work_handles = [] outputs = [ - [ - [fn(torch.tensor([-1])) for _ in range(self.world_size)] - ] for _ in range(len(inputs)) + [[fn(torch.tensor([-1])) for _ in range(self.world_size)]] + for _ in range(len(inputs)) ] expected_outputs = [ - [ - [torch.tensor([i + j]) for j in range(self.world_size)] - ] for i in range(len(inputs)) + [[torch.tensor([i + j]) for j in range(self.world_size)]] + for i in range(len(inputs)) ] for i in range(len(inputs)): for root in range(self.world_size): @@ -1217,7 +1347,7 @@ def _test_gather_stress(self, inputs, fn): self.assertEqual( expected_outputs[iter], outputs[iter], - msg=("Mismatch in iteration %d for root %d" % (iter, root)) + msg=("Mismatch in iteration %d for root %d" % (iter, root)), ) def test_gather_stress(self): @@ -1225,7 +1355,6 @@ def test_gather_stress(self): self._test_gather_stress(inputs, lambda t: t.clone()) @skip_if_not_multigpu - @skip_if_rocm def test_gather_stress_cuda(self): inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)] self._test_gather_stress(inputs, lambda t: t.clone().cuda()) @@ -1241,10 +1370,14 @@ def test_allgather_checks(self): with self.assertRaisesRegex(ValueError, "requires non-empty input tensor list"): pg.allgather([], []) - with self.assertRaisesRegex(ValueError, "requires input/output tensor lists to have the same length"): + with self.assertRaisesRegex( + ValueError, "requires input/output tensor lists to have the same length" + ): pg.allgather([], [t1]) - with self.assertRaisesRegex(ValueError, "requires input/output tensor lists to have the same length"): + with self.assertRaisesRegex( + ValueError, "requires input/output tensor lists to have the same length" + ): pg.allgather([[t1] * self.world_size, [t1] * self.world_size], [t1]) with self.assertRaisesRegex(ValueError, "invalid output tensor list"): @@ -1254,16 +1387,20 @@ def test_allgather_checks(self): pg.allgather([[t1] * (self.world_size + 1)], [t1]) with self.assertRaisesRegex(ValueError, "invalid tensor type"): - pg.allgather([[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t2]) + pg.allgather( + [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t2] + ) with self.assertRaisesRegex(ValueError, "invalid tensor size"): - pg.allgather([[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t3]) + pg.allgather( + [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t3] + ) with self.assertRaisesRegex(ValueError, "invalid tensor type"): - pg.allgather([([t1, t2] * (self.world_size))[:self.world_size]], [t1]) + pg.allgather([([t1, t2] * (self.world_size))[: self.world_size]], [t1]) with self.assertRaisesRegex(ValueError, "invalid tensor size"): - pg.allgather([([t1, t3] * (self.world_size))[:self.world_size]], [t1]) + pg.allgather([([t1, t3] * (self.world_size))[: self.world_size]], [t1]) def _test_allgather_basics(self, fn): store = c10d.FileStore(self.file_name, self.world_size) @@ -1271,18 +1408,14 @@ def _test_allgather_basics(self, fn): # Run with N input tensor per rank for n in [1, 2, 3]: - input = [ - fn(torch.tensor([n * self.rank + i])) for i in range(n) - ] + input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)] output = [ - [ - fn(torch.tensor([-1])) for _ in range(n * self.world_size) - ] for _ in range(n) + [fn(torch.tensor([-1])) for _ in range(n * self.world_size)] + for _ in range(n) ] expected_output = [ - [ - torch.tensor([i]) for i in range(n * self.world_size) - ] for _ in range(n) + [torch.tensor([i]) for i in range(n * self.world_size)] + for _ in range(n) ] work = pg.allgather(output, input) work.wait() @@ -1297,17 +1430,17 @@ def test_allgather_basics_cuda(self): def _test_allgather_stress(self, inputs, fn): store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(threads=8) + ) work_handles = [] outputs = [ - [ - [fn(torch.tensor([-1])) for _ in range(self.world_size)] - ] for _ in range(len(inputs)) + [[fn(torch.tensor([-1])) for _ in range(self.world_size)]] + for _ in range(len(inputs)) ] expected_outputs = [ - [ - [torch.tensor([i + j]) for j in range(self.world_size)] - ] for i in range(len(inputs)) + [[torch.tensor([i + j]) for j in range(self.world_size)]] + for i in range(len(inputs)) ] for i in range(len(inputs)): work = pg.allgather(outputs[i], [fn(inputs[i])]) @@ -1326,7 +1459,6 @@ def test_allgather_stress(self): self._test_allgather_stress(inputs, lambda t: t.clone()) @skip_if_not_multigpu - @skip_if_rocm def test_allgather_stress_cuda(self): inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)] self._test_allgather_stress(inputs, lambda t: t.clone().cuda()) @@ -1341,30 +1473,30 @@ def test_allgather_coalesced_checks(self): # One of output tensors does not match input list. dummy_output_lists[0] = [torch.zeros([0], dtype=torch.float32)] - with self.assertRaisesRegex(ValueError, - "invalid size of output tensor at index 0"): + with self.assertRaisesRegex( + ValueError, "invalid size of output tensor at index 0" + ): c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg) # One of output tensors does not match input list. dummy_output_lists[0] = [torch.zeros([1], dtype=torch.float64)] - with self.assertRaisesRegex(ValueError, - "invalid tensor type at index 0"): + with self.assertRaisesRegex(ValueError, "invalid tensor type at index 0"): c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg) # Output lists have too many elements dummy_output_lists = [ - [ - torch.zeros([1], dtype=torch.float32) - ] for _ in range(self.world_size + 1) + [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size + 1) ] - with self.assertRaisesRegex(ValueError, - "output lists should be equal to world size"): + with self.assertRaisesRegex( + ValueError, "output lists should be equal to world size" + ): c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg) # Output is not a list of lists. dummy_output_lists = [torch.zeros([0], dtype=torch.float32)] - with self.assertRaisesRegex(RuntimeError, - "Invalid function argument.*output_tensor_lists"): + with self.assertRaisesRegex( + RuntimeError, "Invalid function argument.*output_tensor_lists" + ): c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg) def test_reduce_checks(self): @@ -1391,7 +1523,9 @@ def test_reduce_checks(self): opts.rootTensor = 1 pg.reduce([t1], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element tensor list"): + with self.assertRaisesRegex( + ValueError, "requires a single-element tensor list" + ): opts = c10d.ReduceOptions() opts.rootRank = self.rank opts.rootTensor = 0 @@ -1421,7 +1555,9 @@ def test_reduce_basics_cuda(self): def _test_reduce_stress(self, inputs): store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(threads=8) + ) work_handles = [] outputs = [] for i in range(len(inputs)): @@ -1440,10 +1576,12 @@ def _test_reduce_stress(self, inputs): if root == self.rank: # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( - torch.tensor([ - (iter * self.world_size) + - (self.world_size * (self.world_size - 1) / 2) - ]), + torch.tensor( + [ + (iter * self.world_size) + + (self.world_size * (self.world_size - 1) / 2) + ] + ), outputs[i], msg=("Mismatch in iteration %d with root rank %d" % (iter, root)), ) @@ -1453,7 +1591,6 @@ def test_reduce_stress(self): self._test_reduce_stress(inputs) @skip_if_not_multigpu - @skip_if_rocm def test_reduce_stress_cuda(self): inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)] self._test_reduce_stress(inputs) @@ -1514,34 +1651,40 @@ def test_barrier_implies_wait(self): for i, tensor in enumerate(tensors): self.assertEqual(torch.full(size, float(i * self.world_size)), tensor) + @skip_if_win32() def test_round_robin(self): num_process_groups = 2 store = c10d.FileStore(self.file_name, self.world_size) - pg = c10d._round_robin_process_groups([ - c10d.ProcessGroupGloo( - c10d.PrefixStore(str(i), store), - self.rank, - self.world_size) - for i in range(num_process_groups) - ]) + pg = c10d._round_robin_process_groups( + [ + c10d.ProcessGroupGloo( + c10d.PrefixStore(str(i), store), self.rank, self.world_size + ) + for i in range(num_process_groups) + ] + ) # Run a few collectives so that we have called each process group for _ in range(num_process_groups + 1): tensor = torch.full([100, 100], float(self.rank)) pg.broadcast(tensor, root=0).wait() - self.assertEqual(torch.full([100, 100], 0.), tensor) + self.assertEqual(torch.full([100, 100], 0.0), tensor) + @skip_if_win32() def test_round_robin_create_destroy(self): store = c10d.FileStore(self.file_name, self.world_size) def create(num, prefix): - return c10d._round_robin_process_groups([ - c10d.ProcessGroupGloo( - c10d.PrefixStore("%s/%d" % (prefix, i), store), - self.rank, - self.world_size) - for i in range(num) - ]) + return c10d._round_robin_process_groups( + [ + c10d.ProcessGroupGloo( + c10d.PrefixStore("%s/%d" % (prefix, i), store), + self.rank, + self.world_size, + ) + for i in range(num) + ] + ) # Run create/use/destroy twice for i in range(2): @@ -1554,12 +1697,33 @@ def create(num, prefix): del pg -@requires_nccl() +class ProcessGroupNCCLNoGPUTest(TestCase): + MAIN_PROCESS_RANK = 0 + + def setUp(self): + self.rank = self.MAIN_PROCESS_RANK + self.world_size = 1 + self.file = tempfile.NamedTemporaryFile(delete=False) + self.num_gpus = torch.cuda.device_count() + if self.num_gpus > 0: + raise unittest.SkipTest("GPUs are available, skipping test") + + def tearDown(self): + pass + + @requires_nccl() + def test_init_no_gpus(self): + store = c10d.FileStore(self.file.name, self.world_size) + with self.assertRaisesRegex( + RuntimeError, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!" + ): + c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + @unittest.skipIf( TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment", ) -@skip_if_rocm class ProcessGroupNCCLTest(TestCase): MAIN_PROCESS_RANK = 0 @@ -1574,6 +1738,7 @@ def setUp(self): def tearDown(self): pass + @requires_nccl() def test_empty_tensors(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -1598,6 +1763,7 @@ def test_empty_tensors(self): pg.reduce_scatter(ys, xs).wait() self.assertEqual(0, ys[0].numel()) + @requires_nccl() def test_broadcast_ops(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -1620,6 +1786,7 @@ def broadcast(xs, rootRank, rootTensor): for i in range(self.num_gpus): self.assertEqual(tensors[i], tensors[rt]) + @requires_nccl() def test_allreduce_ops(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -1641,7 +1808,8 @@ def allreduce(tensors, op): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( torch.tensor([float(self.num_gpus * (self.num_gpus + 1) / 2)]), - tensors[i]) + tensors[i], + ) # Product tensors = [] @@ -1653,8 +1821,8 @@ def allreduce(tensors, op): for i in range(self.num_gpus): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( - torch.tensor([float(math.factorial(self.num_gpus))]), - tensors[i]) + torch.tensor([float(math.factorial(self.num_gpus))]), tensors[i] + ) # Min tensors = [] @@ -1678,9 +1846,12 @@ def allreduce(tensors, op): self.assertEqual(torch.tensor([self.num_gpus]), tensors[i]) for op in (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR): - with self.assertRaisesRegex(RuntimeError, "Cannot use " + str(op) + " with NCCL"): + with self.assertRaisesRegex( + RuntimeError, "Cannot use " + str(op) + " with NCCL" + ): allreduce(tensors, op) + @requires_nccl() def test_reduce_ops(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -1705,12 +1876,16 @@ def reduce(xs, rootRank, rootTensor, op=None): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( torch.tensor([float(self.num_gpus * (self.num_gpus + 1) / 2)]), - tensors[rt]) + tensors[rt], + ) for op in (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR): - with self.assertRaisesRegex(RuntimeError, "Cannot use " + str(op) + " with NCCL"): + with self.assertRaisesRegex( + RuntimeError, "Cannot use " + str(op) + " with NCCL" + ): reduce(tensors, self.rank, rt, op) + @requires_nccl() def test_allgather_ops(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -1736,6 +1911,7 @@ def allgather(output_ts, input_ts): for s_idx, t in enumerate(device_ts): self.assertEqual(torch.tensor([s_idx]), t) + @requires_nccl() def test_reduce_scatter_ops(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -1749,10 +1925,7 @@ def reduce_scatter(outputs, input_lists, op): virtual_rank = self.rank * self.world_size virtual_world_size = self.num_gpus * self.world_size - output = [ - torch.tensor([0]).cuda(i) - for i in range(self.num_gpus) - ] + output = [torch.tensor([0]).cuda(i) for i in range(self.num_gpus)] # 0 1 2 # 0 [0..11] [1..12] @@ -1772,10 +1945,12 @@ def reduce_scatter(outputs, input_lists, op): reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM) for i in range(self.num_gpus): - expected = torch.tensor([ - float(self.num_gpus * (self.num_gpus - 1) / 2) + - (virtual_rank + i) * virtual_world_size - ]) + expected = torch.tensor( + [ + float(self.num_gpus * (self.num_gpus - 1) / 2) + + (virtual_rank + i) * virtual_world_size + ] + ) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected, output[i]) @@ -1798,9 +1973,9 @@ def reduce_scatter(outputs, input_lists, op): # Product tensor_lists = [ [ - torch.tensor([ - (self.rank * self.num_gpus + i + j) % virtual_world_size + 1 - ]).cuda(i) + torch.tensor( + [(self.rank * self.num_gpus + i + j) % virtual_world_size + 1] + ).cuda(i) for j in range(virtual_world_size) ] for i in range(self.num_gpus) @@ -1813,6 +1988,7 @@ def reduce_scatter(outputs, input_lists, op): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected, output[i]) + @requires_nccl() def test_barrier(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -1841,8 +2017,8 @@ def allreduce(tensors): for j in range(i): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( - torch.tensor([float(i * (i + 1) / 2)]), - tensors_list[i - 2][j]) + torch.tensor([float(i * (i + 1) / 2)]), tensors_list[i - 2][j] + ) class Net(nn.Module): @@ -1867,8 +2043,9 @@ def __init__(self, gpus): self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1]) self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[1]) self.relu = nn.ReLU() - self.no_grad_param = nn.Parameter(torch.tensor([2, 2]).long(), - requires_grad=False).to(gpus[0]) + self.no_grad_param = nn.Parameter( + torch.tensor([2, 2]).long(), requires_grad=False + ).to(gpus[0]) def forward(self, x): dev0 = self.fc1.weight.device @@ -1887,8 +2064,9 @@ def __init__(self, gpus): self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[2]) self.fc4 = nn.Linear(4, 4, bias=False).to(gpus[3]) self.relu = nn.ReLU() - self.no_grad_param = nn.Parameter(torch.tensor([2, 2]).long(), - requires_grad=False).to(gpus[0]) + self.no_grad_param = nn.Parameter( + torch.tensor([2, 2]).long(), requires_grad=False + ).to(gpus[0]) def forward(self, x): dev0 = self.fc1.weight.device @@ -1910,10 +2088,18 @@ def __init__(self, gpus, layouts, dtypes): self.layer_gpus = gpus else: gpus = [gpus] * 4 - self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(device=gpus[0], memory_format=layouts[0], dtype=dtypes[0]) - self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(device=gpus[1], memory_format=layouts[1], dtype=dtypes[1]) - self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(device=gpus[2], memory_format=layouts[2], dtype=dtypes[2]) - self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(device=gpus[3], memory_format=layouts[3], dtype=dtypes[3]) + self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to( + device=gpus[0], memory_format=layouts[0], dtype=dtypes[0] + ) + self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to( + device=gpus[1], memory_format=layouts[1], dtype=dtypes[1] + ) + self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to( + device=gpus[2], memory_format=layouts[2], dtype=dtypes[2] + ) + self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to( + device=gpus[3], memory_format=layouts[3], dtype=dtypes[3] + ) def forward(self, x): x = x.to(self.dtypes[0]) @@ -1955,11 +2141,17 @@ def forward(self, x): return F.softmax(self.embedding(x), dim=1) -@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) class DistributedDataParallelTest(MultiProcessTestCase): def setUp(self): super(DistributedDataParallelTest, self).setUp() - self._fork_processes() + if sys.platform == "win32": + self._spawn_processes() + else: + self._fork_processes() def tearDown(self): # DistributedDataParallel test doesn't seem to call FileStore destructor @@ -1975,14 +2167,21 @@ def world_size(self): return 2 def _prepare_single_device_module( - self, process_group, devices, device_ids, global_batch_size, gradient_as_bucket_view=False): + self, + process_group, + devices, + device_ids, + global_batch_size, + gradient_as_bucket_view=False, + ): model = Net() ddp_model = DistributedDataParallel( copy.deepcopy(model).to(devices[0]), device_ids=device_ids, process_group=process_group, bucket_cap_mb=0.001, - gradient_as_bucket_view=gradient_as_bucket_view) + gradient_as_bucket_view=gradient_as_bucket_view, + ) model.to(devices[0]) @@ -1991,10 +2190,18 @@ def _prepare_single_device_module( return model, ddp_model, input, target - def _prepare_multi_device_module(self, process_group, devices, device_ids, global_batch_size, gradient_as_bucket_view=False): + def _prepare_multi_device_module( + self, + process_group, + devices, + device_ids, + global_batch_size, + gradient_as_bucket_view=False, + ): self.assertTrue( len(devices) == 2 or len(devices) == 4, - "unexpected devices for ddp tests {}".format(devices)) + "unexpected devices for ddp tests {}".format(devices), + ) if len(devices) == 2: model = DoubleGpuNet(devices) elif len(devices) == 4: @@ -2005,14 +2212,22 @@ def _prepare_multi_device_module(self, process_group, devices, device_ids, globa device_ids=device_ids, process_group=process_group, bucket_cap_mb=0.001, - gradient_as_bucket_view=gradient_as_bucket_view) + gradient_as_bucket_view=gradient_as_bucket_view, + ) input = torch.randn(global_batch_size, 2).cuda(devices[0]) target = torch.randn(global_batch_size, 4) return model, ddp_model, input, target - def _test_ddp_with_process_group(self, process_group, devices, device_ids, multi_device=False, gradient_as_bucket_view=False): + def _test_ddp_with_process_group( + self, + process_group, + devices, + device_ids, + multi_device=False, + gradient_as_bucket_view=False, + ): """ Note: we pass down `device_ids` all the way to DistributedDataParallel as part of the test. Below you find tests that either use a list of @@ -2024,13 +2239,21 @@ def _test_ddp_with_process_group(self, process_group, devices, device_ids, multi global_batch_size = self.world_size * local_batch_size if multi_device: - model, ddp_model, input, target = \ - self._prepare_multi_device_module( - process_group, devices, device_ids, global_batch_size, gradient_as_bucket_view) + model, ddp_model, input, target = self._prepare_multi_device_module( + process_group, + devices, + device_ids, + global_batch_size, + gradient_as_bucket_view, + ) else: - model, ddp_model, input, target = \ - self._prepare_single_device_module( - process_group, devices, device_ids, global_batch_size, gradient_as_bucket_view) + model, ddp_model, input, target = self._prepare_single_device_module( + process_group, + devices, + device_ids, + global_batch_size, + gradient_as_bucket_view, + ) def step_model(model, input, target): model.train() @@ -2050,14 +2273,22 @@ def update_parameters(model): step_model(model, input, target) # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs - step_model(ddp_model, - input[self.rank * local_batch_size: (self.rank + 1) * local_batch_size], - target[self.rank * local_batch_size: (self.rank + 1) * local_batch_size]) + step_model( + ddp_model, + input[ + self.rank * local_batch_size : (self.rank + 1) * local_batch_size + ], + target[ + self.rank * local_batch_size : (self.rank + 1) * local_batch_size + ], + ) # Update weights and run a second iteration to shake out errors update_parameters(model) update_parameters(ddp_model) - self.assertEqual(len(list(model.parameters())), len(list(ddp_model.parameters()))) + self.assertEqual( + len(list(model.parameters())), len(list(ddp_model.parameters())) + ) for i, j in zip(model.parameters(), ddp_model.parameters()): self.assertEqual(i, j) @@ -2065,12 +2296,18 @@ def update_parameters(model): torch.manual_seed(1337 + iteration) input = input[torch.randperm(global_batch_size)] - def _test_gloo_backend(self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False): + def _test_gloo_backend( + self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False + ): store = c10d.FileStore(self.file_name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] - process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) - self._test_ddp_with_process_group(process_group, devices, device_ids, multi_device, gradient_as_bucket_view) + options.devices = [create_device(interface=LOOPBACK)] + process_group = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, options + ) + self._test_ddp_with_process_group( + process_group, devices, device_ids, multi_device, gradient_as_bucket_view + ) @requires_gloo() def test_gloo_backend_cpu_module(self): @@ -2108,14 +2345,17 @@ def test_gloo_backend_4gpu_module(self): devices = [torch.device("cuda:" + str(i)) for i in int_devices] self._test_gloo_backend(devices, [], multi_device=True) - def _test_nccl_backend(self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False): + def _test_nccl_backend( + self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False + ): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) - self._test_ddp_with_process_group(process_group, devices, device_ids, multi_device, gradient_as_bucket_view) + self._test_ddp_with_process_group( + process_group, devices, device_ids, multi_device, gradient_as_bucket_view + ) @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_nccl_backend_1gpu_module_device_ids_integer_list(self): int_devices = gpus_for_rank(self.world_size)[self.rank][:1] devices = [torch.device("cuda:" + str(i)) for i in int_devices] @@ -2123,7 +2363,6 @@ def test_nccl_backend_1gpu_module_device_ids_integer_list(self): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_nccl_backend_1gpu_module_device_ids_torch_device_list(self): int_devices = gpus_for_rank(self.world_size)[self.rank][:1] devices = [torch.device("cuda:" + str(i)) for i in int_devices] @@ -2131,7 +2370,6 @@ def test_nccl_backend_1gpu_module_device_ids_torch_device_list(self): @requires_nccl() @skip_if_lt_x_gpu(4) - @skip_if_rocm def test_nccl_backend_2gpu_module(self): int_devices = gpus_for_rank(self.world_size)[self.rank][:2] devices = [torch.device("cuda:" + str(i)) for i in int_devices] @@ -2139,7 +2377,6 @@ def test_nccl_backend_2gpu_module(self): @requires_nccl() @skip_if_lt_x_gpu(8) - @skip_if_rocm def test_nccl_backend_4gpu_module(self): int_devices = gpus_for_rank(self.world_size)[self.rank][:4] devices = [torch.device("cuda:" + str(i)) for i in int_devices] @@ -2147,7 +2384,6 @@ def test_nccl_backend_4gpu_module(self): @requires_nccl() @skip_if_lt_x_gpu(4) - @skip_if_rocm def test_ddp_multi_device_module_config(self): gpus = gpus_for_rank(self.world_size)[self.rank] @@ -2159,22 +2395,29 @@ def test_ddp_multi_device_module_config(self): gpus = gpus[:2] model = DoubleGpuNet(gpus) - with self.assertRaisesRegex(AssertionError, "output_device .* single-device GPU"): + with self.assertRaisesRegex( + AssertionError, "output_device .* single-device GPU" + ): ddp_model = DistributedDataParallel( - model, output_device=gpus[1], process_group=process_group) + model, output_device=gpus[1], process_group=process_group + ) with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"): ddp_model = DistributedDataParallel( - model, device_ids=gpus, process_group=process_group) + model, device_ids=gpus, process_group=process_group + ) - with self.assertRaisesRegex(AssertionError, "input module must be on the same type of devices"): + with self.assertRaisesRegex( + AssertionError, "input module must be on the same type of devices" + ): model.fc1 = model.fc1.cpu() ddp_model = DistributedDataParallel(model, process_group=process_group) model = model.cpu() with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"): ddp_model = DistributedDataParallel( - model, device_ids=gpus, process_group=process_group) + model, device_ids=gpus, process_group=process_group + ) def _test_fp16(self, gradient_as_bucket_view=False): store = c10d.FileStore(self.file_name, self.world_size) @@ -2188,13 +2431,13 @@ def _test_fp16(self, gradient_as_bucket_view=False): device_ids=[gpus[0]], process_group=process_group, bucket_cap_mb=0.001, - gradient_as_bucket_view=gradient_as_bucket_view + gradient_as_bucket_view=gradient_as_bucket_view, ) # Input 2**15, so that the gradients will overflow with a # world_size of 2, unless we normalize the gradient by the # world_size before the reduction - input = torch.tensor([[2**15]]).cuda(gpus[0]).half() + input = torch.tensor([[2 ** 15]]).cuda(gpus[0]).half() # Step model ddp_model.train() @@ -2202,19 +2445,15 @@ def _test_fp16(self, gradient_as_bucket_view=False): loss = output.sum() loss.backward() - self.assertFalse( - any(torch.isinf(p.grad).any() for p in ddp_model.parameters()) - ) + self.assertFalse(any(torch.isinf(p.grad).any() for p in ddp_model.parameters())) @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_fp16(self): self._test_fp16() @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_fp16_grad_is_view(self): self._test_fp16(gradient_as_bucket_view=True) @@ -2260,7 +2499,9 @@ def forward(self, x, fn): batch_size = 4 criterion = nn.CrossEntropyLoss() input = torch.rand([batch_size, 2], dtype=torch.float) - target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) # Always run "backward" to ensure the reducer is called by autograd. # If we don't correctly capture the output tensors from the return value, @@ -2309,16 +2550,26 @@ def test(box, unbox): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_arbitrary_forward_return_value(self): self._test_arbitrary_forward_return_value() @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_arbitrary_forward_return_value_grad_is_view(self): self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True) + @requires_nccl() + @skip_if_not_multigpu + def test_ddp_with_lazy_parameters(self): + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + with self.assertRaisesRegex( + RuntimeError, "Modules with uninitialized parameters" + ): + DistributedDataParallel( + torch.nn.LazyLinear(10), process_group=process_group + ) + def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False): """ Note: this test can be sped up by only running it on a CPU module @@ -2347,9 +2598,13 @@ def forward(self, x): batch_size = 4 criterion = nn.CrossEntropyLoss() input = torch.rand([batch_size, 2], dtype=torch.float) - target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) - def test_find_unused_parameters(find_unused_parameters, test_default=False, gradient_as_bucket_view=False): + def test_find_unused_parameters( + find_unused_parameters, test_default=False, gradient_as_bucket_view=False + ): if test_default: model = DistributedDataParallel( FindUnusedParametersModule().float().to(device_id), @@ -2375,35 +2630,40 @@ def test_find_unused_parameters(find_unused_parameters, test_default=False, grad # trigger an error when `backward` is called (because fc3 is an unused # parameter and will therefore be marked ready twice). try: - test_find_unused_parameters(True, gradient_as_bucket_view=gradient_as_bucket_view) + test_find_unused_parameters( + True, gradient_as_bucket_view=gradient_as_bucket_view + ) except Exception as ex: self.assertTrue( - str(ex).startswith("Expected to mark a variable ready only once.")) + str(ex).startswith("Expected to mark a variable ready only once.") + ) else: self.fail("Expected exception") # Then test that the default behavior can be overridden by setting # `find_unused_parameters=False`. try: - test_find_unused_parameters(False, gradient_as_bucket_view=gradient_as_bucket_view) + test_find_unused_parameters( + False, gradient_as_bucket_view=gradient_as_bucket_view + ) except Exception as ex: self.fail("Unexpected exception: %s" % ex) # Test find_unused_parameters defaults to False try: - test_find_unused_parameters(True, test_default=True, gradient_as_bucket_view=gradient_as_bucket_view) + test_find_unused_parameters( + True, test_default=True, gradient_as_bucket_view=gradient_as_bucket_view + ) except Exception as ex: self.fail("Unexpected exception: %s" % ex) @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_find_unused_parameters_kwarg(self): self._test_find_unused_parameters_kwarg() @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_find_unused_parameters_kwarg_grad_is_view(self): self._test_find_unused_parameters_kwarg(gradient_as_bucket_view=True) @@ -2413,6 +2673,7 @@ def _test_global_local_unused_params_grad(self, gradient_as_bucket_view=False): 1) DDP does not touch the grad of globally unused parameters. 2) DDP does update the grad of locally unused parameters. """ + class GlobalLocalUnusedParamModule(nn.Module): def __init__(self): super(GlobalLocalUnusedParamModule, self).__init__() @@ -2488,6 +2749,7 @@ def test_find_unused_parameters_when_unused_parameters_empty(self): This unit test creates a module that uses all parameters in rank = 0, and has unused parameters in other ranks. """ + class FindUnusedParamModule(nn.Module): def __init__(self): super(FindUnusedParamModule, self).__init__() @@ -2574,7 +2836,9 @@ def forward(self, x): batch_size = 4 criterion = nn.CrossEntropyLoss() input = torch.rand([batch_size, 2], dtype=torch.float) - target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) # Compute loss and gradients for both outputs output1, output2 = model(input) @@ -2585,19 +2849,16 @@ def forward(self, x): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_multiple_outputs_multiple_backward(self): self._test_multiple_outputs_multiple_backward() @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_multiple_outputs_multiple_backward_grad_is_view(self): self._test_multiple_outputs_multiple_backward(gradient_as_bucket_view=True) @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_no_grad(self): """ Note: this test can be sped up by only running it on a CPU module @@ -2644,7 +2905,9 @@ def check_no_grads(): # No parameter should have their gradient set. check_no_grads() - def _test_accumulate_gradients_no_sync(self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False): + def _test_accumulate_gradients_no_sync( + self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False + ): """ This is the recommended way to implement accumulate grads. If ``ddp_comm_hook`` input was specified, it will also register that hook @@ -2663,7 +2926,7 @@ def _test_accumulate_gradients_no_sync(self, num_iters=2, ddp_comm_hook=None, gr ) if ddp_comm_hook is not None: - ddp_model._register_comm_hook(process_group, ddp_comm_hook) + ddp_model.register_comm_hook(process_group, ddp_comm_hook) def step_model(model, input, target): model.train() @@ -2709,7 +2972,6 @@ def step_model(model, input, target): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_accumulate_gradients_no_sync(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -2718,7 +2980,6 @@ def test_accumulate_gradients_no_sync(self): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_accumulate_gradients_no_sync_grad_is_view(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -2727,7 +2988,6 @@ def test_accumulate_gradients_no_sync_grad_is_view(self): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_accumulate_gradients_no_sync_allreduce_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync @@ -2747,7 +3007,6 @@ def allreduce_hook( @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce @@ -2785,9 +3044,9 @@ def _test_accumulate_gradients_module(self, gradient_as_bucket_view=False): process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) global_batch_size = self.world_size - model, ddp_model, input, target = \ - self._prepare_single_device_module( - process_group, devices, devices, global_batch_size, gradient_as_bucket_view) + model, ddp_model, input, target = self._prepare_single_device_module( + process_group, devices, devices, global_batch_size, gradient_as_bucket_view + ) def step_model(model, input, target): model.train() @@ -2810,15 +3069,17 @@ def step_model(model, input, target): # Skip gradients sync without calling prepare_for_backward step_model( ddp_model.module, - input[self.rank:(self.rank + 1)], - target[self.rank:(self.rank + 1)]) + input[self.rank : (self.rank + 1)], + target[self.rank : (self.rank + 1)], + ) for i, j in zip(model.parameters(), ddp_model.parameters()): self.assertNotEqual(i.grad, j.grad) else: step_model( ddp_model, - input[self.rank:(self.rank + 1)], - target[self.rank:(self.rank + 1)]) + input[self.rank : (self.rank + 1)], + target[self.rank : (self.rank + 1)], + ) for i, j in zip(model.parameters(), ddp_model.parameters()): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(i.grad, j.grad) @@ -2829,13 +3090,11 @@ def step_model(model, input, target): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_accumulate_gradients_module(self): self._test_accumulate_gradients_module() @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_accumulate_gradients_module_with_grad_is_view(self): self._test_accumulate_gradients_module(gradient_as_bucket_view=True) @@ -2928,7 +3187,6 @@ def forward(self, x): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_failure_recovery(self): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -2966,7 +3224,9 @@ def forward(self, x): batch_size = 4 criterion = nn.CrossEntropyLoss() input = torch.rand([batch_size, 2], dtype=torch.float) - target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) for _ in range(6): output = ddp(input) @@ -2986,12 +3246,28 @@ def forward(self, x): ) input = torch.rand([batch_size, 2], dtype=torch.float) - target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) for _ in range(6): output = ddp(input) loss = criterion(output, target) loss.backward() + @requires_nccl() + @skip_if_not_multigpu + def test_pass_default_pg(self): + dist.init_process_group( + "nccl", + init_method=f"file://{self.file_name}", + world_size=self.world_size, + rank=self.rank, + ) + + default_pg = c10d.distributed_c10d._get_default_group() + dist.destroy_process_group(default_pg) + self.assertFalse(dist.is_initialized()) + @requires_nccl() @skip_if_not_multigpu def test_save_load_checkpoint(self): @@ -2999,7 +3275,7 @@ def test_save_load_checkpoint(self): "gloo", init_method=f"file://{self.file_name}", world_size=self.world_size, - rank=self.rank + rank=self.rank, ) class TestModel(nn.Module): @@ -3051,7 +3327,9 @@ def train_loop(model, optimizer, iterations): optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001) input = torch.rand([batch_size, 2], dtype=torch.float) - target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) # run the model for 6 iterations, with a checkpoint in the middle train_loop(ddp_withload, optimizer_withload, 3) @@ -3065,19 +3343,21 @@ def train_loop(model, optimizer, iterations): for p in ddp_withload.parameters(): with torch.no_grad(): p.zero_() - map_location = {'cuda:%d' % 0: 'cuda:%d' % self.rank} + map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank} ddp_withload.load_state_dict( - torch.load(checkpoint_path, map_location=map_location)) + torch.load(checkpoint_path, map_location=map_location) + ) train_loop(ddp_withload, optimizer_withload, 3) # re-run the model with the same inputs for 6 iterations with no checkpoint train_loop(ddp_withoutload, optimizer_withoutload, 6) - for p_withload, p_withoutload in zip(ddp_withload.parameters(), ddp_withoutload.parameters()): + for p_withload, p_withoutload in zip( + ddp_withload.parameters(), ddp_withoutload.parameters() + ): self.assertEqual(p_withload, p_withoutload) - def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model): mult = 2 batch_size = mult * self.world_size @@ -3131,17 +3411,25 @@ def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size): # Carry out some trials with small buckets and some with big buckets. bucketsizes = (0.000001, 25) # Tuples of lists. Each list describes per-layer characteristics for one trial. - layer_formats = ([torch.contiguous_format] * 4, - [torch.channels_last] * 2 + [torch.contiguous_format] * 2, - [torch.channels_last] * 4) - layer_dtypes = ([torch.float] * 4, - [torch.float] * 2 + [torch.half] * 2, - [torch.half] * 4) + layer_formats = ( + [torch.contiguous_format] * 4, + [torch.channels_last] * 2 + [torch.contiguous_format] * 2, + [torch.channels_last] * 4, + ) + layer_dtypes = ( + [torch.float] * 4, + [torch.float] * 2 + [torch.half] * 2, + [torch.half] * 4, + ) input_dev = layer_devs[0] if isinstance(layer_devs, list) else layer_devs target_dev = layer_devs[-1] if isinstance(layer_devs, list) else layer_devs - input = torch.randn((global_batch_size, 8, 8, 8), device=input_dev, dtype=torch.float) - target = torch.randn((global_batch_size, 8, 4, 4), device=target_dev, dtype=torch.float) + input = torch.randn( + (global_batch_size, 8, 8, 8), device=input_dev, dtype=torch.float + ) + target = torch.randn( + (global_batch_size, 8, 4, 4), device=target_dev, dtype=torch.float + ) local_batch_start = self.rank * local_batch_size local_batch_end = (self.rank + 1) * local_batch_size @@ -3150,30 +3438,42 @@ def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size): @contextmanager def first_bucket_size(ddp_bucket_mb): old_DEFAULT_FIRST_BUCKET_BYTES = dist._DEFAULT_FIRST_BUCKET_BYTES - dist._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.e6) + dist._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.0e6) try: yield finally: dist._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES - with torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False): - for formats, dtypes, bucketsize in product(layer_formats, layer_dtypes, bucketsizes): + with torch.backends.cudnn.flags( + enabled=True, deterministic=True, benchmark=False + ): + for formats, dtypes, bucketsize in product( + layer_formats, layer_dtypes, bucketsizes + ): with first_bucket_size(bucketsize): - model_msg = "rank = {} formats = {} dtypes = {} bucketsize = {} ".format(self.rank, formats, - dtypes, bucketsize) + model_msg = ( + "rank = {} formats = {} dtypes = {} bucketsize = {} ".format( + self.rank, formats, dtypes, bucketsize + ) + ) try: m = ConvNet(layer_devs, formats, dtypes) - m_ddp = DistributedDataParallel(copy.deepcopy(m), - device_ids=replica_devices, - process_group=process_group, - bucket_cap_mb=bucketsize) + m_ddp = DistributedDataParallel( + copy.deepcopy(m), + device_ids=replica_devices, + process_group=process_group, + bucket_cap_mb=bucketsize, + ) opt = torch.optim.SGD(m.parameters(), lr=0.1) opt_ddp = torch.optim.SGD(m_ddp.parameters(), lr=0.1) has_half = any(p.dtype is torch.half for p in m.parameters()) - tol = 1.e-3 if has_half else 1.e-5 + tol = 1.0e-3 if has_half else 1.0e-5 except BaseException: # Prints case-specific debugging info to narrow down failing case. - print("Caught exception during model creation for " + model_msg, flush=True) + print( + "Caught exception during model creation for " + model_msg, + flush=True, + ) raise # 3 iters: First iter creates grads, second iter retests after rebucketing, # third iter tries zeroed grads. @@ -3182,19 +3482,38 @@ def first_bucket_size(ddp_bucket_mb): named_msg = iter_msg try: F.mse_loss(m(input).float(), target).backward() - F.mse_loss(m_ddp(input[local_batch_start: local_batch_end]).float(), - target[local_batch_start: local_batch_end]).backward() - for i, ((layer_name, m_child), m_ddp_child) in enumerate(zip(m.named_children(), - m_ddp.module.children())): + F.mse_loss( + m_ddp(input[local_batch_start:local_batch_end]).float(), + target[local_batch_start:local_batch_end], + ).backward() + for i, ((layer_name, m_child), m_ddp_child) in enumerate( + zip(m.named_children(), m_ddp.module.children()) + ): named_msg = layer_name + ".weight" + " " + iter_msg - self.assertTrue(m_child.weight.grad.is_contiguous(memory_format=formats[i]), - named_msg) - self.assertTrue(m_ddp_child.weight.grad.is_contiguous(memory_format=formats[i]), - named_msg) - for j, ((param_name, p), p_ddp) in enumerate(zip(m_child.named_parameters(), - m_ddp_child.parameters())): - named_msg = layer_name + "." + param_name + " " + iter_msg - self.assertEqual(p.grad, p_ddp.grad, rtol=tol, atol=tol) + self.assertTrue( + m_child.weight.grad.is_contiguous( + memory_format=formats[i] + ), + named_msg, + ) + self.assertTrue( + m_ddp_child.weight.grad.is_contiguous( + memory_format=formats[i] + ), + named_msg, + ) + for j, ((param_name, p), p_ddp) in enumerate( + zip( + m_child.named_parameters(), + m_ddp_child.parameters(), + ) + ): + named_msg = ( + layer_name + "." + param_name + " " + iter_msg + ) + self.assertEqual( + p.grad, p_ddp.grad, rtol=tol, atol=tol + ) opt.step() opt_ddp.step() if it == 0: @@ -3206,12 +3525,14 @@ def first_bucket_size(ddp_bucket_mb): m_ddp.zero_grad() except BaseException: # Makes sure we still get info if an error occurred somewhere other than the asserts. - print("Caught exception during iterations at " + named_msg, flush=True) + print( + "Caught exception during iterations at " + named_msg, + flush=True, + ) raise @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_grad_layout_1devicemodule_1replicaperprocess(self): dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0])) # Tells DDP to use just one device. @@ -3221,10 +3542,11 @@ def test_grad_layout_1devicemodule_1replicaperprocess(self): local_batch_size = 8 self._test_grad_layout(replica_devices, layer_devs, local_batch_size) - @unittest.skipIf(True, "Reenable when DDP with multiple GPUs per process is confirmed to work") + @unittest.skipIf( + True, "Re-enable when DDP with multiple GPUs per process is confirmed to work" + ) @requires_nccl() @skip_if_lt_x_gpu(4) - @skip_if_rocm def test_grad_layout_1devicemodule_2replicaperprocess(self): int_devices = gpus_for_rank(self.world_size)[self.rank][:2] dev0 = torch.device("cuda:" + str(int_devices[0])) @@ -3252,22 +3574,32 @@ def test_grad_layout_2devicemodule(self): @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_param_layout_mismatch_error(self): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0])) layer_devs = dev0 - layer_formats = [torch.contiguous_format] * 4 if self.rank == 0 else [torch.channels_last] * 4 + layer_formats = ( + [torch.contiguous_format] * 4 + if self.rank == 0 + else [torch.channels_last] * 4 + ) layer_dtypes = [torch.float] * 4 m = ConvNet(layer_devs, layer_formats, layer_dtypes) if self.rank == 0: - m_ddp = DistributedDataParallel(m, device_ids=[dev0], process_group=process_group) + m_ddp = DistributedDataParallel( + m, device_ids=[dev0], process_group=process_group + ) else: - with self.assertRaisesRegex(RuntimeError, ".* appears not to match strides of the same param in process 0"): - m_ddp = DistributedDataParallel(m, device_ids=[dev0], process_group=process_group) + with self.assertRaisesRegex( + RuntimeError, + ".* appears not to match strides of the same param in process 0", + ): + m_ddp = DistributedDataParallel( + m, device_ids=[dev0], process_group=process_group + ) @requires_gloo() def test_ddp_comm_hook_future_passing_cpu(self): @@ -3284,13 +3616,32 @@ def test_ddp_comm_hook_future_passing_cpu(self): ) # Register DDP Communication Hook - cpu_model._register_comm_hook(None, self._simple_hook) + cpu_model.register_comm_hook(None, self._simple_hook) # check whether the grads are equal to what then callback returns. # without the comm_hook, result would be 0.25 * torch.ones(2, 2). self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2)) - def _gpu_model_with_ddp_comm_hook(self, process_group, hook=None, gradient_as_bucket_view=False): + def _gpu_model_with_ddp_comm_hook( + self, process_group, hook=None, gradient_as_bucket_view=False, state=None + ): + device_id = gpus_for_rank(self.world_size)[self.rank][0] + gpu_model = DistributedDataParallel( + ModuleForDdpCommHook().to(device_id), + device_ids=[device_id], + process_group=process_group, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + + # Register a DDP communication hook if any. + if hook is not None: + gpu_model.register_comm_hook(state, hook) + + return gpu_model + + def _gpu_model_with_builtin_ddp_comm_hook( + self, process_group, hook=None, gradient_as_bucket_view=False + ): device_id = gpus_for_rank(self.world_size)[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), @@ -3299,9 +3650,9 @@ def _gpu_model_with_ddp_comm_hook(self, process_group, hook=None, gradient_as_bu gradient_as_bucket_view=gradient_as_bucket_view, ) - # Register DDP Communication Hook if defined + # Register a built-in DDP communication hook if defined if hook is not None: - gpu_model._register_comm_hook(None, hook) + gpu_model._register_builtin_comm_hook(hook) return gpu_model @@ -3345,7 +3696,6 @@ def test_ddp_comm_hook_future_passing_gpu_gloo(self): @requires_nccl() @skip_if_lt_x_gpu(2) - @skip_if_rocm def test_ddp_comm_hook_future_passing_gpu_nccl(self): """ This unit test verifies whether the Future object is passed properly using nccl backend. @@ -3364,7 +3714,7 @@ def test_ddp_comm_hook_future_passing_gpu_nccl(self): def _test_ddp_comm_hook_allreduce_hook_nccl(self, gradient_as_bucket_view=False): """ This unit test verifies whether a DDP communication hook that just calls - allreduce gives the same result result with the case of no hook registered. + allreduce gives the same result with the case of no hook registered. Without the then callback, the future_value in reducer is no longer a PyObject, and this unit test verifies future_value is properly checked. """ @@ -3376,26 +3726,121 @@ def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future: return process_group.allreduce(tensors).get_future() # Get GPU model with allreduce_hook registered. - gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, allreduce_hook, gradient_as_bucket_view) + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, allreduce_hook, gradient_as_bucket_view + ) # check whether the grads are equal to what DDP without hook would return. self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + def _test_default_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether default Python DDP communication hooks ALLREDUCE and FP16_COMPRESS + can give the same result with the case of no hook registered. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # For these default DDP comm hooks, the only state is process group. + state = process_group + for hook in [default.allreduce_hook, default.fp16_compress_hook]: + # Get GPU model with the hook registered. + # The first arg 'process_group' is used for initializing the test environment, + # so it cannot be replaced by 'state', although they have the same value. + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, hook, gradient_as_bucket_view, state + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + def _test_powerSGD_ddp_comm_hook_nccl(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether Python DDP communication hook POWER_SGD + can give the same result with the case of no hook registered. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # Get GPU model with the hook registered. + # Test the hook with different algorithmic configs. + for use_error_feedback, warm_start in product([True, False], [True, False]): + state = powerSGD.PowerSGDState( + process_group=process_group, + matrix_approximation_rank=1, + use_error_feedback=use_error_feedback, + warm_start=warm_start, + ) + for hook in [powerSGD.powerSGD_hook, powerSGD.batched_powerSGD_hook]: + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, hook, gradient_as_bucket_view, state + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether built-in C++ DDP communication hooks ALLREDUCE and FP16_COMPRESS + can give the same result with the case of no hook registered. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + for comm_hook_type in [ + dist.BuiltinCommHookType.ALLREDUCE, + dist.BuiltinCommHookType.FP16_COMPRESS, + ]: + # Get GPU model with the built-in communication hook. + gpu_model = self._gpu_model_with_builtin_ddp_comm_hook( + process_group, comm_hook_type, gradient_as_bucket_view + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + @requires_nccl() @skip_if_lt_x_gpu(2) - @skip_if_rocm def test_ddp_comm_hook_allreduce_hook_nccl(self): self._test_ddp_comm_hook_allreduce_hook_nccl() @requires_nccl() @skip_if_lt_x_gpu(2) - @skip_if_rocm + def test_default_ddp_comm_hooks_nccl(self): + self._test_default_ddp_comm_hooks_nccl() + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_builtin_ddp_comm_hooks_nccl(self): + self._test_builtin_ddp_comm_hooks_nccl() + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_powerSGD_ddp_comm_hook_nccl(self): + self._test_powerSGD_ddp_comm_hook_nccl() + + @requires_nccl() + @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self): self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True) @requires_nccl() @skip_if_lt_x_gpu(2) - @skip_if_rocm + def test_default_ddp_comm_hooks_nccl_is_view(self): + self._test_default_ddp_comm_hooks_nccl(gradient_as_bucket_view=True) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_builtin_ddp_comm_hooks_nccl_grad_is_view(self): + self._test_builtin_ddp_comm_hooks_nccl(gradient_as_bucket_view=True) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_powerSGD_ddp_comm_hook_nccl_grad_is_view(self): + self._test_powerSGD_ddp_comm_hook_nccl(gradient_as_bucket_view=True) + + @requires_nccl() + @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_allreduce_with_then_hook_nccl(self): """ This unit test verifies whether a DDP communication hook that calls allreduce and then @@ -3439,10 +3884,12 @@ def test_ddp_invalid_comm_hook_init(self): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size) - model = DistributedDataParallel(ModuleForDdpCommHook(), process_group=process_group) + model = DistributedDataParallel( + ModuleForDdpCommHook(), process_group=process_group + ) with self.assertRaisesRegex(TypeError, "Communication hook must be callable."): - model._register_comm_hook(state=None, hook=1) + model.register_comm_hook(state=None, hook=1) with self.assertRaisesRegex( ValueError, "bucket annotation should be dist._GradBucket." @@ -3451,7 +3898,7 @@ def test_ddp_invalid_comm_hook_init(self): def comm_hook(state: object, bucket: int) -> torch.futures.Future: return torch.futures.Future() - model._register_comm_hook(state=None, hook=comm_hook) + model.register_comm_hook(state=None, hook=comm_hook) @requires_gloo() def test_ddp_invalid_comm_hook_return_type(self): @@ -3463,7 +3910,9 @@ def test_ddp_invalid_comm_hook_return_type(self): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size) - model = DistributedDataParallel(ModuleForDdpCommHook(), process_group=process_group) + model = DistributedDataParallel( + ModuleForDdpCommHook(), process_group=process_group + ) with self.assertRaisesRegex( ValueError, @@ -3473,7 +3922,7 @@ def test_ddp_invalid_comm_hook_return_type(self): def comm_hook(state: object, bucket: dist._GradBucket) -> int: return torch.futures.Future() - model._register_comm_hook(state=None, hook=comm_hook) + model.register_comm_hook(state=None, hook=comm_hook) with self.assertRaisesRegex( RuntimeError, @@ -3483,7 +3932,7 @@ def comm_hook(state: object, bucket: dist._GradBucket) -> int: def comm_hook(state: object, bucket: dist._GradBucket): return 1 - model._register_comm_hook(state=None, hook=comm_hook) + model.register_comm_hook(state=None, hook=comm_hook) # Run forward output = model(8, self.rank) @@ -3500,19 +3949,22 @@ def test_ddp_comm_hook_register_just_once(self): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size) - model = DistributedDataParallel(ModuleForDdpCommHook(), process_group=process_group) + model = DistributedDataParallel( + ModuleForDdpCommHook(), process_group=process_group + ) def dummy_hook(state, bucket): fut = torch.futures.Future() fut.set_result(bucket.get_tensors()) return fut - model._register_comm_hook(None, dummy_hook) + model.register_comm_hook(None, dummy_hook) with self.assertRaisesRegex( - RuntimeError, "register_comm_hook can only be called once." + RuntimeError, + "register_comm_hook or register_builtin_comm_hook can only be called once.", ): - model._register_comm_hook(None, dummy_hook) + model.register_comm_hook(None, dummy_hook) @requires_gloo() def test_ddp_comm_hook_sparse_gradients(self): @@ -3534,7 +3986,9 @@ def test_ddp_comm_hook_sparse_gradients(self): # "get_future" API does not support gloo backend, see GH Issue #42048. # Instead, we wait for an allreduce work, and write its result to a Future. - def allreduce_hook_gloo(state: object, bucket: dist._GradBucket) -> torch.futures.Future: + def allreduce_hook_gloo( + state: object, bucket: dist._GradBucket + ) -> torch.futures.Future: # Prepare allreduced grad bucket tensors by running an async work. work = process_group.allreduce(bucket.get_tensors()) work.wait() @@ -3543,7 +3997,7 @@ def allreduce_hook_gloo(state: object, bucket: dist._GradBucket) -> torch.future fut.set_result([t / self.world_size for t in bucket.get_tensors()]) return fut - ddp_model._register_comm_hook(None, allreduce_hook_gloo) + ddp_model.register_comm_hook(None, allreduce_hook_gloo) self._run_and_verify_sparse_gradients(vanilla_model, ddp_model) @@ -3597,19 +4051,23 @@ def test_multi_dtype_multi_bucket(self): model = self._create_mixed_precision_model() parameters = [list(model.parameters())] group_by_dtype = groupby( - range(len(parameters[0])), - key=lambda i: parameters[0][i].dtype) + range(len(parameters[0])), key=lambda i: parameters[0][i].dtype + ) buckets = [list(indices) for _, indices in group_by_dtype] dist.Reducer(parameters, buckets, self.process_group) def _create_reducer_for_models(self, models, find_unused_parameters=False): parameters = [list(model.parameters()) for model in models] group_by_dtype = groupby( - range(len(parameters[0])), - key=lambda i: parameters[0][i].dtype) + range(len(parameters[0])), key=lambda i: parameters[0][i].dtype + ) buckets = [list(indices) for _, indices in group_by_dtype] - return dist.Reducer(parameters, buckets, self.process_group, - find_unused_parameters=find_unused_parameters) + return dist.Reducer( + parameters, + buckets, + self.process_group, + find_unused_parameters=find_unused_parameters, + ) def test_forward_backward_single_replica(self): batch_size = 10 @@ -3751,8 +4209,10 @@ def test_multi_limit_multi_dtype(self): self.assertEqual([[0], [1], [2, 4], [3, 5]], result) -@skip_if_rocm -@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) class NcclErrorHandlingTest(MultiProcessTestCase): def setUp(self): super(NcclErrorHandlingTest, self).setUp() @@ -3781,13 +4241,16 @@ def op_timeout_sec(self): def world_size(self): return 3 + @property + def blocking_wait_error_msg(self): + return "Caught collective operation timeout" + def _run_all_reduce(self, pg): pg.allreduce(torch.rand(10).cuda(self.rank)) @requires_nccl() @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_errors_nonblocking(self): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -3814,16 +4277,17 @@ def _test_nccl_errors_blocking(self, func): store, self.rank, self.world_size, - timeout=timedelta(seconds=self.op_timeout_sec)) + timeout=timedelta(seconds=self.op_timeout_sec), + ) process_group.allreduce(torch.rand(10).cuda(self.rank)) if self.rank == 0: work = process_group.allreduce(torch.rand(10).cuda(self.rank)) - with self.assertRaisesRegex(RuntimeError, "Operation timed out!"): + with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # Operation would time out in blocking mode. work.wait() - # Run some GPU operations to make sure cuda does not stuck to - # run new events. It was observed cuda could stuck if not - # aborting nccl communicators before throwing Operation timed out + # Run some GPU operations to make sure cuda has not gotten stuck. + # It was observed cuda could get stuck if NCCL communicators were + # not properly aborted before throwing RuntimeError. a = torch.rand(10).cuda(self.rank) elif self.rank == 1: # Clean up structures (ex: files for FileStore before going down) @@ -3839,42 +4303,36 @@ def _test_nccl_errors_blocking(self, func): @requires_nccl() @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_errors_blocking_clean_exit(self): self._test_nccl_errors_blocking(lambda: sys.exit(0)) @requires_nccl() @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_errors_blocking_nonzero_exit(self): self._test_nccl_errors_blocking(lambda: sys.exit(1)) @requires_nccl() @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_errors_blocking_abort(self): self._test_nccl_errors_blocking(lambda: os.abort()) @requires_nccl() @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_errors_blocking_sigkill(self): self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL)) @requires_nccl() @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_errors_blocking_sigterm(self): self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM)) @requires_nccl() @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_blocking_wait_with_barrier(self): os.environ["NCCL_BLOCKING_WAIT"] = "1" store = c10d.FileStore(self.file_name, self.world_size) @@ -3882,10 +4340,11 @@ def test_nccl_blocking_wait_with_barrier(self): store, self.rank, self.world_size, - timeout=timedelta(seconds=self.op_timeout_sec)) + timeout=timedelta(seconds=self.op_timeout_sec), + ) process_group.barrier().wait() if self.rank == 0: - with self.assertRaisesRegex(RuntimeError, "Operation timed out!"): + with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): # This should timeout process_group.barrier().wait() @@ -3897,17 +4356,16 @@ def _run_invalid_nccl_blocking_wait_env(self, val): @requires_nccl() @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_invalid_nccl_blocking_wait_env(self): - self._run_invalid_nccl_blocking_wait_env('abc') - self._run_invalid_nccl_blocking_wait_env('-1') - self._run_invalid_nccl_blocking_wait_env('2147483647') - self._run_invalid_nccl_blocking_wait_env('4294967295') + self._run_invalid_nccl_blocking_wait_env("abc") + self._run_invalid_nccl_blocking_wait_env("-1") + self._run_invalid_nccl_blocking_wait_env("2147483647") + self._run_invalid_nccl_blocking_wait_env("4294967295") def _wait_for_comm_abort(self, process_group): - ''' + """ Waits for the watchdog thread to abort communicators for the process group. - ''' + """ while True: try: process_group.allreduce(torch.rand(10).cuda(self.rank)) @@ -3916,25 +4374,26 @@ def _wait_for_comm_abort(self, process_group): return else: raise e - time.sleep(0.1) + time.sleep(1) @requires_nccl() @skip_if_lt_x_gpu(3) - @skip_if_rocm def test_nccl_timeout(self): store = c10d.FileStore(self.file_name, self.world_size) os.environ["NCCL_BLOCKING_WAIT"] = "1" # Initialize process_group. timeout = 1 - process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)) + process_group = c10d.ProcessGroupNCCL( + store, self.rank, self.world_size, timeout=timedelta(seconds=timeout) + ) process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() if self.rank == 0: # This should timeout in about 1 second. start = time.time() # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out. - with self.assertRaisesRegex(RuntimeError, "Operation timed out!"): + with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() else: # Sleep to ensure timeout. @@ -3943,11 +4402,17 @@ def test_nccl_timeout(self): self._wait_for_comm_abort(process_group) -@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) class CommTest(MultiProcessTestCase): def setUp(self): super(CommTest, self).setUp() - self._fork_processes() + if sys.platform == "win32": + self._spawn_processes() + else: + self._fork_processes() def tearDown(self): super(CommTest, self).tearDown() @@ -3989,17 +4454,14 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank): self.assertNotEqual(tensors, target) c10d._broadcast_coalesced( - process_group, - tensors, - buffer_size=256, - src=root_rank) + process_group, tensors, buffer_size=256, src=root_rank + ) if self.rank != root_rank: self.assertEqual(tensors, target) @requires_nccl() @skip_if_not_multigpu - @skip_if_rocm def test_broadcast_coalesced_nccl(self): store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) @@ -4013,8 +4475,10 @@ def test_broadcast_coalesced_nccl(self): def test_broadcast_coalesced_gloo_cuda(self): store = c10d.FileStore(self.file_name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] - process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) + options.devices = [create_device(interface=LOOPBACK)] + process_group = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, options + ) device = torch.device("cuda:%d" % self.rank) ranks = list(range(self.world_size)) for root_rank in ranks: @@ -4024,15 +4488,139 @@ def test_broadcast_coalesced_gloo_cuda(self): def test_broadcast_coalesced_gloo_cpu(self): store = c10d.FileStore(self.file_name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] - process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) + options.devices = [create_device(interface=LOOPBACK)] + process_group = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, options + ) device = torch.device("cpu") ranks = list(range(self.world_size)) for root_rank in ranks: self._test_broadcast_coalesced(process_group, device, root_rank) + @requires_nccl() + @skip_if_lt_x_gpu(4) + def test_nccl_barrier(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store) + + t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank) + c10d.all_reduce(t) + expected_tensor = torch.tensor([3] * 10).cuda(2 * self.rank) + self.assertEqual(expected_tensor, t) + + # Test with new_group + pg = c10d.new_group([0, 1]) + t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank) + pg.allreduce(t).wait() + self.assertEqual(expected_tensor, t) + + pg = c10d.new_group([0]) + if self.rank == 0: + t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank) + expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank) + pg.allreduce(t).wait() + self.assertEqual(expected_tensor, t) + + pg = c10d.new_group([1]) + if self.rank == 1: + t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank) + expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank) + pg.allreduce(t).wait() + self.assertEqual(expected_tensor, t) + + @requires_nccl() + @skip_if_lt_x_gpu(4) + def test_nccl_barrier_timeout(self): + store = c10d.FileStore(self.file_name, self.world_size) + if self.rank == 0: + with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"): + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store, + timeout=timedelta(seconds=1)) + + @requires_nccl() + @skip_if_lt_x_gpu(4) + def test_nccl_barrier_timeout_new_group(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store, + timeout=timedelta(seconds=1)) + + if self.rank == 0: + with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"): + c10d.new_group([0, 1], timeout=timedelta(seconds=1)) + + with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"): + c10d.new_group([0], timeout=timedelta(seconds=1)) + + @requires_nccl() + @skip_if_lt_x_gpu(4) + def test_nccl_barrier_timeout_new_group_non_member(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store, + timeout=timedelta(seconds=1)) + + if self.rank == 1: + with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"): + c10d.new_group([0, 1], timeout=timedelta(seconds=1)) + + with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"): + c10d.new_group([0], timeout=timedelta(seconds=1)) + + @requires_nccl() + @skip_if_not_multigpu + def test_nccl_barrier_device_ids(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store) + + c10d.barrier(device_ids=[self.rank]) + + @requires_nccl() + @skip_if_not_multigpu + def test_nccl_barrier_device_ids_function_argument(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store) + + with self.assertRaisesRegex(RuntimeError, "Invalid function argument"): + c10d.barrier(device_ids=self.rank) + + @requires_gloo() + def test_gloo_barrier_device_ids(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="gloo", + rank=self.rank, + world_size=self.world_size, + store=store) + + with self.assertRaisesRegex(RuntimeError, "device_ids not supported"): + c10d.barrier(device_ids=[self.rank]) -if __name__ == '__main__': - assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process" +if __name__ == "__main__": + assert ( + not torch.cuda._initialized + ), "test_distributed must not have initialized CUDA context on main process" run_tests() diff --git a/test/distributed/test_c10d_spawn.py b/test/distributed/test_c10d_spawn.py index d0bf00b8a08a9..0bba18bf3ab92 100644 --- a/test/distributed/test_c10d_spawn.py +++ b/test/distributed/test_c10d_spawn.py @@ -10,8 +10,10 @@ import torch.nn as nn from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU -from torch.testing._internal.common_distributed import requires_gloo -from torch.testing._internal.common_utils import TestCase, load_tests, run_tests, skipIfRocm +from torch.testing._internal.common_distributed import requires_gloo, \ + create_device +from torch.testing._internal.common_utils import TestCase, load_tests, \ + run_tests, skipIfRocm from torch.testing._internal.common_utils import NO_MULTIPROCESSING_SPAWN, TEST_WITH_TSAN @@ -39,7 +41,7 @@ class ProcessGroupShareTensorTest(TestCase): @classmethod def opts(cls, threads=2): opts = c10d.ProcessGroupGloo.Options() - opts.devices = [c10d.ProcessGroupGloo.create_device(interface="lo")] + opts.devices = [create_device(interface='lo')] opts.timeout = 5.0 opts.threads = threads return opts @@ -229,9 +231,18 @@ def tearDown(self): def _test_base(self, net, inp, check_allclose=True): store = c10d.FileStore(self.file.name, self.world_size) process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size) + if inp[0].is_cuda: + num_gpus = torch.cuda.device_count() + batch_size = inp[0].size(0) + # batch_size must be evenly divisible by num_gpus_used, take the largest one + num_gpus_used = [i for i in range(1, num_gpus + 1) if batch_size % i == 0][-1] + device_ids = list(range(num_gpus_used)) + else: + device_ids = None ddp = nn.parallel.DistributedDataParallel( copy.deepcopy(net), + device_ids=device_ids, process_group=process_group ) @@ -270,7 +281,7 @@ def test_cuda(self): def test_rnn(self): # This test is inspired by the bug reported in # https://github.com/pytorch/pytorch/issues/36268 - BATCH_SIZE = 4 + BATCH_SIZE = 12 # Divisible by 2, 3, 4 INPUT_DIM = 256 OUTPUT_DIM = 256 HIDDEN_DIM = 256 diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py index 99a10906462a8..f3161a1f8cb1e 100644 --- a/test/distributed/test_data_parallel.py +++ b/test/distributed/test_data_parallel.py @@ -18,6 +18,7 @@ torch.set_default_dtype(torch.double) +NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL") class TestDataParallel(TestCase): @@ -78,6 +79,13 @@ def step(model): for p1, p2 in zip(model.parameters(), model_dp.parameters()): self.assertTrue(p1.allclose(p2)) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + def test_data_parallel_lazy_linear(self): + + with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'): + model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0)) + model_dp(torch.rand(10, 10).to(0)) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_parallel_apply(self): l1 = nn.Linear(10, 5).to("cuda:0", torch.float) @@ -507,6 +515,25 @@ def forward(self, input): self.assertEqual(out.get_device(), 0) self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + def test_data_parallel_module_zero_inputs(self): + class TestModule(nn.Module): + def forward(self): + t = torch.eye(2, 3, device='cuda:0') + return t + (1 - t) + + def test_helper(output, expected): + self.assertEqual(output.get_device(), 0) + self.assertEqual(output, expected) + + expected = torch.ones(2, 3, device='cuda:0') + model = TestModule() + + test_helper(nn.DataParallel(model, [0])(), expected) + test_helper(nn.DataParallel(model, [0, 1])(), expected) + test_helper(dp.data_parallel(model, None, [0]), expected) + test_helper(dp.data_parallel(model, (), [0, 1]), expected) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_device_args(self): cuda0 = torch.device('cuda:0') @@ -571,6 +598,25 @@ def test_scatter_cpu(self): def test_scatter_gpu(self): self._test_scatter(torch.randn((4, 4)).cuda()) + @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @unittest.skipIf(NO_NCCL, "NCCL needed") + def test_data_parallel_complex(self): + # We expect complex parameters to be broadcast by view_as_real, e.g. move from C to R^2 + class Cplx(torch.nn.Module): + def __init__(self): + super().__init__() + self.cplx = torch.nn.Parameter(torch.zeros(1, 10, dtype=torch.cfloat).cuda()) + + def forward(self, x): + return x + self.cplx + + cplx = torch.nn.DataParallel(Cplx().cuda()) + input = torch.rand(1, 10, dtype=torch.cfloat).cuda() + result = cplx(input) + # 2 is the extra real view dimension here + self.assertEqual(result.size(), torch.Size([1, 10, 2])) + self.assertEqual(result, torch.view_as_real(input)) + def _test_gather(self, output_device): inputs = ( torch.randn(2, 4, device='cuda:0', requires_grad=True), diff --git a/test/distributed/test_distributed_fork.py b/test/distributed/test_distributed_fork.py index 293eba1f278a6..84d23e71af951 100644 --- a/test/distributed/test_distributed_fork.py +++ b/test/distributed/test_distributed_fork.py @@ -54,7 +54,7 @@ def setUp(self): WORLD_SIZE = os.environ["WORLD_SIZE"] dist.init_process_group(init_method=INIT_METHOD, backend="mpi") - class TestMPI(DistributedTest._DistTestBase): + class TestMPIWithFork(TestCase, DistributedTest._DistTestBase): pass elif BACKEND == "test": diff --git a/test/distributed/test_jit_c10d.py b/test/distributed/test_jit_c10d.py new file mode 100644 index 0000000000000..ca43c25c644a1 --- /dev/null +++ b/test/distributed/test_jit_c10d.py @@ -0,0 +1,232 @@ +import unittest +import tempfile +from sys import platform +import torch +import torch.distributed as c10d +import time +from datetime import timedelta +from typing import List + +import torch.testing._internal.common_utils as common +from torch.testing._internal.common_distributed import requires_nccl, skip_if_rocm_single_process +from torch.testing._internal.common_utils import load_tests, TEST_WITH_TSAN, run_tests, IS_WINDOWS +from torch.testing._internal.jit_utils import JitTestCase + +# load_tests from common_utils is used to automatically filter tests for +# sharding on sandcastle. This line silences flake warnings +load_tests = load_tests + +if not c10d.is_available(): + print('c10d not available, skipping tests', file=sys.stderr) + sys.exit(0) + +if platform == 'darwin': + LOOPBACK = 'lo0' +else: + LOOPBACK = 'lo' + +def unique_process_group_name(prefix): + # Append timestamp to process group name to make it unique, so + # that when tests run multiple times or in parallel there + # wouldn't be name conflicts. + now = int(time.time() * 1000) + return "%s_%d" % (prefix, now) + +def _create_tcp_store(): + addr = "localhost" + port = common.find_free_port() + timeout = timedelta(minutes=5) + timeout_millisecond = int(timeout / timedelta(milliseconds=1)) + return torch.classes.dist_c10d.TCPStore(addr, port, 1, True, timeout_millisecond) + + +@unittest.skipIf( + TEST_WITH_TSAN, + "TSAN is not fork-safe since we're forking in a multi-threaded environment", +) +@unittest.skipIf(IS_WINDOWS, "TCPStore not available on Windows") +class ProcessGroupNCCLJitTest(JitTestCase): + MAIN_PROCESS_RANK = 0 + + def setUp(self): + self.rank = self.MAIN_PROCESS_RANK + self.world_size = 1 + self.file = tempfile.NamedTemporaryFile(delete=False) + self.num_gpus = torch.cuda.device_count() + if self.num_gpus < 2: + raise unittest.SkipTest("NCCL test requires 2+ GPUs") + + def _create_nccl_pg(self, name_prefix): + tcp_store = _create_tcp_store() + opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True) + + name = unique_process_group_name(name_prefix) + + return torch.classes.dist_c10d.ProcessGroupNCCL(tcp_store, self.rank, self.world_size, opts, name) + + def _create_nccl_pg_as_base_process_group(self, name): + tcp_store = _create_tcp_store() + + return torch.classes.dist_c10d.frontend().new_process_group_helper( + self.world_size, self.rank, [], "nccl", tcp_store, name, 0) + + @requires_nccl() + @skip_if_rocm_single_process + def test_init_process_group_nccl_torchbind(self): + self._create_nccl_pg("raw_process_group_nccl_torchbind") + + @requires_nccl() + @skip_if_rocm_single_process + def test_process_group_nccl_torchbind_alltoall(self): + nccl_pg = self._create_nccl_pg("process_group_nccl_as_base_class") + + input = torch.rand(16).cuda() + output = torch.rand(16).cuda() + + @torch.jit.script + def run_pg_nccl_alltoall( + pg: torch.classes.dist_c10d.ProcessGroupNCCL, + output: torch.Tensor, + input: torch.Tensor + ): + output_split_sizes: List[int] = [] + input_split_sizes: List[int] = [] + work = pg.alltoall_base(output, input, output_split_sizes, input_split_sizes) + work.wait() + return work.result() + + run_pg_nccl_alltoall(nccl_pg, output, input) + + @requires_nccl() + @skip_if_rocm_single_process + def test_init_process_group_nccl_as_base_process_group_torchbind(self): + name = unique_process_group_name("creation_test_process_group") + self._create_nccl_pg_as_base_process_group(name) + + @requires_nccl() + @skip_if_rocm_single_process + def test_process_group_nccl_as_base_process_group_torchbind_alltoall(self): + name = unique_process_group_name("alltoall_test_process_group") + nccl_pg = self._create_nccl_pg_as_base_process_group(name) + + input = torch.rand(16).cuda() + output = torch.rand(16).cuda() + + @torch.jit.script + def run_pg_nccl_alltoall( + pg: torch.classes.dist_c10d.ProcessGroup, + output: torch.Tensor, + input: torch.Tensor + ): + output_split_sizes: List[int] = [] + input_split_sizes: List[int] = [] + work = pg.alltoall_base(output, input, output_split_sizes, input_split_sizes) + work.wait() + return work.result() + + run_pg_nccl_alltoall(nccl_pg, output, input) + + @requires_nccl() + @skip_if_rocm_single_process + def test_process_group_nccl_serialization(self): + class TestModule(torch.nn.Module): + def __init__(self, pg_nccl): + super(TestModule, self).__init__() + self.pg = pg_nccl + + def forward(self, input: torch.Tensor): + if self.pg is None: + return input + 1 + else: + return input + 2 + + pg_nccl = self._create_nccl_pg("nccl_process_group_as_module_member") + self.checkModule(TestModule(pg_nccl), (torch.rand((2, 3)),)) + + +class StoreTest(JitTestCase): + def setUp(self): + super(StoreTest, self).setUp() + self.file = tempfile.NamedTemporaryFile(delete=False) + self.filestore = torch.classes.dist_c10d.FileStore(self.file.name, 1) + self.prefix = "test_prefix" + + def test_create_file_store(self): + # test FileStore creation in JIT + @torch.jit.script + def create_file_store( + path: str, + num_workers: int + ) -> torch.classes.dist_c10d.FileStore: + return torch.classes.dist_c10d.FileStore(path, num_workers) + + create_file_store(self.file.name, 1) + + def test_create_prefix_store(self): + # test PrefixStore creation in JIT + @torch.jit.script + def create_prefix_file_store( + store: torch.classes.dist_c10d.Store, + prefix: str + ) -> torch.classes.dist_c10d.PrefixStore: + return torch.classes.dist_c10d.PrefixStore(prefix, store) + + create_prefix_file_store(self.filestore, self.prefix) + + +@unittest.skipIf(IS_WINDOWS, "TCPStore not available on Windows") +class C10dFrontendJitTest(JitTestCase): + def setUp(self): + self.rank = 0 + self.world_size = 1 + self.file = tempfile.NamedTemporaryFile(delete=False) + self.num_gpus = torch.cuda.device_count() + if self.num_gpus < 2: + raise unittest.SkipTest("NCCL test requires 2+ GPUs") + + @requires_nccl() + @skip_if_rocm_single_process + def test_frontend_singleton(self): + frontend1 = torch.classes.dist_c10d.frontend() + frontend2 = torch.classes.dist_c10d.frontend() + + tcp_store = _create_tcp_store() + + pg_name = unique_process_group_name("singleton_test_process_group") + + ProcessGroupNCCL1 = frontend1.new_process_group_helper( + self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 0) + + ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name) + self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name) + +@unittest.skipIf(IS_WINDOWS, "TCPStore not available on Windows") +class C10dProcessGroupSerialization(JitTestCase): + def setUp(self): + self.num_gpus = torch.cuda.device_count() + if self.num_gpus < 2: + raise unittest.SkipTest("NCCL test requires 2+ GPUs") + + @requires_nccl() + @skip_if_rocm_single_process + def test_process_group_as_module_member(self): + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + tcp_store = _create_tcp_store() + + name = unique_process_group_name("module_member_process_group") + self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper( + 1, 0, [], "nccl", tcp_store, name, 0) + + def forward(self, input: torch.Tensor): + if self.pg is None: + return input + 1 + else: + return input + 2 + + self.checkModule(TestModule(), (torch.rand((2, 3)),)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributions/test_constraints.py b/test/distributions/test_constraints.py new file mode 100644 index 0000000000000..d4dd9239920de --- /dev/null +++ b/test/distributions/test_constraints.py @@ -0,0 +1,95 @@ +import pytest + +import torch +from torch.distributions import biject_to, constraints, transform_to +from torch.testing._internal.common_cuda import TEST_CUDA + + +CONSTRAINTS = [ + (constraints.real,), + (constraints.positive,), + (constraints.greater_than, [-10., -2, 0, 2, 10]), + (constraints.greater_than, 0), + (constraints.greater_than, 2), + (constraints.greater_than, -2), + (constraints.greater_than_eq, 0), + (constraints.greater_than_eq, 2), + (constraints.greater_than_eq, -2), + (constraints.less_than, [-10., -2, 0, 2, 10]), + (constraints.less_than, 0), + (constraints.less_than, 2), + (constraints.less_than, -2), + (constraints.unit_interval,), + (constraints.interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]), + (constraints.interval, -2, -1), + (constraints.interval, 1, 2), + (constraints.half_open_interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]), + (constraints.half_open_interval, -2, -1), + (constraints.half_open_interval, 1, 2), + (constraints.simplex,), + (constraints.corr_cholesky,), + (constraints.lower_cholesky,), +] + + +def build_constraint(constraint_fn, args, is_cuda=False): + if not args: + return constraint_fn + t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor + return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args)) + + +@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS]) +@pytest.mark.parametrize('is_cuda', [False, + pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA, + reason='CUDA not found.'))]) +def test_biject_to(constraint_fn, args, is_cuda): + constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) + try: + t = biject_to(constraint) + except NotImplementedError: + pytest.skip('`biject_to` not implemented.') + assert t.bijective, "biject_to({}) is not bijective".format(constraint) + if constraint_fn is constraints.corr_cholesky: + # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) + x = torch.randn(6, 6, dtype=torch.double) + else: + x = torch.randn(5, 5, dtype=torch.double) + if is_cuda: + x = x.cuda() + y = t(x) + assert constraint.check(y).all(), '\n'.join([ + "Failed to biject_to({})".format(constraint), + "x = {}".format(x), + "biject_to(...)(x) = {}".format(y), + ]) + x2 = t.inv(y) + assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint) + + j = t.log_abs_det_jacobian(x, y) + assert j.shape == x.shape[:x.dim() - t.input_event_dim] + + +@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS]) +@pytest.mark.parametrize('is_cuda', [False, + pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA, + reason='CUDA not found.'))]) +def test_transform_to(constraint_fn, args, is_cuda): + constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) + t = transform_to(constraint) + if constraint_fn is constraints.corr_cholesky: + # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) + x = torch.randn(6, 6, dtype=torch.double) + else: + x = torch.randn(5, 5, dtype=torch.double) + if is_cuda: + x = x.cuda() + y = t(x) + assert constraint.check(y).all(), "Failed to transform_to({})".format(constraint) + x2 = t.inv(y) + y2 = t(x2) + assert torch.allclose(y, y2), "Error in transform_to({}) pseudoinverse".format(constraint) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/test/test_distributions.py b/test/distributions/test_distributions.py similarity index 91% rename from test/test_distributions.py rename to test/distributions/test_distributions.py index 86cc5a0531d1e..a196169be142b 100644 --- a/test/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -39,30 +39,28 @@ from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests from torch.testing._internal.common_cuda import TEST_CUDA from torch.autograd import grad, gradcheck +from torch.autograd.functional import jacobian from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2, ContinuousBernoulli, Dirichlet, Distribution, Exponential, ExponentialFamily, FisherSnedecor, Gamma, Geometric, Gumbel, - HalfCauchy, HalfNormal, - Independent, Laplace, LogisticNormal, + HalfCauchy, HalfNormal, Independent, Kumaraswamy, + LKJCholesky, Laplace, LogisticNormal, LogNormal, LowRankMultivariateNormal, MixtureSameFamily, Multinomial, MultivariateNormal, - NegativeBinomial, Normal, OneHotCategorical, Pareto, - Poisson, RelaxedBernoulli, RelaxedOneHotCategorical, + NegativeBinomial, Normal, + OneHotCategorical, OneHotCategoricalStraightThrough, + Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical, StudentT, TransformedDistribution, Uniform, VonMises, Weibull, constraints, kl_divergence) -from torch.distributions.constraint_registry import biject_to, transform_to +from torch.distributions.constraint_registry import transform_to from torch.distributions.constraints import Constraint, is_dependent from torch.distributions.dirichlet import _Dirichlet_backward from torch.distributions.kl import _kl_expfamily_expfamily -from torch.distributions.transforms import (AbsTransform, AffineTransform, - CatTransform, ComposeTransform, ExpTransform, - LowerCholeskyTransform, - PowerTransform, SigmoidTransform, - TanhTransform, SoftmaxTransform, - StickBreakingTransform, - identity_transform, StackTransform) -from torch.distributions.utils import probs_to_logits, lazy_property +from torch.distributions.transforms import (AffineTransform, CatTransform, ExpTransform, + StackTransform, identity_transform) +from torch.distributions.utils import (probs_to_logits, lazy_property, tril_matrix_to_vec, + vec_to_tril_matrix) from torch.nn.functional import softmax # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for @@ -240,6 +238,30 @@ def is_all_nan(tensor): 'reinterpreted_batch_ndims': 3, }, ]), + Example(Kumaraswamy, [ + { + 'concentration1': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), + 'concentration0': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), + }, + { + 'concentration1': torch.rand(4).uniform_(1, 2).requires_grad_(), + 'concentration0': torch.rand(4).uniform_(1, 2).requires_grad_(), + }, + ]), + Example(LKJCholesky, [ + { + 'dim': 2, + 'concentration': 0.5 + }, + { + 'dim': 3, + 'concentration': torch.tensor([0.5, 1., 2.]), + }, + { + 'dim': 100, + 'concentration': 4. + }, + ]), Example(Laplace, [ { 'loc': torch.randn(5, 5, requires_grad=True), @@ -335,6 +357,11 @@ def is_all_nan(tensor): {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, ]), + Example(OneHotCategoricalStraightThrough, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, + {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), Example(Pareto, [ { 'scale': 1.0, @@ -604,6 +631,10 @@ def is_all_nan(tensor): {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, ]), + Example(OneHotCategoricalStraightThrough, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), Example(Pareto, [ { 'scale': 0.0, @@ -696,7 +727,7 @@ def _gradcheck_log_prob(self, dist_ctor, ctor_params): # performs gradient checks on log_prob distribution = dist_ctor(*ctor_params) s = distribution.sample() - if s.is_floating_point(): + if not distribution.support.is_discrete: s = s.detach().requires_grad_() expected_shape = distribution.batch_shape + distribution.event_shape @@ -1391,7 +1422,7 @@ def test_uniform(self): self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,)) # Check log_prob computation when value outside range - uniform = Uniform(low_1d, high_1d) + uniform = Uniform(low_1d, high_1d, validate_args=False) above_high = torch.tensor([4.0]) below_low = torch.tensor([-1.0]) self.assertEqual(uniform.log_prob(above_high).item(), -inf) @@ -1486,7 +1517,7 @@ def test_halfcauchy(self): def test_halfnormal(self): std = torch.randn(5, 5).abs().requires_grad_() - std_1d = torch.randn(1, requires_grad=True) + std_1d = torch.randn(1).abs().requires_grad_() std_delta = torch.tensor([1e-5, 1e-5]) self.assertEqual(HalfNormal(std).sample().size(), (5, 5)) self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5)) @@ -1947,6 +1978,8 @@ def gradcheck_func(samples, mu, sigma, prec, scale_tril): sigma = 0.5 * (sigma + sigma.transpose(-1, -2)) # Ensure symmetry of covariance if prec is not None: prec = 0.5 * (prec + prec.transpose(-1, -2)) # Ensure symmetry of precision + if scale_tril is not None: + scale_tril = scale_tril.tril() return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples) gradcheck(gradcheck_func, (mvn_samples, mean, covariance, precision, scale_tril), raise_exception=True) @@ -2249,6 +2282,42 @@ def test_gumbel_sample(self): scipy.stats.gumbel_r(loc=loc, scale=scale), 'Gumbel(loc={}, scale={})'.format(loc, scale)) + def test_kumaraswamy_shape(self): + concentration1 = torch.randn(2, 3).abs().requires_grad_() + concentration0 = torch.randn(2, 3).abs().requires_grad_() + concentration1_1d = torch.randn(1).abs().requires_grad_() + concentration0_1d = torch.randn(1).abs().requires_grad_() + self.assertEqual(Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3)) + self.assertEqual(Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3)) + self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,)) + self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(), (1, 1)) + self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ()) + self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,)) + + # Kumaraswamy distribution is not implemented in SciPy + # Hence these tests are explicit + def test_kumaraswamy_mean_variance(self): + c1_1 = torch.randn(2, 3).abs().requires_grad_() + c0_1 = torch.randn(2, 3).abs().requires_grad_() + c1_2 = torch.randn(4).abs().requires_grad_() + c0_2 = torch.randn(4).abs().requires_grad_() + cases = [(c1_1, c0_1), (c1_2, c0_2)] + for i, (a, b) in enumerate(cases): + m = Kumaraswamy(a, b) + samples = m.sample((60000, )) + expected = samples.mean(0) + actual = m.mean + error = (expected - actual).abs() + max_error = max(error[error == error]) + self.assertLess(max_error, 0.01, + "Kumaraswamy example {}/{}, incorrect .mean".format(i + 1, len(cases))) + expected = samples.var(0) + actual = m.variance + error = (expected - actual).abs() + max_error = max(error[error == error]) + self.assertLess(max_error, 0.01, + "Kumaraswamy example {}/{}, incorrect .variance".format(i + 1, len(cases))) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_fishersnedecor(self): df1 = torch.randn(2, 3).abs().requires_grad_() @@ -2483,6 +2552,29 @@ def test_continuous_bernoulli_3d(self): (2, 5, 2, 3, 5)) self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) + def test_lkj_cholesky_log_prob(self): + def tril_cholesky_to_tril_corr(x): + x = vec_to_tril_matrix(x, -1) + diag = (1 - (x * x).sum(-1)).sqrt().diag_embed() + x = x + diag + return tril_matrix_to_vec(x @ x.T, -1) + + for dim in range(2, 5): + log_probs = [] + lkj = LKJCholesky(dim, concentration=1.) + for i in range(2): + sample = lkj.sample() + sample_tril = tril_matrix_to_vec(sample, diag=-1) + log_prob = lkj.log_prob(sample) + log_abs_det_jacobian = torch.slogdet(jacobian(tril_cholesky_to_tril_corr, sample_tril)).logabsdet + log_probs.append(log_prob - log_abs_det_jacobian) + # for concentration=1., the density is uniform over the space of all + # correlation matrices. + if dim == 2: + # for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.) + self.assertTrue(all([x == torch.tensor(0.5).log() for x in log_probs])) + self.assertEqual(log_probs[0], log_probs[1]) + def test_independent_shape(self): for Dist, params in EXAMPLES: for param in params: @@ -2553,7 +2645,7 @@ def test_cdf_log_prob(self): for i, param in enumerate(params): dist = Dist(**param) samples = dist.sample() - if samples.dtype.is_floating_point: + if not dist.support.is_discrete: samples.requires_grad_() try: cdfs = dist.cdf(samples) @@ -2622,6 +2714,18 @@ def test_valid_parameter_broadcasting(self): (1, 2)), (Gumbel(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])), (1, 1)), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=1.), + (2,)), + (Kumaraswamy(concentration1=1, concentration0=torch.tensor([1., 1.])), + (2, )), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([1.])), + (2,)), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.], [1.]])), + (2, 2)), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.]])), + (1, 2)), + (Kumaraswamy(concentration1=torch.tensor([1.]), concentration0=torch.tensor([[1.]])), + (1, 1)), (Laplace(loc=torch.tensor([0., 0.]), scale=1), (2,)), (Laplace(loc=0, scale=torch.tensor([1., 1.])), @@ -2701,6 +2805,14 @@ def test_invalid_parameter_broadcasting(self): 'concentration': torch.tensor([0, 0]), 'rate': torch.tensor([1, 1, 1]) }), + (Kumaraswamy, { + 'concentration1': torch.tensor([[1, 1]]), + 'concentration0': torch.tensor([1, 1, 1, 1]) + }), + (Kumaraswamy, { + 'concentration1': torch.tensor([[[1, 1, 1], [1, 1, 1]]]), + 'concentration0': torch.tensor([1, 1]) + }), (Laplace, { 'loc': torch.tensor([0, 0]), 'scale': torch.tensor([1, 1, 1]) @@ -2940,11 +3052,9 @@ def setUp(self): self.scalar_sample = 1 self.tensor_sample_1 = torch.ones(3, 2) self.tensor_sample_2 = torch.ones(3, 2, 3) - Distribution.set_default_validate_args(True) def tearDown(self): super(TestDistributionShapes, self).tearDown() - Distribution.set_default_validate_args(False) def test_entropy_shape(self): for Dist, params in EXAMPLES: @@ -3076,23 +3186,23 @@ def test_one_hot_categorical_shape(self): self.assertEqual(dist.sample().size(), torch.Size((3,))) self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) - simplex_sample = self.tensor_sample_2 / self.tensor_sample_2.sum(-1, keepdim=True) - self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 2,))) + sample = torch.tensor([0., 1., 0.]).expand(3, 2, 3) + self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 2,))) self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,))) - simplex_sample = torch.ones(3, 3) / 3 - self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,))) + sample = torch.eye(3) + self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,))) # batched dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) self.assertEqual(dist._batch_shape, torch.Size((3,))) self.assertEqual(dist._event_shape, torch.Size((2,))) self.assertEqual(dist.sample().size(), torch.Size((3, 2))) self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) - simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=True) - self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,))) + sample = torch.tensor([0., 1.]) + self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3))) - simplex_sample = torch.ones(3, 1, 2) / 2 - self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3))) + sample = torch.tensor([0., 1.]).expand(3, 1, 2) + self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 3))) def test_cauchy_shape_scalar_params(self): cauchy = Cauchy(0, 1) @@ -3121,8 +3231,7 @@ def test_halfcauchy_shape_scalar_params(self): self.assertEqual(halfcauchy.sample().size(), torch.Size()) self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))) - self.assertEqual(halfcauchy.log_prob(self.scalar_sample).size(), - torch.Size()) + self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample) self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(halfcauchy.log_prob(self.tensor_sample_2).size(), @@ -3242,6 +3351,15 @@ def test_gumbel_shape_scalar_params(self): self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_kumaraswamy_shape_scalar_params(self): + kumaraswamy = Kumaraswamy(1, 1) + self.assertEqual(kumaraswamy._batch_shape, torch.Size()) + self.assertEqual(kumaraswamy._event_shape, torch.Size()) + self.assertEqual(kumaraswamy.sample().size(), torch.Size()) + self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2))) + self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) + self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_vonmises_shape_tensor_params(self): von_mises = VonMises(torch.tensor([0., 0.]), torch.tensor([1., 1.])) self.assertEqual(von_mises._batch_shape, torch.Size((2,))) @@ -3412,12 +3530,15 @@ def __init__(self, probs): [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]]) - pareto = pairwise(Pareto, [2.5, 4.0, 2.5, 4.0], [2.25, 3.75, 2.25, 3.75]) + pareto = (Pareto(torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4), + torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4)), + Pareto(torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4), + torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4))) poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0]) - uniform_within_unit = pairwise(Uniform, [0.15, 0.95, 0.2, 0.8], [0.1, 0.9, 0.25, 0.75]) + uniform_within_unit = pairwise(Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8]) uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7]) uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4]) - uniform_pareto = pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5]) + uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5]) continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9]) # These tests should pass with precision = 0.01, but that makes tests very expensive. @@ -3755,13 +3876,21 @@ def test_entropy_exponential_family(self): class TestConstraints(TestCase): def test_params_constraints(self): + normalize_probs_dists = ( + Categorical, + Multinomial, + OneHotCategorical, + OneHotCategoricalStraightThrough, + RelaxedOneHotCategorical + ) + for Dist, params in EXAMPLES: for i, param in enumerate(params): dist = Dist(**param) for name, value in param.items(): if isinstance(value, numbers.Number): value = torch.tensor([value]) - if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs': + if Dist in normalize_probs_dists and name == 'probs': # These distributions accept positive probs, but elsewhere we # use a stricter constraint to the simplex. value = value / value.sum(-1, True) @@ -4005,6 +4134,7 @@ def test_continuous_bernoulli_with_logits_overflow(self): expected_gradient=tensor_type([0.])) +# TODO: make this a pytest parameterized test class TestLazyLogitsInitialization(TestCase): def setUp(self): super(TestLazyLogitsInitialization, self).setUp() @@ -4015,13 +4145,13 @@ def setUp(self): def test_lazy_logits_initialization(self): for Dist, params in self.examples: - param = params[0] + param = params[0].copy() if 'probs' in param: probs = param.pop('probs') param['logits'] = probs_to_logits(probs) dist = Dist(**param) - shape = (1,) if not dist.event_shape else dist.event_shape - dist.log_prob(torch.ones(shape)) + # Create new instance to generate a valid sample + dist.log_prob(Dist(**param).sample()) message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params)) self.assertFalse('probs' in vars(dist), msg=message) try: @@ -4034,7 +4164,7 @@ def test_lazy_logits_initialization(self): def test_lazy_probs_initialization(self): for Dist, params in self.examples: - param = params[0] + param = params[0].copy() if 'probs' in param: dist = Dist(**param) dist.sample() @@ -4206,319 +4336,6 @@ def test_icdf(self): self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist) -class TestTransforms(TestCase): - def setUp(self): - super(TestTransforms, self).setUp() - self.transforms = [] - transforms_by_cache_size = {} - for cache_size in [0, 1]: - transforms = [ - AbsTransform(cache_size=cache_size), - ExpTransform(cache_size=cache_size), - PowerTransform(exponent=2, - cache_size=cache_size), - PowerTransform(exponent=torch.tensor(5.).normal_(), - cache_size=cache_size), - SigmoidTransform(cache_size=cache_size), - TanhTransform(cache_size=cache_size), - AffineTransform(0, 1, cache_size=cache_size), - AffineTransform(1, -2, cache_size=cache_size), - AffineTransform(torch.randn(5), - torch.randn(5), - cache_size=cache_size), - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - SoftmaxTransform(cache_size=cache_size), - StickBreakingTransform(cache_size=cache_size), - LowerCholeskyTransform(cache_size=cache_size), - ComposeTransform([ - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - ]), - ComposeTransform([ - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - ExpTransform(cache_size=cache_size), - ]), - ComposeTransform([ - AffineTransform(0, 1, cache_size=cache_size), - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - AffineTransform(1, -2, cache_size=cache_size), - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - ]), - ] - for t in transforms[:]: - transforms.append(t.inv) - transforms.append(identity_transform) - self.transforms += transforms - if cache_size == 0: - self.unique_transforms = transforms[:] - - def _generate_data(self, transform): - domain = transform.domain - codomain = transform.codomain - x = torch.empty(4, 5) - if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky: - x = torch.empty(6, 6) - x = x.normal_() - return x - elif domain is constraints.real: - return x.normal_() - elif domain is constraints.positive: - return x.normal_().exp() - elif domain is constraints.unit_interval: - return x.uniform_() - elif isinstance(domain, constraints.interval): - x = x.uniform_() - x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound) - return x - elif domain is constraints.simplex: - x = x.normal_().exp() - x /= x.sum(-1, True) - return x - raise ValueError('Unsupported domain: {}'.format(domain)) - - def test_inv_inv(self): - for t in self.transforms: - self.assertTrue(t.inv.inv is t) - - def test_equality(self): - transforms = self.unique_transforms - for x, y in product(transforms, transforms): - if x is y: - self.assertTrue(x == y) - self.assertFalse(x != y) - else: - self.assertFalse(x == y) - self.assertTrue(x != y) - - self.assertTrue(identity_transform == identity_transform.inv) - self.assertFalse(identity_transform != identity_transform.inv) - - def test_with_cache(self): - for transform in self.transforms: - if transform._cache_size == 0: - transform = transform.with_cache(1) - self.assertTrue(transform._cache_size == 1) - - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - except NotImplementedError: - continue - y2 = transform(x) - self.assertTrue(y2 is y) - - def test_forward_inverse_cache(self): - for transform in self.transforms: - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - except NotImplementedError: - continue - x2 = transform.inv(y) # should be implemented at least by caching - y2 = transform(x2) # should be implemented at least by caching - if transform.bijective: - # verify function inverse - self.assertEqual(x2, x, msg='\n'.join([ - '{} t.inv(t(-)) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - ])) - else: - # verify weaker function pseudo-inverse - self.assertEqual(y2, y, msg='\n'.join([ - '{} t(t.inv(t(-))) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - 'y2 = t(x2) = {}'.format(y2), - ])) - - def test_forward_inverse_no_cache(self): - for transform in self.transforms: - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - x2 = transform.inv(y.clone()) # bypass cache - y2 = transform(x2) - except NotImplementedError: - continue - if transform.bijective: - # verify function inverse - self.assertEqual(x2, x, msg='\n'.join([ - '{} t.inv(t(-)) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - ])) - else: - # verify weaker function pseudo-inverse - self.assertEqual(y2, y, msg='\n'.join([ - '{} t(t.inv(t(-))) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - 'y2 = t(x2) = {}'.format(y2), - ])) - - def test_univariate_forward_jacobian(self): - for transform in self.transforms: - if transform.event_dim > 0: - continue - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - actual = transform.log_abs_det_jacobian(x, y) - except NotImplementedError: - continue - expected = torch.abs(grad([y.sum()], [x])[0]).log() - self.assertEqual(actual, expected, msg='\n'.join([ - 'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform), - 'Expected: {}'.format(expected), - 'Actual: {}'.format(actual), - ])) - - def test_univariate_inverse_jacobian(self): - for transform in self.transforms: - if transform.event_dim > 0: - continue - y = self._generate_data(transform.inv).requires_grad_() - try: - x = transform.inv(y) - actual = transform.log_abs_det_jacobian(x, y) - except NotImplementedError: - continue - expected = -torch.abs(grad([x.sum()], [y])[0]).log() - self.assertEqual(actual, expected, msg='\n'.join([ - '{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform), - 'Expected: {}'.format(expected), - 'Actual: {}'.format(actual), - ])) - - def test_jacobian_shape(self): - for transform in self.transforms: - x = self._generate_data(transform) - try: - y = transform(x) - actual = transform.log_abs_det_jacobian(x, y) - except NotImplementedError: - continue - self.assertEqual(actual.shape, x.shape[:x.dim() - transform.event_dim]) - - def test_transform_shapes(self): - transform0 = ExpTransform() - transform1 = SoftmaxTransform() - transform2 = LowerCholeskyTransform() - - self.assertEqual(transform0.event_dim, 0) - self.assertEqual(transform1.event_dim, 1) - self.assertEqual(transform2.event_dim, 2) - self.assertEqual(ComposeTransform([transform0, transform1]).event_dim, 1) - self.assertEqual(ComposeTransform([transform0, transform2]).event_dim, 2) - self.assertEqual(ComposeTransform([transform1, transform2]).event_dim, 2) - - def test_transformed_distribution_shapes(self): - transform0 = ExpTransform() - transform1 = SoftmaxTransform() - transform2 = LowerCholeskyTransform() - base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4)) - base_dist1 = Dirichlet(torch.ones(4, 4)) - base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4)) - examples = [ - ((4, 4), (), base_dist0), - ((4,), (4,), base_dist1), - ((4, 4), (), TransformedDistribution(base_dist0, [transform0])), - ((4,), (4,), TransformedDistribution(base_dist0, [transform1])), - ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])), - ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform0])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform1])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform2])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])), - ((3, 4, 4), (), base_dist2), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])), - ] - for batch_shape, event_shape, dist in examples: - self.assertEqual(dist.batch_shape, batch_shape) - self.assertEqual(dist.event_shape, event_shape) - x = dist.rsample() - try: - dist.log_prob(x) # this should not crash - except NotImplementedError: - continue - - def test_jit_fwd(self): - for transform in self.unique_transforms: - x = self._generate_data(transform).requires_grad_() - - def f(x): - return transform(x) - - try: - traced_f = torch.jit.trace(f, (x,)) - except NotImplementedError: - continue - - # check on different inputs - x = self._generate_data(transform).requires_grad_() - self.assertEqual(f(x), traced_f(x)) - - def test_jit_inv(self): - for transform in self.unique_transforms: - y = self._generate_data(transform.inv).requires_grad_() - - def f(y): - return transform.inv(y) - - try: - traced_f = torch.jit.trace(f, (y,)) - except NotImplementedError: - continue - - # check on different inputs - y = self._generate_data(transform.inv).requires_grad_() - self.assertEqual(f(y), traced_f(y)) - - def test_jit_jacobian(self): - for transform in self.unique_transforms: - x = self._generate_data(transform).requires_grad_() - - def f(x): - y = transform(x) - return transform.log_abs_det_jacobian(x, y) - - try: - traced_f = torch.jit.trace(f, (x,)) - except NotImplementedError: - continue - - # check on different inputs - x = self._generate_data(transform).requires_grad_() - self.assertEqual(f(x), traced_f(x)) - - class TestFunctors(TestCase): def test_cat_transform(self): x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) @@ -4586,6 +4403,22 @@ def test_cat_transform_non_uniform(self): t2.log_abs_det_jacobian(x2, y2)], dim=dim) self.assertEqual(actual_jac, expected_jac) + def test_cat_event_dim(self): + t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) + t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) + dim = 1 + bs = 16 + x1 = torch.randn(bs, 2) + x2 = torch.randn(bs, 2) + x = torch.cat([x1, x2], dim=1) + t = CatTransform([t1, t2], dim=dim, lengths=[2, 2]) + y1 = t1(x1) + y2 = t2(x2) + y = t(x) + actual_jac = t.log_abs_det_jacobian(x, y) + expected_jac = sum([t1.log_abs_det_jacobian(x1, y1), + t2.log_abs_det_jacobian(x2, y2)]) + def test_stack_transform(self): x1 = -1 * torch.arange(1, 101, dtype=torch.float) x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100 @@ -4621,105 +4454,9 @@ def test_stack_transform(self): self.assertEqual(actual_jac, expected_jac) -class TestConstraintRegistry(TestCase): - def get_constraints(self, is_cuda=False): - tensor = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor - return [ - constraints.real, - constraints.positive, - constraints.greater_than(tensor([-10., -2, 0, 2, 10])), - constraints.greater_than(0), - constraints.greater_than(2), - constraints.greater_than(-2), - constraints.greater_than_eq(0), - constraints.greater_than_eq(2), - constraints.greater_than_eq(-2), - constraints.less_than(tensor([-10., -2, 0, 2, 10])), - constraints.less_than(0), - constraints.less_than(2), - constraints.less_than(-2), - constraints.unit_interval, - constraints.interval(tensor([-4., -2, 0, 2, 4]), - tensor([-3., 3, 1, 5, 5])), - constraints.interval(-2, -1), - constraints.interval(1, 2), - constraints.half_open_interval(tensor([-4., -2, 0, 2, 4]), - tensor([-3., 3, 1, 5, 5])), - constraints.half_open_interval(-2, -1), - constraints.half_open_interval(1, 2), - constraints.simplex, - constraints.lower_cholesky, - ] - - def test_biject_to(self): - for constraint in self.get_constraints(): - try: - t = biject_to(constraint) - except NotImplementedError: - continue - self.assertTrue(t.bijective, "biject_to({}) is not bijective".format(constraint)) - x = torch.randn(5, 5) - y = t(x) - self.assertTrue(constraint.check(y).all(), '\n'.join([ - "Failed to biject_to({})".format(constraint), - "x = {}".format(x), - "biject_to(...)(x) = {}".format(y), - ])) - x2 = t.inv(y) - self.assertEqual(x, x2, msg="Error in biject_to({}) inverse".format(constraint)) - - j = t.log_abs_det_jacobian(x, y) - self.assertEqual(j.shape, x.shape[:x.dim() - t.event_dim]) - - @unittest.skipIf(not TEST_CUDA, "CUDA not found") - def test_biject_to_cuda(self): - for constraint in self.get_constraints(is_cuda=True): - try: - t = biject_to(constraint) - except NotImplementedError: - continue - self.assertTrue(t.bijective, "biject_to({}) is not bijective".format(constraint)) - # x = torch.randn(5, 5, device="cuda") - x = torch.randn(5, 5).cuda() - y = t(x) - self.assertTrue(constraint.check(y).all(), '\n'.join([ - "Failed to biject_to({})".format(constraint), - "x = {}".format(x), - "biject_to(...)(x) = {}".format(y), - ])) - x2 = t.inv(y) - self.assertEqual(x, x2, msg="Error in biject_to({}) inverse".format(constraint)) - - j = t.log_abs_det_jacobian(x, y) - self.assertEqual(j.shape, x.shape[:x.dim() - t.event_dim]) - - def test_transform_to(self): - for constraint in self.get_constraints(): - t = transform_to(constraint) - x = torch.randn(5, 5) - y = t(x) - self.assertTrue(constraint.check(y).all(), "Failed to transform_to({})".format(constraint)) - x2 = t.inv(y) - y2 = t(x2) - self.assertEqual(y, y2, msg="Error in transform_to({}) pseudoinverse".format(constraint)) - - @unittest.skipIf(not TEST_CUDA, "CUDA not found") - def test_transform_to_cuda(self): - for constraint in self.get_constraints(is_cuda=True): - t = transform_to(constraint) - # x = torch.randn(5, 5, device="cuda") - x = torch.randn(5, 5).cuda() - y = t(x) - self.assertTrue(constraint.check(y).all(), "Failed to transform_to({})".format(constraint)) - x2 = t.inv(y) - y2 = t(x2) - self.assertEqual(y, y2, msg="Error in transform_to({}) pseudoinverse".format(constraint)) - - class TestValidation(TestCase): def setUp(self): super(TestCase, self).setUp() - Distribution.set_default_validate_args(True) def test_valid(self): for Dist, params in EXAMPLES: @@ -4737,9 +4474,29 @@ def test_invalid(self): fail_string = 'ValueError not raised for {} example {}/{}' raise AssertionError(fail_string.format(Dist.__name__, i + 1, len(params))) from e + def test_warning_unimplemented_constraints(self): + class Delta(Distribution): + def __init__(self, validate_args=True): + super().__init__(validate_args=validate_args) + + def sample(self, sample_shape=torch.Size()): + return torch.tensor(0.).expand(sample_shape) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value[value != 0.] = -float('inf') + value[value == 0.] = 0. + return value + + with self.assertWarns(UserWarning): + d = Delta() + sample = d.sample((2,)) + with self.assertWarns(UserWarning): + d.log_prob(sample) + def tearDown(self): super(TestValidation, self).tearDown() - Distribution.set_default_validate_args(False) class TestJit(TestCase): diff --git a/test/distributions/test_transforms.py b/test/distributions/test_transforms.py new file mode 100644 index 0000000000000..b5e9144f0bd89 --- /dev/null +++ b/test/distributions/test_transforms.py @@ -0,0 +1,365 @@ +from numbers import Number + +import pytest + +import torch +from torch.autograd.functional import jacobian +from torch.distributions import Dirichlet, Normal, TransformedDistribution, constraints +from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform, + CorrCholeskyTransform, ExpTransform, + LowerCholeskyTransform, PowerTransform, + SigmoidTransform, TanhTransform, SoftmaxTransform, + StickBreakingTransform, identity_transform, Transform, + _InverseTransform) +from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix + + +def get_transforms(cache_size): + transforms = [ + AbsTransform(cache_size=cache_size), + ExpTransform(cache_size=cache_size), + PowerTransform(exponent=2, + cache_size=cache_size), + PowerTransform(exponent=torch.tensor(5.).normal_(), + cache_size=cache_size), + SigmoidTransform(cache_size=cache_size), + TanhTransform(cache_size=cache_size), + AffineTransform(0, 1, cache_size=cache_size), + AffineTransform(1, -2, cache_size=cache_size), + AffineTransform(torch.randn(5), + torch.randn(5), + cache_size=cache_size), + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + SoftmaxTransform(cache_size=cache_size), + StickBreakingTransform(cache_size=cache_size), + LowerCholeskyTransform(cache_size=cache_size), + CorrCholeskyTransform(cache_size=cache_size), + ComposeTransform([ + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + ]), + ComposeTransform([ + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + ExpTransform(cache_size=cache_size), + ]), + ComposeTransform([ + AffineTransform(0, 1, cache_size=cache_size), + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + AffineTransform(1, -2, cache_size=cache_size), + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + ]), + ] + transforms += [t.inv for t in transforms] + return transforms + + +def reshape_transform(transform, shape): + # Needed to squash batch dims for testing jacobian + if isinstance(transform, AffineTransform): + if isinstance(transform.loc, Number): + return transform + try: + return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size) + except RuntimeError: + return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size) + if isinstance(transform, ComposeTransform): + reshaped_parts = [] + for p in transform.parts: + reshaped_parts.append(reshape_transform(p, shape)) + return ComposeTransform(reshaped_parts, cache_size=transform._cache_size) + if isinstance(transform.inv, AffineTransform): + return reshape_transform(transform.inv, shape).inv + if isinstance(transform.inv, ComposeTransform): + return reshape_transform(transform.inv, shape).inv + return transform + + +# Generate pytest ids +def transform_id(x): + assert isinstance(x, Transform) + name = f'Inv({type(x._inv).__name__})' if isinstance(x, _InverseTransform) else f'{type(x).__name__}' + return f'{name}(cache_size={x._cache_size})' + + +def generate_data(transform): + torch.manual_seed(1) + domain = transform.domain + codomain = transform.codomain + x = torch.empty(4, 5) + if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky: + x = torch.empty(6, 6) + x = x.normal_() + return x + elif domain is constraints.real: + return x.normal_() + elif domain is constraints.real_vector: + # For corr_cholesky the last dim in the vector + # must be of size (dim * dim) // 2 + x = torch.empty(3, 6) + x = x.normal_() + return x + elif domain is constraints.positive: + return x.normal_().exp() + elif domain is constraints.unit_interval: + return x.uniform_() + elif isinstance(domain, constraints.interval): + x = x.uniform_() + x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound) + return x + elif domain is constraints.simplex: + x = x.normal_().exp() + x /= x.sum(-1, True) + return x + elif domain is constraints.corr_cholesky: + x = torch.empty(4, 5, 5) + x = x.normal_().tril() + x /= x.norm(dim=-1, keepdim=True) + x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs()) + return x + raise ValueError('Unsupported domain: {}'.format(domain)) + + +TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1) +TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0) +ALL_TRANSFORMS = TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform] + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +def test_inv_inv(transform, ids=transform_id): + assert transform.inv.inv is transform + + +@pytest.mark.parametrize('x', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +@pytest.mark.parametrize('y', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_equality(x, y): + if x is y: + assert x == y + else: + assert x != y + assert identity_transform == identity_transform.inv + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +def test_with_cache(transform): + if transform._cache_size == 0: + transform = transform.with_cache(1) + assert transform._cache_size == 1 + x = generate_data(transform).requires_grad_() + try: + y = transform(x) + except NotImplementedError: + pytest.skip('Not implemented.') + y2 = transform(x) + assert y2 is y + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +@pytest.mark.parametrize('test_cached', [True, False]) +def test_forward_inverse(transform, test_cached): + x = generate_data(transform).requires_grad_() + try: + y = transform(x) + except NotImplementedError: + pytest.skip('Not implemented.') + if test_cached: + x2 = transform.inv(y) # should be implemented at least by caching + else: + try: + x2 = transform.inv(y.clone()) # bypass cache + except NotImplementedError: + pytest.skip('Not implemented.') + y2 = transform(x2) + if transform.bijective: + # verify function inverse + assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([ + '{} t.inv(t(-)) error'.format(transform), + 'x = {}'.format(x), + 'y = t(x) = {}'.format(y), + 'x2 = t.inv(y) = {}'.format(x2), + ]) + else: + # verify weaker function pseudo-inverse + assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([ + '{} t(t.inv(t(-))) error'.format(transform), + 'x = {}'.format(x), + 'y = t(x) = {}'.format(y), + 'x2 = t.inv(y) = {}'.format(x2), + 'y2 = t(x2) = {}'.format(y2), + ]) + + +def test_compose_transform_shapes(): + transform0 = ExpTransform() + transform1 = SoftmaxTransform() + transform2 = LowerCholeskyTransform() + + assert transform0.event_dim == 0 + assert transform1.event_dim == 1 + assert transform2.event_dim == 2 + assert ComposeTransform([transform0, transform1]).event_dim == 1 + assert ComposeTransform([transform0, transform2]).event_dim == 2 + assert ComposeTransform([transform1, transform2]).event_dim == 2 + + +transform0 = ExpTransform() +transform1 = SoftmaxTransform() +transform2 = LowerCholeskyTransform() +base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4)) +base_dist1 = Dirichlet(torch.ones(4, 4)) +base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4)) + + +@pytest.mark.parametrize('batch_shape, event_shape, dist', [ + ((4, 4), (), base_dist0), + ((4,), (4,), base_dist1), + ((4, 4), (), TransformedDistribution(base_dist0, [transform0])), + ((4,), (4,), TransformedDistribution(base_dist0, [transform1])), + ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])), + ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform0])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform1])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform2])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])), + ((3, 4, 4), (), base_dist2), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])), +]) +def test_transformed_distribution_shapes(batch_shape, event_shape, dist): + assert dist.batch_shape == batch_shape + assert dist.event_shape == event_shape + x = dist.rsample() + try: + dist.log_prob(x) # this should not crash + except NotImplementedError: + pytest.skip('Not implemented.') + + +@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_jit_fwd(transform): + x = generate_data(transform).requires_grad_() + + def f(x): + return transform(x) + + try: + traced_f = torch.jit.trace(f, (x,)) + except NotImplementedError: + pytest.skip('Not implemented.') + + # check on different inputs + x = generate_data(transform).requires_grad_() + assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) + + +@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_jit_inv(transform): + y = generate_data(transform.inv).requires_grad_() + + def f(y): + return transform.inv(y) + + try: + traced_f = torch.jit.trace(f, (y,)) + except NotImplementedError: + pytest.skip('Not implemented.') + + # check on different inputs + y = generate_data(transform.inv).requires_grad_() + assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True) + + +@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_jit_jacobian(transform): + x = generate_data(transform).requires_grad_() + + def f(x): + y = transform(x) + return transform.log_abs_det_jacobian(x, y) + + try: + traced_f = torch.jit.trace(f, (x,)) + except NotImplementedError: + pytest.skip('Not implemented.') + + # check on different inputs + x = generate_data(transform).requires_grad_() + assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +def test_jacobian(transform): + x = generate_data(transform) + try: + y = transform(x) + actual = transform.log_abs_det_jacobian(x, y) + except NotImplementedError: + pytest.skip('Not implemented.') + # Test shape + target_shape = x.shape[:x.dim() - transform.input_event_dim] + assert actual.shape == target_shape + + # Expand if required + transform = reshape_transform(transform, x.shape) + ndims = len(x.shape) + event_dim = ndims - transform.input_event_dim + x_ = x.view((-1,) + x.shape[event_dim:]) + n = x_.shape[0] + # Reshape to squash batch dims to a single batch dim + transform = reshape_transform(transform, x_.shape) + + # 1. Transforms with 0 off-diagonal elements + if transform.input_event_dim == 0: + jac = jacobian(transform, x_) + # assert off-diagonal elements are zero + assert torch.allclose(jac, jac.diagonal().diag_embed()) + expected = jac.diagonal().abs().log().reshape(x.shape) + # 2. Transforms with non-0 off-diagonal elements + else: + if isinstance(transform, CorrCholeskyTransform): + jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_) + elif isinstance(transform.inv, CorrCholeskyTransform): + jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)), + tril_matrix_to_vec(x_, diag=-1)) + elif isinstance(transform, StickBreakingTransform): + jac = jacobian(lambda x: transform(x)[..., :-1], x_) + else: + jac = jacobian(transform, x_) + + # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims) + # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims) + # after reshaping the event dims (see above) to give a batched square matrix whose determinant + # can be computed. + gather_idx_shape = list(jac.shape) + gather_idx_shape[-2] = 1 + gather_idxs = torch.arange(n).reshape((n,) + (1,) * (len(jac.shape) - 1)).expand(gather_idx_shape) + jac = jac.gather(-2, gather_idxs).squeeze(-2) + out_ndims = jac.shape[-2] + jac = jac[..., :out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking). + expected = torch.slogdet(jac).logabsdet + + assert torch.allclose(actual, expected, atol=1e-5) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/test/distributions/test_utils.py b/test/distributions/test_utils.py new file mode 100644 index 0000000000000..5751246eb10a9 --- /dev/null +++ b/test/distributions/test_utils.py @@ -0,0 +1,24 @@ +import pytest + +import torch +from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix + + +@pytest.mark.parametrize('shape', [ + (2, 2), + (3, 3), + (2, 4, 4), + (2, 2, 4, 4), +]) +def test_tril_matrix_to_vec(shape): + mat = torch.randn(shape) + n = mat.shape[-1] + for diag in range(-n, n): + actual = mat.tril(diag) + vec = tril_matrix_to_vec(actual, diag) + tril_mat = vec_to_tril_matrix(vec, diag) + assert torch.allclose(tril_mat, actual) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/test/expect/TestTensorBoard.test_pytorch_graph.expect b/test/expect/TestTensorBoard.test_pytorch_graph.expect index 52d232c98778e..83534d57e446f 100644 --- a/test/expect/TestTensorBoard.test_pytorch_graph.expect +++ b/test/expect/TestTensorBoard.test_pytorch_graph.expect @@ -26,7 +26,7 @@ node { node { name: "output/output.1" op: "IO Node" - input: "myLinear/Linear[l]/21" + input: "myLinear/Linear[l]/23" attr { key: "_output_shapes" value { @@ -50,7 +50,7 @@ node { } } node { - name: "myLinear/Linear[l]/17" + name: "myLinear/Linear[l]/19" op: "prim::Constant" attr { key: "attr" @@ -60,9 +60,9 @@ node { } } node { - name: "myLinear/Linear[l]/bias/18" + name: "myLinear/Linear[l]/bias/20" op: "prim::GetAttr" - input: "myLinear/Linear[l]/weight/14" + input: "myLinear/Linear[l]/weight/16" attr { key: "attr" value { @@ -71,9 +71,9 @@ node { } } node { - name: "myLinear/Linear[l]/weight/19" + name: "myLinear/Linear[l]/weight/21" op: "prim::GetAttr" - input: "myLinear/Linear[l]/weight/14" + input: "myLinear/Linear[l]/weight/16" attr { key: "attr" value { @@ -82,9 +82,9 @@ node { } } node { - name: "myLinear/Linear[l]/20" + name: "myLinear/Linear[l]/22" op: "aten::t" - input: "myLinear/Linear[l]/weight/19" + input: "myLinear/Linear[l]/weight/21" attr { key: "_output_shapes" value { @@ -108,13 +108,13 @@ node { } } node { - name: "myLinear/Linear[l]/21" + name: "myLinear/Linear[l]/23" op: "aten::addmm" - input: "myLinear/Linear[l]/bias/18" + input: "myLinear/Linear[l]/bias/20" input: "input/input" - input: "myLinear/Linear[l]/20" - input: "myLinear/Linear[l]/17" - input: "myLinear/Linear[l]/17" + input: "myLinear/Linear[l]/22" + input: "myLinear/Linear[l]/19" + input: "myLinear/Linear[l]/19" attr { key: "_output_shapes" value { diff --git a/test/fx/named_tup.py b/test/fx/named_tup.py new file mode 100644 index 0000000000000..d26fe9df3c922 --- /dev/null +++ b/test/fx/named_tup.py @@ -0,0 +1,7 @@ +from typing import NamedTuple + +import torch + +class MyNamedTup(NamedTuple): + i : torch.Tensor + f : torch.Tensor diff --git a/test/fx/quantization.py b/test/fx/quantization.py index 968c797c91636..ff6c98ac038b2 100644 --- a/test/fx/quantization.py +++ b/test/fx/quantization.py @@ -164,7 +164,7 @@ def matches(modules, node, pattern, max_uses=sys.maxsize): self_match = pattern arg_matches = None - if node.uses > max_uses: + if len(node.users) > max_uses: return False if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): @@ -219,6 +219,7 @@ def observe(self, args): def load_arg(a): return map_arg(a, lambda node: env[node.name]) + output_node : Optional[Node] = None for node in self.graph.nodes: if node.op == 'placeholder': result = next(args_iter) @@ -232,6 +233,8 @@ def load_arg(a): result = getattr(self_obj, node.target)(*args, **kwargs) elif node.op == 'call_module': result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == 'output': + return load_arg(node.args[0]) env[node.name] = result root_node, obj = self.matches.get(node.name, (None, None)) @@ -240,7 +243,7 @@ def load_arg(a): if node.name in self.quants: self.quants[node.name].observe(node, env) - return load_arg(self.graph.result) + raise RuntimeError('Graph had no output node!') def quantize(self): self.quantized_graph = Graph() @@ -281,7 +284,6 @@ def load_or_emit(n): else: quant_env[node.name] = r - self.quantized_graph.output(load_arg(self.graph.result, quantized=False)) return GraphModule(self.root, self.quantized_graph) def _find_matches(self, patterns): diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py new file mode 100644 index 0000000000000..db06663285da4 --- /dev/null +++ b/test/fx/test_fx_const_fold.py @@ -0,0 +1,274 @@ +import unittest + +import torch +from torch.fx.experimental import const_fold + + +class TestConstFold(unittest.TestCase): + def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule): + self.assertTrue(mod_folded.const_subgraph_module is not None) + + # Check that the constants are attributes in the main subgraph. + num_folded_attrs = 0 + for node in mod_folded.graph.nodes: + if node.op == "get_attr" and (node.target in mod_folded.const_output_names): + num_folded_attrs += 1 + self.assertEqual(num_folded_attrs, len(mod_folded.const_output_names)) + + def test_const_fold_basic_one_attr_no_name_collision(self): + r""" + Perform constant folding conversion, from original mod to split constant folding + module with two split subgraphs, where there's a single attr to fold and + a single output attr result to replace. + + attr1 attr1 + | | | | + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul attr2 x / (input from previous subgraph + \ / \ / is attr) + add sub y + | \ / + output mul attr2 + \ / + add + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]])) + + def forward(self, x, y): + a = self.attr_1 + self.attr_1 + x = x - a + return x * y + self.attr_2 + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) + base_result = mod(in_x, in_y) + fold_result = mod_folded(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_one_attr_name_collision(self): + r""" + Perform constant folding conversion, from original mod to split constant folding + module with two split subgraphs, where there's a single attr to fold and + a single output attr result to replace. Name the attrs such that they will + collide by name with folded attrs. + + add_1 add_1 + | | | | + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul add_2 x / (input from previous subgraph + \ / \ / is attr) + add sub y + | \ / + output mul add_2 + \ / + add + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Note: Named as such to result in name collision. + self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]])) + self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]])) + + def forward(self, x, y): + a = self.add_1__CF + self.add_1__CF + x = x - a + return x * y + self.add_2__CF + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0]) + base_result = mod(in_x, in_y) + fold_result = mod_folded(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_noop(self): + r""" + Check that a graph with no constant folding is handled correctly. + + x attr1 + \ / + sub + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + + def forward(self, x): + return x - self.attr1 + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + + # Check that the folded graph module is None, since there was no folding to do. + self.assertTrue(mod_folded.const_subgraph_module is None) + + # Now run both folded and non-folded to check results equal. + in_x = torch.tensor([[-0.45]]) + base_result = mod(in_x) + fold_result = mod_folded(in_x) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_two_attr_three_input(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into a single output, and there are three placeholder inputs. + + attr1 attr2 attr1 attr2 + \ / \ / + x add add + \ / | + sub y output (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + mul z x / (input from previous subgraph + \ / \ / is attr) + div sub y + | \ / + output mul z + \ / + div + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) + self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]])) + + def forward(self, x, y, z): + a = self.attr1 + self.attr1 + sub = x - a + mul = sub * y + return mul / z + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y, in_z = ( + torch.tensor([[-0.45]]), + torch.tensor([0.9]), + torch.tensor([1.1]), + ) + base_result = mod(in_x, in_y, in_z) + fold_result = mod_folded(in_x, in_y, in_z) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_basic_two_attr(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into a single output. + + attr1 attr2 attr1 attr2 + \ / \ / + x add add (becomes attr add_1) + \ / ==> -------+------- (const/base subgraph split) + sub x | (input from previous subgraph is attr) + | \ / + output sub + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.randn(2, 3)) + self.attr2 = torch.nn.Parameter(torch.randn(2, 3)) + + def forward(self, x): + y = self.attr1 + self.attr2 + return x + y + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x = torch.randn(2, 3) + fold_result = mod_folded(in_x) + base_result = mod(in_x) + self.assertTrue(torch.equal(fold_result, base_result)) + + def test_const_fold_multi_const_folded_attrs(self): + r""" + Perform constant folding conversion, from original mod to split constant + folding module with two split subgraphs, where there are two attrs to + fold into two new attrs. + + attr1 attr2 attr1 attr2 + / \ | / \ | + permute | sum permute | sum + \ / / \ / | + x add y / add | + \ / \ / | | + sub add output output (become attrs add_1 and mul_1) + \ / ==> --------+-------+------ (const/base subgraph split) + \ / x | y | (inputs from previous subgraph + add \ / \ / are attrs) + | sub add + linear \ / + | add + sigmoid | + | linear + output | + sigmoid + | + output + """ + + class ConstFoldTestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr1 = torch.nn.Parameter(torch.randn(4, 4)) + self.attr2 = torch.nn.Parameter(torch.randn(4, 4)) + self.lin = torch.nn.Linear(4, 4) + + def forward(self, x, y): + a = self.attr1 + self.attr1.permute(1, 0) + x = x - a + amax = torch.sum(self.attr2, dim=1) + y = y + amax + return torch.sigmoid(self.lin(x + y)) + + mod = ConstFoldTestModule() + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) + self._verify_const_fold_mod(mod_folded) + + # Now run both folded and non-folded to check results equal. + in_x, in_y = torch.randn(4, 4), torch.randn(4) + fold_result = mod_folded(in_x, in_y) + base_result = mod(in_x, in_y) + self.assertTrue(torch.equal(fold_result, base_result)) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py new file mode 100644 index 0000000000000..3eeaa883da59d --- /dev/null +++ b/test/fx/test_subgraph_rewriter.py @@ -0,0 +1,167 @@ +import os +import sys + +import torch +from torch.fx import symbolic_trace, subgraph_rewriter + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_fx.py TESTNAME\n\n" + "instead.") + +class TestSubgraphRewriter(JitTestCase): + + def test_subgraph_rewriter_preserves_logic(self): + class M(torch.nn.Module): + def forward(self, x): + val = torch.neg(x) + torch.relu(x) + return torch.add(val, val) + + def pattern(x): + return torch.neg(x) + torch.relu(x) + + def comparison(x): + val = torch.neg(x) + torch.relu(x) + return torch.add(val, val) + + traced_module = symbolic_trace(M()) + comparison_fn = symbolic_trace(comparison) + + x = torch.rand(1, 3) + + # Replace `pattern` with the same pattern (shouldn't change + # the underlying logic) + subgraph_rewriter.replace_pattern(traced_module, pattern, pattern) + + traced_module.graph.lint(traced_module) + + ref_output = comparison_fn(x) + test_output = traced_module.forward(x) + self.assertEqual(ref_output, test_output) + + def test_subgraph_rewriter_with_oneliner_pattern(self): + class M(torch.nn.Module): + def forward(self, x): + val = torch.neg(x) + return torch.add(val, val) + + def pattern(x): + return torch.neg(x) + + def replacement(x): + return torch.relu(x) + + def comparison(x): + val = torch.relu(x) + return torch.add(val, val) + + traced_module = symbolic_trace(M()) + comparison_fn = symbolic_trace(comparison) + + x = torch.rand(1, 3) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + traced_module.graph.lint(traced_module) + + ref_output = comparison_fn(x) + test_output = traced_module.forward(x) + self.assertEqual(ref_output, test_output) + + def test_subgraph_rewriter_single_pattern_match(self): + class M(torch.nn.Module): + def forward(self, x): + val = torch.neg(x) + torch.relu(x) + return torch.add(val, val) + + def pattern(x): + return torch.neg(x) + torch.relu(x) + + def replacement(x): + return torch.relu(x) + + def comparison(x): + val = torch.relu(x) + return torch.add(val, val) + + traced_module = symbolic_trace(M()) + comparison_fn = symbolic_trace(comparison) + + x = torch.rand(1, 3) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + traced_module.graph.lint(traced_module) + + ref_output = comparison_fn(x) + test_output = traced_module.forward(x) + self.assertEqual(ref_output, test_output) + + def test_subgraph_rewriter_multiple_pattern_match(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, w1, w2): + m1 = torch.cat([w1, w2]).sum() + m2 = torch.cat([w1, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + + def pattern(w1, w2): + return torch.cat([w1, w2]).sum() + + def replacement(w1, w2): + return torch.stack([w1, w2]) + + def comparison(x, w1, w2): + m1 = torch.stack([w1, w2]) + m2 = torch.stack([w1, w2]) + return x + torch.max(m1) + torch.max(m2) + + traced_module = symbolic_trace(M()) + comparison_fn = symbolic_trace(comparison) + + x = torch.rand(1, 3) + w1 = torch.rand(1, 3) + w2 = torch.rand(1, 3) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + traced_module.graph.lint(traced_module) + + ref_outs = comparison_fn(x, w1, w2) + test_outs = traced_module.forward(x, w1, w2) + self.assertEqual(ref_outs, test_outs) + + def test_subgraph_rewriter_graph_argument_order(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.mm(x, y) + + def pattern(x, y): + return torch.mm(x, y) + + def comparison(x, y): + return torch.mm(x, y) + + traced_module = symbolic_trace(M()) + comparison_fn = symbolic_trace(comparison) + + x = torch.randn(3, 4) + y = torch.randn(4, 5) + + subgraph_rewriter.replace_pattern(traced_module, pattern, pattern) + + traced_module.graph.lint(traced_module) + + ref_outs = comparison_fn(x, y) + test_outs = traced_module.forward(x, y) + self.assertEqual(ref_outs, test_outs) diff --git a/test/jit/test_async.py b/test/jit/test_async.py index e2a55fa3f23c1..b4b6b8e294f79 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -5,13 +5,13 @@ import torch import torch.nn as nn -from typing import Any +from typing import Any, Tuple # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase, _inline_everything -from typing import List, Tuple +from typing import List from torch import Tensor class TestAsync(JitTestCase): @@ -41,8 +41,7 @@ def foo(inp): def test_async_parsing(self): @torch.jit.script - def foo(x): - # type: (Tensor) -> List[Tensor] + def foo(x: Tensor) -> List[Tensor]: return [torch.neg(x), x.t()] @torch.jit.script @@ -257,8 +256,7 @@ def __init__(self): self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) @torch.jit.script_method - def forward(self, x): - # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor] + def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: future1 = torch.jit._fork(self.traced, x) future2 = torch.jit._fork(torch.neg, x) diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index 0e5a82c6b9b0d..ec8f3e2b43da3 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -29,7 +29,11 @@ def _perform_ad_subgraph_slicing(self, fn, *input_sizes): return ge.graph_for(*inputs) def assertGraphSize(self, graph, size): - nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", graph.nodes())) + nodes = list(filter(lambda n: (n.kind() != "prim::BailOut" and + n.kind() != "prim::BailoutTemplate" and + n.kind() != "prim::TypeCheck" and + n.kind() != "prim::RequiresGradCheck"), + graph.nodes())) self.assertEqual(len(list(nodes)), size) def test_chunk_constant_script_ad(self): diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py index e2eaa0b2a1e52..9f61ec77a1f60 100644 --- a/test/jit/test_backends.py +++ b/test/jit/test_backends.py @@ -5,9 +5,17 @@ import torch import torch._C +from torch.testing import FileCheck from pathlib import Path -from torch.testing._internal.common_utils import TEST_WITH_ROCM, skipIfRocm, IS_SANDCASTLE, IS_WINDOWS, IS_MACOS +from torch.testing._internal.common_utils import ( + IS_FBCODE, + IS_MACOS, + IS_SANDCASTLE, + IS_WINDOWS, + TEST_WITH_ROCM, + skipIfRocm, +) # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) @@ -28,6 +36,13 @@ def to_test_backend_multi(module, method_compile_spec): return torch._C._jit_to_backend("test_backend", module, method_compile_spec) +def to_test_backend_selective(module, method_compile_spec, submodules): + def _to_test_backend(module): + return to_test_backend(module, method_compile_spec) + + return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules) + + class BasicModule(torch.nn.Module): """ A simple Module used to test to_backend lowering machinery. @@ -54,7 +69,7 @@ class JitBackendTestCase(JitTestCase): def setUp(self): super().setUp() - if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: + if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE: raise unittest.SkipTest("non-portable load_library call used in test") torch_root = Path(__file__).resolve().parent.parent.parent p = torch_root / 'build' / 'lib' / 'libjitbackend_test.so' @@ -75,9 +90,9 @@ def check_function(self, function_name, input): backend_method = self.lowered_module.__getattr__(function_name) # Run methods. - python_output = python_method(input, input) - jit_output = jit_method(input, input) - backend_output = backend_method(input, input) + python_output = python_method(*input) + jit_output = jit_method(*input) + backend_output = backend_method(*input) # The answers returned by Python, JIT and to_backend should all match. self.assertEqual(python_output, backend_output) @@ -89,6 +104,24 @@ def save_load(self): """ self.lowered_module = self.getExportImportCopy(self.lowered_module) + def test_execution(self): + """ + Stub for correctness tests. + """ + pass + + def test_save_load(self): + """ + Stub for serialization tests. + """ + pass + + def test_errors(self): + """ + Stub for testing error checking. + """ + pass + class BasicModuleTest(JitBackendTestCase): """ @@ -101,7 +134,7 @@ def setUp(self): self.module = BasicModule() self.scripted_module = torch.jit.script(BasicModule()) self.lowered_module = to_test_backend_multi( - self.scripted_module._c, + self.scripted_module, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, ) @@ -110,9 +143,9 @@ def test_execution(self): input = torch.randn(5) # Test all three module methods. - self.check_function("accum", input) - self.check_function("sub_accum", input) - self.check_function("forward", input) + self.check_function("accum", (input, input)) + self.check_function("sub_accum", (input, input)) + self.check_function("forward", (input, input)) @skipIfRocm def test_save_load(self): @@ -160,8 +193,12 @@ def setUp(self): self.module = NestedModuleTest.NestedModule(BasicModule()) # Both modules in self.scripted_module are ScriptModules. self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule())) + + # First, script another instance of NestedModule with share_types=False so that it can be + # selectively lowered without modifying the type of self.scripted_module. lowered_module = to_test_backend_multi( - self.scripted_module._c, {"forward": {"": ""}} + torch.jit.script(BasicModule()), + {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, ) # self.lowered_module is a ScriptModule, but its submodule is a lowered module. self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module)) @@ -171,7 +208,7 @@ def test_execution(self): input = torch.randn(5) # Test forward. - self.check_function("forward", input) + self.check_function("forward", (input, input)) def test_save_load(self): # Lowered module should produce the same outputs. @@ -184,6 +221,161 @@ def test_save_load(self): self.test_execution() +class SelectiveLoweringTest(JitBackendTestCase): + """ + Tests for the selective lowering API. + """ + class OuterModule(torch.nn.Module): + def __init__(self, sub1, sub2, other): + super().__init__() + self.sub1 = sub1 + self.sub2 = sub2 + self.other = other + + def forward(self, x, y): + # Call the module that will be lowered directly to test + # type remapping in modules that are not its parent. + a, b = self.sub1.submodule.forward(x, y) + c, d = self.sub2.forward(x, y) + e, f = self.other.forward(x, y) + return a + c + e, b + d + f + + class MiddleModule(torch.nn.Module): + def __init__(self, submodule): + super().__init__() + self.submodule = submodule + + def forward(self, x, y): + return self.submodule.forward(x, y) + + def setUp(self): + super().setUp() + OuterModule = SelectiveLoweringTest.OuterModule + MiddleModule = SelectiveLoweringTest.MiddleModule + + def script_without_type_sharing(mod): + return torch.jit._recursive.create_script_module(mod, torch.jit._recursive.infer_methods_to_compile, share_types=False) + # Create Python, JIT and backend versions of a hierarchy that looks like this: + # --------- OuterModule -------- + # | | | + # MiddleModule MiddleModule MiddleModule + # | | | + # BasicModule BasicModule BasicModule + # + # Two BasicModules will be lowered and the third will not. + self.module = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) + self.scripted_module = script_without_type_sharing(OuterModule(MiddleModule( + BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) + self.lowered_module = script_without_type_sharing(OuterModule(MiddleModule( + BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) + self.lowered_module = to_test_backend_selective(self.lowered_module, {"forward": ""}, [ + "sub1.submodule", "sub2.submodule"]) + + def test_execution(self): + input = torch.randn(5) + self.check_function("forward", (input, input)) + + self.test_selective_lowering_type_remap() + + def test_save_load(self): + self.test_execution() + self.save_load() + self.test_execution() + + self.test_selective_lowering_type_remap() + + def test_selective_lowering_type_remap(self): + """ + Check that type remapping and replacement occurred during selective lowering. + """ + # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it + # calling the lowered module directly. + FileCheck() \ + .check("OuterModule") \ + .check("BasicModule") \ + .run(self.scripted_module.graph) + FileCheck() \ + .check("OuterModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .run(self.lowered_module.graph) + + # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs. + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.sub1.graph) + FileCheck() \ + .check("MiddleModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .run(self.lowered_module.sub1.graph) + + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.sub2.graph) + FileCheck() \ + .check("MiddleModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .run(self.lowered_module.sub2.graph) + + # Check that self.lowered_module.sub1/sub2.submodule were lowered. Its graph should mention + # __torch__.torch.classes.__backends__.test_backend, the TorchBind class for executing functions + # on the test JIT backend. + FileCheck() \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .check("__torch__.torch.classes.__backends__.test_backend") \ + .run(self.lowered_module.sub1.submodule.graph) + + FileCheck() \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .check("__torch__.torch.classes.__backends__.test_backend") \ + .run(self.lowered_module.sub2.submodule.graph) + + # Check that self.other and self.other.submodule have been left untouched by the selective lowering process. + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.other.graph) + FileCheck() \ + .check("BasicModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.other.submodule.graph) + + def test_errors(self): + """ + Check errors associated with selective lowering. + """ + # Check error messages thrown when attempting to lower something that is not a ScriptModule. + with self.assertRaisesRegex(RuntimeError, r"Object .* is not a ScriptModule"): + to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"]) + + MiddleModule = SelectiveLoweringTest.MiddleModule + mod = MiddleModule(BasicModule()) + mod.new_attr = 3 + + with self.assertRaisesRegex(RuntimeError, r"Attribute named new_attr is not a Module"): + to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["new_attr"]) + + # Check error message thrown when module hierarchy doesn't have unique types. + OuterModule = SelectiveLoweringTest.OuterModule + mod = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) + + with self.assertRaisesRegex(RuntimeError, r"Selective lowering is only supported for module hierarchies with unique types"): + to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]) + + class TestBackends(JitTestCase): """ This class wraps and invokes all subclasses of JitBackendTestCase so that each one @@ -194,19 +386,27 @@ def __init__(self, name): super().__init__(name) self.basic_module_test = BasicModuleTest(name) self.nested_module_test = NestedModuleTest(name) + self.selective_lowering_test = SelectiveLoweringTest(name) def setUp(self): super().setUp() if not TEST_WITH_ROCM: self.basic_module_test.setUp() self.nested_module_test.setUp() + self.selective_lowering_test.setUp() @skipIfRocm def test_execution(self): self.basic_module_test.test_execution() self.nested_module_test.test_execution() + self.selective_lowering_test.test_execution() @skipIfRocm def test_save_load(self): self.basic_module_test.test_save_load() self.nested_module_test.test_save_load() + self.selective_lowering_test.test_save_load() + + @skipIfRocm + def test_errors(self): + self.selective_lowering_test.test_errors() diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index dafc95013b963..b5a0dd8599a64 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -2,7 +2,7 @@ import sys import inspect import unittest -from typing import List +from typing import Dict, List import torch @@ -78,8 +78,7 @@ def forward(self, name): torch.jit.script(Mod()) def test_del(self): - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: a = x * 2 del a return x @@ -109,22 +108,28 @@ def fn(x): return a def test_del_multiple_operands(self): + def fn(x: List[int]) -> List[int]: + a, b, c = x[0], x[1], x[2] + del a, b, c + return x - with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, - "with more than one operand"): - @torch.jit.script - def del_list_multiple_operands(x): - # type: (List[int]) -> List[int] - del x[0], x[1] - return x + self.checkScript(fn, ([1, 2, 3],)) - with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, - "with more than one operand"): - @torch.jit.script - def del_dict_multiple_operands(x): - # type: (Dict[str, int]) -> Dict[str, int] - del x['hi'], x['there'] - return x + def del_list_multiple_operands(x: List[int]) -> List[int]: + del x[0], x[1] + return x + + py_out = del_list_multiple_operands([0, 1, 2]) + jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) + self.assertEquals(py_out, jit_out) + + def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: + del x['hi'], x['there'] + return x + + py_out = del_dict_multiple_operands({"hi": 5, "there": 6}) + jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) + self.assertEquals(py_out, jit_out) class TestTensorBuiltins(JitTestCase): diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index 3fcd893470914..4d3d73e5f7c70 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from torch.testing import FileCheck +from typing import Any # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -54,8 +55,7 @@ class FooTest(object): # noqa: B903 def __init__(self): pass - def __contains__(self, key): - # type: (str) -> bool + def __contains__(self, key: str) -> bool: return key == 'hi' @torch.jit.script @@ -67,17 +67,14 @@ def fn(): def test_set_attr_in_method(self): class FooTest(object): - def __init__(self, x): - # type: (int) -> None + def __init__(self, x: int) -> None: self.foo = x - def incFooTest(self, y): - # type: (int) -> None + def incFooTest(self, y: int) -> None: self.foo = self.foo + y @torch.jit.script - def fn(x): - # type: (int) -> int + def fn(x: int) -> int: foo = FooTest(x) foo.incFooTest(2) return foo.foo @@ -128,8 +125,7 @@ def test_type_annotations(self): with self.assertRaisesRegex(RuntimeError, "Expected a value of type \'bool"): @torch.jit.script # noqa: B903 class FooTest(object): # noqa: B903 - def __init__(self, x): - # type: (bool) -> None + def __init__(self, x: bool) -> None: self.foo = x @torch.jit.script @@ -143,7 +139,7 @@ def test_conditional_set_attr(self): @torch.jit.script class FooTest(object): def __init__(self, x): - if True: + if 1 == 1: self.attr = x def test_class_type_as_param(self): @@ -154,8 +150,7 @@ def __init__(self, x): self.attr = x @torch.jit.script - def fn(foo): - # type: (FooTest) -> Tensor + def fn(foo: FooTest) -> torch.Tensor: return foo.attr @torch.jit.script @@ -292,8 +287,7 @@ def __init__(self, x, y): self.y = y @torch.jit.script - def use_foo(foo): - # type: (Foo) -> Foo + def use_foo(foo: Foo) -> Foo: return foo # create from python @@ -318,8 +312,7 @@ def __init__(self, x, y): self.x = x self.y = y - def use_foo(foo, foo2, tup): - # type: (Foo, Foo, Tuple[Foo, Foo]) -> Tensor + def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor: a, b = tup return foo.x + foo2.y + a.x + b.y @@ -339,19 +332,17 @@ def test_class_sorting(self): global Foo # see [local resolution in python] class Foo(object): # noqa: B903 - def __init__(self, x): - # type: (int) -> None + def __init__(self, x: int) -> None: self.x = x - def __lt__(self, other): + def __lt__(self, other) -> bool: # type: (Foo) -> bool return self.x < other.x def getVal(self): return self.x - def test(li, reverse=False): - # type: (List[Foo], bool) -> Tuple[List[int], List[int]] + def test(li: List[Foo], reverse: bool = False) -> Tuple[List[int], List[int]]: li_sorted = sorted(li) ret_sorted = torch.jit.annotate(List[int], []) for foo in li_sorted: @@ -445,6 +436,39 @@ class Derived(Base): def two(self, x): return x + self.b + 2 + + def test_class_inheritance_implicit(self): + """ + Test that inheritance is detected in + implicit scripting codepaths (e.g. try_ann_to_type). + """ + class A: + def __init__(self, t): + self.t = t + + @staticmethod + def f(a: torch.Tensor): + return A(a + 1) + + class B(A): + def __init__(self, t): + self.t = t + 10 + + @staticmethod + def f(a: torch.Tensor): + return A(a + 1) + + x = A(torch.tensor([3])) + + def fun(x: Any): + if isinstance(x, A): + return A.f(x.t) + else: + return B.f(x.t) + + with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute or method"): + sc = torch.jit.script(fun) + @unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode") def test_imported_classes(self): import jit._imported_class_test.foo @@ -502,36 +526,29 @@ def two(self, x): @torch.jit.interface class OneTwo(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: torch.Tensor) -> torch.Tensor: pass @torch.jit.interface class OneTwoThree(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: torch.Tensor) -> torch.Tensor: pass - def three(self, x): - # type: (Tensor) -> Tensor + def three(self, x: torch.Tensor) -> torch.Tensor: pass @torch.jit.interface class OneTwoWrong(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass - def two(self, x): - # type: (int) -> int + def two(self, x: int) -> int: pass @torch.jit.script @@ -551,8 +568,7 @@ def __init__(self): def one(self, x, y): return x + y - def two(self, x): - # type: (int) -> int + def two(self, x: int) -> int: return 3 def use_them(x): @@ -566,13 +582,11 @@ def use_them(x): self.checkScript(use_them, (torch.rand(3, 4),)) @torch.jit.script - def as_interface(x): - # type: (OneTwo) -> OneTwo + def as_interface(x: OneTwo) -> OneTwo: return x @torch.jit.script - def inherit(x): - # type: (OneTwoThree) -> OneTwo + def inherit(x: OneTwoThree) -> OneTwo: return as_interface(x) with self.assertRaisesRegex(RuntimeError, "does not have method"): @@ -593,8 +607,7 @@ def wrong3(): with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): @torch.jit.script - def wrong4(x): - # type: (OneTwoWrong) -> int + def wrong4(x: OneTwoWrong) -> int: return as_interface(x) # Test interface/class python assignment @@ -646,16 +659,14 @@ class Foo(object): def __init__(self, x): self.x = x - def __len__(self): - # type: () -> int + def __len__(self) -> int: return len(self.x) def __neg__(self): self.x = -self.x return self - def __mul__(self, other): - # type: (Tensor) -> Tensor + def __mul__(self, other: torch.Tensor) -> torch.Tensor: return self.x * other def test_overload(): @@ -668,80 +679,61 @@ def test_overload(): # TODO - support compiling classes from strings in jit.CompilationUnit @torch.jit.script class MyClass(object): - def __init__(self, x): - # type: (int) -> None + def __init__(self, x: int) -> None: self.x = x - def __add__(self, other): - # type: (int) -> int + def __add__(self, other: int) -> int: return self.x + other - def __sub__(self, other): - # type: (int) -> int + def __sub__(self, other: int) -> int: return self.x - other - def __mul__(self, other): - # type: (int) -> int + def __mul__(self, other: int) -> int: return self.x * other - def __pow__(self, other): - # type: (int) -> int + def __pow__(self, other: int) -> int: return int(self.x ** other) - def __truediv__(self, other): - # type: (int) -> float + def __truediv__(self, other: int) -> float: return self.x / other - def __mod__(self, other): - # type: (int) -> int + def __mod__(self, other: int) -> int: return self.x % other - def __ne__(self, other): # noqa T484 - # type: (int) -> bool + def __ne__(self, other: int) -> bool: # noqa T484 return self.x != other - def __eq__(self, other): # noqa T484 - # type: (int) -> bool + def __eq__(self, other: int) -> bool: # noqa T484 return self.x == other - def __lt__(self, other): - # type: (int) -> bool + def __lt__(self, other: int) -> bool: return self.x < other - def __gt__(self, other): - # type: (int) -> bool + def __gt__(self, other: int) -> bool: return self.x > other - def __le__(self, other): - # type: (int) -> bool + def __le__(self, other: int) -> bool: return self.x <= other - def __ge__(self, other): - # type: (int) -> bool + def __ge__(self, other: int) -> bool: return self.x >= other - def __and__(self, other): - # type: (int) -> int + def __and__(self, other: int) -> int: return self.x & other - def __or__(self, other): - # type: (int) -> int + def __or__(self, other: int) -> int: return self.x | other - def __xor__(self, other): - # type: (int) -> int + def __xor__(self, other: int) -> int: return self.x ^ other - def __getitem__(self, other): - # type: (int) -> int + def __getitem__(self, other: int) -> int: return other + 1 - def __setitem__(self, idx, val): - # type: (int, int) -> None + def __setitem__(self, idx: int, val: int) -> None: self.x = val * idx - def __call__(self, val): - # type: (int) -> int + def __call__(self, val: int) -> int: return self.x * val * 3 @@ -799,8 +791,7 @@ def test_cast_overloads(self): @torch.jit.script class Foo(object): - def __init__(self, val): - # type: (float) -> None + def __init__(self, val: float) -> None: self.val = val def __int__(self): @@ -815,8 +806,7 @@ def __bool__(self): def __str__(self): return str(self.val) - def test(foo): - # type: (Foo) -> Tuple[int, float, bool] + def test(foo: Foo) -> Tuple[int, float, bool]: if foo: pass return int(foo), float(foo), bool(foo) @@ -857,8 +847,7 @@ def __init__(self, x, y): def test_class_constructs_itself(self): @torch.jit.script # noqa: B903 class LSTMStateStack(object): # noqa: B903 - def __init__(self, num_layers, hidden_size): - # type: (int, int) -> None + def __init__(self, num_layers: int, hidden_size: int) -> None: self.num_layers = num_layers self.hidden_size = hidden_size self.last_state = ( @@ -885,8 +874,7 @@ class Tree(object): # noqa: B903 def __init__(self): self.child = torch.jit.annotate(Optional[Leaf], None) - def add_child(self, child): - # type: (Leaf) -> None + def add_child(self, child: Leaf) -> None: self.child = child def test_recursive_class(self): @@ -925,6 +913,26 @@ def forward(self, x): # Make sure class constant is accessible from module self.assertEqual(m.w, m_loaded.w) + def test_py_class_to_ivalue_missing_attribute(self): + global Foo # see [local resolution in python] + + class Foo(object): + i : int + f : float + + def __init__(self, i : int, f : float): + self.i = i + self.f = f + + @torch.jit.script + def test_fn(x : Foo) -> float: + return x.i + x.f + + test_fn(Foo(3, 4.0)) + + with self.assertRaisesRegex(RuntimeError, 'missing attribute i'): + test_fn(torch.rand(3, 4)) + def test_unused_method(self): """ Test unused methods on scripted classes. @@ -1158,6 +1166,32 @@ def test_function(a: int, b: int) -> 'ClassWithStaticMethod': self.checkScript(test_function, (1, 2)) + def test_classmethod(self): + """ + Test classmethods on class types. + """ + global ClassWithClassMethod + + @torch.jit.script + class ClassWithClassMethod: + def __init__(self, a: int): + self.a: int = a + + def __eq__(self, other: 'ClassWithClassMethod'): + return self.a == other.a + + @classmethod + def create(cls, a: int) -> 'ClassWithClassMethod': + return cls(a) + + def test_function(a: int) -> 'ClassWithClassMethod': + x = ClassWithClassMethod(a) + # Support calling classmethod with an instance + # Calling with the class is not supported. + return x.create(a) + + self.checkScript(test_function, (1,)) + def test_properties(self): """ Test that a scripted class can make use of the @property decorator. @@ -1167,6 +1201,8 @@ def free_function(x: int) -> int: @torch.jit.script class Properties(object): + __jit_unused_properties__ = ["unsupported"] + def __init__(self, a: int): self.a = a @@ -1174,6 +1210,19 @@ def __init__(self, a: int): def attr(self) -> int: return self.a - 1 + @property + def unsupported(self) -> int: + return sum([self.a]) + + @torch.jit.unused + @property + def unsupported_2(self) -> int: + return sum([self.a]) + + @unsupported_2.setter + def unsupported_2(self, value): + self.a = sum([self.a]) + @attr.setter def attr(self, value: int): self.a = value + 3 @@ -1259,3 +1308,99 @@ def fn() -> bool: with self.assertRaisesRegexWithHighlight(RuntimeError, r"Class does not define __delitem__", "example[key]"): self.checkScript(fn, ()) + + def test_recursive_script_builtin_type_resolution(self): + """ + Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled. + """ + # A will be implicitly compiled because it is not annotated with @torch.jit.script + # but is used in g() below. + tensor_t = torch.Tensor + device_t = torch.device + device_ty = torch.device + + class A(object): + def __init__(self): + pass + + def f(self, x: tensor_t, y: torch.device) -> tensor_t: + return x.to(device=y) + + def g(self, x: device_t) -> device_ty: + return x + + def h(self, a: 'A') -> 'A': + return A() + + def i(self, a: List[int]) -> int: + return a[0] + + def j(self, l: List[device_t]) -> device_ty: + return l[0] + + def call_f(): + a = A() + return a.f(torch.tensor([1]), torch.device("cpu")) + + def call_g(): + a = A() + return a.g(torch.device("cpu")) + + def call_i(): + a = A() + return a.i([3]) + + def call_j(): + a = A() + return a.j([torch.device("cpu"), torch.device("cpu")]) + + for fn in [call_f, call_g, call_i, call_j]: + self.checkScript(fn, ()) + s = self.getExportImportCopy(torch.jit.script(fn)) + self.assertEqual(s(), fn()) + + def test_recursive_script_module_builtin_type_resolution(self): + """ + Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled + when compiling a module. + """ + class Wrapper(): + def __init__(self, t): + self.t = t + + def to(self, l: List[torch.device], device: Optional[torch.device] = None): + return self.t.to(device=device) + + + class A(nn.Module): + def forward(self): + return Wrapper(torch.rand(4, 4)) + + scripted = torch.jit.script(A()) + self.getExportImportCopy(scripted) + + def test_class_attribute_wrong_type(self): + """ + Test that the error message displayed when convering a class type + to an IValue that has an attribute of the wrong type. + """ + @torch.jit.script + class ValHolder(object): # noqa: B903 + def __init__(self, val): + self.val = val + + class Mod(nn.Module): + def __init__(self): + super(Mod, self).__init__() + self.mod1 = ValHolder(1) + self.mod2 = ValHolder(2) + + def forward(self, cond: bool): + if cond: + mod = self.mod1 + else: + mod = self.mod2 + return mod.val + + with self.assertRaisesRegex(RuntimeError, "Could not cast attribute 'val' to type Tensor"): + torch.jit.script(Mod()) diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py new file mode 100644 index 0000000000000..f7af8e3a2efc1 --- /dev/null +++ b/test/jit/test_cuda.py @@ -0,0 +1,476 @@ +import os +import sys +import gc +import unittest + +import torch +from typing import NamedTuple +from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) + +# Check if GPU is available +TEST_CUDA = torch.cuda.is_available() +# Check if multiple GPU's are available +TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 + +# If GPU is not available, then do not run the tests +if not TEST_CUDA: + print('CUDA not available, skipping tests', file=sys.stderr) + JitTestCase = object # noqa: F811 + +TEST_LARGE_TENSOR = TEST_CUDA + +# If GPU is available, then initialize the cuda context and check +# if there is memory available to allocate for LARGE Tensors. +if TEST_CUDA: + torch.ones(1).cuda() # initialize cuda context + TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9 + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead." + ) + +class TestCUDA(JitTestCase): + """ + A suite of tests for the CUDA API in TorchScript. + """ + def setUp(self): + super(TestCUDA, self).setUp() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + super(TestCUDA, self).tearDown() + + @skipIfRocm + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + def test_current_stream(self): + # Test current stream on the device and check if the stream device index + # matches with the device ID + @torch.jit.script + def fn(): + device_index = torch.cuda._current_device() + s0 = torch.cuda.current_stream(device_index) + s1 = torch.cuda.current_stream(1) + s2 = torch.cuda.current_stream(0) + + return s0.device_index(), s1.device_index(), s2.device_index() + + d0, d1, d2 = fn() + + # By default, the current device ID is 0. + self.assertEqual(0, d0) + self.assertEqual(1, d1) + self.assertEqual(0, d2) + self.assertEqual(d0, d2) + + @skipIfRocm + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") + @skipCUDANonDefaultStreamIf(True) + def test_streams_and_events(self): + # This test checks for the default stream ID is set to 0 on the device + @torch.jit.script + def test_default_streams(): + s0 = torch.cuda.default_stream(0) + s1 = torch.cuda.default_stream(1) + + d = torch.device('cuda:1') + + # Check the current stream id and default id are same + # on the current device. The current device id by default is 0 + s2 = torch.cuda.current_stream(0) + check_s2 = s2.id() == s0.id() + check_d0 = torch.cuda._current_device() == s2.device_index() + + # Set the current device to d1 and check if the stream + # has been set to the default stream on d1 + with torch.jit.cuda.device(d): + s3 = torch.cuda.current_stream(1) + check_s3 = s3.id() == s1.id() + check_d1 = torch.cuda._current_device() == s3.device_index() + + # Check if the current device was reset to 0 + is_device_d0 = torch.cuda._current_device() == s2.device_index() + + return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0 + + d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams() + + self.assertEqual(d0, 0) + self.assertEqual(d1, 1) + self.assertTrue(check_s2) + self.assertTrue(check_s3) + self.assertTrue(check_d0) + self.assertTrue(check_d1) + self.assertTrue(is_device_d0) + + # This test checks if the Stream Context manager is a no op + # when the stream is none for `with torch.jit.cuda.stream` + @torch.jit.script + def test_set_none_stream(): + device_index = torch.cuda._current_device() + current_stream = torch.cuda.current_stream(device_index) + default_stream = torch.cuda.default_stream(device_index) + + # When stream is none, check if this operation is a no-op + with torch.jit.cuda.stream(None): + cur_device_index = torch.cuda._current_device() + is_device_index_same = cur_device_index == device_index + is_current_stream_same = torch.cuda.current_stream(cur_device_index).id() == current_stream.id() + is_default_stream_same = torch.cuda.default_stream(device_index).id() == default_stream.id() + + # Check if the device index, current stream and default streams have not changed + are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same + return are_streams_same + self.assertTrue(test_set_none_stream()) + + # This test checks if the Device Context manager is a no op + # when the device is none for `with torch.jit.cuda.device` + @torch.jit.script + def test_set_device_none(): + device_index = torch.cuda._current_device() + # When device is none, check if this operation is a no-op + with torch.jit.cuda.device(None): + # Check if the current device is the same + is_device_same = torch.cuda._current_device() == device_index + return is_device_same + self.assertTrue(test_set_device_none()) + + # Check if a CUDA JIT stream is created + # on the _current_device + @torch.jit.script + def test_simple_stream(): + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + return device_index == s.device_index() + + self.assertTrue(test_simple_stream(), "Could not create Stream!") + + # Class used to store results for the test: test_get_stream. + class Result(NamedTuple): + t1 : torch.Tensor + t2 : torch.Tensor + is_current_and_default_stream_same : bool + is_default_and_user_stream_not_same : bool + is_stream_set : bool + is_stream_reset : bool + default_stream_query : bool + default_stream_id : int + user_stream_id : int + + # The test aims at checking different stream proporties. + @torch.jit.script + def test_get_stream(): + device_index = torch.cuda._current_device() + current_stream = torch.cuda.current_stream(device_index) + default_stream = torch.cuda.default_stream(device_index) + user_stream = torch.jit.cuda.Stream(device_index, 0) + + # Check if the current and default streams are the same on the device + is_current_and_default_stream_same = current_stream.id() == default_stream.id() + # Check if user stream and default stream are not the same on the device + is_default_and_user_stream_not_same = default_stream.id() != user_stream.id() + + with torch.jit.cuda.stream(user_stream): + is_stream_set = torch.cuda.current_stream(device_index).id() == user_stream.id() + + # Check if the stream was reset to current_stream + is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id() + + tensor1 = torch.rand(10000, 10000, device="cuda") + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + default_stream.synchronize() + default_stream_query = default_stream.query() + + # Capture all the results in the class Result + res = Result( + tensor1, tensor2, is_current_and_default_stream_same, + is_default_and_user_stream_not_same, is_stream_set, + is_stream_reset, default_stream_query, default_stream.id(), user_stream.id()) + return res + + result = test_get_stream() + + self.assertEqual(torch.matmul(result.t1, result.t1), result.t2) + self.assertTrue(result.is_current_and_default_stream_same) + self.assertTrue(result.is_default_and_user_stream_not_same) + self.assertTrue(result.is_stream_set) + self.assertTrue(result.is_stream_reset) + self.assertTrue(result.default_stream_query) + self.assertEqual(result.default_stream_id, 0) # Check if the default stream ID is always 0 + self.assertNotEqual(result.user_stream_id, 0) # Check if the user stream is always non zero + + # Test the stream context manager. This test checks if the stream is switched + # to the user stream on using the stream context manager. + @torch.jit.script + def test_stream_context(): + device_index = torch.cuda._current_device() + current_stream = torch.cuda.current_stream(device_index) + user_stream = torch.jit.cuda.Stream(device_index, 0) + A = torch.rand(1000, 1000, device="cuda") + + with torch.jit.cuda.stream(user_stream): + check = torch.cuda.current_stream(device_index).id() == user_stream.id() + B = torch.mm(A, A).to("cuda") + # Wait for B to be computed + user_stream.synchronize() + # Check if the stream has been reset on the current device + is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id() + + return A, B, check, is_stream_reset + + A, B, is_stream_set, is_stream_reset = test_stream_context() + self.assertEqual(torch.matmul(A, A), B) + self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!") + self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!") + + # Test multiple nested streams. Check if the operations are computed as expected on the streams + # This test has been adapted from the eager mode tests available at test/test_cuda.py + @torch.jit.script + def test_multiple_stream(): + prev_device_index = torch.cuda._current_device() + prev_current_stream = torch.cuda.current_stream(prev_device_index) + s1 = torch.jit.cuda.Stream(0, 0) + s2 = torch.jit.cuda.Stream(1, 0) + + A = torch.rand(1000, 1000, device="cuda") + B = torch.rand(1000, 1000, device="cuda") + with torch.jit.cuda.stream(s1): + C = torch.mm(A, A).to("cuda") + # Check if the stream and device have been set to s1 + is_stream_s1 = torch.cuda.current_stream(s1.device_index()).id() == s1.id() + is_device_s1 = torch.cuda._current_device() == s1.device_index() + with torch.jit.cuda.stream(s2): + # Check if the stream and device have been set to s2 + is_stream_s2 = torch.cuda.current_stream(s2.device_index()).id() == s2.id() + is_device_s2 = torch.cuda._current_device() == s2.device_index() + D = torch.mm(B, B).to("cuda") + # Check if the stream and device have been set to s1 + is_stream_s1_after = torch.cuda.current_stream(s1.device_index()).id() == s1.id() + is_device_s1_after = torch.cuda._current_device() == s1.device_index() + # Wait for D to be computed + s2.synchronize() + # Wait for C to be computed on S1 + s1.synchronize() + + # Check if the stream and device has been restored to previous stream and device + is_device_current = torch.cuda._current_device() == prev_device_index + is_stream_current = torch.cuda.current_stream(prev_device_index).id() == prev_current_stream.id() + + check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current + check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current + return A, B, C, D, check_stream, check_device + A, B, C, D, check_stream, check_device = test_multiple_stream() + + self.assertEqual(torch.matmul(A, A), C) + self.assertEqual(torch.matmul(B, B), D) + self.assertTrue(check_stream) + self.assertTrue(check_device) + + # Test multiple streams waiting on each other for the operations to be completed. + @torch.jit.script + def test_data_dependency_between_streams(): + device_index = torch.cuda._current_device() + prev_current_stream = torch.cuda.current_stream(device_index) + s1 = torch.jit.cuda.Stream(0, 0) + s2 = torch.jit.cuda.Stream(0, 0) + event = torch.jit.cuda.Event(False, False, False) + + A = torch.rand(1000, 1000, device="cuda") + with torch.jit.cuda.stream(s1): + is_stream_s1 = torch.cuda.current_stream(device_index).id() == s1.id() + B = torch.mm(A, A).to("cuda") + s1.record_event(event) + # Check if the current_stream is reset + is_current_stream_1 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id() + # Wait for ops on s1 to be computed + s2.wait_event(event) + with torch.jit.cuda.stream(s2): + is_stream_s2 = torch.cuda.current_stream(device_index).id() == s2.id() + C = torch.mm(B, B).to("cuda") + # Wait for C to be computed + s2.synchronize() + # Check if the current_stream is reset + is_current_stream_2 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id() + + check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2 + return A, B, C, check_stream + + A, B, C, check_stream = test_data_dependency_between_streams() + self.assertEqual(torch.matmul(A, A), B) + self.assertEqual(torch.matmul(B, B), C) + self.assertTrue(check_stream) + + # Test a simple CUDA event. Test if the CUDA event was created successfully + @torch.jit.script + def test_simple_event(): + e = torch.jit.cuda.Event(True, False, False) + return e is not None + self.assertTrue(test_simple_event(), "Could not create CUDA Event!") + + # Record the CUDA event for operation torch.mm on the current stream + # and then test if the elapsed time is greater than 0. This test is also + # an adaption from eager mdoe CUDA tests available at test/test_cuda.py + @torch.jit.script + def test_event(): + device_index = torch.cuda._current_device() + stream = torch.cuda.current_stream(device_index) + event = torch.jit.cuda.Event(True, False, False) + is_true_event_query = event.query() + start_event = torch.jit.cuda.Event(True, False, False) + stream.record_event(start_event) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + stream.record_event(event) + event.synchronize() + is_again_true_event_query = event.query() + + if not (is_true_event_query and is_again_true_event_query): + return -1.0 + return start_event.elapsed_time(event) + + self.assertGreater(test_event(), 0) + + # Check for stream synchronization , when a large tensor multiplication is + # computed on the stream. The stream.query should be true once the synchroniztion is done + @torch.jit.script + def test_stream_synchronize() -> float: + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + e_tik = torch.jit.cuda.Event(True, False, False) + e_tok = torch.jit.cuda.Event(True, False, False) + + e_tik.record(s) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + with torch.jit.cuda.stream(s): + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + s.synchronize() + e_tok.record(s) + e_tok.synchronize() + + if not s.query(): + return -1.0 + + # not necessary to check e_tik and e_tok, as elapsed_time would throw + # exception if otherwise. + return e_tik.elapsed_time(e_tok) + self.assertGreater(test_stream_synchronize(), 0) + + # Test event synchronization for the event that records a stream doing + # a large tensor multiplication. Check if the elapsed time is greater than 0 + # and the stream.query evaluates to true. + @torch.jit.script + def test_event_synchronize() -> float: + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + e_tik = torch.jit.cuda.Event(True, False, False) + e_tok = torch.jit.cuda.Event(True, False, False) + + e_tik.record(s) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + with torch.jit.cuda.stream(s): + tensor = torch.mm(tensor1, tensor1).to("cuda") + s.record_event(e_tok) + e_tok.synchronize() + s.synchronize() + + if not s.query(): + return -1.0 + + # not necessary to check e_tik and e_tok, as elapsed_time would throw + # exception if otherwise. + return e_tik.elapsed_time(e_tok) + + self.assertGreater(test_event_synchronize(), 0) + + # Test for event wait. Check if event waits for the all the operations on + # the stream to be done. Check for synchronizations and query on the streams + # and events. This test is adapted from eager mode tests for CUDA. Please refer + # test/test_cuda.py + @torch.jit.script + def test_event_wait() -> float: + device_index = torch.cuda._current_device() + s0 = torch.cuda.current_stream(device_index) + s1 = torch.jit.cuda.Stream(device_index, 0) + e_tik = torch.jit.cuda.Event(True, True, False) + e_tok = torch.jit.cuda.Event(True, True, False) + + e_tik.record(s0) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + with torch.jit.cuda.stream(s0): + tensor2 = torch.mm(tensor1, tensor1).cuda() + e_sync = torch.jit.cuda.Event(True, False, False) + e_sync.record(torch.cuda.current_stream(device_index)) + e_sync.wait(s1) + with torch.jit.cuda.stream(s1): + tensor3 = torch.rand(1000000000, 1000000000, device="cuda") + tensor4 = torch.mm(tensor3, tensor3).cuda() + s1.synchronize() + e_tok.record(torch.cuda.current_stream(device_index)) + e_tok.synchronize() + s0.synchronize() + + if not s0.query() or not s1.query() or not e_sync.query(): + return -1.0 + + # not necessary to check e_tik and e_tok, as elapsed_time would throw + # exception if otherwise. + return e_tik.elapsed_time(e_tok) + self.assertGreater(test_event_wait(), 0) + + # Test for stream wait_event. Checks if the stream waits on the event + @torch.jit.script + def test_wait_event(): + d1 = torch.device('cuda:1') + + with torch.jit.cuda.device(d1): + s0 = torch.cuda.current_stream(1) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + e0 = torch.jit.cuda.Event(False, False, False) + s0.record_event(e0) + + s1 = torch.cuda.current_stream(0) + s1.wait_event(e0) + s1.synchronize() + + return e0.query() and s0.query() and s1.query() + self.assertTrue(test_wait_event()) + + # Test if a scripted module with cuda streams can be saved, loaded and executed + def test_save_load(self): + class Model(torch.nn.Module): + def forward(self): + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + a = torch.rand(3, 4, device="cuda") + b = torch.rand(3, 4, device="cuda") + + with torch.jit.cuda.stream(s): + is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id() + c = torch.cat((a, b), 0).cuda() + s.synchronize() + return is_stream_s, a, b, c + + model = Model() + + # Script the model and save + script_model = torch.jit.script(model) + is_stream_s, a, b, c = script_model() + # Verify if the output is correct + self.assertTrue(is_stream_s) + self.assertEqual(torch.cat((a, b), 0), c) + + # Save and load scripted model + load_model = self.getExportImportCopy(script_model) + is_stream_s, a_load, b_load, c_load = load_model() + self.assertTrue(is_stream_s) + self.assertEqual(torch.cat((a_load, b_load), 0), c_load) diff --git a/test/jit/test_enum.py b/test/jit/test_enum.py index aa34c22413ad8..b39732d0e9bcf 100644 --- a/test/jit/test_enum.py +++ b/test/jit/test_enum.py @@ -356,6 +356,6 @@ def iterate_enum(x: Color): .check_same("Color.BLUE") \ .run(str(scripted.graph)) - # PURPLE always appear last because we follow Python's Enum definition order. + # PURPLE always appears last because we follow Python's Enum definition order. self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value]) self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value]) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 4ec8f7e46d1b3..bd31a7e8ce167 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -7,6 +7,8 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.jit._recursive import wrap_cpp_module +from typing import Any +from itertools import product import io @@ -524,6 +526,77 @@ def forward(self, x): self.assertEqual(output_s, output_f) + def test_freeze_module_with_preserve_sub_module(self): + class SubModule(nn.Module): + def __init__(self): + super(SubModule, self).__init__() + self.a = torch.tensor([1.1]) + self.b = 2.2 + + def forward(self, x): + return self.a + + class TestModule(nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.sub1 = SubModule() # aliasing + self.sub2 = SubModule() + + def forward(self, x): + return self.sub2(x) + self.sub1(x) + m = TestModule() + ms = torch.jit.script(m) + ms.eval() + mf = torch._C._freeze_module(ms._c, ["sub1"]) + + # Test that 'sub1' is preserved entirely and 'sub2' is completely folded + self.assertTrue(mf.hasattr('sub1')) + self.assertTrue(mf.sub1.hasattr('a')) + self.assertTrue(mf.sub1.hasattr('b')) + self.assertFalse(mf.hasattr('sub2')) + input = torch.randn(2, 2) + output_s = ms.forward(input) + output_f = mf.forward(input) + self.assertEqual(output_s, output_f) + + def test_freeze_module_with_preserve_sub_module_and_mutation(self): + class SubModule(nn.Module): + def __init__(self): + super(SubModule, self).__init__() + self.a = torch.tensor([1.1]) + self.b = 2.2 + + def forward(self, x): + self.a[0] = 3.3 + return self.a + + class TestModule(nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.sub1 = SubModule() # aliasing + self.sub2 = SubModule() + + def forward(self, x): + return self.sub2(x) + self.sub1(x) + m = TestModule() + ms = torch.jit.script(m) + ms.eval() + mf = torch._C._freeze_module(ms._c, ["sub1"]) + + # Test that be both sub1 and sub1 are preserved and 'b' is preserved + # even if it is not used. To fulfill user request to preserve 'sub1' + self.assertTrue(mf.hasattr('sub1')) + self.assertTrue(mf.sub1.hasattr('a')) + self.assertTrue(mf.sub1.hasattr('b')) + self.assertTrue(mf.hasattr('sub2')) + self.assertTrue(mf.sub2.hasattr('a')) + self.assertTrue(mf.sub2.hasattr('b')) + input = torch.randn(2, 2) + output_s = ms.forward(input) + output_f = mf.forward(input) + self.assertEqual(output_s, output_f) + + def test_freeze_module_with_helperfunction(self): class SubModule(nn.Module): def __init__(self): @@ -554,7 +627,7 @@ def _forward(self, x): self.assertFalse(mf.hasattr('sub')) self.assertFalse(mf.hasattr('a')) self.assertTrue(mf.hasattr('b')) - with self.assertRaisesRegex(RuntimeError, "TestModule does not have a field with name '_forward'"): + with self.assertRaisesRegex(AttributeError, "TestModule does not have a field with name '_forward'"): mf._forward(x) def test_freeze_module_with_inplace_mutable(self): @@ -963,10 +1036,78 @@ def forward(self, x): model = torch.jit.script(Net()) model.train() - - with self.assertRaisesRegex(RuntimeError, 'Freezing module in training mode is not yet supported'): - mTrain_freezed = torch._C._freeze_module(model._c) - + mTrain_freezed = torch._C._freeze_module(model._c) + # verify mTrain_freezed looks exactly as: + # module { + # attributes { + # conv1 = ... + # conv2 = ... + # dropout1 = ... + # dropout2 = ... + # fc1 = ... + # fc2 = ... + # } + # ... + # submodules { + # module conv1 { + # attributes { + # weight = ... + # bias = ... + # } + # ... + # } + # module conv2 { + # attributes { + # weight = ... + # bias = ... + # } + # ... + # } + # module dropout1 { + # attributes { + # training = ... + # } + # ... + # } + # module dropout2 { + # attributes { + # training = ... + # } + # ... + # } + # module fc1 { + # attributes { + # weight = ... + # bias = ... + # } + # ... + # } + # module fc2 { + # attributes { + # weight = ... + # bias = ... + # } + # ... + # } + self.assertFalse(mTrain_freezed.hasattr('training')) + self.assertTrue(mTrain_freezed.hasattr('conv1')) + self.assertFalse(mTrain_freezed.conv1.hasattr('training')) + self.assertTrue(mTrain_freezed.conv1.hasattr('weight')) + self.assertTrue(mTrain_freezed.conv1.hasattr('bias')) + self.assertTrue(mTrain_freezed.hasattr('conv2')) + self.assertFalse(mTrain_freezed.conv2.hasattr('training')) + self.assertTrue(mTrain_freezed.conv2.hasattr('weight')) + self.assertTrue(mTrain_freezed.conv2.hasattr('bias')) + self.assertTrue(mTrain_freezed.hasattr('dropout1')) + self.assertTrue(mTrain_freezed.dropout1.hasattr('training')) + self.assertTrue(mTrain_freezed.hasattr('dropout2')) + self.assertTrue(mTrain_freezed.dropout2.hasattr('training')) + self.assertTrue(mTrain_freezed.hasattr('fc1')) + self.assertTrue(mTrain_freezed.fc1.hasattr('weight')) + self.assertTrue(mTrain_freezed.fc1.hasattr('bias')) + self.assertTrue(mTrain_freezed.hasattr('fc2')) + self.assertTrue(mTrain_freezed.fc2.hasattr('weight')) + self.assertTrue(mTrain_freezed.fc2.hasattr('bias')) model.eval() mEval_freezed = torch._C._freeze_module(model._c) self.assertFalse(mEval_freezed.hasattr('conv1')) @@ -976,7 +1117,7 @@ def forward(self, x): self.assertFalse(mEval_freezed.hasattr('fc1')) self.assertFalse(mEval_freezed.hasattr('dropout2')) self.assertFalse(mEval_freezed.hasattr('fc2')) - with self.assertRaisesRegex(RuntimeError, "does not have a field with name 'state_dict'"): + with self.assertRaisesRegex(AttributeError, "does not have a field with name 'state_dict'"): print(mEval_freezed.state_dict()) buffer = io.BytesIO() torch.jit.save(mEval_freezed, buffer) @@ -984,6 +1125,14 @@ def forward(self, x): m = torch.jit.load(buffer) FileCheck().check_not('GetAttr[name=') \ .run(m._c._get_method('forward').graph) + m2 = torch._C._freeze_module(model._c, preserveParameters=True) + self.assertTrue(m2.hasattr('conv1')) + self.assertTrue(m2.hasattr('conv2')) + self.assertFalse(m2.hasattr('dropout1')) + self.assertFalse(m2.hasattr('training')) + self.assertTrue(m2.hasattr('fc1')) + self.assertFalse(m2.hasattr('dropout2')) + self.assertTrue(m2.hasattr('fc2')) def test_freeze_module_detach_gradient(self): mod = nn.Conv2d(8, 3, 4, 2, 1) @@ -997,7 +1146,8 @@ def test_freeze_module_detach_gradient(self): inp = torch.ones(1, 8, 32, 32) out1 = fmod.forward(inp) # FIXME: frozen module mutated from outside (original module). - smod.weight[0, 0, 0, 0] += 100.0 + with torch.no_grad(): + smod.weight[0, 0, 0, 0] += 100.0 out2 = fmod.forward(inp) out3 = smod(inp) self.assertNotEqual(out1, out2) @@ -1124,3 +1274,246 @@ def _static_quant(model): # It used to segfault while running frozen module. m_frozen_res = m_frozen(data) self.assertEqual(m_res, m_frozen_res) + + def test_module_getattr_indirection(self): + @torch.jit.script + class ValHolder(object): + def __init__(self, val: int): + self.val: int = val + + class Mod(nn.Module): + def __init__(self): + super(Mod, self).__init__() + self.mod1 = ValHolder(1) + self.mod2 = ValHolder(2) + + def forward(self, cond: bool): + if cond: + mod = self.mod1 + else: + mod = self.mod2 + return mod.val + + mod = Mod() + mod.eval() + frozen_mod = torch.jit.freeze(torch.jit.script(mod)) + mod_eager = Mod() + self.assertEqual(mod_eager(True), frozen_mod(True)) + self.assertEqual(mod_eager(False), frozen_mod(False)) + + def test_freeze_module_with_non_static_module_dict_index(self): + """ + Test that a Module contained a non-static ModuleDict index + cannot be frozen. + """ + @torch.jit.interface + class ModuleInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + pass + + class ImplementsInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + if isinstance(inp, torch.Tensor): + return torch.max(inp, dim=0) + + return inp + + # Test annotation of submodule. + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + value: ModuleInterface = self.d[key] + return value.forward(x) + + m = torch.jit.script(Mod()) + m.eval() + with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"): + mf = torch._C._freeze_module(m._c) + + def test_freeze_non_module_class_getattr(self): + class BoxCoder(object): + def __init__(self, bbox_xform_clip): + # type: (float) -> None + self.bbox_xform_clip = bbox_xform_clip + + def decode(self, input): + return input * self.bbox_xform_clip + + class MyModule(torch.nn.Module): + __annotations__ = { + 'box_coder': BoxCoder, + } + + def __init__(self): + super(MyModule, self).__init__() + self.box_coder = BoxCoder(50.) + + def forward(self, input): + return self.box_coder.decode(input) + + model = MyModule() + model.eval() + script_model = torch.jit.freeze(torch.jit.script(model)) + inp = torch.randn([4, 4]) + output_eager = model(inp) + self.assertEqual(model(inp), script_model(inp)) + FileCheck().check_not("GetAttr").run(script_model.graph) + +class TestFrozenOptimizations(JitTestCase): + def setUp(self): + self.default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + + def tearDown(self): + torch.set_default_dtype(self.default_dtype) + + def test_conv_bn_folding(self): + conv_bias = [True, False] + module_pairs = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)] + use_tracing = [True, False] + + for use_bias, modules, tracing in product(conv_bias, module_pairs, use_tracing): + class ConvBN(torch.nn.Module): + def __init__(self, in_channels, out_channels, **kwargs): + super(ConvBN, self).__init__() + self.conv = modules[0](in_channels, out_channels, bias=use_bias, **kwargs) + self.bn = modules[1](out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + return self.bn(x) + + mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() + inps = [4, 3, 4] + if modules[0] == nn.Conv2d: + inps.append(inps[-1]) + if modules[0] == nn.Conv3d: + inps.append(inps[-1]) + inps.append(inps[-1]) + + inp = torch.rand(inps) + + if tracing: + scripted_mod = torch.jit.trace(mod_eager, (inp)) + else: + scripted_mod = torch.jit.script(mod_eager) + + self.run_pass("inline", scripted_mod.graph) + self.run_pass("peephole", scripted_mod.graph) + self.run_pass("constant_propagation", scripted_mod.graph) + + FileCheck().check("conv").check("batch").run(scripted_mod.graph) + # successfully no-ops with non-const inputs + self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) + FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph) + + scripted_mod = torch.jit.freeze(scripted_mod) + self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) + FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) + + self.assertEqual(mod_eager(inp), scripted_mod(inp)) + self.assertEqual(mod_eager(inp), scripted_mod(inp)) + + + def test_conv_add_folding(self): + + @torch.no_grad() + def test_conv_fusion(use_bias, module, tracing, op, scalar, add_tensor, expect_success): + + class ConvOp(torch.nn.Module): + __constants__ = ['use_scalar'] + + def __init__(self, in_channels, out_channels, tensor=None, **kwargs): + super(ConvOp, self).__init__() + self.conv = module(in_channels, out_channels, bias=use_bias, **kwargs) + self.conv2 = module(in_channels, out_channels, bias=use_bias, **kwargs) + self.use_scalar = scalar + tensor_size = [1 for _ in range(self.conv.weight.ndim)] + tensor_size[1] = self.conv.weight.size(0) + self.tensor = add_tensor if add_tensor is not None else torch.rand(tensor_size) + self.op = op + + def forward(self, x): + x = self.conv(x) + if self.use_scalar: + return self.op(x, 2.) + else: + return self.op(x, self.tensor) + + mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval() + + inps = [4, 3, 4] + if module == nn.Conv2d: + inps.append(inps[-1]) + if module == nn.Conv3d: + inps.append(inps[-1]) + inps.append(inps[-1]) + + + inp = torch.rand(inps) + + if tracing: + scripted_mod = torch.jit.trace(mod_eager, (inp,)) + else: + scripted_mod = torch.jit.script(mod_eager) + + self.run_pass("inline", scripted_mod.graph) + op_str = "aten::" + op.__name__ + + FileCheck().check("conv").check(op_str).run(scripted_mod.graph) + # successively no-ops with non-const inputs + self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) + self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) + FileCheck().check("conv").check(op_str).run(scripted_mod.graph) + scripted_mod = torch.jit.freeze(scripted_mod) + self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) + self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) + + if expect_success: + FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph) + else: + FileCheck().check("conv").check(op_str).run(scripted_mod.graph) + + self.assertEqual(mod_eager(inp), scripted_mod(inp)) + self.assertEqual(mod_eager(inp), scripted_mod(inp)) + + conv_bias = [True, False] + modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] + use_tracing = [False, True] + use_scalar = [False, True] + ops = [torch.add, torch.sub, torch.mul, torch.div] + + for use_bias, module, tracing, pytorch_op, scalar in product(conv_bias, modules, use_tracing, ops, use_scalar): + test_conv_fusion(use_bias, module, tracing, pytorch_op, scalar, add_tensor=None, expect_success=True) + + + for use_bias, pytorch_op in product(conv_bias, ops): + # broadcasting add + test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False, + add_tensor=torch.rand(32, 1, 32), expect_success=False) + + # broadcasting add + test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False, add_tensor=torch.rand(1, 1), expect_success=True) + + # add with different dtype + test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False, + add_tensor=torch.rand(1).to(torch.int), expect_success=False) + + def test_optimize_freeze_module(self): + in_channels, out_channels = 3, 32 + conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) + bn = torch.nn.BatchNorm2d(out_channels, eps=.001) + mod = torch.nn.Sequential(conv, bn) + # set optimize to False here, by default freezing runs optimize_frozen_module + frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False) + # inspect frozen mod + FileCheck().check("batch_norm").run(frozen_mod.graph) + torch.jit.optimize_frozen_module(frozen_mod) + FileCheck().check_not("batch_norm").run(frozen_mod.graph) + + # optimize_frozen_module should be run + frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval())) + FileCheck().check_not("batch_norm").run(frozen_mod.graph) diff --git a/test/jit/test_fuser_common.py b/test/jit/test_fuser_common.py new file mode 100644 index 0000000000000..960ce640b57fb --- /dev/null +++ b/test/jit/test_fuser_common.py @@ -0,0 +1,17 @@ +import torch +from torch.testing._internal.jit_utils import JitTestCase + +class TestFuserCommon(JitTestCase): + def test_autodiff_fallback(self): + for rq in [True, False]: + @torch.jit.script + def fn(x): + return torch.max(x**2.0, x**3.0) + + x = torch.randn(5, requires_grad=not rq) + # cause optimization to be created + for i in range(5): + fn(x) + # test fallback when optimization is not applicable + y = fn(torch.randn(5, requires_grad=rq)) + self.assertEqual(y.requires_grad, rq) diff --git a/test/jit/test_hash.py b/test/jit/test_hash.py new file mode 100644 index 0000000000000..13c4761fc2bca --- /dev/null +++ b/test/jit/test_hash.py @@ -0,0 +1,107 @@ +import os +import sys + +import torch + +from typing import Tuple, List + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +class TestHash(JitTestCase): + def test_hash_tuple(self): + def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool: + return hash(t1) == hash(t2) + + self.checkScript(fn, ((1, 2), (1, 2))) + self.checkScript(fn, ((1, 2), (3, 4))) + self.checkScript(fn, ((1, 2), (2, 1))) + + def test_hash_tuple_nested_unhashable_type(self): + # Tuples may contain unhashable types like `list`, check that we error + # properly in that case. + @torch.jit.script + def fn_unhashable(t1: Tuple[int, List[int]]): + return hash(t1) + + with self.assertRaisesRegex(RuntimeError, "unhashable"): + fn_unhashable((1, [1])) + + def test_hash_tensor(self): + """Tensors should hash by identity""" + def fn(t1, t2): + return hash(t1) == hash(t2) + + tensor1 = torch.tensor(1) + tensor1_clone = torch.tensor(1) + tensor2 = torch.tensor(2) + + self.checkScript(fn, (tensor1, tensor1)) + self.checkScript(fn, (tensor1, tensor1_clone)) + self.checkScript(fn, (tensor1, tensor2)) + + def test_hash_none(self): + def fn(): + n1 = None + n2 = None + return hash(n1) == hash(n2) + + self.checkScript(fn, ()) + + def test_hash_bool(self): + def fn(b1: bool, b2: bool): + return hash(b1) == hash(b2) + + self.checkScript(fn, (True, False)) + self.checkScript(fn, (True, True)) + self.checkScript(fn, (False, True)) + self.checkScript(fn, (False, False)) + + def test_hash_float(self): + def fn(f1: float, f2: float): + return hash(f1) == hash(f2) + + self.checkScript(fn, (1.2345, 1.2345)) + self.checkScript(fn, (1.2345, 6.789)) + self.checkScript(fn, (1.2345, float("inf"))) + self.checkScript(fn, (float("inf"), float("inf"))) + self.checkScript(fn, (1.2345, float('nan'))) + self.checkScript(fn, (float("nan"), float("nan"))) + self.checkScript(fn, (float("nan"), float("inf"))) + + def test_hash_int(self): + def fn(i1: int, i2: int): + return hash(i1) == hash(i2) + + self.checkScript(fn, (123, 456)) + self.checkScript(fn, (123, 123)) + self.checkScript(fn, (123, -123)) + self.checkScript(fn, (-123, -123)) + self.checkScript(fn, (123, 0)) + + def test_hash_string(self): + def fn(s1: str, s2: str): + return hash(s1) == hash(s2) + + self.checkScript(fn, ("foo", "foo")) + self.checkScript(fn, ("foo", "bar")) + self.checkScript(fn, ("foo", "")) + + def test_hash_device(self): + def fn(d1: torch.device, d2: torch.device): + return hash(d1) == hash(d2) + + gpu0 = torch.device('cuda:0') + gpu1 = torch.device('cuda:1') + cpu = torch.device('cpu') + self.checkScript(fn, (gpu0, gpu0)) + self.checkScript(fn, (gpu0, gpu1)) + self.checkScript(fn, (gpu0, cpu)) + self.checkScript(fn, (cpu, cpu)) diff --git a/test/jit/test_isinstance.py b/test/jit/test_isinstance.py new file mode 100644 index 0000000000000..2e93c280c1036 --- /dev/null +++ b/test/jit/test_isinstance.py @@ -0,0 +1,273 @@ +import os +import sys + +import torch +from typing import List, Any, Dict, Tuple, Optional + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead." + ) + +# Tests for torch.jit.isinstance +class TestIsinstance(JitTestCase): + def test_int(self): + def int_test(x: Any): + assert torch.jit.isinstance(x, int) + assert not torch.jit.isinstance(x, float) + + x = 1 + self.checkScript(int_test, (x,)) + + def test_float(self): + def float_test(x: Any): + assert torch.jit.isinstance(x, float) + assert not torch.jit.isinstance(x, int) + + x = 1.0 + self.checkScript(float_test, (x,)) + + def test_bool(self): + def bool_test(x: Any): + assert torch.jit.isinstance(x, bool) + assert not torch.jit.isinstance(x, float) + + x = False + self.checkScript(bool_test, (x,)) + + def test_list(self): + def list_str_test(x: Any): + assert torch.jit.isinstance(x, List[str]) + assert not torch.jit.isinstance(x, List[int]) + assert not torch.jit.isinstance(x, Tuple[int]) + + x = ["1", "2", "3"] + self.checkScript(list_str_test, (x,)) + + def test_list_tensor(self): + def list_tensor_test(x: Any): + assert torch.jit.isinstance(x, List[torch.Tensor]) + assert not torch.jit.isinstance(x, Tuple[int]) + + x = [torch.Tensor([1]), torch.Tensor([2]), torch.Tensor([3])] + self.checkScript(list_tensor_test, (x,)) + + def test_dict(self): + def dict_str_int_test(x: Any): + assert torch.jit.isinstance(x, Dict[str, int]) + assert not torch.jit.isinstance(x, Dict[int, str]) + assert not torch.jit.isinstance(x, Dict[str, str]) + + x = {"a": 1, "b": 2} + self.checkScript(dict_str_int_test, (x,)) + + def test_dict_tensor(self): + def dict_int_tensor_test(x: Any): + assert torch.jit.isinstance(x, Dict[int, torch.Tensor]) + + x = {2: torch.tensor([2])} + self.checkScript(dict_int_tensor_test, (x,)) + + def test_tuple(self): + def tuple_test(x: Any): + assert torch.jit.isinstance(x, Tuple[str, int, str]) + assert not torch.jit.isinstance(x, Tuple[int, str, str]) + assert not torch.jit.isinstance(x, Tuple[str]) + + x = ("a", 1, "b") + self.checkScript(tuple_test, (x,)) + + def test_tuple_tensor(self): + def tuple_tensor_test(x: Any): + assert torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]) + + x = (torch.tensor([1]), torch.tensor([[2], [3]])) + self.checkScript(tuple_tensor_test, (x,)) + + def test_optional(self): + def optional_test(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + assert not torch.jit.isinstance(x, Optional[str]) + + x = torch.ones(3, 3) + self.checkScript(optional_test, (x,)) + + def test_optional_none(self): + def optional_test_none(x: Any): + assert torch.jit.isinstance(x, Optional[torch.Tensor]) + # assert torch.jit.isinstance(x, Optional[str]) + # TODO: above line in eager will evaluate to True while in + # the TS interpreter will evaluate to False as the + # first torch.jit.isinstance refines the 'None' type + + x = None + self.checkScript(optional_test_none, (x,)) + + def test_list_nested(self): + def list_nested(x: Any): + assert torch.jit.isinstance(x, List[Dict[str, int]]) + assert not torch.jit.isinstance(x, List[List[str]]) + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_nested, (x,)) + + def test_dict_nested(self): + def dict_nested(x: Any): + assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]]) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + + x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")} + self.checkScript(dict_nested, (x,)) + + def test_tuple_nested(self): + def tuple_nested(x: Any): + assert torch.jit.isinstance( + x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]] + ) + assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) + assert not torch.jit.isinstance(x, Tuple[str]) + assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]]) + + x = ( + {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}, + [True, False, True], + None, + ) + self.checkScript(tuple_nested, (x,)) + + def test_optional_nested(self): + def optional_nested(x: Any): + assert torch.jit.isinstance(x, Optional[List[str]]) + + x = ["a", "b", "c"] + self.checkScript(optional_nested, (x,)) + + def test_list_tensor_type_true(self): + def list_tensor_type_true(x: Any): + assert torch.jit.isinstance(x, List[torch.Tensor]) + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(list_tensor_type_true, (x,)) + + def test_tensor_type_false(self): + def list_tensor_type_false(x: Any): + assert not torch.jit.isinstance(x, List[torch.Tensor]) + + x = [1, 2, 3] + self.checkScript(list_tensor_type_false, (x,)) + + def test_in_if(self): + def list_in_if(x: Any): + if torch.jit.isinstance(x, List[int]): + assert True + if torch.jit.isinstance(x, List[str]): + assert not True + + x = [1, 2, 3] + self.checkScript(list_in_if, (x,)) + + def test_if_else(self): + def list_in_if_else(x: Any): + if torch.jit.isinstance(x, Tuple[str, str, str]): + assert True + else: + assert not True + + x = ("a", "b", "c") + self.checkScript(list_in_if_else, (x,)) + + def test_in_while_loop(self): + def list_in_while_loop(x: Any): + count = 0 + while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0: + count = count + 1 + assert count == 1 + + x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] + self.checkScript(list_in_while_loop, (x,)) + + def test_type_refinement(self): + def type_refinement(obj: Any): + hit = False + if torch.jit.isinstance(obj, List[torch.Tensor]): + hit = not hit + for el in obj: + # perform some tensor operation + y = el.clamp(0, 0.5) + if torch.jit.isinstance(obj, Dict[str, str]): + hit = not hit + str_cat = "" + for val in obj.values(): + str_cat = str_cat + val + assert "111222" == str_cat + assert hit + + x = [torch.rand(3, 3), torch.rand(4, 3)] + self.checkScript(type_refinement, (x,)) + x = {"1": "111", "2": "222"} + self.checkScript(type_refinement, (x,)) + + def test_list_no_contained_type(self): + def list_no_contained_type(x: Any): + assert torch.jit.isinstance(x, List) + + x = ["1", "2", "3"] + + err_msg = "Attempted to use List without a contained type. " \ + r"Please add a contained type, e.g. List\[int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(list_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + list_no_contained_type(x) + + + + def test_tuple_no_contained_type(self): + def tuple_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Tuple) + + x = ("1", "2", "3") + + err_msg = "Attempted to use Tuple without a contained type. " \ + r"Please add a contained type, e.g. Tuple\[int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(tuple_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + tuple_no_contained_type(x) + + def test_optional_no_contained_type(self): + def optional_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Optional) + + x = ("1", "2", "3") + + err_msg = "Attempted to use Optional without a contained type. " \ + r"Please add a contained type, e.g. Optional\[int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(optional_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + optional_no_contained_type(x) + + def test_dict_no_contained_type(self): + def dict_no_contained_type(x: Any): + assert torch.jit.isinstance(x, Dict) + + x = {"a": "aa"} + + err_msg = "Attempted to use Dict without contained types. " \ + r"Please add contained type, e.g. Dict\[int, int\]" + + with self.assertRaisesRegex(RuntimeError, err_msg,): + torch.jit.script(dict_no_contained_type) + with self.assertRaisesRegex(RuntimeError, err_msg,): + dict_no_contained_type(x) diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 8d0f74349b3be..9d6be7806628d 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -1,13 +1,13 @@ import os import sys import inspect -from typing import Dict, List, Optional, Tuple, Any +from typing import Any, Dict, List, Optional, Tuple from textwrap import dedent from collections import OrderedDict +from torch import Tensor import torch from torch.testing import FileCheck -from torch import Tensor # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -21,22 +21,19 @@ class TestList(JitTestCase): def test_in_check(self): - def int_in(x): - # type: (List[int]) -> bool + def int_in(x: List[int]) -> bool: return 2 in x self.checkScript(int_in, ([1, 2, 3],)) self.checkScript(int_in, ([1, 3, 3],)) - def float_in(x): - # type: (List[float]) -> bool + def float_in(x: List[float]) -> bool: return 2. in x self.checkScript(float_in, ([1., 2., 3.],)) self.checkScript(float_in, ([1., 3., 3.],)) - def str_in(x): - # type: (List[str]) -> bool + def str_in(x: List[str]) -> bool: return 'hi' in x self.checkScript(str_in, (['not', 'here'],)) @@ -46,21 +43,21 @@ def str_in(x): def test_list_literal(self): def reassign(): x = [1] - if True: + if 1 == 1: x = [2, 3] return self.checkScript(reassign, (), optimize=False) def reassign_arity_change(): x = [1] - if True: + if 1 == 1: x = [1, 2, 3] return self.checkScript(reassign_arity_change, (), optimize=False) def reassign_from_empty_literal(): x = [] - if True: + if 1 == 1: x = [1, 2, 3] return with self.assertRaisesRegex(RuntimeError, r"previously has type List\[Tensor\]"): @@ -68,20 +65,20 @@ def reassign_from_empty_literal(): def reassign_from_empty_builtin(): x = torch.jit.annotate(List[int], []) - if True: + if 1 == 1: x = [1, 2, 3] y = torch.jit.annotate(List[float], []) - if True: + if 1 == 1: y = [1.0, 2.0, 3.0] z = [] - if True: + if 1 == 1: z = [torch.randn([1])] return self.checkScript(reassign_from_empty_builtin, (), optimize=False) def reassign_bad_type(): x = [1] - if True: + if 1 == 1: x = [1.0] return with self.assertRaisesRegex(RuntimeError, "previously has type"): @@ -89,9 +86,9 @@ def reassign_bad_type(): def reassign_nested(): x = torch.jit.annotate(List[int], []) - if True: + if 1 == 1: x = [1, 2, 3] - if True: + if 1 == 1: x = [1.0] return with self.assertRaisesRegex(RuntimeError, "previously has type"): @@ -101,8 +98,7 @@ def test_del(self): def inputs(): return [1, 2, 3, 4] - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: del x[1] return x @@ -115,8 +111,7 @@ def fn(x): self.assertEqual(torch.jit.script(fn)(inputs()), python_out) @torch.jit.script - def fn2(x): - # type: (List[int]) -> List[int] + def fn2(x: List[int]) -> List[int]: del x[100] return x @@ -125,29 +120,43 @@ def fn2(x): with self.assertRaisesRegex(RuntimeError, "deletion at a single index"): @torch.jit.script - def fn(x): - # type: (List[int]) -> List[int] + def fn(x: List[int]) -> List[int]: del x[1:3] return x + def test_list_keyword(self): + def foo(): + return list([1, 2, 3]), list(("a", "b")), list(range(5)), list("abcdefg") # noqa: C410 + + self.checkScript(foo, ()) + + def foo2(): + x: List[int] = list() + x.append(1) + return x, + + self.checkScript(foo2, ()) + + def foo3(): + return list(list("abc")) + + self.checkScript(foo3, ()) + FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph) + def test_min_bool_list(self): - def jit_min_list(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]: return min(a, b) self.checkScript(jit_min_list, ([True, False], [False, True])) def test_min_max_list(self): - def jit_min_list(a, b): - # type: (List[int], List[int]) -> List[int] + def jit_min_list(a: List[int], b: List[int]) -> List[int]: return min(a, b) - def jit_min_list_float(a, b): - # type: (List[float], List[float]) -> List[float] + def jit_min_list_float(a: List[float], b: List[float]) -> List[float]: return min(a, b) - def jit_min_list_bool(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]: return min(a, b) def run_tests(func, a, b): @@ -168,16 +177,13 @@ def run_tests(func, a, b): [False, True], [False, False, True], [False, False, False]] run_tests(jit_min_list_bool, args_left_bool, args_right_bool) - def jit_max_list(a, b): - # type: (List[int], List[int]) -> List[int] + def jit_max_list(a: List[int], b: List[int]) -> List[int]: return max(a, b) - def jit_max_list_float(a, b): - # type: (List[float], List[float]) -> List[float] + def jit_max_list_float(a: List[float], b: List[float]) -> List[float]: return max(a, b) - def jit_max_list_bool(a, b): - # type: (List[bool], List[bool]) -> List[bool] + def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]: return max(a, b) args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]] @@ -347,8 +353,7 @@ def func(): t2 = scope['func']() self.assertEqual(t1, t2) - def test_fail(x): - # type: (List[Tensor]) -> List[Tensor] + def test_fail(x: List[Tensor]) -> List[Tensor]: x.sort() return x @@ -408,6 +413,43 @@ def test_over_slice(): return a[3:10] == [3, 4] self.checkScript(test_backward_slice, ()) + def test_slice_index(self): + a = torch.tensor( + [ + [[1, 11], [2, 22]], + [[3, 33], [4, 44]], + [[5, 55], [6, 66]], + ] + ) + + def test_index_slice1(x): + x = x[:, :, [0, 1]] + return x + self.checkScript(test_index_slice1, (a,)) + + def test_index_slice2(x): + x = x[[2, 1, 0], :, :] + return x + self.checkScript(test_index_slice2, (a,)) + + def test_index_slice3(x): + x = x[[0, 1], :, [1]] + return x + self.checkScript(test_index_slice3, (a,)) + + def test_index_slice_empty_list(x): + empty_list: List[int] = [] + x = x[empty_list, :, :] + return x + self.checkScript(test_index_slice_empty_list, (a,)) + + def test_index_slice_out_of_bounds_index(x): + x = x[[4], :, :] + return x + with self.assertRaisesRegex(RuntimeError, "index 4 is out of bounds for dimension 0 with size 3"): + self.checkScript(test_index_slice_out_of_bounds_index, (a,)) + + def test_mutable_list_append(self): def test_append(): a = [0, 1] @@ -417,8 +459,7 @@ def test_append(): self.checkScript(test_append, ()) def test_comprehensions_basic(self): - def comp(l): - # type: (List[int]) -> List[int] + def comp(l: List[int]) -> List[int]: n = [x * 3 for x in l] return n @@ -427,8 +468,7 @@ def comp(l): self.checkScript(comp, ([1, 2, 3],)) def test_comprehensions_basic_float(self): - def comp(l): - # type: (List[float]) -> List[float] + def comp(l: List[float]) -> List[float]: n = [x * 3 for x in l] return n @@ -437,8 +477,7 @@ def comp(l): def test_comprehensions_two_comps(self): @torch.jit.script - def comp(l1, l2): - # type: (List[int], List[int]) -> List[int] + def comp(l1: List[int], l2: List[int]) -> List[int]: n = [x * 3 for x in l1] n2 = [x + 2 for x in l2] @@ -447,8 +486,7 @@ def comp(l1, l2): self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7]) def test_comprehension_out_type_not_in_type(self): - def list_cast(): - # type: () -> int + def list_cast() -> int: li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]] return li[0] + li[1] + li[2] @@ -458,15 +496,13 @@ def test_comprehension_iterable(self): def test_func(fn, inputs): self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) - def foo(names, results): - # type: (List[int], List[int]) -> List[Tuple[int, int]] + def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]: return [(k + 5, v - 2) for k, v in zip(names, results)] test_func(foo, ([1, 2, 4], [4, 7, 9])) test_func(foo, ([5], [4, 7, 9])) - def fn(x): - # type: (int) -> List[int] + def fn(x: int) -> List[int]: return [i for i in range(x)] # noqa: C416 test_func(fn, (9,)) @@ -498,7 +534,7 @@ def test_append_2(): def test_mutable_list_append_if(self): def test_append_if(): a = [1] - if True: + if 1 == 1: a.append(4) return a == [1, 4] self.checkScript(test_append_if, ()) @@ -506,7 +542,7 @@ def test_append_if(): def test_mutable_list_append_if_else(self): def test_append_if_else(): a = [1] - if False: + if 1 == 2: a.append(4) else: a.append(10) @@ -546,8 +582,7 @@ def test_nested_loop(): def test_mutable_list_function_inline(self): @torch.jit.script - def bar(y): - # type: (List[int]) -> None + def bar(y: List[int]) -> None: y.append(4) @torch.jit.script @@ -833,8 +868,7 @@ def test_list_remove2(): def test_extend_list_mutable(self): @torch.jit.script - def extend_list(a, b): - # type: (List[Tensor], List[Tensor]) -> List[Tensor] + def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]: a.extend(b) return a @@ -845,8 +879,7 @@ def extend_list(a, b): def test_extend_list_immutable(self): @torch.jit.script - def extend_list(a, b): - # type: (List[int], List[int]) -> List[int] + def extend_list(a: List[int], b: List[int]) -> List[int]: a.extend(b) return a @@ -857,8 +890,7 @@ def extend_list(a, b): def test_copy_list_mutable(self): @torch.jit.script - def copy_list(a): - # type: (List[Tensor]) -> List[Tensor] + def copy_list(a: List[Tensor]) -> List[Tensor]: return a.copy() for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: @@ -866,36 +898,29 @@ def copy_list(a): def test_copy_list_immutable(self): @torch.jit.script - def copy_list(a): - # type: (List[int]) -> List[int] + def copy_list(a: List[int]) -> List[int]: return a.copy() for l in [[], [1], [1, 2, 3]]: self.assertEqual(copy_list(l), l) def test_min_max_single_list(self): - def min_intlist(li): - # type: (List[int]) -> int + def min_intlist(li: List[int]) -> int: return min(li) - def max_intlist(li): - # type: (List[int]) -> int + def max_intlist(li: List[int]) -> int: return max(li) - def min_boollist(li): - # type: (List[bool]) -> bool + def min_boollist(li: List[bool]) -> bool: return min(li) - def max_boollist(li): - # type: (List[bool]) -> bool + def max_boollist(li: List[bool]) -> bool: return max(li) - def min_floatlist(li): - # type: (List[float]) -> float + def min_floatlist(li: List[float]) -> float: return min(li) - def max_floatlist(li): - # type: (List[float]) -> float + def max_floatlist(li: List[float]) -> float: return max(li) @@ -911,11 +936,11 @@ def check_list(fn, li): check_list(min_intlist, int_list) check_list(max_intlist, int_list) - bool_li = list(map(lambda x: bool(x), int_list)) + bool_li = [bool(x) for x in int_list] check_list(min_boollist, bool_li) check_list(max_boollist, bool_li) - float_li = list(map(lambda x: float(x), int_list)) + float_li = [float(x) for x in int_list] check_list(min_floatlist, float_li) check_list(max_floatlist, float_li) @@ -925,23 +950,19 @@ def test_to_list(self): """ Boolean dtype unit tests. """ - def to_list_bool_0D(x): - # type: (torch.Tensor) -> bool + def to_list_bool_0D(x: torch.Tensor) -> bool: li = torch.jit.annotate(bool, x.tolist()) return li - def to_list_bool_1D(x): - # type: (torch.Tensor) -> List[bool] + def to_list_bool_1D(x: torch.Tensor) -> List[bool]: li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_bool_2D(x): - # type: (torch.Tensor) -> List[List[bool]] + def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]: li = torch.jit.annotate(List[List[bool]], x.tolist()) return li - def to_list_bool_3D(x): - # type: (torch.Tensor) -> List[List[List[bool]]] + def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]: li = torch.jit.annotate(List[List[List[bool]]], x.tolist()) return li @@ -966,23 +987,19 @@ def to_list_bool_3D(x): """ Int dtype unit tests. """ - def to_list_int_0D(x): - # type: (torch.Tensor) -> int + def to_list_int_0D(x: torch.Tensor) -> int: li = torch.jit.annotate(int, x.tolist()) return li - def to_list_int_1D(x): - # type: (torch.Tensor) -> List[int] + def to_list_int_1D(x: torch.Tensor) -> List[int]: li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_int_2D(x): - # type: (torch.Tensor) -> List[List[int]] + def to_list_int_2D(x: torch.Tensor) -> List[List[int]]: li = torch.jit.annotate(List[List[int]], x.tolist()) return li - def to_list_int_3D(x): - # type: (torch.Tensor) -> List[List[List[int]]] + def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]: li = torch.jit.annotate(List[List[List[int]]], x.tolist()) return li @@ -1003,23 +1020,19 @@ def to_list_int_3D(x): """ Float dtype unit tests. """ - def to_list_float_0D(x): - # type: (torch.Tensor) -> float + def to_list_float_0D(x: torch.Tensor) -> float: li = torch.jit.annotate(float, x.tolist()) return li - def to_list_float_1D(x): - # type: (torch.Tensor) -> List[float] + def to_list_float_1D(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li - def to_list_float_2D(x): - # type: (torch.Tensor) -> List[List[float]] + def to_list_float_2D(x: torch.Tensor) -> List[List[float]]: li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_float_3D(x): - # type: (torch.Tensor) -> List[List[List[float]]] + def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]: li = torch.jit.annotate(List[List[List[float]]], x.tolist()) return li @@ -1044,28 +1057,23 @@ def to_list_float_3D(x): - type annotation with the wrong dimension - type annotation with scalar type that doesn't match the input scalar type """ - def to_list_missing_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]: li = x.tolist() return li - def to_list_incorrect_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(float, x.tolist()) return li - def to_list_unsupported_type_annotation(x): - # type: (torch.Tensor) -> List[float] + def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[str], x.tolist()) return li - def to_list_type_annotation_wrong_dim(x): - # type: (torch.Tensor) -> List[List[float]] + def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]: li = torch.jit.annotate(List[List[float]], x.tolist()) return li - def to_list_type_annotation_incorrect_scalar_type(x): - # type: (torch.Tensor) -> List[float] + def to_list_type_annotation_incorrect_scalar_type(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1109,18 +1117,15 @@ def test_to_list_gpu(self): if not torch.cuda.is_available() or torch.cuda.device_count() == 0: self.skipTest("CUDA is not available") - def to_list_bool_1D(x): - # type: (torch.Tensor) -> List[bool] + def to_list_bool_1D(x: torch.Tensor) -> List[bool]: li = torch.jit.annotate(List[bool], x.tolist()) return li - def to_list_int_1D(x): - # type: (torch.Tensor) -> List[int] + def to_list_int_1D(x: torch.Tensor) -> List[int]: li = torch.jit.annotate(List[int], x.tolist()) return li - def to_list_float_1D(x): - # type: (torch.Tensor) -> List[float] + def to_list_float_1D(x: torch.Tensor) -> List[float]: li = torch.jit.annotate(List[float], x.tolist()) return li @@ -1132,8 +1137,7 @@ def to_list_float_1D(x): 5, dtype=torch.double).cuda(),)) def test_no_element_type_annotation(self): - def fn_with_comment(x): - # type: (torch.Tensor) -> List + def fn_with_comment(x: torch.Tensor) -> List: a: List = x.tolist() return a @@ -1175,8 +1179,7 @@ def test_del(self): def inputs(): return {'hi': 2, 'bye': 3} - def fn(x): - # type: (Dict[str, int]) -> Dict[str, int] + def fn(x: Dict[str, int]) -> Dict[str, int]: del x['hi'] return x @@ -1192,8 +1195,7 @@ def fn(x): def test_keys(self): @torch.jit.script - def keys(x): - # type: (Dict[str, Tensor]) -> List[str] + def keys(x: Dict[str, Tensor]) -> List[str]: return list(x.keys()) self.assertEqual(set(keys(self.dict())), set(self.dict().keys())) @@ -1208,30 +1210,26 @@ def specialized_list(): def test_values(self): @torch.jit.script - def values(x): - # type: (Dict[str, Tensor]) -> List[Tensor] + def values(x: Dict[str, Tensor]) -> List[Tensor]: return list(x.values()) the_dict = self.dict() self.assertEqual(set(values(the_dict)), set(the_dict.values())) def test_len(self): - def length(x): - # type: (Dict[str, Tensor]) -> int + def length(x: Dict[str, Tensor]) -> int: return len(x) self.checkScript(length, (self.dict(),)) def test_copy(self): - def func(x): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]: return x.copy() self.checkScript(func, (self.dict(),)) def test_items(self): - def func(x): - # type: (Dict[str, Tensor]) -> List[Tuple[str, Tensor]] + def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: return x.items() # The value returned by Python is in arbitrary order, so we can't use @@ -1246,8 +1244,7 @@ def func(x): self.assertTrue(item in script_out) def test_pop(self): - def pop(x, key): - # type: (Dict[str, Tensor], str) -> Tuple[Tensor, Dict[str, Tensor]] + def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]: return x.pop(key), x # checkScript doesn't copy the inputs, so we can't use it since this mutates @@ -1263,16 +1260,14 @@ def tester(fn, *args): torch.jit.script(pop)(self.dict(), 'x') - def default_pop(x, key, default): - # type: (Dict[str, Tensor], str, Tensor) -> Tuple[Tensor, Dict[str, Tensor]] + def default_pop(x: Dict[str, Tensor], key: str, default: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: return x.pop(key, default), x tester(default_pop, 'a', torch.randn(2, 2)) tester(default_pop, 'x', torch.randn(2, 2)) def test_setdefault(self): - def setdefault(x, key, default): - # type: (Dict[str, Tensor], str, Tensor) -> Dict[str, Tensor] + def setdefault(x: Dict[str, Tensor], key: str, default: Tensor) -> Dict[str, Tensor]: x.setdefault(key, default) return x @@ -1280,17 +1275,24 @@ def setdefault(x, key, default): self.checkScript(setdefault, (self.dict(), 'nonexistant', torch.randn(2, 2))) def test_update(self): - def update(a, b): - # type: (Dict[str, Tensor], Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]] + def update(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: a.update(b) return a, b self.checkScript(update, (self.dict(), self.dict())) self.checkScript(update, (self.dict(), self.dict2())) + def test_update_existing_key(self): + def foo() -> Dict[str, int]: + a: Dict[str, int] = {} + for i in range(3): + a.update({'a': i}) + return a + + self.checkScript(foo, ()) + def test_aug_assign(self): - def aug_assign_dict_tensor(a): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]: a['a'] += 1 a['b'] -= 12 a['c'] *= 122 @@ -1298,8 +1300,7 @@ def aug_assign_dict_tensor(a): a['c'] %= 2 return a - def aug_assign_dict_prim(a): - # type: (Dict[str, float]) -> Dict[str, float] + def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: a['a'] += 3.4 a['b'] -= 2.4 a['c'] *= 3.0 @@ -1312,8 +1313,7 @@ def aug_assign_dict_prim(a): def test_popitem(self): @torch.jit.script - def popitem(x): - # type: (Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]] + def popitem(x: Dict[str, Tensor]) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]: item = x.popitem() return item, x @@ -1331,65 +1331,56 @@ def popitem(x): self.assertTrue(isinstance(script_out[0][1], torch.Tensor)) def test_clear(self): - def clear(x): - # type: (Dict[str, Tensor]) -> Dict[str, Tensor] + def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]: x.clear() return x self.checkScript(clear, (self.dict(),)) def test_get(self): - def get(x, key): - # type: (Dict[str, Tensor], str) -> Optional[Tensor] + def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: return x.get(key) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) - def get_default(x, key): - # type: (Dict[str, Tensor], str) -> Optional[Tensor] + def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: return x.get(key, torch.randn(2, 2)) self.checkScript(get, (self.dict(), 'a')) self.checkScript(get, (self.dict(), "doesn't exist")) def test_get_boolkey(self): - def get(x, key): - # type: (Dict[bool, int], bool) -> Optional[int] + def get(x: Dict[bool, int], key: bool) -> Optional[int]: return x.get(key) self.checkScript(get, (self.dict_bool(), True)) self.checkScript(get, (self.dict_bool(), False)) - def get_default(x, key): - # type: (Dict[bool, int], bool) -> int + def get_default(x: Dict[bool, int], key: bool) -> int: return x.get(key, 42) self.checkScript(get_default, (self.dict_bool(), True)) self.checkScript(get_default, (self.dict_bool(), False)) def test_basic(self): - def simple(x): - # type: (Dict[str, int]) -> Dict[str, int] + def simple(x: Dict[str, int]) -> Dict[str, int]: return x self.checkScript(simple, ({'item': 20, 'other_item': 120},)) - def index(x): - # type: (Dict[str, int]) -> int + def index(x: Dict[str, int]) -> int: return x['item'] self.checkScript(index, ({'item': 20, 'other_item': 120},)) - def type_default(): - # type: () -> Dict[str, Tensor] + def type_default() -> Dict[str, Tensor]: return {} self.checkScript(type_default, ()) @torch.jit.script - def missing_index(x): - # type: (Dict[str, int]) -> int + def missing_index(x: Dict[str, int]) -> int: return x['dne'] with self.assertRaisesRegex(RuntimeError, "KeyError"): @@ -1411,16 +1402,14 @@ def literal3(): ''')) self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3()) - def list_of_dicts(): - # type: () -> List[Dict[str, Tensor]] + def list_of_dicts() -> List[Dict[str, Tensor]]: return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}] self.checkScript(list_of_dicts, ()) def test_mutability(self): @torch.jit.script - def fn(): - # type: () -> Dict[str, int] + def fn() -> Dict[str, int]: a = torch.jit.annotate(Dict[str, int], {}) a['ok'] = 10 return a @@ -1430,14 +1419,12 @@ def fn(): def test_key_type(self): with self.assertRaisesRegex(RuntimeError, "but instead found type"): @torch.jit.script - def fn(a): - # type: (Dict[str, int]) -> int + def fn(a: Dict[str, int]) -> int: return a[None] def test_loop(self): @torch.jit.script - def fn(x): - # type: (int) -> Dict[str, int] + def fn(x: int) -> Dict[str, int]: a = torch.jit.annotate(Dict[str, int], {}) for i in range(x): a['ok'] = i @@ -1456,16 +1443,14 @@ def fn(x, y): self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) def test_membership(self): - def fn(x, y): - # type: (Dict[int, int], int) -> int + def fn(x: Dict[int, int], y: int) -> int: return x.get(y, 3) d = {1: 2, 3: 4} self.checkScript(fn, (d, 3)) self.checkScript(fn, (d, 2)) - def optional(x, y): - # type: (Dict[int, int], int) -> bool + def optional(x: Dict[int, int], y: int) -> bool: res = x.get(y) return res is None @@ -1474,18 +1459,15 @@ def optional(x, y): with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"): @torch.jit.script - def bad_types(x, y): - # type: (Dict[int, int], int) -> int + def bad_types(x: Dict[int, int], y: int) -> int: return x.get(y) # noqa: T484 def test_dict_to_python(self): @torch.jit.ignore - def python_lookup(my_dict, keys): - # type: (Dict[str, int], List[str]) -> List[int] + def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]: return [my_dict[k] for k in keys] - def fn(my_dict, keys): - # type: (Dict[str, int], List[str]) -> List[int] + def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]: return python_lookup(my_dict, keys) a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} @@ -1537,8 +1519,7 @@ def test_type_annotation_missing_contained_type(self): key and value types produces an error. """ # This function uses a type comment. - def fn_with_comment(input): - # type: (Dict) -> Any + def fn_with_comment(input: Dict) -> Any: return input # This function uses Python3 style type annotations. diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py index b53bf10a70c28..e261124bedb55 100644 --- a/test/jit/test_module_containers.py +++ b/test/jit/test_module_containers.py @@ -1,7 +1,7 @@ import os import sys -from typing import List +from typing import Any, List, Tuple from collections import OrderedDict import torch import torch.nn as nn @@ -428,3 +428,64 @@ def forward(self, inputs): m = MyModule() self.checkModule(m, [torch.randn(2, 2)]) + + def test_typed_module_dict(self): + """ + Test that a type annotation can be provided for a ModuleDict that allows + non-static indexing. + """ + @torch.jit.interface + class ModuleInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + pass + + class ImplementsInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + if isinstance(inp, torch.Tensor): + return torch.max(inp, dim=0) + + return inp + + class DoesNotImplementInterface(torch.nn.Module): + def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.max(inp, dim=0) + + # Test annotation of submodule. + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + value: ModuleInterface = self.d[key] + return value.forward(x) + + m = Mod() + self.checkModule(m, (torch.randn(2, 2), "module")) + + # Test annotation of self. + class ModDict(torch.nn.ModuleDict): + def __init__(self): + super().__init__({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + submodule: ModuleInterface = self[key] + return submodule.forward(x) + + m = ModDict() + self.checkModule(m, (torch.randn(2, 2), "module")) + + # Test error message thrown when annotated attribute does not comply with the + # annotation. + class ModWithWrongAnnotation(torch.nn.ModuleDict): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + submodule: ModuleInterface = self.d[key] + return submodule.forward(x) + + with self.assertRaisesRegex(RuntimeError, r"Attribute module is not of annotated type"): + torch.jit.script(ModWithWrongAnnotation()) diff --git a/test/jit/test_module_interface.py b/test/jit/test_module_interface.py index f06dafbc1ba20..d0626b1068b44 100644 --- a/test/jit/test_module_interface.py +++ b/test/jit/test_module_interface.py @@ -1,11 +1,12 @@ # flake8: noqa # TODO: enable linting check for this file -from typing import List +from typing import List, Any import torch import torch.nn as nn import os import sys +from torch import Tensor from torch.testing._internal.jit_utils import JitTestCase # Make the helper files in test/ importable @@ -22,36 +23,30 @@ class OrigModule(nn.Module): def __init__(self): super(OrigModule, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 - def two(self, input): - # type: (Tensor) -> Tensor + def two(self, input: Tensor) -> Tensor: return input + 2 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewModule(nn.Module): def __init__(self): super(NewModule, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestModuleInterface(JitTestCase): def test_not_submodule_interface_call(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestNotModuleInterfaceCall(nn.Module): @@ -61,8 +56,7 @@ def __init__(self): super(TestNotModuleInterfaceCall, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.two(input) with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"): @@ -72,64 +66,51 @@ def test_module_interface(self): global OneTwoModule, OneTwoClass @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass @torch.jit.interface class OneTwoClass(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass class FooMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x * y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 / x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) @torch.jit.export - def forward2(self, x): - # type: (Tensor) -> Tensor + def forward2(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) + 1 def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) - def use_class_interface(mod_list, x): - # type: (List[OneTwoClass], Tensor) -> Tensor + def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor: return mod_list[0].two(x) + mod_list[1].one(x, x) scripted_foo_mod = torch.jit.script(FooMod()) @@ -139,48 +120,66 @@ def use_class_interface(mod_list, x): self.checkScript(use_class_interface, ([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),)) - def call_module_interface_on_other_method(mod_interface, x): - # type: (OneTwoModule, Tensor) -> Tensor + def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor: return mod_interface.forward2(x) # ensure error out when we call the module on the method other than the interface specified. with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute or method"): self.checkScript(call_module_interface_on_other_method, (scripted_bar_mod, torch.rand(3, 4),)) + def test_module_doc_string(self): + @torch.jit.interface + class TestInterface(nn.Module): + def one(self, inp1, inp2): + # type: (Tensor, Tensor) -> Tensor + pass + def forward(self, input): + # type: (Tensor) -> Tensor + r"""stuff 1""" + r"""stuff 2""" + pass + r"""stuff 3""" + + class TestModule(nn.Module): + proxy_mod : TestInterface + + def __init__(self): + super(TestModule, self).__init__() + self.proxy_mod = OrigModule() + + def forward(self, input): + # type: (Tensor) -> Tensor + return self.proxy_mod.forward(input) + + input = torch.randn(3, 4) + self.checkModule(TestModule(), (input,)) def test_module_interface_subtype(self): global OneTwoModule @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: pass - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass @torch.jit.script - def as_module_interface(x): - # type: (OneTwoModule) -> OneTwoModule + def as_module_interface(x: OneTwoModule) -> OneTwoModule: return x @torch.jit.script class Foo(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: Tensor) -> Tensor: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return self.one(self.two(x), x) # check class object is not a subtype of module interface @@ -188,12 +187,10 @@ def forward(self, x): as_module_interface(Foo()) class WrongMod(nn.Module): - def two(self, x): - # type: (int) -> int + def two(self, x: int) -> int: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: return x + torch.randn(3, self.two(3)) scripted_wrong_mod = torch.jit.script(WrongMod()) @@ -202,23 +199,58 @@ def forward(self, x): with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"): as_module_interface(scripted_wrong_mod) + # Check that interface implementations can be contravariant in argument types and covariant in return type. + global TensorToAny + @torch.jit.interface + class TensorToAny(nn.Module): + def forward(self, input: torch.Tensor) -> Any: + pass + + @torch.jit.script + def as_tensor_to_any(x: TensorToAny) -> TensorToAny: + return x + + global AnyToAny + @torch.jit.interface + class AnyToAny(nn.Module): + def forward(self, input: Any) -> Any: + pass + + @torch.jit.script + def as_any_to_any(x: AnyToAny) -> AnyToAny: + return x + + class TensorToAnyImplA(nn.Module): + def forward(self, input: Any) -> Any: + return input + + class TensorToAnyImplB(nn.Module): + def forward(self, input: Any) -> torch.Tensor: + return torch.tensor([1]) + + class AnyToAnyImpl(nn.Module): + def forward(self, input: Any) -> torch.Tensor: + return torch.tensor([1]) + + as_tensor_to_any(torch.jit.script(TensorToAnyImplA())) + as_tensor_to_any(torch.jit.script(TensorToAnyImplB())) + as_any_to_any(torch.jit.script(AnyToAnyImpl())) + + def test_module_interface_inheritance(self): with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"): @torch.jit.interface class InheritMod(nn.ReLU): - def three(self, x): - # type: (Tensor) -> Tensor + def three(self, x: Tensor) -> Tensor: return 3 * x def test_module_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): @@ -228,8 +260,7 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -247,20 +278,17 @@ def forward(self, input): def test_module_swap_wrong_module(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class NewModuleWrong(nn.Module): def __init__(self): super(NewModuleWrong, self).__init__() - def forward(self, input): - # type: (int) -> int + def forward(self, input: int) -> int: return input + 1 class TestModule(nn.Module): @@ -270,8 +298,7 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) scripted_mod = torch.jit.script(TestModule()) @@ -282,12 +309,10 @@ def forward(self, input): def test_module_swap_no_lazy_compile(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class TestModule(nn.Module): @@ -297,20 +322,17 @@ def __init__(self): super(TestModule, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) class NewModuleMethodNotLazyCompile(nn.Module): def __init__(self): super(NewModuleMethodNotLazyCompile, self).__init__() - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod = torch.jit.script(TestModule()) @@ -324,12 +346,10 @@ def __init__(self): super(NewModuleMethodManualExport, self).__init__() @torch.jit.export - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + 1 scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport()) @@ -343,8 +363,7 @@ def __init__(self): super(TestNoModuleInterface, self).__init__() self.proxy_mod = OrigModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod(input) scripted_no_module_interface = torch.jit.script(TestNoModuleInterface()) @@ -359,12 +378,10 @@ def forward(self, input): def test_script_module_as_interface_swap(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: pass class OrigScriptModule(torch.jit.ScriptModule): @@ -372,13 +389,11 @@ def __init__(self): super(OrigScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 + inp2 + 1 @torch.jit.script_method - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return input + self.one(input, input) + 1 class NewScriptModule(torch.jit.ScriptModule): @@ -386,13 +401,11 @@ def __init__(self): super(NewScriptModule, self).__init__() @torch.jit.script_method - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: return inp1 * inp2 + 1 @torch.jit.script_method - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.one(input, input + 1) class TestNNModuleWithScriptModule(nn.Module): @@ -402,8 +415,7 @@ def __init__(self): super(TestNNModuleWithScriptModule, self).__init__() self.proxy_mod = OrigScriptModule() - def forward(self, input): - # type: (Tensor) -> Tensor + def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.forward(input) input = torch.randn(3, 4) @@ -434,8 +446,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> int + def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): @@ -482,8 +493,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> int + def forward(self, x: Tensor) -> int: pass class TestModule(torch.nn.Module): @@ -526,8 +536,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -572,8 +581,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -615,8 +623,7 @@ def forward(self, x): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: Tensor) -> Tensor: pass class TestModule(torch.nn.Module): @@ -650,8 +657,7 @@ def forward(self, x): def test_module_apis_interface(self): @torch.jit.interface class ModuleInterface(nn.Module): - def one(self, inp1, inp2): - # type: (Tensor, Tensor) -> Tensor + def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: pass class TestModule(nn.Module): diff --git a/test/jit/test_peephole.py b/test/jit/test_peephole.py new file mode 100644 index 0000000000000..a41e1136fd3ba --- /dev/null +++ b/test/jit/test_peephole.py @@ -0,0 +1,219 @@ +import torch +from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything +from torch import nn +from torch.testing import FileCheck + +import unittest + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +class TestPeephole(JitTestCase): + def test_peephole_with_writes(self): + def test_write(x): + s = 0 + s += x + s += x + return s + + self.checkScript(test_write, (torch.ones(4, 4),)) + + def test_peephole_with_non_output_writes(self): + @torch.jit.ignore + def nomnom(x): + pass + + def test_write(x): + t = torch.ones_like(x) + z = x.clone() + y = z + 0 + z.add_(t) + # this makes sure z isn't blasted out of existence + # because it isn't returned or used in a side-effectful + # way + nomnom(z) + return y + y + + a = torch.ones(4, 4) + j = self.checkScript(test_write, (a,)) + + def test_peephole_no_output_aliasing(self): + def test_peephole(x): + y = x + 0 + return x, y + + a = torch.ones(4, 4) + j = self.checkScript(test_peephole, (a,)) + r1, r2 = j(a) + self.assertNotEqual(r1.data_ptr(), r2.data_ptr()) + + def test_peephole(self): + a = torch.tensor([0.4]) + b = torch.tensor([0.7]) + c = torch.tensor([0], dtype=torch.int32) + + def f(x, y): + return x.type_as(y) + + tf = torch.jit.trace(f, (a, b)) + FileCheck().check("type_as").run(str(tf.graph)) + self.run_pass('peephole', tf.graph) + FileCheck().check_not("type_as").run(str(tf.graph)) + tf2 = torch.jit.trace(f, (a, c)) + s = str(tf2.graph) + self.run_pass('peephole', tf2.graph) + self.assertEqual(s, str(s)) + + def test_peephole_dynamic(self): + def f(x, y): + return x.type_as(y) + + fn = torch.jit.script(f) + s = str(fn.graph) + torch._C._jit_pass_peephole(fn.graph) + self.assertEqual(s, str(fn.graph)) + + def test_peephole_list_ops(self): + @torch.jit.script + def foo(x, y, z): + return len([x, y, z]) + + self.run_pass('peephole', foo.graph) + FileCheck().check("value=3").check_next("return").run(foo.graph) + + @torch.jit.script + def foo(x, y, z): + li = [x, y, z] + for i in range(len(x)): + li.append(x) + return len([x, y, z]) + + self.run_pass('peephole', foo.graph) + FileCheck().check_not("aten::len").run(foo.graph) + + @torch.jit.script + def foo(x, y, z): + li = [x, y, z] + return li[1], li[-2] + + FileCheck().check("aten::__getitem__").run(foo.graph) + self.run_pass('peephole', foo.graph) + FileCheck().check_not("aten::__getitem__").run(foo.graph) + + @torch.jit.script + def foo(x, y, z): + li = [x, y, z] + return li[-7] + + self.run_pass('peephole', foo.graph) + FileCheck().check("aten::__getitem__").run(foo.graph) + + @torch.jit.script + def foo(x, y, z): + li = [x, y, z] + for i in range(len(x)): + li.append(x) + return li[-2] + + self.run_pass('peephole', foo.graph) + FileCheck().check("aten::__getitem__").run(foo.graph) + + @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") + def test_peephole_cuda(self): + a = torch.tensor([0.4], device='cpu') + b = torch.tensor([0.7], device='cuda') + c = torch.tensor([0.7], device='cuda') + + def f(x, y): + return x.type_as(y) + + trace = torch.jit.trace(f, (a, c)) + s = str(trace.graph) + self.run_pass('peephole', trace.graph) + self.assertEqual(s, str(trace.graph)) + trace = torch.jit.trace(f, (b, c)) + self.run_pass('peephole', trace.graph) + self.run_pass('dce', trace.graph) + FileCheck().check_not("type_as").run(str(trace.graph)) + + @_inline_everything + def test_peephole_type_refinements(self): + def refine(x): + # type: (Optional[Tensor]) -> Tensor + return x if x is not None else torch.tensor(3) + + @torch.jit.script + def test(): + return refine(torch.tensor(4)) + + FileCheck().check("prim::unchecked_cast").run(test.graph) + self.run_pass('peephole', test.graph) + FileCheck().check_not("prim::unchecked_cast").run(test.graph) + + # refinement not optimzied out + def is_int_tensor(x): + scalar = x.item() + if isinstance(scalar, int): + return scalar + 3 + else: + return 8 + + self.checkScript(is_int_tensor, (torch.tensor(2),)) + self.checkScript(is_int_tensor, (torch.tensor(2.5),)) + graph = torch.jit.script(is_int_tensor).graph + self.run_pass('peephole', graph) + FileCheck().check("prim::unchecked_cast").run(graph) + + def test_short_circuit_optimization(self): + @torch.jit.script + def const_expressions(x): + # type: (int) -> Tuple[bool, bool] + return x == 1 and False, x == 1 or True + self.run_pass('constant_propagation', const_expressions.graph) + FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph) + self.assertEqual(const_expressions(1), (False, True)) + + @torch.jit.script + def redundant_expressions(x): + # type: (int) -> Tuple[bool, bool] + return x == 1 and True, x == 1 or False + + self.run_pass('peephole', redundant_expressions.graph) + self.assertEqual(redundant_expressions(1), (True, True)) + self.assertEqual(redundant_expressions(0), (False, False)) + # and True / or False are removed from graph + FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph) + + def test_conv_dim_folding(self): + modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] + for mod in modules: + class ConvDim(torch.nn.Module): + def __init__(self): + super(ConvDim, self).__init__() + self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) + + def forward(self, x): + x = self.conv(x) + return x.dim() + + conv_dim = torch.jit.script(ConvDim()) + self.run_pass("inline", conv_dim.graph) + self.run_pass("peephole", conv_dim.graph) + FileCheck().check_not("conv").check_not("dim").run(conv_dim.graph) + + class ConvDimMutate(torch.nn.Module): + def __init__(self): + super(ConvDimMutate, self).__init__() + self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) + + def forward(self, x): + x = self.conv(x) + x.resize_([4, 4]) + return x.dim() + + conv_dim = torch.jit.script(ConvDimMutate()) + self.run_pass("inline", conv_dim.graph) + self.run_pass("peephole", conv_dim.graph) + FileCheck().check("conv").check("dim").run(conv_dim.graph) diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py index 55604f5ff6bf3..80072e3825182 100644 --- a/test/jit/test_profiler.py +++ b/test/jit/test_profiler.py @@ -25,7 +25,8 @@ def setUp(self): self.default_dtype = torch.get_default_dtype() self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True) torch.set_default_dtype(torch.double) - + self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(False) def tearDown(self): torch._C._jit_set_profiling_executor(self.prev_exec) @@ -35,6 +36,7 @@ def tearDown(self): torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu) torch.set_default_dtype(self.default_dtype) torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled) + torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) def test_tensor_type_not_determined_by_inputs(self): @torch.jit.script @@ -115,7 +117,7 @@ def test_fuse(a, b): g = torch.jit.last_executed_optimized_graph() # Types should remain specialized for typecheck outputs & fusion outputs - FileCheck().check("Double(").check_same("prim::TypeCheck").check("Double").check_same("TensorExpr").run(g) + FileCheck().check("Double(").check_same("prim::TypeCheck").check_same("\n").check("Double").check_same("TensorExpr").run(g) # other outputs should not be specialized FileCheck().check("Tensor = prim::If").run(g) @@ -137,6 +139,25 @@ def foo(a, b): self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2) FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g) + def test_use_not_profiled(self): + def foo(t1, t2, t3, t4, t: float): + h = t1 + t2 + t3 + t4 + if t > 0.5: + # Putting a use of t1 in a never-executed conditional prevents + return t1 + 1 + return h + + t = torch.rand(8, dtype=torch.float) + + foo_script = torch.jit.script(foo) + for _ in range(torch._C._jit_get_num_profiled_runs() + 1): + foo_script(t, t, t, t, 0.1) + + self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1)) + g = torch.jit.last_executed_optimized_graph() + # all adds fused + FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g) + def test_not_fusing_scalar_ops(self): @torch.jit.script def foo(x: int, y: int): @@ -193,6 +214,19 @@ def foo(a, b): g = torch.jit.last_executed_optimized_graph() FileCheck().check("fallback_function").check_next("CallFunction").run(g) + def test_tensor_constant(self): + def foo(a, b): + return a + b + torch.tensor([2]) + + x = torch.ones(1, requires_grad=False) + foo_script = torch.jit.script(foo) + foo_script(x, x) + foo_script(x, x) + + self.assertEqual(foo_script(x, x), foo(x, x)) + g = torch.jit.last_executed_optimized_graph() + FileCheck().check_count("aten::add", 2, exactly=True).run(g) + def test_iterative_fusion(self): @torch.jit.script def foo(a, b, c, d): diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index 82347b8ca9f57..d18c4f6a3dab3 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -1,4 +1,3 @@ -import unittest import os import sys import typing @@ -159,7 +158,6 @@ def a(x): # Make sure that no entries are left over from the previous failure FileCheck().check_count("is being compiled", 2).run(str(e)) - @unittest.skipIf(sys.version_info[:2] < (3, 7), "Class annotations are a thing in > 3.5, need to fix for < 3.7") def test_constants_with_final(self): class M1(torch.nn.Module): x : torch.jit.Final[int] @@ -287,8 +285,7 @@ def forward(self, x): test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))) def test_class_compile(self): - def other_fn(a, b): - # type: (int, Tensor) -> Tensor + def other_fn(a: int, b: Tensor) -> Tensor: return a * b class B(object): @@ -310,8 +307,7 @@ def forward(self, x): self.checkModule(N(), (torch.randn(2, 2),)) def test_error_stack(self): - def d(x): - # type: (int) -> int + def d(x: int) -> int: return x + 10 def c(x): @@ -334,8 +330,7 @@ def a(x): checker.run(str(e)) def test_error_stack_module(self): - def d(x): - # type: (int) -> int + def d(x: int) -> int: return x + 10 def c(x): @@ -497,6 +492,59 @@ def forward(self, x): self.checkModule(M(), (torch.randn(5, 5),)) + def test_prepare_scriptable_basic(self): + class SeluButReluWhenScripted(torch.nn.SELU): + def __prepare_scriptable__(self): + return nn.ReLU() + + t = torch.randn(5, 5) + m = SeluButReluWhenScripted() + sm = torch.jit.script(m) + eager_out = m(t) + script_out = sm(t) + self.assertNotEqual(eager_out, script_out) + + def test_prepare_scriptable_iterable_modules(self): + class SeluButReluWhenScripted(torch.nn.SELU): + def __prepare_scriptable__(self): + return nn.ReLU() + + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + shared = SeluButReluWhenScripted() + self.sequential = nn.Sequential( + SeluButReluWhenScripted(), + SeluButReluWhenScripted(), + nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()), + shared, + ) + self.module_list = nn.ModuleList([SeluButReluWhenScripted(), + shared, + SeluButReluWhenScripted()]) + + def forward(self, x): + for mod in self.module_list: + x += mod(x) + x += self.sequential(x) + return x + + t = torch.randn(5, 5) + m = M() + eager_out = m(t.clone()) + sm = torch.jit.script(m) + script_out = sm(t.clone()) + self.assertNotEqual(eager_out, script_out) + + def test_prepare_scriptable_cycle(self): + t = torch.randn(5, 5) + c = torch.nn.Module() + p = torch.nn.Module() + c.__dict__["_p"] = p + p.__dict__["_c"] = c + + sm = torch.jit.script(p) + def test_attributes(self): @torch.jit.script class Inner2(object): @@ -515,8 +563,7 @@ def __init__(self): self.a = 4 self.inner = Inner2() - def __setstate__(self, obj): - # type: (Tuple[int, Inner2]) -> None + def __setstate__(self, obj: Tuple[int, Inner2]) -> None: a, inner = obj self.a = a self.inner = inner diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 9f11731d1864c..5136e50144f15 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -1,12 +1,14 @@ -import os +from itertools import product as product +from typing import NamedTuple, Optional import io -import sys +import os +import pathlib import random -import torch -from itertools import product as product +import sys + from torch import Tensor from torch.testing._internal.common_utils import TemporaryFileName -from typing import NamedTuple +import torch # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -359,16 +361,14 @@ def _helper(m, fn): else: fn_result = self._try_fn(fn, a, b) - if not a.is_floating_point(): - # NOTE: Torchscript rewrites the module forward into - # torch.reciprocal(a) * b, but torch.reciprocal is - # implemented for integer dtypes. - self.assertTrue(m_result, Exception) - self.assertTrue('"reciprocal_cpu" not implemented for' in str(m_result)) - elif isinstance(m_result, Exception): - self.assertTrue(fn_result, Exception) - else: + if isinstance(m_result, Exception): + self.assertTrue(isinstance(fn_result, Exception)) + elif fn is torch.div or a.is_floating_point(): self.assertEqual(m_result, fn_result) + else: + # Skip when fn is not torch.div and a is integral because + # historic_div_scalar_int performs floored division + pass if isinstance(b, float): _helper(v3_module_float, historic_div_scalar_float_reciprocal) @@ -682,8 +682,7 @@ def test_different_interfaces(self): """ @torch.jit.interface class MyInterface(object): - def bar(self, x): - # type: (Tensor) -> Tensor + def bar(self, x: Tensor) -> Tensor: pass @torch.jit.script @@ -713,8 +712,7 @@ def forward(self, x): @torch.jit.interface class MyInterface(object): - def not_bar(self, x): - # type: (Tensor) -> Tensor + def not_bar(self, x: Tensor) -> Tensor: pass @torch.jit.script # noqa: F811 @@ -769,8 +767,7 @@ class MyCoolNamedTuple(NamedTuple): @torch.jit.interface class MyInterface(object): - def bar(self, x): - # type: (Tensor) -> Tensor + def bar(self, x: Tensor) -> Tensor: pass @torch.jit.script @@ -811,8 +808,7 @@ def forward(self, x): @torch.jit.interface class MyInterface(object): - def not_bar(self, x): - # type: (Tensor) -> Tensor + def not_bar(self, x: Tensor) -> Tensor: pass @torch.jit.script # noqa F811 @@ -920,3 +916,65 @@ def forward(self, a): with self.assertRaises(RuntimeError): extra_files['bar'] = '' torch.jit.load(buffer, _extra_files=extra_files) + + def test_save_load_using_pathlib(self): + class MyMod(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return 2 * a + + m = MyMod() + + # Save then load. + with TemporaryFileName() as fname: + path = pathlib.Path(fname) + m.save(path) + m2 = torch.jit.load(path) + + x = torch.tensor([1., 2., 3., 4.]) + self.assertTrue(torch.equal(m(x), m2(x))) + + def test_save_nonexit_file(self): + class Foo(torch.nn.Module): + def forward(self, x): + return 2 * x + + script_module = torch.jit.script(Foo()) + with self.assertRaises(RuntimeError): + script_module.save("NonExist/path/test.pt") + + def test_save_namedtuple_input_only(self): + """ + Even if a NamedTuple is only used as an input argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self, x: FooTuple) -> torch.Tensor: + return torch.tensor(3) + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded(FooTuple(a=5)) + self.assertEqual(output, torch.tensor(3)) + + def test_save_namedtuple_output_only(self): + """ + Even if a NamedTuple is only used as an output argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self) -> Optional[FooTuple]: + return None + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded() + self.assertEqual(output, None) diff --git a/test/jit/test_string_formatting.py b/test/jit/test_string_formatting.py new file mode 100644 index 0000000000000..2f9f093ebf365 --- /dev/null +++ b/test/jit/test_string_formatting.py @@ -0,0 +1,148 @@ +import os +import sys + +import torch +from typing import List + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +class TestStringFormatting(JitTestCase): + + def test_modulo_operator(self): + def fn(dividend: int, divisor: int) -> int: + return dividend % divisor + self.checkScript(fn, (5, 2)) + + def test_string_interpolation_with_string_placeholder_and_string_variable(self): + def fn(arg1: str): + return "%s in template" % arg1 + self.checkScript(fn, ("foo",)) + + def test_string_interpolation_with_string_placeholder_and_format_string_variable(self): + def fn(arg1: str): + return arg1 % "foo" + self.checkScript(fn, ("%s in template",)) + + def test_string_interpolation_with_double_percent_in_string(self): + def fn(arg1: str): + return "%s in template %%" % arg1 + self.checkScript(fn, ("foo",)) + + def test_string_interpolation_with_percent_in_string(self): + @torch.jit.script + def fn(arg1: str) -> str: + return "%s in template %" % arg1 # noqa: F501 + + with self.assertRaisesRegex(RuntimeError, "Incomplete format specifier"): + fn("foo") + + def test_string_interpolation_with_string_placeholder_and_digit_variable(self): + def fn(arg1: int) -> str: + return "%s in template" % arg1 + self.checkScript(fn, (1,)) + + def test_string_interpolation_with_digit_placeholder_and_digit_variable(self): + def fn(arg1: int) -> str: + return "%d in template" % arg1 + self.checkScript(fn, (1,)) + + def test_string_interpolation_with_alternate_digit_placeholder(self): + def fn(arg1: int) -> str: + return "%i in template" % arg1 + self.checkScript(fn, (1,)) + + def test_string_interpolation_with_digit_placeholder_and_string_variable(self): + @torch.jit.script + def fn(arg1: str) -> str: + return "%d in template" % arg1 + + with self.assertRaisesRegex(RuntimeError, "%d requires a number for formatting, but got String"): + fn("1") + + def test_string_interpolation_with_exponent_placeholder_and_string_variable(self): + @torch.jit.script + def fn(arg1: str) -> str: + return "%e in template" % arg1 + + with self.assertRaisesRegex(RuntimeError, "%e requires a number for formatting, but got String"): + fn("1") + + def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self): + def fn(arg1: int) -> str: + return "%e in template" % arg1 + self.checkScript(fn, (1,)) + + def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self): + def fn(arg1: int) -> str: + return "%E in template" % arg1 + self.checkScript(fn, (1,)) + + def test_string_interpolation_with_float_placeholder_and_float_variable(self): + def fn(arg1: float) -> str: + return "%f in template" % arg1 + self.checkScript(fn, (1.0,)) + + def test_string_interpolation_with_float_placeholder_and_digit_variable(self): + def fn(arg1: int) -> str: + return "%f in template" % arg1 + self.checkScript(fn, (1,)) + + def test_string_interpolation_with_char_placeholder_and_char_variable(self): + def fn(arg1: str) -> str: + return "%c in template" % arg1 + self.checkScript(fn, ("a",)) + + def test_string_interpolation_with_char_placeholder_and_digit_variable(self): + def fn(arg1: int) -> str: + return "%c in template" % arg1 + self.checkScript(fn, (97,)) + + def test_string_interpolation_with_char_placeholder_and_true_string_variable(self): + @torch.jit.script + def fn(arg1: str) -> str: + return "%c in template" % arg1 + + with self.assertRaisesRegex(RuntimeError, "%c requires an int or char for formatting, but got String"): + fn("foo") + + def test_string_interpolation_with_multiple_placeholders(self): + def fn(arg1: str, arg2: int, arg3: float) -> str: + return "%s %d %f in template" % (arg1, arg2, arg3) + self.checkScript(fn, ("foo", 1, 1)) + + def test_string_interpolation_with_subscript(self): + def fn(arg1: List[str]) -> str: + return "%s in template" % arg1[0] + self.checkScript(fn, (["foo", "bar"],)) + + def test_string_interpolation_with_too_few_arguments(self): + @torch.jit.script + def fn(arg1: str) -> str: + return "%s %s in template" % arg1 + + with self.assertRaisesRegex(RuntimeError, "Too few arguments for format string"): + fn("foo") + + def test_string_interpolation_with_too_many_arguments(self): + @torch.jit.script + def fn(arg1: str, arg2: str) -> str: + return "%s in template" % (arg1, arg2) # noqa: F507 + + with self.assertRaisesRegex(RuntimeError, "Too many arguments for format string"): + fn("foo", "bar") + + def test_string_interpolation_with_unknown_format_specifier(self): + @torch.jit.script + def fn(arg1: str) -> str: + return "%a in template" % arg1 # noqa: F501 + + with self.assertRaisesRegex(RuntimeError, "The specifier %a is not supported in TorchScript format strings"): + fn("foo") diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index a8bea73c984d0..7f43b31fe6ec0 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -12,7 +12,7 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase -from torch.testing._internal.common_utils import TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS +from torch.testing._internal.common_utils import TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS, IS_FBCODE from torch.testing import FileCheck if __name__ == "__main__": @@ -24,10 +24,14 @@ class TestTorchbind(JitTestCase): def setUp(self): - if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: + if IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE: raise unittest.SkipTest("non-portable load_library call used in test") - torch_root = Path(__file__).resolve().parent.parent.parent - p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so' + if TEST_WITH_ROCM: + torch_root = Path(torch.__file__).resolve().parent + p = torch_root / 'lib' / 'libtorchbind_test.so' + else: + torch_root = Path(__file__).resolve().parent.parent.parent + p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so' torch.ops.load_library(str(p)) def test_torchbind(self): @@ -58,6 +62,32 @@ def f(): return ss1.pop() + ss2.pop() test_equality(f, lambda x: x) + # test nn module with prepare_scriptable function + class NonJitableClass(object): + def __init__(self, int1, int2): + self.int1 = int1 + self.int2 = int2 + + def return_vals(self): + return self.int1, self.int2 + + class CustomWrapper(torch.nn.Module): + def __init__(self, foo): + super(CustomWrapper, self).__init__() + self.foo = foo + + def forward(self) -> None: + self.foo.increment(1) + return + + def __prepare_scriptable__(self): + int1, int2 = self.foo.return_vals() + foo = torch.classes._TorchScriptTesting._Foo(int1, int2) + return CustomWrapper(foo) + + foo = CustomWrapper(NonJitableClass(1, 2)) + jit_foo = torch.jit.script(foo) + def test_torchbind_take_as_arg(self): global StackString # see [local resolution in python] StackString = torch.classes._TorchScriptTesting._StackString @@ -139,6 +169,23 @@ def foo(): scripted = torch.jit.script(foo) self.assertEqual(scripted(), "mom") + def test_torchbind_class_attr_recursive(self): + class FooBar(torch.nn.Module): + def __init__(self, foo_model): + super(FooBar, self).__init__() + self.foo_mod = foo_model + + def forward(self) -> int: + return self.foo_mod.info() + + def to_ivalue(self): + torchbind_model = torch.classes._TorchScriptTesting._Foo(self.foo_mod.info(), 1) + return FooBar(torchbind_model) + + inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3)) + scripted = torch.jit.script(inst.to_ivalue()) + self.assertEqual(scripted(), 6) + def test_torchbind_class_attribute(self): class FooBar1234(torch.nn.Module): def __init__(self): @@ -219,6 +266,10 @@ def forward(self): traced = torch.jit.trace(TryTracing(), ()) self.assertEqual(torch.zeros(4, 4), traced()) + def test_torchbind_pass_wrong_type(self): + with self.assertRaisesRegex(RuntimeError, 'missing attribute capsule'): + torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4)) + def test_torchbind_tracing_nested(self): class TryTracingNest(torch.nn.Module): def __init__(self): @@ -282,3 +333,19 @@ def test_profiler_custom_op(self): if e.name == '_TorchScriptTesting::take_an_instance': found_event = True self.assertTrue(found_event) + + def test_torchbind_getattr(self): + foo = torch.classes._TorchScriptTesting._StackString(["test"]) + self.assertEqual(None, getattr(foo, 'bar', None)) + + def test_torchbind_attr_exception(self): + foo = torch.classes._TorchScriptTesting._StackString(["test"]) + with self.assertRaisesRegex(AttributeError, 'does not have a field'): + foo.bar + + def test_lambda_as_constructor(self): + obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False) + self.assertEqual(obj_no_swap.diff(), 1) + + obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True) + self.assertEqual(obj_swap.diff(), -1) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 24db4cfe857ef..f43f9adf476e0 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -15,18 +15,16 @@ sys.path.append(pytorch_test_dir) from torch.testing._internal.common_utils import suppress_warnings, \ skipIfCompiledWithoutNumpy, enable_profiling_mode_for_profiling_tests, \ - IS_SANDCASTLE, IS_WINDOWS + IS_SANDCASTLE, TemporaryFileName from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, \ _tmp_donotuse_dont_inline_everything, _trace, RUN_CUDA, RUN_CUDA_MULTI_GPU from torch.testing._internal.common_cuda import with_tf32_off -from typing import List, Tuple from torch import Tensor # Standard library from collections import namedtuple from itertools import chain -import tempfile -from typing import Dict +from typing import Dict, List, Optional, Tuple import warnings if __name__ == '__main__': @@ -399,6 +397,18 @@ def full_with_shape_like(x): self.assertEqual(ge(y).shape, y.shape) self.assertEqual(ge(x).shape, x.shape) + # Test that the trace of setitem doesn't store shapes as constants + # Fix https://github.com/pytorch/pytorch/issues/43548 + def test_trace_slice_setitem_dynamic_shape(self): + def slice_setitem(x, y): + x[:, 2] = y + 1 + return x + + x = torch.randn(3, 4) + traced = torch.jit.trace(slice_setitem, (x, x[:, 0])) + x = torch.randn(10, 5) + self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0])) + # Suppression: we are intentionally slicing a tensor, we don't care that it # will be constantified @suppress_warnings @@ -1203,15 +1213,14 @@ def foo(x): self.run_pass('inline', traced_tensor_size.graph) FileCheck().check("prim::device").run(traced_tensor_size.graph) - @unittest.skipIf(IS_WINDOWS, "temp file name on windows") def test_trace_save(self): def fn(x): return x + 2 def check(func): - with tempfile.NamedTemporaryFile() as f: - func.save(f.name) - loaded = torch.jit.load(f.name) + with TemporaryFileName() as fname: + func.save(fname) + loaded = torch.jit.load(fname) input = torch.randn(2, 2) self.assertEqual(func(input), loaded(input)) @@ -1395,8 +1404,7 @@ def forward(self, x): @_tmp_donotuse_dont_inline_everything def test_trace_optional(self): @torch.jit.script - def test(x): - # type: (Optional[Tensor]) + def test(x: Optional[Tensor]): if x is None: return torch.zeros(1) else: @@ -1526,7 +1534,7 @@ def foo(x): x[i, :] = torch.zeros(4) return x - self.checkTrace(foo, (torch.rand(3, 4),)) + self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False) def test_trace_checker_inplace_on_view(self): def foo(x): @@ -1834,17 +1842,29 @@ def f(x): with self.assertRaisesRegex(RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"): torch.jit.trace(f, (1,)) + def test_trace_skip_none_submodule(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.submod = torch.nn.Linear(3, 4) + self.submod = None + + def forward(self, inputs): + return inputs + + m = TestModule() + tm = torch.jit.trace(m, torch.tensor(1.)) + self.assertFalse(hasattr(tm, "submod")) + class TestMixTracingScripting(JitTestCase): def test_trace_script(self): @torch.jit.script - def func1(x): - # type: (Tuple[Tensor, Tensor]) -> Tensor + def func1(x: Tuple[Tensor, Tensor]) -> Tensor: return x[0] + x[1] @torch.jit.script - def func2(x): - # type: (List[Tensor]) -> Tensor + def func2(x: List[Tensor]) -> Tensor: return x[0] + x[1] a = torch.randn(5) @@ -1854,8 +1874,7 @@ def func2(x): self.checkTrace(func2, ((a, b),)) @torch.jit.script - def func3(x, method='bilinear', align_corners=True): - # type: (Tensor, str, bool) -> Tensor + def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Tensor: hw = x.shape[2:4] return F.interpolate(x, hw, mode=method, align_corners=align_corners) @@ -1863,8 +1882,7 @@ def func3(x, method='bilinear', align_corners=True): self.checkTrace(func3, (inp,)) @torch.jit.script - def func4(x, a): - # type: (Tensor, List[Optional[str]]) -> Tensor + def func4(x: Tensor, a: List[Optional[str]]) -> Tensor: if len(a) == 2: return x + 2 else: @@ -1912,7 +1930,7 @@ def foo(x): @torch.jit.script def bar(x): y = int(foo(x)) - if True: + if 1 == 1: y = 7 return y + 1 @@ -2267,3 +2285,46 @@ def forward(self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tenso traced_module = torch.jit.trace(eager_module, input1) self.assertEqual(traced_module(input1), eager_module(input1)) self.assertEqual(traced_module(input2), eager_module(input2)) + + def test_trace_returning_dict_with_tensor_tuples(self): + """Tracing over a module returning a dictionary whose values are tuples of tensors + should work. + """ + class ReturnsDict(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, k: torch.Tensor, v: torch.Tensor + ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: + x = 2 * k + y = 3 * v + result = { + "imakey": (x, y) + } + return result + + class ReturnsBadDict(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, k: torch.Tensor, v: torch.Tensor + ) -> Dict[str, Tuple[torch.Tensor, float]]: + x = 2 * k + result = { + "imakey": (x, 1) + } + return result + + mod = ReturnsDict() + traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False) + out = traced_module(torch.ones(1), torch.ones(1)) + expected = { + "imakey": (torch.tensor([2.]), torch.tensor([3.])) + } + self.assertEqual(out, expected) + + with self.assertRaisesRegex(RuntimeError, "cannot be understood by the tracer, only outputs matching"): + mod = ReturnsBadDict() + traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False) diff --git a/test/jit/test_type_sharing.py b/test/jit/test_type_sharing.py index 7981ed96d5105..cb6677937b968 100644 --- a/test/jit/test_type_sharing.py +++ b/test/jit/test_type_sharing.py @@ -560,3 +560,54 @@ def forward(self, x): self.assertDifferentType(top1_s, top2_s.sub) self.assertDifferentType(top2_s, top2_s.sub) self.assertDifferentType(top2_s, top1_s.sub) + + def test_type_shared_ignored_attributes(self): + """ + Test that types are shared if the exclusion of their + ignored attributes makes them equal. + """ + class A(torch.nn.Module): + __jit_ignored_attributes__ = ["a"] + + def __init__(self, a, b): + super().__init__() + self.a = a + self.b = b + + def forward(self, x): + return x + + a_with_linear = A(torch.nn.Linear(5, 5), 5) + a_with_string = A("string", 10) + + # Both should have the same type because the attribute + # that differs in type is ignored and the common attribute + # has the same type. + self.assertSameType(a_with_linear, a_with_string) + + def test_type_not_shared_ignored_attributes(self): + """ + Test that types are not shared if the exclusion of their + ignored attributes makes them not equal. + """ + class A(torch.nn.Module): + __jit_ignored_attributes__ = ["a"] + + def __init__(self, a, b, c): + super().__init__() + self.a = a + self.b = b + self.c = c + + def forward(self, x): + return x + + mod = A(torch.nn.Linear(5, 5), 5, "string") + s1 = torch.jit.script(mod) + A.__jit_ignored_attributes__ = ["a", "b"] + s2 = torch.jit.script(mod) + + # The types of s1 and s2 should differ. Although they are instances + # of A, __jit_ignored_attributes__ was modified before scripting s2, + # so the set of ignored attributes is different between s1 and s2. + self.assertDifferentType(s1, s2) diff --git a/test/jit/test_warn.py b/test/jit/test_warn.py new file mode 100644 index 0000000000000..6a89ba4dc3854 --- /dev/null +++ b/test/jit/test_warn.py @@ -0,0 +1,165 @@ +import os +import sys +import io + +import torch +import warnings +from contextlib import redirect_stderr +from torch.testing import FileCheck + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + + +class TestWarn(JitTestCase): + def test_warn(self): + @torch.jit.script + def fn(): + warnings.warn("I am warning you") + + f = io.StringIO() + with redirect_stderr(f): + fn() + + FileCheck() \ + .check_count( + str="UserWarning: I am warning you", + count=1, + exactly=True) \ + .run(f.getvalue()) + + def test_warn_only_once(self): + @torch.jit.script + def fn(): + for _ in range(10): + warnings.warn("I am warning you") + + f = io.StringIO() + with redirect_stderr(f): + fn() + + FileCheck() \ + .check_count( + str="UserWarning: I am warning you", + count=1, + exactly=True) \ + .run(f.getvalue()) + + def test_warn_only_once_in_loop_func(self): + def w(): + warnings.warn("I am warning you") + + @torch.jit.script + def fn(): + for _ in range(10): + w() + + f = io.StringIO() + with redirect_stderr(f): + fn() + + FileCheck() \ + .check_count( + str="UserWarning: I am warning you", + count=1, + exactly=True) \ + .run(f.getvalue()) + + def test_warn_once_per_func(self): + def w1(): + warnings.warn("I am warning you") + + def w2(): + warnings.warn("I am warning you") + + @torch.jit.script + def fn(): + w1() + w2() + + f = io.StringIO() + with redirect_stderr(f): + fn() + + FileCheck() \ + .check_count( + str="UserWarning: I am warning you", + count=2, + exactly=True) \ + .run(f.getvalue()) + + def test_warn_once_per_func_in_loop(self): + def w1(): + warnings.warn("I am warning you") + + def w2(): + warnings.warn("I am warning you") + + @torch.jit.script + def fn(): + for _ in range(10): + w1() + w2() + + f = io.StringIO() + with redirect_stderr(f): + fn() + + FileCheck() \ + .check_count( + str="UserWarning: I am warning you", + count=2, + exactly=True) \ + .run(f.getvalue()) + + def test_warn_multiple_calls_multiple_warnings(self): + @torch.jit.script + def fn(): + warnings.warn("I am warning you") + + f = io.StringIO() + with redirect_stderr(f): + fn() + fn() + + FileCheck() \ + .check_count( + str="UserWarning: I am warning you", + count=2, + exactly=True) \ + .run(f.getvalue()) + + def test_warn_multiple_calls_same_func_diff_stack(self): + def warn(caller: str): + warnings.warn("I am warning you from " + caller) + + @torch.jit.script + def foo(): + warn("foo") + + @torch.jit.script + def bar(): + warn("bar") + + f = io.StringIO() + with redirect_stderr(f): + foo() + bar() + + FileCheck() \ + .check_count( + str="UserWarning: I am warning you from foo", + count=1, + exactly=True) \ + .check_count( + str="UserWarning: I am warning you from bar", + count=1, + exactly=True) \ + .run(f.getvalue()) diff --git a/test/jit/test_with.py b/test/jit/test_with.py index ffd0631639f6f..f958dc46c39ad 100644 --- a/test/jit/test_with.py +++ b/test/jit/test_with.py @@ -1,7 +1,7 @@ import os import sys -from typing import Any +from typing import Any, List import torch from torch.testing._internal.jit_utils import JitTestCase @@ -50,8 +50,7 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x): - # type: (Tensor) -> Tensor + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) @@ -62,8 +61,7 @@ def test_basic(x): y *= c.count return y - def test_pass(x): - # type: (Tensor) -> Tensor + def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -77,8 +75,7 @@ def test_pass(x): x *= c.count return x - def test_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. @@ -90,8 +87,7 @@ def test_early_return(x, c): x = y + y return x - def test_conditional_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. @@ -104,8 +100,7 @@ def test_conditional_early_return(x, c): x = y + y return x - def test_break(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. @@ -118,8 +113,7 @@ def test_break(x, c, l): return x - def test_continue(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. @@ -132,8 +126,7 @@ def test_continue(x, c, l): return x - def test_serial(x): - # type: (Tensor) -> Tensor + def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ @@ -147,8 +140,7 @@ def test_serial(x): return y - def test_nested(x): - # type: (Tensor) -> Tensor + def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ @@ -162,8 +154,7 @@ def test_nested(x): return y - def test_combined(x): - # type: (Tensor) -> Tensor + def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ @@ -215,8 +206,7 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) - def test_basic(x): - # type: (Tensor) -> Tensor + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) @@ -227,8 +217,7 @@ def test_basic(x): y *= c.count return y - def test_pass(x): - # type: (Tensor) -> Tensor + def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should @@ -242,8 +231,7 @@ def test_pass(x): x *= c.count return x - def test_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. @@ -255,8 +243,7 @@ def test_early_return(x, c): x = y + y return x - def test_conditional_early_return(x, c): - # type: (Tensor, Context) -> Tensor + def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. @@ -269,8 +256,7 @@ def test_conditional_early_return(x, c): x = y + y return x - def test_break(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. @@ -283,8 +269,7 @@ def test_break(x, c, l): return x - def test_continue(x, c, l): - # type: (Tensor, Context, List[int]) -> Tensor + def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. @@ -297,8 +282,7 @@ def test_continue(x, c, l): return x - def test_serial(x): - # type: (Tensor) -> Tensor + def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ @@ -312,8 +296,7 @@ def test_serial(x): return y - def test_nested(x): - # type: (Tensor) -> Tensor + def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ @@ -327,8 +310,7 @@ def test_nested(x): return y - def test_combined(x): - # type: (Tensor) -> Tensor + def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ @@ -381,13 +363,11 @@ def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) @torch.jit.script - def method_that_raises(): - # type: () -> Tensor + def method_that_raises() -> torch.Tensor: raise Exception("raised exception") @torch.jit.script - def test_exception(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a with-statement. """ @@ -397,8 +377,7 @@ def test_exception(x, c): return x @torch.jit.script - def test_exception_nested(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a nested with-statement. """ @@ -409,8 +388,7 @@ def test_exception_nested(x, c): return x @torch.jit.script - def with_that_raises(c): - # type: (Context) -> Tensor + def with_that_raises(c: Context) -> torch.Tensor: a = torch.tensor([1]) with c as _: @@ -419,8 +397,7 @@ def with_that_raises(c): return a @torch.jit.script - def test_exception_fn_call(x, c): - # type: (Tensor, Context) -> Tensor + def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while there are active with-statements in two different frames. @@ -506,29 +483,25 @@ def __enter__(self): def __exit__(self, type: Any, value: int, tb: int): pass - def test_no_enter_no_exit(x, c): - # type: (Tensor, NoEnterNoExit) -> Tensor + def test_no_enter_no_exit(x: torch.Tensor, c: NoEnterNoExit) -> torch.Tensor: with c as _: pass return x - def test_bad_enter(x, c): - # type: (Tensor, BadEnter) -> Tensor + def test_bad_enter(x: torch.Tensor, c: BadEnter) -> torch.Tensor: with c as _: pass return x - def test_bad_exit(x, c): - # type: (Tensor, BadExit) -> Tensor + def test_bad_exit(x: torch.Tensor, c: BadExit) -> torch.Tensor: with c as _: pass return x - def test_exit_incorrect_types(x, c): - # type: (Tensor, ExitIncorrectTypes) -> Tensor + def test_exit_incorrect_types(x: torch.Tensor, c: ExitIncorrectTypes) -> torch.Tensor: with c as _: pass @@ -565,8 +538,7 @@ def test_with_no_grad(self): """ # Basic no_grad test. - def test_no_grad(x, y): - # type: (Tensor, Tensor) -> Tensor + def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = x + y @@ -583,8 +555,7 @@ def test_no_grad(x, y): # Test assignment of a grad-less Tensor to a Tensor with gradients # in a no_grad block. - def test_no_grad_assignment(x, y): - # type: (Tensor, Tensor) -> Tensor + def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): x[0] = y @@ -603,13 +574,11 @@ def __init__(self): super().__init__() @torch.jit.ignore - def adder(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: w = x + y return w - def forward(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = self.adder(x, y) @@ -625,8 +594,7 @@ def test_with_record_function(self): Check that torch.autograd.profiler.record_function context manager is torchscriptable. """ - def with_rf(x, y): - # type: (Tensor, Tensor) -> Tensor + def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.autograd.profiler.record_function("foo"): # Nested record_function. with torch.autograd.profiler.record_function("nested"): diff --git a/test/mobile/custom_build/CMakeLists.txt b/test/mobile/custom_build/CMakeLists.txt index 384f4b126e961..339ee953f7ec7 100644 --- a/test/mobile/custom_build/CMakeLists.txt +++ b/test/mobile/custom_build/CMakeLists.txt @@ -20,4 +20,5 @@ target_link_libraries(Predictor ${TORCH_LIBRARIES} -Wl,--no-whole-archive Threads::Threads + ${CMAKE_DL_LIBS} ) diff --git a/test/mobile/op_deps/quantized_ops.cpp b/test/mobile/op_deps/quantized_ops.cpp index 81dec6f7c42e9..78b2367f7dd44 100644 --- a/test/mobile/op_deps/quantized_ops.cpp +++ b/test/mobile/op_deps/quantized_ops.cpp @@ -3,6 +3,7 @@ #include #include +#include #include // This file simulates some irregular op registration/invocation patterns for diff --git a/test/mobile/op_deps/simple_ops.cpp b/test/mobile/op_deps/simple_ops.cpp index 19d6ecdc91211..a76c58838a726 100644 --- a/test/mobile/op_deps/simple_ops.cpp +++ b/test/mobile/op_deps/simple_ops.cpp @@ -80,7 +80,7 @@ namespace { // cares about the name TORCH_LIBRARY(_test, m) { m.def("AA(Tensor self) -> Tensor"); - m.impl("AA", torch::CppFunction::makeUnboxedOnly(AA_op)); + m.impl("AA", torch::CppFunction::makeFromUnboxedFunction(AA_op)); m.def("BB(Tensor self) -> Tensor"); m.impl("BB", TORCH_FN(BB_op)); @@ -89,7 +89,7 @@ TORCH_LIBRARY(_test, m) { m.def("DD", TORCH_FN(DD_op)); } -TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(_test, m) { +TORCH_LIBRARY_FRAGMENT(_test, m) { m.def("EE(Tensor self) -> Tensor"); m.def("FF(Tensor self) -> Tensor"); m.def("GG(Tensor self) -> Tensor"); @@ -97,10 +97,10 @@ TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(_test, m) { } TORCH_LIBRARY_IMPL(_test, CPU, m) { - m.impl_UNBOXED("EE", EE_op); + m.impl("EE", EE_op); m.impl("FF", torch::dispatch(DispatchKey::CPU, - torch::CppFunction::makeUnboxedOnly(FF_op)) + torch::CppFunction::makeFromUnboxedFunction(FF_op)) ); m.impl("GG", torch::dispatch(DispatchKey::CPU, diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 253b45be22176..3549582dcfac5 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -1,8 +1,10 @@ import unittest import torch import torch.utils.bundled_inputs - +from torch.utils.mobile_optimizer import * import io +from typing import NamedTuple +from collections import namedtuple from torch.jit.mobile import _load_for_lite_interpreter @@ -34,7 +36,7 @@ def forward(self, x): mobile_module_run_method_result = mobile_module.run_method("forward", input) torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result) - def test_save_mobile_module_with_debug_info(self): + def test_save_mobile_module_with_debug_info_with_trace(self): class A(torch.nn.Module): def __init__(self): super(A, self).__init__() @@ -53,13 +55,83 @@ def forward(self, x): input = torch.tensor([5]) trace_module = torch.jit.trace(B(), input) - bytes = trace_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True) + exported_module = trace_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True) + + assert(b"mobile_debug.pkl" in exported_module) + assert(b"module_debug_info" in exported_module) + assert(b"top(B).forward" in exported_module) + assert(b"top(B).A0(A).forward" in exported_module) + assert(b"top(B).A1(A).forward" in exported_module) + + def test_save_mobile_module_with_debug_info_with_script_duplicate_class(self): + class A(torch.nn.Module): + def __init__(self): + super(A, self).__init__() + + def forward(self, x): + return x + 1 + + class B(torch.nn.Module): + def __init__(self): + super(B, self).__init__() + self.A0 = A() + self.A1 = A() + + def forward(self, x): + return self.A0(x) + self.A1(x) + + input_data = torch.tensor([5]) + scripted_module = torch.jit.script(B(), input_data) + exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True) + + assert(b"mobile_debug.pkl" in exported_module) + assert(b"module_debug_info" in exported_module) + assert(b"top(B).forward" in exported_module) + assert(b"top(B).A0(A).forward" in exported_module) + assert(b"top(B).A1(A).forward" in exported_module) + + def test_save_mobile_module_with_debug_info_with_script_nested_call(self): + class A(torch.nn.Module): + def __init__(self): + super(A, self).__init__() + + def forward(self, x): + return x + 1 + + class B(torch.nn.Module): + def __init__(self): + super(B, self).__init__() + + def forward(self, x): + return x + 2 + + class C(torch.nn.Module): + def __init__(self): + super(C, self).__init__() + self.A0 = A() + self.B0 = B() + + def forward(self, x): + return self.A0(self.B0(x)) + 1 + + input = torch.tensor([5]) + scripted_module = torch.jit.script(C(), input) - assert(b"mobile_debug.pkl" in bytes) - assert(b"module_debug_info" in bytes) - assert(b"top(B).forward" in bytes) - assert(b"top(B).A0(A).forward" in bytes) - assert(b"top(B).A1(A).forward" in bytes) + optimized_scripted_module = optimize_for_mobile(scripted_module) + + exported_module = scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True) + optimized_exported_module = optimized_scripted_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True) + assert(b"mobile_debug.pkl" in exported_module) + assert(b"module_debug_info" in exported_module) + assert(b"top(C).forward" in exported_module) + assert(b"top(C).A0(A).forward" in exported_module) + assert(b"top(C).B0(B).forward" in exported_module) + + assert(b"mobile_debug.pkl" in optimized_exported_module) + assert(b"module_debug_info" in optimized_exported_module) + assert(b"top(C).forward" in optimized_exported_module) + assert(b"top(C).A0(A).forward" in optimized_exported_module) + assert(b"top(C).B0(B).forward" in optimized_exported_module) def test_load_mobile_module_with_debug_info(self): class MyTestModule(torch.nn.Module): @@ -138,7 +210,81 @@ def forward(self, arg): r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\.$"): script_module._save_to_buffer_for_lite_interpreter() + def test_unsupported_return_typing_namedtuple(self): + myNamedTuple = NamedTuple('myNamedTuple', [('a', torch.Tensor)]) + + class MyTestModule(torch.nn.Module): + def forward(self): + return myNamedTuple(torch.randn(1)) + + script_module = torch.jit.script(MyTestModule()) + with self.assertRaisesRegex(RuntimeError, + r"A named tuple type is not supported in mobile module. " + r"Workaround: instead of using a named tuple type\'s fields, " + r"use a dictionary type\'s key-value pair itmes or " + r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."): + script_module._save_to_buffer_for_lite_interpreter() + + def test_unsupported_return_collections_namedtuple(self): + myNamedTuple = namedtuple('myNamedTuple', [('a')]) + class MyTestModule(torch.nn.Module): + def forward(self): + return myNamedTuple(torch.randn(1)) + + script_module = torch.jit.script(MyTestModule()) + with self.assertRaisesRegex(RuntimeError, + r"A named tuple type is not supported in mobile module. " + r"Workaround: instead of using a named tuple type\'s fields, " + r"use a dictionary type\'s key-value pair itmes or " + r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."): + script_module._save_to_buffer_for_lite_interpreter() + + def test_unsupported_return_list_with_module_class(self): + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + + class MyTestModuleForListWithModuleClass(torch.nn.Module): + def __init__(self): + super(MyTestModuleForListWithModuleClass, self).__init__() + self.foo = Foo() + + def forward(self): + my_list: List[Foo] = [self.foo] + return my_list + + script_module = torch.jit.script(MyTestModuleForListWithModuleClass()) + with self.assertRaisesRegex(RuntimeError, + r"^Returining a list or dictionary with pytorch class type " + r"is not supported in mobile module " + r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " + r"Workaround\: instead of using pytorch class as their element type\, " + r"use a combination of list\, dictionary\, and single types\.$"): + script_module._save_to_buffer_for_lite_interpreter() + + def test_unsupported_return_dict_with_module_class(self): + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + + class MyTestModuleForDictWithModuleClass(torch.nn.Module): + def __init__(self): + super(MyTestModuleForDictWithModuleClass, self).__init__() + self.foo = Foo() + + def forward(self): + my_dict: Dict[int, Foo] = {1: self.foo} + return my_dict + + script_module = torch.jit.script(MyTestModuleForDictWithModuleClass()) + with self.assertRaisesRegex(RuntimeError, + r"^Returining a list or dictionary with pytorch class type " + r"is not supported in mobile module " + r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " + r"Workaround\: instead of using pytorch class as their element type\, " + r"use a combination of list\, dictionary\, and single types\.$"): + script_module._save_to_buffer_for_lite_interpreter() if __name__ == '__main__': unittest.main() diff --git a/test/onnx/expect/TestOperators.test_arange_dynamic.expect b/test/onnx/expect/TestOperators.test_arange_dynamic.expect index 98c35e7fab802..e49a72fb44605 100644 --- a/test/onnx/expect/TestOperators.test_arange_dynamic.expect +++ b/test/onnx/expect/TestOperators.test_arange_dynamic.expect @@ -146,7 +146,7 @@ graph { elem_type: 1 shape { dim { - dim_value: 10 + dim_param: "13_0" } } } diff --git a/test/onnx/expect/TestOperators.test_batchnorm.expect b/test/onnx/expect/TestOperators.test_batchnorm.expect index 40a373ef0901a..ec5d87812f39c 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm.expect @@ -26,25 +26,25 @@ graph { initializer { dims: 2 data_type: 1 - name: "bias" - raw_data: "\000\000\000\000\000\000\000\000" + name: "weight" + raw_data: "\000\000\200?\000\000\200?" } initializer { dims: 2 data_type: 1 - name: "running_mean" + name: "bias" raw_data: "\000\000\000\000\000\000\000\000" } initializer { dims: 2 data_type: 1 - name: "running_var" - raw_data: "\000\000\200?\000\000\200?" + name: "running_mean" + raw_data: "\000\000\000\000\000\000\000\000" } initializer { dims: 2 data_type: 1 - name: "weight" + name: "running_var" raw_data: "\000\000\200?\000\000\200?" } input { diff --git a/test/onnx/expect/TestOperators.test_batchnorm_1d.expect b/test/onnx/expect/TestOperators.test_batchnorm_1d.expect index 27b4e18c6e69c..4a87406616960 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_1d.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_1d.expect @@ -26,25 +26,25 @@ graph { initializer { dims: 2 data_type: 1 - name: "bias" - raw_data: "\000\000\000\000\000\000\000\000" + name: "weight" + raw_data: "\000\000\200?\000\000\200?" } initializer { dims: 2 data_type: 1 - name: "running_mean" + name: "bias" raw_data: "\000\000\000\000\000\000\000\000" } initializer { dims: 2 data_type: 1 - name: "running_var" - raw_data: "\000\000\200?\000\000\200?" + name: "running_mean" + raw_data: "\000\000\000\000\000\000\000\000" } initializer { dims: 2 data_type: 1 - name: "weight" + name: "running_var" raw_data: "\000\000\200?\000\000\200?" } input { diff --git a/test/onnx/expect/TestOperators.test_batchnorm_onnx_irv4.expect b/test/onnx/expect/TestOperators.test_batchnorm_onnx_irv4.expect index 72231618f7104..6e6530c4e5d36 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_onnx_irv4.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_onnx_irv4.expect @@ -26,25 +26,25 @@ graph { initializer { dims: 2 data_type: 1 - name: "bias" - raw_data: "\000\000\000\000\000\000\000\000" + name: "weight" + raw_data: "\000\000\200?\000\000\200?" } initializer { dims: 2 data_type: 1 - name: "running_mean" + name: "bias" raw_data: "\000\000\000\000\000\000\000\000" } initializer { dims: 2 data_type: 1 - name: "running_var" - raw_data: "\000\000\200?\000\000\200?" + name: "running_mean" + raw_data: "\000\000\000\000\000\000\000\000" } initializer { dims: 2 data_type: 1 - name: "weight" + name: "running_var" raw_data: "\000\000\200?\000\000\200?" } input { diff --git a/test/onnx/expect/TestOperators.test_batchnorm_training.expect b/test/onnx/expect/TestOperators.test_batchnorm_training.expect index 0980885042f03..507d1091c0312 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_training.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_training.expect @@ -27,6 +27,12 @@ graph { } } name: "torch-jit-export" + initializer { + dims: 2 + data_type: 1 + name: "weight" + raw_data: "\000\000\200?\000\000\200?" + } initializer { dims: 2 data_type: 1 @@ -45,12 +51,6 @@ graph { name: "running_var" raw_data: "fff?fff?" } - initializer { - dims: 2 - data_type: 1 - name: "weight" - raw_data: "\000\000\200?\000\000\200?" - } input { name: "input" type { diff --git a/test/onnx/expect/TestOperators.test_bitshift.expect b/test/onnx/expect/TestOperators.test_bitshift.expect index 1e94496413758..af67f38bb5e98 100644 --- a/test/onnx/expect/TestOperators.test_bitshift.expect +++ b/test/onnx/expect/TestOperators.test_bitshift.expect @@ -17,23 +17,34 @@ graph { } node { input: "4" - input: "10" + input: "11" output: "5" name: "Pow_1" op_type: "Pow" } node { - input: "0" input: "5" output: "6" - name: "Div_2" + name: "Cast_2" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + } + node { + input: "0" + input: "6" + output: "7" + name: "Div_3" op_type: "Div" } node { input: "1" - input: "11" - output: "9" - name: "BitShift_3" + input: "12" + output: "10" + name: "BitShift_4" op_type: "BitShift" attribute { name: "direction" @@ -44,12 +55,12 @@ graph { name: "torch-jit-export" initializer { data_type: 1 - name: "10" + name: "11" raw_data: "\000\000\200?" } initializer { data_type: 2 - name: "11" + name: "12" raw_data: "\002" } input { @@ -91,7 +102,7 @@ graph { } } output { - name: "6" + name: "7" type { tensor_type { elem_type: 1 @@ -110,7 +121,7 @@ graph { } } output { - name: "9" + name: "10" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_empty_like.expect b/test/onnx/expect/TestOperators.test_empty_like.expect index e2560305767a2..69bacbd9afe3c 100644 --- a/test/onnx/expect/TestOperators.test_empty_like.expect +++ b/test/onnx/expect/TestOperators.test_empty_like.expect @@ -47,10 +47,10 @@ graph { elem_type: 1 shape { dim { - dim_value: 5 + dim_param: "2_0" } dim { - dim_value: 8 + dim_param: "2_1" } } } diff --git a/test/onnx/expect/TestOperators.test_equal.expect b/test/onnx/expect/TestOperators.test_equal.expect index 0066afd119be9..4d3ca19778380 100644 --- a/test/onnx/expect/TestOperators.test_equal.expect +++ b/test/onnx/expect/TestOperators.test_equal.expect @@ -3,15 +3,15 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "x" - input: "y" + input: "0" + input: "1" output: "2" name: "Equal_0" op_type: "Equal" } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 6 @@ -33,7 +33,7 @@ graph { } } input { - name: "y" + name: "1" type { tensor_type { elem_type: 6 diff --git a/test/onnx/expect/TestOperators.test_expand.expect b/test/onnx/expect/TestOperators.test_expand.expect index 817fe955c9b5c..24e67fa690e96 100644 --- a/test/onnx/expect/TestOperators.test_expand.expect +++ b/test/onnx/expect/TestOperators.test_expand.expect @@ -17,7 +17,7 @@ graph { } } node { - input: "11" + input: "10" output: "3" name: "ConstantOfShape_1" op_type: "ConstantOfShape" @@ -74,35 +74,24 @@ graph { } node { input: "7" - output: "8" - name: "Cast_6" - op_type: "Cast" - attribute { - name: "to" - i: 9 - type: INT - } - } - node { - input: "8" input: "3" input: "1" - output: "9" - name: "Where_7" + output: "8" + name: "Where_6" op_type: "Where" } node { input: "0" - input: "9" - output: "10" - name: "Expand_8" + input: "8" + output: "9" + name: "Expand_7" op_type: "Expand" } name: "torch-jit-export" initializer { dims: 1 data_type: 7 - name: "11" + name: "10" raw_data: "\003\000\000\000\000\000\000\000" } input { @@ -122,7 +111,7 @@ graph { } } output { - name: "10" + name: "9" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect index 0899d05e2a8a7..b5ea4c50e216e 100644 --- a/test/onnx/expect/TestOperators.test_full.expect +++ b/test/onnx/expect/TestOperators.test_full.expect @@ -137,10 +137,10 @@ graph { elem_type: 1 shape { dim { - dim_value: 3 + dim_param: "10_0" } dim { - dim_value: 4 + dim_param: "10_1" } } } diff --git a/test/onnx/expect/TestOperators.test_full_like.expect b/test/onnx/expect/TestOperators.test_full_like.expect index 87d139dab1328..f7b05e632b73a 100644 --- a/test/onnx/expect/TestOperators.test_full_like.expect +++ b/test/onnx/expect/TestOperators.test_full_like.expect @@ -47,10 +47,10 @@ graph { elem_type: 1 shape { dim { - dim_value: 3 + dim_param: "2_0" } dim { - dim_value: 4 + dim_param: "2_1" } } } diff --git a/test/onnx/expect/TestOperators.test_ge.expect b/test/onnx/expect/TestOperators.test_ge.expect index 3c2a91443a54f..9551d6d0c88fa 100644 --- a/test/onnx/expect/TestOperators.test_ge.expect +++ b/test/onnx/expect/TestOperators.test_ge.expect @@ -3,8 +3,8 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "x" - input: "y" + input: "0" + input: "1" output: "2" name: "Less_0" op_type: "Less" @@ -17,7 +17,7 @@ graph { } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 6 @@ -33,7 +33,7 @@ graph { } } input { - name: "y" + name: "1" type { tensor_type { elem_type: 6 diff --git a/test/onnx/expect/TestOperators.test_gt.expect b/test/onnx/expect/TestOperators.test_gt.expect index d2cc8dc0ac002..353e022b50052 100644 --- a/test/onnx/expect/TestOperators.test_gt.expect +++ b/test/onnx/expect/TestOperators.test_gt.expect @@ -3,15 +3,15 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "x" - input: "y" + input: "0" + input: "1" output: "2" name: "Greater_0" op_type: "Greater" } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 6 @@ -33,7 +33,7 @@ graph { } } input { - name: "y" + name: "1" type { tensor_type { elem_type: 6 diff --git a/test/onnx/expect/TestOperators.test_layer_norm_aten.expect b/test/onnx/expect/TestOperators.test_layer_norm_aten.expect index c1c4f8023245e..1ef3fb40126f2 100644 --- a/test/onnx/expect/TestOperators.test_layer_norm_aten.expect +++ b/test/onnx/expect/TestOperators.test_layer_norm_aten.expect @@ -35,15 +35,15 @@ graph { dims: 10 dims: 10 data_type: 1 - name: "bias" - raw_data: "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000" + name: "weight" + raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" } initializer { dims: 10 dims: 10 data_type: 1 - name: "weight" - raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" + name: "bias" + raw_data: "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000" } input { name: "input" diff --git a/test/onnx/expect/TestOperators.test_le.expect b/test/onnx/expect/TestOperators.test_le.expect index e7257cc449124..86679b01efca0 100644 --- a/test/onnx/expect/TestOperators.test_le.expect +++ b/test/onnx/expect/TestOperators.test_le.expect @@ -3,8 +3,8 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "x" - input: "y" + input: "0" + input: "1" output: "2" name: "Greater_0" op_type: "Greater" @@ -17,7 +17,7 @@ graph { } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 6 @@ -33,7 +33,7 @@ graph { } } input { - name: "y" + name: "1" type { tensor_type { elem_type: 6 diff --git a/test/onnx/expect/TestOperators.test_linear.expect b/test/onnx/expect/TestOperators.test_linear.expect index abb5a07c564a9..4eb21f3d81e30 100644 --- a/test/onnx/expect/TestOperators.test_linear.expect +++ b/test/onnx/expect/TestOperators.test_linear.expect @@ -28,16 +28,16 @@ graph { name: "torch-jit-export" initializer { dims: 5 + dims: 4 data_type: 1 - name: "bias" - raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>" + name: "weight" + raw_data: "\212\332\356>@\265u>p\303E\275 \320\306\274\354\201\221>\004\354\261\276\2746*>8\247)\276\340\035\224>\024\2446\276\200\211\312<\224\344,>D\356\257>\320\202\226\275\364\213\351>z\226\330\276\310\250\266\275\352F\377\276\000\250)=\244K\021>" } initializer { dims: 5 - dims: 4 data_type: 1 - name: "weight" - raw_data: "\212\332\356>@\265u>p\303E\275 \320\306\274\354\201\221>\004\354\261\276\2746*>8\247)\276\340\035\224>\024\2446\276\200\211\312<\224\344,>D\356\257>\320\202\226\275\364\213\351>z\226\330\276\310\250\266\275\352F\377\276\000\250)=\244K\021>" + name: "bias" + raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>" } input { name: "input" diff --git a/test/onnx/expect/TestOperators.test_lt.expect b/test/onnx/expect/TestOperators.test_lt.expect index f34c95ffba163..8778b0ef5cdb6 100644 --- a/test/onnx/expect/TestOperators.test_lt.expect +++ b/test/onnx/expect/TestOperators.test_lt.expect @@ -3,15 +3,15 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "x" - input: "y" + input: "0" + input: "1" output: "2" name: "Less_0" op_type: "Less" } name: "torch-jit-export" input { - name: "x" + name: "0" type { tensor_type { elem_type: 6 @@ -33,7 +33,7 @@ graph { } } input { - name: "y" + name: "1" type { tensor_type { elem_type: 6 diff --git a/test/onnx/expect/TestOperators.test_ones_like.expect b/test/onnx/expect/TestOperators.test_ones_like.expect index ad529a21feec0..469c860f8e961 100644 --- a/test/onnx/expect/TestOperators.test_ones_like.expect +++ b/test/onnx/expect/TestOperators.test_ones_like.expect @@ -47,10 +47,10 @@ graph { elem_type: 1 shape { dim { - dim_value: 6 + dim_param: "2_0" } dim { - dim_value: 10 + dim_param: "2_1" } } } diff --git a/test/onnx/expect/TestOperators.test_softmaxcrossentropy.expect b/test/onnx/expect/TestOperators.test_softmaxcrossentropy.expect index cde473fcdb4d5..1479846789d46 100644 --- a/test/onnx/expect/TestOperators.test_softmaxcrossentropy.expect +++ b/test/onnx/expect/TestOperators.test_softmaxcrossentropy.expect @@ -8,6 +8,11 @@ graph { output: "2" name: "SoftmaxCrossEntropyLoss_0" op_type: "SoftmaxCrossEntropyLoss" + attribute { + name: "ignore_index" + i: -100 + type: INT + } attribute { name: "reduction" s: "mean" diff --git a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d.expect b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d.expect index 58d8c805163d2..f5cfba35b0324 100644 --- a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d.expect +++ b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d.expect @@ -8,6 +8,11 @@ graph { output: "2" name: "SoftmaxCrossEntropyLoss_0" op_type: "SoftmaxCrossEntropyLoss" + attribute { + name: "ignore_index" + i: -100 + type: INT + } attribute { name: "reduction" s: "mean" diff --git a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d_none.expect b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d_none.expect index 10d47a6ed84d0..8b0ec04b24c8e 100644 --- a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d_none.expect +++ b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d_none.expect @@ -8,6 +8,11 @@ graph { output: "2" name: "SoftmaxCrossEntropyLoss_0" op_type: "SoftmaxCrossEntropyLoss" + attribute { + name: "ignore_index" + i: -100 + type: INT + } attribute { name: "reduction" s: "none" diff --git a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_4d.expect b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_4d.expect index 6ccab9f7b50f3..8d3539ca1c643 100644 --- a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_4d.expect +++ b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_4d.expect @@ -8,6 +8,11 @@ graph { output: "2" name: "SoftmaxCrossEntropyLoss_0" op_type: "SoftmaxCrossEntropyLoss" + attribute { + name: "ignore_index" + i: -100 + type: INT + } attribute { name: "reduction" s: "mean" diff --git a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_weights.expect b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_weights.expect index 1ea4adac8cab3..bf1667b588123 100644 --- a/test/onnx/expect/TestOperators.test_softmaxcrossentropy_weights.expect +++ b/test/onnx/expect/TestOperators.test_softmaxcrossentropy_weights.expect @@ -9,6 +9,11 @@ graph { output: "3" name: "SoftmaxCrossEntropyLoss_0" op_type: "SoftmaxCrossEntropyLoss" + attribute { + name: "ignore_index" + i: -100 + type: INT + } attribute { name: "reduction" s: "mean" diff --git a/test/onnx/expect/TestOperators.test_std.expect b/test/onnx/expect/TestOperators.test_std.expect index a3a416908f3b1..690e381042f3d 100644 --- a/test/onnx/expect/TestOperators.test_std.expect +++ b/test/onnx/expect/TestOperators.test_std.expect @@ -3,16 +3,9 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" input: "0" output: "1" - name: "Mul_0" - op_type: "Mul" - } - node { - input: "1" - output: "2" - name: "ReduceMean_1" + name: "ReduceMean_0" op_type: "ReduceMean" attribute { name: "axes" @@ -28,85 +21,127 @@ graph { } node { input: "0" + output: "2" + name: "Shape_1" + op_type: "Shape" + } + node { output: "3" - name: "ReduceMean_2" - op_type: "ReduceMean" - attribute { - name: "axes" - ints: 0 - ints: 1 - type: INTS - } + name: "Constant_2" + op_type: "Constant" attribute { - name: "keepdims" - i: 1 - type: INT + name: "value" + t { + dims: 2 + data_type: 7 + raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" + } + type: TENSOR } } node { - input: "3" + input: "2" input: "3" output: "4" - name: "Mul_3" - op_type: "Mul" + name: "Gather_3" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } } node { - input: "2" input: "4" output: "5" - name: "Sub_4" - op_type: "Sub" + name: "ReduceProd_4" + op_type: "ReduceProd" + attribute { + name: "keepdims" + i: 0 + type: INT + } } node { - input: "5" + input: "0" + input: "1" output: "6" - name: "Abs_5" - op_type: "Abs" + name: "Sub_5" + op_type: "Sub" } node { + input: "6" + input: "6" output: "7" - name: "Constant_6" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 1 - raw_data: "\000\000\300@" - } - type: TENSOR - } + name: "Mul_6" + op_type: "Mul" } node { - input: "6" input: "7" output: "8" - name: "Mul_7" - op_type: "Mul" + name: "ReduceMean_7" + op_type: "ReduceMean" + attribute { + name: "axes" + ints: 0 + ints: 1 + type: INTS + } + attribute { + name: "keepdims" + i: 1 + type: INT + } } node { + input: "5" output: "9" - name: "Constant_8" + name: "Cast_8" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + } + node { + input: "8" + input: "9" + output: "10" + name: "Mul_9" + op_type: "Mul" + } + node { + output: "11" + name: "Constant_10" op_type: "Constant" attribute { name: "value" t { data_type: 1 - raw_data: "\000\000\240@" + raw_data: "\000\000\200?" } type: TENSOR } } node { - input: "8" input: "9" - output: "10" - name: "Div_9" - op_type: "Div" + input: "11" + output: "12" + name: "Sub_11" + op_type: "Sub" } node { input: "10" - output: "11" - name: "Sqrt_10" + input: "12" + output: "13" + name: "Div_12" + op_type: "Div" + } + node { + input: "13" + output: "14" + name: "Sqrt_13" op_type: "Sqrt" } name: "torch-jit-export" @@ -130,7 +165,7 @@ graph { } } output { - name: "11" + name: "14" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_topk.expect b/test/onnx/expect/TestOperators.test_topk.expect index 27f374243d2c1..71c3b2b226fbb 100644 --- a/test/onnx/expect/TestOperators.test_topk.expect +++ b/test/onnx/expect/TestOperators.test_topk.expect @@ -67,7 +67,7 @@ graph { elem_type: 1 shape { dim { - dim_value: 3 + dim_param: "4_0" } } } @@ -80,7 +80,7 @@ graph { elem_type: 7 shape { dim { - dim_value: 3 + dim_param: "5_0" } } } diff --git a/test/onnx/expect/TestOperators.test_topk_smallest_unsorted.expect b/test/onnx/expect/TestOperators.test_topk_smallest_unsorted.expect index 1c09194cb9a51..fa807dc9eb5d8 100644 --- a/test/onnx/expect/TestOperators.test_topk_smallest_unsorted.expect +++ b/test/onnx/expect/TestOperators.test_topk_smallest_unsorted.expect @@ -77,7 +77,7 @@ graph { elem_type: 1 shape { dim { - dim_value: 3 + dim_param: "4_0" } } } @@ -90,7 +90,7 @@ graph { elem_type: 7 shape { dim { - dim_value: 3 + dim_param: "5_0" } } } diff --git a/test/onnx/expect/TestOperators.test_unique.expect b/test/onnx/expect/TestOperators.test_unique.expect index 43e46b3b58902..5f407644acaf7 100644 --- a/test/onnx/expect/TestOperators.test_unique.expect +++ b/test/onnx/expect/TestOperators.test_unique.expect @@ -51,7 +51,7 @@ graph { elem_type: 1 shape { dim { - dim_value: 2 + dim_param: "1_0" } dim { dim_value: 3 @@ -73,7 +73,7 @@ graph { elem_type: 7 shape { dim { - dim_value: 2 + dim_param: "4_0" } } } diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect b/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect index 67d765831c1be..56c825560d077 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect @@ -50,16 +50,16 @@ graph { elem_type: 1 shape { dim { - dim_value: 1 + dim_param: "4_0" } dim { - dim_value: 2 + dim_param: "4_1" } dim { - dim_value: 6 + dim_param: "4_2" } dim { - dim_value: 8 + dim_param: "4_3" } } } diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect b/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect index 67d765831c1be..56c825560d077 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect @@ -50,16 +50,16 @@ graph { elem_type: 1 shape { dim { - dim_value: 1 + dim_param: "4_0" } dim { - dim_value: 2 + dim_param: "4_1" } dim { - dim_value: 6 + dim_param: "4_2" } dim { - dim_value: 8 + dim_param: "4_3" } } } diff --git a/test/onnx/expect/TestOperators.test_view.expect b/test/onnx/expect/TestOperators.test_view.expect index 75202b5d0da21..abd2276e7716f 100644 --- a/test/onnx/expect/TestOperators.test_view.expect +++ b/test/onnx/expect/TestOperators.test_view.expect @@ -3,16 +3,26 @@ producer_name: "pytorch" producer_version: "CURRENT_VERSION" graph { node { - input: "0" output: "1" - name: "Flatten_0" - op_type: "Flatten" + name: "Constant_0" + op_type: "Constant" attribute { - name: "axis" - i: 1 - type: INT + name: "value" + t { + dims: 2 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000" + } + type: TENSOR } } + node { + input: "0" + input: "1" + output: "2" + name: "Reshape_1" + op_type: "Reshape" + } name: "torch-jit-export" input { name: "0" @@ -28,7 +38,7 @@ graph { } } output { - name: "1" + name: "2" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_view_flatten.expect b/test/onnx/expect/TestOperators.test_view_flatten.expect index 07667797e2cf9..5ae9c0576c7a9 100644 --- a/test/onnx/expect/TestOperators.test_view_flatten.expect +++ b/test/onnx/expect/TestOperators.test_view_flatten.expect @@ -65,60 +65,40 @@ graph { } } node { - input: "6" output: "7" - name: "Cast_6" - op_type: "Cast" - attribute { - name: "to" - i: 11 - type: INT - } - } - node { - output: "8" - name: "Constant_7" + name: "Constant_6" op_type: "Constant" attribute { name: "value" t { - data_type: 11 - raw_data: "\000\000\000\000\000\000\360?" + data_type: 7 + raw_data: "\030\000\000\000\000\000\000\000" } type: TENSOR } } node { - input: "8" input: "7" - output: "9" - name: "Div_8" + input: "6" + output: "8" + name: "Div_7" op_type: "Div" } node { - output: "10" - name: "Constant_9" - op_type: "Constant" + input: "8" + output: "9" + name: "Cast_8" + op_type: "Cast" attribute { - name: "value" - t { - data_type: 11 - raw_data: "\000\000\000\000\000\0008@" - } - type: TENSOR + name: "to" + i: 7 + type: INT } } node { input: "9" - input: "10" - output: "11" - name: "Mul_10" - op_type: "Mul" - } - node { - input: "11" - output: "12" - name: "Cast_11" + output: "10" + name: "Cast_9" op_type: "Cast" attribute { name: "to" @@ -128,8 +108,8 @@ graph { } node { input: "3" - output: "13" - name: "Unsqueeze_12" + output: "11" + name: "Unsqueeze_10" op_type: "Unsqueeze" attribute { name: "axes" @@ -138,9 +118,9 @@ graph { } } node { - input: "12" - output: "14" - name: "Unsqueeze_13" + input: "10" + output: "12" + name: "Unsqueeze_11" op_type: "Unsqueeze" attribute { name: "axes" @@ -149,10 +129,10 @@ graph { } } node { - input: "13" - input: "14" - output: "15" - name: "Concat_14" + input: "11" + input: "12" + output: "13" + name: "Concat_12" op_type: "Concat" attribute { name: "axis" @@ -162,9 +142,9 @@ graph { } node { input: "0" - input: "15" - output: "16" - name: "Reshape_15" + input: "13" + output: "14" + name: "Reshape_13" op_type: "Reshape" } name: "torch-jit-export" @@ -191,7 +171,7 @@ graph { } } output { - name: "16" + name: "14" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_zeros_like.expect b/test/onnx/expect/TestOperators.test_zeros_like.expect index e2560305767a2..69bacbd9afe3c 100644 --- a/test/onnx/expect/TestOperators.test_zeros_like.expect +++ b/test/onnx/expect/TestOperators.test_zeros_like.expect @@ -47,10 +47,10 @@ graph { elem_type: 1 shape { dim { - dim_value: 5 + dim_param: "2_0" } dim { - dim_value: 8 + dim_param: "2_1" } } } diff --git a/test/onnx/export_onnx_tests_generator.py b/test/onnx/export_onnx_tests_generator.py index 39971b5a313b8..0eab014d86418 100644 --- a/test/onnx/export_onnx_tests_generator.py +++ b/test/onnx/export_onnx_tests_generator.py @@ -124,7 +124,7 @@ def convert_tests(testcases, sets=1): input = gen_input(t) if (module_name != "FunctionalModule"): nn_module[module_name] |= 1 - except: # noqa: E722 + except: # noqa: E722,B001 traceback.print_exc() if (module_name != "FunctionalModule"): nn_module[module_name] |= 2 diff --git a/test/onnx/pytorch_helper.py b/test/onnx/pytorch_helper.py index 17afa53b57e4d..e027e0393ba70 100644 --- a/test/onnx/pytorch_helper.py +++ b/test/onnx/pytorch_helper.py @@ -23,7 +23,7 @@ def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=No """ Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built. - Arguments: + Args: helper (caffe2.python.core.ModelHelder): the model helper where this imported network should be inserted model (torch.nn.Module): the model to be exported diff --git a/test/onnx/test_models.py b/test/onnx/test_models.py index 6f37fa6d7e72c..0613c69c08677 100644 --- a/test/onnx/test_models.py +++ b/test/onnx/test_models.py @@ -49,7 +49,6 @@ class TestModels(TestCase): opset_version = _export_onnx_opset_version def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7): - self.is_script_test_enabled = True with torch.onnx.select_model_mode_for_export(model, None): graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX) torch._C._jit_pass_lint(graph) @@ -94,21 +93,18 @@ def test_srresnet(self): self.exportTest(toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x)) @skipIfNoLapack - @disableScriptTest() def test_super_resolution(self): x = Variable( torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0) ) self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6) - @disableScriptTest() def test_alexnet(self): x = Variable( torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0) ) self.exportTest(toC(alexnet()), toC(x)) - @disableScriptTest() def test_mnist(self): x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0)) self.exportTest(toC(MNIST()), toC(x)) @@ -137,13 +133,12 @@ def test_vgg19_bn(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(vgg19_bn()), toC(x)) - @disableScriptTest() def test_resnet(self): # ResNet50 model x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(resnet50()), toC(x), atol=1e-6) - @disableScriptTest() + @disableScriptTest() # None type in outputs def test_inception(self): x = Variable( torch.randn(BATCH_SIZE, 3, 299, 299) + 1.) @@ -162,7 +157,6 @@ def test_squeezenet(self): sqnet_v1_1 = SqueezeNet(version=1.1) self.exportTest(toC(sqnet_v1_1), toC(x)) - @disableScriptTest() def test_densenet(self): # Densenet-121 model x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) @@ -208,22 +202,20 @@ def test_qat_resnet(self): self.exportTest(toC(qat_resnet50), toC(x)) - @disableScriptTest() + @disableScriptTest() # None type in outputs def test_googlenet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5) - @disableScriptTest() def test_mnasnet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5) - @disableScriptTest() def test_mobilenet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5) - @disableScriptTest() + @disableScriptTest() # prim_data def test_shufflenet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5) @@ -238,20 +230,18 @@ def test_deeplab(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(deeplabv3_resnet101()), toC(x), rtol=1e-3, atol=1e-5) - @disableScriptTest() def test_r3d_18_video(self): x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5) - @disableScriptTest() def test_mc3_18_video(self): x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5) - @disableScriptTest() def test_r2plus1d_18_video(self): x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0)) self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5) + if __name__ == '__main__': run_tests() diff --git a/test/onnx/test_models_onnxruntime.py b/test/onnx/test_models_onnxruntime.py index 657a1479723d6..c916b60844d1c 100644 --- a/test/onnx/test_models_onnxruntime.py +++ b/test/onnx/test_models_onnxruntime.py @@ -15,13 +15,31 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None): input=inputs, rtol=rtol, atol=atol) if self.is_script_test_enabled and opset_version > 11: + TestModels.use_new_jit_passes = True + TestModels.onnx_shape_inference = True + outputs = model(inputs) script_model = torch.jit.script(model) run_model_test(self, script_model, False, example_outputs=outputs, - input=inputs, rtol=rtol, atol=atol, use_new_jit_passes=True) + input=inputs, rtol=rtol, atol=atol) + + +TestModels = type(str("TestModels"), + (unittest.TestCase,), + dict(TestModels.__dict__, + is_script_test_enabled=False, + exportTest=exportTest)) + + +# model tests for scripting with new JIT APIs and shape inference +TestModels_new_jit_API = type(str("TestModels_new_jit_API"), + (unittest.TestCase,), + dict(TestModels.__dict__, + exportTest=exportTest, + is_script_test_enabled=True, + use_new_jit_passes=True, + onnx_shape_inference=True)) if __name__ == '__main__': - TestModels.is_script_test_enabled = True - TestModels.exportTest = exportTest unittest.main() diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index c3d4482ce6590..29b98e5f7988d 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -333,45 +333,6 @@ def forward(self, x): x = torch.randn(20, 16, 50) check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10]) - def test_std(self): - class MyModule(Module): - def forward(self, input): - return torch.std(input, unbiased=False) - - ops = [{"op_name": "Mul"}, - {"op_name": "ReduceMean", "attributes": [{"name": "keepdims", "i": 0, "type": 2}]}, - {"op_name": "ReduceMean", "attributes": [{"name": "keepdims", "i": 0, "type": 2}]}, - {"op_name": "Mul"}, - {"op_name": "Sub"}, - {"op_name": "Abs"}, - {"op_name": "Sqrt"}] - ops = {9: ops, 10: ops} - x = torch.randn(2, 3, 4) - check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) - - def test_std_along_dims(self): - class MyModule(Module): - def forward(self, input): - return torch.std(input, dim=(0, 1), unbiased=True, keepdim=True) - - ops = [{"op_name": "Mul"}, - {"op_name": "ReduceMean", - "attributes": [{"name": "axes", "ints": [0, 1], "type": 7}, {"name": "keepdims", "i": 1, "type": 2}]}, - {"op_name": "ReduceMean", - "attributes": [{"name": "axes", "ints": [0, 1], "type": 7}, {"name": "keepdims", "i": 1, "type": 2}]}, - {"op_name": "Mul"}, - {"op_name": "Sub"}, - {"op_name": "Abs"}, - {"op_name": "Constant"}, - {"op_name": "Mul"}, - {"op_name": "Constant"}, - {"op_name": "Div"}, - {"op_name": "Sqrt"} - ] - ops = {9: ops, 10: ops} - x = torch.randn(2, 3, 4) - check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) - if __name__ == '__main__': run_tests() diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 8ccf0fdfdb890..f6fa533d78379 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -776,7 +776,7 @@ def forward(self, x_in): return x_out x = {torch.tensor(1.): torch.randn(1, 2, 3)} - self.assertONNX(MyModel(), (x,)) + self.assertONNX(MyModel(), (x, {})) def test_dict_str(self): class MyModel(torch.nn.Module): @@ -786,7 +786,7 @@ def forward(self, x_in): return x_out x = {"test_key_in": torch.randn(1, 2, 3)} - self.assertONNX(MyModel(), (x,)) + self.assertONNX(MyModel(), (x, {})) def test_arange_dynamic(self): class TestModel(torch.nn.Module): diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 9437197e986bb..3f0bb2a6b2485 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1,22 +1,24 @@ -import numpy as np +from typing import Tuple +import io +import itertools import sys import unittest -import itertools -import torch.onnx -import torch.onnx.operators -from torch.onnx import ExportTypes +import numpy as np + +from debug_embed_params import run_embed_params from torch import nn from torch.autograd import Variable, function -import torch.utils.model_zoo as model_zoo from torch.nn.utils import rnn as rnn_utils -from debug_embed_params import run_embed_params -import io +from torch.onnx import ExportTypes +import torch.onnx +import torch.onnx.operators +import torch.utils.model_zoo as model_zoo # Import various models for testing from torchvision.models.alexnet import alexnet -from torchvision.models.inception import inception_v3 from torchvision.models.densenet import densenet121 +from torchvision.models.inception import inception_v3 from torchvision.models.resnet import resnet50 from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn @@ -464,11 +466,11 @@ def test_rnn_init_predict_split(self): do_constant_folding=False)[0]) prepared = c2.prepare(mp, device='CPU') if self.embed_params: - assert len(prepared.init_net.op) == 875 - assert len(prepared.predict_net.op) == 130 + assert len(prepared.init_net.op) == 879 + assert len(prepared.predict_net.op) == 133 else: - assert len(prepared.init_net.op) == 8 - assert len(prepared.predict_net.op) == 997 + assert len(prepared.init_net.op) == 12 + assert len(prepared.predict_net.op) == 1000 def test_alexnet(self): state_dict = model_zoo.load_url(model_urls['alexnet'], progress=False) @@ -1981,8 +1983,7 @@ def forward(self, lstm_in): def test_tuple_input_output(self): class TupleModel(torch.jit.ScriptModule): @torch.jit.script_method - def forward(self, a): - # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor] + def forward(self, a: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: return a x = (torch.randn(3, 4), torch.randn(4, 3)) @@ -1992,8 +1993,7 @@ def forward(self, a): def test_nested_tuple_input_output(self): class NestedTupleModel(torch.jit.ScriptModule): @torch.jit.script_method - def forward(self, a, b): - # type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor + def forward(self, a: torch.Tensor, b: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor: return a + b[0] + b[1][0] + b[1][1] x = torch.randn(4, 5) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 816951dfc79e2..33d428bbb42a5 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -6,6 +6,7 @@ import io import itertools import copy +import os from torch.nn.utils import rnn as rnn_utils from model_defs.lstm_flattening_result import LstmFlatteningResult @@ -15,8 +16,16 @@ skipIfUnsupportedMaxOpsetVersion, skipIfONNXShapeInference) from test_pytorch_common import BATCH_SIZE from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE +from typing import List import model_defs.word_language_model as word_language_model import torchvision +from torchvision import ops +from torchvision.models.detection.image_list import ImageList +from torchvision.models.detection.transform import GeneralizedRCNNTransform +from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from torchvision.models.detection.roi_heads import RoIHeads +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead +from collections import OrderedDict import onnx def to_numpy(tensor): @@ -52,7 +61,7 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None, def run_ort(ort_sess, input): input_copy = copy.deepcopy(input) input, _ = torch.jit._flatten(input_copy) - inputs = list(map(to_numpy, input)) + inputs = [to_numpy(inp) for inp in input] ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs)) ort_outs = ort_sess.run(None, ort_inputs) @@ -61,7 +70,7 @@ def run_ort(ort_sess, input): def ort_compare_with_pytorch(ort_outs, output, rtol, atol): output, _ = torch.jit._flatten(output) - outputs = list(map(to_numpy, output)) + outputs = [to_numpy(outp) for outp in output] # compare onnxruntime and PyTorch results assert len(outputs) == len(ort_outs), "number of outputs differ" @@ -75,22 +84,30 @@ def run_model_test(self, model, batch_size=2, state_dict=None, example_outputs=None, do_constant_folding=True, dynamic_axes=None, test_with_inputs=None, input_names=None, output_names=None, - fixed_batch_size=False): + fixed_batch_size=False, dict_check=True): model.eval() if input is None: input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) - with torch.no_grad(): if isinstance(input, torch.Tensor): input = (input,) # In-place operators will update input tensor data as well. # Thus inputs are replicated before every forward call. - input_copy = copy.deepcopy(input) - output = model(*input_copy) + if isinstance(input, dict): + input = (input,) + input_args = copy.deepcopy(input) + input_kwargs = {} + if dict_check and isinstance(input_args[-1], dict): + input_kwargs = input_args[-1] + input_args = input_args[:-1] + output = model(*input_args, **input_kwargs) if isinstance(output, torch.Tensor): output = (output,) + if not dict_check and isinstance(input[-1], dict): + input = input + ({},) + ort_sess = convert_to_onnx(model, input=input, opset_version=self.opset_version, example_outputs=output, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=self.keep_initializers_as_inputs, @@ -116,13 +133,78 @@ def run_model_test(self, model, batch_size=2, state_dict=None, ort_outs = run_ort(ort_sess, test_input) ort_compare_with_pytorch(ort_outs, output, rtol, atol) +def _init_test_generalized_rcnn_transform(): + min_size = 100 + max_size = 200 + image_mean = [0.485, 0.456, 0.406] + image_std = [0.229, 0.224, 0.225] + transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + return transform + +def _init_test_rpn(): + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) + out_channels = 256 + rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) + rpn_fg_iou_thresh = 0.7 + rpn_bg_iou_thresh = 0.3 + rpn_batch_size_per_image = 256 + rpn_positive_fraction = 0.5 + rpn_pre_nms_top_n = dict(training=2000, testing=1000) + rpn_post_nms_top_n = dict(training=2000, testing=1000) + rpn_nms_thresh = 0.7 + + rpn = RegionProposalNetwork( + rpn_anchor_generator, rpn_head, + rpn_fg_iou_thresh, rpn_bg_iou_thresh, + rpn_batch_size_per_image, rpn_positive_fraction, + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) + return rpn + +def _init_test_roi_heads_faster_rcnn(): + out_channels = 256 + num_classes = 91 + + box_fg_iou_thresh = 0.5 + box_bg_iou_thresh = 0.5 + box_batch_size_per_image = 512 + box_positive_fraction = 0.25 + bbox_reg_weights = None + box_score_thresh = 0.05 + box_nms_thresh = 0.5 + box_detections_per_img = 100 + + box_roi_pool = ops.MultiScaleRoIAlign( + featmap_names=['0', '1', '2', '3'], + output_size=7, + sampling_ratio=2) + + resolution = box_roi_pool.output_size[0] + representation_size = 1024 + box_head = TwoMLPHead( + out_channels * resolution ** 2, + representation_size) + + representation_size = 1024 + box_predictor = FastRCNNPredictor( + representation_size, + num_classes) + + roi_heads = RoIHeads( + box_roi_pool, box_head, box_predictor, + box_fg_iou_thresh, box_bg_iou_thresh, + box_batch_size_per_image, box_positive_fraction, + bbox_reg_weights, + box_score_thresh, box_nms_thresh, box_detections_per_img) + return roi_heads class TestONNXRuntime(unittest.TestCase): from torch.onnx.symbolic_helper import _export_onnx_opset_version opset_version = _export_onnx_opset_version keep_initializers_as_inputs = True # For IR version 3 type export. use_new_jit_passes = False # For testing main code-path - onnx_shape_inference = False + onnx_shape_inference = True def setUp(self): torch.manual_seed(0) @@ -130,18 +212,19 @@ def setUp(self): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) np.random.seed(seed=0) + os.environ['ALLOW_RELEASED_ONNX_OPSET_ONLY'] = '0' self.is_script_test_enabled = True def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None, - input_names=None, output_names=None, fixed_batch_size=False): + input_names=None, output_names=None, fixed_batch_size=False, dict_check=True): def _run_test(m): return run_model_test(self, m, batch_size=batch_size, input=input, use_gpu=use_gpu, rtol=rtol, atol=atol, do_constant_folding=do_constant_folding, dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs, input_names=input_names, output_names=output_names, - fixed_batch_size=fixed_batch_size) + fixed_batch_size=fixed_batch_size, dict_check=dict_check) if self.is_script_test_enabled and self.use_new_jit_passes: script_model = torch.jit.script(model) _run_test(script_model) @@ -189,6 +272,7 @@ def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7, ort_outs = run_ort(ort_sess, input_copy) ort_compare_with_pytorch(ort_outs, output, rtol, atol) + @skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9. def test_embedding_model_with_external_data(self): class LargeModel(torch.nn.Module): @@ -210,6 +294,7 @@ def forward(self, input): x = torch.tensor([2], dtype=torch.long) self.run_model_test_with_external_data(model, x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9. def test_mobilenet_v2_with_external_data(self): model = torchvision.models.mobilenet_v2(pretrained=True) @@ -314,14 +399,17 @@ def run_word_language_model(self, model_name): # Only support CPU version, since tracer is not working in GPU RNN. self.run_test(model, (x, model.hidden)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() + @disableScriptTest() # Faster RCNN model is not scriptable def test_faster_rcnn(self): model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model.eval() x = torch.randn(2, 3, 200, 300, requires_grad=True) self.run_test(model, (x,), rtol=1e-3, atol=1e-5) + self.run_test(model, (x,), input_names=["images_tensors"], output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, rtol=1e-3, atol=1e-5) def get_image_from_url(self, url): import os @@ -348,6 +436,39 @@ def get_test_images(self): images = [image] return images + @skipIfUnsupportedOpsetVersion([13]) + def test_paste_mask_in_image(self): + # disable profiling + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + + masks = torch.rand(10, 1, 26, 26) + boxes = torch.rand(10, 4) + boxes[:, 2:] += torch.rand(10, 2) + boxes *= 50 + o_im_s = (100, 100) + from torchvision.models.detection.roi_heads import paste_masks_in_image + out = paste_masks_in_image(masks, boxes, o_im_s) + jit_trace = torch.jit.trace(paste_masks_in_image, + (masks, boxes, + [torch.tensor(o_im_s[0]), + torch.tensor(o_im_s[1])])) + out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]) + + assert torch.all(out.eq(out_trace)) + + masks2 = torch.rand(20, 1, 26, 26) + boxes2 = torch.rand(20, 4) + boxes2[:, 2:] += torch.rand(20, 2) + boxes2 *= 100 + o_im_s2 = (200, 200) + from torchvision.models.detection.roi_heads import paste_masks_in_image + out2 = paste_masks_in_image(masks2, boxes2, o_im_s2) + out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])]) + + assert torch.all(out2.eq(out_trace2)) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() def test_mask_rcnn(self): @@ -355,7 +476,35 @@ def test_mask_rcnn(self): max_size=300) images = self.get_test_images() self.run_test(model, (images,), rtol=1e-3, atol=1e-5) - + self.run_test(model, (images,), input_names=["images_tensors"], output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], + "scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5) + + def test_heatmaps_to_keypoints(self): + # disable profiling + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + + maps = torch.rand(10, 1, 26, 26) + rois = torch.rand(10, 4) + from torchvision.models.detection.roi_heads import heatmaps_to_keypoints + out = heatmaps_to_keypoints(maps, rois) + jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) + out_trace = jit_trace(maps, rois) + + assert torch.all(out[0].eq(out_trace[0])) + assert torch.all(out[1].eq(out_trace[1])) + + maps2 = torch.rand(20, 2, 21, 21) + rois2 = torch.rand(20, 4) + from torchvision.models.detection.roi_heads import heatmaps_to_keypoints + out2 = heatmaps_to_keypoints(maps2, rois2) + out_trace2 = jit_trace(maps2, rois2) + + assert torch.all(out2[0].eq(out_trace2[0])) + assert torch.all(out2[1].eq(out_trace2[1])) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() def test_keypoint_rcnn(self): @@ -363,44 +512,98 @@ def test_keypoint_rcnn(self): max_size=300) images = self.get_test_images() self.run_test(model, (images,), rtol=1e-3, atol=1e-5) + self.run_test(model, (images,), input_names=["images_tensors"], + output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2]}, + rtol=1e-3, atol=1e-5) + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_word_language_model_RNN_TANH(self): self.run_word_language_model("RNN_TANH") + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_word_language_model_RNN_RELU(self): self.run_word_language_model("RNN_RELU") + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_word_language_model_LSTM(self): self.run_word_language_model("LSTM") + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_word_language_model_GRU(self): self.run_word_language_model("GRU") - @disableScriptTest() def test_index_1d(self): - self._test_index_generic(lambda input: input[0]) + class MyModel(torch.nn.Module): + def forward(self, input): + return input[0] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) - @disableScriptTest() def test_index_2d_1dimslice(self): - self._test_index_generic(lambda input: input[0:1, :]) + class MyModel(torch.nn.Module): + def forward(self, input): + return input[0:1, :] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) - @disableScriptTest() def test_index_2d_sliceint(self): - self._test_index_generic(lambda input: input[1, :]) + class MyModel(torch.nn.Module): + def forward(self, input): + return input[1, :] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) - @disableScriptTest() def test_index_2d_neg_slice(self): - self._test_index_generic(lambda input: input[0:-1, :]) + class MyModel(torch.nn.Module): + def forward(self, input): + return input[0:-1, :] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() def test_index_mask(self): - self._test_index_generic(lambda input: input[torch.tensor([0, 1, 0], dtype=torch.uint8)]) - self._test_index_generic(lambda input: input[torch.tensor([0, 1, 0], dtype=torch.bool)]) + class MyModel(torch.nn.Module): + def forward(self, input): + return input[torch.tensor([0, 1, 0], dtype=torch.uint8)] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) + + class MyModel(torch.nn.Module): + def forward(self, input): + return input[torch.tensor([0, 1, 0], dtype=torch.bool)] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_data(self): + class Data(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return x.new_zeros(x.data.size()) + + x = torch.randn(3, 4) + self.run_test(Data(), x) + + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() # Need type inference + def test_index_mask_nd(self): + class MyModel(torch.nn.Module): + def forward(self, input): + return input[input > 0] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) @disableScriptTest() def test_dict(self): @@ -411,7 +614,7 @@ def forward(self, x_in): return x_out x = {torch.tensor(1.): torch.randn(1, 2, 3)} - self.run_test(MyModel(), (x,)) + self.run_test(MyModel(), (x, {})) @disableScriptTest() def test_dict_str(self): @@ -422,8 +625,142 @@ def forward(self, x_in): return x_out x = {"test_key_in": torch.randn(1, 2, 3)} - self.run_test(MyModel(), (x,)) + self.run_test(MyModel(), (x, {})) + + def test_optional_inputs_with_no_optionals(self): + class NoOptionalModel(torch.nn.Module): + def forward(self, input): + return input + + # Without empty optional arguments dictionary + x = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (x,)) + # With empty optional arguments dictionary + y = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (y, {})) + + def test_optional_inputs_with_mixed_optionals(self): + class MixedModel(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(MixedModel(), (x, y, None)) + self.run_test(MixedModel(), (x, None, z)) + # With optional arguments dictionary + self.run_test(MixedModel(), (x, {'y': y, 'z': None})) + self.run_test(MixedModel(), (x, {'y': None, 'z': z})) + self.run_test(MixedModel(), (x, {'z': z})) + self.run_test(MixedModel(), (x, {'y': y})) + + def test_optional_inputs_with_all_optionals(self): + class AllOptionalModel(torch.nn.Module): + def forward(self, y=None, z=None): + if y is not None: + return y + if z is not None: + return z + + y = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(AllOptionalModel(), (y, None)) + # With optional arguments dictionary + self.run_test(AllOptionalModel(), {'y': y, 'z': None}) + + def test_input_names_with_optional_args(self): + class NoOptionalModel(torch.nn.Module): + def forward(self, input): + return input + + # Without empty optional arguments dictionary + x = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (x,), input_names=['input_x']) + # With empty optional arguments dictionary + y = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (y, {})) + + class MixedModel(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(MixedModel(), (x, y, None), input_names=['input_x', 'input_y']) + self.run_test(MixedModel(), (x, None, z), input_names=['input_x', 'input_z']) + + # With optional arguments dictionary + self.run_test(MixedModel(), (x, {'y': y, 'z': None}), input_names=['input_x', 'input_y']) + self.run_test(MixedModel(), (x, {'y': None, 'z': z}), input_names=['input_x', 'input_z']) + + class AllOptionalModel(torch.nn.Module): + def forward(self, y=None, z=None): + if y is not None: + return y + if z is not None: + return z + + y = torch.randn(2, 3) + z = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(AllOptionalModel(), (y, None), input_names=['input_y']) + self.run_test(AllOptionalModel(), (None, z), input_names=['input_z']) + # With optional arguments dictionary + self.run_test(AllOptionalModel(), {'y': y, 'z': None}, input_names=['input_y']) + self.run_test(AllOptionalModel(), {'y': None, 'z': z}, input_names=['input_z']) + + @disableScriptTest() + def test_none_as_input(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if y is not None: + return x + y + return x + + x = torch.randn(2, 3) + self.run_test(Model(), (x, None)) + + @disableScriptTest() + def test_none_as_tuple_input(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if y[0] is not None: + return x + y[0] + if y[1] is not None: + return x + y[1] + return x + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(Model(), (x, (None, y))) + + @disableScriptTest() + def test_none_as_named_input(self): + class Model(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + + x = torch.randn(2, 3) + z = torch.randn(2, 3) + self.run_test(Model(), (x, None, z)) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_cste_script(self): class MyModel(torch.jit.ScriptModule): @@ -502,6 +839,27 @@ def forward(self, x): x = torch.arange(-5, 5).to(dtype=torch.float32) self.run_test(MyModel(), x) + def test_hardswish(self): + model = torch.nn.Hardswish() + + x = torch.rand(3, 3).to(dtype=torch.float32) + self.run_test(model, x) + + # Testing edge cases + x = torch.tensor(3).to(dtype=torch.float32) + self.run_test(model, x) + x = torch.tensor(-3).to(dtype=torch.float32) + self.run_test(model, x) + + def test_hardswish_script(self): + class MyModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return torch.nn.functional.hardswish(x) + + x = torch.rand(3, 3).to(dtype=torch.float32) + self.run_test(MyModel(), x) + def test_clamp(self): class ClampModel(torch.nn.Module): def forward(self, x): @@ -594,43 +952,34 @@ def __init__(self): def forward(self, input1, input2, input3): return self.conv1(input1), self.conv2(input2), self.conv3(input3) - class ScriptModel(torch.jit.ScriptModule): - def __init__(self): - super(ScriptModel, self).__init__() - self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2) - self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) - - @torch.jit.script_method - def forward(self, input1, input2, input3): - return self.conv1(input1), self.conv2(input2), self.conv3(input3) - x1 = torch.randn(20, 16, 50) x2 = torch.randn(20, 16, 50, 100) x3 = torch.randn(20, 16, 10, 50, 100) self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) - self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5) - def test_conv_transpose(self): - class TraceModel(torch.nn.Module): + def test_conv_shape_inference(self): + class Model(torch.nn.Module): def __init__(self): - super(TraceModel, self).__init__() - self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2) - self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) + super(Model, self).__init__() + self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - def forward(self, input1, input2, input3): - return self.conv1(input1), self.conv2(input2), self.conv3(input3) + def forward(self, input): + return self.conv2(input) + 2 - class ScriptModel(torch.jit.ScriptModule): + x = torch.randn(20, 16, 50, 100) + self.run_test(Model(), x, atol=10e-5, + input_names=['x'], + dynamic_axes={'x': [0]}) + + def test_conv_transpose(self): + class TraceModel(torch.nn.Module): def __init__(self): - super(ScriptModel, self).__init__() + super(TraceModel, self).__init__() self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2) self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) - @torch.jit.script_method def forward(self, input1, input2, input3): return self.conv1(input1), self.conv2(input2), self.conv3(input3) @@ -639,7 +988,6 @@ def forward(self, input1, input2, input3): x3 = torch.randn(20, 16, 10, 50, 100) self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) - self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5) # Conversion of Transpose depends on input shape to be known. # The following test only works when onnx shape inference is enabled. @@ -656,54 +1004,69 @@ def forward(self, x): return x.transpose(0, 1) x = torch.randn(32, 3, 64, 64) - self.run_test(TransposeModule(), x) + y = torch.randn(16, 3, 8, 64) + self.run_test(TransposeModule(), x, input_names=['x'], + dynamic_axes={'x': [0, 2]}, + test_with_inputs=[y]) def squeeze_model_tests(self, d, x1, x2): class Squeeze(torch.nn.Module): + def __init__(self, d): + super(Squeeze, self).__init__() + self.d = d + def forward(self, x): - if d is not None: - return torch.squeeze(x, dim=d) + if self.d is not None: + return torch.squeeze(x, dim=self.d) else: return torch.squeeze(x) x2 = [] if x2 is None else [x2] - self.run_test(Squeeze(), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2) + if len(x2) > 0: + self.run_test(Squeeze(d), x1, + input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, + test_with_inputs=x2) + else: + self.run_test(Squeeze(d), x1) + @skipIfUnsupportedOpsetVersion([13]) def test_squeeze_without_no_op(self): x = torch.randn(2, 1, 4) self.squeeze_model_tests(1, x, None) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - def test_squeeze(self): + def test_squeeze_dynamic(self): x_squeeze = torch.randn(2, 1, 4) x_noop = torch.randn(2, 2, 3) self.squeeze_model_tests(1, x_squeeze, x_noop) + @skipIfUnsupportedOpsetVersion([13]) def test_squeeze_neg_without_no_op(self): x = torch.randn(2, 1, 4) self.squeeze_model_tests(-2, x, None) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_squeeze_neg(self): x_squeeze = torch.randn(2, 1, 4) x_noop = torch.randn(2, 2, 3) self.squeeze_model_tests(-2, x_squeeze, x_noop) + @skipIfUnsupportedOpsetVersion([13]) def test_squeeze_all_dims(self): x_squeeze = torch.randn(2, 1, 4) x_noop = torch.randn(2, 2, 3) self.squeeze_model_tests(None, x_squeeze, x_noop) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_squeeze_no_op(self): x_noop = torch.randn(2, 1, 4) x_squeeze = torch.randn(2, 2, 1) self.squeeze_model_tests(2, x_noop, x_squeeze) - def test_squeeze_no_op_without_additional_inputs(self): - x_noop = torch.randn(2, 1, 4) - self.squeeze_model_tests(2, x_noop, None) - + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_squeeze_runtime_dim(self): class Squeeze(torch.nn.Module): @@ -717,6 +1080,7 @@ def forward(self, d1, d2): self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)]) self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)]) + @skipIfUnsupportedOpsetVersion([13]) def test_unsqueeze(self): class Unsqueeze(torch.nn.Module): def forward(self, x): @@ -738,7 +1102,10 @@ def forward(self, x): def test_maxpool_adaptive(self): model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False) x = torch.randn(20, 16, 50, requires_grad=True) - self.run_test(model, x) + y = torch.randn(32, 16, 50, requires_grad=True) + self.run_test(model, x, input_names=['x'], + dynamic_axes={'x' : [0]}, + test_with_inputs=[y]) def test_maxpool_2d(self): model = torch.nn.MaxPool2d(5, padding=(1, 2)) @@ -761,7 +1128,7 @@ def test_maxpool_3d_ceil(self): self.run_test(model, x) @skipIfUnsupportedMinOpsetVersion(8) - @disableScriptTest() + @disableScriptTest() # Functional module not scriptable def test_maxpool_with_indices(self): model = torch.nn.MaxPool1d(2, stride=1, return_indices=True) x = torch.randn(20, 16, 50) @@ -800,7 +1167,65 @@ def test_avgpool_2d_ceil(self): def test_avgpool_3d_ceil(self): model = torch.nn.AvgPool3d(3, 2, ceil_mode=True) x = torch.randn(20, 16, 50, 44, 31) - self.run_test(model, x) + y = torch.randn(32, 8, 50, 44, 31) + self.run_test(model, x, input_names=['x'], + dynamic_axes={'x' : [0, 1]}, + test_with_inputs=[y]) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_floating_point(self): + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.is_floating_point(): + return x.new_zeros(x.shape) + return x.new_zeros(x.shape) + + x = torch.randn(2, 3, 4) + self.run_test(FloatingPoint(), x) + + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.size(0) > 1: + a = x + 2 + if a.is_floating_point(): + return x + 1 + return x + 1 + return x + + x = torch.randn(2, 3, 4) + self.run_test(FloatingPoint(), x) + + @unittest.skip("If operator rank mismatch between outputs of two branches.") + @skipIfUnsupportedMinOpsetVersion(9) + @skipIfONNXShapeInference(False) + def test_floating_point_infer_dtype(self): + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.size(0) > 1: + a = x + 2 + if a.is_floating_point(): + return x.new_zeros(x.shape[1:]) + return x.new_zeros(x.shape) + return x + + x = torch.randn(2, 3, 4) + self.run_test(FloatingPoint(), x) + + class FloatingPoint(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if x.size(0) > 1: + a = x + 2 + if a.is_floating_point(): + return x + 1 + return x + return x + + x = torch.randn(2, 3, 4).to(torch.int32) + self.run_test(FloatingPoint(), x) def test_arithmetic(self): class ArithmeticModule(torch.nn.Module): @@ -814,7 +1239,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(ArithmeticModule(), x) - @disableScriptTest() # In scripting the first transpose node do not carry shape and dtype info. # The following test only works when onnx shape inference is enabled. @skipIfONNXShapeInference(False) @@ -856,6 +1280,7 @@ def forward(self, x, y): y = torch.randn(2, 3, 4) self.run_test(FloorDivModule(), (x, y)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_floordiv(self): class FloordivModule(torch.nn.Module): @@ -868,7 +1293,7 @@ def forward(self, x): def test_div(self): class DivModule(torch.nn.Module): def forward(self, x, y): - return x / y + return x / y, torch.true_divide(x, y) x = torch.randn(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) @@ -882,7 +1307,7 @@ def forward(self, x, y): def test_div_promotion_trace(self): class DivModule(torch.nn.Module): def forward(self, x, y): - return x / y + return x / y, torch.true_divide(x, y) x = torch.randn(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) @@ -900,14 +1325,14 @@ def forward(self, x, y): # In scripting x, y do not carry shape and dtype info. # The following test only works when onnx shape inference is enabled. @skipIfONNXShapeInference(False) - def test_true_div_script(self): - class TrueDivModule(torch.nn.Module): + def test_div_promotion_script(self): + class DivModule(torch.nn.Module): def forward(self, x, y): # Add transpose to hide shape/type information # Otherwise shape and type are still avaiable from input. x = x.transpose(1, 2) y = y.transpose(1, 2) - return torch.true_divide(x, y) + return x / y, torch.true_divide(x, y) x = torch.randn(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) @@ -918,21 +1343,22 @@ def forward(self, x, y): # This can be handled by the default case, where both are cast to float. # It works even if type of x, y are unknown. torch.set_default_dtype(torch.float) - self.run_test(torch.jit.script(TrueDivModule()), (x, y)) + self.run_test(torch.jit.script(DivModule()), (x, y)) # 2. x,y are int, and output is double. # This can be handled by the default case, where both are cast to double. # It works even if type of x, y are unknown. torch.set_default_dtype(torch.double) - self.run_test(torch.jit.script(TrueDivModule()), (x, y)) + self.run_test(torch.jit.script(DivModule()), (x, y)) # 3. x is int, y is double, and output is double. # This can only be handled when both type of x and y are known. torch.set_default_dtype(prev_default) x = torch.randn(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double) - self.run_test(torch.jit.script(TrueDivModule()), (x, y)) + self.run_test(torch.jit.script(DivModule()), (x, y)) + @skipIfUnsupportedOpsetVersion([13]) def test_slice_trace(self): class MyModule(torch.nn.Module): def forward(self, x): @@ -941,6 +1367,7 @@ def forward(self, x): x = torch.randn(3) self.run_test(MyModule(), x) + @skipIfUnsupportedOpsetVersion([13]) def test_slice_neg(self): class NegSlice(torch.nn.Module): def forward(self, x): @@ -949,6 +1376,7 @@ def forward(self, x): x = torch.randn(3, 4, 5) self.run_test(NegSlice(), x) + @skipIfUnsupportedOpsetVersion([13]) def test_slice_neg_large(self): class NegSlice(torch.nn.Module): def forward(self, x): @@ -957,6 +1385,7 @@ def forward(self, x): x = torch.randn(3, 4, 5, 6, 7) self.run_test(NegSlice(), x) + @skipIfUnsupportedOpsetVersion([13]) def test_slice_neg_large_negone(self): class NegSlice(torch.nn.Module): def forward(self, x): @@ -965,6 +1394,7 @@ def forward(self, x): x = torch.randn(3, 4, 5, 6, 7) self.run_test(NegSlice(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_slice_with_input_index(self): class InputIndexSlice(torch.nn.Module): @@ -976,8 +1406,9 @@ def forward(self, x, y): y = torch.rand((22, 256)) self.run_test(InputIndexSlice(), (x, y)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) - @disableScriptTest() + @disableScriptTest() # scripting tuple/list append def test_slice_dynamic(self): class DynamicSliceExportMod(torch.nn.Module): def forward(self, x): @@ -994,6 +1425,7 @@ def forward(self, x): dynamic_axes={'input_1': [0, 1, 2], 'output_1': [0, 1, 2]}) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) def test_slice_dynamic_script(self): class DynamicSliceModel(torch.jit.ScriptModule): @@ -1004,6 +1436,7 @@ def forward(self, x): x = torch.rand(1, 2) self.run_test(DynamicSliceModel(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) def test_slice_dynamic_shape_script(self): class DynamicSliceModel(torch.nn.Module): @@ -1013,8 +1446,9 @@ def forward(self, x): x = torch.rand(1, 2, 3, 4) self.run_test(DynamicSliceModel(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) - @disableScriptTest() + @disableScriptTest() # scripting tuple/list append def test_slice_dynamic_to_end(self): class DynamicSliceExportMod(torch.nn.Module): def forward(self, x): @@ -1028,6 +1462,14 @@ def forward(self, x): dynamic_axes={'input_1': [0, 1, 2], 'output_1': [0, 1, 2]}) + def test_square(self): + class Square(torch.nn.Module): + def forward(self, x): + return torch.square(x) + + x = torch.randn(2, 3, 4) + self.run_test(Square(), x) + @skipIfUnsupportedMinOpsetVersion(9) def test_arange_dynamic(self): class ArangeModel(torch.nn.Module): @@ -1111,6 +1553,7 @@ def forward(self, end): x = torch.tensor(6.2, dtype=torch.float) self.run_test(ArangeModel(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_size(self): class SizeModel(torch.nn.Module): @@ -1120,8 +1563,9 @@ def forward(self, input): x = torch.randn(5, 3, 2) self.run_test(SizeModel(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() + @disableScriptTest() # x.stride() not scriptable def test_as_strided(self): class Model(torch.nn.Module): def forward(self, x): @@ -1134,29 +1578,47 @@ def forward(self, x): x = torch.randn(5, 8, 7) self.run_test(Model(), x) - def _test_index_generic(self, fn): + @skipIfUnsupportedOpsetVersion([13]) + @disableScriptTest() # Ellipses followed by tensor indexing not scriptable + def test_tensor_index_advanced_indexing_ellipsis(self): class MyModel(torch.nn.Module): - def __init__(self): - super(MyModel, self).__init__() - def forward(self, input): - return fn(input) + return input[..., torch.tensor([2, 1]), torch.tensor([0, 3])] m1 = torch.randn(3, 4, 5, 6, 7) - self.run_test(MyModel(), m1) + self.run_test(MyModel(), (m1,)) - @disableScriptTest() + @skipIfUnsupportedOpsetVersion([13]) def test_tensor_index_advanced_indexing(self): - self._test_index_generic( - lambda input: input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])]) - self._test_index_generic(lambda input: input[..., torch.tensor([2, 1]), torch.tensor([0, 3])]) - self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])]) - self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])]) + class MyModel(torch.nn.Module): + def forward(self, input): + return input[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])] - @disableScriptTest() + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), (m1,)) + + class MyModel(torch.nn.Module): + def forward(self, input): + return input[:, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]])] + + self.run_test(MyModel(), (m1,)) + + class MyModel(torch.nn.Module): + def forward(self, input): + return input[:, torch.tensor([0, 2]), torch.tensor([1]), 2:4, torch.tensor([[1], [4]])] + + self.run_test(MyModel(), (m1,)) + + @skipIfUnsupportedOpsetVersion([13]) def test_tensor_index_advanced_indexing_consecutive(self): - self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None]) + class MyModel(torch.nn.Module): + def forward(self, input): + return input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None] + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), (m1,)) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put(self): class IndexPutModel(torch.nn.Module): @@ -1169,6 +1631,7 @@ def forward(self, x, ind, update): update = torch.ones(4) self.run_test(IndexPutModel(), (x, ind, update)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_accumulate(self): class IndexPutModel(torch.nn.Module): @@ -1180,8 +1643,8 @@ def forward(self, x, ind, update): update = torch.ones(4) self.run_test(IndexPutModel(), (x, ind, update)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() def test_index_put_slice_index(self): class IndexPutModel(torch.nn.Module): def forward(self, x, update): @@ -1255,8 +1718,9 @@ def forward(self, x, update): update = torch.arange(3 * 5).to(torch.float).view(3, 5) self.run_test(IndexPutModel8(), (x, update)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() + @disableScriptTest() # Ellipses followed by tensor indexing not scriptable def test_index_put_ellipsis(self): class IndexPutModel(torch.nn.Module): def forward(self, x, update): @@ -1276,8 +1740,38 @@ def forward(self, x, update): update = torch.randn(4, 1, 3, 2) self.run_test(IndexPutModel2(), (x, update)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() + def test_index_put_loop(self): + @torch.jit.script + def ngram_attention_bias(sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype): + bias = torch.ones((ngram, sequence_length), device=device, dtype=dtype) * float("-inf") + for stream_idx in range(ngram): + for i in range(sequence_length): + bias[stream_idx, i] = 5 + return bias + + class ScriptModel(torch.nn.Module): + def __init__(self): + super(ScriptModel, self).__init__() + self.ngram = 2 + self.max_target_positions = 512 + + def forward(self, hidden_states): + seq_length, batch_size = hidden_states.shape[:2] + predict_causal_mask = ngram_attention_bias( + self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype + ) + predict_causal_mask = predict_causal_mask[:, :seq_length] + return predict_causal_mask + + x = torch.randn(6, 2) + y = torch.randn(4, 1) + self.run_test(ScriptModel(), x, input_names=['x'], + dynamic_axes={'x': {0: 'seq_length', 1: 'batch_size'}}, test_with_inputs=[y]) + + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfUnsupportedOpsetVersion([13]) def test_copy_(self): class CopyModel(torch.nn.Module): def forward(self, x, data): @@ -1319,9 +1813,6 @@ def forward(self, x, data): update = torch.randn(2) self.run_test(CopyModel3(), (x, update)) - update = torch.randn(1, 2) - self.run_test(CopyModel3(), (x, update)) - class CopyModel4(torch.nn.Module): def forward(self, x, ind, data): x[ind] = data @@ -1332,8 +1823,21 @@ def forward(self, x, ind, data): data = torch.randn(4) self.run_test(CopyModel4(), (x, ind, data)) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape) + def test_copy_tracing(self): + class CopyModel(torch.nn.Module): + def forward(self, x, data): + x[1, 1:3] = data + return x + + x = torch.randn(3, 4) + update = torch.randn(1, 2) + self.run_test(CopyModel(), (x, update)) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() def test_copy_ellipsis(self): class CopyModel(torch.nn.Module): def forward(self, x, update): @@ -1348,14 +1852,20 @@ def forward(self, x, update): update = torch.ones(1) self.run_test(CopyModel(), (x, update)) - class CopyModel2(torch.nn.Module): + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + # TODO: Limited scripting support with ellipsis indexing. + # Due to dependency on input tensor rank being known. + def test_copy_ellipsis_tracing(self): + class CopyModel(torch.nn.Module): def forward(self, x, update): x[2, ..., 1:3] = update return x x = torch.randn(3, 4, 5, 6) + update = torch.ones(1) - self.run_test(CopyModel2(), (x, update)) + self.run_test(CopyModel(), (x, update)) @skipIfUnsupportedMinOpsetVersion(10) def test_flip(self): @@ -1381,7 +1891,7 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Rand(), x) - @disableScriptTest() + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_random_dynamic_size(self): class RandN(torch.nn.Module): @@ -1415,7 +1925,6 @@ def forward(self, x): self.run_test(RandLike(), x) self.run_test(torch.jit.script(RandLike()), x) - @disableScriptTest() def test_random_like_dtype(self): class RandNLike(torch.nn.Module): def forward(self, x): @@ -1545,10 +2054,12 @@ def _interpolate_tests(self, is_upsample): self._interpolate_script(xi, mode_i, False, is_upsample, True) self._interpolate_script(xi, mode_i, False, is_upsample) + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_interpolate_upsample(self): self._interpolate_tests(True) + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(9) def test_interpolate_function_substitution(self): @@ -1579,11 +2090,13 @@ def forward(self, x): self.run_test(TracingModule(), (x,)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) @disableScriptTest() def test_interpolate_downsample(self): self._interpolate_tests(False) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() def test_interpolate_no_shape(self): @@ -1599,7 +2112,16 @@ def forward(self, x, y): y = torch.randn(16, 16, requires_grad=True) self.run_test(MyModel(), (x, y)) - @disableScriptTest() + @skipIfUnsupportedOpsetVersion([13]) + def test_interpolate_adaptive_pooling_error(self): + x = torch.randn(1, 2, 6, requires_grad=True) + with self.assertRaises(RuntimeError) as cm: + self._interpolate(x, "area", True, True) + + with self.assertRaises(RuntimeError) as cm: + self._interpolate(x, "area", False, True) + + @skipIfUnsupportedOpsetVersion([13]) def test_groupnorm(self): model = torch.nn.GroupNorm(3, 6, 0.002) x = torch.randn(4, 6, 180, 180, 180) @@ -1613,6 +2135,7 @@ def test_groupnorm(self): x = torch.randn(4, 6, 180, 180) self.run_test(model, x) + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_groupnorm_noaffine(self): model = torch.nn.GroupNorm(4, 8, 0.002, affine=False) @@ -1627,14 +2150,26 @@ def test_groupnorm_noaffine(self): x = torch.randn(4, 6, 180, 180) self.run_test(model, x) - def test_std(self): - class StandardDeviation(torch.nn.Module): - def forward(self, input): - return torch.std(input, unbiased=False) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(9) + def test_listunpack(self): + class ListUnpack(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + a, b = x.shape + return x.new_zeros((a, b)) - x = torch.randn(2, 3, 4) - model = StandardDeviation() - self.run_test(model, x) + x = torch.randn(2, 3) + self.run_test(ListUnpack(), x) + + class ListUnpackSlice(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + a, b = x.shape[2:] + return x.new_zeros((a, b)) + + x = torch.randn(2, 3, 4, 5) + self.run_test(ListUnpackSlice(), x) def test_pow(self): class PowModule(torch.nn.Module): @@ -1657,24 +2192,240 @@ def forward(self, x, y): y = torch.randint(10, (2, 3, 4)) self.run_test(PowModule(), (x, y)) - def test_std_along_dims(self): + def test_std(self): class StandardDeviation(torch.nn.Module): def forward(self, input): - return torch.std(input, dim=(0, 1), unbiased=False) + return torch.std(input, unbiased=False) x = torch.randn(2, 3, 4) model = StandardDeviation() self.run_test(model, x) - def test_std_keepdim(self): - class StandardDeviation(torch.nn.Module): + class StandardDeviationUnbiased(torch.nn.Module): def forward(self, input): - return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True) + return torch.std(input, unbiased=True) - x = torch.randn(2, 3, 4) + model = StandardDeviationUnbiased() + self.run_test(model, x) + + def test_std_along_dims(self): + class StandardDeviation(torch.nn.Module): + def forward(self, input): + return torch.std(input, dim=(0, 1), unbiased=False) + + x = torch.randn(2, 3, 4) + model = StandardDeviation() + self.run_test(model, x) + + class StandardDeviationUnbiased(torch.nn.Module): + def forward(self, input): + return torch.std(input, dim=(0, 1), unbiased=True) + + x = torch.randn(2, 3, 4) + model = StandardDeviationUnbiased() + self.run_test(model, x) + + def test_std_keepdim(self): + class StandardDeviation(torch.nn.Module): + def forward(self, input): + return torch.std(input, dim=(0, 1), unbiased=False, keepdim=True) + + x = torch.randn(2, 3, 4) + model = StandardDeviation() + self.run_test(model, x) + + class StandardDeviationUnbiased(torch.nn.Module): + def forward(self, input): + return torch.std(input, dim=(0, 1), unbiased=True, keepdim=True) + + x = torch.randn(2, 3, 4) + model = StandardDeviationUnbiased() + self.run_test(model, x) + + def test_var(self): + class Variance(torch.nn.Module): + def forward(self, input): + return torch.var(input, unbiased=False) + + x = torch.randn(2, 3, 4) + model = Variance() + self.run_test(model, x) + + class VarianceUnbiased(torch.nn.Module): + def forward(self, input): + return torch.var(input, unbiased=True) + + model = VarianceUnbiased() + self.run_test(model, x) + + class VarianceSqrt(torch.nn.Module): + def forward(self, input): + y = torch.var(input, 1) + return torch.sqrt(y + 1e-8) + + x = torch.randn(1, 2, 3, 300, 300) + model = VarianceSqrt() + self.run_test(model, x) + + def test_var_along_dims(self): + class Variance(torch.nn.Module): + def forward(self, input): + return torch.var(input, dim=(0, 1), unbiased=False) + + x = torch.randn(2, 3, 4) + model = Variance() + self.run_test(model, x) + + class VarianceUnbiased(torch.nn.Module): + def forward(self, input): + return torch.var(input, dim=(0, 1), unbiased=True) + + x = torch.randn(2, 3, 4) + model = VarianceUnbiased() + self.run_test(model, x) + + def test_var_keepdim(self): + class Variance(torch.nn.Module): + def forward(self, input): + return torch.var(input, dim=(0, 1), unbiased=False, keepdim=True) + + x = torch.randn(2, 3, 4) + model = Variance() + self.run_test(model, x) + + class VarianceUnbiased(torch.nn.Module): + def forward(self, input): + return torch.var(input, dim=(0, 1), unbiased=True, keepdim=True) + + x = torch.randn(2, 3, 4) + model = VarianceUnbiased() + self.run_test(model, x) + + def test_var_mean(self): + class Variance(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, unbiased=False) + + x = torch.randn(2, 3, 4) + model = Variance() + self.run_test(model, x) + + class VarianceUnbiased(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, unbiased=True) + + model = VarianceUnbiased() + self.run_test(model, x) + + def test_var_mean_along_dims(self): + class Variance(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, dim=(0, 1), unbiased=False) + + x = torch.randn(2, 3, 4) + model = Variance() + self.run_test(model, x) + + class VarianceUnbiased(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, dim=(0, 1), unbiased=True) + + x = torch.randn(2, 3, 4) + model = VarianceUnbiased() + self.run_test(model, x) + + def test_var_mean_mixed_dims(self): + class ReverseDims(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, dim=(2, 1), unbiased=False) + + x = torch.randn(2, 3, 4) + model = ReverseDims() + self.run_test(model, x) + + class SkipDims(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, dim=(0, 2), unbiased=False) + + x = torch.randn(2, 3, 4) + model = SkipDims() + self.run_test(model, x) + + class NonZeroDims(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, dim=(1, 2), unbiased=False) + + x = torch.randn(2, 3, 4) + model = NonZeroDims() + self.run_test(model, x) + + def test_var_mean_keepdim(self): + class Variance(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, dim=(0, 1), unbiased=False, keepdim=True) + + x = torch.randn(2, 3, 4) + model = Variance() + self.run_test(model, x) + + class VarianceUnbiased(torch.nn.Module): + def forward(self, input): + return torch.var_mean(input, dim=(0, 1), unbiased=True, keepdim=True) + + x = torch.randn(2, 3, 4) + model = VarianceUnbiased() + self.run_test(model, x) + + def test_std_mean(self): + class StandardDeviation(torch.nn.Module): + def forward(self, input): + return torch.std_mean(input, unbiased=False) + + x = torch.randn(2, 3, 4) + model = StandardDeviation() + self.run_test(model, x) + + class StandardDeviationUnbiased(torch.nn.Module): + def forward(self, input): + return torch.std_mean(input, unbiased=True) + + model = StandardDeviationUnbiased() + self.run_test(model, x) + + def test_std_mean_along_dims(self): + class StandardDeviation(torch.nn.Module): + def forward(self, input): + return torch.std_mean(input, dim=(0, 1), unbiased=False) + + x = torch.randn(2, 3, 4) + model = StandardDeviation() + self.run_test(model, x) + + class VarianceUnbiased(torch.nn.Module): + def forward(self, input): + return torch.std_mean(input, dim=(0, 1), unbiased=True) + + x = torch.randn(2, 3, 4) + model = VarianceUnbiased() + self.run_test(model, x) + + def test_std_mean_keepdim(self): + class StandardDeviation(torch.nn.Module): + def forward(self, input): + return torch.std_mean(input, dim=(0, 1), unbiased=False, keepdim=True) + + x = torch.randn(2, 3, 4) model = StandardDeviation() self.run_test(model, x) + class StandardDeviationUnbiased(torch.nn.Module): + def forward(self, input): + return torch.std_mean(input, dim=(0, 1), unbiased=True, keepdim=True) + + x = torch.randn(2, 3, 4) + model = StandardDeviationUnbiased() + self.run_test(model, x) + def test_bitshift(self): class BitshiftModel(torch.nn.Module): def forward(self, input, input2): @@ -1703,6 +2454,7 @@ def forward(self, input, input2): input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) self.run_test(BitshiftModel(), (input, input2)) + @skipIfUnsupportedOpsetVersion([13]) def test_narrow(self): class NarrowModel(torch.nn.Module): def forward(self, input): @@ -1711,6 +2463,17 @@ def forward(self, input): x = torch.randn(3, 3, requires_grad=True) self.run_test(NarrowModel(), x) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_narrow_dynamic(self): + class NarrowModel(torch.nn.Module): + def forward(self, input): + return torch.narrow(input, 0, 0, input.shape[0] - 1) + + x = torch.randn(3, 3, requires_grad=True) + self.run_test(NarrowModel(), x) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_index_fill(self): class IndexFillModel(torch.nn.Module): @@ -1721,6 +2484,7 @@ def forward(self, input): x = torch.randn(3, 4, 5, requires_grad=True) self.run_test(IndexFillModel(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_index_copy(self): class IndexCopyModel(torch.nn.Module): @@ -1748,8 +2512,6 @@ def forward(self, x): x = torch.randn(3, 4) self.run_test(Select(), x) - # TODO: enable for opset 10 when ONNXRuntime version will be updated - def test_index_select_constant_scaler_index(self): class IndexSelectScalerIndexModel(torch.nn.Module): def forward(self, x): @@ -1758,7 +2520,6 @@ def forward(self, x): x = torch.randn(3, 4) self.run_test(IndexSelectScalerIndexModel(), x) - @disableScriptTest() def test_index_select_scaler_index(self): class IndexSelectScalerIndexModel(torch.nn.Module): def __init__(self, index_base): @@ -1817,7 +2578,6 @@ def forward(self, x, k): self.run_test(MyModuleDynamic(), [x, k]) @skipIfUnsupportedOpsetVersion([7]) - @disableScriptTest() def test_normalize(self): class Model(torch.nn.Module): def forward(self, x): @@ -1847,6 +2607,14 @@ def test_batchnorm1d_noaffine(self): x = torch.randn(10, 10, 128) self.run_test(model, x) + def test_batchnorm1d_norunningstats(self): + x = torch.randn(10, 10) + model = torch.nn.BatchNorm1d(10, track_running_stats=False) + self.run_test(model, x) + + x = torch.randn(10, 10, 128) + self.run_test(model, x) + def test_batchnorm2d(self): x = torch.randn(10, 3, 128, 128) model = torch.nn.BatchNorm2d(3, affine=True) @@ -1857,6 +2625,11 @@ def test_batchnorm2d_noaffine(self): model = torch.nn.BatchNorm2d(3, affine=False) self.run_test(model, x) + def test_batchnorm2d_norunningstats(self): + x = torch.randn(10, 3, 128, 128) + model = torch.nn.BatchNorm2d(3, track_running_stats=False) + self.run_test(model, x) + def test_batchnorm3d(self): x = torch.randn(10, 3, 128, 128, 128) model = torch.nn.BatchNorm3d(3, affine=True) @@ -1953,8 +2726,8 @@ def forward(self, input, indices): indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) self.run_test(GatherModel(), input=(input, indices)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() def test_expand(self): class ExpandModel(torch.nn.Module): def forward(self, input): @@ -1975,7 +2748,7 @@ def forward(self, input, size): return input.expand(size) input = torch.randn(3,) - size = torch.tensor([-1]) + size = torch.tensor(-1) self.run_test(ExpandTensorSizeModel(), input=(input, size)) def test_multinomial(self): @@ -2025,6 +2798,7 @@ def forward(self, input): x = torch.randn(4, 5, dtype=torch.float) self.run_test(ReducedOpModule(), x) + @skipIfUnsupportedOpsetVersion([13]) def test_reduced_sum(self): return self._test_reduced_ops(op=torch.sum) @@ -2104,7 +2878,17 @@ def test_logsoftmax_dim(self): input = torch.randn(3, 4, 5, 6) self.run_test(model, input) + def test_logsoftmax_dtype(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.log_softmax(x, dim=1, dtype=torch.float64) + + x = torch.randn(3, 4, 5, requires_grad=True) + self.run_test(Model(), x) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) + @disableScriptTest() # scripting prim_dtype def test_lstm_no_hidden(self): class LSTMModel(torch.nn.Module): def __init__(self): @@ -2117,7 +2901,24 @@ def forward(self, x): input = torch.randn((10, 16, 16)) self.run_test(LSTMModel(), (input,)) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(9) + @disableScriptTest() # scripting prim_dtype + def test_lstm_proj_no_hidden(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16, proj_size=8) + + def forward(self, x): + return self.rnn(x) + + input = torch.randn((10, 16, 16)) + with self.assertRaises(RuntimeError): + self.run_test(LSTMModel(), (input,)) + @skipIfUnsupportedMinOpsetVersion(9) + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_lstm(self): model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False) @@ -2126,6 +2927,7 @@ def test_lstm(self): c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) self.run_test(model, (input, (h0, c0))) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() def test_lstm_default_init_state(self): @@ -2133,8 +2935,9 @@ def test_lstm_default_init_state(self): input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) self.run_test(model, input) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() + @disableScriptTest() # LSTMModel model not scriptable def test_lstm_fixed_batch_size(self): class LSTMModel(torch.nn.Module): def __init__(self): @@ -2154,6 +2957,7 @@ def forward(self, input): input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) self.run_test(LSTMModel(), input, fixed_batch_size=True, test_with_inputs=[input2]) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() def test_lstm_post_fix_init_state(self): @@ -2178,6 +2982,7 @@ def forward(self, input): self.run_test(model, input, dynamic_axes={'input' : {0 : 'seq', 1 : 'batch'}}, test_with_inputs=[input2]) + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_lstm_constant_folding(self): class LstmNet(torch.nn.Module): @@ -2205,6 +3010,7 @@ def get_LstmNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False) self.run_test(model2, input2, do_constant_folding=True) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() def test_lstm_no_bias(self): @@ -2230,6 +3036,7 @@ def get_LstmNet_model_and_inputs(num_layers, bidirectional): for model, input in models_and_inputs: self.run_test(model, input) + @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_rnn_no_bias(self): def make_model(layers, packed_sequence): @@ -2269,6 +3076,7 @@ def make_input(batch_size, layers, packed_sequence): for model, input in zip(models, inputs): self.run_test(model, input, batch_size=RNN_BATCH_SIZE) + @skipIfUnsupportedOpsetVersion([13]) def test_gru_no_bias(self): class GruNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers, bidirectional): @@ -2298,6 +3106,7 @@ def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size, for model, input in models_and_inputs: self.run_test(model, input, do_constant_folding=True) + @skipIfUnsupportedOpsetVersion([13]) def test_gru_constant_folding(self): class GruNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers, bidirectional): @@ -2560,6 +3369,7 @@ def test_argmin_argmax_select_last_index(self): input = torch.ones(7, 3, 5) self._argmin_argmax_model(input) + @skipIfUnsupportedOpsetVersion([13]) def test_repeat(self): class RepeatModel(torch.nn.Module): def forward(self, x, y): @@ -2579,6 +3389,7 @@ def forward(self, input): x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32) self.run_test(ViewModel(), x) + @skipIfUnsupportedOpsetVersion([13]) def test_view_dynamic(self): class ViewModel(torch.nn.Module): def forward(self, input, other): @@ -2588,6 +3399,18 @@ def forward(self, input, other): shape = torch.randn(6, 4) self.run_test(ViewModel(), (x, shape)) + @skipIfUnsupportedOpsetVersion([13]) + def test_view_dynamic_zero_dim(self): + class ViewModel(torch.nn.Module): + def forward(self, input): + input = input.view(-1, 2) + return input.view(1, -1) + + x = torch.ones(2) + another_x = torch.empty((0,)) + self.run_test(ViewModel(), x, test_with_inputs=[another_x], + input_names=['input_1'], dynamic_axes={'input_1': [0, ]}) + def test_view_as(self): class ViewModel(torch.nn.Module): def forward(self, input, other): @@ -2597,12 +3420,17 @@ def forward(self, input, other): y = torch.randn(6, 4) self.run_test(ViewModel(), (x, y)) - @disableScriptTest() def test_weight_norm(self): + # addmm for 3-d inputs converts to onnx::MatMul model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1) x = torch.randn(3, 4, 5, requires_grad=True) self.run_test(model, x) + # addmm for 2-d inputs converts to onnx::Gemm + model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1) + x = torch.randn(4, 5, requires_grad=True) + self.run_test(model, x) + model = torch.nn.utils.weight_norm(torch.nn.Conv1d(1, 1, 3)) x = torch.randn(1, 1, 5, requires_grad=True) self.run_test(model, x) @@ -2615,12 +3443,17 @@ def test_weight_norm(self): x = torch.randn(3, 3, 5, requires_grad=True) self.run_test(model, x) - @disableScriptTest() def test_weight_norm_nodim(self): + # addmm for 3-d inputs converts to onnx::MatMul model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None) x = torch.randn(3, 4, 5, requires_grad=True) self.run_test(model, x) + # addmm for 2-d inputs converts to onnx::Gemm + model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None) + x = torch.randn(4, 5, requires_grad=True) + self.run_test(model, x) + def test_flatten(self): class FlattenModel(torch.nn.Module): def forward(self, input): @@ -2645,7 +3478,6 @@ def forward(self, x): x = torch.randint(10, (1, 2, 3, 4)) self.run_test(FlattenModel(), x) - @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(9) def test_flatten_dynamic_axes(self): class MyModule(torch.nn.Module): @@ -2680,11 +3512,22 @@ def forward(self, x, y, z, ind): ind = torch.tensor(-2, dtype=torch.long) self.run_test(GetItemModel(), (x, y, z, ind)) - @disableScriptTest() + @skipIfUnsupportedOpsetVersion([13]) + @disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable. + @skipIfUnsupportedMinOpsetVersion(9) + def test_nonzero(self): + class NonzeroModel(torch.nn.Module): + def forward(self, x): + return x.nonzero(), x.nonzero(as_tuple=True) + + x = torch.randn(60).index_fill_(0, torch.randint(0, 60, (20,)), 0).view(3, 4, 5) + self.run_test(NonzeroModel(), (x,)) + def test_unbind(self): class UnbindModel(torch.nn.Module): def forward(self, input): - return input.unbind() + _, out, _ = input.unbind() + return out x = torch.randn(3, 4, 5) self.run_test(UnbindModel(), x) @@ -2716,12 +3559,13 @@ def forward(self, input): self.run_test(LenModel(), x, input_names=['input'], dynamic_axes={'input': {0: 'seq'}}, test_with_inputs=(torch.randn(5, 5),)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_len_list(self): class LenListModel(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, input): - return torch.ones(len(input.shape)) + return torch.ones(len(input.shape)) x = torch.randn(4, 5) self.run_test(LenListModel(), x) @@ -2744,18 +3588,19 @@ def forward(self, input): x = torch.randn(3, 4, 5) self.run_test(UnbindModel2(), x) - @disableScriptTest() def test_split(self): class SplitModel(torch.nn.Module): def forward(self, input): - return input.split([2, 1, 2]) + out1, out2, out3 = input.split([2, 1, 2]) + return out1, out2, out3 x = torch.randn(5, 4, 3) self.run_test(SplitModel(), x) class SplitModel2(torch.nn.Module): def forward(self, input): - return input.split([2, 1, 1], -2) + out1, out2, out3 = input.split([2, 1, 1], -2) + return out1, out2, out3 x = torch.randn(5, 4, 3) self.run_test(SplitModel2(), x) @@ -2768,22 +3613,26 @@ def forward(self, input): x = torch.randn(5, 4, 3) self.run_test(torch.jit.script(SplitModel3()), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() def test_split_size_as_list(self): class SplitModel(torch.nn.Module): - def forward(self, input): + def forward(self, input, split_sizes: List[int]): out = [] - split_sizes = [input.shape[0] - 1, 1] - for ob in input.split(split_sizes): + split_list: List[torch.Tensor] = input.split(split_sizes) + + for ob in split_list: out.append(ob) return torch.cat(out, dim=0) - x = torch.randn(5, 4, 3) - self.run_test(SplitModel(), x) + x = torch.randn(6, 4, 3) + split_sizes = [torch.tensor(2), torch.tensor(4)] + self.run_test(SplitModel(), (x, split_sizes)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - def test_split_size_list_to_slice(self): + def test_split_size_with_slice(self): class SplitModule(torch.nn.Module): def forward(self, x, y, t): splits = (x.size(1), y.size(1)) @@ -2813,6 +3662,25 @@ def forward(self, input): x = torch.randn(5, 4, 3) self.run_test(SplitModel2(), x) + @skipIfUnsupportedMinOpsetVersion(11) + def test_chunk(self): + class ChunkModel(torch.nn.Module): + def __init__(self): + super(ChunkModel, self).__init__() + + def forward(self, x): + return torch.chunk(x, 3, dim=1) + + model = ChunkModel() + model.eval() + x = torch.randn(1, 18) + + for dim_size_ in range(13, 16): + y = torch.randn(1, dim_size_) + self.run_test(model, x, test_with_inputs=[y], + input_names=['x'], + dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}}) + def test_concat(self): class ConcatModel(torch.nn.Module): def forward(self, x, y, z): @@ -2833,6 +3701,7 @@ def forward(self, x): x = torch.randn(4, 5, 6) self.run_test(ConcatDynamicModel(), x) + @skipIfUnsupportedOpsetVersion([13]) def test_stack(self): class StackModel(torch.nn.Module): def forward(self, x, y, z): @@ -2843,6 +3712,7 @@ def forward(self, x, y, z): z = torch.randn(3, 4, 5) self.run_test(StackModel(), (x, y, z)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_stack_dynamic(self): class StackDynamicModel(torch.jit.ScriptModule): @@ -2881,6 +3751,7 @@ def forward(self, x): inputs = torch.zeros(1, 2, 3, dtype=torch.long) self.run_test(model, inputs) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_loop_with_list(self): class ListLoopModel(torch.jit.ScriptModule): @@ -2918,6 +3789,22 @@ def forward(self, x): x = torch.randn(5, 3, 3) self.run_test(model, x) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_loop_multi_dim(self): + class LoopMultiDimModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x, y): + for x_ in torch.flip(x.narrow(0, 0, 7), [0]): + y = x_[0][y] + return y + + model = LoopMultiDimModel() + x = torch.randint(0, 5, (8, 1, 17), dtype=torch.long) + y = torch.ones(1, dtype=torch.long) + self.run_test(model, (x, y)) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_list(self): class ListModel(torch.jit.ScriptModule): @@ -2939,6 +3826,7 @@ def forward(self, x): inputs = torch.randn(16, 1) self.run_test(model, inputs) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_factories(self): class TensorFactory(torch.nn.Module): @@ -2948,6 +3836,7 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(TensorFactory(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_factories_script(self): class TensorFactory(torch.jit.ScriptModule): @@ -2958,6 +3847,7 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(TensorFactory(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_like_factories_script(self): class TensorFactory(torch.jit.ScriptModule): @@ -2970,11 +3860,14 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(TensorFactory(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_eye(self): class TensorFactory(torch.nn.Module): def forward(self, x): - return torch.eye(x.size()[1], 3), torch.eye(4, 4, dtype=torch.long), torch.eye(x.size()[1], 2, dtype=torch.long) + return torch.eye(x.size()[1], 3), torch.eye(4, 4, dtype=torch.long), \ + torch.eye(x.size()[1], 2, dtype=torch.long), torch.eye(x.shape[0]), \ + torch.eye(x.shape[0], dtype=torch.float64) x = torch.randn(2, 3, 4) another_x = torch.randn(5, 6, 7) @@ -2990,8 +3883,8 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Zero_(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() def test_new_zeros(self): class Zero_(torch.nn.Module): def forward(self, x): @@ -3000,7 +3893,23 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Zero_(), x) + @skipIfONNXShapeInference(True) @skipIfUnsupportedMinOpsetVersion(9) + def test_tolist(self): + class List(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input): + cur_shape = torch._shape_as_tensor(input) + final_shape: List[int] = cur_shape.tolist() + pad_tensor = torch.zeros([1, 2] + final_shape) + return pad_tensor + + x = torch.randn(2, 3) + self.run_test(List(), (x,)) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(9) + @disableScriptTest() def test_list_pass(self): class Slice(torch.nn.Module): def forward(self, x, y): @@ -3038,6 +3947,7 @@ def forward(self, x, y): y = torch.randn(1, 2, 3) self.run_test(List(), (x, y)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_empty(self): class Emtpy(torch.nn.Module): @@ -3047,6 +3957,7 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Emtpy(), x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_full(self): class Full(torch.nn.Module): @@ -3056,6 +3967,18 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Full(), x) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(9) + def test_inplace_list(self): + class Arithmetic(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x, y): + return torch.cat([x.add_(3), y.fill_(0)]) + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(Arithmetic(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(9) def test_inplace_fill(self): class Fill_(torch.nn.Module): @@ -3119,6 +4042,29 @@ def forward(self, x): x = torch.arange(16).view(2, 2, 4).to(torch.float32) self.run_test(MaskedFillModel2(), x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_masked_fill_inplace(self): + + class MaskedFillModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8) + x.masked_fill_(mask, 2) + return x + + x = torch.zeros(4, 2, 3, requires_grad=True) + self.run_test(MaskedFillModel(), x) + + class MaskedFillModel2(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + x.masked_fill_(x > 3, -1) + return x + + x = torch.arange(16).view(2, 2, 4).to(torch.float32) + self.run_test(MaskedFillModel2(), x) + + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_masked_scatter(self): class MaskedScatterModel(torch.nn.Module): @@ -3137,6 +4083,35 @@ def forward(self, x): x = torch.randn(3, 4, 5, requires_grad=True) self.run_test(MaskedSelectModel(), x) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() # dtype not available + def test_index_put_to_masked_fill(self): + class MaskedFillModel(torch.nn.Module): + def forward(self, input_mask, some_const): + mask = input_mask.clone() + mask[mask != some_const] = 1 + mask[mask == some_const] = 0 + return mask + + mask = torch.randn(2, 2, 2, requires_grad=True) + constant = torch.tensor(5, dtype=torch.float) + self.run_test(MaskedFillModel(), (mask, constant)) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() # dtype not available + def test_index_put_to_masked_scatter(self): + class MaskedScatterModel(torch.nn.Module): + def forward(self, input_mask, some_const): + mask = input_mask.clone() + mask[mask != some_const] = torch.ones(8) + return mask + + mask = torch.randn(2, 2, 2, requires_grad=True) + constant = torch.tensor(5, dtype=torch.float) + self.run_test(MaskedScatterModel(), (mask, constant)) + @skipIfUnsupportedMinOpsetVersion(9) def test_pixel_shuffle(self): class PixelShuffle(torch.nn.Module): @@ -3147,11 +4122,10 @@ def forward(self, x): self.run_test(PixelShuffle(), x) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() def test_scalar_type(self): class ArithmeticModel(torch.nn.Module): def forward(self, x): - return x.size(0) * 2 * x + return x.size(0) * 2 * x, 2 - x x = torch.ones(2, 3, dtype=torch.float32) self.run_test(ArithmeticModel(), x) @@ -3194,7 +4168,6 @@ def forward(self, x): self.run_test(FullModel(), x) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() def test_full_like(self): class FullLikeModel(torch.nn.Module): def forward(self, x): @@ -3204,7 +4177,6 @@ def forward(self, x): self.run_test(FullLikeModel(), x) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() def test_full_like_value(self): class FullLikeModel(torch.nn.Module): def forward(self, x, y): @@ -3247,14 +4219,20 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @skipIfUnsupportedOpsetVersion([13]) def test_unfold(self): class UnfoldModel(torch.nn.Module): def forward(self, x): return x.unfold(dimension=2, size=2, step=2) x = torch.randn(4, 2, 3, requires_grad=True) - self.run_test(UnfoldModel(), x) + y = torch.randn(2, 1, 3, requires_grad=True) + self.run_test(UnfoldModel(), x, + dynamic_axes={'x': [0, 1]}, + input_names=['x'], + test_with_inputs=[y]) + @skipIfUnsupportedOpsetVersion([13]) @skipIfONNXShapeInference(False) def test_unfold_infer_shape(self): class UnfoldModule(torch.jit.ScriptModule): @@ -3270,6 +4248,32 @@ def forward(self, x): x = torch.randn(32, 3, 64) self.run_test(UnfoldModule(), x) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(12) + def test_unfold_dynamic_inputs(self): + class UnfoldModel(torch.nn.Module): + def forward(self, x): + return x.unfold(dimension=2, size=x.shape[1], step=x.shape[1] - 1) + + x = torch.randn(4, 2, 4, requires_grad=True) + self.run_test(UnfoldModel(), x) + + @skipIfUnsupportedOpsetVersion([13]) + def test_prelu(self): + class PReluModel(torch.nn.Module): + def __init__(self): + super(PReluModel, self).__init__() + self.prelu = torch.nn.PReLU() + + def forward(self, x): + return self.prelu(x) + + x = torch.randn(2, 3, 4) + y = torch.randn(2, 4, 5) + self.run_test(PReluModel(), x, input_names=['x'], + dynamic_axes={'x': [1, 2]}, + test_with_inputs=[y]) + def test_remainder(self): class RemainderModel(torch.nn.Module): def forward(self, input, other): @@ -3306,6 +4310,16 @@ def forward(self, input): x = torch.randint(10, (2, 3)) self.run_test(FModModel(), x) + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(9) + def test_glu(self): + class GluModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.glu(x) + + x = torch.randn(2, 4, 5, 6, requires_grad=True) + self.run_test(GluModel(), x) + @skipIfUnsupportedMinOpsetVersion(9) def test_gelu(self): class GeluModel(torch.nn.Module): @@ -3378,28 +4392,9 @@ def forward(self, input): x = torch.tensor([False, True, True]) self.run_test(model, x) - @unittest.skip("Enable once jit trace Tensor.numel as constant is fixed.") - def test_embedding_bag_dynamic(self): - class EmbeddingModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.embeddingbag = torch.nn.EmbeddingBag(40, 12, mode='sum') - - def forward(self, input): - return self.embeddingbag(input) - - model = EmbeddingModel() - x = torch.randint(7, (10, 5)) - y = torch.randint(10, (20, 5)) - self.run_test(model, x, test_with_inputs=[y], - input_names=['input'], - output_names=['output'], - dynamic_axes={'input': [0], - 'output': [0] - }) - - @disableScriptTest() + @disableScriptTest() # error in propagate as assign input shape @skipIfUnsupportedMinOpsetVersion(10) + @skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue def test_embedding_bag(self): model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True) input = torch.randint(10, (7,)) @@ -3415,27 +4410,27 @@ def test_embedding_bag(self): input = torch.randint(10, (7, 5)) self.run_test(model, (input)) - @disableScriptTest() - @skipIfUnsupportedMinOpsetVersion(10) + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue def test_embedding_bag_1d_per_sample_weights(self): class EmbeddingModel(torch.nn.Module): def forward(self, embedding_matrix, input, offset, weights): - return torch.nn.functional.embedding_bag(embedding_matrix, input, offsets=offset, + return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offset, mode='sum', per_sample_weights=weights) model = EmbeddingModel() x = torch.randint(7, (6,)) - w = torch.randn(6,) + w = torch.randn(6, ) offset = torch.tensor([0, 2, 5]) embedding_matrix = torch.rand(10, 15) self.run_test(model, (embedding_matrix, x, offset, w)) - @disableScriptTest() - @skipIfUnsupportedMinOpsetVersion(10) + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue def test_embedding_bag_2d_per_sample_weights(self): class EmbeddingModel(torch.nn.Module): def forward(self, embedding_matrix, input, weights): - return torch.nn.functional.embedding_bag(embedding_matrix, input, + return torch.nn.functional.embedding_bag(input, embedding_matrix, mode='sum', per_sample_weights=weights) embedding_matrix = torch.rand(10, 15) @@ -3444,12 +4439,52 @@ def forward(self, embedding_matrix, input, weights): w = torch.randn(2, 3) self.run_test(model, (embedding_matrix, x, w)) + @disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast + @skipIfUnsupportedMinOpsetVersion(11) + @unittest.skip("Due to ONNX Loop shape inference issue.") + def test_embedding_bag_dynamic_input(self): + class EmbeddingModel1D(torch.nn.Module): + def forward(self, embedding_matrix, input, weights, offsets): + return torch.nn.functional.embedding_bag(input, embedding_matrix, offsets=offsets, + mode='sum', per_sample_weights=weights) + + model = EmbeddingModel1D() + x = torch.randint(7, (6,)) + w = torch.randn(6, ) + offsets = torch.tensor([0, 2, 5], dtype=torch.long) + embedding_matrix = torch.rand(10, 15) + x2 = torch.randint(7, (2,)) + w2 = torch.randn(2, ) + embedding_matrix2 = torch.rand(12, 25) + offsets2 = torch.tensor([0, ], dtype=torch.long) + self.run_test(model, (embedding_matrix, x, w, offsets), + test_with_inputs=[(embedding_matrix2, x2, w2, offsets2)], + input_names=['embedding_matrix', 'x', 'offsets', 'w'], + dynamic_axes={'embedding_matrix': [0, 1], 'x': [0], 'offsets': [0], 'w': [0]}) + + class EmbeddingModel2D(torch.nn.Module): + def forward(self, embedding_matrix, input, weights): + return torch.nn.functional.embedding_bag(input, embedding_matrix, + mode='sum', per_sample_weights=weights) + + model = EmbeddingModel2D() + x = torch.randint(7, (2, 3)) + w = torch.randn(2, 3) + embedding_matrix = torch.rand(10, 15) + x2 = torch.randint(7, (3, 5)) + w2 = torch.randn(3, 5) + embedding_matrix2 = torch.rand(12, 25) + self.run_test(model, (embedding_matrix, x, w), + test_with_inputs=[(embedding_matrix2, x2, w2)], + input_names=['embedding_matrix', 'x', 'w'], + dynamic_axes={'embedding_matrix': [0, 1], 'x': [0, 1], 'w': [0, 1]}) + @skipIfUnsupportedMinOpsetVersion(8) - @disableScriptTest() def test_meshgrid(self): class Meshgrid(torch.nn.Module): def forward(self, x, y, z): - return torch.meshgrid(x, y, z) + output1, output2, output3 = torch.meshgrid(x, y, z) + return output1, output2, output3 x = torch.randn(3, requires_grad=True) y = torch.zeros(4, requires_grad=True) @@ -3457,11 +4492,11 @@ def forward(self, x, y, z): self.run_test(Meshgrid(), (x, y, z)) @skipIfUnsupportedMinOpsetVersion(8) - @disableScriptTest() def test_meshgrid_scalar(self): class Meshgrid(torch.nn.Module): def forward(self, x, y, z): - return torch.meshgrid(x, y, z) + output1, output2, output3 = torch.meshgrid(x, y, z) + return output1, output2, output3 x = torch.ones(3, requires_grad=True) y = torch.zeros(4, requires_grad=True) @@ -3510,6 +4545,28 @@ def forward(self, input): model = MyModule() self.run_test(model, (x,)) + def test_dtype(self): + class MyModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input, other): + return input.to(dtype=other.dtype) + other + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(MyModel(), (x, y)) + + def test_dtype_eq(self): + class MyModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input, other): + if input.dtype == other.dtype: + return input + other + return input + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(MyModel(), (x, y)) + def test_cast_to(self): class MyModule(torch.jit.ScriptModule): @torch.jit.script_method @@ -3531,8 +4588,8 @@ def forward(self, input, other): model = MyModule() self.run_test(model, (x, y)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) - @disableScriptTest() def test_ones_bool(self): class MyModule(torch.nn.Module): def forward(self, input): @@ -3578,8 +4635,9 @@ def test_constant_pad(self): self.run_test(model, x) # Dynamic padding is added in opset 11 + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() + @disableScriptTest() # Functional module not scriptable def test_pad_types(self): # Test for different pad integer types class Pad(torch.nn.Module): @@ -3613,7 +4671,157 @@ def run(): self.assertEqual('Unsupported: ONNX export of Pad in opset 9. The sizes of the padding must be constant. ' + 'Please try opset version 11.', the_exception.args[0]) - @disableScriptTest() + @skipIfUnsupportedMinOpsetVersion(9) + def test_if_fold(self): + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.dim() == 2: + y = y + 4 + y = y + 2 + else: + y = y - 1 + return y + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.numel() > 1: + y = y + 4 + else: + y = y + 2 + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.dim() != 3: + y = y + 4 + y = y + 2 + else: + return y + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.dim() >= 1: + y = y + 4 + else: + y = y - 1 + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.dim() <= 1: + y = y + 4 + else: + y = y + 2 + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.dim() < 3 and y.dtype == torch.int: + y = y + 4 + y = y + 2 + else: + return y + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.dim() == 3 and y.dtype == torch.int: + y = y + 4 + y = y + 2 + else: + y = y + 1 + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, y): + if y.numel() != 0 and y.dim() == 2: + y = y + 4 + y = y + 2 + else: + return y + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), x) + + class IfFoldModel(torch.nn.Module): + def forward(self, x, y): + if x.numel() == y.numel(): + y = x + y + else: + y = y - x + return y + + x = torch.ones((3, 4), dtype=torch.int) + y = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), (x, y)) + + class IfFoldModel(torch.nn.Module): + def forward(self, x, y): + if x.numel() != y.numel(): + y = x + y + else: + y = y - x + return y + + x = torch.ones((3, 4), dtype=torch.int) + y = torch.ones((3, 4), dtype=torch.int) + self.run_test(IfFoldModel(), (x, y)) + + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfONNXShapeInference(False) + def test_uninitialized(self): + class UninitializedModel(torch.nn.Module): + def forward(self, y): + if y.shape[1] < 5: + if y.size(0) == 1: + y = y + 4 + else: + return y + return y + + x = torch.ones((3, 4), dtype=torch.int) + self.run_test(UninitializedModel(), x) + + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfONNXShapeInference(False) + def test_uninitialized_dynamic(self): + class UninitializedModel(torch.nn.Module): + def forward(self, y): + if y.shape[1] < 5: + if y.size(0) == 1: + y = y + 4 + else: + return y + return y + + x = torch.ones((3, 4), dtype=torch.int) + y = torch.ones((6, 7), dtype=torch.int) + self.run_test(UninitializedModel(), x, test_with_inputs=[y], + input_names=['input_1'], + dynamic_axes={'input_1': [0, 1]}) + def test_reflection_pad(self): model = torch.nn.ReflectionPad1d(2) x = torch.randn(2, 4, 4) @@ -3623,7 +4831,6 @@ def test_reflection_pad(self): x = torch.randn(2, 2, 4, 4) self.run_test(model, x) - @disableScriptTest() def test_replication_pad(self): model = torch.nn.ReplicationPad1d(2) x = torch.randn(2, 4, 4) @@ -3633,8 +4840,8 @@ def test_replication_pad(self): x = torch.randn(2, 2, 4, 4) self.run_test(model, x) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() def test_im2col(self): class Unfold(torch.nn.Module): def forward(self, input): @@ -3657,8 +4864,8 @@ def forward(self, x): # This test checks output scalar type in the ONNX graph should not be null # https://github.com/pytorch/pytorch/issues/28607 + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) - @disableScriptTest() def test_trace_script(self): @torch.jit.script def center_slice_helper(input, h_offset): @@ -3688,13 +4895,14 @@ def forward(self, input): out = input * 2 out *= out.dim() return out + empty_input = torch.randn(0, requires_grad=True) multi_dim_input = torch.randn(1, 2, 3, requires_grad=True) self.run_test(DimModel(), empty_input) self.run_test(DimModel(), multi_dim_input) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() + @disableScriptTest() # variable number of inputs not scriptable def test_einsum(self): class EinsumModelBatchDiagonal(torch.nn.Module): def forward(self, *tensor_list): @@ -3730,142 +4938,106 @@ def forward(self, *tensor_list): x = torch.randn(3, 4) self.run_test(EinsumModelTranspose(), input=(x,)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_crossentropyloss(self): - x = torch.randn(3, 5) - y = torch.empty(3, dtype=torch.long).random_(5) - self._crossentropyloss(x, y) + for ignore_index in [-100, 1]: + x = torch.randn(3, 5) + y = torch.empty(3, dtype=torch.long).random_(5) + y[y == 1] = ignore_index + + self._crossentropyloss(x, y, ignore_index) - x = torch.randn(3, 5, 2) - y = torch.empty(3, 2, dtype=torch.long).random_(5) - self._crossentropyloss(x, y) + x = torch.randn(3, 5, 2) + y = torch.empty(3, 2, dtype=torch.long).random_(5) + y[y == 1] = ignore_index + self._crossentropyloss(x, y, ignore_index) - x = torch.randn(3, 5, 2, 7) - y = torch.empty(3, 2, 7, dtype=torch.long).random_(5) - self._crossentropyloss(x, y) + x = torch.randn(3, 5, 2, 7) + y = torch.empty(3, 2, 7, dtype=torch.long).random_(5) + y[y == 1] = ignore_index + self._crossentropyloss(x, y, ignore_index) - def _crossentropyloss(self, x, y): + def _crossentropyloss(self, x, y, ignore_index): class CrossEntropyLossNone(torch.nn.Module): - def __init__(self): + def __init__(self, ignore_index): super(CrossEntropyLossNone, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='none') + if ignore_index == -100: + self.loss = torch.nn.CrossEntropyLoss(reduction='none') + else: + self.loss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=ignore_index) def forward(self, input, target): return self.loss(input, target) - self.run_test(CrossEntropyLossNone(), input=(x, y)) + self.run_test(CrossEntropyLossNone(ignore_index), input=(x, y)) class CrossEntropyLossNoneWeight(torch.nn.Module): - def __init__(self): + def __init__(self, ignore_index): super(CrossEntropyLossNoneWeight, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5)) + if ignore_index == -100: + self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5)) + else: + self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5), ignore_index=ignore_index) def forward(self, input, target): return self.loss(input, target) - self.run_test(CrossEntropyLossNoneWeight(), input=(x, y)) + self.run_test(CrossEntropyLossNoneWeight(ignore_index), input=(x, y)) class CrossEntropyLossSum(torch.nn.Module): - def __init__(self): + def __init__(self, ignore_index): super(CrossEntropyLossSum, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='sum') + if ignore_index == -100: + self.loss = torch.nn.CrossEntropyLoss(reduction='sum') + else: + self.loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index) def forward(self, input, target): return self.loss(input, target) - self.run_test(CrossEntropyLossSum(), input=(x, y)) + self.run_test(CrossEntropyLossSum(ignore_index), input=(x, y)) class CrossEntropyLossSumWeight(torch.nn.Module): - def __init__(self): + def __init__(self, ignore_index): super(CrossEntropyLossSumWeight, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5)) + if ignore_index == -100: + self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5)) + else: + self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5), ignore_index=ignore_index) def forward(self, input, target): return self.loss(input, target) - self.run_test(CrossEntropyLossSumWeight(), input=(x, y)) + self.run_test(CrossEntropyLossSumWeight(ignore_index), input=(x, y)) class CrossEntropyLossMean(torch.nn.Module): - def __init__(self): + def __init__(self, ignore_index): super(CrossEntropyLossMean, self).__init__() - self.loss = torch.nn.CrossEntropyLoss() + if ignore_index == -100: + self.loss = torch.nn.CrossEntropyLoss() + else: + self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) def forward(self, input, target): return self.loss(input, target) - self.run_test(CrossEntropyLossMean(), input=(x, y)) + self.run_test(CrossEntropyLossMean(ignore_index), input=(x, y)) class CrossEntropyLossMeanWeight(torch.nn.Module): - def __init__(self): + def __init__(self, ignore_index): super(CrossEntropyLossMeanWeight, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5)) - - def forward(self, input, target): - return self.loss(input, target) - - self.run_test(CrossEntropyLossMeanWeight(), input=(x, y)) - - class CrossEntropyLossNoneIgnoreIndex(torch.nn.Module): - def __init__(self): - super(CrossEntropyLossNoneIgnoreIndex, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=1) - - def forward(self, input, target): - return self.loss(input, target) - - self.run_test(CrossEntropyLossNoneIgnoreIndex(), input=(x, y)) - - class CrossEntropyLossNoneWeightIgnoreIndex(torch.nn.Module): - def __init__(self): - super(CrossEntropyLossNoneWeightIgnoreIndex, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5), ignore_index=1) - - def forward(self, input, target): - return self.loss(input, target) - - self.run_test(CrossEntropyLossNoneWeightIgnoreIndex(), input=(x, y)) - - class CrossEntropyLossSumIgnoreIndex(torch.nn.Module): - def __init__(self): - super(CrossEntropyLossSumIgnoreIndex, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=1) - - def forward(self, input, target): - return self.loss(input, target) - - self.run_test(CrossEntropyLossSumIgnoreIndex(), input=(x, y)) - - class CrossEntropyLossSumWeightIgnoreIndex(torch.nn.Module): - def __init__(self): - super(CrossEntropyLossSumWeightIgnoreIndex, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5), ignore_index=1) - - def forward(self, input, target): - return self.loss(input, target) - - self.run_test(CrossEntropyLossSumWeightIgnoreIndex(), input=(x, y)) - - class CrossEntropyLossMeanIgnoreIndex(torch.nn.Module): - def __init__(self): - super(CrossEntropyLossMeanIgnoreIndex, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(ignore_index=1) - - def forward(self, input, target): - return self.loss(input, target) - - self.run_test(CrossEntropyLossMeanIgnoreIndex(), input=(x, y)) - - class CrossEntropyLossMeanWeightIgnoreIndex(torch.nn.Module): - def __init__(self): - super(CrossEntropyLossMeanWeightIgnoreIndex, self).__init__() - self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5), ignore_index=1) + if ignore_index == -100: + self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5)) + else: + self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5), ignore_index=ignore_index) def forward(self, input, target): return self.loss(input, target) - self.run_test(CrossEntropyLossMeanWeightIgnoreIndex(), input=(x, y)) + self.run_test(CrossEntropyLossMeanWeight(ignore_index), input=(x, y)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_kldiv_loss(self): @@ -3932,8 +5104,8 @@ def forward(self, input, target): self.run_test(KLDivLossMiniBatchMean(), input=(x, y)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_nllloss(self): class NLLModel(torch.nn.Module): def __init__(self): @@ -3948,10 +5120,13 @@ def forward(self, input, target): N, C = 5, 4 input = torch.randn(N, 16) target = torch.empty(N, dtype=torch.long).random_(0, C) + + # using test data containing default ignore_index=-100 + target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_nllloss_2d_none(self): class NLLModel(torch.nn.Module): def __init__(self): @@ -3967,10 +5142,13 @@ def forward(self, input, target): N, C = 5, 4 input = torch.randn(N, 16, 10, 10) target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) + + # using test data containing default ignore_index=-100 + target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_nllloss_2d_mean(self): class NLLModel(torch.nn.Module): def __init__(self): @@ -3986,10 +5164,13 @@ def forward(self, input, target): N, C = 5, 4 input = torch.randn(N, 16, 10, 10) target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) + + # using test data containing default ignore_index=-100 + target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_nllloss_2d_sum(self): class NLLModel(torch.nn.Module): def __init__(self): @@ -4005,10 +5186,13 @@ def forward(self, input, target): N, C = 5, 4 input = torch.randn(N, 16, 10, 10) target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) + + # using test data containing default ignore_index=-100 + target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_nllloss_2d_mean_weights(self): class NLLModel(torch.nn.Module): def __init__(self): @@ -4024,10 +5208,13 @@ def forward(self, input, target): N, C = 5, 4 input = torch.randn(N, 16, 10, 10) target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) + + # using test data containing default ignore_index=-100 + target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_nllloss_2d_mean_ignore_index(self): class NLLModel(torch.nn.Module): def __init__(self): @@ -4045,8 +5232,8 @@ def forward(self, input, target): target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) self.run_test(NLLModel(), (input, target)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) - @disableScriptTest() def test_nllloss_2d_mean_ignore_index_weights(self): class NLLModel(torch.nn.Module): def __init__(self): @@ -4074,6 +5261,7 @@ def forward(self, mat1, mat2): mat2 = torch.randn(3, 3) self.run_test(M(), input=(mat1, mat2)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9. def test_where_with_bool_tensor(self): class M(torch.nn.Module): @@ -4085,6 +5273,7 @@ def forward(self, mat1, mat2): mat2 = torch.ones(2, 3) self.run_test(M(), input=(mat1, mat2)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9. def test_where_with_byte_tensor(self): class M(torch.nn.Module): @@ -4187,6 +5376,7 @@ def forward(self, cond, input, other): z = torch.ones(2, 3, 1) self.run_test(Model(), (x, y, z)) + @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_where_condition(self): class Model1(torch.nn.Module): @@ -4217,9 +5407,53 @@ def forward(self, input): else: pass return out + x = torch.randn(1, 2, 3, requires_grad=True) self.run_test(EmptyBranchModel(), x) + def test_derive_index(self): + class MyModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + j = [] + for idx in range(len(x) - 1, -len(x), -2): + y = x[idx] + j += [x * y] + return j + + x = torch.randn(5, 13) + self.run_test(MyModule(), x) + + class MyModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + j = [] + for idx in range(-len(x), len(x) - 1, 2): + y = x[idx] + j += [x * y] + return j + + x = torch.randn(5, 13) + self.run_test(MyModule(), x) + + class MyModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + j = [] + for idx in range(len(x) - 1, -len(x), -3): + y = x[idx] + j += [x * y] + return j + + self.run_test(MyModule(), x) + + class MyModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + j = [] + for idx in range(-len(x), len(x) - 1, 3): + y = x[idx] + j += [x * y] + return j + + self.run_test(MyModule(), x) + @skipIfONNXShapeInference(False) @skipIfUnsupportedMinOpsetVersion(11) def test_if_transpose(self): @@ -4236,6 +5470,25 @@ def forward(self, x): output_names=['output_1'], dynamic_axes={'output_1': [0, 1]}) + @skipIfONNXShapeInference(False) + @skipIfUnsupportedMinOpsetVersion(13) + @skipIfUnsupportedOpsetVersion([13]) + def test_if_list(self): + class IfModel(torch.nn.Module): + def forward(self, x, y, cond): + res = [] + if cond: + res = res + [x] + else: + res = res + [y] + # TODO: remove torch.stack once graph sequence output is supported. + return torch.stack(res) + + x = torch.randn(2, 3) + y = torch.randn(3, 3) + cond = torch.tensor(1, dtype=torch.bool) + self.run_test(torch.jit.script(IfModel()), (x, y, cond)) + def test_onnx_proto_checker(self): class Model(torch.nn.Module): def __init__(self): @@ -4243,6 +5496,7 @@ def __init__(self): def forward(self, x): return 2 * x + x = torch.randn(1, 2, 3, requires_grad=True) f = io.BytesIO() torch.onnx._export(Model(), x, f) @@ -4251,13 +5505,15 @@ def forward(self, x): def check_proto(): torch._C._check_onnx_proto(model.SerializeToString()) + self.assertRaises(RuntimeError, check_proto) - @disableScriptTest() + @disableScriptTest() # dtype mismatch def test_split_tensor_scalar(self): class SplitModel(torch.nn.Module): def forward(self, x): return torch.split(x, x.size(1)) + x = torch.randn(1, 2, 3, requires_grad=True) self.run_test(SplitModel(), x) @@ -4265,10 +5521,12 @@ def test_split_tensor_multi(self): class SplitModel(torch.nn.Module): def forward(self, x): return torch.split(x, torch.ones(3)) + x = torch.randn(1, 2, 3, requires_grad=True) def run_model(): SplitModel(x) + self.assertRaises(TypeError, run_model) def _dispatch_rnn_test(self, name, *args, **kwargs): @@ -4406,6 +5664,50 @@ def forward(self, input): x = torch.randn(6, 4, 3, 3) self.run_test(FakeQuantizePerTensorModel(), (x)) + @skipIfUnsupportedOpsetVersion([13]) + def test_batchnorm_training(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.bn = torch.nn.BatchNorm2d(3, affine=True) + + def forward(self, x): + bn = self.bn(x) + return bn + + model = MyModule() + x = torch.randn(10, 3, 128, 128) + + model.train() + out = model(x) + + # state after 1 train epoch + running_mean = model.bn.running_mean + running_var = model.bn.running_var + saved_mean = x.mean((0, 2, 3)) + saved_var = x.var((0, 2, 3)) + + pytorch_out = [out.detach().numpy(), + running_mean.cpu().numpy(), running_var.cpu().numpy(), + saved_mean.cpu().numpy(), saved_var.cpu().numpy()] + + model_export = MyModule() + f = io.BytesIO() + + ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) + [np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)] + + model_export = torch.jit.script(MyModule()) + ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version, + example_outputs=out, + training=torch.onnx.TrainingMode.TRAINING, + use_new_jit_passes=True, onnx_shape_inference=True) + ort_outs = run_ort(ort_sess, input=(x,)) + [np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in + zip(pytorch_out, ort_outs)] + @skipIfUnsupportedMinOpsetVersion(12) def test_dropout_training(self): class MyModule(torch.nn.Module): @@ -4422,7 +5724,16 @@ def forward(self, x): model.train() - ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) + ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) + assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) + + script_model = torch.jit.script(model) + output = model(x) + ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=output, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.TRAINING) ort_outs = run_ort(ort_sess, input=(x,)) assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) @@ -4448,12 +5759,27 @@ def forward(self, x): model.train() - ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) + ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) ort_outs = run_ort(ort_sess, input=(x,)) y = model(input) output = y.cpu().numpy() + ort_mask = np.where(ort_outs[0] != 0, 1, 0) + pyt_mask = np.where(output != 0, 1, 0) + ratio_pytorch = np.sum(pyt_mask) / nb_elements + ratio_ort = np.sum(ort_mask) / nb_elements + + np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) + + script_model = torch.jit.script(model) + y = model(input) + output = y.cpu().numpy() + ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=y, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) ort_mask = np.where(ort_outs[0] != 0, 1, 0) pyt_mask = np.where(output != 0, 1, 0) @@ -4462,6 +5788,7 @@ def forward(self, x): np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) + @skipIfUnsupportedOpsetVersion([13]) def test_conv_bn(self): class MyModule(torch.nn.Module): def __init__(self): @@ -4476,11 +5803,27 @@ def forward(self, x): model = MyModule() x = torch.randn(10, 3, 128, 128) - ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) + ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs1 = run_ort(ort_sess1, input=(x,)) + ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.EVAL) + ort_outs2 = run_ort(ort_sess2, input=(x,)) + [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in + zip(ort_outs1, ort_outs2)] + + script_model = torch.jit.script(model) + outputs = model(x) + ort_sess1 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=outputs, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.TRAINING) ort_outs1 = run_ort(ort_sess1, input=(x,)) - ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL) + ort_sess2 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=outputs, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.EVAL) ort_outs2 = run_ort(ort_sess2, input=(x,)) - [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)] + [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in + zip(ort_outs1, ort_outs2)] def test_multiple_conv_bn(self): class MyModule(torch.nn.Module): @@ -4494,7 +5837,6 @@ def __init__(self): self.relu = torch.nn.ReLU(inplace=True) self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - def forward(self, x): x = self.conv1(x) x = self.bn(x) @@ -4510,11 +5852,304 @@ def forward(self, x): model = MyModule() x = torch.randn(2, 3, 224, 224) - ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) + ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) ort_outs1 = run_ort(ort_sess1, input=(x,)) - ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL) + ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.EVAL) ort_outs2 = run_ort(ort_sess2, input=(x,)) - [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)] + [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in + zip(ort_outs1, ort_outs2)] + + @skipIfUnsupportedOpsetVersion([13]) + def test_initializer_sequence(self): + class MyModule(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(MyModule, self).__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + test_model = MyModule(3, 4, 10) + state_dict_list = [k for (k, v) in test_model.state_dict().items()] + named_params_list = [k for (k, v) in test_model.named_parameters()] + + x = torch.randn(32, 3) + f = io.BytesIO() + torch.onnx._export(test_model, (x,), f, _retain_param_name=True) + loaded_model = onnx.load_from_string(f.getvalue()) + + actual_list = [p.name for p in loaded_model.graph.initializer] + assert actual_list == state_dict_list, \ + "Initializers' sequence is not as same as state_dict(). Expected: (" \ + + ', '.join(state_dict_list) + "). Actual:(" + ', '.join(actual_list) + ")." + assert actual_list == named_params_list, \ + "Initializers' sequence is not as same as named_parameters(). Expected: (" \ + + ', '.join(named_params_list) + "). Actual:(" + ', '.join(actual_list) + ")." + + def test_initializer_sequence_script_model(self): + def list_is_expected(short_list, long_list) -> bool: + if (len(short_list) > len(long_list)): + return False + + for i in range(len(short_list)): + if (short_list[i] not in long_list[i]): + return False + + return True + + def loop(x, y): + for i in range(int(y)): + x = x + i + return x + + class MyModule(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(MyModule, self).__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, x, y): + x = loop(x, y) + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + test_model = torch.jit.script(MyModule(3, 4, 10)) + state_dict_list = [k for (k, v) in test_model.state_dict().items()] + named_params_list = [k for (k, v) in test_model.named_parameters()] + + x = torch.ones(2, 3, dtype=torch.float) + y = torch.tensor(5, dtype=torch.long) + example_output = (test_model(x, y),) + f = io.BytesIO() + + torch.onnx.export(test_model, (x, y), f, example_outputs=example_output, _retain_param_name=True) + loaded_model = onnx.load_from_string(f.getvalue()) + + actual_list = [p.name for p in loaded_model.graph.initializer] + assert list_is_expected(state_dict_list, actual_list), \ + "ScriptModel - Initializers' sequence is not as same as state_dict(). Expected: (" \ + + ', '.join(state_dict_list) + "). Actual:(" + ', '.join(actual_list) + ")." + assert list_is_expected(named_params_list, actual_list), \ + "ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: (" \ + + ', '.join(named_params_list) + "). Actual:(" + ', '.join(actual_list) + ")." + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_nms(self): + boxes = torch.rand(5, 4) + boxes[:, 2:] += torch.rand(5, 2) + scores = torch.randn(5) + + class Module(torch.nn.Module): + def forward(self, boxes, scores): + return ops.nms(boxes, scores, 0.5) + + self.run_test(Module(), (boxes, scores)) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_clip_boxes_to_image(self): + boxes = torch.randn(5, 4) * 500 + boxes[:, 2:] += boxes[:, :2] + size = torch.randn(200, 300) + + size_2 = torch.randn(300, 400) + + class Module(torch.nn.Module): + def forward(self, boxes, size): + return ops.boxes.clip_boxes_to_image(boxes, size.shape) + + self.run_test(Module(), (boxes, size), + input_names=["boxes", "size"], + dynamic_axes={"size": [0, 1]}, + test_with_inputs=[(boxes, size_2)]) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_roi_align(self): + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) + model = ops.RoIAlign((5, 5), 1, 2) + self.run_test(model, (x, single_roi)) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_roi_align_aligned(self): + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32) + model = ops.RoIAlign((5, 5), 1, 2, aligned=True) + self.run_test(model, (x, single_roi)) + + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) + model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True) + self.run_test(model, (x, single_roi)) + + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) + model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True) + self.run_test(model, (x, single_roi)) + + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) + model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True) + self.run_test(model, (x, single_roi)) + + @skipIfUnsupportedMinOpsetVersion(11) + def test_roi_pool(self): + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) + pool_h = 5 + pool_w = 5 + model = ops.RoIPool((pool_h, pool_w), 2) + self.run_test(model, (x, rois)) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_resize_images(self): + class TransformModule(torch.nn.Module): + def __init__(self): + super(TransformModule, self).__init__() + self.transform = _init_test_generalized_rcnn_transform() + + def forward(self, images): + return self.transform.resize(images, None)[0] + + input = torch.rand(3, 10, 20) + input_test = torch.rand(3, 100, 150) + self.run_test(TransformModule(), (input,), + input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}, + test_with_inputs=[(input_test,)]) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_transform_images(self): + + class TransformModule(torch.nn.Module): + def __init__(self): + super(TransformModule, self).__init__() + self.transform = _init_test_generalized_rcnn_transform() + + def forward(self, images): + return self.transform(images)[0].tensors + + input = torch.rand(3, 100, 200), torch.rand(3, 200, 200) + input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200) + self.run_test(TransformModule(), (input,), test_with_inputs=[(input_test,)]) + + def get_features(self, images): + s0, s1 = images.shape[-2:] + features = [ + ('0', torch.rand(2, 256, s0 // 4, s1 // 4)), + ('1', torch.rand(2, 256, s0 // 8, s1 // 8)), + ('2', torch.rand(2, 256, s0 // 16, s1 // 16)), + ('3', torch.rand(2, 256, s0 // 32, s1 // 32)), + ('4', torch.rand(2, 256, s0 // 64, s1 // 64)), + ] + features = OrderedDict(features) + return features + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_rpn(self): + class RPNModule(torch.nn.Module): + def __init__(self): + super(RPNModule, self).__init__() + self.rpn = _init_test_rpn() + + def forward(self, images, features): + images = ImageList(images, [i.shape[-2:] for i in images]) + return self.rpn(images, features) + + images = torch.rand(2, 3, 150, 150) + features = self.get_features(images) + images2 = torch.rand(2, 3, 80, 80) + test_features = self.get_features(images2) + + model = RPNModule() + model.eval() + model(images, features) + self.run_test(model, (images, features), + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], + "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], + "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}, + test_with_inputs=[(images2, test_features)], + dict_check=False) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_multi_scale_roi_align(self): + + class TransformModule(torch.nn.Module): + def __init__(self): + super(TransformModule, self).__init__() + self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2) + self.image_sizes = [(512, 512)] + + def forward(self, input, boxes): + return self.model(input, boxes, self.image_sizes) + + i = OrderedDict() + i['feat1'] = torch.rand(1, 5, 64, 64) + i['feat2'] = torch.rand(1, 5, 16, 16) + boxes = torch.rand(6, 4) * 256 + boxes[:, 2:] += boxes[:, :2] + + i1 = OrderedDict() + i1['feat1'] = torch.rand(1, 5, 64, 64) + i1['feat2'] = torch.rand(1, 5, 16, 16) + boxes1 = torch.rand(6, 4) * 256 + boxes1[:, 2:] += boxes1[:, :2] + + self.run_test(TransformModule(), (i, [boxes],), test_with_inputs=[(i1, [boxes1],)]) + + @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + def test_roi_heads(self): + class RoiHeadsModule(torch.nn.Module): + def __init__(self): + super(RoiHeadsModule, self).__init__() + self.transform = _init_test_generalized_rcnn_transform() + self.rpn = _init_test_rpn() + self.roi_heads = _init_test_roi_heads_faster_rcnn() + + def forward(self, images, features): + original_image_sizes = [img.shape[-2:] for img in images] + images = ImageList(images, [i.shape[-2:] for i in images]) + proposals, _ = self.rpn(images, features) + detections, _ = self.roi_heads(features, proposals, images.image_sizes) + detections = self.transform.postprocess(detections, + images.image_sizes, + original_image_sizes) + return detections + + images = torch.rand(2, 3, 100, 100) + features = self.get_features(images) + images2 = torch.rand(2, 3, 150, 150) + test_features = self.get_features(images2) + + model = RoiHeadsModule() + model.eval() + model(images, features) + + self.run_test(model, (images, features), + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], + "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}, + test_with_inputs=[(images2, test_features)], + dict_check=False) + def make_test(name, base, layer, bidirectional, initial_state, variable_length, dropout, @@ -4527,7 +6162,8 @@ def make_test(name, base, layer, bidirectional, initial_state, # Cannot export with older opsets because of 'ConstantFill' op # ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime - @disableScriptTest() + @skipIfUnsupportedOpsetVersion([13]) + @disableScriptTest() # Test code not scriptable @skipIfUnsupportedMinOpsetVersion(9) def f(self): self._dispatch_rnn_test( @@ -4542,7 +6178,6 @@ def f(self): f.__name__ = test_name setattr(TestONNXRuntime, f.__name__, f) - def setup_rnn_tests(): layers_opts = [ (1, 'unilayer'), @@ -4567,13 +6202,12 @@ def setup_rnn_tests(): ] test_count = 0 for (layer, bidirectional, initial_state, variable_length, dropout) in \ - itertools.product( - layers_opts, - bidirectional_opts, - initial_state_opts, - variable_length_opts, - dropout_opts, - ): + itertools.product( + layers_opts, + bidirectional_opts, + initial_state_opts, + variable_length_opts, + dropout_opts,): for base, name, extra_kwargs in ( ('elman', 'elman_relu', {'nonlinearity': u'relu'}), @@ -4594,7 +6228,6 @@ def setup_rnn_tests(): if test_count != 192: raise ValueError('Expected 192 tests but found {}'.format(test_count)) - setup_rnn_tests() @@ -4654,17 +6287,31 @@ def setup_rnn_tests(): dict(TestONNXRuntime.__dict__, opset_version=12, keep_initializers_as_inputs=False)) -# opset 9 tests, with use_new_jit_passes=True for using new jit API -TestONNXRuntime_opset9_new_jit_API = type(str("TestONNXRuntime_opset9_new_jit_API"), - (unittest.TestCase,), - dict(TestONNXRuntime.__dict__, - use_new_jit_passes=True)) - -# opset 12 tests, with use_new_jit_passes=True for using new jit API -TestONNXRuntime_opset12_new_jit_API = type(str("TestONNXRuntime_opset12_new_jit_API"), - (unittest.TestCase,), - dict(TestONNXRuntime.__dict__, opset_version=12, - use_new_jit_passes=True)) +# opset 13 tests +TestONNXRuntime_opset13 = type(str("TestONNXRuntime_opset13"), + (unittest.TestCase,), + dict(TestONNXRuntime.__dict__, opset_version=13, + keep_initializers_as_inputs=False, + onnx_shape_inference=True)) + +# opset 9 tests, with use_new_jit_passes=True for using new jit API, +# and with keep_initializers_as_inputs=False for IR version 4 style export. +TestONNXRuntime_opset9_IRv4_new_jit_API = type(str("TestONNXRuntime_opset9_IRv4_new_jit_API"), + (unittest.TestCase,), + dict(TestONNXRuntime.__dict__, + keep_initializers_as_inputs=False, + use_new_jit_passes=True, + onnx_shape_inference=True)) + + +# opset 12 tests, with use_new_jit_passes=True for using new jit API, +# and keep_initializers_as_inputs=False for IR version 4 style export. +TestONNXRuntime_opset12_IRv4_new_jit_API = type(str("TestONNXRuntime_opset12_IRv4_new_jit_API"), + (unittest.TestCase,), + dict(TestONNXRuntime.__dict__, opset_version=12, + keep_initializers_as_inputs=False, + use_new_jit_passes=True, + onnx_shape_inference=True)) # opset 12 tests, with _onnx_shape_inference=True. diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py new file mode 100644 index 0000000000000..b0b56d9296c70 --- /dev/null +++ b/test/onnx/test_pytorch_onnx_shape_inference.py @@ -0,0 +1,78 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +import torch + +import copy + +import test_pytorch_onnx_onnxruntime +from test_pytorch_onnx_onnxruntime import TestONNXRuntime +from torch.onnx import utils, OperatorExportTypes, TrainingMode +from torch.onnx.utils import _validate_dynamic_axes +from torch.onnx.symbolic_helper import (_set_opset_version, _set_operator_export_type, + _set_onnx_shape_inference, _set_training_mode, + _is_tensor_list, _is_tensor, _is_none) + + +def verify_inferred_shape(graph): + # Check every node in graph has type properly assigned. + for n in graph.nodes(): + for out in n.outputs(): + if not _is_tensor_list(out) and not _is_tensor(out) and not _is_none(out): + raise RuntimeError("Output of node is neither type Tensor nor type list of Tensor: ", out) + if _is_tensor(out) and out.type().scalarType() is None: + raise RuntimeError("Output of node does not have type assigned", out) + if _is_tensor(out) and out.type().dim() is None: + raise RuntimeError("Output of node does not have shape assigned", out) + + +def run_model_test(self, model, batch_size=2, state_dict=None, + input=None, use_gpu=True, rtol=0.001, atol=1e-7, + example_outputs=None, do_constant_folding=True, + dynamic_axes=None, test_with_inputs=None, + input_names=None, output_names=None, + fixed_batch_size=False): + model.eval() + + if input is None: + input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) + + with torch.no_grad(): + if isinstance(input, torch.Tensor): + input = (input,) + # In-place operators will update input tensor data as well. + # Thus inputs are replicated before every forward call. + input_copy = copy.deepcopy(input) + output = model(*input_copy) + if isinstance(output, torch.Tensor): + output = (output,) + + _set_opset_version(self.opset_version) + _set_operator_export_type(OperatorExportTypes.ONNX) + _set_onnx_shape_inference(True) + _set_training_mode(False) + if dynamic_axes is None: + dynamic_axes = {} + _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + input_copy = copy.deepcopy(input) + graph, _, _ = utils._model_to_graph(model, input_copy, + input_names=input_names, + output_names=output_names, + operator_export_type=OperatorExportTypes.ONNX, + example_outputs=output, + do_constant_folding=do_constant_folding, + training=TrainingMode.EVAL, + use_new_jit_passes=self.use_new_jit_passes, + dynamic_axes=dynamic_axes) + verify_inferred_shape(graph) + + +if __name__ == '__main__': + TestONNXRuntime.opset_version = 12 + test_pytorch_onnx_onnxruntime.run_model_test = run_model_test + + unittest.main() diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 88daef3d5fb01..5c1bfe8b55157 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -678,6 +678,31 @@ def forward(self, x): assert len(params_dict) == 2 + def test_scripting_param(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True) + self.bn = torch.nn.BatchNorm2d(16, affine=True) + + def forward(self, x): + x = self.conv(x) + bn = self.bn(x) + return bn + + model = torch.jit.script(MyModule()) + x = torch.randn(10, 3, 128, 128) + example_outputs = model(x) + f = io.BytesIO() + _set_opset_version(self.opset_version) + _set_operator_export_type(OperatorExportTypes.ONNX) + graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs, + operator_export_type=OperatorExportTypes.ONNX) + + graph_input_params = [param.debugName() for param in graph.inputs()] + assert all(item in graph_input_params for item in dict(model.named_parameters())), \ + "Graph parameter names does not match model parameters." + def test_modifying_params(self): class MyModel(torch.nn.Module): def __init__(self): diff --git a/test/onnx/verify.py b/test/onnx/verify.py index 4518f6a94ff98..1e57afdd8d415 100644 --- a/test/onnx/verify.py +++ b/test/onnx/verify.py @@ -219,7 +219,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): if self.errors: - errors_msg = "\n\n".join(map(lambda x: "ERROR: " + x, self.errors)) + errors_msg = "\n\n".join("ERROR: " + x for x in self.errors) final_msg = "{}\n{}\n{}".format(self.msg, '-' * 70, errors_msg) raise AssertionError(final_msg) if exc_type == self.exc_class: @@ -246,7 +246,7 @@ def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode For reproducibility, we recommend explicitly setting PyTorch's seed before invoking this function. - Arguments: + Args: model (torch.nn.Module): the model to be exported and verified args (tuple of arguments): the inputs to the model, e.g., such that ``model(*args)`` is a valid @@ -386,8 +386,8 @@ def run(args): "it had a different set of parameters. Are you assigning Parameters\n" "in the forward() of your model definition?") with errs.addErrCtxt(initializer_order_hint): - errs.requireEqual(list(map(lambda x: x.name, proto.graph.initializer)), - list(map(lambda x: x.name, alt_proto.graph.initializer)), + errs.requireEqual([x.name for x in proto.graph.initializer], + [x.name for x in alt_proto.graph.initializer], msg="Parameters list differs") # Now check if the embedded parameters are actually the same diff --git a/test/package_a/__init__.py b/test/package_a/__init__.py index 4761b3db5e414..109408f68cedc 100644 --- a/test/package_a/__init__.py +++ b/test/package_a/__init__.py @@ -5,3 +5,6 @@ class PackageAObject: def __init__(self, obj): self.obj = obj + + def return_result(self): + return result diff --git a/test/print_test_stats.py b/test/print_test_stats.py index 522e6652efe19..95b267028fae0 100755 --- a/test/print_test_stats.py +++ b/test/print_test_stats.py @@ -3,10 +3,13 @@ # Read and print test results statistics from xml.dom import minidom from glob import glob +import bz2 import json import os +import statistics import time +import boto3 import datetime import requests @@ -42,21 +45,19 @@ def append(self, test_case): self.skipped_count += 1 if test_case.skipped else 0 self.errored_count += 1 if test_case.errored else 0 - def print_report(self): + def print_report(self, num_longest=3): sorted_tests = sorted(self.test_cases, key=lambda x: x.time) test_count = len(sorted_tests) print(f"class {self.name}:") print(f" tests: {test_count} failed: {self.failed_count} skipped: {self.skipped_count} errored: {self.errored_count}") print(f" run_time: {self.total_time:.2f} seconds") print(f" avg_time: {self.total_time/test_count:.2f} seconds") - if test_count > 2: - print(f" mean_time: {sorted_tests[test_count>>1].time:.2f} seconds") - print(" Three longest tests:") - for idx in [-1, -2, -3]: - print(f" {sorted_tests[idx].name} time: {sorted_tests[idx].time:.2f} seconds") - elif test_count > 0: - print(" Longest test:") - print(f" {sorted_tests[-1].name} time: {sorted_tests[-1].time:.2f} seconds") + if test_count >= 2: + print(f" median_time: {statistics.median(x.time for x in sorted_tests):.2f} seconds") + sorted_tests = sorted_tests[-num_longest:] + print(f" {len(sorted_tests)} longest tests:") + for test in reversed(sorted_tests): + print(f" {test.name} time: {test.time:.2f} seconds") print("") @@ -77,13 +78,20 @@ def parse_reports(folder): tests_by_class[class_name].append(test_case) return tests_by_class +def build_info(): + return { + "build_pr": os.environ.get("CIRCLE_PR_NUMBER"), + "build_tag": os.environ.get("CIRCLE_TAG"), + "build_sha1": os.environ.get("CIRCLE_SHA1"), + "build_branch": os.environ.get("CIRCLE_BRANCH"), + "build_job": os.environ.get("CIRCLE_JOB"), + "build_workflow_id": os.environ.get("CIRCLE_WORKFLOW_ID"), + } + def build_message(test_case): return { "normal": { - "build_pr": os.environ.get("CIRCLE_PR_NUMBER"), - "build_tag": os.environ.get("CIRCLE_TAG"), - "build_sha1": os.environ.get("CIRCLE_SHA1"), - "build_branch": os.environ.get("CIRCLE_BRANCH"), + **build_info(), "test_suite_name": test_case.class_name, "test_case_name": test_case.name, }, @@ -97,7 +105,7 @@ def build_message(test_case): }, } -def send_report(reports): +def send_report_to_scribe(reports): access_token = os.environ.get("SCRIBE_GRAPHQL_ACCESS_TOKEN") if not access_token: @@ -122,33 +130,125 @@ def send_report(reports): ), }, ) - print("Scribe report status: {}".format(r.text)) r.raise_for_status() +def send_report_to_s3(reports, *, total_seconds): + job = os.environ.get('CIRCLE_JOB') + sha1 = os.environ.get('CIRCLE_SHA1') + branch = os.environ.get('CIRCLE_BRANCH', '') + if branch not in ['master', 'nightly'] and not branch.startswith("release/"): + print("S3 upload only enabled on master, nightly and release branches.") + print(f"skipping test report on branch: {branch}") + return + now = datetime.datetime.utcnow().isoformat() + key = f'test_time/{sha1}/{job}/{now}Z.json.bz2' # Z meaning UTC + s3 = boto3.resource('s3') + try: + s3.get_bucket_acl(Bucket='ossci-metrics') + except Exception as e: + print(f"AWS ACL failed: {e}") + print("AWS credential found, uploading to S3...") + + obj = s3.Object('ossci-metrics', key) + print("") + # use bz2 because the results are smaller than gzip, and the + # compression time penalty we pay is only about half a second for + # input files of a few megabytes in size like these JSON files, and + # because for some reason zlib doesn't seem to play nice with the + # gunzip command whereas Python's bz2 does work with bzip2 + obj.put(Body=bz2.compress(json.dumps({ + **build_info(), + 'total_seconds': total_seconds, + 'suites': { + name: { + 'total_seconds': suite.total_time, + 'cases': [ + { + 'name': case.name, + 'seconds': case.time, + 'errored': case.errored, + 'failed': case.failed, + 'skipped': case.skipped, + } + for case in suite.test_cases + ], + } + for name, suite in reports.items() + } + }).encode())) + +def positive_integer(value): + parsed = int(value) + if parsed < 1: + raise argparse.ArgumentTypeError(f"{value} is not a natural number") + return parsed + +def positive_float(value): + parsed = float(value) + if parsed <= 0.0: + raise argparse.ArgumentTypeError(f"{value} is not a positive rational number") + return parsed + if __name__ == '__main__': + import argparse import sys - if len(sys.argv) == 1: - print("Please specify test report folder") - sys.exit(0) + parser = argparse.ArgumentParser( + "Print statistics from test XML output.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--longest-of-class", + type=positive_integer, + default=3, + metavar="N", + help="how many longest tests to show for each class", + ) + parser.add_argument( + "--class-print-threshold", + type=positive_float, + default=1.0, + metavar="N", + help="Minimal total time to warrant class report", + ) + parser.add_argument( + "--longest-of-run", + type=positive_integer, + default=10, + metavar="N", + help="how many longest tests to show from the entire run", + ) + parser.add_argument( + "--upload-to-s3", + action="store_true", + help="upload test time to S3 bucket", + ) + parser.add_argument( + "folder", + help="test report folder", + ) + args = parser.parse_args() - reports = parse_reports(sys.argv[1]) + reports = parse_reports(args.folder) if len(reports) == 0: - print(f"No test reports found in {sys.argv[1]}") + print(f"No test reports found in {args.folder}") sys.exit(0) - send_report(reports) + send_report_to_scribe(reports) longest_tests = [] total_time = 0 for name in sorted(reports.keys()): test_suite = reports[name] - test_suite.print_report() + if test_suite.total_time >= args.class_print_threshold: + test_suite.print_report(args.longest_of_class) total_time += test_suite.total_time longest_tests.extend(test_suite.test_cases) - if len(longest_tests) > 10: - longest_tests = sorted(longest_tests, key=lambda x: x.time)[-10:] + longest_tests = sorted(longest_tests, key=lambda x: x.time)[-args.longest_of_run:] + + if args.upload_to_s3: + send_report_to_s3(reports, total_seconds=total_time) print(f"Total runtime is {datetime.timedelta(seconds=int(total_time))}") - print("Ten longest tests of entire run:") + print(f"{len(longest_tests)} longest tests of entire run:") for test_case in reversed(longest_tests): print(f" {test_case.class_name}.{test_case.name} time: {test_case.time:.2f} seconds") diff --git a/test/quantization/test_numeric_suite.py b/test/quantization/test_numeric_suite.py index 0a41864557c04..74ecc4a904692 100644 --- a/test/quantization/test_numeric_suite.py +++ b/test/quantization/test_numeric_suite.py @@ -11,7 +11,9 @@ quantize_dynamic, ) from torch.quantization._numeric_suite import ( + OutputLogger, Shadow, + ShadowLogger, compare_model_outputs, compare_model_stub, compare_weights, @@ -85,8 +87,7 @@ def forward(self, x): class TestEagerModeNumericSuite(QuantizationTestCase): @override_qengines def test_compare_weights_conv_static(self): - r"""Compare the weights of float and static quantized conv layer - """ + r"""Compare the weights of float and static quantized conv layer""" qengine = torch.backends.quantized.engine @@ -103,13 +104,12 @@ def compare_and_validate_results(float_model, q_model): model.eval() if hasattr(model, "fuse_model"): model.fuse_model() - q_model = quantize(model, test_only_eval_fn, self.img_data_2d) + q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) compare_and_validate_results(model, q_model) @override_qengines def test_compare_weights_linear_static(self): - r"""Compare the weights of float and static quantized linear layer - """ + r"""Compare the weights of float and static quantized linear layer""" qengine = torch.backends.quantized.engine @@ -126,13 +126,12 @@ def compare_and_validate_results(float_model, q_model): model.eval() if hasattr(model, "fuse_model"): model.fuse_model() - q_model = quantize(model, test_only_eval_fn, self.calib_data) + q_model = quantize(model, test_only_eval_fn, [self.calib_data]) compare_and_validate_results(model, q_model) @override_qengines def test_compare_weights_linear_dynamic(self): - r"""Compare the weights of float and dynamic quantized linear layer - """ + r"""Compare the weights of float and dynamic quantized linear layer""" qengine = torch.backends.quantized.engine @@ -142,7 +141,9 @@ def compare_and_validate_results(float_model, q_model): ) self.assertEqual(len(weight_dict), 1) for k, v in weight_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) model_list = [SingleLayerLinearDynamicModel(qengine)] for model in model_list: @@ -154,8 +155,7 @@ def compare_and_validate_results(float_model, q_model): @override_qengines def test_compare_weights_lstm_dynamic(self): - r"""Compare the weights of float and dynamic quantized LSTM layer - """ + r"""Compare the weights of float and dynamic quantized LSTM layer""" qengine = torch.backends.quantized.engine @@ -165,7 +165,9 @@ def compare_and_validate_results(float_model, q_model): ) self.assertEqual(len(weight_dict), 1) for k, v in weight_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) model_list = [LSTMwithHiddenDynamicModel(qengine)] for model in model_list: @@ -177,8 +179,7 @@ def compare_and_validate_results(float_model, q_model): @override_qengines def test_compare_model_stub_conv_static(self): - r"""Compare the output of static quantized conv layer and its float shadow module - """ + r"""Compare the output of static quantized conv layer and its float shadow module""" qengine = torch.backends.quantized.engine @@ -186,7 +187,9 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data) self.assertEqual(len(ob_dict), 1) for k, v in ob_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) model_list = [AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine)] module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d] @@ -194,15 +197,14 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): model.eval() if hasattr(model, "fuse_model"): model.fuse_model() - q_model = quantize(model, test_only_eval_fn, self.img_data_2d) + q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) compare_and_validate_results( model, q_model, module_swap_list, self.img_data_2d[0][0] ) @override_qengines def test_compare_model_stub_linear_static(self): - r"""Compare the output of static quantized linear layer and its float shadow module - """ + r"""Compare the output of static quantized linear layer and its float shadow module""" qengine = torch.backends.quantized.engine @@ -210,7 +212,9 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data) self.assertEqual(len(ob_dict), 1) for k, v in ob_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) linear_data = self.calib_data[0][0] module_swap_list = [nn.Linear] @@ -219,18 +223,17 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): model.eval() if hasattr(model, "fuse_model"): model.fuse_model() - q_model = quantize(model, test_only_eval_fn, self.calib_data) + q_model = quantize(model, test_only_eval_fn, [self.calib_data]) compare_and_validate_results(model, q_model, module_swap_list, linear_data) @override_qengines def test_compare_model_stub_submodule_static(self): - r"""Compare the output of static quantized submodule and its float shadow module - """ + r"""Compare the output of static quantized submodule and its float shadow module""" qengine = torch.backends.quantized.engine model = ModelWithSubModules().eval() - q_model = quantize(model, test_only_eval_fn, self.img_data_2d) + q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) module_swap_list = [SubModule] ob_dict = compare_model_stub( model, q_model, module_swap_list, self.img_data_2d[0][0] @@ -238,12 +241,14 @@ def test_compare_model_stub_submodule_static(self): self.assertTrue(isinstance(q_model.mod1, Shadow)) self.assertFalse(isinstance(q_model.conv, Shadow)) for k, v in ob_dict.items(): - torch.testing.assert_allclose(v["float"], v["quantized"].dequantize()) + for i, val in enumerate(v["quantized"]): + torch.testing.assert_allclose( + v["float"][i], v["quantized"][i].dequantize() + ) @override_qengines def test_compare_model_stub_functional_static(self): - r"""Compare the output of static quantized functional layer and its float shadow module - """ + r"""Compare the output of static quantized functional layer and its float shadow module""" qengine = torch.backends.quantized.engine @@ -264,12 +269,13 @@ def test_compare_model_stub_functional_static(self): self.assertTrue(isinstance(q_model.my_scalar_add, Shadow)) self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow)) for k, v in ob_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) @override_qengines def test_compare_model_stub_linear_dynamic(self): - r"""Compare the output of dynamic quantized linear layer and its float shadow module - """ + r"""Compare the output of dynamic quantized linear layer and its float shadow module""" qengine = torch.backends.quantized.engine @@ -277,7 +283,9 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data) self.assertEqual(len(ob_dict), 1) for k, v in ob_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) linear_data = self.calib_data[0][0] @@ -292,8 +300,7 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): @override_qengines def test_compare_model_stub_lstm_dynamic(self): - r"""Compare the output of dynamic quantized LSTM layer and its float shadow module - """ + r"""Compare the output of dynamic quantized LSTM layer and its float shadow module""" qengine = torch.backends.quantized.engine @@ -305,7 +312,9 @@ def compare_and_validate_results( ) self.assertEqual(len(ob_dict), 1) for k, v in ob_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) @@ -341,7 +350,7 @@ def compare_and_validate_results(float_model, q_model, data): model.eval() if hasattr(model, "fuse_model"): model.fuse_model() - q_model = quantize(model, test_only_eval_fn, self.img_data_2d) + q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) compare_and_validate_results(model, q_model, self.img_data_2d[0][0]) @override_qengines @@ -357,7 +366,9 @@ def compare_and_validate_results(float_model, q_model, data): self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) for k, v in act_compare_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) linear_data = self.calib_data[0][0] model_list = [AnnotatedSingleLayerLinearModel(qengine)] @@ -365,7 +376,7 @@ def compare_and_validate_results(float_model, q_model, data): model.eval() if hasattr(model, "fuse_model"): model.fuse_model() - q_model = quantize(model, test_only_eval_fn, self.calib_data) + q_model = quantize(model, test_only_eval_fn, [self.calib_data]) compare_and_validate_results(model, q_model, linear_data) @override_qengines @@ -393,7 +404,9 @@ def test_compare_model_outputs_functional_static(self): } self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) for k, v in act_compare_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) @override_qengines def test_compare_model_outputs_linear_dynamic(self): @@ -408,7 +421,9 @@ def compare_and_validate_results(float_model, q_model, data): self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) for k, v in act_compare_dict.items(): - self.assertTrue(v["float"].shape == v["quantized"].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) linear_data = self.calib_data[0][0] @@ -435,7 +450,18 @@ def compare_and_validate_results(float_model, q_model, input, hidden): self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) for k, v in act_compare_dict.items(): - self.assertTrue(v["float"][0].shape == v["quantized"][0].shape) + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(len(v["float"][i]) == len(v["quantized"][i])) + if i == 0: + self.assertTrue(v["float"][i][0].shape == v["quantized"][i][0].shape) + else: + self.assertTrue( + v["float"][i][0].shape == v["quantized"][i][0].shape + ) + self.assertTrue( + v["float"][i][1].shape == v["quantized"][i][1].shape + ) lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) @@ -447,3 +473,35 @@ def compare_and_validate_results(float_model, q_model, input, hidden): model.fuse_model() q_model = quantize_dynamic(model) compare_and_validate_results(model, q_model, lstm_input, lstm_hidden) + + @override_qengines + def test_output_logger(self): + r"""Compare output from OutputLogger with the expected results""" + x = torch.rand(2, 2) + y = torch.rand(2, 1) + + l = [] + l.append(x) + l.append(y) + + logger = OutputLogger() + logger.forward(x) + logger.forward(y) + + self.assertEqual(l, logger.stats["tensor_val"]) + + @override_qengines + def test_shadow_logger(self): + r"""Compare output from ShawdowLogger with the expected results""" + a_float = torch.rand(2, 2) + a_quantized = torch.rand(2, 2) + + b_float = torch.rand(3, 2, 2) + b_quantized = torch.rand(3, 2, 2) + + logger = ShadowLogger() + logger.forward(a_float, a_quantized) + logger.forward(b_float, b_quantized) + + self.assertEqual(len(logger.stats["float"]), 2) + self.assertEqual(len(logger.stats["quantized"]), 2) diff --git a/test/quantization/test_numeric_suite_fx.py b/test/quantization/test_numeric_suite_fx.py new file mode 100644 index 0000000000000..0955b4b58914e --- /dev/null +++ b/test/quantization/test_numeric_suite_fx.py @@ -0,0 +1,144 @@ +import copy + +import torch +from torch.quantization import get_default_qconfig +from torch.quantization._numeric_suite_fx import ( + compare_weights_fx, + remove_qconfig_observer_fx, +) +from torch.quantization.fx.quantize import is_activation_post_process +from torch.quantization.quantize_fx import convert_fx, fuse_fx, prepare_fx +from torch.testing._internal.common_quantization import ( + ConvBnModel, + ConvBNReLU, + ConvModel, + QuantizationTestCase, + SingleLayerLinearDynamicModel, + SingleLayerLinearModel, + skipIfNoFBGEMM, +) + + +@skipIfNoFBGEMM +class TestGraphModeNumericSuite(QuantizationTestCase): + def test_remove_qconfig_observer_fx(self): + r"""Remove activation_post_process node from fx prepred model""" + float_model = SingleLayerLinearModel() + float_model.eval() + + qengine = torch.backends.quantized.engine + qconfig = get_default_qconfig(qengine) + + qconfig_dict = {"": qconfig} + + prepared_model = prepare_fx(float_model, qconfig_dict) + + backup_prepared_model = copy.deepcopy(prepared_model) + backup_prepared_model.eval() + + model = remove_qconfig_observer_fx(backup_prepared_model) + + modules = dict(model.named_modules()) + for node in model.graph.nodes: + if node.op == "call_module": + self.assertFalse(is_activation_post_process(modules[node.target])) + + @skipIfNoFBGEMM + def test_compare_weights_conv_static_fx(self): + r"""Compare the weights of float and static quantized conv layer""" + + def calibrate(model, calib_data): + model.eval() + with torch.no_grad(): + for inp in calib_data: + model(*inp) + + def compare_and_validate_results(float_model, q_model): + weight_dict = compare_weights_fx( + float_model.state_dict(), q_model.state_dict() + ) + self.assertEqual(len(weight_dict), 1) + for k, v in weight_dict.items(): + self.assertTrue(v["float"].shape == v["quantized"].shape) + + qengine = torch.backends.quantized.engine + qconfig = get_default_qconfig(qengine) + qconfig_dict = {"": qconfig} + + model_list = [ConvModel(), ConvBnModel(), ConvBNReLU()] + for float_model in model_list: + float_model.eval() + + fused = fuse_fx(float_model) + prepared_model = prepare_fx(float_model, qconfig_dict) + + # Run calibration + calibrate(prepared_model, self.img_data_2d) + q_model = convert_fx(prepared_model) + + compare_and_validate_results(fused, q_model) + + @skipIfNoFBGEMM + def test_compare_weights_linear_static_fx(self): + r"""Compare the weights of float and static quantized linear layer""" + + def calibrate(model, calib_data): + model.eval() + with torch.no_grad(): + for inp in calib_data: + model(*inp) + + def compare_and_validate_results(float_model, q_model): + weight_dict = compare_weights_fx( + float_model.state_dict(), q_model.state_dict() + ) + self.assertEqual(len(weight_dict), 1) + for k, v in weight_dict.items(): + self.assertTrue(v["float"].shape == v["quantized"].shape) + + float_model = SingleLayerLinearModel() + float_model.eval() + + qengine = torch.backends.quantized.engine + qconfig = get_default_qconfig(qengine) + qconfig_dict = {"": qconfig} + + prepared_model = prepare_fx(float_model, qconfig_dict) + + backup_prepared_model = copy.deepcopy(prepared_model) + backup_prepared_model.eval() + + # Run calibration + calibrate(prepared_model, self.calib_data) + q_model = convert_fx(prepared_model) + + compare_and_validate_results(backup_prepared_model, q_model) + + @skipIfNoFBGEMM + def test_compare_weights_linear_dynamic_fx(self): + r"""Compare the weights of float and dynamic quantized linear layer""" + + def compare_and_validate_results(float_model, q_model): + weight_dict = compare_weights_fx( + float_model.state_dict(), q_model.state_dict() + ) + self.assertEqual(len(weight_dict), 1) + for k, v in weight_dict.items(): + self.assertTrue(len(v["float"]) == len(v["quantized"])) + for i, val in enumerate(v["quantized"]): + self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + + float_model = SingleLayerLinearDynamicModel() + float_model.eval() + + qconfig = torch.quantization.qconfig.default_dynamic_qconfig + qconfig_dict = {"": qconfig} + + prepared_model = prepare_fx(float_model, qconfig_dict) + + backup_prepared_model = copy.deepcopy(prepared_model) + backup_prepared_model.eval() + + q_model = convert_fx(prepared_model) + + compare_and_validate_results(backup_prepared_model, q_model) diff --git a/test/quantization/test_qat_module.py b/test/quantization/test_qat_module.py index 4144c07441045..32de0ff50f0e7 100644 --- a/test/quantization/test_qat_module.py +++ b/test/quantization/test_qat_module.py @@ -110,7 +110,11 @@ def _forward(self, input): running_std = torch.sqrt(self.running_var + self.eps) scale_factor = self.gamma / running_std scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1]) - conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight)) + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device) + conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias) if self.training and not self.freeze_bn: # recovering original conv to get original batch_mean and batch_var diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index e54eb33770c24..71391a9ddee74 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -22,9 +22,11 @@ default_dynamic_qconfig, per_channel_dynamic_qconfig, float16_dynamic_qconfig, - float_qparams_dynamic_qconfig, - register_observed_custom_module_mapping, - register_quantized_custom_module_mapping, + float_qparams_weight_only_qconfig, + PerChannelMinMaxObserver, + QConfigDynamic, + default_dynamic_quant_observer, + FixedQParamsFakeQuantize, ) from torch.testing._internal.common_quantization import ( @@ -80,12 +82,14 @@ hu.assert_deadline_disabled() # Standard library +from typing import Tuple import copy import io import unittest import numpy as np class TestPostTrainingStatic(QuantizationTestCase): + def test_single_layer(self): r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped to nnq.Linear which is the quantized version of the module @@ -118,7 +122,7 @@ def checkQuantized(model): base = AnnotatedSingleLayerLinearModel(qengine) base.qconfig = qconfig keys_before = set(list(base.state_dict().keys())) - model = quantize(base, test_only_eval_fn, self.calib_data) + model = quantize(base, test_only_eval_fn, [self.calib_data]) checkQuantized(model) keys_after = set(list(base.state_dict().keys())) self.assertEqual(keys_before, keys_after) # simple check that nothing changed @@ -126,7 +130,7 @@ def checkQuantized(model): # in-place version model = AnnotatedSingleLayerLinearModel(qengine) model.qconfig = qconfig - quantize(model, test_only_eval_fn, self.calib_data, inplace=True) + quantize(model, test_only_eval_fn, [self.calib_data], inplace=True) checkQuantized(model) @skipIfNoFBGEMM @@ -160,7 +164,7 @@ def checkQuantized(model): # test one line API model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn, - self.calib_data) + [self.calib_data]) checkQuantized(model) def test_nested1(self): @@ -202,7 +206,7 @@ def checkQuantized(model): # test one line API model = quantize(AnnotatedNestedModel(qengine), test_only_eval_fn, - self.calib_data) + [self.calib_data]) checkQuantized(model) @@ -243,7 +247,7 @@ def checkQuantized(model): # test one line API model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn, - self.calib_data) + [self.calib_data]) checkQuantized(model) def test_nested3(self): @@ -285,7 +289,7 @@ def checkQuantized(model): # test one line API model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn, - self.calib_data) + [self.calib_data]) checkQuantized(model) def test_skip_quant(self): @@ -305,15 +309,15 @@ def checkQuantized(model): self.checkQuantDequant(model.sub) self.checkQuantizedLinear(model.sub.module.fc1) self.checkQuantizedLinear(model.sub.module.fc2) - self.assertEqual(type(model.sub.module.relu1), nnq.ReLU) - self.assertEqual(type(model.sub.module.relu2), nnq.ReLU) + self.assertEqual(type(model.sub.module.relu1), nn.ReLU) + self.assertEqual(type(model.sub.module.relu2), nn.ReLU) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) checkQuantized(model) # test one line API - model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, self.calib_data) + model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, [self.calib_data]) checkQuantized(model) @skipIfNoFBGEMM @@ -339,7 +343,7 @@ def checkQuantized(model): checkQuantized(model) # test one line API - model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data) + model = quantize(QuantStubModel(), test_only_eval_fn, [self.calib_data]) checkQuantized(model) def test_resnet_base(self): @@ -350,10 +354,9 @@ def test_resnet_base(self): with override_quantized_engine(qengine): qconfig = torch.quantization.get_default_qconfig(qengine) model = ResNetBase().float().eval() + model.fuse_model() model = QuantWrapper(model) model.qconfig = qconfig - fuse_list = ['module.conv1', 'module.bn1', 'module.relu1'] - fuse_modules(model, fuse_list, inplace=True) model = prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.img_data_2d) @@ -363,6 +366,8 @@ def checkQuantized(model): self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d) self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) + self.assertEqual(type(model.module.fc), nnq.Linear) + test_only_eval_fn(model, self.img_data_2d) self.checkNoQconfig(model) @@ -398,7 +403,7 @@ def checkQuantized(model): checkQuantized(model) model_oneline = quantize( - NormalizationTestModel(), test_only_eval_fn, self.calib_data) + NormalizationTestModel(), test_only_eval_fn, [self.calib_data]) checkQuantized(model) def test_save_load_state_dict(self): @@ -461,7 +466,7 @@ def checkQuantized(model): # test one line API model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn, - self.calib_data) + [self.calib_data]) checkQuantized(model_oneline) @override_qengines @@ -519,7 +524,7 @@ def test_quantized_embedding(self): model = EmbeddingModule().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) weights = torch.randn(10, 12, dtype=torch.float32) - model.qconfig = float_qparams_dynamic_qconfig + model.qconfig = float_qparams_weight_only_qconfig prepare(model, inplace=True) convert(model, inplace=True) self.assertTrue('QuantizedEmbedding' in str(model)) @@ -533,47 +538,98 @@ def test_quantized_embedding(self): self.assertTrue('QuantizedLinear' in str(model)) self.checkQuantizedLinear(model.fc) + @skipIfNoFBGEMM + def test_embedding_linear_dynamic(self): + class EmbeddingWithLinearDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + self.fc = torch.nn.Linear(5, 5) + + def forward(self, indices, linear_in): + return self.emb(indices), self.fc(linear_in) + + model = EmbeddingWithLinearDynamic() + qconfig_dict = {'fc' : default_dynamic_qconfig} + model = EmbeddingWithLinear() + quantize_dynamic(model, qconfig_dict, inplace=True) + model.emb.qconfig = float_qparams_weight_only_qconfig + prepare(model, inplace=True) + convert(model, inplace=True) + self.assertTrue('QuantizedEmbedding' in str(model)) + self.assertTrue('DynamicQuantizedLinear' in str(model)) + + + @skipIfNoFBGEMM + def test_dequant_stub(self): + m = QuantStubModel().eval() + prepare(m, inplace=True) + self.checkObservers(m) + convert(m, inplace=True) + self.assertEqual(type(m.quant), nnq.Quantize) + self.assertEqual(type(m.fc), nnq.Linear) + self.assertEqual(type(m.dequant), nnq.DeQuantize) + + # check DeQuantStub is not swapped when it doesn't have a qconfig + m2 = QuantStubModel().eval() + m2.dequant.qconfig = None + prepare(m2, inplace=True) + self.checkObservers(m2) + convert(m2, inplace=True) + self.assertEqual(type(m2.quant), nnq.Quantize) + self.assertEqual(type(m2.fc), nnq.Linear) + self.assertEqual(type(m2.dequant), DeQuantStub) + + + @skipIfNoFBGEMM def test_quantized_embedding_bag(self): r""" Test the post-training quantization flow, serialization and scripting of embedding_bag modules """ - model = EmbeddingBagModule().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) offsets = torch.tensor([0, 19, 20, 28, 28, 32]) weights = torch.randn(10, 12, dtype=torch.float32) - model.qconfig = float_qparams_dynamic_qconfig - prepare(model, inplace=True) - quantized_model = convert(model) + for dtype in [torch.quint8, torch.quint4x2]: + model = EmbeddingBagModule().eval() + float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0) + float_qparams_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, + weight=float_qparams_observer) + model.qconfig = float_qparams_qconfig - per_sample_weights = torch.from_numpy(np.random.uniform( - low=0.01, high=0.5, size=[len(indices)]).astype(np.float32)) + prepare(model, inplace=True) + quantized_model = convert(model) - # Test to make sure module is quantized correctly. - self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) - self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) - self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True) + per_sample_weights = torch.from_numpy(np.random.uniform( + low=0.01, high=0.5, size=[len(indices)]).astype(np.float32)) - class EmbeddingBagWithLinear(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') - self.fc = torch.nn.Linear(5, 5) + # Test to make sure module is quantized correctly. + self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) + self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) + self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True) + + class EmbeddingBagWithLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, + include_last_offset=True, scale_grad_by_freq=False, mode='sum') + self.fc = torch.nn.Linear(5, 5) - def forward(self, indices, offsets, per_sample_weights, linear_in): - return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in) + def forward(self, indices, offsets, per_sample_weights, linear_in): + return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in) - # Test quantization of embedding_bag layer only - model = EmbeddingBagWithLinear().eval() - model.emb.qconfig = float_qparams_dynamic_qconfig - prepare(model, inplace=True) - quantized_model = convert(model) + # Test quantization of embedding_bag layer only + model2 = EmbeddingBagWithLinear().eval() + model2.emb.qconfig = float_qparams_qconfig + prepare(model2, inplace=True) + quantized_model = convert(model2) - self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) - self.checkLinear(model.fc) - self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) + self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) + self.checkLinear(model2.fc) + self.checkDynamicQuantizedModule(quantized_model.emb, torch.nn.quantized.EmbeddingBag, torch.quint8) @skipIfNoFBGEMM def test_custom_module_class(self): @@ -617,9 +673,6 @@ def from_observed(cls, observed_module): quantized = cls(nnq.Conv2d.from_float(observed_module.conv)) return quantized - register_observed_custom_module_mapping(CustomModule, ObservedCustomModule) - register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule) - class M(torch.nn.Module): def __init__(self): super().__init__() @@ -660,14 +713,28 @@ def forward(self, x): original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach()) original_m.qconfig = default_qconfig - m = prepare(original_m) - self.checkObservers(m) + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + CustomModule: ObservedCustomModule + } + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: QuantizedCustomModule + } + } + m = prepare( + original_m, + prepare_custom_config_dict=prepare_custom_config_dict) + self.checkObservers(m, None, prepare_custom_config_dict) # calibration m(data) # all activation observers are inserted in the top level module # check converted/quantized model - m = convert(m) + m = convert( + m, + convert_custom_config_dict=convert_custom_config_dict) # check if the module is properly quantized self.assertEqual(type(m.quant), nnq.Quantize) self.assertEqual(type(m.conv), nnq.Conv2d) @@ -684,6 +751,20 @@ def forward(self, x): ref_res = ref_m(data) self.assertEqual(res, ref_res) + @skipIfNoFBGEMM + def test_convtranspose_per_channel_fails_early(self): + r""" + Verifies that attempting to quantize a ConvTranspose module with per-Channel + weight observers fails in the prepare step, as opposed to the convert step. + """ + m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) + m.qconfig = torch.quantization.get_default_qconfig('fbgemm') + with self.assertRaises(AssertionError) as context: + mp = torch.quantization.prepare(m) + self.assertTrue( + str(context.exception) == + 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') + @skipIfNoFBGEMM class TestPostTrainingDynamic(QuantizationTestCase): @@ -918,40 +999,59 @@ def checkQuantized(model): def test_quantized_rnn(self, qconfig, dtype): r"""Test dynamic quantization, scriptability and serialization for dynamic quantized lstm modules on int8 and fp16 """ - model = RNNDynamicModel('LSTM').eval() niter = 10 x = torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) qconfig_dict = { - torch.nn.LSTM : qconfig + torch.nn.LSTM : qconfig, + torch.nn.GRU: qconfig } - if dtype == torch.float16: - model_quantized = quantize_dynamic(model=model, dtype=dtype) - else: - model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype) - # Smoke test extra reprs - self.assertTrue('DynamicQuantizedLSTM' in str(model_quantized)) - self.checkDynamicQuantizedModule(model_quantized.mod, torch.nn.quantized.dynamic.LSTM, dtype) - self.checkScriptable(model_quantized, [[x]], check_save_load=True) + def checkQuantized(model, module_type): + mod_type_map = {'LSTM': torch.nn.quantized.dynamic.LSTM, + 'GRU': torch.nn.quantized.dynamic.GRU} + mod_repr_map = {'LSTM': 'DynamicQuantizedLSTM', + 'GRU': 'DynamicQuantizedGRU'} + self.assertTrue(mod_repr_map[module_type] in str(model_quantized)) + self.checkDynamicQuantizedModule(model_quantized.mod, mod_type_map[module_type], dtype) + + for module_type in ['LSTM', 'GRU']: + model = RNNDynamicModel(module_type).eval() - class ScriptWrapperPacked(torch.nn.Module): - def __init__(self, cell): - super(ScriptWrapperPacked, self).__init__() - self.cell = cell + if dtype == torch.float16: + model_quantized = quantize_dynamic(model=model, dtype=dtype) + else: + model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype) - def forward(self, - x # type: PackedSequence - ): - # type: (...) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]] - return self.cell(x) + checkQuantized(model_quantized, module_type) + self.checkScriptable(model_quantized, [[x]], check_save_load=True) - packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2])) - model_with_packed_input = ScriptWrapperPacked(model_quantized.mod) - scripted = torch.jit.script(model_with_packed_input) - # We cannot trace with input dtype being a packed sequence - self._checkScriptable(model_with_packed_input, scripted, [[packed_input]], True) + class ScriptWrapperPackedLSTM(torch.nn.Module): + def __init__(self, cell): + super(ScriptWrapperPackedLSTM, self).__init__() + self.cell = cell + + def forward(self, x: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: + return self.cell(x) + + class ScriptWrapperPackedGRU(torch.nn.Module): + def __init__(self, cell): + super(ScriptWrapperPackedGRU, self).__init__() + self.cell = cell + + def forward(self, x: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: + return self.cell(x) + + script_wrapper_map = {'LSTM': ScriptWrapperPackedLSTM, + 'GRU': ScriptWrapperPackedGRU} + packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2])) + model_with_packed_input = script_wrapper_map[module_type](model_quantized.mod) + model_with_packed_input(packed_input) + scripted = torch.jit.script(model_with_packed_input) + scripted(packed_input) + # We cannot trace with input dtype being a packed sequence + self._checkScriptable(model_with_packed_input, scripted, [[packed_input]], True) @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]), @@ -1062,7 +1162,7 @@ def checkQuantized(model): checkQuantized(model) model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, - self.train_data) + [self.train_data]) checkQuantized(model) def test_eval_only_fake_quant(self): @@ -1102,7 +1202,7 @@ def checkQuantized(model): checkQuantized(model) model = ManualConvLinearQATModel() - model = quantize_qat(model, test_only_train_fn, self.img_data_2d_train) + model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) checkQuantized(model) def test_train_save_load_eval(self): @@ -1198,6 +1298,178 @@ def checkHooksIsPresent(model, before_convert=True): torch.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False) + def test_add_scalar_uses_input_qparams(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.ff = torch.nn.quantized.FloatFunctional() + + def forward(self, x): + x = self.quant(x) + x = self.ff.add_scalar(x, 1.0) + return x + + m = M() + m.qconfig = torch.quantization.default_qconfig + mp = torch.quantization.prepare_qat(m) + mp(torch.randn(4, 4)) + mq = torch.quantization.convert(mp) + res = mq(torch.randn(4, 4)) + eps = 1e-5 + self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps) + + def test_mul_scalar_uses_input_qparams(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.ff = torch.nn.quantized.FloatFunctional() + + def forward(self, x): + x = self.quant(x) + x = self.ff.mul_scalar(x, 2.0) + return x + + m = M() + m.qconfig = torch.quantization.default_qconfig + mp = torch.quantization.prepare_qat(m) + mp(torch.randn(4, 4)) + mq = torch.quantization.convert(mp) + res = mq(torch.randn(4, 4)) + eps = 1e-5 + self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps) + + +class TestEagerModeOps(QuantizationTestCase): + def _test_activation_op_impl( + self, float_module_class, quantized_module_class, extra_module_kwargs): + """ Implementation for testing common activation ops like leaky relu + Args: + extra_module_kwargs: keyword args to instantiate the float module + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_op = float_module_class(**extra_module_kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.activation_op(x) + x = self.dequant(x) + return x + + m = M().eval() + m.qconfig = default_qconfig + m = prepare(m) + self.checkObservers(m) + m = convert(m) + self.assertEqual(type(m.activation_op), quantized_module_class) + + def test_leaky_relu(self): + self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False}) + + def test_relu(self): + self._test_activation_op_impl(nn.ReLU, nn.ReLU, {'inplace': False}) + + +class TestEagerModeQATOps(QuantizationTestCase): + def _test_activation_convert_numerics_impl(self, Act, data): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.act = Act() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.act(x) + x = self.dequant(x) + return x + + m = M().train() + m.qconfig = default_qat_qconfig + m = prepare_qat(m) + before_convert = m(data) + m = convert(m) + after_convert = m(data) + self.assertEqual(before_convert, after_convert) + + def test_fixed_qparam_ops(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + self.hardsigmoid = torch.nn.Hardsigmoid() + self.tanh = torch.nn.Tanh() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.sigmoid(x) + x = self.hardsigmoid(x) + x = self.tanh(x) + x = self.dequant(x) + return x + + m = M().train() + m.qconfig = default_qat_qconfig + m = prepare_qat(m) + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) + data = torch.randn(1, 3, 2, 4) + before_convert = m(data) + m = convert(m) + after_convert = m(data) + self.assertEqual(before_convert, after_convert) + # make sure activation post process is removed + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + # verify fake quant module is removd + self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process')) + # verify that hooks are removed + self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) + + # make sure no fake quantize module is inserted for eval mode + + def checkNoFQModule(m): + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + self.assertFalse(hasattr(getattr(m, attr), "activation_post_process")) + self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) + + m = M().eval() + m.qconfig = default_qconfig + m = prepare(m) + checkNoFQModule(m) + m = convert(m) + checkNoFQModule(m) + + def test_leaky_relu(self): + data = torch.randn(1, 3, 2, 4) + self._test_activation_convert_numerics_impl(nn.LeakyReLU, data) + + def test_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return x + + m = M().train() + m.qconfig = default_qconfig + m = prepare_qat(m) + # make sure no activation_post_process is inserted for relu + self.assertFalse(hasattr(m, "activation_post_process")) + m = convert(m) + # make sure ReLU module is not changed + self.assertTrue(type(m.relu), nn.ReLU) + class TestFunctionalModule(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out @given(train_mode=st.booleans()) @@ -1284,7 +1556,7 @@ def checkQuantized(model): model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) - model = quantize_qat(model, test_only_train_fn, self.img_data_1d_train) + model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model) @@ -1364,7 +1636,7 @@ def checkQuantized(model): ['bn2', 'relu3'], ['sub1.conv', 'sub1.bn'], ['conv3', 'bn3', 'relu4']]) - model = quantize(model, test_only_eval_fn, self.img_data_1d) + model = quantize(model, test_only_eval_fn, [self.img_data_1d]) checkQuantized(model) def test_fusion_sequential_model_train(self): @@ -1759,8 +2031,9 @@ def __init__(self, cell): self.cell = cell @torch.jit.script_method - def forward(self, x, hiddens): - # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] + def forward(self, x: torch.Tensor, + hiddens: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: return self.cell(x, hiddens) else: @@ -1770,8 +2043,7 @@ def __init__(self, cell): self.cell = cell @torch.jit.script_method - def forward(self, x, hiddens): - # type: (torch.Tensor, torch.Tensor) -> torch.Tensor + def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor: return self.cell(x, hiddens) cell = ScriptWrapper(cell) @@ -1883,8 +2155,7 @@ def __init__(self, cell): self.cell = cell @torch.jit.script_method - def forward(self, x, hiddens): - # type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return self.cell(x, hiddens) compare_quantized_unquantized(ScriptWrapper, cell) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index c1641ae3e1946..57fd992a8de03 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -3,35 +3,45 @@ import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd +import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq import torch.multiprocessing as mp -# symbolic trace -from torch.fx import symbolic_trace - # graph mode quantization based on fx -from torch.quantization import ( - QuantType, - fuse_fx, +from torch.quantization.quantize_fx import ( prepare_fx, convert_fx, - prepare_static_fx, - convert_static_fx, - quantize_static_fx, - quantize_dynamic_fx, prepare_qat_fx, - register_observed_custom_module_mapping, - register_quantized_custom_module_mapping, +) + +from torch.quantization.fx.pattern_utils import ( + is_match, + MatchAllNode, ) from torch.quantization import ( + QuantType, + QuantStub, + DeQuantStub, + QuantWrapper, + quant_type_to_str, default_qconfig, default_dynamic_qconfig, - float16_dynamic_qconfig, default_qat_qconfig, + per_channel_dynamic_qconfig, + float16_dynamic_qconfig, + float_qparams_weight_only_qconfig, + get_default_qconfig, + get_default_qat_qconfig, + fuse_modules, prepare, prepare_qat, convert, + quantize_dynamic, + default_placeholder_observer, + PerChannelMinMaxObserver, + QConfigDynamic, + FixedQParamsFakeQuantize, ) # test utils @@ -42,27 +52,184 @@ skip_if_no_torchvision, train_one_epoch, run_ddp, + test_only_eval_fn, + test_only_train_fn, +) + +from torch.testing._internal.common_quantization import ( + LinearModelWithSubmodule, + ResNetBase, + RNNDynamicModel, + RNNCellDynamicModel, ) from torch.testing._internal.common_quantized import ( + supported_qengines, override_qengines, + override_quantized_engine, ) from torch.testing._internal.common_distributed import skip_if_not_multigpu from torch.testing._internal.common_quantization import NodeSpec as ns -from torch.testing._internal.common_quantization import ( - test_only_eval_fn, -) from torch.testing import FileCheck +import copy import itertools import operator import unittest +import io +from typing import Callable + +class TestFuseFx(QuantizationTestCase): + def test_fuse_conv_bn_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(1, 1, 1) + self.conv2d = nn.Conv2d(1, 1, 1) + self.conv3d = nn.Conv3d(1, 1, 1) + self.bn1d = nn.BatchNorm1d(1) + self.bn2d = nn.BatchNorm2d(1) + self.bn3d = nn.BatchNorm3d(1) + self.conv1d2 = nn.Conv1d(1, 1, 1) + self.conv2d2 = nn.Conv2d(1, 1, 1) + self.conv3d2 = nn.Conv3d(1, 1, 1) + self.bn1d2 = nn.BatchNorm1d(1) + self.bn2d2 = nn.BatchNorm2d(1) + self.bn3d2 = nn.BatchNorm3d(1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1d(x) + x = self.bn1d(x) + x = self.conv2d(x) + x = self.bn2d(x) + x = self.conv3d(x) + x = self.bn3d(x) + x = self.conv1d2(x) + x = self.bn1d2(x) + x = self.relu(x) + x = self.conv2d2(x) + x = self.bn2d2(x) + x = self.relu(x) + x = self.conv3d2(x) + x = self.bn3d2(x) + x = self.relu(x) + return x + + # test train mode + m = M().train() + # currently we don't check if the module are configured with qconfig before fusion + # TODO: if we decide to do that in the future, this test needs to + # be updated + # train mode fuse_fx is called in prepare_qat_fx + m = prepare_qat_fx(m, {}) + expected_nodes = [ + ns.call_module(nni.ConvBn1d), + ns.call_module(nni.ConvBn2d), + ns.call_module(nni.ConvBn3d), + ns.call_module(nni.ConvBnReLU1d), + ns.call_module(nni.ConvBnReLU2d), + ns.call_module(nni.ConvBnReLU3d), + ] + expected_occurrence = { + ns.call_module(nn.ReLU): 0 + } + self.checkGraphModuleNodes( + m, + expected_node_list=expected_nodes, + expected_node_occurrence=expected_occurrence) + + # test eval mode + m = M().eval() + from torch.quantization.quantize_fx import fuse_fx + # fuse_fx is a top level api and only supports eval mode + m = fuse_fx(m) + expected_nodes = [ + ns.call_module(nn.Conv1d), + ns.call_module(nn.Conv2d), + ns.call_module(nn.Conv3d), + ns.call_module(nni.ConvReLU1d), + ns.call_module(nni.ConvReLU2d), + ns.call_module(nni.ConvReLU3d), + ] + # ConvBnRelu1d is not fused + expected_occurrence = { + ns.call_module(nn.ReLU): 0 + } + self.checkGraphModuleNodes( + m, + expected_node_list=expected_nodes, + expected_node_occurrence=expected_occurrence) + + def test_fuse_module_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(1, 1, 1) + self.conv2d = nn.Conv2d(1, 1, 1) + self.conv3d = nn.Conv3d(1, 1, 1) + self.bn1d = nn.BatchNorm1d(1) + self.bn2d = nn.BatchNorm2d(1) + self.bn3d = nn.BatchNorm3d(1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1d(x) + x = self.relu(x) + x = self.conv2d(x) + x = self.relu(x) + x = self.conv3d(x) + x = self.relu(x) + x = self.bn1d(x) + x = self.relu(x) + x = self.bn2d(x) + x = self.relu(x) + x = self.bn3d(x) + x = self.relu(x) + return x + + m = M().eval() + from torch.quantization.quantize_fx import fuse_fx + m = fuse_fx(m) + expected_nodes = [ + ns.call_module(nni.ConvReLU1d), + ns.call_module(nni.ConvReLU2d), + ns.call_module(nni.ConvReLU3d), + ns.call_module(nni.BNReLU2d), + ns.call_module(nni.BNReLU3d), + ] + self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) @skipIfNoFBGEMM class TestQuantizeFx(QuantizationTestCase): + def test_pattern_match(self): + """ test MatchAllNode with + conv - bn - add - relu pattern + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + self.bn = nn.BatchNorm2d(1) + self.relu = nn.ReLU() + + def forward(self, x, y): + x = self.conv(x) + x = self.bn(x) + x = x + y + x = self.relu(x) + return x + + pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) + m = torch.fx.symbolic_trace(M()) + modules = dict(m.named_modules()) + for n in m.graph.nodes: + if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: + self.assertTrue(is_match(modules, n, pattern)) + def _get_conv_linear_test_cases(self): ''' Returns a list of test cases, with format: is_dynamic, ModuleClass, module_constructor_inputs, @@ -154,11 +321,11 @@ def test_functional_debug(self): quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC node_occurrence = dict() if weight_prepack_node: - node_occurrence[weight_prepack_node] = 1 + node_occurrence[weight_prepack_node] = 0 + node_occurrence[quantized_node] = 0 self.checkGraphModeFxOp( ModuleClass(*module_constructor_inputs), inputs, quant_type, - expected_node=quantized_node, expected_node_occurrence=node_occurrence, debug=True) @@ -176,16 +343,94 @@ def forward(self, x): return F.linear(x, self.weight) m = M(torch.rand(1, 1)).eval() - original = symbolic_trace(m) qconfig = default_dynamic_qconfig qconfig_dict = {'': qconfig} - quantized = quantize_dynamic_fx(original, qconfig_dict, debug=True) + prepared = prepare_fx(m, qconfig_dict) + quantized = convert_fx(prepared, debug=True) qparams = (quantized._scale_0, quantized._zero_point_0) weight_obs = qconfig.weight() weight_obs(quantized.weight) ref_qparams = weight_obs.calculate_qparams() self.assertEqual(qparams, ref_qparams) + def test_conv_bn_relu(self): + convs = { + 1: nn.Conv1d, + 2: nn.Conv2d, + 3: nn.Conv3d, + } + bns = { + 1: nn.BatchNorm1d, + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d, + } + quantized_convs = { + 1: nnq.Conv1d, + 2: nnq.Conv2d, + 3: nnq.Conv3d, + } + quantized_conv_relus = { + 1: nniq.ConvReLU1d, + 2: nniq.ConvReLU2d, + 3: nniq.ConvReLU3d, + } + + class M(torch.nn.Module): + def __init__(self, dim, has_relu): + super().__init__() + self.conv = convs[dim](3, 3, 3) + self.bn = bns[dim](3) + self.relu = nn.ReLU() if has_relu else nn.Identity() + self.has_relu = has_relu + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.bn(x) + if self.has_relu: + x = self.relu(x) + x = self.dequant(x) + return x + + options = itertools.product([1, 2], [True, False], self.static_quant_types) + for dim, has_relu, quant_type in options: + expected_node = ns.call_module( + quantized_conv_relus[dim] if has_relu + else quantized_convs[dim]) + m = M(dim, has_relu) + m_eager = copy.deepcopy(m) + result = self.checkGraphModeFxOp( + m, + self.img_data_dict[dim], + quant_type, + expected_node=expected_node, + ) + + # check numerics + qengine = torch.backends.quantized.engine + if quant_type == QuantType.STATIC: + m_eager.eval() + qconfig = get_default_qconfig(qengine) + prepare_fn = prepare + else: + m_eager.train() + qconfig = get_default_qat_qconfig(qengine) + prepare_fn = prepare_qat + + fuse_list = ["conv", "bn"] + if has_relu: + fuse_list.append("relu") + fuse_modules(m_eager, fuse_list, inplace=True) + m_eager.qconfig = qconfig + m_eager = prepare_fn(m_eager) + m_eager(*self.img_data_dict[dim][0]) + m_eager = convert(m_eager) + result_eager = m_eager(*self.img_data_dict[dim][0]) + self.assertEqual(result, result_eager) + + @skipIfNoFBGEMM def test_dynamic_quant_fp16(self): class Linear(torch.nn.Module): @@ -222,14 +467,11 @@ def forward(self, x): for debug in [True, False]: node_occurrence = dict() if weight_prepack_node: - if debug: - node_occurrence[weight_prepack_node] = 1 - else: - node_occurrence[weight_prepack_node] = 0 + node_occurrence[weight_prepack_node] = 0 m = ModuleClass(*module_constructor_inputs).eval() - m = symbolic_trace(m) qconfig_dict = {"": float16_dynamic_qconfig} - m = quantize_dynamic_fx(m, qconfig_dict, debug=debug) + m = prepare_fx(m, qconfig_dict) + m = convert_fx(m, debug=debug) self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) @@ -261,12 +503,8 @@ def forward(self, x): device = torch.device('cuda:0') model.to(device) - # symbolically trace - model = symbolic_trace(model) - # QAT prepare - model = fuse_fx(model) - model = prepare_fx(model, qconfig_dict) + model = prepare_qat_fx(model, qconfig_dict) # ensure that running an input on CUDA works without any needed changes input = torch.randn(4, 1, 4, 4, device=device) @@ -279,27 +517,6 @@ def forward(self, x): model_device = next(iter(model_devices)) self.assertEqual(model_device, device) - @skipIfNoFBGEMM - def test_inplace_option(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x): - return self.conv(x) - - model = symbolic_trace(M().eval()) - qconfig_dict = {'': default_qconfig} - non_inplace_model = quantize_static_fx( - model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=False) - inplace_model = model - inplace_model = quantize_static_fx( - inplace_model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=True) - non_inplace_res = non_inplace_model(self.img_data_2d[0][0]) - inplace_res = inplace_model(self.img_data_2d[0][0]) - self.assertEqual(non_inplace_res, inplace_res) - @skipIfNoFBGEMM def test_dict_output(self): """ Make sure quantization runs for models with dictionary output @@ -313,13 +530,202 @@ def forward(self, x): return {"output": self.conv(x["input"])} dict_input = {"input": torch.randn(1, 1, 1, 1)} - m = symbolic_trace(M()).eval() + m = M().eval() qconfig_dict = {"": default_qconfig} - m = prepare_static_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict) m(dict_input) - m = convert_static_fx(m) + m = convert_fx(m) m(dict_input) + @override_qengines + def test_attention(self): + """ Make sure quantization runs for a corner case in attention module + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv(x) + q, k, v = x.chunk(3, dim=0) + q = q.contiguous().view(-1, 1).transpose(0, 1) + k = k.contiguous().view(-1, 1).transpose(0, 1) + v = v.contiguous().view(-1, 1).transpose(0, 1) + torch._assert( + k.size(1) == 1, "key size should be equal to 1" + ) + r = torch.mm(k, v) + return q * k + r + + tensor_input = torch.randn(3, 1, 1, 1) + m = M().eval() + qconfig_dict = { + "": None, + "object_type": [ + (nn.Conv2d, default_qconfig), + ] + } + # make sure it runs + m = prepare_fx(m, qconfig_dict) + m(tensor_input) + m = convert_fx(m) + m(tensor_input) + + def _test_standalone_module( + self, + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check): + """ Test standalone module with different quantized input/quantized output + configurations + """ + class StandaloneModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv(x) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.standalone = StandaloneModule() + + def forward(self, x): + x = self.conv(x) + x = self.standalone(x) + return x + + class RefM(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + data = torch.randn(1, 1, 1, 1) + # instantiate M and RefM and align the parameters + original_m = M().eval() + original_ref_m = RefM().eval() + original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) + original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) + original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) + original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) + + for is_name in [True, False]: + if is_name: + prepare_config = { + "standalone_module_name": [("standalone", None, interface_config)] + } + else: + prepare_config = { + "standalone_module_class": [(StandaloneModule, None, interface_config)] + } + + original_m_copy = copy.deepcopy(original_m) + original_ref_m_copy = copy.deepcopy(original_ref_m) + + qconfig_dict = {"": default_qconfig} + # check prepared model + m = prepare_fx( + original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) + # calibration + m(data) + self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) + + # check converted/quantized model + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) + res = m(data) + + # quantize the reference model + ref_m = prepare_fx(original_ref_m_copy, qconfig_dict) + ref_m(data) + ref_m = convert_fx(ref_m) + ref_res = ref_m(data) + self.assertEqual(res, ref_res) + + def test_standalone_module_float_interface(self): + float_interface_config = { + "input_quantized_idxs": [], # float input + "output_quantized_idxs": [], # float output + } + interface_config = float_interface_config + # input and output of first conv, observer for standalone module + # will be inserted in the standalone module itself + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for input and output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # standalone module will take float as input and output + # so we'll see quantize and dequantize in the modoule + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d): 1, + ns.call_method("dequantize") : 1, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + + def test_standalone_module_quantized_interface(self): + quantized_interface_config = { + "input_quantized_idxs": [0], # quantized input + "output_quantized_idxs": [0], # quantized output + } + interface_config = quantized_interface_config + # observer for input and output of first conv + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 1 + } + convert_count_check = { + # quantizing input for conv + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + # dequantizing output of standalone module + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # quantization of input happens in parent module + # quantization of output happens in the quantized conv module + ns.call_function(torch.quantize_per_tensor) : 0, + ns.call_module(nnq.Conv2d): 1, + # dequantization for output happens in parent module + ns.call_method("dequantize") : 0, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + @skipIfNoFBGEMM def test_qconfig_none(self): class M(torch.nn.Module): @@ -334,13 +740,12 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"": default_qconfig, "module_name": [("conv2", None)]} - m = prepare_static_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) m(data) - m = convert_static_fx(m) + m = convert_fx(m) m(data) # first conv is quantized, second conv is not quantized node_list = [ @@ -364,12 +769,11 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} - m = prepare_static_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) m(data) - m = convert_static_fx(m) + m = convert_fx(m) m(data) # first conv is quantized, second conv is not quantized node_list = [ @@ -389,12 +793,11 @@ def forward(self, x, y): return x + y m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} - m = prepare_static_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) m(data, data) - m = convert_static_fx(m) + m = convert_fx(m) m(data, data) # first conv is quantized, second conv is not quantized node_list = [ @@ -417,12 +820,11 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} - m = prepare_static_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) m(data) - m = convert_static_fx(m) + m = convert_fx(m) m(data) # first conv is quantized, second conv is not quantized node_list = [ @@ -454,7 +856,6 @@ def forward(self, x): return x m = M().eval() - m = symbolic_trace(m) global_qconfig = default_qconfig object_type_qconfig = default_dynamic_qconfig module_name_regex_qconfig = float16_dynamic_qconfig @@ -464,13 +865,12 @@ def forward(self, x): "object_type": [(nn.Conv2d, object_type_qconfig)], "module_name_regex": [("module_conv*", module_name_regex_qconfig)], "module_name": [("module_conv2", module_name_qconfig)]} - m = prepare_static_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict) self.assertEqual(m.linear.qconfig, global_qconfig) self.assertEqual(m.conv.qconfig, object_type_qconfig) self.assertEqual(m.module_conv1.qconfig, module_name_regex_qconfig) self.assertEqual(m.module_conv2.qconfig, module_name_qconfig) - def test_remove_qconfig(self): class M(torch.nn.Module): def __init__(self): @@ -481,46 +881,141 @@ def forward(self, x): return self.avg_pool(x) m = M().eval() - m = symbolic_trace(m) qconfig_dict = {'': default_qconfig} - m = prepare_static_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) m(data) - m = convert_static_fx(m) + m = convert_fx(m) m(data) for name, module in m.named_modules(): self.assertFalse(hasattr(module, 'qconfig'), 'qconfig is not removed for ' + name) - @skipIfNoFBGEMM - def test_qat_and_script(self): - class TwoLayerLinear(nn.Module): + def test_default_quant_after_none_qconfig(self): + """ Make sure default quant is inserted properly""" + class M(torch.nn.Module): def __init__(self): - super(TwoLayerLinear, self).__init__() - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) def forward(self, x): - x = self.fc1(x) - return self.fc2(x) + x = self.conv1(x) + x = x.transpose(1, 2) + x = self.conv2(x) - class Model(nn.Module): + m = M().eval() + qconfig_dict = { + "": default_qconfig, + "module_name": [ + ("conv1", None) + ] + } + m = prepare_fx(m, qconfig_dict) + m = convert_fx(m) + + def test_qconfig_for_call_method(self): + class Sub(torch.nn.Module): def __init__(self): - super(Model, self).__init__() - self.subm = TwoLayerLinear() - self.fc = nn.Linear(5, 5) + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): - x = self.subm(x) - x = self.fc(x) - return x + x = x.transpose(2, 3) + x = self.conv(x) + return x.transpose(2, 3) - model = Model() + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.sub = Sub() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.sub(x) + x = self.conv2(x) + return x.transpose(2, 3) + + qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]} + # since sub is configured to have qconfig None, we should dequantize the output + # of self.conv1 and quantize the input of self.conv2 + # dequantize after conv2 should happen after transpose since + # it is configured with default_qconfig + # nodes in Sub module instance is not quantized + node_list1 = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("dequantize"), + ns.call_method("transpose"), + ns.call_module(nn.Conv2d), + ns.call_method("transpose"), + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("transpose"), + ns.call_method("dequantize") + ] + + qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]} + # Only nodes in Sub module instance are quantized + # the first transpose is not quantized because the input is not quantized + node_list2 = [ + ns.call_module(nn.Conv2d), + ns.call_method("transpose"), + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("transpose"), + ns.call_method("dequantize"), + ns.call_module(nn.Conv2d), + ns.call_method("transpose"), + ] + + for qconfig_dict, node_list in [ + (qconfig_dict1, node_list1), + (qconfig_dict2, node_list2) + ]: + m = M().eval() + m = prepare_fx(m, qconfig_dict) + m(torch.randn(2, 1, 3, 3)) + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node_list=node_list) + # make sure it runs + m(torch.randn(2, 1, 3, 3)) + + def test_preserve_attributes(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv(x) + + m = M() + m.eval() + m.preserved_attr = 3 + prepare_custom_config_dict = { + "preserved_attributes": ["preserved_attr"] + } + m = prepare_fx(m, {"": default_qconfig}, prepare_custom_config_dict) + + def assertAttrPreserved(m): + self.assertTrue(hasattr(m, "preserved_attr")) + self.assertTrue(m.preserved_attr, 3) + + assertAttrPreserved(m) + convert_custom_config_dict = { + "preserved_attributes": ["preserved_attr"] + } + m = convert_fx(m, convert_custom_config_dict=convert_custom_config_dict) + assertAttrPreserved(m) + + @skipIfNoFBGEMM + def test_qat_and_script(self): + model = LinearModelWithSubmodule().train() qengine = torch.backends.quantized.engine qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)} - - # symbolically trace - model = symbolic_trace(model) model = prepare_qat_fx(model, qconfig_dict) # ensure scripting works @@ -553,55 +1048,31 @@ def forward(self, x): @skipIfNoFBGEMM def test_save_observer_state_dict(self): - class TwoLayerLinear(nn.Module): - def __init__(self): - super(TwoLayerLinear, self).__init__() - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) - - def forward(self, x): - x = self.fc1(x) - return self.fc2(x) - - class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.subm = TwoLayerLinear() - self.fc = nn.Linear(5, 5) - - def forward(self, x): - x = self.subm(x) - x = self.fc(x) - return x - - model = Model().eval() + orig = LinearModelWithSubmodule().eval() + model = orig qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')} - - # symbolically trace - model = symbolic_trace(model) - model = prepare_static_fx(model, qconfig_dict) + model = prepare_fx(model, qconfig_dict) # run it through input x = torch.randn(5, 5) model(x) - quant = convert_static_fx(model) + quant = convert_fx(model) # save state_dict of model - import io + obs_dict = torch.quantization.get_observer_state_dict(model) b = io.BytesIO() - torch.save(model.state_dict(), b) + torch.save(obs_dict, b) b.seek(0) # Load the stats into new model - model_2 = Model().eval() - model_2 = symbolic_trace(model_2) - model_2 = prepare_static_fx(model_2, qconfig_dict) + model_2 = orig + model_2 = prepare_fx(model_2, qconfig_dict) loaded_dict = torch.load(b) - model_2.load_state_dict(loaded_dict) + torch.quantization.load_observer_state_dict(model_2, loaded_dict) - quant_2 = convert_static_fx(model_2) + quant_2 = convert_fx(model_2) # Verify that loaded state dict produces same results. self.assertEqual(quant(x), quant_2(x)) @@ -611,70 +1082,291 @@ def test_custom_module_class(self): class CustomModule(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(3, 3) def forward(self, x): - return self.conv(x) + return self.linear(x) class ObservedCustomModule(torch.nn.Module): - def __init__(self, conv): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_float(cls, float_module): assert hasattr(float_module, 'qconfig') - observed = cls(float_module.conv) + observed = cls(float_module.linear) observed.qconfig = float_module.qconfig return observed - class QuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class StaticQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') assert hasattr(observed_module, 'activation_post_process') - observed_module.conv.activation_post_process = \ + observed_module.linear.activation_post_process = \ observed_module.activation_post_process - quantized = cls(nnq.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnq.Linear.from_float(observed_module.linear)) return quantized - class DynamicallyQuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class DynamicQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') - assert hasattr(observed_module, 'activation_post_process') - quantized = cls(nnqd.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnqd.Linear.from_float(observed_module.linear)) return quantized class M(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) - self.custom = CustomModule() + self.linear = torch.nn.Linear(3, 3) + self.custom = CustomModule() + + def forward(self, x): + x = self.linear(x) + x = self.custom(x) + return x + + class RefM(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + data = torch.randn(3, 3) + # instantiate M and RefM and align the parameters + original_m = M().eval() + original_ref_m = RefM().eval() + original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) + original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) + original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) + original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) + + test_configs = { + "static": (default_qconfig, StaticQuantCustomModule, 3), + "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) + } + + for quant_type in [QuantType.DYNAMIC]: + key = quant_type_to_str(quant_type) + qconfig, quantized_module_class, num_observers = test_configs[key] + qconfig_dict = {"": qconfig} + if key == "static": + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule: ObservedCustomModule + } + } + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "static": { + ObservedCustomModule: quantized_module_class + } + } + } + else: + prepare_custom_config_dict = { + "non_traceable_module_class": [ + CustomModule + ] + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "dynamic": { + CustomModule: quantized_module_class + } + } + } + + # check prepared model + m = prepare_fx( + original_m, + qconfig_dict, + prepare_custom_config_dict=prepare_custom_config_dict) + # calibration + m(data) + # all activation observers are inserted in the top level module + count_check = { + ns.call_module(torch.quantization.MinMaxObserver): num_observers + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + + # check converted/quantized model + m = convert_fx( + m, + convert_custom_config_dict=convert_custom_config_dict) + if quant_type == QuantType.STATIC: + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Linear) : 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + self.assertEqual(type(m.custom), quantized_module_class) + res = m(data) + + # quantize the reference model + ref_m = prepare_fx(original_ref_m, qconfig_dict) + ref_m(data) + ref_m = convert_fx(ref_m) + ref_res = ref_m(data) + self.assertEqual(res, ref_res) + + @skipIfNoFBGEMM + def test_non_traceable_module(self): + class NonTraceable(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + for k in x.keys(): + print(x[k]) + return x + + class NonTraceable2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # data dependent control flow is not traceable + for i in x: + print(i) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.m1 = NonTraceable() + self.m2 = NonTraceable2() + + def forward(self, x): + x = self.m1(x) + x = self.m2(x) + return x + + m = M().eval() + qconfig_dict = {"": default_qconfig} + prepare_custom_config_dict = { + "non_traceable_module_name": [ + "m1" + ], + "non_traceable_module_class": [ + NonTraceable2 + ] + } + m = prepare_fx( + m, qconfig_dict, + prepare_custom_config_dict=prepare_custom_config_dict) + + node_occurrence = { + ns.call_module(NonTraceable) : 1, + ns.call_module(NonTraceable2) : 1, + } + # make sure these modules are not traced + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_prepared_model_deepcopy(self): + """Ensures that copy.deepcopy works correctly on a prepared model. + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self._foobar = 'foobar' + self.foobar2 = 'foobar2' + + def forward(self, x): + x = self.conv(x) + return x + + m = M() + m.eval() + qconfig_dict = {'': torch.quantization.default_qconfig} + prepared = prepare_fx(m, qconfig_dict) + # calibrate + prepared(torch.randn(4, 1, 4, 4)) + # copy + prepared_copy = copy.deepcopy(prepared) + # quantize, should run with no errors + quantized = convert_fx(prepared_copy) + + def test_dequantize(self): + r""" Test to make sure dequantize node are placed before + non-quantizable node + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.act = torch.nn.GELU() + + def forward(self, x): + x = self.conv(x) + return self.act(x) + + data = torch.rand(5, 1, 3, 3, dtype=torch.float) + for quant_type in self.static_quant_types: + node_list = [ + ns.call_module(nnq.Conv2d), + ns.call_method("dequantize"), + ns.call_module(nn.GELU), + ] + self.checkGraphModeFxOp( + M().eval(), (data,), quant_type, expected_node_list=node_list) + + def test_sequential(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.convs = torch.nn.Sequential( + torch.nn.Conv2d(1, 1, 1), + torch.nn.Conv2d(1, 1, 1) + ) def forward(self, x): - x = self.conv(x) - x = self.custom(x) + x = self.convs(x) return x - class RefM(torch.nn.Module): + data = torch.rand(5, 1, 3, 3, dtype=torch.float) + for quant_type in self.static_quant_types: + node_list = [ + ns.call_module(nnq.Conv2d), + ns.call_module(nnq.Conv2d), + ] + self.checkGraphModeFxOp( + M().eval(), (data,), quant_type, expected_node_list=node_list) + + def _test_quantized_inputs_outputs( + self, prepare_custom_config_dict, prepare_count_check, + convert_count_check): + """ + Test the option to have inputs and outputs of the graph quantized + """ + class M(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 1, 1) @@ -685,66 +1377,90 @@ def forward(self, x): x = self.conv2(x) return x - data = torch.randn(1, 1, 1, 1) - # instantiate M and RefM and align the parameters - original_m = M() - original_ref_m = RefM() - original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) - original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) - original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach()) - original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach()) - - from torch.fx.symbolic_trace import Tracer - - # define a custom tracer to not trace through the custom module - - class CustomTracer(Tracer): - def is_leaf_module(self, m, module_qualified_name): - return (m.__module__.startswith('torch.nn') and - not isinstance(m, torch.nn.Sequential)) or \ - isinstance(m, CustomModule) - - # TODO: add other quant types after mixed mode support - for quant_type in [QuantType.STATIC]: - # register observed and quantized custom module classes - register_observed_custom_module_mapping(CustomModule, ObservedCustomModule) - register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule) - - m = CustomTracer().trace(original_m).eval() - qconfig_dict = {'': default_qconfig} - # check prepared model - m = prepare_static_fx(m, qconfig_dict) - # calibration - m(data) - # all activation observers are inserted in the top level module - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 3 - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + # quantized input, quantized output + m = M() + qconfig_dict = {'': torch.quantization.default_qconfig} + m.eval() + mp = torch.quantization.quantize_fx.prepare_fx( + m, qconfig_dict, + prepare_custom_config_dict=prepare_custom_config_dict) + self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) + mp(torch.randn(1, 1, 4, 4)) + mq = torch.quantization.quantize_fx.convert_fx(mp) + self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) + + def test_quantized_input_quantized_output(self): + prepare_custom_config_dict = { + 'input_quantized_idxs': [0], 'output_quantized_idxs': [0]} + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2, + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor): 0, + ns.call_method('dequantize'): 0, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check) + + def test_fp32_input_quantized_output(self): + prepare_custom_config_dict = { + 'output_quantized_idxs': [0]} + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 3, + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_method('dequantize'): 0, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check) + + def test_quantized_input_fp32_output(self): + prepare_custom_config_dict = { + 'input_quantized_idxs': [0]} + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2, + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor): 0, + ns.call_method('dequantize'): 1, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check) - # check converted/quantized model - m = convert_static_fx(m) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - res = m(data) + def test_fp32_input_fp32_output(self): + prepare_custom_config_dict = {} + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 3, + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_method('dequantize'): 1, + } + self._test_quantized_inputs_outputs( + prepare_custom_config_dict, prepare_count_check, convert_count_check) - # quantize the reference model - ref_m = symbolic_trace(original_ref_m).eval() - ref_m = prepare_fx(ref_m, qconfig_dict) - ref_m(data) - ref_m = convert_fx(ref_m) - ref_res = ref_m(data) - self.assertEqual(res, ref_res) + @skipIfNoFBGEMM + def test_convtranspose_per_channel_fails_early(self): + r""" + Verifies that attempting to quantize a ConvTranspose module with per-Channel + weight observers fails in the prepare step, as opposed to the convert step. + """ + m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) + m.eval() + qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')} + with self.assertRaises(AssertionError) as context: + mp = prepare_fx(m, qconfig_dict) + self.assertTrue( + str(context.exception) == + 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') +@skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops """ @skipIfNoFBGEMM - def test_linear(self): + def test_linear_module(self): class ModuleLinear(torch.nn.Module): def __init__(self, has_relu=False, f_relu=False): super(ModuleLinear, self).__init__() @@ -760,27 +1476,9 @@ def __init__(self, has_relu=False, f_relu=False): def forward(self, x): return self.relu(self.linear(x)) - class FuncLinear(torch.nn.Module): - def __init__(self, has_relu=False, f_relu=False): - super(FuncLinear, self).__init__() - self.w = torch.randn(4, 30) - self.b = torch.randn(4) - if has_relu: - if f_relu: - self.relu = F.relu - else: - self.relu = torch.nn.ReLU() - else: - self.relu = torch.nn.Identity() - - def forward(self, x): - return self.relu(F.linear(x, self.w, self.b)) - data = (torch.rand((1, 30), dtype=torch.float),) options = itertools.product( [(ModuleLinear(has_relu=False), True)], - # TODO: enable after raw `tensor` is supported in fx - # (FuncLinear(has_relu=False), False)], self.all_quant_types) quantized_nodes = { # is_module @@ -791,12 +1489,6 @@ def forward(self, x): # note that we are checking the final result QuantType.QAT: ns.call_module(nnq.Linear), }, - False: { - # quant_type: - QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), - QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), - QuantType.QAT: ns.call_function(torch.ops.quantized.linear), - } } for (model, is_module), quant_type in options: self.checkGraphModeFxOp( @@ -805,17 +1497,65 @@ def forward(self, x): for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]): for model, quantized_node in [ (ModuleLinear(has_relu=True, f_relu=f_relu), ns.call_module(nniq.LinearReLU))]: - # TODO: support functional linear + relu fusion - # (FuncLinear(has_relu=True, f_relu=f_relu), ns.call_function(torch.ops.quantized.linear_relu))]: self.checkGraphModeFxOp(model, data, quant_type, quantized_node) @skipIfNoFBGEMM - def test_quantized_conv(self): + def test_linear_functional(self): + + class FuncLinear(torch.nn.Module): + def __init__(self, use_bias): + super(FuncLinear, self).__init__() + self.w = torch.randn(4, 30) + self.b = torch.randn(4) + self.use_bias = use_bias + + def forward(self, x): + if self.use_bias: + x = F.linear(x, self.w, self.b) + else: + x = F.linear(x, self.w) + return x + + data = (torch.rand((1, 30), dtype=torch.float),) + quant_type_to_qlinear_fun = { + QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), + QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), + QuantType.QAT: ns.call_function(torch.ops.quantized.linear), + } + quant_type_to_prepare_expected_node_occurrence = { + QuantType.DYNAMIC: {}, + # There should be 3 observers: after input, weight and activation. + QuantType.STATIC: { + ns.call_module(torch.quantization.HistogramObserver): 2, + ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, + }, + # There should be 3 observers: after input, weight and activation. + QuantType.QAT: { + ns.call_module(torch.quantization.FakeQuantize): 3, + }, + } + options = itertools.product( + (QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT), + (True, False), # use_bias + ) + for quant_type, use_bias in options: + model = FuncLinear(use_bias) + qlinear_fun = quant_type_to_qlinear_fun[quant_type] + prepare_expected_node_occurrence = \ + quant_type_to_prepare_expected_node_occurrence[quant_type] + self.checkGraphModeFxOp( + model, data, quant_type, qlinear_fun, + prepare_expected_node_occurrence=prepare_expected_node_occurrence) + + # TODO(future PR): test for Linear + ReLU fusion + + @skipIfNoFBGEMM + def test_conv_module(self): conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} - class Conv(torch.nn.Module): + class ConvWrapper(torch.nn.Module): def __init__(self, dim): - super(Conv, self).__init__() + super(ConvWrapper, self).__init__() self.conv = conv_module[dim](3, 3, 3).float() def forward(self, x): @@ -830,9 +1570,27 @@ def forward(self, x): } for dim, quant_type in options: model = self.checkGraphModeFxOp( - Conv(dim), self.img_data_dict[dim], quant_type, + ConvWrapper(dim), self.img_data_dict[dim], quant_type, quantized_nodes[dim]) + @skipIfNoFBGEMM + def test_conv2d_functional(self): + for bias in [True, False]: + conv = torch.nn.Conv2d(1, 1, 1, bias=bias) + # There should be 3 observers: after input, weight and activation. + # No observer after bias. + prepare_expected_node_occurrence = { + ns.call_module(torch.quantization.HistogramObserver): 2, + ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, + } + expected_node_occurrence = \ + {ns.call_function(torch.ops.quantized.conv2d): 1} + self.checkGraphModeFxOp( + conv, (torch.randn(4, 1, 4, 4),), QuantType.STATIC, + prepare_expected_node_occurrence=prepare_expected_node_occurrence, + expected_node_occurrence=expected_node_occurrence, + ) + @skipIfNoFBGEMM def test_quantized_conv_relu(self): """tests for conv1d_relu/conv2d_relu/conv3d_relu""" @@ -892,7 +1650,10 @@ def __init__(self, is_inplace, is_scalar): def forward(self, x, y): x = self.conv1(x) y = 3 if self.is_scalar else self.conv2(y) + # x = x + y x = self.op(x, y) + # x = y + x + x = self.op(y, x) return x # TODO: decide whether we want to quantize or not @@ -935,6 +1696,8 @@ def forward(self, x, y): y = 3 if self.is_scalar else self.conv2(y) x = self.op(x, y) x = self.relu(x) + x = self.op(y, x) + x = self.relu(x) return x data = (torch.rand((1, 1, 1, 1), dtype=torch.float), @@ -968,6 +1731,135 @@ def test_quantized_mul_relu(self): self._test_quantized_binary_op_relu_impl( operator.mul, operator.imul, torch.ops.quantized.mul_relu) + # TODO(future PR): make more generic + def _test_quantized_add_mul_qat(self, model, expected_node_occurrence): + qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} + mp = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + self.checkGraphModuleNodes( + mp, expected_node_occurrence=expected_node_occurrence) + + @skipIfNoFBGEMM + def test_quantized_add_qat(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = torch.add(x, 1.0) + x = self.conv1(x) + x = torch.add(x, 1.0) + x = torch.relu(x) + x = self.conv2(x) + return x + + m = M() + expected_node_occurrence = { + ns.call_module(torch.quantization.FakeQuantize): 4, + } + self._test_quantized_add_mul_qat(m, expected_node_occurrence) + + @skipIfNoFBGEMM + def test_quantized_mul_qat(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = torch.mul(x, 1.0) + x = self.conv1(x) + x = torch.mul(x, 1.0) + x = torch.relu(x) + x = self.conv2(x) + return x + + m = M() + expected_node_occurrence = { + ns.call_module(torch.quantization.FakeQuantize): 4, + } + self._test_quantized_add_mul_qat(m, expected_node_occurrence) + + def test_int8_input_no_unnecessary_fq(self): + """ + If the inputs to the graph are quantized and the only node + does not need an activation observer, verifies that the + activation observer is not inserted. + """ + class M(nn.Module): + def __init__(self, scalar): + super().__init__() + self.scalar = scalar + self.add_func = torch.nn.quantized.FloatFunctional() + + def forward(self, x): + return self.add_func.add_scalar(x, self.scalar) + + m = M(0.5) + mp = torch.quantization.quantize_fx.prepare_qat_fx( + m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')}, + prepare_custom_config_dict={"input_quantized_idxs": [0]}) + expected_node_occurrence = { + ns.call_module(torch.quantization.FakeQuantize): 0, + } + self.checkGraphModuleNodes( + mp, expected_node_occurrence=expected_node_occurrence) + + def test_quant_output_always_observed(self): + """ + If the output is hardcoded to be quantized, ensure that + there is always an observer, even if the last non-output node is not + quantizeable. + """ + qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} + prepare_custom_config_dict = {'output_quantized_idxs': [0]} + data = (torch.randn(4, 1, 4, 4),) + + # non-quantizeable node, quantized output + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.identity = torch.nn.Identity() + + def forward(self, x): + x = self.identity(x) + return x + + m1 = M1() + self.checkGraphModeFxOp( + m1, data, QuantType.QAT, + prepare_expected_node_occurrence={ + ns.call_module(torch.quantization.FakeQuantize): 1, + }, + expected_node_occurrence={ + ns.call_function(torch.quantize_per_tensor): 1, + }, + prepare_custom_config_dict=prepare_custom_config_dict) + + # quantizeable node, quantized output + class M2(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv(x) + return x + + m2 = M2() + self.checkGraphModeFxOp( + m2, data, QuantType.QAT, + prepare_expected_node_occurrence={ + # one for weights, one for activations + ns.call_module(torch.quantization.FakeQuantize): 2, + }, + expected_node_occurrence={ + ns.call_function(torch.quantize_per_tensor): 1, + }, + prepare_custom_config_dict=prepare_custom_config_dict) + @skipIfNoFBGEMM def test_quantized_cat(self): """ quantization of the output of cat will be depend on the @@ -1106,6 +1998,9 @@ def test_hardswish(self): def test_elu(self): self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu) + def test_leaky_relu(self): + self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) + def _test_norm_impl( self, float_module, float_op, op_args, data, quantized_module, quantized_op, skip_op_arg_for_functional=False): @@ -1295,11 +2190,10 @@ def forward(self, x): data = torch.rand(1, 3, 10, 10) # This model is not executable since we just put all ops # in the same forward - m = M() - original = symbolic_trace(m) + m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(original, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict) # not runnable quantized = convert_fx(prepared) @@ -1332,7 +2226,7 @@ def test_general_value_ops(self): """ class M(torch.nn.Module): def __init__(self): - super(M, self).__init__() + super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.avg_pool1d = torch.nn.AvgPool1d(3) self.avg_pool2d = torch.nn.AvgPool2d(3) @@ -1340,10 +2234,6 @@ def __init__(self): self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) - self.leaky_relu = torch.nn.LeakyReLU() - self.hardsigmoid = torch.nn.Hardsigmoid() - self.sigmoid = torch.nn.Sigmoid() - self.tanh = torch.nn.Tanh() def forward(self, x): x = self.conv(x) @@ -1365,36 +2255,15 @@ def forward(self, x): x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') - x = self.leaky_relu(x) - x = F.leaky_relu(x) - x = F.leaky_relu(x, inplace=True) - x = x.leaky_relu() - x.leaky_relu_() - x = self.hardsigmoid(x) - x = F.hardsigmoid(x) - x = F.hardsigmoid(x, inplace=True) - x = x.hardsigmoid() - x.hardsigmoid_() - x = self.sigmoid(x) - x = torch.sigmoid(x) - # F.sigmoid is deprecated - x = x.sigmoid() - x.sigmoid_() - x = self.tanh(x) - # F.tanh is deprecated - x = torch.tanh(x) - x = x.tanh() - x.tanh_() x = self.conv(x) return x # This model is not executable since we just put all ops # in the same forward - m = M() - original = symbolic_trace(m) + m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(original, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict) # not runnable quantized = convert_fx(prepared) @@ -1417,6 +2286,312 @@ def forward(self, x): expected_node_occurrence=count_check, expected_node_list=order_check) + @skipIfNoFBGEMM + def test_fixed_qparams_ops(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.sigmoid = torch.nn.Sigmoid() + self.hardsigmoid = torch.nn.Hardsigmoid() + self.tanh = torch.nn.Tanh() + + def forward(self, x): + x = self.conv(x) + # F.sigmoid is deprecated + x = self.sigmoid(x) + x = torch.sigmoid(x) + x = x.sigmoid() + x.sigmoid_() + x = self.hardsigmoid(x) + x = F.hardsigmoid(x) + x = F.hardsigmoid(x, inplace=True) + x = x.hardsigmoid() + x.hardsigmoid_() + x = self.tanh(x) + # F.tanh is deprecated + x = torch.tanh(x) + x = x.tanh() + x.tanh_() + x = self.conv(x) + return x + + for eval_mode in [True, False]: + # This model is not executable since we just put all ops + # in the same forward + m = M() + if eval_mode: + m.eval() + qconfig = default_qconfig + prepare = prepare_fx + fq_count = 0 + else: + m.train() + qconfig = default_qat_qconfig + prepare = prepare_qat_fx + fq_count = 13 + + # nothing to fuse so skipping the fuse step + qconfig_dict = {'': qconfig} + prepared = prepare(m, qconfig_dict) + # check the correct number of activation_post_process is inserted + count_check = { + ns.call_module(FixedQParamsFakeQuantize) : fq_count, + } + self.checkGraphModuleNodes( + prepared, + expected_node_occurrence=count_check) + # not runnable + quantized = convert_fx(prepared) + + # This checks that the dequantize from the output of first conv + # is being propagated to the end, so that we don't insert extra + # observers + # check exact counts of quantize and dequantize + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_method('dequantize') : 1 + } + order_check = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_module(nn.Sigmoid), + ns.call_module(nnq.Conv2d), + ns.call_method('dequantize'), + ] + self.checkGraphModuleNodes( + quantized, + expected_node_occurrence=count_check, + expected_node_list=order_check) + + def test_float_functional(self): + class TorchAdd(nn.Module): + """Wrapper around torch.add so that all ops can be found at build""" + def __init__(self): + super().__init__() + self.add_func = nnq.FloatFunctional() + + def forward(self, x, y): + return self.add_func.add(x, y) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.ff1 = TorchAdd() + self.ff2 = nnq.FloatFunctional() + self.ff3 = nnq.FloatFunctional() + self.ff4 = nnq.FloatFunctional() + self.ff5 = nnq.FloatFunctional() + self.ff6 = nnq.FloatFunctional() + + def forward(self, x): + x = self.ff1(x, x) + x = self.ff2.add_scalar(x, 3) + x = self.ff3.mul(x, x) + x = self.ff4.mul_scalar(x, 3) + x = self.ff5.add_relu(x, x) + x = self.ff6.cat([x]) + return x + + data = torch.rand(3, 3) + # Note: QAT test succeeded by chance, to make it actually work + # we need to fix eager mode FloatFunctional by removing + # activation_post_process in add_scalar and mul_scalar + for quant_type in self.static_quant_types: + m = M() + ref_m = torch.quantization.QuantWrapper(M()) + is_qat = quant_type == QuantType.QAT + if is_qat: + m.train() + ref_m.train() + qconfig = default_qat_qconfig + expected_act_post_process = torch.quantization.FakeQuantize + else: + m.eval() + ref_m.eval() + qconfig = default_qconfig + expected_act_post_process = torch.quantization.MinMaxObserver + + prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx + qconfig_dict = {"": qconfig} + m = prepare_fx_function(m, qconfig_dict) + node_occurrence = { + ns.call_module(expected_act_post_process): 5, + ns.call_module(torch.nn.quantized.FloatFunctional): 0 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + m(data) + node_list = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_function(torch.ops.quantized.add), + ns.call_function(torch.ops.quantized.add), + ns.call_function(torch.ops.quantized.mul), + ns.call_function(torch.ops.quantized.mul), + ns.call_function(torch.ops.quantized.add_relu), + ns.call_function(torch.ops.quantized.cat), + ns.call_method('dequantize') + ] + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node_list=node_list) + + # make sure numerics match with eager mode + ref_m.qconfig = qconfig + prepare_function = prepare_qat if is_qat else prepare + ref_m = prepare_function(ref_m) + ref_m(data) + ref_m = convert(ref_m) + self.assertEqual(m(data), ref_m(data)) + + def test_embedding(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) + + def forward(self, indices): + return self.emb(indices) + + model = M().eval() + indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + quantized_node = ns.call_module(nnq.Embedding) + configs = [ + (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)), + (None, ns.call_module(nn.Embedding)), + (default_qconfig, ns.call_module(nn.Embedding)), + ] + + for qconfig, node in configs: + qconfig_dict = {"": qconfig} + m = prepare_fx(model, qconfig_dict) + self.checkGraphModuleNodes(m, expected_node_occurrence={ + ns.call_module(torch.quantization.MinMaxObserver): 0 + }) + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node=node) + # make sure it runs + m(indices) + + def test_embedding_bag(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True) + + def forward(self, indices, offsets): + return self.emb(indices, offsets) + + indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + offsets = torch.tensor([0, 19, 20, 28, 28, 32]) + quantized_node = ns.call_module(nnq.EmbeddingBag) + inputs = (indices, offsets) + + for dtype in [torch.quint8, torch.quint4x2]: + model = M().eval() + float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0) + float_qparams_qconfig = QConfigDynamic(activation=default_placeholder_observer, + weight=float_qparams_observer) + self.checkGraphModeFxOp( + model, + inputs, + QuantType.DYNAMIC, + quantized_node, + custom_qconfig=float_qparams_qconfig + ) + + # check it works in None and static qconfig + for qconfig in [None, default_qconfig]: + qconfig_dict = {"": default_qconfig} + m = M().eval() + m = prepare_fx(model, qconfig_dict) + self.checkGraphModuleNodes(m, expected_node_occurrence={ + ns.call_module(torch.quantization.MinMaxObserver): 0 + }) + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) + # make sure it runs + m(*inputs) + + def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): + options = itertools.product(qconfigs, module_type_strs) + for qconfig, module_type_str in options: + model_eager = M(module_type_str).eval() + model_graph = copy.deepcopy(model_eager) + if torch.backends.quantized.engine == 'qnnpack' and \ + qconfig is float16_dynamic_qconfig: + continue + # fp16 dynamic quant is not supported for qnnpack + + eager_qconfig_dict = {x : qconfig for x in module_types} + model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) + + graph_qconfig_dict = { + "object_type": [ + (x, qconfig) for x in module_types + ] + } + model_graph = prepare_fx(model_graph, graph_qconfig_dict) + model_graph = convert_fx(model_graph) + self.assertEqual(model_eager(sample_input), model_graph(sample_input)) + self.checkScriptable(model_graph, [[sample_input]], True) + + def test_rnn_cell(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] + module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float) + self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) + + def test_rnn(self): + qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] + module_type_strs = ['LSTM'] + module_types = [torch.nn.LSTM] + niter = 10 + sample_input = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) + self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) + + def _test_conv_transpose_impl( + self, float_cls: Callable, q_cls: Callable, data: torch.Tensor): + with override_quantized_engine('qnnpack'): + # Create fp32 versions of FX and Eager models + m1 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2 = torch.nn.Sequential(float_cls(1, 1, 1)) + m2.load_state_dict(m1.state_dict()) + m2 = torch.quantization.QuantWrapper(m2) + # FX graph + q_result1 = self.checkGraphModeFxOp( + m1, (data,), QuantType.STATIC, + expected_node_occurrence={ + ns.call_module(q_cls): 1, + }) + # Eager + m2.qconfig = get_default_qconfig(torch.backends.quantized.engine) + m2.eval() + m2p = torch.quantization.prepare(m2) + m2p(data) + m2q = torch.quantization.convert(m2p) + q_result2 = m2q(data) + # verify results match + self.assertTrue(torch.allclose(q_result1, q_result2)) + + @unittest.skipUnless('qnnpack' in supported_qengines, + "This Pytorch Build has not been built with or does not support QNNPACK") + def test_conv_transpose_1d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4)) + + @unittest.skipUnless('qnnpack' in supported_qengines, + "This Pytorch Build has not been built with or does not support QNNPACK") + def test_conv_transpose_2d(self): + self._test_conv_transpose_impl( + torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) + + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( self, mode, name, model, eager_quantizable_model, @@ -1443,24 +2618,21 @@ def _test_model_impl( qconfig = default_qconfig if mode == 'static' else default_qat_qconfig qconfig_dict = {'': qconfig} - graph_module = symbolic_trace(model) # print('graph module:', graph_module.src) - script = torch.jit.script(graph_module) + script = torch.jit.script(model) # make sure graph module and script module are both runanble - original_out = graph_module(input_value) + original_out = model(input_value) is_not_tuple_out = not isinstance(original_out, tuple) script_out = script(input_value) - self.assertEqual( - (original_out - script_out).abs().max(), 0, - 'Reslut of original graph module and script module does not match') # set to train just before quantization + prepare_fx_fn = prepare_fx if mode != 'static': model.train() + prepare_fx_fn = prepare_qat_fx - graph_module = fuse_fx(graph_module) - prepared = prepare_fx(graph_module, qconfig_dict) + prepared = prepare_fx_fn(model, qconfig_dict) if mode == 'ddp': mp.spawn(run_ddp, @@ -1539,6 +2711,58 @@ def _test_model_impl( ' should match. Mode: ' + mode + ' diff:' + str(diff_from_eager[mode][name])) + def _test_building_block(self, quant_type, BB): + eager = BB().float() + graph = copy.deepcopy(eager) + + if quant_type == QuantType.STATIC: + qconfig = default_qconfig + eager_prepare = prepare + graph_prepare = prepare_fx + eager.eval() + graph.eval() + calibrate_or_train = test_only_eval_fn + data = self.img_data_2d + else: + assert quant_type == QuantType.QAT + qconfig = default_qat_qconfig + eager_prepare = prepare_qat + graph_prepare = prepare_qat_fx + eager.train() + graph.train() + calibrate_or_train = test_only_train_fn + data = self.img_data_2d_train + + if hasattr(eager, "fuse_model"): + eager.fuse_model() + eager = QuantWrapper(eager) + eager.qconfig = qconfig + eager = eager_prepare(eager) + + qconfig_dict = {"": qconfig} + graph = graph_prepare(graph, qconfig_dict) + + eager_out = eager(data[0][0]) + graph_out = graph(data[0][0]) + self.assertEqual(eager_out, graph_out) + + calibrate_or_train(eager, data) + calibrate_or_train(graph, data) + + eager = convert(eager) + graph = convert_fx(graph) + + eager_out = eager(data[0][0]) + graph_out = graph(data[0][0]) + self.assertEqual(eager_out, graph_out) + + @override_qengines + def test_resnet_base(self): + models = [ResNetBase] + options = itertools.product(self.static_quant_types, models) + for quant_type, M in options: + self._test_building_block(quant_type, M) + @skip_if_no_torchvision @skipIfNoFBGEMM @unittest.skip("skip for now since tbb failed") @@ -1556,15 +2780,11 @@ def get_available_classification_models(models): quantized_model_list = set(quantized_model_list) - no_pretrained_model # test eager and graph consistency model_list = quantized_model_list - # slice need to be fixed in symbolic tracing(https://github.com/pytorch/pytorch/issues/43511) - model_list = set(model_list) - {'googlenet', 'inception_v3'} - # getattr should not be used as node name(https://github.com/pytorch/pytorch/issues/43522) - model_list -= {'shufflenet_v2_x1_0', 'mobilenet_v2'} - + # inception_v3 is not symbolically traceable: https://github.com/pytorch/pytorch/issues/48813 + model_list = set(model_list) - {'inception_v3'} # mobilenet: dropout error RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'QUInt8' # incpetion_v3: looks like there is some problem with AuxLogits - quantized_not_working = [('qat', 'mobilenet_v2'), - ('qat', 'inception_v3'), + quantized_not_working = [('qat', 'inception_v3'), ('static', 'inception_v3')] fx_eager_not_matching = ['googlenet', # because _transform_input is not quantized in eager @@ -1606,7 +2826,6 @@ def print_diffs(diffs): @skip_if_no_torchvision @skip_if_not_multigpu @skipIfNoFBGEMM - @unittest.skip('TODO: not working yet due to https://github.com/pytorch/pytorch/issues/43513') def test_resnet18_ddp(self): from torchvision import models from torchvision.models import quantization as quantized_models diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index 6d94919eee1f3..90e5e411d6921 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -51,6 +51,7 @@ SkipQuantModel, NestedModel, ConvModel, + ConvTransposeModel, default_per_channel_qconfig, test_only_eval_fn, ConvBnModel, @@ -61,6 +62,7 @@ AnnotatedSkipQuantModel, AnnotatedNestedModel, AnnotatedConvModel, + AnnotatedConvTransposeModel, AnnotatedConvBnModel, ) @@ -72,12 +74,15 @@ from torch.jit._recursive import wrap_cpp_module # Standard library +from typing import List, Tuple +import io import itertools import unittest class TestQuantizeJitPasses(QuantizationTestCase): """ Test graph mode quantization passes used by quantize_jit """ + def test_foldbn_trivial(self): bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} @@ -738,8 +743,7 @@ def forward(self, x): .run(m.graph) def test_insert_observers_propagate_observed_for_function(self): - def channel_shuffle(x, groups): - # type: (torch.Tensor, int) -> torch.Tensor + def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor: batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape @@ -1122,8 +1126,7 @@ def __init__(self): self.conv = torch.nn.Conv2d(3, 3, 1).float() self.use_skip = True - def forward(self, x, cond): - # type: (Tensor, bool) -> Tensor + def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor: # to avoid being frozen self.use_skip = cond if self.use_skip: @@ -1223,8 +1226,7 @@ def __init__(self): super(ComplexModel, self).__init__() self.layers = torch.nn.ModuleList([SimpleLinearLayer() for i in range(2)]) - def forward(self, x): - # type: (torch.Tensor) -> List[torch.Tensor] + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: states = [] for layer in self.layers: val = layer(x) @@ -1320,8 +1322,7 @@ def forward(self, x, y): @torch.jit.interface class ModInterface(torch.nn.Module): - def forward(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass class TestModule(torch.nn.Module): @@ -1361,6 +1362,52 @@ def forward(self, x, y): FileCheck().check("quantized::embedding_bag_byte_rowwise_offsets") \ .run(m.graph) + @skipIfNoFBGEMM + def test_quantize_fork_wait(self): + """ Tests the case where fork and wait calls are in different subgraphs + Calling inline fork-wait only removes the fork call and leaves aten::wait + calls in the graph, with Tensor as input (instead of Future[Tensor]) + """ + class MainModule(nn.Module): + def __init__(self): + super(MainModule, self).__init__() + self.fork_ops = ForkModule() + + def init_values(self, x): + shared_module = self.fork_ops(x) + self.fork_dict = shared_module + + def forward(self, x): + val = torch.jit._wait(self.fork_ops(x)) + return val + + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, x): + w = torch.ones(5, 5) + b = torch.zeros(5) + return torch.nn.functional.linear(x, w, b) + + class ForkModule(nn.Module): + def __init__(self): + super(ForkModule, self).__init__() + self.test = TestModule() + + def forward(self, x): + fut = torch.jit._fork(self.test.forward, x) + return fut + + model = MainModule().eval() + traced = torch.jit.trace(model, (torch.randn(5, 5),)) + model = prepare_dynamic_jit(traced, {'' : default_qconfig}) + model = convert_dynamic_jit(model) + FileCheck().check("quantized::linear_dynamic") \ + .run(model.graph) + # Make sure model save works + b = io.BytesIO() + torch.jit.save(model, b) class TestQuantizeJitOps(QuantizationTestCase): """ Test graph mode post training static quantization works @@ -2378,8 +2425,7 @@ def __init__(self): self.conv1 = torch.nn.Conv2d(3, 3, 3).float() self.conv2 = torch.nn.Conv2d(3, 3, 3).float() - def forward(self, x): - # type: (Tensor) -> Tuple[Tensor, Tensor] + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x1 = self.conv1(x) x2 = self.conv2(x) return x1, x2 @@ -2625,6 +2671,7 @@ def forward(self, x): num_quantize_per_tensor = 1 # for output for num_quant, num_op in num_op_by_num_quant.items(): num_quantize_per_tensor += num_op * num_quant + num_quantize_per_tensor -= 4 # constant propagation removes some prepacks FileCheck().check_count("aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True) \ .run(m1.graph) @@ -2658,6 +2705,28 @@ def test_conv_with_benchmark_flag(self): FileCheck().check("quantized::conv2d") \ .run(converted_model.graph) + @skipIfNoFBGEMM + def test_cat_linear(self): + class LinearModel(torch.nn.Module): + def __init__(self): + super(LinearModel, self).__init__() + self.weight = torch.randn(5, 5) + + def forward(self, x, y): + a = torch.cat([x, y]) + b = F.linear(a, self.weight) + c = F.linear(b, self.weight) + return b, c + + model = LinearModel().eval() + qconfig = {'' : default_qconfig} + float_model = torch.jit.script(model) + prepared_model = prepare_jit(float_model, qconfig) + prepared_model(torch.rand(5, 5), torch.rand(5, 5)) + converted_model = convert_jit(prepared_model) + FileCheck().check("quantized::linear") \ + .check("quantized::linear") \ + .run(converted_model.graph) class TestQuantizeDynamicJitPasses(QuantizationTestCase): def test_prepare_dynamic(self): @@ -2683,11 +2752,11 @@ def forward(self, x): else: # for input of FC for dynamic quant assert len(attrs_with_prefix(m, '_observer_')) == 1 - observer_name = 'DynamicQuantObserver = prim::GetAttr[name="_observer_' + observer_name = 'Observer = prim::GetAttr[name="_observer_' FileCheck().check(observer_name) \ .check('prim::GetAttr[name="fc"]') \ .check('prim::CallMethod') \ - .check_not('Observer = prim::GetAttr[name="_observer_') \ + .check_not(observer_name) \ .run(m.graph) @@ -2723,7 +2792,7 @@ def forward(self, x): assert len(attrs_with_prefix(m.sub.fc, '_observer_')) == 1 FileCheck().check('prim::GetAttr[name="sub') \ .check('prim::CallMethod') \ - .check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \ + .check('Observer = prim::GetAttr[name="_observer_') \ .check('prim::CallMethod') \ .check_not('Observer = prim::GetAttr[name="_observer_') \ .run(m.graph) @@ -2846,8 +2915,7 @@ def __init__(self): super(Res, self).__init__() self.weight = torch.nn.Parameter(torch.ones(5, 5)) - def forward(self, x, cond): - # type: (Tensor, bool) -> Tensor + def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor: if cond: return torch.nn.functional.linear(x, self.weight) else: @@ -2948,8 +3016,7 @@ def forward(self, x): m = torch.jit.script(M()) m = quantize_dynamic_jit(m, {'': float16_dynamic_qconfig}) - FileCheck().check("quantized::linear_prepack_fp16") \ - .check_next("quantized::linear_dynamic_fp16") \ + FileCheck().check("quantized::linear_dynamic_fp16") \ .check_not("aten::linear") \ .check_not("aten::dequantize") \ .check_not("aten::quantize") \ @@ -2983,6 +3050,7 @@ def forward(self, x): FunctionalLinear(weight, bias), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True) + @skipIfNoFBGEMM def test_embedding_bag(self): class M(torch.nn.Module): def __init__(self, weights): @@ -2990,14 +3058,14 @@ def __init__(self, weights): self.embedding1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, - sparse=False, + sparse=True, _weight=weights, mode='sum') self.embedding2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, - sparse=False, + sparse=True, _weight=weights, mode='sum') @@ -3026,8 +3094,9 @@ def forward(self, indices1, offsets1, indices2, offsets2): m = prepare_jit(m, {'embedding1' : int4_qconfig, 'embedding2' : int8_qconfig}) m = convert_jit(m) FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \ - .check_next("quantized::embedding_bag_byte_rowwise_offsets") \ + .check("quantized::embedding_bag_byte_rowwise_offsets") \ .run(m.graph) + m(*dummy_inputs) @@ -3044,7 +3113,7 @@ def test_single_linear(self): # compare the result of the two quantized models later linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach()) linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach()) - model_eager = quantize(annotated_linear_model, test_only_eval_fn, self.calib_data) + model_eager = quantize(annotated_linear_model, test_only_eval_fn, [self.calib_data]) qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) @@ -3084,7 +3153,7 @@ def test_observer_with_ignored_function(self): linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach()) linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach()) model_eager = quantize(annotated_linear_model, test_only_eval_fn, - self.calib_data) + [self.calib_data]) qconfig_dict = {'': qconfig} model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) @@ -3110,7 +3179,7 @@ def test_conv(self): # copy the weight from eager mode so that we can # compare the result of the two quantized models later conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach()) - model_eager = quantize(annotated_conv_model, test_only_eval_fn, self.img_data_2d) + model_eager = quantize(annotated_conv_model, test_only_eval_fn, [self.img_data_2d]) qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0]) model_script = torch.jit.script(conv_model) @@ -3124,6 +3193,35 @@ def test_conv(self): inplace=False) self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager) + @override_qengines + def test_conv_transpose(self): + r"""Compare the result of quantizing conv_transpose layer in + eager mode and graph mode + """ + if not qengine_is_qnnpack(): + return # Currently only qnnpack is supported + # eager mode + annotated_conv_model = AnnotatedConvTransposeModel( + torch.backends.quantized.engine).eval() + conv_model = ConvTransposeModel().eval() + # copy the weight from eager mode so that we can + # compare the result of the two quantized models later + conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach()) + model_eager = quantize(annotated_conv_model, test_only_eval_fn, [self.img_data_2d]) + qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} + model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0]) + model_script = torch.jit.script(conv_model) + result_eager = model_eager(self.img_data_2d[0][0]) + for model_under_test in [model_traced, model_script]: + model_quantized = quantize_jit( + model_under_test, + qconfig_dict, + test_only_eval_fn, + [self.img_data_2d], + inplace=False) + self.assertEqual(model_quantized(self.img_data_2d[0][0]), + result_eager) + @override_qengines def test_conv_bn(self): r"""Compare the result of quantizing conv + bn layer in @@ -3137,7 +3235,7 @@ def test_conv_bn(self): conv_model_to_script.conv.weight = torch.nn.Parameter(conv_model.conv.weight.detach()) fuse_modules(conv_model, ['conv', 'bn'], inplace=True) model_eager = quantize(conv_model, test_only_eval_fn, - self.img_data_2d) + [self.img_data_2d]) qconfig_dict = { '': default_qconfig } @@ -3168,7 +3266,7 @@ def test_nested(self): script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach()) script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach()) - model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data) + model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data]) qconfig_dict = { 'sub2.fc1': default_per_channel_qconfig if qengine_is_fbgemm() else default_qconfig, 'fc3': default_qconfig @@ -3204,7 +3302,7 @@ def test_skip_quant(self): eager_model.fuse_modules() - model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data) + model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data]) qconfig_dict = { '': get_default_qconfig(torch.backends.quantized.engine), 'fc': None diff --git a/test/quantization/test_quantized_functional.py b/test/quantization/test_quantized_functional.py index 548b0677efe03..59242493d869d 100644 --- a/test/quantization/test_quantized_functional.py +++ b/test/quantization/test_quantized_functional.py @@ -26,7 +26,7 @@ def test_relu_api(self): zero_point = 1 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8) qY = torch.relu(qX) - qY_hat = qF.relu(qX) + qY_hat = F.relu(qX) self.assertEqual(qY, qY_hat) def _test_conv_api_impl( diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py index fe0d52a9356bb..28867e8260b6d 100644 --- a/test/quantization/test_quantized_module.py +++ b/test/quantization/test_quantized_module.py @@ -2,12 +2,16 @@ import torch.nn as nn import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nnq_fused +import torch.nn.intrinsic.quantized._reference as nniqr import torch.nn.quantized as nnq +import torch.nn.quantized._reference as nnqr import torch.nn.quantized.dynamic as nnqd +import torch.nn.functional as F import torch.quantization from torch.quantization import ( - default_float_qparams_observer + default_float_qparams_observer, + PerChannelMinMaxObserver ) from torch.testing._internal.common_quantization import ( QuantizationTestCase, @@ -28,17 +32,18 @@ import io import numpy as np +import itertools -''' +""" Note that tests in this file are just API test, to make sure we wrapped the quantized operator implementations correctly in the user facing APIs, these are not correctness test for the underlying quantized operators. For correctness -test please see `caffe2/test/test_quantized_op.py`. -''' +test please see `test/quantization/test_quantized_op.py`. +""" class TestStaticQuantizedModule(QuantizationTestCase): def test_relu(self): - relu_module = nnq.ReLU() + relu_module = nn.ReLU() relu6_module = nnq.ReLU6() x = torch.arange(-10, 10, dtype=torch.float) @@ -54,20 +59,35 @@ def test_relu(self): self.assertEqual(y6_ref, qy6.dequantize(), msg="ReLU6 module API failed") - - @given( - batch_size=st.integers(1, 5), - in_features=st.integers(16, 32), - out_features=st.integers(4, 8), - use_bias=st.booleans(), - use_fused=st.booleans(), - per_channel=st.booleans() - ) @override_qengines - def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel): + def test_linear_api(self): """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu""" + options = itertools.product( + [1, 5], + [16, 32], + [4, 8], + [True, False], + [True, False], + [True, False], + [True, False]) + for (batch_size, in_features, out_features, use_bias, + use_fused, per_channel, reference) in options: + self._test_linear_api_impl( + batch_size, in_features, out_features, use_bias, use_fused, + per_channel, reference) + + def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, reference): if torch.backends.quantized.engine == 'qnnpack': per_channel = False + + # (use_fused, reference) -> quantized class + class_map = { + (True, True) : nniqr.LinearReLU, + (True, False) : nnq_fused.LinearReLU, + (False, True) : nnqr.Linear, + (False, False) : nnq.Linear, + } + W = torch.rand(out_features, in_features).float() if per_channel: scale_tensor = torch.ones(out_features, dtype=torch.double) @@ -85,11 +105,11 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 - if use_fused: - qlinear = nnq_fused.LinearReLU(in_features, out_features) - else: - qlinear = nnq.Linear(in_features, out_features) + qlinear = class_map[(use_fused, reference)](in_features, out_features) + qlinear_copy = qlinear # deepcopy does not work right now + # qlinear_copy = copy.deepcopy(qlinear) + self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True) # Run module with default-initialized parameters. # This tests that the constructor is correct. qlinear(X_q) @@ -97,22 +117,33 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f qlinear.set_weight_bias(W_q, B) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0) - W_pack = qlinear._packed_params._packed_params + # testing packed param implementation qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) Z_q = qlinear(X_q) + # Check if the module implementation matches calling the # ops directly - if use_fused: - Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) - - self.assertTrue('QuantizedLinearReLU' in str(qlinear)) + if reference: + weight = qlinear._qweight + bias = qlinear._bias + weight_dequant = weight.dequantize() + X_q_dq = X_q.dequantize() + Z_ref = F.linear(X_q_dq, weight_dequant, bias) + if use_fused: + Z_ref = F.relu(Z_ref, inplace=True) + Z_ref = torch.quantize_per_tensor(Z_ref, scale, zero_point, torch.quint8) else: - Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) + W_pack = qlinear._packed_params._packed_params + if use_fused: + Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) + else: + Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) - self.assertTrue('QuantizedLinear' in str(qlinear)) self.assertEqual(Z_ref, Z_q) + self.assertTrue( + ("QuantizedLinearReLU" if use_fused else "QuantizedLinear") in str(qlinear)) # Test serialization of quantized Linear Module using state_dict model_dict = qlinear.state_dict() @@ -129,26 +160,26 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f self.assertEqual(b_model, b_loaded) else: self.assertEqual(model_dict[key], loaded_dict[key]) - if use_fused: - loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features) - else: - loaded_qlinear = nnq.Linear(in_features, out_features) - loaded_qlinear.load_state_dict(loaded_dict) - linear_unpack = torch.ops.quantized.linear_unpack - self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), - linear_unpack(loaded_qlinear._packed_params._packed_params)) - if use_bias: - self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) + loaded_qlinear = class_map[(use_fused, reference)]( + in_features, out_features) + loaded_qlinear.load_state_dict(loaded_dict) + if reference: + self.assertEqual(qlinear._qweight, loaded_qlinear._qweight) + self.assertEqual(qlinear._bias, loaded_qlinear._bias) + else: + linear_unpack = torch.ops.quantized.linear_unpack + self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), + linear_unpack(loaded_qlinear._packed_params._packed_params)) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) + # make sure loaded_qlinear has the same dir as qlinear since + # scripting the module will add __overloads__ to __dict__ + self.checkScriptable(loaded_qlinear, [[X_q]], check_save_load=True) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) - self.assertTrue(hasattr(qlinear, '_packed_params')) - self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) - self.assertTrue(hasattr(qlinear, '_weight_bias')) - self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) - self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) + if not reference: + self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) @@ -163,20 +194,24 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f # Test JIT self.checkScriptable(qlinear, [[X_q]], check_save_load=True) - # Test from_float. - float_linear = torch.nn.Linear(in_features, out_features).float() - float_linear.qconfig = torch.quantization.default_qconfig - torch.quantization.prepare(float_linear, inplace=True) - float_linear(X.float()) - # Sequential allows swapping using "convert". - quantized_float_linear = torch.nn.Sequential(float_linear) - quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) + # Make sure `from_float` works for all linear variants + modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] - # Smoke test to make sure the module actually runs - quantized_float_linear(X_q) + for mut in modules_under_test: + # Test from_float. + float_linear = mut(in_features, out_features).float() + float_linear.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(float_linear, inplace=True) + float_linear(X.float()) + # Sequential allows swapping using "convert". + quantized_float_linear = torch.nn.Sequential(float_linear) + quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True) - # Smoke test extra_repr - self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) + # Smoke test to make sure the module actually runs + quantized_float_linear(X_q) + + # Smoke test extra_repr + self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) def test_quant_dequant_api(self): r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float) @@ -303,10 +338,11 @@ def _test_conv_api_impl( check_save_load=True) # Test from_float - conv_module.qconfig = torch.quantization.default_qconfig - torch.quantization.prepare(conv_module, inplace=True) - conv_module(X.float()) - converted_qconv_module = torch.nn.Sequential(conv_module) + fused_conv_module = torch.nn.intrinsic._FusedModule(conv_module) + fused_conv_module.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(fused_conv_module, inplace=True) + fused_conv_module(X.float()) + converted_qconv_module = fused_conv_module torch.quantization.convert(converted_qconv_module, inplace=True) # Smoke test to make sure the module actually runs @@ -678,7 +714,7 @@ def test_instance_norm(self): msg="InstanceNorm module API failed, qY_ref\n{} vs qY\n{}" .format(qY_ref, qY)) - def test_elu(self): + def _test_activation_module_impl(self, name, float_module_class, quantized_module_class, extra_kwargs): """Tests the correctness of the ELU module. The correctness is defined against the functional implementation. """ @@ -694,24 +730,52 @@ def test_elu(self): qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8) dqX = qX.dequantize() - float_mod = torch.nn.ELU(alpha).float() + float_mod = float_module_class(**extra_kwargs).float() dqY_ref = float_mod(dqX) qY_ref = torch.quantize_per_tensor( dqY_ref, y_scale, y_zero_point, dtype=torch.quint8) - quant_mod = nnq.ELU(y_scale, y_zero_point, alpha) + quant_mod = quantized_module_class(y_scale, y_zero_point, **extra_kwargs) qY = quant_mod(qX) - self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(), - msg="ELU module API failed, qY_ref\n{} vs qY\n{}" - .format(qY_ref, qY)) + msg="{} module API failed, qY_ref\n{} vs qY\n{}" + .format(name, qY_ref, qY)) + + def _test_leaky_relu_serialization(self): + scale_original = 10.0 / 256 + zero_point_original = 1.0 + + quant_mod_original = nnq.LeakyReLU(scale_original, zero_point_original) + state_dict = quant_mod_original.state_dict() + + scale_new = 5.0 / 256 + zero_point_new = 2.0 + quant_mod_new = nnq.LeakyReLU(scale_new, zero_point_new) + quant_mod_new.load_state_dict(state_dict) + + self.assertEqual(quant_mod_original.scale, quant_mod_new.scale) + self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point) + + def test_elu(self): + """Tests the correctness of the ELU module. + The correctness is defined against the functional implementation. + """ + self._test_activation_module_impl("ELU", nn.ELU, nnq.ELU, {"alpha": 1.5}) + + def test_leaky_relu(self): + self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2}) + self._test_leaky_relu_serialization() + + def test_sigmoid(self): + self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {}) @given( num_embeddings=st.integers(10, 50), embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), set_qconfig=st.booleans(), ) + @skipIfNoFBGEMM def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) @@ -734,8 +798,8 @@ def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): w_packed = qemb._packed_params._packed_weight module_out = qemb(indices) - # Call the qembedding_bag operator directly - ref = torch.ops.quantized.embedding_byte(w_packed, indices, sparse=False) + # Call the qembedding operator directly + ref = torch.ops.quantized.embedding_byte(w_packed, indices, pruned_weights=False) self.assertEqual(module_out, ref) self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False) @@ -750,6 +814,7 @@ def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 """ + num_lengths = np.random.randint(1, 6) lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) num_indices = np.sum(lengths) @@ -760,28 +825,36 @@ def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) - obs = default_float_qparams_observer() - obs(weights) - # Get the scale and zero point for the weight tensor - qparams = obs.calculate_qparams() - # Quantize the weights to 8bits - qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) - qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, mode='sum', _weight=qweight) - qemb(indices, offsets) - - # Ensure the module has the correct weights - self.assertEqual(qweight, qemb.weight()) - - w_packed = qemb._packed_params._packed_weight - module_out = qemb(indices, offsets) + for qdtype in [torch.quint8, torch.quint4x2]: + obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + obs(weights) + # Get the scale and zero point for the weight tensor + qparams = obs.calculate_qparams() + # Quantize the weights to 8bits + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype) + qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, + include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype) + qemb(indices, offsets) + + # Ensure the module has the correct weights + self.assertEqual(qweight, qemb.weight()) + + w_packed = qemb._packed_params._packed_weight + module_out = qemb(indices, offsets) + + # Call the qembedding_bag operator directly + if qdtype == torch.quint8: + ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, + per_sample_weights=None, + include_last_offset=True) + else: + ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0, + per_sample_weights=None, + include_last_offset=True) - # Call the qembedding_bag operator directly - ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, - per_sample_weights=None, - include_last_offset=True) - self.assertEqual(module_out, ref) - self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag=True) + self.assertEqual(module_out, ref) + self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, + offsets, set_qconfig, is_emb_bag=True, dtype=qdtype) class TestDynamicQuantizedModule(QuantizationTestCase): @given( @@ -859,16 +932,18 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_d # Test JIT self.checkScriptable(qlinear, [[X]], check_save_load=True) - # Test from_float - float_linear = torch.nn.Linear(in_features, out_features).float() - if use_default_observer: - float_linear.qconfig = torch.quantization.default_dynamic_qconfig - prepare_dynamic(float_linear) - float_linear(X.float()) - quantized_float_linear = nnqd.Linear.from_float(float_linear) + modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] + for mut in modules_under_test: + # Test from_float + float_linear = mut(in_features, out_features).float() + if use_default_observer: + float_linear.qconfig = torch.quantization.default_dynamic_qconfig + prepare_dynamic(float_linear) + float_linear(X.float()) + quantized_float_linear = nnqd.Linear.from_float(float_linear) - # Smoke test to make sure the module actually runs - quantized_float_linear(X) + # Smoke test to make sure the module actually runs + quantized_float_linear(X) # Smoke test extra_repr self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) @@ -947,6 +1022,55 @@ def test_lstm_api(self, dtype, bidirectional): self.check_eager_serialization(cell_dq, ref_dq, [x]) self.check_weight_bias_api(cell_dq, weight_keys, bias_keys) + @override_qengines + def test_gru_api(self): + r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16 + """ + # Check that module matches the numerics of the op and ensure that module can be + # instantiated for all engines and dtypes + + for dtype in [torch.qint8, torch.float16]: + if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack": + # fp16 dynamic quant is not supported for qnnpack + continue + # Test default instantiation + seq_len = 4 + batch = 2 + input_size = 3 + hidden_size = 7 + num_layers = 2 + bias = True + bidirectional = False + + x = torch.rand(seq_len, batch, input_size) + h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size) + + + cell_dq = torch.nn.quantized.dynamic.GRU(input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=False, + dropout=0.0, + bidirectional=bidirectional, + dtype=dtype) + + _all_params = ([m.param for m in cell_dq._all_weight_values]) + result = torch.quantized_gru(x, + h, + _all_params, + cell_dq.bias, + cell_dq.num_layers, + float(cell_dq.dropout), + False, + bidirectional, + False) + + + y, h = cell_dq(x, h) + self.assertEqual(result[0], y, msg="GRU module API failed") + self.assertEqual(result[1], h, msg="GRU module API failed") + @given( dtype=st.sampled_from([torch.qint8, torch.float16]), ) diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 9412332c238bb..a192eddca234a 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -1,9 +1,11 @@ from builtins import round +import copy import itertools import numpy as np import sys import unittest +import operator import torch from torch import _VF @@ -18,10 +20,10 @@ hu.assert_deadline_disabled() from torch.testing._internal.common_utils import TestCase -from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN +from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ - override_quantized_engine, supported_qengines, override_qengines + override_quantized_engine, supported_qengines, override_qengines, _snr from torch.testing._internal.common_quantized import qengine_is_qnnpack from torch.quantization import PerChannelMinMaxObserver @@ -89,9 +91,9 @@ def pool_output_shape(input_size, kernel_size, padding, stride, output_size = ( (input_size + 2 * padding - dilation * (kernel_size - 1) - 1 + (stride - 1 if ceiling_mode else 0)) // stride + 1) - if (padding > 0 and + if (ceiling_mode and ((output_size - 1) * stride >= input_size + padding)): - output_size += 1 + output_size -= 1 return output_size """ @@ -140,17 +142,17 @@ def _test_activation_function(self, X, fn_name, test_configs): quantized_fn: a list of the quantized functions to be tested reference_fn: the original reference function to be called on the the dequantized X - inplace_kwarg: the additional inplace keyword argument to test in-place + extra_kwargs: the additional keyword arguments for each test entry in ops_under_test, it must have at least the fields - for quantized_fn and reference_fn. If inplace_kwarg is missing, the - quantized function is assumed to be either inplace by default or the - test is not testing an inplace function. + for quantized_fn and reference_fn. output_range: the output range the operator will map to. By default, if it is no specified, the range will not be controlled and depend on Xmin and Xmax. change_zero_point: a boolean flag indicating if the zero point parameter should be determined based on torch_type during quantization (see sigmoid/hardsigmoid for examples). By default, if it is not specified, change_zero_point is assumed to be False and zero point will just take on the default value from X. + `output_is_observed`: if specified and is True, we'll append extra + output_scale/output_zero_point keyword argument when calling quantized op """ # Retrives the default parameters from X. X, (scale, zero_point, torch_type) = X @@ -162,44 +164,57 @@ def _test_activation_function(self, X, fn_name, test_configs): for op_group in test_configs: ref_op = op_group['reference_fn'] for q_op in op_group['quantized_fn']: - # Quantizes and dequantizes to account for max error. - qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, - dtype=torch_type) - dqX = qX.dequantize() - dqY_hat = ref_op(dqX.clone()) - - # Retrieves the inplace keyword arguments - # some functions require inplace=True to test in-place. - inplace_kwarg = op_group.get('inplace_kwarg', dict()) - - # Adjusts output_scale if needed. - # The output_scale determines the quantization scale for functions that - # have a constrained output range. e.x. sigmoid ranges from 0 to 1. - output_scale = scale - if 'output_range' in op_group: - (f_min, f_max) = op_group['output_range'] - output_scale = (f_max - f_min) / (q_max - q_min + 1.0) - - # Adjusts output_zero_point if needed (see explanation for the - # change_zero_point parameter above). - # output_zero_point determines the additional offset that will be - # added to a scaled value during quantization. - if op_group.get('change_zero_point', False): - output_zero_point = 0 if torch_type == torch.qint32 else q_min - else: - output_zero_point = zero_point - # Quantizes the dequantized version of Y_hat. - qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, - zero_point=output_zero_point, + for memory_format in (torch.channels_last, torch.contiguous_format): + if memory_format == torch.channels_last and len(X.shape) != 4: + continue + X = X.to(memory_format=memory_format) + + # Retrieves the inplace keyword arguments + # some functions require inplace=True to test in-place. + # copy.copy is needed because these are modified in place + extra_kwargs = \ + copy.copy(op_group.get('extra_kwargs', dict())) + output_is_observed = \ + copy.copy(op_group.get('output_is_observed', False)) + + # Quantizes and dequantizes to account for max error. + qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch_type) + dqX = qX.dequantize() + dqY_hat = ref_op(dqX.clone(), **extra_kwargs) + + # Adjusts output_scale if needed. + # The output_scale determines the quantization scale for functions that + # have a constrained output range. e.x. sigmoid ranges from 0 to 1. + output_scale = scale + if 'output_range' in op_group: + (f_min, f_max) = op_group['output_range'] + output_scale = (f_max - f_min) / (q_max - q_min + 1.0) + + # Adjusts output_zero_point if needed (see explanation for the + # change_zero_point parameter above). + # output_zero_point determines the additional offset that will be + # added to a scaled value during quantization. + if op_group.get('change_zero_point', False): + output_zero_point = 0 if torch_type == torch.qint32 else q_min + else: + output_zero_point = zero_point + + # Quantizes the dequantized version of Y_hat. + qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, + zero_point=output_zero_point, + dtype=torch_type) - # Finds qY using in-place or non-in-place quantized operators. - qY = q_op(qX, **inplace_kwarg) + if output_is_observed: + extra_kwargs.update({'output_scale': output_scale, 'output_zero_point': output_zero_point}) - self.assertEqual(qY, qY_hat, msg='{} - {} failed: ({} vs. {})'.format( - fn_name, q_op, qY, qY_hat - )) + # Finds qY using in-place or non-in-place quantized operators. + qY = q_op(qX, **extra_kwargs) + + self.assertEqual(qY, qY_hat, msg='{} - {} failed: ({} vs. {})'.format( + fn_name, q_op, qY, qY_hat + )) """Tests the correctness of the quantized::relu op.""" @override_qengines @@ -212,17 +227,17 @@ def test_qrelu(self, X): torch.relu, torch.relu_, torch.nn.functional.relu, - torch.nn.quantized.functional.relu, + torch.nn.functional.relu, ], 'reference_fn': torch.nn.functional.relu }, { 'quantized_fn': [ torch.nn.functional.relu, - torch.nn.quantized.functional.relu, + torch.nn.functional.relu, ], 'reference_fn': torch.nn.functional.relu, - 'inplace_kwarg': { + 'extra_kwargs': { 'inplace': True } } @@ -250,7 +265,7 @@ def test_qrelu6(self, X): @override_qengines @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), qparams=hu.qparams())) - def test_qsigmoid(self, X): + def test_sigmoid_non_observed(self, X): sigmoid_test_configs = [ { 'quantized_fn': [ @@ -263,11 +278,29 @@ def test_qsigmoid(self, X): ] self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) - """Tests the correctness of the quantized::hardsigmoid op.""" - @override_qengines + """Tests the correctness of the quantized::sigmoid op.""" + # TODO: enable after observed output is supported in qnnpack + # @override_qengines + @skipIfNoFBGEMM @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), qparams=hu.qparams())) - def test_qhardsigmoid(self, X): + def test_sigmoid(self, X): + sigmoid_test_configs = [ + { + 'quantized_fn': [ + torch.ops.quantized.sigmoid + ], + 'reference_fn': torch.sigmoid, + 'output_range': (0.0, 1.0), + 'change_zero_point': True, + 'output_is_observed': True, + } + ] + self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) + + """Tests the correctness of the quantized::hardsigmoid op.""" + @override_qengines + def test_qhardsigmoid(self): hardsigmoid_test_configs = [ { 'quantized_fn': [ @@ -278,28 +311,57 @@ def test_qhardsigmoid(self, X): 'change_zero_point': True } ] - self._test_activation_function(X, 'hardsigmoid', hardsigmoid_test_configs) + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) + dtypes = (torch.quint8, torch.qint8) + test_cases = itertools.product(shapes, dtypes) + for shape, dtype in test_cases: + X = (np.random.rand(*shape).astype(np.float32), (1.0, 0, dtype)) + self._test_activation_function(X, 'hardsigmoid', hardsigmoid_test_configs) - """Tests the correctness of the quantized::relu op.""" + @override_qengines @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), - qparams=hu.qparams()), - alpha=st.floats(0.0, 1.0, allow_nan=False, allow_infinity=False)) - def test_qrelu_leaky(self, X, alpha): - X, (scale, zero_point, torch_type) = X + qparams=hu.qparams())) + def test_leaky_relu_observed_output(self, X): + leaky_relu_test_configs = [ + { + 'quantized_fn': [ + torch.ops.quantized.leaky_relu + ], + 'reference_fn': torch.nn.functional.leaky_relu, + 'extra_kwargs': { + 'negative_slope': 0.1, + 'inplace': False, + }, + 'output_is_observed': True, + } + ] + self._test_activation_function(X, 'leaky_relu', leaky_relu_test_configs) - X = torch.from_numpy(X) - qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, - dtype=torch_type) - dqX = qX.dequantize() + """Tests the correctness of the quantized::relu op.""" + def test_leaky_relu(self): + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) + dtypes = (torch.quint8, torch.qint8) + memory_formats = (torch.channels_last, torch.contiguous_format) + test_cases = itertools.product(shapes, dtypes, memory_formats) + for shape, dtype, memory_format in test_cases: + if memory_format == torch.channels_last and len(shape) != 4: + continue + X, scale, zero_point, torch_type, alpha = \ + torch.randn(*shape), 0.1, 0, dtype, 0.01 + X = X.to(memory_format=memory_format) - # torch.nn.functional - op = torch.nn.functional.leaky_relu - dqY = op(dqX, negative_slope=alpha) - qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point, - dtype=torch_type) - qY_hat = op(qX, negative_slope=alpha) - self.assertEqual(qY.dequantize(), qY_hat.dequantize(), - msg="F.leaky_relu failed ({} vs {})".format(qY, qY_hat)) + qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, + dtype=torch_type) + dqX = qX.dequantize() + + # torch.nn.functional + op = torch.nn.functional.leaky_relu + dqY = op(dqX, negative_slope=alpha) + qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point, + dtype=torch_type) + qY_hat = op(qX, negative_slope=alpha) + self.assertEqual(qY.dequantize(), qY_hat.dequantize(), + msg="F.leaky_relu failed ({} vs {})".format(qY, qY_hat)) """Tests the correctness of the quantized::elu op.""" @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), @@ -513,22 +575,37 @@ def test_qclamp(self, X, min_val, max_val): X, (scale, zero_point, torch_type) = X assume(min_val <= max_val) - Y = X.copy() - Y[Y < min_val] = min_val - Y[Y > max_val] = max_val - qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale, - zero_point=zero_point, dtype=torch_type) + Y_clamp = torch.clamp(torch.from_numpy(X), min=min_val, max=max_val) + qY_clamp = torch.quantize_per_tensor(Y_clamp, scale=scale, + zero_point=zero_point, dtype=torch_type) + X = torch.from_numpy(X) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch_type) - ops_under_test = { 'ops.quantized': torch.ops.quantized.clamp, } for name, op in ops_under_test.items(): - qY_hat = op(qX, min_val, max_val) - self.assertEqual(qY, qY_hat, msg="{} qclamp failed".format(name)) + qY_clamp_hat = op(qX, min=min_val, max=max_val) + self.assertEqual(qY_clamp, qY_clamp_hat, msg="{} qclamp failed".format(name)) + + if torch.backends.quantized.engine == 'fbgemm': + with override_quantized_engine('fbgemm'): + Y_min_clamp = torch.clamp(X, min=min_val) + Y_max_clamp = torch.clamp(X, max=max_val) + + qY_min_clamp = torch.quantize_per_tensor(Y_min_clamp, scale=scale, + zero_point=zero_point, dtype=torch_type) + qY_max_clamp = torch.quantize_per_tensor(Y_max_clamp, scale=scale, + zero_point=zero_point, dtype=torch_type) + + + for name, op in ops_under_test.items(): + qY_min_clamp_hat = op(qX, min=min_val) + self.assertEqual(qY_min_clamp, qY_min_clamp_hat, msg="{} qclamp failed".format(name)) + qY_max_clamp_hat = op(qX, max=max_val) + self.assertEqual(qY_max_clamp, qY_max_clamp_hat, msg="{} qclamp failed".format(name)) """Tests the correctness of the quantized::hardtanh op.""" @skipIfNoFBGEMM @@ -573,11 +650,11 @@ def test_hardtanh(self, X, min_val, max_val): """Tests the correctness of the quantized::hardswish op.""" @override_qengines def test_hardswish(self): - max_sides = (3, 5) - side_lens = (1, 7, 8) + max_sides = (3, 4) + side_lens = (1, 7) torch_types = (torch.quint8, torch.qint8) - y_scales = (0.1, 4.23) - y_zero_points = (0, 1) + y_scales = (0.1, ) + y_zero_points = (1,) combined = [max_sides, side_lens, torch_types, y_scales, y_zero_points] test_cases = itertools.product(*combined) for test_case in test_cases: @@ -589,52 +666,69 @@ def test_hardswish(self): shapes = [side_len] * max_side X, X_scale, X_zero_point = \ _get_random_tensor_and_q_params(shapes, 2.0, torch_type) - qX = torch.quantize_per_tensor(X, scale=X_scale, zero_point=X_zero_point, - dtype=torch_type) - dqX = qX.dequantize() - - dqY_hat = F.hardswish(dqX) - qY_hat = torch.quantize_per_tensor(dqY_hat, scale=Y_scale, - zero_point=Y_zero_point, + for memory_format in torch.channels_last, torch.contiguous_format: + if memory_format == torch.channels_last and len(shapes) == 4: + X = X.to(memory_format=memory_format) + qX = torch.quantize_per_tensor(X, scale=X_scale, zero_point=X_zero_point, dtype=torch_type) + dqX = qX.dequantize() - qY = torch.nn.quantized.functional.hardswish( - qX, scale=Y_scale, zero_point=Y_zero_point) - self.assertEqual( - qY, qY_hat, - msg="Hardswish failed: {} vs {}, {}".format(qY, qY_hat, torch.backends.quantized.engine)) + dqY_hat = F.hardswish(dqX) + qY_hat = torch.quantize_per_tensor(dqY_hat, scale=Y_scale, + zero_point=Y_zero_point, + dtype=torch_type) - """Tests the correctness of the scalar addition.""" - @unittest.skip("Failing on MacOS") - @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5), - elements=hu.floats(-1e6, 1e6, allow_nan=False), - qparams=hu.qparams()), - b=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False)) - def test_qadd_scalar_relu(self, A, b): + qY = torch.nn.quantized.functional.hardswish( + qX, scale=Y_scale, zero_point=Y_zero_point) + self.assertEqual( + qY, qY_hat, + msg="Hardswish failed: {} vs {}, {}".format(qY, qY_hat, torch.backends.quantized.engine)) + + """Tests the correctness of the binary op + scalar.""" + def _test_binary_op_scalar_relu(self, A, b, binary_op_name, binary_op, quantized_op, quantized_op_relu): import copy - add_scalar = torch.ops.quantized.add - add_scalar_relu = torch.ops.quantized.add_relu + op_scalar = quantized_op + op_scalar_relu = quantized_op_relu A, (scale, zero_point, dtype) = A A = A.astype(np.float32) qA = torch.quantize_per_tensor(torch.from_numpy(A), scale, zero_point, dtype) - C = qA.dequantize() + round(b / scale) * scale + if binary_op_name == 'add': + C = binary_op(qA.dequantize(), round(b / scale) * scale) + else: + C = binary_op(qA.dequantize(), b) C_relu = copy.deepcopy(C) C_relu[C_relu < 0] = 0 - C_hat = add_scalar(qA, b) + C_hat = op_scalar(qA, b) C_ref = torch.quantize_per_tensor(C, C_hat.q_scale(), C_hat.q_zero_point(), dtype) - C_relu_hat = add_scalar_relu(qA, b) + C_relu_hat = op_scalar_relu(qA, b) C_relu_ref = torch.quantize_per_tensor( C_relu, C_relu_hat.q_scale(), C_relu_hat.q_zero_point(), dtype) self.assertEqual(C_ref.dequantize(), C_hat.dequantize(), - msg="Scalar add results don't match:\ - {} vs {}".format(C_ref.dequantize(), C_hat.dequantize())) + msg="{}_scalar results don't match: " + "{} vs {}".format(binary_op_name, C_ref.dequantize(), C_hat.dequantize())) self.assertEqual(C_relu_ref.dequantize(), C_relu_hat.dequantize(), - msg="Scalar add relu results don't match:\ - {} vs {}".format(C_relu_ref.dequantize(), C_relu_hat.dequantize())) + msg="{}_scalar_relu results don't match: " + "{} vs {}".format(binary_op_name, C_relu_ref.dequantize(), C_relu_hat.dequantize())) + + @unittest.skipIf(IS_MACOS, "skipping macos test") + @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5), + elements=hu.floats(-1e6, 1e6, allow_nan=False), + qparams=hu.qparams()), + b=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False)) + def test_add_scalar_relu(self, A, b): + self._test_binary_op_scalar_relu(A, b, "add", operator.add, torch.ops.quantized.add, torch.ops.quantized.add_relu) + + @unittest.skipIf(IS_MACOS, "skipping macos test") + @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5), + elements=hu.floats(-1e6, 1e6, allow_nan=False), + qparams=hu.qparams()), + b=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False)) + def test_mul_scalar_relu(self, A, b): + self._test_binary_op_scalar_relu(A, b, "mul", operator.mul, torch.ops.quantized.mul, torch.ops.quantized.mul_relu) """Tests the correctness of the add and add_relu op.""" def test_qadd_relu_same_qparams(self): @@ -906,7 +1000,56 @@ def test_channel_shuffle(self, X, groups): self.assertEqual(a_ref, a_hat.dequantize(), msg="torch.nn.functional.channel_shuffle results are off") - """Tests max pool operation on quantized tensors.""" + """Tests 1D max pool operation on quantized tensors.""" + @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=3, + min_side=1, max_side=10), + qparams=hu.qparams()), + kernel=st.sampled_from((3, 5, 7)), + stride=st.sampled_from((None, 1, 2)), + dilation=st.integers(1, 2), + padding=st.integers(0, 2), + ceil_mode=st.booleans()) + def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode): + X, (scale, zero_point, torch_type) = X + # Check constraints + assume(kernel // 2 >= padding) # Kernel cannot be overhanging! + iW = X.shape[-1] + oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode) + assume(oW > 0) + + a = torch.from_numpy(X) + a_pool = torch.nn.functional.max_pool1d(a, kernel_size=kernel, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode) + a_ref = torch.quantize_per_tensor(a_pool, scale=scale, + zero_point=zero_point, dtype=torch_type) + a_ref = a_ref.dequantize() + qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point, + dtype=torch_type) + + ops_under_test = { + "torch": torch.max_pool1d, + "nn.functional": torch.nn.functional.max_pool1d, + "nn.quantized.functional": torch.nn.quantized.functional.max_pool1d + } + + for name, op in ops_under_test.items(): + a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding, + dilation=dilation, ceil_mode=ceil_mode) + self.assertEqual(a_ref, a_hat.dequantize(), + msg="{} results are off".format(name)) + # Test the ops.quantized separately, because None is not treated. + a_hat = torch.ops.quantized.max_pool1d( + qa, kernel_size=_single(kernel), + stride=_single(kernel if stride is None else stride), + padding=_single(padding), dilation=_single(dilation), + ceil_mode=ceil_mode) + self.assertEqual(a_ref, a_hat.dequantize(), + msg="ops.quantized.max_pool1d results are off") + + """Tests 2D max pool operation on quantized tensors.""" @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, min_side=1, max_side=10), qparams=hu.qparams()), @@ -1677,12 +1820,14 @@ def test_cat_nhwc(self, X, relu): torch.testing.assert_allclose(out.dequantize(), ref.dequantize()) self.assertNotEqual(out.stride(), sorted(out.stride())) - @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=3, - min_side=1, max_side=2), + @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=1, max_dims=5, + min_side=1, max_side=4), qparams=hu.qparams()), - dim=st.integers(1, 2)) + dim=st.integers(-1, 5)) + @override_qengines def test_mean(self, X, dim): X, (scale, zero_point, torch_type) = X + assume(dim < X.ndim) qX = torch.quantize_per_tensor(torch.tensor(X).float(), scale, zero_point, torch_type) Y = torch.mean(qX.dequantize(), dim) @@ -2129,6 +2274,126 @@ def test_empty_batch(self): result = torch.ops.quantized.linear_dynamic(X, w_packed) self.assertEqual(result.shape, (0, 2)) + def test_advanced_indexing(self): + """ + Verifies that the x[:, [0], :, :] syntax works for quantized tensors. + """ + for dtype in (torch.qint8, torch.quint8, torch.qint32): + scale = 0.1 + zp = 0 + x_q = torch.quantize_per_tensor( + torch.randn(1, 4, 4, 4), scale, zp, dtype) + # reference + x_fp32 = x_q.dequantize() + + # single dim, single index + x_q_s1 = x_q[:, [0], :, :] + x_fp32_s1 = x_fp32[:, [0], :, :] + x_fp32_s1_ref = \ + torch.quantize_per_tensor(x_fp32_s1, scale, zp, dtype) + self.assertEqual(x_q_s1, x_fp32_s1_ref) + + # multiple dim, single index + x_q_s2 = x_q[:, [0], [2], :] + x_fp32_s2 = x_fp32[:, [0], [2], :] + x_fp32_s2_ref = \ + torch.quantize_per_tensor(x_fp32_s2, scale, zp, dtype) + self.assertEqual(x_q_s2, x_fp32_s2_ref) + + # single dim, multiple indices + x_q_s3 = x_q[:, [2, 0, 1], :, :] + x_fp32_s3 = x_fp32[:, [2, 0, 1], :, :] + x_fp32_s3_ref = \ + torch.quantize_per_tensor(x_fp32_s3, scale, zp, dtype) + self.assertEqual(x_q_s3, x_fp32_s3_ref) + + # multiple dim, multiple indices + x_q_s4 = x_q[:, [2, 0, 1], :, [1]] + x_fp32_s4 = x_fp32[:, [2, 0, 1], :, [1]] + x_fp32_s4_ref = \ + torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype) + self.assertEqual(x_q_s4, x_fp32_s4_ref) + + @override_qengines + def test_custom_module_lstm(self): + qengine = torch.backends.quantized.engine + + batch_size = 4 + seq_len = 8 + input_size = 12 + + hidden_size = 8 + num_layers = 2 + + dropout = 0 # This is not supported + + Bias = [False, True] + Batch_first = [False, True] + Bidirectional = [False, True] + + dtype = np.uint8 + qtype = torch.quint8 + + custom_module_config = { + 'float_to_observed_custom_module_class': { + torch.nn.LSTM: torch.nn.quantizable.LSTM + } + } + + x = np.random.randn(seq_len, batch_size, input_size) + scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype) + x = torch.from_numpy(x).to(torch.float) + qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, + dtype=qtype) + x = qx.dequantize() + + with torch.no_grad(): + for bias, batch_first, bidirectional in itertools.product( + Bias, Batch_first, Bidirectional): + # Assume 12dB is sufficient for functional equivalence + # Without the bias, linear performs poorly + min_power = 10 if bias else 5 + max_mse = 5e-6 if bias else 5e-1 + + if batch_first: + x = x.reshape(batch_size, seq_len, input_size) + qx = qx.reshape(batch_size, seq_len, input_size) + else: + x = x.reshape(seq_len, batch_size, input_size) + qx = qx.reshape(seq_len, batch_size, input_size) + + lstm = torch.nn.Sequential( + torch.nn.LSTM(input_size, hidden_size, + num_layers=num_layers, + bias=bias, batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional)) + lstm.eval() + y_ref = lstm(x) + + # Prepare + lstm.qconfig = torch.quantization.get_default_qconfig(qengine) + lstm_prepared = torch.quantization.prepare( + lstm, prepare_custom_config_dict=custom_module_config) + self.assertTrue(hasattr(lstm_prepared[0], 'layers')) + self.assertEqual(num_layers, len(lstm_prepared[0].layers)) + + # Calibrate + y = lstm_prepared(x) + self.assertEqual(y_ref, y) + + # Quantize + lstm_quantized = torch.quantization.convert(lstm_prepared) + qy = lstm_quantized(qx) + + snr = _snr(y, qy) + snr = [snr[0]] + snr[1] + + for signal, mse, power in snr: + self.assertTrue( + power > min_power or mse < max_mse, + msg=(f"Error is too high: SNR(dB): {power}, " + f"Signal: {signal}, MSE: {mse}")) class TestDynamicQuantizedLinear(TestCase): @@ -2327,10 +2592,10 @@ def test_qlinear_legacy(self, batch_size, input_channels, output_channels): class TestDynamicQuantizedRNNOp(TestCase): """Tests the correctness of the dynamic quantized lstm/gru.""" - def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions): + def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range): # For Input (seq_len, batch, input_size) X = torch.randn(seq_len, num_batches, input_size) - s, z = _calculate_dynamic_qparams(X, torch.quint8, reduce_range=True) + s, z = _calculate_dynamic_qparams(X, torch.quint8, reduce_range) Xq = torch.quantize_per_tensor(X, s, z, torch.quint8) # For H and C: (num_layers(1) * num_directions, batch, hidden_size) @@ -2342,9 +2607,9 @@ def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_dir H = torch.zeros(num_directions, num_batches, hidden_size) C = torch.zeros(num_directions, num_batches, hidden_size) - s, z = _calculate_dynamic_qparams(H, torch.quint8, reduce_range=True) + s, z = _calculate_dynamic_qparams(H, torch.quint8, reduce_range) Hq = torch.quantize_per_tensor(H, s, z, torch.quint8) - s, z = _calculate_dynamic_qparams(C, torch.quint8, reduce_range=True) + s, z = _calculate_dynamic_qparams(C, torch.quint8, reduce_range) Cq = torch.quantize_per_tensor(C, s, z, torch.quint8) return Xq, Hq, Cq @@ -2387,7 +2652,11 @@ def test_qlstmGRU(self, num_batches, input_size, hidden_size, if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16: continue - Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size, hidden_size, num_directions) + if torch.backends.quantized.engine == 'qnnpack': + reduce_range = False + else: + reduce_range = True + Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range) Wq1, Wq2, b1, b2 = self._get_rnn_weights_and_bias(input_size, hidden_size, num_directions, @@ -2396,7 +2665,7 @@ def test_qlstmGRU(self, num_batches, input_size, hidden_size, if dtype == torch.qint8: packed_ih = torch.ops.quantized.linear_prepack(Wq1, b1) packed_hh = torch.ops.quantized.linear_prepack(Wq2, b2) - cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih, packed_hh, b1, b2, True) + cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih, packed_hh, b1, b2, reduce_range) W_ref1 = Wq1.dequantize() W_ref2 = Wq2.dequantize() @@ -2516,7 +2785,12 @@ def test_qrnncell(self, num_batches, input_size, hidden_size, per_channel_quant) if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16: continue - Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size, hidden_size, 1) + if torch.backends.quantized.engine == 'qnnpack': + reduce_range = False + else: + reduce_range = True + + Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size, hidden_size, 1, reduce_range) Wq1, Wq2, b1, b2 = self._get_rnn_weights_and_bias(input_size, hidden_size, 1, per_channel_quant, rnn_type) if dtype == torch.qint8: packed_ih = torch.ops.quantized.linear_prepack(Wq1, b1) @@ -2720,23 +2994,24 @@ class TestQuantizedEmbeddingOps(TestCase): def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams): weights = torch.from_numpy((np.random.random_sample(( num_embeddings, embedding_dim)) + 1).astype(np.float32)) - + qtype = torch.quint8 if bit_rate == 8: w_packed = pack_fn(weights) else: w_packed = pack_fn(weights, optimized_qparams=optimized_qparams) w_unpacked = unpack_fn(w_packed) - if bit_rate == 8: + if bit_rate == 8 or bit_rate == 4: # Check numerics of prepack function that accepts qtensor as input. # We use min-max observer to mimic the quantization performed in the original function. obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) obs(weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() - + if bit_rate == 4: + qtype = torch.quint4x2 # Quantize the weights to 8bits - qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qtype) real_packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) self.assertEqual(isinstance(real_packed_weight, torch._C.ScriptObject), True) unpacked_weight = torch.ops.quantized.embedding_bag_unpack(real_packed_weight) @@ -2818,10 +3093,13 @@ def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimize self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams) + def embedding_bag_rowwise_offsets_run( self, bit_rate, num_embeddings, - embedding_dim, num_offsets, enable_per_sample_weights, - include_last_offset, atol, rtol): + embedding_dim, num_offsets, + use_32bit_indices, use_32bit_offsets, + enable_per_sample_weights, + include_last_offset, fallback_to_no_sparse, sparsity, atol, rtol): pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack if bit_rate == 4: @@ -2876,71 +3154,128 @@ def get_reference_result( return embedding_bag(indices, offsets, per_sample_weights=per_sample_weights) + mapping_table = np.zeros(num_embeddings, dtype=np.int32) + pruned_weights = weights + prune_weights = sparsity > 0 + if prune_weights: + if fallback_to_no_sparse: + # Testing that prune_weight with mapping_table {0} will + # fallback to non sparse embedding look up kernel. + mapping_table = np.zeros(1, dtype=np.int32) + else: + # Prune and generate mapping table + num_compressed_rows = 0 + unpruned_ids = [] + for i in range(num_embeddings): + if np.random.uniform() < sparsity: + mapping_table[i] = -1 + q_weights[i, :] = 0 + weights[i, :] = 0 + else: + mapping_table[i] = num_compressed_rows + num_compressed_rows += 1 + unpruned_ids.append(i) + q_weights = q_weights[unpruned_ids] + pruned_weights = weights[unpruned_ids] + + result = pt_op(q_weights, + indices.int() if use_32bit_indices else indices, + offsets.int() if use_32bit_offsets else offsets, + mode=0, + pruned_weights=prune_weights, + per_sample_weights=per_sample_weights, + compressed_indices_mapping=torch.tensor(mapping_table), + include_last_offset=include_last_offset) + reference_result = get_reference_result( num_embeddings, embedding_dim, include_last_offset, weights, per_sample_weights, indices, offsets) - result = pt_op( - q_weights, - indices, - offsets, - mode=0, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - ) + torch.testing.assert_allclose(reference_result, result, atol=atol, rtol=rtol) - if bit_rate == 8: + + if bit_rate == 8 or bit_rate == 4: # Test operator that accepts TorchBind packed weights. - obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) - obs(weights) + if bit_rate == 4: + qdtype = torch.quint4x2 + op = torch.ops.quantized.embedding_bag_4bit + else: + qdtype = torch.quint8 + op = torch.ops.quantized.embedding_bag_byte + obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + obs(pruned_weights) # Get the scale and zero point for the weight tensor qparams = obs.calculate_qparams() - # Quantize the weights to 8bits - qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + qweight = torch.quantize_per_channel(pruned_weights, qparams[0], qparams[1], axis=0, dtype=qdtype) packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) - result = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, offsets, mode=0, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset) + result = op(packed_weight, indices, offsets, mode=0, + pruned_weights=prune_weights, + per_sample_weights=per_sample_weights, + compressed_indices_mapping=torch.tensor(mapping_table), + include_last_offset=include_last_offset) torch.testing.assert_allclose(reference_result, result, atol=atol, rtol=rtol) + """ Tests the correctness of the embedding_bag_8bit quantized operator """ @skipIfNoFBGEMM @given(num_embeddings=st.integers(10, 100), embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), num_offsets=st.integers(1, 20), + use_32bit_indices=st.booleans(), + use_32bit_offsets=st.booleans(), enable_per_sample_weights=st.booleans(), - include_last_offset=st.booleans()) - def test_embedding_bag_byte_rowwise_offsets(self, num_embeddings, - embedding_dim, num_offsets, - enable_per_sample_weights, - include_last_offset): + include_last_offset=st.booleans(), + fallback_to_no_sparse=st.booleans(), + sparsity=st.sampled_from([0.0, 0.5, 0.7])) + def test_embedding_bag_byte(self, num_embeddings, + embedding_dim, num_offsets, + use_32bit_indices, + use_32bit_offsets, + enable_per_sample_weights, + include_last_offset, + fallback_to_no_sparse, + sparsity): self.embedding_bag_rowwise_offsets_run( 8, num_embeddings, embedding_dim, num_offsets, + use_32bit_indices, use_32bit_offsets, enable_per_sample_weights, include_last_offset, - atol=0.005, rtol=1e-3) + fallback_to_no_sparse, + sparsity=sparsity, atol=0.005, rtol=1e-3) """ Tests the correctness of the embedding_bag_4bit quantized operator """ @given(num_embeddings=st.integers(10, 100), embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), num_offsets=st.integers(1, 20), + use_32bit_indices=st.booleans(), + use_32bit_offsets=st.booleans(), enable_per_sample_weights=st.booleans(), - include_last_offset=st.booleans()) - def test_embedding_bag_4bit_rowwise_offsets(self, num_embeddings, - embedding_dim, num_offsets, - enable_per_sample_weights, - include_last_offset): + include_last_offset=st.booleans(), + fallback_to_no_sparse=st.booleans(), + sparsity=st.sampled_from([0.0, 0.5, 0.7])) + def test_embedding_bag_4bit(self, num_embeddings, + embedding_dim, num_offsets, + use_32bit_indices, + use_32bit_offsets, + enable_per_sample_weights, + include_last_offset, + fallback_to_no_sparse, + sparsity): self.embedding_bag_rowwise_offsets_run(4, num_embeddings, embedding_dim, num_offsets, + use_32bit_indices, use_32bit_offsets, enable_per_sample_weights, - include_last_offset, atol=0.1, - rtol=1e-2) + include_last_offset, + fallback_to_no_sparse, + sparsity=sparsity, + atol=0.1, rtol=1e-2) """ Tests the correctness of the quantized embedding lookup operator """ @given(num_embeddings=st.integers(10, 100), embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0)) + @skipIfNoFBGEMM def test_embedding_byte(self, num_embeddings, embedding_dim): quant_op = torch.ops.quantized.embedding_byte prepack_op = torch.ops.quantized.embedding_bag_prepack @@ -2965,11 +3300,73 @@ def test_embedding_byte(self, num_embeddings, embedding_dim): low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) packed_weight = prepack_op(qweight) - qresult = quant_op(packed_weight, indices, sparse=False) + qresult = quant_op(packed_weight, indices, pruned_weights=False) ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) torch.testing.assert_allclose(ref, qresult, atol=0.005, rtol=1e-3) + + @skipIfNoFBGEMM + def test_embedding_2d_indices(self): + """ + Tests the case where 2D indices are passed into the operator + In this case the operator computes the correct offsets argument. + Output shape is dependent on the indices dimension. + """ + quant_op = torch.ops.quantized.embedding_byte + prepack_op = torch.ops.quantized.embedding_bag_prepack + + indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]]) + weights = torch.randn(10, 12, dtype=torch.float32) + + ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) + obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + obs(weights) + qparams = obs.calculate_qparams() + + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + packed_weight = prepack_op(qweight) + qresult = quant_op(packed_weight, indices, pruned_weights=False) + torch.testing.assert_allclose(ref, qresult, atol=0.05, rtol=1e-3) + + @skipIfNoFBGEMM + def test_embedding_bag_2d_indices(self): + """ + Tests the case where 2D indices are passed into the operator + In this case the operator computes the correct offsets argument. + """ + indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]]) + weights = torch.randn(10, 12, dtype=torch.float32) + + embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=10, + embedding_dim=12, + include_last_offset=False, _weight=weights, + scale_grad_by_freq=False, mode='sum' + ) + result = embedding_bag(indices) + + pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets + pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack + q_weights = pt_prepack_op(weights) + qresult = pt_op(q_weights, indices, mode=0, pruned_weights=False) + torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3) + + # Test TorchBind based embedding_bag operator + obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) + obs(weights) + # Get the scale and zero point for the weight tensor + qparams = obs.calculate_qparams() + + # Quantize the weights to 8bits + qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) + + packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) + qresult = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, mode=0) + + torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3) + + class TestQuantizedConv(TestCase): def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs, strides, i_pads, o_pads, channelwise): @@ -2980,12 +3377,14 @@ def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs, W = torch.from_numpy(W).float() bias = torch.from_numpy(bias).float() - + if channelwise and transposed: + # currently transposed conv and per-channel per quantization does not work + return if channelwise: if transposed: - output_channels = W.shape[1] + output_channels = W.shape[1] # IC OC/G else: - output_channels = W.shape[0] + output_channels = W.shape[0] # OC IC/G W_scale = torch.tensor([W_scale] * output_channels) W_zero_point = torch.tensor([W_zero_point] * output_channels) W_q = torch.quantize_per_channel( @@ -3386,8 +3785,6 @@ def test_qconv_transpose2d( Y_scale, Y_zero_point, use_bias): - if not qengine_is_qnnpack(): - return # Currently only QNNPACK is supported if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN): return # QNNPACK doesn't support these assume(o_pad_h < stride_h or o_pad_h < dilation) @@ -3445,6 +3842,122 @@ def test_qconv_transpose2d( Y_q = qconv_op(X_q) self.assertEqual(Y_q_ref, Y_q) + """Tests the correctness of quantized convolution op.""" + @given(batch_size=st.integers(1, 3), + input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), + time=st.integers(2, 5), + height=st.integers(10, 16), + width=st.integers(7, 14), + output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), + groups=st.integers(1, 3), + kernel_t=st.integers(1, 7), + kernel_h=st.integers(1, 7), + kernel_w=st.integers(1, 7), + stride_t=st.integers(1, 2), + stride_h=st.integers(1, 2), + stride_w=st.integers(1, 2), + pad_t=st.integers(0, 2), + pad_h=st.integers(0, 2), + pad_w=st.integers(0, 2), + o_pad_t=st.integers(0, 2), + o_pad_h=st.integers(0, 2), + o_pad_w=st.integers(0, 2), + dilation=st.integers(1, 2), + X_scale=st.floats(1.2, 1.6), + X_zero_point=st.integers(0, 4), + W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), + W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), + Y_scale=st.floats(4.2, 5.6), + Y_zero_point=st.integers(0, 4), + use_bias=st.booleans()) + @override_qengines + def test_qconv_transpose3d( + self, + batch_size, + input_channels_per_group, + time, + height, + width, + output_channels_per_group, + groups, + kernel_t, + kernel_h, + kernel_w, + stride_t, + stride_h, + stride_w, + pad_t, + pad_h, + pad_w, + o_pad_t, + o_pad_h, + o_pad_w, + dilation, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + Y_scale, + Y_zero_point, + use_bias): + if qengine_is_qnnpack(): + return # QNNPACK doesn't support this + assume(o_pad_t < stride_t or o_pad_t < dilation) + assume(o_pad_h < stride_h or o_pad_h < dilation) + assume(o_pad_w < stride_w or o_pad_w < dilation) + + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_t, kernel_h, kernel_w) + strides = (stride_t, stride_h, stride_w) + pads = (pad_t, pad_h, pad_w) + o_pads = (o_pad_t, o_pad_h, o_pad_w) + dilations = (dilation, dilation, dilation) + + qconv = torch.ops.quantized.conv_transpose3d + qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack + conv_op = torch.nn.ConvTranspose3d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=kernels, + stride=strides, + padding=pads, + output_padding=o_pads, + groups=groups, + dilation=dilations, + bias=use_bias + ) + X_q, W_q, bias_float = self._test_qconv_impl( + qconv, qconv_prepack, conv_op, batch_size, + input_channels_per_group, (time, height, width), + output_channels_per_group, groups, kernels, strides, pads, o_pads, + dilations, X_scale, X_zero_point, W_scale, W_zero_point, + Y_scale, Y_zero_point, use_bias, use_relu=False, + use_channelwise=False, use_transpose=True) + + # Test the module implementation + qconv_op = torch.nn.quantized.ConvTranspose3d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=kernels, + stride=strides, + padding=pads, + output_padding=o_pads, + groups=groups, + dilation=dilations, + bias=use_bias + ) + qconv_op.scale = Y_scale + qconv_op.zero_point = Y_zero_point + qconv_op.set_weight_bias(W_q, bias_float) + + Y_dq_ref = conv_op(X_q.dequantize()) + Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale, + zero_point=Y_zero_point, + dtype=torch.quint8) + Y_q = qconv_op(X_q) + self.assertEqual(Y_q_ref, Y_q) + @given( inputs=hu.tensor_conv( spatial_dim=1, batch_size_range=(1, 3), @@ -3513,8 +4026,6 @@ def test_qconv2d_unpack(self, inputs, stride, pad, o_pad, channelwise): return if qengine == 'qnnpack': assume(not channelwise) # QNNPACK doesn't support channelwise - else: - assume(not transposed) # Only QNNPACK supports transposed conv if transposed: qconv_prepack = torch.ops.quantized.conv_transpose2d_prepack qconv_unpack = torch.ops.quantized.conv_transpose2d_unpack @@ -3699,22 +4210,26 @@ def test_qconv3d( stride_w=st.integers(1, 2), pad_d=st.integers(1, 2), pad_h=st.integers(1, 2), pad_w=st.integers(1, 2), - channelwise=st.booleans(), - qengine=st.sampled_from(("fbgemm",))) + o_pad=st.integers(0, 2), + channelwise=st.booleans()) + @override_qengines def test_qconv3d_unpack( - self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, - channelwise, qengine + self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, o_pad, + channelwise ): - if qengine not in supported_qengines: - return - - with override_quantized_engine(qengine): - qconv3d_prepack = torch.ops.quantized.conv3d_prepack - qconv3d_unpack = torch.ops.quantized.conv3d_unpack - self._test_qconv_unpack_impl( - qconv3d_prepack, qconv3d_unpack, inputs, - (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), None, - channelwise) + if qengine_is_qnnpack(): + return # QNNPACK doesn't support this + transposed = inputs[-1] + if transposed: + qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack + qconv_unpack = torch.ops.quantized.conv_transpose3d_unpack + else: + qconv_prepack = torch.ops.quantized.conv3d_prepack + qconv_unpack = torch.ops.quantized.conv3d_unpack + self._test_qconv_unpack_impl( + qconv_prepack, qconv_unpack, inputs, + (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), (o_pad, o_pad, o_pad), + channelwise) class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), @@ -3735,7 +4250,35 @@ def test_reflection_pad1d(self, batch_size, channels, width, qtype): y_ref = padding_op(x) qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype) qy_hat = padding_op(qx) + self.assertEqual(qy_ref, qy_hat) + # Out variant + qy_hat = torch._C._nn.reflection_pad1d(qx, padding, out=qy_hat) + self.assertEqual(qy_ref, qy_hat) + + @given(batch_size=st.integers(1, 64), + channels=st.integers(1, 64), + height=st.integers(16, 128), + width=st.integers(16, 128), + qtype=st.sampled_from(hu._ALL_QINT_TYPES)) + def test_reflection_pad2d(self, batch_size, channels, height, width, qtype): + padding = (width // 4, width // 4, height // 4, height // 4) + + x = torch.arange(batch_size * channels * height * width).to(torch.float) + x = x.resize(batch_size, channels, height, width) + # Per-Tensor test + scale, zp = _calculate_dynamic_qparams(x, qtype) + qx = torch.quantize_per_tensor(x, scale, zp, qtype) + + padding_op = torch.nn.ReflectionPad2d(padding) + + y_ref = padding_op(x) + qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype) + qy_hat = padding_op(qx) + self.assertEqual(qy_ref, qy_hat) + + # Out variant + qy_hat = torch._C._nn.reflection_pad2d(qx, padding, out=qy_hat) self.assertEqual(qy_ref, qy_hat) @given(batch_size=st.integers(1, 64), @@ -3796,55 +4339,68 @@ def test_qnnpack_relu(self, X): """Tests the correctness of the quantized::qnnpack_tanh op.""" @skipIfNoFBGEMM - @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), - qparams=hu.qparams(dtypes=torch.quint8))) - def test_qnnpack_tanh(self, X): + def test_qnnpack_tanh(self): # Note: In QNNPACK the output scale and zero_point can only be # 2.0/256, 128 respectively, as it uses a LUT with 256 bins. - X, (scale, zero_point, torch_type) = X - X = torch.from_numpy(X) - qX = torch.quantize_per_tensor(X, scale=scale, - zero_point=zero_point, - dtype=torch_type) - # Floating point reference - Y = torch.tanh(X) - qY = torch.quantize_per_tensor(Y, scale=1.0 / 128, zero_point=128, - dtype=torch.quint8) - with override_quantized_engine('fbgemm'): - qYserver = torch.tanh(qX) - with override_quantized_engine('qnnpack'): - qY_hat = torch.tanh(qX) - self.assertEqual(qY, qY_hat, - msg="QNNPACK TanH failed (FP ref)!") - self.assertEqual(qYserver, qY_hat, - msg="QNNPACK TanH failed (FBGEMM ref)!") + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) + memory_formats = (torch.channels_last, torch.contiguous_format) + test_cases = itertools.product(shapes, memory_formats) + for shape, memory_format in test_cases: + X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8 + if memory_format == torch.channels_last and len(shape) != 4: + continue + X = X.to(memory_format=memory_format) + qX = torch.quantize_per_tensor(X, scale=scale, + zero_point=zero_point, + dtype=torch_type) + + # Floating point reference + Y = torch.tanh(qX.dequantize()) + qY = torch.quantize_per_tensor(Y, scale=1.0 / 128, zero_point=128, + dtype=torch.quint8) + with override_quantized_engine('fbgemm'): + qYserver = torch.tanh(qX) + with override_quantized_engine('qnnpack'): + qY_hat = torch.tanh(qX) + self.assertEqual( + qY, qY_hat, + msg="QNNPACK TanH failed (FP ref), memory_format {}".format(memory_format)) + self.assertEqual( + qYserver, qY_hat, + msg="QNNPACK TanH failed (FBGEMM ref), memory_format {}".format(memory_format)) """Tests the correctness of the quantized::qnnpack_sigmoid op.""" @skipIfNoFBGEMM - @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), - qparams=hu.qparams(dtypes=torch.quint8))) - def test_qnnpack_sigmoid(self, X): + def test_qnnpack_sigmoid(self): # Note: In QNNPACK the output scale and zero_point can only be # 1.0/256, 0 respectively, as it uses a LUT with 256 bins. - X, (scale, zero_point, torch_type) = X - X = torch.from_numpy(X).to(torch.float32) - qX = torch.quantize_per_tensor(X, scale=scale, - zero_point=zero_point, - dtype=torch_type) + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) + memory_formats = (torch.channels_last, torch.contiguous_format) + test_cases = itertools.product(shapes, memory_formats) + for shape, memory_format in test_cases: + X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8 + if memory_format == torch.channels_last and len(shape) != 4: + continue + X = X.to(memory_format=memory_format) + qX = torch.quantize_per_tensor(X, scale=scale, + zero_point=zero_point, + dtype=torch_type) - # Floating point reference - Y = torch.sigmoid(X) - qY = torch.quantize_per_tensor(Y, scale=1.0 / 256, zero_point=0, - dtype=torch.quint8) - with override_quantized_engine('fbgemm'): - qYserver = torch.sigmoid(qX) - with override_quantized_engine('qnnpack'): - qY_hat = torch.sigmoid(qX) - self.assertEqual(qY, qY_hat, - msg="QNNPACK Sigmoid failed (FP ref)!") - self.assertEqual(qYserver, qY_hat, - msg="QNNPACK Sigmoid failed (FBGEMM ref)!") + # Floating point reference + Y = torch.sigmoid(qX.dequantize()) + qY = torch.quantize_per_tensor(Y, scale=1.0 / 256, zero_point=0, + dtype=torch.quint8) + with override_quantized_engine('fbgemm'): + qYserver = torch.sigmoid(qX) + with override_quantized_engine('qnnpack'): + qY_hat = torch.sigmoid(qX) + self.assertEqual( + qY, qY_hat, + msg="QNNPACK Sigmoid failed (FP ref), memory_format {}".format(memory_format)) + self.assertEqual( + qYserver, qY_hat, + msg="QNNPACK Sigmoid failed (FBGEMM ref), memory_format {}".format(memory_format)) @skipIfNoFBGEMM def test_qnnpack_sigmoid_sweep(self): @@ -4094,31 +4650,32 @@ def test_mean(self, batch_size, channels, height, width, scale, zero_point): np.testing.assert_array_almost_equal(Y.int_repr().numpy(), qY.int_repr().numpy(), decimal=0) """Tests the correctness of the quantized::hardtanh op.""" - @given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8, max_numel=10**5), - elements=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False), - qparams=hu.qparams(dtypes=torch.quint8)), - min_val=hu.floats(-1e6, -9.999999974752427e-07, allow_nan=False, allow_infinity=False), - max_val=hu.floats(9.999999974752427e-07, 1e6, allow_nan=False, allow_infinity=False)) - def test_hardtanh(self, X, min_val, max_val): + def test_hardtanh(self): if 'qnnpack' not in torch.backends.quantized.supported_engines: return with override_quantized_engine('qnnpack'): - X, (scale, zero_point, torch_type) = X + shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) + memory_formats = (torch.channels_last, torch.contiguous_format) + min_vals = (-0.5, -0.3, 0.5) + max_vals = (-0.3, 0.3, 0.7) + test_cases = itertools.product(shapes, memory_formats, min_vals, max_vals) + for shape, memory_format, min_val, max_val in test_cases: + X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8 + if memory_format == torch.channels_last and len(shape) != 4: + continue - assume(min_val <= max_val) - Y = X.copy() - Y[Y < min_val] = min_val - Y[Y > max_val] = max_val - qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale, - zero_point=zero_point, dtype=torch_type) - X = torch.from_numpy(X) - qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, - dtype=torch_type) + Y = X.clone() + Y[Y < min_val] = min_val + Y[Y > max_val] = max_val + qY = torch.quantize_per_tensor(Y, scale=scale, + zero_point=zero_point, dtype=torch_type) + qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, + dtype=torch_type) - qY_hat = torch.nn.quantized.functional.hardtanh(qX, min_val, max_val) - self.assertEqual( - qY, qY_hat, - msg="hardtanh failed:\nactual {}\nexpected {}".format(qY_hat, qY)) + qY_hat = torch.nn.quantized.functional.hardtanh(qX, min_val, max_val) + self.assertEqual( + qY, qY_hat, + msg="hardtanh failed:\nactual {}\nexpected {}\nmemory_format {}".format(qY_hat, qY, memory_format)) """Tests the correctness of the tensor comparators.""" class TestComparatorOps(TestCase): diff --git a/test/quantization/test_quantized_tensor.py b/test/quantization/test_quantized_tensor.py index fc3aa3c655ebe..e919deb9d2bd5 100644 --- a/test/quantization/test_quantized_tensor.py +++ b/test/quantization/test_quantized_tensor.py @@ -67,6 +67,75 @@ def _calculate_dynamic_qparams(X, dtype, reduce_range=False): def get_supported_device_types(): return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu'] +# Note we explicitly cast variables to np.float32 in a couple of places to avoid +# the default casting in Python often resuling in double precision and to make +# sure we're doing the same numerics as C++ code. +def param_search_greedy(x, bit_rate, n_bins=200, ratio=0.16): + xmin, xmax = np.min(x), np.max(x) + stepsize = (xmax - xmin) / np.float32(n_bins) + min_bins = np.float32(n_bins) * (np.float32(1) - np.float32(ratio)) + xq, loss = _compress_uniform_simplified(x, bit_rate, xmin, xmax) + + solutions = [] # [(left, right, loss)] # local optima solution + + cur_min, cur_max, cur_loss = xmin, xmax, loss + thr = min_bins * stepsize + while cur_min + thr < cur_max: + # move left + xq, loss1 = _compress_uniform_simplified( + x, bit_rate, cur_min + stepsize, cur_max + ) + # move right + xq, loss2 = _compress_uniform_simplified( + x, bit_rate, cur_min, cur_max - stepsize + ) + + if cur_loss < loss1 and cur_loss < loss2: + # found a local optima + solutions.append((cur_min, cur_max, cur_loss)) + if loss1 < loss2: + cur_min, cur_max, cur_loss = cur_min + stepsize, cur_max, loss1 + else: + cur_min, cur_max, cur_loss = cur_min, cur_max - stepsize, loss2 + if len(solutions): + best = solutions[0] + for solution in solutions: + if solution[-1] < best[-1]: + best = solution + return best[1], best[0] # xmax, xmin + return xmax, xmin + + +def _compress_uniform_simplified(X, bit_rate, xmin, xmax, fp16_scale_bias=True): + # affine transform to put Xq in [0,2**bit_rate - 1] + # Xq = (2 ** bit_rate - 1) * (Xq - xmin) / data_range + if fp16_scale_bias: + xmin = xmin.astype(np.float16).astype(np.float32) + data_range = xmax - xmin + scale = np.where( + data_range == 0, np.float32(1), data_range / np.float32(2 ** bit_rate - 1) + ) + if fp16_scale_bias: + scale = scale.astype(np.float16).astype(np.float32) + inverse_scale = np.float32(1) / scale + Xq = np.clip(np.round((X - xmin) * inverse_scale), 0, np.float32(2 ** bit_rate - 1)) + Xq = Xq * scale + xmin + + # Manually compute loss instead of using np.linalg.norm to use the same + # accumulation order used by C++ code + vlen = 8 + loss_v = np.zeros(vlen).astype(np.float32) + for i in range(len(Xq) // vlen * vlen): + loss_v[i % vlen] += (X[i] - Xq[i]) * (X[i] - Xq[i]) + loss = np.float32(0) + for i in range(vlen): + loss += loss_v[i] + for i in range(len(Xq) // vlen * vlen, len(Xq)): + loss += (X[i] - Xq[i]) * (X[i] - Xq[i]) + loss = np.sqrt(loss) + + return Xq, loss + class TestQuantizedTensor(TestCase): def test_qtensor(self): num_elements = 10 @@ -103,6 +172,36 @@ def test_qtensor(self): "quantization_scheme=torch.per_tensor_affine, " + "scale=1.0, zero_point=2)") + def test_qtensor_sub_byte(self): + num_elements = 10 + scale = 1.0 + zero_point = 2 + for dtype in [torch.quint4x2]: + r = torch.ones((5, 2), dtype=torch.float) + qr = torch.quantize_per_tensor(r, scale, zero_point, dtype) + self.assertEqual(qr.q_scale(), scale) + self.assertEqual(qr.q_zero_point(), zero_point) + self.assertTrue(qr.is_quantized) + self.assertFalse(r.is_quantized) + self.assertEqual(qr.storage().size(), 5) + + int_repr = qr.int_repr() + for num in int_repr[0:5]: + self.assertEqual(num, 51) # Packed entries, each of value 3, i.e. 00110011 + + # Test tensor creation + q = torch._empty_affine_quantized([num_elements], scale=scale, zero_point=zero_point, + dtype=torch.quint4x2) + self.assertEqual(q.storage().size(), 5) + + # Test save/load + with tempfile.NamedTemporaryFile() as f: + torch.save(qr, f) + f.seek(0) + loaded_q = torch.load(f) + loaded_int_repr = loaded_q.int_repr()[0:5] + self.assertEqual(int_repr[0:5], loaded_int_repr) + def test_qtensor_float_assignment(self): # Scalar Tensor # item @@ -216,15 +315,10 @@ def test_qtensor_dtypes(self): r = torch.rand(3, 2, dtype=torch.float) * 4 - 2 scale = 0.2 zero_point = 2 - qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint8) - rqr = qr.dequantize() - self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale)) - qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8) - rqr = qr.dequantize() - self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale)) - qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint32) - rqr = qr.dequantize() - self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale)) + for dtype in [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2]: + qr = torch.quantize_per_tensor(r, scale, zero_point, dtype) + rqr = qr.dequantize() + self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale)) def _test_quantize_per_channel(self, r, scales, zero_points, axis, float_params): @@ -335,6 +429,52 @@ def quantize_ref(data, scales, zero_points): zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float) self._test_quantize_per_channel(r, scales, zero_points, 0, True) + def test_quantize_per_channel_sub_byte(self): + """ Tests the per channel quantization scheme for 4-bit qtensors. + The scale and zero point for this have to be in floating point. """ + r = torch.rand(3, 2, dtype=torch.float) * 4 + scales = torch.tensor([0.2, 0.3, 0.1], dtype=torch.float) + zero_points = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float) + qr = torch.quantize_per_channel(r, scales, zero_points, 0, torch.quint4x2) + dequant_tensor = qr.dequantize() + + def _get_qranges(bit_width): + if bit_width == 4: + return 0, 15 + + def _quantize_per_channel_sub_byte_ref(data, scales, zero_points, axis, bit_width): + dims = data.size() + data = data.view(-1, dims[axis], np.prod(dims[axis + 1:])) + qtensor_size = math.ceil(data.numel() / 2) + res = torch.empty(qtensor_size, dtype=torch.uint8) + elem_per_byte = 8 / bit_width + quant_min, quant_max = _get_qranges(bit_width) + for i in range(data.size()[0]): + for j in range(data.size()[1]): + for k in range(data.size()[2]): + inv_scale = 1.0 / scales[j] + index = i * data.size()[1] * data.size()[2] + j * data.size()[2] + k + qvalue = np.clip( + np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max).to(dtype=torch.int) + res_idx = int(index / elem_per_byte) + if (index % elem_per_byte == 0): + res[res_idx] = qvalue + else: + res[res_idx] |= (qvalue << ((index % elem_per_byte) * bit_width)) + return res + + ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 0, 4) + self.assertTrue(np.allclose(qr.int_repr(), ref_res)) + self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1 / np.min(scales.numpy()))) + + # Check 4D tensor with non-zero axis. + r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4 + scales = torch.tensor([0.2, 0.03], dtype=torch.float) + zero_points = torch.tensor([0.1, 0.2], dtype=torch.float) + qr = torch.quantize_per_channel(r, scales, zero_points, axis=1, dtype=torch.quint4x2) + ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 1, 4) + self.assertTrue(np.allclose(qr.int_repr(), ref_res)) + def test_qtensor_permute(self): scale = 0.02 zero_point = 1 @@ -422,7 +562,9 @@ def test_qtensor_per_channel_load_save(self): scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01 zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long) # quint32, cuda is not supported yet - for dtype in [torch.quint8, torch.qint8]: + for dtype in [torch.quint8, torch.qint8, torch.quint4x2]: + if dtype == torch.quint4x2: + zero_points = torch.ones(10, dtype=torch.float) qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype) with tempfile.NamedTemporaryFile() as f: # Serializing and Deserializing Tensor @@ -745,3 +887,11 @@ def test_fp16_saturate_op(self): ref[0] = torch.ones(5) * -65504 y = torch._saturate_weight_to_fp16(x) self.assertEqual(y, ref) + + def test_choose_qparams_optimized(self): + for bit_width in [4, 2]: + x = torch.randn(64, dtype=torch.float) + y = torch.choose_qparams_optimized(x, numel=64, n_bins=200, ratio=0.16, bit_width=bit_width) + ref = param_search_greedy(x.numpy(), bit_rate=bit_width) + self.assertEqual(y[0].numpy(), ref[0]) + self.assertEqual(y[1].numpy(), ref[1]) diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index 5068a6fe7fd44..8a70ae149c290 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -5,17 +5,19 @@ PerChannelMinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, - MinMaxDynamicQuantObserver, HistogramObserver, RecordingObserver, PlaceholderObserver, NoopObserver, FakeQuantize, + FixedQParamsFakeQuantize, default_debug_qconfig, default_observer, default_per_channel_weight_observer, + default_affine_fixed_qparams_fake_quant, get_observer_dict, prepare, + QConfig, ) from torch.quantization._learnable_fake_quantize import ( @@ -44,6 +46,7 @@ QuantizationTestCase, AnnotatedSingleLayerLinearModel, test_only_eval_fn, + SingleLayerLinearModel, ) from torch.testing._internal.common_quantized import ( @@ -265,25 +268,6 @@ def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) - @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4, - min_side=1, max_side=10), - qparams=hu.qparams()), - reduce_range=st.booleans()) - def test_per_tensor_dynamic_quant_observers(self, X, reduce_range): - - X, (scale, zero_point, torch_type) = X - x = torch.from_numpy(X) - - obs = MinMaxDynamicQuantObserver(dtype=torch.quint8, reduce_range=reduce_range) - - result = obs(x) - qparams = obs.calculate_qparams() - ref = torch._choose_qparams_per_tensor(x, reduce_range) - - self.assertEqual(ref[0], qparams[0]) - self.assertEqual(ref[1], qparams[1]) - - @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams)), ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans()) @@ -394,7 +378,7 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): def test_observer_scriptable(self): - obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()] + obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver()] for obs in obs_list: scripted = torch.jit.script(obs) @@ -423,7 +407,7 @@ def test_state_dict_respects_device_affinity(self): [device_cpu, device_cuda], [device_cpu, device_cuda], [MinMaxObserver, MovingAverageMinMaxObserver, - MinMaxDynamicQuantObserver, PerChannelMinMaxObserver, + PerChannelMinMaxObserver, MovingAveragePerChannelMinMaxObserver, # TODO: enable this (separate PR) # HistogramObserver, @@ -473,6 +457,69 @@ def test_histogram_observer_save_load_state_dict(self): self.assertEqual(obs2.max_val.shape, torch.Size([])) + def test_save_load_state_dict_script(self): + """ + Tests that we can save and load state_dict for observers that are scripted + in a quantized model. + """ + obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, + PerChannelMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, HistogramObserver] + + for obs in obs_list: + model = SingleLayerLinearModel().eval() + qconfig = QConfig(activation=default_observer, weight=obs) + qconfig_dict = {'' : qconfig} + scripted = torch.jit.script(model) + scripted = torch.quantization.prepare_jit(scripted, qconfig_dict) + x = torch.rand(5, 5) + scripted(x) + obs_dict = torch.quantization.get_observer_state_dict(scripted) + + # Load stats + scripted_2 = torch.jit.script(model) + scripted_2 = torch.quantization.prepare_jit(scripted_2, qconfig_dict) + torch.quantization.load_observer_state_dict(scripted_2, obs_dict) + # Verify that state_dict matches exactly with original one. + self.assertEqual(scripted.state_dict(), scripted_2.state_dict()) + + + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_observer_qparams_respects_device_affinity(self): + """ + Ensure that the scale and zero_point returned by the observer + are on the same device as the input tensor. + """ + observerList = [MinMaxObserver(), + MovingAverageMinMaxObserver(), + PerChannelMinMaxObserver(), + MovingAveragePerChannelMinMaxObserver()] + for obs in observerList: + device = torch.device('cuda:1') + x = torch.randn(1, 2, device=device) + obs.to(device) + result = obs(x) + scale, zero_point = obs.calculate_qparams() + + self.assertEqual(x.device, scale.device) + self.assertEqual(x.device, zero_point.device) + + def test_zero_numel(self): + obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, + PerChannelMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, HistogramObserver, + FakeQuantize, FixedQParamsFakeQuantize] + for obs_cls in obs_list: + if obs_cls is FixedQParamsFakeQuantize: + obs = obs_cls(0.1, 0) + else: + obs = obs_cls() + x = torch.Tensor() + # verify no crash + x = obs(x) + + # HistogramObserver that works like it does on master class _ReferenceHistogramObserver(HistogramObserver): def __init__(self, *args, **kwargs): @@ -727,7 +774,7 @@ def test_histogram_observer_same_inputs(self): self.assertEqual(myobs.max_val, 8.0) self.assertEqual(myobs.histogram, [2., 3., 3.]) - @given(N=st.sampled_from([10, 1000, 10**6]), + @given(N=st.sampled_from([10, 1000]), bins=st.sampled_from([256, 512, 1024, 2048]), dtype=st.sampled_from([torch.qint8, torch.quint8]), qscheme=st.sampled_from([torch.per_tensor_affine, torch.per_tensor_symmetric]), @@ -748,7 +795,7 @@ def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, red self.assertEqual(ref_qparams, my_qparams) -class TestFakeQuantizePerTensor(TestCase): +class TestFakeQuantize(TestCase): @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), X=hu.tensor(shapes=hu.array_shapes(1, 5,), qparams=hu.qparams(dtypes=torch.quint8))) @@ -987,7 +1034,7 @@ def test_numerical_consistency_per_tensor(self, device, X): X=hu.tensor(shapes=hu.array_shapes(1, 5,), qparams=hu.qparams(dtypes=[torch.quint8])), ) - def test_fq_module(self, device, X): + def test_fq_module_per_tensor(self, device, X): np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, torch_type) = X quant_min = torch.iinfo(torch_type).min @@ -1008,7 +1055,22 @@ def test_fq_module(self, device, X): dX = _fake_quantize_per_tensor_affine_grad_reference(dout, X, fq_module.scale, fq_module.zero_point, quant_min, quant_max) np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) - def test_fq_serializable(self): + @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), + X=hu.tensor(shapes=hu.array_shapes(1, 5,), + qparams=hu.qparams(dtypes=torch.quint8))) + def test_fixed_qparams_fq_module(self, device, X): + X, (scale, zero_point, torch_type) = X + X = to_tensor(X, device) + fq_module = default_affine_fixed_qparams_fake_quant() + fixed_scale = fq_module.scale.clone() + fixed_zero_point = fq_module.zero_point.clone() + # run fq module and make sure the quantization parameters does not change + torch.quantization.enable_observer(fq_module) + fq_module(X) + self.assertEqual(fixed_scale, fq_module.scale) + self.assertEqual(fixed_zero_point, fq_module.zero_point) + + def test_fq_serializable_per_tensor(self): observer = default_observer quant_min = 0 quant_max = 255 @@ -1113,8 +1175,6 @@ def fake_quant_scriptable(self): loaded_module.calculate_qparams()) -class TestFakeQuantizePerChannel(TestCase): - @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), qparams=hu.qparams(dtypes=torch.quint8))) @@ -1359,7 +1419,7 @@ def test_numerical_consistency_per_channel(self, device, X): @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,), qparams=hu.qparams(dtypes=torch.qint8))) - def test_fq_module(self, device, X): + def test_fq_module_per_channel(self, device, X): np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, axis, torch_type) = X quant_min = torch.iinfo(torch_type).min @@ -1382,7 +1442,7 @@ def test_fq_module(self, device, X): fq_module.zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) - def test_fq_serializable(self): + def test_fq_serializable_per_channel(self): observer = default_per_channel_weight_observer quant_min = -128 quant_max = 127 @@ -1417,7 +1477,6 @@ def test_observers_preserve_buffers(self): observer_types = [ torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8), torch.quantization.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), - torch.quantization.MinMaxDynamicQuantObserver.with_args(dtype=torch.qint8), torch.quantization.PerChannelMinMaxObserver.with_args(dtype=torch.qint8), torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args(dtype=torch.qint8), torch.quantization.HistogramObserver.with_args(dtype=torch.qint8), diff --git a/test/run_test.py b/test/run_test.py index b24a20c60f460..54eb00104ad47 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import argparse +import copy from datetime import datetime import modulefinder import os @@ -13,12 +14,15 @@ import torch import torch._six from torch.utils import cpp_extension -from torch.testing._internal.common_utils import TEST_WITH_ROCM, shell +from torch.testing._internal.common_utils import TEST_WITH_ROCM, shell, set_cwd, FILE_SCHEMA import torch.distributed as dist from typing import Dict, Optional TESTS = [ + 'test_type_hints', 'test_autograd', + 'benchmark_utils/test_benchmark_utils', + 'test_binary_ufuncs', 'test_bundled_inputs', 'test_complex', 'test_cpp_api_parity', @@ -26,17 +30,19 @@ 'test_cpp_extensions_aot_ninja', 'test_cpp_extensions_jit', 'distributed/test_c10d', + 'distributed/test_jit_c10d', 'distributed/test_c10d_spawn', 'test_cuda', 'test_jit_cuda_fuser', - 'test_jit_cuda_fuser_legacy', - 'test_jit_cuda_fuser_profiling', 'test_cuda_primary_ctx', 'test_dataloader', + 'test_dataset', 'distributed/test_data_parallel', 'distributed/test_distributed_fork', 'distributed/test_distributed_spawn', - 'test_distributions', + 'distributions/test_constraints', + 'distributions/test_distributions', + 'test_dispatch', 'test_expecttest', 'test_foreach', 'test_indexing', @@ -48,31 +54,37 @@ 'test_multiprocessing_spawn', 'distributed/test_nccl', 'test_native_functions', - 'test_nn', 'test_numba_integration', + 'test_nn', 'test_ops', 'test_optim', + 'test_pytree', 'test_mobile_optimizer', 'test_xnnpack_integration', 'test_vulkan', - 'test_quantization', 'test_sparse', + 'test_quantization', 'test_spectral_ops', 'test_serialization', + 'test_shape_ops', 'test_show_pickle', + 'test_sort_and_select', 'test_tensor_creation_ops', + 'test_testing', 'test_torch', 'test_type_info', - 'test_type_hints', 'test_unary_ufuncs', 'test_utils', + 'test_view_ops', 'test_vmap', 'test_namedtuple_return_api', + 'test_numpy_interop', 'test_jit_profiling', 'test_jit_legacy', 'test_jit_fuser_legacy', 'test_tensorboard', 'test_namedtensor', + 'test_reductions', 'test_type_promotion', 'test_jit_disabled', 'test_function_schema', @@ -80,6 +92,7 @@ 'test_overrides', 'test_jit_fuser_te', 'test_tensorexpr', + 'test_tensorexpr_pybind', 'test_openmp', 'test_profiler', 'distributed/nn/jit/test_instantiator', @@ -90,8 +103,60 @@ 'test_determination', 'test_futures', 'test_fx', + 'test_fx_experimental', 'test_functional_autograd_benchmark', 'test_package', + 'distributed/pipeline/sync/skip/test_api', + 'distributed/pipeline/sync/skip/test_gpipe', + 'distributed/pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/pipeline/sync/skip/test_leak', + 'distributed/pipeline/sync/skip/test_portal', + 'distributed/pipeline/sync/skip/test_stash_pop', + 'distributed/pipeline/sync/skip/test_tracker', + 'distributed/pipeline/sync/skip/test_verify_skippables', + 'distributed/pipeline/sync/test_balance', + 'distributed/pipeline/sync/test_bugs', + 'distributed/pipeline/sync/test_checkpoint', + 'distributed/pipeline/sync/test_copy', + 'distributed/pipeline/sync/test_deferred_batch_norm', + 'distributed/pipeline/sync/test_dependency', + 'distributed/pipeline/sync/test_inplace', + 'distributed/pipeline/sync/test_microbatch', + 'distributed/pipeline/sync/test_phony', + 'distributed/pipeline/sync/test_pipe', + 'distributed/pipeline/sync/test_pipeline', + 'distributed/pipeline/sync/test_stream', + 'distributed/pipeline/sync/test_transparency', + 'distributed/pipeline/sync/test_worker', +] + +# Tests need to be run with pytest. +USE_PYTEST_LIST = [ + 'distributed/pipeline/sync/skip/test_api', + 'distributed/pipeline/sync/skip/test_gpipe', + 'distributed/pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/pipeline/sync/skip/test_leak', + 'distributed/pipeline/sync/skip/test_portal', + 'distributed/pipeline/sync/skip/test_stash_pop', + 'distributed/pipeline/sync/skip/test_tracker', + 'distributed/pipeline/sync/skip/test_verify_skippables', + 'distributed/pipeline/sync/test_balance', + 'distributed/pipeline/sync/test_bugs', + 'distributed/pipeline/sync/test_checkpoint', + 'distributed/pipeline/sync/test_copy', + 'distributed/pipeline/sync/test_deferred_batch_norm', + 'distributed/pipeline/sync/test_dependency', + 'distributed/pipeline/sync/test_inplace', + 'distributed/pipeline/sync/test_microbatch', + 'distributed/pipeline/sync/test_phony', + 'distributed/pipeline/sync/test_pipe', + 'distributed/pipeline/sync/test_pipeline', + 'distributed/pipeline/sync/test_stream', + 'distributed/pipeline/sync/test_transparency', + 'distributed/pipeline/sync/test_worker', + 'distributions/test_constraints', + 'distributions/test_transforms', + 'distributions/test_utils', ] WINDOWS_BLOCKLIST = [ @@ -100,7 +165,28 @@ 'distributed/rpc/test_process_group_agent', 'distributed/rpc/test_tensorpipe_agent', 'distributed/test_distributed_fork', - 'distributed/test_distributed_spawn', + 'distributed/pipeline/sync/skip/test_api', + 'distributed/pipeline/sync/skip/test_gpipe', + 'distributed/pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/pipeline/sync/skip/test_leak', + 'distributed/pipeline/sync/skip/test_portal', + 'distributed/pipeline/sync/skip/test_stash_pop', + 'distributed/pipeline/sync/skip/test_tracker', + 'distributed/pipeline/sync/skip/test_verify_skippables', + 'distributed/pipeline/sync/test_balance', + 'distributed/pipeline/sync/test_bugs', + 'distributed/pipeline/sync/test_checkpoint', + 'distributed/pipeline/sync/test_copy', + 'distributed/pipeline/sync/test_deferred_batch_norm', + 'distributed/pipeline/sync/test_dependency', + 'distributed/pipeline/sync/test_inplace', + 'distributed/pipeline/sync/test_microbatch', + 'distributed/pipeline/sync/test_phony', + 'distributed/pipeline/sync/test_pipe', + 'distributed/pipeline/sync/test_pipeline', + 'distributed/pipeline/sync/test_stream', + 'distributed/pipeline/sync/test_transparency', + 'distributed/pipeline/sync/test_worker', ] ROCM_BLOCKLIST = [ @@ -111,7 +197,6 @@ 'test_determination', 'test_multiprocessing', 'test_jit_legacy', - 'test_tensorexpr', 'test_type_hints', 'test_openmp', ] @@ -130,18 +215,28 @@ 'test_cuda_primary_ctx', ] + [test for test in TESTS if test.startswith('distributed/')] + # These tests are slow enough that it's worth calculating whether the patch # touched any related files first. SLOW_TESTS = [ + 'distributions/test_distributions', 'test_nn', 'test_autograd', 'test_cpp_extensions_jit', 'test_jit_legacy', 'test_dataloader', 'test_overrides', + 'test_linalg', 'test_jit', 'test_jit_profiling', 'test_torch', + 'test_binary_ufuncs' + 'test_numpy_interop', + 'test_reductions', + 'test_shape_ops', + 'test_sort_and_select', + 'test_testing', + 'test_view_ops', 'distributed/nn/jit/test_instantiator', 'distributed/test_distributed_fork', 'distributed/rpc/test_process_group_agent', @@ -153,16 +248,38 @@ 'test_cpp_extensions_aot_ninja', 'test_cpp_extensions_aot_no_ninja', 'test_serialization', - 'test_distributions', 'test_optim', 'test_utils', 'test_multiprocessing', 'test_tensorboard', 'distributed/test_c10d', + 'distributed/test_jit_c10d', 'distributed/test_c10d_spawn', 'test_quantization', 'test_determination', 'test_futures', + 'distributed/pipeline/sync/skip/test_api', + 'distributed/pipeline/sync/skip/test_gpipe', + 'distributed/pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/pipeline/sync/skip/test_leak', + 'distributed/pipeline/sync/skip/test_portal', + 'distributed/pipeline/sync/skip/test_stash_pop', + 'distributed/pipeline/sync/skip/test_tracker', + 'distributed/pipeline/sync/skip/test_verify_skippables', + 'distributed/pipeline/sync/test_balance', + 'distributed/pipeline/sync/test_bugs', + 'distributed/pipeline/sync/test_checkpoint', + 'distributed/pipeline/sync/test_copy', + 'distributed/pipeline/sync/test_deferred_batch_norm', + 'distributed/pipeline/sync/test_dependency', + 'distributed/pipeline/sync/test_inplace', + 'distributed/pipeline/sync/test_microbatch', + 'distributed/pipeline/sync/test_phony', + 'distributed/pipeline/sync/test_pipe', + 'distributed/pipeline/sync/test_pipeline', + 'distributed/pipeline/sync/test_stream', + 'distributed/pipeline/sync/test_transparency', + 'distributed/pipeline/sync/test_worker', ] _DEP_MODULES_CACHE: Dict[str, set] = {} @@ -183,7 +300,7 @@ 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3', 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-nccl' } - if not TEST_WITH_ROCM and dist.is_gloo_available(): + if dist.is_gloo_available(): DISTRIBUTED_TESTS_CONFIG['gloo'] = { 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3', 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-gloo' @@ -202,6 +319,13 @@ PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE")) +JIT_EXECUTOR_TESTS = [ + 'test_jit_cuda_fuser', + 'test_jit_profiling', + 'test_jit_legacy', + 'test_jit_fuser_legacy', +] + def print_to_stderr(message): print(message, file=sys.stderr) @@ -222,12 +346,17 @@ def get_executable_command(options, allow_pytest): def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unittest_args=None): unittest_args = options.additional_unittest_args.copy() if options.verbose: - unittest_args.append('--verbose') + unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest if test_module in RUN_PARALLEL_BLOCKLIST: unittest_args = [arg for arg in unittest_args if not arg.startswith('--run-parallel')] if extra_unittest_args: assert isinstance(extra_unittest_args, list) unittest_args.extend(extra_unittest_args) + + # If using pytest, replace -f with equivalent -x + if options.pytest: + unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args] + # Can't call `python -m unittest test_*` here because it doesn't run code # in `if __name__ == '__main__': `. So call `python test_*.py` instead. argv = [test_module + '.py'] + unittest_args @@ -307,15 +436,19 @@ def test_distributed(test_module, test_directory, options): 'MPI not available -- MPI backend tests will be skipped') config = DISTRIBUTED_TESTS_CONFIG for backend, env_vars in config.items(): + if sys.platform == 'win32' and backend != 'gloo': + continue if backend == 'mpi' and not mpi_available: continue for with_init_file in {True, False}: + if sys.platform == 'win32' and not with_init_file: + continue tmp_dir = tempfile.mkdtemp() if options.verbose: init_str = "with {} init_method" with_init = init_str.format("file" if with_init_file else "env") print_to_stderr( - 'Running distributed tests for the {} backend{}'.format( + 'Running distributed tests for the {} backend {}'.format( backend, with_init)) os.environ['TEMP_DIR'] = tmp_dir os.environ['BACKEND'] = backend @@ -323,9 +456,9 @@ def test_distributed(test_module, test_directory, options): os.environ.update(env_vars) if with_init_file: if test_module in ["test_distributed_fork", "test_distributed_spawn"]: - init_method = 'file://{}/'.format(tmp_dir) + init_method = f'{FILE_SCHEMA}{tmp_dir}/' else: - init_method = 'file://{}/shared_init_file'.format(tmp_dir) + init_method = f'{FILE_SCHEMA}{tmp_dir}/shared_init_file' os.environ['INIT_METHOD'] = init_method try: os.mkdir(os.path.join(tmp_dir, 'barrier')) @@ -333,11 +466,14 @@ def test_distributed(test_module, test_directory, options): if backend == 'mpi': # test mpiexec for --noprefix option with open(os.devnull, 'w') as devnull: + allowrunasroot_opt = '--allow-run-as-root' if subprocess.call( + 'mpiexec --allow-run-as-root -n 1 bash -c ""', shell=True, + stdout=devnull, stderr=subprocess.STDOUT) == 0 else '' noprefix_opt = '--noprefix' if subprocess.call( - 'mpiexec -n 1 --noprefix bash -c ""', shell=True, + f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""', shell=True, stdout=devnull, stderr=subprocess.STDOUT) == 0 else '' - mpiexec = ['mpiexec', '-n', '3', noprefix_opt] + mpiexec = ['mpiexec', '-n', '3', noprefix_opt, allowrunasroot_opt] return_code = run_test(test_module, test_directory, options, launcher_cmd=mpiexec) @@ -378,7 +514,8 @@ def parse_args(): parser.add_argument( '-v', '--verbose', - action='store_true', + action='count', + default=0, help='print verbose information and test-by-test results') parser.add_argument( '--jit', @@ -448,6 +585,19 @@ def parse_args(): nargs='*', help='additional arguments passed through to unittest, e.g., ' 'python run_test.py -i sparse -- TestSparse.test_factory_size_check') + parser.add_argument( + '--shard', + nargs=2, + type=int, + help='runs a shard of the tests (taking into account other selections), e.g., ' + '--shard 2 3 will break up the selected tests into 3 shards and run the tests ' + 'in the 2nd shard (the first number should not exceed the second)', + ) + parser.add_argument( + '--exclude-jit-executor', + action='store_true', + help='exclude tests that are run for a specific jit config' + ) return parser.parse_args() @@ -468,7 +618,7 @@ def find_test_index(test, selected_tests, find_last_index=False): If :attr:`test`='torch' and :attr:`find_last_index`=False, result should be **2**. If :attr:`test`='torch' and :attr:`find_last_index`=True, result should be **4**. - Arguments: + Args: test (str): Name of test to lookup selected_tests (list): List of tests find_last_index (bool, optional): should we lookup the index of first or last @@ -515,6 +665,17 @@ def get_selected_tests(options): last_index = find_test_index(options.last, selected_tests, find_last_index=True) selected_tests = selected_tests[:last_index + 1] + if options.shard: + assert len(options.shard) == 2, "Unexpected shard format" + assert min(options.shard) > 0, "Shards must be positive numbers" + which_shard, num_shards = options.shard + assert which_shard <= num_shards, "Selected shard must be less or equal that total number of shards" + assert num_shards <= len(selected_tests), f"Number of shards must be less than {len(selected_tests)}" + selected_tests = selected_tests[which_shard - 1 :: num_shards] + + if options.exclude_jit_executor: + options.exclude.extend(JIT_EXECUTOR_TESTS) + selected_tests = exclude_tests(options.exclude, selected_tests) if sys.platform == 'win32' and not options.ignore_win_blocklist: @@ -717,26 +878,33 @@ def main(): failure_messages = [] try: for test in selected_tests: - err_message = run_test_module(test, test_directory, options) + options_clone = copy.deepcopy(options) + if test in USE_PYTEST_LIST: + options_clone.pytest = True + err_message = run_test_module(test, test_directory, options_clone) if err_message is None: continue has_failed = True failure_messages.append(err_message) - if not options.continue_through_error: + if not options_clone.continue_through_error: raise RuntimeError(err_message) print_to_stderr(err_message) finally: if options.coverage: + from coverage import Coverage test_dir = os.path.dirname(os.path.abspath(__file__)) - if not PYTORCH_COLLECT_COVERAGE: - shell(['coverage', 'combine'], cwd=test_dir) - shell(['coverage', 'html'], cwd=test_dir) - else: - shell(['coverage', 'combine', '--append'], cwd=test_dir) + with set_cwd(test_dir): + cov = Coverage() + if PYTORCH_COLLECT_COVERAGE: + cov.load() + cov.combine(strict=False) + cov.save() + if not PYTORCH_COLLECT_COVERAGE: + cov.html_report() if options.continue_through_error and has_failed: for err in failure_messages: - print_to_stderr(message) + print_to_stderr(err) sys.exit(1) if __name__ == '__main__': diff --git a/test/test_autograd.py b/test/test_autograd.py index c03c1a496605d..cfb1989a71dd2 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -10,9 +10,9 @@ import warnings from copy import deepcopy from collections import OrderedDict -from itertools import product +from itertools import product, permutations from operator import mul -from functools import reduce +from functools import reduce, partial import torch import json @@ -29,12 +29,14 @@ record_function, emit_nvtx) import torch.autograd.functional as autogradF from torch.utils.checkpoint import checkpoint -from torch.testing._internal.common_utils import (TEST_MKL, TEST_WITH_ROCM, TestCase, run_tests, skipIfNoLapack, +from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack, suppress_warnings, slowTest, - load_tests, random_symmetric_pd_matrix, random_symmetric_matrix, - IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck) -from torch.autograd import Variable, Function, detect_anomaly + load_tests, random_symmetric_matrix, + IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck, + TemporaryFileName, TEST_WITH_ROCM) +from torch.autograd import Variable, Function, detect_anomaly, kineto_available from torch.autograd.function import InplaceFunction +import torch.autograd.forward_ad as fwAD from torch.testing import randn_like from torch.testing._internal.common_methods_invocations import (method_tests, create_input, unpack_variables, @@ -45,7 +47,7 @@ mask_not_all_zeros, S) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, skipCUDAIfRocm, - onlyCPU, onlyCUDA, dtypes, dtypesIfCUDA, + onlyCPU, onlyCUDA, onlyOnCPUAndCUDA, dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIfCudnnVersionLessThan, skipCUDAIf) @@ -57,9 +59,6 @@ def getattr_qualified(obj, qname, default=None): e.g. getattr(torch, 'fft.rfft') """ path = qname.split('.') - if len(path) > 1 and path[0] == 'fft': - import torch.fft # noqa: F401 - for name in path: obj = getattr(obj, name, _END_SENTINEL) if obj is _END_SENTINEL: @@ -74,6 +73,10 @@ def getattr_qualified(obj, qname, default=None): PRECISION = 1e-4 +# See #49409, we should remove these if we end up with a global gradcheck setting +gradcheck = partial(gradcheck, check_batched_grad=True) +gradgradcheck = partial(gradgradcheck, check_batched_grad=True) + @contextlib.contextmanager def backward_engine(engine): @@ -100,6 +103,25 @@ def graph_desc(fn): class TestAutograd(TestCase): + def test_tensor_grad_warnings(self): + dummy = torch.empty(1) + + with warnings.catch_warnings(record=True) as w: + # Accessing .grad on leaf + dummy.requires_grad_() + foo = dummy.grad + self.assertEqual(len(w), 0) + + # Accessing .grad on non-leaf + dummy = dummy.clone() + foo = dummy.grad + self.assertEqual(len(w), 1) + + # Accessing .grad on non-leaf that retains gradients + dummy.retain_grad() + foo = dummy.grad + self.assertEqual(len(w), 1) + def _function_test(self, cls): x = torch.randn(5, 5, requires_grad=True) y = torch.randn(5, 5, requires_grad=True) @@ -780,16 +802,55 @@ def test_sparse_mm_backward(self): sparse = torch.sparse_coo_tensor(size, requires_grad=True) dense = torch.randn(size, requires_grad=True) - z = sparse.mm(dense) - with self.assertRaisesRegex(RuntimeError, - "calculating the gradient of a sparse Tensor argument to mm is not supported."): - z.sum().backward() - - z = dense.addmm(sparse, dense) - with self.assertRaisesRegex(RuntimeError, - "calculating the gradient of a sparse Tensor argument to mm is not supported."): - z.sum().backward() + with self.assertRaisesRegex( + RuntimeError, + "The backward pass for this operation requires the 'mat1' tensor to be strided,"): + z = dense.addmm(sparse, dense) + + mm_test_cases = [ + # a requires grad, a is sparse, b requires grad, b is sparse, error message + (False, True, True, False, None), + (False, False, True, True, "The backward pass for this operation requires the 'mat2'"), + (False, True, True, True, "The backward pass for this operation requires the 'mat2'"), + (True, False, True, True, "The backward pass for this operation requires the 'mat2'"), + (True, True, False, False, "The backward pass for this operation requires the 'self'"), + (True, True, True, False, "The backward pass for this operation requires the 'self'"), + (True, True, True, True, "The backward pass for this operation requires the 'mat2'"), + ] + for a_req_grad, a_is_sparse, b_req_grad, b_is_sparse, err_msg in mm_test_cases: + # We should only be testing cases with sparse inputs, and at least one + # input needs to require grad so we can call a backward pass + assert a_is_sparse or b_is_sparse + assert a_req_grad or b_req_grad + + a = torch.randn(size, requires_grad=a_req_grad) + if a_is_sparse: + a = a.to_sparse() + b = torch.randn(size, requires_grad=b_req_grad) + if b_is_sparse: + b = b.to_sparse() + + # If no error expected, check that sparse and dense cases match + if err_msg is None: + r = a.mm(b) + r.sum().backward() + a_grad = None if a.grad is None else a.grad.clone().detach() + b_grad = None if b.grad is None else b.grad.clone().detach() + + # Redo with only dense tensors + a = (a.to_dense() if a.is_sparse else a).clone().detach() + a.requires_grad = a_req_grad + b = (b.to_dense() if b.is_sparse else b).clone().detach() + b.requires_grad = b_req_grad + r = a.mm(b) + r.sum().backward() + + self.assertEqual(a_grad, a.grad) + self.assertEqual(b_grad, b.grad) + else: + with self.assertRaisesRegex(RuntimeError, err_msg): + a.mm(b) def test_multi_backward(self): x = torch.randn(5, 5, requires_grad=True) @@ -827,6 +888,64 @@ def call_backwards(): torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)]) self.assertRaises(RuntimeError, call_backwards) + def test_backward_with_inputs(self): + x = torch.randn(2, 2, requires_grad=True) + y = torch.randn(2, 2, requires_grad=True) + + def fn(): + return x ** 2 + y * x + y ** 2 + + gradient = torch.ones(2, 2) + x_grad_expected = 2 * x + y + y_grad_expected = x + 2 * y + + @torch.no_grad() + def reset_grad(): + x.grad.zero_() + y.grad.zero_() + + torch.autograd.backward(fn(), gradient, inputs=[x, y]) + self.assertEqual(x.grad, x_grad_expected) + self.assertEqual(y.grad, y_grad_expected) + + reset_grad() + torch.autograd.backward(fn(), gradient, inputs=[x]) + self.assertEqual(x.grad, x_grad_expected) + self.assertEqual(y.grad, torch.zeros(2, 2)) + + reset_grad() + torch.autograd.backward(fn(), gradient, inputs=[y]) + self.assertEqual(y.grad, y_grad_expected) + self.assertEqual(x.grad, torch.zeros(2, 2)) + + reset_grad() + self.assertRaisesRegex(RuntimeError, 'cannot be empty', + lambda: torch.autograd.backward(fn(), gradient, inputs=[])) + + def test_backward_with_nonleaf_inputs(self): + x = torch.randn(2, 2, requires_grad=True) + x_nonleaf = x * 1 + y = torch.randn(2, 2, requires_grad=True) + z = torch.randn(2, 2, requires_grad=True) + + out = x_nonleaf ** 2 + y * x_nonleaf + y ** 2 + + out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y]) + x_grad_expected = 2 * x + y + y_grad_expected = x + 2 * y + + self.assertEqual(y.grad, y_grad_expected) + self.assertEqual(x.grad, x_grad_expected) + + self.assertRaisesRegex(RuntimeError, 'not a leaf Tensor', + lambda: out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y, x_nonleaf])) + + # backward doesn't have an allow_unused flag, so the behavior of backward + # when variable is not part of the graph is as if allow_used were true + # x.grad will simply be None. + out.backward(torch.ones(2, 2), create_graph=True, inputs=[z]) + self.assertIsNone(z.grad) + def test_dependent_backward(self): x = torch.randn(10, requires_grad=True) y = x ** 2 @@ -1048,6 +1167,187 @@ def no_grad_context_manager_recursive(depth): enable_grad_context_manager_recursive(10) self.assertFalse(torch.is_grad_enabled()) + def test_set_grad_coroutines(self): + @torch.no_grad() + def coro_no_grad(n=10): + self.assertFalse(torch.is_grad_enabled()) + for i in range(n): + self.assertFalse(torch.is_grad_enabled()) + r = yield i + self.assertFalse(torch.is_grad_enabled()) + self.assertEqual(i, r) + self.assertFalse(torch.is_grad_enabled()) + + @torch.enable_grad() + def coro_enable_grad(n=10): + self.assertTrue(torch.is_grad_enabled()) + for i in range(n): + self.assertTrue(torch.is_grad_enabled()) + r = yield i + self.assertTrue(torch.is_grad_enabled()) + self.assertEqual(i, r) + self.assertTrue(torch.is_grad_enabled()) + + with torch.enable_grad(): + self.assertTrue(torch.is_grad_enabled()) + coro, r = coro_no_grad(), None + try: + while True: + self.assertTrue(torch.is_grad_enabled()) + r = coro.send(r) + self.assertTrue(torch.is_grad_enabled()) + + except StopIteration: + pass + + with torch.no_grad(): + self.assertFalse(torch.is_grad_enabled()) + coro, r = coro_enable_grad(), None + try: + while True: + self.assertFalse(torch.is_grad_enabled()) + r = coro.send(r) + self.assertFalse(torch.is_grad_enabled()) + + except StopIteration: + pass + + def test_set_grad_coroutines_benign_exceptions(self): + class RecoverableException(Exception): + pass + + @torch.no_grad() + def coro_no_grad(n=10): + has_raised = False + for i in range(n): + try: + self.assertFalse(torch.is_grad_enabled()) + yield (-i if has_raised else i) + + except RecoverableException: + self.assertFalse(torch.is_grad_enabled()) + has_raised = True + + @torch.enable_grad() + def coro_enable_grad(n=10): + has_raised = False + for i in range(n): + try: + self.assertTrue(torch.is_grad_enabled()) + yield (-i if has_raised else i) + + except RecoverableException: + self.assertTrue(torch.is_grad_enabled()) + has_raised = True + + with torch.enable_grad(): + coro = coro_no_grad() + assert 0 == next(coro) + try: + while True: + r = coro.throw(RecoverableException) + self.assertLess(r, 0) + + except StopIteration: + pass + + with torch.no_grad(): + coro = coro_enable_grad() + assert 0 == next(coro) + try: + while True: + r = coro.throw(RecoverableException) + self.assertLess(r, 0) + + except StopIteration: + pass + + def test_set_grad_coroutines_critical_exceptions(self): + class UnrecoverableException(Exception): + pass + + class SecondaryException(Exception): + pass + + @torch.no_grad() + def coro_no_grad(n=10): + has_raised = False + for i in range(n): + try: + self.assertFalse(torch.is_grad_enabled()) + yield (-i if has_raised else i) + + except UnrecoverableException: + self.assertFalse(torch.is_grad_enabled()) + raise SecondaryException + + @torch.enable_grad() + def coro_enable_grad(n=10): + has_raised = False + for i in range(n): + try: + self.assertTrue(torch.is_grad_enabled()) + yield (-i if has_raised else i) + + except UnrecoverableException: + self.assertTrue(torch.is_grad_enabled()) + raise SecondaryException + + with torch.enable_grad(): + coro = coro_no_grad() + assert 0 == next(coro) + with self.assertRaises(SecondaryException): + coro.throw(UnrecoverableException) + + with torch.no_grad(): + coro = coro_enable_grad() + assert 0 == next(coro) + with self.assertRaises(SecondaryException): + coro.throw(UnrecoverableException) + + def test_set_grad_coroutines_exit(self): + @torch.no_grad() + def coro_no_grad(state): + for i in range(10): + try: + self.assertFalse(torch.is_grad_enabled()) + yield i + + except GeneratorExit: + self.assertFalse(torch.is_grad_enabled()) + state.add('GeneratorExit') + raise + + @torch.enable_grad() + def coro_enable_grad(state): + for i in range(10): + try: + self.assertTrue(torch.is_grad_enabled()) + yield i + + except GeneratorExit: + self.assertTrue(torch.is_grad_enabled()) + state.add('GeneratorExit') + raise + + state = set() + with torch.enable_grad(): + coro = coro_no_grad(state) + for i in range(5): + next(coro) + + coro.close() + self.assertTrue('GeneratorExit' in state) + + state = set() + with torch.no_grad(): + coro = coro_enable_grad(state) + for i in range(5): + next(coro) + + coro.close() + self.assertTrue('GeneratorExit' in state) + def test_no_grad_python_function(self): """Python Functions should respect grad mode.""" x = torch.ones(5, 5, requires_grad=True) @@ -1650,60 +1950,6 @@ def test_slice_expanded_v(self): expected[3:5] = v_expanded self.assertEqual(result, expected) - def test_stack(self): - x = torch.randn(10, 10, requires_grad=True) - y = torch.randn(10, 10, requires_grad=True) - z = torch.randn(10, 10, requires_grad=True) - stacked = torch.stack([x, y, z], 0) - grad = torch.randn(3, 10, 10) - stacked.backward(grad) - self.assertEqual(x.grad, grad[0]) - self.assertEqual(y.grad, grad[1]) - self.assertEqual(z.grad, grad[2]) - - def test_hstack(self): - x = torch.randn(10, 10, requires_grad=True) - y = torch.randn(10, 10, requires_grad=True) - z = torch.randn(10, 10, requires_grad=True) - stacked = torch.hstack([x, y, z]) - grad = torch.randn(10, 30) - stacked.backward(grad) - self.assertEqual(x.grad, grad[:, 0:10]) - self.assertEqual(y.grad, grad[:, 10:20]) - self.assertEqual(z.grad, grad[:, 20:30]) - - x = torch.randn(10, requires_grad=True) - y = torch.randn(10, requires_grad=True) - z = torch.randn(10, requires_grad=True) - stacked = torch.hstack([x, y, z]) - grad = torch.randn(30) - stacked.backward(grad) - self.assertEqual(x.grad, grad[0:10]) - self.assertEqual(y.grad, grad[10:20]) - self.assertEqual(z.grad, grad[20:30]) - - def test_vstack(self): - x = torch.randn(10, 10, requires_grad=True) - y = torch.randn(10, 10, requires_grad=True) - z = torch.randn(10, 10, requires_grad=True) - stacked = torch.vstack([x, y, z]) - grad = torch.randn(30, 10) - stacked.backward(grad) - self.assertEqual(x.grad, grad[0:10]) - self.assertEqual(y.grad, grad[10:20]) - self.assertEqual(z.grad, grad[20:30]) - - def test_dstack(self): - x = torch.randn(10, 10, requires_grad=True) - y = torch.randn(10, 10, requires_grad=True) - z = torch.randn(10, 10, requires_grad=True) - stacked = torch.dstack([x, y, z]) - grad = torch.randn(10, 10, 3) - stacked.backward(grad) - self.assertEqual(x.grad, grad[:, :, 0]) - self.assertEqual(y.grad, grad[:, :, 1]) - self.assertEqual(z.grad, grad[:, :, 2]) - def test_unbind(self): stacked = torch.randn(3, 10, 10, requires_grad=True) x, y, z = stacked.unbind() @@ -2498,47 +2744,6 @@ def test_var_mean_differentiable(self): torch.autograd.backward(r2, grad) self.assertTrue(torch.allclose(input1.grad, input2.grad, rtol=0.01, atol=0.0)) - @skipIfNoLapack - def test_cholesky(self): - def func(root, upper): - x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05 - return torch.cholesky(x, upper) - - def run_test(upper, dims): - root = torch.rand(*dims, requires_grad=True) - - gradcheck(func, [root, upper]) - gradgradcheck(func, [root, upper]) - - root = random_symmetric_pd_matrix(dims[-1], *dims[:-2]).requires_grad_() - chol = root.cholesky().sum().backward() - self.assertEqual(root.grad, root.grad.transpose(-1, -2)) # Check the gradient is symmetric - - for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]): - run_test(upper, dims) - run_test(upper, dims) - - @skipIfNoLapack - def test_cholesky_solve(self): - def _test_with_size(A_dims, B_dims, upper): - root = torch.rand(*A_dims).requires_grad_() - b = torch.rand(*B_dims).requires_grad_() - - def func(root, b, upper): - if upper: - A = root.triu() - else: - A = root.tril() - return torch.cholesky_solve(b, A, upper) - - gradcheck(func, [root, b, upper]) - gradgradcheck(func, [root, b, upper]) - - for (a_size, b_size), upper in product([((3, 3), (3, 4)), ((3, 3), (3, 2)), - ((2, 3, 3), (2, 3, 4)), ((2, 3, 3), (2, 3, 2))], - [True, False]): - _test_with_size(a_size, b_size, upper) - @skipIfNoLapack def test_eig(self): def func(B): @@ -2592,6 +2797,67 @@ def run_test(upper, dims): for upper, dims in product([True, False], [(3, 3), (5, 3, 3), (4, 3, 2, 2)]): run_test(upper, dims) + @slowTest + @skipIfNoLapack + def test_lobpcg(self): + + def func(k, A, largest=True, B=None): + X_shape = list(A.shape) + X_shape[-1] = k + X = torch.eye(A.size(-2), k, dtype=A.dtype, device=A.device) + if A.dim() > 2: + X = X.expand(X_shape) + + D, U = torch.lobpcg(A=A, k=k, B=B, X=X) + + # LOBPCG uses a random initial eigenspace approximation + # if parameter `X` is not provided. + # This may cause a non-deterministic behavior + # when it comes to the sign of an eigenvector + # (note if v is an eigenvector, so is -v), + # hence we eliminate this non-determinism + # by making sure that each column of U + # gets multiplied by the sign of its max (in absolute value) element. + # Also, gradcheck changes the content of the input by +/- eps (default to 1e-06) + # to compute the numerical gradient which can also cause the signs to flip. + _, idx = U.abs().max(-2, keepdim=True) + sign = U.gather(-2, idx).sign() + U = U * sign + return D, U + + def run_symeig_test(k, sizes, largest=True): + A = torch.rand(*sizes).double() + A = A.matmul(A.transpose(-1, -2)) / 10 + A.requires_grad_(True) + + gradcheck(lambda A: func(k, A, largest), A, check_batched_grad=False) + + # Custom gradient vectors for better stability due to some + # non-determinism in the lobpcg's forward. + # Note it is not required if symeig is in forward instead (tested). + D_grad = torch.rand(*A.shape[:-2], k) / 100 + U_grad = torch.rand(*A.shape[:-1], k) / 100 + gradgradcheck(lambda A: func(k, A, largest), A, [D_grad, U_grad], atol=1e-4, check_batched_grad=False) + + # check whether A.grad is symmetric + A = A.detach().requires_grad_(True) + D, U = func(k, A, largest) + (D.sum() + U.sum()).backward() + self.assertEqual(A.grad, A.grad.transpose(-1, -2)) + + # the tests below take about 1-2 minutes to finish, + # but we want to be extra sure that the backward is correct. + for largest in [True, False]: + run_symeig_test(1, (6, 6), largest=largest) + run_symeig_test(1, (2, 6, 6), largest=largest) + run_symeig_test(1, (2, 2, 6, 6), largest=largest) + run_symeig_test(2, (6, 6), largest=largest) + run_symeig_test(2, (2, 6, 6), largest=largest) + run_symeig_test(2, (2, 2, 6, 6), largest=largest) + run_symeig_test(3, (9, 9), largest=largest) + run_symeig_test(3, (2, 9, 9), largest=largest) + run_symeig_test(3, (2, 2, 9, 9), largest=largest) + @skipIfNoLapack def test_cholesky_inverse(self): def _test_with_size(upper, dims): @@ -2615,112 +2881,6 @@ def func(A, upper): for upper, dims in product([True, False], [(3, 3), (5, 5)]): _test_with_size(upper, dims) - @skipIfNoLapack - def test_triangular_solve(self): - def _test_with_size(A_dims, B_dims): - A = torch.rand(*A_dims).requires_grad_() - b = torch.rand(*B_dims).requires_grad_() - - for upper, transpose, unitriangular in product((True, False), repeat=3): - def func(A, b): - return torch.triangular_solve(b, A, upper, transpose, unitriangular) - - gradcheck(func, [A, b]) - gradgradcheck(func, [A, b]) - - _test_with_size((3, 3), (3, 4)) - _test_with_size((3, 3), (3, 2)) - _test_with_size((2, 3, 3), (2, 3, 4)) - _test_with_size((2, 3, 3), (2, 3, 2)) - - @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") - def test_fft_ifft_rfft_irfft(self): - def _test_complex(sizes, signal_ndim): - x = torch.randn(sizes, requires_grad=True, dtype=torch.double) - - for normalized in (True, False): - def fft(x): - return x.fft(signal_ndim, normalized=normalized) - - gradcheck(fft, [x]) - gradgradcheck(fft, [x], gen_non_contig_grad_outputs=True) - - def ifft(fx): - return fx.ifft(signal_ndim, normalized=normalized) - - # Use output of fft(x) for inverse fft, due to symmetry requirements - fx = fft(x).detach() - fx.requires_grad = True - gradcheck(ifft, [fx]) - gradgradcheck(ifft, [fx], gen_non_contig_grad_outputs=True) - - def _test_real(sizes, signal_ndim): - x = torch.randn(sizes, requires_grad=True, dtype=torch.double) - if x.dim() == signal_ndim: - start_dim = 0 - else: - start_dim = 1 - signal_sizes = x.size()[start_dim:start_dim + signal_ndim] - - for normalized, onesided in product((True, False), repeat=2): - def rfft(x): - return x.rfft(signal_ndim, normalized=normalized, onesided=onesided) - - gradcheck(rfft, [x]) - gradgradcheck(rfft, [x], gen_non_contig_grad_outputs=True) - - # Generally speaking, irfft itself won't and can't pass the - # current gradcheck as it assumes the input follows conjugate - # symmetry, an requirement that is never true with our point - # numerical Jacobian estimate. Without input symmtry, irfft's - # behavior is undefined. - # - # Even onesided results can't remove all redundancy. For - # example, consider the .select(last_signal_dim, 0) slice. - # It is entirely represented in the onesided results (except - # for 1D), and will be reflected onto itself! - # - # So only 1D onesided irfft should pass grad check as it is - # guaranteed that the input has no symmetrical values. - # - # In other cases, we test a function that first uses rfft to - # generate a tensor that follows the conjugate symmetry irfft - # expects, and then feeds it into irfft. Since rfft is already - # tested above, we thereby verify the correctness of irfft. - if signal_ndim == 1 and onesided: - def irfft(fx): - return fx.irfft(signal_ndim, normalized=normalized, - onesided=onesided, signal_sizes=signal_sizes) - - # Use output of rfft(x) for inverse rfft, due to symmetry requirements - fx = rfft(x).detach() - fx.requires_grad = True - gradcheck(irfft, [fx]) - gradgradcheck(irfft, [fx], gen_non_contig_grad_outputs=True) - else: - # Test this function: f(x) = ifft(rfft(x) + rfft(z)), where - # z is some fixed tensor of same size as x. rfft(z) term is - # needed because otherwise f becomes identity. - z = torch.randn(sizes, dtype=torch.double) - fz = z.rfft(signal_ndim, normalized=normalized, onesided=onesided) - - def rfft_irfft(x): - fx = x.rfft(signal_ndim, normalized=normalized, onesided=onesided) - y = fx + fz - return y.irfft(signal_ndim, normalized=normalized, - onesided=onesided, signal_sizes=signal_sizes) - - gradcheck(rfft_irfft, [x]) - gradgradcheck(rfft_irfft, [x], gen_non_contig_grad_outputs=True) - - _test_real((2, 10), 1) - _test_real((2, 3, 4), 2) - _test_real((2, 3, 4, 3), 3) - - _test_complex((2, 2, 10, 2), 1) - _test_complex((1, 2, 3, 4, 2), 2) - _test_complex((2, 1, 3, 4, 3, 2), 3) - def test_gradcheck_fail_when_no_differentiable_outputs_and_num_grad_not_zero(self): def autograd_fn(input): output = torch.detach(input) @@ -2761,6 +2921,20 @@ def run_test(input_size, norm_deg): run_test((10,), 3) run_test((10,), 1) run_test((10,), 1.5) + run_test((10,), inf) + + def test_norm_inf_subgradient(self): + def run_test(input, expected, dim=None): + x = torch.tensor(input, requires_grad=True) + out = x.norm(inf, dim=dim, keepdim=True) + out.backward(torch.ones(out.size())) + self.assertEqual(x.grad, expected) + + run_test([0., 0., 0.], [0., 0., 0.]) + run_test([1., 0., 1.], [0.5, 0., 0.5]) + run_test([[1., 0., 1.], [0., 1., 1.]], [[0.25, 0., 0.25], [0., 0.25, 0.25]]) + run_test([[1., 0., 1.], [0., 1., 0.]], [[0.5, 0., 0.5], [0., 1., 0.]], (1,)) + run_test(torch.ones((2, 2, 2)), torch.full((2, 2, 2), 0.25), (0, 2)) def test_pow_zero_tensor_gradient(self): def run_test(input_size, exponent): @@ -2776,29 +2950,21 @@ def test_pow_scalar_base(self): a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_() gradcheck(lambda a: torch.pow(2, a), (a,)) - @skipIfNoLapack - def test_pinverse(self): - # Why is pinverse tested this way, and not ordinarily as other linear algebra methods? - # 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable - # 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973) - # 3. This method creates two orthogonal matrices, and a constructs a test case with large - # singular values (given by x to the function). - # 4. This will ensure that small perturbations don't affect the rank of matrix, in which case - # a derivative exists. - # 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method - m, n = 5, 10 - U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n - V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n - - def func(x): - S = torch.cat([x, torch.zeros(n - m)], 0) - M = U.mm(torch.diag(S)).mm(V.t()) - return M.pinverse() - - gradcheck(func, [torch.rand(m).add_(1).requires_grad_()]) - gradcheck(func, [torch.rand(m).add_(10).requires_grad_()]) - gradgradcheck(func, [torch.rand(m).add_(1).requires_grad_()]) - gradgradcheck(func, [torch.rand(m).add_(10).requires_grad_()]) + def test_igamma(self): + # 1e-3 offset to avoid zeros + # NOTE: derivative for s is not implemented + s = (torch.rand(100, dtype=torch.double) + 1e-3) + x = (torch.rand(100, dtype=torch.double) + 1e-3).requires_grad_() + gradcheck(torch.igamma, (s, x)) + gradgradcheck(torch.igamma, (s, x)) + + def test_igammac(self): + # 1e-3 offset to avoid zeros in s + # NOTE: derivative for s is not implemented + s = (torch.rand(100, dtype=torch.double) + 1e-3) + x = (torch.rand(100, dtype=torch.double)).requires_grad_() + gradcheck(torch.igamma, (s, x)) + gradgradcheck(torch.igamma, (s, x)) def test_chain_matmul(self): def gen_matrices(p): @@ -2814,18 +2980,17 @@ def gen_matrices(p): gradgradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6])) gradgradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10])) - @unittest.skipIf(IS_WINDOWS, """File open permission error on Windows, - https://github.com/pytorch/pytorch/issues/34086""") def test_profiler_tracing(self): t1, t2 = torch.ones(1), torch.ones(1) - with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: torch.add(t1, t2) - with tempfile.NamedTemporaryFile(mode="w+") as f: - prof.export_chrome_trace(f.name) + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) # read the trace and expect valid json # if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test. - json.load(f) + with io.open(fname, 'r') as f: + json.load(f) # Same test but for cuda. if not torch.cuda.is_available(): @@ -2833,18 +2998,19 @@ def test_profiler_tracing(self): device = torch.device("cuda:0") t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device) - with torch.autograd.profiler.profile(use_cuda=True) as prof: + with torch.autograd.profiler.profile(use_cuda=True, use_kineto=kineto_available()) as prof: torch.add(t1, t2) - with tempfile.NamedTemporaryFile(mode="w+") as f: - prof.export_chrome_trace(f.name) + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) # Now validate the json - json.load(f) + with io.open(fname, 'r') as f: + json.load(f) def test_profiler(self): x = torch.randn(10, 10) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: self.assertTrue(torch.autograd._profiler_enabled()) y = x * 2 + 4 @@ -2855,22 +3021,21 @@ def test_profiler(self): 'aten::empty', 'aten::add', 'aten::to', 'aten::empty_strided', 'aten::copy_', 'aten::empty'] top_level_names = ['aten::mul', 'aten::add'] - top_level_iter = iter(top_level_names) - self.assertEqual(len(p.function_events), len(names)) - for info, expected_name in zip(p.function_events, names): - if info.cpu_interval.start > last_end: - top_level_name_expected = next(top_level_iter) - self.assertEqual(info.name, top_level_name_expected) - last_end = info.cpu_interval.end - self.assertEqual(info.name, expected_name) + for evt in p.function_events: + if evt.time_range.start > last_end: + self.assertTrue(evt.name in top_level_names) + last_end = evt.time_range.end + self.assertTrue(evt.name in names) def test_profiler_seq_nr(self): - with profile() as p: + with profile(use_kineto=kineto_available()) as p: x = torch.randn(10, 10, requires_grad=True) y = torch.randn(10, 10, requires_grad=True) z = x + y s = z.sum() s.backward() + print(p.key_averages().table( + sort_by="self_cpu_time_total", row_limit=-1)) # expecting aten::add, aten::sum to have the sequence numbers, # expecting the corresponding backward nodes to have the same numbers # as the forward ops @@ -2913,7 +3078,7 @@ def test_profiler_seq_nr(self): def test_profiler_unboxed_only(self): x = torch.rand(3, 4) - with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: x.resize_([3, 2]) def test_profiler_propagation(self): @@ -2938,7 +3103,7 @@ def bar(x): traced_bar = torch.jit.trace(bar, x) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: traced_bar(x) found_foo = False @@ -2960,7 +3125,7 @@ def bar(x): def test_record_function_callbacks(self): x = torch.randn(10, 10) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: with record_function("foo"): y = x * 2 + 4 @@ -2992,12 +3157,12 @@ def get_id(): node_id=0, name="", thread=thread, - cpu_start=range[0], - cpu_end=range[1], + start_us=range[0], + end_us=range[1], ) ) - events.populate_cpu_children() + events._populate_cpu_children() # Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2] # as a child of [1, 3] @@ -3016,7 +3181,7 @@ def test_profiler_aggregation_table(self): """ x = torch.randn(1024) - with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: torch.einsum("i->", x) prof_str = str(prof) @@ -3026,8 +3191,8 @@ def test_profiler_aggregation_table(self): def test_profiler_function_event_avg(self): avg = FunctionEventAvg() - avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, cpu_start=10, cpu_end=15)) - avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, cpu_start=20, cpu_end=30)) + avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15)) + avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30)) avg.add(avg) self.assertEqual(avg.key, "foo") @@ -3046,7 +3211,7 @@ def test_profiler_shapes(self): layer1 = torch.nn.Linear(20, 30) layer2 = torch.nn.Linear(30, 40) input = torch.randn(128, 20) - with profile(record_shapes=True) as prof: + with profile(record_shapes=True, use_kineto=kineto_available()) as prof: layer2(layer1(input)) print(prof.function_events) @@ -3062,18 +3227,18 @@ def test_profiler_shapes(self): last_end = 0 for event in prof.function_events: - if event.cpu_interval.start > last_end: + if event.time_range.start > last_end: name_expected, input_shape_expected = next(expected_iter) if name_expected is not None: self.assertEqual(event.name, name_expected) self.assertEqual(event.input_shapes, input_shape_expected) - last_end = event.cpu_interval.end + last_end = event.time_range.end def test_profiler_no_cuda(self): print("") layer = torch.nn.Linear(20, 30) x = torch.randn(128, 20) - with profile(use_cuda=False) as prof: + with profile(use_cuda=False, use_kineto=kineto_available()) as prof: layer(x) prof_str = str(prof) @@ -3085,7 +3250,7 @@ def test_profiler_aggregation_lstm(self): print("") rnn = torch.nn.LSTM(10, 20, 2) total_time_s = 0 - with profile(record_shapes=True) as prof: + with profile(record_shapes=True, use_kineto=kineto_available()) as prof: for i in range(20): input = torch.randn(5, 3, 10) h = torch.randn(2, 3, 20) @@ -3100,7 +3265,7 @@ def test_profiler_aggregation_lstm(self): print(prof.key_averages(group_by_input_shape=True).table( sort_by="self_cpu_time_total", row_limit=10)) print(prof.table( - sort_by="self_cpu_time_total", row_limit=10, header="TEST", top_level_events_only=True)) + sort_by="self_cpu_time_total", row_limit=10, max_src_column_width=300, header="TEST", top_level_events_only=True)) print(prof.key_averages(group_by_input_shape=True).table( sort_by="self_cpu_time_total", row_limit=10, top_level_events_only=True)) @@ -3122,7 +3287,7 @@ def test_profiler_aggregation_lstm(self): def test_memory_profiler(self): def run_profiler(tensor_creation_fn, metric): # collecting allocs / deallocs - with profile(profile_memory=True, record_shapes=True) as prof: + with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof: x = None with record_function("test_user_scope_alloc"): x = tensor_creation_fn() @@ -3214,7 +3379,7 @@ def create_mkldnn_tensor(): # check partial overlap of tensor allocation with memory profiler x = torch.rand(10, 10) - with profile(profile_memory=True, record_shapes=True) as prof: + with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof: del x x = torch.rand(10, 10) del x @@ -3240,7 +3405,7 @@ def forward(x): forward(x) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: forward(x) events = p.function_events @@ -3265,7 +3430,7 @@ def forward(x): def f(x, y): return x + y - with profile() as p: + with profile(use_kineto=kineto_available()) as p: f(1, 2) self.assertTrue('my_func' in str(p)) @@ -3415,6 +3580,16 @@ def test(): test() self.assertEqual(dealloc[0], 1) + def test_inplace_view_leaf_errors(self): + # Issue #21875: Fail faster (when we try to modify the view vs. in backward()) + x = torch.zeros(1, requires_grad=True) + y = x.view_as(x) + with self.assertRaisesRegex(RuntimeError, + "a view of a leaf Variable that " + "requires grad is being used in " + "an in-place operation."): + y.add_(1) + def test_inplace_view_backward(self): # Issue #10532: Make sure that this does not raise RuntimeError. net = nn.Sequential( @@ -3459,13 +3634,10 @@ def test_inplace_view_weak_grad_fn(self): s.backward() self.assertEqual(s, torch.tensor(1.0)) - # Issue 23502: Ensure RuntimeError for modification of SavedVariable. + # Issue #21875: Fail faster (when we try to modify the view vs. in backward()) a = torch.rand(10, requires_grad=True).narrow(0, 0, 10) - b = a.relu_() - c = b.add_(100) - del b with self.assertRaises(RuntimeError): - c.sum().backward(torch.ones(1, requires_grad=True)) + b = a.relu_() def test_mul_out(self): a = torch.randn(2, 2, requires_grad=True) @@ -3822,14 +3994,14 @@ def backward(ctx, grad_out): return NonDetFunc.apply(grad_out, ctx._jitter) * (1 + torch.rand_like(grad_out) * ctx._jitter), None inp = torch.randn(5, 5, requires_grad=True) - gradcheck(lambda x: NonDetFunc.apply(x, 0.0), inp) + gradcheck(lambda x: NonDetFunc.apply(x, 0.0), inp, check_batched_grad=False) with self.assertRaisesRegex(RuntimeError, 'Backward is not reentrant'): - gradcheck(lambda x: NonDetFunc.apply(x, 1e-6), inp) + gradcheck(lambda x: NonDetFunc.apply(x, 1e-6), inp, check_batched_grad=False) with self.assertRaisesRegex(RuntimeError, 'Backward is not reentrant'): - gradgradcheck(lambda x: NonDetFunc.apply(x, 1e-12), inp) - gradcheck(lambda x: NonDetFunc.apply(x, 0.0), inp, nondet_tol=1e-5) - gradcheck(lambda x: NonDetFunc.apply(x, 1e-6), inp, nondet_tol=1e-5) - gradgradcheck(lambda x: NonDetFunc.apply(x, 1e-12), inp, nondet_tol=1e-5) + gradgradcheck(lambda x: NonDetFunc.apply(x, 1e-12), inp, check_batched_grad=False) + gradcheck(lambda x: NonDetFunc.apply(x, 0.0), inp, nondet_tol=1e-5, check_batched_grad=False) + gradcheck(lambda x: NonDetFunc.apply(x, 1e-6), inp, nondet_tol=1e-5, check_batched_grad=False) + gradgradcheck(lambda x: NonDetFunc.apply(x, 1e-12), inp, nondet_tol=1e-5, check_batched_grad=False) def test_version_counter(self): x = torch.randn(1, 2) @@ -4160,6 +4332,50 @@ def maybe_check_raise(fn, should_raise): run_test(grad_mode=False, requires_grad=False, is_view=True, should_raise_tuple=(None, None, None)) + def test_inplace_not_requires_grad(self): + class MyFn(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + return inp.view_as(inp) + + @staticmethod + def backward(ctx, grad): + return grad + + # Original Tensor does not require grad + a = torch.rand(1, 2) + + # Tensor being written does require grad + b = torch.rand(1, requires_grad=True) + + # Take an invalid view on 'a' that should raise an error (warns during deprecation) + view_a = MyFn.apply(a) + + with self.assertWarnsRegex(UserWarning, "This view was created inside a custom Function"): + view_a += b + + # Extra test for copy_ that is a manual implementation and could be easily + # forgotten when the codegen is updated (warns during deprecation) + a = torch.rand(1, 2) + b = torch.rand(1, requires_grad=True) + view_a = MyFn.apply(a) + + with self.assertWarnsRegex(UserWarning, "This view was created inside a custom Function"): + view_a.copy_(b) + + # Functions that should throw must properly throw + a = torch.rand(1, 2) + b = torch.rand(1, requires_grad=True) + view_a = a.unbind()[0] + with self.assertRaisesRegex(RuntimeError, "This view is the output of a function that returns " + "multiple views."): + view_a.copy_(b) + + # Sanity check that views that should work still work + a = torch.rand(1, 2) + b = torch.rand(1, requires_grad=True) + a.select(1, 0).copy_(b) + def _do_test_autograd_simple_views_python(self, dtype): # This is not necessarily the absolute correct behavior, but this is the current # one. This test is here to make sure that any change to this behavior is detected @@ -4273,21 +4489,29 @@ def fn(a, b): if fn_id == "view_of_temp": # This will be fixed after the deprecation cycle and the warning becomes # an error. - with self.assertRaisesRegex(RuntimeError, "Jacobian mismatch for output 0"): - gradcheck(fn, (a, b)) + with self.assertRaisesRegex(RuntimeError, + "a view of a leaf Variable that requires grad " + "is being used in an in-place operation."): + gradcheck(fn, (a, b), check_batched_grad=False) else: # This works but the custom backward is not called (or called with partial) # gradients as tested below - gradcheck(fn, (a, b)) + gradcheck(fn, (a, b), check_batched_grad=False) self.assertTrue(len(w) > 0) else: - gradcheck(fn, (a, b)) + gradcheck(fn, (a, b), check_batched_grad=False) # Was the custom backward called properly bw_called[0] = 0 ga_nz[0] = True # For the case where the backward is called with warnings.catch_warnings(record=True) as w: - fn(a, b).backward() + if inplace and output_is_a_view and fn_id != "one_output": + with self.assertRaisesRegex(RuntimeError, + "a view of a leaf Variable that requires grad " + "is being used in an in-place operation."): + fn(a, b).backward() + else: + fn(a, b).backward() expected_called = 1 expected_ga_nz = True @@ -4468,6 +4692,14 @@ def fn(a, b): with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err): fn(a, b) + def test_named_tensor_for_complex_views(self): + names = ["batch", "height", "width", "complex"] + z = torch.ones((5, 12, 14, 2), requires_grad=True) + z_named = z.refine_names(*names) + z_complex = torch.view_as_complex(z_named.rename(None)).refine_names(*names[:-1]) + z_complex.sum().backward() + self.assertEqual(z.grad, torch.view_as_real(torch.ones_like(z_complex).rename(None))) + def test_custom_function_return_view_in_nograd(self): class Alias(Function): @staticmethod @@ -4574,6 +4806,33 @@ def test(inp, inp_dtype, out_dtype): test(inp, torch.float, torch.double) test(inp, torch.double, torch.float) + def test_nan_to_num(self): + a = torch.randn(3, 3, 3, 3) + with torch.no_grad(): + a[torch.rand_like(a) < 0.2] = float('nan') + a[torch.rand_like(a) < 0.2] = float('inf') + a[torch.rand_like(a) < 0.2] = -float('inf') + + a.requires_grad = True + + gradcheck(lambda x: x.nan_to_num(), a) + gradgradcheck(lambda x: x.nan_to_num(), a) + + gradcheck(lambda x: x.nan_to_num(nan=1.2), a) + gradgradcheck(lambda x: x.nan_to_num(nan=1.2), a) + + gradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0), a) + gradgradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0), a) + + gradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0, neginf=-2.0), a) + gradgradcheck(lambda x: x.nan_to_num(nan=1.2, posinf=2.0, neginf=-2.0), a) + + gradcheck(lambda x: x.nan_to_num(posinf=2.0, neginf=-2.0), a) + gradgradcheck(lambda x: x.nan_to_num(posinf=2.0, neginf=-2.0), a) + + gradcheck(lambda x: x.nan_to_num(neginf=-2.0), a) + gradgradcheck(lambda x: x.nan_to_num(neginf=-2.0), a) + def test_custom_function_error(self): class BadFw(Function): @staticmethod @@ -4634,17 +4893,57 @@ def test_integer_outputs(self): self.assertFalse(out.dtype.is_floating_point) self.assertFalse(out.requires_grad) - bins = torch.linspace(0, 1.0, requires_grad=True) + bins = torch.linspace(0, 1.0, steps=100, requires_grad=True) vals = torch.rand(5, 5, requires_grad=True) out = torch.bucketize(vals, bins) self.assertFalse(out.dtype.is_floating_point) self.assertFalse(out.requires_grad) -def index_variable(shape, max_indices): - if not isinstance(shape, tuple): - shape = (shape,) - index = torch.rand(*shape).mul_(max_indices).floor_().long() - return index + def assert_only_first_requires_grad(res): + if not isinstance(res, tuple): + res = (res,) + self.assertTrue(res[0].requires_grad) + for out in res[1:]: + if out is not None: + self.assertFalse(out.requires_grad) + + for sort in [True, False]: + for return_inverse in [True, False]: + for return_counts in [True, False]: + res = torch.unique(inp, sorted=sort, return_inverse=return_inverse, + return_counts=return_counts) + assert_only_first_requires_grad(res) + + res = torch.unique(inp, sorted=sort, return_inverse=return_inverse, + return_counts=return_counts, dim=0) + assert_only_first_requires_grad(res) + + res = torch.unique_consecutive(inp, return_inverse=return_inverse, + return_counts=return_counts) + assert_only_first_requires_grad(res) + + res = torch.unique_consecutive(inp, return_inverse=return_inverse, + return_counts=return_counts, dim=0) + assert_only_first_requires_grad(res) + + # Here we test the internal functions to make sure all of them are + # covered on top of the public API + res = torch._unique(inp, sorted=sort, return_inverse=return_inverse) + assert_only_first_requires_grad(res) + + # This looks public but is actually manually deleted from the + # torch namespace in torch/functional.py + res = torch._VF.unique_dim(inp, dim=0, sorted=sort, return_inverse=return_inverse, + return_counts=return_counts) + assert_only_first_requires_grad(res) + + # We don't test `unique_dim_consecutive` here. + # It looks public but the python binding is actually manually disabled in + # tools/autograd/gen_python_functions.py + + res = torch._unique2(inp, sorted=sort, return_inverse=return_inverse, + return_counts=return_counts) + assert_only_first_requires_grad(res) def index_perm_variable(shape, max_indices): @@ -4654,20 +4953,6 @@ def index_perm_variable(shape, max_indices): index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape) return index - -def gather_variable(shape, index_dim, max_indices, duplicate=False): - assert len(shape) == 2 - assert index_dim < 2 - batch_dim = 1 - index_dim - index = torch.LongTensor(*shape) - for i in range(shape[index_dim]): - index.select(index_dim, i).copy_( - torch.randperm(max_indices)[:shape[batch_dim]]) - if duplicate: - index.select(batch_dim, 0).copy_(index.select(batch_dim, 1)) - return index - - def bernoulli_scalar(): return torch.tensor(0, dtype=torch.uint8).bernoulli_() @@ -4693,8 +4978,9 @@ def gradgradcheck_method_precision_override(test_name): return override def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable, - input_variables, run_gradgradcheck=True): - test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION)) + input_variables, run_gradgradcheck=True, check_batched_grad=True): + test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION, + check_batched_grad=check_batched_grad)) if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME: return gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name) @@ -4702,9 +4988,12 @@ def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, outpu atol = gradgradcheck_precision_override['atol'] rtol = gradgradcheck_precision_override['rtol'] test_case.assertTrue(gradgradcheck(apply_method, input_variables, None, atol=atol, rtol=rtol, - gen_non_contig_grad_outputs=True)) + gen_non_contig_grad_outputs=True, + check_batched_grad=check_batched_grad)) else: - test_case.assertTrue(gradgradcheck(apply_method, input_variables, gen_non_contig_grad_outputs=True)) + test_case.assertTrue(gradgradcheck(apply_method, input_variables, + gen_non_contig_grad_outputs=True, + check_batched_grad=check_batched_grad)) def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, @@ -4727,22 +5016,30 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, # the tests for these ops which do not have 'complex' in variant should not run for complex # and only run for floating point -# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition -separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos'] # ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan'] +separate_complex_tests = ['view_as_real', 'real', 'imag', 'div', 'pow', 'rsqrt', '__rdiv__', 'add', 'sub'] # NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly # for non-holomorphic functions # allow list for complex complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone', - 'repeat', 'expand', 'flip', 'fliplr', 'flipud', 'rot90', 'transpose', + 'repeat', 'expand', 'rot90', 'transpose', 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu', - 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'round', - 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', - 'cosh', '__rmul__', 'sgn'] + separate_complex_tests - -# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411 -# complex_list += ['fill_', 't', '__rdiv__', 'tanh'] + 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', + 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'mul', + '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', + 'bmm', 'mv', 'ger', 'diagonal', 'fill_', 'sub', + 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul', + 'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', + 'narrow', 'swapaxes', 'swapdims', 'tensor_split', 'tile', + 'baddbmm', 'addbmm', 'addmv'] + separate_complex_tests + +# deny list for batched grad computation +EXCLUDE_BATCHED_GRAD_TESTS = set([ + 'test_unfold_scalar', + 'test_unfold_scalar_neg0', + 'test_to_sparse', +]) def add_test( name, @@ -4816,8 +5113,10 @@ def fn(*inputs): return output_process_fn(output) if not is_inplace and name not in EXCLUDE_GRADCHECK: + check_batched_grad = test_name not in EXCLUDE_BATCHED_GRAD_TESTS run_grad_and_gradgrad_checks(self, name, test_name, fn, - output_variable, (self_variable,) + args_variable) + output_variable, (self_variable,) + args_variable, + check_batched_grad=check_batched_grad) # functional interface tests torch_fn = getattr_qualified(torch, name) @@ -4864,7 +5163,9 @@ def fn(*inputs): 'broadcast_all' in test_name or 'atanh' in test_name or 'acosh' in test_name or - 'asinh' in test_name) + 'asinh' in test_name or + 'abs_complex' in test_name or + 'abs_scalar_complex' in test_name) if hasattr(torch.ones(1), inplace_name) and not skip_inplace: output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) if not isinstance(output_variable, tuple): @@ -4911,7 +5212,10 @@ def fn(*inputs): inplace_name = name + '_' # can't broadcast inplace to left hand side broadcast_skip_inplace = 'broadcast_lhs' in test_name or 'broadcast_all' in test_name - if hasattr(torch.ones(1), inplace_name) and not broadcast_skip_inplace: + # skip C -> R inplace tests + skip_c_to_r_inplace = 'abs_complex' in test_name or 'abs_scalar_complex' in test_name + skip_inplace = broadcast_skip_inplace or skip_c_to_r_inplace + if hasattr(torch.ones(1), inplace_name) and not skip_inplace: check(inplace_name) assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name @@ -4979,6 +5283,26 @@ def fn(a, dim0_size=5): self.assertEqual(x.grad, y.grad) + def test_view_with_multi_output(self): + x = torch.randn(2, 2, 2, dtype=torch.double) + + x1 = torch.view_as_complex(x) + # Taking an invalid view should always be allowed as long as it is not + # modified inplace + res = x1.unbind(0) + + with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): + res[0] += torch.rand(2, requires_grad=True) + + x.requires_grad_(True) + x1 = torch.view_as_complex(x) + # Taking an invalid view should always be allowed as long as it is not + # modified inplace + res = x1.unbind(0) + + with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): + res[0] += torch.rand(2, requires_grad=True) + def as_identity(self): # view_as_real and view_as_complex behavior should be like an identity def func(z): @@ -5977,22 +6301,80 @@ def foo(a): self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1)) self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0)) +class TestAutogradForwardMode(TestCase): + def test_forward_level_cleanup(self): + import weakref + + def get_tensor_and_weak_ref(): + # Helper function to get a Tensor and a weak ref that tells us + # if the c++ version of this Tensor is still alive or not. + # + # Create the following reference chain to do so: + # - python Tensor t + # - c++ Tensor corresponding by t + # - c++ Node corresponding to t.grad_fn + # - python dict of metadata from this Node + # - an object in this dict that we can take a weakref of + + + # Create a new Tensor and Node + t = torch.rand(2, requires_grad=True).clone() + # Create the metadata dict + meta_dict = t.grad_fn.metadata + # Create the object in the dict + + class Foo(object): + pass + my_obj = Foo() + meta_dict[0] = my_obj + + # After exiting this function, the python Tensor t is the only + # thing keeping ref alive + ref = weakref.ref(my_obj) + return t, ref + + # Sanity check that the helper function works as expected + t, t_ref = get_tensor_and_weak_ref() + self.assertIsNotNone(t_ref()) + + del t + self.assertIsNone(t_ref()) + + # Main test code + foo = torch.rand(2) + + with fwAD.dual_level(): + tangent, tangent_ref = get_tensor_and_weak_ref() + self.assertIsNotNone(tangent_ref()) + + dual = fwAD.make_dual(foo, tangent) + self.assertIsNotNone(tangent_ref()) + + # Make sure that the tangent we provided has been re-used as is + self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent) + + # Make sure that dual is keeping the tangent alive + del tangent + self.assertIsNotNone(tangent_ref()) + + # Make sure that the dual level does not keep the c++ + # version of the tangent alive + del dual + self.assertIsNone(tangent_ref()) # Generic device type autograd tests. class TestAutogradDeviceType(TestCase): def test_min_max_median_backprops_to_all_values(self, device): - for f in [torch.min, torch.max, torch.median]: - x = torch.tensor([1., 0., 1., 0., 1., 0.], device=device, requires_grad=True) - y = f(x) - y.backward() - self.assertEqual(x.grad.sum(), 1.) - self.assertEqual((x.grad == 1 / 3).sum(), 3) - - # skip this test if running on rocm, because in cdist - # we use __shfl_down_sync on CUDA for fast reduction - # and it gives incorrect results on rocm platform - @skipCUDAIfRocm + for f in [torch.min, torch.max, torch.median, torch.nanmedian]: + x1 = torch.tensor([1., 0., 1., 0., 1., 0.], device=device, requires_grad=True) + x2 = torch.tensor([float('nan'), float('nan'), float('nan')], requires_grad=True) + for x in [x1, x2]: + y = f(x) + y.backward() + self.assertEqual(x.grad.sum(), 1.) + self.assertEqual((x.grad == 1 / 3).sum(), 3) + def test_cdist(self, device): def _test_cdist_for_size(sizex, sizey=None): if sizey is None: @@ -6037,6 +6419,18 @@ def _test_euclidean_large_cdist(sizex, sizey=None): _test_cdist_for_size((1, 1), (S, 1)) _test_euclidean_large_cdist((2000, 5)) + # Ensure that cdist backward with p<1 does not produce NaNs + def test_cdist_grad_p_lt_1_no_nan(self, device): + for p in [0.99, 0.7, 0.5, 0.1, 0.01]: + x = torch.randn(1, 2, device=device) + y = x.clone().detach() + torch.tensor([[1., 0.]], device=device) + x.requires_grad = True + y.requires_grad = True + result = torch.cdist(x, y, p=p) + result.backward(torch.ones_like(result)) + self.assertFalse(torch.isnan(x.grad).any()) + self.assertFalse(torch.isnan(y.grad).any()) + def test_cdist_same_inputs(self, device): # Test to detect issues in cdist gradient calculation # When the distances are 0 @@ -6064,8 +6458,6 @@ def test_parameter_resize(self, device): m = torch.cat((asd, asd)) m.sum().backward() - # NOTE: flaky on ROCm CI - @skipCUDAIfRocm def test_sparse_ctor_getter_backward(self, device): # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test def _test(size, sparse_dim, nnz, device): @@ -6086,9 +6478,9 @@ def fn(v): z = torch.sparse_coo_tensor(y.indices(), new_v, y.size()) return z.coalesce().values() - gradcheck(fn, (inp,)) + gradcheck(fn, (inp,), check_batched_grad=False) # FIXME: make gradgradcheck work. - # gradgradcheck(fn, (inp,)) + # gradgradcheck(fn, (inp,), check_batched_grad=False) # assert that _values is non-differentiable with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"): @@ -6386,7 +6778,6 @@ def test_ctc_loss_cudnn(self, device): grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out) self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0) - @skipCUDAIfRocm def test_leaky_relu_inplace_with_neg_slope(self, device): a = torch.tensor([-1., 1.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), -2) @@ -6398,7 +6789,6 @@ def test_leaky_relu_inplace_with_neg_slope(self, device): with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): b.backward(torch.ones(2, device=device)) - @skipCUDAIfRocm def test_leaky_relu_inplace_with_zero_slope(self, device): a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), 0.0) @@ -6406,6 +6796,18 @@ def test_leaky_relu_inplace_with_zero_slope(self, device): expected = torch.tensor([0., 0., 1.], device=device) self.assertEqual(a.grad, expected) + @onlyOnCPUAndCUDA + def test_elu_inplace_with_neg_alpha(self, device): + a = torch.tensor([-1., 1.], device=device, requires_grad=True) + b = torch.nn.functional.elu_(a.clone(), alpha=-2) + with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): + b.backward(torch.ones(2, device=device)) + + a = torch.tensor([-1., 1.], device=device, requires_grad=True) + b = torch.nn.functional.celu_(a.clone(), alpha=-2) + with self.assertRaisesRegex(RuntimeError, "call out-of-place version"): + b.backward(torch.ones(2, device=device)) + @onlyCUDA def test_free_unneeded_tensor(self, device): x = torch.randn(2, 3, 10, 10, device=device, requires_grad=True) @@ -6476,6 +6878,24 @@ def flatten_out(mod, inp): torch.autograd.gradcheck(gradcheckfunc, inp) torch.autograd.gradgradcheck(gradcheckfunc, inp) + if inp.is_cuda and not TEST_WITH_ROCM: + # Assert that we have good error message around unsupported CuDNN double backward + # NB: we trigger double backward using .backward() instead of autograd.grad due to + # https://github.com/pytorch/pytorch/issues/37874 + with torch.backends.cudnn.flags(enabled=True): + result = gradcheckfunc(inp) + result[0].sum().backward(create_graph=True) + grad0 = next(mod.parameters()).grad + with self.assertRaisesRegex(RuntimeError, + "please disable the CuDNN backend temporarily"): + grad0.sum().backward() + + # Here we avoid the backward(create_graph=True) memory leak + # described in https://github.com/pytorch/pytorch/issues/7343 + for param in mod.parameters(): + param.grad = None + inp.grad = None + def test_LSTM_grad_and_gradgrad(self, device): hsize = 4 inp = torch.rand(1, 3, hsize, device=device, dtype=torch.float64, requires_grad=True) @@ -6490,6 +6910,39 @@ def test_GRU_grad_and_gradgrad(self, device): mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(torch.float64) self._test_rnn_mod(mod, inp) + def test_copysign_subgradient(self, device): + # Input is 0.0 + x = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + + # Input is -0.0 + x = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + + # Other is 0.0 + x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + + # Other is -0.0 + x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + @deviceCountAtLeast(1) def test_grad_assignment(self, devices): x = torch.randn(5, 5, device=devices[0]) @@ -6832,6 +7285,32 @@ def test_logcumsumexp_large_value(self, device): gradcheck(lambda x: x.logcumsumexp(2), a) gradgradcheck(lambda x: x.logcumsumexp(2), a) + @slowTest + def test_lu_backward(self, device): + def run_test(*sizes): + x = torch.rand(*sizes, device=device, dtype=torch.double).requires_grad_(True) + + gradcheck(lambda x: x.lu(get_infos=True), x) + gradgradcheck(lambda x: x.lu(get_infos=True), x) + + gradcheck(lambda x: x.lu(get_infos=False), x) + gradgradcheck(lambda x: x.lu(get_infos=False), x) + + # there is no pivot-less LU factorization on CPU + if x.device.type == 'cuda': + gradcheck(lambda x: x.lu(pivot=False, get_infos=True), x) + gradgradcheck(lambda x: x.lu(pivot=False, get_infos=True), x) + + gradcheck(lambda x: x.lu(pivot=False, get_infos=False), x) + gradgradcheck(lambda x: x.lu(pivot=False, get_infos=False), x) + + run_test(3, 3) + run_test(3, 3, 3) + run_test(3, 3, 3, 3) + run_test(5, 5) + run_test(3, 5, 5) + run_test(3, 3, 5, 5) + def test_strided_leaf_grad_layout(self, device): # (1) If leaf is non-overlapping and dense, grad's layout should match its leaf. for fmt_a in (torch.contiguous_format, torch.channels_last): @@ -6864,17 +7343,6 @@ def test_strided_leaf_grad_layout(self, device): (c * d).sum().backward() self.assertEqual(c.grad.stride(), (2, 1)) - def test_movedim(self, device): - x = torch.randn(4, 3, 2, 1, dtype=torch.double, device=device, requires_grad=True) - - # Positive axis - gradcheck(lambda x: torch.movedim(x, (0, 1, 2, 3), (3, 2, 1, 0)), x) - gradgradcheck(lambda x: torch.movedim(x, (0, 1, 2, 3), (3, 2, 1, 0)), x) - - # Negative axis - gradcheck(lambda x: torch.movedim(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x) - gradgradcheck(lambda x: torch.movedim(x, (0, -1, -2, -3), (-3, -2, -1, -0)), x) - def _test_atleast(self, device, torch_fn): # 0-dim s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) @@ -6902,6 +7370,54 @@ def test_atleast(self, device): self._test_atleast(device, torch.atleast_2d) self._test_atleast(device, torch.atleast_3d) + def test_xlogy(self, device): + + def _tensor_tensor_helper(x, y): + gradcheck(lambda x, y: torch.xlogy(x, y), (x, y)) + gradgradcheck(lambda x, y: torch.xlogy(x, y), (x, y)) + + with torch.no_grad(): + x = x.clone() + x[torch.rand_like(x) > 0.5] = 0 + + gradcheck(lambda y: torch.xlogy(x, y), (y)) + gradgradcheck(lambda y: torch.xlogy(x, y), (y)) + + shapes = ((4,), (1, 4), (1, 1, 4), (1, 1, 1, 4)) + + # For broadcastible shapes and scalar. + for x_shape, y_shape in permutations(shapes, 2): + x = torch.rand(*x_shape, dtype=torch.double, device=device, requires_grad=True) + y = torch.rand(*y_shape, dtype=torch.double, device=device, requires_grad=True) + + _tensor_tensor_helper(x, y) + _tensor_tensor_helper(y, x) + + gradcheck(lambda y: torch.xlogy(0, y), (y)) + gradgradcheck(lambda y: torch.xlogy(0, y), (y)) + + gradcheck(lambda y: torch.xlogy(2, y), (y)) + gradgradcheck(lambda y: torch.xlogy(2, y), (y)) + gradcheck(lambda y: torch.xlogy(y, 2), (y)) + gradgradcheck(lambda y: torch.xlogy(y, 2), (y)) + + # Different shape + x = torch.rand(2, 3, 4, 5, dtype=torch.double, device=device, requires_grad=True) + y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True) + _tensor_tensor_helper(x, y) + _tensor_tensor_helper(y, x) + _tensor_tensor_helper(x, x) + _tensor_tensor_helper(y, y) + + # Same shape + x = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True) + y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True) + _tensor_tensor_helper(x, y) + _tensor_tensor_helper(y, x) + _tensor_tensor_helper(x, x) + _tensor_tensor_helper(y, y) + + class TestMultithreadAutograd(TestCase): def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None): threads = [] @@ -7061,9 +7577,7 @@ def backward(ctx, *grad): instantiate_device_type_tests( TestAutogradDeviceType, globals(), - # Exclude ROCM for now, there are a lot of failures. See - # https://github.com/pytorch/pytorch/issues/30845 - except_for='cuda' if TEST_WITH_ROCM else None + except_for=None ) if __name__ == '__main__': diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py new file mode 100644 index 0000000000000..89e3c58be4988 --- /dev/null +++ b/test/test_binary_ufuncs.py @@ -0,0 +1,2537 @@ +import torch +import numpy as np + +import itertools +from itertools import product +import math +import random +import unittest +import warnings +import operator +from functools import partial + +from torch._six import inf, nan +from torch.testing._internal.common_utils import ( + TestCase, iter_indices, TEST_WITH_ASAN, run_tests, + torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY) +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, + dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA, + skipCUDAIfRocm, skipIf) + +if TEST_SCIPY: + import scipy.special + +# TODO: remove this +def _generate_input(shape, dtype, device, with_extremal): + if shape == (): + x = torch.tensor((), dtype=dtype, device=device) + else: + if dtype.is_floating_point or dtype.is_complex: + # work around torch.randn not being implemented for bfloat16 + if dtype == torch.bfloat16: + x = torch.randn(*shape, device=device) * random.randint(30, 100) + x = x.to(torch.bfloat16) + else: + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x[torch.randn(*shape) > 0.5] = 0 + if with_extremal and dtype.is_floating_point: + # Use extremal values + x[torch.randn(*shape) > 0.5] = float('nan') + x[torch.randn(*shape) > 0.5] = float('inf') + x[torch.randn(*shape) > 0.5] = float('-inf') + elif with_extremal and dtype.is_complex: + x[torch.randn(*shape) > 0.5] = complex('nan') + x[torch.randn(*shape) > 0.5] = complex('inf') + x[torch.randn(*shape) > 0.5] = complex('-inf') + elif dtype == torch.bool: + x = torch.zeros(shape, dtype=dtype, device=device) + x[torch.randn(*shape) > 0.5] = True + else: + x = torch.randint(15, 100, shape, dtype=dtype, device=device) + + return x + +# TODO: refactor this out +# Converts half/bfloat16 dtype to float when device is cpu +def _convert_t(dtype, device): + if device == 'cpu' and dtype in {torch.half, torch.bfloat16}: + return torch.float + return dtype + +# TODO: revise the tests to use make_tensor in common_utils.py instead +# Returns a tensor of the requested shape, dtype, and device +# Requesting a half CPU tensor returns a float CPU tensor with +# values representable by a half. +# Initialization uses randint for non-float types and randn for float types. +def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: + # Returns a tensor filled with ones + if fill_ones: + return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) + + # Returns a tensor with random integer values + if not (dtype.is_floating_point or dtype.is_complex): + t = torch.randint(0, 10, shape, device=device) + if dtype != torch.uint8: + t = t - 5 # generate negative values also + return t.to(_convert_t(dtype, device)) + + # Populates the CPU tensor with floats representable as half/bfloat16 + if dtype == torch.half and device == 'cpu': + return torch.randn(*shape, dtype=torch.float, device=device).half().float() + if dtype == torch.bfloat16 and device == 'cpu': + return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float() + + # Default: returns a tensor with random float values + return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) + +# TODO: update to use opinfos consistently +class TestBinaryUfuncs(TestCase): + + def test_add_broadcast_empty(self, device): + # empty + empty + self.assertRaises(RuntimeError, lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device)) + self.assertEqual(torch.randn(5, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, device=device)) + self.assertEqual(torch.randn(5, 0, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device)) + + # scalar + empty + self.assertEqual(torch.randn(5, 0, 6, device=device), torch.randn((), device=device) + torch.randn(5, 0, 6, device=device)) + + # non-empty, empty + self.assertEqual(torch.randn(0, device=device), torch.randn(0, device=device) + torch.randn(1, device=device)) + self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7, device=device), + torch.randn(0, 7, 0, 6, 5, 0, 1, device=device) + torch.randn(1, 1, 5, 1, 7, device=device)) + self.assertRaises(RuntimeError, lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device)) + + def test_addcmul_scalars_as_floats(self, device): + # zero-dim variables that don't require grad should bind to scalar arguments + x = torch.tensor(2.) + y = torch.tensor(3., device=device) + # 3 + (3 * 3) * 2 + self.assertEqual(y.addcmul(y, y, value=x), 21) + + x = torch.tensor(2., requires_grad=True) + self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x)) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_comparison_ops(self, device): + x = torch.randn(5, 5) + y = torch.randn(5, 5) + + eq = x == y + for idx in iter_indices(x): + self.assertEqual(x[idx] == y[idx], eq[idx] == 1) + + ne = x != y + for idx in iter_indices(x): + self.assertEqual(x[idx] != y[idx], ne[idx] == 1) + + lt = x < y + for idx in iter_indices(x): + self.assertEqual(x[idx] < y[idx], lt[idx] == 1) + + le = x <= y + for idx in iter_indices(x): + self.assertEqual(x[idx] <= y[idx], le[idx] == 1) + + gt = x > y + for idx in iter_indices(x): + self.assertEqual(x[idx] > y[idx], gt[idx] == 1) + + ge = x >= y + for idx in iter_indices(x): + self.assertEqual(x[idx] >= y[idx], ge[idx] == 1) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_comparison_ops_must_take_bool_output(self, device): + for op in [torch.lt, torch.le, torch.gt, torch.ge, torch.eq, torch.ne, + torch.logical_and, torch.logical_or, torch.logical_xor]: + self.assertEqual(op(torch.tensor([True]), torch.tensor([False])).dtype, torch.bool) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_inplace_comparison_ops_require_inputs_have_same_dtype(self, device): + with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'): + for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']: + x = torch.tensor([1], dtype=torch.int) + y = torch.tensor([2], dtype=torch.long) + in_place_method = getattr(x, op) + in_place_method(y) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_comparison_ops_check_for_scalar_overflow(self, device): + s = 1 << 20 + t = torch.tensor([1 << 5], dtype=torch.uint8) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t < s) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(s < t) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t <= s) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(s <= t) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t > s) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(s > t) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t >= s) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(s >= t) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t == s) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(s == t) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t != s) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(s != t) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_comparison_ops_check_for_zerodim_tensor_overflow(self, device): + t1 = torch.tensor([1 << 5], dtype=torch.uint8) + t2 = torch.tensor([1 << 30], dtype=torch.int32) + ts1 = torch.tensor(1 << 20, dtype=torch.int32) + ts2 = torch.tensor(1 << 40, dtype=torch.int64) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t1 < ts1) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(ts2 < t2) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t1 <= ts1) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(ts2 <= t2) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t1 > ts1) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(ts2 > t2) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t1 >= ts1) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(ts2 >= t2) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t1 == ts1) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(ts2 == t2) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(t1 != ts1) + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + self.assertTrue(ts2 != t2) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_bitwise_ops(self, device): + x = torch.randn(5, 5).gt(0) + y = torch.randn(5, 5).gt(0) + + and_result = x & y + for idx in iter_indices(x): + if and_result[idx]: + self.assertTrue(x[idx] and y[idx]) + else: + self.assertFalse(x[idx] and y[idx]) + + or_result = x | y + for idx in iter_indices(x): + if or_result[idx]: + self.assertTrue(x[idx] or y[idx]) + else: + self.assertFalse(x[idx] or y[idx]) + + xor_result = x ^ y + for idx in iter_indices(x): + if xor_result[idx]: + self.assertTrue(x[idx] ^ y[idx]) + else: + self.assertFalse(x[idx] ^ y[idx]) + + x_clone = x.clone() + x_clone &= y + self.assertEqual(x_clone, and_result) + + x_clone = x.clone() + x_clone |= y + self.assertEqual(x_clone, or_result) + + x_clone = x.clone() + x_clone ^= y + self.assertEqual(x_clone, xor_result) + + def test_inplace_division(self, device): + t = torch.rand(5, 5, device=device) + id_before = id(t) + t /= 2 + id_after = id(t) + self.assertEqual(id_before, id_after) + + # TODO: update to run on CUDA -- what is this test even testing? + @onlyCPU + def test_cast_binary_op(self, device): + # Scalar + a = torch.tensor(2) + b = torch.tensor(3) + a_copy = a.clone() + b_copy = b.clone() + + self.assertEqual(torch.tensor(6, dtype=torch.float), a.float() * b) + + self.assertEqualTypeString(a, a_copy) + self.assertEqualTypeString(b, b_copy) + + # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor + # throws the correct error message + @onlyCUDA + def test_cross_device_inplace_error_msg(self, device): + a = torch.tensor(2.) + b = torch.tensor(2., device=device) + with self.assertRaisesRegex(RuntimeError, + "Expected all tensors to be on the same device"): + a += b + + # TODO: refactor this test into a more generic one, it's parked here currently + @onlyOnCPUAndCUDA + def test_out_resize_warning(self, device): + a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32) + b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32) + + unary_inputs = (a,) + binary_inputs = (a, b) + unary_ops = (torch.ceil, torch.exp) + binary_ops = (torch.add, torch.sub) + for op in (unary_ops + binary_ops): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + inputs = unary_inputs if op in unary_ops else binary_inputs + + # No warnings + op(*inputs, out=torch.empty(3, device=device)) + op(*inputs, out=torch.empty(0, device=device)) + self.assertEqual(len(w), 0) + + # Cases that throw warnings + op(*inputs, out=torch.empty(2, device=device)) + self.assertEqual(len(w), 1) + + # Verifies that the inplace dunders (like idiv) actually are in place + @onlyOnCPUAndCUDA + def test_inplace_dunders(self, device): + t = torch.randn((1,), device=device) + expected = t.data_ptr() + t += 1 + t -= 1 + t *= 1 + t /= 1 + t //= 1 + t %= 1 + self.assertEqual(expected, t.data_ptr()) + + def check_internal_mem_overlap(self, inplace_op, num_inputs, + dtype, device, + expected_failure=False): + if isinstance(inplace_op, str): + inplace_op = getattr(torch.Tensor, inplace_op) + input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) + inputs = [input] + [torch.randn_like(input) + for i in range(num_inputs - 1)] + if not expected_failure: + with self.assertRaisesRegex(RuntimeError, 'single memory location'): + inplace_op(*inputs) + else: + with self.assertRaises(AssertionError): + with self.assertRaisesRegex(RuntimeError, 'single memory location'): + inplace_op(*inputs) + + def unary_check_input_output_mem_overlap(self, data, sz, op, + expected_failure=False): + + def _test(op, output, input): + output_exp = torch.empty_like(output) + op(input, out=output_exp) + self.assertEqual(op(input, out=output), output_exp, msg=op.__name__) + + # output is identical to input: + _test(op, output=data[0:sz], input=data[0:sz]) + # output and input are independent: + _test(op, output=data[0:sz], input=data[sz:2 * sz]) + # output partially overlaps with input: + if not expected_failure: + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + _test(op, data[0:sz], data[1:sz + 1]) + else: + with self.assertRaises(AssertionError): + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + _test(op, data[0:sz], data[1:sz + 1]) + + def binary_check_input_output_mem_overlap(self, op, device, + expected_failure=False): + sz = 3 + data = torch.randn(2 * sz, device=device) + other = torch.randn(sz, device=device) + + self.unary_check_input_output_mem_overlap( + data, sz, lambda input, out: op(other, input, out=out), + expected_failure=expected_failure) + + self.unary_check_input_output_mem_overlap( + data, sz, lambda input, out: op(input, other, out=out), + expected_failure=expected_failure) + + @dtypes(torch.double) + def test_binary_op_mem_overlap(self, device, dtype): + ops = [ + ("add", True, True, 'cpu'), + ("add", True, True, 'cuda'), + ("mul", True, True, 'cpu'), + ("mul", True, True, 'cuda'), + ("sub", True, True, 'cpu'), + ("sub", True, True, 'cuda'), + ("div", True, True, 'cpu'), + ("div", True, True, 'cuda'), + ("pow", True, True, 'cpu'), + ("pow", True, True, 'cuda'), + ("fmod", True, True, 'cpu'), + ("fmod", True, True, 'cuda'), + ("atan2", True, True, 'cpu'), + ("atan2", True, True, 'cuda'), + ("hypot", True, True, 'cpu'), + ("hypot", True, True, 'cuda'), + ("igamma", True, True, 'cpu'), + ("igamma", True, True, 'cuda'), + ("igammac", True, True, 'cpu'), + ("igammac", True, True, 'cuda'), + ("nextafter", True, True, 'cpu'), + ("nextafter", True, True, 'cuda'), + ("le", True, True, 'cpu'), + ("le", True, True, 'cuda'), + ("lt", True, True, 'cpu'), + ("lt", True, True, 'cuda'), + ("ge", True, True, 'cpu'), + ("ge", True, True, 'cuda'), + ("gt", True, True, 'cpu'), + ("gt", True, True, 'cuda'), + ("eq", True, True, 'cpu'), + ("eq", True, True, 'cuda'), + ("ne", True, True, 'cpu'), + ("ne", True, True, 'cuda'), + ("logical_and", True, True, 'cpu'), + ("logical_and", True, True, 'cuda'), + ("logical_or", True, True, 'cpu'), + ("logical_or", True, True, 'cuda'), + ("logical_xor", True, True, 'cpu'), + ("logical_xor", True, True, 'cuda'), + ] + + for (fn, has_input_output_mem_overlap_check, + has_internal_mem_overlap_check, dev) in ops: + if dev != device: + continue + out_op = getattr(torch, fn) + inplace_op = getattr(torch.Tensor, fn + '_') + self.check_internal_mem_overlap( + inplace_op, 2, dtype, device, + expected_failure=not has_internal_mem_overlap_check) + + self.binary_check_input_output_mem_overlap(out_op, device, + expected_failure=not has_input_output_mem_overlap_check) + + def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol): + for num in exponents: + if isinstance(num, int) and num < 0 and not m1.is_floating_point() and not m1.is_complex(): + with self.assertRaisesRegex(RuntimeError, + r'Integers to negative integer powers are not allowed\.'): + torch.pow(m1[4], num) + else: + # base - tensor, exponent - number + # contiguous + res1 = torch.pow(m1[4], num) + res2 = res1.clone().zero_() + # `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`. + for i in range(res2.size(0)): + res2[i] = pow_fn(m1[4][i], num) + rtol = 0 if atol is not None else None + self.assertEqual(res1, res2, atol=atol, rtol=rtol) + + # non-contiguous + res1 = torch.pow(m1[:, 4], num) + res2 = res1.clone().zero_() + for i in range(res2.size(0)): + res2[i] = pow_fn(m1[i, 4], num) + self.assertEqual(res1, res2, atol=atol, rtol=rtol) + + # scalar ** tensor to enforce correct handling of dtypes for __rpow__(). + expected_dtype = torch.result_type(num, m1) + res1 = num ** m1[4] + res2 = torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4] + self.assertEqual(res1, res2) + self.assertEqual(res1.dtype, expected_dtype) + + def test_pow(self, device): + # [res] torch.pow([res,] x) + + # pow has dedicated implementation for different exponents + for dtype in torch.testing.get_all_math_dtypes(device): + + # This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it + # for now. + if dtype == torch.half: + continue + + # deferring to https://github.com/pytorch/pytorch/pull/36793 + if dtype.is_complex: + continue + + m1 = torch.empty(0, dtype=dtype, device=device) + if m1.is_floating_point() or m1.is_complex(): + m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5 + else: + # math.pow will overflow and throw exceptions for large integers + range_high = 4 if dtype in (torch.int8, torch.uint8) else 10 + m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device) + + exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3] + complex_exponents = [-2.5j, -1.0j, 0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] + if m1.is_complex(): + self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4) + else: + self._do_pow_for_exponents(m1, exponents, math.pow, None) + self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) + + # base - number, exponent - tensor + # contiguous + res1 = torch.pow(3, m1[4]) + res2 = res1.clone().zero_() + for i in range(res2.size(0)): + res2[i] = math.pow(3, m1[4, i]) + self.assertEqual(res1, res2) + + # non-contiguous + res1 = torch.pow(3, m1[:, 4]) + res2 = res1.clone().zero_() + for i in range(res2.size(0)): + res2[i] = math.pow(3, m1[i][4]) + self.assertEqual(res1, res2) + + # resize behavior for exp == 1 + out = torch.zeros(1, dtype=dtype, device=device) + torch.pow(m1, 1, out=out) + self.assertEqual(out, m1) + + # TODO: refactor all these tests using opinfos properly + def _test_pow(self, base, exponent, np_exponent=None): + if np_exponent is None: + np_exponent = exponent + + def to_np(value): + if isinstance(value, torch.Tensor): + return value.cpu().numpy() + return value + + try: + np_res = np.power(to_np(base), to_np(np_exponent)) + expected = torch.from_numpy(np_res) if isinstance(np_res, np.ndarray) else torch.tensor(np_res, dtype=base.dtype) + except ValueError as e: + err_msg = "Integers to negative integer powers are not allowed." + self.assertEqual(str(e), err_msg) + out = torch.empty_like(base) + test_cases = [ + lambda: base.pow(exponent), + lambda: base.pow_(exponent), + lambda: torch.pow(base, exponent), + lambda: torch.pow(base, exponent, out=out) + ] + for test_case in test_cases: + self.assertRaisesRegex(RuntimeError, err_msg, test_case) + else: + if isinstance(base, torch.Tensor): + actual = base.pow(exponent) + self.assertEqual(actual, expected.to(actual)) + actual = base.clone() + if torch.can_cast(torch.result_type(base, exponent), base.dtype): + actual2 = actual.pow_(exponent) + self.assertEqual(actual, expected) + self.assertEqual(actual2, expected) + else: + self.assertRaisesRegex(RuntimeError, "can't be cast", lambda: actual.pow_(exponent)) + + actual = torch.pow(base, exponent) + self.assertEqual(actual, expected.to(actual)) + + actual2 = torch.pow(base, exponent, out=actual) + self.assertEqual(actual, expected.to(actual)) + self.assertEqual(actual2, expected.to(actual)) + + def test_int_pow(self, device): + + def _test_integral_pow(dt, range, dev): + tensor = torch.tensor((3, 3), dtype=dt, device=dev).random_(*range) + exps = [0, 1, 2, 4, + torch.tensor((3, 3), dtype=dt, device=dev).random_(0, 5)] + for exp in exps: + self._test_pow(tensor, exp) + + _test_integral_pow(torch.int8, (-3, 4), device) + _test_integral_pow(torch.uint8, (0, 4), device) + _test_integral_pow(torch.int16, (-5, 5), device) + _test_integral_pow(torch.int64, (-10, 10), device) + _test_integral_pow(torch.int32, (-10, 10), device) + + def test_int_tensor_pow_neg_ints(self, device): + ints = [torch.iinfo(torch.int32).min, + -3, -2, -1, 0, 1, 2, 3, + torch.iinfo(torch.int32).max] + neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1] + tensor = torch.tensor(ints, dtype=torch.int32, device=device) + for pow in neg_ints: + self._test_pow(tensor, pow) + + def test_long_tensor_pow_floats(self, device): + ints = [0, 1, 23, 4567] + floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] + tensor = torch.tensor(ints, dtype=torch.int64, device=device) + for pow in floats: + self._test_pow(tensor, pow) + + def test_float_scalar_pow_float_tensor(self, device): + floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, + 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] + tensor = torch.tensor(floats, dtype=torch.float32, device=device) + for base in floats: + self._test_pow(base, tensor) + + @onlyCUDA + def test_cuda_tensor_pow_scalar_tensor(self, device): + cuda_tensors = [torch.randn((3, 3), device=device), torch.tensor(3.0, device=device)] + scalar_tensors = [torch.tensor(5.0, device='cpu'), torch.tensor(-3), torch.tensor(1)] + for base, exp in product(cuda_tensors, scalar_tensors): + self._test_pow(base, exp) + + @onlyCUDA + def test_cpu_tensor_pow_cuda_scalar_tensor(self, device): + cpu_tensors = [torch.randn((3, 3), device='cpu'), torch.tensor(3.0, device='cpu')] + cuda_tensors = [torch.tensor(5.0, device='cuda'), torch.tensor(-3, device='cuda')] + for base, exp in product(cpu_tensors, cuda_tensors): + regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!' + self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp) + + @onlyOnCPUAndCUDA + @dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False))) + def test_complex_scalar_pow_tensor(self, device, dtype): + complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j, 1 + 0j] + exp = make_tensor((100,), device, dtype, low=-2, high=2) + exp[0] = exp[10] = exp[20] = 0 + for base in complexes: + self._test_pow(base, exp) + + def test_tensor_pow_tensor(self, dev): + def rotate(l, n): + return l[-n:] + l[:-n] + + def test_tensor_pow_tensor(values, torch_type, numpy_type): + vals_tensor = torch.tensor(values, dtype=torch_type, device=dev) + for i in range(len(values)): + pows = rotate(values, i) + pows_tensor = torch.tensor(pows, dtype=torch_type, device=dev) + self._test_pow(vals_tensor, pows_tensor) + + ints = [0, 1, 2, 3] + test_tensor_pow_tensor(ints, torch.int32, np.int32) + test_tensor_pow_tensor(ints, torch.int64, np.int64) + + floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, + 0.0, + 1 / 3, 1 / 2, 1.0, 2.0, 3.0] + test_tensor_pow_tensor(floats, torch.float32, np.float32) + test_tensor_pow_tensor(floats, torch.float64, np.float64) + + def test_logical_xor_with_nontrivial_alignment(self, device): + # test tensor that is not aligned to multiple of 16 bytes + size = 128 + a = (torch.randn(size, device=device) > 0) + b = (torch.randn(size, device=device) > 0) + c = (torch.randn(size, device=device) > 0) + non_trivial_alignment = [1, 2, 4, 8, 15] + for i in non_trivial_alignment: + for j in non_trivial_alignment: + for k in non_trivial_alignment: + a_ = a[i: 100 + i] + b_ = b[j: 100 + j] + c_ = c[k: 100 + k] + torch.logical_xor(a_, b_, out=c_) + for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()): + self.assertEqual(x ^ y, z) + + @dtypes(torch.float) + def test_add_with_tail(self, device, dtype): + # test tensor where there is a tail which is not a multiple + # of GPU warp size + for tail_size in [1, 63, 67, 130]: + size = 4096 + tail_size + a = torch.randn(size, device=device, dtype=dtype) + b = torch.randn(size, device=device, dtype=dtype) + c = a + b + for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()): + self.assertEqual(x + y, z) + + # Tests that CUDA tensors on different devices cannot be used in the same + # binary operation, and that CUDA "scalars" cannot be used in the same + # binary operation as non-scalar CPU tensors. + @deviceCountAtLeast(2) + @onlyCUDA + def test_cross_device_binary_ops(self, devices): + vals = (1., (2.,)) + cpu_tensor = torch.randn(2, 2) + for op in (operator.add, torch.add, + operator.sub, torch.sub, + operator.mul, torch.mul, + operator.truediv, torch.true_divide, + operator.floordiv, torch.floor_divide): + for a, b in product(vals, vals): + a = torch.tensor(a, device=devices[0]) + b = torch.tensor(b, device=devices[1]) + + with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): + op(a, b) + with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): + op(b, a) + with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): + op(a, cpu_tensor) + with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): + op(cpu_tensor, a) + + # This test ensures that a scalar Tensor can be safely used + # in a binary operation in conjunction with a Tensor on all + # available CUDA devices + @deviceCountAtLeast(2) + @onlyCUDA + def test_binary_op_scalar_device_unspecified(self, devices): + scalar_val = torch.tensor(1.) + for default_device in devices: + with torch.cuda.device(default_device): + for device in devices: + device_obj = torch.device(device) + x = torch.rand(3, device=device) + y0 = x * scalar_val + self.assertEqual(y0.device, device_obj) + y1 = scalar_val * x + self.assertEqual(y1.device, device_obj) + self.assertEqual(y0, y1) + + def test_div_and_floordiv_vs_python(self, device): + # Tests torch division ops which can handle both arguments being + # scalars. + # NOTE: torch.floor_divide currently truncates instead of flooring. + # the quotient. See https://github.com/pytorch/pytorch/issues/43874. + def _scalar_helper(python_op, torch_op): + for a, b in product(range(-10, 10), range(-10, 10)): + for op in (lambda x: x * .5, lambda x: math.floor(x)): + a = op(a) + b = op(b) + + # Skips zero divisors + if b == 0: + continue + + expected = python_op(a, b) + + for op in (operator.truediv, torch.true_divide): + actual_scalar = torch_op(a, b) + + a_t = torch.tensor(a, device=device) + b_t = torch.tensor(b, device=device) + + actual_tensor = torch_op(a_t, b_t) + actual_first_tensor = torch_op(a_t, b) + actual_second_tensor = torch_op(a, b_t) + + self.assertEqual(actual_scalar, expected_div) + self.assertEqual(actual_tensor.item(), expected_div) + self.assertEqual(actual_first_tensor, actual_tensor) + self.assertEqual(actual_second_tensor, actual_tensor) + + _scalar_helper(operator.truediv, operator.truediv) + _scalar_helper(operator.truediv, torch.true_divide) + _scalar_helper(lambda a, b: math.trunc(a / b), operator.floordiv) + _scalar_helper(lambda a, b: math.trunc(a / b), torch.floor_divide) + + # NOTE: torch.floor_divide currently truncates instead of flooring. + # See https://github.com/pytorch/pytorch/issues/43874. + @onlyOnCPUAndCUDA + def test_div_and_floordiv_script_vs_python(self, device): + # Creates jitted functions of two tensors + def _wrapped_div(a, b): + return a / b + + def _wrapped_floordiv(a, b): + return a // b + + scripted_div = torch.jit.script(_wrapped_div) + scripted_floordiv = torch.jit.script(_wrapped_floordiv) + for a, b in product(range(-10, 10), range(-10, 10)): + for op in (lambda x: x * .5, lambda x: math.floor(x)): + a = op(a) + b = op(b) + + # Skips zero divisors + if b == 0: + continue + + expected_div = a / b + expected_truncdiv = math.trunc(a / b) + a_t = torch.tensor(a, device=device) + b_t = torch.tensor(b, device=device) + + self.assertEqual(scripted_div(a_t, b_t), expected_div) + self.assertEqual(scripted_floordiv(a_t, b_t), expected_truncdiv) + + # Creates jitted functions of one tensor + def _wrapped_div_scalar(a): + return a / 5 + + # NOTE: the JIT implements division as torch.reciprocal(a) * 5 + def _wrapped_rdiv_scalar(a): + return 5 / a + + def _wrapped_floordiv_scalar(a): + return a // 5 + + # NOTE: this fails if the input is not an integer tensor + # See https://github.com/pytorch/pytorch/issues/45199 + def _wrapped_rfloordiv_scalar(a): + return 5 // a + + scripted_div_scalar = torch.jit.script(_wrapped_div_scalar) + scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar) + scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar) + scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar) + + for a in range(-10, 10): + for op in (lambda x: x * .5, lambda x: math.floor(x)): + a = op(a) + + a_t = torch.tensor(a, device=device) + + self.assertEqual(a / 5, scripted_div_scalar(a_t)) + self.assertEqual(math.trunc(a / 5), scripted_floordiv_scalar(a_t)) + + # Skips zero divisors + if a == 0: + continue + + self.assertEqual(5 / a, scripted_rdiv_scalar(a_t)) + + # Handles Issue 45199 (see comment above) + if a_t.is_floating_point(): + with self.assertRaises(RuntimeError): + scripted_rfloordiv_scalar(a_t) + else: + self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t)) + + # NOTE: torch.floor_divide currently truncates instead of flooring + # the quotient. See https://github.com/pytorch/pytorch/issues/43874. + @onlyOnCPUAndCUDA + def test_idiv_and_ifloordiv_vs_python(self, device): + def _wrapped_idiv_tensor(a, b): + a /= b + return a + + def _wrapped_idiv_scalar(a): + a /= 5 + return a + + def _wrapped_true_divide__tensor(a, b): + a.true_divide_(b) + return a + + def _wrapped_true_divide__scalar(a): + a.true_divide_(5) + return a + + def _wrapped_floor_divide__tensor(a, b): + a.floor_divide_(b) + return a + + def _wrapped_floor_divide__scalar(a): + a.floor_divide_(5) + return a + + # The following functions are unsupported by the JIT + def _wrapped_ifloordiv_tensor(a, b): + a //= b + return a + + def _wrapped_ifloordiv_scalar(a): + a //= 5 + return a + + with self.assertRaises(torch.jit.frontend.NotSupportedError): + scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor) + + with self.assertRaises(torch.jit.frontend.NotSupportedError): + scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar) + + scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor) + scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar) + scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor) + scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar) + scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor) + scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar) + + for a, b in product(range(-10, 10), range(-10, 10)): + for op in (lambda x: x * .5, lambda x: math.floor(x)): + a = op(a) + b = op(b) + + # Skips zero divisors + if b == 0: + continue + + expected_idiv = a / b + expected_ifloordiv = a // b + expected_itruncdiv = math.trunc(a / b) + + a_t = torch.tensor(a, device=device) + b_t = torch.tensor(b, device=device) + + if a_t.is_floating_point(): + tmp0 = a_t.clone() + tmp0 /= b + + tmp1 = a_t.clone() + tmp1 /= b_t + + self.assertEqual(tmp0.item(), expected_idiv) + self.assertEqual(tmp1.item(), expected_idiv) + self.assertEqual(scripted_true_divide__tensor(a_t.clone(), b_t).item(), expected_idiv) + self.assertEqual(scripted_true_divide__scalar(a_t.clone()).item(), a / 5) + else: + tmp = a_t.clone() + with self.assertRaises(RuntimeError): + tmp /= b + with self.assertRaises(RuntimeError): + tmp /= b_t + with self.assertRaises(RuntimeError): + scripted_true_divide__tensor(tmp, b_t) + with self.assertRaises(RuntimeError): + scripted_true_divide__scalar(tmp) + + + if not a_t.is_floating_point() and b_t.is_floating_point(): + # Inplace modification fails because a float tensor is required + # if the divisor is a float tensor + with self.assertRaises(RuntimeError): + a_t.clone().floor_divide_(b_t) + with self.assertRaises(RuntimeError): + scripted_floor_divide_tensor(a_t.clone(), b_t) + tmp = a_t.clone() + with self.assertRaises(RuntimeError): + tmp //= b_t + else: + # Inplace modification is OK when both or neither tensor is + # a float tensor + self.assertEqual(a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv) + self.assertEqual(scripted_floor_divide__tensor(a_t.clone(), b_t).item(), expected_itruncdiv) + tmp = a_t.clone() + tmp //= b_t + self.assertEqual(tmp.item(), expected_itruncdiv) + + self.assertEqual(scripted_floor_divide__scalar(a_t), math.trunc(a / 5)) + + # Tests binary op equivalence with Python builtin ops + # Also tests that reverse operations are equivalent to forward ops + # NOTE: division ops are tested separately above + def test_binary_ops_with_scalars(self, device): + for ops in ((operator.add, torch.add), + (operator.sub, torch.sub), + (operator.mul, torch.mul), + (operator.truediv, torch.div)): + python_op, torch_op = ops + + for a, b in product(range(-10, 10), range(-10, 10)): + for op in (lambda x: x * .5, lambda x: math.floor(x)): + a = op(a) + b = op(b) + + # Skips zero divisors + if b == 0 or a == 0: + continue + + a_tensor = torch.tensor(a, device=device) + b_tensor = torch.tensor(b, device=device) + a_tensor_cpu = a_tensor.cpu() + b_tensor_cpu = b_tensor.cpu() + vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu) + + for args in product(vals, vals): + first, second = args + + first_scalar = first if not isinstance(first, torch.Tensor) else first.item() + second_scalar = second if not isinstance(second, torch.Tensor) else second.item() + expected = python_op(first_scalar, second_scalar) + + self.assertEqual(expected, python_op(first, second)) + self.assertEqual(expected, torch_op(first, second)) + + @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), torch.testing.get_all_dtypes(include_complex=False))) + def test_maximum_minimum_type_promotion(self, device, dtypes): + a = torch.tensor((0, 1), device=device, dtype=dtypes[0]) + b = torch.tensor((1, 0), device=device, dtype=dtypes[1]) + for op in (torch.maximum, torch.max, torch.minimum, torch.min): + result = op(a, b) + self.assertEqual(result.dtype, torch.result_type(a, b)) + + @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) + def test_maximum_minimum_int_and_bool(self, device, dtype): + ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) + rng = np.random.default_rng() + a_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) + b_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) + + for torch_op, alias, numpy_op in ops: + a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) + b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) + tensor_result = torch_op(a_tensor, b_tensor) + alias_result = alias(a_tensor, b_tensor) + + out = torch.empty_like(a_tensor) + torch_op(a_tensor, b_tensor, out=out) + + numpy_result = numpy_op(a_np, b_np) + + self.assertEqual(alias_result, tensor_result) + self.assertEqual(tensor_result, numpy_result) + self.assertEqual(out, numpy_result) + + @precisionOverride({torch.bfloat16: 1e-2}) + @dtypes(*(torch.testing.get_all_fp_dtypes())) + def test_maximum_minimum_float(self, device, dtype): + ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) + + if dtype == torch.bfloat16: + a_np = np.random.randn(10).astype(np.float64) + b_np = np.random.randn(10).astype(np.float64) + else: + a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) + b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) + + for torch_op, alias, numpy_op in ops: + numpy_result = numpy_op(a_np, b_np) + + a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) + b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) + tensor_result = torch_op(a_tensor, b_tensor) + alias_result = alias(a_tensor, b_tensor) + out = torch.empty_like(a_tensor) + torch_op(a_tensor, b_tensor, out=out) + + self.assertEqual(alias_result, tensor_result) + self.assertEqual(tensor_result, numpy_result) + self.assertEqual(out, numpy_result) + + @dtypes(*(torch.testing.get_all_fp_dtypes())) + def test_maximum_minimum_float_nan_and_inf(self, device, dtype): + # np.maximum and np.minimum functions compare input arrays element-wisely. + # if one of the elements being compared is a NaN, then that element is returned. + ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) + a_vals = (float('inf'), -float('inf'), float('nan'), float('nan')) + b_vals = (-float('inf'), float('inf'), float('inf'), float('nan')) + if dtype == torch.bfloat16: + a_np = np.array(a_vals, dtype=np.float64) + b_np = np.array(b_vals, dtype=np.float64) + else: + a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype]) + b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype]) + + for torch_op, alias, numpy_op in ops: + numpy_result = numpy_op(a_np, b_np) + + a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) + b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) + tensor_result = torch_op(a_tensor, b_tensor) + alias_result = alias(a_tensor, b_tensor) + + out = torch.empty_like(a_tensor) + torch_op(a_tensor, b_tensor, out=out) + + self.assertEqual(alias_result, tensor_result) + if dtype == torch.bfloat16: + self.assertEqual(tensor_result, numpy_result, exact_dtype=False) + self.assertEqual(out, numpy_result, exact_dtype=False) + else: + self.assertEqual(tensor_result, numpy_result) + self.assertEqual(out, numpy_result) + + @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) + def test_maximum_minimum_complex(self, device, dtypes): + for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min): + with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'): + torch_op(torch.ones(1, device=device, dtype=dtypes[0]), + torch.ones(1, device=device, dtype=dtypes[1])) + + with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'): + torch_op(torch.ones(1, device=device, dtype=dtypes[1]), + torch.ones(1, device=device, dtype=dtypes[0])) + + @onlyCUDA + def test_maximum_minimum_cross_device(self, device): + a = torch.tensor((1, 2, -1)) + b = torch.tensor((3, 0, 4), device=device) + ops = (torch.maximum, torch.minimum) + + for torch_op in ops: + with self.assertRaisesRegex(RuntimeError, + "Expected all tensors to be on the same device"): + torch_op(a, b) + + with self.assertRaisesRegex(RuntimeError, + "Expected all tensors to be on the same device"): + torch_op(b, a) + + # test cuda tensor and cpu scalar + ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum)) + a_np = np.array(1) + b_np = np.array([3, 0, 4]) + + for torch_op, numpy_op in ops: + a_tensor = torch.from_numpy(a_np) + b_tensor = torch.from_numpy(b_np).to(device=device) + tensor_result_1 = torch_op(a_tensor, b_tensor) + numpy_result_1 = numpy_op(a_np, b_np) + tensor_result_2 = torch_op(b_tensor, a_tensor) + numpy_result_2 = numpy_op(b_np, a_np) + + self.assertEqual(tensor_result_1, numpy_result_1) + self.assertEqual(tensor_result_2, numpy_result_2) + + # TODO: tests like this should be generic + @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypes(torch.float, torch.double) + def test_mul_intertype_scalar(self, device, dtype): + x = torch.tensor(1.5, dtype=dtype, device=device) + y = torch.tensor(3, dtype=torch.int32, device=device) + + self.assertEqual(x * y, 4.5) + self.assertEqual(y * x, 4.5) + + with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): + y *= x + x *= y + self.assertEqual(x, 4.5) + + @onlyCPU + @dtypes(*torch.testing.get_all_dtypes()) + def test_sub(self, device, dtype): + m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device) + m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device) + + if dtype == torch.bool: + self.assertRaises(RuntimeError, lambda: m1 - m2) + elif (dtype == torch.bfloat16 or dtype == torch.half): + # bfloat16 has a lower precision so we have to have a separate check for it + self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype), atol=0.01, rtol=0) + else: + self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype)) + + # TODO: what is this test testing? + @onlyCPU + @dtypes(torch.float) + def test_csub(self, device, dtype): + # with a tensor + a = torch.randn(100, 90, dtype=dtype, device=device) + b = a.clone().normal_() + + res_add = torch.add(a, b, alpha=-1) + res_csub = a.clone() + res_csub.sub_(b) + self.assertEqual(res_add, res_csub) + + # with a scalar + a = torch.randn(100, 100, dtype=dtype, device=device) + + scalar = 123.5 + res_add = torch.add(a, -scalar) + res_csub = a.clone() + res_csub.sub_(scalar) + self.assertEqual(res_add, res_csub) + + # TODO: reconcile with minimum/maximum tests + @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypes(torch.float, torch.double) + def test_min_max_binary_op_nan(self, device, dtype): + a = torch.rand(1000, dtype=dtype, device=device) + b = torch.rand(1000, dtype=dtype, device=device) + + # 0:250: a -- nan, b -- not nan + a[:250] = float('nan') + # 250:500: a -- not nan, b -- nan + b[250:500] = float('nan') + # 500:750: a and b both nan + a[500:750] = float('nan') + b[500:750] = float('nan') + # 750:1000: neither nan + + ma = torch.max(a, b) + mi = torch.min(a, b) + + for i in range(750): + self.assertTrue(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) + self.assertTrue(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) + + for i in range(750, 1000): + self.assertFalse(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) + self.assertFalse(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) + + @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), + torch.testing.get_all_dtypes(include_complex=False))) + def test_copysign(self, device, dtypes): + def _test_copysign_numpy(a, b): + torch_result = torch.copysign(a, b) + + if a.dtype == torch.bfloat16: + np_a = a.to(torch.float).cpu().numpy() + else: + np_a = a.cpu().numpy() + + if b.dtype == torch.bfloat16: + np_b = b.to(torch.float).cpu().numpy() + else: + np_b = b.cpu().numpy() + expected = torch.from_numpy(np.copysign(np_a, np_b)) + # To handle inconsistencies of type promotion between PyTorch and Numpy + # Applied for both arguments having integral precision and bfloat16 + types = [torch.bool, torch.bfloat16] + torch.testing.get_all_int_dtypes() + if a.dtype in types or b.dtype in types: + promoted_type = torch.promote_types(torch_result.dtype, expected.dtype) + torch_result = torch_result.to(promoted_type) + expected = expected.to(promoted_type) + + # Verify Value + self.assertEqual(torch_result, expected) + # Verify Sign + # Use double copysign to verify the correctnes of 0.0 and -0.0, since + # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the + # magnitude to verify the sign between torch and numpy results, elementwise. + # Special case: NaN conversions between FP32 and FP16 is not bitwise + # equivalent to pass this assertion. + if a.dtype != torch.float16 and b.dtype != torch.float16: + self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result), + torch.copysign(torch.tensor(1.0), expected)) + + # Compare Result with NumPy + # Type promotion + a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) + b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) + _test_copysign_numpy(a, b) + + # Broadcast + a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9) + b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) + _test_copysign_numpy(a, b) + + a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) + b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9) + _test_copysign_numpy(a, b) + + # 0.0/-0.0/inf/-inf/nan + cases = [0.0, -0.0, float('inf'), float('-inf'), float('nan')] + # torch.bfloat16 can not hold '-nan' + # torch.half can not hold '-nan' on CUDA + types = [torch.float32, torch.float64] + if device == 'cpu': + types.append(torch.float16) + if dtypes[0] in types: + b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) + for case in cases: + _test_copysign_numpy(torch.tensor([case], device=device, dtype=dtypes[0]), b) + + if dtypes[1] in torch.testing.get_all_fp_dtypes(): + a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) + for case in cases: + _test_copysign_numpy(a, torch.tensor([case], device=device, dtype=dtypes[1])) + + @dtypes(torch.bfloat16, torch.float) + def test_div(self, device, dtype): + for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_), + (torch.true_divide, torch.Tensor.true_divide, + torch.Tensor.true_divide_)): + m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype) + res1 = m1.clone() + inplace(res1[:, 3], 2) + res2 = m1.clone() + for i in range(m1.size(0)): + res2[i, 3] = res2[i, 3] / 2 + self.assertEqual(res1, res2) + + if dtype == torch.bfloat16: + a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) + a2 = torch.tensor([2., 2.], dtype=dtype, device=device) + self.assertEqual(op(a1, a2), + torch.tensor([2.1, 3.1], dtype=dtype, device=device), + atol=0.01, rtol=0) + self.assertEqual(method(a1, a2), op(a1, a2)) + + @dtypes(torch.bfloat16, torch.float) + def test_true_divide_out(self, device, dtype): + a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) + a2 = torch.tensor([2., 2.], dtype=dtype, device=device) + res = torch.empty_like(a1) + self.assertEqual(torch.true_divide(a1, a2, out=res), + torch.tensor([2.1, 3.1], dtype=dtype, device=device), + atol=0.01, rtol=0) + + @onlyCUDA + @dtypes(torch.half) + def test_divmul_scalar(self, device, dtype): + x = torch.tensor(100., device=device, dtype=dtype) + x_ref = x.float() + scale = 1e5 + res = x.div(scale) + expected = x_ref.div(scale) + self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) + x = torch.tensor(1e-5, device=device, dtype=dtype) + x_ref = x.float() + res = x.mul(scale) + expected = x_ref.mul(scale) + self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) + res = scale * x + self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) + + @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) + @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) + def test_floor_divide_tensor(self, device, dtype): + x = torch.randn(10, device=device).mul(30).to(dtype) + y = torch.arange(1, 11, dtype=dtype, device=device) + + z = x // y + z_alt = torch.trunc(x.double() / y.double()).to(dtype) + + self.assertEqual(z.dtype, x.dtype) + self.assertEqual(z, z_alt) + + @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) + @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) + def test_floor_divide_scalar(self, device, dtype): + x = torch.randn(100, device=device).mul(10).to(dtype) + + z = x // 3 + z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device) + + self.assertEqual(z.dtype, x.dtype) + self.assertEqual(z, z_alt) + + # Note: this tests fails on XLA + @onlyOnCPUAndCUDA + @dtypes(torch.float, torch.long) + def test_floor_divide_out(self, device, dtype): + x = torch.randn(10, device=device).mul(10).to(dtype) + y = torch.arange(1, 11, dtype=dtype, device=device) + o = torch.empty(10, dtype=dtype, device=device) + + torch.floor_divide(x, y, out=o) + self.assertEqual(o, x // y) + + # Tests scalar with out + torch.floor_divide(x, 2, out=o) + self.assertEqual(o, x // 2) + + if dtype == torch.int: + o = torch.empty(10, dtype=torch.float, device=device) + torch.floor_divide(x, y, out=o) + self.assertEqual(o, torch.floor_divide(x.float(), y.float())) + + @onlyCPU + @dtypes(*torch.testing.get_all_math_dtypes('cpu')) + def test_rdiv(self, device, dtype): + if dtype is torch.float16: + return + elif dtype.is_complex: + x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4) + else: + x = torch.rand(100, device=device).add(1).mul(4).to(dtype) + y = 30 / x + z = torch.tensor([30 / v.item() for v in x], device=device) + self.assertEqual(y, z, exact_dtype=False) + + @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False)) + def test_fmod_remainder_by_zero_float(self, device, dtype): + fn_list = (torch.fmod, torch.remainder) + for fn in fn_list: + # check floating-point tensor fmod/remainder to zero is nan on both CPU and GPU + x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) + zero = torch.zeros_like(x) + self.assertTrue(torch.all(fn(x, 0.0).isnan())) + self.assertTrue(torch.all(fn(x, zero).isnan())) + + @onlyOnCPUAndCUDA # Check Issue https://github.com/pytorch/pytorch/issues/48130 + @skipCUDAIfRocm # Error happens on both ROCM and XLA + @dtypes(*torch.testing.get_all_int_dtypes()) + def test_fmod_remainder_by_zero_integral(self, device, dtype): + fn_list = (torch.fmod, torch.remainder) + for fn in fn_list: + # check integral tensor fmod/remainder to zero + x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) + zero = torch.zeros_like(x) + # RuntimeError on CPU + if self.device_type == 'cpu': + with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): + fn(x, zero) + # Different value for different dtype on CUDA: + # Due to it's an undefined behavior, CUDA returns a pattern of all 1s + # for integral dividend (other than int64) divided by zero. For int64, + # CUDA returns all 1s for negative dividend, half 1s for positive dividend. + # uint8: 0xff -> 255 + # int32: 0xffffffff -> -1 + else: + if dtype == torch.int64: + self.assertEqual(fn(x, zero) == 4294967295, x >= 0) + self.assertEqual(fn(x, zero) == -1, x < 0) + else: + value = 255 if dtype == torch.uint8 else -1 + self.assertTrue(torch.all(fn(x, zero) == value)) + + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) + def test_fmod_remainder(self, device, dtype): + # Use numpy as reference + def _helper(x, mod): + fns_list = ((torch.fmod, torch.Tensor.fmod_, np.fmod), + (torch.remainder, torch.Tensor.remainder_, np.remainder)) + for fn, inplace_fn, ref_fn in fns_list: + np_x = x.cpu().numpy() + np_mod = mod.cpu().numpy() if torch.is_tensor(mod) else mod + exp = ref_fn(np_x, np_mod) + exp = torch.from_numpy(exp) + res = fn(x, mod) + + self.assertEqual(res, exp, exact_dtype=False) + # out + out = torch.empty(0, device=device, dtype=res.dtype) + fn(x, mod, out=out) + self.assertEqual(out, exp, exact_dtype=False) + self.assertEqual(out.size(), torch.Size([10, 10])) + # in-place (Type cast runtime error) + try: + inplace_fn(x, mod) + self.assertEqual(x, exp, exact_dtype=False) + except RuntimeError as e: + self.assertRegex(str(e), "result type (Half|Float|Double) " + "can't be cast to the desired output " + "type (Byte|Char|Short|Int|Long)") + + x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) + # mod with same dtype as x + mod = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) + # Exclude 0 + mod[mod == 0] = 1 + + # Mods: Integer, Float, Tensor, Non-contiguous Tensor + mods = [3, 2.3, mod, mod.t()] + # mod with floating-point dtype + if dtype in torch.testing.get_all_int_dtypes(): + mod_float = make_tensor((10, 10), device=device, dtype=torch.float, low=-9, high=9) + mod[mod == 0] = 1 + mods.append(mod_float) + + for dividend, mod in product([x, x.t()], mods): + _helper(dividend, mod) + + @dtypes(torch.float, torch.double) + def test_remainder_fmod_large_dividend(self, device, dtype): + alarge = 1e9 + pi = 3.14159265358979 + for avalue in [alarge, -alarge]: + for bvalue in [pi, -pi]: + a = torch.tensor([avalue], dtype=dtype, device=device) + b = torch.tensor([bvalue], dtype=dtype, device=device) + c = torch.remainder(a, b) + d = torch.fmod(a, b) + self.assertTrue((b[0] > 0) == (c[0] > 0)) # remainder has same sign as divisor + self.assertTrue((a[0] > 0) == (d[0] > 0)) # fmod has same sign as dividend + self.assertTrue(abs(c[0]) < abs(b[0])) # remainder is within range of divisor + self.assertTrue(abs(d[0]) < abs(b[0])) # fmod is within range of divisor + if ((a[0] > 0) == (b[0] > 0)): + self.assertTrue(c[0] == d[0]) # remainder is same as fmod + else: + self.assertTrue(abs(c[0] - d[0]) == abs(b[0])) # differ by one divisor + + @dtypesIfCPU(torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + def test_hypot(self, device, dtype): + inputs = [ + (torch.randn(10, device=device).to(dtype), torch.randn(10, device=device).to(dtype)), + (torch.randn((3, 3, 3), device=device).to(dtype), torch.randn((3, 3, 3), device=device).to(dtype)), + (torch.randn((10, 1), device=device).to(dtype), torch.randn((10, 1), device=device).to(dtype).transpose(0, 1)), + (torch.randint(100, (10, ), device=device, dtype=torch.long), torch.randn(10, device=device).to(dtype)) + ] + for input in inputs: + actual = torch.hypot(input[0], input[1]) + if dtype == torch.bfloat16: + expected = torch.sqrt(input[0] * input[0] + input[1] * input[1]) + else: + expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy()) + self.assertEqual(actual, expected) + + @onlyOnCPUAndCUDA + @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) + def test_gcd(self, device, dtype): + # Tests gcd(0, 0), gcd(0, a) cases + t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) + t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) + actual = torch.gcd(t1, t2) + expected = np.gcd([0, 10, 0], [0, 0, 10]) + self.assertEqual(actual, expected) + + if dtype == torch.uint8: + # Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128) + a = torch.tensor([190, 210], device=device, dtype=dtype) + b = torch.tensor([190, 220], device=device, dtype=dtype) + actual = torch.gcd(a, b) + expected = torch.tensor([190, 10], device=device, dtype=dtype) + else: + # Compares with NumPy + a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) + b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) + actual = torch.gcd(a, b) + expected = np.gcd(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(actual, expected) + + @onlyOnCPUAndCUDA + @dtypes(torch.int16, torch.int32, torch.int64) + def test_lcm(self, device, dtype): + # Tests lcm(0, 0), lcm(0, a) cases + t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) + t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) + actual = torch.lcm(t1, t2) + expected = np.lcm([0, 10, 0], [0, 0, 10]) + self.assertEqual(actual, expected) + + # Compares with NumPy + a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) + b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) + actual = torch.lcm(a, b) + expected = np.lcm(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(actual, expected) + + @onlyOnCPUAndCUDA + @dtypes(torch.float32, torch.float64) + def test_nextafter(self, device, dtype): + # Test special cases + t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype) + t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype) + actual = torch.nextafter(t1, t2) + expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy()) + self.assertEqual(actual, expected, atol=0, rtol=0) + + actual = torch.nextafter(t2, t1) + expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy()) + self.assertEqual(actual, expected, atol=0, rtol=0) + + t1 = torch.tensor([0, nan], device=device, dtype=dtype) + t2 = torch.tensor([nan, 0], device=device, dtype=dtype) + self.assertTrue(torch.nextafter(t1, t2).isnan().all()) + + a = torch.randn(100, device=device, dtype=dtype) + b = torch.randn(100, device=device, dtype=dtype) + actual = torch.nextafter(a, b) + expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(actual, expected, atol=0, rtol=0) + + def _test_cop(self, torchfn, mathfn, dtype, device): + def reference_implementation(res2): + for i, j in iter_indices(sm1): + idx1d = i * sm1.size(0) + j + res2[i, j] = mathfn(sm1[i, j], sm2[idx1d]) + return res2 + + # contiguous + m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) + m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device) + sm1 = m1[4] + sm2 = m2[4] + + res1 = torchfn(sm1, sm2.view(10, 10)) + res2 = reference_implementation(res1.clone()) + self.assertEqual(res1, res2) + + # non-contiguous + m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) + m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device) + sm1 = m1[:, 4] + sm2 = m2[:, 4] + # view as sm1.size() + sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0])) + res1 = torchfn(sm1, sm2) + # reference_implementation assumes 1-d sm2 + sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()) + res2 = reference_implementation(res1.clone()) + self.assertEqual(res1, res2) + + @onlyCPU + @dtypes(torch.float) + def test_cdiv(self, device, dtype): + self._test_cop(torch.div, lambda x, y: x / y, dtype, device) + + @onlyCPU + @dtypes(torch.float) + def test_cremainder(self, device, dtype): + self._test_cop(torch.remainder, lambda x, y: x % y, dtype, device) + + @onlyCPU + @dtypes(torch.float) + def test_cmul(self, device, dtype): + self._test_cop(torch.mul, lambda x, y: x * y, dtype, device) + + @onlyCPU + @dtypes(torch.float) + def test_cpow(self, device, dtype): + self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device) + + @onlyCPU + @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) + def test_floor_divide_zero(self, device, dtype): + a = torch.tensor([0, 1], dtype=dtype, device=device) + b = torch.tensor([0, 1], dtype=dtype, device=device) + with self.assertRaisesRegex(RuntimeError, 'ZeroDivisionError'): + a // b + + @unittest.skipIf(TEST_WITH_ASAN, "Integer overflows are not allowed under ASAN") + @dtypes(*torch.testing.get_all_dtypes()) + def test_muldiv_scalar(self, device, dtype): + x = make_tensor((10, 3), device, dtype, low=None, high=None) + s = make_tensor((1,), 'cpu', dtype, low=None, high=None).item() + y = torch.full_like(x, s) + self.assertEqual(x * s, x * y) + self.assertEqual(s * x, y * x) + self.assertEqual(x / s, x / y) + self.assertEqual(s / x, y / x) + + @dtypes(*tuple(itertools.combinations_with_replacement(torch.testing.get_all_dtypes(), 2))) + def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes): + # issue #42660 + # testing all combinations of broadcasting and type promotion + # with a range of dtypes and input shapes, and with extremal values + def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None): + # working around the fact that numpy doesn't support bfloat16 + # by letting numpy treat them as float32's + x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32) + y_np = y.cpu().numpy() if y.dtype != torch.bfloat16 else y.to(torch.float32).cpu().numpy() + self.compare_with_numpy(lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y), + lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np), + x_np) + + complex_op_denylist = [torch.lt, torch.le, torch.gt, torch.ge] # complex not supported + input_sizes = [ + (1,), + (10,), + (10, 1), + (1, 10), + (4, 10), + (64, 10), + (12, 3)] + op_pairs = [(torch.lt, np.less), + (torch.le, np.less_equal), + (torch.gt, np.greater), + (torch.ge, np.greater_equal), + (torch.eq, np.equal), + (torch.ne, np.not_equal), + (torch.logical_and, np.logical_and), + (torch.logical_or, np.logical_or), + (torch.logical_xor, np.logical_xor)] + + for size1 in input_sizes: + size2 = (2,) + size1 # perform broadcasting + for with_extremal in [False, True]: + a = _generate_input(size1, dtypes[0], device, with_extremal) + b = _generate_input(size2, dtypes[1], device, with_extremal) + for torch_op, numpy_op in op_pairs: + if (dtypes[0].is_complex or dtypes[1].is_complex) and torch_op in complex_op_denylist: + continue + # functional version of op + compare_with_numpy_bin_op(torch_op, numpy_op, a, b) + + # functional comparison ops always return bool tensors + self.assertEqual(torch_op(a, b).dtype, torch.bool) + + # out version of op + out = torch.zeros(1, dtype=torch.complex128) # all casts to complex128 are safe + compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out) + + @onlyOnCPUAndCUDA + @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) + def test_signed_shift(self, device, dtype): + "Ensure that signed integer bit shifting works as expected." + a = torch.tensor([-10, 10], device=device, dtype=dtype) # [11...1110110, 1010] + expected_l = torch.tensor([-40, 40], device=device, dtype=dtype) # [11...11011000, 101000] + self.assertEqual(a << 2, expected_l) + self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a) + expected_r = torch.tensor([-5, 5], device=device, dtype=dtype) # [1111...111011, 101] + self.assertEqual(a >> 1, expected_r) + self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a) + + def test_bitwise_and(self, device): + for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + a = torch.tensor([1, -2, 3], dtype=dtype, device=device) + b = torch.tensor([2, 1, 3], dtype=dtype, device=device) + expected_res = torch.tensor([0, 0, 3], dtype=dtype, device=device) + b_scalar = 2 + expected_res_scalar = torch.tensor([0, 2, 2], dtype=dtype, device=device) + + # standard version + self.assertEqual(torch.bitwise_and(a, b), expected_res) + self.assertEqual(torch.bitwise_and(a, b_scalar), expected_res_scalar) + + # out + c = torch.empty(0, dtype=dtype, device=device) + torch.bitwise_and(a, b, out=c) + self.assertEqual(c, expected_res) + torch.bitwise_and(a, b_scalar, out=c) + self.assertEqual(c, expected_res_scalar) + + # in-place + a1 = a.clone() + a1.bitwise_and_(b) + self.assertEqual(a1, expected_res) + a.bitwise_and_(b_scalar) + self.assertEqual(a, expected_res_scalar) + + self.assertEqual(torch.tensor([False, True, False], device=device), + torch.bitwise_and(torch.tensor([True, True, False], device=device), + torch.tensor([False, True, False], device=device))) + + def test_bitwise_or(self, device): + for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + a = torch.tensor([1, -2, 3], dtype=dtype, device=device) + b = torch.tensor([2, 1, 3], dtype=dtype, device=device) + expected_res = torch.tensor([3, -1, 3], dtype=dtype, device=device) + b_scalar = 2 + expected_res_scalar = torch.tensor([3, -2, 3], dtype=dtype, device=device) + + # standard version + self.assertEqual(torch.bitwise_or(a, b), expected_res) + self.assertEqual(torch.bitwise_or(a, b_scalar), expected_res_scalar) + + # out + c = torch.empty(0, dtype=dtype, device=device) + torch.bitwise_or(a, b, out=c) + self.assertEqual(c, expected_res) + torch.bitwise_or(a, b_scalar, out=c) + self.assertEqual(c, expected_res_scalar) + + # in-place + a1 = a.clone() + a1.bitwise_or_(b) + self.assertEqual(a1, expected_res) + a.bitwise_or_(b_scalar) + self.assertEqual(a, expected_res_scalar) + + self.assertEqual(torch.tensor([True, True, False], device=device), + torch.bitwise_or(torch.tensor([True, True, False], device=device), + torch.tensor([False, True, False], device=device))) + + def test_bitwise_xor(self, device): + for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + a = torch.tensor([1, -2, 3], dtype=dtype, device=device) + b = torch.tensor([2, 1, 3], dtype=dtype, device=device) + expected_res = torch.tensor([3, -1, 0], dtype=dtype, device=device) + b_scalar = 2 + expected_res_scalar = torch.tensor([3, -4, 1], dtype=dtype, device=device) + + # standard version + self.assertEqual(torch.bitwise_xor(a, b), expected_res) + self.assertEqual(torch.bitwise_xor(a, b_scalar), expected_res_scalar) + + # out + c = torch.empty(0, dtype=dtype, device=device) + torch.bitwise_xor(a, b, out=c) + self.assertEqual(c, expected_res) + torch.bitwise_xor(a, b_scalar, out=c) + self.assertEqual(c, expected_res_scalar) + + # in-place + a1 = a.clone() + a1.bitwise_xor_(b) + self.assertEqual(a1, expected_res) + a.bitwise_xor_(b_scalar) + self.assertEqual(a, expected_res_scalar) + + self.assertEqual(torch.tensor([True, False, False], device=device), + torch.bitwise_xor(torch.tensor([True, True, False], device=device), + torch.tensor([False, True, False], device=device))) + + @onlyOnCPUAndCUDA + @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False), + torch.testing.get_all_dtypes(include_complex=False)))) + def test_heaviside(self, device, dtypes): + input_dtype = dtypes[0] + values_dtype = dtypes[1] + + rng = np.random.default_rng() + input = np.array(rng.integers(-10, 10, size=10), + dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64]) + input[0] = input[3] = input[7] = 0 + values = np.array(rng.integers(-10, 10, size=10), + dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64]) + np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype) + + input = torch.from_numpy(input).to(device=device, dtype=input_dtype) + values = torch.from_numpy(values).to(device=device, dtype=values_dtype) + out = torch.empty_like(input) + + if input_dtype == values_dtype: + torch_result = torch.heaviside(input, values) + self.assertEqual(np_result, torch_result) + + torch_result = input.heaviside(values) + self.assertEqual(np_result, torch_result) + + torch.heaviside(input, values, out=out) + self.assertEqual(np_result, out) + + input.heaviside_(values) + self.assertEqual(np_result, input) + else: + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + torch.heaviside(input, values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + input.heaviside(values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + torch.heaviside(input, values, out=out) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + input.heaviside_(values) + + @onlyCUDA + def test_heaviside_cross_device(self, device): + x = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda') + y = torch.tensor(0) + result = torch.heaviside(x, y) + expect = torch.tensor([0, 1, 0, 1, 0, 1], device='cuda') + self.assertEqual(result, expect) + + result = torch.heaviside(y, x) + expect = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda') + self.assertEqual(result, expect) + + x = torch.tensor([-9, 5, 0, 6, -2, 2]) + y = torch.tensor(0, device='cuda') + with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): + torch.heaviside(x, y) + + with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): + torch.heaviside(y, x) + + @dtypes(*list(product(torch.testing.get_all_complex_dtypes(), + torch.testing.get_all_complex_dtypes()))) + def test_heaviside_complex(self, device, dtypes): + input_dtype = dtypes[0] + values_dtype = dtypes[1] + + data = (complex(0, -6), complex(-1, 3), complex(1, 1)) + input = torch.tensor(data, device=device, dtype=input_dtype) + values = torch.tensor(data, device=device, dtype=values_dtype) + out = torch.empty_like(input) + real = input.real + + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + torch.heaviside(input, real) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + real.heaviside(values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + input.heaviside_(values) + with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + torch.heaviside(real, real, out=out) + + def _test_logical(self, device, dtypes, op, a_, b_, expected_res_): + expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device) + a = torch.tensor(a_, dtype=dtypes[0], device=device) + b = torch.tensor(b_, dtype=dtypes[1], device=device) + + # new tensor + self.assertEqual(expected_res.bool(), getattr(a, op)(b)) + # out + c = torch.empty(0, dtype=torch.bool, device=device) + getattr(torch, op)(a, b, out=c) + self.assertEqual(expected_res.bool(), c) + + # in-place + # TODO: remove when different dtypes as operands are supported + if dtypes[0] != dtypes[1]: + with self.assertRaises(RuntimeError): + getattr(a, op + '_')(b) + return + + getattr(a, op + '_')(b) + self.assertEqual(expected_res, a) + + @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) + def test_logical_xor(self, device, dtypes): + self._test_logical(device, dtypes, 'logical_xor', [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]) + + @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) + def test_logical_and(self, device, dtypes): + self._test_logical(device, dtypes, 'logical_and', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]) + + @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) + def test_logical_or(self, device, dtypes): + self._test_logical(device, dtypes, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]) + + def test_remainder_overflow(self, device): + # Check Integer Overflows + x = torch.tensor(23500, dtype=torch.int64, device=device) + q = 392486996410368 + self.assertEqual(x % q, x) + self.assertEqual(-x % q, q - x) + self.assertEqual(x % -q, x - q) + self.assertEqual(-x % -q, -x) + + def test_rpow(self, device): + m = torch.randn(10, 10, device=device) + self.assertEqual(torch.pow(2, m), 2**m) + + # test with scalar + m = torch.randn(1, device=device).squeeze() + assert m.dim() == 0, "m is intentionally a scalar" + self.assertEqual(torch.pow(2, m), 2**m) + + @onlyCPU + def test_ldexp(self, device): + # random values + mantissas = torch.randn(64, device=device) + exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32) + + # basic test + np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy()) + pt_outcome_1 = torch.ldexp(mantissas, exponents) + pt_outcome_2 = mantissas.ldexp(exponents) + self.assertEqual(np_outcome, pt_outcome_1) + self.assertEqual(np_outcome, pt_outcome_2) + mantissas.ldexp_(exponents) + self.assertEqual(np_outcome, mantissas) + + # test bounds + mantissas = torch.tensor([float('inf'), float('-inf'), float('inf'), float('nan')], device=device) + exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32) + np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy()) + pt_outcome = torch.ldexp(mantissas, exponents) + self.assertEqual(np_outcome, pt_outcome) + + def test_lerp(self, device): + start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)] + for shapes in product(start_end_shapes, start_end_shapes): + start = torch.randn(shapes[0], device=device) + end = torch.randn(shapes[1], device=device) + + # Tensor weights + for weight in [torch.randn(shapes[0], device=device), random.random()]: + actual = torch.lerp(start, end, weight) + actual_method = start.lerp(end, weight) + self.assertEqual(actual, actual_method) + actual_out = torch.Tensor().to(device) + torch.lerp(start, end, weight, out=actual_out) + self.assertEqual(actual, actual_out) + expected = start + weight * (end - start) + self.assertEqual(expected, actual) + + def _test_logaddexp(self, device, dtype, base2): + if base2: + ref_func = np.logaddexp2 + our_func = torch.logaddexp2 + else: + ref_func = np.logaddexp + our_func = torch.logaddexp + + def _test_helper(a, b): + ref = ref_func(a.cpu().numpy(), b.cpu().numpy()) + v = our_func(a, b) + self.assertEqual(ref, v) + + # simple test + a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 + b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 + _test_helper(a, b) + _test_helper(a[:3], b[:3]) + + # large value test for numerical stability + a *= 10000 + b *= 10000 + _test_helper(a, b) + _test_helper(a[:3], b[:3]) + + a = torch.tensor([float('inf'), float('-inf'), float('inf'), float("nan")], dtype=dtype, device=device) + b = torch.tensor([float('inf'), float('-inf'), float('-inf'), float("nan")], dtype=dtype, device=device) + _test_helper(a, b) + + @dtypes(torch.float32, torch.float64) + def test_logaddexp(self, device, dtype): + self._test_logaddexp(device, dtype, base2=False) + + @dtypes(torch.float32, torch.float64) + def test_logaddexp2(self, device, dtype): + self._test_logaddexp(device, dtype, base2=True) + + def test_add(self, device): + dtypes = [torch.float, torch.double] + torch.testing.get_all_complex_dtypes() + for dtype in dtypes: + # [res] torch.add([res,] tensor1, tensor2) + m1 = torch.randn(100, 100, dtype=dtype, device=device) + v1 = torch.randn(100, dtype=dtype, device=device) + + # contiguous + res1 = torch.add(m1[4], v1) + res2 = res1.clone().zero_() + for i in range(m1.size(1)): + res2[i] = m1[4, i] + v1[i] + self.assertEqual(res1, res2) + + m1 = torch.randn(100, 100, device=device) + v1 = torch.randn(100, device=device) + + # non-contiguous + res1 = torch.add(m1[:, 4], v1) + res2 = res1.clone().zero_() + for i in range(m1.size(0)): + res2[i] = m1[i, 4] + v1[i] + self.assertEqual(res1, res2) + + # [res] torch.add([res,] tensor, value) + m1 = torch.randn(10, 10, device=device) + + # contiguous + res1 = m1.clone() + res1[3].add_(2) + res2 = m1.clone() + for i in range(m1.size(1)): + res2[3, i] = res2[3, i] + 2 + self.assertEqual(res1, res2) + + # non-contiguous + m1 = torch.randn(10, 10, device=device) + res1 = m1.clone() + res1[:, 3].add_(2) + res2 = m1.clone() + for i in range(m1.size(0)): + res2[i, 3] = res2[i, 3] + 2 + self.assertEqual(res1, res2) + + # inter-type + m1 = torch.randn(10, 10, dtype=dtype, device=device) + self.assertEqual(m1 + 3, m1 + torch.tensor(3)) + self.assertEqual(3 + m1, torch.tensor(3) + m1) + + # contiguous + non-contiguous + m1 = torch.randn(10, 10, dtype=dtype, device=device) + m2 = torch.randn(10, 10, dtype=dtype, device=device).t() + res = m1 + m2 + self.assertTrue(res.is_contiguous()) + self.assertEqual(res, m1 + m2.contiguous()) + + # 1d + empty + m1 = torch.tensor([1.0], dtype=dtype, device=device) + m2 = torch.tensor([], dtype=dtype, device=device) + self.assertEqual(m1 + m2, []) + + # inter-type unint8 + one = torch.tensor(1, dtype=torch.uint8, device=device) + self.assertEqual(torch.add(one, 1), 2) + self.assertEqual(torch.add(one, 1).dtype, torch.uint8) + + # bool + m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) + m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) + expected = torch.tensor([True, True, False, True, False, True], dtype=torch.bool, device=device) + self.assertEqual(m1 + m2, expected) + + # fused multiply add + a = torch.zeros(2, 3, dtype=torch.bool, device=device) + res = torch.add(a, a, alpha=0) + expected = torch.zeros(2, 3, device=device).bool() + self.assertEqual(res, expected) + + # bfloat16 + m1 = torch.tensor([1., 2.], dtype=torch.bfloat16) + m2 = torch.tensor([3., 4.], dtype=torch.bfloat16) + self.assertEqual(m1 + m2, torch.tensor([4., 6.], dtype=torch.bfloat16)) + + # different alpha types + m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device) + m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device) + # add complex numbers with float alpha + res = torch.add(m1, m2, alpha=0.1) + expected = torch.tensor([2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device) + self.assertEqual(res, expected) + + # add complex numbers with complex alpha + res = torch.add(m1, m2, alpha=complex(0.1, 0.2)) + expected = torch.tensor([1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device) + self.assertEqual(res, expected) + + # add complex numbers with integer alpha + res = torch.add(m1, m2, alpha=2) + expected = torch.tensor([10. + 13.j, 8. + 11.j], dtype=torch.complex64, device=device) + self.assertEqual(res, expected) + + # mismatched alpha + m1 = torch.tensor([1], dtype=torch.int8, device=device) + m2 = torch.tensor([2], dtype=torch.int8, device=device) + self.assertRaisesRegex(RuntimeError, + r"Boolean alpha only supported for Boolean results\.", + lambda: torch.add(m1, m2, alpha=True)) + self.assertRaisesRegex(RuntimeError, + r"For integral input tensors, argument alpha must not be a floating point number\.", + lambda: torch.add(m1, m2, alpha=1.0)) + + # mismatched alpha, float / double tensor and complex alpha + m1 = torch.tensor([3., 4.], device=device) + m2 = torch.tensor([4., 3.], device=device) + self.assertRaises(RuntimeError, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))) + + m1 = torch.tensor([3., 4.], dtype=torch.double, device=device) + m2 = torch.tensor([4., 3.], dtype=torch.double, device=device) + self.assertRaises(RuntimeError, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))) + + # complex + m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64) + m2 = torch.tensor(4., dtype=torch.float64) + self.assertRaisesRegex(RuntimeError, r"result type ComplexFloat can't be cast to the desired output type Double", + lambda: torch.add(m1, m1, out=m2)) + + + def test_sub_typing(self, device): + m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) + m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) + self.assertRaisesRegex(RuntimeError, + r"Subtraction, the `\-` operator, with two bool tensors is not supported. " + r"Use the `\^` or `logical_xor\(\)` operator instead.", + lambda: m1 - m2) + self.assertRaisesRegex(RuntimeError, + r"Subtraction, the `\-` operator, with a bool tensor is not supported. " + r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", + lambda: 1 - m1) + self.assertRaisesRegex(RuntimeError, + r"Subtraction, the `\-` operator, with a bool tensor is not supported. " + r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", + lambda: m2 - 1) + + # mismatched alpha + m1 = torch.tensor([1], dtype=torch.int8, device=device) + m2 = torch.tensor([2], dtype=torch.int8, device=device) + self.assertRaisesRegex(RuntimeError, + r"Boolean alpha only supported for Boolean results\.", + lambda: torch.sub(m1, m2, alpha=True)) + self.assertRaisesRegex(RuntimeError, + r"For integral input tensors, argument alpha must not be a floating point number\.", + lambda: torch.sub(m1, m2, alpha=1.0)) + + def test_mul(self, device): + m1 = torch.randn(10, 10, device=device) + res1 = m1.clone() + res1[:, 3].mul_(2) + res2 = m1.clone() + for i in range(res1.size(0)): + res2[i, 3] = res2[i, 3] * 2 + self.assertEqual(res1, res2) + + a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device) + a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) + self.assertEqual(a1 * a2, torch.tensor([True, False, False, False], dtype=torch.bool, device=device)) + + if device == 'cpu': + a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device) + a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device) + self.assertEqual(a1 * a2, torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), atol=0.01, rtol=0) + self.assertEqual(a1.mul(a2), a1 * a2) + + def test_bool_tensor_comparison_ops(self, device): + a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool, device=device) + b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool, device=device) + self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) + self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) + self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)) + self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)) + self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a == torch.tensor(True, dtype=torch.bool, device=device), + torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a == torch.tensor(0, dtype=torch.bool, device=device), + torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device)) + self.assertFalse(a.equal(b)) + + @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) + def test_logical(self, device, dtype): + if dtype != torch.bool: + x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype) + b = torch.tensor([2], device=device, dtype=dtype) + self.assertEqual(x.lt(2), torch.tensor([True, False, False, False])) + self.assertEqual(x.le(2), torch.tensor([True, True, False, False])) + self.assertEqual(x.ge(2), torch.tensor([False, True, True, True])) + self.assertEqual(x.gt(2), torch.tensor([False, False, True, True])) + self.assertEqual(x.eq(2), torch.tensor([False, True, False, False])) + self.assertEqual(x.ne(2), torch.tensor([True, False, True, True])) + + self.assertEqual(x.lt(b), torch.tensor([True, False, False, False])) + self.assertEqual(x.le(b), torch.tensor([True, True, False, False])) + self.assertEqual(x.ge(b), torch.tensor([False, True, True, True])) + self.assertEqual(x.gt(b), torch.tensor([False, False, True, True])) + self.assertEqual(x.eq(b), torch.tensor([False, True, False, False])) + self.assertEqual(x.ne(b), torch.tensor([True, False, True, True])) + else: + x = torch.tensor([True, False, True, False], device=device) + self.assertEqual(x.lt(True), torch.tensor([False, True, False, True])) + self.assertEqual(x.le(True), torch.tensor([True, True, True, True])) + self.assertEqual(x.ge(True), torch.tensor([True, False, True, False])) + self.assertEqual(x.gt(True), torch.tensor([False, False, False, False])) + self.assertEqual(x.eq(True), torch.tensor([True, False, True, False])) + self.assertEqual(x.ne(True), torch.tensor([False, True, False, True])) + + def test_atan2(self, device): + def _test_atan2_with_size(size, device): + a = torch.rand(size=size, device=device, dtype=torch.double) + b = torch.rand(size=size, device=device, dtype=torch.double) + actual = a.atan2(b) + x = a.view(-1) + y = b.view(-1) + expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], + device=device, dtype=torch.double) + self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02) + + _test_atan2_with_size((2, 2), device) + _test_atan2_with_size((3, 3), device) + _test_atan2_with_size((5, 5), device) + + def test_atan2_edgecases(self, device): + def _test_atan2(x, y, expected, device, dtype): + expected_tensor = torch.tensor([expected], dtype=dtype, device=device) + x_tensor = torch.tensor([x], dtype=dtype, device=device) + y_tensor = torch.tensor([y], dtype=dtype, device=device) + actual = torch.atan2(y_tensor, x_tensor) + self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02) + + for dtype in [torch.float, torch.double]: + _test_atan2(0, 0, 0, device, dtype) + _test_atan2(0, 1, math.pi / 2, device, dtype) + _test_atan2(0, -1, math.pi / -2, device, dtype) + _test_atan2(-1, 0, math.pi, device, dtype) + _test_atan2(1, 0, 0, device, dtype) + _test_atan2(-1, -1, math.pi * -3 / 4 , device, dtype) + _test_atan2(1, 1, math.pi / 4 , device, dtype) + _test_atan2(1, -1, math.pi / -4 , device, dtype) + _test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype) + + def test_trapz(self, device): + def test_dx(sizes, dim, dx, device): + t = torch.randn(sizes, device=device) + actual = torch.trapz(t, dx=dx, dim=dim) + expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim) + self.assertEqual(expected.shape, actual.shape) + self.assertEqual(expected, actual) + + def test_x(sizes, dim, x, device): + t = torch.randn(sizes, device=device) + actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim) + expected = np.trapz(t.cpu().numpy(), x=x, axis=dim) + self.assertEqual(expected.shape, actual.shape) + self.assertEqual(expected, actual.cpu()) + + test_dx((2, 3, 4), 1, 1, device) + test_dx((10, 2), 0, 0.1, device) + test_dx((1, 10), 0, 2.3, device) + test_dx((0, 2), 0, 1.0, device) + test_dx((0, 2), 1, 1.0, device) + test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) + test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) + test_x((1, 10), 0, [1.0], device) + test_x((0, 2), 0, [], device) + test_x((0, 2), 1, [1.0, 2.0], device) + with self.assertRaisesRegex( + IndexError, + 'Dimension out of range'): + test_x((2, 3), 2, [], device) + test_dx((2, 3), 2, 1.0, device) + with self.assertRaisesRegex( + RuntimeError, + 'There must be one `x` value for each sample point'): + test_x((2, 3), 1, [1.0, 2.0], device) + test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) + + @dtypes(torch.double) + def test_pow_scalar_overloads_mem_overlap(self, device, dtype): + sz = 3 + doubles = torch.randn(2 * sz, dtype=dtype, device=device) + self.check_internal_mem_overlap( + lambda t: t.pow_(42), 1, dtype, device) + self.unary_check_input_output_mem_overlap( + doubles, sz, lambda input, out: torch.pow(input, 42, out=out)) + self.unary_check_input_output_mem_overlap( + doubles, sz, lambda input, out: torch.pow(42, input, out=out)) + + @dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False), + torch.testing.get_all_dtypes(include_bool=False)))) + def test_float_power(self, device, dtypes): + def to_np(value): + if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16: + return value.to(torch.float).cpu().numpy() + return value.cpu().numpy() if isinstance(value, torch.Tensor) else value + + base_dtype = dtypes[0] + exp_dtype = dtypes[1] + out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64 + + base = make_tensor((30,), device, base_dtype, low=1, high=100) + # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0 + # Related: https://github.com/pytorch/pytorch/issues/48000 + # base[0] = base[3] = base[7] = 0 + exp = make_tensor((30,), device, exp_dtype, low=-2, high=2) + exp[0] = exp[4] = exp[6] = 0 + + expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp))) + + exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2] + complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] + + for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_): + + # Case of Tensor x Tensor + if op is torch.Tensor.float_power_ and base_dtype != out_dtype: + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + op(base.clone(), exp) + else: + result = op(base.clone(), exp) + self.assertEqual(expected, result) + + if op is torch.float_power: + out = torch.empty_like(base).to(device=device, dtype=out_dtype) + op(base, exp, out=out) + self.assertEqual(expected, out) + + # Case of Tensor x Scalar + for i in complex_exponents if exp_dtype.is_complex else exponents: + out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64 + expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) + + if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp: + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + op(base.clone(), i) + else: + result = op(base.clone(), i) + self.assertEqual(expected_scalar_exp, result) + + if op is torch.float_power: + out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp) + op(base, i, out=out) + self.assertEqual(expected_scalar_exp, out) + + # Case of Scalar x Tensor + for i in complex_exponents if base_dtype.is_complex else exponents: + out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64 + expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp))) + + result = torch.float_power(i, exp) + self.assertEqual(expected_scalar_base, result) + + out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base) + torch.float_power(i, exp, out=out) + self.assertEqual(expected_scalar_base, out) + + def test_float_power_exceptions(self, device): + def _promo_helper(x, y): + for i in (x, y): + if type(i) == complex: + return torch.complex128 + elif type(i) == torch.Tensor and i.is_complex(): + return torch.complex128 + return torch.double + + test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25), + (torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.)) + for base, exp in test_cases: + for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble): + out = torch.empty(1, device=device, dtype=out_dtype) + required_dtype = _promo_helper(base, exp) + + if out.dtype == required_dtype: + torch.float_power(base, exp, out=out) + else: + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + torch.float_power(base, exp, out=out) + + if base.dtype == required_dtype: + torch.Tensor.float_power_(base.clone(), exp) + else: + with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + torch.Tensor.float_power_(base.clone(), exp) + + @skipIf(not TEST_SCIPY, "Scipy required for the test.") + @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False), + torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False))) + def test_xlogy(self, device, dtypes): + def out_variant_helper(torch_fn, x, y): + expected = torch_fn(x, y) + out = torch.empty_like(expected) + torch_fn(x, y, out=out) + self.assertEqual(expected, out) + + def inplace_variant_helper(x, y): + if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]: + with self.assertRaisesRegex(RuntimeError, + "can't be cast to the desired output type"): + x.clone().xlogy_(y) + else: + expected = torch.empty_like(x) + torch.xlogy(x, y, out=expected) + inplace_out = x.clone().xlogy_(y) + self.assertEqual(expected, inplace_out) + + x_dtype, y_dtype = dtypes + + # Tensor-Tensor Test (tensor of same and different shape) + x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000) + y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000) + z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000) + + torch_fn = partial(torch.xlogy, x) + reference_fn = partial(scipy.special.xlogy, x.cpu().numpy()) + + self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False) + out_variant_helper(torch.xlogy, x, x) + out_variant_helper(torch.xlogy, x, y) + out_variant_helper(torch.xlogy, x, z) + inplace_variant_helper(x, x) + inplace_variant_helper(x, y) + inplace_variant_helper(x, z) + + # Scalar-Tensor Test + torch_fn = partial(torch.xlogy, 3.14) + reference_fn = partial(scipy.special.xlogy, 3.14) + + self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False) + out_variant_helper(torch.xlogy, 3.14, x) + out_variant_helper(torch.xlogy, 3.14, y) + out_variant_helper(torch.xlogy, 3.14, z) + + # Special Values Tensor-Tensor + t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device) + zeros = torch.zeros(6, dtype=y_dtype, device=device) + + torch_fn = partial(torch.xlogy, zeros) + reference_fn = partial(scipy.special.xlogy, zeros.cpu().numpy()) + self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False) + out_variant_helper(torch.xlogy, zeros, t) + inplace_variant_helper(zeros, t) + + # Special Values Scalar-Tensor + torch_fn = partial(torch.xlogy, 0) + reference_fn = partial(scipy.special.xlogy, 0) + self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False) + out_variant_helper(torch.xlogy, 0, t) + + def test_xlogy_scalar_type_promotion(self, device): + # Test that python numbers don't participate in type promotion at the same + # priority level as 0-dim tensors + t = torch.randn((), dtype=torch.float32, device=device) + + self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype) + self.assertEqual(t.dtype, torch.xlogy(t, 5.).dtype) + + self.assertEqual(t.dtype, torch.xlogy(5, t).dtype) + self.assertEqual(t.dtype, torch.xlogy(5., t).dtype) + + @skipIf(not TEST_SCIPY, "Scipy required for the test.") + def test_xlogy_bfloat16(self, device): + def _compare_helper(x, y): + x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy() + y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy() + expected = torch.from_numpy(scipy.special.xlogy(x_np, y_np)) + actual = torch.xlogy(x, y) + self.assertEqual(expected, actual, exact_dtype=False) + + x_dtype, y_dtype = torch.bfloat16, torch.bfloat16 + + # Tensor-Tensor Test (tensor of same and different shape) + x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000) + y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000) + z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000) + + _compare_helper(x, x) + _compare_helper(x, y) + _compare_helper(x, z) + + _compare_helper(x, 3.14) + _compare_helper(y, 3.14) + _compare_helper(z, 3.14) + + # Special Values Tensor-Tensor + t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device) + zeros = torch.tensor(5, dtype=y_dtype, device=device) + _compare_helper(t, zeros) + _compare_helper(t, 0.) + +tensor_binary_ops = [ + '__lt__', '__le__', + '__gt__', '__ge__', + '__eq__', '__ne__', + + '__add__', '__radd__', '__iadd__', + '__sub__', '__rsub__', '__isub__', + '__mul__', '__rmul__', '__imul__', + '__matmul__', '__rmatmul__', '__imatmul__', + '__truediv__', '__rtruediv__', '__itruediv__', + '__floordiv__', '__rfloordiv__', '__ifloordiv__', + '__mod__', '__rmod__', '__imod__', + '__divmod__', '__rdivmod__', '__idivmod__', + '__pow__', '__rpow__', '__ipow__', + '__lshift__', '__rlshift__', '__ilshift__', + '__rshift__', '__rrshift__', '__irshift__', + '__and__', '__rand__', '__iand__', + '__xor__', '__rxor__', '__ixor__', + '__or__', '__ror__', '__ior__', +] + +# Test that binary math operations return NotImplemented for unknown types. +def generate_not_implemented_tests(cls): + class UnknownType: + pass + + # TODO: refactor to inline these + _types = [ + torch.half, torch.float, torch.double, + torch.int8, torch.short, torch.int, torch.long, + torch.uint8 + ] + + # TODO: refactor to use make_tensor + def _small_2d(dtype, device, has_zeros=True, fill_ones=False, oneish=False): + t = _make_tensor((5, 5), dtype, device, fill_ones=fill_ones) + if oneish: + return t.clamp(min=_number(.99, 1, dtype), max=1.01) + if not has_zeros: + return t.clamp(min=(_number(_div_min, 1, dtype))) + return t + + for op in tensor_binary_ops: + @dtypes(*_types) + def test(self, device, dtype): + # Generate the inputs + tensor = _small_2d(dtype, device) + + # Runs the tensor op on the device + result = getattr(tensor, op)(UnknownType()) + self.assertEqual(result, NotImplemented) + + test_name = "test_{}_not_implemented".format(op) + assert not hasattr(cls, test_name), "{0} already in {1}".format( + test_name, cls.__name__) + + setattr(cls, test_name, test) + + +generate_not_implemented_tests(TestBinaryUfuncs) +instantiate_device_type_tests(TestBinaryUfuncs, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index f57407c9b1d15..e12339f3acea0 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 import io +from typing import List + import torch import torch.utils.bundled_inputs from torch.testing._internal.common_utils import TestCase, run_tests @@ -27,7 +29,7 @@ def forward(self, arg): sm = torch.jit.script(SingleTensorModel()) original_size = model_size(sm) - get_expr = [] + get_expr : List[str] = [] samples = [ # Tensor with small numel and small storage. (torch.tensor([1]),), diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 85a1a927a8cda..f0a03d62acbe4 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -171,5 +171,28 @@ def test_rng(self): del copy2 self.assertEqual(rng_extension.getInstanceCount(), 0) + +@unittest.skipIf(not TEST_CUDA, "CUDA not found") +@unittest.skipIf(IS_WINDOWS, "MSVC have bug compiling this") +class TestTorchLibrary(common.TestCase): + + def test_torch_library(self): + import torch_test_cpp_extension.torch_library # noqa: F401 + + def f(a: bool, b: bool): + return torch.ops.torch_library.logical_and(a, b) + + self.assertTrue(f(True, True)) + self.assertFalse(f(True, False)) + self.assertFalse(f(False, True)) + self.assertFalse(f(False, False)) + s = torch.jit.script(f) + self.assertTrue(s(True, True)) + self.assertFalse(s(True, False)) + self.assertFalse(s(False, True)) + self.assertFalse(s(False, False)) + self.assertIn('torch_library::logical_and', str(s.graph)) + + if __name__ == "__main__": common.run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index 6c904a67e6196..3d52c99df856c 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1,27 +1,30 @@ +from itertools import repeat, chain, product +from typing import NamedTuple import collections +import gc import io -import tempfile -import unittest -import sys -from itertools import repeat, chain, product import os -import gc -import threading -import queue import pickle +import queue +import sys +import tempfile +import threading +import unittest import torch import torch.cuda import torch.cuda.comm as comm from torch import multiprocessing as mp +from torch.nn.parallel import scatter_gather +from torch.utils.checkpoint import checkpoint_sequential from torch._six import inf, nan, container_abcs from test_torch import AbstractTestCases from torch.testing._internal.common_methods_invocations import tri_tests_args, tri_large_tests_args, \ _compare_trilu_indices, _compare_large_trilu_indices -from torch.testing._internal.common_utils import TestCase, get_gpu_type, freeze_rng_state, run_tests, \ - NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_SANDCASTLE, \ +from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \ + NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \ slowTest, skipCUDANonDefaultStreamIf, TEST_WITH_ROCM, TEST_NUMPY from torch.testing._internal.autocast_test_lists import AutocastTestLists @@ -285,10 +288,10 @@ def test_cudart_register(self): self.assertFalse(t.is_pinned()) cudart = torch.cuda.cudart() r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) - self.assertEquals(r, 0) + self.assertEqual(r, 0) self.assertTrue(t.is_pinned()) r = cudart.cudaHostUnregister(t.data_ptr()) - self.assertEquals(r, 0) + self.assertEqual(r, 0) self.assertFalse(t.is_pinned()) def test_memory_stats(self): @@ -382,8 +385,37 @@ def advance(gen, end): def test_out_of_memory(self): tensor = torch.zeros(1024, device='cuda') - with self.assertRaisesRegex(RuntimeError, "Tried to allocate 80.00 GiB"): - torch.empty(1024 * 1024 * 1024 * 80, dtype=torch.int8, device='cuda') + with self.assertRaisesRegex(RuntimeError, "Tried to allocate 8000000000.00 GiB"): + torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device='cuda') + + # ensure out of memory error doesn't disturb subsequent kernel + tensor.fill_(1) + self.assertTrue((tensor == 1).all()) + + def test_set_per_process_memory_fraction(self): + # test invalid fraction value. + with self.assertRaisesRegex(TypeError, "Invalid type"): + torch.cuda.set_per_process_memory_fraction(int(1)) + with self.assertRaisesRegex(ValueError, "Invalid fraction value"): + torch.cuda.set_per_process_memory_fraction(-0.1) + with self.assertRaisesRegex(ValueError, "Invalid fraction value"): + torch.cuda.set_per_process_memory_fraction(2.0) + + tensor = torch.zeros(1024, device='cuda') + torch.cuda.empty_cache() + total_memory = torch.cuda.get_device_properties(0).total_memory + torch.cuda.set_per_process_memory_fraction(0.5, 0) + + # test 0.499 allocation is ok. + application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved() + tmp_tensor = torch.empty(application, dtype=torch.int8, device='cuda') + del tmp_tensor + torch.cuda.empty_cache() + + application = int(total_memory * 0.5) + # it will get OOM when try to allocate more than half memory. + with self.assertRaisesRegex(RuntimeError, "out of memory"): + torch.empty(application, dtype=torch.int8, device='cuda') # ensure out of memory error doesn't disturb subsequent kernel tensor.fill_(1) @@ -490,7 +522,6 @@ def _test_copy_non_blocking(a, b): event = torch.cuda.Event() a.copy_(b, non_blocking=True) event.record() - self.assertFalse(event.query()) event.synchronize() self.assertEqual(a, b) @@ -503,22 +534,35 @@ def _test_copy_non_blocking(a, b): y = torch.ones(10000000, dtype=torch.uint8).cuda() _test_copy_non_blocking(x, y) - @unittest.skip("skipped because test could be flaky, see #35144") def test_to_non_blocking(self): - def _test_to_non_blocking(a, non_blocking): - stream = torch.cuda.current_stream() - with torch.cuda.stream(stream): - b = a.to('cuda', non_blocking=non_blocking) - self.assertEqual(stream.query(), not non_blocking) - stream.synchronize() - self.assertEqual(a, b) + stream = torch.cuda.current_stream() - # 10MB copies - x = torch.ones(10000000, dtype=torch.uint8) - _test_to_non_blocking(x, True) + def _test_to_non_blocking(a, non_blocking, dst): + torch.cuda.synchronize() + # Pushes an 0.1 second spin to stream so if the copy is non blocking, + # stream will almost surely be active when we query(). + torch.cuda._sleep(int(100 * get_cycles_per_ms())) + b = a.to(device=dst, non_blocking=non_blocking) + self.assertEqual(stream.query(), not non_blocking) + stream.synchronize() + self.assertEqual(a, b) + self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu")) + + for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)): + # Creates source on the opposite device from destination. + src = torch.randn(1000000, + device="cuda" if dst == "cpu" else "cpu", + pin_memory=True if dst == "cuda" else False) + _test_to_non_blocking(src, try_non_blocking, dst) - y = torch.ones(10000000, dtype=torch.uint8) - _test_to_non_blocking(y, False) + def test_to_cpu_blocking_by_default(self): + src = torch.randn(1000000, device="cuda") + torch.cuda.synchronize() + torch.cuda._sleep(int(100 * get_cycles_per_ms())) + dst = src.to(device="cpu") + self.assertEqual(torch.cuda.current_stream().query(), True) + self.assertEqual(src, dst) + self.assertFalse(dst.is_pinned()) def test_serialization_array_with_storage(self): x = torch.randn(5, 5).cuda() @@ -721,12 +765,6 @@ def test_cuda_set_device(self): torch.cuda.set_device(1) self.assertEqual(x.cuda().get_device(), 0) - def test_is_tensor(self): - for t in types: - tensor = get_gpu_type(t)() - self.assertTrue(torch.is_tensor(tensor)) - self.assertTrue(torch.is_tensor(torch.cuda.HalfTensor())) - def test_cuda_synchronize(self): torch.cuda.synchronize() torch.cuda.synchronize('cuda') @@ -964,7 +1002,6 @@ def test_streams_multi_gpu_eq(self): self.assertNotEqual(hash(s0), hash(s3)) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - @skipIfRocm def test_streams_priority(self): low, high = torch.cuda.Stream.priority_range() s0 = torch.cuda.Stream(device=0, priority=low) @@ -1543,7 +1580,6 @@ def test_bincount_ext(self): counted = t.bincount(minlength=65536) self.assertEqual(torch.sum(counted), 10) - @skipIfRocm def test_tiny_half_norm_(self): a = torch.arange(25).cuda().float() a /= 100000000 @@ -1731,8 +1767,53 @@ def test_streaming_backwards_device_transfer(self): self.assertTrue(a.grad.sum().item() == 4 * size) self.assertTrue(b.grad.sum().item() == 4 * size) + def test_streaming_backward_sync_graph_root(self): + # This function tests if bwd ops running on a side stream properly sync with the GraphRoot. + # The potential bug it targets is a race condition. The test uses multiple trials and + # torch.cuda._sleep such that if the race condition exists, the test will almost certainly fail, + # but there's a chance it may spuriously pass. Passing does not guarantee the backend is bug-free, + # but failure does guarantee there is a bug. + fwd_bwd_op_stream = torch.cuda.Stream() + bwd_ambient_stream = torch.cuda.Stream() + # We need these streams to be different otherwise the test is meaningless. + self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream) + + size = int(1e3) + + a = torch.full((size,), 2.0, device="cuda", requires_grad=True) + b = torch.full((size,), 3.0, device="cuda", requires_grad=True) + + # I don't think we need any manual record_streams below. + # a and b remain in scope for the entire test. + # c and grad remain in scope for each iteration, and there's a full sync between iterations. + for trial in range(5): + torch.cuda.synchronize() + a.grad = b.grad = None + with torch.cuda.stream(fwd_bwd_op_stream): + c = a * b + + with torch.cuda.stream(bwd_ambient_stream): + torch.cuda.synchronize() + # Long-running dummy kernel on bwd_ambient_stream delays filling of grad + torch.cuda._sleep(int(50 * get_cycles_per_ms())) + # Fills grad on bwd_ambient_stream + grad = torch.full((size,), float(trial + 1), device="cuda") + + # Bwd ops still run on fwd_bwd_ops_stream, so the following will likely fail if + # bwd ops don't sync with bwd_ambient_stream before consuming grad. + torch.autograd.backward(tensors=c, grad_tensors=grad) + + # See https://github.com/pytorch/pytorch/issues/47028 + # assertEquals below run on bwd_ambient_stream, so this test may also fail + # if backward() fails to sync with bwd_ambient_stream at the end. + # Synchronizing here works around the issue until a proper fix can be made. + torch.cuda.synchronize() + with torch.no_grad(): + self.assertEqual(a.grad, grad * b) + self.assertEqual(b.grad, grad * a) + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @unittest.skipIf(not IS_SANDCASTLE, "Does not work on Sandcastle") + @unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle") def test_cuda_init_race(self): # See https://github.com/pytorch/pytorch/issues/16559 import subprocess @@ -1749,32 +1830,102 @@ def worker(rank): t2.start() """]) - def test_grad_scaling_builtins(self, device="cuda", dtype=torch.float): - inv_scale = torch.tensor([0.25], dtype=dtype, device=device) + def test_grad_scaling_unscale(self, dtype=torch.float): + inv_scale = torch.full((1,), 0.25, dtype=torch.float, device="cuda:0") + found_inf = torch.full((1,), 0.0, dtype=torch.float, device="cuda:0") + + size = 10 + g = torch.full((size, size), 4.0, dtype=dtype, device="cuda:0") + ginf = g.clone() + ginf[2, 2] = float('inf') + gnan = g.clone() + gnan[2, 2] = float('nan') + + # Tries selected combinations of + # - contiguous grads + # - g.clone().t() which is not contiguous but still non overlapping and dense + # - variants of g.clone()[:, :5] which are not non overlapping and dense + # Non overlapping and dense grads route into a multi tensor apply kernel, + # others use a fallback per-tensor kernel, so we should try both. + cases = ( + ([g.clone(), g.clone()], False), + ([g.clone(), g.clone().t()], False), + ([g.clone(), g.clone()[:, :5]], False), + ([g.clone()[:, :5], g.clone()[:, :5]], False), + ([g.clone(), ginf.clone()], True), + ([g.clone(), gnan.clone()], True), + ([g.clone(), ginf.clone()[:, :5]], True), + ([g.clone(), gnan.clone()[:, :5]], True), + ([ginf.clone(), g.clone()[:, :5]], True), + ([ginf.clone()[:, :5], g.clone()[:, :5]], True), + ) + + for grads, has_inf in cases: + found_inf.zero_() + torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale) + if has_inf: + self.assertEqual(found_inf, 1.0) + else: + self.assertEqual(found_inf, 0.0) + for grad in grads: + self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) - found_inf = torch.tensor([0.0], dtype=dtype, device=device) - g = torch.tensor([4.0], dtype=dtype, device=device) - torch._amp_non_finite_check_and_unscale_(g, found_inf, inv_scale) - self.assertEqual(found_inf, 0.0) - self.assertTrue(torch.allclose(g, torch.ones(10, dtype=torch.float32, device="cuda"), atol=1e-7)) + # Passing lists with mismatched devices or dtypes to a raw + # _amp_foreach_non_finite_check_and_unscale_ call should raise errors. + with self.assertRaisesRegex(RuntimeError, r"must have the same dtype"): + torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(dtype=torch.float16)], + found_inf, + inv_scale) - found_inf.zero_() - g = torch.tensor([float('inf')], dtype=dtype, device=device) - torch._amp_non_finite_check_and_unscale_(g, found_inf, inv_scale) - self.assertEqual(found_inf, 1.0) + if TEST_MULTIGPU: + with self.assertRaisesRegex(RuntimeError, r"scaled_grads must be on the same device."): + torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(device="cuda:1")], + found_inf, + inv_scale) + + # Creates a list of grads with mismatched dtypes and devices, to ensure + # scaler._unscale_grads_ organizes grads by dtype and device before calling + # _amp_foreach_non_finite_check_and_unscale_ on each set. + # If inject_inf >= 0, writes an inf into one grad for _unscale_grads_ to find. + def perfect_storm_grads(inject_inf): + grads = [g.clone(), g.clone()[:, :5], g.to(dtype=torch.float16), g.to(dtype=torch.float16)] + if TEST_MULTIGPU: + grads += [g.to(device="cuda:1"), + g.to(device="cuda:1")[:, :5], + g.to(device="cuda:1", dtype=torch.float16), + g.to(device="cuda:1", dtype=torch.float16)] + if inject_inf >= 0: + grads[inject_inf][2, 2] = float('inf') + return grads - found_inf.zero_() - g = torch.tensor([float('nan')], dtype=dtype, device=device) - torch._amp_non_finite_check_and_unscale_(g, found_inf, inv_scale) - self.assertEqual(found_inf, 1.0) + scaler = torch.cuda.amp.GradScaler() + dummy_params = [torch.empty_like(g) for g in perfect_storm_grads(-1)] + dummy_opt = torch.optim.SGD(dummy_params, lr=1.) + + # Ensures the inf/nan checking can find an inf injected onto any grad in the perfect storm. + for inject_inf in range(-1, len(dummy_params)): + found_inf = torch.full((1,), 0.0, dtype=torch.float, device="cuda:0") + grads = perfect_storm_grads(inject_inf) + for i, p in enumerate(dummy_params): + p.grad = grads[i] + found_inf_per_device = scaler._unscale_grads_(dummy_opt, inv_scale, found_inf, True) + if inject_inf < 0: + # No inf was injected, ensures unscaling worked normally. + self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 0) + for grad in grads: + self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) + else: + # inf was injected, ensures inf was found. + self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 1) + def test_grad_scaling_update_scale(self, device="cuda", dtype=torch.float): growth = 2.0 backoff = 0.25 growth_interval = 2 - scale = torch.tensor([4.0], dtype=dtype, device=device) - growth_tracker = torch.tensor([0], dtype=torch.int32, device=device) + scale = torch.full((1,), 4.0, dtype=dtype, device=device) + growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device) + found_inf = torch.full((1,), 0.0, dtype=torch.float, device="cuda:0") - found_inf.zero_() # Simulates 2 consecutive unskipped iterations scale = torch._amp_update_scale(growth_tracker, scale, found_inf, growth, backoff, growth_interval) self.assertEqual(growth_tracker, 1) @@ -1792,7 +1943,7 @@ def test_grad_scaling_builtins(self, device="cuda", dtype=torch.float): def test_grad_scaling_unscale_sparse(self, device="cuda", dtype=torch.float): scaler = torch.cuda.amp.GradScaler() - inv_scale = torch.tensor([0.25], dtype=dtype, device=device) + inv_scale = torch.full((1,), 0.25, dtype=dtype, device=device) found_inf = torch.empty((1,), dtype=dtype, device=device) cur = found_inf.device @@ -1855,6 +2006,7 @@ def test_grad_scaling_device_as_key(self): # are treated as identical keys by dicts. GradScaler relies on this behavior, and may # error otherwise in a way that's difficult to detect (a silent performance hit). d = {} + t = torch.empty((1,), device="cuda:0") dev0a = torch.device("cuda:0") dev0b = torch.device("cuda:0") dev1a = torch.device("cuda:1") @@ -1867,6 +2019,9 @@ def test_grad_scaling_device_as_key(self): d[dev0b] = "0b" self.assertTrue(len(d) == 1) self.assertTrue(d[dev0a] == "0b") + d[t.device] = "t" + self.assertTrue(len(d) == 1) + self.assertTrue(d[dev0a] == "t") d[dev1a] = "1a" d[dev1b] = "1b" @@ -1876,8 +2031,8 @@ def test_grad_scaling_device_as_key(self): @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") def test_grad_scaling_scale(self): scaler = torch.cuda.amp.GradScaler(init_scale=2.) - t0 = torch.tensor([4.0], dtype=torch.float32, device="cuda:0") - t1 = torch.tensor([4.0], dtype=torch.float32, device="cuda:1") + t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0") + t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1") # Create some nested iterables of tensors on different devices. outputs = (t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), (t1.clone(), t0.clone())]) outputs = scaler.scale(outputs) @@ -1895,7 +2050,7 @@ def test_grad_scaling_state_dict(self): if lazy_init_scale: # Dummy scale() call to ensure the scale tensor is lazily initialized. - s1.scale(torch.tensor([4.0], dtype=torch.float32, device="cuda:0")) + s1.scale(torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0")) self.assertTrue(isinstance(s1._scale, torch.cuda.FloatTensor)) s1.load_state_dict(s0.state_dict()) @@ -2034,6 +2189,7 @@ def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): self._run_scaling_case(run, unskipped=3, skipped=1) + @unittest.skipIf(IS_WINDOWS, 'FIXME: fix this test for Windows') def test_grad_scaling_penalty(self): def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): for i, (input, target) in enumerate(data): @@ -2260,7 +2416,6 @@ def _worker(t): self.assertEqual(results[t].sum().item(), size * size) @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - @skipIfRocm def test_cudnn_multiple_threads_same_device(self): # This function is intended to test the lazy creation and reuse of per-thread # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp. @@ -2406,7 +2561,7 @@ def cast(val, to_type): "{} not found as an attribute on either Tensor or the requested module {}".format( op, module)) - # Accounts for ops that return tuples and other non-Tensors. + # Accounts for ops that return Tensors, iterables, and other non-Tensors. # For example, lstm_cell returns a tuple and equal returns bool. def compare(first, second): if isinstance(first, torch.Tensor): @@ -2706,6 +2861,33 @@ def test_autocast_rnn(self): for grad, grad_control in zip(grads, grads_control): self.assertEqual(grad.half(), grad_control) + def test_autocast_cache_leak(self): + # Reported at https://github.com/pytorch/pytorch/issues/48049 + # Test is used to check, if autocast recaches the same parameters + # when executed in a `torch.no_grad()` block. + + linear = torch.nn.Linear(10, 10).to('cuda') + data = torch.randn(1, 10, device='cuda') + + with torch.cuda.amp.autocast(): + with torch.no_grad(): + out = linear(data) + first_iter_mem = torch.cuda.memory_allocated() + for _ in range(3): + out = linear(data) + self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) + + def test_autocast_checkpointing(self): + model = torch.nn.Sequential(torch.nn.Linear(8, 8), + torch.nn.Linear(8, 8), + torch.nn.Linear(8, 8)).cuda() + input = torch.rand((8, 8), device="cuda", dtype=torch.float16, requires_grad=True) + with torch.cuda.amp.autocast(): + output = checkpoint_sequential(model, 2, input) + self.assertTrue(output.requires_grad) + self.assertTrue(output.dtype is torch.float16) + output.sum().backward() + @slowTest @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") def test_max_large_axis(self): @@ -2719,6 +2901,217 @@ def test_max_large_axis(self): def test_to_numpy(self): self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_capture_simple(self): + s1 = torch.cuda.Stream() + + with torch.cuda.stream(s1): + a = torch.zeros((1000,), device="cuda") + a += 1 + g = torch.cuda._Graph() + g.capture_begin() + a += 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s1) + + g.replay() + g.replay() + + self.assertTrue(a.sum().item() == 3000.) + + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_rng_functional(self): + # The caching allocator isn't yet graph-safe. + # In this test, graphed regions try to ensure allocator safety by + # stashing references to all temporaries. This is why we use _fused_dropout + # instead of a public dropout API: _fused_dropout returns the mask temporary + # as well as the output, so we can stash references to both. + # + # TODO: + # Switch to public dropout API when the allocator is made graph-safe. + ops_with_kwargs = ((torch._fused_dropout, {"p": 0.1}), + (torch.nn.functional.rrelu, {"training": True}),) + size = 10000 + + def run(op, kwargs): + a = torch.randn((size,), device="cuda", dtype=torch.float) + + torch.cuda.manual_seed(5) + + # Control + eager_out = a + for _ in range(6): + out = op(eager_out, **kwargs) + # _fused_dropout returns a tuple, rrelu returns a bare tensor. + eager_out = out[0] if isinstance(out, tuple) else out + + graph_in = a.clone() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + # warms up allocator so no mallocs occur in capture + refs = () + graph_out = graph_in + for _ in range(3): + out = op(graph_out, **kwargs) + refs += tuple(out) + graph_out = out[0] if isinstance(out, tuple) else out + del out, refs, graph_out + + torch.cuda.manual_seed(5) + + refs = () + g = torch.cuda._Graph() + g.capture_begin() + graph_out = graph_in + for _ in range(2): + out = op(graph_out, **kwargs) + refs += tuple(out) + graph_out = out[0] if isinstance(out, tuple) else out + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + # Runs a graphed->eager->graphed sequence of RNG ops. + # replay() plays 2 invocations of the op, so the sequence has 6 + # invocations total, matching Control. + # replay() reads from graph_in and writes to graph_out. + g.replay() + out = op(graph_out, **kwargs) + out = op(out[0], **kwargs)[0] if isinstance(out, tuple) else op(out, **kwargs) + graph_in.copy_(out) + g.replay() + + # If replay() updated RNG state correctly, graph_out + # should now hold data equal to eager_out. + try: + self.assertEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op, kwargs in ops_with_kwargs: + run(op, kwargs) + + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_rng_distributions(self): + # The caching allocator isn't yet graph-safe. + # In this test, all ops maintain static references to inputs and outputs + # that persist across replay(), so they should be safe to test with graphs, + # EXCEPT for multinomial which is a complicated compound op. + # + # TODO: + # Uncomment multinomial when the allocator is made graph-safe. + size = 10000 + input = torch.rand((size,), device="cuda", dtype=torch.float) + alloc = torch.empty((size,), device="cuda", dtype=torch.float) + + # Torch ops to test with sample args (tuple) and kwargs (dict) + torch_with_args = (("bernoulli", (input.clone(),), {}), + # ("multinomial", (input.clone(), size, True), {}), + # ("multinomial", (input.clone(), size // 2, False), {}), + ("normal", (input.clone() + 1, input.clone()), {}), + ("poisson", (input.clone(),), {}), + ("rand", (size,), {"device": "cuda", "dtype": torch.float}), + ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), + ("randn", (size,), {"device": "cuda", "dtype": torch.float}),) + + # Tensor methods to test with sample args (tuple) + tensor_with_args = (("bernoulli_", (input.clone(),)), + ("cauchy_", ()), + ("exponential_", ()), + ("geometric_", (0.3,)), + ("log_normal_", ()), + ("normal_", ()), + ("random_", ()), + ("uniform_", ()),) + + def run(module, op, args, kwargs): + torch.cuda.manual_seed(5) + + # Each path runs a dummy op to increment the state a bit before creating controls. + if (module == "torch"): + dummy = getattr(torch, op)(*args, **kwargs) + control1 = getattr(torch, op)(*args, **kwargs) + control2 = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + control1 = alloc.clone() + control2 = alloc.clone() + getattr(dummy, op)(*args) + getattr(control1, op)(*args) + getattr(control2, op)(*args) + + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + torch.cuda.manual_seed(5) + + g = torch.cuda._Graph() + if (module == "torch"): + g.capture_begin() + t1 = getattr(torch, op)(*args, **kwargs) + t2 = getattr(torch, op)(*args, **kwargs) + g.capture_end() + else: + t1 = alloc.clone() + t2 = alloc.clone() + g.capture_begin() + getattr(t1, op)(*args) + getattr(t2, op)(*args) + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + try: + self.assertNotEqual(control1, t1) + self.assertNotEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # Runs a dummy op prelude, as for controls, to make sure replay() + # picks up the dummy op's state increment. + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + getattr(dummy, op)(*args) + + # Runs RNG ops that fill t1 and t2. + g.replay() + + try: + self.assertEqual(control1, t1) + self.assertEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op_with_args in torch_with_args: + run("torch", *op_with_args) + + for meth_with_args in tensor_with_args: + # Adds an empty dict for kwargs, which none of the Tensor methods use + run("Tensor", *(meth_with_args + ({},))) + + def test_batch_norm_gather_stats(self): + input = torch.randn(1, 3, 3, 3, device='cuda') + mean, invstd = torch.batch_norm_gather_stats( + input, mean=torch.ones(2, 3, device='cuda'), invstd=torch.ones(2, 3, device='cuda'), + running_mean=None, running_var=None , momentum=.1, eps=1e-5, count=2 + ) + self.assertEqual(mean, torch.ones(3, device='cuda')) + self.assertEqual(invstd, torch.ones(3, device='cuda')) class TestCudaComm(TestCase): def _test_broadcast(self, input): @@ -2841,9 +3234,9 @@ def test_reduce_add(self): self.assertEqual(result.cpu(), x + y) def _test_reduce_add_coalesced(self, tensors, buffer_size): - dup_tensors = [tensors, list(map(lambda t: t.cuda(1), tensors))] + dup_tensors = [tensors, [t.cuda(1) for t in tensors]] - r_tensors = list(map(comm.reduce_add, zip(*dup_tensors))) + r_tensors = [comm.reduce_add(t) for t in zip(*dup_tensors)] for r, t in zip(r_tensors, tensors): self.assertEqualTypeString(r, t) self.assertEqual(r, t * 2) @@ -3060,6 +3453,49 @@ def test_matmul_device_mismatch(self): with self.assertRaisesRegex(RuntimeError, "expected (it|them) to be on GPU"): torch.addmm(s, m1, m2) + @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") + def test_scatter_namedtuple(self): + # tests ability to scatter namedtuples and retrieve a list where each + # element is of the expected namedtuple type. + fields = ("a", "b") + TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields) + num_gpus = torch.cuda.device_count() + a = torch.rand(num_gpus * 2, device=0) + b = torch.rand(num_gpus * 2, device=0) + a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + + inp = TestNamedTupleInput_0(a, b) + target_gpus = [torch.device(i) for i in range(num_gpus)] + scatter_out = scatter_gather.scatter(inp, target_gpus) + + for i, x in enumerate(scatter_out): + self.assertTrue(isinstance(x, type(inp))) + self.assertEqual(x._fields, fields) + expected_a = a_tensors_for_gpu[i] + expected_b = b_tensors_for_gpu[i] + self.assertEqual(expected_a, x.a) + self.assertEqual(expected_b, x.b) + + class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor + + a = torch.rand(num_gpus * 2, device=0) + b = torch.rand(num_gpus * 2, device=0) + a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + inp = TestNamedTupleInput_1(a, b) + + scatter_out = scatter_gather.scatter(inp, target_gpus) + for i, x in enumerate(scatter_out): + self.assertTrue(isinstance(x, type(inp))) + self.assertEqual(x._fields, fields) + expected_a = a_tensors_for_gpu[i] + expected_b = b_tensors_for_gpu[i] + self.assertEqual(expected_a, x.a) + self.assertEqual(expected_b, x.b) + if __name__ == '__main__': run_tests() diff --git a/test/test_dataloader.py b/test/test_dataloader.py index ce23593ec7bc9..edc31b75485e4 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -3,6 +3,7 @@ import errno import os import ctypes +import faulthandler import torch import gc import time @@ -11,13 +12,15 @@ import itertools import warnings import tempfile +import random from torch import multiprocessing as mp -from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset +from torch.utils.data import (_utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, + ChainDataset, BufferedShuffleDataset) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, - IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, + IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, load_tests, TEST_WITH_ROCM, TEST_WITH_TSAN, IS_SANDCASTLE) try: @@ -32,18 +35,6 @@ else: warnings.warn(err_msg) -try: - import faulthandler - HAS_FAULTHANDLER = True -except ImportError: - HAS_FAULTHANDLER = False - err_msg = ("faulthandler not found. Some data loader tests use it for error " - "reporting (e.g., TestDataLoader.test_proper_exit).") - if IS_PYTORCH_CI: - raise ImportError(err_msg) from None - else: - warnings.warn(err_msg) - # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -84,9 +75,7 @@ JOIN_TIMEOUT = 60.0 # seconds -supported_multiprocessing_contexts = [None] -if torch.multiprocessing._supports_context: - supported_multiprocessing_contexts += list(torch.multiprocessing.get_all_start_methods()) +supported_multiprocessing_contexts = [None] + list(torch.multiprocessing.get_all_start_methods()) @unittest.skipIf( @@ -310,29 +299,25 @@ def test_iterable_dataset_err(self): # takes in dummy var so this can also be used as a `worker_init_fn` def set_faulthander_if_available(_=None): - if HAS_FAULTHANDLER: - faulthandler.enable(sys.__stderr__) - if not IS_WINDOWS: - # windows does not have faulthandler.register - # chain=False prevents the default behavior of killing the process - faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False) + faulthandler.enable(sys.__stderr__) + if not IS_WINDOWS: + # windows does not have faulthandler.register + # chain=False prevents the default behavior of killing the process + faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False) set_faulthander_if_available() # Process `pid` must have called `set_faulthander_if_available` def print_traces_of_all_threads(pid): - if HAS_FAULTHANDLER: - if not IS_WINDOWS: - # use the custom signal if available - os.kill(pid, signal.SIGUSR1) - else: - # otherwise we can still use the handler given by faulthandler.enable() - # at the cost of killing the process. - os.kill(pid, signal.SIGSEGV) + if not IS_WINDOWS: + # use the custom signal if available + os.kill(pid, signal.SIGUSR1) else: - # if there is no faulthandler, use SIGINT otherwise and hope for the best - os.kill(pid, signal.SIGINT) + # otherwise we can still use the handler given by faulthandler.enable() + # at the cost of killing the process. + os.kill(pid, signal.SIGSEGV) + # wait in parent process to give subprocess some time to print time.sleep(5) @@ -482,6 +467,17 @@ def __len__(self): return self.size +class EmptyTensorDataset(torch.utils.data.Dataset): + def __init__(self, len): + self.len = len + + def __len__(self): + return self.len + + def __getitem__(self, any): + return torch.empty(0) + + class SynchronizedSeedDataset(SynchronizedDataset): def __getitem__(self, idx): self.sync_once() @@ -502,6 +498,24 @@ def _test_timeout_pin_memory(persistent_workers): _ = next(iter(dataloader)) +def _test_large_sampler_indices(persistent_workers): + # See + # test_large_sampler_indices + # https://github.com/pytorch/pytorch/issues/48666 + + dataloader = torch.utils.data.DataLoader( + EmptyTensorDataset(10000000), + batch_size=40960, + persistent_workers=persistent_workers, + num_workers=1) + + it = iter(dataloader) + + for x in it: + assert x.numel() == 0 + raise RuntimeError('My Error') + + def disable_stderr(worker_id): r""" Avoids printing "ERROR: Unexpected segmentation fault encountered in worker." @@ -710,6 +724,10 @@ def init_fn(worker_id): torch.manual_seed(12345) +def shuffle_ds_init_fn(worker_id): + random.seed(123) + + # used with test_error_in_init class ErrorIterableDataset(IterableDataset): def __iter__(self): @@ -972,6 +990,24 @@ def test_timeout(self): finally: p.terminate() + def test_large_sampler_indices(self): + # Test that the data loader cleanly exit when the process errors + # 1. having an reference to the iterator + # 2. using a sampler that yields big elements s.t. _index_queues putters block + # + # More context: https://github.com/pytorch/pytorch/issues/48666 + + p = ErrorTrackingProcess(target=_test_large_sampler_indices, args=(self.persistent_workers,)) + p.start() + p.join(JOIN_TIMEOUT) + try: + self.assertFalse(p.is_alive()) + self.assertNotEqual(p.exitcode, 0) + self.assertIsInstance(p.exception, RuntimeError) + self.assertRegex(str(p.exception), r'My Error') + finally: + p.terminate() + def test_invalid_ctor_args_combinations(self): # general with self.assertRaisesRegex(ValueError, "num_workers option should be non-negative"): @@ -984,17 +1020,13 @@ def test_invalid_ctor_args_combinations(self): "batch_size=None option disables auto-batching and is mutually exclusive"): self._get_data_loader(self.dataset, batch_size=None, drop_last=True) - if torch.multiprocessing._supports_context: - valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1] - with self.assertRaisesRegex(ValueError, r"multi-process loading \(num_workers > 0\), but got"): - self._get_data_loader(self.dataset, num_workers=0, multiprocessing_context=valid_ctx) - with self.assertRaisesRegex(ValueError, "should specify a valid start method in"): - self._get_data_loader(self.dataset, num_workers=1, multiprocessing_context='bad') - with self.assertRaisesRegex(TypeError, "multiprocessing_context option should be a valid context "): - self._get_data_loader(self.dataset, num_workers=1, multiprocessing_context=object()) - else: - with self.assertRaisesRegex(ValueError, "multiprocessing_context relies on Python >= 3.4"): - self._get_data_loader(self.dataset, num_workers=1, multiprocessing_context='fork') + valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1] + with self.assertRaisesRegex(ValueError, r"multi-process loading \(num_workers > 0\), but got"): + self._get_data_loader(self.dataset, num_workers=0, multiprocessing_context=valid_ctx) + with self.assertRaisesRegex(ValueError, "should specify a valid start method in"): + self._get_data_loader(self.dataset, num_workers=1, multiprocessing_context='bad') + with self.assertRaisesRegex(TypeError, "multiprocessing_context option should be a valid context "): + self._get_data_loader(self.dataset, num_workers=1, multiprocessing_context=object()) # map-style sampler = torch.utils.data.SequentialSampler(self.dataset) @@ -1213,6 +1245,37 @@ def test_chain_iterable_style_dataset(self): with self.assertRaisesRegex(AssertionError, "ChainDataset only supports IterableDataset"): list(iter(ChainDataset([dataset1, self.dataset]))) + def test_buffer_shuffle_dataset(self): + dataset = CountingIterableDataset(20) + expected = list(range(20)) + buffer_sizes = [5, 20, 25] + for num_workers in [0, 1]: + # Buffer Size <= 1: Not shuffled dataset + fetched_nos = list(self._get_data_loader(BufferedShuffleDataset(dataset, 1), num_workers=num_workers)) + self.assertEqual(len(fetched_nos), len(expected)) + for e, d in zip(expected, fetched_nos): + self.assertIsInstance(d, torch.Tensor) + self.assertEqual(e, d) + # Buffer Size > 1: Shuffled dataset + for buffer_size in buffer_sizes: + fetched = sorted(list(self._get_data_loader(BufferedShuffleDataset(dataset, buffer_size), num_workers=num_workers))) + self.assertEqual(len(fetched), len(expected)) + for e, d in zip(expected, fetched): + self.assertIsInstance(d, torch.Tensor) + self.assertEqual(e, d) + # Random Seed for single process + random.seed(123) + fetched_seed1 = list(self._get_data_loader(BufferedShuffleDataset(dataset, buffer_size), num_workers=num_workers, + worker_init_fn=shuffle_ds_init_fn)) + random.seed(123) + fetched_seed2 = list(self._get_data_loader(BufferedShuffleDataset(dataset, buffer_size), num_workers=num_workers, + worker_init_fn=shuffle_ds_init_fn)) + self.assertEqual(len(fetched_seed1), len(fetched_seed2)) + for d1, d2 in zip(fetched_seed1, fetched_seed2): + self.assertIsInstance(d1, torch.Tensor) + self.assertIsInstance(d2, torch.Tensor) + self.assertEqual(d1, d2) + def test_multiprocessing_contexts(self): reference = [ torch.arange(3), @@ -1370,6 +1433,15 @@ def test_random_sampler_len_with_replacement(self): self.assertEqual(int(math.ceil(float(num_samples) / batch_size)), count_num_samples_in_data_loader) + def test_distributed_sampler_invalid_rank(self): + from torch.utils.data.distributed import DistributedSampler + dataset = torch.IntTensor(range(10)) + with self.assertRaisesRegex(ValueError, "Invalid rank"): + sampler = DistributedSampler(dataset, 3, 3) + + with self.assertRaisesRegex(ValueError, "Invalid rank"): + sampler = DistributedSampler(dataset, 3, -1) + def test_duplicating_data_with_drop_last(self): from torch.utils.data.distributed import DistributedSampler @@ -1411,7 +1483,7 @@ def _test_sampler(self, **kwargs): def test_sampler(self): self._test_sampler() self._test_sampler(num_workers=4) - if not NO_MULTIPROCESSING_SPAWN and torch.multiprocessing._supports_context: + if not NO_MULTIPROCESSING_SPAWN: self._test_batch_sampler(num_workers=4, multiprocessing_context='spawn') def _test_batch_sampler(self, **kwargs): @@ -1436,7 +1508,7 @@ def _test_batch_sampler(self, **kwargs): def test_batch_sampler(self): self._test_batch_sampler() self._test_batch_sampler(num_workers=4) - if not NO_MULTIPROCESSING_SPAWN and torch.multiprocessing._supports_context: + if not NO_MULTIPROCESSING_SPAWN: self._test_batch_sampler(num_workers=4, multiprocessing_context='spawn') @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @@ -1493,8 +1565,10 @@ def test_partial_workers(self): pin_memory_thread.join(JOIN_TIMEOUT) self.assertFalse(pin_memory_thread.is_alive()) + # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065 @skipIfRocm @unittest.skipIf(not HAS_PSUTIL, "psutil not found") + @slowTest def test_proper_exit(self): (r'''There might be ConnectionResetError or leaked semaphore warning ''' r'''(due to dirty process exit), but they are all safe to ignore''') @@ -1525,7 +1599,7 @@ def test_proper_exit(self): # In all cases, all processes should end properly. if use_workers: exit_methods = [None, 'loader_error', 'loader_kill', 'worker_error', 'worker_kill'] - persistent_workers = self.persistent_workers + persistent_workers = self.persistent_workers else: exit_methods = [None, 'loader_error', 'loader_kill'] persistent_workers = False @@ -1801,6 +1875,12 @@ def test_default_collate_shared_tensor(self): finally: _utils.worker._worker_info = old + def test_excessive_thread_creation_warning(self): + with self.assertWarnsRegex( + UserWarning, + r"excessive worker creation might get DataLoader running slow or even freeze"): + dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) + class StringDataset(Dataset): def __init__(self): @@ -1951,15 +2031,19 @@ def __next__(self): def test_dataset_not_reset(self): dataset = DummyDataset() - dataloader = self._get_data_loader(dataset, num_workers=2) - dataset.start = 0 - for i in range(10): - for x in dataloader: - pass - # Changing the start value here doesn't have any effect in the dataset - # cached by the workers. since they are not recreated between epochs - # and can cache values safely - dataset.start = i + pin_memory_configs = [False] + if TEST_CUDA: + pin_memory_configs.append(True) + for pin_memory in pin_memory_configs: + dataloader = self._get_data_loader(dataset, num_workers=2, pin_memory=pin_memory) + dataset.start = 0 + for i in range(10): + for x in dataloader: + pass + # Changing the start value here doesn't have any effect in the dataset + # cached by the workers. since they are not recreated between epochs + # and can cache values safely + dataset.start = i diff --git a/test/test_dataset.py b/test/test_dataset.py new file mode 100644 index 0000000000000..a72b87cca5553 --- /dev/null +++ b/test/test_dataset.py @@ -0,0 +1,164 @@ +import tempfile +import warnings + +import torch +from torch.testing._internal.common_utils import (TestCase, run_tests) +from torch.utils.data import IterableDataset, RandomSampler +from torch.utils.data.datasets import \ + (CollateIterableDataset, BatchIterableDataset, ListDirFilesIterableDataset, + LoadFilesFromDiskIterableDataset, SamplerIterableDataset) + + +def create_temp_dir_and_files(): + # The temp dir and files within it will be released and deleted in tearDown(). + # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. + temp_dir = tempfile.TemporaryDirectory() # noqa: P201 + temp_dir_path = temp_dir.name + temp_file1 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201 + temp_file2 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201 + temp_file3 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201 + + return (temp_dir, temp_file1.name, temp_file2.name, temp_file3.name) + + +class TestIterableDatasetBasic(TestCase): + + def setUp(self): + ret = create_temp_dir_and_files() + self.temp_dir = ret[0] + self.temp_files = ret[1:] + + def tearDown(self): + try: + self.temp_dir.cleanup() + except Exception as e: + warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e))) + + def test_listdirfiles_iterable_dataset(self): + temp_dir = self.temp_dir.name + dataset = ListDirFilesIterableDataset(temp_dir, '') + for pathname in dataset: + self.assertTrue(pathname in self.temp_files) + + def test_loadfilesfromdisk_iterable_dataset(self): + temp_dir = self.temp_dir.name + dataset1 = ListDirFilesIterableDataset(temp_dir, '') + dataset2 = LoadFilesFromDiskIterableDataset(dataset1) + + for rec in dataset2: + self.assertTrue(rec[0] in self.temp_files) + self.assertTrue(rec[1].read() == open(rec[0], 'rb').read()) + + +class IterDatasetWithoutLen(IterableDataset): + def __init__(self, ds): + super().__init__() + self.ds = ds + + def __iter__(self): + for i in self.ds: + yield i + + +class IterDatasetWithLen(IterableDataset): + def __init__(self, ds): + super().__init__() + self.ds = ds + self.length = len(ds) + + def __iter__(self): + for i in self.ds: + yield i + + def __len__(self): + return self.length + + +class TestFunctionalIterableDataset(TestCase): + def test_collate_dataset(self): + arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + ds_len = IterDatasetWithLen(arrs) + ds_nolen = IterDatasetWithoutLen(arrs) + + def _collate_fn(batch): + return torch.tensor(sum(batch), dtype=torch.float) + + collate_ds = CollateIterableDataset(ds_len, collate_fn=_collate_fn) + self.assertEqual(len(ds_len), len(collate_ds)) + ds_iter = iter(ds_len) + for x in collate_ds: + y = next(ds_iter) + self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float)) + + collate_ds_nolen = CollateIterableDataset(ds_nolen) # type: ignore + with self.assertRaises(NotImplementedError): + len(collate_ds_nolen) + ds_nolen_iter = iter(ds_nolen) + for x in collate_ds_nolen: + y = next(ds_nolen_iter) + self.assertEqual(x, torch.tensor(y)) + + def test_batch_dataset(self): + arrs = range(10) + ds = IterDatasetWithLen(arrs) + with self.assertRaises(AssertionError): + batch_ds0 = BatchIterableDataset(ds, batch_size=0) + + # Default not drop the last batch + batch_ds1 = BatchIterableDataset(ds, batch_size=3) + self.assertEqual(len(batch_ds1), 4) + batch_iter = iter(batch_ds1) + value = 0 + for i in range(len(batch_ds1)): + batch = next(batch_iter) + if i == 3: + self.assertEqual(len(batch), 1) + self.assertEqual(batch, [9]) + else: + self.assertEqual(len(batch), 3) + for x in batch: + self.assertEqual(x, value) + value += 1 + + # Drop the last batch + batch_ds2 = BatchIterableDataset(ds, batch_size=3, drop_last=True) + self.assertEqual(len(batch_ds2), 3) + value = 0 + for batch in batch_ds2: + self.assertEqual(len(batch), 3) + for x in batch: + self.assertEqual(x, value) + value += 1 + + batch_ds3 = BatchIterableDataset(ds, batch_size=2) + self.assertEqual(len(batch_ds3), 5) + batch_ds4 = BatchIterableDataset(ds, batch_size=2, drop_last=True) + self.assertEqual(len(batch_ds4), 5) + + ds_nolen = IterDatasetWithoutLen(arrs) + batch_ds_nolen = BatchIterableDataset(ds_nolen, batch_size=5) + with self.assertRaises(NotImplementedError): + len(batch_ds_nolen) + + def test_sampler_dataset(self): + arrs = range(10) + ds = IterDatasetWithLen(arrs) + # Default SequentialSampler + sampled_ds = SamplerIterableDataset(ds) # type: ignore + self.assertEqual(len(sampled_ds), 10) + i = 0 + for x in sampled_ds: + self.assertEqual(x, i) + i += 1 + + # RandomSampler + random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) # type: ignore + + # Requires `__len__` to build SamplerDataset + ds_nolen = IterDatasetWithoutLen(arrs) + with self.assertRaises(AssertionError): + sampled_ds = SamplerIterableDataset(ds_nolen) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_determination.py b/test/test_determination.py index 0f860cab51014..7e9420285e5a9 100644 --- a/test/test_determination.py +++ b/test/test_determination.py @@ -112,6 +112,7 @@ def test_torch_file(self): "distributed/test_distributed_fork", "test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_no_ninja", + "test_utils", "test_determination", ], ) diff --git a/test/test_dispatch.py b/test/test_dispatch.py index 45480d8916f07..092d2564123bc 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -3,7 +3,6 @@ from collections import namedtuple import itertools -import unittest import re # TODO: Expand the dispatcher API to be a generic API for interfacing with @@ -25,6 +24,7 @@ Result = namedtuple('Result', 'state table provenance') dispatch_keys_to_check = ( + 'Undefined', 'CPU', 'CUDA', 'XLA', @@ -243,7 +243,7 @@ def test_def(self): CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') def test_def_impl_schema_mismatch(self): @@ -255,7 +255,14 @@ def test_def_impl_schema_mismatch(self): # m.impl("foo", [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo"), ], expect_raises=True).state - self.assertExpectedInline(state, '''In registration for test::foo: expected schema of operator to be "test::foo(Tensor x, Tensor y) -> (Tensor)" (registered at /dev/null:0), but got inferred schema "(Tensor _0) -> (Tensor _0)" (impl_t_t). The number of arguments is different. 2 vs 1.''') # noqa + self.assertExpectedInline(state, '''\ +Inferred operator schema for a C++ kernel function doesn't match the expected function schema. + operator: test::foo + expected schema: test::foo(Tensor x, Tensor y) -> (Tensor) + registered at /dev/null:0 + inferred schema: (Tensor _0) -> (Tensor _0) + impl_t_t + reason: The number of arguments is different. 2 vs 1.''') # noqa def test_def_with_inference(self): state = self.commute("foo", [ @@ -276,7 +283,7 @@ def test_def_with_inference(self): CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') def test_def_only(self): @@ -308,7 +315,7 @@ def test_impl_only(self): CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') def test_computed_table(self): @@ -334,23 +341,24 @@ def test_computed_table(self): XLA: fn_xla :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: fn_autogradcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: default_def_name_t_t [math kernel] CPU: fn_cpu [kernel] -CUDA: default_def_name_t_t [catch all] +CUDA: default_def_name_t_t [math kernel] XLA: fn_xla [kernel] -AutogradOther: fn_autograd [autograd kernel] +AutogradOther: default_def_name_t_t [math kernel] AutogradCPU: fn_autogradcpu [kernel] -AutogradCUDA: fn_autograd [autograd kernel] +AutogradCUDA: default_def_name_t_t [math kernel] AutogradXLA: fn_autograd [autograd kernel] ''') - def test_computed_table_with_cpu_catchall(self): + def test_computed_table_with_cpu_math_autogradcpu_fallthrough(self): global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute("foo", [ # m.def("foo", [](const Tensor & x) { return x }) @@ -365,20 +373,21 @@ def test_computed_table_with_cpu_catchall(self): debug: registered at /dev/null:0 alias analysis kind: CONSERVATIVE CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: default_def_name_t_t [math kernel] CPU: impl_t_t [kernel] -CUDA: default_def_name_t_t [catch all] -XLA: default_def_name_t_t [catch all] -AutogradOther: default_def_name_t_t [catch all] +CUDA: default_def_name_t_t [math kernel] +XLA: default_def_name_t_t [math kernel] +AutogradOther: default_def_name_t_t [math kernel] AutogradCPU: fallthrough registered in pytorch framework [backend fallback] -AutogradCUDA: default_def_name_t_t [catch all] -AutogradXLA: default_def_name_t_t [catch all] +AutogradCUDA: default_def_name_t_t [math kernel] +AutogradXLA: default_def_name_t_t [math kernel] ''') def test_computed_table_with_math(self): @@ -402,6 +411,7 @@ def test_computed_table_with_math(self): extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: impl_t_t [math kernel] CPU: impl_t_t [math kernel] CUDA: impl_t_t [math kernel] XLA: impl_t_t [math kernel] @@ -435,6 +445,7 @@ def test_computed_table_with_cpu_math(self): extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_math [math kernel] CPU: fn_cpu [kernel] CUDA: fn_math [math kernel] XLA: fn_math [math kernel] @@ -471,10 +482,11 @@ def test_computed_table_with_autograd(self): AutogradXLA: impl_t_t [autograd kernel] ''') - def test_computed_table_with_cpu_autograd_math_catchall(self): + # Now that catchAll maps to Math, registering to both catchAll and Math breaks commutativity. + def test_computed_table_with_cpu_autograd_math(self): result = self.commute("foo", [ - # m.def("foo", [](const Tensor & x) { return x }) - lambda m: m.def_name_t_t("foo"), + # m.def("foo(Tensor x) -> Tensor") + lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) @@ -485,19 +497,19 @@ def test_computed_table_with_cpu_autograd_math_catchall(self): state, table = result.state, result.table self.assertExpectedInline(state, '''\ name: test::foo -schema: test::foo(Tensor _0) -> (Tensor _0) +schema: test::foo(Tensor x) -> (Tensor) debug: registered at /dev/null:0 -alias analysis kind: CONSERVATIVE +alias analysis kind: FROM_SCHEMA CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_math [math kernel] CPU: fn_cpu [kernel] CUDA: fn_math [math kernel] XLA: fn_math [math kernel] @@ -507,83 +519,162 @@ def test_computed_table_with_cpu_autograd_math_catchall(self): AutogradXLA: fn_math [math kernel] ''') - def test_computed_table_with_cpu_autograd_catchall(self): + def test_computed_table_with_ambiguous_autogradother(self): result = self.commute("foo", [ - # m.def("foo", [](const Tensor & x) { return x }) - lambda m: m.def_name_t_t("foo"), + # m.def("foo(Tensor x) -> Tensor") + lambda m: m.def_("foo(Tensor x) -> Tensor"), + # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), + # m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"), + ]) + state, table = result.state, result.table + self.assertExpectedInline(state, '''\ +name: test::foo +schema: test::foo(Tensor x) -> (Tensor) +debug: registered at /dev/null:0 +alias analysis kind: FROM_SCHEMA +QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +''') + + # computed dispatch table is too big, so we only check on a few entries we're interested in. + extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) + + self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_math [math kernel] +CPU: fn_math [math kernel] +CUDA: fn_math [math kernel] +XLA: fn_math [math kernel] +AutogradOther: ambiguous_autogradother [ambiguous autogradother] +AutogradCPU: fn_math [math kernel] +AutogradCUDA: fn_math [math kernel] +AutogradXLA: fn_math [math kernel] +QuantizedCPU: fn_quantizedcpu [kernel] +''') + + def test_computed_table_with_cpu_defaultbackend(self): + result = self.commute("foo", [ + # m.def("foo(Tensor x) -> Tensor") + lambda m: m.def_("foo(Tensor x) -> Tensor"), + # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), + # m.impl("foo", torch::kDefaultBackend, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "DefaultBackend", debug="fn_defaultbackend"), + ]) + state, table = result.state, result.table + self.assertExpectedInline(state, '''\ +name: test::foo +schema: test::foo(Tensor x) -> (Tensor) +debug: registered at /dev/null:0 +alias analysis kind: FROM_SCHEMA +CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +''') + + # computed dispatch table is too big, so we only check on a few entries we're interested in. + extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) + + self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_defaultbackend [default backend kernel] +CPU: fn_cpu [kernel] +CUDA: fn_defaultbackend [default backend kernel] +XLA: fn_defaultbackend [default backend kernel] +AutogradOther: fallthrough registered in pytorch framework [backend fallback] +AutogradCPU: fallthrough registered in pytorch framework [backend fallback] +AutogradCUDA: fallthrough registered in pytorch framework [backend fallback] +AutogradXLA: fallthrough registered in pytorch framework [backend fallback] +''') + + def test_computed_table_with_cpu_autograd_defaultbackend(self): + result = self.commute("foo", [ + # m.def("foo(Tensor x) -> Tensor") + lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), + # m.impl("foo", torch::kDefaultBackend, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "DefaultBackend", debug="fn_defaultbackend"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ name: test::foo -schema: test::foo(Tensor _0) -> (Tensor _0) +schema: test::foo(Tensor x) -> (Tensor) debug: registered at /dev/null:0 -alias analysis kind: CONSERVATIVE +alias analysis kind: FROM_SCHEMA CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. - extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) + extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_defaultbackend [default backend kernel] CPU: fn_cpu [kernel] -CUDA: default_def_name_t_t [catch all] -XLA: default_def_name_t_t [catch all] +CUDA: fn_defaultbackend [default backend kernel] +XLA: fn_defaultbackend [default backend kernel] AutogradOther: fn_autograd [autograd kernel] AutogradCPU: fn_autograd [autograd kernel] AutogradCUDA: fn_autograd [autograd kernel] AutogradXLA: fn_autograd [autograd kernel] +QuantizedCPU: fn_defaultbackend [default backend kernel] ''') - def test_computed_table_with_ambiguous_autogradother(self): + def test_computed_table_with_cpu_autograd_math_defaultbackend(self): result = self.commute("foo", [ - # m.def("foo", [](const Tensor & x) { return x }) - lambda m: m.def_name_t_t("foo"), + # m.def("foo(Tensor x) -> Tensor") + lambda m: m.def_("foo(Tensor x) -> Tensor"), + # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), + # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), - # m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"), + # m.impl("foo", torch::kDefaultBackend, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "DefaultBackend", debug="fn_defaultbackend"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ name: test::foo -schema: test::foo(Tensor _0) -> (Tensor _0) +schema: test::foo(Tensor x) -> (Tensor) debug: registered at /dev/null:0 -alias analysis kind: CONSERVATIVE -QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +alias analysis kind: FROM_SCHEMA +CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ -CPU: fn_math [math kernel] -CUDA: fn_math [math kernel] -XLA: fn_math [math kernel] -AutogradOther: ambiguous_autogradother [ambiguous autogradother] -AutogradCPU: fn_math [math kernel] -AutogradCUDA: fn_math [math kernel] -AutogradXLA: fn_math [math kernel] +Undefined: fn_defaultbackend [default backend kernel] +CPU: fn_cpu [kernel] +CUDA: fn_defaultbackend [default backend kernel] +XLA: fn_defaultbackend [default backend kernel] +AutogradOther: fn_autograd [autograd kernel] +AutogradCPU: fn_autograd [autograd kernel] +AutogradCUDA: fn_autograd [autograd kernel] +AutogradXLA: fn_autograd [autograd kernel] ''') - # Can't do this yet for BC reasons - @unittest.expectedFailure def test_multiple_def_error(self): - state = self.commute("foo", [ + ops = [ # m.def("foo(Tensor x, Tensor y) -> Tensor") lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), # m.def("foo(Tensor x, Tensor y) -> Tensor") lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), - ], expect_raises=True).state - # TODO: fill in the error message here - # self.assertExpectedInline(state, '''''') + ] + self.assertExpectedInline( + self.commute("foo", ops, expect_raises=True).state, + '''Tried to register an operator (test::foo(Tensor x, Tensor y) -> (Tensor)) with the same name and overload ''' + '''name multiple times. Each overload's schema should only be registered with a single call to def(). ''' + '''Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0''' + ) def test_def_with_explicit_alias(self): state = self.commute("foo", [ @@ -600,26 +691,7 @@ def test_def_with_explicit_alias(self): alias analysis kind: PURE_FUNCTION ''') - # TODO: get rid of this test when multiple defs are wrong - def test_multiple_def_schema_mismatch(self): - # error message is order dependent - ops = [ - # m.def("foo(Tensor x, Tensor y) -> Tensor") - lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), - # m.def("foo(Tensor x) -> Tensor") - lambda m: m.def_("foo(Tensor x) -> Tensor"), - ] - self.assertExpectedInline( - self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True).state, - '''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0)''' # noqa - ) - self.assertExpectedInline( - self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True).state, - '''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0)''' # noqa - ) - def test_multiple_def_alias_defaulting(self): - # TODO: should be an error in both directions soon ops = [ # m.def(torch::schema("foo(Tensor x) -> Tensor", # c10::AliasAnalysisKind::PURE_FUNCTION)) @@ -627,25 +699,14 @@ def test_multiple_def_alias_defaulting(self): # RegisterOperators().op("foo(Tensor x) -> Tensor") lambda m: m.def_legacy("foo(Tensor x) -> Tensor"), ] - state = self.commute("foo", ops, ctor_order=(0, 1)).state self.assertExpectedInline( - state, - '''\ -name: test::foo -schema: test::foo(Tensor x) -> (Tensor) -debug: registered at /dev/null:0 -alias analysis kind: PURE_FUNCTION -''' + self.commute("foo", ops, expect_raises=True).state, + '''Tried to register an operator (test::foo(Tensor x) -> (Tensor)) with the same name and overload ''' + '''name multiple times. Each overload's schema should only be registered with a single call to def(). ''' + '''Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0''' ) - # NB: When run with ctor order (1, 0), the destructors are NOT - # COMMUTATIVE. THIS IS A BUG, however we are purposely leaving the bug - # in as it is very benign (only leaves us in a bad state during - # destruction, when no useful work is being done), will be fixed when we - # make alias defaulting a hard error, and is very nontrivial to fix - # prior to that. def test_multiple_def_alias_mismatch(self): - # error message is order dependent ops = [ # m.def(torch::schema("foo(Tensor x) -> Tensor", # c10::AliasAnalysisKind::PURE_FUNCTION)) @@ -655,12 +716,10 @@ def test_multiple_def_alias_mismatch(self): lambda m: m.def_("foo(Tensor x) -> Tensor", alias="CONSERVATIVE"), ] self.assertExpectedInline( - self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True).state, - '''Tried to define the schema for test::foo with different alias analysis kinds: PURE_FUNCTION (registered at /dev/null:0) vs CONSERVATIVE (registered at /dev/null:0)''' # noqa - ) - self.assertExpectedInline( - self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True).state, - '''Tried to define the schema for test::foo with different alias analysis kinds: CONSERVATIVE (registered at /dev/null:0) vs PURE_FUNCTION (registered at /dev/null:0)''' # noqa + self.commute("foo", ops, expect_raises=True).state, + '''Tried to register an operator (test::foo(Tensor x) -> (Tensor)) with the same name and overload ''' + '''name multiple times. Each overload's schema should only be registered with a single call to def(). ''' + '''Duplicate registration: registered at /dev/null:0. Original registration: registered at /dev/null:0''' # noqa ) def test_multiple_fallback(self): @@ -671,12 +730,13 @@ def test_multiple_fallback(self): except RuntimeError as e: self.assertExpectedInline( str(e), - '''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa + '''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration ''' + '''registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa ) else: self.assertTrue(False) - def test_overwrite_catchall(self): + def test_overwrite_math(self): ops = [ lambda m: m.impl_t_t("foo", debug="fn1"), lambda m: m.impl_t_t("foo", debug="fn2"), @@ -687,8 +747,8 @@ def test_overwrite_catchall(self): '''\ name: test::foo schema: (none) -catchall: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -catchall (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias]: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +Math[alias] (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''' ) diff --git a/test/test_expecttest.py b/test/test_expecttest.py index 652a33c418699..39b6f44136761 100644 --- a/test/test_expecttest.py +++ b/test/test_expecttest.py @@ -1,9 +1,10 @@ -import torch.testing._internal.expecttest as expecttest +from torch.testing._internal import expecttest +from torch.testing._internal.common_utils import TestCase, run_tests -import unittest import string import textwrap import doctest +from typing import Dict, Any import hypothesis from hypothesis.strategies import text, integers, composite, sampled_from, booleans @@ -16,7 +17,7 @@ def text_lineno(draw): return (t, lineno) -class TestExpectTest(expecttest.TestCase): +class TestExpectTest(TestCase): @hypothesis.given(text_lineno()) def test_nth_line_ref(self, t_lineno): t, lineno = t_lineno @@ -38,7 +39,7 @@ def test_replace_string_literal_roundtrip(self, t, raw, quote): r3 = {r}{quote}placeholder3{quote} """.format(r='r' if raw else '', quote=quote * 3) new_prog = expecttest.replace_string_literal(textwrap.dedent(prog), 2, t)[0] - ns = {} + ns : Dict[str, Any] = {} exec(new_prog, ns) msg = "program was:\n{}".format(new_prog) self.assertEqual(ns['r'], 'placeholder', msg=msg) # noqa: F821 @@ -102,4 +103,4 @@ def load_tests(loader, tests, ignore): if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_foreach.py b/test/test_foreach.py index 8369ba5b9be5c..c55c4e71dab02 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1,24 +1,69 @@ import torch import unittest -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, skipCUDAIfRocm +from torch._six import inf, nan + +N_values = [20] if not TEST_WITH_SLOW else [30, 300] class TestForeach(TestCase): - bin_ops = [ + foreach_bin_ops = [ torch._foreach_add, - torch._foreach_add_, torch._foreach_sub, - torch._foreach_sub_, torch._foreach_mul, - torch._foreach_mul_, torch._foreach_div, + ] + + foreach_bin_ops_ = [ + torch._foreach_add_, + torch._foreach_sub_, + torch._foreach_mul_, torch._foreach_div_, ] + torch_bin_ops = [ + torch.add, + torch.sub, + torch.mul, + torch.div, + ] + + unary_ops = [ + # foreach_op, foreach_op_, torch_op, bf16, complex64/128 + (torch._foreach_sqrt, torch._foreach_sqrt_, torch.sqrt, True , True), + (torch._foreach_exp, torch._foreach_exp_, torch.exp, True, True), + (torch._foreach_acos, torch._foreach_acos_, torch.acos, False, True), + (torch._foreach_asin, torch._foreach_asin_, torch.asin, False, True), + (torch._foreach_atan, torch._foreach_atan_, torch.atan, False, True), + (torch._foreach_cos, torch._foreach_cos_, torch.cos, True, True), + (torch._foreach_cosh, torch._foreach_cosh_, torch.cosh, False, True), + (torch._foreach_log, torch._foreach_log_, torch.log, True, True), + (torch._foreach_log10, torch._foreach_log10_, torch.log10, True, True), + (torch._foreach_log2, torch._foreach_log2_, torch.log2, True, True), + (torch._foreach_neg, torch._foreach_neg_, torch.neg, True, True), + (torch._foreach_tan, torch._foreach_tan_, torch.tan, False, True), + (torch._foreach_tanh, torch._foreach_tanh_, torch.tanh, True, True), + (torch._foreach_sin, torch._foreach_sin_, torch.sin, False, True), + (torch._foreach_sinh, torch._foreach_sinh_, torch.sinh, False, True), + (torch._foreach_ceil, torch._foreach_ceil_, torch.ceil, False, False), + (torch._foreach_erf, torch._foreach_erf_, torch.erf, True, False), + (torch._foreach_erfc, torch._foreach_erfc_, torch.erfc, False, False), + (torch._foreach_expm1, torch._foreach_expm1_, torch.expm1, False, False), + (torch._foreach_floor, torch._foreach_floor_, torch.floor, False, False), + (torch._foreach_log1p, torch._foreach_log1p_, torch.log1p, True, False), + (torch._foreach_round, torch._foreach_round_, torch.round, False, False), + (torch._foreach_frac, torch._foreach_frac_, torch.frac, False, False), + (torch._foreach_reciprocal, torch._foreach_reciprocal_, torch.reciprocal, True, True), + (torch._foreach_sigmoid, torch._foreach_sigmoid_, torch.sigmoid, True, False), + (torch._foreach_trunc, torch._foreach_trunc_, torch.trunc, False, False), + + # See test_abs + # (torch._foreach_abs, torch._foreach_abs_, torch.abs, True, True), + ] + def _get_test_data(self, device, dtype, N): if dtype in [torch.bfloat16, torch.bool, torch.float16]: tensors = [torch.randn(N, N, device=device).to(dtype) for _ in range(N)] - elif dtype in torch.testing.get_all_int_dtypes(): tensors = [torch.randint(1, 100, (N, N), device=device, dtype=dtype) for _ in range(N)] else: @@ -26,69 +71,185 @@ def _get_test_data(self, device, dtype, N): return tensors - def _test_bin_op_list(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20): - tensors1 = self._get_test_data(device, dtype, N) - tensors2 = self._get_test_data(device, dtype, N) - - expected = [torch_op(tensors1[i], tensors2[i]) for i in range(N)] - res = foreach_op(tensors1, tensors2) - foreach_op_(tensors1, tensors2) - self.assertEqual(res, tensors1) - self.assertEqual(tensors1, expected) - - def _test_unary_op(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20): - tensors1 = self._get_test_data(device, dtype, N) - expected = [torch_op(tensors1[i]) for i in range(N)] - res = foreach_op(tensors1) - foreach_op_(tensors1) - self.assertEqual(res, tensors1) - self.assertEqual(tensors1, expected) - - def _test_pointwise_op(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20): - tensors = self._get_test_data(device, dtype, N) - tensors1 = self._get_test_data(device, dtype, N) - tensors2 = self._get_test_data(device, dtype, N) - value = 2 - - expected = [torch_op(tensors[i], tensors1[i], tensors2[i], value=value) for i in range(N)] - - res = foreach_op(tensors, tensors1, tensors2, value) - foreach_op_(tensors, tensors1, tensors2, value) - self.assertEqual(res, tensors) - self.assertEqual(tensors, expected) - - def _test_bin_op_list_alpha(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20): - tensors1 = self._get_test_data(device, dtype, N) - tensors2 = self._get_test_data(device, dtype, N) - alpha = 2 - - expected = [torch_op(tensors1[i], torch.mul(tensors2[i], alpha)) for i in range(N)] - res = foreach_op(tensors1, tensors2, alpha) - foreach_op_(tensors1, tensors2, alpha) - self.assertEqual(res, tensors1) - - if dtype == torch.bool: - expected = [e.to(torch.bool) for e in expected] - self.assertEqual(tensors1, expected) + def _test_bin_op_list(self, device, dtype, foreach_op, foreach_op_, torch_op): + for N in N_values: + tensors1 = self._get_test_data(device, dtype, N) + tensors2 = self._get_test_data(device, dtype, N) + + # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + expected = [torch_op(tensors1[i].to(dtype=control_dtype), + tensors2[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)] + res = foreach_op(tensors1, tensors2) + foreach_op_(tensors1, tensors2) + self.assertEqual(res, tensors1) + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(tensors1, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + else: + self.assertEqual(tensors1, expected) + + def _test_pointwise_op(self, device, dtype, foreach_op, foreach_op_, torch_op): + for N in N_values: + values = [2 + i for i in range(N)] + for vals in [values[0], values]: + tensors = self._get_test_data(device, dtype, N) + tensors1 = self._get_test_data(device, dtype, N) + tensors2 = self._get_test_data(device, dtype, N) + + # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + + if not isinstance(vals, list): + expected = [torch_op(tensors[i].to(dtype=control_dtype), + tensors1[i].to(dtype=control_dtype), + tensors2[i].to(dtype=control_dtype), + value=values[0]).to(dtype=dtype) for i in range(N)] + else: + expected = [torch_op(tensors[i].to(dtype=control_dtype), + tensors1[i].to(dtype=control_dtype), + tensors2[i].to(dtype=control_dtype), + value=values[i]).to(dtype=dtype) for i in range(N)] + + res = foreach_op(tensors, tensors1, tensors2, vals) + foreach_op_(tensors, tensors1, tensors2, vals) + self.assertEqual(res, tensors) + + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(tensors, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + else: + self.assertEqual(tensors, expected) + + # test error cases + for op in [torch._foreach_addcmul, torch._foreach_addcmul_, torch._foreach_addcdiv, torch._foreach_addcdiv_]: + tensors = self._get_test_data(device, dtype, N) + tensors1 = self._get_test_data(device, dtype, N) + tensors2 = self._get_test_data(device, dtype, N) + + with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): + op(tensors, tensors1, tensors2, [2 for _ in range(N + 1)]) + + with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): + op(tensors, tensors1, tensors2, [2 for _ in range(N - 1)]) + + tensors = self._get_test_data(device, dtype, N + 1) + with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 21 and 20"): + op(tensors, tensors1, tensors2, [2 for _ in range(N)]) + + tensors1 = self._get_test_data(device, dtype, N + 1) + with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 21 and 20"): + op(tensors, tensors1, tensors2, [2 for _ in range(N)]) + + def _test_bin_op_list_alpha(self, device, dtype, foreach_op, foreach_op_, torch_op): + for N in [30, 300]: + tensors1 = self._get_test_data(device, dtype, N) + tensors2 = self._get_test_data(device, dtype, N) + alpha = 2 + + # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + expected = [torch_op(tensors1[i].to(dtype=control_dtype), + torch.mul(tensors2[i].to(dtype=control_dtype), + alpha)).to(dtype=dtype) for i in range(N)] + res = foreach_op(tensors1, tensors2, alpha=alpha) + foreach_op_(tensors1, tensors2, alpha=alpha) + self.assertEqual(res, tensors1) + + if dtype == torch.bool: + expected = [e.to(torch.bool) for e in expected] + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(tensors1, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + else: + self.assertEqual(tensors1, expected) # # Unary ops # - @dtypes(*[torch.float, torch.double, torch.complex64, torch.complex128]) - def test_sqrt(self, device, dtype): - self._test_unary_op(device, dtype, torch._foreach_sqrt, torch._foreach_sqrt_, torch.sqrt) + @dtypes(*(torch.testing.floating_and_complex_types_and(torch.bfloat16, torch.half))) + def test_unary_ops(self, device, dtype): + for fe_op, fe_op_, torch_op, support_bfloat16, support_complex in self.unary_ops: + for N in N_values: + tensors1 = self._get_test_data(device, dtype, N) + # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + + if self.device_type == 'cpu' and dtype == torch.half and torch_op not in [torch.neg, torch.frac, torch.reciprocal]: + with self.assertRaisesRegex(RuntimeError, r"not implemented for \'Half\'"): + expected = [torch_op(tensors1[i]) for i in range(N)] + + with self.assertRaisesRegex(RuntimeError, r"not implemented for \'Half\'"): + res = fe_op(tensors1) + break + + if dtype == torch.bfloat16 and not support_bfloat16: + if self.device_type == 'cuda' or torch_op in [torch.sinh, torch.cosh]: + with self.assertRaisesRegex(RuntimeError, r"not implemented for \'BFloat16\'"): + expected = [torch_op(tensors1[i]) for i in range(N)] + + with self.assertRaisesRegex(RuntimeError, r"not implemented for \'BFloat16\'"): + res = fe_op(tensors1) + break + + if dtype in [torch.complex64, torch.complex128] and not support_complex: + if not (self.device_type == 'cpu' and torch_op in [torch.sigmoid]): + # not using assertRaisesRegex due to different error messages + with self.assertRaises(RuntimeError): + expected = [torch_op(tensors1[i]) for i in range(N)] + + with self.assertRaises(RuntimeError): + res = fe_op(tensors1) + break + + expected = [torch_op(tensors1[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)] + res = fe_op(tensors1) + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + + fe_op_(tensors1) + self.assertEqual(res, tensors1) + else: + self.assertEqual(res, expected) + + fe_op_(tensors1) + self.assertEqual(res, tensors1) + + # Separate test for abs due to a lot of special cases + # Absolute value of a complex number a + bj is defined as sqrt(a^2 + b^2), i.e. a floating point + @dtypes(*(torch.testing.floating_and_complex_types_and(torch.bfloat16, torch.half))) + def test_abs(self, device, dtype): + for N in N_values: + tensors1 = self._get_test_data(device, dtype, N) + # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + + expected = [torch.abs(tensors1[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)] + res = torch._foreach_abs(tensors1) + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + + torch._foreach_abs_(tensors1) + self.assertEqual(res, tensors1) + else: + expected = [torch.abs(tensors1[i]) for i in range(N)] + self.assertEqual(res, expected) - @dtypes(*[torch.float, torch.double, torch.complex64, torch.complex128]) - def test_exp(self, device, dtype): - self._test_unary_op(device, dtype, torch._foreach_exp, torch._foreach_exp_, torch.exp) + if dtype in [torch.complex64, torch.complex128]: + with self.assertRaisesRegex(RuntimeError, r"In-place abs is not supported for complex tensors."): + torch._foreach_abs_(tensors1) + else: + torch._foreach_abs_(tensors1) + self.assertEqual(res, tensors1) # # Pointwise ops # - @skipCUDAIfRocm @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) def test_addcmul(self, device, dtype): - if device == 'cpu': + if self.device_type == 'cpu': if dtype == torch.half: with self.assertRaisesRegex(RuntimeError, r"\"addcmul_cpu_out\" not implemented for \'Half\'"): self._test_pointwise_op(device, dtype, torch._foreach_addcmul, @@ -105,7 +266,7 @@ def test_addcdiv(self, device, dtype): self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv) return - if device == 'cpu': + if self.device_type == 'cpu': if dtype == torch.half: with self.assertRaisesRegex(RuntimeError, r"\"addcdiv_cpu_out\" not implemented for \'Half\'"): self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, @@ -113,88 +274,481 @@ def test_addcdiv(self, device, dtype): return self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv) + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) + def test_min_max(self, device, dtype): + for N in N_values: + tensors1 = self._get_test_data(device, dtype, N) + tensors2 = self._get_test_data(device, dtype, N) + + # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + + expected_max = [torch.max(tensors1[i].to(dtype=control_dtype), + tensors2[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)] + + expected_min = [torch.min(tensors1[i].to(dtype=control_dtype), + tensors2[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)] + + res_max = torch._foreach_maximum(tensors1, tensors2) + self.assertEqual(res_max, expected_max) + + res_min = torch._foreach_minimum(tensors1, tensors2) + self.assertEqual(res_min, expected_min) + + + @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False))) + def test_max_min_float_inf_nan(self, device, dtype): + a = [ + torch.tensor([float('inf')], device=device, dtype=dtype), + torch.tensor([-float('inf')], device=device, dtype=dtype), + torch.tensor([float('nan')], device=device, dtype=dtype), + torch.tensor([float('nan')], device=device, dtype=dtype) + ] + + b = [ + torch.tensor([-float('inf')], device=device, dtype=dtype), + torch.tensor([float('inf')], device=device, dtype=dtype), + torch.tensor([float('inf')], device=device, dtype=dtype), + torch.tensor([float('nan')], device=device, dtype=dtype) + ] + + expected = [torch.max(a1, b1) for a1, b1 in zip(a, b)] + res = torch._foreach_maximum(a, b) + self.assertEqual(expected, res) + + expected = [torch.min(a1, b1) for a1, b1 in zip(a, b)] + res = torch._foreach_minimum(a, b) + self.assertEqual(expected, res) + + @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False))) + def test_max_min_inf_nan(self, device, dtype): + a = [ + torch.tensor([inf], device=device, dtype=dtype), + torch.tensor([-inf], device=device, dtype=dtype), + torch.tensor([nan], device=device, dtype=dtype), + torch.tensor([nan], device=device, dtype=dtype) + ] + + b = [ + torch.tensor([-inf], device=device, dtype=dtype), + torch.tensor([inf], device=device, dtype=dtype), + torch.tensor([inf], device=device, dtype=dtype), + torch.tensor([nan], device=device, dtype=dtype) + ] + + expected_max = [torch.max(a1, b1) for a1, b1 in zip(a, b)] + res_max = torch._foreach_maximum(a, b) + self.assertEqual(expected_max, res_max) + + expected_min = [torch.min(a1, b1) for a1, b1 in zip(a, b)] + res_min = torch._foreach_minimum(a, b) + self.assertEqual(expected_min, res_min) + # # Ops with scalar # + @skipCUDAIfRocm @dtypes(*torch.testing.get_all_dtypes()) def test_int_scalar(self, device, dtype): - tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] - int_scalar = 1 - - # bool tensor + 1 will result in int64 tensor - if dtype == torch.bool: - expected = [torch.ones(10, 10, device=device, dtype=torch.int64) for _ in range(10)] - else: - expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)] - - res = torch._foreach_add(tensors, int_scalar) - self.assertEqual(res, expected) - - if dtype in [torch.bool]: - with self.assertRaisesRegex(RuntimeError, - "result type Long can't be cast to the desired output type Bool"): - torch._foreach_add_(tensors, int_scalar) - else: - torch._foreach_add_(tensors, int_scalar) - self.assertEqual(res, tensors) + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalar = 3 + expected = [torch_bin_op(t, scalar) for t in tensors] + + res = foreach_bin_op(tensors, scalar) + + if dtype == torch.bool: + self.assertEqual(res, expected) + + with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalar) + return + + + if foreach_bin_op_ == torch._foreach_div_ and dtype in torch.testing.integral_types() and self.device_type == "cpu": + with self.assertRaisesRegex(RuntimeError, + "can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalar) + return + + # TODO[type promotion]: Fix once type promotion is enabled. + if dtype in torch.testing.integral_types() and self.device_type == 'cuda': + self.assertEqual(res, [e.to(dtype) for e in expected]) + + foreach_bin_op_(tensors, scalar) + self.assertEqual(tensors, [e.to(dtype) for e in expected]) + else: + self.assertEqual(res, expected) + foreach_bin_op_(tensors, scalar) + self.assertEqual(tensors, expected) + + # TODO[Fix scalar list]: + # We need to update codegen to correctly handle function overloads with float[] and int[]. + # As optimizers work with float tensors, the result will always be torch.float32 for now. + # Current schema is using 'float[]' as scalar list type. + @skipCUDAIfRocm + @dtypes(*torch.testing.get_all_dtypes()) + def test_int_scalarlist(self, device, dtype): + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalars = [1 for _ in range(N)] + expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)] + + # we dont support bool and complex types on CUDA for now + if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == 'cuda': + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op_(tensors, scalars) + + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op(tensors, scalars) + return + + res = foreach_bin_op(tensors, scalars) + + if dtype == torch.bool: + self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)]) + + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalars) + return + + if dtype in torch.testing.integral_types(): + if self.device_type == 'cpu': + self.assertEqual(res, [e.to(torch.float32) for e in expected]) + else: + # TODO[type promotion]: Fix once type promotion is enabled. + self.assertEqual(res, [e.to(dtype) for e in expected]) + else: + self.assertEqual(res, expected) + + if dtype in torch.testing.integral_types() and self.device_type == 'cpu': + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalars) + return + else: + foreach_bin_op_(tensors, scalars) + self.assertEqual(res, tensors) + @skipCUDAIfRocm @dtypes(*torch.testing.get_all_dtypes()) def test_float_scalar(self, device, dtype): - tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] - float_scalar = 1. + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalar = 3.3 + + # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + expected = [torch_bin_op(t.to(dtype=control_dtype), + scalar) for t in tensors] + if (dtype is torch.float16 or dtype is torch.bfloat16): + expected = [e.to(dtype=dtype) for e in expected] + + if dtype == torch.bool: + if foreach_bin_op == torch._foreach_sub: + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"): + foreach_bin_op_(tensors, scalar) + + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"): + foreach_bin_op(tensors, scalar) + return + + res = foreach_bin_op(tensors, scalar) + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + else: + self.assertEqual(res, expected) + + if dtype in torch.testing.integral_types(): + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalar) + return + + foreach_bin_op_(tensors, scalar) + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(tensors, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + else: + self.assertEqual(tensors, expected) - # float scalar + integral tensor will result in float tensor - if dtype in [torch.uint8, torch.int8, torch.int16, - torch.int32, torch.int64, torch.bool]: - expected = [torch.ones(10, 10, device=device, dtype=torch.float32) for _ in range(10)] - else: - expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)] - - res = torch._foreach_add(tensors, float_scalar) - self.assertEqual(res, expected) - - if dtype in [torch.uint8, torch.int8, torch.int16, - torch.int32, torch.int64, torch.bool]: - self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, float_scalar)) - else: - torch._foreach_add_(tensors, float_scalar) - self.assertEqual(res, tensors) + @skipCUDAIfRocm + @dtypes(*torch.testing.get_all_dtypes()) + def test_float_scalarlist(self, device, dtype): + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalars = [1.1 for _ in range(N)] + + # If incoming dtype is float16 or bfloat16, runs in float32 and casts output back to dtype. + control_dtype = torch.float32 if (self.device_type == 'cuda' and + (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype + expected = [torch_bin_op(t.to(dtype=control_dtype), + s) for t, s in zip(tensors, scalars)] + if (dtype is torch.float16 or dtype is torch.bfloat16): + expected = [e.to(dtype=dtype) for e in expected] + + # we dont support bool and complex types on CUDA for now + if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == 'cuda': + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op_(tensors, scalars) + + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op(tensors, scalars) + return + + res = foreach_bin_op(tensors, scalars) + + if dtype == torch.bool: + # see TODO[Fix scalar list] + self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)]) + + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalars) + return + + if dtype in torch.testing.integral_types() and self.device_type == 'cuda': + # see TODO[Fix scalar list] + self.assertEqual(res, [e.to(dtype) for e in expected]) + + foreach_bin_op_(tensors, scalars) + self.assertEqual(tensors, res) + return + else: + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + else: + self.assertEqual(res, expected) + + if dtype in torch.testing.integral_types() and self.device_type == "cpu": + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalars) + return + + foreach_bin_op_(tensors, scalars) + if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM: + self.assertEqual(tensors, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0]) + else: + self.assertEqual(tensors, expected) + @skipCUDAIfRocm @dtypes(*torch.testing.get_all_dtypes()) def test_complex_scalar(self, device, dtype): - tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] - complex_scalar = 3 + 5j - - # bool tensor + 1 will result in int64 tensor - expected = [torch.add(complex_scalar, torch.zeros(10, 10, device=device, dtype=dtype)) for _ in range(10)] + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalar = 3 + 5j + expected = [torch_bin_op(t, scalar) for t in tensors] + + if dtype == torch.bool: + if foreach_bin_op == torch._foreach_sub: + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"): + foreach_bin_op_(tensors, scalar) + + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"): + foreach_bin_op(tensors, scalar) + return + + if dtype in torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=True) and \ + self.device_type == 'cuda': + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): + foreach_bin_op_(tensors, scalar) + + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): + foreach_bin_op(tensors, scalar) + return + + res = foreach_bin_op(tensors, scalar) + self.assertEqual(res, expected) + + if dtype not in [torch.complex64, torch.complex128]: + with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): + foreach_bin_op_(tensors, scalar) + else: + foreach_bin_op_(tensors, scalar) + self.assertEqual(res, tensors) - if dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16] and device == 'cuda:0': - # value cannot be converted to dtype without overflow: - self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar)) - self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, complex_scalar)) - return - - res = torch._foreach_add(tensors, complex_scalar) - self.assertEqual(res, expected) - - if dtype not in [torch.complex64, torch.complex128]: - self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar)) - else: - torch._foreach_add_(tensors, complex_scalar) - self.assertEqual(res, tensors) + @dtypes(*torch.testing.get_all_dtypes()) + def test_complex_scalarlist(self, device, dtype): + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalars = [3 + 5j for _ in range(N)] + expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)] + + if dtype == torch.bool: + if foreach_bin_op == torch._foreach_sub: + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"): + foreach_bin_op_(tensors, scalar) + + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"): + foreach_bin_op(tensors, scalar) + return + + with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"): + res = foreach_bin_op(tensors, scalars) + + with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"): + foreach_bin_op_(tensors, scalars) + @skipCUDAIfRocm @dtypes(*torch.testing.get_all_dtypes()) def test_bool_scalar(self, device, dtype): - tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] - bool_scalar = True - - expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)] + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalar = True + + if dtype == torch.bool: + expected = [torch_bin_op(t, scalar) for t in tensors] + res = foreach_bin_op(tensors, scalar) + + foreach_bin_op_(tensors, scalar) + self.assertEqual(tensors, res) + return + + if foreach_bin_op == torch._foreach_sub and self.device_type == "cpu": + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"): + res = foreach_bin_op(tensors, scalar) + + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"): + foreach_bin_op_(tensors, scalar) + elif foreach_bin_op == torch._foreach_sub and self.device_type == 'cuda': + res = foreach_bin_op(tensors, scalar) + self.assertEqual(res, foreach_bin_op(tensors, 1)) + + foreach_bin_op_(tensors, scalar) + self.assertEqual(tensors, res) + else: + expected = [torch_bin_op(t, scalar) for t in tensors] + res = foreach_bin_op(tensors, scalar) + + # TODO[type promotion]: Fix once type promotion is enabled. + if dtype in torch.testing.integral_types() and self.device_type == 'cuda': + self.assertEqual(res, [e.to(dtype) for e in expected]) + else: + self.assertEqual(res, expected) + + if dtype in torch.testing.integral_types(): + if foreach_bin_op == torch._foreach_div and self.device_type == "cpu": + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired "): + foreach_bin_op_(tensors, scalar) + else: + foreach_bin_op_(tensors, scalar) + self.assertEqual(tensors, res) + else: + foreach_bin_op_(tensors, scalar) + self.assertEqual(tensors, expected) - res = torch._foreach_add(tensors, bool_scalar) - self.assertEqual(res, expected) - - torch._foreach_add_(tensors, bool_scalar) - self.assertEqual(res, tensors) + @skipCUDAIfRocm + @dtypes(*torch.testing.get_all_dtypes()) + def test_bool_scalarlist(self, device, dtype): + for N in N_values: + for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops, + self.foreach_bin_ops_, + self.torch_bin_ops): + tensors = self._get_test_data(device, dtype, N) + scalars = [True for _ in range(N)] + + if dtype == torch.bool: + if self.device_type == 'cuda': + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op(tensors, scalars) + + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op_(tensors, scalars) + return + else: + if foreach_bin_op == torch._foreach_sub: + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"): + foreach_bin_op_(tensors, scalars) + + with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"): + foreach_bin_op(tensors, scalars) + else: + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired"): + foreach_bin_op_(tensors, scalars) + + res = foreach_bin_op(tensors, scalars) + for r in res: + self.assertTrue(r.dtype == torch.float32) + else: + # we dont support bool and complex types on CUDA for now + if (dtype in torch.testing.get_all_complex_dtypes()) and self.device_type == 'cuda': + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op_(tensors, scalars) + + with self.assertRaisesRegex(RuntimeError, "not implemented for"): + foreach_bin_op(tensors, scalars) + return + + if foreach_bin_op == torch._foreach_sub: + if self.device_type == "cpu": + # see TODO[Fix scalar list] + res = foreach_bin_op(tensors, scalars) + if dtype in torch.testing.integral_types(): + self.assertEqual(res, [r.to(torch.float32) for r in [torch_bin_op(t, 1) for t in tensors]]) + + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the "): + foreach_bin_op_(tensors, scalars) + else: + self.assertEqual(res, [torch_bin_op(t, 1) for t in tensors]) + foreach_bin_op_(tensors, scalars) + self.assertEqual(res, tensors) + else: + # see TODO[Fix scalar list] + res = foreach_bin_op(tensors, scalars) + if dtype in torch.testing.integral_types(): + self.assertEqual(res, [r.to(dtype) for r in [torch_bin_op(t, 1) for t in tensors]]) + else: + self.assertEqual(res, [torch_bin_op(t, 1) for t in tensors]) + + foreach_bin_op_(tensors, scalars) + self.assertEqual(res, tensors) + else: + if self.device_type == "cpu": + expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)] + res = foreach_bin_op(tensors, scalars) + + # see TODO[Fix scalar list] + if dtype in torch.testing.integral_types(): + self.assertEqual(res, [e.to(torch.float32) for e in expected]) + else: + self.assertEqual(res, expected) + + if dtype in torch.testing.integral_types(): + with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired "): + foreach_bin_op_(tensors, scalars) + else: + foreach_bin_op_(tensors, scalars) + self.assertEqual(tensors, expected) + else: + expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)] + res = foreach_bin_op(tensors, scalars) + + if dtype in torch.testing.integral_types(): + self.assertEqual(res, [e.to(dtype) for e in expected]) + else: + self.assertEqual(res, expected) + + foreach_bin_op_(tensors, scalars) + self.assertEqual(res, tensors) @dtypes(*torch.testing.get_all_dtypes()) def test_add_with_different_size_tensors(self, device, dtype): @@ -236,56 +790,57 @@ def test_bin_op_scalar_with_different_tensor_dtypes(self, device): # # Ops with list # - def test_add_list_error_cases(self, device): - tensors1 = [] - tensors2 = [] - - # Empty lists - with self.assertRaises(RuntimeError): - torch._foreach_add(tensors1, tensors2) - with self.assertRaises(RuntimeError): - torch._foreach_add_(tensors1, tensors2) - - # One empty list - tensors1.append(torch.tensor([1], device=device)) - with self.assertRaisesRegex(RuntimeError, "Tensor list must have at least one tensor."): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "Tensor list must have at least one tensor."): - torch._foreach_add_(tensors1, tensors2) - - # Lists have different amount of tensors - tensors2.append(torch.tensor([1], device=device)) - tensors2.append(torch.tensor([1], device=device)) - with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): - torch._foreach_add_(tensors1, tensors2) - - # Different dtypes - tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)] - tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)] - - with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): - torch._foreach_add_(tensors1, tensors2) - - # different devices - if torch.cuda.is_available() and torch.cuda.device_count() > 1: - tensor1 = torch.zeros(10, 10, device="cuda:0") - tensor2 = torch.ones(10, 10, device="cuda:1") - with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): - torch._foreach_add([tensor1], [tensor2]) - with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): - torch._foreach_add_([tensor1], [tensor2]) - - # Coresponding tensors with different sizes - tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)] - tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)] - with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"): - torch._foreach_add(tensors1, tensors2) - with self.assertRaisesRegex(RuntimeError, r", got \[10, 10\] and \[11, 11\]"): - torch._foreach_add_(tensors1, tensors2) + def test_bin_op_list_error_cases(self, device): + for bin_op, bin_op_ in zip(self.foreach_bin_ops, self.foreach_bin_ops_): + tensors1 = [] + tensors2 = [] + + # Empty lists + with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"): + bin_op_(tensors1, tensors2) + + # One empty list + tensors1.append(torch.tensor([1], device=device)) + with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."): + bin_op_(tensors1, tensors2) + + # Lists have different amount of tensors + tensors2.append(torch.tensor([1], device=device)) + tensors2.append(torch.tensor([1], device=device)) + with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"): + bin_op_(tensors1, tensors2) + + # Different dtypes + tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)] + tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)] + + with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."): + bin_op_(tensors1, tensors2) + + # different devices + if torch.cuda.is_available() and torch.cuda.device_count() > 1: + tensor1 = torch.zeros(10, 10, device="cuda:0") + tensor2 = torch.ones(10, 10, device="cuda:1") + with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): + bin_op([tensor1], [tensor2]) + with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): + bin_op_([tensor1], [tensor2]) + + # Corresponding tensors with different sizes + tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)] + tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)] + with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"): + bin_op(tensors1, tensors2) + with self.assertRaisesRegex(RuntimeError, r", got \[10, 10\] and \[11, 11\]"): + bin_op_(tensors1, tensors2) @dtypes(*torch.testing.get_all_dtypes()) def test_add_list(self, device, dtype): @@ -318,34 +873,19 @@ def test_div_list(self, device, dtype): self.skipTest("Skipped! See https://github.com/pytorch/pytorch/issues/44489") return - self._test_bin_op_list(device, dtype, torch._foreach_div, torch._foreach_div_, torch.div) - - def test_bin_op_list_error_cases(self, device): - tensors1 = [] - tensors2 = [] - - for bin_op in self.bin_ops: - # Empty lists - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) - - # One empty list - tensors1.append(torch.tensor([1], device=device)) - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) + for N in N_values: + tensors1 = self._get_test_data(device, dtype, N) - # Lists have different amount of tensors - tensors2.append(torch.tensor([1], device=device)) - tensors2.append(torch.tensor([1], device=device)) - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) - - # Different dtypes - tensors1 = [torch.zeros(2, 2, device=device, dtype=torch.float) for _ in range(2)] - tensors2 = [torch.ones(2, 2, device=device, dtype=torch.int) for _ in range(2)] + if dtype in [torch.bfloat16, torch.bool, torch.float16]: + tensors2 = [torch.zeros(N, N, device=device, dtype=dtype).add(2) for _ in range(N)] + else: + tensors2 = self._get_test_data(device, dtype, N) - with self.assertRaises(RuntimeError): - bin_op(tensors1, tensors2) + expected = [torch.div(tensors1[i], tensors2[i]) for i in range(N)] + res = torch._foreach_div(tensors1, tensors2) + torch._foreach_div_(tensors1, tensors2) + self.assertEqual(res, tensors1) + self.assertEqual(tensors1, res) @dtypes(*torch.testing.get_all_dtypes()) def test_add_list_different_sizes(self, device, dtype): diff --git a/test/test_function_schema.py b/test/test_function_schema.py index f2ad2290d326b..5a15273734789 100644 --- a/test/test_function_schema.py +++ b/test/test_function_schema.py @@ -14,90 +14,77 @@ def test_serialize_and_deserialize(self): self.assertEqual(parsed_schema, schema) self.assertTrue(parsed_schema.is_backward_compatible_with(schema)) - def test_backward_compatible_args(self): - old_schema = parse_schema('any(Tensor self, int dim) -> Tensor') - new_schema = parse_schema('any(Tensor self, int? dim) -> Tensor') + def test_backward_compatible_structure(self): + old_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> Tensor') + # BC: A new schema without changes. + new_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> Tensor') self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int dim=5) -> Tensor') - self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') - self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - - def test_backward_compatible_kwargs(self): - old_schema = parse_schema('any(Tensor self, *, Tensor out) -> Tensor') - new_schema = parse_schema('any(Tensor self, *, bool extra1=True, Tensor out, bool extra2=False) -> Tensor') - self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, Tensor out) -> Tensor') - self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - - def test_backward_compatible_ret(self): - old_schema = parse_schema('any(Tensor self) -> Tensor?') - new_schema = parse_schema('any(Tensor self) -> Tensor') - self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - - def test_backward_incompatible_name(self): - old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') - new_schema = parse_schema('any_(Tensor self, int dim, bool keepdim=False) -> Tensor') + # No-BC: A new schema with different name. + new_schema = parse_schema('any_.over(Tensor self, *, Tensor b) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - - def test_backward_incompatible_vararg(self): - old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') - new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False, ...) -> Tensor') + # No-BC: A new schema with different overload name. + new_schema = parse_schema('any.other(Tensor self, *, Tensor b) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - - def test_backward_incompatible_returns(self): - old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') - new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> (Tensor, ...)') + # No-BC: A new schema that adds vararg. + new_schema = parse_schema('any.over(Tensor self, *, Tensor b, ...) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> int') + # No-BC: A new schema with different number of outputs. + new_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> (Tensor, Tensor)') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor?') + + def test_backward_compatible_outputs(self): + old_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> Tensor') + # No-BC: A new schema with output becoming of optional type. + new_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> Tensor?') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + # BC: (the opposite case) An schema where the output is not of optional type anymore. self.assertTrue(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)') + # No-BC: A new schema with a different output type. + new_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> int') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor out') + # No-BC: A new schema with a different output type. + new_schema = parse_schema('any.over(Tensor self, *, Tensor b) -> Tensor out') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - def test_backward_incompatible_args(self): - old_schema = parse_schema('any(Tensor self, int[] dims, bool keepdim=False) -> Tensor') - new_schema = parse_schema('any(Tensor s, int[] dims, bool keepdim=False) -> Tensor') + def test_backward_compatible_arguments(self): + old_schema = parse_schema('any(Tensor self, *, Tensor b, int c) -> Tensor') + # No-BC: A new schema with less arguments. + new_schema = parse_schema('any(Tensor self, *, Tensor b) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int[3] dims, bool keepdim=False) -> Tensor') + # No-BC: A new schema with more arguments, appended, but no default value. + new_schema = parse_schema('any(Tensor self, *, Tensor b, int c, int d) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int[](a) dims, bool keepdim=False) -> Tensor') + # BC: A new schema with more arguments, appended, that have a default value. + new_schema = parse_schema('any(Tensor self, *, Tensor b, int c, int d=1) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + # No-BC: A new schema with more arguments, not-appended, that have a default value. + new_schema = parse_schema('any(Tensor self, int d=1, *, Tensor b, int c) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + # BC: A new schema where old kwargs becomes positional. + new_schema = parse_schema('any(Tensor self, Tensor b, *, int c) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + # BC: (the opposite case) A new schema where an old positional argument becomes kwarg. self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int dims, bool keepdim=False) -> Tensor') - self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + # BC: A new schema where all old kwargs become positional. + new_schema = parse_schema('any(Tensor self, Tensor b, int c) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + # BC: (the opposite case) A new schema where all old positional arguments become kwarg. self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int[] dim, bool keepdim=False, bool? extra=None) -> Tensor') + # No-BC: A new schema where old kwargs appear in different order. + new_schema = parse_schema('any(Tensor self, *, int c, Tensor b) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - - def test_backward_incompatible_kwargs(self): - old_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim=False) -> Tensor') - new_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim) -> Tensor') + # BC: A new schema where argument becomes of type optional. + new_schema = parse_schema('any(Tensor self, *, Tensor b, int? c) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + # BC: A new schema where argument gains a default value. + new_schema = parse_schema('any(Tensor self, *, Tensor b, int c=1) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + # No-BC: A new schema where argument is "renamed". + new_schema = parse_schema('any(Tensor self, *, Tensor b, int renamed) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertTrue(old_schema.is_backward_compatible_with(new_schema)) - new_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim=False, bool extra) -> Tensor') + # No-BC: A new schema where argument type changes to an incompatible type. + new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) - self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) - if __name__ == '__main__': run_tests() diff --git a/test/test_functional_autograd_benchmark.py b/test/test_functional_autograd_benchmark.py index 8c8e06754b6f8..c75edc5b6c7fa 100644 --- a/test/test_functional_autograd_benchmark.py +++ b/test/test_functional_autograd_benchmark.py @@ -5,6 +5,8 @@ import os import unittest +PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE")) + # This is a very simple smoke test for the functional autograd benchmarking script. class TestFunctionalAutogradBenchmark(TestCase): def _test_runner(self, model, disable_gpu=False): @@ -34,6 +36,7 @@ def _test_runner(self, model, disable_gpu=False): @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.") + @unittest.skipIf(PYTORCH_COLLECT_COVERAGE, "Can deadlocks with gcov, see https://github.com/pytorch/pytorch/issues/49656") def test_fast_tasks(self): fast_tasks = ['resnet18', 'ppl_simple_reg', 'ppl_robust_reg', 'wav2letter', 'transformer', 'multiheadattn'] diff --git a/test/test_futures.py b/test/test_futures.py index 26e916e57ca4a..4a783f2f6664d 100644 --- a/test/test_futures.py +++ b/test/test_futures.py @@ -92,7 +92,7 @@ def test_chained_then(self): for i in range(len(futs)): self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1) - def _test_error(self, cb, errMsg): + def _test_then_error(self, cb, errMsg): fut = Future[int]() then_fut = fut.then(cb) @@ -106,21 +106,127 @@ def test_then_wrong_arg(self): def wrong_arg(tensor): return tensor + 1 - self._test_error(wrong_arg, "unsupported operand type.*Future.*int") + self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int") def test_then_no_arg(self): def no_arg(): return True - self._test_error(no_arg, "takes 0 positional arguments but 1 was given") + self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given") def test_then_raise(self): def raise_value_error(fut): raise ValueError("Expected error") - self._test_error(raise_value_error, "Expected error") + self._test_then_error(raise_value_error, "Expected error") + + def test_add_done_callback_simple(self): + callback_result = False + + def callback(fut): + nonlocal callback_result + fut.wait() + callback_result = True + + fut = Future[torch.Tensor]() + fut.add_done_callback(callback) + + self.assertFalse(callback_result) + fut.set_result(torch.ones(2, 2)) + self.assertEqual(fut.wait(), torch.ones(2, 2)) + self.assertTrue(callback_result) + + def test_add_done_callback_maintains_callback_order(self): + callback_result = 0 + + def callback_set1(fut): + nonlocal callback_result + fut.wait() + callback_result = 1 + + def callback_set2(fut): + nonlocal callback_result + fut.wait() + callback_result = 2 + + fut = Future[torch.Tensor]() + fut.add_done_callback(callback_set1) + fut.add_done_callback(callback_set2) + + fut.set_result(torch.ones(2, 2)) + self.assertEqual(fut.wait(), torch.ones(2, 2)) + # set2 called last, callback_result = 2 + self.assertEqual(callback_result, 2) + + def _test_add_done_callback_error_ignored(self, cb): + fut = Future[int]() + fut.add_done_callback(cb) + + fut.set_result(5) + # error msg logged to stdout + self.assertEqual(5, fut.wait()) + + def test_add_done_callback_error_is_ignored(self): + + def raise_value_error(fut): + raise ValueError("Expected error") + + self._test_add_done_callback_error_ignored(raise_value_error) + + def test_add_done_callback_no_arg_error_is_ignored(self): + + def no_arg(): + return True + + # Adding another level of function indirection here on purpose. + # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI + self._test_add_done_callback_error_ignored(no_arg) + + def test_interleaving_then_and_add_done_callback_maintains_callback_order(self): + callback_result = 0 + + def callback_set1(fut): + nonlocal callback_result + fut.wait() + callback_result = 1 + + def callback_set2(fut): + nonlocal callback_result + fut.wait() + callback_result = 2 + + def callback_then(fut): + nonlocal callback_result + return fut.wait() + callback_result + + fut = Future[torch.Tensor]() + fut.add_done_callback(callback_set1) + then_fut = fut.then(callback_then) + fut.add_done_callback(callback_set2) + + fut.set_result(torch.ones(2, 2)) + self.assertEqual(fut.wait(), torch.ones(2, 2)) + # then_fut's callback is called with callback_result = 1 + self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) + # set2 called last, callback_result = 2 + self.assertEqual(callback_result, 2) + + def test_interleaving_then_and_add_done_callback_propagates_error(self): + def raise_value_error(fut): + raise ValueError("Expected error") + + fut = Future[torch.Tensor]() + then_fut = fut.then(raise_value_error) + fut.add_done_callback(raise_value_error) + fut.set_result(torch.ones(2, 2)) + + # error from add_done_callback's callback is swallowed + # error from then's callback is not + self.assertEqual(fut.wait(), torch.ones(2, 2)) + with self.assertRaisesRegex(RuntimeError, "Expected error"): + then_fut.wait() def test_collect_all(self): fut1 = Future[int]() diff --git a/test/test_fx.py b/test/test_fx.py index a48274e168094..ec8321b919d1a 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4,17 +4,26 @@ import numbers import pickle import copy +import sys +import functools +import contextlib from pathlib import Path -from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph +from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph, wrap +from torch.fx.experimental import shape_prop +from torch.fx.immutable_collections import immutable_dict, immutable_list +from copy import deepcopy from torch.fx.proxy import TraceError from fx.quantization import Quantizer +from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401 -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS from torch.testing._internal.jit_utils import JitTestCase +from fx.named_tup import MyNamedTup + try: from torchvision.models import resnet18 HAS_TORCHVISION = True @@ -26,6 +35,29 @@ class SimpleTest(torch.nn.Module): def forward(self, x): return torch.relu(x + 3.0) +def a_non_torch_leaf(a, b): + return a + b + +# Test wrap() passing both a function name as well as a function +# directly +def a_lifted_leaf(a, b): + return a[0] + a[1] + b + +wrap('a_lifted_leaf') +# Test wrapping twice doesn't break anything +wrap('a_lifted_leaf') + +def a_lifted_leaf2(a, b): + return a[0] + a[1] + b + +wrap(a_lifted_leaf2) + +wrap('len') + +class Pair(NamedTuple): + x : torch.Tensor + y : torch.Tensor + class TestFX(JitTestCase): def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version @@ -34,6 +66,7 @@ def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): kwargs = kwargs if kwargs else {} ref_outs = m(*args, **kwargs) gm = symbolic_trace(m) + gm.graph.lint(gm) test_outs = gm(*args, **kwargs) self.assertEqual(ref_outs, test_outs) @@ -79,6 +112,17 @@ def forward(self, A, b=4, *args, c=5, **kwargs): t = T() symbolic_trace(t) + def test_custom_import(self): + graph = torch.fx.Graph() + a = graph.placeholder('x') + b = graph.placeholder('y') + c = graph.call_function(a_non_torch_leaf, (a, b)) + d = graph.call_function(torch.sin, (c,)) + graph.output(d) + gm = GraphModule(torch.nn.Module(), graph) + x, y = torch.rand(1), torch.rand(1) + self.assertEqual(torch.sin(x + y), gm(x, y)) + def test_args_kwargs(self): class T(torch.nn.Module): def forward(self, *args, **kwargs): @@ -88,6 +132,16 @@ def forward(self, *args, **kwargs): t = T() self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) + def test_args_kwargs_no_self(self): + class T(torch.nn.Module): + def forward(*args, **kwargs): # noqa: B902 + self = args[0] + return torch.relu(args[1]) + + t = T() + with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): + self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) + def test_fx_shifts(self): class MyModule(torch.nn.Module): def forward(self, x): @@ -112,7 +166,8 @@ def test_disallow_override(self): # Custom delegate to disallow in-place tensor operations class NoMutableCallTracer(Tracer): def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node: + args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: name = target if isinstance(target, str) else torch.typename(target) if name[-1] == '_': raise RuntimeError('In-place operations are not supported') @@ -165,8 +220,29 @@ def forward(self, x): mrm = MyReluMod() sym = NoLeafModulesTracer().trace(mrm) - for node in sym.graph.nodes: + for node in sym.nodes: self.assertNotEqual(node.op, 'call_module') + sym.lint(sym) + + def test_wrap(self): + self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) + + def to_trace(y): + return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y) + + m = symbolic_trace(to_trace) + self.assertIn('a_lifted_leaf', m.code) + self.assertEqual(27, m(2)) + + def test_wrap_fn_directly(self): + self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) + + def to_trace(y): + return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y) + + m = symbolic_trace(to_trace) + self.assertIn('a_lifted_leaf2', m.code) + self.assertEqual(27, m(2)) def test_graph_edit_with_proxy(self): class M(torch.nn.Module): @@ -175,13 +251,48 @@ def forward(self, a, b): m = M() g = symbolic_trace(m).graph new_g = torch.fx.Graph() - new_g.graph_copy(g) - t = Proxy(new_g.nodes[-1]) + val_map : Dict[Node, Node] = {} + output_val = new_g.graph_copy(g, val_map) + t = Proxy(output_val) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) + gm.graph.lint(gm) self.assertEqual(gm(3, 4), 14) + def test_graph_unique_names(self): + class M(torch.nn.Module): + def forward(self, a, b): + return a + b + m = M() + g = symbolic_trace(m).graph + new_g = torch.fx.Graph() + val_map : Dict[Node, Node] = {} + output_val = new_g.graph_copy(g, val_map) + t = Proxy(output_val) + # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. + new_g.output((t + t).node) + gm = GraphModule(m, new_g) + seen_names : Set[str] = set() + for node in gm.graph.nodes: + assert node.name not in seen_names + seen_names.add(node.name) + + def test_graph_unique_names_manual(self): + graph : torch.fx.Graph = torch.fx.Graph() + a : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') + c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') + d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph.output(d) + graph2 = torch.fx.Graph() + val_map : Dict[Node, Node] = {} + graph2.graph_copy(graph, val_map) + seen_names : Set[str] = set() + for node in graph2.nodes: + assert node.name not in seen_names + seen_names.add(node.name) + @skipIfNoTorchVision def test_resnet(self): resnet = resnet18() @@ -204,6 +315,7 @@ def test_resnet(self): quantizer.observe((torch.rand(1, 3, 224, 224),)) qgraph = quantizer.quantize() + qgraph.graph.lint(qgraph) qgraph_script = torch.jit.script(qgraph) d = qgraph(ip) @@ -273,6 +385,7 @@ def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Mod operator.mul : "mul" } + output_node : Optional[Node] = None # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter @@ -299,9 +412,12 @@ def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Mod else: arg_names.append(arg.name) instructions.append((target_to_name[target], arg_names, out_name)) - + elif n.op == 'output': + if output_node is not None: + raise RuntimeError('Multiple output nodes!') + output_node = n else: - raise RuntimeError('Unsupported opcode' + n.op) + raise RuntimeError('Unsupported opcode ' + n.op) interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() # Load constants @@ -312,7 +428,8 @@ def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Mod # Load instructions interpreter.set_instructions(instructions) # Specify name for single output - interpreter.set_output_name(mod.graph.result.name) + assert isinstance(output_node.args[0], torch.fx.Node) + interpreter.set_output_name(output_node.args[0].name) # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== class WrapperModule(torch.nn.Module): @@ -347,6 +464,8 @@ def __init__(self, interpreter): # Register output graph.output(output_node) + graph.lint(wrapper) + # Return final GraphModule!!! return GraphModule(wrapper, graph) @@ -378,13 +497,15 @@ def forward(self, a): m = M() m_g = symbolic_trace(m) + m_g.graph.lint(m_g) for node in m_g.graph.nodes: self.assertTrue(node.name != "getattr") def test_node_tagging(self): class TaggingTracer(Tracer): def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node: + args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: n = super().create_node(kind, target, args, kwargs, name) n.tag = 'foo' return n @@ -394,7 +515,8 @@ def forward(self, a, b): return a + b m = M() - g = TaggingTracer().trace(m).graph + g = TaggingTracer().trace(m) + g.lint(m) for n in g.nodes: self.assertTrue(hasattr(n, 'tag')) self.assertEqual(n.tag, 'foo') @@ -422,6 +544,7 @@ def forward(self, x): wfq = WrapperForQualname() traced2 = symbolic_trace(wfq) + traced2.graph.lint(traced2) traced2(torch.rand(4, 4)) def test_symbolic_trace_sequential(self): @@ -435,6 +558,7 @@ def forward(self, x): Simple() ) traced = symbolic_trace(seq) + traced.graph.lint(traced) x = torch.rand(3, 4) self.assertEqual(traced(x), seq(x)) @@ -445,6 +569,7 @@ def forward(self, x): ct = ConstTensor() traced = symbolic_trace(ct) + traced.graph.lint(traced) traced(torch.rand(4, 4)) def test_pickle_graphmodule(self): @@ -458,23 +583,43 @@ def forward(self, x): n = Nested() traced = symbolic_trace(n) + traced.graph.lint(traced) pickled = pickle.dumps(traced) loaded = pickle.loads(pickled) + loaded.graph.lint(loaded) x = torch.rand(3, 4) self.assertEqual(loaded(x), traced(x)) + def test_all_input_nodes(self): + graph : torch.fx.Graph = torch.fx.Graph() + a : torch.fx.Node = graph.placeholder('x') + b : torch.fx.Node = graph.call_module('linear_mod', args=(a,)) + c : torch.fx.Node = graph.get_attr('y_attr') + d : torch.fx.Node = graph.call_function(operator.add, args=(b, c)) + e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) + graph.output(e) + graph.lint() + + self.assertEqual(b.all_input_nodes, [a]) + self.assertEqual(c.all_input_nodes, []) + self.assertEqual(d.all_input_nodes, [b, c]) + self.assertEqual(e.all_input_nodes, [d]) + def test_deepcopy_graphmodule_with_transform(self): st = SimpleTest() traced = symbolic_trace(st) + traced.graph.lint(traced) def transform(traced): new_graph = torch.fx.Graph() - new_graph.graph_copy(traced.graph) + val_map : Dict[Node, Node] = {} + output_value = new_graph.graph_copy(traced.graph, val_map) relu_out = new_graph.create_node( - op='call_method', target='neg', args=(new_graph.nodes[-1],), kwargs={}) + op='call_method', target='neg', args=(output_value,), kwargs={}) new_graph.output(relu_out) return GraphModule(traced, new_graph) transformed = transform(traced) + transformed.graph.lint(transformed) copied = copy.deepcopy(transformed) self.assertNotEqual(id(type(transformed)), id(type(copied))) x = torch.randn(3, 4) @@ -500,7 +645,9 @@ def forward(self, x): baz = Baz() traced = symbolic_trace(baz) + traced.graph.lint(traced) copied = copy.deepcopy(traced) + copied.graph.lint(copied) def test_unpack_list_better_error(self): class SomeArgs(torch.nn.Module): @@ -516,7 +663,7 @@ def forward(self, x : list): return self.sa(*x) ul = UnpacksList() - with self.assertRaisesRegex(TraceError, 'Proxy object cannot be unpacked as function argument'): + with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): symbolic_trace(ul) def test_unpack_dict_better_error(self): @@ -533,9 +680,63 @@ def forward(self, x : dict): return self.sk(**x) ud = UnpacksDict() - with self.assertRaisesRegex(TraceError, 'Proxy object cannot be unpacked as function argument'): + with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): symbolic_trace(ud) + def test_pretty_print_targets(self): + # Test that Graph pretty-print prints friendly name for targets + # in `operator` and `builtins` + + class SomeMod(torch.nn.Module): + def forward(self, x): + return torch.add(x.foo + x.bar, 3.0) + + traced = symbolic_trace(SomeMod()) + graph_str = str(traced.graph) + self.assertIn('builtins.getattr', graph_str) + self.assertIn('operator.add', graph_str) + self.assertIn('torch.add', graph_str) + + def test_script_tensor_constant(self): + # TorchScript seems to ignore attributes that start with `__`. + # We used to call anonymous Tensor values `__tensor_constant*`, but + # they were getting ignored by script. Now they're called + # `_tensor_constant*` + class IHaveATensorConstant(torch.nn.Module): + def forward(self, x): + return x + torch.rand(3, 4) + + traced = torch.fx.symbolic_trace(IHaveATensorConstant()) + torch.jit.script(traced) + + def test_torch_fx_len(self): + class FXLenTest(torch.nn.Module): + def forward(self, x): + return len(x) + + traced = symbolic_trace(FXLenTest()) + self.assertEqual(traced(torch.rand(3, 4)), 3) + + # Test scriptability + scripted = torch.jit.script(FXLenTest()) + self.assertEqual(scripted(torch.rand(3)), 3) + + traced_scripted = torch.jit.script(traced) + self.assertEqual(traced_scripted(torch.rand(3)), 3) + + # Test non-proxy len + class FXLenTest2(torch.nn.Module): + def __init__(self): + super().__init__() + self.l = [3, 4, 5] + + def forward(self, x): + return x + len(self.l) + + traced2 = symbolic_trace(FXLenTest2()) + inp = torch.rand(3, 4) + self.assertEqual(traced2(inp), inp + 3.0) + def test_torch_custom_ops(self): class M(torch.nn.Module): def forward(self, a): @@ -546,12 +747,14 @@ def forward(self, a): input = torch.randn(3) ref_out = m(input) gm = symbolic_trace(m) + gm.graph.lint(gm) out = gm(input) self.assertEqual(out, ref_out) def test_pretty_print(self): st = SimpleTest() traced = symbolic_trace(st) + traced.graph.lint(traced) printed = str(traced) assert 'GraphModuleImpl()' in printed assert 'torch.relu' in printed @@ -562,8 +765,9 @@ def forward(self, x): return torch.squeeze(x + 3.0, dim=2) st = KwargPrintTest() traced = symbolic_trace(st) + traced.graph.lint(traced) stringed = str(traced.graph) - for s in ['args', 'kwargs', 'uses']: + for s in ['args', 'kwargs', '#users']: assert s in stringed def test_graph_fns(self): @@ -578,11 +782,25 @@ def test_graph_fns(self): mod.linear = torch.nn.Linear(3, 4) mod.bias = torch.rand(4) gm = GraphModule(mod, g) + gm.graph.lint(gm) input = torch.rand(3) r = gm(input) ref = torch.sin(mod.linear(input) + mod.bias) self.assertEqual(r, ref) + def test_remove_uses(self): + g : torch.fx.Graph = Graph() + x : torch.fx.Node = g.placeholder('x') + relu : torch.fx.Node = g.call_function(torch.relu, (x,)) + neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) + g.output(neg) + + neg.replace_all_uses_with(relu) + g.erase_node(neg) + + self.assertTrue(neg not in relu.users) + + def test_construct_root_dict(self): graph : torch.fx.Graph = torch.fx.Graph() a : torch.fx.Node = graph.create_node('placeholder', 'x') @@ -595,6 +813,7 @@ def test_construct_root_dict(self): add_param : torch.Tensor = torch.rand(3, 4) gm : torch.fx.GraphModule = torch.fx.GraphModule( {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) + gm.graph.lint(gm) assert 'self.foo.bar.baz' in gm.code @@ -603,6 +822,460 @@ def test_construct_root_dict(self): ref_out : torch.Tensor = linear_mod(x) + add_param self.assertEqual(out, ref_out) + def test_symbolic_trace_assert(self): + + class AssertsTensorShape(torch.nn.Module): + def forward(self, x): + torch._assert(x.shape[1] > 4, "assert_foobar") + return x + + m = AssertsTensorShape() + # verify traceability + traced = symbolic_trace(m) + # verify assertion on traced model works correctly at runtime + traced(torch.rand(4, 5)) + with self.assertRaisesRegex(AssertionError, "assert_foobar"): + traced(torch.rand(4, 3)) + # verify the symbolically traced module is scriptable + ms = torch.jit.script(m) + with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"): + ms(torch.rand(4, 3)) + + + def test_copy_no_remap(self): + traced = symbolic_trace(SimpleTest()) + g = traced.graph + copied = torch.fx.Graph() + for node in g.nodes: + copied.node_copy(node) + with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): + copied.lint() + + def test_wrong_topo(self): + graph : torch.fx.Graph = torch.fx.Graph() + a : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) + c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') + d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph.output(d) + nodes = list(graph.nodes) + nodes[3].append(nodes[2]) + with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): + graph.lint() + + def test_example_shape_prop(self): + class TestCase(torch.nn.Module): + def __init__(self): + super().__init__() + self.attr = torch.randn(3, 4) + self.submod = torch.nn.Linear(4, 4) + + def forward(self, x): + return torch.neg(self.submod(x.relu() + self.attr)) + tc = TestCase() + tc_traced = symbolic_trace(tc) + ref_out = tc_traced(torch.rand(3, 4)) + shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) + + # Make sure we're testing all opcodes + opcodes = set() + output_shape : Optional[torch.Shape] = None + for node in tc_traced.graph.nodes: + opcodes.add(node.op) + if node.op == 'output': + output_shape = node.args[0].shape + self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', + 'call_module', 'output'])) + + # Test shape propogation and make sure results match actual + self.assertEqual(output_shape, ref_out.shape) + + def test_fn_type_annotations(self): + class Foo(torch.nn.Module): + def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: + return {'a': p.x + p.y + z + i} + + foo_scripted = torch.jit.script(Foo()) + foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) + + fxed = symbolic_trace(Foo()) + fxed_scripted = torch.jit.script(fxed) + fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) + + def test_fn_type_annotation_empty(self): + def forward(a : List[torch.Tensor]): + return a[0] + torch.jit.script(symbolic_trace(forward)) + + def test_wrapped_method(self): + def wrap_with_relu(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return torch.relu(fn(*args, **kwargs)) + return wrapper + + class Foo(torch.nn.Module): + @wrap_with_relu + def forward(self, x, w): + return torch.matmul(x, w) + + f = Foo() + traced = symbolic_trace(f) + x, w = torch.rand(3, 4), torch.rand(4, 4) + self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) + + def test_empty_graph_codegen(self): + graph = torch.fx.Graph() + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + self.assertEqual(gm(), None) + + def test_sequential(self): + m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) + gm = torch.fx.symbolic_trace(m) + gm_copy = copy.deepcopy(gm) + + def test_ctx_mgr(self): + @contextlib.contextmanager + def do_nothing(): + yield + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + @do_nothing() + def forward(self, x): + return torch.relu(x) + + m = M() + self.checkGraphModule(m, (torch.rand(3, 4),)) + + def test_typename_print(self): + graph : torch.fx.Graph = torch.fx.Graph() + x : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), + type_expr=List[float]) + output : torch.fx.Node = graph.output(b) + self.assertTrue('typing.List[float]' in str(graph)) + + def test_inf_nan(self): + class FooMod(torch.nn.Module): + def forward(self, x): + return x + float('inf'), x + float('-inf'), x + float('nan') + + fm = FooMod() + self.checkGraphModule(fm, (torch.rand(3, 4),)) + + def test_inf_nan_kwds(self): + graph : torch.fx.Graph = torch.fx.Graph() + x : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') + c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') + graph.output((b, c)) + + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + x = torch.rand(3, 4) + self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) + + def test_deepcopy_recursion_depth(self): + depth = sys.getrecursionlimit() + 20 + + g = torch.fx.Graph() + x = g.placeholder('x') + for i in range(depth): + x = g.call_function(torch.relu, (x,)) + g.output(x) + + copied_graph = copy.deepcopy(g) + + val_map = {} + for orig_node, new_node in zip(g.nodes, copied_graph.nodes): + val_map[orig_node] = new_node + + for orig_node, new_node in zip(g.nodes, copied_graph.nodes): + orig_users = set(orig_node.users.keys()) + orig_users_equiv = set(val_map[u] for u in orig_users) + new_users = set(new_node.users.keys()) + self.assertEqual(orig_users_equiv, new_users) + + @skipIfNoTorchVision + def test_replace_uses(self): + rn18 = resnet18() + + class LowerReluTracer(torch.fx.Tracer): + def is_leaf_module(self, m : torch.nn.Module, qualname : str): + if isinstance(m, torch.nn.ReLU): + return False + return super().is_leaf_module(m, qualname) + + rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) + + to_erase = [] + for node in rn18_traced.graph.nodes: + if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: + kwargs = node.kwargs.copy() + # Neg doesn't have in-place + kwargs.pop('inplace') + with rn18_traced.graph.inserting_before(node): + new_node = rn18_traced.graph.call_function( + the_function=torch.neg, args=node.args, kwargs=node.kwargs) + node.replace_all_uses_with(replace_with=new_node) + to_erase.append(node) + + for node in to_erase: + rn18_traced.graph.erase_node(node) + + def test_insertion_point(self): + graph : torch.fx.Graph = torch.fx.Graph() + x : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) + output : torch.fx.Node = graph.output(b) + + with graph.inserting_before(b): + neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) + _, *relu_args = b.args + b.args = (neg, *relu_args) + + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + input = torch.randn(33, 44) + self.assertEqual(gm(input), torch.relu(torch.neg(input))) + + + def test_move_before(self): + graph : torch.fx.Graph = torch.fx.Graph() + x : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) + output : torch.fx.Node = graph.output(b) + + neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) + _, *relu_args = b.args + b.args = (neg, *relu_args) + b.prepend(neg) + + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + input = torch.randn(33, 44) + self.assertEqual(gm(input), torch.relu(torch.neg(input))) + + def test_erase_node_error(self): + st = SimpleTest() + traced = symbolic_trace(st) + + for node in traced.graph.nodes: + # Test deleting with uses both in another Node and at the output + if node.target in [operator.add, torch.relu]: + with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): + traced.graph.erase_node(node) + + def test_copy_it(self): + d = immutable_dict([(3, 4), (5, 6)]) + l = immutable_list([(3, 4), (5, 6)]) + + self.assertEqual(d, deepcopy(d)) + self.assertEqual(l, deepcopy(l)) + + def test_find_uses(self): + graph = torch.fx.Graph() + x = torch.fx.Proxy(graph.placeholder('x')) + + y = torch.relu(x) + z = x + x + u = torch.neg(x) + graph.output((y + z + u).node) + graph.lint() + + users_of_x = x.node.users + self.assertEqual(len(users_of_x), 3) + expected_ops = set(['relu', 'add', 'neg']) + for use in users_of_x: + assert any(use.name.startswith(prefix) for prefix in expected_ops) + + def test_inline_graph(self): + class InlineInto(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + class ToInline(torch.nn.Module): + def forward(self, x): + return torch.neg(x) + + inline_into = symbolic_trace(InlineInto()) + to_inline = symbolic_trace(ToInline()) + + combined_graph = torch.fx.Graph() + output_node = combined_graph.graph_copy(inline_into.graph, {}) + + input_node = list(to_inline.graph.nodes)[0] + assert input_node and input_node.op == 'placeholder' + + val_map = {input_node : output_node} + output = combined_graph.graph_copy(to_inline.graph, val_map) + combined_graph.output(output) + + combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph) + + input = torch.rand(3, 4) + self.assertEqual(combined_module(input), input.relu().neg()) + + def test_multi_insert_point(self): + graph = torch.fx.Graph() + x = torch.fx.Proxy(graph.placeholder('x')) + relu = torch.relu(x) + + with graph.inserting_before(relu.node): + y = torch.neg(x) + z = torch.tanh(y) + + graph.output((relu.node, z.node)) + graph.lint() + + expected_ops = ['x', 'neg', 'tanh', 'relu'] + for node, expected in zip(graph.nodes, expected_ops): + assert expected in node.name + + def test_reassign_args_kwargs_uses(self): + graph = torch.fx.Graph() + x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) + z = x + y + zed = z + z + z + graph.output(zed.node) + graph.lint() + + # zed = z + z + z -> zed = z + z + x + zed.node.args = (zed.node.args[0], x.node) + self.assertEqual(x.node.users.keys(), [z.node, zed.node]) + + # z = x + y -> z = y + y + z.node.args = (y.node, y.node) + self.assertEqual(x.node.users.keys(), [zed.node]) + + def test_trace_function(self): + def foo(x, y): + return torch.relu(x) + y + + x, y = torch.randn(3, 4), torch.randn(3, 4) + self.checkGraphModule(foo, (x, y)) + + def test_direct_param_use(self): + class TransposeTest(torch.nn.Module): + def __init__(self): + super().__init__() + self.b = torch.nn.Parameter(torch.rand(4, 3)) + + def forward(self, x): + return self.b + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = TransposeTest() + + def forward(self, x): + return self.a.b, self.a.b.t(), self.a.b.view(12) + + traced = torch.fx.symbolic_trace(Foo()) + assert(all('constant' not in node.target for node in traced.graph.nodes)) + + def test_single_default_arg(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, y=1): + return y + + m = M() + self.checkGraphModule(m, ()) + self.checkGraphModule(m, (3,)) + + def test_multiple_default_args(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, y=1, z=2): + return y + z + + m = M() + self.checkGraphModule(m, ()) + self.checkGraphModule(m, (3,)) + self.checkGraphModule(m, (3, 4)) + + def test_regular_and_default_args(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y=1): + return x + y + + m = M() + self.checkGraphModule(m, (2,)) + self.checkGraphModule(m, (2, 3)) + + def test_string_literal_return(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self): + return "foo" + + m = M() + self.checkGraphModule(m, ()) + + def test_namedtuple_return_qualname(self): + class NamedTupReturn(torch.nn.Module): + def forward(self, x): + return MyNamedTup(x, x) + + traced = symbolic_trace(NamedTupReturn()) + input = torch.rand(3, 4) + self.assertEqual(traced(input), MyNamedTup(input, input)) + + def test_update_args_kwargs_yells_at_you(self): + symtraced = symbolic_trace(SimpleTest()) + node = next(iter(symtraced.graph.nodes)) + with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'): + node.__update_args_kwargs((), {}) + + def test_torchbind_class_attribute_in_fx(self): + if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: + self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") + + class FooBar1234(torch.nn.Module): + def __init__(self): + super(FooBar1234, self).__init__() + self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"]) + + def forward(self): + return self.f.top() + + m = FooBar1234() + self.checkGraphModule(m, ()) + + def test_namedtuple_return_trace(self): + class NamedTupReturn(torch.nn.Module): + def forward(self, x): + return Pair(x, x) + + traced = symbolic_trace(NamedTupReturn()) + input = torch.rand(3, 4) + self.assertEqual(traced(input), Pair(input, input)) + + def test_return_type_exists(self): + class ReturnTypeModule(torch.nn.Module): + def other(self, x: List[str]) -> List[str]: + return x + + def forward(self, x: List[str]) -> List[str]: + return self.other(x) + + traced = symbolic_trace(ReturnTypeModule()) + self.assertIn("-> typing.List[str]", traced._code) + scripted = torch.jit.script(traced) + self.assertIn("-> List[str]", scripted.code) if __name__ == '__main__': run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py new file mode 100644 index 0000000000000..ac71d60375913 --- /dev/null +++ b/test/test_fx_experimental.py @@ -0,0 +1,972 @@ +import torch +import unittest +import sys +from typing import Callable, Dict, Union, List +from torch.fx.symbolic_trace import symbolic_trace +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.experimental import graph_manipulation +from torch.fx.experimental.accelerator_partitioner import Partitioner +from torch.fx.experimental.rewriter import RewritingTracer +from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.jit_utils import JitTestCase +from torch.fx.experimental.subgraph_creation_example import split_module +from torch.fx.experimental.partitioner_utils import ( + NodeLatency, + get_partition_to_latency_mapping, + get_latency_of_partitioned_graph, + Device, + PartitionerConfig, + PartitionMode +) +from torch.fx.experimental.fuser import fuse +from torch.fx.experimental import merge_matmul + +try: + from torchvision.models import resnet18 + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + + +def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule: + return GraphModule( + root if isinstance(root, torch.nn.Module) else torch.nn.Module(), + RewritingTracer().trace(root), + ) + + +class TestFXExperimental(JitTestCase): + def test_serialize_graph(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.e = torch.rand(4) + self.conv = torch.nn.Conv2d(3, 3, 2, bias=False) + + def forward(self, a, b, c): + add_1 = a + b + conv1 = self.conv(c) + linear = self.linear(add_1 + conv1) + add_2 = linear + self.e + return add_2 + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + b = torch.rand(4) + c = torch.rand(3, 3, 2, 2) + graph_manipulation.get_size_of_all_nodes(traced, [a, b, c]) + + partitioner = Partitioner() + devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)] + partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + # Fix for now to add type/shape to output + for node in traced.graph.nodes: + if node.op == "output": + node.shape = a.shape + node.dtype = a.dtype + for mod in module_with_submodules.modules(): + if isinstance(mod, GraphModule): + for node in mod.graph.nodes: + node.shape = a.shape + node.dtype = a.dtype + for node in module_with_submodules.graph.nodes: + node.shape = a.shape + node.dtype = a.dtype + + agm1 = graph_manipulation.AcceleratedGraphModule(traced) + agm2 = graph_manipulation.AcceleratedGraphModule(module_with_submodules) + assert len(agm1.weights) == 4 + assert len(agm2.weights) == 4 + assert len(agm1.serialized_graph["nodes"]) == 10 + assert len(agm1.serialized_graph["weights"]) == 4 + assert len(agm1.serialized_graph["modules"]) == 0 + assert len(agm2.serialized_graph["nodes"]) == 6 + assert len(agm2.serialized_graph["weights"]) == 4 + assert len(agm2.serialized_graph["modules"]) == 1 + assert agm1.serialized_graph["weights"]["linear.weight"]["shape"] == "[4, 4]" + assert ( + agm1.serialized_graph["weights"]["linear.weight"]["dtype"] + == "torch.float32" + ) + assert ( + agm1.serialized_graph["weights"]["linear.weight"]["is_quantized"] is False + ) + assert agm1.serialized_graph["nodes"][0]["shape"] == "[4]" + assert agm1.serialized_graph["nodes"][0]["dtype"] == "torch.float32" + assert agm1.serialized_graph["nodes"][0]["target"] == "a" + assert agm1.serialized_graph["nodes"][0]["op_code"] == "placeholder" + assert agm1.serialized_graph["nodes"][0]["name"] == "a" + assert agm1.serialized_graph["nodes"][6]["args"][0]["name"] == "add_2" + assert agm1.serialized_graph["nodes"][6]["args"][0]["is_node"] is True + + # Test quantization info serialization. + x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) + q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32) + q_tensor_channel = torch.quantize_per_channel( + x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8 + ) + result = graph_manipulation.serialize_tensor_quantization(q_tensor) + result2 = graph_manipulation.serialize_tensor_quantization(q_tensor_channel) + assert result["q_scheme"] == "torch.per_tensor_affine" + assert result["q_scale"] == 1.0 + assert result2["q_scheme"] == "torch.per_channel_affine" + assert len(result2["q_per_channel_scales"]) == 2 + + def test_find_single_partition(self): + class TestModule(torch.nn.Module): + def forward(self, a, b): + return a + b + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(1) + b = torch.rand(1) + graph_manipulation.get_size_of_all_nodes(traced, [a, b]) + partitioner = Partitioner() + devices = [ + Device("dev_0", 125, 0), + Device("dev_1", 125, 1), + Device("dev_2", 125, 2) + ] + partitioner_config = PartitionerConfig(devices) + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(traced(a, b), module_with_submodules(a, b)) + assert dag.nodes[0].logical_device_ids == [0] + + def test_lack_of_devices(self): + class TestModule(torch.nn.Module): + def forward(self, a, b): + return a + b + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + b = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a, b]) + partitioner = Partitioner() + devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] + partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + catch_runtime_error = False + try: + ret = partitioner.partition_graph(traced, m, partitioner_config) + except RuntimeError: + catch_runtime_error = True + assert catch_runtime_error + + def test_large_node_error(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a): + linear = self.linear(a) + add = linear + a + return add + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a]) + partitioner = Partitioner() + devices = [ + Device("dev_0", 40, 0), + Device("dev_1", 40, 0), + Device("dev_2", 40, 0), + Device("dev_3", 40, 0), + Device("dev_4", 40, 0) + ] + partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + catch_runtime_error = False + try: + ret = partitioner.partition_graph(traced, m, partitioner_config) + except RuntimeError: + catch_runtime_error = True + assert catch_runtime_error + + def test_partition_node_manipulation(self): + class TestModule(torch.nn.Module): + def forward(self, a, b): + add_1 = a + b + add_2 = add_1 + torch.rand(4) + add_3 = add_2 + torch.rand(4) + return add_3 + + m = TestModule() + traced = symbolic_trace(m) + a, b = torch.rand(4), torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a, b]) + partitioner = Partitioner() + devices = [Device('dev_0', 1000, 0)] + partitioner_config = PartitionerConfig(devices) + ret = partitioner.partition_graph(traced, m, partitioner_config) + partition = partitioner.partitions[0] + assert partition.used_mem_bytes == 112 + # Select add_3 node to remove + selected_node = None + for node in partition.nodes: + if node.name == 'add_3': + selected_node = node + partition.remove_node(selected_node) + assert(partition.used_mem_bytes == 80) + + def test_size_based_partition(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.c = torch.rand(4) + + def forward(self, a, b): + add_1 = a + b + linear = self.linear(add_1) + add_2 = linear + self.c + return add_2 + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + b = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a, b]) + partitioner = Partitioner() + devices = [ + Device("dev_0", 125, 0), + Device("dev_1", 125, 1), + Device("dev_2", 125, 2) + ] + partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(traced(a, b), module_with_submodules(a, b)) + for i, node in enumerate(dag.nodes): + assert node.logical_device_ids == [i] + + def test_partition_device_mapping(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a): + b = torch.rand(4) + add_1 = a + b + linear_1 = self.linear(add_1) + add_2 = torch.rand(4) + a + add_3 = add_2 + linear_1 + return add_3 + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a]) + partitioner = Partitioner() + devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] + partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(traced(a), module_with_submodules(a)) + for i, node in enumerate(dag.nodes): + if i == 1: + assert node.logical_device_ids == [1] + else: + assert node.logical_device_ids == [0] + + def test_sparse_nn_partition(self): + class MyRecommendationModule(torch.nn.Module): + def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): + layers = torch.nn.ModuleList() + for _ in range(num_of_layers): + ll = torch.nn.Linear(input_size, output_size) + layers.append(ll) + layers.append(torch.nn.ReLU()) + return layers + + def __init__(self): + super(MyRecommendationModule, self).__init__() + layers = self.create_mlp(4, 4, 4) + self.bottom_layers = torch.nn.Sequential(*layers) + layers = self.create_mlp(3, 24, 24) + self.top_layers = torch.nn.Sequential(*layers) + self.embedding_layers = torch.nn.ModuleList() + el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) + self.embedding_layers.append(el) + for i in range(3): + el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) + self.embedding_layers.append(el) + el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) + self.embedding_layers.append(el) + + def forward(self, a, b, offset): + x = self.bottom_layers(a) + y = [] + c = [] + for i in range(len(self.embedding_layers)): + temp = torch.randint(10, (8,)) + c.append(temp + b) + for i in range(len(self.embedding_layers)): + if i % 2 == 0: + y.append(self.embedding_layers[i](c[i], offset)) + else: + y.append( + self.embedding_layers[i](torch.randint(10, (8,)), offset) + ) + z = torch.cat([x] + y, dim=1) + p = self.top_layers(z) + return p + + m = MyRecommendationModule() + a = torch.rand(2, 4) + b = torch.randint(10, (8,)) + offset = torch.randint(1, (2,)) + traced = symbolic_trace(m) + graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) + devices = [ + Device("dev_0", 33000000, 0), + Device("dev_1", 33000000, 1), + Device("dev_2", 33000000, 2) + ] + partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) + partitioner = Partitioner() + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) + assert len(module_with_submodules.graph.nodes) == 24 + + def test_partition_latency(self): + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a): + add_1 = a + torch.rand(4) + add_2 = add_1 + torch.rand(4) + linear_1 = self.linear(add_1) + add_3 = add_2 + linear_1 + add_4 = add_2 + add_3 + return add_4 + + def get_node_to_latency_mapping(fx_module: GraphModule): + """Given a fx module, generate node latency for each node + based on the size of each node + """ + node_to_latency_mapping: Dict[Node, NodeLatency] = {} + for node in fx_module.graph.nodes: + if node.op not in {"output", "placeholder", "get_attr"}: + if node.size_bytes.total_size == node.size_bytes.output_size: + node_to_latency_mapping[node] = NodeLatency( + node.size_bytes.total_size, 2.0 * node.size_bytes.total_size + ) + else: + node_to_latency_mapping[node] = NodeLatency( + node.size_bytes.total_size, node.size_bytes.output_size + ) + return node_to_latency_mapping + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a]) + node_to_latency_mapping = get_node_to_latency_mapping(traced) + devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] + partitioner = Partitioner() + partitioner_config = PartitionerConfig(devices) + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + self.assertEqual(traced(a), module_with_submodules(a)) + partitions = partitioner.partitions + partition_to_latency_mapping = get_partition_to_latency_mapping( + partitions, node_to_latency_mapping + ) + for p in partition_to_latency_mapping: + if p.partition_id == 0: + assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) + else: + assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) + transfer_rate_bytes_per_sec = 2 + critical_path_latency_sec = get_latency_of_partitioned_graph( + partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec + ) + assert critical_path_latency_sec == 208.0 + + def test_cost_aware_partition(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a): + add_1 = a + torch.rand(4) + add_2 = add_1 + torch.rand(4) + linear_1 = self.linear(add_1) + add_3 = add_2 + torch.rand(4) + add_4 = add_2 + linear_1 + add_5 = add_3 + add_4 + return add_5 + + def get_node_to_latency_mapping(fx_module: GraphModule): + node_to_latency_mapping: Dict[Node, Nodelatency] = {} + for node in fx_module.graph.nodes: + if node.op not in {'output', 'placeholder', 'get_attr'}: + if node.size_bytes.total_size == node.size_bytes.output_size: + node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, 1) + else: + node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, node.size_bytes.output_size) + return node_to_latency_mapping + + m = MyModule() + traced = symbolic_trace(m) + a = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a]) + devices = [ + Device('dev_0', 125, 0), + Device('dev_1', 125, 1), + Device('dev_2', 125, 2), + Device('dev_3', 125, 3) + ] + node_to_latency_mapping = get_node_to_latency_mapping(traced) + partitioner_config = PartitionerConfig( + devices, + mode=PartitionMode.cost_aware, + transfer_rate_bytes_per_sec=2, + node_to_latency_mapping=node_to_latency_mapping + ) + partitioner = Partitioner() + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(traced(a), module_with_submodules(a)) + partitions = partitioner.partitions + partition_to_latency_mapping = get_partition_to_latency_mapping(partitions, node_to_latency_mapping) + critical_path_latency_sec = get_latency_of_partitioned_graph( + partitions, + partition_to_latency_mapping, + partitioner_config.transfer_rate_bytes_per_sec + ) + assert critical_path_latency_sec == 160. + + def test_kl_based_partition(self): + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.linear = torch.nn.Linear(4, 4) + self.b = torch.rand(4) + self.c = torch.rand(4) + self.d = torch.rand(4) + + def forward(self, a): + add_1 = a + self.b + add_2 = add_1 + self.c + linear_1 = self.linear(add_1) + add_3 = add_2 + linear_1 + add_4 = add_2 + self.d + add_5 = add_3 + add_4 + return add_4 + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + graph_manipulation.get_size_of_all_nodes(traced, [a]) + node_to_latency_mapping = get_node_to_latency_mapping(traced) + transfer_rate_bytes_per_sec = 2 + devices = [ + Device('dev_0', 200, 0), + Device('dev_1', 200, 1), + Device('dev_2', 200, 2), + Device('dev_3', 200, 3) + ] + partitioner = Partitioner() + partitioner_config = PartitionerConfig( + devices, + mode=PartitionMode.kl_based, + transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec, + node_to_latency_mapping=node_to_latency_mapping + ) + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + self.assertEqual(traced(a), module_with_submodules(a)) + dag = ret.dag + assert dag.nodes[0] == 176 + assert dag.nodes[1] == 112 + partition_to_latency_mapping = get_partition_to_latency_mapping( + partitioner.partitions, + node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + partitioner.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec + ) + assert cost == 208. + + def test_aot_based_partition(self): + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.b = torch.rand(4) + self.c = torch.rand(4) + + def forward(self, a): + add_1 = a + self.b + add_2 = self.c + add_1 + return add_2 + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + node_to_partition_id = {} + partition_to_logical_devices = {} + count = 0 + GraphManipulation.get_size_of_all_nodes(traced, [a]) + for node in traced.graph.nodes: + if node.op not in {'placeholder', 'get_attr', 'output'}: + node_to_partition_id[node] = count + partition_to_logical_devices[count] = [0] + count += 1 + devices = [Device('dev_0', 200, 0)] + partitioner_config = PartitionerConfig( + devices=devices, + mode=PartitionMode.aot_based, + node_to_partition_mapping=node_to_partition_id, + partition_to_logical_device_mapping=partition_to_logical_devices + ) + partitioner = Partitioner() + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(module_with_submodules(a), traced(a)) + for node in dag.nodes: + assert node.size_bytes == 48 + assert node.logical_device_ids == [0] + + def test_replace_target_nodes_with(self): + class testModule(torch.nn.Module): + def forward(self, a, b): + return a + b + m = testModule() + traced = symbolic_trace(m) + input1 = torch.randn(1) + input2 = torch.randn(1) + assert (input1 + input2) == traced(input1, input2) + graph_manipulation.replace_target_nodes_with( + fx_module=traced, + old_op="call_function", + old_target=operator.add, + new_op="call_function", + new_target=operator.mul, + ) + assert (input1 * input2) == traced(input1, input2) + + @skipIfNoTorchVision + def test_conv_bn_fusion(self): + rn18 = resnet18().eval() + traced = symbolic_trace(rn18) + fused = fuse(traced) + + self.assertTrue(all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())) + + N, C, H, W = 20, 3, 224, 224 + inp = torch.randn(N, C, H, W) + + self.assertEqual(fused(inp), rn18(inp)) + + def test_call_to_assert_no_msg(self): + class M(torch.nn.Module): + def forward(self, a, b): + assert a == b + return a + b + + m = M() + traced = symbolic_trace_with_rewrite(m) + + # Make sure the graph is well-formed + traced.graph.lint(traced) + + # Check the IR to make sure there's a call_function node with target == "Assert" + self.assertTrue( + any( + node.op == "call_function" and node.target == torch._assert + for node in traced.graph.nodes + ) + ) + + # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to + traced(3, 3) + with self.assertRaisesRegex(AssertionError, ""): + traced(3, 5) + + # Confirm that the output is correct + self.assertEqual(traced(3, 3), m(3, 3)) + + def test_call_to_assert_with_msg(self): + class M(torch.nn.Module): + def forward(self, a, b): + assert a == b, "test message" + return a + b + + m = M() + traced = symbolic_trace_with_rewrite(m) + + # Make sure the graph is well-formed + traced.graph.lint(traced) + + # Check the IR to make sure there's a call_function node with target == "Assert" + self.assertTrue( + any( + node.op == "call_function" and node.target == torch._assert + for node in traced.graph.nodes + ) + ) + + # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to + traced(3, 3) + with self.assertRaisesRegex(AssertionError, "test message"): + traced(3, 5) + + # Confirm that the output is correct + self.assertEqual(traced(3, 3), m(3, 3)) + + def test_call_to_assert_with_empty_msg(self): + class M(torch.nn.Module): + def forward(self, a, b): + assert a == b, "" + return a + b + + m = M() + traced = symbolic_trace_with_rewrite(m) + + # Make sure the graph is well-formed + traced.graph.lint(traced) + + # Check the IR to make sure there's a call_function node with target == "Assert" + self.assertTrue( + any( + node.op == "call_function" and node.target == torch._assert + for node in traced.graph.nodes + ) + ) + + # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to + traced(3, 3) + with self.assertRaisesRegex(AssertionError, ""): + traced(3, 5) + + # Confirm that the output is correct + self.assertEqual(traced(3, 3), m(3, 3)) + + def test_call_to_assert_with_multiline_message(self): + class M(torch.nn.Module): + def forward(self, a, b): + error_msg = """ +An error message with +terrible spacing + """ + assert a == b, error_msg + return a + b + + m = M() + traced = symbolic_trace_with_rewrite(m) + + # Make sure the graph is well-formed + traced.graph.lint(traced) + + # Check the IR to make sure there's a call_function node with target == "Assert" + self.assertTrue( + any( + node.op == "call_function" and node.target == torch._assert + for node in traced.graph.nodes + ) + ) + + # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to + error_msg = """ +An error message with +terrible spacing + """ + traced(3, 3) + with self.assertRaisesRegex(AssertionError, error_msg): + traced(3, 5) + + # Confirm that the output is correct + self.assertEqual(traced(3, 3), m(3, 3)) + + def test_subgraph_creation(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + + def mod_partition(node: Node): + nonlocal partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + + # split module in module with submodules + module_with_submodules = split_module(my_module_traced, my_module, mod_partition) + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + + orig_out = my_module_traced(x, y) + submodules_out = module_with_submodules(x, y) + + self.assertEqual(orig_out, submodules_out) + + @skipIfNoTorchVision + def test_subgraph_trivial_resnet(self): + # Smoke test trivially splitting resnet into 1 partition works + # There was an issue before causing submodule names to be aliased + m = resnet18() + traced = symbolic_trace(m) + a = torch.rand(64, 3, 7, 7) + module_with_submodules = split_module(traced, m, lambda node: 0) + module_with_submodules(a) + + def test_subgraph_uniquename(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a, b, c, d): + add_1 = a + b + add_2 = add_1 + c + linear_1 = self.linear(add_1) + add_3 = add_2 + d + add_4 = add_2 + linear_1 + add_5 = add_3 + add_4 + return add_5 + + a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) + mm = MyModule() + traced = symbolic_trace(mm) + + def split_cb(node : torch.fx.Node): + if node.name == 'a' or node.name == 'b' or node.name == 'add': + return 0 + else: + return 1 + module_with_submodule = split_module(traced, mm, split_cb) + self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d)) + + def test_traceable_function_with_nonstandard_name(self): + def foo(x): + return torch.relu(x) + + traced = symbolic_trace_with_rewrite(foo) + + def test_to_folder(self): + class Test(torch.nn.Module): + def __init__(self): + super(Test, self).__init__() + self.W = torch.nn.Parameter(torch.randn(2)) + self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) + self.linear = torch.nn.Linear(2, 2) + self.attr = torch.randn(2) + self.register_buffer('attr2', torch.randn(2)) + + def forward(self, x): + return self.linear(self.seq(self.W + self.attr + self.attr2 + x)) + + mod = symbolic_trace(Test()) + module_name = 'Foo' + import tempfile + from pathlib import Path + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir = Path(tmp_dir) + mod.to_folder(tmp_dir, module_name) + # Recipe taken from here: + # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + import importlib.util + spec = importlib.util.spec_from_file_location(module_name, tmp_dir / '__init__.py') + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + t = torch.randn(2, 2) + self.assertEqual(module.Foo()(t), mod(t)) + + def test_fetch(self): + attrs_for_lowering: Dict[str, List[str]] = { + "torch.nn.modules.conv.Conv2d": [ + "weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode" + ], + "torch.nn.modules.batchnorm.BatchNorm2d": [ + "weight", "bias", "running_mean", "running_var", "eps" + ], + } + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 2) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, a): + a = self.conv(a) + a += a + return self.bn(a) + + mod = TestModule() + traced = symbolic_trace(mod) + lift_lowering_attrs_to_nodes(traced) + + for node in traced.graph.nodes: + if node.op == "call_module": + assert hasattr(node, "attrs_for_lowering") + para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] + + # node.attrs_for_lowering has an addition field of class name + assert len(para_list) + 1 == len(node.attrs_for_lowering) + for p_name in para_list: + assert p_name in node.attrs_for_lowering + + def test_merge_matmuls(self): + """ + A collection of test cases for torch.fx.experimental.merge_matmul, + a graph transformation that merges matrix multiplication operations. + """ + # Utility function for counting matmuls for test assertions. + def _count_matmuls(mod): + gm = torch.fx.symbolic_trace(mod) + + num_matmuls = 0 + for node in gm.graph.nodes: + if node.target == torch.matmul: + num_matmuls += 1 + + return num_matmuls + + # Simple test case in which there are two matmuls of the same size to merge. + class SimpleMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x, y): + a = torch.matmul(x, self.rhs) + b = torch.matmul(y, self.rhs) + return a + b + + # Initialize inputs. + a = torch.randn(3, 3) + b = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct SimpleMergeMatmulModule and call merge_matmul on it. + module = SimpleMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a, b) + after = opt_module(a, b) + before.allclose(after) + + # Basic graph structure check; original module should have 2 matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Test case in which there are multiple matmuls of different sizes to merge. + class FiveMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, a, b, c, d, e): + s = torch.Tensor((0)) + matmuls = [] + + # For some reason using a list comprehension or for-loop for this + # doesn't work. + matmuls.append(torch.matmul(a, self.rhs)) + matmuls.append(torch.matmul(b, self.rhs)) + matmuls.append(torch.matmul(c, self.rhs)) + matmuls.append(torch.matmul(d, self.rhs)) + matmuls.append(torch.matmul(e, self.rhs)) + + for m in matmuls: + s += torch.sum(m) + + return s + + # Initialize inputs. + inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] + + # Initialize RHS. + rhs = torch.randn(5, 4) + + # Construct FiveMergeMatmulModule and call merge_matmul on it. + module = FiveMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(*inputs) + after = opt_module(*inputs) + before.allclose(after) + + # Basic graph structure check; original module should have len(inputs) matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), len(inputs)) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Simple test case in which two matmuls cannot be merged due to a data dependency between + # the LHS operands. + class UnmergeableMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x): + a = torch.matmul(x, self.rhs) + a_abs = torch.abs(a) + b = torch.matmul(a_abs.transpose(1, 0), self.rhs) + return b + + # Initialize inputs. + a = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct UnmergeableMatmulModule and call merge_matmul on it. + module = UnmergeableMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a) + after = opt_module(a) + before.allclose(after) + + # Basic graph structure check; the number of matrix multiplcations should not have changed. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 2) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_indexing.py b/test/test_indexing.py index fb9e472d7a795..b92fd94e8cbde 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -1,11 +1,649 @@ -from torch.testing._internal.common_utils import TestCase, run_tests -from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA import torch from torch import tensor + import unittest import warnings +import random +from functools import reduce + +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA, + onlyOnCPUAndCUDA) + class TestIndexing(TestCase): + def test_index(self, device): + + def consec(size, start=1): + sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0) + sequence.add_(start - 1) + return sequence.view(*size) + + reference = consec((3, 3, 3)).to(device) + + # empty tensor indexing + self.assertEqual(reference[torch.LongTensor().to(device)], reference.new(0, 3, 3)) + + self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0) + self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0) + self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0) + self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0) + self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0) + self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0) + self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0) + + # indexing with Ellipsis + self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9], + [12, 15, 18], + [21, 24, 27]]), atol=0, rtol=0) + self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), atol=0, rtol=0) + self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0) + self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0) + self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0) + self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0) + self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0) + self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0) + self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0) + self.assertEqual(reference[...], reference, atol=0, rtol=0) + + reference_5d = consec((3, 3, 3, 3, 3)).to(device) + self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0) + self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0) + self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0) + self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0) + + # LongTensor indexing + reference = consec((5, 5, 5)).to(device) + idx = torch.LongTensor([2, 4]).to(device) + self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]])) + # TODO: enable one indexing is implemented like in numpy + # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]])) + # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1]) + + # None indexing + self.assertEqual(reference[2, None], reference[2].unsqueeze(0)) + self.assertEqual(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)) + self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1)) + self.assertEqual(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0)) + self.assertEqual(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2)) + + # indexing 0-length slice + self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)]) + self.assertEqual(torch.empty(0, 5), reference[slice(0), 2]) + self.assertEqual(torch.empty(0, 5), reference[2, slice(0)]) + self.assertEqual(torch.tensor([]), reference[2, 1:1, 2]) + + # indexing with step + reference = consec((10, 10, 10)).to(device) + self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0)) + self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0)) + self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0)) + self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1)) + self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0)) + self.assertEqual(reference[None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0)) + self.assertEqual(reference[:, 2, 1:6:2], + torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1)) + + lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] + tensor = torch.DoubleTensor(lst).to(device) + for _i in range(100): + idx1_start = random.randrange(10) + idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) + idx1_step = random.randrange(1, 8) + idx1 = slice(idx1_start, idx1_end, idx1_step) + if random.randrange(2) == 0: + idx2_start = random.randrange(10) + idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) + idx2_step = random.randrange(1, 8) + idx2 = slice(idx2_start, idx2_end, idx2_step) + lst_indexed = [l[idx2] for l in lst[idx1]] + tensor_indexed = tensor[idx1, idx2] + else: + lst_indexed = lst[idx1] + tensor_indexed = tensor[idx1] + self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) + + self.assertRaises(ValueError, lambda: reference[1:9:0]) + self.assertRaises(ValueError, lambda: reference[1:9:-1]) + + self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) + self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) + self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) + + self.assertRaises(IndexError, lambda: reference[0.0]) + self.assertRaises(TypeError, lambda: reference[0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) + + def delitem(): + del reference[0] + + self.assertRaises(TypeError, delitem) + + @onlyOnCPUAndCUDA + @dtypes(torch.half, torch.double) + def test_advancedindex(self, device, dtype): + # Tests for Integer Array Indexing, Part I - Purely integer array + # indexing + + def consec(size, start=1): + # Creates the sequence in float since CPU half doesn't support the + # needed operations. Converts to dtype before returning. + numel = reduce(lambda x, y: x * y, size, 1) + sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0) + sequence.add_(start - 1) + return sequence.view(*size).to(dtype=dtype) + + # pick a random valid indexer type + def ri(indices): + choice = random.randint(0, 2) + if choice == 0: + return torch.LongTensor(indices).to(device) + elif choice == 1: + return list(indices) + else: + return tuple(indices) + + def validate_indexing(x): + self.assertEqual(x[[0]], consec((1,))) + self.assertEqual(x[ri([0]), ], consec((1,))) + self.assertEqual(x[ri([3]), ], consec((1,), 4)) + self.assertEqual(x[[2, 3, 4]], consec((3,), 3)) + self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3)) + self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([1, 3, 5], dtype=dtype, device=device)) + + def validate_setting(x): + x[[0]] = -2 + self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device)) + x[[0]] = -1 + self.assertEqual(x[ri([0]), ], torch.tensor([-1], dtype=dtype, device=device)) + x[[2, 3, 4]] = 4 + self.assertEqual(x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device)) + x[ri([2, 3, 4]), ] = 3 + self.assertEqual(x[ri([2, 3, 4]), ], torch.tensor([3, 3, 3], dtype=dtype, device=device)) + x[ri([0, 2, 4]), ] = torch.tensor([5, 4, 3], dtype=dtype, device=device) + self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([5, 4, 3], dtype=dtype, device=device)) + + # Only validates indexing and setting for halfs + if dtype == torch.half: + reference = consec((10,)) + validate_indexing(reference) + validate_setting(reference) + return + + # Case 1: Purely Integer Array Indexing + reference = consec((10,)) + validate_indexing(reference) + + # setting values + validate_setting(reference) + + # Tensor with stride != 1 + # strided is [1, 3, 5, 7] + reference = consec((10,)) + strided = torch.tensor((), dtype=dtype, device=device) + strided.set_(reference.storage(), storage_offset=0, + size=torch.Size([4]), stride=[2]) + + self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device)) + self.assertEqual(strided[ri([0]), ], torch.tensor([1], dtype=dtype, device=device)) + self.assertEqual(strided[ri([3]), ], torch.tensor([7], dtype=dtype, device=device)) + self.assertEqual(strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device)) + self.assertEqual(strided[ri([1, 2]), ], torch.tensor([3, 5], dtype=dtype, device=device)) + self.assertEqual(strided[ri([[2, 1], [0, 3]]), ], + torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device)) + + # stride is [4, 8] + strided = torch.tensor((), dtype=dtype, device=device) + strided.set_(reference.storage(), storage_offset=4, + size=torch.Size([2]), stride=[4]) + self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device)) + self.assertEqual(strided[ri([0]), ], torch.tensor([5], dtype=dtype, device=device)) + self.assertEqual(strided[ri([1]), ], torch.tensor([9], dtype=dtype, device=device)) + self.assertEqual(strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device)) + self.assertEqual(strided[ri([0, 1]), ], torch.tensor([5, 9], dtype=dtype, device=device)) + self.assertEqual(strided[ri([[0, 1], [1, 0]]), ], + torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device)) + + # reference is 1 2 + # 3 4 + # 5 6 + reference = consec((3, 2)) + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.tensor([1, 3, 5], dtype=dtype, device=device)) + self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.tensor([2, 4, 6], dtype=dtype, device=device)) + self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) + self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) + self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.tensor([1, 2], dtype=dtype, device=device)) + self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], + torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device)) + self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + torch.tensor([1, 2, 3, 3], dtype=dtype, device=device)) + + rows = ri([[0, 0], + [1, 2]]) + columns = [0], + self.assertEqual(reference[rows, columns], torch.tensor([[1, 1], + [3, 5]], dtype=dtype, device=device)) + + rows = ri([[0, 0], + [1, 2]]) + columns = ri([1, 0]) + self.assertEqual(reference[rows, columns], torch.tensor([[2, 1], + [4, 5]], dtype=dtype, device=device)) + rows = ri([[0, 0], + [1, 2]]) + columns = ri([[0, 1], + [1, 0]]) + self.assertEqual(reference[rows, columns], torch.tensor([[1, 2], + [4, 5]], dtype=dtype, device=device)) + + # setting values + reference[ri([0]), ri([1])] = -1 + self.assertEqual(reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)) + reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device) + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], + torch.tensor([-1, 2, -4], dtype=dtype, device=device)) + reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) + self.assertEqual(reference[rows, columns], + torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)) + + # Verify still works with Transposed (i.e. non-contiguous) Tensors + + reference = torch.tensor([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], dtype=dtype, device=device).t_() + + # Transposed: [[0, 4, 8], + # [1, 5, 9], + # [2, 6, 10], + # [3, 7, 11]] + + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], + torch.tensor([0, 1, 2], dtype=dtype, device=device)) + self.assertEqual(reference[ri([0, 1, 2]), ri([1])], + torch.tensor([4, 5, 6], dtype=dtype, device=device)) + self.assertEqual(reference[ri([0]), ri([0])], + torch.tensor([0], dtype=dtype, device=device)) + self.assertEqual(reference[ri([2]), ri([1])], + torch.tensor([6], dtype=dtype, device=device)) + self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], + torch.tensor([0, 4], dtype=dtype, device=device)) + self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], + torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device)) + self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + torch.tensor([0, 4, 1, 1], dtype=dtype, device=device)) + + rows = ri([[0, 0], + [1, 2]]) + columns = [0], + self.assertEqual(reference[rows, columns], + torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device)) + + rows = ri([[0, 0], + [1, 2]]) + columns = ri([1, 0]) + self.assertEqual(reference[rows, columns], + torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device)) + rows = ri([[0, 0], + [1, 3]]) + columns = ri([[0, 1], + [1, 2]]) + self.assertEqual(reference[rows, columns], + torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device)) + + # setting values + reference[ri([0]), ri([1])] = -1 + self.assertEqual(reference[ri([0]), ri([1])], + torch.tensor([-1], dtype=dtype, device=device)) + reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device) + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], + torch.tensor([-1, 2, -4], dtype=dtype, device=device)) + reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) + self.assertEqual(reference[rows, columns], + torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)) + + # stride != 1 + + # strided is [[1 3 5 7], + # [9 11 13 15]] + + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) + strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), + stride=[8, 2]) + + self.assertEqual(strided[ri([0, 1]), ri([0])], + torch.tensor([1, 9], dtype=dtype, device=device)) + self.assertEqual(strided[ri([0, 1]), ri([1])], + torch.tensor([3, 11], dtype=dtype, device=device)) + self.assertEqual(strided[ri([0]), ri([0])], + torch.tensor([1], dtype=dtype, device=device)) + self.assertEqual(strided[ri([1]), ri([3])], + torch.tensor([15], dtype=dtype, device=device)) + self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], + torch.tensor([1, 7], dtype=dtype, device=device)) + self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], + torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device)) + self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + torch.tensor([1, 3, 9, 9], dtype=dtype, device=device)) + + rows = ri([[0, 0], + [1, 1]]) + columns = [0], + self.assertEqual(strided[rows, columns], + torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device)) + + rows = ri([[0, 1], + [1, 0]]) + columns = ri([1, 2]) + self.assertEqual(strided[rows, columns], + torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device)) + rows = ri([[0, 0], + [1, 1]]) + columns = ri([[0, 1], + [1, 2]]) + self.assertEqual(strided[rows, columns], + torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device)) + + # setting values + + # strided is [[10, 11], + # [17, 18]] + + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) + strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + stride=[7, 1]) + self.assertEqual(strided[ri([0]), ri([1])], + torch.tensor([11], dtype=dtype, device=device)) + strided[ri([0]), ri([1])] = -1 + self.assertEqual(strided[ri([0]), ri([1])], + torch.tensor([-1], dtype=dtype, device=device)) + + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) + strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + stride=[7, 1]) + self.assertEqual(strided[ri([0, 1]), ri([1, 0])], + torch.tensor([11, 17], dtype=dtype, device=device)) + strided[ri([0, 1]), ri([1, 0])] = torch.tensor([-1, 2], dtype=dtype, device=device) + self.assertEqual(strided[ri([0, 1]), ri([1, 0])], + torch.tensor([-1, 2], dtype=dtype, device=device)) + + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) + strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + stride=[7, 1]) + + rows = ri([[0], + [1]]) + columns = ri([[0, 1], + [0, 1]]) + self.assertEqual(strided[rows, columns], + torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device)) + strided[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) + self.assertEqual(strided[rows, columns], + torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)) + + # Tests using less than the number of dims, and ellipsis + + # reference is 1 2 + # 3 4 + # 5 6 + reference = consec((3, 2)) + self.assertEqual(reference[ri([0, 2]), ], + torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device)) + self.assertEqual(reference[ri([1]), ...], + torch.tensor([[3, 4]], dtype=dtype, device=device)) + self.assertEqual(reference[..., ri([1])], + torch.tensor([[2], [4], [6]], dtype=dtype, device=device)) + + # verify too many indices fails + with self.assertRaises(IndexError): + reference[ri([1]), ri([0, 2]), ri([3])] + + # test invalid index fails + reference = torch.empty(10, dtype=dtype, device=device) + # can't test cuda because it is a device assert + if not reference.is_cuda: + for err_idx in (10, -11): + with self.assertRaisesRegex(IndexError, r'out of'): + reference[err_idx] + with self.assertRaisesRegex(IndexError, r'out of'): + reference[torch.LongTensor([err_idx]).to(device)] + with self.assertRaisesRegex(IndexError, r'out of'): + reference[[err_idx]] + + def tensor_indices_to_np(tensor, indices): + # convert the Torch Tensor to a numpy array + tensor = tensor.to(device='cpu') + npt = tensor.numpy() + + # convert indices + idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else + i for i in indices) + + return npt, idxs + + def get_numpy(tensor, indices): + npt, idxs = tensor_indices_to_np(tensor, indices) + + # index and return as a Torch Tensor + return torch.tensor(npt[idxs], dtype=dtype, device=device) + + def set_numpy(tensor, indices, value): + if not isinstance(value, int): + if self.device_type != 'cpu': + value = value.cpu() + value = value.numpy() + + npt, idxs = tensor_indices_to_np(tensor, indices) + npt[idxs] = value + return npt + + def assert_get_eq(tensor, indexer): + self.assertEqual(tensor[indexer], get_numpy(tensor, indexer)) + + def assert_set_eq(tensor, indexer, val): + pyt = tensor.clone() + numt = tensor.clone() + pyt[indexer] = val + numt = torch.tensor(set_numpy(numt, indexer, val), dtype=dtype, device=device) + self.assertEqual(pyt, numt) + + def assert_backward_eq(tensor, indexer): + cpu = tensor.float().clone().detach().requires_grad_(True) + outcpu = cpu[indexer] + gOcpu = torch.rand_like(outcpu) + outcpu.backward(gOcpu) + dev = cpu.to(device).detach().requires_grad_(True) + outdev = dev[indexer] + outdev.backward(gOcpu.to(device)) + self.assertEqual(cpu.grad, dev.grad) + + def get_set_tensor(indexed, indexer): + set_size = indexed[indexer].size() + set_count = indexed[indexer].numel() + set_tensor = torch.randperm(set_count).view(set_size).double().to(device) + return set_tensor + + # Tensor is 0 1 2 3 4 + # 5 6 7 8 9 + # 10 11 12 13 14 + # 15 16 17 18 19 + reference = torch.arange(0., 20, dtype=dtype, device=device).view(4, 5) + + indices_to_test = [ + # grab the second, fourth columns + [slice(None), [1, 3]], + + # first, third rows, + [[0, 2], slice(None)], + + # weird shape + [slice(None), [[0, 1], + [2, 3]]], + # negatives + [[-1], [0]], + [[0, 2], [-1]], + [slice(None), [-1]], + ] + + # only test dupes on gets + get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] + + for indexer in get_indices_to_test: + assert_get_eq(reference, indexer) + if self.device_type != 'cpu': + assert_backward_eq(reference, indexer) + + for indexer in indices_to_test: + assert_set_eq(reference, indexer, 44) + assert_set_eq(reference, + indexer, + get_set_tensor(reference, indexer)) + + reference = torch.arange(0., 160, dtype=dtype, device=device).view(4, 8, 5) + + indices_to_test = [ + [slice(None), slice(None), [0, 3, 4]], + [slice(None), [2, 4, 5, 7], slice(None)], + [[2, 3], slice(None), slice(None)], + [slice(None), [0, 2, 3], [1, 3, 4]], + [slice(None), [0], [1, 2, 4]], + [slice(None), [0, 1, 3], [4]], + [slice(None), [[0, 1], [1, 0]], [[2, 3]]], + [slice(None), [[0, 1], [2, 3]], [[0]]], + [slice(None), [[5, 6]], [[0, 3], [4, 4]]], + [[0, 2, 3], [1, 3, 4], slice(None)], + [[0], [1, 2, 4], slice(None)], + [[0, 1, 3], [4], slice(None)], + [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], + [[[0, 1], [1, 0]], [[2, 3]], slice(None)], + [[[0, 1], [2, 3]], [[0]], slice(None)], + [[[2, 1]], [[0, 3], [4, 4]], slice(None)], + [[[2]], [[0, 3], [4, 1]], slice(None)], + # non-contiguous indexing subspace + [[0, 2, 3], slice(None), [1, 3, 4]], + + # less dim, ellipsis + [[0, 2], ], + [[0, 2], slice(None)], + [[0, 2], Ellipsis], + [[0, 2], slice(None), Ellipsis], + [[0, 2], Ellipsis, slice(None)], + [[0, 2], [1, 3]], + [[0, 2], [1, 3], Ellipsis], + [Ellipsis, [1, 3], [2, 3]], + [Ellipsis, [2, 3, 4]], + [Ellipsis, slice(None), [2, 3, 4]], + [slice(None), Ellipsis, [2, 3, 4]], + + # ellipsis counts for nothing + [Ellipsis, slice(None), slice(None), [0, 3, 4]], + [slice(None), Ellipsis, slice(None), [0, 3, 4]], + [slice(None), slice(None), Ellipsis, [0, 3, 4]], + [slice(None), slice(None), [0, 3, 4], Ellipsis], + [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], + [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], + [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], + ] + + for indexer in indices_to_test: + assert_get_eq(reference, indexer) + assert_set_eq(reference, indexer, 212) + assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) + if torch.cuda.is_available(): + assert_backward_eq(reference, indexer) + + reference = torch.arange(0., 1296, dtype=dtype, device=device).view(3, 9, 8, 6) + + indices_to_test = [ + [slice(None), slice(None), slice(None), [0, 3, 4]], + [slice(None), slice(None), [2, 4, 5, 7], slice(None)], + [slice(None), [2, 3], slice(None), slice(None)], + [[1, 2], slice(None), slice(None), slice(None)], + [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], + [slice(None), slice(None), [0], [1, 2, 4]], + [slice(None), slice(None), [0, 1, 3], [4]], + [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], + [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], + [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], + [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], + [slice(None), [0], [1, 2, 4], slice(None)], + [slice(None), [0, 1, 3], [4], slice(None)], + [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], + [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], + [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], + [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], + [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], + [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], + [[0], [1, 2, 4], slice(None), slice(None)], + [[0, 1, 2], [4], slice(None), slice(None)], + [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], + [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], + [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], + [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], + [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], + [slice(None), [2, 3, 4], [1, 3, 4], [4]], + [slice(None), [0, 1, 3], [4], [1, 3, 4]], + [slice(None), [6], [0, 2, 3], [1, 3, 4]], + [slice(None), [2, 3, 5], [3], [4]], + [slice(None), [0], [4], [1, 3, 4]], + [slice(None), [6], [0, 2, 3], [1]], + [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], + [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], + [[2, 0, 1], [1, 2, 3], [4], slice(None)], + [[0, 1, 2], [4], [1, 3, 4], slice(None)], + [[0], [0, 2, 3], [1, 3, 4], slice(None)], + [[0, 2, 1], [3], [4], slice(None)], + [[0], [4], [1, 3, 4], slice(None)], + [[1], [0, 2, 3], [1], slice(None)], + [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], + + # less dim, ellipsis + [Ellipsis, [0, 3, 4]], + [Ellipsis, slice(None), [0, 3, 4]], + [Ellipsis, slice(None), slice(None), [0, 3, 4]], + [slice(None), Ellipsis, [0, 3, 4]], + [slice(None), slice(None), Ellipsis, [0, 3, 4]], + [slice(None), [0, 2, 3], [1, 3, 4]], + [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], + [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], + [[0], [1, 2, 4]], + [[0], [1, 2, 4], slice(None)], + [[0], [1, 2, 4], Ellipsis], + [[0], [1, 2, 4], Ellipsis, slice(None)], + [[1], ], + [[0, 2, 1], [3], [4]], + [[0, 2, 1], [3], [4], slice(None)], + [[0, 2, 1], [3], [4], Ellipsis], + [Ellipsis, [0, 2, 1], [3], [4]], + ] + + for indexer in indices_to_test: + assert_get_eq(reference, indexer) + assert_set_eq(reference, indexer, 1333) + assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) + indices_to_test += [ + [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], + [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], + ] + for indexer in indices_to_test: + assert_get_eq(reference, indexer) + assert_set_eq(reference, indexer, 1333) + if self.device_type != 'cpu': + assert_backward_eq(reference, indexer) + + def test_advancedindex_big(self, device): + reference = torch.arange(0, 123344, dtype=torch.int, device=device) + + self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], + torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int)) + def test_single_int(self, device): v = torch.randn(5, 7, 3, device=device) self.assertEqual(v[4].shape, (7, 3)) @@ -126,7 +764,7 @@ def test_int_indices(self, device): @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool) @dtypesIfCPU(torch.float, torch.long, torch.bool, torch.bfloat16) - @dtypesIfCUDA(torch.half, torch.long, torch.bool) + @dtypesIfCUDA(torch.half, torch.long, torch.bool, torch.bfloat16) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype) diff --git a/test/test_jit.py b/test/test_jit.py index b689f76681f76..10ef7067fb615 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -19,7 +19,8 @@ from jit.test_class_type import TestClassType # noqa: F401 from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401 from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401 -from jit.test_freezing import TestFreezing # noqa: F401 +from jit.test_freezing import TestFreezing, TestFrozenOptimizations # noqa: F401 +from jit.test_peephole import TestPeephole # noqa: F401 from jit.test_save_load import TestSaveLoad # noqa: F401 from jit.test_module_containers import TestModuleContainers # noqa: F401 from jit.test_python_ir import TestPythonIr # noqa: F401 @@ -30,8 +31,13 @@ from jit.test_onnx_export import TestONNXExport # noqa: F401 from jit.test_with import TestWith # noqa: F401 from jit.test_enum import TestEnum # noqa: F401 +from jit.test_string_formatting import TestStringFormatting # noqa: F401 from jit.test_profiler import TestProfiler # noqa: F401 from jit.test_slice import TestSlice # noqa: F401 +from jit.test_warn import TestWarn # noqa: F401 +from jit.test_isinstance import TestIsinstance # noqa: F401 +from jit.test_cuda import TestCUDA # noqa: F401 +from jit.test_hash import TestHash # noqa: F401 # Torch from torch import Tensor @@ -39,7 +45,9 @@ from torch._six import PY37, StringIO from torch.autograd import Variable from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 +from torch.nn.utils.rnn import PackedSequence from torch.testing import FileCheck +import torch.autograd.profiler import torch.cuda import torch.jit import torch.jit._logging @@ -49,6 +57,7 @@ # Testing utils from torch.testing._internal import jit_utils +from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \ freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \ @@ -58,8 +67,8 @@ execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ RUN_CUDA from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, nn_functional_tests, get_script_args, \ - get_call, script_template, EXCLUDE_SCRIPT, additional_module_tests, EXCLUDE_SCRIPT_MODULES, \ - get_nn_module_name_from_kwargs, script_method_template, create_traced_fn + EXCLUDE_SCRIPT, additional_module_tests, EXCLUDE_SCRIPT_MODULES, \ + get_nn_module_name_from_kwargs, script_method_template, create_traced_fn, check_alias_annotation from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests from torch.testing._internal.common_methods_invocations import method_tests as autograd_method_tests @@ -73,30 +82,30 @@ # Standard library from collections import defaultdict, namedtuple, OrderedDict -import copy from copy import deepcopy -from itertools import product, chain -import itertools +from itertools import product from textwrap import dedent -from typing import List, Dict, Optional, Tuple, Union +from typing import List, Dict, NamedTuple, Optional, Tuple, Union +import copy +import functools import inspect +import io +import itertools import math -import functools import numpy as np -import io import os import pickle import pickletools import random +import re import shutil +import string import sys import tempfile import types import unittest import warnings import zipfile -import re -import string def canonical(graph): @@ -261,6 +270,8 @@ def get_grad_executor(plan_state, diff_graph_idx=None, skip_check=False): nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", nodes)) if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"): pass + elif len(nodes) == 2 and nodes[0].kind() == "prim::RequiresGradCheck" and nodes[1].kind() == "prim::If": + pass else: raise RuntimeError("Can't get a grad_executor for a non-differentiable graph") grad_executors = list(plan_state.code.grad_executor_states()) @@ -325,6 +336,40 @@ def test_inferred_as_tensor(self): def dot(points, query, dim): return (points * query).sum(dim) + def test_dict_comprehension(self): + def fn(): + return {i : chr(i + 65) for i in range(4)} + self.checkScript(fn, ()) + + def test_dict_comprehension_with_type_annotation(self): + def fn(): + d: Dict[int, str] = {i : chr(i + 65) for i in range(4)} + return d + self.checkScript(fn, ()) + + with self.assertRaisesRegex(RuntimeError, ""): + with self.assertRaisesRegex(AssertionError, "Expected Dict " + "type annotation for dict " + "comprehension, found " + "Tuple[int, str]"): + @torch.jit.script + def fn(): + d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)} + return d + + def test_dict_comprehension_scope(self): + def comprehension_can_access_outer_scope_variables(): + lst = ["foo", "bar", "baz"] + return {l : len(l) for l in lst} + + self.checkScript(comprehension_can_access_outer_scope_variables, ()) + + with self.assertRaisesRegex(RuntimeError, "undefined value i"): + @torch.jit.script + def outer_scope_cannot_access_comprehension_variables(): + d = {i : chr(i + 65) for i in range(4)} + i = i + 1 + def test_constants_pkl(self): # This test asserts that the serialization archive includes a `constants.pkl` # file. This file is used by `torch.load` to determine whether a zip file @@ -409,6 +454,15 @@ def forward(self, x): self.assertEqual(origin_result, m3(input.cpu())) self.assertEqual(origin_result, m4(input.cuda(0))) + def test_trace_retains_train(self): + class M(torch.nn.Module): + def forward(self, x): + return x + m = M() + m.eval() + tm = torch.jit.trace(m, (torch.rand(3))) + self.assertEqual(tm.training, m.training) + @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") def test_restore_shared_storage_on_cuda(self): class Foo(torch.jit.ScriptModule): @@ -428,134 +482,6 @@ def __init__(self): self.assertTrue(m2.b0.is_shared()) self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr()) - def test_peephole_with_writes(self): - def test_write(x): - s = 0 - s += x - s += x - return s - - self.checkScript(test_write, (torch.ones(4, 4),)) - - def test_peephole_with_non_output_writes(self): - - @torch.jit.ignore - def nomnom(x): - pass - - def test_write(x): - t = torch.ones_like(x) - z = x.clone() - y = z + 0 - z.add_(t) - # this makes sure z isn't blasted out of existence - # because it isn't returned or used in a side-effectful - # way - nomnom(z) - return y + y - - a = torch.ones(4, 4) - j = self.checkScript(test_write, (a,)) - - def test_peephole_no_output_aliasing(self): - def test_peephole(x): - y = x + 0 - return x, y - - a = torch.ones(4, 4) - j = self.checkScript(test_peephole, (a,)) - r1, r2 = j(a) - self.assertNotEqual(r1.data_ptr(), r2.data_ptr()) - - def test_peephole(self): - a = torch.tensor([0.4]) - b = torch.tensor([0.7]) - c = torch.tensor([0], dtype=torch.int32) - - def f(x, y): - return x.type_as(y) - - tf = torch.jit.trace(f, (a, b)) - FileCheck().check("type_as").run(str(tf.graph)) - self.run_pass('peephole', tf.graph) - FileCheck().check_not("type_as").run(str(tf.graph)) - tf2 = torch.jit.trace(f, (a, c)) - s = str(tf2.graph) - self.run_pass('peephole', tf2.graph) - self.assertEqual(s, str(s)) - - def test_peephole_dynamic(self): - def f(x, y): - return x.type_as(y) - - fn = torch.jit.script(f) - s = str(fn.graph) - torch._C._jit_pass_peephole(fn.graph) - self.assertEqual(s, str(fn.graph)) - - def test_peephole_list_ops(self): - @torch.jit.script - def foo(x, y, z): - return len([x, y, z]) - - self.run_pass('peephole', foo.graph) - FileCheck().check("value=3").check_next("return").run(foo.graph) - - @torch.jit.script - def foo(x, y, z): - li = [x, y, z] - for i in range(len(x)): - li.append(x) - return len([x, y, z]) - - self.run_pass('peephole', foo.graph) - FileCheck().check_not("aten::len").run(foo.graph) - - @torch.jit.script - def foo(x, y, z): - li = [x, y, z] - return li[1], li[-2] - - FileCheck().check("aten::__getitem__").run(foo.graph) - self.run_pass('peephole', foo.graph) - FileCheck().check_not("aten::__getitem__").run(foo.graph) - - @torch.jit.script - def foo(x, y, z): - li = [x, y, z] - return li[-7] - - self.run_pass('peephole', foo.graph) - FileCheck().check("aten::__getitem__").run(foo.graph) - - @torch.jit.script - def foo(x, y, z): - li = [x, y, z] - for i in range(len(x)): - li.append(x) - return li[-2] - - self.run_pass('peephole', foo.graph) - FileCheck().check("aten::__getitem__").run(foo.graph) - - @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") - def test_peephole_cuda(self): - a = torch.tensor([0.4], device='cpu') - b = torch.tensor([0.7], device='cuda') - c = torch.tensor([0.7], device='cuda') - - def f(x, y): - return x.type_as(y) - - trace = torch.jit.trace(f, (a, c)) - s = str(trace.graph) - self.run_pass('peephole', trace.graph) - self.assertEqual(s, str(trace.graph)) - trace = torch.jit.trace(f, (b, c)) - self.run_pass('peephole', trace.graph) - self.run_pass('dce', trace.graph) - FileCheck().check_not("type_as").run(str(trace.graph)) - def test_add_relu_fusion(self): class M(torch.nn.Module): def __init__(self, relu_op): @@ -581,7 +507,7 @@ def forward(self, a, b, c): m = torch.jit.load(buffer) new_res = m(a, b, c) FileCheck().check_not("aten::relu(") \ - .check("aten::add_relu(") \ + .check("aten::_add_relu(") \ .run(m.graph) torch.testing.assert_allclose(orig_res, new_res) @@ -600,7 +526,7 @@ def forward(self, a, b, c): m = torch.jit.load(buffer) new_res = m(a, b, c) FileCheck().check_not("aten::relu_(") \ - .check("aten::add_relu(") \ + .check("aten::_add_relu(") \ .run(m.graph) torch.testing.assert_allclose(orig_res, new_res) @@ -631,10 +557,10 @@ def forward(self, a, b): new_res = m(a_copy, b) FileCheck().check_not("aten::add_(") \ .check_not("aten::relu_(") \ - .check("aten::add_relu_(") \ + .check("aten::_add_relu_(") \ .run(m.graph) torch.testing.assert_allclose(orig_res, new_res) - # Since add_relu_ does inplace mutation ensure + # Since _add_relu_ does inplace mutation ensure # a_copy is modified torch.testing.assert_allclose(orig_res, a_copy) @@ -669,10 +595,10 @@ def forward(self, a, b): new_res = m(a_copy, b) FileCheck().check_not("aten::add(") \ .check_not("aten::relu_(") \ - .check("aten::add_relu(") \ + .check("aten::_add_relu(") \ .run(m.graph) torch.testing.assert_allclose(orig_res, new_res) - # Since add_relu_ with out=a does inplace mutation ensure + # Since _add_relu_ with out=a does inplace mutation ensure # a_copy is modified torch.testing.assert_allclose(orig_res, a_copy) @@ -798,6 +724,29 @@ def check(x, y): self.assertTrue(check(x, y)) + def test_nn_conv(self): + class Mod(nn.Module): + def __init__(self, conv): + super().__init__() + self.conv = conv + + def forward(self, input): + return self.conv(input) + + inputs = [ + # Conv + (Mod(nn.Conv1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), + (Mod(nn.Conv2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), + (Mod(nn.Conv3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), + # ConvTransposed + (Mod(nn.ConvTranspose1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), + (Mod(nn.ConvTranspose2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), + (Mod(nn.ConvTranspose3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), + ] + + for m, inp in inputs: + self.checkModule(m, (inp,)) + def test_numel(self): @torch.jit.script def get_numel_script(x): @@ -877,6 +826,30 @@ def forward(self, input): m_dropout.eval() self.assertEqual(dropout(input) + 1, m_dropout(input)) + def test_nn_padding(self): + class Mod(nn.Module): + def __init__(self, padding): + super().__init__() + self.padding = padding + + def forward(self, input): + return self.padding(input) + + inputs = [ + (Mod(nn.ConstantPad1d(2, 3.5)), torch.randn(1, 2, 4)), + (Mod(nn.ConstantPad2d(2, 3.5)), torch.randn(1, 2, 2)), + (Mod(nn.ConstantPad3d(3, 3.5)), torch.randn(16, 3, 10, 20, 30)), + (Mod(nn.ReflectionPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), + (Mod(nn.ReflectionPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), + (Mod(nn.ReplicationPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), + (Mod(nn.ReplicationPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), + (Mod(nn.ReplicationPad3d(3)), torch.randn(16, 3, 8, 32, 48)), + (Mod(nn.ZeroPad2d(2)), torch.randn(1, 1, 3, 3)) + ] + + for m, inp in inputs: + self.checkModule(m, (inp,)) + def test_script_autograd_grad(self): def test_simple_grad(x, y): # type: (Tensor, Tensor) -> List[Optional[Tensor]] @@ -1233,56 +1206,6 @@ def forward(self, x): FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph) - def test_reconstruct_scopes(self): - class SubModule(torch.nn.Module): - def __init__(self): - super(SubModule, self).__init__() - - def bar(self, x): - return x + x - - def forward(self, x): - return x * self.bar(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - self.sub = SubModule() - - def forward(self, x): - return self.sub(x) + x - - traced = torch.jit.trace(MyModule(), torch.zeros(1)) - g = traced.graph - torch._C._jit_pass_inline(g) - torch._C._jit_pass_reconstruct_scopes(traced._c, g) - FileCheck().check("scope: top(MyModule).sub(SubModule).forward").run(g) - - def test_reconstruct_scopes_duplicated_class_types(self): - class SubModule(torch.nn.Module): - def __init__(self): - super(SubModule, self).__init__() - - def forward(self, x): - return x + 2 - - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - self.sub1 = SubModule() - self.sub2 = SubModule() - - def forward(self, x): - return self.sub1(x) + self.sub2(x) - - traced = torch.jit.trace(MyModule(), torch.zeros(1)) - g = traced.graph - torch._C._jit_pass_inline(g) - torch._C._jit_pass_reconstruct_scopes(traced._c, g) - FileCheck().check_dag("scope: top(MyModule).sub1(SubModule).forward") \ - .check_dag("scope: top(MyModule).sub2(SubModule).forward") \ - .run(g) - def test_expand_quantlint(self): pass @@ -1298,18 +1221,18 @@ def broadcast(a, b): graph = torch.jit.script(broadcast).graph torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False) - FileCheck().check("Double(4:120, 3:40, 8:5, 5:1, device=cpu)").run(str(graph)) + FileCheck().check("Double(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph)) def test_shape_analysis_unsqueeze_in_loop(self): input_str = """graph(%x.1 : Tensor): %4 : bool = prim::Constant[value=1]() %1 : int = prim::Constant[value=2]() %7 : int = prim::Constant[value=0]() - # CHECK: FloatTensor = prim::Loop + # CHECK: FloatTensor(requires_grad=0, device=cpu) = prim::Loop %x : Tensor = prim::Loop(%1, %4, %x.1) - # CHECK: : FloatTensor): + # CHECK: : FloatTensor(requires_grad=0, device=cpu)): block0(%i : int, %x.6 : Tensor): - # CHECK: FloatTensor = aten::unsqueeze + # CHECK: FloatTensor(requires_grad=0, device=cpu) = aten::unsqueeze %x.3 : Tensor = aten::unsqueeze(%x.6, %7) -> (%4, %x.3) return (%x)""" @@ -1425,7 +1348,7 @@ def test_dropout(self): self.assertEqual(outputs, m(*inputs)) @slowTest - @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph') + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph') def test_dropout_module_requires_grad(self): with enable_profiling_mode_for_profiling_tests(): class MyModule(torch.nn.Module): @@ -1779,61 +1702,13 @@ def foo(): @torch.jit.script def fn(): - if True: + if 1 == 1: return 1 else: return 2 FileCheck().check_not("prim::If").run(fn.graph) - def test_short_circuit_optimization(self): - @torch.jit.script - def const_expressions(x): - # type: (int) -> Tuple[bool, bool] - return x == 1 and False, x == 1 or True - self.run_pass('constant_propagation', const_expressions.graph) - FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph) - self.assertEqual(const_expressions(1), (False, True)) - - @torch.jit.script - def redundant_expressions(x): - # type: (int) -> Tuple[bool, bool] - return x == 1 and True, x == 1 or False - - self.run_pass('peephole', redundant_expressions.graph) - self.assertEqual(redundant_expressions(1), (True, True)) - self.assertEqual(redundant_expressions(0), (False, False)) - # and True / or False are removed from graph - FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph) - - @_inline_everything - def test_peephole_type_refinements(self): - def refine(x): - # type: (Optional[Tensor]) -> Tensor - return x if x is not None else torch.tensor(3) - - @torch.jit.script - def test(): - return refine(torch.tensor(4)) - - FileCheck().check("prim::unchecked_cast").run(test.graph) - self.run_pass('peephole', test.graph) - FileCheck().check_not("prim::unchecked_cast").run(test.graph) - - # refinement not optimzied out - def is_int_tensor(x): - scalar = x.item() - if isinstance(scalar, int): - return scalar + 3 - else: - return 8 - - self.checkScript(is_int_tensor, (torch.tensor(2),)) - self.checkScript(is_int_tensor, (torch.tensor(2.5),)) - graph = torch.jit.script(is_int_tensor).graph - self.run_pass('peephole', graph) - FileCheck().check("prim::unchecked_cast").run(graph) - def test_unchecked_cast(self): def test(cond): # type: (bool) @@ -1857,15 +1732,15 @@ def constant_prop(a, b): c2 = 1 if bool(a): # -> c0, c1 if bool(b): # -> c0 - if True: # -> c0 + if 1 == 1: # -> c0 c0 = c0 + 1 - if False: + if 1 == 2: c1 = c1 + 1 c2 = c2 + 1 else: # -> c0, c1 c1 = c1 + 1 - if True: # inlined + if 1 == 1: # inlined c0 = c0 + 1 # dynamic c2 = c2 + 4 # set to 5 return a + c0 + c1 + c2 @@ -1916,7 +1791,7 @@ def constant_prop(iter): b = 1 c = 1 for i in range(iter): - if False: + if 1 == 2: a = 10 if i == 5: b = 2 @@ -1968,22 +1843,24 @@ def check_constant(constant_constructor): tup_constant = constants[i] + ", " + constants[j] check_constant(tup_constant) + dict_constants = [] for i in range(len(constants)): # check_constant constructs the second dict with another Tensor # which fails the comparison - if isinstance(eval(constants[i]), (list, bool, Tensor)) or eval(constants[i]) is None: + if not isinstance(eval(constants[i]), (str, int, float)): continue for j in range(len(constants)): dict_constant = "{ " + constants[i] + ": " + constants[j] + "}" check_constant(dict_constant) - + dict_constants.append(dict_constant) + constants = constants + dict_constants # testing node hashing funcs_template = dedent(''' def func(): print({constant_constructor}) ''') - single_elem_tuples = map(lambda x: "(" + x + ",)", constants) + single_elem_tuples = ("(" + x + ",)" for x in constants) input_arg = ", ".join(single_elem_tuples) scope = {} funcs_str = funcs_template.format(constant_constructor=input_arg) @@ -2008,14 +1885,8 @@ def func(): # generate dicts with built-in types (excluding torch.Tensor) xprod = itertools.product(constants, constants) - def keys_pred(t): - return isinstance(eval(t[0]), (list, bool)) or eval(t[0]) is None - - filt = [x for x in xprod if not keys_pred(x)] - dict_strs = map(lambda t: '{' + t[0] + ':' + t[1] + '}', filt) - # test that equal tuples and dicts correctly work with node hashing - for tup in chain(map(lambda x: "(" + x + ",)", constants), dict_strs): + for tup in ("(" + x + ",)" for x in constants): funcs_str = funcs_template.format(constant_constructor=tup) scope = {} execWrapper(funcs_str, globals(), scope) @@ -2398,8 +2269,7 @@ def fn(x): warns = [str(w.message) for w in warns] self.assertEqual(len(warns), 0) - @unittest.skipIf(IS_WINDOWS or True, "TODO: need to fix this test case for " - "Windows, re-enable with https://github.com/pytorch/pytorch/pull/29339") + @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339") def test_torch_load_error(self): class J(torch.jit.ScriptModule): def __init__(self): @@ -2410,20 +2280,20 @@ def forward(self, input): return input + 100 j = J() - with tempfile.NamedTemporaryFile() as f: - j.save(f.name) + with TemporaryFileName() as fname: + j.save(fname) with self.assertRaisesRegex(RuntimeError, "is a zip"): - torch.load(f.name) + torch.load(fname) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") def test_torch_load_zipfile_check(self): @torch.jit.script def fn(x): return x + 10 - with tempfile.NamedTemporaryFile() as f: - fn.save(f.name) - self.assertTrue(torch.serialization._is_zipfile(f)) + with TemporaryFileName() as fname: + fn.save(fname) + with io.open(fname, 'rb') as f: + self.assertTrue(torch.serialization._is_zipfile(f)) def test_python_bindings(self): lstm_cell = torch.jit.script(LSTMCellS) @@ -2543,10 +2413,10 @@ def fn(x): for e in prof.function_events: if e.name == "aten::mul": self.assertTrue(e.thread not in mul_events) - mul_events[e.thread] = e.cpu_interval.elapsed_us() + mul_events[e.thread] = e.time_range.elapsed_us() elif e.name == "other_fn": self.assertTrue(e.thread not in other_fn_events) - other_fn_events[e.thread] = e.cpu_interval.elapsed_us() + other_fn_events[e.thread] = e.time_range.elapsed_us() self.assertTrue(len(mul_events) == 2) self.assertTrue(len(other_fn_events) == 2) @@ -2803,8 +2673,8 @@ def test_not_const(x): test_not_const(torch.rand([2, 2])) graph_str = torch.jit.last_executed_optimized_graph() - FileCheck().check("profiled_type=Double(*:2, 2:1, requires_grad=0, device=cpu").run(graph_str) - FileCheck().check_not("profiled_type=Double(1:2, 2:1, requires_grad=0, device=cpu").run(graph_str) + FileCheck().check("profiled_type=Double(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) + FileCheck().check_not("profiled_type=Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) def test_nested_bailouts(self): @@ -3056,7 +2926,7 @@ def fn(self): return MyTuple(1) def forward(self, x): - if True: + if 1 == 1: return MyTuple(torch.rand(2, 3)) else: return self.fn() @@ -3099,7 +2969,7 @@ def fn(x): FileCheck().check_not("goodbye").check_not("hello").run(traced_bad.graph) # Working example - untraceable = torch.jit._script_if_tracing(untraceable) + untraceable = torch.jit.script_if_tracing(untraceable) def fn2(x): return untraceable(x) @@ -3112,7 +2982,7 @@ def fn2(x): def foo(x: int): return x + 1 - @torch.jit._script_if_tracing + @torch.jit.script_if_tracing def fee(x: int = 2): return foo(1) + x @@ -3527,17 +3397,22 @@ def replace(e): 'buffers_r': ['B'], 'children': ['another', 'foo'], 'modules': ['a', 'another', 'bar', 'foo'], - 'named_attributes': [('another', 'another'), + 'named_attributes': [('_is_full_backward_hook', None), + ('another', 'another'), ('foo', 'foo'), ('name', 'a'), ('p', 'P'), ('training', True)], - 'named_attributes_r': [('another', 'another'), + 'named_attributes_r': [('_is_full_backward_hook', None), + ('another', 'another'), + ('another._is_full_backward_hook', None), ('another.name', 'another'), ('another.training', True), ('foo', 'foo'), + ('foo._is_full_backward_hook', None), ('foo.b', 'B'), ('foo.bar', 'bar'), + ('foo.bar._is_full_backward_hook', None), ('foo.bar.an_int', 4), ('foo.bar.name', 'bar'), ('foo.bar.training', True), @@ -4164,8 +4039,8 @@ def debug_records_from_mod(mod): archive = zipfile.ZipFile(buffer) files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) debug_files = filter(lambda f: f.endswith('.debug_pkl'), files) - debug_files = map(lambda f: archive.open(f), debug_files) - debug_files = map(lambda f: pickle.load(f), debug_files) + debug_files = (archive.open(f) for f in debug_files) + debug_files = (pickle.load(f) for f in debug_files) return list(debug_files) debug_files = debug_records_from_mod(ft3) @@ -6207,7 +6082,7 @@ def fn(x, y, b): res = fn(None, t2, False) res = fn(None, t2, True) g = torch.jit.last_executed_optimized_graph() - self.assertEqual(next(g.outputs()).type().str(), "Tensor") + self.assertIn(next(g.outputs()).type().str(), ("Tensor", "Tensor(requires_grad=1)")) @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals") def test_optional_list(self): @@ -6529,6 +6404,24 @@ def func(a, b): self.checkScript(func, inputs_true, optimize=True) self.checkScript(func, inputs_false, optimize=True) + def test_ternary_module_type_hint(self): + class M1(torch.nn.Module): + def forward(self) -> Any: + return 'out' if self.training else {} + + class M2(torch.nn.Module): + def forward(self) -> Any: + out: Any = 'out' if self.training else {} + return out + + class M3(torch.nn.Module): + def forward(self) -> Optional[int]: + return None if self.training else 1 + + for module in [M1, M2, M3]: + self.checkModule(module().train(), ()) + self.checkModule(module().eval(), ()) + def test_print(self): def func(x, y): q = (x + y).sigmoid() @@ -6735,6 +6628,11 @@ def complicated_arithmetic_operation(): self.checkScript(complicated_arithmetic_operation, ()) + def test_in_operator_with_two_strings(self): + def fn() -> bool: + return "a" in "abcd" + self.checkScript(fn, ()) + def test_bitwise_ops(self): def int_test(): @@ -6890,6 +6788,61 @@ def foo(a, b): with self.assertRaisesRegex(RuntimeError, 'division by 0'): foo(i, j) + # Testing bitwise shorthand aug assignment + def test_bool_augassign_bitwise_or(self): + def func(a: bool, b: bool) -> bool: + a |= b + return a + + self.checkScript(func, (True, False), optimize=True) + self.checkScript(func, (True, True), optimize=True) + self.checkScript(func, (False, False), optimize=True) + self.checkScript(func, (False, True), optimize=True) + + def test_bool_augassign_bitwise_and(self): + def func(a: bool, b: bool) -> bool: + a &= b + return a + + self.checkScript(func, (True, False), optimize=True) + self.checkScript(func, (True, True), optimize=True) + self.checkScript(func, (False, False), optimize=True) + self.checkScript(func, (False, True), optimize=True) + + def test_bool_augassign_bitwise_xor(self): + def func(a: bool, b: bool) -> bool: + a ^= b + return a + + self.checkScript(func, (True, False), optimize=True) + self.checkScript(func, (True, True), optimize=True) + self.checkScript(func, (False, False), optimize=True) + self.checkScript(func, (False, True), optimize=True) + + def test_number_augassign_bitwise_lshift(self): + def func() -> int: + z = 8 + z <<= 2 + return z + + self.checkScript(func, (), optimize=True) + + def test_number_augassign_bitwise_rshift(self): + def func() -> int: + z = 8 + z >>= 2 + return z + + self.checkScript(func, (), optimize=True) + + def test_number_augassign_bitwise_pow(self): + def func() -> float: + z = 8 + z **= 2 + return z + + self.checkScript(func, (), optimize=True) + def test_number_augassign(self): def func(): z = 1 @@ -7072,7 +7025,8 @@ def test_dtype(inp_dtype: torch.dtype): else: g = test_dtype.graph_for(5) # first should have type set second should not - FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor").check("Tensor = aten::tensor").run(g) + FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor") \ + .check("Tensor(requires_grad=0) = aten::tensor").run(g) @torch.jit.script def test_as_tensor_tensor_input(input): @@ -7205,6 +7159,29 @@ def f(x): x = torch.rand(3, 4) self.assertEqual(scripted_f(x), f(x)) + def test_multiline_string_dedents(self): + def foo() -> None: + multiline_string_dedent_1 = """ +This is a string dedent """ + multiline_string_dedent_2 = """ This is a + string dedent """ + multiline_string_dedent_3 = """ + This is a string +dedent """ + multiline_string_dedent_4 = """ This is a string dedent """ + + scripted_foo = torch.jit.script(foo) + self.assertEqual(scripted_foo(), foo()) + + def test_class_with_comment_at_lower_indentation(self): + class Foo(torch.nn.Module): + def forward(self, x): + x = torch.neg(x) + # This comment is at the wrong indent + return x + + torch.jit.script(Foo()) + # adapted from test in test_torch def test_tensor_to(self): template = dedent(''' @@ -7385,6 +7362,19 @@ def any_refinement2(a): self.assertEqual(any_refinement2(3), torch.tensor(3)) self.assertEqual(any_refinement2(torch.tensor(5)), torch.tensor(5)) + @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "bug persists in deprecated executor") + def test_unspecialized_any_binding(self): + # any binding will infer the type, if it infers + # a specialized tensor type `x` Dict type will fail isinstance check + + @torch.jit.script + def foo(x: Any): + assert isinstance(x, Dict[str, torch.Tensor]) + + foo({"1": torch.tensor(3)}) + with self.assertRaises(Exception): + foo(2) + def test_isinstance(self): # test isinstance operator for static type checking template = dedent(''' @@ -7582,7 +7572,7 @@ def foo_break(cond): self.checkScript(foo_break, (i,)) def test_refine_outside_loop(): - if True: + if 1 == 1: x = None else: x = 1 @@ -8019,7 +8009,7 @@ def contained_blocks(node): return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop")) for node in ifs + loops: outs = list(node.outputs()) - out_name = list(map(lambda x: x.debugName(), outs)) + out_name = [x.debugName() for x in outs] if len(out_name) == 0: continue fc = FileCheck() @@ -8223,7 +8213,7 @@ def _dtype_to_expect(self, dtype, dim=0): def _test_dtype_op_shape(self, ops, args, input_dims=1): if input_dims < 1: - raise 'input dims must be at least 1' + raise RuntimeError("input dims must be at least 1") dtypes = [torch.float32, torch.float64, torch.int64, torch.int32] str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '') tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']') @@ -8783,7 +8773,7 @@ def test_pack_unpack_state(self): def test_torch_functional(self): def stft(input, n_fft): # type: (Tensor, int) -> Tensor - return torch.stft(input, n_fft) + return torch.stft(input, n_fft, return_complex=True) inps = (torch.randn(10), 7) self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps)) @@ -8792,8 +8782,8 @@ def istft(input, n_fft): # type: (Tensor, int) -> Tensor return torch.istft(input, n_fft) - inps2 = (torch.stft(*inps), inps[1]) - self.assertEqual(torch.istft(*inps2), torch.jit.script(torch.istft)(*inps2)) + inps2 = (stft(*inps), inps[1]) + self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2)) def lu(x): # type: (Tensor) -> Tuple[Tensor, Tensor] @@ -9601,7 +9591,7 @@ def foo(i): c = a # some nonsense with if-statements and loops to check # that tuple lowering doesn't fail - if True: + if 1 == 1: c = (i * 9, i + 1) t0, t1 = c while False: @@ -9618,14 +9608,14 @@ def foo(i): @torch.jit.script def mixtypes(x): a = (x, x) - if True: + if 1 == 1: a = 4 def test_if_tuple_sizes(self): with self.assertRaisesRegex(RuntimeError, "Type mismatch"): @torch.jit.script def diff_tuple_sizes(x): - if False: + if 1 == 2: c0 = ((x, x), (x, x, x)) else: c0 = ((x, x, x), (x, x)) @@ -9636,7 +9626,7 @@ def test_if_different_type(self): "in the true branch and type float in the false branch:"): @torch.jit.script def diff_type_used(): - if False: + if 1 == 2: c0 = 1 else: c0 = 1.0 @@ -9646,14 +9636,14 @@ def diff_type_used(): @torch.jit.script def diff_existing_type(x): c0 = 1.0 - if False: + if 1 == 2: c0 = 1 print(x) return x @torch.jit.script def diff_type_unused(): - if True: + if 1 == 1: c0 = 1 print(c0) else: @@ -9665,13 +9655,13 @@ def test_if_not_defined_error(self): with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the false branch"): @torch.jit.script def test(): - if True: + if 1 == 1: c0 = 1 return c0 with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the true branch"): @torch.jit.script def test2(): - if True: + if 1 == 1: pass else: c0 = 1 @@ -9760,7 +9750,7 @@ def forward(self): cm = ScriptMod(Mod()) # specialized tensor in graph - FileCheck().check("Double(1:3, 3:1, requires_grad=0, device=cpu)").run(cm.forward.graph) + FileCheck().check("Double(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph) buffer = io.BytesIO() torch.jit.save(cm, buffer) buffer.seek(0) @@ -9995,6 +9985,21 @@ def method(self, x): with self.assertRaisesRegex(RuntimeError, "Argument y not provided."): ModuleDefault() + def test_type_inferred_from_empty_annotation(self): + """ + Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true` + """ + @torch.jit.script + def fn(x): + return x + + graph = fn.graph + n = next(graph.inputs()) + self.assertTrue(n.type() == torch._C.TensorType.getInferred()) + + with self.assertRaisesRegex(RuntimeError, "Inferred \'x\' to be of type \'Tensor"): + fn(1) + def test_script_define_order(self): class M(torch.jit.ScriptModule): @@ -10304,7 +10309,7 @@ def foo(x, y): a = torch.zeros(2, 2) b = torch.zeros(4, dtype=torch.long) torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False) - FileCheck().check("Double(2:4, 4:1, requires_grad=0, device=cpu)").run(str(foo.graph)) + FileCheck().check("Double(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph)) def test_shape_analysis_loop(self): def foo(a, b, x): @@ -10408,7 +10413,7 @@ def t1(a): def t2(a): # mix const/non-const attributes - if True: + if 1 == 1: b = 1 else: b = 0 @@ -10592,8 +10597,8 @@ def test_rand(): out = fn() graph_str = torch.jit.last_executed_optimized_graph() self.assertEqual(out.dtype, torch.double) - FileCheck().check("Double(3:4, 4:1, requires_grad=0, device=cpu)") \ - .check_not("Float(3:4, 4:1, requires_grad=0, device=cpu)").run(graph_str) + FileCheck().check("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \ + .check_not("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str) # fn = self.checkScript(test_rand, ()) # out = fn() @@ -10610,7 +10615,7 @@ def randint(): out = randint() graph_str = torch.jit.last_executed_optimized_graph() self.assertEqual(out.dtype, torch.double) - FileCheck().check("profiled_type=Double(1:2, 2:1, requires_grad=0, device=cpu)").run(graph_str) + FileCheck().check("profiled_type=Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str) def test_erase_number_types(self): @@ -12066,10 +12071,10 @@ def tuple_slice(a): scripted_fn = torch.jit.script(tuple_slice) self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3)) tuple_graph = scripted_fn.graph - slices = tuple_graph.findAllNodes("prim::TupleSlice") - num_outputs = set(map(lambda x: len(x.output().type().elements()), slices)) - # one tuple slice should have an output with 2 elements, other 4 - self.assertTrue(num_outputs == {2, 4}) + slices = tuple_graph.findAllNodes("prim::TupleConstruct") + num_outputs = set(len(x.output().type().elements()) for x in slices) + # there should be only one tupleSlice with length of 2 + self.assertTrue(num_outputs == {2}) self.run_pass('lower_all_tuples', tuple_graph) self.assertTrue('Tuple' not in str(tuple_graph)) @@ -12080,6 +12085,26 @@ def test_indexing_end_out_of_bounds(): self.assertEqual(test_indexing_end_out_of_bounds(), ()) + def test_stepped_tuple_slicing(self): + + def check_slicing_tuple(slicing, tuple_type, tuple): + template = dedent(""" + def func(x): + # type: ({}) -> Any + return x{} + """) + self._check_code(template.format(tuple_type, slicing), "func", [tuple]) + + check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2)) + check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[5:7:-2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) + check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + check_slicing_tuple("[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)) + check_slicing_tuple("[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) + def test_lower_nested_tuples(self): @torch.jit.script def test(): @@ -12515,7 +12540,7 @@ def foo(cond): a = 3 if bool(cond): raise ArbitraryError(a, "hi") - if False: + if 1 == 2: raise ArbitraryError return a @@ -12554,10 +12579,10 @@ def foo(): # a escapes scope @torch.jit.script def foo(): - if True: + if 1 == 1: a = 1 else: - if True: + if 1 == 1: raise Exception("Hi") else: raise Exception("Hi") @@ -13259,7 +13284,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) g = cu.tanh.graph - FileCheck().check_count("prim::Function_0", 2).check("None = prim::Constant") \ + FileCheck().check_count("prim::Closure_0", 2).check("None = prim::Constant") \ .check_next("return").run(g) code = dedent(''' @@ -13276,7 +13301,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) g = cu.tanh.graph - FileCheck().check_count("prim::Function_0", 2).check("int = prim::If") \ + FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \ .run(g) code = dedent(''' @@ -13290,16 +13315,16 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) fc = FileCheck() - fc.check("prim::Function").check("(Tensor, None) = prim::TupleConstruct") + fc.check("prim::Closure").check("(Tensor, None) = prim::TupleConstruct") # Loop then two if's added in exit transform - fc.check("prim::Function").check("prim::Loop").check_count("prim::If", 2) + fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2) fc.run(cu.loop_in_closure.graph) code = dedent(''' def tanh(self): output = torch.tanh(self) def backward(grad_output): - if True: + if 1 == 1: return 1 else: return 1. @@ -13720,6 +13745,32 @@ def foo(): self.assertEqual(foo(), 1) + def test_boolean_literal_constant_metacompile(self): + class Mod(torch.nn.Module): + __constants__ = ['val'] + + def __init__(self, val): + super(Mod, self).__init__() + self.val = val + + def forward(self): + if self.val: + return 1 + else: + return "2" + + self.checkModule(Mod(True), ()) + self.checkModule(Mod(False), ()) + + @torch.jit.script + def foo(): + if True: + return 1 + else: + return "2" + + self.assertEqual(foo(), 1) + def test_assert_is_scripting_metacompile(self): def foo(): assert not torch.jit.is_scripting(), "TestErrorMsg" @@ -13758,6 +13809,23 @@ def test_non_primitive_types(x): out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) self.assertEqual(out, torch.tensor(6.0)) + def test_namedtuple_type_inference(self): + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) + _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) + + def test_check_named_tuple_value(): + named_tuple = _AnnotatedNamedTuple(1) + return named_tuple.value + + self.checkScript(test_check_named_tuple_value, ()) + + def test_error(): + return _UnannotatedNamedTuple(1) + + with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " + r"for argument \'value\' but instead found type \'int\'."): + torch.jit.script(test_error) + def test_isinstance_dynamic(self): @torch.jit.script def foo(a): @@ -14177,7 +14245,6 @@ def forward(self, x, lengths, h0, c0): self.assertEqual(eager_out, script_out) def test_nn_LSTM(self): - from torch.nn.utils.rnn import PackedSequence input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) class S(torch.jit.ScriptModule): @@ -14186,8 +14253,7 @@ def __init__(self): self.x = torch.nn.LSTM(5, 5) @torch.jit.script_method - def forward(self, input): - # type: (PackedSequence) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa + def forward(self, input: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: return self.x(input) eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0] @@ -14196,7 +14262,6 @@ def forward(self, input): self.assertEqual(eager_out, script_out) def test_nn_GRU(self): - from torch.nn.utils.rnn import PackedSequence seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) tensor_input = torch.randn(5, 5, 5) @@ -14206,8 +14271,7 @@ def __init__(self): self.x = torch.nn.GRU(5, 5) @torch.jit.script_method - def forward(self, input): - # type: (PackedSequence) -> Tuple[PackedSequence, Tensor] + def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: return self.x(input) class TensorGRU(torch.jit.ScriptModule): @@ -14216,8 +14280,7 @@ def __init__(self): self.x = torch.nn.GRU(5, 5) @torch.jit.script_method - def forward(self, input): - # type: (Tensor) -> Tuple[Tensor, Tensor] + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return self.x(input) seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0] @@ -14451,7 +14514,7 @@ def single_val_ignored(self, x, y): def forward(self, x, use_ignore_path): # type: (Tensor, bool) -> Tuple[Tensor, Tensor] - if False: + if 1 == 2: return self.tuple_ignored(x) if use_ignore_path: return self.single_val_ignored(x, x), self.single_val_ignored(x, x) @@ -15112,6 +15175,25 @@ def test_dict_types(self): def foo(): new_item = {'score': [1.0], 'ys': [1, 2, 3]} + def test_dict_invalid_annotations(self): + # Check for invalid value type annotation + def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]): + return + with self.assertRaisesRegex(ValueError, "Unknown type annotation"): + torch.jit.script(wrong_value_type) + + # Check for invalid key type annotation + def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]): + return + with self.assertRaisesRegex(ValueError, "Unknown type annotation"): + torch.jit.script(wrong_key_type) + + # Check for invalid key and value type annotation + def wrong_key_value_type(dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]): + return + with self.assertRaisesRegex(ValueError, "Unknown type annotation"): + torch.jit.script(wrong_key_value_type) + def test_get_set_state_with_tensors(self): class M(torch.nn.Module): def __init__(self): @@ -15349,6 +15431,18 @@ def fn(x): ref = a * b self.assertEqual(test, ref) + def test_signed_float_zero(self): + + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, x): + return torch.div(x, -0.) + + inp = torch.ones(1) + self.checkModule(MyModule(), inp) + # known to be failing in tracer EXCLUDE_TRACED = { # The following fail due to #12024. @@ -15398,6 +15492,12 @@ def fn(x): 'test_split_size_list', 'test_split_size_list_dim', 'test_split_size_list_dim_neg0', + 'test_tensor_indices_sections', + 'test_tensor_indices_sections_dim', + 'test_tensor_indices_sections_dim_neg0', + 'test_tensor_split_sections', + 'test_tensor_split_sections_dim', + 'test_tensor_split_sections_dim_neg0' } EXCLUDE_PYTHON_PRINT = { @@ -15414,94 +15514,12 @@ def fn(x): EXCLUDE_ALIAS = { # aliases, which may appear in method_tests but are tested elsewhere 'true_divide', -} - -def check_alias_annotation(method_name, args, kwargs): - formals, tensors, actuals = get_script_args(args) - call = get_call(method_name, 'method', actuals, kwargs) - script = script_template.format(', '.join(formals), call) - CU = torch.jit.CompilationUnit(script) - torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name) + # Disable tests for lu from common_methods_invocations.py + # TODO(@nikitaved) Enable jit tests once autograd.Function does support scripting + 'lu' +} -def check_output_types(self, func, ref_outputs, args, kwargs): - graph = getattr(func, 'last_graph', None) - types = [o.type() for o in graph.outputs()] - self.assertTrue(len(types) == 1) - t = types[0] - torch._C._jit_assert_is_instance(ref_outputs, t) - - -def check_against_reference(self, func, reference_func, args, kwargs=None, - allow_unused=True, check_types=True, no_grad=False): - kwargs = kwargs if kwargs else {} - - def allSum(vs): - if isinstance(vs, torch.Tensor): - vs = (vs,) - return sum((i + 1) * v.sum() - for i, v in enumerate(vs) - if v is not None and v.dtype.is_floating_point) - - def clone_inputs(requires_grad): - inputs = [ - arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad) - if isinstance(arg, torch.Tensor) else arg for arg in args - ] - return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad] - - nograd_inputs, nograd_tensors = clone_inputs(False) - recording_inputs, recording_tensors = clone_inputs(True) - - # test no gradients case - outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs) - with enable_profiling_mode_for_profiling_tests(): - outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs) - self.assertEqual(outputs, outputs_test) - - if check_types: - check_output_types(self, func, outputs_test, nograd_inputs, kwargs) - - if no_grad: - # skip grad tests - return - - with enable_profiling_mode_for_profiling_tests(): - # test single grad case - outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs) - grads = torch.autograd.grad(allSum(outputs), recording_tensors, - allow_unused=allow_unused) - outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs) - grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, - allow_unused=allow_unused) - self.assertEqual(outputs, outputs_test) - self.assertEqual(grads, grads_test) - # test the grad grad case - if self._testMethodName in nn_functional_single_grad: - return - - outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs) - l1 = allSum(outputs) - grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, - allow_unused=allow_unused) - - l2 = (allSum(grads) * l1) - grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) - recording_inputs, recording_tensors = clone_inputs(True) - outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs) - l1_test = allSum(outputs_test) - grads_test = torch.autograd.grad( - l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) - - l2_test = (allSum(grads_test) * l1_test) - grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused) - - self.assertEqual(outputs, outputs_test) - self.assertEqual(grads, grads_test) - for g2, g2_test in zip(grads2, grads2_test): - if g2 is None and g2_test is None: - continue - self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4)) class TestJitGeneratedAutograd(JitTestCase): pass @@ -15583,7 +15601,7 @@ def add_autograd_test( # Disable complex tests # TODO: Add complex support for jit - if 'complex' in variant_name or name in ['view_as_complex', 'complex']: + if 'complex' in variant_name or name in ['view_as_complex', 'complex', 'angle']: return # Skips aliases, which are tested in test_op_aliases.py @@ -15691,8 +15709,8 @@ def fn(*inputs, **kwargs): check_types=check_types) # alias annotation testing - if is_inplace and test_name not in EXCLUDE_SCRIPT: - check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable) + if not is_magic_method and test_name not in EXCLUDE_SCRIPT and not exclude_tensor_method(name, test_name): + check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable, aten_name=name) check(name) inplace_name = name + '_' diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 0c8a1f9a967d6..c09dbd5c9c3d6 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -3,29 +3,46 @@ import torch -from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, skipIfRocm +from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed from test_jit import JitTestCase, RUN_CUDA + +from jit.test_fuser_common import TestFuserCommon # noqa: F401 + import itertools import numpy as np +os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' +os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' +os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' + if GRAPH_EXECUTOR == ProfilingMode.PROFILING: torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) FUSION_GROUP = 'prim::CudaFusionGroup' +FUSION_GUARD = 'prim::CudaFusionGuard' class TestCudaFuser(JitTestCase): def _getSubgraphInFusion(self, graph): - self.assertGraphContainsExactly(graph, FUSION_GROUP, 1, consider_subgraphs=False) - - for node in graph.nodes(): - if node.kind() == FUSION_GROUP: - self.assertTrue(node.hasAttribute('Subgraph')) - return node.g('Subgraph') + num_node = 0 + subgraph = None + + def count(block, ret): + for n in block.nodes(): + if n.kind() == FUSION_GROUP: + ret[0] = ret[0] + 1 + self.assertTrue(n.hasAttribute('Subgraph')) + ret[1] = n.g('Subgraph') + for block in n.blocks(): + count(block, ret) + ret = [num_node, subgraph] + count(graph, ret) + self.assertEqual(ret[0], 1) + return ret[1] def setUp(self): super(TestCudaFuser, self).setUp() @@ -33,6 +50,7 @@ def setUp(self): self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) + self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False) if(RUN_CUDA): self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True) @@ -42,6 +60,7 @@ def tearDown(self): torch._C._jit_set_nvfuser_enabled(self.old_nvfuser) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse) + torch._C._jit_set_nvfuser_guard_mode(self.old_guard) super(TestCudaFuser, self).tearDown() def _run_helper(self, jit_op, op, *args): @@ -52,11 +71,11 @@ def _run_helper(self, jit_op, op, *args): torch.cuda.manual_seed_all(123) o = op(*args) self.assertEqual(o, jit_o) - self.assertGraphContains(jit_op.graph_for(*args), FUSION_GROUP) + self.assertGraphContains(jit_op.graph_for(*args), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_half(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float): o_16 = torch.add(x, y) @@ -77,11 +96,11 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float): for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) - self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_const(self): def t(x, y): o = x + y @@ -94,11 +113,11 @@ def t(x, y): jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_chunk(self): def t(x, y, z, q): o = x + q @@ -117,11 +136,11 @@ def t(x, y, z, q): jit_o = t_jit(x, y, z, q) o = t(x, y, z, q) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_scalar_input(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y @@ -135,11 +154,11 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_0(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): @@ -157,8 +176,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_1(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): @@ -176,8 +195,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_2(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): @@ -195,8 +214,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_3(self): def t(x: torch.Tensor, y: torch.Tensor, z: float): @@ -217,8 +236,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): # Testing partition logic that is capable to avoid creating unsupported # broadcasting semantics in CudaFusionGroup @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_partition_logic_0(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): @@ -239,8 +258,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_partition_logic_1(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): @@ -262,8 +281,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): @unittest.skipIf(True, "Broadcast with different output not supported yet") @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_multiple_output_shape(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = x + 12 @@ -280,12 +299,12 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = t(x, y, z) self.assertEqual(o, jit_o) # Currently cannot fuse this - self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(True, "broadcast on branches can't be resolved yet") @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_broadcasting_multiple_output(self): def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = x + 12 @@ -302,7 +321,7 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = t(x, y, z) self.assertEqual(o, jit_o) # Currently cannot fuse this - self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) def _binary_test_helper(self, operation): def t(x: torch.Tensor, y: torch.Tensor, z: float): @@ -316,7 +335,7 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) def _unary_test_helper(self, operation): def t(x: torch.Tensor, z: float): @@ -329,11 +348,11 @@ def t(x: torch.Tensor, z: float): jit_o = t_jit(x, 2.0) o = t(x, 2.0) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, 2.0), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, 2.0), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_unary_ops(self): operations = [torch.neg, torch.abs, @@ -369,8 +388,8 @@ def test_unary_ops(self): self._unary_test_helper(op) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_binary_ops(self): operations = [torch.div, torch.mul, @@ -476,10 +495,11 @@ def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_dynamic_size(self): - torch._C._jit_set_bailout_depth(3) + old_guard = torch._C._jit_set_nvfuser_guard_mode(True) + torch._C._jit_set_bailout_depth(20) def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y @@ -504,14 +524,15 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda") y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) + torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_CUDA, "requires CUDA") def test_random_topo(self): @@ -552,16 +573,15 @@ def t(x: torch.Tensor, y: torch.Tensor): o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) # end-2-end test of permutation & contiguity handling in integration. # we are testing inputs with all combination of permutation order, just to # ensure that integration would be able to generate functionally correct # kernels @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_binary_ops_permutation(self): # note that num_dim is exclusive from len(x), so we are not reducing # to single element (codegen limitation at this moment) @@ -598,12 +618,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # numerical issues here due to our scheduling. # can't use `self.assertEqual(o, jit_o)` self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) - self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_reduction(self): for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]): # note that num_dim is exclusive from len(x), so we are not reducing @@ -615,9 +634,8 @@ def test_reduction(self): self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_reduction_permutation(self): x = [7, 8, 12] # note that num_dim is exclusive from len(x), so we are not reducing @@ -629,10 +647,11 @@ def test_reduction_permutation(self): self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_reduction_multiple_output(self): - torch._C._jit_set_bailout_depth(2) + old_guard = torch._C._jit_set_nvfuser_guard_mode(True) + torch._C._jit_set_bailout_depth(20) def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor): o = torch.mul(x, y) @@ -652,7 +671,7 @@ def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor): for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) - self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD) x = x.to(memory_format=torch.channels_last) y = y.to(memory_format=torch.channels_last) @@ -663,12 +682,12 @@ def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor): for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) - self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD) + torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_reduction_dtype(self): def t(x: torch.Tensor): o = torch.mul(x, 1.0) @@ -682,12 +701,11 @@ def t(x: torch.Tensor): o = t(x) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) - self.assertGraphContains(t_jit.graph_for(x), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_reduction_half(self): def t(x: torch.Tensor): o = torch.mul(x, 1.0) @@ -701,12 +719,11 @@ def t(x: torch.Tensor): o = t(x) self.assertEqual(o.dtype, jit_o.dtype) self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) - self.assertGraphContains(t_jit.graph_for(x), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_pw_single_reduction_partition(self): sizes = [8, 8, 8] dtype = torch.float @@ -726,12 +743,11 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = t(x, y, z) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") - @skipIfRocm + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_single_reduction_broadcast(self): dtype = torch.float device = "cuda" @@ -750,14 +766,51 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = t(x, y, z) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_profiling_node(self): + dtype = torch.float + device = "cuda" + x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device) + + def repro(x: torch.Tensor, alpha: float): + o = torch.rand_like(x) + o = torch.add(o, alpha) + return o + repro_jit = torch.jit.script(repro) + self._run_helper(repro_jit, repro, x, 0.6) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_reduction_sizes_op(self): + dtype = torch.float + device = "cuda" + x = torch.randn(2, 3, 4, 5, dtype=dtype, device=device) + y = torch.randn(2, 3, 4, 5, dtype=dtype, device=device) + + def t(x: torch.Tensor, y: torch.Tensor): + o = x + y + o = torch.relu(o) + o = o.sum((1, 3)) + return o.size() + t_jit = torch.jit.script(t) + jit_o = t_jit(x, y) + jit_o = t_jit(x, y) + o = t(x, y) + self.assertEqual(o, jit_o) + # since the output value is not used at all, the fusion operator should + # have been optimized away + self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0) class TestPassManagerCudaFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != - ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_context_manager_test(self): x = torch.randn(4, 8, dtype=torch.float, device="cuda") y = torch.randn(4, 8, dtype=torch.float, device="cuda") @@ -771,7 +824,7 @@ def t1(x, y): t_jit = torch.jit.script(t1) t_jit(x, y) t_jit(x, y) - self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP) + self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) def t2(x, y): o = x + y @@ -780,7 +833,7 @@ def t2(x, y): t_jit_2 = torch.jit.script(t2) t_jit_2(x, y) t_jit_2(x, y) - self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GROUP) + self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GUARD) def t3(x, y): o = x + y @@ -789,7 +842,7 @@ def t3(x, y): t_jit_3 = torch.jit.script(t3) t_jit_3(x, y) t_jit_3(x, y) - self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GROUP, 0) + self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GUARD, 0) @unittest.skipIf(not RUN_CUDA, "requires CUDA") def test_register_fuser(self): diff --git a/test/test_jit_cuda_fuser_legacy.py b/test/test_jit_cuda_fuser_legacy.py deleted file mode 100644 index 41e16df7d6869..0000000000000 --- a/test/test_jit_cuda_fuser_legacy.py +++ /dev/null @@ -1,12 +0,0 @@ -import sys -sys.argv.append("--ge_config=legacy") - -import os -os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' -os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' -os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' - -from test_jit_cuda_fuser import * - -if __name__ == '__main__': - run_tests() diff --git a/test/test_jit_cuda_fuser_profiling.py b/test/test_jit_cuda_fuser_profiling.py deleted file mode 100644 index 7559b85519c45..0000000000000 --- a/test/test_jit_cuda_fuser_profiling.py +++ /dev/null @@ -1,12 +0,0 @@ -import sys -sys.argv.append("--ge_config=profiling") - -import os -os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' -os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' -os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' - -from test_jit_cuda_fuser import * - -if __name__ == '__main__': - run_tests() diff --git a/test/test_jit_disabled.py b/test/test_jit_disabled.py index f8aa7abc4ca75..3cdc6f7d2db3c 100644 --- a/test/test_jit_disabled.py +++ b/test/test_jit_disabled.py @@ -1,9 +1,8 @@ -import unittest import sys import os import contextlib import subprocess -from torch.testing._internal.common_utils import TemporaryFileName +from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName @contextlib.contextmanager @@ -16,7 +15,7 @@ def _jit_disabled(): os.environ["PYTORCH_JIT"] = cur_env -class TestJitDisabled(unittest.TestCase): +class TestJitDisabled(TestCase): """ These tests are separate from the rest of the JIT tests because we need run a new subprocess and `import torch` with the correct environment @@ -91,4 +90,4 @@ def forward(self, input): self.compare_enabled_disabled(_program_string) if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index b4efbf12c358c..9218ae908e04f 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -1,11 +1,13 @@ import unittest +import os +import sys import torch import torch.nn as nn import torch.nn.functional as F from torch.testing import FileCheck from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ - enable_profiling_mode_for_profiling_tests + enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward from textwrap import dedent @@ -60,6 +62,21 @@ def func(x): def test_abs_cpu(self): self._test_fused_abs() + @unittest.skipIf(not IS_WINDOWS, "This is meant to be Windows-specific") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") + @enable_cpu_fuser + def test_abs_cpu_unicode_temp_dir(self): + with TemporaryDirectoryName(suffix='中文') as dname: + shell_env = os.environ.copy() + shell_env['TMP'] = dname + cmd = [sys.executable, os.path.basename(__file__), type(self).__name__ + '.test_abs_cpu'] + legacy_jit_flag = '--jit_executor=legacy' + for v in sys.argv: + if v == legacy_jit_flag: + cmd.append(legacy_jit_flag) + return_code = shell(cmd, cwd=os.path.dirname(__file__), env=shell_env) + self.assertEqual(return_code, 0) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") def test_abs_cuda(self): self._test_fused_abs(device="cuda") @@ -540,8 +557,7 @@ def f(x): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_scalar_arg_cuda(self): - def fn_test_scalar_arg(x, p): - # type: (Tensor, float) -> Tensor + def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: return p * (x * x + x) x = torch.randn(4, 4, dtype=torch.float, device='cuda') @@ -553,8 +569,7 @@ def fn_test_scalar_arg(x, p): # use another function otherwise we will bailout # and won't be able to do fused checks - def fn_test_scalar_arg_requires_grad(x, p): - # type: (Tensor, float) -> Tensor + def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor: return p * (x * x + x) scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) diff --git a/test/test_jit_fuser_legacy.py b/test/test_jit_fuser_legacy.py index c33983e45e796..420075f6e611f 100644 --- a/test/test_jit_fuser_legacy.py +++ b/test/test_jit_fuser_legacy.py @@ -1,5 +1,5 @@ import sys -sys.argv.append("--ge_config=legacy") +sys.argv.append("--jit_executor=legacy") from test_jit_fuser import * if __name__ == '__main__': diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 6fab65006927a..3cea1cc86b5db 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1,5 +1,3 @@ -from collections import defaultdict - import operator import unittest import contextlib @@ -17,10 +15,10 @@ torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) -from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ +from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests from torch.testing._internal.jit_utils import JitTestCase, _inline_everything, \ - RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward + RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining from textwrap import dedent from itertools import product, permutations @@ -30,7 +28,10 @@ from torch.testing._internal.te_utils import CudaCodeGenExecuted +from jit.test_fuser_common import TestFuserCommon # noqa: F401 + FUSION_GROUP = 'prim::TensorExprGroup' +LLVM_ENABLED = torch._C._llvm_enabled() def strip_profiling_nodes(nodes): profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut']) @@ -54,55 +55,51 @@ def texpr_reductions_enabled(): class TestTEFuser(JitTestCase): def setUp(self): self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() + self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu() self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() torch._C._jit_override_can_fuse_on_cpu(True) + # TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle + # torch._C._jit_set_te_must_use_llvm_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) self.old_profiling_mode = torch._C._jit_set_profiling_mode(True) + self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(False) + self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() torch._C._jit_set_texpr_fuser_enabled(True) + self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + self.int_dtypes = [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.bool, + ] + self.fp_dtypes = [ + torch.float16, + torch.float32, + torch.float64, + ] + self.dtypes = self.int_dtypes + self.fp_dtypes + def tearDown(self): torch._C._jit_set_profiling_executor(self.old_profiling_executor) torch._C._jit_set_profiling_mode(self.old_profiling_mode) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) + torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state) + torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) - def assertAllFused(self, graph, except_for=()): - - # note this helper collects nodes on 'fast path' only - # i.e. the true blocks of specialized checks - def get_nodes_and_parents_recursively(block, kind, acc): - for node in block.nodes(): - if node.kind() == kind: - acc[block].append(node) - elif node.kind() == 'prim::DifferentiableGraph': - get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc) - elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or - node.inputs().__next__().node().kind() == 'prim::TypeCheck'): - get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc) - else: - for inner_block in node.blocks(): - get_nodes_and_parents_recursively(inner_block, kind, acc) - - allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate', - 'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck'} | set(except_for) - - fusion_groups = defaultdict(list) - get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups) - self.assertTrue(len(fusion_groups) == 1, 'got {}'.format(graph)) - (graph, fusion_nodes) = list(fusion_groups.items())[0] - # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes` - self.assertTrue(len(fusion_nodes) == 1, 'got {}'.format(graph)) - self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), - 'got {}'.format(graph)) - + def assertLastGraphAllFused(self): + self.assertAllFused(torch.jit.last_executed_optimized_graph()) def findFusionGroups(self, graph): result = [] @@ -120,12 +117,8 @@ def func(x): a = torch.randn(5, device=device) scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::abs").check("aten::mul").run(str(fusion_groups[0])) + self.assertLastGraphAllFused() - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_sum_simple(self): def func(x): x2 = x * x @@ -135,26 +128,23 @@ def func(x): a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - self.assertEqual(scripted(a), func(a)) + self.assertLastGraphAllFused() - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_sum_dim(self): def func(x): return x.sum((0, )) * 2 + def func_neg(x): + return x.sum((-2, )) * 2 + with texpr_reductions_enabled(): a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - self.assertEqual(scripted(a), func(a)) + self.assertLastGraphAllFused() + scripted = self.checkScript(func_neg, (a,)) + self.assertLastGraphAllFused() - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_sum_keepdim_cast(self): def func(x): return x.sum((0, ), keepdim=True, dtype=torch.double) * 2 @@ -162,13 +152,10 @@ def func(x): with texpr_reductions_enabled(): a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) - scripted = self.checkScript(func, (a,)) - graph = scripted.graph_for(a) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - self.assertEqual(scripted(a), func(a)) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") + self.checkScript(func, (a,)) + self.assertLastGraphAllFused() + def test_abs_cpu(self): self._test_fused_abs() @@ -190,7 +177,6 @@ def decode(sin_t, cos_t): def test_zero_element_tensors_cuda(self): self._test_zero_element_tensors(device="cuda") - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_zero_element_tensors_cpu(self): self._test_zero_element_tensors(device="cpu") @@ -209,21 +195,17 @@ def f(x, y): traced_f = torch.jit.trace(f, (x, y,)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_broadcast_cuda(self): - def scaleshift(x, scale, shift): - return x * scale + shift + def test_broadcast(self): + for device in self.devices: + def scaleshift(x, scale, shift): + return x * scale + shift - inputs = [ - torch.randn(4, 4, dtype=torch.float, device='cuda'), - torch.randn(4, dtype=torch.float, device='cuda'), - torch.randn(4, dtype=torch.float, device='cuda'), - ] - ge = self.checkTrace(scaleshift, inputs) - graph = ge.graph_for(*inputs) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::mul").check("aten::add").run(str(fusion_groups[0])) + inputs = [ + torch.randn(4, 4, dtype=torch.float, device=device), + torch.randn(4, dtype=torch.float, device=device), + torch.randn(4, dtype=torch.float, device=device), + ] + self.checkScript(scaleshift, inputs) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") @@ -261,35 +243,35 @@ def test_cuda_half(self): grads_half = [t.half() for t in grads] self.assertEqual(grads_half, fusion_grads) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_checks_cat_inputs(self): - # We shouldn't treat cat nodes as broadcasting. All their inputs - # need to be checked for having the same map size, before we can - # run the kernel. - def f(x, y): - return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) - - # NOTE: y is broadcastable to x, but output of f(x, y) should have - # shape 3x4, and not 4x4. - x = torch.randn(2, 4, dtype=torch.float, device='cuda') - y = torch.randn(1, 4, dtype=torch.float, device='cuda') - - scripted = self.checkScript(f, (x, y)) - self.assertEqual(scripted(x, y).shape, (3, 4)) - self.assertAllFused(scripted.graph_for(x, y)) - - @unittest.skipIf(not RUN_CUDA, "No CUDA") - def test_chunk_cuda(self): - def fn(x): - a, b, c = x.chunk(3, 1) - return a * b + c - - inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')] - - ge = self.checkScript(fn, inputs) - graph = ge.graph_for(*inputs) - self.assertAllFused(graph) - FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph)) + # single fusion node causes error + with set_fusion_group_inlining(True): + for device in self.devices: + # We shouldn't treat cat nodes as broadcasting. All their inputs + # need to be checked for having the same map size, before we can + # run the kernel. + def f(x, y): + return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) + + # NOTE: y is broadcastable to x, but output of f(x, y) should have + # shape 3x4, and not 4x4. + x = torch.randn(2, 4, dtype=torch.float, device=device) + y = torch.randn(1, 4, dtype=torch.float, device=device) + + scripted = self.checkScript(f, (x, y)) + self.assertEqual(scripted(x, y).shape, (3, 4)) + self.assertAllFused(scripted.graph_for(x, y)) + + def test_chunk(self): + for device in self.devices: + def fn(x): + a, b, c = x.chunk(3, 1) + return a * b + c + + inputs = [torch.randn(10, 6, dtype=torch.float, device=device)] + + self.checkScript(fn, inputs) + self.assertLastGraphAllFused() @staticmethod def _test_chunk_correctness(self, device='cpu'): @@ -320,8 +302,8 @@ def chunk_4_last(x): for tensor in tensors: for fn in fns: self.checkScript(fn, [tensor]) + self.assertLastGraphAllFused() - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_chunk_correctness(self): return self._test_chunk_correctness(self, 'cpu') @@ -329,120 +311,114 @@ def test_chunk_correctness(self): def test_chunk_correctness_cuda(self): return self._test_chunk_correctness(self, 'cuda') - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_chunk_distributes_cuda(self): - def f(x, y): - z1, z2 = (x + y).chunk(2, dim=1) - return z1 * z2 + def test_chunk_distributes(self): + for device in self.devices: + def f(x, y): + z1, z2 = (x + y).chunk(2, dim=1) + return z1 * z2 - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(f, (x, y)) - graph = ge.graph_for(x, y) - # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. - # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ - # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) - FileCheck().check("with " + FUSION_GROUP + "_").check_count( - "ConstantChunk", 1, exactly=True - ).run(str(graph)) + ge = self.checkTrace(f, (x, y)) + graph = ge.graph_for(x, y) + # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. + # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ + # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) + FileCheck().check("with " + FUSION_GROUP + "_").check_count( + "ConstantChunk", 1, exactly=True + ).run(str(graph)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_chunk_motion_deduplicates_inputs(self): - def func1(x): - z = x * x - z0, z1 = z.chunk(2) - return z0 * z1 - - def func2(x): - z = x * x * x - z0, z1 = z.chunk(2) - return z0 * z1 - - inputs = [ - torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float), - ] - for func in [func1, func2]: - module = self.checkScript(func, inputs) - forward_graph = module.graph_for(*inputs) - self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1) - fusion_group = list(forward_graph.nodes())[-1] - self.assertEqual(len(list(fusion_group.inputs())), 1) - - @unittest.skipIf(not RUN_CUDA, "No CUDA") - def test_chunk_multiple_cuda(self): - # The arguments are intentionally used out of order as a test to see - # if the fusion compiler adds extra args in the correct order - def fn(s, x, y, z): - z1, z2 = z.chunk(2, 2) - x1, x2, x3 = x.chunk(3, 1) - y1, y2 = y.chunk(2, 0) - return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 - - inputs = [ - torch.randn(5, 2, 3, dtype=torch.float, device='cuda'), - torch.randn(5, 6, 3, dtype=torch.float, device='cuda'), - torch.randn(10, 2, 3, dtype=torch.float, device='cuda'), - torch.randn(5, 2, 6, dtype=torch.float, device='cuda'), - ] - - ge = self.checkScript(fn, inputs) - self.assertAllFused(ge.graph_for(*inputs)) + for device in self.devices: + def func1(x): + z = x * x + z0, z1 = z.chunk(2) + return z0 * z1 + + def func2(x): + z = x * x * x + z0, z1 = z.chunk(2) + return z0 * z1 + + inputs = [ + torch.tensor([1.1, 1.2], device=device, dtype=torch.float), + ] + for func in [func1, func2]: + self.checkScript(func, inputs) + self.assertLastGraphAllFused() + + def test_chunk_multiple(self): + for device in self.devices: + # The arguments are intentionally used out of order as a test to see + # if the fusion compiler adds extra args in the correct order + def fn(s, x, y, z): + z1, z2 = z.chunk(2, 2) + x1, x2, x3 = x.chunk(3, 1) + y1, y2 = y.chunk(2, 0) + return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 + + inputs = [ + torch.randn(5, 2, 3, dtype=torch.float, device=device), + torch.randn(5, 6, 3, dtype=torch.float, device=device), + torch.randn(10, 2, 3, dtype=torch.float, device=device), + torch.randn(5, 2, 6, dtype=torch.float, device=device), + ] + + ge = self.checkScript(fn, inputs) + self.assertAllFused(ge.graph_for(*inputs)) def test_minmax(self): - def tmax(a, b): - return torch.max(2 * a, b) + for device in self.devices: + def tmax(a, b): + return torch.max(2 * a, b) - def tmin(a, b): - return torch.min(2 * a, b) + def tmin(a, b): + return torch.min(2 * a, b) - a = torch.randn(4, 4, dtype=torch.float) - b = torch.randn(4, 4, dtype=torch.float) - nan = torch.tensor(float('nan'), dtype=torch.float) + a = torch.randn(4, 4, dtype=torch.float) + b = torch.randn(4, 4, dtype=torch.float) + nan = torch.tensor(float('nan'), dtype=torch.float) + + for f, inputs, device in product( + (tmax, tmin), + ([a, b], [a, nan], [b, nan]), + self.devices): + inputs = [t.to(device) for t in inputs] + s = self.checkScript(f, inputs) + self.assertAllFused(s.graph_for(*inputs)) - devices = ["cpu"] - if torch.cuda.is_available(): - devices.append("cuda") - for f, inputs, device in product( - (tmax, tmin), - ([a, b], [a, nan], [b, nan]), - devices): - inputs = [t.to(device) for t in inputs] - s = self.checkScript(f, inputs) - self.assertAllFused(s.graph_for(*inputs)) - - # TODO: reenable the test after backwards passes start working in PE - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_clamp(self): - def func2(a, b): - return torch.clamp(a + b, min=0, max=2) + for device in self.devices: + def func2(a, b): + return torch.clamp(a + b, min=0, max=2) - def funcInf(a, b): - return torch.clamp(a + b, min=0, max=float('inf')) + def funcInf(a, b): + return torch.clamp(a + b, min=0, max=float('inf')) - def funcNegInf(a, b): - return torch.clamp(a + b, min=float('-inf'), max=0) + def funcNegInf(a, b): + return torch.clamp(a + b, min=float('-inf'), max=0) - def funcOptMin(a, b): - return torch.clamp(a + b, max=2) + def funcOptMin(a, b): + return torch.clamp(a + b, max=2) - def funcOptMax(a, b): - return torch.clamp(a + b, min=0) + def funcOptMax(a, b): + return torch.clamp(a + b, min=0) - a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) - b = torch.randn(4, 4, dtype=torch.float, device='cuda') - nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') - - funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) - for f, inputs in product(funcs, [[a, b], [a, nan]]): - inp1, inp2 = inputs - s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) - self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) - c = s(inp1, inp2) - with enable_profiling_mode_for_profiling_tests(): - warmup_backward(c.sum()) - graph = backward_graph(s) - self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) + a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) + b = torch.randn(4, 4, dtype=torch.float, device=device) + nan = torch.tensor(float('nan'), dtype=torch.float, device=device) + + funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) + for f, inputs in product(funcs, [[a, b], [a, nan]]): + inp1, inp2 = inputs + s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) + self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) + c = s(inp1, inp2) + with enable_profiling_mode_for_profiling_tests(): + warmup_backward(c.sum()) + graph = backward_graph(s) + self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") @@ -460,62 +436,55 @@ def func(x): graph = backward_graph(s, skip_check=True) self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'}) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_add_bool(self): - def f(x, y, z): - return x + y + z - - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') + sizes = [(1,), (2,), (4, 4)] + for device, size in product(self.devices, sizes): + def f(x, y, z): + return x + y + z - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + x = torch.randint(0, 2, size, dtype=torch.bool, device=device) + y = torch.randint(0, 2, size, dtype=torch.bool, device=device) + z = torch.randint(0, 2, size, dtype=torch.bool, device=device) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_mul_bool(self): - def f(x, y, z): - return x * y * z + for device in self.devices: + def f(x, y, z): + return x * y * z - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') + x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_div_bool(self): - def f(x, y, z): - return (x + y) / z + for device in self.devices: + def f(x, y, z): + return (x + y) / z - x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device='cuda') - z = torch.ones_like(x, dtype=torch.bool, device='cuda') + x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) + z = torch.ones_like(x, dtype=torch.bool, device=device) - ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) - self.assertAllFused(ge.graph_for(x, y, z)) + ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) + self.assertAllFused(ge.graph_for(x, y, z)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_bitwise_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.bool, - ] binary_ops = [ operator.__and__, operator.__or__, - operator.__xor__ + operator.__xor__, + operator.__lshift__, + operator.__rshift__, ] - devices = ["cuda"] - for dtype, op, device in product(dtypes, binary_ops, devices): + devices = self.devices + for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -536,25 +505,16 @@ def apply(fn): " ".join(["Failed:", str(dtype), op.__name__, device]) ) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_minmax_int_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - dtypes = [ - torch.int8, - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.bool, - ] binary_ops = [ torch.min, torch.max ] - devices = ["cuda"] - for dtype, op, device in product(dtypes, binary_ops, devices): + devices = self.devices + for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -575,20 +535,20 @@ def apply(fn): " ".join(["Failed:", str(dtype), op.__name__, device]) ) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_comparison_eq_ne(self): - def f(x, y): - mask = (x == 0).type_as(x) - z = x * mask + y - mask = (x != 0).type_as(x) - z = z * mask + y - return z + for device in self.devices: + def f(x, y): + mask = (x == 0).type_as(x) + z = x * mask + y + mask = (x != 0).type_as(x) + z = z * mask + y + return z - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(f, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(f, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) @staticmethod def fn_test_comparison_gt_lt(x, y): @@ -598,47 +558,47 @@ def fn_test_comparison_gt_lt(x, y): z = z * mask + y return z - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_comparison_gt_lt_cuda(self): - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') - - ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) - - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_comparison_ge_le_cuda(self): - def f(x, y): - mask = (x >= 0).type_as(x) - z = x * mask + y - mask = (x <= 0).type_as(x) - z = z * mask + y - return z - - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') - - ge = self.checkTrace(f, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) - x.requires_grad_(True) - y.requires_grad_(True) - self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) - - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_addcmul_cuda(self): - t = torch.randn(1, 4, dtype=torch.float, device='cuda') - t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') - t2 = torch.randn(1, 4, dtype=torch.float, device='cuda') - - def foo(t, t1, t2): - return t.addcmul(t + 1, t2, value=0.1) - - ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) - graph = ge.graph_for(t, t1, t2) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) + def test_comparison_gt_lt(self): + for device in self.devices: + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) + + ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) + + def test_comparison_ge_le(self): + for device in self.devices: + def f(x, y): + mask = (x >= 0).type_as(x) + z = x * mask + y + mask = (x <= 0).type_as(x) + z = z * mask + y + return z + + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) + + ge = self.checkTrace(f, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) + x.requires_grad_(True) + y.requires_grad_(True) + self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", + "aten::_size_if_not_equal")) + + def test_addcmul(self): + for device in self.devices: + t = torch.randn(1, 4, dtype=torch.float, device=device) + t1 = torch.randn(4, 1, dtype=torch.float, device=device) + t2 = torch.randn(1, 4, dtype=torch.float, device=device) + + def foo(t, t1, t2): + return t.addcmul(t + 1, t2, value=0.1) + + ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) + graph = ge.graph_for(t, t1, t2) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) # TODO: We leak CUDA memory here because the traced graph holds onto a # constant-ified tensor. Since the Python-global CompilationUnit is alive @@ -646,73 +606,95 @@ def foo(t, t1, t2): # Removed `_cuda` suffix from this test which disables leak-checking. # If this is a real problem, we'll need to revisit Torchscript Function # lifetimes in Python. - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_lerp(self): - start = torch.randn(4, 1, dtype=torch.float, device='cuda') - end = torch.randn(1, 4, dtype=torch.float, device='cuda') - weight = torch.tensor(0.5, dtype=torch.float, device='cuda') - - # scalar weight overload - def foo_weight_scalar(start, end): - return torch.lerp(start + 1, end, 0.5) - - # tensor weight overload - def foo_weight_tensor(start, end): - return torch.lerp(start + 1, end, weight) - - ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) - graph = ge_weight_scalar.graph_for(start, end) - self.assertAllFused(graph) - - # TODO: uncomment when TE enables support for scalar tensors - # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) - # graph = ge_weight_tensor.graph_for(start, end) - # self.assertAllFused(graph) + for device in self.devices: + start = torch.randn(4, 1, dtype=torch.float, device=device) + end = torch.randn(1, 4, dtype=torch.float, device=device) + weight = torch.tensor(0.5, dtype=torch.float, device=device) + + # scalar weight overload + def foo_weight_scalar(start, end): + return torch.lerp(start + 1, end, 0.5) + + # tensor weight overload + def foo_weight_tensor(start, end): + return torch.lerp(start + 1, end, weight) + + ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) + graph = ge_weight_scalar.graph_for(start, end) + self.assertAllFused(graph) + + # TODO: uncomment when TE enables support for scalar tensors + # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) + # graph = ge_weight_tensor.graph_for(start, end) + # self.assertAllFused(graph) + + def test_concat(self): + # disabling concat causes error with single concat node + with set_fusion_group_inlining(True): + for device in self.devices: + hx = torch.randn(3, 20, dtype=torch.float, device=device) + cx = torch.randn(3, 20, dtype=torch.float, device=device) + + def foo(hx, cx): + return torch.cat((hx + cx, hx * cx)) + + ge = self.checkTrace(foo, (hx, cx)) + graph = ge.graph_for(hx, cx) + self.assertAllFused(graph) + # XXX: TE fuser can handle concats in a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_concat_cuda(self): - hx = torch.randn(3, 20, dtype=torch.float, device='cuda') - cx = torch.randn(3, 20, dtype=torch.float, device='cuda') - - def foo(hx, cx): - return torch.cat((hx + cx, hx * cx)) - - ge = self.checkTrace(foo, (hx, cx)) - graph = ge.graph_for(hx, cx) - self.assertAllFused(graph) - # XXX: TE fuser can handle concats in a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) - - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_concat_invariant_cuda(self): - # Invariant: the output of prim::FusedConcat may - # not be an input to any node inside the FusionGroup. - def fn(x, y, z): - x1 = x + y - y1 = x - y - w = torch.cat([x1, y1]) - return w + z - - x = torch.randn(2, 2, dtype=torch.float, device='cuda') - y = torch.randn(2, 2, dtype=torch.float, device='cuda') - z = torch.randn(4, 2, dtype=torch.float, device='cuda') - ge = self.checkTrace(fn, (x, y, z)) - graph = ge.graph_for(x, y, z) - self.assertAllFused(graph, except_for={'aten::add'}) - # XXX: TE fuser can handle concats inside a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) + def test_remove_output_used_only_in_size(self): + def test_fuse(a, b): + c = a + b + d = c + b + return d + + scripted_f = torch.jit.script(test_fuse) + x = torch.ones(1, requires_grad=True, device='cuda') + y = torch.ones(1, requires_grad=True, device='cuda') + warmup_forward(scripted_f, x, y) + g = torch.jit.last_executed_optimized_graph() + diff_nodes = g.findAllNodes('prim::DifferentiableGraph') + self.assertEqual(len(diff_nodes), 1) + g = diff_nodes[0].g('Subgraph') + if_nodes = [n for n in g.nodes() if n.kind() == 'prim::If'] + self.assertEqual(len(if_nodes), 1) + # the if node and the fusion group inside it should only have one output + self.assertEqual(len(list(if_nodes[0].outputs())), 1) + + def test_concat_invariant(self): + for device in self.devices: + # Invariant: the output of prim::FusedConcat may + # not be an input to any node inside the FusionGroup. + def fn(x, y, z): + x1 = x + y + y1 = x - y + w = torch.cat([x1, y1]) + return w + z + + x = torch.randn(2, 2, dtype=torch.float, device=device) + y = torch.randn(2, 2, dtype=torch.float, device=device) + z = torch.randn(4, 2, dtype=torch.float, device=device) + ge = self.checkTrace(fn, (x, y, z)) + graph = ge.graph_for(x, y, z) + self.assertAllFused(graph, except_for={'aten::add'}) + # XXX: TE fuser can handle concats inside a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @staticmethod def fn_test_exp(x, y): return (x + .5 * y).exp() - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_exp_cuda(self): - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + def test_exp(self): + for device in self.devices: + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(self.fn_test_exp, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(self.fn_test_exp, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") @@ -770,40 +752,37 @@ def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph): test_norm_decompose(lm, ['aten::batch_norm_stats'], ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add']) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") def test_threshold(self): - def f(x): - return torch.threshold(x, 0, -10) + x + x + x + for device in self.devices: + def f(x): + return torch.threshold(x, 0, -10) + x + x + x - x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda') - scripted = self.checkScript(f, (x,)) - self.assertAllFused(scripted.graph_for(x)) + x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device) + scripted = self.checkScript(f, (x,)) + self.assertAllFused(scripted.graph_for(x)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_scalar_arg_cuda(self): - def fn_test_scalar_arg(x, p): - # type: (Tensor, float) -> Tensor - return p * (x * x + x) + def test_scalar_arg(self): + for device in self.devices: + def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: + return p * (x * x + x) - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - p = 3 - scripted = self.checkScript(fn_test_scalar_arg, (x, p)) - self.assertAllFused(scripted.graph_for(x, p)) + x = torch.randn(4, 4, dtype=torch.float, device=device) + p = 3 + scripted = self.checkScript(fn_test_scalar_arg, (x, p)) + self.assertAllFused(scripted.graph_for(x, p)) - x.requires_grad_(True) + x.requires_grad_(True) - # use another function otherwise we will bailout - # and won't be able to do fused checks - def fn_test_scalar_arg_requires_grad(x, p): - # type: (Tensor, float) -> Tensor - return p * (x * x + x) + # use another function otherwise we will bailout + # and won't be able to do fused checks + def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor: + return p * (x * x + x) - scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) - out = scripted(x, p) - self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) + out = scripted(x, p) + self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", + "aten::_size_if_not_equal")) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @unittest.skip("deduplicating introduces aliasing in backward graph's outputs") def test_fuser_deduplication(self): # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation @@ -825,7 +804,6 @@ def f(x, y): # check that a, b share storage, i.e. were generated as a single output in the fuser self.assertEqual(ga2.data_ptr(), gb2.data_ptr()) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @unittest.skip("temporarily disabled because fusion was restricted in fixing #22833") def test_fuser_iou(self): # This checks if most of Intersection over Union is fused. @@ -930,94 +908,66 @@ def doit(x, y): ge = self.checkTrace(doit, (x, y)) self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_cuda(self): - inputs = get_lstm_inputs('cuda', training=True) - module = self.checkScript(LSTMCellS, inputs) - return - forward_graph = module.graph_for(*inputs) - self.assertGraphContainsExactly( - forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) - self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) - # Everything is differentiable but TupleConstruct return - FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ - .check_next("return").run(str(forward_graph)) - - with enable_profiling_mode_for_profiling_tests(True): - hy, cy = module(*inputs) - warmup_backward((hy + cy).sum()) - backward = backward_graph(module) - self.assertAllFused(backward, except_for=("aten::t", "aten::mm", - "aten::_grad_sum_to_size")) - - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_concat_cuda(self): - inputs = get_lstm_inputs('cuda') - ge = self.checkTrace(LSTMCellC, inputs) - graph = ge.graph_for(*inputs) - # XXX: TE fuser can handle concats inside a fusion group. - # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) - - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_gates_permutations_cuda(self): - # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. - # Test that any permutation of this will still result in one FusionGroup. - choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] - template = dedent(''' - def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): - gates = {} + {} + {} + {} - ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) - return ingate * forgetgate * cellgate * outgate - ''') - for permutation in permutations(choices, len(choices)): - code = template.format(*permutation) - scope = {} - exec(code, globals(), scope) - cu = torch.jit.CompilationUnit(code) - - inputs = get_lstm_inputs('cuda', training=False) - self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) - forward_graph = cu.cell.graph_for(*inputs) - self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1) + def test_lstm(self): + for device in self.devices: + inputs = get_lstm_inputs(device, training=True) + module = self.checkScript(LSTMCellS, inputs) + self.assertLastGraphAllFused() + + def test_lstm_concat(self): + # single fusion node causes error + with set_fusion_group_inlining(True): + for device in self.devices: + inputs = get_lstm_inputs(device) + ge = self.checkTrace(LSTMCellC, inputs) + graph = ge.graph_for(*inputs) + self.assertLastGraphAllFused() + # XXX: TE fuser can handle concats inside a fusion group. + # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) + + def test_lstm_gates_permutations(self): + for device in self.devices: + # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. + # Test that any permutation of this will still result in one FusionGroup. + choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] + template = dedent(''' + def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): + gates = {} + {} + {} + {} + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + return ingate * forgetgate * cellgate * outgate + ''') + for permutation in permutations(choices, len(choices)): + code = template.format(*permutation) + scope = {} + exec(code, globals(), scope) + cu = torch.jit.CompilationUnit(code) + + inputs = get_lstm_inputs(device, training=False) + self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) + forward_graph = cu.cell.graph_for(*inputs) + self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1) # TODO: Fuser doesn't work at all when inputs require grad. Fix that - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_lstm_traced_cuda(self): - inputs = get_lstm_inputs('cuda') - ge = self.checkTrace(LSTMCellF, inputs) - graph = ge.graph_for(*inputs) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("Chunk").check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0])) - - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") - @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746") - def test_lstm_traced_cpu(self): - inputs = get_lstm_inputs('cpu') - try: + def test_lstm_traced(self): + for device in self.devices: + inputs = get_lstm_inputs(device) ge = self.checkTrace(LSTMCellF, inputs) graph = ge.graph_for(*inputs) - FileCheck.check("FusionGroup").run(str(graph)) - except RuntimeError as e: - if 'Failed to compile' in e.args[0]: - warnings.warn('CPU fuser test has failed! This is not a hard failure, ' - 'because the kernels sometimes trigger bugs in compilers ' - '(most notably GCC 7.2).') - raise unittest.SkipTest('Failed to compile') from e - else: - raise + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + FileCheck().check("Chunk").check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0])) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_milstm_cuda(self): - inputs = get_milstm_inputs('cuda', training=True) - module = self.checkScript(MiLSTMCell, inputs) - forward_graph = module.graph_for(*inputs) - self.assertGraphContainsExactly( - forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) - FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ - .check_next("return").check(FUSION_GROUP).run(str(forward_graph)) - hy, cy = module(*inputs) - warmup_backward((hy + cy).sum()) + def test_milstm(self): + for device in self.devices: + inputs = get_milstm_inputs(device, training=True) + module = self.checkScript(MiLSTMCell, inputs) + forward_graph = module.graph_for(*inputs) + self.assertGraphContainsExactly( + forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) + FileCheck().check("DifferentiableGraph").check("TupleConstruct") \ + .check_next("return").check(FUSION_GROUP).run(str(forward_graph)) + hy, cy = module(*inputs) + warmup_backward((hy + cy).sum()) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1050,26 +1000,26 @@ def create(self, x): def fn_test_relu(x, y): return F.relu(x + .5 * y) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_relu_cuda(self): - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + def test_relu(self): + for device in self.devices: + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(self.fn_test_relu, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(self.fn_test_relu, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_erf_cuda(self): - def fn_test_erf(x): - return F.relu(torch.erf(x) - torch.erfc(x)) + def test_erf(self): + for device in self.devices: + def fn_test_erf(x): + return F.relu(torch.erf(x) - torch.erfc(x)) - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - ge = self.checkTrace(fn_test_erf, (x,)) - self.assertAllFused(ge.graph_for(x)) - x.requires_grad_(True) - ge = self.checkTrace(fn_test_erf, (x,)) - self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + x = torch.randn(4, 4, dtype=torch.float, device=device) + ge = self.checkTrace(fn_test_erf, (x,)) + self.assertAllFused(ge.graph_for(x)) + x.requires_grad_(True) + ge = self.checkTrace(fn_test_erf, (x,)) + self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", + "aten::_size_if_not_equal")) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1132,68 +1082,68 @@ def fn(x, y): ge = self.checkScript(fn, (x, y)) self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_small_constant_cuda(self): - def fn_test_small_constant(x, y): - return (1e-8 * x + 5e-9 * y) * 1e8 - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + def test_small_constant(self): + for device in self.devices: + def fn_test_small_constant(x, y): + return (1e-8 * x + 5e-9 * y) * 1e8 + x = torch.randn(4, 4, dtype=torch.float, device=device) + y = torch.randn(4, 4, dtype=torch.float, device=device) - ge = self.checkTrace(fn_test_small_constant, (x, y)) - self.assertAllFused(ge.graph_for(x, y)) + ge = self.checkTrace(fn_test_small_constant, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") # Currently we don't pull constants into fusion groups, because in some # cases it could remove the constant from the original graph and now our # fusion group needs to return that constant for its other users. # Instead of never pulling constants into the fusion group, we should just # be more careful at how we rewrite its users. # TODO: fix that and reenable the test. - def test_tensor_scalar_ops_cuda(self): - def should_fuse(x): - z = 3. - y = x + z - return x * y - - def should_fuse_scalar(x, z): - y = x + int(z) - return x * y - - inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')] - ge = self.checkScript(should_fuse, inputs) - graph = ge.graph_for(*inputs) - fusion_groups = self.findFusionGroups(graph) - self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0])) - - inputs = [ - torch.randn(2, 2, dtype=torch.float, device='cuda'), - torch.tensor(3., dtype=torch.float, device='cuda'), - ] - ge = self.checkScript(should_fuse_scalar, inputs) - # Check that the fused graph computes correct results when the scalar - # input changes. - inputs = [ - torch.randn(2, 2, dtype=torch.float, device='cuda'), - torch.tensor(7., dtype=torch.float, device='cuda'), - ] - self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) - # The TE fuser supports fusion of non-constant scalars - self.assertGraphContainsExactly( - ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True) + def test_tensor_scalar_ops(self): + for device in self.devices: + def should_fuse(x): + z = 3. + y = x + z + return x * y + + def should_fuse_scalar(x, z): + y = x + int(z) + return x * y + + inputs = [torch.randn(2, 2, dtype=torch.float, device=device)] + ge = self.checkScript(should_fuse, inputs) + graph = ge.graph_for(*inputs) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0])) + + inputs = [ + torch.randn(2, 2, dtype=torch.float, device=device), + torch.tensor(3., dtype=torch.float, device=device), + ] + ge = self.checkScript(should_fuse_scalar, inputs) + # Check that the fused graph computes correct results when the scalar + # input changes. + inputs = [ + torch.randn(2, 2, dtype=torch.float, device=device), + torch.tensor(7., dtype=torch.float, device=device), + ] + self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) + # The TE fuser supports fusion of non-constant scalars + self.assertGraphContainsExactly( + ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True) - @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_where_and_typing(self): - def f(x, y): - mask = x > y - res = torch.where(mask, x, y) - return mask, res + for device in self.devices: + def f(x, y): + mask = x > y + res = torch.where(mask, x, y) + return mask, res - x = torch.randn(4, 4, dtype=torch.double) - y = torch.randn(4, 4, dtype=torch.double) + x = torch.randn(4, 4, dtype=torch.double, device=device) + y = torch.randn(4, 4, dtype=torch.double, device=device) - script_f = self.checkScript(f, (x, y)) - self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) + script_f = self.checkScript(f, (x, y)) + self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") @@ -1248,8 +1198,11 @@ def fn(a): torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) - def data_for(self, dtype, device="cuda"): - v = torch.arange(1, 3, dtype=torch.float, device=device) + def data_for(self, dtype, device="cuda", size=None): + if size is None: + v = torch.arange(1, 3, dtype=torch.float, device=device) + else: + v = torch.rand(*size, device=device) if dtype == torch.bool: return v > 2 elif dtype in [torch.qint8, torch.quint8, torch.qint32]: @@ -1257,29 +1210,155 @@ def data_for(self, dtype, device="cuda"): else: return v.to(dtype) - @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_unary_ops(self): - def apply(fn): - return lambda x: fn(2 * x) + def test_torch_to(self): + # test no op + @torch.jit.script + def foo(x): + return x.to(torch.float) + + foo(torch.tensor([3.], dtype=torch.float)) + foo(torch.tensor([3.], dtype=torch.float)) + FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + + # test not fusing non-const inputs + @torch.jit.script + def foo(x, dtype: int): + return x.to(dtype) + foo(torch.tensor([3.], dtype=torch.float), torch.int) + foo(torch.tensor([3.], dtype=torch.float), torch.int) + FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + + # test not fusing to_pinned inputs + @torch.jit.script + def foo(x, dtype: int): + return x.to(pin_memory=True) + + foo(torch.tensor([3.], dtype=torch.float), torch.int) + foo(torch.tensor([3.], dtype=torch.float), torch.int) + FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + + + # test across-device not supported + if torch.cuda.is_available(): + @torch.jit.script + def foo(x): + return x.to(device="cuda") + + foo(torch.tensor([3.], dtype=torch.float)) + foo(torch.tensor([3.], dtype=torch.float)) + FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + + sizes = [(1, 4), (4, 4)] + # reuses cast impl, smaller dtype set for faster test + dtypes = [ + torch.bool, + torch.int, + torch.float16, + torch.float32, + torch.float64, + ] + + class MyMod(torch.nn.Module): + def __init__(self, dtype): + super(MyMod, self).__init__() + self.dtype = dtype + + def forward(self, x): + return x.to(self.dtype) + + bad_dtypes = [] + for dtype, output_dtype, device, size in product(dtypes, dtypes, self.devices, sizes): + if dtype == output_dtype: + continue + + x = self.data_for(dtype, device, size=size) + mod = MyMod(output_dtype) + ref = mod.forward(x) + # use freezing to make non-Tensor args to `to` constant + mod = torch.jit.freeze(torch.jit.script(mod.eval())) + warmup_forward(mod.forward, x) + self.assertEqual(ref, mod.forward(x)) + self.assertLastGraphAllFused() + + @unittest.skip("Temporarily disabled") + def test_masked_fill(self): dtypes = [ torch.int8, - torch.uint8, torch.int16, torch.int32, torch.int64, - # torch.float16, + torch.float16, + torch.float32, + torch.float64, + torch.bool, + ] + sizes = [(2,), (4, 4)] + for self_dtype, device, scalar_val, size in product(dtypes, self.devices, [0.4, 3], sizes): + input_v = self.data_for(self_dtype, device, size=size) + mask = self.data_for(torch.bool, device, size=size) + + def fn(input_v, mask): + return torch.masked_fill(input_v, mask, scalar_val) + ref = fn(input_v, mask) + try: + t = torch.jit.trace(fn, (input_v, mask)) + torch.testing.assert_allclose(ref, t(input_v, mask)) + print(torch.jit.last_executed_optimized_graph()) + self.assertLastGraphAllFused() + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)]) + ) + + def test_isnan(self): + x = torch.rand([4]) + x[0] = float('nan') + inputs = [ + x, + torch.tensor([float('nan'), .5]) + ] + dtypes = [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, torch.float32, torch.float64, torch.bool, ] + + for inp, device, dtype in product(inputs, self.devices, dtypes): + # TODO + if dtype == torch.float16 and not LLVM_ENABLED: + continue + + inp = inp.to(device=device, dtype=dtype) + try: + f = torch.jit.trace(lambda x: x.isnan(), (inp,)) + warmup_forward(f, inp) + self.assertEqual(f(inp), inp.isnan()) + self.assertLastGraphAllFused() + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), 'isnan', device]) + ) + + # @unittest.skipIf(not LLVM_ENABLED, "TODO: bugs in ir eval") + def test_unary_ops(self): + def apply(fn): + return lambda x: fn(x) + unary_ops = [ + torch.lgamma, torch.sigmoid, torch.reciprocal, torch.neg, torch.relu, torch.log, torch.log10, + torch.log1p, torch.log2, torch.exp, torch.expm1, @@ -1302,11 +1381,13 @@ def apply(fn): torch.round, torch.trunc, torch.frac, + lambda x: torch.threshold(x, 0, -10), + lambda x: torch.clamp(x, -10, 10), ] - devices = ["cuda"] - for dtype, op, device in product(dtypes, unary_ops, devices): + sizes = [(1,), (2,), (4, 4)] + for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): try: - x = self.data_for(dtype, device) + x = self.data_for(dtype, device, size=size) fn = apply(op) ref = fn(x) except Exception: @@ -1318,6 +1399,263 @@ def apply(fn): t = torch.jit.trace(fn, (x,)) torch.testing.assert_allclose(ref, t(x)) self.assertAllFused(t.graph_for(x)) + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) + ) + + def test_binary_ops(self): + def apply(fn): + return lambda x, y: fn(x, y) + + binary_ops = [ + operator.__and__, + operator.__or__, + operator.__xor__, + torch.add, + torch.sub, + torch.mul, + torch.min, + torch.max, + lambda x, y: torch.lerp(x, y, 0.5), + torch.atan2, + torch.div, + torch.eq, + torch.ne, + torch.ge, + torch.gt, + torch.lt, + torch.fmod, + torch.remainder, + lambda x, y: y.type_as(x), + ] + fp_only = [ + torch.fmod, + torch.remainder, + ] + devices = self.devices + for dtype, op, device in product(self.dtypes, binary_ops, devices): + try: + x = self.data_for(dtype, device) + y = self.data_for(dtype, device) + fn = apply(op) + ref = fn(x, y) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (x, y)) + self.assertEqual(ref, t(x, y)) + if op not in fp_only or dtype.is_floating_point: + self.assertAllFused(t.graph_for(x, y)) + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), op.__name__, device]) + ) + + @unittest.skipIf(not LLVM_ENABLED, "TODO: bugs in ir eval") + def test_binary_tensor_scalar_ops(self): + def apply_with_scalar(fn, scalar): + return lambda x: fn(x, scalar) + + # FIXME: Fails in IR Eval: torch.int64 and_ cpu + binary_ops = [ + operator.__and__, + operator.__or__, + operator.__xor__, + torch.add, + torch.sub, + torch.mul, + torch.eq, + torch.ne, + torch.ge, + torch.lt, + torch.gt, + ] + devices = self.devices + # Maybe we should split this into separate tests to speed it up by + # only using scalar values relevant to particular ops + scalars = [1.5, 3, 0, -2.0, -1] + for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + try: + x = self.data_for(dtype, device) + fn = apply_with_scalar(op, scalar) + ref = fn(x) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (x)) + self.assertEqual(ref, t(x)) + self.assertAllFused(t.graph_for(x)) + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), op.__name__, device]) + ) + + def test_binary_div_ops(self): + def apply_with_scalar(fn, scalar): + return lambda x: fn(x, scalar) + + binary_ops = [ + torch.div, + torch.remainder, + torch.fmod, + ] + devices = self.devices + # Maybe we should split this into separate tests to speed it up by + # only using scalar values relevant to particular ops + scalars = [1.5, 3, -2.0, -1] # skip 0 + for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + try: + x = self.data_for(dtype, device) + fn = apply_with_scalar(op, scalar) + ref = fn(x) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (x)) + self.assertEqual(ref, t(x)) + except Exception as e: + raise RuntimeError( + "Failed: {} {} {} {}".format(dtype, op.__name__, device, scalar) + ) + + def test_binary_cuda_only_ops(self): + def apply_with_scalar(fn, scalar): + return lambda x: fn(x, scalar) + + dtypes = [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + # FIXME: 'pow' fails with dtype=torch.float16/device=cuda/scalar=0 + # torch.float16, + torch.float32, + torch.float64, + # torch.bool intentionally not included + ] + binary_ops = [ + torch.pow, + ] + devices = ['cuda'] if torch.cuda.is_available() else [] + # Maybe we should split this into separate tests to speed it up by + # only using scalar values relevant to particular ops + scalars = [1.5, 3, 0, -2.0, -1] + for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars): + try: + x = self.data_for(dtype, device) + fn = apply_with_scalar(op, scalar) + ref = fn(x) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (x)) + self.assertEqual(ref, t(x)) + self.assertAllFused(t.graph_for(x)) + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), op.__name__, device]) + ) + + @unittest.skipIf(not LLVM_ENABLED, "TODO: enable in ir eval") + def test_ternary_ops(self): + def apply(fn): + return lambda x, y, z: fn(x, y, z) + + ternary_ops = [ + torch.lerp, + torch.addcmul, + ] + devices = self.devices + for dtype, op, device in product(self.dtypes, ternary_ops, devices): + try: + x = self.data_for(dtype, device) + y = self.data_for(dtype, device) + z = self.data_for(dtype, device) + fn = apply(op) + ref = fn(x, y, z) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (x, y, z)) + self.assertEqual(ref, t(x, y, z)) + self.assertAllFused(t.graph_for(x, y, z)) + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), op.__name__, device]) + ) + + @unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure") + def test_list_ops(self): + def apply(fn): + return lambda x, y, z: fn([x * x, y * y, z * z]) + + devices = self.devices + list_ops = [ + torch.cat, + ] + for dtype, op, device in product(self.dtypes, list_ops, devices): + try: + x = self.data_for(dtype, device, size=[5, 4, 1, 7]) + y = self.data_for(dtype, device, size=[5, 4, 1, 7]) + z = self.data_for(dtype, device, size=[5, 4, 1, 7]) + fn = apply(op) + ref = fn(x, y, z) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (x, y, z)) + self.assertEqual(ref, t(x, y, z)) + self.assertAllFused(t.graph_for(x, y, z)) + except Exception as e: + raise RuntimeError( + " ".join(["Failed:", str(dtype), op.__name__, device]) + ) + + def test_where_ops(self): + def apply(fn): + return lambda cond, x, y: fn(cond, x, y) + + ops = [ + torch.where, + lambda cond, x, y: torch.where(cond, x, 3.1415), + lambda cond, x, y: torch.where(cond, 42, y), + ] + devices = self.devices + for dtype, op, device in product(self.dtypes, ops, devices): + try: + cond = self.data_for(torch.bool, device) + x = self.data_for(dtype, device) + y = self.data_for(dtype, device) + fn = apply(op) + ref = fn(cond, x, y) + except Exception: + # If eager mode doesn't support a dtype/op/device combo, + # neither does the fuser. Catch everything to avoid needing to + # guess what errors might be thrown by eager. + continue + try: + t = torch.jit.trace(fn, (cond, x, y)) + self.assertEqual(ref, t(cond, x, y)) + self.assertAllFused(t.graph_for(cond, x, y)) except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) @@ -1329,6 +1667,7 @@ def fn(x): return x * x + x unsupported_dtypes = [ + torch.uint8, torch.bfloat16, torch.complex32, torch.complex64, @@ -1380,6 +1719,62 @@ def eager(t0, t1, t2, t3, t4): torch.testing.assert_allclose(test, ref) self.assertAllFused(script.graph_for(*inputs)) + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + def test_sub_gt_and(self): + def eager(t1, t2, t3, t4, t: float): + w = t1 - t2 + h = t3 - t4 + k = (w > t) & (h > t) + assert k.dtype == torch.bool + if t > 0.5: + # Putting a use of k in a never-executed conditional prevents + # profiling its type, which leaves it as "Tensor". If we + # propagate Tensor back to the definition of k, we have to be + # careful not to create a fusion group containing it. + return k + 1 + return w + t = torch.rand(8, dtype=torch.float, device='cuda') + scripted = self.checkScript(eager, (t, t, t, t, 0.1)) + + def test_chunk_mul_one(self): + for device in self.devices: + def eager(x): + z, y, w = torch.chunk(x, 3, -1) + return z * 3, y, w + x = torch.rand(64, 1, 3072, dtype=torch.float, device=device) + z, y, w = eager(x) + script = self.checkScript(eager, (x,)) + + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + def test_eq_unsqueeze_type_as(self): + def eager(a, b): + mask = b == 1 + mask = torch.unsqueeze(mask, -1) + x = mask.type_as(a) + return x, mask + a = torch.rand(1, 64, 1024, device='cuda', dtype=torch.float) + b = torch.randint(-2, 2, (1, 64), device='cuda', dtype=torch.long) + script = self.checkScript(eager, (a, b)) + + def test_neg_pow(self): + def eager_tt(a: torch.Tensor, b: torch.Tensor): + return torch.neg(torch.pow(a, b)) + + def eager_ts(a: torch.Tensor, b: float): + return torch.neg(torch.pow(a, b)) + + def eager_st(a: float, b: torch.Tensor): + return torch.neg(torch.pow(a, b)) + + a = torch.rand(1, dtype=torch.float) + b = torch.rand(1, dtype=torch.float) + s = b.item() + script = self.checkScript(eager_tt, (a, b)) + self.assertAllFused(script.graph_for(a, b)) + script = self.checkScript(eager_ts, (a, s)) + self.assertAllFused(script.graph_for(a, s)) + script = self.checkScript(eager_st, (s, b)) + self.assertAllFused(script.graph_for(s, b)) if __name__ == '__main__': run_tests() diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 2422e518a7f9a..b17908e910bb6 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -1,5 +1,5 @@ import sys -sys.argv.append("--ge_config=legacy") +sys.argv.append("--jit_executor=legacy") from test_jit import * if __name__ == '__main__': diff --git a/test/test_jit_profiling.py b/test/test_jit_profiling.py index be02985e69a80..1cf67f87ded9b 100644 --- a/test/test_jit_profiling.py +++ b/test/test_jit_profiling.py @@ -1,10 +1,9 @@ import sys -sys.argv.append("--ge_config=profiling") +sys.argv.append("--jit_executor=profiling") from test_jit import * if __name__ == '__main__': run_tests() - if not PY2: - import test_jit_py3 - suite = unittest.findTestCases(test_jit_py3) - unittest.TextTestRunner().run(suite) + import test_jit_py3 + suite = unittest.findTestCases(test_jit_py3) + unittest.TextTestRunner().run(suite) diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index 4de5db8840353..e8694fd91aabd 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -1,10 +1,11 @@ from collections import namedtuple +from typing import Any, Dict, List, NamedTuple, Optional, Tuple + from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase from torch.testing import FileCheck from torch import jit from textwrap import dedent -from typing import NamedTuple, List, Optional, Dict, Tuple, Any from jit.test_module_interface import TestModuleInterface # noqa: F401 import inspect import unittest @@ -53,7 +54,7 @@ def compute(self, total_rows): def fn(): return NormalizationInfo(1, 2, 3, 4, 5) - with self.assertRaisesRegex(OSError, "NormalizationInfo"): + with self.assertRaisesRegex(OSError, "could not get source code"): torch.jit.script(fn) def test_optional_dict_construct(self): @@ -189,8 +190,7 @@ def __init__(self): super().__init__() @torch.jit.ignore - def foo(self, x, z): - # type: (Tensor, Tensor) -> Tuple[GG, GG] + def foo(self, x: torch.Tensor, z: torch.Tensor) -> Tuple[GG, GG]: return GG(x, z), GG(x, z) def forward(self, x, z): @@ -412,8 +412,7 @@ def test_optional_no_element_type_annotation(self): """ Test that using an optional with no contained types produces an error. """ - def fn_with_comment(x): - # type: (torch.Tensor) -> Optional + def fn_with_comment(x: torch.Tensor) -> Optional: return (x, x) def annotated_fn(x: torch.Tensor) -> Optional: @@ -437,8 +436,7 @@ def test_tuple_no_element_type_annotation(self): """ Test that using a tuple with no contained types produces an error. """ - def fn_with_comment(x): - # type: (torch.Tensor) -> Tuple + def fn_with_comment(x: torch.Tensor) -> Tuple: return (x, x) def annotated_fn(x: torch.Tensor) -> Tuple: @@ -555,7 +553,7 @@ def test_reannotate(self): @torch.jit.script def foo(): x = 5 - if True: + if 1 == 1: x : Optional[int] = 7 def test_module_inplace_construct(self): @@ -621,7 +619,7 @@ def if_function(inp: torch.Tensor) -> Any: def test_module_properties(self): class ModuleWithProperties(torch.nn.Module): - __ignored_properties__ = ["ignored_attr"] + __jit_unused_properties__ = ["ignored_attr"] def __init__(self, a: int): super().__init__() @@ -639,6 +637,15 @@ def attr(self): def ignored_attr(self): return sum([self.a]) + @torch.jit.unused + @property + def ignored_attr_2(self): + return sum([self.a]) + + @ignored_attr_2.setter + def ignored_attr_2(self, value): + self.a = sum([self.a]) + @attr.setter def attr(self, a: int): if a > 0: @@ -666,50 +673,91 @@ def attr(self): mod = ModuleWithProperties(3) scripted_mod = torch.jit.script(mod) - with self.assertRaisesRegex(torch.nn.modules.module.ModuleAttributeError, "has no attribute"): + with self.assertRaisesRegex(AttributeError, "has no attribute"): scripted_mod.ignored_attr + def test_ignoring_module_attributes(self): + """ + Test that module attributes can be ignored. + """ + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: int) -> int: + return sum([a]) + + class ModuleWithIgnoredAttr(torch.nn.Module): + __jit_ignored_attributes__ = ["a", "sub"] + + def __init__(self, a: int, b: int): + super().__init__() + self.a = a + self.b = b + self.sub = Sub() + + def forward(self) -> int: + return self.b + + @torch.jit.ignore + def ignored_fn(self) -> int: + return self.sub.forward(self.a) + + mod = ModuleWithIgnoredAttr(1, 4) + scripted_mod = torch.jit.script(mod) + self.assertEqual(scripted_mod(), 4) + self.assertEqual(scripted_mod.ignored_fn(), 1) + + # Test the error message for ignored attributes. + class ModuleUsesIgnoredAttr(torch.nn.Module): + __jit_ignored_attributes__ = ["a", "sub"] + + def __init__(self, a: int): + super().__init__() + self.a = a + self.sub = Sub() + + def forward(self) -> int: + return self.sub(self.b) + + mod = ModuleUsesIgnoredAttr(1) + + with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"): + scripted_mod = torch.jit.script(mod) + + def test_export_opnames_interface(self): global OneTwoModule @torch.jit.interface class OneTwoModule(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: torch.Tensor) -> torch.Tensor: pass - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: torch.Tensor) -> torch.Tensor: pass class FooMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: torch.Tensor) -> torch.Tensor: return 2 * x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x * y - def two(self, x): - # type: (Tensor) -> Tensor + def two(self, x: torch.Tensor) -> torch.Tensor: return 2 / x - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.two(self.one(x, x)) class M(nn.Module): @@ -719,8 +767,7 @@ def __init__(self): super(M, self).__init__() self.sub = BarMod() - def forward(self, x): - # type: (Tensor) -> Tensor + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.sub.forward(x) def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): @@ -738,6 +785,22 @@ def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): # self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset( # set(torch.jit.export_opnames(scripted_M_mod)))) + def test_broadcasting_list(self): + """ + Test BroadcastingList and torch.nn._size_N_t alias + """ + from torch._jit_internal import BroadcastingList2 + from torch.nn.common_types import _size_2_t + + def sum_i(x: _size_2_t) -> int: + return x[0] + x[1] + + def sum_f(x: BroadcastingList2[float]) -> float: + return x[0] + x[1] + + self.assertTrue(torch.jit.script(sum_i)(4) == 8) + self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.) + if __name__ == '__main__': run_tests() diff --git a/test/test_jit_simple.py b/test/test_jit_simple.py index 910e4a17713de..23c7f3b4b6f6f 100644 --- a/test/test_jit_simple.py +++ b/test/test_jit_simple.py @@ -1,10 +1,9 @@ import sys -sys.argv.append("--ge_config=simple") +sys.argv.append("--jit_executor=simple") from test_jit import * if __name__ == '__main__': run_tests() - if not PY2: - import test_jit_py3 - suite = unittest.findTestCases(test_jit_py3) - unittest.TextTestRunner().run(suite) + import test_jit_py3 + suite = unittest.findTestCases(test_jit_py3) + unittest.TextTestRunner().run(suite) diff --git a/test/test_jit_string.py b/test/test_jit_string.py index c0f466688a72a..383eb0bf8353c 100644 --- a/test/test_jit_string.py +++ b/test/test_jit_string.py @@ -1,28 +1,26 @@ from test_jit import JitTestCase from torch.testing._internal.common_utils import run_tests +from typing import List, Tuple + class TestScript(JitTestCase): def test_str_ops(self): - def test_str_is(s): - # type: (str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool] + def test_str_is(s: str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]: return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \ s.isalnum(), s.isalpha(), s.isdecimal(), s.isnumeric(), \ s.isidentifier(), s.istitle(), s.isprintable() - def test_str_to(s): - # type: (str) -> Tuple[str, str, str, str, str] + def test_str_to(s: str) -> Tuple[str, str, str, str, str]: return s.upper(), s.lower(), s.capitalize(), s.title(), s.swapcase() - def test_str_strip(s): - # type: (str) -> Tuple[str, str, str] + def test_str_strip(s: str) -> Tuple[str, str, str]: return ( s.lstrip(), s.rstrip(), s.strip(), ) - def test_str_strip_char_set(s, char_set): - # type: (str, str) -> Tuple[str, str, str] + def test_str_strip_char_set(s: str, char_set: str) -> Tuple[str, str, str]: return ( s.lstrip(char_set), s.rstrip(char_set), @@ -34,44 +32,34 @@ def test_str_strip_char_set(s, char_set): "more strings with spaces", "Titular Strings", "\x0acan'tprintthis", "spaces at the end ", " begin"] - def test_str_center(i, s): - # type: (int, str) -> str + def test_str_center(i: int, s: str) -> str: return s.center(i) - def test_str_center_fc(i, s): - # type: (int, str) -> str + def test_str_center_fc(i: int, s: str) -> str: return s.center(i, '*') - def test_str_center_error(s): - # type: (str) -> str + def test_str_center_error(s: str) -> str: return s.center(10, '**') - def test_ljust(s, i): - # type: (str, int) -> str + def test_ljust(s: str, i: int) -> str: return s.ljust(i) - def test_ljust_fc(s, i, fc): - # type: (str, int, str) -> str + def test_ljust_fc(s: str, i: int, fc: str) -> str: return s.ljust(i, fc) - def test_ljust_fc_err(s): - # type: (str) -> str + def test_ljust_fc_err(s: str) -> str: return s.ljust(10, '**') - def test_rjust(s, i): - # type: (str, int) -> str + def test_rjust(s: str, i: int) -> str: return s.rjust(i) - def test_rjust_fc(s, i, fc): - # type: (str, int, str) -> str + def test_rjust_fc(s: str, i: int, fc: str) -> str: return s.rjust(i, fc) - def test_rjust_fc_err(s): - # type: (str) -> str + def test_rjust_fc_err(s: str) -> str: return s.rjust(10, '**') - def test_zfill(s, i): - # type: (str, int) -> str + def test_zfill(s: str, i: int) -> str: return s.zfill(i) for input in inputs: @@ -93,8 +81,7 @@ def test_zfill(s, i): test_str_center_error("error") test_ljust("error") - def test_count(): - # type: () -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int] + def test_count() -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]: return ( "hello".count("h"), "hello".count("h", 0, 1), @@ -111,8 +98,7 @@ def test_count(): ) self.checkScript(test_count, ()) - def test_endswith(): - # type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool] + def test_endswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]: return ( "hello".endswith("lo"), "hello".endswith("lo", 0), @@ -131,8 +117,7 @@ def test_endswith(): ) self.checkScript(test_endswith, ()) - def test_startswith(): - # type: () -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool] + def test_startswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]: return ( "hello".startswith("lo"), "hello".startswith("lo", 0), @@ -151,8 +136,7 @@ def test_startswith(): ) self.checkScript(test_startswith, ()) - def test_expandtabs(): - # type: () -> Tuple[str, str, str, str, str, str] + def test_expandtabs() -> Tuple[str, str, str, str, str, str]: return ( 'xyz\t82345\tabc'.expandtabs(), 'xyz\t32345\tabc'.expandtabs(3), @@ -163,8 +147,7 @@ def test_expandtabs(): ) self.checkScript(test_expandtabs, ()) - def test_rfind(): - # type: () -> Tuple[int, int, int, int, int, int, int, int, int] + def test_rfind() -> Tuple[int, int, int, int, int, int, int, int, int]: return ( "hello123abc".rfind("llo"), "hello123abc".rfind("12"), @@ -178,8 +161,7 @@ def test_rfind(): ) self.checkScript(test_rfind, ()) - def test_find(): - # type: () -> Tuple[int, int, int, int, int, int, int, int, int] + def test_find() -> Tuple[int, int, int, int, int, int, int, int, int]: return ( "hello123abc".find("llo"), "hello123abc".find("12"), @@ -193,8 +175,7 @@ def test_find(): ) self.checkScript(test_find, ()) - def test_index(): - # type: () -> Tuple[int, int, int, int, int, int] + def test_index() -> Tuple[int, int, int, int, int, int]: return ( "hello123abc".index("llo"), "hello123abc".index("12"), @@ -205,8 +186,7 @@ def test_index(): ) self.checkScript(test_index, ()) - def test_rindex(): - # type: () -> Tuple[int, int, int, int, int, int] + def test_rindex() -> Tuple[int, int, int, int, int, int]: return ( "hello123abc".rindex("llo"), "hello123abc".rindex("12"), @@ -217,8 +197,7 @@ def test_rindex(): ) self.checkScript(test_rindex, ()) - def test_replace(): - # type: () -> Tuple[str, str, str, str, str, str, str] + def test_replace() -> Tuple[str, str, str, str, str, str, str]: return ( "hello123abc".replace("llo", "sdf"), "ff".replace("f", "ff"), @@ -230,11 +209,9 @@ def test_replace(): ) self.checkScript(test_replace, ()) - def test_partition(): - """ - type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], - Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]] - """ + def test_partition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], + Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], + Tuple[str, str, str]]: return ( "hello123abc".partition("llo"), "ff".partition("f"), @@ -246,11 +223,9 @@ def test_partition(): ) self.checkScript(test_partition, ()) - def test_rpartition(): - """ - type: () -> Tuple[Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str], - Tuple[str,str,str], Tuple[str,str,str], Tuple[str,str,str]] - """ + def test_rpartition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], + Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], + Tuple[str, str, str]]: return ( "hello123abc".rpartition("llo"), "ff".rpartition("f"), @@ -262,11 +237,8 @@ def test_rpartition(): ) self.checkScript(test_rpartition, ()) - def test_split(): - """ - type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], - List[str], List[str], List[str], List[str], List[str], List[str]] - """ + def test_split() -> Tuple[List[str], List[str], List[str], List[str], List[str], + List[str], List[str], List[str], List[str], List[str], List[str]]: return ( "a a a a a".split(), "a a a a a".split(), @@ -290,8 +262,8 @@ def test_split_empty_separator(): self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception, "empty separator") - def test_rsplit(): - # type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]] + def test_rsplit() -> Tuple[List[str], List[str], List[str], List[str], List[str], + List[str], List[str], List[str], List[str]]: return ( "a a a a a".rsplit(), " a a a a a ".rsplit(" "), @@ -305,8 +277,8 @@ def test_rsplit(): ) self.checkScript(test_rsplit, ()) - def test_splitlines(): - # type: () -> Tuple[ List[str], List[str], List[str], List[str], List[str], List[str] ] + def test_splitlines() -> Tuple[List[str], List[str], List[str], List[str], + List[str], List[str]]: return ( "hello\ntest".splitlines(), "hello\n\ntest\n".splitlines(), @@ -317,8 +289,7 @@ def test_splitlines(): ) self.checkScript(test_splitlines, ()) - def test_str_cmp(a, b): - # type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool] + def test_str_cmp(a: str, b: str) -> Tuple[bool, bool, bool, bool, bool, bool]: return a != b, a == b, a < b, a > b, a <= b, a >= b for i in range(len(inputs) - 1): diff --git a/test/test_kernel_launch_checks.py b/test/test_kernel_launch_checks.py new file mode 100644 index 0000000000000..698a5cda2a420 --- /dev/null +++ b/test/test_kernel_launch_checks.py @@ -0,0 +1,49 @@ +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches + + +class AlwaysCheckCudaLaunchTest(TestCase): + def test_check_code(self): + """Verifies that the regex works for a few different situations""" + + # Try some different spacings + self.assertEqual(2, check_code_for_cuda_kernel_launches(""" +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +C10_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); + +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +C10_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +some_other_stuff; +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +C10_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>> (arg1,arg2,arg3); +C10_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>> ( arg1 , arg2 , arg3 ) ; + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + """)) + + # Does it work for macros? + self.assertEqual(0, check_code_for_cuda_kernel_launches(r""" +#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \ + indexAddSmallIndex \ + <<>>( \ + selfInfo, sourceInfo, indexInfo, \ + selfAddDim, sourceAddDim, sliceSize, selfAddDimSize); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + """)) + + def test_check_cuda_launches(self): + check_cuda_kernel_launches() + # TODO: Enable this after warning messages have been dealt with. + self.assertTrue(True) + # self.assertTrue(check_cuda_kernel_launches() == 0) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_linalg.py b/test/test_linalg.py index d3e1905e8d248..85c7f67f25499 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1,24 +1,90 @@ import torch +import numpy as np + +import sys +import subprocess +import os import unittest import itertools import warnings +import math from math import inf, nan, isnan +import random +from random import randrange +from itertools import product +from functools import reduce from torch.testing._internal.common_utils import \ - (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN) + (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, + TEST_WITH_ASAN, make_tensor, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, + wrapDeterministicFlagAPITest, iter_indices) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) + (instantiate_device_type_tests, dtypes, + onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA, + onlyCUDA) +from torch.testing import floating_and_complex_types, floating_types, all_types +from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32 from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args -from torch.autograd import gradcheck +from torch.autograd import gradcheck, gradgradcheck + +# Protects against includes accidentally setting the default dtype +# NOTE: jit_metaprogramming_utils sets the default dtype to double! +torch.set_default_dtype(torch.float32) +assert torch.get_default_dtype() is torch.float32 -if TEST_NUMPY: - import numpy as np +if TEST_SCIPY: + import scipy + +# TODO: make this common and import it +AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() class TestLinalg(TestCase): exact_dtype = True + @dtypes(torch.float, torch.cfloat) + @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06}) + def test_inner(self, device, dtype): + def check(a_sizes_, b_sizes_): + for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)): + a = torch.randn(a_sizes, dtype=dtype, device=device) + b = torch.randn(b_sizes, dtype=dtype, device=device) + res = torch.inner(a, b) + ref = np.inner(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) + out = torch.zeros_like(res) + torch.inner(a, b, out=out) + self.assertEqual(res, out) + + check([], []) # scalar x scalar + check([], [0]) # scalar x empty + check([], [3]) # scalar x 1D + check([], [2, 3, 4]) # scalar x 3D + + check([0], [0]) # empty x empty + check([0], [2, 0]) # empty x 2D + + check([2], [2]) # 1D x 1D + check([2], [3, 1, 2]) # 1D x 3D + check([2], [3, 0, 2]) # 1D x 3D empty + + check([1, 2], [3, 2]) # 2D x 2D + check([1, 2], [3, 4, 2]) # 2D x 3D + check([2, 1, 3, 2], [1, 3, 2, 2]) # 4D x 4D + + # Test discontiguous input + a = torch.randn(3, 2, device=device, dtype=dtype).transpose_(0, 1) + b = torch.randn(4, 3, device=device, dtype=dtype)[::2, :] + self.assertFalse(a.is_contiguous() or b.is_contiguous()) + self.assertEqual(a.inner(b).cpu().numpy(), np.inner(a.cpu().numpy(), b.cpu().numpy())) + + # Test error message + with self.assertRaisesRegex(RuntimeError, + r"inner\(\) the last dimension must match on both " + r"input tensors but got shapes \[2, 3\] and \[2, 2\]"): + torch.randn(2, 3, device=device, dtype=dtype).inner(torch.randn(2, 2, device=device, dtype=dtype)) + # Tests torch.outer, and its alias, torch.ger, vs. NumPy - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @precisionOverride({torch.bfloat16: 1e-1}) @dtypes(*(torch.testing.get_all_dtypes())) def test_outer(self, device, dtype): @@ -55,11 +121,232 @@ def run_test_case(a, b): run_test_case(zero_strided, b) run_test_case(a, zero_strided) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @precisionOverride({torch.bfloat16: 1e-1}) - @dtypes(*(torch.testing.get_all_dtypes())) - def test_addr(self, device, dtype): - def run_test_case(m, a, b, beta=1, alpha=1): + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_cholesky(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + def run_test(shape, batch, contiguous): + A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) + if A.numel() > 0 and not contiguous: + A = A.transpose(-2, -1) + self.assertFalse(A.is_contiguous()) + expected_L = np.linalg.cholesky(A.cpu().numpy()) + actual_L = torch.linalg.cholesky(A) + + # For fp32 individual entries in matrices can differ between PyTorch and NumPy + # Let's compare the norms of matrices instead + if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: + # axis is specified to calculate matrix norm for batched input + expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) + actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) + # Compare the norms with standard tolerances + self.assertEqual(actual_norm, expected_norm) + # and individual values with a higher tolerance + self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) + else: + self.assertEqual(actual_L, expected_L) + + shapes = (0, 3, 5) + batches = ((), (3, ), (2, 2)) + larger_input_case = [(100, (5, ), True)] + for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case: + run_test(shape, batch, contiguous) + + # check the out= variant + A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device) + out = torch.empty_like(A) + ans = torch.linalg.cholesky(A, out=out) + self.assertEqual(ans, out) + expected = torch.linalg.cholesky(A) + self.assertEqual(expected, out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_cholesky_errors_and_warnings(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + # cholesky requires the input to be a square matrix or batch of square matrices + A = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.cholesky(A) + A = torch.randn(2, 2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'): + np.linalg.cholesky(A.cpu().numpy()) + + # cholesky requires the input to be at least 2 dimensional tensor + A = torch.randn(2, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, + r'1-dimensional array given\. Array must be at least two-dimensional'): + np.linalg.cholesky(A.cpu().numpy()) + + # if the input matrix is singular, an error should be raised + A = torch.eye(3, 3, dtype=dtype, device=device) + A[-1, -1] = 0 # Now A is singular + with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): + np.linalg.cholesky(A.cpu().numpy()) + + # if at least one matrix in the batch is singular, an error should be raised + A = torch.eye(3, 3, dtype=dtype, device=device) + A = A.reshape((1, 3, 3)) + A = A.repeat(5, 1, 1) + A[4, -1, -1] = 0 # Now A[4] is singular + with self.assertRaisesRegex(RuntimeError, r'For batch 4: U\(3,3\) is zero, singular U\.'): + torch.linalg.cholesky(A) + + # if out tensor with wrong shape is passed a warning is given + A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) + out = torch.empty(2, 3, dtype=dtype, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.cholesky(A, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(A).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.cholesky(A, out=out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_cholesky_autograd(self, device, dtype): + def func(root): + x = 0.5 * (root + root.transpose(-1, -2).conj()) + return torch.linalg.cholesky(x) + + def run_test(shape): + root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) + root = root + torch.eye(shape[-1], dtype=dtype, device=device) + + gradcheck(func, root) + gradgradcheck(func, root) + + root = torch.rand(*shape, dtype=dtype, device=device) + root = torch.matmul(root, root.transpose(-1, -2).conj()) + root.requires_grad_() + chol = torch.linalg.cholesky(root).sum().backward() + self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian + + shapes = ((3, 3), (4, 3, 2, 2)) + for shape in shapes: + run_test(shape) + + # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py + @slowTest + @skipCUDAIf(True, "See issue #26789.") + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_old_cholesky_batched_many_batches(self, device, dtype): + from torch.testing._internal.common_utils import random_symmetric_pd_matrix + + def cholesky_test_helper(n, batchsize, device, upper): + A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) + chol_fact = torch.cholesky(A, upper=upper) + if upper: + # Correctness check + self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) + # Upper triangular check + self.assertEqual(chol_fact, chol_fact.triu()) + else: + # Correctness check + self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) + # Lower triangular check + self.assertEqual(chol_fact, chol_fact.tril()) + + for upper, batchsize in itertools.product([True, False], [262144, 524288]): + cholesky_test_helper(2, batchsize, device, upper) + + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_batched(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + def cholesky_test_helper(n, batch_dims, upper): + A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device) + cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) + cholesky_exp = cholesky_exp.reshape_as(A) + self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) + + for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]): + cholesky_test_helper(3, batchsize, upper) + + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @tf32_on_and_off(0.01) + def test_old_cholesky(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) + + # default Case + C = torch.cholesky(A) + B = torch.mm(C, C.t().conj()) + self.assertEqual(A, B, atol=1e-14, rtol=0) + + # test Upper Triangular + U = torch.cholesky(A, True) + B = torch.mm(U.t().conj(), U) + self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') + + # test Lower Triangular + L = torch.cholesky(A, False) + B = torch.mm(L, L.t().conj()) + self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_empty(self, device, dtype): + def run_test(upper): + A = torch.empty(0, 0, dtype=dtype, device=device) + chol = torch.cholesky(A, upper) + chol_A = torch.matmul(chol, chol.t().conj()) + self.assertEqual(A, chol_A) + for upper in [True, False]: + run_test(upper) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_old_cholesky_autograd(self, device, dtype): + def func(root, upper): + x = 0.5 * (root + root.transpose(-1, -2).conj()) + return torch.cholesky(x, upper) + + def run_test(upper, dims): + root = torch.rand(*dims, dtype=dtype, device=device, requires_grad=True) + root = root + torch.eye(dims[-1]) + + gradcheck(func, [root, upper]) + gradgradcheck(func, [root, upper]) + + root = torch.rand(*dims, dtype=dtype, device=device) + root = torch.matmul(root, root.transpose(-1, -2).conj()) + root.requires_grad_() + chol = root.cholesky().sum().backward() + self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian + + for upper, dims in itertools.product([True, False], [(3, 3), (4, 3, 2, 2)]): + run_test(upper, dims) + + def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1): + def check(m, a, b, beta, alpha): if dtype == torch.bfloat16: a_np = a.to(torch.double).cpu().numpy() b_np = b.to(torch.double).cpu().numpy() @@ -68,41 +355,87 @@ def run_test_case(m, a, b, beta=1, alpha=1): a_np = a.cpu().numpy() b_np = b.cpu().numpy() m_np = m.cpu().numpy() - if beta == 0: expected = alpha * np.outer(a_np, b_np) else: expected = beta * m_np + alpha * np.outer(a_np, b_np) - self.assertEqual(torch.addr(m, a, b, beta=beta, alpha=alpha), expected) - self.assertEqual(torch.Tensor.addr(m, a, b, beta=beta, alpha=alpha), expected) + res = torch.addr(m, a, b, beta=beta, alpha=alpha) + self.assertEqual(res, expected) - result_dtype = torch.addr(m, a, b, beta=beta, alpha=alpha).dtype - out = torch.empty_like(m, dtype=result_dtype) + # Test out variant + out = torch.empty_like(res) torch.addr(m, a, b, beta=beta, alpha=alpha, out=out) self.assertEqual(out, expected) - a = torch.randn(50).to(device=device, dtype=dtype) - b = torch.randn(50).to(device=device, dtype=dtype) - m = torch.randn(50, 50).to(device=device, dtype=dtype) - - # when beta is zero - run_test_case(m, a, b, beta=0., alpha=2) + m = make_tensor((50, 50), device=device, dtype=dtype, low=-2, high=2) + a = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2) + b = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2) - # when beta is not zero - run_test_case(m, a, b, beta=0.5, alpha=2) + check(m, a, b, beta, alpha) # test transpose m_transpose = torch.transpose(m, 0, 1) - run_test_case(m_transpose, a, b, beta=0.5, alpha=2) + check(m_transpose, a, b, beta, alpha) # test 0 strided tensor - zero_strided = torch.randn(1).to(device=device, dtype=dtype).expand(50) - run_test_case(m, zero_strided, b, beta=0.5, alpha=2) + zero_strided = make_tensor((1,), device=device, dtype=dtype, low=-2, high=2).expand(50) + check(m, zero_strided, b, beta, alpha) # test scalar m_scalar = torch.tensor(1, device=device, dtype=dtype) - run_test_case(m_scalar, a, b) + check(m_scalar, a, b, beta, alpha) + + # test nans and infs are not propagated to the output when beta == 0 + float_and_complex_dtypes = torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes() + if beta == 0 and dtype in float_and_complex_dtypes: + m[0][10] = m[10][10] = m[20][20] = float('inf') + m[1][10] = m[11][10] = m[21][20] = float('nan') + check(m, a, b, 0, alpha) + + @dtypes(torch.bool) + def test_addr_bool(self, device, dtype): + self._test_addr_vs_numpy(device, dtype, beta=True, alpha=False) + self._test_addr_vs_numpy(device, dtype, beta=False, alpha=True) + self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False) + self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True) + + @dtypes(*(torch.testing.get_all_int_dtypes())) + def test_addr_integral(self, device, dtype): + with self.assertRaisesRegex(RuntimeError, + 'argument beta must not be a floating point number.'): + self._test_addr_vs_numpy(device, dtype, beta=2., alpha=1) + with self.assertRaisesRegex(RuntimeError, + 'argument alpha must not be a floating point number.'): + self._test_addr_vs_numpy(device, dtype, beta=2, alpha=1.) + with self.assertRaisesRegex(RuntimeError, + 'Boolean beta only supported for Boolean results.'): + self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1) + with self.assertRaisesRegex(RuntimeError, + 'Boolean alpha only supported for Boolean results.'): + self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True) + + # when beta is zero + self._test_addr_vs_numpy(device, dtype, beta=0, alpha=2) + # when beta is not zero + self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2) + + @precisionOverride({torch.bfloat16: 1e-1}) + @dtypes(*(torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) + def test_addr_float_and_complex(self, device, dtype): + with self.assertRaisesRegex(RuntimeError, + 'Boolean beta only supported for Boolean results.'): + self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1) + with self.assertRaisesRegex(RuntimeError, + 'Boolean alpha only supported for Boolean results.'): + self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True) + + # when beta is zero + self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2) + # when beta is not zero + self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2) + if dtype in torch.testing.get_all_complex_dtypes(): + self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j)) @dtypes(*itertools.product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) @@ -114,23 +447,19 @@ def test_outer_type_promotion(self, device, dtypes): self.assertEqual(result.dtype, torch.result_type(a, b)) @dtypes(*itertools.product(torch.testing.get_all_dtypes(), + torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) def test_addr_type_promotion(self, device, dtypes): - a = torch.randn(5).to(device=device, dtype=dtypes[0]) - b = torch.randn(5).to(device=device, dtype=dtypes[1]) - m = torch.randn(5, 5).to(device=device, - dtype=torch.result_type(a, b)) + a = make_tensor((5,), device=device, dtype=dtypes[0], low=-2, high=2) + b = make_tensor((5,), device=device, dtype=dtypes[1], low=-2, high=2) + m = make_tensor((5, 5), device=device, dtype=dtypes[2], low=-2, high=2) + + desired_dtype = torch.promote_types(torch.promote_types(dtypes[0], dtypes[1]), + dtypes[2]) for op in (torch.addr, torch.Tensor.addr): - # pass the integer 1 to the torch.result_type as both - # the default values of alpha and beta are integers (alpha=1, beta=1) - desired_dtype = torch.result_type(m, 1) result = op(m, a, b) self.assertEqual(result.dtype, desired_dtype) - desired_dtype = torch.result_type(m, 2.) - result = op(m, a, b, beta=0, alpha=2.) - self.assertEqual(result.dtype, desired_dtype) - # Tests migrated from test_torch.py # 1) test the shape of the result tensor when there is empty input tensor # 2) test the Runtime Exception when there is scalar input tensor @@ -158,8 +487,7 @@ def test_outer_ger_addr_legacy_tests(self, device): # Tests torch.det and its alias, torch.linalg.det, vs. NumPy @skipCUDAIfNoMagma @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) def test_det(self, device, dtype): tensors = ( torch.randn((2, 2), device=device, dtype=dtype), @@ -175,34 +503,427 @@ def test_det(self, device, dtype): for op in ops: actual = op(t) self.assertEqual(actual, expected) + self.compare_with_numpy(op, np.linalg.det, t) # NOTE: det requires a 2D+ tensor t = torch.randn(1, device=device, dtype=dtype) with self.assertRaises(RuntimeError): op(t) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigh(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(shape, batch, uplo): + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) + actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + # sign of eigenvectors is not unique and therefore absolute values are compared + self.assertEqual(abs(actual_v), abs(expected_v)) + # additionally we can flip the sign and then compare the values + # let's choose the convention that the first element of the eigenvector should be positive, + # otherwise flip the sign of the eigenvector + if matrix.numel() > 0: + sign = np.sign(expected_v[..., 0, :]).reshape(batch + (1, shape)) + expected_v = sign * expected_v + torch_real_slice = actual_v[..., 0, :].real if dtype.is_complex else actual_v[..., 0, :] + sign = torch.sign(torch_real_slice).reshape(batch + (1, shape)) + actual_v = sign * actual_v + self.assertEqual(actual_v, expected_v) + + # check the out= variant + out_w = torch.empty_like(actual_w) + out_v = torch.empty_like(actual_v) + ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v)) + self.assertEqual(ans_w, out_w) + self.assertEqual(ans_v, out_v) + self.assertEqual(ans_w, actual_w) + self.assertEqual(abs(ans_v), abs(actual_v)) + + shapes = (0, 3, 5) + batches = ((), (3, ), (2, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test(shape, batch, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigh_lower_uplo(self, device, dtype): + def run_test(shape, batch, uplo): + # check lower case uplo + # use non-symmetric input to check whether uplo argument is working as intended + matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device) + expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) + actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + self.assertEqual(abs(actual_v), abs(expected_v)) + + uplos = ["u", "l"] + for uplo in uplos: + run_test(3, (2, 2), uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_eigh_errors_and_warnings(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + # eigh requires a square matrix + t = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.linalg.eigh(t) + + # eigh requires 'uplo' parameter to be 'U' or 'L' + t = torch.randn(3, 3, device=device, dtype=dtype) + for uplo in ["a", "wrong"]: + with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + torch.linalg.eigh(t, UPLO=uplo) + with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + np.linalg.eigh(t.cpu().numpy(), UPLO=uplo) + + # if non-empty out tensor with wrong shape is passed a warning is given + a = random_hermitian_matrix(3, dtype=dtype, device=device) + real_dtype = a.real.dtype if dtype.is_complex else dtype + out_w = torch.empty(7, 7, dtype=real_dtype, device=device) + out_v = torch.empty(7, 7, dtype=dtype, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.eigh(a, out=(out_w, out_v)) + # Check warning occurs + self.assertEqual(len(w), 2) + self.assertTrue("An output with one or more elements was resized" in str(w[-2].message)) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out_w = torch.empty_like(a).to(torch.int) + out_v = torch.empty_like(a) + with self.assertRaisesRegex(RuntimeError, "dtype Int does not match self dtype"): + torch.linalg.eigh(a, out=(out_w, out_v)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigh_non_contiguous(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(matrix, uplo): + self.assertFalse(matrix.is_contiguous()) + expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) + actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + # sign of eigenvectors is not unique and therefore absolute values are compared + self.assertEqual(abs(actual_v), abs(expected_v)) + + def run_test_permuted(shape, batch, uplo): + # check for permuted / transposed inputs + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix.transpose(-2, -1) + run_test(matrix, uplo) + + def run_test_skipped_elements(shape, batch, uplo): + # check for inputs with skipped elements + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix[::2] + run_test(matrix, uplo) + + shapes = (3, 5) + batches = ((4, ), (4, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test_permuted(shape, batch, uplo) + run_test_skipped_elements(shape, batch, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_eigh_autograd(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def func(x, uplo): + x = 0.5 * (x + x.conj().transpose(-2, -1)) + return torch.linalg.eigh(x, UPLO=uplo) + + def func_grad_w(x, uplo): + return func(x, uplo)[0] + + def func_grad_v(x, uplo): + # gauge invariant loss function + return abs(func(x, uplo)[1]) + + def run_test(dims, uplo): + x = torch.randn(*dims, dtype=dtype, device=device, requires_grad=True) + + gradcheck(func_grad_w, [x, uplo]) + gradgradcheck(func_grad_w, [x, uplo]) + + gradcheck(func_grad_v, [x, uplo]) + gradgradcheck(func_grad_v, [x, uplo]) + + x = random_hermitian_matrix(dims[-1], *dims[:-2]).requires_grad_() + w, v = torch.linalg.eigh(x) + (w.sum() + abs(v).sum()).backward() + self.assertEqual(x.grad, x.grad.conj().transpose(-1, -2)) # Check the gradient is Hermitian + + for dims, uplo in itertools.product([(3, 3), (2, 3, 3)], ["L", "U"]): + run_test(dims, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigvalsh(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(shape, batch, uplo): + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo) + actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + + # check the out= variant + out = torch.empty_like(actual_w) + ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, actual_w) + + shapes = (0, 3, 5) + batches = ((), (3, ), (2, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test(shape, batch, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_eigvalsh_errors_and_warnings(self, device, dtype): + # eigvalsh requires a square matrix + t = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.linalg.eigvalsh(t) + + # eigvalsh requires 'uplo' parameter to be 'U' or 'L' + t = torch.randn(3, 3, device=device, dtype=dtype) + for uplo in ["a", "wrong"]: + with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + torch.linalg.eigvalsh(t, UPLO=uplo) + with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo) + + # if non-empty out tensor with wrong shape is passed a warning is given + real_dtype = t.real.dtype if dtype.is_complex else dtype + out = torch.empty_like(t).to(real_dtype) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.eigvalsh(t, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(t).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.eigvalsh(t, out=out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigvalsh_non_contiguous(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(matrix, uplo): + self.assertFalse(matrix.is_contiguous()) + expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo) + actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + + def run_test_permuted(shape, batch, uplo): + # check for permuted / transposed inputs + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix.transpose(-2, -1) + run_test(matrix, uplo) + + def run_test_skipped_elements(shape, batch, uplo): + # check for inputs with skipped elements + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix[::2] + run_test(matrix, uplo) + + shapes = (3, 5) + batches = ((4, ), (4, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test_permuted(shape, batch, uplo) + run_test_skipped_elements(shape, batch, uplo) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron(self, device, dtype): + + def run_test_case(a_shape, b_shape): + a = torch.rand(a_shape, dtype=dtype, device=device) + b = torch.rand(b_shape, dtype=dtype, device=device) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty_like(result) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] + for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): + run_test_case(a_shape, b_shape) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_non_contiguous(self, device, dtype): + + def run_test_transposed(a_shape, b_shape): + # check for transposed case + a = torch.rand(a_shape, dtype=dtype, device=device).transpose(-2, -1) + b = torch.rand(b_shape, dtype=dtype, device=device).transpose(-2, -1) + self.assertFalse(a.is_contiguous()) + self.assertFalse(b.is_contiguous()) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty(result.transpose(-2, -1).shape, dtype=dtype, device=device).transpose(-2, -1) + self.assertFalse(out.is_contiguous()) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + def run_test_skipped_elements(a_shape, b_shape): + # check for transposed case + a = torch.rand(2 * a_shape[0], *a_shape[1:], dtype=dtype, device=device)[::2] + b = torch.rand(2 * b_shape[0], *b_shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(a.is_contiguous()) + self.assertFalse(b.is_contiguous()) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(out.is_contiguous()) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + shapes = [(2, 2), (2, 2, 3), (2, 2, 3, 3)] + for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): + # run_test_transposed(a_shape, b_shape) + run_test_skipped_elements(a_shape, b_shape) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_empty(self, device, dtype): + + def run_test_case(empty_shape): + a = torch.eye(3, dtype=dtype, device=device) + b = torch.empty(empty_shape, dtype=dtype, device=device) + result = torch.kron(a, b) + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(result, expected) + + # NumPy doesn't work if the first argument is empty + result = torch.kron(b, a) + self.assertEqual(result.shape, expected.shape) + + empty_shapes = [(0,), (2, 0), (1, 0, 3)] + for empty_shape in empty_shapes: + run_test_case(empty_shape) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_errors_and_warnings(self, device, dtype): + # if non-empty out tensor with wrong shape is passed a warning is given + a = torch.eye(3, dtype=dtype, device=device) + b = torch.ones((2, 2), dtype=dtype, device=device) + out = torch.empty_like(a) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.kron(a, b, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.kron(a, b, out=out) + # This test confirms that torch.linalg.norm's dtype argument works # as expected, according to the function's documentation @skipCUDAIfNoMagma def test_norm_dtype(self, device): - def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype): + def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype): + # Determine the best dtype to use for comparisons between tensors + # of two different types + def get_compare_dtype(type0, type1): + types_32bit_based = [torch.float, torch.cfloat] + is_complex = type0.is_complex or type1.is_complex + + if type0 in types_32bit_based or type1 in types_32bit_based: + return torch.cfloat if is_complex else torch.float + else: + return torch.cdouble if is_complex else torch.double + + compare_dtype = get_compare_dtype(from_dtype, to_dtype) + + def get_value_type(dtype): + if dtype == torch.cfloat: + return torch.float + elif dtype == torch.cdouble: + return torch.double + elif dtype == torch.complex32: + return torch.float16 + else: + return dtype + msg = ( f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' f'from_dtype={from_dtype}, to_dtype={to_dtype}') input = torch.randn(*input_size, dtype=from_dtype, device=device) - result = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=from_dtype) - self.assertEqual(result.dtype, from_dtype, msg=msg) - result_converted = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) - self.assertEqual(result_converted.dtype, to_dtype, msg=msg) - self.assertEqual(result.to(compare_dtype), result_converted.to(compare_dtype), msg=msg) + result = torch.linalg.norm(input, ord, keepdim=keepdim) + if from_dtype.is_complex: + # By default, norm downgrades a complex input to the corresponding real number type + self.assertEqual(result.dtype, get_value_type(from_dtype), msg=msg) + else: + self.assertEqual(result.dtype, from_dtype, msg=msg) - result_out_converted = torch.empty_like(result_converted) - torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_converted) - self.assertEqual(result_out_converted.dtype, to_dtype, msg=msg) - self.assertEqual(result_converted, result_out_converted, msg=msg) + result_out = torch.empty((), dtype=to_dtype, device=device) + torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out) + self.assertEqual(result_out.dtype, to_dtype, msg=msg) + self.assertEqual(result.to(compare_dtype), result_out.to(compare_dtype), msg=msg) - ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] - ord_matrix = [1, -1, 2, -2, inf, -inf, None] + result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) + self.assertEqual(result_with_dtype.dtype, to_dtype, msg=msg) + + if from_dtype.is_complex: + result_convert_first = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim) + self.assertEqual(result_with_dtype.to(compare_dtype), result_convert_first.to(compare_dtype), msg=msg) + else: + self.assertEqual(result.to(compare_dtype), result_with_dtype.to(compare_dtype), msg=msg) + + result_out_with_dtype = torch.empty_like(result_with_dtype) + torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype) + self.assertEqual(result_out_with_dtype.dtype, to_dtype, msg=msg) + self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg) + + ord_vector = [0, 0.1, -0.1, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] + ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None] S = 10 test_cases = [ ((S, ), ord_vector), @@ -211,15 +932,16 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) for keepdim in [True, False]: for input_size, ord_settings in test_cases: for ord in ord_settings: - # float to double - run_test_case(input_size, ord, keepdim, torch.float, torch.double, torch.float) - # double to float - run_test_case(input_size, ord, keepdim, torch.double, torch.double, torch.float) + dtypes = [torch.float, torch.double, torch.cfloat, torch.cdouble] + for from_dtype, to_dtype in itertools.product(dtypes, dtypes): + run_test_case(input_size, ord, keepdim, from_dtype, to_dtype) # Make sure that setting dtype != out.dtype raises an error dtype_pairs = [ (torch.float, torch.double), (torch.double, torch.float), + (torch.cfloat, torch.cdouble), + (torch.cdouble, torch.cfloat), ] for keepdim in [True, False]: for input_size, ord_settings in test_cases: @@ -230,16 +952,8 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) with self.assertRaisesRegex(RuntimeError, r'provided dtype must match dtype of result'): torch.linalg.norm(input, ord=ord, keepdim=keepdim, dtype=dtype, out=result) - # TODO: Once dtype arg is supported in nuclear and frobenius norms, remove the following test - # and add 'nuc' and 'fro' to ord_matrix above - for ord in ['nuc', 'fro']: - input = torch.randn(10, 10, device=device) - with self.assertRaisesRegex(RuntimeError, f"ord=\'{ord}\' does not yet support the dtype argument"): - torch.linalg.norm(input, ord, dtype=torch.float) - # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that # their vector norm results match - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float, torch.double) def test_norm_vector(self, device, dtype): def run_test_case(input, p, dim, keepdim): @@ -278,7 +992,6 @@ def run_test_case(input, p, dim, keepdim): # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that # their matrix norm results match @skipCUDAIfNoMagma - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float, torch.double) def test_norm_matrix(self, device, dtype): def run_test_case(input, p, dim, keepdim): @@ -314,6 +1027,117 @@ def run_test_case(input, p, dim, keepdim): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3}) + def test_cond(self, device, dtype): + def run_test_case(input, p): + result = torch.linalg.cond(input, p) + result_numpy = np.linalg.cond(input.cpu().numpy(), p) + self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) + + # test out= variant + out = torch.empty_like(result) + ans = torch.linalg.cond(input, p, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] + input_sizes = [(32, 32), (2, 3, 3, 3)] + for input_size in input_sizes: + input = torch.randn(*input_size, dtype=dtype, device=device) + for p in norm_types: + run_test_case(input, p) + + # test empty batch sizes + input_sizes = [(0, 3, 3), (0, 2, 5, 5)] + for input_size in input_sizes: + input = torch.randn(*input_size, dtype=dtype, device=device) + for p in norm_types: + run_test_case(input, p) + + # test non-square input + input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)] + for input_size in input_sizes: + input = torch.randn(*input_size, dtype=dtype, device=device) + for p in [2, -2, None]: + run_test_case(input, p) + + # test for singular input + a = torch.eye(3, dtype=dtype, device=device) + a[-1, -1] = 0 # make 'a' singular + for p in norm_types: + run_test_case(a, p) + + # test for 0x0 matrices. NumPy doesn't work for such input, we return 0 + input_sizes = [(0, 0), (2, 5, 0, 0)] + for input_size in input_sizes: + input = torch.randn(*input_size, dtype=dtype, device=device) + for p in ['fro', 2]: + expected_dtype = a.real.dtype if dtype.is_complex else dtype + expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device) + actual = torch.linalg.cond(input, p) + self.assertEqual(actual, expected) + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3}) + def test_cond_errors_and_warnings(self, device, dtype): + norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] + + # cond expects the input to be at least 2-dimensional + a = torch.ones(3, dtype=dtype, device=device) + for p in norm_types: + with self.assertRaisesRegex(RuntimeError, r'supports matrices or batches of matrices'): + torch.linalg.cond(a, p) + + # for some norm types cond expects the input to be square + a = torch.ones(3, 2, dtype=dtype, device=device) + norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] + for p in norm_types: + with self.assertRaisesRegex(RuntimeError, r'supports square matrices or batches of square matrices'): + torch.linalg.cond(a, p) + + # if non-empty out tensor with wrong shape is passed a warning is given + a = torch.ones((2, 2), dtype=dtype, device=device) + for p in ['fro', 2]: + real_dtype = a.real.dtype if dtype.is_complex else dtype + out = torch.empty(a.shape, dtype=real_dtype, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.cond(a, p, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + for p in ['fro', 2]: + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match"): + torch.linalg.cond(a, p, out=out) + + # for batched input if at least one matrix in the batch is not invertible, + # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. + # this should change when at::inverse works with silent errors + # NumPy works fine in this case because it's possible to silence the error and get the inverse matrix results + # possibly filled with NANs + batch_dim = 3 + a = torch.eye(3, 3, dtype=dtype, device=device) + a = a.reshape((1, 3, 3)) + a = a.repeat(batch_dim, 1, 1) + a[0, -1, -1] = 0 # now a[0] is singular + for p in [1, -1, inf, -inf, 'fro', 'nuc']: + with self.assertRaisesRegex(RuntimeError, "linalg_cond does not support yet"): + torch.linalg.cond(a, p) + + # check invalid norm type + a = torch.ones(3, 3, dtype=dtype, device=device) + for p in ['wrong_norm', 5]: + with self.assertRaisesRegex(RuntimeError, f"linalg_cond got an invalid norm type: {p}"): + torch.linalg.cond(a, p) + # Test autograd and jit functionality for linalg functions. # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, # the `test_cases` entries below should be moved there. These entries are in a similar format, @@ -341,11 +1165,13 @@ def test_autograd_and_jit(self, device, dtype): ('norm', (S, S, S), (), 'default_3d'), ('norm', (S,), (inf,), 'vector_inf'), ('norm', (S,), (3.5,), 'vector_3_5'), + ('norm', (S,), (0.5,), 'vector_0_5'), ('norm', (S,), (2,), 'vector_2'), ('norm', (S,), (1,), 'vector_1'), ('norm', (S,), (0,), 'vector_0'), ('norm', (S,), (-inf,), 'vector_neg_inf'), ('norm', (S,), (-3.5,), 'vector_neg_3_5'), + ('norm', (S,), (-0.5,), 'vector_neg_0_5'), ('norm', (S,), (2,), 'vector_neg_2'), ('norm', (S,), (1,), 'vector_neg_1'), ('norm', (S, S), (inf,), 'matrix_inf'), @@ -400,7 +1226,6 @@ def run_func(input): # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments # to ensure that they both throw errors - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float, torch.double) def test_norm_errors(self, device, dtype): def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): @@ -441,86 +1266,100 @@ def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): for ord in ord_settings: run_error_test_case(input, ord, dim, keepdim, error_type, error_regex) - # Test complex number inputs for linalg.norm. Some cases are not supported yet, so - # this test also verifies that those cases raise an error. + # Test complex number inputs for linalg.norm @skipCUDAIfNoMagma @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") @dtypes(torch.cfloat, torch.cdouble) def test_norm_complex(self, device, dtype): def gen_error_message(input_size, ord, keepdim, dim=None): return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( input_size, ord, keepdim, dim) - if self.device_type == 'cpu': - supported_vector_ords = [0, 1, 3, inf, -1, -2, -3, -inf] - supported_matrix_ords = ['nuc', 1, 2, inf, -1, -2, -inf] - unsupported_vector_ords = [ - (2, r'norm with p=2 not supported for complex tensors'), - (None, r'norm with p=2 not supported for complex tensors'), - ] - unsupported_matrix_ords = [ - ('fro', r'frobenius norm not supported for complex tensors'), - (None, r'norm with p=2 not supported for complex tensors'), - ] - - elif self.device_type == 'cuda': - supported_vector_ords = [inf, -inf] - supported_matrix_ords = [1, inf, -1, -inf] - unsupported_vector_ords = [ - (0, r'norm_cuda" not implemented for \'Complex'), - (1, r'norm_cuda" not implemented for \'Complex'), - (2, r'norm with p=2 not supported for complex tensors'), - (-1, r'norm_cuda" not implemented for \'Complex'), - (-2, r'norm_cuda" not implemented for \'Complex'), - (None, r'norm with p=2 not supported for complex tensors'), - ] - unsupported_matrix_ords = [ - (None, r'norm with p=2 not supported for complex tensors'), - ('fro', r'frobenius norm not supported for complex tensors'), - (2, r'"svd_cuda" not implemented for \'Complex'), - (-2, r'"svd_cuda" not implemented for \'Complex'), - ('nuc', r'"svd_cuda" not implemented for \'Complex'), - ] + vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf] + matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf] # Test supported ords for keepdim in [False, True]: # vector norm x = torch.randn(25, device=device, dtype=dtype) xn = x.cpu().numpy() - for ord in supported_vector_ords: + for ord in vector_ords: res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() expected = np.linalg.norm(xn, ord, keepdims=keepdim) msg = gen_error_message(x.size(), ord, keepdim) self.assertEqual(res.shape, expected.shape, msg=msg) self.assertEqual(res, expected, msg=msg) + res_out = torch.Tensor().to(device) + torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) + self.assertEqual(res_out.shape, expected.shape, msg=msg) + self.assertEqual(res_out.cpu(), expected, msg=msg) + # matrix norm x = torch.randn(25, 25, device=device, dtype=dtype) xn = x.cpu().numpy() - for ord in supported_matrix_ords: - # TODO: Need to fix abort when nuclear norm is given cdouble input: - # "double free or corruption (!prev) Aborted (core dumped)" - if ord == 'nuc' and dtype == torch.cdouble: - continue + for ord in matrix_ords: res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() expected = np.linalg.norm(xn, ord, keepdims=keepdim) msg = gen_error_message(x.size(), ord, keepdim) self.assertEqual(res.shape, expected.shape, msg=msg) self.assertEqual(res, expected, msg=msg) - # Test unsupported ords - # vector norm - x = torch.randn(25, device=device, dtype=dtype) - for ord, error_msg in unsupported_vector_ords: - with self.assertRaisesRegex(RuntimeError, error_msg): - torch.linalg.norm(x, ord) + res_out = torch.Tensor().to(device) + torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out) + self.assertEqual(res_out.shape, expected.shape, msg=msg) + self.assertEqual(res_out.cpu(), expected, msg=msg) + + # Test complex number inputs for linalg.norm + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.cfloat, torch.cdouble) + def test_norm_complex_autograd(self, device, dtype): + def gen_error_message(input_size, ord, keepdim, dim=None): + return "complex norm autograd failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( + input_size, ord, keepdim, dim) + + if dtype == torch.cfloat: + dtype_real = torch.float + elif dtype == torch.cdouble: + dtype_real = torch.double + else: + raise RuntimeError(f'dtype not supported in this test: {dtype}') + + vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf] + matrix_ords = [None, 'fro', 1, inf, -1, -inf] + + # TODO: Fix autograd for matrix orders 'nuc', 2, and -2 by adding complex + # support to svd's backward method. Once this is done, these ords + # should be added to `matrix_ords` above + # Update: svd's backward now works with https://github.com/pytorch/pytorch/pull/47761 + # However run_test_case doesn't work for 'matrix_ords_unsupported' cases + # because singular values of 'x' and 'x_real' can be different and so is their norms based on singular values + matrix_ords_unsupported = ['nuc', 2, -2] + + def run_test_case(x, ord, keepdim): + res = torch.linalg.norm(x, ord, keepdim=keepdim) + res.backward() + + x_real = x.clone().detach().abs().requires_grad_(True) + res_real = torch.linalg.norm(x_real, ord, keepdim=keepdim) + res_real.backward() + + msg = gen_error_message(x.size(), ord, keepdim) + + self.assertEqual(res.shape, res_real.shape, msg=msg) + self.assertEqual(res, res_real, msg=msg) + self.assertEqual(x.grad.abs(), x_real.grad, msg=msg) + + # Test supported ords + for keepdim in [False, True]: + for ord in vector_ords: + x = torch.randn(25, dtype=dtype, device=device, requires_grad=True) + run_test_case(x, ord, keepdim) - # matrix norm - x = torch.randn(25, 25, device=device, dtype=dtype) - for ord, error_msg in unsupported_matrix_ords: - with self.assertRaisesRegex(RuntimeError, error_msg): - torch.linalg.norm(x, ord) + for ord in matrix_ords: + x = torch.randn(25, 25, dtype=dtype, device=device, requires_grad=True) + run_test_case(x, ord, keepdim) # Test that linal.norm gives the same result as numpy when inputs # contain extreme values (inf, -inf, nan) @@ -528,7 +1367,6 @@ def gen_error_message(input_size, ord, keepdim, dim=None): @unittest.skipIf(IS_MACOS, "Skipped on MacOS!") @skipCUDAIfNoMagma @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_norm_extreme_values(self, device): vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf] matrix_ords = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf] @@ -585,12 +1423,6 @@ def run_test_case(input, ord, dim, keepdim, should_error): with self.assertRaises(RuntimeError): torch.linalg.norm(input, ord, dim, keepdim) else: - if dtype in [torch.cfloat, torch.cdouble] and ord in [2, None]: - # TODO: Once these ord values have support for complex numbers, - # remove this error test case - with self.assertRaises(RuntimeError): - torch.linalg.norm(input, ord, dim, keepdim) - return result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) result = torch.linalg.norm(input, ord, dim, keepdim) self.assertEqual(result, result_numpy, msg=msg) @@ -614,16 +1446,9 @@ def run_test_case(input, ord, dim, keepdim, should_error): # Test degenerate shape results match numpy for linalg.norm matrix norms @skipCUDAIfNoMagma @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_norm_matrix_degenerate_shapes(self, device, dtype): def run_test_case(input, ord, dim, keepdim, should_error): - if dtype in [torch.cfloat, torch.cdouble] and ord in ['fro', None]: - # TODO: Once these ord values have support for complex numbers, - # remove this error test case - with self.assertRaises(RuntimeError): - torch.linalg.norm(input, ord, dim, keepdim) - return msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' input_numpy = input.cpu().numpy() if should_error: @@ -654,18 +1479,6 @@ def run_test_case(input, ord, dim, keepdim, should_error): for ord in ord_matrix: run_test_case(input, ord, dim, keepdim, ord in error_ords) - def test_norm_deprecated(self, device): - expected_message = ( - r'torch.norm is deprecated and may be removed in a future PyTorch release. ' - r'Use torch.linalg.norm instead.') - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - for func in [torch.norm, torch.functional.norm]: - func(torch.rand(10, device=device)) - self.assertEqual(len(w), 2) - for wi in w: - self.assertEqual(str(wi.message), expected_message) - def test_norm_fastpaths(self, device): x = torch.randn(3, 5, device=device) @@ -694,6 +1507,4912 @@ def test_norm_fastpaths(self, device): expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) self.assertEqual(result, expected) + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.double, torch.float) + def test_eig_basic(self, device, dtype): + a = torch.tensor([[1.96, 0.00, 0.00, 0.00, 0.00], + [-6.49, 3.80, 0.00, 0.00, 0.00], + [-0.47, -6.39, 4.17, 0.00, 0.00], + [-7.20, 1.50, -1.51, 5.70, 0.00], + [-0.65, -6.34, 2.67, 1.80, -7.10]], + dtype=dtype, device=device).t() + e = torch.eig(a)[0] + ee, vv = torch.eig(a, True) + te = torch.tensor((), dtype=dtype, device=device) + tv = torch.tensor((), dtype=dtype, device=device) + eee, vvv = torch.eig(a, True, out=(te, tv)) + self.assertEqual(e, ee, atol=1e-12, rtol=0) + self.assertEqual(ee, eee, atol=1e-12, rtol=0) + self.assertEqual(ee, te, atol=1e-12, rtol=0) + self.assertEqual(vv, vvv, atol=1e-12, rtol=0) + self.assertEqual(vv, tv, atol=1e-12, rtol=0) + # + # compare with numpy + np_e, np_v = np.linalg.eig(a.cpu().numpy()) + # np_e.shape == (n, 2), where each column contain the real and + # imaginary parts of the result + self.assertEqual(ee[:, 0], np_e) # real part + self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part + self.assertEqual(vv, np_v) + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.double, torch.float) + def test_eig_reuse(self, device, dtype): + X = torch.randn(4, 4, dtype=dtype, device=device) + X = torch.mm(X.t(), X) + e = torch.zeros(4, 2, dtype=dtype, device=device) + v = torch.zeros(4, 4, dtype=dtype, device=device) + torch.eig(X, True, out=(e, v)) + Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) + if dtype is torch.float: + atol = 1e-7 + rtol = 1e-5 + else: + atol = 1e-8 + rtol = 0 + self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong') + self.assertTrue(v.is_contiguous(), 'V is not contiguous') + + torch.eig(X, True, out=(e, v)) + Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t())) + self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong') + self.assertTrue(v.is_contiguous(), 'V is not contiguous') + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.double, torch.float) + def test_eig_non_contiguous(self, device, dtype): + X = torch.randn(4, 4, dtype=dtype, device=device) + X = torch.mm(X.t(), X) + e = torch.zeros(4, 2, 2, dtype=dtype, device=device)[:, 1] + v = torch.zeros(4, 2, 4, dtype=dtype, device=device)[:, 1] + self.assertFalse(v.is_contiguous(), 'V is contiguous') + self.assertFalse(e.is_contiguous(), 'E is contiguous') + torch.eig(X, True, out=(e, v)) + Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) + if dtype is torch.float: + atol = 1e-7 + rtol = 1e-5 + else: + atol = 1e-8 + rtol = 0 + self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong') + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.double, torch.float) + def test_eig_invalid_input(self, device, dtype): + # test invalid input + self.assertRaisesRegex( + RuntimeError, + 'input should be 2 dimensional', + lambda: torch.eig(torch.ones((2)))) + self.assertRaisesRegex( + RuntimeError, + 'input should be square', + lambda: torch.eig(torch.ones((2, 3)))) + self.assertRaisesRegex( + RuntimeError, + 'input should not contain infs or NaNs', + lambda: torch.eig(np.inf * torch.ones((2, 2)))) + self.assertRaisesRegex( + RuntimeError, + 'input should not contain infs or NaNs', + lambda: torch.eig(np.nan * torch.ones((2, 2)))) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double, torch.float) + def test_eig_out(self, device, dtype): + # the out version of torch.eig needs to be tested manually: we can't + # use the "test_out=True" parameter to tensor_op_tests because the + # signature is irregular (since we have *two* output vectors) + t = torch.randn(10, 10, dtype=dtype, device=device) + evals, evecs = torch.eig(t, eigenvectors=True) + # + # check that the out= version computes the same values as the normal one + out_evals = torch.empty_like(evals) + out_evecs = torch.empty_like(evecs) + evals2, evecs2 = torch.eig(t, eigenvectors=True, out=(out_evals, out_evecs)) + # check that the out tensors were used in-place + self.assertEqual(evals2.data_ptr(), out_evals.data_ptr()) + self.assertEqual(evecs2.data_ptr(), out_evecs.data_ptr()) + # check that the result is the same as the non-out version + self.assertEqual(evals, out_evals) + self.assertEqual(evecs, out_evecs) + # + # check what happens in the eigenvectors=False case + out_evals = torch.empty_like(evals) + out_evecs = torch.tensor([1, 2, 3], dtype=dtype, device=device) + evals2, evecs2 = torch.eig(t, eigenvectors=False, out=(out_evals, out_evecs)) + # check that the out_evals was used in-place + self.assertEqual(evals2.data_ptr(), out_evals.data_ptr()) + self.assertEqual(evals, out_evals) + # check that out_evecs was NOT touched at all + assert out_evecs.tolist() == [1, 2, 3] + # + # check that we complain if we pass an out vector of the wrong dtype + wrong_out = torch.empty((0, 0), dtype=int) + with self.assertRaisesRegex(RuntimeError, r"Expected .* but got .*"): + torch.eig(t, eigenvectors=True, out=(wrong_out, out_evecs)) + with self.assertRaisesRegex(RuntimeError, r"Expected .* but got .*"): + torch.eig(t, eigenvectors=True, out=(out_evals, wrong_out)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_norm_old(self, device): + def gen_error_message(input_size, p, keepdim, dim=None): + return "norm failed for input size %s, p=%s, keepdim=%s, dim=%s" % ( + input_size, p, keepdim, dim) + + for keepdim in [False, True]: + # full reduction + x = torch.randn(25, device=device) + xn = x.cpu().numpy() + for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]: + res = x.norm(p, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, p, keepdims=keepdim) + self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim)) + + # one dimension + x = torch.randn(25, 25, device=device) + xn = x.cpu().numpy() + for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]: + dim = 1 + res = x.norm(p, dim, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, p, dim, keepdims=keepdim) + msg = gen_error_message(x.size(), p, keepdim, dim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # matrix norm + for p in ['fro', 'nuc']: + res = x.norm(p, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, p, keepdims=keepdim) + msg = gen_error_message(x.size(), p, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # zero dimensions + x = torch.randn((), device=device) + xn = x.cpu().numpy() + res = x.norm(keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, keepdims=keepdim) + msg = gen_error_message(x.size(), None, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # larger tensor sanity check + self.assertEqual( + 2 * torch.norm(torch.ones(10000), keepdim=keepdim), + torch.norm(torch.ones(40000), keepdim=keepdim)) + + # matrix norm with non-square >2-D tensors, all combinations of reduction dims + x = torch.randn(5, 6, 7, 8, device=device) + xn = x.cpu().numpy() + for p in ['fro', 'nuc']: + for dim in itertools.product(*[list(range(4))] * 2): + if dim[0] == dim[1]: + continue + res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim) + msg = gen_error_message(x.size(), p, keepdim, dim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_norm_complex_old(self, device): + def gen_error_message(input_size, p, keepdim, dim=None): + return "complex norm failed for input size %s, p=%s, keepdim=%s, dim=%s" % ( + input_size, p, keepdim, dim) + + for keepdim in [False, True]: + # vector norm + x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device) + xn = x.cpu().numpy() + for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]: + res = x.norm(p, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, p, keepdims=keepdim) + msg = gen_error_message(x.size(), p, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # matrix norm + x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device) + xn = x.cpu().numpy() + for p in ['nuc', 'fro']: + res = x.norm(p, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, p, keepdims=keepdim) + msg = gen_error_message(x.size(), p, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations + @dtypes(torch.float) + @skipCUDAIfRocm + def test_norm_fro_2_equivalence_old(self, device, dtype): + input_sizes = [ + (0,), + (10,), + (0, 0), + (4, 30), + (0, 45), + (100, 0), + (45, 10, 23), + (0, 23, 59), + (23, 0, 37), + (34, 58, 0), + (0, 0, 348), + (0, 3434, 0), + (0, 0, 0), + (5, 3, 8, 1, 3, 5)] + + for input_size in input_sizes: + a = make_tensor(input_size, device, dtype, low=-9, high=9) + + # Try full reduction + dim_settings = [None] + + # Try all possible 1-D reductions + dim_settings += list(range(-a.dim(), a.dim())) + + def wrap_dim(dim, ndims): + assert (dim < ndims) and (dim >= -ndims) + if dim >= 0: + return dim + else: + return dim + ndims + + # Try all possible 2-D reductions + dim_settings += [ + (d0, d1) for d0, d1 in itertools.combinations(range(-a.dim(), a.dim()), 2) + if wrap_dim(d0, a.dim()) != wrap_dim(d1, a.dim())] + + for dim in dim_settings: + for keepdim in [True, False]: + a_norm_2 = torch.norm(a, p=2, dim=dim, keepdim=keepdim) + a_norm_fro = torch.norm(a, p='fro', dim=dim, keepdim=keepdim) + self.assertEqual(a_norm_fro, a_norm_2) + + @skipCUDAIfNoMagma + def test_nuclear_norm_axes_small_brute_force_old(self, device): + def check_single_nuclear_norm(x, axes): + if self.device_type != 'cpu' and randrange(100) < 95: + return # too many cpu <==> device copies + + a = np.array(x.cpu(), copy=False) + expected = np.linalg.norm(a, "nuc", axis=axes) + + ans = torch.norm(x, "nuc", dim=axes) + self.assertTrue(ans.is_contiguous()) + self.assertEqual(ans.shape, expected.shape) + self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) + + out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device) + ans = torch.norm(x, "nuc", dim=axes, out=out) + self.assertIs(ans, out) + self.assertTrue(ans.is_contiguous()) + self.assertEqual(ans.shape, expected.shape) + self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) + + for n in range(1, 3): + for m in range(1, 3): + for axes in itertools.permutations([0, 1], 2): + # 2d, inner dimensions C + x = torch.randn(n, m, device=device) + check_single_nuclear_norm(x, axes) + + # 2d, inner dimensions Fortran + x = torch.randn(m, n, device=device).transpose(-1, -2) + check_single_nuclear_norm(x, axes) + + # 2d, inner dimensions non-contiguous + x = torch.randn(n, 2 * m, device=device)[:, ::2] + check_single_nuclear_norm(x, axes) + + # 2d, all dimensions non-contiguous + x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2] + check_single_nuclear_norm(x, axes) + + for o in range(1, 3): + for axes in itertools.permutations([0, 1, 2], 2): + # 3d, inner dimensions C + x = torch.randn(o, n, m, device=device) + check_single_nuclear_norm(x, axes) + + # 3d, inner dimensions Fortran + x = torch.randn(o, m, n, device=device).transpose(-1, -2) + check_single_nuclear_norm(x, axes) + + # 3d, inner dimensions non-contiguous + x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2] + check_single_nuclear_norm(x, axes) + + # 3d, all dimensions non-contiguous + x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2] + check_single_nuclear_norm(x, axes) + + for r in range(1, 3): + for axes in itertools.permutations([0, 1, 2, 3], 2): + # 4d, inner dimensions C + x = torch.randn(r, o, n, m, device=device) + check_single_nuclear_norm(x, axes) + + # 4d, inner dimensions Fortran + x = torch.randn(r, o, n, m, device=device).transpose(-1, -2) + check_single_nuclear_norm(x, axes) + + # 4d, inner dimensions non-contiguous + x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2] + check_single_nuclear_norm(x, axes) + + # 4d, all dimensions non-contiguous + x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2] + check_single_nuclear_norm(x, axes) + + @skipCUDAIfNoMagma + def test_nuclear_norm_exceptions_old(self, device): + for lst in [], [1], [1, 2]: + x = torch.tensor(lst, dtype=torch.double, device=device) + for axes in (), (0,): + self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) + self.assertRaises(IndexError, torch.norm, x, "nuc", (0, 1)) + + x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) + self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) + self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + + # ~~~ tests for torch.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd(self, device, dtype): + def run_test(dims, some, compute_uv): + x = torch.randn(*dims, dtype=dtype, device=device) + outu = torch.empty(0, dtype=dtype, device=device) + outs = torch.empty(0, dtype=dtype, device=device) + outv = torch.empty(0, dtype=dtype, device=device) + torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) + + if compute_uv: + if some: + x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = outu[..., :min(*dims[-2:])] + narrow_v = outv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, outs, msg='Singular values mismatch') + self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') + self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') + + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') + self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') + self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') + + # test non-contiguous + x = torch.randn(*dims, dtype=dtype, device=device) + n_dim = len(dims) + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + if compute_uv: + if some: + x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = resu[..., :min(*dims[-2:])] + narrow_v = resv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, ress, msg='Singular values mismatch') + self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') + self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') + + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices + for dims, some, compute_uv in product(shapes, [True, False], [True, False]): + run_test(dims, some, compute_uv) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_svd_no_singularvectors(self, device, dtype): + for size in [(5, 5), (5, 20), (20, 5)]: + a = torch.randn(*size, device=device, dtype=dtype) + u, s_expect, v = torch.svd(a) + u, s_actual, v = torch.svd(a, compute_uv=False) + self.assertEqual(s_expect, s_actual, msg="Singular values don't match") + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd_lowrank(self, device, dtype): + from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix + + def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): + density = options.pop('density', 1) + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + if density == 1: + a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) + a = a_input + else: + assert batches == () + a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) + a = a_input.to_dense() + + q = min(*size) + u, s, v = svd_lowrank(a_input, q=q, **options) + + # check if u, s, v is a SVD + u, s, v = u[..., :q], s[..., :q], v[..., :q] + A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) + self.assertEqual(A, a) + + # check if svd_lowrank produces same singular values as torch.svd + U, S, V = torch.svd(a) + self.assertEqual(s.shape, S.shape) + self.assertEqual(u.shape, U.shape) + self.assertEqual(v.shape, V.shape) + self.assertEqual(s, S) + + if density == 1: + # actual_rank is known only for dense inputs + # + # check if pairs (u, U) and (v, V) span the same + # subspaces, respectively + u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] + U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] + self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size, all_batches in [ + (2, (17, 4), all_batches), + (4, (17, 4), all_batches), + (4, (17, 17), all_batches), + (10, (100, 40), all_batches), + (7, (1000, 1000), [()]), + ]: + # dense input + for batches in all_batches: + run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) + if size != size[::-1]: + run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) + + # sparse input + for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: + for density in [0.005, 0.1]: + run_subtest(None, size, (), device, torch.svd_lowrank, density=density) + + # jitting support + jitted = torch.jit.script(torch.svd_lowrank) + actual_rank, size, batches = 2, (17, 4), () + run_subtest(actual_rank, size, batches, device, jitted) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.cfloat) + def test_svd_complex(self, device, dtype): + t = torch.randn((10, 10), dtype=dtype, device=device) + U, S, V = torch.svd(t, some=False) + # note: from the math point of view, it is weird that we need to use + # V.T instead of V.T.conj(): torch.svd has a buggy behavior for + # complex numbers and it's deprecated. You should use torch.linalg.svd + # instead. + t2 = U @ torch.diag(S).type(dtype) @ V.T + self.assertEqual(t, t2) + + def _test_svd_helper(self, shape, some, col_maj, device, dtype): + cpu_tensor = torch.randn(shape, device='cpu').to(dtype) + device_tensor = cpu_tensor.to(device=device) + if col_maj: + cpu_tensor = cpu_tensor.t() + device_tensor = device_tensor.t() + cpu_result = torch.svd(cpu_tensor, some=some) + device_result = torch.svd(device_tensor, some=some) + m = min(cpu_tensor.shape[-2:]) + # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). + # - When some==False, U[..., m:] can be arbitrary. + # - When some==True, U shape: [..., m], V shape: [m, m] + # - Signs are not deterministic. If the sign of a column of U is changed + # then the corresponding column of the V has to be changed. + # Thus here we only compare result[..., :m].abs() from CPU and device. + for x, y in zip(cpu_result, device_result): + self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_and_complex_types()) + def test_svd_square(self, device, dtype): + self._test_svd_helper((10, 10), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_square_col_maj(self, device, dtype): + self._test_svd_helper((10, 10), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some(self, device, dtype): + self._test_svd_helper((20, 5), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all(self, device, dtype): + self._test_svd_helper((20, 5), False, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_some_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_types()) + def test_svd_tall_all_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), False, True, device, dtype) + + # ~~~ tests for torch.linalg.svd ~~~ + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_compute_uv(self, device, dtype): + """ + Test the default case, compute_uv=True. Here we have the very same behavior as + numpy + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + for full_matrices in (True, False): + # check linalg.svd vs numpy + expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) + actual = torch.linalg.svd(t, full_matrices, compute_uv=True) + self.assertEqual(actual, expected) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(actual[0]), + torch.empty_like(actual[1]), + torch.empty_like(actual[2])) + out2 = torch.linalg.svd(t, full_matrices, compute_uv=True, out=out) + self.assertEqual(actual, out) + self.assertEqual(actual, out2) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_linalg_svd_no_compute_uv(self, device, dtype): + """ + Test the compute_uv=False case. Here we have a different return type than + numpy: numpy returns S, we return (empty, S, empty) + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + + def is_empty(x): + return x.numel() == 0 and x.dtype == t.dtype and x.device == t.device + + for full_matrices in (True, False): + # check linalg.svd vs numpy + np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False) + assert is_empty(USV.U) + self.assertEqual(USV.S, np_s) + assert is_empty(USV.V) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.empty_like(USV.U), torch.empty_like(USV.S), torch.empty_like(USV.V)) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out) + assert USV.U is out[0] + assert USV.S is out[1] + assert USV.V is out[2] + self.assertEqual(USV.S, np_s) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @onlyCUDA + @dtypes(torch.float) + def test_linalg_svd_out_different_device(self, device, dtype): + t = torch.randn(5, 7, device=device, dtype=dtype) # this is on cuda + u = torch.empty((5, 5), device='cpu', dtype=dtype) + s = torch.empty((5,), device='cpu', dtype=dtype) + v = torch.empty((7, 7), device='cpu', dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'svd output tensor U is on the wrong device: expected cuda:.* got cpu'): + torch.linalg.svd(t, out=(u, s, v)) + + def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = random_hermitian_pd_matrix(*A_dims, dtype=dtype, device=device) + L = torch.cholesky(A, upper=upper) + return b, A, L + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_cholesky_solve(self, device, dtype): + for (k, n), upper in itertools.product(zip([2, 3, 5], [3, 5, 7]), [True, False]): + b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype) + x = torch.cholesky_solve(b, L, upper=upper) + self.assertEqual(b, A.mm(x)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_cholesky_solve_batched(self, device, dtype): + def cholesky_solve_batch_helper(A_dims, b_dims, upper): + b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper)) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.cholesky_solve(b, L, upper=upper) # Actual output + self.assertEqual(x_act, x_exp) # Equality check + Ax = torch.matmul(A, x_act) + self.assertEqual(b, Ax) # Correctness check + + for upper, batchsize in itertools.product([True, False], [1, 3, 4]): + cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_cholesky_solve_batched_non_contiguous(self, device, dtype): + from numpy.linalg import solve + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + for upper in [True, False]: + A = random_hermitian_pd_matrix(2, 2, dtype=dtype, device='cpu') + b = torch.randn(2, 2, 2, dtype=dtype, device='cpu') + x_exp = solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy()) + A = A.to(device).permute(0, 2, 1) + b = b.to(device).permute(2, 1, 0) + assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" + L = torch.cholesky(A, upper) + x = torch.cholesky_solve(b, L, upper=upper) + self.assertEqual(x, x_exp) + + @slowTest + @skipCUDAIf(True, "See https://github.com/pytorch/pytorch/issues/48996") + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_cholesky_solve_batched_many_batches(self, device, dtype): + for A_dims, b_dims in zip([(5, 256, 256), (5,)], [(5, 10), (512, 512, 5, 10)]): + for upper in [True, False]: + b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) + x = torch.cholesky_solve(b, L, upper) + Ax = torch.matmul(A, x) + self.assertEqual(Ax, b.expand_as(Ax)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_cholesky_solve_batched_broadcasting(self, device, dtype): + from numpy.linalg import solve + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + def run_test(A_dims, b_dims, upper): + A_matrix_size = A_dims[-1] + A_batch_dims = A_dims[:-2] + A = random_hermitian_pd_matrix(A_matrix_size, *A_batch_dims, + dtype=dtype, device='cpu') + b = torch.randn(*b_dims, dtype=dtype, device='cpu') + x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device) + A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device) + L = torch.cholesky(A, upper) + x = torch.cholesky_solve(b, L, upper=upper) + self.assertEqual(x, x_exp) + # https://github.com/pytorch/pytorch/issues/42695 + x = torch.cholesky_solve(b, L, upper=upper, out=x) + self.assertEqual(x, x_exp) + + # test against numpy.linalg.solve + for upper in [True, False]: + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), upper) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), upper) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper) # broadcasting A & b + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_cholesky_solve_autograd(self, device, dtype): + def run_test(A_dims, B_dims, upper): + root = torch.randn(*A_dims, device=device, dtype=dtype).requires_grad_() + b = torch.randn(*B_dims, device=device, dtype=dtype).requires_grad_() + + def func(root, b, upper): + if upper: + A = root.triu() + else: + A = root.tril() + return torch.cholesky_solve(b, A, upper) + + gradcheck(func, [root, b, upper]) + gradgradcheck(func, [root, b, upper], atol=1e-3) + + for (a_size, b_size), upper in itertools.product([((3, 3), (3, 4)), ((3, 3), (3, 2)), + ((2, 3, 3), (2, 3, 4)), ((2, 3, 3), (2, 3, 2))], + [True, False]): + run_test(a_size, b_size, upper) + + @skipCUDAIfNoMagmaAndNoCusolver + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_inverse(self, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + def run_test(torch_inverse, matrix, batches, n): + matrix_inverse = torch_inverse(matrix) + + # Compare against NumPy output + # NumPy uses 'gesv' LAPACK routine solving the equation A A_inv = I + # But in PyTorch 'gertf' + 'getri' is used causing element-wise differences + expected = np.linalg.inv(matrix.cpu().numpy()) + self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=self.precision) + + # Additional correctness tests, check matrix*matrix_inverse == identity + identity = torch.eye(n, dtype=dtype, device=device) + self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix, matrix_inverse)) + self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix_inverse, matrix)) + + # check the out= variant + # prepare the expected out tensor + matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device) + matrix_inverse_out_t = matrix_inverse_out.transpose(-2, -1).clone(memory_format=torch.contiguous_format) + matrix_inverse_out = matrix_inverse_out_t.transpose(-2, -1) + ans = torch_inverse(matrix, out=matrix_inverse_out) + self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0) + self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0) + + # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix + if matrix.ndim > 2 and batches[0] != 0: + expected_inv_list = [] + p = int(np.prod(batches)) # use `p` instead of -1, so that the test works for empty input as well + for mat in matrix.contiguous().view(p, n, n): + expected_inv_list.append(torch_inverse(mat)) + expected_inv = torch.stack(expected_inv_list).view(*batches, n, n) + if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]: + # single-inverse is done using cuSOLVER, while batched inverse is done using MAGMA + # individual values can be significantly different for fp32, hence rather high rtol is used + # the important thing is that torch_inverse passes above checks with identity + self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2) + else: + self.assertEqual(matrix_inverse, expected_inv) + + for torch_inverse in [torch.inverse, torch.linalg.inv]: + for batches, n in itertools.product( + [[], [0], [1], [4], [2, 3]], + [0, 5, 64] + ): + # large batch size and large matrix size will be tested in test_inverse_many_batches (slow test) + if batches and batches[0] == 32 and n == 256: + continue + matrices = random_fullrank_matrix_distinct_singular_value(n, *batches, dtype=dtype).to(device) + run_test(torch_inverse, matrices, batches, n) + + # test non-contiguous input + run_test(torch_inverse, matrices.transpose(-2, -1), batches, n) + if n > 0: + run_test( + torch_inverse, + random_fullrank_matrix_distinct_singular_value(n * 2, *batches, dtype=dtype).to(device) + .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n), + batches, n + ) + + @slowTest + @skipCUDAIfNoMagmaAndNoCusolver + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3, + torch.float64: 1e-5, torch.complex128: 1e-5}) + def test_inverse_many_batches(self, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + def test_inverse_many_batches_helper(torch_inverse, b, n): + matrices = random_fullrank_matrix_distinct_singular_value(b, n, n, dtype=dtype).to(device) + matrices_inverse = torch_inverse(matrices) + + # Compare against NumPy output + expected = np.linalg.inv(matrices.cpu().numpy()) + self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-3) + + for torch_inverse in [torch.inverse, torch.linalg.inv]: + test_inverse_many_batches_helper(torch_inverse, 5, 256) + test_inverse_many_batches_helper(torch_inverse, 3, 512) + test_inverse_many_batches_helper(torch_inverse, 64, 64) + + @skipCUDAIfNoMagmaAndNoCusolver + @skipCPUIfNoLapack + @onlyOnCPUAndCUDA # TODO: XLA doesn't raise exception + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_inverse_errors(self, device, dtype): + # inverse expects batches of square matrices as input + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.inverse(torch.randn(2, 3, 4, 3)) + + # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch + def run_test_singular_input(batch_dim, n): + x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) + x[n, -1, -1] = 0 + with self.assertRaisesRegex(RuntimeError, rf'For batch {n}: U\(3,3\) is zero'): + torch.inverse(x) + + for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: + run_test_singular_input(*params) + + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_pinv(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + def run_test_main(A, hermitian): + # Testing against definition for pseudo-inverses + A_pinv = torch.linalg.pinv(A, hermitian=hermitian) + if A.numel() > 0: + self.assertEqual(A, A @ A_pinv @ A, atol=self.precision, rtol=self.precision) + self.assertEqual(A_pinv, A_pinv @ A @ A_pinv, atol=self.precision, rtol=self.precision) + self.assertEqual(A @ A_pinv, (A @ A_pinv).conj().transpose(-2, -1)) + self.assertEqual(A_pinv @ A, (A_pinv @ A).conj().transpose(-2, -1)) + else: + self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2])) + + # Check out= variant + out = torch.empty_like(A_pinv) + ans = torch.linalg.pinv(A, hermitian=hermitian, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, A_pinv) + + def run_test_numpy(A, hermitian): + # Check against NumPy output + # Test float rcond, and specific value for each matrix + rconds = [float(torch.rand(1)), ] + # Test different types of rcond tensor + for rcond_type in all_types(): + rconds.append(torch.rand(A.shape[:-2], dtype=torch.double, device=device).to(rcond_type)) + # Test broadcasting of rcond + if A.ndim > 2: + rconds.append(torch.rand(A.shape[-3], device=device)) + for rcond in rconds: + actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian) + numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy() + expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian) + self.assertEqual(actual, expected, atol=self.precision, rtol=1e-5) + + for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices + (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices + (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices + (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices + A = torch.randn(*sizes, dtype=dtype, device=device) + hermitian = False + run_test_main(A, hermitian) + run_test_numpy(A, hermitian) + + # Check hermitian = True + for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices + (0, 0), (3, 0, 0), ]: # zero numel square matrices + A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device) + hermitian = True + run_test_main(A, hermitian) + run_test_numpy(A, hermitian) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_pinv_errors_and_warnings(self, device, dtype): + # pinv requires at least 2D tensor + a = torch.randn(1, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"): + torch.linalg.pinv(a) + + # if non-empty out tensor with wrong shape is passed a warning is given + a = torch.randn(3, 3, dtype=dtype, device=device) + out = torch.empty(7, 7, dtype=dtype, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.pinv(a, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes of out and input should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "dtype Int does not match the expected dtype"): + torch.linalg.pinv(a, out=out) + + if torch.cuda.is_available(): + # device of out and input should match + wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' + out = torch.empty_like(a).to(wrong_device) + with self.assertRaisesRegex(RuntimeError, "Expected result and input to be on the same device"): + torch.linalg.pinv(a, out=out) + + # device of rcond and input should match + wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' + rcond = torch.full((), 1e-2, device=wrong_device) + with self.assertRaisesRegex(RuntimeError, "Expected rcond and input to be on the same device"): + torch.linalg.pinv(a, rcond=rcond) + + # rcond can't be complex + rcond = torch.full((), 1j, device=device) + with self.assertRaisesRegex(RuntimeError, "rcond tensor of complex type is not supported"): + torch.linalg.pinv(a, rcond=rcond) + + @skipCUDAIfNoMagmaAndNoCusolver + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_inv_errors(self, device, dtype): + # inv expects batches of square matrices as input + a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device) + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.linalg.inv(a) + + # inv requires the input to be at least 2 dimensional tensor + a = torch.randn(2, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"): + torch.linalg.inv(a) + + # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch + def run_test_singular_input(batch_dim, n): + a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) + a[n, -1, -1] = 0 + with self.assertRaisesRegex(RuntimeError, rf"For batch {n}: U\(3,3\) is zero"): + torch.linalg.inv(a) + + for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: + run_test_singular_input(*params) + + # if non-empty out tensor with wrong shape is passed an error is thrown + a = torch.randn(2, 3, 3, device=device, dtype=dtype) + out = torch.empty(1, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "does not match input shape"): + torch.linalg.inv(a, out=out) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match input dtype"): + torch.linalg.inv(a, out=out) + + # device should match + if torch.cuda.is_available(): + wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' + out = torch.empty(0, device=wrong_device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "does not match input device"): + torch.linalg.inv(a, out=out) + + def solve_test_helper(self, A_dims, b_dims, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype).to(device) + return b, A + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3}) + def test_solve(self, device, dtype): + def run_test(n, batch, rhs): + A_dims = (n, *batch) + b_dims = (*batch, n, *rhs) + b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) + + # Correctness test + x = torch.linalg.solve(A, b) + if rhs == (): + Ax = torch.matmul(A, x.unsqueeze(-1)) + Ax.squeeze_(-1) + else: + Ax = torch.matmul(A, x) + self.assertEqual(b.expand_as(Ax), Ax) + + # Check against NumPy + expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy()) + self.assertEqual(x, expected) + + # Check out= variant + if rhs == (): + out = torch.empty_like(x.unsqueeze(-1)) + else: + out = torch.empty_like(x) + ans = torch.linalg.solve(A, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(x, out) + + # Check empty out + out = torch.empty(0, dtype=dtype, device=device) + ans = torch.linalg.solve(A, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(x, out) + + batches = [(), (0, ), (3, ), (2, 3)] + ns = [0, 5, 32] + nrhs = [(), (1, ), (5, )] + for n, batch, rhs in itertools.product(ns, batches, nrhs): + run_test(n, batch, rhs) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3}) + def test_solve_batched_non_contiguous(self, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype).to(device).permute(1, 0, 2) + b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0) + self.assertFalse(A.is_contiguous()) + self.assertFalse(b.is_contiguous()) + actual = torch.linalg.solve(A, b) + expected = np.linalg.solve(A.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(actual, expected) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_solve_errors(self, device, dtype): + # solve expects batches of square matrices as input + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device) + b = torch.randn(2, 3, 4, 1, dtype=dtype, device=device) + torch.linalg.solve(a, b) + + # solve expects compatible shapes for A x = b + with self.assertRaisesRegex(RuntimeError, "Incompatible matrix sizes"): + a = torch.randn(2, 3, 3, 3, dtype=dtype, device=device) + b = torch.randn(2, 3, 2, 1, dtype=dtype, device=device) + torch.linalg.solve(a, b) + + # if input is not solvable, RuntimeError is raised mentioning the first non-solvable batch + def run_test_singular_input(batch_dim, n): + a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) + a[n, -1, -1] = 0 + b = torch.randn(batch_dim, 3, 1, dtype=dtype, device=device) + with self.assertRaisesRegex(RuntimeError, rf'For batch {n}: U\(3,3\) is zero'): + torch.linalg.solve(a, b) + + for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: + run_test_singular_input(*params) + + # if out is non-empty then it should have correct sizes + with self.assertRaisesRegex(RuntimeError, r'does not match broadcasted other shape'): + out = torch.empty(1, dtype=dtype, device=device) + A = torch.eye(3, dtype=dtype, device=device) + b = torch.randn(3, 1, dtype=dtype, device=device) + torch.linalg.solve(A, b, out=out) + + # if out is non-empty then it should also be Fortran contiguous + with self.assertRaisesRegex(RuntimeError, r'tensor must be in batched column major'): + out = torch.zeros(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0) + self.assertFalse(out.is_contiguous()) + A = torch.eye(2, dtype=dtype, device=device).reshape((1, 2, 2)).repeat(2, 1, 1) + b = torch.randn(2, 2, 2, dtype=dtype, device=device) + torch.linalg.solve(A, b, out=out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_solve(self, device, dtype): + for (k, n) in zip([2, 3, 5], [3, 5, 7]): + b, A = self.solve_test_helper((n,), (n, k), device, dtype) + x = torch.solve(b, A)[0] + self.assertEqual(b, A.mm(x)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_solve_batched(self, device, dtype): + def solve_batch_helper(A_dims, b_dims): + b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.solve(b[i], A[i])[0]) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.solve(b, A)[0] # Actual output + self.assertEqual(x_exp, x_act) # Equality check + Ax = torch.matmul(A, x_act) + self.assertEqual(b, Ax) + + for batchsize in [1, 3, 4]: + solve_batch_helper((5, batchsize), (batchsize, 5, 10)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_solve_batched_non_contiguous(self, device, dtype): + from numpy.linalg import solve + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype).to(device).permute(1, 0, 2) + b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0) + x, _ = torch.solve(b, A) + x_exp = solve(A.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(x, x_exp) + + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_solve_batched_many_batches(self, device, dtype): + for A_dims, b_dims in zip([(5, 256, 256), (3, )], [(5, 1), (512, 512, 3, 1)]): + b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) + x, _ = torch.solve(b, A) + Ax = torch.matmul(A, x) + self.assertEqual(Ax, b.expand_as(x)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_solve_batched_broadcasting(self, device, dtype): + from numpy.linalg import solve + + def run_test(A_dims, b_dims): + A_matrix_size = A_dims[-1] + A_batch_dims = A_dims[:-2] + b, A = self.solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, device, dtype) + x, _ = torch.solve(b, A) + x_exp = solve(A.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(x, x_exp) + + # test against numpy.linalg.solve + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) + def test_tensorsolve(self, device, dtype): + def run_test(a_shape, dims): + a = torch.randn(a_shape, dtype=dtype, device=device) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty_like(result) + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test(a_shape, d) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_tensorsolve_empty(self, device, dtype): + # Check for empty inputs. NumPy does not work for these cases. + a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) + b = torch.empty(a.shape[:2], dtype=dtype, device=device) + x = torch.linalg.tensorsolve(a, b) + self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) + def test_tensorsolve_non_contiguous(self, device, dtype): + def run_test_permuted(a_shape, dims): + # check for permuted / transposed inputs + a = torch.randn(a_shape, dtype=dtype, device=device) + a = a.movedim((0, 2), (-2, -1)) + self.assertFalse(a.is_contiguous()) + b = torch.randn(a.shape[:2], dtype=dtype, device=device) + b = b.t() + self.assertFalse(b.is_contiguous()) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + def run_test_skipped_elements(a_shape, dims): + # check for inputs with skipped elements + a = torch.randn(a_shape, dtype=dtype, device=device) + a = a[::2] + self.assertFalse(a.is_contiguous()) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + b = b[::2] + self.assertFalse(b.is_contiguous()) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + # check non-contiguous out + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(out.is_contiguous()) + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test_permuted(a_shape, d) + + a_shapes = [(4, 3, 6), (6, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test_skipped_elements(a_shape, d) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32) + def test_tensorsolve_errors_and_warnings(self, device, dtype): + # tensorsolve expects the input that can be reshaped to a square matrix + a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + b = torch.randn(8, 4) + self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape)) + with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'): + torch.linalg.tensorsolve(a, b) + + # if non-empty out tensor with wrong shape is passed a warning is given + out = torch.empty_like(a) + b = torch.randn(6, 4) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.tensorsolve(a, b, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.tensorsolve(a, b, out=out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3}) + def test_tensorinv(self, device, dtype): + + def run_test(a_shape, ind): + a = torch.randn(a_shape, dtype=dtype, device=device) + a_numpy = a.cpu().numpy() + result = torch.linalg.tensorinv(a, ind=ind) + expected = np.linalg.tensorinv(a_numpy, ind=ind) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty_like(result) + ans = torch.linalg.tensorinv(a, ind=ind, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + # compare to NumPy output + run_test((12, 3, 4), ind=1) + run_test((3, 8, 24), ind=2) + run_test((18, 3, 3, 2), ind=1) + run_test((1, 4, 2, 2), ind=2) + run_test((2, 3, 5, 30), ind=3) + run_test((24, 2, 2, 3, 2), ind=1) + run_test((3, 4, 2, 3, 2), ind=2) + run_test((1, 2, 3, 2, 3), ind=3) + run_test((3, 2, 1, 2, 12), ind=4) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3}) + def test_tensorinv_non_contiguous(self, device, dtype): + + def run_test(a_shape, ind): + # check for permuted (transposed) case + a = torch.randn(a_shape, dtype=dtype, device=device) + permutation = list(range(0, a.ndim)) + a = a.permute(permutation[ind:] + permutation[:ind]) + self.assertFalse(a.is_contiguous()) + a_numpy = a.cpu().numpy() + result = torch.linalg.tensorinv(a, ind=a.ndim - ind) + expected = np.linalg.tensorinv(a_numpy, ind=a.ndim - ind) + self.assertEqual(result, expected) + + def run_test_skipped_elements(a_shape, ind): + # check for input with skipped elements + a = torch.randn(a_shape, dtype=dtype, device=device) + a = a[::2] + self.assertFalse(a.is_contiguous()) + a_numpy = a.cpu().numpy() + result = torch.linalg.tensorinv(a, ind=ind) + expected = np.linalg.tensorinv(a_numpy, ind=ind) + self.assertEqual(result, expected) + + # check non-contiguous out + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(out.is_contiguous()) + ans = torch.linalg.tensorinv(a, ind=ind, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + run_test((12, 3, 4), ind=1) + run_test((3, 8, 24), ind=2) + run_test((18, 3, 3, 2), ind=1) + run_test((1, 4, 2, 2), ind=2) + run_test((2, 3, 5, 30), ind=3) + run_test((24, 2, 2, 3, 2), ind=1) + run_test((3, 4, 2, 3, 2), ind=2) + run_test((1, 2, 3, 2, 3), ind=3) + run_test((3, 2, 1, 2, 12), ind=4) + + run_test_skipped_elements((12, 3, 2), ind=1) + run_test_skipped_elements((18, 3, 3, 1), ind=1) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_tensorinv_empty(self, device, dtype): + for ind in range(1, 4): + # Check for empty inputs. NumPy does not work for these cases. + a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) + a_inv = torch.linalg.tensorinv(a, ind=ind) + self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind]) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_tensorinv_errors_and_warnings(self, device, dtype): + + def check_shape(a_shape, ind): + # tensorinv requires the input to satisfy + # prod(a.shape[ind:]) == prod(a.shape[:ind]) + a = torch.randn(a_shape) + with self.assertRaisesRegex(RuntimeError, "Expected self to satisfy the requirement"): + torch.linalg.tensorinv(a, ind=ind) + + def check_ind(a_shape, ind): + a = torch.randn(a_shape) + with self.assertRaisesRegex(RuntimeError, "Expected a strictly positive integer"): + torch.linalg.tensorinv(a, ind=ind) + + def check_out(a_shape, ind): + # if non-empty out tensor with wrong shape is passed a warning is given + a = torch.randn(a_shape) + out = torch.empty_like(a) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.tensorinv(a, ind=ind, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.tensorinv(a, ind=ind, out=out) + + # test for invalid shape + check_shape((2, 3, 4), ind=1) + check_shape((1, 2, 3, 4), ind=3) + + # test for invalid ind + check_ind((12, 3, 4), ind=-1) + check_ind((18, 3, 3, 2), ind=0) + + # test for invalid out tensor + check_out((12, 3, 4), ind=1) + check_out((3, 8, 24), ind=2) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_tensorinv_singular_input(self, device, dtype): + + def check_singular_input(a_shape, ind): + prod_ind_end = np.prod(a_shape[ind:]) + a = torch.eye(prod_ind_end, dtype=dtype, device=device) + a[-1, -1] = 0 # Now `a` is singular + a = a.reshape(a_shape) + with self.assertRaisesRegex(RuntimeError, "Failed to invert the input tensor, because it is singular"): + torch.linalg.tensorinv(a, ind=ind) + + # test for non-invertible input + check_singular_input((12, 3, 4), ind=1) + check_singular_input((3, 6, 18), ind=2) + + def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn): + def check(x, y): + # Compare with numpy + res = torch_fn(x, y) + ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy()))) + self.assertEqual(res.cpu(), ref) + + # Test out variant + out = torch.empty_like(res) + torch_fn(x, y, out=out) + self.assertEqual(out, res) + + # Empty + x = torch.tensor([], dtype=dtype, device=device) + y = torch.tensor([], dtype=dtype, device=device) + check(x, y) + + # Contiguous + x = torch.randn(10, dtype=dtype, device=device) + y = torch.randn(10, dtype=dtype, device=device) + check(x, y) + + # 0 strided + y = torch.randn(1, dtype=dtype, device=device).expand(10) + check(x, y) + + # 2 strided + check(x[::2], y[::2]) + + @dtypes(torch.float, torch.cfloat) + @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) + def test_dot_vs_numpy(self, device, dtype): + self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot) + + @dtypes(torch.float, torch.cfloat) + @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) + def test_vdot_vs_numpy(self, device, dtype): + self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot) + + def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False): + def check(x, y, regex): + with self.assertRaisesRegex(RuntimeError, regex): + torch_fn(x, y) + + if complex_dtypes: + x = torch.randn(1, dtype=torch.cfloat, device=device) + y = torch.randn(3, dtype=torch.cdouble, device=device) + else: + x = torch.randn(1, dtype=torch.float, device=device) + y = torch.randn(3, dtype=torch.double, device=device) + + check(x, y, 'dot : expected both vectors to have same dtype') + check(x.reshape(1, 1), y, '1D tensors expected') + check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size') + + if self.device_type != 'cpu': + x_cpu = x.expand(3).cpu() + check(x_cpu, y.to(x.dtype), 'expected all tensors to be on the same device') + + @onlyOnCPUAndCUDA + def test_vdot_invalid_args(self, device): + self._test_dot_vdot_invalid_args(device, torch.vdot) + self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True) + + @onlyOnCPUAndCUDA + def test_dot_invalid_args(self, device): + self._test_dot_vdot_invalid_args(device, torch.dot) + self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_matrix_rank(self, device, dtype): + matrix_rank = torch.linalg.matrix_rank + + def run_test(shape0, shape1, batch): + a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) + rank_a = matrix_rank(a) + + self.assertEqual(rank_a, matrix_rank(a.conj().transpose(-2, -1))) + aaH = torch.matmul(a, a.conj().transpose(-2, -1)) + rank_aaH = matrix_rank(aaH) + rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) + self.assertEqual(rank_aaH, rank_aaH_hermitian) + aHa = torch.matmul(a.conj().transpose(-2, -1), a) + self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) + + # check against NumPy + self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy())) + self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01)) + + self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy())) + self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01)) + + # hermitian flag for NumPy was added in 1.14.0 + if np.lib.NumpyVersion(np.__version__) >= '1.14.0': + self.assertEqual(rank_aaH_hermitian, + np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True)) + self.assertEqual(matrix_rank(aaH, 0.01, True), + np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True)) + + # check out= variant + out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device) + ans = matrix_rank(a, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, rank_a) + + shapes = (3, 13) + batches = ((), (0, ), (4, ), (3, 5, )) + for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches): + run_test(shape0, shape1, batch) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_matrix_rank_empty(self, device, dtype): + matrix_rank = torch.linalg.matrix_rank + + # NumPy doesn't work for input with no elements + def run_test(shape0, shape1, batch): + a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) + rank_a = matrix_rank(a) + expected = torch.zeros(batch, dtype=torch.int64, device=device) + + self.assertEqual(rank_a, matrix_rank(a.conj().transpose(-2, -1))) + + aaH = torch.matmul(a, a.conj().transpose(-2, -1)) + rank_aaH = matrix_rank(aaH) + rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) + self.assertEqual(rank_aaH, rank_aaH_hermitian) + + aHa = torch.matmul(a.conj().transpose(-2, -1), a) + self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) + + self.assertEqual(rank_a, expected) + self.assertEqual(matrix_rank(a, 0.01), expected) + + self.assertEqual(rank_aaH, expected) + self.assertEqual(matrix_rank(aaH, 0.01), expected) + + self.assertEqual(rank_aaH_hermitian, expected) + self.assertEqual(matrix_rank(aaH, 0.01, True), expected) + + batches = ((), (4, ), (3, 5, )) + for batch in batches: + run_test(0, 0, batch) + run_test(0, 3, batch) + run_test(3, 0, batch) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_matrix_rank_basic(self, device, dtype): + matrix_rank = torch.linalg.matrix_rank + + a = torch.eye(10, dtype=dtype, device=device) + self.assertEqual(matrix_rank(a).item(), 10) + self.assertEqual(matrix_rank(a, hermitian=True).item(), 10) + + a[5, 5] = 0 + self.assertEqual(matrix_rank(a).item(), 9) + self.assertEqual(matrix_rank(a, hermitian=True).item(), 9) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_matrix_rank(self, device, dtype): + a = torch.eye(10, dtype=dtype, device=device) + self.assertEqual(torch.matrix_rank(a).item(), 10) + self.assertEqual(torch.matrix_rank(a, True).item(), 10) + + a[5, 5] = 0 + self.assertEqual(torch.matrix_rank(a).item(), 9) + self.assertEqual(torch.matrix_rank(a, True).item(), 9) + + a = torch.randn(24, 42, dtype=dtype, device=device) + self.assertEqual(torch.matrix_rank(a), torch.matrix_rank(a.t())) + aaT = torch.mm(a, a.conj().t()) + self.assertEqual(torch.matrix_rank(aaT), torch.matrix_rank(aaT, True)) + aTa = torch.mm(a.conj().t(), a) + self.assertEqual(torch.matrix_rank(aTa), torch.matrix_rank(aTa, True)) + + a = torch.randn(35, 75, dtype=dtype, device=device) + self.assertEqual(torch.matrix_rank(a), np.linalg.matrix_rank(a.cpu().numpy())) + self.assertEqual(torch.matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01)) + + aaT = torch.mm(a, a.conj().t()) + self.assertEqual(torch.matrix_rank(aaT), np.linalg.matrix_rank(aaT.cpu().numpy())) + self.assertEqual(torch.matrix_rank(aaT, 0.01), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01)) + + if np.lib.NumpyVersion(np.__version__) >= '1.14.0': + self.assertEqual(torch.matrix_rank(aaT, True), np.linalg.matrix_rank(aaT.cpu().numpy(), True)) + self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True)) + + @precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_qr(self, device, dtype): + def run_test(tensor_dims, some): + A = torch.randn(*tensor_dims, dtype=dtype, device=device) + Q, R = torch.qr(A, some=some) + + # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n) + m, n = tensor_dims[-2:] + n_columns = m if (not some) and m > n else min(m, n) + self.assertEqual(Q.size(-2), m) + self.assertEqual(R.size(-1), n) + self.assertEqual(Q.size(-1), n_columns) + + A_ = A.cpu().numpy() + Q_ = Q.cpu().numpy() + R_ = R.cpu().numpy() + + # Check1: A = QR + self.assertEqual(A_, np.matmul(Q_, R_)) + + # Check2: A = QR (with out) + Q_out, R_out = torch.full_like(Q, math.nan), torch.full_like(R, math.nan) + torch.qr(A, some=some, out=(Q_out, R_out)) + Q_out_ = Q_out.cpu().numpy() + R_out_ = R_out.cpu().numpy() + self.assertEqual(A_, np.matmul(Q_out_, R_out_)) + + # Check3: Q == Q_out, R == R_out + self.assertEqual(Q_, Q_out_) + self.assertEqual(R_, R_out_) + + # Check4: Q^{T}Q = I, triu(R) = R + eye = torch.eye(n_columns, device=device, dtype=dtype).expand(Q.shape[:-2] + (n_columns, n_columns)).cpu().numpy() + self.assertEqual(np.matmul(Q_.swapaxes(-1, -2).conj(), Q_), eye) + self.assertEqual(R.triu(), R) + + tensor_dims_list = [(3, 5), (5, 5), (5, 3), # Single matrix + (7, 3, 5), (7, 5, 5), (7, 5, 3), # 3-dim Tensors + (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)] # 4-dim Tensors + for tensor_dims, some in itertools.product(tensor_dims_list, [True, False]): + run_test(tensor_dims, some) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_qr_vs_numpy(self, device, dtype): + """ + test torch.linalg.qr vs numpy.linalg.qr + """ + sizes_to_test = [ + (7, 5), + (5, 7), + (5, 0), # empty + (0, 5), # empty + ] + for size in sizes_to_test: + t = torch.randn(size, device=device, dtype=dtype) + np_t = t.cpu().numpy() + for mode in ['reduced', 'complete']: + exp_q, exp_r = np.linalg.qr(np_t, mode=mode) + q, r = torch.linalg.qr(t, mode=mode) + self.assertEqual(q, exp_q) + self.assertEqual(r, exp_r) + # + # for mode='r' we need a special logic because numpy returns only r + exp_r = np.linalg.qr(np_t, mode='r') + q, r = torch.linalg.qr(t, mode='r') + # check that q is empty + self.assertEqual(q.shape, (0,)) + self.assertEqual(q.dtype, t.dtype) + self.assertEqual(q.device, t.device) + # check r + self.assertEqual(r, exp_r) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_linalg_qr_autograd_errors(self, device, dtype): + # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but + # without 'q' you cannot compute the backward pass. Check that + # linalg_qr_backward complains cleanly in that case. + inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True) + q, r = torch.linalg.qr(inp, mode='r') + self.assertEqual(q.shape, (0,)) # empty tensor + b = torch.sum(r) + with self.assertRaisesRegex(RuntimeError, + "The derivative of qr is not implemented when mode='r'"): + b.backward() + # + inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True) + q, r = torch.linalg.qr(inp, mode='complete') + b = torch.sum(r) + with self.assertRaisesRegex(RuntimeError, + "The derivative of qr is not implemented when mode='complete' and nrows > ncols"): + b.backward() + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_qr_batched(self, device, dtype): + """ + test torch.linalg.qr vs numpy.linalg.qr. We need some special logic + because numpy does not support batched qr + """ + def np_qr_batched(a, mode): + """poor's man batched version of np.linalg.qr""" + all_q = [] + all_r = [] + for matrix in a: + result = np.linalg.qr(matrix, mode=mode) + if mode == 'r': + all_r.append(result) + else: + q, r = result + all_q.append(q) + all_r.append(r) + if mode == 'r': + return np.array(all_r) + else: + return np.array(all_q), np.array(all_r) + + t = torch.randn((3, 7, 5), device=device, dtype=dtype) + np_t = t.cpu().numpy() + for mode in ['reduced', 'complete']: + exp_q, exp_r = np_qr_batched(np_t, mode=mode) + q, r = torch.linalg.qr(t, mode=mode) + self.assertEqual(q, exp_q) + self.assertEqual(r, exp_r) + # for mode='r' we need a special logic because numpy returns only r + exp_r = np_qr_batched(np_t, mode='r') + q, r = torch.linalg.qr(t, mode='r') + # check that q is empty + self.assertEqual(q.shape, (0,)) + self.assertEqual(q.dtype, t.dtype) + self.assertEqual(q.device, t.device) + # check r + self.assertEqual(r, exp_r) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_qr_out(self, device, dtype): + """ + test torch.linalg.qr(out=...) vs torch.lingalg.qr + """ + sizes_to_test = [ + (7, 5), + (5, 7), + (5, 0), # empty + (0, 5), # empty + ] + for size in sizes_to_test: + t = torch.randn(size, device=device, dtype=dtype) + np_t = t.cpu().numpy() + for mode in ['reduced', 'complete', 'r']: + q, r = torch.linalg.qr(t, mode=mode) + out = (torch.empty((0), dtype=dtype, device=device), + torch.empty((0), dtype=dtype, device=device)) + q2, r2 = torch.linalg.qr(t, mode=mode, out=out) + self.assertIs(q2, out[0]) + self.assertIs(r2, out[1]) + self.assertEqual(q2, q) + self.assertEqual(r2, r) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float) + def test_qr_error_cases(self, device, dtype): + t1 = torch.randn(5, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'qr input should have at least 2 dimensions, but has 1 dimensions instead'): + torch.linalg.qr(t1) + t2 = torch.randn((5, 7), device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"): + torch.linalg.qr(t2, mode='hello') + + @dtypes(torch.double, torch.cdouble) + def test_einsum(self, device, dtype): + def check(equation, *operands): + ref = np.einsum(equation, *[operand.cpu().numpy() for operand in operands]) + res = torch.einsum(equation, operands) + self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) + + # Check autograd + ops = [op.detach().requires_grad_() for op in operands] + self.assertTrue(torch.autograd.gradcheck(lambda *ops: torch.einsum(equation, ops), ops)) + for op in ops: + self.assertTrue(op._version == 0) + + # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f + x = torch.rand(5, device=device, dtype=dtype) + y = torch.rand(7, device=device, dtype=dtype) + A = torch.randn(3, 5, device=device, dtype=dtype) + B = torch.randn(2, 5, device=device, dtype=dtype) + C = torch.randn(2, 3, 5, device=device, dtype=dtype) + D = torch.randn(2, 5, 7, device=device, dtype=dtype) + E = torch.randn(7, 9, device=device, dtype=dtype) + F = torch.randn(2, 3, 3, 5, device=device, dtype=dtype) + G = torch.randn(5, 4, 6, device=device, dtype=dtype) + H = torch.randn(4, 4, device=device, dtype=dtype) + I = torch.rand(2, 3, 2, device=device, dtype=dtype) + + # Note: gradcheck fails if the same input is given multiple times which is why the + # calls to clone below. (see https://github.com/pytorch/pytorch/issues/9282) + + # Vector operations + check('i->', x) # sum + check('i,i->', x, x.clone()) # dot + check('i,i->i', x, x.clone()) # vector element-wisem mul + check('i,j->ij', x, y) # outer + + # Matrix operations + check("ij->ji", A) # transpose + check("ij->j", A) # row sum + check("ij->i", A) # col sum + check("ij,ij->ij", A, A.clone()) # matrix element-wise mul + check("ij,j->i", A, x) # matrix vector multiplication + check("ij,kj->ik", A, B) # matmul + check("ij,ab->ijab", A, E) # matrix outer product + + # Tensor operations + check("aij,ajk->aik", C, D) # batch matmul + check("ijk,jk->i", C, A) # tensor matrix contraction + check("aij,jk->aik", D, E) # tensor matrix contraction + check("abcd,dfg->abcfg", F, G) # tensor tensor contraction + check("ijk,jk->ik", C, A) # tensor matrix contraction with double indices + check("ijk,jk->ij", C, A) # tensor matrix contraction with double indices + check("ijk,ik->j", C, B) # non contiguous + check("ijk,ik->jk", C, B) # non contiguous with double indices + + # Test diagonals + check("ii", H) # trace + check("ii->i", H) # diagonal + check('iji->j', I) # non-contiguous trace + check('ngrg...->nrg...', torch.rand((2, 1, 3, 1, 4), device=device, dtype=dtype)) + + # Test ellipsis + check("i...->...", H) + check("ki,...k->i...", A.t(), B) + check("k...,jk->...", A.t(), B) + check('...ik, ...j -> ...ij', C, x) + check('bik,k...j->i...j', C, torch.rand(5, 3, device=device, dtype=dtype)) + check('i...j, ij... -> ...ij', C, torch.rand(2, 5, 2, 3, device=device, dtype=dtype)) + + # torch.bilinear with discontiguous tensors + l = torch.randn(10, 5, device=device, dtype=dtype).transpose(0, 1) + r = torch.randn(20, 5, device=device, dtype=dtype).transpose(0, 1) + w = torch.randn(15, 10, 20, device=device, dtype=dtype) + check("bn,anm,bm->ba", l, w, r) + + # with strided tensors + check("bn,anm,bm->ba", l[:, ::2], w[:, ::2, ::2], r[:, ::2]) + + @dtypes(torch.double, torch.cdouble) + def test_einsum_random(self, device, dtype): + def check(equation, *operands): + ref = np.einsum(equation, *[op.cpu().numpy() for op in operands]) + res = torch.einsum(equation, operands) + self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref))) + + for _ in range(20): + # Create a random number of input operands, each with a random + # number of dimensions randomly labeled. + op_labels = [] + valid_labels = set() + for _ in range(random.randint(1, 3)): + labels = np.random.randint(0, 10, random.randint(1, 5)) + op_labels.append(labels) + valid_labels.update(labels) + label_size = np.random.randint(1, 5, 10) + ell_sizes = np.random.randint(1, 5, 3) + + # Build equation and tensors from input operand labels. + ops = [] + equation = '' + for labels in op_labels: + sizes = [label_size[label] for label in labels] + labels = [chr(ord('a') + label) for label in labels] + + # Add ellipsis dimensions at random + ell_num_dim = random.randint(0, 3) + if ell_num_dim > 0: + ell_index = random.randint(0, len(labels)) + sizes[ell_index:ell_index] = ell_sizes[-ell_num_dim:] + labels.insert(ell_index, "...") + + equation += ''.join(labels) + ',' + ops.append(torch.rand(sizes, device=device, dtype=dtype)) + equation = equation[:-1] + + # Test with implicit output + check(equation, *ops) + + # Randomly choose some labels to be part of the output + out_labels = np.unique(np.random.choice(list(valid_labels), random.randint(1, len(valid_labels)))) + out_labels = [chr(ord('a') + label) for label in out_labels] + ell_index = random.randint(0, len(out_labels)) + out_labels.insert(ell_index, '...') + equation += '->' + ''.join(out_labels) + + # Randomly test the output + check(equation, *ops) + + def test_einsum_corner_cases(self, device): + def check(equation, *operands, expected_output): + tensors = [torch.tensor(operand, dtype=torch.float32, device=device) if not isinstance(operand, tuple) + else torch.rand(operand, dtype=torch.float32, device=device) for operand in operands] + output = torch.einsum(equation, tensors) + self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) + + # Test equation variantions + check(' ', 1, expected_output=1) + check(' -> ', 1, expected_output=1) + check(' , ', 2, 2, expected_output=4) + check(' , , ', 2, 2, 2, expected_output=8) + check(' , -> ', 2, 2, expected_output=4) + check(' i ', [1], expected_output=[1]) + check(' i -> ', [1], expected_output=1) + check(' i -> i ', [1], expected_output=[1]) + check(' i , i ', [2], [2], expected_output=4) + check(' i , i -> i ', [2], [2], expected_output=[4]) + + # Test tensors with 0 size dimensions + check('i', [], expected_output=[]) + check(' i j -> j', [[], []], expected_output=[]) + check('ij->i', [[], []], expected_output=[0., 0.]) + check(' i j k , k -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []]) + + # Test broadcasting + check('i,j', [2], [1, 2], expected_output=[[2, 4]]) + check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]]) + + # Test ellipsis broadcasting + check('...', 1, expected_output=1) + check('...->', 1, expected_output=1) + check('...->...', 1, expected_output=1) + check('...', [1], expected_output=[1]) + check('...->', [1], expected_output=1) + check('i...->i', [1], expected_output=[1]) + check('i...->...i', [1], expected_output=[1]) + check('...a->', [[2], [4]], expected_output=6) + check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]]) + + def test_einsum_error_cases(self, device): + def check(equation, operands, regex, exception=RuntimeError): + with self.assertRaisesRegex(exception, r'einsum\(\) ' + regex): + torch.einsum(equation, operands) + + x = torch.rand(2) + y = torch.rand(2, 3) + + check('', [], r'must provide at least one operand') + check('. ..', [x], r'found \'.\' for operand 0 that is not part of any ellipsis') + check('... ...', [x], r'found \'.\' for operand 0 for which an ellipsis was already found') + check('A', [x], r'operand subscript must be in range \[a, z\] but found A for operand 0') + check(',', [x], r'fewer operands were provided than specified in the equation') + check('', [x, x], r'more operands were provided than specified in the equation') + check('', [x], r'the number of subscripts in the equation \(0\) does not match the number ' + r'of dimensions \(1\) for operand 0 and no ellipsis was given') + check('ai', [x], r'the number of subscripts in the equation \(2\) does not match the number ' + r'of dimensions \(1\) for operand 0 and no ellipsis was given') + check('ai...', [x], r'the number of subscripts in the equation \(2\) is more than the number ' + r'of dimensions \(1\) for operand 0') + check('a->... .', [x], r'found \'.\' for output but an ellipsis \(...\) was already found') + check('a->..', [x], r'found \'.\' for output that is not part of any ellipsis \(...\)') + check('a->A', [x], r'subscripts must be in range \[a, z\] but found A for the output') + check('a->aa', [x], r'output subscript a appears more than once in the output') + check('a->i', [x], r'output subscript i does not appear in the equation for any input operand') + check('aa', [y], r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2') + check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: ' + r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') + + def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, + device, dtype): + triangle_function = torch.triu if upper else torch.tril + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = torch.randn(*A_dims, dtype=dtype, device=device) + # create positive definite matrix + A = torch.matmul(A, A.transpose(-2, -1)) + A_triangular = triangle_function(A) + if unitriangular: + A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) + return b, A_triangular + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_triangular_solve(self, device, dtype): + for (k, n), (upper, unitriangular, transpose) in itertools.product(zip([2, 3, 5], [3, 5, 7]), + itertools.product([True, False], repeat=3)): + b, A = self.triangular_solve_test_helper((n, n), (n, k), upper, + unitriangular, device, dtype) + x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + if transpose: + self.assertEqual(b, A.t().mm(x)) + else: + self.assertEqual(b, A.mm(x)) + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_triangular_solve_batched(self, device, dtype): + def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): + b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, + unitriangular, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, + unitriangular=unitriangular, + transpose=transpose)[0]) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.triangular_solve(b, A, upper=upper, + unitriangular=unitriangular, + transpose=transpose)[0] # Actual output + self.assertEqual(x_act, x_exp) # Equality check + if transpose: + A = A.transpose(-2, -1) + + Ax = torch.matmul(A, x_act) + self.assertEqual(b, Ax) + + for (upper, unitriangular, transpose), batchsize in itertools.product(itertools.product( + [True, False], repeat=3), [1, 3, 4]): + triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), + upper, unitriangular, transpose) + + + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_triangular_solve_batched_many_batches(self, device, dtype): + for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): + # test batched A case + b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1), + upper, unitriangular, device, dtype) + x, _ = torch.triangular_solve(b, A, + upper=upper, transpose=transpose, unitriangular=unitriangular) + if transpose: + A = A.transpose(-2, -1) + + Ax = torch.matmul(A, x) + + rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision + self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol) + + # test batched b case + b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1), + upper, unitriangular, device, dtype) + x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, + unitriangular=unitriangular) + if transpose: + A = A.transpose(-2, -1) + + self.assertEqual(torch.matmul(A, x), b) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_triangular_solve_batched_broadcasting(self, device, dtype): + from scipy.linalg import solve_triangular as tri_solve + + def scipy_tri_solve_batched(A, B, upper, trans, diag): + batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] + single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] + expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), + torch.Size(batch_dims_B))) + expand_A = np.broadcast_to(A, expand_dims + single_dim_A) + expand_B = np.broadcast_to(B, expand_dims + single_dim_B) + flat_A = expand_A.reshape((-1,) + single_dim_A) + flat_B = expand_B.reshape((-1,) + single_dim_B) + flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) + for a, b in zip(flat_A, flat_B)]) + return flat_X.reshape(expand_B.shape) + + def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): + b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, + unitriangular, device, dtype) + x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), + upper, transpose, unitriangular)) + x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] + + self.assertEqual(x, x_exp.to(device)) + + for upper, transpose, unitriangular in itertools.product([True, False], repeat=3): + # test against scipy.linalg.solve_triangular + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_triangular_solve_singular(self, device, dtype): + b = torch.rand(3, 1, dtype=dtype, device=device) + A = torch.eye(3, 3, dtype=dtype, device=device) + A[-1, -1] = 0 # Now A is singular + err_str = r"triangular_solve_cpu: U\(3,3\) is zero, singular U\." + with self.assertRaisesRegex(RuntimeError, err_str): + torch.triangular_solve(b, A) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_triangular_solve_autograd(self, device, dtype): + def run_test(A_dims, B_dims): + A = torch.rand(*A_dims, dtype=dtype).requires_grad_() + b = torch.rand(*B_dims, dtype=dtype).requires_grad_() + + for upper, transpose, unitriangular in itertools.product((True, False), repeat=3): + def func(A, b): + return torch.triangular_solve(b, A, upper, transpose, unitriangular) + + gradcheck(func, [A, b]) + gradgradcheck(func, [A, b]) + + run_test((3, 3), (3, 4)) + run_test((3, 3), (3, 2)) + run_test((2, 3, 3), (2, 3, 4)) + run_test((2, 3, 3), (2, 3, 2)) + + def check_single_matmul(self, x, y, shape): + a = np.array(x, copy=False) + b = np.array(y, copy=False) + expected = np.matmul(a, b) + + ans = torch.matmul(x, y) + self.assertTrue(ans.is_contiguous()) + self.assertTrue(np.array_equal(ans, expected)) + + out = torch.zeros(*shape, dtype=torch.int64).to(x.device) + ans = torch.matmul(x, y, out=out) + self.assertIs(ans, out) + self.assertTrue(ans.is_contiguous()) + self.assertTrue(np.array_equal(ans, expected)) + + # TODO: update to run on CUDA, too + @onlyCPU + def test_matmul_small_brute_force_1d_Nd(self, device): + # Issue #20452: range(0, 10) does not work. + n = 1 + for m in range(1, 8): + for p in range(1, 8): + for o in range(1, 5): + # 1d, 3d, inner dimensions C + x = torch.arange(m, device=device) + y = torch.arange(o * m * p, device=device).reshape(o, m, p) + self.check_single_matmul(x, y, (o, n, p)) + + # 1d, 3d, inner dimensions Fortran + x = torch.arange(m, device=device) + y = torch.arange(o * p * m, device=device).reshape(o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (o, n, p)) + + # 1d, 3d, inner dimensions non-contiguous + x = torch.arange(2 * m, device=device)[::2] + y = torch.arange(o * m * 2 * p, device=device).reshape(o, m, 2 * p)[:, :, ::2] + self.check_single_matmul(x, y, (o, n, p)) + + for r in range(1, 5): + # 1d, 4d, inner dimensions C + x = torch.arange(m) + y = torch.arange(r * o * m * p, device=device).reshape(r, o, m, p) + self.check_single_matmul(x, y, (r, o, n, p)) + + # 1d, 4d, inner dimensions Fortran + x = torch.arange(m) + y = torch.arange(r * o * p * m, device=device).reshape(r, o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (r, o, n, p)) + + # 1d, 4d, inner dimensions non-contiguous + x = torch.arange(2 * m, device=device)[::2] + y = torch.arange(r * o * m * 2 * p, device=device).reshape(r, o, m, 2 * p)[:, :, :, ::2] + self.check_single_matmul(x, y, (r, o, n, p)) + + # TODO: update to run on CUDA, too + @onlyCPU + def test_matmul_small_brute_force_2d_Nd(self, device): + # Issue #20452: range(0, 10) does not work. + for n in range(1, 5): + for m in range(1, 5): + for p in range(1, 5): + for o in range(1, 3): + # 2d, 3d, inner dimensions C + x = torch.arange(n * m, device=device).reshape(n, m) + y = torch.arange(o * m * p, device=device).reshape(o, m, p) + self.check_single_matmul(x, y, (o, n, p)) + + # 2d, 3d, inner dimensions Fortran + x = torch.arange(m * n, device=device).reshape(m, n).transpose(-1, -2) + y = torch.arange(o * p * m, device=device).reshape(o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (o, n, p)) + + # 2d, 3d, inner dimensions non-contiguous + x = torch.arange(n * 2 * m, device=device).reshape(n, 2 * m)[:, ::2] + y = torch.arange(o * m * 2 * p, device=device).reshape(o, m, 2 * p)[:, :, ::2] + self.check_single_matmul(x, y, (o, n, p)) + + for r in range(1, 2): + # 2d, 4d, inner dimensions C + x = torch.arange(n * m, device=device).reshape(n, m) + y = torch.arange(r * o * m * p, device=device).reshape(r, o, m, p) + self.check_single_matmul(x, y, (r, o, n, p)) + + # 2d, 4d, inner dimensions Fortran + x = torch.arange(m * n, device=device).reshape(m, n).transpose(-1, -2) + y = torch.arange(r * o * p * m, device=device).reshape(r, o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (r, o, n, p)) + + # 2d, 4d, inner dimensions non-contiguous + x = torch.arange(n * 2 * m, device=device).reshape(n, 2 * m)[:, ::2] + y = torch.arange(r * o * m * 2 * p, device=device).reshape(r, o, m, 2 * p)[:, :, :, ::2] + self.check_single_matmul(x, y, (r, o, n, p)) + + def test_linear_algebra_scalar_raises(self, device) -> None: + m = torch.randn(5, 5, device=device) + v = torch.randn(5, device=device) + s = torch.tensor(7, device=device) + self.assertRaises(RuntimeError, lambda: torch.mv(m, s)) + self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s)) + + @onlyCPU + @dtypes(torch.float) + def test_cross(self, device, dtype): + x = torch.rand(100, 3, 100, dtype=dtype, device=device) + y = torch.rand(100, 3, 100, dtype=dtype, device=device) + res1 = torch.cross(x, y) + res2 = torch.tensor((), dtype=dtype, device=device) + torch.cross(x, y, out=res2) + self.assertEqual(res1, res2) + + @onlyCPU + @dtypes(torch.float) + def test_cross_with_and_without_dim(self, device, dtype): + x = torch.rand(100, 3, dtype=dtype, device=device) + y = torch.rand(100, 3, dtype=dtype, device=device) + res1 = torch.cross(x, y, dim=1) + res2 = torch.cross(x, y, dim=-1) + res3 = torch.cross(x, y) + self.assertEqual(res1, res2) + self.assertEqual(res1, res3) + + def test_cross_errors(self, device): + self.assertRaisesRegex( + RuntimeError, "inconsistent tensors dimensions", + lambda: torch.cross(torch.rand(100, 3, device=device), torch.rand(100, 3, 10, device=device))) + self.assertRaisesRegex( + RuntimeError, "inconsistent tensors sizes", + lambda: torch.cross(torch.rand(5, 3, device=device), torch.rand(3, 5, device=device))) + self.assertRaisesRegex( + RuntimeError, "no dimension of size 3 in input", + lambda: torch.cross(torch.rand(5, 4, device=device), torch.rand(5, 4, device=device))) + self.assertRaisesRegex( + RuntimeError, "dimension 0 does not have size 3", + lambda: torch.cross(torch.rand(5, 4, 3, device=device), torch.rand(5, 4, 3, device=device), dim=0)) + self.assertRaisesRegex( + RuntimeError, "dimension -1 does not have size 3", + lambda: torch.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-1)) + self.assertRaisesRegex( + IndexError, "Dimension out of range", + lambda: torch.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-5)) + + def test_renorm(self, device): + m1 = torch.randn(10, 5, device=device) + res1 = torch.tensor((), device=device) + + def renorm(matrix, value, dim, max_norm): + m1 = matrix.transpose(dim, 0).contiguous() + # collapse non-dim dimensions. + m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0)))) + norms = m2.norm(value, 1, True) + # clip + new_norms = norms.clone() + new_norms[torch.gt(norms, max_norm)] = max_norm + new_norms.div_(norms.add_(1e-7)) + # renormalize + m1.mul_(new_norms.expand_as(m1)) + return m1.transpose(dim, 0) + + # note that the axis fed to torch.renorm is different (2~=1) + maxnorm = m1.norm(2, 1).mean() + m2 = renorm(m1, 2, 1, maxnorm) + m1.renorm_(2, 1, maxnorm) + self.assertEqual(m1, m2, atol=1e-5, rtol=0) + self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0) + + m1 = torch.randn(3, 4, 5, device=device) + m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) + maxnorm = m2.norm(2, 0).mean() + m2 = renorm(m2, 2, 1, maxnorm) + m1.renorm_(2, 1, maxnorm) + m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) + self.assertEqual(m3, m2) + self.assertEqual(m3.norm(2, 0), m2.norm(2, 0)) + + # TODO: make this work on CUDA, too + @onlyCPU + @skipCPUIfNoLapack + def test_ormqr(self, device): + mat1 = torch.randn(7, 7) + mat2 = torch.randn(7, 7) + q, r = torch.qr(mat1) + m, tau = torch.geqrf(mat1) + out_holder = torch.empty_like(mat1) + + res1 = torch.mm(q, mat2) + res2 = torch.ormqr(m, tau, mat2, left=True, transpose=False) + torch.ormqr(m, tau, mat2, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) + + res1 = torch.mm(mat2, q) + res2 = torch.ormqr(m, tau, mat2, left=False, transpose=False) + torch.ormqr(m, tau, mat2, left=False, transpose=False, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) + + res1 = torch.mm(q.t(), mat2) + res2 = torch.ormqr(m, tau, mat2, left=True, transpose=True) + torch.ormqr(m, tau, mat2, left=True, transpose=True, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) + + res1 = torch.mm(mat2, q.t()) + res2 = torch.ormqr(m, tau, mat2, left=False, transpose=True) + torch.ormqr(m, tau, mat2, left=False, transpose=True, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) + + @skipCUDAIfRocm + def test_blas_empty(self, device): + def fn(torchfn, *args, test_out=False, **kwargs): + def call_torch_fn(*args, **kwargs): + return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape + for shape in args), **kwargs) + result = call_torch_fn(*args, **kwargs) + if not test_out: + return result + else: + out = torch.full_like(result, math.nan) + out1 = call_torch_fn(*args, **kwargs, out=out) + return out + + # mm, addmm + self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) + self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) + self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) + self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) + self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))) + self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True)) + + self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) + self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape) + t = torch.randn((5, 6), device=device) + self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) + self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) + + # mv, addmv + self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) + self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) + self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) + self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)) + + self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) + t = torch.randn((3,), device=device) + self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) + self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) + + # bmm, baddbmm + self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) + self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) + self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))) + self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True)) + + self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape) + self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape) + self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape) + c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) + self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)) # Issue #33467 + self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)) # Issue #33467 + + # addbmm + self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) + t = torch.randn((5, 6), device=device) + self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) + self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) + + # matmul + self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,))) + self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True)) + self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) + self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) + self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4))) + self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True)) + + # dot + self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) + self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True)) + + if torch._C.has_lapack: + # lu + A_LU, pivots = fn(torch.lu, (0, 5, 5)) + self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape]) + A_LU, pivots = fn(torch.lu, (0, 0, 0)) + self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape]) + A_LU, pivots = fn(torch.lu, (2, 0, 0)) + self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) + + @skipCUDAIfRocm + @dtypesIfCUDA(*(torch.float, torch.double, torch.cfloat, torch.cdouble) + + # This test is disabled on CUDA 9, due to: + # See: https://github.com/pytorch/pytorch/issues/31006 + ((torch.half,) if torch.version.cuda and not torch.version.cuda.startswith('9.') else ())) + @dtypes(*(set(torch.testing.get_all_dtypes()) - {torch.half, torch.bool})) + def test_blas_alpha_beta_empty(self, device, dtype): + if dtype is torch.bfloat16 and self.device_type == 'xla': + # TODO (@zasdfgbnm): this causes the following error on test + # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16: + # + # RuntimeError: _th_equal not supported on CPUType for BFloat16 + return + # ensure beta is respected + value = 11 + input = torch.full((2,), value, dtype=dtype, device=device) + mat = torch.ones((2, 0), dtype=dtype, device=device) + vec = torch.ones((0,), dtype=dtype, device=device) + out = torch.empty((2,), dtype=dtype, device=device) + if dtype.is_complex: + alpha = 6 + 7j + beta = 3 + 4j + else: + alpha = 6 + beta = 3 + self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), + torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta)) + self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), + torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out)) + + # torch.addmm + input = torch.full((2, 3), value, dtype=dtype, device=device) + mat2 = torch.ones((0, 3), dtype=dtype, device=device) + out = torch.empty((2, 3), dtype=dtype, device=device) + self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), + torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta)) + self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), + torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out)) + + @dtypes(*(torch.testing.get_all_complex_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_blas_nan_out(self, device, dtype): + # These functions should work correctly with NaN filled outputs, + # but need special handling, see [NOTE: cpu_zero] + b = 3 + n = 5 + m = 7 + p = 11 + + # torch.mv + nm = torch.randn((m, n), device=device).t() + _m = torch.randn((), device=device).expand(m) + _m_out = torch.full((m,), float('nan'), device=device) + self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) + self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum()) + + # torch.mm + mp = torch.randn((p, m), device=device).t() + np_out = torch.full((n, p), float('nan'), device=device) + self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out)) + + # torch.bmm + bnm = torch.randn((b, m, n), device=device).transpose(1, 2) + bmp = torch.randn((b, p, m), device=device).transpose(1, 2) + bnp_out = torch.full((b, n, p), float('nan'), device=device) + self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out)) + + @onlyCPU # not supported by CUBLAS + def test_blas_mv_large_input(self, device): + # This would previously fail if the allocated output had NaNs, see: + # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero] + n = 3000 + m = 200 + + nm = torch.randn((m, n), device=device).t() + _m = torch.randn((), device=device).expand(m) + _m_out = torch.full((m,), 0., device=device) + + self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) + + @onlyCPU + def test_renorm_ps(self, device): + # full reduction + x = torch.randn(5, 5) + xn = x.numpy() + for p in [1, 2, 3, 4, inf]: + res = x.renorm(p, 1, 1) + expected = x / x.norm(p, 0, keepdim=True).clamp(min=1) + self.assertEqual(res, expected, msg="renorm failed for {}-norm".format(p)) + + @onlyCPU + @skipCPUIfNoLapack + def test_orgqr_errors(self, device): + test_cases = [ + # input1 size, input2 size, error regex + ((10,), (2,), r"'input' should be 2 dimensional"), + ((10, 6), (20,), r"input.size\(1\) must be greater than or equal to input2.size\(0\)"), + ((6, 10), (5,), r"input.size\(0\) must be greater than or equal to input.size\(1\)"), + ((0, 0), (0,), r"'input' should not be empty"), + ((2, 2), (2, 0,), r"'tau' should not be empty") + ] + for a_size, tau_size, error_regex in test_cases: + a = torch.rand(*a_size, device=device) + tau = torch.rand(*tau_size, device=device) + with self.assertRaisesRegex(RuntimeError, error_regex): + torch.orgqr(a, tau) + + @precisionOverride({torch.complex64: 5e-6}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double, torch.cfloat, torch.cdouble) + def test_lu(self, device, dtype): + from torch.testing._internal.common_utils import random_matrix + + def run_test(device, pivot): + def run_subtest(matrix_size, batches, device, pivot, singular=False, a=None): + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + if a is None: + a = random_matrix(rows, columns, *batches, **dict(singular=singular, dtype=dtype)).to(device) + a_LU_info, pivots_info, info_ = a.lu(pivot=pivot, get_infos=True) + self.assertEqual(a_LU_info.size(), torch.Size(batches + (rows, columns))) + self.assertEqual(pivots_info.size(), torch.Size(batches + (min(rows, columns),))) + self.assertEqual(info_.size(), torch.Size(batches)) + # If a randomly generated input matrix is singular, + # then info_ contains indices i such that U[i, i] == + # 0. This however conveys that the factorization was + # successful albeit with a singular input. Therefore, + # we require info.min() >= 0 + self.assertGreaterEqual(info_.min(), 0) + a_LU, pivots = a.lu(pivot=pivot) + self.assertEqual(a_LU, a_LU_info) + self.assertEqual(pivots_info, pivots) + + + P, L, U = torch.lu_unpack(a_LU, pivots) + P_ = P.cpu().numpy() + L_ = L.cpu().numpy() + U_ = U.cpu().numpy() + + self.assertEqual(np.matmul(P_, np.matmul(L_, U_)), a) + + if self.device_type == 'cuda': + # lu without pivoting is implemented only for cuda device + a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True) + P_nopiv, L_nopiv, U_nopiv = torch.lu_unpack(a_LU_info_nopiv, nopiv) + P_nopiv_ = P_nopiv.cpu().numpy() + L_nopiv_ = L_nopiv.cpu().numpy() + U_nopiv_ = U_nopiv.cpu().numpy() + + self.assertEqual(np.matmul(P_nopiv_, np.matmul(L_nopiv_, U_nopiv_)), a) + + k = min(rows, columns) + self.assertEqual(nopiv, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(a.shape[:-2] + (k, ))) + if not singular: + # It is not guaranteed that LU factorization + # without pivoting is able to determine if a + # matrix is singular while LU factorization + # with pivoting is. Therefore, we require the + # equality of info-s only for non-singular + # matrices. + # NOTE: infor_ is reshaped because info_nopiv might have + # squashed batch dimensions for complex types on CUDA, + # see the TODOs above. + self.assertEqual(info_.reshape(info_nopiv.shape), info_nopiv) + + for ms, batch in itertools.product([3, 5, 7, (4, 2), (3, 4)], [(), (2,), (3,), (3, 5)]): + run_subtest(ms, batch, device, pivot) + run_subtest(ms, batch, device, pivot, singular=True) + + # Reproducer of a magma bug, see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on + a = torch.ones(batch + (ms if isinstance(ms, tuple) else (ms, ms)), dtype=torch.double, device=device) + run_subtest(ms, batch, device, pivot, singular=True, a=a) + + # Info should be positive for rank deficient matrices + a = torch.ones(5, 3, 3, device=device) + self.assertGreater(a.lu(pivot=pivot, get_infos=True)[2][0], 0) + + run_test(device, True) + + if self.device_type == 'cpu': + # Error checking, no pivoting variant on CPU + with self.assertRaisesRegex(RuntimeError, 'lu without pivoting is not implemented on the CPU'): + torch.lu(torch.empty(1, 2, 2), pivot=False) + else: + run_test(device, False) + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.double) + def test_lu_unpack(self, device, dtype): + def run_test(pivot): + for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)): + a = torch.randn(*shape, dtype=dtype, device=device) + a_lu, p = torch.lu(a, pivot=pivot) + p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p) + self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a) + + run_test(True) + + if self.device_type == 'cuda': + run_test(False) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_lobpcg_basic(self, device, dtype): + self._test_lobpcg_method(device, dtype, 'basic') + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_lobpcg_ortho(self, device, dtype): + self._test_lobpcg_method(device, dtype, 'ortho') + + def _test_lobpcg_method(self, device, dtype, method): + from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix + from torch._linalg_utils import matmul, qform + from torch._lobpcg import lobpcg + + def test_tracker(worker): + k = worker.iparams['k'] + nc = worker.ivars['converged_count'] + if k <= nc: + tol = worker.fparams['tol'] + rerr = worker.tvars['rerr'] + X = worker.X + E = worker.E + B = worker.B + A = worker.A + dtype = X.dtype + device = X.device + + # Check convergence + self.assertLessEqual(rerr[:k].max(), tol) + + # Check B-orthogonality + I = torch.eye(k, k, dtype=dtype, device=device) + self.assertEqual(qform(B, X[:, :k]), I) + + # Check block equation + self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0) + + orig_lobpcg = lobpcg + + def lobpcg(*args, **kwargs): + kwargs['tracker'] = test_tracker + kwargs['niter'] = 1000 + kwargs['method'] = method + kwargs['tol'] = 1e-8 + return orig_lobpcg(*args, **kwargs) + prec = 5e-4 + + # check dense input + mm = torch.matmul + for batches in [(), (2,), (2, 3)]: + for m, n, k in [ + (9, 3, 1), + (9, 3, 2), + (9, 2, 2), + (100, 15, 5), + ]: + # skip tests that are known to fail with the basic + # LOBPCG method due to calling cholesky on singular + # input + if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]: + continue + A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) + B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) + + # classical eigenvalue problem, smallest eigenvalues + E, V = lobpcg(A, k=k, n=n, largest=False) + self.assertEqual(E.shape, batches + (k,)) + self.assertEqual(V.shape, batches + (m, k)) + self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) + e = torch.symeig(A)[0] + e_smallest = e[..., :k] + self.assertEqual(E, e_smallest) + + # classical eigenvalue problem, largest eigenvalues + E, V = lobpcg(A, k=k, n=n, largest=True) + e_largest, _ = torch.sort(e[..., -k:], descending=True) + self.assertEqual(E, e_largest, atol=prec, rtol=0) + self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) + + # generalized eigenvalue problem, smallest eigenvalues + E, V = lobpcg(A, B=B, k=k, n=n, largest=False) + self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0) + + # generalized eigenvalue problem, largest eigenvalues + E, V = lobpcg(A, B=B, k=k, n=n, largest=True) + self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), + atol=prec, rtol=0) + + # check sparse input + for m, n, k, density in [ + (5, 1, 1, 0.8), + (9, 3, 2, 0.5), + (100, 1, 1, 0.1), + (1000, 7, 3, 0.01), + ]: + # skip tests that are known to fail with the basic LOBCG + # method due to insufficient accuracy + if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]: + continue + A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) + B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) + A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m + e_smallest = A_eigenvalues[..., :k] + e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True) + + # classical eigenvalue problem, smallest eigenvalues + E, V = lobpcg(A, k=k, n=n, largest=False) + self.assertEqual(E, e_smallest) + self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) + + # classical eigenvalue problem, largest eigenvalues + E, V = lobpcg(A, k=k, n=n, largest=True) + self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) + self.assertEqual(E, e_largest) + + # generalized eigenvalue problem, smallest eigenvalues + E, V = lobpcg(A, B=B, k=k, n=n, largest=False) + self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0) + + # generalized eigenvalue problem, largest eigenvalues + E, V = lobpcg(A, B=B, k=k, n=n, largest=True) + self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), + atol=prec, rtol=0) + + @skipCPUIfNoLapack + @onlyCPU + @dtypes(torch.double) + def test_lobpcg_torchscript(self, device, dtype): + from torch.testing._internal.common_utils import random_sparse_pd_matrix + from torch._linalg_utils import matmul as mm + + lobpcg = torch.jit.script(torch.lobpcg) + + m = 500 + k = 5 + A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) + X1 = torch.randn((m, k), dtype=dtype, device=device) + E1, V1 = lobpcg(A1, X=X1) + eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() + self.assertLess(eq_err, 1e-6) + + @unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1") + @skipCPUIfNoLapack + @onlyCPU + @dtypes(torch.double) + def test_lobpcg_scipy(self, device, dtype): + """Compare torch and scipy.sparse.linalg implementations of lobpcg + """ + import time + from torch.testing._internal.common_utils import random_sparse_pd_matrix + from torch._linalg_utils import matmul as mm + from scipy.sparse.linalg import lobpcg as scipy_lobpcg + import scipy.sparse + + def toscipy(A): + if A.layout == torch.sparse_coo: + values = A.coalesce().values().cpu().numpy().copy() + indices = A.coalesce().indices().cpu().numpy().copy() + return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape) + return A.cpu().numpy().copy() + + niter = 1000 + repeat = 10 + m = 500 # size of the square matrix + k = 7 # the number of requested eigenpairs + A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) + B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) + X1 = torch.randn((m, k), dtype=dtype, device=device) + + A2 = toscipy(A1) + B2 = toscipy(B1) + X2 = toscipy(X1) + + lambdas1 = [] + + def tracker(worker): + lambdas1.append(worker.E[:]) + + tol = 1e-8 + # tol for scipy lobpcg will be choosed so that the number of + # iterations will be equal or very close to pytorch lobpcg + # (that is around 170-180) + + # Standard eigenvalue problem + E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) + E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol) + iters1 = len(lambdas1) + iters2 = len(lambdas2) + self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) + + E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False) + + eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() + eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() + self.assertLess(eq_err, 1e-6) # std + self.assertLess(eq_err_scipy, 1e-6) # std + + self.assertEqual(E1, torch.from_numpy(E2.copy())) + + # Generalized eigenvalue problem + lambdas1 = [] + + def tracker(worker): + lambdas1.append(worker.E[:]) + + E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) + E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol) + E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False) + iters1 = len(lambdas1) + iters2 = len(lambdas2) + self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) + + eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() + eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() + self.assertLess(eq_err, 1e-6) # general + self.assertLess(eq_err_scipy, 1e-6) # general + + self.assertEqual(E1, torch.from_numpy(E2.copy())) + + # Timings + elapsed_ortho = 0 + elapsed_ortho_general = 0 + elapsed_scipy = 0 + elapsed_general_scipy = 0 + for i in range(repeat): + start = time.time() + torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol) + end = time.time() + elapsed_ortho += end - start + + start = time.time() + torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol) + end = time.time() + elapsed_ortho_general += end - start + + start = time.time() + scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol) + end = time.time() + elapsed_scipy += end - start + + start = time.time() + scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol) + end = time.time() + elapsed_general_scipy += end - start + + elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat + elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat + elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat + elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat + + print(''' +CPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg +------------------------------------------------------- + | standard | generalized | method +torch.lobpcg | {:10.2f} | {:10.2f} | ortho +scipy_lobpcg | {:10.2f} | {:10.2f} | N/A +-(input size: {:4}, eigenpairs:{:2}, units: ms per call)- + '''.format(elapsed_ortho_ms, elapsed_ortho_general_ms, + elapsed_scipy_ms, elapsed_general_scipy_ms, + m, k)) + + # Handling of very small tolerence + tol = 1e-100 + + lambdas1 = [] + + def tracker(worker): + lambdas1.append(worker.E[:]) + + E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) + iters1 = len(lambdas1) + eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() + + try: + E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) + iters2 = len(lambdas2) + eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() + except Exception as msg: + print('Calling scipy_lobpcg failed [standard]:', msg) + iters2 = -1 + eq_err_scipy = -1 + + lambdas1 = [] + + def tracker(worker): + lambdas1.append(worker.E[:]) + + E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol) + iters1_general = len(lambdas1) + eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() + + try: + E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) + iters2_general = len(lambdas2) + eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() + except Exception as msg: + print('Calling scipy_lobpcg failed [generalized]:', msg) + iters2_general = -1 + eq_err_general_scipy = -1 + + print('''\ +Handling of small tol={:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg +---------------------------------------------------------------------------- + | standard | generalized | niter | method +torch.lobpcg | {:10.2e} | {:10.2e} | {:6} | ortho +scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A +---(input size: {:4}, eigenpairs:{:2}, units: relative error, maxiter={:4})--- +'''.format(tol, eq_err, eq_err_general, iters1, eq_err_scipy, eq_err_general_scipy, iters2, m, k, niter)) + + def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False): + dtype = t.dtype + numpy_dtype = dtype + if dtype in {torch.bfloat16}: + numpy_dtype = torch.float + if dtype.is_complex: + alpha = 0.9 + 0.3j if alpha is None else alpha + beta = 0.5 + 0.6j if beta is None else beta + else: + alpha = 1.2 if alpha is None else alpha + beta = 0.8 if beta is None else beta + res1 = f(t, m, v, alpha=alpha, beta=beta) + res2 = torch.full_like(res1, math.nan) + if transpose_out: + res2 = res2.t().clone(memory_format=torch.contiguous_format).t() + f(t, m, v, alpha=alpha, beta=beta, out=res2) + res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) + if beta != 0: + res3 += (beta * t).to(numpy_dtype).cpu().numpy() + res3 = torch.from_numpy(res3).to(dtype) + self.assertEqual(res1, res2) + self.assertEqual(res1, res3) + + @precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8, + torch.cfloat: 1e-4, torch.cdouble: 1e-8}) + @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), + *([torch.float32, torch.float64, torch.bfloat16] + if TEST_WITH_ROCM else torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))) + @dtypes(torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_addmv(self, device, dtype): + # have to use torch.randn(...).to(bfloat16) instead of + # torch.randn(..., dtype=bfloat16). randn does not support + # bfloat16 yet. + ts = [ + torch.randn(10, device=device).to(dtype), + torch.randn(1, device=device).to(dtype).expand(10), + ] + vs = [ + torch.randn(100, device=device).to(dtype), + torch.ones(1, device=device).to(dtype).expand(100), # to reduce errors for low precision + ] + ms = [ + # 0d + torch.ones((), device=device).to(dtype).expand(10, 100), # to reduce errors for low precision + # 1d + torch.randn((1, 100), device=device).to(dtype).expand(10, 100), + # this initialization reduces errors for low precision for broadcasted matrices + # by making sure that intermediate and result values are exactly representable + # in low precision type + torch.randint(3, (10, 1), dtype=torch.float, device=device).to(dtype).expand(10, 100), + # 2d + torch.randn((10, 100), device=device).to(dtype), + torch.randn((100, 10), device=device).to(dtype).t(), + ] + for m, v, t in itertools.product(ms, vs, ts): + self._test_addmm_addmv(torch.addmv, t, m, v) + # Test beta=0, t=nan + t = torch.full((10,), math.nan, device=device).to(dtype) + for m, v in itertools.product(ms, vs): + self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) + + @dtypesIfCUDA(*([torch.half, torch.float, torch.double] + + ([torch.bfloat16] if TEST_WITH_ROCM else []))) + @dtypes(torch.float, torch.double) + def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): + # tests (o, s)*(s). o is output size, s is summed size. + o = 5 + s = 3 + a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) + x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) + y_data = torch.ones(o, device=device, dtype=dtype) + control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype) + + def _test(row_major, incx, incy, lda_tail): + if row_major: + a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype) + else: + a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0) + a = a_storage[:o, :s].copy_(a_data) + + x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype) + x = x_storage[:, 0].copy_(x_data) + + y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype) + y = y_storage[:, 0].copy_(y_data) + + self._test_addmm_addmv(torch.addmv, y, a, x) + + for row_major, incx, incy, lda_tail in itertools.product((False, True), (1, 2), (1, 2), (0, 1)): + _test(row_major, incx, incy, lda_tail) + + @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, + torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) + @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) + @dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes()) + @tf32_on_and_off(0.05) + def test_addmm(self, device, dtype): + M = torch.randn(10, 25, device=device).to(dtype) + m1 = torch.randn(10, 50, device=device).to(dtype) + m2 = torch.randn(50, 25, device=device).to(dtype) + self._test_addmm_addmv(torch.addmm, M, m1, m2) + + # Test 0-strided + M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) + m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50) + m2 = torch.randn(50, 25, device=device).to(dtype) + self._test_addmm_addmv(torch.addmm, M, m1, m2) + + # Test beta=0, M=nan + M = torch.full((10, 25), math.nan, device=device).to(dtype) + m1 = torch.randn(10, 50, device=device).to(dtype) + m2 = torch.randn(50, 25, device=device).to(dtype) + self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0) + + # Test transpose + for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): + def maybe_transpose(cond, m): + if not cond: + return m + return m.t().clone(memory_format=torch.contiguous_format).t() + + M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) + m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) + m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) + self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4) + + @dtypes(torch.float, torch.double) + @dtypesIfCUDA(*([torch.float, torch.double] + + ([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes()))) + @tf32_on_and_off(0.005) + def test_addmm_sizes(self, device, dtype): + for m in [0, 1, 25]: + for n in [0, 1, 10]: + for k in [0, 1, 8]: + M = torch.randn(n, m, device=device).to(dtype) + m1 = torch.randn(n, k, device=device).to(dtype) + m2 = torch.randn(k, m, device=device).to(dtype) + self._test_addmm_addmv(torch.addmm, M, m1, m2) + + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") + @onlyCUDA + def test_matmul_45724(self, device): + # https://github.com/pytorch/pytorch/issues/45724 + a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) + b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) + c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) + cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half() + torch.matmul(a, b, out=c) + self.assertEqual(c, cpu_result) + + @slowTest + @onlyOnCPUAndCUDA + @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) + @tf32_on_and_off(0.01) + def test_mm(self, device, dtype): + def _test_mm(n, m, p, dtype, genf): + # helper function + def matrixmultiply(mat1, mat2): + n = mat1.size(0) + m = mat1.size(1) + p = mat2.size(1) + res = torch.zeros(n, p, dtype=dtype, device=device) + for i, j in iter_indices(res): + res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) + return res + + # contiguous case + mat1 = genf(n, m) + mat2 = genf(m, p) + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # non contiguous case 1 + mat1 = genf(n, m) + mat2 = genf(p, m).t() + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # non contiguous case 2 + mat1 = genf(m, n).t() + mat2 = genf(m, p) + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # non contiguous case 3 + mat1 = genf(m, n).t() + mat2 = genf(p, m).t() + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # test with zero stride + mat1 = genf(n, m) + mat2 = genf(m, 1).expand(m, p) + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # explicitly exercise the _out variant in torch.mm(). + # contiguous case + mat1 = genf(n, m) + mat2 = genf(m, p) + res = genf(n, p) + torch.mm(mat1, mat2, out=res) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # explicitly exercise the _out variant in torch.mm(). + # non contiguous case 3 + mat1 = genf(m, n).t() + mat2 = genf(p, m).t() + res = genf(n, p) + torch.mm(mat1, mat2, out=res) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + def genf_int(x, y): + return torch.randint(0, 100, (x, y), dtype=dtype, device=device) + + def genf_bfloat(x, y): + return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) + + def genf_float(x, y): + return torch.randn(x, y, dtype=dtype, device=device) + + for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]: + if (dtype == torch.int32) or (dtype == torch.int64): + genf = genf_int + elif (dtype == torch.bfloat16): + genf = genf_bfloat + else: + genf = genf_float + + _test_mm(n, m, p, dtype, genf) + + @onlyOnCPUAndCUDA + @dtypes(torch.float32, torch.float64) + def test_strided_mm_bmm(self, device, dtype): + # Tests strided view case with stride smaller than corresponding dimension size + x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device) + new_shape = [2, 2, 2] + new_stride = [3, 1, 1] + sx = torch.as_strided(x, size=new_shape, stride=new_stride) + + torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 + np_fn = lambda x: np.matmul(x, x) # noqa: E731 + self.compare_with_numpy(torch_fn, np_fn, sx) + + torch_fn = lambda x: torch.mm(x, x) # noqa: E731 + self.compare_with_numpy(torch_fn, np_fn, sx[0]) + + @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) + @skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1") + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) + @tf32_on_and_off(0.05) + def test_bmm(self, device, dtype): + num_batches = 10 + M, N, O = 23, 8, 12 + numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 + + if self.device_type == 'cpu': + is_supported = True + elif self.device_type == 'cuda': + is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM + + if not is_supported: + b1 = torch.randn(num_batches, M, N, device=device).to(dtype) + b2 = torch.randn(num_batches, N, O, device=device).to(dtype) + self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2)) + return + + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2]) + + def generate_inputs(): + # transposed tensors + for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): + b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) + b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) + b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) + yield b1, b2 + # broadcasting tensors + for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): + shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) + shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) + b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) + yield b1, b2 + # zero-sized tensors + for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): + shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) + shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) + b1 = torch.randn(shape1, dtype=dtype, device=device) + b2 = torch.randn(shape2, dtype=dtype, device=device) + yield b1, b2 + + for (b1, b2), perm3 in itertools.product(generate_inputs(), itertools.permutations((0, 1, 2))): + res1 = torch.bmm(b1, b2) + res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \ + .permute(perm3).contiguous().permute(invert_perm(perm3)) + torch.bmm(b1, b2, out=res2) + expect = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) + self.assertEqual(expect, res1) + self.assertEqual(expect, res2) + + if self.device_type == 'cuda': + # check that mixed arguments are rejected + self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) + self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) + self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())) + + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") + @onlyCUDA + @wrapDeterministicFlagAPITest + def test_cublas_config_deterministic_error(self, device): + test_cases = [ + # (function, (tensor sizes)) + ('mm', ((2, 2), (2, 2),)), + ('mv', ((2, 2), (2,),)), + ('bmm', ((1, 2, 2), (1, 2, 2),))] + + test_configs = [ + # (CuBLAS workspace config, is deterministic) + ('garbage', False), + (None, False), + (':4096:8', True), + (':16:8', True)] + + cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG' + is_cuda10_2_or_higher = ( + (torch.version.cuda is not None) + and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2])) + + def test_case_info(fn_name, config): + return f'function "{fn_name}" with config "{"" if config is None else config}"' + + # Create processes to test each combination of test cases and config settings + processes = [] + for fn_name, arg_sizes in test_cases: + for config, is_config_deterministic in test_configs: + env = os.environ.copy() + if config is None: + if env.get(cublas_var_name) is not None: + del env[cublas_var_name] + else: + env[cublas_var_name] = config + should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic + script = f""" +import torch +torch.set_deterministic(True) +fn = torch.{fn_name} +arg_sizes = {arg_sizes} +device = '{device}' +should_throw_error = {should_throw_error} +args = [] +for arg_size in arg_sizes: + args.append(torch.randn(*arg_size, device=device)) +try: + fn(*args) +except RuntimeError as e: + if not should_throw_error: + raise RuntimeError('Did not expect any error to be raised') + elif 'Deterministic behavior was enabled with either' not in str(e): + raise RuntimeError('Expected a CuBLAS nondeterministic error, but got a different error') +else: + if should_throw_error: + raise RuntimeError('Expected a CuBLAS nondeterministic error, but it was not raised') + +""" + try: + subprocess.check_output( + [sys.executable, '-c', script], + stderr=subprocess.STDOUT, + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)), + env=env) + except subprocess.CalledProcessError as e: + self.fail(msg=( + f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n' + + e.output.decode("utf-8"))) + + def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): + getattr(out_tensor, func + "_")(b1, b2) + self.assertEqual(out_tensor, ref) + res3 = out_tensor.clone() + + with self.maybeWarnsRegex( + UserWarning, f"This overload of {func}_ is deprecated"): + getattr(out_tensor, func + "_")(1, b1, b2) + self.assertEqual(out_tensor, ref * 2), + getattr(res3, func + "_")(b1, b2, beta=1) + self.assertEqual(out_tensor, res3) + + with self.maybeWarnsRegex( + UserWarning, f"This overload of {func}_ is deprecated"): + getattr(out_tensor, func + "_")(1., .5, b1, b2) + self.assertEqual(out_tensor, ref * 2.5) + getattr(res3, func + "_")(b1, b2, beta=1., alpha=.5) + self.assertEqual(out_tensor, res3) + + with self.maybeWarnsRegex( + UserWarning, f"This overload of {func} is deprecated"): + self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) + + res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5) + self.assertEqual(res4, ref * 3), + + nan = torch.full_like(out_tensor, math.nan) + res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) + self.assertEqual(res5, ref) + + if b1.is_complex(): + res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1j, alpha=.5j) + self.assertEqual(res6, out_tensor * .1j + .5j * ref) + else: + res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1, alpha=.5) + self.assertEqual(res6, out_tensor * .1 + .5 * ref) + + res7 = torch.full_like(out_tensor, math.nan) + getattr(torch, func)(nan, b1, b2, beta=0, out=res7) + self.assertEqual(res7, ref) + + @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) + @tf32_on_and_off(0.05) + def test_addbmm(self, device, dtype): + num_batches = 2 + M, N, O = 2, 3, 4 + + if self.device_type == 'cpu': + is_supported = True + if dtype == torch.bfloat16: + self.precision = 1 # 43 vs 43.75 + else: + is_supported = (dtype != torch.bfloat16 or AMPERE_OR_ROCM) + + if not is_supported: + b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) + b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + t = make_tensor((M, O), device, dtype, low=-1, high=1) + self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.addbmm(t, b1, b2)) + return + + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2]) + + def generate_tensor(): + numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 + # transposed tensors + for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): + for perm3 in itertools.permutations((0, 1)): + b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) + b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) + b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() + ).to(device=device, dtype=dtype).sum(0) + out_tensor = torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) + yield b1, b2, ref, out_tensor + # broadcasting tensors + for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): + shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) + shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) + b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() + ).to(device=device, dtype=dtype).sum(0) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + # zero-sized tensors + for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): + shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) + shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1) + b2 = make_tensor(shape2, device, dtype, low=-1, high=1) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() + ).to(device=device, dtype=dtype).sum(0) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + + for b1, b2, ref, out_tensor in generate_tensor(): + self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) + + @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) + @tf32_on_and_off(0.05) + def test_baddbmm(self, device, dtype): + num_batches = 10 + M, N, O = 12, 8, 5 + + if self.device_type == 'cpu': + is_supported = True + elif self.device_type == 'cuda': + is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM + + if not is_supported: + b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) + b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + t = make_tensor((num_batches, M, O), device, dtype, low=-1, high=1) + self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.baddbmm(t, b1, b2)) + return + + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2]) + + def generate_tensor(): + numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 + # transposed tensors + for perm1, perm2, perm3 in itertools.product(itertools.permutations((0, 1, 2)), repeat=3): + b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) + b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) + b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) + out_tensor = torch.zeros_like(ref) + out_tensor = out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) + yield b1, b2, ref, out_tensor + # broadcasting tensors + for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): + shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) + shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) + b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + # zero-sized tensors + for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): + shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) + shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) + b1 = make_tensor(shape1, device, dtype, low=-2, high=2) + b2 = make_tensor(shape2, device, dtype, low=-2, high=2) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + + for b1, b2, ref, out_tensor in generate_tensor(): + self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) + + # TODO: update to compare against NumPy + @onlyCUDA + def test_solve_methods_arg_device(self, device): + for b_device, A_device in itertools.product(['cpu', device], repeat=2): + if b_device == A_device: + continue + + b = torch.randn(3, 1, device=b_device) + A = torch.randn(3, 3, device=A_device) + err_str = "Expected b and A to be on the same device" + with self.assertRaisesRegex(RuntimeError, err_str): + torch.solve(b, A) + + with self.assertRaisesRegex(RuntimeError, err_str): + torch.cholesky_solve(b, A) + + with self.assertRaisesRegex(RuntimeError, err_str): + torch.triangular_solve(b, A) + + # b and A have to be modified to match accepted inputs sizes for lu_solve + b = b.unsqueeze(0) + A = A.unsqueeze(0) + with self.assertRaisesRegex(RuntimeError, err_str): + torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int()) + + # This checks if a suitable error message is thrown + # when LU output and pivots are on the same device + with self.assertRaisesRegex(RuntimeError, + "Expected LU_pivots and LU_data to be on the same device"): + torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) + + @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_pinverse(self, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value as fullrank + + def run_test(M): + # Testing against definition for pseudo-inverses + MPI = torch.pinverse(M) + MPI_ = MPI.cpu().numpy() + M_ = M.cpu().numpy() + if M.numel() > 0: + self.assertEqual(M_, np.matmul(np.matmul(M_, MPI_), M_)) + self.assertEqual(MPI_, np.matmul(np.matmul(MPI_, M_), MPI_)) + self.assertEqual(np.matmul(M_, MPI_), np.matmul(M_, MPI_).swapaxes(-2, -1).conj()) + self.assertEqual(np.matmul(MPI_, M_), np.matmul(MPI_, M_).swapaxes(-2, -1).conj()) + else: + self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2])) + for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5), # square matrices + (3, 2), (5, 3, 2), (7, 5, 3, 2), # fat matrices + (2, 3), (5, 2, 3), (7, 5, 2, 3), # thin matrices + (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices + M = torch.randn(*sizes, dtype=dtype, device=device) + run_test(M) + + # Test inverse and pseudo-inverse for invertible matrix + for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]: + matsize = sizes[-1] + batchdims = sizes[:-2] + M = fullrank(matsize, *batchdims, dtype=dtype, device=device) + self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M), + atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix') + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_matrix_power(self, device, dtype): + def run_test(M, sign=1): + if sign == -1: + M = M.inverse() + MP2 = torch.matrix_power(M, 2) + self.assertEqual(MP2, torch.matmul(M, M)) + + MP3 = torch.matrix_power(M, 3) + self.assertEqual(MP3, torch.matmul(MP2, M)) + + MP4 = torch.matrix_power(M, 4) + self.assertEqual(MP4, torch.matmul(MP2, MP2)) + + MP6 = torch.matrix_power(M, 6) + self.assertEqual(MP6, torch.matmul(MP3, MP3)) + + MP0 = torch.matrix_power(M, 0) + self.assertEqual(MP0, torch.eye(M.size(-2), dtype=dtype).expand_as(M)) + + # Single matrix + M = torch.randn(5, 5, dtype=dtype, device=device) + run_test(M) + + # Batch matrices + M = torch.randn(3, 3, 3, dtype=dtype, device=device) + run_test(M) + + # Many batch matrices + M = torch.randn(2, 3, 3, 3, dtype=dtype, device=device) + run_test(M) + + # This is for negative powers + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + M = random_fullrank_matrix_distinct_singular_value(5, dtype=dtype, device=device) + run_test(M, sign=-1) + + M = random_fullrank_matrix_distinct_singular_value(3, 3, dtype=dtype, device=device) + run_test(M, sign=-1) + + M = random_fullrank_matrix_distinct_singular_value(3, 2, 3, dtype=dtype, device=device) + run_test(M, sign=-1) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.complex64) + def test_matrix_exp_utils(self, device, dtype): + # test linear combination + def run_test(coeff_shape, data_shape): + coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float) + x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype) + + res1 = torch._compute_linear_combination(x, coeffs) + res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1) + self.assertEqual(res1, res2, atol=1e-5, rtol=0.0) + + # check `out=` version + res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype) + torch._compute_linear_combination(x, coeffs, out=res3) + self.assertEqual(res1, res3, atol=1e-5, rtol=0.0) + + res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) + torch._compute_linear_combination(x, coeffs, out=res4) + self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0) + + res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) + res5_clone = res5.clone() + torch._compute_linear_combination(x, coeffs, out=res5) + self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0) + + run_test([1, 3], [2, 2]) + run_test([3, 1], [2, 2]) + run_test([1, 10], [10, 10]) + run_test([10, 1], [10, 10]) + run_test([5, 3], [2, 2]) + run_test([5, 3], [100, 100]) + run_test([3, 4], [3, 3, 3]) + run_test([3, 4], [3, 3, 3, 3]) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) + def test_matrix_exp_boundary_cases(self, device, dtype): + + with self.assertRaisesRegex(RuntimeError, "expected a tensor of floating or complex types"): + torch.randn(3, 3).type(torch.int).matrix_exp() + + with self.assertRaisesRegex(RuntimeError, "with dim at least 2"): + torch.randn(3).matrix_exp() + + with self.assertRaisesRegex(RuntimeError, "expected a tensor of squared matrices"): + torch.randn(3, 2, 1).matrix_exp() + + # check 1x1 matrices + x = torch.randn(3, 3, 1, 1) + mexp = x.matrix_exp() + self.assertEqual(mexp, x.exp()) + + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_matrix_exp_analytic(self, device, dtype): + # check zero matrix + x = torch.zeros(20, 20, dtype=dtype, device=device) + self.assertTrue((x.matrix_exp() == torch.eye(20, 20, dtype=dtype, device=device)).all().item()) + + def normalize_to_1_operator_norm(sample, desired_norm): + sample_norm, _ = sample.abs().sum(-2).max(-1) + sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) + return sample_to_1_norm * desired_norm + + def gen_good_cond_number_matrices(*n): + """ + Generates a diagonally-domimant matrix + with the eigenvalues centered at 1 + and the radii at most (n[-1] - 1) / (n[-2] ** 2) + """ + identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) + x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) + x = (x - x * identity) + identity + return x + + def run_test(*n): + if dtype == torch.float: + thetas = [ + 1.192092800768788e-07, # deg 1 + 5.978858893805233e-04, # deg 2 + 5.116619363445086e-02, # deg 4 + 5.800524627688768e-01, # deg 8 + 1.461661507209034e+00, # deg 12 + 3.010066362817634e+00 # deg 18 + ] + else: # if torch.double + thetas = [ + 2.220446049250313e-16, # deg 1 + 2.580956802971767e-08, # deg 2 + 3.397168839976962e-04, # deg 4 + 4.991228871115323e-02, # deg 8 + 2.996158913811580e-01, # deg 12 + 1.090863719290036e+00 # deg 18 + ] + + # generate input + q = gen_good_cond_number_matrices(*n) + q_ = q.cpu().numpy() + qinv = torch.inverse(q) + qinv_ = qinv.cpu().numpy() + d = torch.randn(n[:-1], dtype=dtype, device=device) + x = torch.from_numpy( + np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device) + x_norm, _ = x.abs().sum(-2).max(-1) + + # test simple analytic whatever norm generated + mexp = x.matrix_exp() + mexp_analytic = np.matmul( + q_, + np.matmul( + torch.diag_embed(d.exp()).cpu().numpy(), + qinv_ + ) + ) + self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) + + # generate norms to test different degree expansions + sample_norms = [] + for i in range(len(thetas) - 1): + sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) + sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] + + # matrices to equal norm + for sample_norm in sample_norms: + x_normalized = normalize_to_1_operator_norm(x, sample_norm) + + mexp = x_normalized.matrix_exp() + mexp_analytic = np.matmul( + q_, + np.matmul( + torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(), + qinv_ + ) + ) + self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) + + # single matrix + run_test(2, 2) + run_test(3, 3) + run_test(4, 4) + run_test(5, 5) + run_test(100, 100) + run_test(200, 200) + + # small batch of matrices + run_test(3, 2, 2) + run_test(3, 3, 3) + run_test(3, 4, 4) + run_test(3, 5, 5) + run_test(3, 100, 100) + run_test(3, 200, 200) + + # large batch of matrices + run_test(3, 3, 2, 2) + run_test(3, 3, 3, 3) + run_test(3, 3, 4, 4) + run_test(3, 3, 5, 5) + run_test(3, 3, 100, 100) + run_test(3, 3, 200, 200) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + def test_matrix_exp_batch(self, device, dtype): + + def run_test(*n): + tensors_batch = torch.zeros(n, dtype=dtype, device=device) + tensors_batch = tensors_batch.view(-1, n[-2], n[-1]) + + num_matrices = tensors_batch.size(0) + tensors_list = [] + for i in range(num_matrices): + tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device)) + + for i in range(num_matrices): + tensors_batch[i, ...] = tensors_list[i] + + tensors_exp_map = (x.matrix_exp() for x in tensors_list) + tensors_exp_batch = tensors_batch.matrix_exp() + + for i, tensor_exp in enumerate(tensors_exp_map): + self.assertEqual(tensors_exp_batch[i, ...], tensor_exp) + + # small batch of matrices + run_test(3, 2, 2) + run_test(3, 3, 3) + run_test(3, 4, 4) + run_test(3, 5, 5) + + # large batch of matrices + run_test(3, 3, 2, 2) + run_test(3, 3, 3, 3) + run_test(3, 3, 4, 4) + run_test(3, 3, 5, 5) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_matrix_exp_compare_with_taylor(self, device, dtype): + + def normalize_to_1_operator_norm(sample, desired_norm): + sample_norm, _ = sample.abs().sum(-2).max(-1) + sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) + return sample_to_1_norm * desired_norm + + def gen_good_cond_number_matrices(*n): + """ + Generates a diagonally-domimant matrix + with the eigenvalues centered at 1 + and the radii at most (n[-1] - 1) / (n[-2] ** 2) + """ + identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) + x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) + x = (x - x * identity) + identity + return x + + def get_taylor_approximation(a, deg): + a_ = a.cpu().numpy() + identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a) + res = identity.cpu().numpy() + taylor_term = identity.cpu().numpy() + + for i in range(1, deg + 1): + taylor_term = np.matmul(a_, taylor_term) / i + res = res + taylor_term + + return res + + def scale_square(a, deg): + if a.abs().pow(2).sum().sqrt() < 1.0: + return get_taylor_approximation(a, 12) + else: + s = int(torch.log2(a.abs().pow(2).sum().sqrt()).ceil().item()) + b = a / (2 ** s) + b = get_taylor_approximation(b, 18) + for _ in range(s): + b = np.matmul(b, b) + return torch.from_numpy(b).to(a.device) + + def run_test(*n): + degs = [1, 2, 4, 8, 12, 18] + if dtype == torch.float: + thetas = [ + 1.192092800768788e-07, # deg 1 + 5.978858893805233e-04, # deg 2 + 5.116619363445086e-02, # deg 4 + 5.800524627688768e-01, # deg 8 + 1.461661507209034e+00, # deg 12 + 3.010066362817634e+00 # deg 18 + ] + else: # if torch.double + thetas = [ + 2.220446049250313e-16, # deg 1 + 2.580956802971767e-08, # deg 2 + 3.397168839976962e-04, # deg 4 + 4.991228871115323e-02, # deg 8 + 2.996158913811580e-01, # deg 12 + 1.090863719290036e+00 # deg 18 + ] + + # generate norms to test different degree expansions + sample_norms = [] + for i in range(len(thetas) - 1): + sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) + sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] + degs = [degs[0]] + degs + + for sample_norm, deg in zip(sample_norms, degs): + x = gen_good_cond_number_matrices(*n) + x = normalize_to_1_operator_norm(x, sample_norm) + + mexp = x.matrix_exp() + mexp_taylor = scale_square(x, deg) + + self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0) + + # single matrix + run_test(2, 2) + run_test(3, 3) + run_test(4, 4) + run_test(5, 5) + + # small batch of matrices + run_test(3, 2, 2) + run_test(3, 3, 3) + run_test(3, 4, 4) + run_test(3, 5, 5) + + # large batch of matrices + run_test(3, 3, 2, 2) + run_test(3, 3, 3, 3) + run_test(3, 3, 4, 4) + run_test(3, 3, 5, 5) + + @dtypes(torch.double) + def test_chain_matmul(self, device, dtype): + def product(matrices): + for mat in matrices[1:]: + matrices[0] = matrices[0].mm(mat) + return matrices[0] + + def run_test(p): + matrices = [] + for (pi, pi_1) in zip(p[:-1], p[1:]): + matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device)) + self.assertEqual(torch.chain_matmul(*matrices), product(matrices)) + + run_test([10, 20, 30, 5]) + run_test([15, 5, 10, 20, 25]) + + with self.assertRaisesRegex(RuntimeError, "chain_matmul: Expected one or more matrices"): + torch.chain_matmul() + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_slogdet(self, device, dtype): + from torch.testing._internal.common_utils import (random_hermitian_matrix, random_hermitian_psd_matrix, + random_hermitian_pd_matrix, random_square_matrix_of_rank) + + # mat_chars denotes matrix characteristics + # possible values are: hermitian, hermitian_psd, hermitian_pd, singular, non_singular + def run_test(matsize, batchdims, mat_chars): + num_matrices = np.prod(batchdims) + list_of_matrices = [] + if num_matrices != 0: + for idx in range(num_matrices): + mat_type = idx % len(mat_chars) + if mat_chars[mat_type] == 'hermitian': + list_of_matrices.append(random_hermitian_matrix(matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'hermitian_psd': + list_of_matrices.append(random_hermitian_psd_matrix(matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'hermitian_pd': + list_of_matrices.append(random_hermitian_pd_matrix(matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'singular': + list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'non_singular': + list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) + full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) + else: + full_tensor = torch.randn(*batchdims, matsize, matsize, dtype=dtype, device=device) + + actual_value = torch.linalg.slogdet(full_tensor) + expected_value = np.linalg.slogdet(full_tensor.cpu().numpy()) + self.assertEqual(expected_value[0], actual_value[0], atol=self.precision, rtol=self.precision) + self.assertEqual(expected_value[1], actual_value[1], atol=self.precision, rtol=self.precision) + + # test out=variant + sign_out = torch.empty_like(actual_value[0]) + logabsdet_out = torch.empty_like(actual_value[1]) + ans = torch.linalg.slogdet(full_tensor, out=(sign_out, logabsdet_out)) + self.assertEqual(ans[0], sign_out) + self.assertEqual(ans[1], logabsdet_out) + self.assertEqual(sign_out, actual_value[0]) + self.assertEqual(logabsdet_out, actual_value[1]) + + for matsize, batchdims in itertools.product([0, 3, 5], [(0,), (3,), (5, 3)]): + run_test(matsize, batchdims, mat_chars=['hermitian_pd']) + run_test(matsize, batchdims, mat_chars=['singular']) + run_test(matsize, batchdims, mat_chars=['non_singular']) + run_test(matsize, batchdims, mat_chars=['hermitian', 'hermitian_pd', 'hermitian_psd']) + run_test(matsize, batchdims, mat_chars=['singular', 'non_singular']) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_slogdet_errors_and_warnings(self, device, dtype): + # slogdet requires the input to be a square matrix or batch of square matrices + a = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.slogdet(a) + + # slogdet requires the input to be at least 2 dimensional tensor + a = torch.randn(2, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): + torch.linalg.slogdet(a) + + # slogdet requires the input to be of float, double, cfloat or cdouble types + a = torch.randn(2, 2, device=device, dtype=torch.bfloat16) + with self.assertRaisesRegex(RuntimeError, r'of float, double, cfloat or cdouble types'): + torch.linalg.slogdet(a) + + # if non-empty out tensor with wrong shape is passed a warning is given + a = torch.randn(2, 3, 3, device=device, dtype=dtype) + sign_out = torch.empty(1, device=device, dtype=dtype) + real_dtype = a.real.dtype if dtype.is_complex else dtype + logabsdet_out = torch.empty(1, device=device, dtype=real_dtype) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + sign_out = torch.empty_like(a).to(torch.int) + logabsdet_out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "sign dtype Int does not match input dtype"): + torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) + + sign_out = torch.empty(0, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "logabsdet dtype Int does not match the expected dtype"): + torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) + + # device should match + if torch.cuda.is_available(): + wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' + sign_out = torch.empty(0, device=wrong_device, dtype=dtype) + logabsdet_out = torch.empty(0, device=wrong_device, dtype=real_dtype) + with self.assertRaisesRegex(RuntimeError, "Expected sign, logabsdet and input to be on the same device"): + torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) + + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_det_logdet_slogdet(self, device, dtype): + def reference_slogdet(M): + sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) + return M.new_tensor(sdet), M.new_tensor(logabsdet) + + def test_single_det(M, target, desc): + target_sdet, target_logabsdet = target + + det = M.det() + logdet = M.logdet() + sdet, logabsdet = M.slogdet() + linalg_sdet, linalg_logabsdet = torch.linalg.slogdet(M) + + # Test det + self.assertEqual(det, target_sdet * target_logabsdet.exp(), + atol=1e-7, rtol=0, msg='{} (det)'.format(desc)) + + # Test slogdet + # Compare the overall value rather than individual parts because of + # precision issues when det is near zero. + self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), + atol=1e-7, rtol=0, msg='{} (slogdet)'.format(desc)) + self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(), + atol=1e-7, rtol=0, msg='{} (linalg_slogdet)'.format(desc)) + + # Test logdet + # Compare logdet against our own pytorch slogdet because they should + # be consistent, while it may behave slightly differently with other + # slogdet implementations when det is near zero due to precision + # issues. + if sdet.item() < 0: + self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc)) + else: + self.assertEqual(logdet.exp(), target_logabsdet.exp(), + atol=1e-7, rtol=0, msg='{} (logdet non-negative case)'.format(desc)) + + eye = torch.eye(5, dtype=dtype, device=device) + test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity') + # Testing bug in #34061 (https://github.com/pytorch/pytorch/issues/34061) + for n in range(250, 551, 100): + mat = torch.randn(n, n, dtype=dtype, device=device) + q, _ = torch.qr(mat) + ref_det, ref_logabsdet = reference_slogdet(q) + test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal') + + def test(M): + assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' + M = M.to(device) + + ref_M_sdet, ref_M_logabsdet = reference_slogdet(M) + + test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic') + if ref_M_logabsdet.exp().item() >= 1e-6: # skip singular + M_inv = M.inverse() + test_single_det(M_inv, reference_slogdet(M_inv), 'inverse') + + test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose') + + for x in [0, 2, 4]: + for scale in [-2, -0.1, 0, 10]: + if scale > 0: + target = ref_M_sdet, ref_M_logabsdet + math.log(scale) + elif scale == 0: + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + else: + target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale) + + # dim 0 + M_clone = M.clone() + M_clone[:, x] *= scale + test_single_det(M_clone, target, 'scale a row') + # dim 1 + M_clone = M.clone() + M_clone[x, :] *= scale + test_single_det(M_clone, target, 'scale a column') + + for x1, x2 in [(0, 3), (4, 1), (3, 2)]: + assert x1 != x2, 'x1 and x2 needs to be different for this test' + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + # dim 0 + M_clone = M.clone() + M_clone[:, x2] = M_clone[:, x1] + test_single_det(M_clone, target, 'two rows are same') + # dim 1 + M_clone = M.clone() + M_clone[x2, :] = M_clone[x1, :] + test_single_det(M_clone, target, 'two columns are same') + + for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]: + det_scale = scale1 * scale2 * -1 + if det_scale > 0: + target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale) + elif det_scale == 0: + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + else: + target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale) + + # dim 0 + M_clone = M.clone() + t = M_clone[:, x1] * scale1 + M_clone[:, x1] += M_clone[:, x2] * scale2 + M_clone[:, x2] = t + test_single_det(M_clone, target, 'exchanging rows') + # dim 1 + M_clone = M.clone() + t = M_clone[x1, :] * scale1 + M_clone[x1, :] += M_clone[x2, :] * scale2 + M_clone[x2, :] = t + test_single_det(M_clone, target, 'exchanging columns') + + def get_random_mat_scale(n): + # For matrices with values i.i.d. with 0 mean, unit variance, and + # subexponential tail, we have: + # E[log det(A^2)] \approx log((n-1)!) + # + # Notice: + # log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)] + # + # So: + # stddev[det(A)] >= sqrt( (n-1)! ) + # + # We use this as an intuitive guideline to scale random generated + # matrices so our closeness tests can work more robustly: + # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n)) + # + # source: https://arxiv.org/pdf/1112.0752.pdf + + # TODO: technically we need subexponential distn for this to hold, + # but we mostly use gaussian entries below. Consider switching + # to Chi-sq if this turns out not stable enough, since Chi-sq + # is easy enough to sample from. + return math.factorial(n - 1) ** (-1.0 / (2 * n)) + + for n in [5, 10, 25]: + scale = get_random_mat_scale(n) + test(torch.randn(n, n, dtype=dtype, device=device) * scale) + r = torch.randn(n, n, dtype=dtype, device=device) * scale + # symmetric psd + test(r.mm(r.t())) + # symmetric pd + r = torch.randn(n, n, dtype=dtype, device=device) * scale + test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6) + # symmetric + r = torch.randn(n, n, dtype=dtype, device=device) * scale + for i in range(n): + for j in range(i): + r[i, j] = r[j, i] + test(r) + # non-contiguous + test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:]) + # det = 0 + r = torch.randn(n, n, dtype=dtype, device=device) * scale + u, s, v = r.svd() + if reference_slogdet(u)[0] < 0: + u = -u + if reference_slogdet(v)[0] < 0: + v = -v + s[0] *= -1 + s[-1] = 0 + test(u.mm(s.diag()).mm(v)) + + # Small values to test numerical stability. Note that we don't scale + # this matrix. + r = torch.randn(512, 512, dtype=dtype, device=device) + u, s, v = r.svd() + s.fill_(1. / (100 * s.numel())) + test(u.mm(s.diag()).mm(v)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_det_logdet_slogdet_batched(self, device, dtype): + from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix, + random_symmetric_pd_matrix, random_square_matrix_of_rank) + + # mat_chars denotes matrix characteristics + # possible values are: sym, sym_psd, sym_pd, sing, non_sym + def run_test(matsize, batchdims, mat_chars): + num_matrices = reduce(lambda x, y: x * y, batchdims, 1) + list_of_matrices = [] + + for idx in range(num_matrices): + mat_type = idx % len(mat_chars) + if mat_chars[mat_type] == 'sym': + list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'sym_psd': + list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'sym_pd': + list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'sing': + list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) + elif mat_chars[mat_type] == 'non_sing': + list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) + full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) + # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet + full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize))) + + for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]: + expected_value = [] + actual_value = fn(full_tensor) + for full_idx in itertools.product(*map(lambda x: list(range(x)), batchdims)): + expected_value.append(fn(full_tensor[full_idx])) + + if fn == torch.slogdet or fn == torch.linalg.slogdet: + sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims) + expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims) + self.assertEqual(sign_value, actual_value[0]) + self.assertEqual(expected_value, actual_value[1]) + else: + expected_value = torch.stack(expected_value, dim=0).reshape(batchdims) + self.assertEqual(actual_value, expected_value) + + for matsize, batchdims in itertools.product([3, 5], [(3,), (5, 3)]): + run_test(matsize, batchdims, mat_chars=['sym_pd']) + run_test(matsize, batchdims, mat_chars=['sing']) + run_test(matsize, batchdims, mat_chars=['non_sing']) + run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) + run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_cholesky_inverse(self, device, dtype): + from torch.testing._internal.common_utils import random_symmetric_pd_matrix + a = random_symmetric_pd_matrix(5, dtype=dtype, device=device) + + # compute inverse directly + inv0 = torch.inverse(a) + + # default case + chol = torch.cholesky(a) + inv1 = torch.cholesky_inverse(chol, False) + self.assertLessEqual(inv0.dist(inv1), 1e-12) + + # upper Triangular Test + chol = torch.cholesky(a, True) + inv1 = torch.cholesky_inverse(chol, True) + self.assertLessEqual(inv0.dist(inv1), 1e-12) + + # lower Triangular Test + chol = torch.cholesky(a, False) + inv1 = torch.cholesky_inverse(chol, False) + self.assertLessEqual(inv0.dist(inv1), 1e-12) + + def _select_broadcastable_dims(self, dims_full=None): + # select full dimensionality + if dims_full is None: + dims_full = [] + ndims = random.randint(1, 4) + dims_full = [random.randint(1, 8) for _ in range(ndims)] + else: + ndims = len(dims_full) + + # select actual dimensions for ops: + # larger: full ndims, individual sizes may be reduced + # smaller: possibly reduced ndims, sizes may be reduced + smaller_ndims = random.randint(1, ndims) + dims_small = [] + dims_large = [] + for i in range(ndims - 1, -1, -1): + j = random.randint(1, 3) + if j == 1: # no reduced singleton dimension + ds = dims_full[i] + dl = dims_full[i] + elif j == 2: # larger may have reduced singleton dimension + ds = dims_full[i] + dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] + elif j == 3: # smaller may have reduced singleton dimension + ds = 1 + dl = dims_full[i] + dims_large = [dl] + dims_large + if len(dims_small) < smaller_ndims: + dims_small = [ds] + dims_small + return (dims_small, dims_large, dims_full) + + def test_broadcast_fused_matmul(self, device): + fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] + + for fn in fns: + batch_dim = random.randint(1, 8) + n_dim = random.randint(1, 8) + m_dim = random.randint(1, 8) + p_dim = random.randint(1, 8) + + def dims_full_for_fn(): + if fn == "baddbmm": + return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) + elif fn == "addbmm": + return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) + elif fn == "addmm": + return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) + elif fn == "addmv": + return ([n_dim], [n_dim, m_dim], [m_dim]) + elif fn == "addr": + return ([n_dim, m_dim], [n_dim], [m_dim]) + else: + raise AssertionError("unknown function") + + (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() + (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) + + t0_small = torch.randn(*t0_dims_small, device=device).float() + t1 = torch.randn(*t1_dims, device=device).float() + t2 = torch.randn(*t2_dims, device=device).float() + + t0_full = t0_small.expand(*t0_dims_full).to(device) + + fntorch = getattr(torch, fn) + r0 = fntorch(t0_small, t1, t2) + r1 = fntorch(t0_full, t1, t2) + self.assertEqual(r0, r1) + + @tf32_on_and_off(0.001) + def test_broadcast_batched_matmul(self, device): + n_dim = random.randint(1, 8) + m_dim = random.randint(1, 8) + p_dim = random.randint(1, 8) + full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))] + (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims) + + def verify_batched_matmul(full_lhs, one_dimensional): + if not one_dimensional: + lhs_dims = [n_dim, m_dim] + rhs_dims = [m_dim, p_dim] + result_dims = [n_dim, p_dim] + else: + lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim] + rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim] + result_dims = [n_dim] if full_lhs else [p_dim] + + lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim] + rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1] + full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims + dim0_dims = rhs_dims if full_lhs else lhs_dims + small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims) + + small = torch.randn(*(small_dims), device=device).float() + dim0 = torch.randn(*(dim0_dims), device=device).float() + full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float() + if not one_dimensional: + (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,)) + else: + (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,)) + + def maybe_squeeze_result(l, r, result): + if len(lhs_dims) == 1 and l.dim() != 1: + return result.squeeze(-2) + elif len(rhs_dims) == 1 and r.dim() != 1: + return result.squeeze(-1) + else: + return result + + for lhs in lhsTensors: + lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims))) + lhs_expanded_matmul_fn = lhs_expanded.matmul + for rhs in rhsTensors: + rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)). + expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims)))) + truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded)) + for l in (lhs, lhs_expanded): + for r in (rhs, rhs_expanded): + l_matmul_fn = l.matmul + result = maybe_squeeze_result(l, r, l_matmul_fn(r)) + self.assertEqual(truth, result) + # test torch.matmul function as well + torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r)) + self.assertEqual(truth, torch_result) + # test torch.matmul with out + out = torch.zeros_like(torch_result) + torch.matmul(l, r, out=out) + self.assertEqual(truth, maybe_squeeze_result(l, r, out)) + + # compare to bmm + bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims), + rhs_expanded.contiguous().view(-1, *rhs_mat_dims))) + self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims)) + + for indices in itertools.product((True, False), repeat=2): + verify_batched_matmul(*indices) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_lu_solve_batched_non_contiguous(self, device, dtype): + from numpy.linalg import solve + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device='cpu') + b = torch.randn(2, 2, 2, dtype=dtype, device='cpu') + x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device) + A = A.to(device).permute(0, 2, 1) + b = b.to(device).permute(2, 1, 0) + assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" + LU_data, LU_pivots = torch.lu(A) + x = torch.lu_solve(b, LU_data, LU_pivots) + self.assertEqual(x, x_exp) + + def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype).to(device) + LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot) + self.assertEqual(info, torch.zeros_like(info)) + return b, A, LU_data, LU_pivots + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_lu_solve(self, device, dtype): + def sub_test(pivot): + for k, n in zip([2, 3, 5], [3, 5, 7]): + b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype) + x = torch.lu_solve(b, LU_data, LU_pivots) + self.assertEqual(b, A.mm(x)) + + sub_test(True) + if self.device_type == 'cuda': + sub_test(False) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) + def test_lu_solve_batched(self, device, dtype): + def sub_test(pivot): + def lu_solve_batch_test_helper(A_dims, b_dims, pivot): + b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output + self.assertEqual(x_exp, x_act) # Equality check + Ax = torch.matmul(A, x_act) + self.assertEqual(b, Ax) + + for batchsize in [1, 3, 4]: + lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot) + + # Tests tensors with 0 elements + b = torch.randn(3, 0, 3, dtype=dtype, device=device) + A = torch.randn(3, 0, 0, dtype=dtype, device=device) + LU_data, LU_pivots = torch.lu(A) + self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) + + sub_test(True) + if self.device_type == 'cuda': + sub_test(False) + + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_lu_solve_batched_many_batches(self, device, dtype): + def run_test(A_dims, b_dims): + b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) + x = torch.lu_solve(b, LU_data, LU_pivots) + Ax = torch.matmul(A, x) + self.assertEqual(Ax, b.expand_as(Ax)) + + run_test((5, 65536), (65536, 5, 10)) + run_test((5, 262144), (262144, 5, 10)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_lu_solve_batched_broadcasting(self, device, dtype): + from numpy.linalg import solve + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + def run_test(A_dims, b_dims, pivot=True): + A_matrix_size = A_dims[-1] + A_batch_dims = A_dims[:-2] + A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype) + b = torch.randn(*b_dims, dtype=dtype) + x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(dtype=dtype, device=device) + A, b = A.to(device), b.to(device) + LU_data, LU_pivots = torch.lu(A, pivot=pivot) + x = torch.lu_solve(b, LU_data, LU_pivots) + self.assertEqual(x, x_exp) + + # test against numpy.linalg.solve + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b + + @precisionOverride({torch.float32: 1e-5, torch.complex64: 1e-5}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_symeig(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(dims, eigenvectors, upper): + x = random_hermitian_matrix(*dims, dtype=dtype, device=device) + if dtype.is_complex: + real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 + else: + real_dtype = dtype + oute = torch.empty(dims[1:] + dims[:1], dtype=real_dtype, device=device) + outv = torch.empty(dims[1:] + dims[:1] * 2, dtype=dtype, device=device) + torch.symeig(x, eigenvectors=eigenvectors, upper=upper, out=(oute, outv)) + + if eigenvectors: + outv_ = outv.cpu().numpy() + x_recon = np.matmul(np.matmul(outv_, torch.diag_embed(oute.to(dtype)).cpu().numpy()), + outv_.swapaxes(-2, -1).conj()) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using V @ diag(e) @ V.T') + else: + eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) + self.assertEqual(eigvals, oute, msg='Eigenvalues mismatch') + self.assertEqual(torch.empty(0, device=device, dtype=dtype), outv, msg='Eigenvector matrix not empty') + + rese, resv = x.symeig(eigenvectors=eigenvectors, upper=upper) + self.assertEqual(rese, oute, msg="outputs of symeig and symeig with out don't match") + self.assertEqual(resv, outv, msg="outputs of symeig and symeig with out don't match") + + # test non-contiguous + x = random_hermitian_matrix(*dims, dtype=dtype, device=device) + n_dim = len(dims) + 1 + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + rese, resv = torch.symeig(x, eigenvectors=eigenvectors, upper=upper) + if eigenvectors: + resv_ = resv.cpu().numpy() + x_recon = np.matmul(np.matmul(resv_, torch.diag_embed(rese.to(dtype)).cpu().numpy()), + resv_.swapaxes(-2, -1).conj()) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using V @ diag(e) @ V.T') + else: + eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) + self.assertEqual(eigvals, rese, msg='Eigenvalues mismatch') + self.assertEqual(torch.empty(0, device=device, dtype=dtype), resv, msg='Eigenvector matrix not empty') + + batch_dims_set = [(), (3,), (3, 5), (5, 3, 5)] + for batch_dims, eigenvectors, upper in itertools.product(batch_dims_set, (True, False), (True, False)): + run_test((5,) + batch_dims, eigenvectors, upper) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_pca_lowrank(self, device): + from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix + + dtype = torch.double + + def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options): + density = options.pop('density', 1) + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + if density == 1: + a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) + a = a_input + else: + a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) + a = a_input.to_dense() + + u, s, v = pca(a_input, q=guess_rank, **options) + + self.assertEqual(s.shape[-1], guess_rank) + self.assertEqual(u.shape[-2], rows) + self.assertEqual(u.shape[-1], guess_rank) + self.assertEqual(v.shape[-1], guess_rank) + self.assertEqual(v.shape[-2], columns) + + A1 = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) + ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device) + c = a.sum(axis=-2) / rows + c = c.reshape(batches + (1, columns)) + A2 = a - ones_m1.matmul(c) + self.assertEqual(A1, A2) + + if density == 1: + # actual rank is known only for dense input + detect_rank = (s.abs() > 1e-5).sum(axis=-1) + self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank) + U, S, V = torch.svd(A2) + self.assertEqual(s[..., :actual_rank], S[..., :actual_rank]) + + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size, all_batches in [ + (2, (17, 4), all_batches), + (2, (100, 4), all_batches), + (6, (100, 40), all_batches), + (12, (1000, 1000), [()]), + ]: + for batches in all_batches: + for guess_rank in [ + actual_rank, + actual_rank + 2, + actual_rank + 6, + ]: + if guess_rank <= min(*size): + run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank) + run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank) + + # sparse input + for guess_rank, size in [ + (4, (17, 4)), (4, (4, 17)), (16, (17, 17)), + (21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]: + for density in [0.005, 0.1]: + run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density) + + # jitting support + jitted = torch.jit.script(torch.pca_lowrank) + guess_rank, actual_rank, size, batches = 2, 2, (17, 4), () + run_subtest(guess_rank, actual_rank, size, batches, device, jitted) + + # Ensure that nuclear_norm's out variant gives the same result as the non-out + @onlyOnCPUAndCUDA + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64) + def test_nuclear_norm_out(self, device, dtype): + test_cases = [ + # input size, dim + ((25, 25), None), + ((25, 25), (0, 1)), + ((25, 25), (1, 0)), + ((25, 25, 25), (2, 0)), + ((25, 25, 25), (0, 1)), + ] + for keepdim in [False, True]: + for input_size, dim in test_cases: + msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}' + x = torch.randn(*input_size, device=device, dtype=dtype) + result_out = torch.empty(0, device=device, dtype=dtype) + if dim is None: + result = torch.nuclear_norm(x, keepdim=keepdim) + torch.nuclear_norm(x, keepdim=keepdim, out=result_out) + else: + result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim) + torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out) + self.assertEqual(result, result_out, msg=msg) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_geqrf(self, device): + a = torch.randn(5, 5, device=device) + b, c = torch.geqrf(a) + b_placeholder, c_placeholder = torch.empty_like(b), torch.empty_like(c) + torch.geqrf(a, out=(b_placeholder, c_placeholder)) + self.assertEqual(b, b_placeholder) + self.assertEqual(c, c_placeholder) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_lstsq(self, device, dtype): + def _test_underdetermined(a, b, expectedNorm): + # underdetermined systems are only supported on CPU + if self.device_type != 'cpu': + return + + m = a.size()[0] + n = a.size()[1] + assert(m <= n) + + a_copy = a.clone() + b_copy = b.clone() + res1 = torch.lstsq(b, a)[0] + self.assertEqual(a, a_copy, atol=0, rtol=0) + self.assertEqual(b, b_copy, atol=0, rtol=0) + self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, atol=1e-8, rtol=0) + + ta = torch.tensor((), dtype=dtype, device=device) + tb = torch.tensor((), dtype=dtype, device=device) + res2 = torch.lstsq(b, a, out=(tb, ta))[0] + self.assertEqual(a, a_copy, atol=0, rtol=0) + self.assertEqual(b, b_copy, atol=0, rtol=0) + self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, atol=1e-8, rtol=0) + + res3 = torch.lstsq(b, a, out=(b, a))[0] + self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, atol=1e-8, rtol=0) + self.assertEqual(res1, tb, atol=0, rtol=0) + self.assertEqual(res1, b, atol=0, rtol=0) + self.assertEqual(res1, res2, atol=0, rtol=0) + self.assertEqual(res1, res3, atol=0, rtol=0) + + def _test_overdetermined(a, b, expectedNorm): + m = a.size()[0] + n = a.size()[1] + assert(m > n) + + def check_norm(a, b, expected_norm, gels_result): + # Checks |ax - b| and the residual info from the result + + # The first n rows is the least square solution. + # Rows n to m-1 contain residual information. + x = gels_result[:n] + resid_info = gels_result[n:] + + resid_norm = (torch.mm(a, x) - b).norm() + self.assertEqual(resid_norm, expectedNorm, atol=1e-8, rtol=0) + self.assertEqual(resid_info.norm(), resid_norm, atol=1e-8, rtol=0) + + a_copy = a.clone() + b_copy = b.clone() + res1 = torch.lstsq(b, a)[0] + self.assertEqual(a, a_copy, atol=0, rtol=0) + self.assertEqual(b, b_copy, atol=0, rtol=0) + check_norm(a, b, expectedNorm, res1) + + ta = torch.tensor((), dtype=dtype, device=device) + tb = torch.tensor((), dtype=dtype, device=device) + res2 = torch.lstsq(b, a, out=(tb, ta))[0] + self.assertEqual(a, a_copy, atol=0, rtol=0) + self.assertEqual(b, b_copy, atol=0, rtol=0) + check_norm(a, b, expectedNorm, res2) + + res3 = torch.lstsq(b, a, out=(b, a))[0] + check_norm(a_copy, b_copy, expectedNorm, res3) + + self.assertEqual(res1, tb, atol=0, rtol=0) + self.assertEqual(res1, b, atol=0, rtol=0) + self.assertEqual(res1, res2, atol=0, rtol=0) + self.assertEqual(res1, res3, atol=0, rtol=0) + + # basic test + expectedNorm = 0 + a = torch.tensor(((1.44, -9.96, -7.55, 8.34), + (-7.84, -0.28, 3.24, 8.09), + (-4.39, -3.24, 6.27, 5.28), + (4.53, 3.83, -6.64, 2.06)), dtype=dtype, device=device).t() + b = torch.tensor(((8.58, 8.26, 8.48, -5.28), + (9.35, -4.43, -0.70, -0.26)), dtype=dtype, device=device).t() + _test_underdetermined(a, b, expectedNorm) + + # test overdetermined + expectedNorm = 17.390200628863 + a = torch.tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45), + (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70), + (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19), + (4.53, 3.83, -6.64, 2.06, -2.47, 4.70)), dtype=dtype, device=device).t() + b = torch.tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93), + (9.35, -4.43, -0.70, -0.26, -7.36, -2.52)), dtype=dtype, device=device).t() + _test_overdetermined(a, b, expectedNorm) + + # test underdetermined + expectedNorm = 0 + a = torch.tensor(((1.44, -9.96, -7.55), + (-7.84, -0.28, 3.24), + (-4.39, -3.24, 6.27), + (4.53, 3.83, -6.64)), dtype=dtype, device=device).t() + b = torch.tensor(((8.58, 8.26, 8.48), + (9.35, -4.43, -0.70)), dtype=dtype, device=device).t() + _test_underdetermined(a, b, expectedNorm) + + # test reuse + expectedNorm = 0 + a = torch.tensor(((1.44, -9.96, -7.55, 8.34), + (-7.84, -0.28, 3.24, 8.09), + (-4.39, -3.24, 6.27, 5.28), + (4.53, 3.83, -6.64, 2.06)), dtype=dtype, device=device).t() + b = torch.tensor(((8.58, 8.26, 8.48, -5.28), + (9.35, -4.43, -0.70, -0.26)), dtype=dtype, device=device).t() + ta = torch.tensor((), dtype=dtype, device=device) + tb = torch.tensor((), dtype=dtype, device=device) + torch.lstsq(b, a, out=(tb, ta)) + self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, atol=1e-8, rtol=0) + torch.lstsq(b, a, out=(tb, ta)) + self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, atol=1e-8, rtol=0) + torch.lstsq(b, a, out=(tb, ta)) + self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, atol=1e-8, rtol=0) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_lapack_empty(self, device): + # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here. + # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although + # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing" + # (e.g. lu). We often name our functions identically to the lapack function, so it will take work + # to name / migrate-to better wrappers. + def fn(torchfn, *args): + return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape + for shape in args)) + + # inverse, pinverse + self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape) + self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape) + self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape) + self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape) + + # det, logdet, slogdet + self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0))) + self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0))) + self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), + fn(torch.slogdet, (0, 0))) + + # eig, symeig + evalues, evectors = fn(torch.eig, (0, 0), True) + self.assertEqual([(0, 2), (0, 0)], [evalues.shape, evectors.shape]) + evalues, evectors = fn(torch.symeig, (0, 0), True) + self.assertEqual([(0,), (0, 0)], [evalues.shape, evectors.shape]) + + # qr + q, r = fn(torch.qr, (3, 0), True) + self.assertEqual([(3, 0), (0, 0)], [q.shape, r.shape]) + q, r = fn(torch.qr, (0, 3), True) + self.assertEqual([(0, 0), (0, 3)], [q.shape, r.shape]) + q, r = fn(torch.qr, (3, 0), False) + self.assertEqual([(3, 3), (3, 0)], [q.shape, r.shape]) + + # lstsq + self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0, 0), torch.randn(0, 0))) + self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0,), torch.randn(0, 0))) + + @tf32_on_and_off(0.005) + def test_tensordot(self, device): + a = torch.arange(60., device=device).reshape(3, 4, 5) + b = torch.arange(24., device=device).reshape(4, 3, 2) + c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), + axes=([1, 0], [0, 1]))) + self.assertEqual(c, cn) + + cout = torch.zeros((5, 2)) + torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() + self.assertEqual(c, cout) + + a = torch.randn(2, 3, 4, 5, device=device) + b = torch.randn(4, 5, 6, 7, device=device) + c = torch.tensordot(a, b, dims=2).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), + axes=2)) + + with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): + torch.tensordot(a, b, dims=-1) + + self.assertEqual(c, cn) + c = torch.tensordot(a, b).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) + self.assertEqual(c, cn) + + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/test/test_metal.py b/test/test_metal.py new file mode 100644 index 0000000000000..f5a77d0b06d6c --- /dev/null +++ b/test/test_metal.py @@ -0,0 +1,159 @@ +import torch +from torch.nn import functional as F + +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing import FileCheck +import io + +class TestMetalRewritePass(TestCase): + @staticmethod + def validate_transformed_module( + # To please flake + self, + pattern_count_map, + data_shape, + prepack_removal=False, + fuse_clamping_ops=False): + module_instance = self + scripted_model = torch.jit.script(module_instance) + scripted_model.eval() + input_data = torch.normal(1, 20, size=data_shape) + ref_result = scripted_model(input_data) + torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c) + if fuse_clamping_ops or prepack_removal: + scripted_model._c = torch._C._freeze_module(scripted_model._c) + if fuse_clamping_ops: + torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c) + if prepack_removal: + torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c) + + buffer = io.BytesIO() + torch.jit.save(scripted_model, buffer) + buffer.seek(0) + deserialized_scripted_model = torch.jit.load(buffer) + for pattern, v in pattern_count_map.items(): + if (v == 0): + FileCheck().check(pattern).run(deserialized_scripted_model.graph) + elif (v == -1): + FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) + else: + FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) + + def test_conv(self): + # Conv params + batch_size = 2 + input_channels_per_group = 6 + height = 16 + width = 16 + output_channels_per_group = 6 + groups = 4 + kernel_h = kernel_w = 3 + stride_h = stride_w = 1 + pad_h = pad_w = 1 + dilation = 1 + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_h, kernel_w) + strides = (stride_h, stride_w) + paddings = (pad_h, pad_w) + dilations = (dilation, dilation) + conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) + conv_bias_shape = (output_channels) + + class Conv2D(torch.nn.Module): + def __init__(self): + super(Conv2D, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv2d(x, self.weight, self.bias, + self.strides, self.paddings, self.dilations, self.groups) + + data_shape = (batch_size, input_channels, height, width) + pattern_count_map = {"Tensor = aten::conv2d": -1, + "metal_prepack::conv2d_prepack": 1, + "metal_prepack::conv2d_run": 1} + TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) + + class Conv2DRelu(torch.nn.Module): + def __init__(self): + super(Conv2DRelu, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + o = F.conv2d(x, self.weight, self.bias, + self.strides, self.paddings, self.dilations, self.groups) + o = F.relu(o) + return o + + data_shape = (batch_size, input_channels, height, width) + pattern_count_map = {"Tensor = aten::conv2d": -1, + "metal_prepack::conv2d_prepack": 1, + "metal_prepack::conv2d_run": 1} + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), pattern_count_map, data_shape) + + pattern_count_map["aten::relu"] = 1 + pattern_count_map["metal_prepack::conv2d_prepack"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), + pattern_count_map, + data_shape, + prepack_removal=True) + pattern_count_map["aten::relu"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), + pattern_count_map, + data_shape, + prepack_removal=True, + fuse_clamping_ops=True) + + + class Conv2DHardtanh(torch.nn.Module): + def __init__(self): + super(Conv2DHardtanh, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + o = F.conv2d(x, self.weight, self.bias, + self.strides, self.paddings, self.dilations, self.groups) + o = F.hardtanh(o) + return o + + data_shape = (batch_size, input_channels, height, width) + pattern_count_map = {"Tensor = aten::conv2d": -1, + "metal_prepack::conv2d_prepack": 1, + "metal_prepack::conv2d_run": 1} + TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) + pattern_count_map["aten::hardtanh"] = 1 + pattern_count_map["metal_prepack::conv2d_prepack"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DHardtanh(), + pattern_count_map, + data_shape, + prepack_removal=True) + pattern_count_map["aten::hardtanh"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), + pattern_count_map, + data_shape, + prepack_removal=True, + fuse_clamping_ops=True) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index c856b5bd570ad..66187fe2463a4 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -19,6 +19,7 @@ from torch.autograd.gradcheck import gradgradcheck, gradcheck +types = [torch.float, torch.bfloat16] # Comment the line below to find out the CI machines having MKL-DNN build disabled @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") @@ -29,17 +30,55 @@ def test_conversion(self): torch.randn((1, 2, 3, 4, 5), dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]: cpu_tensor.requires_grad_() - mkldnn_tensor = cpu_tensor.to_mkldnn() - cpu_tensor_1 = mkldnn_tensor.to_dense() - self.assertEqual(cpu_tensor, cpu_tensor_1) - self.assertEqual(mkldnn_tensor.dtype, torch.float) - self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) - self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) - self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) - self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) - self.assertRaisesRegex(RuntimeError, - "Cannot access data pointer of Tensor that doesn't have storage", - lambda: mkldnn_tensor.data_ptr() != 0) + # float cpu tensor to mkldnn float tensor or bfloat tensor. + for dtype1 in types: + mkldnn_tensor = cpu_tensor.to_mkldnn(dtype1) + self.assertEqual(mkldnn_tensor.dtype, dtype1) + cpu_tensor_1 = mkldnn_tensor.to_dense() + # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor + self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) + # mkldnn float/bfloat tensor to cpu float or bfloat tensor + for dtype2 in types: + cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2) + self.assertEqual(cpu_tensor_2.dtype, dtype2) + atol = 1e-5 if dtype1 == torch.float and dtype2 == torch.float else 1e-2 + self.assertEqual(cpu_tensor, cpu_tensor_2.float(), atol=atol, rtol=0) + + self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) + self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) + self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) + if dtype1 == torch.float: + self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) + else: + self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size() / 2) + self.assertRaisesRegex(RuntimeError, + "Cannot access data pointer of Tensor that doesn't have storage", + lambda: mkldnn_tensor.data_ptr() != 0) + + # bfloat cpu tensor to mkldnn float tensor or bfloat tensor. + cpu_tensor_bf16 = cpu_tensor.bfloat16() + for dtype1 in types: + mkldnn_tensor = cpu_tensor_bf16.to_mkldnn(dtype1) + self.assertEqual(mkldnn_tensor.dtype, dtype1) + cpu_tensor_1 = mkldnn_tensor.to_dense() + # not given dtype for to_dense, mkldnn tensor has same dtype with cpu tensor + self.assertEqual(mkldnn_tensor.dtype, cpu_tensor_1.dtype) + # mkldnn float/bfloat tensor to cpu float or bfloat tensor + for dtype2 in types: + cpu_tensor_2 = mkldnn_tensor.to_dense(dtype2) + self.assertEqual(cpu_tensor_2.dtype, dtype2) + self.assertEqual(cpu_tensor_bf16, cpu_tensor_2.bfloat16(), atol=1e-5, rtol=0) + + self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) + self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) + self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) + if dtype1 == torch.bfloat16: + self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_bf16.element_size()) + else: + self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor_bf16.element_size() * 2) + self.assertRaisesRegex(RuntimeError, + "Cannot access data pointer of Tensor that doesn't have storage", + lambda: mkldnn_tensor.data_ptr() != 0) def test_unsupported(self): # unsupported types and unsupported types with gpu diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index eae6175fb0244..11ef019a26dbe 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.backends.xnnpack import torch.utils.bundled_inputs +from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.jit_utils import get_forward, get_forward_graph from torch.utils.mobile_optimizer import * from torch.nn import functional as F @@ -11,7 +12,7 @@ FileCheck = torch._C.FileCheck -class TestOptimizer(unittest.TestCase): +class TestOptimizer(TestCase): @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." @@ -95,13 +96,13 @@ def forward(self, x): .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ .check_not("aten::add(") \ .check_not("aten::relu(") \ - .check_count("aten::add_relu(", 1, exactly=True) \ + .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.graph) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) - optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} - optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blacklist_no_prepack) + optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} + optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack) optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data) FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \ @@ -118,19 +119,36 @@ def forward(self, x): FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ .run(str(get_forward(bn_scripted_module._c).graph)) - optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} - bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_prepack) + optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} + bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack) self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) - optimization_blacklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION} - no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_fold_bn) + optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION} + no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn) FileCheck().check_count("aten::batch_norm", 1, exactly=True) \ .run(str(get_forward_graph(no_bn_fold_scripted_module._c))) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) + class MyMobileOptimizedTagTest(torch.nn.Module): + def __init__(self): + super(MyMobileOptimizedTagTest, self).__init__() + self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape))) + self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim)))) + + def forward(self, x): + o = F.linear(x, self.linear_weight, self.linear_bias) + return F.relu(o) + + mobile_optimized_tag_module = MyMobileOptimizedTagTest() + m = torch.jit.script(mobile_optimized_tag_module) + m.eval() + opt_m = optimize_for_mobile(m) + tag = getattr(opt_m, "mobile_optimized", None) + self.assertTrue(tag) + class MyPreserveMethodsTest(torch.nn.Module): def __init__(self): super(MyPreserveMethodsTest, self).__init__() @@ -251,6 +269,69 @@ def get_lint_count_by_type(lint_type, module_lint_List): bi_module_lint_list = generate_mobile_module_lints(bi_module) self.assertEqual(len(bi_module_lint_list), 0) + def test_preserve_bundled_inputs_methods(self): + class MyBundledInputModule(torch.nn.Module): + def __init__(self): + super(MyBundledInputModule, self).__init__() + + def forward(self, inputs): + return inputs + + class MyIncompleteBundledInputModule(torch.nn.Module): + def __init__(self): + super(MyIncompleteBundledInputModule, self).__init__() + + def forward(self, inputs): + return inputs + + @torch.jit.export + def get_all_bundled_inputs(self): + pass + + bi_module = torch.jit.script(MyBundledInputModule()) + module_optim_bi_not_preserved = optimize_for_mobile(bi_module) + + # Expected to be False since no bundled inputs methods were added + self.assertFalse( + hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or + hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs') or + hasattr(module_optim_bi_not_preserved, 'run_on_bundled_input') + ) + + # We expect an exception here + with self.assertRaises(AttributeError): + module_optim_bi_not_preserved.run_on_bundled_input(0) + + # Add bundled inputs methods to the module + torch.utils.bundled_inputs.augment_model_with_bundled_inputs( + bi_module, [(torch.tensor([1]),)], []) + # Now they should be preserved + module_optim_bi_preserved = optimize_for_mobile(bi_module) + + # All of the bundled inputs methods were preserved + self.assertTrue( + hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and + hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs') and + hasattr(module_optim_bi_preserved, 'run_on_bundled_input') + ) + + # We do not expect an exception here + module_optim_bi_preserved.run_on_bundled_input(0) + + bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0] + module_optim_bi_preserved(*bundled_input) + + # If not all 3 bundled inputs methods are present in the module, + # we will not try to preserve them unless specified by the user. + incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule()) + incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module) + self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) + + # Specifically preserve get_all_bundled_inputs even if it's the only one + # bundled inputs method available. + incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs']) + self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) + @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.") @@ -323,7 +404,7 @@ def _quant_script_and_optimize(model): m, m_optim = _quant_script_and_optimize(Standalone()) FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ - .check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \ + .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) self.assertFalse(hasattr(m_optim, "conv2")) @@ -337,7 +418,7 @@ def _quant_script_and_optimize(model): m, m_optim = _quant_script_and_optimize(Parent()) FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ - .check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \ + .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) self.assertFalse(hasattr(m_optim, "child")) @@ -349,4 +430,4 @@ def _quant_script_and_optimize(model): if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 49e0a3cb45c00..81b33c5900dbc 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -3,7 +3,6 @@ import os import sys import time -import subprocess import unittest import copy from sys import platform @@ -368,8 +367,11 @@ def test_inherit_tensor(self): t = torch.zeros(5, 5) p = SubProcess(t.share_memory_()) p.start() - p.join(1) - self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0) + p.join(2) + if p.exitcode is None: + print("test_inherit_tensor: SubProcess too slow") + else: + self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0) @unittest.skipIf(IS_WINDOWS, "Test needs to use fork multiprocessing") def test_autograd_errors(self): @@ -522,7 +524,7 @@ def test_cuda_bad_call(self): @unittest.skipIf(IS_WINDOWS, 'not applicable to Windows (only fails with fork)') @unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available') def test_wrong_cuda_fork(self): - results = self.run_process_no_exception("""\ + stderr = TestCase.runWithPytorchAPIUsageStderr("""\ import torch from torch.multiprocessing import Process def run(rank): @@ -539,7 +541,7 @@ def run(rank): for p in processes: p.join() """) - self.assertRegex(results[1].decode('ascii'), "Cannot re-initialize CUDA in forked subprocess.") + self.assertRegex(stderr, "Cannot re-initialize CUDA in forked subprocess.") @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ don't support multiprocessing with spawn start method") @@ -828,15 +830,6 @@ def test_cuda_parameter_sharing(self): param = Parameter(torch.arange(1., 26, device='cuda').view(5, 5)) self._test_autograd_sharing(param, mp.get_context('spawn'), is_parameter=True) - @staticmethod - def run_process_no_exception(code): - popen = subprocess.Popen( - [sys.executable, '-c', code], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - pipes = popen.communicate() - return pipes - @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ don't support multiprocessing with spawn start method") def test_integer_parameter_serialization(self): diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index 3e995231c68c5..51b9a3598c631 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -54,6 +54,15 @@ def test_nested_child_body(i, ready_queue, nested_child_sleep): time.sleep(nested_child_sleep) +def test_infinite_task(i): + while True: + time.sleep(1) + + +def test_process_exit(idx): + sys.exit(12) + + def test_nested(i, pids_queue, nested_child_sleep, start_method): context = mp.get_context(start_method) nested_child_ready_queue = context.Queue() @@ -184,6 +193,23 @@ def test_nested(self): class SpawnTest(TestCase, _TestMultiProcessing): start_method = 'spawn' + def test_exception_raises(self): + with self.assertRaises(mp.ProcessRaisedException): + mp.spawn(test_success_first_then_exception_func, args=(), nprocs=1) + + def test_signal_raises(self): + context = mp.spawn(test_infinite_task, args=(), nprocs=1, join=False) + for pid in context.pids(): + os.kill(pid, signal.SIGTERM) + with self.assertRaises(mp.ProcessExitedException): + context.join() + + def test_process_exited(self): + with self.assertRaises(mp.ProcessExitedException) as e: + mp.spawn(test_process_exit, args=(), nprocs=1) + self.assertEqual(12, e.exit_code) + + @unittest.skipIf( IS_WINDOWS, "Fork is only available on Unix", diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 5358d2bbab104..07c66775c9487 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -146,16 +146,18 @@ def _test_factory(self, factory, device): names65 = ['A' * i for i in range(1, 66)] x = factory([1] * 65, names=names64, device=device) - def test_none_names_refcount(self): + def test_none_names_refcount(self, N=10): def scope(): unnamed = torch.empty(2, 3) unnamed.names # materialize [None, None] prev_none_refcnt = sys.getrefcount(None) - scope() - self.assertEqual(sys.getrefcount(None), prev_none_refcnt, - msg='Using tensor.names should not change ' - 'the refcount of Py_None') + # Ran it N times to reduce flakiness + [scope() for i in range(N)] + after_none_refcnt = sys.getrefcount(None) + self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2, + msg='Using tensor.names should not change ' + 'the refcount of Py_None') def test_has_names(self): unnamed = torch.empty(2, 3) @@ -1218,6 +1220,7 @@ def kthvalue_wrapper(tensor, *args, **kwargs): Case(torch.mode, False, False, True, True, values_and_indices), Case(kthvalue_wrapper, False, False, True, True, values_and_indices), Case(torch.median, True, False, True, True, values_and_indices), + Case(torch.nanmedian, True, False, True, True, values_and_indices), ] for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()): diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 785ae4c4fb379..00432c9e71cd1 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -1,22 +1,24 @@ import os import re import yaml -import unittest import textwrap import torch + +from torch.testing._internal.common_utils import TestCase, run_tests from collections import namedtuple path = os.path.dirname(os.path.realpath(__file__)) aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml') all_operators_with_namedtuple_return = { - 'max', 'min', 'median', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', + 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', - 'triangular_solve', 'cummax', 'cummin' + 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", 'linalg_qr', + '_svd_helper', 'linalg_svd', 'linalg_slogdet', } -class TestNamedTupleAPI(unittest.TestCase): +class TestNamedTupleAPI(TestCase): def test_native_functions_yaml(self): operators_found = set() @@ -52,29 +54,52 @@ def test_namedtuple_return(self): op = namedtuple('op', ['operators', 'input', 'names', 'hasout']) operators = [ - op(operators=['max', 'min', 'median', 'mode', 'sort', 'topk', 'cummax', 'cummin'], input=(0,), + op(operators=['max', 'min', 'median', 'nanmedian', 'mode', 'sort', 'topk', 'cummax', 'cummin'], input=(0,), names=('values', 'indices'), hasout=True), op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), - op(operators=['svd'], input=(), names=('U', 'S', 'V'), hasout=True), + op(operators=['svd', '_svd_helper', 'linalg_svd'], input=(), names=('U', 'S', 'V'), hasout=True), op(operators=['slogdet'], input=(), names=('sign', 'logabsdet'), hasout=False), - op(operators=['qr'], input=(), names=('Q', 'R'), hasout=True), + op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True), op(operators=['solve'], input=(a,), names=('solution', 'LU'), hasout=True), op(operators=['geqrf'], input=(), names=('a', 'tau'), hasout=True), op(operators=['symeig', 'eig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True), op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True), op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), + op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), + op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True), + op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False), ] + def get_func(f): + "Return either torch.f or torch.linalg.f, where 'f' is a string" + if f.startswith('linalg_'): + return getattr(torch.linalg, f[7:]) + return getattr(torch, f, None) + + def check_namedtuple(tup, names): + "Check that the namedtuple 'tup' has the given names" + for i, name in enumerate(names): + self.assertIs(getattr(tup, name), tup[i]) + for op in operators: for f in op.operators: - ret = getattr(a, f)(*op.input) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - if op.hasout: - ret1 = getattr(torch, f)(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) + # 1. check the namedtuple returned by calling torch.f + func = get_func(f) + if func: + ret1 = func(a, *op.input) + check_namedtuple(ret1, op.names) + # + # 2. check the out= variant, if it exists + if func and op.hasout: + ret2 = func(a, *op.input, out=tuple(ret1)) + check_namedtuple(ret2, op.names) + # + # 3. check the Tensor.f method, if it exists + meth = getattr(a, f, None) + if meth: + ret3 = meth(*op.input) + check_namedtuple(ret3, op.names) all_covered_operators = set([x for y in operators for x in y.operators]) @@ -85,4 +110,4 @@ def test_namedtuple_return(self): if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_native_functions.py b/test/test_native_functions.py index 869c7aad47fbc..2bb5aafa832da 100644 --- a/test/test_native_functions.py +++ b/test/test_native_functions.py @@ -58,7 +58,7 @@ def fake_module(values, const): self.do_test_optional_floatlist_with_module(fake_module) def test_optional_floatlist_invalid(self): - with self.assertRaisesRegex(TypeError, "must be .* but found"): + with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"): FloatListWrapperModule()(torch.zeros(1), ["hi"]) with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): @@ -176,6 +176,22 @@ def fake_module(values, const): self.do_test_optional_filled_intlist_with_module(fake_module) + def test_string_defaults(self): + dummy = torch.rand(1) + fn = torch._C._nn._test_string_default + fn(dummy) + + with self.assertRaisesRegex(RuntimeError, "A"): + fn(dummy, a="") + + with self.assertRaisesRegex(RuntimeError, "B"): + fn(dummy, b="") + + def f(x): + torch._C._nn._test_string_default(x) + scripted_fn = torch.jit.script(f) + scripted_fn(dummy) + if __name__ == '__main__': run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index 4020eb0cf3081..63f4c02da8b92 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -32,19 +32,20 @@ from torch.autograd import gradcheck from torch.autograd.gradcheck import gradgradcheck from torch.nn import Parameter +from torch.nn.parameter import UninitializedParameter from torch.nn.parallel._functions import Broadcast from torch.testing import get_all_fp_dtypes from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \ get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \ - ALL_TENSORTYPES2, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC + ALL_TENSORTYPES2, suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, \ ctcloss_reference, new_module_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ - skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, largeCUDATensorTest, onlyOnCPUAndCUDA, \ + skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, onlyOnCPUAndCUDA, \ deviceCountAtLeast, expectedAlertNondeterministic, largeTensorTest from torch.nn import MultiheadAttention @@ -53,6 +54,10 @@ from torch.testing._internal.common_utils import _assertGradAndGradgradChecks from torch.testing._internal.common_utils import dtype2prec_DONTUSE from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on +from torch.types import _TensorOrTensors + + +AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -107,7 +112,7 @@ def _ordered_sequence(self, tensor_type): def _padded_sequence(self, tensor_type): """Create Tensor of random padded sequences""" ordered = self._ordered_sequence(tensor_type) - lengths = list(map(len, ordered)) + lengths = [len(i) for i in ordered] padded_tensor = rnn_utils.pad_sequence(ordered) return padded_tensor, lengths @@ -313,15 +318,19 @@ class TestNN(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True - def _forward(self, module, input): + def _forward(self, module, input: _TensorOrTensors): with freeze_rng_state(): - return module(input) + if isinstance(input, tuple): + return module(*input) + else: + return module(input) - def _backward(self, module, input, output, grad_output, create_graph=False): + def _backward(self, module, input: _TensorOrTensors, output, grad_output, create_graph=False): output.backward(grad_output, retain_graph=True, create_graph=create_graph) - if input.grad is None: - return None - return input.grad.data + if isinstance(input, tuple): + return tuple(i.grad.data if i.grad is not None else None for i in input) + else: + return input.grad.data if input.grad is not None else None def _forward_criterion(self, criterion, input, target, extra_args=None): if extra_args is None: @@ -333,19 +342,20 @@ def _forward_criterion(self, criterion, input, target, extra_args=None): output = criterion(input, target, *extra_args) return output - def _backward_criterion(self, criterion, input, target, gradOutput=None, extra_args=None): + def _backward_criterion(self, criterion, input, output, target, gradOutput=None, extra_args=None): if extra_args is None: extra_args = tuple() input_tuple = input if isinstance(input, tuple) else (input,) + output_tuple = output if isinstance(output, tuple) else (output,) for i in input_tuple: if i.grad is not None: i.grad.data.zero_() args = input_tuple + (target,) + extra_args if gradOutput is None: gradOutput = torch.ones(()) - criterion(*args).backward(gradOutput.to(input_tuple[0])) + criterion(*args).backward(gradOutput.to(output_tuple[0])) if isinstance(input, tuple): - return tuple(map(lambda i: i.grad.data, input)) + return tuple(i.grad.data for i in input) else: return input.grad.data @@ -446,7 +456,7 @@ def forward(self, inp): for b in net.buffers(): self.assertTrue(b.storage().is_shared()) - def test_hooks(self): + def _test_hooks(self, backward_register_fn): module = nn.Sigmoid() input = torch.ones(5, 5, requires_grad=True) @@ -477,7 +487,7 @@ def bw_hook(inc, h_module, grad_input, grad_output): self.assertEqual(counter['forwards'], 2) self.assertEqual(counter['backwards'], 0) - test_bwd = module.register_backward_hook( + test_bwd = getattr(module, backward_register_fn)( lambda *args: bw_hook(1, *args)) output = module(input) @@ -498,7 +508,7 @@ def bw_hook(inc, h_module, grad_input, grad_output): self.assertEqual(counter['forwards'], 6) self.assertEqual(counter['backwards'], 2) - test2_bwd = module.register_backward_hook(lambda *args: bw_hook(2, *args)) + test2_bwd = getattr(module, backward_register_fn)(lambda *args: bw_hook(2, *args)) module(input).backward(torch.ones(5, 5) * 2) self.assertEqual(counter['forwards'], 9) @@ -519,17 +529,19 @@ def bw_hook(inc, h_module, grad_input, grad_output): test_fwd.remove() test_bwd.remove() + def test_hooks(self): + self._test_hooks("register_backward_hook") + self._test_hooks("register_full_backward_hook") + def test_hook_cpp(self): - counter = [0] bn = nn.BatchNorm1d(5) def hook(module, grad_inputs, grad_outputs): - counter[0] += 1 - self.assertEqual(len(grad_inputs), 3) + self.assertEqual(len(grad_inputs), 1) self.assertEqual(len(grad_outputs), 1) self.assertEqual(module, bn) - bn.register_backward_hook(hook) + bn.register_full_backward_hook(hook) output = bn(torch.randn(5, 5, requires_grad=True)) output.sum().backward() @@ -551,6 +563,165 @@ def bw_fail2(self, grad_input, grad_output): with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): module(input).sum().backward() + def test_hook_requires_grad(self): + test_self = self + + class MyModule(nn.Module): + def forward(self, arg1, arg2, arg3): + test_self.assertTrue(arg1.requires_grad) + test_self.assertFalse(arg2.requires_grad) + test_self.assertTrue(arg3.requires_grad) + return arg1.sum() + arg2.sum() + arg3.sum() + + inp = torch.rand(2, requires_grad=True) + mod = MyModule() + + mod(inp, inp.detach(), inp) + # Ensure that requires grad is properly propagated + mod.register_full_backward_hook(lambda mod, gI, gO: None) + mod(inp, inp.detach(), inp) + + def test_hook_extra_input(self): + class MyModule(nn.Module): + def forward(self, non_tensor, tensor): + return tensor.clone(), non_tensor + + inp = torch.rand(2, requires_grad=True) + mod = MyModule() + + def hook(mod, grad_input, grad_output): + self.assertIsNone(grad_input[0]) + self.assertIsInstance(grad_input[1], torch.Tensor) + + self.assertIsInstance(grad_output[0], torch.Tensor) + self.assertIsNone(grad_output[1]) + + mod.register_full_backward_hook(hook) + out, _ = mod(True, inp) + out.sum().backward() + + def test_hook_inplace(self): + class MyModule(nn.Module): + def forward(self, inp, do_inplace): + self.inp = inp + if do_inplace: + inp += 1 + return inp.clone() + + hook_called = [0] + + def hook(mod, grad_input, grad_output): + hook_called[0] += 1 + + inp = torch.rand(10, requires_grad=True) + mod = MyModule() + mod.register_full_backward_hook(hook) + + # No inplace should work + mod(inp, False).sum().backward() + self.assertEqual(hook_called[0], 1) + + # Input inplace error should throw an error (warning during deprecation cycle) + with self.assertWarnsRegex(UserWarning, "Output 0 of BackwardHookFunctionBackward is " + "a view and is being modified inplace."): + mod(inp.clone(), True) + + # Input inplace error should throw an error if we try to re-use the view after they have + # been modified (warning during deprecation cycle) + local_inp = inp.clone() + out = mod(local_inp, False) + local_inp[0] *= 1 + with self.assertWarnsRegex(UserWarning, "Output 0 of BackwardHookFunctionBackward is " + "a view and its base or another view"): + # Any operation involving the view will fail here + mod.inp + 2 + + # Output inplace error should throw an error (warning during deprecation cycle) + with self.assertWarnsRegex(UserWarning, "BackwardHookFunctionBackward is a view " + "and is being modified inplace."): + # This error won't happen once the warning above is a proper error + with self.assertRaisesRegex(RuntimeError, "Module backward hook for grad_input is " + "called before the grad_output one."): + out = mod(inp, False) + out += 1 + out.sum().backward() + + def test_hook_non_full_warning(self): + def noop(*args): + pass + + a = torch.rand(2, requires_grad=True) + b = torch.rand(2, requires_grad=True) + + # Check invalid input container + class MyModule(nn.Module): + def forward(self, l): + return l[0].clone(), l[1].clone() + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "does not take as input a single Tensor or a tuple of Tensors"): + m([a, b]) + + # Check invalid output container + class MyModule(nn.Module): + def forward(self, a, b): + return [a.clone(), b.clone()] + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "does not return a single Tensor or a tuple of Tensors"): + m(a, b) + + # Check invalid output from different Nodes + class MyModule(nn.Module): + def forward(self, a, b): + return a.clone(), b.clone() + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "outputs are generated by different autograd Nodes"): + m(a, b) + + # Check invalid forward with multiple Nodes + class MyModule(nn.Module): + def forward(self, a): + return a.clone().clone() + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "the forward contains multiple autograd Nodes"): + m(a) + + def test_hook_backward_size(self): + # Make module with multiple operations in forward + # And different size for input and outputs + class MyModule(nn.Module): + def forward(self, arg1, arg2): + tmp = arg1.sum() * arg2 + tmp = tmp + arg2.sum() * arg1.sum() + tmp = tmp.sum().view(1) + tmp = tmp.expand(8).contiguous() + return tmp + + module = MyModule() + inp1 = torch.randn(5, 5, requires_grad=True) + inp2 = torch.randn(10, 10, requires_grad=True) + + def bw_hook(module, grad_input, grad_output): + self.assertEqual(len(grad_input), 2) + self.assertEqual(grad_input[0].size(), torch.Size([5, 5])) + self.assertEqual(grad_input[1].size(), torch.Size([10, 10])) + self.assertEqual(len(grad_output), 1) + self.assertEqual(grad_output[0].size(), torch.Size([8])) + + with module.register_full_backward_hook(bw_hook): + module(inp1, inp2).sum().backward() + def test_hook_backward_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) @@ -1041,32 +1212,19 @@ def test_add_module_raises_error_if_attr_exists(self): with self.assertRaises(KeyError): m.add_module('attribute_name', nn.Module()) + @unittest.expectedFailure def test_getattr_with_property(self): class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.linear = nn.Linear(4, 5) - - def forward(self, input): - return self.linear(input) - @property def some_property(self): return self.something_that_doesnt_exist model = Model() - with self.assertRaises(nn.modules.module.ModuleAttributeError) as mae: - check = model.shouldnt_exist - self.assertIn("shouldnt_exist", mae) - - # Before using nn.modules.ModuleAttributeError, if an AttributeError - # was raised in a property. The AttributeError was raised on the - # property itself. This checks that some_property is not in the - # expection. - with self.assertRaises(nn.modules.module.ModuleAttributeError) as mae: - check = model.some_property - self.assertIn("something_that_doesnt_exist", mae) - self.assertNotIn("some_propery", mae) + + with self.assertRaisesRegex( + AttributeError, + r"'Model' object has no attribute 'something_that_doesnt_exist'"): + model.some_property def test_Sequential_getitem(self): l1 = nn.Linear(10, 20) @@ -1544,7 +1702,8 @@ def test_overwrite_module_params_on_conversion(self): m = nn.Linear(20, 10).float() mw = m.weight[:] m.double() - mw[0][0] = 5 + with torch.no_grad(): + mw[0][0] = 5 self.assertTrue(mw[0][0].dtype == torch.float) self.assertTrue(mw._base[0][0].dtype == torch.double) @@ -1557,7 +1716,8 @@ def test_overwrite_module_params_on_conversion(self): m = nn.Linear(20, 10).float() mw = m.weight[:] m.double() - mw[0][0] = 5 + with torch.no_grad(): + mw[0][0] = 5 self.assertTrue(mw[0][0] == mw._base[0][0]) # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, @@ -2233,7 +2393,6 @@ def test_pruning_container_compute_mask(self): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected_mask, computed_mask) - def test_l1_unstructured_pruning(self): r"""Test that l1 unstructured pruning actually removes the lowest entries by l1 norm (by hand). It also checks that applying l1 @@ -2258,6 +2417,35 @@ def test_l1_unstructured_pruning(self): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected_weight, m.weight) + def test_l1_unstructured_pruning_with_importance_scores(self): + r"""Test that l1 unstructured pruning actually removes the lowest + entries of importance scores and not the parameter by l1 norm (by hand). + It also checks that applying l1 unstructured pruning more than once + respects the previous mask. + """ + m = nn.Linear(4, 2) + # modify its weight matrix by hand + m.weight = torch.nn.Parameter( + torch.tensor( + [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 + ) + ) + importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_weight, m.weight) + + # check that pruning again removes two entries of m.weight that are colocated with + # the next two smallest absolute values of importance scores. + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 0, 0, 4], [-4, 0, 0, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_weight, m.weight) + def test_unstructured_pruning_same_magnitude(self): r"""Since it may happen that the tensor to prune has entries with the same exact magnitude, it is important to check that pruning happens @@ -2275,7 +2463,6 @@ def test_unstructured_pruning_same_magnitude(self): self.assertEqual(nparams_toprune, nparams_pruned) def test_random_structured_pruning_amount(self): - AMOUNT = 0.6 AXIS = 2 p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) @@ -2292,7 +2479,6 @@ def test_random_structured_pruning_amount(self): ) assert per_column_sums == [0, 20] - def test_ln_structured_pruning(self): r"""Check Ln structured pruning by hand. """ @@ -2316,6 +2502,33 @@ def test_ln_structured_pruning(self): prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1) self.assertEqual(expected_mask_axis3, m.weight_mask) + def test_ln_structured_pruning_importance_scores(self): + r"""Check Ln structured pruning by hand. + """ + m = nn.Conv2d(3, 1, 2) + m.weight.data = torch.Tensor( + [[[[1., 2.], [1., 2.5]], + [[0.5, 1.], [0.1, 0.1]], + [[-3., -5.], [0.1, -1.]]]] + ) + importance_scores = torch.Tensor( + [[[[10., 1.], [10., 1.]], + [[30., 3.], [30., 3.]], + [[-20., -2.], [-20., -2.]]]] + ) + # expected effect of pruning 1 of the 3 channels by L2-norm + expected_mask_axis1 = torch.ones_like(m.weight) + expected_mask_axis1[:, 0] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=2, dim=1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis1, m.weight_mask) + + # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm + expected_mask_axis3 = expected_mask_axis1 + expected_mask_axis3[:, :, :, 1] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis3, m.weight_mask) def test_remove_pruning(self): r"""`prune.remove` removes the hook and the reparametrization @@ -2395,6 +2608,49 @@ def test_global_pruning(self): expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) self.assertEqual(expected_nweight, n.weight) + def test_global_pruning_importance_scores(self): + r"""Test that global l1 unstructured pruning over 2 parameters removes + the `amount=4` smallest global weights across the 2 parameters. + """ + m = nn.Linear(4, 2) + n = nn.Linear(3, 1) + # modify the weight matrices by hand + m.weight = torch.nn.Parameter( + torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( + dtype=torch.float32) + ) + m_importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + n.weight = torch.nn.Parameter( + torch.tensor([[0, 0.1, -2]]).to( + dtype=torch.float32) + ) + n_importance_scores = torch.tensor([[0, 10., -0.2]]).to(dtype=torch.float32) + + params_to_prune = ( + (m, 'weight'), + (n, 'weight'), + ) + importance_scores = { + (m, 'weight'): m_importance_scores, + (n, 'weight'): n_importance_scores, + } + + # prune the 4 smallest weights globally by L1 magnitude + prune.global_unstructured( + params_to_prune, + pruning_method=prune.L1Unstructured, + amount=4, + importance_scores=importance_scores, + ) + + expected_m_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]]) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(expected_m_weight, m.weight) + + expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) + self.assertEqual(expected_n_weight, n.weight) def test_custom_from_mask_pruning(self): r"""Test that the CustomFromMask is capable of receiving @@ -2484,7 +2740,6 @@ def test_pruning_serialization_model(self): self.assertEqual(pruned_weight, new_model[0].weight) - def test_pruning_serialization_state_dict(self): # create a model model = torch.nn.Sequential( @@ -2535,7 +2790,6 @@ def test_pruning_serialization_state_dict(self): self.assertEqual(pruned_weight, new_model[0].weight) - def test_prune(self): # create a new pruning method p = prune.L1Unstructured(amount=2) @@ -2549,6 +2803,37 @@ def test_prune(self): pruned_tensor = p.prune(t, default_mask) self.assertEqual(t * expected_mask, pruned_tensor) + def test_prune_importance_scores(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + importance_scores = torch.tensor( + [[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]] + ).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) + pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor) + + def test_prune_importance_scores_mimic_default(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) + pruned_tensor_without_importance_scores = p.prune(t, default_mask) + pruned_tensor_with_importance_scores = p.prune(t, default_mask, importance_scores=t) + self.assertEqual(pruned_tensor_without_importance_scores, pruned_tensor_with_importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) + def test_rnn_pruning(self): l = torch.nn.LSTM(32, 32) # This Module has 4 parameters called: @@ -2582,37 +2867,41 @@ def test_rnn_pruning(self): def test_rnn_weight_norm(self): - l = torch.nn.LSTM(32, 32) - # This Module has 4 parameters called: - # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' + def check_weight_norm(l, name, num_params): + # This Module has 4 or 5 parameters called: + # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', weight_hr_l0 - # Applying weight norm on one of them causes it to become a tensor - l = torch.nn.utils.weight_norm(l, name='weight_ih_l0') - assert ( - sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) - == 3 - ) + # Applying weight norm on one of them causes it to become a tensor + l = torch.nn.utils.weight_norm(l, name=name) + self.assertEqual( + sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]), + num_params - 1, + ) - # Removing the weight norm reparametrization restores the Parameter - l = torch.nn.utils.remove_weight_norm(l, name='weight_ih_l0') - assert ( - sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) - == 4 - ) + # Removing the weight norm reparametrization restores the Parameter + l = torch.nn.utils.remove_weight_norm(l, name=name) + self.assertEqual( + sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]), + num_params, + ) + + # Make sure that, upon removal of the reparametrization, the + # `._parameters` and `.named_parameters` contain the right params. + # Specifically, the original weight ('weight_ih_l0') should be placed + # back in the parameters, while the reparametrization components + # ('weight_ih_l0_v' and 'weight_ih_l0_g') should be removed. + self.assertTrue(name in l._parameters) + self.assertIsNotNone(l._parameters[name]) + self.assertTrue(name + '_v' not in l._parameters) + self.assertTrue(name + '_g' not in l._parameters) + self.assertTrue(name in dict(l.named_parameters())) + self.assertIsNotNone(dict(l.named_parameters())[name]) + self.assertTrue(name + '_v' not in dict(l.named_parameters())) + self.assertTrue(name + '_g' not in dict(l.named_parameters())) + + check_weight_norm(torch.nn.LSTM(32, 32), 'weight_ih_l0', 4) + check_weight_norm(torch.nn.LSTM(32, 32, proj_size=16), 'weight_hr_l0', 5) - # Make sure that, upon removal of the reparametrization, the - # `._parameters` and `.named_parameters` contain the right params. - # Specifically, the original weight ('weight_ih_l0') should be placed - # back in the parameters, while the reparametrization components - # ('weight_ih_l0_v' and 'weight_ih_l0_g') should be removed. - assert 'weight_ih_l0' in l._parameters - assert l._parameters['weight_ih_l0'] is not None - assert 'weight_ih_l0_v' not in l._parameters - assert 'weight_ih_l0_g' not in l._parameters - assert 'weight_ih_l0' in dict(l.named_parameters()) - assert dict(l.named_parameters())['weight_ih_l0'] is not None - assert 'weight_ih_l0_v' not in dict(l.named_parameters()) - assert 'weight_ih_l0_g' not in dict(l.named_parameters()) def test_weight_norm(self): input = torch.randn(3, 5) @@ -2648,18 +2937,69 @@ def test_weight_norm(self): m = torch.nn.utils.weight_norm(m) def test_parameterlistdict_setting_attributes(self): - mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + with warnings.catch_warnings(record=True) as w: + mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + self.assertTrue(len(w) == 0) + + with warnings.catch_warnings(record=True) as w: + mod.train() + mod.eval() + self.assertTrue(len(w) == 0) with self.assertWarnsRegex(UserWarning, r"Setting attributes on ParameterList is not supported"): torch.nn.utils.weight_norm(mod, "0") - mod = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + with warnings.catch_warnings(record=True) as w: + mod = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + self.assertTrue(len(w) == 0) + + with warnings.catch_warnings(record=True) as w: + mod.train() + mod.eval() + self.assertTrue(len(w) == 0) with self.assertWarnsRegex(UserWarning, r"Setting attributes on ParameterDict is not supported"): torch.nn.utils.weight_norm(mod, "b") + def test_parameterlistdict_pickle(self): + m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) + + m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + del m._initialized + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) + + # Test whether loading from older checkpoints works without triggering warnings + m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) + + m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) + + m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + del m._initialized + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) + + # Test whether loading from older checkpoints works without triggering warnings + m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) + def test_weight_norm_pickle(self): m = torch.nn.utils.weight_norm(nn.Linear(5, 7)) m = pickle.loads(pickle.dumps(m)) @@ -2946,6 +3286,23 @@ def test_threshold_int(self): expected = torch.tensor([99, 99, 99, 99, 1, 2, 3]) self.assertEqual(F.threshold(x, 0, 99), expected) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_embedding_max_norm_unsorted_repeating_indices(self): + def create_embedding(device): + # Seed RNG so we get the same Embedding each time + torch.manual_seed(0) + return torch.nn.Embedding( + num_embeddings=20, + embedding_dim=64, + max_norm=1.0).to(device) + + ix = torch.arange(2, device='cpu', dtype=torch.long).repeat(2000) + out_cpu = create_embedding('cpu')(ix) + + ix = ix.to('cuda') + out = create_embedding('cuda')(ix) + self.assertEqual(out.cpu(), out_cpu) + def test_embedding_sparse_basic(self): embedding = nn.Embedding(10, 20, sparse=True) input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long) @@ -2997,6 +3354,12 @@ def test_embedding_from_pretrained(self): output = embedding(input) self.assertEqual(a, output) + def test_embedding_from_pretrained_padding_idx(self): + padding_idx = 2 + embeddings = torch.rand(4, 3, requires_grad=True) + embedding_nn = nn.Embedding.from_pretrained(embeddings, padding_idx=padding_idx) + self.assertEqual(embedding_nn.weight[padding_idx].sum(), 0) + def test_embedding_from_pretrained_options(self): a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) opts = { @@ -3027,6 +3390,13 @@ def test_embedding_functional(self): res_F = F.embedding(a, embeddings) self.assertEqual(res_old, res_F) + embed_old = torch.nn.Embedding(4, 3) + embed_old = embed_old.from_pretrained(embeddings, padding_idx=2) + res_old = embed_old(a) + res_F = F.embedding(a, embeddings, padding_idx=2) + + self.assertEqual(res_old, res_F) + @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs' ' with instruction set support avx2 or newer.') @@ -3495,51 +3865,60 @@ def test_adaptive_pooling_size_none(self): output = module(input) self.assertEqual(output.size(), (4,) + (2,) * (numel - 1) + (4,)) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_adaptive_pooling_avg_nhwc(self): - input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32, device="cuda") - input = input.contiguous(memory_format=torch.channels_last).requires_grad_() - grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32, device="cuda") - pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + device_list = ['cpu'] + if TEST_CUDA: + device_list.append('cuda') - ref_input = input.detach().clone().contiguous().requires_grad_(True) - ref_grad = grad.detach().clone().contiguous() - ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + for device in device_list: + input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) + input = input.contiguous(memory_format=torch.channels_last).requires_grad_() + grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) + pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - out = pool(input) - out.backward(grad) - ref_out = ref_pool(ref_input) - ref_out.backward(ref_grad) + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out) - self.assertEqual(input.grad, ref_input.grad) + out = pool(input) + out.backward(grad) + ref_out = ref_pool(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(input.grad, ref_input.grad) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_adaptive_pooling_avg_nhwc_non_contiguous(self): - input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32, device="cuda") - input = input.contiguous(memory_format=torch.channels_last) - input = input[:, ::2, :, :].requires_grad_() - grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32, device="cuda") - grad = grad[:, ::2, :, :] - pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + device_list = ['cpu'] + if TEST_CUDA: + device_list.append('cuda') - ref_input = input.detach().clone().contiguous().requires_grad_(True) - ref_grad = grad.detach().clone().contiguous() - ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).cuda() + for device in device_list: + input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) + input = input.contiguous(memory_format=torch.channels_last) + input = input[:, ::2, :, :].requires_grad_() + grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) + grad = grad[:, ::2, :, :] + pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - out = pool(input) - out.backward(grad) - ref_out = ref_pool(ref_input) - ref_out.backward(ref_grad) + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out) - self.assertEqual(input.grad, ref_input.grad) + out = pool(input) + out.backward(grad) + ref_out = ref_pool(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(input.grad, ref_input.grad) - @largeCUDATensorTest('12GB') + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @largeTensorTest('12GB', device='cuda') def test_adaptive_pooling_avg_nhwc_launch_config_backward(self): input = torch.randint(1, 10, (1, 32, 2 ** 17 + 1, 32), dtype=torch.float32, device="cuda") input = input.contiguous(memory_format=torch.channels_last).requires_grad_() @@ -3561,7 +3940,8 @@ def test_adaptive_pooling_avg_nhwc_launch_config_backward(self): self.assertEqual(out, ref_out) self.assertEqual(input.grad, ref_input.grad) - @largeCUDATensorTest('12GB') + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @largeTensorTest('12GB', device='cuda') def test_adaptive_pooling_avg_nhwc_launch_config_forward(self): input = torch.randint(1, 10, (1, 32, 16, 16), dtype=torch.float32, device="cuda") input = input.contiguous(memory_format=torch.channels_last).requires_grad_() @@ -3639,7 +4019,7 @@ def test_state_dict(self): self.assertIn('bn.running_var', state_dict) self.assertIn('bn.running_mean', state_dict) self.assertIn('bn.num_batches_tracked', state_dict) - self.assertFalse(any(map(lambda k: k.startswith('empty'), state_dict.keys()))) + self.assertFalse(any(k.startswith('empty') for k in state_dict.keys())) for k, v in state_dict.items(): param = net for component in k.split('.'): @@ -3790,8 +4170,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, # use sequential to verify nesting m = nn.Sequential(CustomState()) - m[0].param[0] = 10 - m[0].sub.weight[0, 0] = 555 + with torch.no_grad(): + m[0].param[0] = 10 + m[0].sub.weight[0, 0] = 555 state_dict = m.state_dict() self.assertEqual(state_dict["0.serialized"].item(), 11) self.assertIn("0.sub.weight", state_dict) @@ -3924,6 +4305,15 @@ def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self): # but it should work with the same type nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) + def test_Conv2d_1x1(self): + in_channels = 2 + out_channels = 2 + mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double) + input = torch.randn(1, in_channels, 5, 5, requires_grad=True, dtype=torch.double) + for enabled in (False, True): + with torch.backends.mkldnn.flags(enabled=enabled): + gradcheck(F.conv2d, (input, mod.weight)) + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') def test_cudnn_non_contiguous(self): @@ -3953,7 +4343,7 @@ def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self): @unittest.skipIf(not TEST_CUDA, 'CUDA not available') @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - @repeat_test_for_types(ALL_TENSORTYPES2) + @repeat_test_for_types(get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) def test_Conv2d_deterministic_cudnn(self, dtype=torch.float): inputs = torch.randn(2, 3, 5, 5, device="cuda", dtype=dtype, requires_grad=True) with cudnn.flags(enabled=True, benchmark=True, deterministic=True): @@ -3983,7 +4373,7 @@ def test_Conv2d_backward_twice(self): lambda: o1.sum().backward()) @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - @repeat_test_for_types(ALL_TENSORTYPES2) + @repeat_test_for_types(get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) def test_Conv2d_large_workspace(self, dtype=torch.float): # These sizes require huge cuDNN workspaces. Make sure we choose a # reasonable algorithm that does not run out of memory @@ -4104,13 +4494,29 @@ def test_ConvTranspose2d_half_cublas_gemm(self): output = deconv(inputs) output.mean().backward() + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + @repeat_test_for_types([torch.half, torch.float]) + def test_ConvTranspose2d_large_output_padding(self, dtype=torch.half): + net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device='cuda', dtype=dtype) + net2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device='cuda', dtype=dtype) + net3 = torch.nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device='cuda', dtype=dtype) + x = torch.rand(1, 128, 6, 6, device='cuda', dtype=dtype, requires_grad=True) + x = net1(x) + x = net2(x) + x = net3(x) + x.backward(torch.randn_like(x)) + torch.cuda.synchronize() + # For https://github.com/pytorch/pytorch/pull/1273 # Almost identical to the above `test_Conv2d_naive_groups` def test_Conv2d_groups_nobias(self): dev_dtypes = [("cpu", torch.float)] if TEST_CUDA: dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] - if TEST_WITH_ROCM: + if AMPERE_OR_ROCM: dev_dtypes += [("cuda", torch.bfloat16)] for device, dtype in dev_dtypes: m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype) @@ -4148,7 +4554,7 @@ def test_Conv2d_groups_nobias_v2(self): dev_dtypes = [("cpu", torch.float)] if TEST_CUDA: dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] - if TEST_WITH_ROCM: + if AMPERE_OR_ROCM: dev_dtypes += [("cuda", torch.bfloat16)] for device, dtype in dev_dtypes: m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype) @@ -5662,6 +6068,7 @@ def test_cudnn_rnn_dropout_states_device(self): def test_cudnn_weight_format(self): rnns = [ nn.LSTM(10, 20, batch_first=True), + nn.LSTM(10, 20, batch_first=True, proj_size=10), nn.GRU(10, 20, batch_first=True), nn.RNN(10, 20, batch_first=True) ] @@ -5672,6 +6079,10 @@ def test_cudnn_weight_format(self): hx = torch.randn(1, 5, 20, requires_grad=True, device="cuda") all_vars = [input, hx] + list(rnn.parameters()) if isinstance(rnn, nn.LSTM): + # LSTM with projections has different hx size + if rnn.proj_size > 0: + hx = torch.randn(1, 5, 10, requires_grad=True, device="cuda") + all_vars[1] = hx cx = torch.randn(1, 5, 20, requires_grad=True, device="cuda") all_vars[2:2] = [cx] hx = (hx, cx) @@ -5711,6 +6122,7 @@ def test_cudnn_weight_format(self): def test_cudnn_weight_tying(self): rnns = [ nn.LSTM(10, 20, batch_first=True, bidirectional=True), + nn.LSTM(10, 20, batch_first=True, bidirectional=True, proj_size=10), nn.GRU(10, 20, batch_first=True, bidirectional=True), nn.RNN(10, 20, batch_first=True, bidirectional=True) ] @@ -5723,6 +6135,10 @@ def test_cudnn_weight_tying(self): opt = torch.optim.SGD(rnn.parameters(), lr=0.1) opt.zero_grad() if isinstance(rnn, nn.LSTM): + # LSTM with projections has different hx size + if rnn.proj_size > 0: + hx = torch.randn(2, 5, 10, requires_grad=True, device="cuda") + all_vars[1] = hx cx = torch.randn(2, 5, 20, requires_grad=True, device="cuda") all_vars[2:2] = [cx] hx = (hx, cx) @@ -5983,6 +6399,82 @@ def get_inputs(input_shape, hidden_shape, mode): hidden_shape = update_shape(correct_hidden_shape, 0, bad_size) test(input_shape, hidden_shape, mode) + def test_projections_lstm_args_check(self): + input_size = 3 + hidden_size = 5 + proj_size = 2 + num_layers = 2 + batch_size = 4 + seq_len = 6 + num_directions = 1 + bad_size = 7 # prime number so that no size can divide it. + + def test(input_shape, hidden_h_shape, hidden_c_shape): + for input, hidden in get_inputs(input_shape, hidden_h_shape, hidden_c_shape): + model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size) + self.assertRaises(RuntimeError, lambda: model(input, hidden)) + + correct_input_shape = (seq_len, batch_size, input_size) + correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size) + correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size) + + def update_shape(shape, dim, new_dim_size): + new_shape = list(shape) + new_shape[dim] = new_dim_size + return tuple(new_shape) + + def get_inputs(input_shape, hidden_h_shape, hidden_c_shape): + '''returns list( tuple(input, hidden) ) + where input, hidden are inputs to a model''' + input = torch.randn(input_shape) + hidden_h = torch.randn(hidden_h_shape) + hidden_c = torch.randn(hidden_c_shape) + return [(input, (hidden_h, hidden_c))] + + # Incorrect input batch size + input_shape = update_shape(correct_input_shape, 1, bad_size) + test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape) + + # Incorrect hidden batch size + input_shape = correct_input_shape + hidden_h_shape = update_shape(correct_hidden_h_shape, 1, bad_size) + hidden_c_shape = update_shape(correct_hidden_c_shape, 1, bad_size) + test(input_shape, hidden_h_shape, hidden_c_shape) + + # Incorrect input size + input_shape = update_shape(correct_input_shape, 2, bad_size) + test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape) + + # Incorrect hidden size + input_shape = correct_input_shape + hidden_h_shape = update_shape(correct_hidden_h_shape, 2, bad_size) + hidden_c_shape = update_shape(correct_hidden_c_shape, 2, bad_size) + test(input_shape, hidden_h_shape, hidden_c_shape) + + # Incorrect hidden[0] + input_shape = correct_input_shape + hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size) + hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size) + test(input_shape, hidden_h_shape, hidden_c_shape) + + # Incorrect proj size = hidden size + input_shape = correct_input_shape + hidden_h_shape = update_shape(correct_hidden_h_shape, 0, hidden_size) + hidden_c_shape = correct_hidden_c_shape + test(input_shape, hidden_h_shape, hidden_c_shape) + + # Incorrect proj size != hidden size + input_shape = correct_input_shape + hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size) + hidden_c_shape = correct_hidden_c_shape + test(input_shape, hidden_h_shape, hidden_c_shape) + + # Incorrect cell size != hidden size + input_shape = correct_input_shape + hidden_h_shape = correct_hidden_h_shape + hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size) + test(input_shape, hidden_h_shape, hidden_c_shape) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_rnn_check_device(self): input_size = 3 @@ -6020,6 +6512,40 @@ def test_rnn_check_device(self): "Input and hidden tensors are not at the same device"): model(input.to('cuda:0'), (hidden.to('cuda:0'), hidden.to('cuda:1'))) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + def test_projections_lstm_check_device(self): + input_size = 3 + hidden_size = 5 + proj_size = 2 + num_layers = 2 + batch_size = 4 + seq_len = 6 + num_directions = 1 + + correct_input_shape = (seq_len, batch_size, input_size) + correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size) + correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size) + + model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size) + input = torch.randn(correct_input_shape) + hidden_h = torch.randn(correct_hidden_h_shape) + hidden_c = torch.randn(correct_hidden_c_shape) + + # input and weights are not at the same device + with self.assertRaisesRegex(RuntimeError, + "Input and parameter tensors are not at the same device"): + model(input.to('cuda:0')) + + # input and hiddens are not at the same device + with self.assertRaisesRegex(RuntimeError, + r"Input and hidden tensors are not at the same device"): + model(input, (hidden_h.to('cuda:0'), hidden_c.to('cuda:0'))) + + # hidden tensors are not at the same CUDA device + with self.assertRaisesRegex(RuntimeError, + "Input and hidden tensors are not at the same device"): + model(input.to('cuda:0'), (hidden_h.to('cuda:0'), hidden_c.to('cuda:1'))) + def test_rnn_initial_hidden_state(self): rnn_modes = ['RNN', 'GRU', 'LSTM'] for mode in rnn_modes: @@ -6034,9 +6560,29 @@ def test_rnn_initial_hidden_state(self): self.assertEqual(output1, output2) self.assertEqual(hidden1, hidden2) + def test_projections_lstm_initial_hidden_state(self): + for bidir in [False, True]: + rnn = nn.LSTM(30, 20, 2, bidirectional=bidir, proj_size=10) + num_dirs = 2 if bidir else 1 + input = torch.randn(10, 32, 30) + hidden_h = torch.zeros(2 * num_dirs, 32, 10) + hidden_c = torch.zeros(2 * num_dirs, 32, 20) + hidden = (hidden_h, hidden_c) + output1, hidden1 = rnn(input, hidden) + output2, hidden2 = rnn(input) + self.assertEqual(output1, output2) + self.assertEqual(hidden1, hidden2) + + def test_projections_errors_on_gru_and_rnn(self): + error_msg = "proj_size argument is only supported for LSTM, not RNN or GRU" + for mode in ['RNN', 'GRU']: + with self.assertRaisesRegex(ValueError, error_msg): + rnn = getattr(nn, mode)(30, 20, 2, proj_size=10) + def _test_RNN_cpu_vs_cudnn(self, dropout, dtype=torch.double): - def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights_val): + def forward_backward(cuda, rnn, input_val, grad_output, weights_val, hx_val, grad_hy, + cx_val=None, grad_cy=None): is_lstm = isinstance(rnn, nn.LSTM) for x_layer, y_layer in zip(rnn.all_weights, weights_val): @@ -6051,8 +6597,12 @@ def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights input = input_val.clone().requires_grad_(True) input_var = input if is_lstm: - hx = (hx_val.clone().requires_grad_(True), - hx_val.add(1).requires_grad_(True)) + if cx_val is None: + hx = (hx_val.clone().requires_grad_(True), + hx_val.add(1).requires_grad_(True)) + else: + hx = (hx_val.clone().requires_grad_(True), + cx_val.add(1).requires_grad_(True)) else: hx = hx_val.clone().requires_grad_(True) @@ -6065,6 +6615,8 @@ def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights else: hx.data = hx.data.cuda() grad_hy = grad_hy.cuda() + if grad_cy is not None: + grad_cy = grad_cy.cuda() grad_output = grad_output.cuda() output, hy = rnn(input, hx) @@ -6073,7 +6625,10 @@ def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights output = output.data if is_lstm: - torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_hy + 1]) + if grad_cy is None: + torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_hy + 1]) + else: + torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_cy + 1]) else: torch.autograd.backward([output, hy], [grad_output, grad_hy]) @@ -6087,6 +6642,7 @@ def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights input_size = 10 hidden_size = 6 + proj_size = 3 num_layers = 2 seq_length = 7 batch = 6 @@ -6118,15 +6674,15 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu): input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) grad_output = torch.randn(seq_length, batch, hidden_size * num_directions, dtype=dtype) + hx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) + grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) + if not contig: grad_output = make_noncontig(grad_output) grad_hy = make_noncontig(grad_hy) input_var = make_noncontig(input_val) hx_val = make_noncontig(hx_val) - hx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) - grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) - if variable_len: lengths = [7, 5, 5, 2, 1, 1] if lens_as_tensor: @@ -6143,7 +6699,7 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu): batch_first=batch_first).to(dtype) outputs_cpu = forward_backward( - False, rnn, input_val, hx_val, grad_output, grad_hy, rnn.all_weights) + False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) rnn_gpu = module(input_size, hidden_size, @@ -6154,7 +6710,7 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu): batch_first=batch_first).to(dtype) outputs_gpu = forward_backward( - True, rnn_gpu, input_val, hx_val, grad_output, grad_hy, rnn.all_weights) + True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) compare_cpu_gpu(outputs_cpu, outputs_gpu) @@ -6167,13 +6723,78 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu): num_layers * num_directions, batch, hidden_size, dtype=dtype) rnn = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype) - outputs_cpu = forward_backward(False, rnn, input_val, hx_val, grad_output, grad_hy, rnn.all_weights) + outputs_cpu = forward_backward(False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) rnn_gpu = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype) - outputs_gpu = forward_backward(True, rnn_gpu, input_val, hx_val, grad_output, grad_hy, rnn.all_weights) + outputs_gpu = forward_backward(True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy) compare_cpu_gpu(outputs_cpu, outputs_gpu) + # checking LSTM with projections + for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \ + in product((True, False), repeat=6): + num_directions = 2 if bidirectional else 1 + if batch_first: + input_val = torch.randn(batch, seq_length, input_size, dtype=dtype) + grad_output = torch.randn(batch, seq_length, proj_size * num_directions, dtype=dtype) + else: + input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) + grad_output = torch.randn(seq_length, batch, proj_size * num_directions, dtype=dtype) + + hx_val = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype) + cx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) + grad_hy = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype) + grad_cy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype) + + if not contig: + grad_output = make_noncontig(grad_output) + grad_hy = make_noncontig(grad_hy) + grad_cy = make_noncontig(grad_cy) + input_var = make_noncontig(input_val) + hx_val = make_noncontig(hx_val) + cx_val = make_noncontig(cx_val) + + if variable_len: + lengths = [7, 5, 5, 2, 1, 1] + if lens_as_tensor: + lengths = torch.tensor(lengths, dtype=torch.long) + input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first) + grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data + + rnn = nn.LSTM(input_size, + hidden_size, + num_layers, + bias=bias, + dropout=dropout, + bidirectional=bidirectional, + batch_first=batch_first, + proj_size=proj_size).to(dtype) + + outputs_cpu = forward_backward( + False, rnn, input_val, grad_output, rnn.all_weights, + hx_val, grad_hy, cx_val, grad_cy) + + rnn_gpu = nn.LSTM(input_size, + hidden_size, + num_layers, + bias=bias, + dropout=dropout, + bidirectional=bidirectional, + batch_first=batch_first, + proj_size=proj_size).to(dtype) + # LSTM with projections is not supported with MIOpen + if TEST_WITH_ROCM and dtype == torch.float: + with self.assertRaisesRegex(RuntimeError, + "LSTM with projections is not supported with MIOpen"): + outputs_gpu = forward_backward( + True, rnn_gpu, input_val, grad_output, rnn.all_weights, + hx_val, grad_hy, cx_val, grad_cy) + else: + outputs_gpu = forward_backward( + True, rnn_gpu, input_val, grad_output, rnn.all_weights, + hx_val, grad_hy, cx_val, grad_cy) + compare_cpu_gpu(outputs_cpu, outputs_gpu) + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") def test_RNN_cpu_vs_cudnn_no_dropout(self): if TEST_WITH_ROCM: @@ -6196,25 +6817,27 @@ def test_RNN_cudnn_weight_norm(self): batch = 6 # runs on CPU to acquire expected output - m = nn.LSTM(input_size, hidden_size, num_layers) - input = torch.randn(seq_length, batch, input_size) - expected_output = m(input) + def check_weight_norm(m, name): + input = torch.randn(seq_length, batch, input_size) + expected_output = m(input) - # adds weight normalization - name = 'weight_hh_l0' - m = torch.nn.utils.weight_norm(m, name=name) + # adds weight normalization + m = torch.nn.utils.weight_norm(m, name=name) - # moves to CUDA - m = m.cuda() - input = input.cuda() + # moves to CUDA + m = m.cuda() + input = input.cuda() - # otherwise, subsequent warnings will be hidden, and further tests rely on them - warnings.simplefilter("always") - self.assertEqual(m(input), expected_output) + # otherwise, subsequent warnings will be hidden, and further tests rely on them + warnings.simplefilter("always") + self.assertEqual(m(input), expected_output) - # remove weight norm - m = torch.nn.utils.remove_weight_norm(m, name=name) - self.assertEqual(m(input), expected_output) + # remove weight norm + m = torch.nn.utils.remove_weight_norm(m, name=name) + self.assertEqual(m(input), expected_output) + + check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers), 'weight_hh_l0') + check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers, proj_size=3), 'weight_hr_l0') @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_partial_flat_weights(self): @@ -6359,16 +6982,6 @@ def test_RNN_change_dropout(self): self.assertNotEqual(output2.data, prev_output) prev_output = output1.data - def _verify_pixel_shuffle(self, input, output, upscale_factor): - for c in range(output.size(1)): - for h in range(output.size(2)): - for w in range(output.size(3)): - height_idx = h // upscale_factor - weight_idx = w // upscale_factor - channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ - (c * upscale_factor ** 2) - self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx]) - def test_inplace_thnn(self): modules = [nn.ReLU, nn.ELU, nn.SELU, nn.CELU, nn.RReLU] for mod in modules: @@ -6381,7 +6994,7 @@ def test_inplace_thnn(self): self.assertEqual(grad_output, grad_output_clone) @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - @repeat_test_for_types(ALL_TENSORTYPES2) + @repeat_test_for_types(get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) def test_noncontig_conv_grad_cuda(self, dtype=torch.float): # FIXME: remove after adding non-contiguous grad tests for all modules module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to("cuda", dtype) @@ -6398,19 +7011,105 @@ def test_noncontig_conv_grad_cuda(self, dtype=torch.float): output.backward(grad.contiguous()) self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0) - def test_pixel_shuffle(self): - batch_size = random.randint(1, 3) - upscale_factor = random.randint(2, 5) - channels = random.randint(1, 4) * upscale_factor ** 2 - height = random.randint(5, 10) - width = random.randint(5, 10) + def test_pixel_shuffle_unshuffle(self): + def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True, + upscale_factor=None): + # Function to imperatively ensure pixels are shuffled to the correct locations. + # Used to validate the batch operations in pixel_shuffle. + def _verify_pixel_shuffle(input, output, upscale_factor): + for c in range(output.size(-3)): + for h in range(output.size(-2)): + for w in range(output.size(-1)): + height_idx = h // upscale_factor + weight_idx = w // upscale_factor + channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ + (c * upscale_factor ** 2) + self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx]) + + upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor + # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2. + channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1) + height = random.randint(5, 10) + width = random.randint(5, 10) + + if num_input_dims == 1: + input = torch.rand(channels, requires_grad=True) + elif num_input_dims == 2: + input = torch.rand(height, width, requires_grad=True) + else: + batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] + input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) + ps = nn.PixelShuffle(upscale_factor) + pus = nn.PixelUnshuffle(downscale_factor=upscale_factor) + + if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0: + output = ps(input) + _verify_pixel_shuffle(input, output, upscale_factor) + output.backward(output.data) + self.assertEqual(input.data, input.grad.data) + + # Ensure unshuffle properly inverts shuffle. + unshuffle_output = pus(output) + self.assertEqual(input, unshuffle_output) + else: + self.assertRaises(RuntimeError, lambda: ps(input)) + + def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True, + downscale_factor=None): + downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor + channels = random.randint(1, 4) + # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor. + height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1) + # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor. + width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1) + + if num_input_dims == 1: + input = torch.rand(channels, requires_grad=True) + elif num_input_dims == 2: + input = torch.rand(height, width, requires_grad=True) + else: + batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] + input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True) + + pus = nn.PixelUnshuffle(downscale_factor) + self.assertRaises(RuntimeError, lambda: pus(input)) + + def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims): + # For 1D - 2D, this is an error case. + # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle. + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims) + + # Error cases for pixel_shuffle. + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False) + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0) + _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2) + + # Error cases for pixel_unshuffle. + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False) + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False) + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) + _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) + + def test_pixel_shuffle_unshuffle_1D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) - input = torch.rand(batch_size, channels, height, width, requires_grad=True) - ps = nn.PixelShuffle(upscale_factor) - output = ps(input) - self._verify_pixel_shuffle(input.data, output.data, upscale_factor) - output.backward(output.data) - self.assertEqual(input.data, input.grad.data) + def test_pixel_shuffle_unshuffle_2D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2) + + def test_pixel_shuffle_unshuffle_3D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3) + + def test_pixel_shuffle_unshuffle_4D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4) + + def test_pixel_shuffle_unshuffle_5D(): + _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) + + test_pixel_shuffle_unshuffle_1D() + test_pixel_shuffle_unshuffle_2D() + test_pixel_shuffle_unshuffle_3D() + test_pixel_shuffle_unshuffle_4D() + test_pixel_shuffle_unshuffle_5D() def test_elu_inplace_view(self): v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True) @@ -6449,27 +7148,31 @@ def test_PReLU_backward_requires_grad_false(self): @unittest.skipIf( not TEST_NUMPY or not TEST_SCIPY, "Numpy or Scipy not found") def test_gelu(self): - def _test_gelu(n, m, dtype, contiguous): + def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None): + numpy_dtype = { + torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double + }[dtype] + devices = ['cpu'] if dtype != torch.bfloat16 else [] + \ + ['cuda'] if TEST_CUDA else [] + def _gelu_ref(X): return X * stats.norm.cdf(X) - if contiguous: - X = torch.rand(n, m, dtype=dtype, requires_grad=True) - else: - X = torch.rand(n, m, dtype=dtype, requires_grad=True)[:, ::2] - res = F.gelu(X) - ref = _gelu_ref(X.detach().numpy()) - self.assertEqual(res, ref) - gradcheck(F.gelu, [X], eps=1e-4) - - if TEST_CUDA: - X_cuda = X.cuda() - res_cuda = F.gelu(X_cuda) - self.assertEqual(res_cuda.cpu(), ref) - gradcheck(F.gelu, [X_cuda], eps=1e-4) + for d in devices: + if contiguous: + X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d) + else: + X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2] + res = F.gelu(X) + ref = _gelu_ref(X.to(numpy_dtype).cpu().detach().numpy()) + self.assertEqual(res, ref, rtol=rtol, atol=atol) + if dtype != torch.bfloat16: + gradcheck(F.gelu, [X], eps=1e-4) for n in range(1, 10): for m in range(1, 10): + _test_gelu(n, m, torch.bfloat16, True, 1e-2, 0) + _test_gelu(n, m, torch.bfloat16, False, 1e-2, 0) _test_gelu(n, m, torch.float32, True) _test_gelu(n, m, torch.float32, False) _test_gelu(n, m, torch.float64, True) @@ -6698,31 +7401,39 @@ def test_hardtanh_backward(self): @unittest.skipIf(not TEST_CUDNN, "needs cudnn") @skipIfRocm def test_batchnorm_cudnn_nhwc(self): - input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda", requires_grad=True) - input = input.contiguous(memory_format=torch.channels_last) - input.retain_grad() - grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") - grad = grad.contiguous(memory_format=torch.channels_last) - bn = nn.BatchNorm2d(8).cuda().float() - bn.weight.data.uniform_() - bn.bias.data.uniform_() - - ref_input = input.detach().clone().contiguous().requires_grad_(True) - ref_grad = grad.detach().clone().contiguous() - ref_bn = nn.BatchNorm2d(8).cuda().float() - ref_bn.load_state_dict(bn.state_dict()) + def run_test(input, grad_output): + c = input.size(1) + mod = nn.BatchNorm2d(c).cuda().float() + mod.weight.data.uniform_() + mod.bias.data.uniform_() + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_mod = nn.BatchNorm2d(c).cuda().float() + ref_mod.load_state_dict(mod.state_dict()) + out = mod(input) + out.backward(grad_output) + ref_out = ref_mod(ref_input) + ref_out.backward(ref_grad) + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(mod.weight.grad, ref_mod.weight.grad) + self.assertEqual(mod.bias.grad, ref_mod.bias.grad) + self.assertEqual(input.grad, ref_input.grad) - out = bn(input) - out.backward(grad) - ref_out = ref_bn(ref_input) - ref_out.backward(ref_grad) + input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") + input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out) - self.assertEqual(bn.weight.grad, ref_bn.weight.grad) - self.assertEqual(bn.bias.grad, ref_bn.bias.grad) - self.assertEqual(input.grad, ref_input.grad) + grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") + grad = grad.contiguous(memory_format=torch.channels_last) + run_test(input, grad) + # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous" + # not channels_last + input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") + input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() + grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") + grad = grad.permute(0, 2, 1, 3) + run_test(input, grad) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_batchnorm_cudnn_half(self): @@ -7018,11 +7729,16 @@ def test_pointwise_loss_broadcast(self): # https://github.com/pytorch/pytorch/issues/27692 reports # that l1_loss get a wrong result for big batch size def test_l1_loss_correct(self): - for N in range(1, 50, 10): - input = torch.rand(N, 3, 1024, 1024) - self.assertEqual( - torch.nn.L1Loss()(input, torch.zeros_like(input)), - input.abs().mean()) + for dtype in [torch.float, torch.cfloat]: + for N in range(1, 50, 10): + input = torch.rand(N, 3, 1024, 1024, dtype=dtype) + self.assertEqual( + torch.nn.L1Loss()(input, torch.zeros_like(input)), + input.abs().mean()) + + def test_smoothl1loss_negative_beta_not_supported(self): + with self.assertRaises(RuntimeError): + F.smooth_l1_loss(torch.randn(2, 2), torch.randn(2, 2), beta=-1.0) def test_cosine_similarity(self): input1 = torch.randn(4, 4, requires_grad=True) @@ -7090,6 +7806,9 @@ def test_grid_sample_error_checking(self): with self.assertRaisesRegex(RuntimeError, "expected input to have non-empty spatial dimensions"): F.grid_sample(torch.empty(1, 1, 0, 2), grid, align_corners=False) + with self.assertRaisesRegex(RuntimeError, "bicubic interpolation only supports 4D input"): + F.grid_sample(torch.empty(1, 1, 2, 2, 2), torch.empty(1, 1, 1, 1, 3), mode='bicubic') + if TEST_CUDA: with self.assertRaisesRegex(RuntimeError, "expected input and grid to be on same device"): F.grid_sample(input.cuda(), grid, align_corners=False) @@ -7222,8 +7941,8 @@ def get_grid(device='cpu', data=None): self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5) out_fallback.backward(gradients.float()) - self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-5, rtol=5e-5) - self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-5, rtol=5e-5) + self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5) + self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5) if TEST_CUDA: input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_() @@ -7301,7 +8020,7 @@ def get_grid(device='cpu', data=None): W = random.randint(3, IW + 2) test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners) - for mode in ('bilinear', 'nearest'): + for mode in ('bilinear', 'nearest', 'bicubic'): for padding_mode in ('zeros', 'border', 'reflection'): for align_corners in (True, False): # test known input on CPU @@ -7369,6 +8088,37 @@ def get_grid(device='cpu', data=None): [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5) else: raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + elif mode == 'bicubic': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000], + [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264], + [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]]).view(1, 1, 2, 5) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000], + [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781], + [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]]).view(1, 1, 2, 5) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000], + [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531], + [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]]).view(1, 1, 2, 5) + else: + raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + else: raise AssertionError("missing groundtruth test for interpolation mode '{}'".format(mode)) output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, @@ -7424,11 +8174,42 @@ def get_grid(device='cpu', data=None): groundtruth = torch.tensor( [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2) + elif mode == 'bicubic': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[[[-4.5, -6.], [-4.5, 6.], [2.725679, 0.740878], [2.725679, -0.740878]], + [[1.5, 0.], [1.5, 0.], [1.927921, -0.05688], [1.927921, 0.05688]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[-5.859375, -5.888672], [-5.859375, 5.888672], [-5.6250, -7.5000], [-5.6250, 7.5000]], + [[-0.234375, -0.263672], [-0.234375, 0.263672], [1.8750, 0.], [1.8750, 0.]]]] + ).view(1, 2, 4, 2) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[[[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]], + [[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]], + [[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]]]]).view(1, 2, 4, 2) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[[[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]], + [[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]], + [[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]]]]).view(1, 2, 4, 2) + else: + raise AssertionError("missing gradient groundtruth test for padding mode '{}'".format(padding_mode)) else: raise AssertionError("missing gradient groundtruth test for interpolation mode '{}'".format(mode)) F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners).sum().backward() - self.assertEqual(grid.grad, groundtruth, + self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0, msg="gradient groundtruth comparison failed for mode={}, " "padding_mode={}".format(mode, padding_mode)) @@ -7439,7 +8220,7 @@ def get_grid(device='cpu', data=None): F.GRID_SAMPLE_INTERPOLATION_MODES[mode], F.GRID_SAMPLE_PADDING_MODES[padding_mode], align_corners).sum().backward() - self.assertEqual(grid.grad, groundtruth) + self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0) # do gradcheck N = random.randint(2, 8) @@ -8625,18 +9406,19 @@ def test_flatten(self): def test_unflatten(self): tensor_input = torch.randn(2, 50) - # Unflatten Tensor + # Unflatten Tensor (unflattened_size as a tuple of ints and list of ints) - unflatten = nn.Unflatten(dim=1, unflattened_size=(2, 5, 5)) - tensor_output = unflatten(tensor_input) - self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) + for us in ((2, 5, 5), [2, 5, 5]): + unflatten = nn.Unflatten(dim=1, unflattened_size=us) + tensor_output = unflatten(tensor_input) + self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) # Unflatten NamedTensor unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5))) named_tensor_input = tensor_input.refine_names('N', 'features') named_tensor_output = unflatten(named_tensor_input) - self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) + self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5])) def test_unflatten_invalid_arg(self): # Wrong type for unflattened_size (tuple of floats) @@ -8646,6 +9428,13 @@ def test_unflatten_invalid_arg(self): r"unflattened_size must be tuple of ints, but found element of type float at pos 2"): nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0)) + # Wrong type for unflattened_size (list of lists and list of tuples) + for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]): + with self.assertRaisesRegex( + TypeError, + r"unflattened_size must be a tuple of tuples, but found type list"): + nn.Unflatten(dim='features', unflattened_size=us) + # Wrong type for unflattened_size (tuple of lists) with self.assertRaisesRegex( @@ -8653,19 +9442,12 @@ def test_unflatten_invalid_arg(self): r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"): nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5])) - # Wrong type for unflattened_size (list of ints) - - with self.assertRaisesRegex( - TypeError, - r"unflattened_size must be a tuple of ints, but found type list"): - nn.Unflatten(dim=1, unflattened_size=[2, 5, 5]) - - # Wrong type for unflattened_size (list of lists) + # Wrong type for unflattened_size (tuple of dicts) with self.assertRaisesRegex( TypeError, - r"unflattened_size must be a tuple of tuples, but found type list"): - nn.Unflatten(dim='features', unflattened_size=[['C', 2], ['W', 5], ['H', 5]]) + r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"): + nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5})) def test_layer_norm_grads_with_create_graph_flag(self): atol = 1e-5 @@ -8740,6 +9522,8 @@ def test_calculate_gain_nonlinear(self): self.assertEqual(gain, 1.4142135623730951) elif fn == 'leaky_relu': # sqrt(2 / 1 + slope^2)) self.assertEqual(gain, 1.4141428569978354) + elif fn == 'selu': + self.assertEqual(gain, 0.75) def test_calculate_gain_leaky_relu(self): for param in [None, 0, 0.01, 10]: @@ -9129,6 +9913,18 @@ def test_fuse_module_eval_numerics(self, X, running_mean, running_var): self.assertEqual(Y_ref, Y_hat, msg="Conv+BN fusion results are off") + na_bn_ref = torch.nn.BatchNorm2d(oC, affine=False) + na_bn_ref.running_mean = torch.from_numpy(running_mean[0]).to(torch.double) + na_bn_ref.running_var = torch.from_numpy(running_var[0]).to(torch.double) + na_bn_ref.eval() + + Y_ref = na_bn_ref(conv_ref(inputs)) + conv_na_bn_fused = torch.nn.utils.fusion.fuse_conv_bn_eval(conv_ref, + na_bn_ref) + Y_hat = conv_na_bn_fused(inputs) + + self.assertEqual(Y_ref, Y_hat, msg="Conv+BN(non-affine) fusion results are off") + class TestAddRelu(TestCase): def test_add_relu(self): @@ -9140,7 +9936,7 @@ def test_add_relu(self): a = a + 5 add_res = a + b relu_res = torch.relu(add_res) - add_relu_res = torch.add_relu(a, b) + add_relu_res = torch._VF._add_relu(a, b) self.assertTrue(torch.allclose(add_relu_res, relu_res)) @@ -9192,6 +9988,15 @@ def test_bfloat16(self, test=test, kwargs=kwargs): if getattr(test, 'check_bfloat16', True): add(cuda_test_name + '_bfloat16', test_bfloat16) + def test_cfloat(self, test=test, kwargs=kwargs): + test.test_cuda(self, dtype=torch.cfloat, **kwargs) + + def test_cdouble(self, test=test, kwargs=kwargs): + test.test_cuda(self, dtype=torch.cdouble, **kwargs) + if getattr(test, 'check_complex', False): + add(cuda_test_name + '_cfloat', test_cfloat) + add(cuda_test_name + '_cdouble', test_cdouble) + else: if tf32_is_not_fp32() and test.with_tf32: @@ -9567,6 +10372,46 @@ def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_forma module.__repr__() str(module) + def _test_dropout_discontiguous(self, cls, device, memory_format=torch.contiguous_format): + # In this test, we verify that dropout preserves the layout and data for different memory formats. + # We check whether, we get same values for the output of dropout, when the probability + # of dropout is 0 or very close to 0. + # Reference: https://github.com/pytorch/pytorch/issues/47176 + close_to_zero_p = 1e-10 # Should be almost zero but not zero, as for p=0 different path is taken + for p in [0, close_to_zero_p]: + inp = torch.ones(2, 3, 3, 3, device=device) + inp_discontiguous = torch.empty(2, 3, 3, 6, device=device, memory_format=memory_format)[..., ::2] + inp_discontiguous.copy_(inp) + mod = cls(p=p) + out = mod(inp_discontiguous) + if p != 0: # Zero will keep strides as is based on input. + # When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2) + # When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1) + self.assertTrue(out.is_contiguous(memory_format=memory_format)) + self.assertEqual(inp_discontiguous, out) + + def _test_dropout_stride_mean_preserve(self, cls, device): + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2], d[3]) + + inp = torch.ones(2, 3, 4, 5, device=device) + shifts = [(0, 0), (1, 0), (0, 1), (1, 1)] + for perm in itertools.permutations((0, 1, 2, 3), r=4): + for shift in shifts: + for p in [1e-10, 0.3, 0.5, 0.7]: + mod = cls(p=p) + permuted_inp = inp.permute(perm).contiguous().permute(invert_perm(perm)) + permuted_inp = permuted_inp[shift[0]:, shift[1]:, :, :] + out = mod(permuted_inp) + + self.assertTrue(out.permute(perm).is_contiguous()) + self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5) + if p == 1e-10: + self.assertEqual(permuted_inp, out) + else: + self.assertNotEqual(permuted_inp, out) + def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float): # default case track_running_stats=False b, c = input.size(0), input.size(1) @@ -10002,10 +10847,15 @@ def test_affine_3d_rotateRandom(self, device): self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) def test_Dropout(self, device): - input = torch.Tensor(1000) + input = torch.empty(1000) self._test_dropout(nn.Dropout, device, input) - if self.device_type == 'cuda' and TEST_WITH_ROCM: + self._test_dropout_discontiguous(nn.Dropout, device) + self._test_dropout_discontiguous(nn.Dropout, device, memory_format=torch.channels_last) + + self._test_dropout_stride_mean_preserve(nn.Dropout, device) + + if self.device_type == 'cuda': input = input.bfloat16() self._test_dropout(nn.Dropout, device, input) @@ -10014,19 +10864,25 @@ def test_Dropout2d(self, device): w = random.randint(1, 5) h = random.randint(1, 5) num_features = 1000 - input = torch.Tensor(num_features, b, w, h) + input = torch.empty(num_features, b, w, h) self._test_dropout(nn.Dropout2d, device, input) self._test_dropout(nn.Dropout2d, device, input, memory_format=torch.channels_last) + self._test_dropout_discontiguous(nn.Dropout2d, device) + self._test_dropout_discontiguous(nn.Dropout2d, device, memory_format=torch.channels_last) + def test_Dropout3d(self, device): b = random.randint(1, 5) w = random.randint(1, 5) h = random.randint(1, 5) d = random.randint(1, 2) num_features = 1000 - input = torch.Tensor(num_features, b, d, w, h) + input = torch.empty(num_features, b, d, w, h) self._test_dropout(nn.Dropout3d, device, input) + self._test_dropout_discontiguous(nn.Dropout3d, device) + self._test_dropout_discontiguous(nn.Dropout3d, device, memory_format=torch.channels_last) + def test_InstanceNorm1d_general(self, device): b = random.randint(3, 5) c = random.randint(3, 5) @@ -10097,6 +10953,34 @@ def test_GroupNorm_empty(self, device): with torch.backends.cudnn.flags(enabled=False): self._test_module_empty_input(mod, inp) + @onlyOnCPUAndCUDA + def test_ReplicationPad_empty(self, device): + for mod, inp in [ + (torch.nn.ReplicationPad1d(3), torch.randn(0, 3, 10, device=device)), + (torch.nn.ReplicationPad2d(3), torch.randn(0, 3, 10, 10, device=device)), + (torch.nn.ReplicationPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device))]: + self._test_module_empty_input(mod, inp, check_size=False) + + with self.assertRaisesRegex(NotImplementedError, 'Only 3D'): + mod = torch.nn.ReplicationPad1d(2) + inp = torch.randn(3, 10, device=device) + mod(inp) + + with self.assertRaisesRegex(RuntimeError, 'Expected 2D or 3D'): + mod = torch.nn.ReplicationPad1d(2) + inp = torch.randn(3, 0, 10, device=device) + mod(inp) + + with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'): + mod = torch.nn.ReplicationPad2d((2, 2, 2, 2)) + inp = torch.randn(43, 0, 10, 10, device=device) + mod(inp) + + with self.assertRaisesRegex(RuntimeError, 'Expected 4D or 5D'): + mod = torch.nn.ReplicationPad3d((2, 2, 2, 2, 2, 2)) + inp = torch.randn(3, 0, 10, 10, 10, device=device) + mod(inp) + @onlyOnCPUAndCUDA def test_ReflectionPad_empty(self, device): for mod, inp in [ @@ -10114,6 +10998,46 @@ def test_ReflectionPad_empty(self, device): inp = torch.randn(3, 0, 10, 10, device=device) mod(inp) + + @onlyOnCPUAndCUDA + @dtypes(torch.float, torch.double) + def test_MarginLoss_empty(self, device, dtype): + for mod, x, y in [ + (torch.nn.MultiMarginLoss().to(device), + torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype), + torch.ones(0, device=device).type(torch.long)), + (torch.nn.MultiLabelMarginLoss().to(device), + torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype), + torch.ones(0, 10, device=device).type(torch.long))]: + + out = mod(x, y) + out.sum().backward() + + self.assertEqual(x, torch.zeros_like(x)) + self.assertEqual(x.grad, torch.zeros_like(x)) + + with self.assertRaisesRegex(RuntimeError, 'Expected'): + x = torch.randn(0, requires_grad=True, device=device, dtype=dtype) + y = torch.ones(10, device=device).type(torch.long) + mod(x, y) + + with self.assertRaisesRegex(RuntimeError, 'Expected'): + x = torch.randn(10, 0, requires_grad=True, device=device, dtype=dtype) + y = torch.ones(10, 0, device=device).type(torch.long) + mod(x, y) + + + @onlyOnCPUAndCUDA + def test_Unfold_empty(self, device): + inp = torch.randn(0, 3, 3, 4, device=device) + unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device) + self._test_module_empty_input(unfold, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'): + inp = torch.randn(3, 0, 3, 4, device=device) + unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device) + unfold(inp) + @onlyCUDA @dtypes(torch.float, torch.double) @tf32_on_and_off(0.005) @@ -10219,8 +11143,22 @@ def test_convTranspose_empty(self, device): with torch.backends.cudnn.flags(enabled=False): self._test_module_empty_input(mod, inp, check_size=False) + @onlyOnCPUAndCUDA + def test_AvgPool2d_empty(self, device): + avgpool = torch.nn.AvgPool2d(3, stride=2).to(device) + inp = torch.randn(0, 16, 20, 32, device=device) + self._test_module_empty_input(avgpool, inp, check_size=False) + + clast_inp = torch.randn(0, 16, 20, 32, device=device).contiguous(memory_format=torch.channels_last) + self._test_module_empty_input(avgpool, clast_inp, check_size=False) + + # test with empty non-batch input + with self.assertRaisesRegex(RuntimeError, '3D or 4D'): + inp = torch.randn(16, 0, 20, 32, device=device) + avgpool(inp) + @onlyCUDA - @largeCUDATensorTest('16GB') + @largeTensorTest('16GB') def test_prelu_backward_32bit_indexing(self, device): m = torch.nn.PReLU().cuda().half() input_ = torch.ones((1024, 1024, 1024, 2), dtype=torch.half, device=device) @@ -10328,6 +11266,7 @@ def verify_reduction_scalars(input, reduction, output): @onlyOnCPUAndCUDA def test_invalid_reduction_strings(self, device): input = torch.randn(3, 5, requires_grad=True, device=device) + cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat) target = torch.tensor([1, 0, 4], device=device) for reduction in ['none', 'invalid']: @@ -10344,6 +11283,7 @@ def v(fn): v(lambda: F.kl_div(input, input, reduction=reduction)) v(lambda: F.smooth_l1_loss(input, input, reduction=reduction)) v(lambda: F.l1_loss(input, input, reduction=reduction)) + v(lambda: F.l1_loss(cinput, cinput, reduction=reduction)) v(lambda: F.mse_loss(input, input, reduction=reduction)) v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction)) v(lambda: F.poisson_nll_loss(input, input, reduction=reduction)) @@ -10402,6 +11342,27 @@ def test(nonlinearity, *args, **kwargs): test('threshold', 3, 2) test('threshold', 3, 2, inplace=True) + def test_pooling_shape(self, device): + ''' Test the output shape calculation for pooling functions ''' + + # Checks output shape against expected for 1D, 2D and 3D + def check(expected_out_shape, sizes, *args, **kwargs): + for kernel in ['max', 'avg']: + for i in [1, 2, 3]: + if hasattr(torch.nn.functional, f'{kernel}_pool{i}d'): + op = getattr(torch.nn.functional, f'{kernel}_pool{i}d') + t = torch.randn(sizes[:i + 2], device=device) + self.assertEqual(op(t, *args, **kwargs).shape, expected_out_shape[:i + 2]) + + check((1, 1, 3, 3, 4), (1, 1, 5, 6, 7), kernel_size=1, stride=2, padding=0, ceil_mode=True) + check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=False) + check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=True) + + # Test case from issue https://github.com/pytorch/pytorch/issues/45357 + x = torch.randn(1, 1, 6, 7, device=device) + y = torch.nn.functional.max_pool2d(x, 1, stride=(2, 2), padding=0, ceil_mode=True) + self.assertEqual(y.size(), (1, 1, 3, 4)) + @onlyOnCPUAndCUDA # TODO: fix on XLA def test_adaptive_avg_pool2d_output_size_one(self, device): def helper(size, memory_format): @@ -10492,8 +11453,10 @@ def check(x, args, message): def test_max_pool1d_corner_cases(self, device, dtype): def check(x, args, expected): model = torch.nn.MaxPool1d(*args) - tensor = torch.tensor(x, device=device, dtype=dtype) - self.assertEqual(model(tensor), torch.tensor(expected, device=device, dtype=dtype)) + if isinstance(x, list): + x = torch.tensor(x, device=device, dtype=dtype) + expected = torch.tensor(expected, device=device, dtype=dtype) + self.assertEqual(model(x), expected) # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) check([[]], (1, None, 0, 1, False, False), [[]]) @@ -10505,7 +11468,7 @@ def check(x, args, expected): check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]]) check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]]) - empty_tensor = torch.empty((2, 0, 1), dtype=torch.float32) + empty_tensor = torch.empty((2, 0, 1), device=device, dtype=dtype) check(empty_tensor, (1, None, 0, 1, False, False), empty_tensor) @onlyCPU @@ -10515,8 +11478,7 @@ def test_max_pool1d(self, device, dtype): def check(x, *args, **kwargs): model = torch.nn.MaxPool1d(*args, **kwargs) ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True) - tensor = torch.tensor(x, device=device, dtype=dtype) - self.assertEqual(model(tensor), ref_model(tensor)[0]) + self.assertEqual(model(x), ref_model(x)[0]) sizes = [random.sample(range(8, 128), 3) for _ in range(3)] kernel_sizes = random.sample(range(1, 5), 3) @@ -10527,10 +11489,11 @@ def check(x, *args, **kwargs): for size, kernel_size, stride, dilation, ceil_mode in \ itertools.product(sizes, kernel_sizes, strides, dilations, ceil_modes): padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1) - check(torch.randn(size), kernel_size, stride, padding, dilation, ceil_mode=ceil_mode) + check(torch.randn(size, device=device, dtype=dtype), + kernel_size, stride, padding, dilation, ceil_mode=ceil_mode) # Non-contiguous test - tensor = torch.randn(5, 151, 33)[::2, ::3, ::2] + tensor = torch.randn(5, 151, 33, device=device, dtype=dtype)[::2, ::3, ::2] check(tensor, 3, 2, 1, 2, ceil_mode=True) check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True) @@ -10625,6 +11588,15 @@ def fn(weight): fn = fn_wrapper(device) _assertGradAndGradgradChecks(self, fn, (weight, )) + def fn_wrapper(device): + def padding_fn(weight): + inp = torch.tensor([[0, 1, 1, 2], [1, 1, 0, 2]], dtype=torch.long).to(device) + return torch.nn.functional.embedding(inp, weight, padding_idx=1) + return padding_fn + + fn = fn_wrapper(device) + _assertGradAndGradgradChecks(self, fn, (weight, )) + def test_embedding_scalar_weight_error(self, device): indices = torch.rand(2, 2, device=device).long() weight = torch.tensor(1.0, device=device) @@ -10721,7 +11693,7 @@ def test_embedding_padding_idx(self, device, dtype): embedding.zero_grad() self.assertEqual(after, pre) - # test is flaky on ROCm CI + # Test fails on Vg20 @skipCUDAIfRocm @dtypesIfCUDA(torch.half, torch.float) @dtypes(torch.float) @@ -10773,7 +11745,8 @@ def _test_helper(shape): # test non-persistent softmax kernel _test_helper((4, 1536)) - @largeCUDATensorTest('12GB') + @onlyCUDA + @largeTensorTest('12GB') def test_conv_large_nosplit(self, device): # Here we just test the convolution correctly route to the fallback implementation # that is, it does not crash. The correctness of fallback implementation should be @@ -10914,7 +11887,7 @@ def test_grid_sample_large_index_2d(self, device, dtype): sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31, msg="View must use 64-bit indexing") for mode, padding_mode, align_corners in itertools.product( - ('nearest', 'bilinear'), ('zeros', 'border', 'reflection'), (True, False)): + ('nearest', 'bilinear', 'bicubic'), ('zeros', 'border', 'reflection'), (True, False)): a = F.grid_sample( small_image, coords, mode=mode, padding_mode=padding_mode, align_corners=align_corners) @@ -10974,7 +11947,8 @@ def test_grid_sample_large_index_3d(self, device, dtype): small_image.grad.zero_() large_view.grad.zero_() - @largeCUDATensorTest('12GB') + @onlyCUDA + @largeTensorTest('12GB') def test_conv_transposed_large(self, device): dtype = torch.half if self.device_type == 'cuda' else torch.float conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) @@ -10990,8 +11964,9 @@ def test_conv_transposed_large(self, device): self.assertEqual(maxdiff2, 0) self.assertEqual(maxdiff3, 0) + @onlyCUDA @skipCUDAIfRocm - @largeCUDATensorTest('12GB') + @largeTensorTest('12GB') def test_conv_large(self, device): dtype = torch.half if self.device_type == 'cuda' else torch.float conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype) @@ -11253,7 +12228,7 @@ def _ordered_sequence(self, device, dtype): def _padded_sequence(self, device, dtype): """Create Tensor of random padded sequences""" ordered = self._ordered_sequence(device, dtype) - lengths = list(map(len, ordered)) + lengths = [len(i) for i in ordered] padded_tensor = rnn_utils.pad_sequence(ordered) return padded_tensor, lengths @@ -11295,7 +12270,8 @@ def test_overwrite_module_params_on_conversion_cpu_device(self, device): m = nn.Linear(20, 10) mw = m.weight[:] m.to(device) - mw[0][0] = 5 + with torch.no_grad(): + mw[0][0] = 5 self.assertTrue(mw[0][0] == mw._base[0][0]) # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, @@ -11321,9 +12297,9 @@ def test_embedding_max_norm_device(self, device, dtype): self.assertEqual(output[1], output[2]) self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) - # test is flaky on ROCm CI - @onlyCUDA + # Test fails on Vg20 @skipCUDAIfRocm + @onlyCUDA @dtypes(torch.half, torch.float) def test_softmax(self, device, dtype): input = torch.rand(32, 100, device=device, dtype=dtype, requires_grad=True) @@ -11381,25 +12357,27 @@ def test_pooling_size_empty(self, device): self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool2d(t, [])) self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool3d(t, [])) - def test_embedding_bag_empty_input(self, device): + @dtypes(torch.int, torch.long) + def test_embedding_bag_empty_input(self, device, dtype): m = 4 n = 3 - x = torch.tensor([], device=device, dtype=torch.long) + x = torch.tensor([], device=device, dtype=dtype) for sparse in [True, False]: Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse) Embed.to(device) - output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=torch.long)) + output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=dtype)) self.assertEqual(output, torch.zeros_like(output)) - output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=torch.long)) + output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtype)) self.assertEqual(output, torch.zeros_like(output)) - def test_EmbeddingBag_per_sample_weights_failures(self, device): + @dtypes(torch.int, torch.long) + def test_EmbeddingBag_per_sample_weights_failures(self, device, dtype): # Failure 1: mismatched embeddings / per_sample_weights dtype es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device) - input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) - offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device) + input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device) + offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device) per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device) if device == 'cpu': with self.assertRaisesRegex(RuntimeError, 'have the same type as'): @@ -11409,14 +12387,14 @@ def test_EmbeddingBag_per_sample_weights_failures(self, device): es(input, offsets, per_sample_weights) # Failure 2.1: input/per_sample_weights have different sizes (1d input) - input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device) - offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device) + input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device) + offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device) per_sample_weights = torch.randn(5, dtype=torch.float, device=device) with self.assertRaisesRegex(ValueError, 'same shape as the input'): es(input, offsets, per_sample_weights) # Failure 2.2: input/per_sample_weights have different sizes (2d input) - input = torch.randint(5, (7, 3), dtype=torch.long, device=device) + input = torch.randint(5, (7, 3), dtype=dtype, device=device) offsets = None per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device) with self.assertRaisesRegex(ValueError, 'same shape as the input'): @@ -11426,7 +12404,7 @@ def test_EmbeddingBag_per_sample_weights_failures(self, device): for unsupported_mode in ('max', 'mean'): es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to( dtype=torch.float, device=device) - input = torch.randint(5, (7, 3), dtype=torch.long, device=device) + input = torch.randint(5, (7, 3), dtype=dtype, device=device) offsets = None per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device) with self.assertRaisesRegex(NotImplementedError, @@ -11444,7 +12422,8 @@ def _embedding_bag_reference_impl(self, input, weight, offsets=None, mode='sum', assert input.numel() == per_sample_weights.numel() bags = [] - embeddings = weight.index_select(0, input) * per_sample_weights.unsqueeze(1) + long_input = input.to(torch.long) + embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze(1) if include_last_offset: for index in range(len(offsets) - 1): offset = offsets[index] @@ -11469,7 +12448,7 @@ def _embedding_bag_reference_impl(self, input, weight, offsets=None, mode='sum', if index + 1 < len(offsets): next_offset = offsets[index + 1] else: - next_offset = len(input) + next_offset = len(long_input) length = next_offset - offset if length == 0: bags.append( @@ -11487,14 +12466,55 @@ def _embedding_bag_reference_impl(self, input, weight, offsets=None, mode='sum', bags.append(embeddings.narrow(0, offset, length).max(0)[0]) return torch.stack(bags) - def test_EmbeddingBag_per_sample_weights_and_offsets(self, device): - def test_per_sample_weights(mode, dtype, trainable_scale): - es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes): + # Test empty input and per sample weight, and backward pass. There was a CUDA + # invalid configuration bug (more context in #46572) + def test_per_sample_weights(mode, trainable_scale): + es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device) + es.weight.data.copy_( + torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight)) + input = torch.tensor([], device=device, dtype=dtypes[0]) + offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[0]) + per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \ + .requires_grad_(trainable_scale) + ref_per_sample_weights = \ + per_sample_weights.detach().requires_grad_(trainable_scale) + reference_weights = es.weight.detach().requires_grad_() + + expected = self._embedding_bag_reference_impl( + input, reference_weights, offsets, mode, ref_per_sample_weights) + result = es(input, offsets, per_sample_weights) + self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) + + grad = torch.randn_like(expected) + result.backward(grad) + # the reference impl doesn't have grad fn for empty input; but the grad should + # simply be a zero tensor + ref_weights_grad = torch.zeros_like(es.weight) + self.assertEqual(es.weight.grad, ref_weights_grad, + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) + if trainable_scale: + ref_per_sample_weights_grad = torch.empty_like(per_sample_weights) + self.assertEqual(per_sample_weights.grad, ref_per_sample_weights_grad, + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) + + modes = ('sum',) + trainable_scale = (True, False) + for mode, trainable in itertools.product(modes, trainable_scale): + test_per_sample_weights(mode, trainable) + + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes): + def test_per_sample_weights(mode, trainable_scale): + es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device) es.weight.data.copy_( - torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) - input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) - offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) - per_sample_weights = torch.randn_like(input, dtype=dtype) \ + torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight)) + input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) + offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0]) + per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \ .requires_grad_(trainable_scale) ref_per_sample_weights = \ per_sample_weights.detach().requires_grad_(trainable_scale) @@ -11503,39 +12523,37 @@ def test_per_sample_weights(mode, dtype, trainable_scale): expected = self._embedding_bag_reference_impl( input, reference_weights, offsets, mode, ref_per_sample_weights) result = es(input, offsets, per_sample_weights) - self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) grad = torch.randn_like(expected) result.backward(grad) expected.backward(grad) self.assertEqual(es.weight.grad, reference_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) if trainable_scale: self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) - if device == 'cuda': - dtypes = (torch.float, torch.double, torch.half) - else: - dtypes = (torch.float, torch.double) modes = ('sum',) trainable_scale = (True, False) - for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale): - test_per_sample_weights(mode, dtype, trainable) - - def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device): - def test_per_sample_weights_new_offsets(mode, dtype, trainable_scale, include_last_offset, has_weight=True): - es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtype, device=device) + for mode, trainable in itertools.product(modes, trainable_scale): + test_per_sample_weights(mode, trainable) + + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes): + def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True): + es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[1], device=device) es.weight.data.copy_( - torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) - input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) - offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) + torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight)) + input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) + offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0]) if include_last_offset: - offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=torch.long)), 0) + offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=dtypes[0])), 0) if has_weight: - per_sample_weights = torch.randn_like(input, device=device, dtype=dtype) \ + per_sample_weights = torch.randn_like(input, device=device, dtype=dtypes[1]) \ .requires_grad_(trainable_scale) ref_per_sample_weights = \ per_sample_weights.detach().requires_grad_(trainable_scale) @@ -11548,51 +12566,48 @@ def test_per_sample_weights_new_offsets(mode, dtype, trainable_scale, include_la expected = self._embedding_bag_reference_impl( input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset) result = es(input, offsets, per_sample_weights) - self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) grad = torch.randn_like(expected) result.backward(grad) expected.backward(grad) self.assertEqual(es.weight.grad, reference_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) if has_weight and trainable_scale: self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0) - if device == 'cuda': - dtypes = (torch.float, torch.double, torch.half) - else: - dtypes = (torch.float, torch.double) trainable_scale = (True, False) include_last_offset = (True, False) modes = (('sum', False), ('sum', True), ('max', False), ('mean', False)) - for dtype, (mode, has_weight), trainable, include_last_offset in itertools.product( - dtypes, modes, trainable_scale, include_last_offset + for (mode, has_weight), trainable, include_last_offset in itertools.product( + modes, trainable_scale, include_last_offset ): test_per_sample_weights_new_offsets( - mode, dtype, trainable, include_last_offset, has_weight + mode, trainable, include_last_offset, has_weight ) def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, mode='mean', device='cpu', - dtype=torch.float, + wdtype=torch.float, + dtype=torch.long, test_per_sample_weights=False, trainable_per_sample_weights=False, sparse=False, test_backward=True, backward_prec=None): - es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype) - e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype) + es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, wdtype) + e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype) e.weight.data.copy_(es.weight) - input = torch.randint(N, (B, L), device=device, dtype=torch.long) - offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L) - grad_output = torch.rand(B, D, device=device, dtype=dtype) + input = torch.randint(N, (B, L), device=device, dtype=dtype) + offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L) + grad_output = torch.rand(B, D, device=device, dtype=wdtype) if test_per_sample_weights: # To prevent large gradients, weights should sum to 1 for each bag per_sample_weights = \ - torch.randn(B, L, device=device, dtype=dtype).softmax(dim=-1) + torch.randn(B, L, device=device, dtype=wdtype).softmax(dim=-1) per_sample_weights_reference = \ per_sample_weights.clone().requires_grad_(trainable_per_sample_weights) per_sample_weights.requires_grad_(trainable_per_sample_weights) @@ -11614,7 +12629,7 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, assert not test_per_sample_weights ref_output = e(input).max(1)[0] - self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0) if not test_backward: return @@ -11627,7 +12642,7 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, # We have more floating point error here because we are dealing with larger numbers if backward_prec is None: - needed_prec = dtype2prec_DONTUSE[dtype] * 3 + needed_prec = dtype2prec_DONTUSE[wdtype] * 3 else: needed_prec = backward_prec @@ -11635,13 +12650,15 @@ def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None, if test_per_sample_weights and trainable_per_sample_weights: self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad, - atol=dtype2prec_DONTUSE[dtype], rtol=0) + atol=dtype2prec_DONTUSE[wdtype], rtol=0) @skipCUDAIf(True, "Temporarily disabled. See t54369166") - def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device): - def run_tests(dtype, mode, sparse, trainable_per_sample_weights): + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.half, torch.float, torch.double))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes): + def run_tests(mode, sparse, trainable_per_sample_weights): kwargs = dict(test_per_sample_weights=True, device=device, - mode=mode, dtype=dtype, sparse=sparse, + mode=mode, wdtype=dtypes[1], dtype=dtypes[0], sparse=sparse, trainable_per_sample_weights=trainable_per_sample_weights) # Simple case @@ -11656,78 +12673,76 @@ def run_tests(dtype, mode, sparse, trainable_per_sample_weights): # Large embedding_dim self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs) - dtypes = (torch.float, torch.double) modes = ('sum',) sparsity = (True, False) trainable_scale = (True, False) - for dtype, mode, sparse, trainable_per_sample_weights in \ - itertools.product(dtypes, modes, sparsity, trainable_scale): - run_tests(dtype, mode, sparse, trainable_per_sample_weights) + for mode, sparse, trainable_per_sample_weights in \ + itertools.product(modes, sparsity, trainable_scale): + run_tests(mode, sparse, trainable_per_sample_weights) # Test CUDA Dense on half precision if device == 'cuda': - dtypes = (torch.half,) modes = ('sum',) sparsity = (False,) trainable_scale = (True, False) - for dtype, mode, sparse, trainable_per_sample_weights in \ - itertools.product(dtypes, modes, sparsity, trainable_scale): - run_tests(dtype, mode, sparse, trainable_per_sample_weights) + for mode, sparse, trainable_per_sample_weights in \ + itertools.product(modes, sparsity, trainable_scale): + run_tests(mode, sparse, trainable_per_sample_weights) - def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_backward=True): + def _test_EmbeddingBag(self, device, mode, sparse, wdtype=torch.double, dtype=torch.long, test_backward=True): # check a known test example - es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype) - es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight)) - input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long) - offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long) + es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype) + es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=wdtype).view_as(es.weight)) + input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype) + offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtype) grad_output = torch.tensor( [1, 2, - 3, 4], device=device, dtype=dtype).view(2, 2) + 3, 4], device=device, dtype=wdtype).view(2, 2) grad_output_with_empty = torch.tensor( [99, 99, 1, 2, 99, 99, 3, 4, - 99, 99], device=device, dtype=dtype).view(5, 2) + 99, 99], device=device, dtype=wdtype).view(5, 2) if mode == "sum" or mode == "mean": denominator = 1 if mode == "sum" else 3 expected_output = torch.tensor( [[13, 16], - [13, 16]], device=device, dtype=dtype) / denominator + [13, 16]], device=device, dtype=wdtype) / denominator expected_output_with_empty = torch.tensor( [[0, 0], [13, 16], [0, 0], [13, 16], - [0, 0]], device=device, dtype=dtype) / denominator + [0, 0]], device=device, dtype=wdtype) / denominator expected_grad_weight = torch.tensor( [[3, 4], [5, 8], [0, 0], [1, 2], - [3, 4]], device=device, dtype=dtype) / denominator + [3, 4]], device=device, dtype=wdtype) / denominator elif mode == "max": expected_output = torch.tensor( [[7, 8], - [9, 10]], device=device, dtype=dtype) + [9, 10]], device=device, dtype=wdtype) expected_output_with_empty = torch.tensor( [[0, 0], [7, 8], [0, 0], [9, 10], - [0, 0]], device=device, dtype=dtype) + [0, 0]], device=device, dtype=wdtype) expected_grad_weight = torch.tensor( [[0, 0], [0, 0], [0, 0], [1, 2], - [3, 4]], device=device, dtype=dtype) + [3, 4]], device=device, dtype=wdtype) output = es(input, offsets) output.backward(grad_output_with_empty) @@ -11735,7 +12750,7 @@ def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_back if sparse: es_weight_grad = es.weight.grad.to_dense() self.assertEqual(output, expected_output_with_empty) - self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0) # check same example except as 2D (2 x 3) input = input.view(2, -1) @@ -11747,12 +12762,12 @@ def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_back if sparse: es_weight_grad = es.weight.grad.to_dense() self.assertEqual(output, expected_output) - self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0) # test all empty bags es.zero_grad() - inputs = torch.tensor([], dtype=torch.long, device=device) - offsets = torch.tensor([0, 0, 0, 0], device=device) + inputs = torch.tensor([], dtype=dtype, device=device) + offsets = torch.tensor([0, 0, 0, 0], dtype=dtype, device=device) es(inputs, offsets).sum().backward() dense_grad = es.weight.grad if dense_grad.is_sparse: @@ -11761,7 +12776,7 @@ def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_back # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50) - kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype, test_backward=test_backward) + kwargs = dict(mode=mode, sparse=sparse, device=device, wdtype=wdtype, dtype=dtype, test_backward=test_backward) self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs) for max_norm in (None, 3): for p in itertools.product([1, 2], repeat=4): @@ -11769,8 +12784,8 @@ def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_back # check that giving illegal input combos raises error es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse) - input = torch.ones(3, 4, dtype=torch.long) - offset = torch.arange(0, 3) + input = torch.ones(3, 4, dtype=dtype) + offset = torch.arange(0, 3, dtype=dtype) self.assertRaises(ValueError, lambda: es(input, offset)) self.assertRaises(ValueError, lambda: es(input.view(-1))) offset[0] = 1 @@ -11780,35 +12795,35 @@ def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_back offset[-1] = 100 self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset)) - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_embedding_bag_device(self, device, dtype): - self._test_EmbeddingBag(device, 'sum', False, dtype) - self._test_EmbeddingBag(device, 'mean', False, dtype) - self._test_EmbeddingBag(device, 'max', False, dtype) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_embedding_bag_device(self, device, dtypes): + self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[1], dtype=dtypes[0]) + self._test_EmbeddingBag(device, 'mean', False, wdtype=dtypes[1], dtype=dtypes[0]) + self._test_EmbeddingBag(device, 'max', False, wdtype=dtypes[1], dtype=dtypes[0]) test_backward = False if self.device_type == 'cuda': # see 'todo' in test_embedding_bag. - test_backward = dtype is not torch.float16 + test_backward = dtypes[1] is not torch.float16 elif self.device_type == 'cpu': # TODO: figure out why precision on sparse embeddings isn't the # same as for dense. - test_backward = dtype is not torch.float + test_backward = dtypes[1] is not torch.float - self._test_EmbeddingBag(device, 'sum', True, dtype, test_backward=test_backward) - self._test_EmbeddingBag(device, 'mean', True, dtype, test_backward=test_backward) + self._test_EmbeddingBag(device, 'sum', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward) + self._test_EmbeddingBag(device, 'mean', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward) - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_embedding_bag_non_contiguous_weight(self, device, dtype): - weight_tensor = torch.randn(3, 4, dtype=dtype, device=device) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) + def test_embedding_bag_non_contiguous_weight(self, device, dtypes): + weight_tensor = torch.randn(3, 4, dtype=dtypes[1], device=device) weight_tensor_non_contig = weight_tensor[:, :3] # This is non-contiguous strided. weight_tensor_contig = weight_tensor_non_contig.clone().contiguous() # Contig-strided. - index = torch.tensor([0, 1, 2], device=device) - offsets = torch.tensor([0, 2], device=device) + index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device) + offsets = torch.tensor([0, 2], dtype=dtypes[0], device=device) for mode in ['sum', 'mean', 'max']: output_non_contig = F.embedding_bag( input=index, @@ -11826,10 +12841,10 @@ def test_embedding_bag_non_contiguous_weight(self, device, dtype): @onlyCUDA - @skipCUDAIfNotRocm - def test_embedding_bag_bfloat16(self, device): - self._test_EmbeddingBag(device, 'sum', True, dtype=torch.bfloat16, test_backward=True) - self._test_EmbeddingBag(device, 'mean', True, dtype=torch.bfloat16, test_backward=True) + @dtypes(torch.int, torch.long) + def test_embedding_bag_bfloat16(self, device, dtype): + self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True) + self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True) @onlyCUDA @@ -11847,7 +12862,7 @@ def test_multihead_attention_dtype(self, device, dtype): self.assertEqual(q.size(), out[0].size()) self.assertEqual(dtype, out[0].dtype) - @dtypesIfCUDA(*ALL_TENSORTYPES2) + @dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) @dtypes(torch.float) def test_Conv2d_naive_groups(self, device, dtype): # Check that grouped convolutions matches two half convolutions @@ -11969,7 +12984,6 @@ def test_batchnorm_eval(self, device): self._test_batchnorm_eval(device) @onlyCUDA - @skipCUDAIfNotRocm def test_batchnorm_eval_bfloat16(self, device): self._test_batchnorm_eval(device, torch.bfloat16) @@ -12255,6 +13269,7 @@ def test_empty_dropout(self, device): @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) + @tf32_on_and_off(0.005) def test_variable_sequence(self, device, dtype): def pad(var, length): if var.size(0) == length: @@ -12267,7 +13282,7 @@ def maybe_index_tuple(maybe_tuple_of_tensors, index): return tuple(maybe_tuple_of_tensors[j][:, index:index + 1, :].contiguous() for j in range(2)) - def check_lengths(lengths, enforce_sorted, use_default_hiddens): + def check_lengths(lengths, enforce_sorted, use_default_hiddens, proj_size): input_size = 3 hidden_size = 4 num_layers = 2 @@ -12278,15 +13293,17 @@ def check_lengths(lengths, enforce_sorted, use_default_hiddens): dtype=dtype, requires_grad=True) num_directions = 2 if bidirectional else 1 lstm = nn.LSTM(input_size, hidden_size, bidirectional=bidirectional, - num_layers=num_layers).to(device, dtype) + num_layers=num_layers, proj_size=proj_size).to(device, dtype) lstm2 = deepcopy(lstm).to(device, dtype) x = x_leaf hidden0 = None if not use_default_hiddens: - hidden0 = tuple(torch.randn(num_directions * num_layers, len(lengths), hidden_size, - device=device, dtype=dtype) - for _ in range(2)) + real_hidden_size = hidden_size if proj_size == 0 else proj_size + hidden0 = (torch.randn(num_directions * num_layers, len(lengths), real_hidden_size, + device=device, dtype=dtype), + torch.randn(num_directions * num_layers, len(lengths), hidden_size, + device=device, dtype=dtype)) # Compute sequences separately seq_outs = [] @@ -12321,7 +13338,7 @@ def check_lengths(lengths, enforce_sorted, use_default_hiddens): for p1, p2 in zip(lstm.parameters(), lstm2.parameters()): prec = dtype2prec_DONTUSE[dtype] if dtype == torch.float16: - prec = 2e-2 + prec = 4e-2 self.assertEqual(p1.grad, p2.grad, atol=prec, rtol=0) tests = [ @@ -12333,9 +13350,16 @@ def check_lengths(lengths, enforce_sorted, use_default_hiddens): [False, [2, 1, 3, 2, 10, 5, 3]], ] + rocm_error_msg = "LSTM with projections is not supported with MIOpen" for enforce_sorted, seq_lens, in tests: for use_default_hiddens in (True, False): - check_lengths(seq_lens, enforce_sorted, use_default_hiddens) + for proj_size in [0, 2]: + # LSTM with projections is not supported with MIOpen + if device != 'cpu' and dtype == torch.float32 and TEST_WITH_ROCM and proj_size > 0: + with self.assertRaisesRegex(RuntimeError, rocm_error_msg): + check_lengths(seq_lens, enforce_sorted, use_default_hiddens, proj_size) + else: + check_lengths(seq_lens, enforce_sorted, use_default_hiddens, proj_size) def _test_batchnorm_update_stats(self, device, dtype=torch.float): module = nn.BatchNorm1d(3).to(device, dtype) @@ -12810,6 +13834,22 @@ def cosine_distance(x, y): self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6) self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6) + def test_to_complex(self, device): + m = nn.Linear(3, 5).to(device) + self.assertIs(m, m.to(device)) + m.to(torch.cfloat) + self.assertIs(m.weight.dtype, torch.cfloat) + m.to(torch.cdouble) + self.assertIs(m.weight.dtype, torch.cdouble) + m.to(torch.float) + self.assertIs(m.weight.dtype, torch.float) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + m.to(torch.cfloat) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("Complex modules are a new feature" in str(w[-1].message)) + class TestModuleGlobalHooks(TestCase): @@ -12924,7 +13964,7 @@ def bw_fail2(self, grad_input, grad_output): def test_module_backward_global_hook_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) - sig_x = torch.nn.functional.sigmoid(input) + sig_x = torch.sigmoid(input) def bw_hook(module, grad_input, grad_output): for grad in grad_input: @@ -12941,7 +13981,7 @@ def bw_hook(module, grad_input, grad_output): def test_module_global_forward_preforward_hook_writeable(self): module = nn.Sigmoid() input = torch.randn(5, 5, requires_grad=True) - sig_x = torch.nn.functional.sigmoid(input) + sig_x = torch.sigmoid(input) def forward_pre_hook(m, input): return torch.nn.functional.relu(input[0]) @@ -12952,7 +13992,7 @@ def forward_hook(m, input, output): nn.modules.module.register_module_forward_pre_hook(forward_pre_hook) nn.modules.module.register_module_forward_hook(forward_hook) output = module(input) - expected_res = -torch.nn.functional.sigmoid(torch.nn.functional.relu(input)) + expected_res = -torch.sigmoid(torch.nn.functional.relu(input)) self.assertEqual(output, expected_res) output.backward(torch.ones(5, 5) * 2, retain_graph=True) mask = (input > 0).double() @@ -13019,6 +14059,329 @@ def local_backward_hook(m, input, output): output.backward(torch.ones(5, 5), retain_graph=True) self.assertTrue(local_backward_called and global_backward_called) + +class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): + pass + + +class TestLazyModules(TestCase): + + @suppress_warnings + def test_lazy_module_parameter(self): + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + self.assertTrue(module.has_uninitialized_params()) + state_dict = module.state_dict() + self.assertIsInstance(state_dict['test_param'], UninitializedParameter) + new_module = LazyModule() + # An error is raised when there is an attempt to replace an existing parameter + # with an uninitialized one + new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5))) + with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'): + new_module.load_state_dict(state_dict) + # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one + new_module = LazyModule() + new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5))) + module.load_state_dict(new_module.state_dict()) + self.assertEqual(module.test_param, torch.ones((5, 5))) + + # Uninitialized parameters are left unchanged + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + self.assertTrue(module.has_uninitialized_params()) + + new_module = LazyModule() + new_module.register_parameter('test_param', UninitializedParameter()) + module.load_state_dict(new_module.state_dict()) + self.assertTrue(module.has_uninitialized_params()) + + @suppress_warnings + def test_lazy_module_jit(self): + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + self.assertTrue(module.has_uninitialized_params()) + with self.assertRaisesRegex(RuntimeError, 'run a forward pass'): + torch.jit.script(module) + + @suppress_warnings + def test_lazy_share_memory(self): + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + self.assertTrue(module.has_uninitialized_params()) + with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'): + module.share_memory() + + @suppress_warnings + def test_linear(self): + module = nn.LazyLinear(10) + self.assertIsInstance(module.weight, UninitializedParameter) + input = torch.ones(5, 5) + module(input) + self.assertIsInstance(module, nn.Linear) + self.assertNotIsInstance(module, nn.LazyLinear) + self.assertTrue(module.weight.shape == (10, 5)) + y = module(input) + self.assertTrue(torch.equal(torch.nn.functional.linear(input, module.weight, module.bias), y)) + + @suppress_warnings + def test_lazy_linear_pickle(self): + module = nn.LazyLinear(10) + self.assertIsInstance(module.weight, UninitializedParameter) + module = pickle.loads(pickle.dumps(module)) + self.assertIsInstance(module, nn.LazyLinear) + self.assertIsInstance(module.weight, UninitializedParameter) + input = torch.ones(5, 5) + module(input) # fully materialized + new_module = pickle.loads(pickle.dumps(module)) + self.assertIsInstance(new_module, nn.Linear) + self.assertNotIsInstance(new_module, nn.LazyLinear) + self.assertTrue(new_module.weight.shape == (10, 5)) + self.assertNotIsInstance(new_module.weight, UninitializedParameter) + + @suppress_warnings + def test_linear_state(self): + module = nn.Linear(5, 10) + lazy_module = nn.LazyLinear(10) + lazy_module.load_state_dict(module.state_dict()) + # Parameters have been initialized but the module won't become a full + # Linear one until the first iteration. This is due to + # limitations on the state_dict loading logic + self.assertFalse(lazy_module.has_uninitialized_params()) + self.assertTrue(lazy_module.weight.shape == (10, 5)) + + module = nn.Linear(5, 10) + lazy_module = nn.LazyLinear(10) + with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'): + module.load_state_dict(lazy_module.state_dict()) + + def _check_lazy_conv(self, cls, lazy_cls, func, init_args, input_shape, expected_weight_shape): + module = lazy_cls(*init_args) + self.assertIsInstance(module.weight, UninitializedParameter) + input = torch.ones(*input_shape) + module(input) + self.assertIsInstance(module, cls) + self.assertNotIsInstance(module, lazy_cls) + self.assertEqual(module.weight.shape, expected_weight_shape) + y = module(input) + self.assertTrue(torch.equal(func(input, module.weight, module.bias), y)) + + def _check_lazy_conv_pickle(self, cls, lazy_cls, init_args, input_shape, expected_weight_shape): + module = lazy_cls(*init_args) + self.assertIsInstance(module.weight, UninitializedParameter) + module = pickle.loads(pickle.dumps(module)) + self.assertIsInstance(module, lazy_cls) + self.assertIsInstance(module.weight, UninitializedParameter) + input = torch.ones(*input_shape) + module(input) # fully materialized + new_module = pickle.loads(pickle.dumps(module)) + self.assertIsInstance(new_module, cls) + self.assertNotIsInstance(new_module, lazy_cls) + self.assertEqual(new_module.weight.shape, expected_weight_shape) + self.assertNotIsInstance(new_module.weight, UninitializedParameter) + + def _check_lazy_conv_state(self, gen_module, gen_lazy_module, expected_weight_shape): + module = gen_module() + lazy_module = gen_lazy_module() + lazy_module.load_state_dict(module.state_dict()) + # Parameters have been initialized but the module won't become a full + # Conv one until the first iteration. This is due to + # limitations on the state_dict loading logic + self.assertFalse(lazy_module.has_uninitialized_params()) + self.assertEqual(lazy_module.weight.shape, expected_weight_shape) + + module = gen_module() + lazy_module = gen_lazy_module() + with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'): + module.load_state_dict(lazy_module.state_dict()) + + @suppress_warnings + def test_lazy_conv1d(self): + self._check_lazy_conv(nn.Conv1d, nn.LazyConv1d, torch.nn.functional.conv1d, + (32, 2), (192, 16, 50), (32, 16, 2)) + + @suppress_warnings + def test_lazy_conv1d_pickle(self): + self._check_lazy_conv_pickle(nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50), (32, 16, 2)) + + @suppress_warnings + def test_lazy_conv1d_state(self): + self._check_lazy_conv_state(lambda: nn.Conv1d(16, 32, 2), + lambda: nn.LazyConv1d(32, 2), + (32, 16, 2)) + + @suppress_warnings + def test_lazy_conv2d(self): + self._check_lazy_conv(nn.Conv2d, nn.LazyConv2d, torch.nn.functional.conv2d, + (32, 2), (192, 16, 8, 6), (32, 16, 2, 2)) + + @suppress_warnings + def test_lazy_conv2d_pickle(self): + self._check_lazy_conv_pickle(nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6), (32, 16, 2, 2)) + + @suppress_warnings + def test_lazy_conv2d_state(self): + self._check_lazy_conv_state(lambda: nn.Conv2d(16, 32, 2), + lambda: nn.LazyConv2d(32, 2), + (32, 16, 2, 2)) + + @suppress_warnings + def test_lazy_conv3d(self): + self._check_lazy_conv(nn.Conv3d, nn.LazyConv3d, torch.nn.functional.conv3d, + (32, 2), (192, 16, 8, 7, 6), (32, 16, 2, 2, 2)) + + @suppress_warnings + def test_lazy_conv3d_pickle(self): + self._check_lazy_conv_pickle(nn.Conv3d, nn.LazyConv3d, (32, 2), (192, 16, 8, 7, 6), (32, 16, 2, 2, 2)) + + @suppress_warnings + def test_lazy_conv3d_state(self): + self._check_lazy_conv_state(lambda: nn.Conv3d(16, 32, 2), + lambda: nn.LazyConv3d(32, 2), + (32, 16, 2, 2, 2)) + + @suppress_warnings + def test_lazy_conv_transposed1d(self): + self._check_lazy_conv(nn.ConvTranspose1d, nn.LazyConvTranspose1d, torch.nn.functional.conv_transpose1d, + (32, 2), (192, 16, 50), (16, 32, 2)) + + @suppress_warnings + def test_lazy_conv_transpose1d_pickle(self): + self._check_lazy_conv_pickle(nn.ConvTranspose1d, nn.LazyConvTranspose1d, (32, 2), (192, 16, 50), (16, 32, 2)) + + @suppress_warnings + def test_lazy_conv_transpose1d_state(self): + self._check_lazy_conv_state(lambda: nn.ConvTranspose1d(16, 32, 2), + lambda: nn.LazyConvTranspose1d(32, 2), + (16, 32, 2)) + + @suppress_warnings + def test_lazy_conv_transpose2d(self): + self._check_lazy_conv(nn.ConvTranspose2d, nn.LazyConvTranspose2d, torch.nn.functional.conv_transpose2d, + (32, 2), (192, 16, 8, 6), (16, 32, 2, 2)) + + @suppress_warnings + def test_lazy_conv_transpose2d_pickle(self): + self._check_lazy_conv_pickle(nn.ConvTranspose2d, nn.LazyConvTranspose2d, (32, 2), (192, 16, 8, 6), (16, 32, 2, 2)) + + @suppress_warnings + def test_lazy_conv_transpose2d_state(self): + self._check_lazy_conv_state(lambda: nn.ConvTranspose2d(16, 32, 2), + lambda: nn.LazyConvTranspose2d(32, 2), + (16, 32, 2, 2)) + + @suppress_warnings + def test_lazy_conv_transpose3d(self): + self._check_lazy_conv(nn.ConvTranspose3d, nn.LazyConvTranspose3d, torch.nn.functional.conv_transpose3d, + (32, 2), (192, 16, 8, 7, 6), (16, 32, 2, 2, 2)) + + @suppress_warnings + def test_lazy_conv_transpose3d_pickle(self): + self._check_lazy_conv_pickle(nn.ConvTranspose3d, nn.LazyConvTranspose3d, (32, 2), (192, 16, 8, 7, 6), (16, 32, 2, 2, 2)) + + @suppress_warnings + def test_lazy_conv_transpose3d_state(self): + self._check_lazy_conv_state(lambda: nn.ConvTranspose3d(16, 32, 2), + lambda: nn.LazyConvTranspose3d(32, 2), + (16, 32, 2, 2, 2)) + + @suppress_warnings + def test_materialize_dtype(self): + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + module.test_param.materialize(10) + self.assertTrue(module.test_param.dtype == torch.float64) + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + module.half() + module.test_param.materialize(10) + self.assertTrue(module.test_param.dtype == torch.float16) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + @suppress_warnings + def test_materialize_device(self): + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + module.test_param.materialize(10) + self.assertTrue(module.test_param.device.type == 'cpu') + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + module.cuda() + module.test_param.materialize(10) + self.assertTrue(module.test_param.device.type == 'cuda') + + @suppress_warnings + def test_chained_initialization(self): + class MyNetwork(torch.nn.Module): + def __init__(self): + super(MyNetwork, self).__init__() + self.linear_1 = torch.nn.LazyLinear(15) + self.linear_2 = torch.nn.LazyLinear(10) + + def forward(self, x): + y = self.linear_1(x) + return self.linear_2(y) + + net = MyNetwork() + net(torch.ones(5, 10)) + self.assertTrue(net.linear_1.weight.shape == (15, 10)) + self.assertTrue(net.linear_2.weight.shape == (10, 15)) + + @suppress_warnings + def test_optimizer_pass(self): + optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam, + torch.optim.AdamW, torch.optim.Adamax, + torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop, + torch.optim.RMSprop, torch.optim.LBFGS] + + def run_step(module, optim): + self.assertIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter) + module.test_param.materialize(10) + self.assertIsInstance(optim.param_groups[0]['params'][0], Parameter) + self.assertNotIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter) + for p in module.parameters(): + p.grad = torch.rand_like(p) + if isinstance(optim, torch.optim.LBFGS): + optim.step(lambda: 1.0) + else: + optim.step() + + for optim_cls in optimizers: + module = LazyModule() + module.register_parameter('test_param', UninitializedParameter()) + if optim_cls is torch.optim.SGD: + optim = optim_cls(module.parameters(), lr=0.0) + elif optim_cls is torch.optim.Adagrad: + with self.assertRaisesRegex(ValueError, 'uninitialized parameter'): + optim = optim_cls(module.parameters()) + continue + else: + optim = optim_cls(module.parameters()) + run_step(module, optim) + + @suppress_warnings + def test_weight_norm(self): + m = nn.LazyLinear(7) + with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'): + m = torch.nn.utils.weight_norm(m) + + @suppress_warnings + def test_spectral_norm(self): + m = nn.LazyLinear(7) + with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'): + m = torch.nn.utils.spectral_norm(m) + + @suppress_warnings + def test_invalid_functions(self): + param = torch.nn.parameter.UninitializedParameter() + with self.assertRaisesRegex(ValueError, 'uninitialized parameter'): + torch.empty_like(param) + + with self.assertRaisesRegex(ValueError, 'uninitialized parameter'): + torch.add(param, param) + + with self.assertRaisesRegex(ValueError, 'uninitialized parameter'): + param + param + instantiate_device_type_tests(TestNNDeviceType, globals()) if __name__ == '__main__': diff --git a/test/test_numba_integration.py b/test/test_numba_integration.py index 1f13b49420b60..5cec57915d391 100644 --- a/test/test_numba_integration.py +++ b/test/test_numba_integration.py @@ -1,7 +1,7 @@ import unittest import torch.testing._internal.common_utils as common -from torch.testing._internal.common_utils import TEST_NUMBA, TEST_NUMPY +from torch.testing._internal.common_utils import TEST_NUMPY from torch.testing._internal.common_cuda import TEST_NUMBA_CUDA, TEST_CUDA, TEST_MULTIGPU import torch @@ -9,9 +9,6 @@ if TEST_NUMPY: import numpy -if TEST_NUMBA: - import numba - if TEST_NUMBA_CUDA: import numba.cuda diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py new file mode 100644 index 0000000000000..81c385ae90a21 --- /dev/null +++ b/test/test_numpy_interop.py @@ -0,0 +1,417 @@ +import torch +import numpy as np + +from itertools import product + +from torch.testing._internal.common_utils import \ + (TestCase, run_tests) +from torch.testing._internal.common_device_type import \ + (instantiate_device_type_tests, onlyCPU, dtypes) + +# For testing handling NumPy objects and sending tensors to / accepting +# arrays from NumPy. +class TestNumPyInterop(TestCase): + # Note: the warning this tests for only appears once per program, so + # other instances of this warning should be addressed to avoid + # the tests depending on the order in which they're run. + @onlyCPU + def test_numpy_non_writeable(self, device): + arr = np.zeros(5) + arr.flags['WRITEABLE'] = False + self.assertWarns(UserWarning, lambda: torch.from_numpy(arr)) + + @onlyCPU + def test_numpy_unresizable(self, device) -> None: + x = np.zeros((2, 2)) + y = torch.from_numpy(x) + with self.assertRaises(ValueError): + x.resize((5, 5)) + + z = torch.randn(5, 5) + w = z.numpy() + with self.assertRaises(RuntimeError): + z.resize_(10, 10) + with self.assertRaises(ValueError): + w.resize((10, 10)) + + @onlyCPU + def test_to_numpy(self, device) -> None: + def get_castable_tensor(shape, dtype): + if dtype.is_floating_point: + dtype_info = torch.finfo(dtype) + # can't directly use min and max, because for double, max - min + # is greater than double range and sampling always gives inf. + low = max(dtype_info.min, -1e10) + high = min(dtype_info.max, 1e10) + t = torch.empty(shape, dtype=torch.float64).uniform_(low, high) + else: + # can't directly use min and max, because for int64_t, max - min + # is greater than int64_t range and triggers UB. + low = max(torch.iinfo(dtype).min, int(-1e10)) + high = min(torch.iinfo(dtype).max, int(1e10)) + t = torch.empty(shape, dtype=torch.int64).random_(low, high) + return t.to(dtype) + + dtypes = [ + torch.uint8, + torch.int8, + torch.short, + torch.int, + torch.half, + torch.float, + torch.double, + torch.long, + ] + + for dtp in dtypes: + # 1D + sz = 10 + x = get_castable_tensor(sz, dtp) + y = x.numpy() + for i in range(sz): + self.assertEqual(x[i], y[i]) + + # 1D > 0 storage offset + xm = get_castable_tensor(sz * 2, dtp) + x = xm.narrow(0, sz - 1, sz) + self.assertTrue(x.storage_offset() > 0) + y = x.numpy() + for i in range(sz): + self.assertEqual(x[i], y[i]) + + def check2d(x, y): + for i in range(sz1): + for j in range(sz2): + self.assertEqual(x[i][j], y[i][j]) + + # empty + x = torch.Tensor().to(dtp) + y = x.numpy() + self.assertEqual(y.size, 0) + + # contiguous 2D + sz1 = 3 + sz2 = 5 + x = get_castable_tensor((sz1, sz2), dtp) + y = x.numpy() + check2d(x, y) + self.assertTrue(y.flags['C_CONTIGUOUS']) + + # with storage offset + xm = get_castable_tensor((sz1 * 2, sz2), dtp) + x = xm.narrow(0, sz1 - 1, sz1) + y = x.numpy() + self.assertTrue(x.storage_offset() > 0) + check2d(x, y) + self.assertTrue(y.flags['C_CONTIGUOUS']) + + # non-contiguous 2D + x = get_castable_tensor((sz2, sz1), dtp).t() + y = x.numpy() + check2d(x, y) + self.assertFalse(y.flags['C_CONTIGUOUS']) + + # with storage offset + xm = get_castable_tensor((sz2 * 2, sz1), dtp) + x = xm.narrow(0, sz2 - 1, sz2).t() + y = x.numpy() + self.assertTrue(x.storage_offset() > 0) + check2d(x, y) + + # non-contiguous 2D with holes + xm = get_castable_tensor((sz2 * 2, sz1 * 2), dtp) + x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t() + y = x.numpy() + self.assertTrue(x.storage_offset() > 0) + check2d(x, y) + + if dtp != torch.half: + # check writeable + x = get_castable_tensor((3, 4), dtp) + y = x.numpy() + self.assertTrue(y.flags.writeable) + y[0][1] = 3 + self.assertTrue(x[0][1] == 3) + y = x.t().numpy() + self.assertTrue(y.flags.writeable) + y[0][1] = 3 + self.assertTrue(x[0][1] == 3) + + def test_to_numpy_bool(self, device) -> None: + x = torch.tensor([True, False], dtype=torch.bool) + self.assertEqual(x.dtype, torch.bool) + + y = x.numpy() + self.assertEqual(y.dtype, np.bool) + for i in range(len(x)): + self.assertEqual(x[i], y[i]) + + x = torch.tensor([True], dtype=torch.bool) + self.assertEqual(x.dtype, torch.bool) + + y = x.numpy() + self.assertEqual(y.dtype, np.bool) + self.assertEqual(x[0], y[0]) + + def test_from_numpy(self, device) -> None: + dtypes = [ + np.double, + np.float, + np.float16, + np.complex64, + np.complex128, + np.int64, + np.int32, + np.int16, + np.int8, + np.uint8, + np.longlong, + np.bool, + ] + complex_dtypes = [ + np.complex64, + np.complex128, + ] + + for dtype in dtypes: + array = np.array([1, 2, 3, 4], dtype=dtype) + tensor_from_array = torch.from_numpy(array) + # TODO: change to tensor equality check once HalfTensor + # implements `==` + for i in range(len(array)): + self.assertEqual(tensor_from_array[i], array[i]) + # ufunc 'remainder' not supported for complex dtypes + if dtype not in complex_dtypes: + # This is a special test case for Windows + # https://github.com/pytorch/pytorch/issues/22615 + array2 = array % 2 + tensor_from_array2 = torch.from_numpy(array2) + for i in range(len(array2)): + self.assertEqual(tensor_from_array2[i], array2[i]) + + # Test unsupported type + array = np.array([1, 2, 3, 4], dtype=np.uint16) + with self.assertRaises(TypeError): + tensor_from_array = torch.from_numpy(array) + + # check storage offset + x = np.linspace(1, 125, 125) + x.shape = (5, 5, 5) + x = x[1] + expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[1] + self.assertEqual(torch.from_numpy(x), expected) + + # check noncontiguous + x = np.linspace(1, 25, 25) + x.shape = (5, 5) + expected = torch.arange(1, 26, dtype=torch.float64).view(5, 5).t() + self.assertEqual(torch.from_numpy(x.T), expected) + + # check noncontiguous with holes + x = np.linspace(1, 125, 125) + x.shape = (5, 5, 5) + x = x[:, 1] + expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[:, 1] + self.assertEqual(torch.from_numpy(x), expected) + + # check zero dimensional + x = np.zeros((0, 2)) + self.assertEqual(torch.from_numpy(x).shape, (0, 2)) + x = np.zeros((2, 0)) + self.assertEqual(torch.from_numpy(x).shape, (2, 0)) + + # check ill-sized strides raise exception + x = np.array([3., 5., 8.]) + x.strides = (3,) + self.assertRaises(ValueError, lambda: torch.from_numpy(x)) + + @onlyCPU + def test_ctor_with_numpy_scalar_ctor(self, device) -> None: + dtypes = [ + np.double, + np.float, + np.float16, + np.int64, + np.int32, + np.int16, + np.uint8, + np.bool, + ] + for dtype in dtypes: + self.assertEqual(dtype(42), torch.tensor(dtype(42)).item()) + + @onlyCPU + def test_numpy_index(self, device): + i = np.int32([0, 1, 2]) + x = torch.randn(5, 5) + for idx in i: + self.assertFalse(isinstance(idx, int)) + self.assertEqual(x[idx], x[int(idx)]) + + @onlyCPU + def test_numpy_array_interface(self, device): + types = [ + torch.DoubleTensor, + torch.FloatTensor, + torch.HalfTensor, + torch.LongTensor, + torch.IntTensor, + torch.ShortTensor, + torch.ByteTensor, + ] + dtypes = [ + np.float64, + np.float32, + np.float16, + np.int64, + np.int32, + np.int16, + np.uint8, + ] + for tp, dtype in zip(types, dtypes): + if np.dtype(dtype).kind == 'u': + # .type expects a XxxTensor, which have no type hints on + # purpose, so ignore during mypy type checking + x = torch.Tensor([1, 2, 3, 4]).type(tp) # type: ignore + array = np.array([1, 2, 3, 4], dtype=dtype) + else: + x = torch.Tensor([1, -2, 3, -4]).type(tp) # type: ignore + array = np.array([1, -2, 3, -4], dtype=dtype) + + # Test __array__ w/o dtype argument + asarray = np.asarray(x) + self.assertIsInstance(asarray, np.ndarray) + self.assertEqual(asarray.dtype, dtype) + for i in range(len(x)): + self.assertEqual(asarray[i], x[i]) + + # Test __array_wrap__, same dtype + abs_x = np.abs(x) + abs_array = np.abs(array) + self.assertIsInstance(abs_x, tp) + for i in range(len(x)): + self.assertEqual(abs_x[i], abs_array[i]) + + # Test __array__ with dtype argument + for dtype in dtypes: + x = torch.IntTensor([1, -2, 3, -4]) + asarray = np.asarray(x, dtype=dtype) + self.assertEqual(asarray.dtype, dtype) + if np.dtype(dtype).kind == 'u': + wrapped_x = np.array([1, -2, 3, -4], dtype=dtype) + for i in range(len(x)): + self.assertEqual(asarray[i], wrapped_x[i]) + else: + for i in range(len(x)): + self.assertEqual(asarray[i], x[i]) + + # Test some math functions with float types + float_types = [torch.DoubleTensor, torch.FloatTensor] + float_dtypes = [np.float64, np.float32] + for tp, dtype in zip(float_types, float_dtypes): + x = torch.Tensor([1, 2, 3, 4]).type(tp) # type: ignore + array = np.array([1, 2, 3, 4], dtype=dtype) + for func in ['sin', 'sqrt', 'ceil']: + ufunc = getattr(np, func) + res_x = ufunc(x) + res_array = ufunc(array) + self.assertIsInstance(res_x, tp) + for i in range(len(x)): + self.assertEqual(res_x[i], res_array[i]) + + # Test functions with boolean return value + for tp, dtype in zip(types, dtypes): + x = torch.Tensor([1, 2, 3, 4]).type(tp) # type: ignore + array = np.array([1, 2, 3, 4], dtype=dtype) + geq2_x = np.greater_equal(x, 2) + geq2_array = np.greater_equal(array, 2).astype('uint8') + self.assertIsInstance(geq2_x, torch.ByteTensor) + for i in range(len(x)): + self.assertEqual(geq2_x[i], geq2_array[i]) + + @onlyCPU + def test_multiplication_numpy_scalar(self, device) -> None: + for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]: + for t_dtype in [torch.float, torch.double]: + np_sc = np_dtype(2.0) + t = torch.ones(2, requires_grad=True, dtype=t_dtype) + r1 = t * np_sc + self.assertIsInstance(r1, torch.Tensor) + self.assertTrue(r1.dtype == t_dtype) + self.assertTrue(r1.requires_grad) + r2 = np_sc * t + self.assertIsInstance(r2, torch.Tensor) + self.assertTrue(r2.dtype == t_dtype) + self.assertTrue(r2.requires_grad) + + @onlyCPU + def test_parse_numpy_int(self, device): + self.assertRaisesRegex(RuntimeError, "Overflow", + lambda: torch.mean(torch.randn(1, 1), np.uint64(-1))) + # https://github.com/pytorch/pytorch/issues/29252 + for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]: + scalar = 3 + np_arr = np.array([scalar], dtype=nptype) + np_val = np_arr[0] + + # np integral type can be treated as a python int in native functions with + # int parameters: + self.assertEqual(torch.ones(5).diag(scalar), torch.ones(5).diag(np_val)) + self.assertEqual(torch.ones([2, 2, 2, 2]).mean(scalar), torch.ones([2, 2, 2, 2]).mean(np_val)) + + # numpy integral type parses like a python int in custom python bindings: + self.assertEqual(torch.Storage(np_val).size(), scalar) # type: ignore + + tensor = torch.tensor([2], dtype=torch.int) + tensor[0] = np_val + self.assertEqual(tensor[0], np_val) + + # Original reported issue, np integral type parses to the correct + # PyTorch integral type when passed for a `Scalar` parameter in + # arithmetic operations: + t = torch.from_numpy(np_arr) + self.assertEqual((t + np_val).dtype, t.dtype) + self.assertEqual((np_val + t).dtype, t.dtype) + + def test_has_storage_numpy(self, device): + for dtype in [np.float32, np.float64, np.int64, + np.int32, np.int16, np.uint8]: + arr = np.array([1], dtype=dtype) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.float32).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.double).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.int).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.long).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.uint8).storage()) + + @dtypes(*torch.testing.get_all_dtypes()) + def test_numpy_scalar_cmp(self, device, dtype): + if dtype.is_complex: + tensors = (torch.tensor(complex(1, 3), dtype=dtype, device=device), + torch.tensor([complex(1, 3), 0, 2j], dtype=dtype, device=device), + torch.tensor([[complex(3, 1), 0], [-1j, 5]], dtype=dtype, device=device)) + else: + tensors = (torch.tensor(3, dtype=dtype, device=device), + torch.tensor([1, 0, -3], dtype=dtype, device=device), + torch.tensor([[3, 0, -1], [3, 5, 4]], dtype=dtype, device=device)) + + for tensor in tensors: + if dtype == torch.bfloat16: + with self.assertRaises(TypeError): + np_array = tensor.cpu().numpy() + continue + + np_array = tensor.cpu().numpy() + for t, a in product((tensor.flatten()[0], tensor.flatten()[0].item()), + (np_array.flatten()[0], np_array.flatten()[0].item())): + self.assertEqual(t, a) + if dtype == torch.complex64 and torch.is_tensor(t) and type(a) == np.complex64: + # TODO: Imaginary part is dropped in this case. Need fix. + # https://github.com/pytorch/pytorch/issues/43579 + self.assertFalse(t == a) + else: + self.assertTrue(t == a) + +instantiate_device_type_tests(TestNumPyInterop, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_op_aliases.py b/test/test_op_aliases.py index 8a106d7860d1a..2a410e6ad6fc8 100644 --- a/test/test_op_aliases.py +++ b/test/test_op_aliases.py @@ -6,6 +6,7 @@ from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, skipCPUIfNoLapack, skipCUDAIfNoMagma, onlyCPU) +from collections.abc import Sequence # Information for generating an alias test # NOTE: ending the alias_name with an underscore will interpret the test @@ -150,6 +151,18 @@ def __init__(self, AliasInfo('true_divide_', torch.Tensor.true_divide_, 'div_', torch.Tensor.div_, lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,), decorators=(onlyCPU,)), + AliasInfo('swapdims', torch.swapdims, 'transpose', torch.transpose, + lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)), + AliasInfo('swapdims_', torch.Tensor.swapdims_, 'transpose_', torch.Tensor.transpose_, + lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)), + AliasInfo('swapaxes', torch.swapaxes, 'transpose', torch.transpose, + lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)), + AliasInfo('swapaxes_', torch.Tensor.swapaxes_, 'transpose_', torch.Tensor.transpose_, + lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)), + AliasInfo('row_stack', torch.row_stack, 'vstack', torch.vstack, + lambda d: ((torch.randn(20, device=d), torch.randn(20, device=d)))), + AliasInfo('moveaxis', torch.moveaxis, 'movedim', torch.movedim, + lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)), ) # Placeholder test class for validating that aliases are correctly @@ -157,6 +170,14 @@ def __init__(self, class TestOpNormalization(JitTestCase): pass + +# Clone input tensor and sequence of Tensors +def clone_inp(inp): + if isinstance(inp, Sequence): + return list(map(torch.clone, inp)) + else: + return inp.clone() + # Generates alias tests and adds them to the specified class (cls) def create_alias_tests(cls): for info in alias_infos: @@ -180,10 +201,18 @@ def _fn(t): arg_string = ', '.join((str(arg) for arg in info.get_args(device))) script = fn_template.format(alias_name=info.alias_name, args=arg_string) else: - fn_template = ''' - def _fn(t): + is_input_tensor_list = isinstance(info.get_input(device), Sequence) + # For sequence of Tensors, annotate the type to be List[Tensor] + if is_input_tensor_list: + fn_template = ''' + def _fn(t: List[Tensor]): return op(t{args}) - ''' + ''' + else: + fn_template = ''' + def _fn(t): + return op(t{args}) + ''' arg_string = ", " + ', '.join((str(arg) for arg in info.get_args(device))) script = fn_template.format(args=arg_string) @@ -192,8 +221,8 @@ def _fn(t): # Acquires and checks the graph remaps the alias inp = info.get_input(device) - scripted(inp.clone()) - graph = scripted.graph_for(inp.clone()) + scripted(clone_inp(inp)) + graph = scripted.graph_for(clone_inp(inp)) FileCheck().check(info.original_name).check_not(info.alias_name).run(graph) # Checks that tracing converts aliases @@ -203,9 +232,9 @@ def _fn(t): def _fn(t, info=info, args=args): return info.alias_op(t, *args) - traced = torch.jit.trace(_fn, (inp.clone(),)) - traced(inp.clone()) - graph = traced.graph_for(inp.clone()) + traced = torch.jit.trace(_fn, (clone_inp(inp),)) + traced(clone_inp(inp)) + graph = traced.graph_for(clone_inp(inp)) FileCheck().check(info.original_name).check_not(info.alias_name).run(graph) # Applies decorators @@ -223,10 +252,10 @@ def _test_alias_computation(self, device, info=info): inp = info.get_input(device) args = info.get_args(device) - alias_input = inp.clone() + alias_input = clone_inp(inp) alias_result = alias_op(alias_input, *args) - original_input = inp.clone() + original_input = clone_inp(inp) original_result = alias_op(original_input, *args) self.assertEqual(alias_input, original_input, atol=0, rtol=0) diff --git a/test/test_ops.py b/test/test_ops.py index 28570d9892aba..8dbb4ccd62c94 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,14 +2,20 @@ import torch +from torch.testing import floating_and_complex_types_and from torch.testing._internal.common_utils import \ - (TestCase, run_tests) + (TestCase, run_tests, IS_SANDCASTLE, clone_input_helper) from torch.testing._internal.common_methods_invocations import \ (op_db) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA, skipCUDAIfRocm) + (instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes) +from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference from torch.autograd.gradcheck import gradcheck, gradgradcheck +from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \ + check_alias_annotation +from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining + # Tests that apply to all operators @@ -21,23 +27,25 @@ class TestOpInfo(TestCase): # throws a runtime error @skipCUDAIfRocm @onlyOnCPUAndCUDA - @ops(op_db, unsupported_dtypes_only=True) + @ops(op_db, dtypes=OpDTypes.unsupported) def test_unsupported_dtypes(self, device, dtype, op): - samples = op.sample_inputs(device, dtype) - if len(samples) == 0: - self.skipTest("Skipped! No sample inputs!") - - # NOTE: only tests on first sample - sample = samples[0] + # sample_inputs can have a function for generating the input that doesn't work for specified dtype + # https://github.com/pytorch/pytorch/issues/49024 with self.assertRaises(RuntimeError): - op(sample.input, *sample.args, **sample.kwargs) + samples = op.sample_inputs(device, dtype) + if len(samples) == 0: + self.skipTest("Skipped! No sample inputs!") + + # NOTE: only tests on first sample + sample = samples[0] + op(*sample.input, *sample.args, **sample.kwargs) # Verifies that ops have their supported dtypes # registered correctly by testing that each claimed supported dtype # does NOT throw a runtime error @skipCUDAIfRocm @onlyOnCPUAndCUDA - @ops(op_db) + @ops(op_db, dtypes=OpDTypes.supported) def test_supported_dtypes(self, device, dtype, op): samples = op.sample_inputs(device, dtype) if len(samples) == 0: @@ -45,7 +53,12 @@ def test_supported_dtypes(self, device, dtype, op): # NOTE: only tests on first sample sample = samples[0] - op(sample.input, *sample.args, **sample.kwargs) + op(*sample.input, *sample.args, **sample.kwargs) + + +# gradcheck requires double precision +_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported, + allowed_dtypes=[torch.double, torch.cdouble]) class TestGradients(TestCase): @@ -68,15 +81,26 @@ def _check_helper(self, device, dtype, op, variant, check): samples = op.sample_inputs(device, dtype, requires_grad=True) for sample in samples: - partial_fn = partial(variant, **sample.kwargs) + if sample.output_process_fn_grad is not None: + out_fn = sample.output_process_fn_grad + + def variant_out_fn(*args, **kwargs): + return out_fn(variant(*args, **kwargs)) + else: + variant_out_fn = variant + + def fn(*inputs): + output = variant_out_fn(*inputs, **sample.kwargs) + return op.output_func(output) + if check == 'gradcheck': - self.assertTrue(gradcheck(partial_fn, (sample.input,) + sample.args, + self.assertTrue(gradcheck(fn, (*sample.input,) + sample.args, check_grad_dtypes=True)) elif check == 'gradgradcheck': - self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args, + self.assertTrue(gradgradcheck(fn, (*sample.input,) + sample.args, gen_non_contig_grad_outputs=False, check_grad_dtypes=True)) - self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args, + self.assertTrue(gradgradcheck(fn, (*sample.input,) + sample.args, gen_non_contig_grad_outputs=True, check_grad_dtypes=True)) else: @@ -88,51 +112,241 @@ def _grad_test_helper(self, device, dtype, op, variant): def _gradgrad_test_helper(self, device, dtype, op, variant): return self._check_helper(device, dtype, op, variant, 'gradgradcheck') + def _skip_helper(self, op, dtype): + if not op.test_complex_grad and dtype.is_complex: + self.skipTest("Skipped! complex grad tests marked to skip.") + # Tests that gradients are computed correctly - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) - @ops(op_db) + @_gradcheck_ops(op_db) def test_fn_grad(self, device, dtype, op): + self._skip_helper(op, dtype) self._grad_test_helper(device, dtype, op, op.get_op()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) - @ops(op_db) - def test_method_grad(self, device, dtype, op): - self._grad_test_helper(device, dtype, op, op.get_method()) + # Method grad (and gradgrad, see below) tests are disabled since they're + # costly and redundant with function grad (and gradgad) tests + # @_gradcheck_ops(op_db) + # def test_method_grad(self, device, dtype, op): + # self._skip_helper(op, dtype) + # self._grad_test_helper(device, dtype, op, op.get_method()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) - @ops(op_db) + @_gradcheck_ops(op_db) def test_inplace_grad(self, device, dtype, op): + self._skip_helper(op, dtype) if not op.test_inplace_grad: self.skipTest("Skipped! Inplace gradcheck marked to skip.") self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace())) - # TODO(@anjali411) enable this for torch.cdouble. # Test that gradients of gradients are computed correctly - @dtypes(torch.double) - @ops(op_db) + @_gradcheck_ops(op_db) def test_fn_gradgrad(self, device, dtype, op): + self._skip_helper(op, dtype) self._gradgrad_test_helper(device, dtype, op, op.get_op()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) - @ops(op_db) - def test_method_gradgrad(self, device, dtype, op): - self._gradgrad_test_helper(device, dtype, op, op.get_method()) + # Method gradgrad (and grad, see above) tests are disabled since they're + # costly and redundant with function gradgrad (and grad) tests + # @_gradcheck_ops(op_db) + # def test_method_gradgrad(self, device, dtype, op): + # self._skip_helper(op, dtype) + # self._gradgrad_test_helper(device, dtype, op, op.get_method()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) - @ops(op_db) + @_gradcheck_ops(op_db) def test_inplace_gradgrad(self, device, dtype, op): + self._skip_helper(op, dtype) if not op.test_inplace_grad: self.skipTest("Skipped! Inplace gradgradcheck marked to skip.") self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace())) +# Tests operators for consistency between JIT and eager, also checks +# correctness of JIT specific alias schemas and intended +# autodifferentiation behavior. +# Inherits from JitCommonTestCase instead of TestCase directly to share +# functionality with original test_jit.py method operator tests +class TestCommon(JitCommonTestCase): + exact_dtype = True + + # Compares variant's backward + # NOTE: verifies it fails when the forward fails + def check_variant_backward(self, input, forward_result, expected_grad, expected_exception): + variant_exception_during_backwards = False + try: + forward_result.sum().backward() + variant_grad = input.grad + input.grad = None + except Exception as e: + if not expected_exception: + self.fail("Unexpected exception during backwards!") + variant_exception_during_backwards = True + + if expected_exception != variant_exception_during_backwards: + self.fail("Unexpected success during backwards!") + + if not expected_exception: + self.assertEqual(variant_grad, expected_grad) + + # Tests that the forward and backward passes of operations produce the + # same values for the cross-product of op variants (method, inplace) + # against eager's gold standard op function variant + @ops(op_db) + def test_variant_consistency_eager(self, device, dtype, op): + samples = op.sample_inputs(device, dtype, requires_grad=True) + if len(samples) == 0: + self.skipTest("Skipped! No sample inputs!") + + for sample in samples: + # Acquires variants to test + method = op.get_method() + inplace = op.get_inplace() + variants = (v for v in (method, inplace) if v is not None) + # Computes expected forward + + # below calls op's function variant + expected_forward = op(*sample.input, *sample.args, **sample.kwargs) + + # Computes expected backward + # NOTE: backward may fail for some dtypes + exception_during_backwards = False + expected_grad = None + try: + expected_forward.sum().backward() + expected_grad = sample.input.grad + sample.input.grad = None + except Exception as e: + exception_during_backwards = True + + # Test eager consistency + for variant in variants: + # Verifies that inplace operations that promote int->float fail + # on tensors with integer dtypes. + if (variant is inplace and op.promotes_integers_to_float and + dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)): + try: + variant_forward = variant(*(clone_input_helper(input) for input in sample.input), + *sample.args, + **sample.kwargs) + except Exception as e: + continue + self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!") + # Compares variant's forward + # Note: copy the tensor-type inputs when testing inplace operation + variant_forward = variant(*(clone_input_helper(input) if variant is inplace else input + for input in sample.input), + *sample.args, + **sample.kwargs) + self.assertEqual(variant_forward, expected_forward) + + # Compares variant's backward + if variant is not inplace or op.test_inplace_grad: + self.check_variant_backward(sample.input, variant_forward, + expected_grad, exception_during_backwards) + + # Tests that the forward and backward passes of operations produce the + # same values for the cross-product of op variants (function, method, inplace) + # and runtimes (eager, traced, scripted). + # TODO WARNING: inplace x {traced, scripted} not currently tested + @ops(op_db) + def test_variant_consistency_jit(self, device, dtype, op): + samples = op.sample_inputs(device, dtype, requires_grad=True) + if len(samples) == 0: + self.skipTest("Skipped! No sample inputs!") + + for sample in samples: + + # Acquires variants to test + func = op.get_op() + method = op.get_method() + inplace = op.get_inplace() + variants = { + 'function': func, 'method': method, + # TODO: inplace tests currently fail + # 'inplace': inplace, + } + + # Test traced and scripted consistency + for func_type, variant in variants.items(): + if variant is None: + continue + + # Create accessor for script function variant + name = op.name + '_' if func_type == 'inplace' else op.name + + # run with disable_autodiff_subgraph_inlining(True) to test + # autodiff support. Context manager forces the graph to contain + # DifferentiableGraph nodes if they are present + with disable_autodiff_subgraph_inlining(): + def fn(*inputs, **kwargs): + output = func(*inputs, **kwargs) + return op.output_func(output) + + # bfloat16 grad doesn't work for some operators + dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \ + if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16) + + # Check scripted forward, grad, and grad grad + script_fn = create_script_fn(self, name, func_type, op.output_func) + + check_against_reference(self, + script_fn, + fn, + (*sample.input,) + sample.args, + sample.kwargs, + no_grad=(dtype not in dtypes_to_grad_check)) + + # Check traced forward, grad, and grad grad + traced_fn = create_traced_fn(self, variant) + check_against_reference(self, + traced_fn, + fn, + (*sample.input,) + sample.args, + sample.kwargs, + no_grad=(dtype not in dtypes_to_grad_check)) + + # Check alias annotation schema for correctness (make + # sure inputs that aren't supposed to be modified aren't) + # Note: only runs in float32 and int64 because schema isn't affected by dtype, + # so running it on all dtypes is would be excessive + if dtype in [torch.float32, torch.int32]: + check_alias_annotation(name, (*sample.input,) + sample.args, sample.kwargs, + func_type=func_type, aten_name=op.aten_name) + + # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample + if dtype is torch.float32: + # Sandcastle doesn't fuse nodes + if IS_SANDCASTLE: + # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs + nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes + fusible_nodes = [] + else: + nonfusible_nodes = op.autodiff_nonfusible_nodes + fusible_nodes = op.autodiff_fusible_nodes + + self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) + self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) + + + @ops(op_db) + def test_out(self, device, dtype, op): + if not op.supports_tensor_out: + self.skipTest("Skipped! Operator %s does not support out=..." % op.name) + + samples = op.sample_inputs(device, dtype) + if len(samples) == 0: + self.skipTest("Skipped! No sample inputs!") + + # NOTE: only tests on first sample + sample = samples[0] + # call it normally to get the expected result + expected = op(*sample.input, *sample.args, **sample.kwargs) + # call it with out=... and check we get the expected result + out_kwargs = sample.kwargs.copy() + out_kwargs['out'] = out = torch.empty_like(expected) + op(*sample.input, *sample.args, **out_kwargs) + self.assertEqual(expected, out) + + instantiate_device_type_tests(TestOpInfo, globals()) instantiate_device_type_tests(TestGradients, globals()) +instantiate_device_type_tests(TestCommon, globals()) if __name__ == '__main__': run_tests() diff --git a/test/test_optim.py b/test/test_optim.py index b00184cc93430..00d3f7a2bd131 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -6,6 +6,7 @@ import torch from torch._six import inf import torch.optim as optim +import torch.optim._multi_tensor as optim_mt import torch.nn.functional as F from torch.optim import SGD from torch.autograd import Variable @@ -249,105 +250,199 @@ def _build_params_dict_single(self, weight, bias, **kwargs): return [dict(params=bias, **kwargs)] def test_sgd(self): - self._test_basic_cases( - lambda weight, bias: optim.SGD([weight, bias], lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.SGD( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.SGD( - self._build_params_dict_single(weight, bias, lr=1e-2), - lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.SGD( - self._build_params_dict_single(weight, bias, lr=1e-2)) - ) - self._test_basic_cases( - lambda weight, bias: optim.SGD([weight, bias], lr=1e-3), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10)] - ) - self._test_basic_cases( - lambda weight, bias: optim.SGD([weight, bias], lr=1e-3), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: ReduceLROnPlateau(opt)] - ) - self._test_basic_cases( - lambda weight, bias: optim.SGD([weight, bias], lr=1e-3), - [lambda opt: StepLR(opt, gamma=0.99, step_size=10), - lambda opt: ExponentialLR(opt, gamma=0.99), - lambda opt: ReduceLROnPlateau(opt)] - ) - with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"): - optim.SGD(None, lr=1e-2, momentum=-0.5) + for optimizer in [optim.SGD, optim_mt.SGD]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict_single(weight, bias, lr=1e-2), + lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict_single(weight, bias, lr=1e-2)) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3), + [lambda opt: StepLR(opt, gamma=0.9, step_size=10)] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3), + [lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt)] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3), + [lambda opt: StepLR(opt, gamma=0.99, step_size=10), + lambda opt: ExponentialLR(opt, gamma=0.99), + lambda opt: ReduceLROnPlateau(opt)] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, momentum=1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, momentum=1, weight_decay=1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], nesterov=True, lr=1e-3, momentum=1, weight_decay=1) + ) + with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"): + optimizer(None, lr=1e-2, momentum=-0.5) def test_sgd_sparse(self): - self._test_rosenbrock_sparse( - lambda params: optim.SGD(params, lr=5e-3) - ) - self._test_rosenbrock_sparse( - lambda params: optim.SGD(params, lr=0.005), - [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)] - ) + for optimizer in [optim.SGD, optim_mt.SGD]: + self._test_rosenbrock_sparse( + lambda params: optimizer(params, lr=5e-3) + ) + self._test_rosenbrock_sparse( + lambda params: optimizer(params, lr=0.005), + [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)] + ) - def test_adam(self): - self._test_basic_cases( - lambda weight, bias: optim.Adam([weight, bias], lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.Adam( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.Adam([weight, bias], lr=1e-3, - amsgrad=True) - ) - self._test_basic_cases( - lambda weight, bias: optim.Adam( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, amsgrad=True) - ) - self._test_basic_cases( - lambda weight, bias: optim.Adam( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3), - [lambda opt: ExponentialLR(opt, gamma=0.9)] - ) - self._test_basic_cases( - lambda weight, bias: optim.Adam([weight, bias], lr=1e-3, - amsgrad=True), - [lambda opt: ExponentialLR(opt, gamma=0.9), - lambda opt: ReduceLROnPlateau(opt)] - ) - self._test_basic_cases( - lambda weight, bias: optim.Adam( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, amsgrad=True), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: ReduceLROnPlateau(opt)] - ) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): - optim.Adam(None, lr=1e-2, betas=(1.0, 0.0)) + @skipIfRocm + def test_multi_tensor_optimizers(self): + if not torch.cuda.is_available(): + return - with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): - optim.Adam(None, lr=1e-2, weight_decay=-1) + optimizer_pairs_with_flags = [ + ((optim.Adam, optim._multi_tensor.Adam), dict(weight_decay=1., amsgrad=True)), + ((optim.Adam, optim._multi_tensor.Adam), dict(weight_decay=1., amsgrad=False)), + ((optim.Adam, optim._multi_tensor.Adam), dict(weight_decay=0., amsgrad=True)), + ((optim.Adam, optim._multi_tensor.Adam), dict(weight_decay=0., amsgrad=False)), + ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=1., amsgrad=True)), + ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=1., amsgrad=False)), + ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=True)), + ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=False)), + ((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)), + ((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)), + ((optim.RMSprop, optim._multi_tensor.RMSprop), dict(weight_decay=1, momentum=1, centered=True)), + ((optim.RMSprop, optim._multi_tensor.RMSprop), dict(weight_decay=1, momentum=0, centered=True)), + ((optim.RMSprop, optim._multi_tensor.RMSprop), dict(weight_decay=1, momentum=1, centered=False)), + ((optim.RMSprop, optim._multi_tensor.RMSprop), dict(weight_decay=0, momentum=1, centered=False)), + ((optim.Rprop, optim._multi_tensor.Rprop), dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))), + ((optim.ASGD, optim._multi_tensor.ASGD), dict(weight_decay=0)), + ((optim.ASGD, optim._multi_tensor.ASGD), dict(weight_decay=1)), + ((optim.Adamax, optim._multi_tensor.Adamax), dict(weight_decay=0)), + ((optim.Adamax, optim._multi_tensor.Adamax), dict(weight_decay=1)), + ((optim.Adadelta, optim._multi_tensor.Adadelta), dict(weight_decay=0)), + ((optim.Adadelta, optim._multi_tensor.Adadelta), dict(weight_decay=1)), + ] + + kIterations = 11 + device = 'cuda' + + for optimizers, params in optimizer_pairs_with_flags: + res = [] + for opt in optimizers: + weight = torch.tensor([[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]], + dtype=torch.float64, device=device, requires_grad=True) + bias = torch.tensor([-0.1085, -0.2979, 0.6892], dtype=torch.float64, device=device, requires_grad=True) + weight2 = torch.tensor([[-0.0508, -0.3941, -0.2843]], + dtype=torch.float64, device=device, requires_grad=True) + bias2 = torch.tensor([-0.0711], dtype=torch.float64, device=device, requires_grad=True) + input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device).reshape(3, 2) + + model = torch.nn.Sequential(torch.nn.Linear(2, 3), + torch.nn.Sigmoid(), + torch.nn.Linear(3, 1), + torch.nn.Sigmoid()) + model.to(torch.float64).to(device) + + pretrained_dict = model.state_dict() + pretrained_dict['0.weight'] = weight + pretrained_dict['0.bias'] = bias + pretrained_dict['2.weight'] = weight2 + pretrained_dict['2.bias'] = bias2 + model.load_state_dict(pretrained_dict) + + optimizer = opt(model.parameters(), **params) + + for _ in range(kIterations): + optimizer.zero_grad() + output = model(input) + loss = output.sum() + loss.backward() + + if iter == 0: + model.parameters().__next__().grad = None + + optimizer.step() + + res.append(model.parameters()) + + for p1, p2 in zip(res[0], res[1]): + self.assertEqual(p1, p2) - def test_adamw(self): - self._test_basic_cases( - lambda weight, bias: optim.AdamW([weight, bias], lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.AdamW( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3) - ) - with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): - optim.AdamW(None, lr=1e-2, weight_decay=-1) + def test_adam(self): + for optimizer in [optim.Adam, optim_mt.Adam]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3, amsgrad=True) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3), + [lambda opt: ExponentialLR(opt, gamma=0.9)] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True), + [lambda opt: ExponentialLR(opt, gamma=0.9), + lambda opt: ReduceLROnPlateau(opt)] + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3, amsgrad=True), + [lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt)] + ) + with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): + optimizer(None, lr=1e-2, betas=(1.0, 0.0)) + + with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): + optimizer(None, lr=1e-2, weight_decay=-1) + + def test_adamw(self): + for optimizer in [optim.AdamW, optim_mt.AdamW]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=1, amsgrad=True) + ) + with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): + optimizer(None, lr=1e-2, weight_decay=-1) def test_sparse_adam(self): self._test_rosenbrock_sparse( @@ -365,21 +460,25 @@ def test_sparse_adam(self): # ROCm precision is too low to pass this test @skipIfRocm def test_adadelta(self): - self._test_basic_cases( - lambda weight, bias: optim.Adadelta([weight, bias]) - ) - self._test_basic_cases( - lambda weight, bias: optim.Adadelta( - self._build_params_dict(weight, bias, rho=0.95)) - ) - self._test_basic_cases( - lambda weight, bias: optim.Adadelta( - self._build_params_dict(weight, bias, rho=0.95)), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: ReduceLROnPlateau(opt)] - ) - with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"): - optim.Adadelta(None, lr=1e-2, rho=1.1) + for optimizer in [optim.Adadelta, optim_mt.Adadelta]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias]) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, rho=0.95)) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, rho=0.95)), + [lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt)] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], weight_decay=1) + ) + with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"): + optimizer(None, lr=1e-2, rho=1.1) def test_adagrad(self): self._test_basic_cases( @@ -421,52 +520,84 @@ def test_adagrad_sparse(self): ) def test_adamax(self): - self._test_basic_cases( - lambda weight, bias: optim.Adamax([weight, bias], lr=1e-1) - ) - self._test_basic_cases( - lambda weight, bias: optim.Adamax( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-1) - ) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"): - optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0)) + for optimizer in [optim.Adamax, optim_mt.Adamax]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-1, weight_decay=1) + ) + with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"): + optimizer(None, lr=1e-2, betas=(0.0, 1.0)) def test_rmsprop(self): - self._test_basic_cases( - lambda weight, bias: optim.RMSprop([weight, bias], lr=1e-2) - ) - self._test_basic_cases( - lambda weight, bias: optim.RMSprop( - self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2) - ) - with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"): - optim.RMSprop(None, lr=1e-2, momentum=-1.0) + for optimizer in [optim.RMSprop, optim_mt.RMSprop]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-2) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-3), + lr=1e-2) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-3), + lr=1e-2, centered=True) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-3), + lr=1e-2, centered=True, momentum=0.1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-3), + lr=1e-2, momentum=0.1) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-3), + lr=1e-2, momentum=0.1, weight_decay=1) + ) + with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"): + optimizer(None, lr=1e-2, momentum=-1.0) def test_asgd(self): - self._test_basic_cases( - lambda weight, bias: optim.ASGD([weight, bias], lr=1e-3, t0=100) - ) - self._test_basic_cases( - lambda weight, bias: optim.ASGD( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, t0=100) - ) - with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): - optim.ASGD(None, lr=1e-2, weight_decay=-0.5) + for optimizer in [optim.ASGD, optim_mt.ASGD]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, t0=100) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3, t0=100) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-3), + lr=1e-2, weight_decay=1) + ) + with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): + optimizer(None, lr=1e-2, weight_decay=-0.5) def test_rprop(self): - self._test_basic_cases( - lambda weight, bias: optim.Rprop([weight, bias], lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.Rprop( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3) - ) - with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"): - optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5)) + for optimizer in [optim.Rprop, optim_mt.Rprop]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3) + ) + with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"): + optimizer(None, lr=1e-2, etas=(1.0, 0.5)) def test_lbfgs(self): self._test_basic_cases( @@ -747,7 +878,7 @@ def test_step_lr(self): # lr = 0.0005 if epoch >= 9 epochs = 10 single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] scheduler = StepLR(self.opt, gamma=0.1, step_size=3) self._test(scheduler, targets, epochs) @@ -766,7 +897,7 @@ def test_get_last_lr_multi_step_lr(self): # lr = 0.00005 if 9 <= epoch epochs = 10 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1 - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) self._test_get_last_lr(scheduler, targets, epochs) @@ -777,7 +908,7 @@ def test_multi_step_lr(self): # lr = 0.00005 if epoch >= 9 epochs = 10 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) self._test(scheduler, targets, epochs) @@ -788,14 +919,14 @@ def test_multi_step_lr_with_epoch(self): # lr = 0.00005 if epoch >= 9 epochs = 10 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) self._test_with_epoch(scheduler, targets, epochs) def test_exp_lr(self): epochs = 10 single_targets = [0.05 * (0.9 ** x) for x in range(epochs)] - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] scheduler = ExponentialLR(self.opt, gamma=0.9) self._test(scheduler, targets, epochs) @@ -805,7 +936,7 @@ def test_cos_anneal_lr(self): single_targets = [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 for x in range(epochs)] - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) self._test(scheduler, targets, epochs) @@ -927,7 +1058,7 @@ def test_compound_step_and_exp_lr(self): single_targets += [0.005 * (0.9 ** x) for x in range(3, 6)] single_targets += [0.0005 * (0.9 ** x) for x in range(6, 9)] single_targets += [0.00005 * (0.9 ** x) for x in range(9, 12)] - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) schedulers[1] = ExponentialLR(self.opt, gamma=0.9) self._test(schedulers, targets, epochs) @@ -939,7 +1070,7 @@ def test_compound_exp_and_multistep_lr(self): single_targets += [0.005 * (0.9 ** x) for x in range(2, 5)] single_targets += [0.0005 * (0.9 ** x) for x in range(5, 9)] single_targets += [0.00005 * (0.9 ** x) for x in range(9, 11)] - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) schedulers[1] = ExponentialLR(self.opt, gamma=0.9) self._test(schedulers, targets, epochs) @@ -951,7 +1082,7 @@ def test_compound_cosanneal_and_step_lr(self): (1 + math.cos(math.pi * x / epochs)) / 2 for x in range(epochs)] single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)] - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] schedulers = [None] * 2 schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) @@ -965,7 +1096,7 @@ def test_compound_cosanneal_and_multistep_lr(self): for x in range(epochs)] multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] single_targets = [x * y for x, y in zip(single_targets, multipliers)] - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] schedulers = [None] * 2 schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) @@ -979,7 +1110,7 @@ def test_compound_cosanneal_and_exp_lr(self): for x in range(epochs)] multipliers = [0.1 ** i for i in range(epochs)] single_targets = [x * y for x, y in zip(single_targets, multipliers)] - targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + targets = [single_targets, [x * epochs for x in single_targets]] schedulers = [None] * 2 schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) schedulers[1] = ExponentialLR(self.opt, gamma=0.1) @@ -1091,8 +1222,8 @@ def test_cycle_lr_exp_range_mode_one_lr(self): diff_lr = max_lr - base_lr gamma = 0.9 xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] - lr_target = list(map(lambda x: base_lr + x[1] * diff_lr * gamma**x[0], enumerate(xs))) - momentum_target = list(map(lambda x: max_lr - x[1] * diff_lr * gamma**x[0], enumerate(xs))) + lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] + momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] scheduler = CyclicLR(self.opt, base_lr=base_lr, @@ -1103,10 +1234,10 @@ def test_cycle_lr_exp_range_mode_one_lr(self): def test_cycle_lr_triangular_mode(self): lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] - lr_target_2 = list(map(lambda x: x + 1, lr_target_1)) + lr_target_2 = [x + 1 for x in lr_target_1] lr_targets = [lr_target_1, lr_target_2] momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] - momentum_target_2 = list(map(lambda x: x + 1, momentum_target_1)) + momentum_target_2 = [x + 1 for x in momentum_target_1] momentum_targets = [momentum_target_1, momentum_target_2] scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4, cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6], @@ -1116,11 +1247,11 @@ def test_cycle_lr_triangular_mode(self): def test_cycle_lr_triangular2_mode(self): lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1, 1.25, 1.50, 1.75, 2.00, 1.75] - lr_target_2 = list(map(lambda x: x + 2, lr_target_1)) + lr_target_2 = [x + 2 for x in lr_target_1] lr_targets = [lr_target_1, lr_target_2] momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5, 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25] - momentum_target_2 = list(map(lambda x: x + 2, momentum_target_1)) + momentum_target_2 = [x + 2 for x in momentum_target_1] momentum_targets = [momentum_target_1, momentum_target_2] scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4, cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7], @@ -1136,11 +1267,11 @@ def test_cycle_lr_exp_range_mode(self): gamma = 0.9 xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] - lr_target_1 = list(map(lambda x: base_lr_1 + x[1] * diff_lr_1 * gamma**x[0], enumerate(xs))) - lr_target_2 = list(map(lambda x: base_lr_2 + x[1] * diff_lr_2 * gamma**x[0], enumerate(xs))) + lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] + lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] lr_targets = [lr_target_1, lr_target_2] - momentum_target_1 = list(map(lambda x: max_lr_1 - x[1] * diff_lr_1 * gamma**x[0], enumerate(xs))) - momentum_target_2 = list(map(lambda x: max_lr_2 - x[1] * diff_lr_2 * gamma**x[0], enumerate(xs))) + momentum_target_1 = [max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] + momentum_target_2 = [max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] momentum_targets = [momentum_target_1, momentum_target_2] scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2], max_lr=[max_lr_1, max_lr_2], step_size_up=4, @@ -1259,6 +1390,18 @@ def test_onecycle_lr_linear_annealing(self): total_steps=10, anneal_strategy='linear') self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) + def test_onecycle_lr_linear_annealing_three_phases(self): + lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25] + momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22] + lr_targets = [lr_target, lr_target] + momentum_targets = [momentum_target, momentum_target] + scheduler = OneCycleLR(self.opt, max_lr=25, div_factor=25, + base_momentum=1, max_momentum=22, + total_steps=10, anneal_strategy='linear', + pct_start=0.4, final_div_factor=4, + three_phase=True) + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) + def test_onecycle_lr_cosine_annealing(self): def annealing_cos(start, end, pct): cos_out = math.cos(math.pi * pct) + 1 diff --git a/test/test_overrides.py b/test/test_overrides.py index b48d9056731f2..d9812d43300c9 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1,11 +1,10 @@ import torch import numpy as np -import unittest import inspect import functools import pprint -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, run_tests from torch.overrides import ( handle_torch_function, has_torch_function, @@ -563,6 +562,8 @@ def instance_gen(): func_args.append(instance_gen()) elif t == 'TensorList': func_args.append([instance_gen(), instance_gen()]) + elif t == 'c10::List>': + func_args.append([instance_gen(), instance_gen()]) elif t == 'IntArrayRef': size = arg.get('size', 2) if size == 1: @@ -575,6 +576,8 @@ def instance_gen(): func_args.append(False) elif t.startswith('int') or t in {'Dimname', 'DimnameList'}: func_args.append(0) + elif t in {'Stream'}: + func_args.append(torch.Stream()) elif t.startswith('float') or t == 'double': func_args.append(1.0) elif t in {'Generator', 'MemoryFormat', 'TensorOptions'}: @@ -587,6 +590,16 @@ def instance_gen(): raise RuntimeError(f"Unsupported argument type {t} for {arg['name']} of function {func}") else: args = inspect.getfullargspec(override) + try: + func_args = inspect.getfullargspec(func) + # Remove annotations from argspec + func_args = type(func_args)(**{**func_args, 'annotations': None}) + if func_args != args: + raise RuntimeError(f"Override for {func} doesn't match its argspec.\n" + + f"Original: {inspect.signature(func)}\n" + + f"Override: {inspect.signature(override)}") + except TypeError: + pass nargs = len(args.args) if args.defaults is not None: nargs -= len(args.defaults) @@ -760,52 +773,110 @@ def test_wrapper(self): self.assertTrue(torch.allclose(torch.einsum('ik,jkl,il->ij', [a, b, c]), torch.nn.functional.bilinear(a, c, b))) -# TODO(@anjali411): re-enable this test -# class TestGradCheckOverride(TestCase): -# "Test that wrappers work with gradcheck." -# def test_gradcheck(self): -# from torch.autograd import gradcheck - -# a = wrap(torch.tensor(5.0, dtype=torch.double)) -# b = wrap(torch.tensor(6.0, dtype=torch.double)) - -# a.requires_grad = True -# b.requires_grad = True - -# gradcheck(torch.add, (a, b), raise_exception=False) - -# total_used_attrs = a.used_attrs.union(b.used_attrs) -# total_used_calls = a.used_calls.union(b.used_calls) - -# # These attributes (and the functions below) may change -# # if the gradcheck implementation changes. It's best to -# # aim for attributes that may be commonly present on other -# # Tensor-likes. -# self.assertEqual(total_used_attrs, { -# 'data', -# 'dtype', -# 'is_floating_point', -# 'is_sparse', -# 'layout', -# 'nelement', -# 'new_zeros', -# 'requires_grad', -# 'retain_grad', -# 'size', -# 'stride', -# }) - -# self.assertEqual(total_used_calls, { -# torch.Tensor.new_zeros, -# torch.Tensor.size, -# torch.Tensor.is_floating_point, -# torch.Tensor.nelement, -# torch.Tensor.retain_grad, -# torch.Tensor.stride, -# torch.autograd.grad, -# torch.add, -# }) - +class TestGradCheckOverride(TestCase): + "Test that wrappers work with gradcheck." + def test_gradcheck(self): + from torch.autograd import gradcheck, gradgradcheck + + a = wrap(torch.tensor(5.0, dtype=torch.double)) + b = wrap(torch.tensor(6.0, dtype=torch.double)) + + a.requires_grad = True + b.requires_grad = True + + gradcheck(torch.add, (a, b), raise_exception=False) + gradgradcheck(torch.add, (a, b), raise_exception=False) + + total_used_attrs = a.used_attrs.union(b.used_attrs) + total_used_calls = a.used_calls.union(b.used_calls) + + # These attributes (and the functions below) may change + # if the gradcheck implementation changes. It's best to + # aim for attributes that may be commonly present on other + # Tensor-likes. + self.assertEqual(total_used_attrs, { + 'data', + 'device', + 'dtype', + 'is_complex', + 'is_floating_point', + 'is_sparse', + 'layout', + 'nelement', + 'new_zeros', + 'requires_grad', + 'retain_grad', + 'size', + 'stride', + }) + + self.assertEqual(total_used_calls, { + torch.Tensor.new_zeros, + torch.Tensor.size, + torch.Tensor.is_complex, + torch.Tensor.is_floating_point, + torch.Tensor.nelement, + torch.Tensor.retain_grad, + torch.Tensor.stride, + torch.autograd.grad, + torch.add, + }) + +class TestNamedTuple(TestCase): + """ Regression test for gh-47090 """ + def test_max(self): + x = torch.tensor([1, 2]) + xs = x.as_subclass(SubTensor2) + r = torch.max(x, dim=0) + rs = torch.max(xs, dim=0) + self.assertEqual(type(r), type(rs)) + self.assertEqual(r, rs) + +class TestGradNewOnesOverride(TestCase): + """ Regression test for gh-47069 """ + def test_newones(self): + t = torch.tensor([1, 2]).as_subclass(SubTensor2) + n = t.new_ones((1, 2)) + self.assertEqual(type(n), SubTensor2) + + +class TestBroadcastAllOverride(TestCase): + """ test for gh-37141 """ + def test_broadcast_all(self): + from torch.distributions.utils import broadcast_all + a = torch.tensor([1.2, 3.4, 5.6]) + a_w = Wrapper(a) + b = torch.tensor(5.0) + b_w = Wrapper(b) + c = torch.tensor([5.0, 5.0, 5.0]) + + o_1 = broadcast_all(a_w, b_w) + self.assertTrue(isinstance(o_1[0], Wrapper)) + self.assertTrue(isinstance(o_1[1], Wrapper)) + self.assertEqual(o_1[0]._data, a) + self.assertEqual(o_1[1]._data, c) + + o_2 = broadcast_all(a_w, b) + self.assertTrue(isinstance(o_2[0], Wrapper)) + self.assertTrue(isinstance(o_2[1], Wrapper)) + self.assertEqual(o_2[0]._data, a) + self.assertEqual(o_2[1]._data, c) + +class TestWrapTorchFunction(TestCase): + def test_wrap_torch_function(self): + class A: + @classmethod + def __torch_function__(cls, func, types, args, kwargs): + return -1 + + def dispatcher(a): + return (a,) + + @torch.overrides.wrap_torch_function(dispatcher) + def f(a): + return a + + self.assertEqual(f(A()), -1) if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_package.py b/test/test_package.py index a25726a53c00b..7f669bb983225 100644 --- a/test/test_package.py +++ b/test/test_package.py @@ -1,11 +1,14 @@ -from unittest import main, skipIf -from torch.testing._internal.common_utils import TestCase, IS_WINDOWS +from unittest import skipIf +from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS from tempfile import NamedTemporaryFile from torch.package import PackageExporter, PackageImporter +from torch.package._mangling import PackageMangler, demangle, is_mangled, get_mangle_prefix from pathlib import Path from tempfile import TemporaryDirectory import torch from sys import version_info +from io import StringIO +import pickle try: from torchvision.models import resnet18 @@ -118,7 +121,9 @@ def test_resources(self): def test_extern(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: - he.extern_modules(['package_a.subpackage', 'module_a']) + he.extern(['package_a.subpackage', 'module_a']) + he.require_module('package_a.subpackage') + he.require_module('module_a') he.save_module('package_a') hi = PackageImporter(filename) import package_a.subpackage @@ -132,12 +137,118 @@ def test_extern(self): self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage) - @skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature') + def test_extern_glob(self): + filename = self.temp() + with PackageExporter(filename, verbose=False) as he: + he.extern(['package_a.*', 'module_*']) + he.save_module('package_a') + he.save_source_string('test_module', """\ +import package_a.subpackage +import module_a +""") + hi = PackageImporter(filename) + import package_a.subpackage + import module_a + + module_a_im = hi.import_module('module_a') + hi.import_module('package_a.subpackage') + package_a_im = hi.import_module('package_a') + + self.assertIs(module_a, module_a_im) + self.assertIsNot(package_a, package_a_im) + self.assertIs(package_a.subpackage, package_a_im.subpackage) + + def test_save_imported_module_fails(self): + """ + Directly saving/requiring an PackageImported module should raise a specific error message. + """ + import package_a.subpackage + obj = package_a.subpackage.PackageASubpackageObject() + obj2 = package_a.PackageAObject(obj) + f1 = self.temp() + with PackageExporter(f1, verbose=False) as pe: + pe.save_pickle("obj", "obj.pkl", obj) + + importer1 = PackageImporter(f1) + loaded1 = importer1.load_pickle("obj", "obj.pkl") + + f2 = self.temp() + pe = PackageExporter(f2, verbose=False) + pe.importers.insert(0, importer1.import_module) + with self.assertRaisesRegex(ModuleNotFoundError, 'torch.package'): + pe.require_module(loaded1.__module__) + with self.assertRaisesRegex(ModuleNotFoundError, 'torch.package'): + pe.save_module(loaded1.__module__) + + def test_exporting_mismatched_code(self): + """ + If an object with the same qualified name is loaded from different + packages, the user should get an error if they try to re-save the + object with the wrong package's source code. + """ + import package_a.subpackage + obj = package_a.subpackage.PackageASubpackageObject() + obj2 = package_a.PackageAObject(obj) + f1 = self.temp() + with PackageExporter(f1, verbose=False) as pe: + pe.save_pickle("obj", "obj.pkl", obj2) + + importer1 = PackageImporter(f1) + loaded1 = importer1.load_pickle("obj", "obj.pkl") + importer2 = PackageImporter(f1) + loaded2 = importer2.load_pickle("obj", "obj.pkl") + + f2 = self.temp() + + def make_exporter(): + pe = PackageExporter(f2, verbose=False) + # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. + pe.importers.insert(0, importer1.import_module) + return pe + + # This should fail. The 'PackageAObject' type defined from 'importer1' + # is not necessarily the same 'obj2's version of 'PackageAObject'. + pe = make_exporter() + with self.assertRaises(pickle.PicklingError): + pe.save_pickle("obj", "obj.pkl", obj2) + + # This should also fail. The 'PackageAObject' type defined from 'importer1' + # is not necessarily the same as the one defined from 'importer2' + pe = make_exporter() + with self.assertRaises(pickle.PicklingError): + pe.save_pickle("obj", "obj.pkl", loaded2) + + # This should succeed. The 'PackageAObject' type defined from + # 'importer1' is a match for the one used by loaded1. + pe = make_exporter() + pe.save_pickle("obj", "obj.pkl", loaded1) + + def test_unique_module_names(self): + import package_a.subpackage + obj = package_a.subpackage.PackageASubpackageObject() + obj2 = package_a.PackageAObject(obj) + f1 = self.temp() + with PackageExporter(f1, verbose=False) as pe: + pe.save_pickle("obj", "obj.pkl", obj2) + + importer1 = PackageImporter(f1) + loaded1 = importer1.load_pickle("obj", "obj.pkl") + importer2 = PackageImporter(f1) + loaded2 = importer2.load_pickle("obj", "obj.pkl") + + # Modules from loaded packages should not shadow the names of modules. + # See mangling.md for more info. + self.assertNotEqual(type(obj2).__module__, type(loaded1).__module__) + self.assertNotEqual(type(loaded1).__module__, type(loaded2).__module__) + + @skipIf(version_info < (3, 7), 'mock uses __getattr__ a 3.7 feature') def test_mock(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: - he.mock_modules(['package_a.subpackage', 'module_a']) + he.mock(['package_a.subpackage', 'module_a']) he.save_module('package_a') + he.require_module('package_a.subpackage') + he.require_module('module_a') hi = PackageImporter(filename) import package_a.subpackage _ = package_a.subpackage @@ -149,14 +260,35 @@ def test_mock(self): with self.assertRaisesRegex(NotImplementedError, 'was mocked out'): r() - @skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature') + @skipIf(version_info < (3, 7), 'mock uses __getattr__ a 3.7 feature') + def test_mock_glob(self): + filename = self.temp() + with PackageExporter(filename, verbose=False) as he: + he.mock(['package_a.*', 'module*']) + he.save_module('package_a') + he.save_source_string('test_module', """\ +import package_a.subpackage +import module_a +""") + hi = PackageImporter(filename) + import package_a.subpackage + _ = package_a.subpackage + import module_a + _ = module_a + + m = hi.import_module('package_a.subpackage') + r = m.result + with self.assertRaisesRegex(NotImplementedError, 'was mocked out'): + r() + + @skipIf(version_info < (3, 7), 'mock uses __getattr__ a 3.7 feature') def test_custom_requires(self): filename = self.temp() class Custom(PackageExporter): def require_module(self, name, dependencies): if name == 'module_a': - self.mock_module('module_a') + self.save_mock_module('module_a') elif name == 'package_a': self.save_source_string('package_a', 'import module_a\nresult = 5\n') else: @@ -183,6 +315,11 @@ def test_resnet(self): # the objects in the pickle e.save_pickle('model', 'model.pkl', resnet) + # check th debug graph has something reasonable: + buf = StringIO() + debug_graph = e._write_dep_graph(failing_module='torch') + self.assertIn('torchvision.models.resnet', debug_graph) + # we can now load the saved model i = PackageImporter(f1) r2 = i.load_pickle('model', 'model.pkl') @@ -305,5 +442,115 @@ def load(): self.assertTrue(torch.allclose(*results)) + @skipIfNoTorchVision + def test_script_resnet(self): + resnet = resnet18() + + f1 = self.temp() + # Option 1: save by pickling the whole model + # + single-line, similar to torch.jit.save + # - more difficult to edit the code after the model is created + with PackageExporter(f1, verbose=False) as e: + e.save_pickle('model', 'pickled', resnet) + + i = PackageImporter(f1) + loaded = i.load_pickle('model', 'pickled') + torch.jit.script(loaded) + + + def test_module_glob(self): + from torch.package.exporter import _GlobGroup + + def check(include, exclude, should_match, should_not_match): + x = _GlobGroup(include, exclude) + for e in should_match: + self.assertTrue(x.matches(e)) + for e in should_not_match: + self.assertFalse(x.matches(e)) + + check('torch.*', [], ['torch.foo', 'torch.bar'], ['tor.foo', 'torch.foo.bar', 'torch']) + check('torch.**', [], ['torch.foo', 'torch.bar', 'torch.foo.bar', 'torch'], ['what.torch', 'torchvision']) + check('torch.*.foo', [], ['torch.w.foo'], ['torch.hi.bar.baz']) + check('torch.**.foo', [], ['torch.w.foo', 'torch.hi.bar.foo'], ['torch.f.foo.z']) + check('torch*', [], ['torch', 'torchvision'], ['torch.f']) + check('torch.**', ['torch.**.foo'], ['torch', 'torch.bar', 'torch.barfoo'], ['torch.foo', 'torch.some.foo']) + check('**.torch', [], ['torch', 'bar.torch'], ['visiontorch']) + + @skipIf(version_info < (3, 7), 'mock uses __getattr__ a 3.7 feature') + def test_pickle_mocked(self): + import package_a.subpackage + obj = package_a.subpackage.PackageASubpackageObject() + obj2 = package_a.PackageAObject(obj) + + filename = self.temp() + with PackageExporter(filename, verbose=False) as he: + he.mock(include='package_a.subpackage') + he.save_pickle('obj', 'obj.pkl', obj2) + + hi = PackageImporter(filename) + with self.assertRaises(NotImplementedError): + hi.load_pickle('obj', 'obj.pkl') + + +class ManglingTest(TestCase): + def test_unique_manglers(self): + """ + Each mangler instance should generate a unique mangled name for a given input. + """ + a = PackageMangler() + b = PackageMangler() + self.assertNotEqual(a.mangle("foo.bar"), b.mangle("foo.bar")) + + def test_mangler_is_consistent(self): + """ + Mangling the same name twice should produce the same result. + """ + a = PackageMangler() + self.assertEqual(a.mangle("abc.def"), a.mangle("abc.def")) + + def test_roundtrip_mangling(self): + a = PackageMangler() + self.assertEqual("foo", demangle(a.mangle("foo"))) + + def test_is_mangled(self): + a = PackageMangler() + b = PackageMangler() + self.assertTrue(is_mangled(a.mangle("foo.bar"))) + self.assertTrue(is_mangled(b.mangle("foo.bar"))) + + self.assertFalse(is_mangled("foo.bar")) + self.assertFalse(is_mangled(demangle(a.mangle("foo.bar")))) + + def test_demangler_multiple_manglers(self): + """ + PackageDemangler should be able to demangle name generated by any PackageMangler. + """ + a = PackageMangler() + b = PackageMangler() + + self.assertEqual("foo.bar", demangle(a.mangle("foo.bar"))) + self.assertEqual("bar.foo", demangle(b.mangle("bar.foo"))) + + def test_mangle_empty_errors(self): + a = PackageMangler() + with self.assertRaises(AssertionError): + a.mangle("") + + def test_demangle_base(self): + """ + Demangling a mangle parent directly should currently return an empty string. + """ + a = PackageMangler() + mangled = a.mangle("foo") + mangle_parent = mangled.partition(".")[0] + self.assertEqual("", demangle(mangle_parent)) + + def test_mangle_prefix(self): + a = PackageMangler() + mangled = a.mangle("foo.bar") + mangle_prefix = get_mangle_prefix(mangled) + self.assertEqual(mangle_prefix + "." + "foo.bar", mangled) + + if __name__ == '__main__': - main() + run_tests() diff --git a/test/test_profiler.py b/test/test_profiler.py index aefdfbb937faf..826a9f5d0b57b 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -1,24 +1,31 @@ import collections import gc +import io import unittest import torch +import torch.nn as nn +import torch.optim +import torch.utils.data from torch.testing._internal.common_utils import ( - TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS) + TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS, TemporaryFileName) +import torch.autograd.profiler as profiler from torch.autograd.profiler import profile +from torch.autograd import kineto_available try: import psutil HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False +import pickle @unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") @unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN") @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") -class TestProfiler_cuda(TestCase): +class TestProfilerCUDA(TestCase): def test_mem_leak(self): """Checks that there's no memory leak when using profiler with CUDA """ @@ -44,5 +51,258 @@ def test_mem_leak(self): self.assertTrue(not (is_increasing and max_diff > 100 * 1024), msg='memory usage is increasing, {}'.format(str(last_rss))) +class TestProfiler(TestCase): + def test_source(self): + """Checks that source code attribution works for eager, TS and autograd mode + """ + # avoid automatic inlining + prev_opt = torch._C._get_graph_executor_optimize() + torch._C._set_graph_executor_optimize(False) + + @torch.jit.script + def ts_method_2(x, y): + return torch.matmul(x, y) + + @torch.jit.script + def ts_method_1(x, y, z): + a = x + z + w = ts_method_2(x, y) + a + return w.sum() + + class DummyModule(nn.Module): + def __init__(self): + super(DummyModule, self).__init__() + self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False) + + def forward(self, x): + return self.conv(x) + + mod = DummyModule() + + with profile(with_stack=True, use_kineto=kineto_available()) as p: + x = torch.randn(10, 10, requires_grad=True) + y = torch.randn(10, 10, requires_grad=True) + z = x + y + w = ts_method_1(x, y, z) + v = 2 * w + v.backward() + a = torch.randn(2, 3, 2, 2, requires_grad=True) + b = mod(a) + c = b.sum() + c.backward() + + print(p.key_averages( + group_by_stack_n=5).table( + sort_by="self_cpu_time_total", row_limit=-1)) + + for e in p.function_events: + if "aten::add" in e.name or "AddBackward" in e.name: + self.assertTrue(any(["test_profiler" in entry for entry in e.stack])) + self.assertTrue(any([( + "test_source" in entry or + "ts_method_1" in entry or + "ts_method_2" in entry) for entry in e.stack])) + + torch._C._set_graph_executor_optimize(prev_opt) + + def payload(self): + x = torch.randn(10, 10).cuda() + y = torch.randn(10, 10).cuda() + z = torch.mm(x, y) + z = z + y + z = z.cpu() + + @unittest.skipIf(not kineto_available(), "Kineto is required") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_kineto(self): + with profile(use_cuda=True, use_kineto=True): + self.payload() + + # rerun to avoid initial start overhead + with profile(use_cuda=True, use_kineto=True) as p: + self.payload() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + found_gemm = False + found_memcpy = False + for e in p.function_events: + if "gemm" in e.name: + found_gemm = True + if "Memcpy" in e.name or "memcpy" in e.name: + found_memcpy = True + self.assertTrue(found_gemm) + self.assertTrue(found_memcpy) + # p.export_chrome_trace("/tmp/test_trace.json") + + def test_high_level_trace(self): + """Checks that python side high level events are recorded. + """ + class RepeatedDataset(torch.utils.data.Dataset): + def __init__(self, N, D_in, D_out): + self.N = N + self.x = torch.randn(N, D_in) + self.y = torch.randn(N, D_out) + + def __len__(self): + return self.N + + def __getitem__(self, idx): + return self.x, self.y + + class TwoLayerNet(torch.nn.Module): + def __init__(self, D_in, H, D_out): + super(TwoLayerNet, self).__init__() + self.linear1 = torch.nn.Linear(D_in, H) + self.linear2 = torch.nn.Linear(H, D_out) + + def forward(self, x): + h_relu = self.linear1(x).clamp(min=0) + y_pred = self.linear2(h_relu) + return y_pred + + class CustomSGD(torch.optim.SGD): + def __init__(self, *args, **kwargs): + super(CustomSGD, self).__init__(*args, **kwargs) + + def train(): + for _, data in enumerate(dataloader): + x, y = data[0], data[1] + y_pred = model(x) + loss = criterion(y_pred, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + N, D_in, H, D_out = 8, 10, 5, 2 + model = TwoLayerNet(D_in, H, D_out) + criterion = torch.nn.MSELoss(reduction='sum') + optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + ds = RepeatedDataset(N, D_in, D_out) + dataloader = torch.utils.data.DataLoader(ds, batch_size=1) + + try: + train() + except Exception: + self.assertTrue(False, "Expected no exception without profiling.") + + # Create multiple instances, expect each func is hooked only one time. + # Nested wrappers(repeated patching) will make following test fail. + optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4) + dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1) + + def judge(expected_event_count, prof): + actual_event_count = {} + for e in prof.function_events: + if "#" in e.name: + key = e.name + if key in expected_event_count.keys(): + actual_event_count[key] = actual_event_count.setdefault(key, 0) + 1 + for key, count in expected_event_count.items(): + self.assertTrue((key in actual_event_count.keys()) and (count == actual_event_count[key])) + + with profile() as prof: + train() + expected_event_count = { + # "+1" because the final iteration will enter __next__ but skip the loop body. + "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), + "Optimizer.step#SGD.step": N, + "Optimizer.zero_grad#SGD.zero_grad": N + } + judge(expected_event_count, prof) + + # Test on pickle/unpickle. Expect to work in multi-processing. + optimizer = pickle.loads(pickle.dumps(optimizer)) + with profile() as prof: + train() + judge(expected_event_count, prof) + + # Test on customized optimizer. + optimizer = CustomSGD(model.parameters(), lr=1e-4) + with profile() as prof: + train() + expected_event_count = { + "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), + "Optimizer.step#CustomSGD.step": N, + "Optimizer.zero_grad#CustomSGD.zero_grad": N + } + judge(expected_event_count, prof) + + def test_flops(self): + model = torch.nn.Sequential( + nn.Conv2d(16, 33, 18), + nn.ReLU(), + nn.Linear(243, 243), + nn.ReLU(), + ) + inputs = torch.randn(40, 16, 18, 260) + with profiler.profile(record_shapes=True, with_flops=True) as prof: + model(inputs) + profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10) + print(profiler_output) + self.assertIn("FLOPS", profiler_output) + + @unittest.skipIf(not kineto_available(), "Kineto is required") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_kineto_profiler_api(self): + called_num = [0] + + with profile(use_cuda=True, use_kineto=True): + self.payload() + + def trace_handler(p): + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json") + called_num[0] += 1 + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2), + on_trace_ready=trace_handler + ) as p: + for idx in range(8): + self.payload() + p.next_step() + + self.assertEqual(called_num[0], 2) + + # case without enable_pred + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA] + ) as p: + self.payload() + self.payload() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + def test_export_stacks(self): + with profile(with_stack=True, use_kineto=kineto_available()) as p: + x = torch.randn(10, 10) + y = torch.randn(10, 10) + z = torch.mm(x, y) + z = z + y + + with TemporaryFileName(mode="w+") as fname: + p.export_stacks(fname) + with io.open(fname, 'r') as f: + lines = f.readlines() + assert len(lines) > 0, "Empty stacks file" + for line in lines: + is_int = False + try: + assert int(line.split(" ")[-1]) > 0, "Invalid stacks record" + is_int = True + except ValueError: + pass + assert is_int, "Invalid stacks record" + + if __name__ == '__main__': run_tests() diff --git a/test/test_pytree.py b/test/test_pytree.py new file mode 100644 index 0000000000000..1a3e69dbc0753 --- /dev/null +++ b/test/test_pytree.py @@ -0,0 +1,155 @@ +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec, LeafSpec +from torch.utils._pytree import _broadcast_to_and_flatten + +class TestPytree(TestCase): + def test_treespec_equality(self): + self.assertTrue(LeafSpec() == LeafSpec()) + self.assertTrue(TreeSpec(list, None, []) == TreeSpec(list, None, [])) + self.assertTrue(TreeSpec(list, None, [LeafSpec()]) == TreeSpec(list, None, [LeafSpec()])) + self.assertFalse(TreeSpec(tuple, None, []) == TreeSpec(list, None, [])) + self.assertTrue(TreeSpec(tuple, None, []) != TreeSpec(list, None, [])) + + def test_flatten_unflatten_leaf(self): + def run_test_with_leaf(leaf): + values, treespec = tree_flatten(leaf) + self.assertEqual(values, [leaf]) + self.assertEqual(treespec, LeafSpec()) + + unflattened = tree_unflatten(values, treespec) + self.assertEqual(unflattened, leaf) + + run_test_with_leaf(1) + run_test_with_leaf(1.) + run_test_with_leaf(None) + run_test_with_leaf(bool) + run_test_with_leaf(torch.randn(3, 3)) + + def test_flatten_unflatten_list(self): + def run_test(lst): + expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst]) + values, treespec = tree_flatten(lst) + self.assertTrue(isinstance(values, list)) + self.assertEqual(values, lst) + self.assertEqual(treespec, expected_spec) + + unflattened = tree_unflatten(values, treespec) + self.assertEqual(unflattened, lst) + self.assertTrue(isinstance(unflattened, list)) + + run_test([]) + run_test([1., 2]) + run_test([torch.tensor([1., 2]), 2, 10, 9, 11]) + + def test_flatten_unflatten_tuple(self): + def run_test(tup): + expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup]) + values, treespec = tree_flatten(tup) + self.assertTrue(isinstance(values, list)) + self.assertEqual(values, list(tup)) + self.assertEqual(treespec, expected_spec) + + unflattened = tree_unflatten(values, treespec) + self.assertEqual(unflattened, tup) + self.assertTrue(isinstance(unflattened, tuple)) + + run_test(()) + run_test((1.,)) + run_test((1., 2)) + run_test((torch.tensor([1., 2]), 2, 10, 9, 11)) + + def test_flatten_unflatten_dict(self): + def run_test(tup): + expected_spec = TreeSpec(dict, list(tup.keys()), + [LeafSpec() for _ in tup.values()]) + values, treespec = tree_flatten(tup) + self.assertTrue(isinstance(values, list)) + self.assertEqual(values, list(tup.values())) + self.assertEqual(treespec, expected_spec) + + unflattened = tree_unflatten(values, treespec) + self.assertEqual(unflattened, tup) + self.assertTrue(isinstance(unflattened, dict)) + + run_test({}) + run_test({'a': 1}) + run_test({'abcdefg': torch.randn(2, 3)}) + run_test({1: torch.randn(2, 3)}) + run_test({'a': 1, 'b': 2, 'c': torch.randn(2, 3)}) + + def test_flatten_unflatten_nested(self): + def run_test(pytree): + values, treespec = tree_flatten(pytree) + self.assertTrue(isinstance(values, list)) + self.assertEqual(len(values), treespec.num_leaves) + + # NB: python basic data structures (dict list tuple) all have + # contents equality defined on them, so the following works for them. + unflattened = tree_unflatten(values, treespec) + self.assertEqual(unflattened, pytree) + + cases = [ + [()], + ([],), + {'a': ()}, + {'a': 0, 'b': [{'c': 1}]}, + {'a': 0, 'b': [1, {'c': 2}, torch.randn(3)], 'c': (torch.randn(2, 3), 1)}, + ] + for case in cases: + run_test(case) + + def test_treespec_repr(self): + # Check that it looks sane + pytree = (0, [0, 0, 0]) + _, spec = tree_flatten(pytree) + self.assertEqual( + repr(spec), 'TreeSpec(tuple, None, [*, TreeSpec(list, None, [*, *, *])])') + + def test_broadcast_to_and_flatten(self): + cases = [ + (1, (), []), + + # Same (flat) structures + ((1,), (0,), [1]), + ([1], [0], [1]), + ((1, 2, 3), (0, 0, 0), [1, 2, 3]), + ({'a': 1, 'b': 2}, {'a': 0, 'b': 0}, [1, 2]), + + # Mismatched (flat) structures + ([1], (0,), None), + ([1], (0,), None), + ((1,), [0], None), + ((1, 2, 3), (0, 0), None), + ({'a': 1, 'b': 2}, {'a': 0}, None), + ({'a': 1, 'b': 2}, {'a': 0, 'c': 0}, None), + ({'a': 1, 'b': 2}, {'a': 0, 'b': 0, 'c': 0}, None), + + # Same (nested) structures + ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]), + ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]), + + # Mismatched (nested) structures + ((1, [2, 3]), (0, (0, 0)), None), + ((1, [2, 3]), (0, [0, 0, 0]), None), + + # Broadcasting single value + (1, (0, 0, 0), [1, 1, 1]), + (1, [0, 0, 0], [1, 1, 1]), + (1, {'a': 0, 'b': 0}, [1, 1]), + (1, (0, [0, [0]], 0), [1, 1, 1, 1]), + (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]), + + # Broadcast multiple things + ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]), + ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), + (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), + ] + for pytree, to_pytree, expected in cases: + _, to_spec = tree_flatten(to_pytree) + result = _broadcast_to_and_flatten(pytree, to_spec) + self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_quantization.py b/test/test_quantization.py index fc67891c24fe3..1c370913c6d0e 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -14,6 +14,7 @@ from quantization.test_quantized_op import TestComparatorOps # noqa: F401 from quantization.test_quantized_op import TestPadding # noqa: F401 from quantization.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401 +from quantization.test_quantized_op import TestDynamicQuantizedRNNOp # noqa: F401 # Quantized Functional from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401 @@ -29,11 +30,9 @@ from quantization.test_fusion_passes import TestFusionPasses # noqa: F401 # Module -# TODO: merge the fake quant per tensor and per channel test cases # TODO: some of the tests are actually operator tests, e.g. test_forward_per_tensor, and # should be moved to test_quantized_op -from quantization.test_workflow_module import TestFakeQuantizePerTensor # noqa: F401 -from quantization.test_workflow_module import TestFakeQuantizePerChannel # noqa: F401 +from quantization.test_workflow_module import TestFakeQuantize # noqa: F401 from quantization.test_workflow_module import TestObserver # noqa: F401 # TODO: merge with TestObserver # TODO: some tests belong to test_quantize.py, e.g. test_record_observer @@ -45,6 +44,8 @@ from quantization.test_quantize import TestPostTrainingStatic # noqa: F401 from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401 from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401 +from quantization.test_quantize import TestEagerModeOps # noqa: F401 +from quantization.test_quantize import TestEagerModeQATOps # noqa: F401 # TODO: merge with other tests in test_quantize.py? from quantization.test_quantize import TestFunctionalModule # noqa: F401 @@ -61,13 +62,24 @@ from quantization.test_quantize_jit import TestQuantizeDynamicJitOps # noqaa: F401 # 3. GraphModule based graph mode quantization -from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401 -from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401 -from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401 +try: + from quantization.test_quantize_fx import TestFuseFx # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401 +except ImportError: + # In FBCode we separate FX out into a separate target for the sake of dev + # velocity. These are covered by a separate test target `quantization_fx` + pass -# Tooling: numric_suite +# Tooling: numeric_suite from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401 +try: + from quantization.test_numeric_suite_fx import TestGraphModeNumericSuite # noqa: F401 +except ImportError: + pass + # Backward Compatibility from quantization.test_backward_compatibility import TestSerialization # noqa: F401 diff --git a/test/test_reductions.py b/test/test_reductions.py new file mode 100644 index 0000000000000..b08cebf7947b4 --- /dev/null +++ b/test/test_reductions.py @@ -0,0 +1,2275 @@ +import torch +import numpy as np + +import unittest +import math +from typing import Dict, List +import random +from functools import partial +from itertools import product, combinations, permutations + +from torch._six import inf, nan, istuple +from torch.testing._internal.common_utils import ( + TestCase, run_tests, TEST_SCIPY, slowTest, torch_to_numpy_dtype_dict, + IS_WINDOWS) +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, + onlyOnCPUAndCUDA, onlyCUDA, expectedAlertNondeterministic, largeTensorTest) + +# TODO: replace with make_tensor +def _generate_input(shape, dtype, device, with_extremal): + if shape == (): + x = torch.tensor((), dtype=dtype, device=device) + else: + if dtype.is_floating_point or dtype.is_complex: + # work around torch.randn not being implemented for bfloat16 + if dtype == torch.bfloat16: + x = torch.randn(*shape, device=device) * random.randint(30, 100) + x = x.to(torch.bfloat16) + else: + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x[torch.randn(*shape) > 0.5] = 0 + if with_extremal and dtype.is_floating_point: + # Use extremal values + x[torch.randn(*shape) > 0.5] = float('nan') + x[torch.randn(*shape) > 0.5] = float('inf') + x[torch.randn(*shape) > 0.5] = float('-inf') + elif with_extremal and dtype.is_complex: + x[torch.randn(*shape) > 0.5] = complex('nan') + x[torch.randn(*shape) > 0.5] = complex('inf') + x[torch.randn(*shape) > 0.5] = complex('-inf') + elif dtype == torch.bool: + x = torch.zeros(shape, dtype=dtype, device=device) + x[torch.randn(*shape) > 0.5] = True + else: + x = torch.randint(15, 100, shape, dtype=dtype, device=device) + + return x + +# TODO: replace with make_tensor +def _rand_shape(dim, min_size, max_size): + shape = [] + for i in range(dim): + shape.append(random.randint(min_size, max_size)) + return tuple(shape) + +class TestReductions(TestCase): + + def test_var_unbiased(self, device): + tensor = torch.randn(100, device=device) + self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) + self.assertEqual(tensor.var(), tensor.var(unbiased=True)) + self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False)) + + tensor = torch.tensor([1.0, 2.0], device=device) + self.assertEqual(tensor.var(unbiased=True), 0.5) + self.assertEqual(tensor.var(unbiased=False), 0.25) + + tensor = torch.tensor([1.0, 2.0, 3.0], device=device) + self.assertEqual(tensor.var(unbiased=True), 1.0) + self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0) + + tensor = torch.randn(100, device=device) + self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True)) + self.assertEqual(tensor.std(), tensor.std(unbiased=True)) + self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False)) + + def test_var_stability(self, device): + tensor = torch.tensor([2281.5, 2281.25], device=device) + self.assertEqual(tensor.var(dim=0), 0.03125) + self.assertEqual(tensor.var(), 0.03125) + + def test_sum_dim_reduction_uint8_overflow(self, device): + example = [[-1, 2, 1], [5, 3, 6]] + x = torch.tensor(example, dtype=torch.uint8, device=device) + self.assertEqual(x.sum(dtype=torch.uint8).item(), 16) + self.assertEqual(x.sum(0, dtype=torch.uint8), torch.tensor([4, 5, 7], dtype=torch.uint8, device=device)) + self.assertEqual(x.sum(1, dtype=torch.uint8), torch.tensor([2, 14], dtype=torch.uint8, device=device)) + y = torch.tensor(example, dtype=torch.uint8, device=device) + torch.sum(x, 0, out=y) + self.assertEqual(x.sum(0, dtype=torch.uint8), y) + + def test_dim_reduction_less_than_64(self, device): + sizes = [1] * 65 + x = torch.randn(sizes, device=device) + ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var, + torch.amin, torch.amax, torch.norm] + for op in ops: + with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): + op(x, 64) + with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): + op(x, -1) + + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_logsumexp(self, device): + from scipy.special import logsumexp + a = torch.randn(5, 4, device=device) + a[0, 0] = inf + a[1, :] = -inf + actual = a.logsumexp(1) + expected = logsumexp(a.cpu().numpy(), 1) + self.assertEqual(expected.shape, actual.shape) + self.assertEqual(expected, actual) + # check that out is actually inplace + b = torch.zeros(5, 2, device=device) + c = b[:, 0] + torch.logsumexp(a, 1, out=c) + self.assertEqual(expected, b[:, 0]) + + @onlyCPU + def test_sum_parallel(self, device): + # To use parallel branches we'll need to compare on tensors + # that are relatively large. Even if this is run on a single + # core machine these tests will still give you signal on + # the correctness + + def _run_test(size): + for dim in range(len(size) + 1): + nv = np.round(np.random.rand(*size)) # 0s and 1s + tv = torch.from_numpy(nv) + # Parallelisim is only used if numel is + # larger than grainsize defined in Parallel.h + self.assertTrue(tv.numel() > 32768) + if dim == len(size): + nvs = nv.sum() + tvs = tv.sum() + else: + nvs = nv.sum(dim) + tvs = tv.sum(dim) + diff = np.abs(nvs - tvs.numpy()).sum() + self.assertEqual(diff, 0) + + _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3]) + _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + _run_test([1, 32 * 8 * 32 * 8]) + _run_test([1, 32770]) + + # TODO: kill map2_ (and similar) uses and update to compare with NumPy + # only works on CPU since this uses map2_, which is only supported on CPU + def _testCSelection(self, torchfn, mathfn): + # Two tensors + size = (100, 100) + a = torch.rand(*size) + b = torch.rand(*size) + c = torchfn(a, b) + expected_c = torch.zeros(*size) + expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b)) + self.assertEqual(expected_c, c, atol=0, rtol=0) + + @onlyCPU + def test_max_elementwise(self, device): + self._testCSelection(torch.max, max) + + @onlyCPU + def test_min_elementwise(self, device): + self._testCSelection(torch.min, min) + + def test_all_any(self, device): + def test(size): + x = torch.ones(*size, device=device).byte() + self.assertTrue(x.all()) + self.assertTrue(x.any()) + + x[3] = 0 + self.assertFalse(x.all()) + self.assertTrue(x.any()) + + x.zero_() + self.assertFalse(x.all()) + self.assertFalse(x.any()) + + x.fill_(2) + self.assertTrue(x.all()) + self.assertTrue(x.any()) + + x = torch.ones(*size, device=device).bool() + self.assertTrue(x.all()) + self.assertTrue(x.any()) + + x[3] = False + self.assertFalse(x.all()) + self.assertTrue(x.any()) + + test((10,)) + test((5, 5)) + + def test_all_any_with_dim(self, device): + def test(x): + r1 = x.prod(dim=0, keepdim=False).byte() + r2 = x.all(dim=0, keepdim=False) + self.assertEqual(r1.shape, r2.shape) + self.assertTrue((r1 == r2).all()) + + r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte() + r4 = x.any(dim=1, keepdim=True) + self.assertEqual(r3.shape, r4.shape) + self.assertTrue((r3 == r4).all()) + + test(torch.tensor([[0, 0, 0], + [0, 0, 1], + [0, 1, 1], + [1, 1, 1]], device=device, dtype=torch.uint8)) + + def test_numpy_named_args(self, device): + x1 = torch.randn(10, device=device) + x2 = torch.randn(10, device=device) + res1 = torch.add(input=x1, other=x2) + res2 = torch.add(x1=x1, x2=x2) + self.assertEqual(res1, res2) + + x1 = torch.randn(10, 10, 10, device=device) + res1 = x1.sum(dim=(0, 2), keepdim=True) + res2 = x1.sum(axis=(0, 2), keepdims=True) + self.assertEqual(res1, res2) + + # TODO: kill this ane replace with common creation ops + def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True, + use_complex=False) -> Dict[str, List[torch.Tensor]]: + float_types = [torch.double, + torch.float] + int_types = [torch.int64, + torch.int32, + torch.int16] + + complex_types = [torch.complex64, + torch.complex128] + + def make_contiguous(shape, dtype) -> torch.Tensor: + if dtype in float_types: + val = torch.randn(shape, dtype=dtype) + val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0)) + val = val + ((val_range[1] - val_range[0]) / 2.0) + val = torch.clamp(val, min=val_range[0], max=val_range[1]) + return val + result = torch.zeros(shape, dtype=dtype) + result.apply_(lambda x: random.randint(val_range[0], val_range[1])) + return result + + def make_non_contiguous(shape, dtype) -> torch.Tensor: + contig = make_contiguous(shape, dtype) + non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0] + non_contig = non_contig.select(-1, -1) + non_contig.copy_(contig) + self.assertFalse(non_contig.is_contiguous()) + return non_contig + + def make_contiguous_slice(size, dtype) -> torch.Tensor: + contig = make_contiguous((1, size), dtype) + non_contig = contig[:1, 1:size - 1] + self.assertTrue(non_contig.is_contiguous()) + return contig + + types = [] + if use_floating: + types += float_types + if use_integral: + types += int_types + if use_complex: + types += complex_types + tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []} + for dtype in types: + tensors["cont"].append(make_contiguous(shape, dtype)) + tensors["noncont"].append(make_non_contiguous(shape, dtype)) + tensors["slice"].append(make_contiguous_slice(sum(list(shape)), dtype)) + + return tensors + + # TODO: refactor this to use comparators from common_utils + def _assert_matches_numpy(self, t, n): + self.assertEqual(n.shape, t.shape) + if t.dtype == torch.float: + self.assertEqual(n, t, rtol=1e-03, atol=1e-05, equal_nan=True) + else: + self.assertEqual(n, t, equal_nan=True) + + # TODO: update this and tests that use it to use the device argument properly + def _test_dim_ops(self, pytorch_op, numpy_op, + use_floating=True, use_integral=True, use_complex=False): + def do_one(tensors_dict, dim): + for category, tensors in tensors_dict.items(): + if category == "slice": + dim = 0 + for tensor in tensors: + # we have no control over NumPy warnings... + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + expected = numpy_op(tensor.cpu().numpy(), dim) + actual = pytorch_op(tensor, dim) + self._assert_matches_numpy(actual, expected) + if torch.cuda.is_available(): + self._assert_matches_numpy(pytorch_op(tensor.cuda(), dim).cpu(), expected) + do_one(self._make_tensors((5, 400000), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), 1) + do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), 0) + do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), 1) + do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), 2) + do_one(self._make_tensors((100000, ), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), -1) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), 0) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), 1) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), 2) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), (1, 2)) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), (1, -1)) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), (0, 2)) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral, use_complex=use_complex), (0, 2, 1)) + + @slowTest + @onlyCPU + def test_sum_dim(self, device): + self._test_dim_ops( + lambda t, d: t.sum(d), + lambda n, d: n.sum(d), + use_floating=True, use_integral=True, use_complex=True) + + @onlyCPU + def test_mean_dim(self, device): + self._test_dim_ops( + lambda t, d: t.mean(d), + lambda n, d: n.mean(d), + use_integral=False, + use_complex=True) + + @onlyCPU + def test_std_dim(self, device): + for unbiased in [False, True]: + self._test_dim_ops( + lambda t, d: t.std(d, unbiased=unbiased), + lambda n, d: n.std(d, ddof=1 if unbiased else 0), + use_integral=False) + + @onlyCPU + def test_var_dim(self, device): + for unbiased in [False, True]: + self._test_dim_ops( + lambda t, d: t.var(d, unbiased=unbiased), + lambda n, d: n.var(d, ddof=1 if unbiased else 0), + use_integral=False) + + @onlyCPU + @unittest.skipIf(not TEST_SCIPY, 'Scipy not found') + def test_logsumexp_dim(self, device): + from scipy.special import logsumexp + self._test_dim_ops( + lambda t, d: t.logsumexp(d), + lambda n, d: logsumexp(n, d), + use_integral=False) + + # TODO: update this and tests that use it to handle device properly + def _test_reduce_integer_upcast(self, fn, has_out=True, test_complex=True): + shape = (3, 4, 5) + reduced_shape = fn(torch.ones(shape)).shape + + def _test_out(dtype, other_dtype): + out = torch.ones(reduced_shape, dtype=dtype) + result = fn(x, out=out) + self.assertIs(out.dtype, result.dtype) + self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False) + result = fn(x, out=out, dtype=dtype) + self.assertIs(out.dtype, result.dtype) + self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False) + # 'out' is favored over dtype, check error + self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype)) + + for dtype in [dtype for dtype in torch.testing.get_all_math_dtypes('cpu') if dtype != torch.float16]: + x = torch.ones(shape, dtype=dtype) + expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64 + self.assertIs(expected_dtype, fn(x).dtype) + self.assertEqual(fn(x.to(expected_dtype)), fn(x)) + + if dtype.is_floating_point: + other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 + elif dtype.is_complex: + other_dtype = torch.complex64 if dtype == torch.complex128 else torch.complex128 + else: + other_dtype = torch.int32 if dtype != torch.int32 else torch.int16 + self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype) + self.assertEqual(fn(x.to(other_dtype)), fn(x, dtype=other_dtype), exact_dtype=False) + + # test mixed int/float/complex + if dtype.is_floating_point: + mixed_dtypes = [torch.int32, torch.complex64] + elif dtype.is_complex: + mixed_dtypes = [torch.int32, torch.float32] + else: + mixed_dtypes = [torch.float32, torch.complex64] + + for mixed_dtype in mixed_dtypes: + self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype) + self.assertEqual(fn(x.to(mixed_dtype)), fn(x, dtype=mixed_dtype), exact_dtype=False) + + if has_out: + _test_out(dtype, other_dtype) + _test_out(dtype, mixed_dtype) + + @onlyCPU + def test_sum_integer_upcast(self, device): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False) + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs)) + + @onlyCPU + def test_prod_integer_upcast(self, device): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False) + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs)) + + @onlyCPU + def test_cumsum_integer_upcast(self, device): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs)) + + @onlyCPU + def test_cumprod_integer_upcast(self, device): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs)) + + def test_mode(self, device): + x = torch.arange(1., SIZE * SIZE + 1, device=device).clone().resize_(SIZE, SIZE) + x[:2] = 1 + x[:, :2] = 1 + x0 = x.clone() + + # Pre-calculated results. + res1val = torch.tensor(SIZE, device=device).fill_(1) + # The indices are the position of the last appearance of the mode element. + res1ind = torch.tensor(SIZE, device=device, dtype=torch.long).fill_(1) + res1ind[0] = SIZE - 1 + res1ind[1] = SIZE - 1 + + res2val, res2ind = torch.mode(x, keepdim=False) + self.assertEqual(res1val, res2val, atol=0, rtol=0) + self.assertEqual(res1ind, res2ind, atol=0, rtol=0) + + # Test use of result tensor + res2val = torch.tensor((), device=device) + res2ind = torch.tensor((), device=device, dtype=torch.long) + torch.mode(x, keepdim=False, out=(res2val, res2ind)) + self.assertEqual(res1val, res2val, atol=0, rtol=0) + self.assertEqual(res1ind, res2ind, atol=0, rtol=0) + + # Test non-default dim + res2val, res2ind = torch.mode(x, 0, False) + self.assertEqual(res1val, res2val, atol=0, rtol=0) + self.assertEqual(res1ind, res2ind, atol=0, rtol=0) + + # input unchanged + self.assertEqual(x, x0, atol=0, rtol=0) + + # TODO: make work on CUDA, too + @onlyCPU + def test_accreal_type(self, device) -> None: + x = torch.ones(2, 3, 4) + self.assertIsInstance(x.double().sum().item(), float) + self.assertIsInstance(x.float().sum().item(), float) + self.assertIsInstance(x.long().sum().item(), int) + self.assertIsInstance(x.int().sum().item(), int) + self.assertIsInstance(x.short().sum().item(), int) + self.assertIsInstance(x.char().sum().item(), int) + self.assertIsInstance(x.byte().sum().item(), int) + + def test_var_mean_some_dims(self, device): + sizes = (4, 6, 7, 5, 3) + dims = len(sizes) + + x = torch.rand(sizes, device=device) + for num_of_dims in range(2, dims): + dim_list = list(combinations(list(range(dims)), r=num_of_dims)) + for dim in dim_list: + for unbiased in [False, True]: + for keepdim in [False, True]: + var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(var1, var2) + self.assertEqual(mean1, mean2) + + # TODO: this should be a generic opinfo test + def test_all_any_empty(self, device): + x = torch.ByteTensor().to(device) + self.assertTrue(x.all()) + self.assertFalse(x.any()) + + x = torch.BoolTensor().to(device) + self.assertTrue(x.all()) + self.assertFalse(x.any()) + + @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypes(torch.float, torch.double) + def test_max_with_inf(self, device, dtype): + a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) + self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item()) + self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item()) + self.assertTrue(torch.max(a).item() == inf) + self.assertTrue(torch.amax(a).item() == inf) + + @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypes(torch.float, torch.double) + def test_min_with_inf(self, device, dtype): + a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) + self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item()) + self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item()) + self.assertTrue(torch.min(a).item() == -inf) + self.assertTrue(torch.amin(a).item() == -inf) + + def _test_minmax_helper(self, torchfn, reffn, device, dtype, skip_indices=False): + def create_input(shape, device, dtype): + if dtype.is_floating_point: + return torch.randn(*shape, device=device, dtype=dtype) + else: + low = 0 if dtype == torch.bool else -1000 + high = 2 if dtype == torch.bool else 1000 + return torch.randint(low, high, shape, device=device, dtype=dtype) + x = create_input((100, 100), device, dtype) + self.compare_with_numpy(torchfn, reffn, x) + # non contiguous + x = create_input((10, 10, 10), device, dtype) + x = x[:, 4] + self.compare_with_numpy(torchfn, reffn, x) + + def get_values(x): + if istuple(x): + return x[0] + return x + + # indices + if not skip_indices: + size = 5 + x = create_input((size, size), device, dtype) + inputs = (x, x.t()) + dims = (0, 1) + for xinp, d in product(inputs, dims): + self.compare_with_numpy(lambda x: get_values(torchfn(x, d, False)), lambda x: reffn(x, d, keepdims=False), xinp) + result = torchfn(xinp, d, False) + if istuple(result): + v, i = result + if d == 1: + self.assertEqual(xinp[torch.arange(size), i], v, atol=0, rtol=0) + else: + self.assertEqual(xinp[i, torch.arange(size)], v, atol=0, rtol=0) + # nan + if dtype.is_floating_point: + for index in (0, 4, 99): + x = create_input((100,), device, dtype) + x[index] = nan + if not skip_indices: + result = torchfn(x, 0) + v = get_values(result) + self.assertEqual(v, nan) + if istuple(result): + i = result[1] + self.assertEqual(i, index) + self.assertEqual(torchfn(x), nan) + + @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool) + @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) + @dtypes(torch.float, torch.double) + def test_max(self, device, dtype): + self._test_minmax_helper(torch.max, np.amax, device, dtype) + + @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool) + @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) + @dtypes(torch.float, torch.double) + def test_min(self, device, dtype): + self._test_minmax_helper(torch.min, np.amin, device, dtype) + + @dtypesIfCPU(torch.float, torch.double, torch.int, torch.long, torch.bool) + @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) + @dtypes(torch.float, torch.double) + def test_amin(self, device, dtype): + self._test_minmax_helper(torch.amin, np.amin, device, dtype) + + @dtypesIfCPU(torch.float, torch.double, torch.int, torch.long, torch.bool) + @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) + @dtypes(torch.float, torch.double) + def test_amax(self, device, dtype): + self._test_minmax_helper(torch.amax, np.amax, device, dtype) + + @onlyOnCPUAndCUDA + @dtypesIfCPU(torch.float, torch.double) + @dtypesIfCUDA(torch.half, torch.float) + def test_aminmax(self, device, dtype): + + def _amin_wrapper(x, dim=None, keepdims=False): + if dim is None: + return torch._aminmax(x)[0] + else: + return torch._aminmax(x, dim, keepdims)[0] + + def _amax_wrapper(x, dim=None, keepdims=False): + if dim is None: + return torch._aminmax(x)[1] + else: + return torch._aminmax(x, dim, keepdims)[1] + + self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype) + self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype) + + # TODO: bincount isn't a classic reduction -- maybe this test suite is + # reductions and summary ops? + def test_bincount(self, device): + # negative input throws + with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): + torch.bincount(torch.tensor([1, -1], device=device)) + # n-d input, with n > 1 throws + with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): + torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) + # floating input type throws + with self.assertRaisesRegex(RuntimeError, 'not implemented'): + torch.bincount(torch.tensor([1., 0.3], device=device)) + # minlength < 0 throws + with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): + torch.bincount(torch.tensor([1, 3], device=device), + torch.tensor([.2, .2], device=device), + minlength=-1) + # input and weights dim mismatch + with self.assertRaisesRegex(RuntimeError, 'same length'): + torch.bincount(torch.tensor([1, 0], device=device), + torch.tensor([1., 0.3, 0.5], device=device)) + # 1-d input with no elements and default minlength + self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), + torch.zeros(0, dtype=torch.long, device=device)) + # 1-d input with no elements and specified minlength + self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), + torch.zeros(10, dtype=torch.long, device=device)) + + # test tensor method without weights + long_counts = torch.tensor( + [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() + self.assertEqual( + torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), + long_counts) + # test minlength functionality + int_counts = torch.bincount( + torch.tensor([1, 1, 1, 1], device=device), minlength=5) + self.assertEqual( + torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), + int_counts) + # test weights + byte_counts = torch.bincount( + torch.tensor([0, 1, 1, 1, 4], device=device), + torch.tensor([.1, .2, .3, .4, .5], device=device)) + self.assertEqual( + torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) + byte_counts = torch.bincount( + torch.tensor([0, 1, 1, 1, 4], device=device), + torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) + self.assertEqual( + torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.float64), byte_counts) + # test non-contiguous inputs and weights + inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device) + weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) + for i in [0, 1]: + assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" + assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" + # inputs are non-contiguous but weights are contiguous + self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) + # inputs and weights are non-contiguous + self.assertEqual( + inputs[:, 1].bincount(weights[:, 1]), + torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) + # weights are non-contiguous but inputs are contiguous + self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), + torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) + + # test bincount on non-contiguous slices + all0s = torch.zeros((32, 2), dtype=torch.int64, device=device) + self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) + + all1s = torch.ones((32, 2), dtype=torch.int64, device=device) + self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) + + # test large number of bins - global memory use + big_exp = torch.zeros(10000000, device=device) + big_exp[-1] = 50.0 + big_w = torch.tensor([.5] * 100, device=device) + big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w) + self.assertEqual(big_exp, big_out) + # test large input size + big_exp = torch.zeros(2, device=device, dtype=torch.int64) + big_exp[1] = 1000000 + big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount() + self.assertEqual(big_exp, big_out) + + @onlyCUDA + @expectedAlertNondeterministic('_bincount_cuda', fn_has_device_arg=False) + def test_bincount_alert_nondeterministic(self, device): + torch.bincount(torch.tensor([], device=device, dtype=torch.long)) + + # TODO: how many var stability tests are there? + def test_var_stability2(self, device): + tensor = torch.FloatTensor([2281.5, 2281.25]).to(device) + + # Stability for inner dim + self.assertEqual(tensor.var(0), 0.03125) + + # General stability + self.assertEqual(tensor.var(), 0.03125) + + # Stability for outer dimensions + tensor = tensor.unsqueeze(1) + self.assertEqual(tensor.var(0), 0.03125) + + @onlyCPU + @dtypes(torch.bool, torch.double) + def test_sum_all(self, device, dtype) -> None: + def check_sum_all(tensor: torch.Tensor) -> None: + pylist = tensor.reshape(-1).tolist() + self.assertEqual(tensor.sum(), sum(pylist)) + + if dtype != torch.bool: + check_sum_all(torch.tensor([1, 2, 3, 4, 5], dtype=dtype, device=device)) + check_sum_all(torch.randn(200000, dtype=dtype, device=device)) + check_sum_all(torch.randn(2000, 2, dtype=dtype, device=device)[:, 0]) + else: + check_sum_all(torch.tensor([True, False, True], dtype=torch.bool, device=device)) + + def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, + memory_format, compare_data=True, default_is_preserve=False): + + assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d) + + # xc is a channels last tensor + xc = input_generator_fn(device) + # xc is not memory dense, but looks like channels last + if memory_format == torch.channels_last: + xc = xc[..., ::2, ::2] + else: + xc = xc[..., ::2, ::2, ::2] + + clone = transformation_fn(xc, memory_format=torch.preserve_format) + self.assertFalse(clone.is_contiguous()) + self.assertTrue(clone.is_contiguous(memory_format=memory_format)) + self.assertFalse(xc.is_contiguous()) + self.assertFalse(xc.is_contiguous(memory_format=memory_format)) + if compare_data: + self.assertEqual(xc, clone.to(xc)) + + xc = input_generator_fn(device) + clone = transformation_fn(xc, memory_format=torch.contiguous_format) + self.assertTrue(clone.is_contiguous()) + self.assertFalse(clone.is_contiguous(memory_format=memory_format)) + if compare_data: + self.assertEqual(xc, clone.to(xc)) + + xc = input_generator_fn(device) + clone = transformation_fn(xc) + + if default_is_preserve: + self.assertFalse(clone.is_contiguous()) + self.assertTrue(clone.is_contiguous(memory_format=memory_format)) + else: + self.assertTrue(clone.is_contiguous()) + self.assertFalse(clone.is_contiguous(memory_format=memory_format)) + if compare_data: + self.assertEqual(xc, clone.to(xc)) + + x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device) + for _ in range(10): + permutation = list(range(len(x.shape))) + random.shuffle(permutation) + x = x.permute(permutation) + self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride()) + + @onlyCPU + @dtypes(torch.double) + def test_sum_out(self, device, dtype: torch.dtype) -> None: + x = torch.rand(100, 100, dtype=dtype, device=device) + res1 = torch.sum(x, 1) + res2 = torch.tensor((), dtype=dtype, device=device) + torch.sum(x, 1, out=res2) + self.assertEqual(res1, res2) + x = torch.rand(100, 100, 100, dtype=dtype, device=device) + res1 = x.sum(2).sum(1) + res2 = torch.tensor((), dtype=dtype, device=device) + torch.sum(x, (2, 1), out=res2) + self.assertEqual(res1, res2) + + @onlyCUDA + @dtypes(torch.float16, torch.float32) + def test_prod_gpu(self, device, dtype): + x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device) + + # Check all combinations: fp16 input - fp16 output, fp16 input - fp32 + # output, fp32 input - fp16 output, fp32 input - fp32 output + for dtype_output in [torch.float16, torch.float32]: + result_expected = torch.tensor(2592, dtype=dtype_output, device=device) + output = torch.prod(x, dtype=dtype_output) + self.assertEqual(output, result_expected) + + output = x.prod(dtype=dtype_output) + self.assertEqual(output, result_expected) + + @onlyCPU + @dtypes(torch.float) + def test_prod(self, device, dtype): + x = torch.rand(100, 100, dtype=dtype, device=device) + res1 = torch.prod(x, 1) + res2 = torch.tensor((), dtype=dtype, device=device) + torch.prod(x, 1, out=res2) + self.assertEqual(res1, res2) + + def test_prod_bool(self, device): + vals = [[True, True], [True, False], [False, False], []] + for val in vals: + result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item() + expect = np.prod(np.array(val), dtype=np.bool) + self.assertEqual(result, expect) + + result = torch.prod(torch.tensor(val, device=device)).item() + expect = np.prod(np.array(val)) + self.assertEqual(result, expect) + + @onlyCPU + def test_max_mixed_devices(self, device): + a = torch.randn(10, device=device) + if torch.cuda.is_available(): + values = torch.randn(10).cuda() + indices = torch.cuda.LongTensor() + self.assertRaises(RuntimeError, + lambda: torch.max(a, 0, out=(values, indices))) + self.assertRaises(RuntimeError, + lambda: torch.amax(a, 0, out=values)) + + @onlyCPU + def test_min_mixed_devices(self, device): + a = torch.randn(10, device=device) + if torch.cuda.is_available(): + values = torch.randn(10).cuda() + indices = torch.cuda.LongTensor() + self.assertRaises(RuntimeError, + lambda: torch.min(a, 0, out=(values, indices))) + self.assertRaises(RuntimeError, + lambda: torch.amin(a, 0, out=values)) + + # TODO: consider refactoring with bincount test + def test_bucketization(self, device): + values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device) + values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device) + + # regular case 3d boundary and 3d input value + boundaries = torch.tensor([[[1, 2, 3, 4], [3, 4, 5, 6]], [[1, 3, 5, 7], [2, 4, 6, 8]]], device=device) + expected_result = torch.tensor([[[0, 2, 4], [0, 1, 3]], [[0, 1, 1], [1, 2, 2]]], device=device) + output = torch.empty(2, 2, 3, device=device, dtype=torch.int64) + self.assertEqual(torch.searchsorted(boundaries, values_3d), expected_result) + self.assertEqual(torch.searchsorted(boundaries, values_3d, out=output), expected_result) + expected_result = torch.tensor([[[1, 3, 4], [0, 2, 4]], [[1, 1, 2], [2, 2, 3]]], device=device) + self.assertEqual(torch.searchsorted(boundaries, values_3d, right=True), expected_result) + self.assertEqual(torch.searchsorted(boundaries, values_3d, right=True, out=output), expected_result) + + # simple 1d boundary and 3d input value + boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device) + expected_result = torch.tensor([[[0, 2, 4], [1, 3, 5]], [[0, 1, 2], [3, 4, 5]]], device=device) + output = torch.empty(2, 2, 3, device=device, dtype=torch.int64) + self.assertEqual(torch.searchsorted(boundaries, values_3d), expected_result) + self.assertEqual(torch.bucketize(values_3d, boundaries), expected_result) + self.assertEqual(torch.bucketize(values_3d, boundaries, out=output), expected_result) + expected_result = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device) + self.assertEqual(torch.searchsorted(boundaries, values_3d, right=True), expected_result) + self.assertEqual(torch.bucketize(values_3d, boundaries, right=True), expected_result) + self.assertEqual(torch.bucketize(values_3d, boundaries, out=output, right=True), expected_result) + + # simple float 1d boundary and 1d input with output int32 type + values_1d_float = values_1d.to(torch.float32) + boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32) + expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32) + self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result) + self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result) + + # multiple dimension input with 0 elements + boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int64) + values_0_el = torch.tensor([[[]]], device=device, dtype=torch.int64) + expected_result = values_0_el.to(torch.int64) + self.assertEqual(torch.searchsorted(boundaries, values_0_el), expected_result) + self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result) + + # nan input + values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64) + boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64) + expected_result = torch.tensor([1, 4, 2, 4], device=device) + self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result) + expected_result = torch.tensor([2, 4, 3, 4], device=device) + self.assertEqual(torch.searchsorted(boundaries, values_nan, right=True), expected_result) + + # type promotion and non contiguous tensors + values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32) + boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64) + expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device) + if self.device_type != 'xla': + self.assertWarnsRegex( + UserWarning, "tensor is non-contiguous", + lambda: self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result)) + else: + # All tensors in XLA is contiguous even doing permute, no warning msg will be generate in XLA + self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result) + + # scalar type + boundaries = torch.tensor([1.5, 2.5, 3.5], device=device) + expected_result = torch.tensor(1, device=device) + self.assertEqual(torch.searchsorted(boundaries, 2), expected_result) + self.assertEqual(torch.bucketize(torch.tensor(2, device=device), boundaries), expected_result) + expected_result = torch.tensor(3, device=device) + scalar_tensor_nan = torch.tensor(float('nan'), device=device) + self.assertEqual(torch.searchsorted(boundaries, scalar_tensor_nan), expected_result) + self.assertEqual(torch.bucketize(float('nan'), boundaries, right=True), expected_result) + + # invalid input dimensions + boundaries = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device) + with self.assertRaisesRegex( + RuntimeError, "first N-1 dimensions of boundaries tensor and input value tensor must match"): + torch.searchsorted(boundaries, values_3d) + with self.assertRaisesRegex( + RuntimeError, "boundaries tensor must be 1 dimension"): + torch.bucketize(values_3d, boundaries) + with self.assertRaisesRegex( + RuntimeError, "only when boundaries tensor dimension is 1"): + torch.searchsorted(boundaries, 1) + + # incompatiable output tensor's dtype + def test_output_dtype(dtype, is_int32): + output = values_1d.to(dtype) + with self.assertRaisesRegex( + RuntimeError, "output tensor's dtype is wrong"): + torch.searchsorted(values_1d, values_1d, out=output, out_int32=is_int32) + + test_output_dtype(torch.float32, False) + test_output_dtype(torch.int32, False) + test_output_dtype(torch.int64, True) + + @dtypesIfCUDA(torch.half, torch.float, torch.double, + torch.int8, torch.short, torch.int, torch.long) + @dtypes(torch.float, torch.double, + torch.int8, torch.short, torch.int, torch.long) + def test_nansum(self, device, dtype): + x = (torch.randn(3, 3)) + if dtype in [torch.half, torch.float, torch.double]: + x[x < 0.2] = float('nan') + # Randomly scale the values + x = (x * random.randint(10, 100)).tolist() + + self.compare_with_numpy(torch.nansum, np.nansum, x, device, dtype) + + def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype, + with_extremal=False, atol=None, rtol=None, + exact_dtype=True, with_keepdim=False): + # Test 0-d to 3-d tensors. + for ndims in range(0, 4): + shape = _rand_shape(ndims, min_size=5, max_size=10) + for n in range(ndims + 1): + for c in combinations(list(range(ndims)), n): + for count_dim in permutations(c): + # Generate Input. + x = _generate_input(shape, dtype, device, with_extremal) + + if count_dim == (): + # Default `dims=None` case + self.compare_with_numpy(torch_func, np_func, x, device=None, dtype=None, + atol=atol, rtol=rtol, exact_dtype=exact_dtype) + else: + # With `dims: tuple of ints` case + if with_keepdim: + torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim) + np_func_partial = partial(np_func, keepdims=True, axis=count_dim) + else: + torch_func_partial = partial(torch_func, dim=count_dim) + np_func_partial = partial(np_func, axis=count_dim) + self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None, + atol=atol, rtol=rtol, exact_dtype=exact_dtype) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + torch.testing.get_all_complex_dtypes())) + def test_count_nonzero(self, device, dtype): + self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype) + self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True) + + def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False): + def is_integral(dtype): + return dtype in torch.testing.get_all_int_dtypes() + + # On Windows CI, the current version of `numpy` promotes all lower integers + # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking + # the exact dtype. + # Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580 + # PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370 + exact_dtype = False if (IS_WINDOWS and is_integral(dtype)) else True + + if dtype == torch.uint8: + with self.assertRaises(TypeError): + self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype, with_extremal=with_extremal) + else: + # TODO: Investigate why the output is not close to numpy. + if dtype == torch.float16: + atol = 0.4 + rtol = 1e-2 + elif dtype == torch.float32: + atol = 7e-05 + rtol = 3e-06 + else: + # Default values + atol = None + rtol = None + self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype, + atol=atol, rtol=rtol, exact_dtype=exact_dtype, + with_keepdim=with_keepdim, with_extremal=with_extremal) + + @onlyOnCPUAndCUDA + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_sum_vs_numpy(self, device, dtype): + self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype) + self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True) + self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True) + + @onlyOnCPUAndCUDA + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_nansum_vs_numpy(self, device, dtype): + self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype) + self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True) + self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True) + + @dtypes(*(torch.testing.get_all_complex_dtypes())) + def test_nansum_complex(self, device, dtype): + x = torch.randn((3, 3, 3), device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"): + torch.nansum(x) + + def test_nansum_out_dtype(self, device): + dtypes = list(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)) + for inp_dtype, out_dtype in combinations(dtypes, 2): + shape = _rand_shape(random.randint(2, 5), min_size=5, max_size=10) + x = _generate_input(shape, inp_dtype, device, with_extremal=False) + torch_fn = partial(torch.nansum, dtype=out_dtype) + np_out_dtype = torch_to_numpy_dtype_dict[out_dtype] + np_fn = partial(np.nansum, dtype=np_out_dtype) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_argminmax_multiple(self, device, dtype): + # Case: All Ones + t = torch.ones(3, 3, device=device, dtype=dtype) + self.compare_with_numpy(torch.argmax, np.argmax, t) + self.compare_with_numpy(torch.argmin, np.argmin, t) + + # Case: With single `nan` present. + if dtype in torch.testing.get_all_fp_dtypes(): + t[2, 2] = float('nan') + self.compare_with_numpy(torch.argmax, np.argmax, t) + self.compare_with_numpy(torch.argmin, np.argmin, t) + + # Case: Randomly Generated Tensors + for ndims in range(1, 5): + shape = _rand_shape(ndims, min_size=5, max_size=10) + for with_extremal in [False, True]: + for contiguous in [False, True]: + # Generate Input. + x = _generate_input(shape, dtype, device, with_extremal) + + if dtype == torch.half: + max_val = torch.max(x.to(torch.float)) + min_val = torch.min(x.to(torch.float)) + else: + max_val = torch.max(x) + min_val = torch.min(x) + + mask = torch.randn(x.shape) > 0.5 + x[mask] = torch.tensor(max_val + 1, dtype=dtype) + + mask = torch.randn(x.shape) > 0.5 + x[mask] = torch.tensor(min_val - 1, dtype=dtype) + + if not contiguous: + x = x.T + + self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None) + self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None) + + # Verify indices returned by max and min. + if dtype != torch.half: + rand_dim = random.randint(0, ndims - 1) + self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1], + lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None) + self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1], + lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None) + + def verify_against_numpy(t): + # Argmax + torch_fn = partial(torch.argmax, dim=1) + np_fn = partial(np.argmax, axis=1) + self.compare_with_numpy(torch_fn, np_fn, t) + # Non-contiguous input + self.compare_with_numpy(torch_fn, np_fn, t.T) + + # Verify indices returned by max. + if dtype != torch.half: + self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None) + self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) + + # Argmin + torch_fn = partial(torch.argmin, dim=1) + np_fn = partial(np.argmin, axis=1) + self.compare_with_numpy(torch_fn, np_fn, t) + # Non-contiguous input + self.compare_with_numpy(torch_fn, np_fn, t.T) + + # Verify indices returned by min. + if dtype != torch.half: + self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None) + self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) + + # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 + t = torch.tensor([[1, 5], + [2, 10], + [3, 3]], device=device, dtype=dtype) + verify_against_numpy(t) + + # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 + t = torch.tensor([[1, 5], + [2, 10], + [0, 0]], device=device, dtype=dtype) + verify_against_numpy(t) + + @dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True))) + def test_all_any_vs_numpy(self, device, dtype): + # Note [all, any uint8 compatibility]: However for compatibility reason, + # for `uint8`, they return Tensor of same dtype `uint8`. + # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 + exact_dtype = True if dtype != torch.uint8 else False + + def _test_all_any(x): + self.compare_with_numpy(torch.all, np.all, x) + self.compare_with_numpy(torch.any, np.any, x) + + def _test_all_any_with_dim(x, dim): + torch_fn = partial(torch.all, dim=dim) + np_fn = partial(np.all, axis=dim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + torch_fn = partial(torch.any, dim=dim) + np_fn = partial(np.any, axis=dim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + def _test_out_variant(x, dim): + out = torch.empty_like(x) + if dtype == torch.bool or dtype == torch.uint8: + expected = torch.all(x, dim) + torch.all(x, dim, out=out) + self.assertEqual(expected, out) + + expected = torch.any(x, dim) + torch.any(x, dim, out=out) + self.assertEqual(expected, out) + else: + with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"): + torch.all(x, dim, out=out) + + with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"): + torch.any(x, dim, out=out) + + def _test_all_any_with_dim_keepdim(x, dim, keepdim): + torch_fn = partial(torch.all, dim=dim, keepdim=keepdim) + np_fn = partial(np.all, axis=dim, keepdims=keepdim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + torch_fn = partial(torch.any, dim=dim, keepdim=keepdim) + np_fn = partial(np.any, axis=dim, keepdims=keepdim) + self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) + + def _test_output_dtype(x): + # This test will fail once the functions return bool output + # for uint8 input. + expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool + self.assertEqual(torch.all(x).dtype, expected_dtype) + self.assertEqual(torch.any(x).dtype, expected_dtype) + + self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype) + self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype) + + for ndim in range(5): + shape = _rand_shape(ndim, 1, 5) + x = _generate_input(shape, dtype, device, with_extremal=False) + _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) + + x = _generate_input(shape, dtype, device, with_extremal=True) + _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) + + x = torch.zeros_like(x) + _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) + + x = torch.ones_like(x) + _test_all_any(x) + _test_all_any(x.T) + _test_all_any(x[..., ::2]) + _test_output_dtype(x) + for dim in range(ndim): + x = _generate_input(shape, dtype, device, with_extremal=False) + _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) + + x = _generate_input(shape, dtype, device, with_extremal=True) + _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) + + x = torch.zeros_like(x) + _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) + + x = torch.ones_like(x) + _test_all_any_with_dim(x, dim) + _test_all_any_with_dim(x.T, dim) + _test_all_any_with_dim(x[..., ::2], dim) + _test_out_variant(x, dim) + _test_all_any_with_dim_keepdim(x, dim, keepdim=True) + _test_all_any_with_dim_keepdim(x, dim, keepdim=False) + + # TODO: part of this test covers torch.norm, with should be covered by test_linalg + @onlyOnCPUAndCUDA + def test_repeated_dim(self, device): + ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var, + torch.amin, torch.amax, torch.norm] + x = torch.randn(3, 3, 3, 3, device=device) + + error_msg = r'appears multiple times in the list of dims' + norm_error_msg = r'Expected dims to be different, got' + for op in ops: + for dim in [(0, 0), (0, -4)]: + e_msg = norm_error_msg if op == torch.norm else error_msg + with self.assertRaisesRegex(RuntimeError, e_msg): + op(x, dim=dim) + + # TODO: update this test to comapre against NumPy + @onlyCUDA + def test_var(self, device): + cpu_tensor = torch.randn(2, 3, 3) + device_tensor = cpu_tensor.to(device) + self.assertEqual(device_tensor.var(), cpu_tensor.var()) + self.assertEqual(device_tensor.var(1), cpu_tensor.var(1)) + self.assertEqual(device_tensor.var(2), cpu_tensor.var(2)) + self.assertEqual(device_tensor.std(), cpu_tensor.std()) + self.assertEqual(device_tensor.std(1), cpu_tensor.std(1)) + self.assertEqual(device_tensor.var(2), cpu_tensor.var(2)) + + cpu_tensor = torch.randn(100) + device_tensor = cpu_tensor.to(device) + self.assertEqual(device_tensor.var(), cpu_tensor.var()) + + # TODO: update this test to compare against NumPy + @onlyCUDA + def test_var_large_input(self, device): + # Large, not-nice input + cpu_tensor = torch.randn(2 * 32 * 1024 + 1, 2, 67) + device_tensor = cpu_tensor.to(device) + + self.assertEqual(cpu_tensor.var(2), device_tensor.var(2)) + + # TODO: update this to compare against NumPy instead of CPU + @onlyCUDA + @dtypes(torch.double) + def test_sum_noncontig(self, device, dtype): + x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2) + y = x.cpu() + self.assertEqual(x.sum().cpu(), y.sum()) + self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2))) + self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3))) + + # TODO: update this to compare against NumPy instead of CPU + @onlyCUDA + def test_min_max_nan(self, device): + tests = [(lambda x: x.min(), 'min'), + (lambda x: x.max(), 'max'), + (lambda x: x.amin(), 'amin'), + (lambda x: x.amax(), 'amax'), + (lambda x: x.min(0).values, 'min_dim'), + (lambda x: x.max(0).values, 'max_dim'), + (lambda x: x.amin(0), 'amin_dim'), + (lambda x: x.amax(0), 'amax_dim')] + for f, name in tests: + a = torch.arange(25.0).view(5, 5) + a[2, 2] = nan + actual = f(a.to(device)).cpu() + expected = f(a).cpu() + self.assertEqual(torch.isnan(actual), torch.isnan(expected), msg='nans for {}'.format(name)) + self.assertEqual(actual[~torch.isnan(actual)], + expected[~torch.isnan(expected)], msg='nans for {}'.format(name)) + + # TODO: make this test generic using OpInfos + @onlyCUDA + def test_sum_cpu_device_mismatch(self, device): + x = torch.randn(20, dtype=torch.float32, device=device) + y = torch.randn(1, dtype=torch.float32) + + err_string = "Expected all tensors to be on the same device, but found at least two devices, {0}".format(device) + + with self.assertRaisesRegex(RuntimeError, err_string): + torch.sum(x, dim=[0], dtype=torch.float32, out=y) + + # tests half to float promotion + if self.device_type == 'cuda': + x = x.half() + with self.assertRaisesRegex(RuntimeError, err_string): + torch.sum(x, dim=[0], dtype=torch.float32, out=y) + + # Assert for illegal dtype would not be raised on XLA + @onlyOnCPUAndCUDA + def test_minmax_illegal_dtype(self, device): + x = torch.randn(5, 5, dtype=torch.float32, device=device) + valid_values = torch.empty(5, dtype=torch.float32, device=device) + valid_indices = torch.empty(5, dtype=torch.long, device=device) + illegal_values = torch.empty(5, dtype=torch.int, device=device) + illegal_indices = torch.empty(5, dtype=torch.double, device=device) + torch.max(x, dim=0, out=(valid_values, valid_indices)) + torch.min(x, dim=0, out=(valid_values, valid_indices)) + torch.amax(x, dim=0, out=valid_values) + torch.amin(x, dim=0, out=valid_values) + rmsg = r'scalar type|dtype' + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.max(x, dim=0, out=(illegal_values, valid_indices)) + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.min(x, dim=0, out=(illegal_values, valid_indices)) + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.amax(x, dim=0, out=illegal_values) + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.amin(x, dim=0, out=illegal_values) + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.max(x, dim=0, out=(valid_values, illegal_indices)) + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.min(x, dim=0, out=(valid_values, illegal_indices)) + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.max(x, dim=0, out=(illegal_values, illegal_indices)) + with self.assertRaisesRegex(RuntimeError, rmsg): + torch.min(x, dim=0, out=(illegal_values, illegal_indices)) + + @dtypes(torch.float, torch.double, torch.int64, torch.int32, torch.int16) + @dtypesIfCUDA(torch.float, torch.double, torch.int64, torch.int32, torch.int16, torch.half) + def test_dim_arg_reduction_scalar(self, device, dtype): + example = 4.0 + + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.argmax().item(), 0) + self.assertEqual(x.argmax(dim=None).item(), 0) + self.assertEqual(x.argmax(dim=0).item(), 0) + self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64)) + + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.argmin().item(), 0) + self.assertEqual(x.argmin(dim=None).item(), 0) + self.assertEqual(x.argmin(dim=0).item(), 0) + self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64)) + + + def test_dim_reduction(self, device): + example = [[-1, 2, 1], [5, 3, 6]] + + types = [torch.double, + torch.float, + torch.int64, + torch.int32, + torch.int16] + if self.device_type == 'cuda': # 'cpu' and 'xla' do not support half + types.append(torch.half) + + sum_dtype = { + torch.double: torch.double, + torch.float: torch.float, + torch.half: torch.half, + torch.int64: torch.int64, + torch.int32: torch.int64, + torch.int16: torch.int64, + } + + # This won't test for 256bit instructions, since we usually + # only work on 1 cacheline (1024bit) at a time and these + # examples aren't big enough to trigger that. + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.sum().item(), 16) + self.assertEqual(x.sum(0), torch.tensor([4, 5, 7], dtype=sum_dtype[dtype])) + self.assertEqual(x.sum(1), torch.tensor([2, 14], dtype=sum_dtype[dtype])) + y = torch.tensor(example, device=device, dtype=sum_dtype[dtype]) + torch.sum(x, 0, out=y) + self.assertEqual(x.sum(0), y) + + # Mean not supported for Int types + for dtype in types[:2]: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.mean().item(), 16.0 / 6) + self.assertEqual(x.mean(0), torch.tensor([2.0, 2.5, 7.0 / 2], dtype=dtype)) + self.assertEqual(x.mean(1), torch.tensor([2.0 / 3, 14.0 / 3], dtype=dtype)) + self.assertEqual(x.mean(), x.mean((0, 1))) + + prod_dtype = { + torch.double: torch.double, + torch.float: torch.float, + torch.half: torch.half, + torch.int64: torch.int64, + torch.int32: torch.int64, + torch.int16: torch.int64 + } + + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.prod().item(), -180) + self.assertEqual(x.prod(0), torch.tensor([-5, 6, 6], dtype=prod_dtype[dtype])) + self.assertEqual(x.prod(1), torch.tensor([-2, 90], dtype=prod_dtype[dtype])) + + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + + self.assertEqual(x.min().item(), -1) + self.assertEqual(x.argmin().item(), 0) + + # TODO: torch.min does not support the same operation as argmin + # for the same case, should we enable it? + self.assertEqual(x.argmin(dim=None).item(), 0) + + self.assertEqual(x.min(0), (torch.tensor([-1, 2, 1], dtype=dtype), + torch.tensor([0, 0, 0], dtype=torch.int64))) + self.assertEqual(x.amin(0), torch.tensor([-1, 2, 1], dtype=dtype)) + self.assertEqual(x.argmin(0), torch.tensor([0, 0, 0], dtype=torch.int64)) + + self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype), + torch.tensor([[0, 0, 0]], dtype=torch.int64))) + self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype)) + self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64)) + + self.assertEqual(x.min(1), (torch.tensor([-1, 3], dtype=dtype), + torch.tensor([0, 1], dtype=torch.int64))) + self.assertEqual(x.amin(1), torch.tensor([-1, 3], dtype=dtype)) + self.assertEqual(x.argmin(1), torch.tensor([0, 1], dtype=torch.int64)) + + self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype), + torch.tensor([[0], [1]], dtype=torch.int64))) + self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype)) + self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64)) + + # test that non-contiguous tensors work + self.assertEqual(x[:, :2].min().item(), -1) + self.assertEqual(x[:, :2].amin().item(), -1) + self.assertEqual(x[:, :2].argmin().item(), 0) + + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + + self.assertEqual(x.max().item(), 6) + self.assertEqual(x.amax().item(), 6) + self.assertEqual(x.argmax().item(), 5) + + self.assertEqual(x.max(0), (torch.tensor([5, 3, 6], dtype=dtype), + torch.tensor([1, 1, 1], dtype=torch.int64))) + self.assertEqual(x.amax(0), torch.tensor([5, 3, 6], dtype=dtype)) + self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64)) + + self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype), + torch.tensor([[1, 1, 1]], dtype=torch.int64))) + self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype)) + self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64)) + + self.assertEqual(x.max(1), (torch.tensor([2, 6], dtype=dtype), + torch.tensor([1, 2], dtype=torch.int64))) + self.assertEqual(x.amax(1), torch.tensor([2, 6], dtype=dtype)) + self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64)) + + self.assertEqual(x.max(1, keepdim=True), (torch.tensor([[2], [6]], dtype=dtype), + torch.tensor([[1], [2]], dtype=torch.int64))) + self.assertEqual(x.amax(1, keepdim=True), torch.tensor([[2], [6]], dtype=dtype)) + self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64)) + + # test that non-contiguous tensors work + self.assertEqual(x[:, :2].max().item(), 5) + self.assertEqual(x[:, :2].amax().item(), 5) + self.assertEqual(x[:, :2].argmax().item(), 2) + + dim_red_fns = [ + "mean", "median", "nanmedian", "mode", "norm", "prod", + "std", "sum", "var", "max", "min", "amax", "amin"] + + def normfn_attr(t, dim, keepdim=False, out=None): + attr = torch.norm + return attr(t, 2, dim, keepdim, out=out) + + for fn_name in dim_red_fns: + fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr + + def fn(x, dim, keepdim=False, out=None): + ans = fn_attr(x, dim, keepdim=keepdim, out=out) + return ans if not istuple(ans) else ans[0] + + def fn_tuple(x, dim, keepdim=False, out=None): + return fn_attr(x, dim, keepdim=keepdim, out=out) + + def test_multidim(x, dim): + self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True)) + self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension()) + self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension()) + + # general case + x = torch.randn(3, 4, 5, device=device) + dim = random.randint(0, 2) + test_multidim(x, dim) + + # check 1-d behavior + x = torch.randn(1, device=device) + dim = 0 + self.assertEqual(fn(x, dim).shape, ()) + self.assertEqual(fn(x, dim, keepdim=True).shape, (1,)) + + # check reducing of a singleton dimension + dims = [3, 4, 5] + singleton_dim = random.randint(0, 2) + dims[singleton_dim] = 1 + x = torch.randn(dims, device=device) + test_multidim(x, singleton_dim) + + # check reducing with output kwargs + if fn_name in ['median', 'nanmedian', 'mode', 'max', 'min']: + y = torch.randn(5, 3, device=device) + values = torch.randn(5, 3, device=device) + indices = torch.zeros(5, 3, device=device).long() - 1 + fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1])) + values_expected, indices_expected = fn_tuple(y, 1, keepdim=False) + self.assertEqual(values[:, 1], values_expected, + msg='{} values with out= kwarg'.format(fn_name)) + self.assertEqual(indices[:, 1], indices_expected, + msg='{} indices with out= kwarg'.format(fn_name)) + continue + + x = torch.randn(5, 3, device=device) + y = torch.randn(5, 3, device=device) + fn(y, 1, keepdim=False, out=x[:, 1]) + expected = fn(y, 1, keepdim=False) + self.assertEqual(x[:, 1], expected, msg='{} with out= kwarg'.format(fn_name)) + + @onlyCUDA + @largeTensorTest('10GB') + def test_reduction_split(self, device): + # Test reduction when there is a 32bit-indexing split + # https://github.com/pytorch/pytorch/issues/37583 + input_ = torch.randn(5, 14400, 14400, device=device) + result = input_.sum(dim=0) + expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4] + self.assertEqual(result, expect) + + @onlyCUDA + @dtypes(torch.half, torch.float, torch.double) + def test_reduction_vectorize_along_input_corner(self, device, dtype): + # 1D case: sum + size = 1024 * 1024 * 64 + 3 + shift = 1 + x = torch.zeros(size, dtype=dtype, device=device) + y = x[shift:] + for i in range(100): + x.zero_() + x[i] = 1 + self.assertEqual(x.sum(), 1.0) + if i < shift: + self.assertEqual(y.sum(), 0.0) + else: + self.assertEqual(y.sum(), 1.0) + for i in range(1, 100): + x.zero_() + x[-i] = 1 + self.assertEqual(x.sum(), 1.0) + self.assertEqual(y.sum(), 1.0) + # 1D case: argmax + size = 1024 * 1024 * 64 + 3 + shift = 1 + ysize = size - shift + x = torch.zeros(size, dtype=dtype, device=device) + y = x[shift:] + for i in range(100): + x.zero_() + x[i] = 1 + self.assertEqual(x.argmax().item(), i) + if i >= shift: + self.assertEqual(y.argmax().item(), i - shift) + for i in range(1, 100): + x.zero_() + x[-i] = 1 + self.assertEqual(x.argmax().item(), size - i) + self.assertEqual(y.argmax().item(), ysize - i) + # 2D case: sum + size = (7, 1024 * 1024 + 3) + x = torch.zeros(size, dtype=dtype, device=device) + for i in range(100): + x.zero_() + for j in range(7): + x[j][i] = j + xs = x.sum(dim=-1) + for j in range(7): + self.assertEqual(xs[j].item(), float(j)) + for i in range(100): + x.zero_() + for j in range(7): + x[j][-i] = j + xs = x.sum(dim=-1) + for j in range(7): + self.assertEqual(xs[j].item(), float(j)) + # 2D case: max/argmax + size = (7, 1024 * 1024 + 3) + x = torch.zeros(size, dtype=dtype, device=device) + for i in range(100): + x.zero_() + for j in range(7): + x[j][i] = j + 1 + xs1 = x.argmax(dim=-1) + xs2 = x.max(dim=-1).indices + for j in range(7): + self.assertEqual(xs1[j].item(), i) + self.assertEqual(xs2[j].item(), i) + for i in range(1, 100): + x.zero_() + for j in range(7): + x[j][-i] = j + 1 + xs1 = x.argmax(dim=-1) + xs2 = x.max(dim=-1).indices + for j in range(7): + self.assertEqual(xs1[j].item(), size[1] - i) + self.assertEqual(xs2[j].item(), size[1] - i) + # 2D case: min/argmin + size = (7, 1024 * 1024 + 3) + x = torch.zeros(size, dtype=dtype, device=device) + for i in range(100): + x.zero_() + for j in range(7): + x[j][i] = -(j + 1) + xs1 = x.argmin(dim=-1) + xs2 = x.min(dim=-1).indices + for j in range(7): + self.assertEqual(xs1[j].item(), i) + self.assertEqual(xs2[j].item(), i) + for i in range(1, 100): + x.zero_() + for j in range(7): + x[j][-i] = -(j + 1) + xs1 = x.argmin(dim=-1) + xs2 = x.min(dim=-1).indices + for j in range(7): + self.assertEqual(xs1[j].item(), size[1] - i) + self.assertEqual(xs2[j].item(), size[1] - i) + + @onlyCUDA + @dtypes(torch.half, torch.float, torch.double) + def test_reduction_vectorize_along_output(self, device, dtype): + def run_test(input_): + M, N = input_.shape + input_.zero_() + for i in range(min(M, N)): + input_[i][i] = 1 + output1 = input_.argmax(dim=0) + output2 = input_.sum(dim=0) + for i in range(min(M, N)): + self.assertEqual(output1[i], i) + self.assertEqual(output2[i], 1) + # vec 4 + run_test(torch.zeros(64, 64, dtype=dtype, device=device)) + # vec 2 + run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64)) + run_test(torch.zeros(64, 62, dtype=dtype, device=device)) + run_test(torch.zeros(64, 2, dtype=dtype, device=device)) + # vec 1 + run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64)) + run_test(torch.zeros(64, 61, dtype=dtype, device=device)) + run_test(torch.zeros(64, 1, dtype=dtype, device=device)) + + @slowTest + def test_argminmax_large_axis(self, device): + # Regression test for gh-32863 + x = torch.zeros(2**31, device=device, dtype=torch.int8) + x[-1] = 1 + self.assertEqual(x.argmax(0), x.shape[0] - 1) + self.assertEqual(x.max(0).indices, x.shape[0] - 1) + x[-1] = -1 + self.assertEqual(x.argmin(0), x.shape[0] - 1) + self.assertEqual(x.min(0).indices, x.shape[0] - 1) + + def test_argminmax_axis_with_dim_one(self, device): + # See: https://github.com/pytorch/pytorch/issues/38922 + n = 32768 + x = torch.zeros(1, n) + self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64)) + self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64)) + + self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64)) + self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64)) + + self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) + self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) + + self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) + self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) + + @dtypes(torch.int, torch.long, torch.float, torch.double) + @dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double) + def test_median_real_values(self, device, dtype): + # Generate random 0-3D sizes + sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] + for size in sizes: + # Create random input tensor + t = torch.randn(size, device=device).type(dtype) + t_numpy = t.cpu().numpy() + res = t.median() + self.assertEqual(res, t.nanmedian()) + k = int((t.numel() - 1) / 2) + self.assertEqual(res, t.view(-1).sort()[0][k]) + if t.numel() % 2 == 1: + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + self.assertEqual(res.cpu().numpy(), np.median(t_numpy)) + for dim in range(t.ndim): + res = t.median(dim, True) + self.assertEqual(res, t.nanmedian(dim, True)) + size = t.size(dim) if t.ndim > 0 else 1 + k = int((size - 1) / 2) + self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim)) + self.assertEqual(res[0], t.gather(dim, res[1])) + if size % 2 == 1: + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True)) + + @dtypes(torch.float, torch.double) + @dtypesIfCUDA(torch.half, torch.float, torch.double) + def test_median_nan_values(self, device, dtype): + # Generate random 0-3D sizes + sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] + for size in sizes: + # Create random input tensor with nan values + t = torch.rand(size, device=device, dtype=dtype) + t.masked_fill_(t < 0.1, float('nan')) + t_numpy = t.cpu().numpy() + for op in [torch.median, torch.nanmedian]: + numpy_op = np.median if op == torch.median else np.nanmedian + res = op(t) + num_nan = t.isnan().sum() + if op == torch.median and num_nan > 0: + k = t.numel() - 1 + else: + k = int((t.numel() - num_nan - 1) / 2) + self.assertEqual(res, t.view(-1).sort()[0][k]) + if (t.numel() - num_nan) % 2 == 1: + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + self.assertEqual(res.item(), numpy_op(t.cpu().numpy())) + for dim in range(t.ndim): + res = op(t, dim, True) + size = t.size(dim) if t.ndim > 0 else 1 + num_nan = t.isnan().sum(dim, True) + if op == torch.median: + k = torch.where(num_nan > 0, size - 1, int((size - 1) / 2)) + else: + k = ((size - num_nan - 1) / 2).type(torch.long) + self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k)) + self.assertEqual(res[0], t.gather(dim, res[1])) + # We can only test agains numpy for odd reductions because numpy + # returns the mean of the two medians and torch returns the lower + mask = (size - num_nan) % 2 == 1 + res = res[0].masked_select(mask).cpu() + ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()] + self.assertEqual(res, torch.from_numpy(ref)) + + def test_median_corner_cases(self, device): + def check(op, a, args, key): + t = torch.tensor(a, device=device) + res = op(t, *args) + if not args: + key = torch.tensor(key, device=device) + else: + if len(key) == 1: + key = torch.tensor(key[0], device=device) + res = res[0] + else: + key = (torch.tensor(key[0], device=device), torch.tensor(key[1], device=device)) + self.assertEqual(res, key) + + nan = float('nan') + check(torch.median, nan, [], nan) + check(torch.nanmedian, nan, [], nan) + check(torch.median, nan, [0], [nan, 0]) + check(torch.nanmedian, nan, [0], [nan, 0]) + check(torch.median, [nan], [0, True], [[nan], [0]]) + check(torch.nanmedian, [nan], [0, True], [[nan], [0]]) + check(torch.median, [nan], [0, True], [[nan], [0]]) + check(torch.nanmedian, [nan], [0, True], [[nan], [0]]) + + # Indices are not deterministic here so can only check values + check(torch.median, [[nan, nan], [1, 2]], [0], [[nan, nan]]) + check(torch.nanmedian, [[nan, nan], [1, 2]], [0], [[1, 2.]]) + check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]]) + check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]]) + + # Discontiguous and strided tensors + a = torch.arange(12, device=device) + self.assertEqual(a[::2].median(), torch.tensor(4, device=device)) + self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device)) + + a.resize_(3, 4) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device)) + self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device)) + + a.resize_(2, 3, 2) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + + + @onlyOnCPUAndCUDA + @dtypes(torch.float, torch.double) + def test_quantile(self, device, dtype): + # Generate some random test cases + ops = ['quantile', 'nanquantile'] + inputs = [tuple(np.random.randint(2, 10, size=i)) for i in range(1, 4)] + quantiles = [tuple(np.random.rand(i)) for i in range(0, 5)] + keepdims = [True, False] + + # Add corner cases + inputs.extend([0.75, (1,), (1, 1), (1, 2, 1)]) + inputs.extend([[float('nan')], [[float('nan'), float('nan')], [1, 2]]]) + inputs.extend([[[float('nan'), float('nan')], [float('nan'), 2]]]) + quantiles.extend([0.5, [0., 1.], np.random.rand(10)]) + + # Enumerate all input combinations + for op, x, q, keepdim in product(ops, inputs, quantiles, keepdims): + if type(x) is tuple: + a = torch.randn(x, dtype=dtype, device=device) + # Make some random elements NaN + a.masked_fill_(torch.randint_like(a, 20) == 0, float('nan')) + else: + a = torch.tensor(x, dtype=dtype, device=device) + + q = torch.tensor(q, dtype=dtype, device=device) + + torch_op = getattr(torch, op) + numpy_op = getattr(np, op) + + # Compute quantile along every dimension and flattened tensor + for dim in [None] + list(range(a.ndim)): + result = torch_op(a, q, dim, keepdim) + expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim, keepdims=keepdim) + self.assertEqual(result.cpu(), torch.from_numpy(np.array(expected)).type(result.type())) + + # Test out variation + out = torch.empty_like(result) + torch_op(a, q, dim, keepdim, out=out) + self.assertEqual(out.cpu(), result.cpu()) + + def test_quantile_backward(self, device): + def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)): + for op in ops: + t = torch.tensor(a, device=device, requires_grad=True) + op(t, torch.tensor(q, device=device), dim).sum().backward() + self.assertEqual(t.grad, expected_grad) + + check([1., 2, 3], 0.5, 0, [0, 1, 0]) + check([1., 2, 3, 4], 0.5, 0, [0, 0.5, 0.5, 0]) + check([3., 1, 4, 2], 0.5, 0, [0.5, 0, 0, 0.5]) + check([1., 2, 3, 4], [0.25, 0.5, 0.75], 0, [0.25, 1.25, 1.25, 0.25]) + check([[1., 2], [2, 1]], 0., 0, [[1, 0], [0, 1]]) + check([[1., 2], [4, 3]], 1., 1, [[0, 1], [1, 0]]) + check([1, float('nan'), 2], 0.5, 0, [0, 1, 0], [torch.quantile]) + check([1, float('nan'), 2], 0.5, 0, [0.5, 0, 0.5], [torch.nanquantile]) + + def test_quantile_error(self, device): + def check(a, q, args, kwargs, message): + with self.assertRaisesRegex(RuntimeError, r'quantile\(\) ' + message): + at = torch.tensor(a, device=device) + qt = torch.tensor(q, device=device) if isinstance(q, list) else q + torch.quantile(at, qt, *args, **kwargs) + + check([], 0.5, [], {}, r'input tensor must be non-empty') + check([1.], [[1.]], [], {}, r'q must be a scalar or 1D tensor') + check([1], 0.5, [], {}, r'input tensor must be either float or double dtype') + check([1.], [1], [], {}, r'q tensor must be same dtype as the input tensor') + check([1.], -1., [], {}, r'q must be in the range \[0, 1\] but got -1') + check([1.], 1.1, [], {}, r'q must be in the range \[0, 1\] but got 1.1') + check([1.], 0.5, [], {'out': torch.empty([], dtype=torch.float64, device=device)}, + r'out tensor must be same dtype as the input tensor') + + if self.device_type == "cpu": + check([1.], [0.5, 1.1, -1], [], {}, r'q values must be in the range \[0, 1\]') + + if self.device_type == "cuda": + with self.assertRaisesRegex( + RuntimeError, r'quantile\(\) q tensor must be on the same device as the input tensor'): + torch.randn(1, device=device).quantile(torch.tensor(0.5)) + with self.assertRaisesRegex( + RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'): + torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1)) + + def test_std_mean(self, device): + x = torch.rand(100, 50, 20, device=device) + for dim in range(x.dim()): + for unbiased in [False, True]: + for keepdim in [False, True]: + std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(std1, std2) + self.assertEqual(mean1, mean2) + + def test_std_mean_all_dims(self, device): + x = torch.rand(100, 50, 20, device=device) + for unbiased in [False, True]: + std1, mean1 = torch.std_mean(x, unbiased=unbiased) + std2 = x.std(unbiased=unbiased) + mean2 = x.mean() + self.assertEqual(std1, std2) + self.assertEqual(mean1, mean2) + + def test_var_mean(self, device): + x = torch.rand(100, 300, 50, device=device) + for dim in range(x.dim()): + for unbiased in [False, True]: + for keepdim in [False, True]: + var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(var1, var2) + self.assertEqual(mean1, mean2) + + def test_var_mean_all_dims(self, device): + x = torch.rand(100, 50, 20, device=device) + for unbiased in [False, True]: + var1, mean1 = torch.var_mean(x, unbiased=unbiased) + var2 = x.var(unbiased=unbiased) + mean2 = x.mean() + self.assertEqual(var1, var2) + self.assertEqual(mean1, mean2) + + def test_std_mean_some_dims(self, device): + sizes = (4, 6, 7, 5, 3) + dims = len(sizes) + x = torch.rand(sizes, device=device) + for num_of_dims in range(2, dims): + dim_list = list(combinations(list(range(dims)), r=num_of_dims)) + for dim in dim_list: + for unbiased in [False, True]: + for keepdim in [False, True]: + std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(std1, std2) + self.assertEqual(mean1, mean2) + + def _compare_std_var_with_numpy(self, op, device, dtype, input, dim, + keepdim, unbiased, use_out): + a = input.cpu().numpy() if input.dtype is not torch.bfloat16 else input.float().cpu().numpy() + numpy_kwargs = { + 'axis' : dim, + 'keepdims' : keepdim, + 'ddof' : 1 if unbiased else 0, + } + + if dim is None: + del numpy_kwargs['axis'] + del numpy_kwargs['keepdims'] + + if op == 'var': + torch_op = torch.var + numpy_op = np.var + elif op == 'std': + torch_op = torch.std + numpy_op = np.std + else: + self.fail("Unknown op!") + + numpy_result = numpy_op(a, **numpy_kwargs) + + if dim is None and use_out is False: + torch_result = torch_op(input, unbiased) + elif dim is not None and use_out is False: + torch_result = torch_op(input, dim, unbiased, keepdim) + elif dim is not None and use_out is True: + out = torch.empty(0, device=device, dtype=dtype) + torch_result = torch_op(input, dim, unbiased, keepdim, out=out) + else: + out = torch.empty(0, device=device, dtype=dtype) + try: + torch_result = torch_op(input, dim, unbiased, keepdim, out=out) + except RuntimeError: + return + self.fail("Failed to hit RuntimeError!") + + self.assertEqual(torch_result, numpy_result, exact_dtype=False) + + @dtypesIfCUDA(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypes(torch.float, torch.double) + def test_var_vs_numpy(self, device, dtype): + _size = (20, 20) + + for test_case in product((torch.randn(_size, device=device, dtype=dtype),), + (None, 0, 1), + (False, True), + (False, True), + (False, True),): + self._compare_std_var_with_numpy('var', device, dtype, *test_case) + + @dtypesIfCUDA(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypes(torch.float, torch.double) + def test_std_vs_numpy(self, device, dtype): + _size = (20, 20) + + for test_case in product((torch.randn(_size, device=device, dtype=dtype),), + (None, 0, 1), + (False, True), + (False, True), + (False, True),): + self._compare_std_var_with_numpy('std', device, dtype, *test_case) + + def test_amin_amax_some_dims(self, device): + sizes = (4, 6, 7, 5, 3) + dims = len(sizes) + x = torch.rand(sizes, device=device) + for num_of_dims in range(2, dims): + dim_list = list(combinations(list(range(dims)), r=num_of_dims)) + for dim in dim_list: + for keepdim in [False, True]: + amin1 = torch.amin(x, dim=dim, keepdim=keepdim) + amax1 = torch.amax(x, dim=dim, keepdim=keepdim) + amin2 = x + amax2 = x + for i, d in enumerate(dim): + if not keepdim: + d -= i + amin2 = torch.amin(amin2, dim=d, keepdim=keepdim) + amax2 = torch.amax(amax2, dim=d, keepdim=keepdim) + self.assertEqual(amin1, amin2) + self.assertEqual(amax1, amax2) + + @onlyCUDA + @expectedAlertNondeterministic('_histc_cuda', fn_has_device_arg=False) + def test_histc_alert_nondeterministic(self, device): + torch.histc(torch.tensor([], device=device), min=0, max=3) + + def test_histc(self, device): + # negative nbins throws + with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'): + torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1) + # empty tensor + actual = torch.histc(torch.tensor([], device=device), min=0, max=3) + expected = torch.zeros(100, dtype=torch.float, device=device) + self.assertEqual(expected, actual) + + # without nbins + actual = torch.histc( + torch.tensor([2, 5], dtype=torch.float, device=device)) + expected = torch.zeros(100, dtype=torch.float, device=device) + expected[0] = 1 + expected[99] = 1 + self.assertEqual(expected, actual) + # tensor with the same element + actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5) + self.assertEqual( + torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device), + actual) + # no element falls between [min, max] + actual = torch.histc( + torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3) + self.assertEqual( + torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device), + actual) + # element falls below min + integral bin size and + actual = torch.histc( + torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device), + bins=5, min=1, max=5) + self.assertEqual( + torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device), + actual) + # non-integral bin size + actual = torch.histc( + torch.tensor([1, 2, 1], dtype=torch.float, device=device), + bins=4, min=0, max=3) + self.assertEqual( + torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), + actual) + # double input + actual = torch.histc( + torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3) + self.assertEqual( + torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device), + actual) + self.assertEqual(actual.dtype, torch.double) + # mixed input + actual = torch.histc( + torch.tensor([1., 2, 1], dtype=torch.float, device=device), + bins=4, min=0, max=3) + self.assertEqual( + torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), + actual) + self.assertEqual(actual.dtype, torch.float) + # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar. + actual = torch.histc( + torch.tensor(0, dtype=torch.float, device=device), + bins=1, min=0, max=3) + self.assertEqual( + torch.tensor([1], dtype=torch.float, device=device), + actual) + # tensors with inf; min, max not provided -- should throw a RuntimeError + with self.assertRaisesRegex(RuntimeError, r'range of \[inf, inf\] is not finite'): + torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device)) + with self.assertRaisesRegex(RuntimeError, r'range of \[1, inf\] is not finite'): + torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device)) + # tensors with inf; min, max provided + self.assertEqual( + torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device), + bins=1, min=0, max=3), + torch.tensor([0], dtype=torch.float, device=device)) + self.assertEqual( + torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device), + bins=4, max=3), + torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device)) + # tensor with nan -- should throw a RuntimeError + with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'): + torch.histc(torch.tensor([float("nan")], dtype=torch.float, device=device)) + # tensors with min > max -- should throw a RuntimeError + with self.assertRaisesRegex(RuntimeError, "max must be larger than min"): + torch.histc(torch.tensor([1., 2., 3.], dtype=torch.float, device=device), + bins=4, min=5, max=1) + + # test against numpy.histogram() + def test_against_np(tensor, bins=100, min=0, max=0): + if min == 0 and max == 0: + min = tensor.min().item() + max = tensor.max().item() + nparr = tensor.cpu().numpy() + actual = torch.histc(tensor, bins=bins, min=min, max=max) + expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0]) + actual_cpu = actual.cpu() + # NB: Numpy returns a int64 tensor, like normal people... + self.assertEqual(actual, expected.to(actual_cpu)) + + test_against_np(torch.tensor([1., 2, 1], device=device)) + test_against_np(torch.randn(5000, device=device)) + + # Test bins arg + test_against_np(torch.randn(301, device=device), bins=10) + + # Test truncated range + test_against_np(torch.randn(201, device=device), min=0.1, max=1) + + noncontig = torch.randn(100, 3, device=device)[:, 2] + test_against_np(noncontig) + + multidim = torch.randn(3, 5, 7, 2, device=device) + test_against_np(multidim) + + expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) + test_against_np(expanded) + + def test_reduction_empty(self, device): + fns_to_test = [ + # name, function, identity + ('max', torch.max, None), + ('amax', torch.amax, None), + ('argmax', torch.argmax, None), + ('min', torch.min, None), + ('amin', torch.amin, None), + ('argmin', torch.argmin, None), + ('mode', torch.mode, None), + ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None), + ('prod', torch.prod, 1.), + ('sum', torch.sum, 0.), + ('norm', torch.norm, 0.), + ('mean', torch.mean, nan), + ('var', torch.var, nan), + ('std', torch.std, nan), + ('logsumexp', torch.logsumexp, -inf), + ] + + shape = (2, 0, 4) + x = torch.randn(shape, device=device) + + for fn in [torch.max, torch.min]: + ident_err = 'operation does not have an identity' + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x)) + + # median and nanmedian have been updated to follow the new convention for empty tensors + # where it should only fail if the dimension being reduced has size 0. + for name, fn in [('median', torch.median), ('nanmedian', torch.nanmedian)]: + ident_err = 'does not have an identity' + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True)) + self.assertEqual(fn(x, dim=0)[0].shape, (shape[1], shape[2])) + self.assertEqual(fn(x, dim=0, keepdim=True)[0].shape, (1, shape[1], shape[2])) + self.assertEqual(fn(x, dim=2)[0].shape, (shape[0], shape[1])) + self.assertEqual(fn(x, dim=2, keepdim=True)[0].shape, (shape[0], shape[1], 1)) + + for item in fns_to_test: + name, fn, identity = item + if identity is None: + ident_err = 'does not have an identity' + + # Reductions over non-zero dimensions should work even for empty tensors + # See https://github.com/pytorch/pytorch/issues/34907 for a discussion on this. + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2, keepdim=True)) + + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True)) + else: + self.assertEqual(torch.empty((2, 0), device=device), fn(x, dim=2)) + self.assertEqual(torch.empty((2, 0, 1), device=device), fn(x, dim=2, keepdim=True)) + # assertEqual doesn't work with inf, -inf, nan and two tensors. + check = (torch.testing.assert_allclose if math.isnan(identity) or math.isinf(identity) else + self.assertEqual) + check(torch.full((2, 4), identity, device=device), fn(x, dim=1)) + check(torch.full((2, 1, 4), identity, device=device), fn(x, dim=1, keepdim=True)) + try: + check(torch.full((), identity, device=device), fn(x)) + except TypeError as err: + # ignore if there is no allreduce. + self.assertTrue('dim' in str(err)) + + for dtype in torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True): + # Refer: [all, any uint8 compatibility] + if dtype == torch.uint8: + out_dtype = torch.uint8 + else: + out_dtype = torch.bool # output of all/any is bool irrespective of input dtype + + # any + xb = x.to(dtype) + yb = x.to(dtype) + self.assertEqual((2, 0), xb.any(2).shape) + self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) + self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1)) + self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True)) + self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any()) + + # all + self.assertEqual((2, 0), xb.all(2).shape) + self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) + self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1)) + self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True)) + self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all()) + + +instantiate_device_type_tests(TestReductions, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index 9b30e46905406..916f133c3fe17 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -17,7 +17,8 @@ from torch.serialization import check_module_version_greater_or_equal from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \ - TEST_DILL, run_tests, download_file, BytesIOContext + TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName +from torch.testing._internal.common_device_type import instantiate_device_type_tests # These tests were all copied from `test/test_torch.py` at some point, so see # the actual blame, see this revision @@ -112,7 +113,6 @@ def _test_serialization_assert(self, b, c): rootview = c[8] self.assertEqual(rootview.data_ptr(), c[0].data_ptr()) - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_serialization_zipfile_utils(self): data = { 'a': b'12039810948234589', @@ -137,24 +137,22 @@ def test(name_or_buffer): with tempfile.NamedTemporaryFile() as f: test(f) - with tempfile.NamedTemporaryFile() as f: - test(f.name) + with TemporaryFileName() as fname: + test(fname) test(io.BytesIO()) def test_serialization(self): # Test serialization with a real file b = self._test_serialization_data() - for use_name in (False, True): - # Passing filename to torch.save(...) will cause the file to be opened twice, - # which is not supported on Windows - if sys.platform == "win32" and use_name: - continue - with tempfile.NamedTemporaryFile() as f: - handle = f if not use_name else f.name - torch.save(b, handle) - f.seek(0) - c = torch.load(handle) + with tempfile.NamedTemporaryFile() as f: + torch.save(b, f) + f.seek(0) + c = torch.load(f) + self._test_serialization_assert(b, c) + with TemporaryFileName() as fname: + torch.save(b, fname) + c = torch.load(fname) self._test_serialization_assert(b, c) # test non-ascii encoding of bytes arrays/strings # The following bytes are produced by serializing @@ -191,7 +189,6 @@ def test_serialization_filelike(self): c = torch.load(f) self._test_serialization_assert(b, c) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") def test_serialization_fake_zip(self): data = [ ord('P'), @@ -204,14 +201,15 @@ def test_serialization_fake_zip(self): t = torch.tensor(data, dtype=torch.uint8) with tempfile.NamedTemporaryFile() as f: - torch.save(t, f.name) + torch.save(t, f) # If this check is False for all Python versions (i.e. the fix # has been backported), this test and torch.serialization._is_zipfile # can be deleted self.assertTrue(zipfile.is_zipfile(f)) self.assertFalse(torch.serialization._is_zipfile(f)) - self.assertEqual(torch.load(f.name), t) + f.seek(0) + self.assertEqual(torch.load(f), t) def test_serialization_gzip(self): # Test serialization with gzip file @@ -275,17 +273,16 @@ def test_serialization_offset_gzip(self): self.assertTrue(torch.equal(a, b)) self.assertEqual(i, j) - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_serialization_sparse(self): x = torch.zeros(3, 3) x[1][1] = 1 x = x.to_sparse() with tempfile.NamedTemporaryFile() as f: - torch.save({"tensor": x}, f.name) - y = torch.load(f.name) + torch.save({"tensor": x}, f) + f.seek(0) + y = torch.load(f) self.assertEqual(x, y["tensor"]) - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_serialization_sparse_invalid(self): x = torch.zeros(3, 3) x[1][1] = 1 @@ -308,11 +305,12 @@ def __reduce_ex__(self, proto): self.tensor.size()))) with tempfile.NamedTemporaryFile() as f: - torch.save({"spoofed": TensorSerializationSpoofer(x)}, f.name) + torch.save({"spoofed": TensorSerializationSpoofer(x)}, f) + f.seek(0) with self.assertRaisesRegex( RuntimeError, "size is inconsistent with indices"): - y = torch.load(f.name) + y = torch.load(f) def test_serialize_device(self): device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0'] @@ -583,17 +581,24 @@ def wrapper(*args, **kwargs): def __exit__(self, *args, **kwargs): torch.save = self.torch_save -class TestBothSerialization(TestCase, SerializationMixin): - def test_serialization_new_format_old_format_compat(self): - x = [torch.ones(200, 200) for i in range(30)] - torch.save(x, "big_tensor.zip", _use_new_zipfile_serialization=True) - x_new_load = torch.load("big_tensor.zip") - self.assertEqual(x, x_new_load) +class TestBothSerialization(TestCase): + @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") + def test_serialization_new_format_old_format_compat(self, device): + x = [torch.ones(200, 200, device=device) for i in range(30)] + + def test(f_new, f_old): + torch.save(x, f_new, _use_new_zipfile_serialization=True) + f_new.seek(0) + x_new_load = torch.load(f_new) + self.assertEqual(x, x_new_load) + + torch.save(x, f_old, _use_new_zipfile_serialization=False) + f_old.seek(0) + x_old_load = torch.load(f_old) + self.assertEqual(x_old_load, x_new_load) - torch.save(x, "big_tensor.zip", _use_new_zipfile_serialization=False) - x_old_load = torch.load("big_tensor.zip") - self.assertEqual(x_old_load, x_new_load) - os.remove("big_tensor.zip") + with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old: + test(f_new, f_old) class TestOldSerialization(TestCase, SerializationMixin): @@ -693,7 +698,6 @@ def run(self, *args, **kwargs): class TestSerialization(TestCase, SerializationMixin): - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_serialization_zipfile(self): data = self._test_serialization_data() @@ -708,12 +712,12 @@ def test(name_or_buffer): with tempfile.NamedTemporaryFile() as f: test(f) - with tempfile.NamedTemporaryFile() as f: - test(f.name) + + with TemporaryFileName() as fname: + test(fname) test(io.BytesIO()) - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_serialization_zipfile_actually_jit(self): with tempfile.NamedTemporaryFile() as f: torch.jit.save(torch.jit.script(torch.nn.Linear(3, 4)), f) @@ -729,12 +733,11 @@ def test_serialization_2gb_file(self): f.seek(0) state = torch.load(f) - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_pathlike_serialization(self): model = torch.nn.Conv2d(20, 3200, kernel_size=3) - with tempfile.NamedTemporaryFile() as f: - path = pathlib.Path(f.name) + with TemporaryFileName() as fname: + path = pathlib.Path(fname) torch.save(model, path) torch.load(path) @@ -742,6 +745,7 @@ def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super(TestSerialization, self).run(*args, **kwargs) +instantiate_device_type_tests(TestBothSerialization, globals()) if __name__ == '__main__': run_tests() diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py new file mode 100644 index 0000000000000..1d65eb163e787 --- /dev/null +++ b/test/test_shape_ops.py @@ -0,0 +1,620 @@ +import torch +import numpy as np + +from itertools import product, combinations, permutations +from functools import partial +import random + +from torch._six import nan +from torch.testing._internal.common_utils import ( + TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict) +from torch.testing._internal.common_methods_invocations import shape_funcs +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA, + dtypesIfCPU, dtypesIfCUDA, ops) + +# TODO: replace with make_tensor +def _generate_input(shape, dtype, device, with_extremal): + if shape == (): + x = torch.tensor((), dtype=dtype, device=device) + else: + if dtype.is_floating_point or dtype.is_complex: + # work around torch.randn not being implemented for bfloat16 + if dtype == torch.bfloat16: + x = torch.randn(*shape, device=device) * random.randint(30, 100) + x = x.to(torch.bfloat16) + else: + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x[torch.randn(*shape) > 0.5] = 0 + if with_extremal and dtype.is_floating_point: + # Use extremal values + x[torch.randn(*shape) > 0.5] = float('nan') + x[torch.randn(*shape) > 0.5] = float('inf') + x[torch.randn(*shape) > 0.5] = float('-inf') + elif with_extremal and dtype.is_complex: + x[torch.randn(*shape) > 0.5] = complex('nan') + x[torch.randn(*shape) > 0.5] = complex('inf') + x[torch.randn(*shape) > 0.5] = complex('-inf') + elif dtype == torch.bool: + x = torch.zeros(shape, dtype=dtype, device=device) + x[torch.randn(*shape) > 0.5] = True + else: + x = torch.randint(15, 100, shape, dtype=dtype, device=device) + + return x + +class TestShapeOps(TestCase): + + # TODO: update to work on CUDA, too + @onlyCPU + def test_unbind(self, device): + x = torch.rand(2, 3, 4, 5) + for dim in range(4): + res = torch.unbind(x, dim) + res2 = x.unbind(dim) + self.assertEqual(x.size(dim), len(res)) + self.assertEqual(x.size(dim), len(res2)) + for i in range(dim): + self.assertEqual(x.select(dim, i), res[i]) + self.assertEqual(x.select(dim, i), res2[i]) + + # TODO: update to work on CUDA, too? + @onlyCPU + def test_tolist(self, device): + list0D = [] + tensor0D = torch.Tensor(list0D) + self.assertEqual(tensor0D.tolist(), list0D) + + table1D = [1, 2, 3] + tensor1D = torch.Tensor(table1D) + storage = torch.Storage(table1D) + self.assertEqual(tensor1D.tolist(), table1D) + self.assertEqual(storage.tolist(), table1D) + self.assertEqual(tensor1D.tolist(), table1D) + self.assertEqual(storage.tolist(), table1D) + + table2D = [[1, 2], [3, 4]] + tensor2D = torch.Tensor(table2D) + self.assertEqual(tensor2D.tolist(), table2D) + + tensor3D = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + tensorNonContig = tensor3D.select(1, 1) + self.assertFalse(tensorNonContig.is_contiguous()) + self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]]) + + @dtypes(torch.int64, torch.float, torch.complex128) + def test_movedim_invalid(self, device, dtype): + shape = self._rand_shape(4, min_size=5, max_size=10) + x = _generate_input(shape, dtype, device, False) + + for fn in [torch.movedim, torch.moveaxis]: + # Invalid `source` and `destination` dimension + with self.assertRaisesRegex(IndexError, "Dimension out of range"): + fn(x, 5, 0) + + with self.assertRaisesRegex(IndexError, "Dimension out of range"): + fn(x, 0, 5) + + # Mismatch in size of `source` and `destination` + with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"): + fn(x, (1, 0), (0, )) + + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + fn(x, (0, 0), (0, 1)) + + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + fn(x, (0, 1, 0), (0, 1, 2)) + + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + fn(x, (0, 1), (1, 1)) + + with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + fn(x, (0, 1, 2), (1, 0, 1)) + + @dtypes(torch.int64, torch.float, torch.complex128) + def test_movedim(self, device, dtype): + for fn in [torch.moveaxis, torch.movedim]: + for nd in range(5): + shape = self._rand_shape(nd, min_size=5, max_size=10) + x = _generate_input(shape, dtype, device, with_extremal=False) + for random_negative in [True, False]: + for src_dim, dst_dim in permutations(range(nd), r=2): + random_prob = random.random() + + if random_negative and random_prob > 0.66: + src_dim = src_dim - nd + elif random_negative and random_prob > 0.33: + dst_dim = dst_dim - nd + elif random_negative: + src_dim = src_dim - nd + dst_dim = dst_dim - nd + + # Integer `source` and `destination` + torch_fn = partial(fn, source=src_dim, destination=dst_dim) + np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + if nd == 0: + continue + + def make_index_negative(sequence, idx): + sequence = list(sequence) + sequence[random_idx] = sequence[random_idx] - nd + return tuple(src_sequence) + + for src_sequence in permutations(range(nd), r=random.randint(1, nd)): + # Sequence `source` and `destination` + dst_sequence = tuple(random.sample(range(nd), len(src_sequence))) + + # Randomly change a dim to a negative dim representation of itself. + random_prob = random.random() + if random_negative and random_prob > 0.66: + random_idx = random.randint(0, len(src_sequence) - 1) + src_sequence = make_index_negative(src_sequence, random_idx) + elif random_negative and random_prob > 0.33: + random_idx = random.randint(0, len(src_sequence) - 1) + dst_sequence = make_index_negative(dst_sequence, random_idx) + elif random_negative: + random_idx = random.randint(0, len(src_sequence) - 1) + dst_sequence = make_index_negative(dst_sequence, random_idx) + random_idx = random.randint(0, len(src_sequence) - 1) + src_sequence = make_index_negative(src_sequence, random_idx) + + torch_fn = partial(fn, source=src_sequence, destination=dst_sequence) + np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + # Move dim to same position + x = torch.randn(2, 3, 5, 7, 11) + torch_fn = partial(fn, source=(0, 1), destination=(0, 1)) + np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1)) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + torch_fn = partial(fn, source=1, destination=1) + np_fn = partial(np.moveaxis, source=1, destination=1) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + # Empty Sequence + torch_fn = partial(fn, source=(), destination=()) + np_fn = partial(np.moveaxis, source=(), destination=()) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + @dtypes(torch.float, torch.bool) + def test_diag(self, device, dtype): + if dtype is torch.bool: + x = torch.rand(100, 100, device=device) >= 0.5 + else: + x = torch.rand(100, 100, dtype=dtype, device=device) + + res1 = torch.diag(x) + res2 = torch.tensor((), dtype=dtype, device=device) + torch.diag(x, out=res2) + self.assertEqual(res1, res2) + + def test_diagonal(self, device): + x = torch.randn((100, 100), device=device) + result = torch.diagonal(x) + expected = torch.diag(x) + self.assertEqual(result, expected) + + x = torch.randn((100, 100), device=device) + result = torch.diagonal(x, 17) + expected = torch.diag(x, 17) + self.assertEqual(result, expected) + + @onlyCPU + @dtypes(torch.float) + def test_diagonal_multidim(self, device, dtype): + x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) + xn = x.numpy() + for args in [(2, 2, 3), + (2,), + (-2, 1, 2), + (0, -2, -1)]: + result = torch.diagonal(x, *args) + expected = xn.diagonal(*args) + self.assertEqual(expected.shape, result.shape) + self.assertEqual(expected, result) + # test non-continguous + xp = x.permute(1, 2, 3, 0) + result = torch.diagonal(xp, 0, -2, -1) + expected = xp.numpy().diagonal(0, -2, -1) + self.assertEqual(expected.shape, result.shape) + self.assertEqual(expected, result) + + @onlyOnCPUAndCUDA + @dtypesIfCPU(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, + include_bfloat16=False)) + @dtypesIfCUDA(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False)) + def test_trace(self, device, dtype): + def test(shape): + tensor = make_tensor(shape, device, dtype, low=-9, high=9) + expected_dtype = tensor.sum().dtype + expected_dtype = torch_to_numpy_dtype_dict[expected_dtype] + + result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype) + expected = torch.tensor(result, device=device) + self.assertEqual(tensor.trace(), expected) + + shapes = ( + [10, 1], + [1, 10], + [100, 100], + [20, 100], + [100, 20], + ) + for shape in shapes: + test(shape) + + def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nans): + """ + Creates a random tensor for a given device and dtype, and computes the expected clamped + values given the min_vals and/or max_vals. + If with_nans is provided, then some values are randomly set to nan. + """ + X = torch.rand(100, device=device).mul(50).add(-25) # uniform in [-25, 25] + X = X.to(dtype) + if with_nans: + mask = torch.randint(0, 2, X.shape, dtype=torch.bool, device=device) + X[mask] = nan + + if isinstance(min_vals, torch.Tensor): + min_vals = min_vals.cpu().numpy() + + if isinstance(max_vals, torch.Tensor): + max_vals = max_vals.cpu().numpy() + + # Use NumPy implementation as reference + X_clamped = torch.tensor(np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device) + return X, X_clamped + + # Tests clamp and its alias, clip + @dtypes(torch.int64, torch.float32) + def test_clamp(self, device, dtype): + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + + # min/max argument product + args = product((-10, None), (10, None)) + + for op in op_list: + for min_val, max_val in args: + if min_val is None and max_val is None: + continue + + X, Y_expected = self.generate_clamp_baseline(device, dtype, + min_vals=min_val, + max_vals=max_val, + with_nans=False) + + # Test op + X1 = X.clone() # So that the in-place ops do not change X + Y_actual = op(X1, min_val, max_val) + self.assertEqual(Y_expected, Y_actual) + + # Test op-out behavior (out does not exist for method versions) + if op in (torch.clamp, torch.clip): + Y_out = torch.empty_like(X) + op(X, min=min_val, max=max_val, out=Y_out) + self.assertEqual(Y_expected, Y_out) + + def test_clamp_propagates_nans(self, device): + op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, + torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + + # min/max argument product + args = product((-10, None), (10, None)) + + for op in op_list: + for min_val, max_val in args: + if min_val is None and max_val is None: + continue + + X, Y_expected = self.generate_clamp_baseline(device, torch.float, + min_vals=min_val, + max_vals=max_val, + with_nans=True) + Y_expected = torch.isnan(Y_expected) + + # Test op + X1 = X.clone() # So that the in-place ops do not change X + Y_actual = op(X1, min_val, max_val) + self.assertEqual(Y_expected, torch.isnan(Y_actual)) + + # Test op-out behavior (out does not exist for method versions) + if op in (torch.clamp, torch.clip): + Y_out = torch.empty_like(X) + op(X, min_val, max_val, out=Y_out) + self.assertEqual(Y_expected, torch.isnan(Y_out)) + + def test_clamp_raises_arg_errors(self, device): + X = torch.randn(100, dtype=torch.float, device=device) + error_msg = 'At least one of \'min\' or \'max\' must not be None' + with self.assertRaisesRegex(RuntimeError, error_msg): + X.clamp() + with self.assertRaisesRegex(RuntimeError, error_msg): + X.clamp_() + with self.assertRaisesRegex(RuntimeError, error_msg): + torch.clamp(X) + + def test_flip(self, device): + data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) + + self.assertEqual(torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) + self.assertEqual(torch.tensor([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2), data.flip(1)) + self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(2)) + self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1)) + self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2)) + + # check for wrap dim + self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1)) + # check for permute + self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2)) + self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) + + # not allow flip on the same dim more than once + self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) + # not allow empty list as input + self.assertRaises(TypeError, lambda: data.flip()) + + # not allow size of flip dim > total dims + self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) + # not allow dim > max dim + self.assertRaises(IndexError, lambda: data.flip(3)) + + # test for non-contiguous case + expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) + transposed_data = torch.arange(1, 9, device=device).view(2, 2, 2).transpose(0, 1) + self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0)) + self.assertEqual(torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), transposed_data.flip(0, 1, 2)) + + # test for shape + data = torch.randn(2, 3, 4, device=device) + size = [2, 3, 4] + test_dims = [] + for i in range(1, 3): + test_dims += combinations(range(len(size)), i) + + for ds in test_dims: + self.assertEqual(size, list(data.flip(ds).size())) + + # test rectangular case + data = torch.tensor([1, 2, 3, 4, 5, 6], device=device).view(2, 3) + flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]], device=device) + flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]], device=device) + + self.assertEqual(flip0_result, data.flip(0)) + self.assertEqual(flip1_result, data.flip(1)) + + # test empty tensor, should just return an empty tensor of the same shape + data = torch.tensor((), device=device) + self.assertEqual(data, data.flip(0)) + + # test bool tensor + a = torch.tensor([False, True], device=device) + self.assertEqual(a.flip(0), torch.tensor([True, False])) + + # case: dims=() + a = torch.randn(3, 2, 1, device=device) + self.assertEqual(a.flip(dims=()), a) + + def _rand_shape(self, dim, min_size, max_size): + shape = [] + for i in range(dim): + shape.append(random.randint(min_size, max_size)) + return tuple(shape) + + @dtypes(torch.cfloat, torch.cdouble) + def test_complex_flip(self, device, dtype): + rand_dim = random.randint(3, 4) + shape = self._rand_shape(rand_dim, 5, 10) + + # Axis to sample for given shape. + for i in range(1, rand_dim): + # Check all combinations of `i` axis. + for flip_dim in combinations(range(rand_dim), i): + data = torch.randn(*shape, device=device, dtype=dtype) + torch_fn = partial(torch.flip, dims=flip_dim) + np_fn = partial(np.flip, axis=flip_dim) + self.compare_with_numpy(torch_fn, np_fn, data) + + def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype): + for dim in range(min_dim, max_dim + 1): + shape = self._rand_shape(dim, 5, 10) + # Randomly scale the input + if dtype.is_floating_point or dtype.is_complex: + data = torch.randn(*shape, device=device, dtype=dtype) + else: + data = torch.randint(0, 10, shape, device=device, dtype=dtype) + self.compare_with_numpy(torch_fn, np_fn, data) + + @dtypes(torch.int64, torch.double, torch.cdouble) + def test_fliplr(self, device, dtype): + self._test_fliplr_flipud(torch.fliplr, np.fliplr, 2, 4, device, dtype) + + @dtypes(torch.int64, torch.double, torch.cdouble) + def test_fliplr_invalid(self, device, dtype): + x = torch.randn(42).to(dtype) + with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): + torch.fliplr(x) + with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): + torch.fliplr(torch.tensor(42, device=device, dtype=dtype)) + + @dtypes(torch.int64, torch.double, torch.cdouble) + def test_flipud(self, device, dtype): + self._test_fliplr_flipud(torch.flipud, np.flipud, 1, 4, device, dtype) + + @dtypes(torch.int64, torch.double, torch.cdouble) + def test_flipud_invalid(self, device, dtype): + with self.assertRaisesRegex(RuntimeError, "Input must be >= 1-d."): + torch.flipud(torch.tensor(42, device=device, dtype=dtype)) + + def test_rot90(self, device): + data = torch.arange(1, 5, device=device).view(2, 2) + self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1])) + self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1])) + self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) + self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) + + # test for default args k=1, dims=[0, 1] + self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) + + # test for reversed order of dims + self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) + + # test for modulo of k + self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1])) + self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) + self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) + + # test for dims out-of-range error + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2])) + + # test tensor with more than 2D + data = torch.arange(1, 9, device=device).view(2, 2, 2) + self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) + self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) + + # test for errors + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) + + @dtypes(torch.cfloat, torch.cdouble) + def test_complex_rot90(self, device, dtype): + shape = self._rand_shape(random.randint(2, 4), 5, 10) + for rot_times in range(4): + data = torch.randn(*shape, device=device, dtype=dtype) + torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1]) + np_fn = partial(np.rot90, k=rot_times, axes=[0, 1]) + self.compare_with_numpy(torch_fn, np_fn, data) + + @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) + def test_nonzero(self, device, dtype): + + shapes = [ + torch.Size((12,)), + torch.Size((12, 1)), + torch.Size((1, 12)), + torch.Size((6, 2)), + torch.Size((3, 2, 2)), + torch.Size((5, 5, 5)), + ] + + def gen_nontrivial_input(shape, dtype, device): + if dtype != torch.bfloat16: + return torch.randint(2, shape, device=device, dtype=dtype) + else: + # windows does not work for bfloat16 randing + return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) + + for shape in shapes: + tensor = gen_nontrivial_input(shape, dtype, device) + dst1 = torch.nonzero(tensor, as_tuple=False) + dst2 = tensor.nonzero(as_tuple=False) + dst3 = torch.empty([], dtype=torch.long, device=device) + torch.nonzero(tensor, out=dst3) + if self.device_type != 'xla': + # xla does not raise runtime error + self.assertRaisesRegex( + RuntimeError, + "scalar type Long", + lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float)) + ) + if self.device_type == 'cuda': + self.assertRaisesRegex( + RuntimeError, + "on the same device", + lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.long)) + ) + np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() + np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() + self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) + self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) + self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0) + tup1 = torch.nonzero(tensor, as_tuple=True) + tup2 = tensor.nonzero(as_tuple=True) + tup1 = torch.stack(tup1).t().cpu() + tup2 = torch.stack(tup2).t().cpu() + self.assertEqual(tup1, np_result, atol=0, rtol=0) + self.assertEqual(tup2, np_result, atol=0, rtol=0) + + def test_nonzero_astuple_out(self, device): + t = torch.randn((3, 3, 3), device=device) + out = torch.empty_like(t, dtype=torch.long) + + with self.assertRaises(RuntimeError): + torch.nonzero(t, as_tuple=True, out=out) + + self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) + + # Verifies that JIT script cannot handle the as_tuple kwarg + # See Issue https://github.com/pytorch/pytorch/issues/45499. + def _foo(t): + tuple_result = torch.nonzero(t, as_tuple=True) + nontuple_result = torch.nonzero(t, as_tuple=False) + out = torch.empty_like(nontuple_result) + torch.nonzero(t, as_tuple=False, out=out) + return tuple_result, nontuple_result, out + + with self.assertRaises(RuntimeError): + scripted_foo = torch.jit.script(_foo) + + # Verifies that JIT tracing works fine + traced_foo = torch.jit.trace(_foo, t) + traced_tuple, traced_nontuple, traced_out = traced_foo(t) + expected_tuple = torch.nonzero(t, as_tuple=True) + expected_nontuple = torch.nonzero(t) + + self.assertEqual(traced_tuple, expected_tuple) + self.assertEqual(traced_nontuple, expected_nontuple) + self.assertEqual(traced_out, expected_nontuple) + + @onlyOnCPUAndCUDA + def test_nonzero_discontiguous(self, device): + shape = (4, 4) + tensor = torch.randint(2, shape, device=device) + tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) + dst1 = tensor.nonzero(as_tuple=False) + dst2 = tensor_nc.nonzero(as_tuple=False) + self.assertEqual(dst1, dst2, atol=0, rtol=0) + dst3 = torch.empty_like(dst1) + data_ptr = dst3.data_ptr() + # expect dst3 storage to be reused + torch.nonzero(tensor, out=dst3) + self.assertEqual(data_ptr, dst3.data_ptr()) + self.assertEqual(dst1, dst3, atol=0, rtol=0) + # discontiguous out + dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] + data_ptr = dst4.data_ptr() + strides = dst4.stride() + torch.nonzero(tensor, out=dst4) + self.assertEqual(data_ptr, dst4.data_ptr()) + self.assertEqual(dst1, dst4, atol=0, rtol=0) + self.assertEqual(strides, dst4.stride()) + + def test_nonzero_non_diff(self, device): + x = torch.randn(10, requires_grad=True) + nz = x.nonzero() + self.assertFalse(nz.requires_grad) + +class TestShapeFuncs(TestCase): + """Test suite for Shape manipulating operators using the ShapeFuncInfo.""" + + @dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128)) + @ops([op for op in shape_funcs if op.name in ['tile', 'repeat']]) + def test_repeat_tile_vs_numpy(self, device, dtype, op): + samples = op.sample_inputs(device, dtype, requires_grad=False) + for sample in samples: + (t, dims) = sample.input + expected = op.ref(t.cpu().numpy(), dims, **sample.kwargs) + result = op(t, dims, **sample.kwargs).cpu().numpy() + self.assertEqual(expected, result) + +instantiate_device_type_tests(TestShapeOps, globals()) +instantiate_device_type_tests(TestShapeFuncs, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_show_pickle.py b/test/test_show_pickle.py index f7617bc089101..79a558cbe51cc 100644 --- a/test/test_show_pickle.py +++ b/test/test_show_pickle.py @@ -4,9 +4,9 @@ import torch import torch.utils.show_pickle -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS -class TestShowPickle(unittest.TestCase): +class TestShowPickle(TestCase): @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") def test_scripted_model(self): @@ -31,4 +31,4 @@ def forward(self, x): if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py new file mode 100644 index 0000000000000..620c77d8b8be0 --- /dev/null +++ b/test/test_sort_and_select.py @@ -0,0 +1,699 @@ +import torch +import numpy as np + +import random +from torch._six import nan +from itertools import product + +from torch.testing._internal.common_utils import \ + (TestCase, run_tests, make_tensor) +from torch.testing._internal.common_device_type import \ + (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, + skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA, onlyCPU) + +# TODO: remove this +SIZE = 100 + +class TestSortAndSelect(TestCase): + + def assertIsOrdered(self, order, x, mxx, ixx, task): + SIZE = 4 + if order == 'descending': + def check_order(a, b): + # `a != a` because we put NaNs + # at the end of ascending sorted lists, + # and the beginning of descending ones. + return a != a or a >= b + elif order == 'ascending': + def check_order(a, b): + # see above + return b != b or a <= b + else: + error('unknown order "{}", must be "ascending" or "descending"'.format(order)) + + are_ordered = True + for j, k in product(range(SIZE), range(1, SIZE)): + self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]), + 'torch.sort ({}) values unordered for {}'.format(order, task)) + + seen = set() + indicesCorrect = True + size = x.size(x.dim() - 1) + for k in range(size): + seen.clear() + for j in range(size): + self.assertEqual(x[k][ixx[k][j]], mxx[k][j], + msg='torch.sort ({}) indices wrong for {}'.format(order, task)) + seen.add(ixx[k][j]) + self.assertEqual(len(seen), size) + + def test_sort(self, device): + SIZE = 4 + x = torch.rand(SIZE, SIZE, device=device) + res1val, res1ind = torch.sort(x) + + # Test use of result tensor + res2val = torch.tensor((), device=device) + res2ind = torch.tensor((), device=device, dtype=torch.long) + torch.sort(x, out=(res2val, res2ind)) + self.assertEqual(res1val, res2val, atol=0, rtol=0) + self.assertEqual(res1ind, res2ind, atol=0, rtol=0) + self.assertEqual(torch.argsort(x), res1ind) + self.assertEqual(x.argsort(), res1ind) + + # Test sorting of random numbers + self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') + + # Test simple sort + self.assertEqual( + torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0], + torch.tensor((10, 20, 30, 40, 50), device=device), + atol=0, rtol=0 + ) + + # Test that we still have proper sorting with duplicate keys + x = torch.floor(torch.rand(SIZE, SIZE, device=device) * 10) + torch.sort(x, out=(res2val, res2ind)) + self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') + + # DESCENDING SORT + x = torch.rand(SIZE, SIZE, device=device) + res1val, res1ind = torch.sort(x, x.dim() - 1, True) + + # Test use of result tensor + res2val = torch.tensor((), device=device) + res2ind = torch.tensor((), device=device, dtype=torch.long) + torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind)) + self.assertEqual(res1val, res2val, atol=0, rtol=0) + self.assertEqual(res1ind, res2ind, atol=0, rtol=0) + self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind) + self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) + + # Test sorting of random numbers + self.assertIsOrdered('descending', x, res2val, res2ind, 'random') + + # Test simple sort task + self.assertEqual( + torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[0], + torch.tensor((50, 40, 30, 20, 10), device=device), + atol=0, rtol=0 + ) + + # Test that we still have proper sorting with duplicate keys + self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') + + # Test sorting with NaNs + x = torch.rand(SIZE, SIZE, device=device) + x[1][2] = float('NaN') + x[3][0] = float('NaN') + torch.sort(x, out=(res2val, res2ind)) + self.assertIsOrdered('ascending', x, res2val, res2ind, + 'random with NaNs') + torch.sort(x, out=(res2val, res2ind), descending=True) + self.assertIsOrdered('descending', x, res2val, res2ind, + 'random with NaNs') + + @onlyCPU + def test_stable_sort(self, device): + # no stable sort for CUDA yet + for dtype in ( + torch.float, torch.double, + torch.int8, torch.int16, torch.int32, + torch.bool + ): + for ncopies in (100, 1000, 10000): + x = torch.tensor([0, 1] * ncopies, dtype=dtype, device=torch.device(device)) + _, idx = x.sort(stable=True) + self.assertEqual( + idx[:ncopies], + torch.arange(start=0, end=2 * ncopies, step=2, device=torch.device(device)) + ) + self.assertEqual( + idx[ncopies:], + torch.arange(start=1, end=2 * ncopies, step=2, device=torch.device(device)) + ) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_msort(self, device, dtype): + def test(shape): + tensor = make_tensor(shape, device, dtype, low=-9, high=9) + if tensor.size() != torch.Size([]): + expected = torch.from_numpy(np.msort(tensor.cpu().numpy())) + else: + expected = tensor # numpy.msort() does not support empty shapes tensor + + result = torch.msort(tensor) + self.assertEqual(result, expected) + + out = torch.empty_like(result) + torch.msort(tensor, out=out) + self.assertEqual(out, expected) + + shapes = ( + [], + [0, ], + [20, ], + [1, 20], + [30, 30], + [10, 20, 30] + ) + for shape in shapes: + test(shape) + + def test_topk(self, device): + def topKViaSort(t, k, dim, dir): + sorted, indices = t.sort(dim, dir) + return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k) + + def compareTensors(t, res1, ind1, res2, ind2, dim): + # Values should be exactly equivalent + self.assertEqual(res1, res2, atol=0, rtol=0) + + # Indices might differ based on the implementation, since there is + # no guarantee of the relative order of selection + if not ind1.eq(ind2).all(): + # To verify that the indices represent equivalent elements, + # gather from the input using the topk indices and compare against + # the sort indices + vals = t.gather(dim, ind2) + self.assertEqual(res1, vals, atol=0, rtol=0) + + def compare(t, k, dim, dir): + topKVal, topKInd = t.topk(k, dim, dir, True) + sortKVal, sortKInd = topKViaSort(t, k, dim, dir) + compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) + + t = torch.rand(random.randint(1, SIZE), + random.randint(1, SIZE), + random.randint(1, SIZE), device=device) + + for _kTries in range(3): + for _dimTries in range(3): + for transpose in (True, False): + for dir in (True, False): + testTensor = t + if transpose: + dim1 = random.randrange(t.ndimension()) + dim2 = dim1 + while dim1 == dim2: + dim2 = random.randrange(t.ndimension()) + + testTensor = t.transpose(dim1, dim2) + + dim = random.randrange(testTensor.ndimension()) + k = random.randint(1, testTensor.size(dim)) + compare(testTensor, k, dim, dir) + + def test_topk_arguments(self, device): + q = torch.randn(10, 2, 10, device=device) + # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1) + self.assertRaises(TypeError, lambda: q.topk(4, True)) + + @skipCUDAIfRocm + def test_unique_dim(self, device): + self.assertFalse(hasattr(torch, 'unique_dim')) + + def run_test(device, dtype): + x = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) + x_empty = torch.empty(5, 0, dtype=dtype, device=device) + x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) + x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) + expected_unique_dim0 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) + expected_inverse_dim0 = torch.tensor([0, 0]) + expected_counts_dim0 = torch.tensor([2]) + expected_unique_dim1 = torch.tensor([[[0., 1.], + [1., 1.], + [2., 1.]], + [[0., 1.], + [1., 1.], + [2., 1.]]], + dtype=dtype, + device=device) + expected_unique_dim1_bool = torch.tensor([[[False, True], [True, True]], + [[False, True], [True, True]]], + dtype=torch.bool, + device=device) + expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) + expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0]) + expected_counts_dim1 = torch.tensor([2, 1, 1]) + expected_counts_dim1_bool = torch.tensor([2, 2]) + expected_unique_dim2 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) + expected_inverse_dim2 = torch.tensor([0, 1]) + expected_counts_dim2 = torch.tensor([1, 1]) + expected_unique_empty = torch.tensor([], dtype=dtype, device=device) + expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) + expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) + # dim0 + x_unique = torch.unique(x, dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_inverse_dim0, x_inverse) + + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_counts_dim0, x_counts) + + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_inverse_dim0, x_inverse) + self.assertEqual(expected_counts_dim0, x_counts) + + # dim1 + x_unique = torch.unique(x, dim=1) + if x.dtype == torch.bool: + self.assertEqual(expected_unique_dim1_bool, x_unique) + else: + self.assertEqual(expected_unique_dim1, x_unique) + + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=1) + if x.dtype == torch.bool: + self.assertEqual(expected_unique_dim1_bool, x_unique) + self.assertEqual(expected_inverse_dim1_bool, x_inverse) + else: + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_inverse_dim1, x_inverse) + + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=1) + if x.dtype == torch.bool: + self.assertEqual(expected_unique_dim1_bool, x_unique) + self.assertEqual(expected_counts_dim1_bool, x_counts) + else: + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_counts_dim1, x_counts) + + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=1) + if x.dtype == torch.bool: + self.assertEqual(expected_unique_dim1_bool, x_unique) + self.assertEqual(expected_inverse_dim1_bool, x_inverse) + self.assertEqual(expected_counts_dim1_bool, x_counts) + else: + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_inverse_dim1, x_inverse) + self.assertEqual(expected_counts_dim1, x_counts) + + # dim2 + x_unique = torch.unique(x, dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_inverse_dim2, x_inverse) + + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_counts_dim2, x_counts) + + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_inverse_dim2, x_inverse) + self.assertEqual(expected_counts_dim2, x_counts) + + # test empty tensor + x_unique, x_inverse, x_counts = torch.unique( + x_empty, + return_inverse=True, + return_counts=True, + dim=1) + self.assertEqual(expected_unique_empty, x_unique) + self.assertEqual(expected_inverse_empty, x_inverse) + self.assertEqual(expected_counts_empty, x_counts) + + # test not a well formed tensor + # Checking for runtime error, as this is the expected behaviour + with self.assertRaises(RuntimeError): + torch.unique( + x_ill_formed_empty, + return_inverse=True, + return_counts=True, + dim=1) + + # test along dim2 + with self.assertRaises(RuntimeError): + torch.unique( + x_ill_formed_empty_another, + return_inverse=True, + return_counts=True, + dim=2) + + # test consecutive version + y = torch.tensor( + [[0, 1], + [0, 1], + [0, 1], + [1, 2], + [1, 2], + [3, 4], + [0, 1], + [0, 1], + [3, 4], + [1, 2]], + dtype=dtype, + device=device + ) + expected_y_unique = torch.tensor( + [[0, 1], + [1, 2], + [3, 4], + [0, 1], + [3, 4], + [1, 2]], + dtype=dtype, + device=device + ) + expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device) + expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device) + expected_y_inverse_bool = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device) + expected_y_counts_bool = torch.tensor([3, 3, 2, 2], dtype=torch.int64, device=device) + y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) + if x.dtype == torch.bool: + self.assertEqual(expected_y_inverse_bool, y_inverse) + self.assertEqual(expected_y_counts_bool, y_counts) + else: + self.assertEqual(expected_y_inverse, y_inverse) + self.assertEqual(expected_y_counts, y_counts) + + run_test(device, torch.float) + run_test(device, torch.double) + run_test(device, torch.long) + run_test(device, torch.uint8) + run_test(device, torch.bool) + + @onlyCUDA + def test_topk_noncontiguous_gpu(self, device): + t = torch.randn(20, device=device)[::2] + top1, idx1 = t.topk(5) + top2, idx2 = t.contiguous().topk(5) + self.assertEqual(top1, top2) + self.assertEqual(idx1, idx2) + + @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64) + def test_topk_integral(self, device, dtype): + a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, size=(10,), + dtype=dtype, device=device) + sort_topk = a.sort()[0][-5:].flip(0) + topk = a.topk(5) + self.assertEqual(sort_topk, topk[0]) # check values + self.assertEqual(sort_topk, a[topk[1]]) # check indices + + @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypes(torch.float, torch.double) + def test_topk_nonfinite(self, device, dtype): + x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype) + val, idx = x.topk(4) + expect = torch.tensor([float('nan'), float('inf'), 1e4, 0], device=device, dtype=dtype) + self.assertEqual(val, expect) + self.assertEqual(idx, [0, 1, 2, 3]) + + val, idx = x.topk(4, largest=False) + expect = torch.tensor([-float('inf'), -1e4, 0, 1e4], device=device, dtype=dtype) + self.assertEqual(val, expect) + self.assertEqual(idx, [5, 4, 3, 2]) + + def test_topk_4d(self, device): + x = torch.ones(2, 3072, 2, 2, device=device) + x[:, 1, :, :] *= 2. + x[:, 10, :, :] *= 1.5 + val, ind = torch.topk(x, k=2, dim=1) + expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device) + expected_ind[:, 1, :, :] = 10 + expected_val = torch.ones(2, 2, 2, 2, device=device) + expected_val[:, 0, :, :] *= 2. + expected_val[:, 1, :, :] *= 1.5 + self.assertEqual(val, expected_val, atol=0, rtol=0) + self.assertEqual(ind, expected_ind, atol=0, rtol=0) + + def _test_unique_scalar_empty(self, dtype, device, f): + # test scalar + x = torch.tensor(0, dtype=dtype, device=device) + unique, inverse, counts = f(x, return_inverse=True, return_counts=True) + expected_unique = torch.tensor([0], dtype=dtype, device=device) + expected_inverse = torch.tensor(0, device=device) + expected_counts = torch.tensor([1], device=device) + self.assertEqual(unique, expected_unique) + self.assertEqual(inverse, expected_inverse) + self.assertEqual(counts, expected_counts) + + # test zero sized tensor + x = torch.zeros((0, 0, 3), dtype=dtype, device=device) + unique, inverse, counts = f(x, return_inverse=True, return_counts=True) + expected_unique = torch.tensor([], dtype=dtype, device=device) + expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device) + expected_counts = torch.tensor([], dtype=torch.long, device=device) + self.assertEqual(unique, expected_unique) + self.assertEqual(inverse, expected_inverse) + self.assertEqual(counts, expected_counts) + + def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): + def ensure_tuple(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + for return_inverse in [True, False]: + for return_counts in [True, False]: + # test with expected + ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) + self.assertEqual(expected_unique, ret[0]) + if return_inverse: + self.assertEqual(expected_inverse, ret[1]) + if return_counts: + count_index = 1 + int(return_inverse) + self.assertEqual(expected_counts, ret[count_index]) + + # tests per-element unique on a higher rank tensor. + y = x.view(additional_shape) + y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse.view(additional_shape), y_inverse) + self.assertEqual(expected_counts, y_counts) + + @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) + def test_unique(self, device, dtype): + if dtype is torch.half and self.device_type == 'cpu': + return # CPU does not have half support + + def ensure_tuple(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + if dtype is torch.bool: + x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) + expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) + expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) + expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) + else: + x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) + expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device) + expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) + expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) + + # test sorted unique + fs = [ + lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs), + lambda x, **kwargs: x.unique(sorted=True, **kwargs), + ] + for f in fs: + self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) + self._test_unique_scalar_empty(dtype, device, f) + + # test unsorted unique + fs = [ + lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), + lambda x, **kwargs: x.unique(sorted=False, **kwargs) + ] + for f in fs: + self._test_unique_scalar_empty(dtype, device, f) + for return_inverse in [True, False]: + for return_counts in [True, False]: + ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) + x_list = x.tolist() + x_unique_list = ret[0].tolist() + self.assertEqual(expected_unique.tolist(), sorted(x_unique_list)) + if return_inverse: + x_inverse_list = ret[1].tolist() + for i, j in enumerate(x_inverse_list): + self.assertEqual(x_list[i], x_unique_list[j]) + if return_counts: + count_index = 1 + int(return_inverse) + x_counts_list = ret[count_index].tolist() + for i, j in zip(x_unique_list, x_counts_list): + count = 0 + for k in x_list: + if k == i: + count += 1 + self.assertEqual(j, count) + + @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) + def test_unique_consecutive(self, device, dtype): + if dtype is torch.half and self.device_type == 'cpu': + return # CPU does not have half support + + if dtype is torch.bool: + x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device) + expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) + expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device) + expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device) + else: + x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device) + expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device) + expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) + expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device) + + for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]: + self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3)) + self._test_unique_scalar_empty(dtype, device, f) + + @dtypes(torch.double) + def test_kthvalue(self, device, dtype): + SIZE = 50 + x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device) + x0 = x.clone() + + k = random.randint(1, SIZE) + res1val, res1ind = torch.kthvalue(x, k, keepdim=False) + res2val, res2ind = torch.sort(x) + + self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) + self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) + # test use of result tensors + k = random.randint(1, SIZE) + res1val = torch.tensor([], dtype=dtype, device=device) + res1ind = torch.tensor([], dtype=torch.long, device=device) + torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind)) + res2val, res2ind = torch.sort(x) + self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) + self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) + + # test non-default dim + k = random.randint(1, SIZE) + res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=False) + res2val, res2ind = torch.sort(x, 0) + self.assertEqual(res1val, res2val[k - 1], atol=0, rtol=0) + self.assertEqual(res1ind, res2ind[k - 1], atol=0, rtol=0) + + # non-contiguous + y = x.narrow(1, 0, 1) + y0 = y.contiguous() + k = random.randint(1, SIZE) + res1val, res1ind = torch.kthvalue(y, k) + res2val, res2ind = torch.kthvalue(y0, k) + self.assertEqual(res1val, res2val, atol=0, rtol=0) + self.assertEqual(res1ind, res2ind, atol=0, rtol=0) + + # non-contiguous [Reference: https://github.com/pytorch/pytorch/issues/45721] + non_contig_t = torch.tensor([0, -1, 1, -2, 2], dtype=dtype, device=device)[::2] + expected_val, expected_ind = non_contig_t.contiguous().kthvalue(2) + non_contig_cpu_t = non_contig_t.cpu() + expected_val_cpu, expected_ind_cpu = non_contig_cpu_t.kthvalue(2) + + out_val, out_ind = non_contig_t.kthvalue(2) + self.assertEqual(expected_val, out_val, atol=0, rtol=0) + self.assertEqual(expected_ind, out_ind, atol=0, rtol=0) + self.assertEqual(expected_val_cpu, out_val, atol=0, rtol=0) + self.assertEqual(expected_ind_cpu, out_ind, atol=0, rtol=0) + + # check that the input wasn't modified + self.assertEqual(x, x0, atol=0, rtol=0) + + # simple test case (with repetitions) + y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device) + self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0) + self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0) + + # simple test case (with NaN) + SIZE = 50 + x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device) + x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan + ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1] + res2val, res2ind = torch.sort(x) + for k in ks: + res1val, res1ind = torch.kthvalue(x, k, keepdim=False) + self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) + self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) + + # test overlapping output + @dtypes(torch.double) + @onlyOnCPUAndCUDA # Fails on XLA + def test_kthvalue_overlap(self, device, dtype): + S = 10 + k = 5 + a = torch.randn(S) + indices = torch.empty((), device=device, dtype=torch.long) + with self.assertRaisesRegex(RuntimeError, "unsupported operation:"): + torch.kthvalue(a, k, out=(a, indices)) + + @dtypes(torch.float) + @onlyOnCPUAndCUDA # Fails on XLA + def test_kthvalue_scalar(self, device, dtype): + # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818) + # Tests that passing a scalar tensor or 1D tensor with 1 element work either way + res = torch.tensor(2, device=device, dtype=dtype).kthvalue(1) + ref = torch.tensor([2], device=device, dtype=dtype).kthvalue(1) + self.assertEqual(res[0], ref[0].squeeze()) + self.assertEqual(res[1], ref[1].squeeze()) + +instantiate_device_type_tests(TestSortAndSelect, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_sparse.py b/test/test_sparse.py index 2a0e76afe36a4..78e3e3de15993 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -8,12 +8,21 @@ import functools import operator import random +from collections import defaultdict import unittest from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, TEST_WITH_ROCM, IS_WINDOWS + do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number from torch.autograd.gradcheck import gradcheck +from typing import Dict, Any +from torch.testing._internal.common_device_type import \ + (instantiate_device_type_tests, ops) +from torch.testing._internal.common_methods_invocations import \ + (sparse_unary_ufuncs) + +if TEST_SCIPY: + import scipy.sparse # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -161,7 +170,7 @@ def test_shape(sparse_dims, nnz, with_size): self.assertEqual(i, x._indices()) self.assertEqual(v, x._values()) self.assertEqual(x.ndimension(), len(with_size)) - self.assertEqual(self.safeCoalesce(x)._nnz(), nnz) + self.assertEqual(x.coalesce()._nnz(), nnz) self.assertEqual(list(x.size()), with_size) # Test .indices() and .values() @@ -183,7 +192,7 @@ def test_shape(sparse_dims, nnz, with_size): i = self.index_tensor([[9, 0, 0, 0, 8, 1, 1, 1, 2, 7, 2, 2, 3, 4, 6, 9]]) v = self.value_tensor([[idx**2, idx] for idx in range(i.size(1))]) x = self.sparse_tensor(i, v, torch.Size([10, 2])) - self.assertEqual(self.safeCoalesce(x)._nnz(), 9) + self.assertEqual(x.coalesce()._nnz(), 9) # Make sure we can access empty indices / values x = self.legacy_sparse_tensor() @@ -191,13 +200,50 @@ def test_shape(sparse_dims, nnz, with_size): self.assertEqual(x._values().numel(), 0) def test_coalesce(self): + + def _test_coalesce(x): + tc = t.coalesce() + self.assertEqual(tc.to_dense(), t.to_dense()) + self.assertTrue(tc.is_coalesced()) + # Our code below doesn't work when nnz is 0, because + # then it's a 0D tensor, not a 2D tensor. + if t._nnz() == 0: + self.assertEqual(t._indices(), tc._indices()) + self.assertEqual(t._values(), tc._values()) + return tc + + value_map: Dict[Any, Any] = {} + for idx, val in zip(t._indices().t(), t._values()): + idx_tup = tuple(idx.tolist()) + if idx_tup in value_map: + value_map[idx_tup] += val + else: + value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val + + new_indices = sorted(list(value_map.keys())) + _new_values = [value_map[idx] for idx in new_indices] + if t._values().ndimension() < 2: + new_values = t._values().new(_new_values) + else: + new_values = torch.stack(_new_values) + + new_indices = t._indices().new(new_indices).t() + tg = t.new(new_indices, new_values, t.size()) + + self.assertEqual(tc._indices(), tg._indices()) + self.assertEqual(tc._values(), tg._values()) + + if t.is_coalesced(): + self.assertEqual(tc._indices(), t._indices()) + self.assertEqual(tc._values(), t._values()) + for empty_i, empty_v, empty_nnz in itertools.product([True, False], repeat=3): sparse_size = [] if empty_i else [2, 1] dense_size = [1, 0, 2] if empty_v else [1, 2] nnz = 0 if empty_nnz else 5 t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size) - self.safeCoalesce(t) # this tests correctness + _test_coalesce(t) # this tests correctness def test_ctor_size_checks(self): indices = self.index_tensor([ @@ -310,6 +356,11 @@ def test_to_sparse(self): sp, _, _ = self._gen_sparse(2, 10, [3, 3, 3]) self.assertRaises(RuntimeError, lambda: sp.to_sparse()) + def test_sparse_bool(self): + a = self.value_tensor([True, False]).to(torch.bool) + b = a.to_sparse().to_dense() + self.assertEqual(a, b) + def test_scalar(self): # tensor with value a = self.sparse_tensor(self.index_tensor([]).unsqueeze(1), 12.3, []) @@ -399,7 +450,7 @@ def fn(x): def test_contig(self): def test_tensor(x, exp_i, exp_v): - x = self.safeCoalesce(x) + x = x.coalesce() self.assertEqual(exp_i, x._indices()) self.assertEqual(exp_v, x._values()) @@ -479,7 +530,7 @@ def test_tensor(x, exp_i, exp_v): def test_contig_hybrid(self): def test_tensor(x, exp_i, exp_v): - x = self.safeCoalesce(x) + x = x.coalesce() self.assertEqual(exp_i, x._indices()) self.assertEqual(exp_v, x._values()) @@ -624,7 +675,6 @@ def test_Sparse_to_Sparse_copy_(self): self.assertEqual(None, x1.grad) @unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU") - @skipIfRocm def test_Sparse_to_Sparse_copy_multi_gpu(self): # This is for testing torch.copy_(SparseTensor, SparseTensor) across GPU devices sparse_dims = 3 @@ -1080,6 +1130,51 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 0, 100, 0) test_shape(1000, 100, 0, 0) + @cpu_only + def test_sspaddmm(self): + + def test_shape(di, dj, dk, nnz): + x = self._gen_sparse(2, nnz, [di, dj])[0] + t = self._gen_sparse(2, nnz, [di, dk])[0] + y = torch.randn(dj, dk) + alpha = random.random() + beta = random.random() + + res = t.sspaddmm(x, y, beta=beta, alpha=alpha) + expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha) + self.assertEqual(self.safeToDense(res), expected) + + res = t.sspaddmm(x, y) + expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y) + self.assertEqual(self.safeToDense(res), expected) + + test_shape(7, 5, 3, 20) + test_shape(1000, 100, 100, 20) + test_shape(3000, 64, 300, 20) + test_shape(0, 100, 100, 0) + test_shape(1000, 0, 100, 0) + test_shape(1000, 100, 0, 0) + + # Test code from issue https://github.com/pytorch/pytorch/issues/45113 + batch_size, input_size, hidden_size = 5, 3, 7 + + # Create coalesced sparse tensor as in the issue + weight = torch.randn(hidden_size, input_size).to_sparse() + self.assertTrue(weight.is_coalesced()) + self.assertFalse(weight._indices().is_contiguous()) + # Create un/coalesced sparse tensor + bias = torch.randn((hidden_size, 1)).to_sparse() + bias = torch.cat([bias] * batch_size, dim=1) + + if not self.is_uncoalesced: + bias = bias.coalesce() + + x = torch.randn(input_size, batch_size) + res = bias.sspaddmm(weight, x) + + true_result = (bias.to_dense() + torch.matmul(weight.to_dense(), x)).to_sparse() + self.assertEqual(self.safeToDense(res), self.safeToDense(true_result)) + def test_sparse_addmm(self): def test_shape(m, n, p, nnz, broadcast): if broadcast: @@ -1135,7 +1230,6 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 100, 0, 0) test_shape(1000, 100, 0, 20) - @skipIfRocm def test_hsmm(self): def test_shape(di, dj, dk, nnz): x = self._gen_sparse(2, nnz, [di, dj])[0] @@ -1218,7 +1312,6 @@ def test_spadd_hybrid(self): self._test_spadd_shape(10, [50, 30, 20], [2, 0]) @cuda_only - @unittest.skipIf(not TEST_WITH_ROCM, "runs only on ROCm") def test_sparse_add_out_bfloat16(self): # fp32 x, _, _ = self._gen_sparse(3, 5, 10) @@ -1261,7 +1354,6 @@ def test_shape(sparse_dims, nnz, with_size): x.norm(**kwargs) - @skipIfRocm def test_sparse_sum(self): def run_tests(S, td=None): @@ -1730,63 +1822,83 @@ def test_narrow(self): self.assertRaises(RuntimeError, lambda: with_dense.narrow_copy(10, 0, 3)) # dim > sparseDim + denseDim def _test_log1p_tensor(self, sparse_tensor): + def is_integral(dtype): + return dtype in torch.testing.get_all_int_dtypes() + dense_tensor = sparse_tensor.to_dense() expected_output = dense_tensor.log1p() - + is_integral_dtype = is_integral(sparse_tensor.dtype) self.assertEqual(expected_output, sparse_tensor.log1p().to_dense()) - self.assertEqual(expected_output, sparse_tensor.coalesce().log1p_().to_dense()) + if is_integral_dtype: + with self.assertRaisesRegex(RuntimeError, "log1p: result type cannot be Integral, got:"): + sparse_tensor.coalesce().log1p_() + else: + self.assertEqual(expected_output, sparse_tensor.coalesce().log1p_().to_dense()) - if self.is_uncoalesced: + if self.is_uncoalesced and not is_integral_dtype: # test in-place op on uncoalesced input with self.assertRaisesRegex(RuntimeError, "in-place on uncoalesced tensors is not supported"): sparse_tensor.log1p_() + elif self.is_uncoalesced and is_integral_dtype: + with self.assertRaisesRegex(RuntimeError, "log1p: result type cannot be Integral, got"): + sparse_tensor.log1p_() - sparse_tensor.requires_grad_() - self.assertTrue(sparse_tensor.requires_grad) + if not is_integral_dtype: + sparse_tensor.requires_grad_() + self.assertTrue(sparse_tensor.requires_grad) - # test autograd - x = sparse_tensor.clone() - y = sparse_tensor.log1p() - with self.assertRaisesRegex(RuntimeError, "log1p of a sparse tensor is made to be non-differentiable"): - y.backward(x) + # test autograd + x = sparse_tensor.clone() + y = sparse_tensor.log1p() + with self.assertRaisesRegex(RuntimeError, "log1p of a sparse tensor is made to be non-differentiable"): + y.backward(x) + else: + with self.assertRaisesRegex(RuntimeError, "only Tensors of floating point dtype can require gradients"): + sparse_tensor.requires_grad_() def test_log1p(self): - if not self.is_uncoalesced: - input_coalesced = torch.sparse_coo_tensor( - indices=torch.tensor([[0], [1], [2]]).transpose(1, 0), - values=torch.tensor([3.0, 4.0, 5.0]), - size=[3, ], - device=self.device - ).coalesce() - self._test_log1p_tensor(input_coalesced) - - # hybrid sparse input - input_coalesced = torch.sparse_coo_tensor( - indices=torch.tensor([[1, 3], [2, 4]]), - values=torch.tensor([[1.0, 3.0], [5.0, 7.0]]), - size=[4, 5, 2], - device=self.device - ).coalesce() - self._test_log1p_tensor(input_coalesced) + for dtype in torch.testing.get_all_dtypes(include_bool=False, include_half=False, + include_bfloat16=False, include_complex=False): + if not self.is_uncoalesced: + input_coalesced = torch.sparse_coo_tensor( + indices=torch.tensor([[0], [1], [2]]).transpose(1, 0), + values=torch.tensor([3.0, 4.0, 5.0]), + size=[3, ], + device=self.device, + dtype=dtype + ).coalesce() + self._test_log1p_tensor(input_coalesced) + + # hybrid sparse input + input_coalesced = torch.sparse_coo_tensor( + indices=torch.tensor([[1, 3], [2, 4]]), + values=torch.tensor([[1.0, 3.0], [5.0, 7.0]]), + size=[4, 5, 2], + device=self.device, + dtype=dtype + ).coalesce() + self._test_log1p_tensor(input_coalesced) - if self.is_uncoalesced: - # test uncoalesced input - input_uncoalesced = torch.sparse_coo_tensor( - indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), - values=torch.tensor([2.0, 3.0, 4.0, 1.0, 1.0, 1.0]), - size=[3, ], - device=self.device - ) - self._test_log1p_tensor(input_uncoalesced) - - # test on empty sparse tensor - input_uncoalesced = torch.sparse_coo_tensor( - indices=torch.zeros([2, 0]), - values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), - size=[0, 0, 5, 5, 5, 5, 5, 5, 0], - device=self.device - ) - self._test_log1p_tensor(input_uncoalesced) + if self.is_uncoalesced: + # test uncoalesced input + input_uncoalesced = torch.sparse_coo_tensor( + indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), + values=torch.tensor([2.0, 3.0, 4.0, 1.0, 1.0, 1.0]), + size=[3, ], + device=self.device, + dtype=dtype + ) + self._test_log1p_tensor(input_uncoalesced) + + # test on empty sparse tensor + input_uncoalesced = torch.sparse_coo_tensor( + indices=torch.zeros([2, 0]), + values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), + size=[0, 0, 5, 5, 5, 5, 5, 5, 0], + device=self.device, + dtype=dtype + ) + self._test_log1p_tensor(input_uncoalesced) def _test_neg_negative(self, sparse_tensor): dense_tensor = sparse_tensor.to_dense() @@ -1846,6 +1958,10 @@ def test_neg_negative(self): self._test_neg_negative(input_uncoalesced) def _test_asin_arcsin(self, sparse_tensor): + def is_integral(dtype): + return dtype in torch.testing.get_all_int_dtypes() + is_integral_dtype = is_integral(sparse_tensor.dtype) + dense_tensor = sparse_tensor.to_dense() expected_output = dense_tensor.asin() @@ -1857,54 +1973,73 @@ def _test_asin_arcsin(self, sparse_tensor): self.assertEqual(expected_output, op(sparse_tensor).to_dense()) if op in (torch.asin, torch.arcsin): sparse_tensor_out = torch.zeros_like(sparse_tensor) - op(sparse_tensor, out=sparse_tensor_out) - self.assertEqual(expected_output, sparse_tensor_out.to_dense()) + if not is_integral_dtype: + op(sparse_tensor, out=sparse_tensor_out) + self.assertEqual(expected_output, sparse_tensor_out.to_dense()) + else: + with self.assertRaisesRegex(RuntimeError, "asin: result type cannot be Integral"): + op(sparse_tensor, out=sparse_tensor_out) for op in (torch.Tensor.asin_, torch.Tensor.arcsin_): - self.assertEqual(expected_output, op(sparse_tensor.clone().coalesce()).to_dense()) - if self.is_uncoalesced: + if is_integral_dtype: + # test coalesce on integral dtype tensor + with self.assertRaisesRegex(RuntimeError, "asin: result type cannot be Integral"): + op(sparse_tensor.clone().coalesce()).to_dense() + else: + self.assertEqual(expected_output, op(sparse_tensor.clone().coalesce()).to_dense()) + + if self.is_uncoalesced and not is_integral_dtype: # test in-place op on uncoalesced input with self.assertRaisesRegex(RuntimeError, "in-place on uncoalesced tensors is not supported"): op(sparse_tensor) + elif self.is_uncoalesced: + # test in-place op on integral dtype tensor + with self.assertRaisesRegex(RuntimeError, "asin: result type cannot be Integral"): + op(sparse_tensor) def test_asin_arcsin(self): + for dtype in torch.testing.get_all_dtypes(include_bool=False, include_half=False, + include_bfloat16=False, include_complex=False): + if not self.is_uncoalesced: + input_coalesced = torch.sparse_coo_tensor( + indices=torch.tensor([[0, 1, 2, 3]]), + values=torch.tensor([0.5, -0.5, 0.7, -0.7]), + size=[4, ], + dtype=dtype, + device=self.device + ).coalesce() + self._test_asin_arcsin(input_coalesced) + + # hybrid sparse input + input_coalesced = torch.sparse_coo_tensor( + indices=torch.tensor([[1, 3], [2, 4]]), + values=torch.tensor([[-0.1, 0.24], [-0.44, 0.1]]), + size=[4, 5, 2], + dtype=dtype, + device=self.device + ).coalesce() + self._test_asin_arcsin(input_coalesced) - if not self.is_uncoalesced: - input_coalesced = torch.sparse_coo_tensor( - indices=torch.tensor([[0, 1, 2, 3]]), - values=torch.tensor([0.5, -0.5, 0.7, -0.7]), - size=[4, ], - device=self.device - ).coalesce() - self._test_asin_arcsin(input_coalesced) - - # hybrid sparse input - input_coalesced = torch.sparse_coo_tensor( - indices=torch.tensor([[1, 3], [2, 4]]), - values=torch.tensor([[-0.1, 0.24], [-0.44, 0.1]]), - size=[4, 5, 2], - device=self.device - ).coalesce() - self._test_asin_arcsin(input_coalesced) - - if self.is_uncoalesced: - # test uncoalesced input - input_uncoalesced = torch.sparse_coo_tensor( - indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), - values=torch.tensor([0.3, -0.3, -0.4, 0.3, -0.5, 0.15]), - size=[3, ], - device=self.device - ) - self._test_asin_arcsin(input_uncoalesced) - - # test on empty sparse tensor - input_uncoalesced = torch.sparse_coo_tensor( - indices=torch.zeros([2, 0]), - values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), - size=[0, 0, 5, 5, 5, 5, 5, 5, 0], - device=self.device - ) - self._test_asin_arcsin(input_uncoalesced) + if self.is_uncoalesced: + # test uncoalesced input + input_uncoalesced = torch.sparse_coo_tensor( + indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), + values=torch.tensor([0.3, -0.3, -0.4, 0.3, -0.5, 0.15]), + size=[3, ], + dtype=dtype, + device=self.device + ) + self._test_asin_arcsin(input_uncoalesced) + + # test on empty sparse tensor + input_uncoalesced = torch.sparse_coo_tensor( + indices=torch.zeros([2, 0]), + values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), + size=[0, 0, 5, 5, 5, 5, 5, 5, 0], + dtype=dtype, + device=self.device + ) + self._test_asin_arcsin(input_uncoalesced) def test_mv(self): def test_shape(di, dj, dk, nnz): @@ -2020,7 +2155,9 @@ def test_shape(sparse_dims, nnz, with_size): if not x.is_cuda: # CUDA sparse tensors currently requires the size to be # specified if nDimV > 0 - self.assertEqual(x.new(indices, values), x) + out = x.new(indices, values).coalesce() + x_c = x.coalesce() + self.assertEqual((out.indices(), out.values()), (x_c.indices(), x_c.values())) self.assertEqual(x.new(indices, values, x.size()), x) test_shape(3, 10, 100) @@ -2589,7 +2726,6 @@ def test_sparse_to_numpy(self): t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([1, 4])) self.assertRaises(TypeError, lambda: t.numpy()) - @skipIfRocm def test_softmax(self): import torch.nn.functional as F @@ -2881,6 +3017,164 @@ def sparse_log(x): test_op(3, 100, [3, 4, 2, 3, 5, 2]) test_op(4, 100, [3, 4, 2, 3, 5, 2]) + def test_sparse_matmul(self): + """ + This function test `torch.sparse.mm` when both the mat1 and mat2 are sparse tensors. + """ + + def _indices2csr(indices, dim): + nnz = len(indices) + r = [0] * (dim + 1) + last_i = 0 + for i in indices: + if i != last_i: + for _i in range(last_i, i + 1): + r[_i + 1] = r[last_i + 1] + last_i = i + r[last_i + 1] += 1 + for _i in range(last_i, dim): + r[_i + 1] = r[last_i + 1] + assert r[-1] == nnz + return r + + def sparse_mm(a, b, method='scipy'): + a = a.to('cpu') + b = b.to('cpu') + if method == 'scipy': + indices_1 = a._indices().numpy() + values_1 = a._values().numpy() + indices_2 = b._indices().numpy() + values_2 = b._values().numpy() + + mat1 = scipy.sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])), shape=a.shape) + mat2 = scipy.sparse.coo_matrix((values_2, (indices_2[0], indices_2[1])), shape=b.shape) + result = mat1.dot(mat2).tocoo() + return torch.sparse_coo_tensor([result.row, result.col], result.data, result.shape) + else: + assert a.shape[1] == b.shape[0] + n, p = a.shape + p, m = b.shape + indices_a = a._indices() + values_a = a._values() + indices_b = b._indices() + values_b = b._values() + nnz1 = len(indices_a[0]) + nnz2 = len(indices_b[0]) + + if a.is_coalesced() and b.is_coalesced(): + r2 = _indices2csr(indices_b[0], b.shape[0]) + d = defaultdict(values_b.numpy().dtype.type) + for n1 in range(nnz1): + for n2 in range(r2[indices_a[1][n1]], r2[indices_a[1][n1] + 1]): + d[indices_a[0][n1].item(), indices_b[1][n2].item()] += values_a[n1] * values_b[n2] + + else: + d = defaultdict(values_b.numpy().dtype.type) + for n1 in range(nnz1): + for n2 in range(nnz2): + if indices_b[0][n2] == indices_a[1][n1]: + d[indices_a[0][n1].item(), indices_b[1][n2].item()] += values_a[n1] * values_b[n2] + i3 = [] + j3 = [] + values = [] + for i, j in sorted(d): + i3.append(i) + j3.append(j) + values.append(d[i, j]) + return torch.sparse_coo_tensor(torch.tensor([i3, j3]), torch.tensor(values), (n, m)) + + def grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b): + def test_grad_dense(a_s, b_s, g_s): + a = a_s.to_dense().detach() + b = b_s.to_dense().detach() + g = g_s.to_dense().detach() + + a.requires_grad_(True) + b.requires_grad_(True) + c = a @ b + c.backward(g) + return a.grad.sparse_mask(a_s.coalesce()), b.grad.sparse_mask(b_s.coalesce()) + + a, _, _ = self._gen_sparse(sparse_dims, nnz, shape_a) + b, _, _ = self._gen_sparse(sparse_dims, nnz, shape_b) + a.requires_grad_(True) + b.requires_grad_(True) + + c = torch.sparse.mm(a, b) + c2 = c.to_dense().detach() + c2 = torch.rand_like(c2) + g = c2.sparse_mask(c.coalesce()) + + c.backward(g) + + a_grad, b_grad = test_grad_dense(a, b, g) + self.assertEqual(a.grad, a_grad) + self.assertEqual(b.grad, b_grad) + + def test_sparse_matmul(sparse_dims, nnz, shape_a, shape_b): + a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a) + b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b) + + # python implementation + r1 = sparse_mm(a, b, 'scipy' if TEST_SCIPY else 'direct') + + self.assertEqual(r1.to_dense(), torch.mm(a.to_dense(), b.to_dense())) + + # cpp implementation + r2 = torch.sparse.mm(a, b) + self.assertEqual(r1, r2) + + a.requires_grad_(True) + b.requires_grad_(True) + + # check autograd support on sparse matmul + def fn(D1, D2): + return torch.sparse.mm(D1, D2).to_dense() + + # For cuda, `nondet_tol` is set with `1e-5` + # This is because cuSparse sometimes returns approximate zero values like `~e-323` + # TODO: Check this cuSparse issue. + # This happens when you do chain multiplication `torch.sparse.mm` operations + gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5) + grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b) + + def test_error_cases(): + def fn(sparse_dims, nnz, shape_a, shape_b): + a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a) + b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b) + r2 = torch.sparse.mm(a, b) + + # This is not a matrix + self.assertRaises(RuntimeError, lambda: fn(3, 4, [2, 2, 2], [2, 2, 2])) + + # Shapes does not + self.assertRaisesRegex(RuntimeError, + r"mat1 and mat2 shapes cannot be multiplied \(2x3 and 4x2\)", + lambda: fn(2, 10, [2, 3], [4, 2])) + + def different_dtypes(): + a, i_a, v_a = self._gen_sparse(2, 10, [2, 2]) + b, i_b, v_b = self._gen_sparse(2, 10, [2, 2]) + r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32)) + + self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes) + + for n in range(2, 5): + for m in range(2, 8): + for p in range(2, 8): + test_sparse_matmul(2, 10, [n, m], [m, p]) + + test_sparse_matmul(2, 0, [0, 0], [0, 0]) + test_sparse_matmul(2, 0, [0, 10], [10, 0]) + test_error_cases() + + def test_assign(self): + def assign_to(a): + a, i_a, v_a = self._gen_sparse(2, 5, [2, 3]) + a[0] = 100 + + self.assertRaises(TypeError, assign_to) + class TestUncoalescedSparse(TestSparse): def setUp(self): @@ -2951,6 +3245,33 @@ def test_cuda_sparse_cpu_dense_add(self): with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): x + sparse_y +class TestSparseUnaryUfuncs(TestCase): + exact_dtype = True + + @ops(sparse_unary_ufuncs) + def test_sparse_consistency(self, device, dtype, op): + unsupportedTypes = [torch.bfloat16, torch.cfloat, torch.cdouble] + if dtype in unsupportedTypes: + self.skipTest('Skipped! Unsupported dtypes for Sparse') + + samples = op.sample_inputs(device, dtype) + + if len(samples) == 0: + self.skipTest("Skipped! No sample inputs!") + + sample = samples[0] + + if len(sample.input) > 1: + self.skipTest("Skipped! Testing unary ops, one input is expected") + sample = sample.input[0] + + expected = op(sample) + assert torch.is_tensor(expected) + output = op(sample.to_sparse()) + assert torch.is_tensor(output) + self.assertEqual(output.to_dense(), expected) + +instantiate_device_type_tests(TestSparseUnaryUfuncs, globals()) if __name__ == '__main__': run_tests() diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index d7ef731699b34..227e119e4ca7b 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -8,8 +8,9 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, precisionOverride, - skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA) + (instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA, + skipCPUIfNoMkl, skipCUDAIfRocm, deviceCountAtLeast, onlyCUDA, OpDTypes) +from torch.testing._internal.common_methods_invocations import spectral_funcs from distutils.version import LooseVersion from typing import Optional, List @@ -22,10 +23,6 @@ if TEST_LIBROSA: import librosa -# saves the torch.fft function that's clobbered by importing the torch.fft module -fft_fn = torch.fft -import torch.fft - def _complex_stft(x, *args, **kwargs): # Transform real and imaginary components separably @@ -99,28 +96,9 @@ def _stft_reference(x, hop_length, window): class TestFFT(TestCase): exact_dtype = True - @skipCPUIfNoMkl - @skipCUDAIfRocm - def test_fft_function_clobbered(self, device): - t = torch.randn((100, 2), device=device) - eager_result = fft_fn(t, 1) - - def method_fn(t): - return t.fft(1) - scripted_method_fn = torch.jit.script(method_fn) - - self.assertEqual(scripted_method_fn(t), eager_result) - - with self.assertRaisesRegex(TypeError, "'module' object is not callable"): - torch.fft(t, 1) - - @skipCPUIfNoMkl - @skipCUDAIfRocm @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @precisionOverride({torch.complex64: 1e-4, torch.float: 1e-4}) - @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) - def test_fft_numpy(self, device, dtype): + @ops([op for op in spectral_funcs if not op.ndimensional]) + def test_reference_1d(self, device, dtype, op): norm_modes = ((None, "forward", "backward", "ortho") if LooseVersion(np.__version__) >= '1.20.0' else (None, "ortho")) @@ -147,39 +125,15 @@ def test_fft_numpy(self, device, dtype): ) ] + for iargs in test_args: + args = list(iargs) + input = args[0] + args = args[1:] - fft_functions = ['fft', 'ifft', 'hfft', 'irfft'] - # Real-only functions - if not dtype.is_complex: - fft_functions += ['rfft', 'ihfft'] - - for fname in fft_functions: - torch_fn = getattr(torch.fft, fname) - numpy_fn = getattr(np.fft, fname) - - def fn(t: torch.Tensor, n: Optional[int], dim: int, norm: Optional[str]): - return torch_fn(t, n, dim, norm) - scripted_fn = torch.jit.script(fn) - - # TODO: revisit the following function if t.fft() becomes torch.fft.fft - # def method_fn(t, n, dim, norm): - # return getattr(t, fname)(n, dim, norm) - # scripted_method_fn = torch.jit.script(method_fn) - - # TODO: revisit the following function if t.fft() becomes torch.fft.fft - # torch_fns = (torch.fft.fft, torch.Tensor.fft, scripted_fn, scripted_method_fn) - torch_fns = (torch_fn, scripted_fn) - - for iargs in test_args: - args = list(iargs) - input = args[0] - args = args[1:] - - expected = numpy_fn(input.cpu().numpy(), *args) - exact_dtype = dtype in (torch.double, torch.complex128) - for fn in torch_fns: - actual = fn(input, *args) - self.assertEqual(actual, expected, exact_dtype=exact_dtype) + expected = op.ref(input.cpu().numpy(), *args) + exact_dtype = dtype in (torch.double, torch.complex128) + actual = op(input, *args) + self.assertEqual(actual, expected, exact_dtype=exact_dtype) @skipCUDAIfRocm @skipCPUIfNoMkl @@ -221,35 +175,25 @@ def test_fft_round_trip(self, device, dtype): forward != torch.fft.fft or x.is_complex())) # Note: NumPy will throw a ValueError for an empty input - @skipCUDAIfRocm - @skipCPUIfNoMkl @onlyOnCPUAndCUDA - @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) - def test_empty_fft(self, device, dtype): + @ops(spectral_funcs) + def test_empty_fft(self, device, dtype, op): t = torch.empty(0, device=device, dtype=dtype) match = r"Invalid number of data points \([-\d]*\) specified" - fft_functions = [torch.fft.fft, torch.fft.fftn, - torch.fft.ifft, torch.fft.ifftn, - torch.fft.irfft, torch.fft.irfftn, - torch.fft.hfft] - # Real-only functions - if not dtype.is_complex: - fft_functions += [torch.fft.rfft, torch.fft.rfftn, torch.fft.ihfft] - for fn in fft_functions: - with self.assertRaisesRegex(RuntimeError, match): - fn(t) + with self.assertRaisesRegex(RuntimeError, match): + op(t) def test_fft_invalid_dtypes(self, device): t = torch.randn(64, device=device, dtype=torch.complex128) - with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"): + with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"): torch.fft.rfft(t) - with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"): + with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"): torch.fft.rfftn(t) - with self.assertRaisesRegex(RuntimeError, "Expected a real input tensor"): + with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"): torch.fft.ihfft(t) @skipCUDAIfRocm @@ -291,74 +235,21 @@ def test_fft_type_promotion(self, device, dtype): C = torch.fft.rfft(t) self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype]) - @skipCUDAIfRocm - @skipCPUIfNoMkl @onlyOnCPUAndCUDA - @dtypes(torch.half, torch.bfloat16) - def test_fft_half_errors(self, device, dtype): + @ops(spectral_funcs, dtypes=OpDTypes.unsupported, + allowed_dtypes=[torch.half, torch.bfloat16]) + def test_fft_half_and_bfloat16_errors(self, device, dtype, op): # TODO: Remove torch.half error when complex32 is fully implemented x = torch.randn(64, device=device).to(dtype) - fft_functions = (torch.fft.fft, torch.fft.ifft, - torch.fft.fftn, torch.fft.ifftn, - torch.fft.rfft, torch.fft.irfft, - torch.fft.rfftn, torch.fft.irfftn, - torch.fft.hfft, torch.fft.ihfft) - for fn in fft_functions: - with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "): - fn(x) - - @skipCPUIfNoMkl - @skipCUDAIfRocm - @onlyOnCPUAndCUDA - @dtypes(torch.double, torch.complex128) # gradcheck requires double - def test_fft_backward(self, device, dtype): - test_args = list(product( - # input - (torch.randn(67, device=device, dtype=dtype), - torch.randn(9, 6, 3, device=device, dtype=dtype)), - # n - (None, 6), - # dim - (-1, 0), - # norm - (None, "forward", "backward", "ortho") - )) - - fft_functions = ['fft', 'ifft', 'hfft', 'irfft'] - # Real-only functions - if not dtype.is_complex: - fft_functions += ['rfft', 'ihfft'] - - for fname in fft_functions: - torch_fn = getattr(torch.fft, fname) - - for iargs in test_args: - args = list(iargs) - input = args[0] - args = args[1:] - - # Workaround for gradcheck's poor support for complex input - # Use real input instead and put view_as_complex into the graph - if dtype.is_complex: - def test_fn(x): - return torch_fn(torch.view_as_complex(x), *args) - input = torch.view_as_real(input).detach().requires_grad_() - else: - def test_fn(x): - return torch_fn(x, *args) - input = input.detach().requires_grad_() - - self.assertTrue(torch.autograd.gradcheck(test_fn, (input,))) + with self.assertRaisesRegex(RuntimeError, "Unsupported dtype "): + op(x) # nd-fft tests - @skipCPUIfNoMkl - @skipCUDAIfRocm @onlyOnCPUAndCUDA @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @precisionOverride({torch.complex64: 1e-4, torch.float: 1e-4}) - @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) - def test_fftn_numpy(self, device, dtype): + @ops([op for op in spectral_funcs if op.ndimensional]) + def test_reference_nd(self, device, dtype, op): norm_modes = ((None, "forward", "backward", "ortho") if LooseVersion(np.__version__) >= '1.20.0' else (None, "ortho")) @@ -376,28 +267,14 @@ def test_fftn_numpy(self, device, dtype): (4, (10, 10), (0, 1)) ] - fft_functions = ['fftn', 'ifftn', 'irfftn'] - # Real-only functions - if not dtype.is_complex: - fft_functions += ['rfftn'] - for input_ndim, s, dim in transform_desc: shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) input = torch.randn(*shape, device=device, dtype=dtype) - for fname, norm in product(fft_functions, norm_modes): - torch_fn = getattr(torch.fft, fname) - numpy_fn = getattr(np.fft, fname) - - def fn(t: torch.Tensor, s: Optional[List[int]], dim: Optional[List[int]], norm: Optional[str]): - return torch_fn(t, s, dim, norm) - - torch_fns = (torch_fn, torch.jit.script(fn)) - - expected = numpy_fn(input.cpu().numpy(), s, dim, norm) + for norm in norm_modes: + expected = op.ref(input.cpu().numpy(), s, dim, norm) exact_dtype = dtype in (torch.double, torch.complex128) - for fn in torch_fns: - actual = fn(input, s, dim, norm) - self.assertEqual(actual, expected, exact_dtype=exact_dtype) + actual = op(input, s, dim, norm) + self.assertEqual(actual, expected, exact_dtype=exact_dtype) @skipCUDAIfRocm @skipCPUIfNoMkl @@ -439,126 +316,259 @@ def test_fftn_round_trip(self, device, dtype): self.assertEqual(x, y, exact_dtype=( forward != torch.fft.fftn or x.is_complex())) + @onlyOnCPUAndCUDA + @ops([op for op in spectral_funcs if op.ndimensional], + allowed_dtypes=[torch.float, torch.cfloat]) + def test_fftn_invalid(self, device, dtype, op): + a = torch.rand(10, 10, 10, device=device, dtype=dtype) + + with self.assertRaisesRegex(RuntimeError, "dims must be unique"): + op(a, dim=(0, 1, 0)) + + with self.assertRaisesRegex(RuntimeError, "dims must be unique"): + op(a, dim=(2, -1)) + + with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): + op(a, s=(1,), dim=(0, 1)) + + with self.assertRaisesRegex(IndexError, "Dimension out of range"): + op(a, dim=(3,)) + + with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"): + op(a, s=(10, 10, 10, 10)) + + # 2d-fft tests + + # NOTE: 2d transforms are only thin wrappers over n-dim transforms, + # so don't require exhaustive testing. + @skipCPUIfNoMkl @skipCUDAIfRocm @onlyOnCPUAndCUDA - @dtypes(torch.double, torch.complex128) # gradcheck requires double - def test_fftn_backward(self, device, dtype): - # input_ndim, s, dim + @dtypes(torch.double, torch.complex128) + def test_fft2_numpy(self, device, dtype): + norm_modes = ((None, "forward", "backward", "ortho") + if LooseVersion(np.__version__) >= '1.20.0' + else (None, "ortho")) + + # input_ndim, s transform_desc = [ - *product((2, 3), (None,), (None, (0,), (0, -1))), - *product((2, 3), (None, (4, 10)), (None,)), - (4, None, None), - (3, (10, 10), (0, 1)), - (2, (1, 1), (0, 1)), - (2, None, (1,)), - (1, None, (0,)), - (1, (11,), (0,)), + *product(range(2, 5), (None, (4, 10))), ] + + fft_functions = ['fft2', 'ifft2', 'irfft2'] + if dtype.is_floating_point: + fft_functions += ['rfft2'] + + for input_ndim, s in transform_desc: + shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) + input = torch.randn(*shape, device=device, dtype=dtype) + for fname, norm in product(fft_functions, norm_modes): + torch_fn = getattr(torch.fft, fname) + numpy_fn = getattr(np.fft, fname) + + def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None): + return torch_fn(t, s, dim, norm) + + torch_fns = (torch_fn, torch.jit.script(fn)) + + # Once with dim defaulted + input_np = input.cpu().numpy() + expected = numpy_fn(input_np, s, norm=norm) + for fn in torch_fns: + actual = fn(input, s, norm=norm) + self.assertEqual(actual, expected) + + # Once with explicit dims + dim = (1, 0) + expected = numpy_fn(input.cpu(), s, dim, norm) + for fn in torch_fns: + actual = fn(input, s, dim, norm) + self.assertEqual(actual, expected) + + @skipCUDAIfRocm + @skipCPUIfNoMkl + @onlyOnCPUAndCUDA + @dtypes(torch.float, torch.complex64) + def test_fft2_fftn_equivalence(self, device, dtype): norm_modes = (None, "forward", "backward", "ortho") - fft_functions = ['fftn', 'ifftn', 'irfftn'] + # input_ndim, s, dim + transform_desc = [ + *product(range(2, 5), (None, (4, 10)), (None, (1, 0))), + (3, None, (0, 2)), + ] + + fft_functions = ['fft', 'ifft', 'irfft'] # Real-only functions - if not dtype.is_complex: - fft_functions += ['rfftn'] + if dtype.is_floating_point: + fft_functions += ['rfft'] for input_ndim, s, dim in transform_desc: shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) - input = torch.randn(*shape, device=device, dtype=dtype) + x = torch.randn(*shape, device=device, dtype=dtype) - for fname, norm in product(fft_functions, norm_modes): - torch_fn = getattr(torch.fft, fname) + for func, norm in product(fft_functions, norm_modes): + f2d = getattr(torch.fft, func + '2') + fnd = getattr(torch.fft, func + 'n') + + kwargs = {'s': s, 'norm': norm} - # Workaround for gradcheck's poor support for complex input - # Use real input instead and put view_as_complex into the graph - if dtype.is_complex: - def test_fn(x): - return torch_fn(torch.view_as_complex(x), s, dim, norm) - inputs = (torch.view_as_real(input).detach().requires_grad_(),) + if dim is not None: + kwargs['dim'] = dim + expect = fnd(x, **kwargs) else: - def test_fn(x): - return torch_fn(x, s, dim, norm) - inputs = (input.detach().requires_grad_(),) + expect = fnd(x, dim=(-2, -1), **kwargs) + + actual = f2d(x, **kwargs) - self.assertTrue(torch.autograd.gradcheck(test_fn, inputs)) + self.assertEqual(actual, expect) @skipCUDAIfRocm @skipCPUIfNoMkl @onlyOnCPUAndCUDA - def test_fftn_invalid(self, device): + def test_fft2_invalid(self, device): a = torch.rand(10, 10, 10, device=device) - fft_funcs = (torch.fft.fftn, torch.fft.ifftn, - torch.fft.rfftn, torch.fft.irfftn) + fft_funcs = (torch.fft.fft2, torch.fft.ifft2, + torch.fft.rfft2, torch.fft.irfft2) for func in fft_funcs: - with self.assertRaisesRegex(RuntimeError, "FFT dims must be unique"): - func(a, dim=(0, 1, 0)) + with self.assertRaisesRegex(RuntimeError, "dims must be unique"): + func(a, dim=(0, 0)) - with self.assertRaisesRegex(RuntimeError, "FFT dims must be unique"): + with self.assertRaisesRegex(RuntimeError, "dims must be unique"): func(a, dim=(2, -1)) with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): - func(a, s=(1,), dim=(0, 1)) + func(a, s=(1,)) with self.assertRaisesRegex(IndexError, "Dimension out of range"): - func(a, dim=(3,)) - - with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"): - func(a, s=(10, 10, 10, 10)) + func(a, dim=(2, 3)) c = torch.complex(a, a) - with self.assertRaisesRegex(RuntimeError, "Expected a real input"): - torch.fft.rfftn(c) + with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"): + torch.fft.rfft2(c) + + # Helper functions + + @skipCPUIfNoMkl + @skipCUDAIfRocm + @onlyOnCPUAndCUDA + @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') + @dtypes(torch.float, torch.double) + def test_fftfreq_numpy(self, device, dtype): + test_args = [ + *product( + # n + range(1, 20), + # d + (None, 10.0), + ) + ] + + functions = ['fftfreq', 'rfftfreq'] + + for fname in functions: + torch_fn = getattr(torch.fft, fname) + numpy_fn = getattr(np.fft, fname) + + for n, d in test_args: + args = (n,) if d is None else (n, d) + expected = numpy_fn(*args) + actual = torch_fn(*args, device=device, dtype=dtype) + self.assertEqual(actual, expected, exact_dtype=False) + + @skipCPUIfNoMkl + @skipCUDAIfRocm + @onlyOnCPUAndCUDA + @dtypes(torch.float, torch.double) + def test_fftfreq_out(self, device, dtype): + for func in (torch.fft.fftfreq, torch.fft.rfftfreq): + expect = func(n=100, d=.5, device=device, dtype=dtype) + actual = torch.empty((), device=device, dtype=dtype) + with self.assertWarnsRegex(UserWarning, "out tensor will be resized"): + func(n=100, d=.5, out=actual) + self.assertEqual(actual, expect) + + + @skipCPUIfNoMkl + @skipCUDAIfRocm + @onlyOnCPUAndCUDA + @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') + @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) + def test_fftshift_numpy(self, device, dtype): + test_args = [ + # shape, dim + *product(((11,), (12,)), (None, 0, -1)), + *product(((4, 5), (6, 6)), (None, 0, (-1,))), + *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))), + ] + + functions = ['fftshift', 'ifftshift'] + + for shape, dim in test_args: + input = torch.rand(*shape, device=device, dtype=dtype) + input_np = input.cpu().numpy() + + for fname in functions: + torch_fn = getattr(torch.fft, fname) + numpy_fn = getattr(np.fft, fname) + + expected = numpy_fn(input_np, axes=dim) + actual = torch_fn(input, dim=dim) + self.assertEqual(actual, expected) + + @skipCPUIfNoMkl + @skipCUDAIfRocm + @onlyOnCPUAndCUDA + @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') + @dtypes(torch.float, torch.double) + def test_fftshift_frequencies(self, device, dtype): + for n in range(10, 15): + sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2), + device=device, dtype=dtype) + x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype) + + # Test fftshift sorts the fftfreq output + shifted = torch.fft.fftshift(x) + self.assertTrue(torch.allclose(shifted, shifted.sort().values)) + self.assertEqual(sorted_fft_freqs, shifted) + + # And ifftshift is the inverse + self.assertEqual(x, torch.fft.ifftshift(shifted)) # Legacy fft tests def _test_fft_ifft_rfft_irfft(self, device, dtype): + complex_dtype = { + torch.float16: torch.complex32, + torch.float32: torch.complex64, + torch.float64: torch.complex128 + }[dtype] + def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): - x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) - for normalized in (True, False): - res = x.fft(signal_ndim, normalized=normalized) - rec = res.ifft(signal_ndim, normalized=normalized) + x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device)) + dim = tuple(range(-signal_ndim, 0)) + for norm in ('ortho', None): + res = torch.fft.fftn(x, dim=dim, norm=norm) + rec = torch.fft.ifftn(res, dim=dim, norm=norm) self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft') - res = x.ifft(signal_ndim, normalized=normalized) - rec = res.fft(signal_ndim, normalized=normalized) + res = torch.fft.ifftn(x, dim=dim, norm=norm) + rec = torch.fft.fftn(res, dim=dim, norm=norm) self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft') def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) signal_numel = 1 signal_sizes = x.size()[-signal_ndim:] - for normalized, onesided in product((True, False), repeat=2): - res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided) - if not onesided: # check Hermitian symmetry - def test_one_sample(res, test_num=10): - idxs_per_dim = [torch.LongTensor(test_num).random_(s).tolist() for s in signal_sizes] - for idx in zip(*idxs_per_dim): - reflected_idx = tuple((s - i) % s for i, s in zip(idx, res.size())) - idx_val = res.__getitem__(idx) - reflected_val = res.__getitem__(reflected_idx) - self.assertEqual(idx_val[0], reflected_val[0], msg='rfft hermitian symmetry on real part') - self.assertEqual(idx_val[1], -reflected_val[1], msg='rfft hermitian symmetry on imaginary part') - if len(sizes) == signal_ndim: - test_one_sample(res) - else: - output_non_batch_shape = res.size()[-(signal_ndim + 1):] - flatten_batch_res = res.view(-1, *output_non_batch_shape) - nb = flatten_batch_res.size(0) - test_idxs = torch.LongTensor(min(nb, 4)).random_(nb) - for test_idx in test_idxs.tolist(): - test_one_sample(flatten_batch_res[test_idx]) - # compare with C2C - xc = torch.stack([x, torch.zeros_like(x)], -1) - xc_res = xc.fft(signal_ndim, normalized=normalized) - self.assertEqual(res, xc_res) - test_input_signal_sizes = [signal_sizes] - rec = res.irfft(signal_ndim, normalized=normalized, - onesided=onesided, signal_sizes=signal_sizes) + dim = tuple(range(-signal_ndim, 0)) + for norm in (None, 'ortho'): + res = torch.fft.rfftn(x, dim=dim, norm=norm) + rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm) self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft') - if not onesided: # check that we can use C2C ifft - rec = res.ifft(signal_ndim, normalized=normalized) - self.assertEqual(x, rec.select(-1, 0), atol=1e-8, rtol=0, msg='twosided rfft and ifft real') - self.assertEqual(rec.select(-1, 1).abs().mean(), 0, atol=1e-8, - rtol=0, msg='twosided rfft and ifft imaginary') + res = torch.fft.fftn(x, dim=dim, norm=norm) + rec = torch.fft.ifftn(res, dim=dim, norm=norm) + x_complex = torch.complex(x, torch.zeros_like(x)) + self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)') # contiguous case _test_real((100,), 1) @@ -568,12 +578,12 @@ def test_one_sample(res, test_num=10): _test_real((50, 40, 70), 3) _test_real((30, 1, 50, 25, 20), 3) - _test_complex((100, 2), 1) - _test_complex((100, 100, 2), 1) - _test_complex((100, 100, 2), 2) - _test_complex((1, 20, 80, 60, 2), 2) - _test_complex((50, 40, 70, 2), 3) - _test_complex((6, 5, 50, 25, 20, 2), 3) + _test_complex((100,), 1) + _test_complex((100, 100), 1) + _test_complex((100, 100), 2) + _test_complex((1, 20, 80, 60), 2) + _test_complex((50, 40, 70), 3) + _test_complex((6, 5, 50, 25, 20), 3) # non-contiguous case _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type @@ -583,20 +593,10 @@ def test_one_sample(res, test_num=10): _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) - _test_complex((2, 100), 1, lambda x: x.t()) - _test_complex((100, 2), 1, lambda x: x.expand(100, 100, 2)) - _test_complex((300, 200, 3), 2, lambda x: x[:100, :100, 1:]) # input is not aligned to complex type - _test_complex((20, 90, 110, 2), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) - _test_complex((40, 60, 3, 80, 2), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) - _test_complex((30, 55, 50, 22, 2), 3, lambda x: x[:, 3:53, 15:40, 1:21]) - - # non-contiguous with strides not representable as aligned with complex type - _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [3, 2, 1])) - _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) - _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) - _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [3, 3, 1])) - _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) - _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) + _test_complex((100,), 1, lambda x: x.expand(100, 100)) + _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) + _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) + _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21]) @skipCUDAIfRocm @skipCPUIfNoMkl @@ -647,7 +647,7 @@ def plan_cache_max_size(device, n): # Test that different GPU has different cache x0 = torch.randn(2, 3, 3, device=devices[0]) x1 = x0.to(devices[1]) - self.assertEqual(x0.rfft(2), x1.rfft(2)) + self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1))) # If a plan is used across different devices, the following line (or # the assert above) would trigger illegal memory access. Other ways # to trigger the error include @@ -711,7 +711,9 @@ def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, else: window = None if expected_error is None: - result = x.stft(n_fft, hop_length, win_length, window, center=center) + with self.maybeWarnsRegex(UserWarning, "stft with return_complex=False"): + result = x.stft(n_fft, hop_length, win_length, window, + center=center, return_complex=False) # NB: librosa defaults to np.complex64 output, no matter what # the input dtype ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) @@ -923,15 +925,22 @@ def test_complex_stft_onesided(self, device): with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, window=window, pad_mode='constant', onesided=True) else: - y = x.stft(10, window=window, pad_mode='constant', onesided=True) - self.assertEqual(y.dtype, torch.double) - self.assertEqual(y.size(), (6, 51, 2)) + y = x.stft(10, window=window, pad_mode='constant', onesided=True, + return_complex=True) + self.assertEqual(y.dtype, torch.cdouble) + self.assertEqual(y.size(), (6, 51)) - y = torch.rand(100, device=device, dtype=torch.double) - window = torch.randn(10, device=device, dtype=torch.cdouble) + x = torch.rand(100, device=device, dtype=torch.cdouble) with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, pad_mode='constant', onesided=True) + # stft is currently warning that it requires return-complex while an upgrader is written + def test_stft_requires_complex(self, device): + x = torch.rand(100) + y = x.stft(10, pad_mode='constant') + # with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): + # y = x.stft(10, pad_mode='constant') + @skipCUDAIfRocm @skipCPUIfNoMkl def test_fft_input_modification(self, device): @@ -939,18 +948,18 @@ def test_fft_input_modification(self, device): signal = torch.ones((2, 2, 2), device=device) signal_copy = signal.clone() - spectrum = signal.fft(2) + spectrum = torch.fft.fftn(signal, dim=(-2, -1)) self.assertEqual(signal, signal_copy) spectrum_copy = spectrum.clone() - _ = torch.ifft(spectrum, 2) + _ = torch.fft.ifftn(spectrum, dim=(-2, -1)) self.assertEqual(spectrum, spectrum_copy) - half_spectrum = torch.rfft(signal, 2) + half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1)) self.assertEqual(signal, signal_copy) half_spectrum_copy = half_spectrum.clone() - _ = torch.irfft(half_spectrum_copy, 2, signal_sizes=(2, 2)) + _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1)) self.assertEqual(half_spectrum, half_spectrum_copy) @onlyOnCPUAndCUDA @@ -959,7 +968,7 @@ def test_fft_input_modification(self, device): def test_istft_round_trip_simple_cases(self, device, dtype): """stft -> istft should recover the original signale""" def _test(input, n_fft, length): - stft = torch.stft(input, n_fft=n_fft) + stft = torch.stft(input, n_fft=n_fft, return_complex=True) inverse = torch.istft(stft, n_fft=n_fft, length=length) self.assertEqual(input, inverse, exact_dtype=True) @@ -981,7 +990,7 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): for sizes in data_sizes: for i in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) - stft = torch.stft(original, **stft_kwargs) + stft = torch.stft(original, return_complex=True, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) # trim the original for case when constructed signal is shorter than original diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 407ea03acda6c..6a8dd07c8c760 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -1,9 +1,9 @@ +import numpy as np import torch from torch import nn -import numpy as np - from torch.testing._internal.common_utils import TestCase, run_tests +from typing import Dict, Optional class StaticRuntime: def __init__(self, scripted): @@ -13,12 +13,22 @@ def __init__(self, scripted): else: self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph) - def __call__(self, *inps): - return self.static_runtime.run(inps) + def __call__(self, *args, **kwargs): + if not kwargs: + return self.static_runtime.run(args) + else: + return self.static_runtime.run(args, kwargs) + def benchmark(self, args, kwargs, warmup_runs, main_runs): + self.static_runtime.benchmark(args, kwargs, warmup_runs, main_runs) -def linear_shim(input, weight, bias=None): - # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor + def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs): + return self.static_runtime.benchmark_individual_ops( + args, kwargs, warmup_runs, main_runs + ) + + +def linear_shim(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = input.matmul(weight.t()) if bias is not None: output += bias @@ -95,6 +105,21 @@ def trivial_graph(a, b, c): s = torch.tensor([[3, 3], [3, 3]]) return a + b * c + s +def loop_graph(a, b, iters : int): + c = a + b * 2 + for i in range(iters): + c = c + b + c *= 2 + c -= a + return c + +def output_graph(a, b, c, iters : int): + s = torch.tensor([[3, 3], [3, 3]]) + k = a + b * c + s + d : Dict[int, torch.Tensor] = {} + for i in range(iters): + d[i] = k + i + return d class TestStaticRuntime(TestCase): def test_multihead_attention_layer(self): @@ -106,7 +131,8 @@ def test_multihead_attention_layer(self): DROPOUT = 0.1 device = torch.device("cpu") attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) - src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) + with torch.no_grad(): + src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) attention.eval() @@ -116,9 +142,36 @@ def test_multihead_attention_layer(self): attention_a = StaticRuntime(attention) o_test = attention_a(src, src, src, src_mask) + o_test_kw = attention_a(src, src, value=src, mask=src_mask) + for a, b in zip(o_ref, o_test): torch.testing.assert_allclose(a, b) + for a, b in zip(o_ref, o_test_kw): + torch.testing.assert_allclose(a, b) + + def test_multihead_attention_layer_benchmark(self): + HID_DIM = 256 + QUERY_LEN = 8 + BATCH_SIZE = 128 + LAYERS = 3 + HEADS = 8 + DROPOUT = 0.1 + device = torch.device("cpu") + attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) + with torch.no_grad(): + src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) + src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) + + attention.eval() + attention = torch.jit.script(attention) + attention_a = StaticRuntime(attention) + + attention_a.benchmark([src, src, src, src_mask], {}, 2, 2) + metrics = attention_a.benchmark_individual_ops( + [src, src, src, src_mask], {}, 2, 2 + ) + def test_mlp(self): # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh ln_bot = [512, 512, 64] @@ -129,8 +182,9 @@ def test_mlp(self): bot_l_acc = StaticRuntime(bot_l) top_l = create_mlp(ln_top, sigmoid_top) top_l_acc = StaticRuntime(top_l) - bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) - top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) + with torch.no_grad(): + bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) + top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) ref_bot = bot_l(bot_inp) acc_bot = bot_l_acc(bot_inp)[0] torch.testing.assert_allclose(acc_bot, ref_bot) @@ -138,8 +192,9 @@ def test_mlp(self): acc_top = top_l_acc(top_inp)[0] torch.testing.assert_allclose(acc_top, ref_top) for _ in range(5): - bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) - top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) + with torch.no_grad(): + bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) + top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) ref_bot = bot_l(bot_inp) acc_bot = bot_l_acc(bot_inp)[0] torch.testing.assert_allclose(acc_bot, ref_bot) @@ -147,13 +202,78 @@ def test_mlp(self): acc_top = top_l_acc(top_inp)[0] torch.testing.assert_allclose(acc_top, ref_top) - # def test_trivial_graph(self): - # s = torch.full((2, 2), 2) - # tg = torch.jit.script(trivial_graph) - # o_ref = tg(s, s, s) - # tg_a = StaticRuntime(tg) - # o_test = tg_a(s, s, s)[0] - # torch.testing.assert_allclose(o_ref, o_test) + def test_trivial_graph(self): + s = torch.full((2, 2), 2) + tg = torch.jit.script(trivial_graph) + o_ref = tg(s, s, s) + tg_a = StaticRuntime(tg) + o_test = tg_a(s, s, s)[0] + torch.testing.assert_allclose(o_ref, o_test) + + def test_leaky_relu(self): + s = torch.randn(5, 5) + tg = torch.jit.script(nn.LeakyReLU(0.1)) + o_ref = tg(s) + tg_a = StaticRuntime(tg) + o_test = tg_a(s)[0] + torch.testing.assert_allclose(o_ref, o_test) + + def test_fusion_trivial_graph(self): + s = torch.full((2, 2), 2) + tg = torch.jit.script(trivial_graph) + o_ref = tg(s, s, s) + torch._C._fuse_to_static_runtime(tg.graph) + assert "StaticSubgraph" in str(tg.graph) + o_test = tg(s, s, s) + torch.testing.assert_allclose(o_ref, o_test) + + def test_fusion_multihead_attention_layer(self): + HID_DIM = 256 + QUERY_LEN = 8 + BATCH_SIZE = 128 + LAYERS = 3 + HEADS = 8 + DROPOUT = 0.1 + device = torch.device("cpu") + attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) + with torch.no_grad(): + src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) + src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) + + attention.eval() + attention = torch.jit.script(attention) + attention.eval() + o_ref = attention(src, src, src, src_mask) + + torch._C._fuse_to_static_runtime(attention._c) + o_test = attention(src, src, src, src_mask) + + for a, b in zip(o_ref, o_test): + torch.testing.assert_allclose(a, b) + + def test_fusion_loop(self): + a = torch.randn(5, 5) + b = torch.randn(5, 5) + c = 4 + lg = torch.jit.script(loop_graph) + o_ref = lg(a, b, c) + torch._C._fuse_to_static_runtime(lg.graph) + assert "StaticSubgraph" in str(lg.graph) + o_test = lg(a, b, c) + torch.testing.assert_allclose(o_ref, o_test) + + def test_fusion_outputs(self): + a = torch.randn(2, 2) + b = torch.randn(2, 2) + c = 4 + og = torch.jit.script(output_graph) + o_ref = og(a, b, b, c) + torch._C._fuse_to_static_runtime(og.graph) + assert "StaticSubgraph" in str(og.graph) + o_test = og(a, b, b, c) + for i in o_ref.keys(): + torch.testing.assert_allclose(o_ref[i], o_test[i]) + if __name__ == "__main__": diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index eaaee2dab836f..8c616e0963695 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -1,20 +1,64 @@ +import torch +import numpy as np + +import sys +import math import warnings import unittest -from itertools import product +from itertools import product, combinations, combinations_with_replacement, permutations import random -import torch +from torch.testing._internal.common_utils import ( + TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, + torch_to_numpy_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC, + IS_WINDOWS) +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA, + onlyCPU, largeTensorTest, precisionOverride, dtypes, + onlyCUDA, skipCPUIf, dtypesIfCUDA, dtypesIfCPU) + +# TODO: refactor tri_tests_args, _compare_trilu_indices, run_additional_tri_tests +from torch.testing._internal.common_methods_invocations import ( + tri_tests_args, _compare_trilu_indices, run_additional_tri_tests) + + +# TODO: replace with make_tensor +def _generate_input(shape, dtype, device, with_extremal): + if shape == (): + x = torch.tensor((), dtype=dtype, device=device) + else: + if dtype.is_floating_point or dtype.is_complex: + # work around torch.randn not being implemented for bfloat16 + if dtype == torch.bfloat16: + x = torch.randn(*shape, device=device) * random.randint(30, 100) + x = x.to(torch.bfloat16) + else: + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x[torch.randn(*shape) > 0.5] = 0 + if with_extremal and dtype.is_floating_point: + # Use extremal values + x[torch.randn(*shape) > 0.5] = float('nan') + x[torch.randn(*shape) > 0.5] = float('inf') + x[torch.randn(*shape) > 0.5] = float('-inf') + elif with_extremal and dtype.is_complex: + x[torch.randn(*shape) > 0.5] = complex('nan') + x[torch.randn(*shape) > 0.5] = complex('inf') + x[torch.randn(*shape) > 0.5] = complex('-inf') + elif dtype == torch.bool: + x = torch.zeros(shape, dtype=dtype, device=device) + x[torch.randn(*shape) > 0.5] = True + else: + x = torch.randint(15, 100, shape, dtype=dtype, device=device) + + return x -from torch.testing._internal.common_utils import \ - (TestCase, run_tests, do_test_empty_full, TEST_NUMPY, suppress_warnings, - torch_to_numpy_dtype_dict, slowTest) -from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA, - onlyCPU, skipCUDAIfNotRocm, largeCUDATensorTest, precisionOverride, dtypes, - onlyCUDA, skipCPUIf, dtypesIfCUDA) -if TEST_NUMPY: - import numpy as np +# TODO: replace with make_tensor +def _rand_shape(dim, min_size, max_size): + shape = [] + for i in range(dim): + shape.append(random.randint(min_size, max_size)) + return tuple(shape) # Test suite for tensor creation ops # @@ -28,6 +72,1564 @@ class TestTensorCreation(TestCase): exact_dtype = True + @onlyCPU + @dtypes(torch.float) + def test_diag_embed(self, device, dtype): + x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4) + result = torch.diag_embed(x) + expected = torch.stack([torch.diag(r) for r in x], 0) + self.assertEqual(result, expected) + + result = torch.diag_embed(x, offset=1, dim1=0, dim2=2) + expected = torch.stack([torch.diag(r, 1) for r in x], 1) + self.assertEqual(result, expected) + + def test_cat_mem_overlap(self, device): + x = torch.rand((1, 3), device=device).expand((6, 3)) + y = torch.rand((3, 3), device=device) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.cat([y, y], out=x) + + @onlyOnCPUAndCUDA + def test_vander(self, device): + x = torch.tensor([1, 2, 3, 5], device=device) + + self.assertEqual((0, 0), torch.vander(torch.tensor([]), 0).shape) + + with self.assertRaisesRegex(RuntimeError, "N must be non-negative."): + torch.vander(x, N=-1) + + with self.assertRaisesRegex(RuntimeError, "x must be a one-dimensional tensor."): + torch.vander(torch.stack((x, x))) + + @onlyOnCPUAndCUDA + @dtypes(torch.bool, torch.uint8, torch.int8, torch.short, torch.int, torch.long, + torch.float, torch.double, + torch.cfloat, torch.cdouble) + def test_vander_types(self, device, dtype): + if dtype is torch.uint8: + # Note: no negative uint8 values + X = [[1, 2, 3, 5], [0, 1 / 3, 1, math.pi, 3 / 7]] + elif dtype is torch.bool: + # Note: see https://github.com/pytorch/pytorch/issues/37398 + # for why this is necessary. + X = [[True, True, True, True], [False, True, True, True, True]] + elif dtype in [torch.cfloat, torch.cdouble]: + X = [[1 + 1j, 1 + 0j, 0 + 1j, 0 + 0j], + [2 + 2j, 3 + 2j, 4 + 3j, 5 + 4j]] + else: + X = [[1, 2, 3, 5], [-math.pi, 0, 1 / 3, 1, math.pi, 3 / 7]] + + N = [None, 0, 1, 3] + increasing = [False, True] + + for x, n, inc in product(X, N, increasing): + numpy_dtype = torch_to_numpy_dtype_dict[dtype] + pt_x = torch.tensor(x, device=device, dtype=dtype) + np_x = np.array(x, dtype=numpy_dtype) + + pt_res = torch.vander(pt_x, increasing=inc) if n is None else torch.vander(pt_x, n, inc) + np_res = np.vander(np_x, n, inc) + + self.assertEqual( + pt_res, + torch.from_numpy(np_res), + atol=1e-3, + rtol=0, + exact_dtype=False) + + def test_cat_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device) + + expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device) + self.assertEqual(torch.cat((x, x), 0), expected1) + + expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dt, device=device) + self.assertEqual(torch.cat((x, x), 1), expected2) + + def test_fill_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + for x in [torch.tensor((10, 10), dtype=dt, device=device), + torch.empty(10000, dtype=dt, device=device)]: # large tensor + numel = x.numel() + bound = 100 if dt in (torch.uint8, torch.int8) else 2000 + for n in range(-bound, bound, bound // 10): + x.fill_(n) + self.assertEqual(x, torch.tensor([n] * numel, dtype=dt, device=device)) + self.assertEqual(dt, x.dtype) + + def test_roll(self, device): + numbers = torch.arange(1, 9, device=device) + + single_roll = numbers.roll(1, 0) + expected = torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device) + self.assertEqual(single_roll, expected, msg="{} did not equal expected result".format(single_roll)) + + roll_backwards = numbers.roll(-2, 0) + expected = torch.tensor([3, 4, 5, 6, 7, 8, 1, 2], device=device) + self.assertEqual(roll_backwards, expected, msg="{} did not equal expected result".format(roll_backwards)) + + data = numbers.view(2, 2, 2) + rolled = data.roll(1, 0) + expected = torch.tensor([5, 6, 7, 8, 1, 2, 3, 4], device=device).view(2, 2, 2) + self.assertEqual(expected, rolled, msg="{} did not equal expected result: {}".format(rolled, expected)) + + data = data.view(2, 4) + # roll a loop until back where started + loop_rolled = data.roll(2, 0).roll(4, 1) + self.assertEqual(data, loop_rolled, msg="{} did not equal the original: {}".format(loop_rolled, data)) + # multiple inverse loops + self.assertEqual(data, data.roll(-20, 0).roll(-40, 1)) + self.assertEqual(torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device), numbers.roll(1, 0)) + + # test non-contiguous + # strided equivalent to numbers.as_strided(size=(4, 2), stride=(1, 4)) + strided = numbers.view(2, 4).transpose(0, 1) + self.assertFalse(strided.is_contiguous(), "this test needs a non-contiguous tensor") + expected = torch.tensor([4, 8, 1, 5, 2, 6, 3, 7]).view(4, 2) + rolled = strided.roll(1, 0) + self.assertEqual(expected, rolled, + msg="non contiguous tensor rolled to {} instead of {} ".format(rolled, expected)) + + # test roll with no dimension specified + expected = numbers.roll(1, 0).view(2, 4) + self.assertEqual(expected, data.roll(1), msg="roll with no dims should flatten and roll.") + self.assertEqual(expected, data.roll(1, dims=None), msg="roll with no dims should flatten and roll.") + + # test roll over multiple dimensions + expected = torch.tensor([[7, 8, 5, 6], [3, 4, 1, 2]], device=device) + double_rolled = data.roll(shifts=(2, -1), dims=(1, 0)) + self.assertEqual(double_rolled, expected, + msg="should be able to roll over two dimensions, got {}".format(double_rolled)) + + self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=())) + self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=1)) + # shifts/dims should align + self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1, 2), dims=(1,))) + self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1,), dims=(1, 2))) + + # test bool tensor + t = torch.zeros(6, dtype=torch.bool, device=device) + t[0] = True + t[3] = True + self.assertEqual(torch.tensor([False, True, False, False, True, False]), t.roll(1, 0)) + + # test complex tensor + t = torch.tensor([1, 2 + 1j, 3.5, 4. + 2j, 5j, 6.], device=device) + t[0] = 1 + 0.5j + t[3] = 4. + expected = torch.tensor([6., 1 + 0.5j, 2 + 1j, 3.5, 4., 5j], device=device) + self.assertEqual(expected, t.roll(1, 0)) + + @slowTest + def test_triu_tril(self, device): + def gen_mask(shape, diagonal, device, upper): + mask = torch.zeros(*shape[-2:]).byte() + for i in range(shape[-2]): + for j in range(shape[-1]): + cond = j - i < diagonal if upper else j - i > diagonal + if cond: + mask[i, j] = 1 + return mask.expand(*shape).to(device) + + torch_functions = {True: torch.triu, False: torch.tril} + numpy_functions = {True: np.triu, False: np.tril} + + # TODO: remove this when bool and half are supported for torch.where + def bool_half_compat_where(pred, true_tensor, false_tensor, dtype): + if dtype == torch.bool or dtype == torch.half: + return torch.where(pred.byte(), true_tensor.byte(), false_tensor.byte()).to(dtype=dtype) + else: + return torch.where(pred, true_tensor, false_tensor) + + def run_test(shape, device, diagonal, dtype): + x = torch.empty(*shape, device=device, dtype=dtype).fill_(2) + + for upper in [True, False]: + # normal test with mask + torch_tri_func = torch_functions[upper] + res1 = torch_tri_func(x, diagonal=diagonal) + res2 = torch.empty(0, device=device, dtype=dtype) + torch_tri_func(x, diagonal=diagonal, out=res2) + exp_mask = gen_mask(shape, diagonal, device, upper) + expected = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x, dtype) + self.assertEqual(res1, res2, atol=0, rtol=0) + self.assertEqual(expected, res1, atol=0, rtol=0) + + # non-contiguous and expanded tensors test + if 0 not in shape: + for s in range(-len(shape), -1): + # non-contiguous tensors + x_nc = x.clone().transpose(s, s + 1) + exp_mask = gen_mask(x_nc.size(), diagonal, device, upper) + if 1 not in shape: + assert not x_nc.is_contiguous(), "x is intentionally non-contiguous" + exp_nc = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x_nc, dtype) + self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, atol=0, rtol=0) + x_nc_is_contiguous = x_nc.is_contiguous() + if upper: + self.assertEqual(x_nc.triu_(diagonal), exp_nc, atol=0, rtol=0) + else: + self.assertEqual(x_nc.tril_(diagonal), exp_nc, atol=0, rtol=0) + + self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous, + "contiguity of x_nc should not be changed") + + # expanded tensors + expanded_size = (x.size(0),) + x.size() + x_expanded = x.clone().expand(*expanded_size) + if x.size(0) != 1: + assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride" + output = torch_tri_func(x_expanded, diagonal) + self.assertEqual(output, expected.expand(expanded_size), atol=0, rtol=0) + if x.size(0) != 1: + self.assertTrue(0 in x_expanded.stride(), + "geometry of x_expanded should be the same") + if upper: + self.assertEqual(output, x_expanded.triu_(diagonal), atol=0, rtol=0) + else: + self.assertEqual(output, x_expanded.tril_(diagonal), atol=0, rtol=0) + + # numpy test + numpy_tri_func = numpy_functions[upper] + self.assertEqual(numpy_tri_func(x.to('cpu').numpy(), diagonal), res1.cpu().numpy()) + + diagonals = [-2, -1, 0, 1, 2] + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7), # thin matrices + (3, 0), (0, 3, 3), (3, 3, 0, 0), # no numel matrices + (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices + (1, 3), (5, 1, 3), (7, 5, 1, 3), # very thin matrices + (1, 3, 3, 3), (3, 1, 3, 3, 3)] # unsqueezed batch dimensions + dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.bfloat16] + for s, d, dtype in product(shapes, diagonals, dtypes): + run_test(s, device, d, dtype) + + def test_diagflat(self, device): + dtype = torch.float32 + # Basic sanity test + x = torch.randn((100,), dtype=dtype, device=device) + result = torch.diagflat(x) + expected = torch.diag(x) + self.assertEqual(result, expected) + + # Test offset + x = torch.randn((100,), dtype=dtype, device=device) + result = torch.diagflat(x, 17) + expected = torch.diag(x, 17) + self.assertEqual(result, expected) + + # Test where input has more than one dimension + x = torch.randn((2, 3, 4), dtype=dtype, device=device) + result = torch.diagflat(x) + expected = torch.diag(x.contiguous().view(-1)) + self.assertEqual(result, expected) + + # Noncontig input + x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0) + self.assertFalse(x.is_contiguous()) + result = torch.diagflat(x) + expected = torch.diag(x.contiguous().view(-1)) + self.assertEqual(result, expected) + + # Complex number support + result = torch.diagflat(torch.ones(4, dtype=torch.complex128)) + expected = torch.eye(4, dtype=torch.complex128) + self.assertEqual(result, expected) + + def test_block_diag(self, device): + def block_diag_workaround(*arrs): + arrs_expanded = [] + for a in arrs: + if a.dim() == 2: + arrs_expanded.append(a) + elif a.dim() == 1: + arrs_expanded.append(a.expand(1, a.size(0))) + elif a.dim() == 0: + arrs_expanded.append(a.expand(1, 1)) + shapes = torch.tensor([a.shape for a in arrs_expanded], device=device) + out = torch.zeros( + torch.sum(shapes, dim=0).tolist(), + dtype=arrs_expanded[0].dtype, + device=device + ) + r, c = 0, 0 + for i, (rr, cc) in enumerate(shapes): + out[r:r + rr, c:c + cc] = arrs_expanded[i] + r += rr + c += cc + return out + + tensors = [ + torch.rand((2, 2), device=device), + torch.rand((2, 3), device=device), + torch.rand(10, device=device), + torch.rand((8, 1), device=device), + torch.rand(1, device=device)[0] + ] + result = torch.block_diag(*tensors) + result_check = block_diag_workaround(*tensors) + self.assertEqual(result, result_check) + + tensor = torch.rand(1, device=device)[0] + result = torch.block_diag(tensor) + result_check = tensor.expand(1, 1) + self.assertEqual(result, result_check) + + tensor = torch.rand(10, device=device) + result = torch.block_diag(tensor) + result_check = tensor.expand(1, tensor.size(0)) + self.assertEqual(result, result_check) + + result = torch.block_diag() + result_check = torch.empty(1, 0, device=device) + self.assertEqual(result, result_check) + self.assertEqual(result.device.type, 'cpu') + + test_dtypes = [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128 + ] + # Test pairs of different dtypes + for dtype1 in test_dtypes: + for dtype2 in test_dtypes: + a = torch.tensor(1, device=device, dtype=dtype1) + b = torch.tensor(2, device=device, dtype=dtype2) + result = torch.block_diag(a, b) + result_dtype = torch.result_type(a, b) + result_check = torch.tensor([[1, 0], [0, 2]], device=device, dtype=result_dtype) + self.assertEqual(result, result_check) + + with self.assertRaisesRegex( + RuntimeError, + "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input 1 has 3 dimensions" + ): + torch.block_diag(torch.tensor(5), torch.tensor([[[6]]])) + + with self.assertRaisesRegex( + RuntimeError, + "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input 0 has 4 dimensions" + ): + torch.block_diag(torch.tensor([[[[6]]]])) + + if device != 'cpu': + with self.assertRaisesRegex( + RuntimeError, + ( + "torch.block_diag: input tensors must all be on the same device." + " Input 0 is on device cpu and input 1 is on device " + ) + ): + torch.block_diag(torch.ones(2, 2).cpu(), torch.ones(2, 2, device=device)) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found") + def test_block_diag_scipy(self, device): + import scipy.linalg + scipy_tensors_list = [ + [ + 1, + [2], + [], + [3, 4, 5], + [[], []], + [[6], [7.3]] + ], + [ + [[1, 2], [3, 4]], + [1] + ], + [ + [[4, 9], [7, 10]], + [4.6, 9.12], + [1j + 3] + ], + [] + ] + + expected_torch_types = [ + torch.float32, + torch.int64, + torch.complex64, + torch.float32 + ] + + expected_scipy_types = [ + torch.float64, + # windows scipy block_diag returns int32 types + torch.int32 if IS_WINDOWS else torch.int64, + torch.complex128, + torch.float64 + ] + + for scipy_tensors, torch_type, scipy_type in zip(scipy_tensors_list, expected_torch_types, expected_scipy_types): + torch_tensors = [torch.tensor(t, device=device) for t in scipy_tensors] + torch_result = torch.block_diag(*torch_tensors) + self.assertEqual(torch_result.dtype, torch_type) + + scipy_result = torch.tensor( + scipy.linalg.block_diag(*scipy_tensors), + device=device + ) + self.assertEqual(scipy_result.dtype, scipy_type) + scipy_result = scipy_result.to(torch_type) + + self.assertEqual(torch_result, scipy_result) + + @onlyOnCPUAndCUDA + @dtypes(torch.float32, torch.float64) + def test_torch_complex(self, device, dtype): + real = torch.tensor([1, 2], device=device, dtype=dtype) + imag = torch.tensor([3, 4], device=device, dtype=dtype) + z = torch.complex(real, imag) + complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 + self.assertEqual(torch.tensor([1.0 + 3.0j, 2.0 + 4.0j], dtype=complex_dtype), z) + + @onlyOnCPUAndCUDA + @dtypes(torch.float32, torch.float64) + def test_torch_polar(self, device, dtype): + abs = torch.tensor([1, 2, -3, -4.5, 1, 1], device=device, dtype=dtype) + angle = torch.tensor([math.pi / 2, 5 * math.pi / 4, 0, -11 * math.pi / 6, math.pi, -math.pi], + device=device, dtype=dtype) + z = torch.polar(abs, angle) + complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 + self.assertEqual(torch.tensor([1j, -1.41421356237 - 1.41421356237j, -3, + -3.89711431703 - 2.25j, -1, -1], + dtype=complex_dtype), + z, atol=1e-5, rtol=1e-5) + + @onlyOnCPUAndCUDA + @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, + torch.float16, torch.complex64, torch.complex128, torch.bool) + def test_torch_complex_floating_dtype_error(self, device, dtype): + for op in (torch.complex, torch.polar): + a = torch.tensor([1, 2], device=device, dtype=dtype) + b = torch.tensor([3, 4], device=device, dtype=dtype) + error = r"Expected both inputs to be Float or Double tensors but " \ + r"got [A-Za-z]+ and [A-Za-z]+" + with self.assertRaisesRegex(RuntimeError, error): + op(a, b) + + @onlyOnCPUAndCUDA + @dtypes(torch.float32, torch.float64) + def test_torch_complex_same_dtype_error(self, device, dtype): + + def dtype_name(dtype): + return 'Float' if dtype == torch.float32 else 'Double' + + for op in (torch.complex, torch.polar): + other_dtype = torch.float64 if dtype == torch.float32 else torch.float32 + a = torch.tensor([1, 2], device=device, dtype=dtype) + b = torch.tensor([3, 4], device=device, dtype=other_dtype) + error = "Expected object of scalar type {} but got scalar type " \ + "{} for second argument".format(dtype_name(dtype), + dtype_name(other_dtype)) + with self.assertRaisesRegex(RuntimeError, error): + op(a, b) + + @onlyOnCPUAndCUDA + @dtypes(torch.float32, torch.float64) + def test_torch_complex_out_dtype_error(self, device, dtype): + + def dtype_name(dtype): + return 'Float' if dtype == torch.float32 else 'Double' + + def complex_dtype_name(dtype): + return 'ComplexFloat' if dtype == torch.complex64 else 'ComplexDouble' + + for op in (torch.complex, torch.polar): + a = torch.tensor([1, 2], device=device, dtype=dtype) + b = torch.tensor([3, 4], device=device, dtype=dtype) + out = torch.zeros(2, device=device, dtype=dtype) + expected_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 + error = "Expected object of scalar type {} but got scalar type " \ + "{} for argument 'out'".format( + complex_dtype_name(expected_dtype), dtype_name(dtype)) + with self.assertRaisesRegex(RuntimeError, error): + op(a, b, out=out) + + def test_cat_empty_legacy(self, device): + # FIXME: this is legacy behavior and should be removed + # when we support empty tensors with arbitrary sizes + dtype = torch.float32 + + x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) + empty = torch.randn((0,), dtype=dtype, device=device) + + res1 = torch.cat([x, empty], dim=1) + res2 = torch.cat([empty, x], dim=1) + self.assertEqual(res1, res2) + + res1 = torch.cat([empty, empty], dim=1) + self.assertEqual(res1, empty) + + with self.assertRaisesRegex(RuntimeError, + 'non-empty list of Tensors'): + torch.cat([], dim=1) + + def test_cat_empty(self, device): + dtype = torch.float32 + + x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) + empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device) + + res1 = torch.cat([x, empty], dim=1) + res2 = torch.cat([empty, x], dim=1) + self.assertEqual(res1, res2) + + res1 = torch.cat([empty, empty], dim=1) + self.assertEqual(res1, empty) + + # check non-legacy-behavior (sizes don't match) + empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device) + self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) + self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) + + # check non-legacy-behavior (dimensions don't match) + empty = torch.randn((4, 0), dtype=dtype, device=device) + self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) + self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) + + def test_cat_out(self, device): + x = torch.zeros((0), device=device) + y = torch.randn((4, 6), device=device) + + with self.assertRaisesRegex( + RuntimeError, r"unsupported operation:.* input tensor 0"): + torch.cat([x, y], dim=0, out=x) + + with self.assertRaisesRegex( + RuntimeError, r"unsupported operation:.* input tensor 1"): + torch.cat([x, y], dim=0, out=y) + + z = torch.zeros((4, 6), device=device) + with self.assertRaisesRegex( + RuntimeError, r"unsupported operation:.* input tensor 1"): + torch.cat([y, z], out=z[:2, :]) + + w = y.view(-1).clone() + a = torch.cat([w[:2], w[4:6]]) + b = torch.cat([w[:2], w[4:6]], out=w[6:10]) + self.assertEqual(a, b) + self.assertEqual(w[:6], y.view(-1)[:6]) + + # Case: + # Reference: https://github.com/pytorch/pytorch/issues/49878 + for dim in [0, 1]: + x = torch.zeros((10, 5, 2), device=device) + + random_length = random.randint(1, 4) + y = x.narrow(dim, 0, x.shape[dim] - random_length) + val = torch.full_like(y[0], 3., device=device) + + if dim == 0: + self.assertTrue(y.is_contiguous()) + else: + self.assertFalse(y.is_contiguous()) + + torch.cat((val[None],) * y.shape[0], dim=0, out=y) + + expected_y = torch.cat((val[None],) * y.shape[0], dim=0) + expected_x = torch.zeros((10, 5, 2), device=device) + if dim == 0: + expected_x[:x.shape[dim] - random_length, :, :] = expected_y + elif dim == 1: + expected_x[:, :x.shape[dim] - random_length, :] = expected_y + + self.assertEqual(y, expected_y) + self.assertEqual(x, expected_x) + + def test_cat_out_channels_last(self, device): + x = torch.randn((4, 3, 8, 8)) + y = torch.randn(x.shape) + res1 = torch.cat((x, y)) + z = res1.clone().contiguous(memory_format=torch.channels_last) + res2 = torch.cat((x, y), out=z) + self.assertEqual(res1, res2) + + @onlyCPU + def test_cat_in_channels_last(self, device): + for dim in range(4): + x = torch.randn((4, 15, 8, 8), device=device) + y = torch.randn(x.shape, device=device) + res1 = torch.cat((x, y), dim=dim) + x = x.clone().contiguous(memory_format=torch.channels_last) + y = y.clone().contiguous(memory_format=torch.channels_last) + res2 = torch.cat((x, y), dim=dim) + self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last)) + self.assertEqual(res1, res2) + + # Size larger than grain size. + x = torch.randn((4, 15, 256, 256), device=device) + y = torch.randn(x.shape, device=device) + res1 = torch.cat((x, y), dim=dim) + x = x.clone().contiguous(memory_format=torch.channels_last) + y = y.clone().contiguous(memory_format=torch.channels_last) + res2 = torch.cat((x, y), dim=dim) + self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last)) + self.assertEqual(res1, res2) + + @onlyCUDA + def test_cat_preserve_channels_last(self, device): + x = torch.randn((4, 3, 8, 8), device=device) + y = torch.randn(x.shape, device=device) + res1 = torch.cat((x, y)) + res2 = torch.cat((x.contiguous(memory_format=torch.channels_last), y.contiguous(memory_format=torch.channels_last))) + self.assertEqual(res1, res2) + self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last)) + + @onlyCUDA + @deviceCountAtLeast(2) + def test_cat_different_devices(self, devices): + cuda0 = torch.randn((3, 3), device=devices[0]) + cuda1 = torch.randn((3, 3), device=devices[1]) + with self.assertRaisesRegex(RuntimeError, + "input tensors must be on the same device"): + torch.cat((cuda0, cuda1)) + cpu = torch.randn(3, 3) + with self.assertRaisesRegex(RuntimeError, + "input tensors must be on the same device"): + torch.cat((cuda0, cpu)) + with self.assertRaisesRegex(RuntimeError, + "input tensors must be on the same device"): + torch.cat((cpu, cuda0)) + + # TODO: reconcile with other cat tests + # TODO: Compare with a NumPy reference instead of CPU + @onlyCUDA + def test_cat(self, device): + SIZE = 10 + for dim in range(-3, 3): + pos_dim = dim if dim >= 0 else 3 + dim + x = torch.rand(13, SIZE, SIZE, device=device).transpose(0, pos_dim) + y = torch.rand(17, SIZE, SIZE, device=device).transpose(0, pos_dim) + z = torch.rand(19, SIZE, SIZE, device=device).transpose(0, pos_dim) + + res1 = torch.cat((x, y, z), dim) + self.assertEqual(res1.narrow(pos_dim, 0, 13), x, atol=0, rtol=0) + self.assertEqual(res1.narrow(pos_dim, 13, 17), y, atol=0, rtol=0) + self.assertEqual(res1.narrow(pos_dim, 30, 19), z, atol=0, rtol=0) + + x = torch.randn(20, SIZE, SIZE, device=device) + self.assertEqual(torch.cat(torch.split(x, 7)), x) + self.assertEqual(torch.cat(torch.chunk(x, 7)), x) + + y = torch.randn(1, SIZE, SIZE, device=device) + z = torch.cat([x, y]) + self.assertEqual(z.size(), (21, SIZE, SIZE)) + + # TODO: update this test to compare against NumPy instead of CPU + @onlyCUDA + @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypes(torch.float, torch.double) + def test_device_rounding(self, device, dtype): + # test half-to-even + a = [-5.8, -3.5, -2.3, -1.5, -0.5, 0.5, 1.5, 2.3, 3.5, 5.8] + res = [-6., -4., -2., -2., 0., 0., 2., 2., 4., 6.] + + a_tensor = torch.tensor(a, device=device).round() + res_tensor = torch.tensor(res, device='cpu') + self.assertEqual(a_tensor, res_tensor) + + # Note: This test failed on XLA since its test cases are created by empty_strided which + # doesn't support overlapping sizes/strides in XLA impl + @onlyOnCPUAndCUDA + def test_like_fn_stride_proparation_vs_tensoriterator_unary_op(self, device): + # Test like functions against tensoriterator based unary operator (exp) to + # make sure the returned tensor from like function follows the same stride propergation + # rule as what tensoriterator does for unary operator. The like function's output strides + # is computed on CPU side always, no need to test GPU here. + + def compare_helper_(like_fn, t): + te = torch.exp(t) + tl = like_fn(t) + self.assertEqual(te.stride(), tl.stride()) + self.assertEqual(te.size(), tl.size()) + + like_fns = [ + lambda t, **kwargs: torch.zeros_like(t, **kwargs), + lambda t, **kwargs: torch.ones_like(t, **kwargs), + lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs), + lambda t, **kwargs: torch.randint_like(t, 100, **kwargs), + lambda t, **kwargs: torch.randn_like(t, **kwargs), + lambda t, **kwargs: torch.rand_like(t, **kwargs), + lambda t, **kwargs: torch.full_like(t, 7, **kwargs), + lambda t, **kwargs: torch.empty_like(t, **kwargs)] + + # dense non-overlapping tensor, + # non-dense non-overlapping sliced tensor + # non-dense non-overlapping gapped tensor + # non-dense non-overlapping 0 strided tensor + # non-dense overlapping general tensor + # non-dense overlapping sliced tensor + # non-dense overlapping gapped tensor + # non-dense overlapping 0 strided tensor + # non-dense overlapping equal strides + tset = ( + torch.randn(4, 3, 2, device=device), + torch.randn(4, 3, 2, device=device)[:, :, ::2], + torch.empty_strided((4, 3, 2), (10, 3, 1), device=device).fill_(1.0), + torch.empty_strided((4, 3, 2), (10, 0, 3), device=device).fill_(1.0), + torch.empty_strided((4, 3, 2), (10, 1, 2), device=device).fill_(1.0), + torch.empty_strided((4, 3, 2), (4, 2, 1), device=device)[:, :, ::2].fill_(1.0), + torch.empty_strided((4, 3, 2), (10, 1, 1), device=device).fill_(1.0), + torch.empty_strided((4, 1, 1, 2), (10, 0, 0, 2), device=device).fill_(1.0), + torch.empty_strided((4, 2, 3), (10, 3, 3), device=device).fill_(1.0)) + + for like_fn in like_fns: + for t in tset: + for p in permutations(range(t.dim())): + tp = t.permute(p) + compare_helper_(like_fn, tp) + + def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype): + # Test error for non-tuple argument + t = torch.randn(10) + with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): + torch_fn(t) + # Test error for a single array + with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): + torch_fn((t)) + + # Test 0-D + num_tensors = random.randint(1, 5) + input_t = [torch.tensor(random.uniform(0, 10), device=device, dtype=dtype) for i in range(num_tensors)] + actual = torch_fn(input_t) + expected = np_fn([input.cpu().numpy() for input in input_t]) + self.assertEqual(actual, expected) + + for ndims in range(1, 5): + base_shape = list(_rand_shape(ndims, min_size=1, max_size=5)) + for i in range(ndims): + shape = list(base_shape) + num_tensors = random.randint(1, 5) + torch_input = [] + # Create tensors with shape being different along one axis only + for param in range(num_tensors): + shape[i] = random.randint(1, 5) + torch_input.append(_generate_input(tuple(shape), dtype, device, with_extremal=False)) + + # Determine if input tensors have valid dimensions. + valid_dim = True + for k in range(len(torch_input) - 1): + for tdim in range(ndims): + # Test whether all tensors have the same shape except in concatenating dimension + # Unless the number of dimensions is less than the corresponding at_least function dimension + # Since the original concatenating dimension would shift after applying at_least and would no + # longer be the concatenating dimension + if (ndims < at_least_dim or tdim != dim) and torch_input[k].size()[tdim] != torch_input[k + 1].size()[tdim]: + valid_dim = False + + # Special case for hstack is needed since hstack works differently when ndims is 1 + if valid_dim or (torch_fn is torch.hstack and ndims == 1): + # Valid dimensions, test against numpy + np_input = [input.cpu().numpy() for input in torch_input] + actual = torch_fn(torch_input) + expected = np_fn(np_input) + self.assertEqual(actual, expected) + else: + # Invalid dimensions, test for error + with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match except in dimension"): + torch_fn(torch_input) + with self.assertRaises(ValueError): + np_input = [input.cpu().numpy() for input in torch_input] + np_fn(np_input) + + @onlyOnCPUAndCUDA + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + torch.testing.get_all_complex_dtypes())) + def test_hstack_column_stack(self, device, dtype): + ops = ((torch.hstack, np.hstack), (torch.column_stack, np.column_stack)) + for torch_op, np_op in ops: + self._test_special_stacks(1, 1, torch_op, np_op, device, dtype) + + # Test torch.column_stack with combinations of 1D and 2D tensors input + one_dim_tensor = torch.arange(0, 10).to(dtype=dtype, device=device) + two_dim_tensor = torch.arange(0, 100).to(dtype=dtype, device=device).reshape(10, 10) + inputs = two_dim_tensor, one_dim_tensor, two_dim_tensor, one_dim_tensor + torch_result = torch.column_stack(inputs) + + np_inputs = [input.cpu().numpy() for input in inputs] + np_result = np.column_stack(np_inputs) + + self.assertEqual(np_result, + torch_result) + + @onlyOnCPUAndCUDA + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + torch.testing.get_all_complex_dtypes())) + def test_vstack_row_stack(self, device, dtype): + ops = ((torch.vstack, np.vstack), (torch.row_stack, np.row_stack)) + for torch_op, np_op in ops: + self._test_special_stacks(0, 2, torch_op, np_op, device, dtype) + for i in range(5): + # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N) + n = random.randint(1, 10) + input_a = _generate_input((n,), dtype, device, with_extremal=False) + input_b = _generate_input((1, n), dtype, device, with_extremal=False) + torch_input = [input_a, input_b] + np_input = [input.cpu().numpy() for input in torch_input] + actual = torch_op(torch_input) + expected = np_op(np_input) + self.assertEqual(actual, expected) + + @onlyOnCPUAndCUDA + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + torch.testing.get_all_complex_dtypes())) + def test_dstack(self, device, dtype): + self._test_special_stacks(2, 3, torch.dstack, np.dstack, device, dtype) + for i in range(5): + # Test dimension change for 1D tensor of size (N), 2D tensor of size (1, N), and 3D tensor of size (1, N, 1) + n = random.randint(1, 10) + input_a = _generate_input((n,), dtype, device, with_extremal=False) + input_b = _generate_input((1, n), dtype, device, with_extremal=False) + input_c = _generate_input((1, n, 1), dtype, device, with_extremal=False) + torch_input = [input_a, input_b, input_c] + np_input = [input.cpu().numpy() for input in torch_input] + actual = torch.dstack(torch_input) + expected = np.dstack(np_input) + self.assertEqual(actual, expected) + + # Test dimension change for 2D tensor of size (M, N) and 3D tensor of size (M, N, 1) + m = random.randint(1, 10) + n = random.randint(1, 10) + input_a = _generate_input((m, n), dtype, device, with_extremal=False) + input_b = _generate_input((m, n, 1), dtype, device, with_extremal=False) + torch_input = [input_a, input_b] + np_input = [input.cpu().numpy() for input in torch_input] + actual = torch.dstack(torch_input) + expected = np.dstack(np_input) + self.assertEqual(actual, expected) + + @dtypes(torch.int32, torch.int64) + def test_large_linspace(self, device, dtype): + start = torch.iinfo(dtype).min + end = torch.iinfo(dtype).max & ~0xfff + steps = 15 + x = torch.linspace(start, end, steps, dtype=dtype, device=device) + self.assertGreater(x[1] - x[0], (end - start) / steps) + + @dtypes(torch.float32, torch.float64) + def test_unpack_double(self, device, dtype): + # Reference: https://github.com/pytorch/pytorch/issues/33111 + vals = (2 ** 24 + 1, 2 ** 53 + 1, + np.iinfo(np.int64).max, np.iinfo(np.uint64).max, np.iinfo(np.uint64).max + 1, + -1e500, 1e500) + for val in vals: + t = torch.tensor(val, dtype=dtype, device=device) + a = np.array(val, dtype=torch_to_numpy_dtype_dict[dtype]) + self.assertEqual(t, torch.from_numpy(a)) + + def _float_to_int_conversion_helper(self, vals, device, dtype): + a = np.array(vals, dtype=np.float32).astype(torch_to_numpy_dtype_dict[dtype]) + t = torch.tensor(vals, device=device, dtype=torch.float).to(dtype) + self.assertEqual(torch.from_numpy(a), t.cpu()) + + # Checks that float->integer casts don't produce undefined behavior errors. + # Note: In C++, casting from a floating value to an integral dtype + # is undefined if the floating point value is not within the integral + # dtype's dynamic range. This can (and should) cause undefined behavior + # errors with UBSAN. These casts are deliberate in PyTorch, however, and + # NumPy has the same behavior. + @onlyOnCPUAndCUDA + @unittest.skipIf(IS_MACOS, "Test is broken on MacOS, see https://github.com/pytorch/pytorch/issues/38752") + @unittest.skipIf(IS_PPC, "Test is borken on PowerPC, see https://github.com/pytorch/pytorch/issues/39671") + @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) + def test_float_to_int_conversion_finite(self, device, dtype): + min = torch.finfo(torch.float).min + max = torch.finfo(torch.float).max + + # Note: CUDA max float -> integer conversion is divergent on some dtypes + vals = (min, -2, -1.5, -.5, 0, .5, 1.5, 2, max) + if self.device_type == 'cuda': + if torch.version.hip: + # HIP min float -> int64 conversion is divergent + vals = (-2, -1.5, -.5, 0, .5, 1.5, 2) + else: + vals = (min, -2, -1.5, -.5, 0, .5, 1.5, 2) + + self._float_to_int_conversion_helper(vals, device, dtype) + + # Note: CUDA will fail this test on most dtypes, often dramatically. + @onlyCPU + @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) + def test_float_to_int_conversion_nonfinite(self, device, dtype): + vals = (float('-inf'), float('inf'), float('nan')) + + self._float_to_int_conversion_helper(vals, device, dtype) + + # TODO: re-enable this test + @unittest.skipIf(True, "real and imag not implemented for complex") + @onlyOnCPUAndCUDA + def test_complex_type_conversions(self, device): + dtypes = [torch.float, torch.complex64, torch.complex128] + for from_type in dtypes: + for to_type in dtypes: + from_tensor = torch.randn(4, dtype=from_type, device=device) + to_tensor = from_tensor.to(to_type) + if from_type.is_complex and not to_type.is_complex: + self.assertEqual(torch.real(from_tensor), to_tensor, exact_dtype=False) + elif not from_type.is_complex and to_type.is_complex: + self.assertEqual(from_tensor, torch.real(to_tensor), exact_dtype=False) + self.assertEqual(torch.zeros_like(torch.imag(to_tensor)), torch.imag(to_tensor), exact_dtype=False) + else: + self.assertEqual(from_tensor, to_tensor, exact_dtype=False) + + @slowTest + @onlyCPU + def test_cat_big(self, device): + SIZE1 = 6500 + SIZE2 = 4500 + concat_list = [] + concat_list.append(torch.ones((SIZE1, 1024 * 512), dtype=torch.uint8, device=device)) + concat_list.append(torch.ones((SIZE2, 1024 * 512), dtype=torch.uint8, device=device)) + result = torch.cat(concat_list) + self.assertEqual(result.size(0), SIZE1 + SIZE2) + + @onlyCPU + def test_cat_bad_input_sizes(self, device): + x = torch.randn(2, 1, device=device) + y = torch.randn(2, 1, 1, device=device) + z = torch.randn(2, 1, 1, device=device) + self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z])) + + x = torch.randn(2, 1, 2, device=device) + y = torch.randn(2, 1, 1, device=device) + z = torch.randn(2, 2, 1, device=device) + self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1)) + + @onlyCPU + @dtypes(torch.half, torch.double, torch.int) + def test_cat2(self, device, dtype): + SIZE = 10 + for dim in range(-3, 3): + pos_dim = dim if dim >= 0 else 3 + dim + x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim) + y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim) + z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim) + + res1 = torch.cat((x, y, z), dim) + self.assertEqual(res1.narrow(pos_dim, 0, 13), x, atol=0, rtol=0) + self.assertEqual(res1.narrow(pos_dim, 13, 17), y, atol=0, rtol=0) + self.assertEqual(res1.narrow(pos_dim, 30, 19), z, atol=0, rtol=0) + + x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE), device=device).to(dtype) + self.assertEqual(torch.cat(torch.split(x, 7)), x) + self.assertEqual(torch.cat(torch.chunk(x, 7)), x) + + y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE), device=device).to(dtype) + z = torch.cat([x, y]) + self.assertEqual(z.size(), (21, SIZE, SIZE)) + + self.assertRaises(RuntimeError, lambda: torch.cat([])) + self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None])) + + @onlyCPU + def test_cat_scalars(self, device): + x = torch.tensor(0, device=device) + y = torch.tensor(1, device=device) + with self.assertRaisesRegex(RuntimeError, 'zero-dimensional.*cannot be concatenated'): + torch.cat([x, y]) + + def test_zeros_dtype_out_match(self, device): + d = torch.tensor((2, 3), device=device, dtype=torch.double) + self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d)) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_trilu_indices(self, device): + for test_args in tri_tests_args: + _compare_trilu_indices(self, *test_args) + run_additional_tri_tests(self, 'cpu') + + # test default options + x = torch.ones( + 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) + self.assertEqual( + x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3)) + self.assertEqual( + x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3)) + + # test stride 0 cases + x = torch.ones( + 3, 1, 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) + output = x.triu(2).expand(3, 3, 3, 3) + b = x.clone().expand(3, 3, 3, 3) + self.assertEqual(b.triu(2), output) + self.assertRaises(RuntimeError, lambda: b.triu_(2)) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_stack(self, device): + for dtype in (torch.half, torch.double, torch.int): + x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + for dim in range(4): + res = torch.stack((x, y, z), dim) + res_neg = torch.stack((x, y, z), dim - 4) + expected_size = x.size()[:dim] + (3,) + x.size()[dim:] + self.assertEqual(res, res_neg) + self.assertEqual(res.size(), expected_size) + self.assertEqual(res.select(dim, 0), x, atol=0, rtol=0) + self.assertEqual(res.select(dim, 1), y, atol=0, rtol=0) + self.assertEqual(res.select(dim, 2), z, atol=0, rtol=0) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_stack_out(self, device): + for dtype in (torch.half, torch.double, torch.int): + x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + for dim in range(4): + expected_size = x.size()[:dim] + (3,) + x.size()[dim:] + res_out = x.new(expected_size) + res_neg_out = x.new(expected_size) + res_out_dp = res_out.data_ptr() + res_out_neg_dp = res_neg_out.data_ptr() + torch.stack((x, y, z), dim, out=res_out) + torch.stack((x, y, z), dim - 4, out=res_neg_out) + self.assertEqual(res_out, res_neg_out) + self.assertEqual(res_out.size(), expected_size) + self.assertEqual(res_out_dp, res_out.data_ptr()) + self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr()) + self.assertEqual(res_out.select(dim, 0), x, atol=0, rtol=0) + self.assertEqual(res_out.select(dim, 1), y, atol=0, rtol=0) + self.assertEqual(res_out.select(dim, 2), z, atol=0, rtol=0) + + def test_repeat_interleave(self, device): + x = torch.tensor([0, 1, 2, 3], device=device) + expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device) + self.assertEqual(torch.repeat_interleave(x), expected) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2)) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.arange(4.0, device=device)) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device)) + + y = torch.tensor([[1, 2], [3, 4]], device=device) + + y1_v1 = torch.repeat_interleave(y, 2) + y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device)) + y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device)) + y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device) + self.assertEqual(y1_v1, y1_expect) + self.assertEqual(y1_v2, y1_expect) + self.assertEqual(y1_v3, y1_expect) + + y2 = torch.repeat_interleave(y, 3, dim=1) + y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]], device=device) + self.assertEqual(y2, y2_expect) + + y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0) + y3_expect = torch.tensor([[1, 2], + [3, 4], + [3, 4]], device=device) + self.assertEqual(y3, y3_expect) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0) + + with self.assertRaises(RuntimeError): + torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0) + + # test zero sized dimension + x = torch.zeros((5, 0), device=device) + y = torch.repeat_interleave(x, repeats=3, dim=1) + self.assertEqual(y, x.new_zeros(5, 0, device=device)) + + x = torch.tensor([], dtype=torch.int64, device=device) + y = torch.repeat_interleave(x, x) + self.assertEqual(y, x) + + # TODO: udpate to work on CUDA, too + @onlyCPU + def test_new_methods_requires_grad(self, device): + size = (10,) + test_cases = [ + # method name, args + ('new_full', [size, 1]), + ('new_empty', [size]), + ('new_zeros', [size]), + ] + for method_name, args in test_cases: + x = torch.randn(size) + for requires_grad in [True, False]: + x_new = x.__getattribute__(method_name)(*args, requires_grad=requires_grad) + self.assertEqual(x_new.requires_grad, requires_grad) + x = torch.randint(10, size) + with self.assertRaisesRegex( + RuntimeError, + r'Only Tensors of floating point and complex dtype can require gradients'): + x_new = x.__getattribute__(method_name)(*args, requires_grad=True) + + # TODO: update to work on CUDA, too? + @onlyCPU + def test_tensor_from_sequence(self, device): + class MockSequence(object): + def __init__(self, lst): + self.lst = lst + + def __len__(self): + return len(self.lst) + + def __getitem__(self, item): + raise TypeError + + class GoodMockSequence(MockSequence): + def __getitem__(self, item): + return self.lst[item] + + bad_mock_seq = MockSequence([1.0, 2.0, 3.0]) + good_mock_seq = GoodMockSequence([1.0, 2.0, 3.0]) + with self.assertRaisesRegex(ValueError, 'could not determine the shape'): + torch.Tensor(bad_mock_seq) + self.assertEqual(torch.Tensor([1.0, 2.0, 3.0]), torch.Tensor(good_mock_seq)) + + # TODO: update to work on CUDA, too? + @onlyCPU + def test_simple_scalar_cast(self, device): + ok = [torch.Tensor([1.5]), torch.zeros(1, 1, 1, 1)] + ok_values = [1.5, 0] + + not_ok = map(torch.Tensor, [[], [1, 2], [[1, 2], [3, 4]]]) + + for tensor, value in zip(ok, ok_values): + self.assertEqual(int(tensor), int(value)) + self.assertEqual(float(tensor), float(value)) + self.assertEqual(complex(tensor), complex(value)) + + self.assertEqual(complex(torch.tensor(1.5j)), 1.5j) + + for tensor in not_ok: + self.assertRaises(ValueError, lambda: int(tensor)) + self.assertRaises(ValueError, lambda: float(tensor)) + self.assertRaises(ValueError, lambda: complex(tensor)) + + self.assertRaises(RuntimeError, lambda: float(torch.tensor(1.5j))) + self.assertRaises(RuntimeError, lambda: int(torch.tensor(1.5j))) + + # TODO: update to work on CUDA, too? + @onlyCPU + def test_offset_scalar_cast(self, device): + x = torch.Tensor([1, 2, 3]) + y = x[2:] + self.assertEqual(int(y), 3) + + def test_meshgrid(self, device): + a = torch.tensor(1, device=device) + b = torch.tensor([1, 2, 3], device=device) + c = torch.tensor([1, 2], device=device) + grid_a, grid_b, grid_c = torch.meshgrid([a, b, c]) + self.assertEqual(grid_a.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_b.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_c.shape, torch.Size([1, 3, 2])) + grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c) + self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2])) + expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64, device=device) + expected_grid_b = torch.tensor([[[1, 1], + [2, 2], + [3, 3]]], device=device) + expected_grid_c = torch.tensor([[[1, 2], + [1, 2], + [1, 2]]], device=device) + self.assertTrue(grid_a.equal(expected_grid_a)) + self.assertTrue(grid_b.equal(expected_grid_b)) + self.assertTrue(grid_c.equal(expected_grid_c)) + self.assertTrue(grid_a2.equal(expected_grid_a)) + self.assertTrue(grid_b2.equal(expected_grid_b)) + self.assertTrue(grid_c2.equal(expected_grid_c)) + + def test_cartesian_prod(self, device): + a = torch.tensor([1], device=device) + b = torch.tensor([1, 2, 3], device=device) + c = torch.tensor([1, 2], device=device) + prod = torch.cartesian_prod(a, b, c) + expected = torch.tensor(list(product([a], b, c)), device=device) + self.assertEqual(expected, prod) + + # test 0 size input + d = torch.empty(0, dtype=b.dtype, device=device) + prod = torch.cartesian_prod(a, b, c, d) + expected = torch.empty(0, 4, dtype=b.dtype, device=device) + self.assertEqual(expected, prod) + + # test single input + prod = torch.cartesian_prod(b) + self.assertEqual(b, prod) + + def test_combinations(self, device): + a = torch.tensor([1, 2, 3], device=device) + + c = torch.combinations(a, r=1) + expected = torch.tensor(list(combinations(a, r=1)), device=device) + self.assertEqual(c, expected) + + c = torch.combinations(a, r=1, with_replacement=True) + expected = torch.tensor(list(combinations_with_replacement(a, r=1)), device=device) + self.assertEqual(c, expected) + + c = torch.combinations(a) + expected = torch.tensor(list(combinations(a, r=2)), device=device) + self.assertEqual(c, expected) + + c = torch.combinations(a, with_replacement=True) + expected = torch.tensor(list(combinations_with_replacement(a, r=2)), device=device) + self.assertEqual(c, expected) + + c = torch.combinations(a, r=3) + expected = torch.tensor(list(combinations(a, r=3)), device=device) + self.assertEqual(c, expected) + + c = torch.combinations(a, r=4) + expected = torch.empty(0, 4, dtype=a.dtype, device=device) + self.assertEqual(c, expected) + + c = torch.combinations(a, r=5) + expected = torch.empty(0, 5, dtype=a.dtype, device=device) + self.assertEqual(c, expected) + + # test empty imput + a = torch.empty(0, device=device) + c1 = torch.combinations(a) + c2 = torch.combinations(a, with_replacement=True) + expected = torch.empty(0, 2, dtype=a.dtype, device=device) + self.assertEqual(c1, expected) + self.assertEqual(c2, expected) + + def test_linlogspace_mem_overlap(self, device): + x = torch.rand(1, device=device).expand(10) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.linspace(1, 10, 10, out=x) + + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.logspace(1, 10, 10, out=x) + + def test_ctor_with_numpy_array(self, device): + correct_dtypes = [ + np.double, + np.float, + np.float16, + np.int64, + np.int32, + np.int16, + np.int8, + np.uint8, + np.bool, + ] + + incorrect_byteorder = '>' if sys.byteorder == 'little' else '<' + incorrect_dtypes = [incorrect_byteorder + t for t in ['d', 'f']] + + for dtype in correct_dtypes: + array = np.array([1, 2, 3, 4], dtype=dtype) + + # Upcast + tensor = torch.DoubleTensor(array).to(device) + for i in range(len(array)): + self.assertEqual(tensor[i], array[i]) + + # Downcast (sometimes) + tensor = torch.FloatTensor(array).to(device) + for i in range(len(array)): + self.assertEqual(tensor[i], array[i]) + + tensor = torch.HalfTensor(array).to(device) + for i in range(len(array)): + self.assertEqual(tensor[i], array[i]) + + @dtypes(torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64) + def test_random(self, device, dtype): + # This test is flaky with p<=(2/(ub-lb))^200=6e-36 + t = torch.empty(200, dtype=dtype, device=device) + lb = 1 + ub = 4 + + t.fill_(-1) + t.random_(lb, ub) + self.assertEqual(t.min(), lb) + self.assertEqual(t.max(), ub - 1) + + t.fill_(-1) + t.random_(ub) + self.assertEqual(t.min(), 0) + self.assertEqual(t.max(), ub - 1) + + def test_random_bool(self, device): + size = 2000 + t = torch.empty(size, dtype=torch.bool, device=device) + + t.fill_(False) + t.random_() + self.assertEqual(t.min(), False) + self.assertEqual(t.max(), True) + self.assertTrue(0.4 < (t.eq(True)).to(torch.int).sum().item() / size < 0.6) + + t.fill_(True) + t.random_() + self.assertEqual(t.min(), False) + self.assertEqual(t.max(), True) + self.assertTrue(0.4 < (t.eq(True)).to(torch.int).sum().item() / size < 0.6) + + def test_random_from_to_bool(self, device): + size = 2000 + + int64_min_val = torch.iinfo(torch.int64).min + int64_max_val = torch.iinfo(torch.int64).max + + min_val = 0 + max_val = 1 + + froms = [int64_min_val, -42, min_val - 1, min_val, max_val, max_val + 1, 42] + tos = [-42, min_val - 1, min_val, max_val, max_val + 1, 42, int64_max_val] + + for from_ in froms: + for to_ in tos: + t = torch.empty(size, dtype=torch.bool, device=device) + if to_ > from_: + if not (min_val <= from_ <= max_val): + self.assertRaisesRegex( + RuntimeError, + "from is out of bounds", + lambda: t.random_(from_, to_) + ) + elif not (min_val <= (to_ - 1) <= max_val): + self.assertRaisesRegex( + RuntimeError, + "to - 1 is out of bounds", + lambda: t.random_(from_, to_) + ) + else: + t.random_(from_, to_) + range_ = to_ - from_ + delta = 1 + self.assertTrue(from_ <= t.to(torch.int).min() < (from_ + delta)) + self.assertTrue((to_ - delta) <= t.to(torch.int).max() < to_) + else: + self.assertRaisesRegex( + RuntimeError, + "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_), + lambda: t.random_(from_, to_) + ) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_random_full_range(self, device, dtype): + size = 2000 + alpha = 0.1 + + int64_min_val = torch.iinfo(torch.int64).min + int64_max_val = torch.iinfo(torch.int64).max + + if dtype == torch.double: + fp_limit = 2**53 + elif dtype == torch.float: + fp_limit = 2**24 + elif dtype == torch.half: + fp_limit = 2**11 + elif dtype == torch.bfloat16: + fp_limit = 2**8 + else: + fp_limit = 0 + + t = torch.empty(size, dtype=dtype, device=device) + + if dtype in [torch.float, torch.double, torch.half, torch.bfloat16]: + from_ = int(max(-fp_limit, int64_min_val)) + to_inc_ = int(min(fp_limit, int64_max_val)) + else: + from_ = int(max(torch.iinfo(dtype).min, int64_min_val)) + to_inc_ = int(min(torch.iinfo(dtype).max, int64_max_val)) + range_ = to_inc_ - from_ + 1 + + t.random_(from_, None) + delta = max(1, alpha * range_) + self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) + self.assertTrue((to_inc_ - delta) < t.to(torch.double).max() <= to_inc_) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_random_from_to(self, device, dtype): + size = 2000 + alpha = 0.1 + + int64_min_val = torch.iinfo(torch.int64).min + int64_max_val = torch.iinfo(torch.int64).max + + if dtype in [torch.float, torch.double, torch.half]: + min_val = int(max(torch.finfo(dtype).min, int64_min_val)) + max_val = int(min(torch.finfo(dtype).max, int64_max_val)) + froms = [min_val, -42, 0, 42] + tos = [-42, 0, 42, max_val >> 1] + elif dtype == torch.bfloat16: + min_val = int64_min_val + max_val = int64_max_val + froms = [min_val, -42, 0, 42] + tos = [-42, 0, 42, max_val >> 1] + elif dtype == torch.uint8: + min_val = torch.iinfo(dtype).min + max_val = torch.iinfo(dtype).max + froms = [int64_min_val, -42, min_val - 1, min_val, 42, max_val, max_val + 1] + tos = [-42, min_val - 1, min_val, 42, max_val, max_val + 1, int64_max_val] + elif dtype == torch.int64: + min_val = int64_min_val + max_val = int64_max_val + froms = [min_val, -42, 0, 42] + tos = [-42, 0, 42, max_val] + else: + min_val = torch.iinfo(dtype).min + max_val = torch.iinfo(dtype).max + froms = [int64_min_val, min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1] + tos = [min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1, int64_max_val] + + if dtype == torch.double: + fp_limit = 2**53 + elif dtype == torch.float: + fp_limit = 2**24 + elif dtype == torch.half: + fp_limit = 2**11 + elif dtype == torch.bfloat16: + fp_limit = 2**8 + else: + fp_limit = 0 + + for from_ in froms: + for to_ in tos: + t = torch.empty(size, dtype=dtype, device=device) + if to_ > from_: + if not (min_val <= from_ <= max_val): + self.assertRaisesRegex( + RuntimeError, + "from is out of bounds", + lambda: t.random_(from_, to_) + ) + elif not (min_val <= (to_ - 1) <= max_val): + self.assertRaisesRegex( + RuntimeError, + "to - 1 is out of bounds", + lambda: t.random_(from_, to_) + ) + else: + if dtype.is_floating_point and ( + not (-fp_limit <= from_ <= fp_limit) or not (-fp_limit <= (to_ - 1) <= fp_limit)): + if not (-fp_limit <= from_ <= fp_limit): + self.assertWarnsRegex(UserWarning, "from is out of bounds", + lambda: t.random_(from_, to_)) + if not (-fp_limit <= (to_ - 1) <= fp_limit): + self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds", + lambda: t.random_(from_, to_)) + else: + t.random_(from_, to_) + range_ = to_ - from_ + delta = max(1, alpha * range_) + if dtype == torch.bfloat16: + # Less strict checks because of rounding errors + # TODO investigate rounding errors + self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) + self.assertTrue((to_ - delta) < t.to(torch.double).max() <= to_) + else: + self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) + self.assertTrue((to_ - delta) <= t.to(torch.double).max() < to_) + else: + self.assertRaisesRegex( + RuntimeError, + "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_), + lambda: t.random_(from_, to_) + ) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_random_to(self, device, dtype): + size = 2000 + alpha = 0.1 + + int64_min_val = torch.iinfo(torch.int64).min + int64_max_val = torch.iinfo(torch.int64).max + + if dtype in [torch.float, torch.double, torch.half]: + min_val = int(max(torch.finfo(dtype).min, int64_min_val)) + max_val = int(min(torch.finfo(dtype).max, int64_max_val)) + tos = [-42, 0, 42, max_val >> 1] + elif dtype == torch.bfloat16: + min_val = int64_min_val + max_val = int64_max_val + tos = [-42, 0, 42, max_val >> 1] + elif dtype == torch.uint8: + min_val = torch.iinfo(dtype).min + max_val = torch.iinfo(dtype).max + tos = [-42, min_val - 1, min_val, 42, max_val, max_val + 1, int64_max_val] + elif dtype == torch.int64: + min_val = int64_min_val + max_val = int64_max_val + tos = [-42, 0, 42, max_val] + else: + min_val = torch.iinfo(dtype).min + max_val = torch.iinfo(dtype).max + tos = [min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1, int64_max_val] + + from_ = 0 + for to_ in tos: + t = torch.empty(size, dtype=dtype, device=device) + if to_ > from_: + if not (min_val <= (to_ - 1) <= max_val): + self.assertRaisesRegex( + RuntimeError, + "to - 1 is out of bounds", + lambda: t.random_(from_, to_) + ) + else: + t.random_(to_) + range_ = to_ - from_ + delta = max(1, alpha * range_) + if dtype == torch.bfloat16: + # Less strict checks because of rounding errors + # TODO investigate rounding errors + self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) + self.assertTrue((to_ - delta) < t.to(torch.double).max() <= to_) + else: + self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) + self.assertTrue((to_ - delta) <= t.to(torch.double).max() < to_) + else: + self.assertRaisesRegex( + RuntimeError, + "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_), + lambda: t.random_(from_, to_) + ) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_random_default(self, device, dtype): + size = 2000 + alpha = 0.1 + + if dtype == torch.float: + to_inc = 1 << 24 + elif dtype == torch.double: + to_inc = 1 << 53 + elif dtype == torch.half: + to_inc = 1 << 11 + elif dtype == torch.bfloat16: + to_inc = 1 << 8 + else: + to_inc = torch.iinfo(dtype).max + + t = torch.empty(size, dtype=dtype, device=device) + t.random_() + self.assertTrue(0 <= t.to(torch.double).min() < alpha * to_inc) + self.assertTrue((to_inc - alpha * to_inc) < t.to(torch.double).max() <= to_inc) + # TODO: this test should be updated @onlyOnCPUAndCUDA def test_empty_full(self, device): @@ -54,9 +1656,7 @@ def test_tensor_device(self, devices): torch.ones((2, 3), dtype=torch.float32, device='cpu:0').device.type) self.assertEqual('cpu', torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cpu:0').device.type) - - if TEST_NUMPY: - self.assertEqual('cpu', torch.tensor(np.random.randn(2, 3), device='cpu').device.type) + self.assertEqual('cpu', torch.tensor(np.random.randn(2, 3), device='cpu').device.type) if device_type == 'cuda': self.assertEqual('cuda:0', str(torch.tensor(5).cuda(0).device)) self.assertEqual('cuda:0', str(torch.tensor(5).cuda('cuda:0').device)) @@ -67,8 +1667,7 @@ def test_tensor_device(self, devices): self.assertEqual('cuda:0', str(torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:0').device)) - if TEST_NUMPY: - self.assertEqual('cuda:0', str(torch.tensor(np.random.randn(2, 3), device='cuda:0').device)) + self.assertEqual('cuda:0', str(torch.tensor(np.random.randn(2, 3), device='cuda:0').device)) for device in devices: with torch.cuda.device(device): @@ -92,9 +1691,8 @@ def test_tensor_device(self, devices): str(torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:1').device)) - if TEST_NUMPY: - self.assertEqual('cuda:1', - str(torch.tensor(np.random.randn(2, 3), device='cuda:1').device)) + self.assertEqual('cuda:1', + str(torch.tensor(np.random.randn(2, 3), device='cuda:1').device)) # TODO: this test should be updated @onlyOnCPUAndCUDA @@ -265,13 +1863,12 @@ def test_tensor_factory(self, device): self.assertIs(torch.int, res1.dtype) # test copy with numpy - if TEST_NUMPY: - for dtype in [np.float64, np.int64, np.int8, np.uint8]: - a = np.array([5.]).astype(dtype) - res1 = torch.tensor(a) - self.assertEqual(5., res1[0].item()) - a[0] = 7. - self.assertEqual(5., res1[0].item()) + for dtype in [np.float64, np.int64, np.int8, np.uint8]: + a = np.array([5.]).astype(dtype) + res1 = torch.tensor(a) + self.assertEqual(5., res1[0].item()) + a[0] = 7. + self.assertEqual(5., res1[0].item()) # test boolean tensor a = torch.tensor([True, True, False, True, True], dtype=torch.bool) @@ -348,17 +1945,16 @@ def test_inference(default_dtype): self.assertIs(torch.int64, torch.tensor(((5, 3), (3, 5))).dtype) self.assertIs(default_complex_dtype, torch.tensor(((5, 3 + 2j), (3, 5 + 4j))).dtype) - if TEST_NUMPY: - self.assertIs(torch.float64, torch.tensor(np.array(())).dtype) - self.assertIs(torch.float64, torch.tensor(np.array(5.)).dtype) - if np.array(5).dtype == np.int64: # np long, which can be 4 bytes (e.g. on windows) - self.assertIs(torch.int64, torch.tensor(np.array(5)).dtype) - else: - self.assertIs(torch.int32, torch.tensor(np.array(5)).dtype) - self.assertIs(torch.uint8, torch.tensor(np.array(3, dtype=np.uint8)).dtype) - self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype) - self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype) - self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype) + self.assertIs(torch.float64, torch.tensor(np.array(())).dtype) + self.assertIs(torch.float64, torch.tensor(np.array(5.)).dtype) + if np.array(5).dtype == np.int64: # np long, which can be 4 bytes (e.g. on windows) + self.assertIs(torch.int64, torch.tensor(np.array(5)).dtype) + else: + self.assertIs(torch.int32, torch.tensor(np.array(5)).dtype) + self.assertIs(torch.uint8, torch.tensor(np.array(3, dtype=np.uint8)).dtype) + self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype) + self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype) + self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype) torch.set_default_dtype(saved_dtype) test_inference(torch.float64) @@ -386,13 +1982,12 @@ def test_new_tensor(self, device): self.assertIs(torch.int, res2.dtype) # test copy with numpy - if TEST_NUMPY: - a = np.array([5.]) - res1 = torch.tensor(a) - res1 = res1.new_tensor(a) - self.assertEqual(5., res1[0].item()) - a[0] = 7. - self.assertEqual(5., res1[0].item()) + a = np.array([5.]) + res1 = torch.tensor(a) + res1 = res1.new_tensor(a) + self.assertEqual(5., res1[0].item()) + a[0] = 7. + self.assertEqual(5., res1[0].item()) if torch.cuda.device_count() >= 2: expected = expected.cuda(1) @@ -453,29 +2048,28 @@ def test_as_tensor(self, device): self.assertIs(y_cuda, torch.as_tensor(y_cuda)) self.assertIs(y_cuda, torch.as_tensor(y_cuda, device='cuda')) - if TEST_NUMPY: - # doesn't copy - for dtype in [np.float64, np.int64, np.int8, np.uint8]: - n = np.random.rand(5, 6).astype(dtype) - n_astensor = torch.as_tensor(n) - self.assertEqual(torch.tensor(n), n_astensor) - n_astensor[0][0] = 25.7 - self.assertEqual(torch.tensor(n), n_astensor) - - # changing dtype causes copy - n = np.random.rand(5, 6).astype(np.float32) - n_astensor = torch.as_tensor(n, dtype=torch.float64) - self.assertEqual(torch.tensor(n, dtype=torch.float64), n_astensor) - n_astensor[0][1] = 250.8 - self.assertNotEqual(torch.tensor(n, dtype=torch.float64), n_astensor) - - # changing device causes copy - if torch.cuda.is_available(): - n = np.random.randn(5, 6) - n_astensor = torch.as_tensor(n, device='cuda') - self.assertEqual(torch.tensor(n, device='cuda'), n_astensor) - n_astensor[0][2] = 250.9 - self.assertNotEqual(torch.tensor(n, device='cuda'), n_astensor) + # doesn't copy + for dtype in [np.float64, np.int64, np.int8, np.uint8]: + n = np.random.rand(5, 6).astype(dtype) + n_astensor = torch.as_tensor(n) + self.assertEqual(torch.tensor(n), n_astensor) + n_astensor[0][0] = 25.7 + self.assertEqual(torch.tensor(n), n_astensor) + + # changing dtype causes copy + n = np.random.rand(5, 6).astype(np.float32) + n_astensor = torch.as_tensor(n, dtype=torch.float64) + self.assertEqual(torch.tensor(n, dtype=torch.float64), n_astensor) + n_astensor[0][1] = 250.8 + self.assertNotEqual(torch.tensor(n, dtype=torch.float64), n_astensor) + + # changing device causes copy + if torch.cuda.is_available(): + n = np.random.randn(5, 6) + n_astensor = torch.as_tensor(n, device='cuda') + self.assertEqual(torch.tensor(n, device='cuda'), n_astensor) + n_astensor[0][2] = 250.9 + self.assertNotEqual(torch.tensor(n, device='cuda'), n_astensor) # TODO: this test should be updated @suppress_warnings @@ -680,6 +2274,25 @@ def test_empty_strided(self, device): self.assertEqual(empty_strided.shape, as_strided.shape) self.assertEqual(empty_strided.stride(), as_strided.stride()) + def test_new_empty_strided(self, device): + def _test(sizes, strides, dtype): + x = torch.zeros(5, 5, dtype=dtype, device=device) + result = x.new_empty_strided(sizes, strides) + expected = torch.empty_strided(sizes, strides, dtype=x.dtype, device=x.device) + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.stride(), expected.stride()) + self.assertEqual(result.dtype, expected.dtype) + self.assertEqual(result.device, expected.device) + + _test([2, 3], [3, 1], torch.float) + _test([5, 3], [0, 1], torch.int) + _test([], [], torch.float) + + # Some really weird cases + for shape in [(2, 3, 4), (0, 2, 0)]: + for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]: + _test(shape, strides, torch.float) + def test_strided_mismatched_stride_shape(self, device): for shape, strides in [((1, ), ()), ((1, 2), (1, ))]: with self.assertRaisesRegex(RuntimeError, "mismatch in length of strides and shape"): @@ -702,6 +2315,23 @@ def test_eye(self, device): for dtype in torch.testing.get_all_dtypes(): if dtype == torch.bfloat16: continue + # Test the RuntimeError is raised when either m or n is a negative number + for n, m in ((-1, 1), (1, -1), (-1, -1)): + with self.assertRaisesRegex(RuntimeError, 'must be greater or equal to'): + torch.eye(n, m, device=device, dtype=dtype) + + # Test when the `m` parameter is not provided + for n in (3, 5, 7): + res1 = torch.eye(n, device=device, dtype=dtype) + naive_eye = torch.zeros(n, n, dtype=dtype, device=device) + naive_eye.diagonal(dim1=-2, dim2=-1).fill_(1) + self.assertEqual(naive_eye, res1) + + # Check eye_out outputs + res2 = torch.empty(0, device=device, dtype=dtype) + torch.eye(n, out=res2) + self.assertEqual(res1, res2) + for n, m in product([3, 5, 7], repeat=2): # Construct identity using diagonal and fill res1 = torch.eye(n, m, device=device, dtype=dtype) @@ -714,7 +2344,6 @@ def test_eye(self, device): torch.eye(n, m, out=res2) self.assertEqual(res1, res2) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @precisionOverride({torch.float: 1e-8, torch.double: 1e-10}) @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False) + torch.testing.get_all_complex_dtypes())) @@ -730,7 +2359,6 @@ def test_linspace_vs_numpy(self, device, dtype): self.assertTrue(t[0].item() == a[0]) self.assertTrue(t[steps - 1].item() == a[steps - 1]) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @precisionOverride({torch.float: 1e-6, torch.double: 1e-10}) @dtypes(*torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)) def test_logspace_vs_numpy(self, device, dtype): @@ -757,7 +2385,8 @@ def test_linspace_steps_warning(self, device, dtype): def test_logspace_steps_warning(self, device, dtype): self._linspace_logspace_warning_helper(torch.logspace, device, dtype) - @largeCUDATensorTest('16GB') + @onlyCUDA + @largeTensorTest('16GB') def test_range_factories_64bit_indexing(self, device): bigint = 2 ** 31 + 1 t = torch.arange(bigint, dtype=torch.long, device=device) @@ -805,6 +2434,37 @@ def test_tensor_ctor_device_inference(self, device): sparse_size, dtype=torch.float64) self.assertEqual(sparse_with_dtype.device, torch.device('cpu')) + @onlyOnCPUAndCUDA + @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) + @unittest.skipIf(not TEST_SCIPY, "Scipy not found") + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) + @dtypesIfCPU(torch.float, torch.double, torch.long) + def test_signal_window_functions(self, device, dtype): + import scipy.signal as signal + + def test(name, kwargs): + torch_method = getattr(torch, name + '_window') + if not dtype.is_floating_point: + with self.assertRaisesRegex(RuntimeError, r'floating point'): + torch_method(3, dtype=dtype) + return + for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]: + for periodic in [True, False]: + res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype) + # NB: scipy always returns a float64 result + ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic)) + self.assertEqual(res, ref, exact_dtype=False) + with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): + torch_method(3, layout=torch.sparse_coo) + self.assertTrue(torch_method(3, requires_grad=True).requires_grad) + self.assertFalse(torch_method(3).requires_grad) + + for window in ['hann', 'hamming', 'bartlett', 'blackman']: + test(window, kwargs={}) + + for num_test in range(50): + test('kaiser', kwargs={'beta': random.random() * 30}) + def test_tensor_factories_empty(self, device): # ensure we can create empty tensors from each factory function shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)] @@ -887,7 +2547,6 @@ def test_arange_device_vs_cpu(self, device, dtype): self.assertEqual(cpu_tensor, device_tensor) @onlyCUDA - @skipCUDAIfNotRocm def test_arange_bfloat16(self, device): ref_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.bfloat16, device=device) bfloat16_tensor = torch.arange(0, 4, dtype=torch.bfloat16, device=device) @@ -1048,7 +2707,9 @@ def test_logspace_special_steps(self, device, dtype): self._test_logspace_base2(device, dtype, steps=steps) @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False, include_complex=False)) - @dtypesIfCUDA(*torch.testing.get_all_dtypes(include_bool=False, include_half=True, include_complex=False)) + @dtypesIfCUDA(*((torch.testing.get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16]) + if TEST_WITH_ROCM + else torch.testing.get_all_dtypes(include_bool=False, include_half=True, include_complex=False))) def test_logspace(self, device, dtype): _from = random.random() to = _from + random.random() @@ -1126,12 +2787,200 @@ def test_full_out(self, device): self.assertEqual(torch.full(o.shape, 1., out=o).dtype, o.dtype) self.assertEqual(torch.full(size, 1, out=o).dtype, o.dtype) + # check that warning for numpy being not writable is suppressed + # when a copy of it is being created. + # see issue #47160 + def test_tensor_from_non_writable_numpy(self, device): + with warnings.catch_warnings(record=True) as w: + a = np.arange(5.) + a.flags.writeable = False + t = torch.tensor(a) + self.assertEqual(len(w), 0) # Class for testing random tensor creation ops, like torch.randint class TestRandomTensorCreation(TestCase): exact_dtype = True + # TODO: add torch.complex64, torch.complex128 + @dtypes(torch.float, torch.double) + def test_normal(self, device, dtype): + + def helper(self, device, dtype, ptype, t_transform, std_transform): + q = torch.empty(100, 100, dtype=dtype, device=device) + + q.normal_() + self.assertEqual(t_transform(q).mean(), 0, atol=0.2, rtol=0) + self.assertEqual(t_transform(q).std(), std_transform(1), atol=0.2, rtol=0) + + q.normal_(2, 3) + self.assertEqual(t_transform(q).mean(), 2, atol=0.3, rtol=0) + self.assertEqual(t_transform(q).std(), std_transform(3), atol=0.3, rtol=0) + + q = torch.empty(100, 100, dtype=dtype, device=device) + q_row1 = q[0:1].clone() + q[99:100].normal_() + self.assertEqual(t_transform(q[99:100]).mean(), 0, atol=0.2, rtol=0) + self.assertEqual(t_transform(q[99:100]).std(), std_transform(1), atol=0.2, rtol=0) + self.assertEqual(t_transform(q[0:1]).clone(), t_transform(q_row1)) + + mean = torch.empty(100, 100, dtype=dtype, device=device) + mean[:50].fill_(ptype(0)) + mean[50:].fill_(ptype(1)) + + std = torch.empty(100, 100, dtype=torch.float, device=device) + std[:, :50] = 4 + std[:, 50:] = 1 + + r = torch.normal(mean) + self.assertEqual(r.dtype, dtype) + self.assertEqual(str(r.device), device) + self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) + self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) + self.assertEqual(t_transform(r).std(), std_transform(1), atol=0.2, rtol=0) + + r.fill_(42) + r = torch.normal(mean, 3) + self.assertEqual(r.dtype, dtype) + self.assertEqual(str(r.device), device) + self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) + self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) + self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.2, rtol=0) + + r.fill_(42) + torch.normal(mean, 3, out=r) + self.assertEqual(r.dtype, dtype) + self.assertEqual(str(r.device), device) + self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) + self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) + self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.2, rtol=0) + + r.fill_(42) + r = torch.normal(2, std) + self.assertFalse(r.dtype.is_complex) + self.assertEqual(str(r.device), device) + self.assertEqual(r.mean(), 2, atol=0.2, rtol=0) + self.assertEqual(r[:, :50].std(), 4, atol=0.3, rtol=0) + self.assertEqual(r[:, 50:].std(), 1, atol=0.2, rtol=0) + + r.fill_(42) + torch.normal(2, std, out=r) + self.assertFalse(r.dtype.is_complex) + self.assertEqual(str(r.device), device) + self.assertEqual(r.mean(), 2, atol=0.2, rtol=0) + self.assertEqual(r[:, :50].std(), 4, atol=0.3, rtol=0) + self.assertEqual(r[:, 50:].std(), 1, atol=0.2, rtol=0) + + r.fill_(42) + r = torch.normal(mean, std) + self.assertEqual(r.dtype, dtype) + self.assertEqual(str(r.device), device) + self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) + self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) + self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0) + self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0) + + r.fill_(42) + torch.normal(mean, std, out=r) + self.assertEqual(r.dtype, dtype) + self.assertEqual(str(r.device), device) + self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) + self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) + self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0) + self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0) + + r.fill_(42) + r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device) + self.assertEqual(r.dtype, dtype) + self.assertEqual(str(r.device), device) + self.assertEqual(t_transform(r).mean(), 2, atol=0.3, rtol=0) + self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.3, rtol=0) + + r.fill_(42) + torch.normal(2, 3, (100, 100), dtype=dtype, device=device, out=r) + self.assertEqual(r.dtype, dtype) + self.assertEqual(str(r.device), device) + self.assertEqual(t_transform(r).mean(), 2, atol=0.3, rtol=0) + self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.3, rtol=0) + + if dtype.is_complex: + helper(self, device, dtype, lambda x: complex(x, x), + lambda t: torch.real(t).to(torch.float), lambda mean: mean / math.sqrt(2)) + helper(self, device, dtype, lambda x: complex(x, x), + lambda t: torch.imag(t).to(torch.float), lambda mean: mean / math.sqrt(2)) + self.assertRaisesRegex( + RuntimeError, "normal expects standard deviation to be non-complex", + lambda: torch.normal(0, torch.empty(100, 100, dtype=dtype, device=device))) + out = torch.empty(100, 100, dtype=dtype, device=device) + self.assertRaisesRegex( + RuntimeError, "normal expects standard deviation to be non-complex", + lambda: torch.normal(0, torch.empty(100, 100, dtype=dtype, device=device), out=out)) + else: + helper(self, device, dtype, lambda x: x, lambda t: t, lambda mean: mean) + + @dtypes(torch.float, torch.double, torch.half) + @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.bfloat16) + def test_uniform_from_to(self, device, dtype): + size = 2000 + alpha = 0.1 + + float_min = torch.finfo(torch.float).min + float_max = torch.finfo(torch.float).max + double_min = torch.finfo(torch.double).min + double_max = torch.finfo(torch.double).max + + if dtype == torch.bfloat16: + min_val = -3.389531389251535e+38 + max_val = 3.389531389251535e+38 + else: + min_val = torch.finfo(dtype).min + max_val = torch.finfo(dtype).max + + values = [double_min, float_min, -42, 0, 42, float_max, double_max] + + for from_ in values: + for to_ in values: + t = torch.empty(size, dtype=dtype, device=device) + if not (min_val <= from_ <= max_val) or not (min_val <= to_ <= max_val): + pass + elif to_ < from_: + self.assertRaisesRegex( + RuntimeError, + "uniform_ expects to return", + lambda: t.uniform_(from_, to_) + ) + elif to_ - from_ > max_val: + self.assertRaisesRegex( + RuntimeError, + "uniform_ expects to-from", + lambda: t.uniform_(from_, to_) + ) + else: + t.uniform_(from_, to_) + range_ = to_ - from_ + if not (dtype == torch.bfloat16) and not ( + dtype == torch.half and device == 'cpu') and not torch.isnan(t).all(): + delta = alpha * range_ + double_t = t.to(torch.double) + if range_ == 0: + self.assertTrue(double_t.min() == from_) + self.assertTrue(double_t.max() == to_) + elif dtype == torch.half: + self.assertTrue(from_ <= double_t.min() <= (from_ + delta)) + self.assertTrue((to_ - delta) <= double_t.max() <= to_) + else: + self.assertTrue(from_ <= double_t.min() <= (from_ + delta)) + self.assertTrue((to_ - delta) <= double_t.max() < to_) + + def test_random_neg_values(self, device): + SIZE = 10 + signed_dtypes = [torch.double, torch.float, torch.long, torch.int, torch.short] + for dtype in signed_dtypes: + res = torch.rand(SIZE, SIZE).to(device=device, dtype=dtype) + res.random_(-10, -1) + self.assertLessEqual(res.max().item(), 9) + self.assertGreaterEqual(res.min().item(), -10) + # TODO: this test should be updated @onlyCPU def test_randint_inference(self, device): @@ -1180,7 +3029,7 @@ def seed(generator): self.assertTrue((res1 < 6).all().item()) self.assertTrue((res1 >= 0).all().item()) - @dtypes(torch.half, torch.float, torch.double, + @dtypes(torch.half, torch.float, torch.bfloat16, torch.double, torch.complex32, torch.complex64, torch.complex128) def test_randn(self, device, dtype): SIZE = 100 @@ -1253,6 +3102,20 @@ def test_randperm(self, device): torch.randperm(n, out=non_contiguous_tensor) self.assertEqual(non_contiguous_tensor, res) + # Test exceptions when device and generator types are incompatible + @onlyCUDA + def test_randperm_device_compatibility(self, device): + cuda_gen = torch.Generator(device='cuda') + cpu_gen = torch.Generator(device='cpu') + for n in (0, 3, 100, 30000): + regex = 'Expected a .* generator device but found .*' + cuda_t = torch.tensor(n, device='cuda') + self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cuda', generator=cpu_gen)) + self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cuda', generator=cpu_gen, out=cuda_t)) + cpu_t = torch.tensor(n, device='cpu') + self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cpu', generator=cuda_gen)) + self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cpu', generator=cuda_gen, out=cpu_t)) + self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, generator=cuda_gen)) # implicitly on CPU # Class for testing *like ops, like torch.ones_like class TestLikeTensorCreation(TestCase): @@ -1311,7 +3174,6 @@ def test_full_like_inference(self, device): self.assertEqual(torch.full_like(like, 1., dtype=torch.complex64).dtype, torch.complex64) - instantiate_device_type_tests(TestTensorCreation, globals()) instantiate_device_type_tests(TestRandomTensorCreation, globals()) instantiate_device_type_tests(TestLikeTensorCreation, globals()) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 143c6dab91d2c..4b523379bc4d4 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -4,13 +4,15 @@ from torch import nn import unittest -from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs +from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests from torch.testing._internal.te_utils import CudaCodeGenCreated, CudaCodeGenExecuted, \ LLVMCodeGenExecuted, SimpleIREvalExecuted +from torch.testing._internal.jit_utils import JitTestCase -class BaseTestClass(unittest.TestCase): + +class BaseTestClass(JitTestCase): def setUp(self): self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) self.old_profiling_mode = torch._C._jit_set_profiling_mode(True) @@ -21,6 +23,10 @@ def setUp(self): torch._C._jit_override_can_fuse_on_gpu(True) self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() torch._C._jit_set_texpr_fuser_enabled(True) + self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(False) + + self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] def tearDown(self): torch._C._jit_set_profiling_executor(self.old_profiling_executor) @@ -29,6 +35,16 @@ def tearDown(self): torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) + torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) + + def assertLastGraphAllFused(self): + self.assertAllFused(torch.jit.last_executed_optimized_graph()) + + +def warmup_and_run_forward(f, *args): + for _ in range(torch._C._jit_get_num_profiled_runs() + 1): + results = f(*args) + return results class TestTensorExprFuser(BaseTestClass): @@ -41,7 +57,8 @@ def easy(x, y): a = torch.rand(1024) b = torch.rand(1024) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) def test_three_arg(self): @@ -60,7 +77,8 @@ def easy(x, y, z): a = torch.rand(1024) b = torch.rand(1024) c = torch.rand(1024) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = a.numpy() + b.numpy() + c.numpy() np.testing.assert_allclose(npr, x.numpy()) assert ( @@ -90,7 +108,8 @@ def run_addcmul(x, y, z, w): ), ) - x = traced(rand_a, rand_b, rand_c, rand_d) + x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d) + self.assertLastGraphAllFused() y = run_addcmul(rand_a, rand_b, rand_c, rand_d) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) @@ -120,6 +139,8 @@ def test(x, y, z): b = torch.rand(M, N, device="cuda") c = torch.rand(M, N, device="cuda") x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() np.testing.assert_allclose(npr, x.cpu().numpy()) assert cuda_cg_executed.elapsed_value() >= 1 @@ -155,7 +176,8 @@ def test(x, y, z): a = torch.rand(*a_shape, device="cuda") b = torch.rand(*b_shape, device="cuda") c = torch.rand(*c_shape, device="cuda") - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() np.testing.assert_allclose(npr, x.cpu().numpy()) assert cuda_cg_executed.elapsed_value() >= 1 @@ -187,7 +209,8 @@ def np_easy(x, y, z): a = torch.rand(1024) b = torch.rand(1024) c = torch.rand(1024) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = np_easy(a.numpy(), b.numpy(), c.numpy()) np.testing.assert_allclose(npr, x.numpy()) @@ -214,7 +237,8 @@ def np_easy(x, y, z): a = torch.rand(shape) b = torch.rand(shape) c = torch.rand(shape) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = np_easy(a.numpy(), b.numpy(), c.numpy()) np.testing.assert_allclose(npr, x.numpy()) @@ -235,10 +259,12 @@ def np_easy(x, y, z): a = torch.rand(N, N) b = torch.rand(N) c = torch.rand(N, N) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() npr = np_easy(a.numpy(), b.numpy(), c.numpy()) np.testing.assert_allclose(npr, x.numpy()) + @unittest.skip("temporarily disable") def test_broadcast_2(self): zero = torch.tensor([0.0], dtype=torch.float) @@ -257,10 +283,13 @@ def foo_np(x, y, z): z = torch.rand(4) traced = torch.jit.trace(foo, (x, y, z)) - r = traced(x, y, z) + r = warmup_and_run_forward(traced, x, y, z) + self.assertLastGraphAllFused() + rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) np.testing.assert_allclose(r, rnp) + @unittest.skip("temporarily disable") def test_broadcast_big2(self): zero = torch.tensor([0.0], dtype=torch.float) @@ -279,7 +308,8 @@ def foo_np(x, y, z): z = torch.rand(1024) traced = torch.jit.trace(foo, (x, y, z)) - r = traced(x, y, z) + r = warmup_and_run_forward(traced, x, y, z) + self.assertLastGraphAllFused() rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) np.testing.assert_allclose(r, rnp) @@ -295,6 +325,7 @@ def alpha(x): np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy()) @suppress_warnings + @unittest.skip("temporarily disable") def test_constant(self): def constant(x): bbb = torch.tensor([1.0]) @@ -304,7 +335,8 @@ def constant(x): traced = torch.jit.trace(constant, (torch.tensor([1.0]))) a = torch.tensor([1.0]) - x = traced(a) + x = warmup_and_run_forward(traced, a) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + 1.0, x.numpy()) def test_add_sub(self): @@ -320,7 +352,8 @@ def easy(x, y, z): a = torch.rand(1024) b = torch.rand(1024) c = torch.rand(1024) - x = traced(a, b, c) + x = warmup_and_run_forward(traced, a, b, c) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy()) def test_promotion(self): @@ -335,7 +368,8 @@ def easy(x, y): a = torch.zeros(1024, dtype=torch.int32) b = torch.rand(1024, dtype=torch.float32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) def test_double(self): @@ -353,7 +387,8 @@ def easy(x, y): a = torch.rand(TENSOR_LEN, dtype=torch.double) b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_short(self): @@ -372,7 +407,8 @@ def easy(x, y): a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_char(self): @@ -386,12 +422,13 @@ def easy(x, y): traced = torch.jit.trace( easy, (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), - torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8)), + torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)), ) a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) - b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8) - x = traced(a, b) + b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_int64_promotion(self): @@ -410,7 +447,8 @@ def easy(x, y): a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) def test_eq(self): @@ -421,7 +459,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) a = torch.zeros(1024, dtype=torch.int32) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_ne(self): @@ -432,7 +471,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) a = torch.zeros(1024, dtype=torch.int32) b = torch.ones(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_ge(self): @@ -445,7 +485,8 @@ def easy(x, y): aa.fill(5) a = torch.from_numpy(aa) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_gt(self): @@ -456,7 +497,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) a = torch.ones(1024, dtype=torch.int32) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_le(self): @@ -469,7 +511,8 @@ def easy(x, y): aa.fill(5) a = torch.from_numpy(aa) b = torch.zeros(1024, dtype=torch.int32) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.zeros(1024), x.numpy()) def test_lt(self): @@ -482,7 +525,8 @@ def easy(x, y): traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) a = torch.ones(1024, dtype=torch.int32, device=dev) b = torch.zeros(1024, dtype=torch.int32, device=dev) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy()) @suppress_warnings @@ -494,24 +538,29 @@ def test(x, y): a = 8.0 * torch.rand(1024) b = 8.0 * torch.rand(1024) np.testing.assert_allclose( - traced(a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) + warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) ) + self.assertLastGraphAllFused() + @unittest.skip("temporarily disable") def test_min_max_reduction(self): def test(x): return torch.min(x) + torch.max(x) traced = torch.jit.trace(test, (torch.zeros(1024))) a = 8.0 * torch.rand(1024) - np.testing.assert_allclose(traced(a), np.amin(a.numpy()) + np.amax(a.numpy())) + np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) + self.assertLastGraphAllFused() + @unittest.skip("temporarily disable") def test_min_max_reduction2(self): def test(x): return x.min() + x.max() traced = torch.jit.trace(test, (torch.zeros(1024))) a = 8.0 * torch.rand(1024) - np.testing.assert_allclose(traced(a), np.amin(a.numpy()) + np.amax(a.numpy())) + np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) + self.assertLastGraphAllFused() def test_min_max_reduction_dim1(self): def test(x): @@ -519,15 +568,19 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(16, 16))) a = 8.0 * torch.rand(16, 16) - np.testing.assert_allclose(traced(a), np.amin(a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin( + a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) + self.assertLastGraphAllFused() + @unittest.skip("temporarily disable") def test_min_max_reduction_dim1_2(self): def test(x): return torch.min(x, 1) traced = torch.jit.trace(test, (torch.zeros(16, 16))) a = 8.0 * torch.rand(16, 16) - np.testing.assert_allclose(traced(a)[0], np.amin(a.numpy(), axis=1)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin(a.numpy(), axis=1)) + self.assertLastGraphAllFused() def test_clamp(self): def test(x): @@ -539,7 +592,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) a = 20.0 * torch.rand(1024, device=dev) - 10.0 an = a.cpu().numpy() - np.testing.assert_allclose(traced(a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) + self.assertLastGraphAllFused() def test_relu(self): def test(x): @@ -550,7 +604,8 @@ def test(x): traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) a = 20.0 * torch.rand(1024, device=dev) - 10.0 an = a.cpu().numpy() - np.testing.assert_allclose(traced(a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) + np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) + self.assertLastGraphAllFused() def test_reps(self): def easy(x, y): @@ -562,7 +617,7 @@ def easy(x, y): for _ in range(32): a = torch.ones(1024) b = torch.zeros(1024) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) np.testing.assert_allclose(np.ones(1024), x.numpy()) def test_add_const_rhs(self): @@ -571,7 +626,8 @@ def test(x): traced = torch.jit.trace(test, torch.rand(4)) x = torch.rand(4) - y = traced(x) + y = warmup_and_run_forward(traced, x) + self.assertLastGraphAllFused() np.testing.assert_allclose(x.numpy() + 3.0, y.numpy()) def test_int_output(self): @@ -582,7 +638,8 @@ def test(x, y, z): x, y, z = xs xn, yn, zn = [t.numpy() for t in xs] traced = torch.jit.trace(test, (x, y, z)) - res = traced(x, y, z) + res = warmup_and_run_forward(traced, x, y, z) + self.assertLastGraphAllFused() np.testing.assert_allclose(xn * yn * zn, res.numpy()) def test_binary_ops(self): @@ -670,7 +727,8 @@ def test_type_as(x, y): in1 = 20 * torch.rand(1024, device=dev) in2 = 20 * torch.rand(1024, device=dev) traced = torch.jit.trace(torch_fn, (in1, in2)) - x = traced(rand_a, rand_b) + x = warmup_and_run_forward(traced, rand_a, rand_b) + self.assertLastGraphAllFused() y = torch_fn(rand_a, rand_b) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) @@ -844,6 +902,7 @@ def test_threshold(x, y): for torch_fn in fns: for dev in device_options: + # print(torch_fn, dev) rand_a = torch.rand(1024, device=dev) rand_b = torch.rand(1024, device=dev) ins = 20 * torch.rand(1024, device=dev) @@ -851,19 +910,22 @@ def test_threshold(x, y): cc.fill(np.nan) nans = torch.from_numpy(cc).to(dev) traced = torch.jit.trace(torch_fn, (ins, ins)) - x = traced(rand_a, rand_b) + x = warmup_and_run_forward(traced, rand_a, rand_b) + self.assertLastGraphAllFused() y = torch_fn(rand_a, rand_b) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) # nans - traced = torch.jit.trace(torch_fn, (ins, ins)) - x = traced(nans, rand_b) - y = torch_fn(nans, rand_b) - try: - np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) - except AssertionError: - # Print extra info before exiting: - print("Failed on dev=", dev, "function=", torch_fn) - np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + # TODO: reenable. Currently all of the tests fail + # traced = torch.jit.trace(torch_fn, (ins, ins)) + # x = warmup_and_run_forward(traced, rand_a, rand_b) + # y = torch_fn(nans, rand_b) + # try: + # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + # print("Succeeded on dev=", dev, "function=", torch_fn) + # except AssertionError: + # # Print extra info before exiting: + # print("Failed on dev=", dev, "function=", torch_fn) + # # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) def test_rand_like(self): devices = ["cuda"] if torch.cuda.is_available() else [] @@ -875,7 +937,8 @@ def run_rand_like(x, y): for device in devices: x = torch.rand(N, device=device) traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) - x_v = traced(x, x) + x_v = warmup_and_run_forward(traced, x, x) + self.assertLastGraphAllFused() x_np = x.cpu().numpy() x1_mean = np.mean(x_np) x2_mean = np.mean(x_np ** 2) @@ -897,10 +960,25 @@ def test_min(x, y): x = torch.tensor([np.nan]) y = torch.tensor([1.0]) - assert np.isnan(tmin(x, y).item()) - assert np.isnan(tmin(y, x).item()) - assert np.isnan(tmax(x, y).item()) - assert np.isnan(tmax(y, x).item()) + assert np.isnan(warmup_and_run_forward(tmin, x, y).item()) + assert np.isnan(warmup_and_run_forward(tmin, y, x).item()) + self.assertLastGraphAllFused() + assert np.isnan(warmup_and_run_forward(tmax, x, y).item()) + assert np.isnan(warmup_and_run_forward(tmax, y, x).item()) + self.assertLastGraphAllFused() + + def test_double_intrinsics(self): + # TODO: add "cpu" device once `pow` is supported there + devices = ["cuda"] if torch.cuda.is_available() else [] + + def do_pow(x): + return torch.pow(x, 7) + + for device in devices: + x = torch.rand(10, dtype=torch.double, device=device) + traced = torch.jit.trace(do_pow, (x)) + x = warmup_and_run_forward(traced, x) + self.assertLastGraphAllFused() def test_remainder(self): def run_remainder(x, y): @@ -916,19 +994,22 @@ def run_remainder(x, y): # random floats traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() y = run_remainder(a, b) np.testing.assert_allclose(x.numpy(), y.numpy()) # div by 0 traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) - x = traced(zeros, a) + x = warmup_and_run_forward(traced, zeros, a) + self.assertLastGraphAllFused() y = run_remainder(zeros, a) np.testing.assert_allclose(x.numpy(), y.numpy()) # numerators and denominatos are nan traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) - x = traced(nans, a) + x = warmup_and_run_forward(traced, nans, a) + self.assertLastGraphAllFused() y = run_remainder(nans, a) np.testing.assert_allclose(x.numpy(), y.numpy()) @@ -941,7 +1022,8 @@ def easy(x): traced = torch.jit.trace(easy, (torch.zeros(1024))) a = torch.zeros(1024) - b, c = traced(a) + b, c = warmup_and_run_forward(traced, a) + self.assertLastGraphAllFused() bp = a.numpy() + 1 cp = bp + bp np.testing.assert_allclose(b.numpy(), bp) @@ -956,28 +1038,28 @@ def easy(x): traced = torch.jit.trace(easy, (torch.zeros(1024, 1024))) a = torch.zeros(1024, 1024) - x = traced(a) + x = warmup_and_run_forward(traced, a) + self.assertLastGraphAllFused() npr = a.numpy() npr2 = npr + 1 npr_a, npr_b = np.array_split(npr2, 2) np.testing.assert_allclose(npr_a + npr_b, x.numpy()) def _test_cat(self, device): - def easy(*args): + def foo(*args): args_2 = [v + i for i, v in enumerate(args)] v = torch.cat(args_2, dim=1) - return v + return v * v - M = 1024 - Ns = [1024, 512, 256, 128] + M = 16 + Ns = [128, 16, 1] values = [torch.zeros(M, N, device=device) for N in Ns] - traced = torch.jit.trace(easy, values) + traced = torch.jit.trace(foo, values) - x = traced(*values) - npr = [v.cpu().numpy() for v in values] - npr_2 = [v + i for i, v in enumerate(npr)] - npr_x = np.concatenate(npr_2, axis=1) - np.testing.assert_allclose(npr_x, x.cpu().numpy()) + x = warmup_and_run_forward(traced, *values) + self.assertLastGraphAllFused() + ref = foo(*values) + np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) def test_cat_cpu(self): self._test_cat('cpu') @@ -986,15 +1068,106 @@ def test_cat_cpu(self): def test_cat_cuda(self): self._test_cat('cuda') + # This test checks that we correctly handle fusion group with just aten::cat in it. + # Note that the test only makes sense with min_fusion_group=1, otherwise no + # fusion groups would be formed at all. + # TODO: Fix and re-enable the test. + def _test_cat_only(self, device): + def foo(*args): + args_2 = [v + i for i, v in enumerate(args)] + v = torch.cat(args_2, dim=1) + return v + + M = 16 + Ns = [128, 16, 1] + values = [torch.zeros(M, N, device=device) for N in Ns] + traced = torch.jit.trace(foo, values) + + x = warmup_and_run_forward(traced, *values) + self.assertLastGraphAllFused() + ref = foo(*values) + np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) + + @unittest.skip("temporarily disable") + def test_cat_only_cpu(self): + self._test_cat_only('cpu') + + @unittest.skip("temporarily disable") + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + def test_cat_only_cuda(self): + self._test_cat_only('cuda') + + def _test_cat_negative_dim(self, device): + def foo(*args): + v = torch.cat(args, dim=-1) + return v * v + + M = 16 + Ns = [128, 16, 1] + values = [torch.randn(M, N, device=device) for N in Ns] + traced = torch.jit.trace(foo, values) + + x = warmup_and_run_forward(traced, *values) + self.assertLastGraphAllFused() + ref = foo(*values) + np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) + + def test_cat_negative_dim_cpu(self): + self._test_cat_negative_dim('cpu') + + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + def test_cat_negative_dim_cuda(self): + self._test_cat_negative_dim('cuda') + + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + def test_cat_promote_inputs(self): + def foo(*args): + v = torch.cat(args, dim=1) + return v * v + + M = 16 + Ns = [128, 16, 1] + dtypes = [torch.half, torch.float32, torch.double] + values = [torch.randn(M, N, device='cuda', dtype=dt) for N, dt in zip(Ns, dtypes)] + traced = torch.jit.trace(foo, values) + + x = warmup_and_run_forward(traced, *values) + self.assertLastGraphAllFused() + ref = foo(*values) + np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) + + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + def test_cat_empty_tensors(self): + def foo(*args): + v = torch.cat(args, dim=1) + return v * v + + M = 16 + Ns = [128, 16, 1] + empty = torch.tensor([], device='cuda', dtype=torch.double) + values = [empty] + [torch.randn(M, N, device='cuda') for N in Ns] + traced = torch.jit.trace(foo, values) + + x = warmup_and_run_forward(traced, *values) + self.assertLastGraphAllFused() + ref = foo(*values) + np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) + + # now test with only empty tensors + values = [empty for i in range(3)] + traced = torch.jit.trace(foo, values) + x = warmup_and_run_forward(traced, *values) + self.assertLastGraphAllFused() + ref = foo(*values) + np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) + def test_scalar(self): @torch.jit.script - def test_float(x, y, z, a, b): - # type: (Tensor, Tensor, Tensor, float, float) -> Tensor + def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor: return torch.add(torch.add(x, y, alpha=a), z, alpha=b) @torch.jit.script - def test_int(x, y, z, a, b): - # type: (Tensor, Tensor, Tensor, int, int) -> Tensor + def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor: return torch.add(torch.add(x, y, alpha=a), z, alpha=b) for test in (test_float, test_int): @@ -1011,8 +1184,7 @@ def test_int(x, y, z, a, b): def test_loop(self): @torch.jit.script - def test(x, y, z): - # type: (Tensor, Tensor, int) -> Tensor + def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor: b = y for i in range(0, z): a = x + y @@ -1046,18 +1218,18 @@ def easy(x, y): # FIXME: interp.elapsed_value() also increments due to simplifier assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1 - def test_unsqueeze(self): + def test_unsqueeze(self, N=256): def easy(x, y): a = torch.unsqueeze(x, 0) b = torch.unsqueeze(y, 0) return a + b - traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) + traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N))) llvm = LLVMCodeGenExecuted() interp = SimpleIREvalExecuted() - a = torch.rand(1024, 1024) + a = torch.rand(N, N) x = traced(a, a) npr = np.expand_dims(a, 0) npr = npr + npr @@ -1065,6 +1237,82 @@ def easy(x, y): # FIXME: interp.elapsed_value() also increments due to simplifier assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1 + def _test_softmax(self, device): + def test_softmax(x, y): + a = F.softmax(x, dim=0, dtype=torch.float32) + b = F.softmax(y, dim=0, dtype=torch.float32) + c = F.softmax(x, dim=1, dtype=torch.float32) + d = F.softmax(y, dim=1, dtype=torch.float32) + return a + b + c + d + + def test_softmax_neg_index(x, y): + a = F.softmax(x, dim=-2, dtype=torch.float32) + b = F.softmax(y, dim=-2, dtype=torch.float32) + c = F.softmax(x, dim=-1, dtype=torch.float32) + d = F.softmax(y, dim=-1, dtype=torch.float32) + return a + b + c + d + + def test_log_softmax(x, y): + a = F.log_softmax(x, dim=0, dtype=torch.float32) + b = F.log_softmax(y, dim=0, dtype=torch.float32) + c = F.log_softmax(x, dim=1, dtype=torch.float32) + d = F.log_softmax(y, dim=1, dtype=torch.float32) + return a + b + c + d + + for test in (test_softmax, test_log_softmax, test_softmax_neg_index): + old = torch._C._jit_set_texpr_reductions_enabled(True) + traced = torch.jit.trace(test, (torch.randn(2, 3, device=device), torch.randn(2, 3, device=device))) + inp = torch.randn(2, 3, device=device) + res = traced(inp, inp) + # Use eager mode as reference. + ref = test(inp, inp) + np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06) + torch._C._jit_set_texpr_reductions_enabled(old) + + def test_softmax_cpu(self): + llvm = LLVMCodeGenExecuted() + interp = SimpleIREvalExecuted() + self._test_softmax('cpu') + # FIXME: interp.elapsed_value() also increments due to simplifier + assert llvm.elapsed_value() == 1 or interp.elapsed_value() > 1 + + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + @unittest.skip("global allocs are not supported yet.") + def test_softmax_cuda(self): + cuda = CudaCodeGenExecuted() + self._test_softmax('cuda') + assert cuda.elapsed_value() == 1 + + def test_half_gelu(self): + devices = ["cuda"] if torch.cuda.is_available() else [] + + @torch.jit.script + def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) + + for device in devices: + a = torch.rand(1024, dtype=torch.half, device=device) + b = torch.rand(1024, dtype=torch.half, device=device) + traced = torch.jit.trace(bias_gelu, (a, b)) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() + + def test_exp_pow(self): + devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] + + @torch.jit.script + def do_exp(x, y, z): + return ((x * y) * 2) * torch.pow(z, 2) + + for device in devices: + x = torch.rand(10, dtype=torch.double, device=device) + y = torch.rand(10, dtype=torch.double, device=device) + z = torch.rand(10, dtype=torch.double, device=device) + traced = torch.jit.trace(do_exp, (x, y, z)) + x = warmup_and_run_forward(traced, x, y, z) + self.assertLastGraphAllFused() + def test_transpose(self): @torch.jit.script def test(x, y, z): @@ -1178,7 +1426,8 @@ def run_rshift(x, y): b = torch.zeros(128, dtype=torch.int32, device=device) inp = torch.ones(128, dtype=torch.int32, device=device) traced = torch.jit.trace(fn, (inp, inp)) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() y = fn(a, b) np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) @@ -1189,7 +1438,8 @@ def run_where(x, y): a = torch.rand(1024, dtype=float) b = torch.rand(1024, dtype=float) traced = torch.jit.trace(run_where, (torch.zeros(1024), torch.zeros(1024))) - x = traced(a, b) + x = warmup_and_run_forward(traced, a, b) + self.assertLastGraphAllFused() y = run_where(a, b) np.testing.assert_allclose(x.numpy(), y.numpy()) @@ -1200,9 +1450,10 @@ def test(x): return (x + y) - (y - x) a = torch.rand(4, device="cuda") scripted = torch.jit.script(test) - scripted(a) cx = CudaCodeGenExecuted() - assert torch.allclose(scripted(a), 2 * a) + out = warmup_and_run_forward(scripted, a) + self.assertLastGraphAllFused() + assert torch.allclose(out, 2 * a) assert cx.elapsed_value() == 1 def test_mask(self): @@ -1214,8 +1465,9 @@ def test(x): for d in devices: x = torch.rand(4, device=d) > 0.5 scripted = torch.jit.script(test) - scripted(x) - assert torch.equal(scripted(x), test(x)) + out = warmup_and_run_forward(scripted, x) + self.assertLastGraphAllFused() + assert torch.equal(out, test(x)) def test_simple_add(self): val = torch._C._jit_get_te_generate_block_code() @@ -1236,6 +1488,39 @@ def simple(a, b): torch._C._jit_set_te_generate_block_code(val) torch._C._jit_texpr_set_fallback_allowed(fall_bk) + def test_strided_output_preserved(self): + def foo(a, b): + return a + b - a + + # smaller, easier to debug example + x = torch.arange(6) + x = torch.as_strided(x, (2, 3), (1, 2)) + total = 0 + for i in range(2): + for j in range(3): + x[i, j] = total + total += 1 + foo_script = torch.jit.script(foo) + foo_script(x, x) + foo_script(x, x) + out_s = foo_script(x, x) + out_eager = foo(x, x) + self.assertEqual(out_s, out_eager) + self.assertEqual(out_s.stride(), out_eager.stride()) + self.assertLastGraphAllFused() + + # more dims + N, C, H, W, = 2, 3, 4, 5 + x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last) + foo_script = torch.jit.script(foo) + foo_script(x, x) + foo_script(x, x) + out_s = foo_script(x, x) + out_eager = foo(x, x) + self.assertEqual(out_s, out_eager) + self.assertEqual(out_s.stride(), out_eager.stride()) + self.assertLastGraphAllFused() + def test_alias_analysis_module(self): class AliasModule(nn.Module): def __init__(self): @@ -1342,6 +1627,24 @@ def getModule(script): torch.testing.assert_allclose(ref, test) + def test_multiple_outputs(self): + for device in self.devices: + # A bug reported internally similar to the one reported in #48533 + def foo(a, b, c): + t_next = c + 1 + t5 = t_next * b + t6 = torch.unsqueeze(t_next, 1) + t7 = a * t6 + return (t7, t5, t_next) + + a = torch.rand(20, 20, dtype=torch.float32, device=device) + b = torch.rand(20 * 29, dtype=torch.float32, device=device).as_strided([20], [29]) + c = torch.ones(20, dtype=torch.int64, device=device) + traced = torch.jit.trace(foo, (a, b, c)) + ref = foo(a, b, c) + exp = traced(a, b, c) + exp = traced(a, b, c) + self.assertEqual(ref, exp) if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_tensorexpr_pybind.py b/test/test_tensorexpr_pybind.py new file mode 100644 index 0000000000000..d8db2fef89ff4 --- /dev/null +++ b/test/test_tensorexpr_pybind.py @@ -0,0 +1,40 @@ +import torch +import unittest + +from torch.testing._internal.jit_utils import JitTestCase + +class kernel_arena_scope(object): + def __enter__(self): + self.scope = torch._C.te.KernelScope() + + def __exit__(self, typ, val, traceback): + self.scope = None + +class TestTensorExprPyBind(JitTestCase): + def test_simple_sum(self): + with kernel_arena_scope(): + dtype = torch._C.te.Dtype.Float + N = 32 + dN = torch._C.te.ExprHandle.int(N) + + A = torch._C.te.Placeholder('A', dtype, [dN]) + B = torch._C.te.Placeholder('B', dtype, [dN]) + + def compute(i): + return A.load([i]) + B.load([i]) + C = torch._C.te.Compute('C', [torch._C.te.DimArg(dN, 'i')], compute) + + loopnest = torch._C.te.LoopNest([C]) + loopnest.prepare_for_codegen() + stmt = torch._C.te.simplify(loopnest.root_stmt()) + + cg = torch._C.te.construct_codegen('ir_eval', stmt, [torch._C.te.BufferArg(x) for x in [A, B, C]]) + + tA = torch.rand(N) * 5 + tB = torch.rand(N) * 6 + tC = torch.empty(N) + cg.call([tA, tB, tC]) + torch.testing.assert_allclose(tA + tB, tC) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_testing.py b/test/test_testing.py new file mode 100644 index 0000000000000..4ff215233fe2e --- /dev/null +++ b/test/test_testing.py @@ -0,0 +1,491 @@ +import torch + +import math + +from torch.testing._internal.common_utils import \ + (TestCase, make_tensor, run_tests, slowTest) +from torch.testing._internal.common_device_type import \ + (instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA, dtypes) + +# For testing TestCase methods and torch.testing functions +class TestTesting(TestCase): + # Ensure that assertEqual handles numpy arrays properly + @dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True))) + def test_assertEqual_numpy(self, device, dtype): + S = 10 + test_sizes = [ + (), + (0,), + (S,), + (S, S), + (0, S), + (S, 0)] + for test_size in test_sizes: + a = make_tensor(test_size, device, dtype, low=-5, high=5) + a_n = a.cpu().numpy() + msg = f'size: {test_size}' + self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg) + self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg) + self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg) + + # Tests that when rtol or atol (including self.precision) is set, then + # the other is zeroed. + # TODO: this is legacy behavior and should be updated after test + # precisions are reviewed to be consistent with torch.isclose. + @onlyOnCPUAndCUDA + def test__comparetensors_legacy(self, device): + a = torch.tensor((10000000.,)) + b = torch.tensor((10000002.,)) + + x = torch.tensor((1.,)) + y = torch.tensor((1. + 1e-5,)) + + # Helper for reusing the tensor values as scalars + def _scalar_helper(a, b, rtol=None, atol=None): + return self._compareScalars(a.item(), b.item(), rtol=rtol, atol=atol) + + for op in (self._compareTensors, _scalar_helper): + # Tests default + result, debug_msg = op(a, b) + self.assertTrue(result) + + # Tests setting atol + result, debug_msg = op(a, b, atol=2, rtol=0) + self.assertTrue(result) + + # Tests setting atol too small + result, debug_msg = op(a, b, atol=1, rtol=0) + self.assertFalse(result) + + # Tests setting rtol too small + result, debug_msg = op(x, y, atol=0, rtol=1.05e-5) + self.assertTrue(result) + + # Tests setting rtol too small + result, debug_msg = op(x, y, atol=0, rtol=1e-5) + self.assertFalse(result) + + @onlyOnCPUAndCUDA + def test__comparescalars_debug_msg(self, device): + # float x float + result, debug_msg = self._compareScalars(4., 7.) + expected_msg = ("Comparing 4.0 and 7.0 gives a difference of 3.0, " + "but the allowed difference with rtol=1.3e-06 and " + "atol=1e-05 is only 1.9100000000000003e-05!") + self.assertEqual(debug_msg, expected_msg) + + # complex x complex, real difference + result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1)) + expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference " + "of 2.0, but the allowed difference with rtol=1.3e-06 " + "and atol=1e-05 is only 1.39e-05!") + self.assertEqual(debug_msg, expected_msg) + + # complex x complex, imaginary difference + result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5)) + expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a " + "difference of 2.5, but the allowed difference with " + "rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!") + self.assertEqual(debug_msg, expected_msg) + + # complex x int + result, debug_msg = self._compareScalars(complex(1, -2), 1) + expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a " + "difference of 2.0, but the allowed difference with " + "rtol=1.3e-06 and atol=1e-05 is only 1e-05!") + self.assertEqual(debug_msg, expected_msg) + + # NaN x NaN, equal_nan=False + result, debug_msg = self._compareScalars(float('nan'), float('nan'), equal_nan=False) + expected_msg = ("Found nan and nan while comparing and either one is " + "nan and the other isn't, or both are nan and equal_nan " + "is False") + self.assertEqual(debug_msg, expected_msg) + + # Checks that compareTensors provides the correct debug info + @onlyOnCPUAndCUDA + def test__comparetensors_debug_msg(self, device): + # Acquires atol that will be used + atol = max(1e-05, self.precision) + + # Checks float tensor comparisons (2D tensor) + a = torch.tensor(((0, 6), (7, 9)), device=device, dtype=torch.float32) + b = torch.tensor(((0, 7), (7, 22)), device=device, dtype=torch.float32) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 4) " + "whose difference(s) exceeded the margin of error (including 0 nan comparisons). " + "The greatest difference was 13.0 (9.0 vs. 22.0), " + "which occurred at index (1, 1).").format(atol) + self.assertEqual(debug_msg, expected_msg) + + # Checks float tensor comparisons (with extremal values) + a = torch.tensor((float('inf'), 5, float('inf')), device=device, dtype=torch.float32) + b = torch.tensor((float('inf'), float('nan'), float('-inf')), device=device, dtype=torch.float32) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 3) " + "whose difference(s) exceeded the margin of error (including 1 nan comparisons). " + "The greatest difference was nan (5.0 vs. nan), " + "which occurred at index 1.").format(atol) + self.assertEqual(debug_msg, expected_msg) + + # Checks float tensor comparisons (with finite vs nan differences) + a = torch.tensor((20, -6), device=device, dtype=torch.float32) + b = torch.tensor((-1, float('nan')), device=device, dtype=torch.float32) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 2) " + "whose difference(s) exceeded the margin of error (including 1 nan comparisons). " + "The greatest difference was nan (-6.0 vs. nan), " + "which occurred at index 1.").format(atol) + self.assertEqual(debug_msg, expected_msg) + + # Checks int tensor comparisons (1D tensor) + a = torch.tensor((1, 2, 3, 4), device=device) + b = torch.tensor((2, 5, 3, 4), device=device) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("Found 2 different element(s) (out of 4), " + "with the greatest difference of 3 (2 vs. 5) " + "occuring at index 1.") + self.assertEqual(debug_msg, expected_msg) + + # Checks bool tensor comparisons (0D tensor) + a = torch.tensor((True), device=device) + b = torch.tensor((False), device=device) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("Found 1 different element(s) (out of 1), " + "with the greatest difference of 1 (1 vs. 0) " + "occuring at index 0.") + self.assertEqual(debug_msg, expected_msg) + + # Checks complex tensor comparisons (real part) + a = torch.tensor((1 - 1j, 4 + 3j), device=device) + b = torch.tensor((1 - 1j, 1 + 3j), device=device) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("Real parts failed to compare as equal! " + "With rtol=1.3e-06 and atol={0}, " + "found 1 element(s) (out of 2) whose difference(s) exceeded the " + "margin of error (including 0 nan comparisons). The greatest difference was " + "3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol) + self.assertEqual(debug_msg, expected_msg) + + # Checks complex tensor comparisons (imaginary part) + a = torch.tensor((1 - 1j, 4 + 3j), device=device) + b = torch.tensor((1 - 1j, 4 - 21j), device=device) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("Imaginary parts failed to compare as equal! " + "With rtol=1.3e-06 and atol={0}, " + "found 1 element(s) (out of 2) whose difference(s) exceeded the " + "margin of error (including 0 nan comparisons). The greatest difference was " + "24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol) + self.assertEqual(debug_msg, expected_msg) + + # Checks size mismatch + a = torch.tensor((1, 2), device=device) + b = torch.tensor((3), device=device) + result, debug_msg = self._compareTensors(a, b) + expected_msg = ("Attempted to compare equality of tensors " + "with different sizes. Got sizes torch.Size([2]) and torch.Size([]).") + self.assertEqual(debug_msg, expected_msg) + + # Checks dtype mismatch + a = torch.tensor((1, 2), device=device, dtype=torch.long) + b = torch.tensor((1, 2), device=device, dtype=torch.float32) + result, debug_msg = self._compareTensors(a, b, exact_dtype=True) + expected_msg = ("Attempted to compare equality of tensors " + "with different dtypes. Got dtypes torch.int64 and torch.float32.") + self.assertEqual(debug_msg, expected_msg) + + # Checks device mismatch + if self.device_type == 'cuda': + a = torch.tensor((5), device='cpu') + b = torch.tensor((5), device=device) + result, debug_msg = self._compareTensors(a, b, exact_device=True) + expected_msg = ("Attempted to compare equality of tensors " + "on different devices! Got devices cpu and cuda:0.") + self.assertEqual(debug_msg, expected_msg) + + # Helper for testing _compareTensors and _compareScalars + # Works on single element tensors + def _comparetensors_helper(self, tests, device, dtype, equal_nan, exact_dtype=True, atol=1e-08, rtol=1e-05): + for test in tests: + a = torch.tensor((test[0],), device=device, dtype=dtype) + b = torch.tensor((test[1],), device=device, dtype=dtype) + + # Tensor x Tensor comparison + compare_result, debug_msg = self._compareTensors(a, b, rtol=rtol, atol=atol, + equal_nan=equal_nan, + exact_dtype=exact_dtype) + self.assertEqual(compare_result, test[2]) + + # Scalar x Scalar comparison + compare_result, debug_msg = self._compareScalars(a.item(), b.item(), + rtol=rtol, atol=atol, + equal_nan=equal_nan) + self.assertEqual(compare_result, test[2]) + + def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05): + for test in tests: + a = torch.tensor((test[0],), device=device, dtype=dtype) + b = torch.tensor((test[1],), device=device, dtype=dtype) + + actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol) + expected = test[2] + self.assertEqual(actual.item(), expected) + + # torch.close is not implemented for bool tensors + # see https://github.com/pytorch/pytorch/issues/33048 + def test_isclose_comparetensors_bool(self, device): + tests = ( + (True, True, True), + (False, False, True), + (True, False, False), + (False, True, False), + ) + + with self.assertRaises(RuntimeError): + self._isclose_helper(tests, device, torch.bool, False) + + self._comparetensors_helper(tests, device, torch.bool, False) + + @dtypes(torch.uint8, + torch.int8, torch.int16, torch.int32, torch.int64) + def test_isclose_comparetensors_integer(self, device, dtype): + tests = ( + (0, 0, True), + (0, 1, False), + (1, 0, False), + ) + + self._isclose_helper(tests, device, dtype, False) + + # atol and rtol tests + tests = [ + (0, 1, True), + (1, 0, False), + (1, 3, True), + ] + + self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) + self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) + + if dtype is torch.uint8: + tests = [ + (-1, 1, False), + (1, -1, False) + ] + else: + tests = [ + (-1, 1, True), + (1, -1, True) + ] + + self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5) + self._comparetensors_helper(tests, device, dtype, False, atol=1.5, rtol=.5) + + @onlyOnCPUAndCUDA + @dtypes(torch.float16, torch.float32, torch.float64) + def test_isclose_comparetensors_float(self, device, dtype): + tests = ( + (0, 0, True), + (0, -1, False), + (float('inf'), float('inf'), True), + (-float('inf'), float('inf'), False), + (float('inf'), float('nan'), False), + (float('nan'), float('nan'), False), + (0, float('nan'), False), + (1, 1, True), + ) + + self._isclose_helper(tests, device, dtype, False) + self._comparetensors_helper(tests, device, dtype, False) + + # atol and rtol tests + eps = 1e-2 if dtype is torch.half else 1e-6 + tests = ( + (0, 1, True), + (0, 1 + eps, False), + (1, 0, False), + (1, 3, True), + (1 - eps, 3, False), + (-.25, .5, True), + (-.25 - eps, .5, False), + (.25, -.5, True), + (.25 + eps, -.5, False), + ) + + self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) + self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) + + # equal_nan = True tests + tests = ( + (0, float('nan'), False), + (float('inf'), float('nan'), False), + (float('nan'), float('nan'), True), + ) + + self._isclose_helper(tests, device, dtype, True) + + self._comparetensors_helper(tests, device, dtype, True) + + # torch.close with equal_nan=True is not implemented for complex inputs + # see https://github.com/numpy/numpy/issues/15959 + # Note: compareTensor will compare the real and imaginary parts of a + # complex tensors separately, unlike isclose. + @dtypes(torch.complex64, torch.complex128) + def test_isclose_comparetensors_complex(self, device, dtype): + tests = ( + (complex(1, 1), complex(1, 1 + 1e-8), True), + (complex(0, 1), complex(1, 1), False), + (complex(1, 1), complex(1, 0), False), + (complex(1, 1), complex(1, float('nan')), False), + (complex(1, float('nan')), complex(1, float('nan')), False), + (complex(1, 1), complex(1, float('inf')), False), + (complex(float('inf'), 1), complex(1, float('inf')), False), + (complex(-float('inf'), 1), complex(1, float('inf')), False), + (complex(-float('inf'), 1), complex(float('inf'), 1), False), + (complex(float('inf'), 1), complex(float('inf'), 1), True), + (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False), + ) + + self._isclose_helper(tests, device, dtype, False) + self._comparetensors_helper(tests, device, dtype, False) + + # atol and rtol tests + + # atol and rtol tests + eps = 1e-6 + tests = ( + # Complex versions of float tests (real part) + (complex(0, 0), complex(1, 0), True), + (complex(0, 0), complex(1 + eps, 0), False), + (complex(1, 0), complex(0, 0), False), + (complex(1, 0), complex(3, 0), True), + (complex(1 - eps, 0), complex(3, 0), False), + (complex(-.25, 0), complex(.5, 0), True), + (complex(-.25 - eps, 0), complex(.5, 0), False), + (complex(.25, 0), complex(-.5, 0), True), + (complex(.25 + eps, 0), complex(-.5, 0), False), + # Complex versions of float tests (imaginary part) + (complex(0, 0), complex(0, 1), True), + (complex(0, 0), complex(0, 1 + eps), False), + (complex(0, 1), complex(0, 0), False), + (complex(0, 1), complex(0, 3), True), + (complex(0, 1 - eps), complex(0, 3), False), + (complex(0, -.25), complex(0, .5), True), + (complex(0, -.25 - eps), complex(0, .5), False), + (complex(0, .25), complex(0, -.5), True), + (complex(0, .25 + eps), complex(0, -.5), False), + ) + + self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) + self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) + + # atol and rtol tests for isclose + tests = ( + # Complex-specific tests + (complex(1, -1), complex(-1, 1), False), + (complex(1, -1), complex(2, -2), True), + (complex(-math.sqrt(2), math.sqrt(2)), + complex(-math.sqrt(.5), math.sqrt(.5)), True), + (complex(-math.sqrt(2), math.sqrt(2)), + complex(-math.sqrt(.501), math.sqrt(.499)), False), + (complex(2, 4), complex(1., 8.8523607), True), + (complex(2, 4), complex(1., 8.8523607 + eps), False), + (complex(1, 99), complex(4, 100), True), + ) + + self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) + + # atol and rtol tests for compareTensors + tests = ( + (complex(1, -1), complex(-1, 1), False), + (complex(1, -1), complex(2, -2), True), + (complex(1, 99), complex(4, 100), False), + ) + + self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) + + # equal_nan = True tests + tests = ( + (complex(1, 1), complex(1, float('nan')), False), + (complex(float('nan'), 1), complex(1, float('nan')), False), + (complex(float('nan'), 1), complex(float('nan'), 1), True), + ) + + with self.assertRaises(RuntimeError): + self._isclose_helper(tests, device, dtype, True) + + self._comparetensors_helper(tests, device, dtype, True) + + # Tests that isclose with rtol or atol values less than zero throws a + # RuntimeError + @dtypes(torch.bool, torch.uint8, + torch.int8, torch.int16, torch.int32, torch.int64, + torch.float16, torch.float32, torch.float64) + def test_isclose_atol_rtol_greater_than_zero(self, device, dtype): + t = torch.tensor((1,), device=device, dtype=dtype) + + with self.assertRaises(RuntimeError): + torch.isclose(t, t, atol=-1, rtol=1) + with self.assertRaises(RuntimeError): + torch.isclose(t, t, atol=1, rtol=-1) + with self.assertRaises(RuntimeError): + torch.isclose(t, t, atol=-1, rtol=-1) + + def test_assert_messages(self, device): + self.assertIsNone(self._get_assert_msg(msg=None)) + self.assertEqual("\nno_debug_msg", self._get_assert_msg("no_debug_msg")) + self.assertEqual("no_user_msg", self._get_assert_msg(msg=None, debug_msg="no_user_msg")) + self.assertEqual("debug_msg\nuser_msg", self._get_assert_msg(msg="user_msg", debug_msg="debug_msg")) + + @onlyCUDA + @slowTest + def test_cuda_assert_should_stop_test_suite(self, device): + # This test is slow because it spawn another process to run another test suite. + + # Test running of cuda assert test suite should early terminate. + stderr = TestCase.runWithPytorchAPIUsageStderr("""\ +#!/usr/bin/env python + +import torch + +from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest) +from torch.testing._internal.common_device_type import instantiate_device_type_tests + +# This test is added to ensure that test suite terminates early when +# CUDA assert was thrown since all subsequent test will fail. +# See: https://github.com/pytorch/pytorch/issues/49019 +# This test file should be invoked from test_testing.py +class TestThatContainsCUDAAssertFailure(TestCase): + + @slowTest + def test_throw_unrecoverable_cuda_exception(self, device): + x = torch.rand(10, device=device) + # cause unrecoverable CUDA exception, recoverable on CPU + y = x[torch.tensor([25])].cpu() + + @slowTest + def test_trivial_passing_test_case_on_cpu_cuda(self, device): + x1 = torch.tensor([0., 1.], device=device) + x2 = torch.tensor([0., 1.], device='cpu') + self.assertEqual(x1, x2) + +instantiate_device_type_tests( + TestThatContainsCUDAAssertFailure, + globals(), + only_for='cuda' +) + +if __name__ == '__main__': + run_tests() +""") + # should capture CUDA error + self.assertIn('CUDA error: device-side assert triggered', stderr) + # should run only 1 test because it throws unrecoverable error. + self.assertIn('Ran 1 test', stderr) + + +instantiate_device_type_tests(TestTesting, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_throughput_benchmark.py b/test/test_throughput_benchmark.py index d2f993ddaa3a3..9d60344b5912b 100644 --- a/test/test_throughput_benchmark.py +++ b/test/test_throughput_benchmark.py @@ -1,10 +1,9 @@ import torch -import tempfile from torch.utils import ThroughputBenchmark from torch.testing import assert_allclose -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, TestCase, TemporaryFileName class TwoLayerNet(torch.jit.ScriptModule): def __init__(self, D_in, H, D_out): @@ -76,8 +75,8 @@ def test_module(self): self.linear_test(TwoLayerNetModule) def test_profiling(self): - with tempfile.NamedTemporaryFile(delete=False) as f: - self.linear_test(TwoLayerNetModule, profiler_output_path=f.name) + with TemporaryFileName() as fname: + self.linear_test(TwoLayerNetModule, profiler_output_path=fname) if __name__ == '__main__': diff --git a/test/test_torch.py b/test/test_torch.py index 6c875e68b12f8..0f83caa082a1f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1,64 +1,50 @@ -import sys +import torch +import numpy as np + import io import inspect import math import random import re import copy -import torch -import torch.cuda -import torch.backends.cuda import tempfile import unittest import warnings import types import pickle import textwrap -import operator -import os -import subprocess from torch.utils.dlpack import from_dlpack, to_dlpack -from torch._six import inf, nan, string_classes, istuple -from itertools import product, combinations, combinations_with_replacement, permutations -from functools import reduce -from functools import partial -from random import randrange +from torch._six import inf, nan, string_classes +from itertools import product, combinations, permutations from torch import multiprocessing as mp -from torch.testing._internal.common_methods_invocations import tri_tests_args, run_additional_tri_tests, \ - _compare_trilu_indices -from torch.testing._internal.common_utils import \ - (TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, run_tests, - skipIfNoLapack, suppress_warnings, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, - do_test_dtypes, IS_SANDCASTLE, load_tests, slowTest, - skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, - skipIfRocm, torch_to_numpy_dtype_dict, skipIfNoSciPy, IS_MACOS, IS_PPC, - wrapDeterministicFlagAPITest) +from torch.testing._internal.common_utils import ( + TestCase, TEST_WITH_ROCM, run_tests, + IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, + do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest, + skipCUDAMemoryLeakCheckIf, BytesIOContext, + skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, + wrapDeterministicFlagAPITest, DeterministicGuard) from multiprocessing.reduction import ForkingPickler -from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ - skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, onlyCUDA, onlyCPU, \ - dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \ - PYTORCH_CUDA_MEMCHECK, largeCUDATensorTest, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic -from typing import Dict, List, Tuple, Union +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + skipCUDAIfNoMagma, skipCUDAIfRocm, + onlyCUDA, onlyCPU, + dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, + PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, + expectedAlertNondeterministic) +from typing import Dict, List import torch.backends.quantized import torch.testing._internal.data -from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, with_tf32_off, \ - _get_torch_cuda_version, TEST_MAGMA +from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32 +# Protects against includes accidentally setting the default dtype +assert torch.get_default_dtype() is torch.float32 # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests -if TEST_NUMPY: - import numpy as np - -if TEST_SCIPY: - import scipy - from scipy import signal - -SIZE = 100 - -AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() +AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() # Wrap base test class into a class to hide it from testing # See https://stackoverflow.com/a/25695512 @@ -195,8 +181,6 @@ def test_namespace(ns, *skips): test_namespace(torch.randn(1), 'as_strided_', re.compile('^clamp_(min|max)_?$'), - 'coalesce', - 'is_coalesced', 'is_distributed', 'is_nonzero', 'is_same_size', @@ -209,52 +193,19 @@ def test_namespace(ns, *skips): 'prelu', 'resize', 'resize_as', - 'smm', 'softmax', 'split_with_sizes', 'unsafe_split_with_sizes', - 'sspaddmm', - 'to_dense', - 'sparse_resize_', - 'sparse_resize_and_clear_', ) test_namespace(torch.nn) test_namespace(torch.nn.functional, 'assert_int_or_pair') # TODO: add torch.* tests when we have proper namespacing on ATen functions # test_namespace(torch) - def test_linear_algebra_scalar_raises(self) -> None: - m = torch.randn(5, 5) - v = torch.randn(5) - s = torch.tensor(7) - self.assertRaises(RuntimeError, lambda: torch.mv(m, s)) - self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s)) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - def test_mvlgamma(self): - from scipy.special import multigammaln - for d in range(1, 5): - input = torch.empty(10).uniform_(d, 10) - res_torch = torch.mvlgamma(input, d) - res_scipy = multigammaln(input.numpy(), d) - self.assertEqual(res_torch.numpy(), res_scipy, atol=1e-5, rtol=0) - - def test_mvlgamma_argcheck(self): - def run_test(d): - input = torch.linspace((d - 2) / 2, 10, 10) - torch.mvlgamma(input, d) - - with self.assertRaisesRegex(RuntimeError, r"All elements must be greater than \(p-1\)/2"): - run_test(3) - def test_msnpu_error(self): with self.assertRaisesRegex(RuntimeError, "support for msnpu"): torch.zeros(1, device=torch.device('msnpu')) - def test_polygamma_neg(self): - with self.assertRaisesRegex(RuntimeError, r'polygamma\(n, x\) does not support negative n\.'): - torch.polygamma(-1, torch.tensor([1.0, 2.0])) - def test_has_storage(self): self.assertIsNotNone(torch.Tensor().storage()) self.assertIsNotNone(torch.Tensor(0).storage()) @@ -263,116 +214,6 @@ def test_has_storage(self): self.assertIsNotNone(torch.Tensor([0, 0, 0]).nonzero().storage()) self.assertIsNotNone(torch.Tensor().new().storage()) - def test_dim_reduction_uint8_overflow(self): - example = [[-1, 2, 1], [5, 3, 6]] - x = torch.tensor(example, dtype=torch.uint8) - self.assertEqual(x.sum(dtype=torch.uint8).item(), 16) - self.assertEqual(x.sum(0, dtype=torch.uint8), torch.tensor([4, 5, 7], dtype=torch.uint8)) - self.assertEqual(x.sum(1, dtype=torch.uint8), torch.tensor([2, 14], dtype=torch.uint8)) - y = torch.tensor(example, dtype=torch.uint8) - torch.sum(x, 0, out=y) - self.assertEqual(x.sum(0, dtype=torch.uint8), y) - - def test_dim_reduction_less_than_64(self): - sizes = [1] * 65 - x = torch.randn(sizes) - ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var, - torch.amin, torch.amax, torch.norm] - for op in ops: - with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): - op(x, 64) - with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): - op(x, -1) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - def test_logsumexp(self): - from scipy.special import logsumexp - a = torch.randn(5, 4) - a[0, 0] = inf - a[1, :] = -inf - actual = a.logsumexp(1) - expected = logsumexp(a.numpy(), 1) - self.assertEqual(expected.shape, actual.shape) - self.assertEqual(expected, actual) - # check that out is actually inplace - b = torch.zeros(5, 2) - c = b[:, 0] - torch.logsumexp(a, 1, out=c) - self.assertEqual(expected, b[:, 0]) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_cpu_parallel(self): - # To use parallel branches we'll need to compare on tensors - # that are relatively large. Even if this is run on a single - # core machine these tests will still give you signal on - # the correctness - - def _run_test(size): - for dim in range(len(size) + 1): - nv = np.round(np.random.rand(*size)) # 0s and 1s - tv = torch.from_numpy(nv) - # Parallelisim is only used if numel is - # larger than grainsize defined in Parallel.h - self.assertTrue(tv.numel() > 32768) - if dim == len(size): - nvs = nv.sum() - tvs = tv.sum() - else: - nvs = nv.sum(dim) - tvs = tv.sum(dim) - diff = np.abs(nvs - tvs.numpy()).sum() - self.assertEqual(diff, 0) - - _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3]) - _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) - _run_test([1, 32 * 8 * 32 * 8]) - _run_test([1, 32770]) - - def _testCSelection(self, torchfn, mathfn): - # Two tensors - size = (100, 100) - a = torch.rand(*size) - b = torch.rand(*size) - c = torchfn(a, b) - expected_c = torch.zeros(*size) - expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b)) - self.assertEqual(expected_c, c, atol=0, rtol=0) - - def test_max_elementwise(self): - self._testCSelection(torch.max, max) - - def test_min_elementwise(self): - self._testCSelection(torch.min, min) - - def test_all_any(self): - def test(size): - x = torch.ones(*size).byte() - self.assertTrue(x.all()) - self.assertTrue(x.any()) - - x[3] = 0 - self.assertFalse(x.all()) - self.assertTrue(x.any()) - - x.zero_() - self.assertFalse(x.all()) - self.assertFalse(x.any()) - - x.fill_(2) - self.assertTrue(x.all()) - self.assertTrue(x.any()) - - x = torch.ones(*size).bool() - self.assertTrue(x.all()) - self.assertTrue(x.any()) - - x[3] = False - self.assertFalse(x.all()) - self.assertTrue(x.any()) - - test((10,)) - test((5, 5)) - def test_where_invalid_device(self): if torch.cuda.is_available(): for devices in [('cpu', 'cuda', 'cuda'), ('cuda', 'cpu', 'cpu'), @@ -413,8 +254,8 @@ def get_tensor(size, dtype, device, contiguous): height = 5 width = 5 for device in torch.testing.get_all_device_types(): - for dt1 in torch.testing.get_all_math_dtypes(device): - for dt2 in torch.testing.get_all_math_dtypes(device): + for dt1 in torch.testing.get_all_dtypes(): + for dt2 in torch.testing.get_all_dtypes(): for contiguous in [True, False]: x1 = get_tensor((height, width), dt1, device, contiguous) x2 = get_tensor((height, width), dt2, device, contiguous) @@ -431,204 +272,6 @@ def get_tensor(size, dtype, device, contiguous): result = torch.where(condition, x1, x2) self.assertEqual(expected, result) - def test_all_any_with_dim(self): - def test(x): - r1 = x.prod(dim=0, keepdim=False).byte() - r2 = x.all(dim=0, keepdim=False) - self.assertEqual(r1.shape, r2.shape) - self.assertTrue((r1 == r2).all()) - - r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte() - r4 = x.any(dim=1, keepdim=True) - self.assertEqual(r3.shape, r4.shape) - self.assertTrue((r3 == r4).all()) - - test(torch.ByteTensor([[0, 0, 0], - [0, 0, 1], - [0, 1, 1], - [1, 1, 1]])) - - def test_numpy_args(self): - x1 = torch.randn(10) - x2 = torch.randn(10) - res1 = torch.add(input=x1, other=x2) - res2 = torch.add(x1=x1, x2=x2) - self.assertEqual(res1, res2) - - x1 = torch.randn(10, 10, 10) - res1 = x1.sum(dim=(0, 2), keepdim=True) - res2 = x1.sum(axis=(0, 2), keepdims=True) - self.assertEqual(res1, res2) - - def _assert_matches_numpy(self, t, n): - self.assertEqual(n.shape, t.shape) - if t.dtype == torch.float: - self.assertEqual(n, t, rtol=1e-03, atol=1e-05, equal_nan=True) - else: - self.assertEqual(n, t, equal_nan=True) - - def _test_dim_ops(self, pytorch_op, numpy_op, - use_floating=True, use_integral=True, use_complex=False): - def do_one(tensors_dict, dim): - for category, tensors in tensors_dict.items(): - if category == "slice": - dim = 0 - for tensor in tensors: - # we have no control over NumPy warnings... - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - expected = numpy_op(tensor.numpy(), dim) - actual = pytorch_op(tensor, dim) - self._assert_matches_numpy(actual, expected) - if torch.cuda.is_available(): - self._assert_matches_numpy(pytorch_op(tensor.cuda(), - dim).cpu(), - expected) - do_one(self._make_tensors((5, 400000), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), 1) - do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), 0) - do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), 1) - do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), 2) - do_one(self._make_tensors((100000, ), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), -1) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), 0) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), 1) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), 2) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), (1, 2)) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), (1, -1)) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), (0, 2)) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral, use_complex=use_complex), (0, 2, 1)) - - @slowTest - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_sum_dim(self): - self._test_dim_ops( - lambda t, d: t.sum(d), - lambda n, d: n.sum(d), - use_floating=True, use_integral=True, use_complex=True) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_mean_dim(self): - self._test_dim_ops( - lambda t, d: t.mean(d), - lambda n, d: n.mean(d), - use_integral=False) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_std_dim(self): - for unbiased in [False, True]: - self._test_dim_ops( - lambda t, d: t.std(d, unbiased=unbiased), - lambda n, d: n.std(d, ddof=1 if unbiased else 0), - use_integral=False) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_var_dim(self): - for unbiased in [False, True]: - self._test_dim_ops( - lambda t, d: t.var(d, unbiased=unbiased), - lambda n, d: n.var(d, ddof=1 if unbiased else 0), - use_integral=False) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - @unittest.skipIf(not TEST_SCIPY, 'Scipy not found') - def test_logsumexp_dim(self): - from scipy.special import logsumexp - self._test_dim_ops( - lambda t, d: t.logsumexp(d), - lambda n, d: logsumexp(n, d), - use_integral=False) - - def _test_reduce_integer_upcast(self, fn, has_out=True, test_complex=True): - shape = (3, 4, 5) - reduced_shape = fn(torch.ones(shape)).shape - - def _test_out(dtype, other_dtype): - out = torch.ones(reduced_shape, dtype=dtype) - result = fn(x, out=out) - self.assertIs(out.dtype, result.dtype) - self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False) - result = fn(x, out=out, dtype=dtype) - self.assertIs(out.dtype, result.dtype) - self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False) - # 'out' is favored over dtype, check error - self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype)) - - for dtype in [dtype for dtype in torch.testing.get_all_math_dtypes('cpu') if dtype != torch.float16]: - x = torch.ones(shape, dtype=dtype) - expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64 - self.assertIs(expected_dtype, fn(x).dtype) - self.assertEqual(fn(x.to(expected_dtype)), fn(x)) - - if dtype.is_floating_point: - other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 - elif dtype.is_complex: - other_dtype = torch.complex64 if dtype == torch.complex128 else torch.complex128 - else: - other_dtype = torch.int32 if dtype != torch.int32 else torch.int16 - self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype) - self.assertEqual(fn(x.to(other_dtype)), fn(x, dtype=other_dtype), exact_dtype=False) - - # test mixed int/float/complex - if dtype.is_floating_point: - mixed_dtypes = [torch.int32, torch.complex64] - elif dtype.is_complex: - mixed_dtypes = [torch.int32, torch.float32] - else: - mixed_dtypes = [torch.float32, torch.complex64] - - for mixed_dtype in mixed_dtypes: - self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype) - self.assertEqual(fn(x.to(mixed_dtype)), fn(x, dtype=mixed_dtype), exact_dtype=False) - - if has_out: - _test_out(dtype, other_dtype) - _test_out(dtype, mixed_dtype) - - def test_sum_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False) - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs)) - - def test_prod_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False) - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs)) - - def test_cumsum_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs)) - - def test_cumprod_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs)) - - def test_cross_validation(self): - self.assertRaisesRegex( - RuntimeError, "inconsistent tensors dimensions", - lambda: torch.cross(torch.rand(100, 3), torch.rand(100, 3, 10))) - self.assertRaisesRegex( - RuntimeError, "inconsistent tensors sizes", - lambda: torch.cross(torch.rand(5, 3), torch.rand(3, 5))) - self.assertRaisesRegex( - RuntimeError, "no dimension of size 3 in input", - lambda: torch.cross(torch.rand(5, 4), torch.rand(5, 4))) - self.assertRaisesRegex( - RuntimeError, "dimension 0 does not have size 3", - lambda: torch.cross(torch.rand(5, 4, 3), torch.rand(5, 4, 3), dim=0)) - self.assertRaisesRegex( - RuntimeError, "dimension -1 does not have size 3", - lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-1)) - self.assertRaisesRegex( - IndexError, "Dimension out of range", - lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-5)) - def test_dtypes(self): all_dtypes = torch.testing.get_all_dtypes() do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu')) @@ -654,6 +297,13 @@ def test_copy_transpose(self): self.assertEqual(y[:, 0], range(100)) self.assertEqual(y[:, 40], range(4000, 4100)) + # Validates regression reported in https://github.com/pytorch/pytorch/issues/45269 + x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.cfloat).t() + y = torch.empty(100, 100, dtype=torch.cfloat) + y.copy_(x) + self.assertEqual(y[:, 0], range(100)) + self.assertEqual(y[:, 40], range(4000, 4100)) + def test_device(self): cpu = torch.device('cpu') self.assertEqual('cpu', str(cpu)) @@ -690,15 +340,7 @@ def test_device(self): self.assertEqual('cuda', cuda90.type) self.assertEqual(90, cuda90.index) - cuda23333 = torch.device('cuda', 23333) - self.assertEqual('cuda:23333', str(cuda23333)) - self.assertEqual('cuda', cuda23333.type) - self.assertEqual(23333, cuda23333.index) - self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu:1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1)) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1)) self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 ')) self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2')) @@ -711,7 +353,6 @@ def test_device(self): self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3')) self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3')) - self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1)) self.assertRaises(RuntimeError, lambda: torch.device(-1)) self.assertRaises(RuntimeError, lambda: torch.device('other')) @@ -723,6 +364,17 @@ def test_device(self): device_hash_set.add(hash(torch.device(device))) self.assertEqual(len(device_set), len(device_hash_set)) + def get_expected_device_repr(device): + if device.index is not None: + return "device(type='{type}', index={index})".format( + type=device.type, index=device.index) + + return "device(type='{type}')".format(type=device.type) + + for device in device_set: + dev = torch.device(device) + self.assertEqual(repr(dev), get_expected_device_repr(dev)) + def test_to(self): def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(t, non_blocking=non_blocking)) @@ -781,10 +433,6 @@ def test_to_with_tensor(self): self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device) self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device) - def test_dtype_out_match(self): - d = torch.autograd.Variable(torch.DoubleTensor(2, 3)) - self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), out=d, dtype=torch.float32)) - def test_as_subclass(self): class SubTensor(torch.Tensor): member_var = object() @@ -855,39 +503,6 @@ def test_qengine(self): assert torch.backends.quantized.engine == qe, 'qengine not set successfully' torch.backends.quantized.engine = original_qe - def test_renorm(self): - m1 = torch.randn(10, 5) - res1 = torch.Tensor() - - def renorm(matrix, value, dim, max_norm): - m1 = matrix.transpose(dim, 0).contiguous() - # collapse non-dim dimensions. - m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0)))) - norms = m2.norm(value, 1, True) - # clip - new_norms = norms.clone() - new_norms[torch.gt(norms, max_norm)] = max_norm - new_norms.div_(norms.add_(1e-7)) - # renormalize - m1.mul_(new_norms.expand_as(m1)) - return m1.transpose(dim, 0) - - # note that the axis fed to torch.renorm is different (2~=1) - maxnorm = m1.norm(2, 1).mean() - m2 = renorm(m1, 2, 1, maxnorm) - m1.renorm_(2, 1, maxnorm) - self.assertEqual(m1, m2, atol=1e-5, rtol=0) - self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0) - - m1 = torch.randn(3, 4, 5) - m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) - maxnorm = m2.norm(2, 0).mean() - m2 = renorm(m2, 2, 1, maxnorm) - m1.renorm_(2, 1, maxnorm) - m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) - self.assertEqual(m3, m2) - self.assertEqual(m3.norm(2, 0), m2.norm(2, 0)) - def _spawn_method(self, method, arg): try: mp.set_start_method('spawn') @@ -917,31 +532,6 @@ def test_multinomial_invalid_probs(self): self._spawn_method(test_method, torch.Tensor([1, -inf, 1])) self._spawn_method(test_method, torch.Tensor([1, 1, nan])) - def test_broadcast_empty(self): - # empty + empty - self.assertRaises(RuntimeError, lambda: torch.randn(5, 0) + torch.randn(0, 5)) - self.assertEqual(torch.randn(5, 0), torch.randn(0) + torch.randn(5, 0)) - self.assertEqual(torch.randn(5, 0, 0), torch.randn(0) + torch.randn(5, 0, 1)) - - # scalar + empty - self.assertEqual(torch.randn(5, 0, 6), torch.randn(()) + torch.randn(5, 0, 6)) - - # non-empty, empty - self.assertEqual(torch.randn(0), torch.randn(0) + torch.randn(1)) - self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7), - torch.randn(0, 7, 0, 6, 5, 0, 1) + torch.randn(1, 1, 5, 1, 7)) - self.assertRaises(RuntimeError, lambda: torch.randn(7, 0) + torch.randn(2, 1)) - - def test_scalars_as_floats(self): - "zero-dim variables that don't require grad should bind to scalar arguments" - x = torch.tensor(2.) - y = torch.tensor(3.) - # 3 + (3 * 3) * 2 - self.assertEqual(y.addcmul(y, y, value=x), 21) - - x = torch.tensor(2., requires_grad=True) - self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x)) - def test_copy_broadcast(self): torch.zeros(5, 6).copy_(torch.zeros(6)) self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30))) @@ -951,308 +541,6 @@ def test_copy_many_to_one(self): # storage to a single storage would cause RuntimeError to be thrown self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6))) - def assertIsOrdered(self, order, x, mxx, ixx, task): - SIZE = 4 - if order == 'descending': - def check_order(a, b): - # `a != a` because we put NaNs - # at the end of ascending sorted lists, - # and the beginning of descending ones. - return a != a or a >= b - elif order == 'ascending': - def check_order(a, b): - # see above - return b != b or a <= b - else: - error('unknown order "{}", must be "ascending" or "descending"'.format(order)) - - are_ordered = True - for j, k in product(range(SIZE), range(1, SIZE)): - self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]), - 'torch.sort ({}) values unordered for {}'.format(order, task)) - - seen = set() - indicesCorrect = True - size = x.size(x.dim() - 1) - for k in range(size): - seen.clear() - for j in range(size): - self.assertEqual(x[k][ixx[k][j]], mxx[k][j], - msg='torch.sort ({}) indices wrong for {}'.format(order, task)) - seen.add(ixx[k][j]) - self.assertEqual(len(seen), size) - - def test_sort(self): - SIZE = 4 - x = torch.rand(SIZE, SIZE) - res1val, res1ind = torch.sort(x) - - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.sort(x, out=(res2val, res2ind)) - self.assertEqual(res1val, res2val, atol=0, rtol=0) - self.assertEqual(res1ind, res2ind, atol=0, rtol=0) - self.assertEqual(torch.argsort(x), res1ind) - self.assertEqual(x.argsort(), res1ind) - - # Test sorting of random numbers - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') - - # Test simple sort - self.assertEqual( - torch.sort(torch.Tensor((50, 40, 30, 20, 10)))[0], - torch.Tensor((10, 20, 30, 40, 50)), - atol=0, rtol=0 - ) - - # Test that we still have proper sorting with duplicate keys - x = torch.floor(torch.rand(SIZE, SIZE) * 10) - torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') - - # DESCENDING SORT - x = torch.rand(SIZE, SIZE) - res1val, res1ind = torch.sort(x, x.dim() - 1, True) - - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind)) - self.assertEqual(res1val, res2val, atol=0, rtol=0) - self.assertEqual(res1ind, res2ind, atol=0, rtol=0) - self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind) - self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) - - # Test sorting of random numbers - self.assertIsOrdered('descending', x, res2val, res2ind, 'random') - - # Test simple sort task - self.assertEqual( - torch.sort(torch.Tensor((10, 20, 30, 40, 50)), 0, True)[0], - torch.Tensor((50, 40, 30, 20, 10)), - atol=0, rtol=0 - ) - - # Test that we still have proper sorting with duplicate keys - self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') - - # Test sorting with NaNs - x = torch.rand(SIZE, SIZE) - x[1][2] = float('NaN') - x[3][0] = float('NaN') - torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, - 'random with NaNs') - torch.sort(x, out=(res2val, res2ind), descending=True) - self.assertIsOrdered('descending', x, res2val, res2ind, - 'random with NaNs') - - def test_topk(self): - def topKViaSort(t, k, dim, dir): - sorted, indices = t.sort(dim, dir) - return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k) - - def compareTensors(t, res1, ind1, res2, ind2, dim): - # Values should be exactly equivalent - self.assertEqual(res1, res2, atol=0, rtol=0) - - # Indices might differ based on the implementation, since there is - # no guarantee of the relative order of selection - if not ind1.eq(ind2).all(): - # To verify that the indices represent equivalent elements, - # gather from the input using the topk indices and compare against - # the sort indices - vals = t.gather(dim, ind2) - self.assertEqual(res1, vals, atol=0, rtol=0) - - def compare(t, k, dim, dir): - topKVal, topKInd = t.topk(k, dim, dir, True) - sortKVal, sortKInd = topKViaSort(t, k, dim, dir) - compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) - - t = torch.rand(random.randint(1, SIZE), - random.randint(1, SIZE), - random.randint(1, SIZE)) - - for _kTries in range(3): - for _dimTries in range(3): - for transpose in (True, False): - for dir in (True, False): - testTensor = t - if transpose: - dim1 = random.randrange(t.ndimension()) - dim2 = dim1 - while dim1 == dim2: - dim2 = random.randrange(t.ndimension()) - - testTensor = t.transpose(dim1, dim2) - - dim = random.randrange(testTensor.ndimension()) - k = random.randint(1, testTensor.size(dim)) - compare(testTensor, k, dim, dir) - - def test_topk_arguments(self): - q = torch.randn(10, 2, 10) - # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1) - self.assertRaises(TypeError, lambda: q.topk(4, True)) - - def test_median(self): - for size in (155, 156): - x = torch.rand(size, size) - x0 = x.clone() - - nelem = x.nelement() - res1val = torch.median(x) - res2val, _ = torch.sort(x.view(nelem)) - ind = int(math.floor((nelem + 1) / 2) - 1) - - self.assertEqual(res2val[ind], res1val, atol=0, rtol=0) - - res1val, res1ind = torch.median(x, dim=1, keepdim=False) - res2val, res2ind = torch.sort(x) - ind = int(math.floor((size + 1) / 2) - 1) - - self.assertEqual(res2val.select(1, ind), res1val, atol=0, rtol=0) - self.assertEqual(res2val.select(1, ind), res1val, atol=0, rtol=0) - - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.median(x, dim=-1, keepdim=False, out=(res2val, res2ind)) - self.assertEqual(res2val, res1val, atol=0, rtol=0) - self.assertEqual(res2ind, res1ind, atol=0, rtol=0) - - # Test non-default dim - res1val, res1ind = torch.median(x, 0, keepdim=False) - res2val, res2ind = torch.sort(x, 0) - self.assertEqual(res1val, res2val[ind], atol=0, rtol=0) - self.assertEqual(res1ind, res2ind[ind], atol=0, rtol=0) - - # input unchanged - self.assertEqual(x, x0, atol=0, rtol=0) - - def test_mode(self): - x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE) - x[:2] = 1 - x[:, :2] = 1 - x0 = x.clone() - - # Pre-calculated results. - res1val = torch.Tensor(SIZE).fill_(1) - # The indices are the position of the last appearance of the mode element. - res1ind = torch.LongTensor(SIZE).fill_(1) - res1ind[0] = SIZE - 1 - res1ind[1] = SIZE - 1 - - res2val, res2ind = torch.mode(x, keepdim=False) - self.assertEqual(res1val, res2val, atol=0, rtol=0) - self.assertEqual(res1ind, res2ind, atol=0, rtol=0) - - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.mode(x, keepdim=False, out=(res2val, res2ind)) - self.assertEqual(res1val, res2val, atol=0, rtol=0) - self.assertEqual(res1ind, res2ind, atol=0, rtol=0) - - # Test non-default dim - res2val, res2ind = torch.mode(x, 0, False) - self.assertEqual(res1val, res2val, atol=0, rtol=0) - self.assertEqual(res1ind, res2ind, atol=0, rtol=0) - - # input unchanged - self.assertEqual(x, x0, atol=0, rtol=0) - - def test_trilu_indices(self): - for test_args in tri_tests_args: - _compare_trilu_indices(self, *test_args) - run_additional_tri_tests(self, 'cpu') - - # test default options - x = torch.ones( - 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) - self.assertEqual( - x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3)) - self.assertEqual( - x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3)) - - # test stride 0 cases - x = torch.ones( - 3, 1, 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) - output = x.triu(2).expand(3, 3, 3, 3) - b = x.clone().expand(3, 3, 3, 3) - self.assertEqual(b.triu(2), output) - self.assertRaises(RuntimeError, lambda: b.triu_(2)) - - def test_narrow(self): - x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]])) - self.assertEqual(x.narrow(0, 0, 2), torch.Tensor([[0, 1, 2], [3, 4, 5]])) - self.assertEqual(x.narrow(0, 1, 1), torch.Tensor([[3, 4, 5]])) - self.assertEqual(x.narrow(0, -1, 1), torch.Tensor([[6, 7, 8]])) - self.assertEqual(x.narrow(0, -2, 2), torch.Tensor([[3, 4, 5], [6, 7, 8]])) - self.assertEqual(x.narrow(0, -3, 3), torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) - self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]])) - self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]])) - - def test_narrow_tensor(self): - x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.Tensor([[0, 1, 2]])) - with self.assertRaises(Exception): - x.narrow(0, torch.tensor(0.), 1) - with self.assertRaises(Exception): - x.narrow(0, torch.tensor([0]), 1) - with self.assertRaises(Exception): - x.narrow(0, torch.tensor([0, 1]), 1) - - def test_stack(self): - for dtype in (torch.half, torch.double, torch.int): - x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - for dim in range(4): - res = torch.stack((x, y, z), dim) - res_neg = torch.stack((x, y, z), dim - 4) - expected_size = x.size()[:dim] + (3,) + x.size()[dim:] - self.assertEqual(res, res_neg) - self.assertEqual(res.size(), expected_size) - self.assertEqual(res.select(dim, 0), x, atol=0, rtol=0) - self.assertEqual(res.select(dim, 1), y, atol=0, rtol=0) - self.assertEqual(res.select(dim, 2), z, atol=0, rtol=0) - - def test_stack_out(self): - for dtype in (torch.half, torch.double, torch.int): - x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - for dim in range(4): - expected_size = x.size()[:dim] + (3,) + x.size()[dim:] - res_out = x.new(expected_size) - res_neg_out = x.new(expected_size) - res_out_dp = res_out.data_ptr() - res_out_neg_dp = res_neg_out.data_ptr() - torch.stack((x, y, z), dim, out=res_out) - torch.stack((x, y, z), dim - 4, out=res_neg_out) - self.assertEqual(res_out, res_neg_out) - self.assertEqual(res_out.size(), expected_size) - self.assertEqual(res_out_dp, res_out.data_ptr()) - self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr()) - self.assertEqual(res_out.select(dim, 0), x, atol=0, rtol=0) - self.assertEqual(res_out.select(dim, 1), y, atol=0, rtol=0) - self.assertEqual(res_out.select(dim, 2), z, atol=0, rtol=0) - - def test_unbind(self): - x = torch.rand(2, 3, 4, 5) - for dim in range(4): - res = torch.unbind(x, dim) - res2 = x.unbind(dim) - self.assertEqual(x.size(dim), len(res)) - self.assertEqual(x.size(dim), len(res2)) - for i in range(dim): - self.assertEqual(x.select(dim, i), res[i]) - self.assertEqual(x.select(dim, i), res2[i]) - def test_slice(self): empty = torch.empty(0, 4) x = torch.arange(0., 16).view(4, 4) @@ -1271,38 +559,6 @@ def test_slice(self): self.assertEqual(x[:, -2:3].tolist(), [[2], [6], [10], [14]]) self.assertEqual(x[0:-1:2].tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]]) - @skipIfNoLapack - def test_ormqr(self): - mat1 = torch.randn(7, 7) - mat2 = torch.randn(7, 7) - q, r = torch.qr(mat1) - m, tau = torch.geqrf(mat1) - out_holder = torch.empty_like(mat1) - - res1 = torch.mm(q, mat2) - res2 = torch.ormqr(m, tau, mat2, left=True, transpose=False) - torch.ormqr(m, tau, mat2, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) - - res1 = torch.mm(mat2, q) - res2 = torch.ormqr(m, tau, mat2, left=False, transpose=False) - torch.ormqr(m, tau, mat2, left=False, transpose=False, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) - - res1 = torch.mm(q.t(), mat2) - res2 = torch.ormqr(m, tau, mat2, left=True, transpose=True) - torch.ormqr(m, tau, mat2, left=True, transpose=True, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) - - res1 = torch.mm(mat2, q.t()) - res2 = torch.ormqr(m, tau, mat2, left=False, transpose=True) - torch.ormqr(m, tau, mat2, left=False, transpose=True, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) - @unittest.skip("Not implemented yet") def test_conv2(self): x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100))) @@ -1536,16 +792,6 @@ def test_numel(self): self.assertEqual(b.nelement(), 3 * 100 * 100) self.assertEqual(b.numel(), 3 * 100 * 100) - # Note: the warning this tests for only appears once per program, so - # other instances of this warning should be addressed to avoid - # the tests depending on the order in which they're run. - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_numpy_non_writeable(self): - arr = np.zeros(5) - arr.flags['WRITEABLE'] = False - self.assertWarns(UserWarning, lambda: torch.from_numpy(arr)) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_empty_storage_view(self): # we should be able to "modify" slices of a 0-element # array without an error being raised due to @@ -1553,7 +799,6 @@ def test_empty_storage_view(self): t = torch.from_numpy(np.empty((0, 4))) t[:, 1::2] *= 1 - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_newaxis_numpy_comparison(self): def run_test(tensor, *idx): npt = tensor.numpy() @@ -1654,23 +899,25 @@ def checkPartialAssign(index): reference[0.0, :, 0.0] = 1 def test_index_add(self): - for dest_contig, src_contig, index_contig in product([True, False], repeat=3): - for other_sizes in ((), (4, 5)): - num_copy, num_dest = 3, 3 - dest = torch.randn(num_dest, *other_sizes) - if not dest_contig: - dest = torch.testing.make_non_contiguous(dest) - src = torch.randn(num_copy, *other_sizes) - if not src_contig: - src = torch.testing.make_non_contiguous(src) - idx = torch.randperm(num_dest).narrow(0, 0, num_copy) - if not index_contig: - idx = torch.testing.make_non_contiguous(idx) - dest2 = dest.clone() - dest.index_add_(0, idx, src) - for i in range(idx.size(0)): - dest2[idx[i]] += src[i] - self.assertEqual(dest, dest2) + for device in torch.testing.get_all_device_types(): + for dest_contig, src_contig, index_contig in product([True, False], repeat=3): + for other_sizes in ((), (4, 5)): + for dtype in [torch.int, torch.long]: + num_copy, num_dest = 3, 3 + dest = torch.randn(num_dest, *other_sizes, device=device) + if not dest_contig: + dest = torch.testing.make_non_contiguous(dest) + src = torch.randn(num_copy, *other_sizes, device=device) + if not src_contig: + src = torch.testing.make_non_contiguous(src) + idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy) + if not index_contig: + idx = torch.testing.make_non_contiguous(idx) + dest2 = dest.clone() + dest.index_add_(0, idx, src) + for i in range(idx.size(0)): + dest2[idx[i]] += src[i] + self.assertEqual(dest, dest2) # add coverage for issue with atomic add that appeared only for # specific dtypes on cuda: @@ -1678,50 +925,20 @@ def test_index_add(self): def test_index_add_all_dtypes(self): for device in torch.testing.get_all_device_types(): for dtype in torch.testing.get_all_math_dtypes(device): - size = [5, 5] - if dtype.is_floating_point or dtype.is_complex: - tensor = torch.rand(size, dtype=dtype, device=device) - elif dtype.is_signed: - tensor = torch.randint(-5, 15, size, dtype=dtype, device=device) - else: - tensor = torch.randint(0, 10, size, dtype=dtype, device=device) - - # index_add calls atomicAdd on cuda. - zeros = torch.zeros(size, dtype=dtype, device=device) + for idx_dtype in [torch.int, torch.long]: + size = [5, 5] + if dtype.is_floating_point or dtype.is_complex: + tensor = torch.rand(size, dtype=dtype, device=device) + elif dtype.is_signed: + tensor = torch.randint(-5, 15, size, dtype=dtype, device=device) + else: + tensor = torch.randint(0, 10, size, dtype=dtype, device=device) - # index_add is not supported for complex dtypes on cuda yet - if device.startswith('cuda') and dtype.is_complex: - continue + # index_add calls atomicAdd on cuda. + zeros = torch.zeros(size, dtype=dtype, device=device) - added = zeros.index_add(0, torch.arange(0, size[0], dtype=torch.long, device=device), tensor) - self.assertEqual(added, tensor) - - def test_t(self): - # Test 0D tensors - x = torch.randn(()) - self.assertEqual(x, x.t()) - x = x.to_sparse() - self.assertEqual(x, x.t()) - - # Test 1D tensors - x = torch.arange(4) - self.assertEqual(x, x.t()) - x = x.to_sparse() - self.assertEqual(x, x.t()) - - # Test 2D tensors - x = torch.rand((2, 2)) - self.assertEqual(x.t(), x.transpose(0, 1)) - x = x.to_sparse() - self.assertEqual(x.t(), x.transpose(0, 1)) - - # Test 3D tensor - x = torch.rand((2, 2, 2)) - with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): - x.t() - x = x.to_sparse() - with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): - x.t() + added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor) + self.assertEqual(added, tensor) def test_take(self): def check(src, idx): @@ -1771,63 +988,6 @@ def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o): ii[dim] = slice(0, idx.size(dim) + 1) idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] - def test_flatten(self): - # Test that flatten returns 1-dim tensor when given a 0-dim tensor - zero_dim_tensor = torch.tensor(123) - flat0 = zero_dim_tensor.flatten() - one_dim_tensor = torch.tensor([123]) - flat1 = zero_dim_tensor.flatten() - - self.assertEqual(zero_dim_tensor.shape, torch.Size([])) - self.assertEqual(flat0.shape, torch.Size([1])) - self.assertEqual(one_dim_tensor.shape, torch.Size([1])) - self.assertEqual(flat1.shape, torch.Size([1])) - self.assertEqual(flat0, one_dim_tensor) - self.assertEqual(flat0, flat1) - self.assertEqual(flat0.shape, flat1.shape) - - # Test both float tensor and quantized tensor - tensors = [torch.randn(5, 5, 5, 5), - torch._empty_affine_quantized([5, 5, 5, 5], - scale=2, - zero_point=3, - dtype=torch.quint8)] - for src in tensors: - flat = src.flatten(0, -1) - self.assertEqual(flat.shape, torch.Size([625])) - self.assertEqual(src.view(-1), flat.view(-1)) - - flat = src.flatten(0, 2) - self.assertEqual(flat.shape, torch.Size([125, 5])) - self.assertEqual(src.view(-1), flat.view(-1)) - - flat = src.flatten(0, 1) - self.assertEqual(flat.shape, torch.Size([25, 5, 5])) - self.assertEqual(src.view(-1), flat.view(-1)) - - flat = src.flatten(1, 2) - self.assertEqual(flat.shape, torch.Size([5, 25, 5])) - self.assertEqual(src.view(-1), flat.view(-1)) - - flat = src.flatten(2, 3) - self.assertEqual(flat.shape, torch.Size([5, 5, 25])) - self.assertEqual(src.view(-1), flat.view(-1)) - - flat = src.flatten(-2, -1) - self.assertEqual(flat.shape, torch.Size([5, 5, 25])) - self.assertEqual(src.view(-1), flat.view(-1)) - - flat = src.flatten(2, 2) - self.assertEqual(flat, src) - - # out of bounds index - with self.assertRaisesRegex(IndexError, 'Dimension out of range'): - src.flatten(5, 10) - - # invalid start and end - with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'): - src.flatten(2, 0) - def test_unflatten(self): # test args: tensor, int, sizes self.assertEqual(torch.tensor([]).unflatten(0, (0, 1)), torch.empty(0, 1)) @@ -2104,26 +1264,6 @@ def test_masked_fill(self): for wi in w: self.assertEqual(str(wi.message)[0:52], str(warn)) - - def test_unbiased(self): - tensor = torch.randn(100) - self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) - self.assertEqual(tensor.var(), tensor.var(unbiased=True)) - self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False)) - - tensor = torch.FloatTensor([1.0, 2.0]) - self.assertEqual(tensor.var(unbiased=True), 0.5) - self.assertEqual(tensor.var(unbiased=False), 0.25) - - tensor = torch.FloatTensor([1.0, 2.0, 3.0]) - self.assertEqual(tensor.var(unbiased=True), 1.0) - self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0) - - tensor = torch.randn(100) - self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True)) - self.assertEqual(tensor.std(), tensor.std(unbiased=True)) - self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False)) - def test_structseq_repr(self): a = torch.arange(250).reshape(5, 5, 10) expected = """ @@ -2140,278 +1280,6 @@ def test_structseq_repr(self): [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))""" self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip()) - def test_var_stability(self): - tensor = torch.FloatTensor([2281.5, 2281.25]) - self.assertEqual(tensor.var(dim=0), 0.03125) - self.assertEqual(tensor.var(), 0.03125) - - # TODO: this should be refactored into the view ops test suite - def test_view_empty(self): - x = torch.randn(0, 6) - self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) - - # TODO: this should be refactored into the view ops test suite - def test_reshape(self): - x = torch.randn(3, 3) - self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) - self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) - self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) - self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) - - y = torch.randn(4, 4, 4)[:, 0, :] - self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) - self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) - self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) - - s = torch.randn(()) - self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) - self.assertEqual(s.reshape(-1).shape, (1,)) - self.assertRaises(RuntimeError, lambda: s.reshape(2)) - - empty = torch.tensor([]) - self.assertEqual(empty, empty.reshape(-1)) - self.assertEqual(empty, empty.reshape([0])) - # TODO: fix these once we have multi-dimensional empty tensors - self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) - self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) - self.assertRaises(RuntimeError, lambda: empty.reshape(1)) - - x = torch.randn(3, 3) - self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) - self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) - self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10))) - - # TODO: this should be refactored into the view ops test suite - def test_empty_reshape(self): - x = torch.randn(0, 6) - self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) - # should be viewable -- i.e. data_ptr is the same. - self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) - - # match NumPy semantics -- don't infer the size of dimension with a degree of freedom - self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) - - def check_single_matmul(self, x, y, shape): - a = np.array(x, copy=False) - b = np.array(y, copy=False) - expected = np.matmul(a, b) - - ans = torch.matmul(x, y) - self.assertTrue(ans.is_contiguous()) - self.assertTrue(np.array_equal(ans, expected)) - - out = torch.zeros(*shape, dtype=torch.int64) - ans = torch.matmul(x, y, out=out) - self.assertIs(ans, out) - self.assertTrue(ans.is_contiguous()) - self.assertTrue(np.array_equal(ans, expected)) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_matmul_small_brute_force_1d_Nd(self): - # Issue #20452: range(0, 10) does not work. - n = 1 - for m in range(1, 8): - for p in range(1, 8): - for o in range(1, 5): - # 1d, 3d, inner dimensions C - x = torch.arange(m) - y = torch.arange(o * m * p).reshape(o, m, p) - self.check_single_matmul(x, y, (o, n, p)) - - # 1d, 3d, inner dimensions Fortran - x = torch.arange(m) - y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (o, n, p)) - - # 1d, 3d, inner dimensions non-contiguous - x = torch.arange(2 * m)[::2] - y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2] - self.check_single_matmul(x, y, (o, n, p)) - - for r in range(1, 5): - # 1d, 4d, inner dimensions C - x = torch.arange(m) - y = torch.arange(r * o * m * p).reshape(r, o, m, p) - self.check_single_matmul(x, y, (r, o, n, p)) - - # 1d, 4d, inner dimensions Fortran - x = torch.arange(m) - y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (r, o, n, p)) - - # 1d, 4d, inner dimensions non-contiguous - x = torch.arange(2 * m)[::2] - y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2] - self.check_single_matmul(x, y, (r, o, n, p)) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_matmul_small_brute_force_2d_Nd(self): - # Issue #20452: range(0, 10) does not work. - for n in range(1, 5): - for m in range(1, 5): - for p in range(1, 5): - for o in range(1, 3): - # 2d, 3d, inner dimensions C - x = torch.arange(n * m).reshape(n, m) - y = torch.arange(o * m * p).reshape(o, m, p) - self.check_single_matmul(x, y, (o, n, p)) - - # 2d, 3d, inner dimensions Fortran - x = torch.arange(m * n).reshape(m, n).transpose(-1, -2) - y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (o, n, p)) - - # 2d, 3d, inner dimensions non-contiguous - x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2] - y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2] - self.check_single_matmul(x, y, (o, n, p)) - - for r in range(1, 2): - # 2d, 4d, inner dimensions C - x = torch.arange(n * m).reshape(n, m) - y = torch.arange(r * o * m * p).reshape(r, o, m, p) - self.check_single_matmul(x, y, (r, o, n, p)) - - # 2d, 4d, inner dimensions Fortran - x = torch.arange(m * n).reshape(m, n).transpose(-1, -2) - y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (r, o, n, p)) - - # 2d, 4d, inner dimensions non-contiguous - x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2] - y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2] - self.check_single_matmul(x, y, (r, o, n, p)) - - def test_expand(self): - tensor = torch.rand(1, 8, 1) - tensor2 = torch.rand(5) - template = torch.rand(4, 8, 5) - target = template.size() - self.assertEqual(tensor.expand_as(template).size(), target) - self.assertEqual(tensor.expand(4, 8, 5).size(), target) - self.assertEqual(tensor.expand(target).size(), target) - self.assertEqual(tensor2.expand_as(template).size(), target) - self.assertEqual(tensor2.expand(4, 8, 5).size(), target) - self.assertEqual(tensor2.expand(target).size(), target) - - # test double expand - self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) - - # test non-contiguous - noncontig = torch.randn(5, 2, 1, 3)[:, 0] - self.assertFalse(noncontig.is_contiguous()) - self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) - - # make sure it's compatible with unsqueeze - expanded = tensor2.expand(1, 1, 5) - unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) - self.assertEqual(expanded, unsqueezed) - self.assertEqual(expanded.stride(), unsqueezed.stride()) - - # test -1 as target size - self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) - self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) - - # test expanding empty to empty - self.assertEqual(torch.zeros(0).expand((0,)), torch.zeros(0)) - - def test_repeat(self): - initial_shape = (8, 4) - tensor = torch.rand(*initial_shape) - - size = (3, 1, 1) - torchSize = torch.Size(size) - target = [3, 8, 4] - self.assertEqual(tensor.repeat(*size).size(), target, msg='Error in repeat') - self.assertEqual(tensor.repeat(torchSize).size(), target, - msg='Error in repeat using LongStorage') - result = tensor.repeat(*size) - self.assertEqual(result.size(), target, msg='Error in repeat using result') - result = tensor.repeat(torchSize) - self.assertEqual(result.size(), target, msg='Error in repeat using result and LongStorage') - self.assertEqual(result.mean(0).view(8, 4), tensor, msg='Error in repeat (not equal)') - - zeroDimTarget = torch.Size([24, 0]) - self.assertEqual(tensor.repeat((3, 0)).size(), zeroDimTarget, msg="Error when calling with 0 repeats") - - def test_repeat_interleave(self): - x = torch.tensor([0, 1, 2, 3]) - expected = torch.tensor([1, 2, 2, 3, 3, 3]) - self.assertEqual(torch.repeat_interleave(x), expected) - - with self.assertRaises(RuntimeError): - torch.repeat_interleave(torch.arange(4).reshape(2, 2)) - - with self.assertRaises(RuntimeError): - torch.repeat_interleave(torch.arange(4.0)) - - with self.assertRaises(RuntimeError): - torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4])) - - y = torch.tensor([[1, 2], [3, 4]]) - - y1_v1 = torch.repeat_interleave(y, 2) - y1_v2 = torch.repeat_interleave(y, torch.tensor(2)) - y1_v3 = torch.repeat_interleave(y, torch.tensor([2])) - y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4]) - self.assertEqual(y1_v1, y1_expect) - self.assertEqual(y1_v2, y1_expect) - self.assertEqual(y1_v3, y1_expect) - - y2 = torch.repeat_interleave(y, 3, dim=1) - y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2], - [3, 3, 3, 4, 4, 4]]) - self.assertEqual(y2, y2_expect) - - y3 = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) - y3_expect = torch.tensor([[1, 2], - [3, 4], - [3, 4]]) - self.assertEqual(y3, y3_expect) - - with self.assertRaises(RuntimeError): - torch.repeat_interleave(y, torch.tensor([1, 2, 3]), dim=0) - - with self.assertRaises(RuntimeError): - torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0) - - # test zero sized dimension - x = torch.zeros((5, 0)) - y = torch.repeat_interleave(x, repeats=3, dim=1) - self.assertEqual(y, x.new_zeros(5, 0)) - - x = torch.tensor([], dtype=torch.int64) - y = torch.repeat_interleave(x, x) - self.assertEqual(y, x) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_repeat_tile(self): - - initial_shape = (8, 4) - - repeats = ((3, 1, 1), - (3, 3, 3), - (1, 2, 1), - (2, 2, 2, 2)) - - def _generate_noncontiguous_input(): - - out = np.broadcast_to(np.random.random((1, 4)), - initial_shape) - # Note: non-writeable NumPy arrays will warn if converted to tensors - out.setflags(write=True) - - assert not (out.flags.c_contiguous or out.flags.f_contiguous) - - return out - - for repeat in repeats: - for tensor in (torch.from_numpy(np.random.random(initial_shape)), - torch.from_numpy(_generate_noncontiguous_input()),): - - self.assertEqual(tensor.repeat(*repeat).numpy(), - np.tile(tensor.numpy(), repeat)) - def test_is_same_size(self): t1 = torch.Tensor(3, 4, 9, 10) t2 = torch.Tensor(3, 4) @@ -2563,90 +1431,11 @@ def test_element_size(self): self.assertGreaterEqual(long, int) self.assertGreaterEqual(double, float) - def test_split(self): - tensor = torch.rand(7, 4) - split_size = 3 - dim = 0 - target_sizes = ([3, 4], [3, 4], [1, 4]) - splits = tensor.split(split_size, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, - atol=0, rtol=0) - start = start + target_size[dim] - - # Variable sections split - tensor = torch.randn(20, 10) - dim = 0 - split_sizes = [5, 5, 10] - target_sizes = ([[5, 10], [5, 10], [10, 10]]) - splits = tensor.split(split_sizes, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, - atol=0, rtol=0) - start = start + target_size[dim] - - split_sizes = [2, 2, 6] - target_sizes = ([20, 2], [20, 2], [20, 6]) - dim = 1 - splits = tensor.split(split_sizes, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, - atol=0, rtol=0) - start = start + target_size[dim] - - def test_chunk(self): - tensor = torch.rand(4, 7) - num_chunks = 3 - dim = 1 - target_sizes = ([4, 3], [4, 3], [4, 1]) - splits = tensor.chunk(num_chunks, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, - atol=0, rtol=0) - start = start + target_size[dim] - - # Invalid chunk sizes - error_regex = 'chunk expects.*greater than 0' - with self.assertRaisesRegex(RuntimeError, error_regex): - tensor.chunk(0) - with self.assertRaisesRegex(RuntimeError, error_regex): - tensor.chunk(-2) - - def test_tolist(self): - list0D = [] - tensor0D = torch.Tensor(list0D) - self.assertEqual(tensor0D.tolist(), list0D) - - table1D = [1, 2, 3] - tensor1D = torch.Tensor(table1D) - storage = torch.Storage(table1D) - self.assertEqual(tensor1D.tolist(), table1D) - self.assertEqual(storage.tolist(), table1D) - self.assertEqual(tensor1D.tolist(), table1D) - self.assertEqual(storage.tolist(), table1D) - - table2D = [[1, 2], [3, 4]] - tensor2D = torch.Tensor(table2D) - self.assertEqual(tensor2D.tolist(), table2D) - - tensor3D = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) - tensorNonContig = tensor3D.select(1, 1) - self.assertFalse(tensorNonContig.is_contiguous()) - self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]]) - def test_permute(self): orig = [1, 2, 3, 4, 5, 6, 7] perm = torch.randperm(7).tolist() x = torch.Tensor(*orig).fill_(0) - new = list(map(lambda x: x - 1, x.permute(*perm).size())) + new = [i - 1 for i in x.permute(*perm).size()] self.assertEqual(perm, new) self.assertEqual(x.size(), orig) @@ -2893,25 +1682,22 @@ def test_parsing_intlist(self): self.assertEqual(torch.Size([3, 4]), torch.ones((torch.tensor(3), torch.tensor(4))).shape) self.assertEqual(torch.Size([3, 4]), torch.ones(torch.tensor(3), torch.tensor(4)).shape) # parse with numpy integers - if TEST_NUMPY: - self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape) - self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape) - self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape) - self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape) # fail parse with float variables self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3.), torch.tensor(4)))) # fail parse with numpy floats - if TEST_NUMPY: - self.assertRaises(TypeError, lambda: torch.ones((np.float(3.), torch.tensor(4)))) - self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4)))) + self.assertRaises(TypeError, lambda: torch.ones((np.float(3.), torch.tensor(4)))) + self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4)))) # fail parse with > 1 element variables self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3))) self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3, 3)))) - if TEST_NUMPY: - self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3))) - self.assertRaises(TypeError, lambda: torch.ones((np.array(3, 3)))) + self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3))) + self.assertRaises(TypeError, lambda: torch.ones((np.array(3, 3)))) # fail parse with additional positional args after intlist arg self.assertRaisesRegex(TypeError, @@ -2922,20 +1708,34 @@ def test_parsing_intlist(self): lambda: torch.tensor().new_zeros((5, 5), 0)) def test_half_tensor(self): - x = torch.randn(5, 5).float() - y = torch.randn(5, 5).float() - xh, yh = x.half(), y.half() - - self.assertEqual(x.half().float(), x, atol=1e-3, rtol=0) + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") - z = torch.Tensor(5, 5) - self.assertEqual(z.copy_(xh), x, atol=1e-3, rtol=0) + # contiguous tensor + # non-contiguous tensor + # dense non-overlapping tensor + # non-dense non-overlapping sliced tensor + # non-dense overlapping equal strides + for device in devices: + tset = ( + torch.randn(4, 3, 2, device=device, dtype=torch.float).contiguous(), + torch.randn(4, 3, 2, device=device, dtype=torch.float).transpose(0, 1), + torch.randn(4, 3, 2, device=device, dtype=torch.float), + torch.randn(4, 3, 2, device=device, dtype=torch.float)[:, :, ::2], + torch.empty_strided( + (4, 2, 3), (10, 3, 3), device=device, dtype=torch.float + ).copy_(torch.rand((4, 2, 3), dtype=torch.float, device=device)), + ) - with tempfile.NamedTemporaryFile() as f: - torch.save(xh, f) - f.seek(0) - xh2 = torch.load(f) - self.assertEqual(xh.float(), xh2.float()) + for x in tset: + self.assertEqual(x.half().float(), x, atol=1e-3, rtol=0) + xh = x.half() + with tempfile.NamedTemporaryFile() as f: + torch.save(xh, f) + f.seek(0) + xh2 = torch.load(f) + self.assertEqual(xh.float(), xh2.float()) def test_from_buffer(self): a = bytearray([1, 2, 3, 4]) @@ -3048,15 +1848,14 @@ def test_storage_casts(self): self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage') self.assertIs(complexdouble_storage.dtype, torch.complex128) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") def test_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.FloatStorage.from_file(f.name, True, size) + def assert_with_filename(filename): + size = 10000 + s1 = torch.FloatStorage.from_file(filename, True, size) t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) # check mapping - s2 = torch.FloatStorage.from_file(f.name, True, size) + s2 = torch.FloatStorage.from_file(filename, True, size) t2 = torch.FloatTensor(s2) self.assertEqual(t1, t2, atol=0, rtol=0) @@ -3070,15 +1869,24 @@ def test_from_file(self): t2.fill_(rnum) self.assertEqual(t1, t2, atol=0, rtol=0) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") + # release the tensors + del s1, t1, s2, t2 + + with TemporaryFileName() as fname: + assert_with_filename(fname) + + if IS_FILESYSTEM_UTF8_ENCODING: + with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname: + assert_with_filename(fname) + def test_torch_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.from_file(f.name, True, size, dtype=torch.float) + def assert_with_filename(filename): + size = 10000 + s1 = torch.from_file(filename, True, size, dtype=torch.float) t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) # check mapping - s2 = torch.from_file(f.name, True, size, dtype=torch.float) + s2 = torch.from_file(filename, True, size, dtype=torch.float) t2 = torch.FloatTensor(s2) self.assertEqual(t1, t2, atol=0, rtol=0) @@ -3092,6 +1900,16 @@ def test_torch_from_file(self): t2.fill_(rnum) self.assertEqual(t1, t2, atol=0, rtol=0) + # release the tensors + del s1, t1, s2, t2 + + with TemporaryFileName() as fname: + assert_with_filename(fname) + + if IS_FILESYSTEM_UTF8_ENCODING: + with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname: + assert_with_filename(fname) + def test_print(self): default_type = torch.Tensor().type() for t in torch._tensor_classes: @@ -3356,20 +2174,6 @@ def test_sizeof(self) -> None: self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) - def test_unsqueeze(self) -> None: - x = torch.randn(2, 3, 4) - y = x.unsqueeze(1) - self.assertEqual(y, x.view(2, 1, 3, 4)) - y = x.clone().unsqueeze_(2) - self.assertEqual(y, x.view(2, 3, 1, 4)) - - x = x[:, 1] - self.assertFalse(x.is_contiguous()) - y = x.unsqueeze(1) - self.assertEqual(y, x.contiguous().view(2, 1, 4)) - y = x.clone().unsqueeze_(2) - self.assertEqual(y, x.contiguous().view(2, 4, 1)) - def test_iter(self) -> None: x = torch.randn(5, 5) for i, sub in enumerate(x): @@ -3378,16 +2182,6 @@ def test_iter(self) -> None: x = torch.Tensor() self.assertEqual(list(x), []) - def test_accreal_type(self) -> None: - x = torch.ones(2, 3, 4) - self.assertIsInstance(x.double().sum().item(), float) - self.assertIsInstance(x.float().sum().item(), float) - self.assertIsInstance(x.long().sum().item(), int) - self.assertIsInstance(x.int().sum().item(), int) - self.assertIsInstance(x.short().sum().item(), int) - self.assertIsInstance(x.char().sum().item(), int) - self.assertIsInstance(x.byte().sum().item(), int) - def test_assertEqual(self) -> None: x = torch.FloatTensor([0]) self.assertEqual(x, 0) @@ -3416,9 +2210,8 @@ def test_new(self) -> None: self.assertEqual(x.new([3, 4]).shape, [2]) self.assertEqual(x.new([3, 4]).tolist(), [3, 4]) self.assertEqual(x.new((3, 4)).tolist(), [3, 4]) - if TEST_NUMPY: - self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4]) - self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4]) + self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4]) + self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4]) self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4]) self.assertEqual(x.new(size=(3, 4)).shape, [3, 4]) self.assertEqual(x.new(()).shape, [0]) @@ -3445,360 +2238,6 @@ def test_pin_memory(self): self.assertIs(pinned, pinned.pin_memory()) self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_numpy_unresizable(self) -> None: - x = np.zeros((2, 2)) - y = torch.from_numpy(x) - with self.assertRaises(ValueError): - x.resize((5, 5)) - - z = torch.randn(5, 5) - w = z.numpy() - with self.assertRaises(RuntimeError): - z.resize_(10, 10) - with self.assertRaises(ValueError): - w.resize((10, 10)) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_to_numpy(self) -> None: - def get_castable_tensor(shape, dtype): - if dtype.is_floating_point: - dtype_info = torch.finfo(dtype) - # can't directly use min and max, because for double, max - min - # is greater than double range and sampling always gives inf. - low = max(dtype_info.min, -1e10) - high = min(dtype_info.max, 1e10) - t = torch.empty(shape, dtype=torch.float64).uniform_(low, high) - else: - # can't directly use min and max, because for int64_t, max - min - # is greater than int64_t range and triggers UB. - dtype_info = torch.iinfo(dtype) - low = max(dtype_info.min, int(-1e10)) - high = min(dtype_info.max, int(1e10)) - dtype_info = torch.iinfo(dtype) - t = torch.empty(shape, dtype=torch.int64).random_(low, high) - return t.to(dtype) - - dtypes = [ - torch.uint8, - torch.int8, - torch.short, - torch.int, - torch.half, - torch.float, - torch.double, - torch.long, - ] - for dtp in dtypes: - # 1D - sz = 10 - x = get_castable_tensor(sz, dtp) - y = x.numpy() - for i in range(sz): - self.assertEqual(x[i], y[i]) - - # 1D > 0 storage offset - xm = get_castable_tensor(sz * 2, dtp) - x = xm.narrow(0, sz - 1, sz) - self.assertTrue(x.storage_offset() > 0) - y = x.numpy() - for i in range(sz): - self.assertEqual(x[i], y[i]) - - def check2d(x, y): - for i in range(sz1): - for j in range(sz2): - self.assertEqual(x[i][j], y[i][j]) - - # empty - x = torch.Tensor().to(dtp) - y = x.numpy() - self.assertEqual(y.size, 0) - - # contiguous 2D - sz1 = 3 - sz2 = 5 - x = get_castable_tensor((sz1, sz2), dtp) - y = x.numpy() - check2d(x, y) - self.assertTrue(y.flags['C_CONTIGUOUS']) - - # with storage offset - xm = get_castable_tensor((sz1 * 2, sz2), dtp) - x = xm.narrow(0, sz1 - 1, sz1) - y = x.numpy() - self.assertTrue(x.storage_offset() > 0) - check2d(x, y) - self.assertTrue(y.flags['C_CONTIGUOUS']) - - # non-contiguous 2D - x = get_castable_tensor((sz2, sz1), dtp).t() - y = x.numpy() - check2d(x, y) - self.assertFalse(y.flags['C_CONTIGUOUS']) - - # with storage offset - xm = get_castable_tensor((sz2 * 2, sz1), dtp) - x = xm.narrow(0, sz2 - 1, sz2).t() - y = x.numpy() - self.assertTrue(x.storage_offset() > 0) - check2d(x, y) - - # non-contiguous 2D with holes - xm = get_castable_tensor((sz2 * 2, sz1 * 2), dtp) - x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t() - y = x.numpy() - self.assertTrue(x.storage_offset() > 0) - check2d(x, y) - - if dtp != torch.half: - # check writeable - x = get_castable_tensor((3, 4), dtp) - y = x.numpy() - self.assertTrue(y.flags.writeable) - y[0][1] = 3 - self.assertTrue(x[0][1] == 3) - y = x.t().numpy() - self.assertTrue(y.flags.writeable) - y[0][1] = 3 - self.assertTrue(x[0][1] == 3) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_to_numpy_bool(self) -> None: - x = torch.tensor([True, False], dtype=torch.bool) - self.assertEqual(x.dtype, torch.bool) - - y = x.numpy() - self.assertEqual(y.dtype, np.bool) - for i in range(len(x)): - self.assertEqual(x[i], y[i]) - - x = torch.tensor([True], dtype=torch.bool) - self.assertEqual(x.dtype, torch.bool) - - y = x.numpy() - self.assertEqual(y.dtype, np.bool) - self.assertEqual(x[0], y[0]) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_from_numpy(self) -> None: - dtypes = [ - np.double, - np.float, - np.float16, - np.complex64, - np.complex128, - np.int64, - np.int32, - np.int16, - np.int8, - np.uint8, - np.longlong, - np.bool, - ] - complex_dtypes = [ - np.complex64, - np.complex128, - ] - - for dtype in dtypes: - array = np.array([1, 2, 3, 4], dtype=dtype) - tensor_from_array = torch.from_numpy(array) - # TODO: change to tensor equality check once HalfTensor - # implements `==` - for i in range(len(array)): - self.assertEqual(tensor_from_array[i], array[i]) - # ufunc 'remainder' not supported for complex dtypes - if dtype not in complex_dtypes: - # This is a special test case for Windows - # https://github.com/pytorch/pytorch/issues/22615 - array2 = array % 2 - tensor_from_array2 = torch.from_numpy(array2) - for i in range(len(array2)): - self.assertEqual(tensor_from_array2[i], array2[i]) - - # Test unsupported type - array = np.array([1, 2, 3, 4], dtype=np.uint16) - with self.assertRaises(TypeError): - tensor_from_array = torch.from_numpy(array) - - # check storage offset - x = np.linspace(1, 125, 125) - x.shape = (5, 5, 5) - x = x[1] - expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[1] - self.assertEqual(torch.from_numpy(x), expected) - - # check noncontiguous - x = np.linspace(1, 25, 25) - x.shape = (5, 5) - expected = torch.arange(1, 26, dtype=torch.float64).view(5, 5).t() - self.assertEqual(torch.from_numpy(x.T), expected) - - # check noncontiguous with holes - x = np.linspace(1, 125, 125) - x.shape = (5, 5, 5) - x = x[:, 1] - expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[:, 1] - self.assertEqual(torch.from_numpy(x), expected) - - # check zero dimensional - x = np.zeros((0, 2)) - self.assertEqual(torch.from_numpy(x).shape, (0, 2)) - x = np.zeros((2, 0)) - self.assertEqual(torch.from_numpy(x).shape, (2, 0)) - - # check ill-sized strides raise exception - x = np.array([3., 5., 8.]) - x.strides = (3,) - self.assertRaises(ValueError, lambda: torch.from_numpy(x)) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_ctor_with_numpy_scalar_ctor(self) -> None: - dtypes = [ - np.double, - np.float, - np.float16, - np.int64, - np.int32, - np.int16, - np.uint8, - np.bool, - ] - for dtype in dtypes: - self.assertEqual(dtype(42), torch.tensor(dtype(42)).item()) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_numpy_index(self): - i = np.int32([0, 1, 2]) - x = torch.randn(5, 5) - for idx in i: - self.assertFalse(isinstance(idx, int)) - self.assertEqual(x[idx], x[int(idx)]) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_numpy_array_interface(self): - types = [ - torch.DoubleTensor, - torch.FloatTensor, - torch.HalfTensor, - torch.LongTensor, - torch.IntTensor, - torch.ShortTensor, - torch.ByteTensor, - ] - dtypes = [ - np.float64, - np.float32, - np.float16, - np.int64, - np.int32, - np.int16, - np.uint8, - ] - for tp, dtype in zip(types, dtypes): - if np.dtype(dtype).kind == 'u': - x = torch.Tensor([1, 2, 3, 4]).type(tp) - array = np.array([1, 2, 3, 4], dtype=dtype) - else: - x = torch.Tensor([1, -2, 3, -4]).type(tp) - array = np.array([1, -2, 3, -4], dtype=dtype) - - # Test __array__ w/o dtype argument - asarray = np.asarray(x) - self.assertIsInstance(asarray, np.ndarray) - self.assertEqual(asarray.dtype, dtype) - for i in range(len(x)): - self.assertEqual(asarray[i], x[i]) - - # Test __array_wrap__, same dtype - abs_x = np.abs(x) - abs_array = np.abs(array) - self.assertIsInstance(abs_x, tp) - for i in range(len(x)): - self.assertEqual(abs_x[i], abs_array[i]) - - # Test __array__ with dtype argument - for dtype in dtypes: - x = torch.IntTensor([1, -2, 3, -4]) - asarray = np.asarray(x, dtype=dtype) - self.assertEqual(asarray.dtype, dtype) - if np.dtype(dtype).kind == 'u': - wrapped_x = np.array([1, -2, 3, -4], dtype=dtype) - for i in range(len(x)): - self.assertEqual(asarray[i], wrapped_x[i]) - else: - for i in range(len(x)): - self.assertEqual(asarray[i], x[i]) - - # Test some math functions with float types - float_types = [torch.DoubleTensor, torch.FloatTensor] - float_dtypes = [np.float64, np.float32] - for tp, dtype in zip(float_types, float_dtypes): - x = torch.Tensor([1, 2, 3, 4]).type(tp) - array = np.array([1, 2, 3, 4], dtype=dtype) - for func in ['sin', 'sqrt', 'ceil']: - ufunc = getattr(np, func) - res_x = ufunc(x) - res_array = ufunc(array) - self.assertIsInstance(res_x, tp) - for i in range(len(x)): - self.assertEqual(res_x[i], res_array[i]) - - # Test functions with boolean return value - for tp, dtype in zip(types, dtypes): - x = torch.Tensor([1, 2, 3, 4]).type(tp) - array = np.array([1, 2, 3, 4], dtype=dtype) - geq2_x = np.greater_equal(x, 2) - geq2_array = np.greater_equal(array, 2).astype('uint8') - self.assertIsInstance(geq2_x, torch.ByteTensor) - for i in range(len(x)): - self.assertEqual(geq2_x[i], geq2_array[i]) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_multiplication_numpy_scalar(self) -> None: - for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]: - for t_dtype in [torch.float, torch.double]: - np_sc = np_dtype(2.0) - t = torch.ones(2, requires_grad=True, dtype=t_dtype) - r1 = t * np_sc - self.assertIsInstance(r1, torch.Tensor) - self.assertTrue(r1.dtype == t_dtype) - self.assertTrue(r1.requires_grad) - r2 = np_sc * t - self.assertIsInstance(r2, torch.Tensor) - self.assertTrue(r2.dtype == t_dtype) - self.assertTrue(r2.requires_grad) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_parse_numpy_int(self): - self.assertRaisesRegex(RuntimeError, "Overflow", - lambda: torch.mean(torch.randn(1, 1), np.uint64(-1))) - # https://github.com/pytorch/pytorch/issues/29252 - for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]: - scalar = 3 - np_arr = np.array([scalar], dtype=nptype) - np_val = np_arr[0] - - # np integral type can be treated as a python int in native functions with - # int parameters: - self.assertEqual(torch.ones(5).diag(scalar), torch.ones(5).diag(np_val)) - self.assertEqual(torch.ones([2, 2, 2, 2]).mean(scalar), torch.ones([2, 2, 2, 2]).mean(np_val)) - - # numpy integral type parses like a python int in custom python bindings: - self.assertEqual(torch.Storage(np_val).size(), scalar) - - tensor = torch.tensor([2], dtype=torch.int) - tensor[0] = np_val - self.assertEqual(tensor[0], np_val) - - # Original reported issue, np integral type parses to the correct - # PyTorch integral type when passed for a `Scalar` parameter in - # arithmetic operations: - t = torch.from_numpy(np_arr) - self.assertEqual((t + np_val).dtype, t.dtype) - self.assertEqual((np_val + t).dtype, t.dtype) - def test_error_msg_type_translation(self): with self.assertRaisesRegex( RuntimeError, @@ -3812,150 +2251,6 @@ def test_error_msg_type_translation(self): model.weight = weight out = model(input) - def test_tensor_from_sequence(self): - class MockSequence(object): - def __init__(self, lst): - self.lst = lst - - def __len__(self): - return len(self.lst) - - def __getitem__(self, item): - raise TypeError - - class GoodMockSequence(MockSequence): - def __getitem__(self, item): - return self.lst[item] - - bad_mock_seq = MockSequence([1.0, 2.0, 3.0]) - good_mock_seq = GoodMockSequence([1.0, 2.0, 3.0]) - with self.assertRaisesRegex(ValueError, 'could not determine the shape'): - torch.Tensor(bad_mock_seq) - self.assertEqual(torch.Tensor([1.0, 2.0, 3.0]), torch.Tensor(good_mock_seq)) - - def test_comparison_ops(self): - x = torch.randn(5, 5) - y = torch.randn(5, 5) - - eq = x == y - for idx in iter_indices(x): - self.assertEqual(x[idx] == y[idx], eq[idx] == 1) - - ne = x != y - for idx in iter_indices(x): - self.assertEqual(x[idx] != y[idx], ne[idx] == 1) - - lt = x < y - for idx in iter_indices(x): - self.assertEqual(x[idx] < y[idx], lt[idx] == 1) - - le = x <= y - for idx in iter_indices(x): - self.assertEqual(x[idx] <= y[idx], le[idx] == 1) - - gt = x > y - for idx in iter_indices(x): - self.assertEqual(x[idx] > y[idx], gt[idx] == 1) - - ge = x >= y - for idx in iter_indices(x): - self.assertEqual(x[idx] >= y[idx], ge[idx] == 1) - - def test_comparison_ops_must_take_bool_output(self): - for op in [torch.lt, torch.le, torch.gt, torch.ge, torch.eq, torch.ne, - torch.logical_and, torch.logical_or, torch.logical_xor]: - self.assertEqual(op(torch.tensor([True]), torch.tensor([False])).dtype, torch.bool) - - def test_inplace_comparison_ops_require_inputs_have_same_dtype(self): - with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'): - for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']: - x = torch.tensor([1], dtype=torch.int) - y = torch.tensor([2], dtype=torch.long) - in_place_method = getattr(x, op) - in_place_method(y) - - def test_comparison_ops_check_for_scalar_overflow(self): - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): - torch.tensor([1 << 5], dtype=torch.uint8) < (1 << 20) - (1 << 20) < torch.tensor([1 << 5], dtype=torch.uint8) - torch.tensor([1 << 5], dtype=torch.uint8) <= (1 << 20) - (1 << 20) <= torch.tensor([1 << 5], dtype=torch.uint8) - torch.tensor([1 << 5], dtype=torch.uint8) > (1 << 20) - (1 << 20) > torch.tensor([1 << 5], dtype=torch.uint8) - torch.tensor([1 << 5], dtype=torch.uint8) >= (1 << 20) - (1 << 20) >= torch.tensor([1 << 5], dtype=torch.uint8) - torch.tensor([1 << 5], dtype=torch.uint8) == (1 << 20) - (1 << 20) == torch.tensor([1 << 5], dtype=torch.uint8) - torch.tensor([1 << 5], dtype=torch.uint8) != (1 << 20) - (1 << 20) != torch.tensor([1 << 5], dtype=torch.uint8) - - def test_comparison_ops_check_for_zerodim_tensor_overflow(self): - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): - torch.tensor([1 << 5], dtype=torch.uint8) < torch.tensor(1 << 20, dtype=torch.int32) - torch.tensor(1 << 40, dtype=torch.int64) < torch.tensor([1 << 30], dtype=torch.int32) - torch.tensor([1 << 5], dtype=torch.uint8) <= torch.tensor(1 << 20, dtype=torch.int32) - torch.tensor(1 << 40, dtype=torch.int64) <= torch.tensor([1 << 30], dtype=torch.int32) - torch.tensor([1 << 5], dtype=torch.uint8) > torch.tensor(1 << 20, dtype=torch.int32) - torch.tensor(1 << 40, dtype=torch.int64) > torch.tensor([1 << 30], dtype=torch.int32) - torch.tensor([1 << 5], dtype=torch.uint8) >= torch.tensor(1 << 20, dtype=torch.int32) - torch.tensor(1 << 40, dtype=torch.int64) >= torch.tensor([1 << 30], dtype=torch.int32) - torch.tensor([1 << 5], dtype=torch.uint8) == torch.tensor(1 << 20, dtype=torch.int32) - torch.tensor(1 << 40, dtype=torch.int64) == torch.tensor([1 << 30], dtype=torch.int32) - torch.tensor([1 << 5], dtype=torch.uint8) != torch.tensor(1 << 20, dtype=torch.int32) - torch.tensor(1 << 40, dtype=torch.int64) != torch.tensor([1 << 30], dtype=torch.int32) - - def test_bitwise_ops(self): - x = torch.randn(5, 5).gt(0) - y = torch.randn(5, 5).gt(0) - - and_result = x & y - for idx in iter_indices(x): - if and_result[idx]: - self.assertTrue(x[idx] and y[idx]) - else: - self.assertFalse(x[idx] and y[idx]) - - or_result = x | y - for idx in iter_indices(x): - if or_result[idx]: - self.assertTrue(x[idx] or y[idx]) - else: - self.assertFalse(x[idx] or y[idx]) - - xor_result = x ^ y - for idx in iter_indices(x): - if xor_result[idx]: - self.assertTrue(x[idx] ^ y[idx]) - else: - self.assertFalse(x[idx] ^ y[idx]) - - x_clone = x.clone() - x_clone &= y - self.assertEqual(x_clone, and_result) - - x_clone = x.clone() - x_clone |= y - self.assertEqual(x_clone, or_result) - - x_clone = x.clone() - x_clone ^= y - self.assertEqual(x_clone, xor_result) - - def test_op_invert(self): - res = 0xffff - torch.arange(127, dtype=torch.int8) - for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - a = torch.arange(127, dtype=dtype) - self.assertEqual(res.to(dtype), ~a) - - self.assertEqual(torch.tensor([True, False]), - ~torch.tensor([False, True])) - - # test exceptions - for dtype in (torch.half, torch.float, torch.double): - a = torch.zeros(10, dtype=dtype) - with self.assertRaises(TypeError): - b = ~a - def test_apply(self): x = torch.arange(1, 6) res = x.clone().apply_(lambda k: k + k) @@ -4012,47 +2307,6 @@ def test_t_not_2d_error(self): self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t()) self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_()) - # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_big_transpose(self): - t = torch.rand(456, 789) - t1 = t.t().contiguous() - t2 = torch.from_numpy(t.numpy().transpose()) - self.assertEqual(t1, t2) - - def test_inplace_division(self): - t = torch.rand(5, 5) - id_before = id(t) - t /= 2 - id_after = id(t) - self.assertEqual(id_before, id_after) - - def test_simple_scalar_cast(self): - ok = [torch.Tensor([1.5]), torch.zeros(1, 1, 1, 1)] - ok_values = [1.5, 0] - - not_ok = map(torch.Tensor, [[], [1, 2], [[1, 2], [3, 4]]]) - - for tensor, value in zip(ok, ok_values): - self.assertEqual(int(tensor), int(value)) - self.assertEqual(float(tensor), float(value)) - self.assertEqual(complex(tensor), complex(value)) - - self.assertEqual(complex(torch.tensor(1.5j)), 1.5j) - - for tensor in not_ok: - self.assertRaises(ValueError, lambda: int(tensor)) - self.assertRaises(ValueError, lambda: float(tensor)) - self.assertRaises(ValueError, lambda: complex(tensor)) - - self.assertRaises(RuntimeError, lambda: float(torch.tensor(1.5j))) - self.assertRaises(RuntimeError, lambda: int(torch.tensor(1.5j))) - - def test_offset_scalar_cast(self): - x = torch.Tensor([1, 2, 3]) - y = x[2:] - self.assertEqual(int(y), 3) - # skip this test for now as it affects all tests @unittest.skipIf(True, "flush_denormal not supported") def test_set_flush_denormal(self): @@ -4080,6 +2334,10 @@ def test_show_config(self): # We can't usefully test the output; just make sure this doesn't crash torch.__config__.show() + @unittest.skipIf(IS_FBCODE, "CXX_FLAGS is only for OSS build.") + def test_cxx_flags(self): + torch.__config__._cxx_flags() + def test_parallel_info(self): torch.__config__.parallel_info() @@ -4089,16 +2347,10 @@ def test_slow_test(self): pass def test_is_nonzero(self): - self.assertExpectedRaisesInline( - RuntimeError, - lambda: torch.tensor([]).is_nonzero(), - "Boolean value of Tensor with no values is ambiguous", - ) - self.assertExpectedRaisesInline( - RuntimeError, - lambda: torch.tensor([0, 0]).is_nonzero(), - "Boolean value of Tensor with more than one value is ambiguous", - ) + with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"): + torch.tensor([]).is_nonzero() + with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"): + torch.tensor([0, 0]).is_nonzero() self.assertFalse(torch.tensor(0).is_nonzero()) self.assertTrue(torch.tensor(1).is_nonzero()) self.assertFalse(torch.tensor([0]).is_nonzero()) @@ -4106,32 +2358,6 @@ def test_is_nonzero(self): self.assertFalse(torch.tensor([[0]]).is_nonzero()) self.assertTrue(torch.tensor([[1]]).is_nonzero()) - def test_meshgrid(self): - a = torch.tensor(1) - b = torch.tensor([1, 2, 3]) - c = torch.tensor([1, 2]) - grid_a, grid_b, grid_c = torch.meshgrid([a, b, c]) - self.assertEqual(grid_a.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_b.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_c.shape, torch.Size([1, 3, 2])) - grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c) - self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2])) - expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64) - expected_grid_b = torch.tensor([[[1, 1], - [2, 2], - [3, 3]]]) - expected_grid_c = torch.tensor([[[1, 2], - [1, 2], - [1, 2]]]) - self.assertTrue(grid_a.equal(expected_grid_a)) - self.assertTrue(grid_b.equal(expected_grid_b)) - self.assertTrue(grid_c.equal(expected_grid_c)) - self.assertTrue(grid_a2.equal(expected_grid_a)) - self.assertTrue(grid_b2.equal(expected_grid_b)) - self.assertTrue(grid_c2.equal(expected_grid_c)) - # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA # is available, we get a different error. @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error") @@ -4144,75 +2370,6 @@ def test_cuda_not_built(self): self.assertRaisesRegex(TypeError, msg, lambda: torch.set_default_tensor_type(torch.cuda.FloatTensor)) self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).to(device="cuda")) - def test_cast_binary_op(self): - # Scalar - a = torch.tensor(2) - b = torch.tensor(3) - a_copy = a.clone() - b_copy = b.clone() - - self.assertEqual(torch.tensor(6, dtype=torch.float), a.float() * b) - - self.assertEqualTypeString(a, a_copy) - self.assertEqualTypeString(b, b_copy) - - def test_cartesian_prod(self): - a = torch.tensor([1]) - b = torch.tensor([1, 2, 3]) - c = torch.tensor([1, 2]) - prod = torch.cartesian_prod(a, b, c) - expected = torch.tensor(list(product([a], b, c))) - self.assertEqual(expected, prod) - - # test 0 size input - d = torch.empty(0, dtype=b.dtype) - prod = torch.cartesian_prod(a, b, c, d) - expected = torch.empty(0, 4, dtype=b.dtype) - self.assertEqual(expected, prod) - - # test single input - prod = torch.cartesian_prod(b) - self.assertEqual(b, prod) - - def test_combinations(self): - a = torch.tensor([1, 2, 3]) - - c = torch.combinations(a, r=1) - expected = torch.tensor(list(combinations(a, r=1))) - self.assertEqual(c, expected) - - c = torch.combinations(a, r=1, with_replacement=True) - expected = torch.tensor(list(combinations_with_replacement(a, r=1))) - self.assertEqual(c, expected) - - c = torch.combinations(a) - expected = torch.tensor(list(combinations(a, r=2))) - self.assertEqual(c, expected) - - c = torch.combinations(a, with_replacement=True) - expected = torch.tensor(list(combinations_with_replacement(a, r=2))) - self.assertEqual(c, expected) - - c = torch.combinations(a, r=3) - expected = torch.tensor(list(combinations(a, r=3))) - self.assertEqual(c, expected) - - c = torch.combinations(a, r=4) - expected = torch.empty(0, 4, dtype=a.dtype) - self.assertEqual(c, expected) - - c = torch.combinations(a, r=5) - expected = torch.empty(0, 5, dtype=a.dtype) - self.assertEqual(c, expected) - - # test empty imput - a = torch.empty(0) - c1 = torch.combinations(a) - c2 = torch.combinations(a, with_replacement=True) - expected = torch.empty(0, 2, dtype=a.dtype) - self.assertEqual(c1, expected) - self.assertEqual(c2, expected) - def test_has_internal_overlap(self): OVERLAP_NO = 0 OVERLAP_YES = 1 @@ -4313,29 +2470,6 @@ def test_ndim(self): c = torch.randn(1, 0) self.assertEqual(2, c.ndim) - def test_T(self): - a = torch.randn(2, 3, 4) - t1 = a.T - t2 = a.permute(2, 1, 0) - self.assertEqual(t2, t1) - b = torch.randn(10) - self.assertEqual(b, b.T) - scalar = torch.tensor(5) - self.assertEqual(scalar, scalar.T) - - def test_python_types(self): - a1 = torch.randn((1, 2), dtype=torch.float64) - a2 = torch.randn((1, 2), dtype=float) - self.assertEqual(a1.dtype, a2.dtype) - - b1 = torch.arange(10, 20, dtype=torch.int64) - b2 = torch.arange(10, 20, dtype=int) - self.assertEqual(b1.dtype, b2.dtype) - - c1 = torch.tensor([True, False], dtype=torch.bool) - c2 = torch.tensor([True, False], dtype=bool) - self.assertEqual(c1.dtype, c2.dtype) - def test_fill_diagonal(self): a1 = torch.randn(7, 3) a2 = a1.clone() @@ -4427,25 +2561,25 @@ def test_empty_meta(self): y = torch.empty_meta(2 ** 20) z = x + y self.assertEqual(z.size(), (2 ** 20, 2 ** 20)) - - def test_tensor_grad_warnings(self): - dummy = torch.empty(1) - - with warnings.catch_warnings(record=True) as w: - # Accessing .grad on leaf - dummy.requires_grad_() - foo = dummy.grad - self.assertEqual(len(w), 0) - - # Accessing .grad on non-leaf - dummy = dummy.clone() - foo = dummy.grad - self.assertEqual(len(w), 1) - - # Accessing .grad on non-leaf that retains gradients - dummy.retain_grad() - foo = dummy.grad - self.assertEqual(len(w), 1) + self.assertRaises(RuntimeError, lambda: z[0][0].item()) + + def test_upsample_nearest1d_meta(self): + # TODO: this is not a sustainable way of testing meta functions, + # but I want some quick scaffolding first before a more + # integrated testing strategy + # NB: Can't make the exponent too big, or it will overflow + # signed 64-bit integer + x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8) + z = torch.nn.functional.interpolate(x, scale_factor=2) + self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) + self.assertRaises(RuntimeError, lambda: z[0][0][0].item()) + + # interpolate doesn't seem to support out= + # (not sure why passing None here doesn't work? How strange...) + z = torch.empty_meta(0) + torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z) + self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) + self.assertRaises(RuntimeError, lambda: z[0][0][0].item()) def test_normal_shape(self): warned = False @@ -4597,7 +2731,7 @@ def neg_dim_test(self): ndim = len(tensor_arg) ndim += extra_dim - n_dim_to_test = sum(map(lambda e: e is DIM_ARG, arg_constr())) + n_dim_to_test = sum(e is DIM_ARG for e in arg_constr()) for dims_val in combinations(range(ndim), n_dim_to_test): arg = arg_constr() @@ -4653,6 +2787,7 @@ def add_neg_dim_tests(): ('cummin', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('nanmedian', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]), ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), @@ -4690,116 +2825,25 @@ def add_neg_dim_tests(): class TestTorchDeviceType(TestCase): exact_dtype = True - # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor - # throws the correct error message - @onlyCUDA - def test_cross_device_inplace_error_msg(self, device): - a = torch.tensor(2.) - b = torch.tensor(2., device=device) - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): - a += b + # TODO: move all tensor creation to common ops + def _rand_shape(self, dim, min_size, max_size): + shape = [] + for i in range(dim): + shape.append(random.randint(min_size, max_size)) + return tuple(shape) - @onlyOnCPUAndCUDA - def test_out_resize_warning(self, device): - a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32) - b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32) - - unary_inputs = (a,) - binary_inputs = (a, b) - unary_ops = (torch.ceil, torch.exp) - binary_ops = (torch.add, torch.sub) - for op in (unary_ops + binary_ops): + @onlyCPU + def test_set_deterministic_beta_warning(self, device): + with DeterministicGuard(torch.is_deterministic()): + # Ensures setting to false does not throw a warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - inputs = unary_inputs if op in unary_ops else binary_inputs - - # No warnings - op(*inputs, out=torch.empty(3, device=device)) - op(*inputs, out=torch.empty(0, device=device)) + torch.set_deterministic(False) self.assertEqual(len(w), 0) - # Cases that throw warnings - op(*inputs, out=torch.empty(2, device=device)) - self.assertEqual(len(w), 1) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.complex64, torch.complex128) - def test_abs_angle_complex_to_float(self, device, dtype): - # Constructs random complex values - from random import random - random_vals = [] - for multiplier in (-1, 1, -10, 10, -100, 100): - for _ in range(10): - random_vals.append(complex(random() * multiplier, random() * multiplier)) - - for vals in (random_vals, []): - a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype]) - t = torch.tensor(vals, device=device, dtype=dtype) - - for fn_name in ('abs', 'angle'): - torch_fn = getattr(torch, fn_name) - np_fn = getattr(np, fn_name) - - # Tests function - np_result = torch.from_numpy(np_fn(a)) - torch_result = torch_fn(t).cpu() - self.assertEqual(np_result, torch_result, exact_dtype=True) - - # Tests float out - float_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype]) - float_out = torch.empty_like(t).float() - torch_fn(t, out=float_out) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.from_numpy(np_float_out), float_out.cpu()) - - # Tests float out (resized out) - float_out = torch.empty(1, device=device, dtype=float_dtype) - torch_fn(t, out=float_out) - self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu()) - - # Tests complex out - np_complex_out = np_fn(a) - complex_out = torch.empty_like(t) - torch_fn(t, out=complex_out) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.from_numpy(np_complex_out), complex_out.cpu()) - - # Tests complex out (resized out) - complex_out = torch.empty(0, device=device, dtype=dtype) - torch_fn(t, out=complex_out) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.from_numpy(np_complex_out), complex_out.cpu()) - - # Tests long out behavior (expected failure) - long_out = torch.empty(0, device=device, dtype=torch.long) - with self.assertRaises(RuntimeError): - torch_fn(t, out=long_out) - - # Tests inplace - if fn_name == 'abs': - torch_inplace_method = getattr(torch.Tensor, fn_name + "_") - np_fn(a, out=a) - torch_inplace_method(t) - self.assertEqual(torch.from_numpy(a), t.cpu()) - - # Note: angle does not have an in-place variant - if fn_name == 'angle': - with self.assertRaises(AttributeError): - torch_inplace_method = getattr(torch.Tensor, fn_name + "_") - - # Verifies that the inplace dunders (like idiv) actually are in place - @onlyOnCPUAndCUDA - def test_inplace_dunders(self, device): - t = torch.randn((1,), device=device) - expected = t.data_ptr() - t += 1 - t -= 1 - t *= 1 - t /= 1 - t //= 1 - self.assertEqual(expected, t.data_ptr()) + # Setting set_deterministic(True) throws a warning once per process + with self.maybeWarnsRegex(UserWarning, "torch.set_deterministic is in beta"): + torch.set_deterministic(True) @dtypes(torch.float32, torch.complex64) def test_storage(self, device, dtype): @@ -4837,435 +2881,6 @@ def test_deepcopy_scalar(self, device, dtype): self.assertEqual(a.size(), deepcopy(a).size()) self.assertEqual(a, deepcopy(a)) - # Tests that when rtol or atol (including self.precision) is set, then - # the other is zeroed. - # TODO: this is legacy behavior and should be updated after test - # precisions are reviewed to be consistent with torch.isclose. - @onlyOnCPUAndCUDA - def test__comparetensors_legacy(self, device): - a = torch.tensor((10000000.,)) - b = torch.tensor((10000002.,)) - - x = torch.tensor((1.,)) - y = torch.tensor((1. + 1e-5,)) - - # Helper for reusing the tensor values as scalars - def _scalar_helper(a, b, rtol=None, atol=None): - return self._compareScalars(a.item(), b.item(), rtol=rtol, atol=atol) - - for op in (self._compareTensors, _scalar_helper): - # Tests default - result, debug_msg = op(a, b) - self.assertTrue(result) - - # Tests setting atol - result, debug_msg = op(a, b, atol=2, rtol=0) - self.assertTrue(result) - - # Tests setting atol too small - result, debug_msg = op(a, b, atol=1, rtol=0) - self.assertFalse(result) - - # Tests setting rtol too small - result, debug_msg = op(x, y, atol=0, rtol=1.05e-5) - self.assertTrue(result) - - # Tests setting rtol too small - result, debug_msg = op(x, y, atol=0, rtol=1e-5) - self.assertFalse(result) - - @onlyOnCPUAndCUDA - def test__comparescalars_debug_msg(self, device): - # float x float - result, debug_msg = self._compareScalars(4., 7.) - expected_msg = ("Comparing 4.0 and 7.0 gives a difference of 3.0, " - "but the allowed difference with rtol=1.3e-06 and " - "atol=1e-05 is only 1.9100000000000003e-05!") - self.assertEqual(debug_msg, expected_msg) - - # complex x complex, real difference - result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1)) - expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference " - "of 2.0, but the allowed difference with rtol=1.3e-06 " - "and atol=1e-05 is only 1.39e-05!") - self.assertEqual(debug_msg, expected_msg) - - # complex x complex, imaginary difference - result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5)) - expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a " - "difference of 2.5, but the allowed difference with " - "rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!") - self.assertEqual(debug_msg, expected_msg) - - # complex x int - result, debug_msg = self._compareScalars(complex(1, -2), 1) - expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a " - "difference of 2.0, but the allowed difference with " - "rtol=1.3e-06 and atol=1e-05 is only 1e-05!") - self.assertEqual(debug_msg, expected_msg) - - # NaN x NaN, equal_nan=False - result, debug_msg = self._compareScalars(float('nan'), float('nan'), equal_nan=False) - expected_msg = ("Found nan and nan while comparing and either one is " - "nan and the other isn't, or both are nan and equal_nan " - "is False") - self.assertEqual(debug_msg, expected_msg) - - # Checks that compareTensors provides the correct debug info - @onlyOnCPUAndCUDA - def test__comparetensors_debug_msg(self, device): - # Acquires atol that will be used - atol = max(1e-05, self.precision) - - # Checks float tensor comparisons (2D tensor) - a = torch.tensor(((0, 6), (7, 9)), device=device, dtype=torch.float32) - b = torch.tensor(((0, 7), (7, 22)), device=device, dtype=torch.float32) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 4) " - "whose difference(s) exceeded the margin of error (including 0 nan comparisons). " - "The greatest difference was 13.0 (9.0 vs. 22.0), " - "which occurred at index (1, 1).").format(atol) - self.assertEqual(debug_msg, expected_msg) - - # Checks float tensor comparisons (with extremal values) - a = torch.tensor((float('inf'), 5, float('inf')), device=device, dtype=torch.float32) - b = torch.tensor((float('inf'), float('nan'), float('-inf')), device=device, dtype=torch.float32) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 3) " - "whose difference(s) exceeded the margin of error (including 1 nan comparisons). " - "The greatest difference was nan (5.0 vs. nan), " - "which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - - # Checks float tensor comparisons (with finite vs nan differences) - a = torch.tensor((20, -6), device=device, dtype=torch.float32) - b = torch.tensor((-1, float('nan')), device=device, dtype=torch.float32) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 2) " - "whose difference(s) exceeded the margin of error (including 1 nan comparisons). " - "The greatest difference was nan (-6.0 vs. nan), " - "which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - - # Checks int tensor comparisons (1D tensor) - a = torch.tensor((1, 2, 3, 4), device=device) - b = torch.tensor((2, 5, 3, 4), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Found 2 different element(s) (out of 4), " - "with the greatest difference of 3 (2 vs. 5) " - "occuring at index 1.") - self.assertEqual(debug_msg, expected_msg) - - # Checks bool tensor comparisons (0D tensor) - a = torch.tensor((True), device=device) - b = torch.tensor((False), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Found 1 different element(s) (out of 1), " - "with the greatest difference of 1 (1 vs. 0) " - "occuring at index 0.") - self.assertEqual(debug_msg, expected_msg) - - # Checks complex tensor comparisons (real part) - a = torch.tensor((1 - 1j, 4 + 3j), device=device) - b = torch.tensor((1 - 1j, 1 + 3j), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Real parts failed to compare as equal! " - "With rtol=1.3e-06 and atol={0}, " - "found 1 element(s) (out of 2) whose difference(s) exceeded the " - "margin of error (including 0 nan comparisons). The greatest difference was " - "3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - - # Checks complex tensor comparisons (imaginary part) - a = torch.tensor((1 - 1j, 4 + 3j), device=device) - b = torch.tensor((1 - 1j, 4 - 21j), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Imaginary parts failed to compare as equal! " - "With rtol=1.3e-06 and atol={0}, " - "found 1 element(s) (out of 2) whose difference(s) exceeded the " - "margin of error (including 0 nan comparisons). The greatest difference was " - "24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - - # Checks size mismatch - a = torch.tensor((1, 2), device=device) - b = torch.tensor((3), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Attempted to compare equality of tensors " - "with different sizes. Got sizes torch.Size([2]) and torch.Size([]).") - self.assertEqual(debug_msg, expected_msg) - - # Checks dtype mismatch - a = torch.tensor((1, 2), device=device, dtype=torch.long) - b = torch.tensor((1, 2), device=device, dtype=torch.float32) - result, debug_msg = self._compareTensors(a, b, exact_dtype=True) - expected_msg = ("Attempted to compare equality of tensors " - "with different dtypes. Got dtypes torch.int64 and torch.float32.") - self.assertEqual(debug_msg, expected_msg) - - # Checks device mismatch - if self.device_type == 'cuda': - a = torch.tensor((5), device='cpu') - b = torch.tensor((5), device=device) - result, debug_msg = self._compareTensors(a, b, exact_device=True) - expected_msg = ("Attempted to compare equality of tensors " - "on different devices! Got devices cpu and cuda:0.") - self.assertEqual(debug_msg, expected_msg) - - # Helper for testing _compareTensors and _compareScalars - # Works on single element tensors - def _comparetensors_helper(self, tests, device, dtype, equal_nan, exact_dtype=True, atol=1e-08, rtol=1e-05): - for test in tests: - a = torch.tensor((test[0],), device=device, dtype=dtype) - b = torch.tensor((test[1],), device=device, dtype=dtype) - - # Tensor x Tensor comparison - compare_result, debug_msg = self._compareTensors(a, b, rtol=rtol, atol=atol, - equal_nan=equal_nan, - exact_dtype=exact_dtype) - self.assertEqual(compare_result, test[2]) - - # Scalar x Scalar comparison - compare_result, debug_msg = self._compareScalars(a.item(), b.item(), - rtol=rtol, atol=atol, - equal_nan=equal_nan) - self.assertEqual(compare_result, test[2]) - - def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05): - for test in tests: - a = torch.tensor((test[0],), device=device, dtype=dtype) - b = torch.tensor((test[1],), device=device, dtype=dtype) - - actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol) - expected = test[2] - self.assertEqual(actual.item(), expected) - - # torch.close is not implemented for bool tensors - # see https://github.com/pytorch/pytorch/issues/33048 - def test_isclose_comparetensors_bool(self, device): - tests = ( - (True, True, True), - (False, False, True), - (True, False, False), - (False, True, False), - ) - - with self.assertRaises(RuntimeError): - self._isclose_helper(tests, device, torch.bool, False) - - self._comparetensors_helper(tests, device, torch.bool, False) - - @dtypes(torch.uint8, - torch.int8, torch.int16, torch.int32, torch.int64) - def test_isclose_comparetensors_integer(self, device, dtype): - tests = ( - (0, 0, True), - (0, 1, False), - (1, 0, False), - ) - - self._isclose_helper(tests, device, dtype, False) - - # atol and rtol tests - tests = [ - (0, 1, True), - (1, 0, False), - (1, 3, True), - ] - - self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) - self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) - - if dtype is torch.uint8: - tests = [ - (-1, 1, False), - (1, -1, False) - ] - else: - tests = [ - (-1, 1, True), - (1, -1, True) - ] - - self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5) - self._comparetensors_helper(tests, device, dtype, False, atol=1.5, rtol=.5) - - @onlyOnCPUAndCUDA - @dtypes(torch.float16, torch.float32, torch.float64) - def test_isclose_comparetensors_float(self, device, dtype): - tests = ( - (0, 0, True), - (0, -1, False), - (float('inf'), float('inf'), True), - (-float('inf'), float('inf'), False), - (float('inf'), float('nan'), False), - (float('nan'), float('nan'), False), - (0, float('nan'), False), - (1, 1, True), - ) - - self._isclose_helper(tests, device, dtype, False) - self._comparetensors_helper(tests, device, dtype, False) - - # atol and rtol tests - eps = 1e-2 if dtype is torch.half else 1e-6 - tests = ( - (0, 1, True), - (0, 1 + eps, False), - (1, 0, False), - (1, 3, True), - (1 - eps, 3, False), - (-.25, .5, True), - (-.25 - eps, .5, False), - (.25, -.5, True), - (.25 + eps, -.5, False), - ) - - self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) - self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) - - # equal_nan = True tests - tests = ( - (0, float('nan'), False), - (float('inf'), float('nan'), False), - (float('nan'), float('nan'), True), - ) - - self._isclose_helper(tests, device, dtype, True) - - self._comparetensors_helper(tests, device, dtype, True) - - # torch.close with equal_nan=True is not implemented for complex inputs - # see https://github.com/numpy/numpy/issues/15959 - # Note: compareTensor will compare the real and imaginary parts of a - # complex tensors separately, unlike isclose. - @dtypes(torch.complex64, torch.complex128) - def test_isclose_comparetensors_complex(self, device, dtype): - tests = ( - (complex(1, 1), complex(1, 1 + 1e-8), True), - (complex(0, 1), complex(1, 1), False), - (complex(1, 1), complex(1, 0), False), - (complex(1, 1), complex(1, float('nan')), False), - (complex(1, float('nan')), complex(1, float('nan')), False), - (complex(1, 1), complex(1, float('inf')), False), - (complex(float('inf'), 1), complex(1, float('inf')), False), - (complex(-float('inf'), 1), complex(1, float('inf')), False), - (complex(-float('inf'), 1), complex(float('inf'), 1), False), - (complex(float('inf'), 1), complex(float('inf'), 1), True), - (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False), - ) - - self._isclose_helper(tests, device, dtype, False) - self._comparetensors_helper(tests, device, dtype, False) - - # atol and rtol tests - - # atol and rtol tests - eps = 1e-6 - tests = ( - # Complex versions of float tests (real part) - (complex(0, 0), complex(1, 0), True), - (complex(0, 0), complex(1 + eps, 0), False), - (complex(1, 0), complex(0, 0), False), - (complex(1, 0), complex(3, 0), True), - (complex(1 - eps, 0), complex(3, 0), False), - (complex(-.25, 0), complex(.5, 0), True), - (complex(-.25 - eps, 0), complex(.5, 0), False), - (complex(.25, 0), complex(-.5, 0), True), - (complex(.25 + eps, 0), complex(-.5, 0), False), - # Complex versions of float tests (imaginary part) - (complex(0, 0), complex(0, 1), True), - (complex(0, 0), complex(0, 1 + eps), False), - (complex(0, 1), complex(0, 0), False), - (complex(0, 1), complex(0, 3), True), - (complex(0, 1 - eps), complex(0, 3), False), - (complex(0, -.25), complex(0, .5), True), - (complex(0, -.25 - eps), complex(0, .5), False), - (complex(0, .25), complex(0, -.5), True), - (complex(0, .25 + eps), complex(0, -.5), False), - ) - - self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) - self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) - - # atol and rtol tests for isclose - tests = ( - # Complex-specific tests - (complex(1, -1), complex(-1, 1), False), - (complex(1, -1), complex(2, -2), True), - (complex(-math.sqrt(2), math.sqrt(2)), - complex(-math.sqrt(.5), math.sqrt(.5)), True), - (complex(-math.sqrt(2), math.sqrt(2)), - complex(-math.sqrt(.501), math.sqrt(.499)), False), - (complex(2, 4), complex(1., 8.8523607), True), - (complex(2, 4), complex(1., 8.8523607 + eps), False), - (complex(1, 99), complex(4, 100), True), - ) - - self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) - - # atol and rtol tests for compareTensors - tests = ( - (complex(1, -1), complex(-1, 1), False), - (complex(1, -1), complex(2, -2), True), - (complex(1, 99), complex(4, 100), False), - ) - - self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) - - # equal_nan = True tests - tests = ( - (complex(1, 1), complex(1, float('nan')), False), - (complex(float('nan'), 1), complex(1, float('nan')), False), - (complex(float('nan'), 1), complex(float('nan'), 1), True), - ) - - with self.assertRaises(RuntimeError): - self._isclose_helper(tests, device, dtype, True) - - self._comparetensors_helper(tests, device, dtype, True) - - # Tests that isclose with rtol or atol values less than zero throws a - # RuntimeError - @dtypes(torch.bool, torch.uint8, - torch.int8, torch.int16, torch.int32, torch.int64, - torch.float16, torch.float32, torch.float64) - def test_isclose_atol_rtol_greater_than_zero(self, device, dtype): - t = torch.tensor((1,), device=device, dtype=dtype) - - with self.assertRaises(RuntimeError): - torch.isclose(t, t, atol=-1, rtol=1) - with self.assertRaises(RuntimeError): - torch.isclose(t, t, atol=1, rtol=-1) - with self.assertRaises(RuntimeError): - torch.isclose(t, t, atol=-1, rtol=-1) - - # XLA tests fail for self.assertRaises for complex dtypes - @onlyOnCPUAndCUDA - def test_complex_assert_raises(self, device): - for dtype in [torch.complex64, torch.complex128]: - size = [5, 5] - tensor = torch.rand(size, dtype=dtype, device=device) - - # index_add calls atomicAdd on cuda. - zeros = torch.zeros(size, dtype=dtype, device=device) - - # index_add is not supported for complex dtypes on cuda yet - if device.startswith('cuda') and dtype.is_complex: - self.assertRaises(RuntimeError, - lambda: zeros.index_add(0, torch.arange(0, size[0], dtype=torch.long, device=device), tensor)) - - self.assertRaises(RuntimeError, lambda: torch.sign(torch.tensor([4j], device=device, dtype=dtype))) - - a = torch.rand((2, 2), dtype=dtype, device=device) - b = torch.rand((2, 2), dtype=dtype, device=device) - c = torch.rand((2, 2), dtype=dtype, device=device) - alpha = 3 - - # addcmul is not supported for complex dtypes on cuda yet - if device.startswith('cuda') and dtype.is_complex: - self.assertRaises(RuntimeError, lambda: torch.addcmul(a, b, c, value=alpha)) - def check_internal_mem_overlap(self, inplace_op, num_inputs, dtype, device, expected_failure=False): @@ -5303,20 +2918,6 @@ def _test(op, output, input): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): _test(op, data[0:sz], data[1:sz + 1]) - def binary_check_input_output_mem_overlap(self, op, device, - expected_failure=False): - sz = 3 - data = torch.randn(2 * sz, device=device) - other = torch.randn(sz, device=device) - - self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(other, input, out=out), - expected_failure=expected_failure) - - self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(input, other, out=out), - expected_failure=expected_failure) - def ternary_check_input_output_mem_overlap(self, op, device, expected_failure=False): sz = 3 @@ -5336,48 +2937,7 @@ def ternary_check_input_output_mem_overlap(self, op, device, data, sz, lambda input, out: op(other1, other2, input, out=out), expected_failure=expected_failure) - def _test_pow(self, base, exponent, np_exponent=None): - if np_exponent is None: - np_exponent = exponent - - def to_np(value): - if isinstance(value, torch.Tensor): - return value.cpu().numpy() - return value - - try: - expected = torch.from_numpy( - np.power(to_np(base), to_np(np_exponent))) - except ValueError as e: - err_msg = "Integers to negative integer powers are not allowed." - self.assertEqual(str(e), err_msg) - out = torch.empty_like(base) - test_cases = [ - lambda: base.pow(exponent), - lambda: base.pow_(exponent), - lambda: torch.pow(base, exponent), - lambda: torch.pow(base, exponent, out=out) - ] - for test_case in test_cases: - self.assertRaisesRegex(RuntimeError, err_msg, test_case) - else: - if isinstance(base, torch.Tensor): - actual = base.pow(exponent) - self.assertEqual(actual, expected.to(actual)) - actual = base.clone() - if torch.can_cast(torch.result_type(base, exponent), base.dtype): - actual2 = actual.pow_(exponent) - self.assertEqual(actual, expected) - self.assertEqual(actual2, expected) - else: - self.assertRaisesRegex(RuntimeError, "can't be cast", lambda: actual.pow_(exponent)) - actual = torch.pow(base, exponent) - self.assertEqual(actual, expected.to(actual)) - - actual2 = torch.pow(base, exponent, out=actual) - self.assertEqual(actual, expected.to(actual)) - self.assertEqual(actual2, expected.to(actual)) def _select_broadcastable_dims(self, dims_full=None): # select full dimensionality @@ -5416,9 +2976,6 @@ def test_scalar_check(self, device): zero_d = torch.randn((), device=device) one_d = torch.randn((1,), device=device) - # _multinomial_alias_setup - self.assertRaises(RuntimeError, lambda: torch._multinomial_alias_setup(zero_d)) - # remainder self.assertEqual((), torch.remainder(zero_d, zero_d).shape) self.assertEqual((), torch.remainder(zero_d, 2).shape) @@ -5533,9 +3090,6 @@ def test_scalar_check(self, device): self.assertEqual((1,), (zero_d_int & one_d_int).shape) self.assertEqual((1,), (one_d_int & 1).shape) - # _multinomial_alias_draw - self.assertRaises(RuntimeError, lambda: torch._multinomial_alias_draw(zero_d, zero_d_int, 10)) - # clone self.assertEqual((), zero_d.clone().shape) @@ -5716,165 +3270,7 @@ def warn_fn(): self.assertEqual(frameinfo.lineno - 6, warning.lineno) self.assertEqual(len(w), 1) - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @dtypes(torch.float) - def test_isfinite_isinf_isnan(self, device, dtype): - vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1) - - self.compare_with_numpy(torch.isfinite, np.isfinite, vals, device, dtype) - self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) - self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @dtypes(torch.long) - def test_isfinite_isinf_isnan_int(self, device, dtype): - vals = (-1, 0, 1) - - self.compare_with_numpy(torch.isfinite, np.isfinite, vals, device, dtype) - self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) - self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @dtypes(*(torch.testing.get_all_fp_dtypes())) - def test_isposinf_isneginf_float(self, device, dtype): - ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf)) - vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1) - - for torch_op, numpy_op in ops: - if torch_op == torch.isposinf: - target_vals = (0, 1, 0, 0, 0, 0) - else: - target_vals = (1, 0, 0, 0, 0, 0) - - t = torch.tensor(vals, device=device, dtype=dtype) - # Manual check here as numpy does not support bfloat16 - if dtype == torch.bfloat16: - self.assertEqual(torch_op(t), - torch.tensor(target_vals, device=device, dtype=torch.bool)) - else: - self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype) - - # test the boolean tensor as the `out=` parameter - out = torch.empty_like(t, dtype=torch.bool) - t_target = torch.tensor(target_vals, device=device, dtype=torch.bool) - torch_op(t, out=out) - self.assertEqual(out, t_target) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) - def test_isposinf_isneginf_int_and_bool(self, device, dtype): - ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf)) - vals = (-1, 0, 1) - - for torch_op, numpy_op in ops: - self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype) - - # test the boolean tensor as the `out=` parameter - t = torch.tensor(vals, device=device, dtype=dtype) - out = torch.empty_like(t, dtype=torch.bool) - t_target = torch.zeros_like(t, dtype=torch.bool) - torch_op(t, out=out) - self.assertEqual(out, t_target) - - @dtypes(torch.complex64, torch.complex128) - def test_isposinf_isneginf_complex(self, device, dtype): - torch_ops = (torch.isposinf, torch.isneginf) - vals = (complex(0, float('inf')), complex(1, -float('inf'))) - t = torch.tensor(vals, device=device, dtype=dtype) - out = torch.empty_like(t) - - for torch_op in torch_ops: - with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): - torch_op(t) - with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): - torch_op(t, out=out) - - @dtypes(*(torch.testing.get_all_dtypes(include_bool=False))) - def test_isposinf_isneginf_non_boolean_output(self, device, dtype): - # test non-boolean tensors as the `out=` parameters - # boolean outputs are tested in the above testcases - vals = (float('inf'), -float('inf'), 1.2) - t = torch.tensor(vals, device=device) - for torch_op in (torch.isposinf, torch.isneginf): - out = torch.empty_like(t, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, 'does not support non-boolean outputs'): - torch_op(t, out=out) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @dtypes(torch.complex64) - def test_isfinite_isinf_isnan_complex(self, device, dtype): - vals = ( - complex(-float('inf'), float('inf')), - complex(-float('inf'), 0), - complex(0, float('inf')), - complex(float('inf'), float('nan')), - complex(float('nan'), 0), - complex(-1, 0), - complex(0, 1) - ) - - self.compare_with_numpy(torch.isfinite, np.isfinite, vals, device, dtype) - self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) - self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @dtypes(torch.complex64, torch.complex128) - def test_isreal_complex(self, device, dtype): - vals = (1, 1 + 1j, 2 + 0j, 3j, 2 - 1j, 2 - 0j) - self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype) - - @dtypes(*torch.testing.get_all_dtypes()) - def test_isreal_noncomplex(self, device, dtype): - vals = (1, 2, 3) - # Manual check here since numpy doesn't support bfloat16 - result = torch.isreal(torch.tensor(vals, dtype=dtype)) - expected = torch.ones(result.size(), dtype=torch.bool, device=device) - self.assertEqual(result, expected) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @dtypes(torch.complex64) - def test_isreal_nan_inf(self, device, dtype): - vals = ( - complex(-float('inf'), float('inf')), - complex(-float('inf'), 0), - complex(0, float('inf')), - complex(float('inf'), float('nan')), - complex(float('nan'), 0), - complex(-1, 0), - complex(0, 1) - ) - self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype) - - @onlyCPU - def test_isfinite_type(self, device): - with self.assertRaises(TypeError): - torch.isfinite(1) # Parameter must be a tensor - - @onlyCPU - def test_isinf_type(self, device): - with self.assertRaises(TypeError): - torch.isinf(1) # Parameter must be a tensor - - @onlyCPU - @dtypes(torch.float) - def test_diag(self, device, dtype): - x = torch.rand(100, 100, dtype=dtype, device=device) - res1 = torch.diag(x) - res2 = torch.tensor((), dtype=dtype, device=device) - torch.diag(x, out=res2) - self.assertEqual(res1, res2) - - def test_diagonal(self, device): - x = torch.randn((100, 100), device=device) - result = torch.diagonal(x) - expected = torch.diag(x) - self.assertEqual(result, expected) - - x = torch.randn((100, 100), device=device) - result = torch.diagonal(x, 17) - expected = torch.diag(x, 17) - self.assertEqual(result, expected) - + # TODO: this test should be in test_nn.py def test_conv_transposed_backward_agnostic_to_memory_format(self, device): in_channels = 64 out_channels = 128 @@ -5891,6 +3287,7 @@ def test_conv_transposed_backward_agnostic_to_memory_format(self, device): input_ = layer_norm(input_.transpose(1, 2).contiguous()).contiguous() input_.sum().backward() + # TODO: this test should be in test_nn.py @largeTensorTest('12GB') def test_conv_transposed_large(self, device): # ConvTranspose3d works for large input tensors (gh-32866) @@ -5905,884 +3302,6 @@ def test_conv_transposed_large(self, device): x = torch.rand([1, 64, 8, 128, 172]).to(device) y = conv(x) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - @onlyCPU - @dtypes(torch.float) - def test_diagonal_multidim(self, device, dtype): - x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) - xn = x.numpy() - for args in [(2, 2, 3), - (2,), - (-2, 1, 2), - (0, -2, -1)]: - result = torch.diagonal(x, *args) - expected = xn.diagonal(*args) - self.assertEqual(expected.shape, result.shape) - self.assertEqual(expected, result) - # test non-continguous - xp = x.permute(1, 2, 3, 0) - result = torch.diagonal(xp, 0, -2, -1) - expected = xp.numpy().diagonal(0, -2, -1) - self.assertEqual(expected.shape, result.shape) - self.assertEqual(expected, result) - - @onlyCPU - @dtypes(torch.float) - def test_broadcast_tensors(self, device, dtype): - x0 = torch.randn(2, 1, 3, dtype=dtype, device=device) - x1 = torch.randn(3, dtype=dtype, device=device) - x2 = torch.randn(3, 1, dtype=dtype, device=device) - expected_size = (2, 3, 3) - - y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2) - self.assertTrue(y0.size() == expected_size) - self.assertTrue(y1.size() == expected_size) - self.assertTrue(y2.size() == expected_size) - - def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol): - for num in exponents: - if isinstance(num, int) and num < 0 and not m1.is_floating_point() and not m1.is_complex(): - with self.assertRaisesRegex(RuntimeError, - r'Integers to negative integer powers are not allowed\.'): - torch.pow(m1[4], num) - else: - # base - tensor, exponent - number - # contiguous - res1 = torch.pow(m1[4], num) - res2 = res1.clone().zero_() - # `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`. - for i in range(res2.size(0)): - res2[i] = pow_fn(m1[4][i], num) - rtol = 0 if atol is not None else None - self.assertEqual(res1, res2, atol=atol, rtol=rtol) - - # non-contiguous - res1 = torch.pow(m1[:, 4], num) - res2 = res1.clone().zero_() - for i in range(res2.size(0)): - res2[i] = pow_fn(m1[i, 4], num) - self.assertEqual(res1, res2, atol=atol, rtol=rtol) - - # scalar ** tensor to enforce correct handling of dtypes for __rpow__(). - expected_dtype = torch.result_type(num, m1) - res1 = num ** m1[4] - res2 = torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4] - self.assertEqual(res1, res2) - self.assertEqual(res1.dtype, expected_dtype) - - def test_pow(self, device): - # [res] torch.pow([res,] x) - - # pow has dedicated implementation for different exponents - for dtype in torch.testing.get_all_math_dtypes(device): - - # This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it - # for now. - if dtype == torch.half: - continue - - # deferring to https://github.com/pytorch/pytorch/pull/36793 - if dtype.is_complex: - continue - - m1 = torch.empty(0, dtype=dtype, device=device) - if m1.is_floating_point() or m1.is_complex(): - m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5 - else: - # math.pow will overflow and throw exceptions for large integers - range_high = 4 if dtype in (torch.int8, torch.uint8) else 10 - m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device) - - exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3] - complex_exponents = [-2.5j, -1.0j, 0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] - if m1.is_complex(): - self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4) - else: - self._do_pow_for_exponents(m1, exponents, math.pow, None) - self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) - - # base - number, exponent - tensor - # contiguous - res1 = torch.pow(3, m1[4]) - res2 = res1.clone().zero_() - for i in range(res2.size(0)): - res2[i] = math.pow(3, m1[4, i]) - self.assertEqual(res1, res2) - - # non-contiguous - res1 = torch.pow(3, m1[:, 4]) - res2 = res1.clone().zero_() - for i in range(res2.size(0)): - res2[i] = math.pow(3, m1[i][4]) - self.assertEqual(res1, res2) - - # resize behavior for exp == 1 - out = torch.zeros(1, dtype=dtype, device=device) - torch.pow(m1, 1, out=out) - self.assertEqual(out, m1) - - @skipCUDAIf( - _get_torch_cuda_version() < [10, 0] and not TEST_MAGMA, - "On cuda 9.2, torch.inverse relies on magma" - ) - @skipCPUIfNoLapack - def test_inverse(self, device): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - def test_inverse_helper(matrix, batches, n): - identity = torch.eye(n, dtype=torch.float64, device=device) - - # correctness test, check matrix*matrix_inverse == identity - matrix_inverse = torch.inverse(matrix) - - self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix, matrix_inverse), atol=1e-8, rtol=0) - self.assertEqual(identity.expand_as(matrix), torch.matmul(matrix_inverse, matrix), atol=1e-8, rtol=0) - - # torch.inverse with out and batches - matrix_inverse_out = torch.empty(*batches, n, n, dtype=torch.float64, device=device) - torch.inverse(matrix, out=matrix_inverse_out) - self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0) - - # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix - if matrix.ndim > 2: - expected_inv_list = [] - for mat in matrix.contiguous().view(-1, n, n): - expected_inv_list.append(torch.inverse(mat)) - expected_inv = torch.stack(expected_inv_list).view(*batches, n, n) - self.assertEqual(matrix_inverse, expected_inv) - - for batches, n in product( - [[], [1], [4], [2, 3], [32]], - [5, 256] - ): - # large batch size and large matrix size will be tested in test_inverse_many_batches (slow test) - if batches and batches[0] == 32 and n == 256: - continue - _matrices = random_fullrank_matrix_distinct_singular_value(n, *batches).to(device) - test_inverse_helper(_matrices, batches, n) - test_inverse_helper(_matrices.transpose(-2, -1), batches, n) - test_inverse_helper( - random_fullrank_matrix_distinct_singular_value(n * 2, *batches).to(device) - .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n), - batches, n - ) - - # incorrect input test - with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): - torch.inverse(torch.randn(2, 3, 4, 3)) - - # test for zero-sized tensor - def test_inverse_helper_zero_size(size): - data = torch.zeros(*size, device=device) - out = torch.inverse(data) - self.assertTrue(out.size() == data.size()) - - test_inverse_helper_zero_size([0, 0]) - test_inverse_helper_zero_size([3, 0, 0]) - test_inverse_helper_zero_size([0, 3, 3]) - - # non-contiguous inputs - if not TEST_NUMPY: - return - - from numpy.linalg import inv - matrices = random_fullrank_matrix_distinct_singular_value(3, 2).to(device).permute(0, 2, 1) - assert not matrices.is_contiguous() - matrices_inverse = torch.inverse(matrices) - expected_inv = torch.as_tensor(inv(matrices.cpu().numpy())) - self.assertEqual(matrices_inverse, expected_inv.to(device)) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @onlyOnCPUAndCUDA - @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) - def test_signed_shift(self, device, dtype): - "Ensure that signed integer bit shifting works as expected." - a = torch.tensor([-10, 10], device=device, dtype=dtype) # [11...1110110, 1010] - expected_l = torch.tensor([-40, 40], device=device, dtype=dtype) # [11...11011000, 101000] - self.assertEqual(a << 2, expected_l) - self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a) - expected_r = torch.tensor([-5, 5], device=device, dtype=dtype) # [1111...111011, 101] - self.assertEqual(a >> 1, expected_r) - self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a) - - def test_bitwise_not(self, device): - res = 0xffff - torch.arange(127, dtype=torch.int8, device=device) - for dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - if dtype == torch.bool: - a = torch.tensor([True, False], device=device) - expected_res = torch.tensor([False, True], device=device) - else: - a = torch.arange(127, dtype=dtype, device=device) - expected_res = res.to(dtype) - # new tensor - self.assertEqual(expected_res, a.bitwise_not()) - # out - b = torch.empty(0, dtype=dtype, device=device) - torch.bitwise_not(a, out=b) - self.assertEqual(expected_res, b) - # in-place - a.bitwise_not_() - self.assertEqual(expected_res, a) - - # test exceptions - for dtype in (torch.half, torch.float, torch.double): - a = torch.zeros(10, dtype=dtype, device=device) - # new tensor - with self.assertRaises(RuntimeError): - a.bitwise_not() - # out - b = torch.empty(0, dtype=dtype, device=device) - with self.assertRaises(RuntimeError): - torch.bitwise_not(a, out=b) - # in-place - with self.assertRaises(RuntimeError): - a.bitwise_not_() - - def test_bitwise_and(self, device): - for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - a = torch.tensor([1, -2, 3], dtype=dtype, device=device) - b = torch.tensor([2, 1, 3], dtype=dtype, device=device) - expected_res = torch.tensor([0, 0, 3], dtype=dtype, device=device) - b_scalar = 2 - expected_res_scalar = torch.tensor([0, 2, 2], dtype=dtype, device=device) - - # standard version - self.assertEqual(torch.bitwise_and(a, b), expected_res) - self.assertEqual(torch.bitwise_and(a, b_scalar), expected_res_scalar) - - # out - c = torch.empty(0, dtype=dtype, device=device) - torch.bitwise_and(a, b, out=c) - self.assertEqual(c, expected_res) - torch.bitwise_and(a, b_scalar, out=c) - self.assertEqual(c, expected_res_scalar) - - # in-place - a1 = a.clone() - a1.bitwise_and_(b) - self.assertEqual(a1, expected_res) - a.bitwise_and_(b_scalar) - self.assertEqual(a, expected_res_scalar) - - self.assertEqual(torch.tensor([False, True, False], device=device), - torch.bitwise_and(torch.tensor([True, True, False], device=device), - torch.tensor([False, True, False], device=device))) - - def test_bitwise_or(self, device): - for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - a = torch.tensor([1, -2, 3], dtype=dtype, device=device) - b = torch.tensor([2, 1, 3], dtype=dtype, device=device) - expected_res = torch.tensor([3, -1, 3], dtype=dtype, device=device) - b_scalar = 2 - expected_res_scalar = torch.tensor([3, -2, 3], dtype=dtype, device=device) - - # standard version - self.assertEqual(torch.bitwise_or(a, b), expected_res) - self.assertEqual(torch.bitwise_or(a, b_scalar), expected_res_scalar) - - # out - c = torch.empty(0, dtype=dtype, device=device) - torch.bitwise_or(a, b, out=c) - self.assertEqual(c, expected_res) - torch.bitwise_or(a, b_scalar, out=c) - self.assertEqual(c, expected_res_scalar) - - # in-place - a1 = a.clone() - a1.bitwise_or_(b) - self.assertEqual(a1, expected_res) - a.bitwise_or_(b_scalar) - self.assertEqual(a, expected_res_scalar) - - self.assertEqual(torch.tensor([True, True, False], device=device), - torch.bitwise_or(torch.tensor([True, True, False], device=device), - torch.tensor([False, True, False], device=device))) - - def test_bitwise_xor(self, device): - for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - a = torch.tensor([1, -2, 3], dtype=dtype, device=device) - b = torch.tensor([2, 1, 3], dtype=dtype, device=device) - expected_res = torch.tensor([3, -1, 0], dtype=dtype, device=device) - b_scalar = 2 - expected_res_scalar = torch.tensor([3, -4, 1], dtype=dtype, device=device) - - # standard version - self.assertEqual(torch.bitwise_xor(a, b), expected_res) - self.assertEqual(torch.bitwise_xor(a, b_scalar), expected_res_scalar) - - # out - c = torch.empty(0, dtype=dtype, device=device) - torch.bitwise_xor(a, b, out=c) - self.assertEqual(c, expected_res) - torch.bitwise_xor(a, b_scalar, out=c) - self.assertEqual(c, expected_res_scalar) - - # in-place - a1 = a.clone() - a1.bitwise_xor_(b) - self.assertEqual(a1, expected_res) - a.bitwise_xor_(b_scalar) - self.assertEqual(a, expected_res_scalar) - - self.assertEqual(torch.tensor([True, False, False], device=device), - torch.bitwise_xor(torch.tensor([True, True, False], device=device), - torch.tensor([False, True, False], device=device))) - - @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False), - torch.testing.get_all_dtypes(include_complex=False)))) - def test_heaviside(self, device, dtypes): - input_dtype = dtypes[0] - values_dtype = dtypes[1] - - rng = np.random.default_rng() - input = np.array(rng.integers(-10, 10, size=10), - dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64]) - input[0] = input[3] = input[7] = 0 - values = np.array(rng.integers(-10, 10, size=10), - dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64]) - np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype) - - input = torch.from_numpy(input).to(device=device, dtype=input_dtype) - values = torch.from_numpy(values).to(device=device, dtype=values_dtype) - out = torch.empty_like(input) - - if input_dtype == values_dtype: - torch_result = torch.heaviside(input, values) - self.assertEqual(np_result, torch_result) - - torch_result = input.heaviside(values) - self.assertEqual(np_result, torch_result) - - torch.heaviside(input, values, out=out) - self.assertEqual(np_result, out) - - input.heaviside_(values) - self.assertEqual(np_result, input) - else: - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): - torch.heaviside(input, values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): - input.heaviside(values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): - torch.heaviside(input, values, out=out) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): - input.heaviside_(values) - - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(*list(product(torch.testing.get_all_complex_dtypes(), - torch.testing.get_all_complex_dtypes()))) - def test_heaviside_complex(self, device, dtypes): - input_dtype = dtypes[0] - values_dtype = dtypes[1] - - data = (complex(0, -6), complex(-1, 3), complex(1, 1)) - input = torch.tensor(data, device=device, dtype=input_dtype) - values = torch.tensor(data, device=device, dtype=values_dtype) - out = torch.empty_like(input) - real = input.real - - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): - torch.heaviside(input, real) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): - real.heaviside(values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): - input.heaviside_(values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): - torch.heaviside(real, real, out=out) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - @dtypes(*torch.testing.get_all_dtypes()) - def test_logical_not(self, device, dtype): - data = [10, 1, 0.3, 0, -0.3, -1, -10] - a = torch.tensor(data, dtype=dtype, device=device) - - # do this before constructing the numpy array because np can't construct - # bfloat16 tensors. Can we define our own dtype in NumPy so testing would be easier? - if dtype == torch.bfloat16 or dtype.is_complex: - self.assertRaises(RuntimeError, lambda: a.logical_not()) - self.assertRaises(RuntimeError, lambda: a.logical_not_()) - raise unittest.SkipTest('logical_not not supported on {}'.format(dtype)) - - a_np = np.array(data, dtype=torch_to_numpy_dtype_dict[dtype]) - self.assertEqual(np.logical_not(a_np), torch.logical_not(a).to('cpu')) - self.assertEqual(np.logical_not(a_np, out=a_np), a.logical_not_().to('cpu')) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - @dtypes(*list(product(torch.testing.get_all_dtypes(), - torch.testing.get_all_dtypes()))) - def test_logical_not_out(self, device, dtypes): - dtype = dtypes[0] - out_dtype = dtypes[1] - data = [10, 1, 0.3, 0, -0.3, -1, -10] - a = torch.tensor(data, dtype=dtype, device=device) - out = torch.empty(a.shape, dtype=out_dtype, device=device) - - if (dtype == torch.bfloat16 or dtype.is_complex or - out_dtype == torch.bfloat16 or out_dtype.is_complex): - self.assertRaises(RuntimeError, lambda: torch.logical_not(a, out=out)) - raise unittest.SkipTest('logical_not not supported on {}'.format(out_dtype)) - - out_np = np.empty(a.shape, dtype=torch_to_numpy_dtype_dict[out_dtype]) - - self.assertEqual(a, a.cpu().numpy()) - torch.logical_not(a, out=out) - np.logical_not(a.cpu().numpy(), out=out_np) - self.assertEqual(out_np, out.to('cpu')) - - def _test_logical(self, device, dtypes, op, a_, b_, expected_res_): - expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device) - a = torch.tensor(a_, dtype=dtypes[0], device=device) - b = torch.tensor(b_, dtype=dtypes[1], device=device) - - # new tensor - self.assertEqual(expected_res.bool(), getattr(a, op)(b)) - # out - c = torch.empty(0, dtype=torch.bool, device=device) - getattr(torch, op)(a, b, out=c) - self.assertEqual(expected_res.bool(), c) - - # in-place - # TODO: remove when different dtypes as operands are supported - if dtypes[0] != dtypes[1]: - with self.assertRaises(RuntimeError): - getattr(a, op + '_')(b) - return - - getattr(a, op + '_')(b) - self.assertEqual(expected_res, a) - - @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) - def test_logical_xor(self, device, dtypes): - self._test_logical(device, dtypes, 'logical_xor', [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]) - - @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) - def test_logical_and(self, device, dtypes): - self._test_logical(device, dtypes, 'logical_and', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]) - - @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) - def test_logical_or(self, device, dtypes): - self._test_logical(device, dtypes, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]) - - # Tests clamp and its alias, clip - def test_clamp(self, device): - op_list = ((torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_), - (torch.clip, torch.Tensor.clip, torch.Tensor.clip_)) - for op, method_op, inplace_op in op_list: - - m1 = torch.rand(100, device=device).mul(5).add(-2.5) # uniform in [-2.5, 2.5] - # just in case we're extremely lucky. - min_val = -1 - max_val = 1 - m1[1] = min_val - m1[2] = max_val - - res1 = m1.clone() - inplace_op(res1, min_val, max_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = max(min_val, min(max_val, res2[i])) - self.assertEqual(res1, res2) - - out = m1.clone() - op(m1, min=min_val, max=max_val, out=out) - self.assertEqual(out, res1) - - res1 = op(m1, min=min_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = max(min_val, res2[i]) - self.assertEqual(res1, res2) - - op(m1, min=min_val, out=out) - self.assertEqual(out, res1) - - res1 = op(m1, max=max_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = min(max_val, res2[i]) - self.assertEqual(res1, res2) - - op(m1, max=max_val, out=out) - self.assertEqual(out, res1) - - # if the tensor contains nan case - test_tens = torch.tensor([nan], device=device) - - res1 = test_tens.clone() - inplace_op(res1, min_val, max_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = max(min(res2[i], max_val), min_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - - out = test_tens.clone() - op(test_tens, min=min_val, max=max_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) - - res1 = op(test_tens, min=min_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = max(res2[i], min_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - - op(test_tens, min=min_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) - - res1 = op(test_tens, max=max_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = min(res2[i], max_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - - op(test_tens, max=max_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) - - error_msg = 'At least one of \'min\' or \'max\' must not be None' - with self.assertRaisesRegex(RuntimeError, error_msg): - method_op(m1) - with self.assertRaisesRegex(RuntimeError, error_msg): - inplace_op(m1) - - @onlyOnCPUAndCUDA - @dtypes(torch.float32, torch.float64) - def test_torch_complex(self, device, dtype): - real = torch.tensor([1, 2], device=device, dtype=dtype) - imag = torch.tensor([3, 4], device=device, dtype=dtype) - z = torch.complex(real, imag) - complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 - self.assertEqual(torch.tensor([1.0 + 3.0j, 2.0 + 4.0j], dtype=complex_dtype), z) - - @onlyOnCPUAndCUDA - @dtypes(torch.float32, torch.float64) - def test_torch_polar(self, device, dtype): - abs = torch.tensor([1, 2, -3, -4.5, 1, 1], device=device, dtype=dtype) - angle = torch.tensor([math.pi / 2, 5 * math.pi / 4, 0, -11 * math.pi / 6, math.pi, -math.pi], - device=device, dtype=dtype) - z = torch.polar(abs, angle) - complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 - self.assertEqual(torch.tensor([1j, -1.41421356237 - 1.41421356237j, -3, - -3.89711431703 - 2.25j, -1, -1], - dtype=complex_dtype), - z, atol=1e-5, rtol=1e-5) - - @onlyOnCPUAndCUDA - @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, - torch.float16, torch.complex64, torch.complex128, torch.bool) - def test_torch_complex_floating_dtype_error(self, device, dtype): - for op in (torch.complex, torch.polar): - a = torch.tensor([1, 2], device=device, dtype=dtype) - b = torch.tensor([3, 4], device=device, dtype=dtype) - error = r"Expected both inputs to be Float or Double tensors but " \ - r"got [A-Za-z]+ and [A-Za-z]+" - with self.assertRaisesRegex(RuntimeError, error): - op(a, b) - - @onlyOnCPUAndCUDA - @dtypes(torch.float32, torch.float64) - def test_torch_complex_same_dtype_error(self, device, dtype): - - def dtype_name(dtype): - return 'Float' if dtype == torch.float32 else 'Double' - - for op in (torch.complex, torch.polar): - other_dtype = torch.float64 if dtype == torch.float32 else torch.float32 - a = torch.tensor([1, 2], device=device, dtype=dtype) - b = torch.tensor([3, 4], device=device, dtype=other_dtype) - error = "Expected object of scalar type {} but got scalar type " \ - "{} for second argument".format(dtype_name(dtype), - dtype_name(other_dtype)) - with self.assertRaisesRegex(RuntimeError, error): - op(a, b) - - @onlyOnCPUAndCUDA - @dtypes(torch.float32, torch.float64) - def test_torch_complex_out_dtype_error(self, device, dtype): - - def dtype_name(dtype): - return 'Float' if dtype == torch.float32 else 'Double' - - def complex_dtype_name(dtype): - return 'ComplexFloat' if dtype == torch.complex64 else 'ComplexDouble' - - for op in (torch.complex, torch.polar): - a = torch.tensor([1, 2], device=device, dtype=dtype) - b = torch.tensor([3, 4], device=device, dtype=dtype) - out = torch.zeros(2, device=device, dtype=dtype) - expected_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 - error = "Expected object of scalar type {} but got scalar type " \ - "{} for argument 'out'".format( - complex_dtype_name(expected_dtype), dtype_name(dtype)) - with self.assertRaisesRegex(RuntimeError, error): - op(a, b, out=out) - - def test_cat_empty_legacy(self, device): - # FIXME: this is legacy behavior and should be removed - # when we support empty tensors with arbitrary sizes - dtype = torch.float32 - - x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) - empty = torch.randn((0,), dtype=dtype, device=device) - - res1 = torch.cat([x, empty], dim=1) - res2 = torch.cat([empty, x], dim=1) - self.assertEqual(res1, res2) - - res1 = torch.cat([empty, empty], dim=1) - self.assertEqual(res1, empty) - - with self.assertRaisesRegex(RuntimeError, - 'non-empty list of Tensors'): - torch.cat([], dim=1) - - def test_cat_empty(self, device): - dtype = torch.float32 - - x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) - empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device) - - res1 = torch.cat([x, empty], dim=1) - res2 = torch.cat([empty, x], dim=1) - self.assertEqual(res1, res2) - - res1 = torch.cat([empty, empty], dim=1) - self.assertEqual(res1, empty) - - # check non-legacy-behavior (sizes don't match) - empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device) - self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) - self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) - - # check non-legacy-behavior (dimensions don't match) - empty = torch.randn((4, 0), dtype=dtype, device=device) - self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) - self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) - - def test_cat_out(self, device): - x = torch.zeros((0), device=device) - y = torch.randn((4, 6), device=device) - - with self.assertRaisesRegex( - RuntimeError, r"unsupported operation:.* input tensor 0"): - torch.cat([x, y], dim=0, out=x) - - with self.assertRaisesRegex( - RuntimeError, r"unsupported operation:.* input tensor 1"): - torch.cat([x, y], dim=0, out=y) - - z = torch.zeros((4, 6), device=device) - with self.assertRaisesRegex( - RuntimeError, r"unsupported operation:.* input tensor 1"): - torch.cat([y, z], out=z[:2, :]) - - w = y.view(-1).clone() - a = torch.cat([w[:2], w[4:6]]) - b = torch.cat([w[:2], w[4:6]], out=w[6:10]) - self.assertEqual(a, b) - self.assertEqual(w[:6], y.view(-1)[:6]) - - def test_cat_out_channels_last(self, device): - x = torch.randn((4, 3, 8, 8)) - y = torch.randn(x.shape) - res1 = torch.cat((x, y)) - z = res1.clone().contiguous(memory_format=torch.channels_last) - res2 = torch.cat((x, y), out=z) - self.assertEqual(res1, res2) - - @onlyCPU - def test_cat_in_channels_last(self, device): - for dim in range(4): - x = torch.randn((4, 15, 8, 8), device=device) - y = torch.randn(x.shape, device=device) - res1 = torch.cat((x, y), dim=dim) - x = x.clone().contiguous(memory_format=torch.channels_last) - y = y.clone().contiguous(memory_format=torch.channels_last) - res2 = torch.cat((x, y), dim=dim) - self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last)) - self.assertEqual(res1, res2) - - # Size larger than grain size. - x = torch.randn((4, 15, 256, 256), device=device) - y = torch.randn(x.shape, device=device) - res1 = torch.cat((x, y), dim=dim) - x = x.clone().contiguous(memory_format=torch.channels_last) - y = y.clone().contiguous(memory_format=torch.channels_last) - res2 = torch.cat((x, y), dim=dim) - self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last)) - self.assertEqual(res1, res2) - - @onlyCUDA - def test_cat_preserve_channels_last(self, device): - x = torch.randn((4, 3, 8, 8), device=device) - y = torch.randn(x.shape, device=device) - res1 = torch.cat((x, y)) - res2 = torch.cat((x.contiguous(memory_format=torch.channels_last), y.contiguous(memory_format=torch.channels_last))) - self.assertEqual(res1, res2) - self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last)) - - @onlyCUDA - @deviceCountAtLeast(2) - def test_cat_different_devices(self, devices): - cuda0 = torch.randn((3, 3), device=devices[0]) - cuda1 = torch.randn((3, 3), device=devices[1]) - with self.assertRaisesRegex(RuntimeError, - "input tensors must be on the same device"): - torch.cat((cuda0, cuda1)) - cpu = torch.randn(3, 3) - with self.assertRaisesRegex(RuntimeError, - "input tensors must be on the same device"): - torch.cat((cuda0, cpu)) - with self.assertRaisesRegex(RuntimeError, - "input tensors must be on the same device"): - torch.cat((cpu, cuda0)) - - def test_block_diag(self, device): - def block_diag_workaround(*arrs): - arrs_expanded = [] - for a in arrs: - if a.dim() == 2: - arrs_expanded.append(a) - elif a.dim() == 1: - arrs_expanded.append(a.expand(1, a.size(0))) - elif a.dim() == 0: - arrs_expanded.append(a.expand(1, 1)) - shapes = torch.tensor([a.shape for a in arrs_expanded], device=device) - out = torch.zeros( - torch.sum(shapes, dim=0).tolist(), - dtype=arrs_expanded[0].dtype, - device=device - ) - r, c = 0, 0 - for i, (rr, cc) in enumerate(shapes): - out[r:r + rr, c:c + cc] = arrs_expanded[i] - r += rr - c += cc - return out - - tensors = [ - torch.rand((2, 2), device=device), - torch.rand((2, 3), device=device), - torch.rand(10, device=device), - torch.rand((8, 1), device=device), - torch.rand(1, device=device)[0] - ] - result = torch.block_diag(*tensors) - result_check = block_diag_workaround(*tensors) - self.assertEqual(result, result_check) - - tensor = torch.rand(1, device=device)[0] - result = torch.block_diag(tensor) - result_check = tensor.expand(1, 1) - self.assertEqual(result, result_check) - - tensor = torch.rand(10, device=device) - result = torch.block_diag(tensor) - result_check = tensor.expand(1, tensor.size(0)) - self.assertEqual(result, result_check) - - result = torch.block_diag() - result_check = torch.empty(1, 0, device=device) - self.assertEqual(result, result_check) - self.assertEqual(result.device.type, 'cpu') - - test_dtypes = [ - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.float32, - torch.float64, - torch.complex64, - torch.complex128 - ] - # Test pairs of different dtypes - for dtype1 in test_dtypes: - for dtype2 in test_dtypes: - a = torch.tensor(1, device=device, dtype=dtype1) - b = torch.tensor(2, device=device, dtype=dtype2) - result = torch.block_diag(a, b) - result_dtype = torch.result_type(a, b) - result_check = torch.tensor([[1, 0], [0, 2]], device=device, dtype=result_dtype) - self.assertEqual(result, result_check) - - with self.assertRaisesRegex( - RuntimeError, - "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input 1 has 3 dimensions" - ): - torch.block_diag(torch.tensor(5), torch.tensor([[[6]]])) - - with self.assertRaisesRegex( - RuntimeError, - "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input 0 has 4 dimensions" - ): - torch.block_diag(torch.tensor([[[[6]]]])) - - if device != 'cpu': - with self.assertRaisesRegex( - RuntimeError, - ( - "torch.block_diag: input tensors must all be on the same device." - " Input 0 is on device cpu and input 1 is on device " - ) - ): - torch.block_diag(torch.ones(2, 2).cpu(), torch.ones(2, 2, device=device)) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - def test_block_diag_scipy(self, device): - import scipy.linalg - scipy_tensors_list = [ - [ - 1, - [2], - [], - [3, 4, 5], - [[], []], - [[6], [7.3]] - ], - [ - [[1, 2], [3, 4]], - [1] - ], - [ - [[4, 9], [7, 10]], - [4.6, 9.12], - [1j + 3] - ], - [] - ] - - expected_torch_types = [ - torch.float32, - torch.int64, - torch.complex64, - torch.float32 - ] - - expected_scipy_types = [ - torch.float64, - # windows scipy block_diag returns int32 types - torch.int32 if IS_WINDOWS else torch.int64, - torch.complex128, - torch.float64 - ] - - for scipy_tensors, torch_type, scipy_type in zip(scipy_tensors_list, expected_torch_types, expected_scipy_types): - torch_tensors = [torch.tensor(t, device=device) for t in scipy_tensors] - torch_result = torch.block_diag(*torch_tensors) - self.assertEqual(torch_result.dtype, torch_type) - - scipy_result = torch.tensor( - scipy.linalg.block_diag(*scipy_tensors), - device=device - ) - self.assertEqual(scipy_result.dtype, scipy_type) - scipy_result = scipy_result.to(torch_type) - - self.assertEqual(torch_result, scipy_result) - def test_is_set_to(self, device): t1 = torch.empty(3, 4, 9, 10, device=device) t2 = torch.empty(3, 4, 9, 10, device=device) @@ -6813,1240 +3332,6 @@ def test_is_set_to(self, device): self.assertFalse(t1.is_set_to(t2)) self.assertFalse(t2.is_set_to(t1)) - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_inverse_many_batches(self, device): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - def test_inverse_many_batches_helper(b, n): - matrices = random_fullrank_matrix_distinct_singular_value(b, n, n).to(device) - matrices_inverse = torch.inverse(matrices) - self.assertEqual(torch.matmul(matrices_inverse, matrices), - torch.eye(b, dtype=torch.float64, device=device).expand_as(matrices)) - - test_inverse_many_batches_helper(5, 256) - test_inverse_many_batches_helper(3, 512) - test_inverse_many_batches_helper(64, 64) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_pinverse(self, device, dtype): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value as fullrank - - def run_test(M): - # Testing against definition for pseudo-inverses - MPI = torch.pinverse(M) - if M.numel() > 0: - self.assertEqual(M, M.matmul(MPI).matmul(M), atol=1e-8, rtol=0, msg='pseudo-inverse condition 1') - self.assertEqual(MPI, MPI.matmul(M).matmul(MPI), atol=1e-8, rtol=0, msg='pseudo-inverse condition 2') - self.assertEqual(M.matmul(MPI), (M.matmul(MPI)).transpose(-2, -1), - atol=1e-8, rtol=0, msg='pseudo-inverse condition 3') - self.assertEqual(MPI.matmul(M), (MPI.matmul(M)).transpose(-2, -1), - atol=1e-8, rtol=0, msg='pseudo-inverse condition 4') - else: - self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2])) - for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5), # square matrices - (3, 2), (5, 3, 2), (7, 5, 3, 2), # fat matrices - (2, 3), (5, 2, 3), (7, 5, 2, 3), # thin matrices - (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices - M = torch.randn(*sizes, dtype=dtype, device=device) - run_test(M) - - # Test inverse and pseudo-inverse for invertible matrix - for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]: - matsize = sizes[-1] - batchdims = sizes[:-2] - M = fullrank(matsize, *batchdims, dtype=dtype, device=device) - self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M), - atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix') - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_matrix_rank(self, device): - a = torch.eye(10, device=device) - self.assertEqual(torch.matrix_rank(a).item(), 10) - self.assertEqual(torch.matrix_rank(a, True).item(), 10) - - a[5, 5] = 0 - self.assertEqual(torch.matrix_rank(a).item(), 9) - self.assertEqual(torch.matrix_rank(a, True).item(), 9) - - a = torch.randn(24, 42, device=device) - self.assertEqual(torch.matrix_rank(a), torch.matrix_rank(a.t())) - aaT = torch.mm(a, a.t()) - self.assertEqual(torch.matrix_rank(aaT), torch.matrix_rank(aaT, True)) - aTa = torch.mm(a.t(), a) - self.assertEqual(torch.matrix_rank(aTa), torch.matrix_rank(aTa, True)) - - if TEST_NUMPY: - from numpy.linalg import matrix_rank - a = torch.randn(35, 75, device=device) - self.assertEqual(torch.matrix_rank(a).item(), matrix_rank(a.cpu().numpy())) - self.assertEqual(torch.matrix_rank(a, 0.01).item(), matrix_rank(a.cpu().numpy(), 0.01)) - - aaT = torch.mm(a, a.t()) - self.assertEqual(torch.matrix_rank(aaT).item(), matrix_rank(aaT.cpu().numpy())) - self.assertEqual(torch.matrix_rank(aaT, 0.01).item(), matrix_rank(aaT.cpu().numpy(), 0.01)) - - if np.lib.NumpyVersion(np.__version__) >= '1.14.0': - self.assertEqual(torch.matrix_rank(aaT, True).item(), matrix_rank(aaT.cpu().numpy(), True)) - self.assertEqual(torch.matrix_rank(aaT, 0.01, True).item(), - matrix_rank(aaT.cpu().numpy(), 0.01, True)) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_matrix_power(self, device, dtype): - def run_test(M, sign=1): - if sign == -1: - M = M.inverse() - MP2 = torch.matrix_power(M, 2) - self.assertEqual(MP2, torch.matmul(M, M)) - - MP3 = torch.matrix_power(M, 3) - self.assertEqual(MP3, torch.matmul(MP2, M)) - - MP4 = torch.matrix_power(M, 4) - self.assertEqual(MP4, torch.matmul(MP2, MP2)) - - MP6 = torch.matrix_power(M, 6) - self.assertEqual(MP6, torch.matmul(MP3, MP3)) - - MP0 = torch.matrix_power(M, 0) - self.assertEqual(MP0, torch.eye(M.size(-2), dtype=dtype).expand_as(M)) - - # Single matrix - M = torch.randn(5, 5, dtype=dtype, device=device) - run_test(M) - - # Batch matrices - M = torch.randn(3, 3, 3, dtype=dtype, device=device) - run_test(M) - - # Many batch matrices - M = torch.randn(2, 3, 3, 3, dtype=dtype, device=device) - run_test(M) - - # This is for negative powers - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - M = random_fullrank_matrix_distinct_singular_value(5, dtype=dtype, device=device) - run_test(M, sign=-1) - - M = random_fullrank_matrix_distinct_singular_value(3, 3, dtype=dtype, device=device) - run_test(M, sign=-1) - - M = random_fullrank_matrix_distinct_singular_value(3, 2, 3, dtype=dtype, device=device) - run_test(M, sign=-1) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float, torch.complex64) - def test_matrix_exp_utils(self, device, dtype): - # test linear combination - def run_test(coeff_shape, data_shape): - coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float) - x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype) - - res1 = torch._compute_linear_combination(x, coeffs) - res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1) - self.assertEqual(res1, res2, atol=1e-5, rtol=0.0) - - # check `out=` version - res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype) - torch._compute_linear_combination(x, coeffs, out=res3) - self.assertEqual(res1, res3, atol=1e-5, rtol=0.0) - - res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) - torch._compute_linear_combination(x, coeffs, out=res4) - self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0) - - res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype) - res5_clone = res5.clone() - torch._compute_linear_combination(x, coeffs, out=res5) - self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0) - - run_test([1, 3], [2, 2]) - run_test([3, 1], [2, 2]) - run_test([1, 10], [10, 10]) - run_test([10, 1], [10, 10]) - run_test([5, 3], [2, 2]) - run_test([5, 3], [100, 100]) - run_test([3, 4], [3, 3, 3]) - run_test([3, 4], [3, 3, 3, 3]) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) - def test_matrix_exp_boundary_cases(self, device, dtype): - - with self.assertRaisesRegex(RuntimeError, "expected a tensor of floating or complex types"): - torch.randn(3, 3).type(torch.int).matrix_exp() - - with self.assertRaisesRegex(RuntimeError, "with dim at least 2"): - torch.randn(3).matrix_exp() - - with self.assertRaisesRegex(RuntimeError, "expected a tensor of squared matrices"): - torch.randn(3, 2, 1).matrix_exp() - - # check 1x1 matrices - x = torch.randn(3, 3, 1, 1) - mexp = x.matrix_exp() - self.assertEqual(mexp, x.exp()) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) - # Although tf32 is always disabled on matrix_exp, this test uses matmul, - # which has tf32 on by default - @with_tf32_off - def test_matrix_exp_analytic(self, device, dtype): - # check zero matrix - x = torch.zeros(20, 20, dtype=dtype, device=device) - self.assertTrue((x.matrix_exp() == torch.eye(20, 20, dtype=dtype, device=device)).all().item()) - - def normalize_to_1_operator_norm(sample, desired_norm): - sample_norm, _ = sample.abs().sum(-2).max(-1) - sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) - return sample_to_1_norm * desired_norm - - def gen_good_cond_number_matrices(*n): - """ - Generates a diagonally-domimant matrix - with the eigenvalues centered at 1 - and the radii at most (n[-1] - 1) / (n[-2] ** 2) - """ - identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) - x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) - x = (x - x * identity) + identity - return x - - def run_test(*n): - if dtype == torch.float: - thetas = [ - 1.192092800768788e-07, # deg 1 - 5.978858893805233e-04, # deg 2 - 5.116619363445086e-02, # deg 4 - 5.800524627688768e-01, # deg 8 - 1.461661507209034e+00, # deg 12 - 3.010066362817634e+00 # deg 18 - ] - else: # if torch.double - thetas = [ - 2.220446049250313e-16, # deg 1 - 2.580956802971767e-08, # deg 2 - 3.397168839976962e-04, # deg 4 - 4.991228871115323e-02, # deg 8 - 2.996158913811580e-01, # deg 12 - 1.090863719290036e+00 # deg 18 - ] - - # generate input - q = gen_good_cond_number_matrices(*n) - qinv = torch.inverse(q) - d = torch.randn(n[:-1], dtype=dtype, device=device) - x = torch.matmul(q, torch.matmul(torch.diag_embed(d), qinv)) - x_norm, _ = x.abs().sum(-2).max(-1) - - # test simple analytic whatever norm generated - mexp = x.matrix_exp() - mexp_analytic = torch.matmul( - q, - torch.matmul( - torch.diag_embed(d.exp()), - qinv - ) - ) - self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) - - # generate norms to test different degree expansions - sample_norms = [] - for i in range(len(thetas) - 1): - sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) - sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] - - # matrices to equal norm - for sample_norm in sample_norms: - x_normalized = normalize_to_1_operator_norm(x, sample_norm) - - mexp = x_normalized.matrix_exp() - mexp_analytic = torch.matmul( - q, - torch.matmul( - torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()), - qinv - ) - ) - self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0) - - # single matrix - run_test(2, 2) - run_test(3, 3) - run_test(4, 4) - run_test(5, 5) - run_test(100, 100) - run_test(200, 200) - - # small batch of matrices - run_test(3, 2, 2) - run_test(3, 3, 3) - run_test(3, 4, 4) - run_test(3, 5, 5) - run_test(3, 100, 100) - run_test(3, 200, 200) - - # large batch of matrices - run_test(3, 3, 2, 2) - run_test(3, 3, 3, 3) - run_test(3, 3, 4, 4) - run_test(3, 3, 5, 5) - run_test(3, 3, 100, 100) - run_test(3, 3, 200, 200) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) - def test_matrix_exp_batch(self, device, dtype): - - def run_test(*n): - tensors_batch = torch.zeros(n, dtype=dtype, device=device) - tensors_batch = tensors_batch.view(-1, n[-2], n[-1]) - - num_matrices = tensors_batch.size(0) - tensors_list = [] - for i in range(num_matrices): - tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device)) - - for i in range(num_matrices): - tensors_batch[i, ...] = tensors_list[i] - - tensors_exp_map = map(lambda x: x.matrix_exp(), tensors_list) - tensors_exp_batch = tensors_batch.matrix_exp() - - for i, tensor_exp in enumerate(tensors_exp_map): - self.assertEqual(tensors_exp_batch[i, ...], tensor_exp) - - # small batch of matrices - run_test(3, 2, 2) - run_test(3, 3, 3) - run_test(3, 4, 4) - run_test(3, 5, 5) - - # large batch of matrices - run_test(3, 3, 2, 2) - run_test(3, 3, 3, 3) - run_test(3, 3, 4, 4) - run_test(3, 3, 5, 5) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) - # Although tf32 is always disabled on matrix_exp, this test uses matmul, - # which has tf32 on by default - @with_tf32_off - def test_matrix_exp_compare_with_taylor(self, device, dtype): - - def normalize_to_1_operator_norm(sample, desired_norm): - sample_norm, _ = sample.abs().sum(-2).max(-1) - sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1) - return sample_to_1_norm * desired_norm - - def gen_good_cond_number_matrices(*n): - """ - Generates a diagonally-domimant matrix - with the eigenvalues centered at 1 - and the radii at most (n[-1] - 1) / (n[-2] ** 2) - """ - identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n) - x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2) - x = (x - x * identity) + identity - return x - - def get_taylor_approximation(a, deg): - identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a) - res = identity - taylor_term = identity - - for i in range(1, deg + 1): - taylor_term = torch.matmul(a, taylor_term) / i - res = res + taylor_term - - return res - - def scale_square(a, deg): - if a.norm() < 1.0: - return get_taylor_approximation(a, 12) - else: - s = int(torch.log2(a.norm()).ceil().item()) - b = a / (2 ** s) - b = get_taylor_approximation(b, 18) - for _ in range(s): - b = torch.matmul(b, b) - return b - - def run_test(*n): - degs = [1, 2, 4, 8, 12, 18] - if dtype == torch.float: - thetas = [ - 1.192092800768788e-07, # deg 1 - 5.978858893805233e-04, # deg 2 - 5.116619363445086e-02, # deg 4 - 5.800524627688768e-01, # deg 8 - 1.461661507209034e+00, # deg 12 - 3.010066362817634e+00 # deg 18 - ] - else: # if torch.double - thetas = [ - 2.220446049250313e-16, # deg 1 - 2.580956802971767e-08, # deg 2 - 3.397168839976962e-04, # deg 4 - 4.991228871115323e-02, # deg 8 - 2.996158913811580e-01, # deg 12 - 1.090863719290036e+00 # deg 18 - ] - - # generate norms to test different degree expansions - sample_norms = [] - for i in range(len(thetas) - 1): - sample_norms.append(0.5 * (thetas[i] + thetas[i + 1])) - sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2] - degs = [degs[0]] + degs - - for sample_norm, deg in zip(sample_norms, degs): - x = gen_good_cond_number_matrices(*n) - x = normalize_to_1_operator_norm(x, sample_norm) - - mexp = x.matrix_exp() - mexp_taylor = scale_square(x, deg) - - self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0) - - # single matrix - run_test(2, 2) - run_test(3, 3) - run_test(4, 4) - run_test(5, 5) - - # small batch of matrices - run_test(3, 2, 2) - run_test(3, 3, 3) - run_test(3, 4, 4) - run_test(3, 5, 5) - - # large batch of matrices - run_test(3, 3, 2, 2) - run_test(3, 3, 3, 3) - run_test(3, 3, 4, 4) - run_test(3, 3, 5, 5) - - @dtypes(torch.double) - def test_chain_matmul(self, device, dtype): - def product(matrices): - for mat in matrices[1:]: - matrices[0] = matrices[0].mm(mat) - return matrices[0] - - def run_test(p): - matrices = [] - for (pi, pi_1) in zip(p[:-1], p[1:]): - matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device)) - self.assertEqual(torch.chain_matmul(*matrices), product(matrices)) - - run_test([10, 20, 30, 5]) - run_test([15, 5, 10, 20, 25]) - - with self.assertRaisesRegex(RuntimeError, "chain_matmul: Expected one or more matrices"): - torch.chain_matmul() - - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_det_logdet_slogdet(self, device, dtype): - def reference_slogdet(M): - if TEST_NUMPY: - sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) - return M.new_tensor(sdet), M.new_tensor(logabsdet) - else: - # naive row reduction - M = M.clone() - l = M.size(0) - multiplier = 1 - for i in range(l): - if M[i, 0].item() != 0: - if i != 0: - M[0], M[i] = M[i], M[0] - multiplier = -1 - break - else: - return 0 - for i in range(1, l): - row = M[i] - for j in range(i): - row -= row[j] / M[j, j] * M[j] - M[i] = row - sdet = M.diag().sign().prod() - logabsdet = M.diag().abs_().log_().sum().add_(math.log(multiplier)) - return sdet, logabsdet - - def test_single_det(M, target, desc): - target_sdet, target_logabsdet = target - - det = M.det() - logdet = M.logdet() - sdet, logabsdet = M.slogdet() - - # Test det - self.assertEqual(det, target_sdet * target_logabsdet.exp(), - atol=1e-7, rtol=0, msg='{} (det)'.format(desc)) - - # Test slogdet - # Compare the overall value rather than individual parts because of - # precision issues when det is near zero. - self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), - atol=1e-7, rtol=0, msg='{} (slogdet)'.format(desc)) - - # Test logdet - # Compare logdet against our own pytorch slogdet because they should - # be consistent, while it may behave slightly differently with other - # slogdet implementations when det is near zero due to precision - # issues. - if sdet.item() < 0: - self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc)) - else: - self.assertEqual(logdet.exp(), target_logabsdet.exp(), - atol=1e-7, rtol=0, msg='{} (logdet non-negative case)'.format(desc)) - - eye = torch.eye(5, dtype=dtype, device=device) - test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity') - # Testing bug in #34061 (https://github.com/pytorch/pytorch/issues/34061) - for n in range(250, 551, 100): - mat = torch.randn(n, n, dtype=dtype, device=device) - q, _ = torch.qr(mat) - ref_det, ref_logabsdet = reference_slogdet(q) - test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal') - - def test(M): - assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' - M = M.to(device) - - ref_M_sdet, ref_M_logabsdet = reference_slogdet(M) - - test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic') - if ref_M_logabsdet.exp().item() >= 1e-6: # skip singular - M_inv = M.inverse() - test_single_det(M_inv, reference_slogdet(M_inv), 'inverse') - - test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose') - - for x in [0, 2, 4]: - for scale in [-2, -0.1, 0, 10]: - if scale > 0: - target = ref_M_sdet, ref_M_logabsdet + math.log(scale) - elif scale == 0: - target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) - else: - target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale) - - # dim 0 - M_clone = M.clone() - M_clone[:, x] *= scale - test_single_det(M_clone, target, 'scale a row') - # dim 1 - M_clone = M.clone() - M_clone[x, :] *= scale - test_single_det(M_clone, target, 'scale a column') - - for x1, x2 in [(0, 3), (4, 1), (3, 2)]: - assert x1 != x2, 'x1 and x2 needs to be different for this test' - target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) - # dim 0 - M_clone = M.clone() - M_clone[:, x2] = M_clone[:, x1] - test_single_det(M_clone, target, 'two rows are same') - # dim 1 - M_clone = M.clone() - M_clone[x2, :] = M_clone[x1, :] - test_single_det(M_clone, target, 'two columns are same') - - for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]: - det_scale = scale1 * scale2 * -1 - if det_scale > 0: - target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale) - elif det_scale == 0: - target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) - else: - target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale) - - # dim 0 - M_clone = M.clone() - t = M_clone[:, x1] * scale1 - M_clone[:, x1] += M_clone[:, x2] * scale2 - M_clone[:, x2] = t - test_single_det(M_clone, target, 'exchanging rows') - # dim 1 - M_clone = M.clone() - t = M_clone[x1, :] * scale1 - M_clone[x1, :] += M_clone[x2, :] * scale2 - M_clone[x2, :] = t - test_single_det(M_clone, target, 'exchanging columns') - - def get_random_mat_scale(n): - # For matrices with values i.i.d. with 0 mean, unit variance, and - # subexponential tail, we have: - # E[log det(A^2)] \approx log((n-1)!) - # - # Notice: - # log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)] - # - # So: - # stddev[det(A)] >= sqrt( (n-1)! ) - # - # We use this as an intuitive guideline to scale random generated - # matrices so our closeness tests can work more robustly: - # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n)) - # - # source: https://arxiv.org/pdf/1112.0752.pdf - - # TODO: technically we need subexponential distn for this to hold, - # but we mostly use gaussian entries below. Consider switching - # to Chi-sq if this turns out not stable enough, since Chi-sq - # is easy enough to sample from. - return math.factorial(n - 1) ** (-1.0 / (2 * n)) - - for n in [5, 10, 25]: - scale = get_random_mat_scale(n) - test(torch.randn(n, n, dtype=dtype, device=device) * scale) - r = torch.randn(n, n, dtype=dtype, device=device) * scale - # symmetric psd - test(r.mm(r.t())) - # symmetric pd - r = torch.randn(n, n, dtype=dtype, device=device) * scale - test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6) - # symmetric - r = torch.randn(n, n, dtype=dtype, device=device) * scale - for i in range(n): - for j in range(i): - r[i, j] = r[j, i] - test(r) - # non-contiguous - test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:]) - # det = 0 - r = torch.randn(n, n, dtype=dtype, device=device) * scale - u, s, v = r.svd() - if reference_slogdet(u)[0] < 0: - u = -u - if reference_slogdet(v)[0] < 0: - v = -v - s[0] *= -1 - s[-1] = 0 - test(u.mm(s.diag()).mm(v)) - - # Small values to test numerical stability. Note that we don't scale - # this matrix. - r = torch.randn(512, 512, dtype=dtype, device=device) - u, s, v = r.svd() - s.fill_(1. / (100 * s.numel())) - test(u.mm(s.diag()).mm(v)) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_det_logdet_slogdet_batched(self, device, dtype): - from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix, - random_symmetric_pd_matrix, random_square_matrix_of_rank) - - # mat_chars denotes matrix characteristics - # possible values are: sym, sym_psd, sym_pd, sing, non_sym - def run_test(matsize, batchdims, mat_chars): - num_matrices = reduce(lambda x, y: x * y, batchdims, 1) - list_of_matrices = [] - - for idx in range(num_matrices): - mat_type = idx % len(mat_chars) - if mat_chars[mat_type] == 'sym': - list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device)) - elif mat_chars[mat_type] == 'sym_psd': - list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device)) - elif mat_chars[mat_type] == 'sym_pd': - list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device)) - elif mat_chars[mat_type] == 'sing': - list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) - elif mat_chars[mat_type] == 'non_sing': - list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) - full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) - # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet - full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize))) - - for fn in [torch.det, torch.logdet, torch.slogdet]: - expected_value = [] - actual_value = fn(full_tensor) - for full_idx in product(*map(lambda x: list(range(x)), batchdims)): - expected_value.append(fn(full_tensor[full_idx])) - - if fn == torch.slogdet: - sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims) - expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims) - self.assertEqual(sign_value, actual_value[0]) - self.assertEqual(expected_value, actual_value[1]) - else: - expected_value = torch.stack(expected_value, dim=0).reshape(batchdims) - self.assertEqual(actual_value, expected_value) - - for matsize, batchdims in product([3, 5], [(3,), (5, 3)]): - run_test(matsize, batchdims, mat_chars=['sym_pd']) - run_test(matsize, batchdims, mat_chars=['sing']) - run_test(matsize, batchdims, mat_chars=['non_sing']) - run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) - run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) - - def solve_test_helper(self, A_dims, b_dims, device, dtype): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - b = torch.randn(*b_dims, dtype=dtype, device=device) - A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device) - return b, A - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_solve(self, device, dtype): - for (k, n) in zip([2, 3, 5], [3, 5, 7]): - b, A = self.solve_test_helper((n,), (n, k), device, dtype) - x = torch.solve(b, A)[0] - self.assertLessEqual(b.dist(A.mm(x)), 1e-12) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_solve_batched(self, device, dtype): - def solve_batch_helper(A_dims, b_dims): - b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.solve(b[i], A[i])[0]) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.solve(b, A)[0] # Actual output - self.assertEqual(x_exp, x_act) # Equality check - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check - - for batchsize in [1, 3, 4]: - solve_batch_helper((5, batchsize), (batchsize, 5, 10)) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_solve_batched_non_contiguous(self, device, dtype): - from numpy.linalg import solve - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, - device=device).permute(1, 0, 2) - b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(dtype=dtype, device=device) - self.assertEqual(x, x_exp) - - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_solve_batched_many_batches(self, device, dtype): - b, A = self.solve_test_helper((5, 256, 256), (5, 1), device, dtype) - x, _ = torch.solve(b, A) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - - b, A = self.solve_test_helper((3,), (512, 512, 3, 1), device, dtype) - x, _ = torch.solve(b, A) - self.assertEqual(torch.matmul(A, x), b) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_solve_batched_broadcasting(self, device, dtype): - from numpy.linalg import solve - - def run_test(A_dims, b_dims): - A_matrix_size = A_dims[-1] - A_batch_dims = A_dims[:-2] - b, A = self.solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, device, dtype) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(dtype=dtype, device=device) - self.assertEqual(x, x_exp) - - # test against numpy.linalg.solve - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b - - def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - - b = torch.randn(*b_dims, dtype=dtype, device=device) - A = random_symmetric_pd_matrix(*A_dims, dtype=dtype, device=device) - L = torch.cholesky(A, upper=upper) - return b, A, L - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_cholesky_solve(self, device, dtype): - for (k, n), upper in product(zip([2, 3, 5], [3, 5, 7]), [True, False]): - b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype) - x = torch.cholesky_solve(b, L, upper=upper) - self.assertLessEqual(b.dist(A.mm(x)), 1e-12) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_cholesky_solve_batched(self, device, dtype): - def cholesky_solve_batch_helper(A_dims, b_dims, upper): - b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper)) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.cholesky_solve(b, L, upper=upper) # Actual output - self.assertEqual(x_act, x_exp) # Equality check - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 2e-12) # Correctness check - - for upper, batchsize in product([True, False], [1, 3, 4]): - cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_cholesky_solve_batched_non_contiguous(self, device, dtype): - from numpy.linalg import solve - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - - for upper in [True, False]: - A = random_symmetric_pd_matrix(2, 2, dtype=dtype, device='cpu') - b = torch.randn(2, 2, 2, dtype=dtype, device='cpu') - x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(dtype=dtype, device=device) - A = A.to(device).permute(0, 2, 1) - b = b.to(device).permute(2, 1, 0) - assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" - L = torch.cholesky(A, upper) - x = torch.cholesky_solve(b, L, upper=upper) - self.assertEqual(x, x_exp) - - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_cholesky_solve_batched_many_batches(self, device, dtype): - for upper in [True, False]: - b, A, L = self.cholesky_solve_test_helper((5, 256, 256), (5, 10), upper, device, dtype) - x = torch.cholesky_solve(b, L, upper) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 10))) - - b, A, L = self.cholesky_solve_test_helper((5,), (512, 512, 5, 10), upper, device, dtype) - x = torch.cholesky_solve(b, L, upper) - self.assertEqual(torch.matmul(A, x), b) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_cholesky_solve_batched_broadcasting(self, device, dtype): - from numpy.linalg import solve - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - - def run_test(A_dims, b_dims, upper): - A_matrix_size = A_dims[-1] - A_batch_dims = A_dims[:-2] - A = random_symmetric_pd_matrix(A_matrix_size, *A_batch_dims, - dtype=dtype, device='cpu') - b = torch.randn(*b_dims, dtype=dtype, device='cpu') - x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device) - A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device) - L = torch.cholesky(A, upper) - x = torch.cholesky_solve(b, L, upper=upper) - self.assertEqual(x, x_exp) - # issue gh-42695 - x = torch.cholesky_solve(b, L, upper=upper, out=x) - self.assertEqual(x, x_exp) - - # test against numpy.linalg.solve - for upper in [True, False]: - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), upper) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), upper) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper) # broadcasting A & b - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_cholesky_inverse(self, device, dtype): - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - a = random_symmetric_pd_matrix(5, dtype=dtype, device=device) - - # compute inverse directly - inv0 = torch.inverse(a) - - # default case - chol = torch.cholesky(a) - inv1 = torch.cholesky_inverse(chol, False) - self.assertLessEqual(inv0.dist(inv1), 1e-12) - - # upper Triangular Test - chol = torch.cholesky(a, True) - inv1 = torch.cholesky_inverse(chol, True) - self.assertLessEqual(inv0.dist(inv1), 1e-12) - - # lower Triangular Test - chol = torch.cholesky(a, False) - inv1 = torch.cholesky_inverse(chol, False) - self.assertLessEqual(inv0.dist(inv1), 1e-12) - - @slowTest - @skipCUDAIf(True, "See issue #26789.") - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_cholesky_batched_many_batches(self, device, dtype): - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - - def cholesky_test_helper(n, batchsize, device, upper): - A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) - chol_fact = torch.cholesky(A, upper=upper) - if upper: - # Correctness check - self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) - # Upper triangular check - self.assertEqual(chol_fact, chol_fact.triu()) - else: - # Correctness check - self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) - # Lower triangular check - self.assertEqual(chol_fact, chol_fact.tril()) - - for upper, batchsize in product([True, False], [262144, 524288]): - cholesky_test_helper(2, batchsize, device, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_cholesky_batched(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) - - def cholesky_test_helper(n, batch_dims, upper): - # This is a workaround while there is no support for complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - # There is no support for complex batched matmul yet - matmul_list = [] - for mat in A.contiguous().view(-1, n, n): - matmul_list.append(mat @ mat.t().conj()) - A = torch.stack(matmul_list).view(*batch_dims, n, n) - else: - A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device) - cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) - cholesky_exp = cholesky_exp.reshape_as(A) - self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) - - for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]): - cholesky_test_helper(3, batchsize, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) - - # This is a workaround while there is no support for complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - A = A @ A.t().conj() - else: - A = random_symmetric_pd_matrix(10, dtype=dtype, device=device) - - # default Case - C = torch.cholesky(A) - B = torch.mm(C, C.t().conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0) - - # test Upper Triangular - U = torch.cholesky(A, True) - B = torch.mm(U.t().conj(), U) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') - - # test Lower Triangular - L = torch.cholesky(A, False) - B = torch.mm(L, L.t().conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') - - def test_view(self, device): - tensor = torch.rand(15, device=device) - template = torch.rand(3, 5, device=device) - empty = torch.empty(0, device=device) - target = template.size() - self.assertEqual(tensor.view_as(template).size(), target) - self.assertEqual(tensor.view(3, 5).size(), target) - self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) - self.assertEqual(tensor.view(-1, 5).size(), target) - self.assertEqual(tensor.view(3, -1).size(), target) - tensor_view = tensor.view(5, 3) - tensor_view.fill_(random.uniform(0, 1)) - self.assertEqual(empty.view_as(empty), empty) - self.assertEqual(empty.view(0), empty) - self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) - self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) - - # test size inference with empty tensors - self.assertEqual(empty.view(-1).size(), torch.Size([0])) - self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) - - with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): - empty.view(-1, 0) - - with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): - empty.view(3, 0, -1, 0) - - self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) - self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) - self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) - - # test view when tensor is not contiguous in every dimension, but only - # contiguous dimensions are touched. - tensor = torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device).transpose(-1, 2).transpose(-2, 3) - # size: [ 4, 2, 3, 9, 6, 2, 1, 5] - # stride: [3840, 1620, 1, 3, 54, 27, 324, 324] - # contiguous dim chunks: [__________, ____, ____, __________, ____, ____] - # merging 1 to chunk after: [__________, ____, ____, __________, __________] - contig_tensor = tensor.clone() - # [4, 2] => [8, 1] - # [3] => [3] - # [9] => [3, 3] - # [6, 2] => [4, 1, 3] - # [1, 5] => [5] - view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5] - self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) - # [4, 2] => [2, 4] - # [3] => [3] - # [9] => [1, 9] - # [6, 2] => [2, 2, 3] - # [1, 5] => [5, 1] - view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1] - self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) - # adding size 1 dims - view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] - self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) - - # invalid views - self.assertRaises(RuntimeError, lambda: tensor.view(-1)) - # crossing [4, 2], [3] - self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5)) - # crossing [6, 2], [1, 5] - self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10)) - # crossing [9], [6, 2] - self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5)) - - # view with stride 0 dims - tensor = torch.empty(1, 1, device=device).expand(3, 4) # all dims are contiguous - contig_tensor = tensor.clone() - self.assertEqual(tensor.view(-1), contig_tensor.view(-1)) - self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1)) - self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1)) - self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) - self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) - - def test_flip(self, device): - data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) - - self.assertEqual(torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) - self.assertEqual(torch.tensor([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2), data.flip(1)) - self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(2)) - self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1)) - self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2)) - - # check for wrap dim - self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1)) - # check for permute - self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2)) - self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) - - # not allow flip on the same dim more than once - self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) - # not allow empty list as input - self.assertRaises(TypeError, lambda: data.flip()) - - # not allow size of flip dim > total dims - self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) - # not allow dim > max dim - self.assertRaises(IndexError, lambda: data.flip(3)) - - # test for non-contiguous case - expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) - transposed_data = torch.arange(1, 9, device=device).view(2, 2, 2).transpose(0, 1) - self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0)) - self.assertEqual(torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), transposed_data.flip(0, 1, 2)) - - # test for shape - data = torch.randn(2, 3, 4, device=device) - size = [2, 3, 4] - test_dims = [] - for i in range(1, 3): - test_dims += combinations(range(len(size)), i) - - for ds in test_dims: - self.assertEqual(size, list(data.flip(ds).size())) - - # test rectangular case - data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3).to(device) - flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]).to(device) - flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]).to(device) - - self.assertEqual(flip0_result, data.flip(0)) - self.assertEqual(flip1_result, data.flip(1)) - - # test empty tensor, should just return an empty tensor of the same shape - data = torch.tensor([]) - self.assertEqual(data, data.flip(0)) - - # test bool tensor - a = torch.tensor([False, True]) - self.assertEqual(a.flip(0), torch.tensor([True, False])) - - def _rand_shape(self, dim, min_size, max_size): - shape = [] - for i in range(dim): - shape.append(random.randint(min_size, max_size)) - return tuple(shape) - - @dtypes(torch.cfloat, torch.cdouble) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_complex_flip(self, device, dtype): - rand_dim = random.randint(3, 4) - shape = self._rand_shape(rand_dim, 5, 10) - - # Axis to sample for given shape. - for i in range(1, rand_dim): - # Check all combinations of `i` axis. - for flip_dim in combinations(range(rand_dim), i): - data = torch.randn(*shape, device=device, dtype=dtype) - torch_fn = partial(torch.flip, dims=flip_dim) - np_fn = partial(np.flip, axis=flip_dim) - self.compare_with_numpy(torch_fn, np_fn, data) - - def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype): - for dim in range(min_dim, max_dim + 1): - shape = self._rand_shape(dim, 5, 10) - # Randomly scale the input - if dtype.is_floating_point or dtype.is_complex: - data = torch.randn(*shape, device=device, dtype=dtype) - else: - data = torch.randint(0, 10, shape, device=device, dtype=dtype) - self.compare_with_numpy(torch_fn, np_fn, data) - - @dtypes(torch.int64, torch.double, torch.cdouble) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_fliplr(self, device, dtype): - self._test_fliplr_flipud(torch.fliplr, np.fliplr, 2, 4, device, dtype) - - @dtypes(torch.int64, torch.double, torch.cdouble) - def test_fliplr_invalid(self, device, dtype): - x = torch.randn(42).to(dtype) - with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): - torch.fliplr(x) - with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): - torch.fliplr(torch.tensor(42, device=device, dtype=dtype)) - - @dtypes(torch.int64, torch.double, torch.cdouble) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_flipud(self, device, dtype): - self._test_fliplr_flipud(torch.flipud, np.flipud, 1, 4, device, dtype) - - @dtypes(torch.int64, torch.double, torch.cdouble) - def test_flipud_invalid(self, device, dtype): - with self.assertRaisesRegex(RuntimeError, "Input must be >= 1-d."): - torch.flipud(torch.tensor(42, device=device, dtype=dtype)) - - def test_rot90(self, device): - data = torch.arange(1, 5, device=device).view(2, 2) - self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1])) - self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1])) - self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) - self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) - - # test for default args k=1, dims=[0, 1] - self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) - - # test for reversed order of dims - self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) - - # test for modulo of k - self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1])) - self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) - self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) - - # test for dims out-of-range error - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2])) - - # test tensor with more than 2D - data = torch.arange(1, 9, device=device).view(2, 2, 2) - self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) - self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) - - # test for errors - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) - - @dtypes(torch.cfloat, torch.cdouble) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_complex_rot90(self, device, dtype): - shape = self._rand_shape(random.randint(2, 4), 5, 10) - for rot_times in range(4): - data = torch.randn(*shape, device=device, dtype=dtype) - torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1]) - np_fn = partial(np.rot90, k=rot_times, axes=[0, 1]) - self.compare_with_numpy(torch_fn, np_fn, data) - - @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - def test_signal_window_functions(self, device): - - def test(name, kwargs): - torch_method = getattr(torch, name + '_window') - for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]: - for periodic in [True, False]: - res = torch_method(size, periodic=periodic, **kwargs, device=device) - # NB: scipy always returns a float32 result - ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic)) - self.assertEqual(res, ref, exact_dtype=False) - with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): - torch_method(3, layout=torch.sparse_coo) - with self.assertRaisesRegex(RuntimeError, r'floating point'): - torch_method(3, dtype=torch.long) - self.assertTrue(torch_method(3, requires_grad=True).requires_grad) - self.assertFalse(torch_method(3).requires_grad) - - for window in ['hann', 'hamming', 'bartlett', 'blackman']: - test(window, kwargs={}) - - for num_test in range(50): - test('kaiser', kwargs={'beta': random.random() * 30}) - def test_broadcast(self, device): # all functions @@ -8200,2791 +3485,104 @@ def _test_in_place_broadcastable(t0, t1, t2=None): _test_in_place_broadcastable(small2, small_expanded, large_expanded) _test_in_place_broadcastable(small2, small, large) - def test_broadcast_fused_matmul(self, device): - fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] - - for fn in fns: - batch_dim = random.randint(1, 8) - n_dim = random.randint(1, 8) - m_dim = random.randint(1, 8) - p_dim = random.randint(1, 8) - - def dims_full_for_fn(): - if fn == "baddbmm": - return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) - elif fn == "addbmm": - return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) - elif fn == "addmm": - return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) - elif fn == "addmv": - return ([n_dim], [n_dim, m_dim], [m_dim]) - elif fn == "addr": - return ([n_dim, m_dim], [n_dim], [m_dim]) - else: - raise AssertionError("unknown function") - - (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() - (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) - - t0_small = torch.randn(*t0_dims_small, device=device).float() - t1 = torch.randn(*t1_dims, device=device).float() - t2 = torch.randn(*t2_dims, device=device).float() - - t0_full = t0_small.expand(*t0_dims_full).to(device) - - fntorch = getattr(torch, fn) - r0 = fntorch(t0_small, t1, t2) - r1 = fntorch(t0_full, t1, t2) - self.assertEqual(r0, r1) - - @tf32_on_and_off(0.001) - def test_broadcast_batched_matmul(self, device): - n_dim = random.randint(1, 8) - m_dim = random.randint(1, 8) - p_dim = random.randint(1, 8) - full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))] - (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims) - - def verify_batched_matmul(full_lhs, one_dimensional): - if not one_dimensional: - lhs_dims = [n_dim, m_dim] - rhs_dims = [m_dim, p_dim] - result_dims = [n_dim, p_dim] - else: - lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim] - rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim] - result_dims = [n_dim] if full_lhs else [p_dim] - - lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim] - rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1] - full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims - dim0_dims = rhs_dims if full_lhs else lhs_dims - small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims) - - small = torch.randn(*(small_dims), device=device).float() - dim0 = torch.randn(*(dim0_dims), device=device).float() - full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float() - if not one_dimensional: - (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,)) + # Ensures that kthvalue throws nondeterministic alerts in the correct cases + @dtypes(torch.double) + def test_kthvalue_nondeterministic_alert(self, device, dtype): + @expectedAlertNondeterministic('kthvalue CUDA', 'cuda') + def test_func(slf, device, call_type): + S = 10 + k = 5 + a = torch.randn(S, device=device) + if call_type == 'function': + torch.kthvalue(a, k) + elif call_type == 'method': + a.kthvalue(k) + elif call_type == 'out': + values = torch.empty_like(a) + indices = torch.empty((), device=device, dtype=torch.long) + torch.kthvalue(a, k, out=(values, indices)) else: - (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,)) + self.fail(f"'{call_type}' is not a valid call type") - def maybe_squeeze_result(l, r, result): - if len(lhs_dims) == 1 and l.dim() != 1: - return result.squeeze(-2) - elif len(rhs_dims) == 1 and r.dim() != 1: - return result.squeeze(-1) - else: - return result - - for lhs in lhsTensors: - lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims))) - lhs_expanded_matmul_fn = lhs_expanded.matmul - for rhs in rhsTensors: - rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)). - expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims)))) - truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded)) - for l in (lhs, lhs_expanded): - for r in (rhs, rhs_expanded): - l_matmul_fn = l.matmul - result = maybe_squeeze_result(l, r, l_matmul_fn(r)) - self.assertEqual(truth, result) - # test torch.matmul function as well - torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r)) - self.assertEqual(truth, torch_result) - # test torch.matmul with out - out = torch.zeros_like(torch_result) - torch.matmul(l, r, out=out) - self.assertEqual(truth, maybe_squeeze_result(l, r, out)) - - # compare to bmm - bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims), - rhs_expanded.contiguous().view(-1, *rhs_mat_dims))) - self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims)) - - for indices in product((True, False), repeat=2): - verify_batched_matmul(*indices) - - def test_contiguous(self, device): - x = torch.randn(1, 16, 5, 5, device=device) - self.assertTrue(x.is_contiguous()) - stride = list(x.stride()) - stride[0] = 20 - # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 - x.set_(x.storage(), 0, x.size(), stride) - self.assertTrue(x.is_contiguous()) - - def test_index(self, device): - - def consec(size, start=1): - sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0) - sequence.add_(start - 1) - return sequence.view(*size) - - reference = consec((3, 3, 3)).to(device) - - # empty tensor indexing - self.assertEqual(reference[torch.LongTensor().to(device)], reference.new(0, 3, 3)) - - self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0) - self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0) - self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0) - self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0) - self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0) - self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0) - self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0) - - # indexing with Ellipsis - self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9], - [12, 15, 18], - [21, 24, 27]]), atol=0, rtol=0) - self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), atol=0, rtol=0) - self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0) - self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0) - self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0) - self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0) - self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0) - self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0) - self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0) - self.assertEqual(reference[...], reference, atol=0, rtol=0) - - reference_5d = consec((3, 3, 3, 3, 3)).to(device) - self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0) - self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0) - self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0) - self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0) - - # LongTensor indexing - reference = consec((5, 5, 5)).to(device) - idx = torch.LongTensor([2, 4]).to(device) - self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]])) - # TODO: enable one indexing is implemented like in numpy - # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]])) - # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1]) - - # None indexing - self.assertEqual(reference[2, None], reference[2].unsqueeze(0)) - self.assertEqual(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)) - self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1)) - self.assertEqual(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0)) - self.assertEqual(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2)) - - # indexing 0-length slice - self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)]) - self.assertEqual(torch.empty(0, 5), reference[slice(0), 2]) - self.assertEqual(torch.empty(0, 5), reference[2, slice(0)]) - self.assertEqual(torch.tensor([]), reference[2, 1:1, 2]) - - # indexing with step - reference = consec((10, 10, 10)).to(device) - self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0)) - self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0)) - self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0)) - self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1)) - self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0)) - self.assertEqual(reference[None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0)) - self.assertEqual(reference[:, 2, 1:6:2], - torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1)) - - lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] - tensor = torch.DoubleTensor(lst).to(device) - for _i in range(100): - idx1_start = random.randrange(10) - idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) - idx1_step = random.randrange(1, 8) - idx1 = slice(idx1_start, idx1_end, idx1_step) - if random.randrange(2) == 0: - idx2_start = random.randrange(10) - idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) - idx2_step = random.randrange(1, 8) - idx2 = slice(idx2_start, idx2_end, idx2_step) - lst_indexed = list(map(lambda l: l[idx2], lst[idx1])) - tensor_indexed = tensor[idx1, idx2] - else: - lst_indexed = lst[idx1] - tensor_indexed = tensor[idx1] - self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) - - self.assertRaises(ValueError, lambda: reference[1:9:0]) - self.assertRaises(ValueError, lambda: reference[1:9:-1]) - - self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) - self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) - self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) - - self.assertRaises(IndexError, lambda: reference[0.0]) - self.assertRaises(TypeError, lambda: reference[0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) - - def delitem(): - del reference[0] - - self.assertRaises(TypeError, delitem) - - @dtypes(torch.half, torch.double) - def test_advancedindex(self, device, dtype): - # Tests for Integer Array Indexing, Part I - Purely integer array - # indexing - - def consec(size, start=1): - # Creates the sequence in float since CPU half doesn't support the - # needed operations. Converts to dtype before returning. - numel = reduce(lambda x, y: x * y, size, 1) - sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0) - sequence.add_(start - 1) - return sequence.view(*size).to(dtype=dtype) - - # pick a random valid indexer type - def ri(indices): - choice = random.randint(0, 2) - if choice == 0: - return torch.LongTensor(indices).to(device) - elif choice == 1: - return list(indices) - else: - return tuple(indices) - - def validate_indexing(x): - self.assertEqual(x[[0]], consec((1,))) - self.assertEqual(x[ri([0]), ], consec((1,))) - self.assertEqual(x[ri([3]), ], consec((1,), 4)) - self.assertEqual(x[[2, 3, 4]], consec((3,), 3)) - self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3)) - self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([1, 3, 5], dtype=dtype, device=device)) - - def validate_setting(x): - x[[0]] = -2 - self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device)) - x[[0]] = -1 - self.assertEqual(x[ri([0]), ], torch.tensor([-1], dtype=dtype, device=device)) - x[[2, 3, 4]] = 4 - self.assertEqual(x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device)) - x[ri([2, 3, 4]), ] = 3 - self.assertEqual(x[ri([2, 3, 4]), ], torch.tensor([3, 3, 3], dtype=dtype, device=device)) - x[ri([0, 2, 4]), ] = torch.tensor([5, 4, 3], dtype=dtype, device=device) - self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([5, 4, 3], dtype=dtype, device=device)) - - # Only validates indexing and setting for halfs - if dtype == torch.half: - reference = consec((10,)) - validate_indexing(reference) - validate_setting(reference) - return + test_func(self, device, 'function') + test_func(self, device, 'method') + test_func(self, device, 'out') - # Case 1: Purely Integer Array Indexing - reference = consec((10,)) - validate_indexing(reference) - - # setting values - validate_setting(reference) - - # Tensor with stride != 1 - # strided is [1, 3, 5, 7] - reference = consec((10,)) - strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), storage_offset=0, - size=torch.Size([4]), stride=[2]) - - self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device)) - self.assertEqual(strided[ri([0]), ], torch.tensor([1], dtype=dtype, device=device)) - self.assertEqual(strided[ri([3]), ], torch.tensor([7], dtype=dtype, device=device)) - self.assertEqual(strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device)) - self.assertEqual(strided[ri([1, 2]), ], torch.tensor([3, 5], dtype=dtype, device=device)) - self.assertEqual(strided[ri([[2, 1], [0, 3]]), ], - torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device)) - - # stride is [4, 8] - strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), storage_offset=4, - size=torch.Size([2]), stride=[4]) - self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device)) - self.assertEqual(strided[ri([0]), ], torch.tensor([5], dtype=dtype, device=device)) - self.assertEqual(strided[ri([1]), ], torch.tensor([9], dtype=dtype, device=device)) - self.assertEqual(strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device)) - self.assertEqual(strided[ri([0, 1]), ], torch.tensor([5, 9], dtype=dtype, device=device)) - self.assertEqual(strided[ri([[0, 1], [1, 0]]), ], - torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device)) - - # reference is 1 2 - # 3 4 - # 5 6 - reference = consec((3, 2)) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.tensor([1, 3, 5], dtype=dtype, device=device)) - self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.tensor([2, 4, 6], dtype=dtype, device=device)) - self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) - self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) - self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.tensor([1, 2], dtype=dtype, device=device)) - self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], - torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device)) - self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], - torch.tensor([1, 2, 3, 3], dtype=dtype, device=device)) - - rows = ri([[0, 0], - [1, 2]]) - columns = [0], - self.assertEqual(reference[rows, columns], torch.tensor([[1, 1], - [3, 5]], dtype=dtype, device=device)) - - rows = ri([[0, 0], - [1, 2]]) - columns = ri([1, 0]) - self.assertEqual(reference[rows, columns], torch.tensor([[2, 1], - [4, 5]], dtype=dtype, device=device)) - rows = ri([[0, 0], - [1, 2]]) - columns = ri([[0, 1], - [1, 0]]) - self.assertEqual(reference[rows, columns], torch.tensor([[1, 2], - [4, 5]], dtype=dtype, device=device)) - - # setting values - reference[ri([0]), ri([1])] = -1 - self.assertEqual(reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)) - reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], - torch.tensor([-1, 2, -4], dtype=dtype, device=device)) - reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) - self.assertEqual(reference[rows, columns], - torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)) - - # Verify still works with Transposed (i.e. non-contiguous) Tensors - - reference = torch.tensor([[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]], dtype=dtype, device=device).t_() - - # Transposed: [[0, 4, 8], - # [1, 5, 9], - # [2, 6, 10], - # [3, 7, 11]] - - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], - torch.tensor([0, 1, 2], dtype=dtype, device=device)) - self.assertEqual(reference[ri([0, 1, 2]), ri([1])], - torch.tensor([4, 5, 6], dtype=dtype, device=device)) - self.assertEqual(reference[ri([0]), ri([0])], - torch.tensor([0], dtype=dtype, device=device)) - self.assertEqual(reference[ri([2]), ri([1])], - torch.tensor([6], dtype=dtype, device=device)) - self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], - torch.tensor([0, 4], dtype=dtype, device=device)) - self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], - torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device)) - self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], - torch.tensor([0, 4, 1, 1], dtype=dtype, device=device)) - - rows = ri([[0, 0], - [1, 2]]) - columns = [0], - self.assertEqual(reference[rows, columns], - torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device)) - - rows = ri([[0, 0], - [1, 2]]) - columns = ri([1, 0]) - self.assertEqual(reference[rows, columns], - torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device)) - rows = ri([[0, 0], - [1, 3]]) - columns = ri([[0, 1], - [1, 2]]) - self.assertEqual(reference[rows, columns], - torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device)) - - # setting values - reference[ri([0]), ri([1])] = -1 - self.assertEqual(reference[ri([0]), ri([1])], - torch.tensor([-1], dtype=dtype, device=device)) - reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], - torch.tensor([-1, 2, -4], dtype=dtype, device=device)) - reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) - self.assertEqual(reference[rows, columns], - torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)) - - # stride != 1 - - # strided is [[1 3 5 7], - # [9 11 13 15]] - - reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) - strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), - stride=[8, 2]) - - self.assertEqual(strided[ri([0, 1]), ri([0])], - torch.tensor([1, 9], dtype=dtype, device=device)) - self.assertEqual(strided[ri([0, 1]), ri([1])], - torch.tensor([3, 11], dtype=dtype, device=device)) - self.assertEqual(strided[ri([0]), ri([0])], - torch.tensor([1], dtype=dtype, device=device)) - self.assertEqual(strided[ri([1]), ri([3])], - torch.tensor([15], dtype=dtype, device=device)) - self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], - torch.tensor([1, 7], dtype=dtype, device=device)) - self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], - torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device)) - self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], - torch.tensor([1, 3, 9, 9], dtype=dtype, device=device)) - - rows = ri([[0, 0], - [1, 1]]) - columns = [0], - self.assertEqual(strided[rows, columns], - torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device)) - - rows = ri([[0, 1], - [1, 0]]) - columns = ri([1, 2]) - self.assertEqual(strided[rows, columns], - torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device)) - rows = ri([[0, 0], - [1, 1]]) - columns = ri([[0, 1], - [1, 2]]) - self.assertEqual(strided[rows, columns], - torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device)) - - # setting values - - # strided is [[10, 11], - # [17, 18]] - - reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) - strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - stride=[7, 1]) - self.assertEqual(strided[ri([0]), ri([1])], - torch.tensor([11], dtype=dtype, device=device)) - strided[ri([0]), ri([1])] = -1 - self.assertEqual(strided[ri([0]), ri([1])], - torch.tensor([-1], dtype=dtype, device=device)) - - reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) - strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - stride=[7, 1]) - self.assertEqual(strided[ri([0, 1]), ri([1, 0])], - torch.tensor([11, 17], dtype=dtype, device=device)) - strided[ri([0, 1]), ri([1, 0])] = torch.tensor([-1, 2], dtype=dtype, device=device) - self.assertEqual(strided[ri([0, 1]), ri([1, 0])], - torch.tensor([-1, 2], dtype=dtype, device=device)) - - reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) - strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - stride=[7, 1]) - - rows = ri([[0], - [1]]) - columns = ri([[0, 1], - [0, 1]]) - self.assertEqual(strided[rows, columns], - torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device)) - strided[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device) - self.assertEqual(strided[rows, columns], - torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)) - - # Tests using less than the number of dims, and ellipsis - - # reference is 1 2 - # 3 4 - # 5 6 - reference = consec((3, 2)) - self.assertEqual(reference[ri([0, 2]), ], - torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device)) - self.assertEqual(reference[ri([1]), ...], - torch.tensor([[3, 4]], dtype=dtype, device=device)) - self.assertEqual(reference[..., ri([1])], - torch.tensor([[2], [4], [6]], dtype=dtype, device=device)) - - # verify too many indices fails - with self.assertRaises(IndexError): - reference[ri([1]), ri([0, 2]), ri([3])] - - # test invalid index fails - reference = torch.empty(10, dtype=dtype, device=device) - # can't test cuda because it is a device assert - if not reference.is_cuda: - for err_idx in (10, -11): - with self.assertRaisesRegex(IndexError, r'out of'): - reference[err_idx] - with self.assertRaisesRegex(IndexError, r'out of'): - reference[torch.LongTensor([err_idx]).to(device)] - with self.assertRaisesRegex(IndexError, r'out of'): - reference[[err_idx]] - - if TEST_NUMPY: - # we use numpy to compare against, to verify that our advanced - # indexing semantics are the same, and also for ease of test - # writing - - def tensor_indices_to_np(tensor, indices): - # convert the Torch Tensor to a numpy array - tensor = tensor.to(device='cpu') - npt = tensor.numpy() + def test_embedding_scalar_weight_error(self, device): + indices = torch.rand(2, 2, device=device).long() + weight = torch.tensor(1.0) + with self.assertRaisesRegex(RuntimeError, "'weight' must be at least 1-D"): + torch.embedding(weight, indices) - # convert indices - idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else - i for i in indices) - - return npt, idxs - - def get_numpy(tensor, indices): - npt, idxs = tensor_indices_to_np(tensor, indices) - - # index and return as a Torch Tensor - return torch.tensor(npt[idxs], dtype=dtype, device=device) - - def set_numpy(tensor, indices, value): - if not isinstance(value, int): - if self.device_type != 'cpu': - value = value.cpu() - value = value.numpy() - - npt, idxs = tensor_indices_to_np(tensor, indices) - npt[idxs] = value - return npt - - def assert_get_eq(tensor, indexer): - self.assertEqual(tensor[indexer], get_numpy(tensor, indexer)) - - def assert_set_eq(tensor, indexer, val): - pyt = tensor.clone() - numt = tensor.clone() - pyt[indexer] = val - numt = torch.tensor(set_numpy(numt, indexer, val), dtype=dtype, device=device) - self.assertEqual(pyt, numt) - - def assert_backward_eq(tensor, indexer): - cpu = tensor.float().clone().detach().requires_grad_(True) - outcpu = cpu[indexer] - gOcpu = torch.rand_like(outcpu) - outcpu.backward(gOcpu) - dev = cpu.to(device).detach().requires_grad_(True) - outdev = dev[indexer] - outdev.backward(gOcpu.to(device)) - self.assertEqual(cpu.grad, dev.grad) - - def get_set_tensor(indexed, indexer): - set_size = indexed[indexer].size() - set_count = indexed[indexer].numel() - set_tensor = torch.randperm(set_count).view(set_size).double().to(device) - return set_tensor - - # Tensor is 0 1 2 3 4 - # 5 6 7 8 9 - # 10 11 12 13 14 - # 15 16 17 18 19 - reference = torch.arange(0., 20, dtype=dtype, device=device).view(4, 5) - - indices_to_test = [ - # grab the second, fourth columns - [slice(None), [1, 3]], - - # first, third rows, - [[0, 2], slice(None)], - - # weird shape - [slice(None), [[0, 1], - [2, 3]]], - # negatives - [[-1], [0]], - [[0, 2], [-1]], - [slice(None), [-1]], - ] + def test_dist(self, device): + def run_test(x, y): + for p in [0, 1, 2, 3, 4, inf, -inf]: + dist_xy = torch.dist(x, y, p) + dist_xy_norm = torch.norm(x - y, p) + self.assertEqual(dist_xy, dist_xy_norm) - # only test dupes on gets - get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] - - for indexer in get_indices_to_test: - assert_get_eq(reference, indexer) - if self.device_type != 'cpu': - assert_backward_eq(reference, indexer) - - for indexer in indices_to_test: - assert_set_eq(reference, indexer, 44) - assert_set_eq(reference, - indexer, - get_set_tensor(reference, indexer)) - - reference = torch.arange(0., 160, dtype=dtype, device=device).view(4, 8, 5) - - indices_to_test = [ - [slice(None), slice(None), [0, 3, 4]], - [slice(None), [2, 4, 5, 7], slice(None)], - [[2, 3], slice(None), slice(None)], - [slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), [0], [1, 2, 4]], - [slice(None), [0, 1, 3], [4]], - [slice(None), [[0, 1], [1, 0]], [[2, 3]]], - [slice(None), [[0, 1], [2, 3]], [[0]]], - [slice(None), [[5, 6]], [[0, 3], [4, 4]]], - [[0, 2, 3], [1, 3, 4], slice(None)], - [[0], [1, 2, 4], slice(None)], - [[0, 1, 3], [4], slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - [[[0, 1], [1, 0]], [[2, 3]], slice(None)], - [[[0, 1], [2, 3]], [[0]], slice(None)], - [[[2, 1]], [[0, 3], [4, 4]], slice(None)], - [[[2]], [[0, 3], [4, 1]], slice(None)], - # non-contiguous indexing subspace - [[0, 2, 3], slice(None), [1, 3, 4]], - - # less dim, ellipsis - [[0, 2], ], - [[0, 2], slice(None)], - [[0, 2], Ellipsis], - [[0, 2], slice(None), Ellipsis], - [[0, 2], Ellipsis, slice(None)], - [[0, 2], [1, 3]], - [[0, 2], [1, 3], Ellipsis], - [Ellipsis, [1, 3], [2, 3]], - [Ellipsis, [2, 3, 4]], - [Ellipsis, slice(None), [2, 3, 4]], - [slice(None), Ellipsis, [2, 3, 4]], - - # ellipsis counts for nothing - [Ellipsis, slice(None), slice(None), [0, 3, 4]], - [slice(None), Ellipsis, slice(None), [0, 3, 4]], - [slice(None), slice(None), Ellipsis, [0, 3, 4]], - [slice(None), slice(None), [0, 3, 4], Ellipsis], - [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], - ] + run_test(torch.randn(5, device=device), torch.randn(5, device=device)) - for indexer in indices_to_test: - assert_get_eq(reference, indexer) - assert_set_eq(reference, indexer, 212) - assert_set_eq(reference, - indexer, - get_set_tensor(reference, indexer)) - if torch.cuda.is_available(): - assert_backward_eq(reference, indexer) - - reference = torch.arange(0., 1296, dtype=dtype, device=device).view(3, 9, 8, 6) - - indices_to_test = [ - [slice(None), slice(None), slice(None), [0, 3, 4]], - [slice(None), slice(None), [2, 4, 5, 7], slice(None)], - [slice(None), [2, 3], slice(None), slice(None)], - [[1, 2], slice(None), slice(None), slice(None)], - [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), slice(None), [0], [1, 2, 4]], - [slice(None), slice(None), [0, 1, 3], [4]], - [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], - [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], - [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], - [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], - [slice(None), [0], [1, 2, 4], slice(None)], - [slice(None), [0, 1, 3], [4], slice(None)], - [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], - [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], - [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], - [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], - [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], - [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], - [[0], [1, 2, 4], slice(None), slice(None)], - [[0, 1, 2], [4], slice(None), slice(None)], - [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], - [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], - [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], - [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], - [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], - [slice(None), [2, 3, 4], [1, 3, 4], [4]], - [slice(None), [0, 1, 3], [4], [1, 3, 4]], - [slice(None), [6], [0, 2, 3], [1, 3, 4]], - [slice(None), [2, 3, 5], [3], [4]], - [slice(None), [0], [4], [1, 3, 4]], - [slice(None), [6], [0, 2, 3], [1]], - [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], - [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], - [[2, 0, 1], [1, 2, 3], [4], slice(None)], - [[0, 1, 2], [4], [1, 3, 4], slice(None)], - [[0], [0, 2, 3], [1, 3, 4], slice(None)], - [[0, 2, 1], [3], [4], slice(None)], - [[0], [4], [1, 3, 4], slice(None)], - [[1], [0, 2, 3], [1], slice(None)], - [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], - - # less dim, ellipsis - [Ellipsis, [0, 3, 4]], - [Ellipsis, slice(None), [0, 3, 4]], - [Ellipsis, slice(None), slice(None), [0, 3, 4]], - [slice(None), Ellipsis, [0, 3, 4]], - [slice(None), slice(None), Ellipsis, [0, 3, 4]], - [slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], - [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], - [[0], [1, 2, 4]], - [[0], [1, 2, 4], slice(None)], - [[0], [1, 2, 4], Ellipsis], - [[0], [1, 2, 4], Ellipsis, slice(None)], - [[1], ], - [[0, 2, 1], [3], [4]], - [[0, 2, 1], [3], [4], slice(None)], - [[0, 2, 1], [3], [4], Ellipsis], - [Ellipsis, [0, 2, 1], [3], [4]], - ] + x = torch.zeros(3, device=device) + y = torch.zeros(3, device=device) + y[1] = 1. + run_test(x, y) - for indexer in indices_to_test: - assert_get_eq(reference, indexer) - assert_set_eq(reference, indexer, 1333) - assert_set_eq(reference, - indexer, - get_set_tensor(reference, indexer)) - indices_to_test += [ - [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], - [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], - ] - for indexer in indices_to_test: - assert_get_eq(reference, indexer) - assert_set_eq(reference, indexer, 1333) - if self.device_type != 'cpu': - assert_backward_eq(reference, indexer) + # Ensures that median throws nondeterministic alerts in the correct cases + @dtypes(torch.double) + def test_median_nondeterministic_alert(self, device, dtype): + def test_func(slf, device, call_type): + S = 10 + a = torch.randn(S, device=device) + if call_type == 'function': + torch.median(a) + elif call_type == 'function with indices': + torch.median(a, 0) + elif call_type == 'method': + a.median() + elif call_type == 'method with indices': + a.median(0) + elif call_type == 'out with indices': + result = torch.empty_like(a) + indices = torch.empty((), dtype=torch.long, device=device) + torch.median(a, 0, out=(result, indices)) + else: + self.fail(f"'{call_type}' is not a valid call type") - def test_advancedindex_big(self, device): - reference = torch.arange(0, 123344, dtype=torch.int, device=device) + @expectedAlertNondeterministic('median CUDA with indices output', 'cuda') + def test_func_expect_error(slf, device, call_type): + test_func(slf, device, call_type) - self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], - torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int)) + test_func(self, device, 'function') + test_func_expect_error(self, device, 'function with indices') + test_func(self, device, 'method') + test_func_expect_error(self, device, 'method with indices') + test_func_expect_error(self, device, 'out with indices') - @dtypes(torch.double) - def test_kthvalue(self, device, dtype): - SIZE = 50 - x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device) - x0 = x.clone() - - k = random.randint(1, SIZE) - res1val, res1ind = torch.kthvalue(x, k, keepdim=False) - res2val, res2ind = torch.sort(x) - - self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) - self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) - # test use of result tensors - k = random.randint(1, SIZE) - res1val = torch.tensor([], dtype=dtype, device=device) - res1ind = torch.tensor([], dtype=torch.long, device=device) - torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind)) - res2val, res2ind = torch.sort(x) - self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) - self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) - - # test non-default dim - k = random.randint(1, SIZE) - res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=False) - res2val, res2ind = torch.sort(x, 0) - self.assertEqual(res1val, res2val[k - 1], atol=0, rtol=0) - self.assertEqual(res1ind, res2ind[k - 1], atol=0, rtol=0) - - # non-contiguous - y = x.narrow(1, 0, 1) - y0 = y.contiguous() - k = random.randint(1, SIZE) - res1val, res1ind = torch.kthvalue(y, k) - res2val, res2ind = torch.kthvalue(y0, k) - self.assertEqual(res1val, res2val, atol=0, rtol=0) - self.assertEqual(res1ind, res2ind, atol=0, rtol=0) - - # check that the input wasn't modified - self.assertEqual(x, x0, atol=0, rtol=0) - - # simple test case (with repetitions) - y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device) - self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0) - self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0) - - # simple test case (with NaN) - SIZE = 50 - x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device) - x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan - ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1] - res2val, res2ind = torch.sort(x) - for k in ks: - res1val, res1ind = torch.kthvalue(x, k, keepdim=False) - self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0) - self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_lu_solve_batched_non_contiguous(self, device, dtype): - from numpy.linalg import solve - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device='cpu') - b = torch.randn(2, 2, 2, dtype=dtype, device='cpu') - x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device) - A = A.to(device).permute(0, 2, 1) - b = b.to(device).permute(2, 1, 0) - assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" - LU_data, LU_pivots = torch.lu(A) - x = torch.lu_solve(b, LU_data, LU_pivots) - self.assertEqual(x, x_exp) - - def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - b = torch.randn(*b_dims, dtype=dtype, device=device) - A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device) - LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot) - self.assertEqual(info, torch.zeros_like(info)) - return b, A, LU_data, LU_pivots - - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - @dtypes(torch.double) - def test_lu_solve(self, device, dtype): - def sub_test(pivot): - for k, n in zip([2, 3, 5], [3, 5, 7]): - b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype) - x = torch.lu_solve(b, LU_data, LU_pivots) - self.assertLessEqual(b.dist(A.mm(x)), 1e-12) - - sub_test(True) - if self.device_type == 'cuda': - sub_test(False) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_lu_solve_batched(self, device, dtype): - def sub_test(pivot): - def lu_solve_batch_test_helper(A_dims, b_dims, pivot): - b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output - self.assertEqual(x_exp, x_act) # Equality check - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check - - for batchsize in [1, 3, 4]: - lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot) - - # Tests tensors with 0 elements - b = torch.randn(3, 0, 3, dtype=dtype, device=device) - A = torch.randn(3, 0, 0, dtype=dtype, device=device) - LU_data, LU_pivots = torch.lu(A) - self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) - - sub_test(True) - if self.device_type == 'cuda': - sub_test(False) + @dtypes(*torch.testing.get_all_fp_dtypes()) + def test_log_normal(self, device, dtype): + a = torch.tensor([10], dtype=dtype, device=device).log_normal_() + self.assertEqual(a.dtype, dtype) + self.assertEqual(a.size(), torch.Size([1])) - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_lu_solve_batched_many_batches(self, device, dtype): - def run_test(A_dims, b_dims): - b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) - x = torch.lu_solve(b, LU_data, LU_pivots) - b_ = torch.matmul(A, x) - self.assertEqual(b_, b.expand_as(b_)) - - run_test((5, 65536), (65536, 5, 10)) - run_test((5, 262144), (262144, 5, 10)) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) - def test_lu_solve_batched_broadcasting(self, device, dtype): - from numpy.linalg import solve - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - - def run_test(A_dims, b_dims, pivot=True): - A_matrix_size = A_dims[-1] - A_batch_dims = A_dims[:-2] - A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype) - b = torch.randn(*b_dims, dtype=dtype) - x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(dtype=dtype, device=device) - A, b = A.to(device), b.to(device) - LU_data, LU_pivots = torch.lu(A, pivot=pivot) - x = torch.lu_solve(b, LU_data, LU_pivots) - self.assertEqual(x, x_exp) - - # test against numpy.linalg.solve - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b - - # Assert for illegal dtype would not be raised on XLA - @onlyOnCPUAndCUDA - def test_minmax_illegal_dtype(self, device): - x = torch.randn(5, 5, dtype=torch.float32, device=device) - valid_values = torch.empty(5, dtype=torch.float32, device=device) - valid_indices = torch.empty(5, dtype=torch.long, device=device) - illegal_values = torch.empty(5, dtype=torch.int, device=device) - illegal_indices = torch.empty(5, dtype=torch.double, device=device) - torch.max(x, dim=0, out=(valid_values, valid_indices)) - torch.min(x, dim=0, out=(valid_values, valid_indices)) - torch.amax(x, dim=0, out=valid_values) - torch.amin(x, dim=0, out=valid_values) - rmsg = r'scalar type|dtype' - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.max(x, dim=0, out=(illegal_values, valid_indices)) - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.min(x, dim=0, out=(illegal_values, valid_indices)) - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.amax(x, dim=0, out=illegal_values) - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.amin(x, dim=0, out=illegal_values) - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.max(x, dim=0, out=(valid_values, illegal_indices)) - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.min(x, dim=0, out=(valid_values, illegal_indices)) - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.max(x, dim=0, out=(illegal_values, illegal_indices)) - with self.assertRaisesRegex(RuntimeError, rmsg): - torch.min(x, dim=0, out=(illegal_values, illegal_indices)) - - @dtypes(torch.float, torch.double, torch.int64, torch.int32, torch.int16) - @dtypesIfCUDA(torch.float, torch.double, torch.int64, torch.int32, torch.int16, torch.half) - def test_dim_arg_reduction_scalar(self, device, dtype): - example = 4.0 - - x = torch.tensor(example, device=device, dtype=dtype) - self.assertEqual(x.argmax().item(), 0) - self.assertEqual(x.argmax(dim=None).item(), 0) - self.assertEqual(x.argmax(dim=0).item(), 0) - self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64)) - - x = torch.tensor(example, device=device, dtype=dtype) - self.assertEqual(x.argmin().item(), 0) - self.assertEqual(x.argmin(dim=None).item(), 0) - self.assertEqual(x.argmin(dim=0).item(), 0) - self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64)) - - - def test_dim_reduction(self, device): - example = [[-1, 2, 1], [5, 3, 6]] - - types = [torch.double, - torch.float, - torch.int64, - torch.int32, - torch.int16] - if self.device_type == 'cuda': # 'cpu' and 'xla' do not support half - types.append(torch.half) - - sum_dtype = { - torch.double: torch.double, - torch.float: torch.float, - torch.half: torch.half, - torch.int64: torch.int64, - torch.int32: torch.int64, - torch.int16: torch.int64, - } + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_geometric(self, device, dtype): + a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5) + self.assertEqual(a.dtype, dtype) + self.assertEqual(a.size(), torch.Size([1])) - # This won't test for 256bit instructions, since we usually - # only work on 1 cacheline (1024bit) at a time and these - # examples aren't big enough to trigger that. - for dtype in types: - x = torch.tensor(example, device=device, dtype=dtype) - self.assertEqual(x.sum().item(), 16) - self.assertEqual(x.sum(0), torch.tensor([4, 5, 7], dtype=sum_dtype[dtype])) - self.assertEqual(x.sum(1), torch.tensor([2, 14], dtype=sum_dtype[dtype])) - y = torch.tensor(example, device=device, dtype=sum_dtype[dtype]) - torch.sum(x, 0, out=y) - self.assertEqual(x.sum(0), y) - - # Mean not supported for Int types - for dtype in types[:2]: - x = torch.tensor(example, device=device, dtype=dtype) - self.assertEqual(x.mean().item(), 16.0 / 6) - self.assertEqual(x.mean(0), torch.tensor([2.0, 2.5, 7.0 / 2], dtype=dtype)) - self.assertEqual(x.mean(1), torch.tensor([2.0 / 3, 14.0 / 3], dtype=dtype)) - self.assertEqual(x.mean(), x.mean((0, 1))) - - prod_dtype = { - torch.double: torch.double, - torch.float: torch.float, - torch.half: torch.half, - torch.int64: torch.int64, - torch.int32: torch.int64, - torch.int16: torch.int64 - } + @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))) + @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_bernoulli_p(self, device, dtype): + for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]): + x = torch.tensor(trivial_p, dtype=dtype, device=device) + self.assertEqual(x.bernoulli().tolist(), trivial_p) - for dtype in types: - x = torch.tensor(example, device=device, dtype=dtype) - self.assertEqual(x.prod().item(), -180) - self.assertEqual(x.prod(0), torch.tensor([-5, 6, 6], dtype=prod_dtype[dtype])) - self.assertEqual(x.prod(1), torch.tensor([-2, 90], dtype=prod_dtype[dtype])) - - for dtype in types: - x = torch.tensor(example, device=device, dtype=dtype) - - self.assertEqual(x.min().item(), -1) - self.assertEqual(x.argmin().item(), 0) - - # TODO: torch.min does not support the same operation as argmin - # for the same case, should we enable it? - self.assertEqual(x.argmin(dim=None).item(), 0) - - self.assertEqual(x.min(0), (torch.tensor([-1, 2, 1], dtype=dtype), - torch.tensor([0, 0, 0], dtype=torch.int64))) - self.assertEqual(x.amin(0), torch.tensor([-1, 2, 1], dtype=dtype)) - self.assertEqual(x.argmin(0), torch.tensor([0, 0, 0], dtype=torch.int64)) - - self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype), - torch.tensor([[0, 0, 0]], dtype=torch.int64))) - self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype)) - self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64)) - - self.assertEqual(x.min(1), (torch.tensor([-1, 3], dtype=dtype), - torch.tensor([0, 1], dtype=torch.int64))) - self.assertEqual(x.amin(1), torch.tensor([-1, 3], dtype=dtype)) - self.assertEqual(x.argmin(1), torch.tensor([0, 1], dtype=torch.int64)) - - self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype), - torch.tensor([[0], [1]], dtype=torch.int64))) - self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype)) - self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64)) - - # test that non-contiguous tensors work - self.assertEqual(x[:, :2].min().item(), -1) - self.assertEqual(x[:, :2].amin().item(), -1) - self.assertEqual(x[:, :2].argmin().item(), 0) - - for dtype in types: - x = torch.tensor(example, device=device, dtype=dtype) - - self.assertEqual(x.max().item(), 6) - self.assertEqual(x.amax().item(), 6) - self.assertEqual(x.argmax().item(), 5) - - self.assertEqual(x.max(0), (torch.tensor([5, 3, 6], dtype=dtype), - torch.tensor([1, 1, 1], dtype=torch.int64))) - self.assertEqual(x.amax(0), torch.tensor([5, 3, 6], dtype=dtype)) - self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64)) - - self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype), - torch.tensor([[1, 1, 1]], dtype=torch.int64))) - self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype)) - self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64)) - - self.assertEqual(x.max(1), (torch.tensor([2, 6], dtype=dtype), - torch.tensor([1, 2], dtype=torch.int64))) - self.assertEqual(x.amax(1), torch.tensor([2, 6], dtype=dtype)) - self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64)) - - self.assertEqual(x.max(1, keepdim=True), (torch.tensor([[2], [6]], dtype=dtype), - torch.tensor([[1], [2]], dtype=torch.int64))) - self.assertEqual(x.amax(1, keepdim=True), torch.tensor([[2], [6]], dtype=dtype)) - self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64)) - - # test that non-contiguous tensors work - self.assertEqual(x[:, :2].max().item(), 5) - self.assertEqual(x[:, :2].amax().item(), 5) - self.assertEqual(x[:, :2].argmax().item(), 2) - - dim_red_fns = [ - "mean", "median", "mode", "norm", "prod", - "std", "sum", "var", "max", "min", "amax", "amin"] - - def normfn_attr(t, dim, keepdim=False, out=None): - attr = torch.norm - return attr(t, 2, dim, keepdim, out=out) - - for fn_name in dim_red_fns: - fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr - - def fn(x, dim, keepdim=False, out=None): - ans = fn_attr(x, dim, keepdim=keepdim, out=out) - return ans if not istuple(ans) else ans[0] - - def fn_tuple(x, dim, keepdim=False, out=None): - return fn_attr(x, dim, keepdim=keepdim, out=out) - - def test_multidim(x, dim): - self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True)) - self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension()) - self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension()) - - # general case - x = torch.randn(3, 4, 5, device=device) - dim = random.randint(0, 2) - test_multidim(x, dim) - - # check 1-d behavior - x = torch.randn(1, device=device) - dim = 0 - self.assertEqual(fn(x, dim).shape, ()) - self.assertEqual(fn(x, dim, keepdim=True).shape, (1,)) - - # check reducing of a singleton dimension - dims = [3, 4, 5] - singleton_dim = random.randint(0, 2) - dims[singleton_dim] = 1 - x = torch.randn(dims, device=device) - test_multidim(x, singleton_dim) - - # check reducing median with NaNs - # If the element in the median is a NaN, there can be issues - # when comparining with other nan elements - if fn_name == 'median': - y = torch.full((1, 3), np.nan, dtype=torch.float64, device=device) - y[:, :1] = 1.1 - values, indices = fn_tuple(y, dim=1) - expected_values = torch.tensor([nan], dtype=torch.float64, device=device) - self.assertEqual(values, expected_values) - self.assertTrue(torch.isnan(y.flatten()[indices[0]])) - - # check reducing with output kwargs - if fn_name in ['median', 'mode', 'max', 'min']: - y = torch.randn(5, 3, device=device) - values = torch.randn(5, 3, device=device) - indices = torch.zeros(5, 3, device=device).long() - 1 - fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1])) - values_expected, indices_expected = fn_tuple(y, 1, keepdim=False) - self.assertEqual(values[:, 1], values_expected, - msg='{} values with out= kwarg'.format(fn_name)) - self.assertEqual(indices[:, 1], indices_expected, - msg='{} indices with out= kwarg'.format(fn_name)) - continue + def isBinary(t): + return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0 - x = torch.randn(5, 3, device=device) - y = torch.randn(5, 3, device=device) - fn(y, 1, keepdim=False, out=x[:, 1]) - expected = fn(y, 1, keepdim=False) - self.assertEqual(x[:, 1], expected, msg='{} with out= kwarg'.format(fn_name)) - - @largeCUDATensorTest('10GB') - def test_reduction_split(self, device): - # Test reduction when there is a 32bit-indexing split - # https://github.com/pytorch/pytorch/issues/37583 - input_ = torch.randn(5, 14400, 14400, device=device) - result = input_.sum(dim=0) - expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4] - self.assertEqual(result, expect) - - @onlyCUDA - @dtypes(torch.half, torch.float, torch.double) - def test_reduction_vectorize_along_input_corner(self, device, dtype): - # 1D case: sum - size = 1024 * 1024 * 64 + 3 - shift = 1 - x = torch.zeros(size, dtype=dtype, device=device) - y = x[shift:] - for i in range(100): - x.zero_() - x[i] = 1 - self.assertEqual(x.sum(), 1.0) - if i < shift: - self.assertEqual(y.sum(), 0.0) - else: - self.assertEqual(y.sum(), 1.0) - for i in range(1, 100): - x.zero_() - x[-i] = 1 - self.assertEqual(x.sum(), 1.0) - self.assertEqual(y.sum(), 1.0) - # 1D case: argmax - size = 1024 * 1024 * 64 + 3 - shift = 1 - ysize = size - shift - x = torch.zeros(size, dtype=dtype, device=device) - y = x[shift:] - for i in range(100): - x.zero_() - x[i] = 1 - self.assertEqual(x.argmax().item(), i) - if i >= shift: - self.assertEqual(y.argmax().item(), i - shift) - for i in range(1, 100): - x.zero_() - x[-i] = 1 - self.assertEqual(x.argmax().item(), size - i) - self.assertEqual(y.argmax().item(), ysize - i) - # 2D case: sum - size = (7, 1024 * 1024 + 3) - x = torch.zeros(size, dtype=dtype, device=device) - for i in range(100): - x.zero_() - for j in range(7): - x[j][i] = j - xs = x.sum(dim=-1) - for j in range(7): - self.assertEqual(xs[j].item(), float(j)) - for i in range(100): - x.zero_() - for j in range(7): - x[j][-i] = j - xs = x.sum(dim=-1) - for j in range(7): - self.assertEqual(xs[j].item(), float(j)) - # 2D case: max/argmax - size = (7, 1024 * 1024 + 3) - x = torch.zeros(size, dtype=dtype, device=device) - for i in range(100): - x.zero_() - for j in range(7): - x[j][i] = j + 1 - xs1 = x.argmax(dim=-1) - xs2 = x.max(dim=-1).indices - for j in range(7): - self.assertEqual(xs1[j].item(), i) - self.assertEqual(xs2[j].item(), i) - for i in range(1, 100): - x.zero_() - for j in range(7): - x[j][-i] = j + 1 - xs1 = x.argmax(dim=-1) - xs2 = x.max(dim=-1).indices - for j in range(7): - self.assertEqual(xs1[j].item(), size[1] - i) - self.assertEqual(xs2[j].item(), size[1] - i) - # 2D case: min/argmin - size = (7, 1024 * 1024 + 3) - x = torch.zeros(size, dtype=dtype, device=device) - for i in range(100): - x.zero_() - for j in range(7): - x[j][i] = -(j + 1) - xs1 = x.argmin(dim=-1) - xs2 = x.min(dim=-1).indices - for j in range(7): - self.assertEqual(xs1[j].item(), i) - self.assertEqual(xs2[j].item(), i) - for i in range(1, 100): - x.zero_() - for j in range(7): - x[j][-i] = -(j + 1) - xs1 = x.argmin(dim=-1) - xs2 = x.min(dim=-1).indices - for j in range(7): - self.assertEqual(xs1[j].item(), size[1] - i) - self.assertEqual(xs2[j].item(), size[1] - i) - - @onlyCUDA - @dtypes(torch.half, torch.float, torch.double) - def test_reduction_vectorize_along_output(self, device, dtype): - def run_test(input_): - M, N = input_.shape - input_.zero_() - for i in range(min(M, N)): - input_[i][i] = 1 - output1 = input_.argmax(dim=0) - output2 = input_.sum(dim=0) - for i in range(min(M, N)): - self.assertEqual(output1[i], i) - self.assertEqual(output2[i], 1) - # vec 4 - run_test(torch.zeros(64, 64, dtype=dtype, device=device)) - # vec 2 - run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64)) - run_test(torch.zeros(64, 62, dtype=dtype, device=device)) - run_test(torch.zeros(64, 2, dtype=dtype, device=device)) - # vec 1 - run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64)) - run_test(torch.zeros(64, 61, dtype=dtype, device=device)) - run_test(torch.zeros(64, 1, dtype=dtype, device=device)) - - @slowTest - def test_argminmax_large_axis(self, device): - # Regression test for gh-32863 - x = torch.zeros(2**31, device=device, dtype=torch.int8) - x[-1] = 1 - self.assertEqual(x.argmax(0), x.shape[0] - 1) - self.assertEqual(x.max(0).indices, x.shape[0] - 1) - x[-1] = -1 - self.assertEqual(x.argmin(0), x.shape[0] - 1) - self.assertEqual(x.min(0).indices, x.shape[0] - 1) - - def test_argminmax_axis_with_dim_one(self, device): - # See: https://github.com/pytorch/pytorch/issues/38922 - n = 32768 - x = torch.zeros(1, n) - self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64)) - self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64)) - - self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64)) - self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64)) - - self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) - self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) - - self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) - self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) - - def test_remainder_overflow(self, device): - # Check Integer Overflows - x = torch.tensor(23500, dtype=torch.int64, device=device) - q = 392486996410368 - self.assertEqual(x % q, x) - self.assertEqual(-x % q, q - x) - self.assertEqual(x % -q, x - q) - self.assertEqual(-x % -q, -x) - - def test_rpow(self, device): - m = torch.randn(10, 10, device=device) - self.assertEqual(torch.pow(2, m), 2**m) - - # test with scalar - m = torch.randn(1, device=device).squeeze() - assert m.dim() == 0, "m is intentionally a scalar" - self.assertEqual(torch.pow(2, m), 2**m) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_symeig(self, device, dtype): - from torch.testing._internal.common_utils import random_symmetric_matrix - - def run_test(dims, eigenvectors, upper): - x = random_symmetric_matrix(*dims, dtype=dtype, device=device) - oute = torch.empty(dims[1:] + dims[:1], dtype=dtype, device=device) - outv = torch.empty(dims[1:] + dims[:1] * 2, dtype=dtype, device=device) - torch.symeig(x, eigenvectors=eigenvectors, upper=upper, out=(oute, outv)) - - if eigenvectors: - x_recon = torch.matmul(torch.matmul(outv, torch.diag_embed(oute)), outv.transpose(-2, -1)) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using V @ diag(e) @ V.T') - else: - eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) - self.assertEqual(eigvals, oute, msg='Eigenvalues mismatch') - self.assertEqual(torch.empty(0, device=device, dtype=dtype), outv, msg='Eigenvector matrix not empty') - - rese, resv = x.symeig(eigenvectors=eigenvectors, upper=upper) - self.assertEqual(rese, oute, msg="outputs of symeig and symeig with out don't match") - self.assertEqual(resv, outv, msg="outputs of symeig and symeig with out don't match") - - # test non-contiguous - x = random_symmetric_matrix(*dims, dtype=dtype, device=device) - n_dim = len(dims) + 1 - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - rese, resv = torch.symeig(x, eigenvectors=eigenvectors, upper=upper) - if eigenvectors: - x_recon = torch.matmul(torch.matmul(resv, torch.diag_embed(rese)), resv.transpose(-2, -1)) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using V @ diag(e) @ V.T') - else: - eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) - self.assertEqual(eigvals, rese, msg='Eigenvalues mismatch') - self.assertEqual(torch.empty(0, device=device, dtype=dtype), resv, msg='Eigenvector matrix not empty') - - batch_dims_set = [(), (3,), (3, 5), (5, 3, 5)] - for batch_dims, eigenvectors, upper in product(batch_dims_set, (True, False), (True, False)): - run_test((5,) + batch_dims, eigenvectors, upper) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_svd(self, device, dtype): - def run_test(dims, some, compute_uv): - x = torch.randn(*dims, dtype=dtype, device=device) - outu = torch.tensor((), dtype=dtype, device=device) - outs = torch.tensor((), dtype=dtype, device=device) - outv = torch.tensor((), dtype=dtype, device=device) - torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) - - if compute_uv: - if some: - x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = outu[..., :min(*dims[-2:])] - narrow_v = outv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, outs, msg='Singular values mismatch') - self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') - self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') - - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') - self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') - self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') - - # test non-contiguous - x = torch.randn(*dims, dtype=dtype, device=device) - n_dim = len(dims) - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - if compute_uv: - if some: - x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = resu[..., :min(*dims[-2:])] - narrow_v = resv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, ress, msg='Singular values mismatch') - self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') - self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') - - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices - for dims, some, compute_uv in product(shapes, [True, False], [True, False]): - run_test(dims, some, compute_uv) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_no_singularvectors(self, device): - for size in [(5, 5), (5, 20), (20, 5)]: - a = torch.randn(*size, device=device) - u, s_expect, v = torch.svd(a) - u, s_actual, v = torch.svd(a, compute_uv=False) - self.assertEqual(s_expect, s_actual, msg="Singular values don't match") - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_lowrank(self, device): - import torch - from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - - dtype = torch.double - - def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): - density = options.pop('density', 1) - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if density == 1: - a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) - a = a_input - else: - assert batches == () - a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) - a = a_input.to_dense() - - q = min(*size) - u, s, v = svd_lowrank(a_input, q=q, **options) - - # check if u, s, v is a SVD - u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) - self.assertEqual(A, a) - - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) - self.assertEqual(s, S) - - if density == 1: - # actual_rank is known only for dense inputs - # - # check if pairs (u, U) and (v, V) span the same - # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - - all_batches = [(), (1,), (3,), (2, 3)] - for actual_rank, size, all_batches in [ - (2, (17, 4), all_batches), - (4, (17, 4), all_batches), - (4, (17, 17), all_batches), - (10, (100, 40), all_batches), - (7, (1000, 1000), [()]), - ]: - # dense input - for batches in all_batches: - run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) - if size != size[::-1]: - run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) - - # sparse input - for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: - for density in [0.005, 0.1]: - run_subtest(None, size, (), device, torch.svd_lowrank, density=density) - - # jitting support - jitted = torch.jit.script(torch.svd_lowrank) - actual_rank, size, batches = 2, (17, 4), () - run_subtest(actual_rank, size, batches, device, jitted) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_pca_lowrank(self, device): - from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - - dtype = torch.double - - def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options): - density = options.pop('density', 1) - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if density == 1: - a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) - a = a_input - else: - a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) - a = a_input.to_dense() - - u, s, v = pca(a_input, q=guess_rank, **options) - - self.assertEqual(s.shape[-1], guess_rank) - self.assertEqual(u.shape[-2], rows) - self.assertEqual(u.shape[-1], guess_rank) - self.assertEqual(v.shape[-1], guess_rank) - self.assertEqual(v.shape[-2], columns) - - A1 = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) - ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device) - c = a.sum(axis=-2) / rows - c = c.reshape(batches + (1, columns)) - A2 = a - ones_m1.matmul(c) - self.assertEqual(A1, A2) - - if density == 1: - # actual rank is known only for dense input - detect_rank = (s.abs() > 1e-5).sum(axis=-1) - self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank) - U, S, V = torch.svd(A2) - self.assertEqual(s[..., :actual_rank], S[..., :actual_rank]) - - all_batches = [(), (1,), (3,), (2, 3)] - for actual_rank, size, all_batches in [ - (2, (17, 4), all_batches), - (2, (100, 4), all_batches), - (6, (100, 40), all_batches), - (12, (1000, 1000), [()]), - ]: - for batches in all_batches: - for guess_rank in [ - actual_rank, - actual_rank + 2, - actual_rank + 6, - ]: - if guess_rank <= min(*size): - run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank) - run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank) - - # sparse input - for guess_rank, size in [ - (4, (17, 4)), (4, (4, 17)), (16, (17, 17)), - (21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]: - for density in [0.005, 0.1]: - run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density) - - # jitting support - jitted = torch.jit.script(torch.pca_lowrank) - guess_rank, actual_rank, size, batches = 2, 2, (17, 4), () - run_subtest(guess_rank, actual_rank, size, batches, device, jitted) - - def test_lerp(self, device): - start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)] - for shapes in product(start_end_shapes, start_end_shapes): - start = torch.randn(shapes[0], device=device) - end = torch.randn(shapes[1], device=device) - - # Tensor weights - for weight in [torch.randn(shapes[0], device=device), random.random()]: - actual = torch.lerp(start, end, weight) - actual_method = start.lerp(end, weight) - self.assertEqual(actual, actual_method) - actual_out = torch.Tensor().to(device) - torch.lerp(start, end, weight, out=actual_out) - self.assertEqual(actual, actual_out) - expected = start + weight * (end - start) - self.assertEqual(expected, actual) - - def _test_logaddexp(self, device, dtype, base2): - if base2: - ref_func = np.logaddexp2 - our_func = torch.logaddexp2 - else: - ref_func = np.logaddexp - our_func = torch.logaddexp - - def _test_helper(a, b): - ref = ref_func(a.cpu().numpy(), b.cpu().numpy()) - v = our_func(a, b) - self.assertEqual(ref, v) - - # simple test - a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 - b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 - _test_helper(a, b) - _test_helper(a[:3], b[:3]) - - # large value test for numerical stability - a *= 10000 - b *= 10000 - _test_helper(a, b) - _test_helper(a[:3], b[:3]) - - a = torch.tensor([float('inf'), float('-inf'), float('inf'), float("nan")], dtype=dtype, device=device) - b = torch.tensor([float('inf'), float('-inf'), float('-inf'), float("nan")], dtype=dtype, device=device) - _test_helper(a, b) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(torch.float32, torch.float64) - def test_logaddexp(self, device, dtype): - self._test_logaddexp(device, dtype, base2=False) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(torch.float32, torch.float64) - def test_logaddexp2(self, device, dtype): - self._test_logaddexp(device, dtype, base2=True) - - def test_diagflat(self, device): - dtype = torch.float32 - # Basic sanity test - x = torch.randn((100,), dtype=dtype, device=device) - result = torch.diagflat(x) - expected = torch.diag(x) - self.assertEqual(result, expected) - - # Test offset - x = torch.randn((100,), dtype=dtype, device=device) - result = torch.diagflat(x, 17) - expected = torch.diag(x, 17) - self.assertEqual(result, expected) - - # Test where input has more than one dimension - x = torch.randn((2, 3, 4), dtype=dtype, device=device) - result = torch.diagflat(x) - expected = torch.diag(x.contiguous().view(-1)) - self.assertEqual(result, expected) - - # Noncontig input - x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0) - self.assertFalse(x.is_contiguous()) - result = torch.diagflat(x) - expected = torch.diag(x.contiguous().view(-1)) - self.assertEqual(result, expected) - - # Ensure that nuclear_norm's out variant gives the same result as the non-out - @onlyOnCPUAndCUDA - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64) - def test_nuclear_norm_out(self, device, dtype): - test_cases = [ - # input size, dim - ((25, 25), None), - ((25, 25), (0, 1)), - ((25, 25), (1, 0)), - ((25, 25, 25), (2, 0)), - ((25, 25, 25), (0, 1)), - ] - for keepdim in [False, True]: - for input_size, dim in test_cases: - msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}' - x = torch.randn(*input_size, device=device, dtype=dtype) - result_out = torch.empty(0, device=device, dtype=dtype) - if dim is None: - result = torch.nuclear_norm(x, keepdim=keepdim) - torch.nuclear_norm(x, keepdim=keepdim, out=result_out) - else: - result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim) - torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out) - self.assertEqual(result, result_out, msg=msg) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_norm(self, device): - def gen_error_message(input_size, p, keepdim, dim=None): - return "norm failed for input size %s, p=%s, keepdim=%s, dim=%s" % ( - input_size, p, keepdim, dim) - - for keepdim in [False, True]: - # full reduction - x = torch.randn(25, device=device) - xn = x.cpu().numpy() - for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]: - res = x.norm(p, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, p, keepdims=keepdim) - self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim)) - - # one dimension - x = torch.randn(25, 25, device=device) - xn = x.cpu().numpy() - for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]: - dim = 1 - res = x.norm(p, dim, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, p, dim, keepdims=keepdim) - msg = gen_error_message(x.size(), p, keepdim, dim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - # matrix norm - for p in ['fro', 'nuc']: - res = x.norm(p, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, p, keepdims=keepdim) - msg = gen_error_message(x.size(), p, keepdim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - # zero dimensions - x = torch.randn((), device=device) - xn = x.cpu().numpy() - res = x.norm(keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, keepdims=keepdim) - msg = gen_error_message(x.size(), None, keepdim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - # larger tensor sanity check - self.assertEqual( - 2 * torch.norm(torch.ones(10000), keepdim=keepdim), - torch.norm(torch.ones(40000), keepdim=keepdim)) - - # matrix norm with non-square >2-D tensors, all combinations of reduction dims - x = torch.randn(5, 6, 7, 8, device=device) - xn = x.cpu().numpy() - for p in ['fro', 'nuc']: - for dim in product(*[list(range(4))] * 2): - if dim[0] == dim[1]: - continue - res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim) - msg = gen_error_message(x.size(), p, keepdim, dim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_norm_complex(self, device): - def gen_error_message(input_size, p, keepdim, dim=None): - return "complex norm failed for input size %s, p=%s, keepdim=%s, dim=%s" % ( - input_size, p, keepdim, dim) - - if device == 'cpu': - for keepdim in [False, True]: - # vector norm - x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device) - xn = x.cpu().numpy() - for p in [0, 1, 3, inf, -1, -2, -3, -inf]: - res = x.norm(p, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, p, keepdims=keepdim) - msg = gen_error_message(x.size(), p, keepdim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - # matrix norm - x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device) - xn = x.cpu().numpy() - for p in ['nuc']: - res = x.norm(p, keepdim=keepdim).cpu() - expected = np.linalg.norm(xn, p, keepdims=keepdim) - msg = gen_error_message(x.size(), p, keepdim) - self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) - - # TODO: remove error test and add functionality test above when 2-norm support is added - with self.assertRaisesRegex(RuntimeError, r'norm with p=2 not supported for complex tensors'): - x = torch.randn(2, device=device, dtype=torch.complex64).norm(p=2) - - # TODO: remove error test and add functionality test above when frobenius support is added - with self.assertRaisesRegex(RuntimeError, r'frobenius norm not supported for complex tensors'): - x = torch.randn(2, 2, device=device, dtype=torch.complex64).norm(p='fro') - - elif device == 'cuda': - with self.assertRaisesRegex(RuntimeError, r'"norm_cuda" not implemented for \'ComplexFloat\''): - (1j * torch.randn(25)).norm() - - @skipCUDAIfNoMagma - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_nuclear_norm_axes_small_brute_force(self, device): - def check_single_nuclear_norm(x, axes): - if self.device_type != 'cpu' and randrange(100) < 95: - return # too many cpu <==> device copies - - a = np.array(x.cpu(), copy=False) - expected = np.linalg.norm(a, "nuc", axis=axes) - - ans = torch.norm(x, "nuc", dim=axes) - self.assertTrue(ans.is_contiguous()) - self.assertEqual(ans.shape, expected.shape) - self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) - - out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device) - ans = torch.norm(x, "nuc", dim=axes, out=out) - self.assertIs(ans, out) - self.assertTrue(ans.is_contiguous()) - self.assertEqual(ans.shape, expected.shape) - self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True) - - for n in range(1, 3): - for m in range(1, 3): - for axes in permutations([0, 1], 2): - # 2d, inner dimensions C - x = torch.randn(n, m, device=device) - check_single_nuclear_norm(x, axes) - - # 2d, inner dimensions Fortran - x = torch.randn(m, n, device=device).transpose(-1, -2) - check_single_nuclear_norm(x, axes) - - # 2d, inner dimensions non-contiguous - x = torch.randn(n, 2 * m, device=device)[:, ::2] - check_single_nuclear_norm(x, axes) - - # 2d, all dimensions non-contiguous - x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2] - check_single_nuclear_norm(x, axes) - - for o in range(1, 3): - for axes in permutations([0, 1, 2], 2): - # 3d, inner dimensions C - x = torch.randn(o, n, m, device=device) - check_single_nuclear_norm(x, axes) - - # 3d, inner dimensions Fortran - x = torch.randn(o, m, n, device=device).transpose(-1, -2) - check_single_nuclear_norm(x, axes) - - # 3d, inner dimensions non-contiguous - x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2] - check_single_nuclear_norm(x, axes) - - # 3d, all dimensions non-contiguous - x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2] - check_single_nuclear_norm(x, axes) - - for r in range(1, 3): - for axes in permutations([0, 1, 2, 3], 2): - # 4d, inner dimensions C - x = torch.randn(r, o, n, m, device=device) - check_single_nuclear_norm(x, axes) - - # 4d, inner dimensions Fortran - x = torch.randn(r, o, n, m, device=device).transpose(-1, -2) - check_single_nuclear_norm(x, axes) - - # 4d, inner dimensions non-contiguous - x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2] - check_single_nuclear_norm(x, axes) - - # 4d, all dimensions non-contiguous - x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2] - check_single_nuclear_norm(x, axes) - - @skipCUDAIfNoMagma - def test_nuclear_norm_exceptions(self, device): - for lst in [], [1], [1, 2]: - x = torch.tensor(lst, dtype=torch.double, device=device) - for axes in (), (0,): - self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) - self.assertRaises(IndexError, torch.norm, x, "nuc", (0, 1)) - - x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) - self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) - self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) - - def test_embedding_scalar_weight_error(self, device): - indices = torch.rand(2, 2, device=device).long() - weight = torch.tensor(1.0) - with self.assertRaisesRegex(RuntimeError, "'weight' must be at least 1-D"): - torch.embedding(weight, indices) - - def test_dist(self, device): - def run_test(x, y): - for p in [0, 1, 2, 3, 4, inf, -inf]: - dist_xy = torch.dist(x, y, p) - dist_xy_norm = torch.norm(x - y, p) - self.assertEqual(dist_xy, dist_xy_norm) - - run_test(torch.randn(5, device=device), torch.randn(5, device=device)) - - x = torch.zeros(3, device=device) - y = torch.zeros(3, device=device) - y[1] = 1. - run_test(x, y) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_geqrf(self, device): - a = torch.randn(5, 5, device=device) - b, c = torch.geqrf(a) - b_placeholder, c_placeholder = torch.empty_like(b), torch.empty_like(c) - torch.geqrf(a, out=(b_placeholder, c_placeholder)) - self.assertEqual(b, b_placeholder) - self.assertEqual(c, c_placeholder) - - def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, - device, dtype): - triangle_function = torch.triu if upper else torch.tril - b = torch.randn(*b_dims, dtype=dtype, device=device) - A = torch.randn(*A_dims, dtype=dtype, device=device) - A_triangular = triangle_function(A) - if unitriangular: - A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) - return b, A_triangular - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_triangular_solve(self, device, dtype): - for (k, n), (upper, unitriangular, transpose) in product(zip([2, 3, 5], [3, 5, 7]), - product([True, False], repeat=3)): - b, A = self.triangular_solve_test_helper((n, n), (n, k), upper, - unitriangular, device, dtype) - x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] - if transpose: - self.assertLessEqual(b.dist(A.t().mm(x)), 4e-12) - else: - self.assertLessEqual(b.dist(A.mm(x)), 4e-12) - - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - @dtypes(torch.double) - def test_triangular_solve_batched(self, device, dtype): - def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): - b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, - unitriangular, device, dtype) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, - unitriangular=unitriangular, - transpose=transpose)[0]) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.triangular_solve(b, A, upper=upper, - unitriangular=unitriangular, - transpose=transpose)[0] # Actual output - self.assertEqual(x_act, x_exp) # Equality check - if transpose: - self.assertLessEqual(b.dist(torch.matmul(A.transpose(-2, -1), x_act)), 3e-12) # Correctness check - else: - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 3e-12) # Correctness check - - for (upper, unitriangular, transpose), batchsize in product(product([True, False], repeat=3), [1, 3, 4]): - triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), - upper, unitriangular, transpose) - - - @slowTest - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_triangular_solve_batched_many_batches(self, device, dtype): - for upper, transpose, unitriangular in product([True, False], repeat=3): - b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1), - upper, unitriangular, device, dtype) - x, _ = torch.triangular_solve(b, A, - upper=upper, transpose=transpose, unitriangular=unitriangular) - if transpose: - A = A.transpose(-2, -1) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - - b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1), - upper, unitriangular, device, dtype) - x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, - unitriangular=unitriangular) - if transpose: - A = A.transpose(-2, -1) - self.assertEqual(torch.matmul(A, x), b) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @unittest.skipIf(not TEST_SCIPY, "SciPy not found") - @dtypes(torch.double) - def test_triangular_solve_batched_broadcasting(self, device, dtype): - from scipy.linalg import solve_triangular as tri_solve - - def scipy_tri_solve_batched(A, B, upper, trans, diag): - batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] - single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] - expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), - torch.Size(batch_dims_B))) - expand_A = np.broadcast_to(A, expand_dims + single_dim_A) - expand_B = np.broadcast_to(B, expand_dims + single_dim_B) - flat_A = expand_A.reshape((-1,) + single_dim_A) - flat_B = expand_B.reshape((-1,) + single_dim_B) - flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) - for a, b in zip(flat_A, flat_B)]) - return flat_X.reshape(expand_B.shape) - - def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): - b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, - unitriangular, device, dtype) - x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), - upper, transpose, unitriangular)) - x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - - self.assertEqual(x, x_exp.to(device)) - - for upper, transpose, unitriangular in product([True, False], repeat=3): - # test against scipy.linalg.solve_triangular - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b - - @onlyCPU - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_triangular_solve_singular(self, device, dtype): - b = torch.rand(3, 1, device=device) - A = torch.eye(3, 3, device=device) - A[-1, -1] = 0 # Now A is singular - err_str = r"triangular_solve_cpu: U\(3,3\) is zero, singular U\." - with self.assertRaisesRegex(RuntimeError, err_str): - torch.triangular_solve(b, A) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_lstsq(self, device, dtype): - def _test_underdetermined(a, b, expectedNorm): - # underdetermined systems are only supported on CPU - if self.device_type != 'cpu': - return - - m = a.size()[0] - n = a.size()[1] - assert(m <= n) - - a_copy = a.clone() - b_copy = b.clone() - res1 = torch.lstsq(b, a)[0] - self.assertEqual(a, a_copy, atol=0, rtol=0) - self.assertEqual(b, b_copy, atol=0, rtol=0) - self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, atol=1e-8, rtol=0) - - ta = torch.tensor((), dtype=dtype, device=device) - tb = torch.tensor((), dtype=dtype, device=device) - res2 = torch.lstsq(b, a, out=(tb, ta))[0] - self.assertEqual(a, a_copy, atol=0, rtol=0) - self.assertEqual(b, b_copy, atol=0, rtol=0) - self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, atol=1e-8, rtol=0) - - res3 = torch.lstsq(b, a, out=(b, a))[0] - self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, atol=1e-8, rtol=0) - self.assertEqual(res1, tb, atol=0, rtol=0) - self.assertEqual(res1, b, atol=0, rtol=0) - self.assertEqual(res1, res2, atol=0, rtol=0) - self.assertEqual(res1, res3, atol=0, rtol=0) - - def _test_overdetermined(a, b, expectedNorm): - m = a.size()[0] - n = a.size()[1] - assert(m > n) - - def check_norm(a, b, expected_norm, gels_result): - # Checks |ax - b| and the residual info from the result - - # The first n rows is the least square solution. - # Rows n to m-1 contain residual information. - x = gels_result[:n] - resid_info = gels_result[n:] - - resid_norm = (torch.mm(a, x) - b).norm() - self.assertEqual(resid_norm, expectedNorm, atol=1e-8, rtol=0) - self.assertEqual(resid_info.norm(), resid_norm, atol=1e-8, rtol=0) - - a_copy = a.clone() - b_copy = b.clone() - res1 = torch.lstsq(b, a)[0] - self.assertEqual(a, a_copy, atol=0, rtol=0) - self.assertEqual(b, b_copy, atol=0, rtol=0) - check_norm(a, b, expectedNorm, res1) - - ta = torch.tensor((), dtype=dtype, device=device) - tb = torch.tensor((), dtype=dtype, device=device) - res2 = torch.lstsq(b, a, out=(tb, ta))[0] - self.assertEqual(a, a_copy, atol=0, rtol=0) - self.assertEqual(b, b_copy, atol=0, rtol=0) - check_norm(a, b, expectedNorm, res2) - - res3 = torch.lstsq(b, a, out=(b, a))[0] - check_norm(a_copy, b_copy, expectedNorm, res3) - - self.assertEqual(res1, tb, atol=0, rtol=0) - self.assertEqual(res1, b, atol=0, rtol=0) - self.assertEqual(res1, res2, atol=0, rtol=0) - self.assertEqual(res1, res3, atol=0, rtol=0) - - # basic test - expectedNorm = 0 - a = torch.tensor(((1.44, -9.96, -7.55, 8.34), - (-7.84, -0.28, 3.24, 8.09), - (-4.39, -3.24, 6.27, 5.28), - (4.53, 3.83, -6.64, 2.06)), dtype=dtype, device=device).t() - b = torch.tensor(((8.58, 8.26, 8.48, -5.28), - (9.35, -4.43, -0.70, -0.26)), dtype=dtype, device=device).t() - _test_underdetermined(a, b, expectedNorm) - - # test overdetermined - expectedNorm = 17.390200628863 - a = torch.tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45), - (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70), - (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19), - (4.53, 3.83, -6.64, 2.06, -2.47, 4.70)), dtype=dtype, device=device).t() - b = torch.tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93), - (9.35, -4.43, -0.70, -0.26, -7.36, -2.52)), dtype=dtype, device=device).t() - _test_overdetermined(a, b, expectedNorm) - - # test underdetermined - expectedNorm = 0 - a = torch.tensor(((1.44, -9.96, -7.55), - (-7.84, -0.28, 3.24), - (-4.39, -3.24, 6.27), - (4.53, 3.83, -6.64)), dtype=dtype, device=device).t() - b = torch.tensor(((8.58, 8.26, 8.48), - (9.35, -4.43, -0.70)), dtype=dtype, device=device).t() - _test_underdetermined(a, b, expectedNorm) - - # test reuse - expectedNorm = 0 - a = torch.tensor(((1.44, -9.96, -7.55, 8.34), - (-7.84, -0.28, 3.24, 8.09), - (-4.39, -3.24, 6.27, 5.28), - (4.53, 3.83, -6.64, 2.06)), dtype=dtype, device=device).t() - b = torch.tensor(((8.58, 8.26, 8.48, -5.28), - (9.35, -4.43, -0.70, -0.26)), dtype=dtype, device=device).t() - ta = torch.tensor((), dtype=dtype, device=device) - tb = torch.tensor((), dtype=dtype, device=device) - torch.lstsq(b, a, out=(tb, ta)) - self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, atol=1e-8, rtol=0) - torch.lstsq(b, a, out=(tb, ta)) - self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, atol=1e-8, rtol=0) - torch.lstsq(b, a, out=(tb, ta)) - self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, atol=1e-8, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @tf32_on_and_off(0.001) - def test_qr(self, device): - def run_test(tensor_dims, some): - A = torch.randn(*tensor_dims, device=device) - Q, R = torch.qr(A, some=some) - - # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n) - m, n = tensor_dims[-2:] - n_columns = m if (not some) and m > n else min(m, n) - self.assertEqual(Q.size(-2), m) - self.assertEqual(R.size(-1), n) - self.assertEqual(Q.size(-1), n_columns) - - # Check1: A = QR - self.assertEqual(A, torch.matmul(Q, R)) - - # Check2: A = QR (with out) - Q_out, R_out = torch.Tensor().to(device), torch.Tensor().to(device) - torch.qr(A, some=some, out=(Q_out, R_out)) - self.assertEqual(A, torch.matmul(Q_out, R_out)) - - # Check3: Q == Q_out, R == R_out - self.assertEqual(Q, Q_out) - self.assertEqual(R, R_out) - - # Check4: Q^{T}Q = I, triu(R) = R - self.assertEqual(torch.matmul(Q.transpose(-2, -1), Q), - torch.eye(n_columns, device=device).expand(Q.shape[:-2] + (n_columns, n_columns))) - self.assertEqual(R.triu(), R) - - tensor_dims_list = [(3, 5), (5, 5), (5, 3), # Single matrix - (7, 3, 5), (7, 5, 5), (7, 5, 3), # 3-dim Tensors - (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)] # 4-dim Tensors - for tensor_dims, some in product(tensor_dims_list, [True, False]): - run_test(tensor_dims, some) - - @onlyOnCPUAndCUDA - @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_quantile(self, device, dtype): - # Generate some random test cases - ops = ['quantile', 'nanquantile'] - inputs = [tuple(np.random.randint(2, 10, size=i)) for i in range(1, 4)] - quantiles = [tuple(np.random.rand(i)) for i in range(0, 5)] - keepdims = [True, False] - - # Add corner cases - inputs.extend([0.75, (1,), (1, 1), (1, 2, 1)]) - inputs.extend([[float('nan')], [[float('nan'), float('nan')], [1, 2]]]) - inputs.extend([[[float('nan'), float('nan')], [float('nan'), 2]]]) - quantiles.extend([0.5, [0., 1.], np.random.rand(10)]) - - # Enumerate all input combinations - for op, x, q, keepdim in product(ops, inputs, quantiles, keepdims): - if type(x) is tuple: - a = torch.randn(x, dtype=dtype, device=device) - # Make some random elements NaN - a.masked_fill_(torch.randint_like(a, 20) == 0, float('nan')) - else: - a = torch.tensor(x, dtype=dtype, device=device) - - q = torch.tensor(q, dtype=dtype, device=device) - - torch_op = getattr(torch, op) - numpy_op = getattr(np, op) - - # Compute quantile along every dimension and flattened tensor - for dim in [None] + list(range(a.ndim)): - result = torch_op(a, q, dim, keepdim) - expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim, keepdims=keepdim) - self.assertEqual(result.cpu(), torch.from_numpy(np.array(expected)).type(result.type())) - - # Test out variation - out = torch.empty_like(result) - torch_op(a, q, dim, keepdim, out=out) - self.assertEqual(out.cpu(), result.cpu()) - - def test_quantile_backward(self, device): - def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)): - for op in ops: - t = torch.tensor(a, device=device, requires_grad=True) - op(t, torch.tensor(q, device=device), dim).sum().backward() - self.assertEqual(t.grad, expected_grad) - - check([1., 2, 3], 0.5, 0, [0, 1, 0]) - check([1., 2, 3, 4], 0.5, 0, [0, 0.5, 0.5, 0]) - check([3., 1, 4, 2], 0.5, 0, [0.5, 0, 0, 0.5]) - check([1., 2, 3, 4], [0.25, 0.5, 0.75], 0, [0.25, 1.25, 1.25, 0.25]) - check([[1., 2], [2, 1]], 0., 0, [[1, 0], [0, 1]]) - check([[1., 2], [4, 3]], 1., 1, [[0, 1], [1, 0]]) - check([1, float('nan'), 2], 0.5, 0, [0, 1, 0], [torch.quantile]) - check([1, float('nan'), 2], 0.5, 0, [0.5, 0, 0.5], [torch.nanquantile]) - - def test_quantile_error(self, device): - def check(a, q, args, kwargs, message): - with self.assertRaisesRegex(RuntimeError, r'quantile\(\) ' + message): - at = torch.tensor(a, device=device) - qt = torch.tensor(q, device=device) if isinstance(q, list) else q - torch.quantile(at, qt, *args, **kwargs) - - check([], 0.5, [], {}, r'input tensor must be non-empty') - check([1.], [[1.]], [], {}, r'q must be a scalar or 1D tensor') - check([1], 0.5, [], {}, r'input tensor must be either float or double dtype') - check([1.], [1], [], {}, r'q tensor must be same dtype as the input tensor') - check([1.], -1., [], {}, r'q must be in the range \[0, 1\] but got -1') - check([1.], 1.1, [], {}, r'q must be in the range \[0, 1\] but got 1.1') - check([1.], 0.5, [], {'out': torch.empty([], dtype=torch.float64, device=device)}, - r'out tensor must be same dtype as the input tensor') - - if self.device_type == "cpu": - check([1.], [0.5, 1.1, -1], [], {}, r'q values must be in the range \[0, 1\]') - - if self.device_type == "cuda": - with self.assertRaisesRegex( - RuntimeError, r'quantile\(\) q tensor must be on the same device as the input tensor'): - torch.randn(1, device=device).quantile(torch.tensor(0.5)) - with self.assertRaisesRegex( - RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'): - torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1)) - - - def test_random_neg_values(self, device): - signed_dtypes = [torch.double, torch.float, torch.long, torch.int, torch.short] - for dtype in signed_dtypes: - res = torch.rand(SIZE, SIZE).to(device=device, dtype=dtype) - res.random_(-10, -1) - self.assertLessEqual(res.max().item(), 9) - self.assertGreaterEqual(res.min().item(), -10) - - @slowTest - def test_triu_tril(self, device): - def gen_mask(shape, diagonal, device, upper): - mask = torch.zeros(*shape[-2:]).byte() - for i in range(shape[-2]): - for j in range(shape[-1]): - cond = j - i < diagonal if upper else j - i > diagonal - if cond: - mask[i, j] = 1 - return mask.expand(*shape).to(device) - - torch_functions = {True: torch.triu, False: torch.tril} - if TEST_NUMPY: - numpy_functions = {True: np.triu, False: np.tril} - - # TODO: remove this when bool and half are supported for torch.where - def bool_half_compat_where(pred, true_tensor, false_tensor, dtype): - if dtype == torch.bool or dtype == torch.half: - return torch.where(pred.byte(), true_tensor.byte(), false_tensor.byte()).to(dtype=dtype) - else: - return torch.where(pred, true_tensor, false_tensor) - - def run_test(shape, device, diagonal, dtype): - x = torch.empty(*shape, device=device, dtype=dtype).fill_(2) - - for upper in [True, False]: - # normal test with mask - torch_tri_func = torch_functions[upper] - res1 = torch_tri_func(x, diagonal=diagonal) - res2 = torch.empty(0, device=device, dtype=dtype) - torch_tri_func(x, diagonal=diagonal, out=res2) - exp_mask = gen_mask(shape, diagonal, device, upper) - expected = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x, dtype) - self.assertEqual(res1, res2, atol=0, rtol=0) - self.assertEqual(expected, res1, atol=0, rtol=0) - - # non-contiguous and expanded tensors test - if 0 not in shape: - for s in range(-len(shape), -1): - # non-contiguous tensors - x_nc = x.clone().transpose(s, s + 1) - exp_mask = gen_mask(x_nc.size(), diagonal, device, upper) - if 1 not in shape: - assert not x_nc.is_contiguous(), "x is intentionally non-contiguous" - exp_nc = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x_nc, dtype) - self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, atol=0, rtol=0) - x_nc_is_contiguous = x_nc.is_contiguous() - if upper: - self.assertEqual(x_nc.triu_(diagonal), exp_nc, atol=0, rtol=0) - else: - self.assertEqual(x_nc.tril_(diagonal), exp_nc, atol=0, rtol=0) - - self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous, - "contiguity of x_nc should not be changed") - - # expanded tensors - expanded_size = (x.size(0),) + x.size() - x_expanded = x.clone().expand(*expanded_size) - if x.size(0) != 1: - assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride" - output = torch_tri_func(x_expanded, diagonal) - self.assertEqual(output, expected.expand(expanded_size), atol=0, rtol=0) - if x.size(0) != 1: - self.assertTrue(0 in x_expanded.stride(), - "geometry of x_expanded should be the same") - if upper: - self.assertEqual(output, x_expanded.triu_(diagonal), atol=0, rtol=0) - else: - self.assertEqual(output, x_expanded.tril_(diagonal), atol=0, rtol=0) - - if not TEST_NUMPY: - continue - - # numpy test - numpy_tri_func = numpy_functions[upper] - self.assertEqual(numpy_tri_func(x.to('cpu').numpy(), diagonal), res1.cpu().numpy()) - - diagonals = [-2, -1, 0, 1, 2] - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7), # thin matrices - (3, 0), (0, 3, 3), (3, 3, 0, 0), # no numel matrices - (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices - (1, 3), (5, 1, 3), (7, 5, 1, 3), # very thin matrices - (1, 3, 3, 3), (3, 1, 3, 3, 3)] # unsqueezed batch dimensions - dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.bfloat16] - for s, d, dtype in product(shapes, diagonals, dtypes): - run_test(s, device, d, dtype) - - @skipCUDANonDefaultStreamIf(True) - def test_multinomial_alias(self, device): - # Get probs vector to use in setup - def get_probs(length, is_contiguous): - probs = torch.softmax(torch.randn(length), 0) - if not is_contiguous: - probs = torch.softmax(torch.randn(length, 2), 0)[:, 1] - assert not (is_contiguous ^ probs.is_contiguous()), "contiguity requirement not met" - return probs.to(device) - - for is_contiguous in [True, False]: - probs = get_probs(4, is_contiguous) - alias_table, prob_table = torch._multinomial_alias_setup(probs) - for n_samples in [-1, 1, 10]: - if n_samples > 0: - samples = torch._multinomial_alias_draw(prob_table, alias_table, n_samples) - self.assertEqual(prob_table.size(), torch.Size([4]), msg="size mismatch: probability table") - self.assertEqual(alias_table.size(), torch.Size([4]), msg="size mismatch: alias table") - self.assertEqual(samples.size(), torch.Size([n_samples]), msg="wrong number of samples") - else: - with self.assertRaisesRegex(RuntimeError, "cannot sample <= 0 samples"): - torch._multinomial_alias_draw(prob_table, alias_table, n_samples) - - with self.assertRaisesRegex(RuntimeError, "expected 1-D"): - probs = probs.view(2, 2) - torch._multinomial_alias_setup(probs) - - with self.assertRaisesRegex(RuntimeError, "expected 1-D"): - a_t, p_t = torch._multinomial_alias_setup(probs) - torch._multinomial_alias_draw(p_t.view(2, 2), a_t.view(2, 2)) - - MAX_SAMPLES = 200000 - for probs in [get_probs(4, True), - torch.tensor([0.8, 0.2], device=device), - torch.tensor([0.7, 0.2, 0.1], device=device)]: - # Check how different the alias distribution and the original distribution are - alias_dist = torch.zeros_like(probs) - alias_table, prob_table = torch._multinomial_alias_setup(probs) - alias_samples = torch._multinomial_alias_draw(prob_table, alias_table, MAX_SAMPLES) - alias_dist = torch.unique(alias_samples, return_counts=True)[1].to(dtype=probs.dtype) / MAX_SAMPLES - self.assertEqual(alias_dist, probs, rtol=0.02, atol=0.0, - msg="Actual: {}\nExpected: {}".format(alias_dist, probs)) - - for probs in [torch.tensor([0.2501, 0.25, 0.2499, 0.25], device=device), - torch.tensor([0.8, 0.199, 0.001], device=device), - torch.tensor([0.25001, 0.25, 0.24999, 0.25], device=device), - torch.tensor([0.33, 0.34, 0.33], device=device), - torch.tensor([0.8, 0.1999, 0.0001], device=device)]: - # Check the difference between the original probabilities and the reconstructed - # probabilities from the alias and probability tables output by _multinomial_alias_setup - alias_table, prob_table = torch._multinomial_alias_setup(probs) - actual = torch.zeros_like(probs) - for i, vals in enumerate(zip(alias_table, prob_table)): - idx, p = vals - actual[i] += p - actual[idx] += 1. - p - actual = actual / len(probs) - self.assertEqual(actual, probs, atol=1e-6, rtol=0) - - # Some special cases - test_cases = [torch.tensor([1.0, 0.0, 0.0], device=device), torch.tensor([0.0, 1.0], device=device)] - for probs in test_cases: - alias_table, prob_table = torch._multinomial_alias_setup(probs) - alias_samples = torch._multinomial_alias_draw(prob_table, alias_table, MAX_SAMPLES) - self.assertEqual(alias_samples.unique(), probs.nonzero().squeeze(-1)) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_lapack_empty(self, device): - # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here. - # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although - # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing" - # (e.g. lu). We often name our functions identically to the lapack function, so it will take work - # to name / migrate-to better wrappers. - def fn(torchfn, *args): - return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape - for shape in args)) - - # inverse, pinverse - self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape) - self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape) - self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape) - self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape) - - # det, logdet, slogdet - self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0))) - self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0))) - self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), - fn(torch.slogdet, (0, 0))) - - # eig, symeig - evalues, evectors = fn(torch.eig, (0, 0), True) - self.assertEqual([(0, 2), (0, 0)], [evalues.shape, evectors.shape]) - evalues, evectors = fn(torch.symeig, (0, 0), True) - self.assertEqual([(0,), (0, 0)], [evalues.shape, evectors.shape]) - - # qr - q, r = fn(torch.qr, (3, 0), True) - self.assertEqual([(3, 0), (0, 0)], [q.shape, r.shape]) - q, r = fn(torch.qr, (0, 3), True) - self.assertEqual([(0, 0), (0, 3)], [q.shape, r.shape]) - q, r = fn(torch.qr, (3, 0), False) - self.assertEqual([(3, 3), (3, 0)], [q.shape, r.shape]) - - # lstsq - self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0, 0), torch.randn(0, 0))) - self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0,), torch.randn(0, 0))) - - def test_roll(self, device): - numbers = torch.arange(1, 9, device=device) - - single_roll = numbers.roll(1, 0) - expected = torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device) - self.assertEqual(single_roll, expected, msg="{} did not equal expected result".format(single_roll)) - - roll_backwards = numbers.roll(-2, 0) - expected = torch.tensor([3, 4, 5, 6, 7, 8, 1, 2], device=device) - self.assertEqual(roll_backwards, expected, msg="{} did not equal expected result".format(roll_backwards)) - - data = numbers.view(2, 2, 2) - rolled = data.roll(1, 0) - expected = torch.tensor([5, 6, 7, 8, 1, 2, 3, 4], device=device).view(2, 2, 2) - self.assertEqual(expected, rolled, msg="{} did not equal expected result: {}".format(rolled, expected)) - - data = data.view(2, 4) - # roll a loop until back where started - loop_rolled = data.roll(2, 0).roll(4, 1) - self.assertEqual(data, loop_rolled, msg="{} did not equal the original: {}".format(loop_rolled, data)) - # multiple inverse loops - self.assertEqual(data, data.roll(-20, 0).roll(-40, 1)) - self.assertEqual(torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device), numbers.roll(1, 0)) - - # test non-contiguous - # strided equivalent to numbers.as_strided(size=(4, 2), stride=(1, 4)) - strided = numbers.view(2, 4).transpose(0, 1) - self.assertFalse(strided.is_contiguous(), "this test needs a non-contiguous tensor") - expected = torch.tensor([4, 8, 1, 5, 2, 6, 3, 7]).view(4, 2) - rolled = strided.roll(1, 0) - self.assertEqual(expected, rolled, - msg="non contiguous tensor rolled to {} instead of {} ".format(rolled, expected)) - - # test roll with no dimension specified - expected = numbers.roll(1, 0).view(2, 4) - self.assertEqual(expected, data.roll(1), msg="roll with no dims should flatten and roll.") - self.assertEqual(expected, data.roll(1, dims=None), msg="roll with no dims should flatten and roll.") - - # test roll over multiple dimensions - expected = torch.tensor([[7, 8, 5, 6], [3, 4, 1, 2]], device=device) - double_rolled = data.roll(shifts=(2, -1), dims=(1, 0)) - self.assertEqual(double_rolled, expected, - msg="should be able to roll over two dimensions, got {}".format(double_rolled)) - - self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=())) - self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=1)) - # shifts/dims should align - self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1, 2), dims=(1,))) - self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1,), dims=(1, 2))) - - # test bool tensor - t = torch.zeros(6, dtype=torch.bool, device=device) - t[0] = True - t[3] = True - self.assertEqual(torch.tensor([False, True, False, False, True, False]), t.roll(1, 0)) - - # test complex tensor - t = torch.tensor([1, 2 + 1j, 3.5, 4. + 2j, 5j, 6.], device=device) - t[0] = 1 + 0.5j - t[3] = 4. - expected = torch.tensor([6., 1 + 0.5j, 2 + 1j, 3.5, 4., 5j], device=device) - self.assertEqual(expected, t.roll(1, 0)) - - def test_nonzero_empty(self, device): - def assert_tuple_empty(tup, dim): - self.assertEqual(dim, len(tup)) - for t in tup: - self.assertEqual(torch.Size([0]), t.shape) - - x = torch.randn(0, 2, 0, 5, 0, device=device) - y = torch.nonzero(x) - z = torch.nonzero(x, as_tuple=True) - - self.assertEqual(0, y.numel()) - self.assertEqual(torch.Size([0, 5]), y.shape) - assert_tuple_empty(z, 5) - - x = torch.tensor(0.5, device=device) - y = torch.nonzero(x) - # nonzero with as_tuple returns a - # tuple of len 1 for a zero-dim tensor. - # This is done to match Numpy behavior. - z = torch.nonzero(x, as_tuple=True) - self.assertEqual(1, len(z)) - self.assertEqual(torch.zeros(1, dtype=torch.long), z[0]) - - x = torch.zeros((), device=device) - y = torch.nonzero(x) - z = torch.nonzero(x, as_tuple=True) - self.assertEqual(torch.Size([0, 0]), y.shape) - self.assertEqual(1, len(z)) - self.assertEqual(torch.empty(0, dtype=torch.long), z[0]) - - @onlyOnCPUAndCUDA - def test_nonzero_deprecated(self, device): - x = torch.randn((2, 3), device=device) - with self.maybeWarnsRegex(UserWarning, "This overload of nonzero is deprecated"): - x.nonzero() - - with self.maybeWarnsRegex(UserWarning, "This overload of nonzero is deprecated"): - torch.nonzero(x) - - # TODO: add torch.complex64, torch.complex128 - @dtypes(torch.float, torch.double) - def test_normal(self, device, dtype): - - def helper(self, device, dtype, ptype, t_transform, std_transform): - q = torch.empty(100, 100, dtype=dtype, device=device) - - q.normal_() - self.assertEqual(t_transform(q).mean(), 0, atol=0.2, rtol=0) - self.assertEqual(t_transform(q).std(), std_transform(1), atol=0.2, rtol=0) - - q.normal_(2, 3) - self.assertEqual(t_transform(q).mean(), 2, atol=0.3, rtol=0) - self.assertEqual(t_transform(q).std(), std_transform(3), atol=0.3, rtol=0) - - q = torch.empty(100, 100, dtype=dtype, device=device) - q_row1 = q[0:1].clone() - q[99:100].normal_() - self.assertEqual(t_transform(q[99:100]).mean(), 0, atol=0.2, rtol=0) - self.assertEqual(t_transform(q[99:100]).std(), std_transform(1), atol=0.2, rtol=0) - self.assertEqual(t_transform(q[0:1]).clone(), t_transform(q_row1)) - - mean = torch.empty(100, 100, dtype=dtype, device=device) - mean[:50].fill_(ptype(0)) - mean[50:].fill_(ptype(1)) - - std = torch.empty(100, 100, dtype=torch.float, device=device) - std[:, :50] = 4 - std[:, 50:] = 1 - - r = torch.normal(mean) - self.assertEqual(r.dtype, dtype) - self.assertEqual(str(r.device), device) - self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) - self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) - self.assertEqual(t_transform(r).std(), std_transform(1), atol=0.2, rtol=0) - - r.fill_(42) - r = torch.normal(mean, 3) - self.assertEqual(r.dtype, dtype) - self.assertEqual(str(r.device), device) - self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) - self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) - self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.2, rtol=0) - - r.fill_(42) - torch.normal(mean, 3, out=r) - self.assertEqual(r.dtype, dtype) - self.assertEqual(str(r.device), device) - self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) - self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) - self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.2, rtol=0) - - r.fill_(42) - r = torch.normal(2, std) - self.assertFalse(r.dtype.is_complex) - self.assertEqual(str(r.device), device) - self.assertEqual(r.mean(), 2, atol=0.2, rtol=0) - self.assertEqual(r[:, :50].std(), 4, atol=0.3, rtol=0) - self.assertEqual(r[:, 50:].std(), 1, atol=0.2, rtol=0) - - r.fill_(42) - torch.normal(2, std, out=r) - self.assertFalse(r.dtype.is_complex) - self.assertEqual(str(r.device), device) - self.assertEqual(r.mean(), 2, atol=0.2, rtol=0) - self.assertEqual(r[:, :50].std(), 4, atol=0.3, rtol=0) - self.assertEqual(r[:, 50:].std(), 1, atol=0.2, rtol=0) - - r.fill_(42) - r = torch.normal(mean, std) - self.assertEqual(r.dtype, dtype) - self.assertEqual(str(r.device), device) - self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) - self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) - self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0) - self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0) - - r.fill_(42) - torch.normal(mean, std, out=r) - self.assertEqual(r.dtype, dtype) - self.assertEqual(str(r.device), device) - self.assertEqual(t_transform(r[:50]).mean(), 0, atol=0.2, rtol=0) - self.assertEqual(t_transform(r[50:]).mean(), 1, atol=0.2, rtol=0) - self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0) - self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0) - - r.fill_(42) - r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device) - self.assertEqual(r.dtype, dtype) - self.assertEqual(str(r.device), device) - self.assertEqual(t_transform(r).mean(), 2, atol=0.3, rtol=0) - self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.3, rtol=0) - - r.fill_(42) - torch.normal(2, 3, (100, 100), dtype=dtype, device=device, out=r) - self.assertEqual(r.dtype, dtype) - self.assertEqual(str(r.device), device) - self.assertEqual(t_transform(r).mean(), 2, atol=0.3, rtol=0) - self.assertEqual(t_transform(r).std(), std_transform(3), atol=0.3, rtol=0) - - if dtype.is_complex: - helper(self, device, dtype, lambda x: complex(x, x), - lambda t: torch.real(t).to(torch.float), lambda mean: mean / math.sqrt(2)) - helper(self, device, dtype, lambda x: complex(x, x), - lambda t: torch.imag(t).to(torch.float), lambda mean: mean / math.sqrt(2)) - self.assertRaisesRegex( - RuntimeError, "normal expects standard deviation to be non-complex", - lambda: torch.normal(0, torch.empty(100, 100, dtype=dtype, device=device))) - out = torch.empty(100, 100, dtype=dtype, device=device) - self.assertRaisesRegex( - RuntimeError, "normal expects standard deviation to be non-complex", - lambda: torch.normal(0, torch.empty(100, 100, dtype=dtype, device=device), out=out)) - else: - helper(self, device, dtype, lambda x: x, lambda t: t, lambda mean: mean) - - @dtypes(torch.float, torch.double, torch.half) - @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.bfloat16) - def test_uniform_from_to(self, device, dtype): - size = 2000 - alpha = 0.1 - - float_min = torch.finfo(torch.float).min - float_max = torch.finfo(torch.float).max - double_min = torch.finfo(torch.double).min - double_max = torch.finfo(torch.double).max - - if dtype == torch.bfloat16: - min_val = -3.389531389251535e+38 - max_val = 3.389531389251535e+38 - else: - min_val = torch.finfo(dtype).min - max_val = torch.finfo(dtype).max - - values = [double_min, float_min, -42, 0, 42, float_max, double_max] - - for from_ in values: - for to_ in values: - t = torch.empty(size, dtype=dtype, device=device) - if not (min_val <= from_ <= max_val) or not (min_val <= to_ <= max_val): - pass - elif to_ < from_: - self.assertRaisesRegex( - RuntimeError, - "uniform_ expects to return", - lambda: t.uniform_(from_, to_) - ) - elif to_ - from_ > max_val: - self.assertRaisesRegex( - RuntimeError, - "uniform_ expects to-from", - lambda: t.uniform_(from_, to_) - ) - else: - t.uniform_(from_, to_) - range_ = to_ - from_ - if not (dtype == torch.bfloat16) and not ( - dtype == torch.half and device == 'cpu') and not torch.isnan(t).all(): - delta = alpha * range_ - double_t = t.to(torch.double) - if range_ == 0: - self.assertTrue(double_t.min() == from_) - self.assertTrue(double_t.max() == to_) - elif dtype == torch.half: - self.assertTrue(from_ <= double_t.min() <= (from_ + delta)) - self.assertTrue((to_ - delta) <= double_t.max() <= to_) - else: - self.assertTrue(from_ <= double_t.min() <= (from_ + delta)) - self.assertTrue((to_ - delta) <= double_t.max() < to_) - - @dtypes(*torch.testing.get_all_fp_dtypes()) - def test_log_normal(self, device, dtype): - a = torch.tensor([10], dtype=dtype, device=device).log_normal_() - self.assertEqual(a.dtype, dtype) - self.assertEqual(a.size(), torch.Size([1])) - - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_geometric(self, device, dtype): - a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5) - self.assertEqual(a.dtype, dtype) - self.assertEqual(a.size(), torch.Size([1])) - - @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))) - @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False))) - def test_bernoulli_p(self, device, dtype): - for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]): - x = torch.tensor(trivial_p, dtype=dtype, device=device) - self.assertEqual(x.bernoulli().tolist(), trivial_p) - - def isBinary(t): - return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0 - - p = torch.rand(5, 5, dtype=dtype, device=device) - self.assertTrue(isBinary(p.bernoulli())) + p = torch.rand(5, 5, dtype=dtype, device=device) + self.assertTrue(isBinary(p.bernoulli())) p = torch.rand(5, dtype=dtype, device=device).expand(5, 5) self.assertTrue(isBinary(p.bernoulli())) @@ -11052,79 +3650,6 @@ def test_exponential(self, device, dtype): with self.assertRaises(RuntimeError): torch.empty((1,), device=device, dtype=dtype).exponential_(-0.5) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False) + - torch.testing.get_all_complex_dtypes())) - @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True) + - torch.testing.get_all_complex_dtypes())) - def test_exp(self, device, dtype): - for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): - a = torch.tensor(v, dtype=dtype, device=device) * torch.arange(18, device=device) / 3 * math.pi - a = a.to(dtype) - if dtype == torch.bfloat16: - with self.assertRaises(TypeError): # compare_with_numpy doesn't support bfloat16 - self.compare_with_numpy(torch.exp, np.exp, a) - return - self.compare_with_numpy(torch.exp, np.exp, a) - - if dtype.is_complex: - inf_real_zero_imag_in = torch.tensor(complex(float('inf'), 0), device=device, dtype=dtype) - inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item() - self.assertTrue(math.isinf(inf_real_zero_imag_out.real)) - if self.device_type == 'cpu': - pass - # These are commented out because it cannot be consistently reproduced. - # This is incorrect. It should be zero. Need fix! - # https://github.com/pytorch/pytorch/issues/40590 - # self.assertNotEqual(inf_real_zero_imag_out.imag, 0) - # This is incorrect. They should equal. Need fix! - # https://github.com/pytorch/pytorch/issues/40590 - # with self.assertRaises(AssertionError): - # self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) - else: - self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0) - self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) - - zero_real_inf_imag_in = torch.tensor(complex(0, float('inf')), device=device, dtype=dtype) - zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item() - self.assertTrue(math.isnan(zero_real_inf_imag_out.real)) - self.assertTrue(math.isnan(zero_real_inf_imag_out.imag)) - # Ensure we are notified when NumPy changes its behavior - self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in) - - inf_real_imag_in = torch.tensor(complex(float('inf'), float('inf')), device=device, dtype=dtype) - inf_real_imag_out = torch.exp(inf_real_imag_in).item() - if self.device_type == 'cpu': - pass - # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590 - # This is commented out because it cannot be consistently reproduced. - # with self.assertRaises(AssertionError): - # self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) - else: - self.assertTrue(math.isinf(inf_real_imag_out.real)) - self.assertTrue(math.isnan(inf_real_imag_out.imag)) - self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) - - inf_real_nan_imag_in = torch.tensor(complex(float('inf'), float('nan')), device=device, dtype=dtype) - inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item() - if self.device_type == 'cpu': - pass - # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590 - # This is commented out because it cannot be consistently reproduced. - # with self.assertRaises(AssertionError): - # self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) - else: - self.assertTrue(math.isinf(inf_real_nan_imag_out.real)) - self.assertTrue(math.isnan(inf_real_nan_imag_out.imag)) - self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) - - nan_real_inf_imag_in = torch.tensor(complex(float('nan'), float('inf')), device=device, dtype=dtype) - nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item() - self.assertTrue(math.isnan(nan_real_inf_imag_out.real)) - self.assertTrue(math.isnan(nan_real_inf_imag_out.imag)) - # Ensure we are notified when NumPy changes its behavior - self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in) - @skipIfNoSciPy @dtypes(*torch.testing.get_all_fp_dtypes()) def test_uniform_kstest(self, device, dtype): @@ -11185,7 +3710,6 @@ def test_cauchy_kstest(self, device, dtype): self.assertTrue(res.statistic < 0.1) @skipIfNoSciPy - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) def test_geometric_kstest(self, device, dtype): from scipy import stats @@ -11197,196 +3721,6 @@ def test_geometric_kstest(self, device, dtype): res = stats.chisquare(actual, expected) self.assertEqual(res.pvalue, 1.0, atol=0.1, rtol=0) - def test_sign(self, device): - for dtype in torch.testing.get_all_math_dtypes(device): - if dtype.is_complex: - continue - - # Include NaN for floating point numbers - if dtype.is_floating_point: - dt_info = torch.finfo(dtype) - - # Create tensor (with NaN checking) - a = torch.tensor([float('nan'), -12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) - a_target = torch.tensor([0, -1, 0, 1, -1, 1], device=device, dtype=dtype) - - else: - dt_info = torch.iinfo(dtype) - - # If unsigned type, everything should be >= 0 - if dt_info.min == 0: - a = torch.tensor([12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) - a_target = torch.tensor([1, 0, 1, 0, 1], device=device, dtype=dtype) - else: - a = torch.tensor([-12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) - a_target = torch.tensor([-1, 0, 1, -1, 1], device=device, dtype=dtype) - - self.assertEqual(a.sign(), a_target, msg='sign device={} dtype={}'.format(device, dtype)) - self.assertEqual(torch.sign(a), a_target, msg='sign device={} dtype={}'.format(device, dtype)) - - out = torch.empty_like(a) - torch.sign(a, out=out) - self.assertEqual(out, a_target, msg='sign_out device={} dtype={}'.format(device, dtype)) - - a.sign_() - self.assertEqual(a, a_target, msg='sign_ device={} dtype={}'.format(device, dtype)) - - # Include test for bool dtype - a_bool = torch.tensor([True, True, False, float('nan')], device=device).bool() - a_bool_target = torch.tensor([True, True, False, True], device=device).bool() - self.assertEqual(a_bool.sign(), a_bool_target, msg='sign device={} dtype=bool'.format(device)) - self.assertEqual(torch.sign(a_bool), a_bool_target, msg='sign device={} dtype=bool'.format(device)) - - a_out = torch.empty_like(a_bool) - torch.sign(a_bool, out=a_out) - self.assertEqual(a_out, a_bool_target, msg='sign_out device={} dtype=bool'.format(device)) - - a_bool.sign_() - self.assertEqual(a_bool, a_bool_target, msg='sign_ device={} dtype=bool'.format(device)) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.torch.testing.get_all_fp_dtypes())) - def test_signbit_float(self, device, dtype): - t = torch.randn(5, 5, device=device) - - if dtype == torch.bfloat16: - t_bf16 = torch.tensor([1, 0, -1], device=device, dtype=dtype) - self.assertEqual(torch.signbit(t_bf16), torch.tensor([False, False, True])) - else: - self.compare_with_numpy(torch.signbit, np.signbit, t) - - t_target = torch.signbit(t) - out = torch.empty_like(t, device=device, dtype=torch.bool) - torch.signbit(t, out=out) - self.assertEqual(out, t_target) - - t_sp = (0, float('inf'), -float('inf'), float('nan')) - if dtype == torch.bfloat16: - t_sp_df16 = torch.tensor(t_sp, device=device, dtype=dtype) - self.assertEqual(torch.signbit(t_sp_df16), torch.tensor([False, False, True, False])) - else: - self.compare_with_numpy(torch.signbit, np.signbit, t_sp, device, dtype) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) - def test_signbit_int_and_bool(self, device, dtype): - t = torch.randint(-5, 5, (5, 5), device=device) - self.compare_with_numpy(torch.signbit, np.signbit, t) - - t_target = torch.signbit(t) - out = torch.empty_like(t, device=device, dtype=torch.bool) - torch.signbit(t, out=out) - self.assertEqual(out, t_target) - - @dtypes(torch.complex64, torch.complex128) - def test_signbit_complex(self, device, dtype): - vals = (complex(0, -1), complex(-1, 2)) - t = torch.tensor(vals, device=device, dtype=dtype) - out = torch.empty_like(t).real.bool() - - with self.assertRaisesRegex(RuntimeError, 'signbit is not implemented for complex tensors.'): - torch.signbit(t) - with self.assertRaisesRegex(RuntimeError, 'signbit is not implemented for complex tensors.'): - torch.signbit(t, out=out) - - @dtypes(torch.cfloat, torch.cdouble) - def test_sgn(self, device, dtype): - x = torch.randn(100, dtype=dtype) - angle = x.angle() - out = x.sgn() - self.assertEqual(out.angle(), angle) - self.assertEqual(out.abs(), torch.ones_like(x).real) - - x_out = torch.empty_like(x) - torch.sgn(x, out=x_out) - self.assertEqual(x_out.angle(), angle) - self.assertEqual(x_out.abs(), torch.ones_like(x).real) - - @dtypes(*(torch.testing.get_all_dtypes(include_bool=False))) - def test_signbit_non_boolean_output(self, device, dtype): - # test non-boolean tensors as the `out=` parameters - # boolean outputs are tested in the above testcases - t = torch.randn(5, 5) - out = torch.empty_like(t, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, 'does not support non-boolean outputs'): - torch.signbit(t, out=out) - - def test_logical_any(self, device): - x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.any()) - - self.assertEqual( - torch.zeros([1, 3, 400], dtype=torch.uint8, device=device), - x.any(0, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 1, 400], dtype=torch.uint8, device=device), - x.any(1, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 3, 1], dtype=torch.uint8, device=device), - x.any(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 1 - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.any()) - - y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(0, keepdim=True)) - - y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(1, keepdim=True)) - - y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(2, keepdim=True)) - - def test_logical_all(self, device): - x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.all()) - - self.assertEqual( - torch.ones([1, 3, 400], dtype=torch.uint8, device=device), - x.all(0, keepdim=True)) - - self.assertEqual( - torch.ones([2, 1, 400], dtype=torch.uint8, device=device), - x.all(1, keepdim=True)) - - self.assertEqual( - torch.ones([2, 3, 1], dtype=torch.uint8, device=device), - x.all(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 0 - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.all()) - - y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(0, keepdim=True)) - - y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(1, keepdim=True)) - - y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(2, keepdim=True)) - def test_pairwise_distance_empty(self, device): shape = (2, 0) x = torch.randn(shape, device=device) @@ -11576,157 +3910,14 @@ def test_multinomial_constraints(self, device): RuntimeError, "number of categories cannot exceed", lambda: torch.multinomial(x, 3)) - def test_add(self, device): - dtypes = [torch.float, torch.double] + torch.testing.get_all_complex_dtypes() - for dtype in dtypes: - # [res] torch.add([res,] tensor1, tensor2) - m1 = torch.randn(100, 100, dtype=dtype, device=device) - v1 = torch.randn(100, dtype=dtype, device=device) - - # contiguous - res1 = torch.add(m1[4], v1) - res2 = res1.clone().zero_() - for i in range(m1.size(1)): - res2[i] = m1[4, i] + v1[i] - self.assertEqual(res1, res2) - - m1 = torch.randn(100, 100, device=device) - v1 = torch.randn(100, device=device) - - # non-contiguous - res1 = torch.add(m1[:, 4], v1) - res2 = res1.clone().zero_() - for i in range(m1.size(0)): - res2[i] = m1[i, 4] + v1[i] - self.assertEqual(res1, res2) - - # [res] torch.add([res,] tensor, value) - m1 = torch.randn(10, 10, device=device) - - # contiguous - res1 = m1.clone() - res1[3].add_(2) - res2 = m1.clone() - for i in range(m1.size(1)): - res2[3, i] = res2[3, i] + 2 - self.assertEqual(res1, res2) - - # non-contiguous - m1 = torch.randn(10, 10, device=device) - res1 = m1.clone() - res1[:, 3].add_(2) - res2 = m1.clone() - for i in range(m1.size(0)): - res2[i, 3] = res2[i, 3] + 2 - self.assertEqual(res1, res2) - - # inter-type - m1 = torch.randn(10, 10, dtype=dtype, device=device) - self.assertEqual(m1 + 3, m1 + torch.tensor(3)) - self.assertEqual(3 + m1, torch.tensor(3) + m1) - - # contiguous + non-contiguous - m1 = torch.randn(10, 10, dtype=dtype, device=device) - m2 = torch.randn(10, 10, dtype=dtype, device=device).t() - res = m1 + m2 - self.assertTrue(res.is_contiguous()) - self.assertEqual(res, m1 + m2.contiguous()) - - # 1d + empty - m1 = torch.tensor([1.0], dtype=dtype, device=device) - m2 = torch.tensor([], dtype=dtype, device=device) - self.assertEqual(m1 + m2, []) - - # inter-type unint8 - one = torch.tensor(1, dtype=torch.uint8, device=device) - self.assertEqual(torch.add(one, 1), 2) - self.assertEqual(torch.add(one, 1).dtype, torch.uint8) - - # bool - m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) - m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) - expected = torch.tensor([True, True, False, True, False, True], dtype=torch.bool, device=device) - self.assertEqual(m1 + m2, expected) - - # fused multiply add - a = torch.zeros(2, 3, dtype=torch.bool, device=device) - res = torch.add(a, a, alpha=0) - expected = torch.zeros(2, 3, device=device).bool() - self.assertEqual(res, expected) - - # bfloat16 - m1 = torch.tensor([1., 2.], dtype=torch.bfloat16) - m2 = torch.tensor([3., 4.], dtype=torch.bfloat16) - self.assertEqual(m1 + m2, torch.tensor([4., 6.], dtype=torch.bfloat16)) - - # mismatched alpha - m1 = torch.tensor([1], dtype=torch.int8, device=device) - m2 = torch.tensor([2], dtype=torch.int8, device=device) - self.assertRaisesRegex(RuntimeError, - r"Boolean alpha only supported for Boolean results\.", - lambda: torch.add(m1, m2, alpha=True)) - self.assertRaisesRegex(RuntimeError, - r"For integral input tensors, argument alpha must not be a floating point number\.", - lambda: torch.add(m1, m2, alpha=1.0)) - - # complex - m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64) - m2 = torch.tensor(4., dtype=torch.float64) - self.assertRaisesRegex(RuntimeError, r"result type ComplexFloat can't be cast to the desired output type Double", - lambda: torch.add(m1, m1, out=m2)) - - - def test_sub_typing(self, device): - m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) - m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with two bool tensors is not supported. " - r"Use the `\^` or `logical_xor\(\)` operator instead.", - lambda: m1 - m2) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with a bool tensor is not supported. " - r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", - lambda: 1 - m1) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with a bool tensor is not supported. " - r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", - lambda: m2 - 1) - - # mismatched alpha - m1 = torch.tensor([1], dtype=torch.int8, device=device) - m2 = torch.tensor([2], dtype=torch.int8, device=device) - self.assertRaisesRegex(RuntimeError, - r"Boolean alpha only supported for Boolean results\.", - lambda: torch.sub(m1, m2, alpha=True)) - self.assertRaisesRegex(RuntimeError, - r"For integral input tensors, argument alpha must not be a floating point number\.", - lambda: torch.sub(m1, m2, alpha=1.0)) - - def test_mul(self, device): - m1 = torch.randn(10, 10, device=device) - res1 = m1.clone() - res1[:, 3].mul_(2) - res2 = m1.clone() - for i in range(res1.size(0)): - res2[i, 3] = res2[i, 3] * 2 - self.assertEqual(res1, res2) - - a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device) - a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) - self.assertEqual(a1 * a2, torch.tensor([True, False, False, False], dtype=torch.bool, device=device)) - - if device == 'cpu': - a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device) - a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device) - self.assertEqual(a1 * a2, torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), atol=0.01, rtol=0) - self.assertEqual(a1.mul(a2), a1 * a2) - def test_cumsum(self, device): x = torch.rand(100, 100, device=device) res1 = torch.cumsum(x, 1) res2 = torch.Tensor().to(device) torch.cumsum(x, 1, out=res2) self.assertEqual(res1, res2) + x.cumsum_(1) + self.assertEqual(res1, x) a = torch.tensor([[True, False, True], [False, False, False], @@ -11775,6 +3966,8 @@ def test_cumprod(self, device): res2 = torch.Tensor().to(device) torch.cumprod(x, 1, out=res2) self.assertEqual(res1, res2) + x.cumprod_(1) + self.assertEqual(res1, x) a = torch.tensor([[True, False, True], [False, False, False], @@ -11919,6 +4112,7 @@ def _test_large_cum_fn_helper(self, x, fn): actual = fn(x).cpu().float() self.assertEqual(expected, actual.cpu().float()) + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration") @onlyCUDA @dtypesIfCUDA(torch.half) # only small dtype not to get oom def test_large_cumsum(self, device, dtype): @@ -11971,288 +4165,6 @@ def test_cummin_discontiguous(self, device): expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 3, 3, 5]], device=device, dtype=torch.long) self._test_cumminmax_helper(x, torch.cummin, expected_val, expected_ind) - - def test_std_mean(self, device): - x = torch.rand(100, 50, 20, device=device) - for dim in range(x.dim()): - for unbiased in [False, True]: - for keepdim in [False, True]: - std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(std1, std2) - self.assertEqual(mean1, mean2) - - def test_std_mean_all_dims(self, device): - x = torch.rand(100, 50, 20, device=device) - for unbiased in [False, True]: - std1, mean1 = torch.std_mean(x, unbiased=unbiased) - std2 = x.std(unbiased=unbiased) - mean2 = x.mean() - self.assertEqual(std1, std2) - self.assertEqual(mean1, mean2) - - def test_var_mean(self, device): - x = torch.rand(100, 300, 50, device=device) - for dim in range(x.dim()): - for unbiased in [False, True]: - for keepdim in [False, True]: - var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(var1, var2) - self.assertEqual(mean1, mean2) - - def test_var_mean_all_dims(self, device): - x = torch.rand(100, 50, 20, device=device) - for unbiased in [False, True]: - var1, mean1 = torch.var_mean(x, unbiased=unbiased) - var2 = x.var(unbiased=unbiased) - mean2 = x.mean() - self.assertEqual(var1, var2) - self.assertEqual(mean1, mean2) - - def test_std_mean_some_dims(self, device): - sizes = (4, 6, 7, 5, 3) - dims = len(sizes) - x = torch.rand(sizes, device=device) - for num_of_dims in range(2, dims): - dim_list = list(combinations(list(range(dims)), r=num_of_dims)) - for dim in dim_list: - for unbiased in [False, True]: - for keepdim in [False, True]: - std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(std1, std2) - self.assertEqual(mean1, mean2) - - def _compare_std_var_with_numpy(self, op, device, dtype, input, dim, - keepdim, unbiased, use_out): - assert TEST_NUMPY - - a = input.cpu().numpy() if input.dtype is not torch.bfloat16 else input.float().cpu().numpy() - numpy_kwargs = { - 'axis' : dim, - 'keepdims' : keepdim, - 'ddof' : 1 if unbiased else 0, - } - - if dim is None: - del numpy_kwargs['axis'] - del numpy_kwargs['keepdims'] - - if op == 'var': - torch_op = torch.var - numpy_op = np.var - elif op == 'std': - torch_op = torch.std - numpy_op = np.std - else: - self.fail("Unknown op!") - - numpy_result = numpy_op(a, **numpy_kwargs) - - if dim is None and use_out is False: - torch_result = torch_op(input, unbiased) - elif dim is not None and use_out is False: - torch_result = torch_op(input, dim, unbiased, keepdim) - elif dim is not None and use_out is True: - out = torch.empty(0, device=device, dtype=dtype) - torch_result = torch_op(input, dim, unbiased, keepdim, out=out) - else: - out = torch.empty(0, device=device, dtype=dtype) - try: - torch_result = torch_op(input, dim, unbiased, keepdim, out=out) - except RuntimeError: - return - self.fail("Failed to hit RuntimeError!") - - self.assertEqual(torch_result, numpy_result, exact_dtype=False) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypesIfCUDA(torch.float, torch.double, torch.cfloat, torch.cdouble) - @dtypes(torch.float, torch.double) - def test_var_vs_numpy(self, device, dtype): - _size = (20, 20) - - for test_case in product((torch.randn(_size, device=device, dtype=dtype),), - (None, 0, 1), - (False, True), - (False, True), - (False, True),): - self._compare_std_var_with_numpy('var', device, dtype, *test_case) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypesIfCUDA(torch.float, torch.double, torch.cfloat, torch.cdouble) - @dtypes(torch.float, torch.double) - def test_std_vs_numpy(self, device, dtype): - _size = (20, 20) - - for test_case in product((torch.randn(_size, device=device, dtype=dtype),), - (None, 0, 1), - (False, True), - (False, True), - (False, True),): - self._compare_std_var_with_numpy('std', device, dtype, *test_case) - - def test_amin_amax_some_dims(self, device): - sizes = (4, 6, 7, 5, 3) - dims = len(sizes) - x = torch.rand(sizes, device=device) - for num_of_dims in range(2, dims): - dim_list = list(combinations(list(range(dims)), r=num_of_dims)) - for dim in dim_list: - for keepdim in [False, True]: - amin1 = torch.amin(x, dim=dim, keepdim=keepdim) - amax1 = torch.amax(x, dim=dim, keepdim=keepdim) - amin2 = x - amax2 = x - for i, d in enumerate(dim): - if not keepdim: - d -= i - amin2 = torch.amin(amin2, dim=d, keepdim=keepdim) - amax2 = torch.amax(amax2, dim=d, keepdim=keepdim) - self.assertEqual(amin1, amin2) - self.assertEqual(amax1, amax2) - - @onlyCUDA - @expectedAlertNondeterministic('_histc_cuda', fn_has_device_arg=False) - def test_histc_alert_nondeterministic(self, device): - torch.histc(torch.tensor([], device=device), min=0, max=3) - - def test_histc(self, device): - # negative nbins throws - with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'): - torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1) - # empty tensor - actual = torch.histc(torch.tensor([], device=device), min=0, max=3) - expected = torch.zeros(100, dtype=torch.float, device=device) - self.assertEqual(expected, actual) - - # without nbins - actual = torch.histc( - torch.tensor([2, 5], dtype=torch.float, device=device)) - expected = torch.zeros(100, dtype=torch.float, device=device) - expected[0] = 1 - expected[99] = 1 - self.assertEqual(expected, actual) - # tensor with the same element - actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5) - self.assertEqual( - torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device), - actual) - # no element falls between [min, max] - actual = torch.histc( - torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3) - self.assertEqual( - torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device), - actual) - # element falls below min + integral bin size and - actual = torch.histc( - torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device), - bins=5, min=1, max=5) - self.assertEqual( - torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device), - actual) - # non-integral bin size - actual = torch.histc( - torch.tensor([1, 2, 1], dtype=torch.float, device=device), - bins=4, min=0, max=3) - self.assertEqual( - torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), - actual) - # double input - actual = torch.histc( - torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3) - self.assertEqual( - torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device), - actual) - self.assertEqual(actual.dtype, torch.double) - # mixed input - actual = torch.histc( - torch.tensor([1., 2, 1], dtype=torch.float, device=device), - bins=4, min=0, max=3) - self.assertEqual( - torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), - actual) - self.assertEqual(actual.dtype, torch.float) - # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar. - actual = torch.histc( - torch.tensor(0, dtype=torch.float, device=device), - bins=1, min=0, max=3) - self.assertEqual( - torch.tensor([1], dtype=torch.float, device=device), - actual) - # tensors with inf; min, max not provided -- should throw a RuntimeError - with self.assertRaisesRegex(RuntimeError, r'range of \[inf, inf\] is not finite'): - torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device)) - with self.assertRaisesRegex(RuntimeError, r'range of \[1, inf\] is not finite'): - torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device)) - # tensors with inf; min, max provided - self.assertEqual( - torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device), - bins=1, min=0, max=3), - torch.tensor([0], dtype=torch.float, device=device)) - self.assertEqual( - torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device), - bins=4, max=3), - torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device)) - # tensor with nan -- should throw a RuntimeError - with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'): - torch.histc(torch.tensor([float("nan")], dtype=torch.float, device=device)) - # tensors with min > max -- should throw a RuntimeError - with self.assertRaisesRegex(RuntimeError, "max must be larger than min"): - torch.histc(torch.tensor([1., 2., 3.], dtype=torch.float, device=device), - bins=4, min=5, max=1) - - # test against numpy.histogram() - def test_against_np(tensor, bins=100, min=0, max=0): - if min == 0 and max == 0: - min = tensor.min().item() - max = tensor.max().item() - nparr = tensor.cpu().numpy() - actual = torch.histc(tensor, bins=bins, min=min, max=max) - expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0]) - actual_cpu = actual.cpu() - # NB: Numpy returns a int64 tensor, like normal people... - self.assertEqual(actual, expected.to(actual_cpu)) - - if TEST_NUMPY: - test_against_np(torch.tensor([1., 2, 1], device=device)) - test_against_np(torch.randn(5000, device=device)) - - # Test bins arg - test_against_np(torch.randn(301, device=device), bins=10) - - # Test truncated range - test_against_np(torch.randn(201, device=device), min=0.1, max=1) - - noncontig = torch.randn(100, 3, device=device)[:, 2] - test_against_np(noncontig) - - multidim = torch.randn(3, 5, 7, 2, device=device) - test_against_np(multidim) - - expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) - test_against_np(expanded) - - def test_bool_tensor_comparison_ops(self, device): - a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool, device=device) - b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool, device=device) - self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)) - self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)) - self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a == torch.tensor(True, dtype=torch.bool, device=device), - torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a == torch.tensor(0, dtype=torch.bool, device=device), - torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertFalse(a.equal(b)) - def test_bool_tensor_value_change(self, device): x = torch.tensor([True, False], dtype=torch.bool, device=device) x[0] = False @@ -12289,36 +4201,6 @@ def test_copy_all_dtypes_and_devices(self, device): # not the data self.assertEqual(x, y) - def test_resize_all_dtypes_and_devices(self, device): - shape = (2, 2) - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) - x.resize_(shape) - self.assertEqual(shape, x.shape) - - def test_resize_as_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) - y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) - x.resize_as_(y) - self.assertEqual(y.shape, x.shape) - - def test_view_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) - self.assertEqual(x.view(6).shape, [6]) - - def test_fill_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): - for x in [torch.tensor((10, 10), dtype=dt, device=device), - torch.empty(10000, dtype=dt, device=device)]: # large tensor - numel = x.numel() - bound = 100 if dt in (torch.uint8, torch.int8) else 2000 - for n in range(-bound, bound, bound // 10): - x.fill_(n) - self.assertEqual(x, torch.tensor([n] * numel, dtype=dt, device=device)) - self.assertEqual(dt, x.dtype) - def test_clone_all_dtypes_and_devices(self, device): for dt in torch.testing.get_all_dtypes(): x = torch.tensor((1, 1), dtype=dt, device=device) @@ -12331,68 +4213,9 @@ def test_clone_zero_stride_dim(self, device): y = x.as_strided([2, 1, 5], [1, 0, 2]) self.assertEqual(y, y.clone()) - def test_cat_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device) - - expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device) - self.assertEqual(torch.cat((x, x), 0), expected1) - - expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dt, device=device) - self.assertEqual(torch.cat((x, x), 1), expected2) - - - - @onlyOnCPUAndCUDA - def test_vander(self, device): - x = torch.tensor([1, 2, 3, 5], device=device) - - self.assertEqual((0, 0), torch.vander(torch.tensor([]), 0).shape) - - with self.assertRaisesRegex(RuntimeError, "N must be non-negative."): - torch.vander(x, N=-1) - - with self.assertRaisesRegex(RuntimeError, "x must be a one-dimensional tensor."): - torch.vander(torch.stack((x, x))) - - @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') - @onlyOnCPUAndCUDA - @dtypes(torch.bool, torch.uint8, torch.int8, torch.short, torch.int, torch.long, - torch.float, torch.double, - torch.cfloat, torch.cdouble) - def test_vander_types(self, device, dtype): - if dtype is torch.uint8: - # Note: no negative uint8 values - X = [[1, 2, 3, 5], [0, 1 / 3, 1, math.pi, 3 / 7]] - elif dtype is torch.bool: - # Note: see https://github.com/pytorch/pytorch/issues/37398 - # for why this is necessary. - X = [[True, True, True, True], [False, True, True, True, True]] - elif dtype in [torch.cfloat, torch.cdouble]: - X = [[1 + 1j, 1 + 0j, 0 + 1j, 0 + 0j], - [2 + 2j, 3 + 2j, 4 + 3j, 5 + 4j]] - else: - X = [[1, 2, 3, 5], [-math.pi, 0, 1 / 3, 1, math.pi, 3 / 7]] - - N = [None, 0, 1, 3] - increasing = [False, True] - - for x, n, inc in product(X, N, increasing): - numpy_dtype = torch_to_numpy_dtype_dict[dtype] - pt_x = torch.tensor(x, device=device, dtype=dtype) - np_x = np.array(x, dtype=numpy_dtype) - - pt_res = torch.vander(pt_x, increasing=inc) if n is None else torch.vander(pt_x, n, inc) - np_res = np.vander(np_x, n, inc) - - self.assertEqual( - pt_res, - torch.from_numpy(np_res), - atol=1e-3, - rtol=0, - exact_dtype=False) - - def test_addcmul(self, device): + @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda'))) + @dtypes(*set(torch.testing.get_all_math_dtypes('cpu'))) + def test_addcmul(self, device, dtype): def rand_tensor(size, dtype, device): if dtype.is_floating_point or dtype.is_complex: return torch.rand(size=size, dtype=dtype, device=device) @@ -12401,50 +4224,20 @@ def rand_tensor(size, dtype, device): else: return torch.randint(-5, 5, size=size, dtype=dtype, device=device) - for dtype in torch.testing.get_all_math_dtypes(device): - a = rand_tensor((2, 2), dtype=dtype, device=device) - b = rand_tensor((2, 2), dtype=dtype, device=device) - c = rand_tensor((2, 2), dtype=dtype, device=device) - if dtype.is_floating_point: - alpha = 0.1 - else: - alpha = 3 + a = rand_tensor((2, 2), dtype=dtype, device=device) + b = rand_tensor((2, 2), dtype=dtype, device=device) + c = rand_tensor((2, 2), dtype=dtype, device=device) - # addcmul is not supported for complex dtypes on cuda yet - if device.startswith('cuda') and dtype.is_complex: - continue + alpha = _number(0.5, 3, dtype) - actual = torch.addcmul(a, b, c, value=alpha) - expected = a + alpha * b * c + actual = torch.addcmul(a, b, c, value=alpha) + expected = a + alpha * b * c - self.assertEqual(expected, actual) + self.assertEqual(expected, actual) - with self.maybeWarnsRegex( - UserWarning, "This overload of addcmul is deprecated"): - self.assertEqual(actual, torch.addcmul(a, alpha, b, c)) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - @tf32_on_and_off(0.005) - def test_tensordot(self, device): - a = torch.arange(60., device=device).reshape(3, 4, 5) - b = torch.arange(24., device=device).reshape(4, 3, 2) - c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() - cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), - axes=([1, 0], [0, 1]))) - self.assertEqual(c, cn) - a = torch.randn(2, 3, 4, 5, device=device) - b = torch.randn(4, 5, 6, 7, device=device) - c = torch.tensordot(a, b, dims=2).cpu() - cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), - axes=2)) - - with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): - torch.tensordot(a, b, dims=-1) - - self.assertEqual(c, cn) - c = torch.tensordot(a, b).cpu() - cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(c, cn) + with self.maybeWarnsRegex( + UserWarning, "This overload of addcmul is deprecated"): + self.assertEqual(actual, torch.addcmul(a, alpha, b, c)) def test_narrow_empty(self, device): x = torch.randn(2, 3, 4, device=device) @@ -12454,33 +4247,6 @@ def test_narrow_empty(self, device): sz[d] = 0 self.assertEqual(sz, y.size()) - @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) - def test_logical(self, device, dtype): - if dtype != torch.bool: - x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype) - b = torch.tensor([2], device=device, dtype=dtype) - self.assertEqual(x.lt(2), torch.tensor([True, False, False, False])) - self.assertEqual(x.le(2), torch.tensor([True, True, False, False])) - self.assertEqual(x.ge(2), torch.tensor([False, True, True, True])) - self.assertEqual(x.gt(2), torch.tensor([False, False, True, True])) - self.assertEqual(x.eq(2), torch.tensor([False, True, False, False])) - self.assertEqual(x.ne(2), torch.tensor([True, False, True, True])) - - self.assertEqual(x.lt(b), torch.tensor([True, False, False, False])) - self.assertEqual(x.le(b), torch.tensor([True, True, False, False])) - self.assertEqual(x.ge(b), torch.tensor([False, True, True, True])) - self.assertEqual(x.gt(b), torch.tensor([False, False, True, True])) - self.assertEqual(x.eq(b), torch.tensor([False, True, False, False])) - self.assertEqual(x.ne(b), torch.tensor([True, False, True, True])) - else: - x = torch.tensor([True, False, True, False], device=device) - self.assertEqual(x.lt(True), torch.tensor([False, True, False, True])) - self.assertEqual(x.le(True), torch.tensor([True, True, True, True])) - self.assertEqual(x.ge(True), torch.tensor([True, False, True, False])) - self.assertEqual(x.gt(True), torch.tensor([False, False, False, False])) - self.assertEqual(x.eq(True), torch.tensor([True, False, True, False])) - self.assertEqual(x.ne(True), torch.tensor([False, True, False, True])) - def test_index_copy(self, device): num_copy, num_dest = 3, 20 dest = torch.randn(num_dest, 4, 5, device=device) @@ -12513,6 +4279,29 @@ def test_index_copy(self, device): c = torch.zeros(3) self.assertRaises(IndexError, lambda: a.index_copy_(dim=1, index=torch.tensor([3]), source=c)) + # Ensures that index_copy throws nondeterministic alerts in the correct cases + @onlyOnCPUAndCUDA + @dtypes(torch.double) + def test_index_copy_nondeterministic_alert(self, device, dtype): + @expectedAlertNondeterministic('index_copy') + def test_func(slf, device, call_type): + S = 10 + a = torch.randn(S, device=device) + b = torch.randn(S, device=device) + index = torch.randint(S, (S,), device=device) + if call_type == 'function': + torch.index_copy(a, 0, index, b) + elif call_type == 'method': + a.index_copy(0, index, b) + elif call_type == 'method inplace': + a.index_copy_(0, index, b) + else: + self.fail(f"'{call_type}' is not a valid call type") + + test_func(self, device, 'function') + test_func(self, device, 'method') + test_func(self, device, 'method inplace') + def test_index_fill(self, device): for dt in torch.testing.get_all_dtypes(): if dt == torch.half or dt == torch.bfloat16 or dt.is_complex: @@ -12524,36 +4313,37 @@ def test_index_fill(self, device): self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dt, device=device)) def test_index_select(self, device): - src = torch.randn(3, 4, 5, device=device) - # Index can be duplicated. - idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) - - # Check that 'out' is used correctly. - out = torch.randn(5 * 4 * 5, device=device) - dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5)) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) - out.fill_(0.123) - self.assertEqual(out, dest.view(-1)) # Must point to the same storage. - - # Bool tensor - src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool) - idx = torch.tensor([1], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(torch.tensor([True]), dest) - - # Complex Tensor - src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device) - idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) + for dtype in [torch.int, torch.long]: + src = torch.randn(3, 4, 5, device=device) + # Index can be duplicated. + idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) + + # Check that 'out' is used correctly. + out = torch.randn(5 * 4 * 5, device=device) + dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5)) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) + out.fill_(0.123) + self.assertEqual(out, dest.view(-1)) # Must point to the same storage. + + # Bool tensor + src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool) + idx = torch.tensor([1], dtype=dtype, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(torch.tensor([True]), dest) + + # Complex Tensor + src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device) + idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) def test_take_empty(self, device): for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: @@ -13004,89 +4794,6 @@ def test_dim_function_empty(self, device): c = torch.randn((0, 1, 2), device=device) self.assertEqual(c, c.index_select(0, ind_empty)) - @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_nonzero(self, device, dtype): - - shapes = [ - torch.Size((12,)), - torch.Size((12, 1)), - torch.Size((1, 12)), - torch.Size((6, 2)), - torch.Size((3, 2, 2)), - torch.Size((5, 5, 5)), - ] - - def gen_nontrivial_input(shape, dtype, device): - if dtype != torch.bfloat16: - return torch.randint(2, shape, device=device, dtype=dtype) - else: - # windows does not work for bfloat16 randing - return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) - - for shape in shapes: - tensor = gen_nontrivial_input(shape, dtype, device) - dst1 = torch.nonzero(tensor, as_tuple=False) - dst2 = tensor.nonzero(as_tuple=False) - dst3 = torch.empty([], dtype=torch.long, device=device) - torch.nonzero(tensor, out=dst3) - self.assertRaisesRegex( - TypeError, - "received an invalid combination of arguments", - lambda: torch.nonzero(tensor, as_tuple=True, out=dst3)) - if self.device_type != 'xla': - # xla does not raise runtime error - self.assertRaisesRegex( - RuntimeError, - "scalar type Long", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float)) - ) - if self.device_type == 'cuda': - self.assertRaisesRegex( - RuntimeError, - "on the same device", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.long)) - ) - np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() - np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() - self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) - self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) - self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0) - tup1 = torch.nonzero(tensor, as_tuple=True) - tup2 = tensor.nonzero(as_tuple=True) - tup1 = torch.stack(tup1).t().cpu() - tup2 = torch.stack(tup2).t().cpu() - self.assertEqual(tup1, np_result, atol=0, rtol=0) - self.assertEqual(tup2, np_result, atol=0, rtol=0) - - @onlyOnCPUAndCUDA - def test_nonzero_discontiguous(self, device): - shape = (4, 4) - tensor = torch.randint(2, shape, device=device) - tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) - dst1 = tensor.nonzero(as_tuple=False) - dst2 = tensor_nc.nonzero(as_tuple=False) - self.assertEqual(dst1, dst2, atol=0, rtol=0) - dst3 = torch.empty_like(dst1) - data_ptr = dst3.data_ptr() - # expect dst3 storage to be reused - torch.nonzero(tensor, out=dst3) - self.assertEqual(data_ptr, dst3.data_ptr()) - self.assertEqual(dst1, dst3, atol=0, rtol=0) - # discontiguous out - dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] - data_ptr = dst4.data_ptr() - strides = dst4.stride() - torch.nonzero(tensor, out=dst4) - self.assertEqual(data_ptr, dst4.data_ptr()) - self.assertEqual(dst1, dst4, atol=0, rtol=0) - self.assertEqual(strides, dst4.stride()) - - def test_nonzero_non_diff(self, device): - x = torch.randn(10, requires_grad=True) - nz = x.nonzero() - self.assertFalse(nz.requires_grad) - def _brute_pdist(self, inp, p=2): """Computes the same as torch.pdist using primitives""" n = inp.shape[-2] @@ -13139,6 +4846,7 @@ def test_pdist_norm_backward(self, device): for trans in [False, True]: self._pdist_single(shape, device, p, torch.float64, trans, grad_check=True) + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration") @skipIfRocm def test_pdist_norm_large(self, device): # use dim0>=46342 for forward, see: @@ -13150,158 +4858,10 @@ def test_pdist_norm_large(self, device): actual_gpu = torch.pdist(x.to(device), p=2) self.assertEqual(expected_cpu, actual_gpu.cpu()) - def test_atan2(self, device): - def _test_atan2_with_size(size, device): - a = torch.rand(size=size, device=device, dtype=torch.double) - b = torch.rand(size=size, device=device, dtype=torch.double) - actual = a.atan2(b) - x = a.view(-1) - y = b.view(-1) - expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], - device=device, dtype=torch.double) - self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02) - - _test_atan2_with_size((2, 2), device) - _test_atan2_with_size((3, 3), device) - _test_atan2_with_size((5, 5), device) - - def test_atan2_edgecases(self, device): - def _test_atan2(x, y, expected, device, dtype): - expected_tensor = torch.tensor([expected], dtype=dtype, device=device) - x_tensor = torch.tensor([x], dtype=dtype, device=device) - y_tensor = torch.tensor([y], dtype=dtype, device=device) - actual = torch.atan2(y_tensor, x_tensor) - self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02) - - for dtype in [torch.float, torch.double]: - _test_atan2(0, 0, 0, device, dtype) - _test_atan2(0, 1, math.pi / 2, device, dtype) - _test_atan2(0, -1, math.pi / -2, device, dtype) - _test_atan2(-1, 0, math.pi, device, dtype) - _test_atan2(1, 0, 0, device, dtype) - _test_atan2(-1, -1, math.pi * -3 / 4 , device, dtype) - _test_atan2(1, 1, math.pi / 4 , device, dtype) - _test_atan2(1, -1, math.pi / -4 , device, dtype) - _test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_trapz(self, device): - def test_dx(sizes, dim, dx, device): - t = torch.randn(sizes, device=device) - actual = torch.trapz(t, dx=dx, dim=dim) - expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim) - self.assertEqual(expected.shape, actual.shape) - self.assertEqual(expected, actual) - - def test_x(sizes, dim, x, device): - t = torch.randn(sizes, device=device) - actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim) - expected = np.trapz(t.cpu().numpy(), x=x, axis=dim) - self.assertEqual(expected.shape, actual.shape) - self.assertEqual(expected, actual.cpu()) - - test_dx((2, 3, 4), 1, 1, device) - test_dx((10, 2), 0, 0.1, device) - test_dx((1, 10), 0, 2.3, device) - test_dx((0, 2), 0, 1.0, device) - test_dx((0, 2), 1, 1.0, device) - test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) - test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) - test_x((1, 10), 0, [1.0], device) - test_x((0, 2), 0, [], device) - test_x((0, 2), 1, [1.0, 2.0], device) - with self.assertRaisesRegex( - IndexError, - 'Dimension out of range'): - test_x((2, 3), 2, [], device) - test_dx((2, 3), 2, 1.0, device) - with self.assertRaisesRegex( - RuntimeError, - 'There must be one `x` value for each sample point'): - test_x((2, 3), 1, [1.0, 2.0], device) - test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) - - def test_reduction_empty(self, device): - fns_to_test = [ - # name, function, identity - ('max', torch.max, None), - ('amax', torch.amax, None), - ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None), - ('argmax', torch.argmax, None), - ('min', torch.min, None), - ('amin', torch.amin, None), - ('argmin', torch.argmin, None), - ('mode', torch.mode, None), - ('median', torch.median, None), - - ('prod', torch.prod, 1.), - ('sum', torch.sum, 0.), - ('norm', torch.norm, 0.), - ('mean', torch.mean, nan), - ('var', torch.var, nan), - ('std', torch.std, nan), - ('logsumexp', torch.logsumexp, -inf), - ] - - shape = (2, 0, 4) - x = torch.randn(shape, device=device) - - for fn in [torch.max, torch.min]: - ident_err = 'operation does not have an identity' - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x)) - - for item in fns_to_test: - name, fn, identity = item - if identity is None: - ident_err = 'does not have an identity' - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2)) - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2, keepdim=True)) - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1)) - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True)) - else: - self.assertEqual(torch.empty((2, 0), device=device), fn(x, dim=2)) - self.assertEqual(torch.empty((2, 0, 1), device=device), fn(x, dim=2, keepdim=True)) - # assertEqual doesn't work with inf, -inf, nan and two tensors. - check = (torch.testing.assert_allclose if math.isnan(identity) or math.isinf(identity) else - self.assertEqual) - check(torch.full((2, 4), identity, device=device), fn(x, dim=1)) - check(torch.full((2, 1, 4), identity, device=device), fn(x, dim=1, keepdim=True)) - try: - check(torch.full((), identity, device=device), fn(x)) - except TypeError as err: - # ignore if there is no allreduce. - self.assertTrue('dim' in str(err)) - - # any - xb = x.to(torch.uint8) - yb = x.to(torch.uint8) - self.assertEqual((2, 0), xb.any(2).shape) - self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) - self.assertEqual(torch.zeros((2, 4), device=device, dtype=torch.uint8), xb.any(1)) - self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=torch.uint8), xb.any(1, keepdim=True)) - self.assertEqual(torch.zeros((), device=device, dtype=torch.uint8), xb.any()) - - # all - self.assertEqual((2, 0), xb.all(2).shape) - self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) - self.assertEqual(torch.ones((2, 4), device=device, dtype=torch.uint8), xb.all(1)) - self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=torch.uint8), xb.all(1, keepdim=True)) - self.assertEqual(torch.ones((), device=device, dtype=torch.uint8), xb.all()) - @onlyOnCPUAndCUDA - def test_addcdiv(self, device): - def _test_addcdiv(a, alpha, b, c): - actual = torch.addcdiv(a, b, c, value=alpha) - # implementation of addcdiv downcasts alpha. arithmetic ops don't. - if not actual.dtype.is_floating_point: - alpha = int(alpha) - expected = a + (alpha * b) / c - self.assertEqual(expected, actual) - - with self.maybeWarnsRegex( - UserWarning, "This overload of addcdiv is deprecated"): - self.assertEqual(actual, torch.addcdiv(a, alpha, b, c)) - + @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda'))) + @dtypes(*set(torch.testing.get_all_math_dtypes('cpu'))) + def test_addcdiv(self, device, dtype): def non_zero_rand(size, dtype, device): if dtype.is_floating_point or dtype.is_complex: a = torch.rand(size=size, dtype=dtype, device=device) @@ -13311,60 +4871,26 @@ def non_zero_rand(size, dtype, device): a = torch.randint(-5, 5, size=size, dtype=dtype, device=device) return a + (a == 0).to(dtype) - def _helper(): - _test_addcdiv( - non_zero_rand((2, 2), dtype=dtype, device=device), - 0.5, - non_zero_rand((2, 2), dtype=dtype, device=device), - non_zero_rand((2, 2), dtype=dtype, device=device)) - - for dtype in torch.testing.get_all_math_dtypes(device): - if dtype.is_complex: - # CPU complex addcdiv is wildly inaccurate - if self.device_type == 'cpu': - with self.assertRaises(AssertionError): - _helper() - - # CUDA complex addcdiv is not implemented - if self.device_type == 'cuda': - with self.assertRaises(RuntimeError): - _helper() - elif not dtype.is_floating_point: - # Integer division with addcdiv is prohibited - with self.assertRaises(RuntimeError): - _helper() - else: - _helper() - - # This function tests that a nan value is returned for input values not in domain - @dtypes(torch.float32, torch.float64) - def test_acosh_domain_float(self, device, dtype): - # Domain of acosh is [1, inf), for values outside the domain - output is mapped - # to NaN, except for input value `inf` - output is mapped to `inf` - sample = torch.tensor([float('-inf'), 1.00, -1.23, -0.06, 0.98, float('inf')], - device=device, dtype=dtype) - nan_mask = torch.tensor([True, False, True, True, True, False], device=device) - inf_mask = torch.tensor([False, False, False, False, False, True], device=device) - self.assertEqual(torch.isnan(torch.acosh(sample)), nan_mask) - self.assertEqual(torch.isnan(sample.acosh()), nan_mask) - self.assertEqual(torch.isinf(torch.acosh(sample)), inf_mask) - self.assertEqual(torch.isinf(sample.acosh()), inf_mask) - - # This function tests that a nan value is returned for input values not in domain - @dtypes(torch.float32, torch.float64) - def test_atanh_domain_float(self, device, dtype): - # Domain of atanh is (-1, 1), for edge values (-1 and 1) - output is mapped - # to inf and for other values outside this range - output is mapped to NaN - sample = torch.tensor([float('-inf'), -1.00, 1.00, -1.23, 1.06, float('inf')], - device=device, dtype=dtype) - nan_mask = torch.tensor([True, False, False, True, True, True], device=device) - inf_mask = torch.tensor([False, True, True, False, False, False], device=device) - # For values not in domain (except -1.0 and 1.0), atanh should return nan - self.assertEqual(torch.isnan(torch.atanh(sample)), nan_mask) - self.assertEqual(torch.isnan(sample.atanh()), nan_mask) - # For values -1.0 and 1.0, atanh should return -inf and inf respectively - self.assertEqual(torch.isinf(torch.atanh(sample)), inf_mask) - self.assertEqual(torch.isinf(sample.atanh()), inf_mask) + def _test_addcdiv(): + a = non_zero_rand((2, 2), dtype=dtype, device=device) + b = non_zero_rand((2, 2), dtype=dtype, device=device) + c = non_zero_rand((2, 2), dtype=dtype, device=device) + alpha = _number(0.5, 3, dtype) + + expected = a + (alpha * b) / c + actual = torch.addcdiv(a, b, c, value=alpha) + self.assertEqual(expected, actual) + + with self.maybeWarnsRegex( + UserWarning, "This overload of addcdiv is deprecated"): + self.assertEqual(actual, torch.addcdiv(a, alpha, b, c)) + + if not (dtype.is_floating_point or dtype.is_complex): + # Integer division with addcdiv is prohibited + with self.assertRaises(RuntimeError): + _test_addcdiv() + else: + _test_addcdiv() def test_nullary_op_mem_overlap(self, device): ops = ( @@ -13382,156 +4908,6 @@ def test_nullary_op_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): getattr(x, op)(*args) - # TODO: run on non-native device types - @dtypes(torch.double) - def test_unary_out_op_mem_overlap(self, device, dtype): - sz = 3 - doubles = torch.randn(2 * sz, dtype=dtype, device=device) - positives = torch.randint(1, 100, (2 * sz,), device=device).double() - ints = torch.randint(-100, 100, (2 * sz,), device=device) - unary_mem_overlap_cases = [ - ("abs", doubles, True, True, 'cpu'), - ("abs", doubles, True, True, 'cuda'), - ("acos", doubles, True, True, 'cpu'), - ("acos", doubles, True, True, 'cuda'), - ("asin", doubles, True, True, 'cpu'), - ("asin", doubles, True, True, 'cuda'), - ("atan", doubles, True, True, 'cpu'), - ("atan", doubles, True, True, 'cuda'), - ("acosh", doubles, True, True, 'cpu'), - ("acosh", doubles, True, True, 'cuda'), - ("asinh", doubles, True, True, 'cpu'), - ("asinh", doubles, True, True, 'cuda'), - ("atanh", doubles, True, True, 'cpu'), - ("atanh", doubles, True, True, 'cuda'), - ("bitwise_not", ints, True, True, 'cpu'), - ("bitwise_not", ints, True, True, 'cuda'), - ("ceil", doubles, True, True, 'cpu'), - ("ceil", doubles, True, True, 'cuda'), - ("cos", doubles, True, True, 'cpu'), - ("cos", doubles, True, True, 'cuda'), - ("cosh", doubles, True, True, 'cpu'), - ("cosh", doubles, True, True, 'cuda'), - ("digamma", doubles, True, True, 'cpu'), - ("erf", doubles, True, True, 'cpu'), - ("erf", doubles, True, True, 'cuda'), - ("erfc", doubles, True, True, 'cpu'), - ("erfc", doubles, True, True, 'cuda'), - ("erfinv", doubles, True, True, 'cpu'), - ("erfinv", doubles, True, True, 'cuda'), - ("exp", doubles, True, True, 'cpu'), - ("exp", doubles, True, True, 'cuda'), - ("exp2", doubles, True, True, 'cpu'), - ("exp2", doubles, True, True, 'cuda'), - ("expm1", doubles, True, True, 'cpu'), - ("expm1", doubles, True, True, 'cuda'), - ("floor", doubles, True, True, 'cpu'), - ("floor", doubles, True, True, 'cuda'), - ("frac", doubles, True, True, 'cpu'), - ("frac", doubles, True, True, 'cuda'), - ("i0", doubles, True, True, 'cpu'), - ("i0", doubles, True, True, 'cuda'), - ("log", positives, True, True, 'cpu'), - ("log", positives, True, True, 'cuda'), - ("log10", positives, True, True, 'cpu'), - ("log10", positives, True, True, 'cuda'), - ("log1p", positives, True, True, 'cpu'), - ("log1p", positives, True, True, 'cuda'), - ("log2", positives, True, True, 'cpu'), - ("log2", positives, True, True, 'cuda'), - ("neg", doubles, True, True, 'cpu'), - ("neg", doubles, True, True, 'cuda'), - ("reciprocal", doubles, True, True, 'cpu'), - ("reciprocal", doubles, True, True, 'cuda'), - ("round", doubles, True, True, 'cpu'), - ("round", doubles, True, True, 'cuda'), - ("rsqrt", positives, True, True, 'cpu'), - ("rsqrt", positives, True, True, 'cuda'), - ("sin", doubles, True, True, 'cpu'), - ("sin", doubles, True, True, 'cuda'), - ("sinh", doubles, True, True, 'cpu'), - ("sinh", doubles, False, True, 'cuda'), - ("sigmoid", doubles, True, True, 'cpu'), - ("sigmoid", doubles, True, True, 'cuda'), - ("logit", doubles, True, True, 'cpu'), - ("logit", doubles, True, True, 'cuda'), - ("sqrt", doubles, True, True, 'cpu'), - ("sqrt", doubles, False, True, 'cuda'), - ("tan", doubles, True, True, 'cpu'), - ("tan", doubles, True, True, 'cuda'), - ("tanh", doubles, True, True, 'cpu'), - ("tanh", doubles, True, True, 'cuda'), - ("trunc", doubles, True, True, 'cpu'), - ("trunc", doubles, True, True, 'cuda') - ] - - for (fn, inputs, has_input_output_mem_overlap_check, - has_internal_mem_overlap_check, dev) in unary_mem_overlap_cases: - if dev != device: - continue - out_fn = getattr(torch, fn) - in_fn = getattr(torch.Tensor, fn + '_') - - self.unary_check_input_output_mem_overlap(inputs, sz, out_fn, - expected_failure=not has_input_output_mem_overlap_check) - - self.check_internal_mem_overlap(in_fn, 1, dtype, dev, - expected_failure=not has_internal_mem_overlap_check) - - @dtypes(torch.double) - def test_binary_op_mem_overlap(self, device, dtype): - ops = [ - ("add", True, True, 'cpu'), - ("add", True, True, 'cuda'), - ("mul", True, True, 'cpu'), - ("mul", True, True, 'cuda'), - ("sub", True, True, 'cpu'), - ("sub", True, True, 'cuda'), - ("div", True, True, 'cpu'), - ("div", True, True, 'cuda'), - ("pow", True, True, 'cpu'), - ("pow", True, True, 'cuda'), - ("fmod", True, True, 'cpu'), - ("fmod", True, True, 'cuda'), - ("atan2", True, True, 'cpu'), - ("atan2", True, True, 'cuda'), - ("hypot", True, True, 'cpu'), - ("hypot", True, True, 'cuda'), - ("nextafter", True, True, 'cpu'), - ("nextafter", True, True, 'cuda'), - ("le", True, True, 'cpu'), - ("le", True, True, 'cuda'), - ("lt", True, True, 'cpu'), - ("lt", True, True, 'cuda'), - ("ge", True, True, 'cpu'), - ("ge", True, True, 'cuda'), - ("gt", True, True, 'cpu'), - ("gt", True, True, 'cuda'), - ("eq", True, True, 'cpu'), - ("eq", True, True, 'cuda'), - ("ne", True, True, 'cpu'), - ("ne", True, True, 'cuda'), - ("logical_and", True, True, 'cpu'), - ("logical_and", True, True, 'cuda'), - ("logical_or", True, True, 'cpu'), - ("logical_or", True, True, 'cuda'), - ("logical_xor", True, True, 'cpu'), - ("logical_xor", True, True, 'cuda'), - ] - - for (fn, has_input_output_mem_overlap_check, - has_internal_mem_overlap_check, dev) in ops: - if dev != device: - continue - out_op = getattr(torch, fn) - inplace_op = getattr(torch.Tensor, fn + '_') - self.check_internal_mem_overlap( - inplace_op, 2, dtype, device, - expected_failure=not has_internal_mem_overlap_check) - - self.binary_check_input_output_mem_overlap(out_op, device, - expected_failure=not has_input_output_mem_overlap_check) - @dtypes(torch.double) def test_ternary_op_mem_overlap(self, device, dtype): ops = [ @@ -13556,6 +4932,7 @@ def test_ternary_op_mem_overlap(self, device, dtype): expected_failure=not has_input_output_mem_overlap_check) @dtypes(torch.double) + @onlyOnCPUAndCUDA def test_copy_mem_overlap(self, device, dtype): self.check_internal_mem_overlap( torch.Tensor.copy_, num_inputs=2, dtype=dtype, device=device) @@ -13564,25 +4941,49 @@ def test_copy_mem_overlap(self, device, dtype): self.unary_check_input_output_mem_overlap( doubles, sz, lambda input, out: out.copy_(input)) - @dtypes(torch.double) - def test_pow_scalar_overloads_mem_overlap(self, device, dtype): - sz = 3 - doubles = torch.randn(2 * sz, dtype=dtype, device=device) - self.check_internal_mem_overlap( - lambda t: t.pow_(42), 1, dtype, device) - self.unary_check_input_output_mem_overlap( - doubles, sz, lambda input, out: torch.pow(input, 42, out=out)) - self.unary_check_input_output_mem_overlap( - doubles, sz, lambda input, out: torch.pow(42, input, out=out)) - + @onlyOnCPUAndCUDA def test_index_add_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) y = torch.rand((6,), device=device) - ind = torch.tensor([0, 2, 3], device=device) + ind = torch.tensor([2, 1, 0], device=device) value = torch.rand((3,), device=device) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x.index_add_(0, ind, value) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.index_add_(0, ind, y[:3]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_add_(0, ind, ind.clone()) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_add_(0, ind.clone(), ind) + + @onlyOnCPUAndCUDA + def test_index_copy_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((6,)) + y = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device) + value = torch.rand((3,), device=device) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + x.index_copy_(0, ind, value) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.index_copy_(0, ind, y[:3]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_copy_(0, ind, ind.clone()) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_copy_(0, ind.clone(), ind) + + @onlyOnCPUAndCUDA + def test_index_fill_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((6,)) + y = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device) + value = torch.rand((3,), device=device) + with self.assertWarnsRegex(UserWarning, "index_fill_ on expanded tensors"): + x.index_fill_(0, ind, 1.0) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_fill_(0, ind, 0) + + @onlyOnCPUAndCUDA def test_shift_mem_overlap(self, device): x = torch.rand(3, device=device) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -13590,6 +4991,7 @@ def test_shift_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x[:-1] >>= x[1:] + @onlyOnCPUAndCUDA def test_bernoulli_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) @@ -13603,16 +5005,26 @@ def test_bernoulli_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.bernoulli(torch.rand_like(x), out=x) + @onlyOnCPUAndCUDA def test_index_put_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) y = torch.rand((6,), device=device) - ind = torch.tensor([0, 2, 3], device=device) + ind = torch.tensor([2, 1, 0], device=device) value = torch.rand((3,), device=device) with self.assertWarnsRegex(UserWarning, 'expanded tensors'): x.index_put_((ind,), value) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): y.index_put_((ind,), y[0]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_put_((ind,), ind) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + y.index_put_((ind,), y[:3]) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_put_((ind,), ind.clone()) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.index_put_((ind.clone(),), ind) + @onlyOnCPUAndCUDA def test_masked_fill_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) mask = torch.tensor([True, False, True, True, False, False], device=device) @@ -13623,13 +5035,22 @@ def test_masked_fill_mem_overlap(self, device): with self.assertWarnsRegex(UserWarning, 'expanded tensors'): x.masked_fill_(mask, fill_val) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + mask[1:].masked_fill_(mask[:-1], False) + + @onlyOnCPUAndCUDA def test_masked_select_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((3,)) y = torch.rand((6,), device=device) mask = torch.tensor([True, False, True, True, False, False], device=device) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.masked_select(y, mask, out=x) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.masked_select(y, mask, out=y) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.masked_select(mask.clone(), mask, out=mask) + @onlyOnCPUAndCUDA def test_masked_scatter_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) src = torch.rand((3,), device=device) @@ -13638,6 +5059,7 @@ def test_masked_scatter_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x.masked_scatter_(mask, src) + @onlyOnCPUAndCUDA def test_index_select_mem_overlap(self, device): x = torch.rand((1, 6), device=device).expand((2, 6)) y = torch.rand((3, 6), device=device) @@ -13645,634 +5067,43 @@ def test_index_select_mem_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.index_select(y, 1, ind, out=x) - def test_cat_mem_overlap(self, device): - x = torch.rand((1, 3), device=device).expand((6, 3)) - y = torch.rand((3, 3), device=device) - with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - torch.cat([y, y], out=x) - + @onlyOnCPUAndCUDA def test_scatter_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((6,)) src = torch.rand((3,), device=device) - ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): x.scatter_(0, ind, src) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + src.scatter_(0, ind, src) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + ind.scatter_(0, ind, ind.clone()) + @onlyOnCPUAndCUDA def test_gather_mem_overlap(self, device): x = torch.rand((1,), device=device).expand((3,)) src = torch.rand((6,), device=device) - ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): torch.gather(src, 0, ind, out=x) - - def test_linlogspace_mem_overlap(self, device): - x = torch.rand(1, device=device).expand(10) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - torch.linspace(1, 10, 10, out=x) - + torch.gather(src, 0, ind, out=src) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - torch.logspace(1, 10, 10, out=x) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_int_pow(self, device): - - def _test_integral_pow(dt, range, dev): - tensor = torch.tensor((3, 3), dtype=dt, device=dev).random_(*range) - exps = [0, 1, 2, 4, - torch.tensor((3, 3), dtype=dt, device=dev).random_(0, 5)] - for exp in exps: - self._test_pow(tensor, exp) - - _test_integral_pow(torch.int8, (-3, 4), device) - _test_integral_pow(torch.uint8, (0, 4), device) - _test_integral_pow(torch.int16, (-5, 5), device) - _test_integral_pow(torch.int64, (-10, 10), device) - _test_integral_pow(torch.int32, (-10, 10), device) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_int_tensor_pow_neg_ints(self, device): - ints = [torch.iinfo(torch.int32).min, - -3, -2, -1, 0, 1, 2, 3, - torch.iinfo(torch.int32).max] - neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1] - tensor = torch.tensor(ints, dtype=torch.int32, device=device) - for pow in neg_ints: - self._test_pow(tensor, pow) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_long_tensor_pow_floats(self, device): - ints = [0, 1, 23, 4567] - floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] - tensor = torch.tensor(ints, dtype=torch.int64, device=device) - for pow in floats: - self._test_pow(tensor, pow) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_float_scalar_pow_float_tensor(self, device): - floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, - 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] - tensor = torch.tensor(floats, dtype=torch.float32, device=device) - for base in floats: - self._test_pow(base, tensor) - - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_tensor_pow_tensor(self, dev): - def rotate(l, n): - return l[-n:] + l[:-n] - - def test_tensor_pow_tensor(values, torch_type, numpy_type): - vals_tensor = torch.tensor(values, dtype=torch_type, device=dev) - for i in range(len(values)): - pows = rotate(values, i) - pows_tensor = torch.tensor(pows, dtype=torch_type, device=dev) - self._test_pow(vals_tensor, pows_tensor) - - ints = [0, 1, 2, 3] - test_tensor_pow_tensor(ints, torch.int32, np.int32) - test_tensor_pow_tensor(ints, torch.int64, np.int64) - - floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, - 0.0, - 1 / 3, 1 / 2, 1.0, 2.0, 3.0] - test_tensor_pow_tensor(floats, torch.float32, np.float32) - test_tensor_pow_tensor(floats, torch.float64, np.float64) - - @dtypes(torch.float) - def test_add_with_tail(self, device, dtype): - # test tensor where there is a tail which is not a multiple - # of GPU warp size - for tail_size in [1, 63, 67, 130]: - size = 4096 + tail_size - a = torch.randn(size, device=device, dtype=dtype) - b = torch.randn(size, device=device, dtype=dtype) - c = a + b - for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()): - self.assertEqual(x + y, z) - - def test_logical_xor_with_nontrivial_alignment(self, device): - # test tensor that is not aligned to multiple of 16 bytes - size = 128 - a = (torch.randn(size, device=device) > 0) - b = (torch.randn(size, device=device) > 0) - c = (torch.randn(size, device=device) > 0) - non_trivial_alignment = [1, 2, 4, 8, 15] - for i in non_trivial_alignment: - for j in non_trivial_alignment: - for k in non_trivial_alignment: - a_ = a[i: 100 + i] - b_ = b[j: 100 + j] - c_ = c[k: 100 + k] - torch.logical_xor(a_, b_, out=c_) - for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()): - self.assertEqual(x ^ y, z) - - def test_var_mean_some_dims(self, device): - sizes = (4, 6, 7, 5, 3) - dims = len(sizes) - - x = torch.rand(sizes, device=device) - for num_of_dims in range(2, dims): - dim_list = list(combinations(list(range(dims)), r=num_of_dims)) - for dim in dim_list: - for unbiased in [False, True]: - for keepdim in [False, True]: - var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(var1, var2) - self.assertEqual(mean1, mean2) - - @skipCUDAIfRocm - def test_blas_empty(self, device): - - def fn(torchfn, *args, test_out=False, **kwargs): - def call_torch_fn(*args, **kwargs): - return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape - for shape in args), **kwargs) - result = call_torch_fn(*args, **kwargs) - if not test_out: - return result - else: - out = torch.full_like(result, math.nan) - out1 = call_torch_fn(*args, **kwargs, out=out) - return out - - # mm, addmm - self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) - self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) - self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) - self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) - self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))) - self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True)) - - self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) - self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape) - t = torch.randn((5, 6), device=device) - self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) - self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) - - # mv, addmv - self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) - self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) - self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) - self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)) - - self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) - t = torch.randn((3,), device=device) - self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) - self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) - - # bmm, baddbmm - self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) - self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) - self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))) - self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True)) - - self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape) - self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape) - self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape) - c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) - self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)) # Issue #33467 - self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)) # Issue #33467 - - # addbmm - self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) - t = torch.randn((5, 6), device=device) - self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) - self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) - - # matmul - self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,))) - self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True)) - self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) - self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) - self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4))) - self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True)) - - # dot - self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) - self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True)) - - if torch._C.has_lapack: - # lu - A_LU, pivots = fn(torch.lu, (0, 5, 5)) - self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.lu, (0, 0, 0)) - self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.lu, (2, 0, 0)) - self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) - - @skipCUDAIfRocm - @dtypesIfCUDA(*(torch.float, torch.double, torch.cfloat, torch.cdouble) + - # This test is disabled on CUDA 9, due to: - # See: https://github.com/pytorch/pytorch/issues/31006 - ((torch.half,) if torch.version.cuda and not torch.version.cuda.startswith('9.') else ())) - @dtypes(*(set(torch.testing.get_all_dtypes()) - {torch.half, torch.bool})) - def test_blas_alpha_beta_empty(self, device, dtype): - if dtype is torch.bfloat16 and self.device_type == 'xla': - # TODO (@zasdfgbnm): this causes the following error on test - # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16: - # - # RuntimeError: _th_equal not supported on CPUType for BFloat16 - return - # ensure beta is respected - value = 11 - input = torch.full((2,), value, dtype=dtype, device=device) - mat = torch.ones((2, 0), dtype=dtype, device=device) - vec = torch.ones((0,), dtype=dtype, device=device) - out = torch.empty((2,), dtype=dtype, device=device) - if dtype.is_complex: - alpha = 6 + 7j - beta = 3 + 4j - else: - alpha = 6 - beta = 3 - self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), - torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta)) - self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device), - torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out)) - - # TODO: update this once torch.addmm is supported for complex - if dtype.is_complex and device != 'cpu': - return - - # torch.addmm - input = torch.full((2, 3), value, dtype=dtype, device=device) - mat2 = torch.ones((0, 3), dtype=dtype, device=device) - out = torch.empty((2, 3), dtype=dtype, device=device) - self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), - torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta)) - self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), - torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out)) - - @dtypes(*(torch.testing.get_all_complex_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_blas_nan_out(self, device, dtype): - # These functions should work correctly with NaN filled outputs, - # but need special handling, see [NOTE: cpu_zero] - b = 3 - n = 5 - m = 7 - p = 11 - - # torch.mv - nm = torch.randn((m, n), device=device).t() - _m = torch.randn((), device=device).expand(m) - _m_out = torch.full((m,), float('nan'), device=device) - self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) - self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum()) - - # torch.mm - mp = torch.randn((p, m), device=device).t() - np_out = torch.full((n, p), float('nan'), device=device) - self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out)) - - if dtype.is_complex and device.startswith('cuda'): - return - - # torch.bmm - bnm = torch.randn((b, m, n), device=device).transpose(1, 2) - bmp = torch.randn((b, p, m), device=device).transpose(1, 2) - bnp_out = torch.full((b, n, p), float('nan'), device=device) - self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out)) - - @onlyCPU # not supported by CUBLAS - def test_blas_mv_large_input(self, device): - # This would previously fail if the allocated output had NaNs, see: - # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero] - n = 3000 - m = 200 - - nm = torch.randn((m, n), device=device).t() - _m = torch.randn((), device=device).expand(m) - _m_out = torch.full((m,), 0., device=device) - - self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out)) - - @skipCUDAIfRocm - def test_unique_dim(self, device): - self.assertFalse(hasattr(torch, 'unique_dim')) - - def run_test(device, dtype): - x = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) - x_empty = torch.empty(5, 0, dtype=dtype, device=device) - x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) - x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) - expected_unique_dim0 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) - expected_inverse_dim0 = torch.tensor([0, 0]) - expected_counts_dim0 = torch.tensor([2]) - expected_unique_dim1 = torch.tensor([[[0., 1.], - [1., 1.], - [2., 1.]], - [[0., 1.], - [1., 1.], - [2., 1.]]], - dtype=dtype, - device=device) - expected_unique_dim1_bool = torch.tensor([[[False, True], [True, True]], - [[False, True], [True, True]]], - dtype=torch.bool, - device=device) - expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) - expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0]) - expected_counts_dim1 = torch.tensor([2, 1, 1]) - expected_counts_dim1_bool = torch.tensor([2, 2]) - expected_unique_dim2 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) - expected_inverse_dim2 = torch.tensor([0, 1]) - expected_counts_dim2 = torch.tensor([1, 1]) - expected_unique_empty = torch.tensor([], dtype=dtype, device=device) - expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) - expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) - # dim0 - x_unique = torch.unique(x, dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_inverse_dim0, x_inverse) - - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_counts_dim0, x_counts) - - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_inverse_dim0, x_inverse) - self.assertEqual(expected_counts_dim0, x_counts) - - # dim1 - x_unique = torch.unique(x, dim=1) - if x.dtype == torch.bool: - self.assertEqual(expected_unique_dim1_bool, x_unique) - else: - self.assertEqual(expected_unique_dim1, x_unique) - - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=1) - if x.dtype == torch.bool: - self.assertEqual(expected_unique_dim1_bool, x_unique) - self.assertEqual(expected_inverse_dim1_bool, x_inverse) - else: - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_inverse_dim1, x_inverse) - - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=1) - if x.dtype == torch.bool: - self.assertEqual(expected_unique_dim1_bool, x_unique) - self.assertEqual(expected_counts_dim1_bool, x_counts) - else: - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_counts_dim1, x_counts) - - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=1) - if x.dtype == torch.bool: - self.assertEqual(expected_unique_dim1_bool, x_unique) - self.assertEqual(expected_inverse_dim1_bool, x_inverse) - self.assertEqual(expected_counts_dim1_bool, x_counts) - else: - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_inverse_dim1, x_inverse) - self.assertEqual(expected_counts_dim1, x_counts) - - # dim2 - x_unique = torch.unique(x, dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_inverse_dim2, x_inverse) - - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_counts_dim2, x_counts) - - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_inverse_dim2, x_inverse) - self.assertEqual(expected_counts_dim2, x_counts) - - # test empty tensor - x_unique, x_inverse, x_counts = torch.unique( - x_empty, - return_inverse=True, - return_counts=True, - dim=1) - self.assertEqual(expected_unique_empty, x_unique) - self.assertEqual(expected_inverse_empty, x_inverse) - self.assertEqual(expected_counts_empty, x_counts) - - # test not a well formed tensor - # Checking for runtime error, as this is the expected behaviour - with self.assertRaises(RuntimeError): - torch.unique( - x_ill_formed_empty, - return_inverse=True, - return_counts=True, - dim=1) + torch.gather(ind.clone(), 0, ind[1:], out=ind[:1]) - # test along dim2 - with self.assertRaises(RuntimeError): - torch.unique( - x_ill_formed_empty_another, - return_inverse=True, - return_counts=True, - dim=2) - - # test consecutive version - y = torch.tensor( - [[0, 1], - [0, 1], - [0, 1], - [1, 2], - [1, 2], - [3, 4], - [0, 1], - [0, 1], - [3, 4], - [1, 2]], - dtype=dtype, - device=device - ) - expected_y_unique = torch.tensor( - [[0, 1], - [1, 2], - [3, 4], - [0, 1], - [3, 4], - [1, 2]], - dtype=dtype, - device=device - ) - expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device) - expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device) - expected_y_inverse_bool = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device) - expected_y_counts_bool = torch.tensor([3, 3, 2, 2], dtype=torch.int64, device=device) - y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) - if x.dtype == torch.bool: - self.assertEqual(expected_y_inverse_bool, y_inverse) - self.assertEqual(expected_y_counts_bool, y_counts) - else: - self.assertEqual(expected_y_inverse, y_inverse) - self.assertEqual(expected_y_counts, y_counts) - - run_test(device, torch.float) - run_test(device, torch.double) - run_test(device, torch.long) - run_test(device, torch.uint8) - run_test(device, torch.bool) - - # Tests that CUDA tensors on different devices cannot be used in the same - # binary operation, and that CUDA "scalars" cannot be used in the same - # binary operation as non-scalar CPU tensors. - @deviceCountAtLeast(2) - @onlyCUDA - def test_cross_device_binary_ops(self, devices): - vals = (1., (2.,)) - cpu_tensor = torch.randn(2, 2) - for op in (operator.add, torch.add, - operator.sub, torch.sub, - operator.mul, torch.mul, - operator.truediv, torch.true_divide, - operator.floordiv, torch.floor_divide): - for a, b in product(vals, vals): - a = torch.tensor(a, device=devices[0]) - b = torch.tensor(b, device=devices[1]) - - with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): - op(a, b) - with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): - op(b, a) - with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): - op(a, cpu_tensor) - with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): - op(cpu_tensor, a) - - # This test ensures that a scalar Tensor can be safely used - # in a binary operation in conjunction with a Tensor on all - # available CUDA devices - @deviceCountAtLeast(2) - @onlyCUDA - def test_binary_op_scalar_device_unspecified(self, devices): - scalar_val = torch.tensor(1.) - for default_device in devices: - with torch.cuda.device(default_device): - for device in devices: - device_obj = torch.device(device) - x = torch.rand(3, device=device) - y0 = x * scalar_val - self.assertEqual(y0.device, device_obj) - y1 = scalar_val * x - self.assertEqual(y1.device, device_obj) - self.assertEqual(y0, y1) - - # Tests that CPU scalars (including zero dim tensors) can be used in - # binary operations with CUDA tensors. - @onlyCUDA - def test_cuda_cpu_scalar_binary_ops(self, device): - val_scalar = math.pi - val_tensor = torch.tensor(val_scalar) - for op in (operator.add, torch.add, - operator.sub, torch.sub, - operator.mul, torch.mul, - operator.truediv, torch.true_divide, - operator.floordiv, torch.floor_divide): - for tensor_val in (1, (1,)): - t_cuda = torch.tensor(tensor_val, device=device) - t_cpu = t_cuda.cpu() - for val in (val_scalar, val_tensor): - cpu_result = op(t_cpu, val) - cuda_result = op(t_cuda, val) - self.assertEqual(cpu_result, cuda_result) - - reverse_cpu_result = op(val, t_cpu) - reverse_cuda_result = op(val, t_cuda) - self.assertEqual(reverse_cpu_result, reverse_cuda_result) + @onlyOnCPUAndCUDA + def test_take_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.take(src, ind, out=x) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.take(src, ind, out=src) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.take(ind.clone(), ind[1:], out=ind[:-1]) - @onlyCUDA - def test_ceil_out_mismatch(self, device): - a = torch.randn(1) - b = torch.randn(1, device=device) - self.assertRaises(RuntimeError, lambda: torch.ceil(a, out=b)) - - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_has_storage_numpy(self, device): - for dtype in [np.float32, np.float64, np.int64, - np.int32, np.int16, np.uint8]: - arr = np.array([1], dtype=dtype) - self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.float32).storage()) - self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.double).storage()) - self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.int).storage()) - self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.long).storage()) - self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.uint8).storage()) - - def test_all_any_empty(self, device): - x = torch.ByteTensor().to(device) - self.assertTrue(x.all()) - self.assertFalse(x.any()) - - x = torch.BoolTensor().to(device) - self.assertTrue(x.all()) - self.assertFalse(x.any()) @onlyCUDA def test_multinomial_device_constrain(self, device): @@ -14379,61 +5210,6 @@ def inplace(): x.to(y) x.to(x, copy=True) - @onlyCPU - def test_renorm_ps(self, device): - # full reduction - x = torch.randn(5, 5) - xn = x.numpy() - for p in [1, 2, 3, 4, inf]: - res = x.renorm(p, 1, 1) - expected = x / x.norm(p, 0, keepdim=True).clamp(min=1) - self.assertEqual(res, expected, msg="renorm failed for {}-norm".format(p)) - - @onlyCUDA - def test_topk_noncontiguous_gpu(self, device): - t = torch.randn(20, device=device)[::2] - top1, idx1 = t.topk(5) - top2, idx2 = t.contiguous().topk(5) - self.assertEqual(top1, top2) - self.assertEqual(idx1, idx2) - - @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64) - def test_topk_integral(self, device, dtype): - a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, size=(10,), - dtype=dtype, device=device) - sort_topk = a.sort()[0][-5:].flip(0) - topk = a.topk(5) - self.assertEqual(sort_topk, topk[0]) # check values - self.assertEqual(sort_topk, a[topk[1]]) # check indices - - @dtypesIfCUDA(*([torch.half, torch.float, torch.double] - + ([torch.bfloat16] if TEST_WITH_ROCM else []))) - @dtypes(torch.float, torch.double) - def test_topk_nonfinite(self, device, dtype): - x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype) - val, idx = x.topk(4) - expect = torch.tensor([float('nan'), float('inf'), 1e4, 0], device=device, dtype=dtype) - self.assertEqual(val, expect) - self.assertEqual(idx, [0, 1, 2, 3]) - - val, idx = x.topk(4, largest=False) - expect = torch.tensor([-float('inf'), -1e4, 0, 1e4], device=device, dtype=dtype) - self.assertEqual(val, expect) - self.assertEqual(idx, [5, 4, 3, 2]) - - def test_topk_4d(self, device): - x = torch.ones(2, 3072, 2, 2, device=device) - x[:, 1, :, :] *= 2. - x[:, 10, :, :] *= 1.5 - val, ind = torch.topk(x, k=2, dim=1) - expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device) - expected_ind[:, 1, :, :] = 10 - expected_val = torch.ones(2, 2, 2, 2, device=device) - expected_val[:, 0, :, :] *= 2. - expected_val[:, 1, :, :] *= 1.5 - self.assertEqual(val, expected_val, atol=0, rtol=0) - self.assertEqual(ind, expected_ind, atol=0, rtol=0) - def test_is_signed(self, device): self.assertEqual(torch.IntTensor(5).to(device).is_signed(), True) self.assertEqual(torch.ByteTensor(5).to(device).is_signed(), False) @@ -14499,31 +5275,6 @@ def test_memory_format_preserved_after_permute(self, device): y = ndhwc.permute(0, 1, 4, 3, 2).permute(0, 1, 4, 3, 2) self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d)) - def test_resize_as_preserves_strides(self, device): - x = torch.empty(2, 3).t() - old_strides = x.stride() - x.resize_as_(x) - self.assertEqual(x.stride(), old_strides) - - def test_memory_format_resize_as(self, device): - def test_helper(shape, memory_format, device): - xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format) - flat = torch.randn(xc.numel(), device=device) - flat.resize_as_(xc, memory_format=torch.preserve_format) - self.assertTrue(flat.is_contiguous(memory_format=memory_format)) - - test_helper((10, 3, 32, 32), torch.channels_last, device) - test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device) - - def test_memory_format_resize_(self, device): - def test_helper(shape, numel, memory_format, device): - flat = torch.randn(numel, device=device) - flat.resize_(shape, memory_format=memory_format) - self.assertTrue(flat.is_contiguous(memory_format=memory_format)) - - test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device) - test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device) - def test_memory_format_propagation_rules(self, device): contiguous = torch.rand(10, 3, 5, 5, device=device) @@ -14677,7 +5428,7 @@ def _test_helper(x, y, bias, memory_format): lambda x, y: x.expm1_(), lambda x, y: x.floor(), lambda x, y: x.floor_(), - # lambda x, y: x.fmod(2), # https://github.com/pytorch/pytorch/issues/24565 + lambda x, y: x.fmod(2), lambda x, y: x.frac(), lambda x, y: x.hypot(y), lambda x, y: x.hypot_(y), @@ -14816,213 +5567,6 @@ def compare_strides(s1, s2, div): for x in xs: _test_helper(x, op, unary=True) - - def _test_unique_scalar_empty(self, dtype, device, f): - # test scalar - x = torch.tensor(0, dtype=dtype, device=device) - unique, inverse, counts = f(x, return_inverse=True, return_counts=True) - expected_unique = torch.tensor([0], dtype=dtype, device=device) - expected_inverse = torch.tensor(0, device=device) - expected_counts = torch.tensor([1], device=device) - self.assertEqual(unique, expected_unique) - self.assertEqual(inverse, expected_inverse) - self.assertEqual(counts, expected_counts) - - # test zero sized tensor - x = torch.zeros((0, 0, 3), dtype=dtype, device=device) - unique, inverse, counts = f(x, return_inverse=True, return_counts=True) - expected_unique = torch.tensor([], dtype=dtype, device=device) - expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device) - expected_counts = torch.tensor([], dtype=torch.long, device=device) - self.assertEqual(unique, expected_unique) - self.assertEqual(inverse, expected_inverse) - self.assertEqual(counts, expected_counts) - - def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): - def ensure_tuple(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - for return_inverse in [True, False]: - for return_counts in [True, False]: - # test with expected - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) - self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) - self.assertEqual(expected_unique, ret[0]) - if return_inverse: - self.assertEqual(expected_inverse, ret[1]) - if return_counts: - count_index = 1 + int(return_inverse) - self.assertEqual(expected_counts, ret[count_index]) - - # tests per-element unique on a higher rank tensor. - y = x.view(additional_shape) - y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) - self.assertEqual(expected_unique, y_unique) - self.assertEqual(expected_inverse.view(additional_shape), y_inverse) - self.assertEqual(expected_counts, y_counts) - - @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) - def test_unique(self, device, dtype): - if dtype is torch.half and self.device_type == 'cpu': - return # CPU does not have half support - - def ensure_tuple(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) - expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) - else: - x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) - expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device) - expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) - expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) - - # test sorted unique - fs = [ - lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs), - lambda x, **kwargs: x.unique(sorted=True, **kwargs), - ] - for f in fs: - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) - self._test_unique_scalar_empty(dtype, device, f) - - # test unsorted unique - fs = [ - lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), - lambda x, **kwargs: x.unique(sorted=False, **kwargs) - ] - for f in fs: - self._test_unique_scalar_empty(dtype, device, f) - for return_inverse in [True, False]: - for return_counts in [True, False]: - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) - self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) - x_list = x.tolist() - x_unique_list = ret[0].tolist() - self.assertEqual(expected_unique.tolist(), sorted(x_unique_list)) - if return_inverse: - x_inverse_list = ret[1].tolist() - for i, j in enumerate(x_inverse_list): - self.assertEqual(x_list[i], x_unique_list[j]) - if return_counts: - count_index = 1 + int(return_inverse) - x_counts_list = ret[count_index].tolist() - for i, j in zip(x_unique_list, x_counts_list): - count = 0 - for k in x_list: - if k == i: - count += 1 - self.assertEqual(j, count) - - @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) - def test_unique_consecutive(self, device, dtype): - if dtype is torch.half and self.device_type == 'cpu': - return # CPU does not have half support - - if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device) - expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device) - else: - x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device) - expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device) - expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) - expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device) - - for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]: - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3)) - self._test_unique_scalar_empty(dtype, device, f) - - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_erfinv(self, device, dtype): - # general testing. Narrow the range to avoid accuracy issues - input_values = torch.randn(4, 4, dtype=dtype, device=device).clamp(-0.3, 0.3) - self.assertEqual(input_values.erf().erfinv(), input_values) - # test inf - self.assertTrue(torch.equal(torch.tensor([-1, 1], dtype=dtype, device=device).erfinv(), - torch.tensor([-inf, inf], dtype=dtype, device=device))) - # test nan - self.assertEqual(torch.tensor([-2, 2], dtype=dtype, device=device).erfinv(), - torch.tensor([nan, nan], dtype=dtype, device=device)) - - if dtype == torch.double: - # double precision - a = torch.tensor([0.5, 0.8], dtype=torch.double, device=device).erfinv() - self.assertEqual(a[0].item(), 0.47693627620447, atol=1e-13, rtol=0) - self.assertEqual(a[1].item(), 0.90619380243682, atol=1e-13, rtol=0) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_ctor_with_numpy_array(self, device): - correct_dtypes = [ - np.double, - np.float, - np.float16, - np.int64, - np.int32, - np.int16, - np.int8, - np.uint8, - np.bool, - ] - - incorrect_byteorder = '>' if sys.byteorder == 'little' else '<' - incorrect_dtypes = map(lambda t: incorrect_byteorder + t, ['d', 'f']) - - for dtype in correct_dtypes: - array = np.array([1, 2, 3, 4], dtype=dtype) - - # Upcast - tensor = torch.DoubleTensor(array).to(device) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) - - # Downcast (sometimes) - tensor = torch.FloatTensor(array).to(device) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) - - tensor = torch.HalfTensor(array).to(device) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(*torch.testing.get_all_dtypes()) - def test_numpy_scalar_cmp(self, device, dtype): - if dtype.is_complex: - tensors = (torch.tensor(complex(1, 3), dtype=dtype, device=device), - torch.tensor([complex(1, 3), 0, 2j], dtype=dtype, device=device), - torch.tensor([[complex(3, 1), 0], [-1j, 5]], dtype=dtype, device=device)) - else: - tensors = (torch.tensor(3, dtype=dtype, device=device), - torch.tensor([1, 0, -3], dtype=dtype, device=device), - torch.tensor([[3, 0, -1], [3, 5, 4]], dtype=dtype, device=device)) - - for tensor in tensors: - if dtype == torch.bfloat16: - with self.assertRaises(TypeError): - np_array = tensor.cpu().numpy() - continue - - np_array = tensor.cpu().numpy() - for t, a in product((tensor.flatten()[0], tensor.flatten()[0].item()), - (np_array.flatten()[0], np_array.flatten()[0].item())): - self.assertEqual(t, a) - if dtype == torch.complex64 and torch.is_tensor(t) and type(a) == np.complex64: - # TODO: Imaginary part is dropped in this case. Need fix. - # https://github.com/pytorch/pytorch/issues/43579 - self.assertFalse(t == a) - else: - self.assertTrue(t == a) - def test_dlpack_conversion(self, device): x = torch.randn(1, 2, 3, 4, device=device, dtype=torch.float) z = from_dlpack(to_dlpack(x)) @@ -15073,415 +5617,8 @@ def test_storage_multigpu(self, devices): x = torch.tensor([], device=device) self.assertEqual(x.dtype, x.storage().dtype) - @onlyCPU - @skipCPUIfNoLapack - def test_orgqr_errors(self, device): - test_cases = [ - # input1 size, input2 size, error regex - ((10,), (2,), r"'input' should be 2 dimensional"), - ((10, 6), (20,), r"input.size\(1\) must be greater than or equal to input2.size\(0\)"), - ((6, 10), (5,), r"input.size\(0\) must be greater than or equal to input.size\(1\)"), - ((0, 0), (0,), r"'input' should not be empty") - ] - for a_size, tau_size, error_regex in test_cases: - a = torch.rand(*a_size, device=device) - tau = torch.rand(*tau_size, device=device) - with self.assertRaisesRegex(RuntimeError, error_regex): - torch.orgqr(a, tau) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_lu(self, device): - from torch.testing._internal.common_utils import random_matrix - - def run_test(device, pivot): - def run_subtest(matrix_size, batches, device, pivot, singular=False, a=None): - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if a is None: - a = random_matrix(rows, columns, *batches, **dict(singular=singular)).to(device) - a_LU_info, pivots_info, info_ = a.lu(pivot=pivot, get_infos=True) - self.assertEqual(a_LU_info.size(), torch.Size(batches + (rows, columns))) - self.assertEqual(pivots_info.size(), torch.Size(batches + (min(rows, columns),))) - self.assertEqual(info_.size(), torch.Size(batches)) - # If a randomly generated input matrix is singular, - # then info_ contains indices i such that U[i, i] == - # 0. This however conveys that the factorization was - # successful albeit with a singular input. Therefore, - # we require info.min() >= 0 - self.assertGreaterEqual(info_.min(), 0) - a_LU, pivots = a.lu(pivot=pivot) - self.assertEqual(a_LU, a_LU_info) - self.assertEqual(pivots_info, pivots) - - P, L, U = torch.lu_unpack(a_LU, pivots) - self.assertEqual(P.matmul(L.matmul(U)), a) - - if self.device_type == 'cuda': - # lu without pivoting is implemented only for cuda device - a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True) - P_nopiv, L_nopiv, U_nopiv = torch.lu_unpack(a_LU_info_nopiv, nopiv) - self.assertEqual(P_nopiv.matmul(L_nopiv.matmul(U_nopiv)), a) - k = min(rows, columns) - self.assertEqual(nopiv, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(a.shape[:-2] + (k, ))) - if not singular: - # It is not guaranteed that LU factorization - # without pivoting is able to determine if a - # matrix is singular while LU factorization - # with pivoting is. Therefore, we require the - # equality of info-s only for non-singular - # matrices. - self.assertEqual(info_, info_nopiv) - - for ms, batch in product([3, 5, 7, (4, 2), (3, 4)], [(), (2,), (3,), (3, 5)]): - run_subtest(ms, batch, device, pivot) - run_subtest(ms, batch, device, pivot, singular=True) - - # Reproducer of a magma bug, see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on - a = torch.ones(batch + (ms if isinstance(ms, tuple) else (ms, ms)), dtype=torch.double, device=device) - run_subtest(ms, batch, device, pivot, singular=True, a=a) - - # Info should be positive for rank deficient matrices - a = torch.ones(5, 3, 3, device=device) - self.assertGreater(a.lu(pivot=pivot, get_infos=True)[2][0], 0) - - run_test(device, True) - - if self.device_type == 'cpu': - # Error checking, no pivoting variant on CPU - with self.assertRaisesRegex(RuntimeError, 'lu without pivoting is not implemented on the CPU'): - torch.lu(torch.empty(1, 2, 2), pivot=False) - else: - run_test(device, False) - - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - @dtypes(torch.double) - def test_lu_unpack(self, device, dtype): - def run_test(pivot): - for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)): - a = torch.randn(*shape, dtype=dtype, device=device) - a_lu, p = torch.lu(a, pivot=pivot) - p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p) - self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a) - - run_test(True) - - if self.device_type == 'cuda': - run_test(False) - - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_max_with_inf(self, device, dtype): - a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) - self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item()) - self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item()) - self.assertTrue(torch.max(a).item() == inf) - self.assertTrue(torch.amax(a).item() == inf) - - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_min_with_inf(self, device, dtype): - a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) - self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item()) - self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item()) - self.assertTrue(torch.min(a).item() == -inf) - self.assertTrue(torch.amin(a).item() == -inf) - - def _test_minmax_helper(self, torchfn, reffn, device, dtype, skip_indices=False): - def create_input(shape, device, dtype): - if dtype.is_floating_point: - return torch.randn(*shape, device=device, dtype=dtype) - else: - low = 0 if dtype == torch.bool else -1000 - high = 2 if dtype == torch.bool else 1000 - return torch.randint(low, high, shape, device=device, dtype=dtype) - x = create_input((100, 100), device, dtype) - self.compare_with_numpy(torchfn, reffn, x) - # non contiguous - x = create_input((10, 10, 10), device, dtype) - x = x[:, 4] - self.compare_with_numpy(torchfn, reffn, x) - - def get_values(x): - if istuple(x): - return x[0] - return x - - # indices - if not skip_indices: - size = 5 - x = create_input((size, size), device, dtype) - inputs = (x, x.t()) - dims = (0, 1) - for xinp, d in product(inputs, dims): - self.compare_with_numpy(lambda x: get_values(torchfn(x, d, False)), lambda x: reffn(x, d, keepdims=False), xinp) - result = torchfn(xinp, d, False) - if istuple(result): - v, i = result - if d == 1: - self.assertEqual(xinp[torch.arange(size), i], v, atol=0, rtol=0) - else: - self.assertEqual(xinp[i, torch.arange(size)], v, atol=0, rtol=0) - # nan - if dtype.is_floating_point: - for index in (0, 4, 99): - x = create_input((100,), device, dtype) - x[index] = nan - if not skip_indices: - result = torchfn(x, 0) - v = get_values(result) - self.assertEqual(v, nan) - if istuple(result): - i = result[1] - self.assertEqual(i, index) - self.assertEqual(torchfn(x), nan) - - @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool) - @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) - @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_max(self, device, dtype): - self._test_minmax_helper(torch.max, np.amax, device, dtype) - - @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool) - @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) - @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_min(self, device, dtype): - self._test_minmax_helper(torch.min, np.amin, device, dtype) - - @dtypesIfCPU(torch.float, torch.double, torch.int, torch.long, torch.bool) - @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) + @dtypesIfCUDA(torch.float, torch.double, torch.half) @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_amin(self, device, dtype): - self._test_minmax_helper(torch.amin, np.amin, device, dtype) - - @dtypesIfCPU(torch.float, torch.double, torch.int, torch.long, torch.bool) - @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) - @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_amax(self, device, dtype): - self._test_minmax_helper(torch.amax, np.amax, device, dtype) - - @onlyOnCPUAndCUDA - @dtypesIfCPU(torch.float, torch.double) - @dtypesIfCUDA(torch.half, torch.float) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_aminmax(self, device, dtype): - - def _amin_wrapper(x, dim=None, keepdims=False): - if dim is None: - return torch._aminmax(x)[0] - else: - return torch._aminmax(x, dim, keepdims)[0] - - def _amax_wrapper(x, dim=None, keepdims=False): - if dim is None: - return torch._aminmax(x)[1] - else: - return torch._aminmax(x, dim, keepdims)[1] - - self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype) - self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype) - - @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), torch.testing.get_all_dtypes(include_complex=False))) - def test_maximum_minimum_type_promotion(self, device, dtypes): - a = torch.tensor((0, 1), device=device, dtype=dtypes[0]) - b = torch.tensor((1, 0), device=device, dtype=dtypes[1]) - for op in (torch.maximum, torch.max, torch.minimum, torch.min): - result = op(a, b) - self.assertEqual(result.dtype, torch.result_type(a, b)) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) - def test_maximum_minimum_int_and_bool(self, device, dtype): - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) - rng = np.random.default_rng() - a_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) - b_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) - - for torch_op, alias, numpy_op in ops: - a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) - b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) - tensor_result = torch_op(a_tensor, b_tensor) - alias_result = alias(a_tensor, b_tensor) - - out = torch.empty_like(a_tensor) - torch_op(a_tensor, b_tensor, out=out) - - numpy_result = numpy_op(a_np, b_np) - - self.assertEqual(alias_result, tensor_result) - self.assertEqual(tensor_result, numpy_result) - self.assertEqual(out, numpy_result) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @precisionOverride({torch.bfloat16: 1e-2}) - @dtypes(*(torch.testing.get_all_fp_dtypes())) - def test_maximum_minimum_float(self, device, dtype): - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) - - if dtype == torch.bfloat16: - a_np = np.random.randn(10).astype(np.float64) - b_np = np.random.randn(10).astype(np.float64) - else: - a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) - b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) - - for torch_op, alias, numpy_op in ops: - numpy_result = numpy_op(a_np, b_np) - - a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) - b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) - tensor_result = torch_op(a_tensor, b_tensor) - alias_result = alias(a_tensor, b_tensor) - out = torch.empty_like(a_tensor) - torch_op(a_tensor, b_tensor, out=out) - - self.assertEqual(alias_result, tensor_result) - self.assertEqual(tensor_result, numpy_result) - self.assertEqual(out, numpy_result) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(*(torch.testing.get_all_fp_dtypes())) - def test_maximum_minimum_float_nan_and_inf(self, device, dtype): - # np.maximum and np.minimum functions compare input arrays element-wisely. - # if one of the elements being compared is a NaN, then that element is returned. - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) - a_vals = (float('inf'), -float('inf'), float('nan'), float('nan')) - b_vals = (-float('inf'), float('inf'), float('inf'), float('nan')) - if dtype == torch.bfloat16: - a_np = np.array(a_vals, dtype=np.float64) - b_np = np.array(b_vals, dtype=np.float64) - else: - a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype]) - b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype]) - - for torch_op, alias, numpy_op in ops: - numpy_result = numpy_op(a_np, b_np) - - a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) - b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) - tensor_result = torch_op(a_tensor, b_tensor) - alias_result = alias(a_tensor, b_tensor) - - out = torch.empty_like(a_tensor) - torch_op(a_tensor, b_tensor, out=out) - - self.assertEqual(alias_result, tensor_result) - if dtype == torch.bfloat16: - self.assertEqual(tensor_result, numpy_result, exact_dtype=False) - self.assertEqual(out, numpy_result, exact_dtype=False) - else: - self.assertEqual(tensor_result, numpy_result) - self.assertEqual(out, numpy_result) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) - def test_maximum_minimum_complex(self, device, dtypes): - for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min): - with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): - torch_op(torch.ones(1, device=device, dtype=dtypes[0]), - torch.ones(1, device=device, dtype=dtypes[1])) - - with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): - torch_op(torch.ones(1, device=device, dtype=dtypes[1]), - torch.ones(1, device=device, dtype=dtypes[0])) - - def test_bincount(self, device): - # negative input throws - with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): - torch.bincount(torch.tensor([1, -1], device=device)) - # n-d input, with n > 1 throws - with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): - torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) - # floating input type throws - with self.assertRaisesRegex(RuntimeError, 'not implemented'): - torch.bincount(torch.tensor([1., 0.3], device=device)) - # minlength < 0 throws - with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): - torch.bincount(torch.tensor([1, 3], device=device), - torch.tensor([.2, .2], device=device), - minlength=-1) - # input and weights dim mismatch - with self.assertRaisesRegex(RuntimeError, 'same length'): - torch.bincount(torch.tensor([1, 0], device=device), - torch.tensor([1., 0.3, 0.5], device=device)) - # 1-d input with no elements and default minlength - self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), - torch.zeros(0, dtype=torch.long, device=device)) - # 1-d input with no elements and specified minlength - self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), - torch.zeros(10, dtype=torch.long, device=device)) - - # test tensor method without weights - long_counts = torch.tensor( - [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() - self.assertEqual( - torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), - long_counts) - # test minlength functionality - int_counts = torch.bincount( - torch.tensor([1, 1, 1, 1], device=device), minlength=5) - self.assertEqual( - torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), - int_counts) - # test weights - byte_counts = torch.bincount( - torch.tensor([0, 1, 1, 1, 4], device=device), - torch.tensor([.1, .2, .3, .4, .5], device=device)) - self.assertEqual( - torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) - byte_counts = torch.bincount( - torch.tensor([0, 1, 1, 1, 4], device=device), - torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) - self.assertEqual( - torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.float64), byte_counts) - # test non-contiguous inputs and weights - inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device) - weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) - for i in [0, 1]: - assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" - assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" - # inputs are non-contiguous but weights are contiguous - self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) - # inputs and weights are non-contiguous - self.assertEqual( - inputs[:, 1].bincount(weights[:, 1]), - torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) - # weights are non-contiguous but inputs are contiguous - self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), - torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) - - # test bincount on non-contiguous slices - all0s = torch.zeros((32, 2), dtype=torch.int64, device=device) - self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) - - all1s = torch.ones((32, 2), dtype=torch.int64, device=device) - self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) - - # test large number of bins - global memory use - big_exp = torch.zeros(10000000, device=device) - big_exp[-1] = 50.0 - big_w = torch.tensor([.5] * 100, device=device) - big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w) - self.assertEqual(big_exp, big_out) - # test large input size - big_exp = torch.zeros(2, device=device, dtype=torch.int64) - big_exp[1] = 1000000 - big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount() - self.assertEqual(big_exp, big_out) - - @onlyCUDA - @expectedAlertNondeterministic('_bincount_cuda', fn_has_device_arg=False) - def test_bincount_alert_nondeterministic(self, device): - torch.bincount(torch.tensor([], device=device, dtype=torch.long)) - - @dtypes(torch.float, torch.double, torch.half) def test_multinomial(self, device, dtype): def make_prob_dist(shape, is_contiguous): if is_contiguous: @@ -15582,180 +5719,10 @@ def test_multinomial_rng_state_advance(self, device, dtype): # expect no more than 1 repeating elements generated in 2 attempts self.assertLessEqual(2 * n_sample - samples.unique().size(0), 1) - def test_var_unbiased(self, device): - tensor = torch.randn(100, device=device) - self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) - self.assertEqual(tensor.var(), tensor.var(unbiased=True)) - self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False)) - - tensor = torch.FloatTensor([1.0, 2.0]).to(device) - self.assertEqual(tensor.var(unbiased=True), 0.5) - self.assertEqual(tensor.var(unbiased=False), 0.25) - - tensor = torch.randn(100, device=device) - self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True)) - self.assertEqual(tensor.std(), tensor.std(unbiased=True)) - self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False)) - - def test_var_stability(self, device): - tensor = torch.FloatTensor([2281.5, 2281.25]).to(device) + def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, + memory_format, compare_data=True, default_is_preserve=False): - # Stability for inner dim - self.assertEqual(tensor.var(0), 0.03125) - - # General stability - self.assertEqual(tensor.var(), 0.03125) - - # Stability for outer dimensions - tensor = tensor.unsqueeze(1) - self.assertEqual(tensor.var(0), 0.03125) - - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_mul_intertype_scalar(self, device, dtype): - x = torch.tensor(1.5, dtype=dtype, device=device) - y = torch.tensor(3, dtype=torch.int32, device=device) - - self.assertEqual(x * y, 4.5) - self.assertEqual(y * x, 4.5) - - with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): - y *= x - x *= y - self.assertEqual(x, 4.5) - - @onlyCPU - @dtypes(torch.float, torch.double) - def test_hardshrink(self, device, dtype): - data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2) - self.assertEqual(torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2), - data.hardshrink(0.3)) - self.assertEqual(torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2), - data.hardshrink(0.5)) - - # test default lambd=0.5 - self.assertEqual(data.hardshrink(), data.hardshrink(0.5)) - - # test non-contiguous case - self.assertEqual(torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2), - data.t().hardshrink(0.3)) - - @onlyCPU - @dtypes(torch.float, torch.double) - def test_hardshrink_edge_cases(self, device, dtype) -> None: - def h(values, l_expected): - for l, expected in l_expected.items(): - values_tensor = torch.tensor([float(v) for v in values], - dtype=dtype, device=device) - expected_tensor = torch.tensor([float(v) for v in expected], - dtype=dtype, device=device) - self.assertEqual(expected_tensor == values_tensor.hardshrink(l), - torch.ones_like(values_tensor, dtype=torch.bool)) - - def test_helper(min, max): - h([0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - {0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], - 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], - max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], - inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}) - - test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max) - - @onlyCPU - @slowTest - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - @dtypes(torch.double) - def test_einsum(self, device: torch.device, dtype: torch.dtype) -> None: - # test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f - x = torch.randn(5, dtype=dtype, device=device) - y = torch.randn(7, dtype=dtype, device=device) - A = torch.randn(3, 5, dtype=dtype, device=device) - B = torch.randn(2, 5, dtype=dtype, device=device) - C = torch.randn(2, 3, 5, dtype=dtype, device=device) - D = torch.randn(2, 5, 7, dtype=dtype, device=device) - E = torch.randn(7, 9, dtype=dtype, device=device) - F = torch.randn(2, 3, 5, 7, dtype=dtype, device=device) - G = torch.randn(7, 11, 13, dtype=dtype, device=device) - H = torch.randn(4, 4, dtype=dtype, device=device) - I = torch.randn(3, 4, 4, dtype=dtype, device=device) - l = torch.randn(5, 10, dtype=dtype, device=device) - r = torch.randn(5, 20, dtype=dtype, device=device) - w = torch.randn(30, 10, 20, dtype=dtype, device=device) - test_list: List[Union[Tuple[str, torch.Tensor], - Tuple[str, torch.Tensor, torch.Tensor], - Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]]] = [ - # -- Vector - ("i->", x), # sum - ("i,i->", x, x), # dot - ("i,i->i", x, x), # vector element-wise mul - ("i,j->ij", x, y), # outer - # -- Matrix - ("ij->ji", A), # transpose - ("ij->j", A), # row sum - ("ij->i", A), # col sum - ("ij,ij->ij", A, A), # matrix element-wise mul - ("ij,j->i", A, x), # matrix vector multiplication - ("ij,kj->ik", A, B), # matmul - ("ij,ab->ijab", A, E), # matrix outer product - # -- Tensor - ("aij,ajk->aik", C, D), # batch matmul - ("ijk,jk->i", C, A), # tensor matrix contraction - ("aij,jk->aik", D, E), # tensor matrix contraction - ("abcd,dfg->abcfg", F, G), # tensor tensor contraction - ("ijk,jk->ik", C, A), # tensor matrix contraction with double indices - ("ijk,jk->ij", C, A), # tensor matrix contraction with double indices - ("ijk,ik->j", C, B), # non contiguous - ("ijk,ik->jk", C, B), # non contiguous with double indices - # -- Diagonal - ("ii", H), # trace - ("ii->i", H), # diagonal - # -- Ellipsis - ("i...->...", H), - ("ki,...k->i...", A.t(), B), - ("k...,jk", A.t(), B), - ("...ii->...i", I), # batch diagonal - # -- Other - ("bn,anm,bm->ba", l, w, r), # as torch.bilinear - ("... ii->...i ", I), # batch diagonal with spaces - ] - for test in test_list: - actual = torch.einsum(test[0], test[1:]) - expected = np.einsum(test[0], *[t.numpy() for t in test[1:]]) - self.assertEqual(expected.shape, actual.shape, msg=test[0]) - self.assertEqual(expected, actual, msg=test[0]) - # test vararg - actual2 = torch.einsum(test[0], *test[1:]) - self.assertEqual(expected.shape, actual2.shape, msg=test[0]) - self.assertEqual(expected, actual2, msg=test[0]) - - def do_einsum(*args): - return torch.einsum(test[0], args) - # FIXME: following test cases fail gradcheck - if test[0] not in {"i,i->", "i,i->i", "ij,ij->ij"}: - gradcheck_inps = tuple(t.detach().requires_grad_() for t in test[1:]) - self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps)) - self.assertTrue(A._version == 0) # check that we do not use inplace ops - - @onlyCPU - @dtypes(torch.bool, torch.double) - def test_sum_all(self, device, dtype) -> None: - def check_sum_all(tensor: torch.Tensor) -> None: - pylist = tensor.reshape(-1).tolist() - self.assertEqual(tensor.sum(), sum(pylist)) - - if dtype != torch.bool: - check_sum_all(torch.tensor([1, 2, 3, 4, 5], dtype=dtype, device=device)) - check_sum_all(torch.randn(200000, dtype=dtype, device=device)) - check_sum_all(torch.randn(2000, 2, dtype=dtype, device=device)[:, 0]) - else: - check_sum_all(torch.tensor([True, False, True], dtype=torch.bool, device=device)) - - def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, - memory_format, compare_data=True, default_is_preserve=False): - - assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d) + assert(memory_format == torch.channels_last or memory_format == torch.channels_last_3d) # xc is a channels last tensor xc = input_generator_fn(device) @@ -15850,20 +5817,6 @@ def transformation_fn(tensor, **kwargs): self._test_memory_format_transformations( device, get_generator(mf, shape), transformation_fn, mf, True, default_is_preserve=True) - @onlyCPU - @dtypes(torch.double) - def test_sum_out(self, device, dtype: torch.dtype) -> None: - x = torch.rand(100, 100, dtype=dtype, device=device) - res1 = torch.sum(x, 1) - res2 = torch.tensor((), dtype=dtype, device=device) - torch.sum(x, 1, out=res2) - self.assertEqual(res1, res2) - x = torch.rand(100, 100, 100, dtype=dtype, device=device) - res1 = x.sum(2).sum(1) - res2 = torch.tensor((), dtype=dtype, device=device) - torch.sum(x, (2, 1), out=res2) - self.assertEqual(res1, res2) - def test_memory_format_factory_like_functions_preserve(self, device): def get_generator(memory_format, shape): def input_generator_fn(device): @@ -15882,2070 +5835,67 @@ def input_generator_fn(device): formats_shapes = ( (torch.channels_last, (4, 3, 8, 8)), - (torch.channels_last_3d, (4, 3, 8, 8, 8))) - - for mf, shape, in formats_shapes: - for transformation_fn in transformation_fns: - self._test_memory_format_transformations( - device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True) - - def test_memory_format_type_shortcuts(self, device): - def get_generator(memory_format, shape, dtype): - def input_generator_fn(device): - return torch.randn(shape, device=device, dtype=dtype).clamp(0, 1) \ - .round().contiguous(memory_format=memory_format) - return input_generator_fn - - - def get_fn(fn_name): - def transformation_fn(tensor, **kwargs): - fn = getattr(tensor, fn_name) - return fn(**kwargs) - return transformation_fn - - shortcuts = ['byte', 'char', 'double', 'bool', 'half', 'int', 'long', 'short'] - if device == 'cpu': - shortcuts += ['bfloat16'] - - formats_shapes = ( - (torch.channels_last, (4, 3, 8, 8)), - (torch.channels_last_3d, (4, 3, 8, 8, 8))) - - for mf, shape in formats_shapes: - for fn_name in shortcuts: - self._test_memory_format_transformations( - device, get_generator(mf, shape, torch.float32), get_fn(fn_name), mf, default_is_preserve=True) - - # Test 'float' separately to avoid float->float no-op. - for mf, shape in formats_shapes: - self._test_memory_format_transformations( - device, get_generator(mf, shape, torch.float64), get_fn('float'), mf, default_is_preserve=True) - - @onlyCUDA - def test_memory_format_cpu_and_cuda_ops(self, device): - def get_generator(memory_format, shape): - def input_generator_fn(device): - return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format) - return input_generator_fn - - def transformation_cpu_fn(tensor, **kwargs): - return tensor.cpu(**kwargs) - - def transformation_cuda_fn(tensor, **kwargs): - return tensor.cuda(**kwargs) - - formats_shapes = ( - (torch.channels_last, (4, 3, 8, 8)), - (torch.channels_last_3d, (4, 3, 8, 8, 8))) - - for mf, shape in formats_shapes: - self._test_memory_format_transformations( - 'cuda', get_generator(mf, shape), transformation_cpu_fn, mf, default_is_preserve=True) - self._test_memory_format_transformations( - 'cpu', get_generator(mf, shape), transformation_cuda_fn, mf, default_is_preserve=True) - - @onlyCPU - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_eig(self, device, dtype): - a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00), - (-6.49, 3.80, 0.00, 0.00, 0.00), - (-0.47, -6.39, 4.17, 0.00, 0.00), - (-7.20, 1.50, -1.51, 5.70, 0.00), - (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous().to(dtype=dtype, device=device) - e = torch.eig(a)[0] - ee, vv = torch.eig(a, True) - te = torch.tensor((), dtype=dtype, device=device) - tv = torch.tensor((), dtype=dtype, device=device) - eee, vvv = torch.eig(a, True, out=(te, tv)) - self.assertEqual(e, ee, atol=1e-12, rtol=0) - self.assertEqual(ee, eee, atol=1e-12, rtol=0) - self.assertEqual(ee, te, atol=1e-12, rtol=0) - self.assertEqual(vv, vvv, atol=1e-12, rtol=0) - self.assertEqual(vv, tv, atol=1e-12, rtol=0) - - # test reuse - X = torch.randn(4, 4, dtype=dtype, device=device) - X = torch.mm(X.t(), X) - e = torch.zeros(4, 2, dtype=dtype, device=device) - v = torch.zeros(4, 4, dtype=dtype, device=device) - torch.eig(X, True, out=(e, v)) - Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) - self.assertEqual(X, Xhat, atol=1e-8, rtol=0, msg='VeV\' wrong') - self.assertFalse(v.is_contiguous(), 'V is contiguous') - - torch.eig(X, True, out=(e, v)) - Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t())) - self.assertEqual(X, Xhat, atol=1e-8, rtol=0, msg='VeV\' wrong') - self.assertFalse(v.is_contiguous(), 'V is contiguous') - - # test non-contiguous - X = torch.randn(4, 4, dtype=dtype, device=device) - X = torch.mm(X.t(), X) - e = torch.zeros(4, 2, 2, dtype=dtype, device=device)[:, 1] - v = torch.zeros(4, 2, 4, dtype=dtype, device=device)[:, 1] - self.assertFalse(v.is_contiguous(), 'V is contiguous') - self.assertFalse(e.is_contiguous(), 'E is contiguous') - torch.eig(X, True, out=(e, v)) - Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) - self.assertEqual(X, Xhat, atol=1e-8, rtol=0, msg='VeV\' wrong') - - # test invalid input - self.assertRaisesRegex( - RuntimeError, - 'A should be 2 dimensional', - lambda: torch.eig(torch.ones((2)))) - self.assertRaisesRegex( - RuntimeError, - 'A should be square', - lambda: torch.eig(torch.ones((2, 3)))) - self.assertRaisesRegex( - RuntimeError, - 'A should not contain infs or NaNs', - lambda: torch.eig(np.inf * torch.ones((2, 2)))) - self.assertRaisesRegex( - RuntimeError, - 'A should not contain infs or NaNs', - lambda: torch.eig(np.nan * torch.ones((2, 2)))) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_lobpcg_basic(self, device, dtype): - self._test_lobpcg_method(device, dtype, 'basic') - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_lobpcg_ortho(self, device, dtype): - self._test_lobpcg_method(device, dtype, 'ortho') - - def _test_lobpcg_method(self, device, dtype, method): - from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix - from torch._linalg_utils import matmul, qform - from torch._lobpcg import lobpcg - - def test_tracker(worker): - k = worker.iparams['k'] - nc = worker.ivars['converged_count'] - if k <= nc: - tol = worker.fparams['tol'] - rerr = worker.tvars['rerr'] - X = worker.X - E = worker.E - B = worker.B - A = worker.A - dtype = X.dtype - device = X.device - - # Check convergence - self.assertLessEqual(rerr[:k].max(), tol) - - # Check B-orthogonality - I = torch.eye(k, k, dtype=dtype, device=device) - self.assertEqual(qform(B, X[:, :k]), I) - - # Check block equation - self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0) - - orig_lobpcg = lobpcg - - def lobpcg(*args, **kwargs): - kwargs['tracker'] = test_tracker - kwargs['niter'] = 1000 - kwargs['method'] = method - kwargs['tol'] = 1e-8 - return orig_lobpcg(*args, **kwargs) - prec = 5e-4 - - # check dense input - mm = torch.matmul - for batches in [(), (2,), (2, 3)]: - for m, n, k in [ - (9, 3, 1), - (9, 3, 2), - (9, 2, 2), - (100, 15, 5), - ]: - # skip tests that are known to fail with the basic - # LOBPCG method due to calling cholesky on singular - # input - if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]: - continue - A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) - B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype) - - # classical eigenvalue problem, smallest eigenvalues - E, V = lobpcg(A, k=k, n=n, largest=False) - self.assertEqual(E.shape, batches + (k,)) - self.assertEqual(V.shape, batches + (m, k)) - self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) - e = torch.symeig(A)[0] - e_smallest = e[..., :k] - self.assertEqual(E, e_smallest) - - # classical eigenvalue problem, largest eigenvalues - E, V = lobpcg(A, k=k, n=n, largest=True) - e_largest, _ = torch.sort(e[..., -k:], descending=True) - self.assertEqual(E, e_largest, atol=prec, rtol=0) - self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) - - # generalized eigenvalue problem, smallest eigenvalues - E, V = lobpcg(A, B=B, k=k, n=n, largest=False) - self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0) - - # generalized eigenvalue problem, largest eigenvalues - E, V = lobpcg(A, B=B, k=k, n=n, largest=True) - self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), - atol=prec, rtol=0) - - # check sparse input - for m, n, k, density in [ - (5, 1, 1, 0.8), - (9, 3, 2, 0.5), - (100, 1, 1, 0.1), - (1000, 7, 3, 0.01), - ]: - # skip tests that are known to fail with the basic LOBCG - # method due to insufficient accuracy - if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]: - continue - A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) - B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype) - A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m - e_smallest = A_eigenvalues[..., :k] - e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True) - - # classical eigenvalue problem, smallest eigenvalues - E, V = lobpcg(A, k=k, n=n, largest=False) - self.assertEqual(E, e_smallest) - self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) - - # classical eigenvalue problem, largest eigenvalues - E, V = lobpcg(A, k=k, n=n, largest=True) - self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0) - self.assertEqual(E, e_largest) - - # generalized eigenvalue problem, smallest eigenvalues - E, V = lobpcg(A, B=B, k=k, n=n, largest=False) - self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0) - - # generalized eigenvalue problem, largest eigenvalues - E, V = lobpcg(A, B=B, k=k, n=n, largest=True) - self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()), - atol=prec, rtol=0) - - @skipCPUIfNoLapack - @onlyCPU - @dtypes(torch.double) - def test_lobpcg_torchscript(self, device, dtype): - from torch.testing._internal.common_utils import random_sparse_pd_matrix - from torch._linalg_utils import matmul as mm - - lobpcg = torch.jit.script(torch.lobpcg) - - m = 500 - k = 5 - A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) - X1 = torch.randn((m, k), dtype=dtype, device=device) - E1, V1 = lobpcg(A1, X=X1) - eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() - self.assertLess(eq_err, 1e-6) - - @unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1") - @skipCPUIfNoLapack - @onlyCPU - @dtypes(torch.double) - def test_lobpcg_scipy(self, device, dtype): - """Compare torch and scipy.sparse.linalg implementations of lobpcg - """ - import time - import scipy - from torch.testing._internal.common_utils import random_sparse_pd_matrix - from torch._linalg_utils import matmul as mm - from scipy.sparse.linalg import lobpcg as scipy_lobpcg - import scipy.sparse - - def toscipy(A): - if A.layout == torch.sparse_coo: - values = A.coalesce().values().cpu().numpy().copy() - indices = A.coalesce().indices().cpu().numpy().copy() - return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape) - return A.cpu().numpy().copy() - - niter = 1000 - repeat = 10 - m = 500 # size of the square matrix - k = 7 # the number of requested eigenpairs - A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) - B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype) - X1 = torch.randn((m, k), dtype=dtype, device=device) - - A2 = toscipy(A1) - B2 = toscipy(B1) - X2 = toscipy(X1) - - lambdas1 = [] - - def tracker(worker): - lambdas1.append(worker.E[:]) - - tol = 1e-8 - # tol for scipy lobpcg will be choosed so that the number of - # iterations will be equal or very close to pytorch lobpcg - # (that is around 170-180) - - # Standard eigenvalue problem - E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) - E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol) - iters1 = len(lambdas1) - iters2 = len(lambdas2) - self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) - - E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False) - - eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() - eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() - self.assertLess(eq_err, 1e-6) # std - self.assertLess(eq_err_scipy, 1e-6) # std - - self.assertEqual(E1, torch.from_numpy(E2.copy())) - - # Generalized eigenvalue problem - lambdas1 = [] - - def tracker(worker): - lambdas1.append(worker.E[:]) - - E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) - E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol) - E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False) - iters1 = len(lambdas1) - iters2 = len(lambdas2) - self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2)) - - eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() - eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() - self.assertLess(eq_err, 1e-6) # general - self.assertLess(eq_err_scipy, 1e-6) # general - - self.assertEqual(E1, torch.from_numpy(E2.copy())) - - # Timings - elapsed_ortho = 0 - elapsed_ortho_general = 0 - elapsed_scipy = 0 - elapsed_general_scipy = 0 - for i in range(repeat): - start = time.time() - torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol) - end = time.time() - elapsed_ortho += end - start - - start = time.time() - torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol) - end = time.time() - elapsed_ortho_general += end - start - - start = time.time() - scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol) - end = time.time() - elapsed_scipy += end - start - - start = time.time() - scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol) - end = time.time() - elapsed_general_scipy += end - start - - elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat - elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat - elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat - elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat - - print(''' -CPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg -------------------------------------------------------- - | standard | generalized | method -torch.lobpcg | {:10.2f} | {:10.2f} | ortho -scipy_lobpcg | {:10.2f} | {:10.2f} | N/A --(input size: {:4}, eigenpairs:{:2}, units: ms per call)- - '''.format(elapsed_ortho_ms, elapsed_ortho_general_ms, - elapsed_scipy_ms, elapsed_general_scipy_ms, - m, k)) - - # Handling of very small tolerence - tol = 1e-100 - - lambdas1 = [] - - def tracker(worker): - lambdas1.append(worker.E[:]) - - E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol) - iters1 = len(lambdas1) - eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max() - - try: - E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) - iters2 = len(lambdas2) - eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max() - except Exception as msg: - print('Calling scipy_lobpcg failed [standard]:', msg) - iters2 = -1 - eq_err_scipy = -1 - - lambdas1 = [] - - def tracker(worker): - lambdas1.append(worker.E[:]) - - E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol) - iters1_general = len(lambdas1) - eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max() - - try: - E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol) - iters2_general = len(lambdas2) - eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max() - except Exception as msg: - print('Calling scipy_lobpcg failed [generalized]:', msg) - iters2_general = -1 - eq_err_general_scipy = -1 - - print('''\ -Handling of small tol={:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg ----------------------------------------------------------------------------- - | standard | generalized | niter | method -torch.lobpcg | {:10.2e} | {:10.2e} | {:6} | ortho -scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A ----(input size: {:4}, eigenpairs:{:2}, units: relative error, maxiter={:4})--- -'''.format(tol, eq_err, eq_err_general, iters1, eq_err_scipy, eq_err_general_scipy, iters2, m, k, niter)) - - def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False): - dtype = t.dtype - numpy_dtype = dtype - if dtype in {torch.bfloat16}: - numpy_dtype = torch.float - if dtype.is_complex: - alpha = 0.9 + 0.3j if alpha is None else alpha - beta = 0.5 + 0.6j if beta is None else beta - else: - alpha = 1.2 if alpha is None else alpha - beta = 0.8 if beta is None else beta - res1 = f(t, m, v, alpha=alpha, beta=beta) - res2 = torch.full_like(res1, math.nan) - if transpose_out: - res2 = res2.t().clone(memory_format=torch.contiguous_format).t() - f(t, m, v, alpha=alpha, beta=beta, out=res2) - res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) - if beta != 0: - res3 += (beta * t).to(numpy_dtype).cpu().numpy() - res3 = torch.from_numpy(res3).to(dtype) - self.assertEqual(res1, res2) - self.assertEqual(res1, res3) - - @precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8, - torch.cfloat: 1e-4, torch.cdouble: 1e-8}) - @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) - @dtypes(torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_addmv(self, device, dtype): - # have to use torch.randn(...).to(bfloat16) instead of - # torch.randn(..., dtype=bfloat16). randn does not support - # bfloat16 yet. - ts = [ - torch.randn(10, device=device).to(dtype), - torch.randn(1, device=device).to(dtype).expand(10), - ] - vs = [ - torch.randn(100, device=device).to(dtype), - torch.ones(1, device=device).to(dtype).expand(100), # to reduce errors for low precision - ] - ms = [ - # 0d - torch.ones((), device=device).to(dtype).expand(10, 100), # to reduce errors for low precision - # 1d - torch.randn((1, 100), device=device).to(dtype).expand(10, 100), - # this initialization reduces errors for low precision for broadcasted matrices - # by making sure that intermediate and result values are exactly representable - # in low precision type - torch.randint(3, (10, 1), dtype=torch.float, device=device).to(dtype).expand(10, 100), - # 2d - torch.randn((10, 100), device=device).to(dtype), - torch.randn((100, 10), device=device).to(dtype).t(), - ] - for m, v, t in product(ms, vs, ts): - self._test_addmm_addmv(torch.addmv, t, m, v) - # Test beta=0, t=nan - t = torch.full((10,), math.nan, device=device).to(dtype) - for m, v in product(ms, vs): - self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) - - @dtypesIfCUDA(*([torch.half, torch.float, torch.double] - + ([torch.bfloat16] if TEST_WITH_ROCM else []))) - @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): - # tests (o, s)*(s). o is output size, s is summed size. - o = 5 - s = 3 - a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) - x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) - y_data = torch.ones(o, device=device, dtype=dtype) - control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype) - - def _test(row_major, incx, incy, lda_tail): - if row_major: - a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype) - else: - a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0) - a = a_storage[:o, :s].copy_(a_data) - - x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype) - x = x_storage[:, 0].copy_(x_data) - - y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype) - y = y_storage[:, 0].copy_(y_data) - - self._test_addmm_addmv(torch.addmv, y, a, x) - - for row_major, incx, incy, lda_tail in product((False, True), (1, 2), (1, 2), (0, 1)): - _test(row_major, incx, incy, lda_tail) - - @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, - torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) - @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) - @dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes()) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @tf32_on_and_off(0.05) - def test_addmm(self, device, dtype): - M = torch.randn(10, 25, device=device).to(dtype) - m1 = torch.randn(10, 50, device=device).to(dtype) - m2 = torch.randn(50, 25, device=device).to(dtype) - self._test_addmm_addmv(torch.addmm, M, m1, m2) - - # Test 0-strided - M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) - m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50) - m2 = torch.randn(50, 25, device=device).to(dtype) - self._test_addmm_addmv(torch.addmm, M, m1, m2) - - # Test beta=0, M=nan - M = torch.full((10, 25), math.nan, device=device).to(dtype) - m1 = torch.randn(10, 50, device=device).to(dtype) - m2 = torch.randn(50, 25, device=device).to(dtype) - self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0) - - # Test transpose - for t1, t2, t3, t4 in product([True, False], repeat=4): - def maybe_transpose(cond, m): - if not cond: - return m - return m.t().clone(memory_format=torch.contiguous_format).t() - - M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) - m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) - m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) - self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4) - - @dtypes(torch.float, torch.double) - @dtypesIfCUDA(*([torch.float, torch.double] + - ([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes()))) - @tf32_on_and_off(0.005) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_addmm_sizes(self, device, dtype): - for m in [0, 1, 25]: - for n in [0, 1, 10]: - for k in [0, 1, 8]: - M = torch.randn(n, m, device=device).to(dtype) - m1 = torch.randn(n, k, device=device).to(dtype) - m2 = torch.randn(k, m, device=device).to(dtype) - self._test_addmm_addmv(torch.addmm, M, m1, m2) - - def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn): - def compare_with_numpy_bin_op(torch_fn, np_fn, x, y): - y_np = y.cpu().numpy() - - # `compare_with_numpy` takes care of moving `x` to correct device for calling np_fn. - self.compare_with_numpy(lambda inp: torch_fn(inp, y), lambda inp: np_fn(inp, y_np), x) - - # Use this tensor for out variant tests. - out = torch.randn((), dtype=dtype, device=device) - - def compare_out_variant(torch_fn, x, y): - torch_fn(v1, v2, out=out) - self.assertEqual(torch_fn(v1, v2), out) - - for _ in range(10): - numel = random.randint(10, 1000) - v1 = torch.randn(numel, dtype=dtype, device=device) - v2 = torch.randn(numel, dtype=dtype, device=device) - compare_with_numpy_bin_op(torch_fn, np_fn, v1, v2) - compare_out_variant(torch_fn, v1, v2) - - # Test 0-strided - v3 = torch.randn(1, dtype=dtype, device=device).expand(numel) - compare_with_numpy_bin_op(torch_fn, np_fn, v1, v3) - compare_out_variant(torch_fn, v1, v3) - - compare_with_numpy_bin_op(torch_fn, np_fn, v3, v1) - compare_out_variant(torch_fn, v3, v1) - - # Test stride greater than 1 - v4 = torch.randn(numel, numel, dtype=dtype, device=device)[:, numel - 1] - compare_with_numpy_bin_op(torch_fn, np_fn, v1, v4) - compare_out_variant(torch_fn, v1, v4) - - compare_with_numpy_bin_op(torch_fn, np_fn, v4, v1) - compare_out_variant(torch_fn, v4, v1) - - @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) - @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - def test_dot_vs_numpy(self, device, dtype): - self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot) - - @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) - @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - def test_vdot_vs_numpy(self, device, dtype): - self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot) - - def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False): - if complex_dtypes: - x = torch.randn(1, dtype=torch.cfloat, device=device) - y = torch.randn(3, dtype=torch.cdouble, device=device) - else: - x = torch.randn(1, dtype=torch.float, device=device) - y = torch.randn(3, dtype=torch.double, device=device) - - with self.assertRaisesRegex(RuntimeError, - 'dot : expected both vectors to have same dtype'): - torch_fn(x, y) - - with self.assertRaisesRegex(RuntimeError, - '1D tensors expected'): - torch_fn(x.reshape(1, 1), y) - - with self.assertRaisesRegex(RuntimeError, - 'inconsistent tensor size'): - torch_fn(x.expand(9), y.to(x.dtype)) - - if self.device_type != 'cpu': - x_cpu = x.expand(3).cpu() - - with self.assertRaisesRegex(RuntimeError, - 'expected all tensors to be on the same device'): - torch_fn(x_cpu, y.to(x.dtype)) - - @onlyOnCPUAndCUDA - def test_vdot_invalid_args(self, device): - self._test_dot_vdot_invalid_args(device, torch.vdot) - self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True) - - @onlyOnCPUAndCUDA - def test_dot_invalid_args(self, device): - self._test_dot_vdot_invalid_args(device, torch.dot) - self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True) - - @onlyCPU - @slowTest - @dtypes(torch.float) - def test_exp_slow(self, device, dtype): - # Test for https://github.com/pytorch/pytorch/issues/17271 - # This is pretty slow on my Macbook but it only takes a few - # seconds on a beefy Xeon server - a = torch.exp(torch.ones(2 ** 31, dtype=dtype, device=device)) - b = torch.exp(torch.ones(1, dtype=dtype, device=device)) - self.assertEqual(a, b.expand(2 ** 31)) - - @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_hardswish(self, device, dtype): - inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] - expectedOutput = np.multiply( - inputValues, - np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0) - precision_4dps = 0.0002 - - inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) - expectedOutputTensor = \ - torch.tensor(expectedOutput, dtype=dtype, device=device) - - # normal - self.assertEqual(torch.nn.functional.hardswish(inputTensor), - expectedOutputTensor, - atol=precision_4dps, rtol=0) - - # inplace - inputTensorCpy = inputTensor.clone().detach() - torch.nn.functional.hardswish(inputTensorCpy, inplace=True) - self.assertEqual(inputTensorCpy, expectedOutputTensor, - atol=precision_4dps, rtol=0) - - @onlyCPU - @dtypes(torch.float, torch.double) - def test_sigmoid(self, device, dtype): - # TODO: why not simulate math.sigmoid like with rsqrt? - inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000] - expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000] - precision_4dps = 0.0002 - - self.assertEqual(torch.tensor(inputValues, dtype=dtype, device=device).sigmoid(), - torch.tensor(expectedOutput, dtype=dtype, device=device), - atol=precision_4dps, rtol=0) - - @dtypes(torch.float, torch.double) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_hardsigmoid(self, device, dtype): - inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] - expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 - - inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) - precision_4dps = 0.0002 - - # normal - self.assertEqual(torch.nn.functional.hardsigmoid(inputTensor), - torch.tensor(expectedOutput, dtype=dtype, device=device), - atol=precision_4dps, rtol=0) - - # inplace - inputTensorCpy = inputTensor.clone().detach() - self.assertEqual(torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True), - torch.tensor(expectedOutput, dtype=dtype, device=device), - atol=precision_4dps, rtol=0) - - @skipIfNoSciPy - @dtypes(torch.float, torch.double) - def test_silu(self, device, dtype): - input_np = np.random.randn(5, 8) - special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] - input_np = np.concatenate((input_np, special_input), axis=0).astype( - torch_to_numpy_dtype_dict[dtype]) - expected_output_np = input_np * scipy.special.expit(input_np) - - expected_output = torch.from_numpy(expected_output_np).to(device) - expected_output_noncontig = expected_output.transpose(0, 1) - - atol = 1e-6 - rtol = 1e-6 - - input = torch.from_numpy(input_np).clone().contiguous().to(device) - self.assertEqual(torch.nn.functional.silu(input), expected_output, - atol=atol, rtol=rtol) - self.assertEqual(torch.nn.functional.silu(input, inplace=True), - expected_output, atol=atol, rtol=rtol) - - input = torch.from_numpy(input_np).clone().to(device) - input_noncontig = input.transpose(0, 1) - self.assertEqual(torch.nn.functional.silu(input_noncontig), - expected_output_noncontig, atol=atol, rtol=rtol) - self.assertEqual(torch.nn.functional.silu( - input_noncontig, inplace=True), expected_output_noncontig, - atol=atol, rtol=rtol) - - - @onlyCPU - @dtypes(torch.float) - def test_diag_embed(self, device, dtype): - x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4) - result = torch.diag_embed(x) - expected = torch.stack([torch.diag(r) for r in x], 0) - self.assertEqual(result, expected) - - result = torch.diag_embed(x, offset=1, dim1=0, dim2=2) - expected = torch.stack([torch.diag(r, 1) for r in x], 1) - self.assertEqual(result, expected) - - @onlyCPU - @dtypes(*torch.testing.get_all_dtypes()) - def test_sub(self, device, dtype): - m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device) - m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device) - - if dtype == torch.bool: - self.assertRaises(RuntimeError, lambda: m1 - m2) - elif (dtype == torch.bfloat16 or dtype == torch.half): - # bfloat16 has a lower precision so we have to have a separate check for it - self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype), atol=0.01, rtol=0) - else: - self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype)) - - @onlyCPU - @dtypes(torch.float) - def test_csub(self, device, dtype): - # with a tensor - a = torch.randn(100, 90, dtype=dtype, device=device) - b = a.clone().normal_() - - res_add = torch.add(a, b, alpha=-1) - res_csub = a.clone() - res_csub.sub_(b) - self.assertEqual(res_add, res_csub) - - # with a scalar - a = torch.randn(100, 100, dtype=dtype, device=device) - - scalar = 123.5 - res_add = torch.add(a, -scalar) - res_csub = a.clone() - res_csub.sub_(scalar) - self.assertEqual(res_add, res_csub) - - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_min_max_binary_op_nan(self, device, dtype): - a = torch.rand(1000, dtype=dtype, device=device) - b = torch.rand(1000, dtype=dtype, device=device) - - # 0:250: a -- nan, b -- not nan - a[:250] = float('nan') - # 250:500: a -- not nan, b -- nan - b[250:500] = float('nan') - # 500:750: a and b both nan - a[500:750] = float('nan') - b[500:750] = float('nan') - # 750:1000: neither nan - - ma = torch.max(a, b) - mi = torch.min(a, b) - - for i in range(750): - self.assertTrue(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) - self.assertTrue(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) - - for i in range(750, 1000): - self.assertFalse(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) - self.assertFalse(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) - - @onlyCPU - @dtypes(*torch.testing.get_all_math_dtypes('cpu')) - def test_threshold(self, device, dtype): - if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex: - # 100 is wide enough to use AVX2 instructions for all types - x = torch.randn(100, dtype=torch.float, device=device).sign().to(dtype=dtype) - y = torch.threshold(x, 0, 0) - self.assertTrue(y.le(0).any()) - - @onlyCPU - @dtypes(torch.float, torch.double) - def test_reciprocal(self, device, dtype): - a = torch.randn(100, 89, device=device, dtype=dtype) - res_div = 1 / a - res_reciprocal = a.clone() - res_reciprocal.reciprocal_() - self.assertEqual(res_reciprocal, res_div) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) - def test_reciprocal_complex(self, device, dtype): - t = torch.randn(10, 10, dtype=dtype, device=device) - expected = torch.from_numpy(np.reciprocal(t.cpu().numpy())) - actual = torch.reciprocal(t).cpu() - self.assertEqual(expected, actual) - - @onlyCUDA - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.complex64, torch.complex128) - def test_reciprocal_complex_extremal(self, device, dtype): - vals = ( - # Inf and Zeros - complex(float('inf'), float('inf')), - complex(float('inf'), 0.), - complex(0., float('inf')), - complex(0., 0.), - - # Nans and Zeros - complex(float('nan'), 0.), - complex(0., float('nan')), - complex(float('nan'), float('nan')), - - # Inf and Nans - complex(float('nan'), float('inf')), - complex(float('inf'), float('nan')), - - # Extremal and Normal Number - complex(float('nan'), 2.0), - complex(float('inf'), 2.0), - complex(2.0, float('nan')), - complex(2.0, float('inf')), - complex(2.0, 0.0), - complex(0.0, 2.0)) - - self.compare_with_numpy(torch.reciprocal, np.reciprocal, vals, device, dtype) - - @dtypes(torch.bfloat16, torch.float) - def test_div(self, device, dtype): - for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_), - (torch.true_divide, torch.Tensor.true_divide, - torch.Tensor.true_divide_)): - m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype) - res1 = m1.clone() - inplace(res1[:, 3], 2) - res2 = m1.clone() - for i in range(m1.size(0)): - res2[i, 3] = res2[i, 3] / 2 - self.assertEqual(res1, res2) - - if dtype == torch.bfloat16: - a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) - a2 = torch.tensor([2., 2.], dtype=dtype, device=device) - self.assertEqual(op(a1, a2), - torch.tensor([2.1, 3.1], dtype=dtype, device=device), - atol=0.01, rtol=0) - self.assertEqual(method(a1, a2), op(a1, a2)) - - @dtypes(torch.bfloat16, torch.float) - def test_true_divide_out(self, device, dtype): - a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) - a2 = torch.tensor([2., 2.], dtype=dtype, device=device) - res = torch.empty_like(a1) - self.assertEqual(torch.true_divide(a1, a2, out=res), - torch.tensor([2.1, 3.1], dtype=dtype, device=device), - atol=0.01, rtol=0) - - @onlyCUDA - @dtypes(torch.half) - def test_divmul_scalar(self, device, dtype): - x = torch.tensor(100., device=device, dtype=dtype) - x_ref = x.float() - scale = 1e5 - res = x.div(scale) - expected = x_ref.div(scale) - self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) - x = torch.tensor(1e-5, device=device, dtype=dtype) - x_ref = x.float() - res = x.mul(scale) - expected = x_ref.mul(scale) - self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) - res = scale * x - self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) - - @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) - @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) - def test_floor_divide_tensor(self, device, dtype): - x = torch.randn(10, device=device).mul(30).to(dtype) - y = torch.arange(1, 11, dtype=dtype, device=device) - - z = x // y - z_alt = torch.trunc(x.double() / y.double()).to(dtype) - - self.assertEqual(z.dtype, x.dtype) - self.assertEqual(z, z_alt) - - @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) - @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) - def test_floor_divide_scalar(self, device, dtype): - x = torch.randn(100, device=device).mul(10).to(dtype) - - z = x // 3 - z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device) - - self.assertEqual(z.dtype, x.dtype) - self.assertEqual(z, z_alt) - - # Note: this tests fails on XLA - @onlyOnCPUAndCUDA - @dtypes(torch.float, torch.long) - def test_floor_divide_out(self, device, dtype): - x = torch.randn(10, device=device).mul(10).to(dtype) - y = torch.arange(1, 11, dtype=dtype, device=device) - o = torch.empty(10, dtype=dtype, device=device) - - torch.floor_divide(x, y, out=o) - self.assertEqual(o, x // y) - - # Tests scalar with out - torch.floor_divide(x, 2, out=o) - self.assertEqual(o, x // 2) - - if dtype == torch.int: - o = torch.empty(10, dtype=torch.float, device=device) - torch.floor_divide(x, y, out=o) - self.assertEqual(o, torch.floor_divide(x.float(), y.float())) - - @onlyCPU - @dtypes(*torch.testing.get_all_math_dtypes('cpu')) - def test_rdiv(self, device, dtype): - if dtype is torch.float16: - return - elif dtype.is_complex: - x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4) - else: - x = torch.rand(100, device=device).add(1).mul(4).to(dtype) - y = 30 / x - if dtype.is_floating_point or dtype.is_complex: - z = torch.tensor([30 / v.item() for v in x], dtype=dtype, device=device) - else: - z = torch.tensor([math.trunc(30. / v.item()) for v in x], dtype=dtype, device=device) - self.assertEqual(y, z) - - @onlyCPU - @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) - def test_fmod(self, device, dtype): - m1 = torch.Tensor(10, 10).uniform_(-10., 10.).to(dtype=dtype, device=device) - res1 = m1.clone() - q = 3 - res1[:, 3].fmod_(q) - res2 = m1.clone() - for i in range(m1.size(1)): - res2[i, 3] = math.fmod(res2[i, 3], q) - self.assertEqual(res1, res2) - - zero = torch.zeros_like(m1) - if dtype in torch.testing.get_all_int_dtypes(): - with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): - m1.fmod(0) - with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): - m1.fmod(zero) - else: - self.assertTrue(torch.all(m1.fmod(0).isnan())) - self.assertTrue(torch.all(m1.fmod(zero).isnan())) - - @onlyCPU - @dtypes(torch.float, torch.long) - def test_remainder(self, device, dtype): - for use_item in [True, False]: - if dtype == torch.float: - m1 = torch.Tensor(10, 10).uniform_(-10., 10.).to(dtype=dtype, device=device) - res1 = m1.clone() - res2 = m1.clone() - qs = torch.arange(-5.1, 4.1, dtype=dtype, device=device) - # Check the case where the divisor is a simple float - for col_idx, q in enumerate(qs): - # Reference - for i in range(m1.size(0)): - res2[i, col_idx] = res2[i, col_idx] % q - # To test - res1[:, col_idx].remainder_(q if not use_item else q.item()) - self.assertEqual(res1, res2) - # Check the case where the divisor is a tensor - res1 = m1.clone() - res1.remainder_(qs.unsqueeze(0).expand_as(res1)) - self.assertEqual(res1, res2) - elif dtype == torch.long: - long_m1 = torch.LongTensor(10, 10).random_(-10, 10) - long_res1 = long_m1.clone() - long_res2 = long_m1.clone() - long_qs = torch.arange(-5, 5, dtype=dtype, device=device) - long_qs[5] = 5 # Can't handle the divisor=0 case - for col_idx, long_q in enumerate(long_qs): - # Reference - for i in range(long_m1.size(0)): - long_res2[i, col_idx] = long_res2[i, col_idx] % long_q - # To test - long_res1[:, col_idx].remainder_(long_q if not use_item else long_q.item()) - self.assertEqual(long_res1, long_res2) - # Divisor is a tensor case - long_res1 = long_m1.clone() - long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1)) - - @dtypes(torch.float, torch.double) - def test_remainder_fmod_large_dividend(self, device, dtype): - alarge = 1e9 - pi = 3.14159265358979 - for avalue in [alarge, -alarge]: - for bvalue in [pi, -pi]: - a = torch.tensor([avalue], dtype=dtype, device=device) - b = torch.tensor([bvalue], dtype=dtype, device=device) - c = torch.remainder(a, b) - d = torch.fmod(a, b) - self.assertTrue((b[0] > 0) == (c[0] > 0)) # remainder has same sign as divisor - self.assertTrue((a[0] > 0) == (d[0] > 0)) # fmod has same sign as dividend - self.assertTrue(abs(c[0]) < abs(b[0])) # remainder is within range of divisor - self.assertTrue(abs(d[0]) < abs(b[0])) # fmod is within range of divisor - if ((a[0] > 0) == (b[0] > 0)): - self.assertTrue(c[0] == d[0]) # remainder is same as fmod - else: - self.assertTrue(abs(c[0] - d[0]) == abs(b[0])) # differ by one divisor - - @dtypesIfCPU(torch.bfloat16, torch.float32, torch.float64) - @dtypes(torch.float32, torch.float64) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_hypot(self, device, dtype): - inputs = [ - (torch.randn(10, device=device).to(dtype), torch.randn(10, device=device).to(dtype)), - (torch.randn((3, 3, 3), device=device).to(dtype), torch.randn((3, 3, 3), device=device).to(dtype)), - (torch.randn((10, 1), device=device).to(dtype), torch.randn((10, 1), device=device).to(dtype).transpose(0, 1)), - (torch.randint(100, (10, ), device=device, dtype=torch.long), torch.randn(10, device=device).to(dtype)) - ] - for input in inputs: - actual = torch.hypot(input[0], input[1]) - if dtype == torch.bfloat16: - expected = torch.sqrt(input[0] * input[0] + input[1] * input[1]) - else: - expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy()) - self.assertEqual(actual, expected) - - @dtypes(torch.int64, torch.float64) - def test_remainder_edge_cases(self, device, dtype): - # Test variations of negative values used as input - a = torch.tensor([6, -6, -6, 6, 27, -27, -27, 27], dtype=dtype, device=device) - b = torch.tensor([-3, 3, -3, 3, -5, 5, -5, 5], dtype=dtype, device=device) - r = a.remainder(b) - r_expected = torch.tensor([0, 0, 0, 0, -3, 3, -2, 2], dtype=dtype, device=device) - self.assertEqual(r, r_expected) - - if dtype == torch.float64: - # Test cases where result should be nan - a = torch.tensor([-34, 0, 34], dtype=dtype, device=device) - b = torch.zeros(3, dtype=dtype, device=device) - self.assertTrue(torch.isnan(a.remainder(b)).all()) - - # Need to test a fairly large tensor with float cpu to run - # the Vec256 implementation - if device == 'cpu': - a = torch.tensor([6, -6, -6, 6, 27, -27, -27, 27] * 10000, dtype=dtype, device=device) - b = torch.tensor([-3, 3, -3, 3, -5, 5, -5, 5] * 10000, dtype=dtype, device=device) - r = a.remainder(b) - r_expected = torch.tensor([0, 0, 0, 0, -3, 3, -2, 2] * 10000, dtype=dtype, device=device) - self.assertEqual(r, r_expected) - - # Test nan cases - a = torch.tensor([-34, 0, 34] * 20000, dtype=dtype, device=device) - b = torch.zeros(3 * 20000, dtype=dtype, device=device) - self.assertTrue(torch.isnan(a.remainder(b)).all()) - - elif dtype == torch.int64: - if device == 'cpu': - # Test int divide by zero causes an exception - a = torch.ones(1000, dtype=dtype, device=device) - b = torch.ones(1000, dtype=dtype, device=device) - b[500] = 0 - self.assertRaises(RuntimeError, lambda: a.remainder(b)) - - # Check scalar type is promoted to match tensor - a = torch.ones(1, dtype=dtype, device=device) - b = 1.0 if dtype == torch.int64 else 1 - r = a.remainder(b) - self.assertEqual(r.dtype, a.dtype) - - @onlyOnCPUAndCUDA - @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_gcd(self, device, dtype): - # Tests gcd(0, 0), gcd(0, a) cases - t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) - t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) - actual = torch.gcd(t1, t2) - expected = np.gcd([0, 10, 0], [0, 0, 10]) - self.assertEqual(actual, expected) - - if dtype == torch.uint8: - # Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128) - a = torch.tensor([190, 210], device=device, dtype=dtype) - b = torch.tensor([190, 220], device=device, dtype=dtype) - actual = torch.gcd(a, b) - expected = torch.tensor([190, 10], device=device, dtype=dtype) - else: - # Compares with NumPy - a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) - b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) - actual = torch.gcd(a, b) - expected = np.gcd(a.cpu().numpy(), b.cpu().numpy()) - self.assertEqual(actual, expected) - - @onlyOnCPUAndCUDA - @dtypes(torch.int16, torch.int32, torch.int64) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_lcm(self, device, dtype): - # Tests lcm(0, 0), lcm(0, a) cases - t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) - t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) - actual = torch.lcm(t1, t2) - expected = np.lcm([0, 10, 0], [0, 0, 10]) - self.assertEqual(actual, expected) - - # Compares with NumPy - a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) - b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) - actual = torch.lcm(a, b) - expected = np.lcm(a.cpu().numpy(), b.cpu().numpy()) - self.assertEqual(actual, expected) - - @onlyOnCPUAndCUDA - @dtypes(torch.float32, torch.float64) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_nextafter(self, device, dtype): - # Test special cases - t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype) - t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype) - actual = torch.nextafter(t1, t2) - expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy()) - self.assertEqual(actual, expected, atol=0, rtol=0) - - actual = torch.nextafter(t2, t1) - expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy()) - self.assertEqual(actual, expected, atol=0, rtol=0) - - t1 = torch.tensor([0, nan], device=device, dtype=dtype) - t2 = torch.tensor([nan, 0], device=device, dtype=dtype) - self.assertTrue(torch.nextafter(t1, t2).isnan().all()) - - a = torch.randn(100, device=device, dtype=dtype) - b = torch.randn(100, device=device, dtype=dtype) - actual = torch.nextafter(a, b) - expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy()) - self.assertEqual(actual, expected, atol=0, rtol=0) - - def _i0_helper(self, t): - # Test by comparing to scipy - dtype = t.dtype - actual = torch.i0(t) - if dtype is torch.bfloat16: - t = t.to(torch.float32) - expected = scipy.special.i0(t.cpu().numpy()) - # Casting down for dtype float16 is required since scipy upcasts to float32 - if dtype is torch.bfloat16 or dtype is torch.float16: - expected = torch.from_numpy(expected).to(dtype) - self.assertEqual(actual, expected) - - def _i0_range_helper(self, range, device, dtype): - # i0 tests are broken up by the domain for which the function does not overflow for each dtype - # This is done to ensure that the function performs well across all possible input values, without worrying - # about inf or nan possibilities - for r in (range, -range): - t = torch.rand(1000, device=device).to(dtype) * r - self._i0_helper(t) - - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) - @dtypes(torch.bfloat16, torch.float32, torch.float64) - @unittest.skipIf(not TEST_SCIPY, "SciPy not found") - def test_i0_range1(self, device, dtype): - # This tests the domain for i0 for which float16 does not overflow - # The domain is (-13.25, 13.25) - self._i0_range_helper(13.25, device, dtype) - - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) - @dtypes(torch.bfloat16, torch.float32, torch.float64) - @unittest.skipIf(not TEST_SCIPY, "SciPy not found") - def test_i0_range2(self, device, dtype): - # This tests the domain for i0 for which float32 and bfloat16 does not overflow - # The domain is (-88.5, 88.5) - self._i0_range_helper(88.5, device, dtype) - - @dtypes(torch.float64) - @unittest.skipIf(not TEST_SCIPY, "SciPy not found") - def test_i0_range3(self, device, dtype): - # This tests the domain for i0 for which float64 does not overflow - # The domain is (-709.75, 709.75) - self._i0_range_helper(709.75, device, dtype) - - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) - @dtypes(torch.bfloat16, torch.float32, torch.float64) - @unittest.skipIf(not TEST_SCIPY, "SciPy not found") - def test_i0_special(self, device, dtype): - t = torch.tensor([], device=device, dtype=dtype) - self._i0_helper(t) - - t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype) - self.assertTrue(torch.i0(t).isnan().all()) - - @slowTest - @onlyOnCPUAndCUDA - @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64, torch.cfloat, torch.cdouble) - @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) - @tf32_on_and_off(0.01) - def test_mm(self, device, dtype): - def _test_mm(n, m, p, dtype, genf): - # helper function - def matrixmultiply(mat1, mat2): - n = mat1.size(0) - m = mat1.size(1) - p = mat2.size(1) - res = torch.zeros(n, p, dtype=dtype, device=device) - for i, j in iter_indices(res): - res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) - return res - - # contiguous case - mat1 = genf(n, m) - mat2 = genf(m, p) - res = torch.mm(mat1, mat2) - - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) - - # non contiguous case 1 - mat1 = genf(n, m) - mat2 = genf(p, m).t() - res = torch.mm(mat1, mat2) - - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) - - # non contiguous case 2 - mat1 = genf(m, n).t() - mat2 = genf(m, p) - res = torch.mm(mat1, mat2) - - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) - - # non contiguous case 3 - mat1 = genf(m, n).t() - mat2 = genf(p, m).t() - res = torch.mm(mat1, mat2) - - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) - - # test with zero stride - mat1 = genf(n, m) - mat2 = genf(m, 1).expand(m, p) - res = torch.mm(mat1, mat2) - - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) - - # explicitly exercise the _out variant in torch.mm(). - # contiguous case - mat1 = genf(n, m) - mat2 = genf(m, p) - res = genf(n, p) - torch.mm(mat1, mat2, out=res) - - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) - - # explicitly exercise the _out variant in torch.mm(). - # non contiguous case 3 - mat1 = genf(m, n).t() - mat2 = genf(p, m).t() - res = genf(n, p) - torch.mm(mat1, mat2, out=res) - - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) - - def genf_int(x, y): - return torch.randint(0, 100, (x, y), dtype=dtype, device=device) - - def genf_bfloat(x, y): - return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) - - def genf_float(x, y): - return torch.randn(x, y, dtype=dtype, device=device) - - for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]: - if (dtype == torch.int32) or (dtype == torch.int64): - genf = genf_int - elif (dtype == torch.bfloat16): - genf = genf_bfloat - else: - genf = genf_float - - _test_mm(n, m, p, dtype, genf) - - @onlyOnCPUAndCUDA - @dtypes(torch.float32, torch.float64) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_strided_mm_bmm(self, device, dtype): - # Tests strided view case with stride smaller than corresponding dimension size - x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device) - new_shape = [2, 2, 2] - new_stride = [3, 1, 1] - sx = torch.as_strided(x, size=new_shape, stride=new_stride) - - torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 - np_fn = lambda x: np.matmul(x, x) # noqa: E731 - self.compare_with_numpy(torch_fn, np_fn, sx) - - torch_fn = lambda x: torch.mm(x, x) # noqa: E731 - self.compare_with_numpy(torch_fn, np_fn, sx[0]) - - @onlyCPU - @dtypes(*(torch.testing.get_all_complex_dtypes() + [torch.float, torch.double])) - def test_bmm(self, device, dtype): - num_batches = 10 - M, N, O = 23, 8, 12 - b1 = torch.randn(num_batches, M, N, dtype=dtype, device=device) - b2 = torch.randn(num_batches, N, O, dtype=dtype, device=device) - res = torch.bmm(b1, b2) - for i in range(num_batches): - r = torch.mm(b1[i], b2[i]) - self.assertEqual(r, res[i]) - if torch.cuda.is_available(): - # check that mixed arguments are rejected - self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cuda())) - self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cuda(), b2)) - - @onlyCUDA - @wrapDeterministicFlagAPITest - def test_cublas_config_deterministic_error(self, device): - test_cases = [ - # (function, (tensor sizes)) - ('mm', ((2, 2), (2, 2),)), - ('mv', ((2, 2), (2,),)), - ('bmm', ((1, 2, 2), (1, 2, 2),))] - - test_configs = [ - # (CuBLAS workspace config, is deterministic) - ('garbage', False), - (None, False), - (':4096:8', True), - (':16:8', True)] - - cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG' - is_cuda10_2_or_higher = ( - (torch.version.cuda is not None) - and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2])) - - def test_case_info(fn_name, config): - return f'function "{fn_name}" with config "{"" if config is None else config}"' - - # Create processes to test each combination of test cases and config settings - processes = [] - for fn_name, arg_sizes in test_cases: - for config, is_config_deterministic in test_configs: - env = os.environ.copy() - if config is None: - if env.get(cublas_var_name) is not None: - del env[cublas_var_name] - else: - env[cublas_var_name] = config - should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic - script = f""" -import torch -torch.set_deterministic(True) -fn = torch.{fn_name} -arg_sizes = {arg_sizes} -device = '{device}' -should_throw_error = {should_throw_error} -args = [] -for arg_size in arg_sizes: - args.append(torch.randn(*arg_size, device=device)) -try: - fn(*args) -except RuntimeError as e: - if not should_throw_error: - raise RuntimeError('Did not expect any error to be raised') - elif 'Deterministic behavior was enabled with either' not in str(e): - raise RuntimeError('Expected a CuBLAS nondeterministic error, but got a different error') -else: - if should_throw_error: - raise RuntimeError('Expected a CuBLAS nondeterministic error, but it was not raised') - -""" - try: - subprocess.check_output( - [sys.executable, '-c', script], - stderr=subprocess.STDOUT, - # On Windows, opening the subprocess with the default CWD makes `import torch` - # fail, so just set CWD to this script's directory - cwd=os.path.dirname(os.path.realpath(__file__)), - env=env) - except subprocess.CalledProcessError as e: - self.fail(msg=( - f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n' - + e.output.decode("utf-8"))) - - @onlyCPU - @dtypes(*(torch.testing.get_all_complex_dtypes() + [torch.float, torch.double])) - def test_addbmm(self, device, dtype): - # num_batches = 10 - # M, N, O = 12, 8, 5 - num_batches = 2 - M, N, O = 2, 3, 4 - b1 = torch.randn(num_batches, M, N, dtype=dtype, device=device) - b2 = torch.randn(num_batches, N, O, dtype=dtype, device=device) - res = torch.bmm(b1, b2) - res2 = torch.tensor((), dtype=dtype, device=device).resize_as_(res[0]).zero_() - res3 = torch.tensor((), dtype=dtype, device=device).resize_as_(res[0]).zero_() - - res2.addbmm_(b1, b2) - self.assertEqual(res2, res.sum(0, False)) - res3.copy_(res2) - - with self.maybeWarnsRegex( - UserWarning, "This overload of addbmm_ is deprecated"): - res2.addbmm_(1, b1, b2) - self.assertEqual(res2, res.sum(0, False) * 2), - res3.addbmm_(b1, b2, beta=1) - self.assertEqual(res2, res3) - - with self.maybeWarnsRegex( - UserWarning, "This overload of addbmm_ is deprecated"): - res2.addbmm_(1., .5, b1, b2) - self.assertEqual(res2, res.sum(0, False) * 2.5) - res3.addbmm_(b1, b2, beta=1., alpha=.5) - self.assertEqual(res2, res3) - - with self.maybeWarnsRegex( - UserWarning, "This overload of addbmm is deprecated"): - self.assertEqual(res2, torch.addbmm(1, res2, 0, b1, b2)) - - res4 = torch.addbmm(res2, b1, b2, beta=1, alpha=.5) - self.assertEqual(res4, res.sum(0, False) * 3), - - res5 = torch.addbmm(res2, b1, b2, beta=0, alpha=1) - self.assertEqual(res5, res.sum(0, False)) - - res6 = torch.addbmm(res2, b1, b2, beta=.1, alpha=.5) - self.assertEqual(res6, res2 * .1 + .5 * res.sum(0)), - - @onlyCPU - @dtypes(*(torch.testing.get_all_complex_dtypes() + [torch.float, torch.double])) - def test_baddbmm(self, device, dtype): - num_batches = 10 - M, N, O = 12, 8, 5 - b1 = torch.randn(num_batches, M, N, dtype=dtype, device=device) - b2 = torch.randn(num_batches, N, O, dtype=dtype, device=device) - res = torch.bmm(b1, b2) - res2 = torch.tensor((), dtype=dtype, device=device).resize_as_(res).zero_() - res3 = torch.tensor((), dtype=dtype, device=device).resize_as_(res).zero_() - - res2.baddbmm_(b1, b2) - self.assertEqual(res2, res) - res3.copy_(res2) - - with self.maybeWarnsRegex( - UserWarning, "This overload of baddbmm_ is deprecated"): - res2.baddbmm_(1, b1, b2) - self.assertEqual(res2, res * 2) - res3.baddbmm_(b1, b2, beta=1) - self.assertEqual(res3, res2) - - with self.maybeWarnsRegex( - UserWarning, "This overload of baddbmm_ is deprecated"): - res2.baddbmm_(1, .5, b1, b2) - self.assertEqual(res2, res * 2.5) - res3.baddbmm_(b1, b2, beta=1, alpha=.5) - self.assertEqual(res3, res2) - - - with self.maybeWarnsRegex( - UserWarning, "This overload of baddbmm is deprecated"): - self.assertEqual(torch.baddbmm(1, res2, 0, b1, b2), res2) - - res4 = torch.baddbmm(res2, b1, b2, beta=1, alpha=.5) - self.assertEqual(res4, res * 3, atol=2e-5, rtol=0) - - res5 = torch.baddbmm(res2, b1, b2, beta=0, alpha=1) - self.assertEqual(res5, res) - - res6 = torch.baddbmm(res2, b1, b2, beta=.1, alpha=.5) - self.assertEqual(res6, res2 * .1 + res * .5) - - def _test_cop(self, torchfn, mathfn, dtype, device): - def reference_implementation(res2): - for i, j in iter_indices(sm1): - idx1d = i * sm1.size(0) + j - res2[i, j] = mathfn(sm1[i, j], sm2[idx1d]) - return res2 - - # contiguous - m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) - m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device) - sm1 = m1[4] - sm2 = m2[4] - - res1 = torchfn(sm1, sm2.view(10, 10)) - res2 = reference_implementation(res1.clone()) - self.assertEqual(res1, res2) - - # non-contiguous - m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) - m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device) - sm1 = m1[:, 4] - sm2 = m2[:, 4] - # view as sm1.size() - sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0])) - res1 = torchfn(sm1, sm2) - # reference_implementation assumes 1-d sm2 - sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()) - res2 = reference_implementation(res1.clone()) - self.assertEqual(res1, res2) - - @onlyCPU - @dtypes(torch.float) - def test_cdiv(self, device, dtype): - self._test_cop(torch.div, lambda x, y: x / y, dtype, device) - - @onlyCPU - @dtypes(torch.float) - def test_cfmod(self, device, dtype): - self._test_cop(torch.fmod, math.fmod, dtype, device) - - @onlyCPU - @dtypes(torch.float) - def test_cremainder(self, device, dtype): - self._test_cop(torch.remainder, lambda x, y: x % y, dtype, device) - - @onlyCPU - @dtypes(torch.float) - def test_cmul(self, device, dtype): - self._test_cop(torch.mul, lambda x, y: x * y, dtype, device) - - @onlyCPU - @dtypes(torch.float) - def test_cpow(self, device, dtype): - self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device) - - @onlyCUDA - @dtypes(torch.float16, torch.float32) - def test_prod_gpu(self, device, dtype): - x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device) - - # Check all combinations: fp16 input - fp16 output, fp16 input - fp32 - # output, fp32 input - fp16 output, fp32 input - fp32 output - for dtype_output in [torch.float16, torch.float32]: - result_expected = torch.tensor(2592, dtype=dtype_output, device=device) - output = torch.prod(x, dtype=dtype_output) - self.assertEqual(output, result_expected) - - output = x.prod(dtype=dtype_output) - self.assertEqual(output, result_expected) - - @onlyCPU - @dtypes(torch.float) - def test_prod(self, device, dtype): - x = torch.rand(100, 100, dtype=dtype, device=device) - res1 = torch.prod(x, 1) - res2 = torch.tensor((), dtype=dtype, device=device) - torch.prod(x, 1, out=res2) - self.assertEqual(res1, res2) - - @onlyCPU - @dtypes(torch.float) - def test_cross(self, device, dtype): - x = torch.rand(100, 3, 100, dtype=dtype, device=device) - y = torch.rand(100, 3, 100, dtype=dtype, device=device) - res1 = torch.cross(x, y) - res2 = torch.tensor((), dtype=dtype, device=device) - torch.cross(x, y, out=res2) - self.assertEqual(res1, res2) - - @onlyCPU - @dtypes(torch.float) - def test_cross_with_and_without_dim(self, device, dtype): - x = torch.rand(100, 3, dtype=dtype, device=device) - y = torch.rand(100, 3, dtype=dtype, device=device) - res1 = torch.cross(x, y, dim=1) - res2 = torch.cross(x, y, dim=-1) - res3 = torch.cross(x, y) - self.assertEqual(res1, res2) - self.assertEqual(res1, res3) - - @dtypes(torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64) - def test_random(self, device, dtype): - # This test is flaky with p<=(2/(ub-lb))^200=6e-36 - t = torch.empty(200, dtype=dtype, device=device) - lb = 1 - ub = 4 - - t.fill_(-1) - t.random_(lb, ub) - self.assertEqual(t.min(), lb) - self.assertEqual(t.max(), ub - 1) - - t.fill_(-1) - t.random_(ub) - self.assertEqual(t.min(), 0) - self.assertEqual(t.max(), ub - 1) - - def test_random_bool(self, device): - size = 2000 - t = torch.empty(size, dtype=torch.bool, device=device) - - t.fill_(False) - t.random_() - self.assertEqual(t.min(), False) - self.assertEqual(t.max(), True) - self.assertTrue(0.4 < (t.eq(True)).to(torch.int).sum().item() / size < 0.6) - - t.fill_(True) - t.random_() - self.assertEqual(t.min(), False) - self.assertEqual(t.max(), True) - self.assertTrue(0.4 < (t.eq(True)).to(torch.int).sum().item() / size < 0.6) - - def test_random_from_to_bool(self, device): - size = 2000 - - int64_min_val = torch.iinfo(torch.int64).min - int64_max_val = torch.iinfo(torch.int64).max - - min_val = 0 - max_val = 1 - - froms = [int64_min_val, -42, min_val - 1, min_val, max_val, max_val + 1, 42] - tos = [-42, min_val - 1, min_val, max_val, max_val + 1, 42, int64_max_val] - - for from_ in froms: - for to_ in tos: - t = torch.empty(size, dtype=torch.bool, device=device) - if to_ > from_: - if not (min_val <= from_ <= max_val): - self.assertRaisesRegex( - RuntimeError, - "from is out of bounds", - lambda: t.random_(from_, to_) - ) - elif not (min_val <= (to_ - 1) <= max_val): - self.assertRaisesRegex( - RuntimeError, - "to - 1 is out of bounds", - lambda: t.random_(from_, to_) - ) - else: - t.random_(from_, to_) - range_ = to_ - from_ - delta = 1 - self.assertTrue(from_ <= t.to(torch.int).min() < (from_ + delta)) - self.assertTrue((to_ - delta) <= t.to(torch.int).max() < to_) - else: - self.assertRaisesRegex( - RuntimeError, - "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_), - lambda: t.random_(from_, to_) - ) - - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_random_full_range(self, device, dtype): - size = 2000 - alpha = 0.1 - - int64_min_val = torch.iinfo(torch.int64).min - int64_max_val = torch.iinfo(torch.int64).max - - if dtype == torch.double: - fp_limit = 2**53 - elif dtype == torch.float: - fp_limit = 2**24 - elif dtype == torch.half: - fp_limit = 2**11 - elif dtype == torch.bfloat16: - fp_limit = 2**8 - else: - fp_limit = 0 - - t = torch.empty(size, dtype=dtype, device=device) - - if dtype in [torch.float, torch.double, torch.half, torch.bfloat16]: - from_ = int(max(-fp_limit, int64_min_val)) - to_inc_ = int(min(fp_limit, int64_max_val)) - else: - from_ = int(max(torch.iinfo(dtype).min, int64_min_val)) - to_inc_ = int(min(torch.iinfo(dtype).max, int64_max_val)) - range_ = to_inc_ - from_ + 1 - - t.random_(from_, None) - delta = max(1, alpha * range_) - self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) - self.assertTrue((to_inc_ - delta) < t.to(torch.double).max() <= to_inc_) - - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_random_from_to(self, device, dtype): - size = 2000 - alpha = 0.1 - - int64_min_val = torch.iinfo(torch.int64).min - int64_max_val = torch.iinfo(torch.int64).max - - if dtype in [torch.float, torch.double, torch.half]: - min_val = int(max(torch.finfo(dtype).min, int64_min_val)) - max_val = int(min(torch.finfo(dtype).max, int64_max_val)) - froms = [min_val, -42, 0, 42] - tos = [-42, 0, 42, max_val >> 1] - elif dtype == torch.bfloat16: - min_val = int64_min_val - max_val = int64_max_val - froms = [min_val, -42, 0, 42] - tos = [-42, 0, 42, max_val >> 1] - elif dtype == torch.uint8: - min_val = torch.iinfo(dtype).min - max_val = torch.iinfo(dtype).max - froms = [int64_min_val, -42, min_val - 1, min_val, 42, max_val, max_val + 1] - tos = [-42, min_val - 1, min_val, 42, max_val, max_val + 1, int64_max_val] - elif dtype == torch.int64: - min_val = int64_min_val - max_val = int64_max_val - froms = [min_val, -42, 0, 42] - tos = [-42, 0, 42, max_val] - else: - min_val = torch.iinfo(dtype).min - max_val = torch.iinfo(dtype).max - froms = [int64_min_val, min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1] - tos = [min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1, int64_max_val] - - if dtype == torch.double: - fp_limit = 2**53 - elif dtype == torch.float: - fp_limit = 2**24 - elif dtype == torch.half: - fp_limit = 2**11 - elif dtype == torch.bfloat16: - fp_limit = 2**8 - else: - fp_limit = 0 - - for from_ in froms: - for to_ in tos: - t = torch.empty(size, dtype=dtype, device=device) - if to_ > from_: - if not (min_val <= from_ <= max_val): - self.assertRaisesRegex( - RuntimeError, - "from is out of bounds", - lambda: t.random_(from_, to_) - ) - elif not (min_val <= (to_ - 1) <= max_val): - self.assertRaisesRegex( - RuntimeError, - "to - 1 is out of bounds", - lambda: t.random_(from_, to_) - ) - else: - if dtype.is_floating_point and ( - not (-fp_limit <= from_ <= fp_limit) or not (-fp_limit <= (to_ - 1) <= fp_limit)): - if not (-fp_limit <= from_ <= fp_limit): - self.assertWarnsRegex(UserWarning, "from is out of bounds", - lambda: t.random_(from_, to_)) - if not (-fp_limit <= (to_ - 1) <= fp_limit): - self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds", - lambda: t.random_(from_, to_)) - else: - t.random_(from_, to_) - range_ = to_ - from_ - delta = max(1, alpha * range_) - if dtype == torch.bfloat16: - # Less strict checks because of rounding errors - # TODO investigate rounding errors - self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) - self.assertTrue((to_ - delta) < t.to(torch.double).max() <= to_) - else: - self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) - self.assertTrue((to_ - delta) <= t.to(torch.double).max() < to_) - else: - self.assertRaisesRegex( - RuntimeError, - "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_), - lambda: t.random_(from_, to_) - ) - - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_random_to(self, device, dtype): - size = 2000 - alpha = 0.1 - - int64_min_val = torch.iinfo(torch.int64).min - int64_max_val = torch.iinfo(torch.int64).max - - if dtype in [torch.float, torch.double, torch.half]: - min_val = int(max(torch.finfo(dtype).min, int64_min_val)) - max_val = int(min(torch.finfo(dtype).max, int64_max_val)) - tos = [-42, 0, 42, max_val >> 1] - elif dtype == torch.bfloat16: - min_val = int64_min_val - max_val = int64_max_val - tos = [-42, 0, 42, max_val >> 1] - elif dtype == torch.uint8: - min_val = torch.iinfo(dtype).min - max_val = torch.iinfo(dtype).max - tos = [-42, min_val - 1, min_val, 42, max_val, max_val + 1, int64_max_val] - elif dtype == torch.int64: - min_val = int64_min_val - max_val = int64_max_val - tos = [-42, 0, 42, max_val] - else: - min_val = torch.iinfo(dtype).min - max_val = torch.iinfo(dtype).max - tos = [min_val - 1, min_val, -42, 0, 42, max_val, max_val + 1, int64_max_val] - - from_ = 0 - for to_ in tos: - t = torch.empty(size, dtype=dtype, device=device) - if to_ > from_: - if not (min_val <= (to_ - 1) <= max_val): - self.assertRaisesRegex( - RuntimeError, - "to - 1 is out of bounds", - lambda: t.random_(from_, to_) - ) - else: - t.random_(to_) - range_ = to_ - from_ - delta = max(1, alpha * range_) - if dtype == torch.bfloat16: - # Less strict checks because of rounding errors - # TODO investigate rounding errors - self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) - self.assertTrue((to_ - delta) < t.to(torch.double).max() <= to_) - else: - self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) - self.assertTrue((to_ - delta) <= t.to(torch.double).max() < to_) - else: - self.assertRaisesRegex( - RuntimeError, - "random_ expects 'from' to be less than 'to', but got from=" + str(from_) + " >= to=" + str(to_), - lambda: t.random_(from_, to_) - ) - - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_random_default(self, device, dtype): - size = 2000 - alpha = 0.1 - - if dtype == torch.float: - to_inc = 1 << 24 - elif dtype == torch.double: - to_inc = 1 << 53 - elif dtype == torch.half: - to_inc = 1 << 11 - elif dtype == torch.bfloat16: - to_inc = 1 << 8 - else: - to_inc = torch.iinfo(dtype).max - - t = torch.empty(size, dtype=dtype, device=device) - t.random_() - self.assertTrue(0 <= t.to(torch.double).min() < alpha * to_inc) - self.assertTrue((to_inc - alpha * to_inc) < t.to(torch.double).max() <= to_inc) - - @onlyCPU - @dtypes(torch.half, torch.double, torch.int) - def test_cat(self, device, dtype): - SIZE = 10 - for dim in range(-3, 3): - pos_dim = dim if dim >= 0 else 3 + dim - x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim) - y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim) - z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE), device=device).to(dtype).transpose(0, pos_dim) - - res1 = torch.cat((x, y, z), dim) - self.assertEqual(res1.narrow(pos_dim, 0, 13), x, atol=0, rtol=0) - self.assertEqual(res1.narrow(pos_dim, 13, 17), y, atol=0, rtol=0) - self.assertEqual(res1.narrow(pos_dim, 30, 19), z, atol=0, rtol=0) - - x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE), device=device).to(dtype) - self.assertEqual(torch.cat(torch.split(x, 7)), x) - self.assertEqual(torch.cat(torch.chunk(x, 7)), x) - - y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE), device=device).to(dtype) - z = torch.cat([x, y]) - self.assertEqual(z.size(), (21, SIZE, SIZE)) - - self.assertRaises(RuntimeError, lambda: torch.cat([])) - self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None])) - - @onlyCPU - def test_cat_scalars(self, device): - x = torch.tensor(0, device=device) - y = torch.tensor(1, device=device) - with self.assertRaisesRegex(RuntimeError, 'zero-dimensional.*cannot be concatenated'): - torch.cat([x, y]) - - @onlyCPU - @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) - def test_floor_divide_zero(self, device, dtype): - a = torch.tensor([0, 1], dtype=dtype, device=device) - b = torch.tensor([0, 1], dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, 'ZeroDivisionError'): - a // b + (torch.channels_last_3d, (4, 3, 8, 8, 8))) - @onlyCPU - def test_cat_bad_input_sizes(self, device): - x = torch.randn(2, 1, device=device) - y = torch.randn(2, 1, 1, device=device) - z = torch.randn(2, 1, 1, device=device) - self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z])) + for mf, shape, in formats_shapes: + for transformation_fn in transformation_fns: + self._test_memory_format_transformations( + device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True) - x = torch.randn(2, 1, 2, device=device) - y = torch.randn(2, 1, 1, device=device) - z = torch.randn(2, 2, 1, device=device) - self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1)) + def test_memory_format_type_shortcuts(self, device): + def get_generator(memory_format, shape, dtype): + def input_generator_fn(device): + return torch.randn(shape, device=device, dtype=dtype).clamp(0, 1) \ + .round().contiguous(memory_format=memory_format) + return input_generator_fn - @slowTest - @onlyCPU - def test_cat_big(self, device): - SIZE1 = 6500 - SIZE2 = 4500 - concat_list = [] - concat_list.append(torch.ones((SIZE1, 1024 * 512), dtype=torch.uint8, device=device)) - concat_list.append(torch.ones((SIZE2, 1024 * 512), dtype=torch.uint8, device=device)) - result = torch.cat(concat_list) - self.assertEqual(result.size(0), SIZE1 + SIZE2) + def get_fn(fn_name): + def transformation_fn(tensor, **kwargs): + fn = getattr(tensor, fn_name) + return fn(**kwargs) + return transformation_fn - @onlyCPU - def test_max_mixed_devices(self, device): - a = torch.randn(10, device=device) - if torch.cuda.is_available(): - values = torch.randn(10).cuda() - indices = torch.cuda.LongTensor() - self.assertRaises(RuntimeError, - lambda: torch.max(a, 0, out=(values, indices))) - self.assertRaises(RuntimeError, - lambda: torch.amax(a, 0, out=values)) + shortcuts = ['byte', 'char', 'double', 'bool', 'half', 'int', 'long', 'short'] + if device == 'cpu': + shortcuts += ['bfloat16'] - @onlyCPU - def test_min_mixed_devices(self, device): - a = torch.randn(10, device=device) - if torch.cuda.is_available(): - values = torch.randn(10).cuda() - indices = torch.cuda.LongTensor() - self.assertRaises(RuntimeError, - lambda: torch.min(a, 0, out=(values, indices))) - self.assertRaises(RuntimeError, - lambda: torch.amin(a, 0, out=values)) + formats_shapes = ( + (torch.channels_last, (4, 3, 8, 8)), + (torch.channels_last_3d, (4, 3, 8, 8, 8))) - def _float_to_int_conversion_helper(self, vals, device, dtype): - assert TEST_NUMPY + for mf, shape in formats_shapes: + for fn_name in shortcuts: + self._test_memory_format_transformations( + device, get_generator(mf, shape, torch.float32), get_fn(fn_name), mf, default_is_preserve=True) - a = np.array(vals, dtype=np.float32).astype(torch_to_numpy_dtype_dict[dtype]) - t = torch.tensor(vals, device=device, dtype=torch.float).to(dtype) - self.assertEqual(torch.from_numpy(a), t.cpu()) + # Test 'float' separately to avoid float->float no-op. + for mf, shape in formats_shapes: + self._test_memory_format_transformations( + device, get_generator(mf, shape, torch.float64), get_fn('float'), mf, default_is_preserve=True) - # Checks that float->integer casts don't produce undefined behavior errors. - # Note: In C++, casting from a floating value to an integral dtype - # is undefined if the floating point value is not within the integral - # dtype's dynamic range. This can (and should) cause undefined behavior - # errors with UBSAN. These casts are deliberate in PyTorch, however, and - # NumPy has the same behavior. - @onlyOnCPUAndCUDA - @unittest.skipIf(IS_MACOS, "Test is broken on MacOS, see https://github.com/pytorch/pytorch/issues/38752") - @unittest.skipIf(IS_PPC, "Test is borken on PowerPC, see https://github.com/pytorch/pytorch/issues/39671") - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) - def test_float_to_int_conversion_finite(self, device, dtype): - min = torch.finfo(torch.float).min - max = torch.finfo(torch.float).max - - # Note: CUDA max float -> integer conversion is divergent on some dtypes - vals = (min, -2, -1.5, -.5, 0, .5, 1.5, 2, max) - if self.device_type == 'cuda': - if torch.version.hip: - # HIP min float -> int64 conversion is divergent - vals = (-2, -1.5, -.5, 0, .5, 1.5, 2) - else: - vals = (min, -2, -1.5, -.5, 0, .5, 1.5, 2) + @onlyCUDA + def test_memory_format_cpu_and_cuda_ops(self, device): + def get_generator(memory_format, shape): + def input_generator_fn(device): + return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format) + return input_generator_fn - self._float_to_int_conversion_helper(vals, device, dtype) + def transformation_cpu_fn(tensor, **kwargs): + return tensor.cpu(**kwargs) - # Note: CUDA will fail this test on most dtypes, often dramatically. - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @onlyCPU - @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) - def test_float_to_int_conversion_nonfinite(self, device, dtype): - vals = (float('-inf'), float('inf'), float('nan')) + def transformation_cuda_fn(tensor, **kwargs): + return tensor.cuda(**kwargs) - self._float_to_int_conversion_helper(vals, device, dtype) + formats_shapes = ( + (torch.channels_last, (4, 3, 8, 8)), + (torch.channels_last_3d, (4, 3, 8, 8, 8))) - # TODO: re-enable this test - @unittest.skipIf(True, "real and imag not implemented for complex") - @onlyOnCPUAndCUDA - def test_complex_type_conversions(self, device): - dtypes = [torch.float, torch.complex64, torch.complex128] - for from_type in dtypes: - for to_type in dtypes: - from_tensor = torch.randn(4, dtype=from_type, device=device) - to_tensor = from_tensor.to(to_type) - if from_type.is_complex and not to_type.is_complex: - self.assertEqual(torch.real(from_tensor), to_tensor, exact_dtype=False) - elif not from_type.is_complex and to_type.is_complex: - self.assertEqual(from_tensor, torch.real(to_tensor), exact_dtype=False) - self.assertEqual(torch.zeros_like(torch.imag(to_tensor)), torch.imag(to_tensor), exact_dtype=False) - else: - self.assertEqual(from_tensor, to_tensor, exact_dtype=False) + for mf, shape in formats_shapes: + self._test_memory_format_transformations( + 'cuda', get_generator(mf, shape), transformation_cpu_fn, mf, default_is_preserve=True) + self._test_memory_format_transformations( + 'cpu', get_generator(mf, shape), transformation_cuda_fn, mf, default_is_preserve=True) @dtypes(torch.complex64, torch.complex128) def test_complex_unsupported(self, device, dtype): @@ -17962,42 +5912,61 @@ def test_complex_unsupported(self, device, dtype): # Note: whether PyTorch should support min and max on complex # tensors is an open question. # See https://github.com/pytorch/pytorch/issues/36374 - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.min(t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): t.min() - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.min(t, dim=0) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.min(t, t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.min(t, t, out=t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.max(t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): t.max() - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.max(t, dim=0) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.max(t, t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.max(t, t, out=t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.amin(t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): t.amin() - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.amin(t, dim=0) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.amax(t) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): t.amax() - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.amax(t, dim=0) + # Tests _aminmax() variants with complex inputs, + # which are currently not supported due to min & max being unsupported + # for complex inputs, as per https://github.com/pytorch/pytorch/issues/36374 + # Test with a single-element tensor t, as well as a multi-element tensor x + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): + min_val, max_val = torch._aminmax(t) + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): + min_val = torch._aminmax(t, dim=0)[0] + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): + max_val = torch._aminmax(t, dim=0)[1] + # Test _aminmax() with a multi-element tensor + x = torch.tensor([(1 + 1j), (2 + 3j)], device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): + min_val, max_val = torch._aminmax(x) + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): + min_val = torch._aminmax(x, dim=0)[0] + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): + max_val = torch._aminmax(x, dim=0)[1] + # Tests clamp variants with complex inputs # Note: whether PyTorch should support clamp on complex # tensors is an open question. @@ -18005,154 +5974,19 @@ def test_complex_unsupported(self, device, dtype): min_val = 1 + 1j max_val = 4 + 4j out = torch.empty((0,), device=device, dtype=dtype) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.clamp(t, min=min_val) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.clamp(t, max=max_val) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.clamp(t, min_val, max_val) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.clamp(t, min=min_val, out=out) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.clamp(t, max=max_val, out=out) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, '(.*not support.*)|(.*not implemented.*)'): torch.clamp(t, min_val, max_val, out=out) - @dtypes(torch.long) - def test_abs_big_number(self, device, dtype): - bignumber = 2 ** 31 + 1 - res = torch.tensor([bignumber], device=device, dtype=dtype) - self.assertGreater(res.abs()[0], 0) - - @dtypes(torch.float, torch.double) - def test_abs_signed_zero(self, device, dtype): - # Both abs(0.0) and abs(-0.0) should result in 0.0 - size = 128 + 1 # pick a large enough number with remainder so that - # both vectorized and nonvectorized op is tested - inp = torch.zeros(size, device=device, dtype=dtype) - inp[::2] = -0.0 - inp = inp.abs() - for v in inp: - self.assertGreater(math.copysign(1.0, v), 0.0) - - @dtypes(torch.float) - def test_absolute(self, device, dtype): - # absolute is an alias for abs. Just check to see that results - # are the same. - t = torch.randn(10, 10, device=device, dtype=dtype) - r_abs = t.abs() - r_absolute = t.absolute() - self.assertEqual(r_abs, r_absolute) - - r_abs = torch.abs(t) - r_absolute = torch.absolute(t) - self.assertEqual(r_abs, r_absolute) - - r_abs = torch.empty((10, 10), device=device, dtype=dtype) - r_absolute = torch.empty((10, 10), device=device, dtype=dtype) - torch.abs(t, out=r_abs) - torch.absolute(t, out=r_absolute) - self.assertEqual(r_abs, r_absolute) - - from copy import deepcopy - t_copy = deepcopy(t) - t.absolute_() - t_copy.abs_() - self.assertEqual(t, t_copy) - - def test_bucketization(self, device): - values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device) - values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device) - - # regular case 3d boundary and 3d input value - boundaries = torch.tensor([[[1, 2, 3, 4], [3, 4, 5, 6]], [[1, 3, 5, 7], [2, 4, 6, 8]]], device=device) - expected_result = torch.tensor([[[0, 2, 4], [0, 1, 3]], [[0, 1, 1], [1, 2, 2]]], device=device) - output = torch.empty(2, 2, 3, device=device, dtype=torch.int64) - self.assertEqual(torch.searchsorted(boundaries, values_3d), expected_result) - self.assertEqual(torch.searchsorted(boundaries, values_3d, out=output), expected_result) - expected_result = torch.tensor([[[1, 3, 4], [0, 2, 4]], [[1, 1, 2], [2, 2, 3]]], device=device) - self.assertEqual(torch.searchsorted(boundaries, values_3d, right=True), expected_result) - self.assertEqual(torch.searchsorted(boundaries, values_3d, right=True, out=output), expected_result) - - # simple 1d boundary and 3d input value - boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device) - expected_result = torch.tensor([[[0, 2, 4], [1, 3, 5]], [[0, 1, 2], [3, 4, 5]]], device=device) - output = torch.empty(2, 2, 3, device=device, dtype=torch.int64) - self.assertEqual(torch.searchsorted(boundaries, values_3d), expected_result) - self.assertEqual(torch.bucketize(values_3d, boundaries), expected_result) - self.assertEqual(torch.bucketize(values_3d, boundaries, out=output), expected_result) - expected_result = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device) - self.assertEqual(torch.searchsorted(boundaries, values_3d, right=True), expected_result) - self.assertEqual(torch.bucketize(values_3d, boundaries, right=True), expected_result) - self.assertEqual(torch.bucketize(values_3d, boundaries, out=output, right=True), expected_result) - - # simple float 1d boundary and 1d input with output int32 type - values_1d_float = values_1d.to(torch.float32) - boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32) - expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32) - self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result) - self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result) - - # multiple dimension input with 0 elements - boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int64) - values_0_el = torch.tensor([[[]]], device=device, dtype=torch.int64) - expected_result = values_0_el.to(torch.int64) - self.assertEqual(torch.searchsorted(boundaries, values_0_el), expected_result) - self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result) - - # nan input - values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64) - boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64) - expected_result = torch.tensor([1, 4, 2, 4], device=device) - self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result) - expected_result = torch.tensor([2, 4, 3, 4], device=device) - self.assertEqual(torch.searchsorted(boundaries, values_nan, right=True), expected_result) - - # type promotion and non contiguous tensors - values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32) - boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64) - expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device) - if self.device_type != 'xla': - self.assertWarnsRegex( - UserWarning, "tensor is non-contiguous", - lambda: self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result)) - else: - # All tensors in XLA is contiguous even doing permute, no warning msg will be generate in XLA - self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result) - - # scalar type - boundaries = torch.tensor([1.5, 2.5, 3.5], device=device) - expected_result = torch.tensor(1, device=device) - self.assertEqual(torch.searchsorted(boundaries, 2), expected_result) - self.assertEqual(torch.bucketize(torch.tensor(2, device=device), boundaries), expected_result) - expected_result = torch.tensor(3, device=device) - scalar_tensor_nan = torch.tensor(float('nan'), device=device) - self.assertEqual(torch.searchsorted(boundaries, scalar_tensor_nan), expected_result) - self.assertEqual(torch.bucketize(float('nan'), boundaries, right=True), expected_result) - - # invalid input dimensions - boundaries = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device) - with self.assertRaisesRegex( - RuntimeError, "first N-1 dimensions of boundaries tensor and input value tensor must match"): - torch.searchsorted(boundaries, values_3d) - with self.assertRaisesRegex( - RuntimeError, "boundaries tensor must be 1 dimension"): - torch.bucketize(values_3d, boundaries) - with self.assertRaisesRegex( - RuntimeError, "only when boundaries tensor dimension is 1"): - torch.searchsorted(boundaries, 1) - - # incompatiable output tensor's dtype - def test_output_dtype(dtype, is_int32): - output = values_1d.to(dtype) - with self.assertRaisesRegex( - RuntimeError, "output tensor's dtype is wrong"): - torch.searchsorted(values_1d, values_1d, out=output, out_int32=is_int32) - - test_output_dtype(torch.float32, False) - test_output_dtype(torch.int32, False) - test_output_dtype(torch.int64, True) - def test_pickle_gradscaler(self, device): # This test is not in test_cuda.py because it should pass in 3 cases: # 1. cuda is not available. @@ -18187,31 +6021,6 @@ def test_pickle_gradscaler(self, device): if lazy_init_scale: self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0) - @dtypesIfCUDA(torch.half, torch.float, torch.double, - torch.int8, torch.short, torch.int, torch.long) - @dtypes(torch.float, torch.double, - torch.int8, torch.short, torch.int, torch.long) - def test_nansum(self, device, dtype): - x = (torch.randn(3, 3)) - if dtype in [torch.half, torch.float, torch.double]: - x[x < 0.2] = float('nan') - # Randomly scale the values - x = (x * random.randint(10, 100)).tolist() - - self.compare_with_numpy(torch.nansum, np.nansum, x, device, dtype) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.float32, torch.float64) - def test_unpack_double(self, device, dtype): - # Reference: https://github.com/pytorch/pytorch/issues/33111 - vals = (2 ** 24 + 1, 2 ** 53 + 1, - np.iinfo(np.int64).max, np.iinfo(np.uint64).max, np.iinfo(np.uint64).max + 1, - -1e500, 1e500) - for val in vals: - t = torch.tensor(val, dtype=dtype, device=device) - a = np.array(val, dtype=torch_to_numpy_dtype_dict[dtype]) - self.assertEqual(t, torch.from_numpy(a)) - def test_multinomial_invalid(self, device): def test(probs): with self.assertRaisesRegex(RuntimeError, @@ -18266,7 +6075,12 @@ def _generate_input(self, shape, dtype, device, with_extremal): x = torch.tensor((), dtype=dtype, device=device) else: if dtype.is_floating_point or dtype.is_complex: - x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + # work around torch.randn not being implemented for bfloat16 + if dtype == torch.bfloat16: + x = torch.randn(*shape, device=device) * random.randint(30, 100) + x = x.to(torch.bfloat16) + else: + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) x[torch.randn(*shape) > 0.5] = 0 if with_extremal and dtype.is_floating_point: # Use extremal values @@ -18277,116 +6091,14 @@ def _generate_input(self, shape, dtype, device, with_extremal): x[torch.randn(*shape) > 0.5] = complex('nan') x[torch.randn(*shape) > 0.5] = complex('inf') x[torch.randn(*shape) > 0.5] = complex('-inf') + elif dtype == torch.bool: + x = torch.zeros(shape, dtype=dtype, device=device) + x[torch.randn(*shape) > 0.5] = True else: x = torch.randint(15, 100, shape, dtype=dtype, device=device) return x - def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype, - with_extremal=False, atol=None, rtol=None, - exact_dtype=True, with_keepdim=False): - # Test 0-d to 3-d tensors. - for ndims in range(0, 4): - shape = self._rand_shape(ndims, min_size=5, max_size=10) - for n in range(ndims + 1): - for c in combinations(list(range(ndims)), n): - for count_dim in permutations(c): - # Generate Input. - x = self._generate_input(shape, dtype, device, with_extremal) - - if count_dim == (): - # Default `dims=None` case - self.compare_with_numpy(torch_func, np_func, x, device=None, dtype=None, - atol=atol, rtol=rtol, exact_dtype=exact_dtype) - else: - # With `dims: tuple of ints` case - if with_keepdim: - torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim) - np_func_partial = partial(np_func, keepdims=True, axis=count_dim) - else: - torch_func_partial = partial(torch_func, dim=count_dim) - np_func_partial = partial(np_func, axis=count_dim) - self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None, - atol=atol, rtol=rtol, exact_dtype=exact_dtype) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) - def test_count_nonzero(self, device, dtype): - self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype) - self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True) - - def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False): - def is_integral(dtype): - return dtype in torch.testing.get_all_int_dtypes() - - # On Windows CI, the current version of `numpy` promotes all lower integers - # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking - # the exact dtype. - # Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580 - # PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370 - exact_dtype = False if (IS_WINDOWS and is_integral(dtype)) else True - - if dtype == torch.uint8: - with self.assertRaises(TypeError): - self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype, with_extremal=with_extremal) - else: - # TODO: Investigate why the output is not close to numpy. - if dtype == torch.float16: - atol = 0.4 - rtol = 1e-2 - elif dtype == torch.float32: - atol = 7e-05 - rtol = 3e-06 - else: - # Default values - atol = None - rtol = None - self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype, - atol=atol, rtol=rtol, exact_dtype=exact_dtype, - with_keepdim=with_keepdim, with_extremal=with_extremal) - - @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) - def test_sum_vs_numpy(self, device, dtype): - self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype) - self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True) - self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True) - - @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) - def test_nansum_vs_numpy(self, device, dtype): - self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype) - self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True) - self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True) - - @dtypes(*(torch.testing.get_all_complex_dtypes())) - def test_nansum_complex(self, device, dtype): - x = torch.randn((3, 3, 3), device=device, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"): - torch.nansum(x) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_nansum_out_dtype(self, device): - dtypes = list(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)) - for inp_dtype, out_dtype in combinations(dtypes, 2): - shape = self._rand_shape(random.randint(2, 5), min_size=5, max_size=10) - x = self._generate_input(shape, inp_dtype, device, with_extremal=False) - torch_fn = partial(torch.nansum, dtype=out_dtype) - np_out_dtype = torch_to_numpy_dtype_dict[out_dtype] - np_fn = partial(np.nansum, dtype=np_out_dtype) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - - @dtypes(torch.int32, torch.int64) - def test_large_linspace(self, device, dtype): - start = torch.iinfo(dtype).min - end = torch.iinfo(dtype).max & ~0xfff - steps = 15 - x = torch.linspace(start, end, steps, dtype=dtype, device=device) - self.assertGreater(x[1] - x[0], (end - start) / steps) - def _test_where_scalar_template(self, device, dtype, exec_fn): for with_extremal in [True, False]: for ndims in range(0, 4): @@ -18420,7 +6132,7 @@ def _where_valid_scalar_tensor_combination(self, scalar_type, dtype): return False @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) def test_where_scalar_invalid_combination_raises(self, device, dtype): @@ -18432,7 +6144,7 @@ def checkRaises(scalar_type, dtype, condition, x, scalar_1): self._test_where_scalar_template(device, dtype, checkRaises) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) def test_where_scalar_valid_combination(self, device, dtype): @@ -18483,411 +6195,11 @@ def get_dtype(scalar_type): # Reset the original dtype torch.set_default_dtype(default_dtype) - @dtypes(torch.int64, torch.float, torch.complex128) - def test_movedim_invalid(self, device, dtype): - shape = self._rand_shape(4, min_size=5, max_size=10) - x = self._generate_input(shape, dtype, device, False) - - # Invalid `source` and `destination` dimension - with self.assertRaisesRegex(IndexError, "Dimension out of range"): - torch.movedim(x, 5, 0) - - with self.assertRaisesRegex(IndexError, "Dimension out of range"): - torch.movedim(x, 0, 5) - - # Mismatch in size of `source` and `destination` - with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"): - torch.movedim(x, (1, 0), (0, )) - - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): - torch.movedim(x, (0, 0), (0, 1)) - - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): - torch.movedim(x, (0, 1, 0), (0, 1, 2)) - - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): - torch.movedim(x, (0, 1), (1, 1)) - - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): - torch.movedim(x, (0, 1, 2), (1, 0, 1)) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.int64, torch.float, torch.complex128) - def test_movedim(self, device, dtype): - for nd in range(5): - shape = self._rand_shape(nd, min_size=5, max_size=10) - x = self._generate_input(shape, dtype, device, with_extremal=False) - for random_negative in [True, False]: - for src_dim, dst_dim in permutations(range(nd), r=2): - random_prob = random.random() - - if random_negative and random_prob > 0.66: - src_dim = src_dim - nd - elif random_negative and random_prob > 0.33: - dst_dim = dst_dim - nd - elif random_negative: - src_dim = src_dim - nd - dst_dim = dst_dim - nd - - # Integer `source` and `destination` - torch_fn = partial(torch.movedim, source=src_dim, destination=dst_dim) - np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - - if nd == 0: - continue - - def make_index_negative(sequence, idx): - sequence = list(sequence) - sequence[random_idx] = sequence[random_idx] - nd - return tuple(src_sequence) - - for src_sequence in permutations(range(nd), r=random.randint(1, nd)): - # Sequence `source` and `destination` - dst_sequence = tuple(random.sample(range(nd), len(src_sequence))) - - # Randomly change a dim to a negative dim representation of itself. - random_prob = random.random() - if random_negative and random_prob > 0.66: - random_idx = random.randint(0, len(src_sequence) - 1) - src_sequence = make_index_negative(src_sequence, random_idx) - elif random_negative and random_prob > 0.33: - random_idx = random.randint(0, len(src_sequence) - 1) - dst_sequence = make_index_negative(dst_sequence, random_idx) - elif random_negative: - random_idx = random.randint(0, len(src_sequence) - 1) - dst_sequence = make_index_negative(dst_sequence, random_idx) - random_idx = random.randint(0, len(src_sequence) - 1) - src_sequence = make_index_negative(src_sequence, random_idx) - - torch_fn = partial(torch.movedim, source=src_sequence, destination=dst_sequence) - np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - - # Move dim to same position - x = torch.randn(2, 3, 5, 7, 11) - torch_fn = partial(torch.movedim, source=(0, 1), destination=(0, 1)) - np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1)) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - - torch_fn = partial(torch.movedim, source=1, destination=1) - np_fn = partial(np.moveaxis, source=1, destination=1) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - - # Empty Sequence - torch_fn = partial(torch.movedim, source=(), destination=()) - np_fn = partial(np.moveaxis, source=(), destination=()) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - - def _test_atleast_dim(self, torch_fn, np_fn, device, dtype): - for ndims in range(0, 5): - shape = self._rand_shape(ndims, min_size=5, max_size=10) - for n in range(ndims + 1): - for with_extremal in [False, True]: - for contiguous in [False, True]: - # Generate Input. - x = self._generate_input(shape, dtype, device, with_extremal) - if contiguous: - x = x.T - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - - # Compare sequence input - torch_sequence_x = (x,) * random.randint(3, 10) - np_sequence_x = tuple(map(lambda x: np.array(x.detach().cpu().numpy()), torch_sequence_x)) - torch_res = torch_fn(*torch_sequence_x) - np_res = np_fn(*np_sequence_x) - - torch_res = tuple(map(lambda x: x.cpu(), torch_res)) - np_res = tuple(map(lambda x: torch.from_numpy(x), np_res)) - self.assertEqual(np_res, torch_res) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) - def test_atleast(self, device, dtype): - self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype) - self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) - self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype) - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) - def test_argminmax_multiple(self, device, dtype): - # Case: All Ones - t = torch.ones(3, 3, device=device, dtype=dtype) - self.compare_with_numpy(torch.argmax, np.argmax, t) - self.compare_with_numpy(torch.argmin, np.argmin, t) - - # Case: With single `nan` present. - if dtype in torch.testing.get_all_fp_dtypes(): - t[2, 2] = float('nan') - self.compare_with_numpy(torch.argmax, np.argmax, t) - self.compare_with_numpy(torch.argmin, np.argmin, t) - - # Case: Randomly Generated Tensors - for ndims in range(1, 5): - shape = self._rand_shape(ndims, min_size=5, max_size=10) - for with_extremal in [False, True]: - for contiguous in [False, True]: - # Generate Input. - x = self._generate_input(shape, dtype, device, with_extremal) - - if dtype == torch.half: - max_val = torch.max(x.to(torch.float)) - min_val = torch.min(x.to(torch.float)) - else: - max_val = torch.max(x) - min_val = torch.min(x) - - mask = torch.randn(x.shape) > 0.5 - x[mask] = torch.tensor(max_val + 1, dtype=dtype) - - mask = torch.randn(x.shape) > 0.5 - x[mask] = torch.tensor(min_val - 1, dtype=dtype) - - if not contiguous: - x = x.T - - self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None) - self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None) - - # Verify indices returned by max and min. - if dtype != torch.half: - rand_dim = random.randint(0, ndims - 1) - self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1], - lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None) - self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1], - lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None) - - def verify_against_numpy(t): - # Argmax - torch_fn = partial(torch.argmax, dim=1) - np_fn = partial(np.argmax, axis=1) - self.compare_with_numpy(torch_fn, np_fn, t) - # Non-contiguous input - self.compare_with_numpy(torch_fn, np_fn, t.T) - - # Verify indices returned by max. - if dtype != torch.half: - self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None) - self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) - - # Argmin - torch_fn = partial(torch.argmin, dim=1) - np_fn = partial(np.argmin, axis=1) - self.compare_with_numpy(torch_fn, np_fn, t) - # Non-contiguous input - self.compare_with_numpy(torch_fn, np_fn, t.T) - - # Verify indices returned by min. - if dtype != torch.half: - self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None) - self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) - - # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 - t = torch.tensor([[1, 5], - [2, 10], - [3, 3]], device=device, dtype=dtype) - verify_against_numpy(t) - - # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 - t = torch.tensor([[1, 5], - [2, 10], - [0, 0]], device=device, dtype=dtype) - verify_against_numpy(t) - - def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype): - # Test error for non-tuple argument - with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): - torch_fn(torch.randn(10)) - - # Test 0-D - num_tensors = random.randint(1, 5) - input_t = [torch.tensor(random.uniform(0, 10), device=device, dtype=dtype) for i in range(num_tensors)] - actual = torch_fn(input_t) - expected = np_fn([input.cpu().numpy() for input in input_t]) - self.assertEqual(actual, expected) - - for ndims in range(1, 5): - base_shape = list(self._rand_shape(ndims, min_size=1, max_size=5)) - for i in range(ndims): - shape = list(base_shape) - num_tensors = random.randint(1, 5) - torch_input = [] - # Create tensors with shape being different along one axis only - for param in range(num_tensors): - shape[i] = random.randint(1, 5) - torch_input.append(self._generate_input(tuple(shape), dtype, device, with_extremal=False)) - - # Determine if input tensors have valid dimensions. - valid_dim = True - for k in range(len(torch_input) - 1): - for tdim in range(ndims): - # Test whether all tensors have the same shape except in concatenating dimension - # Unless the number of dimensions is less than the corresponding at_least function dimension - # Since the original concatenating dimension would shift after applying at_least and would no - # longer be the concatenating dimension - if (ndims < at_least_dim or tdim != dim) and torch_input[k].size()[tdim] != torch_input[k + 1].size()[tdim]: - valid_dim = False - - # Special case for hstack is needed since hstack works differently when ndims is 1 - if valid_dim or (torch_fn is torch.hstack and ndims == 1): - # Valid dimensions, test against numpy - np_input = [input.cpu().numpy() for input in torch_input] - actual = torch_fn(torch_input) - expected = np_fn(np_input) - self.assertEqual(actual, expected) - else: - # Invalid dimensions, test for error - with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match except in dimension"): - torch_fn(torch_input) - with self.assertRaises(ValueError): - np_input = [input.cpu().numpy() for input in torch_input] - np_fn(np_input) - - @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) - def test_hstack(self, device, dtype): - self._test_special_stacks(1, 1, torch.hstack, np.hstack, device, dtype) - - @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) - def test_vstack(self, device, dtype): - self._test_special_stacks(0, 2, torch.vstack, np.vstack, device, dtype) - for i in range(5): - # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N) - n = random.randint(1, 10) - input_a = self._generate_input((n,), dtype, device, with_extremal=False) - input_b = self._generate_input((1, n), dtype, device, with_extremal=False) - torch_input = [input_a, input_b] - np_input = [input.cpu().numpy() for input in torch_input] - actual = torch.vstack(torch_input) - expected = np.vstack(np_input) - self.assertEqual(actual, expected) - - @onlyOnCPUAndCUDA - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) - def test_dstack(self, device, dtype): - self._test_special_stacks(2, 3, torch.dstack, np.dstack, device, dtype) - for i in range(5): - # Test dimension change for 1D tensor of size (N), 2D tensor of size (1, N), and 3D tensor of size (1, N, 1) - n = random.randint(1, 10) - input_a = self._generate_input((n,), dtype, device, with_extremal=False) - input_b = self._generate_input((1, n), dtype, device, with_extremal=False) - input_c = self._generate_input((1, n, 1), dtype, device, with_extremal=False) - torch_input = [input_a, input_b, input_c] - np_input = [input.cpu().numpy() for input in torch_input] - actual = torch.dstack(torch_input) - expected = np.dstack(np_input) - self.assertEqual(actual, expected) - - # Test dimension change for 2D tensor of size (M, N) and 3D tensor of size (M, N, 1) - m = random.randint(1, 10) - n = random.randint(1, 10) - input_a = self._generate_input((m, n), dtype, device, with_extremal=False) - input_b = self._generate_input((m, n, 1), dtype, device, with_extremal=False) - torch_input = [input_a, input_b] - np_input = [input.cpu().numpy() for input in torch_input] - actual = torch.dstack(torch_input) - expected = np.dstack(np_input) - self.assertEqual(actual, expected) - - @onlyOnCPUAndCUDA - def test_repeated_dim(self, device): - ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var, - torch.amin, torch.amax, torch.norm] - x = torch.randn(3, 3, 3, 3, device=device) - - error_msg = r'appears multiple times in the list of dims' - norm_error_msg = r'Expected dims to be different, got' - for op in ops: - for dim in [(0, 0), (0, -4)]: - e_msg = norm_error_msg if op == torch.norm else error_msg - with self.assertRaisesRegex(RuntimeError, e_msg): - op(x, dim=dim) - # Tests that compare a device's computation with the (gold-standard) CPU's. class TestDevicePrecision(TestCase): exact_dtype = True - # Note: ROCm fails when using float tensors - @dtypes(torch.double) - def test_polygamma(self, device, dtype): - cpu_tensor = torch.randn(10, 10, 10, dtype=dtype) - device_tensor = cpu_tensor.to(device) - zeros = torch.zeros(10, 10, 10, dtype=dtype) - for n in [0, 1, 2, 3, 4, 5]: - cpu_out = cpu_tensor.polygamma(n) - device_out = device_tensor.polygamma(n) - norm_errors = (device_out - cpu_out.to(device)) / device_out - self.assertEqual(norm_errors, zeros) - - cpu_tensor.requires_grad = True - for n in [0, 1, 2, 3, 4, 5]: - torch.autograd.gradcheck(lambda x: x.polygamma(n), - cpu_tensor) - - # Note: fails when using float tensors - @dtypes(torch.double) - def test_digamma(self, device, dtype): - cpu_tensor = torch.randn(10, 10, 10, dtype=dtype) - device_tensor = cpu_tensor.to(device) - zeros = torch.zeros(10, 10, 10, dtype=dtype) - cpu_out = cpu_tensor.digamma() - device_out = device_tensor.digamma() - norm_errors = (device_out - cpu_out.to(device)) / device_out - self.assertEqual(norm_errors, zeros) - - # Tests pole behavior - cpu_tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, - -100.99999994, -1931.99999994, 0.000000111, - -0.000000111, 0, -1, -2, -931], dtype=dtype) - expected_errors = torch.tensor([0, 0, 0, 0, 0, 0, 0, nan, nan, nan, nan], dtype=dtype) - device_tensor = cpu_tensor.to(device) - cpu_out = cpu_tensor.digamma() - device_out = device_tensor.digamma() - norm_errors = (device_out - cpu_out.to(device)) / device_out - self.assertEqual(norm_errors, expected_errors) - - def test_var(self, device): - cpu_tensor = torch.randn(2, 3, 3) - device_tensor = cpu_tensor.to(device) - self.assertEqual(device_tensor.var(), cpu_tensor.var()) - self.assertEqual(device_tensor.var(1), cpu_tensor.var(1)) - self.assertEqual(device_tensor.var(2), cpu_tensor.var(2)) - self.assertEqual(device_tensor.std(), cpu_tensor.std()) - self.assertEqual(device_tensor.std(1), cpu_tensor.std(1)) - self.assertEqual(device_tensor.var(2), cpu_tensor.var(2)) - - cpu_tensor = torch.randn(100) - device_tensor = cpu_tensor.to(device) - self.assertEqual(device_tensor.var(), cpu_tensor.var()) - - def test_var_large_input(self, device): - # Large, not-nice input - cpu_tensor = torch.randn(2 * 32 * 1024 + 1, 2, 67) - device_tensor = cpu_tensor.to(device) - - self.assertEqual(cpu_tensor.var(2), device_tensor.var(2)) - - @dtypesIfCUDA(torch.half, torch.float, torch.double) - @dtypes(torch.float, torch.double) - def test_device_rounding(self, device, dtype): - # test half-to-even - a = [-5.8, -3.5, -2.3, -1.5, -0.5, 0.5, 1.5, 2.3, 3.5, 5.8] - res = [-6., -4., -2., -2., 0., 0., 2., 2., 4., 6.] - - a_tensor = torch.tensor(a, device=device).round() - res_tensor = torch.tensor(res, device='cpu') - self.assertEqual(a_tensor, res_tensor) - @onlyCUDA - @skipCUDAIfNotRocm def test_index_add_bfloat16(self, device): inp_tensor = torch.randn(5, 3, device='cpu').bfloat16() t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.bfloat16, device='cpu') @@ -18901,14 +6213,6 @@ def test_index_add_bfloat16(self, device): self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0) - @dtypes(torch.double) - def test_sum_noncontig(self, device, dtype): - x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2) - y = x.cpu() - self.assertEqual(x.sum().cpu(), y.sum()) - self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2))) - self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3))) - def test_device_serialization(self, device): x = torch.randn(4, 4, device=device) @@ -18952,13 +6256,6 @@ def do_test(d0, d1): if len(devices) > 1: do_test(devices[0], devices[1]) - @dtypes(torch.float, torch.double) - def test_abs_zero(self, device, dtype): - # Both abs(0.0) and abs(-0.0) should result in 0.0 - abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist() - for num in abs_zeros: - self.assertGreater(math.copysign(1.0, num), 0.0) - @deviceCountAtLeast(2) def test_type_conversions_same_device(self, devices): x = torch.randn(5, 5, device=devices[1]) @@ -18966,24 +6263,6 @@ def test_type_conversions_same_device(self, devices): self.assertEqual(x.type(torch.int).device, torch.device(devices[1])) self.assertEqual(x.to(torch.int).device, torch.device(devices[1])) - def test_min_max_nan(self, device): - tests = [(lambda x: x.min(), 'min'), - (lambda x: x.max(), 'max'), - (lambda x: x.amin(), 'amin'), - (lambda x: x.amax(), 'amax'), - (lambda x: x.min(0).values, 'min_dim'), - (lambda x: x.max(0).values, 'max_dim'), - (lambda x: x.amin(0), 'amin_dim'), - (lambda x: x.amax(0), 'amax_dim')] - for f, name in tests: - a = torch.arange(25.0).view(5, 5) - a[2, 2] = nan - actual = f(a.to(device)).cpu() - expected = f(a).cpu() - self.assertEqual(torch.isnan(actual), torch.isnan(expected), msg='nans for {}'.format(name)) - self.assertEqual(actual[~torch.isnan(actual)], - expected[~torch.isnan(expected)], msg='nans for {}'.format(name)) - @dtypesIfCUDA(torch.half, torch.float, torch.double, torch.int8, torch.short, torch.int, torch.long, torch.uint8) @@ -18995,42 +6274,6 @@ def test_from_sequence(self, device, dtype): reference = torch.arange(0, 20).resize_(5, 4) self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference, exact_dtype=False) - def test_cat(self, device): - SIZE = 10 - for dim in range(-3, 3): - pos_dim = dim if dim >= 0 else 3 + dim - x = torch.rand(13, SIZE, SIZE, device=device).transpose(0, pos_dim) - y = torch.rand(17, SIZE, SIZE, device=device).transpose(0, pos_dim) - z = torch.rand(19, SIZE, SIZE, device=device).transpose(0, pos_dim) - - res1 = torch.cat((x, y, z), dim) - self.assertEqual(res1.narrow(pos_dim, 0, 13), x, atol=0, rtol=0) - self.assertEqual(res1.narrow(pos_dim, 13, 17), y, atol=0, rtol=0) - self.assertEqual(res1.narrow(pos_dim, 30, 19), z, atol=0, rtol=0) - - x = torch.randn(20, SIZE, SIZE, device=device) - self.assertEqual(torch.cat(torch.split(x, 7)), x) - self.assertEqual(torch.cat(torch.chunk(x, 7)), x) - - y = torch.randn(1, SIZE, SIZE, device=device) - z = torch.cat([x, y]) - self.assertEqual(z.size(), (21, SIZE, SIZE)) - - def test_sum_cpu_device_mismatch(self, device): - x = torch.randn(20, dtype=torch.float32, device=device) - y = torch.randn(1, dtype=torch.float32) - - err_string = "Expected all tensors to be on the same device, but found at least two devices, {0}".format(device) - - with self.assertRaisesRegex(RuntimeError, err_string): - torch.sum(x, dim=[0], dtype=torch.float32, out=y) - - # tests half to float promotion - if self.device_type == 'cuda': - x = x.half() - with self.assertRaisesRegex(RuntimeError, err_string): - torch.sum(x, dim=[0], dtype=torch.float32, out=y) - @deviceCountAtLeast(1) def test_advancedindex_mixed_cpu_devices(self, devices) -> None: def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None: @@ -19093,478 +6336,6 @@ def test_copy_broadcast(self, device) -> None: x.copy_(y) self.assertEqual(x[3], y) - def test_solve_methods_arg_device(self, device): - for b_device, A_device in product(['cpu', device], repeat=2): - if b_device == A_device: - continue - - b = torch.randn(3, 1, device=b_device) - A = torch.randn(3, 3, device=A_device) - err_str = "Expected b and A to be on the same device" - with self.assertRaisesRegex(RuntimeError, err_str): - torch.solve(b, A) - - with self.assertRaisesRegex(RuntimeError, err_str): - torch.cholesky_solve(b, A) - - with self.assertRaisesRegex(RuntimeError, err_str): - torch.triangular_solve(b, A) - - # b and A have to be modified to match accepted inputs sizes for lu_solve - b = b.unsqueeze(0) - A = A.unsqueeze(0) - with self.assertRaisesRegex(RuntimeError, err_str): - torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int()) - - # This checks if a suitable error message is thrown - # when LU output and pivots are on the same device - with self.assertRaisesRegex(RuntimeError, - "Expected LU_pivots and LU_data to be on the same device"): - torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) - - -# Tests ops and indexing to ensure they return views (and new tensors) as -# appropriate. -class TestViewOps(TestCase): - exact_dtype = True - - def is_view_of(self, base, other): - if (not other._is_view() or - other is base or - other._base is not base or - base.device != other.device): - return False - # Note: only validates storage on native device types - # because some accelerators, like XLA, do not expose storage - if base.device.type == 'cpu' or base.device.type == 'cuda': - if base.storage().data_ptr() != other.storage().data_ptr(): - return False - - return True - - # Performs transpose if contiguous=True, else returns the input tensor as is - def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): - if contiguous: - return x - else: - return x.transpose(dim0, dim1) - - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_conj_self(self, device, dtype): - t = torch.ones(5, 5, device=device) - s = t.conj() - self.assertTrue(s is t) - - @onlyOnCPUAndCUDA - def test_view_as_complex(self, device): - def fn(contiguous_input=True, dim0=0, dim1=1): - t = torch.randn(3, 2, 2, device=device) - c_t = t[:, :, 0] + 1j * t[:, :, 1] - - input = self._do_transpose(t, contiguous_input, dim0, dim1) - - if input.size()[-1] != 2: - self.assertRaisesRegex( - RuntimeError, "Tensor must have a last dimension of size 2", - lambda: torch.view_as_complex(input)) - return - - if input.stride()[-1] != 1: - self.assertRaisesRegex( - RuntimeError, "Tensor must have a last dimension with stride 1", - lambda: torch.view_as_complex(input)) - return - - res = torch.view_as_complex(input) - self.assertEqual(res, self._do_transpose(c_t, contiguous_input, dim0, dim1)) - self.assertTrue(self.is_view_of(t, res)) - - fn() - fn(contiguous_input=False) - # RuntimeError since in this case the last dim of input would not be of size 2 - fn(contiguous_input=False, dim0=0, dim1=2) - # RuntimeError since in this case the last dim of input would not have stride 1 - fn(contiguous_input=False, dim0=1, dim1=2) - - - # RuntimeError since in this case the stride of non-last dim of input would not be of size 2 - x = torch.randn(3, 3, device=device) - t = torch.as_strided(x, (2, 2), (1, 1)) - self.assertRaisesRegex( - RuntimeError, "Tensor must have a stride divisible by 2 for all but last dimension", - lambda: torch.view_as_complex(t)) - - # tensor with zero elements - x = torch.tensor([], device=device) # torch.Size([0]) - self.assertRaisesRegex( - RuntimeError, "Tensor must have a last dimension of size 2", - lambda: torch.view_as_complex(x)) - - # zero dimension tensor - z = torch.tensor(2.0) - self.assertRaisesRegex( - RuntimeError, "Input tensor must have one or more dimensions", - lambda: torch.view_as_complex(z)) - - y = x.reshape(0, 2) # torch.Size([0, 2]) - res = torch.view_as_complex(y) - self.assertTrue(self.is_view_of(x, res)) - self.assertEqual(res.shape, torch.Size([0])) - - @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_complex_dtypes(include_complex32=True)) - def test_view_as_real(self, device, dtype): - def fn(contiguous_input=True): - t = torch.randn(3, 4, dtype=dtype, device=device) - input = self._do_transpose(t, contiguous_input) - res = torch.view_as_real(input) - self.assertEqual(res[:, :, 0], input.real) - self.assertEqual(res[:, :, 1], input.imag) - # TODO: Add torch.ComplexHalfStorage - if dtype != torch.complex32: - self.assertTrue(self.is_view_of(t, res)) - else: - self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res)) - - fn() - fn(contiguous_input=False) - - # tensor with zero elements - x = torch.tensor([], dtype=dtype, device=device) - res = torch.view_as_real(x) - # TODO: Add torch.ComplexHalfStorage - if dtype != torch.complex32: - self.assertTrue(self.is_view_of(x, res)) - else: - self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) - self.assertEqual(res.shape, torch.Size([0, 2])) - - # tensor with zero dim - x = torch.tensor(2 + 3j, dtype=dtype, device=device) - res = torch.view_as_real(x) - # TODO: Add torch.ComplexHalfStorage - if dtype != torch.complex32: - self.assertTrue(self.is_view_of(x, res)) - else: - self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) - self.assertEqual(res.shape, torch.Size([2])) - - @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) - def test_real_imag_noncomplex(self, device, dtype): - t = torch.ones((5, 5), dtype=dtype, device=device) - - with self.assertRaises(RuntimeError): - torch.real(t) - - with self.assertRaises(RuntimeError): - torch.imag(t) - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_complex_dtypes()) - def test_real_imag_view(self, device, dtype): - def compare_with_numpy(contiguous_input=True): - t = torch.randn(3, 3, dtype=dtype, device=device) - if not contiguous_input: - u = t.T - else: - u = t - - re = u.real - exp = torch.from_numpy(u.cpu().numpy().real).to(device=device) - self.assertEqual(re, exp) - # for the case of contiguous_input, t=u - # for the case of non contiguous_input, the base still remains - # t since we are performing a view operation to make the input non-contiguous - self.assertTrue(self.is_view_of(t, re)) - - im = u.imag - exp = torch.from_numpy(u.cpu().numpy().imag).to(device=device) - self.assertEqual(im, exp) - self.assertTrue(self.is_view_of(t, im)) - - compare_with_numpy() - compare_with_numpy(contiguous_input=False) - - # ensure storage offset is being correctly set - a = torch.randn(10, dtype=dtype) - self.assertEqual(a[5:].real, a.real[5:]) - self.assertEqual(a[5:].imag, a.imag[5:]) - - @onlyOnCPUAndCUDA - @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) - @suppress_warnings - def test_set_real_imag(self, device, dtypes): - x = torch.randn(10, dtype=dtypes[0], device=device) - - new_real = _make_tensor((10,), dtypes[1], device) - new_imag = _make_tensor((10,), dtypes[1], device) - - x.real = new_real - x.imag = new_imag - - if dtypes[1].is_complex: - self.assertEqual(x.real, new_real.real, exact_dtype=False) - self.assertEqual(x.imag, new_imag.real, exact_dtype=False) - - else: - self.assertEqual(x.real, new_real, exact_dtype=False) - self.assertEqual(x.imag, new_imag, exact_dtype=False) - - def test_diagonal_view(self, device) -> None: - t = torch.ones((5, 5), device=device) - v = torch.diagonal(t) - self.assertTrue(self.is_view_of(t, v)) - - v[0] = 0 - self.assertEqual(t[0, 0], v[0]) - - t = torch.ones((3, 3, 3), device=device) - v = torch.diagonal(t, offset=1, dim1=1, dim2=2) - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0] = 0 - self.assertEqual(t[0, 0, 1], v[0, 0]) - - def test_select_view(self, device) -> None: - t = torch.ones((5, 5), device=device) - v = t.select(0, 2) - self.assertTrue(self.is_view_of(t, v)) - - v[0] = 0 - self.assertEqual(t[2, 0], v[0]) - - def test_unbind_view(self, device) -> None: - t = torch.zeros((5, 5), device=device) - tup = torch.unbind(t) - - for idx, v in enumerate(tup): - self.assertTrue(self.is_view_of(t, v)) - - v[0] = idx + 1 - self.assertEqual(t[idx, 0], v[0]) - - def test_expand_view(self, device) -> None: - t = torch.ones((5, 1), device=device) - v = t.expand(5, 5) - self.assertTrue(self.is_view_of(t, v)) - - v[2, 2] = 0 - self.assertEqual(t[2, 0], v[2, 2]) - - def test_expand_as_view(self, device): - t = torch.ones((5, 1), device=device) - e = torch.empty((5, 5), device=device) - v = t.expand_as(e) - self.assertTrue(self.is_view_of(t, v)) - - v[2, 2] = 0 - self.assertEqual(t[2, 0], v[2, 2]) - - def test_narrow_view(self, device): - t = torch.ones((5, 5), device=device) - v = torch.narrow(t, 1, 2, 2) - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0] = 0 - self.assertEqual(t[0, 2], v[0, 0]) - - def test_permute_view(self, device) -> None: - t = torch.ones((5, 5), device=device) - v = t.permute(1, 0) - self.assertTrue(self.is_view_of(t, v)) - - v[0, 1] = 0 - self.assertEqual(t[1, 0], v[0, 1]) - - def test_transpose_view(self, device): - t = torch.ones((5, 5), device=device) - v = torch.transpose(t, 0, 1) - self.assertTrue(self.is_view_of(t, v)) - - v[0, 1] = 0 - self.assertEqual(t[1, 0], v[0, 1]) - - def test_t_view(self, device): - t = torch.ones((5, 5), device=device) - v = t.t() - self.assertTrue(self.is_view_of(t, v)) - - v[0, 1] = 0 - self.assertEqual(t[1, 0], v[0, 1]) - - def test_T_view(self, device): - t = torch.ones((5, 5), device=device) - v = t.T - self.assertTrue(self.is_view_of(t, v)) - - v[0, 1] = 0 - self.assertEqual(t[1, 0], v[0, 1]) - - def test_unfold_view(self, device): - t = torch.ones(10, device=device) - v = t.unfold(0, 3, 2) - self.assertTrue(self.is_view_of(t, v)) - - v[1, 0] = 0 - self.assertEqual(t[2], v[1, 0]) - - def test_squeeze_view(self, device): - t = torch.ones(5, 1, 5, device=device) - v = torch.squeeze(t) - self.assertTrue(self.is_view_of(t, v)) - v[0, 1] = 0 - self.assertEqual(t, v._base) - - def test_unsqueeze_view(self, device): - t = torch.ones(5, 5, device=device) - v = torch.unsqueeze(t, 1) - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0, 1] = 0 - self.assertEqual(t[0, 1], v[0, 0, 1]) - - def test_as_strided_view(self, device): - t = torch.ones(5, 5, device=device) - v = torch.as_strided(t, (25,), (1,)) - self.assertTrue(self.is_view_of(t, v)) - - v[6] = 0 - self.assertEqual(t[1, 1], v[6]) - - def test_view_view(self, device): - t = torch.ones(5, 5, device=device) - v = t.view(25) - self.assertTrue(self.is_view_of(t, v)) - - v[6] = 0 - self.assertEqual(t[1, 1], v[6]) - - def test_view_as_view(self, device): - t = torch.ones(5, 5, device=device) - e = torch.empty((25,)) - v = t.view_as(e) - self.assertTrue(self.is_view_of(t, v)) - - v[6] = 0 - self.assertEqual(t[1, 1], v[6]) - - def test_contiguous_self(self, device): - t = torch.ones(5, 5, device=device) - s = t.contiguous() - self.assertTrue(s is t) - - def test_contiguous_nonview(self, device): - t = torch.ones(5, 5, device=device) - nv = t.t().contiguous() - self.assertTrue(not self.is_view_of(t, nv)) - - nv[0, 0] = 0 - self.assertNotEqual(t[0, 0], nv[0, 0]) - - def test_reshape_view(self, device): - t = torch.ones(5, 5, device=device) - v = torch.reshape(t, (25,)) - self.assertTrue(self.is_view_of(t, v)) - - v[6] = 0 - self.assertEqual(t[1, 1], v[6]) - - def test_reshape_as_view(self, device): - t = torch.ones(5, 5, device=device) - e = torch.empty((25,), device=device) - v = t.reshape_as(e) - self.assertTrue(self.is_view_of(t, v)) - - v[6] = 0 - self.assertEqual(t[1, 1], v[6]) - - def test_reshape_nonview(self, device): - t = torch.ones(5, 5, device=device) - nv = torch.reshape(t.t(), (25,)) - self.assertTrue(not self.is_view_of(t, nv)) - - nv[6] = 0 - self.assertNotEqual(t[1, 1], nv[6]) - - def test_basic_indexing_slice_view(self, device): - t = torch.ones(5, 5, device=device) - v = t[:2, :3] - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0] = 0 - self.assertEqual(t[0, 0], v[0, 0]) - - def test_basic_indexing_ellipses_view(self, device): - t = torch.ones(5, 5, device=device) - v = t[..., :2] - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0] = 0 - self.assertEqual(t[0, 0], v[0, 0]) - - def test_basic_indexing_newaxis_view(self, device): - t = torch.ones(5, 5, device=device) - v = t[None, :2, 3] - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0] = 0 - self.assertEqual(t[0, 3], v[0, 0]) - - def test_advanced_indexing_nonview(self, device): - t = torch.ones(3, 3, device=device) - rows = torch.tensor([[0, 0], [2, 2]], device=device) - cols = torch.tensor([[0, 1], [2, 2]], device=device) - nv = t[rows, cols] - self.assertTrue(not self.is_view_of(t, nv)) - - nv[1, 1] = 0 - self.assertNotEqual(t[2, 2], nv[1, 1]) - - def test_advanced_indexing_assignment(self, device): - t = torch.ones(3, 3, device=device) - rows = torch.tensor([[0, 0], [2, 2]], device=device) - cols = torch.tensor([[0, 1], [2, 2]], device=device) - t[rows, cols] = 0 - self.assertEqual(t[2, 2], 0) - - @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") - def test_chunk_view(self, device): - t = torch.zeros(3, 3, device=device) - l = torch.chunk(t, 3) - - for idx, v in enumerate(l): - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0] = idx + 1 - self.assertEqual(t[idx, 0], v[0, 0]) - - @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") - def test_split_view(self, device): - t = torch.zeros(3, 3, device=device) - l = torch.split(t, [1, 1, 1]) - - for idx, v in enumerate(l): - self.assertTrue(self.is_view_of(t, v)) - - v[0, 0] = idx + 1 - self.assertEqual(t[idx, 0], v[0, 0]) - - def test_movedim_view(self, device): - t = torch.zeros(3, 3, device=device) - out = torch.movedim(t, (0, 1), (1, 0)) - - self.assertTrue(self.is_view_of(t, out)) - - # Randomly change values in output - # and verify that original is changed - # as well. - for _ in range(3): - idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) - out[idx_1, idx_2] = random.random() - self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) - # Below are fixtures and functions that generate tensor op comparison tests # These tests run a single op on both a CPU and device tensor and compare the # the results. In-place variants of the ops can also be run. @@ -19582,10 +6353,6 @@ def test_movedim_view(self, device): torch.uint8 ] -# _types2 adds bfloat16 type to _types only on ROCm. Should eventually be unified -# with _types when bfloat16 bringup is complete on all platforms. -_types2 = _types + [torch.bfloat16] if TEST_WITH_ROCM else _types - _float_types = [torch.half, torch.float, torch.double] _complex_types = [torch.cfloat, torch.cdouble] @@ -19594,12 +6361,8 @@ def test_movedim_view(self, device): _float_types_no_half = [torch.float, torch.double] -# _float_types2 adds bfloat16 type to _float_types only on ROCm. Should eventually be unified -# with _float_types when bfloat16 bringup is complete on all platforms -_float_types2 = _float_types + [torch.bfloat16] if TEST_WITH_ROCM else _float_types - _signed_types = [ - torch.half, torch.float, torch.double, + torch.half, torch.bfloat16, torch.float, torch.double, torch.int8, torch.short, torch.int, torch.long ] @@ -19637,7 +6400,10 @@ def test_movedim_view(self, device): def _number(floating, integer, dtype): if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]: return floating - return integer + elif dtype in [torch.cfloat, torch.cdouble]: + return floating * (1 + 1j) + else: + return integer # Converts half/bfloat16 dtype to float when device is cpu def _convert_t(dtype, device): @@ -19725,7 +6491,7 @@ def inner(self, device, dtype): return inner return decorator - +# TODO: these tests should be refactored into other test suites using OpInfos # TODO: random functions, cat, gather, scatter, index*, masked*, # resize, resizeAs, storage_offset, storage, stride, unfold # Each tests is defined in tensor_op_tests as a tuple of: @@ -19789,11 +6555,11 @@ def inner(self, device, dtype): ('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)), ('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True, - [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), + 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, _cpu_types, True, + [tf32_on_and_off(0.05), _wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('baddbmm', 'two_scalars', _small_3d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True, - [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), + 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, + _cpu_types, True, [tf32_on_and_off(0.05), _wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('bmm', '', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False), ('addcdiv', '', _small_2d, @@ -19805,10 +6571,10 @@ def inner(self, device, dtype): _small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3, _float_types, _cpu_types, True), ('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-3, - torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + torch.testing.get_all_dtypes(include_complex=True, include_bool=False)), ('addcmul', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2, - 1e-1, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, True, + 1e-1, 1e-5, torch.testing.get_all_dtypes(include_complex=True, include_bool=False), _cpu_types, True, [_wrap_maybe_warns("This overload of addcmul_? is deprecated")]), ('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), @@ -19825,7 +6591,7 @@ def inner(self, device, dtype): torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True, [], 0, True), ('addmv', 'scalar', _medium_1d, - lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4, + lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True, [_wrap_maybe_warns("This overload of addmv_? is deprecated")]), ('addmv', 'two_scalars', _medium_1d, @@ -19833,7 +6599,6 @@ def inner(self, device, dtype): torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True, [_wrap_maybe_warns("This overload of addmv_? is deprecated")]), ('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _types, _types_no_half), - ('angle', '', _small_3d, lambda t, d: [], 0, 0, 0, _types_no_half, [torch.bfloat16], False), ('fmod', 'value', _small_3d, lambda t, d: [3], 1e-3), ('fmod', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-3), ('chunk', '', _medium_2d, lambda t, d: [4], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), @@ -19841,8 +6606,10 @@ def inner(self, device, dtype): ('chunk', 'neg_dim', _medium_2d, lambda t, d: [4, -2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('clamp', 'neg', _medium_2d, lambda t, d: [-1, 5], 1e-5, 1e-2, 1e-5, _signed_types, [torch.bfloat16]), ('clamp', 'pos', _medium_2d, lambda t, d: [1, 5], 1e-5, 1e-2, 1e-5, _unsigned_types, [torch.bfloat16]), - ('clamp_min', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _types, [torch.bfloat16]), - ('clamp_max', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _types, [torch.bfloat16]), + ('clamp_min', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=True), [torch.bfloat16]), + ('clamp_max', '', _medium_2d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=True), [torch.bfloat16]), ('clone', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('conj', '', _small_3d, lambda t, d: [], 1e-5, 0, 1e-5, _types_no_half, [torch.bfloat16], False), @@ -19866,10 +6633,14 @@ def inner(self, device, dtype): ('dot', '', _medium_1d, lambda t, d: [_medium_1d(t, d)], 1e-2, 1e-5, 1e-5, _float_types + _complex_types, _cpu_types, False), ('element_size', '', _medium_1d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False), - ('eq', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('eq', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('ne', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('ne', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types2), + ('eq', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('eq', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('ne', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('ne', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), ('equal', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('equal', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), @@ -19883,10 +6654,14 @@ def inner(self, device, dtype): ('lcm', '', _small_3d, lambda t, d: [_small_3d(t, d)], 0, 0, 0, [torch.int16, torch.int32, torch.int64], [torch.int16, torch.int32, torch.int64], True, [onlyOnCPUAndCUDA]), - ('ge', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('le', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('gt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), - ('lt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2), + ('ge', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('le', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('gt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), + ('lt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False)), ('is_contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), # TODO: can't check negative case - cross-device copy is contiguous ('is_same_size', 'negative', _medium_2d, lambda t, d: [_small_3d(t, d)], @@ -19914,9 +6689,12 @@ def inner(self, device, dtype): 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('minimum', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False), - ('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False), - ('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, torch.testing.get_all_fp_dtypes(), _cpu_types, False), + ('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes(), _cpu_types, False), + ('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes(), _cpu_types, False), + ('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, + torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes(), _cpu_types, False), # Double here because the CPU result will be wrong otherwise ('mean', '64bit_indexing', _giant_1d, lambda t, d: [], 1e-3, 1e-5, 1e-5, [torch.double], _cpu_types, False, [slowTest]), @@ -19927,14 +6705,14 @@ def inner(self, device, dtype): 1e-5, 1e-5, 1e-5, _float_types_no_half), ('mvlgamma', '2d_p=2', lambda t, d: _small_2d(t, d).clamp(0.6, 10), lambda t, d: [2], 1e-5, 1e-5, 1e-5, _float_types_no_half), - ('remainder', 'value', _small_3d, lambda t, d: [3], 1e-1, 1e-5, 1e-5, _signed_types), - ('remainder', 'negative_value', _small_3d, lambda t, d: [-3], 1e-1, 1e-5, 1e-5, _signed_types), + ('remainder', 'value', _small_3d, lambda t, d: [3], 1e-1, 1e-2, 1e-5, _signed_types), + ('remainder', 'negative_value', _small_3d, lambda t, d: [-3], 1e-1, 1e-2, 1e-5, _signed_types), ('remainder', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], - 1e-1, 1e-5, 1e-5, _signed_types), + 1e-1, 1e-2, 1e-5, _signed_types), ('remainder', 'negative_tensor', _small_3d, lambda t, d: [0 - _small_3d(t, d, has_zeros=False)], - 1e-1, 1e-5, 1e-5, _signed_types), + 1e-1, 1e-2, 1e-5, _signed_types), ('std', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), ('std', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), ('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), @@ -19947,10 +6725,12 @@ def inner(self, device, dtype): ('narrow', '', _small_3d, lambda t, d: [1, 3, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('narrow', 'neg_dim', _small_3d, lambda t, d: [-1, 3, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('nonzero', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('norm', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), - ('norm', '3_norm', _small_3d, lambda t, d: [3], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), - ('norm', '3_norm_dim', _small_3d, lambda t, d: [3, 0], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), - ('norm', '3_norm_neg_dim', _small_3d, lambda t, d: [3, -2], 1e-1, 1e-1, 1e-5, _float_types2, _cpu_types, False), + ('norm', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False), + ('norm', '3_norm', _small_3d, lambda t, d: [3], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False), + ('norm', '3_norm_dim', _small_3d, lambda t, d: [3, 0], 1e-1, 1e-1, 1e-5, + torch.testing.get_all_fp_dtypes(), _cpu_types, False), + ('norm', '3_norm_neg_dim', _small_3d, lambda t, d: [3, -2], 1e-1, 1e-1, 1e-5, + torch.testing.get_all_fp_dtypes(), _cpu_types, False), ('new_ones', '', _small_3d, lambda t, d: [1, 2, 3, 4, 5], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('permute', '', _new_t((1, 2, 3, 4)), lambda t, d: [2, 1, 3, 0], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('put_', '', _new_t((2, 5, 3)), @@ -19965,12 +6745,16 @@ def inner(self, device, dtype): torch.LongTensor([[1], [2]]).to(dtype=_convert_t(t, d), device=d), True], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('prod', '', lambda t, d: _small_2d(t, d, oneish=True), - lambda t, d: [], 1e-2, 1e-1, 1e-5, _types2, _cpu_types, False), - ('prod', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-1, 1e-5, _types2, _cpu_types, False), - ('prod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-1, 1e-5, _types2, _cpu_types, False), - ('sum', '', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _types2, _cpu_types, False), - ('sum', 'dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _types2, _cpu_types, False), + ('prod', '', lambda t, d: _small_2d(t, d, oneish=True), lambda t, d: [], 1e-2, 1e-1, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('prod', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-1, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('prod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-1, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('sum', '', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), + ('sum', 'dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, + torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), ('sum', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-5, 1e-5, _types, _cpu_types, False), ('sum', 'complex', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False), ('sum', 'complex_dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False), @@ -19978,15 +6762,28 @@ def inner(self, device, dtype): ('renorm', '2_norm', _small_3d, lambda t, d: [2, 1, 1], 1e-3, 1e-5, 1e-5, _float_types), ('renorm', '2_norm_neg_dim', _small_3d, lambda t, d: [2, -1, 1], 1e-3, 1e-5, 1e-5, _float_types), ('renorm', '1_5_norm', _small_3d, lambda t, d: [1.5, 1, 1], 1e-3, 1e-5, 1e-5, _float_types), - ('repeat', '', _small_2d, lambda t, d: [2, 2, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('size', '', _new_t((1, 2, 3, 4)), lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('size', 'dim', _new_t((1, 2, 3, 4)), lambda t, d: [1], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('size', 'neg_dim', _new_t((1, 2, 3, 4)), lambda t, d: [-2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('sort', '', _small_3d_unique, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('sort', 'dim', _small_3d_unique, lambda t, d: [1], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('sort', 'neg_dim', _small_3d_unique, lambda t, d: [-1], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('sort', 'dim_descending', _small_3d_unique, lambda t, d: [1, True], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), - ('sort', 'neg_dim_descending', _small_3d_unique, lambda t, d: [-1, True], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), + ('sort', 'stable', _small_3d_unique, lambda t, d: [0, False, True], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False, [onlyCPU]), + ('sort', 'dim', _small_3d_unique, lambda t, d: [1], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), + ('sort', 'dim_stable', _small_3d_unique, lambda t, d: [1, False, True], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False, [onlyCPU]), + ('sort', 'neg_dim', _small_3d_unique, lambda t, d: [-1], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), + ('sort', 'neg_dim_stable', _small_3d_unique, lambda t, d: [-1, False, True], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False, [onlyCPU]), + ('sort', 'dim_descending', _small_3d_unique, lambda t, d: [1, True, False], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), + ('sort', 'dim_descending_stable', _small_3d_unique, lambda t, d: [1, True, True], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False, [onlyCPU]), + ('sort', 'neg_dim_descending', _small_3d_unique, lambda t, d: [-1, True, False], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), + ('sort', 'neg_dim_descending_stable', _small_3d_unique, lambda t, d: [-1, True, True], + 1e-5, 1e-5, 1e-5, _types, _cpu_types, False, [onlyCPU]), ('split', '', _small_3d, lambda t, d: [2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('split', 'dim', _small_3d, lambda t, d: [2, 1], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('split', 'neg_dim', _small_3d, lambda t, d: [2, -3], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), @@ -20001,11 +6798,11 @@ def inner(self, device, dtype): ('transpose', 'neg_dim', _new_t((1, 2, 3, 4)), lambda t, d: [-1, -2], ), ('tolist', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('topk', 'dim_sort', _small_3d_unique, lambda t, d: [2, 1, False, True], - 1e-5, 1e-5, 1e-5, _types2, _cpu_types, False), + 1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), ('topk', 'neg_dim_sort', _small_3d_unique, lambda t, d: [2, -1, False, True], - 1e-5, 1e-5, 1e-5, _types2, _cpu_types, False), + 1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), ('topk', 'dim_desc_sort', _small_3d_unique, lambda t, d: [2, 1, True, True], - 1e-5, 1e-5, 1e-5, _types2, _cpu_types, False), + 1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), ('trace', '', _medium_2d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _types, _cpu_types, False), ('tril', '', _medium_2d, lambda t, d: [],), ('tril', 'zero_stride', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), @@ -20030,10 +6827,6 @@ def inner(self, device, dtype): ('rot90', 'k1_d12', _small_3d, lambda t, d: [1, [1, 2]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False), ('rot90', 'k1_neg_d', _small_3d, lambda t, d: [1, [1, -1]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False), ('rot90', 'default', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False), - ('rsqrt', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-2, 1e-5, 1e-4, _float_types_no_half), - ('sinh', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types), - ('tan', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types), - ('tan', 'complex', lambda t, d: _small_3d(t, d), lambda t, d: [], 1e-3, 1e-5, 1e-5, _complex_types), ('__lshift__', '', lambda t, d: torch.pow(2, torch.arange(1, 5).to(dtype=_convert_t(t, d), device=d)), lambda t, d: [2], @@ -20054,43 +6847,19 @@ def inner(self, device, dtype): ('geqrf', '', _new_t((20, 20)), lambda t, d: [], 1e-5, 1e-5, 3e-4, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]), ('eig', 'with_eigvec', _new_t((10, 10)), lambda t, d: [True], - 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]), + 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma, onlyOnCPUAndCUDA]), ('abs', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), [torch.bfloat16]), ('sign', '', _small_3d, lambda t, d: []), - ('log', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('log10', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('log1p', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types_no_half, [torch.bfloat16]), - ('log2', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('sigmoid', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()), ('logit', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()), - ('sqrt', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('tanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, - torch.testing.get_all_fp_dtypes() + _complex_types, [torch.bfloat16]), - ('asin', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]), - ('atan', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]), - ('acosh', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()), - ('asinh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()), - ('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()), - ('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]), - ('erfinv', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]), - ('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()), - ('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1), - lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('expm1', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types), - ('expm1', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1), - lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), - ('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]), ('floor', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('frac', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('round', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]), - ('digamma', 'op', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e0, _float_types_no_half), ] # Creates and decorates a generic test and adds it to the class. @@ -20215,346 +6984,26 @@ def caller(cls, for test in tensor_op_tests: caller(cls, *test) -def _generate_reference_input(dtype, device): - input = [] - input.append(list(range(-5, 5))) - input.append([0 for x in range(-5, 5)]) - input.append([x + 1e-6 for x in range(-5, 5)]) - # Some vectorized implementations don't support large values - input.append([x + 1e10 for x in range(-5, 5)]) - input.append([x - 1e10 for x in range(-5, 5)]) - input.append([*torch.randn(7).tolist(), math.inf, -math.inf, math.nan]) - input.append((torch.randn(10) * 1e6).tolist()) - input.append([math.pi * (x / 2) for x in range(-5, 5)]) - return torch.tensor(input, dtype=dtype, device=device) - -def _generate_gamma_input(dtype, device, test_poles=True): - input = [] - input.append((torch.randn(10).abs() + 1e-4).tolist()) - input.append((torch.randn(10).abs() + 1e6).tolist()) - zeros = torch.linspace(-9.5, -0.5, 10) - input.append(zeros.tolist()) - input.append((zeros - 0.49).tolist()) - input.append((zeros + 0.49).tolist()) - input.append((zeros + (torch.rand(10) * 0.99) - 0.5).tolist()) - - if test_poles: - input.append([-0.999999994, -1.999999994, -2.0000000111, - -100.99999994, -1931.99999994, 0.000000111, - -0.000000111, 0, -2, -329]) - return torch.tensor(input, dtype=dtype, device=device) - -# this class contains information needed to generate tests for torch math functions -# the generated tests compare torch implementation with the reference numpy/scipy implementation, -# and also check proper behavior for contiguous/discontiguous/inplace outputs. -class _TorchMathTestMeta(object): - def __init__(self, - opstr, - args=(), - reffn=None, - refargs=lambda x: (x.numpy(),), - input_fn=_generate_reference_input, - inputargs=(), - substr='', - make_inplace=True, - decorators=None, - ref_backend='numpy', - rtol=None, - atol=None, - dtypes=_float_types_no_half, - replace_inf_with_nan=False): - self.opstr = opstr - self.args = args - self.reffn = reffn # reffn is either callable or ref_backend attribute, set to opstr if not specified - self.refargs = refargs - self.input_fn = input_fn - self.inputargs = inputargs - self.substr = substr - self.make_inplace = make_inplace - assert ref_backend == 'numpy' or ref_backend == 'scipy' - self.ref_backend = ref_backend - if ref_backend == 'numpy': - self.ref_decorator = [unittest.skipIf(not TEST_NUMPY, "Numpy not found")] - elif ref_backend == 'scipy': - self.ref_decorator = [unittest.skipIf(not TEST_SCIPY, "Scipy not found")] - self.decorators = decorators - self.rtol = rtol - self.atol = atol - self.dtypes = dtypes - self.replace_inf_with_nan = replace_inf_with_nan - -torch_op_tests = [_TorchMathTestMeta('sqrt'), - _TorchMathTestMeta('erf', ref_backend='scipy'), - _TorchMathTestMeta('erfc', ref_backend='scipy'), - _TorchMathTestMeta('exp'), - _TorchMathTestMeta('expm1'), - _TorchMathTestMeta('floor'), - _TorchMathTestMeta('ceil'), - _TorchMathTestMeta('rad2deg'), - _TorchMathTestMeta('deg2rad'), - _TorchMathTestMeta('rsqrt', reffn=lambda x: np.reciprocal(np.sqrt(x))), - _TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)), - _TorchMathTestMeta('trunc'), - _TorchMathTestMeta('round'), - # FIXME lgamma produces different result compared to scipy at -inf - _TorchMathTestMeta('lgamma', reffn='gammaln', ref_backend='scipy', replace_inf_with_nan=True), - _TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma', - refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], - ref_backend='scipy'), - _TorchMathTestMeta('polygamma', args=[1], substr='_1', reffn='polygamma', - refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], - ref_backend='scipy', rtol=0.0008, atol=1e-5), - _TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma', - refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], - ref_backend='scipy', rtol=0.0008, atol=1e-5), - _TorchMathTestMeta('digamma', - input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy', - replace_inf_with_nan=True), - _TorchMathTestMeta('abs', input_fn=_medium_2d, dtypes=_types_no_half, rtol=0., atol=0.), - _TorchMathTestMeta('logit', ref_backend='scipy')] - - -def generate_torch_test_functions(cls, testmeta, inplace): - opstr = testmeta.opstr if not inplace else testmeta.opstr + "_" - - def torchfn(x): - return getattr(x, opstr)(*testmeta.args) - - def fn_check_reference(self, device, dtype): - def reffn(x): - backend = np if testmeta.ref_backend == 'numpy' else scipy.special - opstr = None - if testmeta.reffn is None: - opstr = testmeta.opstr - elif isinstance(testmeta.reffn, str): - opstr = testmeta.reffn - if callable(testmeta.reffn): - fn = testmeta.reffn - else: - assert opstr is not None, "invalid reffn" - fn = getattr(backend, opstr) - return fn(*testmeta.refargs(x)) - - inp = testmeta.input_fn(dtype, device, *testmeta.inputargs) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - expected = torch.from_numpy(reffn(inp)) - actual = torchfn(inp) - if testmeta.replace_inf_with_nan: - actual[(actual == -inf) | (actual == inf)] = nan - expected[(expected == -inf) | (expected == inf)] = nan - - torch.testing.assert_allclose(actual, expected, rtol=testmeta.rtol, atol=testmeta.atol) - - def fn_non_contig(self, device, dtype) -> None: - shapes = [(5, 7), (1024,)] - for shape in shapes: - contig = _make_tensor(shape, dtype=dtype, device=device) - non_contig = torch.empty(shape + (2,), dtype=dtype)[..., 0] - non_contig.copy_(contig) - self.assertFalse(non_contig.is_contiguous()) - self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous') - - def fn_non_contig_index(self, device, dtype): - contig = _make_tensor((2, 2, 1, 2), dtype=dtype, device=device) - non_contig = contig[:, 1, ...] - contig = non_contig.clone() - self.assertFalse(non_contig.is_contiguous()) - self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous index') - - def fn_non_contig_expand(self, device, dtype): - shapes = [(1, 3), (1, 7), (5, 7)] - for shape in shapes: - contig = _make_tensor(shape, dtype=dtype, device=device) - non_contig = contig.clone().expand(3, -1, -1) - self.assertFalse(non_contig.is_contiguous()) - contig = torchfn(contig) - non_contig = torchfn(non_contig) - for i in range(3): - self.assertEqual(contig, non_contig[i], msg='non-contiguous expand[' + str(i) + ']') - - def fn_contig_size1(self, device, dtype): - contig = _make_tensor((5, 100), dtype=dtype, device=device) - contig = contig[:1, :50] - contig2 = torch.empty(contig.size(), dtype=dtype) - contig2.copy_(contig) - self.assertTrue(contig.is_contiguous()) - self.assertTrue(contig2.is_contiguous()) - self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1') - - def fn_contig_size1_large_dim(self, device, dtype): - contig = _make_tensor((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype, device=device) - contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :] - contig2 = torch.empty(contig.size(), dtype=dtype) - contig2.copy_(contig) - self.assertTrue(contig.is_contiguous()) - self.assertTrue(contig2.is_contiguous()) - self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1') - - def fn_large(self, device, dtype): - input = _make_tensor((1024, 512), dtype=dtype, device=device) - # clone input to properly test inplace functions - actual = torchfn(input.clone()) - expected = torch.stack([torchfn(slice) for slice in input]) - self.assertEqual(actual, expected, msg='large') - - test_functions = {"test_reference_": fn_check_reference, - "test_non_contig_": fn_non_contig, - "test_non_contig_index_": fn_non_contig_index, - "test_non_contig_expand_": fn_non_contig_expand, - "test_contig_size1_": fn_contig_size1, - "test_check_contig_size1_large_dim_": fn_contig_size1_large_dim, - "test_large_": fn_large} - for name in test_functions: - if inplace and 'expand' in name: - continue - test_name = name + testmeta.opstr + testmeta.substr - if inplace: - test_name += "_inplace" - assert not hasattr(cls, test_name), "{0} already in TestTorchMathOps".format(test_name) - - decorators = [] if testmeta.decorators is None else testmeta.decorators - if 'reference' in name: - decorators = decorators + testmeta.ref_decorator - decorators = decorators + [dtypes(*testmeta.dtypes)] - fn_test = test_functions[name] - for dec in decorators: - fn_test = dec(fn_test) - setattr(cls, test_name, fn_test) - - - - -def generate_torch_op_tests(cls): - for t in torch_op_tests: - generate_torch_test_functions(cls, t, False) - if t.make_inplace: - generate_torch_test_functions(cls, t, True) - - - - - -tensor_binary_ops = [ - '__lt__', '__le__', - '__gt__', '__ge__', - '__eq__', '__ne__', - - '__add__', '__radd__', '__iadd__', - '__sub__', '__rsub__', '__isub__', - '__mul__', '__rmul__', '__imul__', - '__matmul__', '__rmatmul__', '__imatmul__', - '__truediv__', '__rtruediv__', '__itruediv__', - '__floordiv__', '__rfloordiv__', '__ifloordiv__', - '__mod__', '__rmod__', '__imod__', - '__divmod__', '__rdivmod__', '__idivmod__', - '__pow__', '__rpow__', '__ipow__', - '__lshift__', '__rlshift__', '__ilshift__', - '__rshift__', '__rrshift__', '__irshift__', - '__and__', '__rand__', '__iand__', - '__xor__', '__rxor__', '__ixor__', - '__or__', '__ror__', '__ior__', -] - - -# Test that binary math operations return NotImplemented for unknown types. -def generate_not_implemented_tests(cls): - class UnknownType: - pass - - for op in tensor_binary_ops: - @dtypes(*_types) - def test(self, device, dtype): - # Generate the inputs - tensor = _small_2d(dtype, device) - - # Runs the tensor op on the device - result = getattr(tensor, op)(UnknownType()) - self.assertEqual(result, NotImplemented) - - test_name = "test_{}_not_implemented".format(op) - assert not hasattr(cls, test_name), "{0} already in {1}".format( - test_name, cls.__name__) - - setattr(cls, test_name, test) - - class TestTensorDeviceOps(TestCase): exact_dtype = True - def _test_svd_helper(self, shape, some, col_maj, device, dtype): - cpu_tensor = torch.randn(shape, device='cpu').to(dtype) - device_tensor = cpu_tensor.to(device=device) - if col_maj: - cpu_tensor = cpu_tensor.t() - device_tensor = device_tensor.t() - cpu_result = torch.svd(cpu_tensor, some=some) - device_result = torch.svd(device_tensor, some=some) - m = min(cpu_tensor.shape[-2:]) - # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). - # - When some==False, U[..., m:] can be arbitrary. - # - When some==True, U shape: [..., m], V shape: [m, m] - # - Signs are not deterministic. If the sign of a column of U is changed - # then the corresponding column of the V has to be changed. - # Thus here we only compare result[..., :m].abs() from CPU and device. - for x, y in zip(cpu_result, device_result): - self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_square(self, device, dtype): - self._test_svd_helper((10, 10), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_square_col_maj(self, device, dtype): - self._test_svd_helper((10, 10), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_some(self, device, dtype): - self._test_svd_helper((20, 5), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_all(self, device, dtype): - self._test_svd_helper((20, 5), False, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_some_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_all_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), False, True, device, dtype) - -class TestTorchMathOps(TestCase): - exact_dtype = True - class TestTorch(AbstractTestCases._TestTorchMixin): exact_dtype = True +# TODO: this empy class is temporarily instantiated for XLA compatibility +# once XLA updates their test suite it should be removed +class TestViewOps(TestCase): + pass # Generates tests # Note: test generation must be done at file scope, not within main, or # pytest will fail. add_neg_dim_tests() generate_tensor_op_tests(TestTensorDeviceOps) -generate_not_implemented_tests(TestTorchDeviceType) -generate_torch_op_tests(TestTorchMathOps) -instantiate_device_type_tests(TestTorchDeviceType, globals()) instantiate_device_type_tests(TestViewOps, globals()) -instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') instantiate_device_type_tests(TestTensorDeviceOps, globals()) -instantiate_device_type_tests(TestTorchMathOps, globals(), only_for='cpu') +instantiate_device_type_tests(TestTorchDeviceType, globals()) +instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') if __name__ == '__main__': run_tests() diff --git a/test/test_type_hints.py b/test/test_type_hints.py index f2429608159e5..b8635aa7554e8 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -1,5 +1,5 @@ import unittest -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, set_cwd import tempfile import torch import re @@ -149,33 +149,13 @@ def test_doc_examples(self): except OSError: raise unittest.SkipTest('cannot symlink') from None (stdout, stderr, result) = mypy.api.run([ - '--follow-imports', 'silent', - '--check-untyped-defs', + '--cache-dir=.mypy_cache/doc', '--no-strict-optional', # needed because of torch.lu_unpack, see gh-36584 os.path.abspath(fn), ]) if result != 0: self.fail(f"mypy failed:\n{stdout}") - @unittest.skipIf(not HAVE_MYPY, "need mypy") - def test_type_hint_examples(self): - """ - Runs mypy over all the test examples present in - `type_hint_tests` directory. - """ - test_path = os.path.dirname(os.path.realpath(__file__)) - examples_folder = os.path.join(test_path, "type_hint_tests") - examples = os.listdir(examples_folder) - for example in examples: - example_path = os.path.join(examples_folder, example) - (stdout, stderr, result) = mypy.api.run([ - '--follow-imports', 'silent', - '--check-untyped-defs', - example_path, - ]) - if result != 0: - self.fail(f"mypy failed for example {example}\n{stdout}") - @unittest.skipIf(not HAVE_MYPY, "need mypy") def test_run_mypy(self): """ @@ -202,17 +182,11 @@ def is_torch_mypyini(path_to_file): if numpy.__version__ == '1.20.0.dev0+7af1024': self.skipTest("Typeannotations in numpy-1.20.0-dev are broken") - cwd = os.getcwd() # TODO: Would be better not to chdir here, this affects the entire # process! - os.chdir(repo_rootdir) - try: - (stdout, stderr, result) = mypy.api.run([ - '--check-untyped-defs', - '--follow-imports', 'silent', - ]) - finally: - os.chdir(cwd) + with set_cwd(repo_rootdir): + (stdout, stderr, result) = mypy.api.run([]) + if result != 0: self.fail(f"mypy failed: {stdout} {stderr}") @@ -227,14 +201,11 @@ def test_run_mypy_strict(self): if not os.path.exists(mypy_inifile): self.skipTest("Can't find PyTorch MyPy strict config file") - cwd = os.getcwd() - os.chdir(repo_rootdir) - try: + with set_cwd(repo_rootdir): (stdout, stderr, result) = mypy.api.run([ '--config', mypy_inifile, ]) - finally: - os.chdir(cwd) + if result != 0: self.fail(f"mypy failed: {stdout} {stderr}") diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 7f10915a5ac47..576b427d0beb2 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -919,9 +919,9 @@ def test_unary_op_out_casting(self, device, dtypes): t = torch.tensor((1), dtype=dtypes[0], device=device) out = torch.empty(0, dtype=dtypes[1], device=device) - ops = (torch.neg, torch.floor, torch.ceil, torch.cos, torch.erf, torch.log) - float_only_ops = {torch.floor, torch.ceil, torch.cos, torch.erf, torch.log} - real_only_ops = {torch.floor, torch.ceil, torch.erf} + ops = (torch.neg, torch.floor, torch.ceil) + float_only_ops = {torch.floor, torch.ceil} + real_only_ops = {torch.floor, torch.ceil} for op in ops: if dtypes[0] is not dtypes[1]: with self.assertRaises(RuntimeError): diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 09a3cbd583a7f..3c6a4f0e7b0a1 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1,23 +1,28 @@ +import torch +import numpy as np + +import warnings import math from itertools import product, chain from numbers import Number - +import random import unittest -import torch - -from torch.testing._internal.common_utils import \ - (TestCase, run_tests, torch_to_numpy_dtype_dict, suppress_warnings, - TEST_NUMPY, make_tensor) -from torch.testing._internal.common_methods_invocations import \ - (unary_ufuncs) -from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, ops, dtypes) -from torch.testing import \ - (floating_types_and) - -if TEST_NUMPY: - import numpy as np +from torch._six import inf, nan +from torch.testing._internal.common_utils import ( + TestCase, run_tests, torch_to_numpy_dtype_dict, suppress_warnings, + IS_MACOS, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy) +from torch.testing._internal.common_methods_invocations import ( + unary_ufuncs) +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, ops, dtypes, onlyCPU, onlyOnCPUAndCUDA, + onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU, + OpDTypes) +from torch.testing import ( + floating_types_and, integral_types, all_types_and_complex_and, floating_types) + +if TEST_SCIPY: + import scipy # Tests for unary "universal functions (ufuncs)" that accept a single # tensor and have common properties like: @@ -204,10 +209,18 @@ def _fn(t): t = make_tensor((5, 5), device, dtype, low=op.domain[0], high=op.domain[1]) expected = op(t) - for alt in (op.get_method(), op.get_inplace(), torch.jit.script(_fn)): + for alt, inplace in ((op.get_method(), False), (op.get_inplace(), True), + (torch.jit.script(_fn), False)): if alt is None: + continue + + if inplace and op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,): + # Assert that RuntimeError is raised + # for inplace variant of Operators that + # promote integer input to floating dtype. with self.assertRaises(RuntimeError): alt(t.clone()) + continue actual = alt(t.clone()) self.assertEqual(actual, expected, rtol=0, atol=0) @@ -219,14 +232,16 @@ def assertEqualHelper(self, actual, expected, msg, *, dtype, exact_dtype=True, * # Some NumPy functions return scalars, not arrays if isinstance(expected, Number): - self.assertEqual(actual.item(), expected) + self.assertEqual(actual.item(), expected, **kwargs) elif isinstance(expected, np.ndarray): # Handles exact dtype comparisons between arrays and tensors if exact_dtype: # Allows array dtype to be float32 when comparing with bfloat16 tensors # since NumPy doesn't support the bfloat16 dtype + # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16 + # to float32 if expected.dtype == np.float32: - assert actual.dtype in (torch.bfloat16, torch.float32) + assert actual.dtype in (torch.float16, torch.bfloat16, torch.float32) else: assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype] @@ -242,7 +257,6 @@ def assertEqualHelper(self, actual, expected, msg, *, dtype, exact_dtype=True, * # values on a range of tensors, including empty tensors, scalar tensors, # 1D tensors and a large 2D tensor with interesting and extremal values # and discontiguities. - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @suppress_warnings @ops(unary_ufuncs) def test_reference_numerics(self, device, dtype, op): @@ -270,7 +284,18 @@ def test_reference_numerics(self, device, dtype, op): else: msg = None - self.assertEqualHelper(actual, expected, msg, dtype=dtype) + exact_dtype = True + if op.promotes_integers_to_float and dtype in integral_types() + (torch.bool,): + exact_dtype = False + + if dtype in [torch.uint8, torch.int8, torch.bool]: + # NOTE: For these dtypes, PyTorch computes in the default scalar type (float) + # while NumPy computes in float16 + self.assertEqualHelper(actual, expected, msg, dtype=dtype, + exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2) + continue + + self.assertEqualHelper(actual, expected, msg, dtype=dtype, exact_dtype=exact_dtype) # Tests for testing (dis)contiguity consistency @@ -377,8 +402,1449 @@ def test_batch_vs_slicing(self, device, dtype, op): self.assertEqual(actual, expected) + def _test_out_arg(self, op, input, output): + dtype = input.dtype + out_dtype = output.dtype + if dtype is out_dtype: + expected = op(input) + op(input, out=output) + self.assertEqual(output, expected) + else: + with self.assertRaises(RuntimeError): + op(input, out=output) + + def _test_out_promote_int_to_float_op(self, op, input, output): + def compare_out(op, input, out): + out_dtype = out.dtype + expected = op(input) + op(input, out=out) + self.assertEqual(out, expected.to(out_dtype)) + + dtype = input.dtype + out_dtype = output.dtype + if out_dtype.is_floating_point and not dtype.is_complex: + compare_out(op, input, output) + elif out_dtype.is_floating_point and dtype.is_complex: + if op.supports_complex_to_float: + compare_out(op, input, output) + else: + # Can't cast complex to float + with self.assertRaises(RuntimeError): + op(input, out=output) + elif out_dtype.is_complex: + compare_out(op, input, output) + else: + # Can't cast to Integral types + with self.assertRaises(RuntimeError): + op(input, out=output) + + @ops(unary_ufuncs, dtypes=OpDTypes.supported) + def test_out_arg_all_dtypes(self, device, dtype, op): + input = make_tensor((64, 64), dtype=dtype, device=device, + low=op.domain[0], high=op.domain[1]) + + for out_dtype in all_types_and_complex_and(torch.bool, torch.half): + out = torch.empty_like(input, dtype=out_dtype) + if op.promotes_integers_to_float: + self._test_out_promote_int_to_float_op(op, input, out) + else: + self._test_out_arg(op, input, out) + + @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool] + + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_nan_to_num(self, device, dtype): + for contiguous in [False, True]: + x = make_tensor((64, 64), low=0., high=100., dtype=dtype, device=device) + + if dtype.is_floating_point: + # Add extremal values. + extremals = [float('nan'), float('inf'), -float('inf')] + for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals): + x[idx, :] = extremal + + if not contiguous: + x = x.T + + # With args + nan = random.random() + posinf = random.random() * 5 + neginf = random.random() * 10 + + self.compare_with_numpy(lambda x: x.nan_to_num(nan=nan, posinf=posinf), + lambda x: np.nan_to_num(x, nan=nan, posinf=posinf), + x) + self.compare_with_numpy(lambda x: x.nan_to_num(posinf=posinf, neginf=neginf), + lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf), + x) + + # Out Variant + out = torch.empty_like(x) + result = torch.nan_to_num(x) + torch.nan_to_num(x, out=out) + self.assertEqual(result, out) + + result = torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + torch.nan_to_num(x, out=out, nan=nan, posinf=posinf, neginf=neginf) + self.assertEqual(result, out) + + @unittest.skipIf(IS_MACOS, "Skip Reference: https://github.com/pytorch/pytorch/issues/47500") + @dtypes(torch.cfloat, torch.cdouble) + def test_sqrt_complex_edge_values(self, device, dtype): + # Test Reference: https://github.com/pytorch/pytorch/pull/47424 + x = torch.tensor(0. - 1.0000e+20j, dtype=dtype, device=device) + self.compare_with_numpy(torch.sqrt, np.sqrt, x) + + x = torch.tensor(-1.0000e+20 - 4988429.2000j, dtype=dtype, device=device) + self.compare_with_numpy(torch.sqrt, np.sqrt, x) + + @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") + @dtypes(torch.float, torch.double) + def test_digamma_special(self, device, dtype): + # Based on SciPy test for the following special values. + # Reference: + # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22 + euler = 0.57721566490153286 + dataset = [(0., -0.), + (1, -euler), + (0.5, -2 * math.log(2) - euler), + (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler), + (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler), + (1 / 6, -math.pi * math.sqrt(3) / 2 - 2 * math.log(2) - 3 * math.log(3) / 2 - euler), + (1 / 8, -math.pi / 2 - 4 * math.log(2) - + (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) / math.sqrt(2) - euler)] + x = torch.tensor(dataset, device=device, dtype=dtype) + self.compare_with_numpy(torch.digamma, scipy.special.digamma, x) + + @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") + @dtypes(torch.float, torch.double) + def test_digamma(self, device, dtype): + # Tests pole behavior + # TODO: Add value `-1931.99999994`, to the tensor below when + # https://github.com/pytorch/pytorch/issues/49015 is fixed + tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, + -100.99999994, 0.000000111, + -0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device) + self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) + + # TODO opinfo mvlgamma + @unittest.skipIf(not TEST_SCIPY, "Scipy not found") + def test_mvlgamma(self, device): + from scipy.special import multigammaln + for d in range(1, 5): + input = torch.empty(10, device=device).uniform_(d, 10) + res_torch = torch.mvlgamma(input, d) + res_scipy = multigammaln(input.cpu().numpy(), d) + self.assertEqual(res_torch.cpu().numpy(), res_scipy, atol=1e-5, rtol=0) + + def test_mvlgamma_argcheck(self, device): + def run_test(d): + input = torch.linspace((d - 2) / 2, 10, 10, device=device) + torch.mvlgamma(input, d) + + with self.assertRaisesRegex(RuntimeError, r"All elements must be greater than \(p-1\)/2"): + run_test(3) + + # TODO opinfo polygamma + def test_polygamma_neg(self, device): + with self.assertRaisesRegex(RuntimeError, r'polygamma\(n, x\) does not support negative n\.'): + torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device)) + + # TODO resolve with opinfos + @onlyCPU + def test_op_invert(self, device): + res = 0xffff - torch.arange(127, dtype=torch.int8) + for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + a = torch.arange(127, dtype=dtype) + self.assertEqual(res.to(dtype), ~a) + + self.assertEqual(torch.tensor([True, False]), ~torch.tensor([False, True])) + + # test exceptions + for dtype in (torch.half, torch.float, torch.double): + a = torch.zeros(10, dtype=dtype) + with self.assertRaises(TypeError): + b = ~a + + @dtypes(torch.complex64, torch.complex128) + def test_abs_angle_complex_to_float(self, device, dtype): + # Constructs random complex values + from random import random + random_vals = [] + for multiplier in (-1, 1, -10, 10, -100, 100): + for _ in range(10): + random_vals.append(complex(random() * multiplier, random() * multiplier)) + + for vals in (random_vals, []): + a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype]) + t = torch.tensor(vals, device=device, dtype=dtype) + + for fn_name in ('abs', 'angle'): + torch_fn = getattr(torch, fn_name) + np_fn = getattr(np, fn_name) + + # Tests function + np_result = torch.from_numpy(np_fn(a)) + torch_result = torch_fn(t).cpu() + self.assertEqual(np_result, torch_result, exact_dtype=True) + + # Tests float out + float_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 + np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype]) + float_out = torch.empty_like(t).float() + torch_fn(t, out=float_out) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(torch.from_numpy(np_float_out), float_out.cpu()) + + # Tests float out (resized out) + float_out = torch.empty(1, device=device, dtype=float_dtype) + torch_fn(t, out=float_out) + self.assertEqual(torch.from_numpy(np_float_out), float_out.cpu()) + + # Tests complex out + np_complex_out = np_fn(a) + complex_out = torch.empty_like(t) + torch_fn(t, out=complex_out) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(torch.from_numpy(np_complex_out), complex_out.cpu()) + + # Tests complex out (resized out) + complex_out = torch.empty(0, device=device, dtype=dtype) + torch_fn(t, out=complex_out) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(torch.from_numpy(np_complex_out), complex_out.cpu()) + + # Tests long out behavior (expected failure) + long_out = torch.empty(0, device=device, dtype=torch.long) + with self.assertRaises(RuntimeError): + torch_fn(t, out=long_out) + + # Tests inplace + if fn_name == 'abs': + torch_inplace_method = getattr(torch.Tensor, fn_name + "_") + np_fn(a, out=a) + if dtype.is_complex: + with self.assertRaisesRegex(RuntimeError, "In-place abs is not supported for complex tensors."): + torch_inplace_method(t) + return + torch_inplace_method(t) + self.assertEqual(torch.from_numpy(a), t.cpu()) + + # Note: angle does not have an in-place variant + if fn_name == 'angle': + with self.assertRaises(AttributeError): + torch_inplace_method = getattr(torch.Tensor, fn_name + "_") + + # TODO: update sign to use opinfo-based testing + # XLA tests fail for self.assertRaises for complex dtypes + @onlyOnCPUAndCUDA + def test_sign_complex_assert_raises(self, device): + for dtype in [torch.complex64, torch.complex128]: + size = [5, 5] + tensor = torch.rand(size, dtype=dtype, device=device) + + with self.assertRaisesRegex(RuntimeError, + (r'Unlike NumPy, torch.sign is not intended to support complex numbers\. ' + r'Please use torch.sgn instead\.')): + torch.sign(torch.tensor([4j], device=device, dtype=dtype)) + + def check_internal_mem_overlap(self, inplace_op, num_inputs, + dtype, device, + expected_failure=False): + if isinstance(inplace_op, str): + inplace_op = getattr(torch.Tensor, inplace_op) + input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) + inputs = [input] + [torch.randn_like(input) + for i in range(num_inputs - 1)] + if not expected_failure: + with self.assertRaisesRegex(RuntimeError, 'single memory location'): + inplace_op(*inputs) + else: + with self.assertRaises(AssertionError): + with self.assertRaisesRegex(RuntimeError, 'single memory location'): + inplace_op(*inputs) + + def unary_check_input_output_mem_overlap(self, data, sz, op, + expected_failure=False): + + def _test(op, output, input): + output_exp = torch.empty_like(output) + op(input, out=output_exp) + self.assertEqual(op(input, out=output), output_exp, msg=op.__name__) + + # output is identical to input: + _test(op, output=data[0:sz], input=data[0:sz]) + # output and input are independent: + _test(op, output=data[0:sz], input=data[sz:2 * sz]) + # output partially overlaps with input: + if not expected_failure: + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + _test(op, data[0:sz], data[1:sz + 1]) + else: + with self.assertRaises(AssertionError): + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + _test(op, data[0:sz], data[1:sz + 1]) + + # TODO: run on non-native device types + @dtypes(torch.double) + def test_unary_out_op_mem_overlap(self, device, dtype): + sz = 3 + doubles = torch.randn(2 * sz, dtype=dtype, device=device) + positives = torch.randint(1, 100, (2 * sz,), device=device).double() + ints = torch.randint(-100, 100, (2 * sz,), device=device) + unary_mem_overlap_cases = [ + ("abs", doubles, True, True, 'cpu'), + ("abs", doubles, True, True, 'cuda'), + ("acos", doubles, True, True, 'cpu'), + ("acos", doubles, True, True, 'cuda'), + ("asin", doubles, True, True, 'cpu'), + ("asin", doubles, True, True, 'cuda'), + ("atan", doubles, True, True, 'cpu'), + ("atan", doubles, True, True, 'cuda'), + ("acosh", doubles, True, True, 'cpu'), + ("acosh", doubles, True, True, 'cuda'), + ("asinh", doubles, True, True, 'cpu'), + ("asinh", doubles, True, True, 'cuda'), + ("atanh", doubles, True, True, 'cpu'), + ("atanh", doubles, True, True, 'cuda'), + ("bitwise_not", ints, True, True, 'cpu'), + ("bitwise_not", ints, True, True, 'cuda'), + ("ceil", doubles, True, True, 'cpu'), + ("ceil", doubles, True, True, 'cuda'), + ("cos", doubles, True, True, 'cpu'), + ("cos", doubles, True, True, 'cuda'), + ("cosh", doubles, True, True, 'cpu'), + ("cosh", doubles, True, True, 'cuda'), + ("digamma", doubles, True, True, 'cpu'), + ("erf", doubles, True, True, 'cpu'), + ("erf", doubles, True, True, 'cuda'), + ("erfc", doubles, True, True, 'cpu'), + ("erfc", doubles, True, True, 'cuda'), + ("erfinv", doubles, True, True, 'cpu'), + ("erfinv", doubles, True, True, 'cuda'), + ("exp", doubles, True, True, 'cpu'), + ("exp", doubles, True, True, 'cuda'), + ("exp2", doubles, True, True, 'cpu'), + ("exp2", doubles, True, True, 'cuda'), + ("expm1", doubles, True, True, 'cpu'), + ("expm1", doubles, True, True, 'cuda'), + ("floor", doubles, True, True, 'cpu'), + ("floor", doubles, True, True, 'cuda'), + ("frac", doubles, True, True, 'cpu'), + ("frac", doubles, True, True, 'cuda'), + ("i0", doubles, True, True, 'cpu'), + ("i0", doubles, True, True, 'cuda'), + ("log", positives, True, True, 'cpu'), + ("log", positives, True, True, 'cuda'), + ("log10", positives, True, True, 'cpu'), + ("log10", positives, True, True, 'cuda'), + ("log1p", positives, True, True, 'cpu'), + ("log1p", positives, True, True, 'cuda'), + ("log2", positives, True, True, 'cpu'), + ("log2", positives, True, True, 'cuda'), + ("neg", doubles, True, True, 'cpu'), + ("neg", doubles, True, True, 'cuda'), + ("reciprocal", doubles, True, True, 'cpu'), + ("reciprocal", doubles, True, True, 'cuda'), + ("round", doubles, True, True, 'cpu'), + ("round", doubles, True, True, 'cuda'), + ("rsqrt", positives, True, True, 'cpu'), + ("rsqrt", positives, True, True, 'cuda'), + ("sin", doubles, True, True, 'cpu'), + ("sin", doubles, True, True, 'cuda'), + ("sinh", doubles, True, True, 'cpu'), + ("sinh", doubles, False, True, 'cuda'), + ("sigmoid", doubles, True, True, 'cpu'), + ("sigmoid", doubles, True, True, 'cuda'), + ("logit", doubles, True, True, 'cpu'), + ("logit", doubles, True, True, 'cuda'), + ("sqrt", doubles, True, True, 'cpu'), + ("sqrt", doubles, False, True, 'cuda'), + ("tan", doubles, True, True, 'cpu'), + ("tan", doubles, True, True, 'cuda'), + ("tanh", doubles, True, True, 'cpu'), + ("tanh", doubles, True, True, 'cuda'), + ("trunc", doubles, True, True, 'cpu'), + ("trunc", doubles, True, True, 'cuda') + ] + + for (fn, inputs, has_input_output_mem_overlap_check, + has_internal_mem_overlap_check, dev) in unary_mem_overlap_cases: + if dev != device: + continue + out_fn = getattr(torch, fn) + in_fn = getattr(torch.Tensor, fn + '_') + + self.unary_check_input_output_mem_overlap(inputs, sz, out_fn, + expected_failure=not has_input_output_mem_overlap_check) + + self.check_internal_mem_overlap(in_fn, 1, dtype, dev, + expected_failure=not has_internal_mem_overlap_check) + + # TODO: review with ceil opinfo + @onlyCUDA + def test_ceil_out_mismatch(self, device): + a = torch.randn(1) + b = torch.randn(1, device=device) + self.assertRaises(RuntimeError, lambda: torch.ceil(a, out=b)) + + # TODO: opinfo hardshrink + @onlyCPU + @dtypes(torch.float, torch.double) + def test_hardshrink(self, device, dtype): + data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2) + self.assertEqual(torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2), + data.hardshrink(0.3)) + self.assertEqual(torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2), + data.hardshrink(0.5)) + + # test default lambd=0.5 + self.assertEqual(data.hardshrink(), data.hardshrink(0.5)) + + # test non-contiguous case + self.assertEqual(torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2), + data.t().hardshrink(0.3)) + + @onlyCPU + @dtypes(torch.float, torch.double) + def test_hardshrink_edge_cases(self, device, dtype) -> None: + def h(values, l_expected): + for l, expected in l_expected.items(): + values_tensor = torch.tensor([float(v) for v in values], + dtype=dtype, device=device) + expected_tensor = torch.tensor([float(v) for v in expected], + dtype=dtype, device=device) + self.assertEqual(expected_tensor == values_tensor.hardshrink(l), + torch.ones_like(values_tensor, dtype=torch.bool)) + + def test_helper(min, max): + h([0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + {0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], + 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], + max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], + inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}) + + test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max) + + @onlyCPU + @slowTest + @dtypes(torch.float) + def test_exp_slow(self, device, dtype): + # Test for https://github.com/pytorch/pytorch/issues/17271 + # This is pretty slow on my Macbook but it only takes a few + # seconds on a beefy Xeon server + a = torch.exp(torch.ones(2 ** 31, dtype=dtype, device=device)) + b = torch.exp(torch.ones(1, dtype=dtype, device=device)) + self.assertEqual(a, b.expand(2 ** 31)) + + @precisionOverride({torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}) + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16) + @dtypes(torch.float, torch.double) + def test_hardswish(self, device, dtype): + inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] + expectedOutput = np.multiply( + inputValues, + np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0) + + inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) + expectedOutputTensor = \ + torch.tensor(expectedOutput, dtype=dtype, device=device) + + # normal + self.assertEqual(torch.nn.functional.hardswish(inputTensor), + expectedOutputTensor) + + # inplace + inputTensorCpy = inputTensor.clone().detach() + torch.nn.functional.hardswish(inputTensorCpy, inplace=True) + self.assertEqual(inputTensorCpy, expectedOutputTensor) + + @precisionOverride({torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}) + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16) + @dtypes(torch.float, torch.double) + def test_hardsigmoid(self, device, dtype): + inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] + expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 + + inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) + + # normal + self.assertEqual(torch.nn.functional.hardsigmoid(inputTensor), + torch.tensor(expectedOutput, dtype=dtype, device=device)) + + # inplace + inputTensorCpy = inputTensor.clone().detach() + self.assertEqual(torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True), + torch.tensor(expectedOutput, dtype=dtype, device=device)) + + @skipIfNoSciPy + @dtypes(torch.float, torch.double) + def test_silu(self, device, dtype): + input_np = np.random.randn(5, 8) + special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] + input_np = np.concatenate((input_np, special_input), axis=0).astype( + torch_to_numpy_dtype_dict[dtype]) + expected_output_np = input_np * scipy.special.expit(input_np) + + expected_output = torch.from_numpy(expected_output_np).to(device) + expected_output_noncontig = expected_output.transpose(0, 1) + + atol = 1e-6 + rtol = 1e-6 + + input = torch.from_numpy(input_np).clone().contiguous().to(device) + self.assertEqual(torch.nn.functional.silu(input), expected_output, + atol=atol, rtol=rtol) + self.assertEqual(torch.nn.functional.silu(input, inplace=True), + expected_output, atol=atol, rtol=rtol) + + input = torch.from_numpy(input_np).clone().to(device) + input_noncontig = input.transpose(0, 1) + self.assertEqual(torch.nn.functional.silu(input_noncontig), + expected_output_noncontig, atol=atol, rtol=rtol) + self.assertEqual(torch.nn.functional.silu( + input_noncontig, inplace=True), expected_output_noncontig, + atol=atol, rtol=rtol) + + # do ops like threshold need a test_unary(_nonufunc) test suite? + @onlyCPU + @dtypes(*torch.testing.get_all_math_dtypes('cpu')) + def test_threshold(self, device, dtype): + if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex: + # 100 is wide enough to use AVX2 instructions for all types + x = torch.randn(100, dtype=torch.float, device=device).sign().to(dtype=dtype) + y = torch.threshold(x, 0, 0) + self.assertTrue(y.le(0).any()) + + def _helper_test_igamma(self, loglo, loghi, device, dtype, + torch_fcn, scipy_fcn): + exp1 = 2.71828182846 + vec1 = torch.logspace(loglo, loghi, steps=500, base=exp1, + dtype=torch.float64, device=device).unsqueeze(-1) + vec1 = vec1.to(dtype) + inputs = [ + (vec1, vec1.transpose(0, 1)), + (vec1, vec1), # for large number, it should approach 0.5 + (vec1, 0.5 * vec1), # test for considerable ratio + (vec1, 2.0 * vec1), + (vec1[::2, :], vec1[::2, :]), # contiguous/discontiguous tests + (vec1[::2, :], vec1[:vec1.shape[0] // 2, :]), + (vec1[:vec1.shape[0] // 2, :], vec1[::2, :]), + ] + half_prec = dtype in [torch.bfloat16, torch.float16] + for input0, input1 in inputs: + actual = torch_fcn(input0, input1) + if half_prec: + input0 = input0.to(torch.float) + input1 = input1.to(torch.float) + expected = scipy_fcn(input0.cpu().numpy(), input1.cpu().numpy()) + expected = torch.from_numpy(expected).to(dtype) + self.assertEqual(actual, expected) + + @skipCUDAIfRocm # see issue https://github.com/pytorch/pytorch/issues/46531 + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + @onlyOnCPUAndCUDA + def test_igamma_common(self, device, dtype): + # test igamma for reasonable range of values + loglo = -4 # approx 0.018 + loghi = 4 # approx 54.6 + self._helper_test_igamma(loglo, loghi, device, dtype, + torch.igamma, scipy.special.gammainc) + + @skipCUDAIfRocm + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + @onlyOnCPUAndCUDA + def test_igammac_common(self, device, dtype): + # test igammac for reasonable range of values + loglo = -4 # approx 0.018 + loghi = 4 # approx 54.6 + self._helper_test_igamma(loglo, loghi, device, dtype, + torch.igammac, scipy.special.gammaincc) + + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @onlyOnCPUAndCUDA + def test_igamma_edge_cases(self, device, dtype): + tkwargs = {"dtype": dtype, "device": device} + infs = torch.zeros((3,), **tkwargs) + float("inf") + zeros = torch.zeros((3,), **tkwargs) + ones = torch.ones((3,), **tkwargs) + zero_to_large = torch.tensor([0., 1., 1e3], **tkwargs) + small_to_inf = torch.tensor([1e-3, 1., float("inf")], **tkwargs) + nans = torch.zeros((3,), **tkwargs) + float("nan") + inpouts = [ + # (a , x), out + ((zeros, small_to_inf), ones), + ((small_to_inf, zeros), zeros), + ((infs, zero_to_large), zeros), + ((zero_to_large, infs), ones), + ((zeros, zeros), nans), + ((infs, infs), nans), + ((-small_to_inf, small_to_inf), nans), + ] + for inputs, output in inpouts: + input0, input1 = inputs + calc = torch.igamma(input0, input1) + if torch.all(torch.isnan(output)): + self.assertTrue(torch.all(torch.isnan(calc))) + else: + self.assertEqual(calc, output) + + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @onlyOnCPUAndCUDA + def test_igammac_edge_cases(self, device, dtype): + tkwargs = {"dtype": dtype, "device": device} + infs = torch.zeros((3,), **tkwargs) + float("inf") + zeros = torch.zeros((3,), **tkwargs) + ones = torch.ones((3,), **tkwargs) + zero_to_large = torch.tensor([0., 1., 1e3], **tkwargs) + small_to_inf = torch.tensor([1e-3, 1., float("inf")], **tkwargs) + nans = torch.zeros((3,), **tkwargs) + float("nan") + inpouts = [ + # (a , x), out + ((zeros, small_to_inf), zeros), + ((small_to_inf, zeros), ones), + ((infs, zero_to_large), ones), + ((zero_to_large, infs), zeros), + ((zeros, zeros), nans), + ((infs, infs), nans), + ((-small_to_inf, small_to_inf), nans), + ] + for inputs, output in inpouts: + input0, input1 = inputs + calc = torch.igammac(input0, input1) + if torch.all(torch.isnan(output)): + self.assertTrue(torch.all(torch.isnan(calc))) + else: + self.assertEqual(calc, output) + + def _i0_helper(self, t): + # Test by comparing to scipy + dtype = t.dtype + actual = torch.i0(t) + if dtype is torch.bfloat16: + t = t.to(torch.float32) + expected = scipy.special.i0(t.cpu().numpy()) + # Casting down for dtype float16 is required since scipy upcasts to float32 + if dtype is torch.bfloat16 or dtype is torch.float16: + expected = torch.from_numpy(expected).to(dtype) + self.assertEqual(actual, expected) + + def _i0_range_helper(self, range, device, dtype): + # i0 tests are broken up by the domain for which the function does not overflow for each dtype + # This is done to ensure that the function performs well across all possible input values, without worrying + # about inf or nan possibilities + for r in (range, -range): + t = torch.rand(1000, device=device).to(dtype) * r + self._i0_helper(t) + + @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypes(torch.bfloat16, torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_range1(self, device, dtype): + # This tests the domain for i0 for which float16 does not overflow + # The domain is (-13.25, 13.25) + self._i0_range_helper(13.25, device, dtype) + + @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypes(torch.bfloat16, torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_range2(self, device, dtype): + # This tests the domain for i0 for which float32 and bfloat16 does not overflow + # The domain is (-88.5, 88.5) + self._i0_range_helper(88.5, device, dtype) + + @dtypes(torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_range3(self, device, dtype): + # This tests the domain for i0 for which float64 does not overflow + # The domain is (-709.75, 709.75) + self._i0_range_helper(709.75, device, dtype) + + @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypes(torch.bfloat16, torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_special(self, device, dtype): + t = torch.tensor([], device=device, dtype=dtype) + self._i0_helper(t) + + t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype) + self.assertTrue(torch.i0(t).isnan().all()) + + # TODO: allow large opinfo values to be opted-into via metadata + @dtypes(torch.long) + def test_abs_big_number(self, device, dtype): + bignumber = 2 ** 31 + 1 + res = torch.tensor([bignumber], device=device, dtype=dtype) + self.assertGreater(res.abs()[0], 0) + + # TODO: add signed zero testing to opinfos + @dtypes(torch.float, torch.double) + def test_abs_signed_zero(self, device, dtype): + # Both abs(0.0) and abs(-0.0) should result in 0.0 + size = 128 + 1 # pick a large enough number with remainder so that + # both vectorized and nonvectorized op is tested + inp = torch.zeros(size, device=device, dtype=dtype) + inp[::2] = -0.0 + inp = inp.abs() + for v in inp: + self.assertGreater(math.copysign(1.0, v), 0.0) + + # TODO: rationalize with abs testing and verify absolute is tested as an alias + @dtypes(torch.float) + def test_absolute(self, device, dtype): + # absolute is an alias for abs. Just check to see that results + # are the same. + t = torch.randn(10, 10, device=device, dtype=dtype) + r_abs = t.abs() + r_absolute = t.absolute() + self.assertEqual(r_abs, r_absolute) + + r_abs = torch.abs(t) + r_absolute = torch.absolute(t) + self.assertEqual(r_abs, r_absolute) + + r_abs = torch.empty((10, 10), device=device, dtype=dtype) + r_absolute = torch.empty((10, 10), device=device, dtype=dtype) + torch.abs(t, out=r_abs) + torch.absolute(t, out=r_absolute) + self.assertEqual(r_abs, r_absolute) + + from copy import deepcopy + t_copy = deepcopy(t) + t.absolute_() + t_copy.abs_() + self.assertEqual(t, t_copy) + + # Note: ROCm fails when using float tensors + # TODO: update this test to just compare against NumPy + @onlyCUDA + @dtypes(torch.double) + def test_polygamma(self, device, dtype): + cpu_tensor = torch.randn(10, 10, 10, dtype=dtype) + device_tensor = cpu_tensor.to(device) + zeros = torch.zeros(10, 10, 10, dtype=dtype) + for n in [0, 1, 2, 3, 4, 5]: + cpu_out = cpu_tensor.polygamma(n) + device_out = device_tensor.polygamma(n) + norm_errors = (device_out - cpu_out.to(device)) / device_out + self.assertEqual(norm_errors, zeros) + + cpu_tensor.requires_grad = True + for n in [0, 1, 2, 3, 4, 5]: + torch.autograd.gradcheck(lambda x: x.polygamma(n), + cpu_tensor) + + # TODO: update to compare against NumPy by rationalizing with OpInfo + @onlyCUDA + @dtypes(torch.float, torch.double) + def test_abs_zero(self, device, dtype): + # Both abs(0.0) and abs(-0.0) should result in 0.0 + abs_zeros = torch.tensor([0.0, -0.0], device=device, dtype=dtype).abs().tolist() + for num in abs_zeros: + self.assertGreater(math.copysign(1.0, num), 0.0) + + @dtypes(*torch.testing.get_all_fp_dtypes()) + def test_isfinite_isinf_isnan(self, device, dtype): + vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1) + + self.compare_with_numpy(torch.isfinite, np.isfinite, vals, device, dtype) + self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) + self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) + + @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) + def test_isfinite_isinf_isnan_int(self, device, dtype): + vals = (-1, 0, 1) + + self.compare_with_numpy(torch.isfinite, np.isfinite, vals, device, dtype) + self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) + self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) + + @dtypes(*(torch.testing.get_all_fp_dtypes())) + def test_isposinf_isneginf_float(self, device, dtype): + ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf)) + vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1) + + for torch_op, numpy_op in ops: + if torch_op == torch.isposinf: + target_vals = (0, 1, 0, 0, 0, 0) + else: + target_vals = (1, 0, 0, 0, 0, 0) + + t = torch.tensor(vals, device=device, dtype=dtype) + # Manual check here as numpy does not support bfloat16 + if dtype == torch.bfloat16: + self.assertEqual(torch_op(t), + torch.tensor(target_vals, device=device, dtype=torch.bool)) + else: + self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype) + + # test the boolean tensor as the `out=` parameter + out = torch.empty_like(t, dtype=torch.bool) + t_target = torch.tensor(target_vals, device=device, dtype=torch.bool) + torch_op(t, out=out) + self.assertEqual(out, t_target) + + @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) + def test_isposinf_isneginf_int_and_bool(self, device, dtype): + ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf)) + vals = (-1, 0, 1) + + for torch_op, numpy_op in ops: + self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype) + + # test the boolean tensor as the `out=` parameter + t = torch.tensor(vals, device=device, dtype=dtype) + out = torch.empty_like(t, dtype=torch.bool) + t_target = torch.zeros_like(t, dtype=torch.bool) + torch_op(t, out=out) + self.assertEqual(out, t_target) + + @dtypes(torch.complex64, torch.complex128) + def test_isposinf_isneginf_complex(self, device, dtype): + torch_ops = (torch.isposinf, torch.isneginf) + vals = (complex(0, float('inf')), complex(1, -float('inf'))) + t = torch.tensor(vals, device=device, dtype=dtype) + out = torch.empty_like(t) + + for torch_op in torch_ops: + with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): + torch_op(t) + with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): + torch_op(t, out=out) + + @dtypes(*(torch.testing.get_all_dtypes(include_bool=False))) + def test_isposinf_isneginf_non_boolean_output(self, device, dtype): + # test non-boolean tensors as the `out=` parameters + # boolean outputs are tested in the above testcases + vals = (float('inf'), -float('inf'), 1.2) + t = torch.tensor(vals, device=device) + for torch_op in (torch.isposinf, torch.isneginf): + out = torch.empty_like(t, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'does not support non-boolean outputs'): + torch_op(t, out=out) + + @dtypes(torch.complex64, torch.complex128) + def test_isfinite_isinf_isnan_complex(self, device, dtype): + vals = ( + complex(-float('inf'), float('inf')), + complex(-float('inf'), 0), + complex(0, float('inf')), + complex(float('inf'), float('nan')), + complex(float('nan'), 0), + complex(-1, 0), + complex(0, 1) + ) + + self.compare_with_numpy(torch.isfinite, np.isfinite, vals, device, dtype) + self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) + self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) + + @dtypes(torch.complex64, torch.complex128) + def test_isreal_complex(self, device, dtype): + vals = (1, 1 + 1j, 2 + 0j, 3j, 2 - 1j, 2 - 0j) + self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype) + + @dtypes(*torch.testing.get_all_dtypes()) + def test_isreal_noncomplex(self, device, dtype): + vals = (1, 2, 3) + # Manual check here since numpy doesn't support bfloat16 + result = torch.isreal(torch.tensor(vals, dtype=dtype)) + expected = torch.ones(result.size(), dtype=torch.bool, device=device) + self.assertEqual(result, expected) + + @dtypes(torch.complex64) + def test_isreal_nan_inf(self, device, dtype): + vals = ( + complex(-float('inf'), float('inf')), + complex(-float('inf'), 0), + complex(0, float('inf')), + complex(float('inf'), float('nan')), + complex(float('nan'), 0), + complex(-1, 0), + complex(0, 1) + ) + self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype) + + @onlyCPU + def test_isfinite_type(self, device): + with self.assertRaises(TypeError): + torch.isfinite(1) # Parameter must be a tensor + + @onlyCPU + def test_isinf_type(self, device): + with self.assertRaises(TypeError): + torch.isinf(1) # Parameter must be a tensor + + def test_bitwise_not(self, device): + res = 0xffff - torch.arange(127, dtype=torch.int8, device=device) + for dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + if dtype == torch.bool: + a = torch.tensor([True, False], device=device) + expected_res = torch.tensor([False, True], device=device) + else: + a = torch.arange(127, dtype=dtype, device=device) + expected_res = res.to(dtype) + # new tensor + self.assertEqual(expected_res, a.bitwise_not()) + # out + b = torch.empty(0, dtype=dtype, device=device) + torch.bitwise_not(a, out=b) + self.assertEqual(expected_res, b) + # in-place + a.bitwise_not_() + self.assertEqual(expected_res, a) + + # test exceptions + for dtype in (torch.half, torch.float, torch.double): + a = torch.zeros(10, dtype=dtype, device=device) + # new tensor + with self.assertRaises(RuntimeError): + a.bitwise_not() + # out + b = torch.empty(0, dtype=dtype, device=device) + with self.assertRaises(RuntimeError): + torch.bitwise_not(a, out=b) + # in-place + with self.assertRaises(RuntimeError): + a.bitwise_not_() + + @dtypes(*torch.testing.get_all_dtypes()) + def test_logical_not(self, device, dtype): + data = [10, 1, 0.3, 0, -0.3, -1, -10] + a = torch.tensor(data, dtype=dtype, device=device) + if dtype == torch.bfloat16: # numpy doesn't support these dtypes + result = [False, False, False, True, False, False, False] + self.assertEqual(torch.logical_not(a), torch.tensor(result, dtype=torch.bool, device=device)) + else: + a_np = np.array(data, dtype=torch_to_numpy_dtype_dict[dtype]) + self.assertEqual(np.logical_not(a_np), torch.logical_not(a).to('cpu')) + self.assertEqual(np.logical_not(a_np, out=a_np), a.logical_not_().to('cpu')) + + @dtypes(*product(torch.testing.get_all_dtypes(), + torch.testing.get_all_dtypes())) + def test_logical_not_out(self, device, dtypes): + dtype = dtypes[0] + out_dtype = dtypes[1] + data = [10, 1, 0.3, 0, -0.3, -1, -10] + a = torch.tensor(data, dtype=dtype, device=device) + out = torch.empty_like(a, dtype=out_dtype, device=device) + if torch.bfloat16 in dtypes: # numpy doesn't support these dtypes + result = [not i for i in a] + self.assertEqual(torch.logical_not(a, out=out), torch.tensor(result, dtype=out_dtype, device=device)) + else: + out_np = np.empty(a.shape, dtype=torch_to_numpy_dtype_dict[out_dtype]) + self.assertEqual(a, a.cpu().numpy()) + torch.logical_not(a, out=out) + np.logical_not(a.cpu().numpy(), out=out_np) + self.assertEqual(out_np, out.to('cpu')) + + def test_nonzero_empty(self, device): + def assert_tuple_empty(tup, dim): + self.assertEqual(dim, len(tup)) + for t in tup: + self.assertEqual(torch.Size([0]), t.shape) + + x = torch.randn(0, 2, 0, 5, 0, device=device) + y = torch.nonzero(x) + z = torch.nonzero(x, as_tuple=True) + + self.assertEqual(0, y.numel()) + self.assertEqual(torch.Size([0, 5]), y.shape) + assert_tuple_empty(z, 5) + + x = torch.tensor(0.5, device=device) + y = torch.nonzero(x) + # nonzero with as_tuple returns a + # tuple of len 1 for a zero-dim tensor. + # This is done to match Numpy behavior. + z = torch.nonzero(x, as_tuple=True) + self.assertEqual(1, len(z)) + self.assertEqual(torch.zeros(1, dtype=torch.long), z[0]) + + x = torch.zeros((), device=device) + y = torch.nonzero(x) + z = torch.nonzero(x, as_tuple=True) + self.assertEqual(torch.Size([0, 0]), y.shape) + self.assertEqual(1, len(z)) + self.assertEqual(torch.empty(0, dtype=torch.long), z[0]) + + # TODO: rationalize with exp OpInfo + @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False) + + torch.testing.get_all_complex_dtypes())) + @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True) + + torch.testing.get_all_complex_dtypes())) + def test_exp(self, device, dtype): + for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): + a = torch.tensor(v, dtype=dtype, device=device) * torch.arange(18, device=device) / 3 * math.pi + a = a.to(dtype) + if dtype == torch.bfloat16: + with self.assertRaises(TypeError): # compare_with_numpy doesn't support bfloat16 + self.compare_with_numpy(torch.exp, np.exp, a) + return + self.compare_with_numpy(torch.exp, np.exp, a) + + if dtype.is_complex: + inf_real_zero_imag_in = torch.tensor(complex(float('inf'), 0), device=device, dtype=dtype) + inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item() + self.assertTrue(math.isinf(inf_real_zero_imag_out.real)) + if self.device_type == 'cpu': + pass + # These are commented out because it cannot be consistently reproduced. + # This is incorrect. It should be zero. Need fix! + # https://github.com/pytorch/pytorch/issues/40590 + # self.assertNotEqual(inf_real_zero_imag_out.imag, 0) + # This is incorrect. They should equal. Need fix! + # https://github.com/pytorch/pytorch/issues/40590 + # with self.assertRaises(AssertionError): + # self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) + else: + self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0) + self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) + + zero_real_inf_imag_in = torch.tensor(complex(0, float('inf')), device=device, dtype=dtype) + zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item() + self.assertTrue(math.isnan(zero_real_inf_imag_out.real)) + self.assertTrue(math.isnan(zero_real_inf_imag_out.imag)) + # Ensure we are notified when NumPy changes its behavior + self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in) + + inf_real_imag_in = torch.tensor(complex(float('inf'), float('inf')), device=device, dtype=dtype) + inf_real_imag_out = torch.exp(inf_real_imag_in).item() + if self.device_type == 'cpu': + pass + # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590 + # This is commented out because it cannot be consistently reproduced. + # with self.assertRaises(AssertionError): + # self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) + else: + self.assertTrue(math.isinf(inf_real_imag_out.real)) + self.assertTrue(math.isnan(inf_real_imag_out.imag)) + self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) + + inf_real_nan_imag_in = torch.tensor(complex(float('inf'), float('nan')), device=device, dtype=dtype) + inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item() + if self.device_type == 'cpu': + pass + # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590 + # This is commented out because it cannot be consistently reproduced. + # with self.assertRaises(AssertionError): + # self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) + else: + self.assertTrue(math.isinf(inf_real_nan_imag_out.real)) + self.assertTrue(math.isnan(inf_real_nan_imag_out.imag)) + self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) + + nan_real_inf_imag_in = torch.tensor(complex(float('nan'), float('inf')), device=device, dtype=dtype) + nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item() + self.assertTrue(math.isnan(nan_real_inf_imag_out.real)) + self.assertTrue(math.isnan(nan_real_inf_imag_out.imag)) + # Ensure we are notified when NumPy changes its behavior + self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in) + + @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) + def test_sign(self, device, dtype): + if dtype == torch.bool: + a_bool = torch.tensor([True, True, False, float('nan')], device=device).bool() + a_bool_target = torch.tensor([True, True, False, True], device=device).bool() + self.assertEqual(a_bool.sign(), a_bool_target, msg='sign device={} dtype=bool'.format(device)) + self.assertEqual(torch.sign(a_bool), a_bool_target, msg='sign device={} dtype=bool'.format(device)) + + a_out = torch.empty_like(a_bool) + torch.sign(a_bool, out=a_out) + self.assertEqual(a_out, a_bool_target, msg='sign_out device={} dtype=bool'.format(device)) + + a_bool.sign_() + self.assertEqual(a_bool, a_bool_target, msg='sign_ device={} dtype=bool'.format(device)) + return + + # Include NaN for floating point numbers + if dtype.is_floating_point: + dt_info = torch.finfo(dtype) + + # Create tensor (with NaN checking) + a = torch.tensor([float('nan'), -12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) + a_target = torch.tensor([0, -1, 0, 1, -1, 1], device=device, dtype=dtype) + else: + dt_info = torch.iinfo(dtype) + + # If unsigned type, everything should be >= 0 + if dt_info.min == 0: + a = torch.tensor([12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) + a_target = torch.tensor([1, 0, 1, 0, 1], device=device, dtype=dtype) + else: + a = torch.tensor([-12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) + a_target = torch.tensor([-1, 0, 1, -1, 1], device=device, dtype=dtype) + + self.assertEqual(a.sign(), a_target, msg='sign device={} dtype={}'.format(device, dtype)) + self.assertEqual(torch.sign(a), a_target, msg='sign device={} dtype={}'.format(device, dtype)) + + out = torch.empty_like(a) + torch.sign(a, out=out) + self.assertEqual(out, a_target, msg='sign_out device={} dtype={}'.format(device, dtype)) + + a.sign_() + self.assertEqual(a, a_target, msg='sign_ device={} dtype={}'.format(device, dtype)) + + @dtypes(*(torch.testing.torch.testing.get_all_fp_dtypes())) + def test_signbit_float(self, device, dtype): + t = torch.randn(5, 5, device=device) + + if dtype == torch.bfloat16: + t_bf16 = torch.tensor([1, 0, -1], device=device, dtype=dtype) + self.assertEqual(torch.signbit(t_bf16), torch.tensor([False, False, True])) + else: + self.compare_with_numpy(torch.signbit, np.signbit, t) + + t_target = torch.signbit(t) + out = torch.empty_like(t, device=device, dtype=torch.bool) + torch.signbit(t, out=out) + self.assertEqual(out, t_target) + + t_sp = (0, float('inf'), -float('inf'), float('nan')) + if dtype == torch.bfloat16: + t_sp_df16 = torch.tensor(t_sp, device=device, dtype=dtype) + self.assertEqual(torch.signbit(t_sp_df16), torch.tensor([False, False, True, False])) + else: + self.compare_with_numpy(torch.signbit, np.signbit, t_sp, device, dtype) + + @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) + def test_signbit_int_and_bool(self, device, dtype): + t = torch.randint(-5, 5, (5, 5), device=device) + self.compare_with_numpy(torch.signbit, np.signbit, t) + + t_target = torch.signbit(t) + out = torch.empty_like(t, device=device, dtype=torch.bool) + torch.signbit(t, out=out) + self.assertEqual(out, t_target) + + @dtypes(torch.complex64, torch.complex128) + def test_signbit_complex(self, device, dtype): + vals = (complex(0, -1), complex(-1, 2)) + t = torch.tensor(vals, device=device, dtype=dtype) + out = torch.empty_like(t).real.bool() + + with self.assertRaisesRegex(RuntimeError, 'signbit is not implemented for complex tensors.'): + torch.signbit(t) + with self.assertRaisesRegex(RuntimeError, 'signbit is not implemented for complex tensors.'): + torch.signbit(t, out=out) + + @dtypes(torch.cfloat, torch.cdouble) + def test_sgn(self, device, dtype): + x = torch.randn(100, dtype=dtype) + angle = x.angle() + out = x.sgn() + self.assertEqual(out.angle(), angle) + self.assertEqual(out.abs(), torch.ones_like(x).real) + + x_out = torch.empty_like(x) + torch.sgn(x, out=x_out) + self.assertEqual(x_out.angle(), angle) + self.assertEqual(x_out.abs(), torch.ones_like(x).real) + + @dtypes(*(torch.testing.get_all_dtypes(include_bool=False))) + def test_signbit_non_boolean_output(self, device, dtype): + # test non-boolean tensors as the `out=` parameters + # boolean outputs are tested in the above testcases + t = torch.randn(5, 5) + out = torch.empty_like(t, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'does not support non-boolean outputs'): + torch.signbit(t, out=out) + + # This function tests that a nan value is returned for input values not in domain + @dtypes(torch.float32, torch.float64) + def test_acosh_domain_float(self, device, dtype): + # Domain of acosh is [1, inf), for values outside the domain - output is mapped + # to NaN, except for input value `inf` - output is mapped to `inf` + sample = torch.tensor([float('-inf'), 1.00, -1.23, -0.06, 0.98, float('inf')], + device=device, dtype=dtype) + nan_mask = torch.tensor([True, False, True, True, True, False], device=device) + inf_mask = torch.tensor([False, False, False, False, False, True], device=device) + self.assertEqual(torch.isnan(torch.acosh(sample)), nan_mask) + self.assertEqual(torch.isnan(sample.acosh()), nan_mask) + self.assertEqual(torch.isinf(torch.acosh(sample)), inf_mask) + self.assertEqual(torch.isinf(sample.acosh()), inf_mask) + + # This function tests that a nan value is returned for input values not in domain + @dtypes(torch.float32, torch.float64) + def test_atanh_domain_float(self, device, dtype): + # Domain of atanh is (-1, 1), for edge values (-1 and 1) - output is mapped + # to inf and for other values outside this range - output is mapped to NaN + sample = torch.tensor([float('-inf'), -1.00, 1.00, -1.23, 1.06, float('inf')], + device=device, dtype=dtype) + nan_mask = torch.tensor([True, False, False, True, True, True], device=device) + inf_mask = torch.tensor([False, True, True, False, False, False], device=device) + # For values not in domain (except -1.0 and 1.0), atanh should return nan + self.assertEqual(torch.isnan(torch.atanh(sample)), nan_mask) + self.assertEqual(torch.isnan(sample.atanh()), nan_mask) + # For values -1.0 and 1.0, atanh should return -inf and inf respectively + self.assertEqual(torch.isinf(torch.atanh(sample)), inf_mask) + self.assertEqual(torch.isinf(sample.atanh()), inf_mask) + + +def _generate_reference_input(dtype, device): + input = [] + input.append(list(range(-5, 5))) + input.append([0 for x in range(-5, 5)]) + input.append([x + 1e-6 for x in range(-5, 5)]) + # Some vectorized implementations don't support large values + input.append([x + 1e10 for x in range(-5, 5)]) + input.append([x - 1e10 for x in range(-5, 5)]) + input.append([*torch.randn(7).tolist(), math.inf, -math.inf, math.nan]) + input.append((torch.randn(10) * 1e6).tolist()) + input.append([math.pi * (x / 2) for x in range(-5, 5)]) + return torch.tensor(input, dtype=dtype, device=device) + +def _generate_gamma_input(dtype, device, test_poles=True): + input = [] + input.append((torch.randn(10).abs() + 1e-4).tolist()) + input.append((torch.randn(10).abs() + 1e6).tolist()) + zeros = torch.linspace(-9.5, -0.5, 10) + input.append(zeros.tolist()) + input.append((zeros - 0.49).tolist()) + input.append((zeros + 0.49).tolist()) + input.append((zeros + (torch.rand(10) * 0.99) - 0.5).tolist()) + + if test_poles: + input.append([-0.999999994, -1.999999994, -2.0000000111, + -100.99999994, -1931.99999994, 0.000000111, + -0.000000111, 0, -2, -329]) + return torch.tensor(input, dtype=dtype, device=device) + +# this class contains information needed to generate tests for torch math functions +# the generated tests compare torch implementation with the reference numpy/scipy implementation, +# and also check proper behavior for contiguous/discontiguous/inplace outputs. +class _TorchMathTestMeta(object): + def __init__(self, + opstr, + args=(), + reffn=None, + refargs=lambda x: (x.numpy(),), + input_fn=_generate_reference_input, + inputargs=(), + substr='', + make_inplace=True, + decorators=None, + ref_backend='numpy', + rtol=None, + atol=None, + dtypes=floating_types(), + replace_inf_with_nan=False): + self.opstr = opstr + self.args = args + self.reffn = reffn # reffn is either callable or ref_backend attribute, set to opstr if not specified + self.refargs = refargs + self.input_fn = input_fn + self.inputargs = inputargs + self.substr = substr + self.make_inplace = make_inplace + assert ref_backend == 'numpy' or ref_backend == 'scipy' + self.ref_backend = ref_backend + if ref_backend == 'scipy': + self.ref_decorator = [unittest.skipIf(not TEST_SCIPY, "Scipy not found")] + else: + self.ref_decorator = [] + self.decorators = decorators + self.rtol = rtol + self.atol = atol + self.dtypes = dtypes + self.replace_inf_with_nan = replace_inf_with_nan + +# TODO: replace with make_tensor +# Converts half/bfloat16 dtype to float when device is cpu +def _convert_t(dtype, device): + if device == 'cpu' and dtype in {torch.half, torch.bfloat16}: + return torch.float + return dtype + +# TODO: replace with make_tensor +# Returns a tensor of the requested shape, dtype, and device +# Requesting a half CPU tensor returns a float CPU tensor with +# values representable by a half. +# Initialization uses randint for non-float types and randn for float types. +def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: + # Returns a tensor filled with ones + if fill_ones: + return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) + + # Returns a tensor with random integer values + if not (dtype.is_floating_point or dtype.is_complex): + t = torch.randint(0, 10, shape, device=device) + if dtype != torch.uint8: + t = t - 5 # generate negative values also + return t.to(_convert_t(dtype, device)) + + # Populates the CPU tensor with floats representable as half/bfloat16 + if dtype == torch.half and device == 'cpu': + return torch.randn(*shape, dtype=torch.float, device=device).half().float() + if dtype == torch.bfloat16 and device == 'cpu': + return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float() + + # Default: returns a tensor with random float values + return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) + +# TODO: replace with make_tensor +def _medium_2d(dtype, device): + return _make_tensor((50, 50), dtype, device) + +# TODO: replace with opinfo +_types_no_half = [ + torch.float, torch.double, + torch.int8, torch.short, torch.int, torch.long, + torch.uint8 +] + +# TODO: all these should be replaced with OpInfos +torch_op_tests = [ + _TorchMathTestMeta('floor'), + _TorchMathTestMeta('ceil'), + _TorchMathTestMeta('rad2deg'), + _TorchMathTestMeta('deg2rad'), + _TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)), + _TorchMathTestMeta('trunc'), + _TorchMathTestMeta('round'), + # FIXME lgamma produces different result compared to scipy at -inf + _TorchMathTestMeta('lgamma', reffn='gammaln', ref_backend='scipy', replace_inf_with_nan=True), + _TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma', + refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], + ref_backend='scipy'), + _TorchMathTestMeta('polygamma', args=[1], substr='_1', reffn='polygamma', + refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], + ref_backend='scipy', rtol=0.0008, atol=1e-5), + _TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma', + refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], + ref_backend='scipy', rtol=0.0008, atol=1e-5), + _TorchMathTestMeta('abs', input_fn=_medium_2d, dtypes=_types_no_half, rtol=0., atol=0.), + _TorchMathTestMeta('logit', ref_backend='scipy')] + + +def generate_torch_test_functions(cls, testmeta, inplace): + opstr = testmeta.opstr if not inplace else testmeta.opstr + "_" + + def torchfn(x): + return getattr(x, opstr)(*testmeta.args) + + def fn_check_reference(self, device, dtype): + def reffn(x): + backend = np if testmeta.ref_backend == 'numpy' else scipy.special + opstr = None + if testmeta.reffn is None: + opstr = testmeta.opstr + elif isinstance(testmeta.reffn, str): + opstr = testmeta.reffn + if callable(testmeta.reffn): + fn = testmeta.reffn + else: + assert opstr is not None, "invalid reffn" + fn = getattr(backend, opstr) + return fn(*testmeta.refargs(x)) + + inp = testmeta.input_fn(dtype, device, *testmeta.inputargs) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + expected = torch.from_numpy(reffn(inp)) + actual = torchfn(inp) + if testmeta.replace_inf_with_nan: + actual[(actual == -inf) | (actual == inf)] = nan + expected[(expected == -inf) | (expected == inf)] = nan + + torch.testing.assert_allclose(actual, expected, rtol=testmeta.rtol, atol=testmeta.atol) + + def fn_non_contig(self, device, dtype) -> None: + shapes = [(5, 7), (1024,)] + for shape in shapes: + contig = _make_tensor(shape, dtype=dtype, device=device) + non_contig = torch.empty(shape + (2,), dtype=dtype)[..., 0] + non_contig.copy_(contig) + self.assertFalse(non_contig.is_contiguous()) + self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous') + + def fn_non_contig_index(self, device, dtype): + contig = _make_tensor((2, 2, 1, 2), dtype=dtype, device=device) + non_contig = contig[:, 1, ...] + contig = non_contig.clone() + self.assertFalse(non_contig.is_contiguous()) + self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous index') + + def fn_non_contig_expand(self, device, dtype): + shapes = [(1, 3), (1, 7), (5, 7)] + for shape in shapes: + contig = _make_tensor(shape, dtype=dtype, device=device) + non_contig = contig.clone().expand(3, -1, -1) + self.assertFalse(non_contig.is_contiguous()) + contig = torchfn(contig) + non_contig = torchfn(non_contig) + for i in range(3): + self.assertEqual(contig, non_contig[i], msg='non-contiguous expand[' + str(i) + ']') + + def fn_contig_size1(self, device, dtype): + contig = _make_tensor((5, 100), dtype=dtype, device=device) + contig = contig[:1, :50] + contig2 = torch.empty(contig.size(), dtype=dtype) + contig2.copy_(contig) + self.assertTrue(contig.is_contiguous()) + self.assertTrue(contig2.is_contiguous()) + self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1') + + def fn_contig_size1_large_dim(self, device, dtype): + contig = _make_tensor((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype, device=device) + contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :] + contig2 = torch.empty(contig.size(), dtype=dtype) + contig2.copy_(contig) + self.assertTrue(contig.is_contiguous()) + self.assertTrue(contig2.is_contiguous()) + self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1') + + def fn_large(self, device, dtype): + input = _make_tensor((1024, 512), dtype=dtype, device=device) + # clone input to properly test inplace functions + actual = torchfn(input.clone()) + expected = torch.stack([torchfn(slice) for slice in input]) + self.assertEqual(actual, expected, msg='large') + + test_functions = {"test_reference_": fn_check_reference, + "test_non_contig_": fn_non_contig, + "test_non_contig_index_": fn_non_contig_index, + "test_non_contig_expand_": fn_non_contig_expand, + "test_contig_size1_": fn_contig_size1, + "test_check_contig_size1_large_dim_": fn_contig_size1_large_dim, + "test_large_": fn_large} + for name in test_functions: + if inplace and 'expand' in name: + continue + test_name = name + testmeta.opstr + testmeta.substr + if inplace: + test_name += "_inplace" + assert not hasattr(cls, test_name), "{0} already in TestUnaryUfuncMathOps".format(test_name) + + decorators = [] if testmeta.decorators is None else testmeta.decorators + if 'reference' in name: + decorators = decorators + testmeta.ref_decorator + decorators = decorators + [dtypes(*testmeta.dtypes)] + fn_test = test_functions[name] + for dec in decorators: + fn_test = dec(fn_test) + setattr(cls, test_name, fn_test) + +class TestUnaryUfuncMathOps(TestCase): + exact_dtype = True + +def generate_torch_op_tests(cls): + for t in torch_op_tests: + generate_torch_test_functions(cls, t, False) + if t.make_inplace: + generate_torch_test_functions(cls, t, True) + +generate_torch_op_tests(TestUnaryUfuncMathOps) instantiate_device_type_tests(TestUnaryUfuncs, globals()) +instantiate_device_type_tests(TestUnaryUfuncMathOps, globals(), only_for='cpu') if __name__ == '__main__': run_tests() diff --git a/test/test_utils.py b/test/test_utils.py index 398a10971d0d5..9733ae036d6ef 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,21 +3,22 @@ import re import shutil import random +import subprocess import tempfile import textwrap import unittest import torch import torch.nn as nn import torch.utils.data +from torch.utils.data import DataLoader import torch.cuda from torch.utils.checkpoint import checkpoint, checkpoint_sequential -import torch.utils._benchmark as benchmark_utils +import torch.utils.cpp_extension import torch.hub as hub from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings from torch.testing._internal.common_utils import load_tests, retry, IS_SANDCASTLE, IS_WINDOWS from urllib.error import URLError -import numpy as np # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -28,7 +29,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests -class RandomDatasetMock(object): +class RandomDatasetMock(torch.utils.data.Dataset): def __getitem__(self, index): return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) @@ -190,7 +191,7 @@ def forward(self, a, b): b = torch.randn(1, 100, requires_grad=True) with self.assertRaises(TypeError): - checkpoint_sequential(model, 1, a, b) + checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] def test_checkpoint_sequential_deprecated_no_args(self): class Noop(nn.Module): @@ -200,7 +201,7 @@ def forward(self): model = nn.Sequential(Noop()) with self.assertRaises(TypeError): - checkpoint_sequential(model, 1) + checkpoint_sequential(model, 1) # type: ignore[call-arg] def test_checkpoint_rng_cpu(self): for _ in range(5): @@ -268,6 +269,25 @@ def run_fn(tensor1, tensor2): out = checkpoint(run_fn, input_var, None) out.sum().backward() + def test_checkpoint_partial_grad(self): + def run_fn(tensor1, tensor2): + # tensor 2 is used for other application logic + return tensor1, tensor2 + input_var = torch.randn(1, 4, requires_grad=True) + input_var2 = torch.randn(1, 4, requires_grad=False) + out = checkpoint(run_fn, input_var, input_var2) + out[0].sum().backward() + + def run_fn2(tensor1, tensor2): + return tensor1 + input_var = torch.randn(1, 4, requires_grad=False) + input_var2 = torch.randn(1, 4, requires_grad=True) + with self.assertRaisesRegex( + RuntimeError, + r"none of output has requires_grad=True, this checkpoint\(\) is not necessary" + ): + out = checkpoint(run_fn2, input_var, input_var2) + out.sum().backward() class TestDataLoader(TestCase): def setUp(self): @@ -289,35 +309,38 @@ def run(): self.assertEqual(x1, x2) def test_single_keep(self): - dataloader = torch.utils.data.DataLoader(self.dataset, - batch_size=self.batch_size, - num_workers=0, - drop_last=False) + # self.dataset is a Tensor here; technically not a valid input because + # not a Dataset subclass, but needs to stay working so add ignore's + # for type checking with mypy + dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type] + batch_size=self.batch_size, + num_workers=0, + drop_last=False) dataiter = iter(dataloader) self.assertEqual(len(list(dataiter)), 2) def test_single_drop(self): - dataloader = torch.utils.data.DataLoader(self.dataset, - batch_size=self.batch_size, - num_workers=0, - drop_last=True) + dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type] + batch_size=self.batch_size, + num_workers=0, + drop_last=True) dataiter = iter(dataloader) self.assertEqual(len(list(dataiter)), 1) @unittest.skip("FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN") def test_multi_keep(self): - dataloader = torch.utils.data.DataLoader(self.dataset, - batch_size=self.batch_size, - num_workers=2, - drop_last=False) + dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type] + batch_size=self.batch_size, + num_workers=2, + drop_last=False) dataiter = iter(dataloader) self.assertEqual(len(list(dataiter)), 2) def test_multi_drop(self): - dataloader = torch.utils.data.DataLoader(self.dataset, - batch_size=self.batch_size, - num_workers=2, - drop_last=True) + dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type] + batch_size=self.batch_size, + num_workers=2, + drop_last=True) dataiter = iter(dataloader) self.assertEqual(len(list(dataiter)), 1) @@ -328,7 +351,7 @@ def test_multi_drop(self): class TestFFI(TestCase): def test_deprecated(self): with self.assertRaisesRegex(ImportError, "torch.utils.ffi is deprecated. Please use cpp extensions instead."): - from torch.utils.ffi import create_extension # noqa: F401 + from torch.utils.ffi import create_extension # type: ignore # noqa: F401 @unittest.skipIf('SKIP_TEST_BOTTLENECK' in os.environ.keys(), 'SKIP_TEST_BOTTLENECK is set') @@ -345,9 +368,9 @@ def _run(self, command, timeout=30): p.kill() output, err = p.communicate() rc = p.returncode - output = output.decode("ascii") - err = err.decode("ascii") - return (rc, output, err) + output_str = output.decode("ascii") + err_str = err.decode("ascii") + return (rc, output_str, err_str) def _run_bottleneck(self, test_file, scriptargs=''): curdir = os.path.dirname(os.path.abspath(__file__)) @@ -618,181 +641,85 @@ def test_import_hipify(self): from torch.utils.hipify import hipify_python # noqa -class TestBenchmarkUtils(TestCase): - def test_timer(self): - timer = benchmark_utils.Timer( - stmt="torch.ones(())", - ) - median = timer.blocked_autorange(min_run_time=0.01).median - self.assertIsInstance(median, float) - - # We set a very high threshold to avoid flakiness in CI. - # The internal algorithm is tested in `test_adaptive_timer` - median = timer.adaptive_autorange(threshold=0.5).median - - class _MockTimer: - _seed = 0 - - _timer_noise_level = 0.05 - _timer_cost = 100e-9 # 100 ns - - _function_noise_level = 0.05 - _function_costs = ( - ("pass", 8e-9), - ("cheap_fn()", 4e-6), - ("expensive_fn()", 20e-6), - ) - - def __init__(self, stmt, setup, timer, globals): - self._random_state = np.random.RandomState(seed=self._seed) - self._mean_cost = {k: v for k, v in self._function_costs}[stmt] - - def sample(self, mean, noise_level): - return max(self._random_state.normal(mean, mean * noise_level), 5e-9) - - def timeit(self, number): - return sum([ - # First timer invocation - self.sample(self._timer_cost, self._timer_noise_level), - - # Stmt body - self.sample(self._mean_cost * number, self._function_noise_level), - - # Second timer invocation - self.sample(self._timer_cost, self._timer_noise_level), - ]) - - def test_adaptive_timer(self): - class MockTimer(benchmark_utils.Timer): - _timer_cls = self._MockTimer - - def assert_reprs_match(measurement, expected): - measurement_repr = re.sub( - "object at 0x[0-9a-fA-F]+>", - "object at 0xXXXXXXXXXXXX>", - repr(measurement) +class TestAssert(TestCase): + def test_assert_true(self): + # verify assertions work as expected + # bool argument + torch._assert(True, "foo") + with self.assertRaisesRegex(AssertionError, "bar"): + torch._assert(False, "bar") + # tensor argument + torch._assert(torch.tensor([True], dtype=torch.bool), "foo") + with self.assertRaisesRegex(AssertionError, "bar"): + torch._assert(torch.tensor([False], dtype=torch.bool), "bar") + + def test_assert_scriptable(self): + class M(torch.nn.Module): + def forward(self, x): + torch._assert(x.sum() > 0, "foo") + return x + + m = M() + # scriptable + ms = torch.jit.script(m) + # data can be passed without errors + x = torch.randn(4, 4).fill_(1.0) + ms(x) + with self.assertRaisesRegex(torch.jit.Error, "foo"): # type: ignore[type-var] + ms(torch.tensor([False], dtype=torch.bool)) + + +@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only.") +class TestStandaloneCPPJIT(TestCase): + def test_load_standalone(self): + build_dir = tempfile.mkdtemp() + try: + src_path = os.path.join(build_dir, "main.cpp") + src = textwrap.dedent("""\ + #include + #include + int main() { + auto x = torch::eye(3); + std::cout << x << std::endl; + } + """) + with open(src_path, "wt") as f: + f.write(src) + + exec_path = torch.utils.cpp_extension.load( + "standalone_load_test", + src_path, + build_directory=build_dir, + is_python_module=False, + is_standalone=True, ) - self.assertEqual(measurement_repr, textwrap.dedent(expected).strip()) - - assert_reprs_match( - MockTimer("pass").blocked_autorange(min_run_time=10), - """ - - pass - Median: 7.98 ns - IQR: 0.52 ns (7.74 to 8.26) - 125 measurements, 10000000 runs per measurement, 1 thread""" - ) - - assert_reprs_match( - MockTimer("pass").adaptive_autorange(), - """ - - pass - Median: 7.86 ns - IQR: 0.71 ns (7.63 to 8.34) - 6 measurements, 1000000 runs per measurement, 1 thread""" - ) - assert_reprs_match( - MockTimer("cheap_fn()").blocked_autorange(min_run_time=10), - """ - - cheap_fn() - Median: 3.98 us - IQR: 0.27 us (3.85 to 4.12) - 252 measurements, 10000 runs per measurement, 1 thread""" - ) - - assert_reprs_match( - MockTimer("cheap_fn()").adaptive_autorange(), - """ - - cheap_fn() - Median: 4.16 us - IQR: 0.22 us (4.04 to 4.26) - 4 measurements, 1000 runs per measurement, 1 thread""" - ) - - assert_reprs_match( - MockTimer("expensive_fn()").blocked_autorange(min_run_time=10), - """ - - expensive_fn() - Median: 19.97 us - IQR: 1.35 us (19.31 to 20.65) - 501 measurements, 1000 runs per measurement, 1 thread""" - ) - - assert_reprs_match( - MockTimer("expensive_fn()").adaptive_autorange(), - """ - - expensive_fn() - Median: 20.79 us - IQR: 1.09 us (20.20 to 21.29) - 4 measurements, 1000 runs per measurement, 1 thread""" - ) - - class _MockCudaTimer(self._MockTimer): - # torch.cuda.synchronize is much more expensive than - # just timeit.default_timer - _timer_cost = 10e-6 - - _function_costs = ( - self._MockTimer._function_costs[0], - self._MockTimer._function_costs[1], - - # GPU should be faster once there is enough work. - ("expensive_fn()", 5e-6), + ext = ".exe" if IS_WINDOWS else "" + self.assertEqual( + exec_path, + os.path.join(build_dir, f"standalone_load_test{ext}") ) - class MockCudaTimer(benchmark_utils.Timer): - _timer_cls = _MockCudaTimer - - configurations = ( - (7.9903966e-09, 376, 1000000, MockTimer("pass")), - (7.8554826e-09, 4, 100000000, MockCudaTimer("pass")), - (3.9930536e-06, 752, 1000, MockTimer("cheap_fn()")), - (3.9441239e-06, 8, 100000, MockCudaTimer("cheap_fn()")), - (1.9994249e-05, 150, 1000, MockTimer("expensive_fn()")), - (4.9301076e-06, 6, 100000, MockCudaTimer("expensive_fn()")), - ) - - for median, repeats, number_per_run, timer_instance in configurations: - measurement = timer_instance.blocked_autorange(min_run_time=3) - self.assertEqual(measurement.median, median) - self.assertEqual(len(measurement.times), repeats) - self.assertEqual(measurement.number_per_run, number_per_run) - - def test_compare(self): - compare = benchmark_utils.Compare([ - benchmark_utils.Timer( - "torch.ones((n,))", globals={"n": n}, - description="ones", label=str(n)).timeit(3) - for n in range(3) - ]) - compare.print() - - @unittest.skipIf(IS_WINDOWS and os.getenv("VC_YEAR") == "2019", "Random seed only accepts int32") - def test_fuzzer(self): - fuzzer = benchmark_utils.Fuzzer( - parameters=[ - benchmark_utils.FuzzedParameter( - "n", minval=1, maxval=16, distribution="loguniform")], - tensors=[benchmark_utils.FuzzedTensor("x", size=("n",))], - seed=0, - ) - - expected_results = [ - (0.7821, 0.0536, 0.9888, 0.1949, 0.5242, 0.1987, 0.5094), - (0.7166, 0.5961, 0.8303, 0.005), - ] - - for i, (tensors, _, _) in enumerate(fuzzer.take(2)): - x = tensors["x"] - self.assertEqual( - x, torch.Tensor(expected_results[i]), rtol=1e-3, atol=1e-3) + for shell in [True, False]: + r = subprocess.run( + [exec_path], + shell=shell, + stdout=subprocess.PIPE, + ) + self.assertEqual(r.returncode, 0) + self.assertEqual( + # Windows prints "\r\n" for newlines. + textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"), + textwrap.dedent("""\ + 1 0 0 + 0 1 0 + 0 0 1 + [ CPUFloatType{3,3} ] + """) + ) + + finally: + shutil.rmtree(build_dir) if __name__ == '__main__': diff --git a/test/test_view_ops.py b/test/test_view_ops.py new file mode 100644 index 0000000000000..17d04c35d8cae --- /dev/null +++ b/test/test_view_ops.py @@ -0,0 +1,1402 @@ +import torch +import numpy as np + +import unittest +from itertools import product, permutations, combinations +from functools import partial +import random + +from torch.testing._internal.common_utils import \ + (TestCase, run_tests, suppress_warnings, make_tensor) +from torch.testing._internal.common_device_type import \ + (instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA) + +# TODO: replace this with make_tensor() in common_utils.py +def _generate_input(shape, dtype, device, with_extremal): + if shape == (): + x = torch.tensor((), dtype=dtype, device=device) + else: + if dtype.is_floating_point or dtype.is_complex: + # work around torch.randn not being implemented for bfloat16 + if dtype == torch.bfloat16: + x = torch.randn(*shape, device=device) * random.randint(30, 100) + x = x.to(torch.bfloat16) + else: + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x[torch.randn(*shape) > 0.5] = 0 + if with_extremal and dtype.is_floating_point: + # Use extremal values + x[torch.randn(*shape) > 0.5] = float('nan') + x[torch.randn(*shape) > 0.5] = float('inf') + x[torch.randn(*shape) > 0.5] = float('-inf') + elif with_extremal and dtype.is_complex: + x[torch.randn(*shape) > 0.5] = complex('nan') + x[torch.randn(*shape) > 0.5] = complex('inf') + x[torch.randn(*shape) > 0.5] = complex('-inf') + elif dtype == torch.bool: + x = torch.zeros(shape, dtype=dtype, device=device) + x[torch.randn(*shape) > 0.5] = True + else: + x = torch.randint(15, 100, shape, dtype=dtype, device=device) + + return x + +# TODO: replace this with make_tensor() in common_utils.py +def _rand_shape(dim, min_size, max_size): + shape = [] + for i in range(dim): + shape.append(random.randint(min_size, max_size)) + return tuple(shape) + +# TODO: refactor tests to avoid this function +# Converts half/bfloat16 dtype to float when device is cpu +def _convert_t(dtype, device): + if device == 'cpu' and dtype in {torch.half, torch.bfloat16}: + return torch.float + return dtype + +# TODO: replace this with make_tensor() in common_utils.py +# Returns a tensor of the requested shape, dtype, and device +# Requesting a half CPU tensor returns a float CPU tensor with +# values representable by a half. +# Initialization uses randint for non-float types and randn for float types. +def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: + # Returns a tensor filled with ones + if fill_ones: + return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) + + # Returns a tensor with random integer values + if not (dtype.is_floating_point or dtype.is_complex): + t = torch.randint(0, 10, shape, device=device) + if dtype != torch.uint8: + t = t - 5 # generate negative values also + return t.to(_convert_t(dtype, device)) + + # Populates the CPU tensor with floats representable as half/bfloat16 + if dtype == torch.half and device == 'cpu': + return torch.randn(*shape, dtype=torch.float, device=device).half().float() + if dtype == torch.bfloat16 and device == 'cpu': + return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float() + + # Default: returns a tensor with random float values + return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) + +# Tests ops and indexing to ensure they return views (and new tensors) as +# appropriate. +class TestViewOps(TestCase): + exact_dtype = True + + def is_view_of(self, base, other): + if (not other._is_view() or + other is base or + other._base is not base or + base.device != other.device): + return False + # Note: only validates storage on native device types + # because some accelerators, like XLA, do not expose storage + if base.device.type == 'cpu' or base.device.type == 'cuda': + if base.storage().data_ptr() != other.storage().data_ptr(): + return False + + return True + + # Returns true if v1 and v2 are views of the same base + def is_view_of_same_base(self, v1, v2): + if (not v1._is_view() or v1 is v2): + return False + return self.is_view_of(v1._base, v2) + + # Performs transpose if contiguous=True, else returns the input tensor as is + def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): + if contiguous: + return x + else: + return x.transpose(dim0, dim1) + + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_conj_self(self, device, dtype): + t = torch.ones(5, 5, device=device) + s = t.conj() + self.assertTrue(s is t) + + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False), torch.complex64) + def test_view_dtype(self, device, dtype): + int_dtype = { + torch.half: torch.int16, + torch.bfloat16: torch.int16, + torch.float: torch.int, + torch.double: torch.long, + torch.complex64: torch.long, + }[dtype] + numpy_dtype = { + torch.half: np.int16, + torch.bfloat16: np.int16, + torch.float: np.int32, + torch.double: np.int64, + torch.complex64: np.int64, + }[dtype] + + def generate_inputs(): + yield make_tensor((5, 5, 5), device, dtype, low=-5, high=5) + yield make_tensor((5, 5, 5), device, dtype, low=-5, high=5).permute(2, 0, 1) + yield make_tensor((1, 5, 1), device, dtype, low=-5, high=5).expand(5, 5, 5) + yield make_tensor((10, 5, 10), device, dtype, low=-5, high=5)[::2, :, ::2] + yield make_tensor((0, 5, 10), device, dtype, low=-5, high=5) + yield make_tensor((), device, dtype, low=-5, high=5) + + def run_test(fp_tensor): + self.assertRaises(RuntimeError, lambda: fp_tensor.view(torch.complex128)) + self.assertRaises(RuntimeError, lambda: fp_tensor.view(torch.int8)) + + int_tensor = fp_tensor.view(int_dtype) + self.assertEqual(int_tensor.dtype, int_dtype) + self.assertEqual(int_tensor.shape, fp_tensor.shape) + self.assertEqual(int_tensor.stride(), fp_tensor.stride()) + + self.assertEqual(fp_tensor, int_tensor.view(dtype), rtol=0, atol=0) + self.assertEqual(fp_tensor.cpu().numpy().view(numpy_dtype), int_tensor, rtol=0, atol=0) + + fp_tensor.zero_() + self.assertEqual(fp_tensor, torch.zeros_like(fp_tensor), rtol=0, atol=0) + + for fp_tensor in generate_inputs(): + run_test(fp_tensor) + + # Test that requires_grad is dropped, because view(dtype) does not support backward + if dtype is torch.double: + t = make_tensor((5, 5, 5), device, torch.double, low=-5, high=5, requires_grad=True) + self.assertFalse(t.view(torch.complex64).requires_grad) + + + @onlyOnCPUAndCUDA + def test_view_as_complex(self, device): + def fn(contiguous_input=True, dim0=0, dim1=1): + t = torch.randn(3, 2, 2, device=device) + c_t = t[:, :, 0] + 1j * t[:, :, 1] + + input = self._do_transpose(t, contiguous_input, dim0, dim1) + + if input.size()[-1] != 2: + self.assertRaisesRegex( + RuntimeError, "Tensor must have a last dimension of size 2", + lambda: torch.view_as_complex(input)) + return + + if input.stride()[-1] != 1: + self.assertRaisesRegex( + RuntimeError, "Tensor must have a last dimension with stride 1", + lambda: torch.view_as_complex(input)) + return + + res = torch.view_as_complex(input) + self.assertEqual(res, self._do_transpose(c_t, contiguous_input, dim0, dim1)) + self.assertTrue(self.is_view_of(t, res)) + + fn() + fn(contiguous_input=False) + # RuntimeError since in this case the last dim of input would not be of size 2 + fn(contiguous_input=False, dim0=0, dim1=2) + # RuntimeError since in this case the last dim of input would not have stride 1 + fn(contiguous_input=False, dim0=1, dim1=2) + + + # RuntimeError since in this case the stride of non-last dim of input would not be of size 2 + x = torch.randn(3, 3, device=device) + t = torch.as_strided(x, (2, 2), (1, 1)) + self.assertRaisesRegex( + RuntimeError, "Tensor must have a stride divisible by 2 for all but last dimension", + lambda: torch.view_as_complex(t)) + + # tensor with zero elements + x = torch.tensor([], device=device) # torch.Size([0]) + self.assertRaisesRegex( + RuntimeError, "Tensor must have a last dimension of size 2", + lambda: torch.view_as_complex(x)) + + # zero dimension tensor + z = torch.tensor(2.0) + self.assertRaisesRegex( + RuntimeError, "Input tensor must have one or more dimensions", + lambda: torch.view_as_complex(z)) + + y = x.reshape(0, 2) # torch.Size([0, 2]) + res = torch.view_as_complex(y) + self.assertTrue(self.is_view_of(x, res)) + self.assertEqual(res.shape, torch.Size([0])) + + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_complex_dtypes(include_complex32=True)) + def test_view_as_real(self, device, dtype): + def fn(contiguous_input=True): + t = torch.randn(3, 4, dtype=dtype, device=device) + input = self._do_transpose(t, contiguous_input) + res = torch.view_as_real(input) + self.assertEqual(res[:, :, 0], input.real) + self.assertEqual(res[:, :, 1], input.imag) + # TODO: Add torch.ComplexHalfStorage + if dtype != torch.complex32: + self.assertTrue(self.is_view_of(t, res)) + else: + self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res)) + + fn() + fn(contiguous_input=False) + + # tensor with zero elements + x = torch.tensor([], dtype=dtype, device=device) + res = torch.view_as_real(x) + # TODO: Add torch.ComplexHalfStorage + if dtype != torch.complex32: + self.assertTrue(self.is_view_of(x, res)) + else: + self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) + self.assertEqual(res.shape, torch.Size([0, 2])) + + # tensor with zero dim + x = torch.tensor(2 + 3j, dtype=dtype, device=device) + res = torch.view_as_real(x) + # TODO: Add torch.ComplexHalfStorage + if dtype != torch.complex32: + self.assertTrue(self.is_view_of(x, res)) + else: + self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) + self.assertEqual(res.shape, torch.Size([2])) + + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_dtypes()) + def test_view_tensor_split(self, device, dtype): + a = make_tensor((40, 30), device, dtype, low=-9, high=9) + a_split_dim0 = a.tensor_split(7, 0) + for a_split_dim0_tensor in a_split_dim0: + self.assertTrue(self.is_view_of(a, a_split_dim0_tensor)) + a_split_dim1 = a.tensor_split(7, 1) + for a_split_dim1_tensor in a_split_dim1: + self.assertTrue(self.is_view_of(a, a_split_dim1_tensor)) + + @onlyOnCPUAndCUDA + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + def test_real_imag_noncomplex(self, device, dtype): + t = torch.ones((5, 5), dtype=dtype, device=device) + + with self.assertRaises(RuntimeError): + torch.real(t) + + with self.assertRaises(RuntimeError): + torch.imag(t) + + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_complex_dtypes()) + def test_real_imag_view(self, device, dtype): + def compare_with_numpy(contiguous_input=True): + t = torch.randn(3, 3, dtype=dtype, device=device) + if not contiguous_input: + u = t.T + else: + u = t + + re = u.real + exp = torch.from_numpy(u.cpu().numpy().real).to(device=device) + self.assertEqual(re, exp) + # for the case of contiguous_input, t=u + # for the case of non contiguous_input, the base still remains + # t since we are performing a view operation to make the input non-contiguous + self.assertTrue(self.is_view_of(t, re)) + + im = u.imag + exp = torch.from_numpy(u.cpu().numpy().imag).to(device=device) + self.assertEqual(im, exp) + self.assertTrue(self.is_view_of(t, im)) + + compare_with_numpy() + compare_with_numpy(contiguous_input=False) + + # ensure storage offset is being correctly set + a = torch.randn(10, dtype=dtype) + self.assertEqual(a[5:].real, a.real[5:]) + self.assertEqual(a[5:].imag, a.imag[5:]) + + @onlyOnCPUAndCUDA + @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) + @suppress_warnings + def test_set_real_imag(self, device, dtypes): + x = torch.randn(10, dtype=dtypes[0], device=device) + + new_real = _make_tensor((10,), dtypes[1], device) + new_imag = _make_tensor((10,), dtypes[1], device) + + x.real = new_real + x.imag = new_imag + + if dtypes[1].is_complex: + self.assertEqual(x.real, new_real.real, exact_dtype=False) + self.assertEqual(x.imag, new_imag.real, exact_dtype=False) + + else: + self.assertEqual(x.real, new_real, exact_dtype=False) + self.assertEqual(x.imag, new_imag, exact_dtype=False) + + def test_diagonal_view(self, device) -> None: + t = torch.ones((5, 5), device=device) + v = torch.diagonal(t) + self.assertTrue(self.is_view_of(t, v)) + + v[0] = 0 + self.assertEqual(t[0, 0], v[0]) + + t = torch.ones((3, 3, 3), device=device) + v = torch.diagonal(t, offset=1, dim1=1, dim2=2) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0, 1], v[0, 0]) + + def test_select_view(self, device) -> None: + t = torch.ones((5, 5), device=device) + v = t.select(0, 2) + self.assertTrue(self.is_view_of(t, v)) + + v[0] = 0 + self.assertEqual(t[2, 0], v[0]) + + def test_unbind_view(self, device) -> None: + t = torch.zeros((5, 5), device=device) + tup = torch.unbind(t) + + for idx, v in enumerate(tup): + self.assertTrue(self.is_view_of(t, v)) + + v[0] = idx + 1 + self.assertEqual(t[idx, 0], v[0]) + + def test_expand_view(self, device) -> None: + t = torch.ones((5, 1), device=device) + v = t.expand(5, 5) + self.assertTrue(self.is_view_of(t, v)) + + v[2, 2] = 0 + self.assertEqual(t[2, 0], v[2, 2]) + + def test_expand_as_view(self, device): + t = torch.ones((5, 1), device=device) + e = torch.empty((5, 5), device=device) + v = t.expand_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[2, 2] = 0 + self.assertEqual(t[2, 0], v[2, 2]) + + def test_narrow_view(self, device): + t = torch.ones((5, 5), device=device) + v = torch.narrow(t, 1, 2, 2) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 2], v[0, 0]) + + def test_permute_view(self, device) -> None: + t = torch.ones((5, 5), device=device) + v = t.permute(1, 0) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_transpose_view(self, device): + for fn in (torch.swapdims, torch.swapaxes, torch.transpose): + t = torch.ones((5, 5), device=device) + v = fn(t, 0, 1) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_t_view(self, device): + t = torch.ones((5, 5), device=device) + v = t.t() + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_T_view(self, device): + t = torch.ones((5, 5), device=device) + v = t.T + self.assertTrue(self.is_view_of(t, v)) + + v[0, 1] = 0 + self.assertEqual(t[1, 0], v[0, 1]) + + def test_unfold_view(self, device): + t = torch.ones(10, device=device) + v = t.unfold(0, 3, 2) + self.assertTrue(self.is_view_of(t, v)) + + v[1, 0] = 0 + self.assertEqual(t[2], v[1, 0]) + + def test_squeeze_view(self, device): + t = torch.ones(5, 1, 5, device=device) + v = torch.squeeze(t) + self.assertTrue(self.is_view_of(t, v)) + v[0, 1] = 0 + self.assertEqual(t, v._base) + + def test_unsqueeze_view(self, device): + t = torch.ones(5, 5, device=device) + v = torch.unsqueeze(t, 1) + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0, 1] = 0 + self.assertEqual(t[0, 1], v[0, 0, 1]) + + def test_as_strided_view(self, device): + t = torch.ones(5, 5, device=device) + v = torch.as_strided(t, (25,), (1,)) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_view_view(self, device): + t = torch.ones(5, 5, device=device) + v = t.view(25) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_view_as_view(self, device): + t = torch.ones(5, 5, device=device) + e = torch.empty((25,)) + v = t.view_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_contiguous_self(self, device): + t = torch.ones(5, 5, device=device) + s = t.contiguous() + self.assertTrue(s is t) + + def test_contiguous_nonview(self, device): + t = torch.ones(5, 5, device=device) + nv = t.t().contiguous() + self.assertTrue(not self.is_view_of(t, nv)) + + nv[0, 0] = 0 + self.assertNotEqual(t[0, 0], nv[0, 0]) + + def test_reshape_view(self, device): + t = torch.ones(5, 5, device=device) + v = torch.reshape(t, (25,)) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_reshape_as_view(self, device): + t = torch.ones(5, 5, device=device) + e = torch.empty((25,), device=device) + v = t.reshape_as(e) + self.assertTrue(self.is_view_of(t, v)) + + v[6] = 0 + self.assertEqual(t[1, 1], v[6]) + + def test_reshape_nonview(self, device): + t = torch.ones(5, 5, device=device) + nv = torch.reshape(t.t(), (25,)) + self.assertTrue(not self.is_view_of(t, nv)) + + nv[6] = 0 + self.assertNotEqual(t[1, 1], nv[6]) + + def test_flatten_view(self, device): + def test_writes_propagate(t, v): + idx_t = (0,) * t.ndim + idx_v = (0,) * v.ndim + v[idx_v] = 0 + self.assertEqual(t[idx_t], v[idx_v]) + + t = torch.ones(1, 2, 3, 4, device=device) + v = t.flatten() + self.assertTrue(self.is_view_of(t, v)) + test_writes_propagate(t, v) + + # zero-dimensional tensor + t = torch.tensor(1, device=device) + v = t.flatten() + test_writes_propagate(t, v) + self.assertTrue(self.is_view_of(t, v)) + + t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) + v = t.flatten(0, 1) + test_writes_propagate(t, v) + self.assertTrue(self.is_view_of_same_base(t, v)) + + # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: + t = torch.ones(720, device=device) \ + .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)) + # [--1--|---2---|-3-] [--1--|----2---|-3-] + v1 = t.flatten(0, 1) + v2 = v1.flatten(1, 3) + v3 = v2.flatten(2, 2) + test_writes_propagate(t, v1) + self.assertTrue(self.is_view_of_same_base(t, v1)) + test_writes_propagate(t, v2) + self.assertTrue(self.is_view_of_same_base(t, v2)) + test_writes_propagate(t, v3) + self.assertTrue(self.is_view_of_same_base(t, v3)) + + @onlyOnCPUAndCUDA + def test_flatten_nonview(self, device): + def assert_is_nonview(t, nv): + idx_t = (0,) * t.ndim + idx_nv = (0,) * nv.ndim + self.assertTrue(not nv._is_view()) + nv[idx_nv] = 0 + self.assertNotEqual(t[idx_t], nv[idx_nv]) + t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) + nv = t.flatten(1, 3) + assert_is_nonview(t, nv) + + t = torch.ones(2, 2, device=device).T + nv = t.flatten() + assert_is_nonview(t, nv) + + # flatten returns the original object if start_dim=end_dim + t = t = torch.ones(2, 2, device=device) + nv = t.flatten(1, 1) + self.assertTrue(t is nv) + + def test_basic_indexing_slice_view(self, device): + t = torch.ones(5, 5, device=device) + v = t[:2, :3] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0], v[0, 0]) + + def test_basic_indexing_ellipses_view(self, device): + t = torch.ones(5, 5, device=device) + v = t[..., :2] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 0], v[0, 0]) + + def test_basic_indexing_newaxis_view(self, device): + t = torch.ones(5, 5, device=device) + v = t[None, :2, 3] + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = 0 + self.assertEqual(t[0, 3], v[0, 0]) + + def test_advanced_indexing_nonview(self, device): + t = torch.ones(3, 3, device=device) + rows = torch.tensor([[0, 0], [2, 2]], device=device) + cols = torch.tensor([[0, 1], [2, 2]], device=device) + nv = t[rows, cols] + self.assertTrue(not self.is_view_of(t, nv)) + + nv[1, 1] = 0 + self.assertNotEqual(t[2, 2], nv[1, 1]) + + def test_advanced_indexing_assignment(self, device): + t = torch.ones(3, 3, device=device) + rows = torch.tensor([[0, 0], [2, 2]], device=device) + cols = torch.tensor([[0, 1], [2, 2]], device=device) + t[rows, cols] = 0 + self.assertEqual(t[2, 2], 0) + + @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") + def test_chunk_view(self, device): + t = torch.zeros(3, 3, device=device) + l = torch.chunk(t, 3) + + for idx, v in enumerate(l): + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = idx + 1 + self.assertEqual(t[idx, 0], v[0, 0]) + + @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") + def test_split_view(self, device): + t = torch.zeros(3, 3, device=device) + l = torch.split(t, [1, 1, 1]) + + for idx, v in enumerate(l): + self.assertTrue(self.is_view_of(t, v)) + + v[0, 0] = idx + 1 + self.assertEqual(t[idx, 0], v[0, 0]) + + def test_movedim_view(self, device): + def run_test(device, op): + t = torch.zeros(3, 3, device=device) + out = op(t) + + self.assertTrue(self.is_view_of(t, out)) + + # Randomly change values in output + # and verify that original is changed + # as well. + for _ in range(3): + idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) + out[idx_1, idx_2] = random.random() + self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) + + for fn in [torch.movedim, torch.moveaxis]: + op = partial(fn, source=(0, 1), destination=(1, 0)) + run_test(device, op) + + op = partial(fn, source=0, destination=1) + run_test(device, op) + +class TestOldViewOps(TestCase): + def test_ravel(self, device): + + def _test_ravel(tensors, size, nc=False): + for src in tensors: + # Continuous Tensor -> View + flat = src.ravel() + self.assertEqual(flat.shape, torch.Size([size])) + self.assertEqual(src.view(-1), flat) + self.assertEqual(flat._base, src) + + # Non-continuous Tensor -> Copy + if nc: + nc_src = src.t() + nc_flat = nc_src.ravel() + self.assertEqual(nc_flat.shape, torch.Size([size])) + self.assertEqual(nc_src.reshape(-1), nc_flat) + self.assertTrue(nc_flat._base != nc_src) + + # Test that flatten returns 1-dim tensor when given a 0-dim tensor + zero_dim_tensor = torch.tensor(123, device=device) + flat0 = zero_dim_tensor.ravel() + one_dim_tensor = torch.tensor([123], device=device) + flat1 = zero_dim_tensor.ravel() + + self.assertEqual(zero_dim_tensor.shape, torch.Size([])) + self.assertEqual(flat0.shape, torch.Size([1])) + self.assertEqual(one_dim_tensor.shape, torch.Size([1])) + self.assertEqual(flat1.shape, torch.Size([1])) + self.assertEqual(flat0, one_dim_tensor) + self.assertEqual(flat0, flat1) + self.assertEqual(flat0.shape, flat1.shape) + + # Test both float tensor and quantized tensor + tensors = [torch.randn(5, 5, 5, 5, device=device), + torch._empty_affine_quantized([5, 5, 5, 5], + scale=2, + zero_point=3, + dtype=torch.quint8, + device=device)] + _test_ravel(tensors, 625) + + tensors = [torch.randn(0, 2, 3, device=device), + torch.randn(3, 0, 2, device=device), + torch._empty_affine_quantized([0, 2, 3], + scale=2, + zero_point=3, + dtype=torch.quint8, + device=device), + torch._empty_affine_quantized([3, 0, 2], + scale=2, + zero_point=3, + dtype=torch.quint8, + device=device)] + _test_ravel(tensors, 0) + + tensors = [torch.randn(5, 5, device=device), + torch._empty_affine_quantized([5, 5], + scale=2, + zero_point=3, + dtype=torch.quint8, + device=device)] + _test_ravel(tensors, 25, True) + + # TODO: this should be refactored into the view ops test suite + def test_empty_reshape(self, device): + x = torch.randn(0, 6, device=device) + self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) + # should be viewable -- i.e. data_ptr is the same. + self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) + + # match NumPy semantics -- don't infer the size of dimension with a degree of freedom + self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) + + def test_expand(self, device): + tensor = torch.rand(1, 8, 1, device=device) + tensor2 = torch.rand(5, device=device) + template = torch.rand(4, 8, 5, device=device) + target = template.size() + self.assertEqual(tensor.expand_as(template).size(), target) + self.assertEqual(tensor.expand(4, 8, 5).size(), target) + self.assertEqual(tensor.expand(target).size(), target) + self.assertEqual(tensor2.expand_as(template).size(), target) + self.assertEqual(tensor2.expand(4, 8, 5).size(), target) + self.assertEqual(tensor2.expand(target).size(), target) + + # test double expand + self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) + + # test non-contiguous + noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] + self.assertFalse(noncontig.is_contiguous()) + self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) + + # make sure it's compatible with unsqueeze + expanded = tensor2.expand(1, 1, 5) + unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) + self.assertEqual(expanded, unsqueezed) + self.assertEqual(expanded.stride(), unsqueezed.stride()) + + # test -1 as target size + self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) + self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) + + # test expanding empty to empty + self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)) + + # TODO: this should be refactored into the view ops test suite + def test_view_empty(self, device): + x = torch.randn(0, 6, device=device) + self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) + + # TODO: this should be refactored into the view ops test suite + @onlyOnCPUAndCUDA + def test_reshape(self, device): + x = torch.randn(3, 3, device=device) + self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) + self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) + self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) + + y = torch.randn(4, 4, 4, device=device)[:, 0, :] + self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) + self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) + self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) + + s = torch.randn((), device=device) + self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) + self.assertEqual(s.reshape(-1).shape, (1,)) + self.assertRaises(RuntimeError, lambda: s.reshape(2)) + + empty = torch.tensor([], device=device) + self.assertEqual(empty, empty.reshape(-1)) + self.assertEqual(empty, empty.reshape([0])) + # TODO: fix these once we have multi-dimensional empty tensors + self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) + self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) + self.assertRaises(RuntimeError, lambda: empty.reshape(1)) + + x = torch.randn(3, 3, device=device) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) + self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))) + + def test_flatten(self, device): + # Test that flatten returns 1-dim tensor when given a 0-dim tensor + zero_dim_tensor = torch.tensor(123, device=device) + flat0 = zero_dim_tensor.flatten() + one_dim_tensor = torch.tensor([123], device=device) + flat1 = zero_dim_tensor.flatten() + + self.assertEqual(zero_dim_tensor.shape, torch.Size([])) + self.assertEqual(flat0.shape, torch.Size([1])) + self.assertEqual(one_dim_tensor.shape, torch.Size([1])) + self.assertEqual(flat1.shape, torch.Size([1])) + self.assertEqual(flat0, one_dim_tensor) + self.assertEqual(flat0, flat1) + self.assertEqual(flat0.shape, flat1.shape) + + # Test both float tensor and quantized tensor + tensors = [torch.randn(5, 5, 5, 5, device=device), + torch._empty_affine_quantized([5, 5, 5, 5], + scale=2, + zero_point=3, + dtype=torch.quint8, + device=device)] + for src in tensors: + flat = src.flatten(0, -1) + self.assertEqual(flat.shape, torch.Size([625])) + self.assertEqual(src.view(-1), flat.view(-1)) + + flat = src.flatten(0, 2) + self.assertEqual(flat.shape, torch.Size([125, 5])) + self.assertEqual(src.view(-1), flat.view(-1)) + + flat = src.flatten(0, 1) + self.assertEqual(flat.shape, torch.Size([25, 5, 5])) + self.assertEqual(src.view(-1), flat.view(-1)) + + flat = src.flatten(1, 2) + self.assertEqual(flat.shape, torch.Size([5, 25, 5])) + self.assertEqual(src.view(-1), flat.view(-1)) + + flat = src.flatten(2, 3) + self.assertEqual(flat.shape, torch.Size([5, 5, 25])) + self.assertEqual(src.view(-1), flat.view(-1)) + + flat = src.flatten(-2, -1) + self.assertEqual(flat.shape, torch.Size([5, 5, 25])) + self.assertEqual(src.view(-1), flat.view(-1)) + + flat = src.flatten(2, 2) + self.assertEqual(flat, src) + + # out of bounds index + with self.assertRaisesRegex(IndexError, 'Dimension out of range'): + src.flatten(5, 10) + + # invalid start and end + with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'): + src.flatten(2, 0) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_narrow(self, device): + x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]])) + self.assertEqual(x.narrow(0, 0, 2), torch.Tensor([[0, 1, 2], [3, 4, 5]])) + self.assertEqual(x.narrow(0, 1, 1), torch.Tensor([[3, 4, 5]])) + self.assertEqual(x.narrow(0, -1, 1), torch.Tensor([[6, 7, 8]])) + self.assertEqual(x.narrow(0, -2, 2), torch.Tensor([[3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(0, -3, 3), torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]])) + self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]])) + + # TODO: update to work on CUDA, too + @onlyCPU + def test_narrow_tensor(self, device): + x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.Tensor([[0, 1, 2]])) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor(0.), 1) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor([0]), 1) + with self.assertRaises(Exception): + x.narrow(0, torch.tensor([0, 1]), 1) + + # TODO: make work on CUDA, too + @onlyCPU + def test_t(self, device): + # Test 0D tensors + x = torch.randn(()) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) + + # Test 1D tensors + x = torch.arange(4) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) + + # Test 2D tensors + x = torch.rand((2, 2)) + self.assertEqual(x.t(), x.transpose(0, 1)) + x = x.to_sparse() + self.assertEqual(x.t(), x.transpose(0, 1)) + + # Test 3D tensor + x = torch.rand((2, 2, 2)) + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): + x.t() + x = x.to_sparse() + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): + x.t() + + @onlyCPU + def test_split(self, device): + tensor = torch.rand(7, 4) + split_size = 3 + dim = 0 + target_sizes = ([3, 4], [3, 4], [1, 4]) + splits = tensor.split(split_size, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + # Variable sections split + tensor = torch.randn(20, 10) + dim = 0 + split_sizes = [5, 5, 10] + target_sizes = ([[5, 10], [5, 10], [10, 10]]) + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + split_sizes = [2, 2, 6] + target_sizes = ([20, 2], [20, 2], [20, 6]) + dim = 1 + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) + start = start + target_size[dim] + + @onlyCPU + def test_chunk(self, device): + tensor = torch.rand(4, 7) + num_chunks = 3 + dim = 1 + target_sizes = ([4, 3], [4, 3], [4, 1]) + splits = tensor.chunk(num_chunks, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, + atol=0, rtol=0) + start = start + target_size[dim] + + # Invalid chunk sizes + error_regex = 'chunk expects.*greater than 0' + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(0) + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(-2) + + # TODO: make work on CUDA, too + @onlyCPU + def test_unsqueeze(self, device) -> None: + x = torch.randn(2, 3, 4) + y = x.unsqueeze(1) + self.assertEqual(y, x.view(2, 1, 3, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.view(2, 3, 1, 4)) + + x = x[:, 1] + self.assertFalse(x.is_contiguous()) + y = x.unsqueeze(1) + self.assertEqual(y, x.contiguous().view(2, 1, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.contiguous().view(2, 4, 1)) + + # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) + def test_big_transpose(self, device): + t = torch.rand(456, 789, device=device) + t1 = t.t().contiguous() + t2 = torch.from_numpy(t.cpu().numpy().transpose()) + self.assertEqual(t1, t2) + + def test_T(self, device): + a = torch.randn(2, 3, 4, device=device) + t1 = a.T + t2 = a.permute(2, 1, 0) + self.assertEqual(t2, t1) + b = torch.randn(10, device=device) + self.assertEqual(b, b.T) + scalar = torch.tensor(5, device=device) + self.assertEqual(scalar, scalar.T) + + def test_python_types(self, device): + a1 = torch.randn((1, 2), device=device, dtype=torch.float64) + a2 = torch.randn((1, 2), device=device, dtype=float) + self.assertEqual(a1.dtype, a2.dtype) + + b1 = torch.arange(10, 20, dtype=torch.int64, device=device) + b2 = torch.arange(10, 20, dtype=int, device=device) + self.assertEqual(b1.dtype, b2.dtype) + + c1 = torch.tensor([True, False], dtype=torch.bool, device=device) + c2 = torch.tensor([True, False], dtype=bool, device=device) + self.assertEqual(c1.dtype, c2.dtype) + + # TODO: is resize best put in test_view_ops? + def test_resize_as_preserves_strides(self, device): + x = torch.empty(2, 3).t() + old_strides = x.stride() + x.resize_as_(x) + self.assertEqual(x.stride(), old_strides) + + def test_memory_format_resize_as(self, device): + def test_helper(shape, memory_format, device): + xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format) + flat = torch.randn(xc.numel(), device=device) + flat.resize_as_(xc, memory_format=torch.preserve_format) + self.assertTrue(flat.is_contiguous(memory_format=memory_format)) + + test_helper((10, 3, 32, 32), torch.channels_last, device) + test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device) + + def test_memory_format_resize_(self, device): + def test_helper(shape, numel, memory_format, device): + flat = torch.randn(numel, device=device) + flat.resize_(shape, memory_format=memory_format) + self.assertTrue(flat.is_contiguous(memory_format=memory_format)) + + test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device) + test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device) + + @onlyOnCPUAndCUDA + @dtypes(torch.int64, torch.float, torch.complex128) + def test_transpose_invalid(self, device, dtype): + for fn in (torch.swapdims, torch.swapaxes, torch.transpose): + shape = _rand_shape(4, min_size=5, max_size=10) + x = _generate_input(shape, dtype, device, False) + + # Invalid `source` and `destination` dimension + with self.assertRaisesRegex(IndexError, "Dimension out of range"): + fn(x, 5, 0) + + with self.assertRaisesRegex(IndexError, "Dimension out of range"): + fn(x, 0, 5) + + @dtypes(torch.int64, torch.float, torch.complex128) + def test_transpose_vs_numpy(self, device, dtype): + for fn in (torch.swapdims, torch.swapaxes, torch.transpose): + for nd in range(5): + shape = _rand_shape(nd, min_size=5, max_size=10) + x = _generate_input(shape, dtype, device, with_extremal=False) + for random_negative in [True, False]: + for src_dim, dst_dim in permutations(range(nd), r=2): + random_prob = random.random() + + if random_negative and random_prob > 0.66: + src_dim = src_dim - nd + elif random_negative and random_prob > 0.33: + dst_dim = dst_dim - nd + elif random_negative: + src_dim = src_dim - nd + dst_dim = dst_dim - nd + + partial_map = { + torch.swapdims: partial(torch.swapdims, dim0=src_dim, dim1=dst_dim), + torch.swapaxes: partial(torch.swapaxes, axis0=src_dim, axis1=dst_dim), + torch.transpose: partial(torch.transpose, dim0=src_dim, dim1=dst_dim), + } + + torch_fn = partial_map[fn] + np_fn = partial(np.swapaxes, axis1=src_dim, axis2=dst_dim) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + # Move dim to same position + x = torch.randn(2, 3, 5, 7, 11) + partial_map = { + torch.swapdims: partial(torch.swapdims, dim0=0, dim1=0), + torch.swapaxes: partial(torch.swapaxes, axis0=0, axis1=0), + torch.transpose: partial(torch.transpose, dim0=0, dim1=0), + } + torch_fn = partial_map[fn] + np_fn = partial(np.swapaxes, axis1=0, axis2=0) + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + def _test_atleast_dim(self, torch_fn, np_fn, device, dtype): + for ndims in range(0, 5): + shape = _rand_shape(ndims, min_size=5, max_size=10) + for n in range(ndims + 1): + for with_extremal in [False, True]: + for contiguous in [False, True]: + # Generate Input. + x = _generate_input(shape, dtype, device, with_extremal) + if contiguous: + x = x.T + self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + + # Compare sequence input + torch_sequence_x = (x,) * random.randint(3, 10) + np_sequence_x = tuple(np.array(x.detach().cpu().numpy()) for x in torch_sequence_x) + torch_res = torch_fn(*torch_sequence_x) + np_res = np_fn(*np_sequence_x) + + torch_res = tuple(x.cpu() for x in torch_res) + np_res = tuple(torch.from_numpy(x) for x in np_res) + self.assertEqual(np_res, torch_res) + + # TODO: are these view ops? + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + + torch.testing.get_all_complex_dtypes())) + def test_atleast(self, device, dtype): + self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype) + self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) + self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype) + + @onlyCPU + @dtypes(torch.float) + def test_broadcast_tensors(self, device, dtype): + x0 = torch.randn(2, 1, 3, dtype=dtype, device=device) + x1 = torch.randn(3, dtype=dtype, device=device) + x2 = torch.randn(3, 1, dtype=dtype, device=device) + expected_size = (2, 3, 3) + + y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2) + self.assertTrue(y0.size() == expected_size) + self.assertTrue(y1.size() == expected_size) + self.assertTrue(y2.size() == expected_size) + + + @onlyCPU + def test_broadcast_shapes(self, device): + examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)] + for s0 in examples: + x0 = torch.randn(s0) + expected = torch.broadcast_tensors(x0)[0].shape + actual = torch.broadcast_shapes(s0) + self.assertEqual(expected, actual) + + for s1 in examples: + x1 = torch.randn(s1) + expected = torch.broadcast_tensors(x0, x1)[0].shape + actual = torch.broadcast_shapes(s0, s1) + self.assertEqual(expected, actual) + + # Skip BFloat16 since numpy does not support it + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + def test_broadcast_to(self, device, dtype): + def can_broadcast(s0, s1): + # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension + s0 = tuple(reversed(s0)) + s1 = tuple(reversed(s1)) + for i in range(len(s0)): + if s0[i] != 1 and s0[i] != s1[i]: + return False + return True + + sizes = ( + (), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2) + ) + for s0, s1 in combinations(sizes, r=2): + t = make_tensor(s0, device, dtype, low=-9, high=9) + t_np = t.cpu().numpy() + + if can_broadcast(s0, s1): + res = torch.broadcast_to(t, s1) + np_res = np.broadcast_to(t_np, s1) + self.assertEqual(res, np_res) + else: + with self.assertRaisesRegex(RuntimeError, + r"The expanded size of the tensor \(\d\) " + r"must match the existing size \(\d\)"): + torch.broadcast_to(t, s1) + + def test_view(self, device): + tensor = torch.rand(15, device=device) + template = torch.rand(3, 5, device=device) + empty = torch.empty(0, device=device) + target = template.size() + self.assertEqual(tensor.view_as(template).size(), target) + self.assertEqual(tensor.view(3, 5).size(), target) + self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) + self.assertEqual(tensor.view(-1, 5).size(), target) + self.assertEqual(tensor.view(3, -1).size(), target) + tensor_view = tensor.view(5, 3) + tensor_view.fill_(random.uniform(0, 1)) + self.assertEqual(empty.view_as(empty), empty) + self.assertEqual(empty.view(0), empty) + self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) + self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) + + # test size inference with empty tensors + self.assertEqual(empty.view(-1).size(), torch.Size([0])) + self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) + + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(-1, 0) + + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(3, 0, -1, 0) + + self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) + self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) + self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) + + # test view when tensor is not contiguous in every dimension, but only + # contiguous dimensions are touched. + tensor = torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device).transpose(-1, 2).transpose(-2, 3) + # size: [ 4, 2, 3, 9, 6, 2, 1, 5] + # stride: [3840, 1620, 1, 3, 54, 27, 324, 324] + # contiguous dim chunks: [__________, ____, ____, __________, ____, ____] + # merging 1 to chunk after: [__________, ____, ____, __________, __________] + contig_tensor = tensor.clone() + # [4, 2] => [8, 1] + # [3] => [3] + # [9] => [3, 3] + # [6, 2] => [4, 1, 3] + # [1, 5] => [5] + view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5] + self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) + # [4, 2] => [2, 4] + # [3] => [3] + # [9] => [1, 9] + # [6, 2] => [2, 2, 3] + # [1, 5] => [5, 1] + view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1] + self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) + # adding size 1 dims + view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] + self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) + + # invalid views + self.assertRaises(RuntimeError, lambda: tensor.view(-1)) + # crossing [4, 2], [3] + self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5)) + # crossing [6, 2], [1, 5] + self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10)) + # crossing [9], [6, 2] + self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5)) + + # view with stride 0 dims + tensor = torch.empty(1, 1, device=device).expand(3, 4) # all dims are contiguous + contig_tensor = tensor.clone() + self.assertEqual(tensor.view(-1), contig_tensor.view(-1)) + self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1)) + self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1)) + self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) + self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) + + def test_contiguous(self, device): + x = torch.randn(1, 16, 5, 5, device=device) + self.assertTrue(x.is_contiguous()) + stride = list(x.stride()) + stride[0] = 20 + # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 + x.set_(x.storage(), 0, x.size(), stride) + self.assertTrue(x.is_contiguous()) + + @onlyOnCPUAndCUDA + # Skip BFloat16 since numpy does not support it + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + def test_tensor_split_sections(self, device, dtype): + input_sizes = [ + (0,), + (10,), + (10, 0), + (0, 10), + (4, 10), + (12, 3), + ] + for input_size in input_sizes: + a_base = make_tensor(input_size, device, dtype, low=-9, high=9) + # Run tests on transposed input if it has at least 2 dims + for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: + a_n = a.cpu().numpy() + for dim in range(-a.dim(), a.dim()): + for sections in range(1, 2 * a.size(dim)): + msg = f'input_size {input_size}, sections {sections}, dim {dim}' + result1 = torch.tensor_split(a, sections, dim) + result2 = torch.tensor_split(a, torch.tensor(sections, dtype=torch.int64), dim) + for r1, r2 in zip(result1, result2): + self.assertEqual(r1.device, torch.device(device), msg=msg) + self.assertEqual(r1.dtype, dtype, msg=msg) + self.assertEqual(r2.device, torch.device(device), msg=msg) + self.assertEqual(r2.dtype, dtype, msg=msg) + result_n = np.array_split(a_n, sections, dim) + self.assertEqual(result_n, result1, msg=msg) + self.assertEqual(result_n, result2, msg=msg) + + @onlyOnCPUAndCUDA + # Skip BFloat16 since numpy does not support it + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + def test_tensor_split_indices(self, device, dtype): + input_sizes = [ + (0,), + (10,), + (10, 0), + (0, 10), + (4, 10), + (12, 3), + ] + indices_args = [ + (), + (0,), + (3,), + (10,), + (-1,), + (-10,), + (2, -1), + (3, 4, 10), + (0, -1, 0, 10), + (1, 5, 2, 8), + ] + for input_size in input_sizes: + a_base = make_tensor(input_size, device, dtype, low=-9, high=9) + # Run tests on transposed input if it has at least 2 dims + for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: + a_n = a.cpu().numpy() + for dim in range(-a.dim(), a.dim()): + for indices in indices_args: + result_1 = torch.tensor_split(a, indices, dim) + result_2 = torch.tensor_split(a, torch.tensor(indices, dtype=torch.int64), dim) + + msg = f'input_size {input_size}, indices {indices}, dim {dim}' + for r1, r2 in zip(result_1, result_2): + self.assertEqual(r1.device, torch.device(device), msg=msg) + self.assertEqual(r1.dtype, dtype, msg=msg) + self.assertEqual(r2.device, torch.device(device), msg=msg) + self.assertEqual(r2.dtype, dtype, msg=msg) + + result_n = np.array_split(a_n, indices, dim) + self.assertEqual(result_n, result_1, msg=msg) + self.assertEqual(result_n, result_2, msg=msg) + + @onlyOnCPUAndCUDA + def test_tensor_split_errors(self, device): + S = 10 + test_cases = [ + # input size, sections or indices, dim, error type, error message, numpy error type + [(S,), 10, 1, IndexError, r'Dimension out of range', IndexError], + [(), 10, 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, ' + + 'but got a tensor with 0 dims', IndexError], + [(S,), (10,), 1, IndexError, r'Dimension out of range', IndexError], + [(), (10,), 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, ' + + 'but got a tensor with 0 dims', IndexError], + [(S,), 0, 0, RuntimeError, r'number of sections must be larger than 0, got 0', ValueError], + [(S,), -1, 0, RuntimeError, r'number of sections must be larger than 0, got -1', ValueError], + ] + for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases: + a = torch.randn(input_size, device=device) + msg = f'input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}' + with self.assertRaisesRegex(err, err_msg, msg=msg): + torch.tensor_split(a, sections_or_indices, dim) + with self.assertRaisesRegex(err, err_msg, msg=msg): + torch.tensor_split(a, torch.tensor(sections_or_indices), dim) + with self.assertRaises(numpy_err, msg=msg): + np.array_split(a.cpu().numpy(), sections_or_indices, dim) + + # addtional tests for tensor_split with tensor_indices_or_sections + with self.assertRaisesRegex(RuntimeError, + r'tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float'): + torch.tensor_split(a, torch.tensor(1.1), dim) + + with self.assertRaisesRegex(RuntimeError, + r'tensor_split expected tensor_indices_or_sections to be a' + + ' zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims'): + torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0) + + def test_resize_all_dtypes_and_devices(self, device): + shape = (2, 2) + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + x.resize_(shape) + self.assertEqual(shape, x.shape) + + def test_resize_as_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) + x.resize_as_(y) + self.assertEqual(y.shape, x.shape) + + def test_view_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + self.assertEqual(x.view(6).shape, [6]) + + +instantiate_device_type_tests(TestViewOps, globals()) +instantiate_device_type_tests(TestOldViewOps, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_vmap.py b/test/test_vmap.py index abec2c0ae4897..d2386811d2711 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -1,12 +1,17 @@ from torch.testing._internal.common_utils import TestCase, run_tests import torch +import torch.nn.functional as F from torch import Tensor, vmap import functools +import itertools import warnings from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import TEST_WITH_ROCM import types + +FALLBACK_REGEX = r'falling back to slow \(for loop( and stack)?\) implementation' + class TestVmapAPI(TestCase): def test_non_tensor_output_raises(self): with self.assertRaisesRegex(ValueError, "got type as the return"): @@ -24,6 +29,10 @@ def test_different_map_dim_size_raises(self): expected_msg = 'Expected all tensors to have the same size in the mapped dimension' with self.assertRaisesRegex(ValueError, expected_msg): vmap(torch.mul)(x, y) + with self.assertRaisesRegex(ValueError, expected_msg): + vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) + with self.assertRaisesRegex(ValueError, expected_msg): + vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y}) def test_func_with_no_inputs(self): expected_msg = 'got no inputs' @@ -124,11 +133,17 @@ def test_unsupported_op_err_msg(self): # Unsupported view op tensor = torch.randn(2, 3) msg = ( - "Batching rule not implemented for aten::as_strided; the " - "fallback path doesn't work on in-place or view ops" + r"Batching rule not implemented for aten::.+; the " + r"fallback path doesn't work on out= or view ops" ) with self.assertRaisesRegex(RuntimeError, msg): - vmap(torch.as_strided, (0, None, None))(tensor, [2, 3], [0, 0]) + vmap(torch.ravel)(tensor) + + def out_op(x, y): + return torch.abs(x, out=y) + + with self.assertRaisesRegex(RuntimeError, msg): + vmap(out_op)(tensor, tensor) # The fallback doesn't support TensorList with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'): @@ -139,15 +154,6 @@ def test_unsupported_op_err_msg(self): with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'): vmap(torch.Tensor.item)(tensor) - def test_unsupported_inplace_op_err_msg(self): - def foo(x): - return x.cos_() - - x = torch.randn(3) - with self.assertRaisesRegex( - RuntimeError, 'Batching rule not implemented'): - vmap(foo)(x) - def test_nonzero_out_dims(self): # Basic test tensor = torch.randn(2, 3) @@ -350,51 +356,68 @@ def foo(x): result = vmap(vmap(foo, 1, 1), 1, 1)(x) self.assertEqual(result, x * 2) + def test_accepts_nested_inputs(self): + B0 = 2 + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # Single layer of nesting + out = vmap(lambda z: z[0] + z[1])((x, y)) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y)) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) + self.assertEqual(out, x + y) + + out = vmap(lambda z: z[0] + z[1])([x, y]) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y]) + self.assertEqual(out, x + y) + out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y]) + self.assertEqual(out, x + y) + + out = vmap(lambda z: z['x'] + z['y'])({'x': x, 'y': y}) + self.assertEqual(out, x + y) + out = vmap(lambda z: z['x'] + z['y'], in_dims=(0,))({'x': x, 'y': y}) + self.assertEqual(out, x + y) + out = vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y}) + self.assertEqual(out, x + y) + + # Multiple layers of nesting + out_fn = vmap(lambda z: z['x'][0] + z['x'][1][0] + z['y'][0] + z['y'][1]) + out = out_fn({'x': [x, (x,)], 'y': [y, y]}) + self.assertEqual(out, x + x + y + y) + def test_in_dims_wrong_type_err_msg(self): x = torch.randn(3) y = torch.randn(3) - msg = 'expected `in_dims` to be int or tuple' + msg = r'expected `in_dims` to be int or a \(potentially nested\) tuple' with self.assertRaisesRegex(ValueError, msg): vmap(torch.mul, [0, 0])(x, y) with self.assertRaisesRegex(ValueError, msg): vmap(torch.mul, set({0, 0}))(x, y) with self.assertRaisesRegex(ValueError, msg): vmap(torch.mul, 'lol')(x, y) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y]) # The following should not throw vmap(torch.mul, (0, 0))(x, y) def test_not_enough_in_dims_err_msg(self): x = torch.randn(3) y = torch.randn(3) - msg = r'expected one `in_dim` per input \(got \w+ inputs\)' + msg = r'in_dims is not compatible with the structure of `inputs`' with self.assertRaisesRegex(ValueError, msg): vmap(torch.mul, (0,))(x, y) with self.assertRaisesRegex(ValueError, msg): vmap(torch.mul, (0, 0, 0))(x, y) - # The following should not throw - vmap(torch.mul, (0, 0))(x, y) - - def test_in_dims_must_be_flat_tuple_err_msg(self): - msg = 'in_dims must be a flat tuple containing ints and/or Nones' - - x = torch.randn(3) - y = torch.randn(3) - z = torch.randn(3) - - def foo(xy): - return xy[0] * xy[1] - - def bar(x, yz): - return x * yz[0] * yz[1] - - # NB: jax supports all of the following, we don't yet. with self.assertRaisesRegex(ValueError, msg): - vmap(foo, ((0, 0),))((x, y)) + vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y]) with self.assertRaisesRegex(ValueError, msg): - vmap(bar, (0, (0, 0)))(x, (y, z)) - with self.assertRaisesRegex(ValueError, msg): - vmap(foo, ({0: 0, 1: 0},))({0: x, 1: y}) + vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y]) + # The following should not throw + vmap(torch.mul, (0, 0))(x, y) def test_integer_in_dim_but_not_tensor_input_err_msg(self): def foo(xy): @@ -406,26 +429,14 @@ def bar(x, yz): x = torch.randn(2, 3) y = torch.randn(2, 3) - # jax supports these, we too can in the future. - msg = 'Got in_dim=0 for input 0, but input 0 is not a Tensor' - with self.assertRaisesRegex(ValueError, msg): - vmap(foo)((x, y)) - with self.assertRaisesRegex(ValueError, msg): - vmap(foo, (0,))((x, y)) - - # jax supports these as well, we too can in the future. - msg = 'Got in_dim=0 for input 1, but input 1 is not a Tensor' - with self.assertRaisesRegex(ValueError, msg): - vmap(foo)(x, (x, y)) - with self.assertRaisesRegex(ValueError, msg): - vmap(foo, (0, 0))(x, (x, y)) - # the following are errors in jax (and will always be errors) - msg = 'Got in_dim=0 for input 1, but input 1 is not a Tensor' + msg = 'Got in_dim=0 for an input but the input is of type' with self.assertRaisesRegex(ValueError, msg): vmap(torch.sum)(x, 0) with self.assertRaisesRegex(ValueError, msg): vmap(torch.sum, (0, 0))(x, 0) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1]) # The following should not throw vmap(torch.sum, (0, None))(x, 0) @@ -433,15 +444,20 @@ def test_in_dim_not_in_tensor_err_msg(self): def foo(x): return x * x - msg = r'Got in_dim=-?\w for input 0, but input 0 is a Tensor of dimensionality \w' + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + msg = r'Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w' with self.assertRaisesRegex(ValueError, msg): vmap(foo)(torch.randn([])) with self.assertRaisesRegex(ValueError, msg): vmap(foo, in_dims=(0,))(torch.randn([])) with self.assertRaisesRegex(ValueError, msg): - vmap(foo, in_dims=(-1,))(torch.randn(2, 3)) + vmap(foo, in_dims=(-1,))(x) with self.assertRaisesRegex(ValueError, msg): - vmap(foo, in_dims=(2,))(torch.randn(2, 3)) + vmap(foo, in_dims=(2,))(y) + with self.assertRaisesRegex(ValueError, msg): + vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y]) # the following should not throw vmap(foo, in_dims=(0,))(torch.randn(2, 3)) vmap(foo, in_dims=(1,))(torch.randn(2, 3)) @@ -450,8 +466,38 @@ def _assert_uses_vmap_fallback(self, vmap_args, inputs): with warnings.catch_warnings(record=True) as wa: result = vmap(*vmap_args)(*inputs) self.assertEqual(len(wa), 2) - self.assertRegex(str(wa[-1].message), - r'falling back to slow \(for loop and stack\) implementation') + self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) + + def test_fallback_zero_dim(self): + # NB: One day we will implement a batching rule for torch.atan2. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = torch.atan2 + x = torch.randn(11) + y = torch.randn(11) + self._assert_uses_vmap_fallback((op,), (x, y)) + + B0, B1 = 0, 3 + x = torch.randn(B0, 11) + y = torch.randn(11) + + msg = 'The fallback path does not support vmap over dims of size 0' + + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (0, None))(x, y) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (None, 0))(y, x) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(x, x) + + x = torch.randn(B0, B1, 11) + y = torch.randn(B1, 11) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (0, None))(x, y) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (None, 0))(y, x) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(x, x) def test_fallback_atan2(self): # NB: One day we will implement a batching rule for torch.atan2. @@ -464,13 +510,13 @@ def test_fallback_atan2(self): self._assert_uses_vmap_fallback((op,), (x, y)) - # fallback on torch.sub + # fallback on torch.atan2 x = torch.randn(7, 11, 5) y = torch.randn(5, 7, 11) result = vmap(op, (2, 0))(x, y) self.assertEqual(result, op(x.permute(2, 0, 1), y)) - # fallback on torch.sub, nested vmap + # fallback on torch.atan2, nested vmap x = torch.randn(7, 11, 5) y = torch.randn(5, 7, 11) result = vmap(vmap(op), (2, 0))(x, y) @@ -529,6 +575,128 @@ def test_fallback_multiple_returns(self): expected = torch.var_mean(tensor, dim=3) self.assertEqual(result, expected) + def test_inplace_fallback_unary(self): + # Test the in-place fallback on an in-place method that takes no + # additional Tensor arguments. This is the simplest case of the fallback. + # NB: One day we will implement a batching rule for acos_. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = Tensor.acos_ + B0, B1, B2 = 2, 3, 10000 + + x = torch.randn(B0, 5) + self._assert_uses_vmap_fallback((op,), (x,)) + + # Single vmap + x_orig = torch.rand(B0, 5) + x = x_orig.clone() + result = vmap(op)(x) + self.assertTrue(result is x) + self.assertEqual(result, x_orig.acos()) + + # Single vmap + different out_dim produces a view(!) + x_orig = torch.rand(B0, 5) + x = x_orig.clone() + result = vmap(op, out_dims=(1,))(x) + self.assertTrue(result._base is x) + self.assertEqual(result, x_orig.t().acos()) + + # Nested vmap + x_orig = torch.randn(B0, B1, 5) + x = x_orig.clone() + result = vmap(vmap(op))(x) + self.assertTrue(result is x) + self.assertEqual(result, x_orig.acos()) + + # Nested vmap, large batch size + x_orig = torch.randn(B0, B1, B2, 5) + x = x_orig.clone() + result = vmap(vmap(vmap(op)))(x) + self.assertTrue(result is x) + self.assertEqual(result, x_orig.acos()) + + def test_inplace_fallback_nary_same_levels(self): + # NB: One day we will implement a batching rule for atan2_ + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = Tensor.atan2_ + outplace_op = torch.atan2 + + x = torch.randn(5, 7, 11) + y = torch.randn(5, 7, 11) + self._assert_uses_vmap_fallback((op,), (x, y)) + + # Single vmap + B0 = 5 + x_orig = torch.randn(7, 11, B0) + x = x_orig.clone() + y = torch.randn(B0, 7, 11) + vmap(op, (2, 0))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2))) + + # Nested vmap + B0, B1 = 5, 7 + x_orig = torch.randn(B1, 11, B0) + x = x_orig.clone() + y = torch.randn(B0, B1, 11) + vmap(vmap(op), (2, 0))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0]))) + + # big batch size (total 10000) + B0, B1, B2 = 100, 10, 10 + x_orig = torch.randn(B0, B1, B2, 5) + x = x_orig.clone() + y = torch.randn(B0, B1, B2) + result = vmap(vmap(vmap(op)))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1))) + + def test_inplace_fallback_nary_different_levels(self): + # NB: One day we will implement a batching rule for atan2_ + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = Tensor.atan2_ + outplace_op = torch.atan2 + B0, B1, B2 = 2, 3, 5 + + x = torch.rand(B0, 7) + y = torch.rand(7) + self._assert_uses_vmap_fallback((op, (0, None)), (x, y)) + + # op(left, right): All of the levels in right are found in left + x_orig = torch.rand(B0, 7) + x = x_orig.clone() + y = torch.rand(7) + vmap(op, in_dims=(0, None))(x, y) + self.assertEqual(x, outplace_op(x_orig, y)) + + x_orig = torch.rand(B0, B1, 7) + x = x_orig.clone() + y = torch.rand(B0, 7) + vmap(vmap(op, in_dims=(0, None)))(x, y) + self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7))) + + # op(left, right): Some of the levels in right are not found in left + msg = r'vmap: aten::atan2_\(self, \*extra_args\) is not possible' + x = torch.rand(7) + y = torch.rand(B0, 7) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=(None, 0))(x, y) + + x = torch.rand(B1, 7) + y = torch.rand(B0, 7) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y) + + x = torch.rand(B1, 7) + y = torch.rand(7, B0) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y) + + x = torch.rand(B0, 7) + y = torch.rand(B0, B1, 7) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=(None, 0)))(x, y) + def test_backward_unsupported_interaction(self): x = torch.randn(3, requires_grad=True) y = torch.randn(5) @@ -598,6 +766,25 @@ def test_nn_module(self): result = vmap(model)(tensor) self.assertEqual(result, model(tensor)) + def test_fallback_with_undefined_grad(self): + B0 = 7 + x = torch.randn(2, 3, 4, 5, requires_grad=True) + weight = torch.randn(3, 3, 1, 1) + v = torch.randn(B0, 2, 3, 4, 5) + + def get_vjp(v): + result = torch.nn.functional.conv2d(x, weight) + grad_x, = torch.autograd.grad(result, x, v) + return grad_x + + # Runs vmap(get_vjp)(v), which should not error out. + # The backward formula for convolution returns an undefined + # Tensor for grad_bias because the original bias does not exist. + # + # In the future we'll probably add a batching rule for convolution + # backward. When this happens, we should modify this test to use a + # different op (and/or create and use a dummy operator) to avoid bitrot. + self._assert_uses_vmap_fallback([get_vjp], [v]) def slice_inputs(inputs, bdims, i): result = [] @@ -722,12 +909,11 @@ def _wrap_method_with_vmap_fallback_check(self, method): @functools.wraps(method) def wrapper(self, *args, **kwargs): - regex = r'falling back to slow \(for loop and stack\) implementation' with warnings.catch_warnings(record=True) as wa: warnings.simplefilter('always') method(*args, **kwargs) for captured_warning in wa: - self.assertNotRegex(str(captured_warning.message), regex, msg) + self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg) return types.MethodType(wrapper, self) @allowVmapFallbackUsage @@ -765,8 +951,8 @@ def _vmap_test(self, *args, **kwargs): def _vmap_view_test(self, *args, **kwargs): self._vmap_test(*args, **kwargs, check_view=True) - def _test_unary(self, op, getter, device): - test = self._vmap_test + def _test_unary(self, op, getter, device, *args, **kwargs): + test = functools.partial(self._vmap_test, *args, **kwargs) B0, B1 = 7, 11 # Single vmap, various in_dims / out_dims @@ -816,6 +1002,36 @@ def test_unary_pointwise_ops(self): for op, getter in cases: self._test_unary(op, getter, 'cpu') + def test_clone(self): + # Some basic tests + self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu') + self._test_unary(lambda x: x.clone(memory_format=torch.preserve_format), + TensorFactory.randn, 'cpu') + self._test_unary(lambda x: x.clone(memory_format=torch.contiguous_format), + TensorFactory.randn, 'cpu') + + # Test that the per-examples are contiguous when using torch.contiguous_format + def clone_contiguous(x): + return x.clone(memory_format=torch.contiguous_format) + + B0, B1 = 3, 5 + x = torch.randn(2, B0, 7) + y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x) + self.assertTrue(y.movedim(1, 0).is_contiguous()) + self.assertTrue(y[:, 0, :].is_contiguous()) + + x = torch.randn(2, B0, 7, B1) + y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x) + self.assertTrue(y.is_contiguous()) + self.assertTrue(y[0][0].is_contiguous()) + + + msg = r'only supported with memory_format torch.preserve_format or torch.contiguous_format' + with self.assertRaisesRegex(RuntimeError, msg): + vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0)) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0)) + def test_binary_pointwise_ops(self): def get_number(getter): return getter([]).item() @@ -881,6 +1097,98 @@ def make_case(op, input_getter=TensorFactory.randn): # self._test_unary(lambda t: op(number, t), getter, device='cuda') # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda') + def test_as_strided(self): + def _test(sizes, strides, offset, tensor, lambd): + result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor) + expected = vmap(lambd)(tensor) + self.assertTrue(result._base is expected._base) + self.assertEqual(result, expected) + + # single vmap test + B0 = 5 + tensors = [ + # contiguous + torch.randn(B0, 2, 3), + # non-contiguous + torch.randn(B0, 3, 2).transpose(1, 2), + # non-zero storage offset + torch.randn(2, B0, 2, 3)[1], + # non-contiguous strides, zero storage offset + torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0], + # non-contiguous strides, non-zero storage offset + torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1], + ] + + for x in tensors: + S0, S1 = x.stride()[1:] + offset = x.storage_offset() + + # Broadcast + _test([5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3)) + # transpose + _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1)) + # select + _test([2], [S0], offset + S1, x, lambda x: x[:, 1]) + + # Nested vmap test + B1 = 7 + x = torch.randn(B1, B0, 2, 3) + S0, S1 = x.stride()[2:] + result = vmap(vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1)(x) + expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x) + self.assertTrue(result._base is expected._base) + self.assertEqual(result, expected) + + # Check that mal-formatted size/strides doesn't crash + with self.assertRaisesRegex(RuntimeError, 'size and stride must have the same length'): + x = torch.randn(B0, 2, 3).transpose(0, 1) + vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x) + + # Sanity check #1: we require the batch dims to be at the front of the + # tensor (in memory layout). + msg = 'batch dims being vmapped over are at the front of the tensor' + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(2, B0, 3).transpose(0, 1) + vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x) + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 2, 3, B1).movedim(3, 1) + vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x) + + # All the Sanity check #2{a,b,c} cases check that + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # doesn't index memory that is out of bounds of xs[i]. This condition + # is important to the correctness of the as_strided batching rule + # (see NOTE: [When will the as_strided_batching_rule fail?]) + + # Sanity check #2a: The maximum indexable location of + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # is less than or equal to the maximum indexable location of xs[i]. + msg = 'This is not supported inside of vmap' + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 3) + vmap(lambda x: x.as_strided([3], [1], 1))(x) + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 3, 5) + vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x) + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, B1, 3, 5) + vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x) + + # Sanity check #2b: The min indexable location of + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # is greater than or equal to the min indexable location of xs[i]. + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(2, B0, 3)[1] + vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x) + + # Sanity check #2c: + # xs[i] is a zero-dim tensor, but + # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + # is not + with self.assertRaisesRegex(RuntimeError, msg): + x = torch.randn(B0, 0, 3) + vmap(lambda x: x.as_strided([3], [1]))(x) + def test_bmm(self): op = torch.bmm test = self._vmap_test @@ -959,6 +1267,45 @@ def get(shape): result = vmap(op)(real_tensor) self.assertEqual(result.data_ptr(), real_tensor.data_ptr()) + def test_contiguous(self): + op = Tensor.contiguous + + self._test_unary(op, TensorFactory.randn, 'cpu') + + # check that contiguous returns the original tensor if the per-examples + # are already contiguous + B0 = 3 + x = torch.randn(B0, 2, 5, 7) + x = x.movedim(0, 2) + result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x) + self.assertTrue(result is x) + + msg = 'NYI: querying is_contiguous inside of vmap for memory_format' + tensor = torch.randn(B0, 3) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(op, memory_format=torch.channels_last))(tensor) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor) + + def test_stride(self): + B0 = 3 + + x = torch.randn(B0, 2, 5, 7) + + def foo(x): + assert x.stride() == (7 * 5, 7, 1) + return x + + vmap(foo)(x) + + x = torch.randn(2, B0, 5, 7).movedim(1, 0) + + def bar(x): + assert x.stride() == (7 * 5 * B0, 7, 1) + return x + + vmap(bar)(x) + def test_chunk(self): test = self._vmap_view_test op = torch.chunk @@ -972,6 +1319,50 @@ def test_chunk(self): test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + def test_clamp(self): + clamp_cases = ( + (lambda t: t.clamp(min=-0.5), TensorFactory.randn), + (lambda t: t.clamp(max=0.5), TensorFactory.randn), + (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn), + (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn), + (lambda t: t.clamp_max(max=0.5), TensorFactory.randn), + ) + for op, getter in clamp_cases: + self._test_unary(op, getter, 'cpu') + + def test_comparison_ops(self): + test = functools.partial(self._vmap_test, check_propagates_grad=False) + + getter = TensorFactory.randn + B0, B1 = 7, 11 + + ops = ( + torch.eq, lambda x, y: x == y, + torch.gt, lambda x, y: x > y, + torch.ge, lambda x, y: x >= y, + torch.le, lambda x, y: x <= y, + torch.lt, lambda x, y: x < y, + torch.ne, lambda x, y: x != y, + ) + + for op in ops: + # Single vmap: op(Tensor, Tensor) + test(op, (getter([B0, 3]), getter([B0, 3]))) + test(op, (getter([B0]), getter([B0, 2, 3]))) + test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1)) + test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1) + test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None)) + test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None)) + + # Nested vmap: op(Tensor, Tensor) + test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3]))) + test(vmap(op, in_dims=(None, 0)), + (getter([B0, 2, 3]), getter([B1, 3])), in_dims=(0, None)) + + # test number as inputs + number = getter([]).item() + self._test_unary(lambda t: op(t, number), getter, 'cpu', check_propagates_grad=False) + def test_diagonal(self): tensor = torch.randn(3, 5, 7, 11, 13) test = self._vmap_view_test @@ -1026,6 +1417,124 @@ def test_expand_as(self): test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5))) + def test_fill_and_zero_inplace(self): + test = functools.partial(self._vmap_test, check_propagates_grad=False) + B0, B1 = 7, 11 + ops = ( + lambda t: t.fill_(0.1), + lambda t: t.fill_(torch.tensor(0.2)), + lambda t: t.zero_(), + ) + + for op in ops: + # Single vmap, various in_dims / out_dims + test(op, [TensorFactory.randn([B0, 3])]) + test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2) + test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [TensorFactory.randn([B0, B1])]) + test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2) + test(vmap(op, in_dims=2), [TensorFactory.randn([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + + # test when value is a batched tensor for fill_ operator + B0, B1 = 3, 5 + test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)]) + + with self.assertRaisesRegex(RuntimeError, + r"output with shape .+ doesn't match the broadcast shape"): + # Runtime Error is thrown when the tensor being written to isn't being vmapped over + vmap(Tensor.fill_, (None, 0))(TensorFactory.randn([B0, B1]), + TensorFactory.randn([B0])) + + def _test_complex_views(self, op, dtypes): + test = self._vmap_view_test + + def run_test(op, dtype): + def get(shape): + return torch.randn(shape, dtype=dtype) + + B0, B1 = 7, 11 + + # Single vmap, various in_dims / out_dims + test(op, [get([B0, 3])]) + test(op, [get([3, B0])], in_dims=1) + test(op, [get([2, 5, B0, 3])], in_dims=2) + test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [get([B0, B1])]) + test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4) + test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + + for dtype in dtypes: + run_test(op, dtype) + + def test_real(self): + self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble]) + + def test_imag(self): + self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble]) + + def test_view_as_real(self): + self._test_complex_views(torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble]) + + def test_view_as_complex(self): + def run_test(dtype): + def get(shape): + return torch.randn(shape, dtype=dtype) + + op = torch.view_as_complex + test = self._vmap_view_test + B0, B1 = 7, 11 + + # Single vmap, various in_dims / out_dims + test(op, [get([B0, 3, 2])]) + test(op, [get([2, 5, B0, 3, 2])], in_dims=2) + test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(op), [get([B0, B1, 2])]) + test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2) + test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], + in_dims=2, out_dims=2) + + # Interesting case #1: Batch dim directly before dim of size 2 + test(op, [get([3, B0, 2])], in_dims=1) + test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2) + + # Interesting case #2: Batch dim at end of tensor, success cases + # view_as_complex requires that the dim with size 2 have stride 1 + # in order for the view to function propertly + test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1) + test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)]) + test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)]) + + # Interesting case #3: Batch dim at end of tensor, failure cases + msg = "Tensor must have a last dimension with stride 1" + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=1)(get([2, B0])) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1])) + + # Invalid input: no dimension of size 2 + msg = 'Input tensor must have one or more dimensions' + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(get([B0])) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(vmap(op))(get([B0, B1])) + + # Invalid input: Batch dim has size 2, but the logical last dim does + # not have size 2 + msg = 'Tensor must have a last dimension of size 2' + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, in_dims=1)(get([3, 2])) + + for dtype in [torch.float, torch.double]: + run_test(dtype) + def test_is_complex(self): ctensor = torch.randn(3, dtype=torch.cfloat) tensor = torch.randn(3) @@ -1039,6 +1548,62 @@ def foo(x): self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1])) self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0])) + def test_is_contiguous(self): + def foo(x): + if x.is_contiguous(): + return torch.tensor(1.) + else: + return torch.tensor(0.) + + B0, B1 = 3, 5 + + # Single batch dim + contig = torch.randn(B0, 2, 7) + self.assertEqual(vmap(foo)(contig), torch.ones(B0)) + + noncontig = torch.randn(2, B0, 7) + self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0)) + + noncontig = torch.randn(2, B0, 7).movedim(1, 0) + self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0)) + + noncontig = torch.randn(2, 7, B0) + self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0)) + + # Multiple batch dims + contig = torch.randn(B0, B1, 3) + self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) + + contig = torch.randn(B1, B0, 3) + self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1)) + + contig = torch.randn(B1, B0, 3).movedim(0, 1) + self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) + + noncontig = torch.randn(B0, 3, B1) + self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1)) + + # is_contiguous on empty tensor is True + def bar(x): + assert x.is_contiguous() + return x + + vmap(bar)(torch.randn(B0, 0, 3)) + vmap(bar, in_dims=1)(torch.randn(0, B0, 3)) + vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2)) + + # is_contiguous with other memory formats + def baz(x, memory_format): + x.is_contiguous(memory_format=memory_format) + return x + + msg = 'NYI: querying is_contiguous inside of vmap for memory_format' + tensor = torch.randn(B0, 2, 7, 3) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) + def test_movedim(self): op = torch.movedim test = self._vmap_view_test @@ -1131,6 +1696,72 @@ def test_narrow(self): test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)), (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None)) + def test_new_empty(self): + # Empty is non-deterministic so we just check that the shape of the + # output tensor is what we expect and that the vmap fallback isn't used. + op = Tensor.new_empty + + B0, B1 = 7, 11 + + result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0)) + self.assertEqual(result.shape, [B0, 2, 3]) + + result = vmap(lambda x: op(x, []))(torch.randn(B0)) + self.assertEqual(result.shape, [B0]) + + result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1)) + self.assertEqual(result.shape, [B0, B1, 2, 3]) + + def test_new_empty_strided(self): + # Empty is non-deterministic so we just check that the size and shape + # of the output are what we expect and that the vmap fallback isn't used + B0, B1 = 7, 11 + + def _test_single_vmap(size, stride, B0): + x = torch.randn(B0) + result = vmap(lambda x: x.new_empty_strided(size, stride))(x) + S = torch.empty_strided(size, stride).storage().size() + self.assertEqual(result.shape, [B0] + size) + self.assertEqual(result.stride(), [S] + stride) + + def _test_double_vmap(size, stride, B0, B1): + x = torch.randn(B0, B1) + result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x) + S = torch.empty_strided(size, stride).storage().size() + self.assertEqual(result.shape, [B0, B1] + size) + self.assertEqual(result.stride(), [B1 * S, S] + stride) + + x = torch.randn(B1, B0) + result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(x) + S = x.new_empty_strided(size, stride).storage().size() + self.assertEqual(result.shape, [B0, B1] + size) + self.assertEqual(result.stride(), [B1 * S, S] + stride) + + # contiguous case + _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0) + _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1) + + # expanded + _test_single_vmap([2, 3, 5], [0, 5, 1], B0) + _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1) + + # some of these cases are pretty strange, just verifying that if + # empty_strided allows them then BatchedTensor.new_empty_strided + # can as well + for shape in [[2, 3, 4], [0, 2, 0]]: + for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]: + _test_single_vmap(shape, strides, B0) + _test_double_vmap(shape, strides, B0, B1) + + def test_new_zeros(self): + op = Tensor.new_zeros + test = functools.partial(self._vmap_test, check_propagates_grad=False) + B0, B1 = 7, 11 + + test(lambda x: op(x, 2, 3), (torch.rand(B0),)) + test(lambda x: op(x, []), (torch.rand(B0),)) + test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),)) + def test_select(self): op = torch.select test = self._vmap_view_test @@ -1169,6 +1800,36 @@ def test_slice(self): test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), (torch.rand(3, 5, B0, B1, B2),), in_dims=2) + def test_squeeze(self): + test = self._vmap_view_test + op = torch.squeeze + B0, B1 = 1, 11 + test(op, (torch.rand(B0),)) + test(op, (torch.rand(B0, 3, 5),)) + test(op, (torch.rand(1, B0, 5),), in_dims=1) + test(op, (torch.rand(B0, 0, 1, 5, 1),)) + test(op, (torch.rand(B0, 1, 1, 1, 1),)) + test(vmap(op), (torch.rand(B0, B1, 1),)) + test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2) + + def test_sum_dim(self): + test = self._vmap_test + B0, B1 = 5, 7 + + # Single vmap, various in_dims / out_dims + test(lambda x: x.sum(0), [torch.randn([B0])]) + test(lambda x: x.sum(-1), [torch.randn([B0])]) + test(lambda x: x.sum(0), [torch.randn([B0, 3])]) + test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2) + test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) + + # Doubly nested vmap + test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])]) + test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])]) + test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2) + test(vmap(lambda x: x.sum(2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])], + in_dims=2, out_dims=2) + def test_reshape(self): test = self._vmap_test B0, B1, B2 = 7, 11, 13 @@ -1237,6 +1898,27 @@ def wrapped(*args, **kwargs): test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)), check_propagates_grad=False) + def test_tensor_split(self): + test = self._vmap_view_test + op = torch.tensor_split + B0, B1, B2 = 7, 11, 13 + + # tests for torch.tensor_split(self, indices_or_sections: int, dim) + test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + + # tests for torch.tensor_split(self, indices_or_sections: List[int], dim) + test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None)) + test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None)) + test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0), + in_dims=(2, None, None)) + test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)), + (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + def test_split(self): test = self._vmap_view_test op = torch.split @@ -1258,6 +1940,35 @@ def test_split(self): test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), (torch.rand(B1, 2, B0, 64, B2),), in_dims=2) + def test_trace(self): + op = torch.trace + test = self._vmap_test + B0, B1, B2 = 7, 11, 13 + + test(op, (torch.rand(B0, 2, 5),)) + test(op, (torch.rand(2, B0, 5),), in_dims=1) + test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) + + def test_transpose(self): + op = torch.transpose + test = self._vmap_view_test + + B0, B1, B2 = 7, 11, 13 + test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),)) + test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),)) + test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),)) + test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1) + test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) + test(vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)), + (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) + + # Special case: scalar tensor + for dim1, dim2 in itertools.product([0, -1], [0, -1]): + x = torch.rand(B0) + result = vmap(lambda x: op(x, dim1, dim2))(x) + self.assertTrue(result is x) + def test_t(self): op = torch.t test = self._vmap_view_test @@ -1458,7 +2169,9 @@ def _vmap_test(self, *args, **kwargs): # output_process_fn: a function that maps the outputs to the part # that should be differentiated. # batch_size: the batch dim size for the batched grad - def _batched_grad_test(self, op, args, kwargs, output_process_fn=lambda x: x, batch_size=3): + def _batched_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): + if kwargs is None: + kwargs = {} outputs = op(*args, **kwargs) outputs = differentiable(output_process_fn(outputs)) batched_vectors = tuple(construct_v(out, batch_size) for out in outputs) @@ -1482,7 +2195,9 @@ def vector_jacobian_product(*vectors): # Regression. # It might be useful to have a test that computes batched first gradients and # then uses those to compute batched second gradients in the future. - def _batched_grad_grad_test(self, op, args, kwargs, output_process_fn=lambda x: x, batch_size=3): + def _batched_grad_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3): + if kwargs is None: + kwargs = {} outputs = op(*args, **kwargs) outputs = differentiable(output_process_fn(outputs)) ones = tuple(torch.ones_like(out) for out in outputs) @@ -1509,12 +2224,12 @@ def _test_arithmetic(self, op, device, test_grad_grad=True): x = torch.randn(2, 3, requires_grad=True, device=device) y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) scalar = 3.14 - self._batched_grad_test(op, (x, y), {}) - self._batched_grad_test(op, (scalar, y), {}) - self._batched_grad_test(op, (x, scalar), {}) + self._batched_grad_test(op, (x, y)) + self._batched_grad_test(op, (scalar, y)) + self._batched_grad_test(op, (x, scalar)) if test_grad_grad: - self._batched_grad_grad_test(op, (x, y), {}) + self._batched_grad_grad_test(op, (x, y)) def test_add(self, device): self._test_arithmetic(torch.add, device, test_grad_grad=False) @@ -1532,22 +2247,44 @@ def test_div(self, device): self._test_arithmetic(torch.div, device) self._test_arithmetic(lambda x, y: x / y, device) + @allowVmapFallbackUsage + def test_binary_cross_entropy(self, device): + x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True)) + target = torch.rand(3, 2, device=device) + + op = functools.partial(F.binary_cross_entropy, target=target) + + self._batched_grad_test(op, (x,), {}) + self._batched_grad_grad_test(op, (x,), {}) + def test_expand(self, device): x = torch.randn(2, 3, device=device, requires_grad=True) def op(x): return x.expand(5, 5, 2, 3) - self._batched_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) + + @allowVmapFallbackUsage + def test_index(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + index = torch.tensor([[0, 0], [1, 1]], device=device) + + def op(x): + y = x * x + return y[index] + + self._batched_grad_test(op, (x,)) + self._batched_grad_grad_test(op, (x,)) def test_lgamma(self, device): x = torch.randn(2, 3, requires_grad=True, device=device) - self._batched_grad_test(Tensor.lgamma, (x,), {}) - self._batched_grad_grad_test(Tensor.lgamma, (x,), {}) + self._batched_grad_test(Tensor.lgamma, (x,)) + self._batched_grad_grad_test(Tensor.lgamma, (x,)) def test_log(self, device): x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) - self._batched_grad_test(torch.log, (x,), {}) - self._batched_grad_grad_test(torch.log, (x,), {}) + self._batched_grad_test(torch.log, (x,)) + self._batched_grad_grad_test(torch.log, (x,)) def test_logsumexp(self, device): x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) @@ -1555,13 +2292,28 @@ def test_logsumexp(self, device): def op(x): return torch.logsumexp(x, -1) - self._batched_grad_test(op, (x,), {}) - self._batched_grad_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) + self._batched_grad_grad_test(op, (x,)) def test_log1p(self, device): x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) - self._batched_grad_test(torch.log1p, (x,), {}) - self._batched_grad_grad_test(torch.log1p, (x,), {}) + self._batched_grad_test(torch.log1p, (x,)) + self._batched_grad_grad_test(torch.log1p, (x,)) + + @allowVmapFallbackUsage + def test_max(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(torch.max, (x,)) + + @allowVmapFallbackUsage + def test_median(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(torch.median, (x,)) + + @allowVmapFallbackUsage + def test_min(self, device): + x = torch.randn(2, 3, requires_grad=True, device=device) + self._batched_grad_test(torch.min, (x,)) def test_permute(self, device): x = torch.randn(2, 3, 5, requires_grad=True, device=device) @@ -1569,7 +2321,7 @@ def test_permute(self, device): def op(x): return x.permute(2, 0, 1) - self._batched_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) def test_reshape(self, device): x = torch.randn(2, 3, 5, requires_grad=True, device=device) @@ -1577,12 +2329,12 @@ def test_reshape(self, device): def op(x): return x.reshape([2 * 3, 5]) - self._batched_grad_test(op, (x,), {}) + self._batched_grad_test(op, (x,)) def test_sigmoid(self, device): x = torch.randn(2, 3, requires_grad=True, device=device) - self._batched_grad_test(Tensor.sigmoid, (x,), {}) - self._batched_grad_grad_test(Tensor.sigmoid, (x,), {}) + self._batched_grad_test(Tensor.sigmoid, (x,)) + self._batched_grad_grad_test(Tensor.sigmoid, (x,)) def test_stack(self, device): x = torch.randn(2, 3, device=device, requires_grad=True) @@ -1590,26 +2342,104 @@ def test_stack(self, device): def op(x, y): return torch.stack([x, y]) - self._batched_grad_test(op, (x, y), {}) + self._batched_grad_test(op, (x, y)) def test_select(self, device): x = torch.randn(2, 3, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x[1], (x,), {}) - self._batched_grad_test(lambda x: x.select(1, 2), (x,), {}) - self._batched_grad_test(lambda x: x.select(-1, 0), (x,), {}) + self._batched_grad_test(lambda x: x[1], (x,)) + self._batched_grad_test(lambda x: x.select(1, 2), (x,)) + self._batched_grad_test(lambda x: x.select(-1, 0), (x,)) def test_slice(self, device): x = torch.randn(2, 3, 5, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x[0:1], (x,), {}) - self._batched_grad_test(lambda x: x[:, 1:3], (x,), {}) - self._batched_grad_test(lambda x: x[..., 1:3], (x,), {}) + self._batched_grad_test(lambda x: x[0:1], (x,)) + self._batched_grad_test(lambda x: x[:, 1:3], (x,)) + self._batched_grad_test(lambda x: x[..., 1:3], (x,)) + + def test_trace(self, device): + x = torch.randn(2, 3, device=device, requires_grad=True) + self._batched_grad_test(Tensor.trace, (x,)) + + @allowVmapFallbackUsage + def test_symeig(self, device): + def op(x): + return torch.symeig(x, eigenvectors=True)[0] + + x = torch.randn(3, 3, device=device, requires_grad=True) + self._batched_grad_test(op, (x,), {}) + self._batched_grad_grad_test(op, (x,), {}) + + def test_threshold(self, device): + x = torch.randn(2, 3, device=device, requires_grad=True) + self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,)) + + + @allowVmapFallbackUsage + def test_inplace_view(self, device): + leaf = torch.randn(4, 5, requires_grad=True) + + def func(leaf): + # Make sure the function is non-trivially twice differentiable + base = leaf * leaf + view = base[0] + view.cos_() + return view + + self._batched_grad_test(func, (leaf,), {}) + self._batched_grad_grad_test(func, (leaf,), {}) + + @allowVmapFallbackUsage + def test_inplace_manyview(self, device): + leaf = torch.randn(4, 4, 5, requires_grad=True) + + def func(leaf): + # Make sure the function is non-trivially twice differentiable + base = leaf * leaf + view = base.transpose(0, 2) + view = view[1] + view = view.diagonal() + view = view[::2] + view.cos_() + return view + + self._batched_grad_test(func, (leaf,), {}) + self._batched_grad_grad_test(func, (leaf,), {}) def test_diagonal(self, device): x = torch.randn(4, 5, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,), {}) + self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,)) x = torch.randn(3, 4, 5, device=device, requires_grad=True) - self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,), {}) + self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,)) + + @allowVmapFallbackUsage + def test_unrelated_output(self, device): + B0 = 3 + x = torch.randn([], requires_grad=True) + y = torch.randn([], requires_grad=True) + gy = torch.randn(B0, requires_grad=True) + + def vjp(v): + res, = torch.autograd.grad(y, x, v, allow_unused=True) + return torch.zeros_like(x) if res is None else res + + result = vmap(vjp)(gy) + self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) + + @allowVmapFallbackUsage + def test_unrelated_output_multiple_grad(self, device): + B0 = 3 + x = torch.randn([], requires_grad=True) + y = torch.randn([], requires_grad=True) + gy = torch.randn(B0, requires_grad=True) + + def vjp(v): + res, = torch.autograd.grad(y, x, v, allow_unused=True) + return torch.zeros_like(x) if res is None else res + + _ = vjp(gy[0]) + result = vmap(vjp)(gy) + self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) instantiate_device_type_tests( TestVmapBatchedGradient, diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py index a40ec48f2f371..1e9bb617425a5 100644 --- a/test/test_xnnpack_integration.py +++ b/test/test_xnnpack_integration.py @@ -6,16 +6,18 @@ from torch.utils.mobile_optimizer import optimize_for_mobile from torch.testing import FileCheck import torch.testing._internal.hypothesis_utils as hu -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, slowTest from hypothesis import given, assume from hypothesis import strategies as st import io import itertools +from torch.testing._internal.common_utils import TEST_WITH_TSAN @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.") +@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.") class TestXNNPACKOps(TestCase): @given(batch_size=st.integers(0, 3), data_shape=hu.array_shapes(1, 3, 2, 64), @@ -161,6 +163,7 @@ def test_conv2d_transpose(self, @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.") +@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.") class TestXNNPACKSerDes(TestCase): @given(batch_size=st.integers(0, 3), data_shape=hu.array_shapes(1, 3, 2, 64), @@ -551,6 +554,7 @@ def forward(self, x): @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.") +@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.") class TestXNNPACKRewritePass(TestCase): @staticmethod def validate_transformed_module( @@ -911,6 +915,7 @@ def forward(self, x): @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.") +@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") class TestXNNPACKConv1dTransformPass(TestCase): @staticmethod def validate_transform_conv1d_to_conv2d( @@ -1017,6 +1022,8 @@ def forward(self, x): pattern_count_optimized_map, data_shape) + # See https://github.com/pytorch/pytorch/issues/46066 + @slowTest def test_conv1d_with_relu_fc(self): batch_size_list = range(1, 3) input_channels_per_group_list = range(10, 12) diff --git a/test/type_hint_tests/module_list.py b/test/type_hint_tests/module_list.py index c31f8185a2083..14659f7c00a63 100644 --- a/test/type_hint_tests/module_list.py +++ b/test/type_hint_tests/module_list.py @@ -8,4 +8,4 @@ class BarModule(torch.nn.Module): pass ml: torch.nn.ModuleList = torch.nn.ModuleList([FooModule(), BarModule()]) -ml[0].children() == [] +ml[0].children() == [] # noqa: B015 diff --git a/test/type_hint_tests/namedtuple.py b/test/type_hint_tests/namedtuple.py index 6e4b60e181946..3d6eac372882d 100644 --- a/test/type_hint_tests/namedtuple.py +++ b/test/type_hint_tests/namedtuple.py @@ -4,10 +4,10 @@ t = torch.tensor([[3.0, 1.5], [2.0, 1.5]]) t_sort = t.sort() -t_sort[0][0, 0] == 1.5 -t_sort.indices[0, 0] == 1 -t_sort.values[0, 0] == 1.5 +t_sort[0][0, 0] == 1.5 # noqa: B015 +t_sort.indices[0, 0] == 1 # noqa: B015 +t_sort.values[0, 0] == 1.5 # noqa: B015 t_qr = torch.qr(t) -t_qr[0].shape == [2, 2] -t_qr.Q.shape == [2, 2] +t_qr[0].shape == [2, 2] # noqa: B015 +t_qr.Q.shape == [2, 2] # noqa: B015 diff --git a/test/type_hint_tests/opt_size.py b/test/type_hint_tests/opt_size.py new file mode 100644 index 0000000000000..f24e57e6e56f9 --- /dev/null +++ b/test/type_hint_tests/opt_size.py @@ -0,0 +1,6 @@ +import torch.nn as nn + +avg_pool1 = nn.AdaptiveAvgPool2d((1, None)) +avg_pool2 = nn.AdaptiveAvgPool2d((None, 1)) +max_pool1 = nn.AdaptiveMaxPool2d((1, None)) +max_pool2 = nn.AdaptiveMaxPool2d((None, 1)) diff --git a/third_party/NNPACK b/third_party/NNPACK index 24b55303f5cf6..c07e3a0400713 160000 --- a/third_party/NNPACK +++ b/third_party/NNPACK @@ -1 +1 @@ -Subproject commit 24b55303f5cf65d75844714513a0d1b1409809bd +Subproject commit c07e3a0400713d546e0dea2d5466dd22ea389c73 diff --git a/third_party/XNNPACK b/third_party/XNNPACK index 1b354636b5942..e1ffe154593a2 160000 --- a/third_party/XNNPACK +++ b/third_party/XNNPACK @@ -1 +1 @@ -Subproject commit 1b354636b5942826547055252f3b359b54acff95 +Subproject commit e1ffe154593a2e6714d3d2370739cf6fea1055c6 diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 63b254577ed77..5916273f79a21 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 63b254577ed77a8004a9be6ac707f3dccc4e1fd9 +Subproject commit 5916273f79a21551890fd3d56fc5375a78d1598d diff --git a/third_party/fbgemm b/third_party/fbgemm index 1d710393d5b75..9b0131179f293 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 1d710393d5b7588f5de3b83f51c22bbddf095229 +Subproject commit 9b0131179f293a645bfd3409cd66fa5eecc393b0 diff --git a/third_party/fmt b/third_party/fmt index 9bdd1596cef1b..cd4af11efc9c6 160000 --- a/third_party/fmt +++ b/third_party/fmt @@ -1 +1 @@ -Subproject commit 9bdd1596cef1b57b9556f8bef32dc4a32322ef3e +Subproject commit cd4af11efc9c622896a3e4cb599fa28668ca3d05 diff --git a/third_party/foxi b/third_party/foxi index 4aba696ec8f31..6a4e19a2aaf7a 160000 --- a/third_party/foxi +++ b/third_party/foxi @@ -1 +1 @@ -Subproject commit 4aba696ec8f31794fd42880346dc586486205e0a +Subproject commit 6a4e19a2aaf7ae4b9fa9597526e65b395d5e79ad diff --git a/third_party/ideep b/third_party/ideep index ba885200dbbc1..f0280bb805c2d 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit ba885200dbbc1f144c7b58eba487378eb324f281 +Subproject commit f0280bb805c2dedd4bb5dd4765cda7dfcd30989f diff --git a/third_party/kineto b/third_party/kineto new file mode 160000 index 0000000000000..e9198dd3066ee --- /dev/null +++ b/third_party/kineto @@ -0,0 +1 @@ +Subproject commit e9198dd3066ee6e5e20201d6ae6f86f092bb7123 diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index c4491b10e1113..9af253fde189a 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -7,9 +7,9 @@ template_rule( out = "include/dnnl_version.h", substitutions = { "@DNNL_VERSION_MAJOR@": "1", - "@DNNL_VERSION_MINOR@": "5", + "@DNNL_VERSION_MINOR@": "7", "@DNNL_VERSION_PATCH@": "0", - "@DNNL_VERSION_HASH@": "e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0", + "@DNNL_VERSION_HASH@": "2e4732679f0211bb311780d0f383cf2dce9baca7", }, ) @@ -30,6 +30,8 @@ cc_library( srcs = glob([ "src/common/*.cpp", "src/cpu/**/*.cpp", + ], exclude=[ + "src/cpu/aarch64/*.cpp", ]), hdrs = glob([ "include/*.h", @@ -38,7 +40,8 @@ cc_library( "src/cpu/**/*.hpp", "src/cpu/**/*.h", "src/common/*.hpp", - "src/cpu/rnn/*.hpp", + ], exclude=[ + "src/cpu/aarch64/*.hpp", ]) + [ "include/dnnl_version.h", "include/dnnl_config.h", diff --git a/third_party/onnx b/third_party/onnx index a82c6a7010e2e..54c38e6eaf557 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit a82c6a7010e2e332d8f74ad5b0c726fd47c85376 +Subproject commit 54c38e6eaf557b844e70cebc00f39ced3321e9ad diff --git a/third_party/onnx.BUILD b/third_party/onnx.BUILD index 58e2c797bc38d..3079947e15dd5 100644 --- a/third_party/onnx.BUILD +++ b/third_party/onnx.BUILD @@ -8,6 +8,7 @@ py_binary( data = [ "onnx/onnx.in.proto", "onnx/onnx-operators.in.proto", + "onnx/onnx-data.in.proto", ], ) @@ -31,6 +32,16 @@ genrule( tools = [":gen_proto"], ) +genrule( + name = "generate_onnx_data_proto", + outs = [ + "onnx/onnx-data_onnx_torch.proto", + "onnx/onnx-data.pb.h", + ], + cmd = "$(location :gen_proto) -p onnx_torch -o $(@D)/onnx onnx-data -m >/dev/null && sed -i 's/onnx-data_onnx_torch.pb.h/onnx\\/onnx-data_onnx_torch.pb.h/g' $(@D)/onnx/onnx-data.pb.h", + tools = [":gen_proto"], +) + cc_library( name = "onnx", srcs = glob( @@ -73,6 +84,7 @@ cc_library( ]) + [ "onnx/onnx-ml.pb.h", "onnx/onnx-operators-ml.pb.h", + "onnx/onnx-data.pb.h", ], defines = [ "ONNX_ML=1", @@ -104,6 +116,7 @@ proto_library( srcs = [ "onnx/onnx-operators_onnx_torch-ml.proto", "onnx/onnx_onnx_torch-ml.proto", + "onnx/onnx-data_onnx_torch.proto", ], ) diff --git a/third_party/pthreadpool b/third_party/pthreadpool index 029c88620802e..fa75e65a58a5c 160000 --- a/third_party/pthreadpool +++ b/third_party/pthreadpool @@ -1 +1 @@ -Subproject commit 029c88620802e1361ccf41d1970bd5b07fd6b7bb +Subproject commit fa75e65a58a5c70c09c30d17a1fe1c1dff1093ae diff --git a/third_party/pybind11 b/third_party/pybind11 index 25abf7efba0b2..a1cb7c23d3b47 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 25abf7efba0b2990f5a6dfb0a31bc65c0f2f4d17 +Subproject commit a1cb7c23d3b47a2bca24b136e5222e497de9575a diff --git a/third_party/sleef b/third_party/sleef index 7f523de651585..e0a003ee838b7 160000 --- a/third_party/sleef +++ b/third_party/sleef @@ -1 +1 @@ -Subproject commit 7f523de651585fe25cade462efccca647dcc8d02 +Subproject commit e0a003ee838b75d11763aa9c3ef17bf71a725bff diff --git a/third_party/sleef.BUILD b/third_party/sleef.BUILD index 4466632fa895f..2db72c9d4ed47 100644 --- a/third_party/sleef.BUILD +++ b/third_party/sleef.BUILD @@ -178,12 +178,12 @@ genrule( genrule( name = "sleef_h", srcs = [ - "src/libm/sleeflibm_header.h.org", + "src/libm/sleeflibm_header.h.org.in", "src/libm/sleeflibm_footer.h.org", ], outs = ["build/include/sleef.h"], cmd = "{ " + "; ".join([ - "cat $(location src/libm/sleeflibm_header.h.org)", + "cat $(location src/libm/sleeflibm_header.h.org.in)", "$(location :mkrename) cinz_ 2 4 __m128d __m128 __m128i __m128i __SSE2__", "$(location :mkrename) cinz_ 2 4 __m128d __m128 __m128i __m128i __SSE2__ sse2", "$(location :mkrename) cinz_ 2 4 __m128d __m128 __m128i __m128i __SSE2__ sse4", diff --git a/third_party/tensorpipe b/third_party/tensorpipe index 9646e1a431997..eabfe528673aa 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit 9646e1a431997edb1579972cef196d8fb97a77a5 +Subproject commit eabfe528673aa931239758820845a1e999e5ee3f diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD index 29b2cd1d0d0ca..45b99e64ec9a5 100644 --- a/third_party/tensorpipe.BUILD +++ b/third_party/tensorpipe.BUILD @@ -74,6 +74,7 @@ header_template_rule( "#cmakedefine01 TENSORPIPE_HAS_SHM_TRANSPORT": "", "#cmakedefine01 TENSORPIPE_HAS_CMA_CHANNEL": "", "#cmakedefine01 TENSORPIPE_HAS_CUDA_IPC_CHANNEL": "", + "#cmakedefine01 TENSORPIPE_HAS_IBV_TRANSPORT": "", "#cmakedefine01 TENSORPIPE_SUPPORTS_CUDA": "", }, ) @@ -92,7 +93,13 @@ TENSORPIPE_HEADERS = glob([ TENSORPIPE_BASE_SRCS = glob([ "tensorpipe/*.cc", "tensorpipe/channel/*.cc", - "tensorpipe/common/*.cc", + "tensorpipe/common/address.cc", + "tensorpipe/common/epoll_loop.cc", + "tensorpipe/common/error.cc", + "tensorpipe/common/fd.cc", + "tensorpipe/common/ibv.cc", + "tensorpipe/common/socket.cc", + "tensorpipe/common/system.cc", "tensorpipe/core/*.cc", "tensorpipe/transport/*.cc", "tensorpipe/util/*/*.cc", @@ -106,7 +113,10 @@ TENSORPIPE_SRCS = TENSORPIPE_BASE_SRCS + glob([ ]) TENSORPIPE_SRCS_CUDA = TENSORPIPE_SRCS + glob([ + "tensorpipe/common/cuda_loop.cc", + "tensorpipe/channel/cuda_basic/*.cc", "tensorpipe/channel/cuda_ipc/*.cc", + "tensorpipe/channel/cuda_xth/*.cc", ]) cc_library( diff --git a/third_party/valgrind-headers/README.md b/third_party/valgrind-headers/README.md new file mode 100644 index 0000000000000..98173f37ad6e3 --- /dev/null +++ b/third_party/valgrind-headers/README.md @@ -0,0 +1,5 @@ +This folder contains 2 Valgrind headers, downloaded from +https://sourceware.org/git/?p=valgrind.git;a=blob;f=callgrind/callgrind.h;hb=HEAD +https://sourceware.org/git/?p=valgrind.git;a=blob;f=include/valgrind.h;hb=HEAD + + diff --git a/third_party/valgrind-headers/callgrind.h b/third_party/valgrind-headers/callgrind.h new file mode 100644 index 0000000000000..f078cc82b95da --- /dev/null +++ b/third_party/valgrind-headers/callgrind.h @@ -0,0 +1,129 @@ + +/* + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (callgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of callgrind, a valgrind tool for cache simulation + and call tree tracing. + + Copyright (C) 2003-2017 Josef Weidendorfer. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (callgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + +#ifndef __CALLGRIND_H +#define __CALLGRIND_H + +#include "valgrind.h" + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE ORDER OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end. + + The identification ('C','T') for Callgrind has historical + reasons: it was called "Calltree" before. Besides, ('C','G') would + clash with cachegrind. + */ + +typedef + enum { + VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'), + VG_USERREQ__ZERO_STATS, + VG_USERREQ__TOGGLE_COLLECT, + VG_USERREQ__DUMP_STATS_AT, + VG_USERREQ__START_INSTRUMENTATION, + VG_USERREQ__STOP_INSTRUMENTATION + } Vg_CallgrindClientRequest; + +/* Dump current state of cost centers, and zero them afterwards */ +#define CALLGRIND_DUMP_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS, \ + 0, 0, 0, 0, 0) + +/* Dump current state of cost centers, and zero them afterwards. + The argument is appended to a string stating the reason which triggered + the dump. This string is written as a description field into the + profile data dump. */ +#define CALLGRIND_DUMP_STATS_AT(pos_str) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS_AT, \ + pos_str, 0, 0, 0, 0) + +/* Zero cost centers */ +#define CALLGRIND_ZERO_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__ZERO_STATS, \ + 0, 0, 0, 0, 0) + +/* Toggles collection state. + The collection state specifies whether the happening of events + should be noted or if they are to be ignored. Events are noted + by increment of counters in a cost center */ +#define CALLGRIND_TOGGLE_COLLECT \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__TOGGLE_COLLECT, \ + 0, 0, 0, 0, 0) + +/* Start full callgrind instrumentation if not already switched on. + When cache simulation is done, it will flush the simulated cache; + this will lead to an artificial cache warmup phase afterwards with + cache misses which would not have happened in reality. */ +#define CALLGRIND_START_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__START_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +/* Stop full callgrind instrumentation if not already switched off. + This flushes Valgrinds translation cache, and does no additional + instrumentation afterwards, which effectivly will run at the same + speed as the "none" tool (ie. at minimal slowdown). + Use this to bypass Callgrind aggregation for uninteresting code parts. + To start Callgrind in this mode to ignore the setup phase, use + the option "--instr-atstart=no". */ +#define CALLGRIND_STOP_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STOP_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +#endif /* __CALLGRIND_H */ diff --git a/third_party/valgrind-headers/valgrind.h b/third_party/valgrind-headers/valgrind.h new file mode 100644 index 0000000000000..d33dd30932aa8 --- /dev/null +++ b/third_party/valgrind-headers/valgrind.h @@ -0,0 +1,7157 @@ +/* -*- c -*- + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (valgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of Valgrind, a dynamic binary instrumentation + framework. + + Copyright (C) 2000-2017 Julian Seward. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (valgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + + +/* This file is for inclusion into client (your!) code. + + You can use these macros to manipulate and query Valgrind's + execution inside your own programs. + + The resulting executables will still run without Valgrind, just a + little bit more slowly than they otherwise would, but otherwise + unchanged. When not running on valgrind, each client request + consumes very few (eg. 7) instructions, so the resulting performance + loss is negligible unless you plan to execute client requests + millions of times per second. Nevertheless, if that is still a + problem, you can compile with the NVALGRIND symbol defined (gcc + -DNVALGRIND) so that client requests are not even compiled in. */ + +#ifndef __VALGRIND_H +#define __VALGRIND_H + + +/* ------------------------------------------------------------------ */ +/* VERSION NUMBER OF VALGRIND */ +/* ------------------------------------------------------------------ */ + +/* Specify Valgrind's version number, so that user code can + conditionally compile based on our version number. Note that these + were introduced at version 3.6 and so do not exist in version 3.5 + or earlier. The recommended way to use them to check for "version + X.Y or later" is (eg) + +#if defined(__VALGRIND_MAJOR__) && defined(__VALGRIND_MINOR__) \ + && (__VALGRIND_MAJOR__ > 3 \ + || (__VALGRIND_MAJOR__ == 3 && __VALGRIND_MINOR__ >= 6)) +*/ +#define __VALGRIND_MAJOR__ 3 +#define __VALGRIND_MINOR__ 17 + + +#include + +/* Nb: this file might be included in a file compiled with -ansi. So + we can't use C++ style "//" comments nor the "asm" keyword (instead + use "__asm__"). */ + +/* Derive some tags indicating what the target platform is. Note + that in this file we're using the compiler's CPP symbols for + identifying architectures, which are different to the ones we use + within the rest of Valgrind. Note, __powerpc__ is active for both + 32 and 64-bit PPC, whereas __powerpc64__ is only active for the + latter (on Linux, that is). + + Misc note: how to find out what's predefined in gcc by default: + gcc -Wp,-dM somefile.c +*/ +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_arm64_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + + +#if defined(__APPLE__) && defined(__i386__) +# define PLAT_x86_darwin 1 +#elif defined(__APPLE__) && defined(__x86_64__) +# define PLAT_amd64_darwin 1 +#elif (defined(__MINGW32__) && defined(__i386__)) \ + || defined(__CYGWIN32__) \ + || (defined(_WIN32) && defined(_M_IX86)) +# define PLAT_x86_win32 1 +#elif (defined(__MINGW32__) && defined(__x86_64__)) \ + || (defined(_WIN32) && defined(_M_X64)) +/* __MINGW32__ and _WIN32 are defined in 64 bit mode as well. */ +# define PLAT_amd64_win64 1 +#elif defined(__linux__) && defined(__i386__) +# define PLAT_x86_linux 1 +#elif defined(__linux__) && defined(__x86_64__) && !defined(__ILP32__) +# define PLAT_amd64_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && !defined(__powerpc64__) +# define PLAT_ppc32_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF != 2 +/* Big Endian uses ELF version 1 */ +# define PLAT_ppc64be_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF == 2 +/* Little Endian uses ELF version 2 */ +# define PLAT_ppc64le_linux 1 +#elif defined(__linux__) && defined(__arm__) && !defined(__aarch64__) +# define PLAT_arm_linux 1 +#elif defined(__linux__) && defined(__aarch64__) && !defined(__arm__) +# define PLAT_arm64_linux 1 +#elif defined(__linux__) && defined(__s390__) && defined(__s390x__) +# define PLAT_s390x_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==64) +# define PLAT_mips64_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==32) +# define PLAT_mips32_linux 1 +#elif defined(__linux__) && defined(__nanomips__) +# define PLAT_nanomips_linux 1 +#elif defined(__sun) && defined(__i386__) +# define PLAT_x86_solaris 1 +#elif defined(__sun) && defined(__x86_64__) +# define PLAT_amd64_solaris 1 +#else +/* If we're not compiling for our target platform, don't generate + any inline asms. */ +# if !defined(NVALGRIND) +# define NVALGRIND 1 +# endif +#endif + + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE SPECIFICS for SPECIAL INSTRUCTIONS. There is nothing */ +/* in here of use to end-users -- skip to the next section. */ +/* ------------------------------------------------------------------ */ + +/* + * VALGRIND_DO_CLIENT_REQUEST(): a statement that invokes a Valgrind client + * request. Accepts both pointers and integers as arguments. + * + * VALGRIND_DO_CLIENT_REQUEST_STMT(): a statement that invokes a Valgrind + * client request that does not return a value. + + * VALGRIND_DO_CLIENT_REQUEST_EXPR(): a C expression that invokes a Valgrind + * client request and whose value equals the client request result. Accepts + * both pointers and integers as arguments. Note that such calls are not + * necessarily pure functions -- they may have side effects. + */ + +#define VALGRIND_DO_CLIENT_REQUEST(_zzq_rlval, _zzq_default, \ + _zzq_request, _zzq_arg1, _zzq_arg2, \ + _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (_zzq_rlval) = VALGRIND_DO_CLIENT_REQUEST_EXPR((_zzq_default), \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#define VALGRIND_DO_CLIENT_REQUEST_STMT(_zzq_request, _zzq_arg1, \ + _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (void) VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#if defined(NVALGRIND) + +/* Define NVALGRIND to completely remove the Valgrind magic sequence + from the compiled code (analogous to NDEBUG's effects on + assert()) */ +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + (_zzq_default) + +#else /* ! NVALGRIND */ + +/* The following defines the magic code sequences which the JITter + spots and handles magically. Don't look too closely at them as + they will rot your brain. + + The assembly code sequences for all architectures is in this one + file. This is because this file must be stand-alone, and we don't + want to have multiple files. + + For VALGRIND_DO_CLIENT_REQUEST, we must ensure that the default + value gets put in the return slot, so that everything works when + this is executed not under Valgrind. Args are passed in a memory + block, and so there's no intrinsic limit to the number that could + be passed, but it's currently five. + + The macro args are: + _zzq_rlval result lvalue + _zzq_default default value (result returned when running on real CPU) + _zzq_request request code + _zzq_arg1..5 request params + + The other two macros are used to support function wrapping, and are + a lot simpler. VALGRIND_GET_NR_CONTEXT returns the value of the + guest's NRADDR pseudo-register and whatever other information is + needed to safely run the call original from the wrapper: on + ppc64-linux, the R2 value at the divert point is also needed. This + information is abstracted into a user-visible type, OrigFn. + + VALGRIND_CALL_NOREDIR_* behaves the same as the following on the + guest, but guarantees that the branch instruction will not be + redirected: x86: call *%eax, amd64: call *%rax, ppc32/ppc64: + branch-and-link-to-r11. VALGRIND_CALL_NOREDIR is just text, not a + complete inline asm, since it needs to be combined with more magic + inline asm stuff to be useful. +*/ + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || (defined(PLAT_x86_win32) && defined(__GNUC__)) \ + || defined(PLAT_x86_solaris) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "roll $3, %%edi ; roll $13, %%edi\n\t" \ + "roll $29, %%edi ; roll $19, %%edi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EDX = client_request ( %EAX ) */ \ + "xchgl %%ebx,%%ebx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + "xchgl %%ecx,%%ecx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%EAX */ \ + "xchgl %%edx,%%edx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgl %%edi,%%edi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || (PLAT_x86_win32 && __GNUC__) + || PLAT_x86_solaris */ + +/* ------------------------- x86-Win32 ------------------------- */ + +#if defined(PLAT_x86_win32) && !defined(__GNUC__) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#if defined(_MSC_VER) + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm rol edi, 3 __asm rol edi, 13 \ + __asm rol edi, 29 __asm rol edi, 19 + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + valgrind_do_client_request_expr((uintptr_t)(_zzq_default), \ + (uintptr_t)(_zzq_request), (uintptr_t)(_zzq_arg1), \ + (uintptr_t)(_zzq_arg2), (uintptr_t)(_zzq_arg3), \ + (uintptr_t)(_zzq_arg4), (uintptr_t)(_zzq_arg5)) + +static __inline uintptr_t +valgrind_do_client_request_expr(uintptr_t _zzq_default, uintptr_t _zzq_request, + uintptr_t _zzq_arg1, uintptr_t _zzq_arg2, + uintptr_t _zzq_arg3, uintptr_t _zzq_arg4, + uintptr_t _zzq_arg5) +{ + volatile uintptr_t _zzq_args[6]; + volatile unsigned int _zzq_result; + _zzq_args[0] = (uintptr_t)(_zzq_request); + _zzq_args[1] = (uintptr_t)(_zzq_arg1); + _zzq_args[2] = (uintptr_t)(_zzq_arg2); + _zzq_args[3] = (uintptr_t)(_zzq_arg3); + _zzq_args[4] = (uintptr_t)(_zzq_arg4); + _zzq_args[5] = (uintptr_t)(_zzq_arg5); + __asm { __asm lea eax, _zzq_args __asm mov edx, _zzq_default + __SPECIAL_INSTRUCTION_PREAMBLE + /* %EDX = client_request ( %EAX ) */ + __asm xchg ebx,ebx + __asm mov _zzq_result, edx + } + return _zzq_result; +} + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + __asm xchg ecx,ecx \ + __asm mov __addr, eax \ + } \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX ERROR + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm xchg edi,edi \ + } \ + } while (0) + +#else +#error Unsupported compiler. +#endif + +#endif /* PLAT_x86_win32 */ + +/* ----------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) \ + || (defined(PLAT_amd64_win64) && defined(__GNUC__)) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rolq $3, %%rdi ; rolq $13, %%rdi\n\t" \ + "rolq $61, %%rdi ; rolq $51, %%rdi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RDX = client_request ( %RAX ) */ \ + "xchgq %%rbx,%%rbx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RAX = guest_NRADDR */ \ + "xchgq %%rcx,%%rcx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_RAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%RAX */ \ + "xchgq %%rdx,%%rdx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgq %%rdi,%%rdi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------- amd64-Win64 ------------------------- */ + +#if defined(PLAT_amd64_win64) && !defined(__GNUC__) + +#error Unsupported compiler. + +#endif /* PLAT_amd64_win64 */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rlwinm 0,0,3,0,31 ; rlwinm 0,0,13,0,31\n\t" \ + "rlwinm 0,0,29,0,31 ; rlwinm 0,0,19,0,31\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned int _zzq_args[6]; \ + unsigned int _zzq_result; \ + unsigned int* _zzq_ptr; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +#if defined(PLAT_ppc64le_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R12 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "mov r12, r12, ror #3 ; mov r12, r12, ror #13 \n\t" \ + "mov r12, r12, ror #29 ; mov r12, r12, ror #19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("mov r3, %1\n\t" /*default*/ \ + "mov r4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = client_request ( R4 ) */ \ + "orr r10, r10, r10\n\t" \ + "mov %0, r3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "cc","memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = guest_NRADDR */ \ + "orr r11, r11, r11\n\t" \ + "mov %0, r3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R4 */ \ + "orr r12, r12, r12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr r9, r9, r9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------- */ + +#if defined(PLAT_arm64_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "ror x12, x12, #3 ; ror x12, x12, #13 \n\t" \ + "ror x12, x12, #51 ; ror x12, x12, #61 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("mov x3, %1\n\t" /*default*/ \ + "mov x4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = client_request ( X4 ) */ \ + "orr x10, x10, x10\n\t" \ + "mov %0, x3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" ((unsigned long int)(_zzq_default)), \ + "r" (&_zzq_args[0]) \ + : "cc","memory", "x3", "x4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = guest_NRADDR */ \ + "orr x11, x11, x11\n\t" \ + "mov %0, x3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "x3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir X8 */ \ + "orr x12, x12, x12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr x9, x9, x9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------ s390x-linux ------------------------ */ + +#if defined(PLAT_s390x_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +/* __SPECIAL_INSTRUCTION_PREAMBLE will be used to identify Valgrind specific + * code. This detection is implemented in platform specific toIR.c + * (e.g. VEX/priv/guest_s390_decoder.c). + */ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "lr 15,15\n\t" \ + "lr 1,1\n\t" \ + "lr 2,2\n\t" \ + "lr 3,3\n\t" + +#define __CLIENT_REQUEST_CODE "lr 2,2\n\t" +#define __GET_NR_CONTEXT_CODE "lr 3,3\n\t" +#define __CALL_NO_REDIR_CODE "lr 4,4\n\t" +#define __VEX_INJECT_IR_CODE "lr 5,5\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(/* r2 = args */ \ + "lgr 2,%1\n\t" \ + /* r3 = default */ \ + "lgr 3,%2\n\t" \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CLIENT_REQUEST_CODE \ + /* results = r3 */ \ + "lgr %0, 3\n\t" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "2", "3", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __GET_NR_CONTEXT_CODE \ + "lgr %0, 3\n\t" \ + : "=a" (__addr) \ + : \ + : "cc", "3", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_R1 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CALL_NO_REDIR_CODE + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __VEX_INJECT_IR_CODE); \ + } while (0) + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ---------------- */ + +#if defined(PLAT_mips32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +/* .word 0x342 + * .word 0x742 + * .word 0xC2 + * .word 0x4C2*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "srl $0, $0, 13\n\t" \ + "srl $0, $0, 29\n\t" \ + "srl $0, $0, 3\n\t" \ + "srl $0, $0, 19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* T3 = client_request ( T4 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %t9 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%t9 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- mips64-linux ---------------- */ + +#if defined(PLAT_mips64_linux) + +typedef + struct { + unsigned long nraddr; /* where's the code? */ + } + OrigFn; + +/* dsll $0,$0, 3 + * dsll $0,$0, 13 + * dsll $0,$0, 29 + * dsll $0,$0, 19*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "dsll $0,$0, 3 ; dsll $0,$0,13\n\t" \ + "dsll $0,$0,29 ; dsll $0,$0,19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = client_request ( $12 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +#if defined(PLAT_nanomips_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; +/* + 8000 c04d srl zero, zero, 13 + 8000 c05d srl zero, zero, 29 + 8000 c043 srl zero, zero, 3 + 8000 c053 srl zero, zero, 19 +*/ + +#define __SPECIAL_INSTRUCTION_PREAMBLE "srl[32] $zero, $zero, 13 \n\t" \ + "srl[32] $zero, $zero, 29 \n\t" \ + "srl[32] $zero, $zero, 3 \n\t" \ + "srl[32] $zero, $zero, 19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $a7, %1\n\t" /* default */ \ + "move $t0, %2\n\t" /* ptr */ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = client_request( $t0 ) */ \ + "or[32] $t0, $t0, $t0\n\t" \ + "move %0, $a7\n\t" /* result */ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$a7", "$t0", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = guest_NRADDR */ \ + "or[32] $t1, $t1, $t1\n\t" \ + "move %0, $a7" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$a7"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or[32] $t2, $t2, $t2\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or[32] $t3, $t3, $t3\n\t" \ + ); \ + } while (0) + +#endif +/* Insert assembly code for other platforms here... */ + +#endif /* NVALGRIND */ + + +/* ------------------------------------------------------------------ */ +/* PLATFORM SPECIFICS for FUNCTION WRAPPING. This is all very */ +/* ugly. It's the least-worst tradeoff I can think of. */ +/* ------------------------------------------------------------------ */ + +/* This section defines magic (a.k.a appalling-hack) macros for doing + guaranteed-no-redirection macros, so as to get from function + wrappers to the functions they are wrapping. The whole point is to + construct standard call sequences, but to do the call itself with a + special no-redirect call pseudo-instruction that the JIT + understands and handles specially. This section is long and + repetitious, and I can't see a way to make it shorter. + + The naming scheme is as follows: + + CALL_FN_{W,v}_{v,W,WW,WWW,WWWW,5W,6W,7W,etc} + + 'W' stands for "word" and 'v' for "void". Hence there are + different macros for calling arity 0, 1, 2, 3, 4, etc, functions, + and for each, the possibility of returning a word-typed result, or + no result. +*/ + +/* Use these to write the name of your wrapper. NOTE: duplicates + VG_WRAP_FUNCTION_Z{U,Z} in pub_tool_redir.h. NOTE also: inserts + the default behaviour equivalance class tag "0000" into the name. + See pub_tool_redir.h for details -- normally you don't need to + think about this, though. */ + +/* Use an extra level of macroisation so as to ensure the soname/fnname + args are fully macro-expanded before pasting them together. */ +#define VG_CONCAT4(_aa,_bb,_cc,_dd) _aa##_bb##_cc##_dd + +#define I_WRAP_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgw00000ZU_,soname,_,fnname) + +#define I_WRAP_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgw00000ZZ_,soname,_,fnname) + +/* Use this macro from within a wrapper function to collect the + context (address and possibly other info) of the original function. + Once you have that you can then use it in one of the CALL_FN_ + macros. The type of the argument _lval is OrigFn. */ +#define VALGRIND_GET_ORIG_FN(_lval) VALGRIND_GET_NR_CONTEXT(_lval) + +/* Also provide end-user facilities for function replacement, rather + than wrapping. A replacement function differs from a wrapper in + that it has no way to get hold of the original function being + called, and hence no way to call onwards to it. In a replacement + function, VALGRIND_GET_ORIG_FN always returns zero. */ + +#define I_REPLACE_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgr00000ZU_,soname,_,fnname) + +#define I_REPLACE_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgr00000ZZ_,soname,_,fnname) + +/* Derivatives of the main macros below, for calling functions + returning void. */ + +#define CALL_FN_v_v(fnptr) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_v(_junk,fnptr); } while (0) + +#define CALL_FN_v_W(fnptr, arg1) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_W(_junk,fnptr,arg1); } while (0) + +#define CALL_FN_v_WW(fnptr, arg1,arg2) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WW(_junk,fnptr,arg1,arg2); } while (0) + +#define CALL_FN_v_WWW(fnptr, arg1,arg2,arg3) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWW(_junk,fnptr,arg1,arg2,arg3); } while (0) + +#define CALL_FN_v_WWWW(fnptr, arg1,arg2,arg3,arg4) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWWW(_junk,fnptr,arg1,arg2,arg3,arg4); } while (0) + +#define CALL_FN_v_5W(fnptr, arg1,arg2,arg3,arg4,arg5) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_5W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5); } while (0) + +#define CALL_FN_v_6W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_6W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6); } while (0) + +#define CALL_FN_v_7W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6,arg7) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_7W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6,arg7); } while (0) + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || defined(PLAT_x86_solaris) + +/* These regs are trashed by the hidden call. No need to mention eax + as gcc can already see that, plus causes gcc to bomb. */ +#define __CALLER_SAVED_REGS /*"eax"*/ "ecx", "edx" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movl %%esp,%%edi\n\t" \ + "andl $0xfffffff0,%%esp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movl %%edi,%%esp\n\t" + +/* These CALL_FN_ macros assume that on x86-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 48(%%eax)\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || PLAT_x86_solaris */ + +/* ---------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) + +/* ARGREGS: rdi rsi rdx rcx r8 r9 (the rest on stack in R-to-L order) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS /*"rax",*/ "rcx", "rdx", "rsi", \ + "rdi", "r8", "r9", "r10", "r11" + +/* This is all pretty complex. It's so as to make stack unwinding + work reliably. See bug 243270. The basic problem is the sub and + add of 128 of %rsp in all of the following macros. If gcc believes + the CFA is in %rsp, then unwinding may fail, because what's at the + CFA is not what gcc "expected" when it constructs the CFIs for the + places where the macros are instantiated. + + But we can't just add a CFI annotation to increase the CFA offset + by 128, to match the sub of 128 from %rsp, because we don't know + whether gcc has chosen %rsp as the CFA at that point, or whether it + has chosen some other register (eg, %rbp). In the latter case, + adding a CFI annotation to change the CFA offset is simply wrong. + + So the solution is to get hold of the CFA using + __builtin_dwarf_cfa(), put it in a known register, and add a + CFI annotation to say what the register is. We choose %rbp for + this (perhaps perversely), because: + + (1) %rbp is already subject to unwinding. If a new register was + chosen then the unwinder would have to unwind it in all stack + traces, which is expensive, and + + (2) %rbp is already subject to precise exception updates in the + JIT. If a new register was chosen, we'd have to have precise + exceptions for it too, which reduces performance of the + generated code. + + However .. one extra complication. We can't just whack the result + of __builtin_dwarf_cfa() into %rbp and then add %rbp to the + list of trashed registers at the end of the inline assembly + fragments; gcc won't allow %rbp to appear in that list. Hence + instead we need to stash %rbp in %r15 for the duration of the asm, + and say that %r15 is trashed instead. gcc seems happy to go with + that. + + Oh .. and this all needs to be conditionalised so that it is + unchanged from before this commit, when compiled with older gccs + that don't support __builtin_dwarf_cfa. Furthermore, since + this header file is freestanding, it has to be independent of + config.h, and so the following conditionalisation cannot depend on + configure time checks. + + Although it's not clear from + 'defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)', + this expression excludes Darwin. + .cfi directives in Darwin assembly appear to be completely + different and I haven't investigated how they work. + + For even more entertainment value, note we have to use the + completely undocumented __builtin_dwarf_cfa(), which appears to + really compute the CFA, whereas __builtin_frame_address(0) claims + to but actually doesn't. See + https://bugs.kde.org/show_bug.cgi?id=243270#c47 +*/ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"r"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + "movq %%rbp, %%r15\n\t" \ + "movq %2, %%rbp\n\t" \ + ".cfi_remember_state\n\t" \ + ".cfi_def_cfa rbp, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "movq %%r15, %%rbp\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movq %%rsp,%%r14\n\t" \ + "andq $0xfffffffffffffff0,%%rsp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movq %%r14,%%rsp\n\t" + +/* These CALL_FN_ macros assume that on amd64-linux, sizeof(unsigned + long) == 8. */ + +/* NB 9 Sept 07. There is a nasty kludge here in all these CALL_FN_ + macros. In order not to trash the stack redzone, we need to drop + %rsp by 128 before the hidden call, and restore afterwards. The + nastyness is that it is only by luck that the stack still appears + to be unwindable during the hidden call - since then the behaviour + of any routine using this macro does not match what the CFI data + says. Sigh. + + Why is this important? Imagine that a wrapper has a stack + allocated local, and passes to the hidden call, a pointer to it. + Because gcc does not know about the hidden call, it may allocate + that local in the redzone. Unfortunately the hidden call may then + trash it before it comes to use it. So we must step clear of the + redzone, for the duration of the hidden call, to make it safe. + + Probably the same problem afflicts the other redzone-style ABIs too + (ppc64-linux); but for those, the stack is + self describing (none of this CFI nonsense) so at least messing + with the stack pointer doesn't give a danger of non-unwindable + stack. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 96(%%rax)\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +/* This is useful for finding out about the on-stack stuff: + + extern int f9 ( int,int,int,int,int,int,int,int,int ); + extern int f10 ( int,int,int,int,int,int,int,int,int,int ); + extern int f11 ( int,int,int,int,int,int,int,int,int,int,int ); + extern int f12 ( int,int,int,int,int,int,int,int,int,int,int,int ); + + int g9 ( void ) { + return f9(11,22,33,44,55,66,77,88,99); + } + int g10 ( void ) { + return f10(11,22,33,44,55,66,77,88,99,110); + } + int g11 ( void ) { + return f11(11,22,33,44,55,66,77,88,99,110,121); + } + int g12 ( void ) { + return f12(11,22,33,44,55,66,77,88,99,110,121,132); + } +*/ + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rlwinm 1,1,0,0,27\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc32-linux, + sizeof(unsigned long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg12 */ \ + "lwz 3,48(11)\n\t" \ + "stw 3,20(1)\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(11)\n\t" \ + "std 3,136(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +/* ------------------------- ppc64le-linux ----------------------- */ +#if defined(PLAT_ppc64le_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(12)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "r0", "r1", "r2", "r3","r4", "r12", "r14" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +/* This is a bit tricky. We store the original stack pointer in r10 + as it is callee-saves. gcc doesn't allow the use of r11 for some + reason. Also, we can't directly "bic" the stack pointer in thumb + mode since r13 isn't an allowed register number in that context. + So use r4 as a temporary, since that is about to get trashed + anyway, just after each use of this macro. Side effect is we need + to be very careful about any future changes, since + VALGRIND_ALIGN_STACK simply assumes r4 is usable. */ +#define VALGRIND_ALIGN_STACK \ + "mov r10, sp\n\t" \ + "mov r4, sp\n\t" \ + "bic r4, r4, #7\n\t" \ + "mov sp, r4\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, r10\n\t" + +/* These CALL_FN_ macros assume that on arm-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "push {r0, r1, r2, r3} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "ldr r2, [%1, #48] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------ */ + +#if defined(PLAT_arm64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "x0", "x1", "x2", "x3","x4", "x5", "x6", "x7", "x8", "x9", \ + "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", \ + "x18", "x19", "x20", "x30", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", \ + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", \ + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", \ + "v26", "v27", "v28", "v29", "v30", "v31" + +/* x21 is callee-saved, so we can use it to save and restore SP around + the hidden call. */ +#define VALGRIND_ALIGN_STACK \ + "mov x21, sp\n\t" \ + "bic sp, x21, #15\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, x21\n\t" + +/* These CALL_FN_ macros assume that on arm64-linux, + sizeof(unsigned long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11, \ + arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1, #96] \n\t" \ + "str x8, [sp, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------- s390x-linux ------------------------- */ + +#if defined(PLAT_s390x_linux) + +/* Similar workaround as amd64 (see above), but we use r11 as frame + pointer and save the old r11 in r7. r11 might be used for + argvec, therefore we copy argvec in r1 since r1 is clobbered + after the call anyway. */ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"d"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + ".cfi_remember_state\n\t" \ + "lgr 1,%1\n\t" /* copy the argvec pointer in r1 */ \ + "lgr 7,11\n\t" \ + "lgr 11,%2\n\t" \ + ".cfi_def_cfa r11, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "lgr 11, 7\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE \ + "lgr 1,%1\n\t" +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Nb: On s390 the stack pointer is properly aligned *at all times* + according to the s390 GCC maintainer. (The ABI specification is not + precise in this regard.) Therefore, VALGRIND_ALIGN_STACK and + VALGRIND_RESTORE_STACK are not defined here. */ + +/* These regs are trashed by the hidden call. Note that we overwrite + r14 in s390_irgen_noredir (VEX/priv/guest_s390_irgen.c) to give the + function a proper return address. All others are ABI defined call + clobbers. */ +#if defined(__VX__) || defined(__S390_VX__) +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", \ + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", \ + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", \ + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" +#else +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7" +#endif + +/* Nb: Although r11 is modified in the asm snippets below (inside + VALGRIND_CFI_PROLOGUE) it is not listed in the clobber section, for + two reasons: + (1) r11 is restored in VALGRIND_CFI_EPILOGUE, so effectively it is not + modified + (2) GCC will complain that r11 cannot appear inside a clobber section, + when compiled with -O -fno-omit-frame-pointer + */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 1, 0(1)\n\t" /* target->r1 */ \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "d" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +/* The call abi has the arguments in r2-r6 and stack */ +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1, arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1, arg2, arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1, arg2, arg3, arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1, arg2, arg3, arg4, arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-168\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,168\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-176\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,176\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-184\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,184\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-192\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,192\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-200\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,200\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-208\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,208\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11, arg12)\ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-216\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "mvc 208(8,15), 96(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,216\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ----------------------- */ + +#if defined(PLAT_mips32_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16\n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" /* arg1*/ \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 24\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 24 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "nop\n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 56\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 48(%1) \n\t" \ + "sw $4, 44($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 56 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- nanomips-linux -------------------- */ + +#if defined(PLAT_nanomips_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$t4", "$t5", "$a0", "$a1", "$a2", \ +"$a3", "$a4", "$a5", "$a6", "$a7", "$t0", "$t1", "$t2", "$t3", \ +"$t8","$t9", "$at" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + "lw $a7,32(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9,48(%1) \n\t" \ + "sw $t9,12($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_nanomips_linux */ + +/* ------------------------- mips64-linux ------------------------- */ + +#if defined(PLAT_mips64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips64-linux, + sizeof(long long) == 8. */ + +#define MIPS64_LONG2REG_CAST(x) ((long long)(long)x) + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[1]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + __asm__ volatile( \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[2]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" /* arg1*/ \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[3]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[4]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[5]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[6]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[7]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[8]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[9]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[10]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + __asm__ volatile( \ + "dsubu $29, $29, 8\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 8\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[11]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + __asm__ volatile( \ + "dsubu $29, $29, 16\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 16\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[12]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + __asm__ volatile( \ + "dsubu $29, $29, 24\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 24\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[13]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + _argvec[12] = MIPS64_LONG2REG_CAST(arg12); \ + __asm__ volatile( \ + "dsubu $29, $29, 32\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 96(%1)\n\t" \ + "sd $4, 24($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 32\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE INDEPENDENT MACROS for CLIENT REQUESTS. */ +/* */ +/* ------------------------------------------------------------------ */ + +/* Some request codes. There are many more of these, but most are not + exposed to end-user view. These are the public ones, all of the + form 0x1000 + small_number. + + Core ones are in the range 0x00000000--0x0000ffff. The non-public + ones start at 0x2000. +*/ + +/* These macros are used by tools -- they must be public, but don't + embed them into other programs. */ +#define VG_USERREQ_TOOL_BASE(a,b) \ + ((unsigned int)(((a)&0xff) << 24 | ((b)&0xff) << 16)) +#define VG_IS_TOOL_USERREQ(a, b, v) \ + (VG_USERREQ_TOOL_BASE(a,b) == ((v) & 0xffff0000)) + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE NUMERIC VALUES OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end of the most + relevant group. */ +typedef + enum { VG_USERREQ__RUNNING_ON_VALGRIND = 0x1001, + VG_USERREQ__DISCARD_TRANSLATIONS = 0x1002, + + /* These allow any function to be called from the simulated + CPU but run on the real CPU. Nb: the first arg passed to + the function is always the ThreadId of the running + thread! So CLIENT_CALL0 actually requires a 1 arg + function, etc. */ + VG_USERREQ__CLIENT_CALL0 = 0x1101, + VG_USERREQ__CLIENT_CALL1 = 0x1102, + VG_USERREQ__CLIENT_CALL2 = 0x1103, + VG_USERREQ__CLIENT_CALL3 = 0x1104, + + /* Can be useful in regression testing suites -- eg. can + send Valgrind's output to /dev/null and still count + errors. */ + VG_USERREQ__COUNT_ERRORS = 0x1201, + + /* Allows the client program and/or gdbserver to execute a monitor + command. */ + VG_USERREQ__GDB_MONITOR_COMMAND = 0x1202, + + /* Allows the client program to change a dynamic command line + option. */ + VG_USERREQ__CLO_CHANGE = 0x1203, + + /* These are useful and can be interpreted by any tool that + tracks malloc() et al, by using vg_replace_malloc.c. */ + VG_USERREQ__MALLOCLIKE_BLOCK = 0x1301, + VG_USERREQ__RESIZEINPLACE_BLOCK = 0x130b, + VG_USERREQ__FREELIKE_BLOCK = 0x1302, + /* Memory pool support. */ + VG_USERREQ__CREATE_MEMPOOL = 0x1303, + VG_USERREQ__DESTROY_MEMPOOL = 0x1304, + VG_USERREQ__MEMPOOL_ALLOC = 0x1305, + VG_USERREQ__MEMPOOL_FREE = 0x1306, + VG_USERREQ__MEMPOOL_TRIM = 0x1307, + VG_USERREQ__MOVE_MEMPOOL = 0x1308, + VG_USERREQ__MEMPOOL_CHANGE = 0x1309, + VG_USERREQ__MEMPOOL_EXISTS = 0x130a, + + /* Allow printfs to valgrind log. */ + /* The first two pass the va_list argument by value, which + assumes it is the same size as or smaller than a UWord, + which generally isn't the case. Hence are deprecated. + The second two pass the vargs by reference and so are + immune to this problem. */ + /* both :: char* fmt, va_list vargs (DEPRECATED) */ + VG_USERREQ__PRINTF = 0x1401, + VG_USERREQ__PRINTF_BACKTRACE = 0x1402, + /* both :: char* fmt, va_list* vargs */ + VG_USERREQ__PRINTF_VALIST_BY_REF = 0x1403, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF = 0x1404, + + /* Stack support. */ + VG_USERREQ__STACK_REGISTER = 0x1501, + VG_USERREQ__STACK_DEREGISTER = 0x1502, + VG_USERREQ__STACK_CHANGE = 0x1503, + + /* Wine support */ + VG_USERREQ__LOAD_PDB_DEBUGINFO = 0x1601, + + /* Querying of debug info. */ + VG_USERREQ__MAP_IP_TO_SRCLOC = 0x1701, + + /* Disable/enable error reporting level. Takes a single + Word arg which is the delta to this thread's error + disablement indicator. Hence 1 disables or further + disables errors, and -1 moves back towards enablement. + Other values are not allowed. */ + VG_USERREQ__CHANGE_ERR_DISABLEMENT = 0x1801, + + /* Some requests used for Valgrind internal, such as + self-test or self-hosting. */ + /* Initialise IR injection */ + VG_USERREQ__VEX_INIT_FOR_IRI = 0x1901, + /* Used by Inner Valgrind to inform Outer Valgrind where to + find the list of inner guest threads */ + VG_USERREQ__INNER_THREADS = 0x1902 + } Vg_ClientRequest; + +#if !defined(__GNUC__) +# define __extension__ /* */ +#endif + + +/* Returns the number of Valgrinds this code is running under. That + is, 0 if running natively, 1 if running under Valgrind, 2 if + running under Valgrind which is running under another Valgrind, + etc. */ +#define RUNNING_ON_VALGRIND \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* if not */, \ + VG_USERREQ__RUNNING_ON_VALGRIND, \ + 0, 0, 0, 0, 0) \ + + +/* Discard translation of code in the range [_qzz_addr .. _qzz_addr + + _qzz_len - 1]. Useful if you are debugging a JITter or some such, + since it provides a way to make sure valgrind will retranslate the + invalidated area. Returns no value. */ +#define VALGRIND_DISCARD_TRANSLATIONS(_qzz_addr,_qzz_len) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DISCARD_TRANSLATIONS, \ + _qzz_addr, _qzz_len, 0, 0, 0) + +#define VALGRIND_INNER_THREADS(_qzz_addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__INNER_THREADS, \ + _qzz_addr, 0, 0, 0, 0) + + +/* These requests are for getting Valgrind itself to print something. + Possibly with a backtrace. This is a really ugly hack. The return value + is the number of characters printed, excluding the "**** " part at the + start and the backtrace (if present). */ + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +/* Modern GCC will optimize the static routine out if unused, + and unused attribute will shut down warnings about it. */ +static int VALGRIND_PRINTF(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +static int VALGRIND_PRINTF_BACKTRACE(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF_BACKTRACE(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + + +/* These requests allow control to move from the simulated CPU to the + real CPU, calling an arbitrary function. + + Note that the current ThreadId is inserted as the first argument. + So this call: + + VALGRIND_NON_SIMD_CALL2(f, arg1, arg2) + + requires f to have this signature: + + Word f(Word tid, Word arg1, Word arg2) + + where "Word" is a word-sized type. + + Note that these client requests are not entirely reliable. For example, + if you call a function with them that subsequently calls printf(), + there's a high chance Valgrind will crash. Generally, your prospects of + these working are made higher if the called function does not refer to + any global variables, and does not refer to any libc or other functions + (printf et al). Any kind of entanglement with libc or dynamic linking is + likely to have a bad outcome, for tricky reasons which we've grappled + with a lot in the past. +*/ +#define VALGRIND_NON_SIMD_CALL0(_qyy_fn) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL0, \ + _qyy_fn, \ + 0, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL1(_qyy_fn, _qyy_arg1) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL1, \ + _qyy_fn, \ + _qyy_arg1, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL2(_qyy_fn, _qyy_arg1, _qyy_arg2) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL2, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, 0, 0) + +#define VALGRIND_NON_SIMD_CALL3(_qyy_fn, _qyy_arg1, _qyy_arg2, _qyy_arg3) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL3, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, \ + _qyy_arg3, 0) + + +/* Counts the number of errors that have been recorded by a tool. Nb: + the tool must record the errors with VG_(maybe_record_error)() or + VG_(unique_error)() for them to be counted. */ +#define VALGRIND_COUNT_ERRORS \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + 0 /* default return */, \ + VG_USERREQ__COUNT_ERRORS, \ + 0, 0, 0, 0, 0) + +/* Several Valgrind tools (Memcheck, Massif, Helgrind, DRD) rely on knowing + when heap blocks are allocated in order to give accurate results. This + happens automatically for the standard allocator functions such as + malloc(), calloc(), realloc(), memalign(), new, new[], free(), delete, + delete[], etc. + + But if your program uses a custom allocator, this doesn't automatically + happen, and Valgrind will not do as well. For example, if you allocate + superblocks with mmap() and then allocates chunks of the superblocks, all + Valgrind's observations will be at the mmap() level and it won't know that + the chunks should be considered separate entities. In Memcheck's case, + that means you probably won't get heap block overrun detection (because + there won't be redzones marked as unaddressable) and you definitely won't + get any leak detection. + + The following client requests allow a custom allocator to be annotated so + that it can be handled accurately by Valgrind. + + VALGRIND_MALLOCLIKE_BLOCK marks a region of memory as having been allocated + by a malloc()-like function. For Memcheck (an illustrative case), this + does two things: + + - It records that the block has been allocated. This means any addresses + within the block mentioned in error messages will be + identified as belonging to the block. It also means that if the block + isn't freed it will be detected by the leak checker. + + - It marks the block as being addressable and undefined (if 'is_zeroed' is + not set), or addressable and defined (if 'is_zeroed' is set). This + controls how accesses to the block by the program are handled. + + 'addr' is the start of the usable block (ie. after any + redzone), 'sizeB' is its size. 'rzB' is the redzone size if the allocator + can apply redzones -- these are blocks of padding at the start and end of + each block. Adding redzones is recommended as it makes it much more likely + Valgrind will spot block overruns. `is_zeroed' indicates if the memory is + zeroed (or filled with another predictable value), as is the case for + calloc(). + + VALGRIND_MALLOCLIKE_BLOCK should be put immediately after the point where a + heap block -- that will be used by the client program -- is allocated. + It's best to put it at the outermost level of the allocator if possible; + for example, if you have a function my_alloc() which calls + internal_alloc(), and the client request is put inside internal_alloc(), + stack traces relating to the heap block will contain entries for both + my_alloc() and internal_alloc(), which is probably not what you want. + + For Memcheck users: if you use VALGRIND_MALLOCLIKE_BLOCK to carve out + custom blocks from within a heap block, B, that has been allocated with + malloc/calloc/new/etc, then block B will be *ignored* during leak-checking + -- the custom blocks will take precedence. + + VALGRIND_FREELIKE_BLOCK is the partner to VALGRIND_MALLOCLIKE_BLOCK. For + Memcheck, it does two things: + + - It records that the block has been deallocated. This assumes that the + block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - It marks the block as being unaddressable. + + VALGRIND_FREELIKE_BLOCK should be put immediately after the point where a + heap block is deallocated. + + VALGRIND_RESIZEINPLACE_BLOCK informs a tool about reallocation. For + Memcheck, it does four things: + + - It records that the size of a block has been changed. This assumes that + the block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - If the block shrunk, it marks the freed memory as being unaddressable. + + - If the block grew, it marks the new area as undefined and defines a red + zone past the end of the new block. + + - The V-bits of the overlap between the old and the new block are preserved. + + VALGRIND_RESIZEINPLACE_BLOCK should be put after allocation of the new block + and before deallocation of the old block. + + In many cases, these three client requests will not be enough to get your + allocator working well with Memcheck. More specifically, if your allocator + writes to freed blocks in any way then a VALGRIND_MAKE_MEM_UNDEFINED call + will be necessary to mark the memory as addressable just before the zeroing + occurs, otherwise you'll get a lot of invalid write errors. For example, + you'll need to do this if your allocator recycles freed blocks, but it + zeroes them before handing them back out (via VALGRIND_MALLOCLIKE_BLOCK). + Alternatively, if your allocator reuses freed blocks for allocator-internal + data structures, VALGRIND_MAKE_MEM_UNDEFINED calls will also be necessary. + + Really, what's happening is a blurring of the lines between the client + program and the allocator... after VALGRIND_FREELIKE_BLOCK is called, the + memory should be considered unaddressable to the client program, but the + allocator knows more than the rest of the client program and so may be able + to safely access it. Extra client requests are necessary for Valgrind to + understand the distinction between the allocator and the rest of the + program. + + Ignored if addr == 0. +*/ +#define VALGRIND_MALLOCLIKE_BLOCK(addr, sizeB, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MALLOCLIKE_BLOCK, \ + addr, sizeB, rzB, is_zeroed, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_RESIZEINPLACE_BLOCK(addr, oldSizeB, newSizeB, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__RESIZEINPLACE_BLOCK, \ + addr, oldSizeB, newSizeB, rzB, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_FREELIKE_BLOCK(addr, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__FREELIKE_BLOCK, \ + addr, rzB, 0, 0, 0) + +/* Create a memory pool. */ +#define VALGRIND_CREATE_MEMPOOL(pool, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, 0, 0) + +/* Create a memory pool with some flags specifying extended behaviour. + When flags is zero, the behaviour is identical to VALGRIND_CREATE_MEMPOOL. + + The flag VALGRIND_MEMPOOL_METAPOOL specifies that the pieces of memory + associated with the pool using VALGRIND_MEMPOOL_ALLOC will be used + by the application as superblocks to dole out MALLOC_LIKE blocks using + VALGRIND_MALLOCLIKE_BLOCK. In other words, a meta pool is a "2 levels" + pool : first level is the blocks described by VALGRIND_MEMPOOL_ALLOC. + The second level blocks are described using VALGRIND_MALLOCLIKE_BLOCK. + Note that the association between the pool and the second level blocks + is implicit : second level blocks will be located inside first level + blocks. It is necessary to use the VALGRIND_MEMPOOL_METAPOOL flag + for such 2 levels pools, as otherwise valgrind will detect overlapping + memory blocks, and will abort execution (e.g. during leak search). + + Such a meta pool can also be marked as an 'auto free' pool using the flag + VALGRIND_MEMPOOL_AUTO_FREE, which must be OR-ed together with the + VALGRIND_MEMPOOL_METAPOOL. For an 'auto free' pool, VALGRIND_MEMPOOL_FREE + will automatically free the second level blocks that are contained + inside the first level block freed with VALGRIND_MEMPOOL_FREE. + In other words, calling VALGRIND_MEMPOOL_FREE will cause implicit calls + to VALGRIND_FREELIKE_BLOCK for all the second level blocks included + in the first level block. + Note: it is an error to use the VALGRIND_MEMPOOL_AUTO_FREE flag + without the VALGRIND_MEMPOOL_METAPOOL flag. +*/ +#define VALGRIND_MEMPOOL_AUTO_FREE 1 +#define VALGRIND_MEMPOOL_METAPOOL 2 +#define VALGRIND_CREATE_MEMPOOL_EXT(pool, rzB, is_zeroed, flags) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, flags, 0) + +/* Destroy a memory pool. */ +#define VALGRIND_DESTROY_MEMPOOL(pool) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DESTROY_MEMPOOL, \ + pool, 0, 0, 0, 0) + +/* Associate a piece of memory with a memory pool. */ +#define VALGRIND_MEMPOOL_ALLOC(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_ALLOC, \ + pool, addr, size, 0, 0) + +/* Disassociate a piece of memory from a memory pool. */ +#define VALGRIND_MEMPOOL_FREE(pool, addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_FREE, \ + pool, addr, 0, 0, 0) + +/* Disassociate any pieces outside a particular range. */ +#define VALGRIND_MEMPOOL_TRIM(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_TRIM, \ + pool, addr, size, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MOVE_MEMPOOL(poolA, poolB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MOVE_MEMPOOL, \ + poolA, poolB, 0, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MEMPOOL_CHANGE(pool, addrA, addrB, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_CHANGE, \ + pool, addrA, addrB, size, 0) + +/* Return 1 if a mempool exists, else 0. */ +#define VALGRIND_MEMPOOL_EXISTS(pool) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MEMPOOL_EXISTS, \ + pool, 0, 0, 0, 0) + +/* Mark a piece of memory as being a stack. Returns a stack id. + start is the lowest addressable stack byte, end is the highest + addressable stack byte. */ +#define VALGRIND_STACK_REGISTER(start, end) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__STACK_REGISTER, \ + start, end, 0, 0, 0) + +/* Unmark the piece of memory associated with a stack id as being a + stack. */ +#define VALGRIND_STACK_DEREGISTER(id) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_DEREGISTER, \ + id, 0, 0, 0, 0) + +/* Change the start and end address of the stack id. + start is the new lowest addressable stack byte, end is the new highest + addressable stack byte. */ +#define VALGRIND_STACK_CHANGE(id, start, end) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_CHANGE, \ + id, start, end, 0, 0) + +/* Load PDB debug info for Wine PE image_map. */ +#define VALGRIND_LOAD_PDB_DEBUGINFO(fd, ptr, total_size, delta) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__LOAD_PDB_DEBUGINFO, \ + fd, ptr, total_size, delta, 0) + +/* Map a code address to a source file name and line number. buf64 + must point to a 64-byte buffer in the caller's address space. The + result will be dumped in there and is guaranteed to be zero + terminated. If no info is found, the first byte is set to zero. */ +#define VALGRIND_MAP_IP_TO_SRCLOC(addr, buf64) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MAP_IP_TO_SRCLOC, \ + addr, buf64, 0, 0, 0) + +/* Disable error reporting for this thread. Behaves in a stack like + way, so you can safely call this multiple times provided that + VALGRIND_ENABLE_ERROR_REPORTING is called the same number of times + to re-enable reporting. The first call of this macro disables + reporting. Subsequent calls have no effect except to increase the + number of VALGRIND_ENABLE_ERROR_REPORTING calls needed to re-enable + reporting. Child threads do not inherit this setting from their + parents -- they are always created with reporting enabled. */ +#define VALGRIND_DISABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + 1, 0, 0, 0, 0) + +/* Re-enable error reporting, as per comments on + VALGRIND_DISABLE_ERROR_REPORTING. */ +#define VALGRIND_ENABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + -1, 0, 0, 0, 0) + +/* Execute a monitor command from the client program. + If a connection is opened with GDB, the output will be sent + according to the output mode set for vgdb. + If no connection is opened, output will go to the log output. + Returns 1 if command not recognised, 0 otherwise. */ +#define VALGRIND_MONITOR_COMMAND(command) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0, VG_USERREQ__GDB_MONITOR_COMMAND, \ + command, 0, 0, 0, 0) + + +/* Change the value of a dynamic command line option. + Note that unknown or not dynamically changeable options + will cause a warning message to be output. */ +#define VALGRIND_CLO_CHANGE(option) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CLO_CHANGE, \ + option, 0, 0, 0, 0) + + +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + +#endif /* __VALGRIND_H */ diff --git a/tools/README.md b/tools/README.md index 5f915d510f86c..b940d378320b3 100644 --- a/tools/README.md +++ b/tools/README.md @@ -24,11 +24,16 @@ Build system pieces: * [setup_helpers](setup_helpers) - Helper code for searching for third-party dependencies on the user system. * [build_pytorch_libs.py](build_pytorch_libs.py) - cross-platform script that - builds all of the constituent libraries of PyTorch, + builds all of the constituent libraries of PyTorch, but not the PyTorch Python extension itself. * [build_libtorch.py](build_libtorch.py) - Script for building libtorch, a standalone C++ library without Python support. This build script is tested in CI. +* [fast_nvcc](fast_nvcc) - Mostly-transparent wrapper over nvcc that + parallelizes compilation when used to build CUDA files for multiple + architectures at once. + * [fast_nvcc.py](fast_nvcc/fast_nvcc.py) - Python script, entrypoint to the + fast nvcc wrapper. Developer tools which you might find useful: @@ -52,8 +57,6 @@ Important if you want to run on AMD GPU: Tools which are only situationally useful: -* [aten_mirror.sh](aten_mirror.sh) - Mirroring script responsible - for keeping https://github.com/zdevito/ATen up-to-date. * [docker](docker) - Dockerfile for running (but not developing) PyTorch, using the official conda binary distribution. Context: https://github.com/pytorch/pytorch/issues/1619 diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 026293a9281a8..9d4fa54c93b33 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -131,6 +131,20 @@ def is_hip_clang(): sources.write(line) print("%s updated" % gloo_cmake_file) +gloo_cmake_file = "third_party/gloo/cmake/Modules/Findrccl.cmake" +if os.path.exists(gloo_cmake_file): + do_write = False + with open(gloo_cmake_file, "r") as sources: + lines = sources.readlines() + newlines = [line.replace('RCCL_LIBRARY', 'RCCL_LIBRARY_PATH') for line in lines] + if lines == newlines: + print("%s skipped" % gloo_cmake_file) + else: + with open(gloo_cmake_file, "w") as sources: + for line in newlines: + sources.write(line) + print("%s updated" % gloo_cmake_file) + hipify_python.hipify( project_directory=proj_dir, output_directory=out_dir, diff --git a/tools/aten_mirror.sh b/tools/aten_mirror.sh deleted file mode 100755 index 6c787bbda568c..0000000000000 --- a/tools/aten_mirror.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/sh - -# This script is run by a cronjob managed by @zdevito -# which mirrors the ATen-specific directories of PyTorch -# to zdevito/ATen, for ease of use of projects that wish -# to depend solely on ATen. -# -# See also .travis.aten.yml, which is the Travis configuration -# for the ATen project (and ensures ATen is separately -# buildable.) - -if [[ -z "$EXTRACTED_REPO" ]]; then - echo "Need to set envvar EXTRACTED_REPO" - exit 1 -fi -if [[ -z "$FULL_REPO" ]]; then - echo "Need to set envvar FULL_REPO" - exit 1 -fi -rm -rf aten-export-repo -git clone $EXTRACTED_REPO aten-export-repo -cd aten-export-repo -git config user.name "Zach DeVito" -git config user.email "zdevito@fb.com" -git remote add fullrepo $FULL_REPO -git fetch fullrepo -git checkout -b temporary-split-branch fullrepo/master -# Cribbed from https://stackoverflow.com/questions/2982055/detach-many-subdirectories-into-a-new-separate-git-repository -# and https://stackoverflow.com/questions/42355621/git-filter-branch-moving-a-folder-with-index-filter-does-not-work -git filter-branch -f --index-filter 'git rm --cached -qr --ignore-unmatch -- . && git reset -q $GIT_COMMIT -- aten cmake third_party/tbb third_party/catch third_party/cpuinfo && (git ls-files -s | sed "s-.travis.aten.yml-.travis.yml-" | sed "s-.gitmodules.aten-.gitmodules-" | git update-index --index-info)' -git checkout master -git merge temporary-split-branch -git push diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 3bc0199ebf47f..aaac1d98a64f8 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -86,7 +86,7 @@ # e.g., it is used by _cudnn_rnn # # If you need a complex expression, e.g., with local variables, -# write a _backward function in tools/autograd/templates/Functions.cpp +# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp # and invoke it from here. By the way, go read # https://github.com/zdevito/ATen/issues/163; this describes an # important hazard that occurs when porting backwards from Python to C++ @@ -162,33 +162,33 @@ self: grad * self.sgn() - name: acos(Tensor self) -> Tensor - self: grad * -((-self * self + 1).rsqrt()) + self: grad * -((-self * self + 1).rsqrt()).conj() - name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - self: grad - other: maybe_multiply(grad, alpha) + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj())) - name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - self: grad + self: handle_r_to_c(self.scalar_type(), grad) - name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - self: maybe_multiply(grad, beta) - batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2)) * alpha - batch2: batch1.transpose(1, 2).bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha + self: maybe_multiply(grad, beta.conj()) + batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2).conj()) * alpha.conj() + batch2: batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha.conj() - name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor - self: grad - tensor1: grad * value / tensor2 - tensor2: -grad * value * tensor1 / (tensor2 * tensor2) + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj()) - name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor - self: grad - tensor1: grad * tensor2 * value - tensor2: grad * tensor1 * value + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj()) - name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - self: maybe_multiply(grad, beta) - mat1: mm_mat1_backward(grad, mat2, mat1, alpha) + self: maybe_multiply(grad, beta.conj()) + mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha) mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha) - name: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor @@ -197,9 +197,9 @@ dense: mm_mat2_backward(grad, sparse, dense.sizes(), dense.strides(), alpha) - name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor - self: maybe_multiply(grad, beta) - mat: grad.ger(vec) * alpha - vec: mat.t().mv(grad) * alpha + self: maybe_multiply(grad, beta.conj()) + mat: grad.ger(vec.conj()) * alpha.conj() + vec: mat.t().conj().mv(grad) * alpha.conj() - name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor self: maybe_multiply(grad, beta) @@ -213,7 +213,7 @@ self: grad - name: angle(Tensor self) -> Tensor - self: grad.to(self.scalar_type()) * (self*Scalar(c10::complex{0.0, 1.0})).conj() / self.abs().pow(2) + self: angle_backward(grad, self) # The four items below are necessary because TensorIterator doesn't work on # Variables (codegen does not unwrap the input Tensor for all() and any() ). @@ -230,19 +230,19 @@ self: not_implemented("all") - name: acosh(Tensor self) -> Tensor - self: grad * (self.pow(2) - 1).rsqrt() + self: grad * (self.pow(2) - 1).rsqrt().conj() - name: acosh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of acosh") - name: asinh(Tensor self) -> Tensor - self: grad * (self.pow(2) + 1).rsqrt() + self: grad * (self.pow(2) + 1).rsqrt().conj() - name: asinh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of asinh") - name: atanh(Tensor self) -> Tensor - self: grad * 1 / (1 - self.pow(2)) + self: grad * 1 / (1 - self.pow(2)).conj() - name: atanh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of atanh") @@ -251,18 +251,18 @@ self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) - name: asin(Tensor self) -> Tensor - self: grad * (-self * self + 1).rsqrt() + self: grad * (-self * self + 1).rsqrt().conj() - name: atan(Tensor self) -> Tensor - self: grad / (self * self + 1) + self: grad / (self * self + 1).conj() - name: atan2(Tensor self, Tensor other) -> Tensor self, other: atan2_backward(grad, self, other, grad_input_mask) - name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - self: maybe_multiply(grad, beta) - batch1: grad.bmm(batch2.transpose(1, 2)) * alpha - batch2: batch1.transpose(1, 2).bmm(grad) * alpha + self: maybe_multiply(grad, beta.conj()) + batch1: grad.bmm(batch2.transpose(1, 2).conj()) * alpha.conj() + batch2: batch1.transpose(1, 2).conj().bmm(grad) * alpha.conj() - name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor self: zeros_like(grad) @@ -275,8 +275,8 @@ self: zeros_like(grad) - name: bmm(Tensor self, Tensor mat2) -> Tensor - self: grad.bmm(mat2.transpose(1, 2)) - mat2: self.transpose(1, 2).bmm(grad) + self: grad.bmm(mat2.transpose(1, 2).conj()) + mat2: self.transpose(1, 2).conj().bmm(grad) - name: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor self: at::_bmm(grad, mat2.transpose(1, 2), deterministic) @@ -294,6 +294,9 @@ - name: cholesky(Tensor self, bool upper=False) -> Tensor self: cholesky_backward(grad, upper, result) +- name: linalg_cholesky(Tensor self) -> Tensor + self: cholesky_backward(grad, false, result) + - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper) @@ -340,6 +343,13 @@ - name: _conj(Tensor self) -> Tensor self: grad.conj() +- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + other: zeros_like(other) + +- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + - name: cos(Tensor self) -> Tensor self: grad * -self.sin().conj() @@ -404,12 +414,12 @@ self: div_tensor_self_backward(grad, at::scalar_to_tensor(other), self.scalar_type()) - name: dot(Tensor self, Tensor tensor) -> Tensor - self: grad * tensor - tensor: grad * self + self: handle_r_to_c(self.scalar_type(), grad * tensor.conj()) + tensor: handle_r_to_c(tensor.scalar_type(), grad * self.conj()) - name: vdot(Tensor self, Tensor other) -> Tensor - self: 'not_implemented("vdot: self")' - other: 'not_implemented("vdot: other")' + self: handle_r_to_c(self.scalar_type(), grad.conj() * other) + other: handle_r_to_c(other.scalar_type(), grad * self) - name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) self: _fused_dropout_backward(grad, result1, p) @@ -434,7 +444,7 @@ self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad - name: exp(Tensor self) -> Tensor - self: grad * result + self: grad * result.conj() - name: exp2(Tensor self) -> Tensor self: grad * result * M_LN2 @@ -498,8 +508,8 @@ self: not_implemented("geqrf") - name: ger(Tensor self, Tensor vec2) -> Tensor - self: grad.mv(vec2) - vec2: grad.t().mv(self) + self: grad.mv(vec2.conj()) + vec2: grad.t().mv(self.conj()) - name: indices(Tensor(a) self) -> Tensor(a) output_differentiability: [False] @@ -540,8 +550,16 @@ - name: i0(Tensor self) -> Tensor self: not_implemented("i0") +- name: igamma(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igamma: input")' + other: grad * exp((self - 1) * log(other) - other - lgamma(self)) + +- name: igammac(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igammac: input")' + other: -grad * exp((self - 1) * log(other) - other - lgamma(self)) + - name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor - self: index_backward(zeros_like(self), indices, grad) + self: index_backward(grad.new_zeros(self.sizes(), self.options()), indices, grad) indices: TensorList() - name: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) @@ -576,7 +594,10 @@ index: non_differentiable - name: inverse(Tensor self) -> Tensor - self: -at::matmul(result.transpose(-2, -1), at::matmul(grad, result.transpose(-2, -1))) + self: -at::matmul(result.conj().transpose(-2, -1), at::matmul(grad, result.conj().transpose(-2, -1))) + +- name: linalg_inv(Tensor self) -> Tensor + self: -at::matmul(result.conj().transpose(-2, -1), at::matmul(grad, result.conj().transpose(-2, -1))) - name: isnan(Tensor self) -> Tensor self: non_differentiable @@ -610,16 +631,16 @@ self: grad * polygamma(n + 1, self) - name: log(Tensor self) -> Tensor - self: grad.div(self) + self: grad.div(self.conj()) - name: log10(Tensor self) -> Tensor - self: grad / (self * 2.3025850929940456) + self: grad / (self.conj() * 2.3025850929940456) - name: log1p(Tensor self) -> Tensor self: log1p_backward(grad, self) - name: log2(Tensor self) -> Tensor - self: grad / (self * 0.6931471805599453) + self: grad / (self.conj() * 0.6931471805599453) - name: logaddexp(Tensor self, Tensor other) -> Tensor self: grad / (1 + exp(other - self)) @@ -629,6 +650,16 @@ self: grad / (1 + pow(2, other - self)) other: grad / (1 + pow(2, self - other)) +- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + self: grad * at::xlogy((self != 0), other) + other: grad * self / other + +- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + other: grad * self / other + +- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + self: grad * at::xlogy((self != 0), other) + - name: logdet(Tensor self) -> Tensor self: logdet_backward(grad, self, result) @@ -695,6 +726,9 @@ - name: median(Tensor self) -> Tensor self: evenly_distribute_backward(grad, self, result) +- name: nanmedian(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + # This is in theory incorrect in the following case: # sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value # | at middle position of the @@ -712,6 +746,9 @@ - name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) +- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + - name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) @@ -729,7 +766,7 @@ self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) - name: mm(Tensor self, Tensor mat2) -> Tensor - self: mm_mat1_backward(grad, mat2, self, 1) + self: mm_mat1_backward(grad, mat2, self.sizes(), self.strides(), 1) mat2: mm_mat2_backward(grad, self, mat2.sizes(), mat2.strides(), 1) - name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -743,12 +780,15 @@ self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type()) - name: mv(Tensor self, Tensor vec) -> Tensor - self: grad.ger(vec) - vec: self.t().mv(grad) + self: grad.ger(vec.conj()) + vec: self.conj().t().mv(grad) - name: mvlgamma(Tensor self, int p) -> Tensor self: mvlgamma_backward(grad, self, p) +- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + self: grad * at::isfinite(self) + - name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" @@ -757,8 +797,8 @@ save_mean: not_implemented("native_batch_norm_backward save_mean") save_invstd: not_implemented("native_batch_norm_backward save_invstd") -- name: native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor) - input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_layer_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, M, N, eps, grad_input_mask) : (grads[0].defined() ? native_layer_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input, result1, result2, weight, M, N, grad_input_mask) : std::tuple())" +- name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_layer_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, normalized_shape, eps, grad_input_mask) : (grads[0].defined() ? native_layer_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple())" - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" @@ -859,8 +899,8 @@ index: non_differentiable source: grad.take(index) -- name: qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) - self: qr_backward(grads, self, some, Q, R) +- name: linalg_qr(Tensor self, str mode='reduced') -> (Tensor Q, Tensor R) + self: linalg_qr_backward(grads, self, mode, Q, R) - name: rad2deg(Tensor self) -> Tensor self: rad2deg_backward(grad) @@ -875,7 +915,7 @@ self: zeros_like(grad) - name: reciprocal(Tensor self) -> Tensor - self: -grad * result * result + self: -grad * (result * result).conj() - name: remainder.Scalar(Tensor self, Scalar other) -> Tensor self: grad @@ -900,7 +940,7 @@ self: zeros_like(grad) - name: rsqrt(Tensor self) -> Tensor - self: -0.5 * grad * result.pow(3) + self: -0.5 * grad * result.pow(3).conj() - name: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) self: grad.clone().scatter_(dim, index, 0) @@ -934,21 +974,32 @@ - name: sin(Tensor self) -> Tensor self: grad * self.cos().conj() +- name: sinc(Tensor self) -> Tensor + self: grad * ((M_PI * self * (M_PI * self).cos() - (M_PI * self).sin()) / (M_PI * self * self)).conj() + - name: sinh(Tensor self) -> Tensor self: grad * self.cosh().conj() -- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a) - self: slice_backward(grad, self.sizes(), dim, start, end, step) +- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a) + self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step) - name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) self: slogdet_backward(grad, self, sign, logabsdet) output_differentiability: [false, true] +- name: linalg_slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + self: slogdet_backward(grad, self, sign, logabsdet) + output_differentiability: [false, true] + - name: solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU) self: solve_backward_self(grad, self, A) A: solve_backward_A(grad, self, A, solution) -- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) +- name: linalg_solve(Tensor input, Tensor other) -> Tensor + input: solve_backward_A(grad, other, input, result) + other: solve_backward_self(grad, other, input) + +- name: sort(Tensor self, int dim=-1, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) output_differentiability: [True, False] @@ -965,7 +1016,7 @@ self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.options()) - name: sqrt(Tensor self) -> Tensor - self: grad / (2 * result) + self: grad / (2 * result.conj()) - name: squeeze(Tensor(a) self) -> Tensor(a) self: unsqueeze_to(grad, self.sizes()) @@ -986,11 +1037,11 @@ self: std_backward(result, grad, self, dim, unbiased, keepdim) - name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - self: grad - other: -grad * alpha + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), -grad * alpha.conj()) - name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - self: grad + self: handle_r_to_c(self.scalar_type(), grad) - name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor self: -grad * alpha @@ -1011,12 +1062,18 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) +- name: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) +- name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) + self: symeig_backward(grads, self, /*eigenvectors=*/true, /*upper=*/true, eigenvalues, eigenvectors) + +- name: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor + self: non_differentiable + - name: t(Tensor(a) self) -> Tensor(a) self: grad.t() @@ -1037,7 +1094,7 @@ index: non_differentiable - name: tan(Tensor self) -> Tensor - self: grad * (1 + result.pow(2)) + self: grad * (1 + result.pow(2)).conj() - name: tanh(Tensor self) -> Tensor self: tanh_backward(grad, result) @@ -1067,13 +1124,13 @@ - name: trunc(Tensor self) -> Tensor self: zeros_like(grad) -- name: to_dense(Tensor self) -> Tensor +- name: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor self: to_dense_backward(grad, self) - name: to_sparse(Tensor self) -> Tensor self: grad.to_dense() -- name: to_mkldnn(Tensor self) -> Tensor +- name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor self: to_mkldnn_backward(grad, self) - name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) @@ -1086,8 +1143,25 @@ self: zeros_like(grad) - name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + output_differentiability: [True, False] self: not_implemented("_unique") +- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim") + +- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_consecutive") + +- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim_consecutive") + +- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("_unique2") + - name: _unsafe_view(Tensor self, int[] size) -> Tensor self: grad.reshape(self.sizes()) @@ -1106,6 +1180,9 @@ - name: view(Tensor(a) self, int[] size) -> Tensor(a) self: grad.reshape(self.sizes()) +- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + output_differentiability: [False] + - name: view_as_real(Tensor(a) self) -> Tensor(a) self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 @@ -1173,7 +1250,7 @@ weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) - name: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor - grad_output: embedding_dense_double_backward(grad, indices) + grad_output: embedding_dense_double_backward(grad, indices, padding_idx) indices: non_differentiable - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) @@ -1221,9 +1298,9 @@ self: nll_loss2d_backward(grad, self, target, weight, reduction, ignore_index, total_weight) target: non_differentiable -- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor - self: smooth_l1_loss_backward(grad, self, target, reduction) - target: smooth_l1_loss_backward(grad, target, self, reduction) +- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor + self: smooth_l1_loss_backward(grad, self, target, reduction, beta) + target: smooth_l1_loss_backward(grad, target, self, reduction, beta) - name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor self: soft_margin_loss_backward(grad, self, target, reduction) @@ -1239,10 +1316,16 @@ self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" - name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor - self: elu_backward(grad, alpha, scale, input_scale, result) + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) + +- name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result) - name: celu(Tensor self, Scalar alpha=1.0) -> Tensor - self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), result) + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self) + +- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) - name: gelu(Tensor self) -> Tensor self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)" @@ -1296,6 +1379,10 @@ - name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor self: _sparse_softmax_backward_data(grad, result, dim, self) +- name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor + self: sparse_sparse_matmul_backward(grad, self, other, 0) + other: sparse_sparse_matmul_backward(grad, self, other, 1) + - name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor self: softplus_backward(grad, self, beta, threshold, result) @@ -1491,9 +1578,9 @@ grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) self: zeros_like(self) -- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, Tensor output) -> Tensor - grad_output: elu_backward(grad, alpha, scale, input_scale, output) - output: grad * grad_output * input_scale * (output < 0).type_as(grad) +- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + grad_output: elu_backward(grad, alpha, scale, input_scale, is_result, self_or_result) + self_or_result: elu_double_backward(grad, grad_output, alpha, scale, input_scale, is_result, self_or_result) - name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor grad_output: max_pool_double_backward(grad, indices, 2) @@ -1517,9 +1604,9 @@ target: zeros_like(grad) - name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor - grad_output: l1_loss_double_backward_grad_output(grad, self, target, reduction) - self: zeros_like(grad) - target: zeros_like(grad) + grad_output: l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction) + self: l1_loss_double_backward(grad, grad_output, self, target, reduction) + target: l1_loss_double_backward(grad, grad_output, target, target, reduction) - name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor grad_output: log_sigmoid_backward(grad, self, buffer) @@ -1589,10 +1676,10 @@ grad_output: replication_pad3d(grad, padding) self: zeros_like(self) -- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor - grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction) - self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction) - target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction) +- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor + grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, beta) + self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) + target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) - name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) -> Tensor grad_output: softplus_backward(grad, self, beta, threshold, output) @@ -1661,8 +1748,8 @@ output: grad * grad_output * (-2 * output + 1) - name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor - grad_output: tanh_backward(grad, output) - output: -2 * output * grad * grad_output + grad_output: tanh_backward(grad, output.conj()) + output: grad.conj() * (-2 * output.conj() * grad_output) # cudnn - name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) @@ -1727,13 +1814,21 @@ # Only frst three of _cudnn_rnn outputs can have gradients. # _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf) -- name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) +- name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, int hidden_size, int proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) dropout_state: non_differentiable output_differentiability: [True, True, True, False, False] - input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" + input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" -- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) +- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) dropout_state: non_differentiable + input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) # miopen @@ -1779,8 +1874,14 @@ grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, false, false, false, false, grad_input_mask) # fft -- name: _fft_with_size.norm_modes(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, int normalization, bool onesided, int[] output_sizes) -> Tensor - self: fft_backward(self, grad, signal_ndim, complex_input, complex_output, inverse, checked_signal_sizes, normalization, onesided, output_sizes) +- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back())) + +- name: _fft_c2r(Tensor self, int[] dim, int normalization, int last_dim_size) -> Tensor + self: fft_c2r_backward(grad, dim, normalization) + +- name: _fft_c2c(Tensor self, int[] dim, int normalization, bool forward) -> Tensor + self: _fft_c2c(grad, dim, normalization, !forward) - name: unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[] self: unbind_backward(grads, dim) diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index 7b4b0ece8da60..943d9adab4a0c 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -1,10 +1,10 @@ """ For procedural tests needed for __torch_function__, we use this function to export method names and signatures as needed by the tests in -test/test_overrides.py. +test/test_overrides.py. -python -m tools.autograd.gen_autograd \ - build/aten/src/ATen/Declarations.yaml \ +python -m tools.autograd.gen_annotated_fn_args \ + aten/src/ATen/native/native_functions.yaml \ $OUTPUT_DIR \ tools/autograd @@ -13,64 +13,69 @@ torch/testing/_internal/generated """ -from .utils import write, CodeTemplate -from .gen_python_functions import ( - get_py_nn_functions, - get_py_torch_functions, - get_py_variable_methods, - op_name, -) +from collections import defaultdict +import argparse +import os import textwrap -from .gen_autograd import load_aten_declarations - -def gen_annotated(aten_path, out, template_path): - declarations = load_aten_declarations(aten_path) - annotated_args = [] - for func in recurse_dict(get_py_torch_functions(declarations)): - annotated_args.append(process_func("torch._C._VariableFunctions", func)) - - for func in recurse_dict(get_py_nn_functions(declarations)): - annotated_args.append(process_func("torch._C._nn", func)) - - for func in recurse_dict(get_py_variable_methods(declarations)): - annotated_args.append(process_func("torch.Tensor", func)) - - annotated_args = textwrap.indent("\n".join(annotated_args), " ") - env = {"annotated_args": annotated_args} - PY_ANNOTATED_ARGS = CodeTemplate.from_file(template_path + '/templates/annotated_fn_args.py') - write(out, 'annotated_fn_args.py', PY_ANNOTATED_ARGS, env) - - -def process_func(namespace, func): - args = func["arguments"] - out_args = [] - for arg in args: - if 'default' in arg or arg.get('kwarg_only', False) or arg.get('output', False): +from typing import Dict, List, Any + +from tools.codegen.gen import with_native_function, parse_native_yaml, FileManager +from tools.codegen.model import * +import tools.codegen.api.python as python +from .gen_python_functions import should_generate_py_binding, is_py_torch_function, is_py_nn_function, is_py_variable_method + +def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None: + native_functions = parse_native_yaml(native_yaml_path) + mappings = ( + (is_py_torch_function, 'torch._C._VariableFunctions'), + (is_py_nn_function, 'torch._C._nn'), + (is_py_variable_method, 'torch.Tensor'), + ) + annotated_args: List[str] = [] + for pred, namespace in mappings: + groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) + for f in native_functions: + if not should_generate_py_binding(f) or not pred(f): + continue + groups[f.func.name.name].append(f) + for group in groups.values(): + for f in group: + annotated_args.append(f'{namespace}.{gen_annotated_args(f)}') + + template_path = os.path.join(autograd_dir, 'templates') + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py', lambda: { + 'annotated_args': textwrap.indent('\n'.join(annotated_args), ' '), + }) + +@with_native_function +def gen_annotated_args(f: NativeFunction) -> str: + out_args: List[Dict[str, Any]] = [] + for arg in f.func.arguments.flat_positional: + if arg.default is not None: continue - out_args.append({k: arg[k] for k in ('name', 'simple_type', 'size') if k in arg}) - - return f"{namespace}.{op_name(func)}: {out_args!r}," - + out_arg: Dict[str, Any] = {} + out_arg['name'] = arg.name + out_arg['simple_type'] = python.argument_type_str(arg.type, simple_type=True) + size = python.argument_type_size(arg.type) + if size: + out_arg['size'] = size + out_args.append(out_arg) -def recurse_dict(d): - for e in d.values(): - for i in e: - yield i + return f'{f.func.name.name}: {repr(out_args)},' - -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Generate annotated_fn_args script') - parser.add_argument('declarations', metavar='DECL', - help='path to Declarations.yaml') + parser.add_argument('native_functions', metavar='NATIVE', + help='path to native_functions.yaml') parser.add_argument('out', metavar='OUT', help='path to output directory') parser.add_argument('autograd', metavar='AUTOGRAD', help='path to template directory') args = parser.parse_args() - gen_annotated(args.declarations, args.out, args.autograd) - + gen_annotated(args.native_functions, args.out, args.autograd) if __name__ == '__main__': main() diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 82d908de61804..b930aca504df6 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -22,12 +22,8 @@ # import argparse -import copy import os -import yaml -import re -from collections import defaultdict -from .utils import YamlLoader, split_name_params, op_name_without_overload +from tools.codegen.selective_build.selector import SelectiveBuilder # See NOTE [ Autograd View Variables ] in variable.h for details. # If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT, @@ -87,230 +83,82 @@ RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({ 'chunk', 'detach', 'contiguous', 'reshape', 'reshape_as', 'expand_as', 'view_as', 'real', 'imag', 'narrow', 'movedim', + 'tensor_split', 'swapdims', 'swapaxes' }) -def format_return_type(returns): - if len(returns) == 0: - return 'void' - elif len(returns) == 1: - return returns[0]['type'] - else: - return_types = [r['type'] for r in returns] - return 'std::tuple<{}>'.format(','.join(return_types)) - - -def get_simple_type(arg): - simple_type = arg['type'] - simple_type = simple_type.replace(' &', '').replace('const ', '') - simple_type = simple_type.replace('Generator *', 'Generator') - - opt_match = re.match(r'c10::optional<(.+)>', simple_type) - if opt_match: - simple_type = '{}?'.format(opt_match.group(1)) - return simple_type - -def has_tensoroptions_argument(declaration): - for argument in declaration['arguments']: - if 'TensorOptions' == argument['dynamic_type']: - return True - return False - -def process_schema_order_arg(schema_order_arg): - if schema_order_arg == 'dtype': - return 'optTypeMetaToScalarType(options.dtype_opt())' - elif schema_order_arg == 'layout': - return 'options.layout_opt()' - elif schema_order_arg == 'device': - return 'options.device_opt()' - elif schema_order_arg == 'pin_memory': - return 'options.pinned_memory_opt()' - elif schema_order_arg == 'memory_format': - return 'c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)' - else: - return schema_order_arg - - -def load_aten_declarations(path): - with open(path, 'r') as f: - declarations = yaml.load(f, Loader=YamlLoader) - - # enrich declarations with additional information - selected_declarations = [] - for declaration in declarations: - if declaration.get('deprecated'): - continue - - for arg in declaration['arguments']: - arg['simple_type'] = get_simple_type(arg) - for ret in declaration['returns']: - ret['simple_type'] = get_simple_type(ret) - - declaration['formals'] = [arg['type'] + ' ' + arg['name'] - for arg in declaration['arguments']] - declaration['schema_order_formals'] = [arg['type'] + ' ' + arg['name'] - for arg in declaration['schema_order_arguments']] - declaration['args'] = [arg['name'] for arg in declaration['arguments']] - declaration['schema_order_args'] = [arg['name'] for arg in declaration['schema_order_arguments']] - if has_tensoroptions_argument(declaration): - declaration['schema_order_args'] = [process_schema_order_arg(arg) for arg in declaration['schema_order_args']] - declaration['api_name'] = declaration['name'] - if declaration.get('overload_name'): - declaration['type_wrapper_name'] = "{}_{}".format( - declaration['name'], declaration['overload_name']) - else: - declaration['type_wrapper_name'] = declaration['name'] - declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0] - declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1] - declaration['return_type'] = format_return_type(declaration['returns']) - - declaration['base_name'] = declaration['name'] - selected_declarations.append(declaration) - - return selected_declarations - - -def load_deprecated_signatures(aten_decls, deprecated_path): - def group_declarations_by_signature(): - d = defaultdict(list) - for declaration in aten_decls: - name = declaration['name'] - base_name = name[:-1] if declaration['inplace'] else name - simple_types = [arg['simple_type'] for arg in declaration['arguments']] - signature = '{}({})'.format(base_name, ', '.join(simple_types)) - d[signature].append(declaration) - return d - - with open(deprecated_path, 'r') as f: - deprecated_defs = yaml.load(f, Loader=YamlLoader) - declarations = [] - declarations_by_signature = group_declarations_by_signature() - - def get_signature(name, params, call_args): - # create a mapping of parameter name to parameter type - types = dict([param.split(' ')[::-1] for param in params if param != '*']) - # if the name in the call is not in the parameter list, assume it's - # a literal Scalar - rearranged_types = [types.get(arg, 'Scalar') for arg in call_args] - return '{}({})'.format(name, ', '.join(rearranged_types)) - - for deprecated in deprecated_defs: - aten_name, call_args = split_name_params(deprecated['aten']) - name, params = split_name_params(deprecated['name']) - signature = get_signature(aten_name, params, call_args) - - for declaration in declarations_by_signature[signature]: - declaration = copy.deepcopy(declaration) - declaration['deprecated'] = True - declaration['call_args'] = call_args - - call_arg_to_idx = {arg: i for i, arg in enumerate(call_args)} - original_args = declaration['arguments'] - - # Create an arguments list that uses the types from the original - # ATen declaration, but the ordering and parameter names from - # the deprecated overload. Any default parameter values from the - # original ATen declaration are ignored. - arguments = [] - kwarg_only = False - for param in params: - if param == '*': - kwarg_only = True - continue - _, param_name = param.split(' ') - original = original_args[call_arg_to_idx[param_name]] - arguments.append({ - 'name': param_name, - 'kwarg_only': kwarg_only, - 'type': original['type'], - 'simple_type': original['simple_type'], - 'dynamic_type': original['dynamic_type'], - 'output': original.get('output', False), - }) - declaration['arguments'] = arguments - declarations.append(declaration) - return declarations - - -def gen_autograd(aten_path, out, autograd_dir, disable_autograd=False, selected_op_list=None): - full_aten_decls = load_aten_declarations(aten_path) - - def filter_decls(aten_decls, selected_op_list): - if selected_op_list is None: - return aten_decls - return [decl for decl in aten_decls if op_name_without_overload(decl) in selected_op_list] - - aten_decls = filter_decls(full_aten_decls, selected_op_list) - +def gen_autograd( + aten_path: str, + native_functions_path: str, + out: str, + autograd_dir: str, + operator_selector: SelectiveBuilder, + disable_autograd: bool = False, +) -> None: # Parse and load derivatives.yaml from .load_derivatives import load_derivatives - autograd_functions = load_derivatives( - os.path.join(autograd_dir, 'derivatives.yaml'), full_aten_decls) + differentiability_infos = load_derivatives( + os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path) template_path = os.path.join(autograd_dir, 'templates') # Generate VariableType.h/cpp + from .gen_trace_type import gen_trace_type + from .gen_variable_type import gen_variable_type if not disable_autograd: - from .gen_variable_type import gen_variable_type - gen_variable_type(out, aten_decls, template_path) + gen_variable_type(out, native_functions_path, differentiability_infos, template_path, operator_selector) + + # operator filter not applied as tracing sources are excluded in selective build + gen_trace_type(out, native_functions_path, template_path) # Generate Functions.h/cpp from .gen_autograd_functions import gen_autograd_functions_lib gen_autograd_functions_lib( - out, autograd_functions, template_path) + out, differentiability_infos, template_path) # Generate variable_factories.h from .gen_variable_factories import gen_variable_factories - # Some non-selectable ops (e.g. prim ops) need factory methods so we pass in `full_aten_decls` here. - gen_variable_factories(out, full_aten_decls, template_path) - - -def gen_autograd_python(aten_path, out, autograd_dir): + gen_variable_factories(out, native_functions_path, template_path) - # TODO Deduplicate these four variable assignments - aten_decls = load_aten_declarations(aten_path) - - # Parse and load derivatives.yaml +def gen_autograd_python( + aten_path: str, + native_functions_path: str, + out: str, + autograd_dir: str, +) -> None: from .load_derivatives import load_derivatives - autograd_functions = load_derivatives( - os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls) + differentiability_infos = load_derivatives( + os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path) template_path = os.path.join(autograd_dir, 'templates') - # Load deprecated signatures - deprecated = load_deprecated_signatures( - aten_decls, os.path.join(autograd_dir, 'deprecated.yaml')) - # Generate Functions.h/cpp from .gen_autograd_functions import gen_autograd_functions_python gen_autograd_functions_python( - out, autograd_functions, template_path) + out, differentiability_infos, template_path) # Generate Python bindings from . import gen_python_functions - gen_python_functions.gen_py_variable_methods( - out, aten_decls + deprecated, template_path) - gen_python_functions.gen_py_torch_functions( - out, aten_decls + deprecated, template_path) - gen_python_functions.gen_py_nn_functions( - out, aten_decls, template_path) - gen_python_functions.gen_py_fft_functions( - out, aten_decls, template_path) - gen_python_functions.gen_py_linalg_functions( - out, aten_decls, template_path) + deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml') + gen_python_functions.gen( + out, native_functions_path, deprecated_path, template_path) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Generate autograd C++ files script') parser.add_argument('declarations', metavar='DECL', help='path to Declarations.yaml') + parser.add_argument('native_functions', metavar='NATIVE', + help='path to native_functions.yaml') parser.add_argument('out', metavar='OUT', help='path to output directory') parser.add_argument('autograd', metavar='AUTOGRAD', help='path to autograd directory') args = parser.parse_args() - gen_autograd(args.declarations, args.out, args.autograd) + gen_autograd(args.declarations, args.native_functions, + args.out, args.autograd, + SelectiveBuilder.get_nop_selector()) if __name__ == '__main__': diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 0cac0e0b9168e..4724b99a87429 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -4,11 +4,17 @@ # Functions.h/cpp: subclasses of autograd::Node # python_functions.h/cpp: Python bindings for the above classes # -import os import re -from .utils import nested_dict, CodeTemplate, write from .gen_autograd import VIEW_FUNCTIONS -from .utils import IDENT_REGEX + +from typing import List, Sequence, Tuple, Optional + +from tools.codegen.api.autograd import * +from tools.codegen.api.types import * +from tools.codegen.code_template import CodeTemplate +from tools.codegen.gen import FileManager +from tools.codegen.model import * +from tools.codegen.utils import * FUNCTION_DECLARATION = CodeTemplate("""\ struct TORCH_API ${op} : public ${superclass} { @@ -84,142 +90,152 @@ # TODO: This is probably not exhaustive, but it's a start UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS - -def gen_autograd_functions_lib(out, autograd_functions, template_path): - gen_autograd_functions(out, autograd_functions, template_path, "Functions") - - -def gen_autograd_functions_python(out, autograd_functions, template_path): - gen_autograd_functions(out, autograd_functions, template_path, "python_functions") - - -def gen_autograd_functions(out, autograd_functions, template_path, file_basename): +def gen_autograd_functions_lib( + out: str, + differentiability_infos: Sequence[DifferentiabilityInfo], + template_path: str, +) -> None: + gen_autograd_functions(out, differentiability_infos, template_path, "Functions") + +def gen_autograd_functions_python( + out: str, + differentiability_infos: Sequence[DifferentiabilityInfo], + template_path: str, +) -> None: + gen_autograd_functions(out, differentiability_infos, template_path, "python_functions") + +def gen_autograd_functions( + out: str, + differentiability_infos: Sequence[DifferentiabilityInfo], + template_path: str, + file_basename: str, +) -> None: """Functions.h and Functions.cpp body These contain the auto-generated subclasses of torch::autograd::Node for each every differentiable torch function. """ - function_definitions = [] - function_declarations = [] - py_function_initializers = [] - - for func in autograd_functions: - env = process_function(func) - - function_declarations.append(FUNCTION_DECLARATION.substitute(env)) - function_definitions.append(FUNCTION_DEFINITION.substitute(env)) - py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(env)) - - top_env = { - 'autograd_function_definitions': function_definitions, - 'autograd_function_declarations': function_declarations, - 'py_function_initializers': py_function_initializers, - } - - for suffix in [".h", ".cpp"]: - f = file_basename + suffix - templated_output = CodeTemplate.from_file(os.path.join(template_path, f)) - write(out, f, templated_output, top_env) - - -def process_function(func): - env = {} - saved_variables = [] - release_variables = [] - saved_list_sizes = [] - unpack = [] - asserts = [] - - env['compute_index_ranges'] = [] - for arg in func['args_with_derivatives']: - if arg['type'] == 'TensorList': - size = '{}_size_'.format(arg['name']) - saved_list_sizes.append('size_t {}_size_;'.format(arg['name'])) + # only create an autograd function if we are actually going to calculate a derivative + infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos)) + declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) + definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) + py_function_initializers = list(map(lambda f: process_function(f, PY_FUNCTION_DEFINITION), infos)) + + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for suffix in ['.h', '.cpp']: + fname = file_basename + suffix + fm.write_with_template(fname, fname, lambda: { + 'generated_comment': '@' + f'generated from {fm.template_dir}/' + fname, + 'autograd_function_declarations': declarations, + 'autograd_function_definitions': definitions, + 'py_function_initializers': py_function_initializers, + }) + +def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: + saved_variables: List[str] = [] + release_variables: List[str] = [] + saved_list_sizes: List[str] = [] + unpack: List[str] = [] + asserts: List[str] = [] + compute_index_ranges: List[str] = [] + + for arg in info.args_with_derivatives: + if arg.type == 'TensorList' or arg.type == 'const c10::List> &': + size = f'{arg.name}_size_' + saved_list_sizes.append(f'size_t {arg.name}_size_;') else: size = '1' - env['compute_index_ranges'].append('auto {}_ix = gen.range({});'.format(arg['name'], size)) - - def save_arg(arg, is_output): - name = arg['name'] - - if arg['type'] == 'Tensor' or arg['type'] == 'c10::optional' or arg['type'] == 'c10::optional&' or \ - (arg['type'] == 'Scalar' and is_output): - saved_variables.append('SavedVariable {}_;'.format(name)) - release_variables.append('{}_.reset_data();'.format(name)) - release_variables.append('{}_.reset_grad_function();'.format(name)) + compute_index_ranges.append(f'auto {arg.name}_ix = gen.range({size});') + + def save_var(var: SavedAttribute, is_output: bool) -> None: + name = var.name + if var.type == 'Tensor' or var.type == 'c10::optional' or var.type == 'c10::optional&' or \ + (var.type == 'Scalar' and is_output): + saved_variables.append(f'SavedVariable {name}_;') + release_variables.append(f'{name}_.reset_data();') + release_variables.append(f'{name}_.reset_grad_function();') ptr = 'shared_from_this()' if is_output else '' - unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr)) - elif arg['type'] == 'TensorList': - saved_variables.append('std::vector {}_;'.format(name)) - saved_variables.append('bool {}_released_ = false;'.format(name)) + unpack.append(f'auto {name} = {name}_.unpack({ptr});') + elif var.type == 'TensorList': + saved_variables.append(f'std::vector {name}_;') + saved_variables.append(f'bool {name}_released_ = false;') + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f'{name}_.clear();') + release_variables.append(f'{name}_released_ = true;') + unpack.append(f'auto {name} = unpack_list({name}_);') + asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') + elif var.type == 'c10::List>': + saved_variables.append(f'std::vector {name}_;') + saved_variables.append(f'bool {name}_released_ = false;') # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. - release_variables.append('{}_.clear();'.format(name)) - release_variables.append('{}_released_ = true;'.format(name)) - unpack.append('auto {} = unpack_list({}_);'.format(name, name)) - asserts.append('TORCH_CHECK(!{}_released_, ERR_BACKWARD_TWICE);'.format(name)) - elif arg['type'] == 'IntArrayRef': - saved_variables.append('std::vector {};'.format(name)) - elif arg['type'] == 'c10::optional': - saved_variables.append('c10::OptionalArray {};'.format(name)) - elif arg['type'] == 'c10::optional>': - saved_variables.append('c10::OptionalArray {};'.format(name)) - elif arg['type'] == 'int64_t': - saved_variables.append('{} {} = 0;'.format(arg['type'], name)) + release_variables.append(f'{name}_.clear();') + release_variables.append(f'{name}_released_ = true;') + unpack.append(f'auto {name} = unpack_opt_list({name}_);') + asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') + elif var.type == 'IntArrayRef': + saved_variables.append(f'std::vector {name};') + elif var.type == 'c10::optional': + saved_variables.append(f'c10::OptionalArray {name};') + elif var.type == 'c10::optional>': + saved_variables.append(f'c10::OptionalArray {name};') + elif var.type == 'int64_t': + saved_variables.append(f'{var.type} {name} = 0;') else: - saved_variables.append('{} {};'.format(arg['type'], name)) + saved_variables.append(f'{var.type} {name};') - for arg in func['saved_inputs']: - save_arg(arg, is_output=False) - for arg in func['saved_outputs']: - save_arg(arg, is_output=True) - env['saved_variables'] = saved_variables - env['release_variables'] = release_variables - env['saved_list_sizes'] = saved_list_sizes - env['asserts'] = asserts + for var in info.all_saved_inputs: + save_var(var, is_output=False) + for var in info.all_saved_outputs: + save_var(var, is_output=True) # lock the mutex when we release variables and in Node::apply to protect thread safety # see Note [Thread Safety on Autograd Node] if len(release_variables) > 0: - env['thread_lock'] = "std::lock_guard lock(mutex_);" + thread_lock = 'std::lock_guard lock(mutex_);' else: - env['thread_lock'] = '' + thread_lock = '' - if uses_retain_variables(func): - env['will_release_variables'] = WILL_RELEASE_VARIABLES.substitute() + if uses_retain_variables(info): + will_release_variables = WILL_RELEASE_VARIABLES.substitute() else: - env['will_release_variables'] = '' + will_release_variables = '' - body = [] + body: List[str] = [] - if uses_single_grad(func): + if uses_single_grad(info): body.append('auto& grad = grads[0];') - def emit_derivative(derivative, args_with_derivatives): - formula = derivative['formula'] - var_names = derivative['var_names'] + def emit_derivative( + derivative: Derivative, + args_with_derivatives: Sequence[Binding], + ) -> Tuple[bool, str]: + formula = derivative.formula + var_names = derivative.var_names if len(var_names) == 1: checks_any_grad_defined = False if 'not_implemented' not in formula: matching_args = [ arg for arg in args_with_derivatives - if ('name' in arg) and (arg['name'] == var_names[0])] + if arg.name == var_names[0]] if len(matching_args) == 1: # We can add undefined grad support if the input variable is a Tensor - if ('simple_type' in matching_args[0].keys()) and (matching_args[0]['simple_type'] == 'Tensor'): + arg = matching_args[0] + if isinstance(arg.argument, Argument) and str(arg.argument.type) == 'Tensor': formula = 'any_grad_defined ? (' + formula + ') : Tensor()' checks_any_grad_defined = True return (checks_any_grad_defined, DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula)) else: if 'grad_input_mask' in formula: - masks = ['should_compute_output({{ {}_ix }}),'.format(n) for n in var_names] + masks = [f'should_compute_output({{ {n}_ix }}),' for n in var_names] grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names)) else: grad_input_mask = '' - idx_ranges = ', '.join("{}_ix".format(n) for n in var_names) - copy_ranges = [] + idx_ranges = ', '.join(f'{n}_ix' for n in var_names) + copy_ranges: List[str] = [] for i, n in enumerate(var_names): copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) return False, DERIVATIVE_MULTI.substitute( @@ -229,37 +245,45 @@ def emit_derivative(derivative, args_with_derivatives): body.extend(unpack) need_any_grad_defined_var = False - for derivative in func['derivatives']: - checks_any_grad_defined, derivative_text = emit_derivative(derivative, func['args_with_derivatives']) + for derivative in info.derivatives: + checks_any_grad_defined, derivative_text = emit_derivative(derivative, info.args_with_derivatives) body.append(derivative_text) need_any_grad_defined_var |= checks_any_grad_defined # Since single-output derivative formulas need to check if grads are # defined, only perform the check once, before all the formulas if need_any_grad_defined_var: - body.insert(-len(func['derivatives']), + body.insert(-len(info.derivatives), 'bool any_grad_defined = any_variable_defined(grads);') - env['body'] = body - if func['name'] in UNTRACEABLE_FUNCTIONS: - env['superclass'] = 'Node' + if info.name in UNTRACEABLE_FUNCTIONS: + superclass = 'Node' else: - env['superclass'] = 'TraceableFunction' - return nested_dict(env, func) - - -def uses_ident(func, ident): - if func is None: + superclass = 'TraceableFunction' + + return template.substitute( + op=info.op, + compute_index_ranges=compute_index_ranges, + saved_variables=saved_variables, + release_variables=release_variables, + saved_list_sizes=saved_list_sizes, + asserts=asserts, + thread_lock=thread_lock, + will_release_variables=will_release_variables, + body=body, + superclass=superclass, + ) + +def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool: + if info is None: return False - for derivative in func['derivatives']: - formula = derivative['formula'] + for derivative in info.derivatives: + formula = derivative.formula if re.search(IDENT_REGEX.format(ident), formula): return True return False +def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool: + return uses_ident(info, 'retain_variables') -def uses_retain_variables(func): - return uses_ident(func, 'retain_variables') - - -def uses_single_grad(func): - return uses_ident(func, 'grad') +def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool: + return uses_ident(info, 'grad') diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 995dff38030bf..0450983a8e419 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -31,11 +31,26 @@ # from collections import defaultdict +import itertools import re -from .gen_variable_type import should_trace -from .utils import write, is_tensor_method +import yaml + +from .gen_trace_type import should_trace from tools.codegen.code_template import CodeTemplate +from tools.codegen.api.types import * +from tools.codegen.api.python import * +from tools.codegen.gen import cpp_string, parse_native_yaml, with_native_function, FileManager +from tools.codegen.model import * +from tools.codegen.utils import * + +from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable + +try: + # use faster C loader if available + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader # type: ignore # # declarations blocklist @@ -63,10 +78,10 @@ 'copy_sparse_to_sparse_', 'copy_', 'numpy_T', # this needs to be an attribute in Python, not a function 'nonzero(_(out|numpy))?', - 'set_quantizer_', # return types not supported yet 'set_data', '.*_overrideable', # overrideable functions for backend extension - 'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_' + 'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_', + '_fw_primal' ] # These function signatures are not exposed to Python. Note that this signature @@ -78,779 +93,296 @@ 'div(Tensor, Scalar)', 'div_(Tensor, Scalar)', ] -NATIVE_NAMESPACE_MAPPING = { - "torch": "THPVariableFunctionsModule", - "torch.nn": "THPNNVariableFunctionsModule", - "torch.fft": "THPFFTVariableFunctionsModule", - "torch.linalg": "THPLinalgVariableFunctionsModule", -} - -def should_generate_python_binding(declaration): - name = declaration['name'] +@with_native_function +def should_generate_py_binding(f: NativeFunction) -> bool: + name = cpp.name(f.func) for pattern in SKIP_PYTHON_BINDINGS: if re.match('^' + pattern + '$', name): return False - simple_types = [arg['simple_type'] for arg in declaration['arguments']] - signature = '{}({})'.format(name, ', '.join(simple_types)) + args = ', '.join(argument_type_str(arg.type) + for arg in signature(f).arguments()) + sig = f'{name}({args})' for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: - if pattern == signature: + if pattern == sig: return False return True -# -# top-level codegen functions, called from gen_autograd -# - -def get_py_variable_methods(declarations): - """ - Get declarations (grouped by name) which should be generated - as methods on Tensor. - """ - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - not is_nn_module_function(declaration) and - is_tensor_method(declaration)) - - return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) - - -def gen_py_variable_methods(out, declarations, template_path): - """ - Generate Tensor methods. - """ - PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') - - py_variable_methods = get_py_variable_methods(declarations) - - env = create_python_bindings(py_variable_methods, is_python_method=True, module=None) - - write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env) - - -def get_py_nn_functions(declarations): - """ - Get declarations (grouped by name) which should be generated - as functions in the "nn" module. - """ - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - is_nn_module_function(declaration)) - - return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) - - -def gen_py_nn_functions(out, declarations, template_path): - """ - Generate functions in the "nn" module. - """ - PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp') - - py_nn_functions = get_py_nn_functions(declarations) - - env = create_python_bindings(py_nn_functions, is_python_method=False, module="torch.nn") - - write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env) - - -def get_py_fft_functions(declarations): - """ - Get declarations (grouped by name) which should be generated - as functions in the "fft" module. - """ - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - is_fft_module_function(declaration)) - - return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) - - -def gen_py_fft_functions(out, declarations, template_path): - """ - Generate functions in the "fft" module. - """ - PY_FFT_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_fft_functions.cpp') - - py_fft_functions = get_py_fft_functions(declarations) - - env = create_python_bindings(py_fft_functions, is_python_method=False, module="torch.fft") - - write(out, 'python_fft_functions.cpp', PY_FFT_FUNCTIONS_CPP, env) +def get_pycname(name: BaseOperatorName) -> str: + return f'THPVariable_{name}' -def get_py_linalg_functions(declarations): - """ - Get declarations (grouped by name) which should be generated - as functions in the "linalg" module. - """ - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - is_linalg_module_function(declaration)) - - return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) +def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool: + return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0 +def is_py_variable_method(f: NativeFunction) -> bool: + return f.python_module is None and Variant.method in f.variants -def gen_py_linalg_functions(out, declarations, template_path): - """ - Generate functions in the "linalg" module. - """ - PY_LINALG_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_linalg_functions.cpp') - - py_linalg_functions = get_py_linalg_functions(declarations) +def is_py_torch_function(f: NativeFunction) -> bool: + return f.python_module is None and Variant.function in f.variants - env = create_python_bindings(py_linalg_functions, is_python_method=False, module="torch.linalg") - - write(out, 'python_linalg_functions.cpp', PY_LINALG_FUNCTIONS_CPP, env) - - -def get_py_torch_functions(declarations): - """ - Get declarations (grouped by name) which should be generated - as functions in the "torch" module. - """ - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - not is_nn_module_function(declaration) and - not is_fft_module_function(declaration) and - not is_linalg_module_function(declaration) and - is_torch_function(declaration)) +def is_py_nn_function(f: NativeFunction) -> bool: + return f.python_module == 'nn' - return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) +def is_py_fft_function(f: NativeFunction) -> bool: + return f.python_module == 'fft' +def is_py_linalg_function(f: NativeFunction) -> bool: + return f.python_module == 'linalg' -def gen_py_torch_functions(out, declarations, template_path): - """ - Generate functions in the "torch" module. - """ - PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp') +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Main Function +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # - py_torch_functions = get_py_torch_functions(declarations) +def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None: + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) - env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch") + methods = load_signatures(native_yaml_path, deprecated_yaml_path, method=True) + create_python_bindings( + fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True) - write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env) + functions = load_signatures(native_yaml_path, deprecated_yaml_path, method=False) + create_python_bindings( + fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False) + create_python_bindings( + fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False) -def group_declarations_by_op_name(declarations): - groups = defaultdict(list) - for d in declarations: - groups[op_name(d)].append(d) - return groups + create_python_bindings( + fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False) + create_python_bindings( + fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False) -def create_python_bindings(python_functions, is_python_method, module): +def create_python_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: Optional[str], + filename: str, + *, + method: bool, +) -> None: """Generates Python bindings to ATen functions""" - py_methods = [] - py_method_defs = [] - py_forwards = [] - - for name in sorted(python_functions.keys()): - overload_decls = python_functions[name] - py_methods.append(method_impl(name, overload_decls, is_python_method, module)) - py_method_defs.append(method_def(name, overload_decls, is_python_method, module)) - py_forwards.extend(forward_decls(name, overload_decls, is_python_method, module)) - - return { + py_methods: List[str] = [] + py_method_defs: List[str] = [] + py_forwards: List[str] = [] + + grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) + for pair in pairs: + if pred(pair.function): + grouped[pair.function.func.name.name].append(pair) + + for name in sorted(grouped.keys(), key=lambda x: str(x)): + overloads = grouped[name] + py_methods.append(method_impl(name, module, overloads, method=method)) + py_method_defs.append(method_def(name, module, overloads, method=method)) + py_forwards.extend(forward_decls(name, overloads, method=method)) + + fm.write_with_template(filename, filename, lambda: { + 'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}', 'py_forwards': py_forwards, 'py_methods': py_methods, 'py_method_defs': py_method_defs, - } - - -# -# extracting and storing parsed args -# - -UNPACK_METHODS = { - 'const Tensor &': 'tensor', - 'Tensor &': 'tensor', - 'c10::optional': 'optionalTensor', - 'const c10::optional&': 'optionalTensor', - 'c10::optional': 'generator', - 'Storage': 'storage', - 'Storage &': 'storage', - 'const ScalarType &': 'scalartype', - 'const Device &': 'device', - 'c10::optional': 'toDimnameListOptional', - 'c10::optional': 'scalartypeOptional', - 'c10::optional': 'layoutOptional', - 'c10::optional': 'memoryformatOptional', - 'c10::optional': 'scalarOptional', - 'c10::optional': 'intlistOptional', - 'c10::optional': 'toInt64Optional', - 'c10::optional': 'toBoolOptional', - 'c10::optional': 'toDoubleOptional', - 'c10::optional>': 'doublelistOptional', - 'IntArrayRef': 'intlist', - 'Scalar': 'scalar', - 'ScalarType': 'scalartype', - 'Dimname': 'dimname', - 'DimnameList': 'dimnamelist', - 'TensorList': 'tensorlist', - 'int64_t': 'toInt64', - 'bool': 'toBool', - 'double': 'toDouble', - 'std::string': 'string', - 'c10::optional': 'stringOptional', -} - -UNPACK_WITH_SIZE_METHODS = { - 'TensorList': 'tensorlist_n<{}>', - 'DimnameList': 'dimnamelist', - 'IntArrayRef': 'intlist', - 'c10::optional': 'intlistOptional', -} - -UNPACK_WITH_DEFAULT_METHODS = { - 'const ScalarType &': 'scalartypeWithDefault', - 'const Device &': 'deviceWithDefault', - 'c10::optional': 'layoutWithDefault', -} - -def parsed_arg_expr(arg, arg_index): - # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' - typename = arg['type'] - - default_init = arg.get('python_default_init') - if default_init is not None: - # Note: only introduced by make_python_binding_args - default_init = arg['python_default_init'] - if typename not in UNPACK_WITH_DEFAULT_METHODS: - raise RuntimeError( - 'type \'{}\' is not supported in python_default_init'. - format(typename)) - unpack_with_default = UNPACK_WITH_DEFAULT_METHODS[typename] - return '_r.{}({}, {})'.format(unpack_with_default, arg_index, default_init) - - size = arg.get('size') - if size is not None: - if typename not in UNPACK_WITH_SIZE_METHODS: - raise RuntimeError( - 'type \'{}\' with definite size ({}) is not supported'. - format(typename, size)) - unpack_with_size = UNPACK_WITH_SIZE_METHODS[typename].format(size) - return '_r.{}({})'.format(unpack_with_size, arg_index) - - unpack = UNPACK_METHODS.get(typename) - if unpack is None: - raise RuntimeError('type \'{}\' is not supported'.format(typename)) - - return '_r.{}({})'.format(unpack, arg_index) - - -# TODO make this part of something more general, or get rid of it -def unpack_optional_dimname_list_hack(name, expr): - # optional> are special. The PythonArgParser returns an - # optional>, which cannot be implicitly converted to - # optional>. One needs to unwrap the optional and rewrap. - result = """\ - auto __{name} = {expr}; - c10::optional<{typ}> {name} = __{name} ? c10::make_optional({typ}(__{name}.value())) : c10::nullopt; - """.format(name=name, expr=expr, typ='DimnameList') - return [line.strip() for line in result.split('\n')] - - -def parse_arg(arg, arg_index, unpack_to_local=False): - # get parsed rhs - expr = parsed_arg_expr(arg, arg_index) - - # maybe unpack to local - name = arg['name'] - typename = arg['type'] - if typename == 'c10::optional': - inits = unpack_optional_dimname_list_hack(name, expr) - expr = name - elif unpack_to_local: - inits = ['auto {} = {};'.format(name, expr)] - expr = name - else: - inits = [] - - return expr, inits - - -# -# schema type to cpp type conversions -# some of these are to prevent dangling refs to temps, others are more obscure -# TODO don't know if these fold into more general conversions somehere, hope so -# - -TEMP_SAFE_CPP_DECL_TYPE = { - 'Tensor &': 'Tensor', -} - -def get_cpp_decl_type(typename, ensure_temp_safe=True): - if ensure_temp_safe: - typename = TEMP_SAFE_CPP_DECL_TYPE.get(typename, typename) - return typename - - -def get_cpp_formal(arg, ensure_temp_safe=True): - decl_type = get_cpp_decl_type(arg['type'], ensure_temp_safe) - return '{} {}'.format(decl_type, arg['name']) - - -# XXX: if you got here because of an assertion failure, it doesn't mean -# it's enough to just extend the list here. Before you do this, make sure -# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. -SUPPORTED_RETURN_TYPES = { - 'Tensor', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::tuple', - 'std::vector', - 'Scalar', 'bool', 'int64_t', 'void*', 'void', - 'QScheme', 'double', - 'IntArrayRef', - 'ScalarType' -} - -def get_simple_return_type(declaration): - # Use the simple_return_type (Tensor) rather than the fancy return type - # (Tensor &). This is important because the dispatch lambdas take - # mutable arguments *by value*, not by reference. If you then return - # a reference to such an argument, you will now have a pointer to a - # dangling stack entry. Not good. - # - # You want: - # - # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; - # ^^^^^^ - # - # *not* - # - # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; - # ^^^^^^^ - # - # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing - # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a - # mutable reference to temporary. Maybe we could assign it to a - # variable itself.) - # - simple_return_type = declaration['return_type'].replace(' &', '') - if simple_return_type not in SUPPORTED_RETURN_TYPES: - raise RuntimeError(declaration['name'] + " returns unsupported type " + simple_return_type) - return simple_return_type - -# -# dispatch codegen -# - -def get_dispatch_callee(declaration): - # format the name of the receiving function or method - if is_tensor_method(declaration): - return 'self.{}'.format(declaration['name']) - elif is_torch_function(declaration): - namespace = function_namespace(declaration) - return '{}::{}'.format(namespace, declaration['name']) - else: - raise RuntimeError('could not dispatch, neither namespace function nor Tensor method') - - -def get_op_args(declaration, argmap): - # returns a list of argmap values in op call order, with two wrinkles: - # 1. 'self' is eliminated for methods, it's baked into the callee expression elsewhere - # 2. declaration['call_args'] shims legacy overrides and may contain constant values, - # not just names (see load_deprecated_signatures() in gen_autograd.py) - call_args_override = declaration.get('call_args') - if call_args_override: - # names or constants - keys = call_args_override - else: - # only names - keys = [param['name'] for param in declaration['arguments']] - - if is_tensor_method(declaration): - # exclude self for method calls - keys = [k for k in keys if k != 'self'] - - if call_args_override: - # assume missing keys are constants - return [argmap.get(k, k) for k in keys] - else: - return [argmap[k] for k in keys] - - -TENSOR_OPTIONS_DECL = CodeTemplate("""\ -const auto ${name} = TensorOptions() - .dtype(${dtype}) - .device(${device}) - .layout(${layout}) - .requires_grad(${requires_grad}) - .pinned_memory(${pin_memory}); -""") - -# addition to output-variant handler in which tensor options params -# (if present) are checked against properties of a tensor output param -# TODO remove hardcoding, use unpack logic from emit_single_dispatch -PY_VARIABLE_CHECK_OUT_TYPE_HACK = CodeTemplate("""\ -check_out_type_matches(_r.tensor(${out_idx}), _r.scalartype(${type_idx}), - _r.isNone(${type_idx}), _r.layoutOptional(${layout_idx}), - _r.device(${device_idx}), _r.isNone(${device_idx})); -""") - -# Unpack parsed args to locals, call the op, and wrap the result. -# Lambda is so GIL is back on by wrap() time (wrap can allocate) -PY_VARIABLE_WRAP = CodeTemplate("""\ -${inits} -auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} { - ${auto_no_gil} - return ${dispatch_callee}(${dispatch_args}); -}; -return wrap(${namedtuple_typeref}dispatch_${name}(${lambda_args})${set_requires_grad}); -""") - -# void return variant -PY_VARIABLE_RETURN_VOID = CodeTemplate("""\ -${inits} -auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} { - ${auto_no_gil} - ${dispatch_callee}(${dispatch_args}); -}; -dispatch_${name}(${lambda_args})${set_requires_grad}; -Py_RETURN_NONE; -""") - - -def emit_single_dispatch(declaration, is_python_method, output_gap=0): - """ - Emit dispatch code for a single declared overload. - """ - deprecated = '[deprecated] ' if declaration.get('deprecated', False) else '' - schema_comment = '// ' + deprecated + declaration['schema_string'] - inits = [schema_comment] - - pa = declaration['python_arglists'] - args = pa['input_args'] + pa['input_kwargs'] + pa['output_args'] - has_options = has_tensor_options(declaration) - - argmap = {} - - if is_python_method: - # self is passed directly to python binding, rather than parsed - argmap['self'] = {'value': 'self', 'formal': 'Tensor & self'} - - for i, arg in enumerate(args): - unpack = is_scatter(arg) or (has_options and is_tensor_self(arg)) - arg_expr, unpack_stmts = parse_arg(arg, i, unpack_to_local=unpack) - inits.extend(unpack_stmts) - if is_scatter(arg): - for j, elem in enumerate(arg['scatter_args']): - argmap[elem['name']] = { - 'value': '{}[{}]'.format(arg_expr, j), - 'formal': get_cpp_formal(elem, ensure_temp_safe=False), - } - else: - argmap[arg['name']] = {'value': arg_expr, 'formal': get_cpp_formal(arg)} - - # synthetic python binding args deliver op args - binding_argmap, binding_inits, set_requires_grad = \ - handle_python_binding_args(declaration, output_gap) - argmap.update(binding_argmap) - inits.extend(binding_inits) - - lambda_formals = [argmap[arg['name']]['formal'] for arg in declaration['arguments']] - lambda_args = [argmap[arg['name']]['value'] for arg in declaration['arguments']] - - dispatch_callee = get_dispatch_callee(declaration) - dispatch_args = get_op_args(declaration, {name: name for name, _ in argmap.items()}) - - auto_no_gil = [] if declaration['with_gil'] else ['pybind11::gil_scoped_release no_gil;'] - - simple_return_type = get_simple_return_type(declaration) - if simple_return_type == 'void': - template = PY_VARIABLE_RETURN_VOID - else: - template = PY_VARIABLE_WRAP - - return template.substitute( - name=declaration['name'], - inits=inits, - lambda_formals=lambda_formals, - lambda_args=lambda_args, - dispatch_callee=dispatch_callee, - dispatch_args=dispatch_args, - auto_no_gil=auto_no_gil, - set_requires_grad=set_requires_grad, - simple_return_type=simple_return_type, - namedtuple_typeref=declaration['namedtuple_typeref'], - ) - - -# arg['name'] to arg['simple_type'] for scattered tensor options fields -TENSOR_OPTIONS_FIELDS = { - 'dtype': 'ScalarType', - 'device': 'Device', - 'layout': 'Layout', - 'pin_memory': 'bool', - 'requires_grad': 'bool', -} - -def handle_python_binding_args(declaration, output_gap): - # map synthetic python binding args to op args and misc other stuff - # note: this logic shares arcane knowledge with make_python_binding_args - # and isn't completely airtight w.r.t. the possible contents of - # python_binding_args. TODO - - argmap = {} - inits = [] - set_requires_grad = '' - - pa = declaration['python_arglists'] - python_binding_args = pa['python_binding_args'] - - if len(python_binding_args) == 0: - # nothing to see here - return argmap, inits, set_requires_grad - - args = pa['input_args'] + pa['input_kwargs'] + pa['output_args'] - binding_arg_base = len(args) + output_gap - binding_arg_offsets = {arg['name']: i for i, arg in enumerate(python_binding_args)} - - def binding_arg_index(name): - return binding_arg_base + binding_arg_offsets[name] - - def parse_binding_arg(name): - binding_arg = python_binding_args[binding_arg_offsets[name]] - expr, _ = parse_arg(binding_arg, binding_arg_index(name)) - return expr - - has_output = len(pa['output_args']) == 1 - tensor_options_arg = get_tensor_options(declaration) - - if tensor_options_arg is not None: - # if our op has a tensor options arg, these are its scattered fields. - # first some checks - if has_output: - raise RuntimeError('{}: tensor options with output arg'.format(declaration['name'])) - for arg in python_binding_args: - typename = TENSOR_OPTIONS_FIELDS.get(arg['name']) - if typename is None: - raise RuntimeError( - '{}: unrecognized tensor options field \'{}\' in python binding arguments'. - format(declaration['name'], arg['name'])) - if typename != arg['simple_type']: - raise RuntimeError( - '{}: unrecognized type \'{}\' for tensor options field \'{}\' in python binding arguments'. - format(declaration['name'], arg['type'], arg['name'])) - python_binding_argnames = [arg['name'] for arg in python_binding_args] - if not all([key in python_binding_argnames for key in TENSOR_OPTIONS_FIELDS.keys()]): - raise RuntimeError( - '{}: incomplete tensor options args: {}'. - format(declaration['name'], [arg['name'] for arg in python_binding_args])) - # generate a gathering initialization of options struct - argname = tensor_options_arg['name'] - inits.append(TENSOR_OPTIONS_DECL.substitute({ - 'name': argname, - 'dtype': parse_binding_arg('dtype'), - 'layout': parse_binding_arg('layout'), - 'device': parse_binding_arg('device'), - 'requires_grad': parse_binding_arg('requires_grad'), - 'pin_memory': parse_binding_arg('pin_memory'), - })) - inits.append('torch::utils::maybe_initialize_cuda({});'.format(argname)) - # and add to op arg map - argmap['options'] = { - 'value': argname, - 'formal': get_cpp_formal(tensor_options_arg), - } - - else: - # not the scattered fields of a tensor options - sort of a grab bag - if 'dtype' in binding_arg_offsets: - # we're an output-arg variant, check these args against output tensor - if not has_output: - raise RuntimeError( - '{}: dtype in python_binding_args without output arg'. - format(declaration['name'])) - if not all([name in binding_arg_offsets for name in ['layout', 'device']]): - raise RuntimeError( - '{}: incomplete tensor options for output check'. - format(declaration['name'])) - check_type = PY_VARIABLE_CHECK_OUT_TYPE_HACK.substitute( - out_idx=get_python_output_index(declaration), - type_idx=binding_arg_index('dtype'), - layout_idx=binding_arg_index('layout'), - device_idx=binding_arg_index('device'), - ) - inits.append(check_type) - # we'll set requires_grad on outgoing tensor - if 'requires_grad' not in binding_arg_offsets: - raise RuntimeError( - '{}: expected "requires_grad" in python_binding_args absent tensor options arg but found [{}]'. - format(declaration['name'], [arg['name'] for arg in python_binding_args])) - requires_grad = parse_binding_arg('requires_grad') - set_requires_grad = '.set_requires_grad({})'.format(requires_grad) - - return argmap, inits, set_requires_grad - - -# handler for output/no-output overload pair -# (plugged into PY_VARIABLE_CASE as ${call_dispatch}) -PY_VARIABLE_OUT = CodeTemplate("""\ -if (_r.isNone(${out_idx})) { - ${call_dispatch} -} else { - ${call_dispatch_out} -} -""") - -# handler for a single parsed signature - may be a single overload or -# a pair of overloads that whose signatures only differ in output params -PY_VARIABLE_CASE = CodeTemplate("""\ -case ${i}: { - ${body} -} -""") - - -def emit_dispatch_case(i, dictionary, is_python_method): - """ - Emit dispatch code for a single parsed signature. This corresponds to either - a single overload, or a pair that differ only in output params. In the latter - case, a single signature is used for both and dispatching switches on the - presence/absence of passed output args. - - i: this signature's position in generated binding's signature list if number of - signatures > 1, otherwise None - - dictionary: contains a no-output overload declaration under 'base', and optionally - a second overload with outputs under 'out' - - true if we're generating a python method, in which case self is not parsed but - passed directly - """ - base_decl = dictionary['base'] - - if 'out' in dictionary: - # dispatch to output or no-output variant based on arg test - out_decl = dictionary['out'] - out_idx = get_python_output_index(out_decl) - output_gap = get_python_argc(out_decl) - get_python_argc(base_decl) - - call_dispatch = emit_single_dispatch(base_decl, is_python_method, output_gap) - call_dispatch_out = emit_single_dispatch(out_decl, is_python_method) - - # dispatch output and no-output variants, branch on _r.isNone() - body = PY_VARIABLE_OUT.substitute( - out_idx=out_idx, - call_dispatch=call_dispatch, - call_dispatch_out=call_dispatch_out, + }) + +def load_signatures( + native_yaml_path: str, + deprecated_yaml_path: str, + *, + method: bool, + skip_deprecated: bool = False, + pyi: bool = False, +) -> Sequence[PythonSignatureNativeFunctionPair]: + native_functions = list(filter(should_generate_py_binding, parse_native_yaml(native_yaml_path))) + + @with_native_function + def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair: + return PythonSignatureNativeFunctionPair( + signature=signature(f, method=method, pyi=pyi), + function=f, ) - else: - # no-output version only - body = emit_single_dispatch(base_decl, is_python_method) - - if i is not None: - # generate case for ith overload - return PY_VARIABLE_CASE.substitute(i=i, body=body) - else: - # only one overload, omit case wrapper - return body - -# -# named tuple codegen -# -def namedtuple_fieldnames(declaration): - returns = declaration['returns'] - if len(returns) <= 1 or all(['field_name' not in x for x in returns]): - return [] - else: - def get_field_name(x): - # See Note [field_name versus name] - if 'field_name' not in x: - # When building on Windows, `PyStructSequence_UnnamedField` could not be - # resolved by the linker for some reason, which cause error in building: - # - # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol - # PyStructSequence_UnnamedField - # - # Thus, at this point in time, we do not support unnamed - # fields in namedtuple; you must either name all fields, - # or none of them. - raise ValueError("Unnamed field is not supported by codegen") - else: - return x['field_name'] - return [get_field_name(x) for x in returns] - -PY_NAMEDTUPLE_FIELDSDEF = CodeTemplate("""\ -static PyStructSequence_Field ${fieldsname}[] = { ${fields,} {nullptr} }; -""") - -PY_NAMEDTUPLE_TYPEDEF = CodeTemplate("""\ -static PyTypeObject ${typename}; -static bool ${typename}_initialized = false; -if (!${typename}_initialized) { - ${typename}_initialized = true; - static PyStructSequence_Desc desc = { "torch.return_types.${name}", nullptr, ${fieldsname}, ${size} }; - PyStructSequence_InitType(&${typename}, &desc); - ${typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; -} -""") - - -def emit_namedtuple_typedefs(declarations): + pairs = list(map(gen_signature_pairs, native_functions)) + deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method, pyi=pyi) + return pairs if skip_deprecated else pairs + deprecated + +def load_deprecated_signatures( + pairs: Sequence[PythonSignatureNativeFunctionPair], + deprecated_yaml_path: str, + *, + method: bool, + pyi: bool, +) -> List[PythonSignatureNativeFunctionPair]: + # The deprecated.yaml doesn't have complete type information, we need + # find and leverage the original ATen signature (to which it delegates + # the call) to generate the full python signature. + # We join the deprecated and the original signatures using type-only form. + + # native function -> type-only signature + @with_native_function + def signature_original(f: NativeFunction) -> str: + # remove inplace suffix but keep outplace suffix + opname = str(f.func.name.name.base) + if f.func.is_out_fn(): + opname += '_out' + if f.func.name.name.inplace and pyi: + opname += '_' + args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments() + # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml. + types = ', '.join(argument_type_str(a.argument.type) + for a in args if isinstance(a.argument, Argument)) + return f'{opname}({types})' + + # deprecated -> type-only native signature (according to the call order) + def signature_deprecated(opname: str, params: List[str], call_args: List[str]) -> str: + # create a mapping of parameter name to parameter type + types: Dict[str, str] = {} + for param in params: + if param == '*': + continue + type, name = param.split(' ') + types[name] = type + # if the name in the call is not in the parameter list, assume it's + # a literal Scalar + rearranged_types = ', '.join(types.get(arg, 'Scalar') for arg in call_args) + return f'{opname}({rearranged_types})' + + # group the original ATen signatures by type-only signature + grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) + for pair in pairs: + grouped[signature_original(pair.function)].append(pair) + + # find matching original signatures for each deprecated signature + results: List[PythonSignatureNativeFunctionPair] = [] + + with open(deprecated_yaml_path, 'r') as f: + deprecated_defs = yaml.load(f, Loader=Loader) + + for deprecated in deprecated_defs: + _, params = split_name_params(deprecated['name']) + aten_name, call_args = split_name_params(deprecated['aten']) + + for pair in grouped[signature_deprecated(aten_name, params, call_args)]: + # It uses the types from the original ATen declaration, but the + # ordering and parameter names from the deprecated overload. Any + # default parameter values from the original ATen declaration are + # ignored. + # Deprecated signature might reorder input_args and input_kwargs, + # but never changes output_args nor TensorOptions (if any?), + # so here we only look into these two types of args. + python_sig = pair.signature + src_args: Dict[str, PythonArgument] = {a.name: PythonArgument( + name=a.name, + type=a.type, + default=None, + default_init=None, + ) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)} + + args: List[str] = [] + input_args: List[PythonArgument] = [] + input_kwargs: List[PythonArgument] = [] + + kwarg_only = False + for param in params: + if param == '*': + kwarg_only = True + continue + _, param_name = param.split(' ') + args.append(param_name) + + if param_name not in src_args: + # output argument + continue + + if not kwarg_only: + if not method or param_name != 'self': + input_args.append(src_args[param_name]) + else: + input_kwargs.append(src_args[param_name]) + + results.append(PythonSignatureNativeFunctionPair( + signature=PythonSignatureDeprecated( + name=python_sig.name, + input_args=tuple(input_args), + input_kwargs=tuple(input_kwargs), + output_args=python_sig.output_args, + tensor_options_args=python_sig.tensor_options_args, + method=python_sig.method, + deprecated_args_names=tuple(args), + deprecated_args_exprs=tuple(call_args), + returns=python_sig.returns, + ), + function=pair.function, + )) + + return results + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Named Tuple Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +@with_native_function +def gen_namedtuple_typename_key(f: NativeFunction) -> str: + name = cpp.name(f.func) + fieldnames = namedtuple_fieldnames(f.func.returns) + return '_'.join([name] + fieldnames) + +def emit_namedtuple_typedefs( + overloads: Sequence[PythonSignatureNativeFunctionPair] +) -> Tuple[List[str], Dict[str, str]]: """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ - flddefnames = {} # map from unique field name lists to field def name - flddefs = [] # field def declarations - typenames = {} # map from unique name + field name lists to typedef name - typedefs = [] # typedef declarations and init code - - for decl in declarations: - fieldnames = namedtuple_fieldnames(decl) - if fieldnames == []: - decl['namedtuple_typeref'] = '' + flddefnames: Dict[str, str] = {} # map from unique field name lists to field def name + flddefs: List[str] = [] # field def declarations + typenames: Dict[str, str] = {} # map from unique name + field name lists to typedef name + typedefs: List[str] = [] # typedef declarations and init code + + for overload in overloads: + fieldnames = namedtuple_fieldnames(overload.function.func.returns) + if not fieldnames: continue fn_key = '_'.join(fieldnames) fieldsname = flddefnames.get(fn_key) if fieldsname is None: - fieldsname = 'NamedTuple_fields{}'.format('' if flddefs == [] else len(fielddefs)) - fields = ['{{"{}", ""}}'.format(fn) for fn in fieldnames] - fieldsdef = PY_NAMEDTUPLE_FIELDSDEF.substitute( - fieldsname=fieldsname, - fields=fields - ) + fieldsname = f'NamedTuple_fields{"" if not flddefs else len(flddefs)}' flddefnames[fn_key] = fieldsname - flddefs.append(fieldsdef) + fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames) + flddefs.append(f"""\ +static PyStructSequence_Field {fieldsname}[] = {{ {fields}, {{nullptr}} }}; +""") - name = decl['name'] - key = '{}_{}'.format(name, '_'.join(fieldnames)) - typename = typenames.get(key) + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_namedtuple_typename_key(overload.function) + typename = typenames.get(tn_key) if typename is None: - typename = 'NamedTuple{}'.format('' if typedefs == [] else len(typedefs)) - typedef = PY_NAMEDTUPLE_TYPEDEF.substitute( - name=name, - typename=typename, - size=len(fieldnames), - fieldsname=fieldsname - ) - typenames[key] = typename - typedefs.append(typedef) - - decl['namedtuple_typeref'] = '&{}, '.format(typename) + typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' + typenames[tn_key] = typename + typedefs.append(f"""\ +static PyTypeObject {typename}; +static bool {typename}_initialized = false; +if (!{typename}_initialized) {{ + {typename}_initialized = true; + static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, {fieldsname}, {len(fieldnames)} }}; + PyStructSequence_InitType(&{typename}, &desc); + {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; +}} +""") - return flddefs + typedefs + return flddefs + typedefs, typenames +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # -# method impl codegen +# Method Impl Codegen # - -def get_pycname(name): - return 'THPVariable_{}'.format(name) - - -def is_noarg_binding(overloads): - return len(overloads) == 1 and get_python_argc(overloads[0]) == 0 - +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # python binding for all overloads of a particular function/method PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\ @@ -873,6 +405,15 @@ def is_noarg_binding(overloads): """) +# handler for a single parsed signature - may be a single overload or +# a pair of overloads that whose signatures only differ in output params +# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch}) +PY_VARIABLE_CASE = CodeTemplate("""\ +case ${overload_index}: { + ${body} +} +""") + # python binding for single-overload function/method PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\ // ${name} @@ -905,638 +446,408 @@ def is_noarg_binding(overloads): """) -TORCH_FUNCTION_CHECK = CodeTemplate("""\ -if(_r.has_torch_function()) { - return handle_torch_function(_r, ${self_}, args, kwargs, ${namespace}, ${modulename}); -} -""") - -TORCH_FUNCTION_CHECK_NOARGS = CodeTemplate("""\ -if(check_has_torch_function(self_)) { - return handle_torch_function(self_, ${name}); -} -""") - -# NOTE: we type the unpacked self as Tensor not Variable to avoid return type -# discrepancies on method resolution (e.g. Variable::detach_ returns void -# rather than Tensor &) -UNPACK_SELF = "Tensor& self = reinterpret_cast(self_)->cdata;" - - -def method_impl(name, declarations, is_python_method, module): +def method_impl( + name: BaseOperatorName, + module: Optional[str], + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool +) -> str: """ Generate a python binding for all overloads of an op. """ - for declaration in declarations: - # formals for python binding signature - declaration['python_arglists'] = make_python_arglists(declaration, is_python_method) - pycname = get_pycname(name) + noarg = is_noarg(overloads) + namedtuple_inits, namedtuple_typenames = emit_namedtuple_typedefs(overloads) method_header = ['HANDLE_TH_ERRORS'] - method_header += emit_namedtuple_typedefs(declarations) - method_header += [UNPACK_SELF] if is_python_method else [] - - method_footer = ['END_HANDLE_TH_ERRORS'] - - check_has_torch_function = TORCH_FUNCTION_CHECK_NOARGS.substitute( - name='"' + name + '"', - ) if is_python_method else '' - - # emit dispatch - if is_noarg_binding(declarations): - dispatch = emit_single_dispatch(declaration, is_python_method) - return PY_VARIABLE_METHOD_NOARGS.substitute( - name=name, - pycname=pycname, - method_header=method_header, - dispatch=dispatch, - method_footer=method_footer, - check_has_torch_function=check_has_torch_function, - ) - - method_footer = ['Py_RETURN_NONE;'] + method_footer - - grouped = group_overloads(declarations, is_python_method) - is_singleton = len(grouped) == 1 - - signatures = [] - dispatch = [] - for i, dictionary in enumerate(grouped): - signature = dictionary['signature'] - signatures.append('"{}",'.format(signature)) - overload_index = i if not is_singleton else None - dispatch.append(emit_dispatch_case(overload_index, dictionary, is_python_method)) - - if is_singleton: + method_header += namedtuple_inits + method_header += [ + "Tensor& self = reinterpret_cast(self_)->cdata;" + ] if method else [] + + method_footer = ([] if noarg else ['Py_RETURN_NONE;']) + ['END_HANDLE_TH_ERRORS'] + + traceable = 'true' if all(should_trace(o.function) for o in overloads) else 'false' + + grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads) + is_singleton = len(grouped_overloads) == 1 + signatures: List[str] = [] + dispatch: List[str] = [] + for overload_index, overload in enumerate(grouped_overloads): + signature = overload.signature.signature_str() + signatures.append(f'{cpp_string(str(signature))},') + dispatch_body = emit_dispatch_case(overload, namedtuple_typenames) + dispatch.append( + PY_VARIABLE_CASE.substitute(overload_index=overload_index, body=dispatch_body) + if not is_singleton else dispatch_body) + + if noarg: + template = PY_VARIABLE_METHOD_NOARGS + elif is_singleton: template = PY_VARIABLE_METHOD_VARARGS_SINGLETON else: template = PY_VARIABLE_METHOD_VARARGS - if module: - check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( - namespace=NATIVE_NAMESPACE_MAPPING[module], - modulename='"' + module + '"', - self_="self_" if is_python_method else "nullptr", - ) - else: - check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( - namespace="THPVariableClass", - modulename='"torch.Tensor"', - self_="self_" if is_python_method else "nullptr", - ) - - max_args = max([get_python_argc(decl) for decl in declarations]) - traceable = 'true' if all(should_trace(d) for d in declarations) else 'false' - return template.substitute( name=name, pycname=pycname, method_header=method_header, - max_args=max_args, + max_args=max(map(lambda o: o.signature.arguments_count(), overloads)), signatures=signatures, traceable=traceable, - check_has_torch_function=check_has_torch_function, + check_has_torch_function=gen_has_torch_function_check( + name=name, + module=module, + noarg=noarg, + method=method, + ), dispatch=dispatch, method_footer=method_footer, - self_="self_" if is_python_method else "nullptr", + self_="self_" if method else "nullptr", ) +def gen_has_torch_function_check( + name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool +) -> str: + if noarg: + if method: + return f"""\ +if(check_has_torch_function(self_)) {{ + return handle_torch_function(self_, "{name}"); +}} +""" + else: + return '' + + self_ = "self_" if method else "nullptr" + namespace = { + "torch": "THPVariableFunctionsModule", + "torch.nn": "THPNNVariableFunctionsModule", + "torch.fft": "THPFFTVariableFunctionsModule", + "torch.linalg": "THPLinalgVariableFunctionsModule", + }[module] if module else "THPVariableClass" + + return f"""\ +if(_r.has_torch_function()) {{ + return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}"); +}} +""" -# -# forward declarations -# - -PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION = CodeTemplate("""\ -static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); -""") - -PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION = CodeTemplate("""\ -static PyObject * ${pycname}(PyObject* self_, PyObject* args); +# handler for output/no-output overload pair +PY_VARIABLE_OUT = CodeTemplate("""\ +if (_r.isNone(${out_idx})) { + ${call_dispatch} +} else { + ${call_dispatch_out} +} """) - -def forward_decls(name, declarations, is_python_method, module): - if is_python_method: - return [] - - if is_noarg_binding(declarations): - template = PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION +def emit_dispatch_case( + overload: PythonSignatureGroup, + namedtuple_typenames: Dict[str, str], +) -> str: + """ + Emit dispatch code for a single parsed signature. This corresponds to either + a single native function, or a pair that differ only in output params. In the + latter case, a single python signature is used for both and dispatching + switches on the presence/absence of passed output args. + """ + if overload.outplace is not None: + # dispatch output and no-output variants, branch on _r.isNone() + return PY_VARIABLE_OUT.substitute( + out_idx=overload.signature.output_idx(), + call_dispatch=emit_single_dispatch( + overload.signature, overload.base, namedtuple_typenames), + call_dispatch_out=emit_single_dispatch( + overload.signature, overload.outplace, namedtuple_typenames), + ) else: - template = PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION - - pycname = get_pycname(name) - return [template.substitute(pycname=pycname)] - + # no-output version only + return emit_single_dispatch( + overload.signature, overload.base, namedtuple_typenames) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # -# method def (binding table entry) codegen +# Forward Declarations Codegen # +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# Python binary operator dunder methods -BINARY_OP_NAMES = [ - '__lt__', '__le__', - '__gt__', '__ge__', - '__eq__', '__ne__', - - '__add__', '__radd__', '__iadd__', - '__sub__', '__rsub__', '__isub__', - '__mul__', '__rmul__', '__imul__', - '__matmul__', '__rmatmul__', '__imatmul__', - '__truediv__', '__rtruediv__', '__itruediv__', - '__floordiv__', '__rfloordiv__', '__ifloordiv__', - '__mod__', '__rmod__', '__imod__', - '__divmod__', '__rdivmod__', '__idivmod__', - '__pow__', '__rpow__', '__ipow__', - '__lshift__', '__rlshift__', '__ilshift__', - '__rshift__', '__rrshift__', '__irshift__', - '__and__', '__rand__', '__iand__', - '__xor__', '__rxor__', '__ixor__', - '__or__', '__ror__', '__ior__', -] - -# PyMethodDef entry for binary op, throws not implemented error -PY_VARIABLE_METHOD_BINOP_DEF = CodeTemplate("""\ -{"${name}", (PyCFunction)${pycfunc_voidcast}TypeError_to_NotImplemented_<${pycname}>, ${flags}, NULL},""") +def forward_decls( + name: BaseOperatorName, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool +) -> Tuple[str, ...]: + if method: + return () -# PyMethodDef entry -PY_VARIABLE_METHOD_DEF = CodeTemplate("""\ -{"${name}", (PyCFunction)${pycfunc_voidcast}${pycname}, ${flags}, NULL},""") + pycname = get_pycname(name) + if is_noarg(overloads): + return (f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args); +""",) + else: + return (f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); +""",) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Def (Binding Table Entry) Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def method_def(name, declarations, is_python_method, module): +def method_def( + name: BaseOperatorName, + module: Optional[str], + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool +) -> str: """ Generate method def entry. """ pycname = get_pycname(name) - if is_noarg_binding(declarations): - pycfunc_voidcast = '' - flags = 'METH_NOARGS' if is_python_method else 'METH_VARARGS | METH_KEYWORDS' + if is_noarg(overloads): + pyfunc_cast = '' + flags = 'METH_NOARGS' if method else 'METH_VARARGS | METH_KEYWORDS' else: - pycfunc_voidcast = '(void(*)(void))' + pyfunc_cast = 'castPyCFunctionWithKeywords' flags = 'METH_VARARGS | METH_KEYWORDS' if module == "torch": flags += ' | METH_STATIC' - if name in BINARY_OP_NAMES: - def_template = PY_VARIABLE_METHOD_BINOP_DEF + if name.dunder_method: + # PyMethodDef entry for binary op, throws not implemented error + return f"""\ +{{"{name}", {pyfunc_cast}(TypeError_to_NotImplemented_<{pycname}>), {flags}, NULL}},""" else: - def_template = PY_VARIABLE_METHOD_DEF - - return def_template.substitute( - name=name, - pycname=pycname, - pycfunc_voidcast=pycfunc_voidcast, - flags=flags, - ) + # PyMethodDef entry + return f"""\ +{{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},""" +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # -# overload sorting and grouping +# Overload Sorting and Grouping # +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def group_overloads(declarations, is_python_method): - """Returns a list of dictionaries containing the optional keys: - - "base": the regular ATen declaration (e.g. conv2d) - "out": the out variant (e.g. conv2d_out) - "signature": the signature used for Python argument parsing - - Note that we merge pairs of declarations with signatures that - are equivalent mod output arguments, and use a single entry in - the python_arg_parser sig list for both (output arguments become - optional) - """ - grouped = defaultdict(dict) +def group_overloads( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> Sequence[PythonSignatureGroup]: + bases: Dict[str, PythonSignatureNativeFunctionPair] = {} + outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {} # first group by signature ignoring out arguments - for declaration in declarations: - signature = get_python_signature(declaration, is_python_method, skip_outputs=True) - v = grouped[signature] - if declaration['name'].endswith('_out'): - v['out'] = declaration - # prefer the signature with optional out=... arguments - v['signature'] = get_python_signature(declaration, is_python_method) + for overload in overloads: + sig = overload.signature.signature_str(skip_outputs=True) + if overload.function.func.is_out_fn(): + if sig in outplaces: + raise RuntimeError( + f'Found duplicated function definition:\n- {overload.function.func}.\n' + f'Existing definition:\n- {outplaces[sig].function.func}.' + ) + outplaces[sig] = overload else: - v['base'] = declaration - if 'signature' not in v: - v['signature'] = signature - - result = [] - for x, dictionary in sorted(grouped.items()): - if 'base' not in dictionary: + if sig in bases: + raise RuntimeError( + f'Found duplicated function definition:\n- {overload.function.func}.\n' + f'Existing definition:\n- {bases[sig].function.func}.' + ) + bases[sig] = overload + + for sig, out in outplaces.items(): + if sig not in bases: + candidates: List[str] = [] + for overload in overloads: + if str(overload.function.func.name.name) == str(out.function.func.name.name) \ + and not overload.function.func.is_out_fn() \ + and not overload.signature.deprecated: + candidates.append(overload.signature.signature_str(skip_outputs=True)) + out_sig = out.signature.signature_str() raise RuntimeError( - "'base' not in dictionary for {}. keys are {}".format( - x, list(dictionary.keys()))) - result.append(dictionary) - return sort_declarations(result) - + f'While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. ' + f'We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema ' + 'correctly in native_functions.yaml. We discovered the following candidate(s): \n' + + '\n'.join(f'- {candidate}' for candidate in candidates)) + + grouped: List[PythonSignatureGroup] = [] + for sig, base in bases.items(): + outplace = outplaces.get(sig) + grouped.append(PythonSignatureGroup( + # prefer the signature with optional out=... arguments because it's the + # superset that can be used to parse input for both base and outplace. + signature=outplace.signature if outplace is not None else base.signature, + base=base.function, + outplace=outplace.function if outplace is not None else None, + )) + + return sort_overloads(grouped) # This function declares a partial order on declarations, and sorts them according # to its linear extension. This is necessary, because there's some ambiguity in the # choice of overload, and we want a different order. # # See Note[Order of overloads matters] -def sort_declarations(grouped_decls): - - def dynamic_type(arg): - return arg['dynamic_type'] - - def is_coord_smaller(arg1, arg2): - return dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor' - - def is_smaller(d1, d2): - """Returns True if d1 < d2 in the partial order.""" - args1, args2 = d1['base']['arguments'], d2['base']['arguments'] - if len(args1) != len(args2): - return False - any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2)) - all_smaller_or_equal = all(dynamic_type(arg1) == dynamic_type(arg2) or - is_coord_smaller(arg1, arg2) - for arg1, arg2 in zip(args1, args2)) - return any_smaller and all_smaller_or_equal - - # Construct the relation graph - larger_than = defaultdict(set) - for i1, decl1 in enumerate(grouped_decls): - for i2, decl2 in enumerate(grouped_decls): - if is_smaller(decl1, decl2): - larger_than[i1].add(i2) - - if not larger_than: - return grouped_decls - - # Use a topological sort to sort decls according to the partial order. - sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls) - if i not in larger_than] - for i, decl in sorted_deps: - for i2 in sorted(larger_than.keys()): - larger = larger_than[i2] - larger.discard(i) - if not larger: - del larger_than[i2] - sorted_deps.append((i2, grouped_decls[i2])) - - return [decl for i, decl in sorted_deps] - - # -# python signature codegen +# A few examples of ambiguous python signature pairs. # - -SCHEMA_DEFAULT_CONVERSION_HACKS = { - 'nullptr': 'None', - 'c10::nullopt': 'None', - '{}': 'None', -} - -def get_schema_formal(arg, is_python_method): - name = arg['name'] - typename = arg['simple_type'] - - # TODO: remove this and make optional types in simple_type to be consistent across - # tensor and other types after make Tensor? be optional instead of undefined - if arg.get('is_nullable') and '?' not in typename: - typename = '{}?'.format(typename) - - # s/self/input/ outside method bindings. - # TODO remove this? doesn't rename in codegen, it's just for the parse string - if name == 'self' and typename == 'Tensor' and not is_python_method: - name = 'input' - - size = arg.get('size') - if size is not None: - if typename.endswith('?'): - typename = '{}[{}]?'.format(typename[:-1], size) - else: - typename = '{}[{}]'.format(typename, size) - - # default - default = arg.get('default') - if default is not None: - default = SCHEMA_DEFAULT_CONVERSION_HACKS.get(default, default) - return '{} {}={}'.format(typename, name, default) - else: - return '{} {}'.format(typename, name) - - -PYTHON_ARG_PARSER_SCHEMA = CodeTemplate("""\ -${name}(${schema_formals})${deprecated}""") - - -def get_python_signature(declaration, is_python_method, skip_outputs=False): - # Compute the Python function signature for argument parsing, - # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: - # this is NOT the same type signature as specified by PEP 484 - # as understood by mypy; our format was independently developed - # and has some quirks to make it more suitable specifically - # for error parsing. - # - # For a translation to mypy-valid type signatures, see - # tools/gen_pyi.py. If you change any logic here, please - # check that file too. - - python_args = get_python_args(declaration) - if skip_outputs: - python_args = [arg for arg in python_args if not is_output(arg)] - - schema_formals = [get_schema_formal(arg, is_python_method) for arg in python_args] - positional_argc = len(declaration['python_arglists']['input_args']) - if len(python_args) > positional_argc: - schema_formals.insert(positional_argc, '*') - - # Python function signature. - # This is the string that we give to FunctionParameter, which is - # then parsed into the actual structure which we do parsing with. - name = op_name(declaration) - deprecated = '|deprecated' if declaration.get('deprecated', False) else '' - return PYTHON_ARG_PARSER_SCHEMA.substitute( - name=name, - schema_formals=schema_formals, - deprecated=deprecated, - ) - +# All parameters have the same type, except one taking Tensor the other taking +# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor +# object can be accepted as Scalar type parameter (see python_arg_parser.cpp). +# Therefore, same input arguments might be accepted by either python signature. +# We want to always parse the one taking Tensor first. # -# op args to python parsed args transform +# bitwise_and(Tensor input, Tensor other, *, Tensor out=None) +# bitwise_and(Tensor input, Scalar other, *, Tensor out=None) # - -def get_python_args(decl): - arglists = decl['python_arglists'] - return \ - arglists['input_args'] + \ - arglists['input_kwargs'] + \ - arglists['output_args'] + \ - arglists['python_binding_args'] - - -def get_python_argc(decl): - return sum([len(arglist) for arglist in decl['python_arglists'].values()]) - - -def get_python_output_index(decl): - arglists = decl['python_arglists'] - return len(arglists['input_args'] + arglists['input_kwargs']) - - -def make_python_arglists(declaration, is_python_method): - # produces python-ready args converted from declaration['args'], - # partitioned into sublists by category. subslists are order, so - # the final python arglist can be recovered by simple flattening - # (see get_python_args()) - - # partition args into sublists - - args = declaration['arguments'] - - input_args = [] - input_kwargs = [] - output_args = [] - - current_input_args = input_args - for arg in args: - if is_output(arg): - output_args.append(arg) - else: - if arg.get('kwarg_only', False): - current_input_args = input_kwargs - current_input_args.append(arg) - - # adjustments - - # positional inputs: - # - filter self when we're generating a method binding.else - there, it comes in as - # a separate Python param, not in args array - def include(arg): - return not (is_tensor_self(arg) and is_python_method) - input_args = [arg for arg in input_args if include(arg)] - - # keyword inputs: - # - filter options. after loading the yaml, an upstream step has gathered dtype, - # layout et al into a single tensor options arg. here we reintroduce the originals - input_kwargs = [arg for arg in input_kwargs if not is_tensor_options(arg)] - - # outputs: - # - coalesce multiple output args into a single 'out' arg w/type TensorList. - # - force a default. This is so we can use this sig for both out and non-out variants - num_outputs = len(output_args) - if num_outputs > 1: - for arg in output_args: - if not arg['simple_type'] == 'Tensor': - raise RuntimeError( - '{}: unsupported output argument type {}'. - format(declaration['name'], arg['type'])) - typename = 'TensorList' - output_args = [{ - 'default': 'None', - 'kwarg_only': True, - 'name': 'out', - 'output': True, - 'scatter_args': output_args, - 'simple_type': typename, - 'size': num_outputs, - 'type': typename, - }] - elif num_outputs == 1: - output_arg = output_args[0].copy() - output_arg['default'] = 'None' - output_args = [output_arg] - - # make python binding args - # these are the (re)scattered versions of the options arg omitted above. - # TODO because these aren't guaranteed to be 100% faithful to the original - # versions in the yaml, this recreation is a potential source of drift between - # eager and JIT. Pull this logic out to a shared place. - python_binding_args = make_python_binding_args(declaration) - - return { - 'input_args': input_args, - 'input_kwargs': input_kwargs, - 'output_args': output_args, - 'python_binding_args': python_binding_args, - } - +# If they have different number of parameters then they are not ambiguous - but +# the difference on output param can be ignored as it's optional. # -# python binding args +# multiply(Tensor input, Tensor other, *, Tensor out=None) +# multiply(Tensor input, Scalar other) # - -# TODO blowtorch -def dtype_default_type_hack(name): - if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices': - return 'torch.int64' - else: - return 'None' - - -def make_python_binding_args(declaration): - """ - Given various properties of a declaration, build a set of scattered python binding args. - """ - name = declaration['name'] - python_binding_arguments = [] - has_tensor_input_arg = False - has_options_arg = False - for arg in declaration['arguments']: - if is_output(arg): - continue - typename = arg['simple_type'] - if typename in ['Tensor', 'TensorList']: - has_tensor_input_arg = True - elif typename == 'TensorOptions': - has_options_arg = True - if arg['name'] == 'requires_grad': - raise ValueError("argument named requires_grad not supported") - - has_tensor_return = False - for ret in declaration['returns']: - if ret['dynamic_type'] in ['Tensor', 'TensorList']: - # this probably won't work if one of the returns is not a tensor, but it will - # produce a compile-time error that is obvious - has_tensor_return = True - - category_override = declaration['category_override'] - is_like_function = name.endswith('_like') or category_override == 'like' - is_like_function_with_options = is_like_function and has_options_arg - is_new_function = name.startswith('new_') or category_override == 'new' - is_new_function_with_options = is_new_function and has_options_arg - is_factory_function = has_tensor_return and not has_tensor_input_arg or category_override == 'factory' - is_factory_or_like_or_new_function = has_tensor_return and (is_factory_function or is_like_function or is_new_function) - is_like_or_new_function_with_options = is_like_function_with_options or is_new_function_with_options - - if is_factory_function or has_options_arg: - default_type = dtype_default_type_hack(name) - py_default_dtype = 'self.scalar_type()' if is_like_or_new_function_with_options else None - dtype_arg = { - 'default': default_type, - 'dynamic_type': 'ScalarType', - 'kwarg_only': True, - 'name': 'dtype', - 'type': 'const ScalarType &', - 'simple_type': 'ScalarType', - 'python_default_init': py_default_dtype, - } - python_binding_arguments.append(dtype_arg) - - if is_factory_function or is_like_or_new_function_with_options: - py_default_layout = 'layout_from_backend(self.options().backend())' if is_like_or_new_function_with_options else None - layout_arg = { - 'default': 'torch.strided', - 'dynamic_type': 'Layout', - 'kwarg_only': True, - 'name': 'layout', - 'type': 'c10::optional', - 'simple_type': 'Layout', - 'python_default_init': py_default_layout, - } - python_binding_arguments.append(layout_arg) - py_default_device = 'self.device()' if is_like_or_new_function_with_options else None - device_arg = { - 'default': 'None', - 'dynamic_type': 'Device', - 'kwarg_only': True, - 'name': 'device', - 'type': 'const Device &', - 'simple_type': 'Device', - 'python_default_init': py_default_device - } - python_binding_arguments.append(device_arg) - pin_memory_arg = { - 'default': False, - 'dynamic_type': 'bool', - 'kwarg_only': True, - 'name': 'pin_memory', - 'type': 'bool', - 'simple_type': 'bool', - } - python_binding_arguments.append(pin_memory_arg) - - if is_factory_or_like_or_new_function: - requires_grad_arg = { - 'default': False, - 'dynamic_type': 'bool', - 'kwarg_only': True, - 'name': 'requires_grad', - 'type': 'bool', - 'simple_type': 'bool', - } - python_binding_arguments.append(requires_grad_arg) - - return python_binding_arguments - +# Both positional args and keyword-only args are considered together. # -# declaration derived props, utils, etc. -# declarations are dicts loaded from Declarations.yaml, -# passed to our codegen methods by callers in gen_autograd +# subtract(Tensor other, *, Scalar alpha=1) +# subtract(Scalar other, Scalar alpha=1) +# +# A few ambiguous cases which it does NOT handle yet. +# +# If there is any difference in other parameters besides the Tensor/Scalar +# difference, then they are not considered ambiguous by this method anymore. +# However, the difference could be too trivial to disambiguate. +# +# foo(Tensor input, Scalar other, Scalar bar) +# foo(Tensor input, Tensor other, double bar) +# +# If they are taking different number of parameters then they are not considered +# ambiguous anymore, even if the difference is only on optional kwargs. +# +# foo(Scalar other, Scalar alpha=1) +# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1) # -def is_tensor_self(arg): - return arg['name'] == 'self' and arg['simple_type'] == 'Tensor' - - -def is_tensor_options(arg): - return arg['simple_type'] == 'TensorOptions' - - -def is_scatter(arg): - return arg.get('scatter_args') is not None - -def is_output(arg): - return arg.get('output', False) - - -def has_outputs(declaration): - return any([is_output(arg) for arg in declaration['arguments']]) - - -def get_tensor_options(declaration): - args = [arg for arg in declaration['arguments'] if is_tensor_options(arg)] - if len(args) == 0: - return None - if len(args) != 1: - raise RuntimeError( - '{}: multiple tensor options arguments'. - format(declaration['name'])) - return args[0] - - -def has_tensor_options(declaration): - return get_tensor_options(declaration) is not None - - -def is_torch_function(declaration): - return 'namespace' in declaration['method_of'] - - -def is_nn_module_function(declaration): - return declaration.get('python_module') == 'nn' - +def sort_overloads( + grouped_overloads: Sequence[PythonSignatureGroup] +) -> Sequence[PythonSignatureGroup]: -def is_fft_module_function(declaration): - return declaration.get('python_module') == 'fft' + def is_arg_smaller(t1: Type, t2: Type) -> bool: + return str(t1) == 'Scalar' and str(t2) == 'Tensor' -def is_linalg_module_function(declaration): - return declaration.get('python_module') == 'linalg' + def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: + """Returns True if s1 < s2 in the partial order.""" + args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True) + if len(args1) != len(args2): + return False + # TODO: should use some canonical form instead of 'str(arg.type)' - see comments + # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which + # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'. + equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2)) + smaller_or_equal = all(str(arg1.type) == str(arg2.type) + or is_arg_smaller(arg1.type, arg2.type) + for arg1, arg2 in zip(args1, args2)) + return smaller_or_equal and not equal + + # First sort by signature + grouped_overloads = sorted(grouped_overloads, key=lambda x: x.signature.signature_str()) + # Construct the relation graph + larger_than: Dict[int, Set[int]] = defaultdict(set) + for i1, overload1 in enumerate(grouped_overloads): + for i2, overload2 in enumerate(grouped_overloads): + if is_smaller(overload1.signature, overload2.signature): + larger_than[i1].add(i2) -def function_namespace(declaration): - # TODO look into why these can't all be 'torch' calls - if has_tensor_options(declaration) or op_name(declaration).endswith('_like'): - return 'torch' - else: - return 'at' + if not larger_than: + return list(grouped_overloads) + # Use a topological sort to sort overloads according to the partial order. + N = len(grouped_overloads) + sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N))) -def op_name(declaration): - name = declaration['name'] - if has_outputs(declaration): - if not name.endswith("_out"): - raise RuntimeError( - '{} has output params, expecting name ending with \'_out\''. - format(declaration['name'])) - return name[:-4] - else: - if name.endswith("_out"): - raise RuntimeError( - '{}: name ends with \'_out\', expecting output params'. - format(declaration['name'])) - return name + for idx in range(N): + # The size of sorted_ids will grow to N eventually. + i = sorted_ids[idx] + for j in sorted(larger_than.keys()): + larger = larger_than[j] + larger.discard(i) + if not larger: + del larger_than[j] + sorted_ids.append(j) + + return list(map(lambda x: grouped_overloads[x], sorted_ids)) + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Codegen API Integration +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +def emit_single_dispatch( + ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str] +) -> str: + """ + Emit dispatch code for a single native function. + """ + @with_native_function + def go(f: NativeFunction) -> str: + # header comments + deprecated = '[deprecated] ' if ps.deprecated else '' + schema_comment = f'// {deprecated}aten::{f.func}' + + # dispatch lambda signature + name = cpp.name(f.func) + lambda_formals = ', '.join(map(lambda a: f"{a.type_str} {a.name}", + dispatch_lambda_args(ps, f))) + lambda_return = dispatch_lambda_return_str(f) + + # dispatch lambda body + dispatch_callee = cpp_dispatch_target(f) + dispatch_args = ', '.join(cpp_dispatch_exprs(f, python_signature=ps)) + + # from arg parser outputs to dispatch lambda arguments + parser_outputs = arg_parser_output_exprs(ps, f) + lambda_arg_exprs = dispatch_lambda_exprs(ps, f) + inits = '\n'.join(lambda_arg_exprs.inits) + lambda_args = ', '.join(lambda_arg_exprs.exprs) + + # scatter fields + # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky + # solution for enabling the 'requires_grad' argument for tensor methods + # new_full, new_empty, and new_zeros. A much better but more difficult to + # implement solution involves refactoring according to Ed's description here: + # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 + need_set_requires_grad = ps.tensor_options_args and (not has_tensor_options(f) or ( + ps.method and ('requires_grad' in parser_outputs))) + set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \ + if need_set_requires_grad else '' + + if lambda_return == 'void': + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + {dispatch_callee}({dispatch_args}); +}}; +dispatch_{name}({lambda_args}){set_requires_grad}; +Py_RETURN_NONE; +""" + else: + typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f)) + namedtuple_typeref = f'&{typename}, ' if typename is not None else '' + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + return {dispatch_callee}({dispatch_args}); +}}; +return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad}); +""" + + return go(f) diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py new file mode 100644 index 0000000000000..d8e68606e6bae --- /dev/null +++ b/tools/autograd/gen_trace_type.py @@ -0,0 +1,419 @@ +import itertools +from typing import Optional, List, Sequence, Union + +from tools.codegen.api.types import * +import tools.codegen.api.cpp as cpp +from tools.codegen.code_template import CodeTemplate +from tools.codegen.gen import with_native_function, parse_native_yaml, FileManager, mapMaybe +from tools.codegen.model import * + +# Note [Manual Backend kernels] +# For these ops, we want to manually register to dispatch key Backend and +# skip codegen-ed registeration to all keys before Backend. +# For codegen this means: +# - op set below must match ops with manual_kernel_registration=True in native_functions.yaml +# where we skip codegen backend kernels +# - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration +# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration +# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_BACKEND = set([ + 'options', 'data', 'set_data', 'is_leaf', 'output_nr', '_version', 'retain_grad', + '_backward', 'requires_grad_', +]) + +# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD_AND_TRACER = set([ + 'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', '_fw_primal', +]) + +# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: +# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER) +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER + +# These functions we don't want to record for tracing, because we always want +# to trace their constituent parts. This is a temporary hack in lieue +# of proper scopes, where subsequent compilation passes can ask for the unfolding +# on demand. Only concrete ATen methods can be disabled this way; it will have +# NO EFFECT otherwise. +DONT_RECORD_TRACE = { + 'convolution', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', + 'conv_transpose2d', 'conv_transpose3d', 'lstm_cell', 'gru_cell', + 'rnn_tanh_cell', 'rnn_relu_cell', 'linear', + # FIXME: figure out a better way when we support sparse tensors in jit + '_coalesced', +} + +def should_trace(f: NativeFunction) -> bool: + # Operations involving Storage or Type are not traceable at the moment + if any(str(arg.type) in {'Storage', 'Type', 'ConstQuantizerPtr'} + for arg in f.func.schema_order_arguments()): + return False + # We can't trace functions which don't have any Tensor or TensorList returns + if not any(r.type.is_tensor_like() for r in f.func.returns): + return False + return f.func.name.name.base not in DONT_RECORD_TRACE + +SELECT = CodeTemplate("""\ + +if (${cond}) { + ${true} +} else { + ${false} +} +""") + +OP_NAME = CodeTemplate("""\ +op_name = jit::Symbol::fromQualString("aten::${trace_name}"); +""") + +# These functions have their names recorded under trace renamed, +RENAME_TRACE = { + 'zero': 'zeros_like', # replacing aten::zero_ with aten::zeros_like + 'fill': 'full_like', # replacing aten::fill_ with aten::full_like +} + +def format_trace_op_name(f: NativeFunction) -> str: + # TODO: byte-for-byte compatible with old codegen behavior - should clean up + if f.func.kind() in (SchemaKind.functional, SchemaKind.out) or f.func.name.name.dunder_method: + # special case for *_out functions: the in-place and out-of-place ops + # are overloaded with the same name in the JIT + trace_name = str(f.func.name.name) + trace_name = RENAME_TRACE.get(trace_name, trace_name) + return OP_NAME.substitute(trace_name=trace_name) + + # otherwise, this is an in-place op and we need to emit both in- and + # out-of-place versions + outplace_trace_name = f.func.name.name.base + inplace_trace_name = cpp.name(f.func) + outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) + inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) + + return SELECT.substitute( + cond='tracer_state->force_outplace', + true=OP_NAME.substitute(trace_name=outplace_trace_name), + false=OP_NAME.substitute(trace_name=inplace_trace_name), + ) + +ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""") + +def format_trace_inputs(f: NativeFunction) -> str: + + def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequence[str]: + if isinstance(arg, TensorOptionsArguments): + name = 'options' + return [ + ADD_TRACE_INPUT.substitute(name=name, input='optTypeMetaToScalarType(options.dtype_opt())'), + ADD_TRACE_INPUT.substitute(name=name, input='options.layout()'), + ADD_TRACE_INPUT.substitute(name=name, input='options.device()'), + ADD_TRACE_INPUT.substitute(name=name, input='options.pinned_memory()'), + ] + else: + name = arg.name + if str(arg.type) == 'Tensor?[]': + return [f'jit::tracer::addInputs(node, "{name}", {name});'] + else: + return [ADD_TRACE_INPUT.substitute(name=name, input=name)] + + args: List[Union[Argument, TensorOptionsArguments]] = list(f.func.schema_order_arguments()) + + if f.func.is_out_fn(): + # *_out functions take the result as a separate argument, but we don't want to + # trace that argument directly. Instead, we trace its TensorOptions. + # So first, we need to remove the out argument from the list of arguments to trace. + # TODO: byte-for-byte compatible with old codegen behavior - it's incorrect to assume + # there is only one output argument. + args = args[:-1] + + trace_inputs = itertools.chain.from_iterable(dispatch_trace_input(arg) for arg in args) + + if f.func.is_out_fn(): + # for *_out functions, handle the result argument differently for inplace/outplace. + # For inplace: just add the input to the end to confirm with the JIT schema + name = f.func.arguments.out[0].name # TODO: old codegen behavior - should fix + inplace = ADD_TRACE_INPUT.substitute(name=name, input=name) + + # for outplace: do nothing, except if the function is a factory. + # Factories are a bit special because their out-of-place overloads + # take an extra TensorOptions argument, which is missing in the _out function + has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) + has_tensor_input_arg = any(a.type.is_tensor_like() for a in f.func.arguments.flat_non_out) + is_factory_method = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg) + + # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method` + # flag for the whole family of ops with the same basename if any of them is a + # factory method. For most cases the whole family of ops are indeed all factory + # method - 'normal' is the only exception. So we handle it specially here to avoid + # cloning the old logic. + if f.func.name.name.base == 'normal': + is_factory_method = True + + if is_factory_method: + outplace = [ + ADD_TRACE_INPUT.substitute(name='out', input='optTypeMetaToScalarType(out.options().dtype_opt())'), + ADD_TRACE_INPUT.substitute(name='out', input='out.options().layout()'), + ADD_TRACE_INPUT.substitute(name='out', input='out.options().device()'), + ADD_TRACE_INPUT.substitute(name='out', input='out.options().pinned_memory()'), + ] + else: + outplace = [] + + trace_inputs = itertools.chain( + trace_inputs, + [SELECT.substitute(cond='tracer_state->force_outplace', true='\n'.join(outplace), false=inplace)]) + + return '\n'.join(trace_inputs) + +# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, +# which force jit to replace functions with outplace variants (for +# example `aten::add_` becomes `aten::add`). +# +# This replacement implemented in-place with minimum modifications of +# arguments stack (as it assumes that outplace call has the same arguments +# as inplace version). +# +# However there are no such substitutions available for `aten::fill_` +# and `aten::zero_` operators, as we never implemented `aten::fill` +# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with +# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. +# +# But as they potentially can have different arguments, we also have +# to hack into the stack and add missing ones. +# +# A possible alternative would be: +# +# - Add `aten::fill` and `aten::zero` +# +# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` +# arguments (inside of the `native_functions.yaml`) +RENAME_TRACE_ADD_ARGS = { + 'fill': '''\ + jit::tracer::addInputs(node, "options", c10::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt)); + c10::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +''', + 'zero': '''\ + jit::tracer::addInputs(node, "options", c10::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt)); + c10::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +''', +} + +INPLACE_GUARD = CodeTemplate("""\ +jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); +""") + +PRE_RECORD_TRACE = CodeTemplate("""\ +torch::jit::Node* node = nullptr; +std::shared_ptr tracer_state; +if (jit::tracer::isTracing()) { + tracer_state = jit::tracer::getTracingState(); + at::Symbol op_name; + ${set_op_name} + node = tracer_state->graph->create(op_name, /*num_outputs=*/0); + jit::tracer::recordSourceLocation(node); + ${add_trace_inputs} + tracer_state->graph->insertNode(node); + ${inplace_guard} + jit::tracer::setTracingState(nullptr); +} +""") + +def format_prerecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return '' + + # TODO: clean up old codegen behavior + is_inplace = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) and not f.func.name.name.dunder_method + add_args = RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, '') if is_inplace else '' + additional_inputs = SELECT.substitute( + cond='tracer_state->force_outplace', + true=add_args, + false='', + ) if add_args else '' + + return PRE_RECORD_TRACE.substitute( + set_op_name=format_trace_op_name(f), + add_trace_inputs=format_trace_inputs(f) + additional_inputs, + inplace_guard=INPLACE_GUARD.substitute( + name=cpp.name(f.func), + mutable_input=f.func.arguments.out[0].name if f.func.arguments.out else 'self', + ) if is_inplace else '', + ) + +POST_RECORD_TRACE = CodeTemplate("""\ +if (tracer_state) { + jit::tracer::setTracingState(std::move(tracer_state)); + ${add_trace_outputs} +} +""") + +def format_postrecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return '' + + # For outplacing ops, *_out overloads require special handling to move the + # output *argument* to a return value + if f.func.is_out_fn(): + output_names_outplace = [arg.name for arg in f.func.arguments.out] + output_names_inplace = cpp.return_names(f) + + # Code size optimization: the common case is that the return value is + # the same for both variants + if output_names_outplace == output_names_inplace: + outputs = [f'jit::tracer::addOutput(node, {n});' for n in output_names_outplace] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + selection = SELECT.substitute( + cond='force_outplace', + true='\n'.join(f'jit::tracer::addOutput(node, {n});' for n in output_names_outplace), + false='\n'.join(f'jit::tracer::addOutput(node, {n});' for n in output_names_inplace), + ) + return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) + else: + output_names = cpp.return_names(f) + outputs = [f'jit::tracer::addOutput(node, {n});' for n in output_names] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + +def declare_returned_variables(f: NativeFunction) -> str: + modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) + if modifies_arguments: + return '' + if len(f.func.returns) == 1: + return '' + types = map(cpp.return_type, f.func.returns) + names = cpp.return_names(f) + return '\n'.join(f'{type} {name};' for type, name in zip(types, names)) + +def tie_return_values(f: NativeFunction) -> str: + if len(f.func.returns) == 1: + return f'auto {f.func.returns[0].name or "result"}' + names = cpp.return_names(f) + return f'std::tie({", ".join(names)})' + +def get_return_value(f: NativeFunction) -> str: + names = cpp.return_names(f) + if len(f.func.returns) == 1: + return names[0] + if f.func.kind() == SchemaKind.out: + return f'std::forward_as_tuple({", ".join(names)})' + else: + moved = ", ".join(f'std::move({name})' for name in names) + return f'std::make_tuple({moved})' + +TRACE_DISPATCH = CodeTemplate("""\ +static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("aten::${operator_name}", "${overload_name}") + .typed<${arg_types}>(); +${assign_return_values}c10::Dispatcher::singleton() + .redispatch<${ret_and_arg_types}>(${redispatch_args}); +""") + +def emit_trace_body(f: NativeFunction) -> List[str]: + trace_body: List[str] = [] + + trace_body.append(format_prerecord_trace(f)) + trace_body.append(declare_returned_variables(f)) + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + ret_and_arg_types = ', '.join([dispatcher_sig.returns_type()] + [a.type.cpp_type() for a in dispatcher_exprs]) + redispatch_args = ', '.join(['op', 'c10::DispatchKey::Tracer'] + [a.expr for a in dispatcher_exprs]) + + assign_return_values = f'{tie_return_values(f)} = ' \ + if f.func.kind() == SchemaKind.functional and f.func.returns else '' + + trace_body.append(TRACE_DISPATCH.substitute( + operator_name=f.func.name.name, + overload_name=f.func.name.overload_name, + arg_types=dispatcher_sig.type(), + assign_return_values=assign_return_values, + ret_and_arg_types=ret_and_arg_types, + redispatch_args=redispatch_args, + )) + + trace_body.append(format_postrecord_trace(f)) + if f.func.returns: + trace_body.append(f'return {get_return_value(f)};') + return trace_body + +METHOD_DEFINITION = CodeTemplate("""\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""") + +def type_wrapper_name(f: NativeFunction) -> str: + if f.func.name.overload_name: + return f'{cpp.name(f.func)}_{f.func.name.overload_name}' + else: + return cpp.name(f.func) + +@with_native_function +def method_definition(f: NativeFunction) -> Optional[str]: + if cpp.name(f.func) in MANUAL_TRACER: + return None + + formals = ', '.join( + f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ) + + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns), + type_wrapper_name=type_wrapper_name(f), + formals=formals, + type_definition_body=emit_trace_body(f), + ) + +WRAPPER_REGISTRATION = CodeTemplate("""\ +m.impl("${name}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""") + +@with_native_function +def method_registration(f: NativeFunction) -> Optional[str]: + if cpp.name(f.func) in MANUAL_TRACER: + return None + + return WRAPPER_REGISTRATION.substitute( + name=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type='TraceType', + ) + +def gen_trace_type_shard( + fm: FileManager, native_functions: Sequence[NativeFunction], suffix: str +) -> None: + fm.write_with_template('TraceType%s.cpp' % suffix, 'TraceType.cpp', lambda: { + 'generated_comment': '@' + f'generated from {fm.template_dir}/TraceType.cpp', + 'trace_method_definitions': list(mapMaybe(method_definition, native_functions)), + 'trace_wrapper_registrations': list(mapMaybe(method_registration, native_functions)), + }) + +def gen_trace_type(out: str, native_yaml_path: str, template_path: str) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + num_shards = 5 + shards: List[List[NativeFunction]] = [[] for _ in range(num_shards)] + + # functions are assigned arbitrarily but stably to a file based on hash + native_functions = list(sorted(parse_native_yaml(native_yaml_path), key=lambda f: cpp.name(f.func))) + for f in native_functions: + x = sum(ord(c) for c in cpp.name(f.func)) % num_shards + shards[x].append(f) + + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for i, shard in enumerate(shards): + gen_trace_type_shard(fm, shard, '_%d' % i) + gen_trace_type_shard(fm, native_functions, 'Everything') diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index 72fdad71648d3..f8ab30dc45800 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -3,77 +3,81 @@ # This writes one file: variable_factories.h import re +from typing import Optional, List -from .utils import CodeTemplate, write - - -FUNCTION_TEMPLATE = CodeTemplate("""\ -inline at::Tensor ${name}(${formals}) { - at::Tensor tensor = ([&]() { - at::AutoNonVariableTypeMode non_var_type_mode(true); - return at::${name}(${actuals}); - })(); - at::Tensor result = - autograd::make_variable(std::move(tensor), /*requires_grad=*/${requires_grad}); - return result; -} -""") - +from tools.codegen.api.types import * +import tools.codegen.api.cpp as cpp +import tools.codegen.api.python as python +from tools.codegen.gen import with_native_function, parse_native_yaml, FileManager, mapMaybe +from tools.codegen.model import * OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>") TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") - -def fully_qualified_type(argument_type): - def maybe_optional_type(t, opt_match): - return 'c10::optional<{}>'.format(t) if opt_match else t +# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc. +# TODO: maybe update the cpp argument API to take optional namespace argument? +def fully_qualified_type(argument_type: str) -> str: + def maybe_optional_type(type: str, is_opt: bool) -> str: + return f'c10::optional<{type}>' if is_opt else type opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type) + is_opt = opt_match is not None if opt_match: argument_type = argument_type[opt_match.start(1):opt_match.end(1)] match = TYPE_PATTERN.match(argument_type) if match is None: - return maybe_optional_type(argument_type, opt_match) + return maybe_optional_type(argument_type, is_opt) index = match.start(1) - qualified_type = "{}at::{}".format(argument_type[:index], argument_type[index:]) - return maybe_optional_type(qualified_type, opt_match) + qualified_type = f'{argument_type[:index]}at::{argument_type[index:]}' + return maybe_optional_type(qualified_type, is_opt) + +def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None: + native_functions = parse_native_yaml(native_yaml_path) + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: { + 'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h', + 'function_definitions': list(mapMaybe(process_function, native_functions)), + }) +@with_native_function +def process_function(f: NativeFunction) -> Optional[str]: + name = cpp.name(f.func) + has_tensor_options = python.has_tensor_options(f) + is_factory = has_tensor_options or name.endswith("_like") -def gen_variable_factories(out, declarations, template_path): - function_definitions = [] - for decl in declarations: - has_tensor_options = any(a["simple_type"] == "TensorOptions" for a in decl["arguments"]) - is_namespace_fn = 'namespace' in decl['method_of'] - if (has_tensor_options or decl["name"].endswith("_like")) and is_namespace_fn: - function_definitions.append( - process_function( - decl, - has_tensor_options, - ) - ) - write(out, - "variable_factories.h", - CodeTemplate.from_file(template_path + "/variable_factories.h"), - {"function_definitions": function_definitions}) + if Variant.function not in f.variants or not is_factory: + return None + sig = CppSignatureGroup.from_native_function(f, method=False).signature + formals: List[str] = [] + exprs: List[str] = [] + requires_grad = 'false' + for arg in sig.arguments(): + qualified_type = fully_qualified_type(arg.type) + if arg.default: + formals.append(f'{qualified_type} {arg.name} = {arg.default}') + else: + formals.append(f'{qualified_type} {arg.name}') -def process_function(decl, has_tensor_options): - formals = [] - actuals = [] - for argument in decl["arguments"]: - type = fully_qualified_type(argument["type"]) - default = " = {}".format(argument["default"]) if "default" in argument else "" - formals.append("{} {}{}".format(type, argument["name"], default)) - actual = argument["name"] - if argument["simple_type"] == "TensorOptions": + if isinstance(arg.argument, TensorOptionsArguments): # note: we remove the requires_grad setting from the TensorOptions because # it is ignored anyways (and we actually have an assertion that it isn't set # which would fail otherwise). We handle requires_grad explicitly here # instead of passing it through to the kernel. - actual = "at::TensorOptions({}).requires_grad(c10::nullopt)".format(actual) - actuals.append(actual) - requires_grad = "options.requires_grad()" if has_tensor_options else "false" + exprs.append(f'at::TensorOptions({arg.name}).requires_grad(c10::nullopt)') + # Manually set the requires_grad bit on the result tensor. + requires_grad = f'{arg.name}.requires_grad()' + else: + exprs.append(arg.name) - return FUNCTION_TEMPLATE.substitute( - name=decl["name"], formals=formals, actuals=actuals, requires_grad=requires_grad - ) + return f"""\ +inline at::Tensor {name}({', '.join(formals)}) {{ + at::Tensor tensor = ([&]() {{ + at::AutoNonVariableTypeMode non_var_type_mode(true); + return at::{name}({', '.join(exprs)}); + }})(); + at::Tensor result = + autograd::make_variable(std::move(tensor), /*requires_grad=*/{requires_grad}); + return result; +}} +""" diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 89bf64d8149e3..3e48c053f2e01 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -22,104 +22,24 @@ # which will in turn dispatch back to VariableType for its # differentiable subcomponents. # +from dataclasses import dataclass -from .utils import CodeTemplate, nested_dict, write, uninplace_api_name from .gen_autograd import VIEW_FUNCTIONS, VIEW_FUNCTIONS_WITH_METADATA_CHANGE, \ MULTI_OUTPUT_SAFE_FUNCTIONS, RETURNS_VIEWS_OF_INPUT from .gen_autograd_functions import uses_single_grad - -# These functions we don't want to record for tracing, because we always want -# to trace their constituent parts. This is a temporary hack in lieue -# of proper scopes, where subsequent compilation passes can ask for the unfolding -# on demand. Only concrete ATen methods can be disabled this way; it will have -# NO EFFECT otherwise. -DONT_RECORD_TRACE = { - 'convolution', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', - 'conv_transpose2d', 'conv_transpose3d', 'lstm_cell', 'gru_cell', - 'rnn_tanh_cell', 'rnn_relu_cell', 'linear', - # FIXME: figure out a better way when we support sparse tensors in jit - '_coalesced_', -} - -# These functions have their names recorded under trace renamed, -RENAME_TRACE = { - 'zero': 'zeros_like', # replacing aten::zero_ with aten::zeros_like - 'fill': 'full_like', # replacing aten::fill_ with aten::full_like -} - -# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, -# which force jit to replace functions with outplace variants (for -# example `aten::add_` becomes `aten::add`). -# -# This replacement implemented in-place with minimum modifications of -# arguments stack (as it assumes that outplace call has the same arguments -# as inplace version). -# -# However there are no such substitutions available for `aten::fill_` -# and `aten::zero_` operators, as we never implemented `aten::fill` -# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with -# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. -# -# But as they potentially can have different arguments, we also have -# to hack into the stack and add missing ones. -# -# A possible alternative would be: -# -# - Add `aten::fill` and `aten::zero` -# -# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` -# arguments (inside of the `native_functions.yaml`) -RENAME_TRACE_ADD_ARGS = { - 'fill': '''\ - jit::tracer::addInputs(node, "options", TensorOptions()); - c10::optional memory_format = c10::MemoryFormat::Preserve; - jit::tracer::addInputs(node, "memory_format", memory_format); -''', - 'zero': '''\ - jit::tracer::addInputs(node, "options", TensorOptions()); - c10::optional memory_format = c10::MemoryFormat::Preserve; - jit::tracer::addInputs(node, "memory_format", memory_format); -''', -} - -# (declaration name, argument name) -> attribute name -RENAME_ATTRIBUTES = { - ('fill_', 'value'): 'fill_value' -} - -# These functions are not worth profiling because they are very cheap and may -# be called very often. -DONT_PROFILE = { - 'data_ptr', 'get_device', 'is_contiguous', 'is_cuda', 'is_distributed', - 'is_same_size', 'is_set_to', 'is_signed', 'is_sparse', 'numel', - 'size', 'storage_offset', 'stride', -} - -# Note [Manual catchAll kernels] -# For these ops, we want to manually register to dispatch key catchAll and -# skip codegen-ed registeration to all keys before catchAll. -# For codegen this means: -# - op set below must match ops with manual_kernel_registration=True in native_functions.yaml -# where we skip codegen catchall kernels -# - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration -# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration -# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. -# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp -MANUAL_CATCHALL = set([ - 'options', 'data', 'set_data', 'is_leaf', 'output_nr', '_version', 'retain_grad', - 'backward', 'requires_grad_', -]) - -# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. -# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp -MANUAL_AUTOGRAD_AND_TRACER = set([ - 'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', -]) - -# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: -# union(MANUAL_CATCHALL, MANUAL_AUTOGRAD_AND_TRACER) -# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp -MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_CATCHALL | MANUAL_AUTOGRAD_AND_TRACER +from .gen_trace_type import ( + MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, MANUAL_AUTOGRAD, + declare_returned_variables, tie_return_values, get_return_value, type_wrapper_name, +) + +from tools.codegen.api.types import * +from tools.codegen.api.autograd import * +import tools.codegen.api.cpp as cpp +from tools.codegen.code_template import CodeTemplate +from tools.codegen.gen import with_native_function, parse_native_yaml, FileManager, mapMaybe +from tools.codegen.model import * +from tools.codegen.selective_build.selector import SelectiveBuilder +from typing import Callable, List, Optional, Sequence, Tuple, Union # We don't set or modify grad_fn on these methods. Generally, they return # tensors that have requires_grad=False. In-place functions listed here will @@ -139,7 +59,32 @@ 'quantize_per_tensor', 'quantize_per_channel', # Functions that return integers should not have output that require gradients 'argmax', 'argmin', 'argsort', 'searchsorted', - 'bucketize' + 'bucketize', + # Functions that return booleans are not differentiable + 'isnan', 'isposinf', 'isneginf', 'isinf' + # Functions return none are not differentiable + 'record_stream', +} + +# The C -> R functions at the time of adding this are still being audited and tested +# but will not error out. +# C -> C, R -> C functions for which backward is correctly implemented and tested +GRADIENT_IMPLEMENTED_FOR_COMPLEX = { + 't', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone', + 'repeat', 'expand', 'flip', 'fliplr', 'flipud', 'rot90', 'transpose', + 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', + 'triu', 'chunk', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum', + '_conj', 'sin', 'cos', 'mul', 'sinc', 'sinh', 'cosh', '__rmul__', + 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex', + 'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd', + 'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward', + 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', + 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', + 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'atanh', 'take', 'fill_', + 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c', + 'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv', + 'l1_loss_backward', 'baddbmm', 'addbmm', 'addmm', 'addmv' } # Some operators invalidate the grad_accumulator. Let's reset it. @@ -178,6 +123,21 @@ } """) +SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\ +std::vector> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const c10::optional& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_value() && tensor->has_storage() ? c10::optional(tensor->storage()) : c10::nullopt); +""") + +ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\ +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value()) + AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of( + static_cast>(${tensorlist_name}[i])->storage())); +} +""") + SAVE_TENSOR_IMPL = CodeTemplate("""\ c10::intrusive_ptr ${tensor_name}_impl_saved; if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr(); @@ -200,6 +160,21 @@ } """) +SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + c10::optional t = ${tensorlist_name}[i]; + if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr(); +} +""") + +ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\ +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + if (${tensorlist_name}_impl_saved[i]) + AT_ASSERT(${tensorlist_name}_impl_saved[i] == static_cast>(${tensorlist_name}[i])->getIntrusivePtr()); +} +""") + # The following list contains functions that we don't enforce the invariant on. DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = { # These functions are expected to change impl or storage of input tensors @@ -217,49 +192,41 @@ } """) -# NOTE[UnboxedOnly] Many of our codegen templates currently exist twice, once -# in an _UNBOXEDONLY_ variant and once without _UNBOXEDONLY_. This is because -# ops that are `use_c10_dispatcher: full` need different c++ code than ops -# that aren't `use_c10_dispatcher: full` yet. The _UNBOXEDONLY_ variants -# are for ops that aren't `use_c10_dispatcher: full` yet and those code templates -# can be deleted once all ops are `use_c10_dispatcher: full`. -# If you update one of the templates, you likely also have to update the other. - -# See NOTE[UnboxedOnly] -UNBOXEDONLY_WRAPPER_REGISTRATION = CodeTemplate("""\ -m.impl_UNBOXED("${unqual_operator_name_with_overload}", &${class_type}::${type_wrapper_name}); -""") - WRAPPER_REGISTRATION = CodeTemplate("""\ m.impl("${unqual_operator_name_with_overload}", - c10::impl::hacky_wrapper_for_legacy_signatures<${schema_order_cpp_signature}>(TORCH_FN(${class_type}::${type_wrapper_name})) + TORCH_FN(${class_type}::${type_wrapper_name}) ); """) UNPACK_TENSOR = CodeTemplate("""\ auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""") -UNPACK_OPTIONS = CodeTemplate("""\ -auto ${arg_name}_ = TensorOptions(${arg_name});""") - DECLARE_GRAD_FN = CodeTemplate("""\ std::shared_ptr<${op}> grad_fn; """) +SETUP_ANY_REQUIRES_GRAD = CodeTemplate("""\ +auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} ); +(void)_any_requires_grad; +""") + SETUP_DERIVATIVE = CodeTemplate("""\ -if (compute_requires_grad( ${args_with_derivatives} )) { +if (_any_requires_grad) { ${setup} } """) +SETUP_NONE_REQUIRES_GRAD = CodeTemplate("""\ +if (compute_requires_grad( ${args_to_check} )) { + throw_error_out_requires_grad("${base_name}"); +} +""") + ASSIGN_GRAD_FN = CodeTemplate("""\ grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode); grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); """) -CALL_DEFAULT = CodeTemplate("""\ -TypeDefault::${type_wrapper_name}(${args})""") - CALL_DISPATCH_VIA_NAMESPACE = CodeTemplate("""\ at::${api_name}(${unpacked_args})""") @@ -289,7 +256,7 @@ """) SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate("""\ -c10::optional> func=c10::nullopt; +std::function func=nullptr; if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided()) { ${replay_view_func} } @@ -320,299 +287,24 @@ } """) -SELECT = CodeTemplate("""\ - -if (${cond}) { - ${true} -} else { - ${false} -} -""") - -OP_NAME = CodeTemplate("""\ -op_name = jit::Symbol::fromQualString("aten::${trace_name}"); -""") - -PRE_RECORD_TRACE = CodeTemplate("""\ -torch::jit::Node* node = nullptr; -std::shared_ptr tracer_state; -if (jit::tracer::isTracing()) { - tracer_state = jit::tracer::getTracingState(); - at::Symbol op_name; - ${set_op_name} - node = tracer_state->graph->create(op_name, /*num_outputs=*/0); - jit::tracer::recordSourceLocation(node); - ${add_trace_inputs} - tracer_state->graph->insertNode(node); - ${inplace_guard} - jit::tracer::setTracingState(nullptr); -} -""") - -INPLACE_GUARD = CodeTemplate("""\ -jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); -""") - -ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""") - -POST_RECORD_TRACE = CodeTemplate("""\ -if (tracer_state) { - jit::tracer::setTracingState(std::move(tracer_state)); - ${add_trace_outputs} -} -""") - RUN_ONLY_IN_DEBUG_MODE = CodeTemplate("""\ #ifndef NDEBUG ${statements} #endif """) -# Generate a file that lists all functions and their schema string. Used for XLA -REGISTRATION_DECLARATION = CodeTemplate("""\ -${return_type} ${api_name}(${declaration_formals}); \ -// {"schema": "${schema_string}", "compound": "${compound}", "has_math_kernel": "${has_math_kernel}"} -""") - -# TraceType templates -# TODO: change `redispatch` to `NoTracerDispatchMode` + regular `call`. -# See NOTE[UnboxedOnly] -UNBOXED_TRACE_DISPATCH = CodeTemplate("""\ -static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("aten::${operator_name}", "${overload_name}") - .typed<${return_type} (${arg_types})>(); -${assign_return_values}c10::Dispatcher::singleton().redispatch<${ret_and_arg_types}>(${trace_dispatch_args}); -""") -TRACE_DISPATCH = CodeTemplate("""\ -static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("aten::${operator_name}", "${overload_name}") - .typed<${return_type} (${schema_order_arg_types})>(); -${assign_return_values}c10::Dispatcher::singleton() - .redispatch<${schema_order_ret_and_arg_types}>(${schema_order_trace_dispatch_args}); -""") - - -FACTORY_FUNCTION_NAMES = None - -# TODO The maybe_unwrap_optional_tensors is only needed because our at::native::xxx functions -# still take "Tensor" instead of "optional", so we need CPUType, TypeDefault, ... -# to do the same. Once at::native::xxx are converted, we can remove use_optional_tensor -# and use the use_optional_tensor=True behavior always. -def maybe_unwrap_optional_tensors(option, formals, args): - assert len(formals) == len(args), \ - "Assert we didn't screw up with method_args removing self but forgetting to remove it from formals" - if option['use_c10_dispatcher'] == 'full': - def maybe_unwrap_optional_tensor(formal, arg): - if formal['dynamic_type'] == 'Tensor' and formal['is_nullable']: - return "{}.has_value() ? *{} : at::Tensor()".format(arg, arg) - else: - return arg - return [maybe_unwrap_optional_tensor(formal, arg) for (formal, arg) in zip(formals, args)] - else: - assert option['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - return args - - -def find_factory_functions(declarations): - global FACTORY_FUNCTION_NAMES - FACTORY_FUNCTION_NAMES = set() - - for declaration in declarations: - if declaration['is_factory_method']: - FACTORY_FUNCTION_NAMES.add(declaration['api_name']) - - -def should_trace(declaration): - # Operations involving Storage or Type are not traceable at the moment - if any(arg['simple_type'] in {'Storage', 'Type', 'ConstQuantizerPtr'} for arg in declaration['arguments']): - return False - # We can't trace functions which don't have any Tensor or TensorList returns - if 'Tensor' not in declaration['return_type']: - return False - name = declaration['name'] - base_name = name[:-1] if declaration['inplace'] else name[:-4] if name.endswith('_out') else name - if base_name in DONT_RECORD_TRACE or name in DONT_RECORD_TRACE: - return False - return True - - -def is_out_overload(declaration): - return declaration['api_name'].endswith('_out') - - -def format_postrecord_trace(declaration): - # For outplacing ops, *_out overloads require special handling to move the - # output *argument* to a return value - if is_out_overload(declaration): - output_names_outplace = [arg['name'] for arg in declaration['arguments'] if arg.get('output', False)] - output_names_inplace = [r['name'] for r in declaration['returns']] - - # Code size optimization: the common case is that the return value is - # the same for both variants - if output_names_outplace == output_names_inplace: - outputs = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names_outplace] - return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) - - local = {} - local['cond'] = 'force_outplace' - local['true'] = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names_outplace] - local['false'] = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names_inplace] - selection = SELECT.substitute(local) - return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) - - output_names = [r['name'] for r in declaration['returns']] - outputs = ['jit::tracer::addOutput(node, {});'.format(n) for n in output_names] - return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) - - -def format_trace_op_name(declaration): - is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name']) - - if not is_inplace or is_out_overload(declaration): - # special case for *_out functions: the in-place and out-of-place ops - # are overloaded with the same name in the JIT - trace_name = uninplace_api_name(declaration['api_name']) - trace_name = RENAME_TRACE.get(trace_name, trace_name) - return OP_NAME.substitute(trace_name=trace_name) - - # otherwise, this is an in-place op and we need to emit both in- and - # out-of-place versions - outplace_trace_name = uninplace_api_name(declaration['api_name']) - inplace_trace_name = declaration['api_name'] - outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) - inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) - - select_params = {} - select_params['cond'] = 'tracer_state->force_outplace' - select_params['true'] = OP_NAME.substitute(trace_name=outplace_trace_name) - select_params['false'] = OP_NAME.substitute(trace_name=inplace_trace_name) - - return SELECT.substitute(select_params) - - -def format_trace_inputs(declaration): - def dispatch_trace_input(arg_spec): - name, value, simple_type, nullable = arg_spec - # XXX: For arg that have type of Tensor?[], tracer will pass allow_undefined to addInputs - if simple_type == 'TensorList' and nullable: - return '''jit::tracer::addInputs(node, "{}", {}, {});'''.format(name, value, "true") - else: - return ADD_TRACE_INPUT.substitute(name=name, input=value) - - trace_inputs = declaration['arguments'] - - if is_out_overload(declaration): - # *_out functions take the result as a first argument, but they are the - # last argument in the JIT schema. - out_input = trace_inputs[0] - trace_inputs = trace_inputs[1:] - - trace_input_spec = [(i['name'], i['name'], i['simple_type'], i.get('is_nullable')) for i in trace_inputs] - - trace_inputs = \ - '\n'.join(dispatch_trace_input(arg_spec) for arg_spec in trace_input_spec) +@dataclass(frozen=True) +class NativeFunctionWithDifferentiabilityInfo: + func: NativeFunction + info: Optional[DifferentiabilityInfo] - if is_out_overload(declaration): - # for *_out functions, handle the result argument differently for inplace/outplace. - # For inplace: just add the input to the end to confirm with the JIT schema - inplace = ADD_TRACE_INPUT.substitute(name=out_input['name'], input=out_input['name']) - - # for outplace: do nothing, except if the declaration is a factory. - # Factories are a bit special because their out-of-place overloads - # take an extra TensorOptions argument, which is missing in the _out function - trace_name = uninplace_api_name(declaration['api_name']) - has_factory_name = trace_name in FACTORY_FUNCTION_NAMES - if has_factory_name: - outplace = ADD_TRACE_INPUT.substitute(name='out', input='out.options()') - else: - outplace = '' - - trace_inputs += '\n' - trace_inputs += SELECT.substitute( - cond='tracer_state->force_outplace', true=outplace, false=inplace) - - return trace_inputs - - -def format_prerecord_trace(declaration): - local = {} - is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name']) - - local['set_op_name'] = format_trace_op_name(declaration) - - is_inplace = declaration['api_name'] != uninplace_api_name(declaration['api_name']) - add_args = '' - if is_inplace: - api_name = uninplace_api_name(declaration['api_name']) - add_args = RENAME_TRACE_ADD_ARGS.get(api_name, '') - if add_args: - select_params = {} - select_params['cond'] = 'tracer_state->force_outplace' - select_params['true'] = add_args - select_params['false'] = '' - additional_inputs = SELECT.substitute(select_params) - else: - additional_inputs = '' - local['add_trace_inputs'] = format_trace_inputs(declaration) + additional_inputs - - local['inplace_guard'] = '' - if is_inplace: - local['inplace_guard'] = INPLACE_GUARD.substitute( - name=declaration['api_name'], - mutable_input=declaration['arguments'][0]['name']) - - return PRE_RECORD_TRACE.substitute(local) - - -def format_trace(declaration): - if not should_trace(declaration): - return ('', '') - return (format_prerecord_trace(declaration), format_postrecord_trace(declaration)) - - -# Methods shared by TraceType and VariableType to handle return variable declaration, tie and tuple. -def format_return_variables(declaration): - name = declaration['name'] - arguments = declaration['arguments'] - inplace = declaration['inplace'] - is_out_fn = name.endswith('_out') - modifies_arguments = inplace or is_out_fn - - def declare_returned_variables(): - if modifies_arguments: - return '' - if len(declaration['returns']) == 1: - return '' - # TODO: this will be ugly - names = [ret['type'] + ' ' + ret['name'] + ';' for ret in declaration['returns']] - return '\n'.join(names) - - def tie_return_values(): - if len(declaration['returns']) == 1: - return 'auto {}'.format(declaration['returns'][0]['name']) - names = [ret['name'] for ret in declaration['returns']] - return 'std::tie({})'.format(', '.join(names)) - - def get_return_value(): - if inplace: - return 'self' - if is_out_fn: - return_names = [arg['name'] for arg in arguments - if arg.get('output', False)] - if len(return_names) == 1: - return return_names[0] - return 'std::forward_as_tuple({})'.format(', '.join(return_names)) - - returns = declaration['returns'] - if len(returns) == 1: - return returns[0]['name'] - moved = ['std::move({})'.format(r['name']) for r in returns] - return 'std::make_tuple({})'.format(', '.join(moved)) - - return (declare_returned_variables(), tie_return_values(), get_return_value()) - - -def gen_variable_type(out, aten_declarations, template_path): +def gen_variable_type( + out: str, + native_yaml_path: str, + differentiability_infos: Sequence[DifferentiabilityInfo], + template_path: str, + operator_selector: SelectiveBuilder, +) -> None: """VariableType.h and VariableType.cpp body @@ -620,234 +312,190 @@ def gen_variable_type(out, aten_declarations, template_path): implementation of each function dispatches to the base tensor type to compute the output. The grad_fn is attached to differentiable functions. """ + fns = list(sorted(filter( + operator_selector.is_native_function_selected_for_training, + parse_native_yaml(native_yaml_path)), key=lambda f: cpp.name(f.func))) + fns_with_infos = match_differentiability_info(fns, differentiability_infos) - # WARNING: this function call modifies global mutable state - find_factory_functions(aten_declarations) - - aten_declarations = list(sorted(aten_declarations, key=lambda decl: decl['name'])) - - gen_variable_type_shard(out, aten_declarations, template_path, None, True) + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + gen_variable_type_shard(fm, fns_with_infos, 'VariableType.h', 'VariableType.h') # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 5 - shards = [[] for _ in range(num_shards)] + shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)] # functions are assigned arbitrarily but stably to a file based on hash - for decl in aten_declarations: - x = sum(ord(c) for c in decl['name']) % num_shards - shards[x].append(decl) + for fn in fns_with_infos: + x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards + shards[x].append(fn) for i, shard in enumerate(shards): - gen_variable_type_shard(out, shard, template_path, '_%d' % i, False) - gen_variable_type_shard(out, aten_declarations, template_path, 'Everything', False) - - REGISTRATION_DECLARATIONS_H = CodeTemplate.from_file(template_path + "/RegistrationDeclarations.h") - registration_declarations = [] - - for declaration in aten_declarations: - if declaration['use_c10_dispatcher'] == 'full': - declaration_formals = declaration['schema_order_formals'] - else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - declaration_formals = declaration['formals'] - if dispatch_strategy(declaration) == 'use_derived': - registration_declarations.append( - REGISTRATION_DECLARATION.substitute(declaration, - declaration_formals=declaration_formals, - compound='False')) - else: - registration_declarations.append( - REGISTRATION_DECLARATION.substitute(declaration, - declaration_formals=declaration_formals, - compound='True')) - - env = { - 'registration_declarations': registration_declarations, - } - write(out, 'RegistrationDeclarations.h', REGISTRATION_DECLARATIONS_H, env) - - -def gen_variable_type_shard(out, aten_declarations, template_path, suffix, header): - VARIABLE_TYPE_H = CodeTemplate.from_file(template_path + '/VariableType.h') - VARIABLE_TYPE_CPP = CodeTemplate.from_file(template_path + '/VariableType.cpp') - TRACE_TYPE_CPP = CodeTemplate.from_file(template_path + '/TraceType.cpp') - - type_declarations = [] - type_definitions = [] - wrapper_registrations = [] - trace_method_definitions = [] - trace_wrapper_registrations = [] - - for declaration in aten_declarations: - formal_types = [arg['type'] for arg in declaration['arguments']] - type_declarations.append(METHOD_DECLARATION.substitute(declaration)) - strategy = dispatch_strategy(declaration) - if declaration['name'] not in MANUAL_AUTOGRAD and strategy == 'use_derived': - body = emit_body(declaration) + gen_variable_type_shard(fm, shard, 'VariableType.cpp', f'VariableType_{i}.cpp') + + gen_variable_type_shard(fm, fns_with_infos, 'VariableType.cpp', 'VariableTypeEverything.cpp') + +@with_native_function +def gen_formals(f: NativeFunction) -> str: + return ', '.join( + f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ) + +@with_native_function +def gen_wrapper_registration(f: NativeFunction) -> str: + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type='VariableType', + ) + +def gen_variable_type_shard( + fm: FileManager, + fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], + template_name: str, + output_name: str, +) -> None: + type_declarations: List[str] = [] + type_definitions: List[str] = [] + wrapper_registrations: List[str] = [] + + for fn in fns_with_infos: + f = fn.func + name = cpp.name(f.func) + formals = gen_formals(f) + + type_declarations.append(METHOD_DECLARATION.substitute( + return_type=cpp.returns_type(f.func.returns), + type_wrapper_name=type_wrapper_name(f), + formals=formals, + )) + + if name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == 'use_derived': type_definitions.append(METHOD_DEFINITION.substitute( - declaration, type_definition_body=body)) - if declaration['use_c10_dispatcher'] == 'full': - wrapper_registrations.append(WRAPPER_REGISTRATION.substitute( - declaration, class_type='VariableType')) - else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - wrapper_registrations.append(UNBOXEDONLY_WRAPPER_REGISTRATION.substitute( - declaration, class_type='VariableType')) - - # See Note [Manual catchAll kernels] - assert (declaration['name'] in MANUAL_CATCHALL) == declaration['manual_kernel_registration'] - - # Emit TraceType code - if declaration['name'] not in MANUAL_TRACER: - trace_body = emit_trace_body(declaration) - trace_method_definitions.append(METHOD_DEFINITION.substitute( - declaration, type_definition_body=trace_body)) - - if declaration['use_c10_dispatcher'] == 'full': - trace_wrapper_registrations.append(WRAPPER_REGISTRATION.substitute( - declaration, class_type='TraceType')) - else: - trace_wrapper_registrations.append(UNBOXEDONLY_WRAPPER_REGISTRATION.substitute( - declaration, class_type='TraceType')) - - env = { + return_type=cpp.returns_type(f.func.returns), + type_wrapper_name=type_wrapper_name(f), + type_definition_body=emit_body(fn), + formals=formals, + )) + wrapper_registrations.append(gen_wrapper_registration(f)) + + # See Note [Manual Backend kernels] + assert (name in MANUAL_BACKEND) == f.manual_kernel_registration + # If you want to register a kernel to Autograd, you must make the op abstract. + # In other words, this op must have dispatch section in native_functions.yaml. + if name in MANUAL_AUTOGRAD_AND_TRACER or (fn.info and fn.info.has_derivatives): + msg = (f'There\'s a formula for {name}(or its functional variant) in derivatives.yaml. ' + f'It\'s required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA ' + f'or DefaultBackend in native_functions.yaml. Please see ' + f'https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword ' + f'for instructions to choose the right dispatch keyword.') + assert f.is_abstract, msg + + fm.write_with_template(output_name, template_name, lambda: { + 'generated_comment': '@' + f'generated from {fm.template_dir}/{template_name}', 'type_derived_method_declarations': type_declarations, 'type_derived_method_definitions': type_definitions, 'wrapper_registrations': wrapper_registrations, - 'trace_method_definitions': trace_method_definitions, - 'trace_wrapper_registrations': trace_wrapper_registrations, - } - if header: - write(out, 'VariableType.h', VARIABLE_TYPE_H, env) - else: - write(out, 'VariableType%s.cpp' % suffix, VARIABLE_TYPE_CPP, env) - write(out, 'TraceType%s.cpp' % suffix, TRACE_TYPE_CPP, env) - - -def emit_trace_body(declaration): - returns = declaration['returns'] - name = declaration['name'] - inplace = declaration['inplace'] - is_out_fn = name.endswith('_out') - modifies_arguments = inplace or is_out_fn - returns_void = len(returns) == 0 - - trace_body = [] - pre_record_trace, post_record_trace = format_trace(declaration) - declare_returned_variables, tie_return_values, get_return_value = format_return_variables(declaration) - - trace_body.append(pre_record_trace) - trace_body.append(declare_returned_variables) - - arg_types = ', '.join([a['type'] for a in declaration['arguments']]) - ret_and_arg_types = ', '.join([declaration['return_type']] + [a['type'] for a in declaration['arguments']]) - schema_order_arg_types = ', '.join([a['type'] for a in declaration['schema_order_arguments']]) - schema_order_ret_and_arg_types = ', '.join( - [declaration['return_type']] + [a['type'] for a in declaration['schema_order_arguments']]) - - trace_dispatch_args = ['op', 'c10::DispatchKey::Tracer'] + declaration['args'] - schema_order_trace_dispatch_args = ['op', 'c10::DispatchKey::Tracer'] + declaration['schema_order_args'] - assign_return_values = '{} = '.format(tie_return_values) if not modifies_arguments and not returns_void else '' - if declaration['use_c10_dispatcher'] == 'full': - call = TRACE_DISPATCH.substitute( - declaration, - schema_order_arg_types=schema_order_arg_types, - assign_return_values=assign_return_values, - schema_order_ret_and_arg_types=schema_order_ret_and_arg_types, - schema_order_trace_dispatch_args=schema_order_trace_dispatch_args, - ) - else: - assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper' - call = UNBOXED_TRACE_DISPATCH.substitute( - declaration, - arg_types=arg_types, - ret_and_arg_types=ret_and_arg_types, - trace_dispatch_args=trace_dispatch_args, - assign_return_values=assign_return_values, - ) - trace_body.append(call) - trace_body.append(post_record_trace) - if not returns_void: - trace_body.append('return {};'.format(get_return_value)) - return trace_body - - -def emit_body(declaration): - strategy = dispatch_strategy(declaration) - - arguments = declaration['arguments'] - returns = declaration['returns'] - func = declaration['derivative'] - name = declaration['name'] - inplace = declaration['inplace'] - is_out_fn = name.endswith('_out') - modifies_arguments = inplace or is_out_fn - returns_void = len(returns) == 0 - - base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name + }) + +def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: + assert dispatch_strategy(fn) == 'use_derived' + f = fn.func + info = fn.info + + name = cpp.name(f.func) + inplace = f.func.kind() == SchemaKind.inplace + is_out_fn = f.func.kind() == SchemaKind.out + returns_void = len(f.func.returns) == 0 + base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? view_info = VIEW_FUNCTIONS.get(base_name, None) if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: view_info = "self" - def is_differentiable(arg): - if 'TensorOptions' in arg['type']: - return False - if 'Tensor' not in arg['type']: - return False - if arg['name'] in declaration.get('non_differentiable_arg_names', []): - return False - return True + def is_differentiable(name: str, type: Type) -> bool: + return type.is_tensor_like() and (info is None or name not in info.non_differentiable_arg_names) + + def gen_differentiable_input( + arg: Union[Argument, SelfArgument, TensorOptionsArguments] + ) -> Optional[DifferentiableInput]: + if isinstance(arg, TensorOptionsArguments): + return None + a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg + + # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove. + # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are + # not handled properly as they are irrelevant for this codegen. + cpp_type = cpp.argument_type(a, binds=a.name).cpp_type() + + if not is_differentiable(a.name, a.type): + return None + return DifferentiableInput( + name=a.name, + type=a.type, + cpp_type=cpp_type, + ) + + @with_native_function + def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]: + return list(mapMaybe(gen_differentiable_input, f.func.arguments.non_out)) - def find_args_with_derivatives(differentiable_inputs): + def find_args_with_derivatives(differentiable_inputs: List[DifferentiableInput]) -> List[DifferentiableInput]: """Find arguments that have derivative definitions""" - if func is None: + if info is None or not info.has_derivatives: return differentiable_inputs - names = set(name for d in func['derivatives'] for name in d['var_names']) - differentiable = [arg for arg in differentiable_inputs if arg['name'] in names] + names = set(name for d in info.derivatives for name in d.var_names) + differentiable = [arg for arg in differentiable_inputs if arg.name in names] if len(differentiable) != len(names): - missing = names - set(arg['name'] for arg in differentiable) - raise RuntimeError('Missing arguments for derivatives: {} in {}'.format(missing, func['name'])) + missing = names - set(arg.name for arg in differentiable) + raise RuntimeError(f'Missing arguments for derivatives: {missing} in {info.name}') return differentiable - inputs = [arg for arg in arguments if not arg.get('output', False)] - differentiable_inputs = list(filter(is_differentiable, inputs)) + def gen_differentiable_outputs(f: NativeFunction) -> List[DifferentiableOutput]: + outputs: List[DifferentiableOutput] = [ + DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret)) + for name, ret in zip(cpp.return_names(f), f.func.returns)] + + output_differentiability = info.output_differentiability if info else None + if output_differentiability is not None: + differentiable_outputs: List[DifferentiableOutput] = [] + if False in output_differentiability and f.func.kind() == SchemaKind.inplace: + raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)") + for differentiable, output in zip(output_differentiability, outputs): + if differentiable: + differentiable_outputs.append(output) + return differentiable_outputs + + candidate_differentiable_outputs = list(filter(lambda r: is_differentiable(r.name, r.type), outputs)) + + if uses_single_grad(info): + return candidate_differentiable_outputs[:1] + else: + return candidate_differentiable_outputs + + differentiable_inputs = gen_differentiable_inputs(f) args_with_derivatives = find_args_with_derivatives(differentiable_inputs) - non_differentiable_arg_names = declaration.get('non_differentiable_arg_names', []) - candidate_differentiable_outputs = list(filter(is_differentiable, returns)) - - if declaration['output_differentiability'] is not None: - differentiable_outputs = [] - output_differentiability = declaration['output_differentiability'] - if False in output_differentiability and inplace: - raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)") - for differentiable, output in zip(output_differentiability, returns): - if differentiable: - differentiable_outputs.append(output) - elif uses_single_grad(func): - differentiable_outputs = candidate_differentiable_outputs[:1] - else: - differentiable_outputs = candidate_differentiable_outputs + differentiable_outputs = gen_differentiable_outputs(f) requires_derivative = ( base_name not in DONT_REQUIRE_DERIVATIVE and name not in DONT_REQUIRE_DERIVATIVE and - len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0 and - strategy == 'use_derived') + len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0) - if func is not None and not requires_derivative: - raise RuntimeError('ERROR: derivative ignored for {} -- specified an autograd function without derivative' - .format(name)) + if info is not None and info.has_derivatives and not requires_derivative: + raise RuntimeError(f'ERROR: derivative ignored for {name} -- specified an autograd function without derivative') - def emit_save_inputs(): - setup = [] - if func is None: + def emit_save_inputs() -> List[str]: + setup: List[str] = [] + if info is None or not info.has_derivatives: return setup - has_tensorlist_arg = any(arg['type'] == 'TensorList' for arg in func['args_with_derivatives']) + has_tensorlist_arg = any(is_tensor_list_type(arg.type) for arg in args_with_derivatives) # We don't want to save tensors if we know that they will never be used # when computing the derivative, so we add guards to those statements - def guard_for(arg): + def guard_for(arg: SavedAttribute) -> Optional[str]: + assert info is not None + # It's hard to determine the edge offset if we have TensorLists if has_tensorlist_arg: return None @@ -858,156 +506,173 @@ def guard_for(arg): # require_grad if the backward function even gets executed. I don't # have any good ideas for detecting those cases, so I simply disabled the # checks. - if 'backward' in func['name']: + if 'backward' in info.name: return None # If there's a single derivative we could compute, we already have # a requires_grad check that is sufficient - if len(func['args_with_derivatives']) <= 1: + if len(args_with_derivatives) <= 1: return None # We really only care about trimming down the amount of tensors we save - if arg['type'] != 'Tensor': + if arg.type != 'Tensor': return None # We want to emit simple guards, so we only allow that if checking one # input is enough to determine whether we need that value - used_in = [d for d in func['derivatives'] if arg in d['saved_inputs']] + used_in = [d for d in info.derivatives if arg in d.saved_inputs] assert len(used_in) > 0 if len(used_in) != 1: return None derivative = used_in[0] - if len(derivative['var_names']) != 1: + if len(derivative.var_names) != 1: return None - derivative_var_name = derivative['var_names'][0] + derivative_var_name = derivative.var_names[0] # Figure out the offset of the edge that uses this variable - for edge_off, arg in enumerate(func['args_with_derivatives']): - if arg['name'] == derivative_var_name: + for edge_off, a in enumerate(args_with_derivatives): + if a.name == derivative_var_name: break else: raise AssertionError() - return 'grad_fn->should_compute_output({})'.format(edge_off) + return f'grad_fn->should_compute_output({edge_off})' - setup.extend(save_variables(func['saved_inputs'], False, guard_for)) - for arg in func['args_with_derivatives']: - if arg['type'] == 'TensorList': - setup.append("grad_fn->{}_size_ = {}.size();".format(arg['name'], arg['name'])) + setup.extend(save_variables(info.all_saved_inputs, False, guard_for)) + for arg in args_with_derivatives: + if is_tensor_list_type(arg.type): + setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();') return setup - def setup_derivative(differentiable_inputs): - - env = {} - env['args_with_derivatives'] = [arg['name'] for arg in args_with_derivatives] - env['op'] = func['op'] if func is not None else 'NotImplemented' - env['op_ctor'] = '' if func is not None else '"{}"'.format(declaration['api_name']) - + def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]: + body: List[str] = [] if is_out_fn: - setup = ['throw_error_out_requires_grad("{}");'.format(base_name)] - body = [] + # For out functions, ensure that no input or output requires grad body.append(DECLARE_GRAD_FN.substitute(op='Node')) - body.append(SETUP_DERIVATIVE.substitute( - setup=setup, - args_with_derivatives=[arg['name'] for arg in differentiable_inputs])) - body.append(SETUP_DERIVATIVE.substitute( - setup=setup, - args_with_derivatives=[arg['name'] for arg in differentiable_outputs])) + body.append(SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_inputs])) + body.append(SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_outputs])) return body + op = info.op if info is not None and info.has_derivatives else 'NotImplemented' setup = [] - setup.extend(ASSIGN_GRAD_FN.substitute(env).split('\n')) + setup.extend(ASSIGN_GRAD_FN.substitute( + op=op, + op_ctor='' if info is not None and info.has_derivatives else f'"{cpp.name(f.func)}"', + args_with_derivatives=[arg.name for arg in args_with_derivatives], + ).split('\n')) setup.extend(emit_save_inputs()) - body = [] body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives)) - body.append(DECLARE_GRAD_FN.substitute(env)) - body.append(SETUP_DERIVATIVE.substitute(env, setup=setup)) + body.append(DECLARE_GRAD_FN.substitute(op=op)) + body.append(SETUP_DERIVATIVE.substitute(setup=setup)) + return body + + def emit_check_if_in_complex_autograd_allowlist() -> List[str]: + body: List[str] = [] + if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: + return body + for arg in differentiable_outputs: + name = arg.name + # TODO: should be `arg.type.is_tensor_like()`? + if arg.cpp_type in ['Tensor', 'TensorList', 'const c10::List> &']: + body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");') return body - def emit_check_no_requires_grad(tensor_args, args_with_derivatives): + def emit_check_no_requires_grad( + tensor_args: List[DifferentiableInput], + args_with_derivatives: List[DifferentiableInput], + ) -> List[str]: """Checks that arguments without derivatives don't require grad""" - body = [] + body: List[str] = [] for arg in tensor_args: if arg in args_with_derivatives: continue - name = arg['name'] - if name in non_differentiable_arg_names: + name = arg.name + if info and name in info.non_differentiable_arg_names: continue if name == 'output': # Double-backwards definitions sometimes take in 'input' and # 'output', but only define the derivative for input. continue - if arg['dynamic_type'] in {'IndexTensor', 'ByteTensor', 'BoolTensor'}: - continue - body.append('check_no_requires_grad({}, "{}");'.format(name, name)) + body.append(f'check_no_requires_grad({name}, "{name}");') return body - def save_variables(saved_variables, is_output, guard_for=lambda name: None): + def save_variables( + saved_variables: Sequence[SavedAttribute], + is_output: bool, + guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None, + ) -> Sequence[str]: # assign the saved variables to the generated grad_fn - stmts = [] + stmts: List[str] = [] for arg in saved_variables: - name = arg['name'] - expr = arg.get('expr', arg['name']) - if arg['type'] == 'Tensor' or arg['type'] == 'c10::optional' or \ - arg['type'] == 'c10::optional&' or (is_output and arg['type'] == 'Scalar'): + name = arg.name + expr = arg.expr + if arg.type == 'Tensor' or arg.type == 'c10::optional' or \ + arg.type == 'c10::optional&' or (is_output and arg.type == 'Scalar'): name += '_' - var = arg['name'] + var = arg.name if var == 'self' and inplace: var = 'self.clone()' assert not is_output if inplace and is_output: var = 'self' - is_inplace_view = "{}.is_view()".format(var) - expr = 'SavedVariable({}, {}, {})'.format(var, str(is_output).lower(), is_inplace_view) + is_inplace_view = f'{var}.is_view()' + expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})' else: - expr = 'SavedVariable({}, {})'.format(var, str(is_output).lower()) - elif arg['type'] == 'TensorList': + expr = f'SavedVariable({var}, {str(is_output).lower()})' + elif arg.type in ['TensorList', 'c10::List>']: name += '_' - expr = 'make_saved_variable_list({})'.format(arg['name']) - elif arg['type'] == 'IntArrayRef': + expr = f'make_saved_variable_list({arg.name})' + elif arg.type == 'IntArrayRef': expr = expr + ".vec()" guard = guard_for(arg) if guard is None: - stmts.append('grad_fn->{} = {};'.format(name, expr)) + stmts.append(f'grad_fn->{name} = {expr};') else: - stmts.append('if ({}) {{'.format(guard)) - stmts.append(' grad_fn->{} = {};'.format(name, expr)) + stmts.append(f'if ({guard}) {{') + stmts.append(f' grad_fn->{name} = {expr};') stmts.append('}') return stmts - def emit_dispatch_call(api_name, input_base, unpacked_args): + def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str: """ Dispatch call via function in a namespace or method on Tensor.""" - if 'namespace' in declaration['method_of']: + if Variant.function in f.variants: call = CALL_DISPATCH_VIA_NAMESPACE.substitute( - api_name=api_name, + api_name=cpp.name( + f.func, + faithful_name_for_out_overloads=True, + ), unpacked_args=unpacked_args) else: call = CALL_DISPATCH_VIA_METHOD.substitute( - api_name=api_name, + api_name=cpp.name(f.func), var=input_base, unpacked_method_args=unpacked_args[1:]) return call - def emit_view_lambda(): + def emit_view_lambda(unpacked_bindings: List[Binding]) -> str: """ Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.""" input_base = 'input_base' replay_view_func = '' - updated_unpacked_args = [] - combined = nested_dict(env, declaration) - known_view_arg_simple_types = ['int64_t', 'int64_t?', 'bool', 'IntArrayRef'] - for arg in combined['unpacked_args']: + updated_unpacked_args: List[str] = [] + known_view_arg_simple_types: List[str] = ['int64_t', 'c10::optional', 'bool', 'IntArrayRef'] + for unpacked_binding in unpacked_bindings: + arg, arg_type = unpacked_binding.name, unpacked_binding.type if arg == 'self_': updated_unpacked_args.append(input_base) continue - arg_type = combined['unpacked_args_simple_type'][arg] if arg_type not in known_view_arg_simple_types: - raise TypeError('You are adding an {} {} argument to op {} in addition to known types: {}. ' - 'Please update the list or materialize it so that it can be closed over by value, ' - 'also add a test in pytorch/xla/test/test_operations.py where this code is exercised.' - .format(arg_type, arg, declaration['name'], ', '.join(known_view_arg_simple_types))) + known_types_str = ', '.join(known_view_arg_simple_types) + raise TypeError(f'You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: ' + f'{known_types_str}. Please update the list or materialize it so that it can be closed ' + 'over by value, also add a test in pytorch/xla/test/test_operations.py where this code ' + 'is exercised.') if arg_type == 'IntArrayRef': # It's not safe to close over IntArrayRef by value, since this is a @@ -1015,7 +680,7 @@ def emit_view_lambda(): arg_vec = arg + '_vec' replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) updated_unpacked_args.append(arg_vec) - elif arg_type == 'int64_t?': + elif arg_type == 'c10::optional': # Materialize int64_t? to int64_t arg_value = arg + '_val' replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg, val=arg_value, default='0') @@ -1023,7 +688,7 @@ def emit_view_lambda(): else: updated_unpacked_args.append(arg) - replay_view_call = emit_dispatch_call(combined['api_name'], input_base, updated_unpacked_args) + replay_view_call = emit_dispatch_call(f, input_base, updated_unpacked_args) replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute( input_base=input_base, replay_view_call=replay_view_call) @@ -1034,65 +699,74 @@ def emit_view_lambda(): is_view_with_metadata_change=is_view_with_metadata_change, replay_view_func=replay_view_func) - def wrap_output(return_values, var): + def wrap_output(f: NativeFunction, unpacked_bindings: List[Binding], var: str) -> str: call = '' - rhs_value = None - if 'Tensor' not in declaration['return_type']: + rhs_value: Optional[str] = None + if not any(r.type.is_tensor_like() for r in f.func.returns): rhs_value = var elif view_info is not None: # See NOTE [ Autograd View Variables ] in variable.h for details. - differentiable_output_vars = {r['name'] for r in differentiable_outputs} + differentiable_output_vars = {r.name for r in differentiable_outputs} if not isinstance(view_info, str): - raise TypeError("The view info should be a string for {}, but it is: {}".format(base_name, view_info)) + raise TypeError(f'The view info should be a string for {base_name}, but it is: {view_info}') if len(differentiable_output_vars) == 0: # no output is differentiable (.indices() for SparseTensors for example) - rhs_value = 'as_view({}, {}, /* is_differentiable */ false)'.format(view_info, var) + rhs_value = f'as_view({view_info}, {var}, /* is_bw_differentiable */ false, /* is_fw_differentiable */ false)' elif len(differentiable_output_vars) == 1: # Single differentiable output (Tensor or Tensor[]) return_info = differentiable_outputs[0] # We only support simple Tensor or a TensorList for functions that return views - if not return_info['dynamic_type'] in ['Tensor', 'TensorList']: - raise RuntimeError("{} that return differentiable views can only return Tensor or Tensor[]".format(base_name)) + if not is_tensor_type(return_info.type) and not is_tensor_list_type(return_info.type): + raise RuntimeError(f'{base_name} that return differentiable views can only return Tensor or Tensor[]') # Only allow rebasing of the history if we return a single Tensor # If we are in a no grad block, raise a warning # See NOTE [ View + Inplace detection ] for more details about this logic - if return_info['dynamic_type'] == 'TensorList': + if is_tensor_list_type(return_info.type): if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS: - creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE" + creation_meta = 'CreationMeta::MULTI_OUTPUT_SAFE' else: - creation_meta = "CreationMeta::MULTI_OUTPUT_NODE" - rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, " - "/* creation_meta */ {})").format(view_info, var, creation_meta) + creation_meta = 'CreationMeta::MULTI_OUTPUT_NODE' + call += (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, ' + '/* is_fw_differentiable */ true, ' + f'/* creation_meta */ {creation_meta});') + rhs_value = f'std::move({var})' else: - call += emit_view_lambda() - creation_meta = "GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE" - rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, " - "/* view_func */ func, /* creation_meta */ {})").format(view_info, var, creation_meta) + call += emit_view_lambda(unpacked_bindings) + creation_meta = 'GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE' + rhs_value = (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, ' + '/* is_fw_differentiable */ true, ' + f'/* view_func */ func, /* creation_meta */ {creation_meta})') else: # This could be supported but we don't need it at the moment, so keeping things simple. - raise RuntimeError("Function that return multiple differentiable output " - "when at least one of them is view is not supported.") + raise RuntimeError('Function that return multiple differentiable output ' + 'when at least one of them is view is not supported.') else: - rhs_value = 'std::move({})'.format(var) + rhs_value = f'std::move({var})' assert rhs_value is not None - call += ASSIGN_RETURN_VALUE.substitute(return_values=return_values, + call += ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f), rhs_value=rhs_value) return call - def enforce_same_tensorimpl_and_storage(env, call): - save_ptrs_stmts = [] - enforce_same_ptrs_stmts = [] - if declaration['name'] not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: - for arg in env.get('unpacked_args', []): - simple_type = env['unpacked_args_simple_type'][arg] - if simple_type == 'TensorList': + def enforce_same_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str: + save_ptrs_stmts: List[str] = [] + enforce_same_ptrs_stmts: List[str] = [] + if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: + for unpacked_binding in unpacked_bindings: + arg = unpacked_binding.name + noref_cpp_type = unpacked_binding.ctype.cpp_type(strip_ref=True) + if noref_cpp_type == 'TensorList': save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] - elif simple_type == 'Tensor': + elif noref_cpp_type == 'c10::List>': + save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] + enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] + elif noref_cpp_type == 'Tensor': save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg), @@ -1104,85 +778,75 @@ def enforce_same_tensorimpl_and_storage(env, call): RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts) return call - def emit_call(env, tie_return_values): - combined = nested_dict(env, declaration) - if strategy == 'use_derived': - # We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch - # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure - # the baseType operations still dispatch to non-Variable type, even if the arguments passed - # in are now Variables. - # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. - base_type_call = emit_dispatch_call(combined['api_name'], 'self_', combined['unpacked_args']) - if not modifies_arguments and not returns_void: - call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( - base_type_call=base_type_call) - - call += wrap_output(tie_return_values, 'tmp') - else: - call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( - base_type_call=base_type_call) + def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str: + # We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch + # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure + # the baseType operations still dispatch to non-Variable type, even if the arguments passed + # in are now Variables. + # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. + unpacked_args = [b.name for b in unpacked_bindings] + base_type_call = emit_dispatch_call(f, 'self_', unpacked_args) + if not modifies_arguments(f) and not returns_void: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( + base_type_call=base_type_call) + + call += wrap_output(f, unpacked_bindings, 'tmp') else: - args = maybe_unwrap_optional_tensors(declaration, declaration['arguments'], declaration['args']) - - call = CALL_DEFAULT.substitute(declaration, args=args) - if not modifies_arguments and not returns_void: - call = '{} = {}'.format(tie_return_values, call) - call = call + ';' - call = enforce_same_tensorimpl_and_storage(env, call) + call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( + base_type_call=base_type_call) + call = enforce_same_tensorimpl_and_storage(call, unpacked_bindings) return call - def emit_history(): - fn = 'rebase' if modifies_arguments and view_info is None else 'set' - output_names = [r['name'] for r in differentiable_outputs] + def emit_history() -> str: + fn = 'rebase' if modifies_arguments(f) and view_info is None else 'set' + output_names = [r.name for r in differentiable_outputs] # TODO: flatten allocates a std::vector, which could be expensive outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names) return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs) - def emit_save_outputs(): + def emit_save_outputs() -> str: if is_out_fn: # out functions don't currently support differentiation return '' - func = declaration['derivative'] - if func is not None: - stmts = save_variables(func['saved_outputs'], True) + if info is not None and info.has_derivatives: + stmts = save_variables(info.all_saved_outputs, True) if len(stmts) == 0: return '' return CONDITIONAL.substitute(cond='grad_fn', statements=stmts) return '' - def emit_check_inplace(): + def emit_any_requires_grad() -> List[str]: + return [SETUP_ANY_REQUIRES_GRAD.substitute( + args_with_derivatives=[arg.name for arg in args_with_derivatives]), ] + + def emit_check_inplace() -> List[str]: if not inplace: return [] - return ['check_inplace({});'.format(arg['name']) for arg in differentiable_outputs] + return [f'check_inplace({arg.name}, _any_requires_grad);' for arg in differentiable_outputs] - def emit_increment_version(): - if not modifies_arguments: + def emit_increment_version(f: NativeFunction) -> List[str]: + if not modifies_arguments(f): return [] - return ['increment_version({});'.format(arg['name']) for arg in returns] + return [f'increment_version({r});' for r in cpp.return_names(f)] - env = {} - combined = nested_dict(env, declaration) + body: List[str] = [] + unpack_args_stats, unpacked_bindings = unpack_args(f) - body = [] - - declare_returned_variables, tie_return_values, get_return_value = format_return_variables(declaration) - - if strategy != 'use_type': - body.extend(unpack_args(env, declaration)) + body.extend(unpack_args_stats) if requires_derivative: + body.extend(emit_any_requires_grad()) body.extend(emit_check_inplace()) body.extend(setup_derivative(differentiable_inputs)) - body.append(declare_returned_variables) + body.append(declare_returned_variables(f)) - body.append(emit_call(env, tie_return_values)) - if strategy == 'use_derived': - body.extend(emit_increment_version()) + body.append(emit_call(f, unpacked_bindings)) + body.extend(emit_increment_version(f)) if requires_derivative: # set_flags has to appear after version_counter, because rebase_history # requires that the counter is incremented before it is called body.append(emit_history()) - if requires_derivative: body.append(emit_save_outputs()) + body.extend(emit_check_if_in_complex_autograd_allowlist()) if base_name in RESET_GRAD_ACCUMULATOR: # `inplace` implies that there is exactly one output named `self`, # so we can keep the generated code easy. If you need to @@ -1191,50 +855,50 @@ def emit_increment_version(): assert inplace body.append('reset_grad_accumulator(self);') if not returns_void: - body.append('return {};'.format(get_return_value)) + body.append(f'return {get_return_value(f)};') return body - -def unpack_args(env, declaration): - def requires_unpack(arg): - return 'Tensor' in arg['dynamic_type'] - - body = [] - unpacked_args = [] - unpacked_args_simple_type = {} - for i, arg in enumerate(declaration['arguments']): - if not requires_unpack(arg): - unpacked_args.append(arg['name']) - unpacked_args_simple_type[arg['name']] = arg['simple_type'] +@with_native_function +def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]: + body: List[str] = [] + unpacked_bindings: List[Binding] = [] + + bindings = [r for a in f.func.schema_order_arguments() + for r in cpp.argument(a, + method=False, + cpp_no_default_args=set(), + faithful=False, + has_tensor_options=False)] + + for i, binding in enumerate(bindings): + assert not isinstance(binding.argument, SelfArgument) + if isinstance(binding.argument, TensorOptionsArguments): + raise RuntimeError("VariableKernel shouldn't take TensorOptions") + + is_nullable = binding.argument.type.is_nullable() + if not binding.argument.type.is_tensor_like() or is_nullable: + unpacked_bindings.append(binding) continue - dynamic_type = arg['dynamic_type'] - if 'TensorOptions' not in dynamic_type: - is_nullable = arg.get('is_nullable', False) - ref = (not is_nullable) and dynamic_type not in ['TensorList'] - suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else '' - - body.append(UNPACK_TENSOR.substitute( - arg_name=arg['name'], - arg_pos=i, - suffix=suffix, - ref='&' if ref else '', - )) - else: - # Okay, we are abusing the definition of 'unpack' here a bit, - # although it's still getting the non-variable from the variable - # (in this case via TensorOptions rather than Variable/Tensor). - body.append(UNPACK_OPTIONS.substitute(arg_name=arg['name'])) - - unpacked_args.append(arg['name'] + '_') - unpacked_args_simple_type[arg['name'] + '_'] = arg['simple_type'] - - env['unpacked_args'] = unpacked_args - env['unpacked_args_simple_type'] = unpacked_args_simple_type - return body - - -def dispatch_strategy(declaration): + is_tensor_list = is_tensor_list_type(binding.argument.type) + ref = (not is_nullable) and not is_tensor_list + suffix = '_opt' if is_nullable and not is_tensor_list else '' + body.append(UNPACK_TENSOR.substitute( + arg_name=binding.name, + arg_pos=i, + suffix=suffix, + ref='&' if ref else '', + )) + unpacked_bindings.append(Binding( + name=binding.name + '_', + ctype=binding.ctype, + argument=binding.argument, + default=binding.default, + )) + + return body, unpacked_bindings + +def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: """How are we going to call the underlying implementation of a declaration? There are two strategies: @@ -1254,7 +918,7 @@ def dispatch_strategy(declaration): get dispatched back to VariableType (which will ensure that they are differentiable.) """ - if declaration['abstract'] or declaration['derivative'] is not None: + if fn.func.is_abstract or (fn.info is not None and fn.info.has_derivatives): # If the function is abstract (not implemented on at::Type), we must # call the implementation on the derived type with unpacked tensors. @@ -1277,3 +941,48 @@ def dispatch_strategy(declaration): # actually implemented out of differentiable functions. (This # assumption might not hold, but then you'll see gradcheck fail.) return 'use_type' + +def is_tensor_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is None + +def is_tensor_list_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is not None + +def modifies_arguments(f: NativeFunction) -> bool: + return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] + +def match_differentiability_info( + native_functions: List[NativeFunction], + differentiability_infos: Sequence[DifferentiabilityInfo], +) -> List[NativeFunctionWithDifferentiabilityInfo]: + """Sets the "derivative" key on declarations to matching autograd function + + In-place functions will use the out-of-place derivative definition if there + is no in-place specific derivative. + """ + + info_by_schema = {info.func.func: info for info in differentiability_infos} + functional_info_by_signature = { + info.func.func.signature(strip_default=True): info + for info in differentiability_infos + if info.func.func.kind() == SchemaKind.functional} + + def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]: + if f.func in info_by_schema: + return info_by_schema[f.func], True + + # if there is no exact match look for the out-of-place signature. + # i.e mul() for mul_() or mul_out() + return functional_info_by_signature.get(f.func.signature(strip_default=True)), False + + result: List[NativeFunctionWithDifferentiabilityInfo] = [] + for f in native_functions: + info, is_exact_match = find_info(f) + result.append(NativeFunctionWithDifferentiabilityInfo( + func=f, + info=info, + )) + + return result diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 1e1c7c662227d..d5c742bb6fa5b 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -1,115 +1,115 @@ # Parses derivatives.yaml into autograd functions # -# Each autograd function is represented by dictionary containing a list of -# derivatives (also a dictionary). See `create_autograd_function` and -# `create_derivative` for the keys. -from collections import defaultdict -import copy +# Each autograd function is represented by `DifferentiabilityInfo` containing +# a list of `Derivative`. See `tools.codegen.api.autograd` for the data models. +from collections import defaultdict, Counter import re +from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional import yaml -from .utils import YamlLoader -from .utils import IDENT_REGEX, split_name_params - -def load_derivatives(path, declarations): - with open(path, 'r') as f: - definitions = yaml.load(f, Loader=YamlLoader) - - declarations_by_signature = defaultdict(list) - declarations_by_schema = dict() - for declaration in declarations: - declarations_by_signature[get_signature(declaration)].append(declaration) - if declaration['schema_string']: - assert declaration['schema_string'] not in declarations_by_schema - declarations_by_schema[declaration['schema_string']] = declaration - - differentiability_infos = [ - process_definition(defn, declarations_by_signature, declarations_by_schema) +from tools.codegen.api.autograd import * +from tools.codegen.api.types import * +import tools.codegen.api.cpp as cpp +from tools.codegen.gen import parse_native_yaml, with_native_function +from tools.codegen.model import * +from tools.codegen.utils import * + +try: + # use faster C loader if available + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader # type: ignore + +def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]: + with open(derivatives_yaml_path, 'r') as f: + definitions = yaml.load(f, Loader=Loader) + + functions = parse_native_yaml(native_yaml_path) + + # What's the difference between function schema v.s. signature? + # function schema is the complete declaration including mutability annotation / default value and etc. + # signature is the canonical schema for a group of functions (in-place/out/functional variants) + # that are semantically related. + functions_by_signature: Dict[FunctionSchema, List[NativeFunction]] = defaultdict(list) + functions_by_schema: Dict[str, NativeFunction] = dict() + for function in functions: + functions_by_signature[function.func.signature()].append(function) + assert str(function.func) not in functions_by_schema + functions_by_schema[str(function.func)] = function + + infos = [ + create_differentiability_info(defn, functions_by_signature, functions_by_schema) for defn in definitions] - autograd_functions = [d['autograd_fn'] for d in differentiability_infos if d['autograd_fn'] is not None] - ensure_unique_names(autograd_functions) - match_declarations_with_differentiability_info(declarations, differentiability_infos) - - return autograd_functions - - -def create_differentiability_info(signature, non_differentiable_arg_names, - output_differentiability, - autograd_fn): - return { - 'signature': signature, - 'non_differentiable_arg_names': non_differentiable_arg_names, - 'output_differentiability': output_differentiability, - 'autograd_fn': autograd_fn, - } - - -# How do you feel about pasting declaration inside autograd function... -def create_autograd_function(name, derivatives, args_with_derivatives, - declaration): - op = to_camel_case(name) + 'Backward' - op = op.replace('ForwardBackward', 'Backward') - return { - 'name': name, - 'op': op, - 'declaration': declaration, - 'args_with_derivatives': args_with_derivatives, - 'derivatives': derivatives, - 'saved_inputs': all_saved_variables(derivatives, 'saved_inputs'), - 'saved_outputs': all_saved_variables(derivatives, 'saved_outputs'), - } - - -def create_derivative(arguments, returns, name, formula, var_names): - def transform_return(r): - # In-place functions take in and return self. Call the modified version - # "output" so that it can be referred to in derivative definitions. - if r['name'] == 'self': - r = copy.deepcopy(r) - r['name'] = 'result' - return r - - returns = [transform_return(r) for r in returns] - formula, saved_inputs = saved_variables(formula, arguments) - formula, saved_outputs = saved_variables(formula, returns) + # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate + # step. We only assign op names to those with differentiable args, and only append suffix to + # duplicated op names. This can be simplified if the first of the duplicates can be named + # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons. + op_names = create_op_names(infos) + return [ + DifferentiabilityInfo( + name=info.name, + func=info.func, + op=op_name, + derivatives=info.derivatives, + all_saved_inputs=info.all_saved_inputs, + all_saved_outputs=info.all_saved_outputs, + args_with_derivatives=info.args_with_derivatives, + non_differentiable_arg_names=info.non_differentiable_arg_names, + output_differentiability=info.output_differentiability, + ) + for info, op_name in zip(infos, op_names)] + +@with_native_function +def cpp_arguments(f: NativeFunction) -> Sequence[Binding]: + return CppSignatureGroup.from_native_function(f, method=False).signature.arguments() + +def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative: + arguments = cpp_arguments(f) + argument_names = tuple(a.name for a in arguments) + argument_types = tuple(a.type for a in arguments) + + return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f)) + return_types = tuple(cpp.return_type(r) for r in f.func.returns) + + formula, saved_inputs = saved_variables(formula, argument_names, argument_types, var_names) + formula, saved_outputs = saved_variables(formula, return_names, return_types, var_names) # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): - if i >= len(returns): + if i >= len(f.func.returns): raise RuntimeError( - "Out of bounds grads access: derivative formula for {} " - "used grads[{}], but the forward only returns {} outputs." - .format(name, i, len(returns))) - - return { - 'formula': formula, - 'saved_inputs': saved_inputs, - 'saved_outputs': saved_outputs, - 'var_names': var_names, - } - - -def process_definition(defn, declarations_by_signature, declarations_by_schema): + f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} ' + f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.' + ) + + return Derivative( + formula=formula, + var_names=var_names, + saved_inputs=saved_inputs, + saved_outputs=saved_outputs, + ) + +def create_differentiability_info( + defn: Dict[Any, Any], + functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], + functions_by_schema: Dict[str, NativeFunction], +) -> DifferentiabilityInfo: """Processes a single entry `defn` in derivatives.yaml""" - def canonical_declaration(declarations, name): - for declaration in declarations: - if declaration['name'] == name: - return declaration + def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction: + for f in functions: + if cpp.name(f.func) == name: + return f # some functions only have in-place variants - assert name + '_' == declarations[0]['name'] - return declarations[0] + assert name + '_' == cpp.name(functions[0].func) + return functions[0] - def split_names(raw_names): + def split_names(raw_names: str) -> Tuple[str, ...]: """Given "foo, bar", return ["foo", "bar"].""" - return [x.strip() for x in raw_names.split(',')] - - def lookup_pred(pred, xs): - """Return the index of the first element of xs matching pred.""" - return next((i, x) for i, x in enumerate(xs) if pred(x)) + return tuple(x.strip() for x in raw_names.split(',')) - def check_grad_usage(defn_name, declaration, derivatives): + def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: """ Check for some subtle mistakes one might make when writing derivatives. These mistakes will compile, but will be latent until a function is @@ -119,9 +119,9 @@ def check_grad_usage(defn_name, declaration, derivatives): used_grad = 0 used_grads = 0 fully_implemented = True - used_grads_indices = [] + used_grads_indices: List[int] = [] for d in derivatives: - formula = d['formula'] + formula = d.formula used_grad += len(re.findall(IDENT_REGEX.format('grad'), formula)) used_grads += len(re.findall(IDENT_REGEX.format('grads'), formula)) fully_implemented = \ @@ -132,134 +132,99 @@ def check_grad_usage(defn_name, declaration, derivatives): only_used_grads_indices = used_grads == len(used_grads_indices) if used_grad and used_grads: - raise RuntimeError("Derivative definition of {} in derivatives.yaml illegally " + raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml illegally " "mixes use of 'grad' and 'grads'. Consider replacing " - "occurrences of 'grad' with 'grads[0]'".format(defn_name)) + "occurrences of 'grad' with 'grads[0]'") if only_used_grads_indices and set(used_grads_indices) == {0}: - raise RuntimeError("Derivative definition of {} in derivatives.yaml solely " + raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml solely " "refers to 'grads[0]'. If the first output is indeed the " "only differentiable output, replace 'grads[0]' with 'grad'; " "otherwise, there is a likely error in your derivatives " - "declaration.".format(defn_name)) - - def set_up_derivatives(defn_name, defn, declaration): - # Determine the set of inputs which have derivatives - args_with_derivatives_set = set() - for raw_names in defn: - args_with_derivatives_set |= set(split_names(raw_names)) - - # Next, let us determine the list of inputs in order. - args_with_derivatives = [] - for arg in declaration['arguments']: - if arg['name'] not in args_with_derivatives_set: - continue - args_with_derivatives.append(arg) - + "declaration.") + + @with_native_function + def set_up_derivatives(f: NativeFunction) -> Tuple[ + Sequence[Derivative], + Sequence[Binding], + Sequence[str], + ]: # Set up the derivative information - derivatives = [] - non_differentiable_arg_names = [] + derivatives: List[Derivative] = [] + non_differentiable_arg_names: List[str] = [] + args_with_derivatives_set: Set[str] = set() for raw_names in sorted(defn.keys()): formula = defn[raw_names] names = split_names(raw_names) - derivative = create_derivative(declaration['arguments'], declaration['returns'], - declaration['name'], formula, names) if formula.lower().strip() == 'non_differentiable': - assert not sum([type(var_name) == list - for var_name in derivative['var_names']]), \ - "Variable names associated to a formula should be a flat list" - non_differentiable_arg_names += derivative['var_names'] + non_differentiable_arg_names += names else: + derivative = create_derivative(f, formula, names) derivatives.append(derivative) - args_with_derivatives = list(filter(lambda x: x['name'] not in non_differentiable_arg_names, - args_with_derivatives)) + args_with_derivatives_set |= set(names) + + overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names) + if overlap: + raise RuntimeError(f'derivatives definition for {defn} have overlapped non_differentiable ' + f'and differentiable variables: {overlap}') + + # Next, let us determine the list of inputs in order. + # TODO: do we need eagerly calculate and save it here? Can it be derived + # from NativeFunction and `derivatives` on callsites instead? + args_with_derivatives = list(filter(lambda a: a.name in args_with_derivatives_set, cpp_arguments(f))) # Test to see if the use of 'grads' makes sense. - check_grad_usage(defn_name, declaration, derivatives) + check_grad_usage(defn_name, derivatives) return derivatives, args_with_derivatives, non_differentiable_arg_names - def unzip(xs): - return zip(*xs) - # NB: Removes 'name' from defn dictionary specification = defn.pop('name') - defn_name, params = split_name_params(specification) + defn_name, _ = split_name_params(specification) # NB: Removes 'output_differentiability' from defn dictionary # `None` means all differentiable. output_differentiability = defn.pop('output_differentiability', None) - schema_declaration = declarations_by_schema.get('aten::' + specification) - if not schema_declaration: - avail = [k.replace('aten::', '') for k, v in declarations_by_schema.items() - if k.replace('aten::', '').startswith(defn_name + '(') and len(v) > 0] - raise RuntimeError('could not find ATen declaration for schema: {} ' - '. Available signatures:\n{}'.format(specification, '\n'.join(avail))) + schema_function = functions_by_schema.get(specification) + if not schema_function: + avail = '\n'.join(k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name) + raise RuntimeError(f'could not find ATen function for schema: {specification} ' + f'. Available signatures:\n{avail}') # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here # to map in-place schemas to the out-of-place variants. - signature = get_signature(schema_declaration) - declarations = declarations_by_signature[signature] - if len(declarations) == 0: - avail = [k for k, v in declarations_by_signature.items() - if k.startswith(defn_name + '(') and len(v) > 0] - raise RuntimeError('could not find ATen declaration for legacy signature: {} ' - 'corresponding to schema {}. Please report a bug to PyTorch. ' - 'Available signatures: {}'.format(signature, specification, ', '.join(avail))) - - canonical = canonical_declaration(declarations, defn_name) - if 'grad_input_mask' in (a['name'] for a in canonical['arguments']): - raise RuntimeError("Schema for {} has an argument named grad_input_mask, " + # TODO: maybe the logic to handle the legacy schema is no longer necessary? + signature = schema_function.func.signature() + functions = functions_by_signature[signature] + if len(functions) == 0: + avail = '\n'.join(str(k) for k, v in functions_by_signature.items() if cpp.name(k) == defn_name) + raise RuntimeError(f'could not find ATen function for legacy signature: {signature} ' + f'corresponding to schema {specification}. Please report a bug to PyTorch. ' + f'Available signatures:\n{avail}') + + canonical = canonical_function(functions, defn_name) + if 'grad_input_mask' in (a.name for a in cpp_arguments(canonical)): + raise RuntimeError(f"Schema for {defn_name} has an argument named grad_input_mask, " "but this name would be shadowed by our codegen. " - "Please use a different name in native_functions.yaml." - .format(defn_name)) - - derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(defn_name, defn, canonical) - autograd_fn = None - - # only create an autograd function if we are actually going to calculate a derivative - if len(args_with_derivatives) > 0: - autograd_fn = create_autograd_function(defn_name, derivatives, args_with_derivatives, - canonical) - - return create_differentiability_info(signature, non_differentiable_arg_names, - output_differentiability, autograd_fn) - - -def ensure_unique_names(autograd_functions): - # de-duplicate operation names - # you end up with something like: - # AddBackward0 - # AddBackward1 - # one for each overload - functions_by_name = defaultdict(list) - for func in autograd_functions: - functions_by_name[func['op']].append(func) - for op in functions_by_name.keys(): - overloads = functions_by_name[op] - if len(overloads) > 1: - for i, func in enumerate(overloads): - func['op'] += str(i) - - -def get_signature(declaration, use_base_variant=False): - name = declaration['name'] - arguments = declaration['arguments'] - if use_base_variant: - if declaration['inplace']: - assert name.endswith('_') - name = name[:-1] - elif name.endswith('_out'): - name = name[:-4] - arguments = [arg for arg in arguments if not arg.get('output', False)] - simple_types = [arg['simple_type'] for arg in arguments] - return '{}({})'.format(name, ', '.join(simple_types)) - + "Please use a different name in native_functions.yaml.") + + derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical) + + return DifferentiabilityInfo( + name=defn_name, + func=canonical, + op=None, + derivatives=derivatives, + all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]), + all_saved_outputs=dedup_vars([v for d in derivatives for v in d.saved_outputs]), + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=non_differentiable_arg_names, + output_differentiability=output_differentiability, + ) GRAD_INDEX_REGEX = r'(?:^|\W)grads\[(\d+)\]' - -def used_gradient_indices(formula): +def used_gradient_indices(formula: str) -> List[int]: """Determine a list of gradient indices (the i in grads[i]) that are used by the formula. @@ -268,17 +233,30 @@ def used_gradient_indices(formula): """ return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)] +def saved_variables( + formula: str, + arg_names: Tuple[str, ...], + arg_types: Tuple[str, ...], + var_names: Tuple[str, ...], +) -> Tuple[str, Tuple[SavedAttribute, ...]]: -def saved_variables(formula, args): - # find which arguments need to be saved - saved = [] + def stride_expr(name: str) -> str: + assert var_names == (name,), ( + 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' + 'that ".strides()" is being called on.') + return f'strides_or_error({name}, "{name}")' - REPLACEMENTS = [ + REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [ # replace self.sizes() with self_sizes (r'{}.sizes\(\)', { 'suffix': '_sizes', 'type': 'IntArrayRef', }), + # replace self.options() with self_options + (r'{}.options\(\)', { + 'suffix': '_options', + 'type': 'at::TensorOptions', + }), # replace zeros_like(self) with self_info (r'zeros_like\({}\)', { 'suffix': '_info', @@ -315,80 +293,83 @@ def saved_variables(formula, args): 'suffix': '_dim', 'type': 'int64_t', }), + # replace self.strides() with self_strides + (r'{}.strides\(\)', { + 'suffix': '_strides', + 'type': 'IntArrayRef', + 'expr': stride_expr, + }), ] - for arg in args: - if 'name' not in arg: - # some returned arguments do not have names - continue - - name = arg['name'] + # find which arguments need to be saved + saved: List[SavedAttribute] = [] + for name, type in zip(arg_names, arg_types): # First search the formula for expressions which can be evaluated # when the autograd Function is created to avoid saving variables for regex, info in REPLACEMENTS: - def repl(m): - suffix = info['suffix'] - suffix = suffix(m) if callable(suffix) else suffix - expr = info['expr'](name) if 'expr' in info else m.group(0) - saved.append({ - 'name': name + suffix, - 'type': info['type'], - 'expr': expr, - }) + def repl(m: Match[str]) -> str: + suffix: str = info['suffix'](m) if callable(info['suffix']) else info['suffix'] + expr: str = info['expr'](name) if 'expr' in info else m.group(0) + saved.append(SavedAttribute( + name=name + suffix, + type=info['type'], + expr=expr, + )) if 'res' in info: - return info['res'](name) + replacement: str = info['res'](name) + return replacement return name + suffix formula = re.sub(regex.format(name), repl, formula) # Find any variables which remain in the formula and save them if re.search(IDENT_REGEX.format(name), formula): - arg = copy.deepcopy(arg) - arg['type'] = arg['type'].replace('const ', '').replace(' &', '') - saved.append(arg) - - return formula, saved - - -def all_saved_variables(derivatives, key): - seen = set() - saved = [] - for d in derivatives: - for saved_arg in d[key]: - if saved_arg['name'] in seen: - continue - seen.add(saved_arg['name']) - saved.append(saved_arg) - return saved - - -def to_camel_case(name): - return ''.join([p.title() for p in name.split('_')]) - - -def match_declarations_with_differentiability_info(declarations, differentiability_infos): - """Sets the "derivative" and "output_differentiability" key on declarations - to matching differentiability info + saved.append(SavedAttribute( + name=name, + # TODO: change from string to type data model + type=type.replace('const ', '').replace(' &', ''), + expr=name, + )) + + return formula, tuple(saved) + +def create_op_name(info: DifferentiabilityInfo) -> Optional[str]: + # only assign an op name if we are actually going to calculate a derivative + if not info.args_with_derivatives: + return None + name = info.name + camel_case = ''.join([p.title() for p in name.split('_')]) + return (camel_case + 'Backward').replace('ForwardBackward', 'Backward') + +def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]: + names = list(map(create_op_name, infos)) + dups = set(item for item, count in Counter(names).items() if count > 1) - In-place functions will use the out-of-place derivative definition if there - is no in-place specific derivative. - """ - - infos_by_signature = {f['signature']: f for f in differentiability_infos} - - def find_info(declaration): - signature = get_signature(declaration) - if signature in infos_by_signature: - return infos_by_signature[signature] - - # if there is no exact match look for the out-of-place signature. - # i.e mul() for mul_() or mul_out() - signature = get_signature(declaration, use_base_variant=True) - return infos_by_signature.get(signature) - - for declaration in declarations: - info = find_info(declaration) - declaration['derivative'] = info['autograd_fn'] if info else None - declaration['non_differentiable_arg_names'] = info['non_differentiable_arg_names'] if info else [] - declaration['output_differentiability'] = info['output_differentiability'] if info else None + # de-duplicate operation names + # you end up with something like: + # AddBackward0 + # AddBackward1 + # one for each overload + counter: Dict[str, int] = Counter() + dedup: List[Optional[str]] = [] + for name in names: + if name is None: + # Keep a placeholder + dedup.append(None) + elif name in dups: + dedup.append(f'{name}{counter[name]}') + counter[name] += 1 + else: + dedup.append(name) + return dedup + +def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: + seen: Set[str] = set() + saved: List[SavedAttribute] = [] + for var in vars: + if var.name in seen: + continue + seen.add(var.name) + saved.append(var) + return saved diff --git a/tools/autograd/nested_dict.py b/tools/autograd/nested_dict.py deleted file mode 100644 index e1e0981419915..0000000000000 --- a/tools/autograd/nested_dict.py +++ /dev/null @@ -1,19 +0,0 @@ -# TODO: refactor nested_dict into common library with ATen -class nested_dict(object): - """ - A nested dict is a dictionary with a parent. If key lookup fails, - it recursively continues into the parent. Writes always happen to - the top level dict. - """ - - def __init__(self, base, parent): - self.base, self.parent = base, parent - - def __contains__(self, item): - return item in self.base or item in self.parent - - def __getitem__(self, x): - r = self.base.get(x) - if r is not None: - return r - return self.parent[x] diff --git a/tools/autograd/templates/Functions.h b/tools/autograd/templates/Functions.h index 03240e2a5a2be..0540bb65b33b1 100644 --- a/tools/autograd/templates/Functions.h +++ b/tools/autograd/templates/Functions.h @@ -32,6 +32,15 @@ inline std::vector unpack_list(at::ArrayRef xs) { }); } +inline c10::List> unpack_opt_list(at::ArrayRef xs) { + torch::List> result; + result.reserve(xs.size()); + for (const SavedVariable& v : xs) { + result.push_back(v.unpack()); + } + return result; +} + struct TypeAndSize { TypeAndSize() : options(at::TensorOptions()) {} /* implicit */ diff --git a/tools/autograd/templates/TraceType.cpp b/tools/autograd/templates/TraceType.cpp index d08c1e3cc5aae..3ac52ed08edca 100644 --- a/tools/autograd/templates/TraceType.cpp +++ b/tools/autograd/templates/TraceType.cpp @@ -1,6 +1,5 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" -#include #include #include "torch/csrc/autograd/function.h" diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index efddffbe76103..ba2f99369f8d1 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -1,9 +1,7 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" #include "torch/csrc/autograd/FunctionsManual.h" -#include #include -#include // ${generated_comment} diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index 9062a4d08e344..fc8ffa5799c11 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -49,7 +49,6 @@ namespace VariableType { at::Tensor & unpack(Tensor & t, const char * name, int pos); const at::Tensor & unpack(const Tensor & t, const char * name, int pos); at::Tensor unpack_opt(const Tensor & t, const char * name, int pos); - c10::optional unpack_opt(const c10::optional & t, const char * name, int pos); std::vector unpack(at::TensorList tl, const char *name, int pos); }; diff --git a/tools/autograd/templates/python_fft_functions.cpp b/tools/autograd/templates/python_fft_functions.cpp index 7d0186538c981..a77547a6cc078 100644 --- a/tools/autograd/templates/python_fft_functions.cpp +++ b/tools/autograd/templates/python_fft_functions.cpp @@ -7,15 +7,31 @@ #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pycfunction_helpers.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/cuda_lazy_init.h" + +#include using at::Tensor; +using at::Device; +using at::Layout; using at::Scalar; -using at::MemoryFormat; -using at::Generator; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; +using torch::utils::check_out_type_matches; using namespace torch::autograd::utils; namespace torch { namespace autograd { diff --git a/tools/autograd/templates/python_linalg_functions.cpp b/tools/autograd/templates/python_linalg_functions.cpp index b02438e31189c..d361e740b8b3c 100644 --- a/tools/autograd/templates/python_linalg_functions.cpp +++ b/tools/autograd/templates/python_linalg_functions.cpp @@ -7,6 +7,7 @@ #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/structseq.h" diff --git a/tools/autograd/templates/python_nn_functions.cpp b/tools/autograd/templates/python_nn_functions.cpp index e60de17790251..6e4f50a87dd22 100644 --- a/tools/autograd/templates/python_nn_functions.cpp +++ b/tools/autograd/templates/python_nn_functions.cpp @@ -7,6 +7,7 @@ #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/structseq.h" @@ -71,7 +72,8 @@ static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObje ${py_forwards} static PyMethodDef nn_functions[] = { - {"_parse_to", (PyCFunction)(void(*)(void))THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr}, + {"_parse_to", castPyCFunctionWithKeywords(THPVariable__parse_to), + METH_VARARGS | METH_KEYWORDS, nullptr}, ${py_method_defs} {NULL} }; diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 62e9b8dd227fc..c42a869b3a98a 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -8,12 +8,20 @@ #include +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/Dtype.h" #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/Exceptions.h" +#include "torch/csrc/utils/out_types.h" #include "torch/csrc/utils/pybind.h" +#include "torch/csrc/utils/pycfunction_helpers.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/tensor_layouts.h" #include "torch/csrc/utils/tensor_new.h" @@ -44,44 +52,15 @@ using at::Generator; using at::TensorList; using at::Dimname; using at::DimnameList; +using at::ArrayRef; +using torch::utils::check_out_type_matches; using namespace torch::autograd::utils; namespace torch { namespace autograd { static PyObject* THPVariableFunctionsModule = NULL; -static void check_out_type_matches(Tensor result, - ScalarType scalarType, bool scalarType_is_none, - c10::optional layout, - const Device& device, bool device_is_none) { - if (scalarType_is_none && !layout && device_is_none) { // common case - return; - } - if (!scalarType_is_none && result.scalar_type() != scalarType) { - AT_ERROR( - "dtype ", scalarType, - " does not match dtype of out parameter (", result.scalar_type(), ")"); - } - auto scalarType_arg = scalarType_is_none ? result.scalar_type() : scalarType; - auto device_type_arg = device_is_none ? result.device().type() : device.type(); - if (result.scalar_type() != scalarType_arg) { - AT_ERROR( - "scalar type ", scalarType_arg, - " does not match scalar type of out parameter (", result.scalar_type(), ")"); - } - if (layout && result.layout() != *layout) { - AT_ERROR( - "layout ", *layout, - " does not match layout of out parameter (", result.layout(), ")"); - } - if (result.device().type() != device_type_arg) { - AT_ERROR( - "device type ", device_type_arg, - " does not match device type of out parameter (", result.device().type(), ")"); - } -} - inline Tensor dispatch_arange(Scalar end, Tensor result) { pybind11::gil_scoped_release no_gil; return at::arange_out(result, end); @@ -493,23 +472,25 @@ static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, P // Any new ops added here should be accompanied with a comment why they are not // being registered through native_functions.yaml, and be tagged cpp / JIT static PyMethodDef torch_functions[] = { - {"arange", (PyCFunction)(void(*)(void))THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"as_tensor", (PyCFunction)(void(*)(void))THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"dsmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL}, - {"full", (PyCFunction)(void(*)(void))THPVariable_full, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"hsmm", (PyCFunction)(void(*)(void))THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"randint", (PyCFunction)(void(*)(void))THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"range", (PyCFunction)(void(*)(void))THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"saddmm", (PyCFunction)(void(*)(void))THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"sparse_coo_tensor", (PyCFunction)(void(*)(void))THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"_sparse_coo_tensor_unsafe", (PyCFunction)(void(*)(void))THPVariable__sparse_coo_tensor_unsafe, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"_validate_sparse_coo_tensor_args", (PyCFunction)(void(*)(void))THPVariable__validate_sparse_coo_tensor_args, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"spmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"tensor", (PyCFunction)(void(*)(void))THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"get_device", (PyCFunction)(void(*)(void))THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"numel", (PyCFunction)(void(*)(void))THPVariable_numel, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"arange", castPyCFunctionWithKeywords(THPVariable_arange), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"dsmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, NULL}, + {"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"hsmm", castPyCFunctionWithKeywords(THPVariable_hspmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, ${py_method_defs} {NULL} }; @@ -582,29 +563,29 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* { HANDLE_TH_ERRORS static PythonArgParser parser({ - "nonzero(Tensor input, *, Tensor out=None)|deprecated", - "nonzero(Tensor input, *, bool as_tuple)", + "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)", }); - ParsedArgs<2> parsed_args; + ParsedArgs<3> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if(r.has_torch_function()){ return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); } - if (r.idx == 0) { - if (r.isNone(1)) { - return wrap(dispatch_nonzero(r.tensor(0))); - } else { - return wrap(dispatch_nonzero(r.tensor(0), r.tensor(1))); - } - } else { - if (r.toBool(1)) { - return wrap(dispatch_nonzero_numpy(r.tensor(0))); - } else { - return wrap(dispatch_nonzero(r.tensor(0))); - } + const auto as_tuple = r.toBool(1); + const auto has_out = !r.isNone(2); + + if (as_tuple) { + TORCH_CHECK(!has_out, "nonzero does not support the out kwarg when as_tuple is True"); + return wrap(dispatch_nonzero_numpy(r.tensor(0))); } + + if (has_out) { + return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2))); + } + + return wrap(dispatch_nonzero(r.tensor(0))); + END_HANDLE_TH_ERRORS } diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index eaecf0d81845b..11dd227bbf8ad 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -2,6 +2,12 @@ #include +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/Exceptions.h" #include "torch/csrc/Size.h" @@ -12,11 +18,11 @@ #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/jit/frontend/tracer.h" #ifdef USE_CUDA -#include "torch/csrc/cuda/Stream.h" #include "torch/csrc/cuda/Event.h" #endif #include "torch/csrc/utils/cuda_lazy_init.h" #include "torch/csrc/utils/object_ptr.h" +#include "torch/csrc/utils/pycfunction_helpers.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/python_numbers.h" #include "torch/csrc/utils/python_strings.h" @@ -30,6 +36,7 @@ #include #include "c10/util/Optional.h" +#include "c10/core/Stream.h" #include @@ -40,6 +47,7 @@ using at::Backend; using at::Scalar; using at::ScalarType; using at::Tensor; +using c10::Stream; using namespace torch::autograd::utils; namespace torch { namespace autograd { @@ -48,7 +56,7 @@ static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "_is_view"); + return handle_torch_function(self, "_is_view", args); } auto& self_ = reinterpret_cast(self)->cdata; if (self_.is_view()) { @@ -152,7 +160,7 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "get_device"); + return handle_torch_function(self_, "get_device", args, nullptr); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.get_device()); @@ -163,7 +171,7 @@ static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "has_names"); + return handle_torch_function(self_, "has_names", args); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.has_names()); @@ -175,7 +183,7 @@ static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { - return handle_torch_function(self_, "data_ptr"); + return handle_torch_function(self_, "data_ptr", args); } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.data_ptr()); @@ -199,7 +207,7 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "dim"); + return handle_torch_function(self, "dim", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.dim()); @@ -211,7 +219,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "numel"); + return handle_torch_function(self, "numel", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.numel()); @@ -325,7 +333,7 @@ static bool dispatch_to_Bool(const Tensor & self) { static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__float__"); + return handle_torch_function(self, "__float__", args); } jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -336,7 +344,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__complex__"); + return handle_torch_function(self, "__complex__", args); } jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -347,7 +355,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__int__"); + return handle_torch_function(self, "__int__", args); } jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -366,7 +374,7 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__index__"); + return handle_torch_function(self, "__index__", args); } jit::tracer::warn("Converting a tensor to a Python index", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -388,7 +396,7 @@ static Tensor dispatch_invert(const Tensor & self) { static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "__invert__"); + return handle_torch_function(self, "__invert__", args); } auto& self_ = reinterpret_cast(self)->cdata; if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) { @@ -683,7 +691,7 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "element_size"); + return handle_torch_function(self, "element_size", args); } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.element_size()); @@ -704,27 +712,6 @@ static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg) END_HANDLE_TH_ERRORS } -// TODO: move this to ATen. We would need to expose Stream objects in ATen. -static PyObject * THPVariable_record_stream(PyObject* self, PyObject* arg) -{ - HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { - auto args = py::make_tuple(py::handle(arg)); - return handle_torch_function(self, "record_stream", args.ptr()); - } -#ifdef USE_CUDA - auto& self_ = reinterpret_cast(self)->cdata; - if (!THCPStream_Check(arg)) { - return PyErr_Format(PyExc_TypeError, "expected Stream object"); - } - c10::cuda::CUDACachingAllocator::recordStream(self_.storage().data_ptr(), at::cuda::CUDAStream::unpack(((THCPStream*)arg)->cdata)); - Py_RETURN_NONE; -#else - throw std::runtime_error("PyTorch compiled without CUDA support"); -#endif - END_HANDLE_TH_ERRORS -} - static PyObject * THPVariable_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS @@ -782,7 +769,7 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "item"); + return handle_torch_function(self, "item", args); } jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; @@ -851,7 +838,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new"); + return handle_torch_function(self, "new", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -863,7 +850,7 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new_ones"); + return handle_torch_function(self, "new_ones", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -875,7 +862,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "new_tensor"); + return handle_torch_function(self, "new_tensor", args, kwargs); } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -954,7 +941,7 @@ static PyObject * THPVariable_tolist(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "tolist"); + return handle_torch_function(self, "tolist", args); } jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW); auto self_ = reinterpret_cast(self)->cdata; @@ -1023,7 +1010,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { if (check_has_torch_function(self)) { HANDLE_TH_ERRORS - return handle_torch_function(self, "__bool__"); + return handle_torch_function(self, "__bool__", args); END_HANDLE_TH_ERRORS } jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW); @@ -1126,73 +1113,79 @@ static PyObject* THPVariable_set_( // being registered through native_functions.yaml, and be tagged cpp / JIT PyMethodDef variable_methods[] = { // These magic methods are all implemented on python object to wrap NotImplementedError - {"__add__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__radd__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__iadd__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__rmul__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__mul__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__imul__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__sub__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__isub__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__div__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__truediv__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__floordiv__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__idiv__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__ifloordiv__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__mod__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__bool__", (PyCFunction)THPVariable_bool_scalar, METH_NOARGS, NULL}, - {"__float__", (PyCFunction)THPVariable_float_scalar, METH_NOARGS, NULL}, - {"__complex__", (PyCFunction)THPVariable_complex_scalar, METH_NOARGS, NULL}, - {"__int__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL}, - {"__long__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL}, - {"__index__", (PyCFunction)THPVariable_index_scalar, METH_NOARGS, NULL}, - {"__nonzero__", (PyCFunction)THPVariable_bool_scalar, METH_NOARGS, NULL}, - {"__invert__", (PyCFunction)THPVariable_invert, METH_NOARGS, NULL}, - {"__matmul__", (PyCFunction)(void(*)(void))TypeError_to_NotImplemented_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"_is_view", (PyCFunction)THPVariable__is_view, METH_NOARGS, NULL}, - {"apply_", (PyCFunction)THPVariable_apply_, METH_O, NULL}, - {"bfloat16", (PyCFunction)(void(*)(void))THPVariable_bfloat16, METH_VARARGS | METH_KEYWORDS, NULL}, - {"byte", (PyCFunction)(void(*)(void))THPVariable_byte, METH_VARARGS | METH_KEYWORDS, NULL}, - {"char", (PyCFunction)(void(*)(void))THPVariable_char, METH_VARARGS | METH_KEYWORDS, NULL}, - {"contiguous", (PyCFunction)(void(*)(void))THPVariable_contiguous, METH_VARARGS | METH_KEYWORDS, NULL}, - {"copy_", (PyCFunction)(void(*)(void))THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cpu", (PyCFunction)(void(*)(void))THPVariable_cpu, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cuda", (PyCFunction)(void(*)(void))THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL}, - {"data_ptr", (PyCFunction)THPVariable_data_ptr, METH_NOARGS, NULL}, - {"dim", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL}, - {"has_names", (PyCFunction)THPVariable_has_names, METH_NOARGS, NULL}, - {"double", (PyCFunction)(void(*)(void))THPVariable_double, METH_VARARGS | METH_KEYWORDS, NULL}, - {"element_size", (PyCFunction)THPVariable_element_size, METH_NOARGS, NULL}, - {"float", (PyCFunction)(void(*)(void))THPVariable_float, METH_VARARGS | METH_KEYWORDS, NULL}, - {"get_device", (PyCFunction)THPVariable_get_device, METH_NOARGS, NULL}, - {"bool", (PyCFunction)(void(*)(void))THPVariable_bool, METH_VARARGS | METH_KEYWORDS, NULL}, - {"half", (PyCFunction)(void(*)(void))THPVariable_half, METH_VARARGS | METH_KEYWORDS, NULL}, - {"int", (PyCFunction)(void(*)(void))THPVariable_int, METH_VARARGS | METH_KEYWORDS, NULL}, - {"is_contiguous", (PyCFunction)(void(*)(void))THPVariable_is_contiguous, METH_VARARGS | METH_KEYWORDS, NULL}, - {"item", (PyCFunction)THPVariable_item, METH_NOARGS, NULL}, - {"long", (PyCFunction)(void(*)(void))THPVariable_long, METH_VARARGS | METH_KEYWORDS, NULL}, - {"map_", (PyCFunction)(void(*)(void))THPVariable_map_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"map2_", (PyCFunction)(void(*)(void))THPVariable_map2_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ndimension", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL}, - {"nelement", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL}, - {"new", (PyCFunction)(void(*)(void))THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL}, - {"new_ones", (PyCFunction)(void(*)(void))THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL}, - {"new_tensor", (PyCFunction)(void(*)(void))THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, - {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS, NULL}, - {"numel", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL}, - {"numpy", (PyCFunction)THPVariable_numpy, METH_NOARGS, NULL}, - {"record_stream", (PyCFunction)THPVariable_record_stream, METH_O, NULL}, - {"requires_grad_", (PyCFunction)(void(*)(void))THPVariable_requires_grad_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"set_", (PyCFunction)(void (*)(void))THPVariable_set_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"short", (PyCFunction)(void(*)(void))THPVariable_short, METH_VARARGS | METH_KEYWORDS, NULL}, - {"size", (PyCFunction)(void(*)(void))THPVariable_size, METH_VARARGS | METH_KEYWORDS, NULL}, - {"storage", (PyCFunction)THPVariable_storage, METH_NOARGS, NULL}, - {"storage_offset", (PyCFunction)THPVariable_storage_offset, METH_NOARGS, NULL}, - {"storage_type", (PyCFunction)THPVariable_storage_type, METH_NOARGS, NULL}, - {"stride", (PyCFunction)(void(*)(void))THPVariable_stride, METH_VARARGS | METH_KEYWORDS, NULL}, - {"to", (PyCFunction)(void(*)(void))THPVariable_to, METH_VARARGS | METH_KEYWORDS, NULL}, - {"tolist", (PyCFunction)THPVariable_tolist, METH_NOARGS, NULL}, - {"type", (PyCFunction)(void(*)(void))THPVariable_type, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__add__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__radd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__iadd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__imul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__sub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__isub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__div__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__truediv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__floordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__idiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__gt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ge__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__bool__", THPVariable_bool_scalar, METH_NOARGS, NULL}, + {"__float__", THPVariable_float_scalar, METH_NOARGS, NULL}, + {"__complex__", THPVariable_complex_scalar, METH_NOARGS, NULL}, + {"__int__", THPVariable_integral_scalar, METH_NOARGS, NULL}, + {"__long__", THPVariable_integral_scalar, METH_NOARGS, NULL}, + {"__index__", THPVariable_index_scalar, METH_NOARGS, NULL}, + {"__nonzero__", THPVariable_bool_scalar, METH_NOARGS, NULL}, + {"__invert__", THPVariable_invert, METH_NOARGS, NULL}, + {"__matmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"_is_view", THPVariable__is_view, METH_NOARGS, NULL}, + {"apply_", THPVariable_apply_, METH_O, NULL}, + {"bfloat16", castPyCFunctionWithKeywords(THPVariable_bfloat16), METH_VARARGS | METH_KEYWORDS, NULL}, + {"byte", castPyCFunctionWithKeywords(THPVariable_byte), METH_VARARGS | METH_KEYWORDS, NULL}, + {"char", castPyCFunctionWithKeywords(THPVariable_char), METH_VARARGS | METH_KEYWORDS, NULL}, + {"contiguous", castPyCFunctionWithKeywords(THPVariable_contiguous), METH_VARARGS | METH_KEYWORDS, NULL}, + {"copy_", castPyCFunctionWithKeywords(THPVariable_copy_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cpu", castPyCFunctionWithKeywords(THPVariable_cpu), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cuda", castPyCFunctionWithKeywords(THPVariable_cuda), METH_VARARGS | METH_KEYWORDS, NULL}, + {"data_ptr", THPVariable_data_ptr, METH_NOARGS, NULL}, + {"dim", THPVariable_dim, METH_NOARGS, NULL}, + {"has_names", THPVariable_has_names, METH_NOARGS, NULL}, + {"double", castPyCFunctionWithKeywords(THPVariable_double), METH_VARARGS | METH_KEYWORDS, NULL}, + {"element_size", THPVariable_element_size, METH_NOARGS, NULL}, + {"float", castPyCFunctionWithKeywords(THPVariable_float), METH_VARARGS | METH_KEYWORDS, NULL}, + {"get_device", THPVariable_get_device, METH_NOARGS, NULL}, + {"bool", castPyCFunctionWithKeywords(THPVariable_bool), METH_VARARGS | METH_KEYWORDS, NULL}, + {"half", castPyCFunctionWithKeywords(THPVariable_half), METH_VARARGS | METH_KEYWORDS, NULL}, + {"int", castPyCFunctionWithKeywords(THPVariable_int), METH_VARARGS | METH_KEYWORDS, NULL}, + {"is_contiguous", castPyCFunctionWithKeywords(THPVariable_is_contiguous), METH_VARARGS | METH_KEYWORDS, NULL}, + {"item", THPVariable_item, METH_NOARGS, NULL}, + {"long", castPyCFunctionWithKeywords(THPVariable_long), METH_VARARGS | METH_KEYWORDS, NULL}, + {"map_", castPyCFunctionWithKeywords(THPVariable_map_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"map2_", castPyCFunctionWithKeywords(THPVariable_map2_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"ndimension", THPVariable_dim, METH_NOARGS, NULL}, + {"nelement", THPVariable_numel, METH_NOARGS, NULL}, + {"new", castPyCFunctionWithKeywords(THPVariable_new), METH_VARARGS | METH_KEYWORDS, NULL}, + {"new_ones", castPyCFunctionWithKeywords(THPVariable_new_ones), METH_VARARGS | METH_KEYWORDS, NULL}, + {"new_tensor", castPyCFunctionWithKeywords(THPVariable_new_tensor), METH_VARARGS | METH_KEYWORDS, NULL}, + {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS, NULL}, + {"numel", THPVariable_numel, METH_NOARGS, NULL}, + {"numpy", THPVariable_numpy, METH_NOARGS, NULL}, + {"requires_grad_", castPyCFunctionWithKeywords(THPVariable_requires_grad_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL}, + {"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL}, + {"storage", THPVariable_storage, METH_NOARGS, NULL}, + {"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL}, + {"storage_type", THPVariable_storage_type, METH_NOARGS, NULL}, + {"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL}, + {"to", castPyCFunctionWithKeywords(THPVariable_to), METH_VARARGS | METH_KEYWORDS, NULL}, + {"tolist", THPVariable_tolist, METH_NOARGS, NULL}, + {"type", castPyCFunctionWithKeywords(THPVariable_type), METH_VARARGS | METH_KEYWORDS, NULL}, ${py_method_defs} {NULL} }; diff --git a/tools/autograd/utils.py b/tools/autograd/utils.py deleted file mode 100644 index 92f8fe89f5643..0000000000000 --- a/tools/autograd/utils.py +++ /dev/null @@ -1,90 +0,0 @@ -import re -import os -import yaml -from .nested_dict import nested_dict - - -__all__ = [ - 'CodeTemplate', 'IDENT_REGEX', 'YamlLoader', 'nested_dict', - 'split_name_params', 'write', -] - -from tools.codegen.code_template import CodeTemplate - -# You should use these lines, rather than doing it manually. -# Especially if you see this error! -# -# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load -# loader = Loader(stream) -# TypeError: 'module' object is not callable -try: - # use faster C loader if available - from yaml import CLoader as YamlLoader -except ImportError: - from yaml import Loader as YamlLoader - -GENERATED_COMMENT = CodeTemplate( - "@" + "generated from ${filename}") - -# Matches "foo" in "foo, bar" but not "foobar". Used to search for the -# occurrence of a parameter in the derivative formula -IDENT_REGEX = r'(^|\W){}($|\W)' - - -# TODO: Use a real parser here; this will get bamboozled -# by signatures that contain things like std::array (note the space) -def split_name_params(prototype): - name, overload_name, params = re.match(r'(\w+)(\.\w+)?\((.*)\)', prototype).groups() - return name, params.split(', ') - - -# When tracing, we record inplace operations as out-of-place operations, -# because we don't have a story for side effects in the IR yet. -# -# Doing this un-inplacing is a little delicate however; __and__ is NOT inplace! -# TODO: Do something more robust -def uninplace_api_name(api_name): - if api_name.endswith('_') and not api_name.endswith('__'): - api_name = api_name[:-1] - if api_name.endswith('_out'): - api_name = api_name[:-4] - return api_name - - -def write(dirname, name, template, env): - env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template.filename) - path = os.path.join(dirname, name) - # See Note [Unchanging results for ninja] - try: - with open(path, 'r') as f: - old_val = f.read() - except IOError: - old_val = None - new_val = template.substitute(env) - if old_val != new_val: - with open(path, 'w') as f: - print("Writing {}".format(path)) - f.write(new_val) - else: - print("Skipped writing {}".format(path)) - -def is_tensor_method(declaration): - return 'Tensor' in declaration['method_of'] - -def is_out_variant(decl): - return decl['name'].endswith('_out') - -def op_name_without_overload(decl): - name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4] - return 'aten::{}'.format(name) - -def load_op_list_and_strip_overload(op_list, op_list_path): - if op_list is None and op_list_path is None: - return None - if op_list is None: - op_list = [] - if op_list_path is not None: - with open(op_list_path, 'r') as f: - op_list += yaml.load(f, Loader=YamlLoader) - # strip out the overload part - return {opname.split('.', 1)[0] for opname in op_list} diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index e5f85df64a82c..2a40c90efdcde 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -1,5 +1,4 @@ import os -import sys from glob import glob import shutil @@ -8,22 +7,19 @@ def _overlay_windows_vcvars(env): - if sys.version_info >= (3, 5): - from distutils._msvccompiler import _get_vc_env - vc_arch = 'x64' if IS_64BIT else 'x86' - vc_env = _get_vc_env(vc_arch) - # Keys in `_get_vc_env` are always lowercase. - # We turn them into uppercase before overlaying vcvars - # because OS environ keys are always uppercase on Windows. - # https://stackoverflow.com/a/7797329 - vc_env = {k.upper(): v for k, v in vc_env.items()} - for k, v in env.items(): - uk = k.upper() - if uk not in vc_env: - vc_env[uk] = v - return vc_env - else: - return env + from distutils._msvccompiler import _get_vc_env + vc_arch = 'x64' if IS_64BIT else 'x86' + vc_env = _get_vc_env(vc_arch) + # Keys in `_get_vc_env` are always lowercase. + # We turn them into uppercase before overlaying vcvars + # because OS environ keys are always uppercase on Windows. + # https://stackoverflow.com/a/7797329 + vc_env = {k.upper(): v for k, v in vc_env.items()} + for k, v in env.items(): + uk = k.upper() + if uk not in vc_env: + vc_env[uk] = v + return vc_env def _create_build_env(): diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 26ab975373a89..7adb8de597923 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -7,9 +7,6 @@ GENERATED_CPP = [ "autograd/generated/VariableType_2.cpp", "autograd/generated/VariableType_3.cpp", "autograd/generated/VariableType_4.cpp", - "jit/generated/generated_unboxing_wrappers_0.cpp", - "jit/generated/generated_unboxing_wrappers_1.cpp", - "jit/generated/generated_unboxing_wrappers_2.cpp", "autograd/generated/TraceType_0.cpp", "autograd/generated/TraceType_1.cpp", "autograd/generated/TraceType_2.cpp", @@ -23,12 +20,22 @@ GENERATED_CPP = [ "autograd/generated/python_variable_methods.cpp", ] +# NVFuser runtime library +libtorch_nvfuser_runtime_sources = [ + "torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu", + "torch/csrc/jit/codegen/cuda/runtime/broadcast.cu", + "torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu", + "torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu", + "torch/csrc/jit/codegen/cuda/runtime/helpers.cu", + "torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu", + "torch/csrc/jit/codegen/cuda/runtime/tensor.cu", +] + +libtorch_nvfuser_generated_headers = ["{}.h".format(name[36:-3]) for name in libtorch_nvfuser_runtime_sources] + def libtorch_generated_sources(gencode_pattern): return [gencode_pattern.format(name) for name in [ "autograd/generated/Functions.cpp", - "jit/generated/generated_unboxing_wrappers_0.cpp", - "jit/generated/generated_unboxing_wrappers_1.cpp", - "jit/generated/generated_unboxing_wrappers_2.cpp", "autograd/generated/VariableType_0.cpp", "autograd/generated/VariableType_1.cpp", "autograd/generated/VariableType_2.cpp", @@ -39,13 +46,9 @@ def libtorch_generated_sources(gencode_pattern): "autograd/generated/TraceType_2.cpp", "autograd/generated/TraceType_3.cpp", "autograd/generated/TraceType_4.cpp", - ]] + [ - "torch/csrc/autograd/TraceTypeManual.cpp", - "torch/csrc/autograd/VariableTypeManual.cpp", - "torch/csrc/autograd/FunctionsManual.cpp", - ] + ]] -# copied from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/CMakeLists.txt +# copied from https://github.com/pytorch/pytorch/blob/f99a693cd9ff7a9b5fdc71357dac66b8192786d3/aten/src/ATen/core/CMakeLists.txt jit_core_headers = [ "torch/csrc/utils/memory.h", "torch/csrc/WindowsTorchApiMacro.h", @@ -73,12 +76,16 @@ jit_core_sources = [ "torch/csrc/jit/frontend/source_range.cpp", ] -# copied from https://github.com/pytorch/pytorch/blob/master/tools/cpp_build/torch/CMakeLists.txt +# copied from https://github.com/pytorch/pytorch/blob/0bde610c14b92d351b968a0228df29e92442b1cc/torch/CMakeLists.txt # There are some common files used in both internal lite-interpreter and full-jit. Making a separate # list for the shared files. core_sources_common = [ - "torch/csrc/autograd/profiler.cpp", + "torch/csrc/autograd/profiler_legacy.cpp", + "torch/csrc/autograd/profiler_kineto.cpp", + "torch/csrc/autograd/profiler_utils.cpp", + "torch/csrc/autograd/autograd_meta.cpp", + "torch/csrc/autograd/forward_grad.cpp", "torch/csrc/jit/frontend/edit_distance.cpp", "torch/csrc/jit/frontend/string_to_type.cpp", "torch/csrc/jit/mobile/type_parser.cpp", @@ -86,6 +93,7 @@ core_sources_common = [ "torch/csrc/jit/runtime/jit_exception.cpp", "torch/csrc/jit/runtime/operator.cpp", "torch/csrc/jit/runtime/print_handler.cpp", + "torch/csrc/jit/runtime/slice_indices_adjust.cpp", "torch/csrc/jit/runtime/register_ops_utils.cpp", "torch/csrc/jit/runtime/vararg_functions.cpp", "torch/csrc/jit/serialization/unpickler.cpp", @@ -115,7 +123,7 @@ core_trainer_sources = [ "torch/csrc/jit/serialization/type_name_uniquer.cpp", ] -core_sources_full = [ +core_sources_full_mobile = [ "torch/csrc/jit/api/function_impl.cpp", "torch/csrc/jit/api/module.cpp", "torch/csrc/jit/api/object.cpp", @@ -148,6 +156,8 @@ core_sources_full = [ "torch/csrc/jit/ir/scope.cpp", "torch/csrc/jit/ir/subgraph_matcher.cpp", "torch/csrc/jit/jit_log.cpp", + "torch/csrc/jit/jit_opt_limit.cpp", + "torch/csrc/jit/passes/annotate_warns.cpp", "torch/csrc/jit/passes/bailout_graph.cpp", "torch/csrc/jit/passes/batch_mm.cpp", "torch/csrc/jit/passes/canonicalize.cpp", @@ -164,7 +174,6 @@ core_sources_full = [ "torch/csrc/jit/passes/erase_number_types.cpp", "torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp", "torch/csrc/jit/passes/freeze_module.cpp", - "torch/csrc/jit/passes/reconstruct_scopes.cpp", "torch/csrc/jit/passes/fuse_linear.cpp", "torch/csrc/jit/passes/fuse_relu.cpp", "torch/csrc/jit/passes/graph_fuser.cpp", @@ -184,12 +193,15 @@ core_sources_full = [ "torch/csrc/jit/passes/lower_tuples.cpp", "torch/csrc/jit/passes/normalize_ops.cpp", "torch/csrc/jit/passes/peephole_list_idioms.cpp", + "torch/csrc/jit/passes/peephole_alias_sensitive.cpp", "torch/csrc/jit/passes/pass_manager.cpp", "torch/csrc/jit/passes/peephole.cpp", "torch/csrc/jit/passes/create_functional_graphs.cpp", "torch/csrc/jit/passes/remove_mutation.cpp", "torch/csrc/jit/passes/prepack_folding.cpp", "torch/csrc/jit/passes/fold_conv_bn.cpp", + "torch/csrc/jit/passes/frozen_conv_folding.cpp", + "torch/csrc/jit/passes/frozen_graph_optimizations.cpp", "torch/csrc/jit/passes/remove_expands.cpp", "torch/csrc/jit/passes/remove_dropout.cpp", "torch/csrc/jit/passes/requires_grad_analysis.cpp", @@ -202,6 +214,7 @@ core_sources_full = [ "torch/csrc/jit/passes/utils/subgraph_utils.cpp", "torch/csrc/jit/passes/xnnpack_rewrite.cpp", "torch/csrc/jit/passes/vulkan_rewrite.cpp", + "torch/csrc/jit/passes/metal_rewrite.cpp", "torch/csrc/jit/passes/quantization/helper.cpp", "torch/csrc/jit/passes/quantization/quantization_type.cpp", "torch/csrc/jit/passes/quantization/insert_observers.cpp", @@ -218,7 +231,6 @@ core_sources_full = [ "torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp", "torch/csrc/jit/runtime/profiling_record.cpp", "torch/csrc/jit/runtime/symbolic_script.cpp", - "torch/csrc/jit/runtime/static/impl.cpp", "torch/csrc/jit/serialization/import.cpp", "torch/csrc/jit/serialization/import_export_helpers.cpp", "torch/csrc/jit/serialization/import_source.cpp", @@ -226,10 +238,11 @@ core_sources_full = [ "torch/csrc/jit/serialization/python_print.cpp", "torch/csrc/jit/serialization/source_range_serialization.cpp", "torch/csrc/jit/tensorexpr/bounds_inference.cpp", + "torch/csrc/jit/tensorexpr/bounds_overlap.cpp", + "torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp", "torch/csrc/jit/tensorexpr/codegen.cpp", "torch/csrc/jit/tensorexpr/eval.cpp", "torch/csrc/jit/tensorexpr/expr.cpp", - "torch/csrc/jit/tensorexpr/function.cpp", "torch/csrc/jit/tensorexpr/hash_provider.cpp", "torch/csrc/jit/tensorexpr/ir.cpp", "torch/csrc/jit/tensorexpr/ir_mutator.cpp", @@ -242,6 +255,7 @@ core_sources_full = [ "torch/csrc/jit/tensorexpr/block_codegen.cpp", "torch/csrc/jit/tensorexpr/loopnest.cpp", "torch/csrc/jit/tensorexpr/mem_arena.cpp", + "torch/csrc/jit/tensorexpr/reduction.cpp", "torch/csrc/jit/tensorexpr/registerizer.cpp", "torch/csrc/jit/tensorexpr/tensor.cpp", "torch/csrc/jit/tensorexpr/types.cpp", @@ -252,6 +266,13 @@ core_sources_full = [ "torch/csrc/utils/variadic.cpp", ] +core_sources_full = core_sources_full_mobile + [ + "torch/csrc/jit/runtime/static/fusion.cpp", + "torch/csrc/jit/runtime/static/impl.cpp", + "torch/csrc/jit/runtime/static/ops.cpp", + "torch/csrc/jit/runtime/static/passes.cpp", +] + libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources) libtorch_distributed_sources = [ @@ -270,6 +291,8 @@ libtorch_distributed_sources = [ "torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp", "torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.cpp", "torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp", + "torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp", + "torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.cpp", "torch/csrc/distributed/rpc/message.cpp", "torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp", "torch/csrc/distributed/rpc/profiler/server_process_global_profiler.cpp", @@ -280,8 +303,8 @@ libtorch_distributed_sources = [ "torch/csrc/distributed/rpc/request_callback_no_python.cpp", "torch/csrc/distributed/rpc/rpc_agent.cpp", "torch/csrc/distributed/rpc/rref_context.cpp", - "torch/csrc/distributed/rpc/rref_proto.cpp", "torch/csrc/distributed/rpc/rref_impl.cpp", + "torch/csrc/distributed/rpc/rref_proto.cpp", "torch/csrc/distributed/rpc/script_call.cpp", "torch/csrc/distributed/rpc/script_remote_call.cpp", "torch/csrc/distributed/rpc/script_resp.cpp", @@ -326,14 +349,14 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ "torch/csrc/jit/serialization/export_module.cpp", "torch/csrc/jit/serialization/import_legacy.cpp", "torch/csrc/utils/byte_order.cpp", + "torch/csrc/utils/out_types.cpp", ] def libtorch_sources(gencode_pattern = ":generate-code[{}]"): return libtorch_generated_sources(gencode_pattern) + libtorch_core_sources + libtorch_distributed_sources + libtorch_extra_sources -libtorch_cuda_sources = [ +libtorch_cuda_core_sources = [ "torch/csrc/cuda/comm.cpp", - "torch/csrc/cuda/nccl.cpp", "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", "torch/csrc/autograd/profiler_cuda.cpp", "torch/csrc/autograd/functions/comm.cpp", @@ -360,8 +383,10 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/cuda/kernel_cache.cpp", "torch/csrc/jit/codegen/cuda/kernel_ir.cpp", "torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp", + "torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp", "torch/csrc/jit/codegen/cuda/lower_index.cpp", "torch/csrc/jit/codegen/cuda/lower_loops.cpp", + "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp", "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp", "torch/csrc/jit/codegen/cuda/lower_unroll.cpp", "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp", @@ -382,6 +407,11 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/cuda/transform_rfactor.cpp", "torch/csrc/jit/codegen/cuda/type.cpp", "torch/csrc/jit/tensorexpr/cuda_codegen.cpp", + "torch/csrc/jit/runtime/register_cuda_ops.cpp", +] + +libtorch_cuda_sources = libtorch_cuda_core_sources + [ + "torch/csrc/cuda/nccl.cpp", ] torch_cpp_srcs = [ @@ -448,6 +478,7 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/python_comm.cpp", "torch/csrc/cuda/Storage.cpp", "torch/csrc/cuda/Stream.cpp", + "torch/csrc/cuda/Graph.cpp", "torch/csrc/cuda/serialization.cpp", "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", @@ -472,10 +503,10 @@ libtorch_python_core_sources = [ "torch/csrc/MemoryFormat.cpp", "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", - "torch/csrc/PtrWrapper.cpp", "torch/csrc/python_dimname.cpp", "torch/csrc/Size.cpp", "torch/csrc/Storage.cpp", + "torch/csrc/Stream.cpp", "torch/csrc/TypeInfo.cpp", "torch/csrc/api/src/python/init.cpp", "torch/csrc/autograd/functions/init.cpp", @@ -497,7 +528,9 @@ libtorch_python_core_sources = [ "torch/csrc/jit/passes/onnx/constant_fold.cpp", "torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp", "torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp", + "torch/csrc/jit/passes/onnx/list_model_parameters.cpp", "torch/csrc/jit/passes/onnx/function_substitution.cpp", + "torch/csrc/jit/passes/onnx/fold_if_node.cpp", "torch/csrc/jit/passes/onnx/helper.cpp", "torch/csrc/jit/passes/onnx/peephole.cpp", "torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp", @@ -506,6 +539,7 @@ libtorch_python_core_sources = [ "torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp", "torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp", "torch/csrc/jit/passes/onnx/shape_type_inference.cpp", + "torch/csrc/jit/python/pybind_utils.cpp", "torch/csrc/jit/python/python_arg_flatten.cpp", "torch/csrc/jit/python/python_custom_class.cpp", "torch/csrc/jit/python/python_interpreter.cpp", @@ -516,6 +550,7 @@ libtorch_python_core_sources = [ "torch/csrc/jit/python/python_sugared_value.cpp", "torch/csrc/jit/python/python_tree_views.cpp", "torch/csrc/jit/runtime/static/init.cpp", + "torch/csrc/jit/tensorexpr/tensorexpr_init.cpp", "torch/csrc/multiprocessing/init.cpp", "torch/csrc/onnx/init.cpp", "torch/csrc/serialization.cpp", @@ -541,11 +576,17 @@ libtorch_python_core_sources = [ "torch/csrc/utils/disable_torch_function.cpp", ] -libtorch_python_distributed_sources = [ - "torch/csrc/distributed/autograd/init.cpp", - "torch/csrc/distributed/c10d/comm.cpp", +libtorch_python_distributed_core_sources = [ + "torch/lib/c10d/comm.cpp", + "torch/lib/c10d/default_comm_hooks.cpp", + "torch/lib/c10d/frontend.cpp", + "torch/lib/c10d/reducer.cpp", + "torch/csrc/distributed/c10d/python_comm_hook.cpp", "torch/csrc/distributed/c10d/init.cpp", - "torch/csrc/distributed/c10d/reducer.cpp", +] + +libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [ + "torch/csrc/distributed/autograd/init.cpp", "torch/csrc/distributed/rpc/init.cpp", "torch/csrc/distributed/rpc/process_group_agent.cpp", "torch/csrc/distributed/rpc/py_rref.cpp", diff --git a/tools/clang_format_all.py b/tools/clang_format_all.py index 710a21e335142..77ca68d92b0b5 100755 --- a/tools/clang_format_all.py +++ b/tools/clang_format_all.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -A script that runs clang-format on all C/C++ files in CLANG_FORMAT_WHITELIST. There is +A script that runs clang-format on all C/C++ files in CLANG_FORMAT_ALLOWLIST. There is also a diff mode which simply checks if clang-format would make any changes, which is useful for CI purposes. @@ -14,22 +14,22 @@ import sys from clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH -# Whitelist of directories to check. All files that in that directory +# Allowlist of directories to check. All files that in that directory # (recursively) will be checked. -# If you edit this, please edit the whitelist in clang_format_ci.sh as well. -CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/", "test/cpp/jit/", "test/cpp/tensorexpr/"] +# If you edit this, please edit the allowlist in clang_format_ci.sh as well. +CLANG_FORMAT_ALLOWLIST = ["torch/csrc/jit/", "test/cpp/jit/", "test/cpp/tensorexpr/"] # Only files with names matching this regex will be formatted. CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$") -def get_whitelisted_files(): +def get_allowlisted_files(): """ - Parse CLANG_FORMAT_WHITELIST and resolve all directories. - Returns the set of whitelist cpp source files. + Parse CLANG_FORMAT_ALLOWLIST and resolve all directories. + Returns the set of allowlist cpp source files. """ matches = [] - for dir in CLANG_FORMAT_WHITELIST: + for dir in CLANG_FORMAT_ALLOWLIST: for root, dirnames, filenames in os.walk(dir): for filename in filenames: if CPP_FILE_REGEX.match(filename): @@ -77,7 +77,7 @@ async def file_clang_formatted_correctly(filename, semaphore, verbose=False): async def run_clang_format(max_processes, diff=False, verbose=False): """ - Run clang-format to all files in CLANG_FORMAT_WHITELIST that match CPP_FILE_REGEX. + Run clang-format to all files in CLANG_FORMAT_ALLOWLIST that match CPP_FILE_REGEX. """ # Check to make sure the clang-format binary exists. if not os.path.exists(CLANG_FORMAT_PATH): @@ -97,7 +97,7 @@ async def run_clang_format(max_processes, diff=False, verbose=False): # Format files in parallel. if diff: - for f in asyncio.as_completed([file_clang_formatted_correctly(f, semaphore, verbose) for f in get_whitelisted_files()]): + for f in asyncio.as_completed([file_clang_formatted_correctly(f, semaphore, verbose) for f in get_allowlisted_files()]): ok &= await f if ok: @@ -105,7 +105,7 @@ async def run_clang_format(max_processes, diff=False, verbose=False): else: print("Some files not formatted correctly") else: - await asyncio.gather(*[run_clang_format_on_file(f, semaphore, verbose) for f in get_whitelisted_files()]) + await asyncio.gather(*[run_clang_format_on_file(f, semaphore, verbose) for f in get_allowlisted_files()]) return ok @@ -134,7 +134,7 @@ def main(args): options = parse_args(args) # Get clang-format and make sure it is the right binary and it is in the right place. ok = get_and_check_clang_format(options.verbose) - # Invoke clang-format on all files in the directories in the whitelist. + # Invoke clang-format on all files in the directories in the allowlist. if ok: loop = asyncio.get_event_loop() ok = loop.run_until_complete(run_clang_format(options.max_processes, options.diff, options.verbose)) diff --git a/tools/clang_format_hash/linux64/clang-format-linux64 b/tools/clang_format_hash/linux64/clang-format-linux64 index 6a1e2ca2fd308..eb3292de7846e 100644 --- a/tools/clang_format_hash/linux64/clang-format-linux64 +++ b/tools/clang_format_hash/linux64/clang-format-linux64 @@ -1 +1 @@ -d1365110da598d148d8143a7f2ccfd8bac7df499 \ No newline at end of file +9073602de1c4e1748f2feea5a0782417b20e3043 \ No newline at end of file diff --git a/tools/clang_format_hash/mac/clang-format-mojave b/tools/clang_format_hash/mac/clang-format-mojave index 30801a239bf90..f3dfcc3c1ae3c 100644 --- a/tools/clang_format_hash/mac/clang-format-mojave +++ b/tools/clang_format_hash/mac/clang-format-mojave @@ -1 +1 @@ -020c7f38f14665c2ed82f3e8976c9074c2cfac0a \ No newline at end of file +b24cc8972344c4e01afbbae78d6a414f7638ff6f \ No newline at end of file diff --git a/tools/code_analyzer/build.sh b/tools/code_analyzer/build.sh index 88c107bf130ca..31c2f7c5a6c3b 100755 --- a/tools/code_analyzer/build.sh +++ b/tools/code_analyzer/build.sh @@ -101,25 +101,8 @@ analyze_torch_mobile() { call_analyzer } -convert_output_to_bazel() { - cd "${SRC_ROOT}" - - DEST="${BUILD_ROOT}/pt_deps.bzl" - - args=( - --op_dependency "${OUTPUT}" - --output "${DEST}" - ) - - if [ -n "${BASE_OPS_FILE}" ] && [ -f "${BASE_OPS_FILE}" ]; then - args+=( - --base_ops $(< ${BASE_OPS_FILE}) - ) - fi - - python -m tools.code_analyzer.op_deps_processor "${args[@]}" - - echo "Deployed file at: ${DEST}" +print_output_file_path() { + echo "Deployed file at: ${OUTPUT}" } analyze_test_project() { @@ -153,7 +136,7 @@ if [ -n "${ANALYZE_TORCH}" ]; then build_torch_mobile analyze_torch_mobile if [ -n "${DEPLOY}" ]; then - convert_output_to_bazel + print_output_file_path fi fi diff --git a/tools/code_analyzer/default_op_deps.yaml b/tools/code_analyzer/default_op_deps.yaml index 8a71f33bcca80..d6fac569cc177 100644 --- a/tools/code_analyzer/default_op_deps.yaml +++ b/tools/code_analyzer/default_op_deps.yaml @@ -1,20 +1,27 @@ - name: __ROOT__ depends: + - name: aten::_coalesced_ - name: aten::_empty_affine_quantized - name: aten::_empty_per_channel_affine_quantized - name: aten::_indices + - name: aten::_mkldnn_transpose - name: aten::_sparse_coo_tensor_unsafe - name: aten::_values - name: aten::_version - name: aten::add - name: aten::add_ + - name: aten::addmm_ - name: aten::any + - name: aten::as_strided - name: aten::as_strided_ - name: aten::cat - name: aten::chunk + - name: aten::clamp_max + - name: aten::clamp_min - name: aten::clone - name: aten::contiguous - name: aten::copy_ + - name: aten::dense_dim - name: aten::dequantize - name: aten::detach - name: aten::empty @@ -23,7 +30,6 @@ - name: aten::eq - name: aten::equal - name: aten::expand - - name: aten::fill_ - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_leaf @@ -33,8 +39,8 @@ - name: aten::lt - name: aten::mm - name: aten::mul + - name: aten::mul_ - name: aten::narrow - - name: aten::ones - name: aten::ones_like - name: aten::output_nr - name: aten::q_per_channel_axis @@ -52,19 +58,36 @@ - name: aten::set_ - name: aten::set_data - name: aten::size + - name: aten::sparse_dim - name: aten::stride - name: aten::sub - name: aten::sum - name: aten::t - name: aten::to + - name: aten::transpose + - name: aten::transpose_ - name: aten::view - name: aten::zero_ - name: aten::zeros - name: aten::zeros_like - name: _quantized::add depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to - name: _quantized::conv2d depends: - name: aten::eq @@ -98,10 +121,53 @@ - name: aten::is_nonzero - name: aten::squeeze_ - name: aten::unsqueeze +- name: _quantized::conv_transpose1d_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::unsqueeze + - name: aten::zeros - name: _quantized::conv_transpose2d depends: - name: aten::eq - name: aten::is_nonzero +- name: _quantized::conv_transpose2d_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros +- name: _quantized::conv_transpose3d_prepack + depends: + - name: aten::eq + - name: aten::is_nonzero - name: _quantized::linear depends: - name: aten::eq @@ -112,8 +178,28 @@ - name: aten::is_nonzero - name: _quantized::linear_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros - name: _quantized::linear_prepack_fp16 depends: - name: aten::eq @@ -140,11 +226,6 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::leaky_relu -- name: aten::Int - depends: - - name: aten::eq - - name: aten::is_nonzero - - name: aten::item - name: aten::__and__ depends: - name: aten::bitwise_and @@ -163,11 +244,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -185,11 +262,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -207,11 +280,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -229,11 +298,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -266,57 +331,37 @@ depends: - name: aten::eq - name: aten::is_nonzero -- name: aten::_addmv_impl_ - depends: - - name: aten::contiguous - - name: aten::eq - - name: aten::is_nonzero - - name: aten::size - - name: aten::stride -- name: aten::_addr +- name: aten::_add_relu depends: - - name: aten::_copy_from - name: aten::as_strided_ - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - name: aten::resize_as_ - - name: aten::set_quantizer_ - - name: aten::size - - name: aten::stride - name: aten::to - - name: aten::zero_ -- name: aten::_addr_ +- name: aten::_add_relu_ depends: - - name: aten::_copy_from - name: aten::as_strided_ - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - name: aten::resize_as_ - - name: aten::set_quantizer_ + - name: aten::to +- name: aten::_addmv_impl_ + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero - name: aten::size - name: aten::stride - - name: aten::to - - name: aten::zero_ - name: aten::_aminmax depends: - name: aten::as_strided_ @@ -336,7 +381,7 @@ - name: aten::stride - name: aten::to - name: aten::unsqueeze_ -- name: aten::_amp_non_finite_check_and_unscale_ +- name: aten::_amp_foreach_non_finite_check_and_unscale_ depends: - name: aten::eq - name: aten::is_nonzero @@ -344,6 +389,10 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: aten::_backward + depends: + - name: aten::eq + - name: aten::is_nonzero - name: aten::_baddbmm_mkl_ depends: - name: aten::eq @@ -690,9 +739,16 @@ - name: aten::size - name: aten::_dirichlet_grad depends: + - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::_embedding_bag depends: - name: aten::contiguous @@ -896,6 +952,7 @@ - name: aten::unsqueeze - name: aten::_fft_with_size depends: + - name: aten::_fft_with_size - name: aten::eq - name: aten::is_nonzero - name: aten::_foreach_add @@ -948,6 +1005,16 @@ - name: aten::eq - name: aten::exp_ - name: aten::is_nonzero +- name: aten::_foreach_maximum + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::maximum +- name: aten::_foreach_minimum + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::minimum - name: aten::_foreach_mul depends: - name: aten::eq @@ -1186,15 +1253,6 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::set_ -- name: aten::_multinomial_alias_draw - depends: - - name: aten::eq - - name: aten::is_nonzero -- name: aten::_multinomial_alias_setup - depends: - - name: aten::eq - - name: aten::is_nonzero - - name: aten::set_ - name: aten::_nnpack_available depends: - name: aten::eq @@ -1320,10 +1378,18 @@ - name: aten::to - name: aten::_sample_dirichlet depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::expand - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ - name: aten::sum + - name: aten::to - name: aten::zeros - name: aten::_saturate_weight_to_fp16 depends: @@ -1497,14 +1563,29 @@ - name: aten::is_nonzero - name: aten::_standard_gamma depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros - name: aten::_standard_gamma_grad depends: + - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::_std depends: - name: aten::eq @@ -1557,6 +1638,10 @@ - name: aten::is_nonzero - name: aten::mul - name: aten::sub +- name: aten::_test_string_default + depends: + - name: aten::eq + - name: aten::is_nonzero - name: aten::_thnn_differentiable_gru_cell_backward depends: - name: aten::add @@ -1737,6 +1822,7 @@ depends: - name: aten::abs - name: aten::eq + - name: aten::is_complex - name: aten::is_nonzero - name: aten::absolute depends: @@ -1752,6 +1838,7 @@ depends: - name: aten::acos - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -1760,6 +1847,7 @@ - name: aten::is_leaf - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::acos_ depends: @@ -1768,7 +1856,6 @@ - name: aten::is_nonzero - name: aten::acosh depends: - - name: aten::acosh - name: aten::as_strided_ - name: aten::copy_ - name: aten::empty @@ -1874,11 +1961,7 @@ - name: aten::empty_meta - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -1894,39 +1977,11 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to -- name: aten::add_relu - depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - - name: aten::eq - - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to -- name: aten::add_relu_ - depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - - name: aten::eq - - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to - name: aten::addbmm depends: - name: aten::addbmm @@ -2082,18 +2137,20 @@ - name: aten::zero_ - name: aten::addr depends: - - name: aten::_addr + - name: aten::add - name: aten::addr + - name: aten::copy_ - name: aten::eq - - name: aten::expand - name: aten::is_floating_point - name: aten::is_leaf - name: aten::is_nonzero - - name: aten::size + - name: aten::mul + - name: aten::outer + - name: aten::resize_ - name: aten::to - name: aten::addr_ depends: - - name: aten::_addr_ + - name: aten::addr - name: aten::eq - name: aten::is_nonzero - name: aten::affine_grid_generator @@ -2362,8 +2419,10 @@ - name: aten::sort - name: aten::as_strided depends: + - name: aten::as_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::permute - name: aten::as_strided_ depends: - name: aten::eq @@ -2389,7 +2448,6 @@ - name: aten::asinh depends: - name: aten::as_strided_ - - name: aten::asinh - name: aten::copy_ - name: aten::empty - name: aten::empty_like @@ -2408,12 +2466,14 @@ depends: - name: aten::as_strided_ - name: aten::atan + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::atan2 depends: @@ -2450,7 +2510,6 @@ - name: aten::atanh depends: - name: aten::as_strided_ - - name: aten::atanh - name: aten::copy_ - name: aten::empty - name: aten::empty_like @@ -2529,10 +2588,6 @@ - name: aten::size - name: aten::zero_ - name: aten::zeros_like -- name: aten::backward - depends: - - name: aten::eq - - name: aten::is_nonzero - name: aten::baddbmm depends: - name: aten::addmm_ @@ -2746,8 +2801,16 @@ - name: aten::zero_ - name: aten::binomial depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros - name: aten::bitwise_and depends: @@ -2758,11 +2821,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -2799,11 +2858,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -2822,11 +2877,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -2892,20 +2943,11 @@ - name: aten::is_nonzero - name: aten::bucketize depends: - - name: aten::as_strided_ - name: aten::contiguous - - name: aten::copy_ - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::can_cast @@ -3075,6 +3117,14 @@ - name: aten::is_nonzero - name: aten::resize_as_ - name: aten::size +- name: aten::choose_qparams_optimized + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::select - name: aten::chunk depends: - name: aten::chunk @@ -3164,6 +3214,7 @@ - name: aten::_empty_affine_quantized - name: aten::_empty_per_channel_affine_quantized - name: aten::as_strided_ + - name: aten::clone - name: aten::copy_ - name: aten::copy_sparse_to_sparse_ - name: aten::empty @@ -3172,6 +3223,7 @@ - name: aten::eq - name: aten::is_complex - name: aten::is_nonzero + - name: aten::permute - name: aten::q_per_channel_axis - name: aten::q_per_channel_scales - name: aten::q_per_channel_zero_points @@ -3208,6 +3260,12 @@ - name: aten::select - name: aten::size - name: aten::zero_ +- name: aten::column_stack + depends: + - name: aten::eq + - name: aten::hstack + - name: aten::is_nonzero + - name: aten::reshape - name: aten::combinations depends: - name: aten::arange @@ -3263,10 +3321,12 @@ - name: aten::size - name: aten::contiguous depends: + - name: aten::contiguous - name: aten::copy_ - name: aten::empty_like - name: aten::eq - name: aten::is_nonzero + - name: aten::permute - name: aten::conv1d depends: - name: aten::conv1d @@ -3386,21 +3446,40 @@ - name: aten::size - name: aten::stride - name: aten::to -- name: aten::copy_imag +- name: aten::copy_sparse_to_sparse_ depends: - name: aten::eq - name: aten::is_nonzero -- name: aten::copy_real +- name: aten::copysign depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero -- name: aten::copy_sparse_to_sparse_ + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::copysign_ depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to - name: aten::cos depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::cos - name: aten::empty - name: aten::empty_like @@ -3408,6 +3487,7 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::cos_ depends: @@ -3648,22 +3728,13 @@ - name: aten::is_nonzero - name: aten::deg2rad depends: - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::deg2rad - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::deg2rad_ depends: - name: aten::deg2rad @@ -3684,7 +3755,6 @@ - name: aten::add_ - name: aten::all - name: aten::arange - - name: aten::contiguous - name: aten::diagonal - name: aten::eq - name: aten::fmod_ @@ -3799,11 +3869,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -3819,15 +3885,21 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to +- name: aten::divide + depends: + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero +- name: aten::divide_ + depends: + - name: aten::div_ + - name: aten::eq + - name: aten::is_nonzero - name: aten::dot depends: - name: aten::dot @@ -3897,15 +3969,20 @@ depends: - name: aten::bmm - name: aten::diagonal + - name: aten::dot - name: aten::eq + - name: aten::flatten - name: aten::is_nonzero + - name: aten::movedim - name: aten::mul - name: aten::permute - name: aten::reshape - name: aten::size + - name: aten::squeeze - name: aten::sum - name: aten::unsqueeze - name: aten::view + - name: aten::zeros - name: aten::elu depends: - name: aten::as_strided_ @@ -3940,6 +4017,7 @@ - name: aten::eq - name: aten::index_select - name: aten::is_nonzero + - name: aten::masked_fill_ - name: aten::reshape - name: aten::view - name: aten::embedding_backward @@ -3984,6 +4062,7 @@ - name: aten::ne - name: aten::reshape - name: aten::size + - name: aten::to - name: aten::empty depends: - name: aten::empty @@ -4044,33 +4123,17 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::eq_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::equal depends: - name: aten::as_strided_ @@ -4363,10 +4426,261 @@ - name: aten::unsqueeze - name: aten::fft_fft depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_fft2 + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd - name: aten::eq - - name: aten::fft - name: aten::is_complex - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_fftfreq + depends: + - name: aten::arange + - name: aten::empty + - name: aten::eq + - name: aten::fft_fftfreq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::slice +- name: aten::fft_fftn + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_fftshift + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::roll +- name: aten::fft_hfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_real +- name: aten::fft_ifft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_ifft2 + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_ifftn + depends: + - name: aten::_fft_with_size + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_ifftshift + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::roll +- name: aten::fft_ihfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex +- name: aten::fft_irfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_real +- name: aten::fft_irfft2 + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_irfftn + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_rfft + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex +- name: aten::fft_rfft2 + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fft_rfftfreq + depends: + - name: aten::arange + - name: aten::empty + - name: aten::eq + - name: aten::fft_rfftfreq + - name: aten::is_nonzero + - name: aten::mul_ +- name: aten::fft_rfftn + depends: + - name: aten::_fft_with_size + - name: aten::conj + - name: aten::constant_pad_nd + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze - name: aten::view_as_complex - name: aten::view_as_real - name: aten::fill_ @@ -4461,13 +4775,9 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::floor_divide - - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -4481,13 +4791,9 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::floor_divide - - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -4566,6 +4872,8 @@ - name: aten::frobenius_norm depends: - name: aten::conj + - name: aten::copy_ + - name: aten::empty - name: aten::eq - name: aten::frobenius_norm - name: aten::is_complex @@ -4575,6 +4883,7 @@ - name: aten::mul - name: aten::norm - name: aten::real + - name: aten::resize_ - name: aten::sqrt - name: aten::sum - name: aten::to @@ -4643,35 +4952,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::ge - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::ge_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::ge - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::gelu depends: - name: aten::_empty_affine_quantized @@ -4759,13 +5052,9 @@ - name: aten::to - name: aten::ger depends: - - name: aten::_addr - - name: aten::empty - name: aten::eq - - name: aten::ger - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::size + - name: aten::outer - name: aten::get_gradients depends: - name: aten::eq @@ -4927,35 +5216,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::gt - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::gt_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::gt - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::hamming_window depends: - name: aten::add_ @@ -5225,6 +5498,23 @@ - name: aten::size - name: aten::squeeze - name: aten::unsqueeze +- name: aten::igamma + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::igamma_ + depends: + - name: aten::eq + - name: aten::igamma + - name: aten::is_nonzero - name: aten::im2col depends: - name: aten::contiguous @@ -5248,6 +5538,7 @@ - name: aten::imag depends: - name: aten::eq + - name: aten::imag - name: aten::is_complex - name: aten::is_nonzero - name: aten::select @@ -5408,9 +5699,11 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq + - name: aten::fill_ - name: aten::is_nonzero - name: aten::resize_ - name: aten::resize_as_ + - name: aten::select - name: aten::to - name: aten::inverse depends: @@ -5575,6 +5868,7 @@ - name: aten::constant_pad_nd - name: aten::div - name: aten::eq + - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::min @@ -5589,6 +5883,8 @@ - name: aten::transpose - name: aten::unsqueeze - name: aten::view + - name: aten::view_as_complex + - name: aten::view_as_real - name: aten::istitle depends: - name: aten::eq @@ -5612,6 +5908,22 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: aten::kaiser_window + depends: + - name: aten::arange + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::kaiser_window + - name: aten::narrow + - name: aten::ones + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::kl_div depends: - name: aten::eq @@ -5631,14 +5943,32 @@ - name: aten::zeros_like - name: aten::kl_div_backward depends: + - name: aten::as_strided_ + - name: aten::copy_ - name: aten::div + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::exp - name: aten::expand_as - name: aten::is_nonzero - name: aten::mul - name: aten::neg + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros_like +- name: aten::kron + depends: + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::kron + - name: aten::permute + - name: aten::reshape + - name: aten::resize_ + - name: aten::tensordot - name: aten::kthvalue depends: - name: aten::clone @@ -5712,35 +6042,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::le - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::le_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::le - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::leaky_relu depends: - name: aten::_empty_affine_quantized @@ -5852,10 +6166,10 @@ - name: aten::linalg_norm depends: - name: aten::abs - - name: aten::add_ - name: aten::copy_ - name: aten::empty - name: aten::eq + - name: aten::fill_ - name: aten::flatten - name: aten::frobenius_norm - name: aten::is_nonzero @@ -5864,12 +6178,21 @@ - name: aten::norm - name: aten::nuclear_norm - name: aten::permute - - name: aten::pow - name: aten::resize_ - name: aten::sum - name: aten::svd - name: aten::to - name: aten::unsqueeze_ +- name: aten::linalg_tensorsolve + depends: + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linalg_tensorsolve + - name: aten::movedim + - name: aten::reshape + - name: aten::resize_ + - name: aten::solve - name: aten::linear depends: - name: aten::add_ @@ -5909,6 +6232,7 @@ - name: aten::log depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -5918,10 +6242,12 @@ - name: aten::is_nonzero - name: aten::log - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::log10 depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -5931,6 +6257,7 @@ - name: aten::is_nonzero - name: aten::log10 - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::log10_ depends: @@ -5958,6 +6285,7 @@ - name: aten::log2 depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -5967,6 +6295,7 @@ - name: aten::is_nonzero - name: aten::log2 - name: aten::resize_ + - name: aten::resize_as_ - name: aten::to - name: aten::log2_ depends: @@ -6065,7 +6394,6 @@ - name: aten::add_ - name: aten::all - name: aten::arange - - name: aten::contiguous - name: aten::diagonal - name: aten::eq - name: aten::fill_ @@ -6290,35 +6618,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::lt - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::lt_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::lt - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::lu_solve depends: - name: aten::_lu_solve_helper @@ -6575,6 +6887,7 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::max_pool1d_with_indices + - name: aten::quantized_max_pool1d - name: aten::size - name: aten::squeeze_ - name: aten::max_pool1d_with_indices @@ -6709,15 +7022,18 @@ - name: aten::median depends: - name: aten::clone + - name: aten::contiguous - name: aten::empty - name: aten::eq - - name: aten::fill_ - name: aten::is_nonzero - - name: aten::kthvalue - name: aten::median + - name: aten::resize_ - name: aten::select - name: aten::size - - name: aten::view + - name: aten::squeeze_ + - name: aten::stride + - name: aten::transpose_ + - name: aten::unsqueeze - name: aten::meshgrid depends: - name: aten::eq @@ -6945,11 +7261,8 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -6964,11 +7277,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -7048,6 +7357,16 @@ - name: aten::sum - name: aten::topk - name: aten::uniform_ +- name: aten::multiply + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul +- name: aten::multiply_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ - name: aten::mv depends: - name: aten::_addmv_impl_ @@ -7112,6 +7431,76 @@ - name: aten::lgamma_ - name: aten::sum - name: aten::unsqueeze +- name: aten::nan_to_num + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nan_to_num + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::nan_to_num_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nan_to_num +- name: aten::nanmedian + depends: + - name: aten::clone + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nanmedian + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::squeeze_ + - name: aten::stride + - name: aten::transpose_ + - name: aten::unsqueeze +- name: aten::nanquantile + depends: + - name: aten::all + - name: aten::any + - name: aten::broadcast_tensors + - name: aten::ceil_ + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::flatten + - name: aten::gather + - name: aten::ge + - name: aten::is_nonzero + - name: aten::isnan + - name: aten::item + - name: aten::le + - name: aten::lerp_ + - name: aten::logical_and_ + - name: aten::logical_not_ + - name: aten::lt + - name: aten::masked_fill + - name: aten::masked_fill_ + - name: aten::mul + - name: aten::nanquantile + - name: aten::resize_ + - name: aten::scalar_tensor + - name: aten::size + - name: aten::sort + - name: aten::squeeze_ + - name: aten::sub + - name: aten::sum + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unsqueeze + - name: aten::unsqueeze_ + - name: aten::view - name: aten::nansum depends: - name: aten::as_strided @@ -7172,12 +7561,16 @@ depends: - name: aten::_empty_affine_quantized - name: aten::_empty_per_channel_affine_quantized + - name: aten::add + - name: aten::addcmul - name: aten::clone - name: aten::dense_dim - name: aten::empty - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::mul + - name: aten::native_batch_norm - name: aten::q_per_channel_axis - name: aten::q_per_channel_scales - name: aten::q_per_channel_zero_points @@ -7186,6 +7579,7 @@ - name: aten::qscheme - name: aten::sparse_dim - name: aten::sparse_resize_and_clear_ + - name: aten::view - name: aten::native_group_norm_backward depends: - name: aten::_empty_affine_quantized @@ -7260,35 +7654,19 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - - name: aten::mul - name: aten::ne - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::ne_ depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - name: aten::ne - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::neg depends: - name: aten::as_strided_ @@ -7309,27 +7687,28 @@ - name: aten::neg - name: aten::negative depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - name: aten::neg - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to - name: aten::negative_ depends: - name: aten::eq - name: aten::is_nonzero - - name: aten::neg + - name: aten::neg_ - name: aten::new_empty depends: - name: aten::empty - name: aten::eq - name: aten::is_nonzero + - name: aten::new_empty + - name: aten::permute +- name: aten::new_empty_strided + depends: + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::new_empty_strided + - name: aten::permute - name: aten::new_full depends: - name: aten::eq @@ -7339,6 +7718,8 @@ depends: - name: aten::eq - name: aten::is_nonzero + - name: aten::new_zeros + - name: aten::permute - name: aten::zeros - name: aten::nextafter depends: @@ -7494,13 +7875,15 @@ - name: aten::ne_ - name: aten::nuclear_norm depends: + - name: aten::copy_ + - name: aten::empty - name: aten::eq - name: aten::is_floating_point - name: aten::is_leaf - name: aten::is_nonzero - name: aten::nuclear_norm - name: aten::permute - - name: aten::set_ + - name: aten::resize_ - name: aten::sum - name: aten::svd - name: aten::to @@ -7582,8 +7965,10 @@ - name: aten::outer depends: - name: aten::eq - - name: aten::ger - name: aten::is_nonzero + - name: aten::mul + - name: aten::reshape + - name: aten::size - name: aten::output_nr depends: - name: aten::eq @@ -7625,7 +8010,7 @@ - name: aten::set_ - name: aten::pinverse depends: - - name: aten::diag_embed + - name: aten::conj - name: aten::empty - name: aten::eq - name: aten::gt @@ -7635,7 +8020,9 @@ - name: aten::narrow - name: aten::reciprocal - name: aten::svd + - name: aten::to - name: aten::transpose + - name: aten::unsqueeze - name: aten::where - name: aten::zeros - name: aten::pixel_shuffle @@ -7647,8 +8034,16 @@ - name: aten::size - name: aten::poisson depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: aten::zeros - name: aten::poisson_nll_loss depends: @@ -7709,12 +8104,10 @@ - name: aten::empty_strided - name: aten::eq - name: aten::fill_ - - name: aten::is_complex - name: aten::is_floating_point - name: aten::is_leaf - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones + - name: aten::item - name: aten::permute - name: aten::pow - name: aten::resize_ @@ -7734,6 +8127,7 @@ - name: aten::eq - name: aten::fill_ - name: aten::is_nonzero + - name: aten::item - name: aten::resize_ - name: aten::resize_as_ - name: aten::result_type @@ -7833,29 +8227,40 @@ - name: aten::quantile depends: - name: aten::all - - name: aten::ceil + - name: aten::any + - name: aten::broadcast_tensors + - name: aten::ceil_ - name: aten::copy_ - name: aten::empty - name: aten::eq - name: aten::flatten - - name: aten::floor + - name: aten::gather - name: aten::ge - - name: aten::index_select - name: aten::is_nonzero + - name: aten::isnan - name: aten::item - name: aten::le - name: aten::lerp_ - name: aten::logical_and_ + - name: aten::logical_not_ + - name: aten::lt + - name: aten::masked_fill + - name: aten::masked_fill_ - name: aten::mul - - name: aten::permute - name: aten::quantile - - name: aten::reshape + - name: aten::resize_ - name: aten::scalar_tensor - name: aten::size - name: aten::sort + - name: aten::squeeze_ - name: aten::sub + - name: aten::sum - name: aten::to + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unsqueeze - name: aten::unsqueeze_ + - name: aten::view - name: aten::quantize_per_channel depends: - name: aten::contiguous @@ -7952,6 +8357,13 @@ - name: aten::tanh - name: aten::tanh_ - name: aten::unsafe_chunk +- name: aten::quantized_max_pool1d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::quantized_max_pool2d + - name: aten::squeeze + - name: aten::unsqueeze - name: aten::quantized_max_pool2d depends: - name: aten::_empty_affine_quantized @@ -7977,22 +8389,13 @@ - name: aten::tanh - name: aten::rad2deg depends: - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - - name: aten::ones - name: aten::rad2deg - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::rad2deg_ depends: - name: aten::eq @@ -8070,11 +8473,17 @@ - name: aten::is_nonzero - name: aten::range - name: aten::resize_ +- name: aten::ravel + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape - name: aten::real depends: - name: aten::eq - name: aten::is_complex - name: aten::is_nonzero + - name: aten::real - name: aten::select - name: aten::view_as_real - name: aten::reciprocal @@ -8097,6 +8506,10 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::reciprocal +- name: aten::record_stream + depends: + - name: aten::eq + - name: aten::is_nonzero - name: aten::refine_names depends: - name: aten::alias @@ -8181,11 +8594,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -8198,11 +8607,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -8349,10 +8754,6 @@ - name: aten::size - name: aten::zero_ - name: aten::zeros_like -- name: aten::requires_grad - depends: - - name: aten::eq - - name: aten::is_nonzero - name: aten::requires_grad_ depends: - name: aten::eq @@ -8386,22 +8787,10 @@ - name: aten::sparse_dim - name: aten::result_type depends: - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::result_type - name: aten::scalar_tensor - - name: aten::to - name: aten::retain_grad depends: - name: aten::eq @@ -8523,6 +8912,11 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::round +- name: aten::row_stack + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::vstack - name: aten::rowwise_prune depends: - name: aten::contiguous @@ -8549,69 +8943,37 @@ - name: aten::rrelu_with_noise depends: - name: aten::add - - name: aten::as_strided_ - name: aten::contiguous - name: aten::copy_ - name: aten::div - - name: aten::empty - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::leaky_relu - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::rrelu_with_noise_ depends: - name: aten::add - - name: aten::as_strided_ - name: aten::contiguous - name: aten::copy_ - name: aten::div - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::leaky_relu - - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::rrelu_with_noise_backward depends: - name: aten::add - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::div - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - name: aten::item - name: aten::leaky_relu_backward - name: aten::mul - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::sub - - name: aten::to - name: aten::rsplit depends: - name: aten::eq @@ -8648,11 +9010,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -8715,20 +9073,11 @@ - name: aten::to - name: aten::searchsorted depends: - - name: aten::as_strided_ - name: aten::contiguous - - name: aten::copy_ - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - name: aten::to - name: aten::select @@ -8782,6 +9131,25 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: aten::sgn + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sgn + - name: aten::to +- name: aten::sgn_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sgn - name: aten::sigmoid depends: - name: aten::_empty_affine_quantized @@ -8828,6 +9196,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq + - name: aten::is_complex - name: aten::is_nonzero - name: aten::resize_ - name: aten::resize_as_ @@ -8886,12 +9255,14 @@ - name: aten::sin depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::sin - name: aten::to - name: aten::sin_ @@ -8945,7 +9316,6 @@ - name: aten::add_ - name: aten::all - name: aten::arange - - name: aten::contiguous - name: aten::diagonal - name: aten::eq - name: aten::fmod_ @@ -8991,6 +9361,7 @@ - name: aten::bmm - name: aten::contiguous - name: aten::copy_ + - name: aten::detach - name: aten::empty - name: aten::eq - name: aten::is_nonzero @@ -9109,6 +9480,7 @@ - name: aten::sspaddmm - name: aten::smooth_l1_loss depends: + - name: aten::abs_ - name: aten::as_strided_ - name: aten::copy_ - name: aten::empty @@ -9122,20 +9494,26 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::smooth_l1_loss + - name: aten::sub - name: aten::sum - name: aten::to - name: aten::smooth_l1_loss_backward depends: - name: aten::as_strided_ - name: aten::copy_ + - name: aten::div - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero + - name: aten::l1_loss_backward + - name: aten::mul_ - name: aten::resize_ - name: aten::resize_as_ + - name: aten::sign_ - name: aten::smooth_l1_loss_backward + - name: aten::sub - name: aten::to - name: aten::zeros_like - name: aten::soft_margin_loss @@ -9238,28 +9616,25 @@ - name: aten::size - name: aten::sort depends: - - name: aten::_copy_from - name: aten::_make_per_tensor_quantized_tensor + - name: aten::arange + - name: aten::as_strided - name: aten::as_strided_ - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - name: aten::int_repr - - name: aten::is_complex - name: aten::is_nonzero - name: aten::q_scale - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::set_quantizer_ - name: aten::size - name: aten::sort - name: aten::stride - name: aten::to + - name: aten::zero_ - name: aten::sparse_coo_tensor depends: - name: aten::_sparse_coo_tensor_with_dims @@ -9347,6 +9722,7 @@ - name: aten::size - name: aten::squeeze - name: aten::to + - name: aten::view - name: aten::squeeze_ depends: - name: aten::as_strided_ @@ -9414,19 +9790,24 @@ - name: aten::to - name: aten::stft depends: + - name: aten::_fft_with_size - name: aten::as_strided - name: aten::copy_ - name: aten::eq - name: aten::fill_ + - name: aten::is_complex - name: aten::is_nonzero - name: aten::mul - name: aten::narrow - - name: aten::rfft + - name: aten::reshape - name: aten::size + - name: aten::squeeze - name: aten::squeeze_ - name: aten::stride - name: aten::transpose_ - name: aten::unsqueeze + - name: aten::view_as_complex + - name: aten::view_as_real - name: aten::zeros - name: aten::stride depends: @@ -9444,11 +9825,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::permute - name: aten::resize_ - name: aten::resize_as_ @@ -9464,11 +9841,7 @@ - name: aten::empty_like - name: aten::empty_strided - name: aten::eq - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - - name: aten::ones - name: aten::resize_ - name: aten::resize_as_ - name: aten::scalar_tensor @@ -9487,6 +9860,7 @@ depends: - name: aten::as_strided - name: aten::as_strided_ + - name: aten::clone - name: aten::copy_ - name: aten::empty - name: aten::empty_like @@ -9539,25 +9913,11 @@ - name: aten::transpose_ - name: aten::take depends: - - name: aten::_copy_from - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::copy_sparse_to_sparse_ + - name: aten::contiguous - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::qscheme - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::set_quantizer_ - - name: aten::size - - name: aten::stride - - name: aten::to - name: aten::take_backward depends: - name: aten::eq @@ -9567,6 +9927,7 @@ - name: aten::tan depends: - name: aten::as_strided_ + - name: aten::copy_ - name: aten::empty - name: aten::empty_like - name: aten::empty_strided @@ -9575,6 +9936,7 @@ - name: aten::is_leaf - name: aten::is_nonzero - name: aten::resize_ + - name: aten::resize_as_ - name: aten::tan - name: aten::to - name: aten::tan_ @@ -9617,8 +9979,17 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::to +- name: aten::tensor_split + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::size + - name: aten::slice + - name: aten::tensor_split - name: aten::tensordot depends: + - name: aten::copy_ - name: aten::eq - name: aten::is_floating_point - name: aten::is_leaf @@ -9656,6 +10027,7 @@ - name: aten::addmm_ - name: aten::contiguous - name: aten::copy_ + - name: aten::detach - name: aten::empty - name: aten::eq - name: aten::is_nonzero @@ -9666,7 +10038,6 @@ - name: aten::size - name: aten::unsqueeze - name: aten::view - - name: aten::zero_ - name: aten::thnn_conv_depthwise2d depends: - name: aten::eq @@ -9786,9 +10157,11 @@ - name: aten::zero_ - name: aten::trace depends: + - name: aten::empty - name: aten::eq - name: aten::is_nonzero - - name: aten::scalar_tensor + - name: aten::size + - name: aten::stride - name: aten::trace_backward depends: - name: aten::arange @@ -9924,17 +10297,9 @@ - name: aten::triu_indices - name: aten::true_divide depends: - - name: aten::as_strided_ - - name: aten::copy_ - name: aten::div - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - name: aten::is_nonzero - - name: aten::resize_ - - name: aten::resize_as_ - - name: aten::to - name: aten::true_divide_ depends: - name: aten::div_ @@ -10350,17 +10715,16 @@ - name: aten::view - name: aten::view_as_complex depends: - - name: aten::empty - name: aten::eq - name: aten::is_nonzero - - name: aten::set_ + - name: aten::permute + - name: aten::view_as_complex - name: aten::view_as_real depends: - - name: aten::empty - name: aten::eq - name: aten::is_complex - name: aten::is_nonzero - - name: aten::set_ + - name: aten::view_as_real - name: aten::vstack depends: - name: aten::atleast_2d @@ -10370,23 +10734,11 @@ - name: aten::where depends: - name: aten::_s_where - - name: aten::as_strided_ - - name: aten::copy_ - - name: aten::empty - - name: aten::empty_like - - name: aten::empty_strided - name: aten::eq - name: aten::expand - - name: aten::fill_ - - name: aten::is_complex - name: aten::is_nonzero - - name: aten::mul - name: aten::nonzero_numpy - - name: aten::ones - - name: aten::resize_ - - name: aten::resize_as_ - name: aten::scalar_tensor - - name: aten::to - name: aten::where - name: aten::zero_ depends: @@ -10464,6 +10816,7 @@ - name: aten::is_nonzero - name: quantized::add depends: + - name: aten::_empty_affine_quantized - name: aten::as_strided_ - name: aten::contiguous - name: aten::copy_ @@ -10478,6 +10831,7 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::set_quantizer_ + - name: aten::size - name: aten::to - name: quantized::add_out depends: @@ -10970,8 +11324,24 @@ - name: aten::unsqueeze - name: quantized::conv_transpose1d_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::unsqueeze + - name: aten::zeros - name: quantized::conv_transpose1d_unpack depends: - name: aten::clone @@ -11000,8 +11370,23 @@ - name: aten::is_nonzero - name: quantized::conv_transpose2d_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros - name: quantized::conv_transpose2d_stride depends: - name: aten::eq @@ -11014,6 +11399,42 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: quantized::conv_transpose3d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_dilation + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_groups + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_output_padding + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_padding + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_prepack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_stride + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_transpose + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_transpose3d_unpack + depends: + - name: aten::eq + - name: aten::is_nonzero - name: quantized::conv_unpack depends: - name: aten::eq @@ -11035,10 +11456,13 @@ - name: aten::to - name: quantized::embedding_bag_2bit_prepack depends: + - name: aten::choose_qparams_optimized - name: aten::contiguous - name: aten::empty - name: aten::eq - name: aten::is_nonzero + - name: aten::item + - name: aten::select - name: aten::size - name: quantized::embedding_bag_2bit_unpack depends: @@ -11046,12 +11470,19 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::size +- name: quantized::embedding_bag_4bit + depends: + - name: aten::eq + - name: aten::is_nonzero - name: quantized::embedding_bag_4bit_prepack depends: + - name: aten::choose_qparams_optimized - name: aten::contiguous - name: aten::empty - name: aten::eq - name: aten::is_nonzero + - name: aten::item + - name: aten::select - name: aten::size - name: quantized::embedding_bag_4bit_rowwise_offsets depends: @@ -11060,7 +11491,6 @@ - name: aten::eq - name: aten::is_nonzero - name: aten::size - - name: aten::to - name: quantized::embedding_bag_4bit_unpack depends: - name: aten::empty @@ -11080,6 +11510,7 @@ - name: aten::size - name: quantized::embedding_bag_byte_rowwise_offsets depends: + - name: aten::contiguous - name: aten::empty - name: aten::eq - name: aten::is_nonzero @@ -11155,6 +11586,22 @@ - name: aten::is_nonzero - name: aten::q_scale - name: aten::q_zero_point +- name: quantized::leaky_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to - name: quantized::linear depends: - name: aten::eq @@ -11169,8 +11616,28 @@ - name: aten::is_nonzero - name: quantized::linear_prepack depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty - name: aten::eq + - name: aten::equal - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros - name: quantized::linear_prepack_fp16 depends: - name: aten::eq @@ -11232,16 +11699,16 @@ depends: - name: aten::eq - name: aten::is_nonzero +- name: quantized::max_pool1d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::quantized_max_pool1d - name: quantized::max_pool2d depends: - - name: aten::_empty_affine_quantized - - name: aten::contiguous - name: aten::eq - name: aten::is_nonzero - - name: aten::max_pool2d - - name: aten::q_scale - - name: aten::q_zero_point - - name: aten::size + - name: aten::quantized_max_pool2d - name: quantized::mul depends: - name: aten::_empty_affine_quantized @@ -11432,6 +11899,23 @@ - name: aten::resize_ - name: aten::resize_as_ - name: aten::to +- name: quantized::sigmoid + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to - name: quantized::threshold depends: - name: aten::_empty_affine_quantized diff --git a/tools/code_analyzer/gen_op_registration_whitelist.py b/tools/code_analyzer/gen_op_registration_allowlist.py similarity index 79% rename from tools/code_analyzer/gen_op_registration_whitelist.py rename to tools/code_analyzer/gen_op_registration_allowlist.py index 5971864b21876..04a58c8f522b9 100644 --- a/tools/code_analyzer/gen_op_registration_whitelist.py +++ b/tools/code_analyzer/gen_op_registration_allowlist.py @@ -1,11 +1,11 @@ """ -This util is invoked from cmake to produce the op registration whitelist param +This util is invoked from cmake to produce the op registration allowlist param for `ATen/gen.py` for custom mobile build. For custom build with dynamic dispatch, it takes the op dependency graph of ATen and the list of root ops, and outputs all transitive dependencies of the root -ops as the whitelist. +ops as the allowlist. For custom build with static dispatch, the op dependency graph will be omitted, -and it will directly output root ops as the whitelist. +and it will directly output root ops as the allowlist. """ import argparse @@ -38,7 +38,7 @@ def load_root_ops(fname): return result -def gen_transitive_closure(dep_graph, root_ops): +def gen_transitive_closure(dep_graph, root_ops, train=False): result = set(root_ops) queue = root_ops[:] @@ -50,7 +50,10 @@ def gen_transitive_closure(dep_graph, root_ops): # and value = (set of ops reachable from C++ functions). Insert the special # `__ROOT__` key to include ops which can be called from C++ code directly, # in addition to ops that are called from TorchScript model. - queue.append('__ROOT__') + # '__ROOT__' is only needed for full-jit. Keep it only for training. + # TODO: when FL is migrated from full-jit to lite trainer, remove '__ROOT__' + if train: + queue.append('__ROOT__') while queue: cur = queue.pop() @@ -59,7 +62,10 @@ def gen_transitive_closure(dep_graph, root_ops): result.add(dep) queue.append(dep) - return ' '.join(sorted(result)) + return sorted(result) + +def gen_transitive_closure_str(dep_graph, root_ops): + return ' '.join(gen_transitive_closure(dep_graph, root_ops)) if __name__ == "__main__": @@ -77,4 +83,4 @@ def gen_transitive_closure(dep_graph, root_ops): deps = load_op_dep_graph(args.op_dependency) if args.op_dependency else {} root_ops = load_root_ops(args.root_ops) - print(gen_transitive_closure(deps, root_ops)) + print(gen_transitive_closure_str(deps, root_ops)) diff --git a/tools/code_analyzer/op_deps_processor.py b/tools/code_analyzer/op_deps_processor.py index 8d79b229cdfbd..6978ce75ec178 100644 --- a/tools/code_analyzer/op_deps_processor.py +++ b/tools/code_analyzer/op_deps_processor.py @@ -12,7 +12,7 @@ import argparse import yaml -from ..autograd.utils import CodeTemplate +from tools.codegen.code_template import CodeTemplate BAZEL_OUTPUT = CodeTemplate("""\ TORCH_DEPS = { diff --git a/tools/code_analyzer/run_analyzer.sh b/tools/code_analyzer/run_analyzer.sh index 40b333d553d88..dc8705cc39f76 100755 --- a/tools/code_analyzer/run_analyzer.sh +++ b/tools/code_analyzer/run_analyzer.sh @@ -15,12 +15,12 @@ echo "Analyze: ${INPUT}" # to operate, so for safety we match a more expansive set. "${ANALYZER_BIN}" \ -op_schema_pattern="^(_aten|_prim|aten|quantized|_quantized|prepacked|profiler|_test)::[a-zA-Z0-9_.]+(\(.*)?$" \ - -op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl|impl_UNBOXED)|torch::Library::(_?def|_?impl|_?impl_UNBOXED)" \ + -op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl)|torch::Library::(_?def|_?impl)" \ -op_invoke_pattern="c10::Dispatcher::findSchema" \ -root_symbol_pattern="torch::jit::[^(]" \ -torch_library_init_pattern="^.*TORCH_LIBRARY_init_([^(]+)(\(.*)?$" \ - -torch_library_init_pattern="^.*TORCH_LIBRARY_FRAGMENT_init_([^(]+)(\(.*)?$" \ - -torch_library_init_pattern="^.*TORCH_LIBRARY_IMPL_init_([^(]+)_([^_]+)(\(.*)?$" \ + -torch_library_init_pattern="^.*TORCH_LIBRARY_FRAGMENT_init_([_]*[^_]+)_[0-9]+(\(.*)?$" \ + -torch_library_init_pattern="^.*TORCH_LIBRARY_IMPL_init_([_]*[^_]+)_([^_]+)_[0-9]+(\(.*)?$" \ ${EXTRA_ANALYZER_FLAGS} \ "${INPUT}" \ > "${OUTPUT}" diff --git a/tools/codegen/api/autograd.py b/tools/codegen/api/autograd.py new file mode 100644 index 0000000000000..6f58eea6d1eac --- /dev/null +++ b/tools/codegen/api/autograd.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass +from typing import Optional, Sequence, List, Tuple + +from tools.codegen.api.types import * +from tools.codegen.model import * + +# Represents a saved attribute involved in backward calculation. +# Note that it can be a derived property of an input argument, e.g.: +# we could save `other.scalar_type()` instead of the entire `other` tensor. +@dataclass(frozen=True) +class SavedAttribute: + # Name of the saved attribute. + # Suffix is appended if it's derived property, e.g.: `other_scalar_type` + name: str + + # The cpp type string. + # TODO: change from raw string to model.Type + type: str + + # The expression to read the derived property at save time, e.g.: + # `other.scalar_type()`. + expr: str + +# Represents a backward formula that calculates derivatives for one +# or more tensors. +@dataclass(frozen=True) +class Derivative: + # The formula string (legit C++ expression). + # Note that expressions against input arguments have been replaced with the + # corresponding saved attributes. + # E.g.: + # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())` + # here: `mul_tensor_backward(grad, self, other_scalar_type)` + formula: str + + # Names of the arguments for which this formula calculates derivatives. + var_names: Tuple[str, ...] + + # Saved inputs that are referenced by the formula. + saved_inputs: Tuple[SavedAttribute, ...] + + # Saved outputs that are referenced by the formula. + saved_outputs: Tuple[SavedAttribute, ...] + +# Represents differentiability info for a NativeFunction. +@dataclass(frozen=True) +class DifferentiabilityInfo: + # The base name read from derivatives.yaml. + name: str + + # The matching native function. + # + # There can be multiple NativeFunction having the same base name: + # - different overloads with different types of input arguments; + # - in-place/out/functional variants of the same function; + # + # We first use the schema string (under the 'name' key) in derivatives.yaml + # to find the NativeFunction having the same schema string. + # Then we find the in-place/out/functional variants of the matching function. + # Among these variants, we choose the one having the same name as the + # derivatives.yaml entry. If there is no exact match, then we choose the + # in-place variant. + # TODO: maybe the logic to search for all variants is no longer necessary? + func: NativeFunction + + # The name of the generated autograd function. + # It's set only if we will calculate a derivative, i.e. + # 'args_with_derivatives' is not empty. + op: Optional[str] + + # The derivatives formulae for this function. + derivatives: Sequence[Derivative] + + # The union of 'saved_inputs' of all 'derivatives'. + all_saved_inputs: Sequence[SavedAttribute] + + # The union of 'saved_outputs' of all 'derivatives'. + all_saved_outputs: Sequence[SavedAttribute] + + # The function's input arguments for which it calculates derivatives. + # It's the union of 'var_names' of all 'derivatives', sorted by the + # argument order in the function schema. + args_with_derivatives: Sequence[Binding] + + # Names of arguments whose derivative formula is 'non_differentiable'. + non_differentiable_arg_names: Sequence[str] + + # Raw data read from derivatives.yaml. + output_differentiability: Optional[List[bool]] + + @property + def has_derivatives(self) -> bool: + return len(self.args_with_derivatives) > 0 + +# Represents a differentiable `Argument`. +# How is it different from the `Argument` type? +# - It's processed Arguments which are differentiable and only used in the +# context of the autograd codegen; +# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; +@dataclass(frozen=True) +class DifferentiableInput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + +# Represents a differentiable `Return`. +# How it it different from the `Return` type? +# - The name in `Return` is optional. Here it is always populated using the same +# `cpp.return_names()` method. +# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? +# - It's processed Returns which are differentiable, in compliance with the +# `output_differentiability` field defined in derivatives.yaml (if specified), +# and are only used in the context of the autograd codegen; +@dataclass(frozen=True) +class DifferentiableOutput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index d8445f02ee548..0debd52ca896c 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -1,7 +1,6 @@ from tools.codegen.model import * -from tools.codegen.api.types import TensorOptionsArguments, CppArgument, ThisArgument -import tools.codegen.local as local -from typing import Optional, Sequence, Union, Callable, List +from tools.codegen.api.types import * +from typing import Optional, Sequence, Union, List, Set # This file describes the translation of JIT schema to the public C++ # API, which is what people use when they call functions like at::add. @@ -9,7 +8,7 @@ # Prominent characteristics of the C++ API: # # - dtype, layout, device and pin_memory are collected into -# a single C++ type TensorOptions (the legacy dispatcher API +# a single C++ type TensorOptions (the native functions API # also has this, but tensor options is really most relevant # for the C++ API; it makes calling kwarg factory functions # pleasant) @@ -23,99 +22,105 @@ # BTW: policy on name collisions: we try not to have types with # collisions, but functions are fair game to collide -def name(func: FunctionSchema) -> str: +def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str: name = str(func.name.name) if func.is_out_fn(): - name += '_out' + if faithful_name_for_out_overloads: + name += '_outf' + else: + name += '_out' + return name # Translation of "value types" in JIT schema to C++ API type. Value -# types look the same no matter if they are argument types are return +# types look the same no matter if they are argument types or return # types. Returns None if the type in question is not a value type. -def valuetype_type(t: Type) -> Optional[str]: +def valuetype_type(t: Type, *, binds: ArgName) -> Optional[CType]: if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return None elif t.name == BaseTy.int: - return 'int64_t' + return BaseCType('int64_t', binds) elif t.name == BaseTy.float: - return 'double' + return BaseCType('double', binds) elif t.name == BaseTy.str: - return 'std::string' + return BaseCType('std::string', binds) elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage, BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat, - BaseTy.Dimname, BaseTy.ConstQuantizerPtr]: + BaseTy.Dimname, BaseTy.Stream, BaseTy.ConstQuantizerPtr]: # These C++ names line up with their schema names - return t.name.name + return BaseCType(t.name.name, binds) else: raise AssertionError(f"unsupported type: {t}") elif isinstance(t, OptionalType): - elem = valuetype_type(t.elem) + elem = valuetype_type(t.elem, binds=binds) if elem is None: return None - return f"c10::optional<{elem}>" + return OptionalCType(elem) elif isinstance(t, ListType): if str(t.elem) == 'bool': assert t.size is not None - return f"std::array" + return BaseCType(f"std::array", binds) else: return None else: raise AssertionError(f"unrecognized type {repr(t)}") # Translation of types occuring in JIT arguments to a C++ argument type. -def argumenttype_type(t: Type, *, mutable: bool) -> str: +def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: # If it's a value type, do the value type translation - r = valuetype_type(t) + r = valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable: - return 'Tensor &' + return MutRefCType(BaseCType('Tensor', binds)) else: - return 'const Tensor &' + return ConstRefCType(BaseCType('Tensor', binds)) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == 'Tensor': if mutable: - return 'Tensor &' # TODO: fix this discrepancy + return MutRefCType(BaseCType('Tensor', binds)) # TODO: fix this discrepancy else: - if local.use_c10_dispatcher() is UseC10Dispatcher.full: - return 'const c10::optional&' - else: - return 'const Tensor &' - elem = argumenttype_type(t.elem, mutable=mutable) - return f"c10::optional<{elem}>" + return ConstRefCType(OptionalCType(BaseCType('Tensor', binds))) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return OptionalCType(elem) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine + # NB: CType throws away ArrayRef structure because it is not currently + # relevant in translation. When it becomes relevant, need to add back if str(t.elem) == 'int': - return "IntArrayRef" + return BaseCType("IntArrayRef", binds) elif str(t.elem) == 'Tensor': - return "TensorList" + return BaseCType("TensorList", binds) elif str(t.elem) == 'Dimname': - return "DimnameList" - # TODO: do something reasonable about lists of optional tensors - elif not local.use_c10_dispatcher() is UseC10Dispatcher.full and str(t.elem) == 'Tensor?': - return "TensorList" - elem = argumenttype_type(t.elem, mutable=mutable) + return BaseCType("DimnameList", binds) + elif str(t.elem) == 'Tensor?': + return ConstRefCType(BaseCType("c10::List>", binds)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) # TODO: explicitly qualify namespace here - return f"ArrayRef<{elem}>" + return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds) else: raise AssertionError(f"unrecognized type {repr(t)}") # Translate a JIT argument into its C++ type -def argument_type(a: Argument) -> str: - return argumenttype_type(a.type, mutable=a.is_write) +def argument_type(a: Argument, *, binds: ArgName) -> CType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) # Translation of a (non-multi) return type from JIT to C++ +# NB: if need translations on return types, make this return CType too. Need to +# take care; ArgName is misnomer now, and inputs are permitted to conflict with outputs +# so need to make sure you don't have trouble def returntype_type(t: Type, *, mutable: bool) -> str: - r = valuetype_type(t) + # placeholder is ignored + r = valuetype_type(t, binds="__placeholder__") if r is not None: - return r + return r.cpp_type() if isinstance(t, BaseType): if t.name == BaseTy.Tensor: @@ -144,95 +149,154 @@ def returns_type(rs: Sequence[Return]) -> str: args = ','.join(map(return_type, rs)) return f'std::tuple<{args}>' +def return_names(f: NativeFunction) -> Sequence[str]: + returns: List[str] = [] + for i, r in enumerate(f.func.returns): + # If we have an inplace function, the return argument is + # implicitly named self. + # TODO: Consider incorporating this into the data model + if f.func.name.name.inplace: + assert i == 0, "illegal inplace function with multiple returns" + name = 'self' + # If we are out function, the name is the name of the + # corresponding output function (r.name will get recorded + # in field_name later.) + elif f.func.is_out_fn(): + name = f.func.arguments.out[i].name + # If the return argument is explicitly named... + elif r.name: + name_conflict = any(r.name == a.name for a in f.func.schema_order_arguments()) + if name_conflict and not f.func.is_out_fn(): + name = f'{r.name}_return' + else: + name = r.name + # If there is no explicit name, we just name the output result, + # unless it's a multi-return, in which case it's result0, + # result1, etc (zero-indexed) + else: + name = 'result' if len(f.func.returns) == 1 else f'result{i}' + returns.append(name) + return returns + JIT_TO_CPP_DEFAULT = { 'False': 'false', 'True': 'true', 'None': 'c10::nullopt', # UGH this one is type directed 'Mean': 'at::Reduction::Mean', '[]': '{}', - '[0,1]': '{0,1}', # TODO: stop special casing 'contiguous_format': 'MemoryFormat::Contiguous', + 'long': 'at::kLong', } # Convert a JIT default into C++ expression representing the default def default_expr(d: str, t: Type) -> str: if d == 'None' and str(t) == 'Tensor?': return '{}' + if isinstance(t, BaseType) and t.name is BaseTy.str: + # Schema allows single quotes but C++ needs double + if len(d) >= 2 and d[0] == "'" and d[-1] == "'": + s = '' + i = 1 + while i + 1 < len(d): + if d[i] != '\\': + if d[i] == '"': + s += '\\"' + else: + s += d[i] + i += 1 + else: + if d[i + 1] == "'": + s += "'" + else: + s += d[i:i + 2] + i += 2 + + return f'"{s}"' + + if isinstance(t, OptionalType): + if d == 'None': + return 'c10::nullopt' + + return default_expr(d, t.elem) + + if isinstance(t, ListType): + if (d.startswith('[') and d.endswith(']')): + return '{' + d[1:-1] + '}' + elif t.size is None: + # NOTE: Sized lists can have scalar defaults + raise ValueError(f"Expected a list default '[...]' but found: '{d}'") + return JIT_TO_CPP_DEFAULT.get(d, d) # Convert an argument into its C++ API form -def argument(a: Union[Argument, TensorOptionsArguments, ThisArgument]) -> CppArgument: + +def argument( + a: Union[Argument, TensorOptionsArguments, SelfArgument], + *, cpp_no_default_args: Set[str], method: bool, faithful: bool, + has_tensor_options: bool +) -> List[Binding]: + def sub_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]: + return argument( + a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful, + has_tensor_options=has_tensor_options) + if isinstance(a, Argument): - return CppArgument( - type=argument_type(a), + binds: ArgName + if a.name == "memory_format" and has_tensor_options: + binds = SpecialArgName.possibly_redundant_memory_format + else: + binds = a.name + default: Optional[str] = None + if a.name not in cpp_no_default_args and a.default is not None: + default = default_expr(a.default, a.type) + return [Binding( + ctype=argument_type(a, binds=binds), name=a.name, - default=default_expr(a.default, a.type) if a.default is not None else None, - argument=a, - ) - elif isinstance(a, ThisArgument): - return CppArgument( - type=argument_type(a.argument), - name="const_cast(*this)", # this is an abuse but it's convenient - default=None, - argument=a, - ) - elif isinstance(a, TensorOptionsArguments): - default = None - if all(x.default == "None" for x in a.all()): - default = '{}' - elif a.dtype.default == "long": - default = 'at::kLong' # TODO: this is wrong - return CppArgument( - type='const TensorOptions &', - name='options', default=default, argument=a, - ) + )] + elif isinstance(a, TensorOptionsArguments): + if faithful: + return sub_argument(a.dtype) + sub_argument(a.layout) + \ + sub_argument(a.device) + sub_argument(a.pin_memory) + else: + default = None + # Enforced by NativeFunction.__post_init__ + assert 'options' not in cpp_no_default_args + if all(x.default == "None" for x in a.all()): + default = '{}' + elif a.dtype.default == "long": + default = 'at::kLong' # TODO: this is wrong + return [Binding( + ctype=ConstRefCType(BaseCType('TensorOptions', 'options')), + name='options', + default=default, + argument=a, + )] + elif isinstance(a, SelfArgument): + if method: + # Caller is responsible for installing implicit this in context! + return [] + else: + return sub_argument(a.argument) else: assert_never(a) -def group_arguments( - func: FunctionSchema, *, method: bool = False -) -> Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]: - args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = [] - args.extend(func.out_arguments) - - if method: - args.extend(ThisArgument(a) if a.name == "self" else a for a in func.arguments) +def arguments( + arguments: Arguments, + *, faithful: bool, method: bool, cpp_no_default_args: Set[str] +) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] + if faithful: + args.extend(arguments.non_out) + args.extend(arguments.out) else: - args.extend(func.arguments) - - # group up arguments for tensor options - - def pred(name: str, ty: Type) -> Callable[[Argument], bool]: - return lambda a: a.name == name and a.type in [ty, OptionalType(ty)] - predicates = [ # order matters - pred('dtype', Type.parse('ScalarType')), - pred('layout', Type.parse('Layout')), - pred('device', Type.parse('Device')), - pred('pin_memory', Type.parse('bool')), + args.extend(arguments.out) + args.extend(arguments.non_out) + return [ + r.no_default() if faithful else r for a in args + for r in argument( + a, faithful=faithful, method=method, + has_tensor_options=arguments.tensor_options is not None, + cpp_no_default_args=cpp_no_default_args) ] - - i = 0 - while i < len(func.kwarg_only_arguments): - # If there is enough space... - if i <= len(func.kwarg_only_arguments) - len(predicates): - # And the next len(predicates) arguments look like TensorOptions arguments - if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])): - # Group them together as one argument - args.append(TensorOptionsArguments( - dtype=func.kwarg_only_arguments[i], - layout=func.kwarg_only_arguments[i + 1], - device=func.kwarg_only_arguments[i + 2], - pin_memory=func.kwarg_only_arguments[i + 3], - )) - i += len(predicates) - continue - args.append(func.kwarg_only_arguments[i]) - i += 1 - - return args - -# Convert arguments to C++ API form -def arguments(func: FunctionSchema, *, method: bool = False) -> Sequence[CppArgument]: - return list(map(argument, group_arguments(func, method=method))) diff --git a/tools/codegen/api/dispatcher.py b/tools/codegen/api/dispatcher.py index 34960534275f5..bb65bc386e64f 100644 --- a/tools/codegen/api/dispatcher.py +++ b/tools/codegen/api/dispatcher.py @@ -1,13 +1,10 @@ from tools.codegen.model import * -from tools.codegen.api.types import CppArgument, DispatcherExpr, TensorOptionsArguments, \ - DispatcherArgument, ThisArgument, LegacyDispatcherArgument +from tools.codegen.api.types import * import tools.codegen.api.cpp as cpp -import tools.codegen.api.legacy_dispatcher as legacy_dispatcher -import tools.codegen.local as local import itertools -from typing import Sequence, Optional +from typing import Sequence, List, Union # This file describes the translation of JIT schema to the dispatcher # API, the *unboxed* calling convention by which invocations through @@ -28,82 +25,44 @@ # arguments. # -def argumenttype_type(t: Type, *, mutable: bool) -> str: - if local.use_c10_dispatcher() is UseC10Dispatcher.full: - # This is a faux amis. If it makes sense in the future to add - # more special cases here, or invert things so cpp.argument_type - # calls this, or just completely inline the function, please do - # it. - return cpp.argumenttype_type(t, mutable=mutable) - else: - # This is real sharing. If you're modifying this path, ask - # yourself why you are changing the legacy dispatcher protocol - # here and not in legacy_dispatcher. - return legacy_dispatcher.argumenttype_type(t, mutable=mutable) +def name(func: FunctionSchema) -> str: + return cpp.name(func) + +def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: + # This is a faux amis. If it makes sense in the future to add + # more special cases here, or invert things so cpp.argument_type + # calls this, or just completely inline the function, please do + # it. + return cpp.argumenttype_type(t, mutable=mutable, binds=binds) -def argument_type(a: Argument) -> str: - return argumenttype_type(a.type, mutable=a.is_write) +def argument_type(a: Argument, *, binds: ArgName) -> CType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) def returns_type(rs: Sequence[Return]) -> str: # At present, there is no difference. But there could be! return cpp.returns_type(rs) -def argument(a: Argument) -> DispatcherArgument: - if local.use_c10_dispatcher() is UseC10Dispatcher.full: - return DispatcherArgument( - type=argument_type(a), +def argument( + a: Union[Argument, TensorOptionsArguments, SelfArgument] +) -> List[Binding]: + if isinstance(a, Argument): + return [Binding( + ctype=argument_type(a, binds=a.name), name=a.name, argument=a, - ) - else: - la = legacy_dispatcher.argument(a) - return DispatcherArgument( - type=la.type, - name=la.name, - argument=la.argument, - ) - -def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]: - if local.use_c10_dispatcher() is UseC10Dispatcher.full: - return list(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments))) + )] + elif isinstance(a, SelfArgument): + return argument(a.argument) + elif isinstance(a, TensorOptionsArguments): + return argument(a.dtype) + argument(a.layout) + argument(a.device) + argument(a.pin_memory) else: - return [ - DispatcherArgument(type=la.type, name=la.name, argument=la.argument) - for la in legacy_dispatcher.arguments(func) - ] - -# Given a set of CppArguments in scope, return a sequence of dispatcher -# expressions that translate the cpp API into dispatcher API -def cppargument_exprs(a: CppArgument, *, tensor_options: Optional[CppArgument]) -> Sequence[DispatcherExpr]: - if isinstance(a.argument, TensorOptionsArguments): - if local.use_c10_dispatcher() is UseC10Dispatcher.full: - ta = a.argument - return [ - DispatcherExpr(type=argument_type(ta.dtype), expr=f'optTypeMetaToScalarType({a.name}.dtype_opt())'), - DispatcherExpr(type=argument_type(ta.layout), expr=f'{a.name}.layout_opt()'), - DispatcherExpr(type=argument_type(ta.device), expr=f'{a.name}.device_opt()'), - DispatcherExpr(type=argument_type(ta.pin_memory), expr=f'{a.name}.pinned_memory_opt()'), # weird discrep - ] - else: - return [DispatcherExpr(type='const TensorOptions &', expr=a.name)] - elif isinstance(a.argument, Argument): - if a.name == 'memory_format' and tensor_options is not None and local.use_c10_dispatcher() is UseC10Dispatcher.full: - return [DispatcherExpr( - type=argument_type(a.argument), - expr=f'c10::impl::check_tensor_options_and_extract_memory_format({tensor_options.name}, {a.name})') - ] - else: - return [DispatcherExpr(type=argument_type(a.argument), expr=a.name)] - elif isinstance(a.argument, ThisArgument): - return [DispatcherExpr(type=argument_type(a.argument.argument), expr=a.name)] - else: - assert_never(a.argument) - -def cpparguments_exprs(args: Sequence[CppArgument]) -> Sequence[DispatcherExpr]: - tensor_options = next((a for a in args if isinstance(a.argument, TensorOptionsArguments)), None) - return [r for a in args for r in cppargument_exprs(a, tensor_options=tensor_options)] + assert_never(a) -# I don't think this is entirely sound, but it should be reasonably -# close -def legacydispatcherarguments_exprs(args: Sequence[LegacyDispatcherArgument]) -> Sequence[DispatcherExpr]: - return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args]) +def arguments(func: FunctionSchema) -> List[Binding]: + return [ + r for a in itertools.chain( + func.arguments.positional, + func.arguments.kwarg_only, + func.arguments.out + ) for r in argument(a) + ] diff --git a/tools/codegen/api/legacy_dispatcher.py b/tools/codegen/api/legacy_dispatcher.py deleted file mode 100644 index db3d26c84fd03..0000000000000 --- a/tools/codegen/api/legacy_dispatcher.py +++ /dev/null @@ -1,74 +0,0 @@ -from tools.codegen.model import * - -from tools.codegen.api.types import TensorOptionsArguments, LegacyDispatcherArgument, ThisArgument -import tools.codegen.api.cpp as cpp - -from typing import Union, Sequence - -# This file describes the translation of JIT schema to the legacy -# dispatcher API. This looks a lot like the C++ API (which -# makes historical sense, because historically the dispatcher API -# and the C++ API exactly matched), but over time we have -# evolved the C++ API without actually changing our native:: -# kernels. To be deleted eventually. Dispatcher calls use -# this when you are not use_c10_dispatcher: full. - -def name(func: FunctionSchema) -> str: - name = str(func.name.name) - # TODO: delete this! - if func.is_out_fn(): - name += '_out' - if func.name.overload_name: - name += f'_{func.name.overload_name}' - return name - -def argumenttype_type(t: Type, *, mutable: bool) -> str: - if str(t) == 'Tensor?': - if mutable: - return 'Tensor &' - else: - return 'const Tensor &' - elif str(t) == 'Tensor?[]': - return 'TensorList' - return cpp.argumenttype_type(t, mutable=mutable) - -def returns_type(rs: Sequence[Return]) -> str: - return cpp.returns_type(rs) - -def argument_type(a: Argument) -> str: - return argumenttype_type(a.type, mutable=a.is_write) - -def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> LegacyDispatcherArgument: - if isinstance(a, Argument): - return LegacyDispatcherArgument( - type=argument_type(a), - name=a.name, - default=cpp.default_expr(a.default, a.type) if a.default is not None else None, - argument=a, - ) - elif isinstance(a, ThisArgument): - # Erase ThisArgument from the distinction - return LegacyDispatcherArgument( - type=argument_type(a.argument), - name=a.argument.name, - default=None, - argument=a.argument, - ) - elif isinstance(a, TensorOptionsArguments): - # TODO: expunge this logic entirely - default = None - if all(x.default == "None" for x in a.all()): - default = '{}' - elif a.dtype.default == "long": - default = 'at::kLong' # TODO: this is wrong - return LegacyDispatcherArgument( - type='const TensorOptions &', - name='options', - default=default, - argument=a, - ) - else: - assert_never(a) - -def arguments(func: FunctionSchema) -> Sequence[LegacyDispatcherArgument]: - return list(map(argument, cpp.group_arguments(func))) diff --git a/tools/codegen/api/meta.py b/tools/codegen/api/meta.py new file mode 100644 index 0000000000000..259a0793257fc --- /dev/null +++ b/tools/codegen/api/meta.py @@ -0,0 +1,12 @@ +from tools.codegen.model import * +from tools.codegen.api.types import * + +# Follows dispatcher calling convention, but: +# - Mutable arguments not allowed. Meta functions are always +# written in functional form. Look at FunctionSchema.signature() +# - No tensor returns; instead we return a TensorMeta describing +# the tensor in question + +def name(g: StructuredNativeFunctions) -> str: + # use the overload name from the functional version + return str(g.functional.func.name).replace('.', '_') diff --git a/tools/codegen/api/native.py b/tools/codegen/api/native.py new file mode 100644 index 0000000000000..af82210b20f40 --- /dev/null +++ b/tools/codegen/api/native.py @@ -0,0 +1,125 @@ +from tools.codegen.model import * + +from tools.codegen.api.types import * +import tools.codegen.api.cpp as cpp +from tools.codegen import local + +from typing import Union, Sequence, List, Optional + +# This file describes the translation of JIT schema to the native functions API. +# This looks a lot like the C++ API (which makes historical sense, because the +# idea was you wrote native functions to implement functions in the C++ API), +# but over time we have evolved the C++ API without actually changing our +# native:: kernels. The intention is to make native API and dispatcher API +# line up as closely as possible, since this results in the least overhead +# (no translation is needed from dispatcher API to native API). +# +# When a function is not use_c10_dispatcher: full, the dispatcher API actually +# coincides with the native:: API (e.g., we do as dumb as pass through as +# possible). + +def name(func: FunctionSchema) -> str: + name = str(func.name.name) + # TODO: delete this! + if func.is_out_fn(): + name += '_out' + if func.name.overload_name: + name += f'_{func.name.overload_name}' + return name + +def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: + if str(t) == 'Tensor?': + if mutable: + return MutRefCType(BaseCType('Tensor', binds)) + else: + return ConstRefCType(BaseCType('Tensor', binds)) + elif str(t) == 'Tensor?[]': + return BaseCType('const c10::List> &', binds) + return cpp.argumenttype_type(t, mutable=mutable, binds=binds) + +def returns_type(rs: Sequence[Return]) -> str: + return cpp.returns_type(rs) + +def argument_type(a: Argument, *, binds: ArgName) -> CType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) + +def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]: + # Ideally, we NEVER default native functions. However, there are a number + # of functions that call native:: directly and rely on the defaulting + # existing. So for BC, we generate defaults for non-out variants (but not + # for out variants, where it is impossible to generate an appropriate + # default) + should_default = not is_out or local.use_c10_dispatcher() is not UseC10Dispatcher.full + if isinstance(a, Argument): + default: Optional[str] = None + if should_default and a.default is not None: + default = cpp.default_expr(a.default, a.type) + return [Binding( + ctype=argument_type(a, binds=a.name), + name=a.name, + default=default, + argument=a, + )] + elif isinstance(a, SelfArgument): + # Erase SelfArgument from the distinction + return argument(a.argument, is_out=is_out) + elif isinstance(a, TensorOptionsArguments): + if local.use_c10_dispatcher() == UseC10Dispatcher.hacky_wrapper_for_legacy_signatures: + # TODO: expunge this logic entirely + default = None + if should_default: + if all(x.default == "None" for x in a.all()): + default = '{}' + elif a.dtype.default == "long": + default = 'at::kLong' # TODO: this is wrong + return [Binding( + ctype=ConstRefCType(BaseCType('TensorOptions', 'options')), + name='options', + default=default, + argument=a, + )] + else: + assert local.use_c10_dispatcher() == UseC10Dispatcher.full + default = None + if should_default: + default = '{}' + # TODO: Not sure why the arguments assigned here are for + # TensorOptionsArguments and not the constituent pieces. It seems + # to matter + return [ + Binding( + ctype=OptionalCType(BaseCType('ScalarType', 'dtype')), + name='dtype', + default=default, + argument=a, + ), + Binding( + ctype=OptionalCType(BaseCType('Layout', 'layout')), + name='layout', + default=default, + argument=a, + ), + Binding( + ctype=OptionalCType(BaseCType('Device', 'device')), + name='device', + default=default, + argument=a, + ), + Binding( + ctype=OptionalCType(BaseCType('bool', 'pin_memory')), + name='pin_memory', + default=default, + argument=a, + )] + else: + assert_never(a) + +def arguments(func: FunctionSchema) -> List[Binding]: + args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] + if local.use_c10_dispatcher() is UseC10Dispatcher.full: + args.extend(func.arguments.non_out) + args.extend(func.arguments.out) + else: + args.extend(func.arguments.out) + args.extend(func.arguments.non_out) + return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())] diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py new file mode 100644 index 0000000000000..749513cb5c0d5 --- /dev/null +++ b/tools/codegen/api/python.py @@ -0,0 +1,1201 @@ +from dataclasses import dataclass +from typing import Optional, Union, Sequence, Set, List, Dict, Tuple + +from tools.codegen.api.types import * +import tools.codegen.api.cpp as cpp +from tools.codegen.gen import pythonify_default +from tools.codegen.model import * + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Data Models +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# [Notes] python binding codegen +# +# The Python binding codegen produces code that takes the input list of +# PyObjects, finds the matching ATen C++ function using PythonArgParser, +# converts the PyObjects into C++ types and calls the ATen C++ function: +# +# +--------+ parsing +------------------------+ binding +-----------------------+ +# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch | +# +--------+ +------------------------+ +-----------------------+ +# +# The following examples demonstrate the data models the Python binding +# codegen needs to deal with and the tasks it needs to accomplish. It +# helps understand the purpose of the new data types we introduced below. +# +# - Function Schema (source of truth) +# +# aten::empty.names(int[] size, *, Dimname[]? names, +# ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None, +# MemoryFormat? memory_format=None) -> Tensor +# +# - Python Signature +# +# It's used to generate input schema string for PythonArgParser. +# Note: TensorOptions fields are reordered and the additional +# 'requires_grad' field is added: +# +# empty(IntArrayRef size, *, DimnameList? names, +# MemoryFormat? memory_format=None, ScalarType dtype=None, +# Layout layout=torch.strided, Device device=None, +# bool pin_memory=False, bool requires_grad=False) +# +# - C++ Signature +# +# It's used to generate C++ lambda formals & dispatch call. +# Note: the scattered TensorOptions fields are packed into 'options'. +# +# auto dispatch_empty = +# [](IntArrayRef size, c10::optional names, +# const TensorOptions & options, +# c10::optional memory_format) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return torch::empty(size, names, options, memory_format); +# }; +# +# - Binding between Python Arguments and C++ Arguments +# +# Given a set of Python Arguments in scope, we need produce the +# binding expressions that translate the Python API into C++ API: +# +# Python Args Cpp Args Binding Exprs +# ----------------------------------------------------------------- +# 0: size size '_r.intlist(0)' +# 1: names names 'names' [special init] +# 2: memory_format -------+ +# 3: dtype -----+-|--> options 'options' [special packing] +# 4: layout / | +# 5: device / +--> memory_format '_r.memoryformatOptional(2)' +# 6: pin_memory / +# 7: requires_grad -+ +# +# So the full dispatch expression would look like: +# +# dispatch_empty(_r.intlist(0), names, options, +# _r.memoryformatOptional(2)) +# +# Where does 'names' come from? It involves special local init: +# +# auto __names = _r.toDimnameListOptional(1); +# c10::optional names = +# __names ? c10::make_optional(DimnameList(__names.value())) +# : c10::nullopt; +# +# Where does 'options' come from? It involves special local init +# for TensorOptions. Note that Python side has the additional +# 'requires_grad' field: +# +# const auto options = TensorOptions() +# .dtype(_r.scalartype(3)) +# .device(_r.device(5)) +# .layout(_r.layoutOptional(4)) +# .requires_grad(_r.toBool(7)) +# .pinned_memory(_r.toBool(6)); +# +# In some other cases one Python Argument can map to multiple C++ +# Arguments. For example: +# +# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) +# -> (Tensor values, Tensor indices) +# +# Python Args Cpp Args Binding Exprs +# --------------------------------------------------------------------- +# +----> max 'out[0]' +# /-----> max_values 'out[1] +# 0: input / self '_r.tensor(0)' +# 1: dim / dim '_r.dimname(1)' +# 2: keepdim / keepdim '_r.toBool(2)' +# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)' +# +# As demonstrated above, the binding can involve reordering, +# packing, unpacking and special local inits. +# +# +# Let's look at a concrete example: +# +# static PythonArgParser parser({ +# "abs(Tensor input, *, Tensor out=None)", +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- Python Schema, represented by PythonSignature and PythonArgument +# +# }, /*traceable=*/true); +# +# ParsedArgs<2> parsed_args; +# auto _r = parser.parse(nullptr, args, kwargs, parsed_args); +# +# ... +# +# if (_r.isNone(1)) { +# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out') +# represented by PythonArgParserOutputExpr +# +# // aten::abs(Tensor self) -> Tensor +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, base version +# +# auto dispatch_abs = [](const Tensor & self) -> Tensor { +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_args / dispatch_lambda_return_str +# generated from NativeFunction / CppSignature +# (deprecated PythonSignature is special) +# arguments are represented by DispatchLambdaArgument +# +# pybind11::gil_scoped_release no_gil; +# return self.abs(); +# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs +# generated from NativeFunction / CppSignature +# }; +# return wrap(dispatch_abs(_r.tensor(0))); +# ~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_exprs +# binding PythonArgParserOutputExpr (python args) +# and DispatchLambdaArgument (c++ args) +# +# } else { +# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, out-variant +# +# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::abs_out(out, self); +# }; +# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); +# } +# +# +# [Notes] python interface codegen +# The python dataclasses below are used used to generate both python binding code +# and pyi type hint signatures. +# In theory these two should look very similar, but there are number of differences +# in how pyi signatures vs. python_arg_parser signatures are generated. +# These differences have been encapsulated in signature_str() vs. signature_str_pyi() +# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. +# For examples, only pyi signatures include return types. + +@dataclass(frozen=True) +class PythonReturns: + returns: Tuple[Return, ...] + + def named_tuple_pyi(self) -> Optional[Tuple[str, str]]: + python_returns = [argument_type_str_pyi(r.type) for r in self.returns] + field_names = namedtuple_fieldnames(self.returns) + if field_names: + namedtuple_name = '_'.join(['namedtuple'] + field_names) + tuple_args = [f'("{name}", {typ})' for name, typ in zip(field_names, python_returns)] + namedtuple_def = f'NamedTuple("{namedtuple_name}", [{", ".join(tuple_args)}])' + return namedtuple_name, namedtuple_def + return None + + def returns_str_pyi(self) -> str: + named_tuple = self.named_tuple_pyi() + if named_tuple is not None: + namedtuple_name, _ = named_tuple + return namedtuple_name + + python_returns = [argument_type_str_pyi(r.type) for r in self.returns] + if len(python_returns) > 1: + return 'Tuple[' + ', '.join(python_returns) + ']' + if len(python_returns) == 1: + return python_returns[0] + return 'None' + + +@dataclass(frozen=True) +class PythonArgument: + name: str + type: Type + default: Optional[str] + + # Used to generate the default init expr for some PythonArgParser outputs, e.g.: + # + # _r.layoutWithDefault(3, layout_from_backend(self.options().backend()))) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # ^ + # +--- default_init str + default_init: Optional[str] + + # Compute argument formal for python argument parsing. + # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. + def argument_str(self, *, method: bool = False) -> str: + type_str = argument_type_str(self.type).replace('const ', '').replace(' &', '') + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == 'self' and type_str == 'Tensor' and not method: + name = 'input' + + # add default + if self.default is not None: + default = { + 'nullptr': 'None', + 'c10::nullopt': 'None', + '{}': 'None', + }.get(self.default, self.default) + return f'{type_str} {name}={default}' + else: + return f'{type_str} {name}' + + def argument_str_pyi(self, *, method: bool = False, deprecated: bool = False) -> str: + type_str = argument_type_str_pyi(self.type) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == 'self' and type_str == 'Tensor' and not method and not deprecated: + name = 'input' + + if name == 'from': # from is a Python keyword... + name += '_' + + # pyi merges the _out and functional variants into the same signature, with an optional out arg + if name == 'out' and type_str == 'Tensor' and not deprecated: + type_str = 'Optional[' + type_str + ']' + + # pyi deprecated signatures don't get defaults for their out arg + treat_as_no_default = deprecated and isinstance(self, PythonOutArgument) and self.default == 'None' + + # add default + if self.default is not None and not treat_as_no_default: + if isinstance(self.type, ListType) and self.type.elem == BaseType(BaseTy.int) and \ + self.default.startswith('{') and self.default.endswith('}'): + default = '(' + self.default[1:-1] + ')' + else: + default = { + 'nullptr': 'None', + 'c10::nullopt': 'None', + '{}': 'None', + 'MemoryFormat::Contiguous': 'contiguous_format', + 'QScheme::PER_TENSOR_AFFINE': 'per_tensor_affine', + }.get(self.default, self.default) + return f'{name}: {type_str}={default}' + else: + return f'{name}: {type_str}' + +@dataclass(frozen=True) +class PythonOutArgument(PythonArgument): + # In Python signature multiple output fields are packed into one 'out' argument. + # When binding to C++, it's first binded to a local 'out' variable: + # 'auto out = _r.tensorlist_n<2>(2);', + # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. + # TODO: maybe don't need keep scattered out fields for python signature? + outputs: Tuple[PythonArgument, ...] + + @staticmethod + def from_outputs(outputs: Tuple[PythonArgument, ...]) -> Optional['PythonOutArgument']: + if not outputs: + return None + + size = len(outputs) + if size == 1: + return PythonOutArgument( + name=outputs[0].name, + type=outputs[0].type, + default='None', + default_init=None, + outputs=outputs, + ) + elif size > 1: + if any(map(lambda a: not a.type.is_tensor_like(), outputs)): + raise RuntimeError(f'Unsupported output type: {outputs}') + return PythonOutArgument( + name='out', + # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None? + type=ListType(BaseType(BaseTy.Tensor), size), + default='None', + default_init=None, + outputs=outputs, + ) + raise AssertionError(r'Unexpected PythonOutArgument size') + +@dataclass(frozen=True) +class PythonSignature: + # Base operator name, without inplace/outplace suffix. + name: str + + # Positional arguments. + # TODO: create a dedicated SelfArgument type for 'self'? + input_args: Tuple[PythonArgument, ...] + + # Keyword arguments excluding the 'out' argument and scattered kwargs belonging + # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). + input_kwargs: Tuple[PythonArgument, ...] + + output_args: Optional[PythonOutArgument] + + # Return types, which are only used by pyi + returns: PythonReturns + + # These are scattered kwargs arguments belonging to TensorOptions. + # When binding to C++, they are packed into a TensorOptions object 'options'. + # It's possible that the C++ signature doesn't take TensorOptions object (e.g. + # for out variant), in which case they will be used as scattered fields without + # being packed into 'options'. + # TODO: maybe create a PythonTensorOptionsArgument? + tensor_options_args: Tuple[PythonArgument, ...] + + # method or function signature? + method: bool + + @property + def deprecated(self) -> bool: + return False + + def arguments( + self, *, skip_outputs: bool = False, skip_tensor_options: bool = False + ) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]: + result: List[Union[PythonArgument, PythonOutArgument]] = [] + result.extend(self.input_args) + result.extend(self.input_kwargs) + if self.output_args is not None and not skip_outputs: + result.append(self.output_args) + if not skip_tensor_options: + result.extend(self.tensor_options_args) + return tuple(result) + + def arguments_count(self) -> int: + return len(self.arguments()) + + def output_idx(self) -> int: + return len(self.input_args) + len(self.input_kwargs) + + # [old codegen] Compute the Python function signature for argument parsing, + # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: + # this is NOT the same type signature as specified by PEP 484 + # as understood by mypy; our format was independently developed + # and has some quirks to make it more suitable specifically + # for error parsing. + # + # For a translation to mypy-valid type signatures, see + # signature_str_pyi(). + def signature_str(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method), args)) + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, '*') + + return f'{self.name}({", ".join(schema_formals)})' + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, '*') + + # only pyi signatures include returns + returns_str = self.returns.returns_str_pyi() + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: + # only pyi uses vararg signatures + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method), args)) + # vararg only applies to pyi signatures. vararg variants are not generated for all signatures + num_args = self.arguments_count() + num_positionalargs = len(self.input_args) + + have_vararg_version = False + if num_args > 0: + vararg_type = args[0].type + if isinstance(vararg_type, ListType) and str(vararg_type.elem) == 'int' and num_positionalargs == 1: + have_vararg_version = True + + if not have_vararg_version: + return None + # Below are the major changes in vararg vs. regular pyi signatures + # vararg signatures also omit the asterix + schema_formals[0] = '*' + args[0].name + ': _int' + + returns_str = self.returns.returns_str_pyi() + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + +# The deprecated python signature involves some special logic, so create a +# dedicated data model to store these extra properties. +@dataclass(frozen=True) +class PythonSignatureDeprecated(PythonSignature): + # We need keep the order of arguments in deprecated signature. + # Particularly, method signature might have 'self' not at the beginning, e.g.: + # addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) + # When generating lambda function signature we need follow the exact order (even for method=True): + # [](Scalar beta, const Tensor & self, const Tensor & mat1, const Tensor & mat2) -> Tensor + deprecated_args_names: Tuple[str, ...] + + # The deprecated signature might miss some arguments that the corresponding + # C++ signature expects. We need store the constant default values to pass in. + # For example: + # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) + # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + # [func call]: self.addmm(mat1, mat2, beta, 1) + # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. + deprecated_args_exprs: Tuple[str, ...] + + @property + def deprecated(self) -> bool: + return True + + def signature_str(self, *, skip_outputs: bool = False) -> str: + return PythonSignature.signature_str(self, skip_outputs=skip_outputs) + '|deprecated' + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: List[str] = list(map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args)) + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, '*') + + returns_str = self.returns.returns_str_pyi() + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: + # the codegen doesn't include vararg variants for deprecated signatures + return None + +# This struct is used to hold the PythonSignature and its corresponding +# NativeFunction BEFORE grouping base and out-variant functions. +# Why not store NativeFunction in PythonSignature or construct PythonSignature +# from NativeFunction? Because they are not 1-1 mapped. +# One native function could have both deprecated and non-deprecated python +# signatures - NativeFunction doesn't contain information to construct the +# deprecated python signature. +# One python signature is used to handle both the base and the out-variant +# function - see 'PythonSignatureGroup'. +@dataclass(frozen=True) +class PythonSignatureNativeFunctionPair: + signature: PythonSignature + function: NativeFunction + +# We merge pairs of functions with signatures that are equivalent mod +# output arguments, and use a single entry in the python_arg_parser sig +# list for both (output arguments become optional). +@dataclass(frozen=True) +class PythonSignatureGroup: + # The signature used for Python argument parsing. The outplace signature + # is preferred if exists, because it can be used to parse inputs for both + # the out-place variant and the base version (with output omitted). + signature: PythonSignature + + # The regular ATen declaration (e.g. conv2d) + base: NativeFunction + + # The out variant (e.g. conv2d_out) + outplace: Optional[NativeFunction] + +# C++ function dispatch is wrapped in a lambda function. The lambda function +# has almost the same signature as the C++ function, only with some small +# variants - see details below. +# This data model is used to represent arguments of the lambda function +# signature. +@dataclass(frozen=True) +class DispatchLambdaArgument: + name: str + type_str: str + is_out_arg: bool + +# To pass PyObjects arguments to C++ function (via the lambda wrapper), +# we need first convert PyObjects into simple C++ objects. This work +# is done by PythonArgParser. +# This data model is used to represent the output of PythonArgParser. +# It has 1-1 mapping with PythonArgument in PythonSignature. +@dataclass(frozen=True) +class PythonArgParserOutputExpr: + # argument name + name: str + + # RHS expression to reference PythonArgParser output. + expr: str + + # In some special cases we need create different expr, e.g.: + # '_r.isNone(1)' instead of '_r.tensor(1)'. + index: int + + # The python argument it maps to. + argument: PythonArgument + + @property + def is_none_expr(self) -> str: + return f'_r.isNone({self.index})' + +# To pass PythonArgParser output to the lambda wrapper, we need bind +# PythonArgParserOutputExpr to DispatchLambdaArgument. +# They are not always 1-1 mapped, e.g. scattered TensorOptions fields +# need be packed into a TensorOptions object, which is the argument +# that the lambda function wrapper takes. +@dataclass(frozen=True) +class DispatchLambdaArgumentExprs: + # The exprs that provide the binding for lambda arguments, e.g.: + # + # 'self' -> '_r.tensor(0)' + # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]' + # 'options' -> 'options' + # + # It has 1-1 mapping with DispatchLambdaArgument. + exprs: Sequence[str] + + # Special local inits, which might introduce new variables that + # the 'exprs' above reference, e.g.: + # + # 'auto out = _r.tensorlist_n<2>(2);' + # + inits: Sequence[str] + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Helper Functions +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: + return CppSignatureGroup.from_native_function(f, method=method).signature + +def has_tensor_options(f: NativeFunction) -> bool: + return f.func.arguments.tensor_options is not None + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Signature +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# 'simple_type' was introduced by the old codegen, which is slightly +# different from the python schema type, e.g.: doesn't have '?' suffix +# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type. +def argument_type_str(t: Type, *, simple_type: bool = False) -> str: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return 'Tensor' + elif t.name == BaseTy.int: + return 'int64_t' + elif t.name == BaseTy.float: + return 'double' + elif t.name == BaseTy.str: + return 'std::string' + elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, + BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage, + BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat, + BaseTy.Dimname, BaseTy.Stream, BaseTy.ConstQuantizerPtr]: + # These python schema type names line up with their function schema names + return t.name.name + + elif isinstance(t, OptionalType): + if str(t.elem) == 'Tensor': + # Is it desired to keep '?' for simple_type with new style dispatcher? + return 'Tensor?' + elem = argument_type_str(t.elem, simple_type=simple_type) + if elem == 'Layout': + # TODO: fix this special case in PythonArgParser? + return 'Layout' + else: + return f'{elem}?' + + elif isinstance(t, ListType): + size = t.size if not simple_type else None + if str(t.elem) == 'bool': + assert t.size is not None + return f'std::array' + elif str(t.elem) == 'int': + return f'IntArrayRef[{size}]' if size is not None else 'IntArrayRef' + elif str(t.elem) == 'Tensor': + return f'TensorList[{size}]' if size is not None else 'TensorList' + elif str(t.elem) == 'Scalar': + return f'ScalarList[{size}]' if size is not None else 'ScalarList' + elif str(t.elem) == 'Tensor?': + if simple_type: + return 'c10::List>' + else: + return 'const c10::List> &' + elif str(t.elem) == 'Dimname': + return f'DimnameList[{size}]' if size is not None else 'DimnameList' + elem = argument_type_str(t.elem, simple_type=simple_type) + return f'ArrayRef<{elem}>' + + raise RuntimeError(f'unrecognized type {repr(t)}') + +def argument_type_size(t: Type) -> Optional[int]: + l = t.is_list_like() + if l is not None and str(l.elem) != 'bool': + return l.size + else: + return None + +def argument(a: Argument) -> PythonArgument: + return PythonArgument( + name=a.name, + type=a.type, + # TODO: directly translate a.default to python default + default=str(pythonify_default(cpp.default_expr(a.default, a.type))) + if a.default is not None else None, + default_init=None, + ) + +# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen +def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> PythonSignature: + args: List[Argument] = [] + args.extend(f.func.arguments.pre_self_positional) + # Skip SelfArgument if this is method. + if not method and f.func.arguments.self_arg is not None: + args.append(f.func.arguments.self_arg.argument) + args.extend(f.func.arguments.post_self_positional) + args.extend(f.func.arguments.pre_tensor_options_kwarg_only) + # Skip TensorOptionsArguments. Python side TensorOptions + # arguments are created based on different rules - see below. + args.extend(f.func.arguments.post_tensor_options_kwarg_only) + args.extend(f.func.arguments.out) + + input_arg_set = set(a.name for a in f.func.arguments.flat_positional) + kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only) + out_arg_set = set(a.name for a in f.func.arguments.out) + + input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) + input_kwargs = tuple(map(argument, filter(lambda a: a.name in kwarg_only_set, args))) + outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args))) + + # Reintroduce the scattered fields of TensorOptions for Python. + # Compared to the cpp counterpart, the python arguments have new property + # (default_init) and a new argument 'requires_grad', which require some + # special handlings. + # [old codegen] TODO: because these aren't guaranteed to be 100% faithful + # to the original versions in the yaml, this recreation is a potential + # source of drift between eager and JIT. Pull this logic out to a shared place. + + has_tensor_input_arg = any(a.type.is_tensor_like() for a in f.func.arguments.flat_non_out) + if any(a.name == 'requires_grad' for a in f.func.schema_order_arguments()): + raise ValueError('argument named requires_grad is reserved, should not explicitly add it in the schema') + + # [old codegen] this probably won't work if one of the returns is not a tensor, + # but it will produce a compile-time error that is obvious. + has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) + + name: str = cpp.name(f.func) + is_factory_function = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg) + is_like_or_new_function = f.category_override in ('new', 'like') or name.startswith('new_') or name.endswith('_like') + + tensor_options_args: List[PythonArgument] = [] + if is_factory_function or is_like_or_new_function: + tensor_options_args.append(PythonArgument( + name='dtype', + type=BaseType(BaseTy.ScalarType), + default='None' if pyi else _dtype_default_type_hack(name), + default_init='self.scalar_type()' if is_like_or_new_function else None, + )) + tensor_options_args.append(PythonArgument( + name='layout', + type=OptionalType(BaseType(BaseTy.Layout)), + default='strided' if pyi else 'torch.strided', + default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None, + )) + tensor_options_args.append(PythonArgument( + name='device', + type=BaseType(BaseTy.Device), + default='None', + default_init='self.device()' if is_like_or_new_function else None, + )) + tensor_options_args.append(PythonArgument( + name='pin_memory', + type=BaseType(BaseTy.bool), + default='False', + default_init=None, + )) + tensor_options_args.append(PythonArgument( + name='requires_grad', + type=BaseType(BaseTy.bool), + default='False', + default_init=None, + )) + + returns = PythonReturns(returns=f.func.returns) + + return PythonSignature( + name=str(f.func.name.name), + input_args=input_args, + input_kwargs=input_kwargs, + output_args=PythonOutArgument.from_outputs(outputs), + tensor_options_args=tuple(tensor_options_args), + returns=returns, + method=method, + ) + +# TODO blowtorch +# note: removing this will be BC-breaking. A quick test shows that +# randperm will otherwise default its dtype to torch.float64 +def _dtype_default_type_hack(name: str) -> str: + if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices': + return 'torch.int64' + else: + return 'None' +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Interface +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]: + if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)): + return [] + else: + if any(map(lambda r: r.name is None, returns)): + # When building on Windows, `PyStructSequence_UnnamedField` could not be + # resolved by the linker for some reason, which cause error in building: + # + # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol + # PyStructSequence_UnnamedField + # + # Thus, at this point in time, we do not support unnamed + # fields in namedtuple; you must either name all fields, + # or none of them. + raise ValueError("Unnamed field is not supported by codegen") + + return list(map(lambda r: str(r.name), returns)) + +def argument_type_str_pyi(t: Type) -> str: + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + + if isinstance(t, BaseType): + if t.name == BaseTy.int: + ret = '_int' + elif t.name == BaseTy.float: + ret = '_float' + elif t.name == BaseTy.str: + ret = 'str' + elif t.name == BaseTy.Scalar: + ret = 'Number' + elif t.name == BaseTy.ScalarType: + ret = '_dtype' + elif t.name == BaseTy.bool: + ret = '_bool' + elif t.name == BaseTy.QScheme: + ret = '_qscheme' + elif t.name == BaseTy.Layout: + ret = '_layout' + elif t.name == BaseTy.Device: + ret = 'Union[_device, str, None]' + elif t.name == BaseTy.MemoryFormat: + ret = 'memory_format' + elif t.name == BaseTy.Dimname: + ret = 'Union[str, ellipsis, None]' + elif t.name in [BaseTy.Tensor, BaseTy.Generator, + BaseTy.Storage, BaseTy.Stream, BaseTy.str]: + # These python schema type names line up with their function schema names + ret = t.name.name + + elif isinstance(t, ListType): + if str(t.elem) == 'int': + ret = 'Union[_int, _size]' if t.size is not None else '_size' + elif t.is_tensor_like(): + # TODO: this doesn't seem right... + # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] + # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] + if isinstance(t.elem, OptionalType): + add_optional = True + ret = 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]' if t.size is not None else \ + 'Union[Tuple[Tensor, ...], List[Tensor]]' + elif str(t.elem) == 'float': + ret = 'Sequence[float]' + else: + elem = argument_type_str_pyi(t.elem) + ret = f'Sequence[{elem}]' + + if add_optional: + ret = 'Optional[' + ret + ']' + return ret + + raise RuntimeError(f'unrecognized type {repr(t)}') + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# C++ Function Dispatch +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# This section provides APIs to generate the code that does C++ function +# dispatch. The C++ function call is wrapped by a lambda function. +# For example: +# +# // aten::selu_(Tensor(a!) self) -> Tensor(a!) +# auto dispatch_selu_ = [](Tensor self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::selu_(self); +# }; +# +# The lambda function's signature follows the C++ signature in common +# cases, e.g.: +# +# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For out variant the 'out' argument's type is changed from 'Tensor &' +# to 'Tensor'. It's because when calling the lambda it passes in the +# PythonArgParser output '_r.tensor(3)', which is stack allocated object +# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'. +# +# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For multi-output case it can keep using reference type because the +# PythonArgParser output has been unpacked to local variables, e.g.: +# +# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, +# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) +# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple +# +# For deprecated python signature, it should follow deprecated python arg order. +# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary? + +def dispatch_lambda_args(ps: PythonSignature, f: NativeFunction) -> Tuple[DispatchLambdaArgument, ...]: + # Start with cpp arguments - dispatch lambda signature always include 'self' + cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() + + # Special reorder logic for deprecated python signature + if isinstance(ps, PythonSignatureDeprecated): + m: Dict[str, Binding] = dict((a.name, a) for a in cpp_args) + # reorder according to the deprecated signature + # ignore 'out' argument when binding to non-output function. + ordered_args = filter(lambda n: n != 'out' or f.func.is_out_fn(), + ps.deprecated_args_names) + cpp_args = list(map(lambda n: m[n], ordered_args)) + + out_args: Set[str] = set(a.name for a in f.func.arguments.out) + + # Convert from cpp argument to lambda argument + def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: + type_str = cpp_arg.type + is_out_arg = cpp_arg.name in out_args + if ps.method and cpp_arg.name == 'self': + # For method's 'self', we can use 'Tensor &' and simply ignore mutability! + type_str = 'Tensor &' + else: + # For other cases we need prevent dangling refs to temps (unless it's + # unpacked scattered output) + # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'. + # TODO: avoid this special handling? + ensure_temp_safe = len(out_args) <= 1 or not is_out_arg + if ensure_temp_safe: + type_str = { + 'Tensor &': 'Tensor', + }.get(type_str, type_str) + return DispatchLambdaArgument( + name=cpp_arg.name, + type_str=type_str, + is_out_arg=is_out_arg, + ) + + return tuple(map(dispatch_lambda_arg, cpp_args)) + +# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean +# it's enough to just extend the list here. Before you do this, make sure +# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. +SUPPORTED_RETURN_TYPES = { + 'Tensor', + 'std::tuple', + 'std::tuple', + 'std::tuple', + 'std::tuple', + 'std::tuple', + 'std::tuple', + 'std::tuple', + 'std::tuple', + 'std::tuple', + 'std::vector', + 'Scalar', 'bool', 'int64_t', 'void*', 'void', + 'QScheme', 'double', + 'IntArrayRef', + 'ScalarType' +} + +def dispatch_lambda_return_str(f: NativeFunction) -> str: + # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &') + # because the dispatch lambdas take mutable arguments *by value*, not + # by reference. If you then return a reference to such an argument, you + # will now have a pointer to a dangling stack entry. Not good. + # + # You want: + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; + # ^^^^^^ + # + # *not* + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; + # ^^^^^^^ + # + # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing + # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a + # mutable reference to temporary. Maybe we could assign it to a + # variable itself.) + returns_without_annotation = tuple(map(lambda r: Return(r.name, r.type, None), f.func.returns)) + return_str = cpp.returns_type(returns_without_annotation) + if return_str not in SUPPORTED_RETURN_TYPES: + raise RuntimeError(f'{f.func.name} returns unsupported type {return_str}') + return return_str + +def cpp_dispatch_target(f: NativeFunction) -> str: + name = cpp.name(f.func) + if Variant.method in f.variants: + return f'self.{name}' + if Variant.function in f.variants: + if has_tensor_options(f) or f.func.name.name.base.endswith('_like'): + namespace = 'torch' + else: + namespace = 'at' + return f'{namespace}::{name}' + raise RuntimeError(f'could not dispatch, neither function nor method: {f.func}') + +def cpp_dispatch_exprs(f: NativeFunction, *, + python_signature: Optional[PythonSignature] = None, + ) -> Tuple[str, ...]: + cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() + + exprs: Tuple[str, ...] = tuple() + if not isinstance(python_signature, PythonSignatureDeprecated): + # By default the exprs are consistent with the C++ signature. + exprs = tuple(map(lambda a: a.name, cpp_args)) + else: + # For deprecated python signature we may need fill in some constants. + exprs = tuple(filter(lambda n: n != 'out' or f.func.is_out_fn(), + python_signature.deprecated_args_exprs)) + + if Variant.method in f.variants: + exprs = tuple(filter('self'.__ne__, exprs)) + + return exprs + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python / C++ Args Binding +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# We explicitly enumerate the PythonArgParser unpacking methods for all +# supported types. This might be more verbose than necessary, partially +# because of the irregularity of unpacking method naming, partially +# because we want to mimic the old codegen behavior - to reject +# unexpected and/or unsupported cases which the old codegen rejects. +# For certain cases it is intentionally more restrictive than necessary, +# e.g.: it doesn't accepts doublelist with definite size. +def arg_parser_unpack_method(t: Type, has_default: bool) -> str: + if has_default and str(t) not in ('ScalarType', 'Device', 'Layout?'): + raise RuntimeError(f'type \'{t}\' does not supported unpacking with default') + + if isinstance(t, BaseType): + if t.name in [BaseTy.Tensor, BaseTy.Stream, BaseTy.Storage, + BaseTy.Scalar, BaseTy.Dimname]: + # These unpack methods line up with their schema names + return t.name.name.lower() + elif t.name == BaseTy.ScalarType: + return 'scalartypeWithDefault' if has_default else 'scalartype' + elif t.name == BaseTy.Device: + return 'deviceWithDefault' if has_default else 'device' + elif t.name == BaseTy.int: + return 'toInt64' + elif t.name == BaseTy.bool: + return 'toBool' + elif t.name == BaseTy.float: + return 'toDouble' + elif t.name == BaseTy.str: + return 'string' + + elif isinstance(t, OptionalType): + if str(t.elem) == 'Tensor': + return 'optionalTensor' + + elif isinstance(t.elem, BaseType): + if t.elem.name in [BaseTy.ScalarType, BaseTy.Scalar, + BaseTy.int, BaseTy.bool, + BaseTy.float, BaseTy.str]: + # Regular cases: append 'Optional' to elem's unpacking method + return arg_parser_unpack_method(t.elem, False) + 'Optional' + elif t.elem.name == BaseTy.MemoryFormat: + return 'memoryformatOptional' + elif t.elem.name == BaseTy.Generator: + return 'generator' + elif t.elem.name == BaseTy.Layout: + return 'layoutWithDefault' if has_default else 'layoutOptional' + + elif isinstance(t.elem, ListType): + if str(t.elem.elem) == 'int': + # accept definite size + return 'intlistOptional' + elif str(t.elem) == 'float[]': + return 'doublelistOptional' + elif str(t.elem) == 'Dimname[]': + return 'toDimnameListOptional' + + elif isinstance(t, ListType): + if str(t.elem) == 'Tensor': + # accept and use definite size + if t.size is not None: + return f'tensorlist_n<{t.size}>' + else: + return 'tensorlist' + elif str(t.elem) == 'Tensor?': + return 'list_of_optional_tensors' + elif str(t.elem) == 'Dimname': + # accept definite size + return 'dimnamelist' + elif str(t.elem) == 'int': + # accept definite size + return 'intlist' + elif str(t) == 'float[]': + return 'doublelist' + elif str(t) == 'Scalar[]': + return 'scalarlist' + raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser') + +# Return RHS expression for python argument using PythonArgParser output. +# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' +def arg_parser_output_expr( + arg_index: int, a: PythonArgument +) -> PythonArgParserOutputExpr: + has_default = a.default_init is not None + unpack_method = arg_parser_unpack_method(a.type, has_default) + default = f', {a.default_init}' if has_default else '' + expr = f'_r.{unpack_method}({arg_index}{default})' + + return PythonArgParserOutputExpr( + name=a.name, + expr=expr, + index=arg_index, + argument=a, + ) + +# Returns a map with key = arg_name and value = PythonArgParserOutputExpr. +def arg_parser_output_exprs( + ps: PythonSignature, f: NativeFunction +) -> Dict[str, PythonArgParserOutputExpr]: + return {e.name: e for i, a in enumerate(ps.arguments()) + for e in (arg_parser_output_expr(i, a), )} + +# argument name to type for scattered tensor options fields +TENSOR_OPTIONS_FIELDS = { + 'dtype': 'ScalarType', + 'device': 'Device', + 'layout': 'Layout?', + 'pin_memory': 'bool', + 'requires_grad': 'bool', +} + +# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args). +def dispatch_lambda_exprs( + ps: PythonSignature, f: NativeFunction +) -> DispatchLambdaArgumentExprs: + # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing + # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser + # outputs. + arg_parser_outputs = arg_parser_output_exprs(ps, f) + lambda_args = dispatch_lambda_args(ps, f) + inits: List[str] = [] + lambda_args_exprs: Dict[str, str] = dict() + + has_toptions = has_tensor_options(f) + + # 1. special inits/unpacking to provide binding exprs for lambda arguments. + for a in ps.arguments(skip_tensor_options=True): + name = a.name + arg_parser_expr = arg_parser_outputs[a.name].expr + + if has_toptions and name == 'self': + # TODO: why this needs to be special case? + inits.extend([ + f'auto self = {arg_parser_expr};', + ]) + lambda_args_exprs[name] = name + elif isinstance(a, PythonOutArgument) and len(a.outputs) > 1 and f.func.is_out_fn(): + inits.extend([ + f'auto out = {arg_parser_expr};', + ]) + for i, out_arg in enumerate(a.outputs): + lambda_args_exprs[out_arg.name] = f'out[{i}]' + elif str(a.type) == 'Dimname[]?': + # [old codegen] + # TODO: make this part of something more general, or get rid of it. + # optional> are special. The PythonArgParser returns an + # optional>, which cannot be implicitly converted to + # optional>. One needs to unwrap the optional and rewrap. + inits.extend([ + f'auto __{name} = {arg_parser_expr};', + f'c10::optional {name} = __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;', + ]) + lambda_args_exprs[name] = name + else: + # default case - directly using PythonArgParser output expr + lambda_args_exprs[name] = arg_parser_expr + + # method's self is passed directly to python binding, rather than parsed + if ps.method: + lambda_args_exprs['self'] = 'self' + + # 2. special packing/checking for TensorOptions. + tensor_options_args_names = list(map(lambda a: a.name, ps.tensor_options_args)) + if has_toptions: + if f.func.is_out_fn(): + raise RuntimeError(f'{f.func}: tensor options with output arg') + for a in ps.tensor_options_args: + if a.name not in TENSOR_OPTIONS_FIELDS: + raise RuntimeError( + f'{f.func}: unrecognized tensor options field \'{a.name}\' in python binding arguments') + if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): + raise RuntimeError( + f'{f.func}: unrecognized type \'{str(a.type)}\' for tensor options field \'{a.name}\'') + if not all(map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys())): + raise RuntimeError( + f'{f.func}: incomplete tensor options args: {tensor_options_args_names}') + + inits.append(f'''\ +const auto options = TensorOptions() + .dtype({arg_parser_outputs['dtype'].expr}) + .device({arg_parser_outputs['device'].expr}) + .layout({arg_parser_outputs['layout'].expr}) + .requires_grad({arg_parser_outputs['requires_grad'].expr}) + .pinned_memory({arg_parser_outputs['pin_memory'].expr}); +torch::utils::maybe_initialize_cuda(options); +''') + lambda_args_exprs['options'] = 'options' + + # 3. special case - access scattered TensorOptions fields without packing + # TODO: maybe move to the generator side as it's not related to binding. + if not has_toptions and tensor_options_args_names: + if 'dtype' in tensor_options_args_names: + # we're an output-arg variant, check these args against output tensor + if not f.func.is_out_fn(): + raise RuntimeError( + f'{f.func}: dtype in tensor_options_args without output arg') + if not all(map(lambda a: a in tensor_options_args_names, ('layout', 'device'))): + raise RuntimeError( + f'{f.func}: incomplete tensor options for output check') + + inits.append(f"""\ +check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr}, + {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr}, + {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr}); +""") + # we'll set requires_grad on outgoing tensor + if 'requires_grad' not in tensor_options_args_names: + raise RuntimeError( + f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]') + + return DispatchLambdaArgumentExprs( + exprs=tuple(map(lambda a: lambda_args_exprs[a.name], lambda_args)), + inits=inits, + ) diff --git a/tools/codegen/api/translate.py b/tools/codegen/api/translate.py new file mode 100644 index 0000000000000..1b612080d8b64 --- /dev/null +++ b/tools/codegen/api/translate.py @@ -0,0 +1,143 @@ +from typing import Dict, Sequence, List, NoReturn +from tools.codegen.api.types import * + +# This file implements a small program synthesis engine that implements +# conversions between one API to another. +# +# The key data type in this file in CType, short for C++ semantic type. A CType +# represents a C++ type, plus semantic information about what it represents. +# For example, consider the argument "bool pin_memory"; its normal C++ type is +# "bool", but its C++ semantic type also keeps track that this represents a +# "pin_memory"; you can't just use a random other boolean in a context where you +# need a "pin_memory"! +# +# The translator takes a list of needed CTypes, and then figures out how +# to construct expressions with these CTypes from the given bindings. Many +# of these expressions are trivial (I need a Tensor other; there's a Tensor +# other scope); others are more nontrivial and may require packing/unpacking. +# Some examples of non-trivial action: +# +# - Need the "dtype" binding? Well, maybe "dtype" isn't available +# in the context, instead, "options" is, and you need to extract +# it from there. (Gather) +# +# - Need the "context" binding? Well, maybe "context" isn't available +# in the context, and you need to construct it from "dtype", "device", +# etc. (Scatter) +# +# - Need the "memory_format" binding? Well, actually, it's available +# from both "memory_format" and "options", so you had better make sure +# they are consistent. (Join) + +options_ctype = ConstRefCType(BaseCType("TensorOptions", "options")) + +class UnsatError(RuntimeError): + pass + +# Given a set of in-scope bindings and a set of target bindings, synthesize +# a list of expressions that uses only the in-scope bindings (bindings) that +# have all of the types of goals. You may want to use this function if +# you're generating code for a function like: +# +# void f({args}) { +# g({exprs}); // g is a different API +# } +# +# and you need to generate "exprs". +# +# TODO: Don't need full Binding for goals, CType will do +# TODO: Don't need full Binding for bindings, list of Expr will do +def translate(bindings: Sequence[Binding], goals: Sequence[Binding], *, method: bool = False) -> List[Expr]: + # Add all the bindings to the context + ctx: Dict[CType, str] = {} + for b in bindings: + ctx[b.ctype] = b.name + + # Add implicit bindings if the generated code is inside a Tensor method + if method: + ctx[MutRefCType(BaseCType("Tensor", "self"))] = "const_cast(*this)" + ctx[ConstRefCType(BaseCType("Tensor", "self"))] = "const_cast(*this)" + # This is better! Byte-for-byte compat + # ctx[ConstRefCType(BaseCType("Tensor", "self"))] = "*this" + + def unsat(goal: CType) -> NoReturn: + ctx_desc = '\n'.join(f" {t.cpp_type()} {e};" for t, e in ctx.items()) + raise UnsatError(f''' +Failed to synthesize the expression "{goal.cpp_type()} {goal.name}" +while trying to translate from: + + from_func({', '.join(b.defn() for b in bindings)}) + +to: + + to_func({', '.join(g.defn() for g in goals)}) + +When I failed, the following bindings were available in the context: + +{ctx_desc} + +This probably means there is a missing rule in the rules of tools.codegen.api.translate. +Check this module for more information. +''') + + # A shitty backtracking search implementation. It's shitty because it + # doesn't actually do backtracing or search. In particular, if + # direct=True, we won't try to do any fancy synthesis, just trivial + # conversions (e.g., "T a" is OK for "const T& a"). So all of the + # existing rules in this function simply try to solve immediately, + # and bail if things don't work out. + def solve(goal: CType, *, direct: bool) -> str: + def direct_solve(goal: CType) -> str: + return solve(goal, direct=True) + + if goal in ctx: + # Trivial + return ctx[goal] + + # If the goal is a const&, try solving for the value type first + if isinstance(goal, ConstRefCType): + try: + return solve(goal.elem, direct=direct) + except UnsatError: + pass + + if direct: + unsat(goal) + + # For now, all of these rules are mutually exclusive. + if goal == OptionalCType(BaseCType("MemoryFormat", "memory_format")): + memory_format = direct_solve( + OptionalCType(BaseCType("MemoryFormat", SpecialArgName.possibly_redundant_memory_format)) + ) + try: + options = direct_solve(options_ctype) + return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" + except UnsatError: + return memory_format + + elif goal == BaseCType("TensorOptions", "options"): + dtype = direct_solve(OptionalCType(BaseCType("ScalarType", "dtype"))) + pin_memory = direct_solve(OptionalCType(BaseCType("bool", "pin_memory"))) + device = direct_solve(OptionalCType(BaseCType("Device", "device"))) + layout = direct_solve(OptionalCType(BaseCType("Layout", "layout"))) + return f'TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})' + + elif goal == OptionalCType(BaseCType("ScalarType", "dtype")): + options = direct_solve(options_ctype) + return f'optTypeMetaToScalarType({options}.dtype_opt())' + + elif goal == OptionalCType(BaseCType("Layout", "layout")): + options = direct_solve(options_ctype) + return f'{options}.layout_opt()' + + elif goal == OptionalCType(BaseCType("Device", "device")): + options = direct_solve(options_ctype) + return f'{options}.device_opt()' + + elif goal == OptionalCType(BaseCType("bool", "pin_memory")): + options = direct_solve(options_ctype) + return f'{options}.pinned_memory_opt()' + + unsat(goal) + + return [Expr(solve(g.ctype, direct=False), g.ctype) for g in goals] diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index cb315cfc7525b..39fb8bef38464 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -1,95 +1,273 @@ from tools.codegen.model import * from dataclasses import dataclass -from typing import Optional, Union, Sequence +from typing import Optional, Union, Sequence, TypeVar, List, Set +from enum import Enum + +_T = TypeVar('_T') + +# An ArgName is just the str name of the argument in schema; +# but in some special circumstances, we may add a little extra +# context. The Enum SpecialArgName covers all of these cases; +# grep for their construction sites to see when they can occr. + +SpecialArgName = Enum('SpecialArgName', ( + 'possibly_redundant_memory_format', +)) +ArgName = Union[str, SpecialArgName] + +# A CType is short for C++ semantic type. A CType represents a C++ type, plus +# semantic information about what it represents. For example, consider the +# argument "bool pin_memory"; its normal C++ type is "bool", but its C++ +# semantic type also keeps track that this represents a "pin_memory"; you can't +# just use a random other boolean in a context where you need a "pin_memory"! +# +# CTypes encode C++ type structure as needed for translation. Right now we +# track references and optional, but don't, for example, track ArrayRef. If +# you need trnsnlations that know about these types, beef up this data +# structure. -# Represents the implicit *this argument for method calls in C++ API @dataclass(frozen=True) -class ThisArgument: - argument: Argument +class BaseCType: + type: str + name: ArgName + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return self.type -# Bundle of arguments that represent a TensorOptions in the C++ API. @dataclass(frozen=True) -class TensorOptionsArguments: - dtype: Argument - layout: Argument - device: Argument - pin_memory: Argument +class ConstRefCType: + elem: 'CType' - def all(self) -> Sequence[Argument]: - return [self.dtype, self.layout, self.device, self.pin_memory] + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f'const {self.elem.cpp_type()} &' + + @property + def name(self) -> ArgName: + return self.elem.name -# Describe a argument (e.g., the x in "f(int x)") in the C++ API @dataclass(frozen=True) -class CppArgument: - # C++ type, e.g., int - type: str - # C++ name, e.g., x +class MutRefCType: + elem: 'CType' + + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f'{self.elem.cpp_type()} &' + + @property + def name(self) -> ArgName: + return self.elem.name + +@dataclass(frozen=True) +class OptionalCType: + elem: 'CType' + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f'c10::optional<{self.elem.cpp_type()}>' + + @property + def name(self) -> ArgName: + return self.elem.name + +CType = Union[BaseCType, OptionalCType, ConstRefCType, MutRefCType] + +# A binding represents any C++ binding site for a formal parameter. +# We don't distinguish between binding sites for different APIs; +# instead, all of the important distinctions are encoded in CType, +# which you can use to figure out if a given Binding is appropriate +# for use in another context. (See tools.codegen.api.translate) + +@dataclass(frozen=True) +class Binding: name: str - # Only used by the header, but we work it out in all cases anyway - default: Optional[str] - # The JIT argument(s) this formal was derived from. May - # correspond to multiple arguments if this is TensorOptions! - # May also correspond to the implicit *this argument! - argument: Union[Argument, TensorOptionsArguments, ThisArgument] - - # Default string representation prints the most elaborated form - # of the formal - def __str__(self) -> str: + ctype: CType + argument: Union[Argument, TensorOptionsArguments, SelfArgument] + # TODO: maybe don't represent default here + default: Optional[str] = None + + @property + def type(self) -> str: + return self.ctype.cpp_type() + + def no_default(self) -> 'Binding': + return Binding( + name=self.name, + ctype=self.ctype, + default=None, + argument=self.argument, + ) + + def decl(self) -> str: mb_default = "" if self.default is not None: mb_default = f"={self.default}" return f"{self.type} {self.name}{mb_default}" - # However, you might also find the version with no default useful - def str_no_default(self) -> str: + def defn(self) -> str: return f"{self.type} {self.name}" +# An Expr is a C++ expression. It has a C++ string representing its syntax, +# as well as a CType saying what it provides. + @dataclass(frozen=True) -class CppExpr: - type: str +class Expr: expr: str + type: CType +# A CppSignature represents a single overload in the C++ API. For +# any given function schema, there may be multiple CppSignatures +# corresponding to it, based on how we desugar to C++. See also +# CppSignatureGroup. @dataclass(frozen=True) -class DispatcherExpr: - type: str - expr: str +class CppSignature: + # The schema this signature is derived from + func: FunctionSchema + + # Is this a C++ signature for a method, i.e. Tensor::my_op(...)? + method: bool + + # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API + # (i.e. with a potential TensorOptions argument and out arguments in the front) + faithful: bool + + # The set of C++ arguments which should not have defaults applied to them + cpp_no_default_args: Set[str] + + # Is this a fallback C++ binding? Fallback bindings are enabled by + # manual_cpp_binding: True and are alternate, non-public API that + # lets manual C++ binding implementors access the binding that would + # have been automatically generated + fallback_binding: bool = False + + # Return the unpacked argument structure of this signature, + # discarding information about which arguments are semantically + # related to each other. + def arguments(self) -> Sequence[Binding]: + return cpp.arguments( + self.func.arguments, faithful=self.faithful, + method=self.method, cpp_no_default_args=self.cpp_no_default_args) + def name(self) -> str: + n = cpp.name(self.func, faithful_name_for_out_overloads=self.faithful) + if self.fallback_binding: + n = f"__dispatch_{n}" + return n + + # Render the C++ declaration for this signature + def decl(self) -> str: + returns_type = cpp.returns_type(self.func.returns) + cpp_args_str = ', '.join(a.decl() for a in self.arguments()) + return f"{returns_type} {self.name()}({cpp_args_str})" + + # Render the C++ definition for this signature, not including + # the body (with curly braces) + def defn(self, *, prefix: str = "") -> str: + returns_type = cpp.returns_type(self.func.returns) + cpp_args_str = ', '.join(a.defn() for a in self.arguments()) + name = prefix + self.name() + return f"{returns_type} {name}({cpp_args_str})" + + +# Represents group of all CppSignatures associated with a +# FunctionSchema. Right now, that's the regular, user-visible +# signature, as well as a "faithful" signature which doesn't +# have grouping. @dataclass(frozen=True) -class LegacyDispatcherExpr: - type: str - expr: str +class CppSignatureGroup: + func: FunctionSchema + signature: CppSignature + faithful_signature: Optional[CppSignature] + + @staticmethod + def from_native_function(f: NativeFunction, *, method: bool, fallback_binding: bool = False) -> 'CppSignatureGroup': + func = f.func + faithful_signature: Optional[CppSignature] + if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: + faithful_signature = CppSignature( + func=func, + faithful=True, + method=method, + fallback_binding=fallback_binding, + cpp_no_default_args=f.cpp_no_default_args + ) + else: + faithful_signature = None + signature = CppSignature( + func=func, + faithful=False, + method=method, + fallback_binding=fallback_binding, + cpp_no_default_args=f.cpp_no_default_args + ) + return CppSignatureGroup( + func=func, + signature=signature, + faithful_signature=faithful_signature, + ) @dataclass(frozen=True) -class DispatcherArgument: - type: str - name: str - # dispatcher NEVER has defaults - argument: Union[Argument, TensorOptionsArguments] - # TensorOptionsArguments can occur when not using full c10 dispatch +class DispatcherSignature: + # The schema this signature is derived from + func: FunctionSchema - def __str__(self) -> str: - return f"{self.type} {self.name}" + def arguments(self) -> List[Binding]: + return dispatcher.arguments(self.func) + + def name(self) -> str: + return dispatcher.name(self.func) + + def defn(self, name: Optional[str] = None) -> str: + args_str = ', '.join(a.defn() for a in self.arguments()) + if name is None: + name = self.name() + return f"{self.returns_type()} {name}({args_str})" + + def exprs(self) -> List[Expr]: + return [Expr(a.name, a.ctype) for a in self.arguments()] + + def returns_type(self) -> str: + return dispatcher.returns_type(self.func.returns) + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + dispatcher_args_types_str = ', '.join(a.type for a in self.arguments()) + return f'{self.returns_type()} ({dispatcher_args_types_str})' + + @staticmethod + def from_schema(func: FunctionSchema) -> 'DispatcherSignature': + return DispatcherSignature(func) @dataclass(frozen=True) -class LegacyDispatcherArgument: - type: str - name: str - # Legacy dispatcher arguments have defaults for some reasons (e.g., - # the function prototypes in CPUType.h are defaulted). There isn't - # really any good reason to do this, as these functions are only - # ever called from a context where all defaulted arguments are - # guaranteed to be given explicitly. - # TODO: Remove this - default: Optional[str] - argument: Union[Argument, TensorOptionsArguments] - - # Convention here is swapped because arguably legacy - # dispatcher shouldn't have defaults... - def __str__(self) -> str: - return f"{self.type} {self.name}" +class NativeSignature: + # The schema this signature is derived from + func: FunctionSchema - def str_with_default(self) -> str: - mb_default = "" - if self.default is not None: - mb_default = f"={self.default}" - return f"{self.type} {self.name}{mb_default}" + def name(self) -> str: + return native.name(self.func) + + def defn(self, name: Optional[str] = None) -> str: + args_str = ', '.join(a.defn() for a in self.arguments()) + if name is None: + name = self.name() + return f"{native.returns_type(self.func.returns)} {name}({args_str})" + + def ptr_type(self) -> str: + # don't include defaults in type signature! + args_str = ', '.join(a.defn() for a in self.arguments()) + return f'{native.returns_type(self.func.returns)} (*)({args_str})' + + def arguments(self) -> List[Binding]: + return native.arguments(self.func) + + def dispatcher_exprs(self) -> List[Expr]: + return translate.translate(self.arguments(), dispatcher.arguments(self.func), method=False) + + @staticmethod + def from_schema(func: FunctionSchema) -> 'NativeSignature': + return NativeSignature(func) + +# Functions only, no types +from tools.codegen.api import cpp, dispatcher, native, translate diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 0871d8c55ae1c..08e9572131e32 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -2,21 +2,26 @@ import contextlib import textwrap import itertools -from typing import List, Dict, Optional, Iterator, Tuple, Set, Callable, Any, TypeVar, Union, Sequence +from typing import List, Dict, Optional, Iterator, Tuple, Set, Callable, Any, TypeVar, Union, Sequence, Iterable import yaml from enum import Enum -from collections import OrderedDict +from collections import OrderedDict, defaultdict import argparse import pathlib import functools +import json +from dataclasses import dataclass from tools.codegen.code_template import CodeTemplate from tools.codegen.model import * from tools.codegen.api.types import * import tools.codegen.api.cpp as cpp import tools.codegen.api.dispatcher as dispatcher -import tools.codegen.api.legacy_dispatcher as legacy_dispatcher +import tools.codegen.api.native as native +import tools.codegen.api.meta as meta +from tools.codegen.api.translate import translate import tools.codegen.local as local +from tools.codegen.selective_build.selector import SelectiveBuilder try: # use faster C loader if available @@ -92,37 +97,71 @@ def parse_native_yaml(path: str) -> List[NativeFunction]: T = TypeVar('T') S = TypeVar('S') +F = TypeVar('F', NativeFunction, StructuredNativeFunctions, Union[NativeFunction, StructuredNativeFunctions]) + +@contextlib.contextmanager +def native_function_manager(g: Union[StructuredNativeFunctions, NativeFunction]) -> Iterator[None]: + if isinstance(g, StructuredNativeFunctions): + # By default, we associate all errors with structured native functions + # with the out variant. In some cases, it might be better to have + # a more specific place to hang things; if so, use + # native_function_manager again on the inside + f = g.out + else: + f = g + with context(f'in {f.loc}:\n {f.func}'): + with local.parametrize( + use_c10_dispatcher=f.use_c10_dispatcher, + ): + yield + # Given a function that operates on NativeFunction, wrap it into a new function # that sets some appropriate context managers for that native function. # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound # (you will get an error if we try to access the local variables without having # set them). -def with_native_function(func: Callable[[NativeFunction], T]) -> Callable[[NativeFunction], T]: +def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: @functools.wraps(func) - def wrapper(f: NativeFunction) -> T: - with context(f'in {f.loc}:\n {f.func}'): - with local.parametrize( - use_c10_dispatcher=f.use_c10_dispatcher, - ): - return func(f) + def wrapper(f: F) -> T: + with native_function_manager(f): + return func(f) + return wrapper + +def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: + @functools.wraps(func) + def wrapper(slf: S, f: F) -> T: + with native_function_manager(f): + return func(slf, f) return wrapper # These two functions purposely return generators in analogy to map() # so that you don't mix up when you need to list() them # Map over function that may return None; omit Nones from output sequence -def mapMaybe(func: Callable[[T], Optional[S]], xs: Sequence[T]) -> Iterator[S]: +def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]: for x in xs: r = func(x) if r is not None: yield r # Map over function that returns sequences and cat them all together -def concatMap(func: Callable[[T], Sequence[S]], xs: Sequence[T]) -> Iterator[S]: +def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]: for x in xs: for r in func(x): yield r +def cpp_string(s: str) -> str: + """Convert a python string into a c++ string literal """ + s = s.replace('\\', '\\\\') + s = s.replace('"', '\\"') + s = s.replace('\a', '\\a') + s = s.replace('\b', '\\b') + s = s.replace('\f', '\\f') + s = s.replace('\n', '\\n') + s = s.replace('\v', '\\v') + s = s.replace('\t', '\\t') + return f'"{s}"' + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # C++ CODE GENERATION @@ -141,12 +180,35 @@ def concatMap(func: Callable[[T], Sequence[S]], xs: Sequence[T]) -> Iterator[S]: # code we want. Target = Enum('Target', ('DEFINITION', 'DECLARATION', 'REGISTRATION')) -# Generates {dispatch}Type.cpp and {dispatch}Type.h (e.g., CPUType.cpp -# and CPUType.h). This function is also reused to implement per-operator -# registration. It also generates TypeDefault.cpp and TypeDefault.h when -# dispatch is None. +# Dispatch keys that "support all backends". These codegen slightly differently +# then backend specific keys. +def is_generic_dispatch_key(dk: str) -> bool: + return dk in {'DefaultBackend', 'Math'} + +# CUDA specific dispatch keys +def is_cuda_dispatch_key(dk: str) -> bool: + return 'CUDA' in dk + +# Structured kernel generation is only supported for certain key types; +# otherwise use old-style +def is_structured_dispatch_key(dk: str) -> bool: + return dk in {'CUDA', 'CPU'} + +# Generates RegisterSchema.cpp. Depending on the selector, either +# all schemas are registered, or only some are (in the case of +# selective build) +@dataclass(frozen=True) +class RegisterSchema: + selector: SelectiveBuilder + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: + if not self.selector.is_native_function_selected(f): + return None + return f'm.def({cpp_string(str(f.func))});\n' + +# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). # -# {dispatch}Type.cpp # - The primary function of this file is to register all of the # implementations for the given dispatch key to the dispatcher, # so they are available for use in PyTorch. If dispatch is @@ -160,85 +222,310 @@ def concatMap(func: Callable[[T], Sequence[S]], xs: Sequence[T]) -> Iterator[S]: # API without having to disambiguate which overload you want # (as would be the case if you directly registered native:: # functions). -# -# {dispatch}Type.h -# - In principle, this file shouldn't exist at all; historically, -# it existed so that we could directly access these functions -# outside of the registration API for the implementation of -# static dispatch. Should be deleted now! -# -# This function is also used for a secondary purpose: the registration -# logic is also reused to implement per-operator registration. -def compute_type_method( - dispatch: Optional[str], *, - target: Target, - # Which operators to actually generate code for. If None, generate - # code for all operators - op_registration_whitelist: Optional[Set[str]], - # Only valid for generating registrations. If True, only generate - # def() invocations (for schema registration); do not generate - # any impl() invocations for, e.g., catch-all kernels - def_only: bool = False -) -> Callable[[NativeFunction], Optional[str]]: - - if def_only: - assert target is Target.REGISTRATION and dispatch is None - - @with_native_function - def func(f: NativeFunction) -> Optional[str]: - if dispatch is not None: - if f.dispatch is None or dispatch not in f.dispatch: - return None +@dataclass(frozen=True) +class RegisterDispatchKey: + dispatch_key: str + + # TODO: Give more precise type Union[Literal[Target.DEFINITION, + # Target.REGISTRATION]]; requires Literal from typing_extensions + # which we don't have a dep for yet. + target: Target + + # Selector object to determine which operators to generate + # registration code for. + selector: SelectiveBuilder + + # Whether or not we are actually code-genning for ROCm + rocm: bool + + def __post_init__(self) -> None: + assert self.target is not Target.DECLARATION + + @method_with_native_function + def __call__(self, f: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]: + if isinstance(f, StructuredNativeFunctions): + return self.gen_structured(f) + elif isinstance(f, NativeFunction): + r = self.gen_unstructured(f) + return [] if r is None else [r] + else: + assert_never(f) + + def gen_structured_class_set_output(self, k: SchemaKind, parent_class: str, generate_super: bool) -> str: + if generate_super: + set_output_super = f"{parent_class}::set_output(output_idx, sizes, strides, options, names);" + else: + set_output_super = "" + return f""" +void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, + TensorOptions options, DimnameList names) override {{ + {self.gen_structured_class_set_output_body(k)} + if (!names.empty()) namedinference::propagate_names(outputs_[output_idx], names); + // super must happen after, so that downstream can use maybe_get_output + // to retrieve the output + {set_output_super} +}} +""" + + def gen_structured_class_set_output_body(self, k: SchemaKind) -> str: + if self.dispatch_key == 'CUDA': + maybe_set_guard = """ +auto current_device = guard_.current_device(); +if (C10_UNLIKELY(current_device.has_value())) { + TORCH_INTERNAL_ASSERT(*current_device == options.device(), + "structured kernels don't support multi-device outputs"); +} else { + guard_.set_device(options.device()); +} +""" else: - if f.dispatch is not None and target is not Target.REGISTRATION: + maybe_set_guard = '' + + if k is SchemaKind.functional: + if self.dispatch_key == "Meta": + return """ +if (strides.empty()) { + outputs_[output_idx] = at::empty_meta(sizes, options); +} else { + TORCH_INTERNAL_ASSERT(0, "not implemented yet"); +} +""" + else: + expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ + "options.device_opt(), options.pinned_memory_opt()" + if self.dispatch_key == "CPU": + empty_impl = "at::native::empty_cpu" + empty_strided_impl = "at::native::empty_strided_cpu" + elif self.dispatch_key == "CUDA": + empty_impl = "at::native::empty_cuda" + empty_strided_impl = "at::native::empty_strided_cuda" + else: + raise AssertionError("unsupported dispatch key") + return f""" +{maybe_set_guard} +if (strides.empty()) {{ + outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); +}} else {{ + outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts}); +}} +""" + elif k is SchemaKind.inplace: + return maybe_set_guard + elif k is SchemaKind.out: + return f""" +{maybe_set_guard} +at::native::resize_output(outputs_[output_idx], sizes); +if (!strides.empty()) {{ + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + at::native::as_strided_(outputs_[output_idx], sizes, strides); +}} else if (options.memory_format_opt().has_value()) {{ + outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); +}} +""" + else: + assert_never(k) + + # returns the definition of a ctor, as well as how to construct + # this class to a variable named op + def gen_structured_class_ctor(self, k: SchemaKind, class_name: str) -> str: + if k is SchemaKind.functional: + return "" + elif k is SchemaKind.inplace: + # TODO: Make sure out argument is guaranteed to be self + return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" + elif k is SchemaKind.out: + # TODO: Stop hardcoding out here + return f"{class_name}(Tensor& out) : outputs_{{std::ref(out)}} {{}}" + else: + assert_never(k) + + def gen_structured_class( + self, f: NativeFunction, k: SchemaKind, *, class_name: str, parent_class: str, generate_super: bool + ) -> str: + if k is SchemaKind.functional: + assert len(f.func.returns) == 1, "multi-return not supported yet" + output_type = "Tensor" + elif k is SchemaKind.inplace: + output_type = "std::reference_wrapper" + elif k is SchemaKind.out: + assert len(f.func.arguments.out) == 1, "multi-out structured not supported yet" + output_type = "std::reference_wrapper" + + if self.dispatch_key == 'CUDA': + if self.rocm: + guard_field = 'c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;' + else: + guard_field = 'c10::cuda::OptionalCUDAGuard guard_;' + else: + guard_field = '' + + return f""" +struct {class_name} final : public {parent_class} {{ + {self.gen_structured_class_ctor(k, class_name)} + {self.gen_structured_class_set_output(k, parent_class, generate_super)} + const Tensor& maybe_get_output(int64_t output_idx) override {{ + return outputs_[output_idx]; + }} + std::array<{output_type}, {len(f.func.returns)}> outputs_; + {guard_field} +}}; +""" + + def gen_structured(self, g: StructuredNativeFunctions) -> List[str]: + if self.dispatch_key == 'Meta': + assert self.dispatch_key not in g.out.dispatch, \ + "Do not explicitly specify Meta dispatch key on structured " \ + "functions, they will be automatically generated for you" + elif self.dispatch_key not in g.out.dispatch: + return [] + elif not is_structured_dispatch_key(self.dispatch_key): + return list(mapMaybe(self.gen_unstructured, g.functions())) + + # Inner helper function to close over g + # TODO: This function has a lot of similarity with gen_unstructured. If + # you edit this, you may need to also edit gen_unstructured. + @with_native_function + def gen_one(f: NativeFunction) -> Optional[str]: + assert self.target is not Target.DECLARATION + assert not f.manual_kernel_registration + + # TODO: put this into StructuredNativeFunctions itself + functional_func = g.out.func.signature() + functional_sig = DispatcherSignature.from_schema(functional_func) + + # TODO: is it meta or wot? Sort this out + functional_exprs = ', '.join( + e.expr for e in translate(functional_sig.arguments(), dispatcher.arguments(functional_func), method=False) + ) + + if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): return None - if op_registration_whitelist is not None and \ - f"aten::{f.func.name.name}" not in op_registration_whitelist and target is Target.REGISTRATION: - return None + k = f.func.kind() + sig = NativeSignature.from_schema(f.func) + + if self.target is Target.DEFINITION: + if self.dispatch_key == 'Meta': + class_name = f"structured_{meta.name(g)}_meta_{k.name}" + parent_class = f"at::meta::{meta.name(g)}" + else: + class_name = f"structured_{g.out.dispatch[self.dispatch_key]}_{k.name}" + parent_class = f"at::native::structured_{g.out.dispatch[self.dispatch_key]}" + + if k is SchemaKind.functional: + assert len(f.func.returns) == 1, "multi-return not supported yet" + out_expr = "op.outputs_[0]" + ret_expr = "std::move(op.outputs_[0])" # small optimization + op_init = f"{class_name} op;" + elif k is SchemaKind.inplace: + out_expr = "self" + ret_expr = "self" + op_init = f"{class_name} op(self);" + elif k is SchemaKind.out: + assert len(f.func.arguments.out) == 1, "multi-out structured not supported yet" + out_expr = f.func.arguments.out[0].name + ret_expr = out_expr + op_init = f"{class_name} op({out_expr});" + + if self.dispatch_key == 'Meta': + impl_call = "" + else: + impl_call = f"op.impl({functional_exprs}, {out_expr});" + + # For an overview of what this template code looks like, see + # https://github.com/pytorch/rfcs/pull/9 + return f"""\ +namespace {{ + +{self.gen_structured_class( + f, k, + class_name=class_name, + parent_class=parent_class, + generate_super=g.out.structured_inherits is not None +)} + +{sig.defn()} {{ + {op_init} + op.meta({functional_exprs}); + {impl_call} + return {ret_expr}; +}} + +}} // anonymous namespace +""" - name = legacy_dispatcher.name(f.func) - returns_type = legacy_dispatcher.returns_type(f.func.returns) - args = legacy_dispatcher.arguments(f.func) - args_str = ', '.join(map(str, args)) - - if target is Target.DECLARATION: - return f"{returns_type} {name}({args_str});" - elif target is Target.DEFINITION: - if f.dispatch is None: - cpp_name = cpp.name(f.func) - impl_name = f"at::native::{cpp_name}" + elif self.target is Target.REGISTRATION: + dispatcher_sig = DispatcherSignature.from_schema(f.func) + + assert local.use_c10_dispatcher() is UseC10Dispatcher.full + return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: - assert dispatch is not None - impl_name = f"at::native::{f.dispatch[dispatch]}" + assert_never(self.target) + # Silence mypy's "Missing return statement" error + return None + + return list(mapMaybe(gen_one, g.functions())) + + @method_with_native_function + def gen_unstructured(self, f: NativeFunction) -> Optional[str]: + # for mypy type refinement; would be fixed by TODO on target + assert self.target is not Target.DECLARATION + + if self.dispatch_key not in f.dispatch: + return None + if f.manual_kernel_registration: + return None - args_exprs_str = ', '.join(map(lambda a: a.name, args)) + if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): + return None + + name = native.name(f.func) + returns_type = native.returns_type(f.func.returns) + args = native.arguments(f.func) + args_str = ', '.join(a.defn() for a in args) + + if self.target is Target.DEFINITION: + impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" + + args_exprs_str = ', '.join(a.name for a in args) return_kw = " return " cuda_guard = "" - if dispatch is None or 'CUDA' in dispatch or 'Vulkan' == dispatch: - self_args = (a for a in f.func.arguments if a.name == "self") + if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key): + self_arg = [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else [] # There is precedence for which argument we use to do # device guard. This describes the precedence order. - candidate_args = itertools.chain(self_args, f.func.out_arguments, f.func.arguments) + candidate_args = itertools.chain( + self_arg, + f.func.arguments.out, + f.func.arguments.flat_positional + ) # Only tensor like arguments are eligible device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) has_tensor_options = any(isinstance(a.argument, TensorOptionsArguments) for a in args) - # TODO: There is probably a simpler version of this that - # works just as well. - if f.device_guard and (dispatch is None or 'Vulkan' == dispatch) and has_tensor_options: - cuda_guard = """\ + if local.use_c10_dispatcher() == UseC10Dispatcher.full: + cuda_guard_from_tensor_options = """\ + const DeviceGuard device_guard(device_or_default(device)); +""" + else: + assert local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures + cuda_guard_from_tensor_options = """\ const DeviceGuard device_guard(options.device()); """ - elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options: - cuda_guard = """\ + + # TODO: There is probably a simpler version of this that + # works just as well. + if f.device_guard and is_generic_dispatch_key(self.dispatch_key) and has_tensor_options: + cuda_guard = cuda_guard_from_tensor_options + elif f.device_guard and is_cuda_dispatch_key(self.dispatch_key) and has_tensor_options: + cuda_guard = f"""\ globalContext().lazyInitCUDA(); - const DeviceGuard device_guard(options.device()); + {cuda_guard_from_tensor_options} """ elif f.device_guard and device_of is not None: cuda_guard = f"""\ @@ -250,136 +537,142 @@ def func(f: NativeFunction) -> Optional[str]: """ return f"""\ +namespace {{ + {returns_type} {name}({args_str}) {{ {cuda_guard}{return_kw}{impl_name}({args_exprs_str}); }} + +}} // anonymous namespace """ - elif target is Target.REGISTRATION: - assert returns_type == dispatcher.returns_type(f.func.returns) - dispatcher_args = dispatcher.arguments(f.func) - dispatcher_args_types_str = ', '.join(map(lambda a: a.type, dispatcher_args)) - if dispatch is None or dispatch == 'Math': - type_name = f'TypeDefault::{name}' + elif self.target is Target.REGISTRATION: + if f.manual_kernel_registration: + return None else: - type_name = f'{dispatch}Type::{name}' - - # def registration only happens in TypeDefault - def_registration = "" - if dispatch is None: - def_registration = f'm.def("{f.func}");\n' + dispatcher_sig = DispatcherSignature.from_schema(f.func) - impl_registration = "" - if not def_only and not f.manual_kernel_registration and (dispatch is not None or f.dispatch is None): # Figure out which signature the function is if local.use_c10_dispatcher() is UseC10Dispatcher.full: - - payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \ - f"{returns_type} ({dispatcher_args_types_str})>(TORCH_FN({type_name}))" - + payload = f"TORCH_FN({name})" else: - payload = f"torch::CppFunction::makeUnboxedOnly(&{type_name})" - - # Annotate it with dispatch information if necessary - # - # NB: In the ordinary, TypeDerived code generation work flow, specification - # of the backend is handled by the enclosing block, so the torch::dispatch - # invocation here is strictly unnecessary. However, in the fbcode mobile - # only workflow using per-op registration, these registrations will get dumped - # in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend. So - # the torch::dispatch specification here is important! See - # Note [Redundancy in registration code is OK] for how we handle redundant info. - if dispatch is not None: - payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n" - - impl_registration = f'm.impl("{f.func.name}",\n{payload});\n' - - return f"{def_registration}{impl_registration}" - else: - assert_never(target) + assert local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures + payload = f""" +c10::impl::hacky_wrapper_for_legacy_signatures< + {dispatcher_sig.type()}, + {len(f.func.arguments.out)} +>(TORCH_FN({name})) +""" - return func + return f'm.impl("{f.func.name}",\n{payload});\n' + else: + assert_never(self.target) # Generates Function.cpp and Function.h. These files provide the # functional public C++ API, and the scaffolding to call into # the dispatcher from these functions. See also compute_tensor_method. -def compute_function(*, target: Target) -> Callable[[NativeFunction], Optional[str]]: - @with_native_function - def go(f: NativeFunction) -> Optional[str]: - if f.manual_kernel_registration: - return None +@dataclass(frozen=True) +class ComputeFunction: + target: Target + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.function not in f.variants: return None name = cpp.name(f.func) - cpp_returns_type = cpp.returns_type(f.func.returns) - cpp_args = cpp.arguments(f.func) - cpp_args_str = ', '.join(map(str, cpp_args)) + sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding) - if target is Target.DECLARATION: - return f"CAFFE2_API {cpp_returns_type} {name}({cpp_args_str});" + if self.target is Target.DECLARATION: + result = f"TORCH_API {sig_group.signature.decl()};\n" + if sig_group.faithful_signature is not None: + result += f"TORCH_API {sig_group.faithful_signature.decl()};\n" + return result - assert target is Target.DEFINITION + assert self.target is Target.DEFINITION - dispatcher_exprs = dispatcher.cpparguments_exprs(cpp_args) - cpp_args_str_no_default = ', '.join(map(lambda a: a.str_no_default(), cpp_args)) - dispatcher_returns_type = dispatcher.returns_type(f.func.returns) - dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs)) - dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs)) + def generate_defn(faithful: bool) -> str: + dispatcher_sig = DispatcherSignature.from_schema(f.func) - return f""" + if faithful and sig_group.faithful_signature is not None: + sig = sig_group.faithful_signature + else: + sig = sig_group.signature + + dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) + dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) + + return f""" // aten::{f.func} -{cpp_returns_type} {name}({cpp_args_str_no_default}) {{ +{sig.defn()} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") - .typed<{dispatcher_returns_type} ({dispatcher_types_str})>(); + .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ - return go + + result = generate_defn(sig_group.faithful_signature is None) + if sig_group.faithful_signature is not None: + result += generate_defn(True) + + return result # Generates TensorBody.h (sic) and TensorMethods.cpp. These files provide the # object-oriented (method-based) public C++ API, and the scaffolding to call into # the dispatcher from these functions. See also compute_function. -def compute_tensor_method(*, target: Target) -> Callable[[NativeFunction], Optional[str]]: - @with_native_function - def go(f: NativeFunction) -> Optional[str]: +@dataclass(frozen=True) +class ComputeTensorMethod: + target: Target + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None assert not f.func.is_out_fn() - assert len(f.func.arguments) > 0 - assert sum(a.name == 'self' for a in f.func.arguments) == 1 + assert f.func.arguments.self_arg is not None name = cpp.name(f.func) - cpp_returns_type = cpp.returns_type(f.func.returns) - cpp_args = cpp.arguments(f.func, method=True) - cpp_args_exclude_this = [a for a in cpp_args if not isinstance(a.argument, ThisArgument)] - cpp_args_exclude_this_str = ', '.join(str(a) for a in cpp_args_exclude_this) - if target is Target.DECLARATION: - return f"{cpp_returns_type} {name}({cpp_args_exclude_this_str}) const;" + sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding) - assert target is Target.DEFINITION + if self.target is Target.DECLARATION: + result = f"{sig_group.signature.decl()} const;\n" + if sig_group.faithful_signature is not None: + result += f"{sig_group.faithful_signature.decl()} const;\n" + return result - dispatcher_exprs = dispatcher.cpparguments_exprs(cpp_args) - cpp_args_exclude_this_str_no_default = ', '.join(a.str_no_default() for a in cpp_args_exclude_this) - dispatcher_returns_type = dispatcher.returns_type(f.func.returns) - dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs)) - dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs)) + assert self.target is Target.DEFINITION - return f""" + def generate_defn(faithful: bool) -> str: + dispatcher_sig = DispatcherSignature.from_schema(f.func) + + if faithful: + sig = sig_group.faithful_signature + assert sig is not None + else: + sig = sig_group.signature + + dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments(), method=True) + dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) + + return f""" // aten::{f.func} -{cpp_returns_type} Tensor::{name}({cpp_args_exclude_this_str_no_default}) const {{ +{sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") - .typed<{dispatcher_returns_type} ({dispatcher_types_str})>(); + .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ - return go + result = generate_defn(faithful=False) + if sig_group.faithful_signature is not None: + result += generate_defn(faithful=True) + + return result # Generates ATenOpList.cpp, a runtime accessible list of all aten # operators. @@ -393,92 +686,134 @@ def compute_aten_op(f: NativeFunction) -> str: # Generates NativeFunctions.h, a list of forward declarations of all # actual kernel definitions we keep in aten/src/ATen/native/ @with_native_function -def compute_native_function_declaration(f: NativeFunction) -> List[str]: - if f.dispatch is None: - ns = [cpp.name(f.func)] +def compute_native_function_declaration(g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]: + if isinstance(g, StructuredNativeFunctions): + # only out has dispatch + meta_name = meta.name(g) + rs = [] + seen: Set[Any] = set() + out_args = native.arguments(g.out.func) + for k, n in g.out.dispatch.items(): + if n in seen: + continue + if not is_structured_dispatch_key(k): + continue + seen.add(n) + rs.append(f"""\ +struct TORCH_API structured_{n} : public at::meta::{meta_name} {{ + void impl({', '.join(a.decl() for a in out_args)}); +}}; +""") + + seen = set() + for f in g.functions(): + returns_type = native.returns_type(f.func.returns) + args = native.arguments(f.func) + for k, n in f.dispatch.items(): + if n in seen: + continue + if is_structured_dispatch_key(k): + continue + seen.add(n) + args_str = ', '.join(a.decl() for a in args) + rs.append(f"TORCH_API {returns_type} {n}({args_str});") + + return rs + else: + f = g ns = list(f.dispatch.values()) - rs = [] - # Sometimes a function name shows up multiple times; only generate - # it once! - seen = set() - for n in ns: - if n in seen: - continue - if "legacy::" in n: - continue - seen.add(n) - returns_type = legacy_dispatcher.returns_type(f.func.returns) - args = legacy_dispatcher.arguments(f.func) - rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(map(lambda a: a.str_with_default(), args))});") - - return rs + rs = [] + # Sometimes a function name shows up multiple times; only generate + # it once! + seen = set() + for n in ns: + if n in seen: + continue + if "legacy::" in n: + continue + seen.add(n) + returns_type = native.returns_type(f.func.returns) + args = native.arguments(f.func) + rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});") + + return rs + +def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str: + with native_function_manager(g.out): + sig = g.signature() + name = meta.name(g) + args = native.arguments(sig) + args_str = ', '.join(a.decl() for a in args) + parent_class = g.out.structured_inherits + if parent_class is None: + parent_class = "at::impl::MetaBase" + return f"""\ +struct TORCH_API {name} : public {parent_class} {{ + void meta({args_str}); +}}; +""" -# Generates BackendSelectRegister.cpp, a series of kernels which provide +# Generates RegisterBackendSelect.cpp, a series of kernels which provide # specialized computation of dispatch key for operator signatures which cannot # be easily done automatically using templating. -def compute_backend_select(*, target: Target) -> Callable[[NativeFunction], Optional[str]]: - @with_native_function - def go(f: NativeFunction) -> Optional[str]: +@dataclass(frozen=True) +class ComputeBackendSelect: + target: Target + + @method_with_native_function + def __call__(self, f: NativeFunction) -> Optional[str]: if str(f.func.name.name).endswith('_like') or str(f.func.name.name).startswith('new_'): return None - name = legacy_dispatcher.name(f.func) - legacy_dispatcher_returns_type = legacy_dispatcher.returns_type(f.func.returns) - legacy_dispatcher_args = legacy_dispatcher.arguments(f.func) + name = native.name(f.func) + native_sig = NativeSignature.from_schema(f.func) - if not any(isinstance(a.argument, TensorOptionsArguments) for a in legacy_dispatcher_args): + if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()): return None - legacy_dispatcher_tensor_args = [ - a for a in legacy_dispatcher_args + native_tensor_args = [ + a for a in native_sig.arguments() if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() ] - dispatcher_returns_type = dispatcher.returns_type(f.func.returns) - dispatcher_args = dispatcher.arguments(f.func) - dispatcher_exprs = dispatcher.legacydispatcherarguments_exprs(legacy_dispatcher_args) + dispatcher_sig = DispatcherSignature.from_schema(f.func) - if target is Target.DEFINITION: + sig: Union[NativeSignature, DispatcherSignature] + sig = dispatcher_sig + dispatcher_exprs = dispatcher_sig.exprs() + dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" + + if self.target is Target.DEFINITION: # I don't think there's actually a good reason to generate # these two cases differently # The first case could probably be improved though- it calls dispatchTypeId(), # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. - if legacy_dispatcher_tensor_args: - tensor_args = ', '.join(a.name for a in legacy_dispatcher_tensor_args) + if native_tensor_args: + tensor_args = ', '.join(a.name for a in native_tensor_args) compute_dk = f"""\ -DispatchKeySet _dk_set = DispatchKeySet(options.computeDispatchKey()) | c10::detail::multi_dispatch_key_set({tensor_args}); +DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); DispatchKey _dk = c10::impl::dispatchTypeId(_dk_set, _dk_mask);""" else: - compute_dk = "DispatchKey _dk = options.computeDispatchKey();" + compute_dk = f"DispatchKey _dk = {dispatch_key};" return f"""\ // aten::{f.func} -{legacy_dispatcher_returns_type} {name}({', '.join(a.str_with_default() for a in legacy_dispatcher_args)}) {{ +{sig.defn(name)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") - .typed<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>(); + .typed<{dispatcher_sig.type()}>(); {compute_dk} - DispatchKey _autograd_dk = c10::getAutogradKeyFromBackend(_dk); - // This trick allows calling Autograd backend kernel first and then backend kernel, - // without adding another AutogradBackendSelect dispatch key. - DispatchKey _current_dk = at::impl::variable_excluded_from_dispatch() ? _dk : _autograd_dk; - return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in dispatcher_exprs)}); + return op.callWithDispatchKey(_dk, {', '.join(a.expr for a in dispatcher_exprs)}); }} """ - elif target is Target.REGISTRATION: - if local.use_c10_dispatcher() is UseC10Dispatcher.full: - return f"""m.impl("aten::{f.func.name}", - c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>( - TORCH_FN({name})));""" - else: - return f"""m.impl_UNBOXED("aten::{f.func.name}", {name});""" - elif target is Target.DECLARATION: + elif self.target is Target.REGISTRATION: + return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" + elif self.target is Target.DECLARATION: raise AssertionError() else: - assert_never(target) - return go + assert_never(self.target) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # @@ -533,7 +868,7 @@ def dynamic_type(t: Type) -> str: # also include Tensor[] if str(t) == 'Tensor': return 'Tensor' - return cpp.argumenttype_type(t, mutable=False) + return cpp.argumenttype_type(t, mutable=False, binds='__placeholder__').cpp_type() def compute_method_of_yaml(variants: Set[Variant]) -> List[str]: # This is written out explicitly to ensure that Tensor and @@ -588,32 +923,9 @@ def compute_returns_yaml(f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[ name_to_field_name: Dict[str, str] = {} # Compute the returns field of the YAML entry + names = cpp.return_names(f) returns = [] - for i, r in enumerate(f.func.returns): - # If we have an inplace function, the return argument is - # implicitly named self. - # TODO: Consider incorporating this into the data model - if f.func.name.name.inplace: - assert i == 0, "illegal inplace function with multiple returns" - name = 'self' - # If we are out function, the name is the name of the - # corresponding output function (r.name will get recorded - # in field_name later.) - elif f.func.is_out_fn(): - name = f.func.out_arguments[i].name - # If the return argument is explicitly named... - elif r.name: - name_conflict = any(r.name == a.name for a in f.func.schema_order_arguments()) - if name_conflict and not f.func.is_out_fn(): - name = f'{r.name}_return' - else: - name = r.name - # If there is no explicit name, we just name the output result, - # unless it's a multi-return, in which case it's result0, - # result1, etc (zero-indexed) - else: - name = 'result' if len(f.func.returns) == 1 else f'result{i}' - + for i, (r, name) in enumerate(zip(f.func.returns, names)): ret = { 'dynamic_type': dynamic_type(r.type), 'name': name, @@ -624,14 +936,14 @@ def compute_returns_yaml(f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[ # See Note [name and field_name] ret['field_name'] = r.name if f.func.is_out_fn(): - name_to_field_name[f.func.out_arguments[i].name] = r.name + name_to_field_name[f.func.arguments.out[i].name] = r.name returns.append(ret) return returns, name_to_field_name # arguments in yaml roughly corresponds to the public C++ API -def compute_cpp_argument_yaml(cpp_a: CppArgument, *, schema_order: bool, kwarg_only_set: Set[str], +def compute_cpp_argument_yaml(cpp_a: Binding, *, schema_order: bool, kwarg_only_set: Set[str], out_arg_set: Set[str], name_to_field_name: Dict[str, str]) -> object: if isinstance(cpp_a.argument, TensorOptionsArguments): arg: Dict[str, object] = { @@ -645,7 +957,7 @@ def compute_cpp_argument_yaml(cpp_a: CppArgument, *, schema_order: bool, kwarg_o if cpp_a.default is not None: arg['default'] = cpp_a.default return arg - elif isinstance(cpp_a.argument, ThisArgument): + elif isinstance(cpp_a.argument, SelfArgument): raise AssertionError() elif isinstance(cpp_a.argument, Argument): return compute_argument_yaml( @@ -659,7 +971,7 @@ def compute_argument_yaml(a: Argument, *, schema_order: bool, kwarg_only_set: Se 'dynamic_type': dynamic_type(a.type), 'is_nullable': a.type.is_nullable(), 'name': a.name, - 'type': cpp.argument_type(a), + 'type': cpp.argument_type(a, binds="__placeholder__").cpp_type(), } if a.default is not None: arg['default'] = pythonify_default(cpp.default_expr(a.default, a.type)) @@ -684,10 +996,11 @@ def compute_declaration_yaml(f: NativeFunction) -> object: # These sets are used to conveniently test if an argument is a # kwarg-only or out argument - kwarg_only_set = set(a.name for a in f.func.kwarg_only_arguments) - out_arg_set = set(a.name for a in f.func.out_arguments) + kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only) + out_arg_set = set(a.name for a in f.func.arguments.out) - cpp_args = cpp.arguments(f.func) + sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False) + cpp_args = sig_group.signature.arguments() arguments = [ compute_cpp_argument_yaml( cpp_a, schema_order=False, @@ -704,7 +1017,13 @@ def compute_declaration_yaml(f: NativeFunction) -> object: for a in schema_order_jit_arguments ] - cpp_schema_order_types = [cpp.argument(a).type for a in schema_order_jit_arguments] + cpp_schema_order_types = [ + # NB: method here doesn't matter + r.type for a in schema_order_jit_arguments + for r in cpp.argument( + a, method=False, cpp_no_default_args=set(), faithful=False, has_tensor_options=False) + ] + cpp_returns = cpp.returns_type(f.func.returns) schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" @@ -715,7 +1034,6 @@ def compute_declaration_yaml(f: NativeFunction) -> object: ('name', cpp.name(f.func)), ('operator_name', str(f.func.name.name)), ('overload_name', str(f.func.name.overload_name)), - ('use_c10_dispatcher', f.use_c10_dispatcher.name), ('manual_kernel_registration', f.manual_kernel_registration), ('category_override', f.category_override if f.category_override is not None else ''), ('matches_jit_signature', True), @@ -729,28 +1047,28 @@ def compute_declaration_yaml(f: NativeFunction) -> object: ('returns', returns), ('inplace', f.func.name.name.inplace), ('is_factory_method', is_factory_method), - # Note [Abstract ATen methods] - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # An abstract ATen method is one whose dispatch differs between - # types. These are implemented in derived types (with a - # standard (throwing) definition in Type). A concrete ATen - # method is one which has the same dispatch for all types; - # we just implement it in the base Type. This is exposed - # in Declarations.yaml via a field named 'abstract'. - # - # Although this is what we have historically exposed, it is - # actually not all that useful for end users, who are also interested - # whether or not there is an explicit entry in derivatives.yaml - # for the entry or not (as this affects whether or not the operation is - # overrideable or not.) Once this all gets cleaned up, this - # property will be obsolete. - ('abstract', f.dispatch is not None), + ('abstract', f.is_abstract), ('device_guard', f.device_guard), ('with_gil', False), ('deprecated', False), - ('has_math_kernel', f.dispatch is not None and 'Math' in f.dispatch), + ('has_math_kernel', 'Math' in f.dispatch), ]) +@with_native_function +def compute_registration_declarations(f: NativeFunction) -> str: + name = dispatcher.name(f.func) + returns_type = dispatcher.returns_type(f.func.returns) + args = dispatcher.arguments(f.func) + args_str = ', '.join(a.no_default().decl() for a in args) + comment_data : Dict[str, str] = { + 'schema': f'aten::{f.func}', + # TODO: What exactly is the semantics of the 'dispatch' field? + 'dispatch': str(f.dispatch.keys() != {'Math'}), + 'default': str(any(is_generic_dispatch_key(k) for k in f.dispatch)) + } + return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} +""" + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # RUN IT ALL @@ -796,9 +1114,10 @@ def write_with_template(self, filename: str, template_fn: str, env = env_callable() if isinstance(env, dict): # TODO: Update the comment reference to the correct location - comment = "@" + "generated by aten/src/ATen/gen.py" - comment += " from {}".format(os.path.basename(template_fn)) - env['generated_comment'] = comment + if 'generated_comment' not in env: + comment = "@" + "generated by aten/src/ATen/gen.py" + comment += " from {}".format(os.path.basename(template_fn)) + env['generated_comment'] = comment template = _read_template(os.path.join(self.template_dir, template_fn)) self._write_if_changed(filename, template.substitute(env)) elif isinstance(env, str): @@ -817,6 +1136,33 @@ def write_outputs(self, filename: str) -> None: filename, ''.join(name + ";" for name in sorted(self.filenames))) +def get_custom_build_selector( + provided_op_registration_allowlist: Optional[List[str]], + op_selection_yaml_path: Optional[str]) -> SelectiveBuilder: + assert not ( + provided_op_registration_allowlist is not None and + op_selection_yaml_path is not None), ( + "Both provided_op_registration_allowlist and " + + "op_selection_yaml_path can NOT be provided at the " + + "same time.") + + op_registration_allowlist: Optional[Set[str]] = None + if provided_op_registration_allowlist is not None: + op_registration_allowlist = set(provided_op_registration_allowlist) + + if op_registration_allowlist is not None: + selector = SelectiveBuilder.from_legacy_op_registration_allow_list( + op_registration_allowlist, + True, + False, + ) + elif op_selection_yaml_path is not None: + selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path) + else: + selector = SelectiveBuilder.get_nop_selector() + + return selector + def main() -> None: parser = argparse.ArgumentParser(description='Generate ATen source files') parser.add_argument( @@ -835,17 +1181,22 @@ def main() -> None: '--rocm', action='store_true', help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly') - # TODO: remove this, we should just unconditionally generate Vulkan - parser.add_argument( - '--vulkan', - action='store_true', - help='Generate Vulkan backend functions') + # TODO: --op_registration_whitelist will be removed when all call-sites + # for gen.py are moved over to using the operator YAML file for mobile + # custom build. parser.add_argument( '--op_registration_whitelist', nargs='*', help='filter op registrations by the whitelist (if set); ' 'each item is `namespace`::`operator name` without overload name; ' 'e.g.: aten::empty aten::conv2d ...') + parser.add_argument( + '--op_selection_yaml_path', + help='Provide a path to the operator selection (for custom build) YAML ' + 'that contains the information about the set of selected operators ' + 'and their categories (training, ...). Each operator is either a ' + 'full operator name with overload or just a bare operator name. ' + 'The operator names also contain the namespace prefix (e.g. aten::)') parser.add_argument( '--backend_whitelist', nargs='*', @@ -858,14 +1209,31 @@ def main() -> None: 'those that are not listed on --op_registration_whitelist') options = parser.parse_args() - op_registration_whitelist: Optional[Set[str]] - if options.op_registration_whitelist is not None: - op_registration_whitelist = set(options.op_registration_whitelist) - else: - op_registration_whitelist = None + selector = get_custom_build_selector( + options.op_registration_whitelist, + options.op_selection_yaml_path, + ) native_functions = parse_native_yaml(os.path.join(options.source_path, 'native/native_functions.yaml')) + pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]] + pre_grouped_native_functions = defaultdict(dict) + for f in native_functions: + d = pre_grouped_native_functions[f.func.signature()] + assert f.func.kind() not in d + d[f.func.kind()] = f + + def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[NativeFunction, StructuredNativeFunctions]]: + r = StructuredNativeFunctions.from_dict(d) + if r is None: + return list(d.values()) + else: + return [r] + + # TODO: how come ValuesView isn't a Sequence lol + grouped_native_functions = list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))) + structured_native_functions = [g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)] + template_dir = os.path.join(options.source_path, "templates") # NB: It is mandatory to NOT use os.path.join here, as the install directory @@ -889,127 +1257,101 @@ def make_file_manager(install_dir: str) -> FileManager: cuda_fm = make_file_manager(options.install_dir) extra_cuda_headers = '''\ -#include +#include #include #include #include ''' if options.rocm: extra_cuda_headers = '''\ -#include +#include #include #include #include ''' - backends = ["CPU", "SparseCPU", "MkldnnCPU", "CUDA", "SparseCUDA", "QuantizedCPU", "QuantizedCUDA"] - if options.vulkan: - backends.append("Vulkan") + # NB: substrings in these dispatch keys matter, we do tests to see if + # a key contains, e.g., CUDA to classify it as a CUDA backend + dispatch_keys = [ + "CPU", + "SparseCPU", + "MkldnnCPU", + "CUDA", + "SparseCUDA", + "QuantizedCPU", + "QuantizedCUDA", + "Math", + "DefaultBackend", + # Meta is a magic key: it is automatically generated for structured + # kernels + "Meta", + ] if options.backend_whitelist: - backends = [b for b in backends if b in options.backend_whitelist] - - for dispatch in backends: - h_template = 'TypeDerived.h' - cpp_template = 'TypeDerived.cpp' - # TODO: delete this special case - if 'Sparse' in dispatch: - cpp_template = 'SparseTypeDerived.cpp' - - fm = cuda_fm if 'CUDA' in dispatch else cpu_fm - - fm.write_with_template(f'{dispatch}Type.h', h_template, lambda: { - 'Type': f'{dispatch}Type', - 'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', # TODO: remove this - 'type_derived_method_declarations': list(mapMaybe( - compute_type_method(dispatch, target=Target.DECLARATION, op_registration_whitelist=op_registration_whitelist), - native_functions - )), - }) - fm.write_with_template(f'{dispatch}Type.cpp', cpp_template, lambda: { - 'Type': f'{dispatch}Type', - # TODO: remove this - 'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', - # TODO: remove this - 'storage_tensor_headers': '#include ', - # TODO: remove this - 'Generator': 'CUDAGeneratorImpl' if 'CUDA' in dispatch else 'CPUGeneratorImpl', + dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist] + + for dispatch_key in dispatch_keys: + cpp_template = 'RegisterDispatchKey.cpp' + + fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm + + fm.write_with_template(f'Register{dispatch_key}.cpp', cpp_template, lambda: { + 'extra_cuda_headers': extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else '', 'legacy_th_headers': - '#include ' if dispatch == "CPU" else - '#include ' if dispatch == "CUDA" else + '#include ' if dispatch_key == "CPU" else + '#include ' if dispatch_key == "CUDA" else '', - 'Backend': dispatch, - 'type_derived_method_definitions': list(mapMaybe( - compute_type_method(dispatch, target=Target.DEFINITION, op_registration_whitelist=op_registration_whitelist), - native_functions + 'DispatchKey': dispatch_key, + 'dispatch_definitions': list(concatMap( + RegisterDispatchKey(dispatch_key, Target.DEFINITION, selector, rocm=options.rocm), + grouped_native_functions + )), + 'dispatch_registrations': list(concatMap( + RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm), + grouped_native_functions )), - 'function_registrations': list(mapMaybe( - compute_type_method( - dispatch, target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist), - native_functions)), }) del fm - cpu_fm.write('TypeDefault.h', lambda: { - 'type_method_declarations': - list(mapMaybe( - compute_type_method(None, target=Target.DECLARATION, op_registration_whitelist=op_registration_whitelist), - native_functions)) + - list(mapMaybe( - compute_type_method('Math', target=Target.DECLARATION, op_registration_whitelist=op_registration_whitelist), - native_functions)), + # BackendSelect is generated specially + cpu_fm.write('RegisterBackendSelect.cpp', lambda: { + 'backend_select_method_definitions': + list(mapMaybe(ComputeBackendSelect(Target.DEFINITION), native_functions)), + 'backend_select_function_registrations': + list(mapMaybe(ComputeBackendSelect(Target.REGISTRATION), native_functions)), + }) + cpu_fm.write('MetaFunctions.h', lambda: { + 'declarations': list(map(compute_meta_function_declaration, structured_native_functions)), }) - cpu_fm.write('TypeDefault.cpp', lambda: { - 'type_method_definitions': - list(mapMaybe( - compute_type_method(None, target=Target.DEFINITION, op_registration_whitelist=op_registration_whitelist), - native_functions)) + - list(mapMaybe( - compute_type_method('Math', target=Target.DEFINITION, op_registration_whitelist=op_registration_whitelist), - native_functions)), - - 'function_registrations': list(mapMaybe( - compute_type_method(None, target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist), - native_functions)), - - 'math_function_registrations': list(mapMaybe( - compute_type_method('Math', target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist), - native_functions)), + + schema_selector = selector + if options.force_schema_registration: + schema_selector = SelectiveBuilder.get_nop_selector() + cpu_fm.write('RegisterSchema.cpp', lambda: { + 'schema_registrations': list(mapMaybe(RegisterSchema(schema_selector), native_functions)), }) + cpu_fm.write('Functions.h', lambda: { - 'function_declarations': list(mapMaybe(compute_function(target=Target.DECLARATION), native_functions)), + 'function_declarations': list(mapMaybe(ComputeFunction(Target.DECLARATION), native_functions)), }) cpu_fm.write('Functions.cpp', lambda: { - 'function_definitions': list(mapMaybe(compute_function(target=Target.DEFINITION), native_functions)), + 'function_definitions': list(mapMaybe(ComputeFunction(Target.DEFINITION), native_functions)), }) core_fm.write('TensorBody.h', lambda: { - 'tensor_method_declarations': list(mapMaybe(compute_tensor_method(target=Target.DECLARATION), native_functions)), + 'tensor_method_declarations': list(mapMaybe(ComputeTensorMethod(Target.DECLARATION), native_functions)), }) core_fm.write('TensorMethods.cpp', lambda: { - 'tensor_method_definitions': list(mapMaybe(compute_tensor_method(target=Target.DEFINITION), native_functions)), + 'tensor_method_definitions': list(mapMaybe(ComputeTensorMethod(Target.DEFINITION), native_functions)), }) core_fm.write('ATenOpList.cpp', lambda: { 'aten_ops': list(mapMaybe(compute_aten_op, native_functions)), }) cpu_fm.write('NativeFunctions.h', lambda: { - 'native_function_declarations': list(concatMap(compute_native_function_declaration, native_functions)), - }) - cpu_fm.write('BackendSelectRegister.cpp', lambda: { - 'backend_select_method_definitions': - list(mapMaybe(compute_backend_select(target=Target.DEFINITION), native_functions)), - 'backend_select_function_registrations': - list(mapMaybe(compute_backend_select(target=Target.REGISTRATION), native_functions)), + 'native_function_declarations': list(concatMap(compute_native_function_declaration, grouped_native_functions)), }) - if options.force_schema_registration: - def computeSchemaRegister() -> Dict[str, object]: - schema_registrations = list(mapMaybe( - compute_type_method(None, target=Target.REGISTRATION, op_registration_whitelist=None, def_only=True), - native_functions)) - return { - 'schema_registrations': schema_registrations, - } - cpu_fm.write('SchemaRegister.cpp', computeSchemaRegister) - - cpu_fm.write('Declarations.yaml', lambda: format_yaml(list(map(compute_declaration_yaml, native_functions)))) + cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions])) + cpu_fm.write('RegistrationDeclarations.h', lambda: { + 'registration_declarations': [compute_registration_declarations(f) for f in native_functions], + }) if options.output_dependencies: cpu_fm.write_outputs(options.output_dependencies) diff --git a/tools/codegen/model.py b/tools/codegen/model.py index b0c470c91b6a8..1128878fe45cc 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass -from typing import List, Sequence, Dict, Optional, Iterator, Tuple, Set, NoReturn +from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn, Sequence, Callable, Union from enum import Enum import itertools @@ -47,10 +47,9 @@ def __str__(self) -> str: # Valid values of the 'variants' field in native_functions.yaml Variant = Enum('Variant', ('function', 'method')) -UseC10Dispatcher = Enum('UseC10Dispatcher', ( - 'full', - 'with_codegenerated_unboxing_wrapper' -)) +class UseC10Dispatcher(Enum): + full = 0 + hacky_wrapper_for_legacy_signatures = 2 # The basic input to the code generation is native_functions.yaml. # The name "native", BTW, comes from the distinction between native @@ -74,7 +73,7 @@ class NativeFunction: func: 'FunctionSchema' # Corresponds to the 'use_c10_dispatcher' field. The default - # is 'with_codegenerated_unboxing_wrapper' + # is 'full' use_c10_dispatcher: UseC10Dispatcher # Whether or not to omit automatic generation of a DeviceGuard @@ -95,18 +94,65 @@ class NativeFunction: # registrations don't participate in codegen-based selective build! manual_kernel_registration: bool - # Distinguish between a missing dispatch dict (historically, this - # means to register a catch-all kernel) and a present but empty - # dispatch dict (this means register nothing; arguably, this should - # subsume manual_kernel_registration). + # Whether or not to skip generating TensorMethod/Functions bindings + # for this kernel. Technically, this doesn't actually skip generating + # the binding; instead, the binding gets generated to __dispatch_{funcname} + # so you can make use of the normal binding if you need it. + manual_cpp_binding: bool + + # A mapping of dispatch keys to names of functions implementing + # them. In native_functions.yaml, the dispatch entry is optional; in that + # case, that is equivalent to having written: + # + # dispatch: + # Math: $operator_name # # TODO: str key could be replaced with more explicit enum - dispatch: Optional[Dict[str, str]] + dispatch: Dict[str, str] # The location in the YAML file were this native function entry was # defined. This is for conveniently reporting error messages! loc: 'Location' + # Whether or not this out functions is a "structured kernel". Structured + # kernels are defined a little differently from normal kernels; in + # particular, their shape checking logic is defined separately from + # the kernel. Only out functions can be structured; other functions + # delegate to the out function using the structured_delegate keyword. + # Every structured kernel must have at least an out and a functional + # variant. + structured: bool + + # Whether or not this non-out function is a structured kernel, defined + # in terms of the out kernel referenced by the string here. + structured_delegate: Optional['OperatorName'] + + # Only valid for structured kernels. Specifies alternative of what + # to inherit from when defining the meta class for the structured + # operator. This will usually be TensorIteratorBase. This also + # changes the semantics of set_output to call the parent class. + structured_inherits: Optional[str] + + # Argument names whose default should be excluded from the C++ interface. + # Intended for resolving overload ambiguities between signatures. + cpp_no_default_args: Set[str] + + # Note [Abstract ATen methods] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # An abstract ATen method is one whose dispatch differs between + # types. These are implemented in derived types (with a + # standard (throwing) definition in Type). A concrete ATen + # method is one which has the same dispatch for all types; + # we just implement it in the base Type. This is exposed + # in Declarations.yaml via a field named 'abstract'. + @property + def is_abstract(self) -> bool: + if self.structured_delegate: + # Structured functions MUST have a dispatch table + return True + else: + return self.dispatch.keys() != {'Math'} + # NB: The benefit of defining a dataclass is that we automatically get # a constructor defined for all the fields we specify. No need # to explicitly write it out. @@ -123,14 +169,18 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': assert isinstance(funcs, str), f'not a str: {funcs}' func = FunctionSchema.parse(funcs) - use_c10_dispatcher_s = e.pop('use_c10_dispatcher', None) - if use_c10_dispatcher_s is None: - use_c10_dispatcher = UseC10Dispatcher.with_codegenerated_unboxing_wrapper - elif use_c10_dispatcher_s == 'full': + cpp_no_default_args_list = e.pop('cpp_no_default_args', []) + assert isinstance(cpp_no_default_args_list, list) + cpp_no_default_args = set(cpp_no_default_args_list) + + use_c10_dispatcher_s = e.pop('use_c10_dispatcher', 'full') + if use_c10_dispatcher_s == 'full': use_c10_dispatcher = UseC10Dispatcher.full + elif use_c10_dispatcher_s == 'hacky_wrapper_for_legacy_signatures': + use_c10_dispatcher = UseC10Dispatcher.hacky_wrapper_for_legacy_signatures else: raise AssertionError( - f'use_c10_dispatcher must be unset or set to full, got {use_c10_dispatcher}') + f'use_c10_dispatcher must be full or hacky_wrapper_for_legacy_signatures, got {use_c10_dispatcher}') variants_s = e.pop('variants', 'function') assert isinstance(variants_s, str) @@ -146,9 +196,24 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': manual_kernel_registration = e.pop('manual_kernel_registration', False) assert isinstance(manual_kernel_registration, bool), f'not a bool: {manual_kernel_registration}' + manual_cpp_binding = e.pop('manual_cpp_binding', False) + assert isinstance(manual_cpp_binding, bool), f'not a bool: {manual_cpp_binding}' + device_guard = e.pop('device_guard', True) assert isinstance(device_guard, bool), f'not a bool: {device_guard}' + structured = e.pop('structured', False) + assert isinstance(structured, bool), f'not a bool: {structured}' + + structured_delegate_s = e.pop('structured_delegate', None) + assert structured_delegate_s is None or isinstance(structured_delegate_s, str), f'not a str: {structured_delegate}' + structured_delegate: Optional[OperatorName] = None + if structured_delegate_s is not None: + structured_delegate = OperatorName.parse(structured_delegate_s) + + structured_inherits = e.pop('structured_inherits', None) + assert structured_inherits is None or isinstance(structured_inherits, str), f'not a str: {structured_inherits}' + python_module = e.pop('python_module', None) assert python_module is None or isinstance(python_module, str), f'not a str: {python_module}' @@ -157,9 +222,11 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': raw_dispatch = e.pop('dispatch', None) assert raw_dispatch is None or isinstance(raw_dispatch, dict), e - dispatch: Optional[Dict[str, str]] = None + dispatch: Dict[str, str] = {} if raw_dispatch is not None: - dispatch = {} + assert not manual_kernel_registration, \ + "cannot specify both manual_kernel_registration and dispatch; with " \ + "manual registration, dispatch has no effect!" for ks, v in raw_dispatch.items(): if ks == '__line__': continue # not worth tracking line numbers for dispatch entries @@ -167,6 +234,14 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': assert isinstance(v, str), e for k in ks.split(","): dispatch[k.strip()] = v + else: + from tools.codegen.api import cpp + dispatch['Math'] = cpp.name(func) + + assert not ('DefaultBackend' in dispatch and 'Math' in dispatch), \ + "cannot specify both DefaultBackend and Math on a single kernel; each " \ + "strictly subsumes the other. If you wanted to provide an explicit autograd " \ + "implementation, specify DefaultBackend; otherwise specify Math only" e.pop('__line__') assert not e, f"leftover entries: {e}" @@ -175,14 +250,27 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': func=func, use_c10_dispatcher=use_c10_dispatcher, variants=variants, + structured=structured, + structured_delegate=structured_delegate, + structured_inherits=structured_inherits, manual_kernel_registration=manual_kernel_registration, + manual_cpp_binding=manual_cpp_binding, python_module=python_module, category_override=category_override, dispatch=dispatch, device_guard=device_guard, loc=loc, + cpp_no_default_args=cpp_no_default_args, ) + def validate_unstructured(self) -> None: + # TODO: probably better to accumulate these errors and report them all + # at once + assert not self.structured, "This function is structured, but there was " \ + "no valid functional variant of it." + assert self.structured_delegate, "This function delegates to another structured out function, " \ + "but no valid function was found (the delegate may not exist, or it has the wrong type)" + # __post_init__ functions in dataclasses can be used to do extra # validation after construction. # @@ -191,11 +279,83 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': # Validation is for nontrivial invariants that cannot be (conveniently) # encoded in the type system. def __post_init__(self) -> None: - if self.func.out_arguments: + if self.func.arguments.out: assert self.variants == {Variant.function}, "Native functions with out arguments MUST " \ "be declared with only function variant; e.g., variants: function; " \ "otherwise you will tickle a Python argument binding bug " \ "(which usually manifests itself as the result variable being undefined.)" + if self.structured: + assert self.func.kind() == SchemaKind.out, "Put structured field on the out= " \ + "variant of a function; did you mean structured_delegate?" + assert self.device_guard, "device_guard: False is not respected by structured kernels" + if self.structured_delegate: + assert self.func.kind() != SchemaKind.out, "structured_delegate field not allowed " \ + "on out= functions; did you mean structured?" + assert self.device_guard, "device_guard: False is not respected by structured kernels" + # Technically, with the asserts above, this assert is impossible to + # happen + assert not (self.structured and self.structured_delegate), \ + "Cannot have both structured and structured_delegate on function" + defaulted_arguments = {a.name for a in self.func.schema_order_arguments() + if a.default is not None} + invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments) + assert len(invalid_args) == 0, f'Invalid cpp_no_default_args: {invalid_args}' + if self.structured or self.structured_delegate: + assert self.use_c10_dispatcher is UseC10Dispatcher.full, \ + "Structured kernels MUST be use_c10_dispatcher: full; port your argument order" + +SchemaKind = Enum('SchemaKind', ('functional', 'inplace', 'out')) + +# A structured kernel is guaranteed to have a functional and out variant, and +# optionally an inplace variant. +@dataclass(frozen=True) +class StructuredNativeFunctions: + functional: NativeFunction + inplace: Optional[NativeFunction] + out: NativeFunction + + def __post_init__(self) -> None: + test_sig: FunctionSchema = self.functional.func.signature() + for f in self.functions(): + if test_sig != f.func.signature(): + raise AssertionError( + "StructuredNativeFunctions constructed from two NativeFunctions " + f"that don't have matching signatures: {test_sig} != {f.func.signature()}" + ) + assert self.functional.func.kind() == SchemaKind.functional + assert self.functional.structured_delegate == self.out.func.name, \ + f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " \ + f"but its actual delegate is {self.out.func.name}" + assert self.out.func.kind() == SchemaKind.out + assert self.out.structured + # For now, structured composite kernels are not supported (need some + # design work to figure out how to make the composite case work) + assert self.out.dispatch.keys() != {'Math'} + if self.inplace is not None: + assert self.inplace.func.kind() == SchemaKind.inplace + assert self.inplace.structured_delegate == self.out.func.name + + def signature(self) -> 'FunctionSchema': + return self.out.func.signature() + + def functions(self) -> Iterator[NativeFunction]: + yield self.out + yield self.functional + if self.inplace is not None: + yield self.inplace + + @staticmethod + def from_dict(d: Dict[SchemaKind, NativeFunction]) -> Optional['StructuredNativeFunctions']: + functional = d.get(SchemaKind.functional) + inplace = d.get(SchemaKind.inplace) + out = d.get(SchemaKind.out) + if functional is None or out is None or not out.structured: + return None + return StructuredNativeFunctions( + functional=functional, + inplace=inplace, + out=out, + ) # The function schema is undoubtedly the most important data structure # in all of the codegen, as it defines the type signature for operators, @@ -255,21 +415,17 @@ class FunctionSchema: # The name of the operator this function schema describes. name: 'OperatorName' - # NB: Sequence here is intentional, to make it read only - arguments: Sequence['Argument'] - kwarg_only_arguments: Sequence['Argument'] # but not including out args - # Unlike in the previous codegen, we have factored out 'out' arguments - # in the canonical representation, removing them from kwarg - # arguments. This choice is justified by numerous downstream - # transformations which treat out arguments specially; additionally, - # you can see that canonicity is not violated! - out_arguments: Sequence['Argument'] # these are also kwarg-only + arguments: 'Arguments' # TODO: Need to handle collisions with argument names at some point - returns: Sequence['Return'] + returns: Tuple['Return', ...] def schema_order_arguments(self) -> Iterator['Argument']: - return itertools.chain(self.arguments, self.kwarg_only_arguments, self.out_arguments) + return itertools.chain( + self.arguments.flat_positional, + self.arguments.flat_kwarg_only, + self.arguments.out + ) @staticmethod def parse(func: str) -> 'FunctionSchema': @@ -280,30 +436,42 @@ def parse(func: str) -> 'FunctionSchema': assert args[-1] == ")", "Expecting closing )" args = args[:-1] name = OperatorName.parse(ops) - arguments, kwarg_only_arguments, out_arguments = parse_arguments(args) + arguments = Arguments.parse(args) returns = parse_returns(return_decl) r = FunctionSchema( name=name, arguments=arguments, - kwarg_only_arguments=kwarg_only_arguments, - out_arguments=out_arguments, returns=returns ) assert str(r) == func, f'{str(r)} != {func}' return r def __post_init__(self) -> None: - for arg, ret in zip(self.out_arguments, self.returns): + for arg, ret in zip(self.arguments.out, self.returns): assert arg.annotation == ret.annotation, \ "Out arguments must have matching return Tensor; furthermore, " \ "the ith-argument needs to correspond to the ith return" - if self.out_arguments: - assert len(self.out_arguments) == len(self.returns), \ + # Invariant: we expect out arguments to appear as keyword arguments in the schema. + # This means that all mutable returns should be aliased to a keyword argument + # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) + # See Note [is_out_fn] + out_and_self = list(self.arguments.out) + [arg for arg in self.arguments.flat_positional if arg.name == "self"] + mutable_returns = [ret for ret in self.returns if ret.annotation is not None and ret.annotation.is_write] + for ret in mutable_returns: + assert any([ret.annotation == arg.annotation for arg in out_and_self]), \ + "All mutable returns must be aliased either to a keyword argument, or to \"self\". " \ + "Did you forget to mark an out argument as keyword-only?" + if self.arguments.out: + assert len(self.arguments.out) == len(self.returns), \ "Must return as many arguments as there are out arguments" if self.name.name.inplace: # TODO: fixme if str(self.name) not in [ - '_amp_non_finite_check_and_unscale_', + '_amp_foreach_non_finite_check_and_unscale_', + '_foreach_add_.ScalarList', + '_foreach_sub_.ScalarList', + '_foreach_mul_.ScalarList', + '_foreach_div_.ScalarList', '_foreach_add_.Scalar', '_foreach_sub_.Scalar', '_foreach_mul_.Scalar', @@ -314,8 +482,37 @@ def __post_init__(self) -> None: '_foreach_div_.List', '_foreach_exp_', '_foreach_sqrt_', - '_foreach_addcmul_', - '_foreach_addcdiv_']: + '_foreach_abs_', + '_foreach_acos_', + '_foreach_asin_', + '_foreach_atan_', + '_foreach_ceil_', + '_foreach_cos_', + '_foreach_cosh_', + '_foreach_erf_', + '_foreach_erfc_', + '_foreach_expm1_', + '_foreach_floor_', + '_foreach_log_', + '_foreach_log10_', + '_foreach_log1p_', + '_foreach_log2_', + '_foreach_neg_', + '_foreach_tan_', + '_foreach_tanh_', + '_foreach_sin_', + '_foreach_sinh_', + '_foreach_round_', + '_foreach_lgamma_', + '_foreach_frac_', + '_foreach_reciprocal_', + '_foreach_sigmoid_', + '_foreach_trunc_', + '_foreach_addcmul_.Scalar', + '_foreach_addcdiv_.Scalar', + '_foreach_addcmul_.ScalarList', + '_foreach_addcdiv_.ScalarList', + '_foreach_zero_']: assert len(self.returns) == 1 def is_out_fn(self) -> bool: @@ -345,16 +542,66 @@ def is_out_fn(self) -> bool: # but just with extra kwargs for the output elements. This # is difficult to actually check for and historically # we only do this check in tools/ - return bool(self.out_arguments) + return bool(self.arguments.out) + + def kind(self) -> SchemaKind: + """ + What kind of schema is this? A functional schema is one + that returns a newly allocated output; an inplace schema + modifies the self argument inplace; an out schema writes + the result into an explicitly provided out argument. + """ + is_inplace = self.name.name.inplace + is_out = bool(self.arguments.out) + assert not (is_inplace and is_out) + if is_inplace: + return SchemaKind.inplace + elif is_out: + return SchemaKind.out + else: + return SchemaKind.functional + + def signature(self, *, strip_default: bool = False) -> 'FunctionSchema': + """ + Certain schemas are 'related', in that they are simply + inplace/out/functional versions of the same function. This method + factors these schemas into the "core" functional signature which + is equal across all versions. + + Here is what normalization happens to the schema to convert + it to a signature: + - The overload name is stripped (name is retained, since + it expresses semantic content about what the function does) + - Inplace is set False + - Out arguments are stripped + - Mutability annotations are stripped (this is sound + because you cannot overload on mutability annotation) + - Return names are stripped since they are not overloadable and + some variants have return names but some not + """ + + def strip_ret_annotation(r: Return) -> Return: + return Return( + name=None, + type=r.type, + annotation=None, + ) + + return FunctionSchema( + name=OperatorName( + name=BaseOperatorName( + base=self.name.name.base, + inplace=False, + dunder_method=self.name.name.dunder_method, + ), + overload_name="", # stripped + ), + arguments=self.arguments.signature(strip_default=strip_default), + returns=tuple(map(strip_ret_annotation, self.returns)), + ) def __str__(self) -> str: - all_arguments: List[str] = [] - all_arguments.extend(map(str, self.arguments)) - if self.kwarg_only_arguments or self.out_arguments: - all_arguments.append('*') - all_arguments.extend(map(str, self.kwarg_only_arguments)) - all_arguments.extend(map(str, self.out_arguments)) - all_arguments_str = ', '.join(all_arguments) + all_arguments_str = str(self.arguments) if len(self.returns) == 1: returns = str(self.returns[0]) # omit parentheses else: @@ -372,14 +619,14 @@ def __str__(self) -> str: class Annotation: # Typically only has one element. Not actually a set so # we can conveniently assume it is canonically ordered - alias_set: Sequence[str] + alias_set: Tuple[str, ...] is_write: bool @staticmethod def parse(ann: str) -> 'Annotation': m = re.match(r'^([a-z])(!?)$', ann) assert m is not None, f'unrecognized alias annotation {ann}' - alias_set = [m.group(1)] + alias_set = (m.group(1),) is_write = m.group(2) == '!' r = Annotation(alias_set=alias_set, is_write=is_write) assert str(r) == ann, f'{r} != {ann}' @@ -451,6 +698,7 @@ def is_list_like(self) -> Optional['ListType']: 'MemoryFormat', 'QScheme', 'Storage', + 'Stream', 'ConstQuantizerPtr', # TODO: rename )) @@ -639,6 +887,253 @@ def __str__(self) -> str: return f"{type} {self.name}" +# Represents the self argument for functions that may be methods +@dataclass(frozen=True) +class SelfArgument: + argument: Argument + +# Bundle of arguments that represent a TensorOptions. This is mostly +# relevant for the public C++ API but we bake it into the core data +# model because other APIs often have to interact with it +@dataclass(frozen=True) +class TensorOptionsArguments: + dtype: Argument + layout: Argument + device: Argument + pin_memory: Argument + + def all(self) -> Sequence[Argument]: + return [self.dtype, self.layout, self.device, self.pin_memory] + +@dataclass(frozen=True) +class Arguments: + # pre_self_positional is usually empty, but is notably non-empty + # for where.self, where the condition argument comes before the + # self argument + pre_self_positional: Tuple[Argument, ...] + self_arg: Optional[SelfArgument] + post_self_positional: Tuple[Argument, ...] + + pre_tensor_options_kwarg_only: Tuple[Argument, ...] + tensor_options: Optional[TensorOptionsArguments] + # post_tensor_options is typically memory format, which should be + # part of tensor options but isn't right now, and is usually + # placed after the tensor options arguments + post_tensor_options_kwarg_only: Tuple[Argument, ...] + + # Unlike in the previous codegen, we have factored out 'out' arguments + # in the canonical representation, removing them from kwarg + # arguments. This choice is justified by numerous downstream + # transformations which treat out arguments specially; additionally, + # you can see that canonicity is not violated! + out: Tuple[Argument, ...] # these are also kwarg-only + + @property + def flat_non_out(self) -> Sequence[Argument]: + ret: List[Argument] = [] + ret.extend(self.flat_positional) + ret.extend(self.flat_kwarg_only) + return ret + + @property + def flat_positional(self) -> Sequence[Argument]: + ret: List[Argument] = [] + ret.extend(self.pre_self_positional) + if self.self_arg is not None: + ret.append(self.self_arg.argument) + ret.extend(self.post_self_positional) + return ret + + # NB: doesn't contain out arguments + @property + def flat_kwarg_only(self) -> Sequence[Argument]: + ret: List[Argument] = [] + ret.extend(self.pre_tensor_options_kwarg_only) + if self.tensor_options is not None: + ret.extend(self.tensor_options.all()) + ret.extend(self.post_tensor_options_kwarg_only) + return ret + + @property + def non_out(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: + ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] + ret.extend(self.positional) + ret.extend(self.kwarg_only) + return ret + + @property + def positional(self) -> Sequence[Union[Argument, SelfArgument]]: + ret: List[Union[Argument, SelfArgument]] = [] + ret.extend(self.pre_self_positional) + if self.self_arg is not None: + ret.append(self.self_arg) + ret.extend(self.post_self_positional) + return ret + + @property + def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]: + ret: List[Union[Argument, TensorOptionsArguments]] = [] + ret.extend(self.pre_tensor_options_kwarg_only) + if self.tensor_options is not None: + ret.append(self.tensor_options) + ret.extend(self.post_tensor_options_kwarg_only) + return ret + + def signature(self, *, strip_default: bool = False) -> 'Arguments': + # dataclasses.replace could be used here, but it is less + # type safe so for now I've opted to type everything out + def strip_arg_annotation(a: Argument) -> Argument: + return Argument( + name=a.name, + type=a.type, + default=a.default if not strip_default else None, + annotation=None, + ) + + return Arguments( + pre_self_positional=tuple(map(strip_arg_annotation, self.pre_self_positional)), + self_arg=SelfArgument( + strip_arg_annotation(self.self_arg.argument) + ) if self.self_arg is not None else None, + post_self_positional=tuple(map(strip_arg_annotation, self.post_self_positional)), + pre_tensor_options_kwarg_only=tuple(map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)), + # NB: tensor_options guaranteed to not have any alias annotations + tensor_options=self.tensor_options, + post_tensor_options_kwarg_only=tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)), + # out arguments are dropped in signature + out=(), + ) + + + @staticmethod + def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]: + positional: List[Argument] = [] + kwarg_only: List[Argument] = [] + out: List[Argument] = [] + arguments_acc = positional + + # TODO: Use a real parser here; this will get bamboozled + # by signatures that contain things like std::array (note the space) + for arg in args.split(', '): + if not arg: + continue + if arg == '*': + assert arguments_acc is positional, "invalid syntax: kwarg-only specifier * can only occur once" + arguments_acc = kwarg_only + continue + parg = Argument.parse(arg) + # Currently, we rely directly on the invariant that there are NO + # kwarg-only mutating arguments. If you want to relax this, + # we will need a more semantic way of matching that takes + # into account return arguments. In that case, you will have + # to manage out computation a level up, in FunctionSchema. See Note + # [is_out_fn] + if parg.annotation is not None and parg.annotation.is_write: + if arguments_acc is positional: + pass # do nothing + elif arguments_acc is kwarg_only: + arguments_acc = out + else: + assert arguments_acc is not out + arguments_acc.append(parg) + + return positional, kwarg_only, out + + @staticmethod + def parse(args: str) -> 'Arguments': + """ + Input: 'int x, int y, int z' + """ + + # We do this in two phases. First we parse into three + # main categories: positional, kwarg_only, out. + # Then, we reparse positional and kwarg_only to separate + # out the self argument and tensor options arguments. + + positional, kwarg_only, out = Arguments._preparse(args) + + # Split self argument + self_ix = None + for i, a in enumerate(positional): + if a.name == "self": + self_ix = i + break + pre_self_positional: List[Argument] + self_arg: Optional[SelfArgument] + post_self_positional: List[Argument] + if self_ix is not None: + pre_self_positional = positional[:self_ix] + self_arg = SelfArgument(positional[self_ix]) + post_self_positional = positional[self_ix + 1:] + else: + pre_self_positional = [] + self_arg = None + post_self_positional = positional + + # Group tensor options arguments + pre_tensor_options_kwarg_only: List[Argument] = [] + tensor_options: Optional[TensorOptionsArguments] = None + post_tensor_options_kwarg_only: List[Argument] = [] + kwarg_only_acc = pre_tensor_options_kwarg_only + + def pred(name: str, ty: Type) -> Callable[[Argument], bool]: + return lambda a: a.name == name and a.type in [ty, OptionalType(ty)] + predicates = [ # order matters + pred('dtype', Type.parse('ScalarType')), + pred('layout', Type.parse('Layout')), + pred('device', Type.parse('Device')), + pred('pin_memory', Type.parse('bool')), + ] + + i = 0 + while i < len(kwarg_only): + # If there is enough space... + if i <= len(kwarg_only) - len(predicates): + # And the next len(predicates) arguments look like TensorOptions arguments + if all(p(a) for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])): + assert kwarg_only_acc is pre_tensor_options_kwarg_only + # Group them together as one argument + tensor_options = TensorOptionsArguments( + dtype=kwarg_only[i], + layout=kwarg_only[i + 1], + device=kwarg_only[i + 2], + pin_memory=kwarg_only[i + 3], + ) + i += len(predicates) + kwarg_only_acc = post_tensor_options_kwarg_only + continue + kwarg_only_acc.append(kwarg_only[i]) + i += 1 + + return Arguments( + pre_self_positional=tuple(pre_self_positional), + self_arg=self_arg, + post_self_positional=tuple(post_self_positional), + pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only), + tensor_options=tensor_options, + post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only), + out=tuple(out), + ) + + + def __str__(self) -> str: + all_arguments: List[str] = [] + all_arguments.extend(map(str, self.flat_positional)) + if self.flat_kwarg_only or self.out: + all_arguments.append('*') + all_arguments.extend(map(str, self.flat_kwarg_only)) + all_arguments.extend(map(str, self.out)) + return ', '.join(all_arguments) + + def __post_init__(self) -> None: + # TODO: These invariants are weirdly asymmetric? + # TODO: Fancier types? + if self.self_arg is None: + assert not self.pre_self_positional + if self.tensor_options is None: + assert not self.post_tensor_options_kwarg_only + + # Names that validly are __iXXX__ indicating inplace operations. # Taken from https://www.python.org/dev/peps/pep-0203/#new-methods # NB: PyTorch hasn't actually implemented all of these @@ -725,53 +1220,13 @@ def __str__(self) -> str: # Helper functions for parsing argument lists (both inputs and returns) -def parse_returns(return_decl: str) -> Sequence[Return]: +def parse_returns(return_decl: str) -> Tuple[Return, ...]: """ Input: '()' Output: [] """ if return_decl == '()': - return [] + return () if return_decl[0] == '(' and return_decl[-1] == ')': return_decl = return_decl[1:-1] - returns = [] - for arg in return_decl.split(', '): - returns.append(Return.parse(arg)) - return returns - -def parse_arguments(args: str) -> Tuple[Sequence[Argument], Sequence[Argument], Sequence[Argument]]: - """ - Input: 'int x, int y, int z' - Output: positional args, kwarg only args - """ - arguments: List[Argument] = [] - kwarg_only_arguments: List[Argument] = [] - out_arguments: List[Argument] = [] - arguments_acc = arguments - - # TODO: Use a real parser here; this will get bamboozled - # by signatures that contain things like std::array (note the space) - for arg in args.split(', '): - if not arg: - continue - if arg == '*': - assert arguments_acc is arguments, "invalid syntax: kwarg-only specifier * can only occur once" - arguments_acc = kwarg_only_arguments - continue - parg = Argument.parse(arg) - # Currently, we rely directly on the invariant that there are NO - # kwarg-only mutating arguments. If you want to relax this, - # we will need a more semantic way of matching that takes - # into account return arguments. In that case, you will have - # to manage out_arguments computation a level up, in - # FunctionSchema. See Note [is_out_fn] - if parg.annotation is not None and parg.annotation.is_write: - if arguments_acc is arguments: - pass # do nothing - elif arguments_acc is kwarg_only_arguments: - arguments_acc = out_arguments - else: - assert arguments_acc is not out_arguments - arguments_acc.append(parg) - - return arguments, kwarg_only_arguments, out_arguments + return tuple(Return.parse(arg) for arg in return_decl.split(', ')) diff --git a/torch/utils/_benchmark/examples/__init__.py b/tools/codegen/selective_build/__init__.py similarity index 100% rename from torch/utils/_benchmark/examples/__init__.py rename to tools/codegen/selective_build/__init__.py diff --git a/tools/codegen/selective_build/operator.py b/tools/codegen/selective_build/operator.py new file mode 100644 index 0000000000000..68d4ba634fc1c --- /dev/null +++ b/tools/codegen/selective_build/operator.py @@ -0,0 +1,159 @@ +from typing import Dict, Optional, Tuple +from dataclasses import dataclass + +# This class holds information about a single operator used to determine +# the outcome of a selective/custom PyTorch build that doesn't include +# registration code for all the supported operators. This is done to +# reduce the size of the generated binary so that it can be deployed in +# situations where binary size comes at a premium. +# +@dataclass(frozen=True) +class SelectiveBuildOperator(): + # The name of the operator. This includes the aten::, etc... prefix + # The operator name may or may not have the overload name. If this + # operator name does not specify an overload name, the way to determine + # if this entry refers to the family of operators with this base name + # or just the operator with this name is to look at the value of the + # 'include_all_overloads' flag in this class. + name: str + + # True if this is a root operator (i.e. called directly from a + # TorchScript model, etc...). An operator is considered to be a + # root operator if it is called directly from any one of the models + # that this instance of the pytorch library was built for. Hence, it + # may not be a root operator in all of the models that are used in + # this instance of the pytorch library. + is_root_operator: bool + + # Is this operator used for on-device training? If True, then we need to + # use the information to generate code in VariableType_N.cpp for registration + # of training related operators. Again, this is True if this operator + # is used for training in one or more models used by this instance of the + # pytorch library. + is_used_for_training: bool + + # If True, it indicates that this operator instance (object) refers to an + # operator without the overload name and should apply to all overloads + # which have this operator name as the base name. This flag is applicable + # only for objects that have operator names without a DOT (period) character + # in them. + # + # Note: This flag is a temporary workaround to grandfather in the current + # static selective (custom) build mechanism, which largely ignores overload + # names when determining whether to select operators for registration + # purposes. + include_all_overloads: bool + + # Debug Information at the operator level + _debug_info: Optional[Tuple[str, ...]] + + @staticmethod + def from_yaml_dict(op_name: str, op_info: Dict[str, object]) -> 'SelectiveBuildOperator': + allowed_keys = {'name', 'is_root_operator', 'is_used_for_training', 'include_all_overloads', 'debug_info'} + + if len(set(op_info.keys()) - allowed_keys) > 0: + raise Exception("Got unexpected top level keys: {}".format( + ",".join(set(op_info.keys()) - allowed_keys), + )) + + if 'name' in op_info: + assert op_name == op_info['name'] + + is_root_operator = op_info.get('is_root_operator', True) + assert isinstance(is_root_operator, bool) + + is_used_for_training = op_info.get('is_used_for_training', True) + assert isinstance(is_used_for_training, bool) + + include_all_overloads = op_info.get('include_all_overloads', True) + assert isinstance(include_all_overloads, bool) + + debug_info: Optional[Tuple[str, ...]] = None + if 'debug_info' in op_info: + di_list = op_info['debug_info'] + assert isinstance(di_list, list) + debug_info = tuple(map(lambda x: str(x), di_list)) + + return SelectiveBuildOperator( + name=op_name, + is_root_operator=is_root_operator, + is_used_for_training=is_used_for_training, + include_all_overloads=include_all_overloads, + _debug_info=debug_info, + ) + + @staticmethod + def from_legacy_operator_name_without_overload(name: str) -> 'SelectiveBuildOperator': + return SelectiveBuildOperator( + name=name, + is_root_operator=True, + is_used_for_training=True, + include_all_overloads=True, + _debug_info=None, + ) + + def to_dict(self) -> Dict[str, object]: + ret: Dict[str, object] = { + 'is_root_operator': self.is_root_operator, + 'is_used_for_training': self.is_used_for_training, + 'include_all_overloads': self.include_all_overloads, + } + if self._debug_info is not None: + ret['debug_info'] = self._debug_info + + return ret + + +def merge_debug_info( + lhs: Optional[Tuple[str, ...]], + rhs: Optional[Tuple[str, ...]], +) -> Optional[Tuple[str, ...]]: + # Ensure that when merging, each entry shows up just once. + if lhs is None and rhs is None: + return None + + return tuple(set((lhs or ()) + (rhs or ()))) + + +def combine_operators( + lhs: 'SelectiveBuildOperator', + rhs: 'SelectiveBuildOperator') -> 'SelectiveBuildOperator': + if str(lhs.name) != str(rhs.name): + raise Exception( + "Expected both arguments to have the same name, but got '{}' and '{}' instead".format( + str(lhs.name), + str(rhs.name), + ) + ) + + return SelectiveBuildOperator( + name=lhs.name, + # Consider this operator to be a root operator if it is a + # root operator in any of the models used in this instance of + # the pytorch library. + is_root_operator=lhs.is_root_operator or rhs.is_root_operator, + # Consider this operator to be a training operator if it is + # an operator used for training in any of the models used + # in this instance of the pytorch library. + is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, + include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, + _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), + ) + +def merge_operator_dicts( + lhs: Dict[str, SelectiveBuildOperator], + rhs: Dict[str, SelectiveBuildOperator], +) -> Dict[str, SelectiveBuildOperator]: + operators: Dict[str, SelectiveBuildOperator] = {} + for (op_name, op) in list(lhs.items()) + list(rhs.items()): + new_op = op + if op_name in operators: + new_op = combine_operators(operators[op_name], op) + + operators[op_name] = new_op + + return operators + + +def strip_operator_overload_name(op_name: str) -> str: + return op_name.split(".")[0] diff --git a/tools/codegen/selective_build/selector.py b/tools/codegen/selective_build/selector.py new file mode 100644 index 0000000000000..eeb15049075e8 --- /dev/null +++ b/tools/codegen/selective_build/selector.py @@ -0,0 +1,240 @@ +from typing import Dict, Set, Optional, Tuple, List +import yaml + +from dataclasses import dataclass + +from tools.codegen.model import NativeFunction +from tools.codegen.selective_build.operator import * + +# A SelectiveBuilder holds information extracted from the selective build +# YAML specification. +# +# It includes information about the build's selectivity, the debug_info +# associated with this selective build (opaque string), and the set of +# operators that should be included in the build. +# +@dataclass(frozen=True) +class SelectiveBuilder: + + # If true, then the build is not selective, and includes all + # operators. + include_all_operators: bool + + # Debug Information at the selective/custom build level. + _debug_info: Optional[Tuple[str, ...]] + + # A dictionary of operator -> operator metadata. + operators: Dict[str, SelectiveBuildOperator] + + # A dictionary of selected kernel tags and dtypes. Typically a + # PyTorch Operator Kernel (function) may have many code paths + # that are specialized for many many Tensor dtypes, so it's not + # one per kernel function, but there could be many per kernel + # function. The tag isn't a kernel function name, but some fragment + # of the kernel function implementation itself. + kernel_metadata: Dict[str, List[str]] + + # If true, then fragments for all dtypes for all kernel functions + # are included. This is typically set when any one of the + # operator lists is generated from a mechanism other than + # tracing based selective build. + include_all_kernel_dtypes: bool + + @staticmethod + def get_nop_selector() -> 'SelectiveBuilder': + return SelectiveBuilder.from_yaml_dict({'include_all_operators': True}) + + @staticmethod + def from_yaml_dict(data: Dict[str, object]) -> 'SelectiveBuilder': + valid_top_level_keys = { + 'include_all_kernel_dtypes', + 'include_all_operators', + 'debug_info', + 'operators', + 'kernel_metadata', + } + top_level_keys = set(data.keys()) + if len(top_level_keys - valid_top_level_keys) > 0: + raise Exception("Got unexpected top level keys: {}".format( + ",".join(top_level_keys - valid_top_level_keys), + )) + include_all_operators = data.get('include_all_operators', False) + assert isinstance(include_all_operators, bool) + + debug_info = None + if 'debug_info' in data: + di_list = data['debug_info'] + assert isinstance(di_list, list) + + debug_info = tuple(map(lambda x: str(x), di_list)) + + operators = {} + operators_dict = data.get('operators', {}) + assert isinstance(operators_dict, dict) + + for (k, v) in operators_dict.items(): + operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v) + + kernel_metadata = {} + kernel_metadata_dict = data.get('kernel_metadata', {}) + assert isinstance(kernel_metadata_dict, dict) + + for (k, v) in kernel_metadata_dict.items(): + kernel_metadata[str(k)] = list(map(lambda dtype: str(dtype), v)) + + include_all_kernel_dtypes = data.get('include_all_kernel_dtypes', False) + assert isinstance(include_all_kernel_dtypes, bool) + + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + include_all_kernel_dtypes, + ) + + @staticmethod + def from_yaml_str(config_contents: str) -> 'SelectiveBuilder': + contents = yaml.load(config_contents) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_yaml_path(config_path: str) -> 'SelectiveBuilder': + with open(config_path, 'r') as f: + contents = yaml.load(f) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_legacy_op_registration_allow_list( + allow_list: Set[str], + is_root_operator: bool, + is_used_for_training: bool) -> 'SelectiveBuilder': + operators = {} + for op in allow_list: + operators[op] = { + 'name': op, + 'is_root_operator': is_root_operator, + 'is_used_for_training': is_used_for_training, + 'include_all_overloads': True, + } + return SelectiveBuilder.from_yaml_dict({ + 'operators': operators, + 'include_all_kernel_dtypes': True, + }) + + def is_operator_selected(self, name: str) -> bool: + if self.include_all_operators: + return True + + if name in self.operators: + return True + name = strip_operator_overload_name(name) + return name in self.operators and self.operators[name].include_all_overloads + + def is_native_function_selected(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected(op_name) + + def is_operator_selected_for_training(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + not_training_op = SelectiveBuildOperator( + name='', + is_root_operator=False, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op = not_training_op + if name in self.operators: + op = self.operators[name] + + name = strip_operator_overload_name(name) + base_op = not_training_op + if name in self.operators: + base_op = self.operators[name] + + return ( + op.is_used_for_training or + (base_op.include_all_overloads and base_op.is_used_for_training) + ) + + def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected_for_training(op_name) + + def is_root_operator(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + if name in self.operators: + op: SelectiveBuildOperator = self.operators[name] + return op.is_root_operator + name = strip_operator_overload_name(name) + if name not in self.operators: + return False + base_op: SelectiveBuildOperator = self.operators[name] + return base_op.include_all_overloads and base_op.is_root_operator + + def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool: + if self.include_all_operators or self.include_all_kernel_dtypes: + return True + + return kernel_tag in self.kernel_metadata and dtype in self.kernel_metadata[kernel_tag] + + def to_dict(self) -> Dict[str, object]: + ret: Dict[str, object] = { + 'include_all_kernel_dtypes': self.include_all_kernel_dtypes, + 'include_all_operators': self.include_all_operators, + } + operators = {} + for (op_name, op) in self.operators.items(): + operators[op_name] = op.to_dict() + ret['operators'] = operators + + if self._debug_info is not None: + ret['debug_info'] = self._debug_info + + ret['kernel_metadata'] = {k: list(v) for (k, v) in self.kernel_metadata.items()} + + return ret + + +def merge_kernel_metadata( + lhs: Dict[str, List[str]], + rhs: Dict[str, List[str]], +) -> Dict[str, List[str]]: + kernel_metadata: Dict[str, List[str]] = {} + for (tag_name, dtypes) in list(lhs.items()) + list(rhs.items()): + dtypes_copy = set(dtypes) + if tag_name in kernel_metadata: + dtypes_copy |= set(kernel_metadata[tag_name]) + + kernel_metadata[tag_name] = list(dtypes_copy) + + return kernel_metadata + +def combine_selective_builders(lhs: SelectiveBuilder, rhs: SelectiveBuilder) -> SelectiveBuilder: + include_all_operators = lhs.include_all_operators or rhs.include_all_operators + debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) + operators = merge_operator_dicts(lhs.operators, rhs.operators) + kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata) + include_all_kernel_dtypes = lhs.include_all_kernel_dtypes or rhs.include_all_kernel_dtypes + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + include_all_kernel_dtypes, + ) + + +def op_name_from_native_function(f: NativeFunction) -> str: + # This was originally read from the 'operator_name_with_overload' field in the + # declaration dict, which was the part before the first '(' in 'schema_string'. + return f'aten::{f.func.name}' diff --git a/tools/codegen/utils.py b/tools/codegen/utils.py new file mode 100644 index 0000000000000..093c0e0bb724e --- /dev/null +++ b/tools/codegen/utils.py @@ -0,0 +1,14 @@ +import re +from typing import Tuple, List + +# Matches "foo" in "foo, bar" but not "foobar". Used to search for the +# occurrence of a parameter in the derivative formula +IDENT_REGEX = r'(^|\W){}($|\W)' + +# TODO: Use a real parser here; this will get bamboozled +def split_name_params(schema: str) -> Tuple[str, List[str]]: + m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema) + if m is None: + raise RuntimeError(f'Unsupported function schema: {schema}') + name, _, params = m.groups() + return name, params.split(', ') diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py new file mode 100755 index 0000000000000..2a8d1d7314538 --- /dev/null +++ b/tools/fast_nvcc/fast_nvcc.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python3 + +import argparse +import asyncio +import collections +import csv +import hashlib +import itertools +import os +import pathlib +import re +import shlex +import shutil +import subprocess +import sys +import time + + +help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]... + +Run the commands given by nvcc --dryrun, in parallel. + +All flags for this script itself (see the "optional arguments" section +of --help) must be passed before the first "--". Everything after that +first "--" is passed directly to nvcc, with the --dryrun argument added. + +This script only works with the "normal" execution path of nvcc, so for +instance passing --help (after "--") doesn't work since the --help +execution path doesn't compile anything, so adding --dryrun there gives +nothing in stderr. +''' +parser = argparse.ArgumentParser(help_msg) +parser.add_argument( + '--faithful', + action='store_true', + help="don't modify the commands given by nvcc (slower)", +) +parser.add_argument( + '--graph', + metavar='FILE.dot', + help='write Graphviz DOT file with execution graph', +) +parser.add_argument( + '--nvcc', + metavar='PATH', + default='nvcc', + help='path to nvcc (default is just "nvcc")', +) +parser.add_argument( + '--save', + metavar='DIR', + help='copy intermediate files from each command into DIR', +) +parser.add_argument( + '--sequential', + action='store_true', + help='sequence commands instead of using the graph (slower)', +) +parser.add_argument( + '--table', + metavar='FILE.csv', + help='write CSV with times and intermediate file sizes', +) +parser.add_argument( + '--verbose', + metavar='FILE.txt', + help='like nvcc --verbose, but expanded and into a file', +) +default_config = parser.parse_args([]) + + +# docs about temporary directories used by NVCC +url_base = 'https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' +url_vars = f'{url_base}#keeping-intermediate-phase-files' + + +# regex for temporary file names +re_tmp = r'(? '{filename}'") + uniqueified.append(line) + return uniqueified + + +def make_rm_force(commands): + """ + Add --force to all rm commands. + """ + return [f'{c} --force' if c.startswith('rm ') else c for c in commands] + + +def print_verbose_output(*, env, commands, filename): + """ + Human-readably write nvcc --dryrun data to stderr. + """ + padding = len(str(len(commands) - 1)) + with open(filename, 'w') as f: + for name, val in env.items(): + print(f'#{" "*padding}$ {name}={val}', file=f) + for i, command in enumerate(commands): + prefix = f'{str(i).rjust(padding)}$ ' + print(f'#{prefix}{command[0]}', file=f) + for part in command[1:]: + print(f'#{" "*len(prefix)}{part}', file=f) + + +def straight_line_dependencies(commands): + """ + Return a straight-line dependency graph. + """ + return [({i - 1} if i > 0 else set()) for i in range(len(commands))] + + +def files_mentioned(command): + """ + Return fully-qualified names of all tmp files referenced by command. + """ + return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)] + + +def nvcc_data_dependencies(commands): + """ + Return a list of the set of dependencies for each command. + """ + # fatbin needs to be treated specially because while the cicc steps + # do refer to .fatbin.c files, they do so through the + # --include_file_name option, since they're generating files that + # refer to .fatbin.c file(s) that will later be created by the + # fatbinary step; so for most files, we make a data dependency from + # the later step to the earlier step, but for .fatbin.c files, the + # data dependency is sort of flipped, because the steps that use the + # files generated by cicc need to wait for the fatbinary step to + # finish first + tmp_files = {} + fatbins = collections.defaultdict(set) + graph = [] + for i, line in enumerate(commands): + deps = set() + for tmp in files_mentioned(line): + if tmp in tmp_files: + dep = tmp_files[tmp] + deps.add(dep) + if dep in fatbins: + for filename in fatbins[dep]: + if filename in tmp_files: + deps.add(tmp_files[filename]) + if tmp.endswith('.fatbin.c') and not line.startswith('fatbinary'): + fatbins[i].add(tmp) + else: + tmp_files[tmp] = i + if line.startswith('rm ') and not deps: + deps.add(i - 1) + graph.append(deps) + return graph + + +def is_weakly_connected(graph): + """ + Return true iff graph is weakly connected. + """ + neighbors = [set() for _ in graph] + for node, predecessors in enumerate(graph): + for pred in predecessors: + neighbors[pred].add(node) + neighbors[node].add(pred) + # assume nonempty graph + stack = [0] + found = {0} + while stack: + node = stack.pop() + for neighbor in neighbors[node]: + if neighbor not in found: + found.add(neighbor) + stack.append(neighbor) + return len(found) == len(graph) + + +def warn_if_not_weakly_connected(graph): + """ + Warn the user if the execution graph is not weakly connected. + """ + if not is_weakly_connected(graph): + fast_nvcc_warn('execution graph is not (weakly) connected') + + +def print_dot_graph(*, commands, graph, filename): + """ + Print a DOT file displaying short versions of the commands in graph. + """ + def name(k): + return f'"{k} {os.path.basename(commands[k][0])}"' + with open(filename, 'w') as f: + print('digraph {', file=f) + # print all nodes, in case it's disconnected + for i in range(len(graph)): + print(f' {name(i)};', file=f) + for i, deps in enumerate(graph): + for j in deps: + print(f' {name(j)} -> {name(i)};', file=f) + print('}', file=f) + + +async def run_command(command, *, env, deps, gather_data, i, save): + """ + Run the command with the given env after waiting for deps. + """ + for task in deps: + await task + if gather_data: + t1 = time.monotonic() + proc = await asyncio.create_subprocess_shell( + command, + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + code = proc.returncode + results = {'exit_code': code, 'stdout': stdout, 'stderr': stderr} + if gather_data: + t2 = time.monotonic() + results['time'] = t2 - t1 + sizes = {} + for tmp_file in files_mentioned(command): + if os.path.exists(tmp_file): + sizes[tmp_file] = os.path.getsize(tmp_file) + else: + sizes[tmp_file] = 0 + results['files'] = sizes + if save: + dest = pathlib.Path(save) / str(i) + dest.mkdir() + for src in map(pathlib.Path, files_mentioned(command)): + if src.exists(): + shutil.copy2(src, dest / (src.name)) + return results + + +async def run_graph(*, env, commands, graph, gather_data, save): + """ + Return outputs/errors (and optionally time/file info) from commands. + """ + tasks = [] + for i, (command, indices) in enumerate(zip(commands, graph)): + deps = {tasks[j] for j in indices} + tasks.append(asyncio.create_task(run_command( + command, + env=env, + deps=deps, + gather_data=gather_data, + i=i, + save=save, + ))) + return [await task for task in tasks] + + +def print_command_outputs(command_results): + """ + Print captured stdout and stderr from commands. + """ + for result in command_results: + sys.stdout.write(result['stdout'].decode('ascii')) + sys.stderr.write(result['stderr'].decode('ascii')) + + +def write_log_csv(command_parts, command_results, *, filename): + """ + Write a CSV file of the times and /tmp file sizes from each command. + """ + tmp_files = [] + for result in command_results: + tmp_files.extend(result['files'].keys()) + with open(filename, 'w', newline='') as csvfile: + fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files)) + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for i, result in enumerate(command_results): + command = f'{i} {os.path.basename(command_parts[i][0])}' + row = {'command': command, 'seconds': result['time']} + writer.writerow({**row, **result['files']}) + + +def exit_code(results): + """ + Aggregate individual exit codes into a single code. + """ + for result in results: + code = result['exit_code'] + if code != 0: + return code + return 0 + + +def fast_nvcc(args, *, config=default_config): + """ + Emulate the result of calling the given nvcc binary with args. + + Should run faster than plain nvcc. + """ + warn_if_windows() + warn_if_tmpdir_flag(args) + dryrun_data = nvcc_dryrun_data(config.nvcc, args) + env = dryrun_data['env'] + warn_if_tmpdir_set(env) + commands = dryrun_data['commands'] + if not config.faithful: + commands = make_rm_force(unique_module_id_files(commands)) + command_parts = list(map(shlex.split, commands)) + if config.verbose: + print_verbose_output( + env=env, + commands=command_parts, + filename=config.verbose, + ) + graph = nvcc_data_dependencies(commands) + warn_if_not_weakly_connected(graph) + if config.graph: + print_dot_graph( + commands=command_parts, + graph=graph, + filename=config.graph, + ) + if config.sequential: + graph = straight_line_dependencies(commands) + results = asyncio.run(run_graph( + env=env, + commands=commands, + graph=graph, + gather_data=bool(config.table), + save=config.save, + )) + print_command_outputs(results) + if config.table: + write_log_csv(command_parts, results, filename=config.table) + return exit_code([dryrun_data] + results) + + +def our_arg(arg): + return arg != '--' + + +if __name__ == '__main__': + argv = sys.argv[1:] + us = list(itertools.takewhile(our_arg, argv)) + them = list(itertools.dropwhile(our_arg, argv)) + sys.exit(fast_nvcc(them[1:], config=parser.parse_args(us))) diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index b5ea62ff29bbd..8129f38eb0efe 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -2,6 +2,7 @@ import os import subprocess from pathlib import Path +from distutils.util import strtobool def get_sha(): try: @@ -27,7 +28,7 @@ def get_torch_version(sha=None): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate torch/version.py from build and environment metadata.") - parser.add_argument("--is_debug", type=bool, help="Whether this build is debug mode or not.") + parser.add_argument("--is_debug", type=strtobool, help="Whether this build is debug mode or not.") parser.add_argument("--cuda_version", type=str) parser.add_argument("--hip_version", type=str) @@ -47,7 +48,7 @@ def get_torch_version(sha=None): # NB: This is not 100% accurate, because you could have built the # library code with DEBUG, but csrc without DEBUG (in which case # this would claim to be a release build when it's not.) - f.write("debug = {}\n".format(repr(args.is_debug))) + f.write("debug = {}\n".format(repr(bool(args.is_debug)))) f.write("cuda = {}\n".format(repr(args.cuda_version))) f.write("git_version = {}\n".format(repr(sha))) f.write("hip = {}\n".format(repr(args.hip_version))) diff --git a/tools/jit/gen_unboxing_wrappers.py b/tools/jit/gen_unboxing_wrappers.py deleted file mode 100644 index 1af3fda918de3..0000000000000 --- a/tools/jit/gen_unboxing_wrappers.py +++ /dev/null @@ -1,530 +0,0 @@ -""" -To run this file by hand from the root of the PyTorch -repository, run: - -python -m tools.jit.gen_unboxing_wrappers \ - build/aten/src/ATen/Declarations.yaml \ - $OUTPUT_DIR \ - tools/jit/templates - -Where $OUTPUT_DIR is where you would like the files to be -generated. In the full build system, OUTPUT_DIR is -torch/csrc/jit/generated/ -""" - -# This file generates generated_unboxing_wrappers, which contains -# manual unboxing wrappers for ops that aren't use_c10_dispatcher: full -# because the templated unboxing logic in c10 doesn't support them yet. -# The ultimate goal is to make all ops use the templated unboxing and -# delete this codegen file. - -import argparse -import re -from itertools import groupby -from ..autograd.gen_autograd import load_aten_declarations -from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT -from ..autograd.utils import CodeTemplate, write, is_out_variant, op_name_without_overload - -# JIT has a type system of -# Scalar = int | float | bool # int is the largest int (int64_t), -# float is the largest float (double) we don't have the others because they are never held in tensors -# Type = Scalar # primitive numbers -# | Tensor # any tensor, as defined by at::Tensor -# | Type[] # a dynamically sized list[ of a type -# | Scalar[N] # a homogenous fixed size scalar list, single scalars can expand to this list -# | (Type1, Type2, ...) # a heterogeneous tuple -# | Layout | ScalarType | Device | Generator # special singleton types for built-in concepts in tensor lib - -# clean up the variety of C++ types in the ATen declarations -# to be in the restricted set of types that the IR represents -# note: no default values for this map, to make it clear what types -# can be passedthrough - -TYPE_MAP = { - 'std::array': 'bool[2]', - 'std::array': 'bool[3]', - 'std::array': 'bool[4]', - 'std::string': 'str', - 'std::string?': 'str?', - 'Scalar': 'Scalar', - 'MemoryFormat': 'MemoryFormat', - 'MemoryFormat?': 'MemoryFormat?', - 'QScheme': 'QScheme', - 'Scalar?': 'Scalar?', - 'Tensor': 'Tensor', - 'Tensor?': 'Tensor?', - 'TensorList': 'Tensor[]', - # this appears in return values instead of TensorList - # since TensorList is a ArrayRef in arguments but a vector - # in returns - 'std::vector': 'Tensor[]', - 'IntArrayRef': 'int[]', - 'IntArrayRef?': 'int[]?', - 'ArrayRef?': 'float[]?', - 'Layout': 'Layout', - 'Layout?': 'Layout?', - 'Device': 'Device', - 'Device?': 'Device?', - 'ScalarType': 'ScalarType', - 'ScalarType?': 'ScalarType?', - 'int64_t': 'int', - 'int64_t?': 'int?', - 'double': 'float', - 'double?': 'float?', - 'bool': 'bool', - 'bool?': 'bool?', - 'Generator': 'Generator?', - 'Generator?': 'Generator?', -} - - -def optional_type_of(arg, typ): - # optional type special handling for Tensor?[] and Tensor - # types that is missing a optional annotation - if arg.get('is_nullable') and '?' not in typ: - if typ == 'TensorList' or typ == 'Tensor[]': - typ = 'Tensor?[]' - else: - typ = '{}?'.format(typ) - return typ - - -def annotated_type_of(arg, typ): - anno = arg.get('annotation') - if anno: - typ = '{}({})'.format(typ, anno) - return typ - - -def jit_type_of(arg): - jit_type = arg.get('jit_type') - if not jit_type: - jit_type = TYPE_MAP[arg['simple_type']] - if is_sized_intlist_arg(arg): - jit_type = 'int[{}]'.format(arg['size']) - jit_type = optional_type_of(arg, jit_type) - jit_type = annotated_type_of(arg, jit_type) - arg['jit_type'] = jit_type - return jit_type - - -# map from aten 'simple_type' to the function that will turn a tensor into -# that type -FROM_IVALUE = { - 'Device': '{}.toDevice()', - 'Device?': '{}.toOptional()', - 'IntArrayRef': '{}.toIntVector()', - 'IntArrayRef?': '{}.toOptionalIntArray()', - 'ArrayRef?': '{}.toOptionalDoubleArray()', - 'Layout': '{}.toLayout()', - 'Layout?': '{}.toOptional()', - 'MemoryFormat': '{}.toMemoryFormat()', - 'MemoryFormat?': '{}.toOptional()', - 'QScheme': '{}.toQScheme()', - 'Scalar': '{}.toScalar()', - 'Scalar?': '{}.toOptional()', - 'ScalarType': '{}.toScalarType()', - 'ScalarType?': '{}.toOptional()', - 'Tensor': '{}.toTensor()', - 'Tensor?': 'toOptionalTensor({})', - 'Tensor?[]': 'toListOfOptionalTensor({})', - 'TensorList': '{}.toTensorVector()', - 'bool': '{}.toBool()', - 'bool?': '{}.toOptional()', - 'double': '{}.toDouble()', - 'double?': '{}.toOptional()', - 'int64_t': '{}.toInt()', - 'int64_t?': '{}.toOptional()', - 'std::string': '{}.toStringRef()', - 'std::string?': '{}.toOptional()', - 'Generator?': '{}.toOptional()', - 'std::array': 'as_bool_array<2>({}.toBoolList())', - 'std::array': 'as_bool_array<3>({}.toBoolList())', - 'std::array': 'as_bool_array<4>({}.toBoolList())', -} - - -def from_ivalue(arg, value): - typ = optional_type_of(arg, arg['simple_type']) - return FROM_IVALUE[typ].format(value) - - -CALL_UNBOXED_KERNEL = CodeTemplate("""\ -auto result_ = callUnboxedKernel<${return_type}${formals_types_with_leading_comma}>(unboxedKernel${args_with_leading_comma}); -""") -CALL_NAMESPACE = CodeTemplate("""\ -auto result_ = at::${name}( - ${args} -); -""") -CALL_METHOD = CodeTemplate("""\ -auto result_ = (${first}).${name}( - ${args} -); -""") -CALL_NAMESPACE_WITH_TENSOR_OPTIONS = CodeTemplate("""\ -const auto options = TensorOptions() - .dtype(${dtype}) - .layout(${layout}) - .device(${device}) - .pinned_memory(${pin_memory}); - auto result_ = torch::${name}(${args_with_tensor_options}); -""") -CALL_METHOD_WITH_TENSOR_OPTIONS = CodeTemplate("""\ -const auto options = TensorOptions() - .dtype(${dtype}) - .layout(${layout}) - .device(${device}) - .pinned_memory(${pin_memory}); -auto result_ = (${first}).${name}(${args_with_tensor_options}); -""") - -CONSTRUCTOR = CodeTemplate("""\ -[](OperatorKernel* unboxedKernel, const OperatorHandle&, Stack* stack) { - using namespace at; - ${lvalues} - ${call} - drop(*stack, ${num_inputs}); - pack(*stack, std::move(result_)); -} -""") - -OPERATOR = CodeTemplate("""\ - .op("${signature}", - ${op}) -""") - - -disallowed_types = { - 'Storage', - 'DimnameList?', - 'ConstQuantizerPtr', - 'Dimname', - 'DimnameList', -} - -default_only_types = {'Generator'} - - -def is_jit_arg(i, arg): - simple_type = arg['simple_type'] - if simple_type in disallowed_types: - return False - if simple_type in default_only_types and 'default' not in arg: - return False - if simple_type == 'Type': - return False - return True - - -def is_jit_op(decl): - # We currently don't support functions that return nothing - assert all(r['type'] != 'void' for r in decl['returns']) - if len(decl['returns']) == 0: - return False - - arguments = decl['arguments'] - - # there must be a single out variant - if is_out_variant(decl) and sum([not not arg.get('output') for arg in arguments]) > 1: - return False - - return (('namespace' in decl['method_of'] or 'Tensor' in decl['method_of']) and - all(is_jit_arg(i, arg) for i, arg in enumerate(decl['arguments'])) and - all(is_jit_arg(i, arg) for i, arg in enumerate(decl['returns']))) - - -def is_tensor_arg(arg): - return arg['simple_type'] in {'Tensor', 'TensorList'} - - -def is_sized_intlist_arg(arg): - """Returns True for arguments declared as IntArrayRef[k], but False for IntArrayRef.""" - return (arg['simple_type'] == 'IntArrayRef') and ('size' in arg) - - -def base_name(decl): - name = decl['name'] - return name[:-1] if decl.get('inplace', False) else name[:-4] if name.endswith('_out') else name - - -def is_view(decl): - return base_name(decl) in RETURNS_VIEWS_OF_INPUT - - -# Copied from ..autograd.gen_python_functions.SKIP_PYTHON_BINDINGS -BACKWARD_OP_PATTERNS = [ - '.*_backward', - '.*_backward_(out|input|weight|bias)', -] - -def is_backward_op(decl): - for pattern in BACKWARD_OP_PATTERNS: - if re.match('^' + pattern + '$', decl['name']): - return True - return False - - -# for each argument in decl, the location it should appear in the -# jit schema declaration. e.g. -# arguments = [x, y, z] # the order in aten -# jit_argument_order = [2, 0, 1] -# aten::my_arg(Tensor y, Tensor z, Tensor x) # the order in schema -# used to move 'out' arguments to the end of the list -def argument_order(decl): - return decl.get('jit_argument_order') or list(range(len(decl['arguments']))) - - -def gen_unboxing_wrappers( - declarations, - out, - template_path, - disable_autograd=False, - selected_op_list=None, - force_schema_registration=False, -): - GENERATED_UNBOXING_WRAPPERS_CPP = CodeTemplate.from_file(template_path + '/generated_unboxing_wrappers.cpp') - - ops = [] - - def get_invocation(decl, args, num_inputs): - - # because the arg list can get lengthy we put them on a separate line - def pack_arguments(args): - return ',\n'.join(args) - is_namespace_function = 'namespace' in decl['method_of'] - tensor_options_arg_index = decl.get('tensor_options_arg_index', None) - if tensor_options_arg_index is not None: - dtype = args[tensor_options_arg_index] - layout = args[tensor_options_arg_index + 1] - device = args[tensor_options_arg_index + 2] - pin_memory = args[tensor_options_arg_index + 3] - args_with_tensor_options = args[:tensor_options_arg_index] + \ - ['options'] + args[(tensor_options_arg_index + 4):] - if is_namespace_function: - return CALL_NAMESPACE_WITH_TENSOR_OPTIONS.substitute( - name=decl['name'], dtype=dtype, layout=layout, - device=device, pin_memory=pin_memory, - args_with_tensor_options=pack_arguments(args_with_tensor_options)) - else: - return CALL_METHOD_WITH_TENSOR_OPTIONS.substitute( - name=decl['name'], dtype=dtype, layout=layout, - device=device, pin_memory=pin_memory, - args_with_tensor_options=pack_arguments(args_with_tensor_options[1:]), - first=args_with_tensor_options[0], num_inputs=num_inputs) - elif decl['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper': - if len(decl['returns']) == 0: - return_type = "void" - elif len(decl['returns']) == 1: - return_type = decl['returns'][0]['type'] - else: - return_type = "std::tuple<{}>".format(", ".join([r['type'] for r in decl['returns']])) - for a in decl['arguments']: - if 'type' not in a: - raise Exception(decl) - argument_types_with_leading_comma = ", ".join([a['type'] for a in decl['arguments']]) - if argument_types_with_leading_comma != "": - argument_types_with_leading_comma = ", " + argument_types_with_leading_comma - args_with_leading_comma = pack_arguments(args) - if args_with_leading_comma != "": - args_with_leading_comma = ", " + args_with_leading_comma - return CALL_UNBOXED_KERNEL.substitute(name=decl['name'], - args_with_leading_comma=args_with_leading_comma, - num_inputs=num_inputs, - return_type=return_type, - formals_types_with_leading_comma=argument_types_with_leading_comma) - else: - assert decl['use_c10_dispatcher'] == 'full' - if is_namespace_function: - return CALL_NAMESPACE.substitute(name=decl['name'], - args=pack_arguments(args), - num_inputs=num_inputs) - else: - return CALL_METHOD.substitute( - name=decl['name'], first=args[0], - args=pack_arguments(args[1:]), num_inputs=num_inputs) - - def requires_lvalue(arg): - jit_type = jit_type_of(arg) - return jit_type.startswith('Tensor') and '!' in jit_type - - def emit_decl_variant(decl): - if ('emit_dummy_placeholder' in decl): - return "DUMMY_OPERATION" - kw_assignments = [] - - # mutable arguments in aten are passed as non const references - # these must be lvalues, so we have to put them in variables - # before calling the function - lvalues = [] - - arguments = [] - num_inputs = len(decl['arguments']) - op_capture = '' - order = argument_order(decl) - for i, arg in enumerate(decl['arguments']): - value = from_ivalue(arg, '(std::move(peek(*stack, {}, {})))'.format(order[i], num_inputs)) - if requires_lvalue(arg): - lvalues.append('auto {} = {};\n'.format(arg['name'], value)) - value = arg['name'] - arguments.append(value) - - call = get_invocation(decl, arguments, num_inputs) - - returns = decl['returns'] - - if decl['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper': - constructor = CONSTRUCTOR.substitute(name=decl['name'], - call=call, - kw_assignments=kw_assignments, - num_inputs=num_inputs, - op_capture=op_capture, - lvalues=lvalues) - else: - assert decl['use_c10_dispatcher'] == 'full' - - return constructor - - def filter_decls(jit_decls, disable_autograd, selected_op_list, force_schema_registration): - result = [] - for decl in jit_decls: - if disable_autograd and is_backward_op(decl): - continue - op_name = op_name_without_overload(decl) - if selected_op_list is not None and op_name not in selected_op_list: - if force_schema_registration: - decl['emit_dummy_placeholder'] = True - else: - continue - result.append(decl) - return result - - # This function declares an order on declarations. This is necessary because - # there is some ambiguity in the choice of overload: if an argument is overloaded - # to accept both Scalar and Tensor, the schema with the Tensor should come first - # TODO: this can (probably) be removed when we remove the implicit conversion - # from Tensor -> Number. - def sort_decls(jit_decls): - def declkey(decl): - # key = sum_{i < len(args)} {1 if arg is tensor else 2} * (3 ** i) - # This is a ternary encoding where - # 0: No argument at this position - # 1: Tensor argument at this position - # 2: Some other argument at this position. - args = decl['arguments'] - result = 0 - for i in range(len(args)): - result += (3 ** i) * (1 if args[i]['simple_type'] == 'Tensor' else 2) - return result - - # NB: itertools.groupby requires the list be sorted. - sorted_decls = sorted(jit_decls, key=lambda decl: decl['name']) - grouped_decls = [list(g) for _, g in - groupby(sorted_decls, key=lambda decl: decl['name'])] - return [sorted(g, key=declkey) for g in grouped_decls] - - aten_decls = load_aten_declarations(declarations) - jit_decls = [d for d in aten_decls if is_jit_op(d)] - - # add arguments dtype and device for functions like zeros - def expand_options(decl, i, arg): - if arg['simple_type'] != 'TensorOptions': - return [arg] - assert decl.get('tensor_options_arg_index') != i - decl['tensor_options_arg_index'] = i - tensor_options_expansion = [ - # XXX - until we actually have first-class interpreter types for these - # concepts, the default values to be encoded in Tensors - # If you change this, you also need to update [TensorOptions in script] - # in the tracer code. - # dtype is specified as an int64_t of at::ScalarType - {'name': 'dtype', 'simple_type': 'ScalarType'}, - # layout is specified as an int64_t of at::Layout - {'name': 'layout', 'simple_type': 'Layout'}, - # device is specified as an IntArrayRef of { at::Device::Type, device_id } - {'name': 'device', 'simple_type': 'Device'}, - # pin_memory is specified as a boolean - {'name': 'pin_memory', 'simple_type': 'bool', 'default': False}, - ] - # TODO: Don't repack this into TensorOptions. Needs various changes in downstream code. - if 'default' in arg: - for el in tensor_options_expansion: - el['simple_type'] += '?' - el['default'] = 'None' - if 'default' in arg and arg['default'] == 'at::kLong': - tensor_options_expansion[0]['default'] = 'long' - if 'kwarg_only' in arg and arg['kwarg_only']: - for el in tensor_options_expansion: - el['kwarg_only'] = True - return tensor_options_expansion - - additional_jit_decls = [] - - for decl in jit_decls: - decl['arguments'] = [a for i, arg in enumerate(decl['arguments']) for a in expand_options(decl, i, arg)] - if is_out_variant(decl): - reorder_out_args(decl) - - jit_decls.extend(additional_jit_decls) - jit_decls = filter_decls(jit_decls, disable_autograd, selected_op_list, force_schema_registration) - - # generation is deterministic - jit_decl_groups = sort_decls(jit_decls) - - # NOTE: see Note [Sharded File] at the top of the generated_unboxing_wrappers.cpp - # template regarding sharding of the generated files. - # - # If you edit the number of shards here, you will also have to - # modify generate_code.py, torch/CMakeLists.txt, and the TARGETS - # files. - num_shards = 3 - shards = [[] for _ in range(num_shards)] - - # ops are assigned arbitrarily but stably to a file based on hash - for group in jit_decl_groups: - x = sum(ord(c) for c in group[0]['name']) % num_shards - for decl in group: - if decl['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper': - shards[x].append(OPERATOR.substitute(signature=decl['schema_string'], - op=emit_decl_variant(decl))) - else: - assert decl['use_c10_dispatcher'] == 'full' - - for i, shard in enumerate(shards): - env = { - 'constructors': shard, - } - write(out, 'generated_unboxing_wrappers_%d.cpp' % i, GENERATED_UNBOXING_WRAPPERS_CPP, env) - - -default_map = {'{}': 'None', 'nullptr': 'None', 'c10::nullopt': 'None'} - - -def reorder_out_args(decl): - first_arg = decl['arguments'][0] - assert(first_arg['output']) - # the output variant must go at the end - # note: this is an annoying side effect of using a single '*' - # to denote kwarg_only - nargs = len(decl['arguments']) - decl['jit_argument_order'] = [nargs - 1] + list(range(nargs - 1)) - - -def is_kwarg_only(a): - return a.get('kwarg_only') or a.get('output') - -def main(): - parser = argparse.ArgumentParser( - description='Generate JIT op dispatch') - parser.add_argument('declarations', metavar='DECL', - help='path to Declarations.yaml') - parser.add_argument('out', metavar='OUT', - help='path to output directory') - parser.add_argument('template_path', metavar='TEMPLATE_PATH', - help='path to templates directory') - args = parser.parse_args() - gen_unboxing_wrappers(args.declarations, args.out, args.template_path) - - -if __name__ == '__main__': - main() diff --git a/tools/jit/templates/generated_unboxing_wrappers.cpp b/tools/jit/templates/generated_unboxing_wrappers.cpp deleted file mode 100644 index cd8d12f6b15e8..0000000000000 --- a/tools/jit/templates/generated_unboxing_wrappers.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include "torch/csrc/jit/runtime/operator.h" -#include "torch/csrc/jit/runtime/custom_operator.h" -#include "torch/csrc/jit/frontend/function_schema_parser.h" - -#include "torch/csrc/autograd/profiler.h" -#include "torch/csrc/autograd/generated/variable_factories.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// ${generated_comment} - -// This file contains manual unboxing wrappers for ops that aren't -// use_c10_dispatcher: full because the templated unboxing logic in c10 doesn't -// support them yet. The ultimate goal is to make all ops use the templated -// unboxing and delete this codegen file. - -// NOTE [Sharded File]: This file is generated in a sharded fashion to speed up -// incremental rebuilds. See the comment at the top of -// templates/VariableType.cpp for an analogous, in-depth discussion. - -namespace torch { namespace jit { - -using autograd::Variable; -using autograd::variable_list; -using at::Scalar; -using at::ScalarType; -using at::Tensor; -using at::TensorOptions; -using at::DeviceGuard; -using at::MemoryFormat; - -using ::c10::fmap; -using ::c10::filter; -using c10::OperatorKernel; -using c10::OperatorHandle; -using c10::KernelFunction; -using c10::RegistrationHandleRAII; -using c10::Stack; - -namespace { - -template -Return callUnboxedKernel(OperatorKernel* unboxedKernel, Args... args) { - using FuncType = Return (Args...); - auto* typedUnboxedKernel = static_cast*>(unboxedKernel); - return (*typedUnboxedKernel)(std::forward(args)...); -} - -// TODO: remove the toOptionalTensor and toListOfOptionalTensor -// when we remove the undefined tensor semantic from TH - -// XXX: This function is to specialize IValue for tensor type in -// interpreter, it should only be used in this file -at::Tensor toOptionalTensor(const IValue& v) { - if (v.isNone()) { - return at::Tensor(); - } - return v.toTensor(); -} - -// XXX: This function is to specialize IValue for list of optional -// tensor type in interpreter, it should only be used in this file -std::vector toListOfOptionalTensor(const IValue& v) { - // v is a list of optional tensor, loop over as generic list - auto vlist = v.toListRef(); - std::vector res; - - for (const IValue &v: vlist) { - res.emplace_back(toOptionalTensor(v)); - } - return res; -} - -template -std::array as_bool_array(const c10::List& list) { - std::array res; - AT_ASSERT(list.size() == N); - std::copy(list.begin(), list.end(), res.begin()); - return res; -} - -KernelFunction::InternalBoxedKernelFunction *DUMMY_OPERATION = - [](c10::OperatorKernel *, const c10::OperatorHandle &, std::vector *) -> void { - TORCH_CHECK(false, "Operator has been stripped in the custom build.") - }; - -class Registerer final { -public: - Registerer&& op(const std::string& schemaStr, KernelFunction::InternalBoxedKernelFunction* boxed_kernel_wrapper) && { - static auto& dispatcher = c10::Dispatcher::singleton(); - auto schema = parseSchema(schemaStr); - schema.setAliasAnalysis(AliasAnalysisKind::FROM_SCHEMA); - c10::OperatorName name = schema.operator_name(); - RegistrationHandleRAII registration = dispatcher.registerName(name); - auto op = dispatcher.findOp(name).value(); - registrationHandles_.push_back(std::move(registration)); - dispatcher.setManuallyBoxedKernelFor_(op, boxed_kernel_wrapper); - return std::move(*this); - } - - Registerer() = default; - Registerer(const Registerer&) = delete; - Registerer& operator=(const Registerer&) = delete; - Registerer(Registerer&&) noexcept = default; - Registerer& operator=(Registerer&&) noexcept = default; -private: - std::vector registrationHandles_; -}; - -static auto registry = Registerer() - // Generated operators - ${constructors} - ; - -} // anon namespace - - -}} // namespace torch::jit diff --git a/tools/nightly.py b/tools/nightly.py index a6dc185628b71..55a90e3fd9fb0 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -135,7 +135,7 @@ def logging_rotate() -> None: @contextlib.contextmanager def logging_manager(*, debug: bool = False) -> Iterator[None]: """Setup logging. If a failure starts here we won't - be able to save the user ina reasonable way. + be able to save the user in a reasonable way. Logging structure: there is one logger (the root logger) and in processes all events. There are two handlers: @@ -210,7 +210,7 @@ def check_branch(subcommand, branch): @contextlib.contextmanager def timer(logger, prefix): - """Timed conetxt manager""" + """Timed context manager""" start_time = time.time() yield logger.info(f"{prefix} took {time.time() - start_time:.3f} [s]") @@ -322,10 +322,10 @@ def pytorch_install(url): def _site_packages(dirname, platform): if platform.startswith("win"): - os.path.join(pytdir.name, "Lib", "site-packages") + template = os.path.join(dirname, "Lib", "site-packages") else: template = os.path.join(dirname, "lib", "python*.*", "site-packages") - spdir = glob.glob(template)[0] + spdir = glob.glob(template)[0] return spdir diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 7079c67502238..7a37ab134dcb6 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1,19 +1,14 @@ - -import os import collections from pprint import pformat -import yaml -import re import argparse -from ..autograd.utils import YamlLoader, CodeTemplate, write -from ..autograd.gen_python_functions import ( - get_py_torch_functions, - get_py_variable_methods, - namedtuple_fieldnames, -) -from ..autograd.gen_autograd import load_aten_declarations +from tools.codegen.model import * +from tools.codegen.api.python import * +from tools.codegen.gen import FileManager +from typing import Sequence, List, Dict + +from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads """ This module implements generation of type stubs for PyTorch, @@ -32,18 +27,38 @@ (the latter case should be pretty rare). - We go through automatically bound functions based on the - type information recorded in Declarations.yaml and + type information recorded in native_functions.yaml and generate type hints for them (generate_type_hints) There are a number of type hints which we've special-cased; read gen_pyi for the gory details. """ +def get_py_torch_functions( + python_funcs: Sequence[PythonSignatureNativeFunctionPair], + method: bool = False, +) -> Sequence[PythonSignatureGroup]: + """ + Get declarations (grouped by name) which should be generated + as either functions in the "torch" module or methods on Tensor. + """ + def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool: + return (should_generate_py_binding(python_func.function) and + not python_func.function.python_module and + Variant.function in python_func.function.variants) + + def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: + return (should_generate_py_binding(python_func.function) and + not python_func.function.python_module and + Variant.method in python_func.function.variants) + + should_bind = should_bind_method if method else should_bind_function + return group_overloads([f for f in python_funcs if should_bind(f)]) + + # TODO: Consider defining some aliases for our Union[...] types, to make # the stubs to read on the human eye. -needed_modules = set() - DEVICE_PARAM = "device: Union[_device, str, None]=None" FACTORY_PARAMS = f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False" @@ -105,89 +120,6 @@ 'floor_divide', 'floor_divide_', 'floor_divide_out', ] - -def type_to_python(typename, size=None): - """type_to_python(typename: str, size: str) -> str - - Transforms a Declarations.yaml type name into a Python type specification - as used for type hints. - """ - typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *' - - # Disambiguate explicitly sized int/tensor lists from implicitly - # sized ones. These permit non-list inputs too. (IntArrayRef[] and - # TensorList[] are not real types; this is just for convenience.) - if typename in {'IntArrayRef', 'TensorList'} and size is not None: - typename += '[]' - - typename = { - 'Device': 'Device', - 'Generator': 'Generator', - 'IntegerTensor': 'Tensor', - 'Scalar': 'Number', - 'ScalarType': '_dtype', - 'Storage': 'Storage', - 'BoolTensor': 'Tensor', - 'IndexTensor': 'Tensor', - 'Tensor': 'Tensor', - 'MemoryFormat': 'memory_format', - 'IntArrayRef': '_size', - 'IntArrayRef[]': 'Union[_int, _size]', - 'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]', - 'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]', - 'bool': '_bool', - 'double': '_float', - 'int64_t': '_int', - 'accreal': 'Number', - 'real': 'Number', - 'void*': '_int', # data_ptr - 'void': 'None', - 'std::string': 'str', - 'Dimname': 'Union[str, ellipsis, None]', - 'DimnameList': 'Sequence[Union[str, ellipsis, None]]', - 'QScheme': '_qscheme', - }[typename] - - return typename - - -def arg_to_type_hint(arg): - """arg_to_type_hint(arg) -> str - - This takes one argument in a Declarations and returns a string - representing this argument in a type hint signature. - """ - name = arg['name'] - if name == 'from': # from is a Python keyword... - name += '_' - typename = type_to_python(arg['dynamic_type'], arg.get('size')) - if arg.get('is_nullable'): - typename = 'Optional[' + typename + ']' - if 'default' in arg: - default = arg['default'] - if default == 'nullptr': - default = None - elif default == 'c10::nullopt': - default = None - elif isinstance(default, str) and default.startswith('{') and default.endswith('}'): - if arg['dynamic_type'] == 'Tensor' and default == '{}': - default = None - elif arg['dynamic_type'] == 'Generator' and default == '{}': - default = None - elif arg['dynamic_type'] == 'IntArrayRef': - default = '(' + default[1:-1] + ')' - else: - raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type'])) - elif default == 'MemoryFormat::Contiguous': - default = 'contiguous_format' - elif default == 'QScheme::PER_TENSOR_AFFINE': - default = 'per_tensor_affine' - default = '={}'.format(default) - else: - default = '' - return name + ': ' + typename + default - - binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv', 'matmul', 'floordiv', 'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic @@ -201,7 +133,7 @@ def arg_to_type_hint(arg): all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops -def sig_for_ops(opname): +def sig_for_ops(opname: str) -> List[str]: """sig_for_ops(opname : str) -> List[str] Returns signatures for operator special functions (__add__ etc.)""" @@ -231,121 +163,35 @@ def sig_for_ops(opname): else: raise Exception("unknown op", opname) +def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: + type_hints: List[str] = [] -def generate_type_hints(fname, decls, namedtuples, is_tensor=False): - """generate_type_hints(fname, decls, is_tensor=False) - - Generates type hints for the declarations pertaining to the function - :attr:`fname`. attr:`decls` are the declarations from the parsed - Declarations.yaml. - :attr:`namedtuples` is a dictionary for accumulating NamedTuple definitions. - The :attr:`is_tensor` flag indicates whether we are parsing - members of the Tensor class (true) or functions in the - `torch` namespace (default, false). - - This function currently encodes quite a bit about the semantics of - the translation C++ -> Python. - """ - if fname in blocklist: - return [] - - type_hints = [] - dnames = ([d['name'] for d in decls]) - has_out = fname + '_out' in dnames - - if has_out: - decls = [d for d in decls if d['name'] != fname + '_out'] - - for decl in decls: - render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument - python_args = [] - - has_tensor_options = 'TensorOptions' in (a['dynamic_type'] for a in decl['arguments']) - - for a in decl['arguments']: - if a['dynamic_type'] != 'TensorOptions': - if a.get('kwarg_only', False) and render_kw_only_separator: - python_args.append('*') - render_kw_only_separator = False - try: - python_args.append(arg_to_type_hint(a)) - except Exception: - print("Error while processing function {}".format(fname)) - raise - - if 'self: Tensor' in python_args: - self_index = python_args.index('self: Tensor') - python_args.remove('self: Tensor') - if is_tensor: - python_args = ['self'] + python_args - else: - python_args.insert(self_index, 'input: Tensor') - else: - if is_tensor: - raise Exception("method without self is unexpected") - - if has_out: - if render_kw_only_separator: - python_args.append('*') - render_kw_only_separator = False - python_args.append('out: Optional[Tensor]=None') - - if has_tensor_options: - if render_kw_only_separator: - python_args.append('*') - render_kw_only_separator = False - python_args += ["dtype: _dtype=None", - "layout: _layout=strided", - "device: Union[_device, str, None]=None", - "requires_grad:_bool=False"] - - python_args_s = ', '.join(python_args) - python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']] - field_names = namedtuple_fieldnames(decl) - - if field_names: - namedtuple_name = '_'.join(['namedtuple'] + field_names) - tuple_args = ['("{}", {})'.format(name, typ) for name, typ in zip(field_names, python_returns)] - namedtuple_def = 'NamedTuple("{}", [{}])'.format(namedtuple_name, ', '.join(tuple_args)) - if namedtuple_name in namedtuples: - assert namedtuples[namedtuple_name] == namedtuple_def - else: - namedtuples[namedtuple_name] = namedtuple_def - python_returns_s = namedtuple_name - elif len(python_returns) > 1: - python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']' - elif len(python_returns) == 1: - python_returns_s = python_returns[0] - else: - python_returns_s = 'None' - - type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) - numargs = len(decl['arguments']) - vararg_pos = int(is_tensor) - have_vararg_version = (numargs > vararg_pos and - decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef'} and - (numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and - (not is_tensor or decl['arguments'][0]['name'] == 'self')) + # Some deprecated ops that are on the blocklist are still included in pyi + if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: + return type_hints + # deprecated signatures have separate entries for their functional and out variants + # (as opposed to the native ops, which fuse the two into a single signature). + # generate the functional variant here, if an out variant exists. + if sig_group.signature.deprecated and sig_group.outplace is not None: + type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) type_hints.append(type_hint) - if have_vararg_version: - # Two things come into play here: PyTorch has the "magic" that if the first and only positional argument - # is an IntArrayRef, it will be used as a vararg variant. - # The following outputs the vararg variant, the "pass a list variant" is output above. - # The other thing is that in Python, the varargs are annotated with the element type, not the list type. - typelist = decl['arguments'][vararg_pos]['dynamic_type'] - vararg_type = '_int' - # replace first argument and eliminate '*' if present - python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] + - ': ' + vararg_type] + python_args[vararg_pos + 2:]) - python_args_s = ', '.join(python_args) - type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) - type_hints.append(type_hint) + # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument + # Generates the out variant if one exists. Otherwise, generate the functional variant + type_hint = sig_group.signature.signature_str_pyi( + skip_outputs=sig_group.outplace is None) + type_hints.append(type_hint) + + # Some operators also additionally have a vararg variant of their signature + type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( + skip_outputs=sig_group.outplace is None) + if type_hint_vararg: + type_hints.append(type_hint_vararg) return type_hints -def gen_nn_functional(out): +def gen_nn_functional(fm: FileManager) -> None: # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered # through an `_add_docstr` call imports = [ @@ -362,6 +208,7 @@ def gen_nn_functional(out): 'celu_', 'rrelu_', 'pixel_shuffle', + 'pixel_unshuffle', 'channel_shuffle', 'pdist', 'cosine_similarity', @@ -392,20 +239,22 @@ def gen_nn_functional(out): import_code = ["from .. import {0} as {0}".format(_) for _ in imports] # TODO make these types more precise dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] - stubs = CodeTemplate.from_file(os.path.join('torch', 'nn', 'functional.pyi.in')) - env = { + fm.write_with_template('torch/nn/functional.pyi', 'torch/nn/functional.pyi.in', lambda: { 'imported_hints': import_code, - 'dispatched_hints': dispatch_code - } - write(out, 'torch/nn/functional.pyi', stubs, env) + 'dispatched_hints': dispatch_code, + }) - stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in')) - write(out, 'torch/_C/_nn.pyi', stubs, env) + # functional.pyi already contains the definitions for those functions + # so, we don't export then to it + from_c.extend(['hardtanh', 'leaky_relu', 'hardsigmoid']) + dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)] + fm.write_with_template('torch/_C/_nn.pyi', 'torch/_C/_nn.pyi.in', lambda: { + 'imported_hints': import_code, + 'dispatched_hints': dispatch_code, + }) -def gen_nn_pyi(out): - gen_nn_functional(out) -def gen_pyi(declarations_path, out): +def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None: """gen_pyi() This function generates a pyi file for torch. @@ -418,16 +267,13 @@ def gen_pyi(declarations_path, out): # checking. If you are update this, consider if your change # also needs to update the other file. - # Load information from YAML - declarations = load_aten_declarations(declarations_path) - # Dictionary for NamedTuple definitions - namedtuples = {} + namedtuples: Dict[str, str] = {} # Generate type signatures for top-level functions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - unsorted_function_hints = collections.defaultdict(list) + unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_function_hints.update({ 'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'], 'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'], @@ -448,6 +294,9 @@ def gen_pyi(declarations_path, out): 'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],' ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,' ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'], + '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],' + ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,' + ' requires_grad: bool = False) -> Tensor: ...'], 'range': ['def range(start: Number, end: Number,' ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], @@ -465,10 +314,12 @@ def gen_pyi(declarations_path, out): ' generator: Optional[Generator]=None, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], 'full': ['def full(size: _size, fill_value: Number, *,' - ' out: Optional[Tensor]=None, {}) -> Tensor: ...' + ' out: Optional[Tensor]=None,' + ' layout: _layout=strided, {}) -> Tensor: ...' .format(FACTORY_PARAMS), 'def full(size: _size, fill_value: Number, *,' - ' names: List[Union[str, None]], {}) -> Tensor: ...' + ' names: List[Union[str, None]],' + ' layout: _layout=strided, {}) -> Tensor: ...' .format(FACTORY_PARAMS)], 'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'], 'nonzero': ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...', @@ -485,21 +336,20 @@ def gen_pyi(declarations_path, out): ' other: Union[Tensor, Number],' ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop)) - function_declarations = get_py_torch_functions(declarations) - for name in sorted(function_declarations.keys()): - unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name], namedtuples) - - # Generate type signatures for deprecated functions - - # TODO: Maybe we shouldn't generate type hints for deprecated - # functions :) However, examples like those addcdiv rely on these. - with open('tools/autograd/deprecated.yaml', 'r') as f: - deprecated = yaml.load(f, Loader=YamlLoader) - for d in deprecated: - name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups() - sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')] - sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig] - unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig))) + function_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=False, pyi=True) + sig_groups = get_py_torch_functions(function_signatures) + for group in sorted(sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_function_hints[name] += generate_type_hints(group) + + named_tuple = group.signature.returns.named_tuple_pyi() + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def function_hints = [] for name, hints in sorted(unsorted_function_hints.items()): @@ -510,26 +360,26 @@ def gen_pyi(declarations_path, out): # Generate type signatures for Tensor methods # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - unsorted_tensor_method_hints = collections.defaultdict(list) + unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_tensor_method_hints.update({ 'size': ['def size(self) -> Size: ...', 'def size(self, _int) -> _int: ...'], 'stride': ['def stride(self) -> Tuple[_int]: ...', 'def stride(self, _int) -> _int: ...'], - 'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'. - format(type_to_python('IntArrayRef'), FACTORY_PARAMS)], + 'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'. + format(FACTORY_PARAMS)], 'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], # new and __init__ have the same signatures differ only in return type # Adapted from legacy_tensor_ctor and legacy_tensor_new 'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM), 'def new(self, storage: Storage) -> Tensor: ...', 'def new(self, other: Tensor) -> Tensor: ...', - 'def new(self, size: {}, *, {}) -> Tensor: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM), + 'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM), ], '__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM), 'def __init__(self, storage: Storage) -> None: ...', 'def __init__(self, other: Tensor) -> None: ...', - 'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM), + 'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM), ], 'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."], # clamp has no default values in the Declarations @@ -565,6 +415,7 @@ def gen_pyi(declarations_path, out): 'is_quantized': ['is_quantized: _bool'], 'is_meta': ['is_meta: _bool'], 'is_mkldnn': ['is_mkldnn: _bool'], + 'is_vulkan': ['is_vulkan: _bool'], 'storage_offset': ['def storage_offset(self) -> _int: ...'], 'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...', 'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, ' @@ -603,10 +454,23 @@ def gen_pyi(declarations_path, out): for name in simple_conversions: unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name)) - tensor_method_declarations = get_py_variable_methods(declarations) - for name in sorted(tensor_method_declarations.keys()): - unsorted_tensor_method_hints[name] += \ - generate_type_hints(name, tensor_method_declarations[name], namedtuples, is_tensor=True) + # pyi tensor methods don't currently include deprecated signatures for some reason + # TODO: we should probably add them in + tensor_method_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True) + tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True) + + for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): + name = group.signature.name + unsorted_tensor_method_hints[name] += generate_type_hints(group) + + named_tuple = group.signature.returns.named_tuple_pyi() + if named_tuple is not None and not group.signature.deprecated: + # deprecated namedtuples are currently not included for torch functions + tuple_name, tuple_def = named_tuple + if tuple_name in namedtuples: + assert namedtuples[tuple_name] == tuple_def + else: + namedtuples[tuple_name] = tuple_def for op in all_ops: name = '__{}__'.format(op) @@ -630,11 +494,14 @@ def gen_pyi(declarations_path, out): # TODO: These are deprecated, maybe we shouldn't type hint them legacy_storage_base_hints = [] - for c in ('Double', 'Float', 'Long', 'Int', - 'Short', 'Char', 'Byte', 'Bool', - 'Half', 'BFloat16', 'ComplexDouble', - 'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32'): + dt = ('Double', 'Float', 'Long', 'Int', + 'Short', 'Char', 'Byte', 'Bool', + 'Half', 'BFloat16', 'ComplexDouble', + 'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2') + for c in dt: legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c)) + for c in dt: + legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c)) legacy_class_hints = [] for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', @@ -651,7 +518,7 @@ def gen_pyi(declarations_path, out): ['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half', 'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long', 'complex32', 'complex64', 'cfloat', 'complex128', 'cdouble', - 'quint8', 'qint8', 'qint32', 'bool']] + 'quint8', 'qint8', 'qint32', 'bool', 'quint4x2']] # Generate __all__ directive # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -675,27 +542,36 @@ def gen_pyi(declarations_path, out): 'dtype_class_hints': dtype_class_hints, 'all_directive': all_directive } - TORCH_C_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '_C', '__init__.pyi.in')) - TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS = \ - CodeTemplate.from_file(os.path.join('torch', '_C', '_VariableFunctions.pyi.in')) - - write(out, 'torch/_C/__init__.pyi', TORCH_C_TYPE_STUBS, env) - write(out, 'torch/_C/_VariableFunctions.pyi', TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS, env) - write(out, 'torch/_VF.pyi', TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS, env) - gen_nn_pyi(out) + fm.write_with_template('torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: { + 'generated_comment': '@' + 'generated from torch/_C/__init__.pyi.in', + **env, + }) + fm.write_with_template('torch/_C/_VariableFunctions.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: { + 'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in', + **env, + }) + fm.write_with_template('torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: { + 'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in', + **env, + }) + gen_nn_functional(fm) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Generate type stubs for PyTorch') - parser.add_argument('--declarations-path', metavar='DECL', - default='torch/share/ATen/Declarations.yaml', - help='path to Declarations.yaml') + parser.add_argument('--native-functions-path', metavar='NATIVE', + default='aten/src/ATen/native/native_functions.yaml', + help='path to native_functions.yaml') + parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED', + default='tools/autograd/deprecated.yaml', + help='path to deprecated.yaml') parser.add_argument('--out', metavar='OUT', default='.', help='path to output directory') args = parser.parse_args() - gen_pyi(args.declarations_path, args.out) + fm = FileManager(install_dir=args.out, template_dir='.', dry_run=False) + gen_pyi(args.native_functions_path, args.deprecated_functions_path, fm) if __name__ == '__main__': diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 83253cc3a5269..f1809552cd40b 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -7,7 +7,6 @@ import re from subprocess import check_call, check_output import sys -import distutils import distutils.sysconfig from distutils.version import LooseVersion @@ -32,7 +31,7 @@ def _mkdir_p(d): def convert_cmake_value_to_python_value(cmake_value, cmake_type): r"""Convert a CMake value in a string form to a Python value. - Arguments: + Args: cmake_value (string): The CMake value in a string form (e.g., "ON", "OFF", "1"). cmake_type (string): The CMake type of :attr:`cmake_value`. @@ -56,7 +55,7 @@ def convert_cmake_value_to_python_value(cmake_value, cmake_type): def get_cmake_cache_variables_from_file(cmake_cache_file): r"""Gets values in CMakeCache.txt into a dictionary. - Arguments: + Args: cmake_cache_file: A CMakeCache.txt file object. Returns: dict: A ``dict`` containing the value of cached CMake variables. @@ -168,7 +167,7 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e ninja_deps_file = os.path.join(self.build_dir, '.ninja_deps') if IS_WINDOWS and USE_NINJA and os.path.exists(ninja_deps_file): # Cannot rerun ninja on Windows due to a ninja bug. - # The workground is to remove `.ninja_deps`. + # The workaround is to remove `.ninja_deps`. os.remove(ninja_deps_file) args = [] @@ -245,6 +244,7 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e 'MKL_THREADING', 'MKLDNN_CPU_RUNTIME', 'MSVC_Z7_OVERRIDE', + 'CAFFE2_USE_MSVC_STATIC_RUNTIME', 'Numa_INCLUDE_DIR', 'Numa_LIBRARIES', 'ONNX_ML', diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index df6bea52b4277..f04f10cc287c5 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -43,7 +43,7 @@ class BuildType(object): is ``None``, then the build type will be inferred from ``CMakeCache.txt``. If ``CMakeCache.txt`` does not exist, os.environ['CMAKE_BUILD_TYPE'] will be used. - Arguments: + Args: cmake_build_type_env (str): The value of os.environ['CMAKE_BUILD_TYPE']. If None, the actual build type will be inferred. diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index 8a290f4d68937..f939203c46efb 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -1,11 +1,18 @@ import argparse import os import sys +import yaml + +try: + # use faster C loader if available + from yaml import CLoader as YamlLoader +except ImportError: + from yaml import Loader as YamlLoader source_files = {'.py', '.cpp', '.h'} DECLARATIONS_PATH = 'torch/share/ATen/Declarations.yaml' - +NATIVE_FUNCTIONS_PATH = 'aten/src/ATen/native/native_functions.yaml' # TODO: This is a little inaccurate, because it will also pick # up setup_helper scripts which don't affect code generation @@ -22,19 +29,16 @@ def all_generator_source(): def generate_code(ninja_global=None, declarations_path=None, nn_path=None, + native_functions_path=None, install_dir=None, subset=None, disable_autograd=False, - selected_op_list_path=None, - selected_op_list=None, - force_schema_registration=False): - # cwrap depends on pyyaml, so we can't import it earlier - root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - sys.path.insert(0, root) + force_schema_registration=False, + operator_selector=None): from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python from tools.autograd.gen_annotated_fn_args import gen_annotated - from tools.autograd.utils import load_op_list_and_strip_overload - from tools.jit.gen_unboxing_wrappers import gen_unboxing_wrappers + from tools.codegen.selective_build.selector import SelectiveBuilder + # Build ATen based Variable classes if install_dir is None: @@ -53,36 +57,85 @@ def generate_code(ninja_global=None, tools_jit_templates = os.path.join(data_dir, 'tools', 'jit', 'templates') if subset == "pybindings" or not subset: - gen_autograd_python(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, autograd_dir) + gen_autograd_python( + declarations_path or DECLARATIONS_PATH, + native_functions_path or NATIVE_FUNCTIONS_PATH, + autograd_gen_dir, + autograd_dir) + + if operator_selector is None: + operator_selector = SelectiveBuilder.get_nop_selector() if subset == "libtorch" or not subset: - selected_op_list = load_op_list_and_strip_overload(selected_op_list, selected_op_list_path) gen_autograd( declarations_path or DECLARATIONS_PATH, + native_functions_path or NATIVE_FUNCTIONS_PATH, autograd_gen_dir, autograd_dir, disable_autograd=disable_autograd, - selected_op_list=selected_op_list, + operator_selector=operator_selector, ) - gen_unboxing_wrappers( - declarations_path or DECLARATIONS_PATH, - jit_gen_dir, - tools_jit_templates, - disable_autograd=disable_autograd, - selected_op_list=selected_op_list, - force_schema_registration=force_schema_registration) if subset == "python" or not subset: gen_annotated( - declarations_path or DECLARATIONS_PATH, + native_functions_path or NATIVE_FUNCTIONS_PATH, python_install_dir, autograd_dir) +def get_selector_from_legacy_operator_selection_list( + selected_op_list_path: str, +): + with open(selected_op_list_path, 'r') as f: + # strip out the overload part + # It's only for legacy config - do NOT copy this code! + selected_op_list = { + opname.split('.', 1)[0] for opname in yaml.load(f, Loader=YamlLoader) + } + + # Internal build doesn't use this flag any more. Only used by OSS + # build now. Every operator should be considered a root operator + # (hence generating unboxing code for it, which is consistent with + # the current behaviour), and also be considered as used for + # training, since OSS doesn't support training on mobile for now. + # + is_root_operator = True + is_used_for_training = True + + from tools.codegen.selective_build.selector import SelectiveBuilder + selector = SelectiveBuilder.from_legacy_op_registration_allow_list( + selected_op_list, + is_root_operator, + is_used_for_training, + ) + + return selector + + +def get_selector(selected_op_list_path, operators_yaml_path): + # cwrap depends on pyyaml, so we can't import it earlier + root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + sys.path.insert(0, root) + from tools.codegen.selective_build.selector import SelectiveBuilder + + assert not (selected_op_list_path is not None and + operators_yaml_path is not None), \ + ("Expected at most one of selected_op_list_path and " + + "operators_yaml_path to be set.") + + if selected_op_list_path is None and operators_yaml_path is None: + return SelectiveBuilder.get_nop_selector() + elif selected_op_list_path is not None: + return get_selector_from_legacy_operator_selection_list(selected_op_list_path) + else: + return SelectiveBuilder.from_yaml_path(operators_yaml_path) + + def main(): parser = argparse.ArgumentParser(description='Autogenerate code') parser.add_argument('--declarations-path') + parser.add_argument('--native-functions-path') parser.add_argument('--nn-path') parser.add_argument('--ninja-global') parser.add_argument('--install_dir') @@ -98,14 +151,11 @@ def main(): ) parser.add_argument( '--selected-op-list-path', - help='Path to the yaml file that contains the list of operators to include for custom build.', + help='Path to the YAML file that contains the list of operators to include for custom build.', ) parser.add_argument( - '--selected-op-list', - nargs="*", - type=str, - help="""List of operator names to include for custom build, in addition to those in selected-op-list-path. - For example, --selected-op-list aten::add.Tensor aten::_convolution.""", + '--operators_yaml_path', + help='Path to the model YAML file that contains the list of operators to include for custom build.', ) parser.add_argument( '--force_schema_registration', @@ -114,16 +164,18 @@ def main(): 'listed on --selected-op-list' ) options = parser.parse_args() + generate_code( options.ninja_global, options.declarations_path, options.nn_path, + options.native_functions_path, options.install_dir, options.subset, options.disable_autograd, - options.selected_op_list_path, - options.selected_op_list, options.force_schema_registration, + # options.selected_op_list + operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path), ) diff --git a/tools/shared/cwrap_common.py b/tools/shared/cwrap_common.py index ab50de1769b09..a9e0811d13b04 100644 --- a/tools/shared/cwrap_common.py +++ b/tools/shared/cwrap_common.py @@ -130,7 +130,7 @@ def add_argument(self, arg): self.arguments.append(arg) def __repr__(self): - return self.name + '(' + ', '.join(map(lambda a: a.__repr__(), self.arguments)) + ')' + return self.name + '(' + ', '.join(a.__repr__() for a in self.arguments) + ')' class Argument(object): @@ -151,13 +151,13 @@ def parse_header(path): # Remove empty lines and prebackend directives lines = filter(lambda l: l and not l.startswith('#'), lines) # Remove line comments - lines = map(lambda l: l.partition('//'), lines) + lines = (l.partition('//') for l in lines) # Select line and comment part - lines = map(lambda l: (l[0].strip(), l[2].strip()), lines) + lines = ((l[0].strip(), l[2].strip()) for l in lines) # Remove trailing special signs - lines = map(lambda l: (l[0].rstrip(');').rstrip(','), l[1]), lines) + lines = ((l[0].rstrip(');').rstrip(','), l[1]) for l in lines) # Split arguments - lines = map(lambda l: (l[0].split(','), l[1]), lines) + lines = ((l[0].split(','), l[1]) for l in lines) # Flatten lines new_lines = [] for l, c in lines: @@ -166,7 +166,7 @@ def parse_header(path): lines = new_lines del new_lines # Remove unnecessary whitespace - lines = map(lambda l: (l[0].strip(), l[1]), lines) + lines = ((l[0].strip(), l[1]) for l in lines) # Remove empty lines lines = filter(lambda l: l[0], lines) generic_functions = [] @@ -178,8 +178,8 @@ def parse_header(path): else: fn_name = fn_name[:-1] generic_functions.append(Function(fn_name)) - elif l.startswith('THC_API void THNN_'): - fn_name = l[len('THC_API void THNN_'):] + elif l.startswith('TORCH_CUDA_API void THNN_'): + fn_name = l[len('TORCH_CUDA_API void THNN_'):] if fn_name[0] == '(' and fn_name[-2] == ')': fn_name = fn_name[1:-2] else: diff --git a/tools/shared/module_loader.py b/tools/shared/module_loader.py index 2ba555ffffacc..51c57aa161c93 100644 --- a/tools/shared/module_loader.py +++ b/tools/shared/module_loader.py @@ -1,13 +1,6 @@ -import sys - - def import_module(name, path): - if sys.version_info >= (3, 5): - import importlib.util - spec = importlib.util.spec_from_file_location(name, path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - else: - from importlib.machinery import SourceFileLoader - return SourceFileLoader(name, path).load_module() + import importlib.util + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index b78dc4a362a7f..9b1d6fd4a55f3 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -46,7 +46,7 @@ append_filelist("libtorch_python_core_sources" TORCH_PYTHON_SRCS) # NB: This has to match the condition under which the JIT test directory # is included (at the time of writing that's in caffe2/CMakeLists.txt). -if(BUILD_TEST AND NOT USE_ROCM) +if(BUILD_TEST) add_definitions(-DBUILDING_TESTS) list(APPEND TORCH_PYTHON_SRCS ${TORCH_ROOT}/test/cpp/jit/torch_python_test.cpp @@ -66,6 +66,8 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${CMAKE_BINARY_DIR}/third_party ${CMAKE_BINARY_DIR}/third_party/onnx + ${TORCH_ROOT}/third_party/valgrind-headers + ${TORCH_ROOT}/third_party/gloo ${TORCH_ROOT}/third_party/onnx ${pybind11_INCLUDE_DIRS} @@ -160,30 +162,37 @@ endif() if(USE_DISTRIBUTED) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED) - if(NOT MSVC) + if(WIN32) + append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) + else() + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_RPC) append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) - # Disable certain warnings for GCC-9.X - if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - endif() - list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d) - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) - if(USE_TENSORPIPE) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES tensorpipe) - list(APPEND TORCH_PYTHON_PUBLIC_COMPILE_DEFINITIONS USE_TENSORPIPE) - endif() endif() + # Disable certain warnings for GCC-9.X + if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + endif() + if(USE_TENSORPIPE) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES tensorpipe) + list(APPEND TORCH_PYTHON_PUBLIC_COMPILE_DEFINITIONS USE_TENSORPIPE) + endif() + list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) endif() -if(USE_NCCL) +if(USE_NCCL AND NOT WIN32) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL) endif() +if(USE_VALGRIND AND NOT WIN32) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_VALGRIND) +endif() + # In the most recent CMake versions, a new 'TRANSFORM' subcommand of 'list' allows much of the boilerplate of defining the lists # of type stub files to be omitted. # For comptability with older CMake versions, we omit it for now, but leave it as a comment in case comptability with the older @@ -225,9 +234,9 @@ add_custom_command( "${TORCH_SRC_DIR}/nn/functional.pyi" COMMAND "${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi - --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" + --native-functions-path "aten/src/ATen/native/native_functions.yaml" + --deprecated-functions-path "tools/autograd/deprecated.yaml" DEPENDS - "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" "${TORCH_SRC_DIR}/_C/__init__.pyi.in" "${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in" "${TORCH_SRC_DIR}/nn/functional.pyi.in" diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in index 1360ef079725e..1afd8e6c73d72 100644 --- a/torch/_C/_VariableFunctions.pyi.in +++ b/torch/_C/_VariableFunctions.pyi.in @@ -1,6 +1,6 @@ # ${generated_comment} -from torch import Tensor, Generator, strided, memory_format, contiguous_format +from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar from torch._six import inf diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2543e724b1e0b..18b3e7222886f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -4,9 +4,10 @@ import torch from torch import Tensor from enum import Enum from pathlib import Path -from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple, - Optional, overload, Sequence, Tuple, TypeVar, Type, Union, Generic, - Set, AnyStr) +from typing import ( + Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List, + NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union, + Generic, Set, AnyStr) from torch._six import inf from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage @@ -28,13 +29,20 @@ class device: # THPDevice_pynew @overload - def __init__(self, device: Union[_int, str]) -> None: ... + def __init__(self, device: Union[_device, _int, str]) -> None: ... @overload def __init__(self, type: str, index: _int) -> None: ... def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce +# Defined in torch/csrc/Stream.cpp +class Stream: + _cdata: _int # Stream handle + device: device # The device of the stream + + ... + # Defined in torch/csrc/Size.cpp class Size(Tuple[_int, ...]): # TODO: __reduce__ @@ -93,6 +101,7 @@ def DisableTorchFunction(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp strided : layout = ... sparse_coo : layout = ... +_mkldnn : layout = ... # Defined in torch/csrc/MemoryFormat.cpp class memory_format: ... @@ -106,7 +115,7 @@ preserve_format: memory_format = ... # Defined in torch/csrc/QScheme.cpp class qscheme: ... -# Defined in torch/csrc/utils/tensor_qschemes.cpp +# Defined in torch/csrc/utils/tensor_qschemes.h per_tensor_affine: qscheme = ... per_channel_affine: qscheme = ... per_tensor_symmetric: qscheme = ... @@ -136,9 +145,12 @@ class Future(object): def __init__(self) -> None: ... def done(self) -> _bool: ... def wait(self) -> Any: ... + def add_done_callback(self, callback: Callable) -> None: ... def then(self, callback: Callable) -> Future: ... def set_result(self, result: Any) -> None: ... +def _jit_set_num_profiled_runs(num: _size) -> _size: ... + # Defined in torch/csrc/jit/passes/xnnpack_rewrite.h class MobileOptimizerType: ... @@ -154,7 +166,11 @@ def wait(fut: Future) -> Any: ... def _collect_all(futures: List[Future]) -> Future: ... def unify_type_list(types: List[JitType]) -> JitType: ... -def _freeze_module(module: ScriptModule, preserved_attrs: List[str], freeze_interfaces: _bool = True) -> ScriptModule: ... +def _freeze_module(module: ScriptModule, + preserved_attrs: List[str] = [], + freeze_interfaces: _bool = True, + preserveParameters: _bool = True) -> ScriptModule: ... +def _jit_pass_optimize_frozen_graph(Graph) -> None: ... def _is_tracing() -> _bool: ... def _jit_init() -> _bool: ... def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... @@ -165,18 +181,39 @@ def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule', preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule', preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... +def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule', + preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_inline(Graph) -> None: ... +def _jit_pass_constant_propagation(Graph) -> None: ... def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ... +def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ... def _jit_can_fuse_on_cpu() -> _bool: ... def _jit_can_fuse_on_gpu() -> _bool: ... +def _debug_get_fusion_group_inlining() -> _bool: ... +def _debug_set_fusion_group_inlining(enable: _bool): ... def _jit_texpr_fuser_enabled() -> _bool: ... def _jit_nvfuser_enabled() -> _bool: ... +def _llvm_enabled() -> _bool: ... def _jit_override_can_fuse_on_cpu(override: _bool): ... def _jit_override_can_fuse_on_gpu(override: _bool): ... def _jit_set_texpr_fuser_enabled(enable: _bool): ... def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ... def _jit_pass_canonicalize(graph: Graph): ... def _jit_pass_erase_shape_information(graph: Graph): ... +def _jit_pass_fold_convbn(module: 'torch.jit.ScriptModule'): ... +def _jit_pass_insert_observers(module: 'torch.jit.ScriptModule', + method_name: str, + qconfig_dict: Dict[str, Any], + inplace: _bool, + quant_type: _int): ... +def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule', + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int): ... +def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule', + quant_type: _int, + preserved_attrs: Sequence[str]): ... def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ... def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ... def _jit_try_infer_type(obj: Any) -> JitType: ... @@ -185,9 +222,12 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ... # Defined in torch/csrc/jit/python/script_init.cpp ResolutionCallback = Callable[[str], Callable[..., Any]] +# Defined in torch/csrc/jit/python/script_init.cpp +# and torch/csrc/jit/python/init.cpp def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ... def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... +def _jit_assert_is_instance(obj: Any, type: JitType): ... def _jit_clear_class_registry() -> None: ... def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ... def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ... @@ -213,6 +253,55 @@ def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallb def _create_module_with_type(ty: JitType) -> ScriptModule: ... def _run_emit_module_hook(m: ScriptModule): ... def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ... + +def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... +def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ... +def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ... +def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ... +def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ... +def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... +def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... +def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ... +def _jit_pass_fuse_addmm(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess(graph: Graph) -> None: ... +def _jit_pass_onnx_prepare_inplace_ops_for_onnx(graph: Graph) -> None: ... +def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ... +def _jit_pass_onnx_remove_print(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ... +def _jit_pass_onnx_unpack_quantized_weights( + graph: Graph, + paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_quantization_insert_permutes( + graph: Graph, + paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... +def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ... +def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ... +def _jit_pass_erase_number_types(graph: Graph) -> None: ... +def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ... +def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ... +def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ... +def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ... +def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ... +def _jit_pass_onnx_fold_if(graph: Graph) -> None: ... +def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ... +def _jit_pass_inline_fork_wait(graph: Graph) -> None: ... +def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... +def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ... +def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... +def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... +def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ... +def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... +def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ... +def _jit_pass_onnx_block( + old_block: Block, + new_block: Block, + operator_export_type: _onnx.OperatorExportTypes, + env: Dict[Value, Value] +) -> None: ... +def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ... + def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ... def _jit_script_compile_overload( qualname: str, @@ -248,12 +337,47 @@ def import_ir_module_from_buffer( extra_files: Dict[str, Any] ) -> ScriptModule: ... +def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ... +def _check_onnx_proto(proto: str) -> None: ... +def _propagate_and_assign_input_shapes( + graph: Graph, + inputs: Tuple[Tensor, ...], + with_grad: _bool, + propagate: _bool +) -> Graph: ... + +# Defined in torch/csrc/jit/runtime/graph_executor.h +class GraphExecutorState: + ... + # Defined in torch/torch/csrc/jit/ir/ir.h class Graph: + def eraseInput(self, i: _int) -> None: ... ... +# Defined in torch/csrc/jit/ir/ir.h +class Value: + ... + +# Defined in torch/csrc/jit/ir/ir.h +class Block: + ... + +# Defined in torch/csrc/jit/ir/ir.h +class Node: + ... + + # Defined in torch/aten/src/ATen/core/function_schema.h +class Argument: + name: str + type: JitType + default_value: Optional[Any] + def has_default_value(self) -> _bool: ... + ... class FunctionSchema: + arguments: List[Argument] + returns: List[Argument] ... # Defined in torch/csrc/jit/python/script_init.cpp @@ -268,6 +392,8 @@ class ConcreteModuleTypeBuilder: def add_builtin_function(self, name: str, symbol_name: str): ... def add_failed_attribute(self, name: str, failure_reason: str): ... def add_function_attribute(self, name: str, ty: JitType, func: Callable[..., Any]): ... + def add_ignored_attribute(self, name: str): ... + def add_ignored_attributes(self, names: List[str]): ... class ConcreteModuleType: def get_constants(self) -> Dict[str, Any]: ... @@ -290,6 +416,7 @@ class CompilationUnit: def __init__(self) -> None: ... def find_function(self, name: str) -> ScriptFunction: ... def define(self, script: str, rcb: ResolutionCallback): ... + def get_interface(self, name: str) -> InterfaceType: ... class ScriptModule: def setattr(self, name: str, value: Any): ... @@ -309,8 +436,8 @@ class ScriptFunction: def qualified_name(self) -> str: ... class ScriptMethod: + graph: Graph ... - class ModuleDict: def __init__(self, mod: ScriptModule) -> None: ... def items(self) -> List[Tuple[str, Any]]: ... @@ -321,6 +448,10 @@ class ParameterDict: class BufferDict: def __init__(self, mod: ScriptModule) -> None: ... +# Defined in torch/csrc/jit/api/module.h +class Module: + ... + # Defined in torch/csrc/Module.cpp def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension def _autograd_init() -> _bool: ... # THPAutograd_initExtension @@ -334,6 +465,7 @@ def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN def _crash_if_csrc_ubsan() -> _int: ... # THPModule_crashIfCsrcUBSAN def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN def _show_config() -> str: ... # THPModule_showConfig +def _cxx_flags() -> str: ... # THPModule_cxxFlags def _parallel_info() -> str: ... # THPModule_parallelInfo def _set_backcompat_broadcast_warn(arg: _bool) -> None: ... # THPModule_setBackcompatBroadcastWarn def _get_backcompat_broadcast_warn() -> _bool: ... # THPModule_getBackcompatBroadcastWarn @@ -368,9 +500,17 @@ def _get_qengine() -> _int: ... # THPModule_qEngine def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK +def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function +def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary +def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython +def _demangle(str) -> str: ... # c10::demangle + +# Defined in `valgrind.h` and `callgrind.h` respecitively. +def _valgrind_supported_platform() -> _bool: ... # NVALGRIND +def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT has_openmp: _bool has_mkl: _bool @@ -387,10 +527,16 @@ def is_grad_enabled() -> _bool: ... def set_autocast_enabled(enabled: _bool) -> None: ... def is_autocast_enabled() -> _bool: ... def clear_autocast_cache() -> None: ... -def autocast_increment_nesting() -> None: ... -def autocast_decrement_nesting() -> None: ... +def autocast_increment_nesting() -> _int: ... +def autocast_decrement_nesting() -> _int: ... def set_anomaly_enabled(enabled: _bool) -> None: ... def is_anomaly_enabled() -> _bool: ... +def _enter_dual_level() -> _int: ... +def _exit_dual_level(level: _int) -> None: ... +def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ... +def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ... +def __set_forward_AD_enabled(enabled: _bool) -> None: ... +def __is_forward_AD_enabled() -> _bool: ... # Defined in torch/csrc/jit/python/script_init.cpp class LoggerBase(object): @@ -475,6 +621,10 @@ ${legacy_storage_base_hints} # TODO: where ${legacy_class_hints} +# Defined in torch/csrc/autograd/python_engine.cpp +class _ImperativeEngine: + ... + # Defined in torch/csrc/autograd/python_variable.cpp class _TensorBase(object): requires_grad: _bool @@ -491,17 +641,31 @@ class _TensorBase(object): _version: _int _base: Optional[Tensor] grad_fn: Any + _grad: Optional[Tensor] + _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]] ${tensor_method_hints} +# Defined in torch/csrc/multiprocessing/init.cpp +def _multiprocessing_init() -> None: ... + # Defined in torch/csrc/cuda/Module.cpp def _cuda_getCurrentStream(device: _int) -> _int: ... def _cuda_getDefaultStream(device: _int) -> _int: ... def _cuda_getCurrentBlasHandle() -> _int: ... +def _cuda_setDevice(device: _int) -> None: ... +def _cuda_getDevice() -> _int: ... +def _cuda_getDeviceCount() -> _int: ... +def _cuda_sleep(cycles: _int) -> None: ... +def _cuda_synchronize() -> None: ... +def _cuda_ipc_collect() -> None: ... +def _cuda_getArchFlags() -> Optional[str]: ... +def _cuda_init() -> None: ... def _cuda_setStream(cuda_stream: _int) -> None: ... def _cuda_getCompiledVersion() -> _int: ... def _cuda_cudaHostAllocator() -> _int: ... def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... +def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ... def _cuda_emptyCache() -> None: ... def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ... def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ... @@ -509,8 +673,35 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ... def _cuda_memorySnapshot() -> List[Dict[str, Any]]: ... def _cuda_lock_mutex() -> None: ... def _cuda_unlock_mutex() -> None: ... +def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... def _nccl_version() -> _int: ... def _nccl_unique_id() -> bytes: ... +def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ... +def _nccl_reduce(input: Sequence[Tensor], + output: Tensor, + root: _int, + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]]) -> None: ... +def _nccl_all_reduce(input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]]) -> None: ... +def _nccl_broadcast(input: Sequence[Tensor], + root: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]]) -> None: ... +def _nccl_all_gather(input: Sequence[Tensor], + output: Sequence[Tensor], + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]]) -> None: ... +def _nccl_reduce_scatter(input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]]) -> None: ... + class _CudaDeviceProperties: name: str @@ -521,8 +712,23 @@ class _CudaDeviceProperties: is_integrated: _int is_multi_gpu_board: _int +# Defined in torch/csrc/cuda/python_comm.cpp +def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ... +def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ... +def _broadcast_coalesced( + tensors: List[Tensor], + devices: List[_int], + buffer_size: _int +) -> List[List[Tensor]]: ... + +def _scatter(tensor: Tensor, devices: List[_int], chunk_sizes: Optional[List[_int]], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ... +def _scatter_out(tensor: Tensor, out_tensors: List[Tensor], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ... +def _gather(tensors: List[Tensor], dim: _int, destination_index: Optional[_int]) -> Tensor: ... +def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ... + # Defined in torch/csrc/cuda/Stream.cpp class _CudaStreamBase: + _cdata: _int device: _device cuda_stream: _int priority: _int @@ -547,6 +753,10 @@ class _CudaEventBase: def synchronize(self) -> None: ... def ipc_handle(self) -> bytes: ... +# Defined in torch/csrc/cuda/Graph.cpp +class _CudaGraphBase: + ... + # Defined in torch/csrc/DataLoader.cpp def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs @@ -554,7 +764,13 @@ def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerP def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails # Defined in torch/csrc/jit/python/python_tracer.cpp -class TracingState: ... +class TracingState: + def push_scope(self, scope_name: str) -> None: ... + def pop_scope(self) -> None: ... + def current_scope(self) -> str: ... + def set_graph(self, graph: Graph) -> None: ... + def graph(self) -> Graph: ... + ... def _create_graph_by_tracing( func: Callable[..., Any], @@ -606,6 +822,10 @@ class DeviceObjType(JitType): @staticmethod def get() -> DeviceObjType: ... +class StreamObjType(JitType): + @staticmethod + def get() -> StreamObjType: ... + class ListType(JitType): def __init__(self, a: JitType) -> None: ... def getElementType(self) -> JitType: ... @@ -661,6 +881,8 @@ class EnumType(JitType): class TensorType(JitType): @classmethod def get(cls) -> TensorType: ... + @classmethod + def getInferred(cls) -> TensorType: ... # Defined in torch/csrc/jit/python/python_tree_views.cpp class SourceRange: @@ -681,3 +903,15 @@ class Def(TreeView): class Decl(TreeView): ... + +# Defined in torch/csrc/distributed/rpc/init.cpp +def _rpc_init() -> _bool: ... + +# Defined in torch/csrc/distributed/autograd/init.cpp +def _dist_autograd_init() -> _bool: ... + +# Defined in torch/csrc/distributed/c10d/init.cpp +def _c10d_init() -> _bool: ... + +# Defined in torch/csrc/distributed/rpc/testing/init.cpp +def _faulty_agent_init() -> _bool: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 653b705fe135d..15a286f2370cd 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -1,19 +1,34 @@ -from typing import List +from typing import List, Set from enum import Enum # Defined in tools/autograd/init.cpp class ProfilerState(Enum): - Disable = 0 - CPU = 1 - CUDA = 2 - NVTX = 3 + Disable = ... + CPU = ... + CUDA = ... + NVTX = ... + KINETO = ... +class ProfilerActivity(Enum): + CPU = ... + CUDA = ... -class ProfilerConfig: - def __init__(self, state: ProfilerState, report_input_shapes: bool, profile_memory: bool) -> None: ... +class DeviceType(Enum): + CPU = ... + CUDA = ... ... +class ProfilerConfig: + def __init__( + self, + state: ProfilerState, + report_input_shapes: bool, + profile_memory: bool, + with_stack: bool, + with_flops: bool + ) -> None: ... + ... class ProfilerEvent: def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ... @@ -30,11 +45,28 @@ class ProfilerEvent: def sequence_nr(self) -> int: ... def shapes(self) -> List[List[int]]: ... def thread_id(self) -> int: ... + def flops(self) -> float: ... ... +class KinetoEvent: + def name(self) -> str: ... + def device_index(self) -> int: ... + def start_us(self) -> int: ... + def duration_us(self) -> int: ... + ... + +class ProfilerResult: + def events(self) -> List[KinetoEvent]: ... + def legacy_events(self) -> List[List[ProfilerEvent]]: ... + def save(self, str) -> None: ... -def _enable_profiler(config: ProfilerConfig) -> None: ... -def _disable_profiler() -> List[List[ProfilerEvent]]: ... +def _enable_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ... +def _prepare_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ... +def _disable_profiler() -> ProfilerResult: ... def _profiler_enabled() -> bool: ... +def kineto_available() -> bool: ... def _enable_record_function(enable: bool) -> None: ... def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... + +def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... +def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ... diff --git a/torch/_C/_distributed_autograd.pyi b/torch/_C/_distributed_autograd.pyi new file mode 100644 index 0000000000000..39cbb984c635c --- /dev/null +++ b/torch/_C/_distributed_autograd.pyi @@ -0,0 +1,25 @@ +import torch +from typing import Dict, List, Set, Any + +# This module is defined in torch/csrc/distributed/autograd/init.cpp + +class DistAutogradContext: + def _context_id(self) -> int: ... + def _recv_functions(self) -> Dict[int, Any]: ... + def _send_functions(self) -> Dict[int, Any]: ... + def _known_worker_ids(self) -> Set[int]: ... + +def _new_context() -> DistAutogradContext: ... +def _release_context(context_id: int) -> None: ... +def _get_max_id() -> int: ... +def _is_valid_context(worker_id: int) -> bool: ... +def _retrieve_context(context_id: int) -> DistAutogradContext: ... +def _current_context() -> DistAutogradContext: ... +def _init(worker_id: int) -> None: ... +def _get_debug_info() -> Dict[str, str]: ... +def backward( + context_id: int, + roots: List[torch.Tensor], + retain_graph = False +) -> None: ... +def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi new file mode 100644 index 0000000000000..5ac2c0a8315d4 --- /dev/null +++ b/torch/_C/_distributed_c10d.pyi @@ -0,0 +1,354 @@ +from torch import Tensor +from enum import Enum +from typing import Optional, List, Any, overload +from datetime import timedelta + +# This module is defined in torch/csrc/distributed/c10d/init.cpp + +_DEFAULT_FIRST_BUCKET_BYTES: int +_DEFAULT_NO_TIMEOUT: timedelta + +class BuiltinCommHookType(Enum): + ALLREDUCE = ... + FP16_COMPRESS = ... + +def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ... +def _register_builtin_comm_hook(reducer: Reducer, comm_hook_type: BuiltinCommHookType): ... + +class _GradBucket: + def __init__(self, tensors: List[Tensor]): ... + def get_tensors(self) -> List[Tensor]: ... + +class Reducer: + def __init__( + self, + replicas: List[List[Tensor]], + bucket_indices: List[List[int]], + process_group: ProcessGroup, + expect_sparse_gradients: List[List[bool]], + bucket_bytes_cap: int, + find_unused_parameters: bool, + gradient_as_bucket_view: bool, + ): ... + def initialize_buckets(self, bucket_indices: List[List[int]]): ... + ... + +class ReduceOp(Enum): + SUM = ... + PRODUCT = ... + MIN = ... + MAX = ... + BAND = ... + BOR = ... + BXOR = ... + UNUSED = ... + +class BroadcastOptions: + rootRank: int + rootTensor: int + timeout: timedelta + +class AllreduceOptions: + reduceOp: ReduceOp + timeout: timedelta + +class AllreduceCoalescedOptions(AllreduceOptions): + ... + +class ReduceOptions: + reduceOp: ReduceOp + rootRank: int + rootTensor: int + timeout: timedelta + +class AllGatherOptions: + timeout: timedelta + +class GatherOptions: + rootRank: int + timeout: timedelta + +class ScatterOptions: + rootRank: int + timeout: timedelta + +class ReduceScatterOptions: + reduceOp: ReduceOp + timeout: timedelta + +class BarrierOptions: + device_ids: List[int] + timeout: timedelta + +class AllToAllOptions: + timeout: timedelta + +class Store: + def set(self, key: str, value: str): ... + def get(self, key: str) -> bytes: ... + def add(self, key: str, value: int) -> int: ... + def delete_key(self, key: str) -> bool: ... + def num_keys(self) -> int: ... + def set_timeout(self, timeout: timedelta): ... + @overload + def wait(self, keys: List[str]): ... + @overload + def wait(self, keys: List[str], timeout: timedelta): ... + +class FileStore(Store): + def __init__( + self, + path: str, + numWorkers: int + ): ... + +class HashStore(Store): + def __init__(self): ... + +class TCPStore(Store): + def __init__( + self, + host_name: str, + port: int, + world_size: int, + is_master: bool, + timeout: timedelta, + ): ... + +class PrefixStore(Store): + def __init__( + self, + prefix: str, + store: Store + ): ... + +class Work: + def is_completed(self) -> bool: ... + def is_success(self) -> bool: ... + def exception(self) -> Any: ... + def wait(self, timeout: timedelta = _DEFAULT_NO_TIMEOUT) -> bool: ... + def source_rank(self) -> int: ... + def _source_rank(self) -> int: ... + def result(self) -> List[Tensor]: ... + def synchronize(self): ... + ... + +class ProcessGroup: + def __init__(self): ... + def rank(self) -> int: ... + def size(self) -> int: ... + @overload + def broadcast( + self, + tensors: List[Tensor], + opts = BroadcastOptions(), + ) -> Work: ... + @overload + def broadcast( + self, + tensor: Tensor, + root: int, + ) -> Work: ... + @overload + def allreduce( + self, + tensors: List[Tensor], + opts: AllreduceOptions = AllreduceOptions(), + ) -> Work: ... + @overload + def allreduce( + self, + tensors: List[Tensor], + op = ReduceOp.SUM, + ) -> Work: ... + @overload + def allreduce( + self, + tensor: Tensor, + op = ReduceOp.SUM, + ) -> Work: ... + def allreduce_coalesced( + self, + tensors: List[Tensor], + opts = AllreduceCoalescedOptions(), + ) -> Work: ... + @overload + def reduce( + self, + tensors: List[Tensor], + opts = ReduceOptions(), + ) -> Work: ... + @overload + def reduce( + self, + tensor: Tensor, + root: int, + op = ReduceOp.SUM, + ) -> Work: ... + @overload + def allgather( + self, + output_tensors: List[List[Tensor]], + input_tensors: List[Tensor], + opts = AllGatherOptions(), + ) -> Work: ... + @overload + def allgather( + self, + output_tensors: List[Tensor], + input_tensor: Tensor, + ) -> Work: ... + def allgather_coalesced( + self, + output_lists: List[List[Tensor]], + input_list: List[Tensor], + opts = AllGatherOptions(), + ) -> Work: ... + @overload + def gather( + self, + output_tensors: List[List[Tensor]], + input_tensors: List[Tensor], + opts = GatherOptions(), + ) -> Work: ... + @overload + def gather( + self, + output_tensors: List[Tensor], + input_tensor: Tensor, + root: int, + ) -> Work: ... + @overload + def scatter( + self, + output_tensors: List[Tensor], + input_tensors: List[List[Tensor]], + opts = ScatterOptions(), + ) -> Work: ... + @overload + def scatter( + self, + output_tensor: Tensor, + input_tensors: List[Tensor], + root: int, + ) -> Work: ... + @overload + def reduce_scatter( + self, + output_tensors: List[Tensor], + input_tensors: List[List[Tensor]], + opts = ReduceScatterOptions(), + ) -> Work: ... + @overload + def reduce_scatter( + self, + output_tensors: Tensor, + input_tensor: List[Tensor], + ) -> Work: ... + @overload + def alltoall_base( + self, + output_tensor: Tensor, + input_tensor: Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + opts = AllToAllOptions(), + ) -> Work: ... + @overload + def alltoall_base( + self, + output: Tensor, + input: Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + ) -> Work: ... + @overload + def alltoall( + self, + output_tensor: List[Tensor], + input_tensor: List[Tensor], + opts = AllToAllOptions(), + ) -> Work: ... + @overload + def alltoall( + self, + output: List[Tensor], + input: List[Tensor], + ) -> Work: ... + def send( + self, + tensors: List[Tensor], + dstRank: int, + tag: int, + ) -> Work: ... + def recv( + self, + tensors: List[Tensor], + srcRank: int, + tag: int, + ) -> Work: ... + def recv_anysource( + self, + tensors: List[Tensor], + tag: int + ) -> Work: ... + def barrier( + self, + opts = BarrierOptions() + ) -> Work: ... + +class ProcessGroupRoundRobin(ProcessGroup): ... +def _round_robin_process_groups( + process_groups: List[ProcessGroup], +) -> ProcessGroupRoundRobin: ... + + +class ProcessGroupGloo(ProcessGroup): + class Device: ... + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ): ... + @staticmethod + def create_device(hostname = str(), interface = str()) -> Device: ... + ... + +class ProcessGroupNCCL(ProcessGroup): + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ): ... + @staticmethod + def _group_start() -> None: ... + @staticmethod + def _group_end() -> None: ... + ... + +class ProcessGroupMPI(ProcessGroup): + def __init__( + self, + rank: int, + size: int, + pgComm: int, + ): ... + @staticmethod + def create(ranks: List[int]) -> ProcessGroupMPI: ... + +def _compute_bucket_assignment_by_size( + tensors: List[Tensor], + bucket_size: int, + expect_sparse_gradient: List[bool], + tensor_indices: List[int]) -> List[List[int]]: ... +def _broadcast_coalesced( + process_group: ProcessGroup, + tensors: List[Tensor], + buffer_size: int, + src: int, +): ... +def _test_python_store(store: Store): ... diff --git a/torch/_C/_distributed_rpc.pyi b/torch/_C/_distributed_rpc.pyi new file mode 100644 index 0000000000000..0df9df60aeb2b --- /dev/null +++ b/torch/_C/_distributed_rpc.pyi @@ -0,0 +1,194 @@ +from typing import Tuple, Dict, Optional, List, Any, overload +from datetime import timedelta +import enum +import torch +from . import Future +from ._autograd import ProfilerConfig, ProfilerState, ProfilerEvent +from ._distributed_c10d import ProcessGroup, Store + +# This module is defined in torch/csrc/distributed/rpc/init.cpp + +_DEFAULT_NUM_SEND_RECV_THREADS: int +_DEFAULT_INIT_METHOD: str +_DEFAULT_NUM_WORKER_THREADS: int +_UNSET_RPC_TIMEOUT: float +_DEFAULT_RPC_TIMEOUT_SEC: float + +class RpcBackendOptions: + rpc_timeout: float + init_method: str + def __init__( + self, + rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = _DEFAULT_INIT_METHOD, + ): ... + +class WorkerInfo: + def __init__(self, name: str, worker_id: int): ... + @property + def name(self) -> str: ... + @property + def id(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + def __repr__(self) -> str: ... + +class RpcAgent: + def join(self): ... + def sync(self): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def get_debug_info(self) -> Dict[str, str]: ... + def get_metrics(self) -> Dict[str, str]: ... + +class PyRRef: + def __init__(self, value: Any, type_hint: Any = None): ... + def is_owner(self) -> bool: ... + def confirmed_by_owner(self) -> bool: ... + def owner(self) -> WorkerInfo: ... + def owner_name(self) -> str: ... + def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ... + def local_value(self) -> Any: ... + def rpc_sync(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ... + def rpc_async(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ... + def remote(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ... + def _serialize(self) -> Tuple: ... + @staticmethod + def _deserialize(tp: Tuple) -> 'PyRRef': ... + def _get_type(self) -> Any: ... + def _get_future(self) -> Future: ... + def _get_profiling_future(self) -> Future: ... + def _set_profiling_future(self, profilingFuture: Future): ... + def __repr__(self) -> str: ... + ... + +class ProcessGroupRpcBackendOptions(RpcBackendOptions): + num_send_recv_threads: int + def __init__( + self, + num_send_recv_threads: int, + rpc_timeout: float, + init_method: str + ): ... + +class ProcessGroupAgent(RpcAgent): + def __init__( + self, + worker_name: str, + pg: ProcessGroup, + numSendRecvThreads: int, + rpcTimeout: timedelta + ): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + @overload + def get_worker_info(self, id: int) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def join(self): ... + def shutdown(self): ... + def sync(self): ... + +class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions): + num_worker_threads: int + device_maps: Dict[str, Dict[int, int]] + def __init__( + self, + num_worker_threads: int, + _transports: Optional[List], + _channels: Optional[List], + rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = _DEFAULT_INIT_METHOD, + device_maps: Dict[str, Dict[int, int]] = dict()): ... + def set_device_map(self, to: str, device_map: Dict[str, Dict[int, int]]): ... + +class TensorPipeAgent(RpcAgent): + def __init__( + self, + store: Store, + name: str, + worker_id: int, + world_size: int, + pg: ProcessGroup, + opts: _TensorPipeRpcBackendOptionsBase, + ): ... + def join(self): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + @overload + def get_worker_info(self, id: int) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[int, int]]): ... + +def _is_current_rpc_agent_set() -> bool: ... +def _get_current_rpc_agent()-> RpcAgent: ... +def _set_and_start_rpc_agent(agent: RpcAgent): ... +def _reset_current_rpc_agent(): ... +def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ... +def _destroy_rref_context(ignoreRRefLeak: bool): ... +def _rref_context_get_debug_info() -> Dict[str, str]: ... +def _cleanup_python_rpc_handler(): ... +def _invoke_rpc_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any + ): ... +def _invoke_rpc_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: List[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool + ): ... +def _invoke_rpc_torchscript( + dstWorkerName: str, + qualifiedNameStr: str, + argsTuple: Tuple, + kwargsDict: Dict, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + ): ... +def _invoke_remote_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any + ): ... +def _invoke_remote_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: List[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + ): ... +def _invoke_remote_torchscript( + dstWorkerName: WorkerInfo, + qualifiedNameStr: str, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + *args: Any, + **kwargs: Any + ): ... +def get_rpc_timeout() -> float: ... +def enable_gil_profiling(flag: bool): ... +def _set_rpc_timeout(rpcTimeoutSeconds: float): ... + +class RemoteProfilerManager: + @staticmethod + def set_current_profiling_key(key: str): ... + +def _enable_server_process_global_profiler(new_config: ProfilerConfig): ... +def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ... +def _set_profiler_node_id(default_node_id: int): ... +def _enable_jit_rref_pickle(): ... +def _disable_jit_rref_pickle(): ... diff --git a/torch/_C/_distributed_rpc_testing.pyi b/torch/_C/_distributed_rpc_testing.pyi new file mode 100644 index 0000000000000..f5648138f767b --- /dev/null +++ b/torch/_C/_distributed_rpc_testing.pyi @@ -0,0 +1,42 @@ +from ._distributed_c10d import ProcessGroup +from ._distributed_rpc import ProcessGroupAgent, ProcessGroupRpcBackendOptions, WorkerInfo +from typing import List, Dict, overload +from datetime import timedelta + +# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp + +class FaultyProcessGroupRpcBackendOptions(ProcessGroupRpcBackendOptions): + def __init__( + self, + num_send_recv_threads: int, + rpc_timeout: float, + init_method: str, + messages_to_fail: List[str], + messages_to_delay: Dict[str, float], + num_fail_sends: int + ): ... + num_send_recv_threads: int + messages_to_fail: List[str] + messages_to_delay: Dict[str, float] + num_fail_sends: int + +class FaultyProcessGroupAgent(ProcessGroupAgent): + def __init__( + self, + name: str, + process_group: ProcessGroup, + num_send_recv_threads: int, + rpc_timeout: timedelta, + messages_to_fail: List[str], + messages_to_delay: Dict[str, float], + num_fail_sends: int + ): ... + def join(self): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + @overload + def get_worker_info(self, id: int) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... diff --git a/torch/_C/_functions.pyi b/torch/_C/_functions.pyi new file mode 100644 index 0000000000000..bbc6606722f3f --- /dev/null +++ b/torch/_C/_functions.pyi @@ -0,0 +1,12 @@ +from torch import Tensor +from typing import AnyStr, List + +class UndefinedGrad: + def __init__(self) -> None: ... + def __call__(self, *inputs: Tensor) -> List[Tensor]: ... + ... + +class DelayedError: + def __init__(self, msg: AnyStr, num_inputs: int) -> None: ... + def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ... + ... \ No newline at end of file diff --git a/torch/_C/_nn.pyi.in b/torch/_C/_nn.pyi.in index ee4c130f540bb..f465b16cfbeec 100644 --- a/torch/_C/_nn.pyi.in +++ b/torch/_C/_nn.pyi.in @@ -1,5 +1,6 @@ -from torch import Tensor -from typing import Callable, Optional, List +from torch import Tensor, memory_format +from typing import Callable, Optional, List, overload, Tuple +from torch.types import _bool, _dtype, _device # Defined in tools/autograd/templates/python_nn_functions.cpp @@ -10,4 +11,15 @@ def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tens # Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp def mkldnn_reorder_conv2d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ... -def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ... \ No newline at end of file +def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ... + +# Defined at tools/autograd/templates/python_nn_functions.cpp +@overload +def _parse_to(device: _device, dtype: _dtype, non_blocking: _bool, copy: _bool, *, + memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ... +@overload +def _parse_to(dtype: _dtype, non_blocking: _bool, copy: _bool, *, + memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ... +@overload +def _parse_to(tensor: Tensor, non_blocking: _bool, copy: _bool, *, + memory_format: memory_format) -> Tuple[_device, _dtype, _bool, memory_format]: ... \ No newline at end of file diff --git a/torch/_C/_onnx.pyi b/torch/_C/_onnx.pyi index 51f16566ce6c9..7ab3cd9c567d2 100644 --- a/torch/_C/_onnx.pyi +++ b/torch/_C/_onnx.pyi @@ -29,6 +29,7 @@ class OperatorExportTypes(Enum): ONNX_ATEN = ... ONNX_ATEN_FALLBACK = ... RAW = ... + ONNX_FALLTHROUGH = ... class TrainingMode(Enum): EVAL = ... diff --git a/torch/__config__.py b/torch/__config__.py index e4c3fde9ec3ce..edddcbce46459 100644 --- a/torch/__config__.py +++ b/torch/__config__.py @@ -9,8 +9,11 @@ def show(): return torch._C._show_config() # TODO: In principle, we could provide more structured version/config -# information here. We're not for now; considering doing so if someone -# asks for it. +# information here. For now only CXX_FLAGS is exposed, as Timer +# uses them. +def _cxx_flags(): + """Returns the CXX_FLAGS used when building PyTorch.""" + return torch._C._cxx_flags() def parallel_info(): r"""Returns detailed string with parallelization settings""" diff --git a/torch/__init__.py b/torch/__init__.py index da9eecad7df50..9ae1010a3ba87 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -326,14 +326,16 @@ def set_default_dtype(d): _C._set_default_dtype(d) def set_deterministic(d): - r""" Sets whether native PyTorch operations must use deterministic - algorithms. When True, operations without deterministic algorithms - will throw a :class:RuntimeError when called. + r""" Sets whether PyTorch operations must use "deterministic" + algorithms. That is, algorithms which, given the same input, and when + run on the same software and hardware, always produce the same output. + When True, operations will use deterministic algorithms when available, + and if only nondeterministic algorithms are available they will throw a + :class:RuntimeError when called. .. warning:: - This feature is a beta feature, so it does not affect every - nondeterministic operation yet. The following operations are - affected by this flag. + This feature is in beta, and its design and implementation may change + in the future. The following normally-nondeterministic operations will act deterministically when `d=True`: @@ -357,11 +359,13 @@ def set_deterministic(d): * :class:`torch.nn.FractionalMaxPool2d` when called on a CUDA tensor that requires grad * :class:`torch.nn.FractionalMaxPool3d` when called on a CUDA tensor that requires grad * :func:`torch.nn.functional.interpolate` when called on a CUDA tensor that requires grad - and one of the following modes is used: - - `linear` - - `bilinear` - - `bicubic` - - `trilinear` + and one of the following modes is used: + + - `linear` + - `bilinear` + - `bicubic` + - `trilinear` + * :class:`torch.nn.ReflectionPad1d` when called on a CUDA tensor that requires grad * :class:`torch.nn.ReflectionPad2d` when called on a CUDA tensor that requires grad * :class:`torch.nn.ReplicationPad1d` when called on a CUDA tensor that requires grad @@ -372,10 +376,13 @@ def set_deterministic(d): * :class:`torch.nn.EmbeddingBag` when called on a CUDA tensor that requires grad * :func:`torch.scatter_add_` when called on a CUDA tensor * :func:`torch.index_add_` when called on a CUDA tensor + * :func:`torch.index_copy` * :func:`torch.index_select` when called on a CUDA tensor that requires grad * :func:`torch.repeat_interleave` when called on a CUDA tensor that requires grad * :func:`torch.histc` when called on a CUDA tensor * :func:`torch.bincount` when called on a CUDA tensor + * :func:`torch.kthvalue` with called on a CUDA tensor + * :func:`torch.median` with indices output when called on a CUDA tensor A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater, unless the environment variable `CUBLAS_WORKSPACE_CONFIG=:4096:8` @@ -465,11 +472,13 @@ class QInt8Storage(_C.QInt8StorageBase, _StorageBase): class QInt32Storage(_C.QInt32StorageBase, _StorageBase): pass +class QUInt4x2Storage(_C.QUInt4x2StorageBase, _StorageBase): + pass _storage_classes = { DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage, QUInt8Storage, QInt8Storage, - QInt32Storage, BFloat16Storage, ComplexFloatStorage, ComplexDoubleStorage + QInt32Storage, BFloat16Storage, ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage } # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings() @@ -538,6 +547,21 @@ def manager_path(): del BFloat16StorageBase del ComplexDoubleStorageBase del ComplexFloatStorageBase +del QUInt4x2StorageBase + +################################################################################ +# Define _assert +################################################################################ + +# needs to be before the submodule imports to avoid circular dependencies +def _assert(condition, message): + r"""A wrapper around Python's assert which is symbolically traceable. + """ + from .overrides import has_torch_function, handle_torch_function + + if type(condition) is not torch.Tensor and has_torch_function((condition,)): + return handle_torch_function(_assert, (condition,), condition, message) + assert condition, message ################################################################################ # Import most common subpackages @@ -546,12 +570,14 @@ def manager_path(): import torch.cuda import torch.autograd from torch.autograd import no_grad, enable_grad, set_grad_enabled -# import torch.fft # TODO: enable once torch.fft() is removed +import torch.fft import torch.futures import torch.nn import torch.nn.intrinsic +import torch.nn.quantizable import torch.nn.quantized import torch.optim +import torch.optim._multi_tensor import torch.multiprocessing import torch.sparse import torch.utils.backcompat @@ -571,6 +597,7 @@ def manager_path(): import torch.utils.data import torch.__config__ import torch.__future__ +import torch.profiler _C._init_names(list(torch._storage_classes)) diff --git a/torch/_autograd_functions.py b/torch/_autograd_functions.py new file mode 100644 index 0000000000000..2ea67aa38e65e --- /dev/null +++ b/torch/_autograd_functions.py @@ -0,0 +1,79 @@ +import torch + +class _LU(torch.autograd.Function): + @staticmethod + def forward(ctx, self, pivot=True, get_infos=False): + LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) + ctx.save_for_backward(LU, pivots) + ctx.mark_non_differentiable(pivots, infos) + return LU, pivots, infos + + @staticmethod + def backward(ctx, LU_grad, pivots_grad, infors_grad): + """ + Here we derive the gradients for the LU decomposition. + LIMITATIONS: square inputs of full rank. + If not stated otherwise, for tensors A and B, + `A B` means the matrix product of A and B. + + Forward AD: + Note that PyTorch returns packed LU, it is a mapping + A -> (B:= L + U - I, P), such that A = P L U, and + P is a permutation matrix, and is non-differentiable. + + Using B = L + U - I, A = P L U, we get + + dB = dL + dU and (*) + P^T dA = dL U + L dU (**) + + By left/right multiplication of (**) with L^{-1}/U^{-1} we get: + L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}. + + Note that L^{-1} dL is lower-triangular with zero diagonal, + and dU U^{-1} is upper-triangular. + Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so + + L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}), + dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product. + + Hence we finally get: + dL = L 1_L * (L^{-1} P^T dA U^{-1}), + dU = 1_U * (L^{-1} P^T dA U^{-1}) U + + Backward AD: + The backward sensitivity is then: + Tr(B_grad^T dB) = Tr(B_grad^T dL) + Tr(B_grad^T dU) = [1] + [2]. + + [1] = Tr(B_grad^T dL) = Tr(B_grad^T L 1_L * (L^{-1} P^T dA U^{-1})) + = [using Tr(A (B * C)) = Tr((A * B^T) C)] + = Tr((B_grad^T L * 1_L^T) L^{-1} P^T dA U^{-1}) + = [cyclic property of trace] + = Tr(U^{-1} (B_grad^T L * 1_L^T) L^{-1} P^T dA) + = Tr((P L^{-T} (L^T B_grad * 1_L) U^{-T})^T dA). + Similar, [2] can be rewritten as: + [2] = Tr(P L^{-T} (B_grad U^T * 1_U) U^{-T})^T dA, hence + Tr(A_grad^T dA) = [1] + [2] + = Tr((P L^{-T} (L^T B_grad * 1_L + B_grad U^T * 1_U) U^{-T})^T dA), so + A_grad = P L^{-T} (L^T B_grad * 1_L + B_grad U^T * 1_U) U^{-T}. + + In the code below we use the name `LU` instead of `B`, so that there is no confusion + in the derivation above between the matrix product and a two-letter variable name. + """ + LU, pivots = ctx.saved_tensors + P, L, U = torch.lu_unpack(LU, pivots) + + # To make sure MyPy infers types right + assert (L is not None) and (U is not None) + + I = LU_grad.new_zeros(LU_grad.shape) + I.diagonal(dim1=-2, dim2=-1).fill_(1) + + Lt_inv = torch.triangular_solve(I, L, upper=False).solution.transpose(-1, -2) + Ut_inv = torch.triangular_solve(I, U, upper=True).solution.transpose(-1, -2) + + phi_L = (L.transpose(-1, -2) @ LU_grad).tril_() + phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0) + phi_U = (LU_grad @ U.transpose(-1, -2)).triu_() + + self_grad_perturbed = Lt_inv @ (phi_L + phi_U) @ Ut_inv + return P @ self_grad_perturbed, None, None diff --git a/torch/_classes.py b/torch/_classes.py index ecf9987c69b94..90cc6c7672f6d 100644 --- a/torch/_classes.py +++ b/torch/_classes.py @@ -40,7 +40,7 @@ def load_library(self, path): ``torch.classes.loaded_libraries`` attribute, a set that may be inspected for the paths of all libraries loaded using this function. - Arguments: + Args: path (str): A path to a shared library to load. """ torch.ops.load_library(path) diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 5fa2ee639a9f0..4ccffc2c83629 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -8,8 +8,10 @@ import collections import enum import inspect +import ast import weakref import warnings +from textwrap import dedent import torch import sys # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. @@ -78,7 +80,7 @@ def parseExpr(expr, module): value, len_parsed = parseNestedExpr(expr, module) assert len_parsed == len(expr), "whole expression was not parsed, falling back to c++ parser" return value - except Exception as e: + except Exception: """ The python resolver fails in several cases in known unit tests, and is intended to fall back gracefully to the c++ resolver in general. For example, python 2 style @@ -226,6 +228,96 @@ def can_compile_class(cls): return all(has_code) +def get_annotation_str(annotation): + """ + Convert an AST node containing a type annotation to the string present in the source + that represents the same annotation. + """ + if isinstance(annotation, ast.Name): + return annotation.id + elif isinstance(annotation, ast.Attribute): + return '.'.join([get_annotation_str(annotation.value), annotation.attr]) + elif isinstance(annotation, ast.Subscript): + return f"{get_annotation_str(annotation.value)}[{get_annotation_str(annotation.slice.value)}]" # type: ignore + elif isinstance(annotation, ast.Tuple): + return ','.join([get_annotation_str(elt) for elt in annotation.elts]) + elif isinstance(annotation, ast.Constant) or isinstance(annotation, ast.NameConstant): + return f"{annotation.value}" + + # If an AST node is not handled here, it's probably handled in ScriptTypeParser. + return None + + +def get_type_hint_captures(fn): + """ + Get a dictionary containing type resolution mappings necessary to resolve types + for the literal annotations on 'fn'. These are not considered to be closed-over by fn + and must be obtained separately (e.g. using this function). + + Args: + fn: A callable. + Returns: + A Dict[str, Any] containing a mapping from the literal annotations used on + fn to the Python objects they refer to. + """ + # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated + # types are strings. These are only understood by TorchScript in the context of a type annotation + # that refers to a class in its own definition, but trying to include a mapping for this in the result + # function would cause infinite recursion because the class is currently being compiled. + # In addition, there is logic in ScriptTypeParser to handle this. + signature = inspect.signature(fn) + name_to_type = { + name: parameter.annotation + for name, parameter in signature.parameters.items() + if parameter.annotation is not inspect.Parameter.empty and not isinstance(parameter.annotation, str) + } + + # Then, get the literal type annotations from the function declaration + # by source inspection. This accounts for the case in which aliases are used + # to annotate the arguments (e.g device_t = torch.device, and then d: device_t). + src = inspect.getsource(fn) + + # frontend.py cannot be used here because it includes _jit_internal, so use ast instead. + a = ast.parse(dedent(src)) + if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): + raise RuntimeError(f"Expected {fn} to be a function") + f = a.body[0] + + # Prepare a dictionary of source annotation -> type, which will be the final result of this function, + # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping + # them to the type object corresponding to the annotation via name_to_type using the parameter name. + annotation_to_type = {} + + for arg in f.args.args: + # Get the source type annotation string for this argument if possible. + arg_annotation_str = get_annotation_str(arg.annotation) if arg.annotation else None + + # If the argument has no annotation or get_annotation_str cannot convert it to a string, + # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle + # this in the latter case. + if arg_annotation_str is None: + continue + + # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not + # be present in name_to_type is that the annotation itself is a string and not a type object + # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this. + arg_name = arg.arg + if arg_name in name_to_type: + annotation_to_type[arg_annotation_str] = name_to_type[arg_name] + + # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations, + # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type + # of the annotation cannot be a string. + literal_return_annotation = get_annotation_str(f.returns) + valid_literal_annotation = literal_return_annotation is not None + return_annotation = signature.return_annotation + valid_return_annotation_type = return_annotation is not inspect.Parameter.empty and not isinstance(return_annotation, str) + if valid_literal_annotation and valid_return_annotation_type: + annotation_to_type[literal_return_annotation] = return_annotation + + return annotation_to_type + + def createResolutionCallbackForClassMethods(cls): """ This looks at all the methods defined in a class and pulls their closed-over @@ -238,6 +330,7 @@ def createResolutionCallbackForClassMethods(cls): for fn in fns: captures.update(get_closure(fn)) + captures.update(get_type_hint_captures(fn)) def lookup_in_class(key): if key in captures: @@ -366,9 +459,9 @@ def unused(fn): import torch.nn as nn class MyModule(nn.Module): - def __init__(self, use_memory_efficent): + def __init__(self, use_memory_efficient): super(MyModule, self).__init__() - self.use_memory_efficent = use_memory_efficent + self.use_memory_efficient = use_memory_efficient @torch.jit.unused def memory_efficient(self, x): @@ -383,13 +476,22 @@ def forward(self, x): else: return x + 10 - m = torch.jit.script(MyModule(use_memory_efficent=False)) + m = torch.jit.script(MyModule(use_memory_efficient=False)) m.save("m.pt") m = torch.jit.script(MyModule(use_memory_efficient=True)) # exception raised m(torch.rand(100)) """ + if isinstance(fn, property): + prop = fn + setattr(prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010 + + if prop.fset: + setattr(prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED) # noqa: B010 + + return prop + fn._torchscript_modifier = FunctionModifiers.UNUSED return fn @@ -628,11 +730,7 @@ def _get_overloaded_methods(method, mod_class): def is_tuple(ann): if ann is Tuple: - raise RuntimeError( - "Attempted to use Tuple without a " - "contained type. Please add a contained type, e.g. " - "Tuple[int]" - ) + raise_error_container_parameter_missing("Tuple") # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule if not hasattr(ann, '__module__'): @@ -643,11 +741,7 @@ def is_tuple(ann): def is_list(ann): if ann is List: - raise RuntimeError( - "Attempted to use List without a " - "contained type. Please add a contained type, e.g. " - "List[int]" - ) + raise_error_container_parameter_missing("List") if not hasattr(ann, '__module__'): return False @@ -657,11 +751,7 @@ def is_list(ann): def is_dict(ann): if ann is Dict: - raise RuntimeError( - "Attempted to use Dict without " - "contained types. Please add contained type, e.g. " - "Dict[int, int]" - ) + raise_error_container_parameter_missing("Dict") if not hasattr(ann, '__module__'): return False @@ -671,11 +761,7 @@ def is_dict(ann): def is_optional(ann): if ann is Optional: - raise RuntimeError( - "Attempted to use Optional without a " - "contained type. Please add a contained type, e.g. " - "Optional[int]" - ) + raise_error_container_parameter_missing("Optional") # Optional[T] is just shorthand for Union[T, None], so check for both def safe_is_subclass(the_type, super_type): @@ -724,7 +810,7 @@ def is_rref(ann): def is_final(ann): return ann.__module__ in {'typing', 'typing_extensions'} and \ - (getattr(ann, '__origin__', None) is Final) + (getattr(ann, '__origin__', None) is Final or isinstance(ann, type(Final))) # allows BroadcastingList instance to be subscriptable class BroadcastingListCls(object): @@ -846,7 +932,7 @@ def _get_named_tuple_properties(obj): the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) annotations.append(the_type) else: - annotations.append(torch._C.TensorType.get()) + annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, fields, annotations @@ -876,3 +962,111 @@ def _is_exception(obj): if not inspect.isclass(obj): return False return issubclass(obj, Exception) + +def raise_error_container_parameter_missing(target_type): + if target_type == 'Dict': + raise RuntimeError( + "Attempted to use Dict without " + "contained types. Please add contained type, e.g. " + "Dict[int, int]" + ) + raise RuntimeError( + f"Attempted to use {target_type} without a " + "contained type. Please add a contained type, e.g. " + f"{target_type}[int]" + ) + + +def get_origin(target_type): + return getattr(target_type, "__origin__", None) + + +def get_args(target_type): + return getattr(target_type, "__args__", None) + + +def check_args_exist(target_type): + if target_type is List or target_type is list: + raise_error_container_parameter_missing("List") + elif target_type is Tuple or target_type is tuple: + raise_error_container_parameter_missing("Tuple") + elif target_type is Dict or target_type is dict: + raise_error_container_parameter_missing("Dict") + elif target_type is None or target_type is Optional: + raise_error_container_parameter_missing("Optional") + + +# supports List/Dict/Tuple and Optional types +# TODO support future +def container_checker(obj, target_type): + origin_type = get_origin(target_type) + check_args_exist(target_type) + if origin_type is list or origin_type is List: + if not isinstance(obj, list): + return False + arg_type = get_args(target_type)[0] + arg_origin = get_origin(arg_type) + for el in obj: + # check if nested container, ex: List[List[str]] + if arg_origin: # processes nested container, ex: List[List[str]] + if not container_checker(el, arg_type): + return False + elif not isinstance(el, arg_type): + return False + return True + elif origin_type is Dict or origin_type is dict: + if not isinstance(obj, dict): + return False + key_type = get_args(target_type)[0] + val_type = get_args(target_type)[1] + for key, val in obj.items(): + # check if keys are of right type + if not isinstance(key, key_type): + return False + val_origin = get_origin(val_type) + if val_origin: + if not container_checker(val, val_type): + return False + elif not isinstance(val, val_type): + return False + return True + elif origin_type is Tuple or origin_type is tuple: + if not isinstance(obj, tuple): + return False + arg_types = get_args(target_type) + if len(obj) != len(arg_types): + return False + for el, el_type in zip(obj, arg_types): + el_origin = get_origin(el_type) + if el_origin: + if not container_checker(el, el_type): + return False + elif not isinstance(el, el_type): + return False + return True + elif origin_type is Union: # actually handles Optional Case + if obj is None: # check before recursion because None is always fine + return True + optional_type = get_args(target_type)[0] + optional_origin = get_origin(optional_type) + if optional_origin: + return container_checker(obj, optional_type) + elif isinstance(obj, optional_type): + return True + return False + + +def _isinstance(obj, target_type) -> bool: + origin_type = get_origin(target_type) + if origin_type: + return container_checker(obj, target_type) + + # Check to handle weird python type behaviors + # 1. python 3.6 returns None for origin of containers without + # contained type (intead of returning outer container type) + # 2. non-typed optional origin returns as none instead + # of as optional in 3.6-3.8 + check_args_exist(target_type) + + # handle non-containers + return isinstance(obj, target_type) diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index d86c944d0b350..5237d4371f418 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -2,10 +2,10 @@ """ -from typing import Optional, Tuple - -import torch from torch import Tensor +import torch + +from typing import Optional, Tuple def is_sparse(A): @@ -29,8 +29,7 @@ def get_floating_dtype(A): return torch.float32 -def matmul(A, B): - # type: (Optional[Tensor], Tensor) -> Tensor +def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: """Multiply two matrices. If A is None, return B. A can be sparse or dense. B is always @@ -66,15 +65,13 @@ def transjugate(A): return conjugate(transpose(A)) -def bform(X, A, Y): - # type: (Tensor, Optional[Tensor], Tensor) -> Tensor +def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: """Return bilinear form of matrices: :math:`X^T A Y`. """ return matmul(transpose(X), matmul(A, Y)) -def qform(A, S): - # type: (Optional[Tensor], Tensor) -> Tensor +def qform(A: Optional[Tensor], S: Tensor): """Return quadratic form :math:`S^T A S`. """ return bform(S, A, S) @@ -91,8 +88,7 @@ def basis(A): return Q -def symeig(A, largest=False, eigenvectors=True): - # type: (Tensor, Optional[bool], Optional[bool]) -> Tuple[Tensor, Tensor] +def symeig(A: Tensor, largest: Optional[bool] = False, eigenvectors: Optional[bool] = True) -> Tuple[Tensor, Tensor]: """Return eigenpairs of A with specified ordering. """ if largest is None: diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index b0cbf45b252b1..dfe52774f8ca9 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -13,23 +13,343 @@ __all__ = ['lobpcg'] +def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U): + # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0 + F = D.unsqueeze(-2) - D.unsqueeze(-1) + F.diagonal(dim1=-2, dim2=-1).fill_(float('inf')) + F.pow_(-1) + + # A.grad = U (D.grad + (U^T U.grad * F)) U^T + Ut = U.transpose(-1, -2).contiguous() + res = torch.matmul( + U, + torch.matmul( + torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, + Ut + ) + ) + + return res + + +def _polynomial_coefficients_given_roots(roots): + """ + Given the `roots` of a polynomial, find the polynomial's coefficients. + + If roots = (r_1, ..., r_n), then the method returns + coefficients (a_0, a_1, ..., a_n (== 1)) so that + p(x) = (x - r_1) * ... * (x - r_n) + = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0 + + Note: for better performance requires writing a low-level kernel + """ + poly_order = roots.shape[-1] + poly_coeffs_shape = list(roots.shape) + # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0, + # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)}, + # but we insert one extra coefficient to enable better vectorization below + poly_coeffs_shape[-1] += 2 + poly_coeffs = roots.new_zeros(poly_coeffs_shape) + poly_coeffs[..., 0] = 1 + poly_coeffs[..., -1] = 1 + + # perform the Horner's rule + for i in range(1, poly_order + 1): + # note that it is computationally hard to compute backward for this method, + # because then given the coefficients it would require finding the roots and/or + # calculating the sensitivity based on the Vieta's theorem. + # So the code below tries to circumvent the explicit root finding by series + # of operations on memory copies imitating the Horner's method. + # The memory copies are required to construct nodes in the computational graph + # by exploting the explicit (not in-place, separate node for each step) + # recursion of the Horner's method. + # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. + poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs + out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1) + out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(-1, poly_order - i + 1, i + 1) + poly_coeffs = poly_coeffs_new + + return poly_coeffs.narrow(-1, 1, poly_order + 1) + + +def _polynomial_value(poly, x, zero_power, transition): + """ + A generic method for computing poly(x) using the Horner's rule. + + Args: + poly (Tensor): the (possibly batched) 1D Tensor representing + polynomial coefficients such that + poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and + poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n + + x (Tensor): the value (possible batched) to evalate the polynomial `poly` at. + + zero_power (Tensor): the represenation of `x^0`. It is application-specific. + + transition (Callable): the function that accepts some intermediate result `int_val`, + the `x` and a specific polynomial coefficient + `poly[..., k]` for some iteration `k`. + It basically performs one iteration of the Horner's rule + defined as `x * int_val + poly[..., k] * zero_power`. + Note that `zero_power` is not a parameter, + because the step `+ poly[..., k] * zero_power` depends on `x`, + whether it is a vector, a matrix, or something else, so this + functionality is delegated to the user. + """ + + res = zero_power.clone() + for k in range(poly.size(-1) - 2, -1, -1): + res = transition(res, x, poly[..., k]) + return res + +def _matrix_polynomial_value(poly, x, zero_power=None): + """ + Evaluates `poly(x)` for the (batched) matrix input `x`. + Check out `_polynomial_value` function for more details. + """ -def lobpcg(A, # type: Tensor - k=None, # type: Optional[int] - B=None, # type: Optional[Tensor] - X=None, # type: Optional[Tensor] - n=None, # type: Optional[int] - iK=None, # type: Optional[Tensor] - niter=None, # type: Optional[int] - tol=None, # type: Optional[float] - largest=None, # type: Optional[bool] - method=None, # type: Optional[str] - tracker=None, # type: Optional[None] - ortho_iparams=None, # type: Optional[Dict[str, int]] - ortho_fparams=None, # type: Optional[Dict[str, float]] - ortho_bparams=None, # type: Optional[Dict[str, bool]] - ): - # type: (...) -> Tuple[Tensor, Tensor] + # matrix-aware Horner's rule iteration + def transition(curr_poly_val, x, poly_coeff): + res = x.matmul(curr_poly_val) + res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1)) + return res + + if zero_power is None: + zero_power = torch.eye(x.size(-1), x.size(-1), dtype=x.dtype, device=x.device) \ + .view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) + + return _polynomial_value(poly, x, zero_power, transition) + +def _vector_polynomial_value(poly, x, zero_power=None): + """ + Evaluates `poly(x)` for the (batched) vector input `x`. + Check out `_polynomial_value` function for more details. + """ + + # vector-aware Horner's rule iteration + def transition(curr_poly_val, x, poly_coeff): + res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val) + return res + + if zero_power is None: + zero_power = x.new_ones(1).expand(x.shape) + + return _polynomial_value(poly, x, zero_power, transition) + +def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): + # compute a projection operator onto an orthogonal subspace spanned by the + # columns of U defined as (I - UU^T) + Ut = U.transpose(-2, -1).contiguous() + proj_U_ortho = -U.matmul(Ut) + proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1) + + # compute U_ortho, a basis for the orthogonal complement to the span(U), + # by projecting a random [..., m, m - k] matrix onto the subspace spanned + # by the columns of U. + # + # fix generator for determinism + gen = torch.Generator(A.device) + + # orthogonal complement to the span(U) + U_ortho = proj_U_ortho.matmul( + torch.randn( + (*A.shape[:-1], A.size(-1) - D.size(-1)), + dtype=A.dtype, + device=A.device, + generator=gen + ) + ) + U_ortho_t = U_ortho.transpose(-2, -1).contiguous() + + # compute the coefficients of the characteristic polynomial of the tensor D. + # Note that D is diagonal, so the diagonal elements are exactly the roots + # of the characteristic polynomial. + chr_poly_D = _polynomial_coefficients_given_roots(D) + + # the code belows finds the explicit solution to the Sylvester equation + # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U + # and incorporates it into the whole gradient stored in the `res` variable. + # + # Equivalent to the following naive implementation: + # res = A.new_zeros(A.shape) + # p_res = A.new_zeros(*A.shape[:-1], D.size(-1)) + # for k in range(1, chr_poly_D.size(-1)): + # p_res.zero_() + # for i in range(0, k): + # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2) + # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t()) + # + # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity + # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g, + # and we need to compute g(U_grad, A, U, D) + # + # The naive implementation is based on the paper + # Hu, Qingxi, and Daizhan Cheng. + # "The polynomial solution to the Sylvester matrix equation." + # Applied mathematics letters 19.9 (2006): 859-864. + # + # We can modify the computation of `p_res` from above in a more efficient way + # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2) + # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2) + # + ... + # + A.matrix_power(k - 1) U_grad * chr_poly_D[k] + # Note that this saves us from redundant matrix products with A (elimination of matrix_power) + U_grad_projected = U_grad + series_acc = U_grad_projected.new_zeros(U_grad_projected.shape) + for k in range(1, chr_poly_D.size(-1)): + poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D) + series_acc += U_grad_projected * poly_D.unsqueeze(-2) + U_grad_projected = A.matmul(U_grad_projected) + + # compute chr_poly_D(A) which essentially is: + # + # chr_poly_D_at_A = A.new_zeros(A.shape) + # for k in range(chr_poly_D.size(-1)): + # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k) + # + # Note, however, for better performance we use the Horner's rule + chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A) + + # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t + chr_poly_D_at_A_to_U_ortho = torch.matmul( + U_ortho_t, + torch.matmul( + chr_poly_D_at_A, + U_ortho + ) + ) + # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its + # Cholesky decomposition and then use `torch.cholesky_solve` for better stability. + # Cholesky decomposition requires the input to be positive-definite. + # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if + # 1. `largest` == False, or + # 2. `largest` == True and `k` is even + # under the assumption that `A` has distinct eigenvalues. + # + # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite + chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1 + chr_poly_D_at_A_to_U_ortho_L = torch.cholesky( + chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho + ) + + # compute the gradient part in span(U) + res = _symeig_backward_complete_eigenspace( + D_grad, U_grad, A, D, U + ) + + # incorporate the Sylvester equation solution into the full gradient + # it resides in span(U_ortho) + res -= U_ortho.matmul( + chr_poly_D_at_A_to_U_ortho_sign * torch.cholesky_solve( + U_ortho_t.matmul(series_acc), + chr_poly_D_at_A_to_U_ortho_L + ) + ).matmul(Ut) + + return res + +def _symeig_backward(D_grad, U_grad, A, D, U, largest): + # if `U` is square, then the columns of `U` is a complete eigenspace + if U.size(-1) == U.size(-2): + return _symeig_backward_complete_eigenspace( + D_grad, U_grad, A, D, U + ) + else: + return _symeig_backward_partial_eigenspace( + D_grad, U_grad, A, D, U, largest + ) + +class LOBPCGAutogradFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, # type: ignore[override] + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: Optional[None] = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None + ) -> Tuple[Tensor, Tensor]: + + # makes sure that input is contiguous for efficiency. + # Note: autograd does not support dense gradients for sparse input yet. + A = A.contiguous() if (not A.is_sparse) else A + if B is not None: + B = B.contiguous() if (not B.is_sparse) else B + + D, U = _lobpcg( + A, k, B, X, + n, iK, niter, tol, largest, method, tracker, + ortho_iparams, ortho_fparams, ortho_bparams + ) + + ctx.save_for_backward(A, B, D, U, largest) + + return D, U + + @staticmethod + def backward(ctx, D_grad, U_grad): + A_grad = B_grad = None + grads = [None] * 14 + + A, B, D, U, largest = ctx.saved_tensors + + # lobpcg.backward has some limitations. Checks for unsupported input + if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]): + raise ValueError( + 'lobpcg.backward does not support sparse input yet.' + 'Note that lobpcg.forward does though.' + ) + if A.dtype in (torch.complex64, torch.complex128) or \ + B is not None and B.dtype in (torch.complex64, torch.complex128): + raise ValueError( + 'lobpcg.backward does not support complex input yet.' + 'Note that lobpcg.forward does though.' + ) + if B is not None: + raise ValueError( + 'lobpcg.backward does not support backward with B != I yet.' + ) + + if largest is None: + largest = True + + # symeig backward + if B is None: + A_grad = _symeig_backward( + D_grad, U_grad, A, D, U, largest + ) + + # A has index 0 + grads[0] = A_grad + # B has index 2 + grads[2] = B_grad + return tuple(grads) + + +def lobpcg(A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: Optional[None] = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None + ) -> Tuple[Tensor, Tensor]: """Find the k largest (or smallest) eigenvalues and the corresponding eigenvectors of a symmetric positive defined generalized @@ -53,7 +373,18 @@ def lobpcg(A, # type: Tensor not recommended but there exist cases where the usage of the basic method may be preferred. - Arguments: + .. warning:: The backward method does not support sparse and complex inputs. + It works only when `B` is not provided (i.e. `B == None`). + We are actively working on extensions, and the details of + the algorithms are going to be published promptly. + + .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not. + To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric + in first-order optimization routines, prior to running `lobpcg` + we do the following symmetrization map: `A -> (A + A.t()) / 2`. + The map is performed only when the `A` requires gradients. + + Args: A (Tensor): the input tensor of size :math:`(*, m, m)` @@ -175,6 +506,51 @@ def lobpcg(A, # type: Tensor ortho_fparams=ortho_fparams, ortho_bparams=ortho_bparams) + if not torch._jit_internal.is_scripting(): + if A.requires_grad or (B is not None and B.requires_grad): + # While it is expected that `A` is symmetric, + # the `A_grad` might be not. Therefore we perform the trick below, + # so that `A_grad` becomes symmetric. + # The symmetrization is important for first-order optimization methods, + # so that (A - alpha * A_grad) is still a symmetric matrix. + # Same holds for `B`. + A_sym = (A + A.transpose(-2, -1)) / 2 + B_sym = (B + B.transpose(-2, -1)) / 2 if (B is not None) else None + + return LOBPCGAutogradFunction.apply( + A_sym, k, B_sym, X, n, iK, niter, tol, largest, + method, tracker, ortho_iparams, ortho_fparams, ortho_bparams + ) + else: + if A.requires_grad or (B is not None and B.requires_grad): + raise RuntimeError( + 'Script and require grads is not supported atm.' + 'If you just want to do the forward, use .detach()' + 'on A and B before calling into lobpcg' + ) + + return _lobpcg( + A, k, B, X, + n, iK, niter, tol, largest, method, tracker, + ortho_iparams, ortho_fparams, ortho_bparams + ) + +def _lobpcg(A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: Optional[None] = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None + ) -> Tuple[Tensor, Tensor]: + # A must be square: assert A.shape[-2] == A.shape[-1], A.shape if B is not None: @@ -230,7 +606,7 @@ def lobpcg(A, # type: Tensor bparams['ortho_use_drop'] = bparams.get('ortho_use_drop', False) if not torch.jit.is_scripting(): - LOBPCG.call_tracker = LOBPCG_call_tracker + LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore if len(A.shape) > 2: N = int(torch.prod(torch.tensor(A.shape[:-2]))) @@ -252,7 +628,7 @@ def lobpcg(A, # type: Tensor bXret[i] = worker.X[:, :k] if not torch.jit.is_scripting(): - LOBPCG.call_tracker = LOBPCG_call_tracker_orig + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k)) @@ -264,7 +640,7 @@ def lobpcg(A, # type: Tensor worker.run() if not torch.jit.is_scripting(): - LOBPCG.call_tracker = LOBPCG_call_tracker_orig + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore return worker.E[:k], worker.X[:, :k] @@ -549,7 +925,7 @@ def _get_rayleigh_ritz_transform(self, S): matrix product `D M` with element-wise product `M * d`. Also, creating the diagonal matrix `D` is avoided. - Arguments: + Args: S (Tensor): the matrix basis for the search subspace, size is :math:`(m, n)`. @@ -581,7 +957,7 @@ def _get_svqb(self, modification of the corresponding algorithm introduced in [StathopolousWu2002]. - Arguments: + Args: U (Tensor) : initial approximation, size is (m, n) drop (bool) : when True, drop columns that @@ -647,7 +1023,7 @@ def _get_ortho(self, U, V): .. note:: If all U columns are B-collinear to V then the returned tensor U will be empty. - Arguments: + Args: U (Tensor) : initial approximation, size is (m, n) V (Tensor) : B-orthogonal external basis, size is (m, k) diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 8bb2dae8cfaab..b1ca6e6f5949a 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -3,20 +3,18 @@ __all__ = ['svd_lowrank', 'pca_lowrank'] -from typing import Tuple, Optional - -import torch from torch import Tensor +import torch from . import _linalg_utils as _utils from .overrides import has_torch_function, handle_torch_function +from typing import Optional, Tuple -def get_approximate_basis(A, # type: Tensor - q, # type: int - niter=2, # type: Optional[int] - M=None # type: Optional[Tensor] - ): - # type: (...) -> Tensor +def get_approximate_basis(A: Tensor, + q: int, + niter: Optional[int] = 2, + M: Optional[Tensor] = None + ) -> Tensor: """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` @@ -37,7 +35,7 @@ def get_approximate_basis(A, # type: Tensor .. note:: To obtain repeatable results, reset the seed for the pseudorandom number generator - Arguments:: + Args:: A (Tensor): the input tensor of size :math:`(*, m, n)` q (int): the dimension of subspace spanned by :math:`Q` @@ -82,8 +80,8 @@ def get_approximate_basis(A, # type: Tensor return Q -def svd_lowrank(A, q=6, niter=2, M=None): - # type: (Tensor, Optional[int], Optional[int], Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor] +def svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, + M: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that :math:`A \approx U diag(S) V^T`. In case :math:`M` is given, then @@ -103,7 +101,7 @@ def svd_lowrank(A, q=6, niter=2, M=None): will be useful for huge sparse matrices that ``torch.svd`` cannot handle. - Arguments:: + Args:: A (Tensor): the input tensor of size :math:`(*, m, n)` q (int, optional): a slightly overestimated rank of A. @@ -130,8 +128,8 @@ def svd_lowrank(A, q=6, niter=2, M=None): return _svd_lowrank(A, q=q, niter=niter, M=M) -def _svd_lowrank(A, q=6, niter=2, M=None): - # type: (Tensor, Optional[int], Optional[int], Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor] +def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2, + M: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: q = 6 if q is None else q m, n = A.shape[-2:] matmul = _utils.matmul @@ -143,15 +141,19 @@ def _svd_lowrank(A, q=6, niter=2, M=None): # Algorithm 5.1 in Halko et al 2009, slightly modified to reduce # the number conjugate and transpose operations - if m < n: - # computing the SVD approximation of a transpose in order to - # keep B shape minimal + if m < n or n > q: + # computing the SVD approximation of a transpose in + # order to keep B shape minimal (the m < n case) or the V + # shape small (the n > q case) Q = get_approximate_basis(A_t, q, niter=niter, M=M_t) Q_c = _utils.conjugate(Q) if M is None: B_t = matmul(A, Q_c) else: B_t = matmul(A, Q_c) - matmul(M, Q_c) + assert B_t.shape[-2] == m, (B_t.shape, m) + assert B_t.shape[-1] == q, (B_t.shape, q) + assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape U, S, V = torch.svd(B_t) V = Q.matmul(V) else: @@ -161,14 +163,18 @@ def _svd_lowrank(A, q=6, niter=2, M=None): B = matmul(A_t, Q_c) else: B = matmul(A_t, Q_c) - matmul(M_t, Q_c) - U, S, V = torch.svd(_utils.transpose(B)) + B_t = _utils.transpose(B) + assert B_t.shape[-2] == q, (B_t.shape, q) + assert B_t.shape[-1] == n, (B_t.shape, n) + assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape + U, S, V = torch.svd(B_t) U = Q.matmul(U) return U, S, V -def pca_lowrank(A, q=None, center=True, niter=2): - # type: (Tensor, Optional[int], bool, int) -> Tuple[Tensor, Tensor, Tensor] +def pca_lowrank(A: Tensor, q: Optional[int] = None, center: bool = True, + niter: int = 2) -> Tuple[Tensor, Tensor, Tensor]: r"""Performs linear Principal Component Analysis (PCA) on a low-rank matrix, batches of such matrices, or sparse matrix. @@ -203,7 +209,7 @@ def pca_lowrank(A, q=None, center=True, niter=2): .. note:: To obtain repeatable results, reset the seed for the pseudorandom number generator - Arguments: + Args: A (Tensor): the input tensor of size :math:`(*, m, n)` diff --git a/torch/_ops.py b/torch/_ops.py index 70edc28220252..dd0c8cd19fdea 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -94,7 +94,7 @@ def load_library(self, path): ``torch.ops.loaded_libraries`` attribute, a set that may be inspected for the paths of all libraries loaded using this function. - Arguments: + Args: path (str): A path to a shared library to load. """ path = torch._utils_internal.resolve_library_path(path) diff --git a/torch/_six.py b/torch/_six.py index c53feed94ccec..00f9fa6b7f957 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -33,7 +33,6 @@ FileNotFoundError = builtins.FileNotFoundError StringIO = io.StringIO container_abcs = collections.abc -PY3 = sys.version_info[0] == 3 PY37 = sys.version_info[0] == 3 and sys.version_info[1] >= 7 def with_metaclass(meta: type, *bases) -> type: diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 55c5613cdcc3a..83bc041136728 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3,6 +3,7 @@ import torch._C from torch._C import _add_docstr as add_docstr from ._torch_docs import parse_kwargs +from ._torch_docs import reproducibility_notes def add_docstr_all(method, docstr): @@ -111,6 +112,28 @@ def add_docstr_all(method, docstr): """.format(**new_common_args)) +add_docstr_all('new_empty_strided', + r""" +new_empty_strided(size, stride, dtype=None, device=None, requires_grad=False) -> Tensor + +Returns a Tensor of size :attr:`size` and strides :attr:`stride` filled with +uninitialized data. By default, the returned Tensor has the same +:class:`torch.dtype` and :class:`torch.device` as this tensor. + +Args: + {dtype} + {device} + {requires_grad} + +Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty_strided((2, 3), (3, 1)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + +""".format(**new_common_args)) + add_docstr_all('new_ones', r""" new_ones(size, dtype=None, device=None, requires_grad=False) -> Tensor @@ -332,6 +355,20 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.addmv` """) +add_docstr_all('sspaddmm', + r""" +sspaddmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.sspaddmm` +""") + +add_docstr_all('smm', + r""" +smm(mat) -> Tensor + +See :func:`torch.smm` +""") + add_docstr_all('addr', r""" addr(vec1, vec2, *, beta=1, alpha=1) -> Tensor @@ -395,45 +432,9 @@ def add_docstr_all(method, docstr): add_docstr_all('all', r""" -.. function:: all() -> bool - -Returns True if all elements in the tensor are True, False otherwise. - -Example:: - - >>> a = torch.rand(1, 2).bool() - >>> a - tensor([[False, True]], dtype=torch.bool) - >>> a.all() - tensor(False, dtype=torch.bool) - -.. function:: all(dim, keepdim=False, out=None) -> Tensor - -Returns True if all elements in each row of the tensor in the given -dimension :attr:`dim` are True, False otherwise. - -If :attr:`keepdim` is ``True``, the output tensor is of the same size as -:attr:`input` except in the dimension :attr:`dim` where it is of size 1. -Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting -in the output tensor having 1 fewer dimension than :attr:`input`. - -Args: - dim (int): the dimension to reduce - keepdim (bool): whether the output tensor has :attr:`dim` retained or not - out (Tensor, optional): the output tensor - -Example:: +all(dim=None, keepdim=False) -> Tensor - >>> a = torch.rand(4, 2).bool() - >>> a - tensor([[True, True], - [True, False], - [True, True], - [True, True]], dtype=torch.bool) - >>> a.all(dim=1) - tensor([ True, False, True, True], dtype=torch.bool) - >>> a.all(dim=0) - tensor([ True, False], dtype=torch.bool) +See :func:`torch.all` """) add_docstr_all('allclose', @@ -452,45 +453,9 @@ def add_docstr_all(method, docstr): add_docstr_all('any', r""" -.. function:: any() -> bool - -Returns True if any elements in the tensor are True, False otherwise. - -Example:: - - >>> a = torch.rand(1, 2).bool() - >>> a - tensor([[False, True]], dtype=torch.bool) - >>> a.any() - tensor(True, dtype=torch.bool) +any(dim=None, keepdim=False) -> Tensor -.. function:: any(dim, keepdim=False, out=None) -> Tensor - -Returns True if any elements in each row of the tensor in the given -dimension :attr:`dim` are True, False otherwise. - -If :attr:`keepdim` is ``True``, the output tensor is of the same size as -:attr:`input` except in the dimension :attr:`dim` where it is of size 1. -Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting -in the output tensor having 1 fewer dimension than :attr:`input`. - -Args: - dim (int): the dimension to reduce - keepdim (bool): whether the output tensor has :attr:`dim` retained or not - out (Tensor, optional): the output tensor - -Example:: - - >>> a = torch.randn(4, 2) < 0 - >>> a - tensor([[ True, True], - [False, True], - [ True, True], - [False, False]]) - >>> a.any(1) - tensor([ True, True, True, False]) - >>> a.any(0) - tensor([True, True]) +See :func:`torch.any` """) add_docstr_all('apply_', @@ -732,6 +697,13 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.bitwise_xor` """) +add_docstr_all('broadcast_to', + r""" +broadcast_to(shape) -> Tensor + +See :func:`torch.broadcast_to`. +""") + add_docstr_all('logical_and', r""" logical_and() -> Tensor @@ -873,6 +845,19 @@ def add_docstr_all(method, docstr): See :func:`torch.clone` """.format(**common_args)) +add_docstr_all('coalesce', + r""" +coalesce() -> Tensor + +Returns a coalesced copy of :attr:`self` if :attr:`self` is an +:ref:`uncoalesced tensor `. + +Returns :attr:`self` if :attr:`self` is a coalesced tensor. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. +""") + add_docstr_all('contiguous', r""" contiguous(memory_format=torch.contiguous_format) -> Tensor @@ -911,6 +896,19 @@ def add_docstr_all(method, docstr): See :func:`torch.conj` """) +add_docstr_all('copysign', + r""" +copysign(other) -> Tensor + +See :func:`torch.copysign` +""") + +add_docstr_all('copysign_', r""" +copysign_(other) -> Tensor + +In-place version of :meth:`~Tensor.copysign` +""") + add_docstr_all('cos', r""" cos() -> Tensor @@ -1013,6 +1011,13 @@ def add_docstr_all(method, docstr): See :func:`torch.cumprod` """) +add_docstr_all('cumprod_', + r""" +cumprod_(dim, dtype=None) -> Tensor + +In-place version of :meth:`~Tensor.cumprod` +""") + add_docstr_all('cumsum', r""" cumsum(dim, dtype=None) -> Tensor @@ -1020,6 +1025,13 @@ def add_docstr_all(method, docstr): See :func:`torch.cumsum` """) +add_docstr_all('cumsum_', + r""" +cumsum_(dim, dtype=None) -> Tensor + +In-place version of :meth:`~Tensor.cumsum` +""") + add_docstr_all('data_ptr', r""" data_ptr() -> int @@ -1038,10 +1050,12 @@ def add_docstr_all(method, docstr): r""" dense_dim() -> int -If :attr:`self` is a sparse COO tensor (i.e., with ``torch.sparse_coo`` layout), -this returns the number of dense dimensions. Otherwise, this throws an error. +Return the number of dense dimensions in a :ref:`sparse tensor ` :attr:`self`. -See also :meth:`Tensor.sparse_dim`. +.. warning:: + Throws an error if :attr:`self` is not a sparse tensor. + +See also :meth:`Tensor.sparse_dim` and :ref:`hybrid tensors `. """) add_docstr_all('diag', @@ -1180,7 +1194,7 @@ def add_docstr_all(method, docstr): add_docstr_all('dot', r""" -dot(tensor2) -> Tensor +dot(other) -> Tensor See :func:`torch.dot` """) @@ -1478,6 +1492,12 @@ def add_docstr_all(method, docstr): See :func:`torch.ger` """) +add_docstr_all('inner', r""" +inner(other) -> Tensor + +See :func:`torch.inner`. +""") + add_docstr_all('outer', r""" outer(vec2) -> Tensor @@ -1512,13 +1532,40 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.i0` """) +add_docstr_all('igamma', + r""" +igamma(other) -> Tensor + +See :func:`torch.igamma` +""") + +add_docstr_all('igamma_', + r""" +igamma_(other) -> Tensor + +In-place version of :meth:`~Tensor.igamma` +""") + +add_docstr_all('igammac', + r""" +igammac(other) -> Tensor +See :func:`torch.igammac` +""") + +add_docstr_all('igammac_', + r""" +igammac_(other) -> Tensor +In-place version of :meth:`~Tensor.igammac` +""") + add_docstr_all('indices', r""" indices() -> Tensor -If :attr:`self` is a sparse COO tensor (i.e., with ``torch.sparse_coo`` layout), -this returns a view of the contained indices tensor. Otherwise, this throws an -error. +Return the indices tensor of a :ref:`sparse COO tensor `. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. See also :meth:`Tensor.values`. @@ -1546,9 +1593,10 @@ def add_docstr_all(method, docstr): r""" values() -> Tensor -If :attr:`self` is a sparse COO tensor (i.e., with ``torch.sparse_coo`` layout), -this returns a view of the contained values tensor. Otherwise, this throws an -error. +Return the values tensor of a :ref:`sparse COO tensor `. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. See also :meth:`Tensor.indices`. @@ -1628,16 +1676,11 @@ def add_docstr_all(method, docstr): match :attr:`self`, or an error will be raised. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {forward_reproducibility_note} Args: dim (int): dimension along which to index - index (LongTensor): indices of :attr:`tensor` to select from + index (IntTensor or LongTensor): indices of :attr:`tensor` to select from tensor (Tensor): the tensor containing values to add Example:: @@ -1651,7 +1694,7 @@ def add_docstr_all(method, docstr): [ 8., 9., 10.], [ 1., 1., 1.], [ 5., 6., 7.]]) -""") +""".format(**reproducibility_notes)) add_docstr_all('index_copy_', r""" @@ -1666,6 +1709,11 @@ def add_docstr_all(method, docstr): length of :attr:`index` (which must be a vector), and all other dimensions must match :attr:`self`, or an error will be raised. +.. note:: + If :attr:`index` contains duplicate entries, multiple elements from + :attr:`tensor` will be copied to the same index of :attr:`self`. The result + is nondeterministic since it depends on which copy occurs last. + Args: dim (int): dimension along which to index index (LongTensor): indices of :attr:`tensor` to select from @@ -1707,26 +1755,26 @@ def add_docstr_all(method, docstr): add_docstr_all('index_put_', r""" -index_put_(indices, value, accumulate=False) -> Tensor +index_put_(indices, values, accumulate=False) -> Tensor -Puts values from the tensor :attr:`value` into the tensor :attr:`self` using +Puts values from the tensor :attr:`values` into the tensor :attr:`self` using the indices specified in :attr:`indices` (which is a tuple of Tensors). The -expression ``tensor.index_put_(indices, value)`` is equivalent to -``tensor[indices] = value``. Returns :attr:`self`. +expression ``tensor.index_put_(indices, values)`` is equivalent to +``tensor[indices] = values``. Returns :attr:`self`. -If :attr:`accumulate` is ``True``, the elements in :attr:`value` are added to +If :attr:`accumulate` is ``True``, the elements in :attr:`values` are added to :attr:`self`. If accumulate is ``False``, the behavior is undefined if indices contain duplicate elements. Args: indices (tuple of LongTensor): tensors used to index into `self`. - value (Tensor): tensor of same dtype as `self`. + values (Tensor): tensor of same dtype as `self`. accumulate (bool): whether to accumulate into self """) add_docstr_all('index_put', r""" -index_put(tensor1, indices, value, accumulate=False) -> Tensor +index_put(tensor1, indices, values, accumulate=False) -> Tensor Out-place version of :meth:`~Tensor.index_put_`. `tensor1` corresponds to `self` in :meth:`torch.Tensor.index_put_`. @@ -1741,25 +1789,31 @@ def add_docstr_all(method, docstr): add_docstr_all('sparse_mask', r""" -sparse_mask(input, mask) -> Tensor +sparse_mask(mask) -> Tensor -Returns a new SparseTensor with values from Tensor :attr:`input` filtered -by indices of :attr:`mask` and values are ignored. :attr:`input` and :attr:`mask` -must have the same shape. +Returns a new :ref:`sparse tensor ` with values from a +strided tensor :attr:`self` filtered by the indices of the sparse +tensor :attr:`mask`. The values of :attr:`mask` sparse tensor are +ignored. :attr:`self` and :attr:`mask` tensors must have the same +shape. + +.. note:: + + The returned sparse tensor has the same indices as the sparse tensor + :attr:`mask`, even when the corresponding values in :attr:`self` are + zeros. Args: - input (Tensor): an input Tensor - mask (SparseTensor): a SparseTensor which we filter :attr:`input` based on its indices + mask (Tensor): a sparse tensor whose indices are used as a filter Example:: - >>> nnz = 5 - >>> dims = [5, 5, 2, 2] - >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)), - torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz) - >>> V = torch.randn(nnz, dims[2], dims[3]) - >>> size = torch.Size(dims) - >>> S = torch.sparse_coo_tensor(I, V, size).coalesce() + >>> nse = 5 + >>> dims = (5, 5, 2, 2) + >>> I = torch.cat([torch.randint(0, dims[0], size=(nse,)), + torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) + >>> V = torch.randn(nse, dims[2], dims[3]) + >>> S = torch.sparse_coo_tensor(I, V, dims).coalesce() >>> D = torch.randn(dims) >>> D.sparse_mask(S) tensor(indices=tensor([[0, 0, 0, 2], @@ -1834,6 +1888,19 @@ def add_docstr_all(method, docstr): See :func:`torch.isreal` """) +add_docstr_all('is_coalesced', + r""" +is_coalesced() -> bool + +Returns ``True`` if :attr:`self` is a :ref:`sparse COO tensor +` that is coalesced, ``False`` otherwise. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + +See :meth:`coalesce` and :ref:`uncoalesced tensors `. +""") + add_docstr_all('is_contiguous', r""" is_contiguous(memory_format=torch.contiguous_format) -> bool @@ -1896,6 +1963,13 @@ def add_docstr_all(method, docstr): """) +add_docstr_all('kron', + r""" +kron(other) -> Tensor + +See :func:`torch.kron` +""") + add_docstr_all('kthvalue', r""" kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) @@ -1903,6 +1977,20 @@ def add_docstr_all(method, docstr): See :func:`torch.kthvalue` """) +add_docstr_all('ldexp', + r""" +ldexp(other) -> Tensor + +See :func:`torch.ldexp` +""") + +add_docstr_all('ldexp_', + r""" +ldexp_(other) -> Tensor + +In-place version of :meth:`~Tensor.ldexp` +""") + add_docstr_all('lcm', r""" lcm(other) -> Tensor @@ -2206,6 +2294,13 @@ def callable(a, b) -> number See :func:`torch.median` """) +add_docstr_all('nanmedian', + r""" +nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.nanmedian` +""") + add_docstr_all('min', r""" min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) @@ -2254,6 +2349,12 @@ def callable(a, b) -> number See :func:`torch.movedim` """) +add_docstr_all('moveaxis', r""" +moveaxis(source, destination) -> Tensor + +See :func:`torch.moveaxis` +""") + add_docstr_all('mul', r""" mul(value) -> Tensor @@ -2342,6 +2443,18 @@ def callable(a, b) -> number Alias for :meth:`~Tensor.dim()` """) +add_docstr_all('nan_to_num', r""" +nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor + +See :func:`torch.nan_to_num`. +""") + +add_docstr_all('nan_to_num_', r""" +nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor + +In-place version of :meth:`~Tensor.nan_to_num`. +""") + add_docstr_all('ne', r""" ne(other) -> Tensor @@ -2510,6 +2623,20 @@ def callable(a, b) -> number In-place version of :meth:`~Tensor.pow` """) +add_docstr_all('float_power', + r""" +float_power(exponent) -> Tensor + +See :func:`torch.float_power` +""") + +add_docstr_all('float_power_', + r""" +float_power_(exponent) -> Tensor + +In-place version of :meth:`~Tensor.float_power` +""") + add_docstr_all('prod', r""" prod(dim=None, keepdim=False, dtype=None) -> Tensor @@ -2655,6 +2782,13 @@ def callable(a, b) -> number In-place version of :meth:`~Tensor.deg2rad` """) +add_docstr_all('ravel', + r""" +ravel(input) -> Tensor + +see :func:`torch.ravel` +""") + add_docstr_all('reciprocal', r""" reciprocal() -> Tensor @@ -2921,14 +3055,25 @@ def callable(a, b) -> number This is the reverse operation of the manner described in :meth:`~Tensor.gather`. -:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should have same -number of dimensions. It is also required that ``index.size(d) <= src.size(d)`` -for all dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all -dimensions ``d != dim``. +:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have +the same number of dimensions. It is also required that +``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that +``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. +Note that ``index`` and ``src`` do not broadcast. Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be -between ``0`` and ``self.size(dim) - 1`` inclusive, and all values in a row -along the specified dimension :attr:`dim` must be unique. +between ``0`` and ``self.size(dim) - 1`` inclusive. + +.. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. Additionally accepts an optional :attr:`reduce` argument that allows specification of an optional reduction operation, which is applied to all @@ -2948,41 +3093,41 @@ def callable(a, b) -> number Reducing with the addition operation is the same as using :meth:`~torch.Tensor.scatter_add_`. -Note: - Reduction is not yet implemented for the CUDA backend. - Args: dim (int): the axis along which to index - index (LongTensor): the indices of elements to scatter, - can be either empty or the same size of src. - When empty, the operation returns identity - src (Tensor): the source element(s) to scatter, - incase `value` is not specified - value (float): the source element(s) to scatter, - incase `src` is not specified - reduce (string): reduction operation to apply, - can be either 'add' or 'multiply'. + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor or float): the source element(s) to scatter. + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. Example:: - >>> x = torch.rand(2, 5) - >>> x - tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], - [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) - >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) - tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], - [ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000], - [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]]) - - >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23) - >>> z - tensor([[ 0.0000, 0.0000, 1.2300, 0.0000], - [ 0.0000, 0.0000, 0.0000, 1.2300]]) + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) - >>> z = torch.ones(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply') - >>> z - tensor([[1.0000, 1.0000, 1.2300, 1.0000], - [1.0000, 1.0000, 1.0000, 1.2300]]) """) add_docstr_all('scatter_add_', @@ -3005,35 +3150,37 @@ def callable(a, b) -> number :attr:`self`, :attr:`index` and :attr:`src` should have same number of dimensions. It is also required that ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions -``d != dim``. +``d != dim``. Note that ``index`` and ``src`` do not broadcast. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {forward_reproducibility_note} + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. Args: dim (int): the axis along which to index - index (LongTensor): the indices of elements to scatter and add, - can be either empty or the same size of src. - When empty, the operation returns identity. + index (LongTensor): the indices of elements to scatter and add, can be + either empty or of the same dimensionality as ``src``. When empty, the + operation returns ``self`` unchanged. src (Tensor): the source elements to scatter and add Example:: - >>> x = torch.rand(2, 5) - >>> x - tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328], - [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]]) - >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) - tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328], - [1.0000, 1.0427, 1.0000, 1.6782, 1.0000], - [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]]) + >>> src = torch.ones((2, 5)) + >>> index = torch.tensor([[0, 1, 2, 0, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[1., 0., 0., 1., 1.], + [0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.]]) + >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[2., 0., 0., 1., 1.], + [0., 2., 0., 0., 0.], + [0., 0., 2., 1., 1.]]) -""") +""".format(**reproducibility_notes)) add_docstr_all('select', r""" @@ -3149,6 +3296,20 @@ def callable(a, b) -> number In-place version of :meth:`~Tensor.sin` """) +add_docstr_all('sinc', + r""" +sinc() -> Tensor + +See :func:`torch.sinc` +""") + +add_docstr_all('sinc_', + r""" +sinc_() -> Tensor + +In-place version of :meth:`~Tensor.sinc` +""") + add_docstr_all('sinh', r""" sinh() -> Tensor @@ -3191,6 +3352,13 @@ def callable(a, b) -> number See :func:`torch.sort` """) +add_docstr_all('msort', + r""" +msort() -> Tensor + +See :func:`torch.msort` +""") + add_docstr_all('argsort', r""" argsort(dim=-1, descending=False) -> LongTensor @@ -3202,10 +3370,59 @@ def callable(a, b) -> number r""" sparse_dim() -> int -If :attr:`self` is a sparse COO tensor (i.e., with ``torch.sparse_coo`` layout), -this returns the number of sparse dimensions. Otherwise, this throws an error. +Return the number of sparse dimensions in a :ref:`sparse tensor ` :attr:`self`. + +.. warning:: + Throws an error if :attr:`self` is not a sparse tensor. + +See also :meth:`Tensor.dense_dim` and :ref:`hybrid tensors `. +""") + +add_docstr_all('sparse_resize_', + r""" +sparse_resize_(size, sparse_dim, dense_dim) -> Tensor + +Resizes :attr:`self` :ref:`sparse tensor ` to the desired +size and the number of sparse and dense dimensions. + +.. note:: + If the number of specified elements in :attr:`self` is zero, then + :attr:`size`, :attr:`sparse_dim`, and :attr:`dense_dim` can be any + size and positive integers such that ``len(size) == sparse_dim + + dense_dim``. + + If :attr:`self` specifies one or more elements, however, then each + dimension in :attr:`size` must not be smaller than the corresponding + dimension of :attr:`self`, :attr:`sparse_dim` must equal the number + of sparse dimensions in :attr:`self`, and :attr:`dense_dim` must + equal the number of dense dimensions in :attr:`self`. + +.. warning:: + Throws an error if :attr:`self` is not a sparse tensor. + +Args: + size (torch.Size): the desired size. If :attr:`self` is non-empty + sparse tensor, the desired size cannot be smaller than the + original size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions +""") + +add_docstr_all('sparse_resize_and_clear_', + r""" +sparse_resize_and_clear_(size, sparse_dim, dense_dim) -> Tensor -See also :meth:`Tensor.dense_dim`. +Removes all specified elements from a :ref:`sparse tensor +` :attr:`self` and resizes :attr:`self` to the desired +size and the number of sparse and dense dimensions. + +.. warning: + Throws an error if :attr:`self` is not a sparse tensor. + +Args: + size (torch.Size): the desired size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions """) add_docstr_all('sqrt', @@ -3367,6 +3584,31 @@ def callable(a, b) -> number See :func:`torch.symeig` """) +add_docstr_all('swapdims', r""" +swapdims(dim0, dim1) -> Tensor + +See :func:`torch.swapdims` +""") + +add_docstr_all('swapdims_', + r""" +swapdims_(dim0, dim1) -> Tensor + +In-place version of :meth:`~Tensor.swapdims` +""") + +add_docstr_all('swapaxes', r""" +swapaxes(axis0, axis1) -> Tensor + +See :func:`torch.swapaxes` +""") + +add_docstr_all('swapaxes_', r""" +swapaxes_(axis0, axis1) -> Tensor + +In-place version of :meth:`~Tensor.swapaxes` +""") + add_docstr_all('t', r""" t() -> Tensor @@ -3381,6 +3623,13 @@ def callable(a, b) -> number In-place version of :meth:`~Tensor.t` """) +add_docstr_all('tile', + r""" +tile(*reps) -> Tensor + +See :func:`torch.tile` +""") + add_docstr_all('to', r""" to(*args, **kwargs) -> Tensor @@ -3619,11 +3868,33 @@ def callable(a, b) -> number See :func:`torch.topk` """) +add_docstr_all('to_dense', + r""" +to_dense() -> Tensor + +Creates a strided copy of :attr:`self`. + +.. warning:: + Throws an error if :attr:`self` is a strided tensor. + +Example:: + + >>> s = torch.sparse_coo_tensor( + torch.tensor([[1, 1], + [0, 2]]), + torch.tensor([9, 10]), + size=(3, 3)) + >>> s.to_dense() + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) +""") + add_docstr_all('to_sparse', r""" to_sparse(sparseDims) -> Tensor Returns a sparse copy of the tensor. PyTorch supports sparse tensors in -:ref:`coordinate format `. +:ref:`coordinate format `. Args: sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor @@ -3855,7 +4126,7 @@ def callable(a, b) -> number add_docstr_all('vdot', r""" -dot(other) -> Tensor +vdot(other) -> Tensor See :func:`torch.vdot` """) @@ -3911,6 +4182,51 @@ def callable(a, b) -> number >>> torch.equal(b, c) False + +.. function:: view(dtype) -> Tensor + +Returns a new tensor with the same data as the :attr:`self` tensor but of a +different :attr:`dtype`. :attr:`dtype` must have the same number of bytes per +element as :attr:`self`'s dtype. + +.. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + +Args: + dtype (:class:`torch.dtype`): the desired dtype + +Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.int16) + Traceback (most recent call last): + File "", line 1, in + RuntimeError: Viewing a tensor as a new dtype with a different number of bytes per element is not supported. """) add_docstr_all('view_as', @@ -4033,6 +4349,13 @@ def callable(a, b) -> number See :func:`torch.unsafe_split` """) +add_docstr_all('tensor_split', + r""" +tensor_split(indices_or_sections, dim=0) -> List of Tensors + +See :func:`torch.tensor_split` +""") + add_docstr_all('stft', r""" stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor @@ -4048,33 +4371,6 @@ def callable(a, b) -> number See :func:`torch.istft` """) -add_docstr_all('fft', r""" -fft(signal_ndim, normalized=False) -> Tensor - -See :func:`torch.fft` -""") - -add_docstr_all('ifft', - r""" -ifft(signal_ndim, normalized=False) -> Tensor - -See :func:`torch.ifft` -""") - -add_docstr_all('rfft', - r""" -rfft(signal_ndim, normalized=False, onesided=True) -> Tensor - -See :func:`torch.rfft` -""") - -add_docstr_all('irfft', - r""" -irfft(signal_ndim, normalized=False, onesided=True, signal_sizes=None) -> Tensor - -See :func:`torch.irfft` -""") - add_docstr_all('det', r""" det() -> Tensor @@ -4170,6 +4466,20 @@ def callable(a, b) -> number Out-of-place version of :meth:`torch.Tensor.masked_scatter_` """) +add_docstr_all('xlogy', + r""" +xlogy(other) -> Tensor + +See :func:`torch.xlogy` +""") + +add_docstr_all('xlogy_', + r""" +xlogy_(other) -> Tensor + +In-place version of :meth:`~Tensor.xlogy` +""") + add_docstr_all('masked_fill', r""" masked_fill(mask, value) -> Tensor @@ -4262,6 +4572,11 @@ def callable(a, b) -> number are like normal tensors, but they carry no data. """) +add_docstr_all('is_sparse', + r""" +Is ``True`` if the Tensor uses sparse storage layout, ``False`` otherwise. +""") + add_docstr_all('device', r""" Is the :class:`torch.device` where this Tensor is. diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 8146ea6be2017..1aef783ee66fe 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -1,14 +1,15 @@ import math import torch from torch._six import inf +from typing import Optional class __PrinterOptions(object): - precision = 4 - threshold = 1000 - edgeitems = 3 - linewidth = 80 - sci_mode = None + precision: int = 4 + threshold: float = 1000 + edgeitems: int = 3 + linewidth: int = 80 + sci_mode: Optional[bool] = None PRINT_OPTS = __PrinterOptions() @@ -274,11 +275,16 @@ def get_summarized_data(self): else: return torch.stack([get_summarized_data(x) for x in self]) -def _str_intern(self): +def _str_intern(inp): prefix = 'tensor(' indent = len(prefix) suffixes = [] + # This is used to extract the primal value and thus disable the forward AD + # within this function. + # TODO(albanD) This needs to be updated when more than one level is supported + self, tangent = torch.autograd.forward_ad.unpack_dual(inp) + # Note [Print tensor device]: # A general logic here is we only print device when it doesn't match # the device specified in default tensor type. @@ -354,17 +360,22 @@ def _str_intern(self): if self.layout != torch.strided: suffixes.append('layout=' + str(self.layout)) - if self.grad_fn is not None: - name = type(self.grad_fn).__name__ + # Use inp here to get the original grad_fn and not the one generated by the forward grad + # unpacking. + if inp.grad_fn is not None: + name = type(inp.grad_fn).__name__ if name == 'CppFunction': - name = self.grad_fn.name().rsplit('::', 1)[-1] + name = inp.grad_fn.name().rsplit('::', 1)[-1] suffixes.append('grad_fn=<{}>'.format(name)) - elif self.requires_grad: + elif inp.requires_grad: suffixes.append('requires_grad=True') if self.has_names(): suffixes.append('names={}'.format(self.names)) + if tangent is not None: + suffixes.append('tangent={}'.format(tangent)) + return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse) def _str(self): diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 7b00ddbd1505f..3a226873cbe5a 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -110,6 +110,19 @@ def merge_dicts(*dicts): "tf32_note": """This operator supports :ref:`TensorFloat32`.""" } + +reproducibility_notes = { + "forward_reproducibility_note": """This operation may behave nondeterministically when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "backward_reproducibility_note": """This operation may produce nondeterministic gradients when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "cudnn_reproducibility_note": """In some circumstances when given tensors on a CUDA device \ +and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is \ +undesirable, you can try to make the operation deterministic (potentially at \ +a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. \ +See :doc:`/notes/randomness` for more information.""" +} + add_docstr(torch.abs, r""" abs(input, *, out=None) -> Tensor @@ -427,8 +440,8 @@ def merge_dicts(*dicts): Args: input (Tensor): matrix to be added - mat1 (Tensor): the first matrix to be multiplied - mat2 (Tensor): the second matrix to be multiplied + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied Keyword args: beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) @@ -445,6 +458,39 @@ def merge_dicts(*dicts): [ 0.7573, -3.9555, -2.8681]]) """.format(**common_args, **tf32_notes)) +add_docstr(torch.sspaddmm, + r""" +sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + +Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor +:attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + +Note: This function is equivalent to :func:`torch.addmm`, except +:attr:`input` and :attr:`mat1` are sparse. + +Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + {out} +""".format(**common_args)) + +add_docstr(torch.smm, + r""" +smm(input, mat) -> Tensor + +Performs a matrix multiplication of the sparse matrix :attr:`input` +with the dense matrix :attr:`mat`. + +Args: + input (Tensor): a sparse matrix to be matrix multiplied + mat (Tensor): a dense matrix to be matrix multiplied +""") + add_docstr(torch.addmv, r""" addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor @@ -472,8 +518,8 @@ def merge_dicts(*dicts): Args: input (Tensor): vector to be added - mat (Tensor): matrix to be multiplied - vec (Tensor): vector to be multiplied + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied Keyword args: beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) @@ -512,15 +558,6 @@ def merge_dicts(*dicts): :math:`(n \times m)` and :attr:`out` will be a matrix of size :math:`(n \times m)`. -For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and -:attr:`alpha` must be real numbers, otherwise they should be integers - -.. warning:: - This function is deprecated and may be removed in a future release. - It can be implemented using :func:`torch.outer` as - ``alpha * torch.outer(vec1, vec2) + beta * input`` when :attr:`beta` is not zero, - and as ``alpha * torch.outer(vec1, vec2)`` when :attr:`beta` is zero. - Args: input (Tensor): matrix to be added vec1 (Tensor): the first vector of the outer product @@ -573,6 +610,113 @@ def merge_dicts(*dicts): True """) +add_docstr(torch.all, + r""" +all(input) -> Tensor + +Tests if all elements in :attr:`input` evaluate to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + +.. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) +""".format(**single_dim_common)) + +add_docstr(torch.any, + r""" +any(input) -> Tensor + +Args: + {input} + +Tests if any element in :attr:`input` evaluates to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + +.. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if any element in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) +""".format(**single_dim_common)) + add_docstr(torch.angle, r""" angle(input, *, out=None) -> Tensor @@ -588,6 +732,11 @@ def merge_dicts(*dicts): Keyword args: {out} +.. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers, + zero for non-negative real numbers, and propagates NaNs. Previously + the function would return zero for all real numbers and not propagate + floating-point NaNs. + Example:: >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 @@ -938,12 +1087,7 @@ def merge_dicts(*dicts): ``out[n] += 1``. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} Arguments: input (Tensor): 1-d int tensor @@ -968,7 +1112,7 @@ def merge_dicts(*dicts): >>> input.bincount(weights) tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) -""") +""".format(**reproducibility_notes)) add_docstr(torch.bitwise_not, r""" @@ -989,7 +1133,6 @@ def merge_dicts(*dicts): tensor([ 0, 1, -4], dtype=torch.int8) """.format(**common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.bmm, r""" bmm(input, mat2, *, deterministic=False, out=None) -> Tensor @@ -1098,6 +1241,26 @@ def merge_dicts(*dicts): tensor([ True, False, False]) """.format(**common_args)) +add_docstr(torch.broadcast_to, + r""" +broadcast_to(input, shape) -> Tensor + +Broadcasts :attr:`input` to the shape :attr:`\shape`. +Equivalent to calling ``input.expand(shape)``. See :meth:`~Tensor.expand` for details. + +Args: + {input} + shape (list, tuple, or :class:`torch.Size`): the new shape. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.broadcast_to(x, (3, 3)) + tensor([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) +""".format(**common_args)) + add_docstr(torch.stack, r""" stack(tensors, dim=0, *, out=None) -> Tensor @@ -1209,6 +1372,67 @@ def merge_dicts(*dicts): """.format(**common_args)) +add_docstr(torch.tensor_split, + r""" +tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + +Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, +along dimension :attr:`dim` according to the indices or number of sections specified +by :attr:`indices_or_sections`. This function is based on NumPy's +:func:`numpy.array_split`. + +Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + +Example:: + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) +""") + add_docstr(torch.chunk, r""" chunk(input, chunks, dim=0) -> List of Tensors @@ -1448,6 +1672,11 @@ def merge_dicts(*dicts): Returns a new tensor with the reciprocal of the elements of :attr:`input` +.. note:: + Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral + inputs to reciprocal are automatically :ref:`promoted ` to + the default scalar type. + .. math:: \text{out}_{i} = \frac{1}{\text{input}_{i}} """ + r""" @@ -1548,6 +1777,9 @@ def merge_dicts(*dicts): batches of 2D matrices. If the inputs are batches, then returns batched outputs `c` +Supports real-valued and complex-valued inputs. +For the complex-valued inputs the transpose operator above is the conjugate transpose. + Args: input (Tensor): input matrix :math:`b` of size :math:`(*, m, k)`, where :math:`*` is zero or more batch dimensions @@ -1651,18 +1883,12 @@ def merge_dicts(*dicts): add_docstr(torch.clamp, r""" clamp(input, min, max, *, out=None) -> Tensor -Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]` and return -a resulting tensor: +Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. +Let min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: .. math:: - y_i = \begin{cases} - \text{min} & \text{if } x_i < \text{min} \\ - x_i & \text{if } \text{min} \leq x_i \leq \text{max} \\ - \text{max} & \text{if } x_i > \text{max} - \end{cases} + y_i = \min(\max(x_i, \text{min\_value}), \text{max\_value}) """ + r""" -If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, args :attr:`min` -and :attr:`max` must be real numbers, otherwise they should be integers. Args: {input} @@ -1684,9 +1910,6 @@ def merge_dicts(*dicts): Clamps all elements in :attr:`input` to be larger or equal :attr:`min`. -If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, :attr:`value` -should be a real number, otherwise it should be an integer. - Args: {input} @@ -1706,9 +1929,6 @@ def merge_dicts(*dicts): Clamps all elements in :attr:`input` to be smaller or equal :attr:`max`. -If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, :attr:`value` -should be a real number, otherwise it should be an integer. - Args: {input} @@ -1731,6 +1951,40 @@ def merge_dicts(*dicts): Alias for :func:`torch.clamp`. """.format(**common_args)) +add_docstr(torch.column_stack, + r""" +column_stack(tensors, *, out=None) -> Tensor + +Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + +Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` +in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + +""".format(**common_args)) + add_docstr(torch.complex, r""" complex(real, imag, *, out=None) -> Tensor @@ -1816,6 +2070,53 @@ def merge_dicts(*dicts): tensor([-1 - 1j, -2 - 2j, 3 + 3j]) """.format(**common_args)) +add_docstr(torch.copysign, + r""" +copysign(input, other, *, out=None) -> Tensor + +Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + +.. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if} \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if} \text{other}_{i} \geq 0.0 \\ + \end{cases} +""" + r""" + +Supports :ref:`broadcasting to a common shape `, +and integer and float inputs. + +Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + +""".format(**common_args)) + add_docstr(torch.cos, r""" cos(input, *, out=None) -> Tensor @@ -1863,6 +2164,11 @@ def merge_dicts(*dicts): tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) >>> torch.cosh(a) tensor([ 1.0133, 1.7860, 1.2536, 1.2805]) + +.. note:: + When :attr:`input` is on the CPU, the implementation of torch.cosh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. """.format(**common_args)) add_docstr(torch.cross, @@ -2332,8 +2638,7 @@ def merge_dicts(*dicts): [ 1.0500, 0.7336, -0.3836, -1.1015]]]) """.format(**common_args)) -add_docstr(torch.digamma, - r""" +add_docstr(torch.digamma, r""" digamma(input, *, out=None) -> Tensor Computes the logarithmic derivative of the gamma function on `input`. @@ -2347,6 +2652,11 @@ def merge_dicts(*dicts): Keyword args: {out} +.. note:: This function is similar to SciPy's `scipy.special.digamma`. + +.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`. + Previously it returned `NaN` for `0`. + Example:: >>> a = torch.tensor([1, 0.5]) @@ -2442,11 +2752,21 @@ def merge_dicts(*dicts): add_docstr(torch.dot, r""" -dot(input, tensor) -> Tensor +dot(input, other, *, out=None) -> Tensor -Computes the dot product (inner product) of two tensors. +Computes the dot product of two 1D tensors. -.. note:: This function does not :ref:`broadcast `. +.. note:: + + Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + +Args: + input (Tensor): first tensor in the dot product, must be 1D. + other (Tensor): second tensor in the dot product, must be 1D. + +Keyword args: + {out} Example:: @@ -2458,15 +2778,18 @@ def merge_dicts(*dicts): r""" vdot(input, other, *, out=None) -> Tensor -Computes the dot product (inner product) of two tensors. The vdot(a, b) function -handles complex numbers differently than dot(a, b). If the first argument is complex, -the complex conjugate of the first argument is used for the calculation of the dot product. +Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers +differently than dot(a, b). If the first argument is complex, the complex conjugate of the +first argument is used for the calculation of the dot product. -.. note:: This function does not :ref:`broadcast `. +.. note:: + + Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. Args: - input (Tensor): first tensor in the dot product. Its conjugate is used if it's complex. - other (Tensor): second tensor in the dot product. + input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. + other (Tensor): second tensor in the dot product, must be 1D. Keyword args: {out} @@ -2491,7 +2814,10 @@ def merge_dicts(*dicts): .. note:: Since eigenvalues and eigenvectors might be complex, backward pass is supported only - for :func:`torch.symeig` + if eigenvalues and eigenvectors are all real valued. + + When :attr:`input` is on CUDA, :func:`torch.eig() ` causes + host-device synchronization. Args: input (Tensor): the square matrix of shape :math:`(n \times n)` for which the eigenvalues and eigenvectors @@ -2517,6 +2843,32 @@ def merge_dicts(*dicts): true eigenvectors can be computed as :math:`\text{true eigenvector}[j] = eigenvectors[:, j] + i \times eigenvectors[:, j + 1]`, :math:`\text{true eigenvector}[j + 1] = eigenvectors[:, j] - i \times eigenvectors[:, j + 1]`. + +Example:: + + Trivial example with a diagonal matrix. By default, only eigenvalues are computed: + + >>> a = torch.diag(torch.tensor([1, 2, 3], dtype=torch.double)) + >>> e, v = torch.eig(a) + >>> e + tensor([[1., 0.], + [2., 0.], + [3., 0.]], dtype=torch.float64) + >>> v + tensor([], dtype=torch.float64) + + Compute also the eigenvectors: + + >>> e, v = torch.eig(a, eigenvectors=True) + >>> e + tensor([[1., 0.], + [2., 0.], + [3., 0.]], dtype=torch.float64) + >>> v + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=torch.float64) + """) add_docstr(torch.eq, r""" @@ -2688,7 +3040,6 @@ def merge_dicts(*dicts): tensor([ 0., 1.]) """.format(**common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.eye, r""" eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -2698,6 +3049,8 @@ def merge_dicts(*dicts): Args: n (int): the number of rows m (int, optional): the number of columns with default being :attr:`n` + +Keyword arguments: {out} {dtype} {layout} @@ -2740,20 +3093,27 @@ def merge_dicts(*dicts): tensor([-1., 1., -1., -1.]) """.format(**common_args)) -add_docstr(torch.floor_divide, - r""" +add_docstr(torch.floor_divide, r""" floor_divide(input, other, *, out=None) -> Tensor -Return the division of the inputs rounded down to the nearest integer. See :func:`torch.div` -for type promotion and broadcasting rules. +.. warning:: + This function's name is a misnomer. It actually rounds the + quotient towards zero instead of taking its floor. This behavior + will be deprecated in a future PyTorch release. + +Computes :attr:`input` divided by :attr:`other`, elementwise, and rounds each +quotient towards zero. Equivalently, it truncates the quotient(s): .. math:: - \text{{out}}_i = \left\lfloor \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right\rfloor + \text{{out}}_i = \text{trunc} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) """ + r""" + +Supports broadcasting to a common shape, type promotion, and integer and float inputs. + Args: - input (Tensor): the numerator tensor - other (Tensor or Scalar): the denominator + input (Tensor or Number): the dividend + other (Tensor or Number): the divisor Keyword args: {out} @@ -2777,12 +3137,18 @@ def merge_dicts(*dicts): The dividend and divisor may contain both for integer and floating point numbers. The remainder has the same sign as the dividend :attr:`input`. -When :attr:`other` is a tensor, the shapes of :attr:`input` and -:attr:`other` must be :ref:`broadcastable `. +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and float inputs. + +.. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. Args: input (Tensor): the dividend - other (Tensor or float): the divisor, which may be either a number or a tensor of the same shape as the dividend + other (Tensor or Scalar): the divisor Keyword args: {out} @@ -2791,9 +3157,8 @@ def merge_dicts(*dicts): >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) tensor([-1., -0., -1., 1., 0., 1.]) - >>> torch.fmod(torch.tensor([1., 2, 3, 4, 5]), 1.5) - tensor([ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) - + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) """.format(**common_args)) @@ -2842,7 +3207,17 @@ def merge_dicts(*dicts): r""" flatten(input, start_dim=0, end_dim=-1) -> Tensor -Flattens a contiguous range of dims in a tensor. +Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` +are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. +The order of elements in :attr:`input` is unchanged. + +Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, +or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can +be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the +flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + +.. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. Args: {input} @@ -2862,7 +3237,6 @@ def merge_dicts(*dicts): [5, 6, 7, 8]]) """.format(**common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.gather, r""" gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor @@ -2875,23 +3249,24 @@ def merge_dicts(*dicts): out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 -If :attr:`input` is an n-dimensional tensor with size -:math:`(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` -and ``dim = i``, then :attr:`index` must be an :math:`n`-dimensional tensor with -size :math:`(x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})` where :math:`y \geq 1` -and :attr:`out` will have the same size as :attr:`index`. -""" + r""" +:attr:`input` and :attr:`index` must have the same number of dimensions. +It is also required that ``index.size(d) <= input.size(d)`` for all +dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. +Note that ``input`` and ``index`` do not broadcast against each other. + Args: input (Tensor): the source tensor dim (int): the axis along which to index index (LongTensor): the indices of elements to gather - sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + +Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. out (Tensor, optional): the destination tensor Example:: - >>> t = torch.tensor([[1,2],[3,4]]) - >>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]])) + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1], [ 4, 3]]) """) @@ -2985,6 +3360,64 @@ def merge_dicts(*dicts): """) +add_docstr(torch.inner, r""" +inner(input, other, *, out=None) -> Tensor + +Computes the dot product for 1D tensors. For higher dimensions, sums the product +of elements from :attr:`input` and :attr:`other` along their last dimension. + +.. note:: + + If either :attr:`input` or :attr:`other` is a scalar, the result is equivalent + to `torch.mul(input, other)`. + + If both :attr:`input` and :attr:`other` are non-scalars, the size of their last + dimension must match and the result is equivalent to `torch.tensordot(input, + other, dims=([-1], [-1]))` + +Args: + input (Tensor): First input tensor + other (Tensor): Second input tensor + +Keyword args: + out (Tensor, optional): Optional output tensor to write result into. The output + shape is `input.shape[:-1] + other.shape[:-1]`. + +Example:: + + # Dot product + >>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) + tensor(7) + + # Multidimensional input tensors + >>> a = torch.randn(2, 3) + >>> a + tensor([[0.8173, 1.0874, 1.1784], + [0.3279, 0.1234, 2.7894]]) + >>> b = torch.randn(2, 4, 3) + >>> b + tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + >>> torch.inner(a, b) + tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], + [ 2.5671, 0.5452, -0.6912, -1.5509]], + + [[ 0.1782, 2.9843, 0.7366, 1.5672], + [ 3.5115, -0.4864, -1.2476, -4.4337]]]) + + # Scalar input + >>> torch.inner(a, torch.tensor(2)) + tensor([[1.6347, 2.1748, 2.3567], + [0.6558, 0.2469, 5.5787]]) +""") + add_docstr(torch.outer, r""" outer(input, vec2, *, out=None) -> Tensor @@ -3037,6 +3470,8 @@ def merge_dicts(*dicts): batches of 2D matrices. If the inputs are batches, then returns batched outputs `solution, LU`. +Supports real-valued and complex-valued inputs. + .. note:: Irrespective of the original strides, the returned matrices @@ -3222,35 +3657,126 @@ def merge_dicts(*dicts): """.format(**common_args)) -add_docstr(torch.index_select, +add_docstr(torch.igamma, r""" -index_select(input, dim, index, *, out=None) -> Tensor +igamma(input, other, *, out=None) -> Tensor -Returns a new tensor which indexes the :attr:`input` tensor along dimension -:attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. +Computes the regularized lower incomplete gamma function: -The returned tensor has the same number of dimensions as the original tensor -(:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length -of :attr:`index`; other dimensions have the same size as in the original tensor. +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt -.. note:: The returned tensor does **not** use the same storage as the original - tensor. If :attr:`out` has a different shape than expected, we - silently change it to the correct shape, reallocating the underlying - storage if necessary. +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.igammac` and :func:`torch.lgamma` for related functions. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + r""" Args: - {input} - dim (int): the dimension in which we index - index (LongTensor): the 1-D tensor containing the indices to index + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor Keyword args: {out} Example:: - >>> x = torch.randn(3, 4) - >>> x - tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + >>> a1 = torch.tensor([4.0]) + >>> a2 = torch.tensor([3.0, 4.0, 5.0]) + >>> a = torch.igammac(a1, a2) + tensor([0.3528, 0.5665, 0.7350]) + tensor([0.3528, 0.5665, 0.7350]) + >>> b = torch.igamma(a1, a2) + torch.igammac(a1, a2) + tensor([1., 1., 1.]) + +""".format(**common_args)) + +add_docstr(torch.igammac, + r""" +igammac(input, other, *, out=None) -> Tensor + +Computes the regularized upper incomplete gamma function: + +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt + +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.igamma` and :func:`torch.lgamma` for related functions. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. + +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + r""" +Args: + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor + +Keyword args: + {out} + +Example:: + + >>> a1 = torch.tensor([4.0]) + >>> a2 = torch.tensor([3.0, 4.0, 5.0]) + >>> a = torch.igammac(a1, a2) + tensor([0.6472, 0.4335, 0.2650]) + >>> b = torch.igamma(a1, a2) + torch.igammac(a1, a2) + tensor([1., 1., 1.]) + +""".format(**common_args)) + +add_docstr(torch.index_select, + r""" +index_select(input, dim, index, *, out=None) -> Tensor + +Returns a new tensor which indexes the :attr:`input` tensor along dimension +:attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + +The returned tensor has the same number of dimensions as the original tensor +(:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length +of :attr:`index`; other dimensions have the same size as in the original tensor. + +.. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + +Args: + {input} + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) @@ -3271,6 +3797,8 @@ def merge_dicts(*dicts): of 2D square tensors, in which case this function would return a tensor composed of individual inverses. +Supports real and complex input. + .. note:: Irrespective of the original strides, the returned tensors will be @@ -3283,7 +3811,7 @@ def merge_dicts(*dicts): Keyword args: {out} -Example:: +Examples:: >>> x = torch.rand(4, 4) >>> y = torch.inverse(x) @@ -3295,12 +3823,29 @@ def merge_dicts(*dicts): [ 0.0000, -0.0000, -0.0000, 1.0000]]) >>> torch.max(torch.abs(z - torch.eye(4))) # Max non-zero tensor(1.1921e-07) + >>> # Batched inverse example >>> x = torch.randn(2, 3, 4, 4) >>> y = torch.inverse(x) >>> z = torch.matmul(x, y) >>> torch.max(torch.abs(z - torch.eye(4).expand_as(x))) # Max non-zero tensor(1.9073e-06) + + >>> x = torch.rand(4, 4, dtype=torch.cdouble) + >>> y = torch.inverse(x) + >>> z = torch.mm(x, y) + >>> z + tensor([[ 1.0000e+00+0.0000e+00j, -1.3878e-16+3.4694e-16j, + 5.5511e-17-1.1102e-16j, 0.0000e+00-1.6653e-16j], + [ 5.5511e-16-1.6653e-16j, 1.0000e+00+6.9389e-17j, + 2.2204e-16-1.1102e-16j, -2.2204e-16+1.1102e-16j], + [ 3.8858e-16-1.2490e-16j, 2.7756e-17+3.4694e-17j, + 1.0000e+00+0.0000e+00j, -4.4409e-16+5.5511e-17j], + [ 4.4409e-16+5.5511e-16j, -3.8858e-16+1.8041e-16j, + 2.2204e-16+0.0000e+00j, 1.0000e+00-3.4694e-16j]], + dtype=torch.complex128) + >>> torch.max(torch.abs(z - torch.eye(4, dtype=torch.cdouble))) # Max non-zero + tensor(7.5107e-16, dtype=torch.float64) """.format(**common_args)) add_docstr(torch.isinf, r""" @@ -3313,7 +3858,7 @@ def merge_dicts(*dicts): Complex values are infinite when their real or imaginary part is infinite. - Arguments: + Args: {input} Returns: @@ -3398,7 +3943,7 @@ def merge_dicts(*dicts): Real values are finite when they are not NaN, negative infinity, or infinity. Complex values are finite when both their real and imaginary parts are finite. - Arguments: + Args: {input} Returns: @@ -3451,7 +3996,7 @@ def merge_dicts(*dicts): is_floating_point(input) -> (bool) Returns True if the data type of :attr:`input` is a floating point data type i.e., -one of ``torch.float64``, ``torch.float32`` and ``torch.float16``. +one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. Args: {input} @@ -3500,6 +4045,64 @@ def merge_dicts(*dicts): RuntimeError: bool value of Tensor with no values is ambiguous """.format(**common_args)) +add_docstr(torch.kron, + r""" +kron(input, other, *, out=None) -> Tensor + +Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + +If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a +:math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a +:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + +.. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + +where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. +If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + +Supports real-valued and complex-valued inputs. + +.. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + +Arguments: + input (Tensor) + other (Tensor) + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) +""") + add_docstr(torch.kthvalue, r""" kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) @@ -3516,6 +4119,11 @@ def merge_dicts(*dicts): (see :func:`torch.squeeze`), resulting in both the :attr:`values` and :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. +.. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + Args: {input} k (int): k for the k-th smallest element @@ -3571,6 +4179,35 @@ def merge_dicts(*dicts): tensor([15, 30, 15]) """.format(**common_args)) +add_docstr(torch.ldexp, r""" +ldexp(input, other, *, out=None) -> Tensor + +Multiplies :attr:`input` by 2**:attr:`other`. + +.. math:: + \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i +""" + r""" + +Typically this function is used to construct floating point numbers by multiplying +mantissas in :attr:`input` with integral powers of two created from the exponents +in :attr:'other'. + +Args: + {input} + other (Tensor): a tensor of exponents, typically integers. + +Keyword args: + {out} + +Example:: + >>> torch.ldexp(torch.tensor([1.]), torch.tensor([1])) + tensor([2.]) + >>> torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])) + tensor([ 2., 4., 8., 16.]) + + +""".format(**common_args)) + add_docstr(torch.le, r""" le(input, other, *, out=None) -> Tensor @@ -3659,7 +4296,6 @@ def merge_dicts(*dicts): tensor([ 0.5724, 0.0000, -0.1208]) """.format(**common_args)) -# TODO: update kwargs formatting (see https://github.com/pytorch/pytorch/issues/43667) add_docstr(torch.linspace, r""" linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -3668,9 +4304,9 @@ def merge_dicts(*dicts): .. math:: (\text{start}, - \text{start} + \frac{\text{end} - \text{start}}{\text{steps}}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, \ldots, - \text{start} + (\text{steps} - 1) * \frac{\text{end} - \text{start}}{\text{steps}}, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, \text{end}) """ + """ @@ -3686,6 +4322,8 @@ def merge_dicts(*dicts): start (float): the starting value for the set of points end (float): the ending value for the set of points steps (int): size of the constructed tensor + +Keyword arguments: {out} {dtype} {layout} @@ -3865,6 +4503,48 @@ def merge_dicts(*dicts): {out} """.format(**common_args)) +add_docstr(torch.xlogy, + r""" +xlogy(input, other, *, out=None) -> Tensor + +Computes ``input * log(other)`` with the following cases. + +.. math:: + \text{out}_{i} = \begin{cases} + \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ + 0 & \text{if } \text{input}_{i} = 0.0 \\ + \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise} + \end{cases} + +Similar to SciPy's `scipy.special.xlogy`. + +""" + r""" + +Args: + input (Number or Tensor) + other (Number or Tensor) + +.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor. + +Keyword args: + {out} + +Example:: + + >>> x = torch.zeros(5,) + >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')]) + >>> torch.xlogy(x, y) + tensor([0., 0., 0., 0., nan]) + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([3, 2, 1]) + >>> torch.xlogy(x, y) + tensor([1.0986, 1.3863, 0.0000]) + >>> torch.xlogy(x, 4) + tensor([1.3863, 2.7726, 4.1589]) + >>> torch.xlogy(2, y) + tensor([2.1972, 1.3863, 0.0000]) +""".format(**common_args)) + add_docstr(torch.logical_and, r""" logical_and(input, other, *, out=None) -> Tensor @@ -3980,7 +4660,6 @@ def merge_dicts(*dicts): tensor([ True, True, False, False]) """.format(**common_args)) -# TODO: update kwargs formatting (see https://github.com/pytorch/pytorch/issues/43667) add_docstr(torch.logspace, """ logspace(start, end, steps, base=10.0, *, \ out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -3993,9 +4672,9 @@ def merge_dicts(*dicts): .. math:: (\text{base}^{\text{start}}, - \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps}})}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \ldots, - \text{base}^{(\text{start} + (\text{steps} - 1) * \frac{\text{end} - \text{start}}{ \text{steps}})}, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \text{base}^{\text{end}}) """ + """ @@ -4011,7 +4690,9 @@ def merge_dicts(*dicts): start (float): the starting value for the set of points end (float): the ending value for the set of points steps (int): size of the constructed tensor - base (float): base of the logarithm function. Default: ``10.0``. + base (float, optional): base of the logarithm function. Default: ``10.0``. + +Keyword arguments: {out} {dtype} {layout} @@ -4165,11 +4846,13 @@ def merge_dicts(*dicts): add_docstr(torch.lu_solve, r""" -lu_solve(input, LU_data, LU_pivots, *, out=None) -> Tensor +lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted LU factorization of A from :meth:`torch.lu`. +This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + Arguments: b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` is zero or more batch dimensions. @@ -4297,16 +4980,15 @@ def merge_dicts(*dicts): add_docstr(torch.matrix_exp, r""" -matrix_power(input) -> Tensor - Returns the matrix exponential. Supports batched input. For a matrix ``A``, the matrix exponential is defined as .. math:: - \exp^A = \sum_{k=0}^\infty A^k / k!. + \mathrm{e}^A = \sum_{k=0}^\infty A^k / k! """ + r""" The implementation is based on: + Bader, P.; Blanes, S.; Casas, F. Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation. Mathematics 2019, 7, 1174. @@ -4566,13 +5248,12 @@ def merge_dicts(*dicts): r""" median(input) -> Tensor -Returns the median value of all elements in the :attr:`input` tensor. +Returns the median of the values in :attr:`input`. .. note:: The median is not unique for :attr:`input` tensors with an even number of elements. In this case the lower of the two medians is returned. To - compute the mean of both medians in :attr:`input`, use :func:`torch.quantile` - with ``q=0.5`` instead. + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. .. warning:: This function produces deterministic (sub)gradients unlike ``median(dim=0)`` @@ -4590,9 +5271,8 @@ def merge_dicts(*dicts): .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) -Returns a namedtuple ``(values, indices)`` where ``values`` is the median -value of each row of the :attr:`input` tensor in the given dimension -:attr:`dim`. And ``indices`` is the index location of each median value found. +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. @@ -4614,14 +5294,15 @@ def merge_dicts(*dicts): Do not expect the same result when run on CPU and GPU in general. For the same reason do not expect the gradients to be deterministic. - Args: {input} {dim} {keepdim} Keyword args: - out (tuple, optional): the result tuple of two output tensors (max, max_indices) + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. Example:: @@ -4635,6 +5316,60 @@ def merge_dicts(*dicts): torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) """.format(**single_dim_common)) +add_docstr(torch.nanmedian, + r""" +nanmedian(input) -> Tensor + +Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. +When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, +while this function will return the median of the non-``NaN`` elements in :attr:`input`. +If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + +Args: + {input} + +Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + +.. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values +found in the dimension :attr:`dim`. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has +one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the +median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + +Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) +""".format(**single_dim_common)) + add_docstr(torch.quantile, r""" quantile(input, q) -> Tensor @@ -4858,36 +5593,15 @@ def merge_dicts(*dicts): add_docstr(torch.argmin, r""" -argmin(input) -> LongTensor +argmin(input, dim=None, keepdim=False) -> LongTensor -Returns the indices of the minimum value of all elements in the :attr:`input` tensor. +Returns the indices of the minimum value(s) of the flattened tensor or along a dimension This is the second value returned by :meth:`torch.min`. See its documentation for the exact semantics of this method. .. note:: If there are multiple minimal values then the indices of the first minimal value are returned. -Args: - {input} - -Example:: - - >>> a = torch.randn(4, 4) - >>> a - tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], - [ 1.0100, -1.1975, -0.0102, -0.4732], - [-0.9240, 0.1207, -0.7506, -1.0213], - [ 1.7809, -1.2960, 0.9384, 0.1438]]) - >>> torch.argmin(a) - tensor(13) - -.. function:: argmin(input, dim, keepdim=False) -> LongTensor - -Returns the indices of the minimum values of a tensor across a dimension. - -This is the second value returned by :meth:`torch.min`. See its -documentation for the exact semantics of this method. - Args: {input} {dim} If ``None``, the argmin of the flattened input is returned. @@ -4901,8 +5615,15 @@ def merge_dicts(*dicts): [ 1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) >>> torch.argmin(a, dim=1) tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) """.format(**single_dim_common)) add_docstr(torch.mm, @@ -4917,11 +5638,14 @@ def merge_dicts(*dicts): .. note:: This function does not :ref:`broadcast `. For broadcasting matrix products, see :func:`torch.matmul`. +Supports strided and sparse 2-D tensors as inputs, autograd with +respect to strided inputs. + {tf32_note} Args: - input (Tensor): the first matrix to be multiplied - mat2 (Tensor): the second matrix to be multiplied + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied Keyword args: {out} @@ -4935,6 +5659,23 @@ def merge_dicts(*dicts): [-0.0760, -3.6705, 2.4784]]) """.format(**common_args, **tf32_notes)) +add_docstr(torch.hspmm, + r""" +hspmm(mat1, mat2, *, out=None) -> Tensor + +Performs a matrix multiplication of a :ref:`sparse COO matrix +` :attr:`mat1` and a strided matrix :attr:`mat2`. The +result is a (1 + 1)-dimensional :ref:`hybrid COO matrix +`. + +Args: + mat1 (Tensor): the first sparse matrix to be matrix multiplied + mat2 (Tensor): the second strided matrix to be matrix multiplied + +Keyword args: + {out} +""") + add_docstr(torch.matmul, r""" matmul(input, other, *, out=None) -> Tensor @@ -4957,8 +5698,14 @@ def merge_dicts(*dicts): 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` - tensor, :attr:`out` will be an :math:`(j \times k \times n \times p)` tensor. + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. {tf32_note} @@ -5262,6 +6009,103 @@ def merge_dicts(*dicts): [[-0.8437, 0.1727, -0.1398]]]) """.format(**common_args)) +add_docstr(torch.moveaxis, r""" +moveaxis(input, source, destination) -> Tensor + +Alias for :func:`torch.movedim`. + +This function is equivalent to NumPy's moveaxis function. + +Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) +""".format(**common_args)) + +add_docstr(torch.swapdims, r""" +swapdims(input, dim0, dim1) -> Tensor + +Alias for :func:`torch.transpose`. + +This function is equivalent to NumPy's swapaxes function. + +Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapdims(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapdims(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) +""".format(**common_args)) + +add_docstr(torch.swapaxes, r""" +swapaxes(input, axis0, axis1) -> Tensor + +Alias for :func:`torch.transpose`. + +This function is equivalent to NumPy's swapaxes function. + +Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) +""".format(**common_args)) + add_docstr(torch.narrow, r""" narrow(input, dim, start, length) -> Tensor @@ -5288,6 +6132,41 @@ def merge_dicts(*dicts): [ 8, 9]]) """) +add_docstr(torch.nan_to_num, + r""" +nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + +Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` +with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. +By default, :literal:`NaN`s are replaced with zero, positive infinity is replaced with the +greatest finite value representable by :attr:`input`'s dtype, and negative infinity +is replaced with the least finite value representable by :attr:`input`'s dtype. + +Args: + {input} + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + +Keyword args: + {out} + +Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + +""".format(**common_args)) + add_docstr(torch.ne, r""" ne(input, other, *, out=None) -> Tensor @@ -5559,7 +6438,6 @@ def merge_dicts(*dicts): """.format(**common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.ones, r""" ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -5570,6 +6448,8 @@ def merge_dicts(*dicts): Args: size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple. + +Keyword arguments: {out} {dtype} {layout} @@ -5587,7 +6467,6 @@ def merge_dicts(*dicts): """.format(**factory_common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.ones_like, r""" ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor @@ -5603,6 +6482,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword arguments: {dtype} {layout} {device} @@ -5658,7 +6539,7 @@ def merge_dicts(*dicts): add_docstr(torch.poisson, r""" -poisson(input *, generator=None) -> Tensor +poisson(input, generator=None) -> Tensor Returns a tensor of the same size as :attr:`input` with each element sampled from a Poisson distribution with rate parameter given by the corresponding @@ -5787,6 +6668,46 @@ def merge_dicts(*dicts): tensor([ 2., 4., 8., 16.]) """.format(**common_args)) +add_docstr(torch.float_power, + r""" +float_power(input, exponent, *, out=None) -> Tensor + +Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. +If neither input is complex returns a ``torch.float64`` tensor, +and if one or more inputs is complex returns a ``torch.complex128`` tensor. + +.. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + +Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + +Keyword args: + {out} + +Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) +""".format(**common_args)) + add_docstr(torch.prod, r""" prod(input, *, dtype=None) -> Tensor @@ -5857,7 +6778,7 @@ def merge_dicts(*dicts): add_docstr(torch.qr, r""" -qr(input, some=True, out=None) -> (Tensor, Tensor) +qr(input, some=True, *, out=None) -> (Tensor, Tensor) Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` @@ -5867,29 +6788,42 @@ def merge_dicts(*dicts): If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. +.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :func:`~torch.linalg.qr` + instead. + + **Differences with** ``torch.linalg.qr``: + + * ``torch.linalg.qr`` takes a string parameter ``mode`` instead of ``some``: + + - ``some=True`` is equivalent of ``mode='reduced'``: both are the + default + + - ``some=False`` is equivalent of ``mode='complete'``. + + .. warning:: If you plan to backpropagate through QR, note that the current backward implementation is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` columns of :attr:`input` are linearly independent. This behavior will propably change once QR supports pivoting. -.. note:: precision may be lost if the magnitudes of the elements of :attr:`input` - are large - -.. note:: While it should always give you a valid decomposition, it may not - give you the same one across platforms - it will depend on your - LAPACK implementation. +.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of matrices of dimension :math:`m \times n`. some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for - complete QR decomposition. - out (tuple, optional): tuple of `Q` and `R` tensors - satisfying :code:`input = torch.matmul(Q, R)`. - The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)` - respectively, where :math:`k = \min(m, n)` if :attr:`some:` is ``True`` and - :math:`k = m` otherwise. + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) + +Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. Example:: @@ -5921,7 +6855,7 @@ def merge_dicts(*dicts): add_docstr(torch.rad2deg, r""" -rad2deg(input, out=None) -> Tensor +rad2deg(input, *, out=None) -> Tensor Returns a new tensor with each of the elements of :attr:`input` converted from angles in radians to degrees. @@ -5944,7 +6878,7 @@ def merge_dicts(*dicts): add_docstr(torch.deg2rad, r""" -deg2rad(input, out=None) -> Tensor +deg2rad(input, *, out=None) -> Tensor Returns a new tensor with each of the elements of :attr:`input` converted from angles in degrees to radians. @@ -6001,7 +6935,7 @@ def merge_dicts(*dicts): add_docstr(torch.rand, r""" -rand(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +rand(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with random numbers from a uniform distribution on the interval :math:`[0, 1)` @@ -6011,6 +6945,8 @@ def merge_dicts(*dicts): Args: size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: {out} {dtype} {layout} @@ -6028,7 +6964,7 @@ def merge_dicts(*dicts): add_docstr(torch.rand_like, r""" -rand_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor +rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor with the same size as :attr:`input` that is filled with random numbers from a uniform distribution on the interval :math:`[0, 1)`. @@ -6037,6 +6973,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword args: {dtype} {layout} {device} @@ -6055,7 +6993,7 @@ def merge_dicts(*dicts): The shape of the tensor is defined by the variable argument :attr:`size`. -.. note: +.. note:: With the global dtype default (``torch.float32``), this function returns a tensor with dtype ``torch.int64``. @@ -6063,6 +7001,8 @@ def merge_dicts(*dicts): low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. high (int): One above the highest integer to be drawn from the distribution. size (tuple): a tuple defining the shape of the output tensor. + +Keyword args: {generator} {out} {dtype} @@ -6090,7 +7030,7 @@ def merge_dicts(*dicts): add_docstr(torch.randint_like, """ -randint_like(input, low=0, high, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +randint_like(input, low=0, high, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ memory_format=torch.preserve_format) -> Tensor Returns a tensor with the same shape as Tensor :attr:`input` filled with @@ -6105,6 +7045,8 @@ def merge_dicts(*dicts): {input} low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. high (int): One above the highest integer to be drawn from the distribution. + +Keyword args: {dtype} {layout} {device} @@ -6115,7 +7057,7 @@ def merge_dicts(*dicts): add_docstr(torch.randn, r""" -randn(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with random numbers from a normal distribution with mean `0` and variance `1` (also called the standard normal @@ -6129,6 +7071,8 @@ def merge_dicts(*dicts): Args: size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: {out} {dtype} {layout} @@ -6146,7 +7090,7 @@ def merge_dicts(*dicts): add_docstr(torch.randn_like, r""" -randn_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor +randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor with the same size as :attr:`input` that is filled with random numbers from a normal distribution with mean 0 and variance 1. @@ -6155,6 +7099,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword args: {dtype} {layout} {device} @@ -6164,19 +7110,24 @@ def merge_dicts(*dicts): """.format(**factory_like_common_args)) add_docstr(torch.randperm, - r""" -randperm(n, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False) -> LongTensor - + """ +randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, \ +device=None, requires_grad=False, pin_memory=False) -> Tensor +""" + r""" Returns a random permutation of integers from ``0`` to ``n - 1``. Args: n (int): the upper bound (exclusive) + +Keyword args: + {generator} {out} dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. Default: ``torch.int64``. {layout} {device} {requires_grad} + {pin_memory} Example:: @@ -6186,7 +7137,7 @@ def merge_dicts(*dicts): add_docstr(torch.tensor, r""" -tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor +tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor Constructs a tensor with :attr:`data`. @@ -6207,6 +7158,8 @@ def merge_dicts(*dicts): Args: {data} + +Keyword args: {dtype} {device} {requires_grad} @@ -6237,7 +7190,7 @@ def merge_dicts(*dicts): add_docstr(torch.range, r""" -range(start=0, end, step=1, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is @@ -6254,6 +7207,8 @@ def merge_dicts(*dicts): start (float): the starting value for the set of points. Default: ``0``. end (float): the ending value for the set of points step (float): the gap between each pair of adjacent points. Default: ``1``. + +Keyword args: {out} {dtype} If `dtype` is not given, infer the data type from the other input arguments. If any of `start`, `end`, or `stop` are floating-point, the @@ -6274,7 +7229,7 @@ def merge_dicts(*dicts): add_docstr(torch.arange, r""" -arange(start=0, end, step=1, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` with values from the interval ``[start, end)`` taken with common difference @@ -6291,6 +7246,8 @@ def merge_dicts(*dicts): start (Number): the starting value for the set of points. Default: ``0``. end (Number): the ending value for the set of points step (Number): the gap between each pair of adjacent points. Default: ``1``. + +Keyword args: {out} {dtype} If `dtype` is not given, infer the data type from the other input arguments. If any of `start`, `end`, or `stop` are floating-point, the @@ -6311,29 +7268,54 @@ def merge_dicts(*dicts): tensor([ 1.0000, 1.5000, 2.0000]) """.format(**factory_common_args)) +add_docstr(torch.ravel, + r""" +ravel(input) -> Tensor + +Return a contiguous flattened tensor. A copy is made only if needed. + +Args: + {input} + +Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) +""".format(**common_args)) + add_docstr(torch.remainder, r""" -remainder(input, other, out=None) -> Tensor +remainder(input, other, *, out=None) -> Tensor Computes the element-wise remainder of division. The dividend and divisor may contain both for integer and floating point numbers. The remainder has the same sign as the divisor :attr:`other`. -When :attr:`other` is a tensor, the shapes of :attr:`input` and -:attr:`other` must be :ref:`broadcastable `. +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and float inputs. + +.. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. Args: input (Tensor): the dividend - other (Tensor or float): the divisor that may be either a number or a - Tensor of the same shape as the dividend + other (Tensor or Scalar): the divisor + +Keyword args: {out} Example:: >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) tensor([ 1., 0., 1., 1., 0., 1.]) - >>> torch.remainder(torch.tensor([1., 2, 3, 4, 5]), 1.5) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5) tensor([ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) .. seealso:: @@ -6344,7 +7326,7 @@ def merge_dicts(*dicts): add_docstr(torch.renorm, r""" -renorm(input, p, dim, maxnorm, out=None) -> Tensor +renorm(input, p, dim, maxnorm, *, out=None) -> Tensor Returns a tensor where each sub-tensor of :attr:`input` along dimension :attr:`dim` is normalized such that the `p`-norm of the sub-tensor is lower @@ -6357,6 +7339,8 @@ def merge_dicts(*dicts): p (float): the power for the norm computation dim (int): the dimension to slice over to get the sub-tensors maxnorm (float): the maximum norm to keep each sub-tensor under + +Keyword args: {out} Example:: @@ -6427,16 +7411,24 @@ def merge_dicts(*dicts): torch.uint8 """) +add_docstr(torch.row_stack, + r""" +row_stack(tensors, *, out=None) -> Tensor + +Alias of :func:`torch.vstack`. +""".format(**common_args)) add_docstr(torch.round, r""" -round(input, out=None) -> Tensor +round(input, *, out=None) -> Tensor Returns a new tensor with each of the elements of :attr:`input` rounded to the closest integer. Args: {input} + +Keyword args: {out} Example:: @@ -6450,7 +7442,7 @@ def merge_dicts(*dicts): add_docstr(torch.rsqrt, r""" -rsqrt(input, out=None) -> Tensor +rsqrt(input, *, out=None) -> Tensor Returns a new tensor with the reciprocal of the square-root of each of the elements of :attr:`input`. @@ -6460,6 +7452,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -6471,6 +7465,20 @@ def merge_dicts(*dicts): tensor([ nan, 1.8351, 0.8053, nan]) """.format(**common_args)) +add_docstr(torch.scatter, + r""" +scatter(input, dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_` +""") + +add_docstr(torch.scatter_add, + r""" +scatter_add(input, dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_add_` +""") + add_docstr(torch.set_flush_denormal, r""" set_flush_denormal(mode) -> bool @@ -6543,7 +7551,7 @@ def merge_dicts(*dicts): add_docstr(torch.logit, r""" -logit(input, eps=None, out=None) -> Tensor +logit(input, eps=None, *, out=None) -> Tensor Returns a new tensor with the logit of the elements of :attr:`input`. :attr:`input` is clamped to [eps, 1 - eps] when eps is not None. @@ -6561,6 +7569,8 @@ def merge_dicts(*dicts): Args: {input} eps (float, optional): the epsilon for input clamp bound. Default: ``None`` + +Keyword args: {out} Example:: @@ -6574,7 +7584,7 @@ def merge_dicts(*dicts): add_docstr(torch.sign, r""" -sign(input, out=None) -> Tensor +sign(input, *, out=None) -> Tensor Returns a new tensor with the signs of the elements of :attr:`input`. @@ -6583,6 +7593,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -6629,40 +7641,70 @@ def merge_dicts(*dicts): {input} Keyword args: - {out} + {out} + +Example:: + + >>> x=torch.tensor([3+4j, 7-24j, 0, 1+2j]) + >>> x.sgn() + tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) +""".format(**common_args)) + +add_docstr(torch.sin, + r""" +sin(input, *, out=None) -> Tensor + +Returns a new tensor with the sine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sin(\text{input}_{i}) +""" + r""" +Args: + {input} + +Keyword args: + {out} Example:: - >>> x=torch.tensor([3+4j, 7-24j, 0, 1+2j]) - >>> x.sgn() - tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) + >>> a = torch.randn(4) + >>> a + tensor([-0.5461, 0.1347, -2.7266, -0.2746]) + >>> torch.sin(a) + tensor([-0.5194, 0.1343, -0.4032, -0.2711]) """.format(**common_args)) -add_docstr(torch.sin, +add_docstr(torch.sinc, r""" -sin(input, out=None) -> Tensor +sinc(input, *, out=None) -> Tensor -Returns a new tensor with the sine of the elements of :attr:`input`. +Computes the normalized sinc of :attr:`input.` .. math:: - \text{out}_{i} = \sin(\text{input}_{i}) + \text{out}_{i} = + \begin{cases} + 1, & \text{if}\ \text{input}_{i}=0 \\ + \sin(\pi \text{input}_{i}) / (\pi \text{input}_{i}), & \text{otherwise} + \end{cases} """ + r""" Args: {input} + +Keyword args: {out} Example:: >>> a = torch.randn(4) >>> a - tensor([-0.5461, 0.1347, -2.7266, -0.2746]) - >>> torch.sin(a) - tensor([-0.5194, 0.1343, -0.4032, -0.2711]) + tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) + >>> torch.sinc(a) + tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) """.format(**common_args)) add_docstr(torch.sinh, r""" -sinh(input, out=None) -> Tensor +sinh(input, *, out=None) -> Tensor Returns a new tensor with the hyperbolic sine of the elements of :attr:`input`. @@ -6672,6 +7714,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -6681,11 +7725,16 @@ def merge_dicts(*dicts): tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) >>> torch.sinh(a) tensor([ 0.5644, -0.9744, -0.1268, 1.0845]) + +.. note:: + When :attr:`input` is on the CPU, the implementation of torch.sinh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. """.format(**common_args)) add_docstr(torch.sort, r""" -sort(input, dim=-1, descending=False, out=None) -> (Tensor, LongTensor) +sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) Sorts the elements of the :attr:`input` tensor along a given dimension in ascending order by value. @@ -6695,14 +7744,23 @@ def merge_dicts(*dicts): If :attr:`descending` is ``True`` then the elements are sorted in descending order by value. +If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving +the order of equivalent elements. + A namedtuple of (values, indices) is returned, where the `values` are the sorted values and `indices` are the indices of the elements in the original `input` tensor. +.. warning:: `stable=True` only works on the CPU for now. + Args: {input} dim (int, optional): the dimension to sort along descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + +Keyword args: out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can be optionally given to be used as output buffers @@ -6762,14 +7820,46 @@ def merge_dicts(*dicts): [3, 2, 1, 0]]) """.format(**common_args)) +add_docstr(torch.msort, + r""" +msort(input, *, out=None) -> Tensor + +Sorts the elements of the :attr:`input` tensor along its first dimension +in ascending order by value. + +.. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`. + See also :func:`torch.sort`. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.randn(3, 4) + >>> t + tensor([[-0.1321, 0.4370, -1.2631, -1.1289], + [-2.0527, -1.1250, 0.2275, 0.3077], + [-0.0881, -0.1259, -0.5495, 1.0284]]) + >>> torch.msort(t) + tensor([[-2.0527, -1.1250, -1.2631, -1.1289], + [-0.1321, -0.1259, -0.5495, 0.3077], + [-0.0881, 0.4370, 0.2275, 1.0284]]) +""".format(**common_args)) + add_docstr(torch.sparse_coo_tensor, r""" -sparse_coo_tensor(indices, values, size=None, dtype=None, device=None, requires_grad=False) -> Tensor +sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor + +Constructs a :ref:`sparse tensor in COO(rdinate) format +` with specified values at the given +:attr:`indices`. -Constructs a sparse tensors in COO(rdinate) format with non-zero elements at the given :attr:`indices` -with the given :attr:`values`. A sparse tensor can be `uncoalesced`, in that case, there are duplicate -coordinates in the indices, and the value at that index is the sum of all duplicate value entries: -`torch.sparse`_. +.. note:: + + This function returns an :ref:`uncoalesced tensor `. Args: indices (array_like): Initial data for the tensor. Can be a list, tuple, @@ -6782,6 +7872,8 @@ def merge_dicts(*dicts): size (list, tuple, or :class:`torch.Size`, optional): Size of the sparse tensor. If not provided the size will be inferred as the minimum size big enough to hold all non-zero elements. + +Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. Default: if None, infers data type from :attr:`values`. device (:class:`torch.device`, optional): the desired device of returned tensor. @@ -6841,7 +7933,7 @@ def merge_dicts(*dicts): add_docstr(torch.sqrt, r""" -sqrt(input, out=None) -> Tensor +sqrt(input, *, out=None) -> Tensor Returns a new tensor with the square-root of the elements of :attr:`input`. @@ -6850,6 +7942,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -6863,12 +7957,14 @@ def merge_dicts(*dicts): add_docstr(torch.square, r""" -square(input, out=None) -> Tensor +square(input, *, out=None) -> Tensor Returns a new tensor with the square of the elements of :attr:`input`. Args: {input} + +Keyword args: {out} Example:: @@ -6882,7 +7978,7 @@ def merge_dicts(*dicts): add_docstr(torch.squeeze, r""" -squeeze(input, dim=None, out=None) -> Tensor +squeeze(input, dim=None, *, out=None) -> Tensor Returns a tensor with all the dimensions of :attr:`input` of size `1` removed. @@ -6906,6 +8002,8 @@ def merge_dicts(*dicts): {input} dim (int, optional): if given, the input will be squeezed only in this dimension + +Keyword args: {out} Example:: @@ -7062,12 +8160,14 @@ def merge_dicts(*dicts): add_docstr(torch.sum, r""" -sum(input, dtype=None) -> Tensor +sum(input, *, dtype=None) -> Tensor Returns the sum of all elements in the :attr:`input` tensor. Args: {input} + +Keyword args: {dtype} Example:: @@ -7078,7 +8178,7 @@ def merge_dicts(*dicts): >>> torch.sum(a) tensor(-0.5475) -.. function:: sum(input, dim, keepdim=False, dtype=None) -> Tensor +.. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor Returns the sum of each row of the :attr:`input` tensor in the given dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, @@ -7090,6 +8190,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} + +Keyword args: {dtype} Example:: @@ -7109,7 +8211,7 @@ def merge_dicts(*dicts): add_docstr(torch.nansum, r""" -nansum(input, dtype=None) -> Tensor +nansum(input, *, dtype=None) -> Tensor Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. @@ -7125,7 +8227,7 @@ def merge_dicts(*dicts): >>> torch.nansum(a) tensor(7.) -.. function:: nansum(input, dim, keepdim=False, dtype=None) -> Tensor +.. function:: nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor Returns the sum of each row of the :attr:`input` tensor in the given dimension :attr:`dim`, treating Not a Numbers (NaNs) as zero. @@ -7156,18 +8258,49 @@ def merge_dicts(*dicts): add_docstr(torch.svd, r""" -svd(input, some=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor) +svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times +V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch +of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch +dimensions as :attr:`input`. + +If :attr:`some` is ``True`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be +zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +respectively, and the same device as :attr:`input`. The :attr:`some` +argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. + +.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` + :func:`~torch.linalg.svd` instead, which is similar to NumPy's + ``numpy.linalg.svd``. + +.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: -This function returns a namedtuple ``(U, S, V)`` which is the singular value -decomposition of a input real matrix or batches of real matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^T`. + * :attr:`some` is the opposite of ``torch.linalg.`` + :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + default value for both is ``True``, so the default behavior is + effectively the opposite. -If :attr:`some` is ``True`` (default), the method returns the reduced singular value decomposition -i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned -`U` and `V` matrices will contain only :math:`min(n, m)` orthonormal columns. + * it returns ``V``, whereas ``torch.linalg.`` + :func:`~torch.linalg.svd` returns ``Vh``. The result is that + when using ``svd`` you need to manually transpose + ``V`` in order to reconstruct the original matrix. -If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices -of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here. + * If :attr:`compute_uv=False`, it returns zero-filled tensors for + ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + empty tensors. + +Supports real-valued and complex-valued input. .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, then the singular values of each matrix in the batch is returned in descending order. @@ -7176,28 +8309,30 @@ def merge_dicts(*dicts): algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine `gesdd` as well. -.. note:: Irrespective of the original strides, the returned matrix `U` - will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()` +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. -.. note:: Extra care needs to be taken when backward through `U` and `V` - outputs. Such operation is really only stable when :attr:`input` is - full rank with all distinct singular values. Otherwise, ``NaN`` can - appear as the gradients are not properly defined. Also, notice that - double backward will usually do an additional backward through `U` and - `V` even if the original backward is only on `S`. +.. note:: Gradients computed using `U` and `V` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. .. note:: When :attr:`some` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors can be arbitrary bases of the subspaces. -.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` - from the forward pass is required for the backward operation. +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + + +.. note:: With the complex-valued input the backward operation works correctly only + for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. - some (bool, optional): controls the shape of returned `U` and `V` - compute_uv (bool, optional): option whether to compute `U` and `V` or not + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. + +Keyword args: out (tuple, optional): the output tuple of tensors Example:: @@ -7228,11 +8363,13 @@ def merge_dicts(*dicts): >>> u, s, v = torch.svd(a_big) >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.transpose(-2, -1))) tensor(2.6503e-06) + +.. _Gauge problem in AD: https://re-ra.xyz/Gauge-Problem-in-Automatic-Differentiation/ """) add_docstr(torch.symeig, r""" -symeig(input, eigenvectors=False, upper=True, out=None) -> (Tensor, Tensor) +symeig(input, eigenvectors=False, upper=True, *, out=None) -> (Tensor, Tensor) This function returns eigenvalues and eigenvectors of a real symmetric matrix :attr:`input` or a batch of real symmetric matrices, @@ -7265,8 +8402,10 @@ def merge_dicts(*dicts): Args: input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more batch dimensions consisting of symmetric matrices. - eigenvectors(boolean, optional): controls whether eigenvectors have to be computed + eigenvectors(bool, optional): controls whether eigenvectors have to be computed upper(boolean, optional): controls whether to consider upper-triangular or lower-triangular region + +Keyword args: out (tuple, optional): the output tuple of (Tensor, Tensor) Returns: @@ -7345,6 +8484,11 @@ def merge_dicts(*dicts): Reverse the order of a n-D tensor along given axis in dims. +.. note:: + `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flip` is expected to be slower than `np.flip`. + Args: {input} dims (a list or tuple): axis to flip on @@ -7370,13 +8514,18 @@ def merge_dicts(*dicts): r""" fliplr(input) -> Tensor -Flip array in the left/right direction, returning a new tensor. +Flip tensor in the left/right direction, returning a new tensor. Flip the entries in each row in the left/right direction. Columns are preserved, but appear in a different order than before. Note: - Equivalent to input[:,::-1]. Requires the array to be at least 2-D. + Requires the tensor to be at least 2-D. + +.. note:: + `torch.fliplr` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.fliplr`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.fliplr` is expected to be slower than `np.fliplr`. Args: input (Tensor): Must be at least 2-dimensional. @@ -7396,13 +8545,18 @@ def merge_dicts(*dicts): r""" flipud(input) -> Tensor -Flip array in the up/down direction, returning a new tensor. +Flip tensor in the up/down direction, returning a new tensor. Flip the entries in each column in the up/down direction. Rows are preserved, but appear in a different order than before. Note: - Equivalent to input[::-1,...]. Requires the array to be at least 1-D. + Requires the tensor to be at least 1-D. + +.. note:: + `torch.flipud` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flipud`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flipud` is expected to be slower than `np.flipud`. Args: input (Tensor): Must be at least 1-dimensional. @@ -7519,7 +8673,7 @@ def merge_dicts(*dicts): add_docstr(torch.tan, r""" -tan(input, out=None) -> Tensor +tan(input, *, out=None) -> Tensor Returns a new tensor with the tangent of the elements of :attr:`input`. @@ -7528,6 +8682,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -7541,7 +8697,7 @@ def merge_dicts(*dicts): add_docstr(torch.tanh, r""" -tanh(input, out=None) -> Tensor +tanh(input, *, out=None) -> Tensor Returns a new tensor with the hyperbolic tangent of the elements of :attr:`input`. @@ -7551,6 +8707,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -7564,7 +8722,7 @@ def merge_dicts(*dicts): add_docstr(torch.topk, r""" -topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor) +topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) Returns the :attr:`k` largest elements of the given :attr:`input` tensor along a given dimension. @@ -7587,6 +8745,8 @@ def merge_dicts(*dicts): smallest elements sorted (bool, optional): controls whether to return the elements in sorted order + +Keyword args: out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers @@ -7623,7 +8783,7 @@ def merge_dicts(*dicts): Returns a tensor that is a transposed version of :attr:`input`. The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. -The resulting :attr:`out` tensor shares it's underlying storage with the +The resulting :attr:`out` tensor shares its underlying storage with the :attr:`input` tensor, so changing the content of one would change the content of the other. @@ -7658,6 +8818,8 @@ def merge_dicts(*dicts): batches of 2D matrices. If the inputs are batches, then returns batched outputs `X` +Supports real-valued and complex-valued inputs. + Args: input (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where :math:`*` is zero of more batch dimensions (:math:`b`) @@ -7696,7 +8858,7 @@ def merge_dicts(*dicts): add_docstr(torch.tril, r""" -tril(input, diagonal=0, out=None) -> Tensor +tril(input, diagonal=0, *, out=None) -> Tensor Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. @@ -7715,6 +8877,8 @@ def merge_dicts(*dicts): Args: {input} diagonal (int, optional): the diagonal to consider + +Keyword args: {out} Example:: @@ -7751,7 +8915,7 @@ def merge_dicts(*dicts): # as common args. add_docstr(torch.tril_indices, r""" -tril_indices(row, col, offset=0, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor +tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor Returns the indices of the lower triangular part of a :attr:`row`-by- :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row @@ -7778,6 +8942,8 @@ def merge_dicts(*dicts): col (``int``): number of columns in the 2-D matrix. offset (``int``): diagonal offset from the main diagonal. Default: if not provided, 0. + +Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. Default: if ``None``, ``torch.long``. {device} @@ -7802,7 +8968,7 @@ def merge_dicts(*dicts): add_docstr(torch.triu, r""" -triu(input, diagonal=0, out=None) -> Tensor +triu(input, diagonal=0, *, out=None) -> Tensor Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. @@ -7821,6 +8987,8 @@ def merge_dicts(*dicts): Args: {input} diagonal (int, optional): the diagonal to consider + +Keyword args: {out} Example:: @@ -7865,7 +9033,7 @@ def merge_dicts(*dicts): # as common args. add_docstr(torch.triu_indices, r""" -triu_indices(row, col, offset=0, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor +triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor Returns the indices of the upper triangular part of a :attr:`row` by :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row @@ -7892,6 +9060,8 @@ def merge_dicts(*dicts): col (``int``): number of columns in the 2-D matrix. offset (``int``): diagonal offset from the main diagonal. Default: if not provided, 0. + +Keyword args: dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. Default: if ``None``, ``torch.long``. {device} @@ -7922,13 +9092,15 @@ def merge_dicts(*dicts): add_docstr(torch.trunc, r""" -trunc(input, out=None) -> Tensor +trunc(input, *, out=None) -> Tensor Returns a new tensor with the truncated integer values of the elements of :attr:`input`. Args: {input} + +Keyword args: {out} Example:: @@ -8079,7 +9251,7 @@ def merge_dicts(*dicts): add_docstr(torch.zeros, r""" -zeros(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with the scalar value `0`, with the shape defined by the variable argument :attr:`size`. @@ -8087,6 +9259,8 @@ def merge_dicts(*dicts): Args: size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: {out} {dtype} {layout} @@ -8105,7 +9279,7 @@ def merge_dicts(*dicts): add_docstr(torch.zeros_like, r""" -zeros_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor +zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor filled with the scalar value `0`, with the same size as :attr:`input`. ``torch.zeros_like(input)`` is equivalent to @@ -8118,6 +9292,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword args: {dtype} {layout} {device} @@ -8134,7 +9310,7 @@ def merge_dicts(*dicts): add_docstr(torch.empty, r""" -empty(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor +empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument :attr:`size`. @@ -8142,6 +9318,8 @@ def merge_dicts(*dicts): Args: size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: {out} {dtype} {layout} @@ -8161,7 +9339,7 @@ def merge_dicts(*dicts): add_docstr(torch.empty_like, r""" -empty_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor +empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns an uninitialized tensor with the same size as :attr:`input`. ``torch.empty_like(input)`` is equivalent to @@ -8169,6 +9347,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword args: {dtype} {layout} {device} @@ -8184,7 +9364,7 @@ def merge_dicts(*dicts): add_docstr(torch.empty_strided, r""" -empty_strided(size, stride, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor +empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor Returns a tensor filled with uninitialized data. The shape and strides of the tensor is defined by the variable argument :attr:`size` and :attr:`stride` respectively. @@ -8200,6 +9380,8 @@ def merge_dicts(*dicts): Args: size (tuple of ints): the shape of the output tensor stride (tuple of ints): the strides of the output tensor + +Keyword args: {dtype} {layout} {device} @@ -8228,6 +9410,8 @@ def merge_dicts(*dicts): size (int...): a list, tuple, or :class:`torch.Size` of integers defining the shape of the output tensor. fill_value (Scalar): the value to fill the output tensor with. + +Keyword args: {out} {dtype} {layout} @@ -8243,7 +9427,7 @@ def merge_dicts(*dicts): add_docstr(torch.full_like, """ -full_like(input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +full_like(input, fill_value, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ memory_format=torch.preserve_format) -> Tensor Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. @@ -8253,6 +9437,8 @@ def merge_dicts(*dicts): Args: {input} fill_value: the number to fill the output tensor with. + +Keyword args: {dtype} {layout} {device} @@ -8455,10 +9641,15 @@ def merge_dicts(*dicts): could be unstable. Double-backward will also be unstable due to the usage of SVD internally. See :meth:`~torch.svd` for more details. +.. note:: + Supports real and complex inputs. + Batched version for complex inputs is only supported on the CPU. + Arguments: - input (Tensor): The input tensor of size :math:`(*, m, n)` where :math:`*` is zero or more batch dimensions - rcond (float): A floating point value to determine the cutoff for small singular values. - Default: 1e-15 + input (Tensor): The input tensor of size :math:`(*, m, n)` where :math:`*` is + zero or more batch dimensions. + rcond (float, optional): A floating point value to determine the cutoff for + small singular values. Default: ``1e-15``. Returns: The pseudo-inverse of :attr:`input` of dimensions :math:`(*, n, m)` @@ -8495,393 +9686,9 @@ def merge_dicts(*dicts): .. _[2]: https://www.jstor.org/stable/2156365 """) -add_docstr(torch.fft, r""" -fft(input, signal_ndim, normalized=False) -> Tensor - -Complex-to-complex Discrete Fourier Transform - -This method computes the complex-to-complex discrete Fourier transform. -Ignoring the batch dimensions, it computes the following expression: - -.. math:: - X[\omega_1, \dots, \omega_d] = - \sum_{n_1=0}^{N_1-1} \dots \sum_{n_d=0}^{N_d-1} x[n_1, \dots, n_d] - e^{-j\ 2 \pi \sum_{i=0}^d \frac{\omega_i n_i}{N_i}}, - -where :math:`d` = :attr:`signal_ndim` is number of dimensions for the -signal, and :math:`N_i` is the size of signal dimension :math:`i`. - -This method supports 1D, 2D and 3D complex-to-complex transforms, indicated -by :attr:`signal_ndim`. :attr:`input` must be a tensor with last dimension -of size 2, representing the real and imaginary components of complex -numbers, and should have at least ``signal_ndim + 1`` dimensions with optionally -arbitrary number of leading batch dimensions. If :attr:`normalized` is set to -``True``, this normalizes the result by dividing it with -:math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is unitary. - -Returns the real and the imaginary parts together as one tensor of the same -shape of :attr:`input`. - -The inverse of this function is :func:`~torch.ifft`. - -.. deprecated:: 1.7.0 - The function :func:`torch.fft` is deprecated and will be removed in - PyTorch 1.8. Use the new :ref:`torch.fft ` module - functions, instead, by importing :ref:`torch.fft ` and - calling :func:`torch.fft.fft` or :func:`torch.fft.fftn`. - -.. note:: - For CUDA tensors, an LRU cache is used for cuFFT plans to speed up - repeatedly running FFT methods on tensors of same geometry with same - configuration. See :ref:`cufft-plan-cache` for more details on how to - monitor and control the cache. - -.. warning:: - If the torch.fft module is imported then "torch.fft" will refer to the - module and not this function. Use :meth:`torch.Tensor.fft` instead. - -.. warning:: - Due to limited dynamic range of half datatype, performing this operation in half - precision may cause the first element of result to overflow for certain inputs. - -.. warning:: - For CPU tensors, this method is currently only available with MKL. Use - :func:`torch.backends.mkl.is_available` to check if MKL is installed. - -Arguments: - input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1`` - dimensions - signal_ndim (int): the number of dimensions in each signal. - :attr:`signal_ndim` can only be 1, 2 or 3 - normalized (bool, optional): controls whether to return normalized results. - Default: ``False`` - -Returns: - Tensor: A tensor containing the complex-to-complex Fourier transform result - -Example:: - - >>> # unbatched 2D FFT - >>> x = torch.randn(4, 3, 2) - >>> torch.fft(x, 2) - tensor([[[-0.0876, 1.7835], - [-2.0399, -2.9754], - [ 4.4773, -5.0119]], - - [[-1.5716, 2.7631], - [-3.8846, 5.2652], - [ 0.2046, -0.7088]], - - [[ 1.9938, -0.5901], - [ 6.5637, 6.4556], - [ 2.9865, 4.9318]], - - [[ 7.0193, 1.1742], - [-1.3717, -2.1084], - [ 2.0289, 2.9357]]]) - >>> # batched 1D FFT - >>> torch.fft(x, 1) - tensor([[[ 1.8385, 1.2827], - [-0.1831, 1.6593], - [ 2.4243, 0.5367]], - - [[-0.9176, -1.5543], - [-3.9943, -2.9860], - [ 1.2838, -2.9420]], - - [[-0.8854, -0.6860], - [ 2.4450, 0.0808], - [ 1.3076, -0.5768]], - - [[-0.1231, 2.7411], - [-0.3075, -1.7295], - [-0.5384, -2.0299]]]) - >>> # arbitrary number of batch dimensions, 2D FFT - >>> x = torch.randn(3, 3, 5, 5, 2) - >>> y = torch.fft(x, 2) - >>> y.shape - torch.Size([3, 3, 5, 5, 2]) - -""") - -add_docstr(torch.ifft, - r""" -ifft(input, signal_ndim, normalized=False) -> Tensor - -Complex-to-complex Inverse Discrete Fourier Transform - -This method computes the complex-to-complex inverse discrete Fourier -transform. Ignoring the batch dimensions, it computes the following -expression: - -.. math:: - X[\omega_1, \dots, \omega_d] = - \frac{1}{\prod_{i=1}^d N_i} \sum_{n_1=0}^{N_1-1} \dots \sum_{n_d=0}^{N_d-1} x[n_1, \dots, n_d] - e^{\ j\ 2 \pi \sum_{i=0}^d \frac{\omega_i n_i}{N_i}}, - -where :math:`d` = :attr:`signal_ndim` is number of dimensions for the -signal, and :math:`N_i` is the size of signal dimension :math:`i`. - -The argument specifications are almost identical with :func:`~torch.fft`. -However, if :attr:`normalized` is set to ``True``, this instead returns the -results multiplied by :math:`\sqrt{\prod_{i=1}^d N_i}`, to become a unitary -operator. Therefore, to invert a :func:`~torch.fft`, the :attr:`normalized` -argument should be set identically for :func:`~torch.fft`. - -Returns the real and the imaginary parts together as one tensor of the same -shape of :attr:`input`. - -The inverse of this function is :func:`~torch.fft`. - -.. deprecated:: 1.7.0 - The function :func:`torch.ifft` is deprecated and will be removed in a - future PyTorch release. Use the new :ref:`torch.fft ` - module functions, instead, by importing :ref:`torch.fft ` - and calling :func:`torch.fft.ifft` or :func:`torch.fft.ifftn`. - -.. note:: - For CUDA tensors, an LRU cache is used for cuFFT plans to speed up - repeatedly running FFT methods on tensors of same geometry with same - configuration. See :ref:`cufft-plan-cache` for more details on how to - monitor and control the cache. - -.. warning:: - Due to limited dynamic range of half datatype, performing this operation in half - precision may cause the first element of result to overflow for certain inputs. - -.. warning:: - For CPU tensors, this method is currently only available with MKL. Use - :func:`torch.backends.mkl.is_available` to check if MKL is installed. - -Arguments: - input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1`` - dimensions - signal_ndim (int): the number of dimensions in each signal. - :attr:`signal_ndim` can only be 1, 2 or 3 - normalized (bool, optional): controls whether to return normalized results. - Default: ``False`` - -Returns: - Tensor: A tensor containing the complex-to-complex inverse Fourier transform result - -Example:: - - >>> x = torch.randn(3, 3, 2) - >>> x - tensor([[[ 1.2766, 1.3680], - [-0.8337, 2.0251], - [ 0.9465, -1.4390]], - - [[-0.1890, 1.6010], - [ 1.1034, -1.9230], - [-0.9482, 1.0775]], - - [[-0.7708, -0.8176], - [-0.1843, -0.2287], - [-1.9034, -0.2196]]]) - >>> y = torch.fft(x, 2) - >>> torch.ifft(y, 2) # recover x - tensor([[[ 1.2766, 1.3680], - [-0.8337, 2.0251], - [ 0.9465, -1.4390]], - - [[-0.1890, 1.6010], - [ 1.1034, -1.9230], - [-0.9482, 1.0775]], - - [[-0.7708, -0.8176], - [-0.1843, -0.2287], - [-1.9034, -0.2196]]]) - -""") - -add_docstr(torch.rfft, - r""" -rfft(input, signal_ndim, normalized=False, onesided=True) -> Tensor - -Real-to-complex Discrete Fourier Transform - -This method computes the real-to-complex discrete Fourier transform. It is -mathematically equivalent with :func:`~torch.fft` with differences only in -formats of the input and output. - -This method supports 1D, 2D and 3D real-to-complex transforms, indicated -by :attr:`signal_ndim`. :attr:`input` must be a tensor with at least -``signal_ndim`` dimensions with optionally arbitrary number of leading batch -dimensions. If :attr:`normalized` is set to ``True``, this normalizes the result -by dividing it with :math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is -unitary, where :math:`N_i` is the size of signal dimension :math:`i`. - -The real-to-complex Fourier transform results follow conjugate symmetry: - -.. math:: - X[\omega_1, \dots, \omega_d] = X^*[N_1 - \omega_1, \dots, N_d - \omega_d], - -where the index arithmetic is computed modulus the size of the corresponding -dimension, :math:`\ ^*` is the conjugate operator, and -:math:`d` = :attr:`signal_ndim`. :attr:`onesided` flag controls whether to avoid -redundancy in the output results. If set to ``True`` (default), the output will -not be full complex result of shape :math:`(*, 2)`, where :math:`*` is the shape -of :attr:`input`, but instead the last dimension will be halfed as of size -:math:`\lfloor \frac{N_d}{2} \rfloor + 1`. - -The inverse of this function is :func:`~torch.irfft`. - -.. deprecated:: 1.7.0 - The function :func:`torch.rfft` is deprecated and will be removed in a - future PyTorch release. Use the new :ref:`torch.fft ` - module functions, instead, by importing :ref:`torch.fft ` - and calling :func:`torch.fft.rfft` for one-sided output, or - :func:`torch.fft.fft` for two-sided output. - -.. note:: - For CUDA tensors, an LRU cache is used for cuFFT plans to speed up - repeatedly running FFT methods on tensors of same geometry with same - configuration. See :ref:`cufft-plan-cache` for more details on how to - monitor and control the cache. - -.. warning:: - Due to limited dynamic range of half datatype, performing this operation in half - precision may cause the first element of result to overflow for certain inputs. - -.. warning:: - For CPU tensors, this method is currently only available with MKL. Use - :func:`torch.backends.mkl.is_available` to check if MKL is installed. - -Arguments: - input (Tensor): the input tensor of at least :attr:`signal_ndim` dimensions - signal_ndim (int): the number of dimensions in each signal. - :attr:`signal_ndim` can only be 1, 2 or 3 - normalized (bool, optional): controls whether to return normalized results. - Default: ``False`` - onesided (bool, optional): controls whether to return half of results to - avoid redundancy. Default: ``True`` - -Returns: - Tensor: A tensor containing the real-to-complex Fourier transform result - -Example:: - - >>> x = torch.randn(5, 5) - >>> torch.rfft(x, 2).shape - torch.Size([5, 3, 2]) - >>> torch.rfft(x, 2, onesided=False).shape - torch.Size([5, 5, 2]) - -""") - - -add_docstr(torch.irfft, - r""" -irfft(input, signal_ndim, normalized=False, onesided=True, signal_sizes=None) -> Tensor - -Complex-to-real Inverse Discrete Fourier Transform - -This method computes the complex-to-real inverse discrete Fourier transform. -It is mathematically equivalent with :func:`ifft` with differences only in -formats of the input and output. - -The argument specifications are almost identical with :func:`~torch.ifft`. -Similar to :func:`~torch.ifft`, if :attr:`normalized` is set to ``True``, -this normalizes the result by multiplying it with -:math:`\sqrt{\prod_{i=1}^K N_i}` so that the operator is unitary, where -:math:`N_i` is the size of signal dimension :math:`i`. - -.. note:: - Due to the conjugate symmetry, :attr:`input` do not need to contain the full - complex frequency values. Roughly half of the values will be sufficient, as - is the case when :attr:`input` is given by :func:`~torch.rfft` with - ``rfft(signal, onesided=True)``. In such case, set the :attr:`onesided` - argument of this method to ``True``. Moreover, the original signal shape - information can sometimes be lost, optionally set :attr:`signal_sizes` to be - the size of the original signal (without the batch dimensions if in batched - mode) to recover it with correct shape. - - Therefore, to invert an :func:`~torch.rfft`, the :attr:`normalized` and - :attr:`onesided` arguments should be set identically for :func:`~torch.irfft`, - and preferably a :attr:`signal_sizes` is given to avoid size mismatch. See the - example below for a case of size mismatch. - - See :func:`~torch.rfft` for details on conjugate symmetry. - -The inverse of this function is :func:`~torch.rfft`. - -.. deprecated:: 1.7.0 - The function :func:`torch.irfft` is deprecated and will be removed in a - future PyTorch release. Use the new :ref:`torch.fft ` - module functions, instead, by importing :ref:`torch.fft ` - and calling :func:`torch.fft.irfft` for one-sided input, or - :func:`torch.fft.ifft` for two-sided input. - -.. warning:: - Generally speaking, input to this function should contain values - following conjugate symmetry. Note that even if :attr:`onesided` is - ``True``, often symmetry on some part is still needed. When this - requirement is not satisfied, the behavior of :func:`~torch.irfft` is - undefined. Since :func:`torch.autograd.gradcheck` estimates numerical - Jacobian with point perturbations, :func:`~torch.irfft` will almost - certainly fail the check. - -.. note:: - For CUDA tensors, an LRU cache is used for cuFFT plans to speed up - repeatedly running FFT methods on tensors of same geometry with same - configuration. See :ref:`cufft-plan-cache` for more details on how to - monitor and control the cache. - -.. warning:: - Due to limited dynamic range of half datatype, performing this operation in half - precision may cause the first element of result to overflow for certain inputs. - -.. warning:: - For CPU tensors, this method is currently only available with MKL. Use - :func:`torch.backends.mkl.is_available` to check if MKL is installed. - -Arguments: - input (Tensor): the input tensor of at least :attr:`signal_ndim` ``+ 1`` - dimensions - signal_ndim (int): the number of dimensions in each signal. - :attr:`signal_ndim` can only be 1, 2 or 3 - normalized (bool, optional): controls whether to return normalized results. - Default: ``False`` - onesided (bool, optional): controls whether :attr:`input` was halfed to avoid - redundancy, e.g., by :func:`rfft`. Default: ``True`` - signal_sizes (list or :class:`torch.Size`, optional): the size of the original - signal (without batch dimension). Default: ``None`` - -Returns: - Tensor: A tensor containing the complex-to-real inverse Fourier transform result - -Example:: - - >>> x = torch.randn(4, 4) - >>> torch.rfft(x, 2, onesided=True).shape - torch.Size([4, 3, 2]) - >>> - >>> # notice that with onesided=True, output size does not determine the original signal size - >>> x = torch.randn(4, 5) - - >>> torch.rfft(x, 2, onesided=True).shape - torch.Size([4, 3, 2]) - >>> - >>> # now we use the original shape to recover x - >>> x - tensor([[-0.8992, 0.6117, -1.6091, -0.4155, -0.8346], - [-2.1596, -0.0853, 0.7232, 0.1941, -0.0789], - [-2.0329, 1.1031, 0.6869, -0.5042, 0.9895], - [-0.1884, 0.2858, -1.5831, 0.9917, -0.8356]]) - >>> y = torch.rfft(x, 2, onesided=True) - >>> torch.irfft(y, 2, onesided=True, signal_sizes=x.shape) # recover x - tensor([[-0.8992, 0.6117, -1.6091, -0.4155, -0.8346], - [-2.1596, -0.0853, 0.7232, 0.1941, -0.0789], - [-2.0329, 1.1031, 0.6869, -0.5042, 0.9895], - [-0.1884, 0.2858, -1.5831, 0.9917, -0.8356]]) - -""") - - add_docstr(torch.hann_window, """ -hann_window(window_length, periodic=True, dtype=None, \ +hann_window(window_length, periodic=True, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor """ + r""" Hann window function. @@ -8908,6 +9715,8 @@ def merge_dicts(*dicts): window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic function. If False, return a symmetric window. + +Keyword args: {dtype} Only floating point types are supported. layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only ``torch.strided`` (dense layout) is supported. @@ -8922,7 +9731,7 @@ def merge_dicts(*dicts): add_docstr(torch.hamming_window, """ -hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, dtype=None, \ +hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor """ + r""" Hamming window function. @@ -8953,6 +9762,8 @@ def merge_dicts(*dicts): function. If False, return a symmetric window. alpha (float, optional): The coefficient :math:`\alpha` in the equation above beta (float, optional): The coefficient :math:`\beta` in the equation above + +Keyword args: {dtype} Only floating point types are supported. layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only ``torch.strided`` (dense layout) is supported. @@ -8967,7 +9778,7 @@ def merge_dicts(*dicts): add_docstr(torch.bartlett_window, """ -bartlett_window(window_length, periodic=True, dtype=None, \ +bartlett_window(window_length, periodic=True, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor """ + r""" Bartlett window function. @@ -8996,6 +9807,8 @@ def merge_dicts(*dicts): window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic function. If False, return a symmetric window. + +Keyword args: {dtype} Only floating point types are supported. layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only ``torch.strided`` (dense layout) is supported. @@ -9010,7 +9823,7 @@ def merge_dicts(*dicts): add_docstr(torch.blackman_window, """ -blackman_window(window_length, periodic=True, dtype=None, \ +blackman_window(window_length, periodic=True, *, dtype=None, \ layout=torch.strided, device=None, requires_grad=False) -> Tensor """ + r""" Blackman window function. @@ -9036,6 +9849,8 @@ def merge_dicts(*dicts): window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic function. If False, return a symmetric window. + +Keyword args: {dtype} Only floating point types are supported. layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only ``torch.strided`` (dense layout) is supported. @@ -9062,7 +9877,7 @@ def merge_dicts(*dicts): out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling -``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. +``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. The :attr:`periodic` argument is intended as a helpful shorthand to produce a periodic window as input to functions like :func:`torch.stft`. @@ -9230,6 +10045,8 @@ def merge_dicts(*dicts): Arguments: y (Tensor): The values of the function to integrate + +Keyword args: dx (float): The distance between points at which `y` is sampled. dim (int): The dimension along which to integrate. By default, use the last dimension. @@ -9259,8 +10076,7 @@ def merge_dicts(*dicts): array. Returns: - Tensor: Repeated tensor which has the same shape as input, except along the - given axis. + Tensor: Repeated tensor which has the same shape as input, except along the given axis. Example:: @@ -9285,6 +10101,45 @@ def merge_dicts(*dicts): `1` appears `n2` times, `2` appears `n3` times, etc. """.format(**common_args)) +add_docstr(torch.tile, r""" +tile(input, reps) -> Tensor + +Constructs a tensor by repeating the elements of :attr:`input`. +The :attr:`reps` argument specifies the number of repetitions +in each dimension. + +If :attr:`reps` specifies fewer dimensions than :attr:`input` has, then +ones are prepended to :attr:`reps` until all dimensions are specified. +For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`reps` +is (2, 2), then :attr:`reps` is treated as (1, 1, 2, 2). + +Analogously, if :attr:`input` has fewer dimensions than :attr:`reps` +specifies, then :attr:`input` is treated as if it were unsqueezed at +dimension zero until it has as many dimensions as :attr:`reps` specifies. +For example, if :attr:`input` has shape (4, 2) and :attr:`reps` +is (3, 3, 2, 2), then :attr:`input` is treated as if it had the +shape (1, 1, 4, 2). + +.. note:: + + This function is similar to NumPy's tile function. + +Args: + input (Tensor): the tensor whose elements to repeat. + reps (tuple): the number of repetitions per dimension. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.tile((2,)) + tensor([1, 2, 3, 1, 2, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.tile(y, (2, 2)) + tensor([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) +""") add_docstr(torch.quantize_per_tensor, r""" @@ -9465,7 +10320,7 @@ def merge_dicts(*dicts): add_docstr(torch.searchsorted, r""" -searchsorted(sorted_sequence, values, out_int32=False, right=False, out=None) -> Tensor +searchsorted(sorted_sequence, values, *, out_int32=False, right=False, out=None) -> Tensor Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the corresponding values in :attr:`values` were inserted before the indices, the order of the @@ -9483,21 +10338,23 @@ def merge_dicts(*dicts): - *returned index satisfies* * - 1-D - False - - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` * - 1-D - True - - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` * - N-D - False - - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` * - N-D - True - - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` Args: sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* dimension. values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + +Keyword args: out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. Default value is False, i.e. default output data type is torch.int64. right (bool, optional): if False, return the first suitable location that is found. If True, return the @@ -9540,7 +10397,7 @@ def merge_dicts(*dicts): add_docstr(torch.bucketize, r""" -bucketize(input, boundaries, out_int32=False, right=False, out=None) -> Tensor +bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size @@ -9554,13 +10411,15 @@ def merge_dicts(*dicts): * - :attr:`right` - *returned index satisfies* * - False - - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` - * - True - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` Args: input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). boundaries (Tensor): 1-D tensor, must contain a monotonically increasing sequence. + +Keyword args: out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. Default value is False, i.e. default output data type is torch.int64. right (bool, optional): if False, return the first suitable location that is found. If True, return the diff --git a/torch/_utils.py b/torch/_utils.py index 11f378a4d7f9b..75eadd4a990e1 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,12 +1,13 @@ import torch import torch._six -from typing import Optional +from typing import Optional, List, DefaultDict import warnings from collections import defaultdict import sys import traceback + def _type(self, dtype=None, non_blocking=False, **kwargs): """Returns the type if `dtype` is not provided, else casts this object to the specified type. @@ -36,9 +37,9 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): raise RuntimeError("Cannot cast sparse tensor to dense tensor") new_module_name = dtype.__module__.replace('.sparse', '') new_values_type_name = new_module_name + '.' + dtype.__name__ - new_values = torch._values(self).type(new_values_type_name, non_blocking) + new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) new_indices_type_name = new_module_name + '.LongTensor' - new_indices = torch._indices(self).type(new_indices_type_name, non_blocking) + new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking) return dtype(new_indices, new_values, self.size()) if dtype.is_sparse: raise RuntimeError("Cannot cast dense tensor to sparse tensor") @@ -71,8 +72,8 @@ def _cuda(self, device=None, non_blocking=False, **kwargs): with torch.cuda.device(device): if self.is_sparse: new_type = getattr(torch.cuda.sparse, self.__class__.__name__) - indices = torch._indices(self).cuda(device, non_blocking) - values = torch._values(self).cuda(device, non_blocking) + indices = torch.Tensor._indices(self).cuda(device, non_blocking) + values = torch.Tensor._values(self).cuda(device, non_blocking) return new_type(indices, values, self.size()) else: new_type = getattr(torch.cuda, self.__class__.__name__) @@ -143,7 +144,7 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac return tensor -_sparse_tensors_to_validate = [] +_sparse_tensors_to_validate: List["torch.Tensor"] = [] # In _legacy_load() in serialization.py we unpickle storages after the sparse # tensors have been already unpickled. Those storages contain data necessary for @@ -247,7 +248,7 @@ def _flatten_dense_tensors(tensors): buffer. Element-wise operation on this buffer will be equivalent to operating individually. - Arguments: + Args: tensors (Iterable[Tensor]): dense tensors to flatten. Returns: @@ -263,15 +264,15 @@ def _flatten_sparse_tensors(tensors): """Flatten sparse tensors into two contiguous 1D buffers, one of indices and one of values. Assume tensors are of same sparse type. - Arguments: + Args: tensors (Iterable[Tensor]): sparse tensors to flatten. Returns: A tuple of two contiguous 1D buffers, one containing input tensors' indices and the other containing the values. """ - flat_indices = _flatten_dense_tensors([torch._indices(t) for t in tensors]) - flat_values = _flatten_dense_tensors([torch._values(t) for t in tensors]) + flat_indices = _flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors]) + flat_values = _flatten_dense_tensors([torch.Tensor._values(t) for t in tensors]) return flat_indices, flat_values @@ -279,7 +280,7 @@ def _unflatten_dense_tensors(flat, tensors): """View a flat buffer using the sizes of tensors. Assume that tensors are of same dense type, and that flat is given by _flatten_dense_tensors. - Arguments: + Args: flat (Tensor): flattened dense tensors to unflatten. tensors (Iterable[Tensor]): dense tensors whose sizes will be used to unflatten flat. @@ -302,7 +303,7 @@ def _unflatten_sparse_tensors(flat, tensors): tensors. Assume that tensors are of same sparse type, and that flat is given by _flatten_sparse_tensors. - Arguments: + Args: flat (tuple(Tensor, Tensor)): flattened indices and values of sparse tensors to unflatten. tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to @@ -313,8 +314,8 @@ def _unflatten_sparse_tensors(flat, tensors): flat. """ flat_indices, flat_values = flat - indices = _unflatten_dense_tensors(flat_indices, [torch._indices(t) for t in tensors]) - values = _unflatten_dense_tensors(flat_values, [torch._values(t) for t in tensors]) + indices = _unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors]) + values = _unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors]) outputs = [] for t, i, v in zip(tensors, indices, values): outputs.append(t.new(i, v, t.size())) @@ -326,7 +327,7 @@ def _reorder_tensors_as(tensors, ordered_tensors): types, e.g., from _take_tensors. Reorder them to be of same order as ordered_tensors. - Arguments: + Args: tensors (Iterable[Tensor]): tensors to be reordered. They should be of the same order as ordered_tensors within their own types. ordered_tensors (Iterable[Tensor]): tensors whose order will be the @@ -339,8 +340,8 @@ def _reorder_tensors_as(tensors, ordered_tensors): type_dict = defaultdict(list) for tensor in tensors: type_dict[tensor.type()].append(tensor) - type_dict = {t: iter(coll) for t, coll in type_dict.items()} - return tuple(next(type_dict[tensor.type()]) for tensor in ordered_tensors) + type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} + return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) def _take_tensors(tensors, size_limit): @@ -355,12 +356,12 @@ def _take_tensors(tensors, size_limit): Blocks of tensors of same type and within size_limit. The yielded tensors are only ordered as the original sequence within its types. """ - buf_dict = defaultdict(lambda: [[], 0]) + buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0]) for tensor in tensors: t = tensor.type() if tensor.is_sparse: - indices = torch._indices(tensor) - values = torch._values(tensor) + indices = torch.Tensor._indices(tensor) + values = torch.Tensor._values(tensor) size = indices.numel() * indices.element_size() + values.numel() * values.element_size() else: size = tensor.numel() * tensor.element_size() @@ -437,7 +438,7 @@ def _get_available_device_type(): def _get_device_attr(get_member): device_type = _get_available_device_type() - if device_type.lower() == "cuda": + if device_type and device_type.lower() == "cuda": return get_member(torch.cuda) # add more available device types here return None @@ -491,3 +492,12 @@ def _get_device_index(device, optional=False, allow_cpu=False) -> int: raise ValueError('Expected a torch.device with a specified index ' 'or an integer, but got:{}'.format(device)) return device_idx + + +def _handle_complex(tensor): + """ + Returns a real view of a tensor if complex dtype else just the tensor + need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule + """ + return torch.view_as_real(tensor) if not isinstance(tensor, + torch.nn.UninitializedParameter) and tensor.is_complex() else tensor diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index b306b63438e64..26f32cfd9ffdb 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,17 +1,18 @@ import torch import functools from torch import Tensor -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union, List +from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten import warnings -in_dims_t = Union[int, Tuple[Optional[int], ...]] +in_dims_t = Union[int, Tuple] out_dims_t = Union[int, Tuple[int, ...]] # Checks that all args-to-be-batched have the same batch dim size def _validate_and_get_batch_size( - in_dims_as_tuple: Tuple[Optional[int], ...], - args: Tuple) -> int: - batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(in_dims_as_tuple, args) + flat_in_dims: List[Optional[int]], + flat_args: List) -> int: + batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args) if in_dim is not None] if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]): raise ValueError( @@ -19,40 +20,6 @@ def _validate_and_get_batch_size( f'dimension, got sizes {batch_sizes} for the mapped dimension') return batch_sizes[0] -# Check compatibility of `in_dims` and `args`. More specifically, checks the following: -# Wherever an in_dim is not None, then the corresponding index in args must be -# a Tensor. Furthermore, tensor must have the `in_dim` (0 <= in_dim < tensor.dim()) -def _check_args_can_be_mapped_with_in_dims( - in_dims_as_tuple: Tuple[Optional[int], ...], - args: Tuple, - func: Callable, - in_dims: in_dims_t) -> None: - for idx, (in_dim, arg) in enumerate(zip(in_dims_as_tuple, args)): - if in_dim is None: - continue - if not isinstance(in_dim, int): - raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): in_dims ' - f'must be a flat tuple containing ints and/or Nones. If you were ' - f'trying to vmap over a Tensor inside a Python collection in ' - f'`inputs`, we do not yet support that.') - if not isinstance(arg, Tensor): - raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): Got ' - f'in_dim={in_dim} for input {idx}, but input {idx} is not a ' - f'Tensor (got {type(arg)}) so it cannot be vmap\'ed over. ' - f'If you were trying to vmap over a Tensor inside a Python ' - f'collection in `inputs`, we do not yet support that; otherwise, ' - f'use None as the respective in_dim for input {idx}.') - # NB: We don't do dimension wrapping here. Consider allowing it in the - # future if there is demand. - if in_dim >= 0 and in_dim < arg.dim(): - continue - raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): Got in_dim={in_dim} ' - f'for input {idx}, but input {idx} is a Tensor of dimensionality ' - f'{arg.dim()} so expected in_dim to satisfy 0 <= in_dim < {arg.dim()}.') - def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: if isinstance(batched_outputs, tuple): return len(batched_outputs) @@ -73,28 +40,49 @@ def _create_batched_inputs( in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( - f'vmap({_get_name(func)}, in_dims={in_dims}, ...): expected `in_dims` to ' - f'be int or tuple, got: {type(in_dims)}.') - - # NB: Checks that len(in_dims) == len(args) (if in_dims is a tuple). - in_dims_as_tuple = _as_tuple( - in_dims, len(args), - lambda: f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): expected ' - f'one `in_dim` per input (got {len(args)} inputs) of {_get_name(func)}') - + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'expected `in_dims` to be int or a (potentially nested) tuple ' + f'matching the structure of inputs, got: {type(in_dims)}.') if len(args) == 0: raise ValueError( f'vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add ' f'inputs, or you are trying to vmap over a function with no inputs. ' f'The latter is unsupported.') - _check_args_can_be_mapped_with_in_dims(in_dims_as_tuple, args, func, in_dims) - batch_size = _validate_and_get_batch_size(in_dims_as_tuple, args) + flat_args, args_spec = tree_flatten(args) + flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) + if flat_in_dims is None: + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'in_dims is not compatible with the structure of `inputs`. ' + f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs ' + f'has structure {args_spec}.') + + for arg, in_dim in zip(flat_args, flat_in_dims): + if not isinstance(in_dim, int) and in_dim is not None: + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for an input but in_dim must be either ' + f'an integer dimension or None.') + if isinstance(in_dim, int) and not isinstance(arg, Tensor): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for an input but the input is of type ' + f'{type(arg)}. We cannot vmap over non-Tensor arguments, ' + f'please use None as the respective in_dim') + if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for some input, but that input is a Tensor ' + f'of dimensionality {arg.dim()} so expected in_dim to satisfy ' + f'0 <= in_dim < {arg.dim()}.') + + batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] - batched_inputs = tuple(arg if in_dim is None else - torch._add_batch_dim(arg, in_dim, vmap_level) # type: ignore - for in_dim, arg in zip(in_dims_as_tuple, args)) - return batched_inputs, batch_size + batched_inputs = [arg if in_dim is None else + torch._add_batch_dim(arg, in_dim, vmap_level) # type: ignore + for in_dim, arg in zip(flat_in_dims, flat_args)] + return tree_unflatten(batched_inputs, args_spec), batch_size # Undos the batching (and any batch dimensions) associated with the `vmap_level`. def _unwrap_batched( @@ -149,7 +137,7 @@ def _get_name(func: Callable): # Not all callables have __name__, in fact, only static functions/methods do. # A callable created via functools.partial or an nn.Module, to name some # examples, don't have a __name__. - fn_name = repr(func) + return repr(func) # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, # sends those into func, and then unwraps the output BatchedTensors. Operations @@ -178,10 +166,10 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca Args: func (function): A Python function that takes one or more arguments. Must return one or more Tensors. - in_dims (int or Tuple[Optional[int]]): Specifies which dimension of the - inputs should be mapped over. If `in_dims` is a Tuple, then it should have - one element per input. If the `in_dim` for a particular input is - None, then that indicates there is no map dimension. Default: 0. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. `in_dims` should have a structure + like the inputs. If the `in_dim` for a particular input is None, + then that indicates there is no map dimension. Default: 0. out_dims (int or Tuple[int]): Specifies where the mapped dimension should appear in the outputs. If `out_dims` is a Tuple, then it should have one element per output. Default: 0. diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index d515eb49695da..a013c9eb73261 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -18,8 +18,8 @@ from .grad_mode import no_grad, enable_grad, set_grad_enabled from .anomaly_mode import detect_anomaly, set_detect_anomaly from ..overrides import has_torch_function, handle_torch_function -from . import profiler from . import functional +from . import forward_ad __all__ = ['Variable', 'Function', 'backward', 'grad_mode'] @@ -71,6 +71,7 @@ def backward( retain_graph: Optional[bool] = None, create_graph: bool = False, grad_variables: Optional[_TensorOrTensors] = None, + inputs: Optional[Sequence[torch.Tensor]] = None, ) -> None: r"""Computes the sum of gradients of given tensors w.r.t. graph leaves. @@ -95,7 +96,13 @@ def backward( If you have to use this function, make sure to reset the ``.grad`` fields of your parameters to ``None`` after use to break the cycle and avoid the leak. - Arguments: + .. note:: + + If you run any forward ops, create ``grad_tensors``, and/or call ``backward`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + + Args: tensors (sequence of Tensor): Tensors of which the derivative will be computed. grad_tensors (sequence of (Tensor or None)): The "vector" in the Jacobian-vector @@ -110,6 +117,11 @@ def backward( create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. + inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be + accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were + used to compute the attr::tensors. All the provided inputs must be leaf + Tensors. """ if grad_variables is not None: warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.") @@ -119,8 +131,11 @@ def backward( raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) " "arguments both passed to backward(). Please only " "use 'grad_tensors'.") + if inputs is not None and len(inputs) == 0: + raise RuntimeError("'inputs' argument to backward() cannot be empty.") tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors) + inputs = tuple(inputs) if inputs is not None else tuple() grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors)) grad_tensors_ = _make_grads(tensors, grad_tensors_) @@ -128,8 +143,8 @@ def backward( retain_graph = create_graph Variable._execution_engine.run_backward( - tensors, grad_tensors_, retain_graph, create_graph, - allow_unreachable=True) # allow_unreachable flag + tensors, grad_tensors_, retain_graph, create_graph, inputs, + allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag def grad( @@ -153,7 +168,13 @@ def grad( leaves will still be computed, and will be accumulated into their ``.grad`` attribute. - Arguments: + .. note:: + + If you run any forward ops, create ``grad_outputs``, and/or call ``grad`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + + Args: outputs (sequence of Tensor): outputs of the differentiated function. inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be returned (and not accumulated into ``.grad``). @@ -184,7 +205,7 @@ def grad( grad_outputs=grad_outputs, retain_graph=retain_graph, create_graph=create_graph, - only_inputs=only_inputs, + only_inputs=only_inputs, allow_unused=allow_unused, ) @@ -201,7 +222,7 @@ def grad( return Variable._execution_engine.run_backward( outputs, grad_outputs_, retain_graph, create_graph, - inputs, allow_unused) + inputs, allow_unused, accumulate_grad=False) # This function applies in case of gradient checkpointing for memory @@ -230,6 +251,12 @@ def variable(*args, **kwargs): raise RuntimeError("autograd initialization failed") # Import all native method/classes -from torch._C._autograd import (ProfilerState, ProfilerConfig, ProfilerEvent, - _enable_profiler, _disable_profiler, _profiler_enabled, - _enable_record_function, _set_empty_test_observer) +from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, ProfilerConfig, ProfilerEvent, + _enable_profiler_legacy, _disable_profiler_legacy, _profiler_enabled, + _enable_record_function, _set_empty_test_observer, kineto_available) + +if kineto_available(): + from torch._C._autograd import (ProfilerResult, KinetoEvent, + _prepare_profiler, _enable_profiler, _disable_profiler) + +from . import profiler diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index 6a33f1780ef10..97def2dea3d0f 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -89,7 +89,7 @@ class set_detect_anomaly(object): See ``detect_anomaly`` above for details of the anomaly detection behaviour. - Arguments: + Args: mode (bool): Flag whether to enable anomaly detection (``True``), or disable (``False``). diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py new file mode 100644 index 0000000000000..3ad989621c889 --- /dev/null +++ b/torch/autograd/forward_ad.py @@ -0,0 +1,116 @@ +import torch +from .grad_mode import _DecoratorContextManager + +from typing import Any + +# TODO(alband): Once most of the formulas are implemented, these functions need to be added +# to the main doc to make them fully "public". + +# Global variable used to make the python API simpler to use +_current_level = -1 + +def enter_dual_level(): + r"""Function that can be used to enter a new forward grad level. + This level can be used to make and unpack dual Tensors to compute + forward gradients. + + This function also updates the current level that is used by default + by the other functions in this API. + """ + global _current_level + new_level = torch._C._enter_dual_level() + if new_level != _current_level + 1: + raise RuntimeError("Entering a new forward AD level but the current level " + "is not valid. Make sure you did not modified it directly.") + _current_level = new_level + return new_level + +def exit_dual_level(*, level=None): + r"""Function that can be used to exit a forward grad level. + This function deletes all the gradients associated with this + level. Only deleting the latest entered level is allowed. + + This function also updates the current level that is used by default + by the other functions in this API. + """ + global _current_level + if level is None: + level = _current_level + if level != _current_level: + raise RuntimeError("Trying to exit a forward AD level that was not the last one " + "that was created. This is not supported.") + torch._C._exit_dual_level(level=level) + _current_level = level - 1 + +def make_dual(tensor, tangent, *, level=None): + r"""Function that creates a "dual object" that can be used to compute forward AD gradients + based on the given Tensor and its tangent. It returns a new Tensor that shares memory with + :attr:`tensor` and the :attr:`tangent` is used as-is. + + This function is backward differentiable. + + Given a function `f` whose jacobian is `J`, it allows to compute the jacobian vector product, + named `jvp`, between `J` and a given vector `v` as follows. + + Example:: + >>> inp = make_dual(x, v) + >>> out = f(inp) + >>> y, jvp = unpack_dual(out) + + """ + if level is None: + level = _current_level + + if level < 0: + raise RuntimeError("Trying to create a dual Tensor for forward AD but no level " + "exists, make sure to enter_dual_level() first.") + + return torch.make_dual(tensor, tangent, level=level) + +def unpack_dual(tensor, *, level=None): + r"""Function that unpacks a "dual object" to recover two plain tensors, one representing + the primal and the other the tangent (both are views of :attr:`tensor`. Neither of these + tensors can be dual tensor of level :attr:`level`. + + This function is backward differentiable. + """ + if level is None: + level = _current_level + + if level < 0: + return tensor, None + + return torch.unpack_dual(tensor, level=level) + +class dual_level(_DecoratorContextManager): + r"""Context-manager that controls the current forward ad level. It + appropriately enters and exit the dual level. + + This function also updates the current level that is used by default + by the other functions in this API. + + Example:: + + >>> x = torch.tensor([1]) + >>> x_t = torch.tensor([1]) + >>> with dual_level(): + ... inp = make_dual(x, x_t) + ... # Do computations with inp + ... out = your_fn(inp) + ... _, grad = unpack_dual(out) + >>> grad is None + False + >>> # After exiting the level, the grad is deleted + >>> _, grad_after = unpack_dual(out) + >>> grad is None + True + + """ + def __init__(self): + super().__init__() + + def __enter__(self): + return enter_dual_level() + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + exit_dual_level() diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 6714444acdcfa..0d546ceb28d63 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -1,11 +1,12 @@ import torch import torch._C as _C +from torch._C import _functions import torch.utils.hooks as hooks from torch._six import with_metaclass import functools import warnings from collections import OrderedDict -from typing import Any +from typing import Any, List, Optional class _ContextMethodMixin(object): @@ -84,7 +85,8 @@ class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin): _is_legacy = False def apply(self, *args): - return self._forward_cls.backward(self, *args) + # _forward_cls is defined by derived class + return self._forward_cls.backward(self, *args) # type: ignore class FunctionMeta(type): @@ -115,8 +117,8 @@ def __init__(cls, name, bases, attrs): return super(FunctionMeta, cls).__init__(name, bases, attrs) - -class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): +# mypy doesn't understand `with_metaclass` from torch._six +class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # type: ignore r"""Records operation history and defines formulas for differentiating ops. See the Note on extending the autograd engine for more details on how to use @@ -227,7 +229,7 @@ def wrapper(ctx, *args): if not isinstance(outputs, tuple): outputs = (outputs,) - err_fn = torch._C._functions.DelayedError( + err_fn = _functions.DelayedError( b"trying to differentiate twice a function that was marked" b"with @once_differentiable", len(outputs)) @@ -330,7 +332,7 @@ def _unflatten(input, proto): # unflatten a list or tuple input into a nested list/tuple structure # specified by proto def unflatten_helper(input, proto): - res = [] + res: List[Optional[torch.Tensor]] = [] if hasattr(proto, "_jit_wrap"): return proto._jit_wrap(input) if not isinstance(proto, (list, tuple)): @@ -379,16 +381,16 @@ def _do_backward(self, gradients, retain_variables): del self._to_save_nested return result - def backward(self, *gradients: Any) -> Any: + def backward(self, *gradients: Any) -> Any: # type: ignore nested_gradients = _unflatten(gradients, self._nested_output) - result = self.backward_extended(*nested_gradients) + result = self.backward_extended(*nested_gradients) # type: ignore return tuple(_iter_None_tensors(result)) __call__ = _do_forward - def forward(self, *args: Any) -> Any: + def forward(self, *args: Any) -> Any: # type: ignore nested_tensors = _map_tensor_data(self._nested_input) - result = self.forward_extended(*nested_tensors) + result = self.forward_extended(*nested_tensors) # type: ignore del self._nested_input self._nested_output = result return tuple(_iter_tensors(result)) diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index 2a1d0ef55fd9a..70961cef97446 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -1,4 +1,5 @@ import torch +from typing import Tuple, List # Utility functions @@ -131,8 +132,8 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retai assert isinstance(grad_outputs, tuple) assert len(outputs) == len(grad_outputs) - new_outputs = tuple() - new_grad_outputs = tuple() + new_outputs: Tuple[torch.Tensor, ...] = tuple() + new_grad_outputs: Tuple[torch.Tensor, ...] = tuple() for out, grad_out in zip(outputs, grad_outputs): if out is not None and out.requires_grad: new_outputs += (out,) @@ -153,7 +154,7 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage): if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage)) - res = tuple() + res: Tuple[torch.Tensor, ...] = tuple() for i, grads_i in enumerate(grads): if grads_i is None: if strict: @@ -381,15 +382,15 @@ def jacobian(func, inputs, create_graph=False, strict=False): Defaults to ``False``. Returns: - Jacobian (Tensor or nested tuple of Tensors): if there are a single - input and output, this will be a single Tensor containing the - Jacobian for the linearized inputs and output. If one of the two is - a tuple, then the Jacobian will be a tuple of Tensors. If both of - them are tuples, then the Jacobian will be a tuple of tuple of - Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the - ``i``\th output and ``j``\th input and will have as size the - concatenation of the sizes of the corresponding output and the - corresponding input. + Jacobian (Tensor or nested tuple of Tensors): if there is a single + input and output, this will be a single Tensor containing the + Jacobian for the linearized inputs and output. If one of the two is + a tuple, then the Jacobian will be a tuple of Tensors. If both of + them are tuples, then the Jacobian will be a tuple of tuple of + Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the + ``i``\th output and ``j``\th input and will have as size the + concatenation of the sizes of the corresponding output and the + corresponding input. Example: @@ -427,10 +428,11 @@ def jacobian(func, inputs, create_graph=False, strict=False): "jacobian") _check_requires_grad(outputs, "outputs", strict=strict) - jacobian = tuple() + jacobian: Tuple[torch.Tensor, ...] = tuple() for i, out in enumerate(outputs): - jac_i = tuple([] for _ in range(len(inputs))) + # mypy complains that expression and variable have different types due to the empty list + jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore for j in range(out.nelement()): vj = _autograd_grad((out.reshape(-1)[j],), inputs, retain_graph=True, create_graph=create_graph) @@ -476,12 +478,12 @@ def hessian(func, inputs, create_graph=False, strict=False): Defaults to ``False``. Returns: - Hessian (Tensor or a tuple of tuple of Tensors) if there are a single input, - this will be a single Tensor containing the Hessian for the input. - If it is a tuple, then the Hessian will be a tuple of tuples where - ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input - and ``j``\th input with size the sum of the size of the ``i``\th input plus - the size of the ``j``\th input. + Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input, + this will be a single Tensor containing the Hessian for the input. + If it is a tuple, then the Hessian will be a tuple of tuples where + ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input + and ``j``\th input with size the sum of the size of the ``i``\th input plus + the size of the ``j``\th input. Example: @@ -660,7 +662,9 @@ def hvp(func, inputs, v=None, create_graph=False, strict=False): hvp for said inputs, which is the expected mathematical value. Defaults to ``False``. Returns: - func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + hvp (tuple of Tensors or Tensor): result of the dot product with the same shape as the inputs. diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index bbd96e941a542..ebbb7dea41e95 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,3 +1,4 @@ +import sys import torch import functools import inspect @@ -31,13 +32,46 @@ def _wrap_generator(self, func): @functools.wraps(func) def generator_context(*args, **kwargs): gen = func(*args, **kwargs) - while True: - try: - with self.__class__(): - x = next(gen) - yield x - except StopIteration: - break + + # Generators are suspended and unsuspended at `yield`, hence we + # make sure the grad mode is properly set every time the execution + # flow returns into the wrapped generator and restored when it + # returns through our `yield` to our caller (see PR #49017). + cls = type(self) + try: + # Issuing `None` to a generator fires it up + with cls(): + response = gen.send(None) + + while True: + try: + # Forward the response to our caller and get its next request + request = yield response + + except GeneratorExit: + # Inform the still active generator about its imminent closure + with cls(): + gen.close() + raise + + except BaseException: + # Propagate the exception thrown at us by the caller + with cls(): + response = gen.throw(*sys.exc_info()) + + else: + # Pass the last request to the generator and get its response + with cls(): + response = gen.send(request) + + # We let the exceptions raised above by the generator's `.throw` or + # `.send` methods bubble up to our caller, except for StopIteration + except StopIteration as e: + # The generator informed us that it is done: take whatever its + # returned value (if any) was and indicate that we're done too + # by returning it (see docs for python's return-statement). + return e.value + return generator_context def __enter__(self) -> None: @@ -138,7 +172,7 @@ class set_grad_enabled(object): This context manager is thread local; it will not affect computation in other threads. - Arguments: + Args: mode (bool): Flag whether to enable grad (``True``), or disable (``False``). This can be used to conditionally enable gradients. diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 7ca1fccfce54a..829391a52cfd3 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -5,7 +5,9 @@ from torch.overrides import is_tensor_like from itertools import product import warnings -from typing import Callable, Union, Optional +from typing import Callable, Union, Optional, Iterable, List +from torch._vmap_internals import vmap +import functools def zero_gradients(x): if isinstance(x, torch.Tensor): @@ -29,15 +31,16 @@ def make_jacobian(input, num_out): lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input))) if not jacobians: return None - return type(input)(jacobians) + return type(input)(jacobians) # type: ignore else: return None -def iter_tensors(x, only_requiring_grad=False): +def iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False) -> Iterable[torch.Tensor]: if is_tensor_like(x): - if x.requires_grad or not only_requiring_grad: - yield x + # mypy doesn't narrow type of `x` to torch.Tensor + if x.requires_grad or not only_requiring_grad: # type: ignore + yield x # type: ignore elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str): for elem in x: for result in iter_tensors(elem, only_requiring_grad): @@ -102,13 +105,11 @@ def fn_out(): d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj() elif ds_dx.is_complex(): # R -> C # w_d = conj_w_d = 0.5 * ds_dx - dL_dz_conj = 0.5 * (grad_out.conjugate() * ds_dx + grad_out * ds_dx.conj()) - # The above formula is derived for a C -> C function that's a part of - # bigger function with real valued output. From separate calculations, - # it can be verified that the gradient for R -> C function - # equals to real value of the result obtained from the generic formula for - # C -> C functions used above. - d[d_idx] = torch.real(dL_dz_conj) + # dL_dz_conj = 0.5 * [grad_out.conj() * ds_dx + grad_out * ds_dx.conj()] + # = 0.5 * [grad_out.conj() * ds_dx + (grad_out.conj() * ds_dx).conj()] + # = 0.5 * 2 * real(grad_out.conj() * ds_dx) + # = real(grad_out.conj() * ds_dx) + d[d_idx] = torch.real(grad_out.conjugate() * ds_dx) else: # R -> R d[d_idx] = ds_dx * grad_out @@ -139,7 +140,7 @@ def get_stride(size): indices = x_indices[i].tolist() + list(x_idx) d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) update_jacobians(x_value, x_idx, d_tensor, d_idx) - elif x_tensor.layout == torch._mkldnn: + elif x_tensor.layout == torch._mkldnn: # type: ignore # Use .data here to get around the version check x_tensor = x_tensor.data if len(input) != 1: @@ -165,7 +166,7 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0): if output.is_sparse: raise ValueError('Sparse output is not supported at gradcheck yet. ' 'Please call to_dense() on the output of fn for gradcheck.') - if output.layout == torch._mkldnn: + if output.layout == torch._mkldnn: # type: ignore raise ValueError('MKLDNN output is not supported at gradcheck yet. ' 'Please call to_dense() on the output of fn for gradcheck.') diff_input_list = list(iter_tensors(input, True)) @@ -202,6 +203,59 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0): return jacobian, reentrant, correct_grad_sizes, correct_grad_types +def failed_batched_grad_test_msg(output_idx, input_idx, res, exp): + return f""" +For output {output_idx} and input {input_idx}, batched grad computation failed. + +If you are adding a new operator, a fast workaround is to use gradcheck with +`check_batched_grad=False` and filing an issue. If the test is autogenerated +from e.g. common_method_invocations, there should be a deny list you can add +the test to. If you're interested in making your new operator work with +batched grad computation, read on. + +If you're modifying an existing operator that supports batched grad computation, +or wish to make a new operator work with batched grad computation, please read +the following. + +To compute batched grads (e.g., jacobians, hessians), we vmap over the backward +computation. The most common failure case is if there is a 'vmap-incompatible +operation' in the backward pass. Please see +NOTE: [How to write vmap-compatible backward formulas] +in the codebase for an explanation of how to fix this. + +Got: +{res} + +Expected: +{exp} +""".strip() + +def test_batched_grad(fail_test, input, output, output_idx): + diff_input_list = list(iter_tensors(input, True)) + grad = functools.partial(torch.autograd.grad, output, diff_input_list, retain_graph=True, allow_unused=True) + + def vjp(v): + results = grad(v) + results = tuple(grad if grad is not None else + torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape) + for grad, inp in zip(results, diff_input_list)) + return results + + grad_outputs = [torch.randn_like(output) for _ in range(2)] + + expected = [vjp(gO) for gO in grad_outputs] + expected = [torch.stack(shards) for shards in zip(*expected)] + + # Squash warnings since these are expected to happen in most cases + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Batching rule not implemented") + warnings.filterwarnings("ignore", message="torch.vmap is an experimental prototype") + result = vmap(vjp)(torch.stack(grad_outputs)) + for input_idx, (res, exp) in enumerate(zip(result, expected)): + if torch.allclose(res, exp): + continue + return fail_test(failed_batched_grad_test_msg(output_idx, input_idx, res, exp)) + def _as_tuple(x): if istuple(x): @@ -235,7 +289,8 @@ def gradcheck( check_sparse_nnz: bool = False, nondet_tol: float = 0.0, check_undefined_grad: bool = True, - check_grad_dtypes: bool = False + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, ) -> bool: r"""Check gradients computed via small finite differences against analytical gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type @@ -277,8 +332,10 @@ def gradcheck( nondet_tol (float, optional): tolerance for non-determinism. When running identical inputs through the differentiation, the results must either match exactly (default, 0.0) or be within this tolerance. - check_undefined_grad (bool, options): if True, check if undefined output grads - are supported and treated as zeros + check_undefined_grad (bool, optional): if True, check if undefined output grads + are supported and treated as zeros, for ``Tensor`` outputs. + check_batched_grad (bool, optional): if True, check if we can compute + batched gradients using prototype vmap support. Defaults to False. Returns: True if all differences satisfy allclose condition @@ -298,20 +355,20 @@ def fail_test(msg): if is_tensor_like(inp) and inp.requires_grad: if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128): warnings.warn( - 'The {}th input requires gradient and ' + f'Input #{idx} requires gradient and ' 'is not a double precision floating point or complex. ' 'This check will likely fail if all the inputs are ' 'not of double precision floating point or complex. ') content = inp._values() if inp.is_sparse else inp # TODO: To cover more problematic cases, replace stride = 0 check with # "any overlap in memory" once we have a proper function to check it. - if content.layout is not torch._mkldnn and \ - not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())): - raise RuntimeError( - 'The {}th input has a dimension with stride 0. gradcheck only ' - 'supports inputs that are non-overlapping to be able to ' - 'compute the numerical gradients correctly. You should call ' - '.contiguous on the input before passing it to gradcheck.') + if content.layout is not torch._mkldnn: # type: ignore + if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())): + raise RuntimeError( + 'The {}th input has a dimension with stride 0. gradcheck only ' + 'supports inputs that are non-overlapping to be able to ' + 'compute the numerical gradients correctly. You should call ' + '.contiguous on the input before passing it to gradcheck.') any_input_requiring_grad = True inp.retain_grad() if not any_input_requiring_grad: @@ -402,33 +459,50 @@ def not_reentrant_error(error_str=''): if out_is_complex and not reentrant_with_imag_grad_out: return not_reentrant_error(' (calculated using complex valued grad output)') + if check_batched_grad: + assert reentrant, ('Batched gradient checking makes the assumption that ' + 'backward is reentrant. This assertion should never ' + 'be triggered: we expect gradcheck to have early ' + 'exited before reaching this point if backward is ' + 'not reentrant. Please file us a bug report.') + # NB: test_batched_grad compares two autograd.grad invocations with a single + # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the + # sense that we're not comparing an analytical jacobian with a numeric one, + # but it is morally similar (we could have computed a full analytic jac + # via vmap, but that is potentially slow) + test_batched_grad(fail_test, tupled_inputs, o, j) + # check if the backward multiplies by grad_output output = _differentiable_outputs(func(*tupled_inputs)) if any([o.requires_grad for o in output]): - diff_input_list = list(iter_tensors(tupled_inputs, True)) + diff_input_list: List[torch.Tensor] = list(iter_tensors(tupled_inputs, True)) if not diff_input_list: raise RuntimeError("no Tensors requiring grad found in input") grads_input = torch.autograd.grad(output, diff_input_list, [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output], allow_unused=True) - for gi, i in zip(grads_input, diff_input_list): + for gi, di in zip(grads_input, diff_input_list): if gi is None: continue if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: - if gi.layout != i.layout: - return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(i.layout) + ')') + if gi.layout != di.layout: + return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')') if gi.layout == torch.sparse_coo: - if gi.sparse_dim() != i.sparse_dim(): + if gi.sparse_dim() != di.sparse_dim(): return fail_test('grad is sparse tensor, but has incorrect sparse_dim') - if gi.dense_dim() != i.dense_dim(): + if gi.dense_dim() != di.dense_dim(): return fail_test('grad is sparse tensor, but has incorrect dense_dim') gi = gi.to_dense() - i = i.to_dense() - if not gi.eq(0).all(): + di = di.to_dense() + + if check_sparse_nnz: + if not torch.allclose(gi, torch.zeros_like(gi)): + return fail_test('backward not multiplied by grad_output') + elif not gi.eq(0).all(): return fail_test('backward not multiplied by grad_output') - if gi.dtype != i.dtype or gi.device != i.device or gi.is_sparse != i.is_sparse: + if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse: return fail_test("grad is incorrect type") - if gi.size() != i.size(): + if gi.size() != di.size(): return fail_test('grad is incorrect size') if check_undefined_grad: @@ -463,7 +537,11 @@ def check_undefined_grad_support(output_to_check): return True # All backward functions must work properly if all output grads are undefined - outputs_to_check = [[torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*tupled_inputs))]] + outputs_to_check = [[ + torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*tupled_inputs)) + # This check filters out Tensor-likes that aren't instances of Tensor. + if isinstance(o, torch.Tensor) + ]] # If there are multiple output grads, we should be able to undef one at a time without error if len(outputs_to_check[0]) > 1: @@ -491,7 +569,8 @@ def gradgradcheck( raise_exception: bool = True, nondet_tol: float = 0.0, check_undefined_grad: bool = True, - check_grad_dtypes: bool = False + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, ) -> bool: r"""Check gradients of gradients computed via small finite differences against analytical gradients w.r.t. tensors in :attr:`inputs` and @@ -536,8 +615,10 @@ def gradgradcheck( exactly (default, 0.0) or be within this tolerance. Note that a small amount of nondeterminism in the gradient will lead to larger inaccuracies in the second derivative. - check_undefined_grad (bool, options): if True, check if undefined output grads + check_undefined_grad (bool, optional): if True, check if undefined output grads are supported and treated as zeros + check_batched_grad (bool, optional): if True, check if we can compute + batched gradients using prototype vmap support. Defaults to False. Returns: True if all differences satisfy allclose condition @@ -568,6 +649,7 @@ def new_func(*args): grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True) return grad_inputs - return gradcheck(new_func, tupled_inputs + tupled_grad_outputs, eps, atol, rtol, raise_exception, - nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad, - check_grad_dtypes=check_grad_dtypes) + return gradcheck( + new_func, tupled_inputs + tupled_grad_outputs, eps, atol, rtol, raise_exception, + nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad, + check_grad_dtypes=check_grad_dtypes, check_batched_grad=check_batched_grad) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 94b1aae844f14..a3d0da1aef9d5 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -1,11 +1,16 @@ import itertools from typing import Any import torch +from torch.autograd import DeviceType from torch.futures import Future from collections import defaultdict, namedtuple from operator import attrgetter +from typing import Dict, List, Tuple, Optional + +import math + try: # Available in Python >= 3.2 from contextlib import ContextDecorator @@ -13,6 +18,13 @@ import functools class ContextDecorator(object): # type: ignore[no-redef] + + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + def __call__(self, func): @functools.wraps(func) def wrapped(*args, **kwargs): @@ -27,15 +39,41 @@ class EventList(list): def __init__(self, *args, **kwargs): use_cuda = kwargs.pop('use_cuda', True) profile_memory = kwargs.pop('profile_memory', False) + with_flops = kwargs.pop('with_flops', False) super(EventList, self).__init__(*args, **kwargs) - self._cpu_children_populated = False self._use_cuda = use_cuda self._profile_memory = profile_memory + self._tree_built = False + self._with_flops = with_flops + + def _build_tree(self): + self._populate_cpu_children() + self._remove_dup_nodes() + self._set_backward_stacktraces() + self._tree_built = True def __str__(self): return self.table() - def populate_cpu_children(self): + def _remove_dup_nodes(self): + while True: + to_delete = [] + for idx in range(len(self)): + if (self[idx].cpu_parent is not None and + self[idx].cpu_parent.name == self[idx].name and + len(self[idx].cpu_parent.cpu_children) == 1): + self[idx].cpu_parent.cpu_children = self[idx].cpu_children + self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up + for ch in self[idx].cpu_children: + ch.cpu_parent = self[idx].cpu_parent + to_delete.append(idx) + if len(to_delete) == 0: + break + new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete] + self.clear() + self.extend(new_evts) + + def _populate_cpu_children(self): """Populates child events into each underlying FunctionEvent object. One event is a child of another if [s1, e1) is inside [s2, e2). Where s1 and e1 would be start and end of the child event's interval. And @@ -47,13 +85,11 @@ def populate_cpu_children(self): If for any reason two intervals intersect only partially, this function will not record a parent child relationship between then. """ - if self.cpu_children_populated: - return # Some events can be async (i.e. start and end on different threads), # since it's generally undefined how to attribute children ranges to # async ranges, we do not use them when calculating nested ranges and stats - sync_events = [evt for evt in self if not evt.is_async] + sync_events = [evt for evt in self if not evt.is_async and evt.device_type == DeviceType.CPU] events = sorted( sync_events, key=attrgetter("thread"), @@ -78,17 +114,17 @@ def populate_cpu_children(self): # Algorithm has O(N * log(N)) complexity where N is number of # intervals for thread_id, thread_events in threads: - thread_events = sorted( + thread_events_ = sorted( thread_events, - key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end], + key=lambda event: [event.time_range.start, -event.time_range.end], ) - current_events = [] + current_events: List[FunctionEvent] = [] cur_end = 0 - for event in thread_events: + for event in thread_events_: while len(current_events) > 0: parent = current_events[-1] - if event.cpu_interval.start >= parent.cpu_interval.end or \ - event.cpu_interval.end > parent.cpu_interval.end: + if event.time_range.start >= parent.time_range.end or \ + event.time_range.end > parent.time_range.end: # this can't be a parent current_events.pop() else: @@ -103,20 +139,40 @@ def populate_cpu_children(self): current_events.append(event) - self._cpu_children_populated = True + def _set_backward_stacktraces(self): + def bw_parent(evt): + if evt is None: + return None + elif evt.scope == 1: # BACKWARD_FUNCTION + return evt + else: + return bw_parent(evt.cpu_parent) + + fwd_stacks = {} + for evt in self: + if bw_parent(evt) is None and evt.stack is not None: + t = (evt.sequence_nr, evt.thread) + if t not in fwd_stacks: + fwd_stacks[t] = evt.stack + + for evt in self: + p = bw_parent(evt) + if p is not None: + assert p.fwd_thread is not None + t = (p.sequence_nr, p.fwd_thread) + if t in fwd_stacks: + evt.stack = fwd_stacks[t] + else: + evt.stack = [] @property def self_cpu_time_total(self): return sum([event.self_cpu_time_total for event in self]) - @property - def cpu_children_populated(self): - return self._cpu_children_populated - - def table(self, sort_by=None, row_limit=100, header=None, top_level_events_only=False): + def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False): """Prints an EventList as a nicely formatted table. - Arguments: + Args: sort_by (str, optional): Attribute used to sort entries. By default they are printed in the same order as they were registered. Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, @@ -135,9 +191,11 @@ def table(self, sort_by=None, row_limit=100, header=None, top_level_events_only= self, sort_by=sort_by, row_limit=row_limit, + max_src_column_width=max_src_column_width, header=header, use_cuda=self._use_cuda, profile_memory=self._profile_memory, + with_flops=self._with_flops, top_level_events_only=top_level_events_only) def export_chrome_trace(self, path): @@ -145,7 +203,7 @@ def export_chrome_trace(self, path): The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL. - Arguments: + Args: path (str): Path where the trace will be written. """ import os @@ -166,8 +224,8 @@ def export_chrome_trace(self, path): '"args": {}}, ' % ( evt.name, - evt.cpu_interval.start, - evt.cpu_interval.elapsed_us(), + evt.time_range.start, + evt.time_range.elapsed_us(), evt.thread if not evt.is_remote else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "', @@ -183,7 +241,7 @@ def export_chrome_trace(self, path): '"pid": "CPU functions", ' '"id": %s, ' '"cat": "cpu_to_cuda", ' - '"args": {}}, ' % (evt.name, evt.cpu_interval.start, + '"args": {}}, ' % (evt.name, evt.time_range.start, evt.thread, next_id)) f.write('{"name": "%s", ' '"ph": "f", ' @@ -208,29 +266,63 @@ def export_chrome_trace(self, path): f.truncate() f.write("]") - def key_averages(self, group_by_input_shapes=False): + def supported_export_stacks_metrics(self): + return ["self_cpu_time_total", "self_cuda_time_total"] + + def export_stacks(self, path: str, metric: str): + if metric not in self.supported_export_stacks_metrics(): + raise ValueError("metric should be one of: " + str(self.supported_export_stacks_metrics())) + translate_table = str.maketrans(" ;\t\n", "____") + with open(path, 'w') as f: + for evt in self: + if evt.stack and len(evt.stack) > 0: + metric_value = getattr(evt, metric) + if int(metric_value) > 0: + stack_str = "" + for entry in reversed(evt.stack): + stack_str += entry.translate(translate_table) + stack_str += ";" + stack_str = stack_str[:-1] + " " + str(int(metric_value)) + f.write(stack_str + "\n") + + def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0): """Averages all function events over their keys. - @param group_by_input_shapes The key would become - (event name, input dimensions) rather than just event name. - This is useful to see which dimensionality contributes to the runtime - the most and may help with dimension specific optimizations or - choosing best candidates for quantization (aka fitting a roof line) + Args: + group_by_input_shapes: group entries by + (event name, input shapes) rather than just event name. + This is useful to see which input shapes contribute to the runtime + the most and may help with size-specific optimizations or + choosing the best candidates for quantization (aka fitting a roof line) + + group_by_stack_n: group by top n stack trace entries Returns: An EventList containing FunctionEventAvg objects. """ - self.populate_cpu_children() - stats = defaultdict(FunctionEventAvg) + assert self._tree_built + stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) - def get_key(event, group_by_input_shapes): - if not group_by_input_shapes: - return (event.key, event.node_id) - return (event.key, str(event.input_shapes), event.node_id) + def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]: + key = [str(event.key), str(event.node_id), str(event.device_type), str(event.is_legacy)] + if group_by_input_shapes: + key.append(str(event.input_shapes)) + if group_by_stack_n > 0: + key += event.stack[:group_by_stack_n] + return tuple(key) for evt in self: - stats[get_key(evt, group_by_input_shapes)].add( - evt, group_by_input_shapes) - return EventList(stats.values(), use_cuda=self._use_cuda, profile_memory=self._profile_memory) + stats[get_key(evt, group_by_input_shapes, group_by_stack_n)].add(evt) + + avg_list = EventList( + stats.values(), + use_cuda=self._use_cuda, + profile_memory=self._profile_memory, + with_flops=self._with_flops) + for evt in avg_list: + evt.stack = evt.stack[:group_by_stack_n] + if not group_by_input_shapes: + evt.input_shapes = "" + return avg_list def total_average(self): """Averages all events. @@ -253,7 +345,7 @@ class profile(object): only report runtime of PyTorch functions. Note: profiler is thread local and is automatically propagated into the async tasks - Arguments: + Args: enabled (bool, optional): Setting this to False makes this context manager a no-op. Default: ``True``. @@ -272,10 +364,23 @@ class profile(object): self cpu time might be artificially increased because of the shape collection. + with_flops (bool, optional): If with_flops is set, the profiler will estimate + the FLOPS (floating pointer operations per second) value using the operator's input shape + and total CPU time. This allows one to estimate the hardware performance. Currently, + this option only works for the GEMM and CONV operator, default: ``False`` + profile_memory (bool, optional): Whether to report memory usage, default: ``False`` + with_stack (bool, optional): record source information (file and line number) for the ops + + use_kineto (bool, default False): experimental support for Kineto profiler + + use_cpu (default True) - whether to profile CPU events; setting to False requires + use_kineto=True and can be used to lower the overhead for GPU-only profiling + .. warning: - Enabling memory profiling incurs additional profiler overhead + Enabling memory profiling or source attribution incurs additional profiler + overhead .. warning: This context managers should not be called recursively, i.e. no nested @@ -311,37 +416,98 @@ def __init__( enabled=True, use_cuda=False, record_shapes=False, - profile_memory=False): - self.enabled = enabled - self.use_cuda = use_cuda - self.function_events = None + with_flops=False, + profile_memory=False, + with_stack=False, + use_kineto=False, + use_cpu=True): + self.enabled: bool = enabled if not self.enabled: return + self.use_cuda = use_cuda + self.function_events = None self.entered = False self.record_shapes = record_shapes + self.with_flops = with_flops + self.record_shapes |= self.with_flops self.profile_memory = profile_memory + self.with_stack = with_stack + self.use_cpu = use_cpu + self.kineto_results = None + if not self.use_cpu: + assert use_kineto, \ + "Device-only events supported only with Kineto (use_kineto=True)" + + self.profiler_kind = None + self.kineto_activities = set() + if use_kineto: + self.profiler_kind = torch.autograd.ProfilerState.KINETO + if self.use_cpu: + self.kineto_activities.add(torch.autograd.ProfilerActivity.CPU) + if self.use_cuda: + self.kineto_activities.add( + # uses CUPTI + torch.autograd.ProfilerActivity.CUDA) + assert len(self.kineto_activities) > 0, \ + "No activities specified for Kineto profiler" + elif self.use_cuda: + # legacy CUDA mode + self.profiler_kind = torch.autograd.ProfilerState.CUDA + else: + self.profiler_kind = torch.autograd.ProfilerState.CPU + + if self.profiler_kind == torch.autograd.ProfilerState.KINETO: + assert ( + torch.autograd.kineto_available() + ), """Requested Kineto profiling but Kineto is not available, + make sure PyTorch is built with USE_KINETO=1""" + + def config(self): + assert self.profiler_kind is not None + return torch.autograd.ProfilerConfig( + self.profiler_kind, + self.record_shapes, + self.profile_memory, + self.with_stack, + self.with_flops) def __enter__(self): if not self.enabled: return if self.entered: - raise RuntimeError("autograd profiler traces are not reentrant") + raise RuntimeError("profiler context manager is not reentrant") self.entered = True - profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \ - else torch.autograd.ProfilerState.CPU - - config = torch.autograd.ProfilerConfig(profiler_kind, self.record_shapes, self.profile_memory) - torch.autograd._enable_profiler(config) + if self.kineto_activities: + torch.autograd._prepare_profiler(self.config(), self.kineto_activities) + torch.autograd._enable_profiler(self.config(), self.kineto_activities) + else: + torch.autograd._enable_profiler_legacy(self.config()) return self + def _prepare_kineto_trace(self): + assert self.kineto_activities + self.entered = True + torch.autograd._prepare_profiler(self.config(), self.kineto_activities) + + def _start_kineto_trace(self): + assert self.kineto_activities + torch.autograd._enable_profiler(self.config(), self.kineto_activities) + def __exit__(self, exc_type, exc_val, exc_tb): if not self.enabled: return - records = torch.autograd._disable_profiler() + if self.kineto_activities: + self.kineto_results = torch.autograd._disable_profiler() + parsed_results = parse_kineto_results(self.kineto_results) + else: + records = torch.autograd._disable_profiler_legacy() + parsed_results = parse_legacy_records(records) self.function_events = EventList( - parse_cpu_trace(records), + parsed_results, use_cuda=self.use_cuda, - profile_memory=self.profile_memory) + profile_memory=self.profile_memory, + with_flops=self.with_flops) + self.function_events._build_tree() return False def __repr__(self): @@ -352,34 +518,45 @@ def __repr__(self): def __str__(self): if self.function_events is None: return '' - self.function_events.populate_cpu_children() return str(self.function_events) def _check_finish(self): if self.function_events is None: raise RuntimeError("can't export a trace that didn't finish running") - self.function_events.populate_cpu_children() - def table(self, sort_by=None, row_limit=100, header=None, top_level_events_only=False): + def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False): self._check_finish() + assert self.function_events is not None return self.function_events.table( - sort_by=sort_by, row_limit=row_limit, header=header, + sort_by=sort_by, row_limit=row_limit, max_src_column_width=max_src_column_width, header=header, top_level_events_only=top_level_events_only ) table.__doc__ = EventList.table.__doc__ def export_chrome_trace(self, path): self._check_finish() - return self.function_events.export_chrome_trace(path) + if self.kineto_results is not None: + self.kineto_results.save(path) + else: + assert self.function_events is not None + return self.function_events.export_chrome_trace(path) export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ - def key_averages(self, group_by_input_shape=False): + def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): + self._check_finish() + assert self.function_events is not None, "Expected profiling results" + assert self.with_stack, "export_stacks() requires with_stack=True" + return self.function_events.export_stacks(path, metric) + + def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): self._check_finish() - return self.function_events.key_averages(group_by_input_shape) + assert self.function_events is not None, "Expected profiling results" + return self.function_events.key_averages(group_by_input_shape, group_by_stack_n) key_averages.__doc__ = EventList.key_averages.__doc__ def total_average(self): self._check_finish() + assert self.function_events is not None, "Expected profiling results" return self.function_events.total_average() total_average.__doc__ = EventList.total_average.__doc__ @@ -389,6 +566,7 @@ def self_cpu_time_total(self): all self times across all the events. """ self._check_finish() + assert self.function_events is not None return self.function_events.self_cpu_time_total @@ -397,7 +575,7 @@ class record_function(ContextDecorator): Python code (or function) when running autograd profiler. It is useful when tracing the code profile. - Arguments: + Args: name (str): Label assigned to the block of code. node_id (int): ID of node, for distributed profiling. Unset in non-distributed cases. @@ -451,7 +629,7 @@ def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]: once to attach the callback onto the future, and will throw if called multiple times. - Arguments: + Args: fut: (torch._C.Future): future for which to schedule callback for. @@ -489,7 +667,7 @@ class emit_nvtx(object): This context manager should not be called recursively, i.e. at most one instance should be enabled at any given time. - Arguments: + Args: enabled (bool, optional, default=True): Setting ``enabled=False`` makes this context manager a no-op. Default: ``True``. record_shapes (bool, optional, default=False): If ``record_shapes=True``, the nvtx range wrapping @@ -564,12 +742,13 @@ def __enter__(self): raise RuntimeError("NVTX annotation context manager is not reentrant") self.entered = True torch.cuda.synchronize() - torch.autograd._enable_profiler( + torch.autograd._enable_profiler_legacy( torch.autograd.ProfilerConfig( torch.autograd.ProfilerState.NVTX, self.record_shapes, - False - ) + False, + False, + False) ) return self @@ -577,14 +756,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not self.enabled: return torch.cuda.synchronize() - torch.autograd._disable_profiler() + torch.autograd._disable_profiler_legacy() return False def load_nvprof(path): """Opens an nvprof trace file and parses autograd annotations. - Arguments: + Args: path (str): path to nvprof trace """ return EventList(parse_nvprof_trace(path)) @@ -639,14 +818,15 @@ class FormattedTimesMixin(object): cpu_time_total_str = attr_formatter('cpu_time_total') cuda_time_total_str = attr_formatter('cuda_time_total') self_cpu_time_total_str = attr_formatter('self_cpu_time_total') + self_cuda_time_total_str = attr_formatter('self_cuda_time_total') @property def cpu_time(self): - return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count + return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore @property def cuda_time(self): - return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count + return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count # type: ignore class Interval(object): @@ -664,26 +844,35 @@ def elapsed_us(self): class FunctionEvent(FormattedTimesMixin): """Profiling information about a single function.""" def __init__( - self, id, node_id, name, thread, cpu_start, cpu_end, input_shapes=None, - cpu_memory_usage=0, cuda_memory_usage=0, is_async=False, is_remote=True, - sequence_nr=-1): - self.id = id - self.node_id = node_id - self.name = name - self.cpu_interval = Interval(cpu_start, cpu_end) - self.thread = thread - self.kernels = [] - self.count = 1 - self.cpu_children = [] - self.cpu_parent = None - self.input_shapes = input_shapes - self.cpu_memory_usage = cpu_memory_usage - self.cuda_memory_usage = cuda_memory_usage - self.is_async = is_async - self.is_remote = is_remote - self.sequence_nr = sequence_nr + self, id, name, thread, start_us, end_us, fwd_thread=None, input_shapes=None, + stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False, + is_remote=False, sequence_nr=-1, node_id=-1, device_type=DeviceType.CPU, device_index=0, + is_legacy=False, flops=None): + self.id: int = id + self.node_id: int = node_id + self.name: str = name + self.time_range: Interval = Interval(start_us, end_us) + self.thread: int = thread + self.fwd_thread: Optional[int] = fwd_thread + self.kernels: List[Kernel] = [] + self.count: int = 1 + self.cpu_children: List[FunctionEvent] = [] + self.cpu_parent: Optional[FunctionEvent] = None + self.input_shapes: Tuple[int, ...] = input_shapes + self.stack: List = stack + self.scope: int = scope + self.cpu_memory_usage: int = cpu_memory_usage + self.cuda_memory_usage: int = cuda_memory_usage + self.is_async: bool = is_async + self.is_remote: bool = is_remote + self.sequence_nr: int = sequence_nr + self.device_type: DeviceType = device_type + self.device_index: int = device_index + self.is_legacy: bool = is_legacy + self.flops: Optional[float] = flops def append_kernel(self, name, device, start, end): + assert self.device_type == DeviceType.CPU self.kernels.append(Kernel(name, device, Interval(start, end))) def append_cpu_child(self, child): @@ -692,7 +881,9 @@ def append_cpu_child(self, child): One is supposed to append only direct children to the event to have correct self cpu time being reported. """ + assert(self.device_type == DeviceType.CPU) assert(isinstance(child, FunctionEvent)) + assert(child.device_type == DeviceType.CPU) self.cpu_children.append(child) def set_cpu_parent(self, parent): @@ -702,14 +893,16 @@ def set_cpu_parent(self, parent): the child's range interval is completely inside the parent's. We use this connection to determine the event is from top-level op or not. """ + assert(self.device_type == DeviceType.CPU) assert(isinstance(parent, FunctionEvent)) + assert(parent.device_type == DeviceType.CPU) self.cpu_parent = parent # Note: async events don't have children, are not used when computing 'self' # metrics of other events, have only total cpu time @property def self_cpu_memory_usage(self): - if self.is_async: + if self.is_async or self.device_type != DeviceType.CPU: return 0 return self.cpu_memory_usage - sum( [child.cpu_memory_usage for child in self.cpu_children] @@ -717,7 +910,7 @@ def self_cpu_memory_usage(self): @property def self_cuda_memory_usage(self): - if self.is_async: + if self.is_async or self.device_type != DeviceType.CPU: return 0 return self.cuda_memory_usage - sum( [child.cuda_memory_usage for child in self.cpu_children] @@ -725,7 +918,7 @@ def self_cuda_memory_usage(self): @property def self_cpu_time_total(self): - if self.is_async: + if self.is_async or self.device_type != DeviceType.CPU: return 0 return self.cpu_time_total - sum( [child.cpu_time_total for child in self.cpu_children] @@ -733,11 +926,37 @@ def self_cpu_time_total(self): @property def cuda_time_total(self): - return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + if not self.is_legacy: + # account for the kernels in the children ops + return (sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + + sum(ch.cuda_time_total for ch in self.cpu_children)) + else: + # each legacy cpu events has a single (fake) kernel + return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + else: + assert self.device_type == DeviceType.CUDA + return self.time_range.elapsed_us() + + @property + def self_cuda_time_total(self): + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + return self.cuda_time_total - \ + sum([child.cuda_time_total for child in self.cpu_children]) + else: + assert(self.device_type == DeviceType.CUDA) + return self.cuda_time_total @property def cpu_time_total(self): - return self.cpu_interval.elapsed_us() + if self.device_type == DeviceType.CPU: + return self.time_range.elapsed_us() + else: + return 0 @property def key(self): @@ -745,14 +964,16 @@ def key(self): def __repr__(self): return ( - ''.format( + 'cpu_memory_usage={} cuda_memory_usage={} is_async={} is_remote={} seq_nr={} is_legacy={}>'.format( self.id, + self.name, + self.device_type, self.node_id, self.cpu_time_str, - self.cpu_interval.start, - self.cpu_interval.end, + self.time_range.start, + self.time_range.end, str([child.id for child in self.cpu_children]), self.cuda_time_str, self.name, @@ -763,6 +984,7 @@ def __repr__(self): self.is_async, self.is_remote, self.sequence_nr, + self.is_legacy, ) ) @@ -770,23 +992,29 @@ def __repr__(self): class FunctionEventAvg(FormattedTimesMixin): """Used to average stats over multiple FunctionEvent objects.""" def __init__(self): - self.key = None - self.count = 0 - self.node_id = 0 - self.is_async = False - self.is_remote = False - self.cpu_time_total = 0 - self.cuda_time_total = 0 - self.self_cpu_time_total = 0 - self.input_shapes = None - self.cpu_memory_usage = 0 - self.cuda_memory_usage = 0 - self.self_cpu_memory_usage = 0 - self.self_cuda_memory_usage = 0 - self.cpu_children = None - self.cpu_parent = None - - def add(self, other, group_by_input_shapes=False): + self.key: Optional[str] = None + self.count: int = 0 + self.node_id: int = 0 + self.is_async: bool = False + self.is_remote: bool = False + self.cpu_time_total: int = 0 + self.cuda_time_total: int = 0 + self.self_cpu_time_total: int = 0 + self.self_cuda_time_total: int = 0 + self.input_shapes: Optional[List[List[int]]] = None + self.stack: Optional[List] = None + self.scope: Optional[int] = None + self.cpu_memory_usage: int = 0 + self.cuda_memory_usage: int = 0 + self.self_cpu_memory_usage: int = 0 + self.self_cuda_memory_usage: int = 0 + self.cpu_children: Optional[List[FunctionEvent]] = None + self.cpu_parent: Optional[FunctionEvent] = None + self.device_type: DeviceType = DeviceType.CPU + self.is_legacy: bool = False + self.flops: float = 0.0 + + def add(self, other): if self.key is None: # First function being recorded as part of FunctionEventAvg, propagate # fields. @@ -796,23 +1024,28 @@ def add(self, other, group_by_input_shapes=False): self.is_remote = other.is_remote self.cpu_parent = other.cpu_parent self.cpu_children = other.cpu_children - if group_by_input_shapes: - self.input_shapes = other.input_shapes - assert ( - not group_by_input_shapes or - other.input_shapes == self.input_shapes - ) + self.input_shapes = other.input_shapes + self.stack = other.stack + self.scope = other.scope + self.device_type = other.device_type + self.is_legacy = other.is_legacy + assert isinstance(other, (FunctionEvent, FunctionEventAvg)) assert other.key == self.key self.cpu_time_total += other.cpu_time_total self.cuda_time_total += other.cuda_time_total self.self_cpu_time_total += other.self_cpu_time_total + self.self_cuda_time_total += other.self_cuda_time_total self.cpu_memory_usage += other.cpu_memory_usage self.cuda_memory_usage += other.cuda_memory_usage self.self_cpu_memory_usage += other.self_cpu_memory_usage self.self_cuda_memory_usage += other.self_cuda_memory_usage self.count += other.count + if self.flops is None: + self.flops = other.flops + elif other.flops is not None: + self.flops += other.flops return self def __iadd__(self, other): @@ -821,11 +1054,12 @@ def __iadd__(self, other): def __repr__(self): return ( ' ' - 'cpu_memory_usage={} cuda_memory_usage={}'.format( + ' self_cuda_time={} cuda_time={} input_shapes={} ' + 'cpu_memory_usage={} cuda_memory_usage={}>'.format( self.key, self.self_cpu_time_total_str, self.cpu_time_str, + self.self_cuda_time_total_str, self.cuda_time_str, str(self.input_shapes), self.cpu_memory_usage, @@ -845,14 +1079,111 @@ def __missing__(self, key): self[key] = torch._C._demangle(key) if len(key) > 1 else key return self[key] +def filter_stack_entry(entry): + filtered_entries = [ + ("autograd/__init__", "_make_grads"), + ("autograd/__init__", "backward"), + ("torch/tensor", "backward"), + ("_internal/common_utils", "prof_callable"), + ("_internal/common_utils", "prof_func_call"), + ("_internal/common_utils", "prof_meth_call"), + ] + return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries]) -################################################################################ -# CPU checkpoints +def filter_name(name): + # ignoring the following utility ops + filtered_out_names = [ + "profiler::_record_function_enter", + "profiler::_record_function_exit", + "aten::is_leaf", + "aten::output_nr", + "aten::_version", + ] + return name in filtered_out_names + +# Parsing of kineto profiler events +def parse_kineto_results(result): + # result.events() has most of the events - PyTorch op-level and device-level events + # result.legacy_events() has events not yet ported to kineto + # (e.g. start/stop marks, tensor memory allocator events) + + # First, find __start_profile mark to get the absolute time of the start of the trace; + # save memory allocation records + start_record = None + mem_records = [] + for record in itertools.chain(*result.legacy_events()): + if record.kind() == 'mark' and record.name() == '__start_profile': + assert start_record is None + start_record = record + if record.kind() == 'memory_alloc': + mem_records.append(record) + assert start_record is not None, "Invalid profiler output, __start_profile is missing" -def parse_cpu_trace(thread_records): + # Create and return FunctionEvent list + string_table = StringTable() + function_events = [] + cuda_corr_map: Dict[int, List[torch.autograd.KinetoEvent]] = {} + for kineto_event in result.events(): + if filter_name(kineto_event.name()): + continue + rel_start_us = kineto_event.start_us() - start_record.start_us() + rel_end_us = rel_start_us + kineto_event.duration_us() + abs_end_us = kineto_event.start_us() + kineto_event.duration_us() + + cpu_memory_usage = 0 + cuda_memory_usage = 0 + if kineto_event.device_type() == DeviceType.CPU: + # find the corresponding memory allocation events + for mem_record in mem_records: + if (mem_record.start_us() >= kineto_event.start_us() and + mem_record.start_us() <= abs_end_us): + cpu_memory_usage += mem_record.cpu_memory_usage() + cuda_memory_usage += mem_record.cuda_memory_usage() + is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id() + fe = FunctionEvent( + id=kineto_event.correlation_id(), + name=string_table[kineto_event.name()], + thread=kineto_event.start_thread_id(), + start_us=rel_start_us, + end_us=rel_end_us, + fwd_thread=kineto_event.fwd_thread_id(), + input_shapes=kineto_event.shapes(), + stack=[entry for entry in kineto_event.stack() if filter_stack_entry(entry)], + scope=kineto_event.scope(), + cpu_memory_usage=cpu_memory_usage, + cuda_memory_usage=cuda_memory_usage, + is_async=is_async, + sequence_nr=kineto_event.sequence_nr(), + device_type=kineto_event.device_type(), + device_index=kineto_event.device_index(), + ) + function_events.append(fe) + if kineto_event.device_type() == DeviceType.CUDA: + corr_id = kineto_event.linked_correlation_id() + if corr_id > 0: + if corr_id not in cuda_corr_map: + cuda_corr_map[corr_id] = [] + cuda_corr_map[corr_id].append(kineto_event) + + # associate CUDA kernels with CPU events + for fe in function_events: + if (fe.device_type == DeviceType.CPU and not fe.is_async and + fe.id in cuda_corr_map): + for k_evt in cuda_corr_map[fe.id]: + fe.append_kernel( + k_evt.name(), + k_evt.device_index(), + k_evt.start_us(), + k_evt.start_us() + k_evt.duration_us()) + + function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) + return function_events + +# Parsing of legacy profiler events +def parse_legacy_records(thread_records): def get_record_key(record): """ - Returns a tuple to be used by parse_cpu_trace for correlating start and + Returns a tuple to be used by parse_legacy_records for correlating start and end records. """ return (record.handle(), record.node_id()) @@ -864,15 +1195,6 @@ def get_record_key(record): record_stack = [] string_table = StringTable() - # ignoring the following utility ops - filtered_out_names = [ - "profiler::_record_function_enter", - "profiler::_record_function_exit", - "aten::is_leaf", - "aten::output_nr", - "aten::_version", - ] - # cuda start events and the overall profiler start event don't happen # at exactly the same time because we need to record an event on each device # and each record takes ~4us. So we adjust here by the difference @@ -880,6 +1202,7 @@ def get_record_key(record): # and the CPU time of the cuda start event for the device def adjusted_time(cuda_record, cuda_records_map): assert cuda_record.device() != -1 + assert start_record is not None cuda_time_0 = cuda_records_map[(cuda_record.node_id(), cuda_record.device())] return cuda_time_0.cuda_elapsed_us(cuda_record) + start_record.cpu_elapsed_us(cuda_time_0) @@ -908,7 +1231,7 @@ def adjusted_time(cuda_record, cuda_records_map): prev_record = None for record in thread_record_list: record_key = get_record_key(record) - if (record.name() in filtered_out_names or + if (filter_name(record.name()) or record_key in filtered_handles): filtered_handles.add(record_key) continue @@ -943,20 +1266,27 @@ def adjusted_time(cuda_record, cuda_records_map): cuda_memory_usage = cuda_memory_allocs[record_key] is_async = start.thread_id() != record.thread_id() is_remote_event = record.is_remote() + start_flops = start.flops() fe = FunctionEvent( id=record.handle(), node_id=record.node_id(), name=string_table[start.name()], thread=start.thread_id(), - cpu_start=start_record.cpu_elapsed_us(start), - cpu_end=start_record.cpu_elapsed_us(record), + start_us=start_record.cpu_elapsed_us(start), + end_us=start_record.cpu_elapsed_us(record), + fwd_thread=start.fwd_thread_id(), input_shapes=start.shapes(), + stack=[entry for entry in start.stack() if filter_stack_entry(entry)], + scope=start.scope(), cpu_memory_usage=cpu_memory_usage, cuda_memory_usage=cuda_memory_usage, is_async=is_async, is_remote=is_remote_event, sequence_nr=start.sequence_nr(), + device_type=DeviceType.CPU, + is_legacy=True, + flops=start_flops, ) # note: async events have only cpu total time if not is_async and start.has_cuda(): @@ -985,7 +1315,7 @@ def adjusted_time(cuda_record, cuda_records_map): # granularity of the given clock tick)--we always show # the outermost nested call first. This adds stability # in how FunctionEvents appear - functions.sort(key=lambda evt: [evt.cpu_interval.start, -evt.cpu_interval.end]) + functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) return functions @@ -1029,9 +1359,11 @@ def parse_nvprof_trace(path): for row in conn.execute(marker_query): unique.see(row['marker_id']) evt = FunctionEvent(id=row['marker_id'], + node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure + # that pytorch doesn't crash when creating a FunctionEvent() object name=strings[row['name']], - cpu_start=row['start_time'], - cpu_end=row['end_time'], + start_us=row['start_time'], + end_us=row['end_time'], thread=0) # TODO: find in sqlite database functions.append(evt) functions_map[evt.id] = evt @@ -1062,7 +1394,7 @@ def parse_nvprof_trace(path): row['kernel_start'], row['kernel_end']) - functions.sort(key=lambda evt: evt.cpu_interval.start) + functions.sort(key=lambda evt: evt.time_range.start) return functions @@ -1075,7 +1407,9 @@ def build_table( sort_by=None, header=None, row_limit=100, + max_src_column_width=75, use_cuda=True, + with_flops=False, profile_memory=False, top_level_events_only=False): """Prints a summary of events (which can be a list of FunctionEvent or FunctionEventAvg).""" @@ -1085,25 +1419,44 @@ def build_table( if sort_by is not None: events = EventList(sorted( events, key=lambda evt: getattr(evt, sort_by), reverse=True - ), use_cuda=use_cuda, profile_memory=profile_memory) + ), use_cuda=use_cuda, profile_memory=profile_memory, with_flops=with_flops) has_input_shapes = any( - [event.input_shapes is not None for event in events]) + [(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events]) + + MAX_NAME_COLUMN_WIDTH = 55 name_column_width = max([len(evt.key) for evt in events]) + 4 - DEFAULT_COLUMN_WIDTH = 15 - SHAPES_COLUMN_WIDTH = 45 + name_column_width = min(name_column_width, MAX_NAME_COLUMN_WIDTH) + + DEFAULT_COLUMN_WIDTH = 12 + + shapes_column_width = max([len(str(evt.input_shapes)) for evt in events]) + 4 + shapes_column_width = min(shapes_column_width, 45) + + flops_column_width = DEFAULT_COLUMN_WIDTH + + src_column_width = None + stacks = [] + for evt in events: + if evt.stack is not None and len(evt.stack) > 0: + stacks.append(evt.stack) + has_stack = len(stacks) > 0 + if has_stack: + src_column_width = max([max([len(entry) for entry in stack]) for stack in stacks]) + 4 + src_column_width = min(src_column_width, max_src_column_width) headers = [ 'Name', - 'Self CPU total %', - 'Self CPU total', + 'Self CPU %', + 'Self CPU', 'CPU total %', 'CPU total', 'CPU time avg', ] if use_cuda: headers.extend([ - 'CUDA total %', + 'Self CUDA', + 'Self CUDA %', 'CUDA total', 'CUDA time avg', ]) @@ -1118,7 +1471,7 @@ def build_table( 'Self CUDA Mem', ]) headers.append( - 'Number of Calls' + '# of Calls' ) # Only append Node ID if any event has a valid (>= 0) Node ID append_node_id = any([evt.node_id != -1 for evt in events]) @@ -1127,14 +1480,29 @@ def build_table( # Have to use a list because nonlocal is Py3 only... SPACING_SIZE = 2 - row_format = [""] - header_sep = [""] - line_length = [-SPACING_SIZE] - - def add_column(padding): - row_format[0] += '{: <' + str(padding) + '} ' - header_sep[0] += '-' * padding + ' ' - line_length[0] += padding + SPACING_SIZE + row_format_lst = [""] + header_sep_lst = [""] + line_length_lst = [-SPACING_SIZE] + MAX_STACK_ENTRY = 5 + + def add_column(padding, text_dir='>'): + row_format_lst[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE) + header_sep_lst[0] += '-' * padding + (' ' * SPACING_SIZE) + line_length_lst[0] += padding + SPACING_SIZE + + def auto_scale_flops(flops): + flop_headers = [ + 'FLOPS', + 'KFLOPS', + 'MFLOPS', + 'GFLOPS', + 'TFLOPS', + 'PFLOPS', + ] + assert flops > 0 + log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1))) + assert log_flops >= 0 and log_flops < len(flop_headers) + return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)]) add_column(name_column_width) for _ in headers[1:]: @@ -1142,12 +1510,34 @@ def add_column(padding): if has_input_shapes: headers.append('Input Shapes') - add_column(SHAPES_COLUMN_WIDTH) + add_column(shapes_column_width) + + if has_stack: + headers.append('Source Location') + add_column(src_column_width, text_dir='<') + + if with_flops: + # Auto-scaling of flops header + US_IN_SECOND = 1000.0 * 1000.0 # cpu_time_total is in us + raw_flops = [] + for evt in events: + if evt.flops > 0: + if evt.cuda_time_total != 0: + evt.flops = float(evt.flops) / evt.cuda_time_total * US_IN_SECOND + else: + evt.flops = float(evt.flops) / evt.cpu_time_total * US_IN_SECOND + raw_flops.append(evt.flops) + if len(raw_flops) != 0: + (flops_scale, flops_header) = auto_scale_flops(min(raw_flops)) + headers.append(flops_header) + add_column(flops_column_width) + else: + with_flops = False # can't find any valid flops - row_format = row_format[0] - header_sep = header_sep[0] - line_length = line_length[0] - add_column = None + row_format = row_format_lst[0] + header_sep = header_sep_lst[0] + line_length = line_length_lst[0] + add_column = None # type: ignore # Have to use a list because nonlocal is Py3 only... result = [] @@ -1157,7 +1547,16 @@ def append(s): result.append('\n') # Yes, newline after the end as well self_cpu_time_total = sum([event.self_cpu_time_total for event in events]) - cuda_time_total = sum([evt.cuda_time_total for evt in events]) + cuda_time_total = 0 + for evt in events: + if evt.device_type == DeviceType.CPU: + # in legacy profiler, kernel info is stored in cpu events + if evt.is_legacy: + cuda_time_total += evt.self_cuda_time_total + elif evt.device_type == DeviceType.CUDA: + # in kineto mode, there're events with the correct device type (e.g. CUDA) + cuda_time_total += evt.self_cuda_time_total + # Actual printing if header is not None: append('=' * line_length) @@ -1178,8 +1577,11 @@ def append(s): continue else: event_limit += 1 + name = evt.key + if len(name) >= MAX_NAME_COLUMN_WIDTH - 3: + name = name[:(MAX_NAME_COLUMN_WIDTH - 3)] + "..." row_values = [ - evt.key, # Name + name, # Self CPU total, 0 for async events. % format_time_share(evt.self_cpu_time_total, self_cpu_time_total), @@ -1191,8 +1593,9 @@ def append(s): ] if use_cuda: row_values.extend([ + evt.self_cuda_time_total_str, # CUDA time total % - format_time_share(evt.cuda_time_total, cuda_time_total), + format_time_share(evt.self_cuda_time_total, cuda_time_total), evt.cuda_time_total_str, evt.cuda_time_str, # Cuda time avg ]) @@ -1217,9 +1620,26 @@ def append(s): if append_node_id: row_values.append(evt.node_id) if has_input_shapes: - row_values.append(str(evt.input_shapes)[:SHAPES_COLUMN_WIDTH]) + row_values.append(str(evt.input_shapes)[:shapes_column_width]) + if with_flops: + if evt.flops <= 0: + row_values.append("--") + else: + row_values.append('{0:8.3f}'.format(evt.flops * flops_scale)) + if has_stack: + src_field = "" + if len(evt.stack) > 0: + src_field = evt.stack[0][:src_column_width] + row_values.append(src_field) append(row_format.format(*row_values)) + if has_stack: + empty_headers = [""] * (len(headers) - 1) + for entry in evt.stack[1:MAX_STACK_ENTRY]: + append(row_format.format(*(empty_headers + [entry[:src_column_width]]))) + empty_headers.append("") + append(row_format.format(*empty_headers)) + append(header_sep) append("Self CPU time total: {}".format(format_time(self_cpu_time_total))) if use_cuda: diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index 1008d741a6cf8..307f82db34dbc 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -7,9 +7,10 @@ def __instancecheck__(cls, other): return isinstance(other, torch.Tensor) -class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): +# mypy doesn't understand torch._six.with_metaclass +class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): # type: ignore pass from torch._C import _ImperativeEngine as ImperativeEngine -Variable._execution_engine = ImperativeEngine() +Variable._execution_engine = ImperativeEngine() # type: ignore diff --git a/torch/utils/_benchmark/op_fuzzers/__init__.py b/torch/backends/_nnapi/__init__.py similarity index 100% rename from torch/utils/_benchmark/op_fuzzers/__init__.py rename to torch/backends/_nnapi/__init__.py diff --git a/torch/backends/_nnapi/prepare.py b/torch/backends/_nnapi/prepare.py new file mode 100644 index 0000000000000..bb1ea95c8e5be --- /dev/null +++ b/torch/backends/_nnapi/prepare.py @@ -0,0 +1,187 @@ +from typing import Optional, List + +import torch +from torch.backends._nnapi.serializer import serialize_model + +class NnapiModule(torch.nn.Module): + """Torch Module that wraps an NNAPI Compilation. + + This module handles preparing the weights, initializing the + NNAPI TorchBind object, and adjusting the memory formats + of all inputs and outputs. + """ + + comp: Optional[torch.classes._nnapi.Compilation] + + def __init__( + self, + ser_model: torch.Tensor, + weights: List[torch.Tensor], + inp_mem_fmts: List[int], + out_mem_fmts: List[int], + out_templates: List[torch.Tensor]): + super().__init__() + self.ser_model = ser_model + self.weights = weights + self.inp_mem_fmts = inp_mem_fmts + self.out_mem_fmts = out_mem_fmts + self.out_templates = out_templates + self.comp = None + + @torch.jit.export + def init(self): + assert self.comp is None + self.weights = [w.contiguous() for w in self.weights] + comp = torch.classes._nnapi.Compilation() + comp.init(self.ser_model, self.weights) + self.comp = comp + + def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: + comp = self.comp + assert comp is not None + outs = [torch.empty_like(out) for out in self.out_templates] + + assert len(args) == len(self.inp_mem_fmts) + fixed_args = [] + for idx in range(len(args)): + fmt = self.inp_mem_fmts[idx] + # These constants match the values in DimOrder in serializer.py + # TODO: See if it's possible to use those directly. + if fmt == 0: + fixed_args.append(args[idx].contiguous()) + elif fmt == 1: + fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous()) + else: + raise Exception("Invalid mem_fmt") + comp.run(fixed_args, outs) + assert len(outs) == len(self.out_mem_fmts) + for idx in range(len(self.out_templates)): + fmt = self.out_mem_fmts[idx] + # These constants match the values in DimOrder in serializer.py + # TODO: See if it's possible to use those directly. + if fmt == 0: + pass + elif fmt == 1: + outs[idx] = outs[idx].permute(0, 3, 1, 2) + else: + raise Exception("Invalid mem_fmt") + return outs + + +class NnapiInitWrapper(torch.nn.Module): + """Wrapper module to ensure NNAPI init is called.""" + def __init__(self, nnapi_module): + super().__init__() + self.nnapi_module = nnapi_module + + def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: + return self.nnapi_module(args) + + @torch.jit.export + def __getstate__(self): + return self.nnapi_module + + @torch.jit.export + def __setstate__(self, nnapi_module): + self.training = False + self.nnapi_module = nnapi_module + self.nnapi_module.init() + + +class ListWrapper(torch.nn.Module): + """NNAPI list-ifying wrapper. + + NNAPI always expects a list of inputs. This module provides a + single-tensor input interface for models that want it. + """ + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, t: torch.Tensor) -> List[torch.Tensor]: + return self.mod([t]) + +class DelistWrapper(torch.nn.Module): + """NNAPI de-list-ifying wrapper. + + NNAPI always provides a list of outputs. This module provides a + single-tensor output interface for models that want it. + """ + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, ts: List[torch.Tensor]) -> torch.Tensor: + outs = self.mod(ts) + assert len(outs) == 1 + return outs[0] + +class ListDelistWrapper(torch.nn.Module): + """NNAPI list-ifying and de-list-ifying wrapper. + + NNAPI always expects a list of inputs and provides a list of outputs. + This module provides a single-tensor input/output interface + for models that want it. + """ + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, t: torch.Tensor) -> torch.Tensor: + outs = self.mod([t]) + assert len(outs) == 1 + return outs[0] + + +def _condensed_zeros_like(t): + """Get a small-storage deterministic tensor with the same shape and dtype as t + + Similar to `torch.zeros(1, dtype=out.dtype).expand(out.shape)`, + but this works with quantized dtypes as well. + + Similar to `torch.empty(1, dtype=out.dtype).expand(out.shape)`, + but always returns the same data. + """ + + ret = torch.empty_like(t).flatten()[1].clone().expand(t.shape) + assert ret.storage().size() == 1 + ret.storage()[0] = 0 + return ret + + +def convert_model_to_nnapi(model, inputs): + model = torch.jit.freeze(model) + + if isinstance(inputs, torch.Tensor): + inputs = [inputs] + list_inputs = True + else: + list_inputs = False + + outputs = model(*inputs) + + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + delist_outputs = True + else: + delist_outputs = False + + ser_model, used_weights, inp_mem_fmts, out_mem_fmts = serialize_model(model, inputs) + ser_model_tensor = torch.tensor(list(ser_model), dtype=torch.uint8) + + out_templates = [_condensed_zeros_like(out) for out in outputs] + nnapi_model = NnapiInitWrapper(NnapiModule( + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + out_templates)) + + if list_inputs and delist_outputs: + nnapi_model = ListDelistWrapper(nnapi_model) + elif list_inputs: + nnapi_model = ListWrapper(nnapi_model) + elif delist_outputs: + nnapi_model = DelistWrapper(nnapi_model) + + return torch.jit.script(nnapi_model) diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py new file mode 100644 index 0000000000000..bbe141989009d --- /dev/null +++ b/torch/backends/_nnapi/serializer.py @@ -0,0 +1,1363 @@ +import enum +import struct +import array +import logging +from typing import ( + Tuple, + NamedTuple, +) + +import torch + + +# TODO: Add type annotations +# TODO: Check tensor types for ops + + +LOG = logging.getLogger("nnapi_serialize") + + +class NNAPI_OperandCode(object): + FLOAT32 = 0 + INT32 = 1 + UINT32 = 2 + TENSOR_FLOAT32 = 3 + TENSOR_INT32 = 4 + TENSOR_QUANT8_ASYMM = 5 + BOOL = 6 + TENSOR_QUANT16_SYMM = 7 + TENSOR_FLOAT16 = 8 + TENSOR_BOOL8 = 9 + FLOAT16 = 10 + TENSOR_QUANT8_SYMM_PER_CHANNEL = 11 + TENSOR_QUANT16_ASYMM = 12 + + +class NNAPI_OperationCode(object): + ADD = 0 + AVERAGE_POOL_2D = 1 + CONCATENATION = 2 + CONV_2D = 3 + DEPTHWISE_CONV_2D = 4 + DEPTH_TO_SPACE = 5 + DEQUANTIZE = 6 + EMBEDDING_LOOKUP = 7 + FLOOR = 8 + FULLY_CONNECTED = 9 + HASHTABLE_LOOKUP = 10 + L2_NORMALIZATION = 11 + L2_POOL_2D = 12 + LOCAL_RESPONSE_NORMALIZATION = 13 + LOGISTIC = 14 + LSH_PROJECTION = 15 + LSTM = 16 + MAX_POOL_2D = 17 + MUL = 18 + RELU = 19 + RELU1 = 20 + RELU6 = 21 + RESHAPE = 22 + RESIZE_BILINEAR = 23 + RNN = 24 + SOFTMAX = 25 + SPACE_TO_DEPTH = 26 + SVDF = 27 + TANH = 28 + BATCH_TO_SPACE_ND = 29 + DIV = 30 + MEAN = 31 + PAD = 32 + SPACE_TO_BATCH_ND = 33 + SQUEEZE = 34 + STRIDED_SLICE = 35 + SUB = 36 + TRANSPOSE = 37 + ABS = 38 + ARGMAX = 39 + ARGMIN = 40 + AXIS_ALIGNED_BBOX_TRANSFORM = 41 + BIDIRECTIONAL_SEQUENCE_LSTM = 42 + BIDIRECTIONAL_SEQUENCE_RNN = 43 + BOX_WITH_NMS_LIMIT = 44 + CAST = 45 + CHANNEL_SHUFFLE = 46 + DETECTION_POSTPROCESSING = 47 + EQUAL = 48 + EXP = 49 + EXPAND_DIMS = 50 + GATHER = 51 + GENERATE_PROPOSALS = 52 + GREATER = 53 + GREATER_EQUAL = 54 + GROUPED_CONV_2D = 55 + HEATMAP_MAX_KEYPOINT = 56 + INSTANCE_NORMALIZATION = 57 + LESS = 58 + LESS_EQUAL = 59 + LOG = 60 + LOGICAL_AND = 61 + LOGICAL_NOT = 62 + LOGICAL_OR = 63 + LOG_SOFTMAX = 64 + MAXIMUM = 65 + MINIMUM = 66 + NEG = 67 + NOT_EQUAL = 68 + PAD_V2 = 69 + POW = 70 + PRELU = 71 + QUANTIZE = 72 + QUANTIZED_16BIT_LSTM = 73 + RANDOM_MULTINOMIAL = 74 + REDUCE_ALL = 75 + REDUCE_ANY = 76 + REDUCE_MAX = 77 + REDUCE_MIN = 78 + REDUCE_PROD = 79 + REDUCE_SUM = 80 + ROI_ALIGN = 81 + ROI_POOLING = 82 + RSQRT = 83 + SELECT = 84 + SIN = 85 + SLICE = 86 + SPLIT = 87 + SQRT = 88 + TILE = 89 + TOPK_V2 = 90 + TRANSPOSE_CONV_2D = 91 + UNIDIRECTIONAL_SEQUENCE_LSTM = 92 + UNIDIRECTIONAL_SEQUENCE_RNN = 93 + RESIZE_NEAREST_NEIGHBOR = 94 + + +class NNAPI_FuseCode(object): + FUSED_NONE = 0 + FUSED_RELU = 1 + FUSED_RELU1 = 2 + FUSED_RELU6 = 3 + + +class OperandValueSourceType(object): + IMMEDIATE = 0 + NUMBERED_BUFFER = 2 + NUMBERED_MEMORY = 3 + + +# Scalar types that appear explicitly in models. +# These must be kept in sync with +# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. +# TODO: Expose these directly to Python to avoid maintaining this list. +class TorchScalarTypes(enum.Enum): + QUINT8 = 13 + + +def approx_equal(lhs, rhs, tolerance=1e-6): + return abs(lhs - rhs) <= tolerance * min(lhs, rhs) + + +def tensor_size(op_type, dims): + ITEM_SIZES = { + NNAPI_OperandCode.TENSOR_FLOAT32: 4, + NNAPI_OperandCode.TENSOR_INT32: 4, + NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1, + NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2, + } + size = ITEM_SIZES[op_type] + for d in dims: + size *= d + return size + + +class ConvPoolArgs2d(NamedTuple): + """Configuration arguments for a convolution.""" + kernel_h: int + kernel_w: int + stride_h: int + stride_w: int + pad_t: int + pad_b: int + pad_l: int + pad_r: int + dilation_h: int + dilation_w: int + group: int + + +class DimOrder(enum.Enum): + PRESUMED_CONTIGUOUS = 0 + CHANNELS_LAST = 1 + SCALAR_OR_VECTOR = 2 + UNKNOWN_CONSTANT = 999 + + +class Operand(NamedTuple): + """Represenation of an NNAPI operand.""" + + # NNAPI operand type. One of NNAPI_OperandCode. + # TODO: Make this an enum. + op_type: int + + # This is always the PyTorch shape, which is NCHW for feature maps. + # The actual NNAPI operand might have a transposed shape. + shape: Tuple[int, ...] + + # Specifies how the shape of the operand that we define in NNAPI + # relates to the shape we track above. + # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match + # the shape of the PyTorch tensor. + # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and + # the NNAPI operand will be represented explicitly as NHWC. + dim_order: DimOrder + + # Quantization params + scale: float + zero_point: int + + def use_nchw(self): + if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return True + if self.dim_order is DimOrder.CHANNELS_LAST: + return False + raise Exception("Unknown dim order") + + +def broadcast_shapes(shape1, shape2): + assert len(shape1) > 0 + assert len(shape2) > 0 + s1 = list(shape1) + s2 = list(shape2) + # TODO: Support non-equal-rank broadcast where semantics match. + # This can be tricky for NHWC tensors because dimension orders + # don't match between PT and NNAPI, even though semantics match. + if len(s1) > len(s2): + # s2 = [1] * (len(s1) - len(s2)) + s2 + raise Exception("Non-equal-rank broadcast is not supported yet.") + if len(s2) > len(s1): + # s3 = [1] * (len(s2) - len(s1)) + s1 + raise Exception("Non-equal-rank broadcast is not supported yet.") + ret = [] + for d1, d2 in zip(s1, s2): + if d1 == 1: + ret.append(d2) + elif d2 == 1: + ret.append(d1) + elif d1 == d2: + ret.append(d1) + else: + raise Exception("Cannot broadcast shapes: {} and {}".format(shape1, shape2)) + return tuple(ret) + + +def get_conv_pool_shape(image_shape, args, out_ch, transpose): + batch, in_c, in_h, in_w = image_shape + + # TODO: Handle dilation + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("Dilation not supported yet.") + + if transpose: + out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b + out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l + else: + out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1 + out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1 + + # Handle variable-sized tensors. + if in_h == 0: + out_h = 0 + if in_w == 0: + out_w = 0 + + out_shape = (batch, out_ch, out_h, out_w) + return out_shape + + +def fix_shape(shape, dim_order): + # Return the actual shape that an operand should have in NNAPI, + # given a PyTorch shape and dimension order. This is where we + # convert from PyTorch's "always NCHW" shape to explicit NHWC. + if dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return shape + if dim_order is DimOrder.CHANNELS_LAST: + return tuple([shape[0]] + list(shape[2:]) + [shape[1]]) + if dim_order is DimOrder.SCALAR_OR_VECTOR: + assert len(shape) == 0 or len(shape) == 1 + return shape + if dim_order is DimOrder.UNKNOWN_CONSTANT: + # XXX think this through + return shape + raise Exception(f"Bad dim_order: {dim_order!r}.") + + +class _NnapiSerializer(object): + def __init__(self, config): + self.operands = [] + self.values = [] + self.operations = [] + self.value_data = [] + self.operation_args = [] + self.inputs = [] + self.outputs = [] + + self.modules = {} + self.constants = {} + self.tensor_tuples = {} + self.jitval_operand_map = {} + self.cached_immediates = {} + self.used_weights = [] + self.weight_offset = 0 + + if config is None: + config = {} + + # XXX get rid of this + self.solid_weights = config.get("solid_weights", False) + + def add_tensor_operand(self, jitval, oper): + assert isinstance(oper, Operand) + if jitval in self.jitval_operand_map: + raise Exception("Duplicate tensor: %r" % jitval) + + operand_id = len(self.operands) + self.operands.append(oper) + self.jitval_operand_map[jitval] = operand_id + return operand_id + + @staticmethod + def torch_tensor_to_operand(tensor, dim_order): + dtype = str(tensor.dtype).replace("torch.", "") + scale = 0.0 + zero_point = 0 + if dtype == "float32": + op_type = NNAPI_OperandCode.TENSOR_FLOAT32 + elif dtype == "quint8": + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + elif dtype == "qint32": + op_type = NNAPI_OperandCode.TENSOR_INT32 + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + assert zero_point == 0 + else: + raise Exception(f"Can't handle input with dtype '{tensor.dtype}'") + return Operand( + shape=tuple(tensor.shape), + op_type=op_type, + dim_order=dim_order, + scale=scale, + zero_point=zero_point, + ) + + def add_tensor_operand_for_input(self, jitval, tensor): + dim_order = ( + DimOrder.CHANNELS_LAST if getattr(tensor, "nnapi_nhwc", False) + else DimOrder.PRESUMED_CONTIGUOUS) + toper = self.torch_tensor_to_operand(tensor, dim_order) + operand_id = self.add_tensor_operand(jitval, toper) + self.inputs.append(operand_id) + return operand_id + + def add_tensor_operand_for_weight(self, tensor): + toper = self.torch_tensor_to_operand(tensor, DimOrder.UNKNOWN_CONSTANT) + operand_id = len(self.operands) + self.operands.append(toper) + tsize = tensor_size(toper.op_type, toper.shape) + psize = ((tsize - 1) | 0x3) + 1 + self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) + if self.solid_weights: + buf_num = 0 + offset = self.weight_offset + self.weight_offset += psize + else: + buf_num = len(self.used_weights) + offset = 0 + self.value_data.append(struct.pack( + "iii", + buf_num, + offset, + tsize)) + self.used_weights.append(tensor) + return operand_id + + def add_immediate_operand(self, code, value, dims): + assert isinstance(dims, tuple) + cache_key = (code, value) + if cache_key not in self.cached_immediates: + operand_id = len(self.operands) + self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0)) + self.values.append((operand_id, OperandValueSourceType.IMMEDIATE)) + self.value_data.append(value) + self.cached_immediates[cache_key] = operand_id + return self.cached_immediates[cache_key] + + def add_immediate_int_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.INT32, + struct.pack("i", value), + ()) + + def add_immediate_float_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.FLOAT32, + struct.pack("f", value), + ()) + + def add_immediate_bool_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.BOOL, + b"\x01" if value else b"\x00", + ()) + + def add_immediate_int_vector(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.TENSOR_INT32, + array.array("i", value).tobytes(), + (len(value),)) + + def get_tensor_operand_by_jitval(self, jitval): + operand_id = self.jitval_operand_map[jitval] + return (operand_id, self.operands[operand_id]) + + def get_tensor_operand_or_constant(self, jitval): + operand_id = self.jitval_operand_map.get(jitval) + if operand_id is None: + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value) + return (operand_id, self.operands[operand_id]) + + def get_tensor_operand_for_weight(self, jitval): + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value) + return (operand_id, self.operands[operand_id]) + + def add_operation(self, opcode, inputs, outputs): + self.operations.append((opcode, len(inputs), len(outputs))) + self.operation_args.extend(inputs + outputs) + + def add_tensor_tuple(self, jitval, values): + assert jitval not in self.tensor_tuples + self.tensor_tuples[jitval] = values + + def add_constant_value(self, jitval, ctype, value): + assert jitval not in self.constants + self.constants[jitval] = (ctype, value) + + def get_constant_value(self, jitval, typekind=None): + record = self.constants.get(jitval) + if record is None: + raise Exception(f"Could not find constant value for '{jitval!r}'.") + ctype, _ = record + if typekind is not None and ctype.kind() != typekind: + raise Exception( + f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'") + return record + + def get_size_arg(self, jitval): + ctype, value = self.get_constant_value(jitval) + if ctype.kind() == "ListType": + assert ctype.getElementType().kind() == "IntType" + return value + raise Exception(f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'") + + def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): + pc = [i.item() for i in packed_config] + assert pc[0] == 2 + strides = [pc[1], pc[2]] + paddings = [pc[3], pc[4]] + dilations = [pc[5], pc[6]] + output_padding = [pc[7], pc[8]] + group_num = pc[9] + transpose = pc[10] + + assert len(pc) == 11 + assert output_padding == [0, 0] + assert transpose == 0 + + return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num) + + def get_conv_pool_args_2d_from_jit(self, kernel_size, stride, padding, dilation, group=None): + strides = self.get_size_arg(stride) + paddings = self.get_size_arg(padding) + dilations = self.get_size_arg(dilation) + if group is not None: + _, group_num = self.get_constant_value(group, "IntType") + else: + group_num = None + return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num) + + def get_conv_pool_args_2d_common(self, kernel_size, strides, paddings, dilations, group_num): + kernels = list(kernel_size) + + assert len(kernels) == 2 + assert len(strides) == 2 + assert len(paddings) == 2 + assert len(dilations) == 2 + + # NNAPI uses 4 values for padding. + ph, pw = paddings + real_paddings = [ph, ph, pw, pw] + + return ConvPoolArgs2d(*(kernels + strides + real_paddings + dilations + [group_num])) + + def serialize_model(self, model, inputs): + self.add_immediate_bool_scalar(False) + self.add_immediate_bool_scalar(True) + + inp_dim_orders = [] + out_dim_orders = [] + + self_jitval = next(model.graph.inputs()) + self.add_constant_value(self_jitval, self_jitval.type(), model) + + for input_value, input_tensor in zip(list(model.graph.inputs())[1:], inputs): + op_id = self.add_tensor_operand_for_input(input_value, input_tensor) + inp_dim_orders.append(self.operands[op_id].dim_order.value) + + for idx, node in enumerate(model.graph.nodes()): + LOG.debug("Processing node #%d: %r", idx, node) + self.add_node(node) + + retn = model.graph.return_node() + assert retn.inputsSize() == 1 + assert retn.outputsSize() == 0 + retn_input = retn.inputsAt(0) + if retn_input.type().kind() == "TensorType": + op_id = self.jitval_operand_map[retn_input] + # TODO: Make outputs a local variable? + self.outputs.append(op_id) + out_dim_orders.append(self.operands[op_id].dim_order.value) + elif retn_input.type().kind() == "TupleType": + for v in self.tensor_tuples[retn_input]: + op_id = self.jitval_operand_map[v] + self.outputs.append(op_id) + out_dim_orders.append(self.operands[op_id].dim_order.value) + + model = [] + + version = 1 + header = struct.pack( + "iiiiii", + version, + len(self.operands), + len(self.values), + len(self.operations), + len(self.inputs), + len(self.outputs), + ) + model.append(header) + + serialized_values, serialized_value_data = self.serialize_values() + + model.extend(struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands) + model.extend(serialized_values) + model.extend(struct.pack("iii", *x) for x in self.operations) + model.extend(self.serialize_ints(fix_shape(dims, mf)) for (_, dims, mf, _, _) in self.operands) + model.extend(serialized_value_data) + model.append(self.serialize_ints(self.operation_args)) + model.append(self.serialize_ints(self.inputs)) + model.append(self.serialize_ints(self.outputs)) + + # return (b"".join(model), self.used_weight_tensor_names) + return (b"".join(model), self.used_weights, inp_dim_orders, out_dim_orders) + + def serialize_values(self): + serialized_values = [] + serialized_value_data = [] + assert len(self.values) == len(self.value_data) + for ((op_index, source_type), data) in zip(self.values, self.value_data): + source_length = len(data) + + # Pad with 0 bytes out to a multiple of 4 for alignment. + physical_length = ((source_length - 1) | 0x3) + 1 + padded_data = data + (b"\0" * (physical_length - source_length)) + + serialized_values.append(struct.pack("iii", op_index, source_type, source_length)) + serialized_value_data.append(padded_data) + + return serialized_values, serialized_value_data + + @staticmethod + def serialize_ints(ints): + return struct.pack("i" * len(ints), *ints) + + ADDER_MAP = { + "prim::GetAttr": lambda self, node: + self.add_getattr(node), + "prim::Constant": lambda self, node: + self.add_constant_node(node), + "prim::ListConstruct": lambda self, node: + self.add_list_construct(node), + "prim::TupleConstruct": lambda self, node: + self.add_tuple_construct(node), + "aten::reshape": lambda self, node: + self.add_reshape(node), + "aten::quantize_per_tensor": lambda self, node: + self.add_quantize(node), + "aten::dequantize": lambda self, node: + self.add_dequantize(node), + "aten::add": lambda self, node: + self.add_add_sub_op(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE), + "aten::sub": lambda self, node: + self.add_add_sub_op(node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE), + "aten::mul": lambda self, node: + self.add_pointwise_simple_binary_broadcast_op(node, NNAPI_OperationCode.MUL), + "aten::relu": lambda self, node: + self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.RELU), + "aten::sigmoid": lambda self, node: + self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.LOGISTIC), + "aten::hardtanh": lambda self, node: + self.add_hardtanh(node), + "aten::max_pool2d": lambda self, node: + self.add_pool2d_node(node, NNAPI_OperationCode.MAX_POOL_2D), + "aten::adaptive_avg_pool2d": lambda self, node: + self.add_adaptive_avg_pool2d(node), + "aten::upsample_nearest2d": lambda self, node: + self.add_upsample_nearest2d(node), + "aten::prelu": lambda self, node: + self.add_prelu_op(node), + "aten::addmm": lambda self, node: + self.add_addmm(node), + "aten::_convolution": lambda self, node: + self.add_conv_underscore(node), + "aten::conv2d": lambda self, node: + self.add_conv2d(node), + "quantized::linear": lambda self, node: + self.add_qlinear(node), + "quantized::conv2d": lambda self, node: + self.add_qconv2d(node, NNAPI_FuseCode.FUSED_NONE), + "quantized::conv2d_relu": lambda self, node: + self.add_qconv2d(node, NNAPI_FuseCode.FUSED_RELU), + "quantized::add": lambda self, node: + self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE), + "quantized::add_relu": lambda self, node: + self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU), + } + + def add_node(self, node): + adder = self.ADDER_MAP.get(node.kind()) + if not adder: + raise Exception("Unsupported node kind (%r) in node %r" % (node.kind(), node)) + adder(self, node) + + def add_getattr(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + obj_ctype, obj = self.get_constant_value(node.inputsAt(0)) + assert str(obj_ctype).startswith("__torch__.") + name = node.s("name") + value = getattr(obj, name) + output = node.outputsAt(0) + ctype = output.type() + self.add_constant_value(output, ctype, value) + + def add_constant_node(self, node): + assert node.inputsSize() == 0 + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + value = output.toIValue() + self.add_constant_value(output, ctype, value) + + def add_list_construct(self, node): + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + values = [] + for inp in node.inputs(): + _, val = self.get_constant_value(inp) + values.append(val) + self.add_constant_value(output, ctype, values) + + def add_tuple_construct(self, node): + assert node.outputsSize() == 1 + output = node.outputsAt(0) + values = [] + for inp in node.inputs(): + values.append(inp) + self.add_tensor_tuple(output, values) + + def add_reshape(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + shape_ctype, shape = self.get_constant_value(node.inputsAt(1)) + assert shape_ctype.kind() == "ListType" + assert shape_ctype.getElementType().kind() == "IntType" + is_trivial_reshape = len(shape) == 2 and shape[1] == -1 + + if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape: + raise Exception( + "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1].") + + # Bit of a hack here. Use a real tensor to infer the output shape. + out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape + out_oper = in_oper._replace(shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector(shape) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) + + def add_quantize(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + if in_oper.dim_order != DimOrder.CHANNELS_LAST: + raise Exception( + "Most hardware backends prefer NHWC quantized tensors. " + "Try setting `t.nnapi_nhwc = True` on your tensor inputs. ") + _, scale = self.get_constant_value(node.inputsAt(1), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") + _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") + if scalar_type != TorchScalarTypes.QUINT8.value: + raise Exception( + "PyTorch NNAPI export only supports quantized tensors " + "with the quint8 dtype.") + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + + out_oper = in_oper._replace( + op_type=op_type, + scale=scale, + zero_point=zero_point, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs) + + def add_dequantize(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + out_oper = in_oper._replace( + op_type=NNAPI_OperandCode.TENSOR_FLOAT32, + scale=0.0, + zero_point=0, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs) + + def add_pointwise_simple_unary_op(self, node, opcode): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) + + self.add_operation(opcode, inputs, outputs) + + def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): + """Helper for pointwise binary broadcast ops with superfluous extra args""" + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + # TODO: Should support constant as either operand. + in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) + + assert in0_oper.op_type == in1_oper.op_type + assert in0_oper.dim_order == in1_oper.dim_order + # NOTE: PyTorch and NNAPI have the same broadcast semantics. + out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) + out_oper = in0_oper._replace(shape=out_shape) + if qparams is not None: + scale, zp = qparams + out_oper = out_oper._replace(scale=scale, zero_point=zp) + + inputs = [None] * 3 + inputs[0] = in0_id + inputs[1] = in1_id + inputs[2] = self.add_immediate_int_scalar(fuse_code) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(opcode, inputs, outputs) + + def add_pointwise_simple_binary_broadcast_op(self, node, opcode): + assert node.inputsSize() == 2 + self._do_add_binary(node, opcode) + + def add_add_sub_op(self, node, opcode, fuse_code): + assert node.inputsSize() == 3 + + _, alpha = self.get_constant_value(node.inputsAt(2), "IntType") + if alpha != 1: + raise Exception("NNAPI does not support add/sub with alpha.") + + self._do_add_binary(node, opcode, fuse_code) + + def add_qadd(self, node, opcode, fuse_code): + assert node.inputsSize() == 4 + + _, scale = self.get_constant_value(node.inputsAt(2), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType") + + self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point)) + + def add_hardtanh(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType") + _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType") + + op_map = { + 1: NNAPI_OperationCode.RELU1, + 6: NNAPI_OperationCode.RELU6, + } + + if min_val != 0 or max_val not in op_map: + raise Exception("NNAPI only supports hardtanh with args (0, 1) or (0, 6).") + opcode = op_map[max_val] + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) + + self.add_operation(opcode, inputs, outputs) + + def add_prelu_op(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1)) + assert len(w_oper.shape) == 1 + assert w_oper.shape[0] > 0 + if w_oper.shape[0] > 1: + if in_oper.use_nchw(): + # TODO: Support this by adding trailing 1 dims. + raise Exception("Per-channel PReLU only supports channels_last right now.") + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = w_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) + + self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs) + + def add_pool2d_node(self, node, opcode): + assert node.inputsSize() == 6 + assert node.outputsSize() == 1 + image, kernel, stride, padding, dilation, ceil_mode = node.inputs() + + stride = stride or kernel + + # TODO: Validate ceil_mode semantics. + + args = self.get_conv_pool_args_2d_from_jit(self.get_size_arg(kernel), stride, padding, dilation) + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("NNAPI does not support dilated pooling.") + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + out_shape = get_conv_pool_shape(image_oper.shape, args, image_oper.shape[1], False) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(args.pad_l) + inputs[2] = self.add_immediate_int_scalar(args.pad_r) + inputs[3] = self.add_immediate_int_scalar(args.pad_t) + inputs[4] = self.add_immediate_int_scalar(args.pad_b) + inputs[5] = self.add_immediate_int_scalar(args.stride_w) + inputs[6] = self.add_immediate_int_scalar(args.stride_h) + inputs[7] = self.add_immediate_int_scalar(args.kernel_w) + inputs[8] = self.add_immediate_int_scalar(args.kernel_h) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + + self.add_operation(opcode, inputs, outputs) + + def add_adaptive_avg_pool2d(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + image_id, image_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + assert len(image_oper.shape) == 4 + + size_ctype, size_arg = self.get_constant_value(node.inputsAt(1)) + assert size_ctype.kind() == "ListType" + assert size_ctype.getElementType().kind() == "IntType" + if size_arg != [1, 1]: + raise Exception("NNAPI only supports adaptive_avg_pool2d with output size (1, 1).") + + out_shape = image_oper.shape[0:2] + tuple(size_arg) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(0) + inputs[2] = self.add_immediate_int_scalar(0) + inputs[3] = self.add_immediate_int_scalar(0) + inputs[4] = self.add_immediate_int_scalar(0) + inputs[5] = self.add_immediate_int_scalar(1) + inputs[6] = self.add_immediate_int_scalar(1) + inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3]) + inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2]) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + + self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) + + def add_upsample_nearest2d(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + image, size_jit, scale_jit = node.inputs() + size_ctype, size_arg = self.get_constant_value(size_jit) + scale_ctype, scale_arg = self.get_constant_value(scale_jit) + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType": + raise Exception("Size and scale cannot both be non-None.") + elif size_ctype.kind() != "NoneType": + assert size_ctype.kind() == "ListType" + assert size_ctype.getElementType().kind() == "IntType" + assert scale_ctype.kind() == "NoneType" + assert scale_arg is None + assert isinstance(size_arg, list) + assert size_arg + assert all(isinstance(val, int) for val in size_arg) + if len(size_arg) == 1: + size_arg = size_arg * 2 + assert len(size_arg) == 2 + out_h = size_arg[0] + out_w = size_arg[1] + arg_h = self.add_immediate_int_scalar(out_h) + arg_w = self.add_immediate_int_scalar(out_w) + elif scale_ctype.kind() != "NoneType": + assert scale_ctype.kind() == "ListType" + assert scale_ctype.getElementType().kind() == "FloatType" + assert size_ctype.kind() == "NoneType" + assert size_arg is None + assert isinstance(scale_arg, list) + assert scale_arg + assert all(isinstance(val, float) for val in scale_arg) + if len(scale_arg) == 1: + scale_arg = scale_arg * 2 + assert len(scale_arg) == 2 + out_h = int(scale_arg[0] * image_oper.shape[2]) + out_w = int(scale_arg[1] * image_oper.shape[3]) + arg_h = self.add_immediate_float_scalar(scale_arg[0]) + arg_w = self.add_immediate_float_scalar(scale_arg[1]) + else: + raise Exception("Size and scale cannot both be None.") + + out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 4 + inputs[0] = image_id + inputs[1] = arg_w + inputs[2] = arg_h + inputs[3] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + + self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs) + + def add_addmm(self, node): + assert node.inputsSize() == 5 + assert node.outputsSize() == 1 + jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs() + + for jitval in (jit_beta, jit_alpha): + scale_ctype, scale_value = self.get_constant_value(jitval) + assert scale_ctype.kind() in ("IntType", "FloatType") + if scale_value != 1: + raise Exception("NNAPI Fully-Connected does not support alpha and beta.") + + input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) + bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) + + assert len(input_oper.shape) == 2 + assert len(bias_oper.shape) == 1 + + # TODO: Transform at load time to share weights with CPU model. + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + assert len(weight_tensor.shape) == 2 + nnapi_weight_tensor = weight_tensor.t().contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + out_shape = (input_oper.shape[0], weight_oper.shape[0]) + + inputs = [None] * 4 + inputs[0] = input_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), input_oper._replace(shape=out_shape)) + + self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) + + def add_qlinear(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + ( + jit_input, + jit_packed_weight, + jit_scale, + jit_zero_point, + ) = node.inputs() + + input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) + # TODO: Support automatic reshape + assert len(input_oper.shape) == 2 + + _, out_scale = self.get_constant_value(jit_scale, "FloatType") + _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") + weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) + assert weight_ctype.name() == "LinearPackedParamsBase" + raw_weight, raw_bias = packed_weight.__getstate__()[0] + assert raw_bias is not None + + assert len(raw_weight.shape) == 2 + assert len(raw_bias.shape) == 1 + assert raw_bias.shape[0] == raw_weight.shape[0] + assert raw_weight.shape[1] == input_oper.shape[1] + + assert raw_weight.qscheme() == torch.per_tensor_affine + if raw_weight.dtype == torch.quint8: + unsigned_weight = raw_weight + else: + assert raw_weight.dtype == torch.qint8 + unsigned_weight = torch._make_per_tensor_quantized_tensor( + (raw_weight.int_repr().int() + 128).to(torch.uint8), + scale=raw_weight.q_scale(), + zero_point=raw_weight.q_zero_point() + 128) + weight_scale = unsigned_weight.q_scale() + bias_scale = input_oper.scale * weight_scale + int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) + bias_id = self.add_tensor_operand_for_weight(int_bias) + + multiplier = input_oper.scale * weight_scale / out_scale + assert multiplier > 0 + if multiplier >= 1: + raise Exception( + "Quantized convolution multiplier is greater than 1. " + "This is supported by NNAPI, but not by most hardware backends. " + "Try training a model without quantization-aware training. ") + + # TODO: Transform at load time to share weights with CPU model. + nnapi_weight_tensor = unsigned_weight.contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + out_shape = (input_oper.shape[0], weight_oper.shape[0]) + out_oper = input_oper._replace( + shape=out_shape, + scale=out_scale, + zero_point=out_zero_point, + ) + + inputs = [None] * 4 + inputs[0] = input_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) + + def get_optional_bias(self, jit_bias, weight_tensor): + ctype, value = self.get_constant_value(jit_bias) + if ctype.kind() == "NoneType": + nnapi_bias_tensor = torch.zeros(weight_tensor.size()[0], dtype=weight_tensor.dtype) + bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor) + bias_oper = self.operands[bias_id] + return bias_id, bias_oper + else: + return self.get_tensor_operand_for_weight(jit_bias) + + def add_conv2d(self, node): + assert node.inputsSize() == 7 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_weight, + jit_bias, + jit_stride, + jit_pad, + jit_dilation, + jit_groups, + ) = node.inputs() + + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor) + args = self.get_conv_pool_args_2d_from_jit( + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups) + + return self.add_conv2d_common( + node.outputsAt(0), + 0.0, + 0, + jit_image, + weight_tensor, + bias_id, + args, + False, # transpose + NNAPI_FuseCode.FUSED_NONE, + ) + + def add_conv_underscore(self, node): + assert node.inputsSize() == 13 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_weight, + jit_bias, + jit_stride, + jit_pad, + jit_dilation, + jit_transpose, + _, + jit_groups, + _, + _, + _, + _, + ) = node.inputs() + + # XXX check jit_transpose + + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor) + args = self.get_conv_pool_args_2d_from_jit( + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups) + + return self.add_conv2d_common( + node.outputsAt(0), + 0.0, + 0, + jit_image, + weight_tensor, + bias_id, + args, + False, # transpose + NNAPI_FuseCode.FUSED_NONE, + ) + + def add_qconv2d(self, node, fuse_code): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_packed_weight, + jit_scale, + jit_zero_point, + ) = node.inputs() + + _, out_scale = self.get_constant_value(jit_scale, "FloatType") + _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") + weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) + assert weight_ctype.name() == "Conv2dPackedParamsBase" + ( + pack_version, + tensors, + opt_tensors, + ) = packed_weight.__getstate__()[0] + assert pack_version == "2" + packed_config, raw_weight = tensors + raw_bias, = opt_tensors + assert raw_bias is not None + args = self.get_conv_pool_args_2d_from_pack(raw_weight.shape[2:4], packed_config) + + assert raw_weight.qscheme() == torch.per_tensor_affine + if raw_weight.dtype == torch.quint8: + unsigned_weight = raw_weight + else: + assert raw_weight.dtype == torch.qint8 + unsigned_weight = torch._make_per_tensor_quantized_tensor( + (raw_weight.int_repr().int() + 128).to(torch.uint8), + scale=raw_weight.q_scale(), + zero_point=raw_weight.q_zero_point() + 128) + weight_scale = unsigned_weight.q_scale() + _, image_oper = self.get_tensor_operand_by_jitval(jit_image) + bias_scale = image_oper.scale * weight_scale + int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) + bias_id = self.add_tensor_operand_for_weight(int_bias) + + multiplier = image_oper.scale * weight_scale / out_scale + assert multiplier > 0 + if multiplier >= 1: + raise Exception( + "Quantized convolution multiplier is greater than 1. " + "This is supported by NNAPI, but not by most hardware backends. " + "Try training a model without quantization-aware training. ") + + return self.add_conv2d_common( + node.outputsAt(0), + out_scale, + out_zero_point, + jit_image, + unsigned_weight, + bias_id, + args, + False, # transpose + fuse_code, + ) + + def add_conv2d_common( + self, + jit_out, + out_scale, + out_zero_point, + jit_image, + weight_tensor, + bias_id, + args, + transpose, + fuse_code): + image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) + in_c = image_oper.shape[1] + + if args.group == 1: + # Full convolution + depthwise = False + weight_permutation = (0, 2, 3, 1) + elif args.group == in_c: + # Depthwise convolution + depthwise = True + weight_permutation = (1, 2, 3, 0) + else: + raise Exception("Group convolution not supported yet.") + + # TODO: Transform at load time to share weights with CPU model. + nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + bias_oper = self.operands[bias_id] + + if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32 + assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) + assert bias_oper.zero_point == 0 + else: + raise Exception( + "Unsupported input type for conv2d: {}" + .format(image_oper.op_type)) + + assert len(image_oper.shape) == 4 + assert len(weight_oper.shape) == 4 + assert len(bias_oper.shape) == 1 + + if depthwise: + # Depthwise convolution + one, kern_h, kern_w, out_c = weight_oper.shape + assert one == 1 + assert out_c % in_c == 0 + channel_multiplier = out_c // in_c + assert channel_multiplier == 1 # Don't support multiplier + assert out_c == in_c + else: + # Full convolution + kern_nf, kern_h, kern_w, kern_d = weight_oper.shape + out_c = kern_nf + assert kern_d == in_c + + assert out_c == bias_oper.shape[0] + + out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose) + out_oper = image_oper._replace( + shape=out_shape, + scale=out_scale, + zero_point=out_zero_point, + ) + + use_nchw = image_oper.use_nchw() + + if depthwise: + num_args = 12 + opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D + else: + num_args = 11 + if transpose: + opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D + else: + opcode = NNAPI_OperationCode.CONV_2D + + inputs = [None] * num_args + inputs[0] = image_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(args.pad_l) + inputs[4] = self.add_immediate_int_scalar(args.pad_r) + inputs[5] = self.add_immediate_int_scalar(args.pad_t) + inputs[6] = self.add_immediate_int_scalar(args.pad_b) + inputs[7] = self.add_immediate_int_scalar(args.stride_w) + inputs[8] = self.add_immediate_int_scalar(args.stride_h) + if depthwise: + inputs[9] = self.add_immediate_int_scalar(1) + inputs[10] = self.add_immediate_int_scalar(fuse_code) + inputs[11] = self.add_immediate_bool_scalar(use_nchw) + else: + inputs[9] = self.add_immediate_int_scalar(fuse_code) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(jit_out, out_oper) + + self.add_operation(opcode, inputs, outputs) + + +def serialize_model(module, inputs, config=None): + return _NnapiSerializer(config).serialize_model(module, inputs) diff --git a/torch/contrib/_tensorboard_vis.py b/torch/contrib/_tensorboard_vis.py index 8f4ca71ff202f..b1b8d35a511d7 100644 --- a/torch/contrib/_tensorboard_vis.py +++ b/torch/contrib/_tensorboard_vis.py @@ -1,6 +1,7 @@ import time from collections import defaultdict from functools import partial +from typing import DefaultDict import torch @@ -52,7 +53,7 @@ def visualize(graph, name_prefix='', pb_graph=None, executors_it=None): def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph): """Appends the state of a given GraphExecutor to the graph protobuf. - Arguments: + Args: state (GraphExecutor or GraphExecutorState): GraphExecutor to display. name_prefix (str): Name prefix of the containing subgraph. pb_graph (GraphDef): graph to append to. @@ -104,7 +105,7 @@ def inline_graph(subgraph, name, node): for out, val in zip(subgraph.outputs(), node.outputs()): value_map[val.unique()] = rec_value_map[out.unique()] - op_id_counter = defaultdict(int) + op_id_counter: DefaultDict[str, int] = defaultdict(int) def name_for(node): kind = node.kind()[node.kind().index('::') + 2:] diff --git a/torch/csrc/CudaIPCTypes.cpp b/torch/csrc/CudaIPCTypes.cpp index b29fc1da0f1ab..7a61031f08864 100644 --- a/torch/csrc/CudaIPCTypes.cpp +++ b/torch/csrc/CudaIPCTypes.cpp @@ -6,7 +6,7 @@ #include #ifdef _MSC_VER -#include +#include #else #include #include diff --git a/torch/csrc/DataLoader.cpp b/torch/csrc/DataLoader.cpp index 044c3b04221e5..58c74d2b2a1ce 100644 --- a/torch/csrc/DataLoader.cpp +++ b/torch/csrc/DataLoader.cpp @@ -215,9 +215,9 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *_ig #endif PyMethodDef DataLoaderMethods[] = { - {"_set_worker_signal_handlers", (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, nullptr}, - {"_set_worker_pids", (PyCFunction)THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, - {"_remove_worker_pids", (PyCFunction)THPModule_removeWorkerPIDs, METH_O, nullptr}, - {"_error_if_any_worker_fails", (PyCFunction)THPModule_errorIfAnyWorkerFails, METH_NOARGS, nullptr}, + {"_set_worker_signal_handlers", THPModule_setWorkerSignalHandlers, METH_NOARGS, nullptr}, + {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, + {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr}, + {"_error_if_any_worker_fails", THPModule_errorIfAnyWorkerFails, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr} }; diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index 3ce7f09d0cf0e..fef36f56c420b 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -30,7 +30,10 @@ PyObject *THPDevice_repr(THPDevice *self) std::ostringstream oss; oss << "device(type=\'" << self->device.type() << "\'"; if (self->device.has_index()) { - oss << ", index=" << self->device.index(); + // `self->device.index()` returns uint8_t which is treated as ascii while printing, + // hence casting it to uint16_t. + // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout + oss << ", index=" << static_cast(self->device.index()); } oss << ")"; return THPUtils_packString(oss.str().c_str()); @@ -138,9 +141,10 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { END_HANDLE_TH_ERRORS } -PyObject *THPDevice_reduce(THPDevice *self, PyObject *noargs) +PyObject *THPDevice_reduce(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPDevice*)_self; auto ret = THPObjectPtr{PyTuple_New(2)}; if (!ret) throw python_error(); @@ -174,7 +178,7 @@ static struct PyGetSetDef THPDevice_properties[] = { }; static PyMethodDef THPDevice_methods[] = { - {"__reduce__", (PyCFunction)THPDevice_reduce, METH_NOARGS, nullptr}, + {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr}, {nullptr} /* Sentinel */ }; diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp index 2243b31d9e06f..58961f3348c83 100644 --- a/torch/csrc/Dtype.cpp +++ b/torch/csrc/Dtype.cpp @@ -51,12 +51,13 @@ PyObject *THPDtype_is_signed(THPDtype *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -PyObject *THPDtype_reduce(THPDtype *self, PyObject *noargs) +PyObject *THPDtype_reduce(PyObject *_self, PyObject *noargs) { /* * For singletons, a string is returned. The string should be interpreted * as the name of a global variable. */ + auto self = (THPDtype*)_self; return THPUtils_packString(self->name); } @@ -70,7 +71,7 @@ static struct PyGetSetDef THPDtype_properties[] = { }; static PyMethodDef THPDtype_methods[] = { - {"__reduce__", (PyCFunction)THPDtype_reduce, METH_NOARGS, nullptr}, + {"__reduce__", THPDtype_reduce, METH_NOARGS, nullptr}, {nullptr} /* Sentinel */ }; diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index 6969ac0449c06..92e8a93c284e5 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -59,11 +59,10 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala PyTypeObject* getPyTypeObject( const at::Storage& storage, - const caffe2::TypeMeta& dtype) { + const caffe2::TypeMeta dtype) { at::ScalarType scalarType = at::typeMetaToScalarType(dtype); - at::TensorOptions options = at::TensorOptions(storage.device_type()).dtype(scalarType); auto attype = &at::getDeprecatedTypeProperties( - at::dispatchKeyToBackend(at::computeDispatchKey(options)), + at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())), scalarType); auto it = attype_to_py_storage_type.find(attype); if (it != attype_to_py_storage_type.end()) { @@ -107,7 +106,7 @@ THPLayout* getTHPLayout(at::Layout layout) { PyObject* createPyObject( const at::Storage& storage, - const caffe2::TypeMeta& data_type) { + const caffe2::TypeMeta data_type) { auto type = getPyTypeObject(storage, data_type); auto obj = THPObjectPtr(type->tp_alloc(type, 0)); if (!obj) throw python_error(); diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h index 0877fb317cb33..d93d0e3b5cf5a 100644 --- a/torch/csrc/DynamicTypes.h +++ b/torch/csrc/DynamicTypes.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -29,7 +30,7 @@ void registerLayoutObject(THPLayout *thp_layout, at::Layout layout); PyObject* createPyObject( const at::Storage& storage, - const caffe2::TypeMeta& data_type); + const caffe2::TypeMeta data_type); at::Storage createStorage(PyObject* obj); bool isStorage(PyObject* obj); diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index eb735b73d5413..73042117a45cc 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -159,6 +159,13 @@ ValueError::ValueError(const char *format, ...) { va_end(fmt_args); } +AttributeError::AttributeError(const char* format, ...) { + va_list fmt_args; + va_start(fmt_args, format); + msg = formatMessage(format, fmt_args); + va_end(fmt_args); +} + void PyWarningHandler::process( const c10::SourceLocation& source_location, const std::string& msg, diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 66e335b2bc764..c9d096270d2a4 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -291,6 +291,14 @@ struct NotImplementedError : public PyTorchError { } }; +// Translates to Python AttributeError +struct AttributeError : public PyTorchError { + AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); + PyObject* python_type() override { + return PyExc_AttributeError; + } +}; + struct WarningMeta { WarningMeta(const c10::SourceLocation& _source_location, const std::string& _msg, diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 0884223f765d4..2bc478f36007a 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -15,7 +15,6 @@ #include #ifdef USE_CUDA -#include #include #endif @@ -74,55 +73,45 @@ static PyObject * THPGenerator_pynew(PyTypeObject *type, PyObject *args, PyObjec END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_getState(THPGenerator *self, PyObject *noargs) +static PyObject * THPGenerator_getState(PyObject *_self, PyObject *noargs) { using namespace torch::autograd; HANDLE_TH_ERRORS - Variable var = torch::empty({0}, at::device(at::kCPU).dtype(at::kByte)); - if (self->cdata.device().type() == at::kCPU) { - THByteTensor_getRNGState(self->cdata, (THByteTensor*)(var.unsafeGetTensorImpl())); - } else { -#ifdef USE_CUDA - TORCH_INTERNAL_ASSERT(self->cdata.device().type() == at::kCUDA); - THCRandom_getRNGState(self->cdata, (THByteTensor*)(var.unsafeGetTensorImpl())); -#else - TORCH_INTERNAL_ASSERT(false, "PyTorch not compiled with CUDA"); -#endif - } - return THPVariable_Wrap(std::move(var)); + auto& gen = ((THPGenerator*)_self)->cdata; + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + auto state_tensor = gen.get_state(); + + return THPVariable_Wrap(std::move(state_tensor)); END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_setState(THPGenerator *self, PyObject *_new_state) +static PyObject * THPGenerator_setState(PyObject *_self, PyObject *_new_state) { using namespace torch::autograd; + HANDLE_TH_ERRORS if (!THPVariable_Check(_new_state)) { throw torch::TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name); } - auto& tensor = ((THPVariable*)_new_state)->cdata; - if (tensor.layout() != kStrided || tensor.device().type() != kCPU || tensor.scalar_type() != kByte) { - auto type_name = torch::utils::options_to_string(tensor.options()); - throw torch::TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str()); - } - if (self->cdata.device().type() == at::kCPU) { - THByteTensor_setRNGState(self->cdata, (THByteTensor*)tensor.unsafeGetTensorImpl()); - } else { -#ifdef USE_CUDA - TORCH_INTERNAL_ASSERT(self->cdata.device().type() == at::kCUDA); - THCRandom_setRNGState(self->cdata, (THByteTensor*)tensor.unsafeGetTensorImpl()); -#else - TORCH_INTERNAL_ASSERT(false, "PyTorch not compiled with CUDA"); -#endif - } + auto self = (THPGenerator*)_self; + auto& gen = self->cdata; + auto& new_state_tensor = ((THPVariable*)_new_state)->cdata; + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + gen.set_state(new_state_tensor); + Py_INCREF(self); return (PyObject*)self; END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_manualSeed(THPGenerator *self, PyObject *seed) +static PyObject * THPGenerator_manualSeed(PyObject *_self, PyObject *seed) { HANDLE_TH_ERRORS + auto self = (THPGenerator*)_self; auto generator = self->cdata; THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, " "but got %s", THPUtils_typename(seed)); @@ -150,19 +139,21 @@ static PyObject * THPGenerator_manualSeed(THPGenerator *self, PyObject *seed) END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_seed(THPGenerator *self, PyObject *noargs) +static PyObject * THPGenerator_seed(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS // See Note [Acquire lock when using random generators] + auto self = (THPGenerator*)_self; std::lock_guard lock(self->cdata.mutex()); uint64_t seed_val = self->cdata.seed(); return THPUtils_packUInt64(seed_val); END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_initialSeed(THPGenerator *self, PyObject *noargs) +static PyObject * THPGenerator_initialSeed(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPGenerator*)_self; return THPUtils_packUInt64(self->cdata.current_seed()); END_HANDLE_TH_ERRORS } @@ -179,11 +170,11 @@ static struct PyGetSetDef THPGenerator_properties[] = { }; static PyMethodDef THPGenerator_methods[] = { - {"get_state", (PyCFunction)THPGenerator_getState, METH_NOARGS, nullptr}, - {"set_state", (PyCFunction)THPGenerator_setState, METH_O, nullptr}, - {"manual_seed", (PyCFunction)THPGenerator_manualSeed, METH_O, nullptr}, - {"seed", (PyCFunction)THPGenerator_seed, METH_NOARGS, nullptr}, - {"initial_seed", (PyCFunction)THPGenerator_initialSeed, METH_NOARGS, nullptr}, + {"get_state", THPGenerator_getState, METH_NOARGS, nullptr}, + {"set_state", THPGenerator_setState, METH_O, nullptr}, + {"manual_seed", THPGenerator_manualSeed, METH_O, nullptr}, + {"seed", THPGenerator_seed, METH_NOARGS, nullptr}, + {"initial_seed", THPGenerator_initialSeed, METH_NOARGS, nullptr}, {nullptr} }; diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ed4aa21a8f768..ca999652db5c7 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -39,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -61,6 +63,10 @@ #endif #endif +#if defined(USE_VALGRIND) +#include +#endif + #define WITH_NUMPY_IMPORT_ARRAY #include @@ -80,9 +86,9 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) THPObjectPtr types(PySequence_Fast(arg, "expected a sequence")); if (!types) return nullptr; - int num_classes = PySequence_Fast_GET_SIZE(types.get()); + auto num_classes = PySequence_Fast_GET_SIZE(types.get()); names.reserve(names.size() + num_classes); - for (size_t i = 0; i < num_classes; i++) { + for (Py_ssize_t i = 0; i < num_classes; i++) { PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i); THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject"); PyTypeObject* type = (PyTypeObject*)obj; @@ -127,6 +133,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag THPByteStorage_postInit(module); THPBoolStorage_postInit(module); THPQUInt8Storage_postInit(module); + THPQUInt4x2Storage_postInit(module); THPQInt8Storage_postInit(module); THPQInt32Storage_postInit(module); THPBFloat16Storage_postInit(module); @@ -147,27 +154,28 @@ static PyObject * THPModule_crashIfCsrcASAN(PyObject *module, PyObject *arg) { "but got %s", THPUtils_typename(arg)); //NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays) volatile char x[3]; - x[static_cast(THPUtils_unpackLong(arg))] = 0; - return PyLong_FromLong(x[0]); + x[THPUtils_unpackInt(arg)] = 0; + //NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + return THPUtils_packInt32(x[0]); } static PyObject * THPModule_crashIfCsrcUBSAN(PyObject *module, PyObject *arg) { THPUtils_assert(THPUtils_checkLong(arg), "crash_if_csrc_ubsan expects an int, " "but got %s", THPUtils_typename(arg)); - int32_t x = static_cast(THPUtils_unpackLong(arg)); + int32_t x = THPUtils_unpackInt(arg); double y = 1.0 / x; - return PyLong_FromLong((int)y); + return THPUtils_packInt32((int)y); } static PyObject * THPModule_crashIfATenASAN(PyObject *module, PyObject *arg) { THPUtils_assert(THPUtils_checkLong(arg), "crash_if_aten_asan expects an int, " "but got %s", THPUtils_typename(arg)); - return PyLong_FromLong(at::_crash_if_asan(static_cast(THPUtils_unpackLong(arg)))); + return THPUtils_packInt32(at::_crash_if_asan(THPUtils_unpackInt(arg))); } static PyObject * THPModule_getNumThreads(PyObject *module, PyObject *noargs) { - return PyLong_FromLong(at::get_num_threads()); + return THPUtils_packInt32(at::get_num_threads()); } static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg) @@ -182,7 +190,7 @@ static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg) static PyObject * THPModule_getNumInteropThreads(PyObject *module, PyObject *noargs) { - return PyLong_FromLong(at::get_num_interop_threads()); + return THPUtils_packInt32(at::get_num_interop_threads()); } static PyObject * THPModule_setNumInteropThreads(PyObject *module, PyObject *arg) @@ -329,6 +337,13 @@ static PyObject *THPModule_showConfig(PyObject *module, PyObject *noargs) END_HANDLE_TH_ERRORS } +static PyObject *THPModule_cxxFlags(PyObject *module, PyObject *noargs) +{ + HANDLE_TH_ERRORS + return THPUtils_packString(at::get_cxx_flags()); + END_HANDLE_TH_ERRORS +} + static PyObject *THPModule_parallelInfo(PyObject *module, PyObject *noargs) { HANDLE_TH_ERRORS @@ -527,12 +542,12 @@ PyObject *THPModule_setQEngine(PyObject */* unused */, PyObject *arg) Py_RETURN_NONE; } -PyObject *THPModule_qEngine(PyObject */* unused */) +PyObject *THPModule_qEngine(PyObject *_unused, PyObject *noargs) { return THPUtils_packInt64(static_cast(at::globalContext().qEngine())); } -PyObject *THPModule_supportedQEngines(PyObject */* unused */) +PyObject *THPModule_supportedQEngines(PyObject *_unused, PyObject *noargs) { auto qengines = at::globalContext().supportedQEngines(); auto list = THPObjectPtr(PyList_New(qengines.size())); @@ -546,7 +561,7 @@ PyObject *THPModule_supportedQEngines(PyObject */* unused */) return list.release(); } -PyObject *THPModule_isEnabledXNNPACK(PyObject * /* unused */) +PyObject *THPModule_isEnabledXNNPACK(PyObject *_unused, PyObject *noargs) { if (at::globalContext().isXNNPACKAvailable()) Py_RETURN_TRUE; else Py_RETURN_FALSE; @@ -566,54 +581,58 @@ static PyObject * THPModule_vmapmode_decrement_nesting(PyObject* _unused, PyObje //NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays) static PyMethodDef TorchMethods[] = { - {"_initExtension", (PyCFunction)THPModule_initExtension, METH_O, nullptr}, - {"_autograd_init", (PyCFunction)THPAutograd_initExtension, METH_NOARGS, nullptr}, - {"_add_docstr", (PyCFunction)THPModule_addDocStr, METH_VARARGS, nullptr}, - {"_init_names", (PyCFunction)THPModule_initNames, METH_O, nullptr}, - {"_has_distributed",(PyCFunction)THPModule_hasDistributed, METH_NOARGS, nullptr}, - {"_set_default_tensor_type", (PyCFunction)THPModule_setDefaultTensorType, METH_O, nullptr}, - {"_set_default_dtype", (PyCFunction)THPModule_setDefaultDtype, METH_O, nullptr}, - {"_infer_size", (PyCFunction)THPModule_inferSize, METH_VARARGS, nullptr}, - {"_crash_if_csrc_asan", (PyCFunction)THPModule_crashIfCsrcASAN, METH_O, nullptr}, - {"_crash_if_csrc_ubsan", (PyCFunction)THPModule_crashIfCsrcUBSAN, METH_O, nullptr}, - {"_crash_if_aten_asan", (PyCFunction)THPModule_crashIfATenASAN, METH_O, nullptr}, - {"_show_config", (PyCFunction)THPModule_showConfig, METH_NOARGS, nullptr}, - {"_parallel_info", (PyCFunction)THPModule_parallelInfo, METH_NOARGS, nullptr}, - {"_set_backcompat_broadcast_warn", (PyCFunction)THPModule_setBackcompatBroadcastWarn, METH_O, nullptr}, - {"_get_backcompat_broadcast_warn", (PyCFunction)THPModule_getBackcompatBroadcastWarn, METH_NOARGS, nullptr}, - {"_set_backcompat_keepdim_warn", (PyCFunction)THPModule_setBackcompatKeepdimWarn, METH_O, nullptr}, - {"_get_backcompat_keepdim_warn", (PyCFunction)THPModule_getBackcompatKeepdimWarn, METH_NOARGS, nullptr}, - {"get_num_threads", (PyCFunction)THPModule_getNumThreads, METH_NOARGS, nullptr}, - {"set_num_threads", (PyCFunction)THPModule_setNumThreads, METH_O, nullptr}, - {"get_num_interop_threads", (PyCFunction)THPModule_getNumInteropThreads, METH_NOARGS, nullptr}, - {"set_num_interop_threads", (PyCFunction)THPModule_setNumInteropThreads, METH_O, nullptr}, - {"_get_cudnn_enabled", (PyCFunction)THPModule_userEnabledCuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, nullptr}, - {"_get_mkldnn_enabled", (PyCFunction)THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, - {"_set_mkldnn_enabled", (PyCFunction)THPModule_setUserEnabledMkldnn, METH_O, nullptr}, - {"_get_cudnn_allow_tf32", (PyCFunction)THPModule_allowTF32CuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_allow_tf32", (PyCFunction)THPModule_setAllowTF32CuDNN, METH_O, nullptr}, - {"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_benchmark", (PyCFunction)THPModule_setBenchmarkCuDNN, METH_O, nullptr}, - {"_get_cudnn_deterministic", (PyCFunction)THPModule_deterministicCuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_deterministic", (PyCFunction)THPModule_setDeterministicCuDNN, METH_O, nullptr}, - {"_get_deterministic", (PyCFunction)THPModule_deterministic, METH_NOARGS, nullptr}, - {"_set_deterministic", (PyCFunction)THPModule_setDeterministic, METH_O, nullptr}, - {"_get_cublas_allow_tf32", (PyCFunction)THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, - {"_set_cublas_allow_tf32", (PyCFunction)THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, - {"_vmapmode_increment_nesting", (PyCFunction)THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr}, - {"_vmapmode_decrement_nesting", (PyCFunction)THPModule_vmapmode_decrement_nesting, METH_NOARGS, nullptr}, - {"_to_dlpack", (PyCFunction)THPModule_toDLPack, METH_O, nullptr}, - {"_from_dlpack", (PyCFunction)THPModule_fromDLPack, METH_O, nullptr}, - {"set_flush_denormal", (PyCFunction)THPModule_setFlushDenormal, METH_O, nullptr}, - {"get_default_dtype", (PyCFunction)THPModule_getDefaultDtype, METH_NOARGS, nullptr}, - {"_get_default_device", (PyCFunction)THPModule_getDefaultDevice, METH_NOARGS, nullptr}, - {"_get_qengine", (PyCFunction)THPModule_qEngine, METH_NOARGS, nullptr}, - {"_set_qengine", (PyCFunction)THPModule_setQEngine, METH_O, nullptr}, - {"_supported_qengines", (PyCFunction)THPModule_supportedQEngines, METH_NOARGS, nullptr}, - {"_is_xnnpack_enabled", (PyCFunction)THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, - {"_is_torch_function_enabled", (PyCFunction)THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr}, - {"_disabled_torch_function_impl", (PyCFunction)THPModule_disable_torch_function, METH_VARARGS, nullptr}, + {"_initExtension", THPModule_initExtension, METH_O, nullptr}, + {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, + {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, + {"_init_names", THPModule_initNames, METH_O, nullptr}, + {"_has_distributed",THPModule_hasDistributed, METH_NOARGS, nullptr}, + {"_set_default_tensor_type", THPModule_setDefaultTensorType, METH_O, nullptr}, + {"_set_default_dtype", THPModule_setDefaultDtype, METH_O, nullptr}, + {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr}, + {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr}, + {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr}, + {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr}, + {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr}, + {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr}, + {"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr}, + {"_set_backcompat_broadcast_warn", THPModule_setBackcompatBroadcastWarn, METH_O, nullptr}, + {"_get_backcompat_broadcast_warn", THPModule_getBackcompatBroadcastWarn, METH_NOARGS, nullptr}, + {"_set_backcompat_keepdim_warn", THPModule_setBackcompatKeepdimWarn, METH_O, nullptr}, + {"_get_backcompat_keepdim_warn", THPModule_getBackcompatKeepdimWarn, METH_NOARGS, nullptr}, + {"get_num_threads", THPModule_getNumThreads, METH_NOARGS, nullptr}, + {"set_num_threads", THPModule_setNumThreads, METH_O, nullptr}, + {"get_num_interop_threads", THPModule_getNumInteropThreads, METH_NOARGS, nullptr}, + {"set_num_interop_threads", THPModule_setNumInteropThreads, METH_O, nullptr}, + {"_get_cudnn_enabled", THPModule_userEnabledCuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr}, + {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, + {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr}, + {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr}, + {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr}, + {"_get_cudnn_deterministic", THPModule_deterministicCuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_deterministic", THPModule_setDeterministicCuDNN, METH_O, nullptr}, + {"_get_deterministic", THPModule_deterministic, METH_NOARGS, nullptr}, + {"_set_deterministic", THPModule_setDeterministic, METH_O, nullptr}, + {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, + {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, + {"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr}, + {"_vmapmode_decrement_nesting", THPModule_vmapmode_decrement_nesting, METH_NOARGS, nullptr}, + {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, + {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, + {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr}, + {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr}, + {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr}, + {"_get_qengine", THPModule_qEngine, METH_NOARGS, nullptr}, + {"_set_qengine", THPModule_setQEngine, METH_O, nullptr}, + {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr}, + {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, + {"_is_torch_function_enabled", THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr}, + {"_disabled_torch_function_impl", THPModule_disable_torch_function, METH_VARARGS, nullptr}, + {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr}, + {"_has_torch_function_unary", THPModule_has_torch_function_unary, METH_O, nullptr}, + {"_has_torch_function_variadic", MAYBE_WRAP_FASTCALL(THPModule_has_torch_function_variadic), MAYBE_METH_FASTCALL, nullptr}, {nullptr, nullptr, 0, nullptr} }; @@ -632,6 +651,7 @@ bool THCPComplexFloatStorage_init(PyObject *module); void THCPStream_init(PyObject *module); void THCPEvent_init(PyObject *module); +void THCPGraph_init(PyObject *module); #ifdef USE_CUDA PyMethodDef* THCPModule_methods(); @@ -688,9 +708,9 @@ PyObject* initModule() { #ifdef USE_CUDA THPUtils_addPyMethodDefs(methods, THCPModule_methods()); #endif -#ifdef USE_DISTRIBUTED -#ifdef USE_C10D +#if defined(USE_DISTRIBUTED) && defined(USE_C10D) THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions()); +#ifndef _WIN32 THPUtils_addPyMethodDefs(methods, torch::distributed::rpc::python_functions()); THPUtils_addPyMethodDefs( methods, torch::distributed::autograd::python_functions()); @@ -706,7 +726,6 @@ PyObject* initModule() { methods.data() }; ASSERT_TRUE(module = PyModule_Create(&torchmodule)); - ASSERT_TRUE(THPWrapper_init(module)); ASSERT_TRUE(THPGenerator_init(module)); ASSERT_TRUE(THPException_init(module)); THPSize_init(module); @@ -716,6 +735,7 @@ PyObject* initModule() { THPMemoryFormat_init(module); THPQScheme_init(module); THPDevice_init(module); + THPStream_init(module); ASSERT_TRUE(THPVariable_initModule(module)); ASSERT_TRUE(THPFunction_initModule(module)); ASSERT_TRUE(THPEngine_initModule(module)); @@ -746,6 +766,7 @@ PyObject* initModule() { ASSERT_TRUE(THPQUInt8Storage_init(module)); ASSERT_TRUE(THPQInt8Storage_init(module)); ASSERT_TRUE(THPQInt32Storage_init(module)); + ASSERT_TRUE(THPQUInt4x2Storage_init(module)); ASSERT_TRUE(THPBFloat16Storage_init(module)); ASSERT_TRUE(THPComplexDoubleStorage_init(module)); ASSERT_TRUE(THPComplexFloatStorage_init(module)); @@ -770,6 +791,7 @@ PyObject* initModule() { THCPStream_init(module); THCPEvent_init(module); + THCPGraph_init(module); #endif auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) { @@ -821,6 +843,26 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False)); ASSERT_TRUE(set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); + py_module.def( + "_valgrind_supported_platform", [](){ + #if defined(USE_VALGRIND) + return true; + #else + return false; + #endif + } + ); + + py_module.def( + "_valgrind_toggle", [](){ + #if defined(USE_VALGRIND) + CALLGRIND_TOGGLE_COLLECT; + #else + TORCH_CHECK(false, "Valgrind is not supported."); + #endif + } + ); + #ifdef USE_CUDA PyObject *has_cuda = Py_True; #else @@ -836,7 +878,30 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False)); #endif - auto defaultGenerator = at::detail::getDefaultCPUGenerator(); +// See note [Pybind11 ABI constants] +#define SET_STR_DEFINE(name) \ + ASSERT_TRUE(set_module_attr("_" # name, THPUtils_packString(name))) + +#ifdef PYBIND11_COMPILER_TYPE + SET_STR_DEFINE(PYBIND11_COMPILER_TYPE); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None)); +#endif + +#ifdef PYBIND11_STDLIB + SET_STR_DEFINE(PYBIND11_STDLIB); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None)); +#endif + +#ifdef PYBIND11_BUILD_ABI + SET_STR_DEFINE(PYBIND11_BUILD_ABI); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None)); +#endif +#undef SET_STR_DEFINE + + const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); // This reference is meant to be given away, so no need to incref here. ASSERT_TRUE(set_module_attr("default_generator", (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); diff --git a/torch/csrc/PtrWrapper.cpp b/torch/csrc/PtrWrapper.cpp deleted file mode 100644 index aa48c49949b9b..0000000000000 --- a/torch/csrc/PtrWrapper.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include -#include -#include - -static PyObject* THPWrapperClass = nullptr; - -struct THPWrapper { - PyObject_HEAD - void *data; - void (*destructor)(void*); -}; - -PyObject * THPWrapper_New(void *data, void (*destructor)(void*)) -{ - PyObject *args = PyTuple_New(0); - if (!args) { - return nullptr; - } - PyObject *result = PyObject_Call(THPWrapperClass, args, nullptr); - if (result) { - THPWrapper* wrapper = (THPWrapper*) result; - wrapper->data = data; - wrapper->destructor = destructor; - } - Py_DECREF(args); - return result; -} - -bool THPWrapper_check(PyObject * obj) -{ - return (PyObject*)Py_TYPE(obj) == THPWrapperClass; -} - -void * THPWrapper_get(PyObject * obj) -{ - return ((THPWrapper*)obj)->data; -} - -static PyObject * THPWrapper_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ - PyObject* self = type->tp_alloc(type, 0); - THPWrapper* wrapper = (THPWrapper*) self; - wrapper->data = nullptr; - wrapper->destructor = nullptr; - return self; -} - -static void THPWrapper_dealloc(THPWrapper* self) -{ - self->destructor(self->data); - Py_TYPE(self)->tp_free((PyObject*)self); -} - -PyTypeObject THPWrapperType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._PtrWrapper", /* tp_name */ - sizeof(THPWrapper), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THPWrapper_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPWrapper_pynew, /* tp_new */ -}; - -bool THPWrapper_init(PyObject *module) -{ - THPWrapperClass = (PyObject*)&THPWrapperType; - if (PyType_Ready(&THPWrapperType) < 0) - return false; - Py_INCREF(&THPWrapperType); - return true; -} diff --git a/torch/csrc/PtrWrapper.h b/torch/csrc/PtrWrapper.h deleted file mode 100644 index 985193c74c9b5..0000000000000 --- a/torch/csrc/PtrWrapper.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef THP_PTR_WRAPPER_H -#define THP_PTR_WRAPPER_H - -#include - -/** - * Python wrapper around arbitrary opaque C++ class - */ - -bool THPWrapper_init(PyObject *module); - -PyObject * THPWrapper_New(void *data, void (*destructor)(void*)); -void * THPWrapper_get(PyObject * obj); -bool THPWrapper_check(PyObject * obj); - -#endif diff --git a/torch/csrc/QScheme.cpp b/torch/csrc/QScheme.cpp index 274ccf890df3c..77036d460984c 100644 --- a/torch/csrc/QScheme.cpp +++ b/torch/csrc/QScheme.cpp @@ -22,12 +22,13 @@ PyObject *THPQScheme_New(at::QScheme qscheme, const std::string& name) return self.release(); } -PyObject *THPQScheme_reduce(THPQScheme *self, PyObject *noargs) { +PyObject *THPQScheme_reduce(PyObject *_self, PyObject *noargs) { + auto self = (THPQScheme*)_self; return THPUtils_packString(self->name); } static PyMethodDef THPQScheme_methods[] = { - {"__reduce__", (PyCFunction)THPQScheme_reduce, METH_NOARGS, nullptr}, + {"__reduce__", THPQScheme_reduce, METH_NOARGS, nullptr}, {nullptr} /* Sentinel */ }; diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index 6df1b33d09819..ac88501358bdd 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -84,7 +85,7 @@ static PyObject * THPSize_repr(THPSize *self) if (i != 0) { repr += ", "; } - repr += std::to_string(PyLong_AsLong(PyTuple_GET_ITEM(self, i))); + repr += std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i))); } repr += "])"; return THPUtils_packString(repr); @@ -130,20 +131,22 @@ static PyMappingMethods THPSize_as_mapping = { nullptr }; -static PyObject *THPSize_numel(THPSize *self, PyObject *noargs) +static PyObject *THPSize_numel(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPSize*)_self; int64_t numel = 1; for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) { - numel *= PyLong_AsLong(PyTuple_GET_ITEM(self, i)); + numel *= THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)); } return THPUtils_packInt64(numel); END_HANDLE_TH_ERRORS } -static PyObject *THPSize_reduce(THPSize *self, PyObject *noargs) +static PyObject *THPSize_reduce(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPSize*)_self; auto ret = THPObjectPtr{PyTuple_New(2)}; if (!ret) throw python_error(); @@ -168,8 +171,8 @@ static PyObject *THPSize_reduce(THPSize *self, PyObject *noargs) } static PyMethodDef THPSize_methods[] = { - {"numel", (PyCFunction)THPSize_numel, METH_NOARGS, nullptr}, - {"__reduce__", (PyCFunction)THPSize_reduce, METH_NOARGS, nullptr}, + {"numel", THPSize_numel, METH_NOARGS, nullptr}, + {"__reduce__", THPSize_reduce, METH_NOARGS, nullptr}, {nullptr} }; diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 23d642c1b7844..30f109e7cef5a 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -1,6 +1,6 @@ #include #ifdef _MSC_VER -#include +#include #endif #include diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 5e708f2b4f2dc..e7c8bfdbe4f26 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -35,7 +35,8 @@ PyObject_IsInstance(obj, THPComplexDoubleStorageClass) #define THPComplexFloatStorage_Check(obj) \ PyObject_IsInstance(obj, THPComplexFloatStorageClass) - +#define THPQUInt4x2Storage_Check(obj) \ + PyObject_IsInstance(obj, THPQUInt8StorageClass) #define THPDoubleStorage_CData(obj) (obj)->cdata #define THPFloatStorage_CData(obj) (obj)->cdata @@ -52,6 +53,7 @@ #define THPBFloat16Storage_CData(obj) (obj)->cdata #define THPComplexDoubleStorage_CData(obj) (obj)->cdata #define THPComplexFloatStorage_CData(obj) (obj)->cdata +#define THPQUInt4x2Storage_CData(obj) (obj)->cdata #define THPStorageType TH_CONCAT_3(THP,Real,StorageType) #define THPStorageBaseStr TH_CONCAT_STRING_2(Real,StorageBase) diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp new file mode 100644 index 0000000000000..8e2458faf6f44 --- /dev/null +++ b/torch/csrc/Stream.cpp @@ -0,0 +1,117 @@ +#include +#include +#include +#include + +#include + +PyTypeObject *THPStreamClass = nullptr; + +static PyObject* THPStream_pynew( + PyTypeObject *type, PyObject *args, PyObject *kwargs) { + HANDLE_TH_ERRORS + uint64_t cdata = 0; + static char *kwlist[] = {"_cdata", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, "|K", kwlist, &cdata)) { + return nullptr; + } + + THPObjectPtr ptr(type->tp_alloc(type, 0)); + if (!ptr) { + return nullptr; + } + + THPStream* self = (THPStream *)ptr.get(); + self->cdata = cdata; + return (PyObject *)ptr.release(); + END_HANDLE_TH_ERRORS +} + +static void THPStream_dealloc(THPStream *self) { + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static PyObject * THPStream_get_device(THPStream *self, void *unused) { + HANDLE_TH_ERRORS + return THPDevice_New(c10::Stream::unpack(self->cdata).device()); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPStream_eq(THPStream *self, THPStream *other) { + HANDLE_TH_ERRORS + return PyBool_FromLong(self->cdata == other->cdata); + END_HANDLE_TH_ERRORS +} + +static struct PyMemberDef THPStream_members[] = { + {(char*)"_cdata", + T_ULONGLONG, offsetof(THPStream, cdata), READONLY, nullptr}, + {nullptr} +}; + +static struct PyGetSetDef THPStream_properties[] = { + {"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr}, + {nullptr} +}; + +static PyMethodDef THPStream_methods[] = { + {(char*)"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr}, + {nullptr} +}; + +PyTypeObject THPStreamType = { + PyVarObject_HEAD_INIT(nullptr, 0) + "torch.Stream", /* tp_name */ + sizeof(THPStream), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THPStream_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + THPStream_methods, /* tp_methods */ + THPStream_members, /* tp_members */ + THPStream_properties, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + THPStream_pynew, /* tp_new */ +}; + + +void THPStream_init(PyObject *module) +{ + THPStreamClass = &THPStreamType; + Py_TYPE(&THPStreamType) = &PyType_Type; + if (PyType_Ready(&THPStreamType) < 0) { + throw python_error(); + } + Py_INCREF(&THPStreamType); + if (PyModule_AddObject( + module, "Stream", (PyObject *)&THPStreamType) < 0) { + throw python_error(); + } +} diff --git a/torch/csrc/Stream.h b/torch/csrc/Stream.h new file mode 100644 index 0000000000000..8d507977e12aa --- /dev/null +++ b/torch/csrc/Stream.h @@ -0,0 +1,18 @@ +#ifndef THP_STREAM_INC +#define THP_STREAM_INC + +#include + +struct THPStream { + PyObject_HEAD + uint64_t cdata; +}; +extern PyTypeObject *THPStreamClass; + +void THPStream_init(PyObject *module); + +inline bool THPStream_Check(PyObject* obj) { + return THPStreamClass && PyObject_IsInstance(obj, (PyObject*)THPStreamClass); +} + +#endif // THP_STREAM_INC diff --git a/torch/csrc/THP.h b/torch/csrc/THP.h index edf4621765f8b..26f6c06b3d20b 100644 --- a/torch/csrc/THP.h +++ b/torch/csrc/THP.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/WindowsTorchApiMacro.h b/torch/csrc/WindowsTorchApiMacro.h index 7f8ef4e016777..44ae3b3b8180c 100644 --- a/torch/csrc/WindowsTorchApiMacro.h +++ b/torch/csrc/WindowsTorchApiMacro.h @@ -2,6 +2,8 @@ #include -// There's no difference between aten, torch and caffe2 libs any more -// TODO: clean up the naming for consistency -#define TORCH_API CAFFE2_API +#ifdef _WIN32 +#define TORCH_PYTHON_API +#else +#define TORCH_PYTHON_API TORCH_API +#endif diff --git a/torch/csrc/api/include/torch/all.h b/torch/csrc/api/include/torch/all.h index 5bcc8eec93abd..5717bccf6017e 100644 --- a/torch/csrc/api/include/torch/all.h +++ b/torch/csrc/api/include/torch/all.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/api/include/torch/cuda.h b/torch/csrc/api/include/torch/cuda.h index 5f6f2a9eb8a93..a7e063b90af97 100644 --- a/torch/csrc/api/include/torch/cuda.h +++ b/torch/csrc/api/include/torch/cuda.h @@ -23,5 +23,8 @@ void TORCH_API manual_seed(uint64_t seed); /// Sets the seed for all available GPUs. void TORCH_API manual_seed_all(uint64_t seed); +/// Waits for all kernels in all streams on a CUDA device to complete. +void TORCH_API synchronize(int64_t device_index = -1); + } // namespace cuda } // namespace torch diff --git a/torch/csrc/api/include/torch/fft.h b/torch/csrc/api/include/torch/fft.h index 1c119ed752268..a176d37d0e8be 100644 --- a/torch/csrc/api/include/torch/fft.h +++ b/torch/csrc/api/include/torch/fft.h @@ -35,6 +35,36 @@ inline Tensor ifft(const Tensor& self, return torch::fft_ifft(self, n, dim, norm); } +/// Computes the 2-dimensional fast Fourier transform over the given dimensions. +/// See https://pytorch.org/docs/master/fft.html#torch.fft.fft2. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::fft2(t); +/// ``` +inline Tensor fft2(const Tensor& self, + c10::optional s=c10::nullopt, + IntArrayRef dim={-2, -1}, + c10::optional norm=c10::nullopt) { + return torch::fft_fft2(self, s, dim, norm); +} + +/// Computes the inverse of torch.fft.fft2 +/// See https://pytorch.org/docs/master/fft.html#torch.fft.ifft2. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::ifft2(t); +/// ``` +inline Tensor ifft2(const Tensor& self, + c10::optional s=c10::nullopt, + IntArrayRef dim={-2, -1}, + c10::optional norm=c10::nullopt) { + return torch::fft_ifft2(self, s, dim, norm); +} + /// Computes the N dimensional fast Fourier transform over given dimensions. /// See https://pytorch.org/docs/master/fft.html#torch.fft.fftn. /// @@ -99,6 +129,36 @@ inline Tensor irfft(const Tensor& self, return torch::fft_irfft(self, n, dim, norm); } +/// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian output. +/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfft2 +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kDouble); +/// torch::fft::rfft2(t); +/// ``` +inline Tensor rfft2(const Tensor& self, + c10::optional s=c10::nullopt, + IntArrayRef dim={-2, -1}, + c10::optional norm=c10::nullopt) { + return torch::fft_rfft2(self, s, dim, norm); +} + +/// Computes the inverse of torch.fft.rfft2. +/// See https://pytorch.org/docs/master/fft.html#torch.fft.irfft2. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::irfft2(t); +/// ``` +inline Tensor irfft2(const Tensor& self, + c10::optional s=c10::nullopt, + IntArrayRef dim={-2, -1}, + c10::optional norm=c10::nullopt) { + return torch::fft_irfft2(self, s, dim, norm); +} + /// Computes the N dimensional FFT of real input with onesided Hermitian output. /// See https://pytorch.org/docs/master/fft.html#torch.fft.rfftn /// @@ -166,4 +226,66 @@ inline Tensor ihfft(const Tensor& self, return torch::fft_ihfft(self, n, dim, norm); } +/// Computes the discrete Fourier Transform sample frequencies for a signal of size n. +/// +/// See https://pytorch.org/docs/master/fft.html#torch.fft.fftfreq +/// +/// Example: +/// ``` +/// auto frequencies = torch::fft::fftfreq(128, torch::kDouble); +/// ``` +inline Tensor fftfreq(int64_t n, double d, const TensorOptions& options={}) { + return torch::fft_fftfreq(n, d, options); +} + +inline Tensor fftfreq(int64_t n, const TensorOptions& options={}) { + return torch::fft_fftfreq(n, /*d=*/1.0, options); +} + +/// Computes the sample frequencies for torch.fft.rfft with a signal of size n. +/// +/// Like torch.fft.rfft, only the positive frequencies are included. +/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfftfreq +/// +/// Example: +/// ``` +/// auto frequencies = torch::fft::rfftfreq(128, torch::kDouble); +/// ``` +inline Tensor rfftfreq(int64_t n, double d, const TensorOptions& options) { + return torch::fft_rfftfreq(n, d, options); +} + +inline Tensor rfftfreq(int64_t n, const TensorOptions& options) { + return torch::fft_rfftfreq(n, /*d=*/1.0, options); +} + +/// Reorders n-dimensional FFT output to have negative frequency terms first, by +/// a torch.roll operation. +/// +/// See https://pytorch.org/docs/master/fft.html#torch.fft.fftshift +/// +/// Example: +/// ``` +/// auto x = torch::randn({127, 4}); +/// auto centred_fft = torch::fft::fftshift(torch::fft::fftn(x)); +/// ``` +inline Tensor fftshift(const Tensor& x, c10::optional dim=c10::nullopt) { + return torch::fft_fftshift(x, dim); +} + +/// Inverse of torch.fft.fftshift +/// +/// See https://pytorch.org/docs/master/fft.html#torch.fft.ifftshift +/// +/// Example: +/// ``` +/// auto x = torch::randn({127, 4}); +/// auto shift = torch::fft::fftshift(x) +/// auto unshift = torch::fft::ifftshift(shift); +/// assert(torch::allclose(x, unshift)); +/// ``` +inline Tensor ifftshift(const Tensor& x, c10::optional dim=c10::nullopt) { + return torch::fft_ifftshift(x, dim); +} + }} // torch::fft diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 5ce90dcc972e5..26d9e66a1f6b3 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -8,10 +8,42 @@ namespace linalg { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { +inline Tensor cholesky(const Tensor& self) { + return torch::linalg_cholesky(self); +} + +inline Tensor cholesky_out(Tensor& result, const Tensor& self) { + return torch::linalg_cholesky_out(result, self); +} + inline Tensor det(const Tensor& self) { return torch::linalg_det(self); } +inline std::tuple slogdet(const Tensor& input) { + return torch::linalg_slogdet(input); +} + +inline std::tuple slogdet_out(Tensor& sign, Tensor& logabsdet, const Tensor& input) { + return torch::linalg_slogdet_out(sign, logabsdet, input); +} + +inline std::tuple eigh(const Tensor& self, std::string uplo) { + return torch::linalg_eigh(self, uplo); +} + +inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { + return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo); +} + +inline Tensor eigvalsh(const Tensor& self, std::string uplo) { + return torch::linalg_eigvalsh(self, uplo); +} + +inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { + return torch::linalg_eigvalsh_out(result, self, uplo); +} + inline Tensor norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } @@ -28,15 +60,114 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +inline Tensor matrix_rank(const Tensor input, optional tol, bool hermitian) { + return torch::linalg_matrix_rank(input, tol, hermitian); +} + +inline Tensor& matrix_rank_out(Tensor& result, const Tensor input, optional tol, bool hermitian) { + return torch::linalg_matrix_rank_out(result, input, tol, hermitian); +} + +inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) { + return torch::linalg_pinv(input, rcond, hermitian); +} + +inline Tensor& pinv_out(Tensor& result, const Tensor& input, double rcond, bool hermitian) { + return torch::linalg_pinv_out(result, input, rcond, hermitian); +} + +inline Tensor solve(const Tensor& input, const Tensor& other) { + return torch::linalg_solve(input, other); +} + +inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other) { + return torch::linalg_solve_out(result, input, other); +} + +inline Tensor tensorinv(const Tensor& self, int64_t ind) { + return torch::linalg_tensorinv(self, ind); +} + +inline Tensor& tensorinv_out(Tensor& result,const Tensor& self, int64_t ind) { + return torch::linalg_tensorinv_out(result, self, ind); +} + +inline Tensor tensorsolve(const Tensor& self, const Tensor& other, optional dims) { + return torch::linalg_tensorsolve(self, other, dims); +} + +inline Tensor& tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional dims) { + return torch::linalg_tensorsolve_out(result, self, other, dims); +} + +inline Tensor inv(const Tensor& input) { + return torch::linalg_inv(input); +} + +inline Tensor& inv_out(Tensor& result, const Tensor& input) { + return torch::linalg_inv_out(result, input); +} + } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ +/// Cholesky decomposition +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.cholesky +/// +/// Example: +/// ``` +/// auto A = torch::randn({4, 4}); +/// auto A = torch::matmul(A, A.t()); +/// auto L = torch::linalg::cholesky(A); +/// assert(torch::allclose(torch::matmul(L, L.t()), A)); +/// ``` +inline Tensor cholesky(const Tensor& self) { + return detail::cholesky(self); +} + +inline Tensor cholesky_out(Tensor& result, const Tensor& self) { + return detail::cholesky_out(result, self); +} /// See the documentation of torch.linalg.det inline Tensor linalg_det(const Tensor& self) { return detail::det(self); } +/// Computes the sign and (natural) logarithm of the determinant +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.slogdet +inline std::tuple slogdet(const Tensor& input) { + return detail::slogdet(input); +} + +inline std::tuple slogdet_out(Tensor& sign, Tensor& logabsdet, const Tensor& input) { + return detail::slogdet_out(sign, logabsdet, input); +} + +/// Computes eigenvalues and eigenvectors +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigh +inline std::tuple eigh(const Tensor& self, std::string uplo) { + return detail::eigh(self, uplo); +} + +inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { + return detail::eigh_out(eigvals, eigvecs, self, uplo); +} + +/// Computes eigenvalues +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigvalsh +inline Tensor eigvalsh(const Tensor& self, std::string uplo) { + return detail::eigvalsh(self, uplo); +} + +inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { + return detail::eigvalsh_out(result, self, uplo); +} + inline Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } @@ -53,4 +184,82 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string o return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_rank +inline Tensor matrix_rank(const Tensor input, optional tol, bool hermitian) { + return detail::matrix_rank(input, tol, hermitian); +} + +inline Tensor& matrix_rank_out(Tensor& result, const Tensor input, optional tol, bool hermitian) { + return detail::matrix_rank_out(result, input, tol, hermitian); +} + +/// Computes pseudo-inverse +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.pinv +inline Tensor pinv(const Tensor& input, double rcond=1e-15, bool hermitian=false) { + return detail::pinv(input, rcond, hermitian); +} + +inline Tensor& pinv_out(Tensor& result, const Tensor& input, double rcond=1e-15, bool hermitian=false) { + return detail::pinv_out(result, input, rcond, hermitian); +} + +/// Computes a tensor `x` such that `matmul(input, x) = other`. +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve +inline Tensor solve(const Tensor& input, const Tensor& other) { + return detail::solve(input, other); +} + +inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other) { + return detail::solve_out(result, input, other); +} + +/// Computes the inverse of a tensor +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorinv +/// +/// Example: +/// ``` +/// auto a = torch::eye(4*6).reshape({4, 6, 8, 3}); +/// int64_t ind = 2; +/// auto ainv = torch::linalg::tensorinv(a, ind); +/// ``` +inline Tensor tensorinv(const Tensor& self, int64_t ind) { + return detail::tensorinv(self, ind); +} + +inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { + return detail::tensorinv_out(result, self, ind); +} + +/// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`. +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorsolve +/// +/// Example: +/// ``` +/// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4}); +/// auto b = torch::randn(2*3, 4); +/// auto x = torch::linalg::tensorsolve(a, b); +/// ``` +inline Tensor tensorsolve(const Tensor& input, const Tensor& other, optional dims) { + return detail::tensorsolve(input, other, dims); +} + +inline Tensor& tensorsolve_out(Tensor& result, const Tensor& input, const Tensor& other, optional dims) { + return detail::tensorsolve_out(result, input, other, dims); +} + +/// Computes a tensor `inverse_input` such that `dot(input, inverse_input) = eye(input.size(0))`. +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.inv +inline Tensor inv(const Tensor& input) { + return detail::inv(input); +} + +inline Tensor& inv_out(Tensor& result, const Tensor& input) { + return detail::inv_out(result, input); +} + }} // torch::linalg diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index b5a06f4cfb146..6ed3c37311c09 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -307,9 +307,9 @@ inline Tensor cosine_embedding_loss( // ============================================================================ -inline Tensor _smooth_l1_loss(const Tensor& input, const Tensor& target) { +inline Tensor _smooth_l1_loss(const Tensor& input, const Tensor& target, double beta = 1.) { auto t = torch::abs(input - target); - return torch::where(t < 1, 0.5 * torch::pow(t, 2), t - 0.5); + return torch::where(t < beta, 0.5 * torch::pow(t, 2) / beta, t - 0.5 * beta); } #ifndef DOXYGEN_SHOULD_SKIP_THIS @@ -317,7 +317,8 @@ namespace detail { inline Tensor smooth_l1_loss( const Tensor& input, const Tensor& target, - SmoothL1LossFuncOptions::reduction_t reduction) { + SmoothL1LossFuncOptions::reduction_t reduction, + double beta = 1.) { if (target.sizes() != input.sizes()) { TORCH_WARN("Using a target size (", target.sizes(), ") that is different to the input size (", input.sizes(), "). ", "This will likely lead to incorrect results due to broadcasting. ", @@ -325,7 +326,7 @@ inline Tensor smooth_l1_loss( } std::vector expanded_tensors = torch::broadcast_tensors({input, target}); - return torch::smooth_l1_loss(expanded_tensors[0], expanded_tensors[1], enumtype::reduction_get_enum(reduction)); + return torch::smooth_l1_loss(expanded_tensors[0], expanded_tensors[1], enumtype::reduction_get_enum(reduction), beta); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -344,8 +345,9 @@ inline Tensor smooth_l1_loss( inline Tensor smooth_l1_loss( const Tensor& input, const Tensor& target, - const SmoothL1LossFuncOptions& options = {}) { - return detail::smooth_l1_loss(input, target, options.reduction()); + const SmoothL1LossFuncOptions& options = {}, + double beta = 1.) { + return detail::smooth_l1_loss(input, target, options.reduction(), beta); } // ============================================================================ @@ -525,6 +527,85 @@ inline Tensor triplet_margin_loss( // ============================================================================ +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor triplet_margin_with_distance_loss( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative, + c10::optional distance_function, + double margin, + bool swap, + TripletMarginWithDistanceLossFuncOptions::reduction_t reduction) { + Tensor dist_pos, dist_neg; + if (distance_function.has_value()) { + auto distance_function_impl = distance_function.value(); + dist_pos = distance_function_impl(anchor, positive); + dist_neg = distance_function_impl(anchor, negative); + } else { + dist_pos = pairwise_distance(anchor, positive); + dist_neg = pairwise_distance(anchor, negative); + } + + if (swap) { + Tensor dist_swap; + if (distance_function.has_value()) { + dist_swap = distance_function.value()(positive, negative); + } else { + dist_swap = pairwise_distance(positive, negative); + } + dist_neg = torch::min(dist_neg, dist_swap); + } + + auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0); + + Tensor ret; + if (c10::get_if(&reduction)) { + ret = loss; + } else if (c10::get_if(&reduction)) { + ret = loss.mean(); + } else if (c10::get_if(&reduction)) { + ret = loss.sum(); + } else { + ret = anchor; + TORCH_INTERNAL_ASSERT( + false, + enumtype::get_enum_name(reduction), + " is not valid"); + } + return ret; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::triplet_margin_with_distance_loss(anchor, positive, negative, F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); +/// ``` +inline Tensor triplet_margin_with_distance_loss( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative, + const TripletMarginWithDistanceLossFuncOptions& options = {}) { + return detail::triplet_margin_with_distance_loss( + anchor, + positive, + negative, + options.distance_function(), + options.margin(), + options.swap(), + options.reduction()); +} + +// ============================================================================ + #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor ctc_loss(const Tensor& log_probs, diff --git a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h index 7ea98bf07d99a..32161d04d806f 100644 --- a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h @@ -16,6 +16,10 @@ inline Tensor pixel_shuffle( upscale_factor ); } + +inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) { + return torch::pixel_unshuffle(input, downscale_factor); +} } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -36,6 +40,12 @@ inline Tensor pixel_shuffle( return detail::pixel_shuffle(input, options.upscale_factor()); } +inline Tensor pixel_unshuffle( + const Tensor& input, + const PixelUnshuffleFuncOptions& options) { + return detail::pixel_unshuffle(input, options.downscale_factor()); +} + } // namespace functional } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/functional/vision.h b/torch/csrc/api/include/torch/nn/functional/vision.h index e1041cb21d8c4..1fe084d02c79d 100644 --- a/torch/csrc/api/include/torch/nn/functional/vision.h +++ b/torch/csrc/api/include/torch/nn/functional/vision.h @@ -61,8 +61,10 @@ inline Tensor grid_sample( if (c10::get_if(&mode)) { mode_enum = 0; - } else { /// mode == 'nearest' + } else if (c10::get_if(&mode)) { mode_enum = 1; + } else { /// mode == 'bicubic' + mode_enum = 2; } if (c10::get_if(&padding_mode)) { diff --git a/torch/csrc/api/include/torch/nn/modules.h b/torch/csrc/api/include/torch/nn/modules.h index 863fe0b0c9ea2..cbb0f9bc10ddf 100644 --- a/torch/csrc/api/include/torch/nn/modules.h +++ b/torch/csrc/api/include/torch/nn/modules.h @@ -6,6 +6,7 @@ // Containers #include #include +#include #include #include #include diff --git a/torch/csrc/api/include/torch/nn/modules/container/moduledict.h b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h new file mode 100644 index 0000000000000..67fb343fb6f38 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h @@ -0,0 +1,243 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// An OrderedDict of `Module`s that registers its elements by their `key`s. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::OrderedDict> ordereddict = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict1(ordereddict); +/// +/// for (const auto &module : *dict1) { +/// module->pretty_print(std::cout); +/// } +/// +/// std::vector>> list = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict2(list); +/// +/// for (const auto &module : *dict2) { +/// module->pretty_print(std::cout); +/// } +/// +/// \endrst +/// +/// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`? +/// The value a `ModuleDict` provides over manually calling an ordered map of +/// modules is that it allows treating the whole container *as a single module*, +/// such that performing a transformation on the `ModuleDict` applies to each of the +/// modules it stores (which are each a registered submodule of the `ModuleDict`). +/// For example, calling `.to(torch::kCUDA)` on a `ModuleDict` will move each module +/// in the map to CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::OrderedDict> ordereddict = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict(ordereddict); +/// +/// // Convert all modules to CUDA. +/// dict->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `ModuleDict` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding new modules from a vector +/// of key-module pairs or an `OrderedDict` or another `ModuleDict` after +/// construction via `update`. +class ModuleDictImpl : public Cloneable { + public: + using Iterator = torch::OrderedDict>::Iterator; + using ConstIterator = torch::OrderedDict>::ConstIterator; + + ModuleDictImpl() = default; + + /// Constructs the `ModuleDict` from a list of string-Module pairs. + explicit ModuleDictImpl( + const std::vector>>& modules) { + update(modules); + } + + /// Constructs the `ModuleDict` from an `OrderedDict`. + explicit ModuleDictImpl( + const torch::OrderedDict>& modules) { + update(modules); + } + + /// Return the items in the `ModuleDict`. + std::vector>> items() const { + return modules_.pairs(); + } + + /// Return the keys in the `ModuleDict`. + std::vector keys() const { + return modules_.keys(); + } + + /// Return the values in the `ModuleDict`. + std::vector> values() const { + return modules_.values(); + } + + /// Return an iterator to the start of `ModuleDict`. + Iterator begin() { + return modules_.begin(); + } + + /// Return a const iterator to the start of `ModuleDict`. + ConstIterator begin() const { + return modules_.begin(); + } + + /// Return an iterator to the end of `ModuleDict`. + Iterator end() { + return modules_.end(); + } + + /// Return a const iterator to the end of `ModuleDict`. + ConstIterator end() const { + return modules_.end(); + } + + /// Return the number of items currently stored in the `ModuleDict`. + size_t size() const noexcept { + return modules_.size(); + } + + /// Return true if the `ModuleDict` is empty, otherwise return false. + bool empty() const noexcept { + return modules_.is_empty(); + } + + /// Check if the centain parameter with the key in the `ModuleDict`. + bool contains(const std::string& key) const noexcept { + return modules_.contains(key); + } + + /// Remove all items from the `ModuleDict`. + void clear() { + // Not remove the registration of modules to make it consistent with python version. + modules_.clear(); + } + + /// Special cloning function for `ModuleDict` because it does not use + /// `reset()`. + std::shared_ptr clone( + const optional& device = nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->insert(module.key(), module.value()->clone(device)); + } + return clone; + } + + /// `reset()` is empty for `ModuleDict`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `ModuleDict` into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ModuleDict"; + } + + /// Attempts to returns the `Module` associated with the given `key`. Throws + /// an exception if no such `key` is stored in the `ModuleDict`. Check + /// contains(key) before for a non-throwing way of access. + std::shared_ptr operator[](const std::string& key) const { + return modules_[key]; + } + + /// Attempts to return the module at the given key as the requested type. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + template + T& at(const std::string& key) { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + return *modules_[key]->as(); + } + + /// Attempts to return the module at the given key as the requested type. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + template + const T& at(const std::string& key) const { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + return *modules_[key]->as(); + } + + /// Removes and returns the `Module` associated with the given `key`. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + std::shared_ptr pop(const std::string& key) { + auto module = modules_[key]; + modules_.erase(key); + // Not remove the registration of the module to make it consistent with python version. + return module; + } + + /// Updated the `ModuleDict` with a vector of key-module pairs. + void update( + const std::vector>>& modules) { + for (auto& item : modules) { + insert(item.first, item.second); + } + } + + /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or `ModuleDict`. + template + void update(const Container& container) { + for (auto& item : container) { + insert(item.key(), item.value()); + } + } + +private: + /// Private `OrderedDict` holding the key-Module pairs. + torch::OrderedDict> modules_; + + /// Insert a key-module pair by overwriting existing keys, + /// and register or replace the `Module`. + void insert(const std::string& key, std::shared_ptr module) { + if (contains(key)) { + modules_[key] = std::move(module); + replace_module(key, modules_[key]); + } + else { + modules_.insert(key, std::move(module)); + register_module(key, modules_.back().value()); + } + } + +}; + +/// A `ModuleHolder` subclass for `ModuleDictImpl`. +/// See the documentation for `ModuleDictImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ModuleDict); + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h index 0b4c83eb3d55d..472d8e827585a 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h @@ -57,7 +57,7 @@ class ParameterListImpl : public Cloneable { void append(const torch::Tensor& param) { bool requires_grad = param.requires_grad(); register_parameter( - c10::to_string(parameters_.size()), std::move(param), requires_grad); + c10::to_string(parameters_.size()), param, requires_grad); } /// push the a given parameter at the end of the list diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h index b7e3ddd2ba376..4b0fb65a92279 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -107,7 +107,7 @@ class SequentialImpl : public Cloneable { explicit SequentialImpl(torch::OrderedDict&& ordered_dict) { modules_.reserve(ordered_dict.size()); for (auto& item : ordered_dict) { - push_back(std::move(item.key()), std::move(item.value())); + push_back(item.key(), std::move(item.value())); } } diff --git a/torch/csrc/api/include/torch/nn/modules/loss.h b/torch/csrc/api/include/torch/nn/modules/loss.h index d136f9cb7ee99..8c93088648422 100644 --- a/torch/csrc/api/include/torch/nn/modules/loss.h +++ b/torch/csrc/api/include/torch/nn/modules/loss.h @@ -309,7 +309,7 @@ struct TORCH_API SmoothL1LossImpl : public Cloneable { TORCH_MODULE(SmoothL1Loss); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - + /// Creates a criterion that optimizes a multi-class multi-classification /// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) /// and output :math:`y` (which is a 2D `Tensor` of target class indices). @@ -421,9 +421,9 @@ TORCH_MODULE(MultiLabelSoftMarginLoss); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that measures the triplet loss given an input -/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater +/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater /// than :math:`0`. This is used for measuring a relative similarity between -/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`, +/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`, /// `positive examples` and `negative examples` respectively). The /// shapes of all input tensors should be :math:`(N, D)`. /// See https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginLoss to learn @@ -461,6 +461,50 @@ struct TORCH_API TripletMarginLossImpl : public Cloneable /// module storage semantics. TORCH_MODULE(TripletMarginLoss); +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the triplet loss given input +/// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, +/// positive, and negative examples, respectively); and a nonnegative, real-valued function +/// ("distance function") used to compute the relationships between the anchor +/// and positive example ("positive distance") and the anchor and negative +/// example ("negative distance"). +/// See https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginWithDistanceLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// TripletMarginWithDistanceLoss model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); +/// ``` +struct TORCH_API TripletMarginWithDistanceLossImpl : public Cloneable { + explicit TripletMarginWithDistanceLossImpl( + TripletMarginWithDistanceLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `TripletMarginWithDistanceLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative); + + /// The options with which this `Module` was constructed. + TripletMarginWithDistanceLossOptions options; +}; + +/// A `ModuleHolder` subclass for `TripletMarginWithDistanceLossImpl`. +/// See the documentation for `TripletMarginWithDistanceLossImpl` class to learn what methods it +/// provides, and examples of how to use `TripletMarginWithDistanceLoss` with +/// `torch::nn::TripletMarginWithDistanceLossOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(TripletMarginWithDistanceLoss); + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// The Connectionist Temporal Classification loss. @@ -626,9 +670,9 @@ TORCH_MODULE(NLLLoss); struct TORCH_API CrossEntropyLossImpl : public Cloneable { explicit CrossEntropyLossImpl( const CrossEntropyLossOptions& options_ = {}); - + void reset() override; - + /// Pretty prints the `CrossEntropyLoss` module into the given `stream`. void pretty_print(std::ostream& stream) const override; diff --git a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h index 98d4be45e04a6..08278ea2162e5 100644 --- a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h @@ -12,12 +12,13 @@ namespace nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` -/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn -/// about the exact behavior of this module. +/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an +/// upscale factor. See +/// https://pytorch.org/docs/master/nn.html#torch.nn.PixelShuffle to learn about +/// the exact behavior of this module. /// -/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -44,5 +45,42 @@ struct TORCH_API PixelShuffleImpl : public torch::nn::Cloneable { + explicit PixelUnshuffleImpl(const PixelUnshuffleOptions& options_); + + /// Pretty prints the `PixelUnshuffle` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + void reset() override; + + /// The options with which this `Module` was constructed. + PixelUnshuffleOptions options; +}; + +/// A `ModuleHolder` subclass for `PixelUnshuffleImpl`. +/// See the documentation for `PixelUnshuffleImpl` class to learn what methods +/// it provides, and examples of how to use `PixelUnshuffle` with +/// `torch::nn::PixelUnshuffleOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(PixelUnshuffle); + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h index d3244532a7b6d..ac6c58441f390 100644 --- a/torch/csrc/api/include/torch/nn/modules/rnn.h +++ b/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -160,6 +160,8 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase { protected: void check_forward_args(const Tensor& input, std::tuple hidden, const Tensor& batch_sizes) const; + std::tuple get_expected_cell_size(const Tensor& input, const Tensor& batch_sizes) const; + std::tuple permute_hidden(std::tuple hx, const Tensor& permutation) const; std::tuple> forward_helper( diff --git a/torch/csrc/api/include/torch/nn/modules/transformercoder.h b/torch/csrc/api/include/torch/nn/modules/transformercoder.h index 04518a177333a..6b69f53ecf33c 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformercoder.h +++ b/torch/csrc/api/include/torch/nn/modules/transformercoder.h @@ -32,6 +32,8 @@ namespace nn { class TORCH_API TransformerEncoderImpl : public Cloneable { public: + TransformerEncoderImpl(TransformerEncoderLayer encoder_layer, int64_t num_layers) + : TransformerEncoderImpl(TransformerEncoderOptions(encoder_layer, num_layers)) {} explicit TransformerEncoderImpl(TransformerEncoderOptions options_); Tensor forward( diff --git a/torch/csrc/api/include/torch/nn/modules/transformerlayer.h b/torch/csrc/api/include/torch/nn/modules/transformerlayer.h index db003c224552c..d6021e189781a 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformerlayer.h +++ b/torch/csrc/api/include/torch/nn/modules/transformerlayer.h @@ -70,7 +70,7 @@ class TORCH_API TransformerEncoderLayerImpl : public Cloneable reduction_t; + typedef std::function distance_function_t; + + /// Specifies a nonnegative, real-valued function that quantifies the + /// closeness of two tensors. If not specified, `F::pairwise_distance` will + /// be used. Default: nullopt + TORCH_ARG(c10::optional, distance_function) = c10::nullopt; + /// Specifies a nonnegative margin representing the minimum difference + /// between the positive and negative distances required for the loss to be 0. + /// Larger margins penalize cases where the negative examples are not distance + /// enough from the anchors, relative to the positives. Default: 1 + TORCH_ARG(double, margin) = 1.0; + /// Whether to use the distance swap described in the paper Learning shallow + /// convolutional feature descriptors with triplet losses by V. Balntas, + /// E. Riba et al. If True, and if the positive example is closer to the + /// negative example than the anchor is, swaps the positive example and the + /// anchor in the loss computation. Default: False + TORCH_ARG(bool, swap) = false; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::triplet_margin_with_distance_loss`. +/// +/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::triplet_margin_with_distance_loss(anchor, positive, negative, F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); +/// ``` +using TripletMarginWithDistanceLossFuncOptions = TripletMarginWithDistanceLossOptions; +} // namespace functional + +// ============================================================================ + /// Options for the `CTCLoss` module. /// /// Example: diff --git a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h index e72e6931e49ae..e28e0053e98b8 100644 --- a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h @@ -21,6 +21,20 @@ struct TORCH_API PixelShuffleOptions { TORCH_ARG(int64_t, upscale_factor); }; +/// Options for the `PixelUnshuffle` module. +/// +/// Example: +/// ``` +/// PixelUnshuffle model(PixelUnshuffleOptions(5)); +/// ``` +struct TORCH_API PixelUnshuffleOptions { + /* implicit */ PixelUnshuffleOptions(int64_t downscale_factor) + : downscale_factor_(downscale_factor) {} + + /// Factor to decrease spatial resolution by + TORCH_ARG(int64_t, downscale_factor); +}; + namespace functional { /// Options for `torch::nn::functional::pixel_shuffle`. /// @@ -33,6 +47,18 @@ namespace functional { /// F::pixel_shuffle(x, F::PixelShuffleFuncOptions(2)); /// ``` using PixelShuffleFuncOptions = PixelShuffleOptions; + +/// Options for `torch::nn::functional::pixel_unshuffle`. +/// +/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pixel_unshuffle(x, F::PixelUnshuffleFuncOptions(2)); +/// ``` +using PixelUnshuffleFuncOptions = PixelUnshuffleOptions; } // namespace functional } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/options/rnn.h b/torch/csrc/api/include/torch/nn/options/rnn.h index ae37693399ef2..09bbfe0fa2f4d 100644 --- a/torch/csrc/api/include/torch/nn/options/rnn.h +++ b/torch/csrc/api/include/torch/nn/options/rnn.h @@ -38,6 +38,8 @@ struct TORCH_API RNNOptionsBase { TORCH_ARG(double, dropout) = 0.0; /// Whether to make the RNN bidirectional. TORCH_ARG(bool, bidirectional) = false; + /// Cell projection dimension. If 0, projections are not added. Can only be used for LSTMs. + TORCH_ARG(int64_t, proj_size) = 0; }; } // namespace detail @@ -108,6 +110,8 @@ struct TORCH_API LSTMOptions { TORCH_ARG(double, dropout) = 0.0; /// If ``true``, becomes a bidirectional LSTM. Default: ``false`` TORCH_ARG(bool, bidirectional) = false; + /// Cell projection dimension. If 0, projections are not added + TORCH_ARG(int64_t, proj_size) = 0; }; /// Options for the `GRU` module. @@ -148,7 +152,7 @@ namespace detail { struct TORCH_API RNNCellOptionsBase { RNNCellOptionsBase(int64_t input_size, int64_t hidden_size, bool bias, int64_t num_chunks); virtual ~RNNCellOptionsBase() = default; - + TORCH_ARG(int64_t, input_size); TORCH_ARG(int64_t, hidden_size); TORCH_ARG(bool, bias); diff --git a/torch/csrc/api/src/cuda.cpp b/torch/csrc/api/src/cuda.cpp index d40cd8611c429..b8f3ffa0ee0ab 100644 --- a/torch/csrc/api/src/cuda.cpp +++ b/torch/csrc/api/src/cuda.cpp @@ -1,6 +1,7 @@ #include #include +#include #include @@ -49,5 +50,13 @@ void manual_seed_all(uint64_t seed) { } } +void synchronize(int64_t device_index) { + TORCH_CHECK(is_available(), "No CUDA GPUs are available"); + int64_t num_gpus = cuda::device_count(); + TORCH_CHECK(device_index == -1 || device_index < num_gpus, + "Device index out of range: ", device_index); + at::detail::getCUDAHooks().deviceSynchronize(device_index); +} + } // namespace cuda } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp index 43ab1119def9c..4b41b88c420c6 100644 --- a/torch/csrc/api/src/nn/modules/loss.cpp +++ b/torch/csrc/api/src/nn/modules/loss.cpp @@ -180,6 +180,33 @@ Tensor TripletMarginLossImpl::forward( // ============================================================================ +TripletMarginWithDistanceLossImpl::TripletMarginWithDistanceLossImpl( + TripletMarginWithDistanceLossOptions options_) + : options(std::move(options_)) {} + +void TripletMarginWithDistanceLossImpl::reset() {} + +void TripletMarginWithDistanceLossImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::TripletMarginWithDistanceLoss(margin=" << options.margin() + << std::boolalpha << ", swap=" << options.swap() << ")"; +} + +Tensor TripletMarginWithDistanceLossImpl::forward( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative) { + return F::detail::triplet_margin_with_distance_loss( + anchor, + positive, + negative, + options.distance_function(), + options.margin(), + options.swap(), + options.reduction()); +} + +// ============================================================================ + MultiLabelMarginLossImpl::MultiLabelMarginLossImpl( const torch::nn::MultiLabelMarginLossOptions& options_) : options(options_) {} @@ -223,9 +250,9 @@ void SmoothL1LossImpl::pretty_print(std::ostream& stream) const { Tensor SmoothL1LossImpl::forward(const Tensor& input, const Tensor& target) { return F::detail::smooth_l1_loss(input, target, options.reduction()); } - + // ============================================================================ - + CTCLossImpl::CTCLossImpl(const CTCLossOptions& options_) : options(options_) {} void CTCLossImpl::reset() {} diff --git a/torch/csrc/api/src/nn/modules/pixelshuffle.cpp b/torch/csrc/api/src/nn/modules/pixelshuffle.cpp index dd2d34655979e..7062b07fe5d78 100644 --- a/torch/csrc/api/src/nn/modules/pixelshuffle.cpp +++ b/torch/csrc/api/src/nn/modules/pixelshuffle.cpp @@ -21,5 +21,19 @@ Tensor PixelShuffleImpl::forward( return F::detail::pixel_shuffle(input, options.upscale_factor()); } +PixelUnshuffleImpl::PixelUnshuffleImpl(const PixelUnshuffleOptions& options_) + : options(options_) {} + +void PixelUnshuffleImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::PixelUnshuffle(downscale_factor=" + << options.downscale_factor() << ")"; +} + +void PixelUnshuffleImpl::reset() {} + +Tensor PixelUnshuffleImpl::forward(const Tensor& input) { + return F::detail::pixel_unshuffle(input, options.downscale_factor()); +} + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index 634dcf03d9d14..b645a84907786 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -71,6 +71,17 @@ void RNNImplBase::reset() { "num_layers=", options_base.num_layers()); } + TORCH_CHECK( + 0 <= options_base.proj_size() && options_base.proj_size() < options_base.hidden_size(), + "proj_size has to be a positive integer, smaller than ", + "hidden_size or zero to disable projections"); + + if (options_base.proj_size() > 0) { + TORCH_CHECK( + c10::get_if(&options_base.mode()), + "proj_size argument is only supported for LSTM, not RNN or GRU"); + } + int64_t gate_size = 0; if (c10::get_if(&options_base.mode())) { gate_size = 4 * options_base.hidden_size(); @@ -89,21 +100,29 @@ void RNNImplBase::reset() { for (int64_t layer = 0; layer < options_base.num_layers(); layer++) { for (int64_t direction = 0; direction < num_directions; direction++) { - int64_t layer_input_size = layer == 0 ? options_base.input_size() : options_base.hidden_size() * num_directions; + int64_t real_hidden_size = options_base.proj_size() > 0 ? options_base.proj_size() : options_base.hidden_size(); + int64_t layer_input_size = layer == 0 ? options_base.input_size() : real_hidden_size * num_directions; auto w_ih = torch::empty({gate_size, layer_input_size}); - auto w_hh = torch::empty({gate_size, options_base.hidden_size()}); + auto w_hh = torch::empty({gate_size, real_hidden_size}); auto b_ih = torch::empty({gate_size}); // Second bias vector included for CuDNN compatibility. Only one // bias vector is needed in standard definition. auto b_hh = torch::empty({gate_size}); - std::vector layer_params = {w_ih, w_hh, b_ih, b_hh}; + std::vector layer_params = {w_ih, w_hh}; std::string suffix = direction == 1 ? "_reverse" : ""; std::vector param_names = {"weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"}; if (options_base.bias()) { param_names.emplace_back("bias_ih_l{layer}{suffix}"); param_names.emplace_back("bias_hh_l{layer}{suffix}"); + layer_params.emplace_back(b_ih); + layer_params.emplace_back(b_hh); + } + if (options_base.proj_size() > 0) { + auto w_hr = torch::empty({options_base.proj_size(), options_base.hidden_size()}); + layer_params.emplace_back(w_hr); + param_names.emplace_back("weight_hr_l{layer}{suffix}"); } for (size_t i = 0; i < param_names.size(); i++) { // NOLINT(modernize-loop-convert) std::string x = std::regex_replace(param_names[i], std::regex("\\{layer\\}"), c10::str(layer)); @@ -180,12 +199,17 @@ void RNNImplBase::flatten_parameters() { { torch::NoGradGuard no_grad; if (torch::_use_cudnn_rnn_flatten_weight()) { + int64_t num_weights = options_base.bias() ? 4 : 2; + if (options_base.proj_size() > 0) { + ++num_weights; + } torch::_cudnn_rnn_flatten_weight( flat_weights_, - options_base.bias() ? 4 : 2, + num_weights, options_base.input_size(), static_cast(get_cudnn_mode_for_rnn(options_base.mode())), options_base.hidden_size(), + options_base.proj_size(), options_base.num_layers(), options_base.batch_first(), options_base.bidirectional()); @@ -260,7 +284,8 @@ std::tuple RNNImplBase::get_expected_hidden_ mini_batch = options_base.batch_first() ? input.size(0) : input.size(1); } int64_t num_directions = options_base.bidirectional() ? 2 : 1; - return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, options_base.hidden_size()); + int64_t real_hidden_size = options_base.proj_size() > 0 ? options_base.proj_size() : options_base.hidden_size(); + return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, real_hidden_size); } template @@ -306,8 +331,11 @@ void RNNImplBase::pretty_print(std::ostream& stream) const { << ", bias=" << options_base.bias() << ", batch_first=" << options_base.batch_first() << ", dropout=" << options_base.dropout() - << ", bidirectional=" << options_base.bidirectional() - << ")"; + << ", bidirectional=" << options_base.bidirectional(); + if (options_base.proj_size() > 0) { + stream << ", proj_size=" << options_base.proj_size(); + } + stream << ")"; } template @@ -438,16 +466,27 @@ LSTMImpl::LSTMImpl(const LSTMOptions& options_) .bias(options_.bias()) .batch_first(options_.batch_first()) .dropout(options_.dropout()) - .bidirectional(options_.bidirectional())), + .bidirectional(options_.bidirectional()) + .proj_size(options_.proj_size())), options(options_) {} +std::tuple LSTMImpl::get_expected_cell_size( + const Tensor& input, const Tensor& batch_sizes) const { + int64_t mini_batch = 0; + if (batch_sizes.defined()) { + mini_batch = batch_sizes[0].item(); + } else { + mini_batch = options_base.batch_first() ? input.size(0) : input.size(1); + } + int64_t num_directions = options_base.bidirectional() ? 2 : 1; + return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, options_base.hidden_size()); +} + void LSTMImpl::check_forward_args(const Tensor& input, std::tuple hidden, const Tensor& batch_sizes) const { this->check_input(input, batch_sizes); - auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes); - - this->check_hidden_size(std::get<0>(hidden), expected_hidden_size, + this->check_hidden_size(std::get<0>(hidden), this->get_expected_hidden_size(input, batch_sizes), "Expected hidden[0] size {1}, got {2}"); - this->check_hidden_size(std::get<1>(hidden), expected_hidden_size, + this->check_hidden_size(std::get<1>(hidden), this->get_expected_cell_size(input, batch_sizes), "Expected hidden[1] size {1}, got {2}"); } @@ -471,10 +510,14 @@ std::tuple> LSTMImpl::forward_helper( std::tuple hx; if (!hx_opt.has_value()) { int64_t num_directions = options.bidirectional() ? 2 : 1; - auto zeros = torch::zeros({options.num_layers() * num_directions, - max_batch_size, options.hidden_size()}, - torch::dtype(input.dtype()).device(input.device())); - hx = std::make_tuple(zeros, zeros); + int64_t real_hidden_size = options.proj_size() > 0 ? options.proj_size() : options.hidden_size(); + auto h_zeros = torch::zeros({options.num_layers() * num_directions, + max_batch_size, real_hidden_size}, + torch::dtype(input.dtype()).device(input.device())); + auto c_zeros = torch::zeros({options.num_layers() * num_directions, + max_batch_size, options.hidden_size()}, + torch::dtype(input.dtype()).device(input.device())); + hx = std::make_tuple(h_zeros, c_zeros); } else { hx = hx_opt.value(); // Each batch of the hidden state should match the input sequence that @@ -650,13 +693,13 @@ void RNNCellImplBase::pretty_print(std::ostream& stream) const { if (!nonlinearity_str.empty() && nonlinearity_str != "kTanh") { stream << ", nonlinearity=" << nonlinearity_str; } - stream << ")"; + stream << ")"; } template void RNNCellImplBase::check_forward_input(const Tensor& input) const { TORCH_CHECK( - input.size(1) == options_base.input_size(), + input.size(1) == options_base.input_size(), "input has inconsistent input_size: got ", input.size(1), " expected ", options_base.input_size()); } diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 5a91038a94e62..27dd4ccce6498 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -9,8 +9,12 @@ #include #include #include +#include #include #include +#include +#include +#include #include #include @@ -29,14 +33,29 @@ using at::Scalar; using at::IntArrayRef; using at::TensorList; +const char* kCudnnDoubleBackwardMsg = "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)"; + + bool isDefined(const c10::optional& t) { return t.has_value() && t->defined(); } +bool isFwGradDefined(const c10::optional& t) { + return t.has_value() && t->defined() && t->fw_grad(/*level */ 0).defined(); +} + Tensor toLegacyTensor(const c10::optional& t) { return t.has_value() ? *t : Tensor(); } +Tensor toLegacyFwGrad(const c10::optional& t) { + return (t.has_value() && t->defined()) ? t->fw_grad(/*level */ 0) : Tensor(); +} + +Tensor toLegacyPrimal(const c10::optional& t) { + return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor(); +} + void copy_range(variable_list& out, IndexRange range, const Tensor & t) { AT_ASSERT(range.second <= out.size()); AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output"); @@ -49,9 +68,27 @@ void copy_range(variable_list& out, IndexRange range, at::ArrayRef t) { std::copy(t.begin(), t.end(), out.begin() + range.first); } -Tensor not_implemented(const char* name) { - throw std::runtime_error( - std::string("the derivative for '") + name + "' is not implemented"); +Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) { + auto ratio = result / self; + ratio.masked_fill_(self == 0, 0); + return grad * ratio; +} + +template +T not_implemented_base(const char* name, const char* reason) { + std::string msg = c10::str("the derivative for '", name, "' is not implemented."); + if (strlen(reason) > 0) { + msg = c10::str(msg, " ", reason); + }; + throw std::runtime_error(msg); +} + +Tensor not_implemented(const char* name, const char* reason) { + return not_implemented_base(name, reason); +} + +std::vector not_implemented_list(const char* name, const char* reason) { + return not_implemented_base>(name, reason); } Tensor maybe_multiply(const Tensor & t, const Scalar & s) { @@ -87,6 +124,22 @@ static Tensor wrapped_scalar_tensor(Scalar scalar) { return tensor; } +Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) { + if (!at::isComplexType(self_st) && gradient_result.is_complex()) { + // R -> C + return at::real(gradient_result); + } + return gradient_result; +} + +Tensor handle_r_to_c(Tensor self, Tensor gradient_result) { + if (!self.is_complex() && gradient_result.is_complex()) { + // R -> C + return at::real(gradient_result); + } + return gradient_result; +} + Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) { if (keepdim) { return output; @@ -123,22 +176,38 @@ std::tuple _euclidean_dist_backward(const Tensor & grad, const T x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)}; } -Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional & p_, const Tensor & norm) { +Tensor norm_backward(const Tensor& grad, const Tensor& self, const optional & p_, const Tensor& norm) { + return norm_backward(grad, self, p_, norm, {}, true); +} + +Tensor norm_backward(Tensor grad, const Tensor& self, const optional & p_, Tensor norm, IntArrayRef dim, bool keepdim) { + size_t ndim = self.sizes().size(); double p = p_.value_or(2.0).toDouble(); Tensor self_scaled; Tensor scale_v; + + if (!keepdim && self.dim() != 0) { + grad = unsqueeze_multiple(grad, dim, ndim); + norm = unsqueeze_multiple(norm, dim, ndim); + } + if (p == 0.0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else if (p == 1.0) { - return self.sign() * grad; + return self.sgn() * grad; } else if (p == 2.0) { self_scaled = self; scale_v = grad / norm; } else if (std::isinf(p)) { - self_scaled = self.sign() * (self.abs() == norm).type_as(self); - scale_v = grad.clone(at::MemoryFormat::Preserve); + Tensor is_eq_max = (self.abs() == norm).logical_or_(self.isnan().logical_and_(norm.isnan())).type_as(self); + self_scaled = self.sign() * is_eq_max; + Tensor nb_max = is_eq_max.count_nonzero(dim); + if (self.dim() != 0) { + nb_max = unsqueeze_multiple(nb_max, dim, ndim); + } + scale_v = grad / nb_max; } else if (p < 2.0) { - self_scaled = self.sign() * self.abs().pow(p - 1); + self_scaled = self.sgn() * self.abs().pow(p - 1); scale_v = grad / norm.pow(p - 1); } else { self_scaled = self * self.abs().pow(p - 2); @@ -149,36 +218,19 @@ Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional & p_, Tensor norm, IntArrayRef dim, bool keepdim) { - IntArrayRef sizes = self.sizes(); - if (!keepdim && self.dim() != 0) { - if (dim.size()==1) { - grad = grad.unsqueeze(dim[0]); - norm = norm.unsqueeze(dim[0]); - } else { - auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, sizes.size()); - for (size_t i = 0; i < sizes.size(); i++){ - if (dims_to_unsqueeze[i]) { - grad = grad.unsqueeze(i); - norm = norm.unsqueeze(i); - } - } - } - } - return norm_backward(grad, self, p_, norm); -} - -Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) { - double exponent = exponent_.toDouble(); - if (exponent == 0.0) { +Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) { + if (exponent.equal(0.0)) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { - return grad * exponent * self.pow(exponent - 1); + auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); }; + Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble()); + return handle_r_to_c(self, out); } } Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) { - return at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * exponent * self.pow(exponent - 1)); + auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj()); + return handle_r_to_c(self, out); } // Caveats: @@ -190,18 +242,45 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone // d(a^b)/db = 0 for a > 0 and b -> +0. // Currently, tensorflow agrees with us. Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { - return grad * at::where(at::logical_and(self == 0, exponent >= 0), + Tensor cond; + if (exponent.is_complex()) { + auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); + cond = at::logical_and(self == 0, is_real_exp); + } else { + cond = at::logical_and(self == 0, exponent >= 0); + } + auto out = grad * at::where(cond, at::zeros({}, grad.options()), - result * self.log()); + (result * self.log()).conj()); + return handle_r_to_c(exponent, out); } Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { - if (base.toDouble() == 0) { - return grad * at::where(exponent >= 0, + auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); }; + if (base.equal(0.0)) { + auto cond = [](auto exp) { + if (exp.is_complex()) { + return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); + } else { + return exp >=0; + } + }; + auto out = grad * at::where(cond(exponent), at::zeros({}, grad.options()), - result * std::log(base.toDouble())); + grad_lambda(result, base)); + return handle_r_to_c(exponent, out); + } else { + auto out = grad * grad_lambda(result, base); + return handle_r_to_c(exponent, out); + } +} + +Tensor angle_backward(Tensor grad, const Tensor& self) { + if (self.is_complex()) { + return at::where(self == 0.0, at::zeros({}, self.options()), + grad * self / self.abs().pow(2) * Scalar(c10::complex{0.0, 1.0})); } else { - return grad * result * std::log(base.toDouble()); + return at::zeros_like(self, at::MemoryFormat::Preserve); } } @@ -218,35 +297,23 @@ Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) { // https://arxiv.org/pdf/1701.00392.pdf Section 4.20 return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result))); } else { - return at::zeros_like(grad, at::MemoryFormat::Preserve); + return at::zeros_like(self, at::MemoryFormat::Preserve); } } Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) { - auto result = grad * other.conj(); - if (!at::isComplexType(self_st) && result.is_complex()) { - // R -> C - result = at::real(result); - } - return result; + auto out = grad * other.conj(); + return handle_r_to_c(self_st, out); } Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) { auto result = grad / other.conj(); - if (!at::isComplexType(self_st) && result.is_complex()) { - // R -> C - result = at::real(result); - } - return result; + return handle_r_to_c(self_st, result); } Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) { auto result = -grad * ((self / other) / other).conj(); - if (!other.is_complex() && result.is_complex()) { - // R -> C - result = at::real(result); - } - return result; + return handle_r_to_c(other, result); } Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) { @@ -381,15 +448,21 @@ Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t di } Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) { - return std::get<0>(at::solve(grad, A.transpose(-2, -1))); + return at::linalg_solve(A.conj().transpose(-2, -1), grad); } Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) { Tensor grad_self = solve_backward_self(grad, self, A); if (self.ndimension() == 2 && A.ndimension() == 2) { - return -at::mm(grad_self, solution.transpose(-2, -1)); + return -at::mm(grad_self, solution.conj().transpose(-2, -1)); + } + // if self was unsqueezed from (..., M) to (..., M, 1) + auto batched_rhs_shape = IntArrayRef(A.sizes().data(), A.dim()-1); // A.shape[:-1] + bool is_rhs_broadcasted = self.dim() == 1 || (A.dim()-1 == self.dim() && self.sizes().equals(batched_rhs_shape)); + if (is_rhs_broadcasted) { + return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).conj().transpose(-2, -1)); } - return -at::matmul(grad_self, solution.transpose(-2, -1)); + return -at::matmul(grad_self, solution.conj().transpose(-2, -1)); } Tensor cumsum_backward(const Tensor & x, int64_t dim) { @@ -514,17 +587,43 @@ Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optionalset_indices_and_values_unsafe(mask_indices, r_values); + return output; +} + +Tensor sparse_sparse_matmul_backward( + const Tensor& grad, + const Tensor& a, + const Tensor& b, + int64_t grad_order) { + /* + To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition + for dense tensors: + + c = a @ b + then + a_grad = c_grad @ b^T + b_grad = a^T @ c_grad + + So for sparse matrices we can use the following definition: + + if grad_order == 0: + a_grad = sparse_matrix_mask(c_grad @ b^T, mask=a) + else: + b_grad = sparse_matrix_mask(a^T @ c_grad, mask=b) + */ + TORCH_CHECK( + grad_order == 0 || grad_order == 1, + ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function"); + if (grad_order == 0) { + auto a_grad = _sparse_sparse_matmul(grad, b.t()); + return _sparse_matrix_mask(a_grad.coalesce(), a.coalesce()); + } + auto b_grad = _sparse_sparse_matmul(a.t(), grad); + return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce()); +} + Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { auto transposed_sizes = self.transpose(dim, 0).sizes().vec(); auto flatten = [&](const Tensor & t) { @@ -581,17 +731,6 @@ Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64 return at::where(mask, grad, grad_norm); } -Tensor sum_tensorlist(TensorList tl) { - if (tl.size() == 0) { - throw std::runtime_error("Can't sum tensorlist of size 0"); - } - Tensor sum = tl[0]; - for(size_t i = 1; i < tl.size(); ++i) { - sum = sum + tl[i]; - } - return sum; -} - Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) { auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0); if (find_iter != repeats.cend()) { @@ -602,13 +741,69 @@ Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape for (int64_t i = 0; i < num_unsqueezed; ++i) { grad = grad.sum(0, false); } - for (size_t j = num_unsqueezed; j < repeats.size(); ++j) { - int64_t repeat = repeats[j]; - if (repeat == 1) { - continue; + + at::DimVector grad_size, sum_dims; + for (size_t dim = 0; dim < input_dims; ++dim) { + int64_t repeat = repeats[dim + num_unsqueezed]; + // Reshape gradient (repeat > 1) + // Index: [..., dim , ...] [..., dim , dim+1 , ...] + // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...] + // The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor. + // Then, sum up gradients over repeated tensors along 'dim', and reduce shape + // from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize'). + // Example: + // Size(3, 2) Size(6, 2) + // [[v1_0, v1_1], + // [v1_2, v1_3], + // [[v0, v1], repeat(2, 1) [v1_4, v1_5], + // [v2, v3], -------------> [v2_0, v2_1], + // [v4, v5]] [v2_2, v2_3], + // [v2_4, v2_5]] + // + // input grad (3, 2) reshape (2, 3, 2) output grad (6, 2) + // [[[g1_0, g1_1], [[g1_0, g1_1], + // [g1_2, g1_3], [g1_2, g1_3], + // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5], + // [g1_0+g2_0, g1_1+g2_1], [g2_0, g2_1], + // [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3], + // [g2_2, g2_3], [g2_4, g2_5]] + // [g2_4, g2_5]]] + // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then + // sum over 'dim+1'. The gradient for input is not correctly aligned with input. + // Example: + // input grad (3, 2) reshape (3, 2, 2) output grad (6, 2) + // [[[g1_0, g1_1], + // [g1_2, g1_3]], [[g1_0, g1_1], + // [g1_2, g1_3], + // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5], + // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1], + // [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3], + // [[g2_2, g2_3], [g2_4, g2_5]] + // [g2_4, g2_5]]] + if (repeat != 1) { + grad_size.push_back(repeat); + sum_dims.push_back(grad_size.size() - 1); } - int64_t dim = j - num_unsqueezed; - grad = sum_tensorlist(grad.chunk(repeat, dim)); + // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1) + grad_size.push_back(input_shape[dim]); + } + // One-time Reshape & Sum + // Reshape gradient to grad_size: + // 1. If repeat equals to 1, append input size at that dimension, + // 2. If repeat is larger than 1, append both repeat and input size at that dimension. + // Sum over all "repeat" dimensions from sum_dims: + // Example: + // Input Size (2, 3, 4, 5) + // repeat [4, 1, 9, 3] + // output/grad Size (8, 3, 36, 15) + // grad_size [4, 2, 3, 9, 4, 3, 5] + // sum_dims [0, 3, 5] + + // When repeat 1 time over all original dimensions, the empty sum_dims will reduce + // the whole grad tensor into a scalar rather than keeping original dimensions. + if (!sum_dims.empty()) { + grad = grad.reshape(grad_size); + grad = grad.sum(sum_dims); } return grad; } @@ -624,13 +819,12 @@ Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) { } Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) { - auto mask = (input == value); - auto count = mask.sum(); - auto grad_input = grad / count; if (input.is_cuda()) { - return mask * grad_input; + auto mask = (input == value).logical_or_(input.isnan().logical_and_(value.isnan())); + return mask * (grad / mask.sum()); } else { - return at::zeros_like(input).masked_fill_(mask, grad_input); + auto mask = value.isnan().item() ? input.isnan() : input == value; + return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum()); } } @@ -649,11 +843,11 @@ Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbi } Tensor std_backward(const Tensor & result, const Tensor & grad, const Tensor & self, bool unbiased) { - return var_backward(grad / (result * 2), self, unbiased); + return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, unbiased); } Tensor std_backward(const Tensor & result, Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) { - return var_backward(grad / (result * 2), self, dim, unbiased, keepdim); + return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, dim, unbiased, keepdim); } Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) { @@ -716,15 +910,15 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { // leads to stable gradient updates, and retains symmetry of the updated matrix if it // were updated by a gradient based algorithm. if (upper) { - L = L.transpose(-1, -2); - grad = grad.transpose(-1, -2); + L = L.transpose(-1, -2).conj(); + grad = grad.transpose(-1, -2).conj(); } auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false)); - auto phi = at::matmul(L.transpose(-1, -2), grad); + auto phi = at::matmul(L.transpose(-1, -2).conj(), grad); phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5); - auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2), phi), L_inverse); - return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5); // Symmetrizing the gradient + auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2).conj(), phi), L_inverse); + return grad_input.add(grad_input.transpose(-1, -2).conj()).mul_(0.5); // Symmetrizing the gradient } Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) { @@ -911,13 +1105,46 @@ Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_outp return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad); } +// NOTE: [How to write vmap-compatible backward formulas] +// +// See NOTE: [vmap-incompatible in-place operations] for what it means for an +// in-place operation to be incompatible with vmap. +// +// If an in-place operation used in a backward formula is vmap-incompatible, +// then as developers we have the following options: +// +// - If the in-place operation directly followed the creation of a tensor with +// a factory function like at::zeros(...), we should replace the factory with a +// corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call +// propagates the batch dims to the resulting tensor. +// For example: +// Before: at::zeros(input.sizes(), grad.options()).copy_(grad) +// After: grad.new_zeros(input.sizes()).copy_(grad) +// +// - If the in-place operation followed some sequence of operations, if the +// we want to be able to vmap over the backward formula as-is (this is +// usually the case for simple (<15loc) backward formulas), then use +// inplaceIsVmapCompatible to guard the operation. For example: +// c = a * b +// Before: c.mul_(grad) +// After: c = at::inplaceIsVmapCompatible(c, grad) ? c.mul_(grad) : c * grad +// +// - If we don't want to vmap directly over the backward formula (e.g., if the +// backward formula is too complicated or has a lot of vmap-incompatible +// operations, then register the backward formula as an operator and eventually +// write a batching rule for it. + Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional& weight, int64_t reduction) { auto eps = 1e-12; auto inp_pl_eps = input + eps; auto one_m_inp_pl_eps = 1 - input + eps; // gradient wrt input auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2)); - gI *= (grad * grad_output); + if (at::inplaceIsVmapCompatible(gI, grad)) { + gI *= (grad * grad_output); + } else { + gI = gI * (grad * grad_output); + } if (isDefined(weight)) { gI *= *weight; @@ -934,7 +1161,11 @@ Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, con auto eps = 1e-12; // gradient wrt grad_output auto ggO = (input - target) / ((input + eps) * (1 - input + eps)); - ggO *= grad; + if (at::inplaceIsVmapCompatible(ggO, grad)) { + ggO *= grad; + } else { + ggO = ggO * grad; + } if (isDefined(weight)) { ggO *= *weight; @@ -947,30 +1178,47 @@ Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, con return ggO; } -Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) { - auto output = l1_loss_backward(grad, input, target, at::Reduction::None); +Tensor l1_loss_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & self, const Tensor & other, int64_t reduction) { + if (!self.is_complex()) { + return at::zeros_like(grad); + } else { + auto diff = self - other; + auto output = grad_output * sgn_backward(diff.sgn(), grad, diff); + if (reduction == at::Reduction::Mean) { + output /= self.numel(); + } + return output; + } +} + +Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { + auto output = at::l1_loss_backward(grad.conj(), input, target, at::Reduction::None); if (reduction == at::Reduction::Mean) { return output.mean(); } else if (reduction == at::Reduction::Sum) { return output.sum(); } - return output; + return handle_r_to_c(grad_output, output); } -Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) { +Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) { + // special case to protect against a divide-by-zero. + if (beta == 0) { + return at::zeros(grad.sizes(), grad.options()); + } auto d = (input - target).abs(); - auto grad_input = grad * (d < 1).type_as(grad); + auto grad_input = grad * (d < beta).type_as(grad) / beta; if (reduction == at::Reduction::Mean) { grad_input /= input.numel(); } return grad_input; } -Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { +Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double beta) { if (reduction == at::Reduction::None) { - return smooth_l1_loss_backward(grad, input, target, reduction); + return smooth_l1_loss_backward(grad, input, target, reduction, beta); } - auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction); + auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction, beta); return (r * grad).sum(); } @@ -1531,7 +1779,7 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayR _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset), _min_storage_size(out_sizes_, out_strides_, out_effective_offset) ); - auto storage = at::zeros({base_size}, grad.options()); + auto storage = grad.new_zeros({base_size}); // prepare indices tensor if we will do index_add_ later c10::optional flatten_full_indices; @@ -1657,6 +1905,35 @@ std::tuple prelu_double_backward( } } +Tensor elu_double_backward( + const Tensor& grad, + const Tensor& grad_output, + Scalar alpha, + Scalar scale, + Scalar input_scale, + bool is_result, + const Tensor& self_or_result) { + + if (is_result) { + return grad * grad_output * input_scale * (self_or_result < 0).type_as(grad); + } else { + return at::elu_backward(grad * grad_output * input_scale, alpha, scale, input_scale, is_result, self_or_result) * (self_or_result < 0).type_as(grad); + } +} + +Tensor slice_backward_wrapper( + const at::Tensor& grad, + const c10::IntArrayRef& input_sizes, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { + auto start_val = start.has_value() ? start.value() : 0; + auto end_val = end.has_value() ? end.value() : INT64_MAX; + + return slice_backward(grad, input_sizes, dim, start_val, end_val, step); +} + // https://j-towns.github.io/papers/svd-derivative.pdf // // This makes no assumption on the signs of sigma. @@ -1672,16 +1949,21 @@ Tensor svd_backward(const std::vector &grads, const T auto gsigma = grads[1]; auto u = raw_u; - auto v = raw_v; + // Currently torch.svd for complex dtypes returns the conjugate of V, + // while the backward formula is derived with just V (without the conjugation) + // therefore here we need to conjugate the V output of SVD and grads[2]. + // Once https://github.com/pytorch/pytorch/issues/45821 is resolved + // extra .conj(), that are marked below in the code, shall be removed. + auto v = raw_v.conj(); // TODO: remove .conj() auto gu = grads[0]; - auto gv = grads[2]; + auto gv = grads[2].conj(); // TODO: remove .conj() if (!some) { // We ignore the free subspace here because possible base vectors cancel // each other, e.g., both -v and +v are valid base for a dimension. // Don't assume behavior of any particular implementation of svd. u = raw_u.narrow(-1, 0, k); - v = raw_v.narrow(-1, 0, k); + v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj() if (gu.defined()) { gu = gu.narrow(-1, 0, k); } @@ -1689,11 +1971,13 @@ Tensor svd_backward(const std::vector &grads, const T gv = gv.narrow(-1, 0, k); } } - auto vt = v.transpose(-2, -1); + auto vh = v.conj().transpose(-2, -1); Tensor sigma_term; if (gsigma.defined()) { - sigma_term = at::matmul(u, at::matmul(gsigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1), vt)); + gsigma = gsigma.to(self.dtype()); + // computes u @ diag(gsigma) @ vh + sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh); } else { sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } @@ -1703,11 +1987,8 @@ Tensor svd_backward(const std::vector &grads, const T return sigma_term; } - auto ut = u.transpose(-2, -1); - auto im = at::eye(m, self.options()); - auto in = at::eye(n, self.options()); - auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); - auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); + auto uh = u.conj().transpose(-2, -1); + auto sigma_inv = sigma.pow(-1); auto sigma_sq = sigma.pow(2); auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1); // The following two lines invert values of F, and fills the diagonal with 0s. @@ -1719,26 +2000,44 @@ Tensor svd_backward(const std::vector &grads, const T Tensor u_term, v_term; if (gu.defined()) { - u_term = at::matmul(u, at::matmul(F.mul(at::matmul(ut, gu) - at::matmul(gu.transpose(-2, -1), u)), sigma_mat)); + auto guh = gu.conj().transpose(-2, -1); + u_term = at::matmul(u, F.mul(at::matmul(uh, gu) - at::matmul(guh, u)) * sigma.unsqueeze(-2)); if (m > k) { - u_term = u_term + at::matmul(im - at::matmul(u, ut), at::matmul(gu, sigma_mat_inv)); + // projection operator onto subspace orthogonal to span(U) defined as I - UU^H + auto proj_on_ortho_u = -at::matmul(u, uh); + proj_on_ortho_u.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).add_(1); + u_term = u_term + proj_on_ortho_u.matmul(gu * sigma_inv.unsqueeze(-2)); } - u_term = at::matmul(u_term, vt); + u_term = at::matmul(u_term, vh); } else { u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (gv.defined()) { - auto gvt = gv.transpose(-2, -1); - v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vt, gv) - at::matmul(gvt, v)), vt)); + auto gvh = gv.conj().transpose(-2, -1); + v_term = sigma.unsqueeze(-1) * at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh); if (n > k) { - v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvt, in - at::matmul(v, vt))); + // projection operator onto subspace orthogonal to span(V) defined as I - VV^H + auto proj_on_v_ortho = -at::matmul(v, vh); + proj_on_v_ortho.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).add_(1); + v_term = v_term + sigma_inv.unsqueeze(-1) * at::matmul(gvh, proj_on_v_ortho); } v_term = at::matmul(u, v_term); } else { v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } + // for complex-valued input there is an additional term + // https://giggleliu.github.io/2019/04/02/einsumbp.html + // https://arxiv.org/abs/1909.02659 + if (self.is_complex() && gu.defined()) { + Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1); + at::real(L).zero_(); + at::imag(L).mul_(sigma_inv); + Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh); + return u_term + sigma_term + v_term + imag_term; + } + return u_term + sigma_term + v_term; } @@ -1810,101 +2109,103 @@ Tensor symeig_backward(const std::vector &grads, cons auto glambda = grads[0]; auto gv = grads[1]; - auto vt = v.transpose(-2, -1); + auto vh = v.conj().transpose(-2, -1); Tensor result; if (gv.defined()) { Tensor F = lambda.unsqueeze(-2) - lambda.unsqueeze(-1); F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); F.pow_(-1); - F.mul_(at::matmul(vt, gv)); - result = at::matmul(v, at::matmul(F, vt)); + result = at::matmul(v, at::matmul(F * at::matmul(vh, gv), vh)); } else { result = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (glambda.defined()) { - result.add_(at::matmul(at::matmul(v, at::diag_embed(glambda, /*offset=*/0, /*dim1=*/-2, /*dim2=*/-1)), vt)); + glambda = glambda.to(self.dtype()); + // computes v @ diag(glambda) @ vh + Tensor glambda_term = at::matmul(v * glambda.unsqueeze(-2), vh); + if (at::inplaceIsVmapCompatible(result, glambda_term)) { + result.add_(glambda_term); + } else { + result = result + glambda_term; + } } - return result.add(result.transpose(-2, -1)).mul_(0.5); + return result.add(result.conj().transpose(-2, -1)).mul_(0.5); } -Tensor qr_backward(const std::vector &grads, const Tensor& self, - bool some, const Tensor& q, const Tensor& r){ +Tensor linalg_qr_backward(const std::vector &grads, const Tensor& self, + std::string mode, const Tensor& q, const Tensor& r){ + bool compute_q, reduced; + std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); + TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. " + "Please use torch.linalg.qr(..., mode='reduced')"); + auto square_deep_case_backward = [](const Tensor& grad_Q, const Tensor& grad_R, const Tensor& A, const Tensor& Q, const Tensor& R) -> Tensor { - // For square and deep (tall) case we refer - // Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear - // Algebra Functions with Application in Optimum Experimental Design - // (Extended Version) The derivative for the QR decomposition is adapted - // from Eq. 42 of the above reference. - - // Compute R (R')^{T} + // For square and deep (tall) case we refer: + // Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra. + // https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition) + // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks. + // https://arxiv.org/abs/1903.09650 Section 3. QR factorization + // For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html + + // Compute R grad_R^H Tensor R_term; if (grad_R.defined()) { - R_term = at::matmul(R, grad_R.transpose(-2, -1)); + R_term = at::matmul(R, grad_R.conj().transpose(-2, -1)); } else { // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - // Compute Q^{T} Q' + // Compute grad_Q^H Q Tensor Q_term; if (grad_Q.defined()) { - Q_term = at::matmul(Q.transpose(-2, -1), grad_Q); + Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q); } else { // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - // We want to compute: (rhs_solve_1 . R^{-T}) - // Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T} + Tensor M = R_term - Q_term; + + // Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity + Tensor M_tril = at::tril(M); + M = M_tril + M_tril.conj().transpose(-2, -1); + M.diagonal(0, -2, -1).mul_(0.5); + + Tensor rhs_term; + if (grad_Q.defined()) { + rhs_term = grad_Q + at::matmul(Q, M); + } else { + rhs_term = at::matmul(Q, M); + } + + // We want to compute: (rhs_term @ R^{-H}) + // Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H // Since R is upper triangular, we can do this using - // triangular_solve(rhs_solve_1^{T}, R)^{T} - auto rhs_solve_1 = - R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1); - rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1); - Tensor solve_soln_1; - std::tie(solve_soln_1, std::ignore) = at::triangular_solve( - rhs_solve_1.transpose(-2, -1), + // triangular_solve(rhs_term^H, R)^H + Tensor grad_A; + std::tie(grad_A, std::ignore) = at::triangular_solve( + rhs_term.conj().transpose(-2, -1), R, /*upper=*/true, /*transpose=*/false, /*unitriangular=*/false); - Tensor grad_A; - if (grad_R.defined()) { - grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R); - } else { - grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1)); - } - // Successive computations involve computation of QQ^{T} which is identity when A is square - if (A.size(-1) != A.size(-2)) { - Tensor rhs_solve_2; - // We use the same trick from above for this computation - if (grad_Q.defined()) { - rhs_solve_2 = grad_Q - at::matmul(Q, Q_term); - } else { - rhs_solve_2 = -at::matmul(Q, Q_term); - } - Tensor solve_soln_2; - std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R, - /*upper=*/true, /*transpose=*/false, - /*unitriangular=*/false); - grad_A.add_(solve_soln_2.transpose(-2, -1)); - } - return grad_A; + return grad_A.conj().transpose(-2, -1); }; auto m = self.size(-2); auto n = self.size(-1); TORCH_CHECK( - ((m <= n && (!some)) || some), - "The derivative is not implemented when nrows > ncols and complete QR. "); + ((m <= n && (!reduced)) || reduced), + "The derivative of qr is not implemented when mode='complete' and nrows > ncols."); auto grad_Q = grads[0]; auto grad_R = grads[1]; @@ -1917,7 +2218,7 @@ Tensor qr_backward(const std::vector &grads, const Te // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y]. // To obtain grad_X we reuse the gradient formula from the square case. // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U), - // where grad_Q_prime = grad_Q + Y @ grad_V.T + // where grad_Q_prime = grad_Q + Y @ grad_V^H // and grad_Y = Q @ grad_V. // Then concatenate grads to get grad_A = [grad_X | grad_Y]. @@ -1929,8 +2230,8 @@ Tensor qr_backward(const std::vector &grads, const Te grad_V = grad_R.narrow(-1, m, n - m); // reuse grad_R to store grad_U grad_R = grad_R.narrow(-1, 0, m); - // grad_Q_prime starts with the value of Y @ grad_V.T - grad_Q_prime = at::matmul(Y, grad_V.transpose(-2, -1)); + // grad_Q_prime starts with the value of Y @ grad_V^H + grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1)); } else { // when grad_R is not defined then grad_V and grad_Q_prime // get initialized with zeros @@ -1971,15 +2272,17 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) return nonsingular_case_backward(grad, self, det); } } else { - auto nonzero_det_indices = at::where(det); + auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det)); + c10::optional first_nonzero_det_index = nonzero_det_indices[0]; - if (nonzero_det_indices[0].size(0) == det.numel()) { // all determinants are nonzero (non-singular) + if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular) return nonsingular_case_backward(grad, self, det); } - auto zero_det_indices = at::where(det == 0); + auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0)); + c10::optional first_zero_det_index = zero_det_indices[0]; - if (zero_det_indices[0].size(0) == det.numel()) { // all determinants are zero (singular) + if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular) return singular_case_backward(grad, self, det); } @@ -2021,15 +2324,17 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo return singular_case_backward(grad, self); } } else { - auto finite_logdet_indices = at::where(logdet != -INFINITY); + auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); + c10::optional first_finite_logdet_index = finite_logdet_indices[0]; - if (finite_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are finite (non-singular) + if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad, self); } - auto neginf_logdet_indices = at::where(logdet == -INFINITY); + auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); + c10::optional first_neginf_logdet_index = neginf_logdet_indices[0]; - if (neginf_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are -inf (singular) + if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad, self); } @@ -2054,6 +2359,7 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, const Tensor& signdet, const Tensor& logabsdet) { auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { Tensor u, sigma, v; + // TODO: replace self.svd with linalg_svd std::tie(u, sigma, v) = self.svd(); // sigma has all non-negative entries (also with at least one zero entry) // so logabsdet = \sum log(abs(sigma)) @@ -2063,25 +2369,29 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, }; auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { - return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1); + // TODO: replace self.inverse with linalg_inverse + return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().conj().transpose(-2, -1); }; if (self.dim() == 2) { - if (signdet.item() == 0) { + bool is_singular = self.is_complex() ? signdet.abs().item() == 0 : signdet.item() == 0; + if (is_singular) { return singular_case_backward(grad_logabsdet, self); } else { return nonsingular_case_backward(grad_logabsdet, self); } } else { - auto nonzero_signdet_indices = at::where(signdet); + auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(self.is_complex() ? at::where(signdet.abs()) : at::where(signdet)); + c10::optional first_nonzero_signdet_index = nonzero_signdet_indices[0]; - if (nonzero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) + if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad_logabsdet, self); } - auto zero_signdet_indices = at::where(signdet == 0); + auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0)); + c10::optional first_zero_signdet_index = zero_signdet_indices[0]; - if (zero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) + if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad_logabsdet, self); } @@ -2112,9 +2422,9 @@ std::tuple triangular_solve_backward( Tensor grad_b, grad_a; if (grad_x.defined() || grad_m.defined()) { if (grad_x.defined()) { - grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular)); + grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular)); if (output_mask[1]) { - grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2)); + grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj()); if (upper) { grad_a = grad_a.triu((int) unitriangular); } else { @@ -2142,8 +2452,8 @@ std::tuple cholesky_solve_backward( if (grad_x.defined()) { grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper); - Tensor common_term = at::matmul(grad_self, result.transpose(-2, -1)); - common_term = common_term + common_term.transpose(-2, -1); + Tensor common_term = at::matmul(grad_self, result.conj().transpose(-2, -1)); + common_term = common_term + common_term.conj().transpose(-2, -1); if (upper) { grad_input2 = -at::matmul(input2, common_term); @@ -2154,94 +2464,72 @@ std::tuple cholesky_solve_backward( return std::tuple{grad_self, grad_input2}; } -// Generally speaking, fft's backward is ifft. -Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - int64_t normalization, bool onesided, - IntArrayRef output_sizes) { - Tensor gI; - if (!complex_input && complex_output) { - // Forward is R2C - // Do inverse C2C and project onto real plane because grad can be - // asymmetrical so C2R can't be used. - if (onesided) { - // Forward is R2C (onesided) - // Think of onesided R2C rfft as - // 1. view as complex numbers (fill complex dim with zeros) - // 2. C2C fft - // 3. discard half of results - // So backward is - // 1. fill the other half with zeros (with `zero_grad_shape` below) - // (C2C ifft only take twosided inputs so we need to fill here) - // 2. inverse C2C ifft - // 3. discard the complex dim - int64_t zero_length = checked_signal_sizes[signal_ndim - 1] - grad.size(signal_ndim); - auto complex_full_grad = grad; - if (zero_length > 0) { - std::vector zero_grad_shape(signal_ndim + 2); - zero_grad_shape[0] = self.size(0); - for (int64_t i = 1; i < signal_ndim; i++) { - zero_grad_shape[i] = checked_signal_sizes[i - 1]; - } - zero_grad_shape[signal_ndim] = zero_length; - zero_grad_shape[signal_ndim + 1] = 2; - complex_full_grad = at::cat({ grad, at::zeros(zero_grad_shape, grad.options()) }, signal_ndim); - } - gI = _fft_with_size(complex_full_grad, signal_ndim, - /* complex_input */ true, /* complex_output */ true, - !inverse, checked_signal_sizes, normalization, - /* onesided */ false, complex_full_grad.sizes()).select(-1, 0); - } else { - gI = _fft_with_size(grad, signal_ndim, /* complex_input */ true, - /* complex_output */ true, !inverse, - checked_signal_sizes, normalization, - /* onesided */ false, grad.sizes()).select(-1, 0); - } - } else if (complex_input && !complex_output && onesided) { - // Forward is C2R (onesided) - // Think of onesided C2R irfft as - // 1. fill the other half by conjugate symmetry - // 2. inverse C2C ifft - // 3. discard the complex dimension - // So backward is - // 1. R2C rfft (essentially add dummy complex dimension, and dft) - // 2. accumulate gradient by conjugate symmetry - // since rfft results follow conjugate symmetry, we only need to - // double some entries from onesided rfft results, i.e., the ones with - // their reflected indices also landing out of the onesided range. So - // consider the index of last dim: - // i. idx = 0. - // Reflected to (N - 0) % N = 0. Not doubled. - // ii 0 < idx < floor(N/2) (last). - // N > N - idx > ceil(N/2) - // Reflected to () - // iii. idx = floor(N/2) = N/2 (last) when N even. - // Reflected to (N - N/2) % N = N/2. Not doubled. - // iv. idx = floor(N/2) = (N-1)/2 (last) when N odd. - // Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled. - // Therefore, needs to double - // idx = 1, 2, ..., N/2 - 1 when N even - // idx = 1, 2, ..., (N-1)/2 when N odd - // that is - // idx = 1, 2, ..., N - (floor(N/2) + 1) - // = 1, 2, ..., N - onesided_length - gI = _fft_with_size(grad, signal_ndim, /* complex_input */ false, - /* complex_output */ true, /* inverse */ false, - checked_signal_sizes, normalization, /* onesided */ true, - self.sizes()); - int64_t double_length = checked_signal_sizes[signal_ndim - 1] - self.size(signal_ndim); - if (double_length > 0) { // also covers case when signal size is zero - gI.narrow(signal_ndim, 1, double_length).mul_(2); - } - } else { - gI = _fft_with_size(grad, signal_ndim, complex_output, complex_input, - !inverse, checked_signal_sizes, normalization, onesided, - self.sizes()); +Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) { + // Forward is C2R (onesided) + // Think of onesided C2R irfft as + // 1. fill the other half by conjugate symmetry + // 2. inverse C2C ifft + // 3. discard the complex dimension + // So backward is + // 1. R2C rfft (essentially add dummy complex dimension, and dft) + // 2. accumulate gradient by conjugate symmetry + // since rfft results follow conjugate symmetry, we only need to + // double some entries from onesided rfft results, i.e., the ones with + // their reflected indices also landing out of the onesided range. So + // consider the index of last dim: + // i. idx = 0. + // Reflected to (N - 0) % N = 0. Not doubled. + // ii 0 < idx < floor(N/2) (last). + // N > N - idx > ceil(N/2) + // Reflected to () + // iii. idx = floor(N/2) = N/2 (last) when N even. + // Reflected to (N - N/2) % N = N/2. Not doubled. + // iv. idx = floor(N/2) = (N-1)/2 (last) when N odd. + // Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled. + // Therefore, needs to double + // idx = 1, 2, ..., N/2 - 1 when N even + // idx = 1, 2, ..., (N-1)/2 when N odd + // that is + // idx = 1, 2, ..., N - (floor(N/2) + 1) + // = 1, 2, ..., N - onesided_length + auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true); + + auto double_length = grad.size(dim.back()) - gI.size(dim.back()); + if (double_length > 0) { // also covers case when signal size is zero + gI.narrow(dim.back(), 1, double_length).mul_(2); } return gI; } +Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization, + bool onesided, int64_t last_dim_size) { + if (!onesided) { + return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false)); + } + + // Forward is R2C (onesided) + // Think of onesided R2C rfft as + // 1. view as complex numbers (fill complex dim with zeros) + // 2. C2C fft + // 3. discard half of results + // So backward is + // 1. fill the other half with zeros (with `zero_grad_shape` below) + // (C2C ifft only take twosided inputs so we need to fill here) + // 2. inverse C2C ifft + // 3. discard the complex dim + auto half_sizes = grad.sizes(); + at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end()); + const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size()); + new_grad_shape[last_dim] = last_dim_size; + + const auto zero_length = last_dim_size - grad.size(dim.back()); + auto complex_full_grad = zero_length > 0 ? at::zeros(new_grad_shape, grad.options()) : grad; + if (zero_length > 0) { + complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad); + } + return at::real(at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false)); +} + // Helper for batchnorm_double_backward Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) { auto r = to_sum.sum(0, keepdim); @@ -2412,10 +2700,19 @@ infinitely_differentiable_native_layer_norm_backward( const Tensor& mean, const Tensor& rstd, const c10::optional& gamma, - int64_t M, - int64_t N, + IntArrayRef normalized_shape, double eps, std::array grad_input_mask) { + + const int normalized_ndim = normalized_shape.size(); + const auto input_shape = X.sizes(); + const auto input_ndim = X.dim(); + const int axis = input_ndim - normalized_ndim; + const int64_t M = + at::prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + at::prod_intlist(input_shape.cbegin() + axis, input_shape.cend()); + Tensor dX; Tensor dgamma; Tensor dbeta; @@ -2610,7 +2907,7 @@ Tensor log1p_backward(const Tensor& grad, const Tensor& self) { "Use a different mathematical operation which preserves sparsity of gradients, ", "or report a bug if you think this is an error."); } - return grad / (self + 1); + return grad / (self + 1).conj(); } Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntArrayRef values_shape) { @@ -2631,21 +2928,23 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { return at::constant_pad_nd(grad, negated_pad, 0); } -Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) { - // since first backward takes care of padding_idx - // and scaling by frequency, we don't need to worry - // about it here. +Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { + // since first backward takes care of scaling by frequency, + // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1)); // reshape gradient as per the shape of indices auto size = indices.sizes().vec(); size.push_back(-1); + if (padding_idx >= 0) { + gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); + } return gg_weight.view(size); } -Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor& grad) { - return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); +Tensor index_backward(Tensor zeros_like_self, const torch::List>& indices, const Tensor& grad) { + return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); } Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 1e2af0772b523..4cce8cfd22c8a 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -17,6 +17,8 @@ namespace autograd { namespace generated { namespace details { +extern const char* kCudnnDoubleBackwardMsg; + // A simple way to imperatively compute index ranges for slots // that have been flattened struct IndexRangeGenerator { @@ -29,10 +31,17 @@ struct IndexRangeGenerator { size_t i = 0; }; +bool isFwGradDefined(const c10::optional& t); +Tensor toLegacyFwGrad(const c10::optional& t); +Tensor toLegacyPrimal(const c10::optional& t); + bool any_variable_defined(variable_list& variables); void copy_range(variable_list& out, IndexRange range, const at::Tensor & t); void copy_range(variable_list& out, IndexRange range, at::ArrayRef t); -at::Tensor not_implemented(const char* name); +at::Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result); +at::Tensor not_implemented(const char* name, const char* reason=""); +std::vector not_implemented_list(const char* name, const char* reason=""); +at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result); at::Tensor maybe_multiply(const at::Tensor & t, const at::Scalar & s); int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim); Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim); @@ -43,6 +52,7 @@ at::Tensor pow_backward(at::Tensor grad, const at::Tensor & self, const at::Scal at::Tensor pow_backward_self(at::Tensor grad, const at::Tensor & self, const at::Tensor & exponent); at::Tensor pow_backward_exponent(at::Tensor grad, const at::Tensor& self, const at::Tensor& exponent, at::Tensor result); at::Tensor pow_backward_exponent(at::Tensor grad, const at::Scalar & base, const at::Tensor& exponent, at::Tensor result); +at::Tensor angle_backward(at::Tensor grad, const at::Tensor& self); at::Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st); at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st); at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other); @@ -68,11 +78,12 @@ at::Tensor unsqueeze_to(const at::Tensor & self, at::IntArrayRef sizes); at::Tensor unsqueeze_to(const at::Tensor & self, int64_t dim, at::IntArrayRef sizes); std::vector cat_tensors_backward(const at::Tensor & grad, const std::vector> &sizes, int64_t dim); at::Tensor clamp_backward(const at::Tensor & grad, const at::Tensor &self, const optional & min, const optional & max); -at::Tensor mm_mat1_backward(const at::Tensor & grad, const at::Tensor & mat2, const at::Tensor & mat1, const at::Scalar & alpha); +at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name); +at::Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha); at::Tensor mm_mat2_backward(const at::Tensor & grad, const at::Tensor & mat1, at::IntArrayRef sizes, at::IntArrayRef strides, const at::Scalar & alpha); at::Tensor _sparse_addmm_sparse_backward(const at::Tensor& grad, const at::Tensor& sparse_, const at::Tensor& dense, const at::Scalar& alpha); +at::Tensor sparse_sparse_matmul_backward(const at::Tensor& grad, const at::Tensor& mat1, const at::Tensor& mat2,int64_t grad_order); at::Tensor renorm_backward(const at::Tensor & grad, const at::Tensor & self, at::Scalar p, int64_t dim, at::Scalar maxnorm); -at::Tensor sum_tensorlist(at::TensorList tl); at::Tensor repeat_backward(at::Tensor grad, at::IntArrayRef repeats, at::IntArrayRef input_shape); at::Tensor _fused_dropout_backward(at::Tensor grad, at::Tensor mask, double p1m); at::Tensor evenly_distribute_backward(at::Tensor grad, const at::Tensor & input, const at::Tensor & value); @@ -103,9 +114,10 @@ at::Tensor softmax_double_backward(const at::Tensor & grad, const at::Tensor & g at::Tensor log_softmax_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, int dim, const at::Tensor & output); at::Tensor binary_cross_entropy_double_backward(const at::Tensor & grad_output, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional& weight, int64_t reduction); at::Tensor binary_cross_entropy_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional& weight, int64_t reduction); -at::Tensor l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction); -at::Tensor smooth_l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction); -at::Tensor smooth_l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); +at::Tensor l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); +at::Tensor l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); +at::Tensor smooth_l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, double beta); +at::Tensor smooth_l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction, double beta); at::Tensor mse_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, int64_t reduction); at::Tensor mse_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); at::Tensor soft_margin_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction); @@ -115,12 +127,20 @@ at::Tensor logdet_backward(const at::Tensor & grad, const at::Tensor& self, cons at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& self, const at::Tensor& signdet, const at::Tensor& logabsdet); at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape); -at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices); -at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); +at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx); +at::Tensor index_backward(at::Tensor zeros_like_self, const torch::List>& indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); +at::Tensor elu_double_backward(const Tensor& grad, const Tensor& grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, const Tensor& self_or_result); Tensor svd_backward(const std::vector &grads, const Tensor& self, bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v); +Tensor slice_backward_wrapper( + const at::Tensor& grad, + const c10::IntArrayRef& input_sizes, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step); Tensor symeig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v); std::tuple triangular_solve_backward( @@ -131,8 +151,8 @@ std::tuple triangular_solve_backward( std::tuple _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3, IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, IntArrayRef sumdim, int64_t unroll_dim, std::array grad_mask); -Tensor qr_backward(const std::vector &grads, const Tensor& self, - bool some, const Tensor& Q, const Tensor& R); +Tensor linalg_qr_backward(const std::vector &grads, const Tensor& self, + std::string mode, const Tensor& Q, const Tensor& R); Tensor eig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, const Tensor& lambda, const Tensor& v); Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det); @@ -157,6 +177,9 @@ Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim, bool inverse, IntArrayRef checked_signal_sizes, int64_t normalization, bool onesided, IntArrayRef output_sizes); +Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization, + bool onesided, int64_t last_dim_size); +Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization); Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad); std::tuple cholesky_solve_backward( const Tensor& grad_x, const Tensor& self, @@ -193,8 +216,7 @@ infinitely_differentiable_native_layer_norm_backward( const Tensor& mean, const Tensor& rstd, const c10::optional& gamma, - int64_t M, - int64_t N, + IntArrayRef normalized_shape, double eps, std::array grad_input_mask); diff --git a/torch/csrc/autograd/TraceTypeManual.cpp b/torch/csrc/autograd/TraceTypeManual.cpp index 625f3c40e9a13..148725eecdea0 100644 --- a/torch/csrc/autograd/TraceTypeManual.cpp +++ b/torch/csrc/autograd/TraceTypeManual.cpp @@ -131,7 +131,7 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) { m.impl("copy_", copy_); // Skip tracing for the following ops by registering fallthrough kernel explicitly. - m.impl("backward", CppFunction::makeFallthrough()); + m.impl("_backward", CppFunction::makeFallthrough()); m.impl("set_data", CppFunction::makeFallthrough()); m.impl("data", CppFunction::makeFallthrough()); m.impl("is_leaf", CppFunction::makeFallthrough()); @@ -139,6 +139,7 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) { m.impl("_version", CppFunction::makeFallthrough()); m.impl("requires_grad_", CppFunction::makeFallthrough()); m.impl("retain_grad", CppFunction::makeFallthrough()); + m.impl("_fw_primal", CppFunction::makeFallthrough()); } } // namespace diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 9dfc4573188a6..667064c2e1940 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -1,11 +1,13 @@ #include #include #include +#include #include #include #include #include #include +#include using namespace at; using namespace torch::autograd::generated; @@ -64,10 +66,6 @@ Tensor unpack_opt(const Tensor & t, const char * name, int pos) { return unpack(t, name, pos); } -c10::optional unpack_opt(const c10::optional & t, const char * name, int pos) { - return t; -} - std::vector unpack(at::TensorList tl, const char *name, int pos) { std::vector ret(tl.size()); for (size_t i = 0; i < tl.size(); ++i) { @@ -82,15 +80,17 @@ std::vector unpack(at::TensorList tl, const char *name, int pos) { namespace { -void backward( +void _backward( const Tensor& self, + TensorList inputs, const c10::optional& gradient, c10::optional keep_graph, bool create_graph) { // TODO torch::autograd::backward should take the c10::optional gradient directly // instead of us having to unwrap it to Tensor _gradient here. Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); - torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph); + std::vector input_vars(inputs.begin(), inputs.end()); + torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars); } void set_data(Tensor & self, const Tensor & new_data) { @@ -191,17 +191,49 @@ void retain_grad(Tensor & self) { impl::get_autograd_meta(self)->retains_grad_ = true; } +// Taken from codegened version +Tensor _fw_primal(const Tensor & self, int64_t level) { + auto& self_ = unpack(self, "self", 0); + std::shared_ptr grad_fn; + if (compute_requires_grad( self )) { + grad_fn = std::make_shared(); + grad_fn->set_next_edges(collect_next_edges( self )); + } + auto tmp = ([&]() { + at::AutoNonVariableTypeMode non_var_type_mode(true); + return self_.alias(); + })(); + std::function func=nullptr; + if (!self.unsafeGetTensorImpl()->support_as_strided()) { + auto size_vec = self.sizes().vec(); + func = [=](const at::Tensor& input_base) { + return input_base.view(size_vec); + }; + } + auto result = as_view(/* base */ self, /* output */ tmp, /* is_bw_differentiable */ true, + /* is_fw_differentiable */ false, /* view_func */ func, /* creation_meta */ CreationMeta::DEFAULT); + if (grad_fn) { + set_history(flatten_tensor_args( result ), grad_fn); + } + if (generated::details::isFwGradDefined(self)) { + // Modified from original codegen + // We explicitly want to ignore the forward grad at the given level + TORCH_CHECK(level == 0, "Invalid level given to _fw_primal"); + // End modified from original codegen + } + return result; +} + // We don't have an outplace copy, so this can't be generated automatically Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { - jit::Value* output = nullptr; // TODO: once copy is exposed in Declarations.yaml we may be able to bind // it automatically auto& self_ = unpack(self, "self", 0); auto& src_ = unpack(src, "src", 1); - check_inplace(self); std::shared_ptr grad_fn; auto requires_grad = compute_requires_grad(self, src); requires_grad &= isDifferentiableType(self.scalar_type()); + check_inplace(self, requires_grad); if (requires_grad) { grad_fn = std::make_shared(); grad_fn->set_next_edges(collect_next_edges(self, src)); @@ -214,6 +246,24 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { } increment_version(self); rebase_history(self , std::move(grad_fn)); + + if (isDifferentiableType(self.scalar_type()) && + (generated::details::isFwGradDefined(self) || generated::details::isFwGradDefined(src))) { + auto self_fw_grad = generated::details::toLegacyFwGrad(self); + auto src_fw_grad = generated::details::toLegacyFwGrad(src); + Tensor new_fw_grad; + if (self_fw_grad.defined()) { + if (src_fw_grad.defined()) { + new_fw_grad = self_fw_grad.copy_(src_fw_grad); + } else { + new_fw_grad = self_fw_grad.fill_(0); + } + } else { + new_fw_grad = src_fw_grad; + } + self.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ true); + } + return self; } @@ -227,8 +277,13 @@ Tensor& resize_( } { at::AutoNonVariableTypeMode non_var_type_mode(true); - self_.resize_(size, std::move(optional_memory_format)); + self_.resize_(size, optional_memory_format); + } + + if (self.fw_grad(/* level */ 0).defined()) { + AT_ERROR("cannot resize variables that has a forward grad"); } + return self; } @@ -243,15 +298,30 @@ Tensor& resize_as_( } { at::AutoNonVariableTypeMode non_var_type_mode(true); - at::resize_as_(self_, the_template_, std::move(optional_memory_format)); + at::resize_as_(self_, the_template_, optional_memory_format); + } + + // Handle fw grad + if (self.fw_grad(/* level */ 0).defined()) { + AT_ERROR("cannot resize variables that has a forward grad"); } return self; } Tensor detach(const Tensor & self) { RECORD_FUNCTION("detach", std::vector({self})); - auto result = make_variable_non_differentiable_view(self, self, /*allow_tensor_metadata_change=*/false); + std::function func=nullptr; + auto result = as_view(/* base */ self, /* output */ self, /* is_bw_differentiable */ false, + /* is_fw_differentiable */ true, /* view_func */ func, /* creation_meta */ CreationMeta::DEFAULT, + /*allow_tensor_metadata_change=*/false); namedinference::propagate_names(result, self); + + // detach only backward gradients for both primal and tangent + if (self.fw_grad(/* level */ 0).defined()) { + auto new_fw_grad = self.fw_grad(/* level */ 0).detach(); + result.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false); + } + return result; } @@ -261,7 +331,7 @@ Tensor & detach_(Tensor & self) { // NB: is_view() ==> get_autograd_meta() auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(self)); // See NOTE [ View + Inplace detection ] - if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) { + if (diff_view_meta->get_creation_meta() == CreationMeta::MULTI_OUTPUT_SAFE) { TORCH_WARN("This view is an output of a function that " "returns multiple views. Detaching such views inplace " "is being deprecated and will be forbidden " @@ -269,7 +339,8 @@ Tensor & detach_(Tensor & self) { "of detach_(). Alternatively, create this view with an " "`unsafe_` version of the function that produced it."); } else { - AT_ERROR("If you are using DistributedDataParallel (DDP) for training, " + AT_ERROR("Can't detach views in-place. Use detach() instead. " + "If you are using DistributedDataParallel (DDP) for training, " "and gradient_as_bucket_view is set as True, gradients are " "views of DDP buckets, and hence detach_() cannot be called " "on these gradients. To fix this error, please refer to the " @@ -287,90 +358,51 @@ Tensor & detach_(Tensor & self) { autograd_meta->set_requires_grad(false, self.unsafeGetTensorImpl()); autograd_meta->grad_fn_.reset(); autograd_meta->output_nr_ = 0; + + // detach only backward gradients for both primal and tangent + if (self.fw_grad(/* level */ 0).defined()) { + self.fw_grad(/* level */ 0).detach_(); + } + return self; } -// Some ops in the following registration list are registered as catch-all kernels, -// some as catch-all kernels and additionally as backend kernels for Autograd. -// The reason for this is that ops that also use dispatch (e.g. register CPU/CUDA/QuantizedCPU -// kernels) need to get a separate Autograd kernel instead of a catch-all kernel, -// otherwise we won't ever call it for CPU/CUDA/QuantizedCPU tensors, because the backend -// kernel has a higher priority than catch-all kernels. -// Unfortunately, this setup doesn't work in NonVariableTypeMode because that will -// skip past variable kernels. So for ops that we want to use in NonVariableTypeMode -// (and that don't use dispatch), we register them as catch-all kernels instead. +// Ops in the following registration list are registered as +// (1) Math kernels +// (2) Autograd kernels +// (3) DefaultBackend kernels and additionally Autograd kernels +// The reason for (3) is that ops that also use dispatch (e.g. register CPU/CUDA/QuantizedCPU +// kernels) will skip picking up Math kernels for Autograd, so we register them to both +// DefaultBackend and Autograd instead. See +// https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword +// for more details. // Invariant: -// - Ops registered to catchAll below must match `MANUAL_CATCHALL` set in tools/autograd/gen_variable_type.py. +// - Ops registered to Math or DefaultBackend below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py. // and they have manual_kernel_registration=True in native_functions.yaml. // - Ops registered to DispatchKey::Autograd below must be included in `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py -static auto registry = torch::RegisterOperators() - .op(torch::RegisterOperators::options() - .schema("aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .kernel(DispatchKey::Autograd)) - .op(torch::RegisterOperators::options() - .schema("aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .kernel(DispatchKey::Autograd)) - .op(torch::RegisterOperators::options() - .schema("aten::detach(Tensor(a) self) -> Tensor(a)") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .kernel(DispatchKey::Autograd)) - .op(torch::RegisterOperators::options() - .schema("aten::detach_(Tensor(a!) self) -> Tensor(a!)") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .kernel(DispatchKey::Autograd)) - .op(torch::RegisterOperators::options() - .schema("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .kernel(DispatchKey::Autograd)) - .op(torch::RegisterOperators::options() - .schema("aten::backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - // For backward(), we need the catch-all kernel (see comment above), but we also need the Autograd backend - // kernel, because when called with a VariableTensorId tensor, it goes through the variable fallback kernel, - // which calls callBoxed(), which doesn't support optional tensor arguments yet and backward() has an optional - // tensor argument. - // TODO Once callBoxed() supports optional tensor arguments, we can enable `use_c10_dispatcher: full` for backward() - // and remove the backend Autograd kernel here, only leaving the catch-all kernel. - .kernel(DispatchKey::Autograd) - .catchAllKernel()) - .op(torch::RegisterOperators::options() - .schema("aten::set_data(Tensor(a!) self, Tensor new_data) -> ()") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .catchAllKernel()) - .op(torch::RegisterOperators::options() - .schema("aten::data(Tensor self) -> Tensor") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .catchAllKernel()) - .op(torch::RegisterOperators::options() - .schema("aten::is_leaf(Tensor self) -> bool") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .catchAllKernel()) - .op(torch::RegisterOperators::options() - .schema("aten::output_nr(Tensor self) -> int") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .catchAllKernel()) - .op(torch::RegisterOperators::options() - .schema("aten::_version(Tensor self) -> int") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .catchAllKernel()) - .op(torch::RegisterOperators::options() - .schema("aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - // For requires_grad_(), we need the catch-all kernel (see comment above), but we also need the Autograd backend - // kernel, because when called with a VariableTensorId tensor, it goes through the variable fallback kernel, - // which calls callBoxed(), which doesn't support mutable tensor arguments yet and requires_grad_() has a mutable - // tensor argument. - // TODO Once callBoxed() supports mutable tensor arguments, we can enable `use_c10_dispatcher: full` for requires_grad_() - // and remove the backend Autograd kernel here, only leaving the catch-all kernel. - .kernel(DispatchKey::Autograd) - .catchAllKernel()) - .op(torch::RegisterOperators::options() - .schema("aten::retain_grad(Tensor(a!) self) -> ()") - .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) - .catchAllKernel()) - ; + +TORCH_LIBRARY_IMPL(aten, Autograd, m) { + m.impl("resize_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_))); + m.impl("resize_as_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_as_))); + m.impl("detach", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach))); + m.impl("detach_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach_))); + m.impl("copy_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::copy_))); + m.impl("_fw_primal", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal))); +} + +TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) { + m.impl("_backward", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::_backward))); + m.impl("requires_grad_", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::requires_grad_))); +} + +TORCH_LIBRARY_IMPL(aten, Math, m) { + m.impl("set_data", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::set_data))); + m.impl("data", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::data))); + m.impl("is_leaf", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::is_leaf))); + m.impl("output_nr", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::output_nr))); + m.impl("_version", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::_version))); + m.impl("retain_grad", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::retain_grad))); +} } // namespace }}} // namespace torch::autograd::VariableType diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 692972533adcc..2894a75fed69c 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -39,25 +39,35 @@ using namespace torch::autograd::generated; namespace torch { namespace autograd { -inline void check_inplace(const Tensor& tensor) { - auto& var = static_cast(tensor); - if (var.requires_grad() && GradMode::is_enabled()) { - if (var.is_view()) { +// The requires_grad argument is used to know if the inplace operation needs +// gradient to be setup for it. +// In particular, we can have tensor.requires_grad() != requires_grad when writing +// a Tensor that requires gradients inplace into a Tensor that does not require gradients: +// a = torch.rand(2) +// b = torch.rand(2, requires_grad=True) +// a.copy_(b) +inline void check_inplace(const Tensor& tensor, bool requires_grad) { + if (requires_grad && GradMode::is_enabled()) { + if (tensor.is_view()) { // NB: is_view() ==> get_autograd_meta() - auto diff_view_meta = static_cast(impl::get_autograd_meta(var)); + auto diff_view_meta = static_cast(impl::get_autograd_meta(tensor)); // This can throw or warn handle_view_on_rebase(diff_view_meta); + if (tensor.requires_grad() && tensor._base().is_leaf()) { + AT_ERROR( + "a view of a leaf Variable that requires grad is being used in an in-place operation."); + } } - if (var.is_leaf()) { + if (tensor.requires_grad() && tensor.is_leaf()) { AT_ERROR( "a leaf Variable that requires grad is being used in an in-place operation."); } } } -inline void check_inplace(const TensorList tensors) { +inline void check_inplace(const TensorList tensors, bool requires_grad) { for (const auto& tensor : tensors) { - check_inplace(tensor); + check_inplace(tensor, requires_grad); } } @@ -67,6 +77,19 @@ inline void throw_error_out_requires_grad(const char* name) { "but one of the arguments requires grad."); } +inline void throw_error_for_complex_autograd(const Tensor& tensor, const char* name) { + if (tensor.requires_grad()) { + TORCH_CHECK(!tensor.is_complex(), name, + " does not support automatic differentiation for outputs with complex dtype."); + } +} + +inline void throw_error_for_complex_autograd(const TensorList& tensorlist, const char* name) { + for (auto tensor: tensorlist) { + throw_error_for_complex_autograd(tensor, name); + } +} + // TODO: Blegh, bare references inline void rebase_history(Variable& var, std::shared_ptr grad_fn) { @@ -111,88 +134,111 @@ template inline variable_list flatten_tensor_args(Args&&... ar } // See NOTE [ Autograd View Variables ] for details. -inline Tensor as_view(const Tensor & base, Tensor tensor, bool is_differentiable, - c10::optional> view_func=c10::nullopt, - CreationMeta creation_meta=CreationMeta::DEFAULT) { - auto base_var = Variable(base); - if (base_var.is_view()) { - // Set `view_func` using the root base as input. - // `view_func` is used to recover views in backward when either as_strided is not supported - // or the view function changes the metadata which is not recorded by as_strided - // See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor] - // for more details how we use this function in backward. - auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(base_var)); - if (view_func.has_value()) { - auto fn = view_func.value(); - // both current_view and it's parent have a view_func - if (diff_view_meta->has_view_fn()) { - auto prev_fn = diff_view_meta->view_fn(); - view_func = [=](const at::Tensor& root_base) { - auto temp = prev_fn(root_base); - return fn(temp); - }; +inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_differentiable, + bool is_fw_differentiable, std::function view_func=nullptr, + CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) { + if (!isForwardADEnabled()) { + // Fast codepath for backward only code + // It is useful as it avoids the creation of the temporary c10 which makes + // a significant difference when measuring instruction count for a single "t.view(-1)" call from c++. + if (is_bw_differentiable) { + if (base.is_view()) { + auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(base)); + const auto& base_bw_info = diff_view_meta->get_backward_view(); + return make_variable_differentiable_view(tensor, base_bw_info.chain(base, tensor, view_func), + c10::nullopt, creation_meta, allow_tensor_metadata_change); } else { - // current_view has a view_func and but it's parent doesn't have one - if(base_var.unsafeGetTensorImpl()->support_as_strided()) { - auto size = base.sizes().vec(); - auto stride = base.strides().vec(); - auto storage_offset = base.storage_offset(); - view_func = [=](const at::Tensor& root_base) { - auto temp = root_base.as_strided(size, stride, storage_offset); - return fn(temp); - }; - } else { - // When base_var is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's - // a view that doesn't support inplace update, e.g. unbind. - // In this case we should throw an error when inplace update happens in **forward**. - // One would naturally think the following function will be first called in backward pass. - // But the first call site is indeed in **forward** pass when we refresh `grad_fn` - // triggered by inplace update. - // Search Note [View + Inplace update for view tensor] to for the call site. - view_func = [=](const at::Tensor& root_base) { - TORCH_CHECK(false, "This view is the output of a function that returns multiple views." - "Such functions do not allow the output views to be modified inplace." - "You should replace the inplace operation by an out-of-place one"); - return root_base; - }; - } + return make_variable_differentiable_view(tensor, ViewInfo(base, view_func), + c10::nullopt, creation_meta, allow_tensor_metadata_change); } - } else if(diff_view_meta->has_view_fn()) { - // if current_view doesn't have a view_func but it's parent has one - auto prev_view_fn = diff_view_meta->view_fn(); - auto size = tensor.sizes().vec(); - auto stride = tensor.strides().vec(); - auto storage_offset = tensor.storage_offset(); - view_func = [=](const at::Tensor& root_base) { - auto temp = prev_view_fn(root_base); - return temp.as_strided(size, stride, storage_offset); - }; + } else { + TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, + "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); + return make_variable_non_differentiable_view(base, std::move(tensor), allow_tensor_metadata_change); } - base_var = base_var._base(); } - if (is_differentiable) { - return make_variable_differentiable_view(std::move(base_var), std::move(tensor), creation_meta, std::move(view_func)); + // Create both the forward and backward info that are needed + c10::optional new_bw_info; + c10::optional new_fw_info; + + if (is_bw_differentiable) { + if (base.is_view()) { + auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(base)); + const auto& base_bw_info = diff_view_meta->get_backward_view(); + new_bw_info = base_bw_info.chain(base, tensor, view_func); + } else { + new_bw_info = ViewInfo(base, view_func); + } } else { TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, - "Non-differentiable views must have creation_meta=CreationMeta::DEFAULT"); - return make_variable_non_differentiable_view(std::move(base_var), std::move(tensor)); + "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); + } + + if (is_fw_differentiable) { + // Check if base is a forward differentiable view + auto base_meta = torch::autograd::impl::get_autograd_meta(base); + auto is_view = base_meta && base_meta->is_view_; + if (is_view && static_cast(base_meta)->has_fw_view()) { + auto diff_view_meta = static_cast(base_meta); + const auto& base_fw_info = diff_view_meta->get_forward_view(); + new_fw_info = base_fw_info.chain(base, tensor, view_func); + } else { + new_fw_info = ViewInfo(base, view_func); + } + } + + if (is_fw_differentiable || is_bw_differentiable) { + return make_variable_differentiable_view(tensor, std::move(new_bw_info), std::move(new_fw_info), + creation_meta, allow_tensor_metadata_change); + } else { + return make_variable_non_differentiable_view(base, tensor, allow_tensor_metadata_change); } } // See NOTE [ Autograd View Variables ] for details. -inline std::vector as_view(const Tensor & base, std::vector tensors, bool is_differentiable, - CreationMeta creation_meta=CreationMeta::DEFAULT) { - auto base_var = Variable(base); - if (base_var.is_view()) { - base_var = base_var._base(); +inline std::vector as_view(const Tensor & base, std::vector& tensors, bool is_bw_differentiable, + bool is_fw_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) { + c10::optional new_bw_info = c10::nullopt; + c10::optional new_fw_info = c10::nullopt; + + if (is_bw_differentiable) { + if (base.is_view()) { + auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(base)); + const auto& base_bw_info = diff_view_meta->get_backward_view(); + TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::MULTI_OUTPUT_NODE || creation_meta == CreationMeta::MULTI_OUTPUT_SAFE, + "Functions that result multiple view must have a creation meta reflecting this behavior."); + // It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are + // not allowed + new_bw_info = ViewInfo(base_bw_info.base_, /* view_func */ nullptr); + } else { + new_bw_info = ViewInfo(base, /* view_func */ nullptr); + } + } else { + TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, + "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); } + if (isForwardADEnabled() && is_fw_differentiable) { + // Check if base is a forward differentiabble view + auto base_meta = torch::autograd::impl::get_autograd_meta(base); + auto is_view = base_meta && base_meta->is_view_; + if (is_view && static_cast(base_meta)->has_fw_view()) { + auto diff_view_meta = static_cast(base_meta); + const auto& base_fw_info = diff_view_meta->get_forward_view(); + TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::MULTI_OUTPUT_NODE || creation_meta == CreationMeta::MULTI_OUTPUT_SAFE, + "Functions that result multiple view must have a creation meta reflecting this behavior."); + // It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are + // not allowed + new_fw_info = ViewInfo(base_fw_info.base_, /* view_func */ nullptr); + } else { + new_fw_info = ViewInfo(base, /* view_func */ nullptr); + } + } + for(Tensor &tensor : tensors) { - if (is_differentiable) { - tensor = make_variable_differentiable_view(base_var, std::move(tensor), creation_meta); + if (is_fw_differentiable || is_bw_differentiable) { + tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, creation_meta); } else { - TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, - "Non-differentiable views must have creation_meta=CreationMeta::DEFAULT"); - tensor = make_variable_non_differentiable_view(base_var, std::move(tensor)); + tensor = make_variable_non_differentiable_view(base, tensor); } } return tensors; @@ -220,12 +266,31 @@ inline void check_no_requires_grad(TensorList tensors, const char* name) { } } +inline void check_no_requires_grad(const c10::List>& tensors, const char* name) { + for (c10::optional tensor : tensors) { + if (tensor.has_value()) { + check_no_requires_grad(*tensor, name); + } + } +} + // Assumed that saved tensor lists are never inplace outputs inline std::vector make_saved_variable_list(TensorList tensors) { return fmap(tensors, [](const Tensor& tensor) -> SavedVariable { return SavedVariable{tensor, false /* is output */}; }); } +// Assumed that saved tensor lists are never inplace outputs +inline std::vector make_saved_variable_list(const c10::List>& tensors) { + return fmap(tensors, [](const c10::optional& tensor) -> SavedVariable { + if (tensor.has_value()) { + return SavedVariable{*tensor, false /* is output */}; + } else { + return SavedVariable{Tensor(), false /* is output */}; + } + }); +} + inline std::vector> to_args_sizes(TensorList tensors) { std::vector> args_sizes(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/torch/csrc/autograd/anomaly_mode.cpp b/torch/csrc/autograd/anomaly_mode.cpp index bbb76fba656f8..e8afa6f8fc523 100644 --- a/torch/csrc/autograd/anomaly_mode.cpp +++ b/torch/csrc/autograd/anomaly_mode.cpp @@ -1,9 +1,77 @@ +#include +#include #include +#include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { bool AnomalyMode::_enabled = false; +namespace { +std::mutex& get_anomaly_guard_lock() { + static std::mutex anomaly_guard_lock{}; + return anomaly_guard_lock; +} + +uint32_t& get_anomaly_counter() { + static uint32_t counter = 0; + return counter; +} +} // namespace + +DetectAnomalyGuard::DetectAnomalyGuard() { + TORCH_WARN_ONCE( + "This mode should be enabled only for debugging as the different tests will slow down your program execution."); + std::lock_guard lock(get_anomaly_guard_lock()); + uint32_t& counter = get_anomaly_counter(); + counter++; + AnomalyMode::set_enabled(true); +} + +DetectAnomalyGuard::~DetectAnomalyGuard() { + std::lock_guard lock(get_anomaly_guard_lock()); + uint32_t& counter = get_anomaly_counter(); + counter--; + AnomalyMode::set_enabled(counter > 0); +} + AnomalyMetadata::~AnomalyMetadata() = default; -}} +void AnomalyMetadata::store_stack() { + traceback_ = c10::get_backtrace(/* frames_to_skip */ 1); +} + +void AnomalyMetadata::print_stack(const std::string& current_node_name) { + TORCH_WARN( + "Error detected in ", + current_node_name, + ". ", + "Traceback of forward call that caused the error:\n", + traceback_); + + auto& cur_parent = parent_; + // if there is no "parent_" in metadata, then it means this metadata's node + // is the root and stop printing the traceback + while (cur_parent) { + auto parent_metadata = cur_parent->metadata(); + TORCH_WARN( + "\n\n", + "Previous calculation was induced by ", + cur_parent->name(), + ". " + "Traceback of forward call that induced the previous calculation:\n", + parent_metadata->traceback_); + // get the parent of this node, if this node is a root, pyparent is simply + // null + cur_parent = parent_metadata->parent_; + } +} + +void AnomalyMetadata::assign_parent(const std::shared_ptr& parent_node) { + parent_ = parent_node; +} + +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/anomaly_mode.h b/torch/csrc/autograd/anomaly_mode.h index 013600b230fc7..a4e4210dabfe2 100644 --- a/torch/csrc/autograd/anomaly_mode.h +++ b/torch/csrc/autograd/anomaly_mode.h @@ -21,12 +21,43 @@ struct TORCH_API AnomalyMode { static bool _enabled; }; +/// A RAII guard that enables Anomaly Detection Mode. +/// +/// Anomaly detection mode is useful for debugging problems happening +/// in the backward, such as unexpectedly modified tensors or NaNs +/// occuring in the backward. +/// +/// The enabling of anomaly mode is global - as soon as there is one +/// such guard, it is enabled for all computation and threads. It also +/// comes with a significant performance penalty. +/// +/// Example: +/// @code +/// auto x = torch::tensor({1.}, torch::requires_grad()); +/// { +/// torch::autograd::DetectAnomalyGuard detect_anomaly; +/// auto x = torch::tensor({5.0}, torch::requires_grad()); +/// auto y = x * x; +/// auto z = y * y; +/// y += 1; +/// z.backward(); +/// } +/// @endcode +class TORCH_API DetectAnomalyGuard { + public: + DetectAnomalyGuard(); + ~DetectAnomalyGuard(); +}; struct TORCH_API AnomalyMetadata { virtual ~AnomalyMetadata(); - virtual void store_stack() = 0; - virtual void print_stack(const std::string& current_node_name) = 0; - virtual void assign_parent(const std::shared_ptr& parent_node) = 0; + virtual void store_stack(); + virtual void print_stack(const std::string& current_node_name); + virtual void assign_parent(const std::shared_ptr& parent_node); + + private: + std::string traceback_; + std::shared_ptr parent_; }; }} diff --git a/torch/csrc/autograd/autograd.cpp b/torch/csrc/autograd/autograd.cpp index ab02a03279a1e..e1e70586a079c 100644 --- a/torch/csrc/autograd/autograd.cpp +++ b/torch/csrc/autograd/autograd.cpp @@ -68,17 +68,14 @@ variable_list run_backward( bool keep_graph, bool create_graph, const variable_list& inputs, - bool allow_unused) { + bool allow_unused, + bool accumulate_grad) { size_t num_tensors = outputs.size(); edge_list roots; roots.reserve(num_tensors); for (size_t i = 0; i < num_tensors; i++) { const Variable& output = outputs[i]; auto gradient_edge = impl::gradient_edge(output); - if(output.is_complex()) { - TORCH_WARN_ONCE("Complex backward is not fully supported yet and could lead to wrong ", - "gradients for functions we have not fixed yet"); - } TORCH_CHECK( gradient_edge.function, "element ", i, " of tensors does not require grad and does not have a grad_fn"); @@ -96,6 +93,12 @@ variable_list run_backward( if (!grad_fn) { grad_fn = impl::try_get_grad_accumulator(input); } + if (accumulate_grad) { + TORCH_CHECK( + input.is_leaf(), + "One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor" + ) + } TORCH_CHECK( input.requires_grad(), "One of the differentiated Tensors does not require grad"); @@ -108,7 +111,7 @@ variable_list run_backward( } variable_list grad_inputs = Engine::get_default_engine().execute( - roots, grad_outputs, keep_graph, create_graph, output_edges); + roots, grad_outputs, keep_graph, create_graph, accumulate_grad, output_edges); // check if grad_inputs contains None or not base on the allow_unused flag if (!inputs.empty() && !allow_unused) { size_t num_inputs = inputs.size(); @@ -128,12 +131,13 @@ void backward( const variable_list& tensors, const variable_list& grad_tensors, c10::optional retain_graph, - bool create_graph) { + bool create_graph, + const variable_list& inputs) { variable_list gradients = _make_grads(tensors, grad_tensors); if (!retain_graph) { retain_graph = create_graph; } - run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true); + run_backward(tensors, gradients, retain_graph.value(), create_graph, inputs, /*allow_unused=*/true, /*accumulate_grad=*/true); } variable_list grad( @@ -148,8 +152,21 @@ variable_list grad( retain_graph = create_graph; } return run_backward( - outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused); + outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused, /*accumulate_grad=*/false); +} + + +namespace forward_ad { + +uint64_t enter_dual_level() { + return ForwardADLevel::get_next_idx(); } +void exit_dual_level(uint64_t level) { + ForwardADLevel::release_idx(level); +} + +} // namespace forward_ad + } // namespace autograd } // namespace torch diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h index 90b8fcd3c4b4b..7f905b21c3b66 100644 --- a/torch/csrc/autograd/autograd.h +++ b/torch/csrc/autograd/autograd.h @@ -32,11 +32,16 @@ namespace autograd { /// value of `create_graph`. /// \param create_graph If `true`, graph of the derivative will be constructed, allowing /// to compute higher order derivative products. Defaults to `false`. +/// \param inputs Inputs w.r.t. which the gradient will be accumulated into +/// `at::Tensor::grad`. All other Tensors will be ignored. If not provided, the gradient +/// is accumulated into all the leaf Tensors that were used to compute param `tensors`. +/// All the provided inputs must be leaf Tensors. TORCH_API void backward( const variable_list& tensors, const variable_list& grad_tensors = {}, c10::optional retain_graph = c10::nullopt, - bool create_graph = false); + bool create_graph = false, + const variable_list& inputs = {}); /// Computes and returns the sum of gradients of outputs with respect to the inputs. /// @@ -70,5 +75,20 @@ TORCH_API variable_list grad( bool create_graph = false, bool allow_unused = false); +namespace forward_ad { + +/// Creates a new dual level and returns its index. This level index should then be used to call +/// into the other functions below. +/// This API supports entering a new level before the previous one is exited. We call them nested +/// forward AD levels. These can be used to compute higher order derivatives. +TORCH_API uint64_t enter_dual_level(); + +/// Exits the given level. This will clear up all the gradients from this level and all dual Tensors +/// that had gradients for this level will become regular Tensors again. +/// This function can only be used to exit the innermost nesting level and so exiting must happen in +/// reverse order compared to the entering that was done with the function above. +TORCH_API void exit_dual_level(uint64_t level); + +} // namespace forward_ad } // namespace autograd } // namespace torch diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp new file mode 100644 index 0000000000000..b06b0ff08c886 --- /dev/null +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -0,0 +1,218 @@ +#include + +namespace torch { +namespace autograd { + +using at::Tensor; + +// [Forward Grad View/inplace] +// It is important to us to allow view and inplace to work with dual Tensors. These operations +// should either compute the right gradient or raise a user-friendly error. + +// The basic case where all Tensors are dual Tensors is as follows: +// # Have: +// # foo is a dual Tensor that is not a view +// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view +// +// # Case 1: no view +// foo.copy_(bar) +// +// # Case 2: with view, propagate from view to base +// view = foo[0] +// view.copy_(bar) +// +// # Case 3: with view, propagate from base to view +// view = foo[0] +// foo.copy_(bar) +// +// # In both cases, the forward grad of foo must be properly updated. +// # In the second and third cases, the forward grad of view must match +// # the one of foo for the subset they have in common. +// +// All these cases can be handled by the following layout constraint on the forward grad: +// - A Tensor and its forward grad (for all levels) must have the same metadata (size, stride +// and storage offset). Storage offset must be in this metadata because of as_strided. +// - View operations must create a forward grad that is a view of the base's forward grad. +// - Inplace operations must modify the input's forward grad inplace. +// +// This layout constraint is ensured in the `set_fw_grad` function below + + +// More complex cases arrise when non-dual Tensor interact with dual Tensors. +// The two most important cases are: +// +// # Have: +// # foo is a regular Tensor that is not a view +// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view +// +// # Case 4: Changes on the view must propagate to its base +// view = foo[0] +// # view is still a regular Tensor here +// view.copy_(bar) +// # Now both view and foo are dual Tensor with appropriate forward grad +// +// # Case 5: Changes on the base must propagate on all its views +// view = foo[0] +// # view is still a regular Tensor here +// base.copy_(bar) +// # Now both view and foo are dual Tensor with appropriate forward grad +// +// # NB there is a case 6 involving changes on a view propagating to other views +// # but it is fully described by the two others and is skipped in this discussion. +// +// Case 4 is handled by set_fw_grad by properly setting the forward grad of the base if needed. +// Case 5 is handled in fw_grad by reading the forward grad from the base if needed. + + +namespace { + // Check if two Tensor have the same storage offset, sizes and strides + bool has_same_meta(const Variable& base, const Variable& other) { + if (!base.defined() || !other.defined()) { + return false; + } + if (base.storage_offset() != other.storage_offset()) { + return false; + } + if (base.dim() != other.dim()) { + return false; + } + for (int64_t i=0; i lock(mutex_); + if (!fw_grad_) { + fw_grad_ = std::make_shared(); + } + } + if (fw_grad_->contains(level)) { + // Setting the forward grad again is only allowed if it is a no-op. + // We do allow this case to simplify writing codegen for inplace ops. + TORCH_INTERNAL_ASSERT(new_grad_.defined(), "Cannot set a forward grad that is an undefined Tensor. Use " + "_fw_primal(level) to get a new Tensor with this forward grad unset."); + + TORCH_INTERNAL_ASSERT(is_inplace_op, "Only inplace operations can re-set the forward grad of a Tensor that " + "already has one."); + + TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_), "Cannot set a value of a forward grad if it " + "already exists. Inplace operations should modify it inplace."); + } else { + // TODO(alband) remove this spurious version counter bump + auto new_grad = new_grad_; + + if (is_inplace_op && is_view_) { + auto this_view_meta = static_cast(this); + + // For inplace ops on a Tensor that does not already have a forward grad and is a view, we propagate + // the tangent to the base and ensure that the new_grad is a view of that base's tangent. + // This ensure that case 4 from [Forward Grad View/inplace] above works fine + // What happens in this long if statement is: + // - Check if the base already has a grad + // - If not, set a new fw_grad for it full of zeros + // - Take a view of the base's forward grad + // - Copy the given new_grad into this view + // - Use this view as the new new_grad + if (this_view_meta->has_fw_view()) { + auto view_info = this_view_meta->get_forward_view(); + auto& base = view_info.base_; + + if (!base.fw_grad(level).defined()) { + // Enforce same meta here to make sure that the view op below is always valid + Tensor new_base_fw_grad; + if (has_same_meta(new_grad, base)) { + // TODO extend this special case to when the underlying storage of new_grad + // can be re-used. + new_base_fw_grad = new_grad; + } else { + new_base_fw_grad = new_with_same_meta(base); + + // Update new_grad to be a view of the base + Tensor new_fw_grad_value; + if (view_info.has_view_fn()) { + new_fw_grad_value = view_info.view_fn()(new_base_fw_grad); + } else { + new_fw_grad_value = new_base_fw_grad.as_strided(self.sizes(), self.strides(), self.storage_offset()); + } + + new_fw_grad_value.copy_(new_grad); + new_grad = new_fw_grad_value; + } + + base.set_fw_grad(new_base_fw_grad, level, /* is_inplace_op */ false); + } + } + } + + // Enforce the basic layout constraint + if (!has_same_meta(new_grad, self)) { + Tensor new_grad_with_meta = new_with_same_meta(self); + new_grad_with_meta.copy_(new_grad); + new_grad = new_grad_with_meta; + } + + fw_grad_->set_value(new_grad, level); + } +} + +const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const { + // Ensure that concurent fw_grad() "reads" are thread safe + std::lock_guard lock(mutex_); + + const auto& direct_fw_grad = fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad(); + + if (!direct_fw_grad.defined() && is_view_) { + // For view that don't have a forward grad, check if their base has one that + // has been defined by an inplace operation. + // This ensure that case 5 from [Forward Grad View/inplace] above works fine + auto const_view_meta = static_cast(this); + // This is ok to do as we ONLY modify fw_grad_ and this field is properly locked in all methods + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto this_view_meta = const_cast(const_view_meta); + if (this_view_meta->has_fw_view()) { + const auto& view_info = this_view_meta->get_forward_view(); + const auto& base = view_info.base_; + + const auto& base_val = base.fw_grad(level); + if (base_val.defined()) { + // Lazy initialization of fw_grad_ + this_view_meta->fw_grad_ = std::make_shared(); + + Variable new_val; + if (view_info.has_view_fn()) { + new_val = view_info.view_fn()(base_val); + } else { + new_val = base_val.as_strided(self.sizes(), self.strides(), self.storage_offset()); + } + + this_view_meta->fw_grad_->set_value(new_val, level); + return this_view_meta->fw_grad_->value(level); + } + } + } + return direct_fw_grad; +} + +}} // torch::autograd diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index c7520315d9f33..f6d28ec342b67 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -124,7 +124,7 @@ variable_list _wrap_outputs(const variable_list &input_vars, if (!(is_input && is_modified) && var.is_view()) { // NB: is_view() ==> get_autograd_meta() auto diff_view_meta = static_cast(impl::get_autograd_meta(var)); - diff_view_meta->creation_meta = CreationMeta::IN_CUSTOM_FUNCTION; + diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION); } if (is_differentiable) { @@ -142,7 +142,7 @@ variable_list _wrap_outputs(const variable_list &input_vars, if (var.is_view()) { // NB: is_view() ==> get_autograd_meta() auto diff_view_meta = static_cast(impl::get_autograd_meta(var)); - diff_view_meta->creation_meta = CreationMeta::MULTI_OUTPUT_NODE; + diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE); } } } diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 9aefb552ae88f..055fefda67ccf 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -273,12 +273,12 @@ variable_list CppNode::apply(variable_list&& inputs) { auto outputs = T::backward(&ctx_, backward_inputs); int num_forward_inputs = is_variable_input_.size(); - int num_outputs = outputs.size(); + auto num_outputs = outputs.size(); // Returning too many results is ok, but only as long as they're all undefined. // Truncate the result vector in that case. if (num_outputs > num_forward_inputs) { bool all_undef = true; - for (int i = num_forward_inputs; i < num_outputs; ++i) { + for (size_t i = num_forward_inputs; i < num_outputs; ++i) { all_undef &= (!outputs[i].defined()); } if (all_undef) { diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 62ca26e469399..af295feba51ae 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -227,7 +227,7 @@ Engine::~Engine() { // Do not wait for termination of global threads on Windows // Because CRT terminates DLL threads before calling // global object destructors -#if !defined(_WIN32) || !defined(C10_BUILD_SHARED_LIBS) +#if !defined(_WIN32) || defined(C10_USE_MSVC_STATIC_RUNTIME) std::unique_lock lk(non_reentrant_device_thread_mutex_); while(non_reentrant_device_thread_count_.load() != 0) { non_reentrant_device_thread_condvar_.wait(lk); @@ -513,12 +513,10 @@ void GraphTask::exec_post_processing() { } void GraphTask::set_exception_without_signal(const std::shared_ptr& fn) { - std::unique_lock lock(mutex_); - if (!has_error_.load()) { + if (!has_error_.exchange(true)) { if (AnomalyMode::is_enabled() && fn) { fn->metadata()->print_stack(fn->name()); } - has_error_ = true; } } @@ -845,6 +843,7 @@ auto Engine::execute(const edge_list& roots, const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs) -> variable_list { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) validate_outputs(roots, const_cast(inputs), [](const std::string& msg) { @@ -854,7 +853,7 @@ auto Engine::execute(const edge_list& roots, // A frech first time Engine::execute call should start on the CPU device, initialize // a new thread local ready queue on CPU or reuse the existing one (if there is one // allocated already, i.e. consecutive backward calls, re-entrant backward calls), - // then memorize the local_ready_queue in GraphTask + // then memoize the local_ready_queue in GraphTask init_local_ready_queue(); bool not_reentrant_backward_call = worker_device == NO_DEVICE; @@ -864,15 +863,34 @@ auto Engine::execute(const edge_list& roots, /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1, /* cpu_ready_queue */ local_ready_queue); + // If we receive a single root, skip creating extra root node + bool skip_dummy_node = roots.size() == 1; + auto graph_root = skip_dummy_node ? + roots.at(0).function : + std::make_shared(roots, inputs); + // Now compute the dependencies for all executable functions and queue the root - auto graph_root = std::make_shared(roots, inputs); compute_dependencies(graph_root.get(), *graph_task); if (!outputs.empty()) { - graph_task->init_to_execute(*graph_root, outputs); + graph_task->init_to_execute(*graph_root, outputs, accumulate_grad); } - execute_with_graph_task(graph_task, graph_root); + if (skip_dummy_node) { + InputBuffer input_buffer(roots.at(0).function->num_inputs()); + auto input = inputs.at(0); + + const auto input_stream = InputMetadata(input).stream(); + const auto opt_next_stream = roots.at(0).function->stream(c10::DeviceType::CUDA); + input_buffer.add(roots.at(0).input_nr, + std::move(input), + input_stream, + opt_next_stream); + + execute_with_graph_task(graph_task, graph_root, std::move(input_buffer)); + } else { + execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list())); + } // Avoid a refcount bump for the Future, since we check for refcount in // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1) // in dist_engine.cpp). @@ -891,13 +909,13 @@ void Engine::initialize_device_threads_pool() { std::shared_ptr Engine::execute_with_graph_task( const std::shared_ptr& graph_task, - std::shared_ptr graph_root) { + std::shared_ptr graph_root, + InputBuffer&& input_buffer) { initialize_device_threads_pool(); // Lock mutex for GraphTask. std::unique_lock lock(graph_task->mutex_); - ready_queue(graph_task->cpu_ready_queue_, at::kCPU)->push( - NodeTask(graph_task, std::move(graph_root), InputBuffer(0))); + auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device()); // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the // autograd engine with corresponding GraphTask, and its NOT a re-entrant call @@ -910,8 +928,12 @@ std::shared_ptr Engine::execute_with_graph_task( // set the graph_task owner to the current device graph_task->owner_ = worker_device; - // The owning thread start to drive the engine execution with the GraphTask - // that has already been pushed to the current CPU thread's ready_queue + // Now that all the non-thread safe fields of the graph_task have been populated, + // we can enqueue it. + queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + + // The owning thread start to drive the engine execution for any CPU task that + // was just pushed or will be added later from other worker threads lock.unlock(); thread_main(graph_task); TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed()); @@ -924,6 +946,11 @@ std::shared_ptr Engine::execute_with_graph_task( // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant // backward call from that device. graph_task->owner_ = worker_device; + + // Now that all the non-thread safe fields of the graph_task have been populated, + // we can enqueue it. + queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + if (current_depth >= max_recursion_depth_) { // See Note [Reentrant backwards] // If reached the max depth, switch to a different thread @@ -1024,6 +1051,8 @@ auto Engine::ready_queue_by_index(std::shared_ptr cpu_ready_queue, i TORCH_INTERNAL_ASSERT(cpu_ready_queue); return cpu_ready_queue; } else { + // Static cast is ok here as the number of device should never overflow an int. + TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast(device_ready_queues_.size())); // See Note [Allocating GPUs to autograd threads] // NB: This function would become obsolete if we truly allocated a CPU thread // per device, rather than colocate. @@ -1081,16 +1110,20 @@ void Engine::add_thread_pool_task(const std::weak_ptr& graph_task) { thread_pool_shared_->work_.notify_one(); } -void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs) { +void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad) { exec_info_[&graph_root].needed_ = true; - int output_idx = 0; for (auto & output_edge : outputs) { Node *output = output_edge.function.get(); auto & info = exec_info_[output]; - if (!info.captures_) - info.captures_ = make_unique>(); - info.captures_->emplace_back(output_edge.input_nr, output_idx++); + if (accumulate_grad) { + info.needed_ = true; + } else { + if (!info.captures_) { + info.captures_ = make_unique>(); + } + info.captures_->emplace_back(output_edge.input_nr, output_idx++); + } } captured_vars_.resize(output_idx); @@ -1119,6 +1152,7 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs) { std::vector stack; std::unordered_set seen; for (const auto & input : graph_root.next_edges()) { + if (!input.function.get()) continue; if (seen.count(input.function.get()) > 0) continue; stack.emplace_back(input.function.get()); while (!stack.empty()) { @@ -1138,7 +1172,7 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs) { auto it = exec_info_.find(edge.function.get()); return it != exec_info_.end() && it->second.should_execute(); }); - exec_info_[frame.fn_].needed_ = needed; + exec_info_[frame.fn_].needed_ |= needed; stack.pop_back(); } } diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 0dde6e735d10c..7892f47521c58 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -109,7 +109,7 @@ struct GraphTask: std::enable_shared_from_this { std::unordered_set leaf_streams; - void init_to_execute(Node& graph_root, const edge_list& outputs); + void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad); // The value of worker_device in the thread that created this task. // See Note [Reentrant backwards] @@ -144,7 +144,7 @@ struct GraphTask: std::enable_shared_from_this { // CPU threads are dedicated to processing CPU work for the backward they invoked. // So any given graph task maintains its own cpu_ready_queue_ where you should send - // work for it to be done. We memorize the cpu_ready_queue_ per GraphTask so that + // work for it to be done. We memoize the cpu_ready_queue_ per GraphTask so that // we know which ready queue we should push to if we are on device thread (i.e. GPU) // and but next NodeTask should be run on CPU. std::shared_ptr cpu_ready_queue_; @@ -272,6 +272,7 @@ struct TORCH_API Engine { const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs = {}); // Given a pre-populated GraphTask and GraphRoot, computes the backward pass @@ -281,10 +282,11 @@ struct TORCH_API Engine { // machinery and shouldn't be exposed to users in anyway. virtual std::shared_ptr execute_with_graph_task( const std::shared_ptr& graph_task, - std::shared_ptr graph_root); + std::shared_ptr graph_root, + InputBuffer&& input_buffer); virtual std::unique_ptr make_anomaly_metadata() { - return nullptr; + return std::make_unique(); } // We pass cpu_ready_queue to evaluate_function, so that it knows diff --git a/torch/csrc/autograd/forward_grad.cpp b/torch/csrc/autograd/forward_grad.cpp new file mode 100644 index 0000000000000..bb8f19f252a8b --- /dev/null +++ b/torch/csrc/autograd/forward_grad.cpp @@ -0,0 +1,90 @@ +#include + +namespace torch { namespace autograd { + +namespace { + // See discussion in forward_grad.h for why these are global variables and not + // thread local + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + static std::mutex all_forward_levels_mutex_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + static uint64_t next_forward_idx_ = 0; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + static std::vector> all_forward_levels_; + + const static at::Tensor singleton_undefined_tensor; + + // Temporary flag to disable forward mode + // TODO(alband) remove these when perf issues are solved + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + static bool is_forward_grad_enabled = false; +} + +uint64_t ForwardADLevel::get_next_idx() { + std::lock_guard lock(all_forward_levels_mutex_); + TORCH_CHECK(next_forward_idx_ == 0, "Nested forward mode AD is not supported at the moment"); + auto new_index = next_forward_idx_++; + TORCH_INTERNAL_ASSERT(new_index == all_forward_levels_.size()); + all_forward_levels_.push_back(std::make_shared(new_index)); + return new_index; +} + +void ForwardADLevel::release_idx(uint64_t idx) { + std::lock_guard lock(all_forward_levels_mutex_); + TORCH_CHECK(idx == all_forward_levels_.size() - 1, "Exiting a forward AD level that is not the " + "last that was created is not support. Ensure they are released in the reverse " + "order they were created."); + TORCH_CHECK(idx >= 0, "No forward AD level was created so you cannot exit it."); + next_forward_idx_--; + all_forward_levels_.pop_back(); + +} +std::shared_ptr ForwardADLevel::get_by_idx(uint64_t idx) { + std::lock_guard lock(all_forward_levels_mutex_); + TORCH_CHECK(idx < all_forward_levels_.size(), "Trying to access a forward AD level with an invalid index. " + "This index was either not created or is already deleted."); + return all_forward_levels_[idx]; +} + +std::shared_ptr ForwardADLevel::try_get_by_idx(uint64_t idx) { + std::lock_guard lock(all_forward_levels_mutex_); + if (idx < all_forward_levels_.size()) { + return all_forward_levels_[idx]; + } else { + return nullptr; + } +} + +ForwardADLevel::~ForwardADLevel() { + std::lock_guard lock(mutex_); + auto it = grads_.begin(); + while (it != grads_.end()) { + // Warning this will lock *it mutex + // This is ok as this function is the *only* one to call back into another class's method. + (*it)->reset(idx_, /* update_level */ false); + it = grads_.erase(it); + } +} + +const at::Tensor& ForwardGrad::value(uint64_t level) const { + std::lock_guard lock(mutex_); + const auto& it = content_.find(level); + return it == content_.end() ? singleton_undefined_tensor : (*it).second; +} + +const at::Tensor& ForwardGrad::undef_grad() { + return singleton_undefined_tensor; +} + +// Temporary functions to disable forward AD +// TODO(alband) remove these when perf issues are solved +bool isForwardADEnabled() { + return is_forward_grad_enabled; +} + +void setForwardADEnabled(bool value) { + is_forward_grad_enabled = value; +} + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/forward_grad.h b/torch/csrc/autograd/forward_grad.h new file mode 100644 index 0000000000000..2f0e66034f38e --- /dev/null +++ b/torch/csrc/autograd/forward_grad.h @@ -0,0 +1,193 @@ +#pragma once + +#include + + +namespace torch { namespace autograd { + +// [ Using ForwardGrad ] +// ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner design. But +// this shared_ptr must be uniquely associated with the object that stores it (as of +// writing, either AutogradMeta or SavedVariable). This object is called the "owning object" +// in the discussions below. This owning object must call `ForwardGrad::clear()` when it +// is destroyed to ensure that the ForwardGrad is properly de-allocated. + +struct ForwardGrad; + +// This file contains two classes that are used to store forward AD gradients and +// ensure that they are scoped properly. +// Because forward AD runs concurrently with the evaluation of the function, we need +// a mechanism to separate different forward AD invocations and be able to compute the +// right gradients. We model such invocations as levels here. +// The particular scoping issue mentioned above has two main drivers: +// - Ensure that we can conveniently use forward AD within a high level API without +// leaking the forward AD states outside. +// - Ensure that we can keep the level that we expose to the user API simple (an integer +// that represents the nesting depth) while avoiding confusions when the level index +// is re-used. + +// The important external APIs from this file are: +// - ForwardADLevel::get_next_idx() that can be used to enter a new level and get its index +// - ForwardADLevel::release_idx() that can be used to exit a given level. +// - ForwardGrad() can be used to store a given forward gradient that will handle the level +// tracking automatically. + +// The basic implementation strategy is as follows: +// Every tensor has a ForwardGrad, maintaining a map from levels to tangents. +// ForwardGrad is responsible for registering itself to the appropriate ForwardADLevel when a new +// tangent is added to it via ForwardGrad::set_value and to un-register itself from this same level +// if that tangent is removed via ForwardGrad::reset. +// The ForwardADLevel is created when a new level is entered via ForwardADLevel::get_next_idx. +// A reference to the new ForwardADLevel is stored into a global (for the whole process) vector that +// ensure it can be accessed via ForwardADLevel::get_by_idx. This reference is deleted when the index is +// released by the user when calling ForwardADLevel::release_idx. +// When it is destructed, the ForwardADLevel is responsible for clearing all the tangents for its +// level stored in all the ForwardGrad that registered with it. +// +// This process-wide level design, compared to a thread local one, allows us to use very simple user facing +// handle for the level (an int) while enabling cross-thread forward AD. +// The only required synchronization for the user is when entering and exiting the levels. +// Some discussion on alternative design is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453 +// and can be refined in the future. + +// Correctness of concurrency: +// Each class uses its own lock when reading or modifying internal storages. This allows in particular +// to safely remove tangents from ForwardGrad when the ForwardADLevel is being exited. +// We ensure no deadlock by ensuring that a methods never calls into another class's method while +// the local class's lock is held except in one single case: calling from ForwardADLevel's destructor +// into ForwardGrad::reset with update_level=false. + +// The lifetime of these objects is as follows: +// The ForwardADLevel can be in three states: +// - Initialized: where one of its reference is held by the global vector and there may be more +// references held by temporary variables in ForwardGrad's methods. +// - About to be destructed: where "release_idx" has been called and the only reason for the +// ForwardADLevel not to be destructed right away is that some methods in ForwardGrad have +// owning reference to it. This is done so that a ForwardADLevel can never be destructed when +// a ForwardGrad is registered with it and in the process of adding something to its internal state. +// - Being destructed: Here the ForwardADLevel is not referenced anymore and can be safely reset +// all of the ForwardGrad. Note that we can have more than one reset being called here (which is ok) +// but we are guaranteed that there is at least one. +// The ForwardGrad is simpler as there is no intermediary state and no special destructor for. The logic to +// unregister it from the different ForwardADLevel is done when the owning object (AutogradMeta or +// SavedVariable) is being destroyed. + +// Other considered design: +// To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside the ForwardADLevel. While this +// would work, it would mean that the set inside the ForwardADLevel would only grow unless we do an +// expensive linear scan to remove all the dangling weak pointers. Hence this approach was not used. + +// Data structures in this file are optimized for this maximum number of levels. +// The number of levels corresponds to the degree of the gradient being +// computed using forward AD and we don't expect more than second order gradients +// to be common. +#define EXPECTED_MAX_LEVEL 2 + +struct TORCH_API ForwardADLevel { + ForwardADLevel(uint64_t idx): idx_(idx) {} + ~ForwardADLevel(); + + static uint64_t get_next_idx(); + static void release_idx(uint64_t idx); + static std::shared_ptr get_by_idx(uint64_t idx); + static std::shared_ptr try_get_by_idx(uint64_t idx); + + void erase(const std::shared_ptr& grad) { + std::lock_guard lock(mutex_); + grads_.erase(grad); + } + + void insert(const std::shared_ptr& grad) { + std::lock_guard lock(mutex_); + grads_.insert(grad); + } + +private: + std::unordered_set> grads_; + std::mutex mutex_; + uint64_t idx_; + +}; + +struct TORCH_API ForwardGrad : std::enable_shared_from_this { + + ForwardGrad() {} + + // This function must only be called when AutogradMeta or SavedVariable is being + // destructed as it ensures that: + // - The only (potential) other references to this ForwardGrad are the + // different level it is registered to + // - No other thread will try to call `set_value` or `value` ever from now on + // - Any of the ForwardADLevel that this ForwardGrad is registered with might + // call `reset` at any point during this function + void clear() { + c10::SmallVector levels_idx; + + { + std::lock_guard lock(mutex_); + for (auto& c: content_) { + levels_idx.push_back(c.first); + } + } + + for (auto l_idx: levels_idx) { + // Use "try" version here as another thread might have deleted this + // level before we got here + // This is an owning reference as we want to keep the level alive + // until we successfully unregister ourselves + auto level = ForwardADLevel::try_get_by_idx(l_idx); + if (level) { + level->erase(shared_from_this()); + } + } + } + + void set_value(const at::Tensor& value, uint64_t level) { + // Owning reference to ensure the forward_level is not destroyed + // while we are updating our internal state + auto forward_level = ForwardADLevel::get_by_idx(level); + forward_level->insert(shared_from_this()); + + std::lock_guard lock(mutex_); + content_.insert({level, value}); + } + + // This function removes the tangent for a given level from this ForwardGrad + // Use the update_level flag to disable notifying the level about this reset + // This flag is most notably used by the ForwardADLevel destructor. + void reset(uint64_t level, bool update_level=true) { + if (update_level) { + ForwardADLevel::get_by_idx(level)->erase(shared_from_this()); + } + + std::lock_guard lock(mutex_); + content_.erase(level); + } + + const at::Tensor& value(uint64_t level) const; + + bool contains(uint64_t level) { + std::lock_guard lock(mutex_); + return content_.count(level) > 0; + } + + bool empty() const { + return content_.empty(); + } + + static const at::Tensor& undef_grad(); + + +private: + // TODO(albanD): replace this with a SmallVector + std::unordered_map content_; + mutable std::mutex mutex_; + +}; + +// Temporary functions to disable forward AD +// TODO(alband) remove these when perf issues are solved +bool TORCH_API isForwardADEnabled(); +void TORCH_API setForwardADEnabled(bool value); + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 17d4f5473880a..44171e1a3b1ba 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -114,6 +114,10 @@ struct TORCH_API Node : std::enable_shared_from_this { // We are tracking the parents to track multiple backward operations. assign_parent(); } + + if (profiler::profilerEnabled()) { + thread_id_ = at::RecordFunction::currentThreadId(); + } } explicit Node(edge_list&& next_edges = edge_list()) @@ -129,13 +133,33 @@ struct TORCH_API Node : std::enable_shared_from_this { /// Evaluates the function on the given inputs and returns the result of the /// function call. variable_list operator()(variable_list&& inputs) { - RECORD_FUNCTION( - name(), std::vector(inputs.begin(), inputs.end()), sequence_nr()); // In the first iteration of named tensors, autograd ignores names and // operates on unnamed tensors. In the long term, autograd should // probably operate with names. at::NoNamesGuard no_names_guard; - return apply(std::move(inputs)); + + bool pre_sampled = false; + if (at::shouldRunRecordFunction(&pre_sampled)) { + // Using RecordFunction to trogger observers in the backward pass + at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION, pre_sampled); + if (guard.isActive()) { + // Using sequence number and thread id to correlate with + // the forward pass function + guard.setForwardThreadId(thread_id_); + if (guard.needsInputs()) { + guard.before( + name(), + std::vector(inputs.begin(), inputs.end()), + sequence_nr()); + } else { + guard.before(name(), sequence_nr()); + } + } + // keeping stack guard object alive during the call + return apply(std::move(inputs)); + } else { + return apply(std::move(inputs)); + } } // Graph Connectivity API @@ -241,6 +265,11 @@ struct TORCH_API Node : std::enable_shared_from_this { // assigning a node as a parent to this node void assign_parent(); + /// Id of the thread that created Node + uint64_t thread_id() const noexcept { + return thread_id_; + } + /// Returns the name of the dynamic type of the function, for debugging. virtual std::string name() const; @@ -362,6 +391,9 @@ struct TORCH_API Node : std::enable_shared_from_this { // fields. const uint64_t sequence_nr_; + // Id of the thread that created the instance + uint64_t thread_id_ = 0; + // Note [Thread Safety on Autograd Node] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Autograd Engine let the owning thread which calls Engine::execute to drive the diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index dafd07f64b84f..fdc66e9cd4225 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -149,6 +150,14 @@ struct TORCH_API AccumulateGrad : public Node { auto result = new_grad + variable_grad; CHECK_RESULT(result, variable); update_grad(std::move(result)); + } else if (!at::inplaceIsVmapCompatible(variable_grad, new_grad)) { + // Ideally we'd perform an in-place operation to avoid changing + // the grad tensor. However, if that's impossible because the grads + // are vmap-incompatible (See NOTE: [vmap-incompatible in-place operations]), + // then we just add them out-of-place. + auto result = variable_grad + new_grad; + CHECK_RESULT(result, variable); + update_grad(std::move(result)); } else { // In this case we can avoid changing the grad tensor. There are three // scenarios when we'll hit this case: diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp index 6ce068bd58de9..b5991b87f8352 100644 --- a/torch/csrc/autograd/functions/basic_ops.cpp +++ b/torch/csrc/autograd/functions/basic_ops.cpp @@ -47,4 +47,8 @@ auto UndefinedGradBackward::apply(variable_list&& output_grads) -> variable_list return input_grads; } +auto Identity::apply(variable_list&& grads) -> variable_list { + return std::move(grads); +} + }} // namespace torch::autograd diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index 48f20ec408b07..8a312b7baf0cc 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -68,7 +68,13 @@ struct TORCH_API UndefinedGradBackward : public Node { struct TORCH_API GraphRoot : public Node { GraphRoot(edge_list functions, variable_list inputs) : Node(std::move(functions)), - outputs(std::move(inputs)) {} + outputs(std::move(inputs)) { + // Ensures calls to stream() on a GraphRoot instance reflect current stream(s) + // on devices of root grad tensors at the time the instance is constructed. + for (const auto& t : outputs) { + add_input_metadata(t); + } + } variable_list apply(variable_list&& inputs) override { return outputs; @@ -77,4 +83,8 @@ struct TORCH_API GraphRoot : public Node { variable_list outputs; }; +struct TORCH_API Identity : public Node { + variable_list apply(variable_list&& inputs) override; +}; + }} diff --git a/torch/csrc/autograd/functions/comm.cpp b/torch/csrc/autograd/functions/comm.cpp index adb167656a44f..a5817b72754ad 100644 --- a/torch/csrc/autograd/functions/comm.cpp +++ b/torch/csrc/autograd/functions/comm.cpp @@ -119,10 +119,17 @@ variable_list Gather::apply(variable_list&& inputs) { } } - // This is special logic for torch::cuda::gather! - const auto destination_index = - destination_device_.is_cpu() ? -1 : destination_device_.index(); - auto variable = torch::cuda::gather(tensors, dim_, destination_index); + // Disable the autograd during the actual computation + // torch::cuda::gather does not return a view or change things inplace + // so no need for extra logic here + at::Tensor variable; + { + at::AutoNonVariableTypeMode non_var_type_mode(true); + // This is special logic for torch::cuda::gather! + const auto destination_index = + destination_device_.is_cpu() ? -1 : destination_device_.index(); + variable = torch::cuda::gather(tensors, dim_, destination_index); + } if (grad_fn) { set_history(variable, grad_fn); } diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index 7b6ab0769e4d2..34443d7035327 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -42,7 +42,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list { CopySlices::CopySlices( const Variable& base_var, at::TensorGeometry view_, - c10::optional> view_fn_, + std::function view_fn_, std::shared_ptr fn_) : Node(), base(base_var), @@ -75,13 +75,12 @@ auto CopySlices::apply(variable_list&& inputs) -> variable_list { throw std::runtime_error(ERR_BACKWARD_TWICE); } - auto result = at::empty_strided(base.sizes(), base.strides(), grad.options()); + auto result = grad.new_empty_strided(base.sizes(), base.strides()); result.copy_(grad); at::Tensor grad_slice; - if (view_fn.has_value()) { - auto fn = view_fn.value(); - grad_slice = fn(result); + if (view_fn) { + grad_slice = view_fn(result); } else { auto offset = view.storage_offset() - base.storage_offset(); grad_slice = result.as_strided(view.sizes(), view.strides(), offset); diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index 7a71d6430c7e2..1e2bbc3d3ecfa 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -86,7 +86,7 @@ struct TORCH_API CopySlices : public Node { CopySlices( const Variable& base_var, at::TensorGeometry view_, - c10::optional> view_fn_, + std::function view_fn_, std::shared_ptr fn_); variable_list apply(variable_list&& inputs) override; @@ -96,7 +96,7 @@ struct TORCH_API CopySlices : public Node { // view and view_fn are redundant and view_fn will be used if available. // See Note [View + Inplace update for base tensor] for details. at::TensorGeometry view; - c10::optional> view_fn; + std::function view_fn; std::shared_ptr fn; }; diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 69759d1948b21..975f1bf954a09 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -1,12 +1,17 @@ #include +#include #include #include +#include #include #include #include #include #include +#include +#include +#include PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { using namespace torch::autograd::profiler; @@ -34,46 +39,153 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { auto _C_m = py::handle(torch_C_module).cast(); auto m = _C_m.def_submodule("_autograd", "autograd bindings"); + auto parameter_module = THPObjectPtr(PyImport_ImportModule("torch.nn.parameter")); + if (!parameter_module) + return nullptr; + + // NOTE: "leaks" ParameterClass + ParameterClass = PyObject_GetAttrString(parameter_module, "Parameter"); + if (!ParameterClass) + return nullptr; py::enum_(m, "ProfilerState") .value("Disabled", ProfilerState::Disabled) .value("CPU", ProfilerState::CPU) .value("CUDA", ProfilerState::CUDA) - .value("NVTX", ProfilerState::NVTX); + .value("NVTX", ProfilerState::NVTX) + .value("KINETO", ProfilerState::KINETO); + + py::enum_(m, "ProfilerActivity") + .value("CPU", ActivityType::CPU) + .value("CUDA", ActivityType::CUDA); py::class_(m, "ProfilerConfig") - .def(py::init()); - - py::class_(m, "ProfilerEvent") - .def("kind", &Event::kind) - .def("name", [](const Event& e) { return e.name(); }) - .def("thread_id", &Event::thread_id) - .def("device", &Event::device) - .def("cpu_elapsed_us", &Event::cpu_elapsed_us) - .def("cuda_elapsed_us", &Event::cuda_elapsed_us) - .def("has_cuda", &Event::has_cuda) - .def("shapes", &Event::shapes) - .def("cpu_memory_usage", &Event::cpu_memory_usage) - .def("cuda_memory_usage", &Event::cuda_memory_usage) - .def("handle", &Event::handle) - .def("node_id", &Event::node_id) - .def("is_remote", &Event::isRemote) - .def("sequence_nr", &Event::sequence_nr); + .def(py::init()); + + py::class_(m, "ProfilerEvent") + .def("kind", &LegacyEvent::kindStr) + .def("name", [](const LegacyEvent& e) { return e.name(); }) + .def("thread_id", &LegacyEvent::threadId) + .def("fwd_thread_id", &LegacyEvent::fwdThreadId) + .def("device", &LegacyEvent::device) + .def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs) + .def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs) + .def("has_cuda", &LegacyEvent::hasCuda) + .def("shapes", &LegacyEvent::shapes) + .def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage) + .def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage) + .def("handle", &LegacyEvent::handle) + .def("node_id", &LegacyEvent::nodeId) + .def("is_remote", &LegacyEvent::isRemote) + .def("sequence_nr", &LegacyEvent::sequenceNr) + .def("stack", &LegacyEvent::stack) + .def("scope", &LegacyEvent::scope) + .def("correlation_id", &LegacyEvent::correlationId) + .def("start_us", &LegacyEvent::cpuUs) + .def("flops", &LegacyEvent::flops); + + py::enum_(m, "DeviceType") + .value("CPU", c10::DeviceType::CPU) + .value("CUDA", c10::DeviceType::CUDA) + .value("MKLDNN", c10::DeviceType::MKLDNN) + .value("OPENGL", c10::DeviceType::OPENGL) + .value("OPENCL", c10::DeviceType::OPENCL) + .value("IDEEP", c10::DeviceType::IDEEP) + .value("HIP", c10::DeviceType::HIP) + .value("FPGA", c10::DeviceType::FPGA) + .value("MSNPU", c10::DeviceType::MSNPU) + .value("XLA", c10::DeviceType::XLA) + .value("Vulkan", c10::DeviceType::Vulkan) + .value("Metal", c10::DeviceType::Metal); + +#ifdef USE_KINETO + py::class_(m, "KinetoEvent") + // name of the event + .def("name", &KinetoEvent::name) + // PyTorch thread id of the start callback + .def("start_thread_id", [](const KinetoEvent& e) { + return e.startThreadId(); + }) + // PyTorch thread id of the end callback + .def("end_thread_id", [](const KinetoEvent& e) { + return e.endThreadId(); + }) + // for events of scope BACKWARD_FUNCTION - PyTorch thread id + // of the corresponding forward op + .def("fwd_thread_id", [](const KinetoEvent& e) { + return e.fwdThreadId(); + }) + // together with fwd_thread_id, used to uniquely identify + // the forward op + .def("sequence_nr", [](const KinetoEvent& e) { + return e.sequenceNr(); + }) + // absolute start time (since unix epoch) in us + .def("start_us", &KinetoEvent::startUs) + // duration in us + .def("duration_us", &KinetoEvent::durationUs) + // used for correlation between high-level PyTorch events + // and low-level device events + .def("correlation_id", [](const KinetoEvent& e) { + return e.correlationId(); + }) + // shapes of input tensors + .def("shapes", [](const KinetoEvent& e) { + if (e.hasShapes()) { + return e.shapes(); + } else { + return std::vector>(); + } + }) + // stack traces of the PyTorch CPU events + .def("stack", [](const KinetoEvent& e) { + if (e.hasStack()) { + return e.stack(); + } else { + return std::vector(); + } + }) + // type of the RecordFunction that generated a PyTorch CPU event + // (op, torchscript function, user label, etc) + .def("scope", [](const KinetoEvent& e) { + return e.scope(); + }) + // device number, for CPU - process id + .def("device_index", &KinetoEvent::deviceIndex) + // for CUDA - stream id, for CPU - start thread id + .def("device_resource_id", &KinetoEvent::deviceResourceId) + // device type + .def("device_type", [](const KinetoEvent& e) { + return e.deviceType(); + }) + // correlation id of a linked event + .def("linked_correlation_id", &KinetoEvent::linkedCorrelationId); + + py::class_(m, "ProfilerResult") + .def("events", &ProfilerResult::events) + .def("legacy_events", &ProfilerResult::legacy_events) + .def("save", &ProfilerResult::save); m.def("_enable_profiler", enableProfiler); + m.def("_disable_profiler", disableProfiler); + m.def("_prepare_profiler", prepareProfiler); +#endif + + m.def("kineto_available", kinetoAvailable); + + m.def("_enable_profiler_legacy", enableProfilerLegacy); + py::class_(m, "_ProfilerDisableOptions") + .def(py::init()); m.def( - "_disable_profiler", - disableProfiler, - py::arg("cleanup_tls_states") = true, - py::arg("consolidate") = true); + "_disable_profiler_legacy", + disableProfilerLegacy, + py::arg("profiler_disable_options") = ProfilerDisableOptions()); m.def("_profiler_enabled", profilerEnabled); m.def("_enable_record_function", [](bool enable) { at::enableRecordFunction(enable); }); m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { - auto cb = at::RecordFunctionCallback( - [](const at::RecordFunction&) {}, - [](const at::RecordFunction&) {}) + auto cb = at::RecordFunctionCallback(nullptr) .needsInputs(true) .samplingProb(sampling_prob); if (is_global) { @@ -130,6 +242,26 @@ static PyObject * autocast_decrement_nesting(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } +static PyObject * set_forward_AD_enabled(PyObject* _unused, PyObject *arg) { + HANDLE_TH_ERRORS + if (!PyBool_Check(arg)) { + throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); + } + setForwardADEnabled(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject * is_forward_AD_enabled(PyObject* _unused, PyObject *arg) { + HANDLE_TH_ERRORS + if (isForwardADEnabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { @@ -170,17 +302,43 @@ static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } +static PyObject * python_enter_dual_level(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + // It is unlikely that the depth of forward nesting will overflow int64_t so we + // just static cast here. + return utils::wrap(static_cast(forward_ad::enter_dual_level())); + END_HANDLE_TH_ERRORS +} + +static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "exit_dual_level(int64_t level)" + }); + + ParsedArgs<1> parsed_args; + auto _r = parser.parse(args, kwargs, parsed_args); + + forward_ad::exit_dual_level(_r.toInt64(0)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + // autograd methods on torch._C static PyMethodDef methods[] = { // NOLINT - {"_set_grad_enabled", (PyCFunction)set_grad_enabled, METH_O, nullptr}, - {"is_grad_enabled", (PyCFunction)is_grad_enabled, METH_NOARGS, nullptr}, - {"set_autocast_enabled", (PyCFunction)set_autocast_enabled, METH_O, nullptr}, - {"is_autocast_enabled", (PyCFunction)is_autocast_enabled, METH_NOARGS, nullptr}, - {"clear_autocast_cache", (PyCFunction)clear_autocast_cache, METH_NOARGS, nullptr}, - {"autocast_increment_nesting", (PyCFunction)autocast_increment_nesting, METH_NOARGS, nullptr}, - {"autocast_decrement_nesting", (PyCFunction)autocast_decrement_nesting, METH_NOARGS, nullptr}, - {"set_anomaly_enabled", (PyCFunction)set_anomaly_mode_enabled, METH_O, nullptr}, - {"is_anomaly_enabled", (PyCFunction)is_anomaly_mode_enabled, METH_NOARGS, nullptr}, + {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr}, + {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr}, + {"_set_forward_AD_enabled", set_forward_AD_enabled, METH_O, nullptr}, + {"_is_forward_AD_enabled", is_forward_AD_enabled, METH_NOARGS, nullptr}, + {"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr}, + {"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr}, + {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr}, + {"autocast_increment_nesting", autocast_increment_nesting, METH_NOARGS, nullptr}, + {"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr}, + {"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr}, + {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr}, + {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, + {"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr}, {nullptr, nullptr, 0, nullptr} }; diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index 02bcde5d9f968..6aa498f056852 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -21,6 +21,7 @@ struct InputBuffer { : buffer(size) {} InputBuffer(const InputBuffer& other) = delete; InputBuffer(InputBuffer&& other) = default; + explicit InputBuffer(variable_list&& inputs): buffer(std::move(inputs)) {}; InputBuffer& operator=(InputBuffer&& other) = default; // Accumulates the variable at a specified index. diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h index 6a7c5095a0716..7ac44096cda70 100644 --- a/torch/csrc/autograd/profiler.h +++ b/torch/csrc/autograd/profiler.h @@ -1,402 +1,4 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef _WIN32 -#include -#endif -#if defined(C10_IOS) && defined(C10_MOBILE) -#include // for gettimeofday() -#endif - -#include - -struct CUevent_st; -typedef std::shared_ptr CUDAEventStub; - -namespace torch { namespace autograd { - -struct Node; - -namespace profiler { - -struct TORCH_API CUDAStubs { - virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) { - fail(); - } - virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) { - fail(); - return 0.f; - } - virtual void nvtxMarkA(const char* name) { - fail(); - } - virtual void nvtxRangePushA(const char* name) { - fail(); - } - virtual void nvtxRangePop() { - fail(); - } - virtual bool enabled() { - return false; - } - virtual void onEachDevice(std::function op) { - fail(); - } - virtual void synchronize() { - fail(); - } - virtual ~CUDAStubs(); - -private: - void fail() { - AT_ERROR("CUDA used in profiler but not enabled."); - } -}; - -TORCH_API void registerCUDAMethods(CUDAStubs* stubs); - -constexpr inline size_t ceilToMultiple(size_t a, size_t b) { - return ((a + b - 1) / b) * b; -} - -inline int64_t getTime() { -#if defined(C10_IOS) && defined(C10_MOBILE) -// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on -// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not - struct timeval now; - gettimeofday(&now, NULL); - return static_cast(now.tv_sec) * 1000000000 + static_cast(now.tv_usec) * 1000; -#elif defined(_WIN32) || defined(__MACH__) - using namespace std::chrono; - using clock = std::conditional::type; - return duration_cast(clock::now().time_since_epoch()).count(); -#else - // clock_gettime is *much* faster than std::chrono implementation on Linux - struct timespec t{}; - clock_gettime(CLOCK_MONOTONIC, &t); - return static_cast(t.tv_sec) * 1000000000 + static_cast(t.tv_nsec); -#endif -} - -enum class C10_API_ENUM ProfilerState { - Disabled, - CPU, // CPU-only profiling - CUDA, // CPU + CUDA events - NVTX, // only emit NVTX markers -}; - -struct TORCH_API ProfilerConfig { - ProfilerConfig( - ProfilerState state, - bool report_input_shapes, - bool profile_memory) - : state(state), - report_input_shapes(report_input_shapes), - profile_memory(profile_memory) {} - ~ProfilerConfig(); - ProfilerState state; - bool report_input_shapes; - bool profile_memory; - - // Returns IValues corresponding to ProfilerConfig struct, to be used for - // serialization. - at::IValue toIValue() const; - - // Reconstructs a ProfilerConfig from IValues given by toIValue. - static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue); - -}; - -enum class C10_API_ENUM EventKind : uint16_t { - Mark, - PushRange, - PopRange, - MemoryAlloc, -}; - -struct TORCH_API Event final { - Event( - EventKind kind, - at::StringView name, - uint16_t thread_id, - bool record_cuda, - at::RecordFunctionHandle handle = 0, - std::vector>&& shapes = {}, - int node_id = -1) - : name_(std::move(name)), - kind_(kind), - thread_id_(thread_id), - handle_(handle), - shapes_(shapes), - node_id_(node_id) { - record(record_cuda); - } - - // Constructor to be used in conjunction with Event::fromIValue. - Event( - EventKind kind, - at::StringView name, - uint16_t thread_id, - at::RecordFunctionHandle handle, - std::vector>&& shapes, - int node_id, - bool is_remote, - int64_t cpu_memory_usage, - int64_t cpu_ns, - bool cuda_recorded, - int64_t cuda_memory_usage = 0, - int device = -1, - double cuda_us = -1) - : cpu_ns_(cpu_ns), - name_(std::move(name)), - kind_(kind), - thread_id_(thread_id), - handle_(handle), - shapes_(shapes), - cpu_memory_usage_(cpu_memory_usage), - cuda_memory_usage_(cuda_memory_usage), - device_(device), - node_id_(node_id), - is_remote_(is_remote), - cuda_us_(cuda_us) { - // Sanity check values that were deserialized - TORCH_INTERNAL_ASSERT(cpu_ns_ > 0); - if (cuda_recorded) { - TORCH_INTERNAL_ASSERT(device_ >= 0); - TORCH_INTERNAL_ASSERT(cuda_us_ >= 0); - } - } - - // Returns IValues corresponding to event structure, to be used for - // serialization. - at::IValue toIValue() const; - - // Reconstructs an event from IValues given by toIValue. - static Event fromIValue(const at::IValue& eventIValue); - - void record(bool record_cuda); - std::string kind() const { - switch(kind_) { - case EventKind::Mark: return "mark"; - case EventKind::PushRange: return "push"; - case EventKind::PopRange: return "pop"; - case EventKind::MemoryAlloc: return "memory_alloc"; - } - throw std::runtime_error("unknown EventKind"); - } - - // Get enum kind of this event. - EventKind eventKind() const { - return kind_; - } - - const char* name() const { - return name_.str(); - } - uint16_t thread_id() const { - return thread_id_; - } - std::vector> shapes() const { - return shapes_; - } - double cpu_elapsed_us(const Event & e) const { - return (e.cpu_ns_ - cpu_ns_)/(1000.0); - } - - double cpu_us() const { - return cpu_ns_ / (1000.0); - } - - double cuda_elapsed_us(const Event & e) const; - bool has_cuda() const { - return cuda_event != nullptr || (isRemote() && device_ != -1); - } - int device() const { - return device_; - } - - void updateMemoryStats(int64_t alloc_size, c10::Device device) { - if (device.type() == c10::DeviceType::CUDA || - device.type() == c10::DeviceType::HIP) { - cuda_memory_usage_ = alloc_size; - } else if (device.type() == c10::DeviceType::CPU || - device.type() == c10::DeviceType::MKLDNN || - device.type() == c10::DeviceType::IDEEP) { - cpu_memory_usage_ = alloc_size; - } else { - LOG(WARNING) << "Unsupported memory profiling device: " << device; - } - } - - int64_t cpu_memory_usage() const { - return cpu_memory_usage_; - } - - int64_t cuda_memory_usage() const { - return cuda_memory_usage_; - } - - at::RecordFunctionHandle handle() const { - return handle_; - } - - // Node ID corresponding to this event. - int node_id( ) const { - return node_id_; - } - - // Set Node ID on this event. - void setNodeId(int node_id) { - node_id_ = node_id; - } - - void setName(at::StringView newName_) { - name_ = std::move(newName_); - } - - bool isRemote() const { - return is_remote_; - } - - void setCudaUs(int64_t cuda_us) { - cuda_us_ = cuda_us; - } - - void setSequenceNr(int64_t sequence_nr) { - sequence_nr_ = sequence_nr; - } - - int64_t sequence_nr() const { - return sequence_nr_; - } - - private: - // signed to allow for negative intervals, initialized for safety. - int64_t cpu_ns_ = 0; - at::StringView name_; - EventKind kind_; - uint16_t thread_id_; - at::RecordFunctionHandle handle_ {0}; - std::vector> shapes_; - int64_t cpu_memory_usage_ = 0; - int64_t cuda_memory_usage_ = 0; - int device_ = -1; - CUDAEventStub cuda_event = nullptr; - int node_id_ = 0; - bool is_remote_ = false; - int64_t cuda_us_ = -1; - int64_t sequence_nr_ = -1; -}; - -// a linked-list of fixed sized vectors, to avoid -// a std::vector resize from taking a large amount of time inside -// a profiling event -struct RangeEventList { - RangeEventList() { - events_.reserve(kReservedCapacity); - } - - template - void record(Args&&... args) { - std::lock_guard guard(mutex_); - events_.emplace_back(std::forward(args)...); - } - - std::vector consolidate() { - std::lock_guard lock(mutex_); - std::vector result; - result.insert( - result.begin(), - std::make_move_iterator(events_.begin()), - std::make_move_iterator(events_.end())); - events_.erase(events_.begin(), events_.end()); - return result; - } - - size_t size() { - std::lock_guard lock(mutex_); - return events_.size(); - } - - private: - // This mutex is used to serialize access when different threads are writing - // to the same instance of RangeEventList. - std::mutex mutex_; - std::vector events_; - - static const size_t kReservedCapacity = 1024; -}; - -using thread_event_lists = std::vector>; -// NOTE: profiler mode is thread local, with automatic propagation -// across thread boundary (e.g. at::launch tasks) -TORCH_API void enableProfiler(const ProfilerConfig&); -TORCH_API thread_event_lists disableProfiler(bool cleanupTLSState = true, bool consolidate = true); -// adds profiledEvents to the current thread local recorded events. Each event -// will be marked with node ID given by fromNodeId. -TORCH_API void addEventList(std::vector&& profiledEvents); -// Returns if the profiler is currently enabled in the current thread. -TORCH_API bool profilerEnabled(); -// Retrieve the thread_local ProfilerConfig. -TORCH_API ProfilerConfig getProfilerConfig(); -// Writes profiled events to a stream. -TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector& events); - -// Usage: -// { -// RecordProfile guard("filename.trace"); -// // code you want to profile -// } -// Then open filename.trace in chrome://tracing -struct TORCH_API RecordProfile { - RecordProfile(std::ostream& out); - RecordProfile(const std::string& filename); - - ~RecordProfile(); -private: - void init(); - std::unique_ptr file_; - std::ostream& out_; - void processEvents(const std::vector& events); -}; - -// A guard that enables the profiler, taking in an optional callback to process -// the results -// Usage: -// { -// TLSProfilerGuard g([](thread_event_lists profilerResults) { -// // process profilerResults -// }); -// Code to profile -// } -struct TORCH_API TLSProfilerGuard { - explicit TLSProfilerGuard( - const ProfilerConfig& cfg, - c10::optional> - resultCallback = c10::nullopt) - : cb_(std::move(resultCallback)) { - enableProfiler(cfg); - } - ~TLSProfilerGuard() { - thread_event_lists event_lists = disableProfiler(); - if (cb_) { - (*cb_)(event_lists); - } - } - - private: - c10::optional> cb_; -}; - -} // namespace profiler -}} // namespace torch::autograd +#include +#include diff --git a/torch/csrc/autograd/profiler_cuda.cpp b/torch/csrc/autograd/profiler_cuda.cpp index ad677dbc6680c..14dff19629fde 100644 --- a/torch/csrc/autograd/profiler_cuda.cpp +++ b/torch/csrc/autograd/profiler_cuda.cpp @@ -32,7 +32,7 @@ static inline void cudaCheck(cudaError_t result, const char * file, int line) { #define TORCH_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__); struct CUDAMethods : public CUDAStubs { - void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) override { + void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const override { TORCH_CUDA_CHECK(cudaGetDevice(device)); CUevent_st* cuda_event_ptr; TORCH_CUDA_CHECK(cudaEventCreate(&cuda_event_ptr)); @@ -43,23 +43,28 @@ struct CUDAMethods : public CUDAStubs { *cpu_ns = getTime(); TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream)); } - float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) override { + + float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const override{ TORCH_CUDA_CHECK(cudaEventSynchronize(event->get())); TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get())); float ms; TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get())); return ms*1000.0; } - void nvtxMarkA(const char* name) override { + + void nvtxMarkA(const char* name) const override { ::nvtxMark(name); } - void nvtxRangePushA(const char* name) override { + + void nvtxRangePushA(const char* name) const override { ::nvtxRangePushA(name); } - void nvtxRangePop() override { + + void nvtxRangePop() const override { ::nvtxRangePop(); } - void onEachDevice(std::function op) override { + + void onEachDevice(std::function op) const override { at::cuda::OptionalCUDAGuard device_guard; int count = at::cuda::device_count(); for(int i = 0; i < count; i++) { @@ -67,13 +72,14 @@ struct CUDAMethods : public CUDAStubs { op(i); } } - void synchronize() override { + + void synchronize() const override { cudaDeviceSynchronize(); } - bool enabled() override { + + bool enabled() const override { return true; } - }; struct RegisterCUDAMethods { diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp new file mode 100644 index 0000000000000..1c3c351eeb09e --- /dev/null +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -0,0 +1,369 @@ +#include + +#include +#include + +#include + +#ifdef USE_KINETO +#include +#include +#endif + +namespace torch { namespace autograd { namespace profiler { + +#ifdef USE_KINETO +namespace { +// TODO: consider TLS (tid + tls counter) +uint64_t next_correlation_id() { + static std::atomic corr_id_ {1}; + return corr_id_++; +} + +inline int64_t getTimeUs() { + using namespace std::chrono; + return duration_cast(high_resolution_clock::now().time_since_epoch()).count(); +} + +std::string shapesToStr(const std::vector>& shapes); + +struct TORCH_API KinetoThreadLocalState : public ProfilerThreadLocalState { + using ProfilerThreadLocalState::ProfilerThreadLocalState; + virtual ~KinetoThreadLocalState() override = default; + + void reportClientActivity( + const at::RecordFunction& fn, + const KinetoObserverContext* ctx) { + if (!ctx) { + return; + } + libkineto::ClientTraceActivity op; + op.startTime = ctx->startUs; + op.endTime = getTimeUs(); + op.opType = std::string(fn.name().str()); + op.device = 0; + op.threadId = ctx->startThreadId; + op.correlation = ctx->correlationId; + // optimization - postpone shapesToStr till finalizeCPUTrace + // is called from disableProfiler + // if (ctx->shapes && !ctx->shapes->empty()) { + // op.inputDims = shapesToStr(*ctx->shapes); + // } + + // Not setting atm + op.inputTypes = "[]"; + op.arguments = "[]"; + op.outputDims = "[]"; + op.outputTypes = "[]"; + op.inputNames = "[]"; + op.outputNames = "[]"; + + // + op.threadId = pthread_self(); + { + std::lock_guard guard(state_mutex_); + kineto_events_.emplace_back(); + kineto_events_.back() + .activity(op) + .startThreadId(ctx->startThreadId) + .endThreadId(ctx->endThreadId) + .sequenceNr(ctx->sequenceNr) + .fwdThreadId(ctx->fwdThreadId) + .scope(ctx->recFunScope); + if (ctx->shapes && !ctx->shapes->empty()) { + kineto_events_.back().shapes(*ctx->shapes); + } + if (ctx->stack && !ctx->stack->empty()) { + kineto_events_.back().stack(*ctx->stack); + } + cpu_trace->activities.emplace_back(std::move(op)); + } + } + + // TODO: use kineto + void reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) override { + if (config_.profile_memory && config_.state != ProfilerState::Disabled) { + uint64_t thread_id = at::RecordFunction::currentThreadId(); + LegacyEvent evt( + EventKind::MemoryAlloc, + at::StringView(""), + thread_id, + config_.state == ProfilerState::CUDA); + evt.setCpuUs(getTimeUs()); // upd. time using Kineto's clock + evt.updateMemoryStats(alloc_size, device); + getEventList(thread_id).record(std::move(evt)); + } + } + + void addTraceEvents(libkineto::ActivityTraceInterface& trace) { + const auto& events = *(trace.activities()); + for (const auto& ev_ptr : events) { + // ClientTraceActivity events are already processed + if (ev_ptr->type() != libkineto::ActivityType::CPU_OP) { + kineto_events_.emplace_back(); + kineto_events_.back() + .activity(*ev_ptr); + } + } + } + + void finalizeCPUTrace() { + TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size()); + for (auto idx = 0; idx < cpu_trace->activities.size(); ++idx) { + if (kineto_events_[idx].hasShapes()) { + cpu_trace->activities[idx].inputDims = shapesToStr(kineto_events_[idx].shapes()); + } else { + cpu_trace->activities[idx].inputDims = "[]"; + } + } + } + + std::vector kineto_events_; + std::unique_ptr cpu_trace = + std::make_unique(); +}; + +KinetoThreadLocalState* getProfilerTLSState() { + const auto& state = c10::ThreadLocalDebugInfo::get( + c10::DebugInfoKind::PROFILER_STATE); + return static_cast(state); +} + +void pushProfilingCallbacks() { + auto state_ptr = getProfilerTLSState(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( + [](const at::RecordFunction& fn) -> std::unique_ptr { + auto state_ptr = getProfilerTLSState(); + if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) { + return std::make_unique(); + } + + auto corr_id = next_correlation_id(); + libkineto::api().activityProfiler().pushCorrelationId(corr_id); + + auto ctx_ptr = std::make_unique(); + ctx_ptr->startUs = getTimeUs(); + ctx_ptr->correlationId = corr_id; + ctx_ptr->startThreadId = at::RecordFunction::currentThreadId(); + + if (state_ptr->config().report_input_shapes) { + ctx_ptr->shapes = inputSizes(fn); + } + + ctx_ptr->sequenceNr = fn.seqNr(); + ctx_ptr->fwdThreadId = fn.forwardThreadId(); + ctx_ptr->recFunScope = (uint8_t)fn.scope(); + +#ifndef C10_MOBILE + // backward nodes source range corresponds to the forward node + // TODO: consider using C++ stack trace + if (state_ptr->config().with_stack && + fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { + auto cs = prepareCallstack(jit::currentCallstack()); + if (cs.empty()) { + cs = prepareCallstack(jit::tracer::pythonCallstack()); + } + ctx_ptr->stack = callstackStr(cs); + } +#endif + return ctx_ptr; + }, + [](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) { + auto state_ptr = getProfilerTLSState(); + if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) { + return; + } + auto* kineto_ctx_ptr = static_cast(ctx_ptr); + TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr); + + kineto_ctx_ptr->endThreadId = at::RecordFunction::currentThreadId(); + + state_ptr->reportClientActivity(fn, kineto_ctx_ptr); + libkineto::api().activityProfiler().popCorrelationId(); + }) + .needsInputs(state_ptr->config().report_input_shapes) + .needsIds(true)); + state_ptr->setCallbackHandle(handle); +} + +std::string shapesToStr(const std::vector>& shapes) { + std::ostringstream oss; + oss << "["; + for (auto t_idx = 0; t_idx < shapes.size(); ++t_idx) { + if (t_idx > 0) { + oss << ", "; + } + oss << "["; + for (auto s_idx = 0; s_idx < shapes[t_idx].size(); ++s_idx) { + if (s_idx > 0) { + oss << ", "; + } + oss << shapes[t_idx][s_idx]; + } + oss << "]"; + } + oss << "]"; + return oss.str(); +} + +} // namespace + +void prepareProfiler( + const ProfilerConfig& config, + const std::set& activities) { + TORCH_CHECK(config.state == ProfilerState::KINETO, + "Supported only in Kineto profiler"); + + std::set cpuTypes = { + libkineto::ActivityType::CPU_OP, + libkineto::ActivityType::EXTERNAL_CORRELATION, + libkineto::ActivityType::CUDA_RUNTIME, + }; + + std::set cudaTypes = { + libkineto::ActivityType::GPU_MEMCPY, + libkineto::ActivityType::GPU_MEMSET, + libkineto::ActivityType::CONCURRENT_KERNEL, + // also including CUDA_RUNTIME + libkineto::ActivityType::CUDA_RUNTIME, + }; + + std::set k_activities; + if (activities.count(ActivityType::CPU)) { + k_activities.insert(cpuTypes.begin(), cpuTypes.end()); + } + if (activities.count(ActivityType::CUDA)) { + k_activities.insert(cudaTypes.begin(), cudaTypes.end()); + } + + if (!libkineto::api().isProfilerRegistered()) { + libkineto_init(); + libkineto::api().suppressLogMessages(); + } + + if (!libkineto::api().isProfilerInitialized()) { + libkineto::api().initProfilerIfRegistered(); + } + + libkineto::api().activityProfiler().prepareTrace(k_activities); +} + +void enableProfiler( + const ProfilerConfig& config, + const std::set& activities) { + TORCH_CHECK(config.state == ProfilerState::KINETO); + TORCH_CHECK(!activities.empty(), "No activities specified for Kineto profiler"); + + auto state_ptr = getProfilerTLSState(); + TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread"); + auto state = std::make_shared(config); + c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); + + state->cpu_trace = std::make_unique(); + state->cpu_trace->span.startTime = getTimeUs(); + // TODO: number of GPU ops + state->cpu_trace->gpuOpCount = -1; + state->cpu_trace->span.name = "PyTorch Profiler"; + + if (activities.count(ActivityType::CPU)) { + pushProfilingCallbacks(); + } + + libkineto::api().activityProfiler().startTrace(); + + state->mark("__start_profile", false); +} + +std::unique_ptr disableProfiler() { + // all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard + auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); + + auto state_ptr = static_cast(state.get()); + TORCH_CHECK(state_ptr && state_ptr->config().state == ProfilerState::KINETO, + "Can't disable Kineto profiler when it's not running"); + + if (state_ptr->hasCallbackHandle()) { + at::removeCallback(state_ptr->callbackHandle()); + } + + state_ptr->mark("__stop_profile"); + + state_ptr->cpu_trace->span.endTime = getTimeUs(); + + state_ptr->finalizeCPUTrace(); + libkineto::api().activityProfiler().transferCpuTrace(std::move(state_ptr->cpu_trace)); + + auto trace = std::move(libkineto::api().activityProfiler().stopTrace()); + TORCH_CHECK(trace); + state_ptr->addTraceEvents(*trace); + return std::make_unique( + std::move(state_ptr->kineto_events_), + std::move(state_ptr->consolidate()), + std::move(trace)); +} + +KinetoEvent& KinetoEvent::activity(const libkineto::TraceActivity& activity) { + name_ = activity.name(); + device_index_ = activity.deviceId(); + device_resource_id_ = activity.resourceId(); + start_us_ = activity.timestamp(); + duration_us_ = activity.duration(); + correlation_id_ = activity.correlationId(); + activity_type_ = (uint8_t)activity.type(); + if (activity.linkedActivity()) { + linked_correlation_id_ = activity.linkedActivity()->correlationId(); + } + return *this; +} + +c10::DeviceType KinetoEvent::deviceType() const { + switch (activity_type_) { + case (uint8_t)libkineto::ActivityType::CPU_OP: + return c10::DeviceType::CPU; + case (uint8_t)libkineto::ActivityType::GPU_MEMCPY: + return c10::DeviceType::CUDA; + case (uint8_t)libkineto::ActivityType::GPU_MEMSET: + return c10::DeviceType::CUDA; + case (uint8_t)libkineto::ActivityType::CONCURRENT_KERNEL: + return c10::DeviceType::CUDA; + case (uint8_t)libkineto::ActivityType::EXTERNAL_CORRELATION: + return c10::DeviceType::CPU; + case (uint8_t)libkineto::ActivityType::CUDA_RUNTIME: + return c10::DeviceType::CPU; + } + TORCH_CHECK(false, "Unknown activity type"); +} + +KinetoEvent::KinetoEvent() : activity_type_((uint8_t)libkineto::ActivityType::CPU_OP) {} + +ProfilerResult::ProfilerResult( + std::vector events, + thread_event_lists legacy_events, + std::unique_ptr trace) + : events_(std::move(events)), + legacy_events_(std::move(legacy_events)), + trace_(std::move(trace)) {} +ProfilerResult::~ProfilerResult() {} + +void ProfilerResult::save(const std::string& path) { + // Kineto's save is destructive + TORCH_CHECK(!saved_, "Trace is already saved"); + trace_->save(path); + saved_ = true; +} + +#endif + +bool kinetoAvailable() { +#ifdef USE_KINETO + return true; +#else + return false; +#endif +} + +}}} diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h new file mode 100644 index 0000000000000..a1c2b2122e415 --- /dev/null +++ b/torch/csrc/autograd/profiler_kineto.h @@ -0,0 +1,213 @@ +#pragma once + +#include + +#ifdef USE_KINETO +namespace libkineto { +class TraceActivity; +class ActivityTraceInterface; +} +#endif + +namespace torch { +namespace autograd { +namespace profiler { + +enum class C10_API_ENUM ActivityType { + CPU = 0, + CUDA, // CUDA kernels, runtime + NUM_KINETO_ACTIVITIES, // must be the last one +}; + +#ifdef USE_KINETO + +struct KinetoObserverContext : public at::ObserverContext { + int64_t startUs; + uint64_t correlationId; + uint64_t startThreadId; + uint64_t endThreadId; + c10::optional>> shapes; + int64_t sequenceNr; + uint64_t fwdThreadId; + uint8_t recFunScope; + c10::optional> stack; +}; + +struct TORCH_API KinetoEvent { + KinetoEvent(); + + uint64_t startThreadId() const { + return start_thread_id_; + } + + uint64_t endThreadId() const { + return end_thread_id_; + } + + uint8_t activityType() const { + return activity_type_; + } + + uint64_t fwdThreadId() const { + return fwd_thread_id_; + } + + bool hasShapes() const { + return shapes_ != c10::nullopt; + } + + const std::vector>& shapes() const { + return *shapes_; + } + + int64_t sequenceNr() const { + return sequence_nr_; + } + + bool hasStack() const { + return stack_ != c10::nullopt; + } + + const std::vector& stack() const { + return *stack_; + } + + uint8_t scope() const { + return scope_; + } + + KinetoEvent& startThreadId(uint64_t start_thread_id) { + start_thread_id_ = start_thread_id; + return *this; + } + + KinetoEvent& endThreadId(uint64_t end_thread_id) { + end_thread_id_ = end_thread_id; + return *this; + } + + KinetoEvent& fwdThreadId(uint64_t fwd_thread_id) { + fwd_thread_id_ = fwd_thread_id; + return *this; + } + + KinetoEvent& shapes(const std::vector>& shapes) { + shapes_ = shapes; + return *this; + } + + KinetoEvent& sequenceNr(int64_t sequence_nr) { + sequence_nr_ = sequence_nr; + return *this; + } + + KinetoEvent& stack(const std::vector& st) { + stack_ = st; + return *this; + } + + KinetoEvent& scope(uint8_t scope) { + scope_ = scope; + return *this; + } + + // Kineto fields + + KinetoEvent& activity(const libkineto::TraceActivity& activity); + + std::string name() const { + return name_; + } + + uint64_t deviceIndex() const { + return device_index_; + } + + uint64_t startUs() const { + return start_us_; + } + + uint64_t durationUs() const { + return duration_us_; + } + + uint64_t correlationId() const { + return correlation_id_; + } + + KinetoEvent& correlationId(uint64_t correlation_id) { + correlation_id_ = correlation_id; + return *this; + } + + uint64_t linkedCorrelationId() const { + return linked_correlation_id_; + } + + int64_t deviceResourceId() const { + return device_resource_id_; + } + + c10::DeviceType deviceType() const; + + uint64_t start_thread_id_ = 0; + uint64_t end_thread_id_ = 0; + uint64_t fwd_thread_id_ = 0; + int64_t sequence_nr_ = -1; + uint8_t scope_ = 0; + + uint8_t activity_type_; + c10::optional>> shapes_; + c10::optional> stack_; + + std::string name_; + uint64_t device_index_ = 0; + uint64_t start_us_ = 0; + uint64_t duration_us_ = 0; + uint64_t correlation_id_ = 0; + uint64_t linked_correlation_id_ = 0; + int64_t device_resource_id_ = 0; +}; + +// Consolidating events returned directly from Kineto +// with events manually created by us (e.g. start/stop marks, +// memory allocation events) +struct TORCH_API ProfilerResult { + ProfilerResult( + std::vector events, + thread_event_lists legacy_events, + std::unique_ptr trace); + ~ProfilerResult(); + + const std::vector& events() const { + return events_; + } + + const thread_event_lists& legacy_events() const { + return legacy_events_; + } + + void save(const std::string& path); + + private: + bool saved_ = false; + std::vector events_; + thread_event_lists legacy_events_; + std::unique_ptr trace_; +}; + +TORCH_API void enableProfiler( + const ProfilerConfig& config, + const std::set& activities); + +TORCH_API std::unique_ptr disableProfiler(); + +TORCH_API void prepareProfiler( + const ProfilerConfig& config, + const std::set& activities); +#endif // USE_KINETO + +TORCH_API bool kinetoAvailable(); + +} // namespace profiler +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler_legacy.cpp similarity index 50% rename from torch/csrc/autograd/profiler.cpp rename to torch/csrc/autograd/profiler_legacy.cpp index bab21ee5a7a81..85272677a06b8 100644 --- a/torch/csrc/autograd/profiler.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -22,35 +23,33 @@ namespace torch { namespace autograd { namespace profiler { -namespace { - - enum EventIValueIdx { - KIND = 0, - NAME, - THREAD_ID, - HANDLE, - NODE_ID, - CPU_MEM_USAGE, - CPU_NS, - CUDA_RECORDED, - CUDA_MEM_USAGE, - CUDA_DEVICE, - CUDA_US, - NUM_EVENT_IVALUE_IDX // must be last in list - }; - - enum ProfilerIValueIdx { - STATE = 0, - REPORT_INPUT_SHAPES, - PROFILE_MEMORY, - NUM_PROFILER_CFG_IVALUE_IDX // must be last in list - }; +std::vector prepareCallstack(const std::vector& cs) { + std::vector entries; + entries.reserve(cs.size()); + for (const auto& entry : cs) { + auto& range = entry.range; + if (range.source()) { + auto& src = range.source(); + if (src && src->filename()) { + auto line = src->starting_line_no() + + src->lineno_for_offset(range.start()); + entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename}); + } + } + } + return entries; +} -CUDAStubs default_stubs; -constexpr CUDAStubs* default_stubs_addr = &default_stubs; -// Constant initialization, so it is guaranteed to be initialized before -// static initialization calls which may invoke registerCUDAMethods -static CUDAStubs* cuda_stubs = default_stubs_addr; +std::vector callstackStr(const std::vector& cs) { + std::vector cs_str; + cs_str.reserve(cs.size()); + for (const auto& entry : cs) { + std::stringstream loc; + loc << entry.filename << "(" << entry.line << "): " << entry.funcname; + cs_str.push_back(loc.str()); + } + return cs_str; +} // We decompose the profiler logic into the following components: // @@ -116,8 +115,9 @@ static CUDAStubs* cuda_stubs = default_stubs_addr; // - TorchScript functions/methods // - user defined named ranges (see `record_function` python context manager) // -// Profiler setups a pair of callbacks that record profiling events and save them -// into the thread local profiler struct (ThreadLocalDebugInfo, PROFILER_STATE slot) +// Profiler setups a pair of callbacks that record profiling events and save +// them into the thread local profiler struct (ThreadLocalDebugInfo, +// PROFILER_STATE slot) // // // Thus, the overall logic is: @@ -142,244 +142,314 @@ static CUDAStubs* cuda_stubs = default_stubs_addr; // - save profiling events into the profiling state // +namespace { +const CUDAStubs default_stubs; +constexpr const CUDAStubs* default_stubs_addr = &default_stubs; +// Constant initialization, so it is guaranteed to be initialized before +// static initialization calls which may invoke registerCUDAMethods +inline const CUDAStubs*& cuda_stubs() { + static const CUDAStubs* stubs_ = default_stubs_addr; + return stubs_; +} +} + // Profiler state -struct ProfilerThreadLocalState - : public c10::MemoryReportingInfoBase { - explicit ProfilerThreadLocalState( - const ProfilerConfig& config) - : config_(config), remoteProfiledEvents_{c10::nullopt} {} - ~ProfilerThreadLocalState() override = default; - - inline const ProfilerConfig& config() const { - return config_; - } +const ProfilerConfig& ProfilerThreadLocalState::config() const { + return config_; +} - thread_event_lists consolidate() { - std::lock_guard g(state_mutex_); - thread_event_lists result; - for (auto& kv : event_lists_map_) { - auto& list = kv.second; - result.emplace_back(list->consolidate()); - } - // Consolidate remote events if applicable as well. - if (remoteProfiledEvents_) { - result.insert( - result.end(), - std::make_move_iterator(remoteProfiledEvents_->begin()), - std::make_move_iterator(remoteProfiledEvents_->end())); - } - return result; +thread_event_lists ProfilerThreadLocalState::consolidate() { + std::lock_guard g(state_mutex_); + thread_event_lists result; + for (auto& kv : event_lists_map_) { + auto& list = kv.second; + result.emplace_back(list->consolidate()); + } + // Consolidate remote events if applicable as well. + if (remoteProfiledEvents_) { + result.insert( + result.end(), + std::make_move_iterator(remoteProfiledEvents_->begin()), + std::make_move_iterator(remoteProfiledEvents_->end())); } + return result; +} - void mark( - std::string name, - bool include_cuda = true) { - if (config_.state == ProfilerState::Disabled) { - return; - } - if (config_.state == ProfilerState::NVTX) { - cuda_stubs->nvtxMarkA(name.c_str()); - } else { - Event evt( +void ProfilerThreadLocalState::mark(std::string name, bool include_cuda) { + if (config_.state == ProfilerState::Disabled) { + return; + } + if (config_.state == ProfilerState::NVTX) { + cuda_stubs()->nvtxMarkA(name.c_str()); + } else { + LegacyEvent evt( EventKind::Mark, at::StringView(std::move(name)), at::RecordFunction::currentThreadId(), - include_cuda && config_.state == ProfilerState::CUDA - ); - evt.setNodeId(at::RecordFunction::getDefaultNodeId()); - getEventList().record(std::move(evt)); - } + include_cuda && config_.state == ProfilerState::CUDA); + evt.setNodeId(at::RecordFunction::getDefaultNodeId()); + getEventList().record(std::move(evt)); } +} - void setOrAddRemoteProfiledEvents(std::vector&& remoteProfiledEvents) { - // Lock to serialize access from multiple callback threads. - std::lock_guard guard(state_mutex_); - if (remoteProfiledEvents_) { - (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents); - } else { - remoteProfiledEvents_ = {std::move(remoteProfiledEvents)}; - } +void ProfilerThreadLocalState::setOrAddRemoteProfiledEvents( + std::vector&& remoteProfiledEvents) { + // Lock to serialize access from multiple callback threads. + std::lock_guard guard(state_mutex_); + if (remoteProfiledEvents_) { + (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents); + } else { + remoteProfiledEvents_ = {std::move(remoteProfiledEvents)}; } +} - void pushRange( - const at::StringView& name, - const char* msg = "", - int64_t sequence_nr = -1, - std::vector>&& shapes = {}, - at::RecordFunctionHandle handle = 0) { - if (config_.state == ProfilerState::Disabled) { - return; - } - if (config_.state == ProfilerState::NVTX) { - cuda_stubs->nvtxRangePushA(getNvtxStr( - name, msg, sequence_nr, shapes).c_str()); - } else { - Event evt(EventKind::PushRange, - name, - at::RecordFunction::currentThreadId(), - config_.state == ProfilerState::CUDA, - handle, - std::move(shapes), - at::RecordFunction::getDefaultNodeId()); - evt.setSequenceNr(sequence_nr); - getEventList().record(std::move(evt)); - } +void ProfilerThreadLocalState::pushRange( + const at::RecordFunction& fn, + const bool record_cuda, + const char* msg, + std::vector>&& shapes) { + if (config_.state == ProfilerState::Disabled) { + return; } - - void popRange(uint64_t thread_id, at::RecordFunctionHandle handle) { - if (config_.state == ProfilerState::Disabled) { - return; + if (config_.state == ProfilerState::NVTX) { + cuda_stubs()->nvtxRangePushA(getNvtxStr( + fn.name(), msg, fn.seqNr(), shapes).c_str()); + } else { + LegacyEvent evt( + EventKind::PushRange, + fn.name(), + at::RecordFunction::currentThreadId(), + record_cuda, + fn.handle(), + std::move(shapes), + at::RecordFunction::getDefaultNodeId()); + evt.setSequenceNr(fn.seqNr()); + evt.setFwdThreadId(fn.forwardThreadId()); + evt.setScope((uint8_t)fn.scope()); + if (config_.with_flops) { + evt.setExtraArgs(saveExtraArgs(fn)); + evt.setFlops(computeFlops(std::string(fn.name().str()), evt.extraArgs())); } - if (config_.state == ProfilerState::NVTX) { - cuda_stubs->nvtxRangePop(); - } else { - // In some cases RecordFunction (and popRange) may be - // called on a different thread than pushRange - // As a convention, we put the async pop on the original - // thread and save current thread id in pop event - Event evt(EventKind::PopRange, - at::StringView(""), - at::RecordFunction::currentThreadId(), - config_.state == ProfilerState::CUDA, - handle); - evt.setNodeId(at::RecordFunction::getDefaultNodeId()); - getEventList(thread_id).record(std::move(evt)); +#ifndef C10_MOBILE + // backward nodes source range corresponds to the forward node + // TODO: consider using C++ stack trace + if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { + auto cs = prepareCallstack(jit::currentCallstack()); + if (cs.empty()) { + cs = prepareCallstack(jit::tracer::pythonCallstack()); + } + evt.setStack(callstackStr(cs)); } +#endif + getEventList().record(std::move(evt)); } +} - void setCallbackHandle(at::CallbackHandle handle) { - handle_ = handle; +void ProfilerThreadLocalState::popRange(const at::RecordFunction& fn, const bool record_cuda) { + if (config_.state == ProfilerState::Disabled) { + return; } - - at::CallbackHandle callbackHandle() const { - return handle_; + if (config_.state == ProfilerState::NVTX) { + cuda_stubs()->nvtxRangePop(); + } else { + // In some cases RecordFunction (and popRange) may be + // called on a different thread than pushRange + // As a convention, we put the async pop on the original + // thread and save current thread id in pop event + LegacyEvent evt( + EventKind::PopRange, + at::StringView(""), + at::RecordFunction::currentThreadId(), + record_cuda, + fn.handle()); + evt.setNodeId(at::RecordFunction::getDefaultNodeId()); + getEventList(fn.threadId()).record(std::move(evt)); } +} - void reportMemoryUsage( - void* /* unused */, int64_t alloc_size, c10::Device device) override { - if (config_.profile_memory && config_.state != ProfilerState::Disabled) { - uint64_t thread_id = at::RecordFunction::currentThreadId(); - Event evt( - EventKind::MemoryAlloc, - at::StringView(""), - thread_id, - config_.state == ProfilerState::CUDA); - evt.updateMemoryStats(alloc_size, device); - getEventList(thread_id).record(std::move(evt)); - } +void ProfilerThreadLocalState::reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) { + if (config_.profile_memory && config_.state != ProfilerState::Disabled) { + uint64_t thread_id = at::RecordFunction::currentThreadId(); + LegacyEvent evt( + EventKind::MemoryAlloc, + at::StringView(""), + thread_id, + config_.state == ProfilerState::CUDA); + evt.updateMemoryStats(alloc_size, device); + getEventList(thread_id).record(std::move(evt)); } +} - bool memoryProfilingEnabled() const override { - return config_.profile_memory; - } +bool ProfilerThreadLocalState::memoryProfilingEnabled() const { + return config_.profile_memory; +} - private: - std::string getNvtxStr( - const at::StringView& name, - const char* msg, - int64_t sequence_nr, - const std::vector>& shapes) const { - if (sequence_nr >= 0 || shapes.size() > 0) { - std::stringstream s; - if (sequence_nr >= 0) { - s << name.str() << msg << sequence_nr; - } - if (shapes.size() > 0) { - s << ", sizes = ["; - for (size_t idx = 0; idx < shapes.size(); ++idx) { - if (shapes[idx].size() > 0) { - s << "["; - for (size_t dim = 0; dim < shapes[idx].size(); ++dim) { - s << shapes[idx][dim]; - if (dim < shapes[idx].size() - 1) { - s << ", "; - } +std::string ProfilerThreadLocalState::getNvtxStr( + const at::StringView& name, + const char* msg, + int64_t sequence_nr, + const std::vector>& shapes) const { + if (sequence_nr >= 0 || shapes.size() > 0) { + std::stringstream s; +#ifdef __HIP_PLATFORM_HCC__ + s << name.str(); +#endif + if (sequence_nr >= 0) { +#ifdef __HIP_PLATFORM_HCC__ + s << msg << sequence_nr; +#else + s << name.str() << msg << sequence_nr; +#endif + } + if (shapes.size() > 0) { + s << ", sizes = ["; + for (size_t idx = 0; idx < shapes.size(); ++idx) { + if (shapes[idx].size() > 0) { + s << "["; + for (size_t dim = 0; dim < shapes[idx].size(); ++dim) { + s << shapes[idx][dim]; + if (dim < shapes[idx].size() - 1) { + s << ", "; } - s << "]"; - } else { - s << "[]"; - } - if (idx < shapes.size() - 1) { - s << ", "; } + s << "]"; + } else { + s << "[]"; + } + if (idx < shapes.size() - 1) { + s << ", "; } - s << "]"; } - return s.str(); - } else { - return name.str(); + s << "]"; } + return s.str(); + } else { + return name.str(); + } +} + +RangeEventList& ProfilerThreadLocalState::getEventList(int64_t thread_id) { + if (thread_id < 0) { + thread_id = at::RecordFunction::currentThreadId(); + } + RangeEventList* list_ptr = nullptr; + std::lock_guard guard(state_mutex_); + auto it = event_lists_map_.find(thread_id); + if (it != event_lists_map_.end()) { + list_ptr = it->second.get(); + } else { + auto event_list = std::make_shared(); + event_lists_map_[thread_id] = event_list; + list_ptr = event_list.get(); } + return *list_ptr; +} - RangeEventList& getEventList(int64_t thread_id = -1) { - if (thread_id < 0) { - thread_id = at::RecordFunction::currentThreadId(); +std::vector> inputSizes(const at::RecordFunction& fn) { + std::vector> sizes; + sizes.reserve(fn.inputs().size()); + for (const c10::IValue& input : fn.inputs()) { + if (!input.isTensor()) { + sizes.emplace_back(); + continue; } - RangeEventList* list_ptr = nullptr; - std::lock_guard guard(state_mutex_); - auto it = event_lists_map_.find(thread_id); - if (it != event_lists_map_.end()) { - list_ptr = it->second.get(); + const at::Tensor& tensor = input.toTensor(); + if (tensor.defined()) { + sizes.push_back(input.toTensor().sizes().vec()); } else { - auto event_list = std::make_shared(); - event_lists_map_[thread_id] = event_list; - list_ptr = event_list.get(); + sizes.emplace_back(); } - return *list_ptr; } + return sizes; +} - std::mutex state_mutex_; - std::unordered_map> - event_lists_map_; +namespace { + +enum EventIValueIdx { + KIND = 0, + NAME, + THREAD_ID, + HANDLE, + NODE_ID, + CPU_MEM_USAGE, + CPU_NS, + CUDA_RECORDED, + CUDA_MEM_USAGE, + CUDA_DEVICE, + CUDA_US, + SHAPES, + NUM_EVENT_IVALUE_IDX // must be last in list +}; - ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled, false, false); - at::CallbackHandle handle_ = 0; - c10::optional>> remoteProfiledEvents_; +enum ProfilerIValueIdx { + STATE = 0, + REPORT_INPUT_SHAPES, + PROFILE_MEMORY, + NUM_PROFILER_CFG_IVALUE_IDX // must be last in list +}; + +const std::unordered_set disable_cuda_profiling = { + "aten::view", + "aten::t", + "aten::transpose", + "aten::stride", + "aten::empty", + "aten::empty_like", + "aten::empty_strided", + "aten::as_strided", + "aten::expand", + "aten::resize_", + "aten::squeeze", + "aten::unsqueeze", + "aten::slice", + "aten::_unsafe_view", + "aten::size" }; ProfilerThreadLocalState* getProfilerTLSState() { - const auto& state = c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE); - return dynamic_cast(state.get()); + return static_cast( + c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)); } -void pushProfilingCallbacks() { +void pushProfilingCallbacksLegacy() { auto state_ptr = getProfilerTLSState(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { - return; + return nullptr; + } + bool record_cuda = + state_ptr->config().state == ProfilerState::CUDA; + if (record_cuda && disable_cuda_profiling.find(fn.name().str()) != disable_cuda_profiling.end()) { + record_cuda = false; } auto* msg = (fn.seqNr() >= 0) ? ", seq = " : ""; if (state_ptr->config().report_input_shapes) { - std::vector> inputSizes; - inputSizes.reserve(fn.inputs().size()); - for (const c10::IValue& input : fn.inputs()) { - if (!input.isTensor()) { - inputSizes.emplace_back(); - continue; - } - const at::Tensor& tensor = input.toTensor(); - if (tensor.defined()) { - inputSizes.push_back(input.toTensor().sizes().vec()); - } else { - inputSizes.emplace_back(); - } - } - state_ptr->pushRange( - fn.name(), msg, fn.seqNr(), std::move(inputSizes), fn.handle()); + auto sizes = inputSizes(fn); + state_ptr->pushRange(fn, record_cuda, msg, std::move(sizes)); } else { - state_ptr->pushRange(fn.name(), msg, fn.seqNr(), {}, fn.handle()); + state_ptr->pushRange(fn, record_cuda, msg); } + + return nullptr; }, - [](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn, at::ObserverContext*) { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { return; } - state_ptr->popRange(fn.getStartCallbacksThreadId(), fn.handle()); + bool record_cuda = + state_ptr->config().state == ProfilerState::CUDA; + if (record_cuda && disable_cuda_profiling.find(fn.name().str()) != disable_cuda_profiling.end()) { + record_cuda = false; + } + state_ptr->popRange(fn, record_cuda); }) .needsInputs(state_ptr->config().report_input_shapes) .needsIds(true)); @@ -391,11 +461,9 @@ const int kCUDAWarmupStart = 5; } // namespace void registerCUDAMethods(CUDAStubs* stubs) { - cuda_stubs = stubs; + cuda_stubs() = stubs; } -ProfilerConfig::~ProfilerConfig() = default; - at::IValue ProfilerConfig::toIValue() const { c10::impl::GenericList eventIValueList(at::AnyType::get()); eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX); @@ -436,38 +504,42 @@ bool profilerEnabled() { return state_ptr && state_ptr->config().state != ProfilerState::Disabled; } -void enableProfiler(const ProfilerConfig& new_config) { - TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs->enabled(), +void enableProfilerLegacy(const ProfilerConfig& new_config) { + TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs()->enabled(), "Can't use NVTX profiler - PyTorch was compiled without CUDA"); + TORCH_CHECK(new_config.state != ProfilerState::KINETO); + auto state_ptr = getProfilerTLSState(); TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread"); auto state = std::make_shared(new_config); c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); - pushProfilingCallbacks(); + pushProfilingCallbacksLegacy(); if (new_config.state == ProfilerState::CUDA) { // event recording appears to have some startup overhead, so we need to // to generate some dummy events first before recording synchronization events for (int idx = 0; idx < kCUDAWarmupStart; ++idx) { - cuda_stubs->onEachDevice([state](int /* unused */) { + cuda_stubs()->onEachDevice([state](int /* unused */) { state->mark("__cuda_startup"); - cuda_stubs->synchronize(); + cuda_stubs()->synchronize(); }); } // cuda events must be on the same device, so we need a start event recorded // for each gpu. we then use this event to synchronize time on the GPU // with the CPU clock. - cuda_stubs->onEachDevice([state](int d) { + cuda_stubs()->onEachDevice([state](int d) { state->mark("__cuda_start_event"); }); } state->mark("__start_profile", false); } -thread_event_lists disableProfiler(bool cleanupTLSState, bool consolidate) { +thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions) { + auto cleanupTLSState = profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true; + auto consolidate = profilerDisableOptions ? profilerDisableOptions->consolidate : true; // all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard std::shared_ptr state; if (cleanupTLSState) { @@ -493,21 +565,21 @@ thread_event_lists disableProfiler(bool cleanupTLSState, bool consolidate) { return state_ptr->consolidate(); } -void addEventList(std::vector&& profiledEvents) { +void addEventList(std::vector&& profiledEvents) { auto state_ptr = getProfilerTLSState(); TORCH_CHECK(state_ptr, "Profiler must be enabled."); state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents)); } -void Event::record(bool record_cuda) { +void LegacyEvent::record(bool record_cuda) { if (record_cuda) { - cuda_stubs->record(&device_, &cuda_event, &cpu_ns_); + cuda_stubs()->record(&device_, &cuda_event, &cpu_ns_); return; } cpu_ns_ = getTime(); } -/* static */ Event Event::fromIValue(const at::IValue& eventIValue) { +/* static */ LegacyEvent LegacyEvent::fromIValue(const at::IValue& eventIValue) { TORCH_INTERNAL_ASSERT( eventIValue.isList(), "Expected IValue to contain type c10::impl::GenericList"); @@ -516,16 +588,40 @@ void Event::record(bool record_cuda) { ivalues.size() >= NUM_EVENT_IVALUE_IDX, "Expected at least ", NUM_EVENT_IVALUE_IDX, - " elements to reconstruct Event."); + " elements to reconstruct LegacyEvent."); - Event evt( + // Reconstruct input shapes from ivalues. + auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES); + TORCH_INTERNAL_ASSERT( + shapeListIValue.isList(), + "Expected profiler shapes IValue to contain type c10::impl::GenericList." + ); + + auto shapeList = shapeListIValue.toList(); + std::vector> shapes; + shapes.reserve(shapeList.size()); + for (size_t i = 0 ; i < shapeList.size(); ++i) { + std::vector s; + auto shapeIValue = shapeList.get(i); + TORCH_INTERNAL_ASSERT( + shapeIValue.isList(), + "Expected each profiler shape element to contain shapes of type c10::impl::GenericList.") + auto curShapesList = shapeIValue.toList(); + s.reserve(curShapesList.size()); + for (size_t j = 0; j < curShapesList.size(); ++j) { + s.emplace_back(curShapesList.get(j).toInt()); + } + shapes.emplace_back(s); + } + + LegacyEvent evt( static_cast( ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id static_cast( ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle - {}, // TODO: record shapes + std::move(shapes), // input shapes ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id true, // is remote ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage @@ -538,27 +634,40 @@ void Event::record(bool record_cuda) { return evt; } -at::IValue Event::toIValue() const { +at::IValue LegacyEvent::toIValue() const { c10::impl::GenericList eventIValueList(at::AnyType::get()); eventIValueList.reserve(NUM_EVENT_IVALUE_IDX); eventIValueList.emplace_back(static_cast(kind_)); eventIValueList.emplace_back(std::string(name_.str())); - eventIValueList.emplace_back(thread_id_); + eventIValueList.emplace_back(static_cast(thread_id_)); eventIValueList.emplace_back(static_cast(handle_)); eventIValueList.emplace_back(node_id_); eventIValueList.emplace_back(cpu_memory_usage_); eventIValueList.emplace_back(cpu_ns_); // CUDA event information - bool cuda_profiling_enabled = has_cuda(); + bool cuda_profiling_enabled = hasCuda(); eventIValueList.emplace_back(cuda_profiling_enabled); eventIValueList.emplace_back(static_cast(cuda_memory_usage_)); eventIValueList.emplace_back(device_); eventIValueList.emplace_back(cuda_us_); + // Shapes + c10::impl::GenericList shapesList = + c10::impl::GenericList(at::ListType::create(at::IntType::get())); + shapesList.reserve(shapes_.size()); + for (const auto& shape : shapes_) { + c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get()); + s.reserve(shape.size()); + for (const auto& k : shape) { + s.emplace_back(k); + } + shapesList.emplace_back(s); + } + eventIValueList.emplace_back(shapesList); return at::IValue(eventIValueList); } -double Event::cuda_elapsed_us(const Event& e) const { - TORCH_CHECK(e.has_cuda() && has_cuda(), "Events were not recorded for CUDA"); +double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const { + TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA"); TORCH_CHECK( e.device() == device(), c10::str( @@ -568,13 +677,12 @@ double Event::cuda_elapsed_us(const Event& e) const { TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0); return static_cast(e.cuda_us_ - cuda_us_); } - return cuda_stubs->elapsed(&cuda_event, &e.cuda_event); + return cuda_stubs()->elapsed(&cuda_event, &e.cuda_event); } CUDAStubs::~CUDAStubs() = default; - -static jit::CodeTemplate event_template(R"( +static const jit::CodeTemplate event_template(R"( { "name": "${name}", "ph": "X", @@ -585,10 +693,10 @@ static jit::CodeTemplate event_template(R"( "args": {} })"); -void writeProfilerEventsToStream(std::ostream& out, const std::vector& events) { +void writeProfilerEventsToStream(std::ostream& out, const std::vector& events) { TORCH_CHECK(out, "Could not open file"); - Event* profiler_start = nullptr; - for (Event* e : events) { + LegacyEvent* profiler_start = nullptr; + for (LegacyEvent* e : events) { if (0 == strcmp(e->name(), "__start_profile")) { profiler_start = e; break; @@ -602,34 +710,33 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector& e return std::hash()(p.first) ^ std::hash()(p.second); } }; - std::unordered_map, Event*, PairHash> events_map; + std::unordered_map, LegacyEvent*, PairHash> events_map; out << "[\n"; bool first = true; - for (Event* evt : events) { - if (evt->kind() == "push") { - events_map[std::make_pair(evt->handle(), evt->node_id())] = evt; - } else if (evt->kind() == "pop") { + for (LegacyEvent* evt : events) { + if (evt->kindStr() == "push") { + events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt; + } else if (evt->kindStr() == "pop") { if (!first) { out << ",\n"; } first = false; - auto it = events_map.find(std::make_pair(evt->handle(), evt->node_id())); + auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId())); TORCH_CHECK(it != events_map.end(), "Unmatched pop event"); - Event* evt_start = it->second; + LegacyEvent* evt_start = it->second; events_map.erase(it); jit::TemplateEnv env; env.s("name", evt_start->name()); - env.d("ts", profiler_start->cpu_elapsed_us(*evt_start)); - env.d("dur", evt_start->cpu_elapsed_us(*evt)); - env.d("tid", evt_start->thread_id()); + env.d("ts", profiler_start->cpuElapsedUs(*evt_start)); + env.d("dur", evt_start->cpuElapsedUs(*evt)); + env.d("tid", evt_start->threadId()); out << event_template.format(env); } } out << "]\n"; } - RecordProfile::RecordProfile(std::ostream& out) : out_(out) { init(); @@ -641,27 +748,27 @@ RecordProfile::RecordProfile(const std::string& filename) } void RecordProfile::init() { - enableProfiler(ProfilerConfig( - ProfilerState::CPU, - /* report_input_shapes */ false, - /* profile_memory */ false)); + enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU)); } RecordProfile::~RecordProfile() { - thread_event_lists event_lists = disableProfiler(); - std::vector events; - for (auto& l : event_lists) { - for (auto& e : l) { - events.push_back(&e); + try { + thread_event_lists event_lists = disableProfilerLegacy(); + std::vector events; + for (auto& l : event_lists) { + for (auto& e : l) { + events.push_back(&e); + } } - } - processEvents(events); - if (file_){ - file_->close(); + processEvents(events); + } catch (const std::exception& e) { + LOG(ERROR) << e.what() << std::endl; + } catch (...) { + LOG(ERROR) << "Unknown error" << std::endl; } } -void RecordProfile::processEvents(const std::vector& events) { +void RecordProfile::processEvents(const std::vector& events) { writeProfilerEventsToStream(out_, events); } diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h new file mode 100644 index 0000000000000..23169cd33450f --- /dev/null +++ b/torch/csrc/autograd/profiler_legacy.h @@ -0,0 +1,567 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#endif +#if defined(C10_IOS) && defined(C10_MOBILE) +#include // for gettimeofday() +#endif + +#include + +#include + +struct CUevent_st; +typedef std::shared_ptr CUDAEventStub; + +namespace torch { namespace autograd { + +struct Node; + +namespace profiler { + +struct TORCH_API CUDAStubs { + virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const { + fail(); + } + virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const { + fail(); + return 0.f; + } + virtual void nvtxMarkA(const char* name) const { + fail(); + } + virtual void nvtxRangePushA(const char* name) const { + fail(); + } + virtual void nvtxRangePop() const { + fail(); + } + virtual bool enabled() const { + return false; + } + virtual void onEachDevice(std::function op) const { + fail(); + } + virtual void synchronize() const { + fail(); + } + virtual ~CUDAStubs(); + +private: + void fail() const { + AT_ERROR("CUDA used in profiler but not enabled."); + } +}; + +TORCH_API void registerCUDAMethods(CUDAStubs* stubs); + +constexpr inline size_t ceilToMultiple(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +inline int64_t getTime() { +#if defined(C10_IOS) && defined(C10_MOBILE) +// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on +// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000000000 + static_cast(now.tv_usec) * 1000; +#elif defined(_WIN32) || defined(__MACH__) + using namespace std::chrono; + using clock = std::conditional::type; + return duration_cast(clock::now().time_since_epoch()).count(); +#else + // clock_gettime is *much* faster than std::chrono implementation on Linux + struct timespec t{}; + clock_gettime(CLOCK_MONOTONIC, &t); + return static_cast(t.tv_sec) * 1000000000 + static_cast(t.tv_nsec); +#endif +} + +enum class C10_API_ENUM EventKind : uint16_t { + Mark, + PushRange, + PopRange, + MemoryAlloc, +}; + +// To be deprecated, once we switch to Kineto profiling +struct TORCH_API LegacyEvent { + LegacyEvent( + EventKind kind, + at::StringView name, + uint16_t thread_id, + bool record_cuda, + at::RecordFunctionHandle handle = 0, + std::vector>&& shapes = {}, + int node_id = -1) + : name_(std::move(name)), + kind_(kind), + thread_id_(thread_id), + handle_(handle), + shapes_(shapes), + node_id_(node_id) { + record(record_cuda); + } + + // Constructor to be used in conjunction with LegacyEvent::fromIValue. + LegacyEvent( + EventKind kind, + at::StringView name, + uint16_t thread_id, + at::RecordFunctionHandle handle, + std::vector>&& shapes, + int node_id, + bool is_remote, + int64_t cpu_memory_usage, + int64_t cpu_ns, + bool cuda_recorded, + int64_t cuda_memory_usage = 0, + int device = -1, + double cuda_us = -1) + : cpu_ns_(cpu_ns), + name_(std::move(name)), + kind_(kind), + thread_id_(thread_id), + handle_(handle), + shapes_(shapes), + cpu_memory_usage_(cpu_memory_usage), + cuda_memory_usage_(cuda_memory_usage), + device_(device), + node_id_(node_id), + is_remote_(is_remote), + cuda_us_(cuda_us) { + // Sanity check values that were deserialized + TORCH_INTERNAL_ASSERT(cpu_ns_ > 0); + if (cuda_recorded) { + TORCH_INTERNAL_ASSERT(device_ >= 0); + TORCH_INTERNAL_ASSERT(cuda_us_ >= 0); + } + } + + // Returns IValues corresponding to event structure, to be used for + // serialization. + at::IValue toIValue() const; + + // Reconstructs an event from IValues given by toIValue. + static LegacyEvent fromIValue(const at::IValue& eventIValue); + + void record(bool record_cuda); + + std::string kindStr() const { + switch (kind_) { + case EventKind::Mark: return "mark"; + case EventKind::PushRange: return "push"; + case EventKind::PopRange: return "pop"; + case EventKind::MemoryAlloc: return "memory_alloc"; + } + throw std::runtime_error("unknown event kind"); + } + + const char* name() const { + return name_.str(); + } + + uint64_t threadId() const { + return thread_id_; + } + + std::vector> shapes() const { + return shapes_; + } + + double cpuElapsedUs(const LegacyEvent& e) const { + return (e.cpu_ns_ - cpu_ns_)/(1000.0); + } + + void setCpuUs(int64_t cpu_us) { + cpu_ns_ = cpu_us * 1000.0; + } + + double cpuUs() const { + return cpu_ns_ / (1000.0); + } + + double cudaElapsedUs(const LegacyEvent& e) const; + + bool hasCuda() const { + return cuda_event != nullptr || (isRemote() && device_ != -1); + } + + int device() const { + return device_; + } + + void updateMemoryStats(int64_t alloc_size, c10::Device device) { + if (device.type() == c10::DeviceType::CUDA || + device.type() == c10::DeviceType::HIP) { + cuda_memory_usage_ = alloc_size; + } else if (device.type() == c10::DeviceType::CPU || + device.type() == c10::DeviceType::MKLDNN || + device.type() == c10::DeviceType::IDEEP) { + cpu_memory_usage_ = alloc_size; + } else { + LOG(WARNING) << "Unsupported memory profiling device: " << device; + } + } + + int64_t cpuMemoryUsage() const { + return cpu_memory_usage_; + } + + int64_t cudaMemoryUsage() const { + return cuda_memory_usage_; + } + + at::RecordFunctionHandle handle() const { + return handle_; + } + + // Node ID corresponding to this event. + int nodeId( ) const { + return node_id_; + } + + // Set Node ID on this event. + void setNodeId(int node_id) { + node_id_ = node_id; + } + + void setName(at::StringView newName_) { + name_ = std::move(newName_); + } + + bool isRemote() const { + return is_remote_; + } + + void setCudaUs(int64_t cuda_us) { + cuda_us_ = cuda_us; + } + + void setSequenceNr(int64_t sequence_nr) { + sequence_nr_ = sequence_nr; + } + + int64_t sequenceNr() const { + return sequence_nr_; + } + + void setCorrelationId(uint64_t correlation_id) { + correlation_id_ = correlation_id; + } + + uint64_t correlationId() const { + return correlation_id_; + } + + const std::vector& stack() const { + return stack_; + } + + void setStack(const std::vector& stack) { + stack_ = stack; + } + + uint64_t fwdThreadId() const { + return fwd_thread_id_; + } + + void setFwdThreadId(uint64_t fwd_thread_id) { + fwd_thread_id_ = fwd_thread_id; + } + + uint8_t scope() const { + return scope_; + } + + void setScope(uint8_t scope) { + scope_ = scope; + } + + const std::unordered_map& extraArgs() const { + return extra_args_; + } + + void setExtraArgs(std::unordered_map&& save_args) { + extra_args_ = std::move(save_args); + } + + uint64_t flops() { + return flops_; + } + + void setFlops(uint64_t flops) { + flops_ = flops; + } + + private: + // signed to allow for negative intervals, initialized for safety. + int64_t cpu_ns_ = 0; + at::StringView name_; + EventKind kind_; + uint64_t thread_id_; + uint64_t fwd_thread_id_; + at::RecordFunctionHandle handle_ {0}; + std::vector> shapes_; + int64_t cpu_memory_usage_ = 0; + int64_t cuda_memory_usage_ = 0; + int device_ = -1; + CUDAEventStub cuda_event = nullptr; + int node_id_ = 0; + bool is_remote_ = false; + int64_t cuda_us_ = -1; + int64_t sequence_nr_ = -1; + + std::vector stack_; + uint8_t scope_; + uint64_t correlation_id_; + // Extra arguments for computing op flops + std::unordered_map extra_args_; + uint64_t flops_; +}; + +// a linked-list of fixed sized vectors, to avoid +// a std::vector resize from taking a large amount of time inside +// a profiling event +struct RangeEventList { + RangeEventList() { + events_.reserve(kReservedCapacity); + } + + template + void record(Args&&... args) { + std::lock_guard guard(mutex_); + events_.emplace_back(std::forward(args)...); + } + + std::vector consolidate() { + std::lock_guard lock(mutex_); + std::vector result; + result.insert( + result.begin(), + std::make_move_iterator(events_.begin()), + std::make_move_iterator(events_.end())); + events_.erase(events_.begin(), events_.end()); + return result; + } + + size_t size() { + std::lock_guard lock(mutex_); + return events_.size(); + } + + private: + // This mutex is used to serialize access when different threads are writing + // to the same instance of RangeEventList. + std::mutex mutex_; + std::vector events_; + + static const size_t kReservedCapacity = 1024; +}; + +enum class C10_API_ENUM ProfilerState { + Disabled = 0, + CPU, // CPU-only profiling + CUDA, // CPU + CUDA events + NVTX, // only emit NVTX markers + KINETO, // use libkineto + NUM_PROFILER_STATES, // must be the last one +}; + +struct TORCH_API ProfilerConfig { + ProfilerConfig( + ProfilerState state, + bool report_input_shapes = false, + bool profile_memory = false, + bool with_stack = false, + bool with_flops = false) + : state(state), + report_input_shapes(report_input_shapes), + profile_memory(profile_memory), + with_stack(with_stack), + with_flops(with_flops) {} + ~ProfilerConfig() = default; + ProfilerState state; + bool report_input_shapes; + bool profile_memory; + bool with_stack; + bool with_flops; + + // Returns IValues corresponding to ProfilerConfig struct, to be used for + // serialization. + at::IValue toIValue() const; + + // Reconstructs a ProfilerConfig from IValues given by toIValue. + static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue); +}; + +// A struct to control settings of disableProfiler options. +struct TORCH_API ProfilerDisableOptions { + ProfilerDisableOptions() = default; + ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate) + : cleanupTLSState(shouldCleanupTLSState), + consolidate(shouldConsolidate) {} + // Whether we should clean up profiler states that are thread local, such as + // ThreadLocalDebugInfo and thread local RecordFunction callbacks. + bool cleanupTLSState = true; + // Whether we should consolidate all currently recorded profiled events. If + // false, will not consolidate and other threads can continue to write to the + // event lists. + bool consolidate = true; +}; + +// NOTE: profiler mode is thread local, with automatic propagation +// across thread boundary (e.g. at::launch tasks) +TORCH_API void enableProfilerLegacy(const ProfilerConfig&); +using thread_event_lists = std::vector>; +TORCH_API thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions = c10::nullopt); + +// adds profiledEvents to the current thread local recorded events. Each event +// will be marked with node ID given by fromNodeId. +TORCH_API void addEventList(std::vector&& profiledEvents); +// Returns if the profiler is currently enabled in the current thread. +TORCH_API bool profilerEnabled(); +// Retrieve the thread_local ProfilerConfig. +TORCH_API ProfilerConfig getProfilerConfig(); +// Writes profiled events to a stream. +TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector& events); + +// Usage: +// { +// RecordProfile guard("filename.trace"); +// // code you want to profile +// } +// Then open filename.trace in chrome://tracing +struct TORCH_API RecordProfile { + RecordProfile(std::ostream& out); + RecordProfile(const std::string& filename); + + ~RecordProfile(); +private: + void init(); + std::unique_ptr file_; + std::ostream& out_; + void processEvents(const std::vector& events); +}; + +// A guard that enables the profiler, taking in an optional callback to process +// the results +// Usage: +// { +// TLSProfilerGuard g([](thread_event_lists profilerResults) { +// // process profilerResults +// }); +// Code to profile +// } +struct TORCH_API TLSProfilerGuard { + explicit TLSProfilerGuard( + const ProfilerConfig& cfg, + c10::optional> + resultCallback = c10::nullopt, + c10::optional profilerDisableOptions = + c10::nullopt) + : cb_(std::move(resultCallback)), + profilerDisableOptions_(std::move(profilerDisableOptions)) { + enableProfilerLegacy(cfg); + } + ~TLSProfilerGuard() { + thread_event_lists event_lists = disableProfilerLegacy(profilerDisableOptions_); + if (cb_) { + try { + (*cb_)(event_lists); + } catch (const std::exception& e) { + LOG(ERROR) << "Got error processing profiler events: " << e.what(); + } + } + } + + private: + c10::optional> cb_; + const c10::optional profilerDisableOptions_; +}; + +struct TORCH_API FileLineFunc { + std::string filename; + size_t line; + std::string funcname; +}; +TORCH_API std::vector prepareCallstack(const std::vector& cs); +TORCH_API std::vector callstackStr(const std::vector& cs); +TORCH_API std::vector> inputSizes(const at::RecordFunction& fn); + +struct TORCH_API ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { + explicit ProfilerThreadLocalState(const ProfilerConfig& config) + : config_(config), remoteProfiledEvents_{c10::nullopt} {} + ~ProfilerThreadLocalState() override = default; + + const ProfilerConfig& config() const; + + thread_event_lists consolidate(); + + void mark(std::string name, bool include_cuda = true); + + void setOrAddRemoteProfiledEvents( + std::vector&& remoteProfiledEvents); + + void pushRange( + const at::RecordFunction& fn, + const bool record_cuda, + const char* msg = "", + std::vector>&& shapes = {}); + + void popRange(const at::RecordFunction& fn, const bool record_cuda); + + void setCallbackHandle(at::CallbackHandle handle) { + handle_ = handle; + } + + at::CallbackHandle callbackHandle() const { + return handle_; + } + + bool hasCallbackHandle() { + return handle_ > 0; + } + + void reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) override; + + bool memoryProfilingEnabled() const override; + + protected: + std::string getNvtxStr( + const at::StringView& name, + const char* msg, + int64_t sequence_nr, + const std::vector>& shapes) const; + + RangeEventList& getEventList(int64_t thread_id = -1); + + std::mutex state_mutex_; + std::unordered_map> + event_lists_map_; + + ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); + at::CallbackHandle handle_ = 0; + c10::optional>> remoteProfiledEvents_; +}; + + +} // namespace profiler +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/profiler_utils.cpp b/torch/csrc/autograd/profiler_utils.cpp new file mode 100644 index 0000000000000..d62d8c90c9621 --- /dev/null +++ b/torch/csrc/autograd/profiler_utils.cpp @@ -0,0 +1,215 @@ +#include + +namespace torch { namespace autograd { namespace profiler { + +static constexpr auto kConv2dStride = 3; +static constexpr auto kConv2dPadding = 4; +static constexpr auto kConv2dDilation = 5; +static constexpr auto kConv2dGroups = 6; + +// List of supported operators +static constexpr auto kConv2dOp = "aten::conv2d"; +static constexpr auto kGemmOp = "aten::mm"; +static constexpr auto kMulOp = "aten::mul.Tensor"; +static constexpr auto kAddOp = "aten::add.Tensor"; + +static constexpr auto kInputSize = "input_size"; +static constexpr auto kWeightSize = "weight_size"; +static constexpr auto kStride = "stride"; +static constexpr auto kPadding = "padding"; +static constexpr auto kDilation = "dilation"; +static constexpr auto kGroups = "groups"; +static constexpr auto kMatSize = "mat_size"; +static constexpr auto kMat1Size = "mat1_size"; +static constexpr auto kMat2Size = "mat2_size"; + +static bool validateInput(const std::string &op_name, size_t min_size, + const std::vector& inputs, + const std::vector& should_be_tensor) { + std::stringstream ss; + if (inputs.size() < min_size) { + ss << "Failed to save extra arguments for flops compuation of op " + << op_name + << ", min size: " << min_size + << ", actual size: " << inputs.size(); + TORCH_WARN(ss.str()); + return false; + } + for (auto index : should_be_tensor) { + if (!inputs[index].isTensor()) { + ss << "Failed to save extra arguments for flops compuation of op " + << op_name + << ", input[" << index + << "] must be a tensor."; + TORCH_WARN(ss.str()); + return false; + } + } + return true; +} + +std::unordered_map saveExtraArgs(const at::RecordFunction& fn) { + // for specific types of fn, return the saved extra args for computing flops + std::unordered_map map; + std::vector inputs = fn.inputs(); + std::string fname(fn.name().str()); + + if (inputs.empty()) { + // Input shape is unavailable, return empty map + return map; + } + + if (fname == kConv2dOp) { + std::vector tensors{0, 1}; + bool check = validateInput(fname, kConv2dGroups + 1, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor input = inputs[0].toTensor(); + at::Tensor weight = inputs[1].toTensor(); + if (weight.sizes().size() != 4) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor."); + return map; + } + map[kInputSize] = at::IValue(input.sizes()); + map[kWeightSize] = at::IValue(weight.sizes()); + map[kStride] = inputs[kConv2dStride]; + map[kPadding] = inputs[kConv2dPadding]; + map[kDilation] = inputs[kConv2dDilation]; + map[kGroups] = inputs[kConv2dGroups]; + } else if (fname == kGemmOp) { + std::vector tensors{0, 1}; + bool check = validateInput(fname, 2, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor left = inputs[0].toTensor(); + at::Tensor right = inputs[1].toTensor(); + map[kMat1Size] = at::IValue(left.sizes()); + map[kMat2Size] = at::IValue(right.sizes()); + } else if (fname == kMulOp) { + std::vector tensors{0}; + bool check = validateInput(fname, 1, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor mat = inputs[0].toTensor(); + map[kMatSize] = at::IValue(mat.sizes()); + } else if (fname == kAddOp) { + std::vector tensors{0}; + bool check = validateInput(fname, 1, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor mat = inputs[0].toTensor(); + map[kMatSize] = at::IValue(mat.sizes()); + } + + return map; +} + +uint64_t computeFlops(const std::string &op_name, const std::unordered_map &extra_args) { + if (op_name == kConv2dOp) { + if (extra_args.find(kInputSize) == extra_args.end() + || extra_args.find(kWeightSize) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::conv2d requires input_size and weight_size in saved arguments."); + return 0; + } + auto input_sizes_ref = extra_args.at(kInputSize); + auto kernel_sizes_ref = extra_args.at(kWeightSize); + if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes."); + return 0; + } + + const std::vector input_sizes = input_sizes_ref.toIntVector(); + const std::vector kernel_sizes = kernel_sizes_ref.toIntVector(); + if (input_sizes.size() != 4 || kernel_sizes.size() != 4) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because both input and weight must be size 4."); + return 0; + } + // format of the input is defined in torch.nn.quantized.functional.conv2d() + uint64_t minibatch = 0, in_channels = 0, input_h = 0, input_w = 0; + uint64_t out_channels = 0, kernel_h = 0, kernel_w = 0; + const uint64_t conv2d_multiply_factor = 2; + std::tie(minibatch, in_channels, input_h, input_w) = std::make_tuple(input_sizes[0], input_sizes[1], + input_sizes[2], input_sizes[3]); + std::tie(out_channels, std::ignore, kernel_h, kernel_w) = std::make_tuple(kernel_sizes[0], kernel_sizes[1], + kernel_sizes[2], kernel_sizes[3]); + + // grouping is NOT properly handled yet + return conv2d_multiply_factor * minibatch * input_h * input_w * kernel_h * kernel_w * in_channels * out_channels; + } else if (op_name == kGemmOp) { + if (extra_args.find(kMat1Size) == extra_args.end() + || extra_args.find(kMat2Size) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::mm requires mat1_size and mat2_size in saved arguments."); + return 0; + } + auto mat1_sizes_ref = extra_args.at(kMat1Size); + auto mat2_sizes_ref = extra_args.at(kMat2Size); + if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::mm because it requires mat1_size and mat2_size to be IntList."); + return 0; + } + + std::vector mat1_size = mat1_sizes_ref.toIntVector(); + std::vector mat2_size = mat2_sizes_ref.toIntVector(); + if (mat1_size.size() == 0) { + return 0; + } else { + int64_t overlap_dim = mat1_size.back(); + uint64_t flops = 1; + for(int64_t dim : mat1_size) { + flops *= dim; + } + flops /= overlap_dim; + for(int64_t dim : mat2_size) { + flops *= dim; + } + return flops; + } + } else if (op_name == kMulOp) { + if (extra_args.find(kMatSize) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::mul.Tensor requires mat_size in saved arguments."); + return 0; + } + auto mat_sizes = extra_args.at(kMatSize); + if (!mat_sizes.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::mul because it requires mat_size to be IntList."); + return 0; + } + + std::vector mat_size = mat_sizes.toIntVector(); + uint64_t flops = 1; + for(int64_t dim : mat_size) { + flops *= dim; + } + return flops; + } else if (op_name == kAddOp) { + if (extra_args.find(kMatSize) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::add.Tensor requires mat_size in saved arguments."); + return 0; + } + auto mat_sizes = extra_args.at(kMatSize); + if (!mat_sizes.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::add because it requires mat_size to be IntList."); + return 0; + } + + std::vector mat_size = mat_sizes.toIntVector(); + uint64_t flops = 1; + for(int64_t dim : mat_size) { + flops *= dim; + } + return flops; + } + return 0; +} + +} // namespace profiler +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/profiler_utils.h b/torch/csrc/autograd/profiler_utils.h new file mode 100644 index 0000000000000..959821983e8be --- /dev/null +++ b/torch/csrc/autograd/profiler_utils.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { namespace autograd { +namespace profiler { + +std::unordered_map TORCH_API saveExtraArgs(const at::RecordFunction& fn); + +uint64_t TORCH_API computeFlops(const std::string &op_name, + const std::unordered_map &extra_args); + +}}} diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index cbfe6d6e3b0bf..8902a5a2eb90d 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -114,7 +115,7 @@ PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook) PyObject *py_fn = functionToPyObject(c_tuple.function); if (!py_fn) return nullptr; PyTuple_SET_ITEM(tuple.get(), 0, py_fn); - PyObject *py_idx = PyLong_FromLong(c_tuple.input_nr); + PyObject *py_idx = THPUtils_packUInt32(c_tuple.input_nr); if (!py_idx) return nullptr; PyTuple_SET_ITEM(tuple.get(), 1, py_idx); PyTuple_SET_ITEM(py_functions.get(), i, tuple.release()); diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h index 1b51b69adcab7..fb1ba9c1a278c 100644 --- a/torch/csrc/autograd/python_cpp_function.h +++ b/torch/csrc/autograd/python_cpp_function.h @@ -31,9 +31,9 @@ PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) } #define THP_FUNCTION_DEFAULT_METHODS \ - {(char*)"_register_hook_dict", (PyCFunction)THPCppFunction_register_hook_dict, METH_O, nullptr}, \ - {(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, nullptr}, \ - {(char*)"name", (PyCFunction)THPCppFunction_name, METH_NOARGS, nullptr} + {(char*)"_register_hook_dict", THPCppFunction_register_hook_dict, METH_O, nullptr}, \ + {(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \ + {(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr} #define THP_FUNCTION_DEFAULT_PROPERTIES \ {(char*)"next_functions", (getter)THPCppFunction_next_functions, nullptr, nullptr, nullptr}, \ diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index f4c88225efc80..a9c7d709466ed 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -1,13 +1,13 @@ #include #include -#include #include #include #include #include #include #include +#include #include #include #include @@ -86,13 +86,14 @@ variable_list PythonEngine::execute( const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs) { TORCH_CHECK(!PyGILState_Check(), "The autograd engine was called while holding the GIL. If you are using the C++ " "API, the autograd engine is an expensive operation that does not require the " "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'" ". If you are not using the C++ API, please report a bug to the pytorch team.") try { - return Engine::execute(roots, inputs, keep_graph, create_graph, outputs); + return Engine::execute(roots, inputs, keep_graph, create_graph, accumulate_grad, outputs); } catch (python_error& e) { e.restore(); throw; @@ -101,9 +102,10 @@ variable_list PythonEngine::execute( std::shared_ptr PythonEngine::execute_with_graph_task( const std::shared_ptr& graph_task, - std::shared_ptr graph_root) { + std::shared_ptr graph_root, + InputBuffer&& input_buffer) { try { - return Engine::execute_with_graph_task(graph_task, graph_root); + return Engine::execute_with_graph_task(graph_task, graph_root, std::move(input_buffer)); } catch (python_error& e) { pybind11::gil_scoped_acquire gil; if (!PyErr_Occurred()) { @@ -118,7 +120,7 @@ std::shared_ptr PythonEngine::execute_with_graph_task( PyObject *THPEngineClass = nullptr; // Implementation of torch._C._EngineBase.run_backward -PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs) +PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwargs) { HANDLE_TH_ERRORS PyObject *tensors = nullptr; @@ -127,14 +129,14 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar unsigned char create_graph = 0; PyObject *inputs = nullptr; unsigned char allow_unreachable = 0; - const char *accepted_kwargs[] = { + unsigned char accumulate_grad = 0; // Indicate whether to accumulate grad into leaf Tensors or capture + const char *accepted_kwargs[] = { // NOLINT "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs", - "allow_unreachable", nullptr + "allow_unreachable", "accumulate_grad", nullptr }; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs, - &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs, + &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad)) return nullptr; - THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to " "be a tuple, but got %s", THPUtils_typename(tensors)); THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is " @@ -146,7 +148,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar "gradients", num_tensors, num_gradients); // The user either called autograd.backward(...) or autograd.grad(...) to get here - bool backward_api_called = inputs == nullptr; + bool backward_api_called = accumulate_grad; TORCH_CHECK(!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0, "backward() called inside torch.vmap. This is not supported, " "please call backward() outside torch.vmap or instead use " @@ -167,10 +169,6 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar "vmapped tensors (output ", i, " is being vmapped over). Please " "call autograd.grad() outside torch.vmap or file a bug report " "with your use case.") - if(variable.is_complex()) { - TORCH_WARN_ONCE("Complex backward is not fully supported yet and could lead to wrong ", - "gradients for functions we have not fixed yet"); - } auto gradient_edge = torch::autograd::impl::gradient_edge(variable); THPUtils_assert(gradient_edge.function, "element %d of tensors does not require grad and does not have a grad_fn", i); @@ -196,7 +194,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar } std::vector output_edges; - if (!backward_api_called) { + if (inputs != nullptr) { int num_inputs = PyTuple_GET_SIZE(inputs); output_edges.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { @@ -213,7 +211,11 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar const auto output_nr = input_var->cdata.output_nr(); auto grad_fn = input_var->cdata.grad_fn(); if (!grad_fn) { - grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata); + grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata); + } + if (accumulate_grad) { + THPUtils_assert(input_var->cdata.is_leaf(), + "One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor"); } THPUtils_assert(input_var->cdata.requires_grad(), "One of the differentiated Tensors does not require grad"); @@ -229,10 +231,10 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar { pybind11::gil_scoped_release no_gil; auto& engine = python::PythonEngine::get_python_engine(); - outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges); + outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges); } - if (!backward_api_called) { + if (!backward_api_called && inputs != nullptr) { int num_inputs = PyTuple_GET_SIZE(inputs); THPObjectPtr py_outputs {PyTuple_New(num_inputs)}; if (!py_outputs) return nullptr; @@ -281,9 +283,11 @@ PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) } static struct PyMethodDef THPEngine_methods[] = { - {(char*)"run_backward", (PyCFunction)(void(*)(void))THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr}, - {(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr}, + {(char*)"run_backward", + castPyCFunctionWithKeywords(THPEngine_run_backward), + METH_VARARGS | METH_KEYWORDS, nullptr}, + {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr}, + {(char*)"is_checkpoint_valid", THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr}, {nullptr} }; diff --git a/torch/csrc/autograd/python_engine.h b/torch/csrc/autograd/python_engine.h index 7d722d43d504c..3a54484d4d367 100644 --- a/torch/csrc/autograd/python_engine.h +++ b/torch/csrc/autograd/python_engine.h @@ -23,11 +23,13 @@ struct PythonEngine : public Engine { const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs = {}) override; std::shared_ptr execute_with_graph_task( const std::shared_ptr& graph_task, - std::shared_ptr graph_root) override; + std::shared_ptr graph_root, + InputBuffer&& input_buffer) override; std::unique_ptr make_anomaly_metadata() override; private: diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 8332d51f331f2..bc0922466914d 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -611,9 +611,9 @@ PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr& cdata return outputs.release(); } -PyObject* THPFunction_name(THPFunction *self, PyObject* noargs) { +PyObject* THPFunction_name(PyObject *self, PyObject* noargs) { HANDLE_TH_ERRORS - auto cdata = self->cdata.lock(); + auto cdata = ((THPFunction*)self)->cdata.lock(); return THPUtils_packString(cdata->name()); END_HANDLE_TH_ERRORS } @@ -733,7 +733,7 @@ static void _trim_grad_input(const std::shared_ptr& cdata, THPFunction * } } -PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args) +PyObject * THPFunction_do_backward(PyObject *_self, PyObject *args) { try { Py_ssize_t num_args = args ? PyTuple_GET_SIZE(args) : 0; @@ -744,6 +744,8 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args) THPUtils_invalidArguments(args, nullptr, "_do_backward", 1, "(tuple, bool)"); return nullptr; } + + auto self = (THPFunction*)_self; auto cdata = self->cdata.lock(); // In obscure situations, cdata might be nullptr because it's expired. THAT // is an internal error and I'd like to know about it, but since this is @@ -800,13 +802,14 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args) // Other methods / attributes //////////////////////////////////////////////////////////////////////////////// -PyObject* THPFunction__register_hook_dict(THPFunction *self, PyObject *_var) +PyObject* THPFunction__register_hook_dict(PyObject *_self, PyObject *_var) { HANDLE_TH_ERRORS THPUtils_assert(THPVariable_Check(_var), "_register_hook_dict expected a variable"); THPVariable *var = (THPVariable*)_var; std::unique_ptr hook(new PyFunctionPreHook( var->backward_hooks, var->cdata.output_nr())); + auto self = (THPFunction*)_self; auto cdata = self->cdata.lock(); TORCH_CHECK(cdata, "Legacy autograd function had register_hook called before the function was " @@ -818,9 +821,10 @@ PyObject* THPFunction__register_hook_dict(THPFunction *self, PyObject *_var) END_HANDLE_TH_ERRORS } -PyObject* THPFunction_register_hook(THPFunction *self, PyObject *hook) +PyObject* THPFunction_register_hook(PyObject *_self, PyObject *hook) { HANDLE_TH_ERRORS + auto self= (THPFunction*)_self; auto cdata = self->cdata.lock(); TORCH_CHECK(cdata, "Legacy autograd function had _register_hook called before the function was " @@ -1012,11 +1016,11 @@ static struct PyGetSetDef THPFunction_properties[] = { }; static struct PyMethodDef THPFunction_methods[] = { - {(char*)"name", (PyCFunction)THPFunction_name, METH_NOARGS, nullptr}, - {(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr}, - {(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, nullptr}, - {(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, nullptr}, - {(char*)"register_hook", (PyCFunction)THPFunction_register_hook, METH_O, nullptr}, + {(char*)"name", THPFunction_name, METH_NOARGS, nullptr}, + {(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr}, + {(char*)"_do_backward", THPFunction_do_backward, METH_VARARGS, nullptr}, + {(char*)"_register_hook_dict", THPFunction__register_hook_dict, METH_O, nullptr}, + {(char*)"register_hook", THPFunction_register_hook, METH_O, nullptr}, {nullptr} }; diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 81e10a9a1d1b4..4b3ad2c278f2f 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,11 @@ namespace py = pybind11; PyObject *THPVariableClass = nullptr; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *ParameterClass = nullptr; + +// clang-tidy gets confused by static const +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static const char* VOLATILE_WARNING = "volatile was removed and now has no effect. Use " "`with torch.no_grad():` instead."; @@ -145,8 +151,9 @@ static PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject } // Instantiates a subclass of self with the same data. -static PyObject* THPVariable_as_subclass(THPVariable* self, PyObject* args, PyObject* kwargs) { +static PyObject* THPVariable_as_subclass(PyObject* _self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS + auto self = (THPVariable*)_self; static PythonArgParser parser({ "as_subclass(PyObject* cls)", }); @@ -388,19 +395,19 @@ PyObject *THPVariable_get_ndim(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_names(THPVariable *self, void *unused) +PyObject *THPVariable_get_names(PyObject *self, void *unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { - return handle_torch_function_getter(self, "names"); + if (check_has_torch_function(self)) { + return handle_torch_function_getter((THPVariable*)self, "names"); } // The long-term plan is to return a list of (python) torch.Dimname. // However, for now, return a list of string. - size_t size = self->cdata.dim(); + size_t size = ((THPVariable *)self)->cdata.dim(); THPObjectPtr tuple(PyTuple_New(size)); if (!tuple) throw python_error(); - const auto dimnames = self->cdata.names(); + const auto dimnames = ((THPVariable *)self)->cdata.names(); for (size_t i = 0; i < size; ++i) { PyObject* str; if (dimnames[i].type() == at::NameType::WILDCARD) { @@ -423,12 +430,12 @@ PyObject *THPVariable_get_names(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -int THPVariable_set_names(THPVariable *self, PyObject *names) { +int THPVariable_set_names(PyObject *self, PyObject *names, void *unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { - return handle_torch_function_setter(self, "names", names); + if (check_has_torch_function(self)) { + return handle_torch_function_setter((THPVariable*)self, "names", names); } - auto& var = self->cdata; + auto& var = ((THPVariable *)self)->cdata; if (names == Py_None) { at::internal_set_names_inplace(var, at::nullopt); } else { @@ -566,6 +573,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } +PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_vulkan"); + } + auto& self_ = self->cdata; + return torch::autograd::utils::wrap(self_.is_vulkan()); + END_HANDLE_TH_ERRORS +} + PyObject *THPVariable_is_quantized(THPVariable *self, void *unused) { HANDLE_TH_ERRORS @@ -695,6 +713,7 @@ static struct PyGetSetDef THPVariable_properties[] = { {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr}, {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr}, {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, + {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr}, {"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr}, @@ -715,8 +734,10 @@ static PyMappingMethods THPVariable_as_mapping = { }; static PyMethodDef extra_methods[] = { - {"as_subclass", (PyCFunction)THPVariable_as_subclass, METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_make_subclass", (PyCFunction)THPVariable_make_subclass, METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, + {"as_subclass", castPyCFunctionWithKeywords(THPVariable_as_subclass), + METH_VARARGS | METH_KEYWORDS, nullptr}, + {"_make_subclass", castPyCFunctionWithKeywords(THPVariable_make_subclass), + METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, {nullptr} }; diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 708a9c4e0ab55..41a2ccaeaedc9 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -19,12 +19,24 @@ struct THPVariable { }; THP_API PyObject *THPVariableClass; +THP_API PyObject *ParameterClass; bool THPVariable_initModule(PyObject *module); THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var); +static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { + // Check that a python object is a `Tensor`, but not a `Tensor` subclass. + // (A subclass could have different semantics.) The one exception is + // Parameter, which is used for Python bookkeeping but is equivalent to + // Tensor as far as C++ is concerned. + return ( + tp == (PyTypeObject*)THPVariableClass || + tp == (PyTypeObject*)ParameterClass + ); +} + static inline bool THPVariable_CheckExact(PyObject *obj) { - return Py_TYPE(obj) == (PyTypeObject*)THPVariableClass; + return THPVariable_CheckTypeExact(Py_TYPE(obj)); } inline bool THPVariable_Check(PyObject *obj) diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 4b38d924c91b8..285161a49ef29 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -351,6 +351,10 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { } auto& self_ = reinterpret_cast(self)->cdata; + if (self_.is_sparse()) + { + throw TypeError("Cannot assign to a sparse tensor"); + } OptionalDeviceGuard device_guard(device_of(self_)); at::Device self_device = self_.device(); Variable value; diff --git a/torch/csrc/autograd/record_function_ops.cpp b/torch/csrc/autograd/record_function_ops.cpp index 633d0f1772959..da8cd22fbbc9e 100644 --- a/torch/csrc/autograd/record_function_ops.cpp +++ b/torch/csrc/autograd/record_function_ops.cpp @@ -65,10 +65,10 @@ c10::intrusive_ptr _call_end_callbacks_on_fut( } // Internal only, do not use directly, use Python's record_function() -static auto registry = - RegisterOperators() - .op("profiler::_record_function_enter", &record_function_enter) - .op("profiler::_record_function_exit", &record_function_exit); +TORCH_LIBRARY_FRAGMENT(profiler, m) { + m.def("_record_function_enter", &record_function_enter); + m.def("_record_function_exit", &record_function_exit); +} // Needed to register JIT operator in operator registry below c10::AliasAnalysisKind aliasAnalysisFromSchema() { diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index 580e94ea1c27f..d8058a1748c54 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -24,6 +24,12 @@ SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_i // These copies are all shared_ptr copies, so slightly more expensive. // Do them here instead of in the init list in case data is undefined. data_ = variable.tensor_data(); + // TODO(albanD) This needs to be updated when moving to multiple levels + const auto& fw_grad = variable.fw_grad(/* level */ 0); + if (fw_grad.defined()) { + fw_grad_ = std::make_shared(); + fw_grad_->set_value(fw_grad, /* level */ 0); + } if (variable.is_leaf()) { grad_accumulator_ = impl::grad_accumulator(variable); } else if (!is_output) { @@ -100,12 +106,22 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { throw std::logic_error("No grad accumulator for a saved leaf!"); impl::set_grad_accumulator(var, grad_accumulator_); + // NB: var here is never a view so there is no need to make anything special + // for the case where the saved Tensor was a view. This whole argument relies + // on the fact that the Tensor returned by this function is never + // modified in-place. + if (fw_grad_ && !fw_grad_->empty()) { + // TODO(albanD) This needs to be updated when moving to multiple levels + auto new_fw_grad = fw_grad_->value(/* level */ 0); + var.set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false); + } + return var; } const char* ERR_BACKWARD_TWICE = "Trying to backward through the graph a second time, but the saved intermediate " "results have already been freed. Specify retain_graph=True when calling " - "backward the first time."; + ".backward() or autograd.grad() the first time."; }} // namespace torch::autograd diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h index f9533d3629e0d..dde0ffa18a21a 100644 --- a/torch/csrc/autograd/saved_variable.h +++ b/torch/csrc/autograd/saved_variable.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -23,6 +24,12 @@ class TORCH_API SavedVariable { SavedVariable(const c10::optional& variable, bool is_output, bool is_inplace_view=false); SavedVariable(SavedVariable&&) = default; SavedVariable& operator=(SavedVariable&&) = default; + ~SavedVariable() { + if (fw_grad_) { + // See note [ Using ForwardGrad ] + fw_grad_->clear(); + } + } /// Reconstructs the saved variable. Pass `saved_for` as the gradient /// function if constructing the `SavedVariable` with it would have caused a @@ -40,6 +47,11 @@ class TORCH_API SavedVariable { private: at::Tensor data_; + // This field is used to store the forward AD gradients associated with + // the saved Tensor. Note that this shared_ptr must never be shared with + // either the saved Tensor or the unpacked Tensor. See note [ Using ForwardGrad ] + std::shared_ptr fw_grad_; + // The gradient function associated with this node. If has_grad_fn // is false, then this is a leaf node. Note that the grad_fn is not saved if // it would create a circular reference. In that case, the grad_fn must be diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index 9e60dc3397a4b..4d1787d55c79b 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -25,7 +25,7 @@ inline at::Tensor clone_obey_contract(const at::Tensor& new_grad, const at::Tens // (1) // Does this dicey-looking sequence attach the result to new_grad's // history if GradMode::is_enabled()? Yes, and @alband says it should. - return std::move(at::empty_strided(variable.sizes(), variable.strides(), + return std::move(new_grad.new_empty_strided(variable.sizes(), variable.strides(), variable.options().memory_format(c10::nullopt)) .copy_(new_grad)); } else { diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 64201d867d839..f28ef0f67a34b 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -20,28 +21,84 @@ #include #include #include +#include namespace torch { namespace autograd { -DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base, - c10::optional> view_fn, +DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, + c10::optional backward_info, + c10::optional forward_info, CreationMeta creation_meta) - : AutogradMeta(self_impl), creation_meta(creation_meta) { - base_ = std::move(base); - view_fn_ = std::move(view_fn); - TORCH_CHECK(base_.defined(), "base is undefined"); - if (base_.is_view()) { - base_ = base_._base(); - } + : AutogradMeta(self_impl), + backward_info_(std::move(backward_info)), + forward_info_(std::move(forward_info)), + creation_meta(creation_meta) { is_view_ = true; - self_impl->set_version_counter(impl::version_counter(base_)); - attr_version = self_impl->version_counter().current_version(); + if (backward_info_.has_value()) { + self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_)); + attr_version = self_impl->version_counter().current_version(); + } } -DifferentiableViewMeta::~DifferentiableViewMeta() { - base_.reset(); +// Chain this view info with the new view op between base and tensor +ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor, + std::function view_func) const { + // Set `view_func` using the root base as input. + // `view_func` is used to recover views in backward when either as_strided is not supported + // or the view function changes the metadata which is not recorded by as_strided + // See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor] + // for more details how we use this function in backward. + if (view_func) { + // both current_view and it's parent have a view_func + if (view_fn_) { + // Copy parent view function to gain ownership + auto prev_fn = view_fn_; + view_func = [=](const at::Tensor& root_base) { + auto temp = prev_fn(root_base); + return view_func(temp); + }; + } else { + // current_view has a view_func and but it's parent doesn't have one + if (base.unsafeGetTensorImpl()->support_as_strided()) { + auto size = base.sizes().vec(); + auto stride = base.strides().vec(); + auto storage_offset = base.storage_offset(); + view_func = [=](const at::Tensor& root_base) { + auto temp = root_base.as_strided(size, stride, storage_offset); + return view_func(temp); + }; + } else { + // When base is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's + // a view that doesn't support inplace update, e.g. unbind. + // In this case we should throw an error when inplace update happens in **forward**. + // One would naturally think the following function will be first called in backward pass. + // But the first call site is indeed in **forward** pass when we refresh `grad_fn` + // triggered by inplace update. + // Search Note [View + Inplace update for view tensor] to for the call site. + view_func = [=](const at::Tensor& root_base) { + TORCH_CHECK(false, "This view is the output of a function that returns multiple views." + "Such functions do not allow the output views to be modified inplace." + "You should replace the inplace operation by an out-of-place one"); + return root_base; + }; + } + } + } else if(view_fn_) { + // if current_view doesn't have a view_func but it's parent has one + // Copy parent view function to gain ownership + auto prev_view_fn = view_fn_; + auto size = tensor.sizes().vec(); + auto stride = tensor.strides().vec(); + auto storage_offset = tensor.storage_offset(); + view_func = [=](const at::Tensor& root_base) { + auto temp = prev_view_fn(root_base); + return temp.as_strided(size, stride, storage_offset); + }; + } + + return ViewInfo(base_, view_func); } namespace { @@ -81,19 +138,23 @@ namespace impl { auto diff_view_meta = static_cast(get_autograd_meta(self)); // See NOTE [ View + Inplace detection ] - if (diff_view_meta->creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) { + auto creation_meta = diff_view_meta->get_creation_meta(); + if (creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) { // Do not use handle_view_on_rebase here as check_inplace should have been called before this // and either throw an error or clear the warning - TORCH_INTERNAL_ASSERT(diff_view_meta->creation_meta == CreationMeta::DEFAULT); + // Temporary error message as a full fix is too risky for now + // Should be an internal assert again + TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT); TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0); TORCH_INTERNAL_ASSERT(gradient_edge.function); TORCH_CHECK( gradient_edge.function->num_inputs() == 1, "Functions which modify views in-place must return a single Variable"); + auto view_info = diff_view_meta->get_backward_view(); diff_view_meta->output_nr_ = gradient_edge.input_nr; auto copy_slices = std::make_shared( - diff_view_meta->base_, at::TensorGeometry(self), diff_view_meta->view_fn_, std::move(gradient_edge.function)); - set_gradient_edge(diff_view_meta->base_, {std::move(copy_slices), 0}); + view_info.base_, at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function)); + set_gradient_edge(view_info.base_, {std::move(copy_slices), 0}); self.grad_fn(); // trigger an update to the view's grad_fn return; } @@ -179,7 +240,7 @@ namespace impl { if (self.is_view()) { // NB: is_view() ==> get_autograd_meta() auto diff_view_meta = static_cast(meta); - diff_view_meta->attr_version = self._version(); + diff_view_meta->set_attr_version(self._version()); } } @@ -296,12 +357,14 @@ Tensor VariableHooks::tensor_data(const Tensor& self) const { return at::Tensor(self_impl_copy); } -// View Variables +// Backward View Variables //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ bool VariableHooks::is_view(const Tensor& self) const { - if (torch::autograd::impl::get_autograd_meta(self)) { - return torch::autograd::impl::get_autograd_meta(self)->is_view_; + auto meta = torch::autograd::impl::get_autograd_meta(self); + if (meta && meta->is_view_) { + auto diff_view_meta = static_cast(meta); + return diff_view_meta->has_bw_view(); } else { return false; } @@ -311,9 +374,10 @@ const Tensor& VariableHooks::base(const Tensor& self) const { if (self.is_view()) { // is_view() implies get_autograd_meta() auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(self)); - return diff_view_meta->base_; + TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor"); + return diff_view_meta->get_backward_view().base_; } else { - throw std::runtime_error("Can't get base of non-view Variable"); + throw std::runtime_error("Can't get base of non-view Tensor"); } } @@ -340,13 +404,14 @@ const std::shared_ptr& VariableHooks::grad_fn(const Tenso auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(self)); // See NOTE [ View + Inplace detection ] - if (diff_view_meta->creation_meta != CreationMeta::MULTI_OUTPUT_SAFE) { + if (diff_view_meta->get_creation_meta() != CreationMeta::MULTI_OUTPUT_SAFE) { std::lock_guard lock(diff_view_meta->mutex_); - if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) { + auto view_info = diff_view_meta->get_backward_view(); + if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) { return diff_view_meta->grad_fn_; } auto current_version = self._version(); - if (diff_view_meta->attr_version != current_version) { + if (diff_view_meta->get_attr_version() != current_version) { // This is an indirect rebase_history due to another view or the base being modified inplace handle_view_on_rebase(diff_view_meta, /* indirect */ true); TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0); @@ -375,24 +440,24 @@ const std::shared_ptr& VariableHooks::grad_fn(const Tenso // // TODO: Potentially the following logic can be replaced by special logic in VariableType_x.cpp // that would provide a way to recreate the grad_fn chain. - if (diff_view_meta->has_view_fn()) { - auto view_fn = diff_view_meta->view_fn(); - auto diff_view = view_fn(diff_view_meta->base_); + if (view_info.has_view_fn()) { + auto view_fn = view_info.view_fn(); + auto diff_view = view_fn(view_info.base_); diff_view_meta->grad_fn_ = diff_view.grad_fn(); } else { auto fn = std::make_shared(); - fn->self_geometry = at::TensorGeometry(diff_view_meta->base_); + fn->self_geometry = at::TensorGeometry(view_info.base_); fn->size = self.sizes().vec(); fn->stride = self.strides().vec(); fn->storage_offset = self.storage_offset(); - fn->set_next_edges(torch::autograd::collect_next_edges(diff_view_meta->base_)); + fn->set_next_edges(torch::autograd::collect_next_edges(view_info.base_)); fn->add_input_metadata( - diff_view_meta->base_.options(), + view_info.base_.options(), self.sizes(), // Note: sizes(), not base_.sizes(), is intentional - diff_view_meta->base_.device()); + view_info.base_.device()); diff_view_meta->grad_fn_ = std::move(fn); } - diff_view_meta->attr_version = current_version; + diff_view_meta->set_attr_version(current_version); } return diff_view_meta->grad_fn_; } @@ -427,7 +492,8 @@ unsigned VariableHooks::_register_hook(const Tensor& self, std::functioncreation_meta != CreationMeta::DEFAULT) { + auto creation_meta = diff_view_meta->get_creation_meta(); + if (creation_meta != CreationMeta::DEFAULT) { auto grad_fn = diff_view_meta->grad_fn_.get(); std::string msg; std::string modified_obj; @@ -444,24 +510,24 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect msg = c10::str("A view was created in no_grad mode and ", modified_obj, " modified inplace with grad mode enabled."); } - if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_NODE) { + if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) { TORCH_CHECK(false, msg, " This view is the output of a function that returns multiple views. Such functions do not" " allow the output views to be modified inplace. You should replace the inplace operation by an" " out-of-place one."); } else { - if (diff_view_meta->creation_meta == CreationMeta::NO_GRAD_MODE) { + if (creation_meta == CreationMeta::NO_GRAD_MODE) { TORCH_INTERNAL_ASSERT(!grad_fn); msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is deprecated and will be forbidden" " starting 1.6 (see https://github.com/pytorch/pytorch/pull/32839 for more details about this). You" " can clarify your code and remove this warning by moving both the view and the inplace either both" " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want" " the inplace to be tracked)."); - } else if (diff_view_meta->creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) { + } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) { msg = c10::str(msg, " This view was created inside a custom Function (or because an input was returned as-is) and the" " autograd logic to handle view+inplace would override the custom backward associated with the custom" " Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting" " version 1.6. You can remove this warning by cloning the output of the custom Function."); - } else if (diff_view_meta->creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) { + } else if (creation_meta == CreationMeta::MULTI_OUTPUT_SAFE) { msg = c10::str(msg, " This view is an output of a function that " "returns multiple views. Inplace operators on such " "views are being deprecated and will be forbidden " @@ -472,9 +538,10 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state"); } - if (!indirect && !grad_fn) { + if (!indirect && !grad_fn && diff_view_meta->requires_grad()) { // This view is (wrongly) detected as a leaf that requires grad and would raise the surprising: "a leaf Variable that - // requires grad is being used in an in-place operation." after the warning. So we make the warning an error directly. + // requires grad is being used in an in-place operation." after the warning from the `check_inplace` function in + // VariabbleTypeUtils.h. So we make the warning an error directly. TORCH_CHECK(false, msg); } else { TORCH_WARN(msg); @@ -484,8 +551,10 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect // We warn only once per view // Note that if a Tensor is modified inplace from two threads at the same time, this is not thread safe and can warn // multiple time. This is ok as it should be a rare event. - diff_view_meta->creation_meta = CreationMeta::DEFAULT; + diff_view_meta->set_creation_meta(CreationMeta::DEFAULT); } } + + }} // namespace torch::autograd diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 352e315de7ada..9cdf40fe2c634 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -193,6 +194,17 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { std::shared_ptr grad_fn_; std::weak_ptr grad_accumulator_; + // This field is used to store all the forward AD gradients + // associated with this AutogradMeta (and the Tensor it corresponds to) + // There is a semantic 1:1 correspondence between AutogradMeta and + // ForwardGrad but: + // - This field is lazily populated. + // - This field is a shared_ptr but it must never be + // shared by multiple Tensors. See Note [ Using ForwardGrad ] + // Any transition from not_initialized to initialized + // must be protected by mutex_ + std::shared_ptr fw_grad_; + std::vector> hooks_; std::shared_ptr cpp_hooks_list; @@ -211,9 +223,11 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { uint32_t output_nr_; // Mutex to ensure that concurrent read operations that modify internal - // state are still thread-safe. Used by grad_fn() and - // grad_accumulator(). - std::mutex mutex_; + // state are still thread-safe. Used by grad_fn(), grad_accumulator(), + // fw_grad() and set_fw_grad() + // This is mutable because we need to be able to acquire this from const + // version of this class for the functions above + mutable std::mutex mutex_; /// Sets the `requires_grad` property of `Variable`. This should be true for /// leaf variables that want to accumulate gradients, and false for all other @@ -238,6 +252,10 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { return grad_; } + const Variable& fw_grad(uint64_t level, const Variable& self) const override; + + void set_fw_grad(const Variable& new_grad, const Variable& self, uint64_t level, bool is_inplace_op) override; + AutogradMeta(at::TensorImpl* self_impl = nullptr, bool requires_grad = false, Edge gradient_edge = Edge() ) { grad_fn_ = std::move(gradient_edge.function); requires_grad_ = false; @@ -254,6 +272,55 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { !grad_fn_ || !requires_grad_, "requires_grad should be false if grad_fn is set"); } + + ~AutogradMeta() override { + // If AutogradMeta is being destroyed, it means that there is no other reference to its + // corresponding Tensor. It implies that no other thread can be using this object and so there is + // no need to lock mutex_ here to guard the check if fw_grad_ is populated. + if (fw_grad_) { + // See note [ Using ForwardGrad ] + fw_grad_->clear(); + } + } +}; + +struct TORCH_API ViewInfo { + /// The base `Variable` + /// If this ViewInfo represents a forward (respectively backward) AD gradient, + /// then this Tensor cannot be a forward (respectively backward) view. + Variable base_; + + /// By default we use as_strided to recover views which is more efficient. + /// view_fn is only saved when as_strided is not supported. + /// If view_fn has value, we use it to recover views in backward. + std::function view_fn_; + + /// Accessors for the view function + bool has_view_fn() const { + return view_fn_ != nullptr; + } + + std::function view_fn() const { + TORCH_CHECK(has_view_fn(), "Can only access the view function if it exists."); + return view_fn_; + } + + /// The chain function can be used to build a new ViewInfo for a differentiable view + /// function. It will return a new view info that accurately represents how "tensor" is + /// a view of this instance's "base_". + /// The "base" and "tensor" are respectively the input and output of the differentiable + /// view function that happened. They are required to properly set the optional + /// view_fn_ when it is not provided. + /// The "view_func", if provided, should be a function that allows to re-do the view + /// between "base" and "tensor". + ViewInfo chain(const Variable & base, const Variable & tensor, + std::function view_func=nullptr) const; + + ViewInfo(Variable base, std::function view_fn) : + base_(std::move(base)), + view_fn_(std::move(view_fn)) { + TORCH_CHECK(base_.defined(), "base is undefined"); + } }; //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -274,6 +341,27 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { /// /// Differentiable Views /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +/// This class allows to track both forward and backward AD differentiable views. +/// These views can have different base as non-differentiable view for forward +/// and backward mode AD are not the same. +/// +/// Most function are either both forward and backward differentiable views (for +/// example: view, select, narrow, transpose, etc) or both not forward and not +/// backward differentiable views (for example: indices, values, eq, lt, etc). +/// But there are also functions that are forward but not backward differentiable +/// views (only detach for now) or functions that are backward but not forward +/// differentiable view (only make_dual and unpack dual for now). +/// +/// A concrete example of two views with different bases is as follow: +/// +/// # Have: +/// # dual is a dual Tensor that is neither a forward or backward view +/// detached_dual = dual.detach() +/// view = detached_dual.view_as(dual) +/// # The forward base of view is dual +/// # The backward base of view is detached_dual +/// +/// - Backward Mode View /// Differentiable views are the view variables where you want gradients to flow /// back to the base variables. Out-of-place operations on views are quite /// straightforward, but in-place ones are very tricky. Even if the base @@ -300,6 +388,34 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { /// var[1] filled with all ones and /// zeros everywhere else /// +/// - Forward Mode View +/// Forward differentiable views follow the same semantic as backward ones but +/// show up differently as they are computed along with the forward evaluation. +/// The hard examples above are thus very similar +/// +/// (1) in-place operation on view, e.g., +/// +/// # Have: +/// # base is a regular Tensor +/// # var is a dual Tensor whose tangent is all ones +/// base[1] = var # i.e., base[1].copy_(var) +/// # Now, base is a dual Tensor +/// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with +/// fw_grad[1] filled with all ones and +/// zeros everywhere else +/// +/// (2) in-place operation on base after view is created, e.g., +/// +/// # Have: +/// # base is a regular Tensor +/// # var is a dual Tensor whose tangent is all ones +/// view = base[1] +/// base.copy_(var) +/// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones tensor +/// +/// See Note [Forward Grad View/inplace] for more details on how we handle these hard cases. +/// +/// /// DifferentiableViewMeta is created to support gradient tracking of /// such **in-place** operations. In particular, /// + if an in-place op is done on base, the grad_fn field of the view may @@ -392,37 +508,66 @@ enum class CreationMeta: uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NOD TORCH_API void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect=false); struct TORCH_API DifferentiableViewMeta : public AutogradMeta { - /// The base `Variable` (never a view). - Variable base_; +private: + /// Informations about the views + c10::optional backward_info_; + c10::optional forward_info_; + + /// The two following fields are extra information that we track to ensure that + /// any operation on this backward view is valid. /// The value of the version_counter at the time grad_fn was created. The - /// grad_fn field is stale if attr_version != - /// version_counter.current_version(). + /// grad_fn field is stale if attr_version != version_counter.current_version(). uint32_t attr_version; - - /// By default we use as_strided to recover views which is more efficient. - /// view_fn is only saved when as_strided is not supported. - /// If view_fn has value, we use it to recover views in backward. - c10::optional> view_fn_; - CreationMeta creation_meta; +public: + /// requires_grad is a backward AD field so we only use the view specific logic + /// for backward differentiable views bool requires_grad() const override { - return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad()); + return requires_grad_ || grad_fn_ || (has_bw_view() && get_backward_view().base_.requires_grad()); } - bool has_view_fn() const { - return view_fn_.has_value(); + bool has_bw_view() const { + return backward_info_.has_value(); + } + + const ViewInfo& get_backward_view() const { + TORCH_CHECK(has_bw_view(), "backward view info can only exist for backward views."); + return backward_info_.value(); + } + + uint32_t get_attr_version() const { + TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views."); + return attr_version; + } + + void set_attr_version(uint32_t new_attr_version) { + TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views."); + attr_version = new_attr_version; } - std::function view_fn() const { - TORCH_CHECK(has_view_fn(), "view_fn is not set."); - return view_fn_.value(); + CreationMeta get_creation_meta() const { + TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views."); + return creation_meta; } - DifferentiableViewMeta(at::TensorImpl* self_impl, Variable base, c10::optional> view_fn, - CreationMeta creation_meta=CreationMeta::DEFAULT); - ~DifferentiableViewMeta(); + void set_creation_meta(CreationMeta new_creation_meta) { + TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views."); + creation_meta = new_creation_meta; + } + + bool has_fw_view() const { + return forward_info_.has_value(); + } + + const ViewInfo& get_forward_view() const { + TORCH_CHECK(has_fw_view(), "forward view info can only exist for forward views."); + return forward_info_.value(); + } + + DifferentiableViewMeta(at::TensorImpl* self_impl, c10::optional backward_info, + c10::optional forward_info, CreationMeta creation_meta=CreationMeta::DEFAULT); }; //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -448,18 +593,32 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta { // See NOTE [ Autograd View Variables ] for details. // Differentiable view. Track history with DifferentiableViewMeta. inline Variable make_variable_differentiable_view( - Variable base, - at::Tensor data, + const at::Tensor& data, + c10::optional backward_info, + c10::optional forward_info, CreationMeta creation_meta, - c10::optional> view_func = c10::nullopt) { + bool allow_tensor_metadata_change = true) { if (data.defined()) { - auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( - /*version_counter=*/0, - /*allow_tensor_metadata_change=*/true); - data_impl_copy->set_autograd_meta(std::make_unique( - data_impl_copy.get(), std::move(base), std::move(view_func), + // If we already did a TensorImpl allocation for data, just reuse it. + // Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as input), + // we have to use shallow_copy_and_detach to create a new TensorImpl to avoid + // moving leaf node into graph interior. This guarantees only 1 TensorImpl + // allocation happens in view ops. + if (data.getIntrusivePtr().unique() && data.getIntrusivePtr()->unique_version()) { + at::TensorImpl* data_impl = data.unsafeGetTensorImpl(); + data_impl->set_autograd_meta(std::make_unique( + data_impl, std::move(backward_info), std::move(forward_info), creation_meta)); - return Variable(data_impl_copy); + return data; + } else { + c10::intrusive_ptr data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( + /*version_counter=*/0, + /*allow_tensor_metadata_change=*/true); + data_impl_copy->set_autograd_meta(std::make_unique( + data_impl_copy.get(), std::move(backward_info), std::move(forward_info), + creation_meta)); + return Variable(data_impl_copy); + } } return Variable(); } @@ -468,9 +627,12 @@ inline Variable make_variable_differentiable_view( // Non-differentiable view. Just share version counter. inline Variable make_variable_non_differentiable_view( Variable base, - at::Tensor data, + const at::Tensor& data, bool allow_tensor_metadata_change = true) { if (data.defined()) { + // Currently all of non-differentiable view ops(detach/_indices/_values) + // share the same TensorImpl as their base Tensor. Thus a new TensorImpl + // allocation here is required. auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( /*version_counter=*/impl::version_counter(base), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp index 3224d6956d4be..5eb549e76c240 100644 --- a/torch/csrc/cuda/Event.cpp +++ b/torch/csrc/cuda/Event.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -45,8 +46,9 @@ static PyObject * THCPEvent_pynew( } static PyObject * THCPEvent_from_ipc_handle( - PyTypeObject *type, PyObject *args, PyObject *kwargs) { + PyObject *_type, PyObject *args, PyObject *kwargs) { HANDLE_TH_ERRORS + auto type = (PyTypeObject*)_type; static torch::PythonArgParser parser({ "from_ipc_handle(Device device, std::string ipc_handle)", @@ -98,16 +100,20 @@ static PyObject * THCPEvent_get_device(THCPEvent *self, void *unused) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_record(THCPEvent *self, THCPStream *stream) { +static PyObject * THCPEvent_record(PyObject *_self, PyObject *_stream) { HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; + auto stream = (THCPStream*)_stream; self->cuda_event.record(stream->cuda_stream); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_wait(THCPEvent *self, THCPStream *stream) { +static PyObject * THCPEvent_wait(PyObject *_self, PyObject *_stream) { HANDLE_TH_ERRORS { + auto self = (THCPEvent*)_self; + auto stream = (THCPStream*)_stream; pybind11::gil_scoped_release no_gil; self->cuda_event.block(stream->cuda_stream); } @@ -115,21 +121,25 @@ static PyObject * THCPEvent_wait(THCPEvent *self, THCPStream *stream) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_query(THCPEvent *self, PyObject *noargs) { +static PyObject * THCPEvent_query(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; return PyBool_FromLong(self->cuda_event.query()); END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_elapsed_time(THCPEvent *self, THCPEvent *other) { +static PyObject * THCPEvent_elapsed_time(PyObject *_self, PyObject *_other) { HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; + auto other = (THCPEvent*)_other; return PyFloat_FromDouble(self->cuda_event.elapsed_time(other->cuda_event)); END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_synchronize(THCPEvent *self, PyObject *noargs) { +static PyObject * THCPEvent_synchronize(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS { + auto self = (THCPEvent*)_self; pybind11::gil_scoped_release no_gil; self->cuda_event.synchronize(); } @@ -137,8 +147,9 @@ static PyObject * THCPEvent_synchronize(THCPEvent *self, PyObject *noargs) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_ipc_handle(THCPEvent *self, PyObject *noargs) { +static PyObject * THCPEvent_ipc_handle(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THCPEvent*)_self; cudaIpcEventHandle_t handle; self->cuda_event.ipc_handle(&handle); return PyBytes_FromStringAndSize((const char *)&handle, sizeof(handle)); @@ -152,15 +163,16 @@ static struct PyGetSetDef THCPEvent_properties[] = { }; static PyMethodDef THCPEvent_methods[] = { - {(char*)"from_ipc_handle", (PyCFunction)(void(*)(void))THCPEvent_from_ipc_handle, + {(char*)"from_ipc_handle", + castPyCFunctionWithKeywords(THCPEvent_from_ipc_handle), METH_CLASS | METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"record", (PyCFunction)THCPEvent_record, METH_O, nullptr}, - {(char*)"wait", (PyCFunction)THCPEvent_wait, METH_O, nullptr}, - {(char*)"query", (PyCFunction)THCPEvent_query, METH_NOARGS, nullptr}, - {(char*)"elapsed_time", (PyCFunction)THCPEvent_elapsed_time, METH_O, nullptr}, - {(char*)"synchronize", (PyCFunction)THCPEvent_synchronize, + {(char*)"record", THCPEvent_record, METH_O, nullptr}, + {(char*)"wait", THCPEvent_wait, METH_O, nullptr}, + {(char*)"query", THCPEvent_query, METH_NOARGS, nullptr}, + {(char*)"elapsed_time", THCPEvent_elapsed_time, METH_O, nullptr}, + {(char*)"synchronize", THCPEvent_synchronize, METH_NOARGS, nullptr}, - {(char*)"ipc_handle", (PyCFunction)THCPEvent_ipc_handle, + {(char*)"ipc_handle", THCPEvent_ipc_handle, METH_NOARGS, nullptr}, {nullptr} }; diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp new file mode 100644 index 0000000000000..b258f00bcf903 --- /dev/null +++ b/torch/csrc/cuda/Graph.cpp @@ -0,0 +1,46 @@ +#include + +#include + +#include +#include + +#include + +// Cargo culted partially from csrc/distributed/c10d/init.cpp +// and partially from csrc/cuda/Stream.cpp. +// THCPStream_init is also declared at global scope. + +// Because THCPGraph_init is forward declared in the only consumer (csrc/Module.cpp) +// I don't think we need a Graph.h. + +template +using shared_ptr_class_ = py::class_>; + +void THCPGraph_init(PyObject *module) { + // Pybind11 patch notes say "py::module_" is more up-to-date syntax, + // but CI linter and some builds prefer "module". + auto torch_C_m = py::handle(module).cast(); + + shared_ptr_class_<::at::cuda::CUDAGraph>(module, "_CudaGraphBase") + .def(py::init<>()) + .def("capture_begin", + &::at::cuda::CUDAGraph::capture_begin, + py::call_guard(), + R"(``capture_begin`` begins Cuda graph capture on the current stream.)") + .def("capture_end", + &::at::cuda::CUDAGraph::capture_end, + py::call_guard(), + R"(``capture_end`` ends Cuda graph capture on the current stream. + After ``capture_end``, ``replay`` may be called on this instance.)") + .def("replay", + &::at::cuda::CUDAGraph::replay, + py::call_guard(), + R"(``replay`` replays the Cuda graph captured by this instance.)") + // reset is called in __del__ on the Python side + // (see class Graph in torch/cuda/streams.py for reasons and caveats) + .def("reset", + &::at::cuda::CUDAGraph::reset, + py::call_guard(), + R"(``reset`` deletes the graph currently held by this instance.)"); +} diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 253816c7ea9c3..6d0aad7dda127 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -77,7 +78,32 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs) HANDLE_TH_ERRORS torch::utils::cuda_lazy_init(); auto device = static_cast(c10::cuda::current_device()); - return PyLong_FromLong(device); + return THPUtils_packInt32(device); + END_HANDLE_TH_ERRORS +} + +PyObject * THCPModule_canDeviceAccessPeer_wrap(PyObject *self, PyObject *args) +{ + HANDLE_TH_ERRORS + PyObject* arg1 = nullptr; + PyObject* arg2 = nullptr; + if(!PyArg_ParseTuple(args, "OO", &arg1, &arg2)) { + THPUtils_invalidArguments( + args, + nullptr, + "can_device_peer_access", + 1, + "(int device, int peer_device);"); + return nullptr; + } + THPUtils_assert(THPUtils_checkLong(arg1), "invalid argument to canDeviceAccessPeer"); + THPUtils_assert(THPUtils_checkLong(arg2), "invalid argument to canDeviceAccessPeer"); + int64_t device = THPUtils_unpackLong(arg1); + int64_t peer_device = THPUtils_unpackLong(arg2); + + torch::utils::cuda_lazy_init(); + auto can_access = at::cuda::canDeviceAccessPeer(device, peer_device); + return PyBool_FromLong(can_access); END_HANDLE_TH_ERRORS } @@ -85,7 +111,7 @@ PyObject * THCPModule_getDeviceCount_wrap(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS poison_fork(); - return PyLong_FromLong(at::cuda::device_count()); + return THPUtils_packUInt64(at::cuda::device_count()); END_HANDLE_TH_ERRORS } @@ -150,7 +176,7 @@ PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *obj) PyObject * THCPModule_getCompiledVersion(PyObject *self, PyObject *noargs) { - return PyLong_FromLong((long) CUDA_VERSION); + return THPUtils_packInt64((int64_t) CUDA_VERSION); } PyObject * THCPModule_cudaHostAllocator(PyObject *_unused, PyObject *noargs) @@ -263,6 +289,28 @@ PyObject * THCPModule_hasPrimaryContext(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } +PyObject * THCPModule_setMemoryFraction(PyObject *_unused, PyObject *args) +{ + HANDLE_TH_ERRORS + PyObject* fraction_o = nullptr; + PyObject* device_o = nullptr; + if(!PyArg_ParseTuple(args, "OO", &fraction_o, &device_o)) { + THPUtils_invalidArguments( + args, + nullptr, + "set_memory_fraction", + 1, + "(double fraction, int device);"); + return nullptr; + } + double fraction = PyFloat_AsDouble(fraction_o); + int64_t device = PyLong_AsLongLong(device_o); + + c10::cuda::CUDACachingAllocator::setMemoryFraction(fraction, device); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + PyObject * THCPModule_emptyCache(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS @@ -484,42 +532,44 @@ PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self, PyObject *noargs } static struct PyMethodDef _THCPModule_methods[] = { - {"_cuda_init", (PyCFunction)THCPModule_initExtension, METH_NOARGS, nullptr}, - {"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, nullptr}, - {"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, nullptr}, - {"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr}, - {"_cuda_getArchFlags", (PyCFunction)THCPModule_getArchFlags, METH_NOARGS, nullptr}, - {"_cuda_isInBadFork", (PyCFunction)THCPModule_isInBadFork, METH_NOARGS, nullptr}, + {"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr}, + {"_cuda_setDevice", THCPModule_setDevice_wrap, METH_O, nullptr}, + {"_cuda_getDevice", THCPModule_getDevice_wrap, METH_NOARGS, nullptr}, + {"_cuda_getDeviceCount", THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr}, + {"_cuda_canDeviceAccessPeer", THCPModule_canDeviceAccessPeer_wrap, METH_VARARGS, nullptr}, + {"_cuda_getArchFlags", THCPModule_getArchFlags, METH_NOARGS, nullptr}, + {"_cuda_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr}, {"_cuda_getCurrentStream", - (PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr}, + THCPModule_getCurrentStream_wrap, METH_O, nullptr}, {"_cuda_getDefaultStream", - (PyCFunction)THCPModule_getDefaultStream_wrap, METH_O, nullptr}, - {"_cuda_getCurrentBlasHandle", (PyCFunction)THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr}, - {"_cuda_setStream", (PyCFunction)THCPModule_setStream_wrap, METH_O, nullptr}, - {"_cuda_getCompiledVersion", (PyCFunction)THCPModule_getCompiledVersion, METH_NOARGS, nullptr}, - {"_cuda_hasPrimaryContext", (PyCFunction) THCPModule_hasPrimaryContext, METH_O, nullptr}, - {"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache, METH_NOARGS, nullptr}, - {"_cuda_memoryStats", (PyCFunction) THCPModule_memoryStats, METH_O, nullptr}, - {"_cuda_resetAccumulatedMemoryStats", (PyCFunction) THCPModule_resetAccumulatedMemoryStats, METH_O, nullptr}, - {"_cuda_resetPeakMemoryStats", (PyCFunction) THCPModule_resetPeakMemoryStats, METH_O, nullptr}, - {"_cuda_memorySnapshot", (PyCFunction) THCPModule_memorySnapshot, METH_NOARGS, nullptr}, - {"_cuda_cudaHostAllocator", (PyCFunction)THCPModule_cudaHostAllocator, METH_NOARGS, nullptr}, - {"_cuda_cudaCachingAllocator_raw_alloc", (PyCFunction)THCPModule_cudaCachingAllocator_raw_alloc, METH_VARARGS, nullptr}, - {"_cuda_cudaCachingAllocator_raw_delete", (PyCFunction)THCPModule_cudaCachingAllocator_raw_delete, METH_O, nullptr}, - {"_cuda_synchronize", (PyCFunction)THCPModule_cudaSynchronize, METH_NOARGS, nullptr}, - {"_cuda_ipc_collect", (PyCFunction)THCPModule_cudaIPCCollect, METH_NOARGS, nullptr}, - {"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, nullptr}, - {"_cuda_lock_mutex", (PyCFunction)THCPModule_cudaLockMutex, METH_NOARGS, nullptr}, - {"_cuda_unlock_mutex", (PyCFunction)THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr}, + THCPModule_getDefaultStream_wrap, METH_O, nullptr}, + {"_cuda_getCurrentBlasHandle", THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr}, + {"_cuda_setStream", THCPModule_setStream_wrap, METH_O, nullptr}, + {"_cuda_getCompiledVersion", THCPModule_getCompiledVersion, METH_NOARGS, nullptr}, + {"_cuda_hasPrimaryContext", THCPModule_hasPrimaryContext, METH_O, nullptr}, + {"_cuda_setMemoryFraction", THCPModule_setMemoryFraction, METH_VARARGS, nullptr}, + {"_cuda_emptyCache", THCPModule_emptyCache, METH_NOARGS, nullptr}, + {"_cuda_memoryStats", THCPModule_memoryStats, METH_O, nullptr}, + {"_cuda_resetAccumulatedMemoryStats", THCPModule_resetAccumulatedMemoryStats, METH_O, nullptr}, + {"_cuda_resetPeakMemoryStats", THCPModule_resetPeakMemoryStats, METH_O, nullptr}, + {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr}, + {"_cuda_cudaHostAllocator", THCPModule_cudaHostAllocator, METH_NOARGS, nullptr}, + {"_cuda_cudaCachingAllocator_raw_alloc", THCPModule_cudaCachingAllocator_raw_alloc, METH_VARARGS, nullptr}, + {"_cuda_cudaCachingAllocator_raw_delete", THCPModule_cudaCachingAllocator_raw_delete, METH_O, nullptr}, + {"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr}, + {"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr}, + {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr}, + {"_cuda_lock_mutex", THCPModule_cudaLockMutex, METH_NOARGS, nullptr}, + {"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr}, #ifdef USE_NCCL - {"_nccl_version", (PyCFunction)THCPModule_nccl_version, METH_NOARGS, nullptr}, - {"_nccl_unique_id", (PyCFunction)THCPModule_nccl_unique_id, METH_NOARGS, nullptr}, - {"_nccl_init_rank", (PyCFunction)THCPModule_nccl_init_rank, METH_VARARGS, nullptr}, - {"_nccl_reduce", (PyCFunction)THCPModule_nccl_reduce, METH_VARARGS, nullptr}, - {"_nccl_all_reduce", (PyCFunction)THCPModule_nccl_all_reduce, METH_VARARGS, nullptr}, - {"_nccl_broadcast", (PyCFunction)THCPModule_nccl_broadcast, METH_VARARGS, nullptr}, - {"_nccl_all_gather", (PyCFunction)THCPModule_nccl_all_gather, METH_VARARGS, nullptr}, - {"_nccl_reduce_scatter", (PyCFunction)THCPModule_nccl_reduce_scatter, METH_VARARGS, nullptr}, + {"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr}, + {"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr}, + {"_nccl_init_rank", THCPModule_nccl_init_rank, METH_VARARGS, nullptr}, + {"_nccl_reduce", THCPModule_nccl_reduce, METH_VARARGS, nullptr}, + {"_nccl_all_reduce", THCPModule_nccl_all_reduce, METH_VARARGS, nullptr}, + {"_nccl_broadcast", THCPModule_nccl_broadcast, METH_VARARGS, nullptr}, + {"_nccl_all_gather", THCPModule_nccl_all_gather, METH_VARARGS, nullptr}, + {"_nccl_reduce_scatter", THCPModule_nccl_reduce_scatter, METH_VARARGS, nullptr}, #endif {nullptr} }; diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp index 775d48e16f351..e0bb1922ff096 100644 --- a/torch/csrc/cuda/Stream.cpp +++ b/torch/csrc/cuda/Stream.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -65,11 +66,11 @@ static PyObject * THCPStream_get_cuda_stream(THCPStream *self, void *unused) { static PyObject * THCPStream_get_priority(THCPStream *self, void *unused) { HANDLE_TH_ERRORS - return PyLong_FromLong(self->cuda_stream.priority()); + return THPUtils_packInt64(self->cuda_stream.priority()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_priority_range() { +static PyObject * THCPStream_priority_range(PyObject *_unused, PyObject* noargs) { HANDLE_TH_ERRORS int least_priority, greatest_priority; std::tie(least_priority, greatest_priority) = @@ -78,36 +79,37 @@ static PyObject * THCPStream_priority_range() { END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_query(THCPStream *self, PyObject *noargs) { +static PyObject * THCPStream_query(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THCPStream*)_self; return PyBool_FromLong(self->cuda_stream.query()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_synchronize(THCPStream *self, PyObject *noargs) { +static PyObject * THCPStream_synchronize(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS { pybind11::gil_scoped_release no_gil; + auto self = (THCPStream*)_self; self->cuda_stream.synchronize(); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_eq(THCPStream *self, THCPStream *other) { +static PyObject * THCPStream_eq(PyObject *_self, PyObject *_other) { HANDLE_TH_ERRORS + auto self = (THCPStream*)_self; + auto other = (THCPStream*)_other; return PyBool_FromLong(self->cuda_stream == other->cuda_stream); END_HANDLE_TH_ERRORS } static struct PyMemberDef THCPStream_members[] = { - {(char*)"_cdata", - T_ULONGLONG, offsetof(THCPStream, cdata), READONLY, nullptr}, {nullptr} }; static struct PyGetSetDef THCPStream_properties[] = { - {"device", (getter)THCPStream_get_device, nullptr, nullptr, nullptr}, {"cuda_stream", (getter)THCPStream_get_cuda_stream, nullptr, nullptr, nullptr}, {"priority", (getter)THCPStream_get_priority, nullptr, nullptr, nullptr}, @@ -115,12 +117,12 @@ static struct PyGetSetDef THCPStream_properties[] = { }; static PyMethodDef THCPStream_methods[] = { - {(char*)"query", (PyCFunction)THCPStream_query, METH_NOARGS, nullptr}, + {(char*)"query", THCPStream_query, METH_NOARGS, nullptr}, {(char*)"synchronize", - (PyCFunction)THCPStream_synchronize, METH_NOARGS, nullptr}, + THCPStream_synchronize, METH_NOARGS, nullptr}, {(char*)"priority_range", - (PyCFunction)(void(*)(void))THCPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr}, - {(char*)"__eq__", (PyCFunction)THCPStream_eq, METH_O, nullptr}, + THCPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr}, + {(char*)"__eq__", THCPStream_eq, METH_O, nullptr}, {nullptr} }; @@ -154,7 +156,7 @@ PyTypeObject THCPStreamType = { 0, /* tp_iternext */ THCPStream_methods, /* tp_methods */ THCPStream_members, /* tp_members */ - THCPStream_properties, /* tp_getset */ + THCPStream_properties, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ @@ -168,6 +170,8 @@ PyTypeObject THCPStreamType = { void THCPStream_init(PyObject *module) { + Py_INCREF(THPStreamClass); + THCPStreamType.tp_base = THPStreamClass; THCPStreamClass = (PyObject*)&THCPStreamType; if (PyType_Ready(&THCPStreamType) < 0) { throw python_error(); diff --git a/torch/csrc/cuda/Stream.h b/torch/csrc/cuda/Stream.h index c98d1352e399b..71acdc5b7d317 100644 --- a/torch/csrc/cuda/Stream.h +++ b/torch/csrc/cuda/Stream.h @@ -1,13 +1,12 @@ #ifndef THCP_STREAM_INC #define THCP_STREAM_INC +#include #include #include #include -struct THCPStream { - PyObject_HEAD - uint64_t cdata; +struct THCPStream : THPStream{ at::cuda::CUDAStream cuda_stream; }; extern PyObject *THCPStreamClass; diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index ca341305ec1d1..1f85b0e1eba54 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -130,7 +130,7 @@ std::vector broadcast(const Tensor& tensor, IntArrayRef devices) { // When splitting, the view operations will make all Variables broadcast // together to share a single version counter, because they are all views of the // large Variable. However, that large Variable is immediately discarded and all -// these Varaibles do not share storage at all. +// these Variables do not share storage at all. // // For example, when two buffers are broadcast together in `DataParallel` and // one of them is modified in-place during `forward` but the other is needed in diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 6cef307c7cceb..5efb77ea536ae 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -21,6 +21,10 @@ ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) { return reinterpret_cast(var); } +ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) { + return reinterpret_cast(var); +} + ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) { return reinterpret_cast(var); } @@ -107,16 +111,20 @@ using namespace at; namespace detail { +static inline void NCCL_CHECK(ncclResult_t result) { + NCCL_CHECK(from_nccl_result(result)); +} + struct AutoNcclGroup { AutoNcclGroup() { (c10::cuda::CUDACachingAllocator::getFreeMutex())->lock(); #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) - NCCL_CHECK(from_nccl_result(ncclGroupStart())); + NCCL_CHECK(ncclGroupStart()); #endif } ~AutoNcclGroup() { #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) - NCCL_CHECK(from_nccl_result(ncclGroupEnd())); + NCCL_CHECK(ncclGroupEnd()); #endif (c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock(); } @@ -133,8 +141,8 @@ struct NcclCommList { int ndevices; NcclCommList(const std::vector& devices) : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) { - NCCL_CHECK(from_nccl_result( - ncclCommInitAll(to_nccl_comm(comms.get()), devices.size(), devices.data()))); + NCCL_CHECK( + ncclCommInitAll(to_nccl_comm(comms.get()), devices.size(), devices.data())); } NcclCommList(NcclCommList&& foo) = default; ~NcclCommList() { @@ -326,7 +334,7 @@ void get_unique_id(ncclUniqueId& id) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; - NCCL_CHECK(from_nccl_result(ncclGetUniqueId(to_nccl_unique_id(&id)))); + NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id))); #else AT_ERROR("PyTorch built without NCCL support"); #endif @@ -337,11 +345,11 @@ ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) { using namespace torch::cuda::nccl::detail; ncclComm_t comm; ncclUniqueId id = comm_id; - NCCL_CHECK(from_nccl_result(ncclCommInitRank( + NCCL_CHECK(ncclCommInitRank( to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), - rank))); + rank)); return comm; #else return nullptr; @@ -362,8 +370,7 @@ void comm_destroy(ncclComm_t comm) #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; - NCCL_CHECK(from_nccl_result(ncclCommDestroy( - *(to_nccl_comm(&comm))))); + NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm))); #endif } @@ -420,8 +427,8 @@ void broadcast( count_max, ")"); ncclComm_t comm = comms[i]; - NCCL_CHECK(from_nccl_result(ncclBcast( - tensors[i].data_ptr(), numel, data_type, 0, *(to_nccl_comm(&comm)), stream))); + NCCL_CHECK(ncclBcast( + tensors[i].data_ptr(), numel, data_type, 0, to_nccl_comm(comm), stream)); } #else AT_ERROR("PyTorch built without NCCL support"); @@ -460,15 +467,15 @@ void reduce( : streams[i]->stream(); ncclComm_t comm = comms_ref[i]; - NCCL_CHECK(from_nccl_result(ncclReduce( + NCCL_CHECK(ncclReduce( inputs[i].data_ptr(), root == i ? output.data_ptr() : nullptr, count, data_type, to_nccl_red_op(op), root, - *(to_nccl_comm(&comm)), - stream))); + to_nccl_comm(comm), + stream)); } #else AT_ERROR("PyTorch built without NCCL support"); @@ -512,14 +519,14 @@ void all_reduce( : streams[i]->stream(); ncclComm_t comm = comms_ref[i]; - NCCL_CHECK(from_nccl_result(ncclAllReduce( + NCCL_CHECK(ncclAllReduce( inputs[i].data_ptr(), outputs[i].data_ptr(), count, data_type, to_nccl_red_op(op), - *(to_nccl_comm(&comm)), - stream))); + to_nccl_comm(comm), + stream)); } #else AT_ERROR("PyTorch built without NCCL support"); @@ -554,14 +561,14 @@ void reduce_scatter( : streams[i]->stream(); ncclComm_t comm = comms_ref[i]; - NCCL_CHECK(from_nccl_result(ncclReduceScatter( + NCCL_CHECK(ncclReduceScatter( inputs[i].data_ptr(), outputs[i].data_ptr(), count, data_type, to_nccl_red_op(op), - *(to_nccl_comm(&comm)), - stream))); + to_nccl_comm(comm), + stream)); } #else AT_ERROR("PyTorch built without NCCL support"); @@ -596,27 +603,111 @@ void all_gather( ncclComm_t comm = comms_ref[i]; #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) - NCCL_CHECK(from_nccl_result(ncclAllGather( + NCCL_CHECK(ncclAllGather( inputs[i].data_ptr(), outputs[i].data_ptr(), count, data_type, - *(to_nccl_comm(&comm)), - stream))); + to_nccl_comm(comm), + stream)); #else - NCCL_CHECK(from_nccl_result(ncclAllGather( + NCCL_CHECK(ncclAllGather( inputs[i].data_ptr(), count, data_type, outputs[i].data_ptr(), - *(to_nccl_comm(&comm)), - stream))); + to_nccl_comm(comm), + stream)); #endif } #else AT_ERROR("PyTorch built without NCCL support"); #endif } + +void all2all(at::Tensor& input, + at::Tensor& output, + int size, + ncclComm_t _comm, + at::cuda::CUDAStream& stream) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 + using namespace torch::cuda::nccl::detail; + + int numranks; + auto type = to_nccl_data_type(input); + size_t count = input.numel() / size; + size_t rankdiff = input.nbytes() / size; + const auto* sendbuff = reinterpret_cast(input.data_ptr()); + auto* recvbuff = reinterpret_cast(output.data_ptr()); + auto comm = to_nccl_comm(_comm); + NCCL_CHECK(ncclCommCount(comm, &numranks)); + NCCL_CHECK(ncclGroupStart()); + for (int r = 0; r < numranks; r++) { + // NCCL uses 0 byte message for synchronization + // Avoid send/recv when message size is zero + if (count != 0) { + NCCL_CHECK(ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream)); + NCCL_CHECK(ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); + } + } + NCCL_CHECK(ncclGroupEnd()); +#else + AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + +void send( + const at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int dst) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 7) + using namespace torch::cuda::nccl::detail; + NCCL_CHECK(ncclSend( + input.data_ptr(), + input.numel(), + to_nccl_data_type(input), + dst, + to_nccl_comm(comm), + stream.stream())); +#else + AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + +void recv( + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int src) { +#ifdef USE_NCCL +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 7) + using namespace torch::cuda::nccl::detail; + NCCL_CHECK(ncclRecv( + output.data_ptr(), + output.numel(), + to_nccl_data_type(output), + src, + to_nccl_comm(comm), + stream.stream())); +#else + AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0"); +#endif +#else + AT_ERROR("PyTorch built without NCCL support"); +#endif +} + } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 3550cf70aa58b..4cbae2e0208a7 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -136,6 +136,24 @@ TORCH_CUDA_API void all_gather( const stream_list& streams = {}, const comm_list& user_comms = {}); +TORCH_CUDA_API void all2all( + at::Tensor& input, + at::Tensor& output, + int size, + ncclComm_t comm, + at::cuda::CUDAStream& stream); + +TORCH_CUDA_API void send( + const at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int dst); + +TORCH_CUDA_API void recv( + at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream stream, + int src); } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 403bcb2b85dab..35dbeae3f3aab 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -199,7 +199,9 @@ PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) { nullptr, "nccl_broadcast", 1, - "(sequence[Tensor] inputs, int root)"); + "(sequence[Tensor] inputs, int root" + " sequence[torch.cuda.Stream] streams," + " sequence[torch.cuda.nccl.Communicator] comms)"); return nullptr; } @@ -228,7 +230,9 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) { nullptr, "nccl_all_gather", 1, - "(sequence[Tensor] inputs, sequence[Tensor] outputs"); + "(sequence[Tensor] inputs, sequence[Tensor] outputs" + " sequence[torch.cuda.Stream] streams," + " sequence[torch.cuda.nccl.Communicator] comms)"); return nullptr; } @@ -258,7 +262,9 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) { nullptr, "nccl_reduce_scatter", 1, - "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op"); + "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op" + " sequence[torch.cuda.Stream] streams," + " sequence[torch.cuda.nccl.Communicator] comms)"); return nullptr; } diff --git a/torch/csrc/distributed/autograd/context/container.cpp b/torch/csrc/distributed/autograd/context/container.cpp index ee3939010e8a8..6948de958b843 100644 --- a/torch/csrc/distributed/autograd/context/container.cpp +++ b/torch/csrc/distributed/autograd/context/container.cpp @@ -245,14 +245,17 @@ void DistAutogradContainer::sendReleaseContextRpc( CleanupAutogradContextReq(context_id).toMessage(), options); + std::weak_ptr wp = cleanupFuture; cleanupFuture->addCallback( - [worker_id](const rpc::FutureMessage& cleanupFuture) { - if (cleanupFuture.hasError()) { + [worker_id, wp]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + if (future->hasError()) { std::string errorMsg = c10::str( "Could not release Dist Autograd Context on node ", worker_id, ": ", - cleanupFuture.error()->what()); + future->tryRetrieveErrorMessage()); LOG(ERROR) << errorMsg; return; } diff --git a/torch/csrc/distributed/autograd/context/context.cpp b/torch/csrc/distributed/autograd/context/context.cpp index 6527fc25b92b8..526ca053dd409 100644 --- a/torch/csrc/distributed/autograd/context/context.cpp +++ b/torch/csrc/distributed/autograd/context/context.cpp @@ -123,26 +123,27 @@ void DistAutogradContext::resetGraphTask() { } void DistAutogradContext::addOutstandingRpc( - const std::shared_ptr& futureMessage) { - futureMessage->addCallback([this](const rpc::FutureMessage& futureMessage) { - if (futureMessage.hasError()) { + const std::shared_ptr& jitFuture) { + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + auto future = wp.lock(); + if (future->hasError()) { // If we have an error, let the local autograd engine know about it. std::unique_lock lock(lock_); if (graphTask_) { graphTask_->set_exception_without_signal(nullptr); lock.unlock(); if (!graphTask_->future_completed_.exchange(true)) { - graphTask_->future_result_->setErrorIfNeeded( - std::make_exception_ptr(*futureMessage.error())); + graphTask_->future_result_->setErrorIfNeeded(future->exception_ptr()); } } else { LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: " - << (*futureMessage.error()).what(); + << future->tryRetrieveErrorMessage(); } } }); std::lock_guard guard(lock_); - outStandingRpcs_.push_back(futureMessage); + outStandingRpcs_.push_back(jitFuture); } void DistAutogradContext::clearOutstandingRpcs() { @@ -170,8 +171,10 @@ std::shared_ptr DistAutogradContext:: state->future->markCompleted(c10::IValue()); } else { for (auto& rpc : outStandingRpcs) { - rpc->addCallback([state](const rpc::FutureMessage& rpc) { - if (rpc.hasError()) { + std::weak_ptr wp = rpc; + rpc->addCallback([state, wp]() { + auto future = wp.lock(); + if (future->hasError()) { // If there's an error, we want to setError() on the future, // unless another error has already been sent - use a CAS to // guard. @@ -183,7 +186,7 @@ std::shared_ptr DistAutogradContext:: bool expectedAlreadySent = false; if (state->alreadySentError.compare_exchange_strong( expectedAlreadySent, true)) { - state->future->setError(std::make_exception_ptr(*rpc.error())); + state->future->setError(future->exception_ptr()); } return; } diff --git a/torch/csrc/distributed/autograd/context/context.h b/torch/csrc/distributed/autograd/context/context.h index 47d915bde0ace..b611040af448c 100644 --- a/torch/csrc/distributed/autograd/context/context.h +++ b/torch/csrc/distributed/autograd/context/context.h @@ -52,7 +52,7 @@ class TORCH_API DistAutogradContext { // Adds a future message recording an outstanding RPC. void addOutstandingRpc( - const std::shared_ptr& futureMessage); + const std::shared_ptr& jitFuture); // Returns all gradients. const c10::Dict getGradients() const; @@ -134,7 +134,7 @@ class TORCH_API DistAutogradContext { // List of futures for RPCs initiated by this node to propagate gradients to // other nodes. The distributed autograd engine on this node can return // successfully only if all these futures are done and are successful. - std::vector> outStandingRpcs_; + std::vector> outStandingRpcs_; // Lock to protect concurrent modification of the context. mutable std::mutex lock_; @@ -147,7 +147,7 @@ using ContextPtr = std::shared_ptr; // doesn't know the current context. It's just a util class. class TORCH_API ThreadLocalDistAutogradContext { public: - // Store 'new_context' to the thread local varaible maintained by this class. + // Store 'new_context' to the thread local variable maintained by this class. explicit ThreadLocalDistAutogradContext(ContextPtr&& new_context); ~ThreadLocalDistAutogradContext(); diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index d642ef53101ae..20f0e46304e4e 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -279,7 +279,7 @@ void DistEngine::computeDependencies( // Create a dummy GraphRoot and run init_to_execute with it. GraphRoot dummyRoot(edges, {}); - graphTask->init_to_execute(dummyRoot, outputEdges); + graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false); for (auto& mapEntry : graphTask->exec_info_) { auto& execInfo = mapEntry.second; if (!execInfo.captures_) { diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index 99951f098a229..509c5c6cbd08f 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -47,11 +47,12 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) { // Send the gradients over to the appropriate node. auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); - auto futureMessage = rpcAgent->send( - rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage()); + auto jitFuture = rpcAgent->send( + rpcAgent->getWorkerInfo(fromWorkerId_), + std::move(gradCall).toMessage()); // Record the future in the context. - sharedContext->addOutstandingRpc(futureMessage); + sharedContext->addOutstandingRpc(jitFuture); // 'recv' function sends the gradients over the wire using RPC, it doesn't // need to return anything for any downstream autograd function. diff --git a/torch/csrc/distributed/autograd/init.cpp b/torch/csrc/distributed/autograd/init.cpp index 9ab16fb6a93c6..ad6dfa7d8f463 100644 --- a/torch/csrc/distributed/autograd/init.cpp +++ b/torch/csrc/distributed/autograd/init.cpp @@ -15,14 +15,22 @@ namespace { template using shared_ptr_class_ = py::class_>; -PyObject* dist_autograd_init(PyObject* /* unused */) { +PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) { auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd")); if (!autograd_module) { throw python_error(); } - auto module = py::handle(autograd_module).cast(); + auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); + if (!torch_C_module) { + throw python_error(); + } + + auto torch_C_m = py::handle(torch_C_module).cast(); + auto m = torch_C_m.def_submodule("_distributed_autograd", "distributed autograd bindings"); + + auto module = py::handle(m).cast(); auto distAutogradContext = shared_ptr_class_(module, "DistAutogradContext") @@ -196,7 +204,7 @@ Example:: static PyMethodDef methods[] = { // NOLINT {"_dist_autograd_init", - (PyCFunction)dist_autograd_init, + dist_autograd_init, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp index 3656a1b9dae4a..2336711d07e93 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp @@ -13,7 +13,7 @@ constexpr auto kProfileEventsStartIdx = 3; RpcWithProfilingResp::RpcWithProfilingResp( rpc::MessageType messageType, rpc::Message&& wrappedMessage, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId) : messageType_(messageType), wrappedMessage_(std::move(wrappedMessage)), @@ -32,7 +32,7 @@ RpcWithProfilingResp::RpcWithProfilingResp( std::unique_ptr wrappedRpc, rpc::MessageType wrappedMessageType, std::vector tensors, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId) : messageType_(messageType), wrappedRpc_(std::move(wrappedRpc)), @@ -52,7 +52,7 @@ rpc::MessageType RpcWithProfilingResp::wrappedMessageType() const { return wrappedMessageType_; } -std::vector RpcWithProfilingResp:: +std::vector RpcWithProfilingResp:: getProfiledEvents() const { return profiledEvents_; } @@ -119,15 +119,15 @@ std::unique_ptr RpcWithProfilingResp::fromMessage( static_cast(tupleElements[0].toInt()); rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]); int profiledEventsSize = tupleElements[2].toInt(); - std::vector remoteEvents; + std::vector remoteEvents; remoteEvents.reserve(profiledEventsSize); for (int i = kProfileEventsStartIdx; i < kProfileEventsStartIdx + profiledEventsSize; ++i) { TORCH_CHECK(i < tupleElements.size()); // Reconstruct remote event from the ivalues. - torch::autograd::profiler::Event fromIvalueEvent = - torch::autograd::profiler::Event::fromIValue(tupleElements[i]); + torch::autograd::profiler::LegacyEvent fromIvalueEvent = + torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]); remoteEvents.push_back(std::move(fromIvalueEvent)); } diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h index c4dc088017f04..ad7b54ea8b82f 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h @@ -15,7 +15,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { RpcWithProfilingResp( rpc::MessageType messageType, rpc::Message&& wrappedMessage, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId); // For receving RPCs. Used in from message when converting a message received @@ -25,13 +25,13 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { std::unique_ptr wrappedRpc, rpc::MessageType wrappedMessageType, std::vector tensors, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId); rpc::Message toMessageImpl() && override; static std::unique_ptr fromMessage( const rpc::Message& message); // Retrieve remote Events - std::vector getProfiledEvents() const; + std::vector getProfiledEvents() const; // Retrieve the globally unique profiling ID corresponding to this command. const rpc::ProfilingId& getProfilingId() const; // Retrieve the original RPC which this ProfilingRPC wraps. @@ -51,7 +51,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { std::unique_ptr wrappedRpc_; rpc::MessageType wrappedMessageType_; std::vector tensors_; - const std::vector profiledEvents_; + const std::vector profiledEvents_; const rpc::ProfilingId profilingId_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp new file mode 100644 index 0000000000000..00230eda638b6 --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp @@ -0,0 +1,77 @@ +#include +#include +#include + +namespace torch { +namespace distributed { +namespace autograd { + +using rpc::Message; +using rpc::MessageType; + +RRefBackwardReq::RRefBackwardReq( + const rpc::RRefId& rrefId, + int64_t autogradContextId, + bool retainGraph) + : rrefId_(rrefId), + autogradContextId_(autogradContextId), + retainGraph_(retainGraph) {} + +Message RRefBackwardReq::toMessageImpl() && { + std::vector ivalues; + + // Add all the fields. + ivalues.emplace_back(rrefId_.toIValue()); + ivalues.emplace_back(autogradContextId_); + ivalues.emplace_back(retainGraph_); + + // Now pickle using JIT pickler. + std::vector tensorTable; + std::vector payload = + jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable); + + return Message( + std::move(payload), + std::move(tensorTable), + MessageType::RREF_BACKWARD_REQ); +} + +std::unique_ptr RRefBackwardReq::fromMessage( + const Message& message) { + // Unpickle the message and retrieve tupleElements. + auto payload = static_cast(message.payload().data()); + auto payload_size = message.payload().size(); + IValue tuple = jit::unpickle( + payload, + payload_size, + *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(), + &message.tensors()); + std::vector tupleElements = tuple.toTuple()->elements(); + + // Build RRefBackwardReq. + TORCH_INTERNAL_ASSERT(tupleElements.size() == 3); + + // Retrieve all fields. + bool retainGraph = tupleElements[2].toBool(); + int64_t autogradContextId = tupleElements[1].toInt(); + rpc::RRefId rrefId = rpc::RRefId::fromIValue(tupleElements[0]); + + return std::make_unique( + rrefId, autogradContextId, retainGraph); +} + +const rpc::RRefId& RRefBackwardReq::getRRefId() const { + return rrefId_; +} + +int64_t RRefBackwardReq::getAutogradContextId() const { + return autogradContextId_; +} + +bool RRefBackwardReq::retainGraph() const { + return retainGraph_; +} + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h new file mode 100644 index 0000000000000..8e95c9b6f99c3 --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace distributed { +namespace autograd { + +// Internal system RPC to invoke distributed backward pass on remote nodes when +// 'rref.backward()' is invoked. +class TORCH_API RRefBackwardReq : public rpc::RpcCommandBase { + public: + RRefBackwardReq( + const rpc::RRefId& rrefId, + int64_t autogradContextId, + bool retainGraph = false); + + const rpc::RRefId& getRRefId() const; + + int64_t getAutogradContextId() const; + + bool retainGraph() const; + + // Serialization and deserialization methods. + rpc::Message toMessageImpl() && override; + static std::unique_ptr fromMessage( + const rpc::Message& message); + + private: + const rpc::RRefId rrefId_; + const int64_t autogradContextId_; + const bool retainGraph_; +}; + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.cpp b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.cpp new file mode 100644 index 0000000000000..2b1e9e9b6e9fc --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.cpp @@ -0,0 +1,19 @@ +#include + +namespace torch { +namespace distributed { +namespace autograd { + +rpc::Message RRefBackwardResp::toMessageImpl() && { + return rpc::Message({}, {}, rpc::MessageType::RREF_BACKWARD_RESP); +} + +std::unique_ptr RRefBackwardResp::fromMessage( + const rpc::Message& message) { + TORCH_INTERNAL_ASSERT(message.type() == rpc::MessageType::RREF_BACKWARD_RESP); + return std::unique_ptr(); +} + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h new file mode 100644 index 0000000000000..5e7ce7cf36fd0 --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace torch { +namespace distributed { +namespace autograd { + +// Response for the RRefBackwardReq. +class TORCH_API RRefBackwardResp : public rpc::RpcCommandBase { + public: + RRefBackwardResp() = default; + rpc::Message toMessageImpl() && override; + static std::unique_ptr fromMessage( + const rpc::Message& message); +}; + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 464d8248d8a4b..08bb99471686b 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -16,7 +16,7 @@ namespace autograd { using torch::distributed::autograd::AutogradMetadata; using torch::distributed::autograd::RpcWithAutograd; -using torch::distributed::rpc::FutureMessage; +using torch::distributed::rpc::JitFuture; using torch::distributed::rpc::Message; using torch::distributed::rpc::MessageType; using torch::distributed::rpc::RpcAgent; @@ -138,7 +138,7 @@ Message getMessageWithAutograd( return std::move(*rpcWithAutograd).toMessage(); } -std::shared_ptr sendMessageWithAutograd( +std::shared_ptr sendMessageWithAutograd( RpcAgent& agent, const WorkerInfo& dst, torch::distributed::rpc::Message&& wrappedRpcMsg, @@ -151,7 +151,7 @@ std::shared_ptr sendMessageWithAutograd( MessageType::FORWARD_AUTOGRAD_REQ, forceGradRecording); - std::shared_ptr fut; + std::shared_ptr fut; // If profiler is enabled, wrap this message with profiling metadata that will // tell the remote end to process this request with the profiler enabled. if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) { diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index 2a0a066e1a952..07ba45ed60d7c 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -45,7 +45,7 @@ TORCH_API rpc::Message getMessageWithAutograd( bool forceGradRecording = false); // Send message after autograd checking -TORCH_API std::shared_ptr +TORCH_API std::shared_ptr sendMessageWithAutograd( rpc::RpcAgent& agent, const rpc::WorkerInfo& dst, diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h deleted file mode 100644 index e2b501f08affd..0000000000000 --- a/torch/csrc/distributed/c10d/comm.h +++ /dev/null @@ -1,92 +0,0 @@ -#pragma once - -#include - -#include -#include -#include - -namespace c10d { - -// Broadcast many tensors to all processes in the process group. -void broadcast_coalesced( - std::shared_ptr process_group, - at::TensorList tensors, - size_t buffer_size, - int rank = 0); - -// This class passes bucket contents tensor (for multiple replicas) to -// DDP communication hook. -// Optionally in the future this can be enhanced with parameter to bucket -// mappings as well. -class GradBucket { - public: - explicit GradBucket(const std::vector& tensors) - : tensors_(tensors) {} - // Each tensor in the list that getTensors returns refers to the replica on - // each device. There will be multiple replicas only in the case of single - // process multiple device mode. In the single process single device mode, - // this list would consist of only a single tensor. - const std::vector& getTensors() const { - return tensors_; - } - - private: - std::vector tensors_; -}; - -// DDP's c10d reducer allows communication hooks defined as a sub class -// of CommHookInterface. CommHookInterface is an abstract class and can -// be used to implement both Python and CPP hooks. -struct TORCH_API CommHookInterface { - public: - virtual ~CommHookInterface() {} - - // runHook takes a GradBucket type bucket and passes the tensors of - // this grad bucket to hook's callback. This function is called once - // the bucket is ready. The hook can perform whatever processing is - // needed and return a Future that will hold the new value of the grad - // bucket's tensors once ready. - virtual c10::intrusive_ptr runHook( - const GradBucket& bucket) = 0; - - // Once the grad bucket of Future is ready, c10d reducer will call this - // function to get the resulting tensors of the grad bucket. Then c10d - // reducer will use these tensors and copy grads to the grads of individual - // parameters. - virtual std::vector processFuture(c10::IValue future_value) = 0; -}; - -// PythonCommHook enables registering a python hook to c10d reducer and is a -// sub class of CommHookInterface. -class TORCH_API PythonCommHook : public CommHookInterface { - public: - // The constructor takes a state and a callable hook. Inputs are Python - // objects. The state is passed to the hook in runHook function can be used to - // maintain and update any state information that users would like to maintain - // as part of the training process. The hook can perform whatever processing - // user specified and return a Future indicating completion of any async work. - PythonCommHook(py::object state, py::object hook); - - ~PythonCommHook() override { - py::gil_scoped_acquire ag; - state_.dec_ref(); - hook_.dec_ref(); - // explicitly setting PyObject* state_ and hook_ to nullptr to prevent - // py::object's dtor to decref on the PyObject again. - // See Note [Destructing py::object] in python_ivalue.h - state_.ptr() = nullptr; - hook_.ptr() = nullptr; - } - - c10::intrusive_ptr runHook( - const GradBucket& bucket) override; - - std::vector processFuture(c10::IValue future_value) override; - - private: - py::object state_; - py::object hook_; -}; - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 165d6a1c8603b..fc798e537b2f2 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,7 +1,12 @@ #include +#include #include +#include +#ifndef _WIN32 #include +#include +#endif #include #ifdef USE_C10D_GLOO @@ -17,17 +22,20 @@ #endif #include -#include -#include +#include #include +#include +#include +#include #include -#include -#include +#include #include #include #include +#include + namespace torch { namespace distributed { namespace c10d { @@ -51,6 +59,13 @@ std::vector split(char separator, const std::string& string) { template using shared_ptr_class_ = py::class_>; +constexpr auto kDeprecationWarning = + "{} API is being deprecated, please ping " + "https://github.com/pytorch/pytorch/issues/46291 " + "if you see this warning"; +template +using intrusive_ptr_class_ = py::class_>; + // PythonStore is a pybind11 trampoline class to allow a Python // class to inherit from c10d.Store and implement its interface. class PythonStore : public ::c10d::Store { @@ -92,6 +107,14 @@ class PythonStore : public ::c10d::Store { PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value); } + int64_t getNumKeys() override { + PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys); + } + + bool deleteKey(const std::string& key) override { + PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, deleteKey, key); + } + bool check(const std::vector& keys) override { PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys); } @@ -107,38 +130,75 @@ class PythonStore : public ::c10d::Store { } }; -// This method is called from DDP's Python API. Its inputs are -// a c10d reducer object, state, and callable comm_hook. State and -// comm_hook inputs are Python objects and this function creates a -// c10d PythonCommHook object using these inputs. It later calls -// register_comm_hook function of the reducer input to register that -// PythonCommHook object. +// Called from DDP's Python API to create a c10d Python comm hook object. +// The input state and callable comm_hook are Python objects. It later calls +// register_comm_hook function of the reducer input to register the hook. void _register_comm_hook( ::c10d::Reducer& reducer, py::object state, py::object comm_hook) { reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>( std::move(state), std::move(comm_hook))); -}; +} + +// Called from DDP's Python API to create a c10d C++ comm hook. +// The input is an enum hook type. It later calls register_builtin_comm_hook +// function of the reducer input to set the hook type. +void _register_builtin_comm_hook( + ::c10d::Reducer& reducer, + ::c10d::BuiltinCommHookType comm_hook_type) { + reducer.register_builtin_comm_hook(comm_hook_type); +} -PyObject* c10d_init(PyObject* _unused) { +PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { C10_LOG_API_USAGE_ONCE("c10d.python.import"); auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed")); if (!c10d_module) { throw python_error(); } - auto module = py::handle(c10d_module).cast(); + auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); + if (!torch_C_module) { + throw python_error(); + } - module.def( - "_register_comm_hook", - &_register_comm_hook, - py::arg("ddp_model"), - py::arg("state"), - py::arg("comm_hook")); + auto torch_C_m = py::handle(torch_C_module).cast(); + auto m = + torch_C_m.def_submodule("_distributed_c10d", "distributed c10d bindings"); + + auto module = py::handle(m).cast(); + + module + .def( + "_register_comm_hook", + &_register_comm_hook, + py::arg("reducer"), + py::arg("state"), + py::arg("comm_hook"), + py::call_guard()) + .def( + "_register_builtin_comm_hook", + &_register_builtin_comm_hook, + py::arg("reducer"), + py::arg("comm_hook_type")); shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket") - .def(py::init&>(), py::arg("tensors")) + .def( + py::init< + size_t, + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&>(), + py::arg("index"), + py::arg("tensors"), + py::arg("offsets"), + py::arg("lengths"), + py::arg("sizes_list")) + .def( + "get_index", + &::c10d::GradBucket::getIndex, + py::call_guard()) .def( "get_tensors", &::c10d::GradBucket::getTensors, @@ -149,14 +209,31 @@ PyObject* c10d_init(PyObject* _unused) { replicas only in the case of single process multiple device mode. In the single process single device mode, this list would consist of only a single tensor. - )"); + )") + .def( + "get_offsets", + &::c10d::GradBucket::getOffsets, + py::call_guard()) + .def( + "get_lengths", + &::c10d::GradBucket::getLengths, + py::call_guard()) + .def( + "get_sizes_list", + &::c10d::GradBucket::getSizesVec, + py::call_guard()); + + py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"( +An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)") + .value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE) + .value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS); shared_ptr_class_<::c10d::Reducer>(module, "Reducer") .def( py::init< std::vector>, std::vector>, - std::shared_ptr<::c10d::ProcessGroup>, + c10::intrusive_ptr<::c10d::ProcessGroup>, std::vector>, int64_t, bool, @@ -210,6 +287,8 @@ An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, Note that ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when using the ``NCCL`` backend. +Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. + The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. They are used in specifying strategies for reduction collectives, e.g., :func:`reduce`, :func:`all_reduce_multigpu`, etc.)") @@ -266,6 +345,7 @@ They are used in specifying strategies for reduction collectives, e.g., py::class_<::c10d::BarrierOptions>(module, "BarrierOptions") .def(py::init<>()) + .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids) .def_readwrite("timeout", &::c10d::BarrierOptions::timeout); py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") @@ -273,8 +353,14 @@ They are used in specifying strategies for reduction collectives, e.g., .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); auto store = - py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>( - module, "Store") + py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( + module, + "Store", + R"( +Base class for all store implementations, such as the 3 provided by PyTorch +distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`, +and :class:`~torch.distributed.HashStore`). +)") // Default constructor. .def(py::init<>()) // Convert from std::string to std::vector. @@ -286,7 +372,24 @@ They are used in specifying strategies for reduction collectives, e.g., std::vector value_(value.begin(), value.end()); store.set(key, value_); }, - py::call_guard()) + py::call_guard(), + R"( +Inserts the key-value pair into the store based on the supplied ``key`` and +``value``. If ``key`` already exists in the store, it will overwrite the old +value with the new supplied ``value``. + +Arguments: + key (str): The key to be added to the store. + value (str): The value associated with ``key`` to be added to the store. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> store.set("first_key", "first_value") + >>> # Should return "first_value" + >>> store.get("first_key") +)") // Convert from std::vector to py::bytes. // The returned value is not guaranteed to be valid UTF-8. .def( @@ -296,21 +399,148 @@ They are used in specifying strategies for reduction collectives, e.g., return py::bytes( reinterpret_cast(value.data()), value.size()); }, - py::call_guard()) + py::call_guard(), + R"( +Retrieves the value associated with the given ``key`` in the store. If ``key`` is not +present in the store, the function will wait for ``timeout``, which is defined +when initializing the store, before throwing an exception. + +Arguments: + key (str): The function will return the value associated with this key. + +Returns: + Value associated with ``key`` if ``key`` is in the store. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> store.set("first_key", "first_value") + >>> # Should return "first_value" + >>> store.get("first_key") +)") .def( "add", &::c10d::Store::add, - py::call_guard()) + py::call_guard(), + R"( +The first call to add for a given ``key`` creates a counter associated +with ``key`` in the store, initialized to ``amount``. Subsequent calls to add +with the same ``key`` increment the counter by the specified ``amount``. +Calling :meth:`~torch.distributed.store.add` with a key that has already +been set in the store by :meth:`~torch.distributed.store.set` will result +in an exception. + +Arguments: + key (str): The key in the store whose counter will be incremented. + amount (int): The quantity by which the counter will be incremented. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> # Using TCPStore as an example, other store types can also be used + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> store.add("first_key", 1) + >>> store.add("first_key", 6) + >>> # Should return 7 + >>> store.get("first_key") +)") + .def( + "delete_key", + &::c10d::Store::deleteKey, + py::call_guard(), + R"( +Deletes the key-value pair associated with ``key`` from the store. Returns +`true` if the key was successfully deleted, and `false` if it was not. + +.. warning:: + The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API + with the :class:`~torch.distributed.FileStore` will result in an exception. + +Arguments: + key (str): The key to be deleted from the store + +Returns: + `True` if ``key`` was deleted, otherwise `False`. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> # Using TCPStore as an example, HashStore can also be used + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> store.set("first_key") + >>> # This should return true + >>> store.delete_key("first_key") + >>> # This should return false + >>> store.delete_key("bad_key") +)") + .def( + "num_keys", + &::c10d::Store::getNumKeys, + py::call_guard(), + R"( +Returns the number of keys set in the store. Note that this number will typically +be one greater than the number of keys added by :meth:`~torch.distributed.store.set` +and :meth:`~torch.distributed.store.add` since one key is used to coordinate all +the workers using the store. + +.. warning:: + When used with the :class:`~torch.distributed.TCPStore`, ``num_keys`` returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained. + +Returns: + The number of keys present in the store. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> # Using TCPStore as an example, other store types can also be used + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> store.set("first_key", "first_value") + >>> # This should return 2 + >>> store.num_keys() +)") .def( "set_timeout", &::c10d::Store::setTimeout, - py::call_guard()) + py::call_guard(), + R"( +Sets the store's default timeout. This timeout is used during initialization and in +:meth:`~torch.distributed.store.wait` and :meth:`~torch.distributed.store.get`. + +Arguments: + timeout (timedelta): timeout to be set in the store. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> # Using TCPStore as an example, other store types can also be used + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> store.set_timeout(timedelta(seconds=10)) + >>> # This will throw an exception after 10 seconds + >>> store.wait(["bad_key"]) +)") .def( "wait", [](::c10d::Store& store, const std::vector& keys) { store.wait(keys); }, - py::call_guard()) + py::call_guard(), + R"( +Waits for each key in ``keys`` to be added to the store. If not all keys are +set before the ``timeout`` (set during store initialization), then ``wait`` +will throw an exception. + +Arguments: + keys (list): List of keys on which to wait until they are set in the store. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> # Using TCPStore as an example, other store types can also be used + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> # This will throw an exception after 30 seconds + >>> store.wait(["bad_key"]) +)") .def( "wait", [](::c10d::Store& store, @@ -318,15 +548,93 @@ They are used in specifying strategies for reduction collectives, e.g., const std::chrono::milliseconds& timeout) { store.wait(keys, timeout); }, - py::call_guard()); - - shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store) + py::call_guard(), + R"( +Waits for each key in ``keys`` to be added to the store, and throws an exception +if the keys have not been set by the supplied ``timeout``. + +Arguments: + keys (list): List of keys on which to wait until they are set in the store. + timeout (timedelta): Time to wait for the keys to be added before throwing an exception. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> # Using TCPStore as an example, other store types can also be used + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) + >>> # This will throw an exception after 10 seconds + >>> store.wait(["bad_key"], timedelta(seconds=10)) +)"); + + intrusive_ptr_class_<::c10d::FileStore>( + module, + "FileStore", + store, + R"( +A store implementation that uses a file to store the underlying key-value pairs. + +Arguments: + file_name (str): path of the file in which to store the key-value pairs + world_size (int): The total number of processes using the store + +Example:: + >>> import torch.distributed as dist + >>> store1 = dist.FileStore("/tmp/filestore", 2) + >>> store2 = dist.FileStore("/tmp/filestore", 2) + >>> # Use any of the store methods from either the client or server after initialization + >>> store1.set("first_key", "first_value") + >>> store2.get("first_key") + + )") .def(py::init()); - shared_ptr_class_<::c10d::HashStore>(module, "HashStore", store) +#ifndef _WIN32 + intrusive_ptr_class_<::c10d::HashStore>( + module, + "HashStore", + store, + R"( +A thread-safe store implementation based on an underlying hashmap. This store can be used +within the same process (for example, by other threads), but cannot be used across processes. + +Example:: + >>> import torch.distributed as dist + >>> store = dist.HashStore() + >>> # store can be used from other threads + >>> # Use any of the store methods after initialization + >>> store.set("first_key", "first_value") + )") .def(py::init<>()); +#endif - shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store) + intrusive_ptr_class_<::c10d::TCPStore>( + module, + "TCPStore", + store, + R"( +A TCP-based distributed key-value store implementation. The server store holds +the data, while the client stores can connect to the server store over TCP and +perform actions such as :meth:`~torch.distributed.store.set` to insert a key-value +pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc. + +Arguments: + host_name (str): The hostname or IP Address the server store should run on. + port (int): The port on which the server store should listen for incoming requests. + world_size (int): The total number of store users (number of clients + 1 for the server). + is_master (bool): True when initializing the server store, False for client stores. + timeout (timedelta): Timeout used by the store during initialization and for methods such as :meth:`~torch.distributed.store.get` and :meth:`~torch.distributed.store.wait`. + +Example:: + >>> import torch.distributed as dist + >>> from datetime import timedelta + >>> # Run on process 1 (server) + >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30)) + >>> # Run on process 2 (client) + >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False) + >>> # Use any of the store methods from either the client or server after initialization + >>> server_store.set("first_key", "first_value") + >>> client_store.get("first_key") + )") .def( py::init< const std::string&, @@ -337,15 +645,29 @@ They are used in specifying strategies for reduction collectives, e.g., py::arg("host_name"), py::arg("port"), py::arg("world_size"), - py::arg("is_master"), + // using noconvert() requires this argument to be True or False + // prevents accidental implicit conversion to bool + py::arg("is_master").noconvert(), py::arg("timeout") = std::chrono::milliseconds(::c10d::Store::kDefaultTimeout)); - shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store) - .def(py::init>()); + intrusive_ptr_class_<::c10d::PrefixStore>( + module, + "PrefixStore", + store, + R"( +A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`, +:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`) +that adds a prefix to each key inserted to the store. + +Arguments: + prefix (str): The prefix string that is prepended to each key before being inserted into the store. + store (torch.distributed.store): A store object that forms the underlying key-value store. + )") + .def(py::init>()); auto processGroup = - shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") + intrusive_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") .def("rank", &::c10d::ProcessGroup::getRank) .def("size", &::c10d::ProcessGroup::getSize) @@ -607,22 +929,24 @@ They are used in specifying strategies for reduction collectives, e.g., py::arg("opts") = ::c10d::BarrierOptions(), py::call_guard()); +#ifndef _WIN32 module.def( "_round_robin_process_groups", - [](std::vector> processGroups) - -> std::shared_ptr<::c10d::ProcessGroup> { + [](std::vector> processGroups) + -> c10::intrusive_ptr<::c10d::ProcessGroup> { if (processGroups.size() == 0) { throw std::invalid_argument("Specify at least 1 process group"); } const auto& first = processGroups.front(); - return std::make_shared<::c10d::ProcessGroupRoundRobin>( + return c10::make_intrusive<::c10d::ProcessGroupRoundRobin>( first->getRank(), first->getSize(), std::move(processGroups)); }, py::arg("process_groups"), py::call_guard()); +#endif #ifdef USE_C10D_GLOO - auto processGroupGloo = shared_ptr_class_<::c10d::ProcessGroupGloo>( + auto processGroupGloo = intrusive_ptr_class_<::c10d::ProcessGroupGloo>( module, "ProcessGroupGloo", processGroup); shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device"); @@ -651,23 +975,24 @@ They are used in specifying strategies for reduction collectives, e.g., py::arg("interface") = ""); processGroupGloo - .def(py::init< - const std::shared_ptr<::c10d::Store>&, - int, - int, - ::c10d::ProcessGroupGloo::Options>(), - py::call_guard()) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init< + const c10::intrusive_ptr<::c10d::Store>&, + int, + int, + ::c10d::ProcessGroupGloo::Options>(), + py::call_guard()) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, std::chrono::milliseconds timeout) { ::c10d::ProcessGroupGloo::Options options; // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. - char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV); + char* ifnameEnv = getenv(::c10d::GLOO_SOCKET_IFNAME_ENV); if (ifnameEnv) { - for (const auto& iface : split(',', ifnameEnv)) { + for (const auto& iface : ::c10d::split(',', ifnameEnv)) { options.devices.push_back( ::c10d::ProcessGroupGloo::createDeviceForInterface(iface)); } @@ -681,7 +1006,7 @@ They are used in specifying strategies for reduction collectives, e.g., options.timeout = timeout; options.threads = options.devices.size() * 2; - return std::make_shared<::c10d::ProcessGroupGloo>( + return c10::make_intrusive<::c10d::ProcessGroupGloo>( store, rank, size, options); }), py::arg("store"), @@ -692,64 +1017,99 @@ They are used in specifying strategies for reduction collectives, e.g., #endif #ifdef USE_C10D_NCCL - auto processGroupNCCL = shared_ptr_class_<::c10d::ProcessGroupNCCL>( - module, "ProcessGroupNCCL", processGroup) - .def(py::init< - const std::shared_ptr<::c10d::Store>&, - int, - int, - ::c10d::ProcessGroupNCCL::Options>(), - py::call_guard()) - .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, - int rank, - int size, - const std::chrono::milliseconds& timeout){ - ::c10d::ProcessGroupNCCL::Options options; - options.isHighPriorityStream = false; - options.opTimeout = timeout; - return std::make_shared<::c10d::ProcessGroupNCCL>( - store, rank, size, options); - }), - py::arg("store"), - py::arg("rank"), - py::arg("size"), - py::arg("timeout") = std::chrono::milliseconds( - ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), - py::call_guard()); + auto processGroupNCCL = + intrusive_ptr_class_<::c10d::ProcessGroupNCCL>( + module, "ProcessGroupNCCL", processGroup) + .def( + py::init< + const c10::intrusive_ptr<::c10d::Store>&, + int, + int, + c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(), + py::call_guard()) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size, + const std::chrono::milliseconds& timeout) { + auto options = ::c10d::ProcessGroupNCCL::Options::create(); + options->isHighPriorityStream = false; + options->opTimeout = timeout; + return c10::make_intrusive<::c10d::ProcessGroupNCCL>( + store, rank, size, options); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::arg("timeout") = std::chrono::milliseconds( + ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), + py::call_guard()); - py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options") + intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( + processGroupNCCL, "Options") .def(py::init<>()) - .def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream) - .def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout); + .def_readwrite( + "is_high_priority", + &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream) + .def_readwrite( + "op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout); + processGroupNCCL.def_static( + "_group_start", []() { ::c10d::ProcessGroupNCCL::groupStart(); }); + processGroupNCCL.def_static( + "_group_end", []() { ::c10d::ProcessGroupNCCL::groupEnd(); }); #endif #ifdef USE_C10D_MPI - auto processGroupMPI = shared_ptr_class_<::c10d::ProcessGroupMPI>( + auto processGroupMPI = intrusive_ptr_class_<::c10d::ProcessGroupMPI>( module, "ProcessGroupMPI", processGroup); // Define static create function instead of a constructor, because // this function may return null. This happens if this process is not // part of a sub group that is to be created. processGroupMPI.def_static( - "create", - [](std::vector ranks) { - return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks); - }, - py::call_guard()); + "create", + [](std::vector ranks) { + return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks); + }, + py::call_guard()); #endif - shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") + intrusive_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") .def("is_completed", &::c10d::ProcessGroup::Work::isCompleted) - .def("is_success", &::c10d::ProcessGroup::Work::isSuccess) - .def("exception", &::c10d::ProcessGroup::Work::exception) - .def("source_rank", &::c10d::ProcessGroup::Work::sourceRank) + .def( + "is_success", + [](::c10d::ProcessGroup::Work& work) -> bool { + TORCH_WARN_ONCE(fmt::format( + kDeprecationWarning, "ProcessGroup::Work::is_success")); + return work.isSuccess(); + }) + .def( + "exception", + [](::c10d::ProcessGroup::Work& work) -> std::exception_ptr { + TORCH_WARN_ONCE(fmt::format( + kDeprecationWarning, "ProcessGroup::Work::exception")); + return work.exception(); + }) + .def( + "source_rank", + [](::c10d::ProcessGroup::Work& work) -> int { + TORCH_WARN_ONCE(fmt::format( + kDeprecationWarning, "ProcessGroup::Work::source_rank")); + return work.sourceRank(); + }) + .def("_source_rank", &::c10d::ProcessGroup::Work::sourceRank) .def( "result", [](::c10d::ProcessGroup::Work& work) -> std::vector { return work.result(); }) - .def("synchronize", &::c10d::ProcessGroup::Work::synchronize) + .def( + "synchronize", + [](::c10d::ProcessGroup::Work& work) -> void { + TORCH_WARN_ONCE(fmt::format( + kDeprecationWarning, "ProcessGroup::Work::synchronize")); + work.synchronize(); + }) .def( "wait", &::c10d::ProcessGroup::Work::wait, @@ -777,27 +1137,26 @@ They are used in specifying strategies for reduction collectives, e.g., >>> work = process_group.allreduce(tensors) >>> return work.get_future() - >>> ddp_model._register_comm_hook(state = None, hook = allreduce) + >>> ddp_model._egister_comm_hook(state = None, hook = allreduce) .. warning :: - ``get_future`` API supports only NCCL backend and single-process single-device mode. + ``get_future`` API supports only NCCL backend. The ``torch._C.Future`` object returned by this API can be used in - ``DistributedDataParallel._register_comm_hook``, but it is subject to some subtle - differences compared to ``torch.futures.Future`` due to compromises made for performance - reasons. + ``DistributedDataParallel.register_comm_hook``, and adds some CUDA-specific + features on top of ``torch.futures.Future``. In the example above, ``allreduce`` work will be done on GPU using NCCL backend, ``fut.wait()`` will return after synchronizing the appropriate NCCL streams - with PyTorch's default device streams to ensure we can have asynchronous CUDA + with PyTorch's current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note that - ``FutureNCCL`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. + ``CUDAFuture`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. In addition, if a callback function was added by ``fut.then()``, it will wait until ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback stream and invoke the callback inline after running the callback on the callback stream. - ``fut.then()`` will return another ``FutureNCCL`` that holds the return value of the + ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the callback and a ``CUDAEvent`` that recorded the callback stream. - Note that ``fut.done()`` returns if the enire operation is completed on the GPU. + Note that ``fut.done()`` returns only whether the operation has been enqueued on the GPU. )"); module.def( @@ -814,7 +1173,7 @@ They are used in specifying strategies for reduction collectives, e.g., // Define a lambda such that the pybind11 prototype can take a std::vector // for the tensor list argument, but still pass it to the underlying // function as a c10::ArrayRef. - [](std::shared_ptr<::c10d::ProcessGroup> process_group, + [](c10::intrusive_ptr<::c10d::ProcessGroup> process_group, std::vector tensors, // NOLINT size_t buffer_size, int rank) { @@ -835,7 +1194,7 @@ They are used in specifying strategies for reduction collectives, e.g., // Python side of the world. Calling Python functions on a Python object // completely bypasses pybind11. We need to test that the overloaded // functions call into Python and behave like we expect. - [](std::shared_ptr<::c10d::Store> store) { + [](c10::intrusive_ptr<::c10d::Store> store) { auto add = [&store](const std::string& key, int64_t value) { store->add(key, value); }; @@ -881,15 +1240,440 @@ They are used in specifying strategies for reduction collectives, e.g., py::call_guard()); module.attr("_DEFAULT_FIRST_BUCKET_BYTES") = ::c10d::kDefaultFirstBucketBytes; + module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout); Py_RETURN_TRUE; } +#undef PROCESS_GROUP_DEPRECATION_WARNING + +// NOTE: Below are TorchBind bindings for c10d, these bindings will +// live together with those pybind11 bindings above until we resolve +// all the TorchBind issues and merge these two together. we shouldn't +// document this until we finish the migration. + +static const auto StoreTorchBind = + torch::class_<::c10d::Store>("dist_c10d", "Store"); + +static const auto FileStoreTorchBind = + torch::class_<::c10d::FileStore>("dist_c10d", "FileStore") + .def(torch::init([](const std::string& path, + int64_t num_workers) { + return c10::make_intrusive<::c10d::FileStore>( + path, num_workers); + })); + +static const auto TCPStoreTorchBind = + torch::class_<::c10d::TCPStore>("dist_c10d", "TCPStore") + .def(torch::init([](const std::string& host_name, + int64_t port, + int64_t world_size, + bool is_master, + int64_t timeout) { + auto timeout_miliseconds = std::chrono::milliseconds(timeout); + return c10::make_intrusive<::c10d::TCPStore>( + host_name, port, world_size, is_master, timeout_miliseconds); + })); + +// TODO: This should really take Store as constructor argument instead of +// TCPStore, but the fact that TorchScript does not support polymorphism +// forced us to cast in C++ instead of automatic casting +static const auto PrefixStoreTorchBind = + torch::class_<::c10d::PrefixStore>("dist_c10d", "PrefixStore") + .def(torch::init([](const std::string& prefix, + const c10::intrusive_ptr<::c10d::Store>& store) { + return c10::make_intrusive<::c10d::PrefixStore>( + prefix, store); + })); + + +// Torchbind the ProcessGroup to make it available in TorchScript +static const auto ProcessGroupWorkTorchBind = + torch::class_<::c10d::ProcessGroup::Work>("dist_c10d", "Work") + .def(torch::init<>()) + .def( + "wait", + [](const c10::intrusive_ptr<::c10d::ProcessGroup::Work>& work) + -> bool { + // TODO: make std::chrono::millisecond works with TorchBind to + // provide the full API in python + return work->wait(); + }) + .def("result", &::c10d::ProcessGroup::Work::result); + +// TODO: Support argument names in Python API. +static const auto ProcessGroupTorchBind = + torch::class_<::c10d::ProcessGroup>("dist_c10d", "ProcessGroup") + .def_pickle( + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self) { + auto name = + ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self); + return std::vector{name}; + }, + [](std::vector state) { + TORCH_CHECK( + state.size() == 1, + "Expecting exactly 1 state when restoring ProcessGroup, got: ", + state.size()); + const auto& process_group_name = state.front(); + auto process_group = + ::c10d::DistributedC10d::get()->getProcessGroupByName( + process_group_name); + TORCH_CHECK( + process_group.defined(), + "Needed process group not found, ", + "please create a process group with name: ", + process_group_name); + return process_group; + }) + .def( + "rank", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self) { + return static_cast(self->getRank()); + }) + .def( + "size", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self) { + return static_cast(self->getSize()); + }) + // TODO: make BroadcastOptions compatible with TorchBind to provide + // the full API in python. + /* + // TODO: Enable this method when TorchBind supports overloading. + .def( + "broadcast", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector data) { return self->broadcast(data); }) + */ + .def( + "broadcast", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor tensor, + int64_t rootRank) { + ::c10d::BroadcastOptions opts; + opts.rootRank = rootRank; + std::vector tensors = {std::move(tensor)}; + return self->broadcast(tensors, opts); + }) + // TODO: make AllreduceOptions compatible with TorchBind to provide + // the full API in python. + .def( + "allreduce", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector tensors) { + return self->allreduce(tensors); + }) + /* + // TODO: Enable these methods when TorchBind supports overloading. + // TODO: Enable these methods when ReduceOp can be torchbinded. + .def( + "allreduce", + [](c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector& tensors, + c10::intrusive_ptr<::c10d::ReduceOp> op) { + ::c10d::AllreduceOptions opts; + opts.reduceOp = *op; + return self->allreduce(tensors, opts); + } + ) + .def( + "allreduce", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& tensor, + c10::intrusive_ptr<::c10d::ReduceOp> op) { + ::c10d::AllreduceOptions opts; + opts.reduceOp = *op; + std::vector tensors = {tensor}; + return self->allreduce(tensors, opts); + } + ) + */ + // TODO: make AllreduceCoalescedOptions compatible with TorchBind to + // provide the full API in python. + .def( + "allreduce_coalesced", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector tensors) { + ::c10d::AllreduceCoalescedOptions opts; + return self->allreduce_coalesced(tensors, opts); + }) + .def( + "reduce", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector tensors) { + ::c10d::ReduceOptions opts; + return self->reduce(tensors, opts); + }) + /* + // TODO: Enable this when c10d::ReduceOp is TorchBind compatible. + .def( + "reduce", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& tensor, + int rootRank, + c10::intrusive_ptr<::c10d::ReduceOp> op) { + ::c10d::ReduceOptions opts; + opts.reduceOp = *op; + opts.rootRank = rootRank; + std::vector tensors = {tensor}; + return self->reduce(tensors, opts); + }) + */ + .def( + "allgather", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector> outputTensors, + std::vector inputTensors) { + ::c10d::AllgatherOptions opts; + return self->allgather(outputTensors, inputTensors, opts); + }) + /* + // TODO: Enable these methods when TorchBind supports overloading. + .def( + "allgather", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector output, + at::Tensor input) { + std::vector> outputs = { + std::move(output)}; + std::vector inputs = {std::move(input)}; + ::c10d::AllgatherOptions opts; + return self->allgather(outputs, inputs, opts); + }) + */ + .def( + "allgather_coalesced", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector> output_lists, + std::vector input_list) { + ::c10d::AllgatherOptions opts; + return self->allgather_coalesced(output_lists, input_list, opts); + }) + /* + // TODO: Enable this method when TorchBind supports overloading. + .def( + "gather", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector> output_tensors, + std::vector input_tensors) { + ::c10d::GatherOptions opts; + return self->gather(output_tensors, input_tensors, opts); + }) + */ + .def( + "gather", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector output, + at::Tensor input, + int64_t rootRank) { + ::c10d::GatherOptions opts; + opts.rootRank = rootRank; + std::vector> outputs = { + std::move(output)}; + std::vector inputs = {std::move(input)}; + return self->gather(outputs, inputs, opts); + }) + /* + // TODO: Enable this method when TorchBind supports overloading. + .def( + "scatter", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector outputTensors, + std::vector> inputTensors) { + ::c10d::ScatterOptions opts; + self->scatter(outputTensors, inputTensors, opts); + }) + */ + .def( + "scatter", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor output, + std::vector input, + int64_t rootRank) { + ::c10d::ScatterOptions opts; + opts.rootRank = rootRank; + std::vector> inputs = {std::move(input)}; + std::vector outputs = {std::move(output)}; + return self->scatter(outputs, inputs, opts); + }) + /* + // TODO: Enable this method when TorchBind supports overloading. + // TODO: Enable this method when TorchBind supports + ReduceScatterOptions. .def( "reduce_scatter", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector outputTensors, + std::vector> inputTensors) { + ::c10d::ReduceScatterOptions opts; + return self->reduce_scatter(outputTensors, inputTensors, opts); + }) + */ + .def( + "reduce_scatter", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor output, + std::vector input) { + std::vector outputs = {std::move(output)}; + std::vector> inputs = {std::move(input)}; + ::c10d::ReduceScatterOptions opts; + return self->reduce_scatter(outputs, inputs, opts); + }) + .def( + "alltoall_base", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor outputTensor, + at::Tensor inputTensor, + std::vector outputSplitSizes, + std::vector inputSplitSizes) { + ::c10d::AllToAllOptions opts; + return self->alltoall_base( + outputTensor, + inputTensor, + outputSplitSizes, + inputSplitSizes, + opts); + }) + .def( + "alltoall", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector outputTensors, + std::vector inputTensors) { + ::c10d::AllToAllOptions opts; + return self->alltoall(outputTensors, inputTensors, opts); + }) + .def( + "send", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector tensors, + int64_t dstRank, + int64_t tag) { + return self->send( + tensors, static_cast(dstRank), static_cast(tag)); + }) + .def( + "recv", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector tensors, + int64_t srcRank, + int64_t tag) { + return self->recv( + tensors, static_cast(srcRank), static_cast(tag)); + }) + .def( + "recv_anysource", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + std::vector tensors, + int64_t tag) { + return self->recvAnysource(tensors, static_cast(tag)); + }) + .def( + "barrier", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self) { + ::c10d::BarrierOptions opts; + return self->barrier(opts); + }); + +#ifdef USE_C10D_NCCL +// XXX: Ideally the Options of ProcessGroupNCCL should be +// bound using `def_readwrite` like in pybind11, but we +// didn't do that because: 1. no milisecond support yet +// 2. no def_readwrite or property support yet. +// TODO: make this binding the same as pybind11 +static const auto ProcessGroupNCCLOptionsTorchBind = + torch::class_<::c10d::ProcessGroupNCCL::Options>( + "dist_c10d", + "ProcessGroupNCCLOptions") + .def(torch::init([](int64_t timeout, bool isHighPriorityStream) { + auto opTimeout = std::chrono::milliseconds(timeout); + return ::c10d::ProcessGroupNCCL::Options::create( + opTimeout, isHighPriorityStream); + })); + +static const auto ProcessGroupNCCLTorchBind = + torch::class_<::c10d::ProcessGroupNCCL>("dist_c10d", "ProcessGroupNCCL") + .def_pickle( + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + auto base_process_group = + static_intrusive_pointer_cast<::c10d::ProcessGroup>(self); + auto name = + ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self); + return std::vector{name}; + }, + [](std::vector state) { + TORCH_CHECK( + state.size() == 1, + "Expecting exactly 1 state when restoring ProcessGroupNCCL, got: ", + state.size()); + const auto& process_group_name = state.front(); + auto base_process_group = + ::c10d::DistributedC10d::get()->getProcessGroupByName( + process_group_name); + TORCH_CHECK( + base_process_group.defined(), + "Needed process group not found, ", + "please create a process group with name: ", + process_group_name); + c10::intrusive_ptr<::c10d::ProcessGroupNCCL> process_group_nccl = + dynamic_intrusive_pointer_cast<::c10d::ProcessGroupNCCL>( + base_process_group); + TORCH_CHECK( + process_group_nccl.defined(), + "Process group ", + process_group_name, + " isn't configured for NCCL backend"); + return process_group_nccl; + }) + .def(torch::init( + [](const c10::intrusive_ptr<::c10d::Store>& store, + int64_t rank, + int64_t size, + c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options> options, + const std::string& name) { + auto pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>(store, rank, size, options); + ::c10d::DistributedC10d::get()->registerProcessGroupName( + pg, name); + return pg; + })) + .def( + "alltoall_base", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, + at::Tensor output, + at::Tensor input, + std::vector outputSplitSizes, + std::vector inputSplitSizes) { + return self->alltoall_base( + output, + input, + outputSplitSizes, + inputSplitSizes, + ::c10d::AllToAllOptions()); + + }) + .def("size", [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + return (int64_t) self->getSize(); + }) + .def("rank", [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + return (int64_t) self->getRank(); + }); +#endif + +static const auto DistributedC10dFrontendTorchBind = + torch::class_<::c10d::DistributedC10d>("dist_c10d", "frontend") + .def(torch::init([]() { return ::c10d::DistributedC10d::get(); })) + .def( + "new_process_group_helper", + &::c10d::DistributedC10d::newProcessGroupHelper) + .def( + "get_process_group_by_name", + &::c10d::DistributedC10d::getProcessGroupByName) + .def( + "get_name_of_process_group", + &::c10d::DistributedC10d::getNameOfProcessGroup); + } // namespace // c10d methods on torch._C static PyMethodDef methods[] = { // NOLINT - {"_c10d_init", (PyCFunction)c10d_init, METH_NOARGS, nullptr}, + {"_c10d_init", c10d_init, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { diff --git a/torch/csrc/distributed/c10d/python_comm_hook.cpp b/torch/csrc/distributed/c10d/python_comm_hook.cpp new file mode 100644 index 0000000000000..594fc99bbba96 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_comm_hook.cpp @@ -0,0 +1,60 @@ +#include + +#include +#include +#include +#include + +namespace c10d { + +PythonCommHook::~PythonCommHook() { + py::gil_scoped_acquire ag; + state_.dec_ref(); + hook_.dec_ref(); + // Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor + // to decref on the PyObject again. + // See Note [Destructing py::object] in python_ivalue.h + state_.ptr() = nullptr; + hook_.ptr() = nullptr; +} + +c10::intrusive_ptr PythonCommHook::runHook( + GradBucket& bucket) { + py::gil_scoped_acquire acquire; + + py::object py_fut = hook_(state_, bucket); + + try { + return py_fut.cast>()->fut; + } catch (const py::cast_error& e) { + auto type = py_fut.get_type(); + auto errMsg = c10::str( + e.what(), + ". DDP communication hook's callback must return a " + "torch.futures.Future or torch._C.Future object, but got ", + type.attr("__module__").cast(), + ".", + type.attr("__qualname__").cast()); + throw std::runtime_error(errMsg); + } +} + +std::vector PythonCommHook::parseHookResult( + const c10::IValue& result) { + TORCH_INTERNAL_ASSERT( + result.isPyObject() || result.isTensorList(), + "expected the hook result is either a PyObject or TensorList"); + + if (result.isPyObject()) { + py::gil_scoped_acquire ag; + py::object obj = torch::jit::toPyObject(result); + auto value = torch::jit::toIValue( + obj, c10::ListType::create(c10::TensorType::get())); + + return value.toTensorVector(); + } + + return result.toTensorVector(); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/python_comm_hook.h b/torch/csrc/distributed/c10d/python_comm_hook.h new file mode 100644 index 0000000000000..64879d78b2537 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_comm_hook.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace c10d { + +class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { + public: + // Takes a state and a callable hook. The inputs are Python objects. + // The state is passed to the hook in runHook method, and it can be used to + // maintain and update any state information during the execution of the hook. + // The hook performs user-specified processing and returns a future indicating + // asychronous communication of gradients. + PythonCommHook(py::object state, py::object hook) + : state_(std::move(state)), hook_(std::move(hook)) {} + + ~PythonCommHook() override; + + c10::intrusive_ptr runHook(GradBucket& bucket) override; + + std::vector parseHookResult(const c10::IValue& result) override; + + private: + // Only needed for stateful communication. + py::object state_; + py::object hook_; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 34023afdce915..8e4cd346fc2f4 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -31,14 +31,23 @@ constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000); template using shared_ptr_class_ = py::class_>; -PyObject* rpc_init(PyObject* /* unused */) { +PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { auto rpc_module = THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc")); if (!rpc_module) { throw python_error(); } - auto module = py::handle(rpc_module).cast(); + auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); + if (!torch_C_module) { + throw python_error(); + } + + auto torch_C_m = py::handle(torch_C_module).cast(); + auto m = + torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); + + auto module = py::handle(m).cast(); auto rpcBackendOptions = shared_ptr_class_( @@ -81,7 +90,7 @@ PyObject* rpc_init(PyObject* /* unused */) { be constructed directly, rather, an instance can be retrieved through :meth:`~torch.distributed.rpc.get_worker_info` and the result can be passed in to functions such as - :meth:`~torch.distributed.rpc.rpc_sync`, :class:`~torch.distributed.rpc.rpc_async`, + :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`, :meth:`~torch.distributed.rpc.remote` to avoid copying a string on every invocation.)") .def( @@ -114,6 +123,20 @@ PyObject* rpc_init(PyObject* /* unused */) { "join", &RpcAgent::join, py::call_guard()) .def( "sync", &RpcAgent::sync, py::call_guard()) + .def( + "shutdown", + &RpcAgent::shutdown, + py::call_guard()) + .def( + "get_worker_info", + (const WorkerInfo& (RpcAgent::*)(void) const) & + RpcAgent::getWorkerInfo, + py::call_guard()) + .def( + "get_worker_info", + (const WorkerInfo& (RpcAgent::*)(const std::string&) const) & + RpcAgent::getWorkerInfo, + py::call_guard()) .def( "get_worker_infos", &RpcAgent::getWorkerInfos, @@ -223,7 +246,7 @@ PyObject* rpc_init(PyObject* /* unused */) { to the local node and returns it. If the current node is the owner, returns a reference to the local value. - Arguments: + Args: timeout (float, optional): Timeout for ``to_here``. If the call does not complete within this timeframe, an exception indicating so will be raised. If this @@ -240,9 +263,11 @@ PyObject* rpc_init(PyObject* /* unused */) { )") .def( "rpc_sync", - [](const PyRRef& self) { - return self.createRRefProxy(RRefProxyType::RPC_SYNC); + [](const PyRRef& self, float timeoutSeconds) { + return self.createRRefProxy( + RRefProxyType::RPC_SYNC, timeoutSeconds); }, + py::arg("timeout") = kUnsetRpcTimeout, py::call_guard(), R"( Create a helper proxy to easily launch an ``rpc_sync`` using @@ -256,6 +281,12 @@ PyObject* rpc_init(PyObject* /* unused */) { >>> >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs)) + Args: + timeout (float, optional): Timeout for ``rref.rpc_sync()``. + If the call does not complete within this timeframe, an + exception indicating so will be raised. If this argument + is not provided, the default RPC timeout will be used. + Example:: >>> from torch.distributed import rpc >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) @@ -264,9 +295,11 @@ PyObject* rpc_init(PyObject* /* unused */) { )") .def( "rpc_async", - [](const PyRRef& self) { - return self.createRRefProxy(RRefProxyType::RPC_ASYNC); + [](const PyRRef& self, float timeoutSeconds) { + return self.createRRefProxy( + RRefProxyType::RPC_ASYNC, timeoutSeconds); }, + py::arg("timeout") = kUnsetRpcTimeout, py::call_guard(), R"( Create a helper proxy to easily launch an ``rpc_async`` using @@ -280,6 +313,12 @@ PyObject* rpc_init(PyObject* /* unused */) { >>> >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs)) + Args: + timeout (float, optional): Timeout for ``rref.rpc_async()``. + If the call does not complete within this timeframe, an + exception indicating so will be raised. If this argument + is not provided, the default RPC timeout will be used. + Example:: >>> from torch.distributed import rpc >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) @@ -288,9 +327,11 @@ PyObject* rpc_init(PyObject* /* unused */) { )") .def( "remote", - [](const PyRRef& self) { - return self.createRRefProxy(RRefProxyType::REMOTE); + [](const PyRRef& self, float timeoutSeconds) { + return self.createRRefProxy( + RRefProxyType::REMOTE, timeoutSeconds); }, + py::arg("timeout") = kUnsetRpcTimeout, py::call_guard(), R"( Create a helper proxy to easily launch a ``remote`` using @@ -304,6 +345,16 @@ PyObject* rpc_init(PyObject* /* unused */) { >>> >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs)) + Args: + timeout (float, optional): Timeout for ``rref.remote()``. If + the creation of this :class:`~torch.distributed.rpc.RRef` + is not successfully completed within the timeout, then the + next time there is an attempt to use the RRef + (such as ``to_here``), a timeout will be raised. If not + provided, the default RPC timeout will be used. Please see + ``rpc.remote()`` for specific timeout semantics for + :class:`~torch.distributed.rpc.RRef`. + Example:: >>> from torch.distributed import rpc >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) @@ -350,6 +401,7 @@ PyObject* rpc_init(PyObject* /* unused */) { // Intentionally not releasing GIL, as most accesses just // retrieve cached type py::object &PyRRef::getRRefType, + py::arg("timeout") = kUnsetRpcTimeout, R"( Returns the type of the data object referenced by this ``RRef``. On the owner, this is same as @@ -357,6 +409,14 @@ PyObject* rpc_init(PyObject* /* unused */) { RPC to fetch the ``type`` object from the owner. After this function is run once, the ``type`` object is cached by the ``RRef``, and subsequent invocations no longer trigger RPC. + + Args: + rref (torch.distributed.rpc.RRef): The RRef to get type of. + timeout (float, optional): Timeout, in seconds for + ``_get_type``. If the call does not complete within + this timeframe, an exception indicating so will be + raised. If this argument is not provided, the default + RPC timeout will be used. )") .def( "_get_future", @@ -393,6 +453,44 @@ PyObject* rpc_init(PyObject* /* unused */) { Set future that is completed when the profiling event corresponding to the creation of this RRef on the remote node has been recorded. )") + .def( + "backward", + [](PyRRef& self, + int64_t dist_autograd_ctx_id, + bool retain_graph) { + self.backward(dist_autograd_ctx_id, retain_graph); + }, + py::arg("dist_autograd_ctx_id") = -1, + py::arg("retain_graph") = false, + py::call_guard(), + R"( + Runs the backward pass using the RRef as the root of the + backward pass. If ``dist_autograd_ctx_id`` is provided, + we perform a distributed backward pass using the provided + ctx_id starting from the owner of the RRef. In this case, + :meth:`~torch.distributed.autograd.get_gradients` should be + used to retrieve the gradients. If ``dist_autograd_ctx_id`` + is ``None``, it is assumed that this is a local autograd graph + and we only perform a local backward pass. In the local case, + the node calling this API has to be the owner of the RRef. + The value of the RRef is expected to be a scalar Tensor. + + Args: + dist_autograd_ctx_id (int, optional): The distributed + autograd context id for which we should retrieve the + gradients (default: -1). + retain_graph(bool, optional): If ``False``, the graph used to + compute the grad will be freed. Note that in nearly all + cases setting this option to ``True`` is not needed and + often can be worked around in a much more efficient way. + Usually, you need to set this to ``True`` to run backward + multiple times (default: False). + + Example:: + >>> import torch.distributed.autograd as dist_autograd + >>> with dist_autograd.context() as context_id: + >>> rref.backward(context_id) + )") // not releasing GIL to avoid context switch .def("__repr__", &PyRRef::str); @@ -404,7 +502,7 @@ PyObject* rpc_init(PyObject* /* unused */) { The backend options class for ``ProcessGroupAgent``, which is derived from ``RpcBackendOptions``. - Arguments: + Args: num_send_recv_threads (int, optional): The number of threads in the thread-pool used by ``ProcessGroupAgent`` (default: 4). rpc_timeout (float, optional): The default timeout, in seconds, @@ -434,7 +532,7 @@ PyObject* rpc_init(PyObject* /* unused */) { shared_ptr_class_(module, "ProcessGroupAgent", rpcAgent) .def(py::init([](std::string workerName, - const std::shared_ptr<::c10d::ProcessGroup>& pg, + const c10::intrusive_ptr<::c10d::ProcessGroup>& pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout) { return std::make_unique( @@ -446,12 +544,17 @@ PyObject* rpc_init(PyObject* /* unused */) { })) .def( "get_worker_info", - (const WorkerInfo& (ProcessGroupAgent::*)(void)const) & + (const WorkerInfo& (ProcessGroupAgent::*)(void) const) & RpcAgent::getWorkerInfo, py::call_guard()) .def( "get_worker_info", - (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) & + (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&) const) & + ProcessGroupAgent::getWorkerInfo, + py::call_guard()) + .def( + "get_worker_info", + (const WorkerInfo& (ProcessGroupAgent::*)(worker_id_t id) const) & ProcessGroupAgent::getWorkerInfo, py::call_guard()) .def( @@ -511,11 +614,11 @@ PyObject* rpc_init(PyObject* /* unused */) { shared_ptr_class_(module, "TensorPipeAgent", rpcAgent) .def( - py::init([](const std::shared_ptr<::c10d::Store>& store, + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, - std::shared_ptr<::c10d::ProcessGroup> processGroup, + c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, TensorPipeRpcBackendOptions opts) { return std::make_shared( store, @@ -542,12 +645,17 @@ PyObject* rpc_init(PyObject* /* unused */) { py::call_guard()) .def( "get_worker_info", - (const WorkerInfo& (TensorPipeAgent::*)(void)const) & + (const WorkerInfo& (TensorPipeAgent::*)(void) const) & RpcAgent::getWorkerInfo, py::call_guard()) .def( "get_worker_info", - (const WorkerInfo& (TensorPipeAgent::*)(const std::string&)const) & + (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) & + TensorPipeAgent::getWorkerInfo, + py::call_guard()) + .def( + "get_worker_info", + (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) & TensorPipeAgent::getWorkerInfo, py::call_guard()) .def( @@ -712,7 +820,7 @@ PyObject* rpc_init(PyObject* /* unused */) { Set whether GIL wait times should be enabled or not. This incurs a slight overhead cost. Default is disabled for performance reasons. - Arguments: + Args: flag (bool): True to set GIL profiling, False to disable. )"); @@ -731,7 +839,7 @@ PyObject* rpc_init(PyObject* /* unused */) { :meth:`~torch.distributed.rpc.rpc_sync` and :meth:`~torch.distributed.rpc.rpc_async`. - Arguments: + Args: rpcTimeoutSeconds (float): Timeout value in seconds. )"); @@ -777,7 +885,7 @@ PyObject* rpc_init(PyObject* /* unused */) { } // namespace static PyMethodDef methods[] = { // NOLINT - {"_rpc_init", (PyCFunction)rpc_init, METH_NOARGS, nullptr}, + {"_rpc_init", rpc_init, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { diff --git a/torch/csrc/distributed/rpc/macros.h b/torch/csrc/distributed/rpc/macros.h new file mode 100644 index 0000000000000..2763dd0207bef --- /dev/null +++ b/torch/csrc/distributed/rpc/macros.h @@ -0,0 +1,5 @@ +#pragma once + +#if defined(USE_CUDA) && !defined(__HIP_PLATFORM_HCC__) +#define USE_CUDA_NOT_ROCM +#endif diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index 089a84484be9e..b35e9149d1e62 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -76,40 +76,11 @@ MessageType Message::type() const { } bool Message::isRequest() const { - return MessageType::SCRIPT_CALL == type_ || // dist.rpc on builtin ops - MessageType::PYTHON_CALL == type_ || // dist.rpc on Python UDFs - MessageType::SCRIPT_REMOTE_CALL == type_ || // dist.remote on builtin ops - MessageType::PYTHON_REMOTE_CALL == type_ || // dist.remote on Python UDFs - // RRef related internal messages - MessageType::SCRIPT_RREF_FETCH_CALL == type_ || - MessageType::PYTHON_RREF_FETCH_CALL == type_ || - MessageType::RREF_USER_DELETE == type_ || - MessageType::RREF_CHILD_ACCEPT == type_ || - MessageType::RREF_FORK_REQUEST == type_ || - // Autograd message - MessageType::BACKWARD_AUTOGRAD_REQ == type_ || - MessageType::FORWARD_AUTOGRAD_REQ == type_ || - // Cleanup Autograd context request - MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_ || - // Run with profiling request - MessageType::RUN_WITH_PROFILING_REQ == type_; + return MessageTypeFlags::REQUEST_TYPE & type_; } bool Message::isResponse() const { - return MessageType::SCRIPT_RET == type_ || // ret of dist.rpc on builtin ops - MessageType::PYTHON_RET == type_ || // ret of dist.rpc on Python UDFs - MessageType::REMOTE_RET == type_ || // ret of dist.remote - MessageType::SCRIPT_RREF_FETCH_RET == type_ || // ret on RRef::toHere() - MessageType::PYTHON_RREF_FETCH_RET == type_ || // ret on RRef::toHere() - MessageType::EXCEPTION == type_ || // propagate back exceptions - MessageType::RREF_ACK == type_ || // ret of other types - // Autograd response - MessageType::BACKWARD_AUTOGRAD_RESP == type_ || - MessageType::FORWARD_AUTOGRAD_RESP == type_ || - // Cleanup autograd context response - MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_ || - // Run with profiling response - MessageType::RUN_WITH_PROFILING_RESP == type_; + return MessageTypeFlags::RESPONSE_TYPE & type_; } int64_t Message::id() const { @@ -133,6 +104,18 @@ Message createExceptionResponse(const std::string& exceptionStr, int64_t id) { id); } +namespace { + +// NB: need to call torch::class_ to register Message in the map returned by +// c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within +// an IValue. +// NB: add this line here instead of in rpc/init.cpp because 1) we have C++ +// only tests that won't run rpc/init.cpp; 2) Message is not meant to be +// visible from Python. +static const auto message = torch::class_("rpc", "_Message"); + +} // namespace + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index 3f7c930216639..3d2d623e821f3 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -16,48 +16,77 @@ enum RPCErrorType { FaultyProcessGroupAgent for testing */ }; +// The enum values are bitwise ORed with MessageType +// They are bit flags starting from 0x100 and should have +// value such as 0x100, 0x200, 0x400, 0x800, 0xF00, etc. +enum MessageTypeFlags { + REQUEST_TYPE = 0x100, + RESPONSE_TYPE = 0x200, +}; + +// Message types must have values between 0 to 255 enum MessageType { // messages for dist.rpc on builtin operators - SCRIPT_CALL = 0, - SCRIPT_RET = 1, + SCRIPT_CALL = 0 | MessageTypeFlags::REQUEST_TYPE, + SCRIPT_RET = 1 | MessageTypeFlags::RESPONSE_TYPE, // messages for dist.rpc on Python UDF - PYTHON_CALL = 2, - PYTHON_RET = 3, + PYTHON_CALL = 2 | MessageTypeFlags::REQUEST_TYPE, + PYTHON_RET = 3 | MessageTypeFlags::RESPONSE_TYPE, // messages for dist.remote on builtin operators and Python UDF - SCRIPT_REMOTE_CALL = 4, // A remote call on a builtin operator - PYTHON_REMOTE_CALL = 5, // A remote call on a Python UDF - REMOTE_RET = 6, // Response for remote calls for UDF, builtin, or script + SCRIPT_REMOTE_CALL = + 4 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator + PYTHON_REMOTE_CALL = + 5 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF + REMOTE_RET = + 6 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for UDF, + // builtin, or script // RRef related internal messages - SCRIPT_RREF_FETCH_CALL = 7, // A UserRRef fetches value from owner - PYTHON_RREF_FETCH_CALL = 8, // A UserRRef fetches value from owner - SCRIPT_RREF_FETCH_RET = 9, // An OwnerRRef sends ivalue to user - PYTHON_RREF_FETCH_RET = 10, // An OwnerRRef sends py::object to user - RREF_USER_DELETE = 11, // A UserRRef tells the owner to deref - RREF_FORK_REQUEST = 12, // A child UserRRef tells the owner about itself - RREF_CHILD_ACCEPT = 13, // A child UserRRef tells parent that owner knows it - RREF_ACK = 14, // ACK to internal RRef messages + SCRIPT_RREF_FETCH_CALL = + 7 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches value + // from owner + PYTHON_RREF_FETCH_CALL = + 8 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef fetches + // value from owner + SCRIPT_RREF_FETCH_RET = + 9 | MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user + PYTHON_RREF_FETCH_RET = 10 | + MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user + RREF_USER_DELETE = 11 | + MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref + RREF_FORK_REQUEST = + 12 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner + // about itself + RREF_CHILD_ACCEPT = + 13 | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent that + // owner knows it + RREF_ACK = + 14 | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages // Messages with autograd info - FORWARD_AUTOGRAD_REQ = 15, - FORWARD_AUTOGRAD_RESP = 16, + FORWARD_AUTOGRAD_REQ = 15 | MessageTypeFlags::REQUEST_TYPE, + FORWARD_AUTOGRAD_RESP = 16 | MessageTypeFlags::RESPONSE_TYPE, // Messages to propagate gradients on the backward pass. - BACKWARD_AUTOGRAD_REQ = 17, - BACKWARD_AUTOGRAD_RESP = 18, + BACKWARD_AUTOGRAD_REQ = 17 | MessageTypeFlags::REQUEST_TYPE, + BACKWARD_AUTOGRAD_RESP = 18 | MessageTypeFlags::RESPONSE_TYPE, // Messages to tell workers to clean up their autograd context. - CLEANUP_AUTOGRAD_CONTEXT_REQ = 19, - CLEANUP_AUTOGRAD_CONTEXT_RESP = 20, + CLEANUP_AUTOGRAD_CONTEXT_REQ = 19 | MessageTypeFlags::REQUEST_TYPE, + CLEANUP_AUTOGRAD_CONTEXT_RESP = 20 | MessageTypeFlags::RESPONSE_TYPE, // Messages that tell workers to run requests with profiling enabled. - RUN_WITH_PROFILING_REQ = 21, - RUN_WITH_PROFILING_RESP = 22, + RUN_WITH_PROFILING_REQ = 21 | MessageTypeFlags::REQUEST_TYPE, + RUN_WITH_PROFILING_RESP = 22 | MessageTypeFlags::RESPONSE_TYPE, + + // Messages to support RRef.backward(). + RREF_BACKWARD_REQ = 23 | MessageTypeFlags::REQUEST_TYPE, + RREF_BACKWARD_RESP = 24 | MessageTypeFlags::RESPONSE_TYPE, // Other internal message types - EXCEPTION = 55, + EXCEPTION = 55 | MessageTypeFlags::RESPONSE_TYPE, UNKNOWN = 60 }; @@ -80,7 +109,7 @@ enum MessageType { // Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall, // and PythonResp into a Message, and it is up to the RpcAgent // implementation to determine how to serialize a message. -class TORCH_API Message final { +class TORCH_API Message final : public torch::CustomClassHolder { public: Message(); @@ -141,9 +170,6 @@ TORCH_API Message createExceptionResponse(const std::exception& e, int64_t id); TORCH_API Message createExceptionResponse(const std::string& exceptionStr, int64_t id); -// FutureMessage is an internal type used in the communication layer. All -// user-facing surface APIs should use JitFuture instead. -using FutureMessage = torch::utils::Future; using JitFuture = c10::ivalue::Future; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index d97577724a55c..9c1a703cfa6db 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -90,7 +90,7 @@ void ProcessGroupAgent::collectNames() { ProcessGroupAgent::ProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr<::c10d::ProcessGroup> pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, std::unique_ptr cb) @@ -153,6 +153,8 @@ const WorkerInfo& ProcessGroupAgent::getWorkerInfo( } const WorkerInfo& ProcessGroupAgent::getWorkerInfo(worker_id_t id) const { + TORCH_CHECK( + id >= 0 && id < allWorkerInfo_.size(), "Invalid destination: ", id); return allWorkerInfo_[id]; } @@ -285,7 +287,7 @@ void ProcessGroupAgent::shutdownImpl() { threadPool_.waitWorkComplete(); } -std::shared_ptr ProcessGroupAgent::send( +std::shared_ptr ProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds) { @@ -317,7 +319,7 @@ std::shared_ptr ProcessGroupAgent::send( pg_->getRank()); auto requestId = nextId(); - auto future = std::make_shared(); + auto future = std::make_shared(at::AnyClassType::get()); if (message.isRequest()) { // millisecond level precision of when request started. auto futureStartTime = std::chrono::steady_clock::now(); @@ -360,7 +362,7 @@ std::shared_ptr ProcessGroupAgent::send( message.setId(requestId); ++clientActiveCalls_; } else { - future->markCompleted(Message()); + future->markCompleted(IValue()); } // Sending to ourselves: bypass the send logic and enqueue directly @@ -380,6 +382,7 @@ std::shared_ptr ProcessGroupAgent::send( // the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the // C++ land. enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message))); + return future; } @@ -396,7 +399,7 @@ void ProcessGroupAgent::handleSend(const SendWork& work) { // ProcessGroup is not thread-safe when sending with the same tag, // hence the lock - std::vector> pendingSends; + std::vector> pendingSends; const auto dst = work.to_.id_; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -511,22 +514,24 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { std::move(data.first), std::move(data.second), work.type_, work.id_); if (message.isRequest()) { ++serverActiveCalls_; - std::shared_ptr futureResponse; + std::shared_ptr futureResponse; try { futureResponse = cb_->operator()(message); } catch (const std::exception& e) { - futureResponse = std::make_shared(); - futureResponse->setError(e.what()); + futureResponse = std::make_shared(at::AnyClassType::get()); + futureResponse->setError(std::current_exception()); } if (futureResponse->completed()) { --serverActiveCalls_; if (!futureResponse->hasError()) { - send(work.from_, std::move(*futureResponse).moveValue()); + send( + work.from_, + std::move(*futureResponse->value().toCustomClass())); } else { send( work.from_, createExceptionResponse( - futureResponse->error()->what(), message.id())); + futureResponse->tryRetrieveErrorMessage(), message.id())); } } else { ++serverActiveAsyncCalls_; @@ -535,28 +540,30 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { // Use a weak_ptr, so we can std::move the future's value. auto fromId = work.from_.id_; auto requestId = work.id_; - futureResponse->addCallback([this, - fromId, - requestId, - weak = std::weak_ptr( - futureResponse)]() { - auto futureResponse = weak.lock(); - TORCH_INTERNAL_ASSERT(futureResponse); - --serverActiveCalls_; - --serverActiveAsyncCalls_; - if (!futureResponse->hasError()) { - send(getWorkerInfo(fromId), std::move(*futureResponse).moveValue()); - } else { - send( - getWorkerInfo(fromId), - createExceptionResponse( - futureResponse->error()->what(), requestId)); - } - }); + futureResponse->addCallback( + [this, + fromId, + requestId, + weak = std::weak_ptr(futureResponse)]() { + auto futureResponse = weak.lock(); + TORCH_INTERNAL_ASSERT(futureResponse); + --serverActiveCalls_; + --serverActiveAsyncCalls_; + if (!futureResponse->hasError()) { + send( + getWorkerInfo(fromId), + std::move(*futureResponse->value().toCustomClass())); + } else { + send( + getWorkerInfo(fromId), + createExceptionResponse( + futureResponse->tryRetrieveErrorMessage(), requestId)); + } + }); } } else if (message.isResponse()) { auto id = message.id(); - std::shared_ptr fm = nullptr; + std::shared_ptr jitFuture = nullptr; { std::lock_guard lock{futureMutex_}; const auto& futureInfo = futures_.find(id); @@ -568,7 +575,7 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { return false; } // Use futureInfo before destructing it. - fm = futureInfo->second.future_; + jitFuture = futureInfo->second.future_; auto endTime = futureInfo->second.endTime_; futures_.erase(id); // look up the corresponding future by its time out and request @@ -587,10 +594,11 @@ bool ProcessGroupAgent::handleRecv(RecvWork& work) { futureCV_.notify_all(); --clientActiveCalls_; if (message.type() == MessageType::EXCEPTION) { - fm->setError( - std::string(message.payload().begin(), message.payload().end())); + jitFuture->setError(std::make_exception_ptr(std::runtime_error( + std::string(message.payload().begin(), message.payload().end())))); } else { - fm->markCompleted(std::move(message)); + jitFuture->markCompleted( + IValue(c10::make_intrusive(std::move(message)))); } } else { // TODO: pass the error back to the caller instead of crashing here. @@ -641,7 +649,7 @@ void ProcessGroupAgent::markFutureWithError(Message& message) { } void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { - std::shared_ptr fm = nullptr; + std::shared_ptr jitFuture = nullptr; { std::lock_guard lock{futureMutex_}; const auto& futureInfo = futures_.find(id); @@ -651,7 +659,7 @@ void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { // out and been processed accordingly. return; } - fm = futureInfo->second.future_; + jitFuture = futureInfo->second.future_; auto rpcEndTime = futureInfo->second.endTime_; futures_.erase(id); // look up the corresponding future by its time out and request ID, @@ -669,7 +677,7 @@ void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) { } --clientActiveCalls_; - fm->setError(std::move(errorMsg)); + jitFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg))); futureCV_.notify_all(); } @@ -801,7 +809,8 @@ void ProcessGroupAgent::pollTimedOutRPCs() { if (!timedOutFuture.future_->hasError()) { --clientActiveCalls_; - timedOutFuture.future_->setError(std::move(err)); + timedOutFuture.future_->setError( + std::make_exception_ptr(std::runtime_error(err))); // The future timed out and will not be processed by handleRecv(), even // if we eventually get a response. In order to keep track of all // send/recv pairs, we increment the count here. diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 1bc8db9ebf208..8d2471a7d113d 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -61,7 +61,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { public: ProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr<::c10d::ProcessGroup> pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, std::unique_ptr cb); @@ -88,7 +88,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // This method wraps the destination information and the message into a // SendWork object, and put the SendWork into a queue. Another thread will // consume SendWork from the queue and send it out. - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; @@ -130,16 +130,16 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // additional information to manage timeouts and destination information, // which is needed for termination detection. struct FutureInfo { - std::shared_ptr future_; + std::shared_ptr future_; steady_clock_time_point endTime_; int dstRank_; std::chrono::milliseconds timeout_; FutureInfo( - const std::shared_ptr& future, + std::shared_ptr future, const steady_clock_time_point& endTime, int dstRank, const std::chrono::milliseconds timeout) - : future_(future), + : future_(std::move(future)), endTime_(endTime), dstRank_(dstRank), timeout_(timeout) {} @@ -209,7 +209,7 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { return ++nextId_; } - std::shared_ptr pg_; + c10::intrusive_ptr<::c10d::ProcessGroup> pg_; // worker name -> rank std::unordered_map nameMap_; std::vector allWorkerInfo_; @@ -230,14 +230,14 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { // Lock and shared ptr to currently pending work, set in listenloop() and // interruptible in shutdown(). std::mutex recvWorkMutex_; - std::shared_ptr recvWork_; + c10::intrusive_ptr recvWork_; // Map of dst rank to current oustanding sends that we are waiting on. In the // case of a call to ::shutdown() while we are still waiting on these sends, // the pending sends contained in this map will be aborted, allowing the // waiting thread to be unblocked. std::unordered_map< worker_id_t, - std::set>> + std::set>> currentPendingSends_; // Lock to serialize access to the above map. std::mutex pendingSendMutex_; diff --git a/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h b/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h index f4baed5218b6c..b45026b184fe7 100644 --- a/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h +++ b/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h @@ -51,8 +51,7 @@ class State { // parse_cpu_trace(result) for results of all profile range. std::mutex resultsMutex_; std::vector results_; - const ProfilerConfig config_ = - ProfilerConfig(ProfilerState::Disabled, false, false); + const ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); }; class StateStackEntry; diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 823e21d20b4bd..e9056db2cd8aa 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -1,5 +1,8 @@ #include +#include +#include +#include #include #include #include @@ -134,8 +137,7 @@ c10::intrusive_ptr PyRRef::getFuture() const { // Marking hasValue to false, as this Future is only used for signaling // profiler to update profiling result and the profiler does not retrieve // any value from it. - return wrapFutureMessageInJitFuture( - rref_->getOwnerCreationFuture(), false /* hasValue */); + return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */); } c10::intrusive_ptr PyRRef::getProfilingFuture() const { @@ -226,20 +228,22 @@ std::string PyRRef::str() const { } } -py::object PyRRef::createRRefProxy(const RRefProxyType& type) const { +py::object PyRRef::createRRefProxy( + const RRefProxyType& type, + float timeoutSeconds) const { auto& pythonRpcHandler = PythonRpcHandler::getInstance(); pybind11::gil_scoped_acquire ag; auto& functions = pythonRpcHandler.getRRefProxyFunctions(); auto& ctor = functions.rrefProxyCtor_; switch (type) { case RRefProxyType::RPC_SYNC: { - return ctor(*this, functions.rpcSync_); + return ctor(*this, functions.rpcSync_, timeoutSeconds); } case RRefProxyType::RPC_ASYNC: { - return ctor(*this, functions.rpcAsync_); + return ctor(*this, functions.rpcAsync_, timeoutSeconds); } case RRefProxyType::REMOTE: { - return ctor(*this, functions.remote_); + return ctor(*this, functions.remote_, timeoutSeconds); } default: { TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type); @@ -247,14 +251,15 @@ py::object PyRRef::createRRefProxy(const RRefProxyType& type) const { } } -py::object PyRRef::getRRefType() { +py::object PyRRef::getRRefType(float timeout) { // GIL is not released when calling this function. if (!type_.has_value()) { pybind11::gil_scoped_release release; auto& pythonRpcHandler = PythonRpcHandler::getInstance(); auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions(); pybind11::gil_scoped_acquire acquire; - type_ = isOwner() ? typeFuncs.onOwner_(*this) : typeFuncs.onUser_(*this); + type_ = isOwner() ? typeFuncs.onOwner_(*this) + : typeFuncs.onUser_(*this, timeout); } return *type_; @@ -283,6 +288,59 @@ c10::IValue PyRRef::toIValue() const { return IValue(rrefPtr); } +void PyRRef::backward(int64_t autogradContextId, bool retainGraph) { + backward(autogradContextId, retainGraph, rref_); +} + +void PyRRef::backward( + int64_t autogradContextId, + bool retainGraph, + const c10::intrusive_ptr& rref) { + if (rref->isOwner()) { + auto value = + c10::static_intrusive_pointer_cast(rref)->getValue(); + + // If we have a PyObj, retrieve the underlying tensor. + if (rref->isPyObj()) { + py::gil_scoped_acquire gil; + py::object obj = torch::jit::toPyObject(value); + try { + value = torch::jit::toIValue(obj, c10::TensorType::get()); + } catch (py::cast_error& e) { + throw std::runtime_error( + "RRef should contain a tensor for .backward()"); + } + } + + TORCH_CHECK( + value.isTensor(), "RRef should contain a tensor for .backward()"); + auto root = value.toTensor(); + + if (autogradContextId == -1) { + torch::autograd::backward({root}); + } else { + torch::distributed::autograd::backward( + autogradContextId, {root}, retainGraph); + } + + } else { + TORCH_CHECK( + autogradContextId != -1, + "User RRefs require 'dist_autograd_ctx_id' to be specified"); + + autograd::RRefBackwardReq rrefBackwardReq( + rref->rrefId(), autogradContextId, retainGraph); + + // Invoke distributed backward remotely. + auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); + rpcAgent + ->send( + rpcAgent->getWorkerInfo(rref->owner()), + std::move(rrefBackwardReq).toMessage()) + ->waitAndThrow(); + } +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/py_rref.h b/torch/csrc/distributed/rpc/py_rref.h index 3cc7ab73f0198..a160ac72accef 100644 --- a/torch/csrc/distributed/rpc/py_rref.h +++ b/torch/csrc/distributed/rpc/py_rref.h @@ -46,10 +46,23 @@ class PYBIND11_EXPORT PyRRef { // create a proxy on this RRef, which can be used to launch RPC on the owner // of this RRef to run functions on the object referenced by this RRef. - py::object createRRefProxy(const RRefProxyType& mode) const; + py::object createRRefProxy( + const RRefProxyType& mode, + float timeoutSeconds = rpc::kUnsetRpcTimeout) const; - // get the type of the data object referenced by this RRef. - py::object getRRefType(); + // get the type of the data object referenced by this RRef. Timeout argument + // is only used in the first invocation of this function as an argument to the + // RPC to the owner node of the RRef. + py::object getRRefType(float timeout = rpc::kUnsetRpcTimeout); + + // Run the backward pass with the RRef as the root. + void backward(int64_t autogradContextId, bool retainGraph); + + // Helper static function to run backward on a given rref. + static void backward( + int64_t autogradContextId, + bool retainGraph, + const c10::intrusive_ptr& rref); private: c10::intrusive_ptr rref_; diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index b7c16639b19b1..1a399b403ab12 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -1,10 +1,10 @@ -#include #include #include #include #include #include #include +#include #include #include #include @@ -24,7 +24,7 @@ namespace rpc { namespace { -IValue toIValue(const Message& message) { +IValue toPyIValue(const Message& message) { MessageType msgType = message.type(); auto response = deserializeResponse(message, msgType); switch (msgType) { @@ -61,36 +61,42 @@ std::shared_ptr matchBuiltinOp( const py::kwargs& kwargs, Stack& stack) { Symbol symbol = Symbol::fromQualString(opName); - std::vector> candidates; + std::shared_ptr matchedOperator; if (symbol.is_aten()) { - for (const auto& op : torch::jit::getAllOperatorsFor(symbol)) { - try { - // FIXME: This is temporary solution. We should at least refactor - // ``createStackForSchema`` to avoid throwing an error. - stack = torch::jit::createStackForSchema( - op->schema(), args, kwargs, c10::nullopt); - } catch (std::runtime_error& e) { - VLOG(1) << "Couldn't match schema: " << op->schema() - << " to args: " << args << " and kwargs: " << kwargs - << ", reason: " << e.what(); - continue; - } - - // Prefer C10 ops so that they go through C10 dispatch. We expect the - // total # of possible overloaded ops to be small (i.e. it is 10 for - // torch.add) so a worst-case linear search should not incur significant - // extra overhead. + // Prefer C10 ops so that they go through C10 dispatch. We expect the + // total # of possible overloaded ops (i.e. size of below ops list) to be + // small (i.e. it is 10 for torch.add) so a worst-case linear search should + // not incur significant extra overhead. + auto ops = torch::jit::getAllOperatorsFor(symbol); + std::vector> c10OpsForSymbol; + for (auto it = ops.begin(); it != ops.end();) { + std::shared_ptr op = *it; if (op->isC10Op()) { - return op; + c10OpsForSymbol.emplace_back(std::move(op)); + it = ops.erase(it); + } else { + ++it; } - candidates.emplace_back(op); } + + // Don't throw on failures in this call, since we are not examining on all + // operators here, and the matched operator may indeed not be a c10 op. + std::pair, torch::jit::Stack> + opWithStack; + try { + opWithStack = torch::jit::getOpWithStack(c10OpsForSymbol, args, kwargs); + } catch (const std::runtime_error& e) { + opWithStack = torch::jit::getOpWithStack(ops, args, kwargs); + } + matchedOperator = std::get<0>(opWithStack); + stack = std::get<1>(opWithStack); } - // Ensure that we generated some candidates. + // We should never hit this path, since if !matchedOperator, then the last + // call to getOpWithStack should have thrown. TORCH_CHECK( - !candidates.empty(), + matchedOperator != nullptr, "Failed to match operator name ", opName, " and arguments " @@ -99,10 +105,11 @@ std::shared_ptr matchBuiltinOp( ", kwargs: ", kwargs, ") to a builtin operator"); - return candidates[0]; + + return matchedOperator; } -std::shared_ptr sendPythonRemoteCall( +std::shared_ptr sendPythonRemoteCall( const WorkerInfo& dst, SerializedPyObj serializedPyObj, const IValue& rrefId, @@ -127,42 +134,40 @@ std::shared_ptr sendPythonRemoteCall( using namespace torch::distributed::autograd; -c10::intrusive_ptr wrapFutureMessageInJitFuture( - const std::shared_ptr& futureResponseMessage, +c10::intrusive_ptr toPyJitFuture( + const std::shared_ptr& messageJitFuture, bool hasValue) { if (hasValue) { - c10::intrusive_ptr jitFuture = + c10::intrusive_ptr pyJitFuture = c10::make_intrusive(PyObjectType::get()); - std::weak_ptr wp = futureResponseMessage; - futureResponseMessage->addCallback( - at::wrapPropagateTLSState([jitFuture, wp]() { - auto futureResponseMessage = wp.lock(); - if (futureResponseMessage->hasError()) { - jitFuture->setError( - std::make_exception_ptr(*futureResponseMessage->error())); + std::weak_ptr wp = messageJitFuture; + messageJitFuture->addCallback( + at::wrapPropagateTLSState([pyJitFuture, wp]() { + auto future = wp.lock(); + if (future->hasError()) { + pyJitFuture->setError(future->exception_ptr()); } else { - jitFuture->markCompleted( - toIValue(futureResponseMessage->constValue())); + pyJitFuture->markCompleted( + toPyIValue(*future->value().toCustomClass())); } })); - return jitFuture; + return pyJitFuture; } else { - c10::intrusive_ptr jitFuture = + c10::intrusive_ptr pyJitFuture = c10::make_intrusive(NoneType::get()); - std::weak_ptr wp = futureResponseMessage; - futureResponseMessage->addCallback( - at::wrapPropagateTLSState([wp, jitFuture]() { - auto futureResponseMessage = wp.lock(); - if (futureResponseMessage->hasError()) { - jitFuture->setError( - std::make_exception_ptr(*futureResponseMessage->error())); + std::weak_ptr wp = messageJitFuture; + messageJitFuture->addCallback( + at::wrapPropagateTLSState([wp, pyJitFuture]() { + auto future = wp.lock(); + if (future->hasError()) { + pyJitFuture->setError(future->exception_ptr()); } else { - jitFuture->markCompleted(IValue()); + pyJitFuture->markCompleted(IValue()); } })); - return jitFuture; + return pyJitFuture; } } @@ -179,7 +184,7 @@ c10::intrusive_ptr pyRpcBuiltin( py::gil_scoped_release release; auto scriptCall = std::make_unique(op, std::move(stack)); auto agent = RpcAgent::getCurrentRpcAgent(); - return wrapFutureMessageInJitFuture(sendMessageWithAutograd( + return toPyJitFuture(sendMessageWithAutograd( *agent, dst, std::move(*scriptCall).toMessage(), @@ -200,7 +205,7 @@ c10::intrusive_ptr pyRpcPythonUdf( std::move(serializedPyObj), isAsyncExecution); auto agent = RpcAgent::getCurrentRpcAgent(); - return wrapFutureMessageInJitFuture(sendMessageWithAutograd( + return toPyJitFuture(sendMessageWithAutograd( *agent, dst, std::move(*pythonCall).toMessage(), @@ -268,20 +273,19 @@ PyRRef pyRemoteBuiltin( auto scriptRemoteCall = std::make_unique( op, std::move(stack), userRRef->rrefId(), userRRef->forkId()); - auto fm = sendMessageWithAutograd( + auto jitFuture = sendMessageWithAutograd( *agent, dst, std::move(*scriptRemoteCall).toMessage(), /*forceGradRecord */ false, /* timeout */ rpcTimeoutSeconds); - userRRef->registerOwnerCreationFuture(fm); + userRRef->registerOwnerCreationFuture(jitFuture); ctx.addPendingUser(userRRef->forkId(), userRRef); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRef->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return PyRRef(userRRef); } else { @@ -291,22 +295,20 @@ PyRRef pyRemoteBuiltin( auto scriptRemoteCall = std::make_unique( op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId()); - auto fm = sendMessageWithAutograd( + auto jitFuture = sendMessageWithAutograd( *agent, dst, std::move(*scriptRemoteCall).toMessage(), /* forceGradRecord */ false, /* timeout */ rpcTimeoutSeconds); - ownerRRef->registerOwnerCreationFuture(fm); - + ownerRRef->registerOwnerCreationFuture(jitFuture); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRef->rrefId()]() { - auto fm = wp.lock(); - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); })); return PyRRef(ownerRRef); } @@ -325,7 +327,7 @@ PyRRef pyRemotePythonUdf( if (ctx.getWorkerId() != dst.id_) { auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get()); - auto fm = sendPythonRemoteCall( + auto jitFuture = sendPythonRemoteCall( dst, std::move(serializedPyObj), userRRef->rrefId().toIValue(), @@ -333,14 +335,12 @@ PyRRef pyRemotePythonUdf( rpcTimeoutSeconds, isAsyncExecution); - userRRef->registerOwnerCreationFuture(fm); - + userRRef->registerOwnerCreationFuture(jitFuture); ctx.addPendingUser(userRRef->forkId(), userRRef); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRef->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return PyRRef(userRRef); } else { @@ -348,7 +348,7 @@ PyRRef pyRemotePythonUdf( auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get()); // prevent this owner RRef being deleted due to other forks ctx.addSelfAsFork(ownerRRef); - auto fm = sendPythonRemoteCall( + auto jitFuture = sendPythonRemoteCall( dst, std::move(serializedPyObj), ownerRRef->rrefId().toIValue(), @@ -356,13 +356,12 @@ PyRRef pyRemotePythonUdf( rpcTimeoutSeconds, isAsyncExecution); - ownerRRef->registerOwnerCreationFuture(fm); - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + ownerRRef->registerOwnerCreationFuture(jitFuture); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRef->rrefId()]() { - auto fm = wp.lock(); auto deletedRRef = - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); if (deletedRRef && deletedRRef->isPyObj()) { py::gil_scoped_acquire ag; deletedRRef.reset(); diff --git a/torch/csrc/distributed/rpc/python_functions.h b/torch/csrc/distributed/rpc/python_functions.h index 56c0910968289..15bc0b2af8a0a 100644 --- a/torch/csrc/distributed/rpc/python_functions.h +++ b/torch/csrc/distributed/rpc/python_functions.h @@ -9,16 +9,16 @@ namespace torch { namespace distributed { namespace rpc { -// Converts an internal FutureMessage type into a user-facing FutureIValue type -// by creating a new FutureIValue and call its markCompleted as a callback in -// the given FutureMessage. +// Converts an internal ivalue::Future of Message into a user-facing +// ivalue::Future of py::object type by creating a new ivalue::Future and call +// its markCompleted as a callback in the given ivalue::Future. // If hasValue is true, the Message will be converted into a py::object and then -// wrap it with an IValue. If hasValue is false, this FutureIValue is only used -// for signaling and launching callbacks. In this case, the message will be -// discarded and then set the FutureIValue using an empty IValue or the given +// wrap it with an IValue. If hasValue is false, this ivalue::Future is only +// used for signaling and launching callbacks. In this case, the message will be +// discarded and then set the ivalue::Future using an empty IValue or the given // FutureError if there is an error. -c10::intrusive_ptr wrapFutureMessageInJitFuture( - const std::shared_ptr& futureResponseMessage, +c10::intrusive_ptr toPyJitFuture( + const std::shared_ptr& messageJitFuture, bool hasValue = true); c10::intrusive_ptr pyRpcBuiltin( diff --git a/torch/csrc/distributed/rpc/request_callback.cpp b/torch/csrc/distributed/rpc/request_callback.cpp index 44b7cb6eb2e5b..33703c9235230 100644 --- a/torch/csrc/distributed/rpc/request_callback.cpp +++ b/torch/csrc/distributed/rpc/request_callback.cpp @@ -9,8 +9,7 @@ namespace rpc { using namespace torch::distributed::autograd; -std::shared_ptr RequestCallback::operator()( - Message& request) const { +std::shared_ptr RequestCallback::operator()(Message& request) const { // NB: cannot clear autograd context id here because the processMessage method // might pause waiting for all RRefs in the arguments to be confirmed by their // owners and resumne processing in a different thread. Hence, the diff --git a/torch/csrc/distributed/rpc/request_callback.h b/torch/csrc/distributed/rpc/request_callback.h index 95847eb6153ad..128cf9590034e 100644 --- a/torch/csrc/distributed/rpc/request_callback.h +++ b/torch/csrc/distributed/rpc/request_callback.h @@ -12,7 +12,7 @@ namespace rpc { class TORCH_API RequestCallback { public: // Invoke the callback. - std::shared_ptr operator()(Message& request) const; + std::shared_ptr operator()(Message& request) const; virtual ~RequestCallback() {} @@ -24,8 +24,7 @@ class TORCH_API RequestCallback { // message containing an exception. Different rpc agent implementations are // expected to ensure delivery of the response/exception based on their // implementation specific mechanisms. - virtual std::shared_ptr processMessage( - Message& request) const = 0; + virtual std::shared_ptr processMessage(Message& request) const = 0; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index b68cb4092b678..684ca5576a563 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -86,12 +88,12 @@ std::unique_ptr deserializePythonRpcCommandReference( void processAsyncExecution( const py::object& pyFn, const int64_t messageId, - const std::shared_ptr& responseFuture, + const std::shared_ptr& responseFuture, std::function&)> postProcessing) { + const std::shared_ptr&)> postProcessing) { std::shared_ptr pyFuture; auto& pythonRpcHandler = PythonRpcHandler::getInstance(); { @@ -146,11 +148,12 @@ std::unique_ptr RequestCallbackImpl:: } void RequestCallbackImpl::processScriptCall( - ScriptCall& scriptCall, + RpcCommandBase& rpc, const std::function& markComplete, - std::vector& stack, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { + auto& scriptCall = static_cast(rpc); + auto& stack = scriptCall.stackRef(); if (processScriptCallOp(scriptCall, markComplete, stack)) { return; } @@ -173,13 +176,14 @@ void RequestCallbackImpl::processScriptCall( try { Message m = ScriptResp(valueJitFuture->value()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); } else { @@ -192,9 +196,10 @@ void RequestCallbackImpl::processScriptCall( try { Message m = ScriptResp(jitFuture->value()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }); } @@ -204,7 +209,7 @@ void RequestCallbackImpl::processPythonCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& upc = static_cast(rpc); if (upc.isAsyncExecution()) { try { @@ -215,17 +220,18 @@ void RequestCallbackImpl::processPythonCall( [](const py::object& result, const int64_t messageId, PythonRpcHandler& pythonRpcHandler, - const std::shared_ptr& responseFuture) { + const std::shared_ptr& responseFuture) { auto serializedPyObj = pythonRpcHandler.serialize(result); py::gil_scoped_release release; auto m = std::move(PythonResp(std::move(serializedPyObj))).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }); } catch (std::exception& e) { - responseFuture->markCompleted( - createExceptionResponse(e.what(), messageId)); + responseFuture->markCompleted(IValue(c10::make_intrusive( + createExceptionResponse(e.what(), messageId)))); } } else { auto& pythonRpcHandler = PythonRpcHandler::getInstance(); @@ -332,7 +338,7 @@ void RequestCallbackImpl::processPythonRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto& uprc = static_cast(rpc); const auto& rrefId = uprc.rrefId(); @@ -370,20 +376,22 @@ void RequestCallbackImpl::processPythonRemoteCall( const py::object& result, const int64_t messageId, PythonRpcHandler& /* unused */, - const std::shared_ptr& responseFuture) { + const std::shared_ptr& responseFuture) { IValue py_ivalue = jit::toIValue(result, PyObjectType::get()); py::gil_scoped_release release; ownerRRef->setValue(std::move(py_ivalue)); auto m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }); } catch (std::exception& e) { ownerRRef->setError(std::current_exception()); auto m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } } else { IValue py_ivalue; @@ -411,14 +419,14 @@ void RequestCallbackImpl::processPythonRemoteCall( void RequestCallbackImpl::processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { // Making this lambda mutable to allow move-capture it in callbacks auto postProcessing = [responseFuture]( const c10::intrusive_ptr& rref, int64_t messageId) mutable { auto whenValueSet = rref->getFuture(); if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); + responseFuture->setError(whenValueSet->exception_ptr()); return; } try { @@ -434,15 +442,17 @@ void RequestCallbackImpl::processPythonRRefFetchCall( Message m = PythonRRefFetchRet(std::move(*result).toIValues()).toMessage(); m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); } catch (py::error_already_set& e) { // py::error_already_set requires GIL to destruct, take special care. - responseFuture->setError(e.what()); + responseFuture->setError( + std::make_exception_ptr(std::runtime_error(e.what()))); py::gil_scoped_acquire acquire; e.restore(); PyErr_Clear(); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); } }; @@ -484,7 +494,7 @@ void RequestCallbackImpl::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { try { processRpc(rpc, messageType, messageId, responseFuture); } catch (py::error_already_set& e) { @@ -502,6 +512,61 @@ void RequestCallbackImpl::processRpcWithErrors( } } +bool RequestCallbackImpl::cudaAvailable() const { +#ifdef USE_CUDA + return true; +#else + return false; +#endif +} + +void RequestCallbackImpl::processRRefBackward( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const { + auto& rrefBackwardReq = static_cast(rpc); + + // Get all fields + const auto& rrefId = rrefBackwardReq.getRRefId(); + const auto& autogradContextId = rrefBackwardReq.getAutogradContextId(); + const auto& retainGraph = rrefBackwardReq.retainGraph(); + + auto futureOwner = RRefContext::getInstance().getOwnerRRef(rrefId); + futureOwner->addCallback([responseFuture, + messageId, + futureOwner, + autogradContextId, + retainGraph]() { + const auto& rref = futureOwner->constValue(); + auto whenValueSet = rref->getFuture(); + + whenValueSet->addCallback([responseFuture, + messageId, + rref, + whenValueSet, + autogradContextId, + retainGraph]() { + if (whenValueSet->hasError()) { + responseFuture->setError(whenValueSet->exception_ptr()); + return; + } + + try { + // Run backward (TODO: make this async?). + PyRRef::backward(autogradContextId, retainGraph, rref); + + // Return the response. + Message m = RRefBackwardResp().toMessage(); + m.setId(messageId); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); + } + }); + }); +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h index 0591cc88c7d05..2883359af303c 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.h +++ b/torch/csrc/distributed/rpc/request_callback_impl.h @@ -18,14 +18,13 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void processScriptCall( - ScriptCall& scriptCall, + RpcCommandBase& rpc, const std::function& markComplete, - std::vector& stack, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; TypePtr getScriptRemoteCallType( ScriptRemoteCall& scriptRemoteCall) const override; @@ -40,12 +39,12 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; void handleRRefDelete(c10::intrusive_ptr& rref) const override; @@ -53,7 +52,14 @@ class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const override; + const std::shared_ptr& responseFuture) const override; + + bool cudaAvailable() const override; + + void processRRefBackward( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const override; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 9aa1f2b2aa551..09c56dc960c96 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -8,6 +7,8 @@ #include #include #include +#include +#include #include #include #include @@ -47,13 +48,13 @@ std::unique_ptr RequestCallbackNoPython:: return rpc; } -std::shared_ptr RequestCallbackNoPython::processMessage( +std::shared_ptr RequestCallbackNoPython::processMessage( Message& request) const { // We need two futures here because it could pause twice when processing a // RPC message: // 1) waiting for all RRefs in the arguments to become confirmed; // 2) waiting for processRpc to finish. - auto retFuture = std::make_shared(); + auto retFuture = std::make_shared(at::AnyClassType::get()); auto& rrefContext = RRefContext::getInstance(); try { rrefContext.recordThreadLocalPendingRRefs(); @@ -80,7 +81,7 @@ std::shared_ptr RequestCallbackNoPython::processMessage( if (serverProcessGlobalProfilerStateStackEntryPtr) { // Initialize thread-local profiler state from process-global // profiler state. - ::torch::autograd::profiler::enableProfiler( + ::torch::autograd::profiler::enableProfilerLegacy( serverProcessGlobalProfilerStateStackEntryPtr->statePtr() ->config()); } @@ -92,7 +93,7 @@ std::shared_ptr RequestCallbackNoPython::processMessage( if (serverProcessGlobalProfilerStateStackEntryPtr) { // Restore thread-local profiler state. ::torch::autograd::profiler::thread_event_lists event_lists = - ::torch::autograd::profiler::disableProfiler(); + ::torch::autograd::profiler::disableProfilerLegacy(); // Put thread_local event_lists into the process-global profiler // state. profiler::processglobal::pushResultRecursive( @@ -110,7 +111,7 @@ void RequestCallbackNoPython::processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { try { processRpc(rpc, messageType, messageId, responseFuture); } catch (std::exception& e) { @@ -119,11 +120,12 @@ void RequestCallbackNoPython::processRpcWithErrors( } void RequestCallbackNoPython::processScriptCall( - ScriptCall& scriptCall, + RpcCommandBase& rpc, const std::function& markComplete, - std::vector& stack, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { + auto& scriptCall = static_cast(rpc); + auto& stack = scriptCall.stackRef(); TORCH_CHECK( scriptCall.hasOp(), "Only supports the case where ScriptCall has an op"); processScriptCallOp(scriptCall, markComplete, stack); @@ -159,7 +161,7 @@ void RequestCallbackNoPython::processPythonCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -167,7 +169,7 @@ void RequestCallbackNoPython::processPythonRemoteCall( RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } @@ -181,6 +183,51 @@ void RequestCallbackNoPython::processScriptRemoteCall( processScriptRemoteCallOp(scriptRemoteCall, postProcessing, stack, ownerRRef); } +void RequestCallbackNoPython::processBaseScriptRemoteCall( + RpcCommandBase& rpc, + const std::function& markComplete, + const int64_t messageId, + const std::shared_ptr& responseFuture) const { + auto& scriptRemoteCall = static_cast(rpc); + auto rrefId = scriptRemoteCall.retRRefId(); + auto forkId = scriptRemoteCall.retForkId(); + auto& ctx = RRefContext::getInstance(); + + auto postProcessing = [rrefId, forkId, messageId, responseFuture]() { + if (rrefId != forkId) { + // Caller is a user and callee is the owner, add fork + // + // NB: rrefId == forkId is true if and only if calling remote to + // self. In that case both the caller and the callee will access + // the OwnerRRef. Hence, on the callee side (here), it should not + // call addForkOfOwner as it is not a fork. To allow callee to + // distinguish when this request is sent to self, the caller will + // set forkId using rrefId (OwnerRRef does not have a forkId + // anyway). + RRefContext::getInstance().addForkOfOwner(rrefId, forkId); + } + Message m = RemoteRet(rrefId, forkId).toMessage(); + m.setId(messageId); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + }; + + // scriptRemoteCall is only alive within this block, use reference to + // avoid copy. If the underlying code runs with a continuation, runAsync() + // below will std::move the appropriate portion of the stack. + TypePtr returnType = getScriptRemoteCallType(scriptRemoteCall); + c10::intrusive_ptr ownerRRef; + if (rrefId == forkId) { + // Creating an owner RRef on self, should already exist in owners map + ownerRRef = ctx.getOwnerRRef(rrefId, /* forceCreated */ true)->constValue(); + } else { + ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, returnType); + } + + auto& stack = scriptRemoteCall.stackRef(); + processScriptRemoteCall(scriptRemoteCall, postProcessing, stack, ownerRRef); +} + bool RequestCallbackNoPython::processScriptRemoteCallOp( ScriptRemoteCall& scriptRemoteCall, const std::function& postProcessing, @@ -209,28 +256,318 @@ bool RequestCallbackNoPython::processScriptRemoteCallOp( return false; } +void RequestCallbackNoPython::processScriptRRefFetchCall( + RpcCommandBase& rpc, + const std::function& markComplete, + const int64_t messageId, + const std::shared_ptr& responseFuture) const { + auto& srf = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + + auto futureOwner = ctx.getOwnerRRef(srf.rrefId()); + + if (futureOwner->completed()) { // optional fast-path + // the OwnerRRef has been created + const auto& rref = futureOwner->constValue(); + if (rref->hasValue()) { + markComplete(ScriptRRefFetchRet({rref->getValue()}).toMessage()); + return; + } + } + + futureOwner->addCallback([responseFuture, messageId, futureOwner]() { + const auto& rref = futureOwner->constValue(); + auto whenValueSet = rref->getFuture(); + + // Our response is satisfied when the rpc.remote() request + // finishes executing on the owner. + whenValueSet->addCallback( + [responseFuture, messageId, rref, whenValueSet]() { + if (whenValueSet->hasError()) { + responseFuture->setError(whenValueSet->exception_ptr()); + return; + } + try { + Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage(); + m.setId(messageId); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } catch (const std::exception& /* unused */) { + responseFuture->setError(std::current_exception()); + } + }); + }); +} + void RequestCallbackNoPython::processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& /* unused */) const { C10_THROW_ERROR(Error, "Python call not supported!"); } +void RequestCallbackNoPython::processRRefUserDelete( + RpcCommandBase& rpc, + const std::function& markComplete) const { + auto& rud = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId()); + handleRRefDelete(deletedRRef); + markComplete(std::move(RRefAck()).toMessage()); +} + void RequestCallbackNoPython::handleRRefDelete( c10::intrusive_ptr& rref) const { TORCH_CHECK(!rref->isPyObj(), "RRefs with python objects not supported!"); } +void RequestCallbackNoPython::processRRefChildAccept( + RpcCommandBase& rpc, + const std::function& markComplete) const { + auto& rca = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + ctx.delPendingChild(rca.forkId()); + markComplete(std::move(RRefAck()).toMessage()); +} + +void RequestCallbackNoPython::processRRefForkRequest( + RpcCommandBase& rpc, + const std::function& markComplete) const { + auto& rfr = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId()); + markComplete(RRefAck().toMessage()); +} + +void RequestCallbackNoPython::processForwardAutogradReq( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const { + auto& rpcWithAutograd = static_cast(rpc); + + // Attach 'recv' autograd function. + auto autogradContext = addRecvRpcBackward( + rpcWithAutograd.autogradMetadata(), + rpcWithAutograd.tensors(), + rpcWithAutograd.fromWorkerId()); + // For this recv thread on server side, before processRpc(), + // set current_context_id_ to be context_id passed from client. + // In this way, if there is nested rpc call in python rpc call, original + // context_id from client can be passed in the chain calls. + TORCH_INTERNAL_ASSERT( + autogradContext != nullptr, + "autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get " + "or create valid autogradContext in addRecvRpcBackward."); + + DistAutogradContextGuard ctxGuard(autogradContext->contextId()); + + // Process the original RPC. + auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); + // Make an overall future for the wrapped response. + auto wrappedRpcResponseFuture = + std::make_shared(at::AnyClassType::get()); + // Kick off processing for the nested RPC command. + // wrappedRpcResponseFuture will be a Future to the result. + processRpc( + rpcWithAutograd.wrappedRpc(), + wrappedMessageType, + messageId, + wrappedRpcResponseFuture); + + auto fromWorkerId = rpcWithAutograd.fromWorkerId(); + // The original future needs to be marked as completed when the wrapped + // one completes, with the autograd context information wrapped. + // Uses weak_ptr so we can std::move the value. + wrappedRpcResponseFuture->addCallback( + [responseFuture, + messageId, + fromWorkerId, + weak = std::weak_ptr(wrappedRpcResponseFuture), + ctxId = autogradContext->contextId()]() { + // As this callback can be invoked by a different thread, we have to + // make sure that the thread_local states in the previous thread is + // correctly propagated. + // NB: The execution of TorchScript functions can also run on a + // different thread, which is addressed by + // https://github.com/pytorch/pytorch/pull/36395 + // NB: when adding async UDF support, we should also propagate + // thread_local states there. + // TODO: Land on a general solution for RPC ThreadLocalState. See + // https://github.com/pytorch/pytorch/issues/38510 + DistAutogradContextGuard cbCtxGuard(ctxId); + + auto wrappedRpcResponseFuture = weak.lock(); + TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); + if (wrappedRpcResponseFuture->hasError()) { + // Propagate error to responseFuture if we had one. + responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); + } else { + auto msg = getMessageWithAutograd( + fromWorkerId, + std::move( + *wrappedRpcResponseFuture->value().toCustomClass()), + MessageType::FORWARD_AUTOGRAD_RESP); + msg.setId(messageId); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(msg)))); + } + }); +} + +void RequestCallbackNoPython::processBackwardAutogradReq( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const { + auto& gradientsCall = static_cast(rpc); + const auto& autogradMetadata = gradientsCall.getAutogradMetadata(); + + // Retrieve the appropriate autograd context. + auto autogradContext = DistAutogradContainer::getInstance().retrieveContext( + autogradMetadata.autogradContextId); + + // Lookup the appropriate 'send' function to enqueue. + std::shared_ptr sendFunction = + autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId); + + // Attach the gradients to the send function. + sendFunction->setGrads(gradientsCall.getGrads()); + + // Now execute the autograd graph using the "distributed engine." + auto execFuture = DistEngine::getInstance().executeSendFunctionAsync( + autogradContext, sendFunction, gradientsCall.retainGraph()); + + // Our response is satisfied when the rpcs come back. + execFuture->addCallback([responseFuture, messageId, execFuture]() { + if (!execFuture->hasError()) { + Message m = std::move(PropagateGradientsResp()).toMessage(); + m.setId(messageId); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); + } else { + responseFuture->setError(execFuture->exception_ptr()); + } + }); +} + +void RequestCallbackNoPython::processCleanupAutogradContextReq( + RpcCommandBase& rpc, + const std::function& markComplete) const { + auto& cleanupContextReq = static_cast(rpc); + auto cleanupContextId = cleanupContextReq.getContextId(); + // release the context if it still exists on this thread. We need to + // check if it exists since it may have been deleted by an in-flight + // RPC. This can create nested RPCs if there are other nodes that get + // notified to clean up their context. + DistAutogradContainer::getInstance().releaseContextIfPresent( + cleanupContextId); + markComplete(std::move(CleanupAutogradContextResp()).toMessage()); +} + +void RequestCallbackNoPython::processRunWithProfilingReq( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const { + auto& rpcWithProfilingReq = static_cast(rpc); + auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType(); + auto profilingConfig = rpcWithProfilingReq.getProfilingConfig(); + // If requested with CUDA from caller but CUDA is not available on this + // machine, fallback to CPU and log a warning instead of crashing. + if (profilingConfig.state == torch::autograd::profiler::ProfilerState::CUDA && + !this->cudaAvailable()) { + profilingConfig = torch::autograd::profiler::ProfilerConfig( + torch::autograd::profiler::ProfilerState::CPU, + profilingConfig.report_input_shapes, + profilingConfig.profile_memory); + + LOG(WARNING) << "Profiler was requested to be enabled with CUDA on this " + "node, but CUDA is not available. " + << "Falling back to CPU profiling only."; + } + TORCH_INTERNAL_ASSERT( + profilingConfig.state != torch::autograd::profiler::ProfilerState::CUDA || + this->cudaAvailable(), + "Profiler state set to CUDA but CUDA not available."); + const auto profilingKeyId = rpcWithProfilingReq.getProfilingId(); + auto wrappedRpcResponseFuture = + std::make_shared(at::AnyClassType::get()); + // Enable the profiler with the config from the sender. + // When enabling on the main thread, ensure profiler states are cleaned + // up, but defer consolidation of all profiled events to the continuation + // below. + torch::autograd::profiler::ProfilerDisableOptions requestThreadOptions( + true /* cleanup TLS state */, false /* consolidate events */); + { + torch::autograd::profiler::TLSProfilerGuard g( + profilingConfig, c10::nullopt, requestThreadOptions); + TORCH_INTERNAL_ASSERT( + torch::autograd::profiler::profilerEnabled(), + "Expected profiler to be enabled!"); + // Kick off processing for nested work and get Future result in + // wrappedRpcResponseFuture + processRpc( + rpcWithProfilingReq.wrappedRpc(), + wrappedMsgType, + messageId, + wrappedRpcResponseFuture); + + wrappedRpcResponseFuture->addCallback( + at::wrapPropagateTLSState([wrappedRpcResponseFuture, + responseFuture, + profilingKeyId, + profilingConfig] { + std::vector profiledEvents; + // Defer consolidation of profiler events until async work has + // completed (such as async UDF) + + TORCH_INTERNAL_ASSERT( + torch::autograd::profiler::profilerEnabled(), + "Expected profiler to be enabled!"); + + // On continuation thread, don't clean up profiler states, since + // they will be cleaned up by main thread, and consolidate all + // events so we obtain asynchronously run events. + torch::autograd::profiler::ProfilerDisableOptions opts(false, true); + auto event_lists = + torch::autograd::profiler::disableProfilerLegacy(opts); + if (wrappedRpcResponseFuture->hasError()) { + // Propagate error + // No need to propagate remote events in the case of an error. + responseFuture->setError(wrappedRpcResponseFuture->exception_ptr()); + } else { + populateRemoteProfiledEvents( + profiledEvents, profilingConfig, event_lists); + auto rpcWithProfilingResp = std::make_unique( + MessageType::RUN_WITH_PROFILING_RESP, + std::move(*wrappedRpcResponseFuture->value() + .toCustomClass()), + profiledEvents, + profilingKeyId); + responseFuture->markCompleted(IValue(c10::make_intrusive( + std::move(*rpcWithProfilingResp).toMessage()))); + } + })); + // Exiting the scope will disable the profiler on this thread with the + // options specified above. + } +} + +void RequestCallbackNoPython::processRRefBackward( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& /* unused */) const { + C10_THROW_ERROR(Error, "Python call not supported!"); +} + void RequestCallbackNoPython::processRpc( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const { + const std::shared_ptr& responseFuture) const { auto markComplete = [messageId, &responseFuture](Message m) { m.setId(messageId); - responseFuture->markCompleted(std::move(m)); + responseFuture->markCompleted( + IValue(c10::make_intrusive(std::move(m)))); }; - // TODO: RpcCommandBase should have an abstract execute() method that we can // call here instead of having another switch statement here. Even better we // could have abstract classes RpcRequest and RpcResp which inherit from @@ -239,12 +576,7 @@ void RequestCallbackNoPython::processRpc( // to a python object. switch (messageType) { case MessageType::SCRIPT_CALL: { - auto& scriptCall = static_cast(rpc); - - // scriptCall is only alive within this block, use reference to avoid copy - auto& stack = scriptCall.stackRef(); - processScriptCall( - scriptCall, markComplete, stack, messageId, responseFuture); + processScriptCall(rpc, markComplete, messageId, responseFuture); return; } case MessageType::PYTHON_CALL: { @@ -252,45 +584,7 @@ void RequestCallbackNoPython::processRpc( return; } case MessageType::SCRIPT_REMOTE_CALL: { - auto& scriptRemoteCall = static_cast(rpc); - auto rrefId = scriptRemoteCall.retRRefId(); - auto forkId = scriptRemoteCall.retForkId(); - auto& ctx = RRefContext::getInstance(); - - auto postProcessing = [rrefId, forkId, messageId, responseFuture]() { - if (rrefId != forkId) { - // Caller is a user and callee is the owner, add fork - // - // NB: rrefId == forkId is true if and only if calling remote to - // self. In that case both the caller and the callee will access - // the OwnerRRef. Hence, on the callee side (here), it should not - // call addForkOfOwner as it is not a fork. To allow callee to - // distinguish when this request is sent to self, the caller will - // set forkId using rrefId (OwnerRRef does not have a forkId - // anyway). - RRefContext::getInstance().addForkOfOwner(rrefId, forkId); - } - Message m = RemoteRet(rrefId, forkId).toMessage(); - m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - }; - - // scriptRemoteCall is only alive within this block, use reference to - // avoid copy. If the underlying code runs with a continuation, runAsync() - // below will std::move the appropriate portion of the stack. - TypePtr returnType = getScriptRemoteCallType(scriptRemoteCall); - c10::intrusive_ptr ownerRRef; - if (rrefId == forkId) { - // Creating an owner RRef on self, should already exist in owners map - ownerRRef = - ctx.getOwnerRRef(rrefId, /* forceCreated */ true)->constValue(); - } else { - ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, returnType); - } - - auto& stack = scriptRemoteCall.stackRef(); - processScriptRemoteCall( - scriptRemoteCall, postProcessing, stack, ownerRRef); + processBaseScriptRemoteCall(rpc, markComplete, messageId, responseFuture); return; } case MessageType::PYTHON_REMOTE_CALL: { @@ -298,43 +592,7 @@ void RequestCallbackNoPython::processRpc( return; } case MessageType::SCRIPT_RREF_FETCH_CALL: { - auto& srf = static_cast(rpc); - auto& ctx = RRefContext::getInstance(); - - auto futureOwner = ctx.getOwnerRRef(srf.rrefId()); - - if (futureOwner->completed()) { // optional fast-path - // the OwnerRRef has been created - const auto& rref = futureOwner->constValue(); - if (rref->hasValue()) { - markComplete(ScriptRRefFetchRet({rref->getValue()}).toMessage()); - return; - } - } - - futureOwner->addCallback([responseFuture, messageId, futureOwner]() { - const auto& rref = futureOwner->constValue(); - auto whenValueSet = rref->getFuture(); - - // Our response is satisfied when the rpc.remote() request - // finishes executing on the owner. - whenValueSet->addCallback( - [responseFuture, messageId, rref, whenValueSet]() { - if (whenValueSet->hasError()) { - responseFuture->setError( - whenValueSet->tryRetrieveErrorMessage()); - return; - } - try { - Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage(); - m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } catch (const std::exception& e) { - responseFuture->setError(e.what()); - } - }); - }); - + processScriptRRefFetchCall(rpc, markComplete, messageId, responseFuture); return; } case MessageType::PYTHON_RREF_FETCH_CALL: { @@ -342,266 +600,35 @@ void RequestCallbackNoPython::processRpc( return; } case MessageType::RREF_USER_DELETE: { - auto& rud = static_cast(rpc); - auto& ctx = RRefContext::getInstance(); - auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId()); - handleRRefDelete(deletedRRef); - markComplete(std::move(RRefAck()).toMessage()); + processRRefUserDelete(rpc, markComplete); return; } case MessageType::RREF_CHILD_ACCEPT: { - auto& rca = static_cast(rpc); - auto& ctx = RRefContext::getInstance(); - ctx.delPendingChild(rca.forkId()); - markComplete(std::move(RRefAck()).toMessage()); + processRRefChildAccept(rpc, markComplete); return; } case MessageType::RREF_FORK_REQUEST: { - auto& rfr = static_cast(rpc); - auto& ctx = RRefContext::getInstance(); - ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId()); - markComplete(RRefAck().toMessage()); + processRRefForkRequest(rpc, markComplete); return; } case MessageType::FORWARD_AUTOGRAD_REQ: { - auto& rpcWithAutograd = static_cast(rpc); - - // Attach 'recv' autograd function. - auto autogradContext = addRecvRpcBackward( - rpcWithAutograd.autogradMetadata(), - rpcWithAutograd.tensors(), - rpcWithAutograd.fromWorkerId()); - // For this recv thread on server side, before processRpc(), - // set current_context_id_ to be context_id passed from client. - // In this way, if there is nested rpc call in python rpc call, original - // context_id from client can be passed in the chain calls. - auto& autogradContainer = DistAutogradContainer::getInstance(); - TORCH_INTERNAL_ASSERT( - autogradContext != nullptr, - "autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get " - "or create valid autogradContext in addRecvRpcBackward."); - - DistAutogradContextGuard ctxGuard(autogradContext->contextId()); - - // Process the original RPC. - auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); - // Make an overall future for the wrapped response. - auto wrappedRpcResponseFuture = std::make_shared(); - // Kick off processing for the nested RPC command. - // wrappedRpcResponseFuture will be a Future to the result. - processRpc( - rpcWithAutograd.wrappedRpc(), - wrappedMessageType, - messageId, - wrappedRpcResponseFuture); - - auto fromWorkerId = rpcWithAutograd.fromWorkerId(); - // The original future needs to be marked as completed when the wrapped - // one completes, with the autograd context information wrapped. - // Uses weak_ptr so we can std::move the value. - wrappedRpcResponseFuture->addCallback( - [responseFuture, - messageId, - fromWorkerId, - weak = std::weak_ptr(wrappedRpcResponseFuture), - ctxId = autogradContext->contextId()]() { - // As this callback can be invoked by a different thread, we have to - // make sure that the thread_local states in the previous thread is - // correctly propagated. - // NB: The execution of TorchScript functions can also run on a - // different thread, which is addressed by - // https://github.com/pytorch/pytorch/pull/36395 - // NB: when adding async UDF support, we should also propagate - // thread_local states there. - // TODO: Land on a general solution for RPC ThreadLocalState. See - // https://github.com/pytorch/pytorch/issues/38510 - DistAutogradContextGuard cbCtxGuard(ctxId); - - auto wrappedRpcResponseFuture = weak.lock(); - TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); - if (wrappedRpcResponseFuture->hasError()) { - // Propagate error to responseFuture if we had one. - responseFuture->setError( - wrappedRpcResponseFuture->error()->what()); - } else { - auto msg = getMessageWithAutograd( - fromWorkerId, - std::move(*wrappedRpcResponseFuture).moveValue(), - MessageType::FORWARD_AUTOGRAD_RESP); - msg.setId(messageId); - responseFuture->markCompleted(std::move(msg)); - } - }); + processForwardAutogradReq(rpc, messageId, responseFuture); return; } case MessageType::BACKWARD_AUTOGRAD_REQ: { - auto& gradientsCall = static_cast(rpc); - const auto& autogradMetadata = gradientsCall.getAutogradMetadata(); - - // Retrieve the appropriate autograd context. - auto autogradContext = - DistAutogradContainer::getInstance().retrieveContext( - autogradMetadata.autogradContextId); - - // Lookup the appropriate 'send' function to enqueue. - std::shared_ptr sendFunction = - autogradContext->retrieveSendFunction( - autogradMetadata.autogradMessageId); - - // Attach the gradients to the send function. - sendFunction->setGrads(gradientsCall.getGrads()); - - // Now execute the autograd graph using the "distributed engine." - auto execFuture = DistEngine::getInstance().executeSendFunctionAsync( - autogradContext, sendFunction, gradientsCall.retainGraph()); - - // Our response is satisfied when the rpcs come back. - execFuture->addCallback([responseFuture, messageId, execFuture]() { - if (!execFuture->hasError()) { - Message m = std::move(PropagateGradientsResp()).toMessage(); - m.setId(messageId); - responseFuture->markCompleted(std::move(m)); - } else { - responseFuture->setError(execFuture->tryRetrieveErrorMessage()); - } - }); + processBackwardAutogradReq(rpc, messageId, responseFuture); return; }; case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { - auto& cleanupContextReq = static_cast(rpc); - auto cleanupContextId = cleanupContextReq.getContextId(); - // release the context if it still exists on this thread. We need to - // check if it exists since it may have been deleted by an in-flight - // RPC. This can create nested RPCs if there are other nodes that get - // notified to clean up their context. - DistAutogradContainer::getInstance().releaseContextIfPresent( - cleanupContextId); - markComplete(std::move(CleanupAutogradContextResp()).toMessage()); + processCleanupAutogradContextReq(rpc, markComplete); return; } case MessageType::RUN_WITH_PROFILING_REQ: { - auto& rpcWithProfilingReq = static_cast(rpc); - auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType(); - const auto profilingConfig = rpcWithProfilingReq.getProfilingConfig(); - const auto profilingKeyId = rpcWithProfilingReq.getProfilingId(); - auto wrappedRpcResponseFuture = std::make_shared(); - // Enable the profiler with the config from the sender. - std::vector profiledEvents; - { - torch::autograd::profiler::TLSProfilerGuard g( - profilingConfig, - [&profiledEvents, profilingConfig]( - const std::vector>& event_lists) { - // Gather all events into a vector - for (auto& l : event_lists) { - for (auto& e : l) { - profiledEvents.push_back(e); - } - } - // find __start_profile event and __cuda_start_event. - bool cuda_profiling_enabled = profilingConfig.state == - torch::autograd::profiler::ProfilerState::CUDA; - bool found_cpu_start = false; - const torch::autograd::profiler::Event* profilerStart = nullptr; - // Each device has its own cudaProfilerStart, so we must take - // care to use the correct one depending on the device the - // operation ran on. - std::unordered_map - cudaProfilerStarts; - for (auto& e : profiledEvents) { - if (!found_cpu_start && - 0 == strcmp(e.name(), "__start_profile")) { - profilerStart = &e; - found_cpu_start = true; - } - if (cuda_profiling_enabled && - 0 == strcmp(e.name(), "__cuda_start_event")) { - e.setCudaUs(e.cpu_us()); - auto device = e.device(); - TORCH_CHECK( - device != -1, - "CUDA profiling was enabled but could not find CUDA device."); - TORCH_CHECK( - cudaProfilerStarts.find(device) == - cudaProfilerStarts.end(), - c10::str( - "Duplicate __cuda_start_event found for ", device)); - cudaProfilerStarts[device] = &e; - } - // TODO: determine no. of CUDA devices and break here if we have - // a cudaProfilerStart for all of them, in the case of cuda - // profiling. - if (found_cpu_start && !cuda_profiling_enabled) { - break; - } - } - // We should always find __start_profile. - TORCH_CHECK( - profilerStart != nullptr, - "Expected to find __start_profile event."); - // Should have >= 1 CUDA start event. - // TODO: we can enhance this assert by ensuring we have found a - // start for every available CUDA device. - TORCH_CHECK( - !cuda_profiling_enabled || cudaProfilerStarts.size() > 0, - "Profiler was enabled with CUDA recording, but did not find __cuda_start_event."); - - if (cuda_profiling_enabled) { - // Compute and set global time for when this CUDA kernel was - // launched/ended, since deserialized event will not have a - // corresponding CUDA event. - for (auto& e : profiledEvents) { - if (e.has_cuda()) { - auto cuda_device = e.device(); - TORCH_CHECK( - cuda_device != -1, - "CUDA profiling was enabled but could not find CUDA device."); - auto it = cudaProfilerStarts.find(cuda_device); - TORCH_CHECK( - it != cudaProfilerStarts.end(), - c10::str( - "Failed to find __cuda_start_event for device ", - cuda_device)); - auto cudaProfilerStartEvent = it->second; - double cuda_elapsed_us = - cudaProfilerStartEvent->cuda_elapsed_us(e); - int64_t cuda_us = - cuda_elapsed_us + cudaProfilerStartEvent->cpu_us(); - e.setCudaUs(cuda_us); - } - } - } - }); - TORCH_INTERNAL_ASSERT( - torch::autograd::profiler::profilerEnabled(), - "Expected profiler to be enabled!"); - // Kick off processing for nested work and get Future result in - // wrappedRpcResponseFuture - processRpc( - rpcWithProfilingReq.wrappedRpc(), - wrappedMsgType, - messageId, - wrappedRpcResponseFuture); - } - wrappedRpcResponseFuture->addCallback([wrappedRpcResponseFuture, - responseFuture, - profiledEvents = - std::move(profiledEvents), - profilingKeyId] { - if (wrappedRpcResponseFuture->hasError()) { - // Propagate error - responseFuture->setError(wrappedRpcResponseFuture->error()->what()); - } else { - auto rpcWithProfilingResp = std::make_unique( - MessageType::RUN_WITH_PROFILING_RESP, - std::move(*wrappedRpcResponseFuture).moveValue(), - profiledEvents, - profilingKeyId); - responseFuture->markCompleted( - std::move(*rpcWithProfilingResp).toMessage()); - } - }); + processRunWithProfilingReq(rpc, messageId, responseFuture); + return; + } + case MessageType::RREF_BACKWARD_REQ: { + processRRefBackward(rpc, messageId, responseFuture); return; } default: { @@ -611,7 +638,7 @@ void RequestCallbackNoPython::processRpc( } } -Message RequestCallbackNoPython::handleError( +IValue RequestCallbackNoPython::handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const { @@ -624,7 +651,16 @@ Message RequestCallbackNoPython::handleError( DistAutogradContainer::getInstance().getWorkerId(), ": ", e.what()); - return createExceptionResponse(errorMsg, messageId); + return IValue(c10::make_intrusive( + createExceptionResponse(errorMsg, messageId))); +} + +bool RequestCallbackNoPython::cudaAvailable() const { +#ifdef USE_CUDA + return true; +#else + return false; +#endif } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h index dd54ea0094174..9932c4744900a 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.h +++ b/torch/csrc/distributed/rpc/request_callback_no_python.h @@ -14,8 +14,7 @@ namespace rpc { // RequestCallback implementation with no Python dependencies. class TORCH_API RequestCallbackNoPython : public RequestCallback { public: - std::shared_ptr processMessage( - Message& request) const override; + std::shared_ptr processMessage(Message& request) const override; protected: virtual std::unique_ptr deserializePythonRpcCommand( @@ -23,11 +22,10 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { const MessageType& messageType) const; virtual void processScriptCall( - ScriptCall& scriptCall, + RpcCommandBase& rpc, const std::function& markComplete, - std::vector& stack, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; bool processScriptCallOp( ScriptCall& scriptCall, @@ -38,7 +36,7 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual TypePtr getScriptRemoteCallType( ScriptRemoteCall& scriptRemoteCall) const; @@ -49,6 +47,12 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { std::vector& stack, const c10::intrusive_ptr& ownerRRef) const; + void processBaseScriptRemoteCall( + RpcCommandBase& rpc, + const std::function& markComplete, + const int64_t messageId, + const std::shared_ptr& responseFuture) const; + bool processScriptRemoteCallOp( ScriptRemoteCall& scriptRemoteCall, const std::function& postProcessing, @@ -59,12 +63,49 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const std::function& markComplete, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; + + void processScriptRRefFetchCall( + RpcCommandBase& rpc, + const std::function& markComplete, + const int64_t messageId, + const std::shared_ptr& responseFuture) const; virtual void processPythonRRefFetchCall( RpcCommandBase& rpc, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; + + void processRRefUserDelete( + RpcCommandBase& rpc, + const std::function& markComplete) const; + + void processRRefChildAccept( + RpcCommandBase& rpc, + const std::function& markComplete) const; + + void processRRefForkRequest( + RpcCommandBase& rpc, + const std::function& markComplete) const; + + void processForwardAutogradReq( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const; + + void processBackwardAutogradReq( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const; + + void processCleanupAutogradContextReq( + RpcCommandBase& rpc, + const std::function& markComplete) const; + + void processRunWithProfilingReq( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const; virtual void handleRRefDelete(c10::intrusive_ptr& rref) const; @@ -72,18 +113,25 @@ class TORCH_API RequestCallbackNoPython : public RequestCallback { RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; virtual void processRpcWithErrors( RpcCommandBase& rpc, const MessageType& messageType, const int64_t messageId, - const std::shared_ptr& responseFuture) const; + const std::shared_ptr& responseFuture) const; - Message handleError( + IValue handleError( const std::exception& e, const MessageType messageType, int64_t messageId) const; + + virtual bool cudaAvailable() const; + + virtual void processRRefBackward( + RpcCommandBase& rpc, + const int64_t messageId, + const std::shared_ptr& responseFuture) const; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 4d9f6db392200..2033b2b771e2a 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -45,7 +45,7 @@ void RpcAgent::shutdown() { shutdownImpl(); } -std::shared_ptr RpcAgent::sendWithRetries( +std::shared_ptr RpcAgent::sendWithRetries( const WorkerInfo& to, Message&& message, RpcRetryOptions retryOptions) { @@ -57,12 +57,12 @@ std::shared_ptr RpcAgent::sendWithRetries( retryOptions.rpcRetryDuration.count() >= 0, "rpcRetryDuration cannot be negative."); - auto originalFuture = std::make_shared(); + auto originalFuture = std::make_shared(at::AnyClassType::get()); steady_clock_time_point newTime = computeNewRpcRetryTime(retryOptions, /* retryCount */ 0); // Making a copy of the message so it can be retried after the first send. Message msgCopy = message; - auto fm = send(to, std::move(message)); + auto jitFuture = send(to, std::move(message)); auto firstRetryRpc = std::make_shared( to, std::move(msgCopy), @@ -70,13 +70,13 @@ std::shared_ptr RpcAgent::sendWithRetries( /* retryCount */ 0, retryOptions); // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. - fm->addCallback([this, - newTime, - firstRetryRpc, - weak = std::weak_ptr(fm)]() { - auto fm = weak.lock(); - TORCH_INTERNAL_ASSERT(fm); - rpcRetryCallback(fm, newTime, firstRetryRpc); + jitFuture->addCallback([this, + newTime, + firstRetryRpc, + wp = std::weak_ptr(jitFuture)]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + rpcRetryCallback(future, newTime, firstRetryRpc); }); return originalFuture; @@ -85,11 +85,10 @@ std::shared_ptr RpcAgent::sendWithRetries( void RpcAgent::retryExpiredRpcs() { // Stores the retried futures so callbacks can be added outside the lock. std::vector< - std::pair, std::shared_ptr>> + std::pair, std::shared_ptr>> futures; // Stores futures and exception messages for non-retriable error-ed futures. - std::vector, std::string>> - errorFutures; + std::vector, std::string>> errorFutures; while (rpcAgentRunning_.load()) { std::unique_lock lock(rpcRetryMutex_); @@ -126,15 +125,15 @@ void RpcAgent::retryExpiredRpcs() { auto& earliestRpc = *it; // Making a copy of the message so it can be retried in the future. Message msgCopy = earliestRpc->message_; - std::shared_ptr fm; + std::shared_ptr jitFuture; // send() will throw an exception if an RPC is retried while the agent is // shutdown. We must catch this exception and mark the original future // with an error, since this RPC never succeeded and can no longer be // retried. try { - fm = send(earliestRpc->to_, std::move(msgCopy)); - futures.emplace_back(fm, earliestRpc); + jitFuture = send(earliestRpc->to_, std::move(msgCopy)); + futures.emplace_back(jitFuture, earliestRpc); } catch (std::exception& e) { // We must store the futures and exception messages here and only mark // the futures with an error after releasing the lock. @@ -158,20 +157,20 @@ void RpcAgent::retryExpiredRpcs() { // We attach callbacks to the futures outside of the lock to prevent // potential deadlocks. for (const auto& it : futures) { - auto fm = it.first; + auto jitFuture = it.first; auto earliestRpc = it.second; steady_clock_time_point newTime = computeNewRpcRetryTime( earliestRpc->options_, earliestRpc->retryCount_); earliestRpc->retryCount_++; // Use weak_ptr so that the value can be std::moved in rpcRetryCallback. - fm->addCallback([this, - newTime, - earliestRpc, - weak = std::weak_ptr(fm)]() { - auto fm = weak.lock(); - TORCH_INTERNAL_ASSERT(fm); - rpcRetryCallback(fm, newTime, earliestRpc); + jitFuture->addCallback([this, + newTime, + earliestRpc, + wp = std::weak_ptr(jitFuture)]() { + auto future = wp.lock(); + TORCH_INTERNAL_ASSERT(future); + rpcRetryCallback(future, newTime, earliestRpc); }); } futures.clear(); @@ -181,17 +180,18 @@ void RpcAgent::retryExpiredRpcs() { for (const auto& it : errorFutures) { auto errorFuture = it.first; auto errorMsg = it.second; - errorFuture->setError(errorMsg); + errorFuture->setError( + std::make_exception_ptr(std::runtime_error(errorMsg))); } errorFutures.clear(); } } void RpcAgent::rpcRetryCallback( - const std::shared_ptr& futureMessage, + const std::shared_ptr& jitFuture, steady_clock_time_point newTime, std::shared_ptr earliestRpc) { - if (futureMessage->hasError()) { + if (jitFuture->hasError()) { // Adding one since we want to include the original send as well and not // just the retry count. LOG(INFO) << "Send try " << (earliestRpc->retryCount_ + 1) << " failed"; @@ -203,7 +203,7 @@ void RpcAgent::rpcRetryCallback( "RPC Agent is no longer running on Node ", RpcAgent::getWorkerInfo().id_, ". Cannot retry message."); - earliestRpc->originalFuture_->setError(*futureMessage->error()); + earliestRpc->originalFuture_->setError(jitFuture->exception_ptr()); } else if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) { // If the previous future completed with an error and we haven't // completed maxRetries send attempts, we move the earliestRpc @@ -223,12 +223,12 @@ void RpcAgent::rpcRetryCallback( "The RPC has not succeeded after the specified number of max retries (", earliestRpc->options_.maxRetries, ")."); - earliestRpc->originalFuture_->setError(errorMessage); + earliestRpc->originalFuture_->setError( + std::make_exception_ptr(std::runtime_error(errorMessage))); } } else { // This try succeeded, so we can make the original future as complete. - earliestRpc->originalFuture_->markCompleted( - std::move(*futureMessage).moveValue()); + earliestRpc->originalFuture_->markCompleted(jitFuture->value()); } } diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index 34b77a085510f..bfc6c38c07a1f 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -105,7 +105,7 @@ struct TORCH_API RpcRetryInfo { RpcRetryInfo( const WorkerInfo& to, Message&& message, - std::shared_ptr originalFuture, + std::shared_ptr originalFuture, int retryCount, RpcRetryOptions options) : to_(to), @@ -117,7 +117,7 @@ struct TORCH_API RpcRetryInfo { const WorkerInfo& to_; Message message_; // Future that is returned to the caller of sendWithRetries(). - std::shared_ptr originalFuture_; + std::shared_ptr originalFuture_; // Number of send attempts completed so far. int retryCount_; RpcRetryOptions options_; @@ -151,13 +151,13 @@ class TORCH_API RpcAgent { virtual ~RpcAgent(); // Send a message to the ``RpcAgent`` of id ``to`` and returns a - // ``FutureMessage`` ptr. The implementation must be asynchronous, i.e., it + // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it // cannot block until it receives the response. // - // If ``message.isRequest()`` is true, the ``FutureMessage`` will be + // If ``message.isRequest()`` is true, the ``JitFuture`` will be // completed when the response arrives. For other message types, the Future // should be ignored by the caller. - virtual std::shared_ptr send( + virtual std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0; @@ -167,14 +167,14 @@ class TORCH_API RpcAgent { // time using an exponential backoff algorithm. // // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a - // ``FutureMessage`` ptr, just like send(). Caller can specify the maximum + // ``JitFuture`` ptr, just like send(). Caller can specify the maximum // number of retries for this RPC (default is 5), initial duration between // sends (default is 1000ms), and backoff constant (default is 1.5) by // passing in the RpcRetryOptions struct. This API might end up // executing a method twice on the remote end (it does not guarantee // exactly-once semantics). Therefore, the user must ensure their requests // are idempotent. - std::shared_ptr sendWithRetries( + std::shared_ptr sendWithRetries( const WorkerInfo& to, Message&& message, RpcRetryOptions retryOptions = RpcRetryOptions()); @@ -299,7 +299,7 @@ class TORCH_API RpcAgent { // error and do not retry again. In case 3, we move the RpcRetryInfo struct // to another time point in the map to schedule the RPC for a future send. void rpcRetryCallback( - const std::shared_ptr& message, + const std::shared_ptr& message, steady_clock_time_point newTime, std::shared_ptr earliestRpc); diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index dd64ee5c9445d..ce257c50a7a4b 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -14,11 +14,12 @@ thread_local bool RRefContext::recording_ = false; namespace callback { void confirmPendingUser( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const ForkId& expectedForkId) { - if (!futureMessage.hasError()) { - auto msgType = futureMessage.constValue().type(); - auto rpc = deserializeResponse(futureMessage.constValue(), msgType); + if (!jitFuture.hasError()) { + auto msgPtr = jitFuture.constValue().toCustomClass(); + auto msgType = msgPtr->type(); + auto rpc = deserializeResponse(*msgPtr, msgType); auto rr = dynamic_cast(rpc.get()); TORCH_INTERNAL_ASSERT(rr->forkId() == expectedForkId); } else { @@ -34,30 +35,31 @@ void confirmPendingUser( // the user application will use the RRef before the errors are handled. In // this case, errors may not be raised as they have not yet been handled. auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId); - auto errorType = getRPCErrorType(futureMessage); - rref_ptr->handleError(errorType, futureMessage); + auto errorType = getRPCErrorType(jitFuture); + rref_ptr->handleError(errorType, jitFuture); } RRefContext::getInstance().delPendingUser(expectedForkId); } c10::intrusive_ptr finishCreatingOwnerRRef( - const FutureMessage& futureMessage, + const JitFuture& jitFuture, const RRefId& rrefId) { - if (futureMessage.hasError()) { + if (jitFuture.hasError()) { auto& ctx = RRefContext::getInstance(); // We expect to run this callback only after the OwnerRRef has been created, // since this is only invoked when sending to self. auto rref_ptr = ctx.getOwnerRRef(rrefId, /* ensure created */ true)->constValue(); - auto errorType = getRPCErrorType(futureMessage); - rref_ptr->handleError(errorType, futureMessage); + auto errorType = getRPCErrorType(jitFuture); + rref_ptr->handleError(errorType, jitFuture); // OwnerRRefs do not have a forkId, so don't need to assert here. auto deletedRRef = ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId()); return deletedRRef; } else { - auto msgType = futureMessage.constValue().type(); - auto rpc = deserializeResponse(futureMessage.constValue(), msgType); + auto msgPtr = jitFuture.constValue().toCustomClass(); + auto msgType = msgPtr->type(); + auto rpc = deserializeResponse(*msgPtr, msgType); auto rr = dynamic_cast(rpc.get()); TORCH_INTERNAL_ASSERT( rr->rrefId() == rr->forkId(), @@ -102,10 +104,11 @@ std::vector> RRefContext::destroyInstance( return deletedRRefs; } -void RRefContext::handleException(const FutureMessage& fm) { - if (fm.hasError()) { - VLOG(1) << "Got exception: " << fm.error()->what(); - throw std::runtime_error(fm.error()->what()); +void RRefContext::handleException(const JitFuture& jitFuture) { + if (jitFuture.hasError()) { + auto errMsg = jitFuture.tryRetrieveErrorMessage(); + VLOG(1) << "Got exception: " << errMsg; + throw std::runtime_error(errMsg); } } @@ -209,12 +212,13 @@ void RRefContext::delUser( // which is now idempotent. See the comment at RRefContext::delForkOfOwner // for more details. ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(owner), RRefUserDelete(rrefId, forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } @@ -483,21 +487,24 @@ void RRefContext::notifyOwnerAndParentOfFork( // into forks_. Because, there will be no real `UserRRef` associated // with this fork ID. ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } else { ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(rref->owner()), RRefForkRequest(rref->rrefId(), forkId).toMessage()); addPendingUser(forkId, rref); - fm->addCallback([this, forkId, parent](const FutureMessage& fm) { - handleException(fm); + + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, forkId, parent, wp]() { + handleException(*wp.lock()); this->finishForkRequest(forkId, parent); // Decrease after calling finishForkRequest because, as that creates a new // future, it might otherwise cause the count to briefly go to zero. @@ -676,11 +683,12 @@ void RRefContext::clearRecordedPendingRRefsOnError() { void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) { delPendingUser(forkId); ++numPendingFutures_; - auto fm = agent_->sendWithRetries( + auto jitFuture = agent_->sendWithRetries( agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); - fm->addCallback([this](const FutureMessage& fm) { - handleException(fm); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback([this, wp]() { + handleException(*wp.lock()); --numPendingFutures_; }); } diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index cf89980e7f718..1e3537a6dfd3e 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -16,16 +16,14 @@ namespace rpc { namespace callback { // It's the callback for RemoteCall. -void TORCH_API confirmPendingUser( - const FutureMessage& futureMessage, - const ForkId& expectedForkId); +void TORCH_API +confirmPendingUser(const JitFuture& jitFuture, const ForkId& expectedForkId); // It's the callback for finishing creating owner rref, it returned deletedRRef, // so that the deletedRRef can be handled under GIL in python_functions.cpp if // deletedRRef contains python object. -c10::intrusive_ptr TORCH_API finishCreatingOwnerRRef( - const FutureMessage& futureMessage, - const RRefId& rrefId); +c10::intrusive_ptr TORCH_API +finishCreatingOwnerRRef(const JitFuture& jitFuture, const RRefId& rrefId); } // namespace callback using torch::utils::Future; @@ -42,7 +40,7 @@ class TORCH_API RRefContext { static std::vector> destroyInstance( bool ignoreRRefLeak = true); - static void handleException(const FutureMessage& fm); + static void handleException(const JitFuture& jitFuture); RRefContext(const RRefContext&) = delete; RRefContext(RRefContext&& other) = delete; diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 6c6a377a46524..014b49cbe6121 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -1,10 +1,10 @@ -#include #include #include #include #include #include #include +#include #include #include @@ -65,23 +65,21 @@ RRefForkData RRef::fork() const { getTypeStr(type_)); } -void RRef::handleError( - RPCErrorType errorType, - const FutureMessage& futMessage) { +void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) { static std::unordered_map< RPCErrorType, - std::function, + std::function, std::hash> errorHandlers = { {RPCErrorType::TIMEOUT, - [this](const FutureMessage& /* unused */) { setTimedOut(); }}, + [this](const JitFuture& /* unused */) { setTimedOut(); }}, {RPCErrorType::INTENTIONAL_FAILURE, - [this](const FutureMessage& /* unused */) { setTimedOut(); }}, - {RPCErrorType::UNKNOWN_ERROR, [](const FutureMessage& fm) { + [this](const JitFuture& /* unused */) { setTimedOut(); }}, + {RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) { // Default error handler - RRefContext::handleException(fm); + RRefContext::handleException(jitFuture); }}}; - errorHandlers.find(errorType)->second(futMessage); + errorHandlers.find(errorType)->second(jitFuture); } ////////////////////////// UserRRef ///////////////////////////////////// @@ -170,7 +168,7 @@ IValue UserRRef::toHere(const float timeoutSeconds) const { // toHere is profiled as a blocking call, and does not execute operations on // the remote node. Hence, don't wrap it with a profiling message since we // don't need the profiler to be enabled remotely. - auto futureResponse = autograd::sendMessageWithAutograd( + auto jitFuture = autograd::sendMessageWithAutograd( *agent, agent->getWorkerInfo(ownerId_), std::move(msgToSend), @@ -181,9 +179,10 @@ IValue UserRRef::toHere(const float timeoutSeconds) const { // TODO: we should ideally be able to interrupt this blocking wait if we check // getTimedOut() and it is true // (https://github.com/pytorch/pytorch/issues/39411). - const Message& message = futureResponse->wait(); - MessageType msgType = message.type(); - auto response = deserializeResponse(message, msgType); + jitFuture->waitAndThrow(); + auto messagePtr = jitFuture->constValue().toCustomClass(); + MessageType msgType = messagePtr->type(); + auto response = deserializeResponse(*messagePtr, msgType); TORCH_INTERNAL_ASSERT( msgType == MessageType::SCRIPT_RREF_FETCH_RET || msgType == MessageType::PYTHON_RREF_FETCH_RET, diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index 29aa355908fa9..c7f8122714685 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -228,12 +228,12 @@ class TORCH_API RRef : public RRefInterface { // node. Note that this is only set when processing requests invoked with // rpc.remote. This is only used to get the future corresponding to the rref // for profiling use cases. - inline void registerOwnerCreationFuture(std::shared_ptr fut) { + inline void registerOwnerCreationFuture(std::shared_ptr fut) { ownerCreationFuture_ = std::move(fut); } // Get the future corresponding to the creation of this rref. - inline std::shared_ptr getOwnerCreationFuture() const { + inline std::shared_ptr getOwnerCreationFuture() const { return ownerCreationFuture_; } @@ -243,7 +243,7 @@ class TORCH_API RRef : public RRefInterface { } // Dispatches an error to the correct handler based on its RPCErrorType. - void handleError(RPCErrorType errorType, const FutureMessage& futMessage); + void handleError(RPCErrorType errorType, const JitFuture& JitFuture); // Send delete UserRRef request to Owner, // if the request hasn't been sent yet. @@ -272,7 +272,7 @@ class TORCH_API RRef : public RRefInterface { // it could be any TypePtr that JIT support, including PyObjectType const TypePtr type_; // Future corresponding to request to create RRef on remote node. - std::shared_ptr ownerCreationFuture_; + std::shared_ptr ownerCreationFuture_; }; // ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user diff --git a/torch/csrc/distributed/rpc/rref_proto.cpp b/torch/csrc/distributed/rpc/rref_proto.cpp index ebf49f0ad23bf..d82d24d9c3e7c 100644 --- a/torch/csrc/distributed/rpc/rref_proto.cpp +++ b/torch/csrc/distributed/rpc/rref_proto.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include diff --git a/torch/csrc/distributed/rpc/script_call.cpp b/torch/csrc/distributed/rpc/script_call.cpp index 503fee77aa40b..56047ce76cf88 100644 --- a/torch/csrc/distributed/rpc/script_call.cpp +++ b/torch/csrc/distributed/rpc/script_call.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include namespace torch { diff --git a/torch/csrc/distributed/rpc/script_remote_call.cpp b/torch/csrc/distributed/rpc/script_remote_call.cpp index fcd31a5cf7b78..f04eb9d14e3e2 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.cpp +++ b/torch/csrc/distributed/rpc/script_remote_call.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 11c5408c2c35b..4f56c916cb988 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -10,6 +10,10 @@ #include #include +#ifdef USE_CUDA_NOT_ROCM +#include +#endif + namespace torch { namespace distributed { namespace rpc { @@ -29,49 +33,16 @@ const std::string kClientActiveCalls = "agent.client_active_calls"; const std::string kServerActiveCalls = "agent.server_active_calls"; const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls"; -inline void checkCPUTensor(const torch::Tensor& tensor) { - TORCH_CHECK( - tensor.device() == at::kCPU, - "TensorPipeAgent only supports CPU tensors by default. Sending " - "GPU tensors using RPC requires explicitly configurations using " - "`set_device_map` on `TensorPipeRpcBackendOptions`. Got a tensor " - "with device ", - tensor.device(), - ", but no device map is specified."); -} - -std::vector getDevicesForTensors( - const std::string& remoteName, - const std::vector& tensors, - const std::unordered_map& deviceMaps) { - const auto workerIter = deviceMaps.find(remoteName); - if (workerIter == deviceMaps.end()) { - for (const auto& tensor : tensors) { - checkCPUTensor(tensor); - } - return {}; - } else { - std::vector deviceIndices; - deviceIndices.reserve(tensors.size()); - const auto& deviceMap = workerIter->second; - for (const auto& tensor : tensors) { - const auto deviceIter = deviceMap.find(tensor.device().index()); - if (deviceIter == deviceMap.end()) { - checkCPUTensor(tensor); - deviceIndices.push_back(-1); - } else { - deviceIndices.push_back(deviceIter->second); - } - } - return deviceIndices; - } -} - } // namespace +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_DEFINE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration); -C10_DEFINE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_DEFINE_REGISTRY(TensorPipeCpuChannelRegistry, CpuChannelRegistration); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_DEFINE_REGISTRY(TensorPipeCudaChannelRegistry, CudaChannelRegistration); std::string TensorPipeAgent::guessUvAddress( tensorpipe::transport::uv::Context& uvContext) { @@ -112,6 +83,15 @@ constexpr int64_t kMultiplexedUvChannelPriority = 100; // The basic channel reuses a transport as a channel, and is thus our fallback. constexpr int64_t kBasicChannelPriority = 0; +#if TENSORPIPE_HAS_CUDA_IPC_CHANNEL && defined(USE_CUDA_NOT_ROCM) +constexpr int64_t kCudaIpcChannelPriority = 300; +#endif + +#ifdef USE_CUDA_NOT_ROCM +constexpr int64_t kCudaXthChannelPriority = 400; +constexpr int64_t kCudaBasicChannelPriority = 100; +#endif + std::unique_ptr makeUvTransport() { auto context = std::make_shared(); std::string address = TensorPipeAgent::guessUvAddress(*context); @@ -121,6 +101,7 @@ std::unique_ptr makeUvTransport() { // The UV transport is implemented using standard TCP connections. It leverages // libuv (https://github.com/libuv/libuv) in order to be cross-platform. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport); #if TENSORPIPE_HAS_SHM_TRANSPORT @@ -147,26 +128,28 @@ std::unique_ptr makeShmTransport() { // memory (plus UNIX domain sockets to bootstrap the connection and exchange // file descriptors). It is Linux-only due to some advanced features (O_TMPFILE, // eventfd, ...). +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport); #endif -std::unique_ptr makeBasicChannel() { +std::unique_ptr makeBasicChannel() { auto context = std::make_shared(); - return std::make_unique( - ChannelRegistration{std::move(context), kBasicChannelPriority}); + return std::make_unique( + CpuChannelRegistration{std::move(context), kBasicChannelPriority}); } // The basic channel is just a straightforward adapter wrapper that allows any // transport to be used as a channel. -C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_REGISTER_CREATOR(TensorPipeCpuChannelRegistry, basic, makeBasicChannel); #if TENSORPIPE_HAS_CMA_CHANNEL -std::unique_ptr makeCmaChannel() { +std::unique_ptr makeCmaChannel() { auto context = std::make_shared(); - return std::make_unique( - ChannelRegistration{std::move(context), kCmaChannelPriority}); + return std::make_unique( + CpuChannelRegistration{std::move(context), kCmaChannelPriority}); } // The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv @@ -174,13 +157,14 @@ std::unique_ptr makeCmaChannel() { // process (as long as they belong to the same user and other security // constraints are satisfied). It does, more or less, what GDB does when it's // attached to a running process. -C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_REGISTER_CREATOR(TensorPipeCpuChannelRegistry, cma, makeCmaChannel); #endif constexpr static int kNumUvThreads = 16; -std::unique_ptr makeMultiplexedUvChannel() { +std::unique_ptr makeMultiplexedUvChannel() { std::vector> contexts; std::vector> listeners; for (int laneIdx = 0; laneIdx < kNumUvThreads; ++laneIdx) { @@ -191,8 +175,8 @@ std::unique_ptr makeMultiplexedUvChannel() { } auto context = std::make_shared( std::move(contexts), std::move(listeners)); - return std::make_unique( - ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority}); + return std::make_unique(CpuChannelRegistration{ + std::move(context), kMultiplexedUvChannelPriority}); } // The multiplexed UV channel encapsulates multiple UV transports (each with its @@ -201,11 +185,84 @@ std::unique_ptr makeMultiplexedUvChannel() { // is split in equal chunks and each chunks is sent on a different connection // and thus driven by a different thread. This is needed to reach very high // bandwidths. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_REGISTER_CREATOR( - TensorPipeChannelRegistry, + TensorPipeCpuChannelRegistry, mpt_uv, makeMultiplexedUvChannel); +#if TENSORPIPE_HAS_CUDA_IPC_CHANNEL && defined(USE_CUDA_NOT_ROCM) + +std::unique_ptr makeCudaIpcChannel() { + auto context = std::make_shared(); + return std::make_unique( + CudaChannelRegistration{std::move(context), kCudaIpcChannelPriority}); +} + +// The cuda_ipc channels use cudaMemcpy to transmit CUDA tensor across processes +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_REGISTER_CREATOR( + TensorPipeCudaChannelRegistry, + cuda_ipc, + makeCudaIpcChannel); + +#endif + +#ifdef USE_CUDA_NOT_ROCM + +std::unique_ptr makeCudaXthChannel() { + auto context = std::make_shared(); + return std::make_unique( + CudaChannelRegistration{std::move(context), kCudaXthChannelPriority}); +} + +// The cuda_xth channel supports same-process GPU-to-GPU comm +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_REGISTER_CREATOR( + TensorPipeCudaChannelRegistry, + cuda_xth, + makeCudaXthChannel); + +std::unique_ptr makeCudaBasicChannel() { + auto context = std::make_shared( + std::make_shared()); + return std::make_unique( + CudaChannelRegistration{std::move(context), kCudaBasicChannelPriority}); +} + +// The cuda_basic is the fallback channel for GPU-to-GPU comm +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_REGISTER_CREATOR( + TensorPipeCudaChannelRegistry, + cuda_basic, + makeCudaBasicChannel); + +#endif + +} // namespace + +namespace { + +// This is a wrapper of CUDAMultiStreamGuard to run in both CUDA-enabled and +// CPU-only environments. When CUDA is not available, all methods are no-ops. +struct MultiStreamGuard { + MultiStreamGuard(const MultiStreamGuard& other) = delete; + MultiStreamGuard(MultiStreamGuard&& other) = delete; + MultiStreamGuard& operator=(const MultiStreamGuard& rhs) = delete; + MultiStreamGuard& operator=(MultiStreamGuard&& rhs) = delete; + +#ifndef USE_CUDA_NOT_ROCM + explicit MultiStreamGuard( + const std::shared_ptr& /* unused */) {} +#else + explicit MultiStreamGuard(const std::shared_ptr& ctx) + : guard(ctx->getReservedStreams()) {} + + private: + at::cuda::CUDAMultiStreamGuard guard; +#endif +}; + } // namespace ////////////////////////// MetricsTracker ///////////////////////////////// @@ -226,6 +283,28 @@ float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const { //////////////////////// TensorpipeRpcAgent ///////////////////////////////// +void TensorPipeAgent::removeFromTimeoutMap( + uint64_t messageId, + steady_clock_time_point expirationTime) { + // Remove entry from timeoutMap_. + { + std::unique_lock lock(timeoutMapMutex_); + auto& timedOutFuturesVector = timeoutMap_[expirationTime]; + for (auto it = timedOutFuturesVector.begin(); + it != timedOutFuturesVector.end(); + it++) { + if (it->messageId == messageId) { + it = timedOutFuturesVector.erase(it); + break; + } + } + + if (timedOutFuturesVector.empty()) { + timeoutMap_.erase(expirationTime); + } + } +} + void TensorPipeAgent::collectNames() { const worker_id_t selfId = workerInfo_.id_; const std::string& selfName = workerInfo_.name_; @@ -258,11 +337,11 @@ void TensorPipeAgent::collectNames() { } TensorPipeAgent::TensorPipeAgent( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, - std::shared_ptr processGroup, + c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, TensorPipeRpcBackendOptions opts, std::unique_ptr cb) : RpcAgent( @@ -323,26 +402,42 @@ void TensorPipeAgent::startImpl() { priority, std::move(key), std::move(reg->transport)); } - for (auto& key : TensorPipeChannelRegistry()->Keys()) { - int64_t priority = -1; - if (opts_.channels.has_value()) { - auto iter = - std::find(opts_.channels->begin(), opts_.channels->end(), key); - if (iter == opts_.channels->end()) { - continue; + auto registerChannel = [this](const auto& registry) { + // The registry argument is either TensorPipeCpuChannelRegistry or + // TensorPipeCudaChannelRegistry. + for (auto& key : registry->Keys()) { + int64_t priority = -1; + if (opts_.channels.has_value()) { + auto iter = + std::find(opts_.channels->begin(), opts_.channels->end(), key); + if (iter == opts_.channels->end()) { + continue; + } + // Assign priorities in reverse order of occurrence in the vector, so + // that a channel that comes before another receives a higher priority. + priority = + opts_.channels->size() - 1 - (iter - opts_.channels->begin()); } - // Assign priorities in reverse order of occurrence in the vector, so that - // a channel that comes before another receives a higher priority. - priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin()); - } - std::unique_ptr reg = - TensorPipeChannelRegistry()->Create(key); - if (priority == -1) { - priority = reg->priority; + + // The reg var is either a std::unique_ptr or a + // std::unique_ptr, depending on the type of the + // registry. + auto reg = registry->Create(key); + if (priority == -1) { + priority = reg->priority; + } + context_->registerChannel( + priority, std::move(key), std::move(reg->channel)); } - context_->registerChannel( - priority, std::move(key), std::move(reg->channel)); - } + }; + + registerChannel(TensorPipeCpuChannelRegistry()); + +#ifdef USE_CUDA_NOT_ROCM + + registerChannel(TensorPipeCudaChannelRegistry()); + +#endif listener_ = context_->listen(addresses); @@ -403,26 +498,31 @@ void TensorPipeAgent::onListenerAccepted( void TensorPipeAgent::pipeRead( const std::shared_ptr& pipe, - std::function fn) { + std::function)> fn) noexcept { pipe->readDescriptor([fn{std::move(fn)}, pipe]( const tensorpipe::Error& error, tensorpipe::Message tpMessage) mutable { if (error) { - fn(error, Message()); + fn(error, Message(), nullptr); return; } - TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage); + auto ctx = createLazyStreamContext(); + TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage, ctx); pipe->read( std::move(tpMessage), [tpBuffers{ std::make_shared(std::move(tpBuffers))}, - fn{std::move(fn)}]( + fn{std::move(fn)}, + ctx{std::move(ctx)}]( const tensorpipe::Error& error, tensorpipe::Message tpMessage) mutable { if (error) { - fn(error, Message()); + fn(error, Message(), nullptr); return; } @@ -431,7 +531,7 @@ void TensorPipeAgent::pipeRead( Message rpcMessage = tensorpipeDeserialize( std::move(tpMessage), std::move(*tpBuffers)); - fn(error, std::move(rpcMessage)); + fn(error, std::move(rpcMessage), std::move(ctx)); }); }); } @@ -439,22 +539,21 @@ void TensorPipeAgent::pipeRead( void TensorPipeAgent::pipeWrite( const std::shared_ptr& pipe, Message&& rpcMessage, - std::function fn) { + std::vector&& devices, + std::shared_ptr ctx, + std::function fn) noexcept { tensorpipe::Message tpMessage; TensorpipeWriteBuffers tpBuffers; - const auto& deviceMaps = - rpcMessage.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_; - auto devices = getDevicesForTensors( - pipe->getRemoteName(), rpcMessage.tensors(), deviceMaps); std::tie(tpMessage, tpBuffers) = - tensorpipeSerialize(std::move(rpcMessage), std::move(devices)); + tensorpipeSerialize(std::move(rpcMessage), std::move(devices), ctx); pipe->write( std::move(tpMessage), [tpBuffers{ std::make_shared(std::move(tpBuffers))}, - fn{std::move(fn)}]( + fn{std::move(fn)}, + ctx{std::move(ctx)}]( const tensorpipe::Error& error, tensorpipe::Message /* unused */) { fn(error); }); @@ -462,8 +561,9 @@ void TensorPipeAgent::pipeWrite( void TensorPipeAgent::sendCompletedResponseMessage( std::shared_ptr& pipe, - std::shared_ptr& futureResponseMessage, - uint64_t messageId) { + std::shared_ptr& futureResponseMessage, + uint64_t messageId, + std::shared_ptr ctx) { if (!rpcAgentRunning_.load()) { LOG(WARNING) << "RPC agent for " << workerInfo_.name_ << " won't send response to request #" << messageId << " to " @@ -475,52 +575,22 @@ void TensorPipeAgent::sendCompletedResponseMessage( << " is sending response to request #" << messageId << " to " << pipe->getRemoteName(); - const c10::optional error = - futureResponseMessage->error(); - Message&& responseMessage = std::move(*futureResponseMessage).moveValue(); - responseMessage.setId(messageId); - if (!error) { - const auto& iter = reverseDeviceMaps_.find(pipe->getRemoteName()); - if (iter == opts_.deviceMaps.end()) { - for (const auto& t : responseMessage.tensors()) { - if (!t.device().is_cpu()) { - responseMessage = createExceptionResponse( - c10::str( - "TensorPipe RPC backend only supports CPU tensors by default," - " please move your tensors to CPU before sending them over " - "RPC, or call `set_device_map` on " - "`TensorPipeRpcBackendOptions` to explicitly configure " - "device mapping. Response device mapping is not available for " - "destination ", - pipe->getRemoteName(), - ", but found tensor on device: ", - t.device()), - responseMessage.id()); - break; - } - } - } else { - const auto& deviceMap = iter->second; - for (const auto& t : responseMessage.tensors()) { - if (!t.device().is_cpu() && - deviceMap.find(t.device().index()) == deviceMap.end()) { - responseMessage = createExceptionResponse( - c10::str( - "TensorPipe RPC backend only supports CPU tensors by default." - " Response device mapping is not available for destination ", - pipe->getRemoteName(), - " for device ", - t.device(), - " but received a tensor on that device."), - responseMessage.id()); - break; - } - } + if (!futureResponseMessage->hasError()) { + Message&& responseMessage = + std::move(*futureResponseMessage->value().toCustomClass()); + responseMessage.setId(messageId); + std::vector devices; + try { + devices = getDevicesForTensors(pipe->getRemoteName(), responseMessage); + } catch (const std::exception& e) { + responseMessage = createExceptionResponse(e.what(), messageId); } pipeWrite( pipe, std::move(responseMessage), + std::move(devices), + std::move(ctx), [this, pipe, messageId](const tensorpipe::Error& error) { if (error) { LOG(WARNING) @@ -538,7 +608,10 @@ void TensorPipeAgent::sendCompletedResponseMessage( } else { pipeWrite( pipe, - createExceptionResponse(error->what(), responseMessage.id()), + createExceptionResponse( + futureResponseMessage->tryRetrieveErrorMessage(), messageId), + /* devices */ {}, + std::move(ctx), [this, pipe, messageId](const tensorpipe::Error& error) { if (error) { LOG(WARNING) @@ -560,14 +633,16 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { pipeRead( pipe, [this, pipe]( - const tensorpipe::Error& error, Message&& requestMessage) mutable { + const tensorpipe::Error& error, + Message&& requestMessage, + std::shared_ptr ctx) mutable { if (error) { // FIXME This is not a correct way to check whether this error was // "intentionally" caused by the remote end shutting down. We should // find a better way, Perhaps sending an empty message? if ((error.isOfType() && !rpcAgentRunning_.load()) || - error.isOfType()) { + error.isOfType()) { // This is expected. } else { LOG(WARNING) @@ -593,34 +668,41 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { threadPool_.run([this, pipe, messageId, - requestMessage{std::move(requestMessage)}]() mutable { + requestMessage{std::move(requestMessage)}, + ctx{std::move(ctx)}]() mutable { + // create guards again as this function runs on a different thread + MultiStreamGuard guard(ctx); VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is running request #" << messageId << " from " << pipe->getRemoteName() << " in thread pool"; - std::shared_ptr futureResponseMessage; + std::shared_ptr futureResponseMessage; try { futureResponseMessage = cb_->operator()(requestMessage); - } catch (const std::exception& e) { - futureResponseMessage = std::make_shared(); - futureResponseMessage->setError(e.what()); + } catch (const std::exception& /* unused */) { + futureResponseMessage = + std::make_shared(at::AnyClassType::get()); + futureResponseMessage->setError(std::current_exception()); } // Shortcut if immediately done if (futureResponseMessage->completed()) { decreaseCallCount(serverActiveCalls_); sendCompletedResponseMessage( - pipe, futureResponseMessage, messageId); + pipe, futureResponseMessage, messageId, std::move(ctx)); } else { // Not complete yet increaseCallCount(serverActiveAsyncCalls_); - futureResponseMessage->addCallback( - [this, pipe, futureResponseMessage, messageId]() mutable { - decreaseCallCount(serverActiveCalls_); - decreaseCallCount(serverActiveAsyncCalls_); - sendCompletedResponseMessage( - pipe, futureResponseMessage, messageId); - }); + futureResponseMessage->addCallback([this, + pipe, + futureResponseMessage, + messageId, + ctx{std::move(ctx)}]() mutable { + decreaseCallCount(serverActiveCalls_); + decreaseCallCount(serverActiveAsyncCalls_); + sendCompletedResponseMessage( + pipe, futureResponseMessage, messageId, std::move(ctx)); + }); } VLOG(1) << "RPC agent for " << workerInfo_.name_ @@ -630,7 +712,7 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { }); } -std::shared_ptr TensorPipeAgent::send( +std::shared_ptr TensorPipeAgent::send( const WorkerInfo& toWorkerInfo, Message&& requestMessage, const float rpcTimeoutSeconds) { @@ -663,14 +745,20 @@ std::shared_ptr TensorPipeAgent::send( ClientPipe& clientPipe = it->second; auto& pendingResponseMessage = clientPipe.pendingResponseMessage_; - auto futureResponseMessage = std::make_shared(); + auto futureResponseMessage = std::make_shared( + reverseDeviceMaps_.empty() && opts_.deviceMaps.empty()); uint64_t messageId = nextMessageID_++; requestMessage.setId(messageId); pendingResponseMessage[messageId] = futureResponseMessage; lock.unlock(); - futureResponseMessage->futMsg.addCallback([this]() { + // Get devices for tensors in the request message. This can throw if device + // maps are not configured properly for this request. + auto devices = + getDevicesForTensors(clientPipe.pipe_->getRemoteName(), requestMessage); + + futureResponseMessage->jitFuture->addCallback([this]() { TORCH_INTERNAL_ASSERT( this->threadPool_.inThreadPool(), "Future marked complete from outside the thread pool"); @@ -687,15 +775,17 @@ std::shared_ptr TensorPipeAgent::send( // documentation, a user-provided timeout of 0 indicates the RPC should never // expire (infinite timeout), so there is no need to track it in the // timeoutMap_. + steady_clock_time_point expirationTime; if (timeout.count() != 0) { // Compute the expiration time for this message based on the timeout - auto expirationTime = computeRpcMessageExpiryTime(timeout); + expirationTime = computeRpcMessageExpiryTime(timeout); // Add the Future to the right vector in the timeoutMap_ { std::unique_lock lock(timeoutMapMutex_); auto& timeoutFuturesVector = timeoutMap_[expirationTime]; - timeoutFuturesVector.emplace_back(futureResponseMessage, timeout); + timeoutFuturesVector.emplace_back( + messageId, futureResponseMessage, timeout); } timeoutThreadCV_.notify_one(); } @@ -703,10 +793,15 @@ std::shared_ptr TensorPipeAgent::send( VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #" << messageId << " to " << clientPipe.pipe_->getRemoteName(); + auto ctx = createLazyStreamContext(); + ctx->waitForCurrentStreams(requestMessage.tensors()); pipeWrite( clientPipe.pipe_, std::move(requestMessage), - [this, &clientPipe, messageId](const tensorpipe::Error& error) mutable { + std::move(devices), + std::move(ctx), + [this, &clientPipe, messageId, expirationTime]( + const tensorpipe::Error& error) mutable { if (error) { if (error.isOfType() && !rpcAgentRunning_.load()) { @@ -731,8 +826,10 @@ std::shared_ptr TensorPipeAgent::send( pipeRead( clientPipe.pipe_, - [this, &clientPipe]( - const tensorpipe::Error& error, Message&& responseMessage) { + [this, &clientPipe, expirationTime]( + const tensorpipe::Error& error, + Message&& responseMessage, + std::shared_ptr ctx) { if (error) { if (error.isOfType() && !rpcAgentRunning_.load()) { @@ -757,6 +854,9 @@ std::shared_ptr TensorPipeAgent::send( for (auto& p : pendingMsgs) { markFutureWithError(std::move(p.second), errorMsg); } + + // Remove entry from timeoutMap_. + removeFromTimeoutMap(responseMessage.id(), expirationTime); return; } @@ -767,7 +867,7 @@ std::shared_ptr TensorPipeAgent::send( << " received response #" << messageId << " from " << clientPipe.pipe_->getRemoteName(); - std::shared_ptr futureResponseMessage; + std::shared_ptr futureResponseMessage; { std::lock_guard lock(mutex_); // A read error will lead all following callbacks to be @@ -784,6 +884,9 @@ std::shared_ptr TensorPipeAgent::send( clientPipe.pendingResponseMessage_.erase(it); } + // Remove entry from timeoutMap_. + removeFromTimeoutMap(messageId, expirationTime); + if (responseMessage.type() == MessageType::EXCEPTION) { markFutureWithError( std::move(futureResponseMessage), @@ -793,13 +896,13 @@ std::shared_ptr TensorPipeAgent::send( } else { markFutureAsComplete( std::move(futureResponseMessage), - std::move(responseMessage)); + std::move(responseMessage), + std::move(ctx)); } }); }); - return std::shared_ptr( - futureResponseMessage, &futureResponseMessage->futMsg); + return futureResponseMessage->jitFuture; } void TensorPipeAgent::pollTimeoutRpcs() { @@ -827,10 +930,8 @@ void TensorPipeAgent::pollTimeoutRpcs() { // Move all these futures to a separate vector so we can process them // outside the lock. - std::vector, - std::chrono::milliseconds>> - timedOutFutures = std::move(timeoutMap_.begin()->second); + std::vector timedOutFutures = + std::move(timeoutMap_.begin()->second); // We can safely remove this key from the timeoutMap_ since all these // futures will be processed. timeoutMap_.erase(timeoutMap_.begin()); @@ -840,11 +941,12 @@ void TensorPipeAgent::pollTimeoutRpcs() { // Set an error on futures added to the timedOutFutures vector. We do this // outside the lock to prevent potential lock-order-inversions by callbacks // triggered by the setError call. - for (auto& futureTimeoutPair : timedOutFutures) { + for (auto& timeoutMetadata : timedOutFutures) { std::string errorMsg = - fmt::format(kRpcTimeoutErrorStr, futureTimeoutPair.second.count()); + fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count()); auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT); - markFutureWithError(std::move(futureTimeoutPair.first), std::move(err)); + markFutureWithError( + std::move(timeoutMetadata.responseFuture), std::move(err)); } } } @@ -1046,16 +1148,20 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) { } void TensorPipeAgent::markFutureAsComplete( - std::shared_ptr futureMessage, - Message message) { - if (!futureMessage->isComplete.test_and_set()) { + std::shared_ptr atomicFuture, + Message message, + std::shared_ptr ctx) { + if (!atomicFuture->isComplete.test_and_set()) { // Completing the future will run its callbacks, which could execute // arbitrary user code. To prevent blocking or stalling the TensorPipe event // loops, we defer this to a worker thread. threadPool_.run([this, - futureMessage{std::move(futureMessage)}, - message{std::move(message)}]() mutable { - futureMessage->futMsg.markCompleted(std::move(message)); + atomicFuture{std::move(atomicFuture)}, + message{std::move(message)}, + ctx{std::move(ctx)}]() mutable { + MultiStreamGuard guard(ctx); + atomicFuture->jitFuture->markCompleted( + IValue(c10::make_intrusive(std::move(message)))); // The future's callbacks may schedule further RPCs, increasing the count. // Thus we must decrease it after completing the future, otherwise it may // briefly dip to zero and trick join into thinking all work is done. @@ -1065,16 +1171,17 @@ void TensorPipeAgent::markFutureAsComplete( } void TensorPipeAgent::markFutureWithError( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, std::string errorMsg) { - if (!futureMessage->isComplete.test_and_set()) { + if (!atomicFuture->isComplete.test_and_set()) { // Completing the future will run its callbacks, which could execute // arbitrary user code. To prevent blocking or stalling the TensorPipe event // loops, we defer this to a worker thread. threadPool_.run([this, - futureMessage{std::move(futureMessage)}, + atomicFuture{std::move(atomicFuture)}, errorMsg{std::move(errorMsg)}]() mutable { - futureMessage->futMsg.setError(std::move(errorMsg)); + atomicFuture->jitFuture->setError( + std::make_exception_ptr(std::runtime_error(errorMsg))); // The future's callbacks may schedule further RPCs, increasing the count. // Thus we must decrease it after completing the future, otherwise it may // briefly dip to zero and trick join into thinking all work is done. @@ -1083,6 +1190,72 @@ void TensorPipeAgent::markFutureWithError( } } +std::vector TensorPipeAgent::getDevicesForTensors( + const std::string& remoteName, + const Message& message) const { + const auto& deviceMaps = + message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_; + + const auto errStr = c10::str( + "TensorPipe RPC backend only supports CPU tensors by default, please " + "move your tensors to CPU before sending them over RPC, or call " + "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly " + "configure device mapping. ", + message.isRequest() ? "Request" : "Response", + " device mapping is not available for destination ", + remoteName); + + const auto& iter = deviceMaps.find(remoteName); + if (iter == deviceMaps.end()) { + for (const auto& t : message.tensors()) { + TORCH_CHECK( + t.device().is_cpu(), + errStr, + ", but found tensor on device: ", + t.device()); + } + return {}; + } else { + std::vector deviceIndices; + deviceIndices.reserve(message.tensors().size()); + const auto& deviceMap = iter->second; + bool hasCudaTensor = false; + for (const auto& t : message.tensors()) { + if (t.device().is_cpu()) { + deviceIndices.push_back(-1); + } else { + const auto deviceIter = deviceMap.find(t.device().index()); + TORCH_CHECK( + deviceIter != deviceMap.end(), + errStr, + " for device ", + t.device(), + " but received a tensor on that device."); + deviceIndices.push_back(deviceIter->second); + hasCudaTensor = true; + } + } + if (!hasCudaTensor) { + deviceIndices.clear(); + } + return deviceIndices; + } +} + +size_t TensorPipeAgent::timeoutMapSize() { + std::unique_lock lock(timeoutMapMutex_); + return timeoutMap_.size(); +} + +size_t TensorPipeAgent::numPendingResponses() { + size_t totalPending = 0; + std::unique_lock lock(mutex_); + for (const auto& entry : connectedPipes_) { + totalPending += entry.second.pendingResponseMessage_.size(); + } + return totalPending; +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index a3df040112855..b7117475e1f37 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -2,6 +2,7 @@ #ifdef USE_TENSORPIPE +#include #include #include @@ -9,14 +10,24 @@ #include #include #include +#include #include +#ifdef USE_CUDA_NOT_ROCM +#include +#endif + // Forward-declare the TensorPipe classes we need, to avoid including its // headers in PyTorch's ones and thus have it become a public dependency. namespace tensorpipe { class CpuBuffer; + +#ifdef USE_CUDA_NOT_ROCM +class CudaBuffer; +#endif + class Context; class Error; class Listener; @@ -34,6 +45,11 @@ namespace channel { template class Context; using CpuContext = Context; + +#ifdef USE_CUDA_NOT_ROCM +using CudaContext = Context; +#endif + } // namespace channel using DeviceMap = std::unordered_map; @@ -44,6 +60,8 @@ namespace torch { namespace distributed { namespace rpc { +struct LazyStreamContext; + using steady_clock_time_point = std::chrono::time_point; @@ -53,14 +71,26 @@ struct TransportRegistration { std::string address; }; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) C10_DECLARE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration); -struct ChannelRegistration { +struct CpuChannelRegistration { std::shared_ptr channel; int64_t priority; }; -C10_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_DECLARE_REGISTRY(TensorPipeCpuChannelRegistry, CpuChannelRegistration); + +struct CudaChannelRegistration { +#ifdef USE_CUDA_NOT_ROCM + std::shared_ptr channel; + int64_t priority; +#endif +}; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_DECLARE_REGISTRY(TensorPipeCudaChannelRegistry, CudaChannelRegistration); constexpr auto kDefaultNumWorkerThreads = 16; @@ -94,7 +124,8 @@ struct TensorPipeRpcBackendOptions : public RpcBackendOptions { if (channels.has_value()) { for (const std::string& channelName : channels.value()) { TORCH_CHECK( - TensorPipeChannelRegistry()->Has(channelName), + TensorPipeCudaChannelRegistry()->Has(channelName) || + TensorPipeCpuChannelRegistry()->Has(channelName), "Unknown channel: ", channelName); } @@ -141,18 +172,18 @@ struct AggregatedNetworkData { class TensorPipeAgent : public RpcAgent { public: TensorPipeAgent( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, std::string selfName, worker_id_t selfId, int worldSize, - std::shared_ptr processGroup, + c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, TensorPipeRpcBackendOptions opts, std::unique_ptr cb); TensorPipeAgent(const TensorPipeAgent&) = delete; TensorPipeAgent& operator=(const TensorPipeAgent&) = delete; - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; @@ -191,6 +222,16 @@ class TensorPipeAgent : public RpcAgent { tensorpipe::transport::uv::Context& uvContext); private: + FRIEND_TEST(TestE2ETensorPipe, TestTrainingLoop); + size_t timeoutMapSize(); + size_t numPendingResponses(); + + // Removes the given messageId with the given expirationTime from the + // timeoutMap_. + void removeFromTimeoutMap( + uint64_t messageId, + steady_clock_time_point expirationTime); + // Populates workerIdToInfo_ and workerNameToInfo_ using addressStore_ void collectNames(); @@ -200,14 +241,19 @@ class TensorPipeAgent : public RpcAgent { // by client, and read request messages by server. void pipeRead( const std::shared_ptr&, - std::function); + std::function)>) noexcept; // TensorPipe write function that could be used to write response // messages by server, and write request messages by client. void pipeWrite( const std::shared_ptr&, Message&& message, - std::function); + std::vector&& devices, + std::shared_ptr ctx, + std::function) noexcept; // Callback of listener accept() void onListenerAccepted( @@ -219,8 +265,9 @@ class TensorPipeAgent : public RpcAgent { void sendCompletedResponseMessage( std::shared_ptr& pipe, - std::shared_ptr& futureResponseMessage, - uint64_t messageId); + std::shared_ptr& futureResponseMessage, + uint64_t messageId, + std::shared_ptr ctx); // Collects metrics from successful RPC calls void trackNetworkData( @@ -233,15 +280,53 @@ class TensorPipeAgent : public RpcAgent { uint64_t requestSize, const std::string& destWorkerName); + inline std::vector getDevicesForTensors( + const std::string& remoteName, + const Message& message) const; + +#ifdef USE_CUDA_NOT_ROCM + // An RPC-specific CUDAFuture subclass. It overrides the extractDataPtrs + // function to handle and only handle RPC Messages. + struct TORCH_CUDA_API RpcCUDAFuture final : at::cuda::CUDAFuture { + public: + using at::cuda::CUDAFuture::CUDAFuture; + + protected: + std::vector> extractDataPtrs( + const at::IValue& value) override { + const auto message = value.toCustomClass(); + TORCH_INTERNAL_ASSERT( + message, "Passed a non-Message type to RpcCUDAFuture"); + std::vector> data_ptrs; + for (const auto& tensor : message->tensors()) { + data_ptrs.emplace_back(tensor.storage().data_ptr()); + } + return data_ptrs; + } + }; +#endif + // When a request+response completes, we need to mark the future message as // complete. However, if its timeout has already expired, it already has an // error set. There is no atomic "test-and-set" way to mark a future complete // only if it isn't yet. It does exist for errors (setErrorIfNeeded) but, even // then, it ends up printing a log message, which may worry the user. To solve // both issues we use a separate atomic flag to know the status of the future. - struct AtomicFutureMessage { - FutureMessage futMsg; + struct AtomicJitFuture { + AtomicJitFuture(bool noCuda = true) { +#ifdef USE_CUDA_NOT_ROCM + if (!noCuda) { + jitFuture = std::make_shared(at::AnyClassType::get()); + } else { +#else + { +#endif + jitFuture = std::make_shared(at::AnyClassType::get()); + } + } + std::atomic_flag isComplete = ATOMIC_FLAG_INIT; + std::shared_ptr jitFuture; }; // Maintains state per client pipe to track pending response messages and @@ -254,7 +339,7 @@ class TensorPipeAgent : public RpcAgent { std::shared_ptr pipe_; bool readError_{false}; // Map from Message Request ID's to corresponding futures. - std::unordered_map> + std::unordered_map> pendingResponseMessage_; }; @@ -278,17 +363,27 @@ class TensorPipeAgent : public RpcAgent { // The join method is required to behave like a barrier and perform collective // operations. For simplicity and reliability, we offload this to a process // group, but probably one day we might want to re-implement them using RPCs. - const std::shared_ptr processGroup_; + const c10::intrusive_ptr<::c10d::ProcessGroup> processGroup_; mutable std::mutex mutex_; uint64_t nextMessageID_{0}; + // Metadata used for tracking of whether certain RPCs have timed out or not. + struct TimeoutMessageMetadata { + TimeoutMessageMetadata( + uint64_t messageId_, + std::shared_ptr responseFuture_, + std::chrono::milliseconds timeout_) + : messageId(messageId_), + responseFuture(responseFuture_), + timeout(timeout_) {} + uint64_t messageId; + std::shared_ptr responseFuture; + std::chrono::milliseconds timeout; + }; + // Map to store the expiration times for each message. - std::map< - steady_clock_time_point, - std::vector, - std::chrono::milliseconds>>> + std::map> timeoutMap_; // Thread that will poll the timeoutMap_ for timed out messages and mark them @@ -360,10 +455,11 @@ class TensorPipeAgent : public RpcAgent { // Helpers to set the state of the requests. void markFutureAsComplete( - std::shared_ptr futureMessage, - Message message); + std::shared_ptr atomicFuture, + Message message, + std::shared_ptr ctx); void markFutureWithError( - std::shared_ptr futureMessage, + std::shared_ptr atomicFuture, std::string errorMsg); }; diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 54da8d8c52c1e..1d17e4451372e 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -2,7 +2,11 @@ #ifdef USE_TENSORPIPE -#include +#ifdef USE_CUDA_NOT_ROCM +#include +#include +#include +#endif #include @@ -37,7 +41,8 @@ inline c10::Device indexToDevice(c10::DeviceIndex index) { std::tuple tensorpipeSerialize( Message&& rpcMessage, - std::vector deviceIndices) { + std::vector deviceIndices, + const std::shared_ptr& ctx) { tensorpipe::Message tpMessage; TensorpipeWriteBuffers buffers; @@ -62,16 +67,7 @@ std::tuple tensorpipeSerialize( tensorpipe::Message::Payload{payloadPtr, buffers.payload.size()}); // Tensors - if (deviceIndices.empty()) { - buffers.tensors = cloneSparseTensors(rpcMessage.tensors()).vec(); - } else { - std::vector tensors; - tensors.reserve(rpcMessage.tensors().size()); - for (const auto& tensor : rpcMessage.tensors()) { - tensors.emplace_back(tensor.cpu()); - } - buffers.tensors = cloneSparseTensors(tensors).vec(); - } + buffers.tensors = cloneSparseTensors(rpcMessage.tensors()).vec(); torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t { buffers.pickle.insert( @@ -88,8 +84,10 @@ std::tuple tensorpipeSerialize( buffers.pickle.data(), buffers.pickle.size()}); const auto& tensorDataVec = pickler.tensorData(); for (size_t i = 0; i < tensorDataVec.size(); ++i) { + // This is different from jit::getWriteableTensorData as it avoids copying + // tensor to CPU. const auto& tensorData = - jit::getWriteableTensorData(tensorDataVec[i]); + jit::getWriteableTensorData(tensorDataVec[i], /* toCpu */ false); // Enforce memory copy if tensor is created from torch::from_blob, means // that the tensor doesn't own the memory. std::string metadata = @@ -107,16 +105,37 @@ std::tuple tensorpipeSerialize( // it uses non-const ptrs even though it doesn't modify them when writing. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) char* tensorPtr = const_cast(tensorData.data()); - tpMessage.tensors.push_back(tensorpipe::Message::Tensor{ - tensorpipe::CpuBuffer{tensorPtr, tensorData.sizeInBytes()}, - std::move(metadata)}); + if (tensorDataVec[i].device().is_cpu()) { + tpMessage.tensors.push_back(tensorpipe::Message::Tensor{ + tensorpipe::CpuBuffer{tensorPtr, tensorData.sizeInBytes()}, + std::move(metadata)}); +#ifdef USE_CUDA_NOT_ROCM + } else if (tensorDataVec[i].device().is_cuda()) { + auto stream = ctx->getStream(tensorDataVec[i].device().index()); + tpMessage.tensors.push_back(tensorpipe::Message::Tensor{ + tensorpipe::CudaBuffer{ + tensorPtr, tensorData.sizeInBytes(), stream.stream()}, + std::move(metadata)}); + // record tensor data ptrs on TensorPipe streams, so that the tensors + // won't be destructed before TensorPipe finishing sending them. + c10::cuda::CUDACachingAllocator::recordStream( + tensorDataVec[i].storage().data_ptr(), stream); +#endif + } else { + TORCH_CHECK( + false, + "Attempting to send a Tensor with unexpected device type ", + tensorDataVec[i].device()); + } } } return std::make_tuple(std::move(tpMessage), std::move(buffers)); } -TensorpipeReadBuffers tensorpipeAllocate(tensorpipe::Message& tpMessage) { +TensorpipeReadBuffers tensorpipeAllocate( + tensorpipe::Message& tpMessage, + const std::shared_ptr& ctx) { TensorpipeReadBuffers buffers; TORCH_INTERNAL_ASSERT( @@ -153,9 +172,26 @@ TensorpipeReadBuffers tensorpipeAllocate(tensorpipe::Message& tpMessage) { tpMessage.payloads[kTpMessagePickleIdx].data = buffers.pickle.data(); for (auto& tensor : tpMessage.tensors) { - buffers.tensors.emplace_back( - at::getCPUAllocator()->allocate(tensor.buffer.cpu.length)); - tensor.buffer.cpu.ptr = buffers.tensors.back().get(); + if (tensor.buffer.type == tensorpipe::DeviceType::kCpu) { + buffers.tensors.emplace_back( + at::getCPUAllocator()->allocate(tensor.buffer.cpu.length)); + tensor.buffer.cpu.ptr = buffers.tensors.back().get(); +#ifdef USE_CUDA_NOT_ROCM + } else if (tensor.buffer.type == tensorpipe::DeviceType::kCuda) { + auto deviceIndex = std::stoi(tensor.metadata); + auto stream = ctx->getStream(deviceIndex); + // CUDACachingAllocator will call recordStream accordingly on the current + // stream. + at::cuda::CUDAStreamGuard guard(stream); + buffers.tensors.emplace_back( + c10::cuda::CUDACachingAllocator::get()->allocate( + tensor.buffer.cuda.length)); + tensor.buffer.cuda.ptr = buffers.tensors.back().get(); + tensor.buffer.cuda.stream = stream.stream(); +#endif + } else { + TORCH_INTERNAL_ASSERT(false, "Unrecognized TensorPipe buffer type."); + } } return buffers; @@ -186,27 +222,31 @@ Message tensorpipeDeserialize( // No need to pass typeResolver here, as it always processes string and // tensors only torch::jit::Unpickler unpickler( - pickleReadFunc, nullptr, nullptr, tensorReadFunc, {}); + pickleReadFunc, + nullptr, + nullptr, + tensorReadFunc, + {}, + /* use_storage_device*/ true); + auto ival = unpickler.parse_ivalue(); for (auto&& t : ival.toTensorList()) { tensors.emplace_back(std::move(t)); } - // NB: This is a temporary solution. When TensorPipe Tensor.data can point to - // a CUDA memory address, we should directly use CUDACachingAllocator to - // create CUDA buffers in tensorpipeAllocate. for (size_t i = 0; i < message.tensors.size(); ++i) { auto& tensor = message.tensors[i]; if (!tensor.metadata.empty()) { TORCH_INTERNAL_ASSERT( - message.tensors.size() == tensors.size(), - "Number of device indices must match the number of tensors in the " - "RPC message. But got ", - tensors.size(), - " tensors with ", - message.tensors.size(), - " device indices."); - tensors[i] = tensors[i].to(indexToDevice(std::stoi(tensor.metadata))); + tensors[i].device() == indexToDevice(std::stoi(tensor.metadata)), + "Tensor ", + i, + " in message ", + *buffers.id, + " was expected to be received on device ", + tensor.metadata, + ", but got it on ", + tensors[i].device()); } } diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.h b/torch/csrc/distributed/rpc/tensorpipe_utils.h index a5c6d23bb3b06..a06e5f0f0e306 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.h +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.h @@ -2,8 +2,16 @@ #ifdef USE_TENSORPIPE +#include #include +#ifdef USE_CUDA_NOT_ROCM +#include +#include +#include +#include +#endif + namespace tensorpipe { class Message; } // namespace tensorpipe @@ -12,6 +20,100 @@ namespace torch { namespace distributed { namespace rpc { +#ifdef USE_CUDA_NOT_ROCM +using at::cuda::CUDAStream; +#endif + +// A general device context class for both CPU and CUDA. If CUDA is not +// available, all CUDA-related methods will be no-ops. +struct TORCH_API LazyStreamContext { + LazyStreamContext(const LazyStreamContext& other) = delete; + LazyStreamContext(LazyStreamContext&& other) = delete; + LazyStreamContext& operator=(const LazyStreamContext& rhs) = delete; + LazyStreamContext& operator=(LazyStreamContext&& rhs) & = delete; + + LazyStreamContext() = default; + virtual ~LazyStreamContext() = default; + virtual void waitForCurrentStreams(const std::vector& = {}) {} + +#ifdef USE_CUDA_NOT_ROCM + virtual std::vector getReservedStreams() const { + throw std::runtime_error( + "Attempting to access CUDA streams, but torch is not built with CUDA"); + } + + virtual CUDAStream getStream(c10::DeviceIndex index) { + throw std::runtime_error(c10::str( + "Attempting to access CUDA stream of device ", + index, + ", but torch is not built with CUDA")); + } +#endif +}; + +#ifndef USE_CUDA_NOT_ROCM + +// CUDA is not available, use CPU device context. +inline std::shared_ptr createLazyStreamContext() { + return std::make_shared(); +} + +#else + +// CUDA is available. Implement CUDA-related operations. +struct TORCH_CUDA_API CudaLazyStreamContext : public LazyStreamContext { + using LazyStreamContext::LazyStreamContext; + + // let streams in this context wiat for current streams. + void waitForCurrentStreams( + const std::vector& tensors = {}) override { + for (const auto& tensor : tensors) { + if (tensor.is_cuda()) { + getStream(tensor.device().index()); + } + } + + for (const auto& entry : streams_) { + at::cuda::CUDAEvent event; + event.record(at::cuda::getCurrentCUDAStream(entry.first)); + event.block(entry.second); + } + } + + // get all streams used in this context + std::vector getReservedStreams() const override { + std::vector reservedStreams; + reservedStreams.reserve(streams_.size()); + for (const auto& entry : streams_) { + reservedStreams.push_back(entry.second); + } + return reservedStreams; + } + + // get a stream for the given device. If it is the first time using that + // device, allocate a new stream and store it in the map. + CUDAStream getStream(c10::DeviceIndex index) override { + auto iter = streams_.find(index); + if (iter == streams_.end()) { + auto cudaStream = at::cuda::getStreamFromPool( + /* isHighPriority */ false, /* device */ index); + streams_.emplace(index, cudaStream); + return cudaStream; + } else { + return iter->second; + } + } + + private: + std::unordered_map streams_; +}; + +inline std::shared_ptr createLazyStreamContext() { + return std::make_shared(); +} + +#endif + // A struct that holds pointers that keep alive all the memory that will be // accessed by TensorPipe during a write operation. struct TensorpipeWriteBuffers { @@ -43,14 +145,18 @@ struct TensorpipeReadBuffers { TORCH_API std::tuple tensorpipeSerialize( Message&& rpcMessage, - std::vector devices = {}); + std::vector devices = {}, + const std::shared_ptr& = + std::make_shared()); // Allocate the buffers that will hold the incoming data. They will be managed // by the returned holder, which must be kept alive until the asynchronous read // has finished. Pointers to these buffers will be stored in-place in the // TensorPipe message. -TORCH_API TensorpipeReadBuffers -tensorpipeAllocate(tensorpipe::Message& tpMessage); +TORCH_API TensorpipeReadBuffers tensorpipeAllocate( + tensorpipe::Message& tpMessage, + const std::shared_ptr& ctx = + std::make_shared()); // Convert a TensorPipe message back into an RPC message. This requires the data // to be available and can thus only be performed once the asynchronous read has diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp index a1be688a285e2..57dbef3a549ba 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include namespace torch { @@ -12,7 +12,7 @@ std::string fromVec(const std::vector& vec) { FaultyProcessGroupAgent::FaultyProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr<::c10d::ProcessGroup> pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, const std::vector& messagesToFail, @@ -56,7 +56,7 @@ std::unordered_map> FaultyProcessGroupAgent:: return delayMessages; } -std::shared_ptr FaultyProcessGroupAgent::send( +std::shared_ptr FaultyProcessGroupAgent::send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds) { @@ -78,11 +78,11 @@ std::shared_ptr FaultyProcessGroupAgent::send( if (failMessageCountMap_[key] < failNumSends_) { failMessageCountMap_[key]++; lock.unlock(); - auto fm = std::make_shared(); - fm->setError(makeRPCError( + auto jitFuture = std::make_shared(at::AnyClassType::get()); + jitFuture->setError(std::make_exception_ptr(std::runtime_error(makeRPCError( c10::str("Send attempt failed intentionally for ", key), - RPCErrorType::INTENTIONAL_FAILURE)); - return fm; + RPCErrorType::INTENTIONAL_FAILURE)))); + return jitFuture; } else { lock.unlock(); return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds); diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h index f240f6847c441..8cbe4c9a137dd 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h @@ -35,7 +35,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent { public: FaultyProcessGroupAgent( std::string workerName, - std::shared_ptr pg, + c10::intrusive_ptr pg, int numSendRecvThreads, std::chrono::milliseconds rpcTimeout, const std::vector& messagesToFail, @@ -43,7 +43,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent { int failNumSends = 0); // Faulty send function for this class. - std::shared_ptr send( + std::shared_ptr send( const WorkerInfo& to, Message&& message, const float rpcTimeoutSeconds = diff --git a/torch/csrc/distributed/rpc/testing/init.cpp b/torch/csrc/distributed/rpc/testing/init.cpp index cdb67e2ea6b56..17ed268174c50 100644 --- a/torch/csrc/distributed/rpc/testing/init.cpp +++ b/torch/csrc/distributed/rpc/testing/init.cpp @@ -17,16 +17,18 @@ namespace { template using shared_ptr_class_ = py::class_>; -PyObject* faulty_agent_init(PyObject* /* unused */) { +PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) { // Add the FaultyProcessGroupAgent and its backend options object to the - // python module torch.distributed.rpc._testing - auto faulty_agent_module = - THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc._testing")); - if (!faulty_agent_module) { + // python module torch._C._distributed_rpc_testing + auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); + if (!torch_C_module) { throw python_error(); } - auto module = py::handle(faulty_agent_module).cast(); + auto torch_C_m = py::handle(torch_C_module).cast(); + auto m = torch_C_m.def_submodule( + "_distributed_rpc_testing", "distributed rpc testing bindings"); + auto module = py::handle(m).cast(); // Import the rpc_module so we can subclass ProcessGroupAgent py::module rpc_module = py::module::import("torch.distributed.rpc"); @@ -66,7 +68,7 @@ PyObject* faulty_agent_init(PyObject* /* unused */) { .def( py::init< std::string, - std::shared_ptr<::c10d::ProcessGroup>, + c10::intrusive_ptr<::c10d::ProcessGroup>, int, std::chrono::milliseconds, const std::vector&, @@ -89,12 +91,17 @@ PyObject* faulty_agent_init(PyObject* /* unused */) { py::call_guard()) .def( "get_worker_info", - (const WorkerInfo& (ProcessGroupAgent::*)(void)const) & + (const WorkerInfo& (ProcessGroupAgent::*)(void) const) & RpcAgent::getWorkerInfo, py::call_guard()) .def( "get_worker_info", - (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) & + (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&) const) & + ProcessGroupAgent::getWorkerInfo, + py::call_guard()) + .def( + "get_worker_info", + (const WorkerInfo& (ProcessGroupAgent::*)(worker_id_t id) const) & ProcessGroupAgent::getWorkerInfo, py::call_guard()) .def( @@ -109,10 +116,7 @@ PyObject* faulty_agent_init(PyObject* /* unused */) { } // namespace static PyMethodDef methods[] = { // NOLINT - {"_faulty_agent_init", - (PyCFunction)faulty_agent_init, - METH_NOARGS, - nullptr}, + {"_faulty_agent_init", faulty_agent_init, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp index a9cd006439e88..d19f9b97ddf5f 100644 --- a/torch/csrc/distributed/rpc/torchscript_functions.cpp +++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -8,13 +7,14 @@ #include #include #include +#include #include namespace torch { namespace distributed { namespace rpc { -c10::intrusive_ptr rpcTorchscript( +c10::intrusive_ptr rpcTorchscript( const std::string& dstWorkerName, const c10::QualifiedName& qualifiedName, const c10::FunctionSchema& functionSchema, @@ -43,14 +43,14 @@ c10::intrusive_ptr rpcTorchscript( auto scriptCall = std::make_unique( qualifiedName, std::move(stack), isAsyncExecution); auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent(); - auto futMessage = autograd::sendMessageWithAutograd( + auto jitFuture = autograd::sendMessageWithAutograd( *rpcAgentPtr, rpcAgentPtr->getWorkerInfo(dstWorkerName), std::move(*scriptCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds); - // Get function return type to construct c10::ivalue::Future. + // Get function return type to construct JitFuture. auto returns = functionSchema.returns(); // Script call only allows single IValue returned. TORCH_INTERNAL_ASSERT( @@ -62,15 +62,15 @@ c10::intrusive_ptr rpcTorchscript( // Create a JIT future and pass it to futMessage's callback to set state // of the JIT future. - auto futPtr = c10::make_intrusive(returnType); - std::weak_ptr wp = futMessage; - futMessage->addCallback(at::wrapPropagateTLSState([futPtr, wp]() { - auto futMessage = wp.lock(); - if (futMessage->hasError()) { - c10::ivalue::Future::FutureError jitFutErr(futMessage->error()->what()); - futPtr->setError(std::make_exception_ptr(jitFutErr)); + auto futPtr = c10::make_intrusive(returnType); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState([futPtr, wp]() { + auto future = wp.lock(); + if (future->hasError()) { + futPtr->setError(future->exception_ptr()); } else { - futPtr->markCompleted(deserializeRespToIValue(futMessage->constValue())); + futPtr->markCompleted(deserializeRespToIValue( + *future->constValue().toCustomClass())); } })); if (shouldProfile) { @@ -112,21 +112,19 @@ c10::intrusive_ptr remoteTorchscript( userRRefPtr->forkId(), isAsyncExecution); - auto fm = torch::distributed::autograd::sendMessageWithAutograd( + auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( *rpcAgentPtr, dstWorkerInfo, std::move(*scriptRemoteCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds /* timeout */); - userRRefPtr->registerOwnerCreationFuture(fm); - + userRRefPtr->registerOwnerCreationFuture(jitFuture); ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr); - std::weak_ptr wp = fm; - fm->addCallback( + std::weak_ptr wp = jitFuture; + jitFuture->addCallback( at::wrapPropagateTLSState([wp, forkId{userRRefPtr->forkId()}]() { - auto fm = wp.lock(); - callback::confirmPendingUser(*fm, forkId); + callback::confirmPendingUser(*wp.lock(), forkId); })); return userRRefPtr; @@ -142,19 +140,18 @@ c10::intrusive_ptr remoteTorchscript( ownerRRefPtr->rrefId(), isAsyncExecution); - auto fm = torch::distributed::autograd::sendMessageWithAutograd( + auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( *rpcAgentPtr, dstWorkerInfo, std::move(*scriptRemoteCall).toMessage(), true /*forceGradRecording*/, rpcTimeoutSeconds /* timeout */); - ownerRRefPtr->registerOwnerCreationFuture(fm); - std::weak_ptr wp = fm; - fm->addCallback(at::wrapPropagateTLSState( + ownerRRefPtr->registerOwnerCreationFuture(jitFuture); + std::weak_ptr wp = jitFuture; + jitFuture->addCallback(at::wrapPropagateTLSState( [wp, ownerRRefId = ownerRRefPtr->rrefId()]() { - auto fm = wp.lock(); - callback::finishCreatingOwnerRRef(*fm, ownerRRefId); + callback::finishCreatingOwnerRRef(*wp.lock(), ownerRRefId); })); return ownerRRefPtr; } diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 981cfd50f95e9..0f137a72a2522 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include #include @@ -34,7 +36,7 @@ void processRemoteProfiledEvents( "Profiler was expected to be enabled. This can happen in callback " " continutations that run in different threads, and the TLS of the " " profiler was not propagated."); - std::vector events = + std::vector events = rpcWithProfilingResp.getProfiledEvents(); const auto& profilingId = rpcWithProfilingResp.getProfilingId(); auto& remoteProfilerManager = RemoteProfilerManager::getInstance(); @@ -44,7 +46,7 @@ void processRemoteProfiledEvents( std::for_each( events.begin(), events.end(), - [&keyPrefixStr](torch::autograd::profiler::Event& event) { + [&keyPrefixStr](torch::autograd::profiler::LegacyEvent& event) { std::string name = keyPrefixStr + std::string(event.name()); event.setName(at::StringView(name)); }); @@ -56,15 +58,15 @@ void processRemoteProfiledEvents( const std::string kRPCErrorPrefix = std::string("RPCErr"); -RPCErrorType getRPCErrorType(const FutureMessage& fm) { +RPCErrorType getRPCErrorType(const JitFuture& jitFuture) { TORCH_INTERNAL_ASSERT( - fm.hasError(), - "FutureMessage passed to getRPCErrorType does not have an error."); + jitFuture.hasError(), + "JitFuture of Message passed to getRPCErrorType does not have an error."); // Attempt to parse for error string given by makeRPCError, otherwise return // unknown error. // Note that this function expects errors formatted with makeRPCError(). - auto err = std::string(fm.error()->what()); + auto err = jitFuture.tryRetrieveErrorMessage(); size_t pos = err.find(kRPCErrorPrefix); if (pos != std::string::npos) { // Parse the RPCErrorType. @@ -134,6 +136,9 @@ std::unique_ptr deserializeRequest(const Message& request) { case MessageType::RUN_WITH_PROFILING_REQ: { return autograd::RpcWithProfilingReq::fromMessage(request); } + case MessageType::RREF_BACKWARD_REQ: { + return autograd::RRefBackwardReq::fromMessage(request); + } default: { TORCH_INTERNAL_ASSERT( false, "Request type ", request.type(), " not supported."); @@ -198,6 +203,9 @@ std::unique_ptr deserializeResponse( auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc(); return wrappedRPC; } + case MessageType::RREF_BACKWARD_RESP: { + return autograd::RRefBackwardResp::fromMessage(response); + } default: { TORCH_INTERNAL_ASSERT( false, "Response type ", response.type(), " not supported."); @@ -371,9 +379,10 @@ std::string wireSerialize( // converts CUDA tensor to cpu and data() might get destructed as we go // out of scope of this loop. auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]); - entries.push_back({c10::to_string(i), - writeableTensorData.data(), - writeableTensorData.sizeInBytes()}); + entries.push_back( + {c10::to_string(i), + writeableTensorData.data(), + writeableTensorData.sizeInBytes()}); } } @@ -501,6 +510,86 @@ std::vector readWrappedPayload( payload.resize(payload.size() - additionalPayloadSize); return tupleElements; } + +void populateRemoteProfiledEvents( + std::vector& profiledEvents, + const torch::autograd::profiler::ProfilerConfig& profilingConfig, + const std::vector>& + eventLists) { + // Gather all events into a vector + for (auto& l : eventLists) { + for (auto& e : l) { + profiledEvents.push_back(e); + } + } + // find __start_profile event and __cuda_start_event. + bool cudaProfilingEnabled = + profilingConfig.state == torch::autograd::profiler::ProfilerState::CUDA; + bool foundCpuStart = false; + const torch::autograd::profiler::LegacyEvent* profilerStart = nullptr; + // Each device has its own cudaProfilerStart, so we must take + // care to use the correct one depending on the device the + // operation ran on. + std::unordered_map + cudaProfilerStarts; + for (auto& e : profiledEvents) { + if (!foundCpuStart && 0 == strcmp(e.name(), "__start_profile")) { + profilerStart = &e; + foundCpuStart = true; + } else if ( + cudaProfilingEnabled && 0 == strcmp(e.name(), "__cuda_start_event")) { + e.setCudaUs(e.cpuUs()); + auto device = e.device(); + TORCH_CHECK( + device != -1, + "CUDA profiling was enabled but could not find CUDA device."); + TORCH_CHECK( + cudaProfilerStarts.find(device) == cudaProfilerStarts.end(), + c10::str("Duplicate __cuda_start_event found for ", device)); + cudaProfilerStarts[device] = &e; + } + + // TODO: determine no. of CUDA devices and break here if we have + // a cudaProfilerStart for all of them, in the case of cuda + // profiling. + if (foundCpuStart && !cudaProfilingEnabled) { + break; + } + } + // We should always find __start_profile. + TORCH_CHECK( + profilerStart != nullptr, "Expected to find __start_profile event."); + // Should have >= 1 CUDA start event if cudaProfilingEnabled. + // TODO: we can enhance this assert by ensuring we have found a + // start for every available CUDA device. + TORCH_CHECK( + !cudaProfilingEnabled || cudaProfilerStarts.size() > 0, + "Profiler was enabled with CUDA recording, but did not find __cuda_start_event."); + + if (cudaProfilingEnabled) { + // Compute and set global time for when this CUDA kernel was + // launched/ended, since deserialized event will not have a + // corresponding CUDA event. + for (auto& e : profiledEvents) { + if (e.hasCuda()) { + auto cudaDevice = e.device(); + TORCH_CHECK( + cudaDevice != -1, + "CUDA profiling was enabled but could not find CUDA device."); + auto it = cudaProfilerStarts.find(cudaDevice); + TORCH_CHECK( + it != cudaProfilerStarts.end(), + c10::str( + "Failed to find __cuda_start_event for device ", cudaDevice)); + auto cudaProfilerStartEvent = it->second; + double cudaElapsedUs = cudaProfilerStartEvent->cudaElapsedUs(e); + int64_t cudaUs = cudaElapsedUs + cudaProfilerStartEvent->cpuUs(); + e.setCudaUs(cudaUs); + } + } + } +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h index 806b52208eb0d..4d27daff6ffef 100644 --- a/torch/csrc/distributed/rpc/utils.h +++ b/torch/csrc/distributed/rpc/utils.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -14,7 +15,7 @@ namespace distributed { namespace rpc { // Parse error message and return RPCErrorType based on the message. -TORCH_API RPCErrorType getRPCErrorType(const FutureMessage& fm); +TORCH_API RPCErrorType getRPCErrorType(const JitFuture& jitFuture); // Create an error string given the error description and error type TORCH_API std::string makeRPCError( const std::string& rpcErrorStr, @@ -78,6 +79,14 @@ TORCH_API std::vector readWrappedPayload( std::vector& payload, const rpc::Message& message); +// Takes a list of events from autograd profiler and populates them into +// profiledEvents to be carried over RPC. +TORCH_API void populateRemoteProfiledEvents( + std::vector& profiledEvents, + const torch::autograd::profiler::ProfilerConfig& profilerConfig, + const std::vector>& + eventLists); + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/generic/StorageMethods.cpp b/torch/csrc/generic/StorageMethods.cpp index 0b1be66cf355c..0be4a7aa7b272 100644 --- a/torch/csrc/generic/StorageMethods.cpp +++ b/torch/csrc/generic/StorageMethods.cpp @@ -1,4 +1,6 @@ #include +#include +#include #ifdef USE_CUDA #include @@ -10,16 +12,18 @@ #define LSEEK lseek #endif -static PyObject * THPStorage_(size)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(size)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS - return PyLong_FromLong(self->cdata->nbytes() / sizeof(scalar_t)); + auto self = (THPStorage*)_self; + return THPUtils_packUInt64(self->cdata->nbytes() / sizeof(scalar_t)); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(dataPtr)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(dataPtr)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; return PyLong_FromVoidPtr(THWStorage_(data)(LIBRARY_STATE self->cdata)); END_HANDLE_TH_ERRORS } @@ -31,9 +35,10 @@ static PyObject * THPStorage_(copy_)(PyObject *self, PyObject *args, PyObject *k END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(isPinned)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(isPinned)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; #if defined(USE_CUDA) return PyBool_FromLong(at::globalContext().isPinnedPtr(THWStorage_(data)(LIBRARY_STATE self->cdata))); #else @@ -42,16 +47,18 @@ static PyObject * THPStorage_(isPinned)(THPStorage *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(elementSize)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(elementSize)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS - return PyLong_FromLong(THWStorage_(elementSize)(LIBRARY_STATE_NOARGS)); + auto self = (THPStorage*)_self; + return THPUtils_packInt64(THWStorage_(elementSize)(LIBRARY_STATE_NOARGS)); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(new)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(new)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THWStoragePtr new_storage(THWStorage_(new)(LIBRARY_STATE_NOARGS)); PyObject *_ret = THPStorage_(New)(new_storage); new_storage.release(); @@ -59,9 +66,10 @@ static PyObject * THPStorage_(new)(THPStorage *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(resize_)(THPStorage *self, PyObject *number_arg) +static PyObject * THPStorage_(resize_)(PyObject *_self, PyObject *number_arg) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THPUtils_assert(THPUtils_checkLong(number_arg), "resize_ expects an int, " "but got %s", THPUtils_typename(number_arg)); int64_t newsize = THPUtils_unpackLong(number_arg); @@ -72,9 +80,10 @@ static PyObject * THPStorage_(resize_)(THPStorage *self, PyObject *number_arg) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(fill_)(THPStorage *self, PyObject *number_arg) +static PyObject * THPStorage_(fill_)(PyObject *_self, PyObject *number_arg) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THPUtils_assert(THPUtils_(checkReal)(number_arg), "fill_ expects %s, " "but got %s", THPUtils_typeTraits::python_type_str, THPUtils_typename(number_arg)); @@ -218,12 +227,13 @@ static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObj END_HANDLE_TH_ERRORS } -PyObject * THPStorage_(writeFile)(THPStorage *self, PyObject *args) +PyObject * THPStorage_(writeFile)(PyObject *_self, PyObject *args) { HANDLE_TH_ERRORS - PyObject *file = PyTuple_GET_ITEM(args, 0); - bool is_real_file = PyTuple_GET_ITEM(args, 1) == Py_True; - bool save_size = PyTuple_GET_ITEM(args, 2) == Py_True; + auto self = (THPStorage*)_self; + PyObject *file = PyTuple_GetItem(args, 0); + bool is_real_file = PyTuple_GetItem(args, 1) == Py_True; + bool save_size = PyTuple_GetItem(args, 2) == Py_True; if (!is_real_file) { THPStorage_(writeFileRaw)(self->cdata, file, save_size); @@ -252,9 +262,10 @@ PyObject * THPStorage_(newWithFile)(PyObject *_unused, PyObject *file) END_HANDLE_TH_ERRORS } -static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args) +static PyObject *THPStorage_(setFromFile)(PyObject *_self, PyObject *args) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; PyObject *file = PyTuple_GET_ITEM(args, 0); PyObject *offset = PyTuple_GET_ITEM(args, 1); bool is_real_file = PyTuple_GET_ITEM(args, 2) == Py_True; @@ -301,17 +312,19 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args) } #ifdef THC_GENERIC_FILE -PyObject * THPStorage_(getDevice)(THPStorage *self, PyObject *noargs) +PyObject * THPStorage_(getDevice)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS - return PyLong_FromLong(THCStorage_(getDevice)(LIBRARY_STATE self->cdata)); + auto self = (THPStorage*)_self; + return THPUtils_packInt32(THCStorage_(getDevice)(LIBRARY_STATE self->cdata)); END_HANDLE_TH_ERRORS } #endif -PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata) +PyObject * THPStorage_(_setCdata)(PyObject *_self, PyObject *new_cdata) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THPUtils_assert(THPUtils_checkLong(new_cdata), "given an invalid argument to " "_set_cdata - expected an int or long, but got %s", THPUtils_typename(new_cdata)); @@ -325,24 +338,27 @@ PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata) } static PyMethodDef THPStorage_(methods)[] = { - {"copy_", (PyCFunction)(void(*)(void))THPStorage_(copy_), METH_VARARGS | METH_KEYWORDS, nullptr}, - {"element_size", (PyCFunction)THPStorage_(elementSize), METH_NOARGS, nullptr}, - {"fill_", (PyCFunction)THPStorage_(fill_), METH_O, nullptr}, - {"new", (PyCFunction)THPStorage_(new), METH_NOARGS, nullptr}, - {"resize_", (PyCFunction)THPStorage_(resize_), METH_O, nullptr}, - {"size", (PyCFunction)THPStorage_(size), METH_NOARGS, nullptr}, - {"data_ptr", (PyCFunction)THPStorage_(dataPtr), METH_NOARGS, nullptr}, - {"is_pinned", (PyCFunction)THPStorage_(isPinned), METH_NOARGS, nullptr}, - {"_write_file", (PyCFunction)THPStorage_(writeFile), METH_VARARGS, nullptr}, - {"_new_with_file", (PyCFunction)(void(*)(void))THPStorage_(newWithFile), METH_O | METH_STATIC, nullptr}, - {"_set_from_file", (PyCFunction)THPStorage_(setFromFile), METH_VARARGS, nullptr}, + {"copy_", castPyCFunctionWithKeywords(THPStorage_(copy_)), + METH_VARARGS | METH_KEYWORDS, nullptr}, + {"element_size", THPStorage_(elementSize), METH_NOARGS, nullptr}, + {"fill_", THPStorage_(fill_), METH_O, nullptr}, + {"new", THPStorage_(new), METH_NOARGS, nullptr}, + {"resize_", THPStorage_(resize_), METH_O, nullptr}, + {"size", THPStorage_(size), METH_NOARGS, nullptr}, + {"data_ptr", THPStorage_(dataPtr), METH_NOARGS, nullptr}, + {"is_pinned", THPStorage_(isPinned), METH_NOARGS, nullptr}, + {"_write_file", THPStorage_(writeFile), METH_VARARGS, nullptr}, + {"_new_with_file", THPStorage_(newWithFile), METH_O | METH_STATIC, nullptr}, + {"_set_from_file", THPStorage_(setFromFile), METH_VARARGS, nullptr}, #if !defined(THC_GENERIC_FILE) - {"from_buffer", (PyCFunction)(void(*)(void))THPStorage_(fromBuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"from_buffer", castPyCFunctionWithKeywords(THPStorage_(fromBuffer)), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, #endif - {"from_file", (PyCFunction)(void(*)(void))THPStorage_(fromFile), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"from_file", castPyCFunctionWithKeywords(THPStorage_(fromFile)), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, #ifdef THC_GENERIC_FILE - {"get_device", (PyCFunction)THPStorage_(getDevice), METH_NOARGS, nullptr}, + {"get_device", THPStorage_(getDevice), METH_NOARGS, nullptr}, #endif - {"_set_cdata", (PyCFunction)THPStorage_(_setCdata), METH_O, nullptr}, + {"_set_cdata", THPStorage_(_setCdata), METH_O, nullptr}, {nullptr} }; diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 018c272159ffb..ce7d67c6445f6 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -4,11 +4,13 @@ #include #endif +#include #include -static PyObject * THPStorage_(sharedDecref)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(sharedDecref)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; #ifndef THC_GENERIC_FILE THWStorage *storage = self->cdata; THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); @@ -21,9 +23,10 @@ static PyObject * THPStorage_(sharedDecref)(THPStorage *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(sharedIncref)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(sharedIncref)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; #ifndef THC_GENERIC_FILE THWStorage *storage = self->cdata; THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); @@ -69,9 +72,10 @@ static PyObject * THPStorage_(pyNewFilenameStorage)(PyObject *_unused, PyObject END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(shareFilename)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(shareFilename)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THWStorage *storage = self->cdata; THManagedMapAllocator *ctx; // Storage is already in shared memory, just return a handle @@ -92,7 +96,7 @@ static PyObject * THPStorage_(shareFilename)(THPStorage *self, PyObject *noargs) if (!manager_handle) return nullptr; THPObjectPtr storage_handle(PyBytes_FromString(ctx->filename())); if (!storage_handle) return nullptr; - THPObjectPtr size(PyLong_FromLong(storage->nbytes() / sizeof(scalar_t))); + THPObjectPtr size(THPUtils_packUInt64(storage->nbytes() / sizeof(scalar_t))); if (!size) return nullptr; THPObjectPtr tuple(PyTuple_New(3)); @@ -151,9 +155,10 @@ static PyObject * THPStorage_(pyNewFdStorage)(PyObject *_unused, PyObject *args) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(shareFd)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(shareFd)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THWStorage *storage = self->cdata; THMapAllocator *ctx; // Storage is already in shared memory, just return a handle @@ -168,9 +173,9 @@ static PyObject * THPStorage_(shareFd)(THPStorage *self, PyObject *noargs) AT_ASSERT(ctx); } - THPObjectPtr storage_handle(PyLong_FromLong(ctx->fd())); + THPObjectPtr storage_handle(THPUtils_packInt32(ctx->fd())); if (!storage_handle) return nullptr; - THPObjectPtr size(PyLong_FromLong(storage->nbytes() / sizeof(scalar_t))); + THPObjectPtr size(THPUtils_packUInt64(storage->nbytes() / sizeof(scalar_t))); if (!size) return nullptr; THPObjectPtr tuple(PyTuple_New(2)); @@ -214,9 +219,10 @@ static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args) #else // THC_GENERIC_FILE -static PyObject * THPStorage_(shareCuda)(THPStorage *self, PyObject *noargs) +static PyObject * THPStorage_(shareCuda)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THWStorage *storage = self->cdata; if (storage->received_cuda()) { @@ -226,14 +232,14 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self, PyObject *noargs) at::DeviceGuard device_guard(storage->device()); THPObjectPtr tuple(PyTuple_New(8)); - THPObjectPtr device(PyLong_FromLong(storage->device().index())); + THPObjectPtr device(THPUtils_packInt32(storage->device().index())); THPObjectPtr _handle(Py_None); Py_INCREF(Py_None); - THPObjectPtr size_bytes(PyLong_FromLong(storage->nbytes())); - THPObjectPtr _offset_bytes(PyLong_FromLong(0)); + THPObjectPtr size_bytes(THPUtils_packUInt64(storage->nbytes())); + THPObjectPtr _offset_bytes(THPUtils_packInt32(0)); THPObjectPtr _ref_counter(Py_None); Py_INCREF(Py_None); - THPObjectPtr _ref_counter_offset(PyLong_FromLong(0)); + THPObjectPtr _ref_counter_offset(THPUtils_packInt32(0)); THPObjectPtr _event_handle(Py_None); Py_INCREF(Py_None); THPObjectPtr _event_sync_required(Py_None); @@ -256,7 +262,7 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self, PyObject *noargs) auto sent_data = static_cast(storage->data_ptr().get_context()); sent_data->set_original_ptr(std::move(old_data_ptr)); _ref_counter = PyBytes_FromString((sent_data->handle()).c_str()); - _ref_counter_offset = PyLong_FromLong(sent_data->offset()); + _ref_counter_offset = THPUtils_packInt64(sent_data->offset()); cudaIpcEventHandle_t ipc_event_handle; @@ -454,8 +460,9 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args) // pointer. // // NB: This does NOT preserve object identity when you call it multiple times -static PyObject * THPStorage_(weakRef)(THPStorage *self, PyObject *args) { +static PyObject * THPStorage_(weakRef)(PyObject *_self, PyObject *args) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THStorage* storage = self->cdata; return PyLong_FromVoidPtr(c10::raw::intrusive_ptr::make_weak(storage)); END_HANDLE_TH_ERRORS @@ -498,9 +505,10 @@ PyObject * THPStorage_(expired)(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject * THPStorage_(sharedFd)(THPStorage *self, PyObject *noargs) +PyObject * THPStorage_(sharedFd)(PyObject *_self, PyObject *noargs) { HANDLE_TH_ERRORS + auto self = (THPStorage*)_self; THMapAllocator *ctx = nullptr; #ifndef THC_GENERIC_FILE THWStorage *storage = self->cdata; @@ -508,12 +516,13 @@ PyObject * THPStorage_(sharedFd)(THPStorage *self, PyObject *noargs) #endif THPUtils_assert(ctx, "couldn't retrieve a shared file descriptor"); - return PyLong_FromLong(ctx->fd()); + return THPUtils_packInt32(ctx->fd()); END_HANDLE_TH_ERRORS } -PyObject * THPStorage_(isShared)(THPStorage *self, PyObject *noargs) +PyObject * THPStorage_(isShared)(PyObject *_self, PyObject *noargs) { + auto self = (THPStorage*)_self; #ifdef THC_GENERIC_FILE Py_RETURN_TRUE; #else @@ -527,25 +536,25 @@ PyObject * THPStorage_(isShared)(THPStorage *self, PyObject *noargs) } static PyMethodDef THPStorage_(sharingMethods)[] = { - {"_new_with_weak_ptr", (PyCFunction)(void(*)(void))THPStorage_(newWithWeakPtr), METH_O | METH_CLASS, nullptr}, + {"_new_with_weak_ptr", THPStorage_(newWithWeakPtr), METH_O | METH_CLASS, nullptr}, #ifdef THC_GENERIC_FILE - {"_share_cuda_", (PyCFunction)THPStorage_(shareCuda), METH_NOARGS, nullptr}, - {"_new_shared_cuda", (PyCFunction)(void(*)(void))THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr}, - {"_release_ipc_counter", (PyCFunction)(void(*)(void))THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr}, + {"_share_cuda_", THPStorage_(shareCuda), METH_NOARGS, nullptr}, + {"_new_shared_cuda", THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr}, + {"_release_ipc_counter", THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr}, #else - {"_share_fd_", (PyCFunction)THPStorage_(shareFd), METH_NOARGS, nullptr}, - {"_new_shared_fd", (PyCFunction)(void(*)(void))THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr}, - {"_new_using_fd", (PyCFunction)(void(*)(void))THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr}, - {"_share_filename_", (PyCFunction)THPStorage_(shareFilename), METH_NOARGS, nullptr}, - {"_new_shared_filename", (PyCFunction)(void(*)(void))THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr}, - {"_new_using_filename", (PyCFunction)(void(*)(void))THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr}, + {"_share_fd_", THPStorage_(shareFd), METH_NOARGS, nullptr}, + {"_new_shared_fd", THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr}, + {"_new_using_fd", THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr}, + {"_share_filename_", THPStorage_(shareFilename), METH_NOARGS, nullptr}, + {"_new_shared_filename", THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr}, + {"_new_using_filename", THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr}, #endif - {"_weak_ref", (PyCFunction)THPStorage_(weakRef), METH_NOARGS, nullptr}, - {"_free_weak_ref", (PyCFunction)(void(*)(void))THPStorage_(freeWeakRef), METH_O | METH_STATIC, nullptr}, - {"_expired", (PyCFunction)(void(*)(void))THPStorage_(expired), METH_O | METH_STATIC, nullptr}, - {"_shared_decref", (PyCFunction)THPStorage_(sharedDecref), METH_NOARGS, nullptr}, - {"_shared_incref", (PyCFunction)THPStorage_(sharedIncref), METH_NOARGS, nullptr}, - {"_get_shared_fd", (PyCFunction)THPStorage_(sharedFd), METH_NOARGS, nullptr}, - {"is_shared", (PyCFunction)THPStorage_(isShared), METH_NOARGS, nullptr}, + {"_weak_ref", THPStorage_(weakRef), METH_NOARGS, nullptr}, + {"_free_weak_ref", THPStorage_(freeWeakRef), METH_O | METH_STATIC, nullptr}, + {"_expired", THPStorage_(expired), METH_O | METH_STATIC, nullptr}, + {"_shared_decref", THPStorage_(sharedDecref), METH_NOARGS, nullptr}, + {"_shared_incref", THPStorage_(sharedIncref), METH_NOARGS, nullptr}, + {"_get_shared_fd", THPStorage_(sharedFd), METH_NOARGS, nullptr}, + {"is_shared", THPStorage_(isShared), METH_NOARGS, nullptr}, {nullptr} }; diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index 5611098194fde..c850b16c374c7 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -51,6 +51,7 @@ Sections start with a reference to the source file where the code related to the - [Interpreter](#interpreter) - [Graph Executor](#graph-executor) - [JIT Logging](#jit-logging) + - [JIT Optimization Limitter](#jit-optimization-limitter) - [DifferentiableGraphOp](#differentiablegraphop) - [Interpreter](#interpreter-1) - [FusionGroup](#fusiongroup) @@ -180,7 +181,7 @@ For Nodes representing built-in Operators, the method `Node::schema` can also lo All of the strings correspond to different `FunctionSchema` objects. A `Node` can be queried for its schema using the `schema()` method (it will check the argument types, and will try to match one of the options for its `kind()`). -Note that the chosen overload is not shown in any way in the textual output. If you're unsure which function does a node resolve to, you might need to check the type annotations of its input values. +Note that the chosen overload is not shown in any way in the textual output. If you're unsure which function a node resolves to, you might need to check the type annotations of its input values. Each node also has a set of attributes which are named integers, strings, floats, Tensors, and subgraphs, or lists of these types. These are used by special primitive operators to encode additional data in the Node. For instance `prim::Constant` defines a compile-time constant value. For Tensor constants, it will have a single Tensor attribute with the name `attr::value` which contains the value of the constant. @@ -204,7 +205,7 @@ Iterators for the `nodes()` list are invalided when the current Node they point Block also contain a list of input and output values. The meaning of these values depends on where the block is used. For the Graph's top-level block, these are inputs and outputs to the Graph, and line up with the FunctionSchema associated with a Method. -**Control-flow** is represented with using sub-blocks rather than a control-flow graph representation. A `prim::If` has one block for the true branch and one block for the else.A `prim:Loop` has a block for the loop body (there is no condition block, instead the end of the loop body computes whether to re-enter the loop body). This representation ensures we have structured control-flow. This limitation makes a lot of optimizations easier and is true for the vast majority of networks. A Node can lookup what Block it is in, and a Block and can look up its parent (either the Node that has it as a subblock, or `nullptr` for the main Block). +**Control-flow** is represented with using sub-blocks rather than a control-flow graph representation. A `prim::If` has one block for the true branch and one block for the else.A `prim:Loop` has a block for the loop body (there is no condition block, instead the end of the loop body computes whether to re-enter the loop body). This representation ensures we have structured control-flow. This limitation makes a lot of optimizations easier and is true for the vast majority of networks. A Node can look up what Block it is in, and a Block and can look up its parent (either the Node that has it as a subblock, or `nullptr` for the main Block). ### If ### For if-statements (`prim::If`) the Blocks have no inputs, and the outputs are the new values of variables in the outer block whose values were altered in an if-statement. @@ -383,7 +384,7 @@ The entry point from Python into C++ for tracing using `torch.jit.trace` is `_cr A thread local instance of the TracingState object maintains a mapping between actual data being computed during the trace (e.g. Tensors) stored in IValues, and the abstract `Value*` in the Graph that would compute that value. The functions `void setValueTrace(const IValue&, Value*)` and `Value* getValueTrace(const IValue&)` are used by the tracer to maintain this mapping. -An initial IValue to Value mapping is setup up between the inputs to the function being traced and symbolic Value inputs to the Graph being constructed. If we are tracing a `torch.nn.Module`, the tracer also adds Parameters and sub-Modules to the Module being constructed that correspond to the Python `torch.nn.Module` being traced. These values are also added as mapping so that uses of the Parameters in the trace will create uses of the Parameters in the Graph. +An initial IValue to Value mapping is set up between the inputs to the function being traced and symbolic Value inputs to the Graph being constructed. If we are tracing a `torch.nn.Module`, the tracer also adds Parameters and sub-Modules to the Module being constructed that correspond to the Python `torch.nn.Module` being traced. These values are also added as mapping so that uses of the Parameters in the trace will create uses of the Parameters in the Graph. As the trace runs, individual operators create Nodes in the Graph being traced to record what happens. This code is currently generated per operator in [tools/autograd/gen_variable_type.py](../../../tools/autograd/gen_variable_type.py). It results in code that looks like the following: @@ -886,7 +887,7 @@ def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh): return hy, cy ``` -After going through the the frontend, we get start with this unoptimized graph: +After going through the the frontend, we start with this unoptimized graph: ``` graph(%x : Tensor, @@ -1184,6 +1185,33 @@ By default, types in the graph are printed with maximum verbosity. The verbosit * `2`: Also print strides * `3`: Also print device type and whether gradient is required +## JIT Optimization Limitter ## + +[jit_opt_limit.h](jit_opt_limit.h) + +Often times, we need to limit the number of optimizations for any lowering passes for debugging purposes. + +`TorchScript` offers a simple optimization limit checker that can be configured through environment variable `PYTORCH_JIT_OPT_LIMIT`. The purpose is to limit how many optimization you can make per pass. This is useful for debugging any passes. + +Opt limit checker is enabled on a per file basis (hence per pass). For example, in `constant_propagation.cpp`, `PYTORCH_JIT_OPT_LIMIT` should be set to `constant_propagation=` where `` is the number of optimizations you want to make for the pass. (i.e. +`PYTORCH_JIT_OPT_LIMIT="constant_propagation="`). + +Multiple files can be configured by separating each file name with a colon +`:` as in the following example, +`PYTORCH_JIT_OPT_LIMIT="constant_propagation=:dead_code_elimination="` + +You can call opt limiter by calling `JIT_OPT_LIMIT()`. It will return true if +we haven't reached the optimization limit yet. Otherwise, it will return +false. Typical usage: + +```cpp +auto allowed = JIT_OPT_LIMIT(); +if (!allowed) { + GRAPH_DUMP(...); //supplied from jit_log + return; +} +``` + ## DifferentiableGraphOp ## [runtime/graph_executor.cpp](runtime/graph_executor.cpp) @@ -1209,7 +1237,7 @@ a = torch.rand(2, 3) b = a # At this point, `a` and `b` share their storage. c = b[0] -# `c` is shares storage with `a` and `b`, but only sees a slice of the allocated memory. +# `c` shares storage with `a` and `b`, but only sees a slice of the allocated memory. ``` Some operators will *mutate* one or more of their operands in-place. These are typically denoted with a trailing underscore, or by taking an `out` argument as input: @@ -1293,7 +1321,7 @@ So to determine whether `a` and `b` may alias, we traverse the `AliasTracker` D ### Writing optimization passes with `AliasDb` `AliasDb` provides a high-level interface to help people write mutability-safe optimization passes. -In particular, `moveAfterTopologicallyValid()` (and it's `moveBefore` variant) will reorder nodes in a way that preserves data dependencies and avoids any data hazards. The rules for this are that all mutable *writes* to a given memory location must occur in the same order (avoid WAW hazards), and that no reads can be reordered before or after any write (WAR, RAW hazards). +In particular, `moveAfterTopologicallyValid()` (and its `moveBefore` variant) will reorder nodes in a way that preserves data dependencies and avoids any data hazards. The rules for this are that all mutable *writes* to a given memory location must occur in the same order (avoid WAW hazards), and that no reads can be reordered before or after any write (WAR, RAW hazards). However, reordering of reads across writes *is allowed* if we can prove that the read cannot alias the thing being written. This happens whenever we have tensors that come from functions that produce fresh results (common) inside of the function. It also happens whenever the creation of the mutable tensor is seen in the function (so it gets assigned a fresh variable), and all of its writes occur in that function. diff --git a/torch/csrc/jit/api/compilation_unit.h b/torch/csrc/jit/api/compilation_unit.h index 4e2fa336e589f..96a7f0a463bce 100644 --- a/torch/csrc/jit/api/compilation_unit.h +++ b/torch/csrc/jit/api/compilation_unit.h @@ -34,15 +34,15 @@ struct Resolver; using ResolverPtr = std::shared_ptr; struct Self { - virtual ~Self() {} + virtual ~Self() = default; virtual std::shared_ptr makeSugared(Value* v) const = 0; virtual ClassTypePtr getClassType() const = 0; }; // A CompilationUnit is a list of named Functions -// with helper methods to iterate the list, or invoke the function. -// Classes have a CompilationUnit holding the class methods -// and Modules also have a CompilationUnit holding the Functions that +// with helper methods to iterate the list or invoke the function. +// Classes have a CompilationUnit holding the class methods, +// and Modules have a CompilationUnit holding the Functions that // are used to implement their Methods struct TORCH_API CompilationUnit { @@ -85,7 +85,7 @@ struct TORCH_API CompilationUnit { } // for historic reasons, these are defined in ir_emitter.cpp - // Returns the list of Function's just defined. + // Returns the list of Functions just defined. std::vector define( const c10::optional& prefix, const std::vector& properties, @@ -100,7 +100,7 @@ struct TORCH_API CompilationUnit { bool shouldMangle = false); // same as above but parse the definitions from source - // Returns the list of Function's just defined. + // Returns the list of Functions just defined. std::vector define( // prefix namespace to put all the defined functions into const c10::optional& prefix, @@ -210,7 +210,7 @@ struct TORCH_API CompilationUnit { // have isolation. void _clear_python_cu() { // Delete all the associated class methods - for (auto type : classes_) { + for (const auto& type : classes_) { if (auto cls = type->cast()) { for (auto method : cls->methods()) { // Tombstone the method in the compilation unit. diff --git a/torch/csrc/jit/api/function_impl.cpp b/torch/csrc/jit/api/function_impl.cpp index 5914be1645176..b6600bee0820d 100644 --- a/torch/csrc/jit/api/function_impl.cpp +++ b/torch/csrc/jit/api/function_impl.cpp @@ -39,8 +39,10 @@ void GraphFunction::run(Stack&& stack) { run(stack); } -c10::intrusive_ptr GraphFunction::runAsync(Stack& stack) { - return get_executor().runAsync(stack); +c10::intrusive_ptr GraphFunction::runAsync( + Stack& stack, + TaskLauncher taskLauncher) { + return get_executor().runAsync(stack, std::move(taskLauncher)); } IValue GraphFunction::operator()( diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h index aacdc7525bceb..c99ce9a7a4d94 100644 --- a/torch/csrc/jit/api/function_impl.h +++ b/torch/csrc/jit/api/function_impl.h @@ -25,7 +25,9 @@ struct TORCH_API GraphFunction : public Function { void run(Stack&& stack) override; - c10::intrusive_ptr runAsync(Stack& stack) override; + c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher taskLauncher = at::launch) override; IValue operator()(std::vector stack, const Kwargs& kwargs = Kwargs()) override; diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h index 1d0ea9bce2c82..96b632b6b1119 100644 --- a/torch/csrc/jit/api/method.h +++ b/torch/csrc/jit/api/method.h @@ -31,6 +31,15 @@ struct TORCH_API Method { std::vector stack, const Kwargs& kwargs = Kwargs()); + // Run method async. Invocation on this function would invokes a JIT + // interpreter that executes ops inline, one by one, on caller's thread. A + // model can utilize async op, i.e. `fork`, to launch an asynchronous task + // which will be launched on provided `taskLauncher`. + c10::intrusive_ptr run_async( + std::vector stack, + const Kwargs& kwargs = Kwargs(), + TaskLauncher taskLauncher = at::launch); + std::shared_ptr graph() const { return function_->graph(); } diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index f645f73c7a95d..3c53e543a49da 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -75,7 +76,7 @@ void Module::to(at::Device device, bool non_blocking) { } void module_state_to( - autograd::Variable variable, + const autograd::Variable& variable, const c10::optional& device, const c10::optional& dtype, bool non_blocking) { @@ -118,6 +119,17 @@ IValue Method::operator()(std::vector stack, const Kwargs& kwargs) { return (*function_)(std::move(stack), kwargs); } +c10::intrusive_ptr Method::run_async( + std::vector stack, + const Kwargs& kwargs, + TaskLauncher taskLauncher) { + stack.insert(stack.begin(), owner()._ivalue()); + RECORD_TORCHSCRIPT_FUNCTION(name(), stack); + + function_->getSchema().checkAndNormalizeInputs(stack, kwargs); + return function_->runAsync(stack, std::move(taskLauncher)); +} + void Module::clone_method( const Module& orig, const Function& method, @@ -243,6 +255,14 @@ Module Module::clone_impl( for (auto& fn : type()->methods()) { r.clone_method(*this, *fn, type_remap); } + + // Execute __setstate__(__getstate__()) to initialize custom class members. + if (auto setstate_method = r.find_method("__setstate__")) { + auto getstate_method = r.find_method("__getstate__"); + TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__"); + auto state = (*getstate_method)(Stack{}); + (*setstate_method)(Stack{state}); + } } return r; } diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 09fab9d161155..4b68a85c6696d 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -25,6 +25,7 @@ #include #include #include +#include #include // This file contains classes which assist in desugaring Python style @@ -87,13 +88,13 @@ using ModuleLookup = std::function&)>; struct TORCH_API Module : public Object { explicit Module(c10::QualifiedName class_name); Module(std::shared_ptr cu, const c10::ClassTypePtr& type); - Module() {} + Module() = default; Module( c10::QualifiedName, std::shared_ptr cu, bool shouldMangle = false); Module(ModulePtr module_value) : Object(std::move(module_value)) {} - ~Module() {} + ~Module() = default; void set_optimized(bool o) { TORCH_WARN( @@ -133,7 +134,7 @@ struct TORCH_API Module : public Object { void register_attribute( const std::string& name, - const TypePtr t, + const TypePtr& t, IValue v, bool is_param = false, bool is_buffer = false) { @@ -264,7 +265,7 @@ struct TORCH_API Module : public Object { const std::unordered_map& type_remap); c10::QualifiedName getNameForMethod(std::string basename) const { - return QualifiedName(*type()->name(), basename); + return QualifiedName(*type()->name(), std::move(basename)); } void to_impl( @@ -440,7 +441,7 @@ struct slot_list_impl { } slot_list_impl(Module module, bool recurse, bool return_module) - : module_(std::move(module)), + : module_(module), recurse_(recurse), return_module_(return_module), size_(c10::nullopt) { @@ -547,7 +548,7 @@ struct NamedPolicy { } name = ss.str(); } - return value_type{std::move(name), Policy::create(cursors, v)}; + return value_type{std::move(name), Policy::create(cursors, std::move(v))}; } static bool valid(const ClassTypePtr& t, size_t i, const IValue& v) { return Policy::valid(t, i, v); diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 305c254ad1c06..13a2b8298a4aa 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -12,8 +12,15 @@ using ResolverPtr = std::shared_ptr; using ObjectPtr = c10::intrusive_ptr; +// Throw this in C++ land if `attr` fails. This will be converted to a Python +// AttributeError by the Python binding code +class ObjectAttributeError : public std::runtime_error { + public: + ObjectAttributeError(const std::string& what) : std::runtime_error(what) {} +}; + struct TORCH_API Object { - Object() {} + Object() = default; Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {} Object(std::shared_ptr cu, const c10::ClassTypePtr& type); Object( @@ -59,12 +66,10 @@ struct TORCH_API Object { if (auto r = _ivalue()->type()->findConstantSlot(name)) { return _ivalue()->type()->getConstant(*r); } - TORCH_CHECK( - false, - _ivalue()->type()->repr_str(), - " does not have a field with name '", - name, - "'"); + std::stringstream err; + err << _ivalue()->type()->repr_str() << " does not have a field with name '" + << name.c_str() << "'"; + throw ObjectAttributeError(err.str()); } c10::IValue attr(const std::string& name, c10::IValue or_else) const { diff --git a/torch/csrc/jit/backends/backend_detail.cpp b/torch/csrc/jit/backends/backend_detail.cpp index 9ec21297c0c10..32a18179d8746 100644 --- a/torch/csrc/jit/backends/backend_detail.cpp +++ b/torch/csrc/jit/backends/backend_detail.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { diff --git a/torch/csrc/jit/backends/backend_detail.h b/torch/csrc/jit/backends/backend_detail.h index 2d19f2ed8950f..00f0f2f9eb440 100644 --- a/torch/csrc/jit/backends/backend_detail.h +++ b/torch/csrc/jit/backends/backend_detail.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace torch { diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp index b01cb62dc3a27..2f6ceeb60d72a 100644 --- a/torch/csrc/jit/backends/backend_init.cpp +++ b/torch/csrc/jit/backends/backend_init.cpp @@ -1,12 +1,123 @@ #include + #include #include #include +#include #include namespace torch { namespace jit { +// Get all types that are shared in the module hierarchy rooted at \p mod. +std::unordered_set getSharedModuleTypes(Module& mod) { + // Maintain a set of all TypePtrs. + std::unordered_set types; + // Maintain another set of TypePtrs that have been encountered more than once. + std::unordered_set duplicate_types; + + // Iterate over all modules in the hierarchy, including the root. + for (auto module : mod.modules()) { + auto module_type = module.type(); + if (types.count(module_type) > 0) { + duplicate_types.insert(module_type); + } + + types.insert(module_type); + } + + return duplicate_types; +} + +// Selectively lower \p mod to a backend. \p to_backend +// is called to lower modules. \p modules_to_lower contains +// qualified names of submodules of \p mod that should be lowered. +void toBackendSelectiveImpl( + Module& mod, + const py::function& to_backend, + const std::vector& modules_to_lower, + const std::unordered_set& duplicate_types) { + // This map will be used later to remap types in ancestor module graphs for + // all lowered submodules. + std::unordered_map type_remap; + + // For each module that should be lowered: + for (const auto& module_to_lower : modules_to_lower) { + // Use QualifiedName to parse the qualified module names. + c10::QualifiedName qual_module_name(module_to_lower); + auto& atoms = qual_module_name.atoms(); + + // Search through the module hierarchy using the atoms of + // qual_module_name until current points to the module to + // be lowered and parent points to its parent. + Module current = mod; + Module parent; + + for (size_t i = 0, e = atoms.size(); i < e; ++i) { + IValue submodule = current.attr(atoms[i]); + if (submodule.isModule()) { + if (i == e - 1) { + parent = current; + } + current = submodule.toModule(); + } else { + std::stringstream err; + err << "Attribute named " << atoms[i] << " is not a Module"; + throw std::runtime_error(err.str()); + } + } + + // Check that the parent type is not shared and therefore can be edited. + if (duplicate_types.count(parent.type()) > 0) { + throw py::cast_error(c10::str( + "Selective lowering is only supported for module hierarchies with unique types for selected modules; ", + parent.type()->repr_str(), + " is shared")); + } + + // Call to_backend on the module that needs to be lowered. It needs to be + // wrapped before doing so because _to_jit_backend accepts wrapped modules. + // The result needs to be unwrapped in order to access its type below. + auto lowered_submodule = + py::cast(to_backend(py::module::import("torch.jit._recursive") + .attr("wrap_cpp_module")(current)) + .attr("_c")); + + // Adjust the parent's type so that the type of the submodule matches + // the type of lowered_submodule. + auto parent_type = parent.type(); + + parent_type->unsafeChangeAttributeType( + atoms.back(), lowered_submodule.type()); + parent.setattr(atoms.back(), lowered_submodule._ivalue()); + + // Record the type mapping from old type -> lowered type. + type_remap[current.type()] = lowered_submodule.type(); + } + + // Having lowered all of the modules that needed to be lowered, remap types in + // all graphs in the hierarchy so that the graphs all use the new lowered + // type. + auto type_remap_fn = [&type_remap](TypePtr in) { + auto it = type_remap.find(in); + if (it == type_remap.end()) + return in; + return it->second; + }; + + // modules() iterates over all modules in the hierarchy including the root. + for (auto module : mod.modules()) { + auto module_type = module.type(); + for (auto& fn : module_type->methods()) { + auto method = module.get_method(fn->name()); + auto graph = method.graph(); + graph->remapTypes(type_remap_fn); + auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn); + fn->setSchema(new_schema); + } + } +} + void initJitBackendBindings(PyObject* module) { // Bind a function for lowering to each JIT backend. The name of the backend // must be the first argument. For example, to lower a Module to @@ -20,11 +131,12 @@ void initJitBackendBindings(PyObject* module) { auto codegen_lambda = [=](const std::string& backend_name, const Module& orig_module, const py::dict& method_compile_spec) { - const c10::QualifiedName qual_backend_name({"__torch__", - "torch", - "classes", - detail::kBackendsNamespace, - backend_name}); + const c10::QualifiedName qual_backend_name( + {"__torch__", + "torch", + "classes", + detail::kBackendsNamespace, + backend_name}); // TODO: Validate method_compile_spec. // Clone orig_module to make sure backend transformation is @@ -124,7 +236,7 @@ void initJitBackendBindings(PyObject* module) { static const auto method_ct = CodeTemplate(R"( def $method(self${,def_inputs}): typed_inputs: List[Any] = [${fwd_inputs,}] - $ret, = self.__backend.execute(self.__handles["$method"], typed_inputs) + $unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs) ${refine,} return $ret )"); @@ -181,7 +293,9 @@ void initJitBackendBindings(PyObject* module) { out_ss << "_0"; type_check_ss << "assert isinstance(_0, "; - if (auto out_tuple_ty = out_ty->cast()) { + auto out_tuple_ty = out_ty->cast(); + + if (out_tuple_ty) { auto tuple_elements = out_tuple_ty->elements(); type_check_ss << tuple_elements[0]->str() << ")"; type_checks.emplace_back(type_check_ss.str()); @@ -201,6 +315,14 @@ void initJitBackendBindings(PyObject* module) { method_te.v("def_inputs", def_inputs); method_te.v("fwd_inputs", fwd_inputs); method_te.v("refine", type_checks); + method_te.s("unpack", out_ss.str()); + + // If the output type is a single element tuple then add an extra comma + // to ensure the final output maintains this type. + if (out_tuple_ty && out_tuple_ty->elements().size() == 1) { + out_ss << ","; + } + method_te.s("ret", out_ss.str()); loweredModule.define( @@ -226,11 +348,39 @@ void initJitBackendBindings(PyObject* module) { m.def( "_jit_to_backend", [=](const std::string& backend_name, - const Module& orig_module, + py::handle orig_module, const py::dict& method_compile_spec) { return py::module::import("torch.jit._recursive") - .attr("wrap_cpp_module")( - codegen_lambda(backend_name, orig_module, method_compile_spec)); + .attr("wrap_cpp_module")(codegen_lambda( + backend_name, + py::cast(orig_module.attr("_c")), + method_compile_spec)); + }); + + m.def( + "_jit_to_backend_selective", + [=](py::handle orig_module, + const py::function& to_backend, + const std::vector& modules_to_lower) { + if (auto original_module = + as_module(py::cast(orig_module))) { + // Clone the Module to avoid editing types that are shared with + // Modules in other instances outside this hierarchy. + Module& mod = original_module.value(); + auto cloned_mod = mod.clone(); + // Get all shared module types. Type sharing is only a problem if the + // parent modules of the ones to lower are in this set. + auto shared_types = getSharedModuleTypes(cloned_mod); + toBackendSelectiveImpl( + cloned_mod, to_backend, modules_to_lower, shared_types); + // Wrap the result in a RecursiveScriptModule because that's what + // the caller passed in. + return py::module::import("torch.jit._recursive") + .attr("wrap_cpp_module")(cloned_mod); + } + + throw py::cast_error(c10::str( + "Object ", py::str(orig_module), " is not a ScriptModule")); }); } } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index b1e5afb2299ae..a9abd22251007 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -6,6 +7,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -410,11 +412,12 @@ static TensorView* newForReduction( (*(axes_set.rbegin())) < orig_domain.size(), "Error setting up reduction, reduction axis is outside nDims. Keep in mind reductions are relative to root domains, not modified views."); + auto axis_iter = axes_set.begin(); for (size_t dim = 0; dim < orig_domain.size(); dim++) { bool isReduction = false; - if (*axes_set.begin() == dim) { + if (axis_iter != axes_set.end() && *axis_iter == dim) { isReduction = true; - axes_set.erase(axes_set.begin()); + axis_iter++; } const IterDomain* id = orig_domain[dim]; @@ -730,6 +733,7 @@ TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) { return clamp(in->as(), min_val, max_val)->as(); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index b81c589a87080..3db5d4c4b70e4 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -17,6 +17,7 @@ class Val; namespace torch { namespace jit { namespace fuser { +namespace cuda { // Insertion of casting op to dtype, returns new resulting val TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1); @@ -183,6 +184,7 @@ TORCH_CUDA_API TensorView* threshold(TensorView* in, Val* thresh, Val* value); TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val); TORCH_CUDA_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val); +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index f6e791f0edba9..c787404e24356 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -1,4 +1,3 @@ - #include #include #include @@ -12,6 +11,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace codegen { namespace { @@ -113,7 +113,11 @@ class CudaKernelGenerator : private OptInConstDispatch { // Shared memory if (has_dynamic_smem || has_reductions) { indent() << "alignas(" +#ifndef __HIP_PLATFORM_HCC__ << dataTypeSize(kernel_summary.largest_smem_data_type) +#else + << 8 // for HIP, we want 8-aligned even for smaller datatypes +#endif << ") extern __shared__ char array[];\n"; if (has_dynamic_smem) { @@ -289,9 +293,9 @@ class CudaKernelGenerator : private OptInConstDispatch { code_ << *op << gen(node->in()); } else { if (node->getUnaryOpType() == UnaryOpType::Cast) { - const auto cast_str = - cast_func_str({node->in()->getDataType().value(), - node->out()->getDataType().value()}); + const auto cast_str = cast_func_str( + {node->in()->getDataType().value(), + node->out()->getDataType().value()}); code_ << cast_str.value(); } else { code_ << node->getUnaryOpType(); @@ -581,34 +585,43 @@ class CudaKernelGenerator : private OptInConstDispatch { TORCH_INTERNAL_ASSERT(tv->domain()->nDims() > 0); TORCH_INTERNAL_ASSERT(node->size() != nullptr); - switch (tv->memoryType()) { - case MemoryType::Global: - indent() << "// Allocate global tensor " << gen(tv) << "\n"; - break; - case MemoryType::Shared: - if (node->size()->isConstScalar()) { - // Static shared memory - indent() << "__shared__ " << node->buffer_type() << " " << gen(tv) - << "[" << genInline(node->size()) << "];\n"; - } else { - // Align Offset Position - indent() << "offset = alignBufferSize(offset," - << dataTypeSize(node->buffer_type()) << ");\n"; - // Shared Memory Pointer - indent() << node->buffer_type() << "* " << gen(tv) - << " = reinterpret_cast<" << node->buffer_type() << "*>" - << "(array + offset);\n"; - // Increment Offset Position - indent() << "offset += (" << genInline(node->size()) << " * sizeof(" - << node->buffer_type() << "));\n"; - } - break; - case MemoryType::Local: - indent() << node->buffer_type() << " " << gen(tv) << "[" - << genInline(node->size()) << "];\n"; - break; - default: - TORCH_INTERNAL_ASSERT(false, "Unexpected memory type"); + if (node->alias() != nullptr) { + // Allocate alias another Allocate node + const auto alias_tv = node->alias()->buffer()->as(); + indent() << "// Alias Allocation - " << node->getMemoryType() << "\n"; + indent() << node->buffer_type() << "* " << gen(tv) << " = " + << gen(alias_tv) << ";\n"; + } else { + // Standard Memory Allocation + switch (tv->memoryType()) { + case MemoryType::Global: + indent() << "// Allocate global tensor " << gen(tv) << "\n"; + break; + case MemoryType::Shared: + if (node->size()->isConstScalar()) { + // Static shared memory + indent() << "__shared__ " << node->buffer_type() << " " << gen(tv) + << "[" << genInline(node->size()) << "];\n"; + } else { + // Align Offset Position + indent() << "offset = alignBufferSize(offset," + << dataTypeSize(node->buffer_type()) << ");\n"; + // Shared Memory Pointer + indent() << node->buffer_type() << "* " << gen(tv) + << " = reinterpret_cast<" << node->buffer_type() << "*>" + << "(array + offset);\n"; + // Increment Offset Position + indent() << "offset += (" << genInline(node->size()) << " * sizeof(" + << node->buffer_type() << "));\n"; + } + break; + case MemoryType::Local: + indent() << node->buffer_type() << " " << gen(tv) << "[" + << genInline(node->size()) << "];\n"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected memory type"); + } } } @@ -635,6 +648,7 @@ std::string generateCudaKernel( } } // namespace codegen +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/codegen.h b/torch/csrc/jit/codegen/cuda/codegen.h index 562aa1554eb2f..0304a61f8e7e0 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.h +++ b/torch/csrc/jit/codegen/cuda/codegen.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -9,6 +8,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace codegen { //! Generates a CUDA kernel definition for the given kernel @@ -17,6 +17,7 @@ TORCH_CUDA_API std::string generateCudaKernel( const std::string& kernel_name = "CUDAGeneratedKernel"); } // namespace codegen +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 9f8f7aba1cf41..974e993739bc7 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -9,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { ComputeAtData::ComputeAtData(TensorView* tv) : tv_ref_(tv), @@ -477,6 +478,7 @@ ComputeAt::ComputeAt( setCommonConsumer(); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index a9112a6225ca6..0ceac0e5c9daf 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -10,6 +10,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class TensorDomain; class TensorView; @@ -158,6 +159,7 @@ class ComputeAt { ComputeAt& operator=(const ComputeAt& other) = delete; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 597215821b6ad..f3a8837478cc6 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -7,6 +7,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { template T* ptr(T& obj) { @@ -545,6 +546,7 @@ Statement* OptOutMutator::mutate(Val* v) { return Val::mutatorDispatch(this, v); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 647d3fa4458f7..2cade85ba06d6 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -47,6 +47,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class Fusion; @@ -649,6 +650,7 @@ class TORCH_CUDA_API OptInMutator { } }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index f33079bcbab5b..76ba1faf66419 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -1,4 +1,3 @@ - #include #include #include @@ -11,10 +10,13 @@ #include #include #include +#include #include #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -24,11 +26,15 @@ int FusionExecutor::fusion_id_counter_ = 0; std::string FusionExecutor::getStructuredCode(const std::string& kernel) { // generating cuda code; - std::string code = std::string("namespace ") + - FusionExecutor::kernelNamespace() + " {\n" + - executor_utils::kernelPreamble() + kernel + "}\n"; - - const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG"); + std::string code = ""; +#ifdef __HIP_PLATFORM_HCC__ + code += std::string("#include \n") + + std::string("#include \n"); +#endif + code += std::string("namespace ") + FusionExecutor::kernelNamespace() + + " {\n" + executor_utils::kernelPreamble() + kernel + "}\n"; + + const char* debug_env = std::getenv("PYTORCH_CUDA_FUSER_DEBUG"); if (debug_env && atoi(debug_env)) { std::cout << "\n==== codegen output for kernel: " << kernelName() << " ====" << std::endl @@ -50,7 +56,7 @@ void FusionExecutor::debugCompileFusionFromStr( FusionGuard fg(&fusion_); options_ = options; - const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG"); + const char* debug_env = std::getenv("PYTORCH_CUDA_FUSER_DEBUG"); if (debug_env && atoi(debug_env)) { std::cout << "\n==== codegen output for kernel: " << kernelName() << " ====" << std::endl @@ -59,8 +65,30 @@ void FusionExecutor::debugCompileFusionFromStr( << std::endl; } + setUsedTVs(); + fusion_id_ = id; lowered_ = GpuLower(&fusion_); + const auto kernel = lowered_.kernel(); + + const char* dump_kir_env = std::getenv("PYTORCH_CUDA_FUSER_DUMP_KIR"); + if (dump_kir_env && atoi(dump_kir_env)) { + kernel->print(); + } + + const auto& kernel_summary = kernel->summary(); + has_block_reductions = kernel_summary.has_block_reductions; + has_grid_reductions = kernel_summary.has_grid_reductions; + has_block_broadcasts = kernel_summary.has_block_broadcasts; + + if (!kernel_summary.static_smem_allocations.empty()) { + StatefulExpressionEvaluator static_evaluator(&fusion_); + unsigned static_smem_size = computeSharedMemory( + static_evaluator, kernel_summary.static_smem_allocations); + TORCH_INTERNAL_ASSERT( + static_smem_size < max_device_smem, + "The static shared memory allocation is larger than available memory."); + } compiled_kernel_ = executor_utils::nvrtcCompile(code, name, fusion_id_); TORCH_INTERNAL_ASSERT( @@ -94,6 +122,12 @@ void FusionExecutor::compileFusion(Fusion* fusion, CompileOptions options) { fusion_id_ = ++fusion_id_counter_; lowered_ = GpuLower(&fusion_); const auto kernel = lowered_.kernel(); + + const char* dump_kir_env = std::getenv("PYTORCH_CUDA_FUSER_DUMP_KIR"); + if (dump_kir_env && atoi(dump_kir_env)) { + kernel->print(); + } + const auto kernel_code = codegen::generateCudaKernel(kernel, kernelName()); const auto structured_code = getStructuredCode(kernel_code); @@ -141,17 +175,18 @@ at::Tensor inferAndAlloc( } auto at_type = data_type_to_aten(tv->getDataType().value()); - auto tensor_options = - at::TensorOptions().dtype(at_type).device(options.device); if (zero_init) { + auto tensor_options = + at::TensorOptions().dtype(at_type).device(options.device); c10::IntArrayRef isizes(sizes); return at::zeros(isizes, tensor_options); } else { c10::IntArrayRef isizes(sizes); // Non Variable type guard for empty_cuda call at::AutoNonVariableTypeMode non_variable_type_mode; - return at::native::empty_cuda(isizes, tensor_options); + return at::native::empty_cuda( + isizes, at_type, c10::nullopt, options.device, c10::nullopt); } } @@ -164,21 +199,25 @@ uint64_t FusionExecutor::computeSharedMemory( uint64_t total) { FUSER_PERF_SCOPE("computeSharedMemory"); for (auto smem_alloc : buffers) { - auto inferred_val = see.inferValue(smem_alloc->size()); - if (inferred_val.has_value()) { - const uint64_t data_size = dataTypeSize(smem_alloc->buffer_type()); - // Add padding to align dynamic shared memory - if (align_padding) { - total = ceilDiv(total, data_size) * data_size; + // If this buffer aliases another buffer, + // then do not allocate memory for this buffer. + if (smem_alloc->alias() == nullptr) { + auto inferred_val = see.inferValue(smem_alloc->size()); + if (inferred_val.has_value()) { + const uint64_t data_size = dataTypeSize(smem_alloc->buffer_type()); + // Add padding to align dynamic shared memory + if (align_padding) { + total = ceilDiv(total, data_size) * data_size; + } + total += inferred_val.value() * data_size; + } else { + TORCH_INTERNAL_ASSERT( + false, + "Failed to evaluate the size ", + smem_alloc->size(), + " of shared memory buffer - T", + smem_alloc->buffer()->name()); } - total += inferred_val.value() * data_size; - } else { - TORCH_INTERNAL_ASSERT( - false, - "Failed to evaluate the size ", - smem_alloc->size(), - " of shared memory buffer - T", - smem_alloc->buffer()->name()); } } return total; @@ -373,18 +412,20 @@ std::vector FusionExecutor::runFusion( // take the short-cut for launch if we see a recorded input set again; launch_params = executor_entry->launch_params; for (size_t i = 0; i < executor_entry->output_sizes.size(); i++) { - auto tensor_options = at::TensorOptions() - .dtype(executor_entry->output_types[i]) - .device(options_.device); alloced_outputs.push_back(at::native::empty_cuda( - executor_entry->output_sizes[i], tensor_options)); + executor_entry->output_sizes[i], + executor_entry->output_types[i], + c10::nullopt, + options_.device, + c10::nullopt)); } for (size_t i = 0; i < executor_entry->empty_buffer_sizes.size(); i++) { - auto tensor_options = at::TensorOptions() - .dtype(executor_entry->empty_buffer_types[i]) - .device(options_.device); global_buffers.empty_buffers.push_back(at::native::empty_cuda( - executor_entry->empty_buffer_sizes[i], tensor_options)); + executor_entry->empty_buffer_sizes[i], + executor_entry->empty_buffer_types[i], + c10::nullopt, + options_.device, + c10::nullopt)); } } for (size_t i = 0; i < executor_entry->zero_buffer_sizes.size(); i++) { diff --git a/torch/csrc/jit/codegen/cuda/executor_launch_params.h b/torch/csrc/jit/codegen/cuda/executor_launch_params.h index 981352e4839bf..2794dc6822884 100644 --- a/torch/csrc/jit/codegen/cuda/executor_launch_params.h +++ b/torch/csrc/jit/codegen/cuda/executor_launch_params.h @@ -72,6 +72,8 @@ class TORCH_CUDA_API LaunchParams { class_val == UNINITIALIZED_VAL || incoming_val == class_val, "Tried to set ", val, + " from ", + class_val, " to ", incoming_val, ", but it was already set and new value does not match.", diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 9670968b8fe18..61ca4ef3db892 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -1,14 +1,23 @@ #include #include +#include #include #include #include #include -#include +#include #include +#include +#include +#include +#include +#include +#include +#include + #include namespace torch { @@ -19,13 +28,18 @@ namespace executor_utils { std::string kernelPreamble() { std::stringstream ss; - ss << code_template_tensor_struct << "\n" - << code_fp16_support << "\n" - << code_random_number_gen << "\n" - << code_helper_funcs << "\n" - << code_template_block_reduction << "\n" - << code_template_grid_reduction << "\n" - << code_template_block_broadcast << "\n"; + +#ifndef __HIP_PLATFORM_HCC__ + ss << nvfuser_resources::fp16_support_cu; +#endif + + ss << nvfuser_resources::tensor_cu; + ss << nvfuser_resources::random_numbers_cu; + ss << nvfuser_resources::helpers_cu; + ss << nvfuser_resources::block_reduction_cu; + ss << nvfuser_resources::grid_reduction_cu; + ss << nvfuser_resources::broadcast_cu; + return ss.str(); } @@ -246,18 +260,11 @@ NvrtcFunction nvrtcCompile( } const auto prop = at::cuda::getCurrentDeviceProperties(); - int nvrtc_major, nvrtc_minor; - AT_CUDA_NVRTC_CHECK( - at::globalContext().getNVRTC().nvrtcVersion(&nvrtc_major, &nvrtc_minor)); - - // Short-circuits if NVRTC version too low - TORCH_INTERNAL_ASSERT(nvrtc_major >= 6); - // Major and minor is determined by device properties and - // possibly "downcompiled" to a lower (compatible) compute architecture - // based on the NVRTC version - const int major = prop->major; - const int minor = prop->minor; - nvrtcProgram program; + + int major = 0, minor = 0; + getMajorMinor(prop, major, minor); + + nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables) { FUSER_PERF_SCOPE("nvrtcCreateProgram"); @@ -271,10 +278,14 @@ NvrtcFunction nvrtcCompile( at::globalContext().getNVRTC().nvrtcDestroyProgram(&program)); }); +#ifdef __HIP_PLATFORM_HCC__ + std::vector args = {"--std=c++14"}; +#else const std::string compute = "--gpu-architecture=compute_" + std::to_string(major) + std::to_string(minor); std::vector args = { "--std=c++14", compute.c_str(), "-default-device"}; +#endif const char* disable_fma = getenv("PYTORCH_CUDA_FUSER_DISABLE_FMA"); // int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0; @@ -345,6 +356,7 @@ NvrtcFunction nvrtcCompile( // TODO: We do go through different code path, should investigate whether this // has an impact on generated binary. const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN"); +#ifndef __HIP_PLATFORM_HCC__ if (prefix_env) { FUSER_PERF_SCOPE("load CUBIN"); @@ -402,6 +414,12 @@ NvrtcFunction nvrtcCompile( options.data(), option_vals.data())); } +#else + // load ptx directly + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData( + &(compiled_kernel_.module), ptx.data())); + +#endif AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleGetFunction( &(compiled_kernel_.function), compiled_kernel_.module, diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.h b/torch/csrc/jit/codegen/cuda/executor_utils.h index 76b8a9a145f19..28a702b98d735 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.h +++ b/torch/csrc/jit/codegen/cuda/executor_utils.h @@ -1,11 +1,12 @@ #pragma once #include -#include #include #include +#include + #include #include @@ -13,6 +14,9 @@ #include #include +#include +#include + namespace torch { namespace jit { namespace fuser { diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 17fb81ceaf6a4..21e018e9382f5 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -9,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { void StatefulExpressionEvaluator::safeBind( Val* value, @@ -82,7 +83,7 @@ c10::optional StatefulExpressionEvaluator::getValue( Val* value) { TORCH_INTERNAL_ASSERT( value->isAnInt(), - "Expressoin Evaluation does not support values other than integers at this time."); + "Expression Evaluation does not support values other than integers at this time."); switch (value->getValType().value()) { case ValType::Scalar: @@ -219,6 +220,7 @@ void StatefulExpressionEvaluator::handle(kir::BinaryOp* bop) { } } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index 40ba53380fae0..33716ff80e5e2 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -13,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class TORCH_CUDA_API StatefulExpressionEvaluator : private OptOutDispatch { public: @@ -77,6 +77,7 @@ class TORCH_CUDA_API StatefulExpressionEvaluator : private OptOutDispatch { Fusion* fusion_ = nullptr; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index fcb12a978d2a8..4a6fc5848e558 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -1,5 +1,5 @@ - #include + #include #include #include @@ -15,6 +15,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { static thread_local Fusion* ACTIVE_FUSION = nullptr; @@ -625,6 +626,7 @@ std::vector Fusion::getTerminatingOutputs() { return terminating_outputs; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index 99c97cc919435..e54e99c1386b4 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -12,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* * Usage: FusionGuard and Fusion are required user interfaces for any operation @@ -229,6 +230,7 @@ class TORCH_CUDA_API Fusion final { std::unordered_map lowered_origin_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp index 1dfdc7b1edcd8..e8299bd21450f 100644 --- a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp @@ -15,6 +15,8 @@ #include #include +#include + #include #include @@ -677,7 +679,6 @@ struct CudaGraphFuser { // Builds up expressions that compute shapes of all intermediates (and // outputs) of the fusion group, based on the sizes of inputs. You should run // DCE to remove those that you end up not using. - /* std::unordered_map buildShapeExpressions(Node* fusion_group) { WithInsertPoint insert_guard{fusion_group->next()}; std::unordered_map shape_of; @@ -736,6 +737,38 @@ struct CudaGraphFuser { shape_of.emplace(outputs.at(outputs.size() - 1), last_size); continue; } + // extended shape expression support to reduction operations + // TODO: `aten::sum` is too flexible, we should restrict for a better + // match + if (n->kind() == aten::sum) { + // TODO: expand support to wire non-constant inputs, this is currently + // blocked by profiling executor not capable of profiling scalar inputs. + TORCH_INTERNAL_ASSERT( + n->input(1)->node()->kind() == prim::Constant && + n->input(2)->node()->kind() == prim::Constant, + "only supports reduction axes and keepdim being constant"); + + // hmmm, do I need to setInsertPoint... + Node* in1_const = + graph->createClone(n->input(1)->node(), [](Value*) -> Value* { + throw std::runtime_error("unexpected input"); + }); + graph->insertNode(in1_const); + Node* in2_const = + graph->createClone(n->input(2)->node(), [](Value*) -> Value* { + throw std::runtime_error("unexpected input"); + }); + graph->insertNode(in2_const); + + std::vector inputs = { + shape_of.at(n->input(0)), in1_const->output(), in2_const->output()}; + Node* size_node = + graph->insertNode(graph->create(prim::ReductionSizes, inputs, 1)); + Value* size = size_node->output(0); + size->setType(ListType::ofInts()); + shape_of.emplace(n->output(), size); + continue; + } auto tensor_inputs = filter(n->inputs(), [](Value* v) { return v->type()->isSubtypeOf(TensorType::get()); }); @@ -753,6 +786,8 @@ struct CudaGraphFuser { return; auto subgraph = fusion_group->g(attr::Subgraph); + // TODO: failure in buildShapeExpressions should not break fusion execution, + // we can add a try/catch here to bailout from removeOutputsUsedOnlyInSize. auto shape_of = buildShapeExpressions(fusion_group); auto outputs = fusion_group->outputs().vec(); auto soutputs = subgraph->outputs().vec(); @@ -774,7 +809,6 @@ struct CudaGraphFuser { } } } - */ void refreshAliasDb() { aliasDb_ = torch::make_unique(graph_); @@ -835,9 +869,9 @@ struct CudaGraphFuser { //} // Remove outputs that have been added only because we need their size - // for (Node* n : block_->nodes()) { - // removeOutputsUsedOnlyInSize(n); - //} + for (Node* n : block_->nodes()) { + removeOutputsUsedOnlyInSize(n); + } for (Node* node : block_->nodes()) { for (Block* sub_block : node->blocks()) { @@ -914,12 +948,139 @@ void PeepholeOptimizeShapeExpressions(Block* block) { } } +//! [ Note -- CudaFusionGuard implementation ] +//! +//! shamelessly copying code from NNC (tensorexpr_fuser) with very little +//! modification, original code at: +//! `../../passes/tensorexpr_fuser.cpp:guardFusionGroup` +//! +//! Add prim::CudaFusionGuard node to ensure that accepted profiling information +//! is not violated at runtime. +//! +//! We replace a single +//! +//! outputs = prim::CudaFusionGroup[cache_id](inputs) +//! +//! with the following pattern: +//! +//! %1 : bool = prim::CudaFusionGuard[types=[...]](inputs) +//! outputs = prim::If(%1) +//! block0(): +//! outputs = prim::CudaFusionGroup[cache_id](inputs) +//! -> (outputs) +//! block1(): +//! %2 : Function = prim::Constant[name="fallback_function", fallback=1]() +//! otuputs = prim::CallFunction(%2, inputs) +//! -> (outputs) +//! +//! `prim::CudaFusionGuard` stores all profiled data type in attribute +//! `attr::types`. +//! At runtime, we check input tensors against our profiled data type and return +//! an output holds the result of the check (bool). +//! See [ Note -- type guard logic in CudaFusionGuard ] +//! +//! This ensures that `prim::CudaFusionGroup` only execute compatible inputs. +//! In case of check failure, execution goes through false block, which +//! recursively goes along another profiling / optimization iteration. (could be +//! tuned by `bailout_depth`) +//! +//! TODO: we also need to assert/check reduction axes and replace it with +//! constants in `CudaFusionGroup` +void guardFusionGroup(Node* fusion) { + // Fixup types of the subgraph inputs + std::vector guard_types; + std::vector inputs_to_check; + for (Value* input : fusion->inputs()) { + // We only check inputs of the fusion group and expect NNC to infer + // intermediates and outputs shapes + if (!input->type()->cast()) { + continue; + } + + // note: modified from original implementation, we are guarding fusion + // outputs + if (input->node()->kind() == prim::Constant) { + continue; + } + inputs_to_check.push_back(input); + guard_types.push_back(input->type()); + } + if (!inputs_to_check.size()) { + return; + } + + Node* typecheck_node = fusion->owningGraph() + ->create(prim::CudaFusionGuard, inputs_to_check, 1) + ->insertBefore(fusion); + // fix output to BoolType + typecheck_node->output()->setType(BoolType::get()); + Value* typecheck_result = typecheck_node->output(); + typecheck_node->tys_(attr::types, guard_types); + + std::unordered_map typechecked_inputs; + + // Insert if block + auto versioning_if = + fusion->owningGraph() + ->create(prim::If, {typecheck_result}, fusion->outputs().size()) + ->insertAfter(typecheck_node); + for (size_t idx = 0; idx < fusion->outputs().size(); ++idx) { + versioning_if->output(idx)->setType(fusion->output(idx)->type()); + fusion->output(idx)->replaceAllUsesWith(versioning_if->output(idx)); + } + auto true_block = versioning_if->addBlock(); + auto false_block = versioning_if->addBlock(); + + // Fill in the false block. It should contain the unoptimized + // copy of the fused subgraph. + auto& subgraph = *fusion->g(attr::Subgraph); + WithInsertPoint guard(false_block->return_node()); + const auto subgraph_outputs = + insertGraph(*fusion->owningGraph(), subgraph, fusion->inputs()); + for (Value* output : subgraph_outputs) { + false_block->registerOutput(output); + } + + // types get copied to the fallback graph, so remove specializations before + // replacing + // TODO: this is not exposed here, I need to remove that before inserting the + // graph + // removeTensorTypeSpecializations(false_block); + replaceBlockWithFallbackGraph(false_block, fusion->inputs()); + + // Fill in the true block. It has all inputs type-checked and its + // body should be the fusion group node. + fusion->moveBefore(true_block->return_node()); + for (Value* output : fusion->outputs()) { + true_block->registerOutput(output); + } +} + +void guardFusionGroups(Block* block) { + std::vector fusions; + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + guardFusionGroups(b); + } + if (n->kind() == prim::CudaFusionGroup) { + fusions.push_back(n); + } + } + for (Node* fusion : fusions) { + guardFusionGroup(fusion); + } +} + } // anonymous namespace void CudaFuseGraph(std::shared_ptr& graph) { FUSER_PERF_SCOPE("CudaFuseGraph"); + // TODO: we need to properly restore shape information after fusion. + // shamelessly use tool from NNC. + RemoveProfileNodesAndSpecializeTypes(graph); CudaGraphFuser(graph->block(), graph).run(); + guardFusionGroups(graph->block()); // After FuseGraph some common subexpressions may come back EliminateCommonSubexpression(graph); // We might have emitted a fair amount of useless shape propagating code, so @@ -927,6 +1088,11 @@ void CudaFuseGraph(std::shared_ptr& graph) { EliminateDeadCode(graph); // Improve the quality of shape propagation code that was left PeepholeOptimizeShapeExpressions(graph->block()); + + // TODO: we need to properly restore shape information after fusion. + // shamelessly use tool from NNC. + RemoveTensorTypeSpecializations(graph); + // Compile CudaFusionGroup compileFusionRecursive(graph->block()); } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 9b757661e12d7..eebf7333eae9c 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1,11 +1,12 @@ - #include + #include #include #include #include #include #include +#include #include #include #include @@ -14,6 +15,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -810,7 +812,7 @@ kir::TensorIndex* Index::getGlobalProducerIndex( " dim: ", i, " id: ", - kir_root_dom_i); + kir::toString(kir_root_dom_i)); auto root_ind = index_map.at(kir_root_dom_i); TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind)); @@ -928,7 +930,7 @@ kir::TensorIndex* Index::getProducerIndex_impl( " dim: ", i, " id: ", - kir_root_dom_i); + kir::toString(kir_root_dom_i)); auto root_ind_i = index_map.at(kir_root_dom_i); TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind_i)); @@ -1037,7 +1039,7 @@ kir::TensorIndex* Index::getGlobalConsumerIndex( " dim: ", i, " id: ", - kir_root_dom_i); + kir::toString(kir_root_dom_i)); auto ind = index_map.at(kir_root_dom_i); if (i == root_dom.size() - 1 && inner_most_dim_contig) { @@ -1099,7 +1101,7 @@ kir::TensorIndex* Index::getConsumerIndex_impl( " dim: ", i, " id: ", - kir_root_dom_i); + kir::toString(kir_root_dom_i)); auto root_ind_i = index_map.at(kir_root_dom_i); TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind_i)); @@ -1278,6 +1280,7 @@ std::pair, bool> Index::getConsumerRootPredIndices( return std::make_pair(root_inds, use_rfactor); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index f227560e5a132..7b4b67df00924 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -56,6 +56,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class IndexCompute : public BackwardVisitor { private: @@ -188,6 +189,7 @@ class Index { bool unroll = false); }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/instrumentation.cpp b/torch/csrc/jit/codegen/cuda/instrumentation.cpp index 80a0c66075f03..49196bee40b7c 100644 --- a/torch/csrc/jit/codegen/cuda/instrumentation.cpp +++ b/torch/csrc/jit/codegen/cuda/instrumentation.cpp @@ -1,10 +1,9 @@ - #include #include #ifdef _WIN32 -#include +#include #else #include #include @@ -13,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace inst { Trace::Trace() { @@ -51,7 +51,7 @@ void Trace::logEvent(char ph, const char* name, char sep) { const unsigned int tid = GetCurrentThreadId(); #else const unsigned int pid = getpid(); - const unsigned int tid = pthread_self(); + const unsigned int tid = std::hash{}(pthread_self()); #endif // _WIN32 fprintf( @@ -66,6 +66,7 @@ void Trace::logEvent(char ph, const char* name, char sep) { } } // namespace inst +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/instrumentation.h b/torch/csrc/jit/codegen/cuda/instrumentation.h index b3c2454570eea..63204d770872f 100644 --- a/torch/csrc/jit/codegen/cuda/instrumentation.h +++ b/torch/csrc/jit/codegen/cuda/instrumentation.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -9,6 +8,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace inst { //! An optional record of selected timestamped operations, events and counters @@ -85,9 +85,10 @@ class TraceScope : public NonCopyable { //! \param name The name of the scope, normally a simple string literal //! #define FUSER_PERF_SCOPE(name) \ - fuser::inst::TraceScope FUSER_ANONYMOUS(_perf_scope_)(name) + torch::jit::fuser::cuda::inst::TraceScope FUSER_ANONYMOUS(_perf_scope_)(name) } // namespace inst +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index 42dfed02b1149..b9a92ec87f488 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -1,13 +1,21 @@ - #include + #include #include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static std::atomic cuda_fusion_guard_mode{true}; + +std::atomic& getCudaFusionGuardMode() { + return cuda_fusion_guard_mode; +} + CudaFuserInterface* getFuserInterface() { static CudaFuserInterface fuser_interface_; return &fuser_interface_; @@ -29,16 +37,149 @@ void runFusionGroup(const Node* fusion_node, Stack& stack) { void fuseGraph(std::shared_ptr& graph) { TORCH_CHECK( - getFuserInterface()->fn_fuse_graph != nullptr, + getFuserInterface()->fn_fuse_graph_ != nullptr, "Running the CUDA fuser requires a CUDA build."); - getFuserInterface()->fn_fuse_graph(graph); + getFuserInterface()->fn_fuse_graph_(graph); +} + +bool canFuseNode(const Node* node) { + return getFuserInterface()->fn_can_fuse_n_ != nullptr && + getFuserInterface()->fn_can_fuse_n_(node); +} + +//! [ Note -- type guard logic in CudaFusionGuard ] +//! +//! CudaFusionGuard is used to Guard input tensor to `CudaFusionGroup` so that +//! we would not feed inputs that violates the graph defined in `GraphCache`. +//! +//! see [ Note -- 2 level cache implementation ] for definition of unique +//! computational graph. +//! see [ Note -- CudaFusionGuard implementation] for details on how guard works +//! in profiling executor +//! +//! Type guard logic is used to query whether a runtime input `tensor` compiles +//! with profiled `guard_tensor_type`. `guard_tensor_type` is the observed +//! tensor type during profiling runs. +//! +//! At this moment, we only do single profiling run, so `guard_tensor_type` has +//! static shape / stride / scalarType. *This might be a little confusing as our +//! implementation is actually more relaxed. +//! +//! Things that we check: +//! a. identical rank & scalar type +//! b. stride check: +//! b.1. identical stride order +//! b.2. identical contiguity +//! note that contiguity here is used for tensor collapsing. So +//! extra attention should be paid to contiguity across size-1 +//! dimensions. +//! c. size check: +//! making sure that broadcast semantics are identical. So we want to +//! make sure a given dimension either are both size-1 for `tensor` & +//! `guard_tensor_type`, or are both non-size-1. +//! This is due to the fact that we specialize size-1 dimension as +//! broadcasted dimension while translating PyTorch tensor to Fusion IR. +//! +bool complyWith( + const at::Tensor& tensor, + const c10::TensorTypePtr& guard_tensor_type) { + // guard broadcast semantics, contiguity & stride order; + TORCH_INTERNAL_ASSERT( + guard_tensor_type && guard_tensor_type->dim().has_value()); + + // check a. if num_dimension check fails or scalar type check fails + if (*guard_tensor_type->dim() != static_cast(tensor.ndimension()) || + (guard_tensor_type->scalarType().has_value() && + (guard_tensor_type->scalarType().value() != tensor.scalar_type())) || + tensor.requires_grad()) { + return false; + } + + // TODO: should we get symbolic_size instead and check for size + // consistency across tensors as well? + const auto& sizes = guard_tensor_type->sizes(); + const auto& stride_properties = guard_tensor_type->stride_properties(); + + const auto& t_sizes = tensor.sizes(); + const auto& t_strides = tensor.strides(); + int inner_dim = -1; + for (size_t j = 0; j < *guard_tensor_type->dim(); j++) { + // check b. for stride check, we go along dimensions from fastest stride to + // slowest stride + int sorted_index = stride_properties[j]->stride_index_ + ? static_cast(*stride_properties[j]->stride_index_) + : -1; + + // only apply stride check when we have stride_properties + if (sorted_index != -1) { + // check b.1. stride order [current dimension has stride larger + // than its inner dimension(s)], check only applies when both: + // i. already encountered an inner dimension + // ii. not at the fastest dimension + if (j != 0 && inner_dim != -1) { + // we are not looking at dim-j, but dim-sorted_index, which + // is the j-th fastest dim; + // TODO: merge this with above and put a long comment there + if (t_strides[sorted_index] < t_strides[inner_dim]) { + return false; + } + } + + // check b.2. contiguity, we only check when it's marked as + // contiguous. + if (stride_properties[j]->contiguous_ && + *stride_properties[j]->contiguous_) { + if (j != 0) { + // we use contiguity to collapse dimension, if size == 1, it is + // always collapsible + if (t_sizes[sorted_index] != 1) { + TORCH_INTERNAL_ASSERT( + stride_properties[j - 1]->stride_index_.has_value(), + "Counknown index is meaningless"); + // TODO: merge this check up + if (t_strides[sorted_index] != + t_strides[inner_dim] * t_sizes[inner_dim]) { + return false; + } + } + } else { + // TODO: merge this check up + if (t_strides[sorted_index] != 1) { + return false; + } + } + } + + // update inner_dim to be current dim. Note that we try to skip update + // when current `t_size[sorted_index] == 1`, because: + // 1. stride comparison on a size-1 dimension is meaningless + // [check b.1] + // 2. contiguity on a size-1 dimension is misleading. For collapsing, + // we should actually look at the next non-size-1 dimension + // [check b.2] + if (inner_dim == -1 || t_sizes[sorted_index] != 1) { + inner_dim = sorted_index; + } + } + + // check c, we go along semantic ordered dimensions + // check broadcast / size-1: + bool guard_bcast = sizes[j].has_value() && sizes[j].value() == 1; + if (guard_bcast != (t_sizes[j] == 1)) { + return false; + } + } + + return true; } } // namespace cuda } // namespace fuser namespace { -RegisterOperators reg({ + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_fusion({ Operator( prim::CudaFusionGroup, [](const Node* node) -> Operation { @@ -46,9 +187,52 @@ RegisterOperators reg({ fuser::cuda::runFusionGroup(node, *stack); }; }, - c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE), + aliasAnalysisSpecialCase()), }); -} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_guard({ + Operator( + "prim::CudaFusionGuard(...) -> bool", + // prim::CudaFusionGuard returns a fresh Boolean type without aliasing. + // if we would ever return refined tensor, which would change aliasing + // analysis, we should update aliasdb pass. + [](const Node* node) -> Operation { + return [node](Stack* stack) { + // TODO: check latency here!!!! + std::vector types = node->tys(attr::types); + const auto num_inputs = types.size(); + at::ArrayRef inputs = last(stack, num_inputs); + drop(stack, num_inputs); + + if (!fuser::cuda::getCudaFusionGuardMode()) { + push(stack, IValue(true)); + return; + } + + for (size_t i = 0; i < num_inputs; i++) { + const c10::TensorTypePtr& guard_tensor_type = + types[i]->cast(); + + // TODO: maybe we should just push false and fallback + TORCH_INTERNAL_ASSERT(inputs[i].isTensor()); + const at::Tensor& tensor = inputs[i].toTensor(); + + if (!fuser::cuda::complyWith(tensor, guard_tensor_type)) { + push(stack, IValue(false)); + return; + } + } + + // TODO: check type and return the right flag + // naively return true; + push(stack, IValue(true)); + return; + }; + }, + aliasAnalysisFromSchema()), +}); +} // namespace } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/interface.h b/torch/csrc/jit/codegen/cuda/interface.h index 0479a124705ac..00d94a9f12e01 100644 --- a/torch/csrc/jit/codegen/cuda/interface.h +++ b/torch/csrc/jit/codegen/cuda/interface.h @@ -2,6 +2,7 @@ #include #include +#include /* * This file contains APIs for cuda fuser; @@ -16,11 +17,14 @@ namespace jit { namespace fuser { namespace cuda { +TORCH_API std::atomic& getCudaFusionGuardMode(); + // dummy struct to allow API registration struct CudaFuserInterface { void (*fn_compile_n_)(Node*) = nullptr; void (*fn_run_n_s_)(const Node*, Stack&) = nullptr; - void (*fn_fuse_graph)(std::shared_ptr&) = nullptr; + void (*fn_fuse_graph_)(std::shared_ptr&) = nullptr; + bool (*fn_can_fuse_n_)(const Node*) = nullptr; }; // Get interface, this is used by registration and user facing API internally @@ -29,6 +33,11 @@ C10_EXPORT CudaFuserInterface* getFuserInterface(); C10_EXPORT void compileFusionGroup(Node* fusion_node); C10_EXPORT void runFusionGroup(const Node* fusion_node, Stack& stack); C10_EXPORT void fuseGraph(std::shared_ptr&); +C10_EXPORT bool canFuseNode(const Node* node); + +C10_EXPORT bool complyWith( + const at::Tensor& tensor, + const c10::TensorTypePtr& guard_tensor_type); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 9f6b3fdb50b65..9d625b3c1a628 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -1,4 +1,3 @@ - #include #include #include @@ -18,6 +17,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { Statement::Statement(const Statement* src, IrCloner* ir_cloner) { name_ = src->name_; @@ -234,6 +234,7 @@ bool Expr::sameAs(const Expr* const other) const { return true; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 2719cd056f95c..29d284d9b5ba0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -34,6 +34,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { using StmtNameType = unsigned int; @@ -352,6 +353,7 @@ class TORCH_CUDA_API Expr : public Statement { std::vector outputs_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 17efc3e692e7a..72ae3d51c567e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -1,11 +1,12 @@ - #include + #include #include namespace torch { namespace jit { namespace fuser { +namespace cuda { Statement* IrCloner::clone(const Statement* statement) { if (statement == nullptr) { @@ -114,6 +115,7 @@ void IrCloner::handle(const Merge* merge) { clone_ = new Merge(merge, this); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 39435aab4e657..213154c810597 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -10,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class Fusion; @@ -81,6 +81,7 @@ class TORCH_CUDA_API IrCloner : private OptInConstDispatch { std::unordered_map clones_map_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index 488e626299ad4..ea3d31b8a53d3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -1,5 +1,5 @@ - #include + #include #include #include @@ -9,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -372,8 +373,8 @@ void IrGraphGenerator::handle(const TensorView* tv) { const bool is_output = outputs_.find(tv) != outputs_.end(); const char* style = is_input ? "style=filled, fillcolor=palegreen" - : is_output ? "style=filled, fillcolor=lightblue" - : "style=filled, fillcolor=beige"; + : is_output ? "style=filled, fillcolor=lightblue" + : "style=filled, fillcolor=beige"; graph_def_ << " " << getid(tv) << " [label=\"" << label.str() << "\", shape=Mrecord, color=brown, " << style << "];\n"; @@ -456,6 +457,7 @@ void IrGraphGenerator::handle(const Merge* merge) { addArc(merge, merge->out()); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index e3c41fb525ff0..4c8e0bf0e4678 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -13,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // Generates a DOT (https://www.graphviz.org) graph // representation of a fuser IR @@ -110,6 +110,7 @@ class TORCH_CUDA_API IrGraphGenerator : private OptInConstDispatch { int next_id_ = 1; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 4186f7dfcd885..ff0c709001f03 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -16,6 +16,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* * A Bool value. @@ -413,6 +414,7 @@ class TORCH_CUDA_API TensorView : public Val { MemoryType memory_type_ = MemoryType::Local; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index ca71fd6c2d623..d5e573344ca75 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -15,6 +15,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // Returns true if both v1 and v2 are scalars, are the same type of scalars, and // dispatches to the inherited Val type's `->sameAs` call. e.g. if both vals are @@ -677,6 +678,7 @@ class TORCH_CUDA_API NamedScalar : public Val { std::string name_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index e82e3fd5baa46..8108a3c04e92a 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -1,5 +1,5 @@ - #include + #include #include #include @@ -8,6 +8,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // Make sure we can inline something, before we attempt to. static void checkInlineable(const Expr* expr) { @@ -176,39 +177,39 @@ void IrPrinter::handle(const NamedScalar* i) { } void IrPrinter::handle(const kir::Bool* b) { - os_ << "kir::Bool"; + os_ << "kir::Bool (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::Float* f) { - os_ << "kir::Float"; + os_ << "kir::Float (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::Half* h) { - os_ << "kir::Half"; + os_ << "kir::Half (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::Int* i) { - os_ << "kir::Int"; + os_ << "kir::Int (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::NamedScalar*) { - os_ << "kir::NamedScalar"; + os_ << "kir::NamedScalar (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::TensorIndex*) { - os_ << "kir::TensorIndex"; + os_ << "kir::TensorIndex (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::IterDomain*) { - os_ << "kir::IterDomain"; + os_ << "kir::IterDomain (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::TensorDomain*) { - os_ << "kir::TensorDomain"; + os_ << "kir::TensorDomain (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::TensorView*) { - os_ << "kir::TensorView"; + os_ << "kir::TensorView (use kir::toString() to print Kernel IR nodes)"; } static bool isTV(const Val* val) { @@ -350,15 +351,15 @@ void IrPrinter::handle(const TernaryOp* top) { } void IrPrinter::handle(const kir::UnaryOp* uop) { - os_ << "kir::UnaryOp"; + os_ << "kir::UnaryOp (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::BinaryOp* bop) { - os_ << "kir::BinaryOp"; + os_ << "kir::BinaryOp (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::TernaryOp* top) { - os_ << "kir::TernaryOp"; + os_ << "kir::TernaryOp (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const ReductionOp* rop) { @@ -370,11 +371,11 @@ void IrPrinter::handle(const ReductionOp* rop) { } void IrPrinter::handle(const kir::ReductionOp* rop) { - os_ << "kir::ReductionOp"; + os_ << "kir::ReductionOp (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::GridReduction* gr) { - os_ << "kir::GridReduction"; + os_ << "kir::GridReduction (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const BroadcastOp* bop) { @@ -384,23 +385,23 @@ void IrPrinter::handle(const BroadcastOp* bop) { } void IrPrinter::handle(const kir::BroadcastOp*) { - os_ << "kir::BroadcastOp"; + os_ << "kir::BroadcastOp (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::ForLoop* fl) { - os_ << "kir::ForLoop"; + os_ << "kir::ForLoop (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::IfThenElse* ite) { - os_ << "kir::IfThenElse"; + os_ << "kir::IfThenElse (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::Allocate* a) { - os_ << "kir::Allocate"; + os_ << "kir::Allocate (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const kir::Sync* a) { - os_ << "kir::Sync"; + os_ << "kir::Sync (use kir::toString() to print Kernel IR nodes)"; } void IrPrinter::handle(const Split* s) { @@ -440,6 +441,7 @@ std::ostream& operator<<(std::ostream& os, Fusion& f) { return os << &f; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 01e8bdaa09dcb..9a2323e727995 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -10,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { //! Define pretty printing functions for IR nodes //! @@ -110,9 +110,33 @@ class TORCH_CUDA_API IrPrinter : public OptInConstDispatch { TORCH_CUDA_API std::ostream& operator<<( std::ostream& os, const Statement* stmt); + TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion* f); TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion& f); +// TODO(kir): catch accidental << printing of Kernel IR nodes +// (use kir::toString(node) instead) +std::ostream& operator<<(std::ostream& os, const kir::Bool*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::Float*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::Half*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::Int*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::NamedScalar*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::TensorIndex*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::IterDomain*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::TensorDomain*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::TensorView*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::UnaryOp*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::BinaryOp*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::TernaryOp*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::ReductionOp*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::BroadcastOp*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::GridReduction*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::ForLoop*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::IfThenElse*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::Allocate*) = delete; +std::ostream& operator<<(std::ostream& os, const kir::Sync*) = delete; + +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 2e1e34de6871e..3d5ac416c93a3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -12,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -1197,8 +1198,8 @@ class ConcretizeDomain : private BackwardVisitor { bcast_domain_map_[id] = concretized(To); } -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Woverloaded-virtual" + using BackwardVisitor::handle; + void handle(ReductionOp* rop) override { concretizePwOp(rop); } @@ -1214,7 +1215,6 @@ class ConcretizeDomain : private BackwardVisitor { void handle(TernaryOp* top) override { concretizePwOp(top); }; -#pragma clang diagnostic pop private: using MapType = std::unordered_map; @@ -1222,7 +1222,12 @@ class ConcretizeDomain : private BackwardVisitor { }; void ConcretizeDomain::concretizePwOp(Expr* e) { - TensorView* tv = *ir_utils::filterByType(e->outputs()).begin(); + if (e->output(0)->getValType() != ValType::TensorView) { + return; + } + + TORCH_INTERNAL_ASSERT(e->outputs().size() == 1); + TensorView* tv = e->output(0)->as(); std::vector io = tv->getRootDomain(); @@ -1316,8 +1321,13 @@ class ProveValEqual : private IterVisitor { // Inspect a pointwise op and record the identified equality void provePwOp(Expr* e) { - TensorView* tv = *ir_utils::filterByType(e->outputs()).begin(); - std::vector io = tv->getRootDomain(); + if (e->output(0)->getValType() != ValType::TensorView) { + return; + } + + TORCH_INTERNAL_ASSERT(e->outputs().size() == 1); + TensorView* tv = e->output(0)->as(); + const std::vector& io = tv->getRootDomain(); // Record equalities from output to all the inputs // ignores un-concretizable broadcasts @@ -1331,8 +1341,8 @@ class ProveValEqual : private IterVisitor { } } -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Woverloaded-virtual" + using IterVisitor::handle; + void handle(ReductionOp* rop) override { provePwOp(rop); } @@ -1348,7 +1358,6 @@ class ProveValEqual : private IterVisitor { void handle(TernaryOp* top) override { provePwOp(top); } -#pragma clang diagnostic pop private: ConcretizeDomain cd_; @@ -1470,6 +1479,7 @@ c10::optional NamedScalar::getParallelIndex() const { return c10::nullopt; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_printer.h b/torch/csrc/jit/codegen/cuda/ir_printer.h index 57ca00076afca..67aab7aaa8007 100644 --- a/torch/csrc/jit/codegen/cuda/ir_printer.h +++ b/torch/csrc/jit/codegen/cuda/ir_printer.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -10,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { //! Prints computation Fusion IR nodes //! @@ -62,6 +62,7 @@ class TORCH_CUDA_API IrTransformPrinter : public IrPrinter { } }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index 1b51212e500ef..e5402dafb71d5 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -7,7 +7,7 @@ namespace torch { namespace jit { namespace fuser { - +namespace cuda { namespace ir_utils { template @@ -110,6 +110,7 @@ auto filterByType(const ContainerType& inputs) { } } // namespace ir_utils +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 1a846fa96a725..d1efc6d163eb0 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -7,6 +8,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* ITER VISITOR */ @@ -540,6 +542,7 @@ std::unordered_set InputsOf::output(Fusion* fusion, Val* output_) { return io.inputs; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index cf01e903f3a14..d70c4f9805449 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -16,6 +16,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* * IterVisitor starts from leaf nodes, fusion outputs, or the provided values. @@ -67,8 +68,8 @@ class TORCH_CUDA_API IterVisitor : public OptOutDispatch { virtual std::vector next(Expr* expr) { FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, "); - std::vector next_stmts{expr->inputs().begin(), - expr->inputs().end()}; + std::vector next_stmts{ + expr->inputs().begin(), expr->inputs().end()}; return next_stmts; } @@ -263,6 +264,7 @@ class InputsOf : public IterVisitor { static std::unordered_set output(Fusion* fusion, Val* output_); }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index c6c0a39ccb793..1ad8d0699f40a 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -1,13 +1,16 @@ - #include + #include #include +#include +#include #include namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -152,6 +155,12 @@ void Kernel::analyze() { } } +void Kernel::print() const { + kir::IrPrinter ir_printer(std::cout); + ir_printer.printKernel(this); +} + +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h index 1d7b1834c39f4..41485bb8c39d5 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.h +++ b/torch/csrc/jit/codegen/cuda/kernel.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -13,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { //! Summary of interesting facts about the kernel //! @@ -105,6 +105,9 @@ class TORCH_CUDA_API Kernel final : public NonCopyable { ir_nodes_.push_back(std::move(node)); } + //! Debug dump of the Kernel IR + void print() const; + private: // Analyze the kernel IR and caches the summary of interesting data void analyze(); @@ -131,6 +134,7 @@ class TORCH_CUDA_API Kernel final : public NonCopyable { std::unique_ptr predicate_map_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index e8300970eb59f..6b5585d044141 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -1,5 +1,5 @@ - #include + #include #include #include @@ -13,6 +13,26 @@ namespace cuda { namespace { +// Check device of TensorType in all inputs ensure all tensors are on cuda +// devices. +// return common device index (or -1 if device differs). +int getCommonDeviceCUDA(const at::ArrayRef& inputs) { + int index = -1; + for (const auto& input : inputs) { + if (!input.isTensor()) { + continue; + } + const auto& device = input.toTensor().device(); + TORCH_CHECK(device.is_cuda(), "nvfuser only supports cuda device"); + auto cur_index = device.index(); + if (index != -1 && index != cur_index) { + return -1; + } + index = cur_index; + } + return index; +} + // TODO: temporary hack to resolve my is_constructible issue; std::vector toVector(const at::DimVector& small_vec) { return std::vector(small_vec.begin(), small_vec.end()); @@ -21,52 +41,49 @@ std::vector toVector(const at::DimVector& small_vec) { #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunused-function" void debugPrint(const TensorTypePtr& type) { - printf("\nsizes:"); + std::stringstream sizes_s; if (auto sizes = type->symbolic_sizes().sizes()) { - // for (const auto& shape_symbol : sizes.value()) { - int rank = static_cast(sizes->size()); - for (int i = 0; i < rank; i++) { - const auto& shape_symbol = sizes.value()[i]; + for (const auto& shape_symbol : *sizes) { if (shape_symbol.is_static()) { - printf("%ld, ", shape_symbol.static_size()); + sizes_s << shape_symbol.static_size() << ", "; } else { - printf("s(%ld), ", *reinterpret_cast(&shape_symbol)); + sizes_s << "s(" << *reinterpret_cast(&shape_symbol) + << "), "; } } } else { - printf("no size available\n"); + sizes_s << "no size available"; } + std::cout << "sizes:" << sizes_s.str() << std::endl; if (const auto& stride_properties = type->stride_properties().sizes()) { - int rank = static_cast(stride_properties->size()); - printf("\nstride: "); - for (int i = 0; i < rank; i++) { - if ((*stride_properties)[i].has_value() && - (*stride_properties)[i]->stride_.has_value()) { - printf("%ld, ", (*stride_properties)[i]->stride_.value()); + std::stringstream stride_s; + std::stringstream index_s; + std::stringstream contig_s; + + for (const auto& stride_property : *stride_properties) { + if (stride_property.has_value() && stride_property->stride_.has_value()) { + stride_s << *stride_property->stride_ << ", "; } else { - printf("?, "); + stride_s << "?, "; } - } - printf("\nstride index: "); - for (int i = 0; i < rank; i++) { - if ((*stride_properties)[i].has_value() && - (*stride_properties)[i]->stride_index_.has_value()) { - printf("%ld, ", (*stride_properties)[i]->stride_index_.value()); + if (stride_property.has_value() && + stride_property->stride_index_.has_value()) { + index_s << *stride_property->stride_index_ << ", "; } else { - printf("?, "); + index_s << "?, "; } - } - printf("\ncontiguous: "); - for (int i = 0; i < rank; i++) { - if ((*stride_properties)[i].has_value() && - (*stride_properties)[i]->contiguous_.has_value()) { - printf("%d, ", (*stride_properties)[i]->contiguous_.value()); + if (stride_property.has_value() && + stride_property->contiguous_.has_value()) { + contig_s << *stride_property->contiguous_ << ", "; } else { - printf("?, "); + contig_s << "?, "; } } + std::cout << "stride: " << stride_s.str() << std::endl; + std::cout << "stride index: " << index_s.str() << std::endl; + std::cout << "contiguous: " << contig_s.str() << std::endl; } else { - printf("no stride properties available\n"); + std::cout << "no stride properties available" << std::endl; } } #pragma clang diagnostic pop @@ -95,6 +112,7 @@ at::DimVector graphReductionAxes(const std::shared_ptr& graph) { return reduction_axes; } +// TODO(CONTIGUITY) at::DimVector getPermutationPerSortedStride(const TensorTypePtr& type) { FUSER_PERF_SCOPE("getPermutationPerSortedStride"); @@ -192,7 +210,7 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( std::stringstream encoded_inputs; for (const auto& input : inputs) { if (input.isTensor()) { - auto input_tensor = input.toTensor(); + auto& input_tensor = input.toTensor(); encoded_inputs << ";"; auto sep = ""; @@ -206,6 +224,7 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( encoded_inputs << sep << stride; sep = ","; } + encoded_inputs << "@" << input_tensor.device().str(); } else { // encode s for scalar; encoded_inputs << ";s"; @@ -240,19 +259,27 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( return ret; } -FusionExecutorCache::FusionExecutorCache( - std::unique_ptr&& fusion, - at::Device device) - : device_(device), fusion_(std::move(fusion)) { +FusionExecutorCache::FusionExecutorCache(std::unique_ptr&& fusion) + : fusion_(std::move(fusion)) { FUSER_PERF_SCOPE("FusionExecutorCache::FusionExecutorCache"); // avoid putting `has_reduction_` in the initializer list has_reduction_ = fusion_->hasReduction(); } std::vector FusionExecutorCache::runFusionWithInputs( - const at::ArrayRef& inputs, - size_t unique_id) { + const at::ArrayRef& inputs) { FUSER_PERF_SCOPE("runFusionWithInputs"); + + // get unique id `unique_id` for given input set `inputs`; + auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs); + if (id_lookup_ret.eviction) { + evictCache(id_lookup_ret.evict_id); + } + + const size_t unique_id = id_lookup_ret.id; + const int device_index = getCommonDeviceCUDA(inputs); + TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs"); + LaunchParams launch_params; if (code_to_fe_lookup_.count(unique_id) == 0) { // enter when we get a new input set. We need to search for compatible @@ -300,7 +327,7 @@ std::vector FusionExecutorCache::runFusionWithInputs( launch_params = reduction_params.value().lparams; auto fusion_executor = - &red_fusion_executor_cache_[reduction_params.value()]; + &red_fusion_executor_cache_[device_index][reduction_params.value()]; if (!fusion_executor->compiled()) { // HEURISTIC NOT COMPILED, COMPILE A KERNEL @@ -349,7 +376,7 @@ std::vector FusionExecutorCache::runFusionWithInputs( // This means we have not found a previously generated kernel that's // compatible with the new reduction params. We need to finish codegen. CompileOptions options; - options.device = device_; + options.device = c10::Device(DeviceType::CUDA, device_index); fusion_executor->compileFusion(&fusion, options); } // record new short cut to `FusionExecutor` @@ -357,17 +384,20 @@ std::vector FusionExecutorCache::runFusionWithInputs( } else { // Handle pointwise operations - if (!pw_fusion_executor_cache_) { - pw_fusion_executor_cache_ = std::make_unique(); + if (pw_fusion_executor_cache_.count(device_index) == 0) { + pw_fusion_executor_cache_[device_index] = + std::make_unique(); CompileOptions options; - options.device = device_; + options.device = c10::Device(DeviceType::CUDA, device_index); // no need to copy fusion_, as we are not generating more than 1 kernel // for PW. scheduleFusion(fusion_.get(), inputs); - pw_fusion_executor_cache_->compileFusion(fusion_.get(), options); + pw_fusion_executor_cache_[device_index]->compileFusion( + fusion_.get(), options); } // record new short cut to `FusionExecutor` - code_to_fe_lookup_[unique_id] = pw_fusion_executor_cache_.get(); + code_to_fe_lookup_[unique_id] = + pw_fusion_executor_cache_[device_index].get(); } } @@ -375,62 +405,7 @@ std::vector FusionExecutorCache::runFusionWithInputs( inputs, launch_params, unique_id); } -GraphCache::InputsRequirement::InputsRequirement( - const std::shared_ptr& graph, - const std::vector& reduction_axes) { - FUSER_PERF_SCOPE("InputsRequirement::InputsRequirement"); - - // run over inputs to extract common types; - TensorTypePtr acc_type = TensorType::get(); - for (const auto& input : graph->inputs()) { - // only check tensor types; - if (auto input_type = input->type()->cast()) { - vec_optional_ttp.emplace_back(input_type); - if (acc_type->dim().has_value()) { - // TODO: I think merge cannot handle broadcast - Go verify it later; - // TODO: Since we are only handling permutation here, we should just - // merge the stride_index_; - acc_type = acc_type->merge(input_type); - } else { - acc_type = input_type; - } - } else { - vec_optional_ttp.emplace_back(c10::nullopt); - } - } - extractPermutation(acc_type, reduction_axes); -} - -GraphCache::InputsRequirement::InputsRequirement( - const at::ArrayRef& inputs, - const std::vector& reduction_axes) { - FUSER_PERF_SCOPE("InputsRequirement::InputsRequirement"); - - // run over inputs to extract common types; - TensorTypePtr acc_type = TensorType::get(); - for (const auto& input : inputs) { - // only check tensor types; - if (input.isTensor()) { - // TensorType::create populates stride properties; - // auto input_type = TensorType::create(input.toTensor()); - // vec_optional_ttp.emplace_back(input_type); - vec_optional_ttp.emplace_back(TensorType::create(input.toTensor())); - if (acc_type->dim().has_value()) { - // TODO: I think merge cannot handle broadcast - Go verify it later; - // TODO: Since we are only handling permutation here, we should just - // merge the stride_index_; - acc_type = acc_type->merge(vec_optional_ttp.back().value()); - } else { - acc_type = vec_optional_ttp.back().value(); - } - } else { - vec_optional_ttp.emplace_back(c10::nullopt); - } - } - extractPermutation(acc_type, reduction_axes); -} - -bool GraphCache::InputsRequirement::requiresPermutation() { +bool GraphCache::requiresPermutation() { const size_t input_rank = input_permutation_.size(); for (size_t i = 0; i < input_rank; i++) { if (input_permutation_[i] != (long)i) { @@ -453,116 +428,22 @@ bool GraphCache::InputsRequirement::requiresPermutation() { return false; } -// TODO: tests! -bool GraphCache::InputsRequirement::complyWith( - const InputsRequirement& expect) { - FUSER_PERF_SCOPE("InputsRequirement::complyWith"); - - if (device_ != expect.device_ || - input_permutation_ != expect.input_permutation_ || - pw_output_permutation_ != expect.pw_output_permutation_ || - reduction_output_permutation_ != expect.reduction_output_permutation_ || - vec_optional_ttp.size() != expect.vec_optional_ttp.size()) { - return false; - } - - // trick here is, `this` is always well defined while `expect` could has - // missing options; - for (size_t i = 0; i < vec_optional_ttp.size(); i++) { - // TensorType has to match, otherwise it's not compatible to our graph. - auto expect_vec_optional_ttp_i = expect.vec_optional_ttp[i]; - TORCH_INTERNAL_ASSERT( - vec_optional_ttp[i].has_value() == - expect_vec_optional_ttp_i.has_value()); - if (expect_vec_optional_ttp_i.has_value()) { - // We assume that dimensionality should always match. - TORCH_INTERNAL_ASSERT( - (*expect_vec_optional_ttp_i)->symbolic_sizes().sizes().has_value() && - (*expect_vec_optional_ttp_i) - ->stride_properties() - .sizes() - .has_value() && - (*expect_vec_optional_ttp_i)->dim().has_value() && - (*vec_optional_ttp[i])->dim().value() && - (*expect_vec_optional_ttp_i)->dim().value() == - (*vec_optional_ttp[i])->dim().value(), - "expect fixed rank of tensors"); - - int rank = static_cast((*expect_vec_optional_ttp_i)->dim().value()); - auto vec_shape_symbol_ex = - (*expect_vec_optional_ttp_i)->symbolic_sizes().sizes().value(); - auto vec_optional_stride_ex = - (*expect_vec_optional_ttp_i)->stride_properties().sizes().value(); - auto vec_shape_symbol = - (*vec_optional_ttp[i])->symbolic_sizes().sizes().value(); - auto vec_optional_stride = - (*vec_optional_ttp[i])->stride_properties().sizes().value(); - for (int j = 0; j < rank; j++) { - // if broadcast rule differs, compliance is broken; - if ((vec_shape_symbol_ex[j].is_static() && - vec_shape_symbol_ex[j].static_size() == 1) ^ - (vec_shape_symbol[j].is_static() && - vec_shape_symbol[j].static_size() == 1)) { - return false; - } - - const auto& vec_optional_stride_ex_j = vec_optional_stride_ex[j]; - const auto& vec_optional_stride_j = vec_optional_stride[j]; - // if contiguity / stride index differ, compliance is broken; - if (vec_optional_stride_ex_j.has_value() != - vec_optional_stride_j.has_value()) { - return false; - } - if (vec_optional_stride_ex_j.has_value() && - (vec_optional_stride_ex_j->stride_index_ != - vec_optional_stride_j->stride_index_ || - vec_optional_stride_ex_j->contiguous_ != - vec_optional_stride_j->contiguous_)) { - return false; - } - } - } - } - return true; -} - -void GraphCache::InputsRequirement::extractPermutation( - const TensorTypePtr& acc_type, - const std::vector& reduction_axes) { +void GraphCache::extractPermutation(const TensorTypePtr& acc_type) { input_permutation_ = getPermutationPerSortedStride(acc_type); reduction_output_permutation_ = - inversePermutation(input_permutation_, reduction_axes); + inversePermutation(input_permutation_, toVector(reduction_axes_)); pw_output_permutation_ = inversePermutation(input_permutation_, {}); - TORCH_CHECK( - acc_type->device().has_value(), "requires fixed device for all inputs"); - device_ = acc_type->device(); } -FusionExecutorCache* GraphCache::appendFusionExecutorCache( - const InputsRequirement& input_stack) { - FUSER_PERF_SCOPE("createFusionExecutorCache"); - - input_stacks_.emplace_back(input_stack); - std::shared_ptr parsing_graph = graph_->copy(); - // assign inputs on parsing_graph to accommodate legacy executor, where input - // type might be missing/incomplete; - // This is purely overhead for profiling executor; - for (size_t i = 0; i < input_stack.vec_optional_ttp.size(); i++) { - // skip scalar inputs; - if (input_stack.vec_optional_ttp[i].has_value()) { - parsing_graph->inputs()[i]->setType( - input_stack.vec_optional_ttp[i].value()); - } - } +void GraphCache::createFusion(const std::shared_ptr& graph) { + FUSER_PERF_SCOPE("GraphCache::createFusion"); // permute inputs on `Graph` to sort dimensions on common stride order; - if (input_stacks_.back().requiresPermutation()) { - auto input_permutation = input_stacks_.back().input_permutation_; - + if (requiresPermutation()) { // TODO: lambda is a bad idea, the logic in this function is too tricky and // should be properly tested to ensure correctness. - // lambda to permute `TensorType` axes per `input_permutation` - auto type_permute_fn = [&input_permutation](const TensorTypePtr& type) { + // lambda to permute `TensorType` axes per `input_permutation_` + auto type_permute_fn = [this](const TensorTypePtr& type) { // std::vector vec_shape_symbol = // type->symbolic_sizes().sizes().value(); auto vec_shape_symbol = type->symbolic_sizes().sizes().value(); @@ -575,7 +456,8 @@ FusionExecutorCache* GraphCache::appendFusionExecutorCache( std::vector permuted_vec_ss; std::vector> permuted_vec_optional_stride; for (int i = 0; i < rank; i++) { - permuted_vec_ss.emplace_back(vec_shape_symbol[input_permutation[i]]); + permuted_vec_ss.emplace_back( + vec_shape_symbol[this->input_permutation_[i]]); // permutation doesn't change contiguity info, nor does it change // stride; The only thing affected is stride_index_; if (vec_optional_stride[i].has_value()) { @@ -583,7 +465,7 @@ FusionExecutorCache* GraphCache::appendFusionExecutorCache( if (index.has_value()) { for (int j = 0; j < rank; j++) { // follow the permutation to resolve the new stride_index; - if (input_permutation[j] == (long)index.value()) { + if (this->input_permutation_[j] == (long)index.value()) { index = j; break; } @@ -606,7 +488,7 @@ FusionExecutorCache* GraphCache::appendFusionExecutorCache( type->requires_grad()); }; // closing lambda - for (auto input : parsing_graph->inputs()) { + for (auto input : graph->inputs()) { if (auto input_type = input->type()->cast()) { input->setType(type_permute_fn(input_type)); } @@ -614,7 +496,7 @@ FusionExecutorCache* GraphCache::appendFusionExecutorCache( if (!reduction_axes_.empty()) { // see [ NOTE - reduction in graph ] part 2. - for (auto n : parsing_graph->nodes()) { + for (auto n : graph->nodes()) { if (isReductionNode(n)) { auto dims_list = constant_as>(n->input(1)); TORCH_INTERNAL_ASSERT( @@ -622,34 +504,31 @@ FusionExecutorCache* GraphCache::appendFusionExecutorCache( std::vector adjusted_reduction_axes; for (const auto dim : dims_list->vec()) { // adjust reduction axis to be the permuted axis; - for (size_t j = 0; j < input_permutation.size(); j++) { + for (size_t j = 0; j < input_permutation_.size(); j++) { // follow the permutation to resolve the new reduction axes; - if (input_permutation[j] == dim) { + if (input_permutation_[j] == dim) { adjusted_reduction_axes.emplace_back(j); break; } } } - parsing_graph->setInsertPoint(n); + graph->setInsertPoint(n); auto const_ival_axes = - parsing_graph->insertConstant(IValue(adjusted_reduction_axes)); + graph->insertConstant(IValue(adjusted_reduction_axes)); n->replaceInput(1, const_ival_axes); } } } } - TORCH_INTERNAL_ASSERT( - input_stacks_.back().device_.has_value(), - "device is not set for fusion executor, something went wrong in NvFuser"); - fe_cache_.emplace_back(std::make_unique( - parseJitIR(parsing_graph), input_stacks_.back().device_.value())); - return fe_cache_.back().get(); + fusion_executor_cache_ = + std::make_unique(parseJitIR(graph)); } -GraphCache::GraphCache(std::shared_ptr graph) - : graph_(std::move(graph)) { +GraphCache::GraphCache(const std::shared_ptr& graph) { FUSER_PERF_SCOPE("GraphCache::GraphCache"); + TORCH_INTERNAL_ASSERT( + IsNewExecutorEnabled(), "legacy executor is not supported by nvfuser"); // [ NOTE - reduction in graph ] // @@ -661,104 +540,66 @@ GraphCache::GraphCache(std::shared_ptr graph) // 2. adjust reduction axes for the permutation; // permute changes the semantics of axes, we need to update the reduction // axes in the graph in order to match the behavior; - reduction_axes_ = graphReductionAxes(graph_); + reduction_axes_ = graphReductionAxes(graph); - // compile a kernel if we have enough information from graph (profiling - // record) - if (IsNewExecutorEnabled()) { - appendFusionExecutorCache( - InputsRequirement(graph_, toVector(reduction_axes_))); + // run over inputs to extract common types; + TensorTypePtr acc_type = TensorType::get(); + for (const auto& input : graph->inputs()) { + // only check tensor types; + if (auto input_type = input->type()->cast()) { + if (acc_type->dim().has_value()) { + // TODO: I think merge cannot handle broadcast - Go verify it later; + // TODO: Since we are only handling permutation here, we should just + // merge the stride_index_; + acc_type = acc_type->merge(*input_type); + } else { + acc_type = input_type; + } + } } + extractPermutation(acc_type); + createFusion(graph); } std::vector GraphCache::runGraphWithInputs( const at::ArrayRef& inputs) { FUSER_PERF_SCOPE("runGraphWithInputs"); - // get unique id `unique_id` for given input set `inputs`; - auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs); - const size_t unique_id = id_lookup_ret.id; - - // if we went over the cache size for short-cut, we evict entries using LRU; - if (id_lookup_ret.eviction) { - auto index_lookup_iter = code_to_index_lookup_.find(id_lookup_ret.evict_id); - TORCH_INTERNAL_ASSERT( - index_lookup_iter != code_to_index_lookup_.end(), - "evicting cache entry not found in lookup table"); - // evict nested cache in FusionExecutorCache - fe_cache_[index_lookup_iter->second]->evictCache(index_lookup_iter->first); - code_to_index_lookup_.erase(index_lookup_iter); - } - - FusionExecutorCache* fusion_executor_cache = nullptr; - - if (code_to_index_lookup_.count(unique_id) == 0) { - InputsRequirement input_stack(inputs, toVector(reduction_axes_)); - for (size_t i = 0; i < fe_cache_.size(); i++) { - if (input_stack.complyWith(input_stacks_[i])) { - // found compliable fe_cache_ entry - fusion_executor_cache = fe_cache_[i].get(); - // record short cut to designated fusion executor - code_to_index_lookup_[unique_id] = i; - break; - } - } - if (!fusion_executor_cache) { - // This is the ugly bit, each level of cache has their own entry. At this - // point, we are creating an instance of FusionExecutorCache as well as a - // cache entry for GraphCache; - // But we are not creating any cache entry for nested structures. We only - // create cache entry below when we later call - // `fusion_executor_cache->runFusionWithInputs` - fusion_executor_cache = appendFusionExecutorCache(input_stack); - // record short cut to designated fusion executor - code_to_index_lookup_[unique_id] = fe_cache_.size() - 1; - } - } else { - // take short cut to designated fusion executor - fusion_executor_cache = fe_cache_[code_to_index_lookup_[unique_id]].get(); - } - InputsRequirement* input_requirement = - &input_stacks_[code_to_index_lookup_[unique_id]]; // GraphCache need to permute inputs/outputs to accommodate dimension // coalescing - if (input_requirement->requiresPermutation()) { + if (requiresPermutation()) { std::vector permuted_inputs; permuted_inputs.reserve(inputs.size()); for (const auto& input : inputs) { if (input.isTensor()) { permuted_inputs.emplace_back( - input.toTensor().permute(input_requirement->input_permutation_)); + input.toTensor().permute(input_permutation_)); } else { permuted_inputs.emplace_back(input); } } - auto outputs = - fusion_executor_cache->runFusionWithInputs(permuted_inputs, unique_id); + auto outputs = fusion_executor_cache_->runFusionWithInputs(permuted_inputs); std::vector permuted_outputs; permuted_outputs.reserve(outputs.size()); for (const auto& output : outputs) { // This is to address the issue that not all outputs from a reduction // fusion are reduced tensor; We support intermediate tensors to be output - if (static_cast(output.dim()) == - input_requirement->pw_output_permutation_.size()) { - permuted_outputs.emplace_back( - output.permute(input_requirement->pw_output_permutation_)); + if (static_cast(output.dim()) == pw_output_permutation_.size()) { + permuted_outputs.emplace_back(output.permute(pw_output_permutation_)); } else if ( static_cast(output.dim()) == - input_requirement->reduction_output_permutation_.size()) { + reduction_output_permutation_.size()) { permuted_outputs.emplace_back( - output.permute(input_requirement->reduction_output_permutation_)); + output.permute(reduction_output_permutation_)); } else { TORCH_INTERNAL_ASSERT( false, - "Something went wrong with integration permutation, can't find a consistent permutation for output in fusion", - *graph_); + "Something went wrong with integration permutation, can't find a consistent permutation for output in fusion"); } } return permuted_outputs; } else { - return fusion_executor_cache->runFusionWithInputs(inputs, unique_id); + return fusion_executor_cache_->runFusionWithInputs(inputs); } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index e0e8a75ea5cdd..8ceda77453d7b 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -28,124 +28,126 @@ namespace cuda { //! class TORCH_CUDA_API InputsIdLookup { public: - // constructor where maximum cache size is fixed during init + //! constructor where maximum cache size is fixed during init explicit InputsIdLookup(size_t max_cache_size = 10) : max_cache_size_(max_cache_size){}; - // struct to hold return value for lookupId. + //! struct to hold return value for lookupId. struct IdLookupReturn { size_t id = 0; size_t evict_id = 0; bool eviction = false; }; - // encode each input sets to with an unique id; - // Returned data structure also indicates whether eviction has happened within - // the lookup cache. This is needed because lookup shortcut is also cached in - // nested `GraphCache`, `FusionExecutorCache` and `FusionExecutor`. - // see [ Note -- 2 level cache implementation ] + //! encode each input sets to with an unique id; + //! Returned data structure also indicates whether eviction has happened + //! within the lookup cache. This is needed because lookup shortcut is also + //! cached in nested `GraphCache`, `FusionExecutorCache` and `FusionExecutor`. + //! see [ Note -- 2 level cache implementation ] IdLookupReturn lookupId(const at::ArrayRef& inputs); - // debugging API + //! debugging API that returns the size of lookup table size_t size() const { return encoding_lookup_.size(); } private: - // entry stored in `encoding_lookup_` to implement LRU + //! entry stored in `encoding_lookup_` to implement LRU struct EncodingEntry { size_t id; std::list::iterator lru_iter; }; - // maximum cache size for LRU + //! maximum cache size for LRU const size_t max_cache_size_; - // next available unique id, we monotonically increase `current_id_` avoid - // conflicts + //! next available unique id, we monotonically increase `current_id_` avoid + //! conflicts size_t current_id_ = 1; - // entry in the cache, This is used to implement LRU cache, where entries in - // the list is ordered by their recent usage (freshly used entry is placed at - // the beginning) + //! entry in the cache, This is used to implement LRU cache, where entries in + //! the list is ordered by their recent usage (freshly used entry is placed at + //! the beginning) std::list used_entry_; - // map from `std::string` to a unique id `size_t` (packaged in `EncodingEntry` - // ). We store an iterator to `used_entry_` to implement LRU + //! map from `std::string` to a unique id `size_t` (packaged in + //! `EncodingEntry` + //! ). We store an iterator to `used_entry_` to implement LRU std::unordered_map encoding_lookup_; }; -// [ Note -- 2 level cache implementation ] -// -// 2 level hierarchically nested cache is to handle the code generation and -// execution of a given PyTorch IR graph that is unique in its computational -// graph (see note computational graph down). -// -// The nested cache structures are: -// a. GraphCache -// - holds a vector of `InputsRequirement` & `FusionExecutorCache`, where -// each entry is constructed to handle a set of inputs with unique -// contiguity info, stride order & broadcasting semantics, on a given -// device; -// - `InputsRequirement::complyWith` demonstrates the meta information -// that remains unchanged for a given `FusionExecutorCache` -// - At run-time (or compile-time with Profiling Executor), we extract -// `InputsRequirement` from given inputs to the fused operation. We -// iterate through existing entries within GraphCache (that is the -// `input_stacks_`) looking for a suitable entry to execute the -// computation. -// - In the case of cache miss, we generate a new entry and put it in -// the GraphCache instance (We push back to both `input_stacks_` and -// `fe_cache_`, fusion executor cache. -// b. FusionExecutorCache -// - holds a group of `FusionExecutor` to handle dynamic shape (varying -// tensor sizes) -// - currently this is a dummy implementation and has branching to handle -// different scheduler for point-wise fusion and reduction fusion; -// -// * note computational graph -// In theory, computational graph should refer to only the computational nodes -// in a subgraph and should remain agnostic to input meta info, like -// shape, strides, type e.t.c.. However, the contract right here is fuzzy. -// Different executor applies their own protocol of what is a unique -// computational graph. e.g. Legacy Executor embeds tensor type & dimensionality -// in the graph, while Profiling Executor keeps symbolic shape as well as stride -// order in the graph as well. -// Our definition of computational graph is relaxed to support Legacy Executor, -// so the `GraphCache` could handle varying memory layout of strided tensor -// (different stride order & contiguity information). We utilize the profiling -// information now by generating an entry in GraphCache with the given profiling -// record. +//! [ Note -- 2 level cache implementation ] +//! +//! We have 2 level cache for a separation in function to keep them simpler. +//! +//! 2 level hierarchically nested cache is to handle the code generation and +//! execution of a given PyTorch IR graph that is unique in its computational +//! graph (see note on unique computational graph down). +//! +//! The nested cache structures are: +//! a. GraphCache +//! - GraphCache translates PyTorch IR into Fusion IR and pass it to a +//! `FusionExecutorCache`; +//! - GraphCache assumes all inputs to comply with profiling information, +//! mostly tensor size & contiguity (see note on unique computational +//! graph). The assumption is assured at runtime by +//! `prim::CudaFusionGuard`; +//! - GraphCache handles permutation for I/O tensors, when they share +//! global stride order. This permutation facilitates dimension +//! collapsing, which gives simpler indexing. +//! b. FusionExecutorCache +//! - has a single `Fusion`, FusionExecutorCache handles kernel schedule +//! and passed scheduled tensor to `FusionExecutor` to generate code; +//! - create `FusionExecutor` instances to handle heuristics from dynamic +//! shape (varying tensor sizes); +//! - create `FusionExecutor` instances to handle different devices; +//! - holds input cache `InputsIdLookup`, which allow cache on heuristics +//! and launch parameters to reduce latency. +//! +//! * note on unique computational graph +//! In theory, computational graph should refer to only the computational nodes +//! in a subgraph and should remain agnostic to input meta info, like +//! shape, strides, type e.t.c.. However, the contract right here is fuzzy. +//! Different executor applies their own protocol of what is a unique +//! computational graph. e.g. Legacy Executor embeds tensor type & +//! dimensionality in the graph, while Profiling Executor keeps symbolic shape +//! as well as stride order in the graph as well. +//! +//! Our definition of a "unique" computational graph is aligned with `Fusion` +//! IR, hence the requirement extends to meta information on input tensors. +//! Which means, for each input tensor, following properties are fixed: +//! a) stride order; +//! b) contiguity information; +//! c) broadcasting semantics (size-1 or not); +//! d) rank; +//! e) scalar type; class FusionExecutorCache { public: - // create new fusion executor cache at a given device to handle kernel - // generation of dynamic sizes; - // fusion executor is taking the ownership of `fusion`; - FusionExecutorCache(std::unique_ptr&& fusion, at::Device device); + //! create new fusion executor cache at a given device to handle kernel + //! generation of dynamic sizes; + //! fusion executor is taking the ownership of `fusion`; + explicit FusionExecutorCache(std::unique_ptr&& fusion); - // Execute fusion graph with given inputs, create `FusionExecutor` as needed; + //! Execute fusion graph with given inputs, create `FusionExecutor` as needed; std::vector runFusionWithInputs( - const at::ArrayRef& inputs, - size_t unique_id); + const at::ArrayRef& inputs); - // evict cached short cut entry in `code_to_fe_lookup_`; - inline void evictCache(size_t cache_id) { + private: + //! evict cached short cut entry in `code_to_fe_lookup_` as well as cached + //! entry in `FusionExecutor` + void evictCache(size_t cache_id) { auto iter = code_to_fe_lookup_.find(cache_id); TORCH_INTERNAL_ASSERT( iter != code_to_fe_lookup_.end(), "evict cache failed to find an entry"); - // evict nested lookup entry in nested FusionExecutor + // evict nested lookup entry in nested `FusionExecutor` (iter->second)->evictCache(cache_id); code_to_fe_lookup_.erase(iter); }; private: - // device_ where compiled binaries are loaded on & inputs are expected to - // reside; - at::Device device_; - - // original un-scheduled `Fusion`; + //! original un-scheduled `Fusion`; std::unique_ptr fusion_; // I'm trading the const model in favor of assigning `has_reduction_` in the @@ -155,104 +157,73 @@ class FusionExecutorCache { // initizlize it in the initializer list, where the order of initialization // is controled by the order of declaration instead of their order in the list // - // cache fusion->hasReduction() because it's expensive; + //! cache fusion->hasReduction() because it's expensive; bool has_reduction_; - // TODO: ugly logic for now. We should integrate the hashing of cache for - // different kernels. (alternatively we could do so in scheduler). - // ugly bits now: - // The fact that we have heuristics only for reduction, but use a general - // kernel for all point-wise fusion ended up with this: - // 1. For point-wise fusion, we have a single `FusionExecutor` in - // `pw_fusion_executor_cache_` - // 2. For reduction fusion we have a hash table with ReductionParams as entry - // pointing to the actual `FusionExecutor` in `red_fusion_executor_cache_` - std::unique_ptr pw_fusion_executor_cache_; - std::unordered_map + //! TODO: ugly logic for now. We should integrate the hashing of cache for + //! different kernels. (alternatively we could do so in scheduler). + //! ugly bits now: + //! The fact that we have heuristics only for reduction, but use a general + //! kernel for all point-wise fusion ended up with this: + //! 1. For point-wise fusion, we have a single `FusionExecutor` in + //! `pw_fusion_executor_cache_` + //! 2. For reduction fusion we have a hash table with ReductionParams as entry + //! pointing to the actual `FusionExecutor` in `red_fusion_executor_cache_` + //! + //! Both cache_ key on device_index, because `FusionExecutor` is designated to + //! a single device + std::unordered_map> + pw_fusion_executor_cache_; + std::unordered_map< + int, + std::unordered_map> red_fusion_executor_cache_; - // short cut to FusionExecutor for input set encoded with id; + //! short cut to FusionExecutor for input set encoded with id; std::unordered_map code_to_fe_lookup_; + + //! inputs to unique_id lookup table; + InputsIdLookup inputs_id_lookup_; }; class GraphCache { public: - // TODO: we should probably change shared_ptr to unique_ptr, as we want to - // claim the ownership of the computational graph. - // create GraphCache on a given graph; - // Note: if run with profiling executor, we'll try to generete a kernel with - // profiling information at this moment. - GraphCache(std::shared_ptr graph); - - // execute graph with given inputs. + //! TODO: we should probably change shared_ptr to unique_ptr, as we want to + //! claim the ownership of the computational graph. + //! create GraphCache on a given graph; + //! We extract global stride index order and translate PyTorch JIT IR to + //! Fusion IR. + explicit GraphCache(const std::shared_ptr& graph); + + //! execute graph with given inputs, permutation on I/O tensors are performed. std::vector runGraphWithInputs( const at::ArrayRef& inputs); private: - // TODO: place holder with naive implementation for now. - // structure use to mark the compatibility of each FusionExecutorCache; - // We also have `input_permutation_` & `output_permutation_` used to - // facilitate dimension coalescing per stride order. - struct InputsRequirement { - // target device - c10::optional device_; - // TODO: TensorTypePtr is not very easy to work with. - // c10::nullopt to take place of non-tensor type; - std::vector> vec_optional_ttp; - - // common permutation order used for dimension coalescing; - at::DimVector input_permutation_; - at::DimVector pw_output_permutation_; - at::DimVector reduction_output_permutation_; - - // construct InputsRequirement from `Graph`, this is used for constructing - // `GraphCache` entry using profiling record - InputsRequirement( - const std::shared_ptr& graph, - const std::vector& reduction_axes); - - // construct InputsRequirement from live input feeds, this is used to handle - // run-time inputs to: 1. search for compatible entry; 2. insert new entry - // in case of a cache miss. - InputsRequirement( - const at::ArrayRef& inputs, - const std::vector& reduction_axes); - - bool complyWith(const InputsRequirement& expect); + //! Computation graph; + std::shared_ptr graph_; + //! TODO: poor name, we should use `eliminated_axes_` instead; + at::DimVector reduction_axes_; - // helper function used at run-time to check whether a common permutation is - // present, this is used to take the short-cut to skip permutation logic. - bool requiresPermutation(); + //! helper function used at run-time to check whether a common permutation is + //! present, this is used to take the short-cut to skip permutation logic. + bool requiresPermutation(); - // extract permutation for input output tensor from accumulcated tensor type - // pointer on all inputs; - void extractPermutation( - const TensorTypePtr& acc_type, - const std::vector& reduction_axes); - }; + //! construct FusionExecutorCache + void createFusion(const std::shared_ptr& graph); - // construct FusionExecutorCache per InputsRequirement. - // This function makes sure that we properly insert both `input_stacks_` and - // `fe_cache_` at the same time. - FusionExecutorCache* appendFusionExecutorCache( - const InputsRequirement& input_stack); + //! extract permutation for I/O tensor from accumulcated tensor type pointer + //! on all inputs; + void extractPermutation(const TensorTypePtr& acc_type); private: - // Computation graph; - std::shared_ptr graph_; - // TODO: poor name, we should use `eliminated_axes_` instead; - at::DimVector reduction_axes_; - - // short cut to index of stack for input set encoded with id; - std::unordered_map code_to_index_lookup_; + // common permutation order used to facilitate dimension coalescing; + at::DimVector input_permutation_; + at::DimVector pw_output_permutation_; + at::DimVector reduction_output_permutation_; - // TODO: we should really hash instead of iterative check. Optimize later... - // unordered_map; - std::vector input_stacks_; - std::vector> fe_cache_; - - // inputs to unique_id lookup table; - InputsIdLookup inputs_id_lookup_; + //! FusionExecutorCache that performs schedule and kernel execution; + std::unique_ptr fusion_executor_cache_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 7941f369d4ff8..50d5c0caf05da 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -1,4 +1,3 @@ - #include #include #include @@ -8,6 +7,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace kir { NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) { @@ -61,7 +61,7 @@ IterDomain::IterDomain(Passkey, Val* start, Val* extent) start_(start), extent_(extent) {} -IterDomain::IterDomain(Passkey, const fuser::IterDomain* iter_domain) +IterDomain::IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain) : Val(iter_domain), start_(GpuLower::lowerValue(iter_domain->start())), extent_(GpuLower::lowerValue(iter_domain->rawExtent())), @@ -88,10 +88,12 @@ TensorDomain::TensorDomain(Passkey, std::vector domain) resetDomains(); } -TensorDomain::TensorDomain(Passkey, const fuser::TensorDomain* tensor_domain) +TensorDomain::TensorDomain( + Passkey, + const fuser::cuda::TensorDomain* tensor_domain) : Val(tensor_domain), contiguity_(tensor_domain->contiguity()) { const auto lowerIterDomains = - [](const std::vector& domains) { + [](const std::vector& domains) { std::vector lowered_domains; lowered_domains.reserve(domains.size()); for (const auto iter_domain : domains) { @@ -165,7 +167,7 @@ std::vector TensorDomain::noBroadcasts( return no_broadcast_domains; } -TensorView::TensorView(Passkey, const fuser::TensorView* tv) +TensorView::TensorView(Passkey, const fuser::cuda::TensorView* tv) : Val(tv), fuser_tv_(tv) { domain_ = GpuLower::lowerValue(tv->domain())->as(); memory_type_ = tv->getMemoryType(); @@ -265,7 +267,7 @@ BroadcastOp::BroadcastOp(Passkey, Val* out, Val* in) TensorIndex::TensorIndex( Passkey, - const fuser::TensorView* view, + const fuser::cuda::TensorView* view, std::vector indices) : Val(ValType::TensorIndex, view->getDataType().value(), true, true), view_(GpuLower::lowerValue(view)->as()), @@ -446,13 +448,15 @@ std::string GridReduction::getPredicateFlagName(const TensorView* val) { } // TODO(kir): remove this -std::string GridReduction::getPredicateFlagName(const fuser::TensorView* val) { +std::string GridReduction::getPredicateFlagName( + const fuser::cuda::TensorView* val) { std::stringstream ss; ss << "T" << val->name() << "_pred"; return ss.str(); } } // namespace kir +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index e51bde37d285c..da49d4369324c 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -19,6 +18,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace kir { class IrBuilder; @@ -39,7 +39,7 @@ class TORCH_CUDA_API NamedScalar : public Val { NamedScalar(Passkey, std::string name, DataType dtype) : Val(ValType::KirNamedScalar, dtype, true, true), name_(name) {} - explicit NamedScalar(Passkey, const fuser::NamedScalar* node) + explicit NamedScalar(Passkey, const fuser::cuda::NamedScalar* node) : Val(node), name_(node->name()) {} const std::string& name() const { @@ -70,7 +70,7 @@ class TORCH_CUDA_API Bool : public Val { : Val(ValType::KirScalar, DataType::Bool, true, true), maybe_value_(value) {} - explicit Bool(Passkey, const fuser::Bool* node) + explicit Bool(Passkey, const fuser::cuda::Bool* node) : Val(node), maybe_value_(node->value()) {} bool isSymbolic() const { @@ -95,7 +95,7 @@ class TORCH_CUDA_API Float : public Val { : Val(ValType::KirScalar, DataType::Float, true, true), maybe_value_(value) {} - explicit Float(Passkey, const fuser::Float* node) + explicit Float(Passkey, const fuser::cuda::Float* node) : Val(node), maybe_value_(node->value()) {} bool isSymbolic() const { @@ -118,7 +118,7 @@ class TORCH_CUDA_API Half : public Val { : Val(ValType::KirScalar, DataType::Half, true, true), maybe_value_(value) {} - explicit Half(Passkey, const fuser::Half* node) + explicit Half(Passkey, const fuser::cuda::Half* node) : Val(node), maybe_value_(node->value()) {} bool isSymbolic() const { @@ -143,7 +143,10 @@ class TORCH_CUDA_API Int : public Val { : Val(ValType::KirScalar, DataType::Int, true, true), maybe_value_(value) {} - explicit Int(Passkey, const fuser::Int* node, bool /*avoid_zero_ambiguity*/) + explicit Int( + Passkey, + const fuser::cuda::Int* node, + bool /*avoid_zero_ambiguity*/) : Val(node), maybe_value_(node->value()) {} bool isSymbolic() const { @@ -164,7 +167,7 @@ class TORCH_CUDA_API IterDomain : public Val { public: IterDomain(Passkey, Val* start, Val* extent); - explicit IterDomain(Passkey, const fuser::IterDomain* iter_domain); + explicit IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain); bool isReduction() const { return getIterType() == IterType::Reduction; @@ -218,6 +221,10 @@ class TORCH_CUDA_API IterDomain : public Val { Val* extent() const; + Val* rawExtent() const { + return extent_; + } + private: Val* const start_ = nullptr; Val* const extent_ = nullptr; @@ -230,7 +237,9 @@ class TORCH_CUDA_API TensorDomain : public Val { public: explicit TensorDomain(Passkey, std::vector domain); - explicit TensorDomain(Passkey, const fuser::TensorDomain* tensor_domain); + explicit TensorDomain( + Passkey, + const fuser::cuda::TensorDomain* tensor_domain); std::vector::size_type nDims() const { return domain_.size(); @@ -297,7 +306,7 @@ class TORCH_CUDA_API TensorDomain : public Val { class TORCH_CUDA_API TensorView : public Val { public: - explicit TensorView(Passkey, const fuser::TensorView* tv); + explicit TensorView(Passkey, const fuser::cuda::TensorView* tv); TensorDomain* domain() const { return domain_; @@ -307,7 +316,7 @@ class TORCH_CUDA_API TensorView : public Val { return memory_type_; } - const fuser::TensorView* fuserTv() const { + const fuser::cuda::TensorView* fuserTv() const { TORCH_INTERNAL_ASSERT(fuser_tv_ != nullptr); return fuser_tv_; } @@ -317,7 +326,7 @@ class TORCH_CUDA_API TensorView : public Val { MemoryType memory_type_ = MemoryType::Local; // TODO(kir): remove temporary hack - const fuser::TensorView* fuser_tv_ = nullptr; + const fuser::cuda::TensorView* fuser_tv_ = nullptr; }; class TORCH_CUDA_API UnaryOp : public Expr { @@ -455,7 +464,7 @@ class TORCH_CUDA_API TensorIndex : public Val { public: TensorIndex( Passkey, - const fuser::TensorView* view, + const fuser::cuda::TensorView* view, std::vector indices); std::vector::size_type nDims() const { @@ -530,11 +539,24 @@ class TORCH_CUDA_API Allocate : public Expr { return buffer_->getDataType().value(); } + Allocate* alias() const { + return alias_; + } + + void setAlias(Allocate* alias) { + TORCH_INTERNAL_ASSERT(alias->getMemoryType() == memory_type_); + alias_ = alias; + } + private: Val* buffer_ = nullptr; MemoryType memory_type_ = MemoryType::Local; Val* size_ = nullptr; bool zero_init_ = false; + + // This alias tracks the next Allocate node in a linked chain of aliases + // If the alias is nullptr, then the Allocate node uses memory in the kernel + Allocate* alias_ = nullptr; }; // Sync represents __syncthreads barrier for block level coordination. @@ -723,7 +745,7 @@ class TORCH_CUDA_API GridReduction : public Expr { } static std::string getPredicateFlagName(const TensorView* val); - static std::string getPredicateFlagName(const fuser::TensorView* val); + static std::string getPredicateFlagName(const fuser::cuda::TensorView* val); private: ReductionOp* reduction_op_ = nullptr; @@ -733,6 +755,7 @@ class TORCH_CUDA_API GridReduction : public Expr { }; } // namespace kir +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp index 84fb818891f6e..afd6e2a4919c6 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp @@ -1,9 +1,9 @@ - #include namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace kir { bool isLoweredScalar(const Val* val) { @@ -99,6 +99,7 @@ Val* IrBuilder::modExpr(Val* lhs, Val* rhs) { } } // namespace kir +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h b/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h index bed780edcc65c..0af37c8c410bd 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -10,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace kir { // Simple classification helpers @@ -76,6 +76,7 @@ class IrBuilder { }; } // namespace kir +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp new file mode 100644 index 0000000000000..4a20543e53f2d --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp @@ -0,0 +1,271 @@ +#include + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace kir { + +static std::string boolLiteral(bool value) { + return value ? "true" : "false"; +} + +void IrPrinter::printNode(const Statement* stmt) { + handle(stmt); +} + +void IrPrinter::printKernel(const Kernel* kernel) { + TORCH_CHECK(kernel != nullptr); + + // kernel declaration + os_ << "\nKERNEL ("; + for (auto in : kernel->inputs()) { + os_ << gen(in); + if (in != kernel->inputs().back()) { + os_ << ", "; + } + } + os_ << ") -> ("; + for (auto out : kernel->outputs()) { + os_ << gen(out); + if (out != kernel->outputs().back()) { + os_ << ", "; + } + } + os_ << ") :\n"; + + // kernel body + startBlock(); + for (auto expr : kernel->topLevelExprs()) { + handle(expr); + } + endBlock(); + os_ << "END.\n\n"; +} + +std::ostream& IrPrinter::indent() { + for (int i = 0; i < indent_level_; ++i) { + os_ << kTab; + } + return os_; +} + +std::string IrPrinter::gen(const Statement* stmt) { + std::stringstream ss; + IrPrinter ir_printer(ss); + ir_printer.handle(stmt); + return ss.str(); +} + +void IrPrinter::startBlock() { + ++indent_level_; +} + +void IrPrinter::endBlock() { + TORCH_CHECK(indent_level_ > 0); + --indent_level_; +} + +void IrPrinter::handleBlock(const kir::Scope& scope) { + startBlock(); + for (auto expr : scope.exprs()) { + handle(expr); + } + endBlock(); +} + +void IrPrinter::handle(const Statement* s) { + OptInConstDispatch::handle(s); +} + +void IrPrinter::handle(const Val* v) { + OptInConstDispatch::handle(v); +} + +void IrPrinter::handle(const Expr* e) { + OptInConstDispatch::handle(e); +} + +void IrPrinter::handle(const kir::Bool* node) { + if (node->isSymbolic()) { + os_ << "b" << node->name(); + } else { + os_ << boolLiteral(*node->value()); + } +} + +void IrPrinter::handle(const kir::Float* node) { + if (node->isSymbolic()) { + os_ << "f" << node->name(); + } else { + const int digits = std::numeric_limits::max_digits10; + os_ << "float(" << std::setprecision(digits) << *node->value() << ")"; + } +} + +void IrPrinter::handle(const kir::Half* node) { + if (node->isSymbolic()) { + os_ << "h" << node->name(); + } else { + os_ << "half(" << *node->value() << ")"; + } +} + +void IrPrinter::handle(const kir::Int* node) { + if (node->isSymbolic()) { + os_ << "i" << node->name(); + } else { + os_ << *node->value(); + } +} + +void IrPrinter::handle(const kir::NamedScalar* node) { + os_ << node->name(); +} + +void IrPrinter::handle(const kir::TensorIndex* node) { + os_ << gen(node->view()) << "["; + for (auto index : node->indices()) { + os_ << gen(index); + if (index != node->indices().back()) { + os_ << ", "; + } + } + os_ << "]"; +} + +void IrPrinter::handle(const kir::IterDomain* node) { + if (node->isRFactorProduct()) { + os_ << "rfactor."; + } + os_ << node->getParallelType() << "." << node->getIterType() << "(" + << gen(node->start()) << " .. " << gen(node->rawExtent()) << ")"; +} + +void IrPrinter::handle(const kir::TensorDomain*) { + // TODO(kir): print Tensor shapes? + os_ << "kir::TensorDomain"; +} + +void IrPrinter::handle(const kir::TensorView* node) { + // TODO(KIR): print memory type too? + os_ << "T" << node->name(); +} + +void IrPrinter::handle(const kir::UnaryOp* node) { + indent() << gen(node->out()) << " = "; + + if (auto op = inline_op_str(node->getUnaryOpType())) { + os_ << *op << gen(node->in()); + } else { + if (node->getUnaryOpType() == UnaryOpType::Cast) { + const auto cast_str = cast_func_str( + {node->in()->getDataType().value(), + node->out()->getDataType().value()}); + os_ << cast_str.value(); + } else { + os_ << node->getUnaryOpType(); + } + + os_ << "("; + if (node->getUnaryOpType() == UnaryOpType::RandLike) { + os_ << "RND"; + } else { + os_ << gen(node->in()); + } + os_ << ")"; + } + + os_ << "\n"; +} + +void IrPrinter::handle(const kir::BinaryOp* node) { + indent() << gen(node->out()) << " = "; + + const auto op_type = node->getBinaryOpType(); + const auto lhs = gen(node->lhs()); + const auto rhs = gen(node->rhs()); + + if (auto op = inline_op_str(op_type)) { + os_ << lhs << " " << *op << " " << rhs; + } else { + os_ << op_type << "(" << lhs << ", " << rhs << ")"; + } + + os_ << "\n"; +} + +void IrPrinter::handle(const kir::TernaryOp* node) { + indent() << gen(node->out()) << " = " << node->getTernaryOpType() << "(" + << gen(node->in1()) << ", " << gen(node->in2()) << ", " + << gen(node->in3()) << ")\n"; +} + +void IrPrinter::handle(const kir::ReductionOp* node) { + indent() << gen(node->out()) << " = " + << "REDUCTION(op='" << node->getReductionOpType() << "'" + << ", in=" << gen(node->in()) << ", init=" << gen(node->init()) + << ", pred=" << gen(node->pred()) << ")\n"; +} + +void IrPrinter::handle(const kir::GridReduction* node) { + const auto* reduction_op = node->reduction_op(); + indent() << gen(reduction_op->out()) << " = " + << "GRID_REDUCTION(op='" << reduction_op->getReductionOpType() << "'" + << ", in=" << gen(reduction_op->in()) + << ", init=" << gen(reduction_op->init()) + << ", pred=" << gen(reduction_op->pred()) << ")\n"; + indent() << kTab << ".reduction_buffer=" << gen(node->reduction_buffer()) + << "\n"; + indent() << kTab << ".sync_buffer=" << gen(node->sync_buffer()) << "\n"; + indent() << kTab << ".grid_pred=" << gen(node->pred()) << "\n"; +} + +void IrPrinter::handle(const kir::BroadcastOp* node) { + indent() << gen(node->out()) << " = BROADCAST(" << gen(node->in()) << ")\n"; +} + +void IrPrinter::handle(const kir::ForLoop* node) { + indent() << "FOR " << gen(node->index()) << " in " << gen(node->iter_domain()) + << ":\n"; + handleBlock(node->body()); +} + +void IrPrinter::handle(const kir::IfThenElse* node) { + indent() << "IF " << gen(node->cond()) << ":\n"; + handleBlock(node->thenBody()); + if (node->hasElse()) { + indent() << "ELSE:\n"; + handleBlock(node->elseBody()); + } +} + +void IrPrinter::handle(const kir::Allocate* node) { + indent() << gen(node->buffer()) << " = ALLOCATE(" + << "mem_type=" << node->getMemoryType() << ", " + << "size=" << gen(node->size()) << ", " + << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n"; +} + +void IrPrinter::handle(const kir::Sync* node) { + indent() << "SYNC(war_hazard=" << boolLiteral(node->isWarHazardSync()) + << ")\n"; +} + +std::string toString(const Statement* stmt) { + std::stringstream ss; + IrPrinter ir_printer(ss); + ir_printer.printNode(stmt); + return ss.str(); +} + +} // namespace kir +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h new file mode 100644 index 0000000000000..4cabd4beda789 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h @@ -0,0 +1,84 @@ +#pragma once + +#include + +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace kir { + +//! Define pretty printing functions for Kernel IR nodes +//! +//! This class is intended for debug printing, so it attempts +//! to handle invalid IR states as much as possible. +//! +class TORCH_CUDA_API IrPrinter : private OptInConstDispatch { + static constexpr char* kTab = " "; + + public: + //! Constructs a new IrPrinter which outputs to the specified stream + explicit IrPrinter(std::ostream& os) : os_(os) {} + + //! Print a single Kernel IR node + void printNode(const Statement* stmt); + + //! Print a complete Kernel definition + void printKernel(const Kernel* kernel); + + private: + static std::string gen(const Statement* stmt); + + std::ostream& indent(); + + void startBlock(); + void endBlock(); + void handleBlock(const kir::Scope& scope); + + void handle(const Statement*) final; + void handle(const Val*) final; + void handle(const Expr*) final; + + void handle(const kir::Bool*) final; + void handle(const kir::Float*) final; + void handle(const kir::Half*) final; + void handle(const kir::Int*) final; + void handle(const kir::NamedScalar*) final; + + void handle(const kir::TensorIndex*) final; + void handle(const kir::IterDomain*) final; + void handle(const kir::TensorDomain*) final; + void handle(const kir::TensorView*) final; + + void handle(const kir::UnaryOp*) final; + void handle(const kir::BinaryOp*) final; + void handle(const kir::TernaryOp*) final; + void handle(const kir::ReductionOp*) final; + void handle(const kir::BroadcastOp*) final; + + void handle(const kir::GridReduction*) final; + void handle(const kir::ForLoop*) final; + void handle(const kir::IfThenElse*) final; + void handle(const kir::Allocate*) final; + void handle(const kir::Sync*) final; + + private: + std::ostream& os_; + int indent_level_ = 0; +}; + +//! Returns the string representation of a Kernel IR node +std::string toString(const Statement* stmt); + +} // namespace kir +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h deleted file mode 100644 index d30eb3fcda522..0000000000000 --- a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h +++ /dev/null @@ -1,659 +0,0 @@ -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -// IO data structure for kernel code; -static auto code_template_tensor_struct = R"( -typedef unsigned char uint8_t; -typedef signed char int8_t; -typedef short int int16_t; -typedef long long int int64_t; - -template -struct Tensor { - T& operator[](int64_t ind) { - return data[ind]; - }; - - T* data; - int64_t size[N]; - int64_t stride[N]; -}; - -// Specialization for 0-dim case as it does not need size and stride arrays. -// They will be an error as well since zero-length arrays are not allowed. -template -struct Tensor { - T& operator[](int64_t) { - return *data; - }; - - T* data; -}; -)"; - -// Code support for FP16 __half type and intrinsics -static auto code_fp16_support = R"( -#define __HALF_TO_US(var) *(reinterpret_cast(&(var))) -#define __HALF_TO_CUS(var) *(reinterpret_cast(&(var))) -struct __align__(2) __half { - __host__ __device__ __half() { } -protected: - unsigned short __x; -}; - -/* Definitions of intrinsics */ -__device__ __half __float2half(const float f) { - __half val; - asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f)); - return val; -} -__device__ float __half2float(const __half h) { - float val; - asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h))); - return val; -} -)"; - -// struct and code for functions that need random number generation -static auto code_random_number_gen = R"( -class Philox { -public: - __device__ inline Philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) { - key.x = (unsigned int)seed; - key.y = (unsigned int)(seed >> 32); - counter = make_uint4(0, 0, 0, 0); - counter.z = (unsigned int)(subsequence); - counter.w = (unsigned int)(subsequence >> 32); - STATE = 0; - incr_n(offset / 4); - } - __device__ inline unsigned long operator()() { - if(STATE == 0) { - uint4 counter_ = counter; - uint2 key_ = key; - for(int i = 0; i < 9; i++) { - counter_ = single_round(counter_, key_); - key_.x += (kPhilox10A); key_.y += (kPhilox10B); - } - output = single_round(counter_, key_); - incr(); - } - unsigned long ret; - switch(STATE) { - case 0: ret = output.x; break; - case 1: ret = output.y; break; - case 2: ret = output.z; break; - case 3: ret = output.w; break; - } - STATE = (STATE + 1) % 4; - return ret; - } -private: - uint4 counter; - uint4 output; - uint2 key; - unsigned int STATE; - __device__ inline void incr_n(unsigned long long n) { - unsigned int nlo = (unsigned int)(n); - unsigned int nhi = (unsigned int)(n >> 32); - counter.x += nlo; - if (counter.x < nlo) - nhi++; - counter.y += nhi; - if (nhi <= counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - __device__ inline void incr() { - if (++counter.x) - return; - if (++counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, - unsigned int *result_high) { - *result_high = __umulhi(a, b); - return a*b; - } - __device__ inline uint4 single_round(uint4 ctr, uint2 key) { - unsigned int hi0; - unsigned int hi1; - unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); - unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); - uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; - return ret; - } - static const unsigned long kPhilox10A = 0x9E3779B9; - static const unsigned long kPhilox10B = 0xBB67AE85; - static const unsigned long kPhiloxSA = 0xD2511F53; - static const unsigned long kPhiloxSB = 0xCD9E8D57; -}; -// Inverse of 2^32. -#define M_RAN_INVM32 2.3283064e-10f -__device__ __inline__ float uniform(unsigned int x) { - return x * M_RAN_INVM32; -} -)"; - -// Helper functions for Operations -static auto code_helper_funcs = R"( -__device__ constexpr int ceilDiv(const int a, const int b) { - return (a + b - 1) / b; -} -__device__ constexpr int alignBufferSize(const int buffer, const int size) { - return (buffer + (size-1)) & ~(size-1); -} -__device__ float clamp(const float x, const float minv, const float maxv) { - return x < minv ? minv : (x > maxv ? maxv : x); -} -__device__ float frac(const float x) { - return x - truncf(x); -} -__device__ float gelu(const float x) { - return x * normcdf(x); -} -__device__ float reciprocal(const float x) { - return 1.f / x; -} -__device__ float relu(const float x) { - return x <= 0.f ? 0.f : x; -} -__device__ float remainder(const float a, const float b) { - return a - b * floorf(a / b); -} -__device__ float sigmoid(const float x) { - return 1.f / (1.f + expf(-x)); -} -__device__ float threshold(const float x, const float t, const float v) { - return x <= t ? v : x; -} -__device__ float where(const bool c, const float a, const float b) { - return c ? a : b; -} -__device__ float randLike(Philox rnd) { - return uniform(rnd()); -}; -)"; - -/* - * EXAMPLE USAGE: - * blockReduceSum - * (output[output_index], inputs[input_index], [] __device__ (T& a, const T - * b) { a += b; } ); - */ -static auto code_template_block_reduction = R"( -// [Z,Y,X]_THREADS is the number of participating threads in the z, y, x -// dimension of the block. If set to 0 it means that dimension doesn't -// participate, otherwise it is the number of threads. We could start with warp -// reductions, then reduce the warps, this could save some shared memory, but -// may actually be slower. -template -__inline__ __device__ -void blockReduce( - T& out, - const T inp_val, - Func reduction_op, - const dim3& thread_idx, - const dim3& block_dim, - T* shared_mem, - bool read_write_pred, - T init_val) { - - unsigned int reduction_size - = (X_REDUCE ? block_dim.x : 1) - * (Y_REDUCE ? block_dim.y : 1) - * (Z_REDUCE ? block_dim.z : 1); - - // If this thread will output a final result - bool should_write = true; - - if (X_REDUCE) - should_write = should_write && thread_idx.x == 0; - if (Y_REDUCE) - should_write = should_write && thread_idx.y == 0; - if (Z_REDUCE) - should_write = should_write && thread_idx.z == 0; - - unsigned int reduction_stride; - unsigned int reduction_tid; - unsigned int linear_tid; - - if(X_REDUCE && !Y_REDUCE && Z_REDUCE){ - // Transpose Z and Y in the shared memory so Z and X dims are contiguous in smem - reduction_stride = 1; - linear_tid = threadIdx.y * blockDim.z * blockDim.x + threadIdx.z * blockDim.x + threadIdx.x; - reduction_tid = threadIdx.z * blockDim.x + threadIdx.x; - } else { - // Normal reduction in order - reduction_stride - = (X_REDUCE ? 1 - : (Y_REDUCE ? block_dim.x - : (Z_REDUCE ? block_dim.x * block_dim.y : 0))); - - linear_tid = thread_idx.z * block_dim.y * block_dim.x + thread_idx.y * block_dim.x + thread_idx.x; - - reduction_tid - = ( Z_REDUCE ? thread_idx.z : 0 ) * ( Y_REDUCE ? block_dim.y : 1 ) * ( X_REDUCE ? block_dim.x : 1 ) - + ( Y_REDUCE ? thread_idx.y : 0 ) * ( X_REDUCE ? block_dim.x : 1 ) - + ( X_REDUCE ? thread_idx.x : 0 ); - } - - assert( reduction_stride != 0 ); - - if(read_write_pred){ - shared_mem[linear_tid] = inp_val; - } else { - shared_mem[linear_tid] = init_val; - } - __syncthreads(); - // Reduce down to nearest power of 2: - int np2 = 1 << (31 - __clz(reduction_size)); - - if( reduction_tid < np2 ){ - if( reduction_tid + np2 < reduction_size){ - reduction_op( shared_mem[linear_tid], shared_mem[linear_tid + np2 * reduction_stride] ); - } - } - __syncthreads(); - //for (int factor = np2/2; factor > contig_threads / 2; factor>>=1) { - for (int factor = np2/2; factor > 0; factor>>=1) { - if (reduction_tid < factor) { - reduction_op( shared_mem[linear_tid], shared_mem[linear_tid + factor * reduction_stride] ); - } - __syncthreads(); - } - - if(should_write && read_write_pred) - out = shared_mem[linear_tid]; - -} -)"; - -/** - Inter-block reduction. - - Function gridReduce performs point-wise reductions of scalars across thread - blocks. Thread blocks are disjointly partitioned into groups of thread blocks, - "reduction segments," that are collectively defined by boolean template - parameters, X_BLOCK, Y_BLOCK and Z_BLOCK. Each of X/Y/Z_BLOCK determines - whether thread blocks along the dimension should be grouped into the same - reduction segment. Cross-block reducitons are independently done within each - segment and generates distinctive results per segment. For instance, if all of - X/Y/Z_BLOCK are true, reductions will be done across all thread blocks since - there will be just a single segment consisting of all thread blocks. If none - of them are true, each thread block will become a segment by itself, so no - reduction will be performed. - - The input scalars to reduce within each segment are a certain subset of - thread-private scalars provided as part of the gridReduce function parameters. - Boolean template parameters, X_THREAD, Y_THREAD and Z_THREAD, determine which - subset of the scalars should be used for inter-block reductions. Specifically, - all the input scalars of threads along each dimension will be used when - X/Y/Z_THREAD are true. Otherwise, only the value held at offset 0 of each - dimension will be used. Thus, for example, if all of X/Y/Z_THREAD are true, - the scalars of all threads in each block will participate in inter-block - reductions. If all of them are false, only one scalar of the thread at - threadIdx.x == threadIdx.y == threadIdx.z == 0 will be used. In the code - below, we call the subset of threads a "reduction block." - - Inter-block reductions perform point-wise reductions of scalars of reduction - blocks within each reduction segment. More specifically, let rb be a reduction - block and rs be a reduction segment. Let IN(thread_idx, block_idx) denote the - input scalar of thread at thread_idx and block_idx. The result of each - reduction segment, OUT(thread_idx, block_idx_out), is defined only for each - thread_idx in thread block block_idx_out in the segment as follows: - - OUT(thread_idx, block_idx_out) = Reduction of IN(thread_idx, block_idx) for - all block_idx in a reduction segment - - OUT is not given for all threads that are not in block_idx_out and the - reduction block. - - See also the function comment of gridReduce. -*/ -static auto code_template_grid_reduction = R"( -namespace reduction { - -// Utility functions -__host__ __device__ __forceinline__ size_t size(const dim3& d) { - return (size_t)d.x * (size_t)d.y * (size_t)d.z; -} - -__host__ __device__ __forceinline__ int isize(const dim3& d) { - return d.x * d.y * d.z; -} - -__host__ __device__ __forceinline__ size_t offset(const dim3& pos, const dim3& dim) { - return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x + - (size_t)pos.z * (size_t)dim.x * (size_t)dim.y; -} - -__host__ __device__ __forceinline__ size_t ioffset(const dim3& pos, const dim3& dim) { - return pos.x + pos.y * dim.x + pos.z * dim.x * dim.y; -} - -// Returns dim3 of each reduction segment. -template -__host__ __device__ dim3 dimension_of_reduction_segment(const dim3& grid_dim) { - return dim3{X_BLOCK ? grid_dim.x : 1, - Y_BLOCK ? grid_dim.y : 1, - Z_BLOCK ? grid_dim.z : 1}; -} - -// Returns the number of blocks in each reduction segment. -template -__host__ __device__ size_t size_of_reduction_segment(const dim3& grid_dim) { - return size(dimension_of_reduction_segment(grid_dim)); -} - -// Returns the total number of reduction segments. -template -__host__ __device__ size_t number_of_reduction_segments(const dim3& grid_dim) { - return (X_BLOCK ? 1: grid_dim.x) * - (Y_BLOCK ? 1 : grid_dim.y) * - (Z_BLOCK ? 1 : grid_dim.z); -} - -// Returns the 1-D index of the segment of thread block of block_idx. -template -__host__ __device__ size_t index_of_reduction_segment(const dim3& block_idx, - const dim3& grid_dim) { - size_t seg_idx = 0; - if (!Z_BLOCK) - seg_idx += block_idx.z; - if (!Y_BLOCK) - seg_idx = seg_idx * grid_dim.y + block_idx.y; - if (!X_BLOCK) - seg_idx = seg_idx * grid_dim.x + block_idx.x; - return seg_idx; -} - -// Returns the offset of thread block in its reduction segment. -template -__host__ __device__ size_t offset_in_reduction_segment(const dim3& block_idx, - const dim3& grid_dim) { - size_t offset = 0; - if (Z_BLOCK) - offset = offset * grid_dim.z + block_idx.z; - if (Y_BLOCK) - offset = offset * grid_dim.y + block_idx.y; - if (X_BLOCK) - offset = offset * grid_dim.x + block_idx.x; - return offset; -} - -// Returns dim3 of each reduction block. -template -__host__ __device__ dim3 dimension_of_reduction_block(const dim3& block_dim) { - return dim3{X_THREAD ? block_dim.x : 1, - Y_THREAD ? block_dim.y : 1, - Z_THREAD ? block_dim.z : 1}; -} - -// Returns the number of threads of each reduction block. -template -__host__ __device__ int size_of_reduction_block(const dim3& block_dim) { - return isize(dimension_of_reduction_block(block_dim)); -} - -// Returns the linear offset of a thread in a reduction block. -template -__host__ __device__ int offset_in_reduction_block(const dim3& thread_idx, - const dim3& block_dim) { - int offset = 0; - if (Z_THREAD) - offset += thread_idx.z; - if (Y_THREAD) - offset = offset * block_dim.y + thread_idx.y; - if (X_THREAD) - offset = offset * block_dim.x + thread_idx.x; - return offset; -} - -/** Reduces all the reduction blocks in each reduction segment. - - This is only used by one thread block per reduction segment. The input - reduction blocks of the segment are stored in an intermediate buffer pointed - by parameter in. Template parameters X/Y/Z_THREAD denote how the reduction - block is formed. - - The size of a reduction block is by definition smaller or equal to the size of - a thread block. We use the remaining threads to parallelize reductions across - reduction blocks. For example, when X/Y/Z_THREAD = {true, false, false}, we - use blockDim.y*blockDim.z threads for each output value. This is done first by - loading the input values in parallel and then by reducing across threads of - dimensions whose XYZ_THREAD are false. - - Note that what is done here after the loading from global memory is similar to - what the existing blockReduce function does. The main difference is that the - logical block to reduce is a 2D domain where the leading dimension is the size - of a reduction block and the second dimension is the remaining factor in each - thread block. For example, when X/Y/Z_THREAD = {false, true, false}, the - threads are arranged as (blockDim.y, blockDim.x*blockDim.z). We do not reduce - along the first dimension but only the second dimension. So, it is possible to - reuse the existing blockReduce with dim3{blockDim.y, blockDim.x*blockDim.z} - instead of blockDim and with X_THREAD and Y_THREAD being false and true, - respectively. Also, it still need to shuffle the final output values to their - actual corresponding threads. In the case of when X/Y/Z_THREAD = {false, true, - false}, after the intra-block reduction, the final results will still be held - by the first blockDim.y threads, which need to be transferred to threads at - threadIdx.x == 0 and threadIdx.z == 0. -*/ -template -__device__ void gridReduceLastBlock( - T& out, - const T *in, - const size_t in_size, - Func reduction_op, - T* shared_buf, - bool read_write_pred, - T init_val) { - - const int tid = ioffset(threadIdx, blockDim); - const int block_size = isize(blockDim); - const int rblock_size = size_of_reduction_block(blockDim); - - T inp = init_val; - if (tid < in_size) { - inp = in[tid]; - } - for (size_t i = tid + block_size; i < in_size; i += block_size) { - reduction_op(inp, in[i]); - } - - const auto should_write = (X_THREAD || threadIdx.x == 0) && - (Y_THREAD || threadIdx.y == 0) && - (Z_THREAD || threadIdx.z == 0); - - auto rem_size = block_size / rblock_size; - - if (rem_size > 1) { - const int rblock_offset = tid % rblock_size; - const int rblock_idx = tid / rblock_size; - blockReduce( - inp, inp, reduction_op, - dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0}, - dim3{(unsigned)rblock_size, (unsigned)rem_size}, - shared_buf, true, init_val); - __syncthreads(); - if (tid < rblock_size) { - shared_buf[tid] = inp; - } - __syncthreads(); - if (should_write) { - inp = shared_buf[offset_in_reduction_block( - threadIdx, blockDim)]; - } - } - - if (should_write && read_write_pred) { - out = inp; - } -} - -/** Reduces per-thread values across thread blocks. - -Function parameters: -- out: Per-thread output location -- inp_val: Per-thread input value -- reduction_op: Scalar reduction function -- work_buf: Temporary buffer for cross-block reductions -- sync_flags: A vector of integers for synchronizations -- shared_buf: Shared memory buffer for intra-block reduction - -Return true when the thread block has the valid result. - -Template parameters: -- X/Y/Z_BLOCK: When true, reduces across thread blocks along the X/Y/Z - dimensions -- X/Y/Z_THREAD: When true, all threads along the X/Y/Z dimensions participate in - the cross-block reduction. Otherwise, only threads at offset 0 do. -- T: Scalar data type of input/output data -- Func: Type of scalara reduction function - -Template parameters X/Y/Z_BLOCK define a group of thread blocks that are reduced together. We call -it a reduction segment. Some examples are: - -Case 1: X/Y/Z_BLOCK == true/true/true -> There is only one segment, which includes all - thread blocks. It is effecively the same as the grid. -Case 2: X/Y/Z_BLOCK == false/false/false -> Each thread block comprises an individual - segment by itself. -Case 3: X/Y/Z_BLOCK == true/false/false -> Each segment contains thread blocks that have - the same blockDim.x. There will be blockDim.y*blockDim.z such segments. - -X/Y/Z_THREAD defines a sub region of a thread block that should be reduced with -the sub regions of other thread blocks. We call it a reduction block. E.g., - -Case 1: X/Y/Z_THREAD == false/false/false -> Only thread 0 participates in the - cross-block reductions. The reduction block is 1x1x1 with thread 0. -Case 2: X/Y/Z_THREAD == true/true/true-> All threads in a thread block participate in - the cross-block reductions. The reduction block in this case is equivalent to - the thread block. - -After the function completes, only one thread block per reduction segment gets -valid reduction results. There is no guarantee which particular block gets the -final results. -*/ -template -__device__ bool gridReduce(T& out, T inp_val, Func reduction_op, - volatile T* work_buf, - Tensor sync_flags, - T* shared_buf, bool read_write_pred, T init_val) { - - // Number of values to reduce in the grid dimensions - const auto seg_size = - size_of_reduction_segment(gridDim); - - // Index of the reduction we're performing out of the seg_size - const auto seg_idx = - index_of_reduction_segment(blockIdx, gridDim); - - // Number of threads we can use in final reduction, Seems to assume all threads in the block participate - const auto rblock_size = - size_of_reduction_block(blockDim); - - // advance to the offset for this segment - // index of reduction * size of the reduction * size of threads - work_buf += seg_idx * seg_size * rblock_size; - - if ((X_THREAD || threadIdx.x == 0) && - (Y_THREAD || threadIdx.y == 0) && - (Z_THREAD || threadIdx.z == 0)) { - auto rblock_offset = - offset_in_reduction_segment(blockIdx, gridDim); - auto thread_offset = - offset_in_reduction_block(threadIdx, blockDim); - auto work_buf_offset = rblock_size * rblock_offset + thread_offset; - if(read_write_pred){ - work_buf[work_buf_offset] = inp_val; - } else { - work_buf[work_buf_offset] = init_val; - } - } - __syncthreads(); - - __shared__ bool last_block; - if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { - __threadfence(); - // printf("%ld\n", sync_flags[seg_idx]); - auto old = (int64_t) atomicAdd( (unsigned long long*) &sync_flags[seg_idx], 1); - last_block = old + 1 == seg_size; - // printf("Last_block = %d + 1 == %d\n", (int)old, (int)seg_size); - } - __syncthreads(); - - if (last_block) { - // printf("Last block %d %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z); - // final reduction - gridReduceLastBlock( - out, (T*)work_buf, seg_size * rblock_size, - reduction_op, shared_buf, read_write_pred, init_val); - return true; - } else { - // printf("Not last block %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z); - return false; - } -} -} // namespace reduction -)"; - -static auto code_template_block_broadcast = R"( -namespace broadcast { - -template -__host__ __device__ unsigned offset_of_source(const dim3& block_dim, const dim3& thread_idx) { - unsigned offset = 0; - if (!Z_THREAD) - offset = offset * block_dim.z + thread_idx.z; - if (!Y_THREAD) - offset = offset * block_dim.y + thread_idx.y; - if (!X_THREAD) - offset = offset * block_dim.x + thread_idx.x; - return offset; -} - -/** Broadcasts within partitioned groups of threads. - - X_THREAD: Broadcast from threadIdx.x == 0 if true - Y_THREAD: Broadcast from threadIdx.y == 0 if true - Z_THREAD: Broadcast from threadIdx.z == 0 if true - inp_val: Per-thread source value. Only valid when the thread is a source. - out: Per-thread output location - */ -template - __device__ void blockBroadcast(T& out, T inp_val, T* shared_mem) { - - const bool has_valid_data = - (!X_THREAD || threadIdx.x == 0) && - (!Y_THREAD || threadIdx.y == 0) && - (!Z_THREAD || threadIdx.z == 0); - - const auto shared_offset = offset_of_source(blockDim, threadIdx); - - if (has_valid_data) - shared_mem[shared_offset] = inp_val; - - __syncthreads(); - - out = shared_mem[shared_offset]; -} - -} // namespace broadcast -)"; - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 4e9d2ec499bfa..fe36afbce338f 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -1,8 +1,9 @@ - #include + #include #include #include +#include #include #include #include @@ -14,6 +15,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // TODO(kir): revisit this thread_local GpuLower* active_gpu_lower = nullptr; @@ -113,8 +115,14 @@ void GpuLower::lower() { const auto unrolled_loops = UnrollPass::runPass(fusion_, lowered_exprs, preds); + // Reuse memory locations if: + // TensorView is dynamic shared memory + // TensorViews have the same size + // Output TensorView is modified using Input TensorView + const auto reuse_mem_exprs = reuseMemoryAllocations(fusion_, unrolled_loops); + // Insert SyncThreads at end of for-loop to avoid WAR race condition - const auto sync_exprs = insertThreadSynchronization(fusion_, unrolled_loops); + const auto sync_exprs = insertThreadSynchronization(fusion_, reuse_mem_exprs); const auto indexed_loops = IndexLowering::getIndexedExprs(fusion_, sync_exprs); @@ -171,13 +179,13 @@ class TORCH_CUDA_API GpuLower::KernelIrMapper : private OptInConstDispatch { void lowerDefinition(Val* lowered_value, const Expr* def) { switch (def->type()) { case ExprType::UnaryOp: { - const auto op = def->as(); + const auto op = def->as(); ir_builder_.create( op->getUnaryOpType(), lowered_value, lower(op->in())); break; } case ExprType::BinaryOp: { - const auto op = def->as(); + const auto op = def->as(); ir_builder_.create( op->getBinaryOpType(), lowered_value, @@ -186,7 +194,7 @@ class TORCH_CUDA_API GpuLower::KernelIrMapper : private OptInConstDispatch { break; } case ExprType::TernaryOp: { - const auto op = def->as(); + const auto op = def->as(); ir_builder_.create( op->getTernaryOpType(), lowered_value, @@ -275,6 +283,7 @@ GpuLower* GpuLower::current() { return active_gpu_lower; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index 1cc50fa20ab4d..3958a1350cf18 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -13,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class TORCH_CUDA_API GpuLower { class KernelIrMapper; @@ -61,6 +61,7 @@ class TORCH_CUDA_API GpuLower { Fusion* fusion_ = nullptr; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp new file mode 100644 index 0000000000000..12c5254b3a525 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -0,0 +1,269 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +//! Get string representation of Allocate size for symbolic comparison +//! +class SymbolicSizePrinter final : private OptOutConstDispatch { + public: + static std::string print_size(const kir::Allocate* alloc) { + SymbolicSizePrinter printer; + printer.handle(alloc->size()); + return printer.os_.str(); + } + + private: + void handle(const Val* v) final { + OptOutConstDispatch::handle(v); + } + + void handle(const Expr* e) final { + OptOutConstDispatch::handle(e); + } + + void handle(const kir::Int* node) final { + if (auto def = FusionGuard::getCurFusion()->origin(node)) { + os_ << "( "; + handle(def); + os_ << " )"; + return; + } else if (node->isSymbolic()) { + os_ << "i" << node->name(); + } else { + os_ << *node->value(); + } + } + + void handle(const kir::NamedScalar* node) final { + os_ << node->name(); + } + + void handle(const kir::BinaryOp* node) final { + if (auto inline_bop = inline_op_str(node->getBinaryOpType())) { + handle(node->lhs()); + os_ << " " << inline_bop.value() << " "; + handle(node->rhs()); + } else { + os_ << node->getBinaryOpType() << "("; + handle(node->lhs()); + os_ << ", "; + handle(node->rhs()); + os_ << ")"; + } + } + + private: + std::stringstream os_; +}; + +//! Reuse Allocation nodes via pointer aliasing +//! +class AllocateReuseModifier final : private OptOutDispatch { + public: + explicit AllocateReuseModifier(Fusion* fusion, size_t register_size_threshold) + : eval_evaluator_(fusion), + register_size_threshold_(register_size_threshold) {} + + void modify(const std::vector& exprs) { + // Find candidate TensorViews and collect analysis information + for (auto expr : exprs) { + handle(expr); + } + + // Iterate over candidates to find match + for (auto tv : candidate_alias_tv_) { + TORCH_INTERNAL_ASSERT( + map_tv_to_origin_expr_.find(tv) != map_tv_to_origin_expr_.end()); + + const auto& expr = map_tv_to_origin_expr_[tv]; + const auto output = expr->output(0)->as(); + + TORCH_INTERNAL_ASSERT( + map_tv_to_allocations_.find(output->name()) != + map_tv_to_allocations_.end()); + + auto output_alloc = map_tv_to_allocations_[output->name()]; + + auto input_alloc = findCompatibleInputAllocate( + SymbolicSizePrinter::print_size(output_alloc), expr); + if (input_alloc != nullptr) { + // std::cout << "Alias Match\t" << output->getMemoryType() << std::endl; + output_alloc->setAlias(input_alloc); + } + } + } + + private: + // Check if we are a Pointwise TensorView op. + bool isPwiseTVOp(const Expr* expr) { + // Ignore set operations + if (expr->outputs().size() == 1 && ir_utils::isTV(expr->output(0)) && + ((expr->getExprType().value() == ExprType::UnaryOp && + expr->as()->getUnaryOpType() != UnaryOpType::Set) || + expr->getExprType().value() == ExprType::BinaryOp || + expr->getExprType().value() == ExprType::TernaryOp)) + return true; + return false; + } + + // Find an Input Allocate that is compatible with the Output Allocate + kir::Allocate* findCompatibleInputAllocate( + const std::string& output_size_str, + Expr* expr) { + // Stop searching if current op is not point-wise + if (!isPwiseTVOp(expr)) { + return nullptr; + } + + const auto& expr_inputs_iter = + ir_utils::filterByType(expr->inputs()); + + std::vector expr_inputs( + expr_inputs_iter.begin(), expr_inputs_iter.end()); + + for (const auto input : expr_inputs) { + auto input_alloc = map_tv_to_allocations_[input->name()]; + + // input_allocation == nullptr implies that input_tv is a fusion input. + if (input_alloc != nullptr) { + if (candidate_alias_tv_.find(input) != candidate_alias_tv_.end() && + output_size_str == SymbolicSizePrinter::print_size(input_alloc) && + map_tv_to_last_usage_[input] <= map_expr_to_pos_[expr]) { + return input_alloc; + } + } + } + + // Assume the first argument contains the primary variable + // Follow path along point-wise operations + if (!expr_inputs.empty()) { + auto first_input_argument_tv = expr_inputs.front()->getOrigin(); + if (first_input_argument_tv != nullptr) { + return findCompatibleInputAllocate( + output_size_str, first_input_argument_tv); + } + } + return nullptr; + } + + void handle(Expr* expr) final { + size_t expr_index = map_expr_to_pos_.size(); + map_expr_to_pos_[expr] = expr_index; + + if (ir_utils::isTVOp(expr)) { + const auto output = expr->output(0)->as(); + map_tv_to_origin_expr_[output] = expr; + + bool has_allocation = map_tv_to_allocations_.find(output->name()) != + map_tv_to_allocations_.end(); + + if (has_allocation) { + bool smem_valid = output->getMemoryType() == MemoryType::Shared; + + bool local_valid = false; + if (output->getMemoryType() == MemoryType::Local) { + auto allocation = map_tv_to_allocations_[output->name()]; + auto inferred_register_size = + eval_evaluator_.inferValue(allocation->size()); + if (inferred_register_size.has_value()) { + local_valid = inferred_register_size.value() > + static_cast(register_size_threshold_); + } + } + + // For the output TV to be an alias candidate, + // its allocation size must exceed the threshold + // OR be in shared memory + if (smem_valid || local_valid) { + candidate_alias_tv_.insert(output); + } + } + + const auto& expr_inputs = + ir_utils::filterByType(expr->inputs()); + for (const auto input : expr_inputs) { + map_tv_to_last_usage_[input] = expr_index; + } + } else { + OptOutDispatch::handle(expr); + } + } + + void handle(kir::Allocate* a) final { + if (a->buffer()->getValType().value() == ValType::KirTensorView) { + auto tv = a->buffer()->as()->fuserTv(); + map_tv_to_allocations_[tv->name()] = a; + } + } + + void handle(kir::ForLoop* fl) final { + for (auto expr : fl->body().exprs()) { + handle(expr); + } + } + + void handle(kir::IfThenElse* ite) final { + for (auto expr : ite->thenBody().exprs()) { + handle(expr); + } + for (auto expr : ite->elseBody().exprs()) { + handle(expr); + } + } + + private: + // Expression Evaluator to infer size of register allocation + StatefulExpressionEvaluator eval_evaluator_; + + // Alias local memory if it exceeds this threshold + const size_t register_size_threshold_; + + // Map expression to unique position + std::unordered_map map_expr_to_pos_; + + // Map TensorView to origin expression + std::unordered_map map_tv_to_origin_expr_; + + // Map TensorView to last usage expression position + std::unordered_map map_tv_to_last_usage_; + + // Map TensorView name to Allocate node + std::unordered_map map_tv_to_allocations_; + + // Track candidate TensorViews whose Allocate nodes + // could potentially alias another Allocate node + std::unordered_set candidate_alias_tv_; +}; + +} // namespace + +std::vector reuseMemoryAllocations( + Fusion* fusion, + const std::vector& exprs) { + FUSER_PERF_SCOPE("reuseMemoryAllocations"); + FusionGuard fg(fusion); + + // Alias local memory if it exceeds this threshold + const size_t register_size_threshold = 1; + AllocateReuseModifier arm(fusion, register_size_threshold); + arm.modify(exprs); + return exprs; +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.h b/torch/csrc/jit/codegen/cuda/lower_alias_memory.h new file mode 100644 index 0000000000000..128fa39398f58 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +//! Reuse Allocation nodes via pointer aliasing +//! +//! First pass finds candidate TensorViews +//! A candidate TensorView is anything in shared memory OR +//! in local memory with a static size larger than register_size_threshold +//! +//! Second pass finds appropriate input Allocate Node +//! among candidate TensorViews +//! +//! Alias Criteria: +//! If input is a candidate TensorView, +//! input allocation has the same size as output allocation, +//! thread bindings match, +//! is not used after this op: +//! then alias output Allocate to input Allocate. +//! +std::vector reuseMemoryAllocations( + Fusion* fusion, + const std::vector& exprs); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 5dcefda05f484..8205abb4fa875 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -1,4 +1,3 @@ - #include #include #include @@ -12,6 +11,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { IndexLowering::IndexLowering() : ir_builder_(GpuLower::current()->kernel()) {} @@ -305,6 +305,7 @@ void IndexLowering::generate(const std::vector& exprs) { } } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 7e553f8013dc5..6dbf50d65ff3c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -12,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class TORCH_CUDA_API IndexLowering : public OptInDispatch { public: @@ -67,6 +68,7 @@ class TORCH_CUDA_API IndexLowering : public OptInDispatch { kir::IrBuilder ir_builder_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index 71bf2a282feca..4326b83fc4edc 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -1,7 +1,9 @@ - #include + +#include #include #include +#include #include #include #include @@ -9,27 +11,22 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { +//! Scan through Kernel IR to insert Sync nodes to avoid +//! Write-After-Read (WAR) race condition +//! class LocalSyncInserter final : private OptOutDispatch { public: - static void InsertSyncs(Expr* expr) { + // Write-After-Read race conditions are only found within for-loops. + // Sync nodes are inserted directly into the for-loops. + // The expressions are modified in-place and exprs is const. + static void InsertSyncs(const std::vector& exprs) { LocalSyncInserter sync_inserter; - sync_inserter.handle(expr); - } - - void handle(Expr* expr) final { - if (ir_utils::isTVOp(expr)) { - // For this SyncInserter - (!initial_sync_) ? hasOutputSmemExpr(expr, initial_) - : hasInputSmemExpr(expr, final_); - - // For parent SyncInserter - hasOutputSmemExpr(expr, all_smem_outputs_); - hasInputSmemExpr(expr, all_smem_inputs_); - } else { - OptOutDispatch::handle(expr); + for (auto expr : exprs) { + sync_inserter.handle(expr); } } @@ -49,7 +46,42 @@ class LocalSyncInserter final : private OptOutDispatch { return all_smem_outputs_; } + const std::unordered_set& all_aliased_allocations() const { + return all_alias_allocations_; + } + private: + explicit LocalSyncInserter( + const std::unordered_set* parent_alias_allocations = + nullptr) { + if (parent_alias_allocations != nullptr) { + all_alias_allocations_.insert( + parent_alias_allocations->begin(), parent_alias_allocations->end()); + } + } + + void handle(Expr* expr) final { + if (ir_utils::isTVOp(expr)) { + // For this SyncInserter + (!initial_sync_) ? hasOutputSmemExpr(expr, initial_) + : hasInputSmemExpr(expr, final_); + + // For parent SyncInserter + hasOutputSmemExpr(expr, all_smem_outputs_); + hasInputSmemExpr(expr, all_smem_inputs_); + } else { + OptOutDispatch::handle(expr); + } + } + + void handle(kir::Allocate* a) final { + if (a->buffer()->getValType().value() == ValType::KirTensorView && + a->alias() != nullptr && a->getMemoryType() == MemoryType::Shared) { + auto tv = a->buffer()->as()->fuserTv(); + all_alias_allocations_.insert(tv->name()); + } + } + void handle(kir::IfThenElse* ite) final { for (auto expr : ite->thenBody().exprs()) { handle(expr); @@ -69,14 +101,18 @@ class LocalSyncInserter final : private OptOutDispatch { final_.clear(); } else if (expr->getExprType().value() == ExprType::ForLoop) { // Recursively handle nested for-loop - LocalSyncInserter child_sync_inserter; + LocalSyncInserter child_sync_inserter(&all_alias_allocations_); child_sync_inserter.handle(expr); const auto& child_inputs = child_sync_inserter.all_smem_inputs(); const auto& child_outputs = child_sync_inserter.all_smem_outputs(); + const auto& child_alias_allocations = + child_sync_inserter.all_aliased_allocations(); // Default - Track all smem inputs / outputs all_smem_inputs_.insert(child_inputs.begin(), child_inputs.end()); all_smem_outputs_.insert(child_outputs.begin(), child_outputs.end()); + all_alias_allocations_.insert( + child_alias_allocations.begin(), child_alias_allocations.end()); if (!initial_sync_) { // Parent - None @@ -137,6 +173,7 @@ class LocalSyncInserter final : private OptOutDispatch { // Determine if any smem TV is written to at beginning of the for-loop // and whether that smem TV is read from at the end of the for-loop // Insert new SyncThreads at end of for-loop to prevent WAR race condition + // TODO: replace __syncthreads with __threadfence for alias ops if (detect_intersection(initial_, final_) && fl->body().exprs().back()->getExprType().value() != ExprType::Sync && !is_last_op_sync_) { @@ -186,6 +223,9 @@ class LocalSyncInserter final : private OptOutDispatch { } private: + // Track TensorViews for Allocate nodes that alias another memory location + std::unordered_set all_alias_allocations_; + // Track Shared Memory Inputs (Reads) for parent for-loop std::unordered_set all_smem_inputs_; @@ -214,14 +254,11 @@ std::vector insertThreadSynchronization( const std::vector& exprs) { FUSER_PERF_SCOPE("insertThreadSynchronization"); FusionGuard fg(fusion); - std::vector mutated_exprs; - for (auto expr : exprs) { - LocalSyncInserter::InsertSyncs(expr); - mutated_exprs.push_back(expr); - } - return mutated_exprs; + LocalSyncInserter::InsertSyncs(exprs); + return exprs; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h index e17d536de5754..82fab236db80a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h @@ -10,42 +10,44 @@ namespace torch { namespace jit { namespace fuser { - -// Insert sync at end of for-loops to prevent write-after-read race condition. -// WAR race condition occurs when the next iteration of the loop overwrites -// shared memory value before a previous operation has finished reading it. - -// WAR Race Check: -// Track all output shared memory TVs before first sync -// Track all input shared memory TVs after last sync -// If the intersection is non-empty, then there is a WAR race condition. -// Recursively check each nested for-loop - -// Parent-Child For-Loop Recursive Relationship -// Notation: -// None - Zero Syncs -// 1+ - One or more Syncs -// End - Sync is last op in for-loop to prevent WAR race condition - -// Default: Track all shared memory inputs and outputs - -// Parent - None -// Child - None => Append All Child Outputs to Parent Initial -// Child - 1+ => Parent first sync => Inherit Child Initial + Final -// Child - End => Parent first sync => Keep Child Initial / Clear Parent Final - -// Parent - 1+ -// Child - None => Append All Child to Parent Last -// Child - 1+ => Child Final to Parent Final / Discard Child Initial -// Child - End => Clear Parent Last / Discard Child Initial - -// If Child - End and Parent has zero remaining operations, then -// Parent inherits Child End. - +namespace cuda { + +//! Insert sync at end of for-loops to prevent write-after-read race condition. +//! WAR race condition occurs when the next iteration of the loop overwrites +//! shared memory value before a previous operation has finished reading it. +//! +//! WAR Race Check: +//! Track all output shared memory TVs before first sync +//! Track all input shared memory TVs after last sync +//! If the intersection is non-empty, then there is a WAR race condition. +//! Recursively check each nested for-loop +//! +//! Parent-Child For-Loop Recursive Relationship +//! Notation: +//! None - Zero Syncs +//! 1+ - One or more Syncs +//! End - Sync is last op in for-loop to prevent WAR race condition +//! +//! Default: Track all shared memory inputs and outputs +//! +//! Parent - None +//! Child - None => Append All Child Outputs to Parent Initial +//! Child - 1+ => Parent first sync => Inherit Child Initial + Final +//! Child - End => Parent first sync => Keep Child Initial / Clear Parent Final +//! +//! Parent - 1+ +//! Child - None => Append All Child to Parent Last +//! Child - 1+ => Child Final to Parent Final / Discard Child Initial +//! Child - End => Clear Parent Last / Discard Child Initial +//! +//! If Child - End and Parent has zero remaining operations, then +//! Parent inherits Child End. +//! std::vector insertThreadSynchronization( Fusion* fusion, const std::vector& exprs); +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 97c3feb507232..a2919163c3dfd 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -1,5 +1,5 @@ - #include + #include #include #include @@ -14,6 +14,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { LoopNestGenerator::LoopNestGenerator( Fusion* fusion, @@ -75,7 +76,7 @@ Expr* LoopNestGenerator::pushAlloc(TensorView* tv) { const auto alloc = ir_builder_.create( lowered_tv, lowered_tv->memoryType(), size); - // Track Shared Memory Allocation Nodes + // Track Dynamic Shared Memory Allocation Nodes if (tv->getMemoryType() == MemoryType::Shared) { if (!size->isConstScalar()) { dynamic_smem_.push_front(alloc); @@ -85,7 +86,8 @@ Expr* LoopNestGenerator::pushAlloc(TensorView* tv) { // Place the allocation if (alloc_loop != nullptr) { - alloc_loop->body().insert(0, alloc); + alloc_loop->body().insert(for_loop_allocations_[alloc_loop], alloc); + ++for_loop_allocations_[alloc_loop]; } else { lowered_exprs.insert(lowered_exprs.begin(), alloc); } @@ -98,6 +100,7 @@ void LoopNestGenerator::openFor(std::pair id_pair) { IterDomain* id = id_pair.first; if (for_loops.size() > 0) { kir::ForLoop* new_scope = scope_utils::openFor(for_loops.back(), id); + for_loop_allocations_.insert({new_scope, 0}); for_loops.push_back(new_scope); } else { for_loops.push_back(scope_utils::openFor(nullptr, id)); @@ -180,6 +183,7 @@ void LoopNestGenerator::initReduction( new_fl = ir_builder_.create( ir_builder_.create(c10::nullopt), id, inner_fl); } + for_loop_allocations_.insert({new_fl, 0}); if (init_loop_nest == nullptr) { // If this is our first generated loop, then it will be our outer most @@ -203,7 +207,7 @@ void LoopNestGenerator::initReduction( } // If we don't have an alloc_loop defined it means it needs to go in - // lowered_exprs Make sure to place after the allocation of what we're + // lowered_exprs. Make sure to place after the allocation of what we're // initializing if there is one. if (alloc_loop == nullptr) { if (alloc_expr != nullptr) { @@ -219,9 +223,10 @@ void LoopNestGenerator::initReduction( } } else { if (alloc_expr != nullptr) { - // If there is an allocation for this tensor view place this loop nest - // after it + // If there is an allocation for this TensorView + // place this loop nest after it alloc_loop->body().insert_after(alloc_expr, init_loop_nest); + ++for_loop_allocations_[alloc_loop]; } else { // Otherwise we're allocating a global value alloc_loop->body().insert(0, init_loop_nest); @@ -256,6 +261,7 @@ void LoopNestGenerator::handle(Expr* expr) { shared_memory_sync |= isModifiedSharedMemory(in); } if (shared_memory_sync) { + TORCH_INTERNAL_ASSERT(!for_loops.empty(), "Attempted to add SyncThreads"); // push Sync to the back of the last for loop scope_utils::pushBack(for_loops.back(), ir_builder_.create()); cleanSharedMemory(); @@ -488,7 +494,7 @@ void groupExpressions( } // Sort each loop-nest group based on axis (i.e., score) -void sortGroup(TensorView* target, ExprListT& exprs, ExprScoreMapT& scores) { +void sortGroup(ExprListT& exprs, ExprScoreMapT& scores) { std::stable_sort( exprs.begin(), exprs.end(), @@ -497,6 +503,61 @@ void sortGroup(TensorView* target, ExprListT& exprs, ExprScoreMapT& scores) { }); } +// If an expression is missing from expr_status, search for all ancestors +// that are necessary for the expression +void mapMissingInputsToAncestors( + const TensorView* tv, + const std::unordered_map& expr_status, + std::vector& ancestors) { + const Expr* expr = tv->getOrigin(); + const auto& expr_inputs = ir_utils::filterByType(expr->inputs()); + for (auto input : expr_inputs) { + const Expr* input_origin = input->getOrigin(); + if (input_origin != nullptr) { + if (expr_status.find(input_origin) == expr_status.end()) { + mapMissingInputsToAncestors(input, expr_status, ancestors); + } else { + ancestors.push_back(input); + } + } + } +} + +// For each expression, find all TensorView inputs. +// If an input TensorView is missing from expr_status, +// find that input's ancestors that are present in expr_status. +std::unordered_map> findExprTvInputs( + const std::unordered_map& expr_status) { + std::unordered_map> + map_expr_to_tv_inputs; + + // Iterate over all exprs and filter missing expr + for (auto item : expr_status) { + const auto expr = item.first; + const auto& expr_inputs = + ir_utils::filterByType(expr->inputs()); + + map_expr_to_tv_inputs.insert({expr, std::vector()}); + auto& tv_inputs = map_expr_to_tv_inputs[expr]; + + for (auto input : expr_inputs) { + const Expr* input_origin = input->getOrigin(); + bool missing_input = input_origin != nullptr && + expr_status.find(input_origin) == expr_status.end(); + + if (missing_input) { + // Map missing input to ancestor that is present in exprs_status + std::vector ancestors; + mapMissingInputsToAncestors(input, expr_status, ancestors); + tv_inputs.insert(tv_inputs.begin(), ancestors.begin(), ancestors.end()); + } else { + tv_inputs.push_back(input); + } + } + } + return map_expr_to_tv_inputs; +} + // Reorder expressions that are computed at the same position in a // breadth-first order. void reorderSegmentBreadthFirst( @@ -509,23 +570,25 @@ void reorderSegmentBreadthFirst( expr_status.insert({*it, false}); } + // Holds all input TVs necessary for every expression. + const auto map_expr_to_tv_inputs = findExprTvInputs(expr_status); + while (seg_begin != seg_end) { std::vector visited_exprs; for (auto it = seg_begin; it != seg_end; ++it) { const auto expr = *it; - const auto& expr_inputs = - ir_utils::filterByType(expr->inputs()); - // expr can be visited if all input expressions are already - // visited. If an input expression is not found in expr_status, - // that should be safe to ignore. + const auto& expr_inputs = map_expr_to_tv_inputs.at(expr); + + // if all input expressions are visited + // then expr can be visited const bool ready_to_visit = std::all_of( expr_inputs.begin(), expr_inputs.end(), [&expr_status](const TensorView* input) { const Expr* input_origin = input->getOrigin(); return input_origin == nullptr || - expr_status.find(input_origin) == expr_status.end() || - expr_status.at(input_origin); + (expr_status.find(input_origin) != expr_status.end() && + expr_status.at(input_origin)); }); if (ready_to_visit) { std::iter_swap(seg_begin, it); @@ -561,7 +624,7 @@ void reorderGroupBreadthFirst(ExprListT& exprs, const ExprScoreMapT& scores) { seg_begin = seg_end; seg_score = cur_score; } else { - // expre list is assumed to be sorted in the order of scores, so + // exprs list is assumed to be sorted in the order of scores, so // this should never be reachable TORCH_INTERNAL_ASSERT( false, "Unexpected expression: ", expr, ", score: ", cur_score); @@ -583,7 +646,7 @@ void mergeNonRootGroupsIntoRootGroups( std::find(target_group.begin(), target_group.end(), target_expr); TORCH_INTERNAL_ASSERT(pos != target_group.end()); target_group.insert(pos, it->second.begin(), it->second.end()); - // Upate the target map + // Update the target map for (auto& inserted_expr : it->second) { TORCH_INTERNAL_ASSERT(target_map.at(inserted_expr) == target); target_map.at(inserted_expr) = target_of_target; @@ -626,10 +689,13 @@ void mergeGroupsIntoSortedList( // outer loops need to be located earlier. void reorderExprsForComputeAt(std::vector& exprs) { ExprListT reordered_exprs; + // expr -> target ExprTargetMapT target_map; + // target -> [computed at expressions] TargetGroupMapT computed_at_exprs; + // score of each expression that is calculated based on the // computeAt axis. A lower score of an expression means it should be // placed earlier in the expression list. This is a requirement for @@ -652,7 +718,8 @@ void reorderExprsForComputeAt(std::vector& exprs) { // 2. Sort each loop-nest group based on axis (i.e., score) for (auto& group : computed_at_exprs) { - sortGroup(group.first, group.second, scores); + sortGroup(group.second, scores); + // Reorder expressions in a breadth-first order reorderGroupBreadthFirst(group.second, scores); } @@ -663,7 +730,7 @@ void reorderExprsForComputeAt(std::vector& exprs) { // At this point, only root loop-nests (i.e., no computeAt'ed) // should exist. for (auto& group : computed_at_exprs) { - // Make usre only root loop-nests exist. + // Guarantee only root loop-nests exist. TensorView* target = group.first; TORCH_INTERNAL_ASSERT(!target->hasComputeAt()); } @@ -731,6 +798,7 @@ bool LoopNestGenerator::isModifiedSharedMemory(Val* key) const { return false; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index efe056ae9fe81..fb0ffb2f3c7c7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -11,6 +11,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* * Loop nest generator pass will get IR that looks something like: @@ -53,7 +54,7 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { // Tracks if shared memory is modified std::unordered_map smem_; - // Track dynamic shared memory buffer + // Track dynamic shared memory buffers // Insert allocation at the beginning of the kernel std::deque dynamic_smem_; @@ -90,6 +91,10 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { void generate(const std::vector& exprs); private: + // Track number of allocations in each for loop. It is used to insert + // allocations in the correct order, which is necessary for memory aliasing + std::unordered_map for_loop_allocations_; + // Lowered exprs to return std::vector lowered_exprs; @@ -111,6 +116,7 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { kir::IrBuilder ir_builder_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index 03311dc43ebfe..673e790d302db 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -1,4 +1,3 @@ - #include #include @@ -11,6 +10,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -278,6 +278,7 @@ kir::Bool* ThreadPredicateMap::getExpr(const TensorView* out_tv) const { return getPredicate(at(out_tv).first, at(out_tv).second); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h index ab321dc530c87..8c139dbce1ff6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h @@ -9,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { //! Maps TensorViews to std::pair> //! @@ -60,6 +61,7 @@ class TORCH_CUDA_API ThreadPredicateMap { MapType thread_predicates_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 51fd7f0b1b825..57e4ad5614d5a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -1,4 +1,3 @@ - #include #include @@ -13,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { kir::Bool* UnrollPass::getThreadPredicate(TensorView* tv) { // No thread predicate is needed predicate when tv is output of a @@ -137,6 +137,7 @@ std::vector UnrollPass::runPass( return mutated_exprs; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index f77b8f37c8108..69f35ad17385c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -10,6 +10,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* * A bit deceptively: UnrollPass adds all predicates, so it needs to be run even @@ -107,6 +108,7 @@ class TORCH_CUDA_API UnrollPass : public OptOutDispatch { const ThreadPredicateMap& thread_predicates); }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 262cb5a7d4c0d..cfb2fceed3c53 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -1,10 +1,11 @@ - #include + #include #include #include #include #include +#include #include #include @@ -13,7 +14,7 @@ namespace torch { namespace jit { namespace fuser { - +namespace cuda { namespace scope_utils { // START SCOPE HELPER SYSTEMS @@ -488,12 +489,13 @@ bool isUnrolledFor(const Expr* expr) { } const std::unordered_map - ParallelTypeBitmap::pt_to_offset_{{ParallelType::BIDx, 0}, - {ParallelType::BIDy, 1}, - {ParallelType::BIDz, 2}, - {ParallelType::TIDx, 3}, - {ParallelType::TIDy, 4}, - {ParallelType::TIDz, 5}}; + ParallelTypeBitmap::pt_to_offset_{ + {ParallelType::BIDx, 0}, + {ParallelType::BIDy, 1}, + {ParallelType::BIDz, 2}, + {ParallelType::TIDx, 3}, + {ParallelType::TIDy, 4}, + {ParallelType::TIDz, 5}}; const std::unordered_map ParallelTypeBitmap::offset_to_pt_ = {{0, ParallelType::BIDx}, @@ -659,7 +661,7 @@ std::pair getAllocPoint( if (loops_it == loops.end()) { for (auto loop : loops) { - std::cout << loop->iter_domain() << " "; + std::cout << kir::toString(loop->iter_domain()) << " "; } std::cout << std::endl; } @@ -717,7 +719,7 @@ IterDomain* getTermIDInMap( } } // namespace loop_utils - +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 92c7c438b870f..1a2c16ab7c183 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -12,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class ThreadPredicateMap; @@ -179,7 +180,7 @@ IterDomain* getTermIDInMap( std::unordered_map p2c_root_map); } // namespace loop_utils - +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 5e1715c51b898..1967a66ffdd17 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -8,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { void validateIr(Fusion* fusion) { FUSER_PERF_SCOPE("validateIr"); @@ -69,6 +71,7 @@ void validateIr(Fusion* fusion) { } } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index daec9cb10d48f..eddee4f8350e6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -8,9 +7,11 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { void validateIr(Fusion* fusion); +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index ddddce75ad9ee..552d68329d58c 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -23,28 +23,29 @@ namespace jit { namespace fuser { namespace cuda { -// [ Note -- cache entry indexing ] -// -// CudaFusionManager holds the cache and handles interfacing to CudaFusionGroup -// node, including selection, construction and execution of FusionExecutors. -// -// CudaFusionManager bridges PyTorch IR node CudaFusionGroup to GraphCache. -// Therefore, we want to cache on stringified graph. But it is expensive to -// stringify and hash on a computational graph, we cache the hash of a -// stringified graph on node via cache_id. -// -// CudaFusionGroup node stores: -// i. a PyTorch IR in `attr::Subgraph` -// ii. an int in `attr::cache_id`, (a cached hash value of `attr::Subgraph`) -// -// We have 2 unordered_map at CudaFusionGroup: -// std::unordered_map graph_cache_ids_; -// std::unordered_map> graph_cache_; -// -// Mapping from std::string to graph_cache_id ensures that we assign the same -// cache_id to CudaFusionGroup with identical computational grah, allowing -// kernel reuse; Direct mapping from cache_id to GraphCache allows efficient -// graph_cache indexing; +//! [ Note -- cache entry indexing ] +//! +//! CudaFusionManager holds the cache and handles interfacing to CudaFusionGroup +//! node, including selection, construction and execution of FusionExecutors. +//! +//! CudaFusionManager bridges PyTorch IR node CudaFusionGroup to GraphCache. +//! Therefore, we want to cache on stringified graph. But it is expensive to +//! stringify and hash on a computational graph, we cache the hash of a +//! stringified graph on node via cache_id. +//! +//! CudaFusionGroup node stores: +//! i. a PyTorch IR in `attr::Subgraph` +//! ii. an int in `attr::cache_id`, (a cached hash value of +//! `attr::Subgraph`) +//! +//! We have 2 unordered_map at CudaFusionGroup: +//! std::unordered_map graph_cache_ids_; +//! std::unordered_map> graph_cache_; +//! +//! Mapping from std::string to graph_cache_id ensures that we assign the same +//! cache_id to CudaFusionGroup with identical computational grah, allowing +//! kernel reuse; Direct mapping from cache_id to GraphCache allows efficient +//! graph_cache indexing; namespace { @@ -119,7 +120,7 @@ class CudaFusionManager { // TODO: I think merge cannot handle broadcast - Go verify it later; // TODO: Since we are only handling permutation here, we should just // merge the stride_index_; - acc_type = acc_type->merge(input_type); + acc_type = acc_type->merge(*input_type); } else { acc_type = input_type; } diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 382883188dbe0..7e7dcf7740f10 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -7,6 +8,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { void OptOutMutator::mutate(Fusion* fusion) { std::vector orig_exprs = fusion->exprs(); @@ -198,6 +200,7 @@ Statement* OptOutMutator::mutate(kir::IfThenElse* ite) { return ite; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/mutator.h b/torch/csrc/jit/codegen/cuda/mutator.h index 40e325765a3e8..9451ea6c47da7 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.h +++ b/torch/csrc/jit/codegen/cuda/mutator.h @@ -10,6 +10,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* * Mutators are the mechanism used to modify IR nodes. Since most nodes are @@ -21,6 +22,7 @@ namespace fuser { * specialize those nodes which they want to have a particular transformation. */ +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index ea963332fa6d2..df250e061a899 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -474,7 +474,7 @@ class IrParser { value_map.emplace(node->output()->unique(), out); }, [](const Node* node) -> bool { - // we don't support cast of output types yet; + // TODO: support cast of output types yet; if (!node->inputs()[3]->type()->isSubtypeOf( static_cast(NoneType::get()))) { // We can only handle output as half and float; @@ -638,7 +638,7 @@ bool isNodeParsible(const Node* node) { return IrParser::canParseNode(node); } -std::unique_ptr parseJitIR(std::shared_ptr& graph) { +std::unique_ptr parseJitIR(const std::shared_ptr& graph) { FUSER_PERF_SCOPE("parseJitIR"); IrParser parser(graph); diff --git a/torch/csrc/jit/codegen/cuda/parser.h b/torch/csrc/jit/codegen/cuda/parser.h index f83cb8e8808b8..69dfab8f631c6 100644 --- a/torch/csrc/jit/codegen/cuda/parser.h +++ b/torch/csrc/jit/codegen/cuda/parser.h @@ -39,7 +39,7 @@ TORCH_CUDA_API bool isNodeParsible(const Node* node); // lowers PyTorch jit graph to `Fusion`. TORCH_CUDA_API std::unique_ptr parseJitIR( - std::shared_ptr& graph); + const std::shared_ptr& graph); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp index 5c839864665b1..63ec915c053af 100644 --- a/torch/csrc/jit/codegen/cuda/partition.cpp +++ b/torch/csrc/jit/codegen/cuda/partition.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -19,7 +20,7 @@ static c10::optional getDevice(const Value* value) { // not tensor type, return false as the op is not outputing scalar. return c10::nullopt; } - return value->type()->expect()->device(); + return value->type()->expectRef().device(); } static c10::optional getDevice(const Node* node) { @@ -111,83 +112,82 @@ bool maybeBroadcastOnShape( return false; }; -// [ Note - tricky broadcasting ] -// -// github issue # 190 -// -// To extend the issue further, we consider two difficult broadcasting cases -// that is difficult to naively schedule: -// scenario 1: single tensor with multiple broadcasting semantics; -// ``` -// %t = op(...) -// %t0_o = op0(%t, %t0) -// %t1_o = op1(%t, %t1) -// ``` -// It's hard to check/validate whether `%t0` and `%t1` implies -// identical broadcasting for `%t` so that we can simply broadcast -// it to their common shape and use the broadcasted tensor view in -// both `op0` and `op1`; or, if `%t0` and `%t1` has different -// shapes, we would need differently broadcasted `%t` for the two -// ops. -// Even with this condition sorted out, scheduling is challenging. -// As we cannot inline the computation of `%t` to the downstream -// consumer of `%t0_o` and `%t1_o` easily, because `computeAt` -// could propagate contradicting transformations on the common -// ancestor `%t`. -// See footnote*; -// scenario 2: output tensor_view which is broadcasted later; -// ``` -// %t = op(...) -// %t0_o = op0(%t, %t0) -// return (%t, %t0_o) -// ``` -// Similarly, if we need to broadcast `%t` to `%t0` for `op0`, and -// use it as output, it also complicates schedule. -// -// Currently we just avoid the two cases in our graph partitioning. -// -// We bake the implementation along with our partition, where we merge nodes -// from producer to consumer. In the example down, we list all "type"s of edges -// among producer/consumer and the out side world. -// -// %input_t0, %input_t1, %input_t2 # inputs from outside world feeding -// # producer/consumer pair -// %p_out_t0, %p_out_t1 = producer(%input_t0, %input_t1) -// %c_out_t, ... = consumer(%input_t0, %input_t2, %p_out_t0) -// -// producer/consumer : the nodes that we are trying to merge, each node could be -// a parsible real operation or a `CudaFusionGroup`. -// %input_t0 : inputs shared by both producer & consumer -// %input_t1 : inputs feed only to producer, but not to consumer -// %input_t2 : inputs feed only to consumer, but not to producer -// %p_put_t0 : outputs of producer that is fed to consumer -// %p_put_t1 : outputs of producer that is not fed to consumer -// %c_put_t0 : outputs of consumer -// -// We can see that after merging consumer & producer, we will have: -// %input_t0, %input_t1, %input_t2 # inputs from outside world feeding -// # producer/consumer pair -// %p_out_t, %c_out_t = group(%input_t0, %input_t1, %input_t2) -// -// Under the assumption that any existing `CudaFusionGroup` does not have -// violating broadcasting semantics mentioned above. -// -// If we examine the `group`, new cases of scenario 1 (multiple broadcast) could -// only be created by merging new edges in the new `group`, that is: -// case 1. `%input_t0`, shared by `producer` and `consumer` -// case 2. `%p_out_t0`, produced by `producer` and fed to `consumer` -// -// new cases of scenario 2 (output was broadcasted later) could only be added -// via: -// case 3. `%p_out_t0`, produced by `producer` and fed to `consumer`, which -// could be broadcasted in the consumer subgraph. -// -// footnote*: -// We are only disabling multiple broadcast right on the tensor, instead of -// tracing all the broadcast further down. -// I don't think we need to worry about broadcasting further down the dependency -// chain, as those would create new IterDomain, which doesn't have th problem of -// conflicting broadcasting. +//! [ Note - tricky broadcasting ] +//! +//! github issue # 190 +//! +//! To extend the issue further, we consider two difficult broadcasting cases +//! that is difficult to naively schedule: +//! scenario 1: single tensor with multiple broadcasting semantics; +//! ``` +//! %t = op(...) +//! %t0_o = op0(%t, %t0) +//! %t1_o = op1(%t, %t1) +//! ``` +//! It's hard to check/validate whether `%t0` and `%t1` implies +//! identical broadcasting for `%t` so that we can simply +//! broadcast it to their common shape and use the broadcasted +//! tensor view in both `op0` and `op1`; or, if `%t0` and `%t1` +//! has different shapes, we would need differently broadcasted +//! `%t` for the two ops. Even with this condition sorted out, +//! scheduling is challenging. As we cannot inline the computation +//! of `%t` to the downstream consumer of `%t0_o` and `%t1_o` +//! easily, because `computeAt` could propagate contradicting +//! transformations on the common ancestor `%t`. See footnote*; +//! scenario 2: output tensor_view which is broadcasted later; +//! ``` +//! %t = op(...) +//! %t0_o = op0(%t, %t0) +//! return (%t, %t0_o) +//! ``` +//! Similarly, if we need to broadcast `%t` to `%t0` for `op0`, +//! and use it as output, it also complicates schedule. +//! +//! Currently we just avoid the two cases in our graph partitioning. +//! +//! We bake the implementation along with our partition, where we merge nodes +//! from producer to consumer. In the example down, we list all "type"s of edges +//! among producer/consumer and the out side world. +//! +//! %input_t0, %input_t1, %input_t2 # inputs from outside world feeding +//! # producer/consumer pair +//! %p_out_t0, %p_out_t1 = producer(%input_t0, %input_t1) +//! %c_out_t, ... = consumer(%input_t0, %input_t2, %p_out_t0) +//! +//! producer/consumer : the nodes that we are trying to merge, each node could +//! be +//! a parsible real operation or a `CudaFusionGroup`. +//! %input_t0 : inputs shared by both producer & consumer +//! %input_t1 : inputs feed only to producer, but not to consumer +//! %input_t2 : inputs feed only to consumer, but not to producer +//! %p_put_t0 : outputs of producer that is fed to consumer +//! %p_put_t1 : outputs of producer that is not fed to consumer +//! %c_put_t0 : outputs of consumer +//! +//! We can see that after merging consumer & producer, we will have: +//! %input_t0, %input_t1, %input_t2 # inputs from outside world feeding +//! # producer/consumer pair +//! %p_out_t, %c_out_t = group(%input_t0, %input_t1, %input_t2) +//! +//! Under the assumption that any existing `CudaFusionGroup` does not have +//! violating broadcasting semantics mentioned above. +//! +//! If we examine the `group`, new cases of scenario 1 (multiple broadcast) +//! could only be created by merging new edges in the new `group`, that is: +//! case 1. `%input_t0`, shared by `producer` and `consumer` +//! case 2. `%p_out_t0`, produced by `producer` and fed to `consumer` +//! +//! new cases of scenario 2 (output was broadcasted later) could only be added +//! via: +//! case 3. `%p_out_t0`, produced by `producer` and fed to `consumer`, which +//! could be broadcasted in the consumer subgraph. +//! +//! footnote*: +//! We are only disabling multiple broadcast right on the tensor, instead of +//! tracing all the broadcast further down. +//! I don't think we need to worry about broadcasting further down the +//! dependency chain, as those would create new IterDomain, which doesn't have +//! th problem of conflicting broadcasting. bool createTrickyBroadcast(const Node* consumer, const Node* producer) { auto count_broadcasting_in_node = [](const Node* node, diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index 5a0eb3fcf8f4b..54a70121af4eb 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -1,4 +1,3 @@ - #include #include @@ -14,6 +13,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { std::vector PredicateCompute::computePredicates( const TensorView* tv, @@ -287,6 +287,7 @@ UnrollPredicate::UnrollPredicate( openLoop(unrolled_loop); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.h b/torch/csrc/jit/codegen/cuda/predicate_compute.h index 3c6d86106fe4b..d2fb8534a84e7 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.h +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.h @@ -32,6 +32,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { class PredicateCompute { public: @@ -73,6 +74,7 @@ class TORCH_CUDA_API UnrollPredicate { const std::unordered_map& p2c_root_map_; }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/register_interface.cpp b/torch/csrc/jit/codegen/cuda/register_interface.cpp index cedf44c0ba1b8..284ee05420a1d 100644 --- a/torch/csrc/jit/codegen/cuda/register_interface.cpp +++ b/torch/csrc/jit/codegen/cuda/register_interface.cpp @@ -1,7 +1,10 @@ #include #include +#include #include +#include + /* * Registers function pointers in interface.h */ @@ -18,7 +21,10 @@ class RegisterInterface { auto ptr = getFuserInterface(); ptr->fn_compile_n_ = &compileCudaFusionGroup; ptr->fn_run_n_s_ = &runCudaFusionGroup; - ptr->fn_fuse_graph = &CudaFuseGraph; + ptr->fn_fuse_graph_ = &CudaFuseGraph; + ptr->fn_can_fuse_n_ = &isFusableCudaFusionGroup; + + RegisterProfilingNode(canFuseNode); } }; diff --git a/torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu new file mode 100644 index 0000000000000..480a99efdc426 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu @@ -0,0 +1,104 @@ +// [Z,Y,X]_THREADS is the number of participating threads in the z, y, x +// dimension of the block. If set to 0 it means that dimension doesn't +// participate, otherwise it is the number of threads. We could start with warp +// reductions, then reduce the warps, this could save some shared memory, but +// may actually be slower. +// +// EXAMPLE USAGE: +// blockReduceSum +// (output[output_index], inputs[input_index], +// [] __device__ (T& a, const T b) { a += b; }); +// +// Note: We agressively template functions taking dim3 in the functions below +// because ROCM uses different types for the various dim3 and maps them +// directly to intrinsics, but they're dim3 when used after modification. +// +template < + bool X_REDUCE, + bool Y_REDUCE, + bool Z_REDUCE, + typename T, + typename Func, + typename _dim3ti, + typename _dim3bd> +__device__ void blockReduce( + T& out, + const T inp_val, + Func reduction_op, + const _dim3ti& thread_idx, + const _dim3bd& block_dim, + T* shared_mem, + bool read_write_pred, + T init_val) { + unsigned int reduction_size = (X_REDUCE ? block_dim.x : 1) * + (Y_REDUCE ? block_dim.y : 1) * (Z_REDUCE ? block_dim.z : 1); + + // If this thread will output a final result + bool should_write = true; + + if (X_REDUCE) + should_write = should_write && thread_idx.x == 0; + if (Y_REDUCE) + should_write = should_write && thread_idx.y == 0; + if (Z_REDUCE) + should_write = should_write && thread_idx.z == 0; + + unsigned int reduction_stride; + unsigned int reduction_tid; + unsigned int linear_tid; + + if (X_REDUCE && !Y_REDUCE && Z_REDUCE) { + // Transpose Z and Y in the shared memory so Z and X dims are contiguous in + // smem + reduction_stride = 1; + linear_tid = threadIdx.y * blockDim.z * blockDim.x + + threadIdx.z * blockDim.x + threadIdx.x; + reduction_tid = threadIdx.z * blockDim.x + threadIdx.x; + } else { + // Normal reduction in order + reduction_stride = + (X_REDUCE ? 1 + : (Y_REDUCE ? block_dim.x + : (Z_REDUCE ? block_dim.x * block_dim.y : 0))); + + linear_tid = thread_idx.z * block_dim.y * block_dim.x + + thread_idx.y * block_dim.x + thread_idx.x; + + reduction_tid = (Z_REDUCE ? thread_idx.z : 0) * + (Y_REDUCE ? block_dim.y : 1) * (X_REDUCE ? block_dim.x : 1) + + (Y_REDUCE ? thread_idx.y : 0) * (X_REDUCE ? block_dim.x : 1) + + (X_REDUCE ? thread_idx.x : 0); + } + + assert(reduction_stride != 0); + + if (read_write_pred) { + shared_mem[linear_tid] = inp_val; + } else { + shared_mem[linear_tid] = init_val; + } + __syncthreads(); + // Reduce down to nearest power of 2: + int np2 = 1 << (31 - __clz(reduction_size)); + + if (reduction_tid < np2) { + if (reduction_tid + np2 < reduction_size) { + reduction_op( + shared_mem[linear_tid], + shared_mem[linear_tid + np2 * reduction_stride]); + } + } + __syncthreads(); + // for (int factor = np2/2; factor > contig_threads / 2; factor>>=1) { + for (int factor = np2 / 2; factor > 0; factor >>= 1) { + if (reduction_tid < factor) { + reduction_op( + shared_mem[linear_tid], + shared_mem[linear_tid + factor * reduction_stride]); + } + __syncthreads(); + } + + if (should_write && read_write_pred) + out = shared_mem[linear_tid]; +} diff --git a/torch/csrc/jit/codegen/cuda/runtime/broadcast.cu b/torch/csrc/jit/codegen/cuda/runtime/broadcast.cu new file mode 100644 index 0000000000000..9a13b021f1012 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/broadcast.cu @@ -0,0 +1,41 @@ +namespace broadcast { + +template +__host__ __device__ unsigned offset_of_source( + const dim3& block_dim, + const dim3& thread_idx) { + unsigned offset = 0; + if (!Z_THREAD) + offset = offset * block_dim.z + thread_idx.z; + if (!Y_THREAD) + offset = offset * block_dim.y + thread_idx.y; + if (!X_THREAD) + offset = offset * block_dim.x + thread_idx.x; + return offset; +} + +// Broadcasts within partitioned groups of threads. +// +// X_THREAD: Broadcast from threadIdx.x == 0 if true +// Y_THREAD: Broadcast from threadIdx.y == 0 if true +// Z_THREAD: Broadcast from threadIdx.z == 0 if true +// inp_val: Per-thread source value. Only valid when the thread is a source. +// out: Per-thread output location +// +template +__device__ void blockBroadcast(T& out, T inp_val, T* shared_mem) { + const bool has_valid_data = (!X_THREAD || threadIdx.x == 0) && + (!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0); + + const auto shared_offset = + offset_of_source(blockDim, threadIdx); + + if (has_valid_data) + shared_mem[shared_offset] = inp_val; + + __syncthreads(); + + out = shared_mem[shared_offset]; +} + +} // namespace broadcast diff --git a/torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu b/torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu new file mode 100644 index 0000000000000..ba236784ed740 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu @@ -0,0 +1,21 @@ +#define __HALF_TO_US(var) *(reinterpret_cast(&(var))) +#define __HALF_TO_CUS(var) *(reinterpret_cast(&(var))) + +struct __align__(2) __half { + __host__ __device__ __half() {} + + protected: + unsigned short __x; +}; + +__device__ __half __float2half(const float f) { + __half val; + asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f)); + return val; +} + +__device__ float __half2float(const __half h) { + float val; + asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h))); + return val; +} diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu new file mode 100644 index 0000000000000..8900ab8c5b902 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu @@ -0,0 +1,376 @@ +// Inter-block reduction. +// +// Function gridReduce performs point-wise reductions of scalars across thread +// blocks. Thread blocks are disjointly partitioned into groups of thread +// blocks, "reduction segments," that are collectively defined by boolean +// template parameters, X_BLOCK, Y_BLOCK and Z_BLOCK. Each of X/Y/Z_BLOCK +// determines whether thread blocks along the dimension should be grouped into +// the same reduction segment. Cross-block reducitons are independently done +// within each segment and generates distinctive results per segment. For +// instance, if all of X/Y/Z_BLOCK are true, reductions will be done across all +// thread blocks since there will be just a single segment consisting of all +// thread blocks. If none of them are true, each thread block will become a +// segment by itself, so no reduction will be performed. +// +// The input scalars to reduce within each segment are a certain subset of +// thread-private scalars provided as part of the gridReduce function +// parameters. Boolean template parameters, X_THREAD, Y_THREAD and Z_THREAD, +// determine which subset of the scalars should be used for inter-block +// reductions. Specifically, all the input scalars of threads along each +// dimension will be used when X/Y/Z_THREAD are true. Otherwise, only the value +// held at offset 0 of each dimension will be used. Thus, for example, if all of +// X/Y/Z_THREAD are true, the scalars of all threads in each block will +// participate in inter-block reductions. If all of them are false, only one +// scalar of the thread at threadIdx.x == threadIdx.y == threadIdx.z == 0 will +// be used. In the code below, we call the subset of threads a "reduction +// block." +// +// Inter-block reductions perform point-wise reductions of scalars of reduction +// blocks within each reduction segment. More specifically, let rb be a +// reduction block and rs be a reduction segment. Let IN(thread_idx, block_idx) +// denote the input scalar of thread at thread_idx and block_idx. The result of +// each reduction segment, OUT(thread_idx, block_idx_out), is defined only for +// each thread_idx in thread block block_idx_out in the segment as follows: +// +// OUT(thread_idx, block_idx_out) = +// Reduction of IN(thread_idx, block_idx) for +// all block_idx in a reduction segment +// +// OUT is not given for all threads that are not in block_idx_out and the +// reduction block. +// +// See also the function comment of gridReduce. + +namespace reduction { + +// Utility functions +template +__device__ __forceinline__ size_t size(const _dim3& d) { + return (size_t)d.x * (size_t)d.y * (size_t)d.z; +} + +#define isize(d) d.x* d.y* d.z + +template +__device__ __forceinline__ size_t +offset(const _dim3pos& pos, const _dim3dim& dim) { + return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x + + (size_t)pos.z * (size_t)dim.x * (size_t)dim.y; +} + +#define ioffset(pos, dim) pos.x + pos.y* dim.x + pos.z* dim.x* dim.y + +// Returns dim3 of each reduction segment. +template +__device__ dim3 dimension_of_reduction_segment(const _dim3& grid_dim) { + return dim3{ + X_BLOCK ? grid_dim.x : 1, + Y_BLOCK ? grid_dim.y : 1, + Z_BLOCK ? grid_dim.z : 1}; +} + +// Returns the number of blocks in each reduction segment. +template +__device__ size_t size_of_reduction_segment(const _dim3& grid_dim) { + return size( + dimension_of_reduction_segment(grid_dim)); +} + +// Returns the total number of reduction segments. +template +__device__ size_t number_of_reduction_segments(const _dim3& grid_dim) { + return (X_BLOCK ? 1 : grid_dim.x) * (Y_BLOCK ? 1 : grid_dim.y) * + (Z_BLOCK ? 1 : grid_dim.z); +} + +// Returns the 1-D index of the segment of thread block of block_idx. +template < + bool X_BLOCK, + bool Y_BLOCK, + bool Z_BLOCK, + typename _dim3bi, + typename _dim3gd> +__device__ size_t +index_of_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) { + size_t seg_idx = 0; + if (!Z_BLOCK) + seg_idx += block_idx.z; + if (!Y_BLOCK) + seg_idx = seg_idx * grid_dim.y + block_idx.y; + if (!X_BLOCK) + seg_idx = seg_idx * grid_dim.x + block_idx.x; + return seg_idx; +} + +// Returns the offset of thread block in its reduction segment. +template < + bool X_BLOCK, + bool Y_BLOCK, + bool Z_BLOCK, + typename _dim3bi, + typename _dim3gd> +__device__ size_t +offset_in_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) { + size_t offset = 0; + if (Z_BLOCK) + offset = offset * grid_dim.z + block_idx.z; + if (Y_BLOCK) + offset = offset * grid_dim.y + block_idx.y; + if (X_BLOCK) + offset = offset * grid_dim.x + block_idx.x; + return offset; +} + +// Returns dim3 of each reduction block. +template +__device__ dim3 dimension_of_reduction_block(const _dim3& block_dim) { + return dim3{ + X_THREAD ? block_dim.x : 1, + Y_THREAD ? block_dim.y : 1, + Z_THREAD ? block_dim.z : 1}; +} + +// Returns the number of threads of each reduction block. +template +__device__ int size_of_reduction_block(const _dim3& block_dim) { + auto tmp_dim = + dimension_of_reduction_block(block_dim); + return isize(tmp_dim); +} + +// Returns the linear offset of a thread in a reduction block. +template < + bool X_THREAD, + bool Y_THREAD, + bool Z_THREAD, + typename _dim3ti, + typename _dim3bd> +__device__ int offset_in_reduction_block( + const _dim3ti& thread_idx, + const _dim3bd& block_dim) { + int offset = 0; + if (Z_THREAD) + offset += thread_idx.z; + if (Y_THREAD) + offset = offset * block_dim.y + thread_idx.y; + if (X_THREAD) + offset = offset * block_dim.x + thread_idx.x; + return offset; +} + +// Reduces all the reduction blocks in each reduction segment. +// +// This is only used by one thread block per reduction segment. The input +// reduction blocks of the segment are stored in an intermediate buffer pointed +// by parameter in. Template parameters X/Y/Z_THREAD denote how the reduction +// block is formed. +// +// The size of a reduction block is by definition smaller or equal to the size +// of a thread block. We use the remaining threads to parallelize reductions +// across reduction blocks. For example, when X/Y/Z_THREAD = {true, false, +// false}, we use blockDim.y*blockDim.z threads for each output value. This is +// done first by loading the input values in parallel and then by reducing +// across threads of dimensions whose XYZ_THREAD are false. +// +// Note that what is done here after the loading from global memory is similar +// to what the existing blockReduce function does. The main difference is that +// the logical block to reduce is a 2D domain where the leading dimension is the +// size of a reduction block and the second dimension is the remaining factor in +// each thread block. For example, when X/Y/Z_THREAD = {false, true, false}, the +// threads are arranged as (blockDim.y, blockDim.x*blockDim.z). We do not reduce +// along the first dimension but only the second dimension. So, it is possible +// to reuse the existing blockReduce with dim3{blockDim.y, +// blockDim.x*blockDim.z} instead of blockDim and with X_THREAD and Y_THREAD +// being false and true, respectively. Also, it still need to shuffle the final +// output values to their actual corresponding threads. In the case of when +// X/Y/Z_THREAD = {false, true, false}, after the intra-block reduction, the +// final results will still be held by the first blockDim.y threads, which need +// to be transferred to threads at threadIdx.x == 0 and threadIdx.z == 0. +template < + bool X_THREAD, + bool Y_THREAD, + bool Z_THREAD, + typename T, + typename Func> +__device__ void gridReduceLastBlock( + T& out, + const T* in, + const size_t in_size, + Func reduction_op, + T* shared_buf, + bool read_write_pred, + T init_val) { + const int tid = ioffset(threadIdx, blockDim); + const int block_size = isize(blockDim); + const int rblock_size = + size_of_reduction_block(blockDim); + + T inp = init_val; + if (tid < in_size) { + inp = in[tid]; + } + for (size_t i = tid + block_size; i < in_size; i += block_size) { + reduction_op(inp, in[i]); + } + + const auto should_write = (X_THREAD || threadIdx.x == 0) && + (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0); + + auto rem_size = block_size / rblock_size; + + if (rem_size > 1) { + const int rblock_offset = tid % rblock_size; + const int rblock_idx = tid / rblock_size; + blockReduce( + inp, + inp, + reduction_op, + dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0}, + dim3{(unsigned)rblock_size, (unsigned)rem_size}, + shared_buf, + true, + init_val); + __syncthreads(); + if (tid < rblock_size) { + shared_buf[tid] = inp; + } + __syncthreads(); + if (should_write) { + inp = shared_buf[offset_in_reduction_block( + threadIdx, blockDim)]; + } + } + + if (should_write && read_write_pred) { + out = inp; + } +} + +// Reduces per-thread values across thread blocks. +// +// Function parameters: +// - out: Per-thread output location +// - inp_val: Per-thread input value +// - reduction_op: Scalar reduction function +// - work_buf: Temporary buffer for cross-block reductions +// - sync_flags: A vector of integers for synchronizations +// - shared_buf: Shared memory buffer for intra-block reduction +// +// Return true when the thread block has the valid result. +// +// Template parameters: +// - X/Y/Z_BLOCK: When true, reduces across thread blocks along the X/Y/Z +// dimensions +// - X/Y/Z_THREAD: When true, all threads along the X/Y/Z dimensions participate +// in the cross-block reduction. Otherwise, only threads at offset 0 do. +// - T: Scalar data type of input/output data +// - Func: Type of scalara reduction function +// +// Template parameters X/Y/Z_BLOCK define a group of thread blocks that are +// reduced together. We call it a reduction segment. Some examples are: +// +// Case 1: X/Y/Z_BLOCK == true/true/true -> There is only one segment, which +// includes all thread blocks. It is effecively the same as the grid. +// +// Case 2: X/Y/Z_BLOCK == false/false/false -> Each thread block comprises an +// individual segment by itself. +// +// Case 3: X/Y/Z_BLOCK == true/false/false -> Each segment contains thread +// blocks that have the same blockDim.x. There will be blockDim.y*blockDim.z +// such segments. +// +// X/Y/Z_THREAD defines a sub region of a thread block that should be reduced +// with the sub regions of other thread blocks. We call it a reduction block. +// E.g., +// +// Case 1: X/Y/Z_THREAD == false/false/false -> Only thread 0 participates in +// the cross-block reductions. The reduction block is 1x1x1 with thread 0. +// +// Case 2: X/Y/Z_THREAD == true/true/true-> All threads in a thread block +// participate in the cross-block reductions. The reduction block in this case +// is equivalent to the thread block. +// +// After the function completes, only one thread block per reduction segment +// gets valid reduction results. There is no guarantee which particular block +// gets the final results. +// +template < + bool X_BLOCK, + bool Y_BLOCK, + bool Z_BLOCK, + bool X_THREAD, + bool Y_THREAD, + bool Z_THREAD, + typename T, + typename Func> +__device__ bool gridReduce( + T& out, + T inp_val, + Func reduction_op, + volatile T* work_buf, + Tensor sync_flags, + T* shared_buf, + bool read_write_pred, + T init_val) { + // Number of values to reduce in the grid dimensions + const auto seg_size = + size_of_reduction_segment(gridDim); + + // Index of the reduction we're performing out of the seg_size + const auto seg_idx = + index_of_reduction_segment(blockIdx, gridDim); + + // Number of threads we can use in final reduction, Seems to assume all + // threads in the block participate + const auto rblock_size = + size_of_reduction_block(blockDim); + + // advance to the offset for this segment + // index of reduction * size of the reduction * size of threads + work_buf += seg_idx * seg_size * rblock_size; + + if ((X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) && + (Z_THREAD || threadIdx.z == 0)) { + auto rblock_offset = offset_in_reduction_segment( + blockIdx, gridDim); + auto thread_offset = + offset_in_reduction_block( + threadIdx, blockDim); + auto work_buf_offset = rblock_size * rblock_offset + thread_offset; + if (read_write_pred) { + work_buf[work_buf_offset] = inp_val; + } else { + work_buf[work_buf_offset] = init_val; + } + } + __syncthreads(); + + __shared__ bool last_block; + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + __threadfence(); + // printf("%ld\n", sync_flags[seg_idx]); + auto old = (int64_t)atomicAdd((unsigned long long*)&sync_flags[seg_idx], 1); + last_block = old + 1 == seg_size; + // printf("Last_block = %d + 1 == %d\n", (int)old, (int)seg_size); + } + __syncthreads(); + + if (last_block) { + // printf("Last block %d %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z); + // final reduction + gridReduceLastBlock( + out, + (T*)work_buf, + seg_size * rblock_size, + reduction_op, + shared_buf, + read_write_pred, + init_val); + return true; + } else { + // printf("Not last block %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z); + return false; + } +} + +} // namespace reduction diff --git a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu new file mode 100644 index 0000000000000..15b33b25634dc --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu @@ -0,0 +1,47 @@ +__device__ constexpr int ceilDiv(int a, int b) { + return (a + b - 1) / b; +} + +__device__ constexpr int alignBufferSize(int buffer, int size) { + return (buffer + (size - 1)) & ~(size - 1); +} + +__device__ float clamp(float x, float minv, float maxv) { + return x < minv ? minv : (x > maxv ? maxv : x); +} + +__device__ float frac(float x) { + return x - truncf(x); +} + +__device__ float gelu(float x) { + return x * normcdf(x); +} + +__device__ float reciprocal(float x) { + return 1.f / x; +} + +__device__ float relu(float x) { + return x <= 0.f ? 0.f : x; +} + +__device__ float remainder(float a, float b) { + return a - b * floorf(a / b); +} + +__device__ float sigmoid(float x) { + return 1.f / (1.f + expf(-x)); +} + +__device__ float threshold(float x, float t, float v) { + return x <= t ? v : x; +} + +__device__ float where(bool c, float a, float b) { + return c ? a : b; +} + +__device__ float randLike(Philox rnd) { + return uniform(rnd()); +} diff --git a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu new file mode 100644 index 0000000000000..d690145e61bdc --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu @@ -0,0 +1,104 @@ +class Philox { + public: + __device__ Philox( + unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + counter = make_uint4(0, 0, 0, 0); + counter.z = (unsigned int)(subsequence); + counter.w = (unsigned int)(subsequence >> 32); + STATE = 0; + incr_n(offset / 4); + } + + __device__ unsigned long operator()() { + if (STATE == 0) { + uint4 counter_ = counter; + uint2 key_ = key; + for (int i = 0; i < 9; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + output = single_round(counter_, key_); + incr(); + } + unsigned long ret = 0; + switch (STATE) { + case 0: + ret = output.x; + break; + case 1: + ret = output.y; + break; + case 2: + ret = output.z; + break; + case 3: + ret = output.w; + break; + } + STATE = (STATE + 1) % 4; + return ret; + } + + private: + __device__ void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ void incr() { + if (++counter.x) + return; + if (++counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ unsigned int mulhilo32( + unsigned int a, + unsigned int b, + unsigned int* result_high) { + *result_high = __umulhi(a, b); + return a * b; + } + + __device__ uint4 single_round(uint4 ctr, uint2 key) { + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; + } + + private: + static constexpr unsigned long kPhilox10A = 0x9E3779B9; + static constexpr unsigned long kPhilox10B = 0xBB67AE85; + static constexpr unsigned long kPhiloxSA = 0xD2511F53; + static constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + + uint4 counter = {}; + uint4 output = {}; + uint2 key = {}; + unsigned int STATE = 0; +}; + +__device__ float uniform(unsigned int x) { + constexpr float kRanInvM32 = 2.3283064e-10f; // Inverse of 2^32. + return x * kRanInvM32; +} diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensor.cu b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu new file mode 100644 index 0000000000000..76731c8c44824 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu @@ -0,0 +1,26 @@ +typedef unsigned char uint8_t; +typedef signed char int8_t; +typedef short int int16_t; +typedef long long int int64_t; + +template +struct Tensor { + __device__ T& operator[](int64_t ind) { + return data[ind]; + }; + + T* data; + int64_t size[N]; + int64_t stride[N]; +}; + +// Specialization for 0-dim case as it does not need size and stride arrays. +// They will be an error as well since zero-length arrays are not allowed. +template +struct Tensor { + __device__ T& operator[](int64_t) { + return *data; + }; + + T* data; +}; diff --git a/torch/csrc/jit/codegen/cuda/scheduler.cpp b/torch/csrc/jit/codegen/cuda/scheduler.cpp index f9bc25ca711e3..e4d4f3478a834 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler.cpp @@ -112,9 +112,7 @@ bool scheduleFusion(Fusion* fusion, const at::ArrayRef inputs) { out_tv->split(0, kPwThreadX); // Split by another 4 which will be our unroll factor auto ur_factor = disable_unroll ? 1 : kUnrollFactor; - if (!disable_unroll) { - out_tv->split(0, ur_factor); - } + out_tv->split(0, ur_factor); } for (auto output : fusion->outputs()) { diff --git a/torch/csrc/jit/codegen/cuda/shape_inference.cpp b/torch/csrc/jit/codegen/cuda/shape_inference.cpp index b06d586ec1288..5430d0edf0b7c 100644 --- a/torch/csrc/jit/codegen/cuda/shape_inference.cpp +++ b/torch/csrc/jit/codegen/cuda/shape_inference.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -14,9 +15,8 @@ namespace cuda { namespace { -bool hasTypeDeviceAndDim(const TensorTypePtr& op) { - return op->sizes().size().has_value() && op->scalarType().has_value() && - op->device().has_value(); +bool hasTypeAndDim(const TensorTypePtr& op) { + return op->sizes().size().has_value() && op->scalarType().has_value(); } /* NaiveTypePropagator @@ -84,7 +84,7 @@ class NaiveTypePropagator { case aten::gelu: case aten::tanh: { TORCH_CHECK( - hasTypeDeviceAndDim(node->input(0)->type()->cast()), + hasTypeAndDim(node->input(0)->type()->cast()), "Type, device, and dimensionality propagation has failed, or was not provided enough information."); node->output()->setType(node->input(0)->type()->cast()); break; @@ -92,7 +92,7 @@ class NaiveTypePropagator { // TODO: rand_like should support cast. case aten::rand_like: { TORCH_CHECK( - hasTypeDeviceAndDim(node->input(0)->type()->cast()), + hasTypeAndDim(node->input(0)->type()->cast()), "Type, device, and dimensionality propagation has failed, or was not provided enough information."); node->output()->setType(node->input(0)->type()->cast()); break; @@ -186,7 +186,7 @@ class NaiveTypePropagator { const TensorTypePtr& op, const std::vector& dims, bool keepdim) { - TORCH_CHECK(hasTypeDeviceAndDim(op), "requires complete shape on input"); + TORCH_CHECK(hasTypeAndDim(op), "requires complete shape on input"); auto input_size = op->sizes(); int64_t ndims = keepdim ? input_size.size().value() : 0; if (!keepdim) { @@ -226,7 +226,7 @@ class NaiveTypePropagator { } else { auto ptr = (op0 != nullptr) ? op0 : op1; TORCH_CHECK( - hasTypeDeviceAndDim(ptr), + hasTypeAndDim(ptr), "Type, device, and dimensionality propagation has failed, or was not provided enough information."); return TensorType::create( scalar_type.has_value() ? *scalar_type : *ptr->scalarType(), diff --git a/torch/csrc/jit/codegen/cuda/shape_inference.h b/torch/csrc/jit/codegen/cuda/shape_inference.h index da2a2ed3f3a92..ede73d97afc18 100644 --- a/torch/csrc/jit/codegen/cuda/shape_inference.h +++ b/torch/csrc/jit/codegen/cuda/shape_inference.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace torch { diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 86ff7263af248..351d7048234a0 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -3,19 +3,18 @@ #include #include #include -#include -// #include #include +#include #include // Cleanup -// #include #include #include namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { DataType aten_opt_type_map(const c10::optional& scalar_type) { @@ -507,17 +506,8 @@ TensorView* TensorView::cache_after() { // After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV // Expr* consumer_uses = - size_t count = 0; for (auto expr : fusion()->unordered_uses(this)) { createExprProducer(expr, this, consumer); - ++count; - } - - if (count > 1) { - std::cout - << "WARNING: Cache_After with multiple consumers can create incorrect " - "kernels depending on computeAt configuration." - << std::endl; } // Expr* consumer_origin = @@ -713,6 +703,7 @@ void TensorView::createExprProducer( CreateExprProducer::create(expr, current, producer); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/tools/stringify_file.py b/torch/csrc/jit/codegen/cuda/tools/stringify_file.py new file mode 100644 index 0000000000000..9f4e74e9c1e62 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/tools/stringify_file.py @@ -0,0 +1,29 @@ + +# Generates a C++ header files embedding the original input as a string literal + +import argparse +import pathlib +from datetime import datetime + +arg_parser = argparse.ArgumentParser( + description='Converts source files to C++ string literals', allow_abbrev=False) + +arg_parser.add_argument('-i', '--input', required=True, + help='Input source file') + +arg_parser.add_argument('-o', '--output', required=True, + help='Name of the generated header file') + +args = arg_parser.parse_args() + +with open(args.input, 'r') as fin: + with open(args.output, 'w') as fout: + literal_name = f'{pathlib.Path(args.input).stem}_cu' + fout.write(f'// Generated from "{args.input}"\n') + fout.write(f'// {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n\n') + fout.write('namespace nvfuser_resources {\n\n') + fout.write(f'constexpr const char* {literal_name} = R"(\n') + for line in fin: + fout.write(line) + fout.write(')";\n') + fout.write('\n} // namespace nvfuser_resources\n') diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 30e41296258f7..c7c356ba1126f 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -1,9 +1,11 @@ #include + #include namespace torch { namespace jit { namespace fuser { +namespace cuda { // Transform dispatch void ReplayTransformations::handle(Expr* e) { @@ -431,6 +433,7 @@ int BestEffortReplay::findFirstMismatchedID( return td1->nDims(); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index 161fa547680e4..987b0274868db 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -11,6 +11,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -196,6 +197,7 @@ class TORCH_CUDA_API BestEffortReplay { const TensorDomain* td2); }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 8ea00bd28c56c..120904ab454d7 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -11,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { using id_map = std::unordered_map; @@ -317,12 +319,16 @@ std::pair TransformReplay::replayPasC( unsigned int producer_compute_at_axis = new_IDs.size(); // Add axes in (2) - std::unordered_set consumer_CA_ids_set( - consumer_CA_ids.begin(), consumer_CA_ids.end()); for (auto c_id : consumer->domain()) { auto it = replay_PasC.getReplay().find(c_id); if (it != replay_PasC.getReplay().end()) { auto id = it->second; + // If the leaf id from ReplayTransformations is used to move + // forward in BestEffortReplay, it is not a final ID. + if (producer_replayed_leaves.getUnorderedLeafIDs().find(id) == + producer_replayed_leaves.getUnorderedLeafIDs().end()) { + continue; + } if (used_IDs.find(id) == used_IDs.end()) { new_IDs.push_back(id); used_IDs.emplace(id); @@ -491,12 +497,16 @@ std::pair TransformReplay::replayCasP( } // Add axes in (2) - std::unordered_set consumer_CA_ids_set( - producer_CA_ids.begin(), producer_CA_ids.end()); for (auto p_id : producer->domain()) { auto it = replay_CasP.getReplay().find(p_id); if (it != replay_CasP.getReplay().end()) { auto id = it->second; + // If the leaf id from ReplayTransformations is used to move + // forward in BestEffortReplay, it is not a final ID. + if (consumer_replayed_leaves.getUnorderedLeafIDs().find(id) == + consumer_replayed_leaves.getUnorderedLeafIDs().end()) { + continue; + } if (used_IDs.find(id) == used_IDs.end()) { new_IDs.push_back(id); used_IDs.emplace(id); @@ -560,6 +570,7 @@ std::pair TransformReplay::replayCasP( return {consumer, replay.second}; } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index d43d71faf12c4..e4168f8316a62 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -9,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { /* * compute_at is a relative property between two TensorViews which marks at what @@ -151,6 +152,7 @@ class TORCH_CUDA_API TransformReplay { const TensorDomain* self); }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index 27a44a73d7ae4..b43ec54284326 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -9,6 +9,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { namespace { @@ -393,6 +394,7 @@ TensorDomain* TransformRFactor::runReplay2( new_root, new_domain, std::vector(new_root.size(), true)); } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_rfactor.h index 6d0977fd8acc8..9bf0926477bc9 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.h +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.h @@ -11,6 +11,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // TODO: Only replay dispatch is really borrowed from TransformIter, we should // reevaluate the reuse of dispatch for classes that inherit TransformIter. @@ -23,6 +24,7 @@ class TORCH_CUDA_API TransformRFactor { static TensorDomain* runReplay2(TensorDomain*, std::vector axes); }; +} // namespace cuda } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 3e0cd569c19e5..9d8d10f8475a6 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -6,6 +6,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // Return highest on list (smallest enum val) DataType promote_type(const DataType& t1, const DataType& t2) { @@ -535,6 +536,7 @@ size_t dataTypeSize(DataType type) { } } +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index bb60fb2e0d15d..63a98ca1968d5 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -12,6 +12,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key struct TypeHash { @@ -200,6 +201,7 @@ enum class LaunchConfigType { TIDx }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index fdc1e7c3d2fdb..f47c9440c259c 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -1,4 +1,3 @@ - #pragma once #include @@ -6,6 +5,7 @@ namespace torch { namespace jit { namespace fuser { +namespace cuda { // Common Functions constexpr int64_t ceilDiv(int64_t a, int64_t b) { @@ -74,6 +74,7 @@ class PolymorphicBase { } }; +} // namespace cuda } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index eadc52d0170d4..522d819816651 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -91,7 +91,7 @@ static std::string variableType(const std::shared_ptr& t) { return "double"; } else if (t->kind() == TypeKind::BoolType) { return "bool"; - } else if (auto scalar_type = t->expect()->scalarType()) { + } else if (auto scalar_type = t->expectRef().scalarType()) { return calcScalarTypeName(*scalar_type); } // something went wrong with the type analysis during shape propagation @@ -118,7 +118,7 @@ static std::string typeCastedValueName( } else if (t->kind() == TypeKind::NoneType) { // Support None value for optional arguments like memory format return vn; - } else if (auto scalar_type = t->expect()->scalarType()) { + } else if (auto scalar_type = t->expectRef().scalarType()) { if (*scalar_type != outtype) { return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; } @@ -261,7 +261,7 @@ static std::string encodeRHS(const Node* n) { } else { size_t i = 0; - auto outtype = n->output()->type()->expect()->scalarType(); + auto outtype = n->output()->type()->expectRef().scalarType(); TORCH_INTERNAL_ASSERT(outtype); for (auto in : n->inputs()) { diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index e49a6a6923457..e3e2382f58946 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -260,7 +260,7 @@ std::shared_ptr compileKernel( sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size(); } - auto scalar_type = o->type()->expect()->scalarType(); + auto scalar_type = o->type()->expectRef().scalarType(); TORCH_INTERNAL_ASSERT(scalar_type); auto type = TensorType::createContiguous(*scalar_type, device, sizes); output_desc.emplace_back(type); diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index 4e76dc23e55d7..e671afb540baf 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -19,22 +20,22 @@ namespace cpu { #ifdef _MSC_VER static const std::string getTempPath() { - char lpTempPathBuffer[MAX_PATH]; + wchar_t lpTempPathBuffer[MAX_PATH]; - DWORD dwRetVal = GetTempPath( + DWORD dwRetVal = GetTempPathW( MAX_PATH, // length of the buffer lpTempPathBuffer); // buffer for path TORCH_CHECK(dwRetVal < MAX_PATH && dwRetVal != 0, "GetTempPath failed."); - return std::string(lpTempPathBuffer); + return std::string(c10::u16u8(lpTempPathBuffer)); } static const std::string temp_dir = getTempPath(); static const std::string so_template = temp_dir + "pytorch_fuserXXXXXX.dll"; static const std::string cpp_template = temp_dir + "pytorch_fuserXXXXXX.cpp"; static const std::string check_exists_string = "where \"${program}\" > nul 2> nul"; -static std::vector env_list; +static std::vector env_list; constexpr int so_suffix_len = 4; constexpr int cpp_suffix_len = 4; #else @@ -45,60 +46,66 @@ constexpr int so_suffix_len = 3; constexpr int cpp_suffix_len = 4; #endif +intptr_t run(const std::string& cmd); + static bool programExists(const std::string& program) { TemplateEnv env; env.s("program", program); std::string cmd = format(check_exists_string, env); +#ifdef _MSC_VER + return (run(cmd.c_str()) == 0); +#else return (system(cmd.c_str()) == 0); +#endif } #ifdef _MSC_VER -c10::optional exec(const std::string& cmd) { - std::array buffer; - std::string result; +c10::optional exec(const std::wstring& cmd) { + std::array buffer; + std::wstring result; std::unique_ptr pipe( - _popen(cmd.c_str(), "r"), _pclose); + _wpopen(cmd.c_str(), L"r"), _pclose); if (!pipe) { return c10::nullopt; } - while (fgets(buffer.data(), static_cast(buffer.size()), pipe.get()) != + while (fgetws(buffer.data(), static_cast(buffer.size()), pipe.get()) != nullptr) { result += buffer.data(); } return result; } -inline std::string& rtrim(std::string& s, const char* t = " \t\n\r\f\v") { +inline std::wstring& rtrim(std::wstring& s, const wchar_t* t = L" \t\n\r\f\v") { s.erase(s.find_last_not_of(t) + 1); return s; } void activate() { - char* root = nullptr; - std::string cmd; - c10::optional exec_out; - std::string path; - std::string vcruntime_plat; - std::string envvars; + wchar_t* root = nullptr; + std::wstring cmd; + c10::optional exec_out; + std::wstring path; + std::wstring vcruntime_plat; + std::wstring envvars; // Checking whether the environment is already activated - if (getenv("VSCMD_ARG_TGT_ARCH")) { + if (_wgetenv(L"VSCMD_ARG_TGT_ARCH")) { return; } // Getting `ProgramFiles` through environment variable queries - root = getenv("ProgramFiles(x86)"); + root = _wgetenv(L"ProgramFiles(x86)"); if (!root) { - root = getenv("ProgramFiles"); + root = _wgetenv(L"ProgramFiles"); } if (!root) { return; } // Getting VS 2017 installation path through `vswhere` - cmd = "\"" + std::string(root) + - "\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" - " -latest -prerelease -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath"; + cmd = L"\"" + std::wstring(root) + + L"\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" + L" -latest -prerelease -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath"; exec_out = exec(cmd); if (!exec_out) { return; @@ -107,25 +114,25 @@ void activate() { rtrim(path); // Checking whether the activation script `vcvarsall.bat` exists - path += "\\VC\\Auxiliary\\Build"; - struct stat st; - if (stat(path.c_str(), &st) == -1 || !(st.st_mode & _S_IFDIR)) { + path += L"\\VC\\Auxiliary\\Build"; + struct _stati64 st; + if (_wstati64(path.c_str(), &st) == -1 || !(st.st_mode & _S_IFDIR)) { return; } - path += "\\vcvarsall.bat"; - if (_access(path.c_str(), 0) == -1) { + path += L"\\vcvarsall.bat"; + if (_waccess(path.c_str(), 0) == -1) { return; } // Determining current platform if (sizeof(void*) == 8) { - vcruntime_plat = "x64"; + vcruntime_plat = L"x64"; } else { - vcruntime_plat = "x86"; + vcruntime_plat = L"x86"; } // Getting environment variables after activating VS development shell - cmd = "\"" + path + "\" " + vcruntime_plat + ">NUL && set"; + cmd = L"\"" + path + L"\" " + vcruntime_plat + L">NUL && set"; exec_out = exec(cmd); if (!exec_out) { return; @@ -133,25 +140,26 @@ void activate() { envvars = *exec_out; // Setting environment variables to the current environment - std::istringstream f(envvars); - std::string envvar; - while (getline(f, envvar, '\n')) { + std::wistringstream f(envvars); + std::wstring envvar; + while (getline(f, envvar, L'\n')) { env_list.push_back(envvar); } } intptr_t run(const std::string& cmd) { // Getting the path of `cmd.exe` - char* comspec = getenv("COMSPEC"); + wchar_t* comspec = _wgetenv(L"COMSPEC"); if (!comspec) { - comspec = "C:\\Windows\\System32\\cmd.exe"; + comspec = L"C:\\Windows\\System32\\cmd.exe"; } // Constructing the command line - const char* a[] = {"/c", cmd.c_str(), nullptr}; + auto wCmd = c10::u8u16(cmd); + const wchar_t* a[] = {L"/c", wCmd.c_str(), nullptr}; // Constructing the env array // If `env_list` is not empty, then add char pointers ending with nullptr. // Otherwise, it will be nullptr, which implies the default env. - std::vector e; + std::vector e; if (!env_list.empty()) { for (auto& s : env_list) { e.push_back(s.c_str()); @@ -159,7 +167,7 @@ intptr_t run(const std::string& cmd) { e.push_back(nullptr); } // Running the command - intptr_t r = _spawnve(_P_WAIT, comspec, a, e.data()); + intptr_t r = _wspawnve(_P_WAIT, comspec, a, e.data()); return r; } #endif diff --git a/torch/csrc/jit/codegen/fuser/cpu/temp_file.h b/torch/csrc/jit/codegen/fuser/cpu/temp_file.h index b83f56a56a182..f0dcc06a3e491 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/temp_file.h +++ b/torch/csrc/jit/codegen/fuser/cpu/temp_file.h @@ -7,7 +7,8 @@ #ifdef _WIN32 #include -#include +#include +#include #include #include #include @@ -27,15 +28,15 @@ namespace fuser { namespace cpu { #ifdef _MSC_VER -int mkstemps(char* tmpl, int suffix_len) { +int wmkstemps(wchar_t* tmpl, int suffix_len) { int len; - char* name; + wchar_t* name; int fd = -1; int save_errno = errno; - len = strlen(tmpl); + len = wcslen(tmpl); if (len < 6 + suffix_len || - strncmp(&tmpl[len - 6 - suffix_len], "XXXXXX", 6)) { + wcsncmp(&tmpl[len - 6 - suffix_len], L"XXXXXX", 6)) { return -1; } @@ -47,7 +48,7 @@ int mkstemps(char* tmpl, int suffix_len) { name[i] = "abcdefghijklmnopqrstuvwxyz0123456789"[rd() % 36]; } - fd = _open(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD); + fd = _wopen(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD); } while (errno == EEXIST); if (fd >= 0) { @@ -63,20 +64,25 @@ struct TempFile { TH_DISALLOW_COPY_AND_ASSIGN(TempFile); TempFile(const std::string& t, int suffix) { +#ifdef _MSC_VER + auto wt = c10::u8u16(t); + std::vector tt(wt.c_str(), wt.c_str() + wt.size() + 1); + int fd = wmkstemps(tt.data(), suffix); + AT_ASSERT(fd != -1); + file_ = _wfdopen(fd, L"r+"); + auto wname = std::wstring(tt.begin(), tt.end() - 1); + name_ = c10::u16u8(wname); +#else // mkstemps edits its first argument in places // so we make a copy of the string here, including null terminator std::vector tt(t.c_str(), t.c_str() + t.size() + 1); int fd = mkstemps(tt.data(), suffix); AT_ASSERT(fd != -1); -#ifdef _MSC_VER - file_ = _fdopen(fd, "r+"); -#else file_ = fdopen(fd, "r+"); -#endif - // - 1 because tt.size() includes the null terminator, // but std::string does not expect one name_ = std::string(tt.begin(), tt.end() - 1); +#endif } const std::string& name() const { @@ -110,8 +116,9 @@ struct TempFile { if (file_ != nullptr) { fclose(file_); } - if (!name_.empty() && _access(name_.c_str(), 0) != -1) { - _unlink(name_.c_str()); + auto wname = c10::u8u16(name_); + if (!wname.empty() && _waccess(wname.c_str(), 0) != -1) { + _wunlink(wname.c_str()); } #else if (file_ != nullptr) { diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp index 27315ee475277..03ae998384138 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -28,11 +29,8 @@ const at::cuda::NVRTC& nvrtc() { return at::globalContext().getNVRTC(); } -static void getMajorMinor( - const cudaDeviceProp* const prop, - int& major, - int& minor) { - int nvrtc_major, nvrtc_minor; +void getMajorMinor(const cudaDeviceProp* const prop, int& major, int& minor) { + int nvrtc_major = 0, nvrtc_minor = 0; AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor)); // Short-circuits if NVRTC version too low @@ -51,22 +49,27 @@ static void getMajorMinor( minor = 0; } else if (nvrtc_major <= 9 && prop->major >= 7) { // 9 supports 3-7.2 major = 7; - if (prop->major == 7 && prop->minor <= 2) - minor = prop->minor; - else - minor = 0; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + minor = (prop->major == 7 && prop->minor <= 2) ? prop->minor : 0; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) } else if (nvrtc_major <= 10 && prop->major >= 7) { // 10 supports 3-7.5 major = 7; - if (prop->major == 7 && prop->minor <= 5) - minor = prop->minor; - else - minor = 0; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + minor = (prop->major == 7 && prop->minor <= 5) ? prop->minor : 0; + } else if ( + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + nvrtc_major == 11 && nvrtc_minor == 0 && + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + prop->major >= 8) { // 11.0 supports 3.5-8.0 + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + major = 8; + minor = 0; } } // Compiles the specified kernel and stores the metadata required to run it FusedKernelCUDA::FusedKernelCUDA( - int16_t device, + at::DeviceIndex device, std::string name, std::string code, std::vector input_desc, @@ -222,7 +225,7 @@ static std::shared_ptr createFusionKernel( std::vector concat_desc, bool has_random) { return std::make_shared( - device, + static_cast(device), std::move(name), std::move(code), std::move(input_desc), diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h index 5dba48dabfc7b..fde27d767405a 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h @@ -17,12 +17,17 @@ namespace jit { namespace fuser { namespace cuda { +TORCH_CUDA_API void getMajorMinor( + const cudaDeviceProp* const prop, + int& major, + int& minor); + // A class holding metadata for an actual CUDA function. // Note: CUDA functions are per device. struct TORCH_CUDA_API FusedKernelCUDA : public ::torch::jit::fuser::FusedKernel { FusedKernelCUDA( - int16_t device, + at::DeviceIndex device, std::string name, std::string code, std::vector input_desc, @@ -45,7 +50,7 @@ struct TORCH_CUDA_API FusedKernelCUDA // Note: per device to store device properties and compute launch heuristics // Acquiring these values at launch time would be too slow - int16_t device_; + at::DeviceIndex device_; int maxBlocks_; cudaDeviceProp* prop_; std::vector ptx_; diff --git a/torch/csrc/jit/codegen/fuser/interface.cpp b/torch/csrc/jit/codegen/fuser/interface.cpp index 8f35cf104b437..48003c81fcfa4 100644 --- a/torch/csrc/jit/codegen/fuser/interface.cpp +++ b/torch/csrc/jit/codegen/fuser/interface.cpp @@ -5,8 +5,11 @@ #include #include +#include #include +C10_DEFINE_bool(torch_jit_enable_cpu_fusion, false, "enable cpu fusion"); + namespace torch { namespace jit { @@ -30,7 +33,8 @@ void runFusion(const int64_t key, Stack& stack) { } bool canFuseOnCPU() { - return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled; + return fuser::hasFusionBackend(DeviceType::CPU) && + (detail::cpu_fuser_enabled || FLAGS_torch_jit_enable_cpu_fusion); } bool canFuseOnGPU() { diff --git a/torch/csrc/jit/codegen/fuser/interface.h b/torch/csrc/jit/codegen/fuser/interface.h index f3272cbc38cd8..4d6220dc9ed6f 100644 --- a/torch/csrc/jit/codegen/fuser/interface.h +++ b/torch/csrc/jit/codegen/fuser/interface.h @@ -32,6 +32,9 @@ TORCH_API bool canFuseOnGPU(); // flakiness) TORCH_API void overrideCanFuseOnCPU(bool value); +// Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval +TORCH_API void overrideMustUseLLVMOnCPU(bool value); + // Sets whether fusion on the GPU is allowed (enabled by default) TORCH_API void overrideCanFuseOnGPU(bool value); diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h new file mode 100644 index 0000000000000..fa92ce22d6e4c --- /dev/null +++ b/torch/csrc/jit/cuda/cuda.h @@ -0,0 +1,179 @@ +#include +#include +#include +#include + +namespace torch { +namespace jit { + +class CUDAEvent; +// This class is a wrapper around c10::cuda::CUDAStream. +// It is needed because TorchBind does not support all of the argument types +// for c10::cuda::CUDAStream. For more details, please refer to +// c10/cuda/CUDAStream.h. +class CUDAStream final : public CustomClassHolder { + public: + CUDAStream(int64_t device = -1, int64_t priority = 0) { + constexpr int64_t PRIORITY_INDEX = 0; + stream_ = std::make_unique( + c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device)); + } + + CUDAStream(c10::cuda::CUDAStream s) { + stream_ = std::make_unique(s); + } + + bool query() { + return stream_->query(); + } + + c10::intrusive_ptr recordEvent( + c10::intrusive_ptr event); + + void synchronize() { + stream_->synchronize(); + } + + void waitEvent(c10::intrusive_ptr event); + + void waitStream(c10::intrusive_ptr stream); + + /// Get the CUDA device index that this stream is associated with. + int64_t device_index() const { + return stream_->device_index(); + } + + /// Get the full Device that this stream is associated with. The Device + /// is guaranteed to be a CUDA device. + c10::Device device() const { + return stream_->device(); + } + + /// Return the stream ID corresponding to this particular stream. + int64_t id() const { + return stream_->id(); + } + + /// Pack a CUDAStream to uint64_t representation. + /// The CUDAStream can be unpacked using unpack(). The format of + /// the uint64_t is unspecified and may be changed. + int64_t pack() const { + return stream_->pack(); + } + + private: + std::unique_ptr stream_; + friend class CUDAEvent; +}; + +// This class is a wrapper around at::cuda::CUDAStream. +// It is needed because TorchBind does not support all of the argument types +// for at::cuda::CUDAEvent. For more details, please refer to +// aten/src/ATen/cuda/CUDAEvent.h. +class CUDAEvent final : public CustomClassHolder { + public: + CUDAEvent( + bool enable_timing = false, + bool blocking = false, + bool interprocess = false) { + int flags = cudaEventDisableTiming; + if (enable_timing) { + flags = cudaEventDefault; + } + if (blocking) { + flags |= cudaEventBlockingSync; + } + if (interprocess) { + TORCH_CHECK(!enable_timing); + flags |= cudaEventInterprocess; + } + + event_ = std::make_unique(flags); + } + + double elapsedTime(c10::intrusive_ptr end) { + return event_->elapsed_time(*end->event_); + } + + std::string ipcHandle() { + cudaIpcEventHandle_t handle; + event_->ipc_handle(&handle); + std::string str_handle((const char*)&handle, sizeof(handle)); + return str_handle; + } + + bool query() { + return event_->query(); + } + + void record(c10::intrusive_ptr stream); + + void synchronize() { + event_->synchronize(); + } + void wait(c10::intrusive_ptr stream); + + private: + void recordInternal(CUDAStream* stream); + std::unique_ptr event_; + + friend class CUDAStream; +}; + +c10::intrusive_ptr CUDAStream::recordEvent( + c10::intrusive_ptr event) { + if (!event) { + event = c10::make_intrusive(); + } + + event->recordInternal(this); + return event; +} + +void CUDAStream::waitEvent(c10::intrusive_ptr event) { + event->event_->block(*stream_); +} + +void CUDAStream::waitStream(c10::intrusive_ptr stream) { + auto ev = c10::make_intrusive(); + stream->recordEvent(ev); + waitEvent(ev); +} + +void CUDAEvent::record(c10::intrusive_ptr stream) { + event_->record(*stream->stream_); +} + +void CUDAEvent::recordInternal(CUDAStream* stream) { + event_->record(*stream->stream_); +} + +void CUDAEvent::wait(c10::intrusive_ptr stream) { + event_->block(*stream->stream_); +} + +TORCH_LIBRARY(cuda, m) { + auto stream_class = m.class_("Stream").def( + torch::init()); + auto event_class = m.class_("Event").def( + torch::init()); + + stream_class.def("query", &CUDAStream::query) + .def("record_event", &CUDAStream::recordEvent) + .def("synchronize", &CUDAStream::synchronize) + .def("wait_event", &CUDAStream::waitEvent) + .def("wait_stream", &CUDAStream::waitStream) + .def("device_index", &CUDAStream::device_index) + .def("device", &CUDAStream::device) + .def("pack", &CUDAStream::pack) + .def("id", &CUDAStream::id); + + event_class.def("elapsed_time", &CUDAEvent::elapsedTime) + .def("query", &CUDAEvent::query) + .def("record", &CUDAEvent::record) + .def("synchronize", &CUDAEvent::synchronize) + .def("wait", &CUDAEvent::wait); +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/docs/serialization.md b/torch/csrc/jit/docs/serialization.md index 9f40618e0e4df..2fc3c9abb4290 100644 --- a/torch/csrc/jit/docs/serialization.md +++ b/torch/csrc/jit/docs/serialization.md @@ -112,7 +112,7 @@ serialized as well. `PythonPrint` works by walking a `Graph` (the IR representation of either a `ClassType`'s method or raw `Function`) and emitting Python code that corresponds to it. The rules for emitting Python code are mostly -straightforward uninteresting. There are some extra pieces of information +straightforward and uninteresting. There are some extra pieces of information that `PythonPrint` tracks, however: **Class dependencies**. While walking the graph, `PythonPrint` keeps track of diff --git a/torch/csrc/jit/frontend/builtin_functions.cpp b/torch/csrc/jit/frontend/builtin_functions.cpp index 336c97704c966..a8d0795e8c608 100644 --- a/torch/csrc/jit/frontend/builtin_functions.cpp +++ b/torch/csrc/jit/frontend/builtin_functions.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -63,6 +64,17 @@ def _assert_int_or_pair(vals: List[int], name: str, message: str): def list_with_default(out_size: List[int], defaults: List[int]): assert len(defaults) > len(out_size) return out_size +def _assert(condition : bool, message : str): + assert condition, message +)SCRIPT"; + +// an additional overload for Tensor variant of _assert +const auto aten_ops_additional = + R"SCRIPT( +def _assert(condition : Tensor, message : str): + assert bool(condition), message +def __contains__(self: str, key: str): + return self.find(key, 0, len(self)) != -1 )SCRIPT"; // Implementations of historic symbol behaviors are defined here @@ -215,6 +227,7 @@ struct BuiltinFunctionRegistry { } loadSource(aten_ops, "aten"); + loadSource(aten_ops_additional, "aten"); // Loads functions implementing historic behavior, see note [Versioned // Symbols] diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp index 169b589cbfff1..b9ef3ec4c974e 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.cpp +++ b/torch/csrc/jit/frontend/concrete_module_type.cpp @@ -106,6 +106,7 @@ bool ConcreteModuleTypeBuilder::equals( bool equal = pyClass_.is(other.pyClass_) && iterableModuleKind_ == other.iterableModuleKind_ && + ignoredAttributes_ == other.ignoredAttributes_ && constants_ == other.constants_ && attributes_ == other.attributes_ && overloads_ == other.overloads_ && @@ -186,6 +187,10 @@ c10::optional ConcreteModuleType::findFailedAttribute( return c10::nullopt; } +bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const { + return data_.ignoredAttributes_.count(name) > 0; +} + std::shared_ptr ConcreteModuleType:: findSubmoduleConcreteType(const std::string& name) const { const auto it = std::find_if( @@ -223,7 +228,7 @@ void ConcreteModuleTypeBuilder::addConstant( "\n:", match.reason()); } - constants_.emplace(std::move(name), toIValue(value, match.type())); + constants_.emplace(std::move(name), toIValue(std::move(value), match.type())); } void ConcreteModuleTypeBuilder::addConstant(std::string name, IValue value) { @@ -232,7 +237,7 @@ void ConcreteModuleTypeBuilder::addConstant(std::string name, IValue value) { void ConcreteModuleTypeBuilder::addAttribute( std::string name, - TypePtr type, + const TypePtr& type, bool isParameter, bool isBuffer) { TORCH_INTERNAL_ASSERT(type); @@ -251,13 +256,13 @@ void ConcreteModuleTypeBuilder::addFunctionAttribute( TORCH_INTERNAL_ASSERT(type); functionAttributes_.emplace( std::move(name), - ConcreteModuleTypeBuilder::FunctionAttribute{type->expect(), - std::move(pyFunction)}); + ConcreteModuleTypeBuilder::FunctionAttribute{ + type->expect(), std::move(pyFunction)}); } void ConcreteModuleTypeBuilder::addBuiltinFunction( std::string name, - std::string symbol_name) { + const std::string& symbol_name) { builtinFunctions_.emplace( std::move(name), c10::Symbol::fromQualString(symbol_name)); } @@ -281,6 +286,10 @@ void ConcreteModuleTypeBuilder::addFailedAttribute( failedAttributes_.emplace(std::move(name), std::move(failureReason)); } +void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) { + ignoredAttributes_.emplace(std::move(name)); +} + void ConcreteModuleType::dump() const { std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n"; diff --git a/torch/csrc/jit/frontend/concrete_module_type.h b/torch/csrc/jit/frontend/concrete_module_type.h index 0410693d439cc..ff829d101fc1c 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.h +++ b/torch/csrc/jit/frontend/concrete_module_type.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -66,7 +66,7 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { void addConstant(std::string name, IValue value); void addAttribute( std::string name, - TypePtr type, + const TypePtr& type, bool isParameter, bool isBuffer); void addFunctionAttribute( @@ -79,8 +79,9 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { void addOverload( std::string methodName, std::vector overloadedMethodNames); - void addBuiltinFunction(std::string name, std::string symbol_name); + void addBuiltinFunction(std::string name, const std::string& symbol_name); void addFailedAttribute(std::string name, std::string failureReason); + void addIgnoredAttribute(std::string name); void setIterableModuleKind(IterableModuleKind kind); // If a ConcreteModuleType is poisoned, it will never compare equal to any @@ -133,7 +134,7 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { }; private: - ConcreteModuleTypeBuilder() {} + ConcreteModuleTypeBuilder() = default; ClassTypePtr createTypeFromThis() const; // If true, this type will never compare equally to anything else. This is @@ -150,6 +151,9 @@ class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { // Any attributes we failed to convert to TorchScript, along with a hint as to // why std::unordered_map failedAttributes_; + // Any attributes that were marked as ignored. They cannot be used in + // TorchScript but can still be used in ignored function in Python. + std::unordered_set ignoredAttributes_; // Any function attributes. These are special right now because functions are // not first-class in the type system. std::unordered_map functionAttributes_; @@ -191,6 +195,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType { std::shared_ptr findSubmoduleConcreteType( const std::string& name) const; c10::optional findFailedAttribute(const std::string& name) const; + bool isIgnoredAttribute(const std::string& name) const; // These getters are only here to return things as types that can be // automatically converted by pybind. @@ -216,7 +221,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType { void dump() const; private: - ConcreteModuleType() {} + ConcreteModuleType() = default; // The JIT type derived from this ConcreteModuleType. ConcreteModuleTypeBuilder data_; diff --git a/torch/csrc/jit/frontend/convert_to_ssa.cpp b/torch/csrc/jit/frontend/convert_to_ssa.cpp index 10109aa558245..9d97e1067e745 100644 --- a/torch/csrc/jit/frontend/convert_to_ssa.cpp +++ b/torch/csrc/jit/frontend/convert_to_ssa.cpp @@ -5,19 +5,20 @@ #include #include #include -#include namespace torch { namespace jit { // At the beginning of the pass the Graph has already undergone type checking, // and writes or reads to a variable are emitted as Loads and Stores in the -// graph. a = 1 print(a) is represented as: -// -// %a.1 : int = prim::Constant[value=1]() -// prim::Store[name="a"](%a.1) -// %a : int = prim::Load[name="a"]() -// prim::Print(%a) +// graph. +// a = 1 +// print(a) +// is represented as: +// %a.1 : int = prim::Constant[value=1]() +// prim::Store[name="a"](%a.1) +// %a : int = prim::Load[name="a"]() +// prim::Print(%a) // // First, this pass recursively adds the Loads & Stores to control flow nodes // Then the graph is converted to SSA form. @@ -149,7 +150,7 @@ struct ControlFlowLoadStores { case prim::Loop: { addLoopLoadStores(n); } break; - case prim::Function: { + case prim::Closure: { for (auto b : n->blocks()) { addControlFlowLoadStores(b); } @@ -157,7 +158,7 @@ struct ControlFlowLoadStores { case prim::Store: { environment_stack->setVar(n->s(attr::name), n->input()->type()); } break; - case prim::LocalVariableScope: { + case prim::ComprehensionScope: { addControlFlowLoadStores(n->blocks().at(0)); } break; } @@ -182,8 +183,8 @@ struct ControlFlowLoadStores { std::shared_ptr environment_stack = nullptr; }; -// Given a graph where outputs have been added to control flow nodes, and -// loads and stores are represented in the graph, erases the Loads & Stores. +// Given a graph where 1) outputs have been added to control flow nodes and +// 2) loads and stores are represented in the graph, erase the Loads & Stores. struct EraseLoadStores { void eraseBlockLoadStores(Block* block) { pushFrame(block); @@ -204,7 +205,7 @@ struct EraseLoadStores { n->output()->replaceAllUsesWith(var); n->destroy(); } break; - case prim::LocalVariableScope: { + case prim::ComprehensionScope: { // writes within a local variable scope do not leak into // the rest of the graph auto body = n->blocks().at(0); @@ -279,7 +280,7 @@ struct LoopContinuations { assignExitContinuations(n->blocks().at(0)); assignExitContinuations(n->blocks().at(1)); } break; - case prim::Function: { + case prim::Closure: { LoopContinuations closure_block; closure_block.run(n->blocks().at(0)); } break; diff --git a/torch/csrc/jit/frontend/error_report.cpp b/torch/csrc/jit/frontend/error_report.cpp index d148214c73e2d..0c20014cc12f6 100644 --- a/torch/csrc/jit/frontend/error_report.cpp +++ b/torch/csrc/jit/frontend/error_report.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/frontend/error_report.h b/torch/csrc/jit/frontend/error_report.h index 1d7f4fcedb9af..a07f5e4370ea8 100644 --- a/torch/csrc/jit/frontend/error_report.h +++ b/torch/csrc/jit/frontend/error_report.h @@ -11,7 +11,7 @@ struct Call { SourceRange caller_range; }; -struct CAFFE2_API ErrorReport : public std::exception { +struct TORCH_API ErrorReport : public std::exception { ErrorReport(const ErrorReport& e); explicit ErrorReport(SourceRange r); @@ -20,7 +20,7 @@ struct CAFFE2_API ErrorReport : public std::exception { const char* what() const noexcept override; - struct CAFFE2_API CallStack { + struct TORCH_API CallStack { // These functions are used to report why a function was being compiled // (i.e. what was the call stack of user functions at compilation time that // led to this error) diff --git a/torch/csrc/jit/frontend/exit_transforms.cpp b/torch/csrc/jit/frontend/exit_transforms.cpp index 3126d78c3bd23..bb94f4d784e77 100644 --- a/torch/csrc/jit/frontend/exit_transforms.cpp +++ b/torch/csrc/jit/frontend/exit_transforms.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -119,7 +120,7 @@ struct ExitTransformer { static bool isGraphOrClosureBlock(Block* block) { return block->owningNode() == nullptr || - owningNodeKind(block) == prim::Function; + owningNodeKind(block) == prim::Closure; } static void removeOutputs(Block* b) { @@ -273,8 +274,8 @@ struct ExitTransformer { return constructWontExitPair(); } - // for the block that is not exitting, its' exit values will not get - // used so we create uninitialized values of the same type as the other + // The exit values of the block that is not exiting will not get + // used, so we create uninitialized values of the same type as the other // block. if (then_status == ExitStatus::WONT || then_status == ExitStatus::THROWS) { std::vector exit_vals = @@ -425,7 +426,7 @@ struct ExitTransformer { case prim::With: { exit_pair = transformWith(node); } break; - case prim::Function: { + case prim::Closure: { // exits of closure declaration stay local to the closure transformExits(node->blocks().at(0)); } break; @@ -674,7 +675,7 @@ class DepthFirstGraphNodeIterator { // If either of the then or else blocks have nodes, the current block // and iterator position need to be saved on the stack to resume // processing later. - block_stack_.push_back({current_.first, current_.second}); + block_stack_.emplace_back(current_.first, current_.second); } if (!then_block_empty && else_block_empty) { @@ -690,7 +691,7 @@ class DepthFirstGraphNodeIterator { } else if (!then_block_empty && !else_block_empty) { // Set current_ to {then_block, then_block.begin()} and push the // else_block to the stack so that it will be processed after. - block_stack_.push_back({else_block, else_block->nodes().begin()}); + block_stack_.emplace_back(else_block, else_block->nodes().begin()); current_.first = then_block; current_.second = then_block->nodes().begin(); } @@ -704,7 +705,7 @@ class DepthFirstGraphNodeIterator { // If body_block is not empty, push the current block onto the stack // to resume processing it later and set current_ to {body_block, // body_block.begin()}. - block_stack_.push_back({current_.first, current_.second}); + block_stack_.emplace_back(current_.first, current_.second); current_.first = body_block; current_.second = body_block->nodes().begin(); diff --git a/torch/csrc/jit/frontend/function_schema_parser.cpp b/torch/csrc/jit/frontend/function_schema_parser.cpp index 91608babd5827..2daeabc0133b9 100644 --- a/torch/csrc/jit/frontend/function_schema_parser.cpp +++ b/torch/csrc/jit/frontend/function_schema_parser.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -202,14 +203,14 @@ struct SchemaParser { IValue convertToList( TypeKind kind, const SourceRange& range, - std::vector vs) { + const std::vector& vs) { switch (kind) { case TypeKind::FloatType: - return fmap(vs, [](IValue v) { return v.toDouble(); }); + return fmap(vs, [](const IValue& v) { return v.toDouble(); }); case TypeKind::IntType: - return fmap(vs, [](IValue v) { return v.toInt(); }); + return fmap(vs, [](const IValue& v) { return v.toInt(); }); case TypeKind::BoolType: - return fmap(vs, [](IValue v) { return v.toBool(); }); + return fmap(vs, [](const IValue& v) { return v.toBool(); }); default: throw ErrorReport(range) << "lists are only supported for float or int types"; @@ -224,7 +225,7 @@ struct SchemaParser { } while (L.nextIf(',')); } L.expect(']'); - return convertToList(kind, tok.range, std::move(vs)); + return convertToList(kind, tok.range, vs); } IValue parseTensorDefault(const SourceRange& range) { diff --git a/torch/csrc/jit/frontend/function_schema_parser.h b/torch/csrc/jit/frontend/function_schema_parser.h index e4fcf1e7c0b4a..bdfaec640ac40 100644 --- a/torch/csrc/jit/frontend/function_schema_parser.h +++ b/torch/csrc/jit/frontend/function_schema_parser.h @@ -8,10 +8,10 @@ namespace torch { namespace jit { -CAFFE2_API c10::either parseSchemaOrName( +TORCH_API c10::either parseSchemaOrName( const std::string& schemaOrName); -CAFFE2_API c10::FunctionSchema parseSchema(const std::string& schema); -CAFFE2_API c10::OperatorName parseName(const std::string& name); +TORCH_API c10::FunctionSchema parseSchema(const std::string& schema); +TORCH_API c10::OperatorName parseName(const std::string& name); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 99ce4140c58ac..954648fc453ee 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1,13 +1,16 @@ #include + #include #include #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -19,6 +22,7 @@ #include #include #include +#include #include #include @@ -41,7 +45,7 @@ using ListAttributeMap = std::unordered_map>; struct Refinement { Refinement(std::string identifier, TypePtr type) - : identifier_(std::move(identifier)), type_(type) {} + : identifier_(std::move(identifier)), type_(std::move(type)) {} const std::string& identifier() const { return identifier_; } @@ -69,7 +73,7 @@ struct RefinementSet { : RefinementSet( Refinements({std::move(single_true)}), Refinements({std::move(single_false)})) {} - RefinementSet() {} // empty + RefinementSet() = default; // empty RefinementSet And(const RefinementSet& rhs) const { // if the result of an AND is true, both a & b had to be true, // so we take the union of a.true_refinements and b.true_refinements. @@ -243,7 +247,7 @@ struct Environment { while (runner->next) { runner = runner->next.get(); } - runner->error_messages[name] = msg; + runner->error_messages[name] = std::move(msg); } // see if type error has been set for a variable @@ -279,7 +283,7 @@ struct Environment { TypePtr type) { auto g = b->owningGraph(); g->insertNode(g->createStore(name, v))->setSourceRange(loc); - type_table[name] = type; + type_table[name] = std::move(type); } SugaredValuePtr findInThisFrame(const std::string& name) { @@ -402,7 +406,7 @@ struct Environment { << " but is being assigned to a value of type " << as_simple_value->type()->repr_str(); } - insertStore(name, loc, std::move(as_simple_value), annotated_type); + insertStore(name, loc, as_simple_value, annotated_type); } else { value_table[name] = std::move(value); } @@ -483,7 +487,7 @@ struct Environment { {"all", std::make_shared(aten::all, at::nullopt)}, {"divmod", std::make_shared(aten::divmod, at::nullopt)}, - {"list", std::make_shared(aten::list, at::nullopt)}, + {"list", SpecialFormValue::create(prim::list)}, {"ord", std::make_shared(aten::ord, at::nullopt)}, {"chr", std::make_shared(aten::chr, at::nullopt)}, {"bin", std::make_shared(aten::bin, at::nullopt)}, @@ -858,9 +862,12 @@ struct to_ir { return emitStatements(statements.begin(), statements.end()); } - // XXX - right now closures are used _only_ for defining gradients internally + // XXX: Right now closures are not generically implemented and are only used + // as an intermediate form for special tasks, like defining gradients or + // forked functions. + // // There are several unfinished aspects that make them unusable generally - // 1. We do not have a type, ivalue, operator to represent prim::Function, so + // 1. We do not have a type, ivalue, operator to represent prim::Closure, so // closure_node has type None // 2. There is no export logic for it yet, so it cannot be // exported/python_printed @@ -869,9 +876,19 @@ struct to_ir { // the changes to those variables will just get forgotten. // 4. There is no parsing support in frontend.py, this is intentional since it // prevents people from accidentally using this feature. + // + // This function leaves in the graph something like: + // + // %2 : None = prim::Closure() + // block0(): + // %1 : Tensor = prim::DoSomething(%0) + // -> (%1) + // + // A separate pass is required to erase this closure and replace it with + // something actually executable (see liftClosure and inlineForkedClosure). std::shared_ptr emitClosure( const std::function& emit_body) { - Node* closure_node = graph->insertNode(graph->create(prim::Function, 1)); + Node* closure_node = graph->insertNode(graph->create(prim::Closure, 1)); // it is not a real thing yet, so just say the type is None closure_node->output()->setType(NoneType::get()); Block* block = closure_node->addBlock(); @@ -930,49 +947,51 @@ struct to_ir { } void emitDelete(const Delete& stmt) { - if (stmt.expr().kind() == TK_SUBSCRIPT) { - Subscript subscript(stmt.expr()); - const List& subscript_exprs = subscript.subscript_exprs(); - if (subscript_exprs[0].kind() == TK_SLICE_EXPR) { - throw ErrorReport(stmt.range()) - << "del statements only support deletion at a single index, " - "slicing is not supported" - " (see https://github.com/pytorch/pytorch/issues/31430)"; - } - const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); - const SourceRange& val_range = subscript.value().range(); - Value* idx = emitExpr(subscript_exprs[0]); - Value* val = sv->asValue(val_range, method); - - // If val is a class instance, this is a method call to a type-specific - // implementation of del defined in a __delitem__ method. - if (auto cls = val->type()->cast()) { - if (!cls->findMethod("__delitem__")) { - throw ErrorReport(stmt.range()) - << "Class does not define __delitem__"; + for (const auto& target : stmt.targets()) { + if (target.kind() == TK_SUBSCRIPT) { + Subscript subscript(target); + const List& subscript_exprs = subscript.subscript_exprs(); + if (subscript_exprs[0].kind() == TK_SLICE_EXPR) { + throw ErrorReport(target.range()) + << "del statements only support deletion at a single index, " + "slicing is not supported" + " (see https://github.com/pytorch/pytorch/issues/31430)"; } + const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); + const SourceRange& val_range = subscript.value().range(); + Value* idx = emitExpr(subscript_exprs[0]); + Value* val = sv->asValue(val_range, method); + + // If val is a class instance, this is a method call to a type-specific + // implementation of del defined in a __delitem__ method. + if (auto cls = val->type()->cast()) { + if (!cls->findMethod("__delitem__")) { + throw ErrorReport(target.range()) + << "Class does not define __delitem__"; + } - // Use MethodValue to call the method to handle recursion. - MethodValue(val, "__delitem__") - .call(stmt.range(), method, {idx}, {}, 0); + // Use MethodValue to call the method to handle recursion. + MethodValue(val, "__delitem__") + .call(stmt.range(), method, {idx}, {}, 0); + } else { + auto node = graph->create(aten::Delete, {val, idx}, 0) + ->setSourceRange(target.range()); + graph->insertNode(node); + } + } else if (target.kind() == TK_VAR) { + Var var(target); + environment_stack->removeVar(var.name(), /*check_if_removed=*/true); } else { - auto node = graph->create(aten::Delete, {val, idx}, 0) - ->setSourceRange(stmt.range()); - graph->insertNode(node); + throw ErrorReport(target.range()) + << "del statements are only supported for deleting" + " list and dict items and variables"; } - } else if (stmt.expr().kind() == TK_VAR) { - Var var(stmt.expr()); - environment_stack->removeVar(var.name(), /*check_if_removed=*/true); - } else { - throw ErrorReport(stmt.range()) - << "del statements are only supported for deleting" - " list and dict items and variables"; } } void emitReturn(const Return& stmt) { - Value* result = emitExpr(stmt.expr()); TypePtr result_type = def_stack_.back().declared_return_type_; + Value* result = emitExpr(stmt.expr(), result_type); // result type is annotated, every return must convert to that type if (result_type) { // this guard skips implicit conversion from None -> Tensor for the return @@ -1090,9 +1109,9 @@ struct to_ir { } RefinementSet findIsNoneRefinements( - Expr lhs, + const Expr& lhs, Value* lhs_value, - Expr rhs, + const Expr& rhs, Value* rhs_value, int tok) { if (rhs.kind() != TK_NONE && lhs.kind() == TK_NONE) { @@ -1110,8 +1129,7 @@ struct to_ir { // propagate further in all loaded models. The handling of // unwrap_optional will fail in these cases since export did // not expect that the input would be none and an unannotated None. - // cannot be passed to unwrapoptional To enable this, - // we need to (1) implement a real casting operator + // To enable this, we need to (1) implement a real casting operator // annotated(T, X) that stays in the graph and does the cast // and (2) only enable this OPTIONAL_NONE when loading newer // graphs because it is incompatible with older graphs. @@ -1195,11 +1213,26 @@ struct to_ir { return emitHasAttr(apply.inputs()[0], apply.inputs()[1]); } } + auto sv = emitSugaredExpr(apply.callee(), 1); + auto loc = apply.callee().range(); + if (auto special_form = dynamic_cast(sv.get())) { + if (special_form->form() == prim::isinstance) { + checkApplyNumInputs(apply, 2); + return emitIsInstance(apply.inputs()[0], apply.inputs()[1]); + } + } } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); c10::optional static_if = c10::nullopt; - if (expr_out->node()->kind() == aten::is_scripting) { + auto kind = expr_out->node()->kind(); + if (kind == aten::is_scripting) { static_if = true; + } else if (kind == aten::has_torch_function) { + static_if = false; + } + // MetaCompile on boolean literals and constants + if (auto maybe_ivalue = toIValue(expr_out)) { + static_if = maybe_ivalue->toBool(); } return CondValue(expr_out, RefinementSet({}), static_if); } break; @@ -1221,10 +1254,12 @@ struct to_ir { return graph->create(kind, n_outputs)->setSourceRange(loc); } - Value* emitTernaryIf(const TernaryIf& expr) { + Value* emitTernaryIf( + const TernaryIf& expr, + const TypePtr& type_hint = nullptr) { CondValue cond_value = emitCondExpr(expr.cond()); - auto true_expr = [&] { return emitExpr(expr.true_expr()); }; - auto false_expr = [&] { return emitExpr(expr.false_expr()); }; + auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); }; + auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); }; return emitIfExpr(expr.range(), cond_value, true_expr, false_expr); } @@ -1250,10 +1285,10 @@ struct to_ir { type_set = true; } - // comprehension introduces it's own scope. no variable assigned + // comprehension introduces its own scope. no variable assigned // leaks into the rest of the graph Node* n = - graph->insertNode(create(prim::LocalVariableScope, lc.range(), 0)); + graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0)); auto* comprehension_block = n->addBlock(); pushFrame(comprehension_block); WithInsertPoint guard(comprehension_block); @@ -1272,6 +1307,52 @@ struct to_ir { return list_value; } + Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) { + const auto loc = dc.range(); + const auto targets_list = List::create(dc.range(), {dc.target()}); + const auto itrs = List::create(dc.range(), {dc.iter()}); + + Value* dict_value = + graph->insertNode(graph->create(prim::DictConstruct, 1))->output(); + // Set the default type to be Dict[Str, Tensor] + dict_value->setType(DictType::create(StringType::get(), TensorType::get())); + bool type_set = false; + if (type_hint) { + if (!type_hint->cast()) { + throw ErrorReport(loc) + << "Expected Dict type annotation for dict comprehension" + ", found " + << type_hint->repr_str(); + } + dict_value->setType(type_hint); + type_set = true; + } + + // A dict comprehension introduces its own scope. No variable assigned + // may leak into the rest of the graph + Node* n = + graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0)); + auto* comprehension_block = n->addBlock(); + pushFrame(comprehension_block); + WithInsertPoint guard(comprehension_block); + auto emit_body = [&]() { + auto k = emitExpr(dc.key()); + auto v = emitExpr(dc.value()); + if (!type_set) { + dict_value->setType(DictType::create(k->type(), v->type())); + type_set = true; + } + NamedValue self = NamedValue(loc, "self", dict_value); + NamedValue input_k = NamedValue(loc, "", k); + NamedValue input_v = NamedValue(loc, "", v); + emitBuiltinCall( + loc, *graph, aten::_set_item, {self, input_k, input_v}, {}); + }; + emitFor(targets_list, itrs, loc, emit_body); + popFrame(); + return dict_value; + } + // Insert subtyping refinements void insertRefinements(const SourceRange& loc, const RefinementSet& ref) { for (const Refinement& r : ref.activeRefinements()) { @@ -1334,8 +1415,8 @@ struct to_ir { Value* emitIfExpr( const SourceRange& range, const CondValue& cond_value, - std::function true_expr, - std::function false_expr) { + const std::function& true_expr, + const std::function& false_expr) { Node* n = graph->insertNode(create(prim::If, range, 0)); n->addInput(cond_value.value()); auto* true_block = n->addBlock(); @@ -1344,7 +1425,7 @@ struct to_ir { auto emit_if_expr = [this, &range]( Block* b, const RefinementSet& refinements, - std::function expr_value) { + const std::function& expr_value) { pushFrame(b); WithInsertPoint guard(b); insertRefinements(range, refinements); @@ -1353,9 +1434,8 @@ struct to_ir { popFrame(); }; - emit_if_expr(true_block, cond_value.refinements(), std::move(true_expr)); - emit_if_expr( - false_block, cond_value.refinements().Not(), std::move(false_expr)); + emit_if_expr(true_block, cond_value.refinements(), true_expr); + emit_if_expr(false_block, cond_value.refinements().Not(), false_expr); auto true_type = true_block->outputs().at(0)->type(); auto false_type = false_block->outputs().at(0)->type(); @@ -1575,7 +1655,7 @@ struct to_ir { // category checks: tuple_check = true, types = {float, int} struct GatheredTypes { GatheredTypes(ScriptTypeParser parser) : typeParser_(std::move(parser)) {} - void gather(Expr classinfo) { + void gather(const Expr& classinfo) { if (classinfo.kind() == TK_TUPLE_LITERAL) { for (Expr e : TupleLiteral(classinfo).inputs()) { gather(e); @@ -1668,7 +1748,7 @@ struct to_ir { // semantics specified at // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop void emitLoopCommon( - SourceRange range, + const SourceRange& range, const std::function& emit_body, const SugaredValuePtr& iter_val, c10::optional> targets, @@ -1733,7 +1813,7 @@ struct to_ir { void emitUnrolledLoop( const SourceRange& loc, const std::function& emit_body, - SugaredValuePtr iterable, + const SugaredValuePtr& iterable, const List& targets) { auto static_len = iterable->staticLen(); TORCH_INTERNAL_ASSERT( @@ -1974,6 +2054,18 @@ struct to_ir { return use_inplace_op ? aten::mul_ : aten::mul; case '%': return use_inplace_op ? aten::fmod_ : aten::fmod; + case '|': + return use_inplace_op ? aten::bitwise_or : aten::__or__; + case '&': + return use_inplace_op ? aten::bitwise_and : aten::__and__; + case '^': + return use_inplace_op ? aten::bitwise_xor : aten::__xor__; + case TK_LSHIFT: + return use_inplace_op ? aten::__lshift__ : aten::__lshift__; + case TK_RSHIFT: + return use_inplace_op ? aten::__irshift__ : aten::__rshift__; + case TK_POW: + return aten::pow; default: throw ErrorReport(stmt) << "Unknown augmented assignment: " << kindToString(stmt.aug_op()); @@ -2085,8 +2177,8 @@ struct to_ir { stmt.range(), *method.graph(), getAugOp(stmt, lhs->type()), - /*inputs=*/{lhs, rhs}, - /*attributes=*/{}, + /*args=*/{lhs, rhs}, + /*kwargs=*/{}, /*self=*/c10::nullopt); } } @@ -2656,9 +2748,9 @@ struct to_ir { if (auto special_form = dynamic_cast(sv.get())) { return emitApplySpecialForm(special_form->form(), apply, type_hint); } - auto inputs = getNamedValues(apply.inputs(), true); - auto attributes = emitAttributes(apply.attributes()); - return sv->call(loc, method, inputs, attributes, n_binders); + auto args = getNamedValues(apply.inputs(), true); + auto kwargs = emitAttributes(apply.attributes()); + return sv->call(loc, method, args, kwargs, n_binders); } // this function handles expressions that look like apply statements @@ -2679,9 +2771,9 @@ struct to_ir { } auto forked = emitSugaredExpr(Expr(trees[0]), 1); TreeList sliced_trees(trees.begin() + 1, trees.end()); - auto inputs = getNamedValues(sliced_trees, true); - auto attributes = emitAttributes(apply.attributes()); - return emitForkExpr(apply.range(), forked, inputs, attributes); + auto args = getNamedValues(sliced_trees, true); + auto kwargs = emitAttributes(apply.attributes()); + return emitForkExpr(apply.range(), forked, args, kwargs); } case prim::annotate: { checkApplyNumInputs(apply, 2); @@ -2701,7 +2793,7 @@ struct to_ir { << why_not.str(); } - // None is a subtype of Optional[T], but we want to remember what T is, + // None is a subtype of Optional[T], but we want to remember what T is // after annotation so that variables assigned to this None will still // get the right type. To do this, we make a None constant that // has the type Optional[T] @@ -2873,6 +2965,49 @@ struct to_ir { } return iterable_tree; } + case prim::list: { + if (apply.inputs().size() == 0) { + TypePtr type = type_hint ? type_hint : ListType::ofTensors(); + if (!type->cast()) { + throw ErrorReport(apply.range()) + << "Expected list type annotation for list(), found " + << type_hint->repr_str(); + } + return std::make_shared( + graph + ->insertNode(graph->createList( + type->expectRef().getElementType(), {})) + ->output()); + } + // list(iter) desugars to [_elem for _elem in iter] + checkApplyNumInputs(apply, 1); + auto iter_input = emitSugaredExpr(apply.inputs()[0], 1); + + // aten::list builtin op is registered for List and Str input + // dispatch to the builtin op to avoid perf slowdown on existing uses + if (auto simple = asSimple(iter_input)) { + if (simple->type()->cast() || + simple->type()->cast()) { + return std::make_shared(emitBuiltinCall( + apply.range(), *method.graph(), aten::list, {simple}, {})); + } + } + const std::string& iter_name = createTempName("$_iter"); + environment_stack->setSugaredVar( + apply.range(), + iter_name, + iter_input, + /*annotated_type=*/nullptr); + + const std::string& elem_name = createTempName("$_elem"); + auto ident = + Var::create(apply.range(), Ident::create(apply.range(), elem_name)); + auto iter = + Var::create(apply.range(), Ident::create(apply.range(), iter_name)); + auto lc = ListComp::create(apply.range(), ident, ident, iter); + return std::make_shared( + emitListComprehension(lc, type_hint)); + } default: TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form); } @@ -2882,7 +3017,15 @@ struct to_ir { // Push the source range of a call in case compiling this function // triggers an error ErrorReport::CallStack::update_pending_range(tree.range()); - return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method); + Value* out_val = + emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method); + // AnyType is the only user-exposed type which we don't unify to from + // its subtypes, so we add a cast for use cases like + // x : Any = 1 if cond else "str" + if (type_hint == AnyType::get() && out_val->type() != AnyType::get()) { + out_val = graph->insertUncheckedCast(out_val, type_hint); + } + return out_val; } NodeKind reverseComparision(NodeKind kind) { @@ -2923,7 +3066,7 @@ struct to_ir { return emitApplyExpr(apply, n_binders, type_hint); } break; case TK_SUBSCRIPT: { - return emitSubscript(Subscript(tree)); + return emitSubscript(Subscript(tree), type_hint); } break; default: return std::make_shared(emitSimpleExpr(tree, type_hint)); @@ -2956,11 +3099,15 @@ struct to_ir { return graph->insertConstant(maybe_out_stack->at(0), tree->range()); } + /** + * Emit a fork expression, of the form: + * torch.jit.fork(forked, *args, **kwargs) + */ std::shared_ptr emitForkExpr( SourceRange loc, const std::shared_ptr& forked, - at::ArrayRef inputs, - at::ArrayRef attributes) { + at::ArrayRef args, + at::ArrayRef kwargs) { auto g = method.graph(); Node* fork_node; TypePtr out_type; @@ -2980,8 +3127,7 @@ struct to_ir { fork_node->addInput(closure_output); } else { auto emit_closure_body = [&](Block* closure_block) { - auto fn_sugared_output = - forked->call(loc, method, inputs, attributes, 1); + auto fn_sugared_output = forked->call(loc, method, args, kwargs, 1); auto fn_simple_output = fn_sugared_output->asValue(loc, method); closure_block->registerOutput(fn_simple_output); out_type = fn_simple_output->type(); @@ -3000,8 +3146,8 @@ struct to_ir { // through RPC in TorchScript, // Ideally, function value in JIT IR is first-class citizen and // The RPC C++ entry API can take c10::Function directly. - auto rpcMinInputs = 2; - auto rpcMaxInputs = 5; // NOLINT + size_t rpcMinInputs = 2; + size_t rpcMaxInputs = 5; // NOLINT std::string op_name = rpc_op.toUnqualString(); if (apply.inputs().size() < rpcMinInputs || apply.inputs().size() > rpcMaxInputs) { @@ -3123,6 +3269,22 @@ struct to_ir { return std::make_shared(rpc_node_output); } + Value* emitBinaryOp(const TreeRef& tree) { + const auto& inputs = tree->trees(); + auto kind = getNodeKind(tree->kind(), inputs.size()); + auto overload = getOperatorOverload(tree->kind(), inputs.size()); + auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false); + if (tree->kind() == TK_IN) { + // For `in` the arguments are in reverse order (the object being + // checked is second) + std::iter_swap(named_values.begin() + 0, named_values.begin() + 1); + } + return asSimple( + makeMagic( + overload, std::make_shared(kind, at::nullopt)) + ->call(tree->range(), method, named_values, {}, 0)); + } + Value* emitSimpleExpr( const TreeRef& tree, const TypePtr& type_hint = nullptr) { @@ -3135,6 +3297,21 @@ struct to_ir { return emitBuiltinCall( tree->range(), *method.graph(), kind, named_values, {}); } + case '%': { + auto lhs = emitSugaredExpr(Expr(tree->tree(0)), 0) + ->asValue(tree->tree(0)->range(), method); + auto const& lhs_type = lhs->type(); + if (lhs_type == StringType::get()) { + auto values = getValues(tree->trees(), /*maybe_unpack=*/false); + auto node = graph->create(aten::percentFormat, values, 1) + ->setSourceRange(tree->range()); + Value* output = graph->insertNode(node)->output(); + output->setType(StringType::get()); + return output; + } else { + return emitBinaryOp(tree); + } + } case TK_IN: case TK_POW: case TK_NE: @@ -3147,28 +3324,12 @@ struct to_ir { case '/': case '+': case '-': - case '%': case '&': case '|': case '^': case TK_LSHIFT: - case TK_RSHIFT: { - const auto& inputs = tree->trees(); - auto kind = getNodeKind(tree->kind(), inputs.size()); - auto overload = getOperatorOverload(tree->kind(), inputs.size()); - auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false); - - if (tree->kind() == TK_IN) { - // For `in` the arguments are in reverse order (the object being - // checked is second) - std::iter_swap(named_values.begin() + 0, named_values.begin() + 1); - } - - return asSimple( - makeMagic( - overload, std::make_shared(kind, at::nullopt)) - ->call(tree->range(), method, named_values, {}, 0)); - } + case TK_RSHIFT: + return emitBinaryOp(tree); case TK_IS: case TK_ISNOT: case TK_AND: @@ -3199,7 +3360,7 @@ struct to_ir { return graph->insertConstant(IValue(), tree->range()); } break; case TK_IF_EXPR: { - return emitTernaryIf(TernaryIf(tree)); + return emitTernaryIf(TernaryIf(tree), type_hint); } break; case TK_STRINGLITERAL: { return emitStringLiteral(StringLiteral(tree)); @@ -3215,7 +3376,7 @@ struct to_ir { TypePtr elem_type = TensorType::get(); if (type_hint) { if (type_hint->kind() == TypeKind::ListType) { - elem_type = type_hint->expect()->getElementType(); + elem_type = type_hint->expectRef().getElementType(); } else { // If the type hint was not a List[T] throw an error throw ErrorReport(tree) @@ -3307,6 +3468,10 @@ struct to_ir { auto lc = ListComp(tree); return emitListComprehension(lc, type_hint); } break; + case TK_DICT_COMP: { + auto dc = DictComp(tree); + return emitDictComprehension(dc, type_hint); + } break; default: throw ErrorReport(tree) << "Cannot emit expr for: " << tree; } @@ -3354,7 +3519,25 @@ struct to_ir { } else { AT_ASSERT(!sliceable->type()->isSubtypeOf(TensorType::get())); } + // TODO for now let's deal with TupleType first. Ideally all list, tensor, + // string, and tuple slicing should be same (tugsbayasgalan) + if (sliceable->type()->cast()) { + std::vector> tuple_args; + // since we are only dealing with tuple slicing for now, we try to keep + // tuple args seperate for now + tuple_args.reserve(3); + + start ? tuple_args.emplace_back(start) + : tuple_args.emplace_back(c10::nullopt); + end ? tuple_args.emplace_back(end) + : tuple_args.emplace_back(c10::nullopt); + step ? tuple_args.emplace_back(step) + : tuple_args.emplace_back(c10::nullopt); + return emitTupleSlice(loc, args[0], tuple_args); + } + + // TODO this needs to be cleaned for list slicing // Default value for start is 0. if (!start) { start = graph->insertConstant(0, loc); @@ -3364,19 +3547,6 @@ struct to_ir { if (end) { args.emplace_back(loc, "end", end); } - if (sliceable->type()->cast()) { - if (step) { - // TODO: add support for slicing tuples with a step - throw ErrorReport(loc) - << "Unsupported operation: slicing tuples with a step isn't supported"; - } - - if (end) { - return emitTupleSlice(loc, args[0], args[1], /*end*/ args[2]); - } else { - return emitTupleSlice(loc, args[0], args[1], c10::nullopt); - } - } if (!step) { step = graph->insertConstant(1, loc); @@ -3739,32 +3909,43 @@ struct to_ir { Value* emitTupleSlice( const SourceRange& loc, const NamedValue& tuple_val, - const NamedValue& beg_val, - const at::optional& end_val) { + const std::vector>& tuple_args) { auto tuple_type = tuple_val.value(*graph)->type()->expect(); - int64_t beg = getAdjTupleIndex( - loc, - tuple_type, - getSliceInd(beg_val.value(*graph), loc), - /*allow_out_of_bounds*/ true); - int64_t end; int64_t tuple_len = tuple_type->elements().size(); + auto beg_val = tuple_args[0]; + auto end_val = tuple_args[1]; + auto step = tuple_args[2]; + + int64_t step_size = 1; + if (step) { + auto val = toIValue(step->value(*graph)); + TORCH_CHECK(val->isInt(), "Step size should always be an integer"); + step_size = val->to(); + } + + int64_t beg = std::numeric_limits::max(); + if (beg_val) { + beg = getAdjTupleIndex( + loc, tuple_type, getSliceInd(beg_val->value(*graph), loc), true); + } + + int64_t end = std::numeric_limits::max(); if (end_val) { end = getAdjTupleIndex( loc, tuple_type, getSliceInd(end_val->value(*graph), loc), true); - } else { - end = tuple_len; } - // slicing does not throw out of bounds errors - end = std::min(std::max((int64_t)0, end), tuple_len); - beg = std::min(std::max((int64_t)0, beg), tuple_len); + + int64_t num_values = slice_indices_adjust(tuple_len, &beg, &end, step_size); return graph - ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end)) + ->insertNode(graph->createTupleSlice( + tuple_val.value(*graph), beg, step_size, num_values)) ->output(); } - std::shared_ptr emitSubscript(const Subscript& subscript) { + std::shared_ptr emitSubscript( + const Subscript& subscript, + TypePtr type_hint = nullptr) { const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); const List& subscript_exprs = subscript.subscript_exprs(); const SourceRange& range = subscript.range(); @@ -3782,19 +3963,25 @@ struct to_ir { auto s_tuple_val = sv->asTupleValue(val_range, method)->asValue(val_range, method); const SliceExpr& slice = SliceExpr(subscript_exprs[0]); + std::vector> tuple_args; + tuple_args.reserve(3); auto begin = NamedValue(val_range, "begin", emitExpr(Expr(slice.startOr(0)))); + tuple_args.emplace_back(begin); if (slice.end().present()) { auto end = NamedValue(val_range, "end", emitExpr(Expr(slice.end().get()))); - auto tupleSliceValue = - emitTupleSlice(val_range, s_tuple_val, begin, end); - return std::make_shared(tupleSliceValue); + tuple_args.emplace_back(end); + } else { - auto tupleSliceValue = - emitTupleSlice(val_range, s_tuple_val, begin, c10::nullopt); - return std::make_shared(tupleSliceValue); + tuple_args.emplace_back(c10::nullopt); } + // pushing step_size to match the tuple_args + tuple_args.emplace_back(c10::nullopt); + + auto tupleSliceValue = + emitTupleSlice(val_range, s_tuple_val, tuple_args); + return std::make_shared(tupleSliceValue); } else { return std::make_shared(emitBasicSlice( range, sv->asValue(val_range, method), subscript_exprs)); @@ -3834,7 +4021,7 @@ struct to_ir { return std::make_shared( emitMultidimSlicing(range, sliceable, subscript_exprs)); } else { - return sv->getitem(range, method, idx); + return sv->getitem(range, method, idx, std::move(type_hint)); } } } @@ -4090,6 +4277,10 @@ void runCleanupPasses(std::shared_ptr& to_clean) { // For jitter CanonicalizeOutputs(to_clean); + + // Annotate aten::warns so that each has its unique ID. This enables us to + // mimic Python behavior of only emitting each warning only once. + AnnotateWarns(to_clean); } // we consider _N where N is a number, to be a non-meaningful name @@ -4136,10 +4327,28 @@ void CompilationUnit::define_interface( arguments.insert( arguments.end(), schema.arguments().begin(), schema.arguments().end()); iface->addMethod(schema.cloneWithArguments(std::move(arguments))); - if (method_def.statements().size() != 1 || - method_def.statements()[0].kind() != TK_PASS) { + // we need to make sure everything but the last element is just string + // literals (aka comments) unless there is "pass" in between + auto stmts_size = method_def.statements().size(); + for (size_t i = 0; i < stmts_size - 1; i++) { + auto cur_statement = method_def.statements()[i]; + if (cur_statement.kind() == TK_EXPR_STMT) { + auto expr = ExprStmt(cur_statement).expr(); + if (expr.kind() != TK_STRINGLITERAL) { + throw ErrorReport(method_def.range()) + << "interfaces declarations should only contain a single 'pass' statement."; + } + } + // if we see a "pass", we just stop there + if (cur_statement.kind() == TK_PASS) { + this->register_type(iface); + return; + } + } + + if (method_def.statements()[stmts_size - 1].kind() != TK_PASS) { throw ErrorReport(method_def.range()) - << "interfaces declarations should only contain a single 'pass' statement."; + << "interfaces declarations should contain 'pass' statement."; } } this->register_type(iface); diff --git a/torch/csrc/jit/frontend/lexer.h b/torch/csrc/jit/frontend/lexer.h index 3a83d8b9a87fd..ad897ab20a9c4 100644 --- a/torch/csrc/jit/frontend/lexer.h +++ b/torch/csrc/jit/frontend/lexer.h @@ -80,6 +80,12 @@ namespace jit { _(TK_TIMES_EQ, "*=", "*=") \ _(TK_DIV_EQ, "/=", "/=") \ _(TK_MOD_EQ, "%=", "%=") \ + _(TK_BIT_OR_EQ, "|=", "|=") \ + _(TK_BIT_AND_EQ, "&=", "&=") \ + _(TK_BIT_XOR_EQ, "^=", "^=") \ + _(TK_LSHIFT_EQ, "<<=", "<<=") \ + _(TK_RSHIFT_EQ, ">>=", ">>=") \ + _(TK_POW_EQ, "**=", "**=") \ _(TK_GLOBAL, "global", "global") \ _(TK_BUILT_IN, "built-in", "") \ _(TK_SUBSCRIPT, "subscript", "") \ @@ -102,6 +108,7 @@ namespace jit { _(TK_ASSERT, "assert", "assert") \ _(TK_DOTS, "dots", "...") \ _(TK_LIST_COMP, "list comprehension", "") \ + _(TK_DICT_COMP, "dict comprehension", "") \ _(TK_BREAK, "break", "break") \ _(TK_CONTINUE, "continue", "continue") \ _(TK_DELETE, "del", "del") \ @@ -124,8 +131,8 @@ enum TokenKind { #undef DEFINE_TOKEN }; -CAFFE2_API std::string kindToString(int kind); -CAFFE2_API int stringToKind(const std::string& str); +TORCH_API std::string kindToString(int kind); +TORCH_API int stringToKind(const std::string& str); // nested hash tables that indicate char-by-char what is a valid token. struct TokenTrie; @@ -158,7 +165,7 @@ struct TokenTrie { // stuff that is shared against all TC lexers/parsers and is initialized only // once. -struct CAFFE2_API SharedParserData { +struct TORCH_API SharedParserData { SharedParserData() : head(new TokenTrie()) { std::stringstream ss; for (const char* c = valid_single_char_tokens; *c; c++) { @@ -362,7 +369,7 @@ struct CAFFE2_API SharedParserData { TokenTrieRef head; }; -CAFFE2_API SharedParserData& sharedParserData(); +TORCH_API SharedParserData& sharedParserData(); struct Token { int kind; @@ -377,8 +384,8 @@ struct Token { }; struct Lexer { - explicit Lexer(const std::shared_ptr& source) - : source(source), + explicit Lexer(std::shared_ptr source) + : source(std::move(source)), pos(0), nesting(0), indent_stack(), diff --git a/torch/csrc/jit/frontend/parse_string_literal.h b/torch/csrc/jit/frontend/parse_string_literal.h index 4b00647741986..ab6c6607bb4a6 100644 --- a/torch/csrc/jit/frontend/parse_string_literal.h +++ b/torch/csrc/jit/frontend/parse_string_literal.h @@ -12,10 +12,6 @@ inline bool isCharCount(char c, const std::string& str, size_t start, int len) { std::count(str.begin() + start, str.begin() + start + len, c) == len; } -inline static bool isOctal(char c) { - return c >= '0' && c < '8'; -} - inline c10::optional parseOctal(const std::string& str, size_t pos) { //\xxx where x are 0-7 if (pos + 3 >= str.size()) diff --git a/torch/csrc/jit/frontend/parser.cpp b/torch/csrc/jit/frontend/parser.cpp index c9f4aac038ccd..87ae08a041eee 100644 --- a/torch/csrc/jit/frontend/parser.cpp +++ b/torch/csrc/jit/frontend/parser.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -74,6 +75,12 @@ struct ParserImpl { case TK_TIMES_EQ: case TK_DIV_EQ: case TK_MOD_EQ: + case TK_BIT_OR_EQ: + case TK_BIT_AND_EQ: + case TK_BIT_XOR_EQ: + case TK_LSHIFT_EQ: + case TK_RSHIFT_EQ: + case TK_POW_EQ: case TK_NEWLINE: case '=': case ')': @@ -144,6 +151,16 @@ struct ParserImpl { } break; case '{': { L.next(); + // If we have a dict literal, `keys` and `values` will store the keys + // and values used in the object's construction. EDGE CASE: We have a + // dict comprehension, so we'll get the first element of the dict + // comprehension in `keys` and a list comprehension in `values`. + // For example, `{i : chr(i + 65) for i in range(4)}` would give us + // `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`. + // The optimal way of handling this case is to simply splice the new + // dict comprehension together from the existing list comprehension. + // Splicing prevents breaking changes to our API and does not require + // the use of global variables. std::vector keys; std::vector values; auto range = L.cur().range; @@ -155,22 +172,25 @@ struct ParserImpl { } while (L.nextIf(',')); } L.expect('}'); - prefix = DictLiteral::create( - range, - List::create(range, keys), - List::create(range, values)); + if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) { + ListComp lc(*values.begin()); + prefix = DictComp::create( + range, *keys.begin(), lc.elt(), lc.target(), lc.iter()); + } else { + prefix = DictLiteral::create( + range, + List::create(range, keys), + List::create(range, values)); + } } break; case TK_STRINGLITERAL: { prefix = parseConcatenatedStringLiterals(); } break; + case TK_ELLIPSIS: case TK_DOTS: { prefix = Dots::create(L.cur().range); L.next(); } break; - case TK_ELLIPSIS: { - prefix = Dots::create(L.cur().range); - L.next(); - } break; default: { Ident name = parseIdent(); prefix = Var::create(name.range(), name); @@ -197,10 +217,25 @@ struct ParserImpl { case TK_MINUS_EQ: case TK_TIMES_EQ: case TK_DIV_EQ: + case TK_BIT_OR_EQ: + case TK_BIT_AND_EQ: + case TK_BIT_XOR_EQ: case TK_MOD_EQ: { int modifier = L.next().text()[0]; return create_compound(modifier, r, {}); } break; + case TK_LSHIFT_EQ: { + L.next(); + return create_compound(TK_LSHIFT, r, {}); + } break; + case TK_RSHIFT_EQ: { + L.next(); + return create_compound(TK_RSHIFT, r, {}); + } break; + case TK_POW_EQ: { + L.next(); + return create_compound(TK_POW, r, {}); + } break; case '=': { L.next(); return create_compound('=', r, {}); // no reduction @@ -233,8 +268,9 @@ struct ParserImpl { auto kind = L.cur().kind; auto pos = L.cur().range; L.next(); - auto unary_kind = - kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind; + auto unary_kind = kind == '*' ? TK_STARRED + : kind == '-' ? TK_UNARY_MINUS + : kind; auto subexp = parseExp(unary_prec); // fold '-' into constant numbers, so that attributes can accept // things like -1 @@ -561,10 +597,11 @@ struct ParserImpl { return parseFunction(/*is_method=*/in_class); } case TK_DELETE: { - L.expect(TK_DELETE); - auto expr = parseExp(); + auto range = L.next().range; + auto targets = + parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp); L.expect(TK_NEWLINE); - return Delete::create(expr); + return Delete::create(range, targets); } case TK_WITH: { return parseWith(); diff --git a/torch/csrc/jit/frontend/resolver.h b/torch/csrc/jit/frontend/resolver.h index f4c0938023a6e..898c37839c240 100644 --- a/torch/csrc/jit/frontend/resolver.h +++ b/torch/csrc/jit/frontend/resolver.h @@ -25,7 +25,7 @@ using ResolverPtr = std::shared_ptr; * handle the method. */ struct Resolver { - virtual ~Resolver() {} + virtual ~Resolver() = default; // Resolve a given name to a SugaredValue. This takes the method `m` that the // caller is currently constructing, since we may need to insert nodes into diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index fb2e0f20f380f..977cc3daaa40d 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -71,7 +72,7 @@ Value* tryConvertToType( if (convertibleToList(value->type(), unwrapOptional(concrete_type))) { auto unpacked = createTupleUnpack(value); auto elem_type = - unwrapOptional(concrete_type)->expect()->getElementType(); + unwrapOptional(concrete_type)->expectRef().getElementType(); value = graph.insertNode(graph.createList(elem_type, unpacked))->output(); } @@ -284,6 +285,19 @@ static bool varargsCanBeUsedAsList( !typevar_list; } +// Note (@zasdfgbnm): +// This is a workaround for https://github.com/pytorch/pytorch/issues/47964 +// Currently JIT does not distinguish ScalarType vs int, so there is really +// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode +// the aten::view.dtype here to block this overload. This blocklist should be +// removed when JIT fully suports ScalarType as its own type. +bool isBlockListedSchema(const FunctionSchema& schema) { + if (schema.name() == "aten::view" && schema.overload_name() == "dtype") { + return true; + } + return false; +} + static c10::optional tryMatchSchema( const FunctionSchema& schema, const SourceRange& loc, @@ -293,6 +307,10 @@ static c10::optional tryMatchSchema( c10::optional self, std::ostream* failure_messages, bool allow_conversions) { + if (isBlockListedSchema(schema)) { + return c10::nullopt; + } + auto err = [&]() -> std::ostream& { *failure_messages << "\n" << schema << ":\n"; return *failure_messages; @@ -322,8 +340,9 @@ static c10::optional tryMatchSchema( // The actual cannot already be a list if (actual_type->kind() != TypeKind::ListType && !convertibleToList(actual_type, unwrapOptional(arg.type()))) { - auto formal_type = - unwrapOptional(arg.type())->expect()->getElementType(); + auto formal_type = unwrapOptional(arg.type()) + ->expectRef() + .getElementType(); Value* list = tryCreateList( formal_type, @@ -437,9 +456,10 @@ static c10::optional tryMatchSchema( return_field_names = fmap(returns, [&](const Argument& r) { return r.name(); }); } - return MatchedSchema{std::move(positional_inputs), - std::move(return_types), - std::move(return_field_names)}; + return MatchedSchema{ + std::move(positional_inputs), + std::move(return_types), + std::move(return_field_names)}; } MatchedSchema matchSchema( @@ -519,7 +539,7 @@ std::pair matchSchemas( render_errors ? &failure_messages : nullptr, allow_conversions); if (matched_schema) { - return std::make_pair(i, std::move(*matched_schema)); + return std::make_pair(i, *matched_schema); } } } @@ -551,8 +571,8 @@ static Value* packOutputs( TupleTypePtr named_tuple = nullptr; if (field_names) { auto types = fmap(values, [](Value* v) { return v->type(); }); - named_tuple = TupleType::createNamed( - c10::nullopt, field_names.value(), std::move(types)); + named_tuple = + TupleType::createNamed(c10::nullopt, field_names.value(), types); } return g.insertNode(g.createTuple(values, named_tuple))->output(); } @@ -584,14 +604,15 @@ Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, const c10::optional& self) { const auto& variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); std::stringstream failure_messages; std::vector schemas; + schemas.reserve(variants.size()); for (const std::shared_ptr& op : variants) { schemas.push_back(&op->schema()); } @@ -620,7 +641,7 @@ Value* emitBuiltinCall( throw error; } - auto matched = matchSchemas(schemas, loc, graph, inputs, attributes, self); + auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self); if (matched.first < variants.size()) { return emitBuiltinNode(matched.second, loc, graph, name); diff --git a/torch/csrc/jit/frontend/schema_matching.h b/torch/csrc/jit/frontend/schema_matching.h index 88fe23a9682de..83e34bb33ae55 100644 --- a/torch/csrc/jit/frontend/schema_matching.h +++ b/torch/csrc/jit/frontend/schema_matching.h @@ -23,7 +23,7 @@ TORCH_API MatchedSchema matchSchema( const SourceRange& loc, Graph& graph, at::ArrayRef args, - at::ArrayRef kwarg, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt); TORCH_API std::pair matchSchemas( @@ -31,7 +31,7 @@ TORCH_API std::pair matchSchemas( const SourceRange& loc, Graph& graph, at::ArrayRef args, - at::ArrayRef kwarg, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt, bool render_errors = false); @@ -43,8 +43,8 @@ TORCH_API Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt); TORCH_API c10::optional findInputWithName( diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index a089abf7fb2c6..1785a2486b42d 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -24,6 +25,8 @@ using c10::OptionalType; using c10::QSchemeType; using c10::QuantizerType; using c10::RRefType; +using c10::StorageType; +using c10::StreamObjType; using c10::StringType; using c10::Symbol; using c10::TensorType; @@ -40,7 +43,7 @@ TypePtr SchemaTypeParser::parseBaseType() { {"ScalarType", IntType::get()}, {"Layout", IntType::get()}, {"MemoryFormat", IntType::get()}, - {"Storage", IntType::get()}, + {"Storage", StorageType::get()}, {"QScheme", QSchemeType::get()}, {"Quantizer", QuantizerType::get()}, {"ConstQuantizerPtr", @@ -48,6 +51,7 @@ TypePtr SchemaTypeParser::parseBaseType() { // parser, it should use the custom class mechanism // instead. @jerryzh {"Device", DeviceObjType::get()}, + {"Stream", StreamObjType::get()}, {"Scalar", NumberType::get()}, {"str", StringType::get()}, {"float", FloatType::get()}, @@ -148,7 +152,6 @@ c10::optional SchemaTypeParser::parseTensorDType( } c10::optional SchemaTypeParser::tryToParseDeviceType() { - c10::optional device; L.expect('='); const std::string& dev = L.expect(TK_IDENT).text(); @@ -190,7 +193,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { // unknown sizes, a mix of ranks with known and unknown sizes, or ranks with // known sizes and strides. The type might also have requires_grad and/or // device option. Examples of types we're handling here: - // Long(10:48,8:6,6:1, requires_grad=0, device=cuda:1) + // Long(10, 8, 6, strides=[48, 6, 1], requires_grad=0, device=cuda:1) // Float(10, *, 20, device=cuda:1) // Float(requires_grad=1) std::vector> dims; @@ -220,6 +223,17 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { } return; } + if (field == "strides") { + seen_strides = true; + L.expect('='); + parseList('[', ',', ']', [&] { + const std::string& num = L.expect(TK_NUMBER).text(); + std::string::size_type num_len; + size_t stride = c10::stoi(num, &num_len); + strides.push_back(stride); + }); + return; + } throw ErrorReport(L.cur()) << "Unexpected specifier '" << field << "'"; } if (device.has_value() || requires_grad.has_value()) { @@ -241,14 +255,6 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { std::string::size_type num_len; size_t dim = c10::stoi(num, &num_len); dims.emplace_back(dim); - if (seen_strides || L.cur().kind == ':') { - L.expect(':'); - seen_strides = true; - const std::string& num = L.expect(TK_NUMBER).text(); - std::string::size_type num_len; - size_t stride = c10::stoi(num, &num_len); - strides.push_back(stride); - } }); if (seen_strides) { at::IntArrayRef strides_ref(strides); diff --git a/torch/csrc/jit/frontend/schema_type_parser.h b/torch/csrc/jit/frontend/schema_type_parser.h index fe6089d505396..17782473bd65b 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.h +++ b/torch/csrc/jit/frontend/schema_type_parser.h @@ -10,7 +10,7 @@ namespace jit { using TypePtr = c10::TypePtr; -struct CAFFE2_API SchemaTypeParser { +struct TORCH_API SchemaTypeParser { TypePtr parseBaseType(); c10::optional parseAliasAnnotation(); std::pair> parseType(); diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index bec9e879a3978..64837e68e8810 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -86,6 +87,22 @@ TypePtr ScriptTypeParser::subscriptToType( c10::optional> ScriptTypeParser::parseBroadcastList( const Expr& expr) const { + // Alias torch.nn._common_types._size_?_t to BroadcastingList?[int] + if (expr.kind() == TK_VAR) { + auto var = Var(expr); + auto& name = var.name().name(); + constexpr auto _size_prefix = "_size_"; + constexpr auto _size_suffix = "_t"; + constexpr auto _size_n_len = 9; // strlen("_size_X_t") + constexpr auto _size_prefix_len = 6; // strlen("_size_"); + if (name.find(_size_prefix) == 0 && name.length() == _size_n_len && + name.find(_size_suffix) == _size_prefix_len + 1 && + ::isdigit(name[_size_prefix_len])) { + int n = name[_size_prefix_len] - '0'; + return std::pair(ListType::create(IntType::get()), n); + } + } + if (expr.kind() != TK_SUBSCRIPT) return c10::nullopt; auto subscript = Subscript(expr); @@ -149,9 +166,6 @@ c10::optional ScriptTypeParser::parseBaseTypeName( case TK_NONE: { return "None"; } - case TK_STRINGLITERAL: { - return StringLiteral(expr).text(); - } case '.': { auto select = Select(expr); const std::string& name = select.selector().name(); @@ -190,6 +204,22 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const { } return subscriptToType(*value_name, subscript); + } else if (expr.kind() == TK_STRINGLITERAL) { + const auto& type_name = StringLiteral(expr).text(); + if (resolver_) { + if (auto typePtr = resolver_->resolveType(type_name, expr.range())) { + return typePtr; + } + } + + // Check if the type is a custom class. This is done by checking + // if type_name starts with "torch.classes." + if (type_name.find("torch.classes.") == 0) { + auto custom_class_type = getCustomClass("__torch__." + type_name); + return custom_class_type; + } + + throw ErrorReport(expr) << "Unknown type name '" << type_name << "'"; } else if (auto name = parseBaseTypeName(expr)) { auto itr = string_to_type_lut().find(*name); if (itr != string_to_type_lut().end()) { diff --git a/torch/csrc/jit/frontend/script_type_parser.h b/torch/csrc/jit/frontend/script_type_parser.h index 490054c881129..18758ff014c83 100644 --- a/torch/csrc/jit/frontend/script_type_parser.h +++ b/torch/csrc/jit/frontend/script_type_parser.h @@ -15,7 +15,7 @@ namespace jit { */ class TORCH_API ScriptTypeParser { public: - explicit ScriptTypeParser() {} + explicit ScriptTypeParser() = default; explicit ScriptTypeParser(ResolverPtr resolver) : resolver_(std::move(resolver)) {} diff --git a/torch/csrc/jit/frontend/source_range.h b/torch/csrc/jit/frontend/source_range.h index efa297c0440a2..36772807ca8b6 100644 --- a/torch/csrc/jit/frontend/source_range.h +++ b/torch/csrc/jit/frontend/source_range.h @@ -14,7 +14,7 @@ struct SourceRange; // Source represents a code segment. It keeps track of: // - text : the text of the code segment // - filename (optional) : if present, represents the name of the file from -// which the code semgemnt originated. +// which the code segment originated. // - starting_line_no : represents the line in the original file where the // code segment started. struct Source { @@ -106,7 +106,7 @@ struct Source { // A SourceRange is a view into a Source, that points to a subset of the source, // specified by `start` and `end` byte offsets into the source text. -struct CAFFE2_API SourceRange { +struct TORCH_API SourceRange { SourceRange(std::shared_ptr source_, size_t start_, size_t end_) : source_(std::move(source_)), start_(start_), end_(end_) {} SourceRange() : source_(nullptr), start_(0), end_(0) {} diff --git a/torch/csrc/jit/frontend/string_to_type.cpp b/torch/csrc/jit/frontend/string_to_type.cpp index 2674011abc312..3aaf0aa192bd0 100644 --- a/torch/csrc/jit/frontend/string_to_type.cpp +++ b/torch/csrc/jit/frontend/string_to_type.cpp @@ -11,6 +11,7 @@ const std::unordered_map& string_to_type_lut() { {"bool", BoolType::get()}, {"str", StringType::get()}, {"Device", DeviceObjType::get()}, + {"Stream", StreamObjType::get()}, // technically this is not a python type but we need it when // parsing serialized methods that use implicit conversions to Scalar {"number", NumberType::get()}, diff --git a/torch/csrc/jit/frontend/strtod.h b/torch/csrc/jit/frontend/strtod.h index f257a36132244..c333ed045a1e6 100644 --- a/torch/csrc/jit/frontend/strtod.h +++ b/torch/csrc/jit/frontend/strtod.h @@ -5,8 +5,8 @@ namespace torch { namespace jit { -CAFFE2_API double strtod_c(const char* nptr, char** endptr); -CAFFE2_API float strtof_c(const char* nptr, char** endptr); +TORCH_API double strtod_c(const char* nptr, char** endptr); +TORCH_API float strtof_c(const char* nptr, char** endptr); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index f4aed768fbf26..6cd4b7181cd16 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -17,14 +18,14 @@ struct NoneValue : SugaredValue { std::shared_ptr PrintValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { auto& g = *m.graph(); - if (!attributes.empty()) + if (!kwargs.empty()) throw ErrorReport(loc) << "print doesn't accept any keyword arguments"; - std::vector lowered_inputs = toValues(*m.graph(), inputs); + std::vector lowered_inputs = toValues(*m.graph(), args); g.insertNode(g.create(prim::Print, lowered_inputs, 0)->setSourceRange(loc)); return std::make_shared(); } @@ -46,11 +47,11 @@ builtin_cast_method_to_scalar_type() { std::shared_ptr BuiltinFunction::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { return std::make_shared( - emitBuiltinCall(loc, *m.graph(), symbol, inputs, attributes, self)); + emitBuiltinCall(loc, *m.graph(), symbol, args, kwargs, self)); } // older versions of gcc/clang have a bug where enums can't be used as keys @@ -109,6 +110,7 @@ std::shared_ptr SimpleValue::attr( {"is_sparse", "prim"}, {"is_mkldnn", "prim"}, {"is_quantized", "prim"}, + {"is_vulkan", "prim"}, {"is_meta", "prim"}, {"is_leaf", "aten"}, {"requires_grad", "prim"}, @@ -321,14 +323,14 @@ void SimpleValue::setAttr( std::shared_ptr SimpleValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { // allow our 'fake' closures to be called, used for fork serialization // at the moment, but can be expanded later Node* self = getValue()->node(); if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 && - self->inputs().at(0)->node()->kind() == prim::Function) { + self->inputs().at(0)->node()->kind() == prim::Closure) { std::shared_ptr graph = self->inputs().at(0)->node()->g(attr::Subgraph); Value* context = self->inputs().at(1); @@ -347,16 +349,15 @@ std::shared_ptr SimpleValue::call( auto ret = StrongFunctionPtr(std::move(cu), fn); std::vector ctx_inputs = {close_context}; - ctx_inputs.insert(ctx_inputs.end(), inputs.begin(), inputs.end()); - return FunctionValue(ret).call(loc, m, ctx_inputs, attributes, n_binders); + ctx_inputs.insert(ctx_inputs.end(), args.begin(), args.end()); + return FunctionValue(ret).call(loc, m, ctx_inputs, kwargs, n_binders); } if (auto class_type = getValue()->type()->cast()) { - return attr(loc, m, "__call__") - ->call(loc, m, inputs, attributes, n_binders); + return attr(loc, m, "__call__")->call(loc, m, args, kwargs, n_binders); } - return SugaredValue::call(loc, m, inputs, attributes, n_binders); + return SugaredValue::call(loc, m, args, kwargs, n_binders); } Value* SimpleValue::len(const SourceRange& loc, Function& m) { @@ -376,7 +377,8 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) { SugaredValuePtr SimpleValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { Value* val = getValue(); TypePtr val_type = val->type(); Graph& g = *m.graph(); @@ -392,6 +394,17 @@ SugaredValuePtr SimpleValue::getitem( return std::make_shared( g.insert(aten::select, {val, 0, idx}, {}, loc)); } else if (auto class_type = val_type->cast()) { + // Check if this is an indexing operation enabled by a type hint. + // The ModuleDict has already been checked during IR generation to make + // sure its contents implement the module interface referred to by + // type_hint. + if (class_type->is_module() && type_hint) { + auto res = g.insert(prim::ModuleDictIndex, {val, idx}, {}, loc); + res->setType(type_hint); + return std::make_shared(res); + } + + // Defer to the __getitem__ attr on the class. return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1); } else { throw ErrorReport(loc) << "'" << val_type->repr_str() << "'" @@ -484,7 +497,8 @@ Value* RangeValue::len(const SourceRange& loc, Function& m) { SugaredValuePtr RangeValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { if (has_only_end_) { return std::make_shared(idx); } else { @@ -534,7 +548,8 @@ Value* IterableTree::len(const SourceRange& loc, Function& m) { SugaredValuePtr IterableTree::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { std::vector child_items; for (const SugaredValuePtr& child : children_) { child_items.emplace_back(child->getitem(loc, m, idx)); @@ -545,7 +560,7 @@ SugaredValuePtr IterableTree::getitem( void IterableTree::addChild( const SourceRange& range, Function& m, - const SugaredValuePtr iter_value) { + const SugaredValuePtr& iter_value) { c10::optional child_len = iter_value->staticLen(); if (children_.size() == 0) { unroll_length_ = child_len; @@ -568,27 +583,27 @@ void IterableTree::addChild( std::shared_ptr MagicMethod::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { - if (inputs.size() > 0) { - Value* self = inputs[0].value(*m.graph()); + if (args.size() > 0) { + Value* self = args[0].value(*m.graph()); if (auto class_ptr = self->type()->cast()) { return SimpleValue(self) .attr(loc, m, desugared_name_) - ->call(loc, m, inputs.slice(1), attributes, n_binders); + ->call(loc, m, args.slice(1), kwargs, n_binders); } } TORCH_INTERNAL_ASSERT(base_value_); - return base_value_->call(loc, m, inputs, attributes, n_binders); + return base_value_->call(loc, m, args, kwargs, n_binders); } std::shared_ptr ClassValue::call( const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { AT_ASSERT(n_binders <= 1); @@ -601,7 +616,7 @@ std::shared_ptr ClassValue::call( } // Call the init function - MethodValue(self, "__init__").call(loc, m, inputs, attributes, n_binders); + MethodValue(self, "__init__").call(loc, m, args, kwargs, n_binders); return std::make_shared(self); } @@ -620,15 +635,15 @@ std::shared_ptr ClassValue::attr( std::shared_ptr NamedTupleConstructor::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { auto& g = *m.graph(); auto schema = type_->schema(); TORCH_INTERNAL_ASSERT(schema); auto qualname = type_->name(); - auto matched_schema = matchSchema(*schema, loc, g, inputs, attributes); + auto matched_schema = matchSchema(*schema, loc, g, args, kwargs); auto self = g.insertNode( diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 3523523f5c23b..85bf1d2020a04 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -84,8 +85,8 @@ struct TORCH_API SugaredValue const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { // n_binders is always set to the number of variables an expression is // syntactically bound to: @@ -139,7 +140,8 @@ struct TORCH_API SugaredValue virtual std::shared_ptr getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint = nullptr) { throw ErrorReport(loc) << "'" << kind() << "'" << " object is not subscriptable"; } @@ -181,8 +183,8 @@ struct TORCH_API SimpleValue : public SugaredValue { const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::shared_ptr iter(const SourceRange& loc, Function& m) @@ -193,8 +195,11 @@ struct TORCH_API SimpleValue : public SugaredValue { } Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; private: Value* value_; @@ -215,8 +220,8 @@ struct TORCH_API BuiltinFunction : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef attributes, - at::ArrayRef inputs, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; // try to create this builtin but if it doesn't exist or the self argument @@ -229,7 +234,7 @@ struct TORCH_API BuiltinFunction : public SugaredValue { struct TORCH_API SugaredTupleValue : public SugaredValue { explicit SugaredTupleValue(std::vector> tup) - : tup_(tup){}; + : tup_(std::move(tup)){}; std::vector> asTuple( const SourceRange& loc, @@ -251,8 +256,11 @@ struct TORCH_API SugaredTupleValue : public SugaredValue { return "Tuple"; } - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override { + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override { if (!(idx->type()->cast() && toIValue(idx))) { throw ErrorReport(loc) << "Expected integer literal for index. " @@ -289,7 +297,7 @@ struct TORCH_API SugaredTupleValue : public SugaredValue { struct TORCH_API BuiltinModule : public SugaredValue { BuiltinModule(std::string name, c10::optional version = at::nullopt) - : name(std::move(name)), version(std::move(version)) {} + : name(std::move(name)), version(version) {} std::string kind() const override { return "builtin module"; @@ -332,8 +340,8 @@ struct TORCH_API ClassValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::shared_ptr attr( @@ -354,8 +362,8 @@ struct TORCH_API NamedTupleConstructor : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::string kind() const override { @@ -366,7 +374,7 @@ struct TORCH_API NamedTupleConstructor : public SugaredValue { }; struct FunctionValue : public SugaredValue { - FunctionValue(Function* callee) : callees_({std::move(callee)}) {} + FunctionValue(Function* callee) : callees_({callee}) {} FunctionValue(const StrongFunctionPtr& p) : callees_({p.function_}), cu_(p.cu_) {} FunctionValue(const std::vector& callees) { @@ -384,8 +392,8 @@ struct FunctionValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { std::vector schemas; for (Function* callee : callees_) { @@ -398,7 +406,7 @@ struct FunctionValue : public SugaredValue { } schemas.push_back(&callee->getSchema()); } - auto match = matchSchemas(schemas, loc, *f.graph(), inputs, attributes); + auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs); Value* output = f.graph()->insertFunctionCall(callees_[match.first], match.second); output->node()->setSourceRange(loc); @@ -417,7 +425,7 @@ struct FunctionValue : public SugaredValue { struct TORCH_API ClosureValue : public SugaredValue { ClosureValue(Value* value) : value_(value) { - TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Function); + TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure); } std::string kind() const override { return "closure"; @@ -431,9 +439,9 @@ struct TORCH_API ClosureValue : public SugaredValue { // defines how a method obtained from a module/class/interface behaves in script struct MethodValue : public SugaredValue { MethodValue(Value* self, std::vector method_names) - : self_(std::move(self)), method_names_(std::move(method_names)) {} + : self_(self), method_names_(std::move(method_names)) {} MethodValue(Value* self, std::string method_name) - : MethodValue(self, std::vector({method_name})) {} + : MethodValue(self, std::vector({std::move(method_name)})) {} std::string kind() const override { return "method"; @@ -442,11 +450,11 @@ struct MethodValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - std::vector inputsWithSelf = {self_}; - inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end()); + std::vector argsWithSelf = {self_}; + argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); std::vector schemas; for (const std::string& method_name : method_names_) { if (auto class_type = self_->type()->cast()) { @@ -466,8 +474,7 @@ struct MethodValue : public SugaredValue { false, "method constructed that is not a class or interface"); } } - auto match = - matchSchemas(schemas, loc, *f.graph(), inputsWithSelf, attributes); + auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs); Value* output = f.graph()->insertMethodCall(method_names_[match.first], match.second); output->node()->setSourceRange(loc); @@ -486,8 +493,8 @@ struct TORCH_API PrintValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; @@ -500,16 +507,16 @@ struct TORCH_API CastValue : public BuiltinFunction { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - if (inputs.size() == 1 && attributes.size() == 0) { - auto v = inputs[0].value(*m.graph()); + if (args.size() == 1 && kwargs.size() == 0) { + auto v = args[0].value(*m.graph()); if (v->type()->isSubtypeOf(type_)) { return std::make_shared(v); } } - return BuiltinFunction::call(loc, m, inputs, attributes, n_binders); + return BuiltinFunction::call(loc, m, args, kwargs, n_binders); } private: @@ -527,17 +534,17 @@ struct TORCH_API TensorCastValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - TORCH_INTERNAL_ASSERT(inputs.size() == 0 && attributes.size() == 0); + TORCH_INTERNAL_ASSERT(args.size() == 0 && kwargs.size() == 0); Value* dtype_const = m.graph()->insertConstant(dtype_, loc); - std::vector kwargs{self_, - NamedValue(loc, "dtype", dtype_const)}; + std::vector kwargs_{ + self_, NamedValue(loc, "dtype", dtype_const)}; Value* casted_val = m.graph()->insert( /*opname=*/Symbol::fromQualString("aten::to"), - /*args=*/inputs, - /*kwargs=*/kwargs, + /*args=*/args, + /*kwargs=*/kwargs_, /*range=*/loc); return std::make_shared(casted_val); } @@ -560,8 +567,8 @@ struct TORCH_API MagicMethod : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; private: @@ -604,8 +611,11 @@ struct TORCH_API RangeValue : SugaredValue { return "range"; } Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; std::shared_ptr iter(const SourceRange& loc, Function& m) override; @@ -661,7 +671,7 @@ struct TORCH_API IterableTree : SugaredValue { void addChild( const SourceRange& range, Function& m, - const SugaredValuePtr iter_value); + const SugaredValuePtr& iter_value); std::vector get_children() { return children_; @@ -680,8 +690,11 @@ struct TORCH_API IterableTree : SugaredValue { std::vector get_base_iterables(); Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; private: c10::optional unroll_length_ = c10::nullopt; @@ -726,7 +739,7 @@ struct TORCH_API ExceptionMessageValue : public SugaredValue { }; struct TORCH_API ExceptionValue : public SugaredValue { - explicit ExceptionValue(const std::string& message) : message_(message) {} + explicit ExceptionValue(std::string message) : message_(std::move(message)) {} std::string kind() const override { return "exception"; @@ -735,11 +748,11 @@ struct TORCH_API ExceptionValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, + at::ArrayRef args, at::ArrayRef /*attributes*/, size_t /*n_binders*/) override { auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc); - for (auto& input : inputs) { + for (auto& input : args) { auto input_str = input.value(*m.graph()); if (!input_str->type()->isSubtypeOf(StringType::get())) { input_str = diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index c4f749271a535..f5b1d62f075b9 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -103,6 +103,9 @@ void TracingState::delValue(const IValue& var) { Value* getValueTrace(const IValue& var) { return getTracingState()->getValue(var); } +Value* getOptTensorValueTrace(const c10::optional& var) { + return getValueTrace(IValue(var)); +} Value* TracingState::getValue(const IValue& var) { // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] // arguments @@ -134,7 +137,7 @@ Value* TracingState::getValue(const IValue& var) { return graph->insertNode(dict_node)->output(); } if (var.isTensor()) { - auto ten = var.toTensor(); + auto& ten = var.toTensor(); if (!ten.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); @@ -234,7 +237,7 @@ bool TracingState::hasValue(const IValue& var) const { Value* TracingState::getOutput(const IValue& iv, size_t i) { bool tracing_mode_strict = getTracingState()->strict; if (iv.isTensor()) { - at::Tensor var = iv.toTensor(); + const at::Tensor& var = iv.toTensor(); if (!var.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); @@ -284,11 +287,23 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) { key_type->isSubtypeOf(TensorType::get()); bool value_type_valid = value_type->isSubtypeOf(TensorType::get()); + // Support tuple values that contain only tensors + if (value_type->isSubtypeOf(AnyTupleType::get())) { + value_type_valid = true; + for (const auto& type : value_type->containedTypes()) { + if (!type->isSubtypeOf(TensorType::get())) { + value_type_valid = false; + break; + } + } + } + if (!key_type_valid || !value_type_valid) { std::ostringstream os; os << "output " << i << " (" << dict << ") of traced region " - << "cannot be understood by the tracer, only dict[str, Tensor] " - << "or dict[Tensor, Tensor] can be a dictionary output of a traced function"; + << "cannot be understood by the tracer, only outputs matching" + << "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] " + << "can be a dictionary output of a traced function"; throw std::runtime_error(os.str()); } std::vector keys; @@ -491,7 +506,7 @@ void setValueTrace(const IValue& v, Value* value) { } void TracingState::setValue(const IValue& v, Value* value) { if (v.isTensor()) { - auto var = v.toTensor(); + auto& var = v.toTensor(); AT_ASSERT(var.defined()); env_stack.back()[v] = value; } else if (v.isTensorList()) { @@ -547,7 +562,11 @@ void addInputs(Node* n, const char* name, int64_t value) { } void addInputs(Node* n, const char* name, c10::optional value) { - if (value) { + using ArgumentStash = jit::tracer::ArgumentStash; + if (ArgumentStash::hasValue(name)) { + Value* v = ArgumentStash::popValue(name); + n->addInput(v); + } else if (value) { detail::genericAddInput(n, *value); } else { Graph* g = n->owningGraph(); @@ -614,6 +633,9 @@ void addInputs( void addInputs(Node* n, const char* name, at::Device value) { detail::genericAddInput(n, value); } +void addInputs(Node* n, const char* name, c10::Stream stream) { + detail::genericAddInput(n, static_cast(stream.pack())); +} void addInputs(Node* n, const char* name, at::Layout value) { detail::genericAddInput(n, static_cast(value)); } @@ -671,6 +693,16 @@ void addInputs( } n->addInput(list_node->output()); } +TORCH_API void addInputs( + Node* n, + const char* name, + const List>& value) { + Graph* g = n->owningGraph(); + Node* list_node = nullptr; + list_node = g->insertNode(g->createList( + OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace))); + n->addInput(list_node->output()); +} void addInputs( Node* n, @@ -696,15 +728,6 @@ void addInputs( } } -void addInputs(Node* n, const char* name, const at::TensorOptions& options) { - // [TensorOptions in script] - update this when you change how we schematize - // TensorOptions - addInputs(n, name, options.dtype_opt()); - addInputs(n, name, options.layout()); - addInputs(n, name, options.device()); - addInputs(n, name, options.pinned_memory()); -} - void addInputs(Node* n, const char* name, at::IntArrayRef value) { using ArgumentStash = jit::tracer::ArgumentStash; std::vector info = ArgumentStash::hasIntArrayRef(name) @@ -921,6 +944,18 @@ void setRecordSourceLocation(void (*v)(Node*)) { record_source_location.store(v); } +std::vector defaultPythonCallstack() { + return std::vector(); +} +std::atomic python_callstack_fn( + defaultPythonCallstack); +std::vector pythonCallstack() { + return python_callstack_fn.load()(); +} +void setPythonCallstack(std::vector (*v)()) { + python_callstack_fn.store(v); +} + void defaultWarn(const std::string& str) { TORCH_WARN(str); } diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index 82ce500c532c1..f5cbd821bda48 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -193,6 +194,9 @@ struct WithNestedTracingFrame { TORCH_API void recordSourceLocation(Node* n); TORCH_API void setRecordSourceLocation(void (*v)(Node*)); +TORCH_API std::vector pythonCallstack(); +TORCH_API void setPythonCallstack(std::vector (*v)()); + // Having finished adding a new 'node' to the graph IR 'setValueTrace' // associates this node with an output variable, so that further operations // involving this variable know which node in the IR to reference. @@ -251,6 +255,10 @@ TORCH_API void addInputs( const char* name, ArrayRef value, bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + const List>& value); TORCH_API void addInputs( Node* n, const char* name, @@ -266,11 +274,8 @@ TORCH_API void addInputs( Node* n, const char* name, const c10::optional& value); -TORCH_API void addInputs( - Node* n, - const char* name, - const at::TensorOptions& value); TORCH_API void addInputs(Node* n, const char* name, at::Device value); +TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream); TORCH_API void addInputs(Node* n, const char* name, at::Layout value); TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value); TORCH_API void addInputs( diff --git a/torch/csrc/jit/frontend/tree.h b/torch/csrc/jit/frontend/tree.h index c4a0a606037c0..f9d86f05d9eb3 100644 --- a/torch/csrc/jit/frontend/tree.h +++ b/torch/csrc/jit/frontend/tree.h @@ -12,20 +12,17 @@ namespace torch { namespace jit { -// Tree's are used to represent all forms of TC IR, pre- and post- typechecking. -// Rather than have a full class hierarchy for all TC statements, -// Trees are a slight variation of Lisp S-expressions. -// for instance the expression a*b+1 is represented as: +// Trees are used to represent all forms of TC IR, pre- and post-typechecking. +// Rather than have a full class hierarchy for all TC statements, trees are a +// slight variation of Lisp s-expressions. For instance, the expression a*b+1 +// is represented as: // (+ (* (ident a) (ident b)) (const 1)) // Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which -// define stringValue(). -// Everything else is a Compound object, which has a 'kind' that is a token from -// Lexer.h's TokenKind enum, and contains a list of subtrees. -// Like TokenKind single-character operators like '+' are representing using the -// character itself, so add.kind() == '+'. -// Compound objects are also always associated with a SourceRange for -// reporting error message. - +// define stringValue(). Everything else is a Compound object, which has a +// 'kind' that is a token from lexer.h's TokenKind enum. Single-character +// operators like '+' are represented using the character itself (so, add.kind() +// would be '+'). Each Compound object also contains a list of subtrees and is +// associated with a SourceRange for error reporting. // Memory management of trees is done using intrusive_ptr. struct Tree; diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index 1026a1c17d84f..879638de556c9 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -226,8 +226,9 @@ struct Ident : public TreeView { const std::string& name() const { return subtree(0)->stringValue(); } - static Ident create(const SourceRange& range, const std::string& name) { - return Ident(Compound::create(TK_IDENT, range, {String::create(name)})); + static Ident create(const SourceRange& range, std::string name) { + return Ident( + Compound::create(TK_IDENT, range, {String::create(std::move(name))})); } }; @@ -308,6 +309,7 @@ struct Expr : public TreeView { case '^': case '|': case TK_LIST_COMP: + case TK_DICT_COMP: case TK_DOTS: case TK_IN: case TK_WITH_ITEM: @@ -403,7 +405,7 @@ struct Def : public TreeView { auto new_ident = Ident::create(name().range(), std::move(new_name)); return create(range(), new_ident, decl(), statements()); } - Def withDecl(Decl decl) const { + Def withDecl(const Decl& decl) const { return create(range(), name(), decl, statements()); } Ident name() const { @@ -578,6 +580,35 @@ struct ListComp : public Expr { } }; +// TODO: supports only single comprehension for now +struct DictComp : public Expr { + explicit DictComp(const TreeRef& tree) : Expr(tree) { + tree->match(TK_DICT_COMP); + } + Expr key() const { + return Expr(subtree(0)); + } + Expr value() const { + return Expr(subtree(1)); + } + Expr target() const { + return Expr(subtree(2)); + } + Expr iter() const { + return Expr(subtree(3)); + } + // TODO: no ifs for now + static DictComp create( + const SourceRange& range, + const Expr& key, + const Expr& value, + const Expr& target, + const Expr& iter) { + return DictComp( + Compound::create(TK_DICT_COMP, range, {key, value, target, iter})); + } +}; + struct Global : public Stmt { explicit Global(const TreeRef& tree) : Stmt(tree) { tree_->match(TK_GLOBAL); @@ -598,6 +629,12 @@ struct AugAssignKind : public TreeView { case '*': case '/': case '%': + case '|': + case '&': + case '^': + case TK_POW: + case TK_LSHIFT: + case TK_RSHIFT: return; default: throw ErrorReport(tree) << "is not a valid AugAssignKind"; @@ -839,7 +876,7 @@ struct Const : public Expr { } int64_t asIntegral() const { try { - return c10::stoll(subtree(0)->stringValue(), /*pos=*/0, /*base=*/0); + return c10::stoll(subtree(0)->stringValue(), /*__idx=*/0, /*base=*/0); } catch (const std::out_of_range& e) { throw ErrorReport(range()) << "Integral constant out of range " "(must fit in a signed 64 bit integer)"; @@ -928,15 +965,15 @@ struct SliceExpr : public Expr { Maybe step() const { return Maybe(subtree(2)); } - Expr startOr(int alternative) const { + Expr startOr(int64_t alternative) const { const auto startOption = start(); return startOption.present() ? startOption.get() : createInt(alternative); } - Expr endOr(int alternative) const { + Expr endOr(int64_t alternative) const { const auto endOption = end(); return endOption.present() ? endOption.get() : createInt(alternative); } - Expr stepOr(int alternative) const { + Expr stepOr(int64_t alternative) const { const auto stepOption = step(); return stepOption.present() ? stepOption.get() : createInt(alternative); } @@ -950,7 +987,7 @@ struct SliceExpr : public Expr { } private: - Expr createInt(int value) const { + Expr createInt(int64_t value) const { return Expr(Const::create(range(), c10::to_string(value))); } }; @@ -1119,11 +1156,11 @@ struct Delete : public Stmt { explicit Delete(const TreeRef& tree) : Stmt(tree) { tree_->match(TK_DELETE); } - Expr expr() const { - return Expr(subtree(0)); + List targets() const { + return subtree(0); } - static Delete create(const Expr& value) { - return Delete(Compound::create(TK_DELETE, value.range(), {value})); + static Delete create(const SourceRange& range, const List& targets) { + return Delete(Compound::create(TK_DELETE, range, {targets})); } }; diff --git a/torch/csrc/jit/frontend/versioned_symbols.cpp b/torch/csrc/jit/frontend/versioned_symbols.cpp index 8e39e6f4247fc..d7ee0d3393ad9 100644 --- a/torch/csrc/jit/frontend/versioned_symbols.cpp +++ b/torch/csrc/jit/frontend/versioned_symbols.cpp @@ -1,4 +1,5 @@ #include + #include #include diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index bb5872f35f4f2..c1abd8cc016a0 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -63,7 +63,7 @@ class MutableTypePtrHelper { } case TypeKind::TupleType: { std::vector mutable_types; - for (const auto& elem : type->expect()->elements()) { + for (const auto& elem : type->expectRef().elements()) { if (auto mut_elem = getMutableType(elem)) { mutable_types.push_back(*mut_elem); } @@ -486,6 +486,7 @@ void AliasDb::analyzeImpl(Node* node) { return analyzeGradOf(node); // TODO: think more about TensorExpr alias correctness case prim::TensorExprGroup: + case prim::StaticSubgraph: case prim::Constant: case prim::AutogradZero: case prim::AutogradAdd: @@ -494,7 +495,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::MMBatchSide: case prim::BroadcastSizes: case prim::ChunkSizes: - case prim::Function: + case prim::Closure: case prim::CreateObject: case prim::tolist: return analyzeCreator(node); @@ -510,7 +511,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::GetAttr: if (isFrozen_ && node->kind() == prim::GetAttr) { auto& ty = node->input()->type(); - if (ty->expect()->is_module()) { + if (ty->expectRef().is_module()) { return analyzeCreator(node); } } @@ -524,10 +525,12 @@ void AliasDb::analyzeImpl(Node* node) { case prim::SetAttr: return analyzeSetAttr(node); case prim::profile_optional: + case prim::profile_ivalue: case prim::profile: makePointerTo(node->output(), node->inputs().at(0)); return; - case prim::TypeCheck: { + case prim::TypeCheck: + case prim::RequiresGradCheck: { auto num_inputs = node->inputs().size(); for (size_t i = 0; i < num_inputs; i++) { makePointerTo(node->outputs().at(i), node->inputs().at(i)); @@ -570,7 +573,8 @@ void AliasDb::analyzeImpl(Node* node) { !aliasAnalysisHasSpecialCaseFor(node->kind()), "Special cases should be handled already if we're here."); - if (node->kind().is_aten() || node->kind().is_prim()) { + if (node->kind().is_aten() || node->kind().is_prim() || + node->kind().is_cuda()) { // TODO There is nothing in the system that relies on aten:: and prim:: // ops using AliasAnalysisKind::FROM_SCHEMA or // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended @@ -992,7 +996,7 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) { // the contained types of immutable type containers (optional, tuple, future) // are unified, so these types can be mutable or immutable // and point to a type which is mutable or immutable. - // Any is mutable but can point to a immutable type through refinement + // Any is mutable but can point to an immutable type through refinement if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) { bool expected_kind = false; for (auto kind : {from->type()->kind(), to->type()->kind()}) { diff --git a/torch/csrc/jit/ir/attributes.h b/torch/csrc/jit/ir/attributes.h index 21c4a0e96b8d1..16789a18e2bf4 100644 --- a/torch/csrc/jit/ir/attributes.h +++ b/torch/csrc/jit/ir/attributes.h @@ -21,19 +21,20 @@ constexpr int max_tensor_display_size = 10; enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs, ty, tys, ival }; static inline const char* toString(AttributeKind kind) { - static const char* names[] = {"f", - "fs", - "i", - "is", - "s", - "ss", - "t", - "ts", - "g", - "gs", - "ty", - "tys", - "ival"}; + static const char* names[] = { + "f", + "fs", + "i", + "is", + "s", + "ss", + "t", + "ts", + "g", + "gs", + "ty", + "tys", + "ival"}; AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(*names)); return names[int(kind)]; } @@ -108,7 +109,7 @@ struct TORCH_API GraphAttr : public AttributeValue { using ConstructorType = std::shared_ptr; using ValueType = std::shared_ptr; GraphAttr(Symbol name, ConstructorType value_) - : AttributeValue(name), value_(value_) {} + : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { return value_; } @@ -138,8 +139,8 @@ struct TORCH_API GraphsAttr : public AttributeValue { ValueType value_; }; -struct AttributeError : public std::exception { - AttributeError(Symbol name, bool defined) { +struct IRAttributeError : public std::exception { + IRAttributeError(Symbol name, bool defined) { std::stringstream ss; if (!defined) { ss << "required keyword attribute '" << name.toUnqualString() diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index 8b2c0a0da56b5..512533e7d89d0 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -8,12 +9,6 @@ namespace torch { namespace jit { -namespace { -c10::AliasAnalysisKind aliasAnalysisInternalSpecialCase() { - return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; -} -} // namespace - bool insertableTensor(const at::Tensor& ten) { return !ten.requires_grad(); } @@ -53,7 +48,7 @@ Value* insertConstant( const IValue& val, c10::optional loc, c10::optional scope) { - auto value = tryInsertConstant(g, val, loc, scope); + auto value = tryInsertConstant(g, val, std::move(loc), std::move(scope)); if (value) { return *value; } @@ -110,6 +105,10 @@ c10::optional tryInsertConstant( ss << val.toDevice(); n->s_(attr::value, ss.str()); n->output()->setType(DeviceObjType::get()); + } else if (val.isStream()) { + auto stream = val.toStream(); + n->i_(attr::value, stream.pack()); + n->output()->setType(StreamObjType::get()); } else if (val.isNone()) { n->output()->setType(NoneType::get()); } else if (val.isTuple()) { @@ -120,10 +119,9 @@ c10::optional tryInsertConstant( n->destroy(); return c10::nullopt; }; - } else if (val.isGenericDict() && insertableIValue(val)) { - n->ival_(attr::value, val); - n->output()->setType(val.type()); - } else if (val.isEnum()) { + } else if ( + (val.isGenericDict() && insertableIValue(val)) || (val.isEnum()) || + (val.isObject() && !val.toObjectRef().type()->is_module())) { n->ival_(attr::value, val); n->output()->setType(val.type()); } else { @@ -179,11 +177,17 @@ c10::optional toIValue(const Value* v) { } else if (type == DeviceObjType::get()) { auto d = c10::Device(node->s(attr::value)); return d; + } else if (type == StreamObjType::get()) { + auto s = c10::Stream::unpack(node->i(attr::value)); + return s; } else if (node->mustBeNone()) { return IValue(); } else if (type->cast()) { const auto& enum_val = node->ival(attr::value); return enum_val; + } else if (type->cast() && !type->is_module()) { + const auto& class_val = node->ival(attr::value); + return class_val; } else { std::stringstream ss; ss << "constant literal not supported for: " << type->str(); diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 6543f36d6ac2d..7f58aca252713 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -104,7 +104,7 @@ std::ostream& operator<<( static void printAttribute(std::ostream& out, const at::Tensor& tensor) { // 1-elem tensors are usually boxed scalars, so print them like it if (tensor.numel() == 1) { - auto scalar_tensor = tensor.view({}).item(); + auto scalar_tensor = tensor.view(std::vector{}).item(); out << "{"; if (scalar_tensor.isFloatingPoint()) { out << scalar_tensor.toDouble(); @@ -133,6 +133,9 @@ static void printAttribute(std::ostream& out, const IValue& ival) { } else if (input.isTensorList()) { ss << "[]"; return true; + } else if (input.isObject() && !input.type()->is_module()) { + ss << "object(" << &input.toObjectRef() << ")"; + return true; } return false; }; @@ -326,6 +329,8 @@ std::ostream& Graph::print(std::ostream& out, bool print_source_locations) out << "with " << fg->kind().toQualString() << "_" << i++ << " = " << *fg->g(attr::Subgraph); } + out.flush(); + /* // Uncomment this to debug all_nodes issues { @@ -844,22 +849,40 @@ void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) { uses_.end()); } -size_t findArgument(const FunctionSchema& the_schema, Symbol name) { - auto name_str = name.toUnqualString(); +size_t findArgument( + const FunctionSchema& the_schema, + const std::string& unqualName) { for (size_t i = 0; i < the_schema.arguments().size(); ++i) { const Argument* arg = &the_schema.arguments()[i]; - if (arg->name() == name_str) { + if (arg->name() == unqualName) { return i; } } throw std::runtime_error( - std::string("Couldn't find an argument called ") + name.toQualString()); + std::string("Couldn't find an argument called ") + unqualName); +} + +size_t findArgument(const FunctionSchema& the_schema, Symbol name) { + const auto unqualName = name.toUnqualString(); + return findArgument(the_schema, unqualName); } c10::optional Node::get(Symbol name) const { return toIValue(namedInput(name)); } +bool Node::hasNamedInput(const std::string& name) const { + for (const auto& argument : schema().arguments()) { + if (argument.name() == name) { + return true; + } + } + return false; +} + +Value* Node::namedInput(const std::string& unqualName) const { + return input(findArgument(schema(), unqualName)); +} Value* Node::namedInput(Symbol name) const { return input(findArgument(schema(), name)); } @@ -1056,6 +1079,11 @@ bool Node::hasSideEffects() const { case prim::rpc_sync: // It represents RPC message sent. case prim::rpc_remote: // It represents RPC message sent. case aten::wait: // It can represent RPC message received. +#ifndef __HIP_PLATFORM_HCC__ + case cuda::set_stream: + case cuda::_set_device: + case cuda::_current_device: +#endif case prim::Enter: case prim::Exit: return true; @@ -1071,7 +1099,7 @@ bool Node::hasSideEffects() const { return false; } - if (kind_.is_prim() || kind_.is_aten()) { + if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) { // TODO There is nothing in the system that relies on aten:: and prim:: // ops using AliasAnalysisKind::FROM_SCHEMA, // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or @@ -1091,9 +1119,7 @@ bool Node::hasSideEffects() const { switch (op->aliasAnalysisKind()) { case AliasAnalysisKind::PURE_FUNCTION: - return false; case AliasAnalysisKind::FROM_SCHEMA: - return false; case AliasAnalysisKind::INTERNAL_SPECIAL_CASE: return false; case AliasAnalysisKind::CONSERVATIVE: @@ -1585,17 +1611,25 @@ Node* Graph::createTupleIndex( return n; } -Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) { - auto n = create(prim::TupleSlice, {tup}); - auto tuple_type = tup->type()->expect(); - n->i_(attr::beg, beg); - n->i_(attr::end, end); - std::vector output_types; - for (auto i = beg; i < end; ++i) { - output_types.push_back(tuple_type->elements().at(i)); +Node* Graph::createTupleSlice( + Value* tup, + int64_t beg, + int64_t step_size, + int64_t num_values) { + std::vector new_vals; + TupleTypePtr tt = tup->type()->expect(); + new_vals.reserve(num_values); + + int64_t i = beg; + for (int64_t j = 0; j < num_values; ++j) { + auto idx = insertConstant(IValue(static_cast(i))); + auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i])); + + new_vals.push_back(tupleIndex->output()); + i += step_size; } - auto tt = TupleType::create(std::move(output_types)); - n->output()->setType(tt); + + auto n = createTuple(new_vals); return n; } @@ -1877,10 +1911,50 @@ std::vector inlineCallTo( std::unordered_map new_callstack_entries; + c10::optional module_instance_info = c10::nullopt; + if (to_replace->kind() == prim::CallMethod) { + auto class_type_ptr = to_replace->input(0)->type()->cast(); + if (to_replace->input(0)->node()->kind() == prim::GetAttr) { + module_instance_info = c10::make_optional(ModuleInstanceInfo( + class_type_ptr, to_replace->input(0)->node()->s(attr::name))); + } else { + std::string instance_name_unknown("INSTANCE_NAME_UNKNOWN"); + module_instance_info = c10::make_optional( + ModuleInstanceInfo(class_type_ptr, instance_name_unknown)); + } + } + // TODO: We might need to use nodes_map instead of value_map. Otherwise, we // are missing nodes without outputs (e.g. prim::Print). std::unordered_set updated_nodes; for (const auto& kv : value_map) { + /* Skip the old value if it is the graph input. + * The reason is that, value_map contains values not all for the nodes of + * the graph but primary inputs as well, and it will create duplicates when + * the first inlined graph is input to the next one. To avoid this issue, + * skip the old value when it is one of the + * callee->optimized_graph()->inputs() or callee->graph()->inputs(), depends + * on if it is inlined_optimized_graph + */ + + if (inline_optimized_graph) { + auto is_graph_input = std::find( + callee->optimized_graph()->inputs().begin(), + callee->optimized_graph()->inputs().end(), + kv.first); + if (is_graph_input != callee->optimized_graph()->inputs().end()) { + continue; + } + } else { + auto is_graph_input = std::find( + callee->graph()->inputs().begin(), + callee->graph()->inputs().end(), + kv.first); + if (is_graph_input != callee->graph()->inputs().end()) { + continue; + } + } + Node* new_node = kv.second->node(); if (!updated_nodes.insert(new_node).second) { continue; @@ -1895,16 +1969,18 @@ std::vector inlineCallTo( if (new_node_cs) { new_callstack_entries[raw_callstack_ptr] = c10::make_intrusive( - *new_node_cs, callee, to_replace->sourceRange()); + *new_node_cs, + callee, + to_replace->sourceRange(), + module_instance_info); } else { new_callstack_entries[raw_callstack_ptr] = c10::make_intrusive( - callee, to_replace->sourceRange()); + callee, to_replace->sourceRange(), module_instance_info); } } new_node->setCallStack(new_callstack_entries.at(raw_callstack_ptr)); } - const auto& old_outputs = to_replace->outputs(); AT_ASSERT(new_outputs.size() == old_outputs.size()); @@ -1990,6 +2066,16 @@ Node* ProfileOptionalOp::allocNewInstance(Graph* g) { return new ProfileOptionalOp(g, {nullptr}); } +void ProfileIValueOp::cloneFrom(Node* other_) { + Node::cloneFrom(other_); + auto other = other_->cast(); + this->callback_ = other->getCallback(); +} + +Node* ProfileIValueOp::allocNewInstance(Graph* g) { + return new ProfileIValueOp(g, {nullptr}); +} + TypePtr NamedValue::type() const { if (value_) { return value_->type(); @@ -1998,8 +2084,9 @@ TypePtr NamedValue::type() const { } } -constexpr Symbol ProfileOp::Kind; -constexpr Symbol ProfileOptionalOp::Kind; +const Symbol ProfileOp::Kind = ::c10::prim::profile; +const Symbol ProfileOptionalOp::Kind = ::c10::prim::profile_optional; +const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue; OperatorSet::OperatorSet(std::initializer_list sig_literals) { for (const char* sig : sig_literals) { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 665bd9797b26b..02867b8639cdd 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -72,6 +72,11 @@ using namespace ::c10::attr; namespace aten { using namespace ::c10::aten; } +namespace cuda { +#ifndef __HIP_PLATFORM_HCC__ +using namespace ::c10::cuda; +#endif +} // namespace cuda struct Function; struct MatchedSchema; @@ -414,6 +419,8 @@ struct TORCH_API Node { return inputs_.at(i); } + bool hasNamedInput(const std::string& unqualName) const; + Value* namedInput(const std::string& unqualName) const; Value* namedInput(Symbol name) const; c10::optional get(Symbol name) const; @@ -438,7 +445,7 @@ struct TORCH_API Node { // instructions lowered by the interpreter and not run in the optimized graph bool notExecutedOp() const { return kind_ == prim::Constant || kind_ == prim::profile || - kind_ == prim::profile_optional; + kind_ == prim::profile_optional || kind_ == prim::profile_ivalue; } // Graphs @@ -808,7 +815,7 @@ struct TORCH_API Node { auto it = findAttr(name, true); auto* child = dynamic_cast(it->get()); if (child == nullptr) { - throw AttributeError(name, true); + throw IRAttributeError(name, true); } return child->value(); } @@ -823,7 +830,7 @@ struct TORCH_API Node { return v->name == name; }); if (required && it == values_.end()) { - throw AttributeError(name, false); + throw IRAttributeError(name, false); } AT_ASSERT(!required || it != values_.end()); return it; @@ -835,7 +842,7 @@ struct TORCH_API Node { return v->name == name; }); if (required && it == values_.end()) { - throw AttributeError(name, false); + throw IRAttributeError(name, false); } AT_ASSERT(!required || it != values_.end()); return it; @@ -934,14 +941,14 @@ struct Block { return owning_node_; } - Value* addInput(std::string name = "") { + Value* addInput(const std::string& name = "") { Value* v = input_->addOutput(); - v->setDebugName(std::move(name)); + v->setDebugName(name); return v; } - Value* insertInput(size_t i, std::string name = "") { + Value* insertInput(size_t i, const std::string& name = "") { Value* v = input_->insertOutput(i); - v->setDebugName(std::move(name)); + v->setDebugName(name); return v; } void eraseInput(size_t i) { @@ -1085,11 +1092,11 @@ struct Graph { current_scope_ = std::move(scope); } - Value* addInput(std::string name = "") { - return block_->addInput(std::move(name)); + Value* addInput(const std::string& name = "") { + return block_->addInput(name); } - Value* insertInput(size_t i, std::string name = "") { - return block_->insertInput(i, std::move(name)); + Value* insertInput(size_t i, const std::string& name = "") { + return block_->insertInput(i, name); } void eraseInput(size_t i) { block_->eraseInput(i); @@ -1120,7 +1127,11 @@ struct Graph { Value* tup, Value* idx, const TypePtr& output_type); - TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end); + TORCH_API Node* createTupleSlice( + Value* tup, + int64_t beg, + int64_t step_size, + int64_t num_values); TORCH_API Node* createEnumName(Value* e); TORCH_API Node* createEnumValue(Value* e); TORCH_API Node* createList( @@ -1165,7 +1176,7 @@ struct Graph { const std::string& cconv, pyobj_list&& scalar_args); // clone n, making a new node in _this_ graph. - // use node_map to translate inputs of n to inputs of the cloned node + // use value_map to translate inputs of n to inputs of the cloned node // if copy_blocks is false, it will not recursively clone the nested blocks // this node contains. TORCH_API Node* createClone( @@ -1261,7 +1272,7 @@ struct Graph { /** \brief An utility class for setting temporary insertion points. * * When an object of this class is created, it stores the current insertion - * point, sets the new one, and restores the original insertion point when the + * point, sets the new one, and restores the original insertion point when the * object is destroyed. */ struct WithInsertPoint { @@ -1324,9 +1335,9 @@ inline const Graph* Value::owningGraph() const { /************* All nodes not required to be defined before Graph **************/ struct ProfileOp : public Node { - static constexpr Symbol Kind = ::c10::prim::profile; + static const Symbol Kind; ProfileOp(Graph* graph, std::function&)> callback) - : Node(graph, ::c10::prim::profile), callback_(callback) {} + : Node(graph, ::c10::prim::profile), callback_(std::move(callback)) {} void cloneFrom(Node* other_) override; Node* allocNewInstance(Graph* g) override; @@ -1336,7 +1347,7 @@ struct ProfileOp : public Node { } void setCallback(std::function&)> callback) { - callback_ = callback; + callback_ = std::move(callback); } private: @@ -1344,11 +1355,34 @@ struct ProfileOp : public Node { }; struct TORCH_API ProfileOptionalOp : public Node { - static constexpr Symbol Kind = ::c10::prim::profile_optional; + static const Symbol Kind; ProfileOptionalOp( Graph* graph, std::function&)> callback) - : Node(graph, ::c10::prim::profile_optional), callback_(callback) {} + : Node(graph, ::c10::prim::profile_optional), + callback_(std::move(callback)) {} + + void cloneFrom(Node* other_) override; + Node* allocNewInstance(Graph* g) override; + + const std::function&)>& getCallback() const { + return callback_; + } + + void setCallback(std::function&)> callback) { + callback_ = std::move(callback); + } + + private: + std::function&)> callback_; +}; + +struct TORCH_API ProfileIValueOp : public Node { + static const Symbol Kind; + ProfileIValueOp( + Graph* graph, + std::function&)> callback) + : Node(graph, ::c10::prim::profile_ivalue), callback_(callback) {} void cloneFrom(Node* other_) override; Node* allocNewInstance(Graph* g) override; diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index 4c4ce31d3b97e..ce4307c85748b 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -70,9 +71,11 @@ struct ParsedLiteral { int64_t i = 0; std::string s = ""; double f = 0.0; + TypePtr ty; std::vector is; std::vector ss; std::vector fs; + std::vector tys; }; struct VarWithType { @@ -139,6 +142,7 @@ void IRParser::parseOperatorOutputs(std::vector* outs) { ParsedLiteral IRParser::parseScalarLiteral(Node* n) { auto token = L.cur(); std::string str; + std::pair> type_alias; ParsedLiteral r; switch (token.kind) { case TK_STRINGLITERAL: @@ -167,6 +171,13 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) { } L.next(); return r; + case TK_IDENT: + // Type literal + r.k = AttributeKind::ty; + type_alias = type_parser.parseType(); + AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled"); + r.ty = type_alias.first; + return r; default: throw ErrorReport(token.range) << "Could not parse literal" << token.text(); @@ -194,6 +205,7 @@ void IRParser::parseAttr(Node* n) { c10::List is; c10::List ss; c10::List fs; + std::vector tys; int elem_num = 0; parseList('[', ',', ']', [&] { ParsedLiteral r = parseScalarLiteral(n); @@ -213,6 +225,11 @@ void IRParser::parseAttr(Node* n) { AT_ASSERT(!elem_num++ || k == AttributeKind::fs); k = AttributeKind::fs; break; + case AttributeKind::ty: + tys.push_back(r.ty); + AT_ASSERT(!elem_num++ || k == AttributeKind::tys); + k = AttributeKind::tys; + break; default: throw ErrorReport(L.cur().range) << "Unexpected attr type"; } @@ -230,6 +247,9 @@ void IRParser::parseAttr(Node* n) { case AttributeKind::is: n->ival_(Symbol::attr(attrname), IValue(is)); break; + case AttributeKind::tys: + n->tys_(Symbol::attr(attrname), tys); + break; default: throw ErrorReport(L.cur().range) << "Unexpected attr type"; } @@ -246,6 +266,9 @@ void IRParser::parseAttr(Node* n) { case AttributeKind::f: n->f_(Symbol::attr(attrname), r.f); break; + case AttributeKind::ty: + n->ty_(Symbol::attr(attrname), r.ty); + break; default: throw ErrorReport(L.cur().range) << "Unexpected attr type"; } diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index 52cace15075ff..8777b4265b3bc 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -114,7 +114,7 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) { const auto& e_a1 = *it_a1; const auto& e_a2 = *it_a2; - if (!ivaluesEqual(e_a1.key(), e_a2.key()) && + if (!ivaluesEqual(e_a1.key(), e_a2.key()) || !ivaluesEqual(e_a1.value(), e_a2.value())) { return false; } @@ -126,6 +126,9 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) { if (a1.isEnum()) { return a1.toEnumHolder() == a2.toEnumHolder(); } + if (a1.isObject()) { + return &a1.toObjectRef() == &a2.toObjectRef(); + } TORCH_INTERNAL_ASSERT(false); } diff --git a/torch/csrc/jit/ir/scope.cpp b/torch/csrc/jit/ir/scope.cpp index 9007224272255..ba1bceda104ba 100644 --- a/torch/csrc/jit/ir/scope.cpp +++ b/torch/csrc/jit/ir/scope.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -89,6 +90,14 @@ InlinedCallStackPtr InlinedCallStack::intrusive_from_this() { InlinedCallStack::InlinedCallStack(Function* fn, SourceRange source_range) : fn_(fn), source_range_(std::move(source_range)) {} +InlinedCallStack::InlinedCallStack( + Function* fn, + SourceRange source_range, + c10::optional module_instance_info) + : fn_(fn), + source_range_(std::move(source_range)), + module_instance_info_(std::move(module_instance_info)) {} + InlinedCallStack::InlinedCallStack( InlinedCallStackPtr callee, Function* fn, @@ -97,6 +106,16 @@ InlinedCallStack::InlinedCallStack( fn_(fn), source_range_(std::move(source_range)) {} +InlinedCallStack::InlinedCallStack( + InlinedCallStackPtr callee, + Function* fn, + SourceRange source_range, + c10::optional module_instance_info) + : callee_(std::move(callee)), + fn_(fn), + source_range_(std::move(source_range)), + module_instance_info_(std::move(module_instance_info)) {} + c10::optional InlinedCallStack::callee() const { return callee_; } @@ -105,10 +124,19 @@ std::vector InlinedCallStack::vec() { std::vector r; c10::optional current = intrusive_from_this(); while (current) { - r.emplace_back(std::make_pair((*current)->fn_, (*current)->source_range_)); + r.emplace_back(std::make_tuple( + (*current)->fn_, + (*current)->source_range_, + (*current)->module_instance_info_)); current = (*current)->callee_; } return r; } + +ModuleInstanceInfo::ModuleInstanceInfo( + c10::ClassTypePtr module_type, + std::string instance_name) + : module_type_(std::move(module_type)), + instance_name_(std::move(instance_name)) {} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index d75f3e060f363..e742caba495bc 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -51,6 +52,32 @@ struct TORCH_API Scope : public c10::intrusive_ptr_target { struct Function; struct InlinedCallStack; +/** + * ModuleInstanceInfo is a structure to include the module type and instance + * name. It also provide public methods to get the pointer to module type and + * instance name. + * + * This structure is mainly used as a private member in InlinedCallStack, such + * that one can follow the callstack to find the relevant module hierarchy. + */ +struct ModuleInstanceInfo { + private: + c10::ClassTypePtr module_type_{nullptr}; + std::string instance_name_; + + public: + ModuleInstanceInfo(c10::ClassTypePtr module_type, std::string instance_name); + c10::ClassTypePtr class_type() { + return module_type_; + } + c10::ClassTypePtr class_type() const { + return module_type_; + } + std::string instance_name() const { + return instance_name_; + } +}; + /** * InlinedCallStack is an element in a list representing callstack of functions * that have been inlined. @@ -80,7 +107,8 @@ struct InlinedCallStack; * [ham, source_range4] -- */ using InlinedCallStackPtr = c10::intrusive_ptr; -using InlinedCallStackEntry = std::pair; +using InlinedCallStackEntry = + std::tuple>; struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { private: @@ -88,17 +116,30 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { Function* fn_; SourceRange source_range_; InlinedCallStackPtr intrusive_from_this(); + c10::optional module_instance_info_; public: // Constructor for a leaf callstack node. InlinedCallStack(Function* fn, SourceRange source_range); + // Constructor for a leaf callstack node. + InlinedCallStack( + Function* fn, + SourceRange source_range, + c10::optional module_instance_info); + // Constructor for an inner callstack node. InlinedCallStack( InlinedCallStackPtr callee, Function* fn, SourceRange source_range); + InlinedCallStack( + InlinedCallStackPtr callee, + Function* fn, + SourceRange source_range, + c10::optional module_instance_info); + // Return next element in the callstack list. c10::optional callee() const; diff --git a/torch/csrc/jit/ir/subgraph_matcher.cpp b/torch/csrc/jit/ir/subgraph_matcher.cpp index 8319543d0bac4..63bb6626010c6 100644 --- a/torch/csrc/jit/ir/subgraph_matcher.cpp +++ b/torch/csrc/jit/ir/subgraph_matcher.cpp @@ -40,6 +40,9 @@ class SubgraphMatcher { bool matchNodes(const Node* n1, Node* n2); bool matchAttributes(const Node* n1, Node* n2); + static bool isInput(const Value* v); + static bool isOutput(const Value* v); + std::unordered_map nodes_map_; std::unordered_map values_map_; @@ -59,16 +62,23 @@ bool patternGraphIsValid(const Graph& pattern) { } } - // Verify that pattern graph returns only one value. - const Node* bottom_node = *(pattern.nodes().end()); - if (bottom_node->inputs().size() != 1) { - return false; - } - // TODO: Verify that nodes in the pattern don't alias. return true; } +bool SubgraphMatcher::isInput(const Value* v) { + return v->node()->kind() == prim::Param; +} + +bool SubgraphMatcher::isOutput(const Value* v) { + for (const Value* output : v->owningGraph()->outputs()) { + if (v == output) { + return true; + } + } + return false; +} + /** * Compare two Values. V1 is from pattern, V2 is from the actual graph. * @@ -98,8 +108,7 @@ bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) { // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is // PARAM, we're comparing entering values - in these two cases the number of // uses don't need to be the same. - if (v1->uses().size() != v2->uses().size() && v2->node() != anchor_ && - v1->node()->kind() != prim::Param) { + if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) { GRAPH_DEBUG( "Values %", v1->debugName(), @@ -297,13 +306,16 @@ bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) { anchor_ = anchor; const Node* bottom_node = *(pattern_.nodes().end()); - AT_ASSERT(bottom_node->inputs().size() == 1); - bottom_node = bottom_node->input()->node(); + bottom_node = bottom_node->input(0)->node(); if (!matchNodes(bottom_node, anchor)) { return false; } + for (const Value* output : pattern_.outputs()) { + AT_ASSERT(values_map_.count(output)); + } + GRAPH_UPDATE("Pattern matched!\n"); return true; } diff --git a/torch/csrc/jit/ir/subgraph_matcher.h b/torch/csrc/jit/ir/subgraph_matcher.h index afb816948a63e..cd8d733da7554 100644 --- a/torch/csrc/jit/ir/subgraph_matcher.h +++ b/torch/csrc/jit/ir/subgraph_matcher.h @@ -12,9 +12,9 @@ namespace jit { * \brief A structure describing a match of a pattern in a graph. * * The structure contains an anchor node, from which the match was found, and - * match-maps for nodes and values. A match-map specifies correspondance between - * nodes in the pattern graph (match-map keys) with nodes in the actual graph - * (match-map values). We keep such maps for both nodes and values. + * match-maps for nodes and values. A match-map specifies the correspondance + * between nodes in the pattern graph (match-map keys) with nodes in the actual + * graph (match-map values). We keep such maps for both nodes and values. */ struct Match { Node* anchor; @@ -33,18 +33,39 @@ struct Match { * - Matched subgraphs do not span across different blocks. * - No uses outside the match are allowed, except for Param and Return nodes. * Basically, we're matching hammocks, not arbitrary subgraphs. - * - Pattern graph must return only one value (i.e. it must have a single + * - The pattern graph must return only one value (i.e. it must have a single * node leading to return). * - Nodes that are not used in computation of the return value in the pattern * graph are ignored during matching (IOW, we're essentially performing DCE on * the pattern). * - Pattern graph nodes cannot alias. TODO: the check not implemented yet. - * - Aliasing nodes in the graph can not consitute a match (i.e. in all found - * matches no nodes in the subgraph alias with each other). TODO: the check not - * implemented yet. - * - The matcher will not mutate either the pattern graph or the matched graph, - * but the latter is taken as non-const so that Match may contain non-const + * - Aliasing nodes in the graph cannot consitute a match (i.e. through all + * found matches, no nodes in the subgraph alias with each other). TODO: check + * not implemented yet. + * - The matcher will not mutate either the pattern graph or the matched graph. + * The matched graph is taken as non-const so that Match may contain non-const * pointers. This enables clients of this API to use Match to drive mutations. + * + * Note [Multi-output Patterns] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Subgraph matcher provides limited support for multi-output patterns. With a + * single output pattern, a single scan through the graph is sufficient to + * find all the matches: given a starting node (an "anchor"), we can + * deterministically check whether a pattern matches a subgraph corresponding to + * this anchor node. For a general case of multi-output patterns, we would have + * N anchors, which would result in M^N comparisons (M is the size of the + * graph). Clearly this is computationally prohibitive. + * + * To overcome this, we impose some constraints on the multi-output patterns + * that we accept. We require that checking whether the pattern matches a + * subgraph would still be fully determined by a single node in the graph. To + * achieve this, we designate the first output in the pattern as the "main" + * output and assume that we can traverse up from this node to match the + * entire pattern. + * + * Corrolary 1: the order of outputs in the pattern matters! + * Corollary 2: patterns cannot contain any nodes not participating in the main + * output computation. */ std::vector TORCH_API findPatternMatches(const Graph& pattern, Graph& graph); diff --git a/torch/csrc/jit/ir/type_hashing.cpp b/torch/csrc/jit/ir/type_hashing.cpp index a03f6508216fd..c447cebefc4f6 100644 --- a/torch/csrc/jit/ir/type_hashing.cpp +++ b/torch/csrc/jit/ir/type_hashing.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/jit_log.h b/torch/csrc/jit/jit_log.h index 85e05edeb8a64..3abbdb9052f84 100644 --- a/torch/csrc/jit/jit_log.h +++ b/torch/csrc/jit/jit_log.h @@ -25,15 +25,16 @@ // * `GRAPH_DEBUG` should be used for providing information useful for debugging // the internals of a particular optimization pass or analysis -// The current logging level is `GRAPH_UPDATE` meaning that both `GRAPH_DUMP` -// and `GRAPH_UPDATE` will be enabled when -// one specifies a file(s) in `PYTORCH_JIT_LOG_LEVEL`. +// The default logging level is `GRAPH_DUMP` meaning that only `GRAPH_DUMP` +// statements will be enabled when one specifies a file(s) in +// `PYTORCH_JIT_LOG_LEVEL`. -// `GRAPH_DEBUG` can be enabled by prefixing a file name with an `>` as in +// `GRAPH_UPDATE` can be enabled by prefixing a file name with an `>` as in // `>alias_analysis`. -// `>>` and `>>>` are also valid and **currently** are equivalent to -// `GRAPH_DEBUG` as there is no logging level that is -// higher than `GRAPH_DEBUG`. +// `GRAPH_DEBUG` can be enabled by prefixing a file name with an `>>` as in +// `>>alias_analysis`. +// `>>>` is also valid and **currently** is equivalent to `GRAPH_DEBUG` as there +// is no logging level that is higher than `GRAPH_DEBUG`. namespace torch { namespace jit { @@ -97,5 +98,12 @@ TORCH_API std::ostream& operator<<( // pass #define GRAPH_DEBUG(...) \ JIT_LOG(::torch::jit::JitLoggingLevels::GRAPH_DEBUG, __VA_ARGS__); + +#define GRAPH_DUMP_ENABLED \ + (is_enabled(__FILE__, ::torch::jit::JitLoggingLevels::GRAPH_DUMP)) +#define GRAPH_UPDATE_ENABLED \ + (is_enabled(__FILE__, ::torch::jit::JitLoggingLevels::GRAPH_UPDATE)) +#define GRAPH_DEBUG_ENABLED \ + (is_enabled(__FILE__, ::torch::jit::JitLoggingLevels::GRAPH_DEBUG)) } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/jit_opt_limit.cpp b/torch/csrc/jit/jit_opt_limit.cpp new file mode 100644 index 0000000000000..749f12197a0f9 --- /dev/null +++ b/torch/csrc/jit/jit_opt_limit.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +std::unordered_map& passes_to_current_counter() { + static std::unordered_map passes_to_current_counter; + return passes_to_current_counter; +} + +static int parseOptLimit(const std::string& opt_limit) { + try { + int64_t n = c10::stoi(opt_limit); + return n; + } catch (...) { + return -1; + } +} + +static std::unordered_map parseJITOptLimitOption( + const char* option) { + std::stringstream in_ss; + if (option) { + in_ss << option; + } + std::unordered_map passes_to_opt_limits; + std::string line; + while (std::getline(in_ss, line, ':')) { + if (line.size() == 0) { + continue; + } + auto index_at = line.find_last_of('='); + auto pass_name = line.substr(0, index_at); + pass_name = c10::detail::ExcludeFileExtension(pass_name); + auto opt_limit = parseOptLimit(line.substr(index_at + 1)); + passes_to_opt_limits.insert({pass_name, opt_limit}); + } + + return passes_to_opt_limits; +} + +bool opt_limit(const char* pass_name) { + static const char* opt_limit = std::getenv("PYTORCH_JIT_OPT_LIMIT"); + // if nothing is provided, let's allow everything + if (!opt_limit) { + return true; + } + + static const std::unordered_map passes_to_opt_limits = + parseJITOptLimitOption(opt_limit); + std::string pass{pass_name}; + pass = c10::detail::StripBasename(pass); + pass = c10::detail::ExcludeFileExtension(pass); + + auto opt_limit_it = passes_to_opt_limits.find(pass); + if (opt_limit_it == passes_to_opt_limits.end()) { + return true; + } + + auto current_count_it = passes_to_current_counter().find(pass); + if (current_count_it == passes_to_current_counter().end()) { + passes_to_current_counter().insert({pass, 0}); + } + + current_count_it = passes_to_current_counter().find(pass); + if (current_count_it->second >= opt_limit_it->second) { + return false; + } + + current_count_it->second++; + return true; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/jit_opt_limit.h b/torch/csrc/jit/jit_opt_limit.h new file mode 100644 index 0000000000000..b5dee8dcb76cd --- /dev/null +++ b/torch/csrc/jit/jit_opt_limit.h @@ -0,0 +1,40 @@ +#pragma once +#include +#include +#include + +// `TorchScript` offers a simple optimization limit checker +// that can be configured through environment variable `PYTORCH_JIT_OPT_LIMIT`. +// The purpose is to limit how many optimization you can make per pass. +// This is useful for debugging any passes. + +// Opt limit checker is enabled on a per file basis (hence per pass). For +// example, in `constant_propagation.cpp`, `PYTORCH_JIT_OPT_LIMIT` should be set +// to `constant_propagation=` or, simply, to +// `constant_propagation=` where is the number of +// optimizations you want to make for the pass. (i.e. +// `PYTORCH_JIT_OPT_LIMIT="constant_propagation="`). + +// Multiple files can be configured by separating each file name with a colon +// `:` as in the following example, +// `PYTORCH_JIT_OPT_LIMIT="constant_propagation=:dead_code_elimination="` + +// You can call opt limiter by calling JIT_OPT_LIMIT(). It will return true if +// we haven't reached the optimization limit yet. Otherwise, it will return +// false. Typical usage: + +// auto allowed = JIT_OPT_LIMIT(); +// if (!allowed) { +// GRAPH_DUMP(...); //supplied from jit_log +// return; +// } + +namespace torch { +namespace jit { + +TORCH_API bool opt_limit(const char* pass_name); + +#define JIT_OPT_LIMIT() opt_limit(__FILE__); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/export_data.cpp b/torch/csrc/jit/mobile/export_data.cpp index 3bd28fbbac5cc..a6eff3192bfdd 100644 --- a/torch/csrc/jit/mobile/export_data.cpp +++ b/torch/csrc/jit/mobile/export_data.cpp @@ -36,7 +36,7 @@ class ScriptModuleSerializer { void writeArchive(const std::string& archive_name, const IValue& value) { std::vector data; // Vector to capture the run-time class types during pickling the IValues - std::vector memorizedClassTypes; + std::vector memoizedClassTypes; Pickler data_pickle( [&](const char* buf, size_t size) { data.insert(data.end(), buf, buf + size); @@ -45,7 +45,7 @@ class ScriptModuleSerializer { [&](const c10::ClassTypePtr& t) { return type_name_uniquer_.getUniqueName(t); }, - &memorizedClassTypes); + &memoizedClassTypes); data_pickle.protocol(); data_pickle.pushIValue(value); data_pickle.stop(); diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 07dd6b47ab7d2..47d492ccf8625 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -118,6 +119,10 @@ c10::IValue Function::operator()(Stack& stack) { return stack.front(); } +const std::shared_ptr Function::get_code() const { + return code_; +} + } // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h index c5ec851e60b54..749633d4916c3 100644 --- a/torch/csrc/jit/mobile/function.h +++ b/torch/csrc/jit/mobile/function.h @@ -31,6 +31,7 @@ class Function { void set_register_size(size_t size); std::string get_module_debug_info(size_t pc) const; + const std::shared_ptr get_code() const; private: c10::QualifiedName name_; diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index e26177605674c..fd3bb9496da9b 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -86,8 +87,7 @@ void print_unsupported_ops_and_throw( TORCH_CHECK( false, "Following ops cannot be found. ", - "May need to add them explicitly to the selective build operator whitelist, ", - "or re-run the export_opnames to update the whitelist:", + "Check fburl.com/missing_ops for the fix.", error_message); } @@ -123,6 +123,7 @@ void parseMethods( "The numbers of bytecode values and debug info values do not match."); } + // Process all methods in this mobile module. for (size_t i = method_i_start; i < vals.size(); ++i) { const auto& element = vals[i]; const auto& m_tuple = element.toTuple()->elements(); @@ -191,6 +192,8 @@ void parseMethods( } std::unordered_set unsupported_op_names; + // ops_list is the list of operator names that were read in from + // bytecode.plk for the method that is currently being processed. for (const auto& op : ops_list) { auto op_item = op.toTuple()->elements(); TORCH_CHECK( @@ -227,7 +230,9 @@ void parseMethods( class BytecodeDeserializer final { public: explicit BytecodeDeserializer(std::unique_ptr reader); - mobile::Module deserialize(c10::optional device); + mobile::Module deserialize( + c10::optional device, + ExtraFilesMap& extra_files); std::unordered_map deserializeMetadata( c10::optional device); @@ -256,9 +261,30 @@ std::unordered_map BytecodeDeserializer:: } mobile::Module BytecodeDeserializer::deserialize( - c10::optional device) { + c10::optional device, + ExtraFilesMap& extra_files) { device_ = device; + for (const auto& kv : extra_files) { + const std::string& key = "extra/" + kv.first; + if (reader_->hasRecord(key)) { + at::DataPtr meta_ptr; + size_t meta_size = 0; + std::tie(meta_ptr, meta_size) = reader_->getRecord(key); + extra_files[kv.first] = + std::string(static_cast(meta_ptr.get()), meta_size); + } + } auto mcu = std::make_shared(); + + // bvals can have 2 possible formats: + // + // 1. Old format: bvals is an array (Tuple) of N elements, each element being + // itself a Tuple(method_name, method_table). + // + // 2. New format: bvals is an array (Tuple) of 1+N elements. The first element + // being a Tuple (int, table), and the integer stands for the bytecode version + // number. The rest of the elements are the same as before. + // auto bvals = readArchive("bytecode", mcu).toTuple()->elements(); c10::optional> debug_info_bvals; @@ -385,44 +411,50 @@ c10::IValue BytecodeDeserializer::readArchive( mobile::Module _load_for_mobile( std::istream& in, - c10::optional device) { + c10::optional device, + ExtraFilesMap& extra_files) { std::unique_ptr rai = std::make_unique(&in); - auto module = _load_for_mobile(std::move(rai), device); + auto module = _load_for_mobile(std::move(rai), device, extra_files); return module; } mobile::Module _load_for_mobile( const std::string& filename, - c10::optional device) { + c10::optional device, + ExtraFilesMap& extra_files) { std::unique_ptr rai = std::make_unique(filename); - auto module = _load_for_mobile(std::move(rai), device); + auto module = _load_for_mobile(std::move(rai), device, extra_files); return module; } mobile::Module _load_for_mobile( std::unique_ptr rai, - c10::optional device) { + c10::optional device, + ExtraFilesMap& extra_files) { auto observer = torch::observerConfig().getModuleObserver(); + auto instance_key = std::rand(); if (observer) { - observer->onEnterLoadModel(); + observer->onEnterLoadModel(instance_key); } auto reader = torch::make_unique(std::move(rai)); BytecodeDeserializer deserializer(std::move(reader)); try { - mobile::Module result = deserializer.deserialize(std::move(device)); + mobile::Module result = deserializer.deserialize(device, extra_files); std::unordered_map copied_metadata = result.metadata(); if (result.metadata().find("model_name") == result.metadata().end()) { copied_metadata["model_name"] = result.name(); } if (observer) { - observer->onExitLoadModel(copied_metadata); + observer->onExitLoadModel(instance_key, copied_metadata); } return result; } catch (c10::Error& error) { if (observer) { observer->onFailLoadModel( - error.what(), deserializer.deserializeMetadata(std::move(device))); + instance_key, + error.what(), + deserializer.deserializeMetadata(std::move(device))); } TORCH_RETHROW(error); } catch (...) { @@ -440,7 +472,9 @@ mobile::Module _load_for_mobile( } catch (c10::Error& error) { if (observer) { observer->onFailLoadModel( - error.what(), deserializer.deserializeMetadata(std::move(device))); + instance_key, + error.what(), + deserializer.deserializeMetadata(std::move(device))); } TORCH_RETHROW(error); } diff --git a/torch/csrc/jit/mobile/import.h b/torch/csrc/jit/mobile/import.h index f6376c19cf895..ebfd0f3d9557e 100644 --- a/torch/csrc/jit/mobile/import.h +++ b/torch/csrc/jit/mobile/import.h @@ -11,17 +11,24 @@ namespace jit { using caffe2::serialize::FileAdapter; using caffe2::serialize::IStreamAdapter; using caffe2::serialize::ReadAdapterInterface; +using ExtraFilesMap = std::unordered_map; +static ExtraFilesMap default_extra_files_mobile; +// The family of methods below convery a serialized Mobile Module +// into a mobile::Module object. TORCH_API mobile::Module _load_for_mobile( std::istream& in, - c10::optional device = c10::nullopt); + c10::optional device = c10::nullopt, + ExtraFilesMap& extra_files = default_extra_files_mobile); TORCH_API mobile::Module _load_for_mobile( const std::string& filename, - c10::optional device = c10::nullopt); + c10::optional device = c10::nullopt, + ExtraFilesMap& extra_files = default_extra_files_mobile); TORCH_API mobile::Module _load_for_mobile( std::unique_ptr rai, - c10::optional device = c10::nullopt); + c10::optional device = c10::nullopt, + ExtraFilesMap& extra_files = default_extra_files_mobile); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/import_data.cpp b/torch/csrc/jit/mobile/import_data.cpp index a1dacef293981..6ded78b1f56d1 100644 --- a/torch/csrc/jit/mobile/import_data.cpp +++ b/torch/csrc/jit/mobile/import_data.cpp @@ -171,8 +171,9 @@ mobile::Module _load_data( std::unique_ptr rai, c10::optional device) { auto observer = torch::observerConfig().getModuleObserver(); + auto instance_key = std::rand(); if (observer) { - observer->onEnterLoadModel(); + observer->onEnterLoadModel(instance_key); } try { auto reader = torch::make_unique(std::move(rai)); @@ -186,12 +187,12 @@ mobile::Module _load_data( copied_metadata["model_name"] = result.name(); } if (observer) { - observer->onExitLoadModel(copied_metadata); + observer->onExitLoadModel(instance_key, copied_metadata); } return result; } catch (c10::Error& error) { if (observer) { - observer->onFailLoadModel(error.what()); + observer->onFailLoadModel(instance_key, error.what()); } TORCH_RETHROW(error); } catch (...) { @@ -208,7 +209,7 @@ mobile::Module _load_data( } } catch (c10::Error& error) { if (observer) { - observer->onFailLoadModel(error.what()); + observer->onFailLoadModel(instance_key, error.what()); } TORCH_RETHROW(error); } diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 51f9e17e522eb..008c42b5c7ef4 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -36,12 +37,10 @@ bool InterpreterState::run(Stack& stack) { switch (inst.op) { case OP: { if (at::hasGlobalCallbacks()) { - if (auto debug_info = c10::ThreadLocalDebugInfo::get( - c10::DebugInfoKind::MOBILE_RUNTIME_INFO)) { - if (auto* mobile_debug_info = - dynamic_cast(debug_info.get())) { - mobile_debug_info->setOpIdx(pc); - } + if (auto* mobile_debug_info = + static_cast(c10::ThreadLocalDebugInfo::get( + c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) { + mobile_debug_info->setOpIdx(pc); } } @@ -52,7 +51,7 @@ bool InterpreterState::run(Stack& stack) { // enable only for the RecordFunction enableRecordFunction(true); } - RECORD_FUNCTION(code_->op_names_[inst.X].name, stack); + RECORD_USER_SCOPE_WITH_INPUTS(code_->op_names_[inst.X].name, stack); if (!prev_value) { enableRecordFunction(false); } @@ -150,7 +149,7 @@ bool InterpreterState::run(Stack& stack) { case RET: return false; case LIST_CONSTRUCT: { - auto type = code_->types_[inst.X]->expect(); + const auto& type = code_->types_[inst.X]->expectRef(); listConstruct(stack, type, inst.N); ++pc; } break; diff --git a/torch/csrc/jit/mobile/interpreter.h b/torch/csrc/jit/mobile/interpreter.h index 86eac344b25a0..6380c2edbe3ff 100644 --- a/torch/csrc/jit/mobile/interpreter.h +++ b/torch/csrc/jit/mobile/interpreter.h @@ -1,8 +1,8 @@ #pragma once -#include #include #include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/mobile/method.h b/torch/csrc/jit/mobile/method.h index 89d333e555aae..00e820c3c1861 100644 --- a/torch/csrc/jit/mobile/method.h +++ b/torch/csrc/jit/mobile/method.h @@ -1,3 +1,5 @@ +#pragma once + #include #include diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index fc8cde35aabfc..6a9ab44ed4f63 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -116,20 +117,30 @@ bool Module::is_training() const { return true; } +const std::vector Module::get_methods() const { + std::vector methods; + for (std::unique_ptr& fn : cu_->methods()) { + methods.emplace_back(this, fn.get()); + } + return methods; +} + Method::Method(const Module* owner, Function* function) : owner_(owner), function_(function) {} void Method::run(Stack& stack) { auto observer = torch::observerConfig().getModuleObserver(); + auto instance_key = std::rand(); /* if the metadata dict doesn't contain "model_name", copy the metadata and set the value of "model_name" as name() */ std::unordered_map copied_metadata = owner_->metadata(); if (owner_->metadata().find("model_name") == owner_->metadata().end()) { - copied_metadata["model_name"] = name(); + copied_metadata["model_name"] = owner_->name(); } if (observer) { - observer->onEnterRunMethod(copied_metadata, function_->name()); + observer->onEnterRunMethod( + copied_metadata, instance_key, function_->name()); } auto debug_info = std::make_shared(); @@ -142,11 +153,11 @@ void Method::run(Stack& stack) { stack.insert(stack.begin(), owner_->_ivalue()); function_->run(stack); if (observer) { - observer->onExitRunMethod(); + observer->onExitRunMethod(instance_key); } } catch (c10::Error& error) { if (observer) { - observer->onFailRunMethod(error.what()); + observer->onFailRunMethod(instance_key, error.what()); } TORCH_RETHROW(error); } catch (...) { @@ -163,14 +174,14 @@ void Method::run(Stack& stack) { } } catch (c10::Error& error) { if (observer) { - observer->onFailRunMethod(error.what()); + observer->onFailRunMethod(instance_key, error.what()); } TORCH_RETHROW(error); } } } -c10::IValue Method::operator()(std::vector stack) { +c10::IValue Method::operator()(std::vector stack) { run(stack); TORCH_INTERNAL_ASSERT(!stack.empty()); return stack.front(); diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index be66bd84ef75e..2be75c61b6b5f 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -1,5 +1,6 @@ #pragma once //#include +#include #include #include @@ -8,6 +9,27 @@ namespace jit { namespace mobile { using Stack = std::vector; +// A CompilationUnit object is the one that gets executed by the lite +// interpreter. +// +// A CompilationUnit object contains a list of Method Objects. These are methods +// that appear in the original PyTorch Model. These method correspond to Python +// member functions of the Model class. +// +// Methods in turn contain a Function, and a back-pointer to the Module that +// owns this Method instance. +// +// A Function contains a Code Object (code_) which is defined in interpreter.h +// +// A Code object contains the following: +// +// std::vector instructions_; +// std::vector op_names_; +// std::vector> operators_; +// std::vector constants_; +// std::vector types_; +// size_t register_size_; // Aggregated output size. +// class CompilationUnit { public: void register_function(std::unique_ptr fn); @@ -20,6 +42,14 @@ class CompilationUnit { std::vector> methods_; }; +// A Torch Mobile Module is a representation of the model (trained in case +// of inference). A Mobile Module contains +// +// 1. data (object_) +// 2. metadata (optional) about the model (metadata_ from the metadata.pkl +// file added after training) +// 3. Compilation Unit (cu_) +// class TORCH_API Module { public: Module( @@ -43,7 +73,7 @@ class TORCH_API Module { return get_method("forward")(std::move(inputs)); } c10::optional find_method(const std::string& basename) const; - std::string name() { + const std::string name() const { return object_->name(); } const std::vector& slots() const { @@ -66,6 +96,17 @@ class TORCH_API Module { const std::unordered_map metadata() const { return metadata_; } + const std::vector get_methods() const; + + c10::IValue attr(const std::string& name, c10::IValue or_else) const { + if (auto r = object_->type()->findAttributeSlot(name)) { + return object_->getSlot(*r); + } + if (auto r = object_->type()->findConstantSlot(name)) { + return object_->type()->getConstant(*r); + } + return or_else; + } private: c10::intrusive_ptr object_; diff --git a/torch/csrc/jit/mobile/observer.h b/torch/csrc/jit/mobile/observer.h index 2935fa078fc74..6ec2806d83b4d 100644 --- a/torch/csrc/jit/mobile/observer.h +++ b/torch/csrc/jit/mobile/observer.h @@ -70,15 +70,17 @@ class MobileModuleObserver { virtual void onEnterRunMethod( const std::unordered_map&, + const int32_t, const std::string&) {} - virtual void onExitRunMethod() {} - virtual void onCancelRunMethod(const std::string&) {} - virtual void onFailRunMethod(const char*) {} - virtual void onEnterLoadModel() {} + virtual void onExitRunMethod(const int32_t) {} + virtual void onFailRunMethod(const int32_t, const char*) {} + virtual void onEnterLoadModel(const int32_t) {} virtual void onExitLoadModel( + const int32_t, const std::unordered_map&) {} - virtual void onFailLoadModel(const char*) {} + virtual void onFailLoadModel(const int32_t, const char*) {} virtual void onFailLoadModel( + const int32_t, const char*, const std::unordered_map&) {} }; diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index 546aabd97094e..42814e5fe5aad 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/annotate_warns.cpp b/torch/csrc/jit/passes/annotate_warns.cpp new file mode 100644 index 0000000000000..3e0dc9faa1c16 --- /dev/null +++ b/torch/csrc/jit/passes/annotate_warns.cpp @@ -0,0 +1,29 @@ +#include + +#include + +namespace torch { +namespace jit { + +void AnnotateWarns(Block* b) { + static std::atomic idx(0); + for (Node* n : b->nodes()) { + for (Block* child_b : n->blocks()) { + AnnotateWarns(child_b); + } + + if (n->kind() != aten::warn) { + continue; + } + + n->i_(attr::warn_id, idx); + idx++; + } +} + +void AnnotateWarns(const std::shared_ptr& graph) { + AnnotateWarns(graph->block()); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/annotate_warns.h b/torch/csrc/jit/passes/annotate_warns.h new file mode 100644 index 0000000000000..18e9f67641e04 --- /dev/null +++ b/torch/csrc/jit/passes/annotate_warns.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +TORCH_API void AnnotateWarns(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/bailout_graph.cpp b/torch/csrc/jit/passes/bailout_graph.cpp index 04b1962a72a01..7e721e185f83e 100644 --- a/torch/csrc/jit/passes/bailout_graph.cpp +++ b/torch/csrc/jit/passes/bailout_graph.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -120,7 +121,6 @@ struct BailOutGraphBuilderForNode { auto old_max_count = getOrAddInputForValue(lv.maxTripCount()); auto cur_iter = getInputForValue(lv.currentTripCount()); auto block_outputs = lv.bodyBlock()->outputs(); - auto carried_deps = lv.carriedInputsWithCond(); auto* block = copy_graph_->block(); // subtract the number of iterations diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index 5180ce6abff6d..4a41e8ff54ff0 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { diff --git a/torch/csrc/jit/passes/clear_profiling.cpp b/torch/csrc/jit/passes/clear_profiling.cpp index 809d17767a46c..9acb9fbc31291 100644 --- a/torch/csrc/jit/passes/clear_profiling.cpp +++ b/torch/csrc/jit/passes/clear_profiling.cpp @@ -1,5 +1,6 @@ #include + #include namespace torch { diff --git a/torch/csrc/jit/passes/clear_undefinedness.cpp b/torch/csrc/jit/passes/clear_undefinedness.cpp index 591c82464a95e..568441dad1e6b 100644 --- a/torch/csrc/jit/passes/clear_undefinedness.cpp +++ b/torch/csrc/jit/passes/clear_undefinedness.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -9,7 +10,7 @@ void clearUndefinedness(Value* o) { o->setType(TensorType::get()); } else if ( o->type()->kind() == ListType::Kind && - o->type()->expect()->getElementType()->kind() == + o->type()->expectRef().getElementType()->kind() == TensorType::Kind) { o->setType(ListType::create(TensorType::get())); } diff --git a/torch/csrc/jit/passes/constant_pooling.cpp b/torch/csrc/jit/passes/constant_pooling.cpp index 2ee9dd13c56cc..06a5d618b9c54 100644 --- a/torch/csrc/jit/passes/constant_pooling.cpp +++ b/torch/csrc/jit/passes/constant_pooling.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index c3285f2e24265..3db17e5658daa 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -16,7 +17,9 @@ namespace torch { namespace jit { -c10::optional> runNodeIfInputsAreConstant(const Node* n) { +c10::optional> runNodeIfInputsAreConstant( + const Node* n, + bool ignore_custom_classes) { Stack stack; for (auto input : n->inputs()) { if (auto ival = toIValue(input)) { @@ -25,6 +28,7 @@ c10::optional> runNodeIfInputsAreConstant(const Node* n) { return c10::nullopt; } } + switch (n->kind()) { case prim::ListUnpack: { if (stack.back().toList().size() != n->outputs().size()) { @@ -42,7 +46,9 @@ c10::optional> runNodeIfInputsAreConstant(const Node* n) { } break; case prim::ListConstruct: { listConstruct( - stack, n->output()->type()->expect(), n->inputs().size()); + stack, + n->output()->type()->expectRef(), + n->inputs().size()); } break; case prim::DictConstruct: { dictConstruct( @@ -51,20 +57,24 @@ c10::optional> runNodeIfInputsAreConstant(const Node* n) { case prim::CreateObject: { createObject(stack, n->output()->type()->expect()); } break; + case prim::GetAttr: { + auto attr = pop(stack).toObject()->getAttr(n->s(attr::name)); + push(stack, attr); + } break; case prim::isinstance: { isinstance(stack, n->tys(attr::types)); } break; default: { - const auto& the_operator = n->getOperator(); - if (the_operator.schema().is_vararg()) { + const auto maybe_schema = n->maybeSchema(); + if (maybe_schema && maybe_schema->is_vararg()) { // vararg schemas require the number of inputs at the top of the stack // but this is broken in other places in constant prop, so disable it // for now return c10::nullopt; } - auto op = n->getOperation(); try { + auto op = n->getOperation(); op(&stack); } catch (...) { return c10::nullopt; @@ -80,6 +90,12 @@ c10::optional> runNodeIfInputsAreConstant(const Node* n) { return c10::nullopt; } } + // Weak form of const propagation + if (ignore_custom_classes) { + if (v.isCustomClass()) { + return c10::nullopt; + } + } } return stack; } @@ -89,7 +105,7 @@ namespace { std::unordered_set skip_list = { prim::If, prim::Loop, - prim::Function, + prim::Closure, prim::Constant, prim::AutogradZero, prim::Uninitialized, @@ -104,14 +120,16 @@ std::unordered_set skip_list = { struct ConstantPropagator { // Runs constant propagation with an aliasing db and checks if inputs or // outputs might be mutated in the graph - static ConstantPropagator WithAliasDb(std::shared_ptr graph) { - return ConstantPropagator(graph, true); + static ConstantPropagator WithAliasDb( + std::shared_ptr graph, + bool ignore_custom_classes) { + return ConstantPropagator(std::move(graph), true, ignore_custom_classes); } // Runs constant propagation only on ops that clearly do not have aliased // inputs or outputs without computing aliasing information static ConstantPropagator NoAliasDb(std::shared_ptr graph) { - return ConstantPropagator(graph, false); + return ConstantPropagator(std::move(graph), false, false); } void run() { @@ -119,18 +137,23 @@ struct ConstantPropagator { } private: - ConstantPropagator(std::shared_ptr graph, bool aliasing_types) + ConstantPropagator( + std::shared_ptr graph, + bool aliasing_types, + bool ignore_custom_classes) : graph_(std::move(graph)) { if (aliasing_types) { aliasDb_ = torch::make_unique(graph_); } else { aliasDb_ = nullptr; } + ignore_custom_classes_ = ignore_custom_classes; } void propagateNode(Node* n) { std::vector outputs; - if (auto outputs_opt = runNodeIfInputsAreConstant(n)) { + if (auto outputs_opt = + runNodeIfInputsAreConstant(n, ignore_custom_classes_)) { outputs = std::move(outputs_opt.value()); } else { // The op failed to run, so we cannot continue constant-prop for it. @@ -353,11 +376,15 @@ struct ConstantPropagator { std::shared_ptr graph_; std::unique_ptr aliasDb_; + bool ignore_custom_classes_; }; } // anonymous namespace -void ConstantPropagation(std::shared_ptr& graph) { - ConstantPropagator cp = ConstantPropagator::WithAliasDb(graph); +void ConstantPropagation( + std::shared_ptr& graph, + bool ignore_custom_classes) { + ConstantPropagator cp = + ConstantPropagator::WithAliasDb(graph, ignore_custom_classes); cp.run(); EliminateDeadCode(graph); GRAPH_DUMP("After ConstantPropagation: ", graph); diff --git a/torch/csrc/jit/passes/constant_propagation.h b/torch/csrc/jit/passes/constant_propagation.h index 3164c06927296..a02fc0fd994c5 100644 --- a/torch/csrc/jit/passes/constant_propagation.h +++ b/torch/csrc/jit/passes/constant_propagation.h @@ -5,15 +5,25 @@ namespace torch { namespace jit { -TORCH_API void ConstantPropagation(std::shared_ptr& graph); +// Runs constant propagation on all objects unless ignore_custom_classes is +// specified as true, in which case user defined classes are skipped. This is +// useful to prevent early fusion of packing operations, which end up lowering +// away information about their constructors (e.g. packed::linear_clamp_prepack +// and prepacked::conv2d_clamp_prepack) +TORCH_API void ConstantPropagation( + std::shared_ptr& graph, + bool ignore_custom_classes = false); // runs constant propagation only on ops that have non-aliasing inputs & outputs TORCH_API void ConstantPropagationImmutableTypes(std::shared_ptr& graph); // Runs the node if its inputs are constants. Callers of this function must // make their own determination if constant prop is appropriate - for example -// non-deterministic ops or ops with side effects -TORCH_API c10::optional runNodeIfInputsAreConstant(const Node* node); +// non-deterministic ops or ops with side effects. If ignore_custom_classes is +// specified, nodes that output user defined classes are not run. +TORCH_API c10::optional runNodeIfInputsAreConstant( + const Node* node, + bool ignore_custom_classes = false); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 6ac510b137775..78bffb34dc68d 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -270,7 +271,9 @@ std::vector CreateAutodiffSubgraphs( size_t threshold) { std::vector diff_nodes; AliasDb db(graph); + GRAPH_DEBUG("Before creating autodiff subgraphs", *graph); SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run(); + GRAPH_DEBUG("After creating autodiff subgraphs", *graph); return diff_nodes; } } // namespace jit diff --git a/torch/csrc/jit/passes/create_functional_graphs.cpp b/torch/csrc/jit/passes/create_functional_graphs.cpp index a5c2c6f3c956f..d5d85f6f5b2ac 100644 --- a/torch/csrc/jit/passes/create_functional_graphs.cpp +++ b/torch/csrc/jit/passes/create_functional_graphs.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/cuda_graph_fuser.h b/torch/csrc/jit/passes/cuda_graph_fuser.h index 0f821845613b2..104a437104aa4 100644 --- a/torch/csrc/jit/passes/cuda_graph_fuser.h +++ b/torch/csrc/jit/passes/cuda_graph_fuser.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include diff --git a/torch/csrc/jit/passes/decompose_ops.cpp b/torch/csrc/jit/passes/decompose_ops.cpp index 2d0f70a0119cc..d7ca56973129a 100644 --- a/torch/csrc/jit/passes/decompose_ops.cpp +++ b/torch/csrc/jit/passes/decompose_ops.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -38,7 +39,7 @@ bool isDecomposableNorm(Node* normalize_op) { if (!input->type()->isSubtypeOf(TensorType::get())) { return false; } - auto device = input->type()->expect()->device(); + auto device = input->type()->expectRef().device(); // As of now, we do the decomposition for batchnorm/layernorm on GPU device // only if (!device || (*device).is_cpu()) { @@ -125,12 +126,13 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) { Graph* graph = it->owningGraph(); Value* input = it->namedInput(attr::input); Value* input_dim = graph->insert(aten::dim, {input}); - std::vector inputs{input, - it->namedInput(attr::running_mean), - it->namedInput(attr::running_var), - it->namedInput(attr::training), - it->namedInput(attr::momentum), - it->namedInput(attr::eps)}; + std::vector inputs{ + input, + it->namedInput(attr::running_mean), + it->namedInput(attr::running_var), + it->namedInput(attr::training), + it->namedInput(attr::momentum), + it->namedInput(attr::eps)}; // inline the compiled decomposed batchnorm std::shared_ptr d_graph = @@ -161,10 +163,11 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) { decomposed = true; WithInsertPoint insert_guard{*it}; Graph* graph = it->owningGraph(); - std::vector inputs{it->namedInput(attr::input), - it->namedInput(attr::normalized_shape), - it->namedInput(attr::eps), - it->namedInput(attr::cudnn_enable)}; + std::vector inputs{ + it->namedInput(attr::input), + it->namedInput(attr::normalized_shape), + it->namedInput(attr::eps), + it->namedInput(attr::cudnn_enable)}; // inline the compiled decomposed layernorm std::shared_ptr d_graph = diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 7e6a45351319e..738bb48862fd6 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -1,4 +1,5 @@ #include + #include #include diff --git a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp index 8a00e91eb9c4e..0bffc35f7bdfd 100644 --- a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp +++ b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp @@ -144,7 +144,6 @@ struct ConvertTracedAttrReferences { // the correctness of CSE over GetAttr Nodes (i think) std::unordered_map local_remaps; - auto prefix_atoms = prefix.atoms(); for (Node* n : b->nodes()) { // The only difference between these two branches is for // TracedModuleForward we advance the scope, but for other @@ -242,7 +241,7 @@ struct ConvertTracedAttrReferences { // add block and Node outputs to lift it into a scope in which // it dominates the Use. struct MakeDefsDominateUses { - MakeDefsDominateUses() {} + MakeDefsDominateUses() = default; void run(Block* b) { processNode(b->param_node(), b); diff --git a/torch/csrc/jit/passes/fold_conv_bn.cpp b/torch/csrc/jit/passes/fold_conv_bn.cpp index 7d344632838af..cc772c3d5f481 100644 --- a/torch/csrc/jit/passes/fold_conv_bn.cpp +++ b/torch/csrc/jit/passes/fold_conv_bn.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -9,19 +10,20 @@ namespace torch { namespace jit { +std::tuple computeUpdatedConvWeightAndBias( + const ConvBNParameters& p) { + at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps); + const int64_t ndim = p.conv_w.dim(); + at::DimVector sizes(ndim, 1); + sizes.at(0) = -1; + at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape(sizes); + at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b; + return std::make_tuple(new_w, new_b); +} + namespace { using graph_rewrite_helper::PatternInfo; -struct ConvBNParameters { - at::Tensor conv_w; - at::Tensor conv_b; - at::Tensor bn_rm; - at::Tensor bn_rv; - double bn_eps = 0.0; - at::Tensor bn_w; - at::Tensor bn_b; -}; - static bool hastensor(Module& m, const char* name) { return m.hasattr(name) && m.attr(name).isTensor(); } @@ -115,16 +117,6 @@ class FoldConvBatchNormHelper { Module& bn, ConvBNParameters& r); - /** - * Given the current weight and bias tensors of a Conv module and parameters - * of the BatchNorm module we're folding with, compute the updated values - * for the weight and bias. - * - * The function is basically copied from torch/nn/utils/fusion.py - */ - std::tuple computeUpdatedConvWeightAndBias( - const ConvBNParameters& p); - std::unordered_map> conv_module_and_params_; @@ -149,17 +141,6 @@ class FoldConvBatchNormHelper { std::unordered_set nodes_to_delete_; }; -std::tuple FoldConvBatchNormHelper:: - computeUpdatedConvWeightAndBias(const ConvBNParameters& p) { - at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps); - const int64_t ndim = p.conv_w.dim(); - at::DimVector sizes(ndim, 1); - sizes.at(0) = -1; - at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape(sizes); - at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b; - return std::make_tuple(new_w, new_b); -} - bool extractOptionalBNParams(const script::Module& bn, ConvBNParameters& r) { auto bn_forward = bn.get_method("forward"); auto graph = bn_forward.graph(); diff --git a/torch/csrc/jit/passes/fold_conv_bn.h b/torch/csrc/jit/passes/fold_conv_bn.h index 54b27c6eec6f2..fbedc4ea030f6 100644 --- a/torch/csrc/jit/passes/fold_conv_bn.h +++ b/torch/csrc/jit/passes/fold_conv_bn.h @@ -13,5 +13,25 @@ namespace jit { */ TORCH_API Module FoldConvBatchNorm(const Module& module); +struct TORCH_API ConvBNParameters { + at::Tensor conv_w; + at::Tensor conv_b; + at::Tensor bn_rm; + at::Tensor bn_rv; + double bn_eps = 0.0; + at::Tensor bn_w; + at::Tensor bn_b; +}; + +/** + * Given the current weight and bias tensors of a Conv module and parameters + * of the BatchNorm module we're folding with, compute the updated values + * for the weight and bias. + * + * The function is basically copied from torch/nn/utils/fusion.py + */ +TORCH_API std::tuple computeUpdatedConvWeightAndBias( + const ConvBNParameters& p); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index bec7bf144201f..18e90a2ba4ed3 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -1,7 +1,9 @@ #include + #include #include +#include #include #include @@ -11,38 +13,29 @@ namespace torch { namespace jit { namespace { -ModulePtr getModulePtrForGetAttrNode( - const Node* node, - const std::shared_ptr& graph, - const Module& graph_input_module) { - std::vector names; - names.clear(); - while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) { - TORCH_INTERNAL_ASSERT( - node->kind() == prim::GetAttr, "Expected prim::GetAttr nodes"); - names.insert(names.begin(), node->s(attr::name)); - node = node->inputs()[0]->node(); - } - // Copy/paste from quantization/helper.h - Module m = graph_input_module; - for (const auto& p : names) { - m = m.attr(p).toModule(); - } - return m._ivalue(); -} class AttributePropagator { public: AttributePropagator( Module& module, std::vector& preservedAttrs, - bool freezeInterfaces) - : module_(module), freezeInterfaces_(freezeInterfaces) { + bool freezeInterfaces, + bool preserveParameters) + : module_(module), + freezeInterfaces_(freezeInterfaces), + preserveParameters_(preserveParameters) { // Currently only top level attributes and functions can be preserved // explicitly. auto checkName = [this](std::string& name) { if (module_.hasattr(name)) { - insertMutableAttr(name, module_.attr(name), module_._ivalue()); + auto attr = module_.attr(name); + + // Freezing client wants to presever this submodule. When cleaning + // the frozen module, make sure it will be preserved entirely. + if (attr.isModule()) { + preservedSubModule_.insert(attr.toModule()._ivalue()); + } + insertMutableAttr(name, attr, module_._ivalue()); return true; } @@ -87,9 +80,11 @@ class AttributePropagator { void run() { auto applyInline = [](std::shared_ptr& subgraph) { Inline(*subgraph); + ClearProfilingInformation(subgraph); }; auto applyOptimizations = [](std::shared_ptr& subgraph) { - runOptimization(subgraph, /* unroll? */ false); + runOptimization( + subgraph, /* unroll? */ false, /* const_prop_user_classes? */ false); }; for (auto function : preservedMethods_) { @@ -161,7 +156,7 @@ class AttributePropagator { Module& attrModule, std::shared_ptr& graph) { if (!input->type()->cast() && - !input->type()->expect()->is_module()) { + !input->type()->expectRef().is_module()) { return false; } @@ -224,6 +219,13 @@ class AttributePropagator { for (Block* sub_block : n->blocks()) { blocks.push(sub_block); } + + // Modules with prim::ModuleDictIndex cannot be frozen because they + // return InterfaceTypes. + TORCH_CHECK( + n->kind() != prim::ModuleDictIndex, + "Freezing modules containing prim::ModuleDictIndex is not supported"); + if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) { // By default if interface attributes are present then fail freezing. // If freezingInterfaces is on then Interfaces are folded similarly @@ -288,11 +290,11 @@ class AttributePropagator { IValue overrideGradient(IValue attr) { if (attr.isTensor()) { - auto t = attr.toTensor(); + auto& t = attr.toTensor(); if (t.requires_grad()) { - t = t.detach(); - t.set_requires_grad(false); - attr = IValue(t); + auto detached = t.detach(); + detached.set_requires_grad(false); + attr = IValue(std::move(detached)); } } else if (attr.isTuple()) { auto tuple = std::move(attr).toTuple(); @@ -315,6 +317,15 @@ class AttributePropagator { val = overrideGradient(val); } attr = std::move(dict); + } else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) { + auto obj_type = attr.type()->expect(); + auto obj_value = std::move(attr).toObject(); + auto sub_attributes = obj_type->getAttributes(); + for (const auto& sub_attr : sub_attributes) { + auto sub_attr_val = obj_value->getAttr(sub_attr.getName()); + sub_attr_val = overrideGradient(sub_attr_val); + } + return obj_value; } return attr; @@ -414,7 +425,7 @@ class AttributePropagator { if (!findConstantAttr(input, name, attrModule, graph)) { GRAPH_DEBUG( input->type()->cast() || - input->type()->expect()->is_module() + input->type()->expectRef().is_module() ? "attribute: " + name + " is mutable." : ""); continue; @@ -429,8 +440,20 @@ class AttributePropagator { } if (!paramConst) { auto attr = attrModule.attr(name); - - if (isEval) { + if (!isEval || preserveParameters_) { + auto type = attrModule.type(); + auto slot = *type->findAttributeSlot(name); + if (type->is_parameter(slot) || type->is_buffer(slot) || + (attr.isObject() && + !attr.toObjectRef().type()->is_module())) { + continue; + } else { + attr = overrideGradient(attr); + } + if (!isEval && name == "training") { + continue; + } + } else { attr = overrideGradient(attr); } if (auto attrVal = tryInsertConstant(*graph, attr)) { @@ -503,7 +526,32 @@ class AttributePropagator { return true; } } - return false; + return preservedSubModule_.count(subModule._ivalue()); + } + + void removeExtraWaitCalls(Block* b) { + auto nodes = b->nodes(); + for (auto it = nodes.begin(); it != nodes.end(); it++) { + auto node = *it; + if (node->kind() != aten::wait) { + continue; + } + TORCH_INTERNAL_ASSERT(node->inputs().size() == 1); + TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); + // If input type is not a from aten::fork call then the + // aten::wait operator can be deleted. + if (node->input()->type()->kind() != TypeKind::FutureType) { + node->output()->replaceAllUsesWith(node->input()); + it.destroyCurrent(); + } + } + // For the remaining nodes, recurse. + for (auto it = nodes.begin(); it != nodes.end(); it++) { + auto node = *it; + for (auto sub_b : node->blocks()) { + removeExtraWaitCalls(sub_b); + } + } } // cleanupFrozenModule function cleans up the Frozen module. It performs the @@ -516,11 +564,12 @@ class AttributePropagator { auto graph = function->graph(); recordReferencedAttrs(graph); handleSharedClassType(module_, graph); + removeExtraWaitCalls(graph->block()); } removeUnusedAttrs(); } - // Prepraring for clean up phase. At this point, record all subModules that + // Prepraring for clean up phase. At this point, record all subModules that // contains mutable attributes. void recordReferencedAttrs(std::shared_ptr& graph) { std::stack blocks({graph->block()}); @@ -534,12 +583,27 @@ class AttributePropagator { } if (n->kind() == prim::GetAttr) { auto& name = n->s(attr::name); - auto mptr = - getModulePtrForGetAttrNode(n->input(0)->node(), graph, module_); - auto module = Module(mptr); - if (module.type() == n->inputs()[0]->type() && module.hasattr(name)) { - auto attr = module.attr(name); - insertMutableAttr(name, attr, mptr); + // For now, use all module ivalues which are the same type + // and could be the module that this GetAttr resolves to + // TODO: we could attempt to follow the GetAttr chain and + // find the exact ivalue, we would have to be careful + // that the chain does not contain any attributes which + // get written to (setAttr calls) + for (auto& mptr : modules) { + auto module = Module(mptr); + if (module.type() == n->inputs()[0]->type()) { + TORCH_INTERNAL_ASSERT(module.hasattr(name)); + auto module = Module(mptr); + auto attr = module.attr(name); + // TODO: this could be insertReferencedAttr to be more clear, + // these are attributes we could not inline, which include + // other reasons besides mutation (unsupported constant, + // getAttr resolving to non-getAttr node, etc) + insertMutableAttr(name, attr, mptr); + if (attr.isModule()) { + modules.insert(attr.toModule()._ivalue()); + } + } } } else if (n->kind() == prim::fork) { applyToForkSubgraph( @@ -657,6 +721,9 @@ class AttributePropagator { // Contains user specified methods to be preserved in frozen module. std::unordered_set preservedMethods_; + // Contains user specified sub module to be preserve in frozen module. + std::unordered_set preservedSubModule_; + // Track all used attributes ivalues that can be aliased. IValue::HashAliasedIValues usedAttrs_; @@ -672,6 +739,9 @@ class AttributePropagator { // Allow to freeze modules containing interfaces. bool freezeInterfaces_; + // Preserve module parameters + bool preserveParameters_; + // Contains the attributes names (e.g. {"self", "subModule", "a"} std::deque names_; }; // class AttributePropagator @@ -680,18 +750,8 @@ class AttributePropagator { Module freeze_module( const Module& module, std::vector preservedAttrs, - bool freezeInterfaces) { - // Currently freezing module is supported only in eval mode. - // If assertion below is commented and module is in training mode then this - // implementation folds attributes correctly. Tensor attributes with - // required_grad set are not folded and 'training' attribute is also not - // folded. - // TODO: Determine if freezing in training mode is useful and further clarify - // its semantics. - TORCH_CHECK( - !module.hasattr("training") || !module.is_training(), - "Freezing module in training mode is not yet supported"); - + bool freezeInterfaces, + bool preserveParameters) { Method method = module.get_method("forward"); // Check that module does not return itself. for (auto& output : method.graph()->outputs()) { @@ -702,7 +762,7 @@ Module freeze_module( auto moduleClone = module.clone(true); AttributePropagator attrPropagator( - moduleClone, preservedAttrs, freezeInterfaces); + moduleClone, preservedAttrs, freezeInterfaces, preserveParameters); attrPropagator.run(); return moduleClone; } diff --git a/torch/csrc/jit/passes/freeze_module.h b/torch/csrc/jit/passes/freeze_module.h index 270cbca0c4139..d8bb61043ef36 100644 --- a/torch/csrc/jit/passes/freeze_module.h +++ b/torch/csrc/jit/passes/freeze_module.h @@ -11,7 +11,7 @@ /** \brief Freeze Module, i.e., Assume all atrributes are constants. * * Freezing module is a functionality that allows the JIT to internalize - * imutable attributes. Combined with inlinig, the module is aggressively + * immutable attributes. Combined with inlining, the module is aggressively * optimized and significant overhead is optimized away. The freezeModule API * produces a cloned frozen module. */ @@ -22,7 +22,8 @@ namespace jit { TORCH_API Module freeze_module( const Module& module, std::vector preservedAttrs = std::vector(), - bool freezeInterfaces = true); + bool freezeInterfaces = true, + bool preserveParameters = false); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/frozen_conv_folding.cpp b/torch/csrc/jit/passes/frozen_conv_folding.cpp new file mode 100644 index 0000000000000..128094e09b78f --- /dev/null +++ b/torch/csrc/jit/passes/frozen_conv_folding.cpp @@ -0,0 +1,375 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using Tensor = at::Tensor; + +bool nonConstantParameters(Node* n) { + for (size_t i = 1; i < n->inputs().size(); i++) { + if (n->inputs().at(i)->node()->kind() != prim::Constant) { + return true; + } + } + return false; +} + +bool supportedConvNode(Node* n) { + switch (n->kind()) { + case aten::conv1d: + case aten::conv2d: + case aten::conv3d: + return true; + case aten::_convolution: { + auto transposed_conv = + constant_as(n->namedInput("transposed")).value_or(true); + // dont handle transposed conv yet or not-constant transpose parameter + return !transposed_conv; + } + default: + return false; + } +} + +void FoldFrozenConvBatchnorm(Block* b) { + for (Node* n : b->nodes()) { + for (Block* block : n->blocks()) { + FoldFrozenConvBatchnorm(block); + } + + if (n->kind() == aten::batch_norm && + supportedConvNode(n->inputs().at(0)->node())) { + auto conv = n->inputs().at(0)->node(); + auto bn = n; + if (nonConstantParameters(conv) || nonConstantParameters(bn)) { + continue; + } + if (conv->output()->uses().size() > 1) { + continue; + } + + auto bn_rm = constant_as(bn->namedInput("running_mean")).value(); + auto bn_rv = constant_as(bn->namedInput("running_var")).value(); + auto bn_eps = constant_as(bn->namedInput("eps")).value(); + auto conv_w = constant_as(conv->namedInput("weight")).value(); + + // implementation taken from torch/nn/utils/fusion.py + Tensor conv_b; + if (conv->namedInput("bias")->type() == NoneType::get()) { + conv_b = at::zeros_like(bn_rm); + } else { + conv_b = constant_as(conv->namedInput("bias")).value(); + } + Tensor bn_w; + if (bn->namedInput("weight")->type() == NoneType::get()) { + bn_w = at::ones_like(bn_rm); + } else { + bn_w = constant_as(bn->namedInput("weight")).value(); + } + Tensor bn_b; + if (n->namedInput("bias")->type() == NoneType::get()) { + bn_b = at::zeros_like(bn_rm); + } else { + bn_b = constant_as(bn->namedInput("bias")).value(); + } + + ConvBNParameters params; + params.conv_w = conv_w; + params.conv_b = conv_b; + params.bn_rm = bn_rm; + params.bn_rv = bn_rv; + params.bn_eps = bn_eps; + params.bn_w = bn_w; + params.bn_b = bn_b; + std::tuple out = computeUpdatedConvWeightAndBias(params); + WithInsertPoint guard(conv); + auto fused_conv_w = b->owningGraph()->insertConstant(std::get<0>(out)); + auto fused_conv_b = b->owningGraph()->insertConstant(std::get<1>(out)); + auto conv_w_value = conv->namedInput("weight"); + auto conv_b_value = conv->namedInput("bias"); + + fused_conv_w->setDebugName(conv_w_value->debugName() + "_fused_bn"); + fused_conv_b->setDebugName(conv_b_value->debugName() + "_fused_bn"); + + conv->replaceInputWith(conv_w_value, fused_conv_w); + conv->replaceInputWith(conv_b_value, fused_conv_b); + + bn->output()->replaceAllUsesWith(conv->output()); + } + } +} + +bool supportedAddOrSub(Node* n) { + static const OperatorSet add_set{ + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", + "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + // sub is equivalent to add + "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", + "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + }; + return n->isMemberOf(add_set); +} + +// In order to fuse add/sub/mul/div with conv, the dimensions of its +// constant tensor must satisfy the following: +// - with resizing, broadcast to w/ weight/bias tensor shape +// - broadcast to the conv output shape +// It needs to have a shape that can resize to weight/bias +// tensor shape because we need to run the op with the conv +// weights/bias without changing their sizes. +// It needs to broadcast to the conv output shape so that we do +// accidentally change the shape of op output by pre-fusing it +// compared to eager. +// The only dimension value shared by weight/bias/conv output +// is they all contain a dim with value = channels-out. In the +// conv output tensor, this is in the second dimension, +// so the pointwise op tensor may have a second dimension of +// value == channels-out, but all the other dimensions have to be 1 +bool opDoesNotBroadCastWithConv(Tensor& op_tensor, Tensor& weight_tensor) { + if (op_tensor.ndimension() > weight_tensor.ndimension()) { + return false; + } + for (int64_t i = op_tensor.ndimension() - 1; i >= 0; i--) { + // channels-out dimension == weight_tensor.size(0) + if (i == 1 && op_tensor.size(i) == weight_tensor.size(0)) { + continue; + } + if (op_tensor.size(i) != 1) { + return false; + } + } + return true; +} + +bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) { + if (nonConstantParameters(conv) || nonConstantParameters(op)) { + return false; + } + + if (conv->output()->uses().size() > 1) { + return false; + } + + auto conv_w = constant_as(conv->namedInput("weight")).value(); + Tensor weight_tensor = + constant_as(conv->namedInput("weight")).value(); + + // avoid fusing op that causes type promotion + // resticting to float avoids int/float difficulties with scalar overload + if (!weight_tensor.is_floating_point()) { + return false; + } + + if (op->inputs().at(1)->type()->cast()) { + auto op_tensor = constant_as(op->inputs().at(1)).value(); + if (!opDoesNotBroadCastWithConv(op_tensor, weight_tensor)) { + return false; + } + if (!op_tensor.is_floating_point()) { + return false; + } + if (c10::promoteTypes( + op_tensor.scalar_type(), weight_tensor.scalar_type()) != + weight_tensor.scalar_type()) { + return false; + } + } + return true; +} + +Tensor resizeConstantScalarOrTensorToShape( + Value* v, + const std::vector& shape, + at::TensorOptions options) { + Tensor ret_tensor; + if (v->type()->cast()) { + ret_tensor = constant_as(v).value(); + } else { + ret_tensor = at::zeros(shape, options); + if (v->type()->cast()) { + ret_tensor.fill_(constant_as(v).value()); + } else { + ret_tensor.fill_(constant_as(v).value()); + } + } + + if (ret_tensor.numel() == 1) { + // expand errors if the shape input has less # dims than the tensor input + ret_tensor = ret_tensor.reshape({1}); + ret_tensor = ret_tensor.expand(shape); + } else { + TORCH_INTERNAL_ASSERT(ret_tensor.numel() == at::prod_intlist(shape)); + ret_tensor = ret_tensor.view(shape); + } + return ret_tensor; +} + +void FoldFrozenConvAddOrSub(Block* b) { + for (Node* n : b->nodes()) { + for (Block* block : n->blocks()) { + FoldFrozenConvAddOrSub(block); + } + + if (supportedAddOrSub(n) && supportedConvNode(n->inputs().at(0)->node())) { + auto conv = n->inputs().at(0)->node(); + auto add_or_div = n; + + if (!checkConvAndBroadcastingOpPreConditions(conv, add_or_div)) { + continue; + } + + Tensor weight_tensor = + constant_as(conv->namedInput("weight")).value(); + + Tensor add_or_sub_tensor = resizeConstantScalarOrTensorToShape( + add_or_div->inputs().at(1), + {weight_tensor.size(0)}, + weight_tensor.options()); + Tensor bias; + if (conv->namedInput("bias")->type() == NoneType::get()) { + bias = at::zeros_like(add_or_sub_tensor); + } else { + bias = constant_as(conv->namedInput("bias")).value(); + } + + WithInsertPoint guard(conv); + + add_or_div->replaceInputWith( + conv->output(), b->owningGraph()->insertConstant(bias)); + add_or_div->replaceInput( + 1, b->owningGraph()->insertConstant(add_or_sub_tensor)); + + auto stack_out = runNodeIfInputsAreConstant(add_or_div); + TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); + Tensor fuse_bias = (*stack_out)[0].toTensor(); + + auto fused_conv_b = b->owningGraph()->insertConstant(fuse_bias); + auto conv_b_value = conv->namedInput("bias"); + + fused_conv_b->setDebugName( + conv_b_value->debugName() + "_fused_" + + add_or_div->kind().toUnqualString()); + conv->replaceInputWith(conv_b_value, fused_conv_b); + add_or_div->output()->replaceAllUsesWith(conv->output()); + // DCE run after cleans up nodes + } + } +} + +bool supportedMulOrDiv(Node* n) { + static const OperatorSet add_set{ + "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", + "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", + // div is equivalent to mul + "aten::div.Tensor(Tensor self, Tensor other) -> Tensor", + "aten::div.Scalar(Tensor self, Scalar other) -> Tensor", + }; + return n->isMemberOf(add_set); +} + +void FoldFrozenConvMulOrDiv(Block* b) { + for (Node* n : b->nodes()) { + for (Block* block : n->blocks()) { + FoldFrozenConvMulOrDiv(block); + } + + if (supportedMulOrDiv(n) && supportedConvNode(n->inputs().at(0)->node())) { + auto conv = n->inputs().at(0)->node(); + auto mul_or_div = n; + + if (!checkConvAndBroadcastingOpPreConditions(conv, mul_or_div)) { + continue; + } + + Tensor weight_tensor = + constant_as(conv->namedInput("weight")).value(); + int64_t out_channels = weight_tensor.size(0); + + // We've already verified that the second input has numel == 1 or + // channels-out resize it to the shape that will broadcast to + // weight_tensor when the op is run so we dont change weight size + std::vector weight_compatible_size = {out_channels}; + for (int64_t i = 1; i < weight_tensor.ndimension(); ++i) { + weight_compatible_size.push_back(1); + } + + WithInsertPoint guard(conv); + + Tensor mul_tensor = resizeConstantScalarOrTensorToShape( + mul_or_div->inputs().at(1), + weight_compatible_size, + weight_tensor.options()); + + // First fold with weight tensor + mul_or_div->replaceInputWith( + conv->output(), b->owningGraph()->insertConstant(weight_tensor)); + mul_or_div->replaceInput(1, b->owningGraph()->insertConstant(mul_tensor)); + + auto stack_out = runNodeIfInputsAreConstant(mul_or_div); + TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); + Tensor fuse_weight = (*stack_out)[0].toTensor(); + + auto fused_conv_weight = b->owningGraph()->insertConstant(fuse_weight); + auto conv_weight_value = conv->namedInput("weight"); + + fused_conv_weight->setDebugName( + conv_weight_value->debugName() + "_fused_" + + mul_or_div->kind().toUnqualString()); + conv->replaceInputWith(conv_weight_value, fused_conv_weight); + mul_or_div->output()->replaceAllUsesWith(conv->output()); + + // now fold with bias tensor + if (conv->namedInput("bias")->type() != NoneType::get()) { + Tensor bias = constant_as(conv->namedInput("bias")).value(); + // bias is of shape {channels_out} + auto mul_tensor = resizeConstantScalarOrTensorToShape( + mul_or_div->inputs().at(1), {out_channels}, bias.options()); + + mul_or_div->replaceInput(0, b->owningGraph()->insertConstant(bias)); + mul_or_div->replaceInput( + 1, b->owningGraph()->insertConstant(mul_tensor)); + + auto stack_out = runNodeIfInputsAreConstant(mul_or_div); + TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); + Tensor fuse_bias = (*stack_out)[0].toTensor(); + + auto fused_conv_bias = b->owningGraph()->insertConstant(fuse_bias); + auto conv_b_value = conv->namedInput("bias"); + + fused_conv_weight->setDebugName( + conv_b_value->debugName() + "_fused_" + + mul_or_div->kind().toUnqualString()); + conv->replaceInputWith(conv_b_value, fused_conv_bias); + } + // DCE run after cleans up nodes + } + } +} + +void FoldFrozenConvBatchnorm(std::shared_ptr& graph) { + FoldFrozenConvBatchnorm(graph->block()); + EliminateDeadCode(graph); +} + +void FoldFrozenConvAddOrSub(std::shared_ptr& graph) { + FoldFrozenConvAddOrSub(graph->block()); + EliminateDeadCode(graph); +} + +void FoldFrozenConvMulOrDiv(std::shared_ptr& graph) { + FoldFrozenConvMulOrDiv(graph->block()); + EliminateDeadCode(graph); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/frozen_conv_folding.h b/torch/csrc/jit/passes/frozen_conv_folding.h new file mode 100644 index 0000000000000..ef12fd8b5fa9e --- /dev/null +++ b/torch/csrc/jit/passes/frozen_conv_folding.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// Fuses Convolution -> Batchnorm into a single Convolution by +// folding batchnorm weights into conv weights. +// This pass only works on Frozen Graphs; otherwise it is a No-Op. +TORCH_API void FoldFrozenConvBatchnorm(std::shared_ptr& graph); + +// Fuses Convolution -> Add/Sub into a single Convolution by +// folding add constant tensor into conv weights. +// This pass only works on Frozen Graphs; otherwise it is a No-Op. +TORCH_API void FoldFrozenConvAddOrSub(std::shared_ptr& graph); + +// Fuses Convolution -> Mul/Div into a single Convolution by +// folding add constant tensor into conv weights. +// This pass only works on Frozen Graphs; otherwise it is a No-Op. +TORCH_API void FoldFrozenConvMulOrDiv(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp new file mode 100644 index 0000000000000..458c022f71f5e --- /dev/null +++ b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +void OptimizeFrozenGraph(std::shared_ptr& graph) { + // run a couple times to capture Conv -> Mul -> Add etc + for (size_t i = 0; i < 2; i++) { + FoldFrozenConvBatchnorm(graph); + FoldFrozenConvAddOrSub(graph); + FoldFrozenConvMulOrDiv(graph); + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.h b/torch/csrc/jit/passes/frozen_graph_optimizations.h new file mode 100644 index 0000000000000..a808791d35c12 --- /dev/null +++ b/torch/csrc/jit/passes/frozen_graph_optimizations.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +/** \brief Runs a set of Optimizations that Optimize Frozen Graphs + * + * Currently this set of optimizations is: + * - FoldFrozenConvBatchnorm + * - FoldFrozenConvAddOrSub + * - FoldFrozenConvMulOrDiv + */ + +namespace torch { +namespace jit { + +TORCH_API void OptimizeFrozenGraph(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/fuse_relu.cpp b/torch/csrc/jit/passes/fuse_relu.cpp index 8c3fe1e6e7180..479b4b29a9c52 100644 --- a/torch/csrc/jit/passes/fuse_relu.cpp +++ b/torch/csrc/jit/passes/fuse_relu.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -17,7 +18,7 @@ void fuseAddReluImpl(std::shared_ptr& graph) { return (%res))"; std::string add_relu_fused = R"( graph(%a, %b, %alpha): - %res = aten::add_relu(%a, %b, %alpha) + %res = aten::_add_relu(%a, %b, %alpha) return (%res))"; rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused); @@ -35,7 +36,7 @@ void fuseAddReluImpl(std::shared_ptr& graph) { return (%res))"; std::string add_inplace_relu_fused = R"( graph(%a, %b, %alpha): - %res = aten::add_relu_(%a, %b, %alpha) + %res = aten::_add_relu_(%a, %b, %alpha) return (%res))"; rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused); @@ -46,7 +47,7 @@ void fuseAddReluImpl(std::shared_ptr& graph) { return (%res))"; std::string add_out_relu_fused = R"( graph(%a, %b, %alpha, %out): - %res = aten::add_relu(%a, %b, %alpha, %out) + %res = aten::_add_relu(%a, %b, %alpha, %out) return (%res))"; rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused); diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 07634bfc52009..8243a2cb624f1 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -120,16 +121,6 @@ bool isSimpleMap(Node* node) { return true; } -Value* broadcastSizes(at::ArrayRef sizes, AliasDb* db) { - AT_ASSERT(!sizes.empty()); - Graph* graph = sizes[0]->owningGraph(); - Node* broadcast_n = - graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); - broadcast_n->output()->setType(ListType::ofInts()); - db->createValue(broadcast_n->output()); - return broadcast_n->output(); -} - struct GraphFuser { using FusionCallback = std::function; @@ -160,12 +151,13 @@ struct GraphFuser { AliasDb* aliasDb, Block* block, FusionCallback callback, - Symbol kind) + Symbol kind, + bool strict_fuser_check = false) : block_(block), aliasDb_(aliasDb), callback_(std::move(callback)), kind_(kind), - strict_fuser_check_(false) {} + strict_fuser_check_(strict_fuser_check) {} void setInputArgLimit(size_t limit) { subgraph_arg_limit_ = limit; @@ -185,7 +177,7 @@ struct GraphFuser { if (!v->type()->isSubtypeOf(TensorType::get())) { return true; } - auto device = v->type()->expect()->device(); + auto device = v->type()->expectRef().device(); if (!device) { return !strict_fuser_check; } @@ -926,13 +918,6 @@ struct GraphFuser { } } - bool usedOnlyInSize(Value* v) { - const auto& uses = v->uses(); - return std::all_of(uses.begin(), uses.end(), [](const Use& u) { - return u.user->matches("aten::size(Tensor self) -> int[]"); - }); - } - // Builds up expressions that compute shapes of all intermediates (and // outputs) of the fusion group, based on the sizes of inputs. You should run // DCE to remove those that you end up not using. @@ -1139,6 +1124,13 @@ struct GraphFuser { } void run() { +// TODO: old fuser is not maintained internally, somewhere it is being turned on +// inadvertently for certain workflows. make this a no-op until we identify +// location +#if defined(FBCODE_CAFFE2) + return; +#endif + // Run the pass until no changes are made. // This is necessary, because the algorithm can miss out on certain fusion // opportunities if ran only once. Consider this graph: @@ -1185,7 +1177,8 @@ struct GraphFuser { for (Node* node : block_->nodes()) { for (Block* sub_block : node->blocks()) { - GraphFuser(aliasDb_, sub_block, callback_, kind_).run(); + GraphFuser(aliasDb_, sub_block, callback_, kind_, strict_fuser_check_) + .run(); } } } @@ -1260,7 +1253,7 @@ void FuseGraph(std::shared_ptr& graph, bool strict_fuser_check) { void CustomFuseGraph( std::shared_ptr& graph, - std::function fn, + const std::function& fn, Symbol kind, size_t arg_limit) { AliasDb db(graph); diff --git a/torch/csrc/jit/passes/graph_fuser.h b/torch/csrc/jit/passes/graph_fuser.h index eee89200e2bd9..0cdcc2e20f469 100644 --- a/torch/csrc/jit/passes/graph_fuser.h +++ b/torch/csrc/jit/passes/graph_fuser.h @@ -26,7 +26,7 @@ TORCH_API void FuseGraph( // post condition on the fused subgraph. TORCH_API void CustomFuseGraph( std::shared_ptr& graph, - std::function is_fusable, + const std::function& is_fusable, Symbol kind, size_t arg_limit = std::numeric_limits::max()); diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp index 6a6c0a2c355c3..34d7fd6c121a1 100644 --- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp +++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -8,7 +9,7 @@ namespace jit { namespace graph_rewrite_helper { std::string getFuncName(Value* func_value) { - auto func = func_value->type()->expect()->function(); + auto func = func_value->type()->expectRef().function(); const auto& qname = func->qualname(); const auto& name = qname.qualifiedName(); auto rdot_idx = name.rfind('.'); @@ -84,43 +85,51 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; - std::string conv_transpose2d_for_deprecated_conv = R"( + std::string conv1d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): - %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) + %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; - std::string conv_transpose2d = R"( + std::string conv1d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): - %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) + %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; - std::string conv1d_for_deprecated_conv = R"( + std::string conv3d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): - %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) + %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; - std::string conv1d = R"( + std::string conv3d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): - %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) + %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; - std::string conv3d_for_deprecated_conv = R"( + std::string conv_transpose1d = R"( + graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], + %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, + %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): + %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) + return (%r) )"; + + std::string conv_transpose2d_for_deprecated_conv = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): - %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) + %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; - std::string conv3d = R"( + + std::string conv_transpose2d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): - %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) + %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) return (%r) )"; // Filter the unsupported case @@ -146,6 +155,29 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { } return !calc_value_map["transposed"].toBool(); }; + auto filter_conv3d = [](const Match& match, + const std::unordered_map& vmap) { + auto calc_value_map = getConvParams(match, vmap); + if (calc_value_map["output_padding"].toIntList().size() != 3 || + calc_value_map["stride"].toIntList().size() != 3 || + calc_value_map["padding"].toIntList().size() != 3 || + calc_value_map["dilation"].toIntList().size() != 3) { + return false; + } + return !calc_value_map["transposed"].toBool(); + }; + auto filter_conv_transpose1d = + [](const Match& match, + const std::unordered_map& vmap) { + auto calc_value_map = getConvParams(match, vmap); + if (calc_value_map["output_padding"].toIntList().size() != 1 || + calc_value_map["stride"].toIntList().size() != 1 || + calc_value_map["padding"].toIntList().size() != 1 || + calc_value_map["dilation"].toIntList().size() != 1) { + return false; + } + return calc_value_map["transposed"].toBool(); + }; auto filter_conv_transpose2d = [](const Match& match, const std::unordered_map& vmap) { @@ -158,39 +190,36 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { } return calc_value_map["transposed"].toBool(); }; - auto filter_conv3d = [](const Match& match, - const std::unordered_map& vmap) { - auto calc_value_map = getConvParams(match, vmap); - if (calc_value_map["output_padding"].toIntList().size() != 3 || - calc_value_map["stride"].toIntList().size() != 3 || - calc_value_map["padding"].toIntList().size() != 3 || - calc_value_map["dilation"].toIntList().size() != 3) { - return false; - } - return !calc_value_map["transposed"].toBool(); - }; SubgraphRewriter rewriter_conv1d; rewriter_conv1d.RegisterRewritePattern(convolution, conv1d); rewriter_conv1d.RegisterRewritePattern( convolution_deprecated, conv1d_for_deprecated_conv); rewriter_conv1d.runOnGraph(graph, filter_conv1d); + SubgraphRewriter rewriter_conv2d; rewriter_conv2d.RegisterRewritePattern(convolution, conv2d); rewriter_conv2d.RegisterRewritePattern( convolution_deprecated, conv2d_for_deprecated_conv); rewriter_conv2d.runOnGraph(graph, filter_conv2d); + + SubgraphRewriter rewriter_conv3d; + rewriter_conv3d.RegisterRewritePattern(convolution, conv3d); + rewriter_conv3d.RegisterRewritePattern( + convolution_deprecated, conv3d_for_deprecated_conv); + rewriter_conv3d.runOnGraph(graph, filter_conv3d); + + SubgraphRewriter rewriter_conv_transpose1d; + rewriter_conv_transpose1d.RegisterRewritePattern( + convolution, conv_transpose1d); + rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d); + SubgraphRewriter rewriter_conv_transpose2d; rewriter_conv_transpose2d.RegisterRewritePattern( convolution, conv_transpose2d); rewriter_conv_transpose2d.RegisterRewritePattern( convolution_deprecated, conv_transpose2d_for_deprecated_conv); rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d); - SubgraphRewriter rewriter_conv3d; - rewriter_conv3d.RegisterRewritePattern(convolution, conv3d); - rewriter_conv3d.RegisterRewritePattern( - convolution_deprecated, conv3d_for_deprecated_conv); - rewriter_conv3d.runOnGraph(graph, filter_conv3d); } bool isClampFusable( diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.h b/torch/csrc/jit/passes/graph_rewrite_helper.h index f21859df1f471..57b91be92edd0 100644 --- a/torch/csrc/jit/passes/graph_rewrite_helper.h +++ b/torch/csrc/jit/passes/graph_rewrite_helper.h @@ -39,10 +39,11 @@ struct PatternInfo { static PatternInfo parse_from_str( std::string pattern_string, const std::vector& filters = {}) { - PatternInfo rv{std::move(pattern_string), - std::make_unique(), - decltype(vmap){}, - filters}; + PatternInfo rv{ + std::move(pattern_string), + std::make_unique(), + decltype(vmap){}, + filters}; parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap); return rv; } diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index 9ae808e418c39..4f3a96bf90179 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -241,7 +242,7 @@ struct GuardElimination { size_t i = 0; for (auto input : n->inputs()) { if ((input->node()->kind() == prim::Guard && - !input->type()->expect()->isSummarized()) || + !input->type()->expectRef().isSummarized()) || input->node()->kind() == prim::Constant || (allow_numbers && input->type()->isSubtypeOf(NumberType::get())) || except.count(i) != 0) { @@ -376,7 +377,7 @@ struct GuardElimination { case aten::conv3d: return checkInputs(n, std::unordered_set{2, 6}, false); case aten::slice: - return !n->input(0)->type()->expect()->isSummarized() && + return !n->input(0)->type()->expectRef().isSummarized() && // check that the dimension argument is constant n->input(1)->node()->kind() == prim::Constant && // the start offset is constant @@ -388,7 +389,7 @@ struct GuardElimination { case aten::max_pool1d: case aten::max_pool2d: case aten::max_pool3d: - return !n->input(0)->type()->expect()->isSummarized() && + return !n->input(0)->type()->expectRef().isSummarized() && // check that the kernel size is constant n->input(1)->node()->kind() == prim::Constant && // check that the stride is constant @@ -401,7 +402,7 @@ struct GuardElimination { n->input(5)->node()->kind() == prim::Constant; case aten::unsqueeze: // check that the dimension argument is constant - return !n->input(0)->type()->expect()->isSummarized() && + return !n->input(0)->type()->expectRef().isSummarized() && n->input(1)->node()->kind() == prim::Constant; case aten::cat: // check that the dimension argument is constant @@ -426,8 +427,8 @@ struct GuardElimination { // aten::size is effectively a constant if (asize->input() ->type() - ->expect() - ->sizes() + ->expectRef() + .sizes() .concrete_sizes()) { return true; } diff --git a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp index 45c0283345233..c5d91391f43e7 100644 --- a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp @@ -21,7 +21,7 @@ bool canRunWithAutograd(Node* node) { } return kind != prim::FusionGroup && kind != prim::CudaFusionGroup && kind != prim::TypeCheck && kind != prim::TensorExprGroup && - (kind.is_aten() || kind.is_prim()); + kind != prim::CudaFusionGuard && (kind.is_aten() || kind.is_prim()); } namespace { diff --git a/torch/csrc/jit/passes/inline_forked_closures.cpp b/torch/csrc/jit/passes/inline_forked_closures.cpp index ea5a977e40912..771050030c976 100644 --- a/torch/csrc/jit/passes/inline_forked_closures.cpp +++ b/torch/csrc/jit/passes/inline_forked_closures.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -19,7 +20,7 @@ void inlineForkedClosure(Node* fork_closure) { Node* function_context_node = fork_closure->input()->node(); if (function_context_node->inputs().size() != 2 || - function_context_node->inputs().at(0)->node()->kind() != prim::Function || + function_context_node->inputs().at(0)->node()->kind() != prim::Closure || function_context_node->inputs().at(1)->node()->kind() != prim::TupleConstruct) { throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value"; diff --git a/torch/csrc/jit/passes/inliner.cpp b/torch/csrc/jit/passes/inliner.cpp index 02edaa7745163..2d510bd181e6c 100644 --- a/torch/csrc/jit/passes/inliner.cpp +++ b/torch/csrc/jit/passes/inliner.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/insert_guards.cpp b/torch/csrc/jit/passes/insert_guards.cpp index 20621317cc70e..8269d4e4deb89 100644 --- a/torch/csrc/jit/passes/insert_guards.cpp +++ b/torch/csrc/jit/passes/insert_guards.cpp @@ -60,7 +60,7 @@ void InsertGuards(std::shared_ptr graph) { gi.run(); } -void RemoveProfilingNodes(std::shared_ptr graph) { +void RemoveProfilingNodes(const std::shared_ptr& graph) { removeProfilingNodes(graph->block()); } diff --git a/torch/csrc/jit/passes/insert_guards.h b/torch/csrc/jit/passes/insert_guards.h index db1b06daa457b..a9a5035d05db9 100644 --- a/torch/csrc/jit/passes/insert_guards.h +++ b/torch/csrc/jit/passes/insert_guards.h @@ -15,7 +15,7 @@ namespace jit { TORCH_API void InsertGuards(std::shared_ptr graph); -TORCH_API void RemoveProfilingNodes(std::shared_ptr graph); +TORCH_API void RemoveProfilingNodes(const std::shared_ptr& graph); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/lift_closures.cpp b/torch/csrc/jit/passes/lift_closures.cpp index 82e6f22166812..09ac524978fd6 100644 --- a/torch/csrc/jit/passes/lift_closures.cpp +++ b/torch/csrc/jit/passes/lift_closures.cpp @@ -1,11 +1,12 @@ #include + #include #include namespace torch { namespace jit { -// Closures are initially emitted as prim::Function nodes with a single block. +// Closures are initially emitted as prim::Closure nodes with a single block. // Here, we convert the block to a subgraph, adding all closed over variables // as a context tuple input to the closure node. // At this point the closure has already undergone conversion to SSA, @@ -58,7 +59,7 @@ void liftClosures(Block* block) { Node* n = *it; it++; switch (n->kind()) { - case prim::Function: { + case prim::Closure: { liftClosure(n); } break; default: { diff --git a/torch/csrc/jit/passes/liveness.cpp b/torch/csrc/jit/passes/liveness.cpp index e580eb425507b..3b2cf54461f86 100644 --- a/torch/csrc/jit/passes/liveness.cpp +++ b/torch/csrc/jit/passes/liveness.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/lower_grad_of.cpp b/torch/csrc/jit/passes/lower_grad_of.cpp index d9bdd9141346f..3f3de5ff779e2 100644 --- a/torch/csrc/jit/passes/lower_grad_of.cpp +++ b/torch/csrc/jit/passes/lower_grad_of.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { diff --git a/torch/csrc/jit/passes/lower_graph.cpp b/torch/csrc/jit/passes/lower_graph.cpp index 581f4d14d8b42..b4da8216b5afa 100644 --- a/torch/csrc/jit/passes/lower_graph.cpp +++ b/torch/csrc/jit/passes/lower_graph.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp index d241e0d7560d1..2fc61d388e775 100644 --- a/torch/csrc/jit/passes/lower_tuples.cpp +++ b/torch/csrc/jit/passes/lower_tuples.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/metal_rewrite.cpp b/torch/csrc/jit/passes/metal_rewrite.cpp new file mode 100644 index 0000000000000..e07f4a384088c --- /dev/null +++ b/torch/csrc/jit/passes/metal_rewrite.cpp @@ -0,0 +1,203 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace { + +void insertPrePackedConv2dOp(std::shared_ptr& graph) { + graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); + + std::string conv_2d_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): + %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) + return (%r) )"; + + std::string prepacked_ops_conv2d_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], + %dilation:int[], %groups:int): + %output_min_max : None = prim::Constant() + %packed_weight_bias = metal_prepack::conv2d_prepack( + %weight, %bias, %stride, %padding, %dilation, %groups, + %output_min_max, %output_min_max) + %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) + return (%r) )"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern( + conv_2d_pattern, prepacked_ops_conv2d_pattern); + rewriter.runOnGraph(graph); +} + +void fuseReluWithPackedOps(std::shared_ptr& graph) { + SubgraphRewriter rewriter; + + std::string conv2d_prepack_run_relu = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], + %dilation:int[], %groups:int, %dummy_min_max): + %packed_weight_bias = metal_prepack::conv2d_prepack( + %weight, %bias, %stride, %padding, %dilation, %groups, + %dummy_min_max, %dummy_min_max) + %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) + %r = aten::relu(%r) + return (%r) )"; + + std::string conv2d_prepack_run_relu_fused = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], + %dilation:int[], %groups:int, %dummy_min_max): + %output_min: float = prim::Constant[value=0.0]() + %output_max: None = prim::Constant() + %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack( + %weight, %bias, %stride, %padding, %dilation, %groups, + %output_min, %output_max) + %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) + return (%r) )"; + + rewriter.RegisterRewritePattern( + conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused); + + std::string conv2d_prepack_run_relu_inplace = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], + %dilation:int[], %groups:int, %dummy_min_max): + %packed_weight_bias = metal_prepack::conv2d_prepack( + %weight, %bias, %stride, %padding, %dilation, %groups, + %dummy_min_max, %dummy_min_max) + %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) + %r = aten::relu_(%r) + return (%r) )"; + + rewriter.RegisterRewritePattern( + conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused); + + rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); +} + +void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { + SubgraphRewriter rewriter; + + std::string conv2d_prepack_run_hardtanh_fused = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], + %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): + %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack( + %weight, %bias, %stride, %padding, %dilation, %groups, + %output_min, %output_max) + %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) + return (%r) )"; + + std::string conv2d_prepack_run_hardtanh = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], + %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): + %packed_weight_bias = metal_prepack::conv2d_prepack( + %weight, %bias, %stride, %padding, %dilation, %groups, + %dummy_min_max, %dummy_min_max) + %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) + %r = aten::hardtanh(%r, %output_min, %output_max) + return (%r) )"; + + rewriter.RegisterRewritePattern( + conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused); + + std::string conv2d_prepack_run_hardtanh_inplace = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], + %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max): + %packed_weight_bias = metal_prepack::conv2d_prepack( + %weight, %bias, %stride, %padding, %dilation, %groups, + %dummy_min_max, %dummy_min_max) + %r = metal_prepack::conv2d_run(%input, %packed_weight_bias) + %r = aten::hardtanh_(%r, %output_min, %output_max) + return (%r) )"; + + rewriter.RegisterRewritePattern( + conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused); + + rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); +} + +} // namespace + +void metalInsertPrePackedOps(std::shared_ptr& graph) { + insertPrePackedConv2dOp(graph); +} + +void metalInsertPrePackedOps(script::Module& module) { + for (auto& method : module.get_methods()) { + auto graph = method.graph(); + metalInsertPrePackedOps(graph); + } + for (script::Module m : module.children()) { + metalInsertPrePackedOps(m); + } +} + +void metalFoldPrePackingOps(script::Module& m) { + PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { + return ( + n->kind() == Symbol::fromQualString("metal_prepack::conv2d_prepack")); + }; + PrePackingOpsFolder(m, filter_fn, "prepack_folding"); +} + +void metalFusePrePackedConvWithClamp(script::Module& module) { + auto graph = module.get_method("forward").graph(); + fuseReluWithPackedOps(graph); + fuseHardtanhWithPackedOps(graph); +} + +void metalInsertCopyOps(script::Module& module) { + auto graph = module.get_method("forward").graph(); + auto&& outputs = graph->outputs(); + for (size_t i = 0; i < outputs.size(); ++i) { + Value* output = outputs[i]; + auto namedValue = NamedValue("", output); + if (namedValue.type()->kind() == TypeKind::TensorType) { + // find the insertion point + WithInsertPoint ip(output->node()->next()); + Value* replaced_output = graph->insert( + Symbol::fromQualString("metal::copy_to_host"), {namedValue}); + // replaced the output + graph->block()->replaceOutput(i, replaced_output); + } + } + SubgraphRewriter rewriter; + rewriter.runOnGraph(graph); +} + +void runCanonicalOptimizations(script::Module& module) { + auto graph = module.get_method("forward").graph(); + runOptimization(graph, false /* no loop unrolling */); +} + +script::Module metalOptimizeForMobile( + const script::Module& m, + const std::vector& preserved_methods) { + auto cloned_module = m.clone(); + cloned_module.eval(); + cloned_module = FoldConvBatchNorm(cloned_module); + metalInsertPrePackedOps(cloned_module); + cloned_module = freeze_module(cloned_module, preserved_methods); + metalFusePrePackedConvWithClamp(cloned_module); + metalFoldPrePackingOps(cloned_module); + metalInsertCopyOps(cloned_module); + removeDropout(cloned_module); + // remove duplicated constants + runCanonicalOptimizations(cloned_module); + cloned_module.register_attribute( + "optimized_for_metal", BoolType::get(), true); + return cloned_module; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/metal_rewrite.h b/torch/csrc/jit/passes/metal_rewrite.h new file mode 100644 index 0000000000000..30e4825cedd17 --- /dev/null +++ b/torch/csrc/jit/passes/metal_rewrite.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include +#include +#include + +namespace torch { +namespace jit { +TORCH_API void metalInsertPrePackedOps(std::shared_ptr& graph); +TORCH_API void metalInsertPrePackedOps(script::Module& module); +TORCH_API void metalFusePrePackedConvWithClamp(script::Module& module); +TORCH_API void metalFoldPrePackingOps(script::Module& module); +TORCH_API script::Module metalOptimizeForMobile( + const script::Module& module, + const std::vector& preserved_methods); +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp index 75ad2a4499ce2..a2ea0a7c5e7c0 100644 --- a/torch/csrc/jit/passes/normalize_ops.cpp +++ b/torch/csrc/jit/passes/normalize_ops.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -6,30 +7,6 @@ namespace jit { namespace { -// map from op alias -> normalized op -static const std::unordered_map alias_map = { - {aten::absolute, aten::abs}, {aten::absolute_, aten::abs_}, - {aten::clip, aten::clamp}, {aten::clip_, aten::clamp_}, - {aten::linalg_det, aten::det}, {aten::ger, aten::outer}, - {aten::arccos, aten::acos}, {aten::arccos_, aten::acos_}, - {aten::arcsin, aten::asin}, {aten::arcsin_, aten::asin_}, - {aten::arctan, aten::atan}, {aten::arctan_, aten::atan_}, - {aten::arccosh, aten::acosh}, {aten::arccosh_, aten::acosh_}, - {aten::arcsinh, aten::asinh}, {aten::arcsinh_, aten::asinh_}, - {aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_}, - {aten::fix, aten::trunc}, {aten::fix_, aten::trunc_}, - {aten::negative, aten::neg}, {aten::negative_, aten::neg_}, - {aten::subtract, aten::sub}, {aten::subtract_, aten::sub_}, - {aten::greater_equal, aten::ge}, {aten::greater_equal_, aten::ge_}, - {aten::greater, aten::gt}, {aten::greater_, aten::gt_}, - {aten::less_equal, aten::le}, {aten::less_equal_, aten::le_}, - {aten::less, aten::lt}, {aten::less_, aten::lt_}, - {aten::not_equal, aten::ne}, {aten::not_equal_, aten::ne_}, - {aten::divide, aten::div}, {aten::divide_, aten::div_}, - {aten::multiply, aten::mul}, {aten::multiply_, aten::mul_}, - {aten::true_divide, aten::div}, {aten::true_divide_, aten::div_}, -}; - void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) { WithInsertPoint insert_guard{node}; auto graph = node->owningGraph(); @@ -53,8 +30,8 @@ void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) { // difficult to consumer for downstream user of the IR, such as our own // optimization passes here, we convert op aliases into a standard form bool normalizeOpAliases(graph_node_list_iterator& iter) { - auto alias = alias_map.find(iter->kind()); - if (alias != alias_map.end()) { + auto alias = getOperatorAliasMap().find(iter->kind()); + if (alias != getOperatorAliasMap().end()) { replaceNodeWithNewSymbol(*iter, alias->second); iter.destroyCurrent(); return true; @@ -79,6 +56,59 @@ void NormalizeOps(Block* block) { } // namespace +const std::unordered_map& getOperatorAliasMap() { + // map from op alias -> normalized op + static const std::unordered_map alias_map = { + {aten::absolute, aten::abs}, + {aten::absolute_, aten::abs_}, + {aten::clip, aten::clamp}, + {aten::clip_, aten::clamp_}, + {aten::linalg_det, aten::det}, + {aten::ger, aten::outer}, + {aten::arccos, aten::acos}, + {aten::arccos_, aten::acos_}, + {aten::arcsin, aten::asin}, + {aten::arcsin_, aten::asin_}, + {aten::arctan, aten::atan}, + {aten::arctan_, aten::atan_}, + {aten::arccosh, aten::acosh}, + {aten::arccosh_, aten::acosh_}, + {aten::arcsinh, aten::asinh}, + {aten::arcsinh_, aten::asinh_}, + {aten::arctanh, aten::atanh}, + {aten::arctanh_, aten::atanh_}, + {aten::fix, aten::trunc}, + {aten::fix_, aten::trunc_}, + {aten::negative, aten::neg}, + {aten::negative_, aten::neg_}, + {aten::subtract, aten::sub}, + {aten::subtract_, aten::sub_}, + {aten::greater_equal, aten::ge}, + {aten::greater_equal_, aten::ge_}, + {aten::greater, aten::gt}, + {aten::greater_, aten::gt_}, + {aten::less_equal, aten::le}, + {aten::less_equal_, aten::le_}, + {aten::less, aten::lt}, + {aten::less_, aten::lt_}, + {aten::not_equal, aten::ne}, + {aten::not_equal_, aten::ne_}, + {aten::divide, aten::div}, + {aten::divide_, aten::div_}, + {aten::multiply, aten::mul}, + {aten::multiply_, aten::mul_}, + {aten::true_divide, aten::div}, + {aten::true_divide_, aten::div_}, + {aten::row_stack, aten::vstack}, + {aten::swapdims, aten::transpose}, + {aten::swapdims_, aten::transpose_}, + {aten::swapaxes, aten::transpose}, + {aten::swapaxes_, aten::transpose_}, + {aten::moveaxis, aten::movedim}, + }; + return alias_map; +} + void NormalizeOps(const std::shared_ptr& graph) { NormalizeOps(graph->block()); } diff --git a/torch/csrc/jit/passes/normalize_ops.h b/torch/csrc/jit/passes/normalize_ops.h index 03963cd26a8be..4d630392ca47d 100644 --- a/torch/csrc/jit/passes/normalize_ops.h +++ b/torch/csrc/jit/passes/normalize_ops.h @@ -12,5 +12,7 @@ namespace jit { // Currently only handles normalization of op aliases. TORCH_API void NormalizeOps(const std::shared_ptr& graph); +const std::unordered_map& getOperatorAliasMap(); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 5bc0a2e50e948..51476dc4c667c 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index ffb98d926b85e..66c8398d42c09 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -1,4 +1,5 @@ #include + #include #include diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index ce9ce5fb37c4e..b826db926e132 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -1,4 +1,6 @@ #include + +#include #include #include #include @@ -82,33 +84,6 @@ bool IsErasableSequence(const Node* loop_node, size_t i) { return true; } -void FixupONNXLoopNodeInputs(Node* node) { - if (node->kind() != ::c10::onnx::Loop) { - return; - } - - auto* graph = node->owningGraph(); - - // add cast to condition input outside the loop. - Value* cond_val = node->inputs()[1]; - if (IsCondCastRequired(cond_val)) - InsertCastForCond(cond_val, graph, node); - - // Setup Loop input cond and i. - TORCH_INTERNAL_ASSERT(node->blocks().size() == 1); - auto* sub_block = node->blocks()[0]; - Value* cond = sub_block->insertInput(1, "cond"); - cond->setType(BoolType::create()); - - Value* i = sub_block->inputs()[0]; - i->setType(TensorType::fromNumberType(IntType::get())); - - // add cast to condition input inside the loop. - Value* next_cond_val = sub_block->outputs()[0]; - if (IsCondCastRequired(next_cond_val)) - InsertCastForCond(next_cond_val, graph, sub_block->return_node()); -} - // ONNX::Loop does not support Sequence type as loop-carried dependencies. Only // tensors are supported. This pass converts Sequence loop-carried dependencies // to scan_outputs. In opset 11, only the below pattern is supported. @@ -218,6 +193,33 @@ void ConvertSequenceDependencies(Block* block, int opset_version) { } } // anonymous namespace +void FixupONNXLoopNodeInputs(Node* node) { + if (node->kind() != ::c10::onnx::Loop) { + return; + } + + auto* graph = node->owningGraph(); + + // add cast to condition input outside the loop. + Value* cond_val = node->inputs()[1]; + if (IsCondCastRequired(cond_val)) + InsertCastForCond(cond_val, graph, node); + + // Setup Loop input cond and i. + TORCH_INTERNAL_ASSERT(node->blocks().size() == 1); + auto* sub_block = node->blocks()[0]; + Value* cond = sub_block->insertInput(1, "cond"); + cond->setType(BoolType::create()); + + Value* i = sub_block->inputs()[0]; + i->setType(TensorType::fromNumberType(IntType::get())); + + // add cast to condition input inside the loop. + Value* next_cond_val = sub_block->outputs()[0]; + if (IsCondCastRequired(next_cond_val)) + InsertCastForCond(next_cond_val, graph, sub_block->return_node()); +} + std::vector FixupONNXLoopNode(Node* node, int opset_version) { auto output_size = node->outputs().size(); FixupONNXLoopNodeInputs(node); @@ -226,6 +228,111 @@ std::vector FixupONNXLoopNode(Node* node, int opset_version) { return new_outputs; } +// Check if node is prim::Uninitialized, +// or output of prim::Uninitialized->onnx::Identity +bool IsUninitializedNode(Node* n) { + if (n->kind() == ::c10::onnx::Identity && + n->inputs()[0]->node()->kind() == prim::Uninitialized) + return true; + if (n->kind() == prim::Uninitialized) + return true; + return false; +} + +// Infer shape and type of the uninitialized_output from the corresponding +// output of the other subblock. prim::Uninitialized node is proven to be +// unused. So replace this node with a constant of the inferred shape and type. +void InferShapeTypeForUninitializedOutput( + Graph* graph, + Block* block, + Value* uninitialized_output, + Value* other_output) { + auto output_type = other_output->type()->expect(); + auto elem_type = at::initialTensorOptions().dtype(output_type->scalarType()); + Node* const_node = graph->create(::c10::onnx::Constant, 1); + + if (output_type->sizes().concrete_sizes().has_value()) { + auto size = output_type->sizes().concrete_sizes().value(); + const_node->t_(attr::value, at::zeros(size, elem_type)); + const_node->output()->setType(other_output->type()); + const_node->output()->copyMetadata(other_output); + } else { + const_node->t_(attr::value, at::zeros({}, elem_type)); + const_node->output()->setType( + TensorType::create(*(output_type->scalarType()), at::kCPU, {}, {})); + } + const_node->insertBefore(block->return_node()); + uninitialized_output->replaceAllUsesWith(const_node->output()); + uninitialized_output->node()->destroy(); +} + +// Corresponding outputs for ONNX If then and else subblocks should have +// same shape and type. This pass detects if prim::Uninitialized node +// appears as part of outputs of either of the subblocks, and infers +// shape and type from the corresponding output of the other subblock +// In the example graph below, shape and type of the subblock output %7 +// for subblock 1 is inferred from %y.1. Shape and type of Subblock +// output %7 is inferred from %y.5. +// +// graph(%y.1 : Int(3:4, 4:1, requires_grad=0, device=cpu)): +// ... +// %7 : Tensor = prim::Uninitialized() +// %16 : bool, %17 : Tensor, %y.14 : Tensor = prim::If(%15) # +// test/onnx/test_pytorch_onnx_onnxruntime.py:614:20 +// block0(): +// %y.5 : Tensor = aten::add(%y.1, %3, %6) # +// test/onnx/test_pytorch_onnx_onnxruntime.py:615:28 +// -> (%2, %7, %y.5) +// block1(): +// -> (%1, %y.1, %7) +// ... + +void ONNXFixupUninitializedOutput(Node* node) { + if (node->kind() != ::c10::onnx::If) { + return; + } + + GRAPH_DUMP("Graph before fixing If shape type: ", node->owningGraph()); + auto* if_node = node; + auto* graph = if_node->owningGraph(); + + // Check if the input to ONNX If node is node Bool, and insert + // cast to Bool if needed. + if (!if_node->input()->type()->isSubtypeOf(BoolType::get())) { + Node* cast_node = CreateCastToBoolNode(if_node->input(), graph); + cast_node->insertBefore(if_node); + if_node->replaceInputWith(if_node->input(), cast_node->output()); + } + + Block* then_block = if_node->blocks()[0]; + Block* else_block = if_node->blocks()[1]; + + // Infer shape and type for subblock outputs + TORCH_INTERNAL_ASSERT( + then_block->outputs().size() == else_block->outputs().size()) + for (size_t i = 0; i < else_block->outputs().size(); i++) { + Value* then_block_output = then_block->outputs()[i]; + Value* else_block_output = else_block->outputs()[i]; + + // If both subblocks have an uninitialized output, shape and type cannot + // be inferred. + TORCH_CHECK( + !(IsUninitializedNode(then_block_output->node()) && + IsUninitializedNode(else_block_output->node())), + "Cannot infer shape and type for ONNX If with uninitialized output in both subblocks. Please check the model graph."); + + if (IsUninitializedNode(then_block_output->node())) { + InferShapeTypeForUninitializedOutput( + graph, then_block, then_block_output, else_block_output); + if_node->outputs()[i]->setType(then_block->outputs()[i]->type()); + } else if (IsUninitializedNode(else_block_output->node())) { + InferShapeTypeForUninitializedOutput( + graph, else_block, else_block_output, then_block_output); + if_node->outputs()[i]->setType(else_block->outputs()[i]->type()); + } + } +} + std::vector FixupONNXIfNode(Node* node, int opset_version) { if (node->kind() != ::c10::onnx::If) { return node->outputs().vec(); @@ -234,17 +341,17 @@ std::vector FixupONNXIfNode(Node* node, int opset_version) { auto* if_node = node; auto* graph = if_node->owningGraph(); for (Block* block : node->blocks()) { - if (block->nodes().begin() == block->nodes().end()) { - // ONNX does not support empty blocks, must use some op which does - // nothing - Value* output = block->outputs()[0]; - Node* id_node = graph->create(onnx::Identity); - id_node->insertBefore(block->return_node()); - id_node->addInput(output); - id_node->output()->copyMetadata(output); - block->return_node()->replaceInputWith(output, id_node->output()); + for (Value* output : block->outputs()) { + if (output->node()->owningBlock() != block) { + Node* id_node = graph->create(onnx::Identity); + id_node->insertBefore(block->return_node()); + id_node->addInput(output); + id_node->output()->copyMetadata(output); + block->return_node()->replaceInputWith(output, id_node->output()); + } } } + ONNXFixupUninitializedOutput(if_node); GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph()); return if_node->outputs().vec(); } diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h index 3487946d721b7..fad7611085223 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h @@ -7,5 +7,5 @@ namespace jit { std::vector FixupONNXControlflowNode(Node* n, int opset_version); -} +} // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/fold_if_node.cpp b/torch/csrc/jit/passes/onnx/fold_if_node.cpp new file mode 100644 index 0000000000000..dbf24bcd6e387 --- /dev/null +++ b/torch/csrc/jit/passes/onnx/fold_if_node.cpp @@ -0,0 +1,269 @@ +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { + +namespace onnx { +using namespace ::c10::onnx; +} + +// This function determines wheather If Node can be folded. +static bool isStaticCondition(Node* node) { + TORCH_INTERNAL_ASSERT( + node->kind() == onnx::If || node->kind() == onnx::Not || + node->kind() == onnx::Identity); + auto cast_node = node->input()->node(); + if (cast_node->kind() != onnx::Cast) + cast_node = node; + auto prev_node = cast_node->input()->node(); + + if (prev_node->kind() == onnx::Not || prev_node->kind() == onnx::Identity || + prev_node->kind() == onnx::If) + return isStaticCondition(prev_node); + + auto compare_node = prev_node; + if (compare_node->kind() == onnx::Equal || + compare_node->kind() == onnx::Greater || + compare_node->kind() == onnx::Less || + compare_node->kind() == onnx::GreaterOrEqual || + compare_node->kind() == onnx::LessOrEqual) { + for (size_t i = 0; i < compare_node->inputs().size(); i++) { + auto sym = compare_node->inputs()[i] + ->type() + ->cast() + ->symbolic_sizes(); + if (!(compare_node->inputs()[i]->node()->kind() == onnx::Constant || + compare_node->inputs()[i]->node()->kind() == onnx::Size || + compare_node->inputs()[i]->node()->kind() == onnx::ReduceProd)) + return false; + if (compare_node->inputs()[i]->node()->kind() != onnx::Constant) { + auto shape_node = compare_node->inputs()[i]->node()->input()->node(); + auto shape = + shape_node->input()->type()->cast()->symbolic_sizes(); + + // ONNX shape and type inference cannot determine the shape of the input + if (!shape.rank()) + return false; + + // If dynamic_axes are used on inputs to ReduceProd node, don't fold If + // node + auto dynamic_axes = shape.isComplete(); + if (!dynamic_axes && + compare_node->inputs()[i]->node()->kind() == onnx::ReduceProd) + return false; + } + } + return true; + } else if (compare_node->kind() == onnx::Constant) { + return true; + } + return false; +} + +// find index of the block output +static c10::optional findIndex( + c10::ArrayRef outputs, + Value* input) { + c10::optional idx = c10::nullopt; + for (size_t i = 0; i < outputs.size(); i++) { + if (input == outputs[i]) { + idx = i; + break; + } + } + return idx; +} + +// This function returns the value of the constant-folded subblock +// that is input to the If node. +static bool constantFoldedConditionValue(Node* node) { + TORCH_INTERNAL_ASSERT(node->kind() == onnx::If); + // usually Cast node precedes If node in the graph, but + // there are some rare scenarios when that is not the case. + auto cast_node = node->input()->node(); + if (cast_node->kind() != onnx::Cast) + cast_node = node; + auto prev_node = cast_node->input()->node(); + if (prev_node->kind() == onnx::If) { + int cond = 1 - (int)constantFoldedConditionValue(prev_node); + Block* block = prev_node->blocks()[cond]; + auto outputs = cast_node->input()->node()->outputs(); + auto cast_input = cast_node->input(); + int idx = findIndex(outputs, cast_input).value(); + prev_node = block->outputs()[idx]->node(); + } + + if (prev_node->kind() == onnx::Constant) { + const at::Tensor& val = prev_node->t(attr::value); + return at::is_nonzero(val); + } + + if (prev_node->kind() == onnx::Identity && + prev_node->input()->node()->kind() == onnx::Constant) { + auto val = prev_node->input()->node()->t(attr::value); + return at::is_nonzero(val); + } + + Node* compare_node = nullptr; + if (prev_node->kind() == onnx::Not) { + compare_node = prev_node->input()->node(); + } else if (cast_node->inputs().size() > 0) { + compare_node = cast_node->input()->node(); + } + TORCH_INTERNAL_ASSERT(compare_node != nullptr); + ScalarTypeAnalysisNodeForONNX(compare_node); + std::vector inputs; + for (size_t i = 0; i < compare_node->inputs().size(); i++) { + auto input_node = compare_node->inputs()[i]->node(); + if (input_node->kind() == onnx::Constant) { + const at::Tensor& val = input_node->t(attr::value); + inputs.push_back(val); + } else { // input_node is either onnx::Size or onnx::ReduceProd + auto shape_node = input_node->input()->node(); + auto shape = + shape_node->input()->type()->cast()->symbolic_sizes(); + + at::Tensor val; + if (input_node->kind() == onnx::Size) { + auto rank = shape.rank(); + val = c10::scalar_to_tensor((int64_t)*rank); + } else if (input_node->kind() == onnx::ReduceProd) { + auto sizes = shape.sizes(); + int64_t prod = 1; + for (int64_t i = 0; i < (int64_t)*shape.rank(); i++) { + auto dim = sizes.value()[i].static_size(); + prod *= dim; + } + val = c10::scalar_to_tensor(prod); + } + + inputs.push_back(val); + } + } + + at::Tensor res; + if (compare_node->kind() == onnx::Equal) { + res = at::eq(inputs[0], inputs[1]); + if (prev_node->kind() == onnx::Not) + res = at::not_equal(inputs[0], inputs[1]); + } else if ( + compare_node->kind() == onnx::Greater && prev_node->kind() != onnx::Not) { + res = at::greater(inputs[0], inputs[1]); + } else if ( + (prev_node->kind() == onnx::Not && compare_node->kind() == onnx::Less) || + compare_node->kind() == onnx::GreaterOrEqual) { + res = at::greater_equal(inputs[0], inputs[1]); + } else if ( + compare_node->kind() == onnx::Less && prev_node->kind() != onnx::Not) { + res = at::less(inputs[0], inputs[1]); + } else if ( + (prev_node->kind() == onnx::Not && + compare_node->kind() == onnx::Greater) || + compare_node->kind() == onnx::LessOrEqual) { + res = at::less_equal(inputs[0], inputs[1]); + } else { + TORCH_INTERNAL_ASSERT( + false, "Condition value of the If node could not be constant-folded!"); + } + + return at::is_nonzero(res); +} + +// This pass return then or else branch of the If node depending on the +// value of the constant-folded sublock that is input to the If node +// +// Example: +// before post pass +// graph(%y.2 : Int(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): +// %4 : Long(2, strides=[1], device=cpu) = onnx::Shape(%y.2) +// %5 : Long(device=cpu) = onnx::Size(%4) +// %12 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={2}]() +// %6 : Bool(device=cpu) = onnx::Equal(%5, %12) +// %11 : bool = onnx::Cast[to=9](%6) +// %7 : Int(3, 4, strides=[4, 1], device=cpu) = onnx::If(%11) +// block0(): +// %13 : Int(requires_grad=0, device=cpu) = onnx::Constant[value={4}]() +// %8 : Int(3, 4, strides=[4, 1], device=cpu) = onnx::Add(%y.2, %13) +// %14 : Int(requires_grad=0, device=cpu) = onnx::Constant[value={2}]() +// %9 : Int(3, 4, strides=[4, 1], device=cpu) = onnx::Add(%8, %14) +// -> (%9) +// block1(): +// %y.1 : Int(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = +// onnx::Identity(%y.2) +// -> (%y.1) +// return (%7) + +// after post pass +// graph(%y.2 : Int(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): +// %4 : Long(2, strides=[1], device=cpu) = onnx::Shape(%y.2) +// %5 : Long(device=cpu) = onnx::Size(%4) +// %12 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={2}]() +// %6 : Bool(device=cpu) = onnx::Equal(%5, %12) +// %11 : bool = onnx::Cast[to=9](%6) +// %13 : Int(requires_grad=0, device=cpu) = onnx::Constant[value={4}]() +// %8 : Int(3, 4, strides=[4, 1], device=cpu) = onnx::Add(%y.2, %13) +// %14 : Int(requires_grad=0, device=cpu) = onnx::Constant[value={2}]() +// %9 : Int(3, 4, strides=[4, 1], device=cpu) = onnx::Add(%8, %14) +// return (%9) + +static void foldIfNode(Block* b) { + for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { + for (auto* child_block : it->blocks()) { + foldIfNode(child_block); + } + if (it->kind() == onnx::If) { + auto if_node = *it; + if (isStaticCondition(if_node)) { + Block* then_block = it->blocks()[0]; + Block* else_block = it->blocks()[1]; + Block* block = else_block; + if (constantFoldedConditionValue(if_node)) + block = then_block; + + std::vector nodes_in_valid_path; + for (auto* valid_node : block->nodes()) { + nodes_in_valid_path.push_back(valid_node); + } + Node* cur = if_node; + for (auto* valid_node : nodes_in_valid_path) { + valid_node->moveAfter(cur); + cur = valid_node; + } + for (size_t i = 0; i < block->return_node()->inputs().size(); ++i) { + if_node->outputs()[i]->replaceAllUsesWith( + block->return_node()->inputs()[i]); + } + it->removeAllInputs(); + it.destroyCurrent(); + } + } + } +} + +// This pass is folding If node when the condition (subblock) can be +// constant-folded. Currently ONNX Runtime is doing Shape and Type Inference on +// both branches of the If operator, regardless of which branch is executing in +// Runtime. This can cause runtime errors in some cases: +// 1. Condition of the If node is based on shape / size of the input +// 2. then and else branch have different return types +// Folding If node can prevent Runtime errors in ONNXRuntime. +void FoldIfNodeONNX(Block* b) { + foldIfNode(b); +} + +bool ConditionValueONNX(Node* n) { + return constantFoldedConditionValue(n); +} + +bool IsStaticConditionONNX(Node* n) { + return isStaticCondition(n); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/fold_if_node.h b/torch/csrc/jit/passes/onnx/fold_if_node.h new file mode 100644 index 0000000000000..9180142de5c53 --- /dev/null +++ b/torch/csrc/jit/passes/onnx/fold_if_node.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +void FoldIfNodeONNX(Block* b); +bool ConditionValueONNX(Node* n); +bool IsStaticConditionONNX(Node* n); + +} // namespace jit + +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/function_substitution.cpp b/torch/csrc/jit/passes/onnx/function_substitution.cpp index 6c51797e74339..460deae6dbd49 100644 --- a/torch/csrc/jit/passes/onnx/function_substitution.cpp +++ b/torch/csrc/jit/passes/onnx/function_substitution.cpp @@ -1,4 +1,5 @@ #include + #include #include diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index e4965f692a232..a14dcd611dd89 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -50,16 +51,6 @@ void buildParamsMapFromValueToParamsMap( } } -Node* addNodeToBlock(Block* block, Value* input, Symbol kind) { - auto new_node = block->appendNode(block->owningGraph()->create(kind)); - auto new_input = new_node->addInput(input); - for (size_t i = 0; i < new_node->outputs().size(); i++) { - auto output = new_node->outputs()[i]; - block->registerOutput(output); - } - return new_node; -} - c10::optional ONNXTypeToATenType(int32_t onnx_type) { switch (onnx_type) { case ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: @@ -94,5 +85,17 @@ c10::optional ONNXTypeToATenType(int32_t onnx_type) { return c10::optional{}; } +Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef inputs) { + auto new_node = block->appendNode(block->owningGraph()->create(kind)); + for (auto input : inputs) { + auto new_input = new_node->addInput(input); + } + return new_node; +} + +Value* addInputToBlock(Block* block) { + return block->addInput(); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/helper.h b/torch/csrc/jit/passes/onnx/helper.h index b3ab64fe759a3..e27909ff63620 100644 --- a/torch/csrc/jit/passes/onnx/helper.h +++ b/torch/csrc/jit/passes/onnx/helper.h @@ -27,7 +27,10 @@ void eraseUnusedBlockInputs(Block* b); void buildParamsMapFromValueToParamsMap( const ValueToParamPairMap& valsToParamsMap, ParamMap& paramsDict); -Node* addNodeToBlock(Block* block, Value* input, Symbol kind); + +Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef inputs); + +Value* addInputToBlock(Block* block); TORCH_API c10::optional ONNXTypeToATenType(int32_t onnx_type); } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp new file mode 100644 index 0000000000000..78cc7aedc680b --- /dev/null +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -0,0 +1,187 @@ +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +// findSubModuleAttr function chases getAttr chains backwards to locate the +// submodules. For example: module M { +// attributes { +// A = +// } +// ... +// %A = prim::GetAttr[name="A"](%self) +// ... +// %B = prim::GetAttr[name="B"](%A) +// ... +// %weight = prim::GetAttr[name="scale"](%B) +// ... + +std::deque findSubModuleAttr( + Value* input, + std::string& name, + Module& attrModule, + std::shared_ptr& graph) { + Node* node = input->node(); + std::deque moduleNames; + + // Loop starts from inner submodule and follows the chain until reaches the + // top module. + while (node->outputs().at(0)->type() != graph->inputs().at(0)->type()) { + if (node->kind() == prim::GetAttr) { + moduleNames.push_front(node->s(attr::name)); + node = node->inputs()[0]->node(); + } + } + + // Assign the inner module to attrModule. + for (auto& moduleName : moduleNames) { + attrModule = attrModule.attr(moduleName).toModule(); + } + return moduleNames; +} + +Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) { + auto schema = function->getSchema(); + auto args = schema.arguments(); + args.emplace_back(Argument(name, nullptr, c10::nullopt, attr)); + auto new_schema = FunctionSchema( + schema.name(), + schema.overload_name(), + args, + schema.returns(), + schema.is_vararg(), + schema.is_varret()); + function->setSchema(new_schema); + return function->graph()->addInput(name)->setType(attr.type()); +} + +std::vector getParamAttributes( + std::shared_ptr& graph, + const Module& module_, + Function* function_) { + std::vector attrValues; + auto isEval = !module_.hasattr("training") || !module_.is_training(); + auto block = graph->block(); + std::vector blocks({block}); + + Node* m = *block->nodes().begin(); + WithInsertPoint guard(m); + + while (!blocks.empty()) { + Block* block = blocks.back(); + blocks.pop_back(); + for (auto it = block->nodes().begin(); it != block->nodes().end();) { + Node* n = *it; + it++; // node n can be destroyed + + for (Block* sub_block : n->blocks()) { + blocks.emplace_back(sub_block); + } + if (n->kind() == prim::SetAttr && + n->s(attr::name) == "num_batches_tracked") { + n->destroy(); + } else if (n->kind() == prim::GetAttr) { + for (auto use : n->output()->uses()) { + if (use.user->kind() == prim::PythonOp) + throw ErrorReport(n->sourceRange()) + << "Couldn't export Python method."; + } + + auto name = n->s(attr::name); + auto attrModule = module_; + auto input = n->inputs()[0]; + + auto moduleNames = findSubModuleAttr(input, name, attrModule, graph); + if (!attrModule.hasattr(name)) { + continue; + } + Value* paramConst = nullptr; + + auto attr = attrModule.attr(name); + + std::string fullName(""); + for (auto& name : moduleNames) { + fullName += name + '.'; + } + fullName += name; + + auto type = attrModule.type(); + auto slot = *type->findAttributeSlot(name); + + if (type->is_parameter(slot) || type->is_buffer(slot) || + (attr.isObject() && !attr.toObjectRef().type()->is_module()) || + name == "training") { + if (attr.isTensor()) { + TORCH_INTERNAL_ASSERT(attr.isTensor()); + auto tensor_ = attr.toTensor(); + if (isEval && tensor_.requires_grad()) { + tensor_ = tensor_.detach(); + tensor_.set_requires_grad(false); + attr = IValue(tensor_); + } + attrValues.emplace_back(attr.toTensor()); + paramConst = addParamAsArgument(function_, fullName, attr); + } else if ( + attr.isObject() && !attr.toObjectRef().type()->is_module()) { + // Only below registered torch classes are supported. + auto type = attr.type(); + TORCH_CHECK( + (type == + getCustomClass( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) || + (type == + getCustomClass( + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) || + (type == + getCustomClass( + "__torch__.torch.classes.quantized.LinearPackedParamsBase")), + "Unknown type ", + type->repr_str(), + " encountered in handling model params. This type is not supported in ONNX export."); + attrValues.emplace_back( + script::Object(attr.toObject()).run_method("__getstate__")); + paramConst = addParamAsArgument(function_, fullName, attr); + } else if (attr.isNone() || name == "training") { + auto attrVal = tryInsertConstant(*graph, attr); + paramConst = *attrVal; + } + n->output()->replaceAllUsesWith(paramConst); + n->removeAllInputs(); + + GRAPH_UPDATE("Folding GetAttr %", n->outputs()[0]->debugName()); + } + } + } + } + return attrValues; +} + +std::pair> list_module_parameters( + const Module& module) { + Module moduleClone = module.clone(true); + Method method = moduleClone.get_method("forward"); + auto function = &method.function(); + std::vector modelParams; + + GRAPH_DEBUG("List attributes for function: " + function->name()); + auto graph = function->graph(); + // Add model_parameters and model_buffers as model inputs. Order is based on + // the appearance in the graph. + auto attributes = getParamAttributes(graph, moduleClone, function); + + modelParams.reserve(attributes.size()); + for (auto& attr_ : attributes) { + modelParams.push_back(attr_); + } + GRAPH_DEBUG("Cleaning up module"); + EliminateDeadCode(graph->block()); + + return std::make_pair(moduleClone, modelParams); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.h b/torch/csrc/jit/passes/onnx/list_model_parameters.h new file mode 100644 index 0000000000000..50d1cea2b8fe0 --- /dev/null +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { + +TORCH_API std::pair> list_module_parameters( + const Module& module); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 6177fbbcc6432..d488201a8f802 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -137,14 +138,14 @@ void fuseBroadcast(Block* b) { // Not all broadcasts are supported by ONNX broadcast. c10::optional axis = fusibleExpandTo( unexpanded_input->type() - ->expect() - ->sizes() + ->expectRef() + .sizes() .concrete_sizes() .value(), // from n->output() ->type() - ->expect() - ->sizes() + ->expectRef() + .sizes() .concrete_sizes() .value()); // to if (axis == c10::nullopt) @@ -310,7 +311,13 @@ void pushPackingPastRnn(Block* b) { std::vector new_sizes; new_sizes.push_back(*oldType->sizes()[0]); new_sizes.push_back(*oldType->sizes()[1]); - new_sizes.push_back(rnn->i(attr::hidden_size)); + if (next->kind() == onnx::Reshape) { + // bidirection + new_sizes.push_back(rnn->i(attr::hidden_size) * 2); + } else { + // unidirection + new_sizes.push_back(rnn->i(attr::hidden_size)); + } TensorTypePtr newType = TensorType::createContiguous( *oldType->scalarType(), *oldType->device(), new_sizes); next->outputs().at(0)->setType(newType); @@ -690,7 +697,8 @@ static void fuseLogSoftmaxNllLoss(Block* b) { // (%10) origLogSoftmaxNode = prev->input(0)->node(); auto transpose = origLogSoftmaxNode->input(0)->node(); - origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0)); + if (transpose->inputs().size() > 0) + origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0)); } else if ( prev->kind() == onnx::Reshape && prev->input(0)->node()->kind() == onnx::Transpose && @@ -742,11 +750,9 @@ static void fuseLogSoftmaxNllLoss(Block* b) { // onnx::Reshape(%35, %36) return (%37) auto nllloss_output = origNllLossNode->output(0)->uses()[0].user; TORCH_INTERNAL_ASSERT(nllloss_output->kind() == onnx::Reshape); - TORCH_INTERNAL_ASSERT( - nllloss_output->inputs()[1]->node()->kind() == - prim::ListConstruct); // make output of reshape the output of nllloss nllloss_output->replaceAllUsesWith(origNllLossNode); + origNllLossNode->output(0)->copyMetadata(nllloss_output->output(0)); } } else { continue; diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp index edbb159d481c0..709b174e94285 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp index 3c62b2877fa5a..ec98c01eaefa2 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -37,22 +38,23 @@ at::optional FindFusibleListUnpack(Node* n) { // split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] // split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] // -// graph(%input : Float(5:12, 4:3, 3:1)): +// graph(%input : Float(5, 4, 3, strides=[12, 3, 1])): // %13 : int[] = prim::Constant[value=[2, 1, 2]]() // %7 : int = prim::Constant[value=0]() // %8 : Tensor[] = aten::split_with_sizes(%input, %13, %7) -// %9 : Float(2:12, 4:3, 3:1), %10 : Float(1:12, 4:3, 3:1), %11 : Float(2:12, -// 4:3, 3:1) = prim::ListUnpack(%8) return (%9, %10, %11) +// %9 : Float(2, 4, 3, strides=[12, 3, 1]), %10 : Float(1, 4, 3, strides=[12, +// 3, 1]), %11 : Float(2, 4, 3, strides=[12, 3, 1]) = prim::ListUnpack(%8) +// return (%9, %10, %11) // // After fusion -// graph(%input : Float(5:12, 4:3, 3:1)): +// graph(%input : Float(5, 4, 3, strides=[12, 3, 1])): // %13 : int[] = prim::Constant[value=[2, 1, 2]]() // %7 : int = prim::Constant[value=0]() // %8 : int = prim::Constant[value=3]() # Adding addtional input of value 3 // representing the number of outputs. -// %14 : Float(2:12, 4:3, 3:1), %15 : Float(1:12, 4:3, 3:1), %16 : Float(2:12, -// 4:3, 3:1) = aten::split_with_sizes(%input, %13, %7, %8) -// return (%14, %15, %16) +// %14 : Float(2, 4, 3, strides=[12, 3, 1]), %15 : Float(1, 4, 3, strides=[12, +// 3, 1]), %16 : Float(2, 4, 3, strides=[12, 3, 1] = +// aten::split_with_sizes(%input, %13, %7, %8) return (%14, %15, %16) void FuseWithListUnpack(Node* n) { auto found_listUnpack = FindFusibleListUnpack(n); if (!found_listUnpack) { @@ -71,7 +73,7 @@ void FuseWithListUnpack(Node* n) { Symbol::fromQualString("attr::_outputs"), static_cast(listUnpack_node->outputs().size())); - for (auto i = 0; i < listUnpack_node->outputs().size(); ++i) { + for (size_t i = 0; i < listUnpack_node->outputs().size(); ++i) { auto new_output = n->addOutput(); new_output->copyMetadata(listUnpack_node->output(i)); } @@ -96,6 +98,7 @@ static void FuseWithListUnpack(Block* b) { case aten::unbind: case aten::unsafe_chunk: case aten::where: + case aten::nonzero_numpy: FuseWithListUnpack(*it); break; default: @@ -108,8 +111,8 @@ static void FuseWithListUnpack(Block* b) { // when inputs to the add node are two int lists // // before the pass: -// graph(%x.1 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu), -// %y.1 : Float(1:6, 2:3, 3:1, requires_grad=0, device=cpu)): +// graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu), +// %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)): // %2 : None = prim::Constant() // %3 : int[] = aten::size(%x.1) // %l1.1 : int[] = aten::list(%3 @@ -120,8 +123,8 @@ static void FuseWithListUnpack(Block* b) { // return (%8) // // after the pass: -// graph(%x.1 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu), -// %y.1 : Float(1:6, 2:3, 3:1, requires_grad=0, device=cpu)): +// graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu), +// %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)): // %2 : None = prim::Constant() // %3 : int[] = aten::size(%x.1) // %l1.1 : int[] = aten::list(%3) @@ -158,11 +161,67 @@ static void ReplaceAddWithConcat(Block* b) { } } +// This pass also covers the case when the input to ListUnpack +// is int[] comming from some other op than ListConstruct (like Slice or Shape) +// +// before the pass +// graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)): +// %1 : None = prim::Constant() +// %2 : int[] = aten::size(%x.1) # :7:9 +// %a.1 : int, %b.1 : int = prim::ListUnpack(%2) +// %5 : int[] = prim::ListConstruct(%a.1, %b.1) +// %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1) # +// test/onnx/test_pytorch_onnx_onnxruntime.py:1757:23 return (%6) +// +// after the pass: +// graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)): +// %1 : None = prim::Constant() +// %2 : int[] = aten::size(%x.1) # :7:9 +// %7 : Tensor = onnx::Constant[value={0}]() +// %8 : Tensor = onnx::Gather(%2, %7) +// %9 : Tensor = onnx::Constant[value={1}]() +// %10 : Tensor = onnx::Gather(%2, %9) +// %a.1 : int, %b.1 : int = prim::ListUnpack(%2) +// %5 : int[] = prim::ListConstruct(%8, %10) +// %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1) # +// test/onnx/test_pytorch_onnx_onnxruntime.py:1757:23 return (%6) +static void fuseListAndListUnpack(Block* b) { + for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { + for (auto* child_block : it->blocks()) { + fuseListAndListUnpack(child_block); + } + if (it->kind() == prim::ListUnpack) { + for (size_t i = 0; i < it->outputs().size(); i++) { + auto output = it->outputs().at(i); + if (it->inputs().size() == 1 && + it->input()->node()->kind() != prim::ListConstruct && + it->input()->type()->cast() && + it->input() + ->type() + ->cast() + ->getElementType() + ->cast()) { + Node* gather_indices = b->owningGraph()->create(onnx::Constant, 1); + gather_indices->insertBefore(*it); + gather_indices->t_( + attr::value, at::scalar_to_tensor(at::Scalar(int(i)))); + Node* gather_node = b->owningGraph()->create(onnx::Gather, 1); + gather_node->insertBefore(*it); + gather_node->addInput(it->input()); + gather_node->addInput(gather_indices->output()); + output->replaceAllUsesWith(gather_node->output()); + } + } + } + } +} + } // namespace void PreprocessForONNX(std::shared_ptr& graph) { FuseWithListUnpack(graph->block()); ReplaceAddWithConcat(graph->block()); + fuseListAndListUnpack(graph->block()); } } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index d9ae3ca244fd8..bc26183a25bb7 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -105,15 +105,29 @@ std::unordered_map MergeSliceAndSelectToIndices( // Loop over fetched slice and select nodes and convert them to index tensors. // keep track of which dimension the current slice/select node is applying to. int64_t cur_dim = 0; - // select does not keep dims, - // this creates offset for latter slice and select nodes. int64_t dim_offset = 0; const auto orig_tensor_indices = index_put_node->input(1)->node()->inputs(); for (auto it = slice_and_select_nodes.rbegin(); it != slice_and_select_nodes.rend(); ++it) { auto node = *it; - auto dim = node->get(attr::dim)->toInt() + dim_offset; + // select does not keep dims, + // this creates offset for latter slice and select nodes. + auto dim = node->get(attr::dim)->toInt(); + if (dim < 0) { + auto input_type = orig_data->type()->expect(); + if (input_type->dim().has_value()) { + auto rank = input_type->dim().value(); + // Rank of original tensor to index on. + // Minus the offset created by select operators. + dim = dim + rank - dim_offset; + } else { + std::cerr + << "Error: ONNX Remove Inplace Ops - Cannot export ellipsis indexing for input " + << "of unknown rank."; + } + } + dim = dim + dim_offset; while (cur_dim < dim) { // Handle skipped dims, these are created from ..., or tensor indices @@ -257,6 +271,93 @@ std::vector ReshapeToAdvancedIndexingFormat( return indices; } +// Register index_put inputs/outputs through the blocks. +// Eg. The IR before updating: +// = prim::Loop(%10, %27) +// block0(%stream_idx.1 : int): +// = prim::Loop(%9, %27) +// block0(%i.1 : int): +// %36 : Tensor = aten::select(%bias.1, %26, %stream_idx.1) +// %41 : Tensor = aten::copy_(%37, %40, %25) +// -> (%27) +// -> (%27) +// After updating: +// %62 : Tensor = prim::Loop(%10, %27, %bias.2) +// block0(%stream_idx.1 : int, %bias.3 : Tensor): +// %61 : Tensor = prim::Loop(%9, %27, %bias.3) +// block0(%i.1 : int, %bias.1 : Tensor): +// %36 : Tensor = aten::select(%bias.1, %26, %stream_idx.1) +// %59 : Tensor?[] = prim::ListConstruct(%55, %58) +// %60 : Tensor = aten::index_put(%bias.1, %59, %45, %25) +// -> (%27, %60) +// -> (%27, %61) +void RegisterIndexPutInBlocks( + Value* orig_data, + Value* new_index_put, + Node* block_node, + Block* outer_block, + Node* next_node) { + auto cur_node = next_node; + while (nullptr != cur_node) { + if (cur_node->kind() != prim::Loop) + return; + cur_node = cur_node->owningBlock()->owningNode(); + } + + for (auto block_input : outer_block->inputs()) { + if (block_input->debugName() == orig_data->debugName()) { + AT_ERROR( + "More than one aten::index_put in a subblock are not supported."); + } + } + + // Register index_put outputs through the blocks. + for (auto block_output : outer_block->outputs()) { + if (block_output->debugName() == new_index_put->debugName()) + return; + } + outer_block->registerOutput(new_index_put); + std::vector> node_list = { + std::make_pair(outer_block, next_node)}; + next_node->addOutput()->copyMetadata(new_index_put); + auto next_block = next_node->owningBlock(); + while (nullptr != next_block->owningNode()) { + outer_block = next_block; + outer_block->registerOutput(next_node->output(0)); + next_node = outer_block->owningNode(); + next_node->addOutput()->copyMetadata(new_index_put); + next_block = next_node->owningBlock(); + node_list.emplace_back(std::make_pair(outer_block, next_node)); + } + + // Register index_put inputs through the blocks. + auto next_data = orig_data; + while (!node_list.empty()) { + auto cur_pair = node_list.back(); + // Add input to current node. + cur_pair.second->addInput(next_data); + // Add input to current block. + auto cur_input = cur_pair.first->addInput(); + cur_input->copyMetadata(next_data); + next_data = cur_input; + node_list.pop_back(); + } + // Update index_put inputs inside the inner most block. + auto prev_data = block_node->input(0); + for (auto node : block_node->owningBlock()->nodes()) { + size_t idx = 0; + for (auto inputs_ : node->inputs()) { + if (inputs_ == prev_data) { + node->replaceInput(idx, next_data); + idx++; + break; + } + } + } + orig_data->replaceAllUsesAfterNodeWith( + next_node->output(0)->node(), next_node->output(0)); +} + // Trace back all the slice & select nodes associated with the index_put node, // and convert them to associated indices. // E.g. The IR for x[1:3, 0] = update @@ -322,7 +423,16 @@ void SquashSliceAndSelect(Node* index_put_node) { new_index_put->copyMetadata(index_put_node->output()); index_put_node->output()->replaceAllUsesWith(new_index_put); - orig_data->replaceAllUsesAfterNodeWith(new_index_put->node(), new_index_put); + auto block_node = new_index_put->node(); + auto outer_block = block_node->owningBlock(); + auto next_node = outer_block->owningNode(); + if (nullptr == next_node) { + orig_data->replaceAllUsesAfterNodeWith( + new_index_put->node(), new_index_put); + return; + } + RegisterIndexPutInBlocks( + orig_data, new_index_put, block_node, outer_block, next_node); } void PrepareCopyForONNX(Block* block) { @@ -340,14 +450,23 @@ void PrepareCopyForONNX(Block* block) { // Remove aten::copy_, and replace it with index_put. // 1. create an empty listConstruct node as indices input for index_put. // 2. create index_put node. + + // Tracing aten::copy_ broadcasts the rhs values. + // 3. Apply broadcasting for scripting. WithInsertPoint guard(node); auto graph = node->owningGraph(); auto dummy_list = graph->insertNode(graph->createList(OptionalType::ofTensor(), {})) ->output(); + + auto expanded_value = + graph->insert(aten::expand_as, {node->input(1), node->input(0)}); + expanded_value->node()->setSourceRange(node->sourceRange()); + expanded_value->copyMetadata(node->input(1)); + auto index_put = graph->insert( aten::index_put, - {node->input(0), dummy_list, node->input(1), node->input(2)}); + {node->input(0), dummy_list, expanded_value, node->input(2)}); index_put->node()->setSourceRange(node->sourceRange()); index_put->copyMetadata(node->output()); node->output()->replaceAllUsesWith(index_put); @@ -452,18 +571,29 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { << "Warning: ONNX Preprocess - Removing mutation on block inputs. " << "This changes graph semantics." << std::endl; - auto newNode = node->owningGraph()->create(aten::clone, 1); - newNode->output()->copyMetadata(input); - newNode->addInput(input); - - auto* noneNode = node->owningGraph()->create(prim::Constant); - noneNode->output()->setType(NoneType::get()); - newNode->addInput(noneNode->output()); - - newNode->insertBefore(node); - noneNode->insertBefore(newNode); - node->replaceInput(index, newNode->output()); - input->replaceAllUsesAfterNodeWith(node, newNode->output()); + if (input->type()->kind() == TypeKind::ListType) { + // Create an aten::list to clone the list in graph inputs + auto newNode = node->owningGraph()->create(aten::list, 1); + newNode->output()->copyMetadata(input); + newNode->addInput(input); + newNode->insertBefore(node); + node->replaceInput(index, newNode->output()); + input->replaceAllUsesAfterNodeWith(node, newNode->output()); + } else { + // Create an aten::clone to clone the tensor in graph inputs + auto newNode = node->owningGraph()->create(aten::clone, 1); + newNode->output()->copyMetadata(input); + newNode->addInput(input); + + auto* noneNode = node->owningGraph()->create(prim::Constant); + noneNode->output()->setType(NoneType::get()); + newNode->addInput(noneNode->output()); + + newNode->insertBefore(node); + noneNode->insertBefore(newNode); + node->replaceInput(index, newNode->output()); + input->replaceAllUsesAfterNodeWith(node, newNode->output()); + } } } } diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 3ba377df491e9..cffa3a709e4ce 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -50,17 +51,15 @@ static const std::unordered_set standardOps = { onnx::Mod, }; -static bool IsStandardOp(const NodeKind& nkind) { - return standardOps.find(nkind) != standardOps.end(); -} - // For these operators, all inputs share the same scalar type. // The output scalar type is always Bool. -static const std::unordered_set comparisonOps = {onnx::Greater, - onnx::Less, - onnx::Equal, - onnx::GreaterOrEqual, - onnx::LessOrEqual}; +static const std::unordered_set comparisonOps = { + onnx::Greater, + onnx::Less, + onnx::Equal, + onnx::GreaterOrEqual, + onnx::LessOrEqual, +}; static bool IsComparisonOp(const NodeKind& nkind) { return comparisonOps.find(nkind) != comparisonOps.end(); @@ -69,6 +68,7 @@ static bool IsComparisonOp(const NodeKind& nkind) { static TensorTypePtr CreateProfiledTensorTypeWithScalarType( const TensorTypePtr& typePtr, const c10::ScalarType& scalar_type) { + AT_ASSERT(typePtr != nullptr); return typePtr->withScalarType({scalar_type}); } @@ -90,9 +90,55 @@ static c10::optional PromoteScalarTypes( return st; } +// Type promotion between scalars and tensors +// per logic here +// https://pytorch.org/docs/master/tensor_attributes.html#tensor-attributes +static c10::optional PromoteScalarTypesWithCategory( + const std::vector& typesFromTensors, + const std::vector& typesFromScalars) { + auto typeFromTensor = PromoteScalarTypes(typesFromTensors); + auto typeFromScalar = PromoteScalarTypes(typesFromScalars); + + auto getTypeCategory = [](c10::ScalarType t) { + if (c10::kBool == t) { + return 1; + } + if (c10::isIntegralType(t, /*includeBool=*/false)) { + return 2; + } + if (c10::isFloatingType(t)) { + return 3; + } + return 0; + }; + + if (c10::nullopt == typeFromScalar) { + return typeFromTensor; + } else if (c10::nullopt == typeFromTensor) { + return typeFromScalar; + } + + auto typeCategoryFromTensor = getTypeCategory(typeFromTensor.value()); + auto typeCategoryFromScalar = getTypeCategory(typeFromScalar.value()); + + if (typeCategoryFromScalar > typeCategoryFromTensor) { + return typeFromScalar; + } + return typeFromTensor; +} + static c10::optional InferExpectedScalarType(const Node* n) { std::vector typesFromTensors; std::vector typesFromScalars; + + auto get_scalar_type = + [](const Value* input) -> c10::optional { + if (auto tensor_type = input->type()->cast()) { + return tensor_type->scalarType(); + } + return c10::nullopt; + }; + std::for_each( n->inputs().begin(), n->inputs().end(), [&](const Value* input) { auto nkind = input->node()->kind(); @@ -108,18 +154,46 @@ static c10::optional InferExpectedScalarType(const Node* n) { // which is by default considered as a tensor. typesFromScalars.emplace_back(c10::kLong); } else if (nkind == onnx::Constant) { - typesFromScalars.emplace_back( - input->node()->t(attr::value).scalar_type()); - } else if ( - auto scalar_type = - input->type()->cast()->scalarType()) { + auto tensor = input->node()->t(attr::value); + auto rank = tensor.dim(); + auto scalar_type = tensor.scalar_type(); + // Mimic PyTorch scalar type promotion logic + // from https://github.com/pytorch/pytorch/issues/9515 + // Quoting: + // A Tensor is a considered a "wrapped number" if it is + // auto-wrapped from a C++ or Python number type. Integer types are + // wrapped as 0-dim int64 tensors and floating-point types are + // wrapped as 0-dim double tensors. + if (rank == 0) { + auto default_scalar_type = + at::typeMetaToScalarType(at::get_default_dtype()); + switch (scalar_type) { + case at::kDouble: + // floating-point numbers wrapped as double tensors are + // considered to have default type, instead of double. + typesFromScalars.emplace_back(default_scalar_type); + break; + case at::kLong: + case at::kBool: + // bool and integer numbers remain the same type. + typesFromScalars.emplace_back(scalar_type); + break; + default: + // other types are not from wrapped numbers, + // track them as types from tensors. + typesFromTensors.emplace_back(scalar_type); + break; + } + } else { + typesFromTensors.emplace_back(scalar_type); + } + } else if (auto scalar_type = get_scalar_type(input)) { typesFromTensors.emplace_back(*scalar_type); } }); c10::optional st = c10::nullopt; - const c10::optional output_st = - n->output()->type()->cast()->scalarType(); + const auto output_st = get_scalar_type(n->output()); if (IsComparisonOp(n->kind())) { // For comparison ops, always promote scalar type to highest among inputs, @@ -130,31 +204,16 @@ static c10::optional InferExpectedScalarType(const Node* n) { typesFromTensors.end()); st = PromoteScalarTypes(typesFromScalars); } else { - if (typesFromScalars.size() == n->inputs().size()) { - // If all inputs are scalars, infer scalar_type by calling - // c10::promoteTypes. - st = PromoteScalarTypes(typesFromScalars); - } else if (output_st) { + if (output_st) { // If output scalar type is available, use that. st = output_st; - } else if (!typesFromTensors.empty()) { - // When inputs consist of tensors and scalars. In PyTorch, scalars are - // implicitly casted to have the same scalar type as input tensors. - st = typesFromTensors[0]; - if (std::any_of( - typesFromTensors.begin(), - typesFromTensors.end(), - [&st](const c10::ScalarType& type) { return type != st; })) { - std::cerr - << "Warning: ONNX Scalar Type Analysis - Scalar types mismatch for tensor inputs of operator " - << n->kind().toDisplayString() - << ". Please report a bug to PyTorch. " - << "The scalar type " << c10::toString(*st) - << " of the first tensor is chosen." << std::endl; - } } else { - // When inputs consist of only scalars. - st = PromoteScalarTypes(typesFromScalars); + // PyTorch now does implicit type promotion regardless whether the inputs + // are tensors or scalars. (Previously only scalars support implicit + // casting). + // Per logic here + // https://pytorch.org/docs/master/tensor_attributes.html#tensor-attributes + st = PromoteScalarTypesWithCategory(typesFromTensors, typesFromScalars); } } @@ -175,7 +234,8 @@ static void UpdateScalarTypeForInputs( for (auto input : n->inputs()) { auto input_tensor_type = input->type()->cast(); - auto input_scalar_type = input_tensor_type->scalarType(); + auto input_scalar_type = + input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt; if ((input->node()->kind() == onnx::Constant) || (input_scalar_type && (*input_scalar_type != scalar_type))) { @@ -211,42 +271,38 @@ static void UpdateScalarTypeForOutput( CreateProfiledTensorTypeWithScalarType(output_tensor_type, scalar_type)); } +static void ImplicitCastNodeForONNX(Node* n) { + if (IsImplicitCastSupported(n->kind())) { + auto expected_scalar_type = InferExpectedScalarType(n); + if (expected_scalar_type) { + UpdateScalarTypeForInputs(n, *expected_scalar_type); + if (!IsComparisonOp(n->kind())) { + UpdateScalarTypeForOutput(n, *expected_scalar_type); + } + } + } +} + static void ImplicitCastForONNX(Block* block) { for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { for (auto sub : it->blocks()) { ImplicitCastForONNX(sub); } - auto* subgraph = it->owningGraph(); - - if (IsImplicitCastSupported(it->kind())) { - auto expected_scalar_type = InferExpectedScalarType(*it); - if (expected_scalar_type) { - UpdateScalarTypeForInputs(*it, *expected_scalar_type); - if (!IsComparisonOp(it->kind())) { - UpdateScalarTypeForOutput(*it, *expected_scalar_type); - } - } - } + + ImplicitCastNodeForONNX(*it); } EliminateDeadCode( block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); } - -// This pass tries to resolve scalar type mismatch issues between input tensors -// introduced by the implicit type conversions on scalars. -// TODO: Note that currently this pass handles traced graph only. -// More specifically, graphs that have scalar type information recorded. -// For scripted graphs we need something like scalar type propagation, -// otherwise we do not have enough information to perform the check, let alone -// fixes. -void ImplicitCastForONNX(const std::shared_ptr& graph) { - ImplicitCastForONNX(graph->block()); -} } // anonymous namespace void ScalarTypeAnalysisForONNX(const std::shared_ptr& graph) { ImplicitCastForONNX(graph->block()); } +void ScalarTypeAnalysisNodeForONNX(Node* n) { + ImplicitCastNodeForONNX(n); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.h b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h index 65fc2278980b8..8a39085cc8406 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.h +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h @@ -6,6 +6,7 @@ namespace torch { namespace jit { TORCH_API void ScalarTypeAnalysisForONNX(const std::shared_ptr& graph); +void ScalarTypeAnalysisNodeForONNX(Node* n); } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 60725bf0cf32a..89ef8f5afae87 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -1,6 +1,9 @@ #include + #include +#include #include +#include #include #include @@ -40,8 +43,8 @@ TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) { return new_tensor_type; } auto type = old_tensor_type; - if (new_tensor_type->sizes().isComplete()) { - type = type->withSizes(new_tensor_type->sizes().concrete_sizes().value()); + if (new_tensor_type->dim()) { + type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes()); } if (new_tensor_type->scalarType().has_value()) { type = type->withScalarType(new_tensor_type->scalarType()); @@ -69,7 +72,8 @@ namespace onnx_torch = ::torch::onnx; namespace onnx = ::ONNX_NAMESPACE; TensorTypePtr TorchTensorTypeFromONNX( - const onnx::TypeProto_Tensor& onnx_tensor_type) { + const onnx::TypeProto_Tensor& onnx_tensor_type, + SymbolDimMap& symbol_map) { c10::optional scalar_type; if (onnx_tensor_type.has_elem_type()) { scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type()); @@ -82,33 +86,62 @@ TensorTypePtr TorchTensorTypeFromONNX( c10::VaryingShape{}, {}); if (onnx_tensor_type.has_shape()) { - std::vector sizes; + std::vector sizes; auto onnx_shape = onnx_tensor_type.shape(); for (int i = 0; i < onnx_shape.dim_size(); ++i) { auto& dim = onnx_shape.dim(i); if (dim.has_dim_value()) { - sizes.push_back(dim.dim_value()); + sizes.emplace_back(c10::ShapeSymbol::fromStaticSize(dim.dim_value())); } else { - // TODO: handle dim_param? - return v_type; + c10::optional sym = c10::nullopt; + if (dim.has_dim_param()) { + // A specific dim param is produced. + // Search if this is already known, + // and assign the same Symbol. + GRAPH_UPDATE("Got dim_param:", dim.dim_param()); + for (auto pair : symbol_map) { + if (pair.second == dim.dim_param()) { + sym = pair.first; + break; + } + } + if (!sym) { + sym = c10::ShapeSymbol::newSymbol(); + symbol_map[sym.value()] = dim.dim_param(); + } + } else { + // A None dim param is produced. + // Assign a new Symbol, no need to keep track + // of it because there won't be duplicates. + sym = c10::ShapeSymbol::newSymbol(); + } + sizes.emplace_back(sym.value()); } } v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {}); - v_type = v_type->withSizes(sizes); + v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes)); + + if (v_type->sizes().concrete_sizes().has_value()) { + // Populate strides based on sizes info, if sizes are all static. + // Creating strides ensures yielding True for isCompleteTensor. + v_type = v_type->contiguous(); + } } return v_type; } ListTypePtr TorchListTypeFromONNX( - const onnx::TypeProto_Sequence& onnx_sequence_type) { + const onnx::TypeProto_Sequence& onnx_sequence_type, + SymbolDimMap& symbol_map) { c10::optional scalar_type; if (onnx_sequence_type.has_elem_type()) { auto onnx_seq_elem_type = onnx_sequence_type.elem_type(); if (onnx_seq_elem_type.has_tensor_type()) { auto onnx_tensor_type = onnx_seq_elem_type.tensor_type(); - auto v_tensor_type = TorchTensorTypeFromONNX(onnx_tensor_type); + auto v_tensor_type = + TorchTensorTypeFromONNX(onnx_tensor_type, symbol_map); auto v_type = ListType::create(v_tensor_type); return v_type; } @@ -118,21 +151,24 @@ ListTypePtr TorchListTypeFromONNX( void UpdateTorchValueByOnnxValueInfo( Value* v, - const onnx::ValueInfoProto& p_info) { + const onnx::ValueInfoProto& p_info, + SymbolDimMap& symbol_map) { if (!p_info.has_type()) { return; } auto p_type = p_info.type(); if (p_type.has_tensor_type()) { - auto torch_tensor_type = TorchTensorTypeFromONNX(p_type.tensor_type()); + auto torch_tensor_type = + TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_map); if (torch_tensor_type) { - v->setType(torch_tensor_type); + v->setType(MergeInferredType(v->type(), torch_tensor_type)); } } else if (p_type.has_sequence_type()) { - auto torch_list_type = TorchListTypeFromONNX(p_type.sequence_type()); + auto torch_list_type = + TorchListTypeFromONNX(p_type.sequence_type(), symbol_map); if (torch_list_type) { - v->setType(torch_list_type); + v->setType(MergeInferredType(v->type(), torch_list_type)); } } } @@ -148,59 +184,105 @@ bool IsSupportedNode(const Node* n) { // Skip when block size is zero. This is when the node is first created, // doesn't have subblocks attached yet. Run shape inference for these nodes // when the subgraph has already completed shape inferencing. - if ((node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) && - n->blocks().size() == 0) { - return false; + if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) { + if (n->blocks().size() == 0) { + return false; + } + for (auto b : n->blocks()) { + for (auto b_n : b->nodes()) { + if (!IsSupportedNode(b_n)) { + return false; + } + } + } } return true; } +Value* CloneValueFromListConstruct(Value* v, std::shared_ptr n_graph) { + auto lc_node = v->node(); + TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct); + // In jit/passes/onnx/peephole.cpp::eraseListConstruct, + // prim::ListConstruct is converted to onnx::Concat. The conversion should + // eventually be moved to symbolic. For now, treat this operator as + // special case, and change from list type to tensor type. The scalar type + // is preserved. If the elemtype is Int, insert a onnx::Concat node into + // the graph. + TypePtr elem = v->type()->cast()->getElementType(); + c10::optional scalar_type = c10::nullopt; + if (elem->cast()) { + scalar_type = at::kLong; + + auto lc_node = v->node(); + // ListConstruct Int[] output case, we need to transform to ONNX + // Concat to ensure the output is a single tensor(dynamic) type in + // order to be consumed as inputs + std::vector unsqueezed; + for (auto* input : lc_node->inputs()) { + Node* unsqueezed_node = + n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1)); + auto new_input = n_graph->addInput(); + new_input->copyMetadata(input); + unsqueezed_node->addInput(new_input); + unsqueezed_node->is_(attr::axes, {0}); + unsqueezed.emplace_back(unsqueezed_node->output()); + } + Node* concat_node = + n_graph->insertNode(n_graph->create(::c10::onnx::Concat, 1)); + concat_node->i_(attr::axis, 0); + for (auto v : unsqueezed) { + concat_node->addInput(v); + } + return concat_node->output(); + } else if (elem->cast()) { + scalar_type = at::kFloat; + } else if (elem->cast()) { + scalar_type = at::kBool; + } else if (auto t_type = elem->cast()) { + scalar_type = t_type->scalarType(); + } + + auto input = n_graph->addInput(); + if (scalar_type) { + auto v_type = TensorType::create( + scalar_type.value(), + at::kCPU, + c10::SymbolicShape(), + c10::VaryingShape{}, + {}); + input->setType(v_type); + } + return input; +} + // Clone the node n for the new graph. Node* CloneNodeToGraph(Node* n, std::shared_ptr n_graph) { auto clone_node = n_graph->createClone(n, [&n_graph](Value* v) { auto v_n = v->node(); - if (v_n->kind() == ::c10::onnx::Constant) { - // Clone the input if it is constant. - auto constant_n = n_graph->insertNode( - n_graph->createClone(v_n, [](Value* v) { return v; })); - return constant_n->output(); - } else if (v_n->kind() == ::c10::prim::ListConstruct) { - // In jit/passes/onnx/peephole.cpp::eraseListConstruct, - // prim::ListConstruct is converted to onnx::Concat. The conversion should - // eventually be moved to symbolic. For now, treat this operator as - // special case, and change from list type to tensor type. The scalar type - // is preserved. - TypePtr elem = v->type()->cast()->getElementType(); - c10::optional scalar_type = c10::nullopt; - if (elem->cast()) { - scalar_type = at::kLong; - } else if (elem->cast()) { - scalar_type = at::kFloat; - } else if (elem->cast()) { - scalar_type = at::kBool; - } else if (auto t_type = elem->cast()) { - scalar_type = t_type->scalarType(); + switch (v_n->kind()) { + case ::c10::onnx::Constant: { + // Clone the input if it is constant. + auto constant_n = n_graph->insertNode( + n_graph->createClone(v_n, [](Value* v) { return v; })); + return constant_n->output(); } - - auto input = n_graph->addInput(); - if (scalar_type) { - auto v_type = TensorType::create( - scalar_type.value(), - at::kCPU, - c10::SymbolicShape(), - c10::VaryingShape{}, - {}); - input->setType(v_type); + case ::c10::prim::ListConstruct: { + return CloneValueFromListConstruct(v, n_graph); + } + case ::c10::prim::PackPadded: { + auto input = n_graph->addInput(); + input->copyMetadata(v_n->input(0)); + return input; + } + default: { + // If the input is not constant, we cannot depend on its value + // in shape inference. Set it to graph input in the new graph, + // and copy over metadata, such as datatype and shape. + auto input = n_graph->addInput(); + input->copyMetadata(v); + return input; } - return input; - } else { - // If the input is not constant, we cannot depend on its value - // in shape inference. Set it to graph input in the new graph, - // and copy over metadata, such as datatype and shape. - auto input = n_graph->addInput(); - input->copyMetadata(v); - return input; } }); return clone_node; @@ -233,11 +315,11 @@ bool IsGraphValidForInference(std::shared_ptr graph) { void ConvertGraphToONNXProto( std::shared_ptr graph, - onnx::ModelProto& model_proto, + std::shared_ptr& model_proto, + SymbolDimMap& symbol_map, int opset_version) { - std::string model_str; RawDataExportMap export_map; - std::tie(model_str, export_map) = export_onnx( + std::tie(model_proto, export_map, symbol_map) = export_onnx( graph, {}, opset_version, @@ -250,15 +332,46 @@ void ConvertGraphToONNXProto( true, false, std::string()); - model_proto.ParseFromString(model_str); - for (int i = 0; i < model_proto.graph().output_size(); ++i) { - model_proto.mutable_graph()->mutable_output(i)->clear_type(); + for (int i = 0; i < model_proto->graph().output_size(); ++i) { + model_proto->mutable_graph()->mutable_output(i)->clear_type(); } } +// this function checks wheather the blocks of If node have the same return +// type. +bool IsBlockReturnTypeSame(Node* n) { + TORCH_INTERNAL_ASSERT(n->kind() == ::c10::onnx::If); + auto then_block = n->blocks()[0]; + auto else_block = n->blocks()[1]; + for (size_t i = 0; i < n->outputs().size(); i++) { + // check the type + auto then_block_type = then_block->outputs()[i]->type(); + auto else_block_type = else_block->outputs()[i]->type(); + if (then_block_type->cast() && + else_block_type->cast()) { + if (then_block_type->cast()->scalarType() != + else_block_type->cast()->scalarType()) { + return false; + } + } + } + return true; +} + // Any additional post process that are specific to individual node kind. void SpecialPostProcess(Node* n) { switch (n->kind()) { + case ::c10::onnx::If: { + if (!IsBlockReturnTypeSame(n) && IsStaticConditionONNX(n)) { + auto cond = ConditionValueONNX(n); + auto block_idx = cond ? 0 : 1; + for (size_t i = 0; i < n->outputs().size(); i++) { + n->outputs()[i]->setType( + n->blocks()[block_idx]->outputs()[i]->type()); + } + } + break; + } case ::c10::onnx::SequenceInsert: { // Special case when input sequence to SequenceInsert is empty. // onnx Sequence type requires element type to be set. @@ -284,20 +397,49 @@ void SpecialPostProcess(Node* n) { void UpdateOutputTypeByONNXProto( Node* n, Node* clone_node, - const onnx::ModelProto& model_proto) { + const onnx::ModelProto& model_proto, + SymbolDimMap& symbol_map) { auto graph_proto = model_proto.graph(); - // inferred shapes are stored in value_info. + + // get data from value_info and updated original graph. + auto updateNodeOutputsByONNXValueInfo = + [&](const onnx::ValueInfoProto& v_info) { + for (size_t i = 0; i < n->outputs().size(); ++i) { + if (clone_node->output(i)->debugName() == v_info.name()) { + UpdateTorchValueByOnnxValueInfo(n->output(i), v_info, symbol_map); + } + } + }; + + // Check graph outputs for inferred shapes. + for (size_t i = 0; i < graph_proto.output_size(); ++i) { + updateNodeOutputsByONNXValueInfo(graph_proto.output(i)); + } + + // Check value_infos for inferred shapes. for (size_t i = 0; i < graph_proto.value_info_size(); ++i) { - auto v_info = graph_proto.value_info(i); - // get data from value_info and updated original graph. - for (size_t j = 0; j < clone_node->outputs().size(); ++j) { - if (clone_node->output(j)->debugName() == v_info.name()) { - UpdateTorchValueByOnnxValueInfo(n->output(j), v_info); - } + updateNodeOutputsByONNXValueInfo(graph_proto.value_info(i)); + } +} + +void FetchBlockInputMetadataFromParent(Block* b) { + auto n = b->owningNode(); + if (nullptr != n && n->kind() == ::c10::onnx::Loop) { + // Copy node input metadata to subgraph input. + for (size_t i = 0; i < n->inputs().size(); ++i) { + b->inputs().at(i)->copyMetadata(n->inputs().at(i)); } } +} - SpecialPostProcess(n); +void ONNXShapeTypeInference(Block* b, int opset_version) { + FetchBlockInputMetadataFromParent(b); + for (auto n : b->nodes()) { + for (auto subblock : n->blocks()) { + ONNXShapeTypeInference(subblock, opset_version); + } + ONNXShapeTypeInference(n, opset_version); + } } } // namespace @@ -313,35 +455,121 @@ void ONNXShapeTypeInference(Node* n, int opset_version) { auto n_graph = std::make_shared(); auto clone_node = CloneNodeToGraph(n, n_graph); n_graph->insertNode(clone_node); + // Register all node outputs as graph outputs. for (auto output : clone_node->outputs()) { n_graph->registerOutput(output); } + ScalarTypeAnalysisForONNX(n_graph); + GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString()); GRAPH_DEBUG( "Cloned torch graph to run shape inference: ", n_graph->toString()); - if (!IsGraphValidForInference(n_graph)) { - GRAPH_UPDATE("Skipping ONNX shape inference for this node."); - return; + if (IsGraphValidForInference(n_graph)) { + // TODO: Some ops have conversion happen at Peephole pass. + // The conversion here is incomplete for these ops. + // e.g: ListConstruct, ListUnpack, etc. + std::shared_ptr model_proto; + SymbolDimMap symbol_map; + ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version); + GRAPH_DEBUG( + "ONNX graph to run shape inference: ", prettyPrint(*model_proto)); + + // infer shape + try { + onnx::shape_inference::InferShapes(*model_proto); + UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map); + } catch (std::runtime_error& ex) { + // TODO: include this as warning once we have a more consolidated warning + // system. + const char shape_err[] = "ShapeInferenceError"; + const char type_err[] = "TypeInferenceError"; + if ((strstr(ex.what(), shape_err) == NULL) && + (strstr(ex.what(), type_err) == NULL)) + throw; + GRAPH_DEBUG("ONNX shape inference fails with: ", ex.what()); + } + GRAPH_DEBUG( + "ONNX graph after shape inference: ", prettyPrint(*model_proto)); } - // TODO: Some ops have conversion happen at Peephole pass. - // The conversion here is incomplete for these ops. - // e.g: ListConstruct, ListUnpack, etc. - onnx::ModelProto model_proto; - ConvertGraphToONNXProto(n_graph, model_proto, opset_version); - GRAPH_DEBUG("ONNX graph to run shape inference: ", prettyPrint(model_proto)); - - // infer shape - onnx::shape_inference::InferShapes(model_proto); - GRAPH_DEBUG("ONNX graph after shape inference: ", prettyPrint(model_proto)); - - UpdateOutputTypeByONNXProto(n, clone_node, model_proto); + SpecialPostProcess(n); GRAPH_DEBUG( "Torch graph after shape inference:", n->owningGraph()->toString()); } +void ONNXSetDynamicInputShape( + std::shared_ptr& graph, + const std::unordered_map< + std::string, + std::unordered_map>& dynamic_axes, + const std::vector& input_names) { + GRAPH_UPDATE("ONNX set dynamic input shape."); + GRAPH_UPDATE("dynamic axes tensor names:", [&]() { + std::vector res(dynamic_axes.size()); + std::transform( + dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) { + return pair.first; + }); + return res; + }()); + + std::map name_to_sym; + + for (int i = 0; i < input_names.size(); ++i) { + auto input_name = input_names[i]; + if (dynamic_axes.find(input_name) != dynamic_axes.end()) { + auto axes_names = dynamic_axes.find(input_name)->second; + TORCH_INTERNAL_ASSERT(i < graph->inputs().size()); + auto input_tensor_type = graph->inputs()[i]->type()->cast(); + if (!input_tensor_type) { + continue; + } + + auto shape_ref = input_tensor_type->symbolic_sizes().sizes(); + TORCH_CHECK( + shape_ref.has_value(), "Input tensor shape should have value."); + auto shape = shape_ref.value(); + + for (auto pair : axes_names) { + auto axis = pair.first; + auto name = pair.second; + if (name_to_sym.find(name) == name_to_sym.end()) { + name_to_sym[name] = ::c10::ShapeSymbol::newSymbol(); + } + TORCH_CHECK( + axis < shape.size(), + "Dynamic shape axis should be no more than the shape dimension for ", + name); + shape[axis] = name_to_sym[name]; + } + + graph->inputs()[i]->setType( + input_tensor_type->withSymbolicShapes(::c10::SymbolicShape(shape))); + } + } +} + +void ONNXAssignOutputShape( + std::shared_ptr& graph, + at::ArrayRef outputs, + bool onnx_shape_inference) { + TORCH_INTERNAL_ASSERT(graph->outputs().size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (onnx_shape_inference) { + graph->outputs()[i]->setType(MergeInferredType( + TensorType::create(outputs[i]), graph->outputs()[i]->type())); + } else { + graph->outputs()[i]->inferTypeFrom(outputs[i]); + } + } +} + +void ONNXShapeTypeInference(std::shared_ptr& graph, int opset_version) { + ONNXShapeTypeInference(graph->block(), opset_version); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index d373f3a06d92a..79e7c06045ea9 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -8,11 +8,39 @@ namespace jit { TORCH_API TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type); +// Update graph input types with dynamic axes info. +// Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol. +// Note it is possible for multiple axes to share the same ShapeSymbol, +// if they are defined as such in dynamic_axes. +TORCH_API void ONNXSetDynamicInputShape( + std::shared_ptr& graph, + const std::unordered_map< + std::string, + std::unordered_map>& dynamic_axes, + const std::vector& input_names); + +// Update graph output with types of output Tensors. +// If onnx_shape_inference is true, types of output Tensors will be compared and +// merged with inferred types. It is possible that inferred types contain +// dynamic axes, hence it takes precedence over types of output Tensors. +TORCH_API void ONNXAssignOutputShape( + std::shared_ptr& graph, + at::ArrayRef outputs, + bool onnx_shape_inference); + // Utilize ONNX Shape Inference for node. // The node must have ONNX namespace, and is valid ONNX node accroding to spec. // On successful ONNX shape inference runs, the function updates output types of // n with inferred shape and type. Otherwise n is unchanged. TORCH_API void ONNXShapeTypeInference(Node* n, int opset_version); +// Utilize ONNX Shape Inference for graph. +// Internally calls ONNXShapeTypeInference for each node, to achieve more +// coverage that skips only individual nodes if illegal, instead of skipping for +// the entire graph. +TORCH_API void ONNXShapeTypeInference( + std::shared_ptr& g, + int opset_version); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 0de7f365e407a..b9b9fe1d43843 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -25,20 +26,21 @@ using namespace ::c10::onnx; double getScaleFromInput(Node* input_node) { c10::optional scale; std::string input_name = input_node->kind().toQualString(); - std::unordered_set noscale_ops = {"quantized::max_pool2d", - "aten::max_pool2d", - "aten::relu", - "prim::ListUnpack", - "aten::split_with_sizes", - "quantized::nchw2nhwc", - "quantized::nhwc2nchw", - "aten::slice", - "aten::avg_pool2d", - "quantized::cat", - "prim::ListConstruct", - "aten::upsample_nearest2d", - "aten::sigmoid", - "aten::reshape"}; + std::unordered_set noscale_ops = { + "quantized::max_pool2d", + "aten::max_pool2d", + "aten::relu", + "prim::ListUnpack", + "aten::split_with_sizes", + "quantized::nchw2nhwc", + "quantized::nhwc2nchw", + "aten::slice", + "aten::avg_pool2d", + "quantized::cat", + "prim::ListConstruct", + "aten::upsample_nearest2d", + "aten::sigmoid", + "aten::reshape"}; if (input_name == "aten::quantize_per_tensor") { TORCH_CHECK( input_node->inputs().size() > 1, @@ -273,10 +275,11 @@ void unpackQuantizedWeightsHelper( std::vector wt_sizes = unpacked_weight.sizes().vec(); if (unpacked_weight.ndimension() == 4) { unpacked_weight.permute({0, 2, 3, 1}); - wt_sizes = {unpacked_weight.size(0), - unpacked_weight.size(2), - unpacked_weight.size(3), - unpacked_weight.size(1)}; + wt_sizes = { + unpacked_weight.size(0), + unpacked_weight.size(2), + unpacked_weight.size(3), + unpacked_weight.size(1)}; } // Remove packed_params diff --git a/torch/csrc/jit/passes/pass_manager.cpp b/torch/csrc/jit/passes/pass_manager.cpp index 1addeb6eab6de..cbd6daa34205d 100644 --- a/torch/csrc/jit/passes/pass_manager.cpp +++ b/torch/csrc/jit/passes/pass_manager.cpp @@ -61,7 +61,7 @@ void clearAllPrePasses() { // LEGACY CALL RegisterPostPass::RegisterPostPass(GraphPass p) { - registerPass(p); + registerPass(std::move(p)); } } // namespace jit diff --git a/torch/csrc/jit/passes/pass_manager.h b/torch/csrc/jit/passes/pass_manager.h index 73109f3d25460..e9aeb09e8c0cc 100644 --- a/torch/csrc/jit/passes/pass_manager.h +++ b/torch/csrc/jit/passes/pass_manager.h @@ -129,6 +129,9 @@ struct C10_EXPORT PassManager { isRegistered(true); } } + + // clang-tidy requires virtual destructor; + virtual ~PassManager() = default; }; } // namespace jit diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 9035a81b74908..f3786c5524bb7 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -1,9 +1,11 @@ #include + #include #include #include #include #include +#include #include #include #include @@ -25,6 +27,7 @@ struct PeepholeOptimizeImpl { : graph_(graph), shape_peepholes_(!disable_shape_peepholes) { run(graph->block()); PeepholeOptimizeListIdioms(graph); + PeepholeOptimizeAliasSensitive(graph); } // The intent for this optimization pass is to catch all of the small, easy to @@ -87,8 +90,9 @@ struct PeepholeOptimizeImpl { "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*const_inputs=*/attr::size)) { // x.expand(x.size()) == x - if (auto input_type = - node->namedInput(attr::self)->type()->cast()) { + auto input_type = + node->namedInput(attr::self)->type()->cast(); + if (input_type && shape_peepholes_) { auto expanded_sizes = node->get>(attr::size); auto input_type_sizes = input_type->sizes().concrete_sizes(); if (expanded_sizes.has_value() && input_type_sizes && @@ -110,8 +114,9 @@ struct PeepholeOptimizeImpl { input_node->input()->debugName()); node->output()->replaceAllUsesWith(input_node->input()); } - } else if (node->matches( - "aten::type_as(Tensor self, Tensor other) -> Tensor")) { + } else if ( + node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor") && + shape_peepholes_) { // x.type_as(y) == x iff x.type() == y.type() auto self_type = node->input(0)->type()->expect(); auto other_type = node->input(1)->type()->expect(); @@ -420,11 +425,11 @@ void FuseAddMM(Block* block) { // Attempts to find a matrix with a defined scalar type to type as auto* type_as_mat = mat1; - if (!type_as_mat->type()->expect()->scalarType()) { + if (!type_as_mat->type()->expectRef().scalarType()) { type_as_mat = mat2; } auto mat_scalar_type = - type_as_mat->type()->expect()->scalarType(); + type_as_mat->type()->expectRef().scalarType(); // we can't use type_as if we don't know the target type (mm), the // bias needs to be coerced to diff --git a/torch/csrc/jit/passes/peephole_alias_sensitive.cpp b/torch/csrc/jit/passes/peephole_alias_sensitive.cpp new file mode 100644 index 0000000000000..13a51914d9526 --- /dev/null +++ b/torch/csrc/jit/passes/peephole_alias_sensitive.cpp @@ -0,0 +1,79 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +// This pass only does optimizations which requires Alias Analysis +// It is seprated out from Peephole Pass so that Peephole does not have +// maintain alias db correctness throughout the pass. +// In the future `runAliasingSensitivePeepholeTransformations` +// in peephole.cpp can be incorporated and keep the alias-db +// correct throughout transformations so we only need to build it once +struct PeepholeOptimizeAliasSensitiveImpl { + PeepholeOptimizeAliasSensitiveImpl(std::shared_ptr graph) + : graph_(std::move(graph)), + aliasDb_(torch::make_unique(graph_)) { + run(graph_->block()); + } + + private: + void replaceWithIValue(Value* v, IValue val) { + WithInsertPoint guard(v->node()); + v->replaceAllUsesWith(v->owningGraph()->insertConstant(val)); + } + + void run(Block* block) { + for (Node* node : block->nodes()) { + for (Block* b : node->blocks()) { + run(b); + } + + // dim(conv(x)) extremely common and prevents Conv->BN fusion + if (node->kind() == aten::conv1d || node->kind() == aten::conv2d || + node->kind() == aten::conv3d) { + auto dim_uses = c10::filter(node->output()->uses(), [](const Use& use) { + return use.user->kind() == aten::dim; + }); + if (dim_uses.size() == 0) { + continue; + } + auto kind = node->kind(); + int64_t output_size = + kind == aten::conv1d ? 3 : (kind == aten::conv2d ? 4 : 5); + // this is to handle potential resize_ calls, however unlikely + // if we add more checks related to resize_ in the graph, + // factor this out like collectResizeSet in shape_analysis + if (!aliasDb_->hasWriters(node->output())) { + for (const Use& dim_use : dim_uses) { + replaceWithIValue(dim_use.user->output(), output_size); + } + } else { + for (const Use& dim_use : dim_uses) { + if (aliasDb_->moveAfterTopologicallyValid(node, dim_use.user)) { + replaceWithIValue(dim_use.user->output(), output_size); + } + } + } + continue; + } + } + } + + std::shared_ptr graph_; + std::unique_ptr aliasDb_; +}; + +void PeepholeOptimizeAliasSensitive(const std::shared_ptr& graph) { + PeepholeOptimizeAliasSensitiveImpl opt(graph); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/peephole_alias_sensitive.h b/torch/csrc/jit/passes/peephole_alias_sensitive.h new file mode 100644 index 0000000000000..8148b7451c283 --- /dev/null +++ b/torch/csrc/jit/passes/peephole_alias_sensitive.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// Peephole Optimizes alias sensitive peepholes +// Currently this is invoked as part of PeepholeOptimize +TORCH_API void PeepholeOptimizeAliasSensitive( + const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/peephole_list_idioms.cpp b/torch/csrc/jit/passes/peephole_list_idioms.cpp index 2007b1b603daf..12f710ccbc808 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_list_idioms.cpp @@ -25,8 +25,9 @@ c10::optional normalizeIndex(int64_t index, size_t len) { // so we first use the Alias Db to collect the set of list values // which we shouldn't optimize. struct PeepholeOptimizeListIdiomsImpl { - PeepholeOptimizeListIdiomsImpl(const std::shared_ptr& graph) - : graph_(graph), aliasDb_(torch::make_unique(graph_)) { + PeepholeOptimizeListIdiomsImpl(std::shared_ptr graph) + : graph_(std::move(graph)), + aliasDb_(torch::make_unique(graph_)) { collectMutatedLists(graph_->block()); run(graph_->block()); } @@ -88,6 +89,45 @@ struct PeepholeOptimizeListIdiomsImpl { } } } + } else if (node->kind() == prim::ListUnpack) { + auto list_creation_node = first_input->node(); + if (list_creation_node->kind() == prim::ListConstruct) { + // if sizes are unequal it's a runtime error + if (list_creation_node->inputs().size() != node->outputs().size()) { + continue; + } + for (size_t i = 0; i < node->outputs().size(); ++i) { + node->output(i)->replaceAllUsesWith( + list_creation_node->inputs().at(i)); + } + } + } else if (node->kind() == aten::add) { + if (node->inputs().size() != 2) { + continue; + } + auto second_input = node->inputs().at(1); + // already checked first, need to check second + if (mutated_lists_.count(second_input)) { + continue; + } + if (first_input->node()->kind() != prim::ListConstruct || + second_input->node()->kind() != prim::ListConstruct) { + continue; + } + WithInsertPoint guard(node); + auto list_construct = + graph_->insertNode(graph_->create(prim::ListConstruct)); + list_construct->output()->setType(node->output()->type()); + for (Value* v : first_input->node()->inputs()) { + list_construct->addInput(v); + } + for (Value* v : second_input->node()->inputs()) { + list_construct->addInput(v); + } + node->output()->replaceAllUsesWith(list_construct->output()); + if (mutated_lists_.count(node->output())) { + mutated_lists_.insert(list_construct->output()); + } } } } diff --git a/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp b/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp index c3b78a0d6123c..2f20dc8df62d2 100644 --- a/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp +++ b/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp @@ -1,4 +1,5 @@ #include + #include #include diff --git a/torch/csrc/jit/passes/quantization/finalize.cpp b/torch/csrc/jit/passes/quantization/finalize.cpp index 635c02728f6b8..af4a7bbc332bc 100644 --- a/torch/csrc/jit/passes/quantization/finalize.cpp +++ b/torch/csrc/jit/passes/quantization/finalize.cpp @@ -1,5 +1,7 @@ #include + #include +#include #include #include #include @@ -65,10 +67,14 @@ void InsertPrepackUnpack(Module& module) { void FoldQuantizedPrepackingOps(Module& module) { auto filter_fn = [](const Node* n) -> bool { return ( - (n->kind() == Symbol::fromQualString("quantized::linear_prepack")) || + n->kind() == Symbol::fromQualString("quantized::linear_prepack") || n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") || n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") || - n->kind() == Symbol::fromQualString("quantized::conv3d_prepack")); + n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") || + n->kind() == + Symbol::fromQualString("quantized::conv_transpose1d_prepack") || + n->kind() == + Symbol::fromQualString("quantized::conv_transpose2d_prepack")); }; PrePackingOpsFolder(module, filter_fn, "quantized"); } @@ -77,6 +83,15 @@ Module Finalize( Module& module, QuantType quant_type, const std::vector& preserved_attrs) { + // Tracing annotates the resulting graph with shape information. In many case, + // user applies different input shapes to traced graph. It is on the user to + // know it is correct to do so. The quantized module needs to be clean up and + // To prevent the JIT optimizations from leveraging the annotated shape info, + // clear shape information in the graph. + for (auto func : module.type()->methods()) { + ClearProfilingInformation(func->graph()); + } + auto graph = module.get_method("forward").graph(); InsertPrepackUnpack(graph); GRAPH_DUMP("Before QuantFusion:", graph); diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index ddaf150803fee..0692d4c653a19 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -32,6 +33,8 @@ std::vector _static_quantizable_aten_funcs = { "conv1d", "conv2d", "conv3d", + "conv_transpose1d", + "conv_transpose2d", "linear", "hardswish", "hardswish_", @@ -174,17 +177,19 @@ const int _sym_zero_point = 128; std::tuple _per_tensor_asym_qparam = std::make_tuple( c10::kPerTensorAffine, - QParamVector({std::make_pair(".scale", IValue(_asym_scale)), - std::make_pair(".zero_point", IValue(_asym_zero_point)), - std::make_pair(".scalar_type", IValue(c10::kQUInt8))})); + QParamVector( + {std::make_pair(".scale", IValue(_asym_scale)), + std::make_pair(".zero_point", IValue(_asym_zero_point)), + std::make_pair(".scalar_type", IValue(c10::kQUInt8))})); // quantization parrameters for ops with range -1 to 1 // for example: aten/src/ATen/native/quantized/cpu/qtanh.cpp std::tuple _per_tensor_sym_qparam = std::make_tuple( c10::kPerTensorAffine, - QParamVector({std::make_pair(".scale", IValue(_sym_scale)), - std::make_pair(".zero_point", IValue(_sym_zero_point)), - std::make_pair(".scalar_type", IValue(c10::kQUInt8))})); + QParamVector( + {std::make_pair(".scale", IValue(_sym_scale)), + std::make_pair(".zero_point", IValue(_sym_zero_point)), + std::make_pair(".scalar_type", IValue(c10::kQUInt8))})); // Map from aten op symbol to the quantization parameters // for the ops with fixed quantization parameters @@ -216,10 +221,11 @@ std::vector _propagate_quant_single_input_ops = {"cat"}; // the inputs are quantized // if the second input is a Scalar, we'll only look at the first input to decide // if we need to quantize the output -std::vector _propagate_quant_binary_ops = {"add", - "add_", - "mul", - "mul_"}; +std::vector _propagate_quant_binary_ops = { + "add", + "add_", + "mul", + "mul_"}; // Check if `use` is an aten function of name `func_name` and if value // `v` is the nth argument (if provided) of the function. @@ -270,11 +276,14 @@ bool isWeight(Value* v) { v, // ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq, // %mode_enum, %sparse, %per_sample_weights, %include_last_offset) - AtenFuncArgs({{"conv1d", 1}, - {"conv2d", 1}, - {"conv3d", 1}, - {"linear", 1}, - {"embedding_bag", 0}}), + AtenFuncArgs( + {{"conv1d", 1}, + {"conv2d", 1}, + {"conv3d", 1}, + {"conv_transpose1d", 1}, + {"conv_transpose2d", 1}, + {"linear", 1}, + {"embedding_bag", 0}}), // embedding_bag - prim::CallFunction(%func, %input.1, %weight, // %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse, // %per_sample_weights.1, %include_last_offset) @@ -286,7 +295,12 @@ bool isBiasOfConvOrLinear(Value* v) { bool result = matchArgPattern( v, AtenFuncArgs( - {{"conv1d", 2}, {"conv2d", 2}, {"conv3d", 2}, {"linear", 2}}), + {{"conv1d", 2}, + {"conv2d", 2}, + {"conv3d", 2}, + {"conv_transpose1d", 2}, + {"conv_transpose2d", 2}, + {"linear", 2}}), CallFuncArgs({{"linear", 3}})); return result; } @@ -518,7 +532,7 @@ bool useQuantizable(const Use& use, QuantType quant_type) { std::shared_ptr getCallFunctionGraph(Node* n) { auto* func_node = n->input(0)->node(); - auto func = func_node->output()->type()->expect()->function(); + auto func = func_node->output()->type()->expectRef().function(); TORCH_CHECK( func->isGraphFunction(), "Quantization only works for graph function"); return func->graph(); @@ -728,6 +742,20 @@ bool is_conv3d_module( match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv3d"); } +bool is_conv_transpose1d_module( + const Match& match, + const std::unordered_map& vmap) { + return is_module( + match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose1d"); +} + +bool is_conv_transpose2d_module( + const Match& match, + const std::unordered_map& vmap) { + return is_module( + match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose2d"); +} + bool is_batchnorm2d_module( const Match& match, const std::unordered_map& vmap) { diff --git a/torch/csrc/jit/passes/quantization/helper.h b/torch/csrc/jit/passes/quantization/helper.h index 440134ccbd3f0..f473b4b7caa86 100644 --- a/torch/csrc/jit/passes/quantization/helper.h +++ b/torch/csrc/jit/passes/quantization/helper.h @@ -194,6 +194,14 @@ bool is_conv3d_module( const Match& match, const std::unordered_map& vmap); +bool is_conv_transpose1d_module( + const Match& match, + const std::unordered_map& vmap); + +bool is_conv_transpose2d_module( + const Match& match, + const std::unordered_map& vmap); + bool is_batchnorm2d_module( const Match& match, const std::unordered_map& vmap); diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index f637a681211e2..649869c24736d 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -164,6 +165,14 @@ class ModuleCloneHelper { for (auto& fn : type->methods()) { clone_method(module, r, *fn, module_qconfig_map, type_remap); } + // Execute __setstate__(__getstate__()) to initialize custom class + // members. + if (auto setstate_method = r.find_method("__setstate__")) { + auto getstate_method = r.find_method("__getstate__"); + TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__"); + auto state = (*getstate_method)(Stack{}); + (*setstate_method)(Stack{state}); + } } return r; } @@ -188,7 +197,7 @@ class ModuleCloneHelper { } for (Node* node : block->nodes()) { // remapping type for module instance - if (node->kind() == prim::CallMethod) { + if (node->kind() == prim::CallMethod || node->kind() == prim::GetAttr) { Value* instance = node->inputs()[0]; auto child_opt = getInvokedModuleOpt(source, node, self); if (child_opt.has_value()) { @@ -386,7 +395,17 @@ class InsertObserversHelper { // are observed bool shouldObserve( Node* n, - const std::unordered_set& block_observed_values) { + const std::unordered_set& block_observed_values, + QuantType quant_type) { + // Check whether node output uses can be quantized, eg cat followed by + // linear op + for (Value* v : n->outputs()) { + for (const auto& use : v->uses()) { + if (useQuantizable(use, quant_type)) { + return true; + } + } + } if (isPropagateQuantSingleInputOp(n)) { return isObserved(n->input(0), block_observed_values); } else if (isPropagateQuantBinaryOp(n)) { @@ -1520,7 +1539,8 @@ InsertObserversHelper::insertObserversFor( // If the node is one of the propagate quant node, e.g. // aten::cat, we should observe its output only // if the input of the node is observed - if (observer_opt && shouldObserve(n, block_observed_values)) { + if (observer_opt && + shouldObserve(n, block_observed_values, quant_type_)) { recordObserved( v, *observer_opt, values_to_observe, block_observed_values); } diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 5c6851ce4fab7..53a13b6cf183d 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -330,6 +331,8 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { // Create and insert quantized embedding op. Value* none = g->insertConstant(IValue()); Value* zero = g->insertConstant(IValue(0)); + bool pruned_wt = false; + auto pruned_const = g->insertConstant(pruned_wt); if (is_aten_op) { TORCH_CHECK( @@ -340,6 +343,10 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { for (auto i = 1; i < inputs_size - 1; ++i) { qembedding_bag_inputs.push_back(embedding_bag_inputs[i]); } + // The sparse field in the float operator denotes sparse gradients. + // For inference this stands for pruned weights. We currently don't support + // pruning in graph mode API so we set the field to 0 for inference. + qembedding_bag_inputs[5] = pruned_const; } else { TORCH_CHECK( inputs_size == 11, @@ -348,16 +355,13 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets qembedding_bag_inputs.push_back( embedding_bag_inputs[6]); // scale_grad_by_freq - qembedding_bag_inputs.push_back(zero); // zero - qembedding_bag_inputs.push_back(embedding_bag_inputs[8]); // sparse + qembedding_bag_inputs.push_back(zero); // mode + qembedding_bag_inputs.push_back(pruned_const); // pruned_weights qembedding_bag_inputs.push_back( embedding_bag_inputs[9]); // per_sample_weights } - if (op_name == "embedding_bag_4bit") { - // 4-bit op has an extra input compressed_indices_mapping - qembedding_bag_inputs.push_back(none); - } + qembedding_bag_inputs.push_back(none); // compressed_indices_mapping qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 1]); Node* qembedding_bag = @@ -475,7 +479,7 @@ void ReplicateChooseQParamsQuantDequant(std::shared_ptr& graph) { matched_choose_qparam, matched_quantize, matched_dequantize)); } } - for (const auto nodes : nodes_to_rewrite) { + for (const auto& nodes : nodes_to_rewrite) { auto quant_node = std::get<1>(nodes); auto dequant_node = std::get<2>(nodes); // get input of quantize call. @@ -984,12 +988,18 @@ std::tuple InsertQuantDeQuantHelper:: v->debugName(), " exists."); QParamVector qparams; - c10::QScheme qscheme; + c10::QScheme qscheme = c10::kPerTensorAffine; auto observer_module = module.attr(observer_name.value()).toModule(); auto scalar_type = observer_module.attr("dtype"); - if (isPlaceholderObserver(n->input(0)) || - scalar_type == at::ScalarType::Half) { + if (isPlaceholderObserver(n->input(0))) { + // get compute_dtype for dynamic quantization + if (observer_module.hasattr("compute_dtype")) { + qparams.push_back(std::make_pair( + "_scalar_type", observer_module.attr("compute_dtype"))); + } + return std::make_tuple(qscheme, qparams); + } else if (scalar_type == at::ScalarType::Half) { return std::make_tuple(qscheme, qparams); } auto calculate_qparams = observer_module.get_method("calculate_qparams"); @@ -1113,10 +1123,11 @@ void InsertQuantDeQuantHelper::propagateQParams( "q_zero_point"); Node* dtype = insertQParam( graph, quantized_input, prim::dtype, IntType::get(), "dtype"); - quant_inputs = {original_output, - scale->output(), - zero_point->output(), - dtype->output()}; + quant_inputs = { + original_output, + scale->output(), + zero_point->output(), + dtype->output()}; } Node* quant = insertQuant( graph, quant_inputs, quant_kind, original_output->debugName() + ".quant"); diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h index ba692d88c18c4..8248357d986e8 100644 --- a/torch/csrc/jit/passes/quantization/quantization_patterns.h +++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h @@ -407,6 +407,38 @@ graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %pad %r_quant = quantized::conv3d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point) return (%r_quant) )"; + // aten::conv_transpose1d + std::string conv_transpose1d = R"( +graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): + %a_dequant = aten::dequantize(%a_quant) + %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose1d_unpack(%packed_params) + %w_dequant = aten::dequantize(%w_quant) + %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) + %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) + return (%r_quant) )"; + + // quantized::conv_transpose1d + std::string quantized_conv_transpose1d = R"( +graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): + %r_quant = quantized::conv_transpose1d(%a_quant, %packed_params, %r_scale, %r_zero_point) + return (%r_quant) )"; + + // aten::conv_transpose2d + std::string conv_transpose2d = R"( +graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): + %a_dequant = aten::dequantize(%a_quant) + %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose2d_unpack(%packed_params) + %w_dequant = aten::dequantize(%w_quant) + %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) + %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) + return (%r_quant) )"; + + // quantized::conv_transpose1d + std::string quantized_conv_transpose2d = R"( +graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation): + %r_quant = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point) + return (%r_quant) )"; + std::string add_relu = R"( graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype): %a_dequant = aten::dequantize(%a_quant) @@ -907,6 +939,12 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype) {"quantized::conv3d", conv3d, quantized_conv3d}, {"quantized::conv3d_relu", conv3d_relu, quantized_conv3d_relu}, {"quantized::conv3d_relu", conv3d_inplace_relu, quantized_conv3d_relu}, + {"quantized::conv_transpose1d", + conv_transpose1d, + quantized_conv_transpose1d}, + {"quantized::conv_transpose2d", + conv_transpose2d, + quantized_conv_transpose2d}, {"quantized::linear", linear, quantized_linear}, {"quantized::linear_relu", linear_relu, quantized_linear_relu}, {"quantized::linear_relu", linear_inplace_relu, quantized_linear_relu}, @@ -1128,12 +1166,44 @@ graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups): %r = aten::conv3d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups) return (%r) )"; + std::string conv_transpose1d_with_quant = R"( +graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): + %w_dequant = aten::dequantize(%w_quant) + %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) + return (%r) )"; + + std::string conv_transpose1d_with_quant_prepack = R"( +graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): + %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose1d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups) + %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose1d_unpack(%packed_params) + %w_dequant = aten::dequantize(%w_quant_unpacked) + %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation) + return (%r) )"; + + std::string conv_transpose2d_with_quant = R"( +graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): + %w_dequant = aten::dequantize(%w_quant) + %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation) + return (%r) )"; + + std::string conv_transpose2d_with_quant_prepack = R"( +graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation): + %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose2d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups) + %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose2d_unpack(%packed_params) + %w_dequant = aten::dequantize(%w_quant_unpacked) + %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation) + return (%r) )"; + return { {"conv1d_prepack_unpack", conv1d_with_quant, conv1d_with_quant_prepack}, {"conv2d_prepack_unpack", conv2d_with_quant, conv2d_with_quant_prepack}, - {"conv3d_prepack_unpack", conv3d_with_quant, conv3d_with_quant_prepack} - - }; + {"conv3d_prepack_unpack", conv3d_with_quant, conv3d_with_quant_prepack}, + {"conv_transpose1d_prepack_unpack", + conv_transpose1d_with_quant, + conv_transpose1d_with_quant_prepack}, + {"conv_transpose2d_prepack_unpack", + conv_transpose2d_with_quant, + conv_transpose2d_with_quant_prepack}}; } } // namespace jit diff --git a/torch/csrc/jit/passes/reconstruct_scopes.cpp b/torch/csrc/jit/passes/reconstruct_scopes.cpp deleted file mode 100644 index a4787cd84c05d..0000000000000 --- a/torch/csrc/jit/passes/reconstruct_scopes.cpp +++ /dev/null @@ -1,206 +0,0 @@ -#include -#include - -namespace torch { -namespace jit { - -class ReconstructScopesPass { - public: - ReconstructScopesPass(const Module& m, Graph& g, std::string p) - : root_module_(m), - graph_(g), - prefix_(std::move(p)), - class_types_are_not_unique_(false){}; - void run(); - - private: - const Module& root_module_; - Graph& graph_; - std::string prefix_; - - // This boolean indicates whether there are two submodules of the same - // class type. This issue may occur in a scripted module and make it - // difficult to exactly track module information corresponding to each - // Node* after inlining the graph. Consider the following example: - - // class A(nn.Module): - // def __init__(self): - // super(A, self).__init__() - - // def forward(self, x): - // return x + 1 - - // class B(nn.Module): - // def __init__(self): - // super(B, self).__init__() - // self.A0 = A() - // self.A1 = A() - - // def forward(self, x): - // return self.A0(x) + self.A1(x) - - // m_traced = torch.jit.trace(B(), torch.Tensor([1])) - // m_scripted = torch.jit.script(B()) - - // In m_traced, self.A0 and self.A1 have different class types, but in - // m_scripted, self.A0 and self.A1 have the same class types. Therefore, - // it is difficult to distinguish 'A0' and 'A1' in the module hierarchy - // after the graph is inlined. In this case, we add a warning to let - // users know that the debugging information may be incomplete. - bool class_types_are_not_unique_; - - std::unordered_map func_to_module_; - std::unordered_map module_names_; - - void visitBlock(Block* b, const std::string& root_scope_string); - void visitNode(Node* n, const std::string& root_scope_string); - - std::string getModuleTypeName( - const Module& module, - const std::string& prefix); - void constructFunctionToModuleMap(const Module& module); - void constructRelativeNamesForModules( - const Module& module, - const std::string& prefix); - - std::string getScopeString(const InlinedCallStackEntry& frame) const; - - void appendSourceRangeInfo( - std::string& scopeString, - const InlinedCallStackEntry& frame) const; -}; - -void ReconstructScopesPass::constructFunctionToModuleMap(const Module& module) { - for (const auto& method : module.get_methods()) { - Function* func_ptr = &method.function(); - if (!class_types_are_not_unique_ && - func_to_module_.find(func_ptr) != func_to_module_.end()) { - class_types_are_not_unique_ = true; - } - func_to_module_[func_ptr] = module._ivalue(); - } - for (const Module& m : module.children()) { - constructFunctionToModuleMap(m); - } -} - -std::string ReconstructScopesPass::getModuleTypeName( - const Module& module, - const std::string& prefix) { - std::string moduleType = module.type()->str(); - size_t lastDotIndex = moduleType.rfind('.'); - if (lastDotIndex != std::string::npos) { - moduleType = moduleType.substr(lastDotIndex + 1); - } - return prefix + "(" + moduleType + ")"; -} - -void ReconstructScopesPass::constructRelativeNamesForModules( - const Module& module, - const std::string& prefix) { - module_names_[module._ivalue()] = getModuleTypeName(module, prefix); - for (const NameModule& s : module.named_children()) { - constructRelativeNamesForModules( - s.value, module_names_[module._ivalue()] + "." + s.name); - } -} - -void ReconstructScopesPass::appendSourceRangeInfo( - std::string& scopeString, - const InlinedCallStackEntry& frame) const { - SourceRange r = frame.second; - if (r.source()) { - if (auto orig = r.source()->findSourceRangeThatGenerated(r)) { - r = *orig; - } - } - if (auto file_line_col = r.file_line_col()) { - std::string filename; - size_t line, col; - std::tie(filename, line, col) = *file_line_col; - scopeString += "<" + filename + ":" + c10::to_string(line) + ":" + - c10::to_string(col) + ">"; - } -} - -std::string ReconstructScopesPass::getScopeString( - const InlinedCallStackEntry& frame) const { - Function* f = frame.first; - if (!func_to_module_.count(f)) { - return ""; - } - auto m = func_to_module_.at(f); - if (!module_names_.count(m)) { - return ""; - } - std::string scopeString = module_names_.at(m) + "." + f->name(); - - // When class types are not unique, the module information may be - // incomplele. In this case, we add source range information, - // which can be helpful for deugging purposes. - if (class_types_are_not_unique_) { - appendSourceRangeInfo(scopeString, frame); - } - return scopeString; -} - -void ReconstructScopesPass::visitNode( - Node* n, - const std::string& root_scope_string) { - for (Block* b : n->blocks()) { - visitBlock(b, root_scope_string); - } - ScopePtr sc = c10::make_intrusive(); - if (!n->callstack()) { - sc = sc->push(Symbol::scope(root_scope_string)); - } else { - for (const auto& frame : (*n->callstack())->vec()) { - auto name = getScopeString(frame); - GRAPH_UPDATE("Adding a scope ", name, " for node ", *n); - sc = sc->push(Symbol::scope(name)); - } - } - n->setScope(sc); - GRAPH_UPDATE("Updated node: ", *n); -} - -void ReconstructScopesPass::visitBlock( - Block* b, - const std::string& root_scope_string) { - for (Node* n : b->nodes()) { - visitNode(n, root_scope_string); - } -} - -void ReconstructScopesPass::run() { - GRAPH_DUMP("Graph before reconstructing scope", &graph_); - func_to_module_.clear(); - module_names_.clear(); - - constructFunctionToModuleMap(root_module_); - constructRelativeNamesForModules(root_module_, prefix_); - - if (class_types_are_not_unique_) { - TORCH_WARN( - "It seems that the module contain two instances of the same class type.\n", - "The current debugging program has not provided support for distinguishing ", - "the two instances of the same class type.\n", - "The module debugging information may be incomplete."); - } - - std::string root_scope_string = - getModuleTypeName(root_module_, prefix_) + ".forward"; - visitBlock(graph_.block(), root_scope_string); - GRAPH_DUMP("Graph after reconstructing scope", &graph_); -} - -void ReconstructScopes( - const Module& module, - Graph& g, - const std::string& prefix = "top") { - ReconstructScopesPass p(module, g, prefix); - p.run(); -} - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/reconstruct_scopes.h b/torch/csrc/jit/passes/reconstruct_scopes.h deleted file mode 100644 index b08655cb37410..0000000000000 --- a/torch/csrc/jit/passes/reconstruct_scopes.h +++ /dev/null @@ -1,37 +0,0 @@ -/** \brief A pass to reconstruct scopes of nodes from their inline callstacks. - * - * The pass takes the root module and a graph and for every graph node with - * non-empty inline call-stack it computes the scope from this callstack. - * - * Callstack can be thought of as a stack of pointers to Function, and Function - * in a general case may not be a part of any module. That's why this pass - * requires a root module to be passed in - we can traverse all methods of the - * module and its submodules and then recognize these methods in callstacks. - * - * Scope can be thought of as a stack of strings, so we basically converting a - * pointer to Function to a string, or in other words trying to find a name for - * a function in this module hierarchy. - * - * The produced scopes look like: - * top.submod1.function1/top.submod1.subsubmod1.function2 - * - * 'top' is the name we use for the root module itself, and it can be customized - * with an optional third argument of the pass. - * - * The pass would not change anything if inlining has not been run on the graph. - */ -#pragma once - -#include -#include - -namespace torch { -namespace jit { - -TORCH_API void ReconstructScopes( - const Module& module, - Graph& g, - const std::string& prefix); - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/remove_inplace_ops.cpp b/torch/csrc/jit/passes/remove_inplace_ops.cpp index 5d7ee3a2c5cc3..fb4445371c74f 100644 --- a/torch/csrc/jit/passes/remove_inplace_ops.cpp +++ b/torch/csrc/jit/passes/remove_inplace_ops.cpp @@ -8,6 +8,7 @@ static const std::unordered_map inPlaceToOutOfPlace = { {aten::sub_, aten::sub}, {aten::div_, aten::div}, {aten::mul_, aten::mul}, + {aten::masked_fill_, aten::masked_fill}, {aten::zero_, aten::zeros_like}, {aten::fill_, aten::full_like}}; diff --git a/torch/csrc/jit/passes/remove_mutation.h b/torch/csrc/jit/passes/remove_mutation.h index 6e04801f41fd7..5a53a2aae169e 100644 --- a/torch/csrc/jit/passes/remove_mutation.h +++ b/torch/csrc/jit/passes/remove_mutation.h @@ -10,8 +10,8 @@ namespace torch { namespace jit { struct MutationRemover { - MutationRemover(const std::shared_ptr& graph) - : aliasDb_(nullptr), graph_(graph) { + MutationRemover(std::shared_ptr graph) + : aliasDb_(nullptr), graph_(std::move(graph)) { aliasDb_ = torch::make_unique(graph_); } diff --git a/torch/csrc/jit/passes/requires_grad_analysis.cpp b/torch/csrc/jit/passes/requires_grad_analysis.cpp index deb7837e27d12..12c1b7b4658e0 100644 --- a/torch/csrc/jit/passes/requires_grad_analysis.cpp +++ b/torch/csrc/jit/passes/requires_grad_analysis.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index cf9a93b7f5211..bdb283c777814 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -67,8 +67,9 @@ bool containsTensorType(const TypePtr& t) { class ShapePropagator { public: - explicit ShapePropagator(std::shared_ptr graph) : aliasDb_(graph) { - collectResizeSet(std::move(graph)->block()); + explicit ShapePropagator(const std::shared_ptr& graph) + : aliasDb_(graph) { + collectResizeSet(graph->block()); } void PropagateShapeOnBlock(Block* block, bool insert_expands = true) { @@ -882,7 +883,7 @@ class ShapePropagator { "aten::trunc(Tensor self) -> Tensor", "aten::rot90(Tensor self, int k, int[] dims) -> Tensor", "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor", - "aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor", + "aten::slice(Tensor self, int dim, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor", "aten::alias(Tensor self) -> Tensor", }, [](Node* node) -> type_vec_t { @@ -1208,6 +1209,7 @@ class ShapePropagator { "aten::max(Tensor self) -> Tensor", "aten::min(Tensor self) -> Tensor", "aten::median(Tensor self) -> Tensor", + "aten::nanmedian(Tensor self) -> Tensor", "aten::norm(Tensor self, Scalar p) -> Tensor", "aten::std(Tensor self, bool unbiased) -> Tensor", "aten::trace(Tensor self) -> Tensor", @@ -1268,9 +1270,10 @@ class ShapePropagator { type->withScalarType(maybe_dtype_option->toScalarType())}; } if (type->scalarType()) { - return {at::isFloatingType(*type->scalarType()) - ? type - : type->withScalarType(at::kLong)}; + return { + at::isFloatingType(*type->scalarType()) + ? type + : type->withScalarType(at::kLong)}; } else { return {type}; } @@ -1354,6 +1357,7 @@ class ShapePropagator { "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", + "aten::nanmedian(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", }, [](Node* node) -> type_vec_t { diff --git a/torch/csrc/jit/passes/specialize_autogradzero.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp index 2fc95ae723397..6e660a32bd199 100644 --- a/torch/csrc/jit/passes/specialize_autogradzero.cpp +++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -358,7 +359,7 @@ struct AutogradZeroSpecializer { // AutogradAdds when possible. Outputs of other nodes are conservatively // marked Unknown and not optimized. void specializeAutogradZero(std::shared_ptr g) { - AutogradZeroSpecializer azs(g); + AutogradZeroSpecializer azs(std::move(g)); azs.run(); } diff --git a/torch/csrc/jit/passes/subgraph_rewrite.cpp b/torch/csrc/jit/passes/subgraph_rewrite.cpp index 89f03cb50b584..45d07fe75fed0 100644 --- a/torch/csrc/jit/passes/subgraph_rewrite.cpp +++ b/torch/csrc/jit/passes/subgraph_rewrite.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -72,22 +73,44 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph( } // Figure out what values we need to use as inputs and outputs for the - // replacement subgraph. These would be inputs and outputs of the subgraph - // we matched. + // replacement subgraph and where the replacement subgraph needs to be + // inserted. + Node* ins_point = nullptr; std::vector inputs, outputs; for (Value* v : pattern_graph.inputs()) { - inputs.push_back(match.values_map.at(v)); + Value* input = match.values_map.at(v); + if (!ins_point || ins_point->isBefore(input->node())) { + ins_point = input->node(); + } + inputs.push_back(input); } + AT_ASSERT(ins_point); + + // Check that the insertion point we've chosen precedes all the uses of the + // outputs - otherwise the replacement is incorrect and we have to skip it. + bool ins_point_before_uses = true; for (Value* v : pattern_graph.outputs()) { + Value* output = match.values_map.at(v); outputs.push_back(match.values_map.at(v)); + + for (const Use& u : output->uses()) { + if (u.user->isBefore(ins_point)) { + ins_point_before_uses = false; + break; + } + } + } + + if (!ins_point_before_uses) { + continue; } - // Insert a clone of replacement subgraph after the matched subgraph. + // Insert a clone of replacement subgraph. // `inputs` vector holds values that we would use as incoming values to the // new subgraph, and we will get `new_outputs` vector containing values // produced by this new subgraph - we will then rewrite old outputs with the // new ones. - WithInsertPoint insert_point(match.anchor); + WithInsertPoint insert_point(ins_point->next()); std::vector new_outputs = insertGraph(*graph, replacement_graph, inputs); diff --git a/torch/csrc/jit/passes/subgraph_rewrite.h b/torch/csrc/jit/passes/subgraph_rewrite.h index a108d245195d9..ee84f58ad087f 100644 --- a/torch/csrc/jit/passes/subgraph_rewrite.h +++ b/torch/csrc/jit/passes/subgraph_rewrite.h @@ -4,8 +4,8 @@ * the corresponding subgraphs with another subgraph. A special case of such * rewrites is fusion, where the new subgraph consists of just a single node. * - * There is a default set of most-common patterns that everyone could use, or - * alternatively an arbitrary pattern can be registered. + * There is a default set of the most common patterns that everyone could use. + * Alternatively, an arbitrary pattern can be registered. */ #pragma once @@ -37,13 +37,13 @@ TORCH_API Module PatternBasedRewrite(const Module& module); /** A class implementing API for pattern-based subgraph rewrites. * * To perform pattern-based subgraph rewrites on a module using this API, one - * needs to crete an object of such class, register rewrite patterns and run the - * transformation pass (`runOnModule`). + * needs to create an object of such class, register rewrite patterns and run + * the transformation pass (`runOnModule`). * * To use standard patterns, one could use `RegisterDefaultPatterns`. * - * To enable rewrites of custom patterns, they must be registered with - * `RegisterRewritePattern`. + * To enable rewrites of custom patterns, the custom patterns must be registered + * with `RegisterRewritePattern`. */ class TORCH_API SubgraphRewriter { public: @@ -51,12 +51,12 @@ class TORCH_API SubgraphRewriter { Module runOnModule(const Module& module); // Run pattern-based subgraph rewrite pass on the graph (used in testing). - // filter is a function that does extra filtering on the match, if it returns - // false for a given Match, we'll skip the match - // filter function takes a `Match` and a value map from parsing the pattern - // graph since we need to do extra filtering on the matched result but we need - // to refer to the values in the matched result through the values in pattern - // graph. + // `filter` is a function that does extra filtering on the match. If it + // returns false for a given Match, we'll skip the Match. The filter + // function's arguments consist of a Match and a value map from parsing the + // pattern graph. Both the Match and the value map are necessary because we + // need to 1) do extra filtering on the matched result as well as 2) refer to + // the values in the matched result through the values in the pattern graph. void runOnGraph( std::shared_ptr& graph, const std::vector& filters); @@ -77,7 +77,7 @@ class TORCH_API SubgraphRewriter { * * The method takes two parameters specifying the pattern: * \p PATTERN - IR string representing the pattern subgraph. - * \p REPLACEMENT - IR stringn representing the replacement subgraph. + * \p REPLACEMENT - IR string representing the replacement subgraph. * * See examples of pattern registering in `RegisterDefaultPatterns`. */ @@ -99,8 +99,8 @@ class TORCH_API SubgraphRewriter { /** Rewrite pattern descriptor. * - * This structure is used in implementation of `SubgraphRewriter` and not - * supposed to be used externally. + * This structure is used in the implementation of `SubgraphRewriter` and + * is not supposed to be used externally. */ struct RewritePatternDescr { std::string pattern; diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 3782c2af4f339..8bbc0b5d9a0e7 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -1,9 +1,13 @@ #include + #include +#include #include #include #include +#include #include +#include #include #include #include @@ -14,6 +18,11 @@ #include #include +// NOLINTNEXTLINE +C10_DEFINE_bool( + torch_jit_disable_cat, + false, + "disable aten::cat in TE fusion groups"); namespace torch { namespace jit { @@ -29,6 +38,23 @@ bool isSupportedForBlock(Node* node) { } } +bool usedOnlyInSize(Value* v) { + const auto& uses = v->uses(); + return std::all_of(uses.begin(), uses.end(), [](const Use& u) { + return u.user->matches("aten::size(Tensor self) -> int[]"); + }); +} + +Value* broadcastSizes(at::ArrayRef sizes, AliasDb* db) { + AT_ASSERT(!sizes.empty()); + Graph* graph = sizes[0]->owningGraph(); + Node* broadcast_n = + graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); + broadcast_n->output()->setType(ListType::ofInts()); + db->createValue(broadcast_n->output()); + return broadcast_n->output(); +} + namespace tensorexpr { bool isSupported(Node* node) { // For Block codegen we allow limited ops. @@ -62,17 +88,25 @@ bool isSupported(Node* node) { "aten::lt.Tensor(Tensor self, Tensor other) -> Tensor", "aten::lt.Scalar(Tensor self, Scalar other) -> Tensor", "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", - "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", - // TODO : do we support pow.Scalar ? - "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor", + // TODO: uncomment when we properly support pow + // "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", + // "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor", // TODO: support clamp_min, clamp_max "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", + "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", + "aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", + "aten::to.dtype_layout(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None" + ", bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor", + "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", + "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", + "aten::isnan(Tensor self) -> Tensor", + "aten::lgamma(Tensor self) -> Tensor", "aten::log10(Tensor self) -> Tensor", "aten::log(Tensor self) -> Tensor", "aten::log2(Tensor self) -> Tensor", - // TODO: log1p + "aten::log1p(Tensor self) -> Tensor", "aten::exp(Tensor self) -> Tensor", "aten::erf(Tensor self) -> Tensor", "aten::erfc(Tensor self) -> Tensor", @@ -96,6 +130,8 @@ bool isSupported(Node* node) { "aten::round(Tensor self) -> Tensor", "aten::trunc(Tensor self) -> Tensor", "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", + // "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor", + // "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor", TODO: requires 0-dim Tensor "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor", "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "aten::cat(Tensor[] tensors, int dim=0) -> Tensor", @@ -123,16 +159,20 @@ bool isSupported(Node* node) { "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor", "aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor", - "aten::where(Tensor condition) -> Tensor[]", // TODO: enable other min/max variants, operators that can be both // elementwise or reductions: "aten::min.other(Tensor self, Tensor other) -> Tensor", "aten::max.other(Tensor self, Tensor other) -> Tensor", // TODO: enable slice, shape inference is not implemented for this op yet }; + static const OperatorSet cuda_only_operator_set{ + "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", + }; static const OperatorSet supported_reduction_set{ "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor", + "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", }; // clang-format on @@ -150,6 +190,17 @@ bool isSupported(Node* node) { } } + // Operator is only supported on CUDA. + if (node->isMemberOf(cuda_only_operator_set)) { + auto device = tensorexpr::pickDeviceType(node->inputs()); + if (!device) { + device = tensorexpr::pickDeviceType(node->outputs()); + } + if (!device || device->is_cpu()) { + return false; + } + } + // non-const dtype / device for (auto arg_name : {"dtype", "device"}) { if (auto index = node->schema().argumentIndexWithName(arg_name)) { @@ -159,6 +210,10 @@ bool isSupported(Node* node) { } } + if (FLAGS_torch_jit_disable_cat && node->kind() == aten::cat) { + return false; + } + return true; } @@ -175,7 +230,7 @@ bool isSupported(Node* node) { } // namespace tensorexpr -static bool texpr_fuser_enabled_ = false; +static bool texpr_fuser_enabled_ = true; void setTensorExprFuserEnabled(bool val) { texpr_fuser_enabled_ = val; @@ -202,36 +257,35 @@ bool texprReductionsEnabled() { return texpr_reductions_enabled; } -// TODO: if a value has differently typed uses, temporarrily insert a node -// specializing the type for each use and later remove, instead of bailing -bool profiledWithDifferentTypes(Value* v) { - std::vector types; - for (const auto& use : v->uses()) { - if (use.user->kind() == prim::profile) { - types.push_back(use.user->ty(attr::profiled_type)); - } - } - for (size_t i = 1; i < types.size(); ++i) { - if (types.at(i - 1) != types.at(i)) { - return true; - } - } - return false; -} - void removeProfileNodesAndSpecializeTypes(Block* b) { for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { if (it->kind() == prim::profile) { GRAPH_DEBUG("Removing prim::profile: %", it->output()->debugName()); it->output()->replaceAllUsesWith(it->input()); - if (!profiledWithDifferentTypes(it->input())) { - it->input()->setType(it->ty(attr::profiled_type)); - } else { - GRAPH_DEBUG( - "Ignoring value with differently typed profiles :%", - it->output()->debugName()); + auto profiled_type = it->ty(attr::profiled_type)->expect(); + + // A value can be profiled with differently typed uses. + // This can occur from: + // - having a use which is not executed, so the type will be + // TensorType::get() + // - control-flow that depends on tensor type: + // if x.size() == 2 op(x) else op(x) + // - mutation of the value on a field represented in the tensor type + // op(x); x.resize_([...]); op(x) + + // The most common case today with num_profiles = 1 is from the first + // case. Here we can just ignore non-profiled uses, and choose any of the + // profiled uses. Because we guard all tensor types in the runtime, even + // if we set a Value to have a profiled type from one use and then execute + // a use with a different profiled type, we will still be correct. + // In the future we could consider unifying the types of uses, or adding a + // type refinement node so uses can have the correct corresponding type. + if (profiled_type == TensorType::get()) { + continue; } + it->input()->setType(it->ty(attr::profiled_type)); it.destroyCurrent(); + } else { for (Block* ib : it->blocks()) { removeProfileNodesAndSpecializeTypes(ib); @@ -241,7 +295,9 @@ void removeProfileNodesAndSpecializeTypes(Block* b) { } void RemoveProfileNodesAndSpecializeTypes(std::shared_ptr& graph) { + GRAPH_DEBUG("Before removeProfileNodesAndSpecializeTypes:\n", *graph); removeProfileNodesAndSpecializeTypes(graph->block()); + GRAPH_DEBUG("After removeProfileNodesAndSpecializeTypes:\n", *graph); } void removeTensorTypeSpecialization(Value* v) { @@ -277,6 +333,103 @@ void RemoveTensorTypeSpecializations(std::shared_ptr& graph) { removeTensorTypeSpecializations(graph->block()); } +void insertTypeGuard( + Node* guarded_node, + tensor_type_converter_t type_converter, + Symbol kind) { + GRAPH_DEBUG("Inserting a typecheck guard for a node", *guarded_node); + auto subgraph = SubgraphUtils::getSubgraph(guarded_node); + + // Fixup types of the subgraph inputs + std::vector inputs_to_check; + std::vector guard_types; + for (Value* input : guarded_node->inputs()) { + // We only check inputs of the guarded nodes and expect user to infer + // intermediates and outputs shapes + if (!input->type()->cast()) { + continue; + } + + // fusion outputs are already guarded + if (input->node()->kind() == prim::Constant || + input->node()->kind() == prim::FusionGroup) { + continue; + } + inputs_to_check.push_back(input); + guard_types.push_back(type_converter(input->type()->expect())); + } + if (!inputs_to_check.size()) { + return; + } + + // Add prim::TypeCheck node + // + // TypeCheck nodes look like the following: + // %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool = + // prim::TypeCheck(%inp1 : Tensor, %inp2 : Tensor) + // + // They have N inputs whose types we are going to check and N+1 outputs. The + // first N outputs specify expected types and N+1-th output holds the result + // of the check (bool). + Node* typecheck_node = + guarded_node->owningGraph() + ->create(kind, inputs_to_check, inputs_to_check.size() + 1) + ->insertBefore(guarded_node); + typecheck_node->tys_(attr::types, guard_types); + Value* typecheck_result = typecheck_node->output(inputs_to_check.size()); + + std::unordered_map typechecked_inputs; + for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) { + typechecked_inputs[typecheck_node->input(i)] = typecheck_node->output(i); + } + + // Fixup types of the typecheck node outputs, which are used by the op in + // execution + typecheck_node->output(inputs_to_check.size())->setType(BoolType::get()); + for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) { + typecheck_node->output(i)->setType(typecheck_node->input(i)->type()); + } + + // Insert if + auto versioning_if = + guarded_node->owningGraph() + ->create(prim::If, {typecheck_result}, guarded_node->outputs().size()) + ->insertAfter(typecheck_node); + for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) { + versioning_if->output(idx)->setType(guarded_node->output(idx)->type()); + guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx)); + } + auto true_block = versioning_if->addBlock(); + auto false_block = versioning_if->addBlock(); + + // Fill in the false block. It should contain the unoptimized + // copy of the fused subgraph. + WithInsertPoint guard(false_block->return_node()); + const auto subgraph_outputs = insertGraph( + *guarded_node->owningGraph(), *subgraph, guarded_node->inputs()); + for (Value* output : subgraph_outputs) { + false_block->registerOutput(output); + } + + // types get copied to the fallback graph, so remove specializations before + // replacing + removeTensorTypeSpecializations(false_block); + replaceBlockWithFallbackGraph(false_block, guarded_node->inputs()); + + // Fill in the true block. It has all inputs type-checked and its + // body should be the fusion group node. + guarded_node->moveBefore(true_block->return_node()); + for (size_t idx = 0; idx < guarded_node->inputs().size(); ++idx) { + if (typechecked_inputs.count(guarded_node->input(idx))) { + guarded_node->replaceInput( + idx, typechecked_inputs.at(guarded_node->input(idx))); + } + } + for (Value* output : guarded_node->outputs()) { + true_block->registerOutput(output); + } +} + class TensorExprFuser { public: TensorExprFuser( @@ -287,6 +440,132 @@ class TensorExprFuser { min_group_size_(min_group_size), disable_shape_checks_(disable_shape_checks) {} + // Builds up expressions that compute shapes of all intermediates (and + // outputs) of the fusion group, based on the sizes of inputs. You should run + // DCE to remove those that you end up not using. + std::unordered_map buildShapeExpressions(Node* fusion_group) { + GRAPH_DUMP("buildShapeExpressions for ", fusion_group->g(attr::Subgraph)); + WithInsertPoint insert_guard{fusion_group->next()}; + std::unordered_map shape_of; + + Graph* graph = fusion_group->owningGraph(); + auto subgraph = fusion_group->g(attr::Subgraph); + + auto inputs = fusion_group->inputs(); + auto sinputs = subgraph->inputs(); + AT_ASSERT(inputs.size() == sinputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i]->type()->isSubtypeOf(TensorType::get())) { + Value* soutput = graph->insert(aten::size, {inputs[i]}); + aliasDb_->createValue(soutput); + GRAPH_DEBUG( + "Adding a mapping for %", + sinputs[i]->debugName(), + " ", + getHeader(soutput->node())); + shape_of[sinputs[i]] = soutput; + } + } + + // When we have a guarantee that an output won't be removed, because it's + // used in expressions that don't involve size checks, we can use its size + // instead of computing a long chain of broadcasts, starting from the + // beginning of the kernel. + auto outputs = fusion_group->outputs(); + auto soutputs = subgraph->outputs(); + AT_ASSERT(outputs.size() == soutputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (usedOnlyInSize(outputs[i])) + continue; + Value* soutput = graph->insert(aten::size, {outputs[i]}); + aliasDb_->createValue(soutput); + shape_of[soutputs[i]] = soutput; + } + + for (Node* n : subgraph->nodes()) { + // XXX: Use of shape_of.emplace is crucial to the output shape + // optimization! + if (n->kind() == aten::cat) { + // This is a bit more involved, because we have to account for the case + // when inputs have different shapes, but fortunately those tensors are + // always outputs, and so we can simply avoid replacing their queries, + // because it won't help us. + continue; + } + if (n->kind() == prim::Constant) { + continue; + } + if (n->kind() == prim::ConstantChunk) { + Node* sizes_node = graph->insertNode( + graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2)); + sizes_node->i_(attr::dim, n->i(attr::dim)); + sizes_node->i_(attr::chunks, n->i(attr::chunks)); + for (Value* output : sizes_node->outputs()) { + aliasDb_->createValue(output); + } + Value* regular_size = sizes_node->outputs().at(0); + Value* last_size = sizes_node->outputs().at(1); + regular_size->setType(ListType::ofInts()); + last_size->setType(ListType::ofInts()); + auto outputs = n->outputs(); + for (Value* o : outputs.slice(0, outputs.size() - 1)) { + shape_of.emplace(o, regular_size); + } + shape_of.emplace(outputs.at(outputs.size() - 1), last_size); + continue; + } + auto tensor_inputs = filter(n->inputs(), [](Value* v) { + return v->type()->isSubtypeOf(TensorType::get()); + }); + GRAPH_DEBUG("Building sizes for ", getHeader(n)); + bool all_inputs_have_sizes = true; + auto shapes = fmap(tensor_inputs, [&](Value* v) { + GRAPH_DEBUG("Getting aten::size for %", v->debugName()); + all_inputs_have_sizes &= shape_of.count(v); + return shape_of.count(v) != 0 ? shape_of.at(v) : nullptr; + }); + + if (!all_inputs_have_sizes) { + GRAPH_DEBUG( + "Not all tensor arguments have sizes available to compute the broadcasted size", + getHeader(n)); + continue; + } + shape_of.emplace( + n->output(), + shapes.size() == 1 ? shapes[0] + : broadcastSizes(shapes, aliasDb_.get())); + } + return shape_of; + } + + void removeOutputsUsedOnlyInSize(Node* fusion_group) { + if (fusion_group->kind() != prim::TensorExprGroup) + return; + auto subgraph = fusion_group->g(attr::Subgraph); + + auto shape_of = buildShapeExpressions(fusion_group); + auto outputs = fusion_group->outputs().vec(); + auto soutputs = subgraph->outputs().vec(); + // XXX: Iterating in this order is not only good for performance reasons! + // It is also crucial for correctness (i has to reflect the current true + // index of outputs[i])! + for (int64_t i = static_cast(outputs.size()) - 1; i >= 0; --i) { + auto output = outputs[i]; + auto soutput = soutputs[i]; + if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) { + auto uses = output->uses(); + for (Use u : uses) { + AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]")); + u.user->output()->replaceAllUsesWith(shape_of.at(soutput)); + u.user->destroy(); + } + fusion_group->eraseOutput(i); + subgraph->eraseOutput(i); + } + } + } + void run() { aliasDb_ = torch::make_unique(graph_); RemoveRedundantProfiles(graph_); @@ -298,7 +577,7 @@ class TensorExprFuser { // fusion is done. inlineSmallFusionGroups(graph_->block()); GRAPH_DUMP("After inlining small fusion groups: ", graph_); - guardFusionGroups(graph_->block()); + prepareFusionGroupAndGuardOutputs(graph_->block()); GRAPH_DUMP("After guarding fusion groups: ", graph_); removeTensorTypeSpecializations(graph_->block()); GRAPH_DUMP("After removing tensor type specializations: ", graph_); @@ -361,6 +640,17 @@ class TensorExprFuser { } } + // No Ops in eager shouldn't be outputs of Fusion Groups because it + // will degrade perf and change aliasing relationships + static bool unexecutedEagerOp(Node* n) { + if (n->kind() != aten::to) { + return false; + } + + return *n->input(0)->type()->expect() == + *n->output()->type()->expect(); + } + std::pair scanNode(Node* n) { GRAPH_DEBUG("Considering node:", *n) @@ -371,7 +661,7 @@ class TensorExprFuser { // fusion group from - skip them. if (n->kind() == prim::ListConstruct || n->kind() == aten::slice || n->kind() == aten::unsqueeze || n->kind() == prim::ConstantChunk || - n->kind() == prim::Constant) { + n->kind() == prim::Constant || unexecutedEagerOp(n)) { return std::make_pair(++n->reverseIterator(), false); } return createFusionGroup(n); @@ -462,6 +752,8 @@ class TensorExprFuser { SubgraphUtils::unmergeSubgraph(n); return true; } + // Cleanup the subgraph from duplicated constants while we're at it. + ConstantPooling(subgraph); return false; } @@ -513,16 +805,27 @@ class TensorExprFuser { return fusion_group; } + bool shapeIsKnown(Value* v) { + if (v->type()->cast()) { + if (!v->isCompleteTensor()) { + return false; + } + if (*v->type()->cast()->dim() == 0) { + return false; + } + } + return true; + } bool allShapesAreKnown(Node* node) { // TODO: Relax the checks to support dynamic shapes for (Value* input : node->inputs()) { - if (input->type()->cast()) { - if (!input->isCompleteTensor()) { - return false; - } - if (*input->type()->cast()->dim() == 0) { - return false; - } + if (!shapeIsKnown(input)) { + return false; + } + } + for (Value* output : node->outputs()) { + if (!shapeIsKnown(output)) { + return false; } } return true; @@ -554,6 +857,91 @@ class TensorExprFuser { return true; } + bool typesAreSupported(Node* node) { + // clang-format off + // breaks up the schema strings so they are no longer discoverable with ctrl-F + static const OperatorSet float_only_operator_set{ + "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor", + "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor", + "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor", + "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", + }; + static const OperatorSet int_only_operator_set{ + "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor", + "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor", + "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor", + "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor", + }; + // clang-format on + + for (const Value* v : node->inputs()) { + if (auto const& tt = v->type()->cast()) { + auto const& st = tt->scalarType(); + + // All tensors must be typed. + if (!st) { + return false; + } + + // Byte tensors introduce too many corner cases in type promotion. + // Better not to try to handle them. + if (*st == c10::ScalarType::Byte) { + return false; + } + + // These operators only support floats, because integer divisors need to + // raise ZeroDivisionError. + if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) { + return false; + } + + // These operators have complicated casting rules for floats. + if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) { + return false; + } + } else if (node->isMemberOf(float_only_operator_set)) { + // Check scalar operands of float-only ops. + if (!v->type()->cast()) { + return false; + } + } else if (node->isMemberOf(int_only_operator_set)) { + if (!v->type()->cast()) { + return false; + } + } + } + if (node->kind() == aten::to) { + // only support same-device conversion + auto device = tensorexpr::pickDeviceType(node->inputs()); + auto output_device = tensorexpr::pickDeviceType(node->outputs()); + if (!device || !output_device || *device != *output_device) { + return false; + } + // non_blocking only applies in cross-device conversion, which we bail on + // copy arg only applies if op is a no-op, which we dont start fusion + // group from memory format is separately handled in NNC output + + // all non-Tensor arguments must be constant + for (size_t i = 1; i < node->inputs().size(); i++) { + if (node->inputs().at(i)->node()->kind() != prim::Constant) { + return false; + } + } + // cant support non-constant pin_memory or pin_memory = True + if (auto maybe_index = + node->schema().argumentIndexWithName("pin_memory")) { + int index = *maybe_index; + auto inp = node->input(index); + if (inp->type() != NoneType::get() && + constant_as(inp).value_or(true)) { + return false; + } + } + } + + return true; + } + #define REQ(cond) \ if (!(cond)) { \ GRAPH_DEBUG("Failed cond " #cond "\n"); \ @@ -561,17 +949,10 @@ class TensorExprFuser { } bool canHandle(Node* node) { - REQ(node->kind() != prim::Constant); REQ(disable_shape_checks_ || allShapesAreKnown(node)); REQ(isFusableOnDevice(node)); - // Don't include nodes whose inputs are tensor constants - we cannot handle - // them at the moment. - // TODO: actually support tensor constants and remove this. for (Value* input : node->inputs()) { - if (input->node()->kind() == prim::Constant) { - REQ(!input->type()->cast()) - } if (auto const& tt = input->type()->cast()) { auto st = tt->scalarType(); if (!st) { @@ -595,6 +976,14 @@ class TensorExprFuser { } REQ(tensorexpr::isSupported(node)); + REQ(typesAreSupported(node)); + + // A hook to optimizations limitter to allow bisecting the pass + auto allowed = JIT_OPT_LIMIT(); + if (!allowed) { + return false; + } + return true; } @@ -679,111 +1068,44 @@ class TensorExprFuser { } #undef REQ - void guardFusionGroup(Node* fusion_group) { - GRAPH_DEBUG("Inserting a typecheck guard for a node", *fusion_group); + // TODO: support constant tensors instead of setting them as input + void liftTensorConstantsFromFusionGroups(Node* fusion_group) { auto subgraph = SubgraphUtils::getSubgraph(fusion_group); - - // Fixup types of the subgraph inputs - std::vector inputs_to_check; - for (Value* input : fusion_group->inputs()) { - // We only check inputs of the fusion group and expect NNC to infer - // intermediates and outputs shapes - if (!input->type()->cast()) { - continue; + WithInsertPoint guard(fusion_group); + for (auto it = subgraph->block()->nodes().begin(); + it != subgraph->block()->nodes().end(); + ++it) { + auto n = *it; + if (n->kind() == prim::Constant && + n->output()->type()->cast()) { + auto constant = + fusion_group->owningGraph()->insertConstant(*toIValue(n->output())); + fusion_group->addInput(constant); + auto inputToGraph = subgraph->addInput(); + inputToGraph->setType(n->output()->type()); + n->output()->replaceAllUsesWith(inputToGraph); + it.destroyCurrent(); } - - // fusion outputs are already guarded - if (input->node()->kind() == prim::Constant || - input->node()->kind() == prim::FusionGroup) { - continue; - } - inputs_to_check.push_back(input); - } - if (!inputs_to_check.size()) { - return; } + } - // Add prim::TypeCheck node - // - // TypeCheck nodes look like the following: - // %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool = - // prim::TypeCheck(%inp1 : Tensor, %inp2 : Tensor) - // - // They have N inputs whose types we are going to check and N+1 outputs. The - // first N outputs specify expected types and N+1-th output holds the result - // of the check (bool). - Node* typecheck_node = - fusion_group->owningGraph() - ->create( - prim::TypeCheck, inputs_to_check, inputs_to_check.size() + 1) - ->insertBefore(fusion_group); - Value* typecheck_result = typecheck_node->output(inputs_to_check.size()); - - std::unordered_map typechecked_inputs; - for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) { - typechecked_inputs[typecheck_node->input(i)] = typecheck_node->output(i); - } - - // Fixup types of the typecheck node outputs, which are used by the op in - // execution - typecheck_node->output(inputs_to_check.size())->setType(BoolType::get()); - for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) { - typecheck_node->output(i)->setType(typecheck_node->input(i)->type()); - } - - // Insert if - auto versioning_if = - fusion_group->owningGraph() - ->create( - prim::If, {typecheck_result}, fusion_group->outputs().size()) - ->insertAfter(typecheck_node); - for (size_t idx = 0; idx < fusion_group->outputs().size(); ++idx) { - versioning_if->output(idx)->setType(fusion_group->output(idx)->type()); - fusion_group->output(idx)->replaceAllUsesWith(versioning_if->output(idx)); - } - auto true_block = versioning_if->addBlock(); - auto false_block = versioning_if->addBlock(); - - // Fill in the false block. It should contain the unoptimized - // copy of the fused subgraph. - WithInsertPoint guard(false_block->return_node()); - const auto subgraph_outputs = insertGraph( - *fusion_group->owningGraph(), *subgraph, fusion_group->inputs()); - for (Value* output : subgraph_outputs) { - false_block->registerOutput(output); - } - - // types get copied to the fallback graph, so remove specializations before - // replacing - removeTensorTypeSpecializations(false_block); - replaceBlockWithFallbackGraph(false_block, fusion_group->inputs()); - - // Fill in the true block. It has all inputs type-checked and its - // body should be the fusion group node. - fusion_group->moveBefore(true_block->return_node()); - for (size_t idx = 0; idx < fusion_group->inputs().size(); ++idx) { - if (typechecked_inputs.count(fusion_group->input(idx))) { - fusion_group->replaceInput( - idx, typechecked_inputs.at(fusion_group->input(idx))); - } - } - for (Value* output : fusion_group->outputs()) { - true_block->registerOutput(output); - } - } - - void guardFusionGroups(Block* block) { + void prepareFusionGroupAndGuardOutputs(Block* block) { std::vector fusion_groups; for (Node* n : block->nodes()) { for (Block* b : n->blocks()) { - guardFusionGroups(b); + prepareFusionGroupAndGuardOutputs(b); } if (n->kind() == prim::TensorExprGroup) { fusion_groups.push_back(n); } } for (Node* fusion_group : fusion_groups) { - guardFusionGroup(fusion_group); + removeOutputsUsedOnlyInSize(fusion_group); + liftTensorConstantsFromFusionGroups(fusion_group); + insertTypeGuard( + fusion_group, + [](const TensorTypePtr& t) { return t; }, + prim::TypeCheck); } } @@ -824,16 +1146,7 @@ Operation createTensorExprOp(const Node* node) { std::make_shared(node->g(attr::Subgraph)); return [kernel](Stack* stack) { RECORD_FUNCTION("TensorExpr", std::vector()); - if (!tensorexpr::fallbackAllowed()) { - kernel->run(*stack); - return 0; - } - - try { - kernel->run(*stack); - } catch (const std::runtime_error& e) { - kernel->fallback(*stack); - } + kernel->run(*stack); return 0; }; } diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h index a99cc88ef439e..992d03a6915eb 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.h +++ b/torch/csrc/jit/passes/tensorexpr_fuser.h @@ -28,6 +28,34 @@ TORCH_API bool texprReductionsEnabled(); TORCH_API void RemoveProfileNodesAndSpecializeTypes( std::shared_ptr& graph); TORCH_API void RemoveTensorTypeSpecializations(std::shared_ptr& graph); +TORCH_API void removeTensorTypeSpecializations(Block* block); + +using tensor_type_converter_t = + c10::function_ref; + +// inserts a TypeCheck pattern +// +// around the guarded node that has a Subgraph attribute, this inserts a pattern +// +// if TypeCheck(...): +// guarded_node +// else: +// FallbackGraph(...) +// +// The TypeCheck includes the types of all Tensor inputs to the guarded_node, +// as processed by the type_converter, a lambda +// TensorTypePtr(const TensorTypePtr& t). This allows to erase irrelevant +// aspects of the type. +// +// The Fallback graph will have the same subgraph as the guarded node (with the +// expectation that the guarded_node's subgraph will then be optimized. +TORCH_API void insertTypeGuard( + Node* guarded_node, + tensor_type_converter_t type_converter, + c10::Symbol kind); + +TORCH_API bool usedOnlyInSize(Value* v); +TORCH_API Value* broadcastSizes(at::ArrayRef sizes, AliasDb* db); namespace tensorexpr { TORCH_API bool isSupported(Node* node); diff --git a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp index 15aba5d147956..c9aca34cc64b1 100644 --- a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp +++ b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -9,12 +10,18 @@ void UpdateDifferentiableGraphRequiresGrad( Block* block, c10::optional new_requires_grad) { for (Node* n : block->nodes()) { + for (Value* v : n->inputs()) { + auto ty = v->type()->cast(); + if (ty) { + v->setType(ty->withRequiresGrad(new_requires_grad)); + } + } if (n->kind() == prim::profile) { n->ty_( attr::profiled_type, n->ty(attr::profiled_type) - ->expect() - ->withRequiresGrad(new_requires_grad)); + ->expectRef() + .withRequiresGrad(new_requires_grad)); } for (Block* b : n->blocks()) { UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad); diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index 20e4c6874e68f..78ab4a6728425 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -1,5 +1,7 @@ #include + #include +#include #include namespace torch { @@ -61,19 +63,11 @@ Stack deepCopy(const Stack& stack) { } bool deepEquals(const IValue& lhs, const IValue& rhs) { - if (lhs.isInt() && rhs.isInt()) { - return lhs.toInt() == rhs.toInt(); - } else if (lhs.isDouble() && rhs.isDouble()) { - return lhs.toDouble() == rhs.toDouble(); - } else if (lhs.isNone() && rhs.isNone()) { - return true; - } else if (lhs.isIntList() && rhs.isIntList()) { - return lhs.toIntVector() == rhs.toIntVector(); - } else if (lhs.isTensor() && rhs.isTensor()) { + if (lhs.isTensor() && rhs.isTensor()) { return lhs.toTensor().equal(rhs.toTensor()); } - throw std::runtime_error("Deep equals not implemented for type"); + return lhs == rhs; } struct AliasAndIValue { @@ -146,6 +140,16 @@ const Node* findNodeForOp( return node; } } + + // Check for alias-ed operator names + const auto aliasOp = torch::jit::getOperatorAliasMap().find(opName); + AT_ASSERT(aliasOp != torch::jit::getOperatorAliasMap().end()); + for (const auto node : g.nodes()) { + if (node->kind() == aliasOp->second) { + return node; + } + } + AT_ASSERT(false); } diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index 73976cb66bc84..bee1d50b75fba 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { @@ -394,6 +395,38 @@ Node* createSingletonSubgraphAndUpdateAliasing( }); } +std::string truncateStrWithHash(const std::string& s, size_t maxlen) { + if (s.size() <= maxlen) { + return s; + } + std::string hash_str = c10::to_string(c10::hash{}(s)); + // If hash-string plus '_' can fit into maxlen, then truncate the original + // string correspondingly so that the final string with the hash included fits + // into maxlen. If that's not possible, at least truncate the original string + // to maxlen (and appen the hash to it). + size_t trunc_len = + (maxlen > hash_str.size() + 1) ? (maxlen - hash_str.size() - 1) : maxlen; + std::stringstream truncated; + truncated << s.substr(0, trunc_len); + truncated << "_" << hash_str; + return truncated.str(); +} + +std::string generateNameForGraph( + const std::shared_ptr& graph, + size_t maxlen, + const std::string& prefix) { + std::stringstream graph_name; + graph_name << prefix; + for (Node* node : graph->nodes()) { + if (!node->kind().is_aten()) { + continue; + } + graph_name << "_" << node->kind().toUnqualString(); + } + return truncateStrWithHash(graph_name.str(), maxlen); +} + } // namespace SubgraphUtils } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.h b/torch/csrc/jit/passes/utils/subgraph_utils.h index c0ffc3635031d..f111d495c63dd 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.h +++ b/torch/csrc/jit/passes/utils/subgraph_utils.h @@ -68,6 +68,11 @@ TORCH_API void unmergeSubgraph( // Convenience function std::shared_ptr getSubgraph(Node* n); +TORCH_API std::string generateNameForGraph( + const std::shared_ptr& graph, + size_t maxlen = 40, + const std::string& prefix = "fused"); + } // namespace SubgraphUtils } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp index 0b4e90f3e1aab..4e381c47dae0a 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.cpp +++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp @@ -22,6 +22,51 @@ namespace jit { namespace { +void insertPrePackedLinearOp(std::shared_ptr& graph) { + // fuse decomposed linear into aten::linear + FuseLinear(graph); + + std::string linear_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %r = prim::CallFunction(%linear, %input, %weight, %bias) + return (%r))"; + std::string prepacked_ops_pattern_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %weight_t = aten::t(%weight) + %packed_weight_bias = vulkan_prepack::linear_prepack( + %weight_t, %bias) + %res = vulkan_prepack::linear_run(%input, %packed_weight_bias) + return (%res))"; + std::string linear_pattern = R"( + graph(%input, %weight, %bias): + %r = aten::linear(%input, %weight, %bias) + return (%r))"; + std::string prepacked_ops_pattern = R"( + graph(%input, %weight, %bias): + %weight_t = aten::t(%weight) + %packed_weight_bias = vulkan_prepack::linear_prepack( + %weight_t, %bias) + %res = vulkan_prepack::linear_run(%input, %packed_weight_bias) + return (%res))"; + + const auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + const auto linear_value = match_vmap.at(vmap.at("linear")); + const auto func_name = graph_rewrite_helper::getFuncName(linear_value); + return (func_name == "linear"); + }; + + SubgraphRewriter linear_call_fn_rewriter; + linear_call_fn_rewriter.RegisterRewritePattern( + linear_before_inline, prepacked_ops_pattern_before_inline); + linear_call_fn_rewriter.runOnGraph(graph, filter); + + SubgraphRewriter linear_rewriter; + linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); + linear_rewriter.runOnGraph(graph); +} + void insertPrePackedConv2dOp(std::shared_ptr& graph) { graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); @@ -131,6 +176,7 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { } // namespace void vulkanInsertPrePackedOps(std::shared_ptr& graph) { + insertPrePackedLinearOp(graph); insertPrePackedConv2dOp(graph); } @@ -153,8 +199,10 @@ void vulkanFusePrePackedConvWithClamp(script::Module& module) { void vulkanFoldPrePackingOps(script::Module& m) { PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { return ( - n->kind() == - Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack")); + (n->kind() == + Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack")) || + (n->kind() == + Symbol::fromQualString("vulkan_prepack::linear_prepack"))); }; PrePackingOpsFolder(m, filter_fn, "prepack_folding"); } diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index 3ebfab1d32646..3be480068c40e 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -334,6 +335,9 @@ void fusePrePackedLinearConvWithClamp(script::Module& module) { auto graph = module.get_method("forward").graph(); fuseReluWithPackedOps(graph); fuseHardtanhWithPackedOps(graph); + + // Ignore user defined classes for later passes + ConstantPropagation(graph, true); } void FoldPrePackingOps(script::Module& m) { @@ -348,6 +352,9 @@ void FoldPrePackingOps(script::Module& m) { "prepacked::conv2d_transpose_clamp_prepack")); }; PrePackingOpsFolder(m, filter_fn, "prepack_folding"); + auto graph = m.get_method("forward").graph(); + // Folding requires a const propagation through user defined classes + ConstantPropagation(graph, false); } script::Module optimizeForMobile( @@ -390,7 +397,7 @@ script::Module optimizeForMobile( if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) { FuseAddRelu(cloned_module); } - + cloned_module.register_attribute("mobile_optimized", BoolType::get(), true); return cloned_module; } diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index db866704aa97a..c02399cbb7a70 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,8 @@ #include #include #include +#include +#include #include #include #include @@ -29,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -36,7 +40,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -51,7 +57,6 @@ #include #include #include -#include #include #include #include @@ -81,6 +86,7 @@ #include #include #include +#include #include #include @@ -89,6 +95,7 @@ #include #include +#include #include #include @@ -139,8 +146,20 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_onnx_remove_print", RemovePrintOps) .def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops) .def("_jit_pass_onnx", ToONNX) + .def( + "_jit_pass_onnx_assign_output_shape", + [](std::shared_ptr& graph, + const std::vector& tensors, + bool onnx_shape_inference = false) { + ONNXAssignOutputShape(graph, tensors, onnx_shape_inference); + }) .def("_jit_pass_lower_all_tuples", LowerAllTuples) .def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution) + .def( + "_jit_pass_onnx_fold_if", + [](std::shared_ptr& graph) { + return FoldIfNodeONNX(graph->block()); + }) .def( "_jit_pass_onnx_peephole", [](std::shared_ptr& graph, @@ -188,7 +207,17 @@ void initJITBindings(PyObject* module) { .def( "_jit_pass_onnx_prepare_inplace_ops_for_onnx", PrepareInplaceOpsForONNX) - .def("_jit_pass_onnx_node_shape_type_inference", ONNXShapeTypeInference) + .def( + "_jit_pass_onnx_node_shape_type_inference", + [](Node* n, int opset_version) { + ONNXShapeTypeInference(n, opset_version); + }) + .def( + "_jit_pass_onnx_graph_shape_type_inference", + [](std::shared_ptr& graph, int opset_version) { + ONNXShapeTypeInference(graph, opset_version); + }) + .def("_jit_pass_onnx_set_dynamic_input_shape", ONNXSetDynamicInputShape) .def("_jit_pass_fuse", FuseGraph) .def( "_jit_pass_dce", @@ -260,16 +289,26 @@ void initJITBindings(PyObject* module) { "_jit_pass_quant_fusion", [](std::shared_ptr& g) { return QuantFusion(g); }) .def("_jit_pass_fold_convbn", &FoldConvBatchNorm) + .def( + "_jit_onnx_list_model_parameters", + [](Module& module) { return list_module_parameters(module); }) .def( "_freeze_module", [](Module& module, std::vector& preservedAttrs, - bool freezeInterfaces) { - return freeze_module(module, preservedAttrs, freezeInterfaces); + bool freezeInterfaces, + bool preserveParameters) { + return freeze_module( + module, preservedAttrs, freezeInterfaces, preserveParameters); }, py::arg("module"), py::arg("preservedAttrs") = std::vector(), - py::arg("freezeInterfaces") = true) + py::arg("freezeInterfaces") = true, + py::arg("preserveParameters") = false) + .def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm) + .def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub) + .def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv) + .def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph) .def("_jit_pass_fuse_linear", &FuseLinear) .def( "_jit_pass_fuse_add_relu", @@ -314,19 +353,9 @@ void initJITBindings(PyObject* module) { subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name); subgraph_rewriter.runOnGraph(g); }) - .def( - "_jit_pass_reconstruct_scopes", - [](script::Module& module, - std::shared_ptr& g, - const std::string& prefix) { - ReconstructScopes(module, *g, prefix); - }, - py::arg("module"), - py::arg("graph"), - py::arg("prefix") = "top") .def( "_jit_pass_remove_inplace_ops", - [](std::shared_ptr g) { return RemoveInplaceOps(g); }) + [](const std::shared_ptr& g) { return RemoveInplaceOps(g); }) .def("_jit_pass_constant_pooling", ConstantPooling) .def( "_jit_pass_create_functional_graphs", @@ -356,7 +385,9 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_lint", LintGraph) .def( "_jit_pass_complete_shape_analysis", - [](std::shared_ptr graph, py::tuple inputs, bool with_grad) { + [](const std::shared_ptr& graph, + const py::tuple& inputs, + bool with_grad) { ArgumentSpecCreator arg_spec_creator(*graph); Stack stack; stack.reserve(inputs.size()); // captures? @@ -378,7 +409,7 @@ void initJITBindings(PyObject* module) { }) .def( "_jit_interpret_graph", - [](std::shared_ptr& graph, py::tuple inputs) { + [](std::shared_ptr& graph, const py::tuple& inputs) { Stack stack; stack.reserve(inputs.size()); // captures? for (auto& obj : inputs) { @@ -414,11 +445,14 @@ void initJITBindings(PyObject* module) { }) .def( "_jit_pass_constant_propagation", - [](std::shared_ptr& g) { return ConstantPropagation(g); }) + [](std::shared_ptr& g) { return ConstantPropagation(g); }, + py::arg("graph")) .def("_jit_pass_erase_shape_information", EraseShapeInformation) .def( "_jit_pass_create_autodiff_subgraphs", - [](std::shared_ptr graph) { CreateAutodiffSubgraphs(graph); }) + [](const std::shared_ptr& graph) { + CreateAutodiffSubgraphs(graph); + }) #if defined(BUILDING_TESTS) && !defined(__HIP_PLATFORM_HCC__) .def( "_jit_run_cpp_tests", @@ -446,7 +480,7 @@ void initJITBindings(PyObject* module) { }) .def( "_jit_unflatten", - [](autograd::variable_list vars, python::IODescriptor& desc) { + [](const autograd::variable_list& vars, python::IODescriptor& desc) { return py::reinterpret_steal( python::unflatten(vars, desc)); }) @@ -470,13 +504,20 @@ void initJITBindings(PyObject* module) { }) .def( "_jit_check_alias_annotation", - [](std::shared_ptr g, - py::tuple args, + [](const std::shared_ptr& g, + const py::tuple& args, const std::string& unqualified_op_name) { auto stack = toTraceableStack(args); checkAliasAnnotation(g, std::move(stack), unqualified_op_name); }) .def("_jit_set_nvfuser_enabled", &RegisterCudaFuseGraph::registerPass) + .def( + "_jit_set_nvfuser_guard_mode", + [](bool profiling_flag) { + bool oldState = fuser::cuda::getCudaFusionGuardMode(); + fuser::cuda::getCudaFusionGuardMode() = profiling_flag; + return oldState; + }) .def("_jit_nvfuser_enabled", &RegisterCudaFuseGraph::isRegistered) .def( "_jit_set_profiling_mode", @@ -499,6 +540,13 @@ void initJITBindings(PyObject* module) { getNumProfiledRuns() = num; return old_num; }) + .def( + "_jit_get_num_profiled_runs", + [] { + // pybind can't automatically bind to atomic size_t + size_t num_runs = getNumProfiledRuns(); + return num_runs; + }) .def( "_jit_set_bailout_depth", [](size_t depth) { @@ -515,7 +563,7 @@ void initJITBindings(PyObject* module) { .def( "_jit_try_infer_type", [](py::object obj) -> TypePtr { - auto match = tryToInferType(obj); + auto match = tryToInferType(std::move(obj)); if (match.success()) { return match.type(); } @@ -583,12 +631,33 @@ void initJITBindings(PyObject* module) { using namespace torch::jit::tensorexpr; return getTEGenerateBlockCode(); }) + .def( + "_jit_get_te_must_use_llvm_cpu", + []() -> bool { + using namespace torch::jit::tensorexpr; + return getTEMustUseLLVMOnCPU(); + }) + .def( + "_jit_set_te_must_use_llvm_cpu", + [](bool use_llvm) { + using namespace torch::jit::tensorexpr; + getTEMustUseLLVMOnCPU() = use_llvm; + }) + .def( + "_llvm_enabled", + []() { +#ifdef TORCH_ENABLE_LLVM + return true; +#else + return false; +#endif + }) .def( "_jit_pass_fuse_tensorexprs", [](std::shared_ptr& g) { return FuseTensorExprs(g); }) .def( "_jit_fuser_get_fused_kernel_code", - [](Graph& g, std::vector inps) { + [](Graph& g, const std::vector& inps) { return debugGetFusedKernelCode(g, inps); }) .def( @@ -654,6 +723,30 @@ void initJITBindings(PyObject* module) { std::vector& preserved_methods) { return vulkanOptimizeForMobile(module, preserved_methods); }) + .def( + "_jit_pass_metal_insert_prepacked_ops", + [](std::shared_ptr& graph) { + return metalInsertPrePackedOps(graph); + }) + .def( + "_jit_pass_metal_insert_prepacked_ops", + [](script::Module& module) { + return metalInsertPrePackedOps(module); + }) + .def( + "_jit_pass_metal_fuse_clamp_w_prepacked_conv", + [](script::Module& module) { + return metalFusePrePackedConvWithClamp(module); + }) + .def( + "_jit_pass_metal_fold_prepacking_ops", + [](script::Module& module) { return metalFoldPrePackingOps(module); }) + .def( + "_jit_pass_metal_optimize_for_mobile", + [](script::Module& module, + std::vector& preserved_methods) { + return metalOptimizeForMobile(module, preserved_methods); + }) .def( "_jit_pass_onnx_unpack_quantized_weights", [](std::shared_ptr& graph, @@ -862,14 +955,14 @@ void initJITBindings(PyObject* module) { py::class_(m, "PyTorchFileReader") .def(py::init()) .def(py::init([](const py::object& buffer) { - auto adapter = std::make_unique(std::move(buffer)); + auto adapter = std::make_unique(buffer); return std::make_unique(std::move(adapter)); })) .def( "get_record", [](PyTorchStreamReader& self, const std::string& key) { at::DataPtr data; - size_t size; + size_t size = 0; std::tie(data, size) = self.getRecord(key); return py::bytes(reinterpret_cast(data.get()), size); }) @@ -1052,7 +1145,7 @@ void initJITBindings(PyObject* module) { m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); - auto operations = getAllOperatorsFor(symbol); + const auto& operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); @@ -1081,6 +1174,10 @@ void initJITBindings(PyObject* module) { "then", &PythonFutureWrapper::then, py::call_guard()) + .def( + "add_done_callback", + &PythonFutureWrapper::add_done_callback, + py::call_guard()) .def( "set_result", // Intentionally not releasing GIL @@ -1120,7 +1217,7 @@ void initJITBindings(PyObject* module) { auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1)); auto body_block = fork_node->addBlock(); - Value* node_output; + Value* node_output = nullptr; py::object py_func_output; // Insert new trace ops into the fork op's sub-block WithInsertPoint guard(body_block); @@ -1194,8 +1291,8 @@ void initJITBindings(PyObject* module) { }); }); - m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) { - toIValue(obj, type); + m.def("_jit_assert_is_instance", [](py::object obj, const TypePtr& type) { + toIValue(std::move(obj), type); }); initPythonCustomClassBindings(module); @@ -1205,6 +1302,7 @@ void initJITBindings(PyObject* module) { initJitScriptBindings(module); initJitBackendBindings(module); initStaticRuntimeBindings(module); + initTensorExprBindings(module); setPrintHandler([](const std::string& str) { py::gil_scoped_acquire acquire; diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp new file mode 100644 index 0000000000000..da16a678752d9 --- /dev/null +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -0,0 +1,280 @@ +#include + +#include + +namespace torch { +namespace jit { + +IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { + switch (type->kind()) { + case TypeKind::TensorType: { + auto var = py::cast(obj); + if (var.is_sparse()) { + TORCH_WARN_ONCE( + "Using sparse tensors in TorchScript is experimental. Many optimization " + "pathways have not been thoroughly tested with sparse tensors. Please " + "include the fact that the network is running sparse tensors in any bug " + "reports submitted."); + } + guardAgainstNamedTensor(var); + return var; + } + case TypeKind::FloatType: + return py::cast(obj); + case TypeKind::IntType: + // TODO(xintchen): Handling LayoutType and ScalarTypeType correctly. + case TypeKind::LayoutType: + case TypeKind::ScalarTypeType: + if (THPDtype_Check(obj.ptr())) { + auto dtype = reinterpret_cast(obj.ptr()); + return static_cast(dtype->scalar_type); + } + if (THPQScheme_Check(obj.ptr())) { + auto qscheme = reinterpret_cast(obj.ptr()); + return static_cast(qscheme->qscheme); + } + if (THPLayout_Check(obj.ptr())) { + auto layout = reinterpret_cast(obj.ptr()); + return static_cast(layout->layout); + } + return py::cast(obj); + case TypeKind::NoneType: + if (!obj.is_none()) { + throw py::cast_error( + c10::str("Cannot cast ", py::str(obj), " to None")); + } + return {}; + case TypeKind::BoolType: + return py::cast(obj); + case TypeKind::TupleType: { + py::tuple tuple = py::cast(obj); + size_t tuple_size = tuple.size(); + auto tuple_type = type->cast(); + const auto& elem_types = tuple_type->elements(); + if (elem_types.size() != tuple_size) { + throw py::cast_error(c10::str( + "Object ", + py::str(obj), + " had a different number of elements than type ", + type->repr_str())); + } + std::vector values; + values.reserve(tuple_size); + for (size_t i = 0; i < tuple_size; ++i) { + values.push_back(toIValue(tuple[i], elem_types[i])); + } + return tuple_type->name() + ? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type) + : c10::ivalue::Tuple::create(std::move(values)); + } + case TypeKind::StringType: + return ConstantString::create(py::cast(obj)); + case TypeKind::DeviceObjType: { + auto device = reinterpret_cast(obj.ptr()); + return device->device; + } + case TypeKind::StreamObjType: { + auto stream = reinterpret_cast(obj.ptr()); + return static_cast(stream->cdata); + } + case TypeKind::ListType: { + const auto& elem_type = type->expectRef().getElementType(); + switch (elem_type->kind()) { + // allows single int/float to be broadcasted to a fixed size list + case TypeKind::IntType: + if (!N || !py::isinstance(obj)) { + return IValue(py::cast>(obj)); + } else { + int64_t value = py::cast(obj); + c10::List repeated; + repeated.reserve(*N); + for (int i = 0; i < *N; ++i) { + repeated.push_back(value); + } + return repeated; + } + case TypeKind::FloatType: + if (!N || !py::isinstance(obj)) { + return IValue(py::cast>(obj)); + } else { + double value = py::cast(obj); + c10::List repeated; + repeated.reserve(*N); + for (int i = 0; i < *N; ++i) { + repeated.push_back(value); + } + return repeated; + } + case TypeKind::BoolType: + return IValue(py::cast>(obj)); + case TypeKind::TensorType: + return IValue(py::cast>(obj)); + default: + return createGenericList(obj, elem_type); + } + } + case TypeKind::DictType: { + const auto& dict_type = type->expect(); + return createGenericDict( + py::cast(obj), + dict_type->getKeyType(), + dict_type->getValueType()); + } + case TypeKind::OptionalType: { + // check if it's a none obj since optional accepts NoneType + if (obj.is_none()) { + // check if it's a none obj since optional accepts NoneType + // return an IValue() to denote a NoneType + return {}; + } + return toIValue(obj, type->expectRef().getElementType()); + } + case TypeKind::ClassType: { + auto classType = type->expect(); + if (auto mod = as_module(py::cast(obj))) { + // if obj is already a ScriptModule, just return its ivalue + return mod.value()._ivalue(); + } + // otherwise is a normal class object, we create a fresh + // ivalue::Object to use from the py object. + // 1. create a bare ivalue + const size_t numAttrs = classType->numAttributes(); + auto cu = classType->compilation_unit(); + auto userObj = c10::ivalue::Object::create( + c10::StrongTypePtr(cu, classType), numAttrs); + + // 2. copy all the contained types + for (size_t slot = 0; slot < numAttrs; slot++) { + const auto& attrType = classType->getAttribute(slot); + const auto& attrName = classType->getAttributeName(slot); + + if (!py::hasattr(obj, attrName.c_str())) { + throw py::cast_error(c10::str( + "Tried to cast object to type ", + type->repr_str(), + " but object", + " was missing attribute ", + attrName)); + } + + try { + const auto& contained = py::getattr(obj, attrName.c_str()); + userObj->setSlot(slot, toIValue(contained, attrType)); + } catch (std::exception& e) { + throw py::cast_error(c10::str( + "Could not cast attribute '", + attrName, + "' to type ", + attrType->repr_str(), + ": ", + e.what())); + } + } + return userObj; + } + case TypeKind::InterfaceType: { + auto interfaceType = type->expect(); + // When converting an pyobj to an interface, we check if rhs + // is module or normal torchscript class, get the type and ivalue + // from them correspondingly. + c10::ClassTypePtr classType = nullptr; + IValue res; + if (auto mod = as_module(py::cast(obj))) { + classType = mod.value().type(); + res = mod.value()._ivalue(); + } else { + // We inspect the value to found the compiled TorchScript class + // and then create a ivalue::Object from that class type. + py::str qualified_name = py::module::import("torch._jit_internal") + .attr("_qualified_name")(obj.get_type()); + auto pyCu = get_python_cu(); + classType = pyCu->get_class(c10::QualifiedName(qualified_name)); + if (!classType) { + throw std::runtime_error(c10::str( + "Assigning the object ", + py::str(obj), + " to an interface fails because the value is not " + "a TorchScript compatible type, did you forget to", + "turn it into a user defined TorchScript class?")); + } + res = toIValue(obj, classType); + } + // check if the classType conform with the interface or not + std::stringstream why_not; + if (!classType->isSubtypeOfExt(interfaceType, &why_not)) { + throw py::cast_error(c10::str( + "Object ", + py::str(obj), + " is not compatible with interface ", + interfaceType->repr_str(), + "\n", + why_not.str())); + } + return res; + } + case TypeKind::NumberType: { + if (THPDtype_Check(obj.ptr())) { + auto dtype = reinterpret_cast(obj.ptr()); + return static_cast(dtype->scalar_type); + } + if (THPQScheme_Check(obj.ptr())) { + auto qscheme = reinterpret_cast(obj.ptr()); + return static_cast(qscheme->qscheme); + } + if (THPLayout_Check(obj.ptr())) { + auto layout = reinterpret_cast(obj.ptr()); + return static_cast(layout->layout); + } + if (py::isinstance(obj)) { + return py::cast(obj); + } else if (py::isinstance(obj)) { + return py::cast(obj); + } else { + throw py::cast_error( + c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str())); + } + } + case TypeKind::RRefType: { +#ifdef USE_RPC + return obj.cast().toIValue(); +#else + AT_ERROR("RRef is only supported with the distributed package"); +#endif + } break; + case TypeKind::PyObjectType: { + return c10::ivalue::ConcretePyObjectHolder::create(obj); + } + case TypeKind::CapsuleType: { + return IValue::make_capsule(py::cast(obj).obj_ptr); + } + case TypeKind::FutureType: { + return obj.cast>()->fut; + } + case TypeKind::AnyType: + return toTypeInferredIValue(obj); + case TypeKind::FunctionType: + case TypeKind::GeneratorType: + case TypeKind::StorageType: + case TypeKind::QuantizerType: + case TypeKind::VarType: + case TypeKind::QSchemeType: + case TypeKind::AnyListType: + case TypeKind::AnyTupleType: + case TypeKind::AnyClassType: + case TypeKind::AnyEnumType: + break; + case TypeKind::EnumType: + EnumTypePtr enum_type = type->expect(); + py::object py_obj = py::reinterpret_borrow(obj); + std::string name = py::cast(obj.attr("name")); + IValue value = toIValue(obj.attr("value"), enum_type->getValueType(), {}); + auto enum_holder = + c10::make_intrusive(enum_type, name, value); + return IValue(enum_holder); + } + throw py::cast_error(c10::str( + "toIValue() cannot handle converting to type: ", type->repr_str())); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 65f5a49145c8c..b14f4ddc37fdb 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -10,13 +10,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include #include @@ -29,6 +29,11 @@ #endif #include +#include +#ifdef USE_C10D_NCCL +#include +#include +#endif #include #include @@ -48,7 +53,7 @@ namespace torch { namespace jit { -inline IValue toIValue( +IValue toIValue( py::handle obj, const TypePtr& type, c10::optional N = c10::nullopt); @@ -112,11 +117,12 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper // Future owns a reference to the py::function in its callback // vector, but Future does not acquire GIL on destruction. auto pf = std::make_shared(std::move(cb)); + return std::make_shared(fut->then( // Capture a copy of the ivalue::Future instead of the `this` pointer // because the PythonFutureWrapper object could have been deleted // when the callbacks are fired. For example, RPC only captures the - // ivalue::Future instead of PythonFutureWrapper in FutureMessage's + // ivalue::Future instead of PythonFutureWrapper in JitFuture's // callback functions. Hence, if user code does not hold a reference to // this PythonFutureWrapper object, there is no guarantee that the // PythonFutureWrapper is still valid when running the callback. @@ -144,6 +150,36 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper PyObjectType::get())); } + void add_done_callback(py::function cb) { + auto pf = std::make_shared(std::move(cb)); + fut->addCallback(std::bind( + [pyFut(this->getPtr())](std::shared_ptr pf) { + try { + pybind11::gil_scoped_acquire ag; + pf->func_(pyFut); + } catch (py::error_already_set& e) { + { + pybind11::gil_scoped_acquire ag; + // Release ownership on py::objects and also restore Python + // Error Indicator. + e.restore(); + // Clear the Python Error Indicator as we has recorded the + // exception in the response message. + PyErr_Clear(); + } + // Log and ignore exceptions raised through the callback + VLOG(1) << "Got the following error when running the callback: " + << e.what(); + + } catch (std::exception& e) { + // Log and ignore exceptions raised through the callback + VLOG(1) << "Got the following error when running the callback: " + << e.what(); + } + }, + std::move(pf))); + } + void markCompleted(const py::object& pyValue) { DCHECK(PyGILState_Check()); IValue value = toIValue(pyValue, PyObjectType::get()); @@ -216,34 +252,15 @@ inline TypedIValue toDictKeyIValue(py::handle key) { } inline c10::optional unifyOrInitializeType( - TypePtr accum, - TypePtr unify) { + const TypePtr& accum, + const TypePtr& unify) { if (!accum) { return unify; } return unifyTypes(accum, unify); } -struct InferredType { - InferredType(TypePtr type) : type_(std::move(type)) {} - InferredType(std::string reason) - : type_(nullptr), reason_(std::move(reason)) {} - TypePtr type() const { - TORCH_INTERNAL_ASSERT(type_); - return type_; - } - bool success() const { - return type_ != nullptr; - } - const std::string& reason() const { - TORCH_INTERNAL_ASSERT(!type_); - return reason_; - } - - private: - TypePtr type_; - std::string reason_; -}; +using InferredType = c10::InferredType; InferredType tryToInferContainerType(py::handle input); @@ -256,8 +273,7 @@ InferredType tryToInferContainerType(py::handle input); inline InferredType tryToInferType(py::handle input) { // Try tensor types if (THPVariable_Check(input.ptr())) { - auto tensor = py::cast(input); - return InferredType(TensorType::create(tensor)); + return InferredType(TensorType::get()); } if (input.is(py::none())) { @@ -282,6 +298,8 @@ inline InferredType tryToInferType(py::handle input) { return InferredType(IntType::get()); } else if (THPDevice_Check(input.ptr())) { return InferredType(DeviceObjType::get()); + } else if (THPStream_Check(input.ptr())) { + return InferredType(StreamObjType::get()); } else if (THPDtype_Check(input.ptr())) { return InferredType(IntType::get()); } else if (THPQScheme_Check(input.ptr())) { @@ -320,7 +338,7 @@ inline InferredType tryToInferType(py::handle input) { if (py::isinstance(input)) { auto object = py::cast(input); return InferredType(object.type()); -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC } else if (py::isinstance(input)) { auto rref_ivalue = input.cast().toIValue(); return InferredType(rref_ivalue.type()); @@ -434,7 +452,7 @@ inline InferredType tryToInferContainerType(py::handle input) { } } -inline bool isTraceableType(TypePtr type) { +inline bool isTraceableType(const TypePtr& type) { if (type->isSubtypeOf(TensorType::get())) { return true; } @@ -447,7 +465,9 @@ inline bool isTraceableType(TypePtr type) { return std::all_of( tuple_type->elements().begin(), tuple_type->elements().end(), - [](TypePtr element_type) { return isTraceableType(element_type); }); + [](const TypePtr& element_type) { + return isTraceableType(element_type); + }); } if (auto dict_type = type->cast()) { @@ -480,13 +500,13 @@ inline Stack toTraceableStack(const py::tuple& inputs) { inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { auto elems = c10::impl::GenericList(elem_type); for (auto elem : obj) { - elems.push_back(toIValue(std::move(elem), elem_type)); + elems.push_back(toIValue(elem, elem_type)); } return IValue(std::move(elems)); } inline IValue createGenericDict( - py::dict obj, + const py::dict& obj, const TypePtr& key_type, const TypePtr& value_type) { c10::impl::GenericDict elems(key_type, value_type); @@ -506,256 +526,9 @@ inline void guardAgainstNamedTensor(const T& var) { "workaround please drop names via `tensor = tensor.rename(None)`."); } -inline IValue toIValue( - py::handle obj, - const TypePtr& type, - c10::optional N) { - switch (type->kind()) { - case TypeKind::TensorType: { - auto var = py::cast(obj); - if (var.is_sparse()) { - TORCH_WARN_ONCE( - "Using sparse tensors in TorchScript is experimental. Many optimization " - "pathways have not been thoroughly tested with sparse tensors. Please " - "include the fact that the network is running sparse tensors in any bug " - "reports submitted."); - } - guardAgainstNamedTensor(var); - return var; - } - case TypeKind::FloatType: - return py::cast(obj); - case TypeKind::IntType: - // TODO(xintchen): Handling LayoutType and ScalarTypeType correctly. - case TypeKind::LayoutType: - case TypeKind::ScalarTypeType: - if (THPDtype_Check(obj.ptr())) { - auto dtype = reinterpret_cast(obj.ptr()); - return static_cast(dtype->scalar_type); - } - if (THPQScheme_Check(obj.ptr())) { - auto qscheme = reinterpret_cast(obj.ptr()); - return static_cast(qscheme->qscheme); - } - if (THPLayout_Check(obj.ptr())) { - auto layout = reinterpret_cast(obj.ptr()); - return static_cast(layout->layout); - } - return py::cast(obj); - case TypeKind::NoneType: - if (!obj.is_none()) { - throw py::cast_error( - c10::str("Cannot cast ", py::str(obj), " to None")); - } - return {}; - case TypeKind::BoolType: - return py::cast(obj); - case TypeKind::TupleType: { - py::tuple tuple = py::cast(obj); - size_t tuple_size = tuple.size(); - auto tuple_type = type->cast(); - const auto& elem_types = tuple_type->elements(); - if (elem_types.size() != tuple_size) { - throw py::cast_error(c10::str( - "Object ", - py::str(obj), - " had a different number of elements than type ", - type->repr_str())); - } - std::vector values; - values.reserve(tuple_size); - for (size_t i = 0; i < tuple_size; ++i) { - values.push_back(toIValue(tuple[i], elem_types[i])); - } - return tuple_type->name() - ? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type) - : c10::ivalue::Tuple::create(std::move(values)); - } - case TypeKind::StringType: - return ConstantString::create(py::cast(obj)); - case TypeKind::DeviceObjType: { - auto device = reinterpret_cast(obj.ptr()); - return device->device; - } - case TypeKind::ListType: { - const auto& elem_type = type->expect()->getElementType(); - switch (elem_type->kind()) { - // allows single int/float to be broadcasted to a fixed size list - case TypeKind::IntType: - if (!N || !py::isinstance(obj)) { - return IValue(py::cast>(obj)); - } else { - int64_t value = py::cast(obj); - c10::List repeated; - repeated.reserve(*N); - for (int i = 0; i < *N; ++i) { - repeated.push_back(value); - } - return repeated; - } - case TypeKind::FloatType: - if (!N || !py::isinstance(obj)) { - return IValue(py::cast>(obj)); - } else { - double value = py::cast(obj); - c10::List repeated; - repeated.reserve(*N); - for (int i = 0; i < *N; ++i) { - repeated.push_back(value); - } - return repeated; - } - case TypeKind::BoolType: - return IValue(py::cast>(obj)); - case TypeKind::TensorType: - return IValue(py::cast>(obj)); - default: - return createGenericList(obj, elem_type); - } - } - case TypeKind::DictType: { - const auto& dict_type = type->expect(); - return createGenericDict( - py::cast(obj), - dict_type->getKeyType(), - dict_type->getValueType()); - } - case TypeKind::OptionalType: { - // check if it's a none obj since optional accepts NoneType - if (obj.is_none()) { - // check if it's a none obj since optional accepts NoneType - // return an IValue() to denote a NoneType - return {}; - } - return toIValue(obj, type->expect()->getElementType()); - } - case TypeKind::ClassType: { - auto classType = type->expect(); - if (auto mod = as_module(py::cast(obj))) { - // if obj is already a ScriptModule, just return its ivalue - return mod.value()._ivalue(); - } - // otherwise is a normal class object, we create a fresh - // ivalue::Object to use from the py object. - // 1. create a bare ivalue - const size_t numAttrs = classType->numAttributes(); - auto cu = classType->compilation_unit(); - auto userObj = c10::ivalue::Object::create( - c10::StrongTypePtr(cu, classType), numAttrs); - - // 2. copy all the contained types - for (size_t slot = 0; slot < numAttrs; slot++) { - const auto& attrType = classType->getAttribute(slot); - const auto& attrName = classType->getAttributeName(slot); - - const auto& contained = py::getattr(obj, attrName.c_str()); - userObj->setSlot(slot, toIValue(contained, attrType)); - } - return userObj; - } - case TypeKind::InterfaceType: { - auto interfaceType = type->expect(); - // When converting an pyobj to an interface, we check if rhs - // is module or normal torchscript class, get the type and ivalue - // from them correspondingly. - c10::ClassTypePtr classType = nullptr; - IValue res; - if (auto mod = as_module(py::cast(obj))) { - classType = mod.value().type(); - res = mod.value()._ivalue(); - } else { - // We inspect the value to found the compiled TorchScript class - // and then create a ivalue::Object from that class type. - py::str qualified_name = py::module::import("torch._jit_internal") - .attr("_qualified_name")(obj.get_type()); - auto pyCu = get_python_cu(); - classType = pyCu->get_class(c10::QualifiedName(qualified_name)); - if (!classType) { - throw std::runtime_error(c10::str( - "Assigning the object ", - py::str(obj), - " to an interface fails because the value is not " - "a TorchScript compatible type, did you forget to", - "turn it into a user defined TorchScript class?")); - } - res = toIValue(std::move(obj), classType); - } - // check if the classType conform with the interface or not - std::stringstream why_not; - if (!classType->isSubtypeOfExt(interfaceType, &why_not)) { - throw py::cast_error(c10::str( - "Object ", - py::str(obj), - " is not compatible with interface ", - interfaceType->repr_str(), - "\n", - why_not.str())); - } - return res; - } - case TypeKind::NumberType: { - if (THPDtype_Check(obj.ptr())) { - auto dtype = reinterpret_cast(obj.ptr()); - return static_cast(dtype->scalar_type); - } - if (THPQScheme_Check(obj.ptr())) { - auto qscheme = reinterpret_cast(obj.ptr()); - return static_cast(qscheme->qscheme); - } - if (THPLayout_Check(obj.ptr())) { - auto layout = reinterpret_cast(obj.ptr()); - return static_cast(layout->layout); - } - if (py::isinstance(obj)) { - return py::cast(obj); - } else if (py::isinstance(obj)) { - return py::cast(obj); - } else { - throw py::cast_error( - c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str())); - } - } - case TypeKind::RRefType: { -#ifdef USE_DISTRIBUTED - return obj.cast().toIValue(); -#else - AT_ERROR("RRef is only supported with the distributed package"); -#endif - } break; - case TypeKind::PyObjectType: { - return c10::ivalue::ConcretePyObjectHolder::create(obj); - } - case TypeKind::CapsuleType: { - return IValue::make_capsule( - py::cast>(obj)); - } - case TypeKind::FutureType: { - return obj.cast>()->fut; - } - case TypeKind::AnyType: - return toTypeInferredIValue(obj); - case TypeKind::FunctionType: - case TypeKind::GeneratorType: - case TypeKind::QuantizerType: - case TypeKind::VarType: - case TypeKind::QSchemeType: - case TypeKind::AnyListType: - case TypeKind::AnyTupleType: - case TypeKind::AnyClassType: - case TypeKind::AnyEnumType: - break; - case TypeKind::EnumType: - EnumTypePtr enum_type = type->expect(); - py::object py_obj = py::reinterpret_borrow(obj); - std::string name = py::cast(obj.attr("name")); - IValue value = toIValue(obj.attr("value"), enum_type->getValueType(), {}); - auto enum_holder = - c10::make_intrusive(enum_type, name, value); - return IValue(enum_holder); - } - throw py::cast_error(c10::str( - "toIValue() cannot handle converting to type: ", type->repr_str())); -} +// Defined in pybind_utils.cpp to break a circular dependency with +// python_ivalue.h +IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N); // Small wrapper around getting the type name string from Python to make // types easier to interpret, e.g. give the structural type for a NamedTuple @@ -805,6 +578,15 @@ inline IValue argumentToIValue( py::repr(object)), "\nCast error details: ", error.what())); + } catch (const py::error_already_set& error) { + throw schema_match_error(c10::str( + schema.formatTypeMismatchMsg( + argument, + friendlyTypeName(object), + argumentPosition, + py::repr(object)), + "\n Python error details: ", + error.what())); } } @@ -896,7 +678,7 @@ inline py::object toPyObject(IValue ivalue) { } return std::move(py_dict); } else if (ivalue.isRRef()) { -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC auto RRefPtr = c10::dynamic_intrusive_pointer_cast( std::move(ivalue).toRRef()); @@ -932,7 +714,7 @@ inline py::object toPyObject(IValue ivalue) { // PyObject return py::reinterpret_borrow(ivalue.toPyObject()); } else if (ivalue.isCapsule()) { - return py::cast(ivalue.toCapsule()); + return py::cast(c10::Capsule(ivalue.toCapsule())); } else if (ivalue.isFuture()) { return py::cast(std::make_shared(ivalue.toFuture())); } else if (ivalue.isEnum()) { @@ -942,7 +724,7 @@ inline py::object toPyObject(IValue ivalue) { auto py_class = getScriptedClassOrError(qualified_class_name); return py_class.attr(enum_holder->name().c_str()); } else if (ivalue.isRRef()) { -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC return py::cast(torch::distributed::rpc::PyRRef( c10::static_intrusive_pointer_cast( ivalue.toRRef()))); @@ -1006,9 +788,9 @@ inline Stack createStackForSchema( push(stack, std::move(*self)); } // First push all positional args. - for (size_t i = 0; i < args.size(); ++i) { + for (const auto& arg : args) { // Use the type information from the schema to convert the PyObject. - push(stack, argumentToIValue(schema, stack.size(), args[i])); + push(stack, argumentToIValue(schema, stack.size(), arg)); } // Now for every remaining non-positional argument in the schema, look for it @@ -1085,15 +867,16 @@ inline Stack evilDeprecatedBadCreateStackDoNotUse( // tracing graph. inline py::object runAndInsertCall( Function& callee, - tuple_slice args, - py::kwargs kwargs, + const tuple_slice& args, + const py::kwargs& kwargs, c10::optional self, // Lambda that tells this function how to insert `callee` into the graph if // we're tracing. - std::function callInserter) { - auto stack = createStackForSchema( - callee.getSchema(), std::move(args), std::move(kwargs), std::move(self)); - auto tracing_state = tracer::getTracingState(); + const std::function& + callInserter) { + auto stack = + createStackForSchema(callee.getSchema(), args, kwargs, std::move(self)); + const auto& tracing_state = tracer::getTracingState(); if (!tracing_state) { pybind11::gil_scoped_release no_gil_guard; // If we're not tracing, just run the callee as normal. @@ -1143,8 +926,8 @@ inline py::object runAndInsertCall( inline py::object invokeScriptFunctionFromPython( Function& callee, - tuple_slice args, - py::kwargs kwargs) { + const tuple_slice& args, + const py::kwargs& kwargs) { return runAndInsertCall( callee, args, @@ -1157,8 +940,8 @@ inline py::object invokeScriptFunctionFromPython( inline py::object invokeScriptMethodFromPython( Method& callee, - tuple_slice args, - py::kwargs kwargs) { + const tuple_slice& args, + const py::kwargs& kwargs) { auto self = callee.owner()._ivalue(); return runAndInsertCall( callee.function(), @@ -1170,20 +953,18 @@ inline py::object invokeScriptMethodFromPython( }); } -inline py::object invokeOperatorFromPython( +inline std::pair, Stack> getOpWithStack( const std::vector>& operations, py::args args, - py::kwargs kwargs) { + const py::kwargs& kwargs) { Stack stack; - if (operations.size() == 1) { - const Operator& op = *operations.at(0); + std::shared_ptr op = operations.at(0); // Create a stack full of the arguments and keyword arguments. stack = createStackForSchema( - op.schema(), std::move(args), std::move(kwargs), c10::nullopt); + op->schema(), std::move(args), kwargs, c10::nullopt); - pybind11::gil_scoped_release no_gil_guard; - op.getOperation()(&stack); + return std::make_pair(op, stack); } else { std::vector errors; std::shared_ptr found_op = nullptr; @@ -1205,6 +986,17 @@ inline py::object invokeOperatorFromPython( throw std::runtime_error(ss.str()); } + return std::make_pair(found_op, stack); + } +} +inline py::object invokeOperatorFromPython( + const std::vector>& operations, + py::args args, + const py::kwargs& kwargs) { + auto opWithStack = getOpWithStack(operations, args, kwargs); + std::shared_ptr found_op = std::get<0>(opWithStack); + Stack stack = std::get<1>(opWithStack); + { pybind11::gil_scoped_release no_gil_guard; found_op->getOperation()(&stack); } diff --git a/torch/csrc/jit/python/python_arg_flatten.cpp b/torch/csrc/jit/python/python_arg_flatten.cpp index 41cd3cd2b8afe..ea7fac0626c02 100644 --- a/torch/csrc/jit/python/python_arg_flatten.cpp +++ b/torch/csrc/jit/python/python_arg_flatten.cpp @@ -21,6 +21,7 @@ static constexpr char TupleOpen = '('; static constexpr char TupleClose = ')'; static constexpr char Variable = 'v'; static constexpr char String = 's'; +static constexpr char NoneType = 'n'; } // namespace D namespace { @@ -62,11 +63,13 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) { args.vars.push_back(var); args.desc.metadata.emplace_back(var); args.desc.structure.push_back(D::Variable); + } else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) { + args.desc.structure.push_back(D::NoneType); } else { std::string msg = - "Only tuples, lists and Variables supported as JIT inputs/outputs. " - "Dictionaries and strings are also accepted but their usage is not " - "recommended. But got unsupported type "; + "Only tuples, lists and Variables are supported as JIT inputs/outputs. " + "Dictionaries and strings are also accepted, but their usage is not " + "recommended. Here, received an input of unsupported type: "; msg += THPUtils_typename(obj); throw std::runtime_error(msg); } @@ -97,7 +100,7 @@ py::object cast_dict(std::vector objs) { py::dict sequence = {}; for (size_t i = 0; i < num_objs; ++i) { py::tuple obj = py::reinterpret_borrow(objs[i]); - sequence[obj[0]] = std::move(obj[1]); + sequence[obj[0]] = obj[1]; } return std::move(sequence); } @@ -136,6 +139,8 @@ py::object unflatten_rec( throw std::runtime_error("Not enough Variables given to unflatten"); auto str = *str_it++; return py::reinterpret_borrow(THPUtils_packString(str)); + } else if (type == D::NoneType) { + return py::reinterpret_borrow(py::none()); } else { if (var_it == var_it_end) throw std::runtime_error("Not enough Variables given to unflatten"); diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index 49c85c8c3c7f6..521b93f11179f 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -28,7 +29,10 @@ void initPythonCustomClassBindings(PyObject* module) { auto m = py::handle(module).cast(); py::class_(m, "ScriptClass") - .def("__call__", &ScriptClass::__call__); + .def("__call__", &ScriptClass::__call__) + .def_property_readonly("__doc__", [](const ScriptClass& self) { + return self.class_type_.type_->expectRef().doc_string(); + }); // This function returns a ScriptClass that wraps the constructor // of the given class, specified by the qualified name passed in. diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 46a6448ee3106..ab8545d7ff842 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -216,12 +217,12 @@ void initPythonIRBindings(PyObject* module_) { .def( "dump_alias_db", [](std::shared_ptr g) { - AliasDb db(g); + AliasDb db(std::move(g)); db.dump(); }) .def( "_export_onnx", - [](const std::shared_ptr g, + [](const std::shared_ptr& g, const std::map& initializers, int64_t onnx_opset_version, const std::unordered_map< @@ -236,8 +237,10 @@ void initPythonIRBindings(PyObject* module_) { bool use_external_data_format, const std::string& onnx_file_path) { std::string graph; + std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto; RawDataExportMap export_map; - std::tie(graph, export_map) = export_onnx( + SymbolDimMap symbol_map; + std::tie(model_proto, export_map, symbol_map) = export_onnx( g, initializers, onnx_opset_version, @@ -250,6 +253,7 @@ void initPythonIRBindings(PyObject* module_) { add_node_names, use_external_data_format, onnx_file_path); + graph = serialize_model_proto_to_string(model_proto); std::unordered_map python_serialized_export_map; for (auto& kv : export_map) { @@ -261,6 +265,7 @@ void initPythonIRBindings(PyObject* module_) { python_serialized_export_map[kv.first] = py::bytes(static_cast(t.data_ptr()), copy_bytes); } + graph = serialize_model_proto_to_string(model_proto); return std::make_tuple( py::bytes(graph), python_serialized_export_map); }, @@ -278,7 +283,7 @@ void initPythonIRBindings(PyObject* module_) { py::arg("onnx_file_path") = std::string()) .def( "_pretty_print_onnx", - [](const std::shared_ptr g, + [](const std::shared_ptr& g, const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, @@ -385,7 +390,7 @@ void initPythonIRBindings(PyObject* module_) { .GS(prependNode) .def( "insertConstant", - [](Graph& g, IValue ival) { return g.insertConstant(ival); }) + [](Graph& g, const IValue& ival) { return g.insertConstant(ival); }) .GS(lint) .GS(insertNode); #undef GS @@ -429,7 +434,7 @@ void initPythonIRBindings(PyObject* module_) { .VS(requires_grad) .def( "requiresGrad", - [](Value& n) { n.type()->expect()->requiresGrad(); }) + [](Value& n) { n.type()->expectRef().requiresGrad(); }) .def("toIValue", [](Value& n) { return toIValue(&n); }) .def("type", [](Value& v) { return v.type(); }); #undef VS @@ -468,8 +473,14 @@ void initPythonIRBindings(PyObject* module_) { }) .def("returnNode", [](Block& b) { return b.return_node(); }) .def("paramNode", [](Block& b) { return b.param_node(); }) - .def("addNode", [](Block& b, Value& input, const char* str) { - return addNodeToBlock(&b, &input, Symbol::fromQualString(str)); + .def( + "addNode", + [](Block& b, const char* str, const std::vector& inputs) { + return addNodeToBlock(&b, Symbol::fromQualString(str), inputs); + }) + .def("addInputToBlock", [](Block& b) { return addInputToBlock(&b); }) + .def("registerOutput", [](Block& b, Value* value) { + return b.registerOutput(value); }); #define NS(name) def(#name, &Node ::name) @@ -483,6 +494,7 @@ void initPythonIRBindings(PyObject* module_) { }) .def("sourceRange", [](Node& n) { return n.sourceRange().str(); }) .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; }) + .def("inputsSize", [](Node& n) { return n.inputs().size(); }) .def("outputsSize", [](Node& n) { return n.outputs().size(); }) .NS(kind) .def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); }) @@ -491,6 +503,17 @@ void initPythonIRBindings(PyObject* module_) { [](Node& n) { return py::make_iterator(n.inputs().begin(), n.inputs().end()); }) + .def( + "schema", + [](Node& n) { + std::stringstream ss; + if (auto sch = n.maybeSchema()) { + ss << n.schema(); + } else { + ss << "(no schema)"; + } + return ss.str(); + }) .def( "outputs", [](Node& n) { @@ -556,14 +579,12 @@ void initPythonIRBindings(PyObject* module_) { .AS(removeAttribute) .AS(attributeNames) #undef AS -#define CREATE_ACCESSOR(Kind, method) \ - def(#method "_", \ - [](Node& n, const char* name, Kind##Attr::ValueType v) { \ - return n.method##_(Symbol::attr(name), std::move(v)); \ - }) \ - .def(#method, [](Node& n, const char* name) { \ - return n.method(Symbol::attr(name)); \ - }) +#define CREATE_ACCESSOR(Kind, method) \ + def(#method "_", [](Node& n, const char* name, Kind##Attr::ValueType v) { \ + return n.method##_(Symbol::attr(name), std::move(v)); \ + }).def(#method, [](Node& n, const char* name) { \ + return n.method(Symbol::attr(name)); \ + }) .CREATE_ACCESSOR(Float, f) .CREATE_ACCESSOR(Floats, fs) .CREATE_ACCESSOR(String, s) @@ -576,7 +597,7 @@ void initPythonIRBindings(PyObject* module_) { // Tensor (t_) -- manually written to unwrap the variable into a tensor. .def( "t_", - [](Node& n, const char* name, torch::autograd::Variable v) { + [](Node& n, const char* name, const torch::autograd::Variable& v) { AT_ASSERT(!v.requires_grad()); return n.t_(Symbol::attr(name), v); }) @@ -588,7 +609,7 @@ void initPythonIRBindings(PyObject* module_) { "ts_", [](Node& n, const char* name, - std::vector vs) { + const std::vector& vs) { std::vector tensors; tensors.reserve(vs.size()); for (auto& variable : vs) { @@ -610,10 +631,11 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "z_", - [](Node& n, const char* name, at::Tensor v) { + [](Node& n, const char* name, const at::Tensor& v) { return n.t_( Symbol::attr(name), - autograd::Variable(v.view({})).set_requires_grad(false)); + autograd::Variable(v.view(std::vector{})) + .set_requires_grad(false)); }) .def( "z", @@ -622,7 +644,8 @@ void initPythonIRBindings(PyObject* module_) { "zs_", [](Node& n, const char* name, TensorsAttr::ValueType v) { for (auto& i : v) { - i = autograd::Variable(i.view({})).set_requires_grad(false); + i = autograd::Variable(i.view(std::vector{})) + .set_requires_grad(false); } return n.ts_(Symbol::attr(name), std::move(v)); }) @@ -663,7 +686,7 @@ void initPythonIRBindings(PyObject* module_) { .def( "dim", [](Type& t) { - auto vshape = t.shared_from_this()->expect()->sizes(); + auto vshape = t.shared_from_this()->expectRef().sizes(); return vshape.size() ? py::cast(*vshape.size()) : py::cast(Py_None); }) @@ -671,7 +694,7 @@ void initPythonIRBindings(PyObject* module_) { "undefined", [](Type& t) { auto undef = - t.shared_from_this()->expect()->undefined(); + t.shared_from_this()->expectRef().undefined(); return undef.has_value() ? py::cast(*undef) : py::cast(Py_None); }) @@ -685,6 +708,16 @@ void initPythonIRBindings(PyObject* module_) { } return py::none(); }) + .def( + "varyingSizes", + [](Type& t) -> py::object { + if (auto ptt = t.expect()) { + if (auto s = ptt->sizes().sizes()) { + return py::cast(s.value()); + } + } + return py::none(); + }) .def( "strides", [](Type& t) -> py::object { @@ -699,13 +732,13 @@ void initPythonIRBindings(PyObject* module_) { "contiguous", [](Type& t) { return std::static_pointer_cast( - t.expect()->contiguous()); + t.expectRef().contiguous()); }) .def( "scalarType", [](Type& t) { auto scalar_type = - t.shared_from_this()->expect()->scalarType(); + t.shared_from_this()->expectRef().scalarType(); return (scalar_type) ? toString(*scalar_type) : nullptr; }) .def( @@ -718,7 +751,7 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "isSubtypeOf", - [](std::shared_ptr& self, std::shared_ptr other) { + [](std::shared_ptr& self, std::shared_ptr& other) { if (!other) { return false; } @@ -737,7 +770,8 @@ void initPythonIRBindings(PyObject* module_) { py::class_>(m, "FloatType") .def_static("get", &FloatType::get); py::class_>(m, "TensorType") - .def_static("get", &TensorType::get); + .def_static("get", &TensorType::get) + .def_static("getInferred", &TensorType::getInferred); py::class_>(m, "BoolType") .def_static("get", &BoolType::get); py::class_>(m, "StringType") @@ -745,6 +779,9 @@ void initPythonIRBindings(PyObject* module_) { py::class_>( m, "DeviceObjType") .def_static("get", &DeviceObjType::get); + py::class_>( + m, "StreamObjType") + .def_static("get", &StreamObjType::get); py::class_>( m, "PyObjectType") .def_static("get", &PyObjectType::get); @@ -752,8 +789,9 @@ void initPythonIRBindings(PyObject* module_) { .def_static("get", &NoneType::get); py::class_>(m, "TupleType") - .def( - py::init([](std::vector a) { return TupleType::create(a); })) + .def(py::init([](std::vector a) { + return TupleType::create(std::move(a)); + })) .def("elements", [](TupleType& self) { std::vector types; for (const auto& type : self.elements()) { @@ -770,21 +808,22 @@ void initPythonIRBindings(PyObject* module_) { .def("getElementType", &ListType::getElementType); py::class_>(m, "DictType") .def(py::init([](TypePtr key, TypePtr value) { - return DictType::create(key, value); + return DictType::create(std::move(key), std::move(value)); })) .def("getKeyType", &DictType::getKeyType) .def("getValueType", &DictType::getValueType); py::class_>( m, "OptionalType") - .def(py::init([](TypePtr a) { return OptionalType::create(a); })) + .def(py::init( + [](TypePtr a) { return OptionalType::create(std::move(a)); })) .def_static("ofTensor", &OptionalType::ofTensor) .def("getElementType", &OptionalType::getElementType); py::class_>(m, "RRefType") - .def(py::init([](TypePtr a) { return RRefType::create(a); })) + .def(py::init([](TypePtr a) { return RRefType::create(std::move(a)); })) .def("getElementType", &RRefType::getElementType); py::class_>(m, "FutureType") - .def(py::init([](TypePtr a) { return FutureType::create(a); })) + .def(py::init([](TypePtr a) { return FutureType::create(std::move(a)); })) .def("getElementType", &FutureType::getElementType); py::class_>(m, "ClassType") diff --git a/torch/csrc/jit/python/python_ivalue.h b/torch/csrc/jit/python/python_ivalue.h index d02f2a7c554aa..ecd4dd2585bf5 100644 --- a/torch/csrc/jit/python/python_ivalue.h +++ b/torch/csrc/jit/python/python_ivalue.h @@ -1,5 +1,7 @@ #pragma once +#include #include +#include #include namespace py = pybind11; @@ -24,6 +26,22 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { return py_obj_.ptr(); } + InferredType tryToInferType() override { + pybind11::gil_scoped_acquire ag; + return torch::jit::tryToInferType(py_obj_); + } + + IValue toIValue(const TypePtr& type, c10::optional N = c10::nullopt) + override { + pybind11::gil_scoped_acquire ag; + return torch::jit::toIValue(py_obj_, type, N); + } + + std::string toStr() override { + pybind11::gil_scoped_acquire ag; + return py::str(py_obj_); + } + // Note [Destructing py::object] // ~~~~~~~~~~~~~~~~~~~~~~~~~~ // diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index ba94d33f37b39..92a06110f0924 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -91,12 +92,12 @@ FunctionSchema PythonValue::getSchema( auto types_it = arg_types.begin(); for (; types_it != arg_types.end(); ++types_it, ++names_it) { - args.push_back(Argument( + args.emplace_back( /*name=*/*names_it, /*type=*/std::move(*types_it), /*N=*/c10::nullopt, /*default_value=*/c10::nullopt, - /*kwarg_only=*/false)); + /*kwarg_only=*/false); } rets.push_back(Argument("0", std::move(ret_type), {}, {}, false)); } @@ -114,21 +115,20 @@ FunctionSchema PythonValue::getSchema( std::shared_ptr PythonValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { - std::vector inputsWithSelf; + std::vector argsWithSelf; if (moduleSelf_) { - inputsWithSelf.emplace_back(NamedValue("self", moduleSelf_)); + argsWithSelf.emplace_back(NamedValue("self", moduleSelf_)); } - inputsWithSelf.insert(inputsWithSelf.end(), inputs_.begin(), inputs_.end()); - inputs_ = inputsWithSelf; + argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); - auto schema = getSchema(inputs_.size(), n_binders, loc); - auto inputs = toValues(*m.graph(), inputs_); + auto schema = getSchema(argsWithSelf.size(), n_binders, loc); + auto inputs = toValues(*m.graph(), argsWithSelf); MatchedSchema matched_schema = - matchSchema(schema, loc, *m.graph(), inputs_, attributes); + matchSchema(schema, loc, *m.graph(), argsWithSelf, kwargs); // If if a function is marked as dropped, // we throw an exception if it is invoked. @@ -217,6 +217,33 @@ std::shared_ptr PythonModuleValue::attr( return toSugaredValue(member, m, loc, /*is_constant=*/true); } +#ifndef __HIP_PLATFORM_HCC__ +std::shared_ptr CUDAPythonModuleValue::attr( + const SourceRange& loc, + Function& m, + const std::string& field) { + // List of all the cuda operators which are supported in JIT + const std::unordered_set cuda_ops = { + "current_stream", + "default_stream", + "_current_device", + "_set_device", + "device_index", + "device_count", + "set_stream"}; + + if (cuda_ops.find(field) != cuda_ops.end()) { + return std::make_shared(Symbol::cuda(field), c10::nullopt); + } + + py::object member = getattr(loc, field); + // note: is_constant = true because we consider that global properties + // on modules like math.pi or torch.float to be constants + // even though it is possible, though rare, for someone to mutate them + return toSugaredValue(member, m, loc, /*is_constant=*/true); +} +#endif + Value* ModuleValue::asValue(const SourceRange& loc, Function& m) { return self_; } @@ -234,9 +261,11 @@ SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) { SugaredValuePtr ModuleValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) { - return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx); + return getSugaredDict(loc, m)->getModules()->getitem( + loc, m, idx, type_hint); } else if ( concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) { if (auto ivalue = toIValue(idx)) { @@ -252,6 +281,32 @@ SugaredValuePtr ModuleValue::getitem( } } throw ErrorReport(loc) << "Key Error, " << idx_str; + } else if (type_hint) { + // Check that all submodules comply with the type hint. + const auto& self_type = concreteType_->getJitType()->expect(); + for (size_t i = 0; i < self_type->numAttributes(); ++i) { + const auto& attr_type = self_type->getAttribute(i); + if (attr_type->is_module()) { + std::stringstream ss; + if (!attr_type->isSubtypeOfExt(type_hint, &ss)) { + auto loc = self_->node()->sourceRange(); + throw ErrorReport(loc) + << "Attribute " << self_type->getAttributeName(i) + << " is not of annotated type " << type_hint->annotation_str() + << ": " << ss.str(); + } + } + } + + // Emit a prim::ModuleDictIndex operator. This is needed because it's + // difficult to construct a dict in the graph representing the ModuleDict + // and use aten::__getitem__ ops to index into it because any call to + // ModuleDict.setAttr would invalidate that emitted dict. + auto graph = m.graph(); + auto* getitem_node = + graph->insertNode(graph->create(prim::ModuleDictIndex, {self_, idx})); + getitem_node->output(0)->setType(type_hint); + return std::make_shared(getitem_node->output(0)); } throw ErrorReport(loc) << "Unable to extract string literal index. " @@ -265,7 +320,7 @@ SugaredValuePtr ModuleValue::getitem( void checkInterface( const SourceRange& loc, Function& m, - std::shared_ptr self, + const std::shared_ptr& self, const std::string& field) { if (self->asValue(loc, m)->type()->cast()) { throw ErrorReport(loc) @@ -279,7 +334,7 @@ void recurseThroughNestedModules( Function& m, std::vector& keys, std::vector& values, - std::shared_ptr self, + std::shared_ptr& self, const std::string& prefix, const std::string& field) { auto prefix_value = @@ -576,7 +631,7 @@ std::shared_ptr ModuleValue::attr( // Check if it's a property. auto prop = - concreteType_->getJitType()->expect()->getProperty(field); + concreteType_->getJitType()->expectRef().getProperty(field); if (prop) { return MethodValue(self_, prop->getter->name()) .call(loc, m, {}, {}, /*n_binders=*/1); @@ -586,11 +641,14 @@ std::shared_ptr ModuleValue::attr( std::string hint; if (auto failureReason = concreteType_->findFailedAttribute(field)) { hint = *failureReason; + } else if (concreteType_->isIgnoredAttribute(field)) { + hint = "attribute was ignored during compilation"; } throw ErrorReport(loc) << "Module '" - << concreteType_->getJitType()->expect()->name()->name() << "'" + << concreteType_->getJitType()->expectRef().name()->name() + << "'" << " has no attribute '" << field << "' " << hint; } @@ -650,8 +708,8 @@ void ModuleValue::setAttr( std::shared_ptr BooleanDispatchValue::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { c10::optional result; Graph& graph = *(caller.graph()); @@ -660,14 +718,14 @@ std::shared_ptr BooleanDispatchValue::call( auto arg_name = py::str(dispatched_fn_["arg_name"]); ErrorReport error(loc); - if (index < inputs.size()) { + if (index < args.size()) { // Dispatch flag is in arg list - result = constant_as(inputs.at(index).value(graph)); + result = constant_as(args.at(index).value(graph)); error << "Argument for boolean dispatch at position " << index << " was not constant"; - } else if (auto i = findInputWithName(arg_name, attributes)) { + } else if (auto i = findInputWithName(arg_name, kwargs)) { // Dispatch flag is in kwargs - result = constant_as(attributes[*i].value(graph)); + result = constant_as(kwargs[*i].value(graph)); error << "Keyword argument '" << arg_name << "' for boolean dispatch at position was not constant"; } else { @@ -686,28 +744,28 @@ std::shared_ptr BooleanDispatchValue::call( } else { value = toSugaredValue(dispatched_fn_["if_false"], caller, loc); } - return value->call(loc, caller, inputs, attributes, n_binders); + return value->call(loc, caller, args, kwargs, n_binders); } std::shared_ptr PythonExceptionValue::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t /*n_binders*/) { Value* error_message = nullptr; - if (inputs.size() == 0) { + if (args.size() == 0) { error_message = insertConstant(*caller.graph(), "", loc); - } else if (inputs.size() == 1) { - error_message = inputs.at(0).value(*caller.graph()); + } else if (args.size() == 1) { + error_message = args.at(0).value(*caller.graph()); } else { std::vector message_values; - message_values.reserve(inputs.size() + attributes.size()); + message_values.reserve(args.size() + kwargs.size()); - for (auto inp : inputs) { + for (const auto& inp : args) { message_values.push_back(inp.value(*caller.graph())); } - for (auto kwarg_inp : attributes) { + for (const auto& kwarg_inp : kwargs) { message_values.push_back(kwarg_inp.value(*caller.graph())); } error_message = @@ -800,10 +858,10 @@ std::shared_ptr createSimpleEnumValue( std::shared_ptr PythonSliceClass::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t /*n_binders*/) { - if (!attributes.empty()) { + if (!kwargs.empty()) { throw ErrorReport(loc) << "Slice does not accept any keyword arguments"; } @@ -822,23 +880,23 @@ std::shared_ptr PythonSliceClass::call( Value* start; Value* stop; Value* step; - size_t n = inputs.size(); + size_t n = args.size(); // Slice's constructor signature is Slice(start=None, stop, step=None) if (n == 1) { // Case where only `stop` is specified. start = ValOr(nullptr, default_start); - stop = ValOr(inputs[0].value(graph), default_stop); + stop = ValOr(args[0].value(graph), default_stop); step = ValOr(nullptr, default_step); } else if (n == 2) { // Case where `start` and `stop` are specified. - start = ValOr(inputs[0].value(graph), default_start); - stop = ValOr(inputs[1].value(graph), default_stop); + start = ValOr(args[0].value(graph), default_start); + stop = ValOr(args[1].value(graph), default_stop); step = ValOr(nullptr, default_step); } else if (n == 3) { // Case where `start`, `stop` and `step` are all specified. - start = ValOr(inputs[0].value(graph), default_start); - stop = ValOr(inputs[1].value(graph), default_stop); - step = ValOr(inputs[2].value(graph), default_step); + start = ValOr(args[0].value(graph), default_start); + stop = ValOr(args[1].value(graph), default_stop); + step = ValOr(args[2].value(graph), default_step); } else { throw ErrorReport(loc) << "slice accepts exactly 1, 2 or 3 arguments, got: " << n; @@ -908,6 +966,12 @@ std::shared_ptr toSugaredValue( if (auto callee = as_function(obj)) { return std::make_shared(callee->function_); } else if (py::isinstance(obj)) { +#ifndef USE_ROCM + std::string obj_name = py::cast(py::getattr(obj, "__name__")); + if (obj_name.compare("torch.cuda") == 0) { + return std::make_shared(obj); + } +#endif return std::make_shared(obj); } else if ( obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() || @@ -916,7 +980,10 @@ std::shared_ptr toSugaredValue( } else if ( obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) { return SpecialFormValue::create(prim::annotate); -#ifdef USE_DISTRIBUTED + } else if ( + obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { + return SpecialFormValue::create(prim::isinstance); +#ifdef USE_RPC // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. } else if ( obj.ptr() == diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index ecb3c6da4ff48..1edbc6c15cada 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace torch { @@ -47,8 +48,8 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::string kind() const override; @@ -90,6 +91,20 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { const std::string& field) override; }; +// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with +// torch.cuda.* are resolved using CUDAPythonModuleValue. +#ifndef __HIP_PLATFORM_HCC__ +struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue { + explicit CUDAPythonModuleValue(py::object mod) + : PythonValue(std::move(mod)) {} + + std::shared_ptr attr( + const SourceRange& loc, + Function& m, + const std::string& field) override; +}; +#endif + // Represents all the parameters of a module as a List[Tensor] struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { ConstantParameterList(Value* the_list) : the_list_(the_list) {} @@ -99,8 +114,8 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { return toSimple(the_list_); } @@ -110,8 +125,8 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { }; struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue { - explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name) - : iterable_(iterable), name_(name){}; + explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name) + : iterable_(std::move(iterable)), name_(std::move(name)){}; std::string kind() const override { return name_; @@ -120,10 +135,10 @@ struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - if (inputs.size() || attributes.size()) { + if (args.size() || kwargs.size()) { throw ErrorReport(loc) << name_ << " method does not accept any arguments"; } @@ -175,11 +190,11 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { return attr(loc, caller, "forward") - ->call(loc, caller, inputs, attributes, n_binders); + ->call(loc, caller, args, kwargs, n_binders); } std::shared_ptr getSugaredDict( @@ -201,7 +216,8 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { std::shared_ptr getitem( const SourceRange& loc, Function& m, - Value* idx) override; + Value* idx, + TypePtr type_hint) override; private: Value* self_; @@ -216,7 +232,7 @@ void recurseThroughNestedModules( Function& m, std::vector& keys, std::vector& values, - std::shared_ptr self, + std::shared_ptr& self, const std::string& prefix, const std::string& field); @@ -248,7 +264,7 @@ struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue { Function& m, const std::string& field) override; - SugaredValuePtr iter(const SourceRange& loc, Function& m) { + SugaredValuePtr iter(const SourceRange& loc, Function& m) override { return keys_; }; @@ -268,8 +284,8 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; private: @@ -308,14 +324,14 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; // Python Slice class. struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { - explicit PythonSliceClass() {} + explicit PythonSliceClass() = default; std::string kind() const override { return "Python slice class"; @@ -324,8 +340,8 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 0aa46a0139ba7..550ba12a46c08 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -23,25 +23,44 @@ namespace tracer { // Python interpreter retrieval routine adapted from // https://stackoverflow.com/a/8706144 -SourceRange getPythonInterpreterSourceRange() { - c10::optional source_filename; - size_t source_line = 0; - std::stringstream stack_trace; - +std::vector _pythonCallstack() { pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); + std::vector entries; while (nullptr != frame) { - int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti); + size_t line = PyCode_Addr2Line(frame->f_code, frame->f_lasti); std::string filename = THPUtils_unpackString(frame->f_code->co_filename); std::string funcname = THPUtils_unpackString(frame->f_code->co_name); - stack_trace << filename << "(" << line << "): " << funcname << "\n"; - if (!source_filename) { - source_filename = filename; - source_line = line; - } + auto source = std::make_shared(funcname, filename, line); + entries.emplace_back( + StackEntry{funcname, SourceRange(source, 0, funcname.size())}); frame = frame->f_back; } + return entries; +} + +SourceRange getPythonInterpreterSourceRange() { + auto cs = pythonCallstack(); + c10::optional source_filename; + size_t source_line = 0; + std::stringstream stack_trace; + for (const auto& entry : cs) { + auto& range = entry.range; + if (range.source()) { + auto& src = range.source(); + if (src && src->filename()) { + auto line = + src->starting_line_no() + src->lineno_for_offset(range.start()); + stack_trace << *(src->filename()) << "(" << line + << "): " << entry.filename << "\n"; + if (!source_filename) { + source_filename = *(src->filename()); + source_line = line; + } + } + } + } auto stack_trace_text = stack_trace.str(); auto source = @@ -123,6 +142,7 @@ void pythonWarn(const std::string& reason) { } void initPythonTracerBindings(PyObject* module) { + setPythonCallstack(_pythonCallstack); setRecordSourceLocation(pythonRecordSourceLocation); auto m = py::handle(module).cast(); @@ -156,7 +176,9 @@ void initPythonTracerBindings(PyObject* module) { }) .def( "set_graph", - [](TracingState& s, std::shared_ptr g) { s.graph = g; }) + [](TracingState& s, std::shared_ptr g) { + s.graph = std::move(g); + }) .def("graph", [](TracingState& s) { return s.graph; }); m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); }); @@ -171,7 +193,7 @@ void initPythonTracerBindings(PyObject* module) { py::arg("self") = nullptr); m.def("_get_tracing_state", []() { return getTracingState(); }); m.def("_set_tracing_state", [](std::shared_ptr state) { - return setTracingState(state); + return setTracingState(std::move(state)); }); m.def("_get_value_trace", [](const Variable& var) { return getValueTrace(var); @@ -179,7 +201,7 @@ void initPythonTracerBindings(PyObject* module) { m.def("_set_value_trace", [](const Variable& var, Value* value) { return setValueTrace(var, value); }); - m.def("_tracer_set_get_unique_name_fn", [](py::function func) { + m.def("_tracer_set_get_unique_name_fn", [](const py::function& func) { const auto& tracing_state = getTracingState(); AT_ASSERT(tracing_state); tracing_state->lookup_var_name_fn = diff --git a/torch/csrc/jit/python/python_tracer.h b/torch/csrc/jit/python/python_tracer.h index 9797a1e32e502..5d8e3a9a52eaa 100644 --- a/torch/csrc/jit/python/python_tracer.h +++ b/torch/csrc/jit/python/python_tracer.h @@ -16,7 +16,6 @@ struct Module; namespace tracer { void initPythonTracerBindings(PyObject* module); -std::string getPythonInterpreterStackTrace(); SourceRange getPythonInterpreterSourceRange(); Node* preRecordPythonTrace( diff --git a/torch/csrc/jit/python/python_tree_views.cpp b/torch/csrc/jit/python/python_tree_views.cpp index 49b8f7d4f4afe..cdfb28baa2966 100644 --- a/torch/csrc/jit/python/python_tree_views.cpp +++ b/torch/csrc/jit/python/python_tree_views.cpp @@ -25,7 +25,7 @@ c10::optional maybeConvertToString(const py::object& obj) { struct SourceRangeFactory { SourceRangeFactory( std::string text, - py::object filename, + const py::object& filename, size_t file_lineno, size_t leading_whitespace_chars) : source_(std::make_shared( @@ -200,9 +200,10 @@ void initTreeViewBindings(PyObject* module) { r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type)); })); - py::class_(m, "Delete").def(py::init([](Expr expr) { - return Delete::create(expr); - })); + py::class_(m, "Delete") + .def(py::init([](const SourceRange& range, std::vector targets) { + return Delete::create(range, wrap_list(range, std::move(targets))); + })); py::class_(m, "WithItem") .def(py::init([](const SourceRange& range, const Expr& target, Var* var) { @@ -227,12 +228,13 @@ void initTreeViewBindings(PyObject* module) { wrap_maybe(li.range(), type)); })); py::class_(m, "AugAssign") - .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) { - const auto& r = lhs.range(); - auto kind = - AugAssignKind(Compound::create(stringToKind(kind_str), r, {})); - return AugAssign::create(r, lhs, kind, rhs); - })); + .def(py::init( + [](const Expr& lhs, const std::string& kind_str, const Expr& rhs) { + const auto& r = lhs.range(); + auto kind = + AugAssignKind(Compound::create(stringToKind(kind_str), r, {})); + return AugAssign::create(r, lhs, kind, rhs); + })); py::class_(m, "Return") .def(py::init([](const SourceRange& range, Expr* value) { return Return::create( @@ -282,7 +284,7 @@ void initTreeViewBindings(PyObject* module) { wrap_list(range, std::move(targets)), wrap_list(range, std::move(body))); })); - py::class_(m, "For").def(py::init([](const SourceRange range, + py::class_(m, "For").def(py::init([](const SourceRange& range, std::vector& targets, std::vector& itrs, std::vector body) { @@ -301,25 +303,26 @@ void initTreeViewBindings(PyObject* module) { [](const Ident& name) { return Var::create(name.range(), name); })) .def_property_readonly("name", [](const Var& var) { return var.name(); }); py::class_(m, "BinOp") - .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) { - return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs); - })); + .def(py::init( + [](const std::string& kind, const Expr& lhs, const Expr& rhs) { + return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs); + })); // NB: we take range here, because unary ops precede their exprs, so we need // to include them py::class_(m, "UnaryOp") - .def(py::init( - [](const SourceRange& range, std::string kind, const Expr& expr) { - auto resolved_kind = stringToKind(kind); - resolved_kind = - resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind; - return UnaryOp::create(range, resolved_kind, expr); - })); + .def(py::init([](const SourceRange& range, + const std::string& kind, + const Expr& expr) { + auto resolved_kind = stringToKind(kind); + resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind; + return UnaryOp::create(range, resolved_kind, expr); + })); py::class_(m, "Const") - .def(py::init([](const SourceRange& range, std::string value) { + .def(py::init([](const SourceRange& range, const std::string& value) { return Const::create(range, value); })); py::class_(m, "StringLiteral") - .def(py::init([](const SourceRange& range, std::string value) { + .def(py::init([](const SourceRange& range, const std::string& value) { return StringLiteral::create(range, value); })); py::class_(m, "Apply") @@ -349,6 +352,14 @@ void initTreeViewBindings(PyObject* module) { const Expr& iter) { return ListComp::create(range, elt, target, iter); })); + py::class_(m, "DictComp") + .def(py::init([](const SourceRange& range, + const Expr& key, + const Expr& value, + const Expr& target, + const Expr& iter) { + return DictComp::create(range, key, value, target, iter); + })); py::class_(m, "ListLiteral") .def(py::init([](const SourceRange& range, std::vector args) { return ListLiteral::create(range, wrap_list(range, std::move(args))); @@ -383,7 +394,7 @@ void initTreeViewBindings(PyObject* module) { wrap_maybe(range, step)); })); py::class_(m, "Starred") - .def(py::init([](const SourceRange& range, Expr expr) { + .def(py::init([](const SourceRange& range, const Expr& expr) { return Starred::create(range, expr); })); py::class_, TreeView>(m, "EmptyTypeAnnotation") diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 95d041fe315b1..426707e303d3f 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -147,11 +147,11 @@ struct PythonResolver : public Resolver { ClassTypePtr classType_; }; -std::shared_ptr pythonResolver(ResolutionCallback rcb) { +std::shared_ptr pythonResolver(const ResolutionCallback& rcb) { return std::make_shared(rcb); } std::shared_ptr pythonResolver( - ResolutionCallback rcb, + const ResolutionCallback& rcb, std::string classname, ClassTypePtr classType) { return std::make_shared( @@ -339,7 +339,7 @@ static StrongFunctionPtr script_compile_overloaded_function( const c10::QualifiedName& name, const Decl& overload_decl, const Def& implementation_def, - ResolutionCallback rcb, + const ResolutionCallback& rcb, const FunctionDefaults& implementation_defaults, const py::object& signature) { if (signature.is(py::none())) { @@ -356,7 +356,7 @@ static StrongFunctionPtr script_compile_overloaded_function( /*properties=*/{}, /*propResolvers=*/{}, {new_def}, - {pythonResolver(std::move(rcb))}, + {pythonResolver(rcb)}, nullptr, true); TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); @@ -377,14 +377,14 @@ static StrongFunctionPtr script_compile_function( const c10::QualifiedName& name, const Def& def, const FunctionDefaults& defaults, - ResolutionCallback rcb) { + const ResolutionCallback& rcb) { auto cu = get_python_cu(); auto defined_functions = cu->define( QualifiedName(name.prefix()), /*properties=*/{}, /*propResolvers=*/{}, {def}, - {pythonResolver(std::move(rcb))}, + {pythonResolver(rcb)}, nullptr, true); TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); @@ -491,21 +491,6 @@ static std::shared_ptr _propagate_and_assign_input_shapes( return retval; } -static std::shared_ptr _assign_output_shapes( - Graph& graph, - std::vector outputs) { - auto retval = graph.copy(); - AT_ASSERT(retval->outputs().size() == outputs.size()); - for (size_t i = 0; i < outputs.size(); ++i) { - auto scalar_type = outputs[i].scalar_type(); - auto sizes = outputs[i].sizes(); - auto type = - torch::jit::TensorType::createContiguous(scalar_type, at::kCPU, sizes); - retval->outputs()[i]->setType(type); - } - return retval; -} - void addFunctionToModule(Module& module, const StrongFunctionPtr& func) { // Make a graph with a fake self argument auto graph = func.function_->graph()->copy(); @@ -641,7 +626,7 @@ struct slot_dict_impl { template py::list debugMakeList(const T& list) { py::list result; - for (auto elem : list) { + for (const auto& elem : list) { result.append(py::cast(elem)); } return result; @@ -681,7 +666,7 @@ static py::dict _jit_debug_module_iterators(Module& module) { return result; } -static constexpr const char* magic_method_names[] = { +static constexpr std::array magic_method_names = { "__lt__", "__le__", "__eq__", "__ne__", "__ge__", "__gt__", "__not__", "__abs__", "__add__", "__and__", "__floordiv__", "__index__", @@ -729,7 +714,7 @@ void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); // NOLINTNEXTLINE(bugprone-unused-raii) - py::class_>(m, "Capsule"); + py::class_(m, "Capsule"); auto object_class = py::class_(m, "ScriptObject") @@ -758,18 +743,26 @@ void initJitScriptBindings(PyObject* module) { .def( "getattr", [](Object& self, const std::string& name) { - return toPyObject(self.attr(name)); + try { + return toPyObject(self.attr(name)); + } catch (const ObjectAttributeError& err) { + throw AttributeError("%s", err.what()); + } }) .def( "__getattr__", [](Object& self, const std::string& name) -> py::object { - if (name == "__qualname__") { - return py::cast(self.type()->name()->name()); - } - if (auto method = self.find_method(name)) { - return py::cast(*method); + try { + if (name == "__qualname__") { + return py::cast(self.type()->name()->name()); + } + if (auto method = self.find_method(name)) { + return py::cast(*method); + } + return toPyObject(self.attr(name)); + } catch (const ObjectAttributeError& err) { + throw AttributeError("%s", err.what()); } - return toPyObject(self.attr(name)); }) .def( "hasattr", @@ -789,6 +782,12 @@ void initJitScriptBindings(PyObject* module) { }); }) .def("__copy__", &Object::copy) + .def( + "__hash__", + [](const Object& self) { + // Similar to Tensor's `__hash__`, which is `id()`. + return std::hash{}(self._ivalue().get()); + }) .def(py::pickle( [](const Object& self) -> std::tuple { // __getstate__ @@ -806,7 +805,8 @@ void initJitScriptBindings(PyObject* module) { err << "which does not have a __getstate__ method defined!"; throw std::runtime_error(err.str()); }, - [](std::tuple state_tup) -> Object { + [](const std::tuple& state_tup) + -> Object { py::object state; std::string qualname; std::tie(state, qualname) = state_tup; @@ -970,7 +970,7 @@ void initJitScriptBindings(PyObject* module) { [](Module& m, std::shared_ptr concreteType, const std::string& script, - ResolutionCallback rcb) { + const ResolutionCallback& rcb) { const auto self = ModuleSelf(std::move(concreteType)); m._ivalue()->compilation_unit()->define( *m.type()->name(), script, pythonResolver(rcb), &self); @@ -980,7 +980,7 @@ void initJitScriptBindings(PyObject* module) { "_register_attribute", [](Module& m, const std::string& name, - TypePtr type, + const TypePtr& type, py::handle value) { m.register_attribute(name, type, toIValue(value, type)); }) @@ -988,9 +988,9 @@ void initJitScriptBindings(PyObject* module) { "_create_method_from_trace", [](Module& self, const std::string& name, - py::function func, - py::tuple input_tuple, - py::function var_lookup_fn, + const py::function& func, + const py::tuple& input_tuple, + const py::function& var_lookup_fn, bool strict, bool force_outplace) { // prereq: Module's buffers and parameters are unique @@ -1106,7 +1106,7 @@ void initJitScriptBindings(PyObject* module) { "define", [](CompilationUnit& cu, const std::string& src, - ResolutionCallback rcb) { + const ResolutionCallback& rcb) { cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr); }) .def( @@ -1193,9 +1193,13 @@ void initJitScriptBindings(PyObject* module) { "name", [](const StrongFunctionPtr& self) { return self.function_->name(); }) .def_property_readonly( - "qualified_name", [](const StrongFunctionPtr& self) { + "qualified_name", + [](const StrongFunctionPtr& self) { return self.function_->qualname().qualifiedName(); - }); + }) + .def_property_readonly("__doc__", [](const StrongFunctionPtr& self) { + return self.function_->doc_string(); + }); py::class_(m, "ScriptMethod", py::dynamic_attr()) .def( @@ -1245,19 +1249,19 @@ void initJitScriptBindings(PyObject* module) { "_jit_script_compile", [](const std::string& qualname, const Def& def, - ResolutionCallback rcb, + const ResolutionCallback& rcb, const FunctionDefaults& defaults) { C10_LOG_API_USAGE_ONCE("torch.script.compile"); const auto name = c10::QualifiedName(qualname); TORCH_INTERNAL_ASSERT(name.name() == def.name().name()); - return script_compile_function(name, def, defaults, std::move(rcb)); + return script_compile_function(name, def, defaults, rcb); }); m.def( "_jit_script_compile_overload", [](const std::string& qualname, const Decl& overload_decl, const Def& implementation_def, - ResolutionCallback rcb, + const ResolutionCallback& rcb, const FunctionDefaults& implementation_defaults, const py::object& signature) { const auto name = c10::QualifiedName(qualname); @@ -1265,7 +1269,7 @@ void initJitScriptBindings(PyObject* module) { name, overload_decl, implementation_def, - std::move(rcb), + rcb, implementation_defaults, signature); }); @@ -1279,10 +1283,10 @@ void initJitScriptBindings(PyObject* module) { }); m.def( "_create_function_from_trace", - [](std::string qualname, - py::function func, - py::tuple input_tuple, - py::function var_lookup_fn, + [](const std::string& qualname, + const py::function& func, + const py::tuple& input_tuple, + const py::function& var_lookup_fn, bool strict, bool force_outplace) { auto typed_inputs = toTraceableStack(input_tuple); @@ -1303,7 +1307,7 @@ void initJitScriptBindings(PyObject* module) { [](const std::string& qualifiedName, const ClassDef& classDef, const ClassMethodDefaults& defaults, - ResolutionCallback rcb) { + const ResolutionCallback& rcb) { C10_LOG_API_USAGE_ONCE("torch.script.class"); if (classDef.superclass().present()) { throw ErrorReport(classDef.range()) @@ -1370,12 +1374,12 @@ void initJitScriptBindings(PyObject* module) { "_jit_script_interface_compile", [](const std::string& qualifiedName, const ClassDef& classDef, - ResolutionCallback rcb, + const ResolutionCallback& rcb, bool is_module) { get_python_cu()->define_interface( c10::QualifiedName(qualifiedName), classDef, - pythonResolver(std::move(rcb)), + pythonResolver(rcb), is_module); }); @@ -1462,10 +1466,11 @@ void initJitScriptBindings(PyObject* module) { m.def( "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining); + m.def("_debug_set_fusion_group_inlining", debugSetFusionGroupInlining); + m.def("_debug_get_fusion_group_inlining", getFusionGroupInlining); m.def("_propagate_shapes", _propagate_shapes); m.def( "_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes); - m.def("_assign_output_shapes", _assign_output_shapes); m.def( "_last_executed_optimized_graph", []() { return lastExecutedOptimizedGraph(); }, @@ -1476,7 +1481,7 @@ void initJitScriptBindings(PyObject* module) { // TODO this should go in the global Python CU auto cu = std::make_shared(); c10::QualifiedName name(qualname); - auto fn = cu->create_function(std::move(name), graph); + auto fn = cu->create_function(std::move(name), std::move(graph)); return StrongFunctionPtr(std::move(cu), fn); }); m.def("_ivalue_tags_match", ivalue_tags_match); @@ -1579,6 +1584,17 @@ void initJitScriptBindings(PyObject* module) { .def( "add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute) + .def( + "add_ignored_attribute", + &ConcreteModuleTypeBuilder::addIgnoredAttribute) + .def( + "add_ignored_attributes", + [](ConcreteModuleTypeBuilder& self, + const std::vector& names) { + for (auto& name : names) { + self.addIgnoredAttribute(name); + } + }) .def( "set_module_dict", [](ConcreteModuleTypeBuilder& self) { @@ -1604,6 +1620,7 @@ void initJitScriptBindings(PyObject* module) { .def("get_attributes", &ConcreteModuleType::getAttributesPy) .def("get_modules", &ConcreteModuleType::getModulesPy) .def("dump", &ConcreteModuleType::dump) + .def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute) .def( "equals", [](const ConcreteModuleType& self, const ConcreteModuleType& other) { @@ -1668,18 +1685,23 @@ void initJitScriptBindings(PyObject* module) { m.def( "_resolve_type", - [](const std::string& name, SourceRange range, ResolutionCallback rcb) { + [](const std::string& name, + const SourceRange& range, + const ResolutionCallback& rcb) { return pythonResolver(rcb)->resolveType(name, range); }); m.def( "_resolve_type_from_object", - [](const py::object& obj, SourceRange range, ResolutionCallback rcb) { + [](const py::object& obj, + const SourceRange& range, + const ResolutionCallback& rcb) { return pythonResolver(rcb)->resolveTypeFromObject(obj, range); }); m.def( "_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); }); + // NOLINTNEXTLINE(bugprone-unused-raii) py::class_>( m, "LoggerBase"); py::enum_(m, "AggregationType") diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 401933c6d67e7..a0e60e879146a 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -237,7 +237,7 @@ struct CompleteArgumentSpec { for (int32_t i = 0; i < num_inputs; i++) { if (!inputs[i].isTensor()) continue; - auto tensor = inputs[i].toTensor(); + auto& tensor = inputs[i].toTensor(); all_dims += tensor.defined() ? tensor.ndimension() : 0; } // allocate enough room for all TensorPODs and dimensions diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 6b5a33c3b478e..c3eebdfda1293 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -45,7 +45,7 @@ bool needTrimGrad(Node* n) { return false; } -bool isDifferentiable(Node* n) { +bool isDifferentiable(const Node* n) { // TODO: scalar-tensor ops should be canonicalized static OperatorSet differentiable_ops = { "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)", @@ -89,7 +89,7 @@ bool isDifferentiable(Node* n) { return std::all_of( body->nodes().begin(), body->nodes().end(), - static_cast(isDifferentiable)); + static_cast(isDifferentiable)); } // formulas are only defined with floating point scalars, @@ -107,7 +107,7 @@ bool isDifferentiable(Graph& g) { return std::all_of( g.nodes().begin(), g.nodes().end(), - static_cast(isDifferentiable)); + static_cast(isDifferentiable)); } // NB: Write gradient using torchscript @@ -253,12 +253,13 @@ class GradientHelper { graph->insertNode(graph->createTupleUnpack(backward_value)); auto tuple_outputs = tuple_unpack_node->outputs(); AT_ASSERT(tuple_outputs.size() == size_t(3)); - return {tuple_outputs[0], - tuple_outputs[1], - nullptr, - tuple_outputs[2], - nullptr, - nullptr}; + return { + tuple_outputs[0], + tuple_outputs[1], + nullptr, + tuple_outputs[2], + nullptr, + nullptr}; } else if ( node->matches( @@ -282,14 +283,15 @@ class GradientHelper { graph->insertNode(graph->createTupleUnpack(backward_value)); auto tuple_outputs = tuple_unpack_node->outputs(); AT_ASSERT(tuple_outputs.size() == size_t(3)); - return {tuple_outputs[0], - tuple_outputs[1], - tuple_outputs[2], - nullptr, - nullptr, - nullptr, - nullptr, - nullptr}; + return { + tuple_outputs[0], + tuple_outputs[1], + tuple_outputs[2], + nullptr, + nullptr, + nullptr, + nullptr, + nullptr}; } throw std::runtime_error( diff --git a/torch/csrc/jit/runtime/autodiff.h b/torch/csrc/jit/runtime/autodiff.h index 2b7c9a7f2f669..769b79da5ed00 100644 --- a/torch/csrc/jit/runtime/autodiff.h +++ b/torch/csrc/jit/runtime/autodiff.h @@ -90,7 +90,7 @@ struct Gradient { TORCH_API Gradient differentiate(std::shared_ptr& graph); // can we take a derivative of this node symbolically? -TORCH_API bool isDifferentiable(Node* n); +TORCH_API bool isDifferentiable(const Node* n); TORCH_API bool isDifferentiable(Graph& g); TORCH_API bool isZero(Value* v); diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index fbb515cea1041..cf9bb1bc6931f 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -85,6 +85,17 @@ bool getAutodiffSubgraphInlining() { return autodiff_subgraph_inlining; } +// for debugging it is helpful to be able to force fusion groups +// to be created +static std::atomic fusion_group_inlining(true); +void debugSetFusionGroupInlining(bool state) { + fusion_group_inlining = state; +} + +bool getFusionGroupInlining() { + return fusion_group_inlining; +} + thread_local std::weak_ptr last_executed_optimized_graph; std::shared_ptr lastExecutedOptimizedGraph() { return last_executed_optimized_graph.lock(); @@ -120,7 +131,7 @@ struct CaptureList { auto tensors = val.toTensorList(); sizes_.push_back(tensors.size()); - for (const at::Tensor& tensor : tensors) { + for (const at::Tensor tensor : tensors) { captureTensor(tensor, is_output); } } else { @@ -258,12 +269,14 @@ struct DifferentiableGraphBackward : public autograd::Node { size_t output_index = 0; for (IValue& v : stack) { if (v.isTensorList()) { - for (const at::Tensor& tensor : v.toTensorList()) { + for (at::Tensor tensor : v.toTensorList()) { produceOutput(output_index++, std::move(tensor), outputs); } } else if (v.isTensor()) { produceOutput(output_index++, std::move(v).toTensor(), outputs); } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(v.isNone()); + output_index++; // Input grad can also be None even if it requires grad // Example: `other` in expand_as(self, other) outputs.emplace_back(); @@ -284,11 +297,14 @@ struct DifferentiableGraphBackward : public autograd::Node { } void addOutputForIValue(const IValue& value) { if (value.isTensorList()) { - for (const at::Tensor& tensor : value.toTensorList()) { + for (const at::Tensor tensor : value.toTensorList()) { addOutputForTensor(tensor); } - } else { + } else if (value.isTensor()) { addOutputForTensor(value.toTensor()); + } else { + // We could have None passed here via `Optional[Tensor]` + add_next_edge(autograd::Edge{}); } } @@ -308,7 +324,7 @@ struct DifferentiableGraphBackward : public autograd::Node { if (v.isTensorList()) { auto tensors = v.toTensorList(); input_instructions_.pushTensorList(tensors.size()); - for (const at::Tensor& tensor : tensors) { + for (const at::Tensor tensor : tensors) { addInputVariable(tensor); } } else if (v.isTensor()) { @@ -499,13 +515,15 @@ void GraphExecutorImplBase::run(Stack& stack) { logging::getLogger()->addStatValue( logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); - ExecutionPlan plan = + const ExecutionPlan& plan = getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()); InterpreterState(plan.code).run(stack); last_executed_optimized_graph = plan.graph; } -c10::intrusive_ptr GraphExecutorImplBase::runAsync(Stack& stack) { +c10::intrusive_ptr GraphExecutorImplBase::runAsync( + Stack& stack, + TaskLauncher taskLauncher) { TORCH_CHECK( stack.size() >= num_inputs, "expected ", @@ -518,13 +536,14 @@ c10::intrusive_ptr GraphExecutorImplBase::runAsync(Stack& stack) { logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); struct Frame { - explicit Frame(ExecutionPlan eplan) - : plan(std::move(eplan)), state(plan.code) {} + explicit Frame(ExecutionPlan eplan, TaskLauncher taskLauncher) + : plan(std::move(eplan)), state(plan.code, std::move(taskLauncher)) {} ExecutionPlan plan; InterpreterState state; }; auto frame = std::make_shared( - getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())); + getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()), + std::move(taskLauncher)); auto res = frame->state.runAsync(stack); last_executed_optimized_graph = frame->plan.graph; if (!res->completed()) { @@ -549,7 +568,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); } - ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth) + const ExecutionPlan& getPlanFor(Stack& stack, size_t remaining_bailout_depth) override { return getGraphExecutorOptimize() ? getOrCompile(stack) : getOrCompileFallback(); @@ -705,7 +724,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { }; GraphExecutor::GraphExecutor( - std::shared_ptr graph, + const std::shared_ptr& graph, std::string function_name) : pImpl( IsNewExecutorEnabled() @@ -720,15 +739,17 @@ void GraphExecutor::run(Stack& inputs) { return pImpl->run(inputs); } -c10::intrusive_ptr GraphExecutor::runAsync(Stack& stack) { - return pImpl->runAsync(stack); +c10::intrusive_ptr GraphExecutor::runAsync( + Stack& stack, + TaskLauncher taskLauncher) { + return pImpl->runAsync(stack, std::move(taskLauncher)); } size_t GraphExecutor::getDefaultNumBailOuts() { return getProfilingMode() ? getBailoutDepth().load() : 0; } -ExecutionPlan GraphExecutor::getPlanFor( +const ExecutionPlan& GraphExecutor::getPlanFor( Stack& inputs, size_t remaining_bailout_depth) { return pImpl->getPlanFor(inputs, remaining_bailout_depth); @@ -849,7 +870,10 @@ void runNondiffOptimization( "After customPostPassses (end of runNondiffOptimization)\n", *graph); } -void runOptimization(std::shared_ptr& graph, bool unroll) { +void runOptimization( + std::shared_ptr& graph, + bool unroll, + bool const_prop_user_classes) { // Basic graph preprocessing to eliminate noise. GRAPH_DEBUG( "Before EliminateDeadCode (beginning of runOptimization)\n", *graph); @@ -862,8 +886,14 @@ void runOptimization(std::shared_ptr& graph, bool unroll) { PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); - ConstantPropagation(graph); + + if (const_prop_user_classes) { + ConstantPropagation(graph); + } else { + ConstantPropagation(graph, true); + } GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph); + ConstantPooling(graph); GRAPH_DEBUG("After ConstantPooling\n", *graph); diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 6fee30834f1ed..1b938c1876484 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -55,10 +55,12 @@ struct TORCH_API EnableProfilingGuard { struct GraphExecutorImplBase; struct TORCH_API GraphExecutor { GraphExecutor() = default; - GraphExecutor(std::shared_ptr graph, std::string function_name); + GraphExecutor(const std::shared_ptr& graph, std::string function_name); void run(Stack& inputs); - c10::intrusive_ptr runAsync(Stack& stack); + c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher taskLauncher = at::launch); // `remaining_bailout_depth` stands for the maximum number of profiled and // specialized recompilations allowed for the current `GraphExecutor`. if @@ -69,7 +71,9 @@ struct TORCH_API GraphExecutor { // profiled information whenever a bailout check is failed/triggered, a new // `GraphExecutor` will be created. This new `GraphExecutor`'s // remaining_bailout_depth will be reduced by 1. - ExecutionPlan getPlanFor(Stack& inputs, size_t remaining_bailout_depth); + const ExecutionPlan& getPlanFor( + Stack& inputs, + size_t remaining_bailout_depth); explicit operator bool() const { return pImpl != nullptr; } @@ -93,6 +97,9 @@ TORCH_API Node* replaceBlockWithFallbackGraph( // regardless of whether sizes have been specialized or not. TORCH_API void runRequiredPasses(const std::shared_ptr& g); +TORCH_API void debugSetFusionGroupInlining(bool state); +TORCH_API bool getFusionGroupInlining(); + TORCH_API void debugSetAutodiffSubgraphInlining(bool state); TORCH_API std::shared_ptr lastExecutedOptimizedGraph(); diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h index eed7e1f57f1f2..b762e7a950893 100644 --- a/torch/csrc/jit/runtime/graph_executor_impl.h +++ b/torch/csrc/jit/runtime/graph_executor_impl.h @@ -31,13 +31,19 @@ namespace jit { void packGradient(const Gradient& gradient, Node* dnode); bool needsGradient(const std::shared_ptr& graph); -void runOptimization(std::shared_ptr& graph, bool unroll = true); +void runOptimization( + std::shared_ptr& graph, + bool unroll = true, + bool const_prop_user_classes = true); void runNondiffOptimization( std::shared_ptr& graph, bool strict_fuser_check = false); void debugSetAutodiffSubgraphInlining(bool state); bool getAutodiffSubgraphInlining(); +void debugSetFusionGroupInlining(bool state); +bool getFusionGroupInlining(); + // Tunable parameters for deciding when to create/keep subgraphs of // differentiable code const size_t autodiffSubgraphNodeThreshold = 2; @@ -66,9 +72,11 @@ struct GraphExecutorImplBase { // entry point where execution begins void run(Stack& stack); - c10::intrusive_ptr runAsync(Stack& stack); + c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher taskLauncher = at::launch); - virtual ExecutionPlan getPlanFor( + virtual const ExecutionPlan& getPlanFor( Stack& stack, size_t remaining_bailout_depth) = 0; virtual GraphExecutorState getDebugState() = 0; diff --git a/torch/csrc/jit/runtime/instruction.h b/torch/csrc/jit/runtime/instruction.h index 8cfbb17e76855..dae7a0bcad3f1 100644 --- a/torch/csrc/jit/runtime/instruction.h +++ b/torch/csrc/jit/runtime/instruction.h @@ -52,7 +52,7 @@ namespace jit { _(ISINSTANCE, "TI") /* check object is one of types[X:X+N] */ \ _(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */ \ _(FORK, "CN") /* launch a thread to run code entry x with N inputs */ \ - _(WARN, "") /* emit a warning with line information */ \ + _(WARN, "I") /* emit a warning with line information */ \ _(ENTER, "EN") /* enter scope of a contextmanager */ \ _(EXIT, "EX") /* exit the last entered contextmanager */ diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 337fe66c07897..5e6c9a96aca6a 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -23,7 +24,7 @@ #include #include -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC #include using torch::distributed::autograd::DistAutogradContainer; #endif @@ -267,7 +268,7 @@ void insertLastUses(Graph& g) { } inline int64_t getDistAutogradContextId() { -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC return DistAutogradContainer::currentContextId(); #else return 0; @@ -319,6 +320,7 @@ struct CanEmitInline { // by the later BailOut in createBailoutBlock and its jf_index // will become invalid. v->node()->kind() != prim::TensorExprGroup && + v->node()->kind() != prim::StaticSubgraph && v->node()->kind() != prim::CudaFusionGroup && v->node()->kind() != prim::FusionGroup && v->node()->kind() != prim::BailOut && v->uses().size() == 1 && @@ -412,6 +414,33 @@ struct BailoutBlock { std::vector instructions; // ends in a TAIL_CALL }; +thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr; +struct TLSCurrentInterpreterGuard { + TLSCurrentInterpreterGuard(InterpreterStateImpl* state) { + prev_state_ = tls_int_state_ptr_; + tls_int_state_ptr_ = state; + } + + ~TLSCurrentInterpreterGuard() { + tls_int_state_ptr_ = prev_state_; + } + + private: + InterpreterStateImpl* prev_state_; +}; + +template +Ttarget safe_narrow_cast(Tsource v) { + Ttarget res = static_cast(v); + // Casting it back to check whether it overflew. + if (static_cast(res) != v) { + TORCH_WARN( + "ATTENTION: your model computation is overflowing, safe_narrow_cast<>() failed"); + return v; + } + return res; +} + struct CodeImpl { friend struct InterpreterState; std::vector instructions_; @@ -519,7 +548,10 @@ struct CodeImpl { } void insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) { - instructions_.emplace_back(op, X, N); + instructions_.emplace_back( + op, + safe_narrow_cast(X), + safe_narrow_cast(N)); instructions_source_.emplace_back(current_node_); // check that we didn't accidentally emit nodes out of topological order @@ -675,7 +707,7 @@ struct CodeImpl { void emitCall(Function* func, at::ArrayRef inputs) { emitLoadInputs(inputs); insertInstruction(CALL, function_table_.size()); - function_table_.emplace_back(std::move(func)); + function_table_.emplace_back(func); } void emitNodeAtBlockLevel(Node* node) { @@ -712,8 +744,9 @@ struct CodeImpl { // Emit the expected type. size_t types_start = type_table_.size(); + auto types = node->tys(attr::types); for (size_t i = 0; i < num_inputs; i++) { - emitType(node->outputs()[i]->type()); + emitType(types[i]); } insertInstruction(TYPECHECK, types_start, num_inputs); } @@ -760,6 +793,9 @@ struct CodeImpl { } else if (node->cast()) { profile_function_table_.push_back( node->cast()->getCallback()); + } else if (node->cast()) { + profile_function_table_.push_back( + node->cast()->getCallback()); } else { TORCH_INTERNAL_ASSERT(false); } @@ -811,7 +847,7 @@ struct CodeImpl { void emitTupleConstruct(Node* node) { bool named = - node->output()->type()->expect()->name().has_value(); + node->output()->type()->expectRef().name().has_value(); if (named) { emitContainerConstruct(NAMED_TUPLE_CONSTRUCT, node); } else { @@ -856,8 +892,16 @@ struct CodeImpl { } void emitWarn(Node* node) { + if (FLAGS_torch_jit_disable_warning_prints) { + return; + } + emitLoadInputs(node->inputs()); - insertInstruction(WARN); + int32_t idx = -1; + if (node->hasAttribute(attr::warn_id)) { + idx = static_cast(node->i(attr::warn_id)); + } + insertInstruction(WARN, idx); } void emitEnter(Node* node) { @@ -894,7 +938,7 @@ struct CodeImpl { break; case prim::CallFunction: emitCall( - node->inputs().at(0)->type()->expect()->function(), + node->inputs().at(0)->type()->expectRef().function(), node->inputs().slice(1)); break; case prim::CallMethod: @@ -910,6 +954,7 @@ struct CodeImpl { case prim::BailOut: emitBailOut(node); break; + case prim::profile_ivalue: case prim::profile_optional: case prim::profile: emitProfile(node); @@ -996,16 +1041,34 @@ struct CodeImpl { // InterpreterState state that and used to compute a Code struct InterpreterStateImpl : c10::intrusive_ptr_target { - InterpreterStateImpl(const Code& code) { + InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher) + : taskLauncher_(std::move(taskLauncher)) { enterFrame(code, 0); } private: + struct WarnedNodes { + public: + // Inserts idx into warned_nodes_, returns a boolean indicates whether + // insertion actually happened (idx wasn't originally in the set). + bool insert(int32_t idx) { + std::unique_lock lock(mutex_); + return warned_nodes_.insert(idx).second; + } + + private: + std::mutex mutex_; + std::unordered_set warned_nodes_; + }; + + WarnedNodes warned_nodes_; + // if we need to suspend, where do we reset the stack? // answer: to where it was when we were called, not // including any inputs to this function int64_t stack_start_ = -1; c10::intrusive_ptr future_; + TaskLauncher taskLauncher_; // this holds all the tensors for this interpreter run // we don't bother minimizing the size of this vector, since the extra @@ -1044,30 +1107,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // RecordFunction object associated with this frame std::unique_ptr record_function; + // symbol table for a frame ShapeSymbolTable symbols2dims; }; - // saved-by-value stuff that can exist on the stack inside runInterpreter - struct ActiveFrame { - size_t pc; - Instruction* instructions; - IValue* constants; - Operation* operators; - Function** functions; - std::function&)>* profile_functions; - TypePtr* types; - - ActiveFrame(const Frame& frame) - : pc(frame.pc), - instructions(frame.function->instructions_.data()), - constants(frame.function->constant_table_.data()), - operators(frame.function->operator_table_.data()), - functions(frame.function->function_table_.data()), - profile_functions(frame.function->profile_function_table_.data()), - types(frame.function->type_table_.data()) {} - }; - std::vector frames; c10::intrusive_ptr intrusive_from_this() { @@ -1078,7 +1122,6 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { void enterFrame(const Code& code, size_t base_pointer) { frames.emplace_back(Frame{code.pImpl, 0, base_pointer, c10::nullopt}); registers.resize(registers.size() + code.pImpl->register_size_); - // frames.back().function->dump(std::cout); } void leaveFrame() { @@ -1101,16 +1144,16 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } } - void runBuiltinFunction(Stack& stack, Function* fn, ActiveFrame* af) { + void runBuiltinFunction(Stack& stack, Function* fn) { // BuiltinOpFunction directly invokes a void(Stack&) to implement // custom C++ classes. Call run() here with the stack, and we will // get the results from that C++ method back in the stack. Advance // the PC by 1 without adding any new frame. fn->run(stack); - ++af->pc; + ++frames.back().pc; } - void runGraphFunction(Stack& stack, Function* fn, ActiveFrame* af) { + void runGraphFunction(Stack& stack, Function* fn) { const Code& code = // consider passing // `frames.back().function->remaining_bailout_depth_` into @@ -1122,21 +1165,9 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { fn->get_executor() .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) .code; - frames.back().pc = af->pc + 1; + ++frames.back().pc; enterFrame(code, stack.size() - code.num_inputs()); - if (at::hasCallbacks() && at::isRecordFunctionEnabled()) { - auto rec_fn = std::make_unique( - at::RecordScope::TORCHSCRIPT_FUNCTION); - if (rec_fn->active) { - if (rec_fn->needs_inputs) { - rec_fn->before(fn->name(), last(stack, code.num_inputs())); - } else { - rec_fn->before(fn->name()); - } - frames.back().record_function = std::move(rec_fn); - } - } - *af = ActiveFrame(frames.back()); + checkAndStartRecordFunction(frames.back(), stack); } bool runImpl(Stack& stack) { @@ -1152,18 +1183,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { stack_start_ = 0; } - ActiveFrame af(frames.back()); + TLSCurrentInterpreterGuard g(this); + if (frames.back().pc == 0 && stack_start_ == 0) { + checkAndStartRecordFunction(frames.back(), stack); + } try { while (true) { + Frame& frame = frames.back(); // std::cout << "RUNNING "; - // frames.back().function->dump(std::cout, af.pc); - Instruction inst = af.instructions[af.pc]; + // frames.back().function->dump(std::cout, frame.pc); + Instruction inst = frame.function->instructions_[frame.pc]; switch (inst.op) { case ENTER: { auto obj = peek(stack, 0, 1); TORCH_INTERNAL_ASSERT(obj.isObject()); entered_objects.push_back(obj); - ++af.pc; + ++frame.pc; } break; case EXIT: { auto obj = entered_objects.back().toObject(); @@ -1173,90 +1208,90 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { push(stack, IValue()); push(stack, IValue()); push(stack, IValue()); - runGraphFunction(stack, &f, &af); + runGraphFunction(stack, &f); } break; case OP: - af.operators[inst.X](&stack); - ++af.pc; + frame.function->operator_table_[inst.X](&stack); + ++frame.pc; break; case OPN: stack.push_back(inst.N); - af.operators[inst.X](&stack); - ++af.pc; + frame.function->operator_table_[inst.X](&stack); + ++frame.pc; break; case LOAD: stack.emplace_back(reg(inst.X)); - ++af.pc; + ++frame.pc; break; case MOVE: stack.emplace_back(std::move(reg(inst.X))); - ++af.pc; + ++frame.pc; break; case STORE: reg(inst.X) = pop(stack); - ++af.pc; + ++frame.pc; break; case STOREN: for (size_t i = inst.N; i > 0; --i) { reg(inst.X + i - 1) = pop(stack); } - ++af.pc; + ++frame.pc; break; case DROP: pop(stack); - ++af.pc; + ++frame.pc; break; case DROPR: reg(inst.X) = IValue(); - ++af.pc; + ++frame.pc; break; case LOADC: - stack.emplace_back(af.constants[inst.X]); - ++af.pc; + stack.emplace_back(frame.function->constant_table_[inst.X]); + ++frame.pc; break; case GET_ATTR: { auto userObj = pop(stack).toObject(); auto value = userObj->getSlot(inst.X); push(stack, std::move(value)); - ++af.pc; + ++frame.pc; } break; case SET_ATTR: { auto v = pop(stack); auto userObj = pop(stack).toObject(); userObj->setSlot(inst.X, std::move(v)); - ++af.pc; + ++frame.pc; } break; case JF: - af.pc += (pop(stack).toBool()) ? 1 : inst.X; + frame.pc += (pop(stack).toBool()) ? 1 : inst.X; break; case JMP: - af.pc += inst.X; + frame.pc += inst.X; break; case LOOP: { // stack: iteration_count, max_iter, cond, loop_carried_deps... - auto frame = stack.end() - (inst.N + 1); - int64_t trip_count = frame[0].toInt(); - int64_t max_trip_count = frame[1].toInt(); - bool cond = frame[2].toBool(); + auto fr = stack.end() - (inst.N + 1); + int64_t trip_count = fr[0].toInt(); + int64_t max_trip_count = fr[1].toInt(); + bool cond = fr[2].toBool(); if (trip_count < max_trip_count && cond) { - frame[2] = trip_count; - frame[0] = trip_count + 1; - ++af.pc; + fr[2] = trip_count; + fr[0] = trip_count + 1; + ++frame.pc; } else { size_t n_loop_carried = inst.N - 2; for (size_t i = 0; i < n_loop_carried; ++i) { - frame[i] = std::move(frame[i + 3]); + fr[i] = std::move(fr[i + 3]); } drop(stack, 3); // iteration_count, max_iter, cond - af.pc += inst.X; + frame.pc += inst.X; } } break; case CALL: { - Function* fn = af.functions[inst.X]; + Function* fn = frame.function->function_table_[inst.X]; if (!fn->isGraphFunction()) { - runBuiltinFunction(stack, fn, &af); + runBuiltinFunction(stack, fn); } else { - runGraphFunction(stack, fn, &af); + runGraphFunction(stack, fn); } } break; case INTERFACE_CALL: { @@ -1276,17 +1311,17 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { peek(stack, 0, inst.N) .toObject() ->type() - ->getMethod(af.constants[inst.X].toStringRef()); + ->getMethod( + frame.function->constant_table_[inst.X].toStringRef()); if (!function.isGraphFunction()) { - runBuiltinFunction(stack, &function, &af); + runBuiltinFunction(stack, &function); } else { - runGraphFunction(stack, &function, &af); + runGraphFunction(stack, &function); } } break; case RET: if (frames.size() > 1) { leaveFrame(); - af = ActiveFrame(frames.back()); break; } if (future_) { @@ -1298,6 +1333,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { jit::last(stack, num_outputs).vec())); } } + // destroy the last frame and call RecordFunction's end callbacks + leaveFrame(); return false; case WAIT: { auto future = stack.back().toFuture(); @@ -1310,11 +1347,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { Callback( c10::intrusive_ptr state, Stack stack) - : state_(std::move(state)), stack_(std::move(stack)) { + : stateImpl_(std::move(state)), + state_(stateImpl_), + stack_(std::move(stack)) { dist_autograd_context_id_ = getDistAutogradContextId(); + state_ = InterpreterState(stateImpl_); } void operator()() { - at::launch(InterpreterContinuation( + stateImpl_->taskLauncher_(InterpreterContinuation( state_, std::move(stack_), dist_autograd_context_id_, @@ -1322,6 +1362,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } private: + c10::intrusive_ptr stateImpl_; InterpreterState state_; Stack stack_; int64_t dist_autograd_context_id_; @@ -1343,7 +1384,6 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { stack.resize(stack_start_); } // save pc into the frame so we continue here when restored - frames.back().pc = af.pc; future->addCallback( Callback(intrusive_from_this(), std::move(copied))); @@ -1351,26 +1391,26 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } stack.pop_back(); stack.emplace_back(future->value()); - ++af.pc; + ++frame.pc; } break; case PROFILE_OP: { - auto& frame_id_ref = frames.back().id; + auto& frame_id_ref = frame.id; if (!frame_id_ref.has_value()) { frame_id_ref = Frame::num_frames++; } - auto callback = af.profile_functions[inst.X]; + auto callback = frame.function->profile_function_table_[inst.X]; push(stack, c10::IValue{static_cast(*frame_id_ref)}); callback(stack); - ++af.pc; + ++frame.pc; break; } case FAIL_GUARD: { // patch FAIL_GUARD back to GUARD GRAPH_DEBUG( "Bailout ", inst.X, " triggered via bailout_requests_!"); - af.instructions[af.pc].op = GUARD; + frame.function->instructions_[frame.pc].op = GUARD; push(stack, false); - ++af.pc; + ++frame.pc; break; } case TYPECHECK: { @@ -1379,13 +1419,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // Check every input's shape against profiled (expected) shape. for (i = 0; i < num_inputs; i++) { auto& input = peek(stack, i, num_inputs); - auto t = input.toTensor(); - const TypePtr& expected = af.types[inst.X + i]; + auto& t = input.toTensor(); + const TypePtr& expected = frame.function->type_table_[inst.X + i]; auto expected_type = expected->cast(); - if (t.defined() && - (!frames.back().symbols2dims.bindSymbolicShapes( - t.sizes(), expected_type->symbolic_sizes()) || - !expected_type->matchTensor(t))) { + if (t.defined() && !expected_type->matchTensor(t)) { push(stack, false); break; } @@ -1393,7 +1430,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { if (i == num_inputs) { push(stack, true); } - ++af.pc; + ++frame.pc; break; } case GUARD: { @@ -1403,8 +1440,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // so it's safe to pass this guard check push(stack, true); } else { - auto t = stack.back().toTensor(); - const TypePtr& expected = af.types[inst.X]; + auto& t = stack.back().toTensor(); + const TypePtr& expected = frame.function->type_table_[inst.X]; auto expected_type = expected->cast(); if (t.defined() && !frames.back().symbols2dims.bindSymbolicShapes( @@ -1414,21 +1451,21 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { push(stack, expected_type->matchTensor(t)); } } - ++af.pc; + ++frame.pc; } break; case TAIL_CALL: { GRAPH_DEBUG("running TAIL_CALL for ", inst.X); - af.functions[inst.X]->ensure_defined(); + frame.function->function_table_[inst.X]->ensure_defined(); size_t remaining_bailout_depth = - frames.back().function->remaining_bailout_depth_ > 0 - ? frames.back().function->remaining_bailout_depth_ - 1 + frame.function->remaining_bailout_depth_ > 0 + ? frame.function->remaining_bailout_depth_ - 1 : 0; - const Code& code = af.functions[inst.X] + const Code& code = frame.function->function_table_[inst.X] ->get_executor() .getPlanFor(stack, remaining_bailout_depth) .code; size_t num_inputs = code.num_inputs(); - size_t base_pointer = frames.back().base_pointer; + size_t base_pointer = frame.base_pointer; TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs); size_t inputs_start = stack.size() - num_inputs; for (size_t i = 0; i < num_inputs; ++i) { @@ -1438,85 +1475,103 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { stack.resize(base_pointer + num_inputs); leaveFrame(); enterFrame(code, base_pointer); - af = ActiveFrame(frames.back()); + checkAndStartRecordFunction(frames.back(), stack); } break; case LIST_UNPACK: { listUnpack(stack, inst.X); - ++af.pc; + ++frame.pc; } break; case TUPLE_CONSTRUCT: { tupleConstruct(stack, inst.X); - ++af.pc; + ++frame.pc; } break; case TUPLE_SLICE: { tupleSlice(stack, inst.X, inst.X + inst.N); - ++af.pc; + ++frame.pc; } break; case NAMED_TUPLE_CONSTRUCT: { - auto type = af.types[inst.X]->expect(); + auto type = + frame.function->type_table_[inst.X]->expect(); namedTupleConstruct(stack, type, inst.N); - ++af.pc; + ++frame.pc; } break; case LIST_CONSTRUCT: { - auto type = af.types[inst.X]->expect(); + const auto& type = + frame.function->type_table_[inst.X]->expectRef(); listConstruct(stack, type, inst.N); - ++af.pc; + ++frame.pc; } break; case DICT_CONSTRUCT: { - auto type = af.types[inst.X]->expect(); + auto type = frame.function->type_table_[inst.X]->expect(); dictConstruct(stack, type, inst.N); - ++af.pc; + ++frame.pc; } break; case CREATE_OBJECT: { - auto type = af.types[inst.X]->expect(); + auto type = + frame.function->type_table_[inst.X]->expect(); createObject(stack, type); - ++af.pc; + ++frame.pc; } break; case ISINSTANCE: { at::ArrayRef types( - af.types + inst.X, af.types + inst.X + inst.N); + &(frame.function->type_table_[inst.X]), + &(frame.function->type_table_[inst.X + inst.N])); isinstance(stack, types); - ++af.pc; + ++frame.pc; } break; case FORK: { // Move inputs to a separate stack - Function* forked_fn = af.functions[inst.X]; + Function* forked_fn = frame.function->function_table_[inst.X]; InterpreterState forked_interpreter( forked_fn->get_executor() .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) - .code); + .code, + taskLauncher_); InterpreterContinuation continuation( forked_interpreter, Stack(stack.end() - inst.N, stack.end()), getDistAutogradContextId()); drop(stack, inst.N); push(stack, forked_interpreter.getFuture()); - at::launch(std::move(continuation)); - ++af.pc; + taskLauncher_(std::move(continuation)); + ++frame.pc; } break; case WARN: { - Node* node = frames.back().function->instructions_source_.at(af.pc); + // Keeps track of which WARN instruction has been executed before, + // we only want to execute each WARN once to match default Python + // warning behavior. + bool need_warn = true; + if (inst.X != -1) { + need_warn = warned_nodes_.insert(inst.X); + } + + Node* node = + frames.back().function->instructions_source_.at(frame.pc); auto range = node->sourceRange().source(); if (range->filename()) { - auto line = range->starting_line_no() + - range->lineno_for_offset(node->sourceRange().start()); drop(stack, 1); - c10::SourceLocation location{ - "", range->filename()->c_str(), uint32_t(line)}; - // Sends the warning to the warning handler with the - // "verbatim" flag. This flag ensures the warning handler - // will print the exception as configured. - c10::Warning::warn( - location, pop(stack).toStringRef(), /*verbatim=*/true); + const auto msg = pop(stack).toStringRef(); + if (need_warn) { + auto line = range->starting_line_no() + + range->lineno_for_offset(node->sourceRange().start()); + c10::SourceLocation location{ + "", range->filename()->c_str(), uint32_t(line)}; + // Sends the warning to the warning handler with the + // "verbatim" flag. This flag ensures the warning handler + // will print the exception as configured. + c10::Warning::warn(location, msg, /*verbatim=*/true); + } } else { - TORCH_WARN(pop(stack).toStringRef()); + const auto msg = pop(stack).toStringRef(); + if (need_warn) { + TORCH_WARN(msg); + } } - ++af.pc; + ++frame.pc; } break; } } } catch (std::exception& e) { - frames.back().pc = af.pc; for (auto it = entered_objects.rbegin(), end = entered_objects.rend(); it != end; ++it) { @@ -1542,6 +1597,44 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } void formatStackTrace(std::ostream& out) { + format_stack_trace(out, callstack()); + } + + void handleError(const ExceptionMessage& msg, bool is_jit_exception) { + std::ostringstream ss; + ss << "The following operation failed in the TorchScript interpreter.\n"; + formatStackTrace(ss); + ss << "RuntimeError: " << msg << "\n"; + if (future_) { + future_->setError(std::make_exception_ptr(Future::FutureError(ss.str()))); + } else if (is_jit_exception) { + throw JITException(ss.str()); + } else { + throw std::runtime_error(ss.str()); + } + } + + static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { + bool pre_sampled = false; + if (!frame.record_function && at::hasCallbacks() && + at::shouldRunRecordFunction(&pre_sampled)) { + auto rec_fn = std::make_unique( + at::RecordScope::TORCHSCRIPT_FUNCTION, pre_sampled); + if (rec_fn->isActive()) { + if (rec_fn->needsInputs()) { + rec_fn->before( + frame.function->function_name_, + last(stack, frame.function->n_inputs)); + } else { + rec_fn->before(frame.function->function_name_); + } + frame.record_function = std::move(rec_fn); + } + } + } + + public: + std::vector callstack() const { std::vector entries; for (size_t i = 0; i < frames.size(); ++i) { const Frame& frame = frames[i]; @@ -1556,30 +1649,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { Node* node = frame.function->instructions_source_[pc]; if (node->callstack()) { for (const auto& p : (*node->callstack())->vec()) { - entries.emplace_back(StackEntry{previous_fn_name, p.second}); - previous_fn_name = p.first->name(); + entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)}); + previous_fn_name = std::get<0>(p)->name(); } } entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()}); } - format_stack_trace(out, entries); + return entries; } - void handleError(const ExceptionMessage& msg, bool is_jit_exception) { - std::ostringstream ss; - ss << "The following operation failed in the TorchScript interpreter.\n"; - formatStackTrace(ss); - ss << "RuntimeError: " << msg << "\n"; - if (future_) { - future_->setError(std::make_exception_ptr(Future::FutureError(ss.str()))); - } else if (is_jit_exception) { - throw JITException(ss.str()); - } else { - throw std::runtime_error(ss.str()); - } - } - - public: c10::intrusive_ptr getOrCreateFuture() { if (!future_) { future_ = @@ -1611,6 +1689,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } }; +std::vector currentCallstack() { + if (tls_int_state_ptr_) { + auto cs = tls_int_state_ptr_->callstack(); + std::reverse(cs.begin(), cs.end()); + return cs; + } + return std::vector(); +} + std::atomic InterpreterStateImpl::Frame::num_frames; std::ostream& operator<<(std::ostream& out, const Code& code) { @@ -1669,8 +1756,10 @@ size_t Code::register_size() const { return pImpl->register_size_; } -InterpreterState::InterpreterState(const Code& code) - : pImpl(c10::make_intrusive(code)) {} +InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher) + : pImpl(c10::make_intrusive( + code, + std::move(taskLauncher))) {} InterpreterState::~InterpreterState() = default; void InterpreterState::run(Stack& stack) { @@ -1690,7 +1779,7 @@ InterpreterState::InterpreterState( : pImpl(std::move(pImpl_)) {} void InterpreterContinuation::operator()() { -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC auto prev_dist_id = DistAutogradContainer::currentContextId(); DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_); #endif @@ -1700,9 +1789,10 @@ void InterpreterContinuation::operator()() { } else { state.runAsync(stack); } -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC DistAutogradContainer::forceCurrentContextId(prev_dist_id); #endif } + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 3325e1213e919..a4bb209cd17ee 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -5,11 +5,16 @@ #include #include +#include #include +#include + +C10_DECLARE_bool(torch_jit_disable_warning_prints); namespace at { class Tensor; -} +TORCH_API void launch(std::function func); +} // namespace at namespace c10 { struct IValue; struct OperatorName; @@ -31,6 +36,7 @@ struct Node; struct Instruction; using Stack = std::vector; using c10::ivalue::Future; +using TaskLauncher = std::function)>; struct TORCH_API Code { Code() : pImpl(nullptr) {} @@ -65,9 +71,11 @@ struct TORCH_API Code { }; struct InterpreterState { - TORCH_API InterpreterState(const Code& code); + TORCH_API InterpreterState( + const Code& code, + TaskLauncher taskLauncher = at::launch); TORCH_API void run(Stack& stack); - c10::intrusive_ptr runAsync(Stack& stack); + TORCH_API c10::intrusive_ptr runAsync(Stack& stack); c10::intrusive_ptr getFuture(); TORCH_API ~InterpreterState(); @@ -97,7 +105,7 @@ struct Suspend : public std::exception { // thread local settings are propagated with ThreadLocalState struct InterpreterContinuation { InterpreterContinuation( - InterpreterState state_, + const InterpreterState& state_, Stack stack_, int64_t dist_autograd_context_id = 0, c10::optional tls_state = c10::nullopt) @@ -126,5 +134,8 @@ struct InterpreterContinuation { TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext( const at::Tensor& t); +// current (TLS) TorchScript interpreter callstack +TORCH_API std::vector currentCallstack(); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/logging.h b/torch/csrc/jit/runtime/logging.h index f5f4559e65f46..a59e809719260 100644 --- a/torch/csrc/jit/runtime/logging.h +++ b/torch/csrc/jit/runtime/logging.h @@ -16,7 +16,7 @@ class LoggerBase { TORCH_API virtual void addStatValue( const std::string& stat_name, int64_t val) = 0; - virtual ~LoggerBase() {} + virtual ~LoggerBase() = default; }; TORCH_API LoggerBase* getLogger(); @@ -28,7 +28,7 @@ TORCH_API LoggerBase* setLogger(LoggerBase* logger); class NoopLogger : public LoggerBase { public: void addStatValue(const std::string& stat_name, int64_t val) override {} - ~NoopLogger() {} + ~NoopLogger() = default; }; // Trivial locking logger. Pass in an instance of this to setLogger() to use it. @@ -42,7 +42,7 @@ class TORCH_API LockingLogger : public LoggerBase { virtual int64_t getCounterValue(const std::string& name) const; enum class AggregationType { SUM = 0, AVG = 1 }; void setAggregationType(const std::string& stat_name, AggregationType type); - ~LockingLogger() {} + ~LockingLogger() = default; private: mutable std::mutex m; @@ -74,10 +74,11 @@ constexpr const char* EXECUTION_PLAN_CACHE_MISS = "pytorch_runtime.execution_plan_cache_miss"; inline std::vector allRuntimeCounters() { - return {GRAPH_EXECUTORS_CONSTRUCTED, - GRAPH_EXECUTOR_INVOCATIONS, - EXECUTION_PLAN_CACHE_HIT, - EXECUTION_PLAN_CACHE_MISS}; + return { + GRAPH_EXECUTORS_CONSTRUCTED, + GRAPH_EXECUTOR_INVOCATIONS, + EXECUTION_PLAN_CACHE_HIT, + EXECUTION_PLAN_CACHE_MISS}; } } // namespace runtime_counters diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 2bd6a2b47ec9c..46f8a618a7832 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -230,20 +231,25 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::ConstantChunk, // optimization pass adds it prim::DifferentiableGraph, // optimization pass adds it, prim::FunctionalGraph, // optimization pass adds it, + prim::ReductionSizes, // optimization pass (fuser) adds it prim::BroadcastSizes, // optimization pass (fuser) adds it prim::ChunkSizes, // optimization pass (fuser) adds it prim::Drop, // used in interpreter only prim::FusedConcat, // optimization pass adds it prim::FusionGroup, // optimization pass adds it prim::CudaFusionGroup, // optimization pass adds it + prim::CudaFusionGuard, // optimization pass adds it prim::TensorExprGroup, // optimization pass adds it + prim::StaticSubgraph, // optimization pass adds it prim::Load, // used in interpreter only prim::MMTreeReduce, // used as an optimization prim::MMBatchSide, // used as an optimization prim::Store, // used in interpreter only prim::profile, // used in interpreter only prim::profile_optional, // used in interpreter only + prim::profile_ivalue, // used in interpreter only prim::TypeCheck, // used in interpreter only + prim::RequiresGradCheck, // used in interpreter only prim::FallbackGraph, // converted into prim::CallFunction }; @@ -273,6 +279,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::CudaFusionGroup, prim::DifferentiableGraph, prim::TensorExprGroup, + prim::StaticSubgraph, prim::FunctionalGraph, prim::Constant, prim::Uninitialized, @@ -286,7 +293,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::MMBatchSide, prim::BroadcastSizes, prim::ChunkSizes, - prim::Function, + prim::Closure, prim::TupleUnpack, prim::TupleIndex, prim::TupleSlice, @@ -301,7 +308,9 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::SetAttr, prim::profile, prim::profile_optional, + prim::profile_ivalue, prim::TypeCheck, + prim::RequiresGradCheck, prim::Print, prim::CallFunction, prim::CallMethod, diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 05305c71f27c5..ec3aa3d4fb99c 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -36,7 +36,7 @@ using OperationCreator = Operation (*)(const Node*); /* * Note: JIT relies on Operator instances having static lifetime, because * it for example stores a non-owning FunctionSchema* pointer in the Node class, - * which points to the function shema stored in the Operator instance. + * which points to the function schema stored in the Operator instance. * Also, jit::Operator is meant to store more operator related information like * symbolic derivatives, which also requires them to have static lifetime * so that changes to symbolic derivatives are remembered. @@ -73,7 +73,7 @@ struct TORCH_API Operator { public: Operator(c10::OperatorHandle opHandle, Operation operation) : op_(c10::make_left( - C10Operator{std::move(opHandle), std::move(operation)})) {} + C10Operator{opHandle, std::move(operation)})) {} Operator( std::string schema, @@ -102,8 +102,7 @@ struct TORCH_API Operator { : op_(c10::make_right(JitOnlyOperator{ c10::make_right( UnparsedFunctionSchema{std::move(schema), alias_analysis}), - c10::make_right( - std::move(op_creator))})) {} + c10::make_right(op_creator)})) {} // Helper constructor to register `op` to run // run for _every_ IR Node where n.kind() == name, regardless of arguments. @@ -116,8 +115,7 @@ struct TORCH_API Operator { : op_(c10::make_right(JitOnlyOperator{ c10::make_left( varArgSchemaWithName(name, alias_analysis)), - c10::make_right( - std::move(op_creator))})) {} + c10::make_right(op_creator)})) {} Operation getOperation(const Node* node = nullptr) const { return op_.fold( diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 5d63d78d47656..0d18f5230478c 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -9,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -28,41 +30,68 @@ #include #include #include - -C10_DECLARE_bool(); +#include +#include C10_DEFINE_bool( torch_jit_enable_new_executor, true, "If this flag is set to false TorchScript will be using the legacy/original executor"); +C10_DEFINE_bool( + torch_jit_disable_warning_prints, + false, + "Disables warning.warn prints in TorchScript graph"); + +constexpr size_t kDefaultNumProfiledRuns = 1; +constexpr size_t kDefaultBailoutDepth = 20; + +C10_DEFINE_int64( + torch_jit_num_profiled_runs, + kDefaultNumProfiledRuns, + "Number of profiling runs"); +C10_DEFINE_int64( + torch_jit_bailout_depth, + kDefaultBailoutDepth, + "Number of re-specializations"); + namespace torch { namespace jit { -// TODO: keep the else clause for trial runs -#if defined(FBCODE_CAFFE2) || defined(C10_MOBILE) +#if defined(C10_MOBILE) static std::atomic executor_mode{true}; static std::atomic profiling_mode{false}; #else static std::atomic executor_mode{true}; -static std::atomic profiling_mode{false}; +static std::atomic profiling_mode{true}; #endif -static std::atomic num_profiled_runs{1}; -static std::atomic bailout_depth{1}; +static std::atomic num_profiled_runs{kDefaultNumProfiledRuns}; +static std::atomic bailout_depth{kDefaultBailoutDepth}; std::atomic& getProfilingMode() { return profiling_mode; } + std::atomic& getExecutorMode() { return executor_mode; } std::atomic& getNumProfiledRuns() { + // Initialize num_profiled_runs from command-line flag. + static const size_t init = []() { + return num_profiled_runs = FLAGS_torch_jit_num_profiled_runs; + }(); + (void)init; // Silence clang-tidy. return num_profiled_runs; } std::atomic& getBailoutDepth() { + // Initialize bailout_depth from command-line flag. + static const size_t init = []() { + return bailout_depth = FLAGS_torch_jit_bailout_depth; + }(); + (void)init; // Silence clang-tidy. return bailout_depth; } @@ -90,6 +119,62 @@ static bool needsGradientInProfilingMode(Block* b) { return false; } +bool guardDifferentiableGraph(Node* dnode) { + auto gi = dnode->g(attr::Subgraph)->inputs(); + bool all_inputs_seen = true; + for (size_t i = 0; i < gi.size(); i++) { + auto ty = gi[i]->type()->cast(); + if (ty) { + auto n = gi[i]->uses().at(0).user; + auto dni = dnode->inputs().at(i); + GRAPH_DEBUG("found first user of ", i, " as ", *n); + if (n->kind() == prim::profile) { + GRAPH_DEBUG( + "setting input ", i, " to type ", *n->ty(attr::profiled_type)); + dni->setType(n->ty(attr::profiled_type)); + } else if (dni->node()->kind() == prim::DifferentiableGraph) { + // The profiling node might have been absorbed in a preceding + // differentiable graph and thus not (not ideal for fusing either), + // see TestAutodiffSubgraphSlicing.test_does_not_create_cycles. + // Alternatives to this special casing could be specializing the types + // before autodiff or duplicating profile nodes for autodiff outputs + // but that should be done while creating subgraphs and would be + // a mess. + // XXX TODO: revisit the alternatives + Value* o = dni->node()->g(attr::Subgraph)->outputs().at(dni->offset()); + if (o->node()->kind() == prim::profile) { + dni->setType(o->node()->ty(attr::profiled_type)); + } + } + + // we check if the optional is defined + all_inputs_seen &= (dni->type()->cast() != TensorType::get()); + } + } + if (all_inputs_seen) { + // we may have seen both true and false for requires_grad. In this case + // we guard with true here and the other case is in the fallback. This + // will give us trouble when we get "alternating patterns" of gradients + // of two inputs, but so it is. An alternative could be to look into + // the individual requires_grad seen in the profiling record. + insertTypeGuard( + dnode, + [](const TensorTypePtr& t) { + return TensorType::get()->withRequiresGrad( + t->requiresGrad().value_or(true)); + }, + prim::RequiresGradCheck); + return true; + } else { + // we inline the differentiable graph as a fallback + // ideally we would set this up for re-profiling + UpdateDifferentiableGraphRequiresGrad( + dnode->g(attr::Subgraph), c10::nullopt); + SubgraphUtils::unmergeSubgraph(dnode); + return false; + } +} + void runNooptPassPipeline(std::shared_ptr& graph) { GRAPH_DEBUG( "Before LowerGradOf (beginning of runNooptPassPipeline)\n", *graph); @@ -109,8 +194,8 @@ void runPreAutodiffPassPipeline(std::shared_ptr& graph) { "Before InsertGuards (beginning of runPreAutodiffPassPipeline)\n", *graph); - if (tensorExprFuserEnabled()) { - // With TE fuser we don't generate bailouts + if (tensorExprFuserEnabled() || RegisterCudaFuseGraph::isRegistered()) { + // With TE fuser or nvfuser, we don't generate bailouts LowerGradOf(*graph); GRAPH_DEBUG("After LowerGradOf, before specializeAutogradZero\n", *graph); } else { @@ -253,7 +338,7 @@ void runDiffGraphPasses(std::shared_ptr& graph) { BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); - FuseTensorExprs(graph); + FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1); GRAPH_DEBUG( "After Fusion, before RemoveTensorTypeSpecializations\n", *graph); @@ -312,7 +397,7 @@ void runNoGradOptimizations(std::shared_ptr& graph) { BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); - FuseTensorExprs(graph); + FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1); GRAPH_DEBUG( "After Fusion, before RemoveTensorTypeSpecializations\n", *graph); @@ -356,11 +441,23 @@ void ProfilingGraphExecutorImpl::runProfilingOptimizations( GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *copy); size_t idx = 0; for (Node* dnode : diff_nodes) { - GRAPH_DEBUG("Optimizing diff node ", idx); + GRAPH_DEBUG("Optimizing diff node ", idx, " in ", *copy); + if (!guardDifferentiableGraph(dnode)) { + // if we cannot guard (because of inputs without profiling information), + // we re-inline the subgraph and remove the differentiable node + GRAPH_DEBUG("Could not guardDifferentiableGraph ", idx, " in ", *copy); + idx++; + continue; + } + GRAPH_DEBUG("After guardDifferentiableGraph:\n", *copy); auto diff_graph = std::move(dnode->g(attr::Subgraph)); Gradient gradient = differentiate(diff_graph); GRAPH_DEBUG("Forward graph:\n", *(gradient.f)); GRAPH_DEBUG("Backward graph:\n", *(gradient.df)); + // just like inside autograd.Functions, the forward of a differentiable + // graph is essentially in a torch.no_grad context. + UpdateDifferentiableGraphRequiresGrad(gradient.f, false); + GRAPH_DEBUG("After UpdateDifferentiableGraphRequiresGrad ", *gradient.f); runDiffGraphPasses(gradient.f); // replaces fallback graphs inserted by TE Fuser replaceFallbackGraphWithFallbackFunction(gradient.f->block()); @@ -370,6 +467,7 @@ void ProfilingGraphExecutorImpl::runProfilingOptimizations( InlineAutodiffSubgraphs( copy, getAutodiffSubgraphInlining() ? autodiffSubgraphInlineThreshold : 1); + replaceFallbackGraphWithFallbackFunction(copy->block()); RemoveProfilingNodes(copy); GRAPH_DEBUG( "After InlineAutodiffSubgraphs and Removing Profiling Nodes\n", *copy); @@ -446,12 +544,26 @@ ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl( std::string function_name) : GraphExecutorImplBase(graph, std::move(function_name)) {} -ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( +const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor( Stack& stack, size_t remaining_bailout_depth) { - std::lock_guard lock(compile_mutex); GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this); + // no opt mode + if (!getGraphExecutorOptimize()) { + if (!fallback_plan_) { + auto copy = graph->copy(); + GRAPH_DEBUG( + "Before LowerGradOf (beginning of runNooptPassPipeline)\n", *graph); + LowerGradOf(*copy); + GRAPH_DEBUG("After LowerGradOf, before RemoveExpands\n", *graph); + RemoveExpands(copy); + fallback_plan_ = ExecutionPlan(copy, function_name_); + GRAPH_DUMP("NoOpt Graph: ", copy); + } + return *fallback_plan_; + } + // if tensorExprFuserEnabled() returns true we need to persist the very first // time ProfilingGraphExecutorImpl is called, so we can update it correctly // for fallback functions in ProfilingGraphExecutorImpl Else, @@ -461,11 +573,6 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( remaining_bailout_depth_ = remaining_bailout_depth; } - if (optimized_plan_) { - GRAPH_DEBUG("plan already optimized:", (*optimized_plan_).graph); - return *optimized_plan_; - } - // simple executor if (*remaining_bailout_depth_ == 0) { auto copy = graph->copy(); @@ -502,6 +609,20 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( return *optimized_plan_; } +const ExecutionPlan& ProfilingGraphExecutorImpl::getPlanFor( + Stack& stack, + size_t remaining_bailout_depth) { + std::lock_guard lock(compile_mutex); + + // IMPORTANT: This is a hot path of calling a torchscript function. Try not to + // add any code above this. + if (optimized_plan_) { + return *optimized_plan_; + } + + return getOptimizedPlanFor(stack, remaining_bailout_depth); +} + GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() { GraphExecutorState state; TORCH_INTERNAL_ASSERT(optimized_plan_); diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h index 39f58069015bf..8b4553c31ffb8 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h @@ -9,12 +9,15 @@ struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase { const std::shared_ptr& graph, std::string function_name); - ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth) + const ExecutionPlan& getPlanFor(Stack& stack, size_t remaining_bailout_depth) override; GraphExecutorState getDebugState() override; ~ProfilingGraphExecutorImpl() override = default; private: + const ExecutionPlan& getOptimizedPlanFor( + Stack& stack, + size_t remaining_bailout_depth); void runProfilingInsensitiveOptimizations(std::shared_ptr& graph); void runProfilingOptimizations(std::shared_ptr& graph); void replaceFallbackGraphWithFallbackFunction(Block* b); @@ -22,6 +25,8 @@ struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase { c10::optional profiling_plan_; // plan to run in order to profiling the code c10::optional optimized_plan_; + // this plan is used if getGraphExecutorOptimize is unset + c10::optional fallback_plan_; // fallback functions are inserted for tensorexpr fusion groups // and by specialize_autogradzero. Whenever, at runtime, input // tensor don't match profiled properties, fallback functions are called diff --git a/torch/csrc/jit/runtime/profiling_record.cpp b/torch/csrc/jit/runtime/profiling_record.cpp index 98c0736681700..273b2427046c5 100644 --- a/torch/csrc/jit/runtime/profiling_record.cpp +++ b/torch/csrc/jit/runtime/profiling_record.cpp @@ -1,15 +1,57 @@ #include + #include #include #include #include #include +#include #include #include namespace torch { namespace jit { +namespace { + +class ProfileRegistry { + public: + static ProfileRegistry* getRegistry() { + static ProfileRegistry profile_registry_; + return &profile_registry_; + } + + void registerProfileNode(const std::function& func) { + std::lock_guard guard(mutex_); + registry_funcs_.push_back(func); + } + + bool shouldProfileNode(const Node* node) { + std::lock_guard guard(mutex_); + // to guard differentiable graphs, we want profiling information + // (in particular requires_grad) for nodes handled by autodiff + if (isDifferentiable(node)) { + return true; + } + for (const auto& func : registry_funcs_) { + if (func(node)) { + return true; + } + } + return false; + } + + private: + std::vector> registry_funcs_; + std::mutex mutex_; +}; + +} // namespace + +void RegisterProfilingNode(const std::function& func) { + ProfileRegistry::getRegistry()->registerProfileNode(func); +} + bool ShapeSymbolTable::bindSymbolicShapes( at::IntArrayRef new_sizes, const c10::SymbolicShape& sym_shapes) { @@ -130,7 +172,7 @@ void ProfilingRecord::insertShapeProfile(Node* n, size_t offset) { if (v.isTensor()) { std::lock_guard lock(this->mutex_); auto& profiled_types = profiled_types_per_frame_[frame_id]; - auto t = v.toTensor(); + auto& t = v.toTensor(); if (t.defined()) { auto pttp = tensorTypeInCurrentExecutionContext(t); GRAPH_DEBUG( @@ -145,7 +187,7 @@ void ProfilingRecord::insertShapeProfile(Node* n, size_t offset) { } else { auto type = profiled_types.at(pno); GRAPH_DEBUG("Existing type for %", pno->debugName(), " ", *type); - pttp = type->merge(pttp); + pttp = type->merge(*pttp); GRAPH_DEBUG("Result for %", pno->debugName(), " ", *pttp); profiled_types[pno] = pttp; } @@ -189,7 +231,7 @@ bool needsProfiledInputs(Node* n) { case aten::mm: return true; default: - return false; + return ProfileRegistry::getRegistry()->shouldProfileNode(n); } } @@ -203,7 +245,7 @@ bool needsProfiledOutput(Node* n) { case prim::AutogradZero: return true; default: - return false; + return ProfileRegistry::getRegistry()->shouldProfileNode(n); } } @@ -324,7 +366,7 @@ std::unique_ptr ProfilingRecord::instrumentGraph( " records for run ", frame_id); - if (raw_pr->profiled_types_per_frame_.size() == 0) { + if (raw_pr->profiled_types_per_frame_.empty()) { return; } @@ -335,18 +377,17 @@ std::unique_ptr ProfilingRecord::instrumentGraph( // and use it for building the symbol sets auto profiled_types_iter = raw_pr->profiled_types_per_frame_.begin(); auto merged_profiled_types = profiled_types_iter->second; - profiled_types_iter++; + ++profiled_types_iter; // merge profiling information from next runs into the first one for (; profiled_types_iter != raw_pr->profiled_types_per_frame_.end(); - profiled_types_iter++) { + ++profiled_types_iter) { SetPartitioningHelper partition_helper; for (const auto& val_type_pair : profiled_types_iter->second) { - if (merged_profiled_types.count(val_type_pair.first) == 0) { - merged_profiled_types[val_type_pair.first] = val_type_pair.second; - } else { - auto type = merged_profiled_types[val_type_pair.first]; - auto merged_type = type->merge(val_type_pair.second); + auto insertion_result = merged_profiled_types.insert(val_type_pair); + if (!insertion_result.second) { // Already existed + const TensorType* type = insertion_result.first->second.get(); + auto merged_type = type->merge(*val_type_pair.second); if (merged_type->sizes().size().has_value()) { auto new_shape = raw_pr->mergeSymbolicShapes( val_type_pair.second->symbolic_sizes(), @@ -359,13 +400,12 @@ std::unique_ptr ProfilingRecord::instrumentGraph( profiled_types_iter->first, " into ", *type); - merged_type = type->withSymbolicShapes(new_shape); + merged_type = type->withSymbolicShapes(std::move(new_shape)); GRAPH_DEBUG("Result : ", *merged_type); - merged_profiled_types[val_type_pair.first] = merged_type; + insertion_result.first->second = std::move(merged_type); } else { // reset symbolic shapes when ranks are different - type = type->merge(val_type_pair.second); - merged_profiled_types[val_type_pair.first] = type; + insertion_result.first->second = std::move(merged_type); } } } diff --git a/torch/csrc/jit/runtime/profiling_record.h b/torch/csrc/jit/runtime/profiling_record.h index bb135d14a9531..851d0d5be4f24 100644 --- a/torch/csrc/jit/runtime/profiling_record.h +++ b/torch/csrc/jit/runtime/profiling_record.h @@ -82,6 +82,8 @@ namespace jit { using ::c10::TensorTypePtr; using Dimension = int64_t; +TORCH_API void RegisterProfilingNode(const std::function&); + struct ProfilingRecord; // `SetPartitioningHelper` is used to maintain the following invariant: diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp index 4e1a4fb1f2113..e31c13ac6dc8a 100644 --- a/torch/csrc/jit/runtime/register_c10_ops.cpp +++ b/torch/csrc/jit/runtime/register_c10_ops.cpp @@ -46,7 +46,7 @@ Operator createOperatorFromC10_withTracingHandledHere( node->addInput(none); continue; } else { - type = type->expect()->getElementType(); + type = type->expectRef().getElementType(); } } if (type->isSubtypeOf(TensorType::get())) { @@ -67,7 +67,7 @@ Operator createOperatorFromC10_withTracingHandledHere( } else if (type->kind() == TypeKind::NumberType) { tracer::addInputs(node, args[i].name().c_str(), iter->toScalar()); } else if (type->kind() == TypeKind::ListType) { - const auto& elem_type = type->expect()->getElementType(); + const auto& elem_type = type->expectRef().getElementType(); if (elem_type->isSubtypeOf(TensorType::get())) { AT_ASSERT(iter->isTensorList()); auto list = iter->toTensorVector(); @@ -134,7 +134,7 @@ Operator createOperatorFromC10_withTracingHandledHere( AT_ASSERT(iter->isTensor()); tracer::addOutput(node, iter->toTensor()); } else if (type->kind() == TypeKind::ListType) { - const auto& elem_type = type->expect()->getElementType(); + const auto& elem_type = type->expectRef().getElementType(); if (elem_type->isSubtypeOf(TensorType::get())) { AT_ASSERT(iter->isTensorList()); tracer::addOutput(node, iter->toTensorList()); diff --git a/torch/csrc/jit/runtime/register_cuda_ops.cpp b/torch/csrc/jit/runtime/register_cuda_ops.cpp new file mode 100644 index 0000000000000..5cf31d626dd03 --- /dev/null +++ b/torch/csrc/jit/runtime/register_cuda_ops.cpp @@ -0,0 +1,87 @@ +// This file registers special JIT operators used to implement the PyTorch CUDA +// API in TorchScript. +#ifndef __HIP_PLATFORM_HCC__ +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace { + +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; +} + +RegisterOperators const reg({ + Operator( + "cuda::current_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream", + [](Stack* stack) { + auto idx = uint16_t(pop(stack).toInt()); + auto s = c10::cuda::getCurrentCUDAStream(idx); + auto st = make_custom_class(s); + push(stack, IValue(st)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::default_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream", + [](Stack* stack) { + auto idx = uint16_t(pop(stack).toInt()); + auto s = c10::cuda::getDefaultCUDAStream(idx); + auto st = make_custom_class(s); + push(stack, IValue(st)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::_current_device() -> int", + [](Stack* stack) { + auto v = c10::cuda::current_device(); + push(stack, static_cast(v)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::_set_device(int64_t val) -> ()", + [](Stack* stack) { + int64_t idx = -1; + pop(stack, idx); + c10::cuda::set_device(static_cast(idx)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::device_index(Device device) -> int", + [](Stack* stack) { + auto device = pop(stack); + auto idx = device.toDevice().index(); + push(stack, idx); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::device_count() -> int", + [](Stack* stack) { push(stack, at::cuda::device_count()); }, + aliasAnalysisFromSchema()), + Operator( + "cuda::set_stream(__torch__.torch.classes.cuda.Stream stream) -> ()", + [](Stack* stack) { + auto v = pop(stack); + auto s = v.toCustomClass(); + // To set the current CUDA stream using + // c10::cuda::setCurrentCUDAStream, the jit::CUDAStream object needs + // to be converted to c10::cuda::CUDAStream. Since the latter cannot + // be returned from a class registered via TorchBind, this can only be + // achieved by packing the c10::cuda::CUDAStream instance contained + // inside the jit::CUDAStream object to a uint64_t representation, and + // unpacking it inside this operator. The unpacked stream is then used + // to set the current CUDA stream. + auto packed = s->pack(); + auto unpacked = c10::cuda::CUDAStream::unpack(packed); + c10::cuda::setCurrentCUDAStream(unpacked); + }, + aliasAnalysisFromSchema()), +}); +} // namespace +} // namespace jit +} // namespace torch +#endif diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index a6125bb9f202a..537716e1ad1f1 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -1,4 +1,5 @@ #include +#include namespace torch { namespace jit { @@ -205,7 +206,7 @@ int64_t partProduct(int n, int m) { return (int64_t)n; if (m == (n + 2)) return (int64_t)n * m; - int k = (n + m) / 2; + auto k = n + (m - n) / 2; // Overflow-safe midpoint if ((k & 1) != 1) k = k - 1; return partProduct(n, k) * partProduct(k + 2, m); @@ -434,22 +435,13 @@ void listSlice(Stack* stack) { const int64_t list_size = list.size(); - // clamp start and end to the bounds of the list - const auto normalized_start = - std::max((int64_t)0, normalizeIndex(start, list_size)); - const auto normalized_end = - std::min(list_size, normalizeIndex(end, list_size)); - c10::List sliced_list = make_result_list(list.elementType()); - if (normalized_end <= normalized_start) { - // early exit if the slice is trivially empty - push(stack, std::move(sliced_list)); - return; - } - - sliced_list.reserve(normalized_end - normalized_start); + const int64_t num_values = + slice_indices_adjust(list_size, &start, &end, step); + sliced_list.reserve(num_values); - for (auto i = normalized_start; i < normalized_end;) { + int i = start; + for (int j = 0; j < num_values; ++j) { sliced_list.push_back(list.get(i)); i += step; } diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index ae974c063ef35..d7c9ede9294fb 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +57,19 @@ c10::impl::GenericList make_result_list(const TypePtr& elemType); inline void noop(Stack* n) {} +// As described in https://docs.python.org/3/library/functions.html#round +// When a number is exactly halfway between two integers, python builtin round +// function will round to even number. We use round(x/2)*2 to handle the +// special halfway case. For positive 'x', round(x/2)*2 = +// round((x_e + x_r)/2)*2 = x_e + round(x_r/2)*2, where x_e is an even integer, +// x_r is either 0.5 of 1.5, round(x_r/2)*2 results a 0 or 2, so the final +// result will always be a even number. Due to symmetricity, it also applies to +// negative cases. +inline double round_to_even(double a) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + return a - std::floor(a) == 0.5 ? (std::round(a * 0.5) * 2.0) : std::round(a); +} + // using the rules from python_arg_parser FunctionParameter::check // tensor cannot have grad set, tensor must be 0 dim, // and if the dest is an int the source must be integral type @@ -165,7 +179,7 @@ void setItem(const c10::List& list, int64_t idx, T&& value) { if (normalized_idx < 0 || normalized_idx >= list_size) { throw std::out_of_range("list index out of range"); } - list.set(normalized_idx, std::move(value)); + list.set(normalized_idx, std::forward(value)); } void listAppend(Stack* stack); diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index bf2ffa421ee94..847afeb1152b4 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1,6 +1,8 @@ +#include #include #include #include +#include #include #include @@ -29,23 +31,22 @@ namespace { std::string stringSlice( std::string string, - int64_t start, - int64_t end, + c10::optional start, + c10::optional end, int64_t step) { - TORCH_CHECK(step == 1, "Slicing a string only supports step=1"); - - const int64_t size = string.size(); + int64_t start_val = start.has_value() ? start.value() : INT64_MAX; + int64_t end_val = end.has_value() ? end.value() : INT64_MAX; - // Clamp start and end to the bounds of the list - start = std::max(int64_t(0), normalizeIndex(start, size)); - end = std::min(size, normalizeIndex(end, size)); + const int64_t num_vals = + slice_indices_adjust(string.size(), &start_val, &end_val, step); - if (end <= start) { - // Slice is empty - return std::string(""); + int64_t i = start_val; + std::string result = ""; + for (int64_t j = 0; j < num_vals; j++) { + result += string[i]; + i += step; } - std::string result(string.begin() + start, string.begin() + end); return result; } @@ -98,6 +99,88 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::cpu(Tensor(a) self) -> Tensor(a|b)"), + [](Stack* stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.cpu()); + }, + aliasAnalysisFromSchema()), + Operator( + prim::tolist, + // This operator has to be unschematized because the return type + // depends on the type hint and input. The implementation of this + // operator below is intended to be as close to the Python + // implementation in torch/csrc/utils/tensor_list.cpp as possible. + [](const Node* /*node*/) -> Operation { + return [](Stack* stack) { + int elem_ty_val; + int dim_val; + at::Tensor t; + + pop(stack, elem_ty_val); + pop(stack, dim_val); + pop(stack, t); + + // If the Tensor is not on the CPU, transfer it. + if (!t.device().is_cpu()) { + t = t.cpu(); + } + + // Rebuild the output type using elem_ty_val and dim_val. Start + // with the element type corresponding to elem_ty_val. + TypePtr out_ty; + if (elem_ty_val == 0) { + out_ty = IntType::get(); + } else if (elem_ty_val == 1) { + out_ty = FloatType::get(); + } else if (elem_ty_val == 2) { + out_ty = BoolType::get(); + } else { + TORCH_CHECK( + false, + "Unsupported element type for tolist; only int, float and bool are supported"); + } + + // Check that type of the Tensor matches that of the annotation. + // Make an exception for the case in which the annotated type is + // float and the Tensor data type is also float; the elements will + // be casted to double later. + TORCH_CHECK( + (out_ty == FloatType::get() && t.is_floating_point()) || + tryScalarTypeFromJitType(out_ty) == t.scalar_type(), + "Output annotation element type and runtime tensor element type must match for tolist()"); + + // Check that the dimension of the Tensor matches that of the + // annotation. + TORCH_CHECK( + dim_val == t.dim(), + "Output annotation list dimension and runtime tensor dimension must match for tolist()"); + + // Wrap out_ty in a ListType dim times. + for (int i = 0; i < dim_val; ++i) { + out_ty = ListType::create(out_ty); + } + + int64_t dim = t.dim(); + auto sizes = t.sizes(); + auto strides = t.strides(); + size_t element_size = t.element_size(); + char* data = static_cast(t.data_ptr()); + auto result = tensorToListRecursive( + data, + 0, + dim, + out_ty, + t.scalar_type(), + sizes, + strides, + element_size); + push(stack, std::move(result)); + }; + }, + aliasAnalysisSpecialCase()), // only used internally in range() translation OperatorGenerator( TORCH_SELECTIVE_SCHEMA( @@ -521,7 +604,7 @@ RegisterOperators reg( aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( - "aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"), + "aten::slice.t(t[] l, int? start=0, int? end=9223372036854775807, int step=1) -> t[]"), listSlice, aliasAnalysisFromSchema()), OperatorGenerator( @@ -548,6 +631,14 @@ RegisterOperators reg( TORCH_SELECTIVE_SCHEMA("aten::eq.int_list(int[] a, int[] b) -> bool"), listEq, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::eq.device(Device a, Device b) -> bool"), + [](Stack* stack) { + auto a = pop(stack).toDevice(); + auto b = pop(stack).toDevice(); + push(stack, a == b); + }, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"), [](Stack* stack) { push(stack, IValue::uninitialized()); }, @@ -589,6 +680,12 @@ RegisterOperators reg( push(stack, x != y); }, aliasAnalysisFromSchema()), + // We define aten::dequantize in both native_functions.yaml and here, + // however, aten::dequantize.any defined here overrides + // aten::dequantize.tensors in native_functions.yaml. The variants here + // are only for graph mode quantization, and they should be removed once + // we deprecate graph mode quantization, and use the variants in + // native_functions.yaml. OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::dequantize.tensor(Tensor qtensor) -> Tensor"), @@ -598,10 +695,19 @@ RegisterOperators reg( push(stack, at::dequantize(qtensor)); }, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::dequantize.list(Tensor[] qtensors) -> Tensor[]"), + [](Stack* stack) { + auto qtensors = pop(stack).toTensorVector(); + push(stack, at::dequantize(qtensors)); + }, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA("aten::dequantize.any(Any tensors) -> Any"), [](Stack* stack) { dequantize(*stack); }, aliasAnalysisFromSchema()), + DEFINE_UNARY_OP(aten::log, std::log(a), float, float), DEFINE_STRING_OP(aten::add, a + b, str), DEFINE_COMPARISON_OP(aten::eq, a == b), DEFINE_COMPARISON_OP(aten::ne, a != b), @@ -615,6 +721,7 @@ RegisterOperators reg( DEFINE_BOOL_OP(aten::__and__, a&& b), DEFINE_BOOL_OP(aten::__or__, a || b), DEFINE_BOOL_OP(aten::__xor__, a != b), + DEFINE_UNARY_OP(aten::round, round_to_even(a), float, float), DEFINE_UNARY_OP(aten::floor, floor(a), int, int), DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int), DEFINE_UNARY_OP(aten::neg, -a, int, float), @@ -730,6 +837,11 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::__contains__.int_list(int[] l, int item) -> bool"), + listContains, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::__contains__.str_list(str[] l, str item) -> bool"), @@ -797,7 +909,7 @@ RegisterOperators reg( TORCH_SELECTIVE_SCHEMA( "aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"), [](Stack* stack) { - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index(self, indices); push(stack, std::move(result)); @@ -810,7 +922,7 @@ RegisterOperators reg( auto unsafe = pop(stack).toBool(); auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::_index_put_impl_(self, indices, values, accumulate, unsafe); @@ -823,7 +935,7 @@ RegisterOperators reg( [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index_put_(self, indices, values, accumulate); push(stack, std::move(result)); @@ -835,7 +947,7 @@ RegisterOperators reg( [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index_put_(self, indices, values, accumulate); push(stack, std::move(result)); @@ -1094,7 +1206,7 @@ void dictUpdate(Stack* stack) { auto dict = pop(stack).toGenericDict(); for (const auto& item : to_add) { - dict.insert(item.key(), item.value()); + dict.insert_or_assign(item.key(), item.value()); } } @@ -1255,7 +1367,7 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size) { int64_t stringFindImpl( std::string string, - std::string substr, + const std::string& substr, int64_t start, int64_t end, bool reverse = false) { diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index dc075ce141661..9b731609243f5 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -29,7 +29,15 @@ RegisterOperators reg( {Operator( prim::profile, [](const Node* node) -> Operation { - auto callback = node->cast()->getCallback(); + return [](Stack* stack) { + AT_ERROR( + "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT + }; + }, + aliasAnalysisSpecialCase()), + Operator( + prim::profile_ivalue, + [](const Node* node) -> Operation { return [](Stack* stack) { AT_ERROR( "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT @@ -56,9 +64,36 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + Operator( + prim::RequiresGradCheck /* (...) -> (..., bool) */, + [](const Node* node) -> Operation { + std::vector rg_props = + fmap(node->tys(attr::types), [](const TypePtr& t) { + // if an rg property changes we assume a tensor does require + // gradients which is set in `guardDifferentiableGraph` + TORCH_INTERNAL_ASSERT( + t->cast()->requiresGrad().has_value()); + return *t->cast()->requiresGrad(); + }); + return [rg_props](Stack* stack) { + auto num_inputs = rg_props.size(); + // Check every input's shape against profiled (expected) shape. + for (size_t i = 0; i < num_inputs; i++) { + auto& input = peek(stack, i, num_inputs); + const auto& t = input.toTensor(); + if (rg_props[i] != t.requires_grad()) { + push(stack, false); + return; + } + } + + push(stack, true); + }; + }, + aliasAnalysisSpecialCase()), Operator( prim::TypeCheck /* (...) -> (..., bool) */, - [](const Node * /* node */) -> Operation { + [](const Node* /* node */) -> Operation { return [](Stack* /* stack */) { AT_ERROR("prim::TypeCheck not yet implemented"); // NOLINT }; @@ -211,6 +246,13 @@ RegisterOperators reg( push(stack, c10::Device(pop(stack).toStringRef())); }, aliasAnalysisFromSchema()), + Operator( + "aten::percentFormat(str self, ...) -> str", + [](Stack* stack) { + size_t num_inputs = pop(stack).toInt(); + percentFormat(*stack, num_inputs); + }, + aliasAnalysisFromSchema()), Operator( "aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", [](Stack* stack) { @@ -225,14 +267,6 @@ RegisterOperators reg( to_dispatch(self, device, scalarType, non_blocking, copy)); }, aliasAnalysisFromSchema()), - Operator( - "aten::eq.device(Device a, Device b) -> bool", - [](Stack* stack) { - auto a = pop(stack).toDevice(); - auto b = pop(stack).toDevice(); - push(stack, a == b); - }, - aliasAnalysisFromSchema()), Operator( "prim::requires_grad(Tensor a) -> bool", [](Stack* stack) { @@ -265,6 +299,14 @@ RegisterOperators reg( push(stack, a.is_mkldnn()); }, aliasAnalysisFromSchema()), + Operator( + "prim::is_vulkan(Tensor a) -> bool", + [](Stack* stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_vulkan()); + }, + aliasAnalysisFromSchema()), Operator( "prim::is_quantized(Tensor a) -> bool", [](Stack* stack) { @@ -301,14 +343,6 @@ RegisterOperators reg( push(stack, a.layout()); }, aliasAnalysisFromSchema()), - Operator( - "aten::cpu(Tensor(a) self) -> Tensor(a|b)", - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.cpu()); - }, - aliasAnalysisFromSchema()), Operator( "prim::index(Device self) -> int?", [](Stack* stack) { @@ -338,6 +372,34 @@ RegisterOperators reg( "prim::AutogradZero() -> Tensor", [](Stack* stack) { stack->emplace_back(at::Tensor()); }, aliasAnalysisSpecialCase()), + Operator( + "prim::ReductionSizes(int[] size, int[] red_axes, bool keepdim = False) -> int[]", + [](Stack* stack) { + bool keepdim = pop(stack).toBool(); + c10::List axes = pop(stack).toIntList(); + c10::List size = pop(stack).toIntList(); + if (keepdim) { + for (const auto& axis : axes) { + size.set(axis, 1); + } + } else { + int64_t index = 0; + auto iter = size.begin(); + std::sort(axes.begin(), axes.end()); + for (const auto& axis : axes) { + // move iter to the next axis + iter += axis - index; + + // input iter points to axis and is updated to axis + 1 + iter = size.erase(iter); + + // update current index for iter + index = axis + 1; + } + } + push(stack, IValue(std::move(size))); + }, + aliasAnalysisFromSchema()), Operator( "prim::BroadcastSizes(...) -> int[]", [](Stack* stack) { @@ -349,7 +411,7 @@ RegisterOperators reg( at::infer_size(size, peek(stack, i, num_inputs).toIntVector()); } drop(stack, num_inputs); - push(stack, IValue(std::move(size))); + push(stack, IValue(size)); }, aliasAnalysisSpecialCase()), Operator( @@ -517,80 +579,6 @@ RegisterOperators reg( } }, aliasAnalysisFromSchema()), - Operator( - prim::tolist, - // This operator has to be unschematized because the return type - // depends on the type hint and input. The implementation of this - // operator below is intended to be as close to the Python - // implementation in torch/csrc/utils/tensor_list.cpp as possible. - [](const Node* node) -> Operation { - return [](Stack* stack) { - int elem_ty_val; - int dim_val; - at::Tensor t; - - pop(stack, elem_ty_val); - pop(stack, dim_val); - pop(stack, t); - - // If the Tensor is not on the CPU, transfer it. - if (!t.device().is_cpu()) { - t = t.cpu(); - } - - // Rebuild the output type using elem_ty_val and dim_val. Start - // with the element type corresponding to elem_ty_val. - TypePtr out_ty; - if (elem_ty_val == 0) { - out_ty = IntType::get(); - } else if (elem_ty_val == 1) { - out_ty = FloatType::get(); - } else if (elem_ty_val == 2) { - out_ty = BoolType::get(); - } else { - TORCH_CHECK( - false, - "Unsupported element type for tolist; only int, float and bool are supported"); - } - - // Check that type of the Tensor matches that of the annotation. - // Make an exception for the case in which the annotated type is - // float and the Tensor data type is also float; the elements will - // be casted to double later. - TORCH_CHECK( - (out_ty == FloatType::get() && t.is_floating_point()) || - tryScalarTypeFromJitType(out_ty) == t.scalar_type(), - "Output annotation element type and runtime tensor element type must match for tolist()"); - - // Check that the dimension of the Tensor matches that of the - // annotation. - TORCH_CHECK( - dim_val == t.dim(), - "Output annotation list dimension and runtime tensor dimension must match for tolist()"); - - // Wrap out_ty in a ListType dim times. - for (int i = 0; i < dim_val; ++i) { - out_ty = ListType::create(out_ty); - } - - int64_t dim = t.dim(); - auto sizes = t.sizes(); - auto strides = t.strides(); - size_t element_size = t.element_size(); - char* data = static_cast(t.data_ptr()); - auto result = tensorToListRecursive( - data, - 0, - dim, - out_ty, - t.scalar_type(), - sizes, - strides, - element_size); - push(stack, std::move(result)); - }; - }, - aliasAnalysisSpecialCase()), Operator( prim::ConstantChunk, [](const Node* node) -> Operation { @@ -635,6 +623,18 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + // This operator is generated inside the compiler for indexing into + // ModuleDict without a statically determinable key. Accordingly, + // self must be a ModuleType and the output must be an InterfaceType. + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "prim::ModuleDictIndex(Any self, str ind) -> Any"), + [](Stack* stack) { + IValue ind = pop(stack); + IValue module_dict = pop(stack); + push(stack, module_dict.toModule().attr(ind.toStringRef())); + }, + aliasAnalysisFromSchema()), Operator( "aten::dict() -> Dict(str, Tensor)", [](Stack* stack) { @@ -702,24 +702,9 @@ RegisterOperators logging_operators( }, aliasAnalysisFromSchema())}); -template void hashValue(Stack* stack) { auto value = pop(stack); - auto hash = std::hash()(value.to()); - push(stack, int64_t(hash)); -} - -// As described in https://docs.python.org/3/library/functions.html#round -// When a number is exactly halfway between two integers, python builtin round -// function will round to even number. We use round(x/2)*2 to handle the -// special halfway case. For positive 'x', round(x/2)*2 = -// round((x_e + x_r)/2)*2 = x_e + round(x_r/2)*2, where x_e is an even integer, -// x_r is either 0.5 of 1.5, round(x_r/2)*2 results a 0 or 2, so the final -// result will always be a even number. Due to symmetricity, it also applies to -// negative cases. -double round_to_even(double a) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - return a - std::floor(a) == 0.5 ? (std::round(a * 0.5) * 2.0) : std::round(a); + push(stack, value.hash()); } RegisterOperators reg2({ @@ -761,10 +746,6 @@ RegisterOperators reg2({ // `listContains` is not implemented for non-primitive types // TODO: Add List[bool] once .to> doesn't throw an error - Operator( - "aten::__contains__.int_list(int[] l, int item) -> bool", - listContains, - aliasAnalysisFromSchema()), Operator( "aten::__contains__.float_list(float[] l, float item) -> bool", listContains, @@ -946,8 +927,6 @@ RegisterOperators reg2({ DEFINE_INT_OP(aten::__lshift__, a << b), DEFINE_INT_OP(aten::__rshift__, a >> b), - DEFINE_UNARY_OP(aten::round, round_to_even(a), float, float), - DEFINE_UNARY_OP(aten::log, std::log(a), float, float), DEFINE_GENERIC_BINARY_OP(aten::log, std::log(a) / std::log(b), float), DEFINE_INT_FLOAT_OP(aten::log, std::log(a) / std::log(b), float), DEFINE_SCALAR_SCALAR_BINARY_OP( @@ -1156,16 +1135,8 @@ RegisterOperators reg2({ #undef DEFINE_DIVMOD_MIXED_OP Operator( - "aten::hash.str(str t) -> int", - hashValue, - aliasAnalysisFromSchema()), - Operator( - "aten::hash.int(int t) -> int", - hashValue, - aliasAnalysisFromSchema()), - Operator( - "aten::hash.float(float t) -> int", - hashValue, + "aten::hash.generic(t value) -> int", + hashValue, aliasAnalysisFromSchema()), }); diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index c8f62870f789c..2cd5a13d3f4b1 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -372,6 +372,10 @@ RegisterOperators reg({ TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"), [](Stack* stack) { push(stack, true); }, aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::has_torch_function(...) -> bool"), + [](Stack* stack) { push(stack, false); }, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"), @@ -431,10 +435,7 @@ RegisterOperators reg({ aliasAnalysisConservative()), Operator( "aten::set_grad_enabled(bool val) -> ()", - [](Stack* stack) { - torch::GradMode::set_enabled(pop(stack).toBool()); - push(stack, IValue()); - }, + [](Stack* stack) { torch::GradMode::set_enabled(pop(stack).toBool()); }, aliasAnalysisConservative()), }); } // namespace diff --git a/torch/csrc/jit/runtime/slice_indices_adjust.cpp b/torch/csrc/jit/runtime/slice_indices_adjust.cpp new file mode 100644 index 0000000000000..8c0d39156d29e --- /dev/null +++ b/torch/csrc/jit/runtime/slice_indices_adjust.cpp @@ -0,0 +1,57 @@ +#include + +#include +#include + +namespace torch { +namespace jit { + +int64_t slice_indices_adjust( + int64_t length, + int64_t* start, + int64_t* stop, + int64_t step) { + TORCH_CHECK(step != 0, "List slice should have non-zero step") + TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds") + + // Comes from PySlice_Unpack. + if (*start == INT64_MAX) { + *start = (step < 0) ? INT64_MAX : 0; + } + if (*stop == INT64_MAX) { + *stop = (step < 0) ? INT64_MIN : INT64_MAX; + } + + // Comes from PySlice_AdjustIndices. + if (*start < 0) { + *start += length; + if (*start < 0) { + *start = (step < 0) ? -1 : 0; + } + } else if (*start >= length) { + *start = (step < 0) ? length - 1 : length; + } + + if (*stop < 0) { + *stop += length; + if (*stop < 0) { + *stop = (step < 0) ? -1 : 0; + } + } else if (*stop >= length) { + *stop = (step < 0) ? length - 1 : length; + } + + if (step < 0) { + if (*stop < *start) { + return (*start - *stop - 1) / (-step) + 1; + } + } else { + if (*start < *stop) { + return (*stop - *start - 1) / step + 1; + } + } + return 0; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/slice_indices_adjust.h b/torch/csrc/jit/runtime/slice_indices_adjust.h new file mode 100644 index 0000000000000..ea1e9511769db --- /dev/null +++ b/torch/csrc/jit/runtime/slice_indices_adjust.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { + +// Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +// 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software +// Foundation; All Rights Reserved +// +// Stolen (with appropriate modifications) by @agolynski +// (https://github.com/pytorch/pytorch/pull/33019) from cpython repo +// Objects/sliceobject.c with comment: this is harder to get right than you +// might think +// +// This adjusts indexes according to python list semantics and returns number +// of elements in the resulting list. +TORCH_API int64_t slice_indices_adjust( + int64_t length, + int64_t* start, + int64_t* stop, + int64_t step); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/README.md b/torch/csrc/jit/runtime/static/README.md index e2cba9dd22ac0..0ffb946aaf588 100644 --- a/torch/csrc/jit/runtime/static/README.md +++ b/torch/csrc/jit/runtime/static/README.md @@ -2,7 +2,7 @@ # Static Runtime -The premise of this approach is that a small subset of neural networks are well represented by a +The premise of this approach is that a small subset of neural networks are well represented by a completely flattened dataflow graph. TorchScript supports a far more feature programming paradigm, so many models will not work out of the box. @@ -13,7 +13,6 @@ This is a list of current assumptions for use with this feature. - Inference only execution -- Single CPU device After `torch.jit.freeze` and inlining/constant propagation is run on the model: @@ -22,6 +21,46 @@ After `torch.jit.freeze` and inlining/constant propagation is run on the model: - No references to `self` - Inlined weights (i.e. no calls to `GetAttr`) +## Threading model +Static runtime supports two execution modes. + +Mode 1: single-threaded with no parallelism except for intra-op parallelism. +For this mode, you can do either: +``` + // m is the TorchScript module + auto runtime = StaticRuntime(m, opts); + auto output = runtime.run(args, kwargs); +``` +or +``` + auto mod = PrepareForStaticRuntime(m); + auto runtime = StaticRuntime(mod, opts); + auto output = runtime.run(args, kwargs); +``` +Mode 2: similar to data parallelism, run the same model for different inputs +on different threads at the same time. In this case, run +`PrepareForStaticRuntime` to prepare the graph for Static Runtime. You +should have one InferenceModule instance per model, and one Static Runtime instance +per running thread. To avoiding creating StaticRuntime on the fly, use a +synchronized stack (i.e. `boost::lockfree::stack`) to cache all the Static +Runtime instances in your code. +``` + // initialization + auto mod = PrepareForStaticRuntime(m); + // 128 is good for most cases. Pick a number that works for you + boost::lockfree::stack, + boost::lockfree::fixed_sized> pool(128); + + // inference + std::shared_ptr runtime = nullptr; + pool.pop(runtime); + if (!runtime) { + runtime = std::make_shared(mod, opts); + } + auto output = runtime->run(args, kwargs); + pool.push(runtime); +``` + ## Planned features - Memory planning diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp new file mode 100644 index 0000000000000..b4ef21e8fa37f --- /dev/null +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -0,0 +1,278 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +void createFusionGroups(Block* block, AliasDb* aliasDb); + +void fuseStaticSubgraphs(std::shared_ptr graph) { + PrepareGraphForStaticRuntime(graph); + auto aliasDb = torch::make_unique(graph); + createFusionGroups(graph->block(), aliasDb.get()); + torch::jit::EliminateDeadCode(graph); +} + +Operation createStaticSubgraphRuntime(const Node* node) { + auto g = torch::jit::PrepareForStaticRuntime(node->g(attr::Subgraph)); + auto runtime = std::make_shared(g); + auto num_inputs = runtime->get_inference_module()->input_regs.size(); + return [runtime, num_inputs](Stack* stack) { + RECORD_FUNCTION("Static Runtime", std::vector()); + auto inps = torch::jit::last(stack, num_inputs); + // TODO maybe avoid call to vec + auto outputs = runtime->run(inps.vec(), {}); + torch::jit::drop(stack, num_inputs); + + if (runtime->num_outputs() > 1) { + for (auto& o : outputs.toTuple()->elements()) { + push_one(*stack, std::move(o)); + } + } else { + push_one(*stack, std::move(outputs)); + } + return 0; + }; +} + +RegisterOperators StaticSubgraphOps({torch::jit::Operator( + prim::StaticSubgraph, + createStaticSubgraphRuntime, + AliasAnalysisKind::INTERNAL_SPECIAL_CASE)}); + +#define REQ(cond) \ + if (!(cond)) { \ + GRAPH_DEBUG("Failed cond " #cond "\n"); \ + return false; \ + } + +bool canHandle(Node* node) { + for (Value* input : node->inputs()) { + bool is_tensor = !!input->type()->cast(); + auto list_type = input->type()->cast(); + bool is_list = list_type && list_type->getElementType()->cast(); + auto tuple_type = input->type()->cast(); + bool is_tuple = [&]() -> bool { + if (!tuple_type) { + return false; + } + for (auto& t : tuple_type->elements()) { + if (!t->cast()) { + return false; + } + } + return true; + }(); + if (!(is_tensor || is_list || is_tuple)) { + if (input->node()->kind() != prim::Constant) { + return false; + } + } + } + + auto kind = node->kind(); + if (kind.is_prim()) { + REQ(kind == prim::TupleConstruct || kind == prim::ListConstruct || + kind == prim::StaticSubgraph); + if (kind == prim::TupleConstruct || kind == prim::ListConstruct) { + for (Value* input : node->inputs()) { + if (!input->type()->cast()) { + return false; + } + } + } + return true; + } + + // TODO add "canRunNatively" once memory management is audited + return canRunOutOfPlace(node); +} + +bool canMerge(Node* consumer, Node* producer, AliasDb* aliasDb) { + // Only fuse within a block + REQ(consumer->owningBlock() == producer->owningBlock()); + + // Symbolic checks + REQ(canHandle(producer) || producer->kind() == prim::StaticSubgraph); + TORCH_INTERNAL_ASSERT( + consumer->kind() == prim::StaticSubgraph || canHandle(consumer)); + + // Alias checks + REQ(aliasDb->couldMoveBeforeTopologically(producer, consumer)); + + // Ops that return aliases can only be folded if this is the only use. + if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze || + producer->kind() == prim::ConstantChunk) { + for (auto& use : producer->output(0)->uses()) { + REQ(use.user == consumer); + } + } + + return true; +} + +Node* getOrCreateStaticSubgraph(Node* n, AliasDb* aliasDb) { + if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::StaticSubgraph) { + return n; + } + GRAPH_UPDATE("Creating a static subgraph::Group node from: ", *n); + return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( + n, prim::StaticSubgraph, *aliasDb); +} + +value_list sortReverseTopological(ArrayRef inputs, Block* b) { + value_list result; + for (auto i : inputs) { + if (i->node()->owningBlock() == b) { + result.push_back(i); + } + } + // Sort in reverse topological order + std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { + return a->node()->isAfter(b->node()); + }); + return result; +} + +static void debugDumpFusionGroup(const std::string& msg, Node* n) { + GRAPH_DEBUG(msg, *n); + if (n->kind() == prim::StaticSubgraph) { + GRAPH_DEBUG(*n->g(attr::Subgraph)); + } +} + +c10::optional tryMerge( + Node* fusion_group, + Node* to_merge, + AliasDb* aliasDb) { + if (!canMerge(fusion_group, to_merge, aliasDb)) { + return c10::nullopt; + } + + std::vector nodes_to_merge = {to_merge}; + + if (to_merge->kind() == aten::cat) { + Node* listconstruct = to_merge->input(0)->node(); + nodes_to_merge.push_back(listconstruct); + } + + // First, try to move all the nodes we want to fuse next to the fusion + // group. + Node* move_point = fusion_group; + for (auto n : nodes_to_merge) { + GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n)); + if (!aliasDb->moveBeforeTopologicallyValid(n, move_point)) { + GRAPH_UPDATE("Failed to move because of AliasDb checks!"); + return c10::nullopt; + } + move_point = n; + } + + // Now all the nodes that we're going to fuse are moved next to the fusion + // group, so we can safely merge them into the fusion group subgraph. + fusion_group = getOrCreateStaticSubgraph(fusion_group, aliasDb); + + for (auto n : nodes_to_merge) { + GRAPH_UPDATE("Merging ", getHeader(n)); + SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( + n, fusion_group, *aliasDb); + } + return fusion_group; +} + +std::pair createFusionGroup( + Node* fusion_node, + AliasDb* aliasDb) { + fusion_node = getOrCreateStaticSubgraph(fusion_node, aliasDb); + + GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n"); + auto inputs = + sortReverseTopological(fusion_node->inputs(), fusion_node->owningBlock()); + for (auto input : inputs) { + debugDumpFusionGroup("Current fusion group: ", fusion_node); + GRAPH_DEBUG("Trying to merge: ", *input->node()); + if (auto maybe_fusion_group = + tryMerge(fusion_node, input->node(), aliasDb)) { + // we successfully merged, so the new group's `inputs` may have + // changed. So rescan the new group for more merging opportunities. + return std::make_pair( + maybe_fusion_group.value()->reverseIterator(), true); + } + } + + return std::make_pair(++fusion_node->reverseIterator(), false); +} + +std::pair scanNode(Node* n, AliasDb* aliasDb) { + GRAPH_DEBUG("Considering node:", *n); + + if (!canHandle(n)) { + return std::make_pair(++n->reverseIterator(), false); + } + + return createFusionGroup(n, aliasDb); +} + +void createFusionGroups(Block* block, AliasDb* aliasDb) { + bool any_changed = true; + while (any_changed) { + any_changed = false; + for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { + bool changed; + std::tie(it, changed) = scanNode(*it, aliasDb); + any_changed |= changed; + } + } + + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + createFusionGroups(b, aliasDb); + } + } + + // Try to merge adjacent fusion groups together. Because we have only merged + // by looking at graph inputs, without this we would not attempt to merge + // adjacent fusion groups that don't have a depdency on each other + + std::vector initial_fusion_groups; + for (Node* n : block->nodes()) { + if (n->kind() == prim::StaticSubgraph) { + initial_fusion_groups.push_back(n); + } + } + + Node* prev_fusion_group = + initial_fusion_groups.size() ? initial_fusion_groups[0] : nullptr; + + for (size_t i = 1; i < initial_fusion_groups.size(); ++i) { + // Try merging the just created fusion group into the previous one. + // If it did not work, then put the previous fusion group into + // fusion_groups vector - we will not touch it anymore in this loop. + // If merging suceeded, save the merged group as the "previous" fusion + // group so that we can try to merge the next one into it. + + Node* fusion_group = initial_fusion_groups[i]; + debugDumpFusionGroup( + "Trying to merge into the previous fusion group: ", prev_fusion_group); + if (auto merged_fusion_group = + tryMerge(prev_fusion_group, fusion_group, aliasDb)) { + prev_fusion_group = *merged_fusion_group; + debugDumpFusionGroup( + "Successfully merged into the previous fusion group: ", + prev_fusion_group); + } else { + GRAPH_DEBUG("Cannot merge into the previous fusion group"); + prev_fusion_group = fusion_group; + } + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/fusion.h b/torch/csrc/jit/runtime/static/fusion.h new file mode 100644 index 0000000000000..5f0e30b8505b9 --- /dev/null +++ b/torch/csrc/jit/runtime/static/fusion.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +TORCH_API void fuseStaticSubgraphs(std::shared_ptr graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 5b3c1e029a903..8160e3af83694 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -1,207 +1,844 @@ #include -#include + +#include +#include +#include +#include +#include #include +#include #include #include #include +#include +#include #include namespace torch { namespace jit { -using c10::DispatchKey; -using c10::RegisterOperators; - -static auto reg = - RegisterOperators() - .op("static::add(Tensor a, Tensor b) -> Tensor", - RegisterOperators::options().kernel( - DispatchKey::CPU, - [](at::Tensor a, at::Tensor b) -> at::Tensor { return a + b; })) - .op("static::mul.a(Tensor a, Tensor b) -> Tensor", - RegisterOperators::options().kernel( - DispatchKey::CPU, - [](at::Tensor a, at::Tensor b) -> at::Tensor { return a * b; })) - .op("static::mul.b(Tensor a, int b) -> Tensor", - RegisterOperators::options().kernel( - DispatchKey::CPU, - [](at::Tensor a, int64_t b) -> at::Tensor { return a * b; })); - -#define SUPPORTED_OPS(F) \ - F(aten::__getitem__) \ - F(aten::add) \ - F(aten::addmm) \ - F(aten::bmm) \ - F(aten::cat) \ - F(aten::clamp) \ - F(aten::contiguous) \ - F(aten::div) \ - F(aten::flatten) \ - F(aten::index_put_) \ - F(aten::isnan) \ - F(aten::matmul) \ - F(aten::mul) \ - F(aten::permute) \ - F(aten::relu) \ - F(aten::sigmoid) \ - F(aten::size) \ - F(aten::softmax) \ - F(aten::t) \ - F(aten::to) \ - F(aten::transpose) \ - F(aten::view) \ - F(prim::Constant) \ - F(prim::ListConstruct) \ - F(prim::TupleConstruct) - -StaticRuntime::StaticRuntime(const torch::jit::Module& m) - : module_(m.copy()), graph_(nullptr) { - module_.eval(); - module_ = freeze_module(module_); - graph_ = module_.get_method("forward").graph(); - - Inline(*graph_); - ConstantPropagation(graph_); - Canonicalize(graph_); - ConstantPropagation(graph_); - RemoveTensorMutation(graph_); - ConstantPropagation(graph_); - - for (auto n : graph_->nodes()) { +void PrepareGraphForStaticRuntime(std::shared_ptr graph) { + Inline(*graph); + ConstantPropagation(graph); + Canonicalize(graph); + ConstantPropagation(graph); + RemoveTensorMutation(graph); + ConstantPropagation(graph); + EliminateDeadCode(graph); +} + +namespace { +void OptimizeGraph(std::shared_ptr& graph) { + PrepareGraphForStaticRuntime(graph); + FuseInferenceOpsForSparseNN(graph); + ConstantPropagation(graph); +} + +void CheckGraphEligibility(const std::shared_ptr& graph) { + for (auto n : graph->nodes()) { if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) { throw std::runtime_error("Cannot accelerate unfrozen graphs"); } - bool supported = false; -#define X(_) \ - if (n->kind() == c10::Symbol::fromQualString(#_)) { \ - supported = true; \ } - SUPPORTED_OPS(X) -#undef X - if (!supported) { - throw std::runtime_error( - std::string("Unsupported operation: ") + n->kind().toQualString()); + // check output types + // Static Runtime supports output types include None, Tensor and List/Tuple + // of Tensor + for (Value* output : graph->outputs()) { + VLOG(1) << "output: %" << output->debugName() + << " has type: " << output->type()->repr_str(); + auto kind = output->node()->kind(); + if (kind == prim::TupleConstruct || kind == prim::ListConstruct) { + for (Value* input : output->node()->inputs()) { + const auto& type = input->type(); + TORCH_CHECK( + type->cast() != nullptr, + "Static Runtime expects output type as List or Tuple of Tensor, but got List or Tuple of ", + type->repr_str()); + } + } else { + const auto& type = output->type(); + TORCH_CHECK( + type->cast() != nullptr || + type->cast() != nullptr, + "Static Runtime expects output type as None or Tensor, but got ", + type->repr_str()); + } + } +} + +// remove unused input 0 from graph +void RemoveSelfFromGraphInput(std::shared_ptr& graph) { + if (graph->inputs().at(0)->type()->is_module()) { + TORCH_CHECK(!graph->inputs().at(0)->hasUses()); + graph->eraseInput(0); + } +} + +// remove "self" from function schema +std::unique_ptr RemoveSelfFromSchema( + const c10::FunctionSchema& s) { + TORCH_CHECK(s.arguments().size() >= 1 && s.arguments()[0].name() == "self"); + std::vector args({s.arguments().begin() + 1, s.arguments().end()}); + return std::make_unique(s.cloneWithArguments(args)); +} + +// Returns two useful constructs: +// first: map each value to all values that are alive +// at the same time. +// second: set of all inputs/outputs/constants (always alive) +std::pair>, std::set> +LivenessMap(const std::shared_ptr& graph) { + std::unordered_map> liveness_map; + std::set always_alive; + + std::vector frontier; + // map live values to their deps, invariant: set.size() > 0 + std::unordered_map> live_values; + for (const auto& input : graph->inputs()) { + frontier.emplace_back(input); + always_alive.insert(input); + } + for (const auto& output : graph->outputs()) { + always_alive.insert(output); + } + + auto add_live_value = [&](Value* v) { + liveness_map[v] = {}; + + for (const auto& live_v : live_values) { + liveness_map.at(v).insert(live_v.first); + liveness_map.at(live_v.first).insert(v); + } + + // only add values to the live set if they + // have deps, otherwise they die immediately + if (v->uses().size()) { + live_values[v] = {}; + } + + for (const auto& u : v->uses()) { + const auto& node = u.user; + // track deps of this value + live_values.at(v).insert(node); + } + }; + + auto traverse_node = [&](Node* node, std::vector& dead) { + for (const auto& input : node->inputs()) { + // ignore constant values + if (input->node()->kind() == prim::Constant) { + always_alive.insert(input); + continue; + } + if (live_values.count(input)) { + live_values.at(input).erase(node); + if (!live_values.at(input).size()) { + dead.emplace_back(input); + } + } + } + }; + + for (const auto& node : graph->nodes()) { + for (const auto& v : node->outputs()) { + add_live_value(v); + } + + std::vector dead; + traverse_node(node, dead); + for (const auto& dead_value : dead) { + live_values.erase(dead_value); + } + } + + for (const auto& v : live_values) { + TORCH_CHECK(always_alive.count(v.first)); + } + + for (const auto& node : graph->nodes()) { + for (const auto& input : node->inputs()) { + for (const auto& output : node->outputs()) { + if (liveness_map.count(input) && liveness_map.count(output)) { + liveness_map.at(input).insert(output); + liveness_map.at(output).insert(input); + } + } + } + } + + return std::make_pair(liveness_map, always_alive); +} + +std::unordered_set GetOptimizableValues( + const std::shared_ptr& graph) { + std::unordered_set can_reuse; + // values used by unsupported ops (as either inputs or outputs) + // these need to be removed from "can_reuse" after analyzing all nodes + std::unordered_set cannot_reuse; + for (const auto& n : graph->nodes()) { + for (const auto& v : n->inputs()) { + if (canRunOutOfPlace(n) && canReuseInputs(n)) { + can_reuse.insert(v); + } else { + cannot_reuse.insert(v); + } + } + for (const auto& v : n->outputs()) { + if (canRunOutOfPlace(n) && canReuseOutputs(n)) { + can_reuse.insert(v); + } else { + cannot_reuse.insert(v); + } + } + } + for (auto v : cannot_reuse) { + can_reuse.erase(v); + } + return can_reuse; +} + +size_t AssignRegisters( + const std::shared_ptr& graph, + std::unordered_map& value_to_reg, + std::vector& values, + std::vector& input_regs, + std::vector& output_regs, + bool optimize_memory) { + auto lm = LivenessMap(graph); + auto optimizable_values = GetOptimizableValues(graph); + + size_t num_regs = 0; + size_t reused_regs = 0; + std::unordered_map> reg_to_val; + auto getReg = [&](Value* v) -> size_t { + if (!optimize_memory) { + return num_regs++; + } + TORCH_CHECK(!value_to_reg.count(v)); + auto iter = lm.first.find(v); + if (iter == lm.first.end()) { + return num_regs++; + } + if (!optimizable_values.count(v)) { + return num_regs++; + } + if (lm.second.count(v)) { + return num_regs++; + } + const auto& live_values = iter->second; + // iterate through all the allocated registers + // and check for potential re-use, greedily + for (const auto& v2r : value_to_reg) { + auto candidate_v = v2r.first; + + if (!optimizable_values.count(candidate_v)) { + continue; + } + if (lm.second.count(candidate_v)) { + continue; + } + + // Only re-use float* tensors + auto t = candidate_v->type()->cast(); + if (!t) { + continue; + } + // TODO audit this assumption (passes tests, but is scary) + if (t->scalarType() && *(t->scalarType()) != at::kFloat) { + continue; + } + // TODO + // if (*(t->scalarType()) != at::kFloat) { + // continue; + //} + if (!live_values.count(candidate_v)) { + bool already_used = false; + for (auto use : reg_to_val.at(v2r.second)) { + if (live_values.count(use)) { + already_used = true; + } + } + if (already_used) { + continue; + } + reused_regs++; + return v2r.second; + } + } + return num_regs++; + }; + + // assign register to Value* + for (Value* input : graph->inputs()) { + TORCH_CHECK(value_to_reg.count(input) == 0); + auto reg = getReg(input); + value_to_reg[input] = reg; + reg_to_val[reg].insert(input); + input_regs.push_back(reg); + } + for (Node* node : graph->nodes()) { + for (Value* input : node->inputs()) { + TORCH_CHECK(value_to_reg.count(input) > 0); + } + for (Value* output : node->outputs()) { + TORCH_CHECK( + value_to_reg.count(output) == 0, "the graph needs to be in SSA form"); + auto reg = getReg(output); + value_to_reg[output] = reg; + reg_to_val[reg].insert(output); } } + TORCH_CHECK(graph->outputs().size() > 0); + for (Value* output : graph->outputs()) { + TORCH_CHECK(value_to_reg.count(output) > 0); + output_regs.push_back(value_to_reg[output]); + } - SubgraphRewriter sr; - sr.RegisterRewritePattern( - R"IR( - graph(%x, %w, %s): - %r = aten::add(%x, %w, %s) - return (%r))IR", - R"IR( - graph(%x, %w, %s): - %y = static::add(%x, %w) - %r = static::mul(%y, %s) - return (%r))IR"); - sr.runOnGraph(graph_); + values.resize(value_to_reg.size()); + for (const auto& p : value_to_reg) { + values[p.second] = p.first; + } + return reused_regs; +} - // remove unused input 0 from graph - if (graph_->inputs().at(0)->type()->is_module()) { - if (!graph_->inputs().at(0)->hasUses()) { - graph_->eraseInput(0); +// Internal values are discarded after run if +// opts_.cleanup_activations is true. +void DeduceInternalValues( + const std::shared_ptr& graph, + const std::unordered_map& value_to_reg, + std::vector& internals) { + std::unordered_set outputs{ + graph->outputs().begin(), graph->outputs().end()}; + for (Node* node : graph->nodes()) { + if (node->kind() != prim::Constant) { + for (Value* output : node->outputs()) { + if (outputs.count(output) == 0) { + internals.push_back(value_to_reg.at(output)); + } + } } } +} +} // namespace - // fill workspace_ with constants - for (Node* node : graph_->nodes()) { +void InferenceModule::init() { + OptimizeGraph(graph); + CheckGraphEligibility(graph); + RemoveSelfFromGraphInput(graph); + reused_regs = AssignRegisters( + graph, + value_to_reg, + values, + input_regs, + output_regs, + opts.optimize_memory); + DeduceInternalValues(graph, value_to_reg, internals); +} + +InferenceModule::InferenceModule( + const torch::jit::Module& m, + InferenceModuleOptions opts_) + : module(m.copy()), graph(nullptr), schema(nullptr), opts(opts_) { + module.eval(); + module = freeze_module(module); + + Method method = module.get_method("forward"); + graph = method.graph(); + + const c10::FunctionSchema& s = method.function().getSchema(); + schema = RemoveSelfFromSchema(s); + + init(); +} + +InferenceModule::InferenceModule( + std::shared_ptr g, + InferenceModuleOptions opts_) + : module(), graph(std::move(g)), schema(nullptr), opts(opts_) { + init(); +} + +StaticRuntime::StaticRuntime( + const torch::jit::Module& m, + const StaticRuntimeOptions& opts) + : StaticRuntime(PrepareForStaticRuntime(m), opts) {} + +StaticRuntime::StaticRuntime( + std::shared_ptr m, + const StaticRuntimeOptions& opts) + : module_(m), opts_(opts) { + TORCH_CHECK( + module_ != nullptr, + "std::shared_ptr module_ cannot be nullptr") + + Graph* graph = module_->graph.get(); + std::unordered_map val_to_ival; + + // NB: create an unchanging std::vector we can reference + for (auto input : graph->inputs()) { + inputs_.emplace_back(); + } + for (auto i = 0; i < graph->inputs().size(); ++i) { + Value* input = graph->inputs()[i]; + val_to_ival[input] = &(inputs_[i]); + } + + // fill workspace_ with constants and create ProcessedNodes + // NB: before optimizing the order of execution, ensure that the + // memory optimization pass (LivenessMap + AssignRegisters) is + // aware of the new order! + + // Fill constants first, so we have a std::vector we can reference + // later + for (Node* node : graph->nodes()) { + if (node->kind() != prim::Constant) { + continue; + } + auto* v = node->output(); + TORCH_CHECK(v->type()->kind() != FunctionType::Kind); + constants_.emplace_back(toIValue(v).value()); + } + { + int i = 0; + for (Node* node : graph->nodes()) { + if (node->kind() != prim::Constant) { + continue; + } + auto* v = node->output(); + val_to_ival[v] = &(constants_[i++]); + } + } + for (Node* node : graph->nodes()) { if (node->kind() == prim::Constant) { - CHECK(node->output()->type()->kind() != FunctionType::Kind); - workspace_[node->output()] = toIValue(node->output()).value(); - } else { - nodes_.emplace_back(node); + continue; } + std::vector inputs; + for (Value* input : node->inputs()) { + inputs.emplace_back(val_to_ival.at(input)); + } + nodes_.emplace_back( + ProcessedNode(node, std::move(inputs), opts.enable_out_variant)); + for (auto i = 0; i < node->outputs().size(); ++i) { + val_to_ival[node->outputs()[i]] = &nodes_.back().Output(i); + } + } + for (auto output : graph->outputs()) { + outputs_.emplace_back(val_to_ival.at(output)); } } +size_t StaticRuntime::num_outputs() const { + return module_->output_regs.size(); +} + std::vector StaticRuntime::run( const std::vector& inps) { - // Container for inputs, outputs, and activations (excluding parameters) + std::vector stack; + stack.resize(inps.size()); + for (size_t i = 0; i < inps.size(); i++) { + stack[i] = inps[i]; + } + + c10::IValue v = run(stack, std::unordered_map()); + + std::vector out; - int start = 0; - if (graph_->inputs().size() != inps.size()) { - start = 1; - CHECK_EQ(graph_->inputs().size(), inps.size() + 1); - CHECK((graph_->inputs().at(0)->type()->is_module())); - workspace_[graph_->inputs()[0]] = module_._ivalue(); + if (v.isTuple()) { + auto t = v.toTuple(); + for (const auto& el : t->elements()) { + out.emplace_back(el.toTensor()); + } + } else { + out.emplace_back(v.toTensor()); } + return out; +} - for (size_t i = 0; i < inps.size(); i++) { - workspace_[graph_->inputs()[i + start]] = inps[i]; +c10::IValue StaticRuntime::run( + const std::vector& args, + const std::unordered_map& kwargs) { + // We assume inference workloads, so we do not need + // autograd. Enabling this is a significant win on dispatcher + // overhead because it saves a round of dispatch for at least some + // functions, such as resize_ and resize_as_. + at::AutoNonVariableTypeMode non_var_type_mode(true); + + if (planner_) { + planner_->allocate(); } - for (const auto& n : nodes_) { - n.run(workspace_); + if (!kwargs.empty()) { + // This is not ideal + TORCH_CHECK( + module_->schema != nullptr, + "Schema is not available. Consider creating the Static Runtime " + "with StaticRuntime(const torch::jit::Module& m) instead."); + std::vector s = args; + module_->schema->checkAndNormalizeInputs(s, kwargs); + for (size_t i = 0; i < s.size(); i++) { + Input(i) = s[i]; + } + } else { + for (size_t i = 0; i < args.size(); i++) { + Input(i) = args[i]; + } } - std::vector out; - for (Value* output : graph_->outputs()) { - const IValue& v = workspace_[output]; - if (v.isTuple()) { - auto t = v.toTuple(); - for (const auto& el : t->elements()) { - out.emplace_back(el.toTensor()); + // NB: before optimizing the order of execution, ensure that the + // memory optimization pass (LivenessMap + AssignRegisters) is + // aware of the new order! + for (auto& n : nodes_) { + n.run(); + } + + if (opts_.cleanup_activations) { + if (!planner_) { + std::unordered_map> shared; + planner_ = std::make_unique(this, shared); + } + planner_->deallocate(); + } + + // no need to keep references of outputs in static runtime anymore + if (num_outputs() > 1) { + std::vector outputs; + outputs.reserve(num_outputs()); + for (auto i = 0; i < num_outputs(); ++i) { + outputs.emplace_back(Output(i)); + } + return c10::ivalue::Tuple::create(outputs); + } + return Output(0); +} + +void StaticRuntime::benchmark( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) { + float time_per_iter = benchmark_model(args, kwargs, warmup_runs, main_runs); + std::cout << "Static runtime ms per iter: " << time_per_iter + << ". Iters per second: " << 1000.0 / time_per_iter << std::endl; + + IndividualMetrics results = + benchmark_individual_ops(args, kwargs, warmup_runs, main_runs); + std::cout << "Setting up took " << results.setup_time << " ms" << std::endl; + + for (size_t i = 0; i < nodes_.size(); i++) { + const Node* node = nodes_[i].get_node(); + std::cout << "Node #" << i << ": " << results.time_per_node[i] + << " ms/iter, "; + node->print(std::cout, 0, nullptr, false); + } + + std::vector> time_per_node_type_vec{ + results.time_per_node_type.begin(), results.time_per_node_type.end()}; + std::sort( + time_per_node_type_vec.begin(), + time_per_node_type_vec.end(), + [](auto& left, auto& right) { return left.second > right.second; }); + + std::cout << "Time per node type:" << std::endl; + for (const auto& p : time_per_node_type_vec) { + const std::string& kind = p.first; + const double ms = p.second; + std::cout << std::setw(15) << ms << " ms. " << std::setw(10) + << results.percent_per_node_type[kind] << "%. " << kind << " (" + << results.instances_per_node_type[kind] << " nodes)" + << std::endl; + } + std::cout << std::setw(15) << results.total_time << " ms. in Total" + << std::endl; + + if (planner_) { + std::cout << "Total memory managed: " << planner_->total_managed() + << " bytes" << std::endl; + } + if (module_->opts.optimize_memory) { + std::cout << "Total number of reused registers: " << module_->reused_regs + << std::endl; + } +} + +float StaticRuntime::benchmark_model( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) { + TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1); + + for (int i = 0; i < warmup_runs; i++) { + run(args, kwargs); + } + caffe2::Timer timer; + for (int i = 0; i < main_runs; i++) { + run(args, kwargs); + } + float millis = timer.MilliSeconds(); + return millis / static_cast(main_runs); +} + +StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) { + TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1); + + // See comment on above use of AutoNonVariableTypeMode for + // explanation. + at::AutoNonVariableTypeMode non_var_type_mode(true); + + IndividualMetrics results; + results.total_time = 0.0; + results.time_per_node.resize(nodes_.size(), 0); + + // setup time + caffe2::Timer timer; + std::vector stack(args); + if (!kwargs.empty()) { + // This is not ideal + TORCH_CHECK( + module_->schema != nullptr, + "Schema is not available. Consider creating the Static Runtime " + "with StaticRuntime(const torch::jit::Module& m) instead."); + module_->schema->checkAndNormalizeInputs(stack, kwargs); + } + for (size_t i = 0; i < stack.size(); i++) { + Input(i) = stack[i]; + } + results.setup_time = timer.MilliSeconds(); + + // warmup runs + for (int i = 0; i < warmup_runs; i++) { + run(args, kwargs); + } + + // main runs + for (int i = 0; i < main_runs; i++) { + if (planner_) { + planner_->allocate(); + } + for (size_t j = 0; j < nodes_.size(); j++) { + timer.Start(); + nodes_[j].run(); + float millis = timer.MilliSeconds(); + results.time_per_node[j] += millis; + } + if (opts_.cleanup_activations) { + if (!planner_) { + std::unordered_map> shared; + planner_ = std::make_unique(this, shared); + } + planner_->deallocate(); + } + } + + // post processing + for (size_t i = 0; i < nodes_.size(); i++) { + const Node* node = nodes_[i].get_node(); + std::string kind = std::string(node->kind().toQualString()); + results.time_per_node[i] /= static_cast(main_runs); + results.time_per_node_type[kind] += results.time_per_node[i]; + results.instances_per_node_type[kind]++; + results.total_time += results.time_per_node[i]; + } + for (const auto& p : results.time_per_node_type) { + const std::string& kind = p.first; + results.percent_per_node_type[kind] = p.second / results.total_time * 100; + } + return results; +} + +MemoryPlanner::MemoryPlanner( + StaticRuntime* runtime, + std::unordered_map> should_share) { + // collect register indices of outputs of ops with out variant + std::unordered_set managed_values; + std::unordered_set unmanaged_value_set; + for (ProcessedNode& pnode : runtime->get_nodes()) { + if (pnode.has_out_variant()) { + // Types are stored in the underlying TorchScript IR + for (Value* out : pnode.get_node()->outputs()) { + if (out->type()->cast()) { + managed_values.insert(out); + } } } else { - out.emplace_back(v.toTensor()); + for (auto i = 0; i < pnode.outputs().size(); ++i) { + unmanaged_value_set.insert(&pnode.Output(i)); + } + } + } + + const InferenceModule* module = runtime->get_inference_module(); + + // remove model outputs from managed_values + for (Value* output : module->graph->outputs()) { + managed_values.erase(output); + } + for (IValue* output : runtime->outputs()) { + unmanaged_value_set.erase(output); + } + for (IValue* out : unmanaged_value_set) { + unmanaged_values_.emplace_back(out); + } + + // remove tensors in output List/Tuple from managed_values + for (Value* output : module->graph->outputs()) { + Node* output_node = output->node(); + if (output_node->kind() == prim::TupleConstruct || + output_node->kind() == prim::ListConstruct) { + for (Value* input : output_node->inputs()) { + managed_values.erase(input); + } + } + } + + // some Values should share storage, this map will + // keep track of the index into managed_storage_ + std::unordered_map shared; + + // Snapshot of the current memory state + for (const auto& pnode : runtime->get_nodes()) { + for (auto i = 0; i < pnode.outputs().size(); ++i) { + const auto& ival = pnode.outputs()[i]; + const auto& val = pnode.get_node()->outputs()[i]; + if (managed_values.count(val)) { + TORCH_CHECK(ival.isTensor()); + auto* impl = ival.toTensor().storage().unsafeGetStorageImpl(); + if (shared.count(val)) { + managed_storage_[shared.at(val)].second.emplace_back(impl); + } else { + auto p = + std::make_pair>(0, {impl}); + managed_storage_.emplace_back(std::move(p)); + // first of a group, update the shared map with the index + if (should_share.count(val)) { + for (auto v : should_share.at(val)) { + shared[v] = managed_storage_.size() - 1; + } + } + } + } } } - return out; } -ProcessedNode::ProcessedNode(Node* node) : node_(node) { +// Don't change the size if it is already aligned, otherwise increase the size +// to make it aligned. +size_t MemoryPlanner::compute_aligned_tensor_size(size_t nbytes) { + // Note: everything below is size_t + return (nbytes + c10::gAlignment - 1) & (~(c10::gAlignment - 1)); +} + +at::DataPtr MemoryPlanner::allocate_buffer(size_t size) { + at::Allocator* allocator = c10::GetCPUCachingAllocator(); + return allocator->allocate(size); +} + +void MemoryPlanner::allocate() { + if (managed_bytes_ == 0) { + return; + } + + buffer_ = allocate_buffer(managed_bytes_); + + size_t offset = 0; + uint8_t* start = static_cast(buffer_.get()); + + for (const auto& ms : managed_storage_) { + auto tensor_size = ms.first; + if (tensor_size == 0) { + continue; + } + const auto& impls = ms.second; + DCHECK_LE(offset + tensor_size, managed_bytes_); + void* src = static_cast(start + offset); + + for (auto& impl : impls) { + impl->set_data_ptr(at::DataPtr(src, src, nullptr, impl->device())); + impl->set_nbytes(tensor_size); + } + + offset += tensor_size; + } + DCHECK_EQ(offset, managed_bytes_); +} + +void MemoryPlanner::deallocate() { + managed_bytes_ = 0; + + // free memory used by outputs of ops in out variants + // but keep the TensorImpl and StorageImpl around + for (auto& ms : managed_storage_) { + const auto& impls = ms.second; + size_t max = 0; + for (auto& impl : impls) { + size_t current_size = compute_aligned_tensor_size(impl->nbytes()); + impl->reset(); + max = std::max(max, current_size); + } + ms.first = max; + managed_bytes_ += max; + } + for (auto& iv : unmanaged_values_) { + *iv = IValue(); + } + buffer_ = {}; +} + +ProcessedNode::ProcessedNode( + Node* node, + std::vector&& inputs, + bool enable_out_variants) + : node_(node), inputs_(std::move(inputs)) { + // TODO leverage type information + outputs_.resize(node->outputs().size()); if (node->kind() != prim::ListConstruct && - node->kind() != prim::TupleConstruct) { + node->kind() != prim::TupleConstruct && + node->kind() != prim::ListUnpack) { const Operator& op = node->getOperator(); - CHECK(op.hasOperation()); + TORCH_CHECK(op.hasOperation()); op_ = op.getOperation(node); } + if (enable_out_variants && canRunOutOfPlace(node)) { + fn_ = getOutOfPlaceOperation(node); + std::ostringstream ss; + node->print(ss, 0, nullptr, false); + VLOG(1) << "Switch to out variant for node: " << ss.str(); + } else if (canRunNatively(node)) { + native_fn_ = getNativeOperation(node); + std::ostringstream ss; + node->print(ss, 0, nullptr, false); + VLOG(1) << "Switch to native impl for node: " << ss.str(); + } else { + std::ostringstream ss; + node->print(ss, 0, nullptr, false); + VLOG(1) << "Fallback interpreter for node: " << ss.str(); + } } -void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const { - if (use_stack_) { +void ProcessedNode::run() { + if (fn_) { + fn_(this); + } else if (native_fn_) { + native_fn_(this); + } else { std::vector stack; const size_t size = node_->inputs().size(); stack.reserve(size); for (size_t i = 0; i < size; i++) { - Value* v = node_->inputs()[i]; - auto f = workspace.find(v); - TORCH_CHECK( - f != workspace.end(), - "Workspace does not contain Value ", - v->debugName()); - stack.emplace_back(f->second); - } - if (op_) { - (*op_)(&stack); - } else { - if (node_->kind() == prim::ListConstruct) { - listConstruct( - stack, - node_->output()->type()->expect(), - node_->inputs().size()); - } else if (node_->kind() == prim::TupleConstruct) { - bool named = - node_->output()->type()->expect()->name().has_value(); - if (named) { - namedTupleConstruct( - stack, - node_->output()->type()->expect(), - node_->inputs().size()); - } else { - tupleConstruct(stack, node_->inputs().size()); - } - } else { - TORCH_CHECK(0, "Unhandled operation!", node_->kind().toQualString()); - } + stack.emplace_back(Input(i)); } + + DCHECK(op_); + op_->operator()(&stack); + DCHECK_EQ(stack.size(), node_->outputs().size()); for (auto i = 0; i < node_->outputs().size(); i++) { - workspace[node_->outputs()[i]] = stack[i]; + Output(i) = std::move(stack[i]); } - } else { - TORCH_CHECK(0, "Non-stack execution not yet implemented"); } } diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 2274d2883fb57..a8e2cb3668d26 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -2,59 +2,298 @@ #include #include +#include #include #include #include #include -#ifdef FBCODE_CAFFE2 -#include -#endif - namespace torch { namespace jit { +struct TORCH_API InferenceModuleOptions { + bool optimize_memory{true}; // TODO remove when logic moves to runtime +}; + +struct TORCH_API StaticRuntimeOptions { + bool cleanup_activations{true}; + bool enable_out_variant{true}; +}; + +/// Static runime supports two execution modes. +/// +/// Mode 1: single-threaded with no parallelism except for intra-op parallelism +/// For this mode, you can do either: +/// @code +/// // m is the TorchScript module +/// auto runtime = StaticRuntime(m, opts); +/// auto output = runtime.run(args, kwargs); +/// @endcode +/// or +/// @code +/// auto mod = PrepareForStaticRuntime(m); +/// auto runtime = StaticRuntime(mod, opts); +/// auto output = runtime.run(args, kwargs); +/// @endcode +/// Mode 2: similar to data parallelism, run the same model for different inputs +/// on different threads at the same time. In this case, run +/// PrepareForStaticRuntime to prepare the graph for Static Runtime. You +/// should have one InferenceModule instance per model, and one Static Runtime +/// instance per running thread. To avoiding creating StaticRuntime on the fly, +/// use a synchronized stack (i.e. boost::lockfree::stack) to cache all the +/// Static Runtime instances in your code. +/// @code +/// // initialization +/// auto mod = PrepareForStaticRuntime(m); +/// // 128 is good for most cases. Pick a number that works for you +/// boost::lockfree::stack, +/// boost::lockfree::fixed_sized> pool(128); +/// +/// // inference +/// std::shared_ptr runtime = nullptr; +/// pool.pop(runtime); +/// if (!runtime) { +/// runtime = std::make_shared(mod, opts); +/// } +/// auto output = runtime->run(args, kwargs); +/// pool.push(runtime); +/// @endcode +/// + +// Group readonly data structures into InferenceModule +struct TORCH_API InferenceModule { + public: + explicit InferenceModule(const torch::jit::Module& m, InferenceModuleOptions); + explicit InferenceModule( + std::shared_ptr g, + InferenceModuleOptions); + torch::jit::Module module; + std::shared_ptr graph; + std::unique_ptr schema; + + std::unordered_map value_to_reg; + std::vector values; // useful for debugging + std::vector input_regs; // inputs to the graph + std::vector output_regs; // outputs of the graph + std::vector internals; + size_t reused_regs = 0; + InferenceModuleOptions opts; + + private: + void init(); +}; + +TORCH_API void PrepareGraphForStaticRuntime( + std::shared_ptr g); + +inline TORCH_API std::shared_ptr PrepareForStaticRuntime( + const torch::jit::Module& m, + InferenceModuleOptions opts = InferenceModuleOptions()) { + return std::make_shared(m, opts); +} + +inline TORCH_API std::shared_ptr PrepareForStaticRuntime( + std::shared_ptr g, + InferenceModuleOptions opts = InferenceModuleOptions()) { + return std::make_shared(g, opts); +} + +class MemoryPlanner; class ProcessedNode; class TORCH_API StaticRuntime { public: - explicit StaticRuntime(std::shared_ptr g) - : graph_(std::move(g)) {} + // InferenceModule m is created by PrepareForStaticRuntime + explicit StaticRuntime( + std::shared_ptr m, + const StaticRuntimeOptions& opts = StaticRuntimeOptions()); - explicit StaticRuntime(const torch::jit::Module& m); + // m is unoptimized + explicit StaticRuntime( + const torch::jit::Module& m, + const StaticRuntimeOptions& opts = StaticRuntimeOptions()); std::vector run(const std::vector& inps); -#ifdef FBCODE_CAFFE2 - using ConstantMap = folly::F14FastMap; -#else - using ConstantMap = std::unordered_map; -#endif + // This interface only works module_ that has a non-empty TorchScript module + // member; otherwise use the above interface + c10::IValue run( + const std::vector& args, + const std::unordered_map& kwargs); - private: - torch::jit::Module module_; - std::shared_ptr graph_; + void benchmark( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs); - // Static runtime states - // Value table (including weights) - ConstantMap workspace_; + float benchmark_model( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs); + + struct IndividualMetrics { + float setup_time; + float total_time; + std::vector time_per_node; + std::unordered_map time_per_node_type; + std::unordered_map percent_per_node_type; + std::unordered_map instances_per_node_type; + }; + + IndividualMetrics benchmark_individual_ops( + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs); + + const InferenceModule* get_inference_module() { + return module_.get(); + } + + const std::vector& get_nodes() const { + return nodes_; + } + + std::vector& get_nodes() { + return nodes_; + } + + const std::vector& get_registers() { + return reg_; + } + + size_t num_outputs() const; + + inline const std::vector& outputs() const { + return outputs_; + } + private: + // Static runtime states + std::shared_ptr module_; + StaticRuntimeOptions opts_; + // IValue table (including inputs, outputs, intermediates, and weights) + std::vector reg_; + std::vector constants_; + std::vector inputs_; + std::vector outputs_; // The nodes we need to run std::vector nodes_; + + // Memory planning is only enabled if opts_.cleanup_activations is true. + // Otherwise, the memory used by activations is cached inside the static + // runtime. + std::unique_ptr planner_; + + // Input is readwrite + IValue& Input(size_t i) { + DCHECK(i < inputs_.size()); + return inputs_[i]; + } + + // Output is readonly. The writing process happens inside ProcessedNodes + const IValue& Output(size_t i) const { + DCHECK(i < outputs_.size()); + return *outputs_[i]; + } +}; + +/// There are three types of ops in a processed graph in Static Runtime: +/// 1. op with _out variant +/// 2. view producing op +/// 3. tensor producing op (could be replaced with type 1 by adding the _out +/// variant to Static Runtime) +/// The memory planner only manages tensors that are outputs of type 1 ops, +/// because type 2 ops don't incur memory allocation and for type 3, the output +/// tensors are allocated inside the operator and can't be directly managed by +/// memory planner. +/// +/// Memory planner tries to minimize the number of memory allocations by +/// tracking the unique StorageImpls of the output tensors of ops with _out +/// variants. It tries to do this in several steps: +/// 1. record the max memory usage for each StorageImpl at the end of each +/// iteration +/// 2. in the next iteration, allocate the buffer for the max total usage and +/// compute the offset of each allocation with regard to the single memory +/// buffer, optionally reusing memory. In the first iteration, we rely on +/// the default allocator for memory allocation. +/// 3. free the buffer at the end of each iteration +/// Steps 1 and 3 are handled by `deallocate()`, and step 2 by `allocate()`. +/// Only models with simple output types are supported, i.e. None, Tensor or +/// List/Tuple of Tensors. Complex output types such as List of Lists are not +/// supported. + +class MemoryPlanner { + public: + explicit MemoryPlanner( + StaticRuntime* runtime, + std::unordered_map> should_share); + + void allocate(); + void deallocate(); + size_t total_managed() const { + return managed_bytes_; + } + + private: + std::vector unmanaged_values_; + // each pair contains the size (in bytes) of data to be allocated + // and a vector of StorageImpl's that should be backed by that same data + // Thus, if memonger is disabled, all vectors are of size 1. + std::vector>> + managed_storage_; + size_t managed_bytes_{0}; + at::DataPtr buffer_; // allocated each time we call Run() + + static size_t compute_aligned_tensor_size(size_t nbytes); + static at::DataPtr allocate_buffer(size_t size); }; class ProcessedNode { public: - ProcessedNode(Node* n); - void run(StaticRuntime::ConstantMap& workspace) const; + ProcessedNode( + Node* n, + std::vector&& inputs, + bool enable_out_variant); + + void run(); + Node* get_node() const { return node_; } + // Input is readonly + const IValue& Input(size_t i) const { + DCHECK(i < inputs_.size()); + return *inputs_[i]; + } + + // Output is readwrite + IValue& Output(size_t i) { + DCHECK(i < outputs_.size()); + return outputs_[i]; + } + + const std::vector& outputs() const { + return outputs_; + } + + const std::vector& inputs() const { + return inputs_; + } + + bool has_out_variant() const { + return static_cast(fn_); + } + private: Node* node_; c10::optional op_; - // if false, we have an optimized version - bool use_stack_ = true; + std::function fn_; + std::function native_fn_; + std::vector inputs_; // unowned + std::vector outputs_; // TODO make list for safety }; } // namespace jit diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp index d57242d6b68c7..9027549aedab2 100644 --- a/torch/csrc/jit/runtime/static/init.cpp +++ b/torch/csrc/jit/runtime/static/init.cpp @@ -1,4 +1,7 @@ #include + +#include +#include #include namespace torch { @@ -6,14 +9,85 @@ namespace jit { void initStaticRuntimeBindings(PyObject* module) { auto m = py::handle(module).cast(); - py::class_(m, "StaticRuntime").def("run", &StaticRuntime::run); + py::class_ static_runtime(m, "StaticRuntime"); + py::class_( + static_runtime, "IndividualMetrics") + .def_readonly("setup_time", &StaticRuntime::IndividualMetrics::setup_time) + .def_readonly("total_time", &StaticRuntime::IndividualMetrics::total_time) + .def_readonly( + "time_per_node", &StaticRuntime::IndividualMetrics::time_per_node) + .def_readonly( + "time_per_node_type", + &StaticRuntime::IndividualMetrics::time_per_node_type) + .def_readonly( + "percent_per_node_type", + &StaticRuntime::IndividualMetrics::percent_per_node_type) + .def_readonly( + "instances_per_node_type", + &StaticRuntime::IndividualMetrics::instances_per_node_type); + static_runtime + .def( + "run", + py::overload_cast&>( + &StaticRuntime::run)) + .def( + "run", + [](StaticRuntime& self, + const std::vector& args, + const std::unordered_map& kwargs) { + std::vector arg_ivalues{args.begin(), args.end()}; + std::unordered_map kwarg_ivalues{ + kwargs.begin(), kwargs.end()}; + c10::IValue ret = self.run(arg_ivalues, kwarg_ivalues); + return toPyObject(ret); + }) + .def( + "benchmark", + [](StaticRuntime& self, + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) { + std::vector arg_ivalues{args.begin(), args.end()}; + std::unordered_map kwarg_ivalues{ + kwargs.begin(), kwargs.end()}; + self.benchmark(arg_ivalues, kwarg_ivalues, warmup_runs, main_runs); + }) + .def( + "benchmark_individual_ops", + [](StaticRuntime& self, + const std::vector& args, + const std::unordered_map& kwargs, + const int warmup_runs, + const int main_runs) { + std::vector arg_ivalues{args.begin(), args.end()}; + std::unordered_map kwarg_ivalues{ + kwargs.begin(), kwargs.end()}; + return self.benchmark_individual_ops( + arg_ivalues, kwarg_ivalues, warmup_runs, main_runs); + }); m.def( "_jit_to_static_runtime", - [](const std::shared_ptr& g) { - return StaticRuntime(g); + [](std::shared_ptr g) { + return StaticRuntime(PrepareForStaticRuntime(g)); }) - .def("_jit_to_static_runtime", [](const torch::jit::Module& m) { - return StaticRuntime(m); + .def( + "_jit_to_static_runtime", + [](const torch::jit::Module& m) { + return StaticRuntime(PrepareForStaticRuntime(m)); + }) + .def( + "_fuse_to_static_runtime", + [](torch::jit::Module& module) { + module.eval(); + module = freeze_module(module); + + Method method = module.get_method("forward"); + auto graph = method.graph(); + fuseStaticSubgraphs(graph); + }) + .def("_fuse_to_static_runtime", [](std::shared_ptr g) { + fuseStaticSubgraphs(g); }); } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp new file mode 100644 index 0000000000000..6fd7c7bb28598 --- /dev/null +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -0,0 +1,514 @@ +#include + +#include +#include +#include +#include + +namespace torch { +namespace jit { + +C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor); + +bool canRunOutOfPlace(Node* n) { + auto op_name = std::string(n->kind().toQualString()); + return SROperatorRegistry()->Has(op_name); +} + +bool canReuseInputs(Node* n) { + auto op_name = std::string(n->kind().toQualString()); + DCHECK(SROperatorRegistry()->Has(op_name)); + return SROperatorRegistry()->Create(op_name)->CanReuseInput(); +} + +bool canReuseOutputs(Node* n) { + auto op_name = std::string(n->kind().toQualString()); + DCHECK(SROperatorRegistry()->Has(op_name)); + return SROperatorRegistry()->Create(op_name)->CanReuseOutput(); +} + +// TODO: expand to include all view producing ops, mostly in +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp +bool canRunNatively(Node* n) { + // In alphabetical order + const static std::unordered_set native_nodes{ + "aten::flatten", + "aten::reshape", + "aten::slice", + "aten::transpose", + "aten::to", + "prim::ListConstruct", + "prim::ListUnpack", + "prim::TupleConstruct"}; + auto str = std::string(n->kind().toQualString()); + if (!native_nodes.count(str)) { + return false; + } + if (str == "aten::to") { + return n->inputs().size() == 5; + } + return true; +} + +// TODO: PLEASE DON'T COPY PASTE THIS, this is copy pasted +// generated code to unblock, need to make this nicer +struct static_add final : public at::native::structured_add_out { + static_add(at::Tensor& output) : output_(output) {} + void set_output( + int64_t output_idx, + at::IntArrayRef sizes, + at::IntArrayRef strides, + at::TensorOptions options, + at::DimnameList names) override { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx == 0); + // NB: do NOT use resize_output as it will complain if not zero sized. + at::native::resize_(output_, sizes); + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + output_.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + output_.unsafeGetTensorImpl()->empty_tensor_restride( + *options.memory_format_opt()); + } + } + const at::Tensor& maybe_get_output(int64_t output_idx) override { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx == 0); + return output_; + } + at::Tensor& output_; +}; + +REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto& in1_t = p_node->Input(1).toTensor(); + auto in2_s = p_node->Input(2).toScalar(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + static_add op{out_t}; + op.meta(in0_t, in1_t, in2_s); + op.impl(in0_t, in1_t, in2_s, out_t); + }; +}); + +REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto& in1_t = p_node->Input(1).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::mul_out(out_t, in0_t, in1_t); + }; +}); + +REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto& in1_t = p_node->Input(1).toTensor(); + auto& in2_t = p_node->Input(2).toTensor(); + auto in3_s = p_node->Input(3).toScalar(); + auto in4_s = p_node->Input(4).toScalar(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); + }; +}); + +REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_s = p_node->Input(1).toScalar(); + auto in2_s = p_node->Input(2).toScalar(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::clamp_out(out_t, in0_t, in1_s, in2_s); + }; +}); + +REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto& in1_t = p_node->Input(1).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::bmm_out_cpu(out_t, in0_t, in1_t); + }; +}); + +REGISTER_OPERATOR_FUNCTOR( + aten::nan_to_num, + aten_nan_to_num, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto input_size = p_node->inputs().size(); + auto& in0_t = p_node->Input(0).toTensor(); + double in1_d = input_size > 1 ? p_node->Input(1).toDouble() : 0; + double in2_d = input_size > 2 ? p_node->Input(2).toDouble() + : std::numeric_limits::infinity(); + double in3_d = input_size > 3 + ? p_node->Input(3).toDouble() + : -std::numeric_limits::infinity(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::nan_to_num_out(out_t, in0_t, in1_d, in2_d, in3_d); + }; + }); +REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto in0_tl = p_node->Input(0).toTensorVector(); + auto in1_i = p_node->Input(1).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_tl[0]); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::_cat_out_cpu(out_t, in0_tl, in1_i); + }; +}); +REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::tanh_out(out_t, in0_t); + }; +}); + +// Split out into a function to appease MSVC's pre-processor +SROperator aten_stack(Node* n) { + return [](ProcessedNode* p_node) { + auto inputs = p_node->Input(0).toTensorVector(); + auto dim = p_node->Input(1).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(inputs[0]); + } +#ifndef NDEBUG + at::IntArrayRef entry_shape = inputs[0].sizes(); + for (auto i = 1; i < inputs.size(); i++) { + TORCH_CHECK( + inputs[i].sizes() == entry_shape, + "stack expects each tensor to be equal size, but got ", + entry_shape, + " at entry 0 and ", + inputs[i].sizes(), + " at entry ", + i); + } +#endif + for (auto i = 0; i < inputs.size(); i++) { + inputs[i] = inputs[i].unsqueeze(dim); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::_cat_out_cpu(out_t, inputs, dim); + }; +} + +REGISTER_OPERATOR_FUNCTOR(aten::stack, aten_stack, aten_stack); + +REGISTER_OPERATOR_FUNCTOR( + aten::sigmoid, + aten_sigmoid, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::sigmoid_out(out_t, in0_t); + }; + }); +REGISTER_OPERATOR_FUNCTOR( + aten::leaky_relu, + aten_leaky_relu, + [](Node* n) -> SROperator { + auto in1 = toIValue(n->inputs()[1]); + if (in1) { + auto in1_s = in1->toScalar(); + return [=](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + at::native::leaky_relu_out(out_t, in0_t, in1_s); + }; + } else { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_s = p_node->Input(1).toScalar(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + at::native::leaky_relu_out(out_t, in0_t, in1_s); + }; + } + }); +REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::threshold_out(out_t, in0_t, 0, 0); + }; +}); +REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + double in1_d = + p_node->inputs().size() > 1 ? p_node->Input(1).toDouble() : -1.0; + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + at::native::logit_out(out_t, in0_t, in1_d); + }; +}); +REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(in0_t); + } + auto& out_t = p_node->Output(0).toTensor(); + at::native::resize_as_(out_t, in0_t, c10::nullopt); + at::native::copy_(out_t, in0_t, false); + }; +}); +REGISTER_OPERATOR_FUNCTOR_OPT( + quantized::embedding_bag_byte_rowwise_offsets, + quantized_embedding_bag_byte_rowwise_offsets, + false, // don't reuse byte inputs + true, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto weight = p_node->Input(0).toTensor(); + auto indices = p_node->Input(1).toTensor(); + auto offsets = p_node->Input(2).toOptional(); + auto pruned_weights = p_node->Input(5).toBool(); + auto per_sample_weights = p_node->Input(6).toOptional(); + auto compressed_indices_mapping = + p_node->Input(7).toOptional(); + auto include_last_offset = p_node->Input(8).toBool(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = + at::empty({0}, weight.options().dtype(at::kFloat)); + } + auto out_t = p_node->Output(0).toTensor(); + out_t.resize_({0}); + return at::native::embedding_bag_byte_rowwise_offsets_out( + out_t, + weight, + indices, + offsets, + false, // unused scale_grad_by_freq + 0, // unused mode + pruned_weights, + per_sample_weights, + compressed_indices_mapping, + include_last_offset); + }; + }); + +// The out variant takes precedence over native +REGISTER_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto self = p_node->Input(0).toTensor(); // self + auto dim = p_node->Input(1).toInt(); // dim + int64_t start = 0; + if (p_node->Input(2).isScalar()) { + start = p_node->Input(2).toInt(); + } else { + auto t = p_node->Input(2).toTensor(); + start = t.item(); + } + auto length = p_node->Input(3).toInt(); // length + + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(self); + } + auto output = p_node->Output(0).toTensor(); + output.resize_({0}); + at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output); + }; +}); + +std::function getOutOfPlaceOperation(Node* n) { + auto op_name = n->kind().toQualString(); + if (SROperatorRegistry()->Has(op_name)) { + return SROperatorRegistry()->Create(op_name)->Generate(n); + } + + return [](ProcessedNode*) { TORCH_CHECK(0); }; +} + +std::function getNativeOperation(Node* n) { + if (n->kind() == c10::Symbol::fromQualString("aten::transpose")) { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_i = p_node->Input(1).toInt(); + auto in2_i = p_node->Input(2).toInt(); + p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_i = p_node->Input(1).toInt(); + auto in2_i = p_node->Input(2).toInt(); + p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i); + }; + } else if (n->kind() == prim::TupleConstruct) { + return [](ProcessedNode* p_node) { + // prepare inputs + std::vector stack; + const size_t size = p_node->inputs().size(); + stack.reserve(size); + for (size_t i = 0; i < size; i++) { + stack.emplace_back(p_node->Input(i)); + } + // run op + auto* node = p_node->get_node(); + const auto& type = node->output()->type()->expect(); + if (type->name().has_value()) { + namedTupleConstruct(stack, type, node->inputs().size()); + } else { + tupleConstruct(stack, node->inputs().size()); + } + // put output back + p_node->Output(0) = std::move(stack[0]); + }; + } else if (n->kind() == prim::ListConstruct) { + return [](ProcessedNode* p_node) { + // prepare inputs + std::vector stack; + const size_t size = p_node->inputs().size(); + stack.reserve(size); + for (size_t i = 0; i < size; i++) { + stack.emplace_back(p_node->Input(i)); + } + // run op + listConstruct( + stack, + p_node->get_node()->output()->type()->expectRef(), + p_node->inputs().size()); + // put output back + p_node->Output(0) = std::move(stack[0]); + }; + } else if (n->kind() == prim::ListUnpack) { + return [](ProcessedNode* p_node) { + // prepare inputs + std::vector stack; + const size_t size = p_node->inputs().size(); + stack.reserve(size); + for (size_t i = 0; i < size; i++) { + stack.emplace_back(p_node->Input(i)); + } + // run op + size_t num_outputs = p_node->outputs().size(); + listUnpack(stack, num_outputs); + // put output back + DCHECK_EQ(stack.size(), num_outputs); + for (auto i = 0; i < num_outputs; i++) { + p_node->Output(i) = std::move(stack[i]); + } + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_iv = p_node->Input(1).toIntVector(); + p_node->Output(0) = at::native::permute(in0_t, in1_iv); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_iv = p_node->Input(1).toIntVector(); + p_node->Output(0) = at::native::reshape(in0_t, in1_iv); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) { + return [](ProcessedNode* p_node) { + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_i = p_node->Input(1).toInt(); + auto in2_i = p_node->Input(2).toInt(); + auto in3_i = p_node->Input(3).toInt(); + auto in4_i = p_node->Input(4).toInt(); + p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::narrow")) { + return [](ProcessedNode* p_node) { + auto& self = p_node->Input(0).toTensor(); // self + auto dim = p_node->Input(1).toInt(); // dim + int64_t start = 0; + if (p_node->Input(2).isScalar()) { + start = p_node->Input(2).toInt(); + } else { + auto& t = p_node->Input(2).toTensor(); + start = t.item(); + } + auto length = p_node->Input(3).toInt(); // length + TORCH_CHECK( + self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + auto cur_size = self.size(dim); + if (start != cur_size && start < 0) { // start being the end is valid, but + // not a valid dim specification. + start = at::maybe_wrap_dim(start, cur_size); + } + TORCH_CHECK( + length >= 0 && start <= cur_size - length, + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + p_node->Output(0) = + at::native::slice(self, dim, start, start + length, 1); + }; + } else if (n->kind() == c10::Symbol::fromQualString("aten::to")) { + return [](ProcessedNode* p_node) { + DCHECK(p_node->inputs().size() == 5); + auto& in0_t = p_node->Input(0).toTensor(); + auto in1_i = p_node->Input(1).toScalarType(); + auto in2_i = p_node->Input(2).toBool(); + auto in3_i = p_node->Input(3).toBool(); + if (p_node->Input(4).isNone()) { + p_node->Output(0) = + at::native::to(in0_t, in1_i, in2_i, in3_i, c10::nullopt); + } else { + auto in4_o = p_node->Input(4).toMemoryFormat(); + p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o); + } + }; + } + return [](ProcessedNode*) { TORCH_CHECK(0); }; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h new file mode 100644 index 0000000000000..efbb80ce97e2f --- /dev/null +++ b/torch/csrc/jit/runtime/static/ops.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { + +using SROperator = std::function; +using SROpFunctor = SROperator (*)(Node* n); +struct SROperatorFunctor { + virtual SROperator Generate(Node*) { + std::function out; + return out; + } + virtual bool CanReuseInput() { + return false; + } + virtual bool CanReuseOutput() { + return false; + } + virtual ~SROperatorFunctor() = default; +}; + +C10_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor); + +// TODO: reuse_inp reuse_out can be deprecated with further analysis +// try to avoid this API. +#define REGISTER_OPERATOR_FUNCTOR_OPT(name, id, reuse_inp, reuse_out, ...) \ + struct SROperatorFunctor_##id : public SROperatorFunctor { \ + const SROpFunctor fn = __VA_ARGS__; \ + bool CanReuseInput() override { \ + return reuse_inp; \ + } \ + bool CanReuseOutput() override { \ + return reuse_out; \ + } \ + SROperator Generate(Node* n) override { \ + return fn(n); \ + } \ + }; \ + C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id); + +#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \ + REGISTER_OPERATOR_FUNCTOR_OPT(name, id, true, true, __VA_ARGS__) + +inline at::Tensor create_empty_from(const at::Tensor& t) { + return at::empty({0}, t.options()); +} + +bool canRunOutOfPlace(Node* n); +bool canReuseInputs(Node* n); +bool canReuseOutputs(Node* n); + +std::function getOutOfPlaceOperation(Node* n); + +bool canRunNatively(Node* n); +std::function getNativeOperation(Node* n); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp new file mode 100644 index 0000000000000..c70213881cfc5 --- /dev/null +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -0,0 +1,160 @@ +#include + +#include + +namespace torch { +namespace jit { + +void ConcatAddMulReplaceNaNClip(std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %y0 = aten::cat(%a, %b) + %y1 = aten::add(%y0, %c, %d) + %y2 = aten::mul(%y1, %e) + %y3 = aten::nan_to_num(%y2, %f, %g, %h) + %res = aten::clamp(%y3, %i, %j) + return (%res))IR"; + std::string pattern2 = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %y0 = aten::cat(%a, %b) + %y1 = aten::add(%y0, %c, %d) + %y2 = aten::mul(%y1, %e) + %y3 = aten::nan_to_num_(%y2, %f, %g, %h) + %res = aten::clamp(%y3, %i, %j) + return (%res))IR"; + std::string pattern3 = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %y0 = aten::cat(%a, %b) + %y1 = aten::add(%y0, %c, %d) + %y2 = aten::mul(%y1, %e) + %y3 = aten::nan_to_num_(%y2, %f, %g, %h) + %res = aten::clamp_(%y3, %i, %j) + return (%res))IR"; + std::string pattern4 = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %y0 = aten::cat(%a, %b) + %y1 = aten::add(%y0, %c, %d) + %y2 = aten::mul(%y1, %e) + %y3 = aten::nan_to_num(%y2, %f, %g, %h) + %res = aten::clamp_(%y3, %i, %j) + return (%res))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): + %res = fb::concat_add_mul_replacenan_clip(%c, %e, %a, %i, %j) + return (%res))IR"; + + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); + + fuse.RegisterRewritePattern(pattern2, fused_pattern); + fuse.runOnGraph(graph); + + fuse.RegisterRewritePattern(pattern3, fused_pattern); + fuse.runOnGraph(graph); + + fuse.RegisterRewritePattern(pattern4, fused_pattern); + fuse.runOnGraph(graph); +} + +void CastedBatchOneHotLengths(std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g): + %y0 : Tensor = aten::to(%a, %b, %c, %c, %d) + %y1 : Tensor = fb::batch_one_hot_lengths(%y0, %e, %f) + %res : Tensor = aten::to(%y1, %g, %c, %c, %d) + return (%res))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g): + %res : Tensor = fb::casted_batch_one_hot_lengths(%a, %e, %f) + return (%res))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + +void ConcatBatchMatMulBatchGather(std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %y0 : Tensor = aten::stack(%a, %b) + %y1 : Tensor = aten::transpose(%y0, %b, %c) + %y2 : Tensor = aten::bmm(%y0, %y1) + %y3 : Tensor = aten::flatten(%y2, %d, %e) + %res : Tensor = aten::index_select(%y3, %b, %f) + return (%res))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %res : Tensor = fb::concat_batch_matmul_batch_gather(%f, %a) + return (%res))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + +void ClipRangesGatherRangesLengthsToOffsets( + std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d): + %y0 : Tensor = fb::clip_ranges(%b, %c) + %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0) + %y3 : Tensor = fb::lengths_to_offsets(%y2, %d) + return (%y3, %y1))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d): + %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d) + return (%y1, %y0))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + +void ClipRangesGatherSigridHash(std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g): + %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d) + %y2 : Tensor = fb::sigrid_hash(%y0, %e, %f, %g) + return (%y2, %y1))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f, %g): + %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_offsets(%b, %a, %c, %e, %f, %g, %d) + return (%out, %off))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + +void ClipRangesGatherRangesSigridHash( + std::shared_ptr& graph) { + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %y0 : Tensor = fb::clip_ranges(%b, %c) + %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0) + %y3 : Tensor = fb::sigrid_hash(%y1, %d, %e, %f) + return (%y3, %y2))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_v3(%b, %a, %c, %d, %e, %f) + return (%out, %off))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + +void FuseInferenceOpsForSparseNN(std::shared_ptr& graph) { +#ifdef FBCODE_CAFFE2 + ConcatAddMulReplaceNaNClip(graph); + CastedBatchOneHotLengths(graph); + ConcatBatchMatMulBatchGather(graph); + ClipRangesGatherRangesLengthsToOffsets(graph); + ClipRangesGatherSigridHash(graph); + ClipRangesGatherRangesSigridHash(graph); +#endif +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h new file mode 100644 index 0000000000000..7cc9c52f7696f --- /dev/null +++ b/torch/csrc/jit/runtime/static/passes.h @@ -0,0 +1,9 @@ +#include + +namespace torch { +namespace jit { + +void FuseInferenceOpsForSparseNN(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 578586e9e9ffe..1113d451176b0 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1,4 +1,5 @@ #include + #include #include @@ -1369,8 +1370,8 @@ std::pair, Value*> extractClosure(Value* closure) { Value* context = closure->node()->inputs().at(1); TORCH_CHECK( - fn->node()->kind() == prim::Function, - "closure tuple must contain a prim::Function"); + fn->node()->kind() == prim::Closure, + "closure tuple must contain a prim::Closure"); return std::make_pair(fn->node()->g(attr::Subgraph), context); } diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp index 2654e5f477389..2bb884068f867 100644 --- a/torch/csrc/jit/runtime/vararg_functions.cpp +++ b/torch/csrc/jit/runtime/vararg_functions.cpp @@ -1,9 +1,103 @@ #include + #include namespace torch { namespace jit { +namespace { +static constexpr int defaultPrecision = 6; + +// IValue tags are intentionally private, so we need additional logic to cast +// the IValue type to the specified format. +void addFormattedArg( + char key, + const IValue& ival, + std::stringstream& ss, + int precision = defaultPrecision) { + // TODO: Implement precison-based formatting + std::stringstream tmp; + switch (key) { + case 'd': + case 'i': + TORCH_CHECK( + ival.isScalar(), + "%", + key, + " requires a number for formatting, but got ", + ival.tagKind()); + if (ival.isInt()) { + ss << ival.toInt(); + } else { + ss << static_cast(ival.toDouble()); + } + break; + case 'e': + case 'E': + TORCH_CHECK( + ival.isScalar(), + "%", + key, + " requires a number for formatting, but got ", + ival.tagKind()); + tmp << std::setprecision(precision) << std::scientific; + if (key == 'E') { + tmp << std::uppercase; + } + if (ival.isInt()) { + tmp << static_cast(ival.toInt()); + } else { + tmp << static_cast(ival.toDouble()); + } + ss << tmp.str(); + break; + case 'f': + case 'F': + TORCH_CHECK( + ival.isScalar(), + "%", + key, + " requires a number for formatting, but got ", + ival.tagKind()); + tmp << std::setprecision(precision) << std::fixed; + if (ival.isInt()) { + tmp << static_cast(ival.toInt()); + } else { + tmp << static_cast(ival.toDouble()); + } + ss << tmp.str(); + break; + case 'c': + TORCH_CHECK( + ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1), + "%", + key, + " requires an int or char for formatting, but got ", + ival.tagKind()); + if (ival.isInt()) { + ss << static_cast(ival.toInt()); + } else { + ss << ival.toStringRef(); + } + break; + case 's': + if (ival.isString()) { + ss << ival.toStringRef(); + } else { + ss << ival; + } + break; + default: + TORCH_CHECK( + false, + "The specifier %", + key, + " is not supported in TorchScript format strings"); + } +} + +} // namespace + void tupleUnpack(Stack& stack) { auto tuple = pop(stack).toTuple(); stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end()); @@ -39,6 +133,48 @@ void format(Stack& stack, size_t num_inputs) { push(stack, ss.str()); } +void percentFormat(Stack& stack, size_t num_inputs) { + auto format_str = peek(stack, 0, num_inputs).toStringRef(); + auto args = last(stack, num_inputs - 1)[0]; + auto args_size = 1; // assumed size + if (args.isTuple()) { + args_size = args.toTuple()->elements().size(); + } + std::stringstream ss; + size_t used_args = 0; + size_t begin = 0; + while (true) { + size_t percent_idx = format_str.find('%', begin); + if (percent_idx == std::string::npos) { + ss << format_str.substr(begin); + break; + } + size_t format_idx = percent_idx + 1; + TORCH_CHECK( + percent_idx < format_str.length() - 1, "Incomplete format specifier"); + ss << format_str.substr(begin, percent_idx - begin); + if (format_str.at(format_idx) == '%') { + ss << '%'; + begin = percent_idx + 2; // skip the `%` and the format specifier + continue; + } + TORCH_CHECK(used_args < args_size, "Too few arguments for format string"); + char key = format_str.at(format_idx); + IValue arg; + if (args.isTuple()) { + arg = args.toTuple()->elements()[used_args]; + } else { + arg = args; + } + addFormattedArg(key, arg, ss); + begin = percent_idx + 2; + ++used_args; + } + TORCH_CHECK(used_args == args_size, "Too many arguments for format string"); + drop(stack, num_inputs); + push(stack, ss.str()); +} + void listUnpack(Stack& stack, size_t num_outputs) { auto list = pop(stack).toList(); TORCH_CHECK( @@ -51,8 +187,9 @@ void listUnpack(Stack& stack, size_t num_outputs) { } void tupleConstruct(Stack& stack, size_t num_inputs) { - std::vector elems{std::make_move_iterator(stack.end() - num_inputs), - std::make_move_iterator(stack.end())}; + std::vector elems{ + std::make_move_iterator(stack.end() - num_inputs), + std::make_move_iterator(stack.end())}; drop(stack, num_inputs); push(stack, c10::ivalue::Tuple::create(std::move(elems))); } @@ -61,25 +198,36 @@ void namedTupleConstruct( Stack& stack, at::TupleTypePtr type, size_t num_inputs) { - std::vector elems{std::make_move_iterator(stack.end() - num_inputs), - std::make_move_iterator(stack.end())}; + std::vector elems{ + std::make_move_iterator(stack.end() - num_inputs), + std::make_move_iterator(stack.end())}; drop(stack, num_inputs); push( stack, c10::ivalue::Tuple::createNamed(std::move(elems), std::move(type))); } -void listConstruct(Stack& stack, at::ListTypePtr type, size_t num_inputs) { - c10::List vals(type->getElementType()); - vals.reserve(num_inputs); - for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { - vals.emplace_back(std::move(stack[i])); - } - drop(stack, num_inputs); - push(stack, std::move(vals)); +void listConstruct(Stack& stack, const at::ListType& type, size_t num_inputs) { + // Structuring the implementation this way allows NRVO to avoid + // move-constructing vals on its way onto the stack. Moving a List + // isn't free. + auto makeList = + [](Stack& stack, const at::ListType& type, size_t num_inputs) { + c10::List vals(type.getElementType()); + vals.reserve(num_inputs); + for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { + vals.push_back(std::move(stack[i])); + } + drop(stack, num_inputs); + return vals; + }; + stack.push_back(makeList(stack, type, num_inputs)); } -void dictConstruct(Stack& stack, at::DictTypePtr type, size_t num_inputs) { +void dictConstruct( + Stack& stack, + const at::DictTypePtr& type, + size_t num_inputs) { at::TypePtr key_type = type->getKeyType(); at::TypePtr value_type = type->getValueType(); auto vals = c10::impl::GenericDict(key_type, value_type); @@ -96,7 +244,7 @@ void dictConstruct(Stack& stack, at::DictTypePtr type, size_t num_inputs) { push(stack, std::move(vals)); } -void createObject(Stack& stack, at::ClassTypePtr type) { +void createObject(Stack& stack, const at::ClassTypePtr& type) { auto userObj = c10::ivalue::Object::create( c10::StrongTypePtr(type->compilation_unit(), type), type->numAttributes()); @@ -132,19 +280,19 @@ void dequantize(Stack& stack) { auto elems = tuple->elements(); std::vector output_elems; output_elems.reserve(elems.size()); - for (size_t i = 0; i < elems.size(); ++i) { - if (elems[i].isTensor()) { - output_elems.emplace_back(at::dequantize(elems[i].toTensor())); + for (const auto& elem : elems) { + if (elem.isTensor()) { + output_elems.emplace_back(at::dequantize(elem.toTensor())); } else { - output_elems.emplace_back(elems[i]); + output_elems.emplace_back(elem); } } push(stack, c10::ivalue::Tuple::create(std::move(output_elems))); } else if (iv.isTensorList()) { auto elems = iv.toTensorList(); auto output_list = c10::impl::GenericList(elems.elementType()); - for (size_t i = 0; i < elems.size(); ++i) { - output_list.emplace_back(at::dequantize(elems[i])); + for (auto&& elem : elems) { + output_list.emplace_back(at::dequantize(elem)); } push(stack, std::move(output_list)); } else { diff --git a/torch/csrc/jit/runtime/vararg_functions.h b/torch/csrc/jit/runtime/vararg_functions.h index 4941172e60bf8..e9580411212a1 100644 --- a/torch/csrc/jit/runtime/vararg_functions.h +++ b/torch/csrc/jit/runtime/vararg_functions.h @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace torch { @@ -11,6 +12,8 @@ void tupleUnpack(Stack& stack); void format(Stack& stack, size_t num_inputs); +void percentFormat(Stack& stack, size_t num_inputs); + void listUnpack(Stack& stack, size_t num_outputs); void tupleConstruct(Stack& stack, size_t num_inputs); @@ -20,11 +23,17 @@ void namedTupleConstruct( at::TupleTypePtr type, size_t num_inputs); -void listConstruct(Stack& stack, at::ListTypePtr list_type, size_t num_inputs); +void listConstruct( + Stack& stack, + const at::ListType& list_type, + size_t num_inputs); -void dictConstruct(Stack& stack, at::DictTypePtr type, size_t num_inputs); +void dictConstruct( + Stack& stack, + const at::DictTypePtr& type, + size_t num_inputs); -void createObject(Stack& stack, at::ClassTypePtr type); +void createObject(Stack& stack, const at::ClassTypePtr& type); void isinstance(Stack& stack, at::ArrayRef types); diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index c44c00a88727d..ae0f41f5c41d1 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -137,12 +138,13 @@ std::string GetFileRootPath(const std::string& rootPath) { return folder; } -std::string GetExternalFileName(const c10::optional external_ref) { +std::string GetExternalFileName( + const c10::optional& external_ref) { auto tensorName = external_ref.value(); const std::string illegalChars = "\\/:?\"<>|"; - for (int i = 0; i < tensorName.size(); i++) { - if (illegalChars.find(tensorName[i]) != std::string::npos) { - tensorName[i] = '_'; + for (char& i : tensorName) { + if (illegalChars.find(i) != std::string::npos) { + i = '_'; } } return tensorName; @@ -160,7 +162,7 @@ void CreateExternalFile( std::string fullFilePath = folder + "/" + tensorName; std::unique_ptr fp( fopen(fullFilePath.c_str(), "wb"), &CloseFile); - if (fp == NULL) { + if (fp == nullptr) { throw std::runtime_error( std::string("ONNX export failed. Could not open file or directory: ") + fullFilePath); @@ -178,6 +180,10 @@ class EncoderBase { return model_proto_; } + SymbolDimMap get_symbol_dim_param_map() { + return symbol_dim_map_; + } + protected: // Using std::map instead of std::unordered_map for initializers // in EncodeGraph constructor so that the order in which initializers @@ -215,6 +221,14 @@ class EncoderBase { bool use_external_data_format = false, const std::string& onnx_file_path = std::string()); + void AddInitializersIntoGraphProto( + onnx::GraphProto* graph_proto, + const Block* block, + const std::map& initializers = + std::map(), + bool use_external_data_format = false, + const std::string& onnx_file_path = std::string()); + virtual void EncodeTensor( onnx::TensorProto* tensor_proto, const at::Tensor& tensor, @@ -243,6 +257,7 @@ class EncoderBase { const bool use_external_data_format = false, const std::string& onnx_file_path = std::string()); + SymbolDimMap symbol_dim_map_; onnx::ModelProto model_proto_; size_t num_blocks_; size_t num_op_nodes_; @@ -316,33 +331,38 @@ void EncoderBase::EncodeValueInfo( std::unordered_map>& dynamic_axes) { std::string name = n->debugName(); v->set_name(name); - auto tensorTypeToONNXType = [&dynamic_axes, &name]( - TensorTypePtr t, + auto tensorTypeToONNXType = [&dynamic_axes, &name, this]( + const TensorTypePtr& t, onnx::TypeProto_Tensor* tensor_type) { - if (t->sizes().isComplete()) { - // onnx::TypeProto* onnx_type = v->mutable_type(); - // onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type(); + if (t->dim()) { onnx::TensorShapeProto* shape = tensor_type->mutable_shape(); - std::vector sizes = t->sizes().concrete_sizes().value(); + auto sizes = t->symbolic_sizes().sizes().value(); for (size_t i = 0; i < sizes.size(); i++) { shape->add_dim(); if ((dynamic_axes.find(name) != dynamic_axes.end()) && (dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) { shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i)); + if (!sizes[i].is_static()) { + symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i); + } + } else if (sizes[i].is_static()) { + shape->mutable_dim(i)->set_dim_value(sizes[i].static_size()); } else { - shape->mutable_dim(i)->set_dim_value(sizes[i]); + if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) { + symbol_dim_map_[sizes[i]] = name + "_" + std::to_string(i); + } + shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]); } } } if (t->scalarType()) { - // onnx::TypeProto* onnx_type = v->mutable_type(); - // onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type(); tensor_type->set_elem_type(ATenTypeToOnnxType(t->scalarType().value())); } }; if (TensorTypePtr node_type = n->type()->cast()) { - if (node_type->sizes().isComplete() || node_type->scalarType()) { + if (node_type->dim() || node_type->scalarType()) { + // Encode type if either shape or dtype exists. onnx::TypeProto* onnx_type = v->mutable_type(); onnx::TypeProto_Tensor* tensor_type = onnx_type->mutable_tensor_type(); tensorTypeToONNXType(node_type, tensor_type); @@ -541,14 +561,33 @@ void EncoderBase::EncodeBlock( onnx_file_path); } } + AddInitializersIntoGraphProto( + graph_proto, + block, + initializers, + use_external_data_format, + onnx_file_path); +} + +void EncoderBase::AddInitializersIntoGraphProto( + onnx::GraphProto* graph_proto, + const Block* block, + const std::map& initializers, + bool use_external_data_format, + const std::string& onnx_file_path) { AT_ASSERT(block->inputs().size() >= initializers.size()); - for (auto& name_tensor_pair : initializers) { + + for (auto input : block->inputs()) { + auto name_tensor_pair = initializers.find(input->debugName()); + if (name_tensor_pair == initializers.end()) { + continue; + } auto p = graph_proto->add_initializer(); - p->set_name(name_tensor_pair.first); + p->set_name(name_tensor_pair->first); EncodeTensor( p, - name_tensor_pair.second, - name_tensor_pair.first, + name_tensor_pair->second, + name_tensor_pair->first, use_external_data_format, onnx_file_path); } @@ -854,7 +893,11 @@ std::string pretty_print_onnx( // conform to the ONNX op specification. Thus, the output will not // be interpretable by a ONNX-compatible framework. However, PyTorch or // libtorch will be able to import the IR and play it back. -std::tuple export_onnx( +std::tuple< + std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, + RawDataExportMap, + SymbolDimMap> +export_onnx( const std::shared_ptr& graph, const std::map& initializers, int64_t onnx_opset_version, @@ -887,10 +930,17 @@ std::tuple export_onnx( proto_size <= INT_MAX, "Exporting model exceed maximum protobuf size of 2GB. " "Please call torch.onnx.export with use_external_data_format=True."); - GRAPH_UPDATE("onnx proto:", prettyPrint(graph_encoder.get_model_proto())); + GRAPH_DEBUG("onnx proto:", prettyPrint(graph_encoder.get_model_proto())); return std::make_tuple( - graph_encoder.get_model_proto().SerializeAsString(), - graph_encoder.get_raw_data_export_map()); + std::make_shared<::ONNX_NAMESPACE::ModelProto>( + graph_encoder.get_model_proto()), + graph_encoder.get_raw_data_export_map(), + graph_encoder.get_symbol_dim_param_map()); +} + +std::string serialize_model_proto_to_string( + const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto) { + return model_proto->SerializeAsString(); } void check_onnx_proto(const std::string& proto_string) { diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index 212ed65207fe8..4b3b71a6a3ee0 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -8,6 +8,10 @@ #include +namespace ONNX_NAMESPACE { +class ModelProto; +} + namespace torch { namespace jit { @@ -21,7 +25,13 @@ namespace jit { // file contents being the raw tensor data. using RawDataExportMap = std::unordered_map; -TORCH_API std::tuple export_onnx( +using SymbolDimMap = std::map; + +TORCH_API std::tuple< + std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, + RawDataExportMap, + SymbolDimMap> +export_onnx( const std::shared_ptr& graph, const std::map& initializers, int64_t onnx_opset_version, @@ -38,6 +48,9 @@ TORCH_API std::tuple export_onnx( bool use_external_data_format = false, const std::string& onnx_file_path = std::string()); +TORCH_API std::string serialize_model_proto_to_string( + const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto); + TORCH_API void check_onnx_proto(const std::string& proto_string); // For testing purposes @@ -98,5 +111,22 @@ TORCH_API void SetExportModuleMobileInfoConverter( // Returns a list of names of all operators in the module and its submodules. TORCH_API std::vector export_opnames(const Module& m); +namespace mobile { + +class Module; +/** + * Given a torch::jit::mobile::Module, return a set of operator names + * (with overload name) that are used by any method in this mobile + * Mobile. This method runs through the bytecode for all methods + * in the specified model (module), and extracts all the root + * operator names. Root operators are operators that are called + * directly by the model (as opposed to non-root operators, which + * may be called transitively by the root operators). + * + */ +TORCH_API std::set _export_operator_list( + torch::jit::mobile::Module& module); + +} // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 7c42bf0525835..08e1685b35eda 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -4,8 +4,11 @@ #include #include #include +#include +#include +#include +#include #include -#include #include #include #include @@ -47,27 +50,60 @@ static IValue Tup(std::vector ivalues) { static IValue Table( const std::vector>& entries) { std::vector ivalue_entries; + ivalue_entries.reserve(entries.size()); for (const auto& e : entries) { ivalue_entries.push_back(Tup({e.first, e.second})); } return Tup(std::move(ivalue_entries)); } -std::string getModulePath(Node* node) { - std::string modulePath = node->scopeName(); - size_t end = modulePath.size(); - // Here we remove the source range information to make the - // module debugging information shorter and cleaner. - if (modulePath[end - 1] == '>') { - end = modulePath.rfind('<'); - if (end > 0 && modulePath[end - 1] == '<') { - --end; +std::string getModulePath(Node* node, const std::string& root_scope_string) { + constexpr size_t kFunction = 0; + constexpr size_t kModuleInstanceInfo = 2; + + if (!node->callstack()) { + return root_scope_string + ".forward"; + } else { + std::string module_info = root_scope_string; + auto callstack_ptr = *(node->callstack()); + const auto& vec = callstack_ptr->vec(); + + for (const auto& element : vec) { + const auto& opt_module_instance_info = + std::get(element); + if (opt_module_instance_info.has_value()) { + const auto& module_instance_info = opt_module_instance_info.value(); + if (module_instance_info.class_type()) { + const auto& class_type = module_instance_info.class_type(); + const auto& instance_name = module_instance_info.instance_name(); + auto type_name = class_type->name()->qualifiedName(); + type_name = type_name.substr(type_name.find_last_of('.') + 1); + module_info.append(".") + .append(instance_name) + .append("(") + .append(type_name) + .append(")") + .append(".") + .append(std::get(element)->name()); + } else { + module_info += ".(UNKNOWN_INSTANCE(UNKNOWN_TYPE)"; + } + } else { + module_info += ".(UNKNOWN_INSTANCE(UNKNOWN_TYPE)"; + } } + + return module_info; + } +} + +std::string getModuleTypeName(const Module& module, const std::string& prefix) { + std::string moduleType = module.type()->str(); + size_t lastDotIndex = moduleType.rfind('.'); + if (lastDotIndex != std::string::npos) { + moduleType = moduleType.substr(lastDotIndex + 1); } - // We only keep the last function in a callstack. - size_t start = modulePath.rfind('/', end); - start = (start != std::string::npos) ? start + 1 : 0; - return modulePath.substr(start, end - start); + return prefix + "(" + moduleType + ")"; } std::pair> getFunctionTuple( @@ -77,9 +113,6 @@ std::pair> getFunctionTuple( auto graph = func.graph()->copy(); Inline(*graph); - if (save_mobile_debug_info) { - ReconstructScopes(module, *graph, "top"); - } torch::jit::Code code(graph, func.name()); auto instructions_copy = code.instructions(); @@ -94,7 +127,8 @@ std::pair> getFunctionTuple( auto node = code.instructions_source()[i]; opnames.emplace_back(node->schema().operator_name()); if (save_mobile_debug_info) { - op_module_paths.emplace_back(getModulePath(node)); + std::string root_scope_string = getModuleTypeName(module, "top"); + op_module_paths.emplace_back(getModulePath(node, root_scope_string)); } } // CALL nodes at this point represent built-in (i.e. non-Graph) @@ -109,14 +143,43 @@ std::pair> getFunctionTuple( auto method_name_idx = code.constant_table().size() + method_names.size(); method_names.emplace_back(node->s(attr::name)); - Instruction new_instr{INTERFACE_CALL, - static_cast(method_name_idx), - static_cast(node->inputs().size())}; - instructions_copy[i] = std::move(new_instr); + Instruction new_instr{ + INTERFACE_CALL, + static_cast(method_name_idx), + static_cast(node->inputs().size())}; + instructions_copy[i] = new_instr; } else { TORCH_INTERNAL_ASSERT( false, "Unsupported node kind on CALL opcode for mobile"); } + } else if (ins.op == RET) { + auto node = code.instructions_source()[i]; + for (const auto& input : node->inputs()) { + const auto& input_type = input->type(); + if (input_type->kind() == TypeKind::TupleType) { + if (const auto& name_typed_input = + input_type->cast()) { + TORCH_CHECK( + !name_typed_input->name(), + "A named tuple type is not supported in mobile module. ", + "Workaround: instead of using a named tuple type's fields, ", + "use a dictionary type's key-value pair itmes or ", + "a pytorch class (class Foo(torch.nn.Module))'s attributes.'"); + } + } else if ( + input_type->kind() == TypeKind::ListType || + input_type->kind() == TypeKind::DictType) { + for (const TypePtr& element_type : input_type->containedTypes()) { + TORCH_CHECK( + element_type->kind() != TypeKind::ClassType, + "Returining a list or dictionary with pytorch class type ", + "is not supported in mobile module " + "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). " + "Workaround: instead of using pytorch class as their element type, ", + "use a combination of list, dictionary, and single types."); + } + } + } } else { TORCH_CHECK( ins.op != CREATE_OBJECT, @@ -164,11 +227,12 @@ std::pair> getFunctionTuple( // register size auto register_size = static_cast(code.register_size()); - auto table = Table({{"instructions", Tup(instructions)}, - {"operators", Tup(operators)}, - {"constants", Tup(constants)}, - {"types", Tup(types)}, - {"register_size", register_size}}); + auto table = Table( + {{"instructions", Tup(instructions)}, + {"operators", Tup(operators)}, + {"constants", Tup(constants)}, + {"types", Tup(types)}, + {"register_size", register_size}}); auto bytecode_vals = Tup({func.qualname().qualifiedName(), table}); c10::optional debug_info_vals; @@ -244,12 +308,12 @@ void moduleMethodsTuple( } void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) { - GetExtraFilesHook() = hook; + GetExtraFilesHook() = std::move(hook); } void SetExportModuleMobileInfoConverter( ExportModuleMobileInfoConverter converter) { - GetMobileInfoConverter() = converter; + GetMobileInfoConverter() = std::move(converter); } class ScriptModuleSerializer { @@ -292,7 +356,7 @@ class ScriptModuleSerializer { void writeArchive(const std::string& archive_name, const IValue& value) { std::vector data; // Vector to capture the run-time class types during pickling the IValues - std::vector memorizedClassTypes; + std::vector memoizedClassTypes; Pickler data_pickle( [&](const char* buf, size_t size) { data.insert(data.end(), buf, buf + size); @@ -301,7 +365,7 @@ class ScriptModuleSerializer { [&](const c10::ClassTypePtr& t) { return type_name_uniquer_.getUniqueName(t); }, - &memorizedClassTypes); + &memoizedClassTypes); data_pickle.protocol(); data_pickle.pushIValue(value); data_pickle.stop(); @@ -316,7 +380,7 @@ class ScriptModuleSerializer { writer_.writeRecord(fname, data.data(), data.size()); // serialize all the captured run-time class types - for (const c10::ClassTypePtr& wroteType : memorizedClassTypes) { + for (const c10::ClassTypePtr& wroteType : memoizedClassTypes) { convertNamedType(wroteType); } } @@ -446,7 +510,7 @@ class ScriptModuleSerializer { }; if (!pp) { pp = &file_streams_.insert( - qualifier, + std::move(qualifier), PythonPrint( constant_table_, class_deps_, @@ -538,5 +602,26 @@ std::vector export_opnames(const script::Module& m) { return std::vector(names.begin(), names.end()); } +namespace mobile { + +std::set _export_operator_list( + torch::jit::mobile::Module& module) { + std::set operator_list; + for (Method func : module.get_methods()) { + const Function& function = func.function(); + const std::shared_ptr cptr = function.get_code(); + // op_names below isn't a list of unique operator names. In fact + // it can contain the same operator name many many times, so we need + // to de-dup the list by adding all the operator names into + // an std::set. + std::vector const& op_names = cptr->op_names_; + for (auto& op_name : op_names) { + operator_list.insert(toString(op_name)); + } + } + return operator_list; +} + +} // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index a5810f23505d6..159ded3f01ee1 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -108,8 +109,8 @@ class ScriptModuleDeserializer final { public: ScriptModuleDeserializer( std::shared_ptr cu, - std::unique_ptr reader) - : compilation_unit_(cu), + std::shared_ptr reader) + : compilation_unit_(std::move(cu)), reader_(std::move(reader)), source_importer_( compilation_unit_, @@ -128,7 +129,7 @@ class ScriptModuleDeserializer final { IValue readArchive(const std::string& archive_name); std::shared_ptr compilation_unit_; - std::unique_ptr reader_; + std::shared_ptr reader_; c10::optional device_; std::vector constants_table_; SourceImporter source_importer_; @@ -144,7 +145,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) { // Decouple how to get obj from type. In this file it's dependent on // Method.run() and graph executor, etc. // For bytecode import we need to decouple these dependencies. - auto obj_loader = [&](at::StrongTypePtr type, IValue input) { + auto obj_loader = [&](const at::StrongTypePtr& type, IValue input) { auto cls = type.type_->expect(); auto qn = cls->name(); size_t n = cls->numAttributes(); @@ -175,7 +176,6 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) { return obj; } }; - return readArchiveAndTensors( archive_name, type_resolver, obj_loader, device_, *reader_.get()); } @@ -257,8 +257,7 @@ Module ScriptModuleDeserializer::deserialize( } if (reader_->hasRecord("model.json")) { #if !defined(C10_MOBILE) && !defined(C10_DISABLE_LEGACY_IMPORT) - return torch::jit::LEGACY_deserialize( - compilation_unit_, std::move(reader_), device_); + return torch::jit::LEGACY_deserialize(compilation_unit_, reader_, device_); #else AT_ERROR("Legacy model format is not supported on mobile."); #endif @@ -271,7 +270,6 @@ Module ScriptModuleDeserializer::deserialize( rewriteQuantizedConvForBC(m); return m; } - } // namespace Module import_ir_module( @@ -323,7 +321,7 @@ Module load( } Module load( - std::unique_ptr rai, + std::shared_ptr rai, c10::optional device, ExtraFilesMap& extra_files) { // Verify that we're loading a zip archive and not a torch.save pickle archive @@ -347,7 +345,7 @@ Module load( " produced by `torch.jit.save()`"); } - auto reader = torch::make_unique(std::move(rai)); + auto reader = std::make_shared(std::move(rai)); auto cu = std::make_shared(); ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader)); diff --git a/torch/csrc/jit/serialization/import.h b/torch/csrc/jit/serialization/import.h index 543a1ca32aaf2..cbfb765a6350e 100644 --- a/torch/csrc/jit/serialization/import.h +++ b/torch/csrc/jit/serialization/import.h @@ -55,13 +55,13 @@ TORCH_API Module load( c10::optional device = c10::nullopt, ExtraFilesMap& extra_files = default_extra_files); -/// Loads a serialized `Module` from the given `rai`. +/// Loads a serialized `Module` from the given shared_ptr `rai`. /// /// The reader adapter, which is for customized input stream, must contain a /// serialized `Module`, exported either via `ScriptModule.save()` in /// Python or `torch::jit::ExportModule` in C++. TORCH_API Module load( - std::unique_ptr rai, + std::shared_ptr rai, c10::optional device = c10::nullopt, ExtraFilesMap& extra_files = default_extra_files); diff --git a/torch/csrc/jit/serialization/import_export_helpers.cpp b/torch/csrc/jit/serialization/import_export_helpers.cpp index 3f09d1f425f5f..e12e937b9268d 100644 --- a/torch/csrc/jit/serialization/import_export_helpers.cpp +++ b/torch/csrc/jit/serialization/import_export_helpers.cpp @@ -1,4 +1,5 @@ #include + #include #include #include diff --git a/torch/csrc/jit/serialization/import_legacy.cpp b/torch/csrc/jit/serialization/import_legacy.cpp index 67db88e6d871c..40e035b820907 100644 --- a/torch/csrc/jit/serialization/import_legacy.cpp +++ b/torch/csrc/jit/serialization/import_legacy.cpp @@ -25,8 +25,8 @@ void postSetStateValidate(const IValue& v); namespace { struct ClassResolver : public Resolver { - explicit ClassResolver(SourceImporter source_importer) - : source_importer_(std::move(source_importer)) {} + explicit ClassResolver(const SourceImporter& source_importer) + : source_importer_(source_importer) {} TypePtr resolveType(const std::string& name, const SourceRange& loc) override { return source_importer_.loadType(c10::QualifiedName(name)); @@ -40,9 +40,9 @@ class ScriptModuleDeserializer final { public: ScriptModuleDeserializer( std::shared_ptr cu, - std::unique_ptr reader, + std::shared_ptr reader, const c10::optional& device) - : compilation_unit_(cu), + : compilation_unit_(std::move(cu)), reader_(std::move(reader)), device_(device), source_importer_( @@ -76,7 +76,7 @@ class ScriptModuleDeserializer final { std::shared_ptr sourceLoader(const std::string& qualifier); std::shared_ptr compilation_unit_; - std::unique_ptr reader_; + std::shared_ptr reader_; c10::optional device_; // Legacy only tensor can be a constant. std::vector constant_table_; @@ -130,7 +130,7 @@ Module ScriptModuleDeserializer::LEGACY_deserialize() { LEGACY_pickled_ivalues_ = LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements(); } - LEGACY_moduleStack_.push_back("__torch__"); + LEGACY_moduleStack_.emplace_back("__torch__"); const auto& module_def = model_def.main_module(); // Move tensors in constant table. @@ -267,7 +267,7 @@ void ScriptModuleDeserializer::LEGACY_moduleSetState( if (setstate->num_inputs() == 1) { setstate->run({module._ivalue()}); } else if (setstate->num_inputs() == 2) { - setstate->run({module._ivalue(), state}); + setstate->run({module._ivalue(), std::move(state)}); } else { AT_ERROR("Unexpected schema on '__setstate__'"); } @@ -383,9 +383,10 @@ Module ScriptModuleDeserializer::LEGACY_convertModule( Module LEGACY_deserialize( std::shared_ptr cu, - std::unique_ptr reader, + std::shared_ptr reader, const c10::optional& device) { - ScriptModuleDeserializer deserializer(cu, std::move(reader), device); + ScriptModuleDeserializer deserializer( + std::move(cu), std::move(reader), device); return deserializer.LEGACY_deserialize(); } diff --git a/torch/csrc/jit/serialization/import_legacy.h b/torch/csrc/jit/serialization/import_legacy.h index 64f8a7da1968d..a261828109596 100644 --- a/torch/csrc/jit/serialization/import_legacy.h +++ b/torch/csrc/jit/serialization/import_legacy.h @@ -16,7 +16,7 @@ struct CompilationUnit; // Deserializes a model in legacy format. Module LEGACY_deserialize( std::shared_ptr cu, - std::unique_ptr reader, + std::shared_ptr reader, const c10::optional& device); } // namespace jit diff --git a/torch/csrc/jit/serialization/import_source.cpp b/torch/csrc/jit/serialization/import_source.cpp index b5a18cf05710b..79236ccb72f1e 100644 --- a/torch/csrc/jit/serialization/import_source.cpp +++ b/torch/csrc/jit/serialization/import_source.cpp @@ -93,11 +93,11 @@ struct ConstantTableValue : public SugaredValue { struct SourceImporterImpl : public Resolver, std::enable_shared_from_this { SourceImporterImpl( - const std::shared_ptr cu, + std::shared_ptr cu, const std::vector* constant_table, SourceLoader source_loader, size_t version) - : cu_(cu), source_loader_(std::move(source_loader)) { + : cu_(std::move(cu)), source_loader_(std::move(source_loader)) { env_ = { {"torch", std::make_shared("aten", version)}, {"ops", std::make_shared(version)}, diff --git a/torch/csrc/jit/serialization/pickle.cpp b/torch/csrc/jit/serialization/pickle.cpp index f1bd700a918da..27b5fe2cef68c 100644 --- a/torch/csrc/jit/serialization/pickle.cpp +++ b/torch/csrc/jit/serialization/pickle.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -75,7 +76,7 @@ std::vector pickle_save(const at::IValue& ivalue) { #ifndef C10_MOBILE class VectorReader : public caffe2::serialize::ReadAdapterInterface { public: - VectorReader(const std::vector& data) : data_(std::move(data)) {} + VectorReader(std::vector data) : data_(std::move(data)) {} size_t size() const override { return data_.size(); @@ -103,7 +104,7 @@ IValue pickle_load(const std::vector& data) { return readArchiveAndTensors( "data", - /*class_resolver=*/c10::nullopt, + /*type_resolver=*/c10::nullopt, /*obj_loader=*/c10::nullopt, /*device=*/c10::nullopt, reader); diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 6f911f4246ccb..8115694858889 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -1,6 +1,6 @@ #include #include -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC #include #endif #include @@ -90,11 +90,11 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { } else if (ivalue.isObject()) { auto obj = ivalue.toObject(); auto type = obj->type(); - if (memorized_class_types_ != nullptr) { - // Memorize every class type the Pickler encountered + if (memoized_class_types_ != nullptr) { + // memoize every class type the Pickler encountered // This is used to make sure we capture all the run-time types // and serialize them properly for class/interface polymorphism - memorized_class_types_->emplace_back(type); + memoized_class_types_->emplace_back(type); } auto type_name = type->name().value(); if (type_renamer_) { @@ -121,8 +121,8 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { } else if (ivalue.isCapsule()) { std::stringstream err; err << "Cannot serialize custom bound C++ class"; - if (memorized_class_types_ && memorized_class_types_->size()) { - if (auto qualname = memorized_class_types_->back()->name()) { + if (memoized_class_types_ && memoized_class_types_->size()) { + if (auto qualname = memoized_class_types_->back()->name()) { err << " " << qualname->qualifiedName(); } } @@ -130,7 +130,7 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { "this class."; AT_ERROR(err.str()); } else if (ivalue.isRRef()) { -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC TORCH_CHECK( torch::distributed::rpc::getAllowJitRRefPickle() == true, "RRef jit pickling is only allowed inside RPC calls."); @@ -166,7 +166,7 @@ void Pickler::pushDevice(const IValue& ivalue) { } } -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC void Pickler::pushRRef(const IValue& ivalue) { // It is the same as how rref is pickled in python, see PyRRef::pickle auto rrefInterface = ivalue.toRRef(); @@ -354,7 +354,7 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { // // The format here is the same one used by `torch.save()`. The code for the // format can be found in `torch/serialization.py`. - auto tensor = ivalue.toTensor(); + auto& tensor = ivalue.toTensor(); bool quantized = tensor.is_quantized(); // The arguments to this function are: // storage, storage_offset, size, stride, requires_grad, backward_hooks @@ -393,11 +393,9 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { break; case at::kPerChannelAffineFloatQParams: case at::kPerChannelAffine: { - const auto* quantizer = static_cast( - tensor.quantizer().get()); - pushTensor(quantizer->scales()); - pushTensor(quantizer->zero_points()); - pushInt(quantizer->axis()); + pushTensor(tensor.q_per_channel_scales()); + pushTensor(tensor.q_per_channel_zero_points()); + pushInt(tensor.q_per_channel_axis()); } break; default: TORCH_CHECK( @@ -598,12 +596,14 @@ void Pickler::pushTuple(const IValue& ivalue) { } } -WriteableTensorData getWriteableTensorData(const at::Tensor& tensor) { +WriteableTensorData getWriteableTensorData( + const at::Tensor& tensor, + bool to_cpu) { WriteableTensorData result; result.tensor_ = tensor; result.size_ = tensor.storage().nbytes(); // TODO HIP support - if (tensor.storage().device_type() == DeviceType::CUDA) { + if (tensor.storage().device_type() == DeviceType::CUDA && to_cpu) { // NB: This new tensor is created to support cuda tensors. // Storages can be mutated when converting tensors from cuda to cpu, // and we need a cpu tensor to copy data from. diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 2ff9fe68b1bf3..21d0f61a18eb7 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -107,7 +108,7 @@ struct WriteableTensorData { private: friend TORCH_API WriteableTensorData - getWriteableTensorData(const at::Tensor& tensor); + getWriteableTensorData(const at::Tensor& tensor, bool to_cpu); at::Tensor tensor_; uint64_t size_; }; @@ -120,17 +121,17 @@ class TORCH_API Pickler { public: Pickler(std::function writer) - : Pickler(writer, nullptr, nullptr, nullptr) {} + : Pickler(std::move(writer), nullptr, nullptr, nullptr) {} Pickler( std::function writer, std::vector* tensor_table, std::function type_renamer, - std::vector* memorized_class_types) - : writer_(writer), + std::vector* memoized_class_types) + : writer_(std::move(writer)), tensor_table_(tensor_table), - type_renamer_(type_renamer), - memorized_class_types_(memorized_class_types) {} + type_renamer_(std::move(type_renamer)), + memoized_class_types_(memoized_class_types) {} ~Pickler(); // Push protocol onto the stack @@ -208,7 +209,7 @@ class TORCH_API Pickler { // the left of a '::', its type cannot be deduced by the compiler so one must // explicitly instantiate the template, i.e. push(int) works, push(int) // does not) - static constexpr size_t kBufferSize = 256; + static CONSTEXPR_EXCEPT_WIN_CUDA size_t kBufferSize = 256; template void push(typename std::common_type::type value) { const char* begin = reinterpret_cast(&value); @@ -252,7 +253,7 @@ class TORCH_API Pickler { std::function type_renamer_; // List of all the types that it wrote, inspect from the IValues it wrote. - std::vector* memorized_class_types_; + std::vector* memoized_class_types_; // List of tensor storages to serialize in the same binary as the pickle data // similar to ivalues, they are memoized using BINPUT @@ -265,8 +266,9 @@ class TORCH_API Pickler { }; // returns a (tensor, record_size) for a tensor, converting it to a CPU tensor -// if necessary -TORCH_API WriteableTensorData getWriteableTensorData(const at::Tensor& tensor); +// if it was CUDA and to_cpu is True. +TORCH_API WriteableTensorData +getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true); // return the value of the tensor's storage pointer uint64_t getStorageKey(const at::Tensor& tensor); diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index e04339dacc226..6b61b7cb51776 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -309,12 +310,12 @@ struct PythonPrintImpl { // because it doesn't hash any information about the tensors. // We will probably need to optimize this at some point using hashing. if (val.isTensor()) { - auto t = val.toTensor(); + auto& t = val.toTensor(); for (size_t i = 0; i < constant_table_.size(); ++i) { if (!constant_table_[i].isTensor()) { continue; } - auto t2 = constant_table_[i].toTensor(); + auto& t2 = constant_table_[i].toTensor(); if (t.options().type_equal(t2.options()) && t.equal(t2)) { return i; } @@ -801,7 +802,7 @@ struct PythonPrintImpl { } level--; } break; - case prim::Function: { + case prim::Closure: { if (enforce_importable_) { throw ErrorReport(node->sourceRange()) << "closures are not exportable"; @@ -822,6 +823,15 @@ struct PythonPrintImpl { body_ << "):\n"; printBody(graph->block()); } break; + case prim::ModuleDictIndex: { + const auto dict = node->inputs().at(0); + const auto key = node->inputs().at(1); + const auto out = node->outputs().at(0); + assignValuesToTheirUniqueNames(out); + indent(); + body_ << useOf(out) << " : " << out->type()->annotation_str() << " = " + << useOf(dict) << "[" << useOf(key) << "]\n"; + } break; default: auto ss = std::make_shared(&source_range_stack_); printRHS(*ss, node); @@ -864,15 +874,16 @@ struct PythonPrintImpl { void printConstant(TaggedStringStream& stmt, const IValue& v) { const auto customFormatter = [&](std::ostream& ss, const IValue& v) { - if (v.isTensor() || containsNonASCIIString(v)) { + if (v.isTensor() || containsNonASCIIString(v) || v.isObject()) { + TORCH_INTERNAL_ASSERT(!v.type()->is_module()); ss << "CONSTANTS.c" << getOrAddConstant(v); return true; } - if (v.isTuple() && v.type()->expect()->schema()) { + if (v.isTuple() && v.type()->expectRef().schema()) { // print the namedtuple constructor and let rest of tuple printing // continue - ss << v.type()->expect()->annotation_str(type_printer_); + ss << v.type()->expectRef().annotation_str(type_printer_); } return false; }; @@ -970,7 +981,7 @@ struct PythonPrintImpl { } break; case prim::TupleConstruct: { if (auto qualname = - node->output()->type()->expect()->name()) { + node->output()->type()->expectRef().name()) { stmt << node->output()->type()->annotation_str(type_printer_); } printValueList( @@ -1246,6 +1257,7 @@ struct PythonPrintImpl { body_ << "def " << func.name() << "("; auto param_it = graph.inputs().begin(); for (const Argument& arg : schema.arguments()) { + registerClassDependencies(arg.type()); std::string arg_name = genName(arg.name()); if (param_it == graph.inputs().begin()) { // the first argument may omit its type when it is implied by context @@ -1264,9 +1276,10 @@ struct PythonPrintImpl { assignValue(*param_it++, arg_name); } - body_ << ") -> " - << schema.returns().at(0).type()->annotation_str(type_printer_) - << ":\n"; + const auto& returnType = schema.returns().at(0).type(); + body_ << ") -> " << returnType->annotation_str(type_printer_) << ":\n"; + registerClassDependencies(returnType); + printBody(graph.block()); } @@ -1327,15 +1340,13 @@ struct PythonPrintImpl { body_ << "\"" << param << "\", "; } body_ << "]\n"; -#ifndef FBCODE_CAFFE2 - // Note: Forward compat gated. TODO: @voznesenskym to remove when ready. + indent(); body_ << "__buffers__ = ["; for (const auto& buffer : buffers) { body_ << "\"" << buffer << "\", "; } body_ << "]\n"; -#endif } for (size_t i = 0; i < numAttrs; i++) { @@ -1454,7 +1465,7 @@ struct PythonPrintImpl { } } - ~PythonPrintImpl() {} + ~PythonPrintImpl() = default; TaggedStringStream body_; // When printing this node, is it safe to write it inline (i.e. without diff --git a/torch/csrc/jit/serialization/source_range_serialization.cpp b/torch/csrc/jit/serialization/source_range_serialization.cpp index 16d3de491965f..9f158e48f0e3d 100644 --- a/torch/csrc/jit/serialization/source_range_serialization.cpp +++ b/torch/csrc/jit/serialization/source_range_serialization.cpp @@ -87,8 +87,8 @@ SourceRangePickler::SourceRangePickler() : srs(new SourceRangeSerializer()) {} std::vector SourceRangePickler::pickle(const SourceRangeRecords& ranges) { std::vector ivalues; for (const auto& range : ranges) { - std::vector row_elems{(int64_t)range.bytes, - srs->serialize(range.range)}; + std::vector row_elems{ + (int64_t)range.bytes, srs->serialize(range.range)}; ivalues.emplace_back(c10::ivalue::Tuple::create(std::move(row_elems))); } std::vector table; diff --git a/torch/csrc/jit/serialization/source_range_serialization.h b/torch/csrc/jit/serialization/source_range_serialization.h index aaf3684597806..3e47b14cd9979 100644 --- a/torch/csrc/jit/serialization/source_range_serialization.h +++ b/torch/csrc/jit/serialization/source_range_serialization.h @@ -32,7 +32,7 @@ class SourceRangeUnpickler { virtual c10::optional findSourceRangeThatGenerated( const SourceRange& range) = 0; - virtual ~SourceRangeUnpickler() {} + virtual ~SourceRangeUnpickler() = default; }; } // namespace jit diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index c416f96410234..841e87592be92 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -1,6 +1,6 @@ #include #include -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC #include #endif #include @@ -29,7 +29,7 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) { // objects it contains as attributes. // `IfPossible` - we can only do this recovery when we have an object as // the top-level unpickled thing (which is guaranteed for Modules, but -// not for torch.load/torch,save). Otherwise we do not know the types +// not for torch.load/torch.save). Otherwise we do not know the types // of the contained objects and cannot restore the tags. void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { struct Work { @@ -54,6 +54,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { } switch (w.static_type->kind()) { case TensorType::Kind: + case StorageType::Kind: case NumberType::Kind: case FloatType::Kind: case IntType::Kind: @@ -67,6 +68,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { case StringType::Kind: case FunctionType::Kind: case DeviceObjType::Kind: + case StreamObjType::Kind: case QSchemeType::Kind: case LayoutType::Kind: case ScalarTypeType::Kind: @@ -113,7 +115,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { auto elem_type = w.static_type->cast()->getElementType(); auto lst = w.value.toList(); lst.unsafeSetElementType(elem_type); - for (const IValue& item : lst) { + for (const IValue item : lst) { Work elem = {elem_type, item}; to_process.emplace_back(std::move(elem)); } @@ -146,7 +148,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { } } -void restoreContainerTypeTags(IValue& ivalue, TypePtr type) { +void restoreContainerTypeTags(IValue& ivalue, const TypePtr& type) { if (auto dict_type = type->cast()) { auto dict = ivalue.toGenericDict(); dict.unsafeSetKeyType(dict_type->getKeyType()); @@ -222,7 +224,7 @@ void Unpickler::setInput(size_t memo_id) { // avoid it by calling push_back for bool template inline void append(std::vector& a, T&& e) { - a.emplace_back(std::move(e)); + a.emplace_back(std::forward(e)); } template <> inline void append(std::vector& a, bool&& e) { @@ -352,7 +354,7 @@ PickleOpCode Unpickler::readInstruction() { dict.insert_or_assign(stack_[i], stack_[i + 1]); } stack_.erase(stack_.begin() + start, stack_.end()); - stack_.push_back(std::move(dict)); + stack_.emplace_back(std::move(dict)); } break; case PickleOpCode::SETITEMS: { size_t start = marks_.back(); @@ -416,6 +418,12 @@ PickleOpCode Unpickler::readInstruction() { /*resizable=*/false); // NB: we didn't set any allocator for the // tensor auto options = at::CPU(type).options(); + + if (use_storage_device_) { + options = options.device(storage.device()); + device = storage.device(); + } + at::Tensor tensor; if (options.backend() == c10::Backend::QuantizedCPU) { tensor = at::_empty_affine_quantized({}, options, 0, 0) @@ -431,7 +439,7 @@ PickleOpCode Unpickler::readInstruction() { "supported devices include CPU and CUDA, however got ", DeviceTypeName(device.type(), false)); } - stack_.push_back(std::move(tensor)); + stack_.emplace_back(std::move(tensor)); } break; default: { AT_ERROR( @@ -549,7 +557,7 @@ void Unpickler::readGlobal( stack_.emplace_back(int64_t(globals_.size() - 1)); return; } else if (module_name == "torch.distributed.rpc" && class_name == "rref") { -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC return rebuildRRef(); #else TORCH_INTERNAL_ASSERT( @@ -624,7 +632,7 @@ void Unpickler::rebuildTensor(bool quantized) { auto tup = pop(stack_).toTuple(); const auto& elements = tup->elements(); size_t idx = 0; - auto storage_tensor = elements.at(idx++).toTensor(); + auto& storage_tensor = elements.at(idx++).toTensor(); int64_t storage_offset = elements.at(idx++).toInt(); std::vector size = tupleToIntList(elements.at(idx++)); std::vector stride = tupleToIntList(elements.at(idx++)); @@ -665,11 +673,11 @@ void Unpickler::rebuildTensor(bool quantized) { impl->set_storage_offset(storage_offset); impl->set_sizes_and_strides(size, stride); result = autograd::make_variable(result, requires_grad); - stack_.push_back(std::move(result)); + stack_.emplace_back(std::move(result)); }); } -#ifdef USE_DISTRIBUTED +#ifdef USE_RPC void Unpickler::rebuildRRef() { globals_.emplace_back([this] { // It is the same as how rref is unpickled in python, diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 1138b8d42e040..acf6ca9965b7b 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -31,9 +31,10 @@ class TORCH_API Unpickler { std::function reader, TypeResolver type_resolver, const std::vector* tensor_table) - : reader_(reader), + : reader_(std::move(reader)), tensor_table_(tensor_table), type_resolver_(std::move(type_resolver)), + use_storage_device_(false), version_(caffe2::serialize::kProducedFileFormatVersion) {} // tensors inside the pickle contain meta-data, the raw tensor @@ -43,13 +44,15 @@ class TORCH_API Unpickler { TypeResolver type_resolver, ObjLoader obj_loader, std::function read_record, - c10::optional device) - : reader_(reader), + c10::optional device, + bool use_storage_device = false) + : reader_(std::move(reader)), tensor_table_(nullptr), type_resolver_(std::move(type_resolver)), obj_loader_(std::move(obj_loader)), read_record_(std::move(read_record)), device_(std::move(device)), + use_storage_device_(use_storage_device), version_(caffe2::serialize::kProducedFileFormatVersion) {} // consume the pickle stream, producing an IValue from the contents. @@ -139,6 +142,10 @@ class TORCH_API Unpickler { std::function read_record_; c10::optional device_; + // When set to true, Unpickler will ignore the pickled device and use the + // device of the DataPtr returned by the read_record_ function. The default + // value of this flag is false. + const bool use_storage_device_; // See [type tag serialization] uint64_t version_; diff --git a/torch/csrc/jit/tensorexpr/analysis.h b/torch/csrc/jit/tensorexpr/analysis.h index b9f957594bb39..f4ead7fcbfa80 100644 --- a/torch/csrc/jit/tensorexpr/analysis.h +++ b/torch/csrc/jit/tensorexpr/analysis.h @@ -38,12 +38,18 @@ class NodeFinder : public IRVisitor { IRVisitor::visit(v); } - static std::vector find(Stmt* s) { + static std::vector find(const Stmt* s) { NodeFinder nf; s->accept(&nf); return nf.nodes; } + static std::vector find(const Expr* e) { + NodeFinder nf; + e->accept(&nf); + return nf.nodes; + } + std::vector nodes; }; @@ -169,15 +175,15 @@ class CreateBufferMap : public IRVisitor { private: void visit(const Store* v) override { auto load_node = dynamic_cast(v->value()); - auto call_node = dynamic_cast(v->value()); - if (load_node || call_node) { - TORCH_INTERNAL_ASSERT(!(load_node && call_node)); - auto t_buf = load_node ? load_node->buf() : call_node->tensor()->buf(); - if (load_node) { - map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf()); - } else { - map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), t_buf); - } + if (load_node) { + auto t_buf = load_node->buf(); + map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf()); + } else { + auto add_node = dynamic_cast(v->value()); + auto mul_node = dynamic_cast(v->value()); + // This means for now, v->value() can be Add or Mul + TORCH_INTERNAL_ASSERT((add_node || mul_node)); + map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf()); } v->value()->accept(this); } diff --git a/torch/csrc/jit/tensorexpr/block_codegen.h b/torch/csrc/jit/tensorexpr/block_codegen.h index fcd88e040e176..21a4aecb5bcc3 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.h +++ b/torch/csrc/jit/tensorexpr/block_codegen.h @@ -111,8 +111,9 @@ class TORCH_API BlockCodeGen : public CodeGen { BlockCodeGen( Stmt* stmt, const std::vector& buffer_args, - at::Device device = at::Device(at::kCPU)) - : CodeGen(stmt, buffer_args, device) { + at::Device device = at::Device(at::kCPU), + const std::string& kernel_func_name = "func") + : CodeGen(stmt, buffer_args, device, kernel_func_name) { Initialize(); } diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.cpp b/torch/csrc/jit/tensorexpr/bounds_inference.cpp index e9a211e5f9595..2424c2dfc45ac 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_inference.cpp @@ -10,70 +10,93 @@ namespace torch { namespace jit { namespace tensorexpr { -class BoundsInference : public IRVisitor { - public: - void visit(const FunctionCall* v) override; - void visit(const Load* v) override; - void visit(const Store* v) override; - void visit(const For* v) override; - void visit(const Block* v) override; - - BoundsInfo accesses() const { - return accesses_; - } +using namespace analysis; - private: - BoundsInfo accesses_; -}; +template +BoundsInfo mergeTensorAccesses( + const Container& accesses, + const std::unordered_map& varToBuf, + bool distinctAccessKinds) { + BoundsInfo ret; + for (auto& access : accesses) { + if (access->type() == AccessType::Input || + access->type() == AccessType::Output) { + continue; + } -void BoundsInference::visit(const Load* v) { - accesses_[v->buf()].push_back({kLoad, v->indices(), v->indices()}); -} + auto vtbIt = varToBuf.find(access->var()); + TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end()); + const Buf* buf = vtbIt->second; + std::vector& infos = ret[buf]; -void BoundsInference::visit(const FunctionCall* v) { - accesses_[v->tensor()->func_var()].push_back( - {kLoad, v->params(), v->params()}); -} + bool added = false; + // This loop should be small, max of 2 (kLoad, kStore). + for (auto& TABI : infos) { + TensorAccessKind kind = access->isWrite() ? kStore : kLoad; + if (!distinctAccessKinds || kind == TABI.kind) { + TORCH_INTERNAL_ASSERT(TABI.start.size() == access->bounds().size()); + TORCH_INTERNAL_ASSERT(TABI.stop.size() == access->bounds().size()); + for (size_t i = 0; i < TABI.start.size(); ++i) { + TABI.start[i] = IRSimplifier::simplify( + new Min(TABI.start[i], access->bounds()[i].start, true)); + TABI.stop[i] = IRSimplifier::simplify( + new Max(TABI.stop[i], access->bounds()[i].end, true)); + added = true; -void BoundsInference::visit(const Store* v) { - accesses_[v->buf()].push_back({kStore, v->indices(), v->indices()}); - IRVisitor::visit(v); -} + if (kind != TABI.kind) { + TABI.kind = kMutate; + } + } + } + } + + if (!added) { + TensorAccessBoundsInfo info; + info.kind = access->isWrite() ? kStore : kLoad; -void BoundsInference::visit(const For* v) { - v->body()->accept(this); - for (auto& pair : accesses_) { - for (TensorAccessBoundsInfo& access : pair.second) { - for (size_t j = 0; j < access.start.size(); j++) { - // TODO: This function assumes that all indices grow monotonically and - // thus for the loop: - // for i in A..B: - // buf[i] = i - // the range for i is [A, B). It should be generalized to correctly - // handle all cases. - const Expr* old_start = access.start[j]; - const Expr* old_stop = access.stop[j]; - const Expr* new_start = Substitute(old_start, {{v->var(), v->start()}}); - const Expr* new_stop = Substitute( - old_stop, {{v->var(), new Sub(v->stop(), new IntImm(1))}}); - - access.start[j] = IRSimplifier::simplify(new_start); - access.stop[j] = IRSimplifier::simplify(new_stop); + for (auto& b : access->bounds()) { + info.start.push_back(b.start); + info.stop.push_back(b.end); } + + infos.push_back(info); } } + + return ret; } -void BoundsInference::visit(const Block* v) { - BoundsInfo res; - for (auto s : *v) { - s->accept(this); - for (auto& pair : accesses_) { - res[pair.first].insert( - res[pair.first].end(), pair.second.begin(), pair.second.end()); - } +std::unordered_map getAllBufs(Stmt* s) { + std::unordered_map varToBuf; + + auto bufs = NodeFinder::find(s); + auto calls = NodeFinder::find(s); + for (auto* c : calls) { + bufs.push_back(c->tensor()->buf()); + } + + for (auto* b : bufs) { + varToBuf[b->base_handle()] = b; } - accesses_ = res; + return varToBuf; +} + +BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds) { + auto varToBuf = getAllBufs(s); + + MemDependencyChecker checker; + s->accept(&checker); + + return mergeTensorAccesses( + checker.getHistory(), varToBuf, distinctAccessKinds); +} + +BoundsInfo getInferredBounds( + MemDependencyChecker& analyzer, + Stmt* s, + bool distinctAccessKinds) { + return mergeTensorAccesses( + analyzer.accessesWithin(s), getAllBufs(s), distinctAccessKinds); } void printBoundsInfo(const BoundsInfo& v) { @@ -117,146 +140,120 @@ void printBoundsInfo(const BoundsInfo& v) { std::cerr << "}\n"; } -bool equalExprs(const Expr* A, const Expr* B) { - const Expr* diff = IRSimplifier::simplify(new Sub(B, A)); - return diff->isConstant() && immediateEquals(diff, 0); -} +std::vector getBoundExtents( + const std::vector& infos) { + std::vector starts; + std::vector stops; -// returns the bounds of an overlapping range, or {nullptr, nullptr} if the -// ranges don't overlap. -std::pair rangeOverlap( - const Expr* s1, - const Expr* e1, - const Expr* s2, - const Expr* e2) { - // If they're equal they're equal. - if (equalExprs(s1, s2) && equalExprs(e1, e2)) { - return {s1, e1}; - } + // Find the safe size of the temprorary buffer by determining the outer + // extents of a union of all bounds. + for (const TensorAccessBoundsInfo& p : infos) { + for (size_t i = 0; i < p.start.size(); i++) { + if (starts.size() <= i) { + starts.push_back(p.start[i]); + } else { + starts[i] = + IRSimplifier::simplify(new Min(starts[i], p.start[i], true)); + } - std::pair noOverlap = {nullptr, nullptr}; - std::pair overlap = { - IRSimplifier::simplify(new Min(s1, s2, true)), - IRSimplifier::simplify(new Max(e1, e2, true))}; - - const Expr* lowDiff = IRSimplifier::simplify(new Sub(s1, e2)); - const Expr* highDiff = IRSimplifier::simplify(new Sub(s2, e1)); - if (lowDiff->isConstant() && highDiff->isConstant()) { - // No overlap. - if (!(immediateAs(lowDiff) <= 1 || immediateAs(highDiff) >= 1)) { - return noOverlap; + if (stops.size() <= i) { + stops.push_back(p.stop[i]); + } else { + stops[i] = IRSimplifier::simplify(new Max(stops[i], p.stop[i], true)); + } } - - return overlap; } - // Can still merge if we can infer adjacency without knowing static values: - // If we know one side, we can use the fact that each eX >= sX. - if (highDiff->isConstant() && abs(immediateAs(highDiff)) <= 1) { - return {s1, e2}; - } + std::vector extents; + for (size_t i = 0; i < starts.size(); ++i) { + const Expr* dim = IRSimplifier::simplify( + new Add(new Sub(stops[i], starts[i]), new IntImm(1))); - if (lowDiff->isConstant() && abs(immediateAs(lowDiff)) <= 1) { - return {s2, e1}; + extents.push_back(dim); } - const Expr* diffs = IRSimplifier::simplify(new Sub(s2, s1)); - const Expr* diffe = IRSimplifier::simplify(new Sub(e2, e1)); + return extents; +} - // If one side fully encloses the other, they're adjacent. - if (diffs->isConstant() && diffe->isConstant()) { - int ds_i = immediateAs(diffs); - int de_i = immediateAs(diffe); - if ((ds_i <= 0 && de_i >= 0) || (ds_i >= 0 && de_i <= 0)) { - return overlap; - } - } +using BoundSet = std::unordered_set; - // If either the start or end is 1 element apart from it's pair, they must - // be adjacent. - if (diffs->isConstant() && abs(immediateAs(diffs)) <= 1) { - return overlap; +BoundSet convertBounds( + const std::vector& bounds, + TensorAccessKind filter = kMutate) { + BoundSet ret; + for (auto& TABI : bounds) { + if (filter == kMutate || TABI.kind == filter) { + for (size_t i = 0; i < TABI.start.size(); ++i) { + ret.insert(Bound(TABI.start[i], TABI.stop[i])); + } + } } + return ret; +} - if (diffe->isConstant() && abs(immediateAs(diffe)) <= 1) { - return overlap; +BoundSet convertBounds( + BoundsInfo& bounds, + const Buf* buf, + TensorAccessKind filter = kMutate) { + auto it = bounds.find(buf); + if (it == bounds.end()) { + return BoundSet(); } - return noOverlap; + return convertBounds(it->second, filter); } -/* - * Go through the given BoundsInfo vector and merge entries corresponding to - * the same buf. E.g. given - * [{a, kLoad, 0, 100}, {b, kStore, 0, 100}, {a, kLoad, 10, 110}] - * produce: - * [{a, kLoad, 0, 110}, {b, kStore, 0, 100}] - */ -BoundsInfo mergeTensorAccesses(const BoundsInfo& unmerged) { - BoundsInfo res; - // For each buf in the BoundsInfo: - for (auto& pair : unmerged) { - const std::vector& new_vec = pair.second; - std::vector& existing_vec = res[pair.first]; - - // For each bound pair in the unmerged set: - for (const auto& new_bound : new_vec) { - bool found = false; - // For each already merged bound pair: - for (auto& existing_bound : existing_vec) { - // Only merge the same kind of access. - if (existing_bound.kind != new_bound.kind) { - continue; - } +HazardKind getPotentialHazards( + MemDependencyChecker& analyzer, + Stmt* A, + Stmt* B) { + BoundsInfo aBounds = getInferredBounds(analyzer, A, true); + BoundsInfo bBounds = getInferredBounds(analyzer, B, true); - // Sanity check the buf indices have the same dimensionality. - TORCH_INTERNAL_ASSERT(new_bound.start.size() == new_bound.stop.size()); - TORCH_INTERNAL_ASSERT( - existing_bound.start.size() == existing_bound.stop.size()); - TORCH_INTERNAL_ASSERT( - new_bound.start.size() == existing_bound.start.size()); - - std::vector start; - std::vector stop; - bool fail = false; - // For each dimension: - for (size_t i = 0; i < new_bound.start.size(); ++i) { - // The range of the new bound must overlap the existing bound. - // TODO(nickg): we allow all dimensions to partially overlap, - // which will overstate the bounds. - auto pair = rangeOverlap( - new_bound.start[i], - new_bound.stop[i], - existing_bound.start[i], - existing_bound.stop[i]); - if (pair.first == nullptr) { - fail = true; - break; - } - start.push_back(pair.first); - stop.push_back(pair.second); + BoundSet aWrites; + BoundSet aReads; + + for (auto& pair : bBounds) { + const Buf* buf = pair.first; + if (aBounds.find(buf) == aBounds.end()) { + continue; + } + + auto aWrites = convertBounds(aBounds, buf, kStore); + auto aReads = convertBounds(aBounds, buf, kLoad); + + auto bWrites = convertBounds(pair.second, kStore); + auto bReads = convertBounds(pair.second, kLoad); + + // First, RAW. + for (auto& bR : bReads) { + for (auto& aW : aWrites) { + if (boundOverlap(bR, aW) != NoOverlap) { + return HazardKind::ReadAfterWrite; } - if (fail) { - continue; + } + } + + // Then WAR. + for (auto& bW : bWrites) { + for (auto& aR : aReads) { + if (boundOverlap(bW, aR) != NoOverlap) { + return HazardKind::WriteAfterRead; } - found = true; - // Update the existing bound. - existing_bound.start = start; - existing_bound.stop = stop; } - if (!found) { - existing_vec.push_back(new_bound); + } + + // Then WAW. + for (auto& bW : bWrites) { + for (auto& aW : aWrites) { + if (boundOverlap(bW, aW) != NoOverlap) { + return HazardKind::WriteAfterWrite; + } } } } - return res; -} - -BoundsInfo inferBounds(Stmt* s) { - BoundsInference ac; - s->accept(&ac); - return mergeTensorAccesses(ac.accesses()); + return HazardKind::NoDependency; } } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.h b/torch/csrc/jit/tensorexpr/bounds_inference.h index 732e80358de2b..b5b58d09f0b6a 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.h +++ b/torch/csrc/jit/tensorexpr/bounds_inference.h @@ -5,6 +5,7 @@ #include #include +#include namespace torch { namespace jit { @@ -14,7 +15,7 @@ class Expr; class Buf; class Stmt; -enum C10_API_ENUM TensorAccessKind { kLoad, kStore }; +enum C10_API_ENUM TensorAccessKind { kLoad, kStore, kMutate }; struct TORCH_API TensorAccessBoundsInfo { TensorAccessKind kind; @@ -25,11 +26,29 @@ struct TORCH_API TensorAccessBoundsInfo { using BoundsInfo = std::unordered_map>; -TORCH_API BoundsInfo inferBounds(Stmt* s); +TORCH_API BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds = true); + +// Bounds inference caching the analysis. The MemDependencyChecker must already +// have been run. +TORCH_API BoundsInfo getInferredBounds( + analysis::MemDependencyChecker& analyzer, + Stmt* s, + bool distinctAccessKinds = true); TORCH_API void printBoundsInfo(const BoundsInfo& v); -TORCH_API BoundsInfo mergeTensorAccesses(const BoundsInfo& unmerged); +TORCH_API std::vector getBoundExtents( + const std::vector& infos); + +// The kind of dependency found, in increasing order of exclusivity. +enum class HazardKind { + ReadAfterWrite, + WriteAfterRead, + WriteAfterWrite, + NoDependency, +}; +TORCH_API HazardKind +getPotentialHazards(analysis::MemDependencyChecker& analyzer, Stmt* A, Stmt* B); } // namespace tensorexpr } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp new file mode 100644 index 0000000000000..05aacf5cf3842 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp @@ -0,0 +1,245 @@ +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { +namespace analysis { + +OverlapKind boundOverlap(Bound a, Bound b) { + // If they're equal they're equal. + bool startEqual = exprEquals(a.start, b.start); + bool endEqual = exprEquals(a.end, b.end); + if (startEqual && endEqual) { + return ContainedOrEqual; + } + + const Expr* lowDiff = IRSimplifier::simplify(new Sub(a.start, b.end)); + const Expr* highDiff = IRSimplifier::simplify(new Sub(b.start, a.end)); + + if (lowDiff->isConstant() && highDiff->isConstant()) { + int low = immediateAs(lowDiff); + int high = immediateAs(highDiff); + // No overlap. + if (low > 0 || high > 0) { + return NoOverlap; + } + } + + const Expr* diff_start = IRSimplifier::simplify(new Sub(b.start, a.start)); + const Expr* diff_end = IRSimplifier::simplify(new Sub(b.end, a.end)); + + // If one side fully encloses the other, they're adjacent. + if (diff_start->isConstant() && diff_end->isConstant()) { + int start = immediateAs(diff_start); + int end = immediateAs(diff_end); + // If diff_start and diff_end have different signs they are enclosing. + if (start <= 0 && end >= 0) { + return ContainedOrEqual; + } + + if (start >= 0 && end <= 0) { + return Contains; + } + } + + // We can't be sure there's no overlap so the conservative answer is + // partial. + return PartialOverlap; +} + +bool indexBoundsEquals(const IndexBounds& A, const IndexBounds& B) { + if (A.size() != B.size()) { + return false; + } + + for (size_t i = 0; i != A.size(); ++i) { + if (!A[i].equals(B[i])) { + return false; + } + } + return true; +} + +Bound flattenBounds(const IndexBounds& a) { + if (a.empty()) { + return Bound(); + } + Bound ret = a[0]; + + for (size_t i = 1; i < a.size(); ++i) { + ret.start = new Mul(ret.start, a[i].start); + ret.end = new Mul(ret.end, a[i].end); + } + + ret.start = IRSimplifier::simplify(ret.start); + ret.end = IRSimplifier::simplify(ret.end); + return ret; +} + +OverlapKind overlaps(const IndexBounds& a, const IndexBounds& b) { + if (a.empty() && b.empty()) { + return ContainedOrEqual; + } + + // All accesses to a buf must have the same dimensionality. + + if (a.size() != b.size()) { + return boundOverlap(flattenBounds(a), flattenBounds(b)); + } + TORCH_INTERNAL_ASSERT(a.size() == b.size()); + + OverlapKind overlap = boundOverlap(a[0], b[0]); + for (size_t i = 1; i < a.size(); ++i) { + OverlapKind bOverlap = boundOverlap(a[i], b[i]); + if (bOverlap == NoOverlap) { + return NoOverlap; + } + + if (overlap == ContainedOrEqual && bOverlap == Contains) { + overlap = Contains; + } + + if (overlap == Contains && bOverlap == ContainedOrEqual) { + continue; + } + + if (bOverlap != overlap) { + overlap = PartialOverlap; + break; + } + } + + return overlap; +} + +std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { + // The bounds must overlap. + std::vector res; + + if (a.start->isConstant() != b.start->isConstant() || + a.end->isConstant() != b.end->isConstant()) { + return {a}; + } + + const Expr* lowDiff = IRSimplifier::simplify(new Sub(b.start, a.start)); + const Expr* highDiff = IRSimplifier::simplify(new Sub(b.end, a.end)); + + // If the diff has only a single var, we can try to guess sign. + if (!lowDiff->isConstant()) { + auto vars = VarFinder::find(lowDiff); + if (vars.size() == 1) { + lowDiff = IRSimplifier::simplify(new Sub( + Substitute(b.start, {{*vars.begin(), new IntImm(1)}}), + Substitute(a.start, {{*vars.begin(), new IntImm(1)}}))); + } + } + + if (!highDiff->isConstant()) { + auto vars = VarFinder::find(highDiff); + if (vars.size() == 1) { + highDiff = IRSimplifier::simplify(new Sub( + Substitute(b.end, {{*vars.begin(), new IntImm(1)}}), + Substitute(a.end, {{*vars.begin(), new IntImm(1)}}))); + } + } + + bool hasHead = lowDiff->isConstant() && immediateAs(lowDiff) > 0; + bool hasTail = highDiff->isConstant() && immediateAs(highDiff) < 0; + + bool constantExtents = lowDiff->isConstant() && highDiff->isConstant(); + + if (!constantExtents) { + // If we can't infer the bound lengths, there's no way to create a safe + // subset. Just bail out. + return {a}; + } + + if (hasHead) { + res.emplace_back( + a.start, IRSimplifier::simplify(new Sub(b.start, new IntImm(1)))); + } + + if (hasTail) { + const Expr* tailStart = + IRSimplifier::simplify(new Add(b.end, new IntImm(1))); + res.emplace_back(tailStart, a.end); + } + + return res; +} + +std::vector subtractBound(Bound a, Bound b) { + OverlapKind overlap = boundOverlap(a, b); + if (overlap == NoOverlap) { + return {a}; + } + if (overlap == ContainedOrEqual) { + return {}; + } + + return subtractBound(a, b, overlap); +} + +std::vector subtractIndicesBounds( + const IndexBounds& A, + const IndexBounds& B, + OverlapKind overlap) { + if (overlap == NoOverlap) { + return {A}; + } + + if (overlap == ContainedOrEqual) { + return {}; + } + // All accesses to a buf must have the same dimensionality. + TORCH_INTERNAL_ASSERT(A.size() == B.size()); + + // Each dimension can be sliced into multiple bound segments. + std::vector boundSlices; + std::vector remainingOuterBounds; + + for (size_t i = 0; i < A.size(); ++i) { + auto slices = subtractBound(A[i], B[i]); + + Bound remaining = A[i]; + + for (auto slice : slices) { + IndexBounds newRegion; + newRegion.reserve(A.size()); + TORCH_INTERNAL_ASSERT(remainingOuterBounds.size() == i); + + for (size_t j = 0; j < i; ++j) { + newRegion.push_back(remainingOuterBounds[j]); + } + newRegion.push_back(slice); + for (size_t j = i + 1; j < A.size(); ++j) { + newRegion.push_back(A[j]); + } + + boundSlices.push_back(newRegion); + + if (slice.equals(A[i])) { + remaining = A[i]; + } else { + auto remainingSlices = subtractBound(remaining, slice); + TORCH_INTERNAL_ASSERT(remainingSlices.size() == 1); + remaining = remainingSlices[0]; + } + } + + remainingOuterBounds.push_back(remaining); + } + + return boundSlices; +} + +std::vector TORCH_API +subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B) { + return subtractIndicesBounds(A, B, overlaps(A, B)); +} + +} // namespace analysis +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.h b/torch/csrc/jit/tensorexpr/bounds_overlap.h new file mode 100644 index 0000000000000..1ae4278c4b2cb --- /dev/null +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { +namespace analysis { + +// A simple class containing the start and end of a range in a single dimension. +struct TORCH_API Bound { + const Expr* start{nullptr}; + const Expr* end{nullptr}; + + // This stores whether or not the start and end of this Bound have previously + // been swapped. This occurs when the bound is in a loop with a negative + // stride. + bool swapped{false}; + + Bound() = default; + Bound(const Expr* s, const Expr* e) : start(s), end(e) {} + + void print() const { + std::cout << "(" << *start << ", " << *end << ")"; + } + + bool equals(const Bound& other) const { + return exprEquals(start, other.start) && exprEquals(end, other.end); + } + + bool operator==(const Bound& other) const { + return exprEquals(start, other.start) && exprEquals(end, other.end); + } + + void swap() { + std::swap(start, end); + swapped = !swapped; + } +}; + +struct BoundHash { + size_t operator()(const Bound& b) const { + return std::hash()(b.start) ^ std::hash()(b.end); + } +}; + +// The type of overlap found. Each condition is true only if none of the +// previous conditions hold. +// ContainedOrEqual: All elements in the Bound A are in the Bound B (this +// includes the case where the bounds are equal). +// Contains: All elements in the Bound B are in the Bound B. +// PartialOverlap: Any elements in the Bound B are in the Bound A. +// NoOverlap: No elements in the Bound A are in the bound B. +enum OverlapKind { ContainedOrEqual, Contains, PartialOverlap, NoOverlap }; + +// Returns the kind of overlap between Bound A and Bound A in a single +// dimension. +OverlapKind TORCH_API boundOverlap(Bound A, Bound B); + +// A multi dimensional bound representing the bound of a set of indices. +using IndexBounds = std::vector; + +// Returns true if two IndexBounds are equivalent. +bool TORCH_API indexBoundsEquals(const IndexBounds& A, const IndexBounds& B); + +// Flattens a multi dimensional bound to a single dimension. The IndexBounds "a" +// *must* encapsulate the entire range of the buffer. +Bound TORCH_API flattenBounds(const IndexBounds& a); + +// Determines the kind of overlap in X dimensions. +OverlapKind TORCH_API overlaps(const IndexBounds& a, const IndexBounds& b); + +// Returns the Bound slices created by subtracing bound B from bound A. +// Multiple Bounds can be returned in the case where B slices A into two +// distinct regions with no overlap. +// +// Note: this doesn't use IndexBounds because the Bounds returned do not +// represent multiple different dimensions. +std::vector TORCH_API subtractBound(Bound a, Bound b); +std::vector TORCH_API +subtractBound(Bound a, Bound b, OverlapKind overlap); + +// Returns the bound slices created by subtracting the IndexBounds B from A. +std::vector TORCH_API subtractIndicesBounds( + const IndexBounds& A, + const IndexBounds& B, + OverlapKind overlap); +std::vector TORCH_API +subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B); + +} // namespace analysis +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/buffer.h b/torch/csrc/jit/tensorexpr/buffer.h deleted file mode 100644 index 26bba143deaf8..0000000000000 --- a/torch/csrc/jit/tensorexpr/buffer.h +++ /dev/null @@ -1,88 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -// TODO: Merge this class with 'BufHandle' -class Buffer { - public: - Buffer(const BufHandle& data) : data_(data.node()) { - if (data_->base_handle()->dtype() != kHandle) { - throw malformed_input("Buffer dtype must be Handle"); - } - - std::vector stride_handles(ndim()); - for (int i = (int)ndim() - 1; i >= 0; i--) { - if (i == ndim() - 1) { - stride_handles[i] = 1; - } else { - stride_handles[i] = stride_handles[i + 1] * ExprHandle(dim(i + 1)); - } - } - strides_ = ExprHandleVectorToExprVector(stride_handles); - } - Buffer( - const std::string& name, - const Dtype& dtype, - const std::vector& dims) - : Buffer(BufHandle(name, dims, dtype)) {} - - const Buf* data() const { - return data_; - } - Dtype dtype() const { - return data_->dtype(); - } - int ndim() const { - return data_->ndim(); - } - const Expr* dim(int index) const { - return data_->dim(index); - } - std::vector dims() const { - return data_->dims(); - } - - // TODO: consider defer the storage flatten to a later stage. - template - ExprHandle operator()(Args... args) const { - return LoadValue(std::forward(args)...); - } - - ExprHandle LoadValue( - const ExprHandle& x, - const ExprHandle& y, - const ExprHandle& z) const { - return Load::make(*this, {x, y, z}, ExprHandle(1)); - } - ExprHandle LoadValue(const ExprHandle& x, const ExprHandle& y) const { - return Load::make(*this, {x, y}, ExprHandle(1)); - } - ExprHandle LoadValue(const ExprHandle& x) const { - return Load::make(*this, {x}, ExprHandle(1)); - } - - template - ExprHandle call(const std::vector& args) const { - std::vector params(args.begin(), args.end()); - return LoadValue(params); - } - - private: - ExprHandle LoadValue(const std::vector& indices) const; - - const Buf* data_; - std::vector strides_; -}; - -inline ExprHandle Buffer::LoadValue( - const std::vector& indices) const { - return Load::make(*this, indices, ExprHandle(1)); -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index 845214afed62b..7f1f09032555a 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -30,28 +30,27 @@ RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList:: void RegisterCodeGenList::AddStmtFactoryMethod( const std::string& name, const StmtFactoryMethod& stmt_factory_method) { - auto insert_ret = - stmt_factory_methods_.insert(std::make_pair(name, stmt_factory_method)); - if (!insert_ret.second) { - throw std::runtime_error("Duplicated CodeGen names: " + name); - } + stmt_factory_methods_[name] = stmt_factory_method; } std::unique_ptr CreateCodeGen( const std::string& name, Stmt* stmt, const std::vector& params, - at::Device device) { + at::Device device, + const std::string& kernel_func_name) { RegisterCodeGenList::StmtFactoryMethod method = RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name); - return method(stmt, params, device); + return method(stmt, params, device, kernel_func_name); } const Expr* GenericIntrinsicsExpander::mutate(const Intrinsics* v) { if (v->op_type() == kSigmoid) { auto x = v->param(0)->accept_mutator(this); - auto one = ExprHandle(getImmediateByType(v->dtype(), 1.0)); - auto zero = ExprHandle(getImmediateByType(v->dtype(), 0.0)); + auto one = expr_to_vec( + ExprHandle(getImmediateByType(v->dtype(), 1.0)), v->dtype().lanes()); + auto zero = expr_to_vec( + ExprHandle(getImmediateByType(v->dtype(), 0.0)), v->dtype().lanes()); ExprHandle y = one / (one + exp(zero - ExprHandle(x))); return y.node(); } diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index 271e831a7768f..4a63db9783620 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -24,8 +23,12 @@ class TORCH_API CodeGen { CodeGen( Stmt* stmt, const std::vector& buffer_args, - at::Device device = at::kCPU) - : stmt_(stmt), buffer_args_(buffer_args), device_(device) {} + at::Device device = at::kCPU, + const std::string& kernel_func_name = "func") + : stmt_(stmt), + buffer_args_(buffer_args), + device_(device), + kernel_func_name_(kernel_func_name) {} virtual ~CodeGen() {} @@ -63,30 +66,38 @@ class TORCH_API CodeGen { virtual void call(const std::vector& args) = 0; + virtual at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + return at::empty_strided( + size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); + } + + const std::string& kernel_func_name() const { + return kernel_func_name_; + } + private: Stmt* stmt_; std::vector buffer_args_; at::Device device_ = at::kCPU; + std::string kernel_func_name_ = "func"; }; class CodeGen::BufferArg { public: - BufferArg(const Buffer& buffer) + BufferArg(const Placeholder& buffer) : var_(buffer.data()->base_handle()), dtype_(buffer.dtype()) {} BufferArg(Tensor* tensor) - : var_(tensor->function() - ->func_var(tensor->output_index()) - ->base_handle()), - dtype_(tensor->function()->body(tensor->output_index())->dtype()) {} - BufferArg(const Function& func) - : var_(func.func_var(0)->base_handle()), dtype_(func.body(0)->dtype()) { - // TODO: Support multiple-output functions - if (func.func_vars().size() != 1) { - throw unimplemented_lowering(); - } - } + : var_(tensor->buf()->base_handle()), dtype_(tensor->buf()->dtype()) {} BufferArg(const VarHandle& var) : var_(var.node()), dtype_(var.dtype()), isVar_(true) {} + BufferArg(const BufHandle& buf) + : var_(buf.node()->base_handle()), dtype_(buf.node()->dtype()) {} const Var* var() const { return var_; @@ -158,7 +169,8 @@ class RegisterCodeGenList { using StmtFactoryMethod = std::function( Stmt* stmt, const std::vector&, - at::Device device)>; + at::Device device, + const std::string& kernel_func_name)>; TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name); @@ -184,9 +196,10 @@ class RegisterCodeGen { name, [](Stmt* stmt, const std::vector& params, - at::Device device) { + at::Device device, + const std::string& kernel_func_name) { std::unique_ptr method( - new CodeGenType(stmt, params, device)); + new CodeGenType(stmt, params, device, kernel_func_name)); return method; }); } @@ -196,7 +209,8 @@ TORCH_API std::unique_ptr CreateCodeGen( const std::string& name, Stmt* stmt, const std::vector& params, - at::Device device = at::kCPU); + at::Device device = at::kCPU, + const std::string& kernel_func_name = "func"); class TORCH_API GenericIntrinsicsExpander : public IRMutator { protected: diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index 06e6703d494ac..1364ea710282e 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -1,8 +1,9 @@ #include -#include +#include #include #include +#include #include #include #include @@ -89,6 +90,9 @@ static void getMajorMinor( max_dev_version = CudaVersion(7, 2); } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5 max_dev_version = CudaVersion(7, 5); + } else if (nvrtc_version.first == 11 && nvrtc_version.second == 0) { + // 11.0 supports 3-8.0 + max_dev_version = CudaVersion(8, 0); } if (dev_version > max_dev_version) { dev_version = max_dev_version; @@ -283,8 +287,12 @@ void CudaPrinter::visit(const Intrinsics* v) { if (returnType == ScalarType::Half || returnType == ScalarType::Float) { func_name = func_name + "f"; } - if (v->op_type() == IntrinsicsOp::kFabs && is_integral(returnType)) { - func_name = "abs"; + if (v->op_type() == IntrinsicsOp::kAbs && !is_integral(returnType)) { + // since kAbs's func_name is `abs`, prefix `f` for floating point + func_name = "f" + func_name; + } + if (v->op_type() == IntrinsicsOp::kIsNan) { + func_name = "isnan"; } os() << func_name << "("; @@ -546,6 +554,7 @@ class PrioritizeLoad : public IRMutator { const Var* load_new_var = new Var("v", v->dtype()); const Expr* new_value = IRMutator::mutate(v); load_list.push_back(std::make_pair(load_new_var, new_value)); + return load_new_var; } @@ -657,12 +666,14 @@ class PrioritizeLoad : public IRMutator { }; std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) { - // We are using a global counter here to make sure difference instances - // within CudaCodeGen have different names. - static int64_t counter = 0; - ++counter; - int64_t value = counter; - return func_prefix + "_" + std::to_string(value); + int64_t counter = 0; + std::string name = func_prefix; + while (taken_func_names.count(name)) { + name = func_prefix + "_" + std::to_string(counter++); + } + + taken_func_names.insert(name); + return name; } bool GPUMetaVarRewriter::isFullExtent() { @@ -865,18 +876,30 @@ static std::ostream& operator<<( return out; } -static const char* resource_string = R"( +#ifdef USE_ROCM +static const char* device_resource_string = R"( +#include +#define POS_INFINITY INFINITY +#define NEG_INFINITY -INFINITY + +)"; +#else +static const char* device_resource_string = R"( #define NAN __int_as_float(0x7fffffff) #define POS_INFINITY __int_as_float(0x7f800000) #define NEG_INFINITY __int_as_float(0xff800000) +)"; +#endif + +static const char* shared_resource_string = R"( template -T maximum(T a, T b) { +__device__ T maximum(T a, T b) { return isnan(a) ? a : (a > b ? a : b); } template -T minimum(T a, T b) { +__device__ T minimum(T a, T b) { return isnan(a) ? a : (a < b ? a : b); } @@ -898,7 +921,7 @@ void CudaCodeGen::Initialize() { metavar_rewriter_ = std::make_unique(cuda_analysis_.get()); - os() << resource_string; + os() << device_resource_string << shared_resource_string; if (has_random_) { os() << philox_random_string << std::endl; @@ -907,14 +930,26 @@ void CudaCodeGen::Initialize() { // Check whether the statement uses the Half type, if so add the // half_support_literal. Stmt* stmt_v = stmt(); - CudaHalfChecker halfChecker; - stmt_v = stmt_v->accept_mutator(&halfChecker); + HalfChecker halfChecker(buffer_args()); + stmt_v->accept(&halfChecker); if (halfChecker.hasHalf()) { os() << fuser::cuda::half_support_literal << std::endl; } - std::string func_name = GetUniqueFuncName("func"); - os() << "extern \"C\" __global__" << std::endl << "void " << func_name << "("; + std::string func_name = GetUniqueFuncName(kernel_func_name()); + os() << "extern \"C\" __global__" << std::endl; +#ifdef USE_ROCM + // CUDA has a default limit of threads per block (=flat work group size) + // of 1024, but ROCm uses 256 by default. At the time of writing + // (#45506), I am unaware of a stricter limit that TensorExpr imposes + // (maybe for perf),so I use 1024 as maximum flat work group size. + // We put a minimum value of 1, this is also used by hip (ROCm 3.8) in + // the __launch_bound__ implementation. The arguments for the attribute + // are (min, max), for details see the documentation at + // https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size + os() << "__attribute__((amdgpu_flat_work_group_size(1, 1024)))" << std::endl; +#endif + os() << "void " << func_name << "("; const std::vector buffer_args = this->buffer_args(); for (size_t i = 0; i < buffer_args.size(); i++) { if (i > 0) { @@ -962,6 +997,11 @@ void CudaCodeGen::Initialize() { PrioritizeLoad prioritize_load; stmt_v = stmt_v->accept_mutator(&prioritize_load); + + // The registerizer might insert half-type scalars, we don't want this. + HalfRewriter hsFix; + stmt_v = stmt_v->accept_mutator(&hsFix); + stmt_v = IRSimplifier::simplify(stmt_v); set_stmt(stmt_v); @@ -1110,6 +1150,18 @@ void CudaCodeGen::call(const std::vector& args) { } } +at::Tensor CudaCodeGen::empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + c10::DeviceGuard device_guard(device_opt.value()); + return at::native::empty_strided_cuda( + size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + void CudaCodeGen::CompileToNVRTC( const std::string& code, const std::string& func_name) { diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h index b23f33f7ea337..a4d9a199314a0 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.h +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h @@ -73,12 +73,14 @@ class GPUMetaVarRewriter : public IRMutator { public: explicit GPUMetaVarRewriter(const CudaAnalysis* cuda_analysis) : cuda_analysis_(cuda_analysis) { - gpu_block_vars_ = {new Var("blockIdx.x", kInt), - new Var("blockIdx.y", kInt), - new Var("blockIdx.z", kInt)}; - gpu_thread_vars_ = {new Var("threadIdx.x", kInt), - new Var("threadIdx.y", kInt), - new Var("threadIdx.z", kInt)}; + gpu_block_vars_ = { + new Var("blockIdx.x", kInt), + new Var("blockIdx.y", kInt), + new Var("blockIdx.z", kInt)}; + gpu_thread_vars_ = { + new Var("threadIdx.x", kInt), + new Var("threadIdx.y", kInt), + new Var("threadIdx.z", kInt)}; current_block_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)}; current_thread_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)}; @@ -197,8 +199,9 @@ class TORCH_CUDA_API CudaCodeGen : public CodeGen { CudaCodeGen( Stmt* stmt, const std::vector& buffer_args, - at::Device device = at::Device(at::kCUDA, at::cuda::current_device())) - : CodeGen(stmt, buffer_args, device) { + at::Device device = at::Device(at::kCUDA, at::cuda::current_device()), + const std::string& kernel_func_name = "func") + : CodeGen(stmt, buffer_args, device, kernel_func_name) { Initialize(); } @@ -211,6 +214,14 @@ class TORCH_CUDA_API CudaCodeGen : public CodeGen { call(std::vector({CallArg(ts)...})); } + at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) override; + const std::vector& gpu_block_extents() const { return cuda_analysis_->gpu_block_extents(); } @@ -239,6 +250,7 @@ class TORCH_CUDA_API CudaCodeGen : public CodeGen { std::unique_ptr printer_; std::unique_ptr cuda_analysis_; std::unique_ptr metavar_rewriter_; + std::unordered_set taken_func_names; CUfunction function_; bool has_random_ = false; diff --git a/torch/csrc/jit/tensorexpr/cuda_half_support.h b/torch/csrc/jit/tensorexpr/cuda_half_support.h deleted file mode 100644 index 62e3ff21fb720..0000000000000 --- a/torch/csrc/jit/tensorexpr/cuda_half_support.h +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -// Walk the Statment looking for Half size loads/stores. -class CudaHalfChecker : public IRMutator { - public: - bool hasHalf() { - return hasHalf_; - } - - const Expr* mutate(const Load* v) override { - const Expr* child = IRMutator::mutate(v); - if (child->dtype().scalar_type() != ScalarType::Half) { - return child; - } - - hasHalf_ = true; - - // TODO discards lanes. - return new Cast(kFloat, child); - } - - Stmt* mutate(const Store* v) override { - const Expr* new_val = v->value()->accept_mutator(this); - - if (v->value()->dtype().scalar_type() == ScalarType::Half) { - // TODO discards lanes. - new_val = new Cast(kHalf, new_val); - hasHalf_ = true; - } - - return new Store(v->buf(), v->indices(), new_val, v->mask()); - } - - private: - bool hasHalf_{false}; -}; - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 43fe492fcabb5..186af3ca822f5 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -8,6 +8,962 @@ DEFINE_TRIGGER(simple_ir_eval_executed); RegisterCodeGen ir_eval_codegen_reg("simple_ir_eval"); +template +inline typename std::enable_if::value, T>::type mod_value( + T lhs, + T rhs) { + return lhs % rhs; +} + +template +inline typename std::enable_if::value, T>::type +mod_value(T lhs, T rhs) { + return std::fmod(lhs, rhs); +} + +inline bool mod_value(bool lhs, bool rhs) { + throw std::runtime_error("Attempted modulus of bool"); +} + +template +inline typename std::enable_if::value, T>::type div_value( + T lhs, + T rhs) { + TORCH_CHECK(rhs != 0, "Division by zero"); + return lhs / rhs; +} + +template +inline typename std::enable_if::value, T>:: + type __ubsan_ignore_float_divide_by_zero__ + div_value(T lhs, T rhs) { + return lhs / rhs; +} + +inline bool div_value(bool lhs, bool rhs) { + LOG(FATAL) << "Attempted division of bool"; + return false; +} + +inline c10::Half div_value(c10::Half lhs, c10::Half rhs) { + return lhs / rhs; +} + +class SimpleIREvaluatorImpl : public IRVisitor { + public: + SimpleIREvaluatorImpl() = default; + + ~SimpleIREvaluatorImpl() override = default; + + void bindBuf(const Var* var, void* ptr) { + buffer_mapping_[var] = ptr; + } + void bindVar(const Var* var, const Value& val) { + eval_context_[var] = val; + } + + Value evaluateExpr(const Expr* e) { + e->accept(this); + return value_; + } + + Value value() const { + return value_; + } + + void clear() { + eval_context_.clear(); + buffer_mapping_.clear(); + internal_buffers_.clear(); + } + + TORCH_API void visit(const Add* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Sub* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Mul* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Div* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Mod* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Max* v) override { + visit_binary_op(v, v->propagate_nans()); + } + TORCH_API void visit(const Min* v) override { + visit_binary_op(v, v->propagate_nans()); + } + + TORCH_API void visit(const And* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Or* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Xor* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Lshift* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Rshift* v) override { + visit_binary_op(v); + } + + void visit(const CompareSelect* v) override { + visit_compare_select_op(v, v->compare_select_op()); + } + + template + typename std::enable_if_t::value, T> max_value( + T a, + T b) { + return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? b : a)); + } + + template + typename std::enable_if_t::value, T> max_value( + T a, + T b) { + return a < b ? b : a; + } + + template + typename std::enable_if_t::value, T> min_value( + T a, + T b) { + return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? a : b)); + } + + template + typename std::enable_if_t::value, T> min_value( + T a, + T b) { + return a < b ? a : b; + } + + template + Value binary_op(const Value& lhs, const Value& rhs, IRNodeType op_type) { + std::vector lhs_v = lhs.as_vec(); + std::vector rhs_v = rhs.as_vec(); + std::vector result_v(lhs_v.size()); + for (size_t i = 0; i < lhs_v.size(); i++) { + switch (op_type) { + case IRNodeType::kAdd: + result_v[i] = lhs_v[i] + rhs_v[i]; + break; + case IRNodeType::kSub: + result_v[i] = lhs_v[i] - rhs_v[i]; + break; + case IRNodeType::kMul: + result_v[i] = lhs_v[i] * rhs_v[i]; + break; + case IRNodeType::kDiv: + result_v[i] = div_value(lhs_v[i], rhs_v[i]); + break; + case IRNodeType::kMod: + result_v[i] = mod_value(lhs_v[i], rhs_v[i]); + break; + case IRNodeType::kMax: + result_v[i] = max_value(lhs_v[i], rhs_v[i]); + break; + case IRNodeType::kMin: + result_v[i] = min_value(lhs_v[i], rhs_v[i]); + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } + return Value(result_v); + } + + template + Value bitwise_binary_op( + const Value& lhs, + const Value& rhs, + IRNodeType op_type) { + std::vector lhs_v = lhs.as_vec(); + std::vector rhs_v = rhs.as_vec(); + std::vector result_v(lhs_v.size()); + for (size_t i = 0; i < lhs_v.size(); i++) { + switch (op_type) { + case IRNodeType::kAnd: + result_v[i] = lhs_v[i] & rhs_v[i]; + break; + case IRNodeType::kOr: + result_v[i] = lhs_v[i] | rhs_v[i]; + break; + case IRNodeType::kXor: + result_v[i] = lhs_v[i] ^ rhs_v[i]; + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } + return Value(result_v); + } + + template + Value shift_binary_op( + const Value& lhs, + const Value& rhs, + IRNodeType op_type) { + std::vector lhs_v = lhs.as_vec(); + std::vector rhs_v = rhs.as_vec(); + std::vector result_v(lhs_v.size()); + for (size_t i = 0; i < lhs_v.size(); i++) { + switch (op_type) { + case IRNodeType::kLshift: { + typename std::make_unsigned::type a = + static_cast::type>(lhs_v[i]); + result_v[i] = a << rhs_v[i]; + break; + } + case IRNodeType::kRshift: + result_v[i] = lhs_v[i] >> rhs_v[i]; + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } + return Value(result_v); + } + + template + Value compare_select_op( + const Value& lhs, + const Value& rhs, + const Value& retval1, + const Value& retval2, + CompareSelectOperation cmp_op) { + std::vector lhs_v = lhs.as_vec(); + std::vector rhs_v = rhs.as_vec(); + std::vector ret_val1_v = retval1.as_vec(); + std::vector ret_val2_v = retval2.as_vec(); + std::vector result_v(lhs_v.size()); + for (size_t i = 0; i < lhs_v.size(); i++) { + switch (cmp_op) { + case CompareSelectOperation::kEQ: + result_v[i] = (lhs_v[i] == rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kNE: + result_v[i] = (lhs_v[i] != rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kGT: + result_v[i] = (lhs_v[i] > rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kGE: + result_v[i] = (lhs_v[i] >= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kLT: + result_v[i] = (lhs_v[i] < rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kLE: + result_v[i] = (lhs_v[i] <= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } + return Value(result_v); + } + + template + void visit_binary_op(const BinaryOpNode* v, bool option = false) { + v->lhs()->accept(this); + Value lhs_v = value_; + v->rhs()->accept(this); + Value rhs_v = value_; + if (lhs_v.dtype() != rhs_v.dtype()) { + throw malformed_input("bad dtype in binary op", v); + } + + IRNodeType expr_type = v->expr_type(); + if (expr_type == IRNodeType::kAnd || expr_type == IRNodeType::kOr || + expr_type == IRNodeType::kXor) { + switch (lhs_v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + value_ = bitwise_binary_op(lhs_v, rhs_v, expr_type); \ + break; + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + case ScalarType::Bool: + value_ = bitwise_binary_op(lhs_v, rhs_v, expr_type); + break; + default: + throw unsupported_dtype(); + } + return; + } + + if (expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kRshift) { + switch (lhs_v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + value_ = shift_binary_op(lhs_v, rhs_v, expr_type); \ + break; + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + case ScalarType::Bool: + value_ = shift_binary_op(lhs_v, rhs_v, expr_type); + break; + default: + throw unsupported_dtype(); + } + return; + } + + switch (lhs_v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + value_ = binary_op(lhs_v, rhs_v, expr_type); \ + break; + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + case ScalarType::Bool: + value_ = binary_op(lhs_v, rhs_v, expr_type); + break; + default: + throw unsupported_dtype(); + } + } + + template + Value compare_select_op_helper( + const Value& lhs, + const Value& rhs, + const Value& retval1, + const Value& retval2, + CompareSelectOperation cmp_op) { + Value value; + switch (retval1.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + value = compare_select_op(lhs, rhs, retval1, retval2, cmp_op); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + + return value; + } + + void visit_compare_select_op( + const CompareSelect* v, + CompareSelectOperation cmp_op) { + v->lhs()->accept(this); + Value lhs_v = value_; + v->rhs()->accept(this); + Value rhs_v = value_; + v->ret_val1()->accept(this); + Value ret_val1_v = value_; + v->ret_val2()->accept(this); + Value ret_val2_v = value_; + + if (lhs_v.dtype() != rhs_v.dtype() || + ret_val1_v.dtype() != ret_val2_v.dtype()) { + throw malformed_input("bad dtype in CompareSelect", v); + } + + switch (lhs_v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + value_ = compare_select_op_helper( \ + lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + } + +#define IMM_VISIT(Type, Name) \ + TORCH_API void visit(const Name##Imm* v) override { \ + value_ = Value(v->value()); \ + } + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); +#undef IMM_VISIT + + TORCH_API void visit(const Block* v) override { + const Block* last = scope_; + scope_ = v; + for (Stmt* s : v->stmts()) { + s->accept(this); + } + + auto it = var_by_scope_.find(v); + if (it != var_by_scope_.end()) { + for (const Expr* v : it->second) { + eval_context_.erase(v); + } + var_by_scope_.erase(it); + } + + scope_ = last; + } + + TORCH_API void visit(const Var* v) override { + auto iter = eval_context_.find(v); + if (iter == eval_context_.end()) { + throw malformed_input("could not find Var in context", v); + } + + value_ = iter->second; + } + + template + std::vector castValues(const Dtype& src_dtype, const Value& v) { + const std::vector& src_values = v.as_vec(); + std::vector dst_values(src_values.size()); + for (int i = 0; i < src_dtype.lanes(); ++i) { + dst_values[i] = static_cast(src_values[i]); + } + return dst_values; + } + + template + void doCastFromSrc( + const Dtype& src_dtype, + const Dtype& dst_dtype, + const Value& v) { + switch (dst_dtype.scalar_type()) { +#define DST_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + this->value_ = Value(castValues(src_dtype, v)); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DST_TYPE_CASE); +#undef DST_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + + TORCH_API void visit(const Cast* v) override { + const Expr* src_value = v->src_value(); + src_value->accept(this); + Dtype dst_dtype = v->dtype(); + Dtype src_dtype = src_value->dtype(); + if (src_dtype.lanes() != dst_dtype.lanes()) { + throw malformed_input("lane mismatch in Cast", v); + } + + if (src_dtype != dst_dtype) { + switch (src_dtype.scalar_type()) { +#define SRC_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + doCastFromSrc(src_dtype, dst_dtype, value_); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, SRC_TYPE_CASE); +#undef SRC_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + } + + template + std::vector bitcastValues(const Dtype& src_dtype, const Value& v) { + const std::vector& src_values = v.as_vec(); + std::vector dst_values(src_values.size()); + for (int i = 0; i < src_dtype.lanes(); ++i) { + dst_values[i] = raw_bitcast(src_values[i]); + } + return dst_values; + } + + template + void doBitCastFromSrc( + const Dtype& src_dtype, + const Dtype& dst_dtype, + const Value& v) { + switch (dst_dtype.scalar_type()) { +#define DST_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + this->value_ = Value(bitcastValues(src_dtype, v)); \ + break; + // bool/half not supported + AT_FORALL_SCALAR_TYPES(DST_TYPE_CASE); +#undef DST_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + + TORCH_API void visit(const BitCast* v) override { + const Expr* src_value = v->src_value(); + src_value->accept(this); + Dtype dst_dtype = v->dtype(); + Dtype src_dtype = src_value->dtype(); + if (src_dtype.byte_size() != dst_dtype.byte_size()) { + throw malformed_input("lane mismatch in Cast", v); + } + if (src_dtype != dst_dtype) { + switch (src_dtype.scalar_type()) { +#define SRC_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + doBitCastFromSrc(src_dtype, dst_dtype, value_); \ + break; + // bool/half not supported + AT_FORALL_SCALAR_TYPES(SRC_TYPE_CASE); +#undef SRC_TYPE_CASE + default: + throw unsupported_dtype(); + } + } + } + + TORCH_API void visit(const For* v) override { + const Expr* var_node = v->var(); + v->start()->accept(this); + int start = value_.as(); + v->stop()->accept(this); + int stop = value_.as(); + if (eval_context_.count(var_node)) { + throw malformed_input("could not find var_node in For context", v); + } + + for (int i = start; i < stop; i++) { + eval_context_[var_node] = Value(i); + if (v->body()) { + v->body()->accept(this); + } + } + eval_context_.erase(var_node); + } + + TORCH_API void visit(const Ramp* v) override { + v->base()->accept(this); + int base = value().as(); + v->stride()->accept(this); + int stride = value().as(); + int lanes = v->lanes(); + + std::vector values(lanes); + for (int i = 0; i < lanes; i++) { + values[i] = base + i * stride; + } + + value_ = Value(values); + } + + TORCH_API void visit(const Broadcast* v) override { + v->value()->accept(this); + Value value = this->value(); + int lanes = v->lanes(); + switch (value.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + std::vector v(lanes, value.as()); \ + value_ = Value(v); \ + } break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + } + + TORCH_API void visit(const IfThenElse* v) override { + v->condition()->accept(this); + bool cond_v; + switch (value_.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + cond_v = value_.as(); \ + } break; + AT_FORALL_SCALAR_TYPES_AND(Bool, TYPE_CASE); +#undef TYPE_CASE + case ScalarType::Half: + throw unsupported_dtype("IfThenElse condition can't have Half dtype"); + default: + throw unsupported_dtype(); + } + + if (cond_v) { + v->true_value()->accept(this); + } else { + v->false_value()->accept(this); + } + } + + TORCH_API void visit(const Load* v) override { + const Var* base_node = v->base_handle(); + auto iter = buffer_mapping_.find(base_node); + if (iter == buffer_mapping_.end()) { + throw malformed_input("could not find base node in Load", v); + } + void* ptr = iter->second; + + const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices()); + flat_idx->accept(this); + std::vector index = value().as_vec(); + v->mask()->accept(this); + std::vector mask = value().as_vec(); + ScalarType v_sdtype = v->dtype().scalar_type(); + switch (v_sdtype) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + Type* ptr##Name = static_cast(ptr); \ + std::vector v(index.size()); \ + for (size_t i = 0; i < index.size(); i++) { \ + if (mask[i]) { \ + v[i] = ptr##Name[index[i]]; \ + } \ + } \ + value_ = Value(v); \ + } break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + } + + TORCH_API void visit(const Store* v) override { + const Var* base_node = v->base_handle(); + auto iter = buffer_mapping_.find(base_node); + if (iter == buffer_mapping_.end()) { + throw malformed_input("could not find base node in Store", v); + } + + void* ptr = iter->second; + + const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices()); + flat_idx->accept(this); + std::vector index = value().as_vec(); + v->mask()->accept(this); + std::vector mask = value().as_vec(); + if (index.size() != mask.size()) { + throw malformed_input("mask size mismatch in Store", v); + } + + ScalarType v_sdtype = v->value()->dtype().scalar_type(); + + switch (v_sdtype) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + v->value()->accept(this); \ + std::vector value = this->value().as_vec(); \ + if (index.size() != value.size()) { \ + throw malformed_input("value size mismatch in Store", v); \ + } \ + Type* ptr##Name = static_cast(ptr); \ + for (size_t i = 0; i < index.size(); i++) { \ + if (mask[i]) { \ + ptr##Name[index[i]] = value[i]; \ + } \ + } \ + } break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + } + + TORCH_API void visit(const BaseCallNode* v) override { + throw unimplemented_lowering(v); + } + + template + void visit_intrinsics_helper(const Intrinsics* v) { + std::vector values(v->nparams()); + for (int i = 0; i < v->nparams(); i++) { + v->param(i)->accept(this); + values[i] = this->value(); + } + std::vector v1; + if (values.size() >= 1ULL) { + v1 = values[0].as_vec(); + } + std::vector v2; + if (values.size() >= 2ULL) { + v2 = values[1].as_vec(); + if (v1.size() != v2.size()) { + throw malformed_input("value size mismatch in Intrinsics", v); + } + } + + if (values.size() > 2) { + throw unimplemented_lowering(v); + } + + std::vector result(v1.size(), -1); + if (values.size() == 1ULL) { + for (size_t i = 0; i < v1.size(); i++) { + result[i] = compute_intrinsics(v->op_type(), v1[i]); + } + } else { + for (size_t i = 0; i < v1.size(); i++) { + result[i] = compute_intrinsics(v->op_type(), v1[i], v2[i]); + } + } + value_ = Value(result); + } + + TORCH_API void visit(const Intrinsics* v) override { + auto ty = v->dtype().scalar_type(); + if (v->op_type() == kIsNan) { + auto inp_dtype = v->params().at(0)->dtype().scalar_type(); + if (inp_dtype == ScalarType::Float) { + visit_intrinsics_helper(v); + } else if (inp_dtype == ScalarType::Double) { + visit_intrinsics_helper(v); + } else if (inp_dtype == ScalarType::Half) { + throw unsupported_dtype(); // TODO + } + } else { + switch (ty) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + visit_intrinsics_helper(v); \ + break; + AT_FORALL_SCALAR_TYPES(TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + } + } + + void visit(const Allocate* v) override { + const Var* buffer_var = v->buffer_var(); + std::vector dims = v->dims(); + int total_byte_size = v->dtype().byte_size(); + for (size_t i = 0; i < dims.size(); i++) { + dims[i]->accept(this); + total_byte_size *= value_.as(); + } + int int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int); + std::unique_ptr> buffer(new std::vector(int_count)); + auto iter = buffer_mapping_.find(buffer_var); + if (iter != buffer_mapping_.end() && iter->second != nullptr) { + throw std::runtime_error( + "Allocate a buffer that has already been allocated: " + + buffer_var->name_hint()); + } + buffer_mapping_[buffer_var] = buffer->data(); + internal_buffers_.insert(std::make_pair(buffer_var, std::move(buffer))); + } + + void visit(const Free* v) override { + const Var* buffer_var = v->buffer_var(); + int count = internal_buffers_.erase(buffer_var); + if (count == 0) { + throw std::runtime_error( + "Free a buffer that is not currently bound: " + + buffer_var->name_hint()); + } + buffer_mapping_.erase(buffer_var); + } + + void visit(const Let* v) override { + var_by_scope_[scope_].push_back(v->var()); + bindVar(v->var(), evaluateExpr(v->value())); + } + + void visit(const Cond* v) override { + v->condition()->accept(this); + if (value().as()) { + if (v->true_stmt()) { + v->true_stmt()->accept(this); + } + } else { + if (v->false_stmt()) { + v->false_stmt()->accept(this); + } + } + } + + private: + template < + typename TReturn, + typename TInput, + typename std::enable_if::value, int>:: + type = 0> + static TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v) { + switch (op_type) { + case kSin: + return std::sin(v); + case kCos: + return std::cos(v); + case kTan: + return std::tan(v); + case kAsin: + return std::asin(v); + case kAcos: + return std::acos(v); + case kAtan: + return std::atan(v); + case kSinh: + return std::sinh(v); + case kCosh: + return std::cosh(v); + case kTanh: + return std::tanh(v); + case kExp: + return std::exp(v); + case kAbs: + return std::abs(v); + case kExpm1: + return std::expm1(v); + case kLog: + return std::log(v); + case kLog2: + return std::log2(v); + case kLog10: + return std::log10(v); + case kLog1p: + return std::log1p(v); + case kErf: + return std::erf(v); + case kErfc: + return std::erfc(v); + case kSqrt: + return std::sqrt(v); + case kRsqrt: { + auto rsqrt = [](TInput v) __ubsan_ignore_float_divide_by_zero__ { + return 1.0f / std::sqrt(v); + }; + return rsqrt(v); + } + case kCeil: + return std::ceil(v); + case kFloor: + return std::floor(v); + case kRound: + return std::round(v); + case kTrunc: + return std::trunc(v); + case kLgamma: + return std::lgamma(v); + case kFrac: + TInput intpart; + return std::modf(v, &intpart); + case kIsNan: + return std::isnan(v); + default: + throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + } + } + + template < + typename TReturn, + typename TInput, + typename std::enable_if::value, int>::type = 0> + static TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v) { + switch (op_type) { + case kAbs: { + // internal tool complains about calling `abs` on unsigned, the + // following makes the tool happy + using X = + std::conditional_t::value, int, TInput>; + return std::is_unsigned::value ? v + : std::abs(static_cast(v)); + } + default: + throw std::runtime_error( + "Invalid integral op_type: " + c10::to_string(op_type)); + } + } + + // specialization for float -> int ops (just kIsNan currently) + int compute_intrinsics(IntrinsicsOp op_type, float v) { + switch (op_type) { + case kIsNan: + return std::isnan(v); + default: + throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + } + } + + template + TReturn compute_intrinsics(IntrinsicsOp op_type, TInput v1, TInput v2) { + switch (op_type) { + case kPow: + return std::pow(v1, v2); + case kFmod: + return std::fmod(v1, v2); + case kRemainder: + return std::remainder(v1, v2); + case kAtan2: + return std::atan2(v1, v2); + default: + throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + } + } + + Value value_; + const Block* scope_; + std::unordered_map eval_context_; + std::unordered_map> var_by_scope_; + std::unordered_map + buffer_mapping_; // TODO: change Var* to Buf* + std::unordered_map>> + internal_buffers_; +}; + +SimpleIREvaluator::SimpleIREvaluator( + Stmt* stmt, + const std::vector& buffer_args, + at::Device device, + const std::string& kernel_func_name) + : CodeGen(stmt, buffer_args, device, kernel_func_name) { + impl_ = std::make_unique(); + expand_intrinsics(); +} + +SimpleIREvaluator::~SimpleIREvaluator() {} + +void SimpleIREvaluator::call(const std::vector& args) { + if (args.size() != buffer_args().size()) { + throw malformed_input("bad args in IREvaluator call"); + } + for (size_t i = 0; i < args.size(); i++) { + bindArg(buffer_args()[i], args[i]); + } + stmt()->accept(&*impl_); + impl_->clear(); + USE_TRIGGER(simple_ir_eval_executed); +} + +void SimpleIREvaluator::bindArg(const BufferArg& buf, const CallArg& data) { + if (!buf.isVar()) { + impl_->bindBuf(buf.var(), data.data()); + return; + } + + switch (buf.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + impl_->bindVar(buf.var(), data.Name##Data()); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } +} +void SimpleIREvaluator::bindVar(const Var* v, const Expr* e) { + impl_->bindVar(v, impl_->evaluateExpr(e)); +} + +Value SimpleIREvaluator::value() const { + return impl_->value(); +} } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 7d15a001301fb..28f47c2c050e9 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include #include @@ -8,11 +10,9 @@ #include #include #include -#include #include #include #include -#include #include #include #include @@ -85,97 +85,26 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH); AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH); #undef VALUE_AS_VEC_DISPATCH -template -inline typename std::enable_if::value, T>::type mod_value( - T lhs, - T rhs) { - return lhs % rhs; +template +To raw_bitcast(const From& src) { + TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation"); + To storage; + std::memcpy(&storage, &src, sizeof(From)); + return reinterpret_cast(storage); } -template -inline typename std::enable_if::value, T>::type -mod_value(T lhs, T rhs) { - return std::fmod(lhs, rhs); -} - -inline bool mod_value(bool lhs, bool rhs) { - throw std::runtime_error("Attempted modulus of bool"); -} - -template -inline typename std::enable_if::value, T>::type div_value( - T lhs, - T rhs) { - TORCH_CHECK(rhs != 0, "Division by zero"); - return lhs / rhs; -} - -template -inline typename std::enable_if::value, T>:: - type __ubsan_ignore_float_divide_by_zero__ - div_value(T lhs, T rhs) { - return lhs / rhs; -} - -inline bool div_value(bool lhs, bool rhs) { - LOG(FATAL) << "Attempted division of bool"; - return false; -} - -class SimpleIREvaluator : public CodeGen, public IRVisitor { +class SimpleIREvaluatorImpl; +class TORCH_API SimpleIREvaluator : public CodeGen { public: - template - SimpleIREvaluator(Stmt* stmt, Ts... ts) : CodeGen(stmt, ts...) { - expand_intrinsics(); - } - SimpleIREvaluator( Stmt* stmt, const std::vector& buffer_args, - at::Device device = at::kCPU) - : CodeGen(stmt, buffer_args, device) { - expand_intrinsics(); - } + at::Device device = at::kCPU, + const std::string& kernel_func_name = "func"); - ~SimpleIREvaluator() override {} + ~SimpleIREvaluator() override; - TORCH_API void call(const std::vector& args) override { - if (args.size() != buffer_args().size()) { - throw malformed_input("bad args in IREvaluator call"); - } - for (size_t i = 0; i < args.size(); i++) { - bind(buffer_args()[i], args[i]); - } - stmt()->accept(this); - eval_context_.clear(); - buffer_mapping_.clear(); - internal_buffers_.clear(); - USE_TRIGGER(simple_ir_eval_executed); - } - - void bind(const BufferArg& buf, const CallArg& data) { - if (!buf.isVar()) { - buffer_mapping_[buf.var()] = data.data(); - return; - } - - switch (buf.dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - eval_context_[buf.var()] = data.Name##Data(); \ - break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } - } - - void bindVar(const Var* v, const Expr* e) { - e->accept(this); - Value value = value_; - eval_context_[v] = value_; - } + void call(const std::vector& args) override; template void operator()(const Ts&... ts) { @@ -183,695 +112,17 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor { call(args); } - TORCH_API void visit(const Add* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Sub* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Mul* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Div* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Mod* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Max* v) override { - visit_binary_op(v, v->propagate_nans()); - } - TORCH_API void visit(const Min* v) override { - visit_binary_op(v, v->propagate_nans()); - } - - TORCH_API void visit(const And* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Or* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Xor* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Lshift* v) override { - visit_binary_op(v); - } - TORCH_API void visit(const Rshift* v) override { - visit_binary_op(v); - } - - void visit(const CompareSelect* v) override { - visit_compare_select_op(v, v->compare_select_op()); - } - - template - typename std::enable_if_t::value, T> max_value( - T a, - T b) { - return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? b : a)); - } - - template - typename std::enable_if_t::value, T> max_value( - T a, - T b) { - return a < b ? b : a; - } - - template - typename std::enable_if_t::value, T> min_value( - T a, - T b) { - return std::isnan(a) ? a : (std::isnan(b) ? b : (a < b ? a : b)); - } - - template - typename std::enable_if_t::value, T> min_value( - T a, - T b) { - return a < b ? a : b; - } - - template - Value binary_op(const Value& lhs, const Value& rhs, IRNodeType op_type) { - std::vector lhs_v = lhs.as_vec(); - std::vector rhs_v = rhs.as_vec(); - std::vector result_v(lhs_v.size()); - for (size_t i = 0; i < lhs_v.size(); i++) { - switch (op_type) { - case IRNodeType::kAdd: - result_v[i] = lhs_v[i] + rhs_v[i]; - break; - case IRNodeType::kSub: - result_v[i] = lhs_v[i] - rhs_v[i]; - break; - case IRNodeType::kMul: - result_v[i] = lhs_v[i] * rhs_v[i]; - break; - case IRNodeType::kDiv: - result_v[i] = div_value(lhs_v[i], rhs_v[i]); - break; - case IRNodeType::kMod: - result_v[i] = mod_value(lhs_v[i], rhs_v[i]); - break; - case IRNodeType::kMax: - result_v[i] = max_value(lhs_v[i], rhs_v[i]); - break; - case IRNodeType::kMin: - result_v[i] = min_value(lhs_v[i], rhs_v[i]); - break; - default: - // TODO: change to a proper error report - throw std::runtime_error("invalid operator type"); - } - } - return Value(result_v); - } - - Value bitwise_binary_op( - const Value& lhs, - const Value& rhs, - IRNodeType op_type) { - std::vector lhs_v = lhs.as_vec(); - std::vector rhs_v = rhs.as_vec(); - std::vector result_v(lhs_v.size()); - for (size_t i = 0; i < lhs_v.size(); i++) { - switch (op_type) { - case IRNodeType::kAnd: - result_v[i] = lhs_v[i] & rhs_v[i]; - break; - case IRNodeType::kOr: - result_v[i] = lhs_v[i] | rhs_v[i]; - break; - case IRNodeType::kXor: - result_v[i] = lhs_v[i] ^ rhs_v[i]; - break; - case IRNodeType::kLshift: - result_v[i] = lhs_v[i] << rhs_v[i]; - break; - case IRNodeType::kRshift: - result_v[i] = lhs_v[i] >> rhs_v[i]; - break; - default: - // TODO: change to a proper error report - throw std::runtime_error("invalid operator type"); - } - } - return Value(result_v); - } - - template - Value compare_select_op( - const Value& lhs, - const Value& rhs, - const Value& retval1, - const Value& retval2, - CompareSelectOperation cmp_op) { - std::vector lhs_v = lhs.as_vec(); - std::vector rhs_v = rhs.as_vec(); - std::vector ret_val1_v = retval1.as_vec(); - std::vector ret_val2_v = retval2.as_vec(); - std::vector result_v(lhs_v.size()); - for (size_t i = 0; i < lhs_v.size(); i++) { - switch (cmp_op) { - case CompareSelectOperation::kEQ: - result_v[i] = (lhs_v[i] == rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; - break; - case CompareSelectOperation::kNE: - result_v[i] = (lhs_v[i] != rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; - break; - case CompareSelectOperation::kGT: - result_v[i] = (lhs_v[i] > rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; - break; - case CompareSelectOperation::kGE: - result_v[i] = (lhs_v[i] >= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; - break; - case CompareSelectOperation::kLT: - result_v[i] = (lhs_v[i] < rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; - break; - case CompareSelectOperation::kLE: - result_v[i] = (lhs_v[i] <= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; - break; - default: - // TODO: change to a proper error report - throw std::runtime_error("invalid operator type"); - } - } - return Value(result_v); - } - - template - void visit_binary_op(const BinaryOpNode* v, bool option = false) { - v->lhs()->accept(this); - Value lhs_v = value_; - v->rhs()->accept(this); - Value rhs_v = value_; - if (lhs_v.dtype() != rhs_v.dtype()) { - throw malformed_input("bad dtype in binary op", v); - } - IRNodeType expr_type = v->expr_type(); - if (expr_type == IRNodeType::kAnd || expr_type == IRNodeType::kOr || - expr_type == IRNodeType::kXor || expr_type == IRNodeType::kLshift || - expr_type == IRNodeType::kRshift) { - value_ = bitwise_binary_op(lhs_v, rhs_v, expr_type); - return; - } - - switch (lhs_v.dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - value_ = binary_op(lhs_v, rhs_v, expr_type); \ - break; - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); -#undef TYPE_CASE - case ScalarType::Bool: - value_ = binary_op(lhs_v, rhs_v, expr_type); - break; - default: - throw unsupported_dtype(); - } - } - - template - Value compare_select_op_helper( - const Value& lhs, - const Value& rhs, - const Value& retval1, - const Value& retval2, - CompareSelectOperation cmp_op) { - Value value; - switch (retval1.dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - value = compare_select_op(lhs, rhs, retval1, retval2, cmp_op); \ - break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } - - return value; - } - - void visit_compare_select_op( - const CompareSelect* v, - CompareSelectOperation cmp_op) { - v->lhs()->accept(this); - Value lhs_v = value_; - v->rhs()->accept(this); - Value rhs_v = value_; - v->ret_val1()->accept(this); - Value ret_val1_v = value_; - v->ret_val2()->accept(this); - Value ret_val2_v = value_; - - if (lhs_v.dtype() != rhs_v.dtype() || - ret_val1_v.dtype() != ret_val2_v.dtype()) { - throw malformed_input("bad dtype in CompareSelect", v); - } - - switch (lhs_v.dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - value_ = compare_select_op_helper( \ - lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \ - break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } - } - -#define IMM_VISIT(Type, Name) \ - TORCH_API void visit(const Name##Imm* v) override { \ - value_ = Value(v->value()); \ - } - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); -#undef IMM_VISIT - - TORCH_API void visit(const Block* v) override { - const Block* last = scope_; - scope_ = v; - for (Stmt* s : v->stmts()) { - s->accept(this); - } - - auto it = var_by_scope_.find(v); - if (it != var_by_scope_.end()) { - for (const Expr* v : it->second) { - eval_context_.erase(v); - } - var_by_scope_.erase(it); - } - - scope_ = last; - } - - TORCH_API void visit(const Var* v) override { - auto iter = eval_context_.find(v); - if (iter == eval_context_.end()) { - throw malformed_input("could not find Var in context", v); - } - - value_ = iter->second; - } - - template - std::vector castValues(const Dtype& src_dtype, const Value& v) { - const std::vector& src_values = v.as_vec(); - std::vector dst_values(src_values.size()); - for (int i = 0; i < src_dtype.lanes(); ++i) { - dst_values[i] = static_cast(src_values[i]); - } - return dst_values; - } - - template - void doCastFromSrc( - const Dtype& src_dtype, - const Dtype& dst_dtype, - const Value& v) { - switch (dst_dtype.scalar_type()) { -#define DST_TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - this->value_ = Value(castValues(src_dtype, v)); \ - break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DST_TYPE_CASE); -#undef DST_TYPE_CASE - default: - throw unsupported_dtype(); - } - } - - TORCH_API void visit(const Cast* v) override { - const Expr* src_value = v->src_value(); - src_value->accept(this); - Dtype dst_dtype = v->dtype(); - Dtype src_dtype = src_value->dtype(); - if (src_dtype.lanes() != dst_dtype.lanes()) { - throw malformed_input("lane mismatch in Cast", v); - } - - if (src_dtype != dst_dtype) { - switch (src_dtype.scalar_type()) { -#define SRC_TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - doCastFromSrc(src_dtype, dst_dtype, value_); \ - break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, SRC_TYPE_CASE); -#undef SRC_TYPE_CASE - default: - throw unsupported_dtype(); - } - } - } - - TORCH_API void visit(const For* v) override { - const Expr* var_node = v->var(); - v->start()->accept(this); - int start = value_.as(); - v->stop()->accept(this); - int stop = value_.as(); - if (eval_context_.count(var_node)) { - throw malformed_input("could not find var_node in For context", v); - } - - for (int i = start; i < stop; i++) { - eval_context_[var_node] = Value(i); - if (v->body()) { - v->body()->accept(this); - } - } - eval_context_.erase(var_node); - } - - TORCH_API void visit(const Ramp* v) override { - v->base()->accept(this); - int base = value().as(); - v->stride()->accept(this); - int stride = value().as(); - int lanes = v->lanes(); - - std::vector values(lanes); - for (int i = 0; i < lanes; i++) { - values[i] = base + i * stride; - } - - value_ = Value(values); - } - - TORCH_API void visit(const Broadcast* v) override { - v->value()->accept(this); - Value value = this->value(); - int lanes = v->lanes(); - switch (value.dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: { \ - std::vector v(lanes, value.as()); \ - value_ = Value(v); \ - } break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } - } - - TORCH_API void visit(const IfThenElse* v) override { - v->condition()->accept(this); - bool cond_v; - switch (value_.dtype().scalar_type()) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: { \ - cond_v = value_.as(); \ - } break; - AT_FORALL_SCALAR_TYPES_AND(Bool, TYPE_CASE); -#undef TYPE_CASE - case ScalarType::Half: - throw unsupported_dtype("IfThenElse condition can't have Half dtype"); - default: - throw unsupported_dtype(); - } - - if (cond_v) { - v->true_value()->accept(this); - } else { - v->false_value()->accept(this); - } - } - - TORCH_API void visit(const Load* v) override { - const Var* base_node = v->base_handle(); - auto iter = buffer_mapping_.find(base_node); - if (iter == buffer_mapping_.end()) { - throw malformed_input("could not find base node in Load", v); - } - void* ptr = iter->second; - - const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices()); - flat_idx->accept(this); - std::vector index = value().as_vec(); - v->mask()->accept(this); - std::vector mask = value().as_vec(); - ScalarType v_sdtype = v->dtype().scalar_type(); - switch (v_sdtype) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: { \ - Type* ptr##Name = static_cast(ptr); \ - std::vector v(index.size()); \ - for (size_t i = 0; i < index.size(); i++) { \ - if (mask[i]) { \ - v[i] = ptr##Name[index[i]]; \ - } \ - } \ - value_ = Value(v); \ - } break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } - } - - TORCH_API void visit(const Store* v) override { - const Var* base_node = v->base_handle(); - auto iter = buffer_mapping_.find(base_node); - if (iter == buffer_mapping_.end()) { - throw malformed_input("could not find base node in Store", v); - } - - void* ptr = iter->second; - - const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices()); - flat_idx->accept(this); - std::vector index = value().as_vec(); - v->mask()->accept(this); - std::vector mask = value().as_vec(); - if (index.size() != mask.size()) { - throw malformed_input("mask size mismatch in Store", v); - } - - ScalarType v_sdtype = v->value()->dtype().scalar_type(); - - switch (v_sdtype) { -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: { \ - v->value()->accept(this); \ - std::vector value = this->value().as_vec(); \ - if (index.size() != value.size()) { \ - throw malformed_input("value size mismatch in Store", v); \ - } \ - Type* ptr##Name = static_cast(ptr); \ - for (size_t i = 0; i < index.size(); i++) { \ - if (mask[i]) { \ - ptr##Name[index[i]] = value[i]; \ - } \ - } \ - } break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } - } - - TORCH_API void visit(const BaseCallNode* v) override { - throw unimplemented_lowering(v); - } - - template - void visit_intrinsics_helper(const Intrinsics* v) { - std::vector values(v->nparams()); - for (int i = 0; i < v->nparams(); i++) { - v->param(i)->accept(this); - values[i] = this->value(); - } - std::vector v1; - if (values.size() >= 1ULL) { - v1 = values[0].as_vec(); - } - std::vector v2; - if (values.size() >= 2ULL) { - v2 = values[1].as_vec(); - if (v1.size() != v2.size()) { - throw malformed_input("value size mismatch in Intrinsics", v); - } - } - - if (values.size() > 2) { - throw unimplemented_lowering(v); - } - - std::vector result(v1.size(), -1); - if (values.size() == 1ULL) { - for (size_t i = 0; i < v1.size(); i++) { - result[i] = compute_intrinsics(v->op_type(), v1[i]); - } - } else { - for (size_t i = 0; i < v1.size(); i++) { - result[i] = compute_intrinsics(v->op_type(), v1[i], v2[i]); - } - } - value_ = Value(result); - } - - TORCH_API void visit(const Intrinsics* v) override { - auto ty = v->dtype().scalar_type(); - if (ty == ScalarType::Float) { - visit_intrinsics_helper(v); - } else if (ty == ScalarType::Double) { - visit_intrinsics_helper(v); - } else { - throw unsupported_dtype(); - } - } - - void visit(const Allocate* v) override { - const Var* buffer_var = v->buffer_var(); - std::vector dims = v->dims(); - int total_byte_size = v->dtype().byte_size(); - for (size_t i = 0; i < dims.size(); i++) { - dims[i]->accept(this); - total_byte_size *= value_.as(); - } - int int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int); - std::unique_ptr> buffer(new std::vector(int_count)); - auto iter = buffer_mapping_.find(buffer_var); - if (iter != buffer_mapping_.end() && iter->second != nullptr) { - throw std::runtime_error( - "Allocate a buffer that has already been allocated: " + - buffer_var->name_hint()); - } - buffer_mapping_[buffer_var] = buffer->data(); - internal_buffers_.insert(std::make_pair(buffer_var, std::move(buffer))); - } - - void visit(const Free* v) override { - const Var* buffer_var = v->buffer_var(); - int count = internal_buffers_.erase(buffer_var); - if (count == 0) { - throw std::runtime_error( - "Free a buffer that is not currently bound: " + - buffer_var->name_hint()); - } - buffer_mapping_.erase(buffer_var); - } - - void visit(const Let* v) override { - var_by_scope_[scope_].push_back(v->var()); - bindVar(v->var(), v->value()); - } - - void visit(const Cond* v) override { - v->condition()->accept(this); - if (value().as()) { - if (v->true_stmt()) { - v->true_stmt()->accept(this); - } - } else { - if (v->false_stmt()) { - v->false_stmt()->accept(this); - } - } - } - - Value value() const { - return value_; - } + void bindVar(const Var* v, const Expr* e); + Value value() const; private: + void bindArg(const BufferArg& buf, const CallArg& data); void expand_intrinsics() { GenericIntrinsicsExpander intrinsics_expander; apply_mutator(&intrinsics_expander); } - template - static T compute_intrinsics(IntrinsicsOp op_type, T v) { - switch (op_type) { - case kSin: - return std::sin(v); - case kCos: - return std::cos(v); - case kTan: - return std::tan(v); - case kAsin: - return std::asin(v); - case kAcos: - return std::acos(v); - case kAtan: - return std::atan(v); - case kSinh: - return std::sinh(v); - case kCosh: - return std::cosh(v); - case kTanh: - return std::tanh(v); - case kExp: - return std::exp(v); - case kFabs: - return std::fabs(v); - case kExpm1: - return std::expm1(v); - case kLog: - return std::log(v); - case kLog2: - return std::log2(v); - case kLog10: - return std::log10(v); - case kLog1p: - return std::log1p(v); - case kErf: - return std::erf(v); - case kErfc: - return std::erfc(v); - case kSqrt: - return std::sqrt(v); - case kRsqrt: - return 1.0f / std::sqrt(v); - case kCeil: - return std::ceil(v); - case kFloor: - return std::floor(v); - case kRound: - return std::round(v); - case kTrunc: - return std::trunc(v); - case kLgamma: - return std::lgamma(v); - case kFrac: - T intpart; - return std::modf(v, &intpart); - default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); - } - } - - template - static T compute_intrinsics(IntrinsicsOp op_type, T v1, T v2) { - switch (op_type) { - case kPow: - return std::pow(v1, v2); - case kFmod: - return std::fmod(v1, v2); - case kRemainder: - return std::remainder(v1, v2); - case kAtan2: - return std::atan2(v1, v2); - default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); - } - } - - Value value_; - const Block* scope_; - std::unordered_map eval_context_; - std::unordered_map> var_by_scope_; - std::unordered_map buffer_mapping_; - std::unordered_map>> - internal_buffers_; + std::unique_ptr impl_; }; template @@ -887,7 +138,7 @@ class ExprEval { ExprEval(const ExprHandle& expr, const std::vector& buffer_args) : dtype_(expr.dtype()) { std::vector buffer_args_extended = buffer_args; - Buffer ret_buf("ret_val", dtype_, {1}); + Placeholder ret_buf("ret_val", dtype_, {1}); std::vector indices; const Expr* zero = new IntImm(0); for (size_t i = 0; i < ret_buf.data()->ndim(); i++) { diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp index acbe0879e896d..19195676ff8c2 100644 --- a/torch/csrc/jit/tensorexpr/expr.cpp +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -124,8 +124,48 @@ ExprHandle expm1(const ExprHandle& v) { return Intrinsics::make(kExpm1, v); } -ExprHandle fabs(const ExprHandle& v) { - return Intrinsics::make(kFabs, v); +ExprHandle abs(const ExprHandle& v) { + return Intrinsics::make(kAbs, v); +} + +ExprHandle fast_log(const ExprHandle& v) { + // this implementation is taken from sleef: + // https://github.com/shibatch/sleef/blob/master/src/libm/sleefsp.c#L1131 + // to generate coefficients, this tool is provided + // https://github.com/shibatch/sleef/blob/master/src/gencoef/gencoef.txt + auto ilogb2kf = [](ExprHandle x) { + auto y = (bitcast(x) >> IntImm::make(23)) & IntImm::make(0xff); + return y - IntImm::make(0x7f); + }; + + auto ldexp3kf = [](ExprHandle x, ExprHandle e) { + return bitcast(bitcast(x) + (e << IntImm::make(23))); + }; + auto e = ilogb2kf(v * FloatImm::make(1.0 / 0.75)); + auto m = ldexp3kf(v, IntImm::make(-1) * e); + auto one = FloatImm::make(1.0f); + auto x = (m - one) / (m + one); + auto x2 = x * x; + + auto mlaf = [](ExprHandle x, ExprHandle y, float z) { + return x * y + FloatImm::make(z); + }; + + auto t = FloatImm::make(0.2392828464508056640625); + t = mlaf(t, x2, 0.28518211841583251953125); + t = mlaf(t, x2, 0.400005877017974853515625); + t = mlaf(t, x2, 0.666666686534881591796875); + t = mlaf(t, x2, 2.0); + x = x * t + FloatImm::make(0.693147180559945286226764) * e; + x = IfThenElse::make( + v < FloatImm::make(0), + FloatImm::make(std::numeric_limits::quiet_NaN()), + x); + x = IfThenElse::make( + v == FloatImm::make(0), + FloatImm::make(-std::numeric_limits::infinity()), + x); + return x; } ExprHandle log(const ExprHandle& v) { @@ -200,6 +240,10 @@ ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2) { return Intrinsics::make(kRemainder, v1, v2); } +ExprHandle isnan(const ExprHandle& v1) { + return Intrinsics::make(kIsNan, v1); +} + ExprHandle ifThenElse( const ExprHandle& c, const ExprHandle& t, @@ -219,6 +263,14 @@ ExprHandle Buf::make(const std::vector& dims, Dtype dtype) { return Buf::make("", dims, dtype); } +ExprHandle expr_to_vec(ExprHandle v, int lanes) { + if (lanes == 1) { + return v; + } else { + return Broadcast::make(v, lanes); + } +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index c93fe67f19372..ca3b0b421928c 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -31,6 +31,7 @@ enum IRNodeType { kCompareSelect, kLet, kCast, + kBitCast, kBroadcast, kRamp, kPolynomial, @@ -194,6 +195,9 @@ class TORCH_API Buf : public ExprNode { return dims_.size(); } const Expr* dim(size_t index) const { + if (index >= ndim()) { + throw out_of_range_index(); + } return dims_[index]; } std::vector dims() const { @@ -208,7 +212,6 @@ class TORCH_API Buf : public ExprNode { std::vector dims_; }; -// TODO: Merge this class with 'Buffer' class TORCH_API BufHandle : public ExprHandle { public: BufHandle( @@ -285,8 +288,9 @@ TORCH_API ExprHandle tanh(const ExprHandle& v); TORCH_API ExprHandle sigmoid(const ExprHandle& v); TORCH_API ExprHandle exp(const ExprHandle& v); TORCH_API ExprHandle expm1(const ExprHandle& v); -TORCH_API ExprHandle fabs(const ExprHandle& v); +TORCH_API ExprHandle abs(const ExprHandle& v); TORCH_API ExprHandle log(const ExprHandle& v); +TORCH_API ExprHandle fast_log(const ExprHandle& v); TORCH_API ExprHandle log2(const ExprHandle& v); TORCH_API ExprHandle log10(const ExprHandle& v); TORCH_API ExprHandle log1p(const ExprHandle& v); @@ -304,10 +308,13 @@ TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2); TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2); +TORCH_API ExprHandle isnan(const ExprHandle& v1); TORCH_API ExprHandle ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f); +TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes); + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/function.cpp b/torch/csrc/jit/tensorexpr/function.cpp index c86ec70dab21f..e69de29bb2d1d 100644 --- a/torch/csrc/jit/tensorexpr/function.cpp +++ b/torch/csrc/jit/tensorexpr/function.cpp @@ -1,144 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -Tensor* Compute( - const std::string& func_name, - const std::vector& dim_args, - const std::function&)>& body_func) { - std::vector dims; - std::vector args; - unpack_dim_args(dim_args, &dims, &args); - const Expr* body = body_func(VarVectorToVarHandleVector(args)).node(); - Function* func = new Function(func_name, dims, args, body); - const Buf* buf = func->func_var(0); - return new Tensor(buf, func, 0); -} - -Tensor* Compute( - const std::string& func_name, - const std::vector& dim_args, - const std::function& body_func) { - if (dim_args.size() != 1) { - throw malformed_input("mismatch between body and arg size (1)"); - } - - std::vector dims; - std::vector args; - unpack_dim_args(dim_args, &dims, &args); - const Expr* body = body_func(VarHandle(args[0])).node(); - Function* func = new Function(func_name, dims, args, body); - const Buf* buf = func->func_var(0); - return new Tensor(buf, func, 0); -} - -Tensor* Compute( - const std::string& func_name, - const std::vector& dim_args, - const std::function& - body_func) { - if (dim_args.size() != 2) { - throw malformed_input("mismatch between body and arg size (2)"); - } - std::vector dims; - std::vector args; - unpack_dim_args(dim_args, &dims, &args); - const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); - Function* func = new Function(func_name, dims, args, body); - const Buf* buf = func->func_var(0); - return new Tensor(buf, func, 0); -} - -Tensor* Compute( - const std::string& func_name, - const std::vector& dim_args, - const std::function< - ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& - body_func) { - if (dim_args.size() != 3) { - throw malformed_input("mismatch between body and arg size (3)"); - } - std::vector dims; - std::vector args; - unpack_dim_args(dim_args, &dims, &args); - const Expr* body = - body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])) - .node(); - Function* func = new Function(func_name, dims, args, body); - const Buf* buf = func->func_var(0); - return new Tensor(buf, func, 0); -} - -Tensor* Compute( - const std::string& func_name, - const std::vector& dim_args, - const std::function& body_func) { - if (dim_args.size() != 4) { - throw malformed_input("mismatch between body and arg size (4)"); - } - std::vector dims; - std::vector args_nodes; - unpack_dim_args(dim_args, &dims, &args_nodes); - auto args = VarVectorToVarHandleVector(args_nodes); - const Expr* body = body_func(args[0], args[1], args[2], args[3]).node(); - Function* func = new Function(func_name, dims, args_nodes, body); - const Buf* buf = func->func_var(0); - return new Tensor(buf, func, 0); -} - -Stmt* Function::ElementStmt(size_t index) { - const Buf* buf = func_var(index); - std::vector indices; - for (size_t i = 0; i < buf->ndim(); i++) { - indices.push_back(this->args_[i]); - } - - const Expr* mask = new IntImm(1); - - Stmt* update_stmt = new Store(buf, indices, body(index), mask); - return update_stmt; -} - -Tensor* Reduce( - const std::string& func_name, - const std::vector& dim_args, - const Reducer& reducer, - const Buffer& buffer, - const std::vector& reduce_args) { - return Reduce( - func_name, - dim_args, - reducer, - [&](ParameterList& p) { return buffer.call(p); }, - reduce_args); -} - -Tensor* Reduce( - const std::string& func_name, - const std::vector& dim_args, - const Reducer& reducer, - Tensor* tensor, - const std::vector& reduce_args) { - return Reduce( - func_name, - dim_args, - reducer, - [&](ParameterList& p) { return tensor->call(p); }, - reduce_args); -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/function.h b/torch/csrc/jit/tensorexpr/function.h deleted file mode 100644 index 128253d95ff02..0000000000000 --- a/torch/csrc/jit/tensorexpr/function.h +++ /dev/null @@ -1,105 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -class Function : public KernelScopedObject { - public: - Function( - const std::string& func_name, - const std::vector& dims, - const std::vector& args, - const Expr* body) - // TODO: Function should not create buffers, they should be created - // manually before constructing a function. - : func_vars_({new Buf(func_name, dims, body->dtype())}), - dims_(dims), - args_(args), - bodies_({body}) {} - Function( - const std::vector& func_names, - const std::vector& dims, - const std::vector& args, - const std::vector& bodies) - : func_vars_(func_names.size()), - dims_(dims), - args_(args), - bodies_(bodies) { - for (size_t i = 0; i < func_names.size(); i++) { - func_vars_[i] = new Buf(func_names[i], dims, bodies[i]->dtype()); - } - } - Function( - const std::string& func_name, - Buf* func_var, - const std::vector& dims, - const std::vector& args, - const Expr* body) - : func_vars_({func_var}), dims_(dims), args_(args), bodies_({body}) {} - - size_t ndim() const { - return dims_.size(); - } - - const Expr* dim(size_t index) const { - if (index < 0 || index >= dims_.size()) { - throw out_of_range_index(); - } - - return dims_[index]; - } - const std::vector& dims() const { - return dims_; - } - - const Var* arg(size_t index) const { - if (index < 0 || index >= args_.size()) { - throw out_of_range_index(); - } - - return args_[index]; - } - const std::vector& args() const { - return args_; - } - - std::vector bodies() const { - return bodies_; - } - const Expr* body(size_t index) const { - if (index >= bodies_.size()) { - throw out_of_range_index(); - } - - return bodies_[index]; - } - - std::vector func_vars() const { - return func_vars_; - } - const Buf* func_var(size_t index) const { - if (index >= func_vars_.size()) { - throw out_of_range_index(); - } - return func_vars_[index]; - } - - Stmt* ElementStmt(size_t index); - - private: - std::vector func_vars_; - std::vector dims_; - std::vector args_; - std::vector bodies_; -}; - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h new file mode 100644 index 0000000000000..571f8092cb4ba --- /dev/null +++ b/torch/csrc/jit/tensorexpr/half_support.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +// Walk the Statment looking for Half size loads/stores. +class HalfChecker : public IRVisitor { + public: + HalfChecker(const std::vector& args) { + for (const auto& BA : args) { + hasHalf_ |= BA.dtype().scalar_type() == ScalarType::Half; + } + } + + bool hasHalf() { + return hasHalf_; + } + + void visit(const Load* v) override { + hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; + IRVisitor::visit(v); + } + + void visit(const Store* v) override { + hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half; + IRVisitor::visit(v); + } + + void visit(const HalfImm* v) override { + hasHalf_ = true; + } + + void visit(const Cast* v) override { + hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; + IRVisitor::visit(v); + } + + private: + bool hasHalf_{false}; +}; + +class HalfRewriter : public IRMutator { + const Expr* mutate(const Load* v) override { + const Expr* child = IRMutator::mutate(v); + if (child->dtype().scalar_type() != ScalarType::Half) { + return child; + } + + const Expr* ret = + new Cast(child->dtype().cloneWithScalarType(ScalarType::Float), child); + + inserted_half_casts_.insert(ret); + return ret; + } + + Stmt* mutate(const Store* v) override { + const Expr* new_val = v->value()->accept_mutator(this); + + Dtype newType = v->value()->dtype(); + if (newType.scalar_type() == ScalarType::Half) { + new_val = + new Cast(newType.cloneWithScalarType(ScalarType::Half), new_val); + inserted_half_casts_.insert(new_val); + } + + return new Store(v->buf(), v->indices(), new_val, v->mask()); + } + + const Expr* mutate(const HalfImm* v) override { + return new Cast(kFloat, v); + } + + const Expr* mutate(const Cast* v) override { + const Expr* child = v->src_value()->accept_mutator(this); + + // just don't allow half casts we didn't insert. + if (v->dtype().scalar_type() == ScalarType::Half) { + if (inserted_half_casts_.count(v) < 1) { + return child; + } + } + + // Remove Half(Float()) and friends. + const Cast* cast_child = dynamic_cast(child); + if (cast_child) { + if (v->dtype().is_floating_point() && + cast_child->dtype().is_floating_point()) { + return new Cast(v->dtype(), cast_child->src_value()); + } + } + + if (child == v->src_value()) { + return v; + } + + return new Cast(v->dtype(), child); + } + Stmt* mutate(const Let* v) override { + if (v->dtype().scalar_type() == ScalarType::Half) { + const Var* load_new_var = new Var(v->var()->name_hint(), kFloat); + const Expr* new_value = new Cast( + v->dtype().cloneWithScalarType(ScalarType::Float), + v->value()->accept_mutator(this)); + var_map[v->var()] = load_new_var; + + return new Let(load_new_var, new_value); + } + + return IRMutator::mutate(v); + } + + const Expr* mutate(const Var* v) override { + auto it = var_map.find(v); + if (it != var_map.end()) { + return it->second; + } + + return v; + } + + private: + std::unordered_set inserted_half_casts_; + std::unordered_map var_map; +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index 69144bb9b7a0a..93fc15f7549c3 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -1,6 +1,6 @@ #include -#include +#include namespace torch { namespace jit { @@ -39,43 +39,40 @@ static bool indicesValid(const std::vector& indices) { return true; } -Load::Load( - const Buffer& buffer, - const std::vector& indices, - const Expr* mask) - : Load( - ChooseDtype(buffer.dtype(), dtypeOfIndices(indices)), - buffer.data(), - indices, - mask) {} - -Load::Load( - Dtype dtype, - const Buf* buf, - const std::vector& indices, - const Expr* mask) - : ExprNodeBase(dtype), buf_(buf), indices_(indices), mask_(mask) { - if (indices_.size() > 0 && buf->base_handle()->dtype() != kHandle) { +void Load::verify_dtypes() const { + if (indices_.size() > 0 && buf_->base_handle()->dtype() != kHandle) { throw malformed_input( - "Load base handle dtype must be Handle", buf->base_handle()); + "Load base handle dtype must be Handle", buf_->base_handle()); } - if (!indicesValid(indices)) { + if (!indicesValid(indices_)) { throw malformed_input("invalid indices in Load"); } - Dtype index_dtype = dtypeOfIndices(indices); - if (index_dtype.lanes() != mask->dtype().lanes()) { + Dtype index_dtype = dtypeOfIndices(indices_); + if (index_dtype.lanes() != mask_->dtype().lanes()) { throw malformed_input("lane mismatch in Load mask"); } } -ExprHandle Load::make( - const Buffer& buffer, - const std::vector& indices, - const ExprHandle& mask) { - return ExprHandle( - new Load(buffer, ExprHandleVectorToExprVector(indices), mask.node())); +Load::Load( + Dtype dtype, + const Buf* buf, + const std::vector& indices, + const Expr* mask) + : ExprNodeBase(dtype), buf_(buf), indices_(indices), mask_(mask) { + verify_dtypes(); } + +Load::Load( + const Buf* buf, + const std::vector& indices, + const Expr* mask) + : Load( + ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), + buf, + indices, + mask) {} + ExprHandle Load::make( Dtype dtype, const BufHandle& buf, @@ -85,15 +82,11 @@ ExprHandle Load::make( dtype, buf.node(), ExprHandleVectorToExprVector(indices), mask.node())); } -Store::Store( - const Buffer& buffer, - const std::vector& indices, - const Expr* value, - const Expr* mask) - : Store(buffer.data(), indices, value, mask) { - if (buffer.dtype().scalar_type() != value->dtype().scalar_type()) { - throw malformed_input("invalid dtype in Store"); - } +ExprHandle Load::make( + const BufHandle& buf, + const std::vector& indices, + const ExprHandle& mask) { + return Load::make(buf.dtype(), buf, indices, mask); } Store::Store( @@ -128,15 +121,6 @@ Store::Store( */ } -Store* Store::make( - const Buffer& buffer, - const std::vector& indices, - const ExprHandle& value, - const ExprHandle& mask) { - return new Store( - buffer, ExprHandleVectorToExprVector(indices), value.node(), mask.node()); -} - Store* Store::make( const BufHandle& buf, const std::vector& indices, @@ -191,6 +175,9 @@ const Expr* flatten_index( } Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) { + if (op_type == kIsNan) { + return dt1.cloneWithScalarType(ScalarType::Int); + } // TODO: check the op_type and make a real decision return dt1; } @@ -203,11 +190,15 @@ Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) { Dtype Intrinsics::IntrinsicsDtype( IntrinsicsOp op_type, const std::vector& params) { - // TODO: check the op_type an dmake a real decision + // TODO: check the op_type and make a real decision + // Doesnt this fail with kRand? if (params.size() == 0) { throw malformed_input("invalid params in Intrinsics"); + } else if (params.size() == 1) { + return IntrinsicsDtype(op_type, params[0]->dtype()); + } else if (params.size() == 2) { + return IntrinsicsDtype(op_type, params[0]->dtype(), params[1]->dtype()); } - return params[0]->dtype(); } @@ -225,7 +216,7 @@ int Intrinsics::OpArgCount(IntrinsicsOp op_type) { case kSigmoid: case kExp: case kExpm1: - case kFabs: + case kAbs: case kLog: case kLog2: case kLog10: @@ -240,6 +231,7 @@ int Intrinsics::OpArgCount(IntrinsicsOp op_type) { case kTrunc: case kFrac: case kLgamma: + case kIsNan: return 1; case kRand: return 0; diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index d75b611145f99..299655997c4d1 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -28,6 +28,7 @@ inline int getPrecedence(IRNodeType ty) { case kPrimitive: return 0; case kCast: + case kBitCast: return 2; case kAdd: case kSub: @@ -55,7 +56,7 @@ inline int getPrecedence(IRNodeType ty) { } } -class Buffer; +class Placeholder; class Cast : public ExprNode { public: @@ -81,6 +82,34 @@ ExprHandle cast(const ExprHandle& src_value) { return Cast::make(Dtype(ToDtype(), src_value.dtype().lanes()), src_value); } +// This is a bitwise cast, akin to bitcast in LLVM +class BitCast : public ExprNode { + public: + const Expr* src_value() const { + return src_value_; + } + static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { + return ExprHandle(new BitCast(dtype, src_value.node())); + } + BitCast(Dtype dtype, const Expr* src_value) + : ExprNodeBase(dtype, kBitCast), src_value_(src_value) { + TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size()); + } + + bool isConstant() const override { + return src_value_->isConstant(); + } + + private: + const Expr* src_value_; +}; + +template +ExprHandle bitcast(const ExprHandle& src_value) { + return BitCast::make( + Dtype(ToDtype(), src_value.dtype().lanes()), src_value); +} + // Represent the expression node for binary operators. // A CRTP pattern to share common code among the operators. template @@ -150,69 +179,51 @@ class Mod : public BinaryOpNode { : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {} }; -class And : public BinaryOpNode { +template +class BitwiseOpNode : public BinaryOpNode { public: - And(const Expr* lhs, const Expr* rhs) - : BinaryOpNode(lhs, rhs, IRNodeType::kAnd) { - if (!lhs->dtype().is_integral()) { + BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type) + : BinaryOpNode(lhs, rhs, type) {} + + static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) { + if (!lhs.dtype().is_integral()) { throw unsupported_dtype(); } - if (lhs->dtype() != rhs->dtype()) { - throw malformed_input("bad dtype in And"); + if (lhs.dtype() != rhs.dtype()) { + throw malformed_input("lhs/rhs dtype mismatch"); } + return BinaryOpNode::make(lhs, rhs); } }; -class Or : public BinaryOpNode { +class And : public BitwiseOpNode { + public: + And(const Expr* lhs, const Expr* rhs) + : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {} +}; + +class Or : public BitwiseOpNode { public: Or(const Expr* lhs, const Expr* rhs) - : BinaryOpNode(lhs, rhs, IRNodeType::kOr) { - if (!lhs->dtype().is_integral()) { - throw unsupported_dtype(); - } - if (lhs->dtype() != rhs->dtype()) { - throw malformed_input("bad dtype in Or"); - } - } + : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {} }; -class Xor : public BinaryOpNode { +class Xor : public BitwiseOpNode { public: Xor(const Expr* lhs, const Expr* rhs) - : BinaryOpNode(lhs, rhs, IRNodeType::kXor) { - if (!lhs->dtype().is_integral()) { - throw unsupported_dtype(); - } - if (lhs->dtype() != rhs->dtype()) { - throw malformed_input("bad dtype in Xor"); - } - } + : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {} }; -class Lshift : public BinaryOpNode { +class Lshift : public BitwiseOpNode { public: Lshift(const Expr* lhs, const Expr* rhs) - : BinaryOpNode(lhs, rhs, IRNodeType::kLshift) { - if (lhs->dtype().scalar_type() != ScalarType::Int) { - throw unsupported_dtype(); - } - if (lhs->dtype() != rhs->dtype()) { - throw malformed_input("bad dtype in Lshift"); - } - } + : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {} }; -class Rshift : public BinaryOpNode { +class Rshift : public BitwiseOpNode { public: Rshift(const Expr* lhs, const Expr* rhs) - : BinaryOpNode(lhs, rhs, IRNodeType::kRshift) { - if (lhs->dtype().scalar_type() != ScalarType::Int) { - throw unsupported_dtype(); - } - if (lhs->dtype() != rhs->dtype()) { - throw malformed_input("bad dtype in Rshift"); - } - } + : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {} }; class Max : public BinaryOpNode { @@ -391,26 +402,28 @@ class TORCH_API Load : public ExprNode { return buf_; } static ExprHandle make( - const Buffer& buffer, + Dtype dtype, + const BufHandle& buf, const std::vector& indices, const ExprHandle& mask); static ExprHandle make( - Dtype dtype, const BufHandle& buf, const std::vector& indices, const ExprHandle& mask); Load( - const Buffer& buffer, + Dtype dtype, + const Buf* base_handle, const std::vector& indices, const Expr* mask); Load( - Dtype dtype, const Buf* base_handle, const std::vector& indices, const Expr* mask); private: + void verify_dtypes() const; + const Buf* buf_; std::vector indices_; const Expr* mask_; @@ -624,7 +637,7 @@ enum IntrinsicsOp { kSigmoid, kExp, kExpm1, - kFabs, + kAbs, kLog, kLog2, kLog10, @@ -642,6 +655,7 @@ enum IntrinsicsOp { kRemainder, kLgamma, kFrac, + kIsNan, kRand, // We need more discussions on this. Should we consider stateful? }; @@ -702,8 +716,8 @@ class Intrinsics : public CallNode { return "sigmoid"; case kExp: return "exp"; - case kFabs: - return "fabs"; + case kAbs: + return "abs"; case kLog: return "log"; case kLog2: @@ -742,6 +756,8 @@ class Intrinsics : public CallNode { return "erfc"; case kFrac: return "frac"; + case kIsNan: + return "isnan"; default: throw std::runtime_error( "invalid op_type: " + c10::to_string(op_type())); diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index d3a0cc45d27fd..ddbe88bb2c8f4 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -139,6 +139,15 @@ const Expr* IRMutator::mutate(const Cast* v) { return new Cast(v->dtype(), src_value_new); } +const Expr* IRMutator::mutate(const BitCast* v) { + const Expr* src_value = v->src_value(); + const Expr* src_value_new = src_value->accept_mutator(this); + if (src_value_new == v->src_value()) { + return v; + } + return new BitCast(v->dtype(), src_value_new); +} + const Expr* IRMutator::mutate(const Var* v) { return v; } @@ -270,7 +279,7 @@ const Expr* IRMutator::mutate(const MinTerm* v) { const Expr* IRMutator::mutate(const ReduceOp* v) { const Expr* buf_new_expr = v->accumulator()->accept_mutator(this); const Buf* buf_new = dynamic_cast(buf_new_expr); - auto body = v->body().node()->accept_mutator(this); + const Expr* body_new = v->body()->accept_mutator(this); std::vector new_output_args; std::vector new_reduce_args; @@ -282,11 +291,7 @@ const Expr* IRMutator::mutate(const ReduceOp* v) { } return new ReduceOp( - buf_new, - ExprHandle(body), - v->interaction(), - new_output_args, - new_reduce_args); + buf_new, body_new, new_output_args, new_reduce_args, v->reducer()); } const Expr* IRMutator::mutate(const BaseCallNode* v) { diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h index 0913da0e972df..773920cb52faf 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.h +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -26,6 +26,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); #undef IMM_DECLARE class Cast; +class BitCast; class Var; class Buf; class Ramp; @@ -75,6 +76,7 @@ class TORCH_API IRMutator { AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); #undef IMM_MUTATE_DECLARE virtual const Expr* mutate(const Cast* v); + virtual const Expr* mutate(const BitCast* v); virtual const Expr* mutate(const Var* v); virtual const Expr* mutate(const Buf* v); virtual const Expr* mutate(const Ramp* v); diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 4fc7336f8eac7..1df2f96671dfb 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include @@ -176,7 +175,7 @@ void IRPrinter::visit(const CompareSelect* v) { } static void formatFPSuffix(std::ostream& os, double v) { - // No suffix for doubles. + os << (v == std::ceil(v) ? ".0" : ""); } template @@ -350,8 +349,7 @@ void IRPrinter::visit(const MinTerm* v) { void IRPrinter::visit(const ReduceOp* v) { os() << "ReduceOp("; - os() << *v->accumulator() << ", "; - os() << v->complete() << ", "; + os() << *v->body() << ", "; bool first = true; os() << "out_args={"; @@ -553,11 +551,6 @@ std::ostream& operator<<(std::ostream& stream, const Tensor& t) { return stream; } -std::ostream& operator<<(std::ostream& stream, const Function& f) { - stream << std::to_string(&f); - return stream; -} - void print(const Expr* expr) { if (expr) { IRPrinter p(std::cout); @@ -565,6 +558,7 @@ void print(const Expr* expr) { } else { std::cout << "(null expr)"; } + std::cout << "\n"; } void print(const Stmt* stmt) { @@ -580,10 +574,6 @@ void print(const Tensor* t) { std::cout << std::to_string(t); } -void print(const Function* f) { - std::cout << std::to_string(f); -} - } // namespace tensorexpr } // namespace jit } // namespace torch @@ -606,6 +596,11 @@ std::string to_string(const Tensor* t) { return "(null tensor)\n"; } std::ostringstream oss; + if (!t->body()) { + oss << "Tensor " << t->buf()->name_hint() << " = " << *t->ElementStmt() + << "\n"; + return oss.str(); + } oss << "Tensor " << t->buf()->name_hint() << "("; for (size_t i = 0; i < t->ndim(); i++) { if (i != 0) { @@ -616,24 +611,4 @@ std::string to_string(const Tensor* t) { oss << ") = " << *t->body() << "\n"; return oss.str(); } - -std::string to_string(const Function* f) { - if (!f) { - return "(null function)\n"; - } - std::ostringstream oss; - oss << "Function F("; - for (size_t i = 0; i < f->ndim(); i++) { - if (i != 0) { - oss << ", "; - } - oss << *f->arg(i) << "[" << *f->dim(i) << "]"; - } - oss << ") {\n"; - for (size_t i = 0; i < f->bodies().size(); i++) { - oss << " " << *f->func_var(i) << " = " << *f->body(i) << "\n"; - } - oss << "}\n"; - return oss.str(); -} } // namespace std diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h index 64ba35280371c..d9079d7fb7177 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.h +++ b/torch/csrc/jit/tensorexpr/ir_printer.h @@ -11,7 +11,6 @@ namespace jit { namespace tensorexpr { class Tensor; -class Function; class TORCH_API IRPrinter : public IRVisitor { public: @@ -95,12 +94,10 @@ TORCH_API std::ostream& operator<<(std::ostream& stream, const Expr&); TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&); TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&); TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&); -TORCH_API std::ostream& operator<<(std::ostream& stream, const Function&); TORCH_API void print(const Expr* expr); TORCH_API void print(const Stmt* stmt); TORCH_API void print(const Tensor* t); -TORCH_API void print(const Function* f); } // namespace tensorexpr } // namespace jit @@ -109,12 +106,10 @@ TORCH_API void print(const Function* f); namespace std { using torch::jit::tensorexpr::Expr; -using torch::jit::tensorexpr::Function; using torch::jit::tensorexpr::Stmt; using torch::jit::tensorexpr::Tensor; TORCH_API std::string to_string(const Expr* expr); TORCH_API std::string to_string(const Stmt* stmt); TORCH_API std::string to_string(const Tensor* t); -TORCH_API std::string to_string(const Function* f); } // namespace std diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 37c856a2e618e..685bdc4aad823 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -516,10 +516,10 @@ const Expr* PolynomialTransformer::mutate(const Sub* v) { if (rhsPoly && lhsTerm) { // Negate every part of the Polynomial. const Expr* minusOne = getImmediateByType(lhsTerm->dtype(), -1); - const Expr* negateScalar = evaluateOp(new Mul(minusOne, lhsTerm->scalar())); + const Expr* negateScalar = evaluateOp(new Mul(minusOne, rhsPoly->scalar())); std::vector variables; - for (auto* t : lhsPoly->variables()) { + for (auto* t : rhsPoly->variables()) { const Expr* negate = evaluateOp(new Mul(minusOne, t->scalar())); variables.push_back(new Term(hasher_, negate, t->variables())); } @@ -984,17 +984,102 @@ const Expr* PolynomialTransformer::mutate(const Div* v) { } // If numberator and denominator are equal the result is 1. - if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) { - return getImmediateByType(v->dtype(), 1); - } + // Unless the demoninator could be zero. + // if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) { + // return getImmediateByType(v->dtype(), 1); + // } if (auto ret = factorizeDivision(lhs_new, rhs_new)) { - return ret; + return ret->accept_mutator(this); } return new Div(lhs_new, rhs_new); } +const Expr* PolynomialTransformer::mutate(const Mod* v) { + const Expr* lhs_new = v->lhs()->accept_mutator(this); + const Expr* rhs_new = v->rhs()->accept_mutator(this); + + // Constant Folding. + if (lhs_new->isConstant() && rhs_new->isConstant()) { + return evaluateOp(new Mod(lhs_new, rhs_new)); + } + + // 0 % x => 0. + if (lhs_new->isConstant() && immediateEquals(lhs_new, 0)) { + return lhs_new; + } + + // x % 1 == 0. + if (rhs_new->isConstant() && immediateEquals(rhs_new, 1)) { + return getImmediateByType(v->dtype(), 0); + } + + // x % x => 0. + if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) { + return getImmediateByType(v->dtype(), 0); + } + + const Term* lhsTerm = dynamic_cast(lhs_new); + if (!lhsTerm) { + const Polynomial* lhsPoly = dynamic_cast(lhs_new); + if (lhsPoly) { + // Can still optimize this out if we can factorize the polynomial. + lhsTerm = factorizePolynomial(lhsPoly); + } + } + + if (lhsTerm) { + // ((C1 * C2) * x) % C1 => 0. + if (rhs_new->isConstant() && + immediateEquals(evaluateOp(new Mod(lhsTerm->scalar(), rhs_new)), 0)) { + return getImmediateByType(v->dtype(), 0); + } + + // (x * y * z) % x => 0. + for (auto* component : lhsTerm->variables()) { + if (hasher_.hash(component) == hasher_.hash(rhs_new)) { + return getImmediateByType(v->dtype(), 0); + } + } + + // (6 * x * y) % (3 * x * y) => 0. + // also, (x * y * z) % (z * y) => 0. + // This requires all variable terms found in the RHS to be present in the + // LHS. + const Term* rhsTerm = dynamic_cast(rhs_new); + if (rhsTerm) { + auto& lVars = lhsTerm->variables(); + auto& rVars = rhsTerm->variables(); + size_t rLeft = rVars.size(); + + auto rIt = rVars.begin(); + + for (auto lIt = lVars.begin(); lIt != lVars.end() && !rVars.empty(); + ++lIt) { + auto lHash = hasher_.hash(*lIt); + for (; rIt != rVars.end(); ++rIt) { + auto rHash = hasher_.hash(*rIt); + if (lHash == rHash) { + --rLeft; + break; + } else if (lHash < rHash) { + break; + } + } + } + + if (rLeft == 0 && + immediateEquals( + evaluateOp(new Mod(lhsTerm->scalar(), rhsTerm->scalar())), 0)) { + return getImmediateByType(v->dtype(), 0); + } + } + } + + return new Mod(lhs_new, rhs_new); +} + namespace { // Combines two MinTerm / MaxTerm expressions into one. @@ -1202,17 +1287,54 @@ const Expr* PolynomialTransformer::mutate(const Min* v) { const Expr* PolynomialTransformer::mutate(const CompareSelect* v) { const Expr* lhs_new = v->lhs()->accept_mutator(this); const Expr* rhs_new = v->rhs()->accept_mutator(this); - const Expr* retval1_new = v->ret_val1()->accept_mutator(this); - const Expr* retval2_new = v->ret_val2()->accept_mutator(this); - const Expr* v_new = new CompareSelect( - lhs_new, rhs_new, retval1_new, retval2_new, v->compare_select_op()); + const Expr* true_branch = v->ret_val1()->accept_mutator(this); + const Expr* false_branch = v->ret_val2()->accept_mutator(this); // Constant Folding. if (lhs_new->isConstant() && rhs_new->isConstant()) { + const Expr* v_new = new CompareSelect( + lhs_new, rhs_new, true_branch, false_branch, v->compare_select_op()); return evaluateOp(v_new); } - return v_new; + // If the comparison is done in float, don't attempt diff simplification, + // since we can't correctly handle NaN. + if (lhs_new->dtype().is_floating_point() || + rhs_new->dtype().is_floating_point()) { + return new CompareSelect( + lhs_new, rhs_new, true_branch, false_branch, v->compare_select_op()); + } + + // If diff is constant, we can determine it. + const Expr* diff = new Sub(rhs_new, lhs_new); + diff = diff->accept_mutator(this); + + if (!diff->isConstant()) { + return new CompareSelect( + lhs_new, rhs_new, true_branch, false_branch, v->compare_select_op()); + } + + bool equal = immediateEquals(diff, 0); + bool lhsSmaller = !equal && !immediateIsNegative(diff); + + switch (v->compare_select_op()) { + case CompareSelectOperation::kEQ: + return equal ? true_branch : false_branch; + case CompareSelectOperation::kGT: + return (lhsSmaller || equal) ? false_branch : true_branch; + case CompareSelectOperation::kGE: + return lhsSmaller ? false_branch : true_branch; + case CompareSelectOperation::kLT: + return lhsSmaller ? true_branch : false_branch; + case CompareSelectOperation::kLE: + return (lhsSmaller || equal) ? true_branch : false_branch; + case CompareSelectOperation::kNE: + return equal ? false_branch : true_branch; + } + + // should not be possible but just in case. + return new CompareSelect( + lhs_new, rhs_new, true_branch, false_branch, v->compare_select_op()); } const Expr* PolynomialTransformer::mutate(const Intrinsics* v) { @@ -1541,9 +1663,9 @@ const Expr* polyGCD(const Polynomial* poly) { // We ony want to factorize if we're saving complete operations, i.e. no // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work. int opsSaved = 1; // default to saving the scalar. - long GCD = immediateAs(scalar); + long GCD = std::abs(immediateAs(scalar)); for (auto* t : variables) { - long termScalar = immediateAs(t->scalar()); + long termScalar = std::abs(immediateAs(t->scalar())); long newGCD = gcd(std::max(GCD, termScalar), std::min(GCD, termScalar)); if (newGCD == 1) { return nullptr; @@ -1655,7 +1777,7 @@ const Expr* simplifyRoundModPattern(const Polynomial* poly) { } // Trivially factorize terms by GCD of scalar components. -const Expr* TermExpander::factorizePolynomial(const Polynomial* poly) { +const Term* IRSimplifierBase::factorizePolynomial(const Polynomial* poly) { const Expr* scalar = poly->scalar(); const std::vector& variables = poly->variables(); @@ -1909,6 +2031,7 @@ Block* TermExpander::fuseConditions(Block* v) { stmts.push_back(s); continue; } + // Fuse the two Conds by appending the bodies of the second Cond to the // first. Block* true_block = new Block({}); @@ -1939,11 +2062,13 @@ Block* TermExpander::fuseConditions(Block* v) { false_block = nullptr; } - prev_cond = prev_cond->cloneWithNewBodies(true_block, false_block); + Stmt* new_cond = prev_cond->cloneWithNewBodies(true_block, false_block) + ->accept_mutator(this); + prev_cond = dynamic_cast(new_cond); // erase, which shortens the list. stmts.pop_back(); - stmts.push_back(prev_cond); + stmts.push_back(new_cond); did_anything = true; } diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h index 696c5001b8437..2bb1287bb4f86 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.h +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h @@ -414,6 +414,9 @@ class TORCH_API IRSimplifierBase : public IRMutator { Stmt* mutate(const For* v) override; + // Trivially factorize terms by GCD of scalar components. + const Term* factorizePolynomial(const Polynomial* poly); + HashProvider& hasher() { return hasher_; } @@ -471,9 +474,7 @@ class TORCH_API PolynomialTransformer : public IRSimplifierBase { const Expr* mutate(const Div* v) override; - const Expr* mutate(const Mod* v) override { - return mutateBinaryOp(v, this); - } + const Expr* mutate(const Mod* v) override; const Expr* mutate(const And* v) override { return mutateBinaryOp(v, this); @@ -548,9 +549,6 @@ class TORCH_API TermExpander : public IRSimplifierBase { // Expand Terms out to a series of Muls. const Expr* mutate(const Term* v) override; - // Trivially factorize terms by GCD of scalar components. - const Expr* factorizePolynomial(const Polynomial* poly); - // Expand Polynomials out to a series of Adds. const Expr* mutate(const Polynomial* v) override; diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index 6d2b2140d5b3a..772a28c77addb 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -79,6 +79,9 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); void IRVisitor::visit(const Cast* v) { v->src_value()->accept(this); } +void IRVisitor::visit(const BitCast* v) { + v->src_value()->accept(this); +} void IRVisitor::visit(const Var* v) {} void IRVisitor::visit(const Ramp* v) { @@ -227,7 +230,7 @@ void IRVisitor::visit(const MinTerm* v) { void IRVisitor::visit(const ReduceOp* v) { v->accumulator()->accept(this); - v->body().node()->accept(this); + v->body()->accept(this); for (auto* e : v->output_args()) { e->accept(this); diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h index 3f5f05229c167..8353da680edb9 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.h +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -26,6 +26,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE) #undef IMM_DECLARE class Cast; +class BitCast; class Var; class Buf; class Ramp; @@ -74,6 +75,7 @@ class TORCH_API IRVisitor { #undef IMM_PRINT_VISIT virtual void visit(const Cast* v); + virtual void visit(const BitCast* v); virtual void visit(const Var* v); virtual void visit(const Buf* v); virtual void visit(const Ramp* v); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 539abe941e19e..90cd449b54909 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1,7 +1,10 @@ #include +#include +#include #include #include +#include #include #include #include @@ -19,6 +22,7 @@ static int te_cuda_pointwise_block_count = -1; static int te_cuda_pointwise_block_size = -1; static bool fallback_allowed = false; static bool te_generate_block_code = false; +static bool te_must_use_llvm_on_cpu = false; bool setFallbackAllowed(bool value) { bool old_value = fallback_allowed; @@ -39,6 +43,9 @@ bool fallbackAllowed() { bool fallbackEnforced() { static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK"); + if (tensorexpr::getTEGenerateBlockCode()) { + return false; + } if (!enable_c_str) { return fallback_allowed; } @@ -48,6 +55,15 @@ bool fallbackEnforced() { return false; } +bool dontUseLLVMFlag() { + static const char* enable_c_str = + std::getenv("PYTORCH_TENSOREXPR_DONT_USE_LLVM"); + if (!enable_c_str) { + return false; + } + return std::string(enable_c_str) == "1"; +} + int& getTECudaPointwiseLoopLevels() { return te_cuda_pointwise_loop_levels; } @@ -67,6 +83,10 @@ bool& getTEGenerateBlockCode() { return te_generate_block_code; } +bool& getTEMustUseLLVMOnCPU() { + return te_must_use_llvm_on_cpu; +} + c10::optional pickDeviceType( const at::ArrayRef& inputs) { c10::optional device = c10::nullopt; @@ -86,6 +106,18 @@ c10::optional pickDeviceType( } // namespace jit } // namespace torch +size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { + if (idx < 0) { + // Handle negative indexing + idx = list_size + idx; + } + + if (idx < 0 || idx >= list_size) { + AT_ERROR("Invalid index ", idx, " for list_size", list_size); + } + return static_cast(idx); +} + static at::ScalarType tensorType(Tensor* t) { return static_cast(t->body()->dtype().scalar_type()); } @@ -123,15 +155,16 @@ ExprHandle TensorExprKernel::broadcast( ExprHandle TensorExprKernel::chunk( Tensor* t, size_t chunkIdx, - size_t dim, - size_t chunks, + int64_t dim, + int64_t chunks, const std::vector& axes) { + auto norm_dim = normalizeAndCheckIndex(dim, axes.size()); auto sizes = bufferSizes(t); - size_t step = sizes[dim] / chunks; + size_t step = sizes[norm_dim] / chunks; std::vector indices; for (size_t i = 0; i < axes.size(); ++i) { - if (i == dim) { + if (i == norm_dim) { indices.push_back(axes[i] + IntImm::make((int)chunkIdx * (int)step)); } else { indices.push_back(axes[i]); @@ -141,6 +174,25 @@ ExprHandle TensorExprKernel::chunk( return t->call(indices); } +ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) { + if (e.dtype().scalar_type() == dt) { + return e; + } + + switch (dt) { +// NOLINTNEXTLINE +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + e = cast(e); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + return e; +} + ExprHandle TensorExprKernel::tensorOrConstant( const torch::jit::Value* v, const std::vector& axes) { @@ -200,12 +252,15 @@ std::vector TensorExprKernel::inferSizesForValue( const torch::jit::Value* v) { switch (v->node()->kind()) { case aten::_cast_Float: + case aten::to: case aten::sigmoid: case aten::reciprocal: case aten::neg: case aten::relu: + case aten::isnan: case aten::log: case aten::log10: + case aten::log1p: case aten::log2: case aten::exp: case aten::expm1: @@ -230,7 +285,9 @@ std::vector TensorExprKernel::inferSizesForValue( case aten::trunc: case aten::frac: case aten::lgamma: - return sizesForValue(v->node()->input()); + case aten::type_as: + case aten::masked_fill: + return sizesForValue(v->node()->input(0)); case aten::sub: case aten::add: @@ -249,7 +306,6 @@ std::vector TensorExprKernel::inferSizesForValue( case aten::lt: case aten::min: case aten::max: - case aten::type_as: case aten::pow: case aten::fmod: case aten::remainder: @@ -261,7 +317,6 @@ std::vector TensorExprKernel::inferSizesForValue( } return broadcastShapes(shapes); } - case aten::lerp: case aten::clamp: case aten::threshold: @@ -323,20 +378,31 @@ std::vector TensorExprKernel::inferSizesForValue( // The sizes of the output tensor on that dimension is a sum of the // corresponding sizes of the input tensors, the other dimension have the // same sizes. + // Negative dim will correspond to dim = dim + input.dim(). auto const& n = v->node(); auto inputs = n->input(0)->node()->inputs(); + if (inputs.size() == 0) { + throw std::runtime_error("Empty input list is passed to aten::cat"); + } + TORCH_INTERNAL_ASSERT(n->input(1)->node()->kind() == prim::Constant); int64_t dim = n->input(1)->node()->i(attr::value); - - ExprHandle concat_size = IntImm::make(0); + auto shape = sizesForValue(inputs[0]); + size_t norm_dim = normalizeAndCheckIndex(dim, shape.size()); + ExprHandle concat_dim_size = 0; for (auto input : inputs) { - concat_size = concat_size + sizesForValue(input)[dim]; + concat_dim_size = concat_dim_size + sizesForValue(input)[norm_dim]; } - concat_size = IRSimplifier::simplify(concat_size); - auto shape = sizesForValue(inputs[0]); - shape[dim] = concat_size; + concat_dim_size = IRSimplifier::simplify(concat_dim_size); + shape[norm_dim] = concat_dim_size; return shape; } + + case aten::softmax: + case aten::log_softmax: + // Output of softmax / log_softmax has the same shape as input 0. + return sizesForValue(v->node()->input(0)); + case aten::slice: throw std::runtime_error( "Shape info is not implemented for this kind of node"); @@ -375,7 +441,59 @@ ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) { return scalars_.at(v->unique()); } -void TensorExprKernel::promoteInputs(std::vector& inputs) { +ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) { + auto scalarType = static_cast(e.dtype().scalar_type()); + if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) { + return e; + } + + auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype()); + + // We intend to promote Integers to floating-point types + TORCH_INTERNAL_ASSERT( + !c10::isIntegralType(defaultType, /*includeBool*/ true)); + + return Cast::make( + Dtype( + static_cast(defaultType), e.dtype().lanes()), + e); +} + +ExprHandle promoteHalfToFloat(const ExprHandle& e) { + auto scalarType = static_cast(e.dtype().scalar_type()); + auto floatType = static_cast(tensorexpr::ScalarType::Float); + if (c10::isFloatingType(scalarType) && + (c10::elementSize(scalarType) < c10::elementSize(floatType))) { + return Cast::make( + Dtype(tensorexpr::ScalarType::Float, e.dtype().lanes()), e); + } else { + return e; + } +} + +bool TensorExprKernel::checkTypes( + const ScalarType highType, + const int typeConstraints) { + if (typeConstraints == kAllTypes) { + return true; + } + + if (is_integral(highType)) { + return (typeConstraints & kIntegralTypes) != 0; + } else if (is_floating_point(highType)) { + return (typeConstraints & kFloatingPointTypes) != 0; + } else if (highType == ScalarType::Bool) { + return (typeConstraints & kBoolType) != 0; + } + + // assume JIT not supporting complex and qint yet + TORCH_INTERNAL_ASSERT((typeConstraints & (kQintTypes | kComplexTypes)) == 0); + return false; +} + +void TensorExprKernel::promoteInputs( + std::vector& inputs, + const int typeConstraints) { if (inputs.empty()) { return; } @@ -386,22 +504,12 @@ void TensorExprKernel::promoteInputs(std::vector& inputs) { highType = promoteTypes(highType, input.dtype().scalar_type()); } - for (ExprHandle& e : inputs) { - if (e.dtype().scalar_type() == highType) { - continue; - } + if (!checkTypes(highType, typeConstraints)) { + throw unsupported_dtype(); + } - switch (highType) { -// NOLINTNEXTLINE -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - e = cast(e); \ - break; - AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } + for (ExprHandle& e : inputs) { + e = promoteToDtype(e, highType); } } @@ -458,6 +566,7 @@ std::vector TensorExprKernel::broadcastShapes( auto res2 = broadcastShapes(shapes); return res2; } + std::vector TensorExprKernel::broadcastShapes( const std::vector& a, const std::vector& b) { @@ -505,19 +614,20 @@ std::vector TensorExprKernel::valueShape( Tensor* TensorExprKernel::computeOneOperand( const std::string& name, const torch::jit::Value* v, - const std::function& innerExpr) { + const std::function& innerExpr, + const int checkParamTypes) { auto const& n = v->node(); auto const& shape = valueShape(n->inputs()[0]); return Compute( name, c10::fmap(shape), - [this, v, innerExpr](const std::vector& axes) { + [this, v, innerExpr, checkParamTypes]( + const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices)}; - - promoteInputs(inputs); + promoteInputs(inputs, checkParamTypes); ExprHandle compute = innerExpr(inputs[0]); return demoteOutput(compute, n->output()); }); @@ -612,7 +722,8 @@ Tensor* TensorExprKernel::computeThreeOperand( const torch::jit::Value* v, const std::function< ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& - innerExpr) { + innerExpr, + bool promote_inputs) { auto const& n = v->node(); std::vector> shapes; for (size_t idx = 0; idx < 3; idx++) { @@ -623,7 +734,7 @@ Tensor* TensorExprKernel::computeThreeOperand( return Compute( name, c10::fmap(shape), - [this, v, innerExpr](const std::vector& axes) { + [this, v, innerExpr, promote_inputs](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { @@ -632,7 +743,9 @@ Tensor* TensorExprKernel::computeThreeOperand( tensorOrConstant(n->inputs()[2], indices), }; - promoteInputs(inputs); + if (promote_inputs) { + promoteInputs(inputs); + } ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]); return demoteOutput(compute, n->output()); }); @@ -682,6 +795,16 @@ ExprHandle boolToInteger(const ExprHandle& x) { } // namespace +c10::optional findDtypeForValue(const torch::jit::Value* v) { + if (v->type()->kind() == TypeKind::TensorType) { + auto tt = v->type()->cast(); + if (tt->scalarType()) { + return static_cast(*tt->scalarType()); + } + } + return c10::nullopt; +} + Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { switch (v->node()->kind()) { case aten::add: { @@ -701,6 +824,17 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { }); } break; + case aten::to: { + // see handling of aten::to in tensorexpr_fuser.cpp for why we only + // need to handle the first input + auto node = v->node(); + return computeOneOperand("aten_to", v, [node](const ExprHandle& a) { + auto output_dtype = findDtypeForValue(node->output()); + TORCH_INTERNAL_ASSERT(output_dtype); + return Cast::make(ToDtype(*output_dtype), a); + }); + } break; + case aten::sub: { auto sub_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) { // NB: sub isn't supported on boolean, no need to promote to integer. @@ -723,7 +857,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::div: { return computeTwoOperand( "aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - return boolToInteger(lhs) / boolToInteger(rhs); + return promoteIntegerToDefaultType(lhs) / + promoteIntegerToDefaultType(rhs); }); } break; @@ -827,6 +962,20 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { }); } break; + case aten::masked_fill: { + return computeThreeOperand( + "aten_masked_fill", + v, + [](const ExprHandle& input, + const ExprHandle& mask, + const ExprHandle& value) { + // value needs to promote to input, not vice versa + auto val = promoteToDtype(value, input.dtype().scalar_type()); + return ifThenElse(mask, val, input); + }, + /*promote_inputs*/ false); + } + case aten::clamp: { bool noMin = false; bool noMax = false; @@ -851,26 +1000,32 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { const ExprHandle& in, const ExprHandle& min, const ExprHandle& max) { + auto cast = [&](const ExprHandle& e) { + return Cast::make(in.dtype(), e); + }; + if (noMin && noMax) { return in; } else if (noMin) { - return CompareSelect::make(in, max, max, in, kGT); + auto cmax = cast(max); + return CompareSelect::make(in, cmax, cmax, in, kGT); } else if (noMax) { - return CompareSelect::make(in, min, min, in, kLT); + auto cmin = cast(min); + return CompareSelect::make(in, cmin, cmin, in, kLT); } else { - return CompareSelect::make( - in, - min, - min, - CompareSelect::make(in, max, max, in, kGT), - kLT); + auto cmax = cast(max); + auto cmin = cast(min); + auto mm = CompareSelect::make(in, cmin, cmin, in, kLT); + return CompareSelect::make(mm, cmax, cmax, mm, kGT); } - }); + }, + false /* promote_inputs */); } break; case aten::sigmoid: { - return computeOneOperand( - "aten_sigmoid", v, [](const ExprHandle& a) { return sigmoid(a); }); + return computeOneOperand("aten_sigmoid", v, [](const ExprHandle& a) { + return sigmoid(promoteIntegerToDefaultType(a)); + }); } break; case aten::reciprocal: { @@ -885,6 +1040,15 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { }); } break; + case aten::isnan: { + return computeOneOperand("aten_isnan", v, [](const ExprHandle& a) { + if (!a.dtype().is_floating_point()) { + return IntImm::make(0); + } + return isnan(a); + }); + } break; + case aten::relu: { return computeOneOperand("aten_relu", v, [](const ExprHandle& a) { auto zero = Cast::make(a.dtype(), 0); @@ -893,59 +1057,78 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::log: { - return computeOneOperand( - "aten_log", v, [](const ExprHandle& a) { return log(a); }); + return computeOneOperand("aten_log", v, [](const ExprHandle& a) { + return log(promoteIntegerToDefaultType(a)); + }); } break; case aten::log10: { - return computeOneOperand( - "aten_log10", v, [](const ExprHandle& a) { return log10(a); }); + return computeOneOperand("aten_log10", v, [](const ExprHandle& a) { + return log10(promoteIntegerToDefaultType(a)); + }); + } break; + + case aten::log1p: { + return computeOneOperand("aten_log1p", v, [](const ExprHandle& a) { + return log1p(promoteIntegerToDefaultType(a)); + }); } break; case aten::log2: { - return computeOneOperand( - "aten_log2", v, [](const ExprHandle& a) { return log2(a); }); + return computeOneOperand("aten_log2", v, [](const ExprHandle& a) { + return log2(promoteIntegerToDefaultType(a)); + }); } break; case aten::exp: { - return computeOneOperand( - "aten_exp", v, [](const ExprHandle& a) { return exp(a); }); + return computeOneOperand("aten_exp", v, [](const ExprHandle& a) { + return exp(promoteIntegerToDefaultType(a)); + }); } break; case aten::expm1: { - return computeOneOperand( - "aten_expm1", v, [](const ExprHandle& a) { return expm1(a); }); + return computeOneOperand("aten_expm1", v, [](const ExprHandle& a) { + return expm1(promoteIntegerToDefaultType(a)); + }); } break; case aten::erf: { - return computeOneOperand( - "aten_erf", v, [](const ExprHandle& a) { return erf(a); }); + return computeOneOperand("aten_erf", v, [](const ExprHandle& a) { + return erf(promoteIntegerToDefaultType(a)); + }); } break; case aten::erfc: { - return computeOneOperand( - "aten_erfc", v, [](const ExprHandle& a) { return erfc(a); }); + return computeOneOperand("aten_erfc", v, [](const ExprHandle& a) { + return erfc(promoteIntegerToDefaultType(a)); + }); } break; case aten::cos: { - return computeOneOperand( - "aten_cos", v, [](const ExprHandle& a) { return cos(a); }); + return computeOneOperand("aten_cos", v, [](const ExprHandle& a) { + return cos(promoteIntegerToDefaultType(a)); + }); } break; case aten::sin: { - return computeOneOperand( - "aten_sin", v, [](const ExprHandle& a) { return sin(a); }); + return computeOneOperand("aten_sin", v, [](const ExprHandle& a) { + return sin(promoteIntegerToDefaultType(a)); + }); } break; case aten::tan: { - return computeOneOperand( - "aten_tan", v, [](const ExprHandle& a) { return tan(a); }); + return computeOneOperand("aten_tan", v, [](const ExprHandle& a) { + return tan(promoteIntegerToDefaultType(a)); + }); } break; case aten::type_as: { - return computeTwoOperand( - "aten_type_as", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - return Cast::make(rhs.dtype(), lhs); + auto const& n = v->node(); + Tensor* rhs = tensors_.at(n->inputs()[1]->unique()); + auto dtype = rhs->body()->dtype(); + return computeOneOperand( + "aten_type_as", v, [dtype](const ExprHandle& lhs) { + return Cast::make(dtype, lhs); }); } break; @@ -959,54 +1142,31 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::pow: { return computeTwoOperand( "aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - const FloatImm* floatImm = rhs.AsNode(); - if (floatImm) { - float imm = floatImm->value(); - if (imm == 1.0f) { - return lhs; - } else if (imm == 2.0f) { // NOLINT - return lhs * lhs; - } else if (imm == 3.0f) { // NOLINT - return (lhs * lhs) * lhs; - } else if (imm == 4.0f) { // NOLINT - ExprHandle tmp = lhs * lhs; - return tmp * tmp; - } else if (imm == 0.5f) { // NOLINT - return sqrt(lhs); - } else if (imm == 0.0f) { - return ExprHandle(1.0f); - } else if (imm == -0.5f) { // NOLINT - return rsqrt(lhs); - } else if (imm == -1.0f) { - return ExprHandle(1.0f) / lhs; - } else if (imm == -2.0f) { // NOLINT - return ExprHandle(1.0f) / (lhs * lhs); - } + if (!rhs.node()->isConstant()) { + return pow(lhs, rhs); } - - const Cast* floatCast = rhs.AsNode(); - if (floatCast) { - const IntImm* intImm = - dynamic_cast(floatCast->src_value()); - if (intImm) { - float imm = static_cast(intImm->value()); - if (imm == 1) { - return lhs; - } else if (imm == 2) { - return lhs * lhs; - } else if (imm == 3) { - return (lhs * lhs) * lhs; - } else if (imm == 4) { - ExprHandle tmp = lhs * lhs; - return tmp * tmp; - } else if (imm == 0) { - return ExprHandle(1.0f); - } else if (imm == -1) { - return ExprHandle(1.0f) / lhs; - } else if (imm == -2) { - return ExprHandle(1.0f) / (lhs * lhs); - } - } + double val = + immediateAs(IRSimplifier::simplify(rhs.node())); + + if (val == 1.0f) { + return lhs; + } else if (val == 2.0f) { // NOLINT + return lhs * lhs; + } else if (val == 3.0f) { // NOLINT + return (lhs * lhs) * lhs; + } else if (val == 4.0f) { // NOLINT + ExprHandle tmp = lhs * lhs; + return tmp * tmp; + } else if (val == 0.5f) { // NOLINT + return sqrt(lhs); + } else if (val == 0.0f) { + return ExprHandle(1.0f); + } else if (val == -0.5f) { // NOLINT + return rsqrt(lhs); + } else if (val == -1.0f) { + return ExprHandle(1.0f) / lhs; + } else if (val == -2.0f) { // NOLINT + return ExprHandle(1.0f) / (lhs * lhs); } return pow(lhs, rhs); }); @@ -1015,7 +1175,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::fmod: { return computeTwoOperand( "aten_fmod", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - return fmod(lhs, rhs); + return fmod(promoteHalfToFloat(lhs), promoteHalfToFloat(rhs)); }); } break; @@ -1028,65 +1188,114 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { const ExprHandle& weight) { return a + weight * (end - a); }); } break; case aten::remainder: { - return computeTwoOperand( - "aten_remainder", - v, - [](const ExprHandle& lhs, const ExprHandle& rhs) { - return fmod((rhs + fmod(lhs, rhs)), rhs); - }); + auto imodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) { + return Mod::make(lhs, rhs); + }; + auto fmodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) { + auto lhs_t = promoteHalfToFloat(lhs); + auto rhs_t = promoteHalfToFloat(rhs); + return fmod((rhs_t + fmod(lhs_t, rhs_t)), rhs_t); + }; + { + auto const& n = v->node(); + auto const& shape = broadcastShapes( + valueShape(n->inputs()[0]), valueShape(n->inputs()[1])); + return Compute( + "aten_remainder", + c10::fmap(shape), + [&](const std::vector& axes) { + auto const& n = v->node(); + std::vector indices(axes.begin(), axes.end()); + std::vector inputs = { + tensorOrConstant(n->inputs()[0], indices), + tensorOrConstant(n->inputs()[1], indices), + }; + + promoteInputs(inputs); + bool allInt = true; + for (auto& e : inputs) { + if (e.dtype().is_floating_point()) { + allInt = false; + break; + } + } + if (allInt) { + return demoteOutput( + imodImpl(inputs[0], inputs[1]), n->output()); + } else { + return demoteOutput( + fmodImpl(inputs[0], inputs[1]), n->output()); + } + }); + } } break; case aten::acos: { - return computeOneOperand( - "aten_acos", v, [](const ExprHandle& a) { return acos(a); }); + return computeOneOperand("aten_acos", v, [](const ExprHandle& a) { + return acos(promoteIntegerToDefaultType(a)); + }); } break; case aten::asin: { - return computeOneOperand( - "aten_asin", v, [](const ExprHandle& a) { return asin(a); }); + return computeOneOperand("aten_asin", v, [](const ExprHandle& a) { + return asin(promoteIntegerToDefaultType(a)); + }); } break; case aten::cosh: { - return computeOneOperand( - "aten_cosh", v, [](const ExprHandle& a) { return cosh(a); }); + return computeOneOperand("aten_cosh", v, [](const ExprHandle& a) { + return cosh(promoteIntegerToDefaultType(a)); + }); } break; case aten::sinh: { - return computeOneOperand( - "aten_sinh", v, [](const ExprHandle& a) { return sinh(a); }); + return computeOneOperand("aten_sinh", v, [](const ExprHandle& a) { + return sinh(promoteIntegerToDefaultType(a)); + }); } break; case aten::atan: { - return computeOneOperand( - "aten_atan", v, [](const ExprHandle& a) { return atan(a); }); + return computeOneOperand("aten_atan", v, [](const ExprHandle& a) { + return atan(promoteIntegerToDefaultType(a)); + }); } break; case aten::atan2: { return computeTwoOperand( "aten_atan2", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { - return atan2(lhs, rhs); + return atan2( + promoteIntegerToDefaultType(lhs), + promoteIntegerToDefaultType(rhs)); }); } break; case aten::tanh: { - return computeOneOperand( - "aten_tanh", v, [](const ExprHandle& a) { return tanh(a); }); + return computeOneOperand("aten_tanh", v, [](const ExprHandle& a) { + return tanh(promoteIntegerToDefaultType(a)); + }); } break; case aten::sqrt: { - return computeOneOperand( - "aten_sqrt", v, [](const ExprHandle& a) { return sqrt(a); }); + return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) { + return tensorexpr::sqrt(promoteIntegerToDefaultType(a)); + }); } break; case aten::rsqrt: { - return computeOneOperand( - "aten_rsqrt", v, [](const ExprHandle& a) { return rsqrt(a); }); + return computeOneOperand("aten_rsqrt", v, [](const ExprHandle& a) { + return rsqrt(promoteIntegerToDefaultType(a)); + }); } break; case aten::abs: { return computeOneOperand( - "aten_abs", v, [](const ExprHandle& a) { return fabs(a); }); + "aten_abs", + v, + [](const ExprHandle& a) { + return tensorexpr::abs(promoteHalfToFloat(a)); + }, + kIntegralTypes | kFloatingPointTypes | kBoolType); } break; case aten::ceil: { @@ -1131,7 +1340,13 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::frac: { return computeOneOperand( - "aten_frac", v, [](const ExprHandle& a) { return a - floor(a); }); + "aten_frac", + v, + [](const ExprHandle& a) { + auto aa = promoteHalfToFloat(a); + return aa - floor(aa); + }, + kFloatingPointTypes); } break; case aten::lgamma: { @@ -1164,19 +1379,75 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { [this, v](const std::vector& axes) { auto const& n = v->node(); auto inputs = n->inputs()[0]->node()->inputs(); - size_t dim = n->inputs()[1]->node()->i(attr::value); + if (inputs.size() == 0) { + throw std::runtime_error( + "Empty input list is passed to aten::cat"); + } + + // Some of the inputs can be empty tensors, we need to skip them + // when we construct the expression, but we need to take them into + // account in dtype promotion. + std::vector nonempty_inputs; + for (auto input : inputs) { + if (input->type()->kind() == TypeKind::TensorType) { + auto tt = input->type()->cast(); + if (tt->isComplete() && tt->sizes().size() && tt->sizes()[0] && + *tt->sizes()[0]) { + nonempty_inputs.push_back(input); + } + } + } + + // When all inputs are empty tensors, the tensor we create for this + // computation would contain no elements, so it doesn't really + // matter what we return here, so just return 0. + if (!nonempty_inputs.size()) { + return ExprHandle(0); + } + + int64_t dim_ = n->inputs()[1]->node()->i(attr::value); + size_t dim = normalizeAndCheckIndex(dim_, axes.size()); + // Promote input types. + // Note that we need to consider all inputs, including empty - they + // also affect the resultant dtype. + auto maybe_dtype = findDtypeForValue(inputs[0]); + TORCH_INTERNAL_ASSERT( + maybe_dtype, "Cannot find dtype for one of aten::cat inputs"); + ScalarType highType = *maybe_dtype; + for (const auto input : inputs) { + auto maybe_dtype = findDtypeForValue(input); + TORCH_INTERNAL_ASSERT( + maybe_dtype, "Cannot find dtype for one of aten::cat inputs"); + highType = promoteTypes(highType, *maybe_dtype); + } + // Now we know the final dtype, we know what inputs are non-empty, + // and we know that there is at least one such an input. With all + // that we construct a tensor expression performing the + // concatenation. + // The expression we build here is a cascading if-then-else that + // essentially represents: + // + // inp1[i, j, k] if 0 < i < l1, + // out[i,j,k] = inp2[i, j-l1, k] if l1 =< i < l1 + l2, + // ... + // inpN[i, j-l_N_1, k] if l1+l2+...l_N_1 < i + // where l_i is the corresponding size of the i-th input. std::vector newAxes(axes.begin(), axes.end()); - ExprHandle load = tensorOrConstant(inputs[0], newAxes); - size_t offset = bufferSizes(tensors_.at(inputs[0]->unique()))[dim]; + ExprHandle load = promoteToDtype( + tensorOrConstant(nonempty_inputs[0], newAxes), highType); + size_t offset = + bufferSizes(tensors_.at(nonempty_inputs[0]->unique()))[dim]; newAxes[dim] = newAxes[dim] - IntImm::make(offset); - for (size_t ii = 1; ii < inputs.size(); ++ii) { + for (size_t ii = 1; ii < nonempty_inputs.size(); ++ii) { + auto input = nonempty_inputs[ii]; load = ifThenElse( CompareSelect::make(axes[dim], IntImm::make(offset), kLT), load, - tensorOrConstant(inputs[ii], newAxes)); - offset += bufferSizes(tensors_.at(inputs[ii]->unique()))[dim]; + promoteToDtype(tensorOrConstant(input, newAxes), highType)); + + offset += bufferSizes(tensors_.at(input->unique()))[dim]; newAxes[dim] = axes[dim] - IntImm::make(offset); } @@ -1232,80 +1503,50 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { return computeSum(v); } - default: { - throw std::runtime_error("Unhandled node kind"); + case aten::softmax: { + return computeSoftmax(v, false); } - } -} -void TensorExprKernel::flattenTensors(BackendType backendType) { - if (backendType != BackendType::kCudaCodeGen && - backendType != BackendType::kBlockCodeGen) { - // We only need to flatten for GPU, for other backends just use the same - // tensors. - flatTensorOutputs_ = tensorOutputs_; - return; - } + case aten::log_softmax: { + return computeSoftmax(v, true); + } - flatTensorOutputs_.resize(tensorOutputs_.size()); - for (size_t tensorIdx = 0; tensorIdx < tensorOutputs_.size(); tensorIdx++) { - Tensor* tensor = tensorOutputs_[tensorIdx]; - ExprHandle totalCount = ExprHandle(tensor->dim(0)); - for (int i = 1; i < tensor->ndim(); i++) { - const IntImm* totalCountImm = totalCount.AsNode(); - const IntImm* tensorDimImm = dynamic_cast(tensor->dim(i)); - if (totalCountImm && tensorDimImm) { - // TODO: switch to real constant folding when it is available. - totalCount = ExprHandle(totalCountImm->value() * tensorDimImm->value()); - } else { - totalCount = totalCount * ExprHandle(tensor->dim(i)); - } + default: { + throw std::runtime_error("Unhandled node kind"); } - // Flatten the index for GPU kernels. - // TODO: move this to fusing axis when it is ready. - Tensor* newOut = Compute( - tensor->func_var()->name_hint() + "_flat", - {totalCount}, - [tensor](const VarHandle& index) -> ExprHandle { - std::vector dims; - ExprHandle value = index; - for (int i = tensor->ndim() - 1; i >= 0; i--) { - ExprHandle idx = value; - if (i > 0) { - idx = Mod::make(value, ExprHandle(tensor->dim(i))); - } - dims.push_back(idx); - value = value / ExprHandle(tensor->dim(i)); - } - std::reverse(dims.begin(), dims.end()); - return tensor->call(dims); - }); - flatTensorOutputs_[tensorIdx] = newOut; } } Stmt* TensorExprKernel::generateStmt(BackendType backendType) { - flattenTensors(backendType); - - torch::jit::tensorexpr::LoopNest l(flatTensorOutputs_); + torch::jit::tensorexpr::LoopNest l(tensorOutputs_); GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); bool hasReduction = NodeFinder::find(l.root_stmt()).size() != 0; - // Compute non-output tensors_ inline - for (auto& p : tensors_) { - if (!l.hasLoopBodyFor(p.second) || hasReduction) { - continue; - } - l.computeInline(p.second->buf()); + // For Block codegen we create a map of tensor dims before + // inlining. Like GPU codegen we need to inline. But the order + // where this analysis is run matters. + auto block_analysis = std::make_unique(); + if (backendType == kBlockCodeGen) { + // Run Block analysis to get multi dim buffer info + auto root_stmt = l.root_stmt(); + root_stmt->accept(block_analysis.get()); } - if (backendType == kCudaCodeGen) { - for (size_t i = 0; i < flatTensorOutputs_.size(); i++) { - Tensor* tensor = flatTensorOutputs_[i]; - // For every output tensor we've created a flattened 1D tensor - let's - // mark the original output tensor with computeInline - l.computeInline(tensorOutputs_[i]->buf()); + // inlining output & intermediate buffers can duplicate computation. + // it slows down cpu code generation but is enabled on gpu because it avoids + // difficult synchronization logic across blocks. + bool allow_duplicated_work = + (backendType == kCudaCodeGen || backendType == kBlockCodeGen); + l.inlineIntermediateBufs(allow_duplicated_work); + + if (backendType == kCudaCodeGen) { + for (auto tensor : tensorOutputs_) { + std::vector loops = l.getLoopStmtsFor(tensor); + TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); + For* flattened = nullptr; + LoopNest::flatten(loops, &flattened); + assert(flattened); int loopLevels = getTECudaPointwiseLoopLevels(); const int kDefaultLoopLevels = 2; @@ -1320,8 +1561,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { if (blockSize < 0) { blockSize = kDefaultBlockSize; } - std::vector loops = l.getLoopStmtsFor(tensor); - l.splitWithMask(loops[0], blockSize, &outer, &inner); + l.splitWithMask(flattened, blockSize, &outer, &inner); l.setGPUBlockIndex(outer, 0); l.setGPUThreadIndex(inner, 0); } else if (loopLevels == 3) { @@ -1334,8 +1574,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { const int kDefaultBlockSize = 256; blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount; blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize; - std::vector loops = l.getLoopStmtsFor(tensor); - l.splitWithMask(loops[0], blockCount * blockSize, &outer, &inner); + l.splitWithMask(flattened, blockCount * blockSize, &outer, &inner); l.splitWithMask(inner, blockSize, &inner1, &inner2); l.setGPUBlockIndex(inner1, 0); l.setGPUThreadIndex(inner2, 0); @@ -1347,26 +1586,23 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { } if (backendType == kBlockCodeGen) { - auto block_analysis = std::make_unique(); - for (size_t i = 0; i < flatTensorOutputs_.size(); i++) { + for (auto tensor : tensorOutputs_) { const int default_fp16_blocksize = 16; const int default_uint8_blocksize = 32; int blockSize = default_fp16_blocksize; // We only handle looplevels == 2 for now - Tensor* tensor = flatTensorOutputs_[i]; - // Run Block analysis to get multi dim buffer info - auto root_stmt = l.root_stmt(); - root_stmt->accept(block_analysis.get()); - if (tensor->buf()->dtype().scalar_type() == ScalarType::Byte) { blockSize = default_uint8_blocksize; } - l.computeInline(l.getLoopBodyFor(tensorOutputs_[i])); - For* outer; - For* inner; std::vector loops = l.getLoopStmtsFor(tensor); - TORCH_INTERNAL_ASSERT(loops.size() > 0, "loops should not be empty"); - l.splitWithMask(loops[0], blockSize, &outer, &inner); + TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); + For* flattened = nullptr; + LoopNest::flatten(loops, &flattened); + assert(flattened); + + For* outer = nullptr; + For* inner = nullptr; + l.splitWithMask(flattened, blockSize, &outer, &inner); l.setGPUBlockIndex(outer, 0); l.setGPUThreadIndex(inner, 0); l.setBufferMap(outer, block_analysis->getBufferMap()); @@ -1414,7 +1650,7 @@ std::vector TensorExprKernel::prepareBufferArgs() { params.emplace_back(stride.var); } } - for (auto& o : flatTensorOutputs_) { + for (auto& o : tensorOutputs_) { params.emplace_back(o); } return params; @@ -1434,10 +1670,13 @@ TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice( backendType = kBlockCodeGen; } else if (device.type() == at::kCPU) { #ifdef TORCH_ENABLE_LLVM - backendType = kLLVMCodeGen; + backendType = dontUseLLVMFlag() ? kSimpleIREval : kLLVMCodeGen; #else backendType = kSimpleIREval; #endif + if (getTEMustUseLLVMOnCPU() && backendType == kSimpleIREval) { + throw std::runtime_error("LLVM Backend not found"); + } } else { throw std::runtime_error("Invalid device type"); } @@ -1449,7 +1688,7 @@ void TensorExprKernel::bindInput(const torch::jit::Value* input) { switch (t->kind()) { case TypeKind::TensorType: { auto tt = input->type()->cast(); - Buffer inBuffer( + Placeholder inBuffer( "t" + input->debugName(), ToDtype(static_cast(*tt->scalarType())), {0}); @@ -1470,7 +1709,7 @@ void TensorExprKernel::bindInput(const torch::jit::Value* input) { for (size_t i = 0; i < axes.size(); i++) { idx = idx + axes[i] * IntImm::make(*strides[i]); } - return inBuffer(idx); + return inBuffer.load(idx); })); kernelArgs_.emplace_back( inBuffer, std::vector(), std::vector()); @@ -1554,6 +1793,141 @@ Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) { reduction_info.reductionDims); } +Tensor* TensorExprKernel::computeSoftmax( + const torch::jit::Value* v, + bool log_softmax) { + // Softmax is computed as follows: + // softmax(vi) = exp(vi) / sum(exp(vi)) + // + // In order to avoid overflow issues due to exp of a large number, we + // subtract the max of that dim before computing exp. + // softmax(vi) = exp(vi - max(vi)) / sum(exp(vi - max(vi))) + // + // This is implemented as 4 loopnests: + // - First loop computes the max over the softmax dim. + // - Second loop computes exp for every element in v after subtracting + // the max of the softmax dim it belongs to. + // - Third loop computes the sum over the softmax dim. + // - Final loop computes softmax for every element in v. + + // LogSoftmax is computed as follows: + // log_softmax(vi) = log(softmax(vi)) + // = vi - log(sum(exp(vi))) + // + // Using the same max trick as above: + // log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi)))) + // + // This is implemented as 5 loopnests: + // - First loop computes the max over the softmax dim. + // - Second loop computes exp for every element in v after subtracting + // the max of the softmax dim it belongs to. + // - Third loop computes the sum over the softmax dim. + // - Fourth loop computes log for every element in the sum. + // - Final loop computes the log_softmax for every element in v. + + TORCH_INTERNAL_ASSERT(v->node()->inputs().size() == 3); + auto output_dims = dimsFromSizes(sizesForValue(v)); + + // We do not handle None for dims (input 1) because that is supposed to + // be deprecated. + TORCH_INTERNAL_ASSERT(v->node()->input(1)->node()->kind() == prim::Constant); + int64_t rank = + *v->node()->input(0)->type()->cast()->sizes().size(); + size_t softmax_dim = + normalizeAndCheckIndex(v->node()->input(1)->node()->i(attr::value), rank); + std::vector non_softmax_dims; + for (size_t i = 0; i < output_dims.size(); ++i) { + if (i != softmax_dim) { + non_softmax_dims.push_back(output_dims[i]); + } + } + + // Softmax implementation includes two reductions, one to find the max and + // the other to calculate the sum along the softmax dim. These reductions + // will have the softmax dimension as the inner most loop. So, the innermost + // index in the indices will refer to the softmax dimension. + + // Update the indices by moving the softmax dimension index to the + // appropriate position. + auto move_softmax_dim_index_to_pos = [&](const ParameterList& indices) { + std::vector new_indices; + for (auto ind : indices) { + new_indices.push_back(ind); + } + for (size_t i = softmax_dim; i < indices.size() - 1; ++i) { + new_indices[i + 1] = indices[i]; + } + new_indices[softmax_dim] = indices[indices.size() - 1]; + return new_indices; + }; + + // Remove the index corresponding to the softmax dimension. + auto remove_softmax_dim_index = [&](const ParameterList& indices) { + std::vector new_indices; + for (size_t i = 0; i < indices.size(); ++i) { + if (i != softmax_dim) { + new_indices.push_back(indices[i]); + } + } + return new_indices; + }; + + auto convert_indices_to_expr_handle = [&](const ParameterList& indices) { + std::vector new_indices(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + new_indices[i] = indices[i]; + } + return new_indices; + }; + + c10::optional dtype = ToDtype(ScalarType::None); + auto maybe_dtype = v->node()->get(attr::dtype); + if (maybe_dtype && !maybe_dtype->isNone()) { + dtype = ToDtype(static_cast(maybe_dtype->toInt())); + } + + auto max = Reduce( + "aten_softmax_max", + non_softmax_dims, + Maximum(dtype.value()), + [&](ParameterList& indices) { + return tensorOrConstant( + v->node()->inputs()[0], move_softmax_dim_index_to_pos(indices)); + }, + {output_dims[softmax_dim]}); + auto e = + Compute("aten_softmax_exp", output_dims, [&](ParameterList& indices) { + auto inp = tensorOrConstant( + v->node()->inputs()[0], convert_indices_to_expr_handle(indices)); + return exp(inp - max->call(remove_softmax_dim_index(indices))); + }); + auto sum = Reduce( + "aten_softmax_sum", + non_softmax_dims, + Sum(), + [&](ParameterList& indices) { + return e->call(move_softmax_dim_index_to_pos(indices)); + }, + {output_dims[softmax_dim]}); + if (!log_softmax) { + return Compute("aten_softmax", output_dims, [&](ParameterList& indices) { + return e->call(indices) / sum->call(remove_softmax_dim_index(indices)); + }); + } + + auto log_sum = Compute( + "aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) { + return log(sum->call(indices)); + }); + return Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) { + auto inp = tensorOrConstant( + v->node()->inputs()[0], convert_indices_to_expr_handle(indices)); + auto non_softmax_indices = remove_softmax_dim_index(indices); + return inp - max->call(non_softmax_indices) - + log_sum->call(non_softmax_indices); + }); +} + TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo( const torch::jit::Node* node) { std::vector axes; @@ -1624,8 +1998,91 @@ std::vector TensorExprKernel::getReductionAxes( return axes; } +template +std::vector reverse_sort_indices(const std::vector& v) { + // initialize original index locations + std::vector idx(v.size()); + iota(idx.begin(), idx.end(), 0); + + std::sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) { + return v[i1] > v[i2]; + }); + return idx; +} + +bool denseAndNonOverlapping( + at::ArrayRef sizes, + at::ArrayRef strides) { + return (strides == at::infer_dense_strides(sizes, strides)); +} + +Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { + const TensorTypePtr& tt = v->type()->expect(); + TORCH_INTERNAL_ASSERT(tensors_.count(v->unique())); + Tensor* tensor = tensors_[v->unique()]; + + TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes()); + const auto sizes = *tt->sizes().concrete_sizes(); + std::vector default_strides = TensorType::contiguousStridesOf(sizes); + TORCH_INTERNAL_ASSERT(tt->strides().concrete_sizes()); + const std::vector strides = *tt->strides().concrete_sizes(); + // All Tensors in NNC are layed out in default, contiguous layout. + // If the output is also default contiguous we don't need to do anything + if (strides == default_strides) { + return tensor; + } + // If the tensor is not dense or overlaps, we have + // no way of matching the profiled striding + if (!denseAndNonOverlapping(sizes, strides)) { + return tensor; + } + + auto dims = dimsFromSizes(sizesForValue(v)); + // We need to convert the output tensor so that its values are layed + // so that whene viewed from the output strides the values are correct. + // A contiguous Tensor of size(2, 3) with values 0-5 is layed out as: + // [0] [1] [2] [3] [4] [5] + // The same valued tensor with strides (2, 1) would be layed out like + // [0] [3] [1] [4] [2] [5] + // When we are doing the re-ordering of values into the output tensor, + // we are iterating per-element of the input, ad we are fixed + // in indexing in to the output tensor at [i, j] = val + // `val` we want here is equal to the indices for the output + // tensor that would have given the same position as the output + // The position is equal to the sum of stride[i] * index[i], + // and we can can calculate the equivalent indices in the + // output tensor strides by iteratively computing the index of + // the biggest stride: + // absolute = ... + // for stride in strides_from_largest_to_smallest: + // cur_idx = absolute // stride + // absolute = absolute % stride + + return Compute( + "output_1", dims, [&](const std::vector& axes_input) { + std::vector axes(axes_input.begin(), axes_input.end()); + auto absolute_position = IntImm::make(0); + for (size_t i = 0; i < axes.size(); ++i) { + absolute_position = + absolute_position + (IntImm::make(default_strides[i]) * axes[i]); + } + std::vector sorted_stride_indices = + reverse_sort_indices(strides); + std::vector new_axes(sorted_stride_indices.size()); + for (size_t stride_index : sorted_stride_indices) { + auto stride = strides[stride_index]; + auto index = Div::make(absolute_position, IntImm::make(stride)); + absolute_position = + Mod::make(absolute_position, IntImm::make(stride)); + new_axes[stride_index] = index; + } + return tensor->call(new_axes); + }); +} + void TensorExprKernel::compile() { KernelScope kernelScope(&kernelArena_); + GRAPH_DUMP("TensorExprKernel graph:", graph_); // Bind inputs to buffers. nInputs_ = graph_->inputs().size(); for (auto const& input : graph_->inputs()) { @@ -1650,57 +2107,82 @@ void TensorExprKernel::compile() { } } + device_ = *pickDeviceType(graph_->inputs()); + // Move output operands from `tensors_` to `tensorOutputs_` for (const auto& output : graph_->outputs()) { if (!tensors_.count(output->unique())) { throw malformed_input("cannot find output Tensor"); } + // The "strided" tensor will be incorrect if used in NNC, + // since NNC views it as contiguous. Only convert it to the right + // strides at the end of the kernel (if already contiguous it's a no-op) + Tensor* properly_strided_output = convertOutputToCorrectStrides(output); + tensors_[output->unique()] = properly_strided_output; + const auto& tt = output->type()->expect(); + auto sizes = *tt->sizes().concrete_sizes(); + tensorOutputSizes_.push_back(sizes); + auto strides = *tt->strides().concrete_sizes(); + + // If the tensor is not dense or overlaps, we have + // no way of matching the profiled striding + if (denseAndNonOverlapping(sizes, strides)) { + tensorOutputStrides_.push_back(*tt->strides().concrete_sizes()); + } else { + tensorOutputStrides_.push_back(TensorType::contiguousStridesOf(sizes)); + } + tensorOutputs_.emplace_back(tensors_.at(output->unique())); + tensorOutputTensorOptions_.emplace_back( + c10::TensorOptions(tensorType(tensors_[output->unique()])) + .device(device_)); tensors_.erase(output->unique()); } - device_ = *pickDeviceType(graph_->inputs()); BackendType backendType = inferBackendTypeFromDevice(device_); Stmt* stmt = generateStmt(backendType); // Set up formal params (inputs, then outputs) for kernel. std::vector params = prepareBufferArgs(); // Generate code. - codegen_ = CreateCodeGen(getCodeGenName(backendType), stmt, params, device_); + codegen_ = CreateCodeGen( + getCodeGenName(backendType), + stmt, + params, + device_, + SubgraphUtils::generateNameForGraph(graph_)); } TensorExprKernel::TensorExprKernel(const std::shared_ptr& subgraph) : graph_(subgraph), code_(subgraph, "") { - if (!fallbackAllowed()) { + allow_fallback_ = fallbackAllowed(); + if (!allow_fallback_) { compile(); return; } + use_fallback_ = fallbackEnforced(); + if (use_fallback_) { + return; + } + try { compile(); } catch (...) { - fallback_ = true; + use_fallback_ = true; } } void TensorExprKernel::run(Stack& stack) { - if (fallbackEnforced()) { - fallback(stack); - return; - } - if (!fallbackAllowed()) { + if (!use_fallback_ && !allow_fallback_) { runKernel(stack); - return; - } - - if (fallback_) { - fallback(stack); - return; - } - try { - runKernel(stack); - } catch (...) { - fallback_ = true; + } else if (!use_fallback_ && allow_fallback_) { + try { + runKernel(stack); + } catch (...) { + fallback(stack); + } + } else { fallback(stack); } } @@ -1708,47 +2190,29 @@ void TensorExprKernel::run(Stack& stack) { std::vector TensorExprKernel::prepareRunArgs( const at::ArrayRef& inputs, std::vector& outputs) { - std::map varToSize; - std::vector runArgs; - for (size_t i = 0; i < inputs.size(); i++) { + runArgs.reserve(inputs.size() + tensorOutputs_.size()); + + for (size_t i = 0, e = inputs.size(); i < e; i++) { auto const& input = inputs[i]; if (input.isInt()) { runArgs.emplace_back((int32_t)input.toInt()); } else if (input.isDouble()) { runArgs.emplace_back((float)input.toDouble()); } else if (input.isTensor()) { - auto const& tensor = input.toTensor(); - runArgs.emplace_back(tensor.data_ptr()); - for (auto const& size : kernelArgs_[i].sizes()) { - int32_t s = tensor.sizes()[size.idx]; - runArgs.emplace_back(s); - varToSize[size.var.node()] = s; - } - for (auto const& stride : kernelArgs_[i].strides()) { - int32_t s = tensor.strides()[stride.idx]; - runArgs.emplace_back(s); - } + runArgs.emplace_back(input.toTensor().data_ptr()); } } - for (auto& o : tensorOutputs_) { - std::vector tensorSize; - for (const Expr* dim : o->dims()) { - auto it = varToSize.find(dim); - if (it != varToSize.end()) { - tensorSize.push_back(it->second); - } else { - const IntImm* s = dynamic_cast(dim); - if (!s) { - throw malformed_input("output expected Int", dim); - } - tensorSize.push_back(s->value()); - } - } - - outputs.push_back(at::empty( - tensorSize, c10::TensorOptions(tensorType(o)).device(device_))); + for (size_t i = 0, e = tensorOutputs_.size(); i < e; ++i) { + auto const& opts = tensorOutputTensorOptions_[i]; + outputs.emplace_back(codegen_->empty_strided( + tensorOutputSizes_[i], + tensorOutputStrides_[i], + opts.dtype, + opts.layout, + opts.device, + opts.pinned_memory)); runArgs.emplace_back(outputs.back().data_ptr()); } return runArgs; diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index bcc5682f68a57..4e817f98d9b04 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -36,6 +36,16 @@ class TORCH_API TensorExprKernel { } private: + enum ElementType { + kAllTypes = 0, + kIntegralTypes = 1 << 0, + kFloatingPointTypes = 1 << 1, + kBoolType = 1 << 2, + kComplexTypes = 1 << 3, + kQintTypes = 1 << 4, + kNonComplexOrQintTypes = kIntegralTypes | kBoolType | kFloatingPointTypes, + }; + enum BackendType { kUninitialized, kSimpleIREval, @@ -65,13 +75,17 @@ class TORCH_API TensorExprKernel { ExprHandle chunk( Tensor* t, size_t chunkIdx, - size_t dim, - size_t chunks, + int64_t dim, + int64_t chunks, const std::vector& axes); std::vector valueShape(const torch::jit::Value* v); - void promoteInputs(std::vector& inputs); + bool checkTypes(const ScalarType highType, const int typeConstraints); + + void promoteInputs( + std::vector& inputs, + int typeConstraints = kAllTypes); ExprHandle demoteOutput(const ExprHandle& e, const torch::jit::Value* v); @@ -82,7 +96,8 @@ class TORCH_API TensorExprKernel { Tensor* computeOneOperand( const std::string& name, const torch::jit::Value* v, - const std::function& innerExpr); + const std::function& innerExpr, + const int checkParamTypes = kAllTypes); Tensor* computeTwoOperand( const std::string& name, @@ -101,7 +116,8 @@ class TORCH_API TensorExprKernel { const torch::jit::Value* v, const std::function< ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& - innerExpr); + innerExpr, + bool promote_inputs = true); Tensor* computeConditionWithTwoOperand( const std::string& name, @@ -121,9 +137,10 @@ class TORCH_API TensorExprKernel { Tensor* computeSum(const torch::jit::Value* v); + Tensor* computeSoftmax(const torch::jit::Value* v, bool log_softmax); + Tensor* computeValue(const torch::jit::Value* v); - void flattenTensors(BackendType backendType); Stmt* generateStmt(BackendType backendType); std::vector prepareBufferArgs(); @@ -136,6 +153,8 @@ class TORCH_API TensorExprKernel { void bindInput(const torch::jit::Value* input); + Tensor* convertOutputToCorrectStrides(torch::jit::Value* v); + // Captures the information for reduction operation nodes. struct ReductionInfo { std::vector reductionDims; @@ -186,10 +205,25 @@ class TORCH_API TensorExprKernel { std::vector strideArgs_; }; + struct UnpackedTensorOptions { + c10::optional dtype; + c10::optional layout; + c10::optional device; + c10::optional pinned_memory; + + UnpackedTensorOptions(const c10::TensorOptions& opts) + : dtype(optTypeMetaToScalarType(opts.dtype_opt())), + layout(opts.layout_opt()), + device(opts.device_opt()), + pinned_memory(opts.pinned_memory_opt()) {} + }; + int64_t nInputs_ = 0; std::vector kernelArgs_; + std::vector> tensorOutputSizes_; + std::vector> tensorOutputStrides_; + std::vector tensorOutputTensorOptions_; std::vector tensorOutputs_; - std::vector flatTensorOutputs_; std::unordered_map tensors_; std::unordered_map scalars_; std::unique_ptr codegen_; @@ -198,7 +232,8 @@ class TORCH_API TensorExprKernel { std::vector inputTypes_; std::shared_ptr graph_; Code code_; - bool fallback_{false}; + bool allow_fallback_{false}; + bool use_fallback_{false}; bool hasRandom_{false}; bool hasBroadcast_{false}; std::unordered_map> @@ -209,6 +244,7 @@ TORCH_API int& getTECudaPointwiseLoopLevels(); TORCH_API int& getTECudaPointwiseBlockCount(); TORCH_API int& getTECudaPointwiseBlockSize(); TORCH_API bool& getTEGenerateBlockCode(); +TORCH_API bool& getTEMustUseLLVMOnCPU(); TORCH_API bool fallbackAllowed(); TORCH_API bool setFallbackAllowed(bool value); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index e4331bc2d824b..6257afc084208 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -1,6 +1,8 @@ #ifdef TORCH_ENABLE_LLVM #include + +#include #include #include @@ -11,20 +13,39 @@ #include #include #include +#include #include + +#if LLVM_VERSION_MAJOR >= 10 +#include +#else #include +#endif + +#include #include +#include + +#if LLVM_VERSION_MAJOR >= 11 +#include +#endif -#include #include +#include #include #include +#include #include -#define DEBUG_PRINT 0 +#include using namespace torch::jit::tensorexpr; +C10_DEFINE_bool( + torch_jit_llvm_use_fast_intrinsics, + false, + "Use fast (but slightly less accurate) implementations of tanh and sigmoid"); + DEFINE_TRIGGER(llvm_codegen_created); DEFINE_TRIGGER(llvm_codegen_executed); @@ -33,18 +54,6 @@ namespace jit { namespace tensorexpr { namespace { -bool is_unsigned_integral(const ScalarType& type) { - switch (type) { - case ScalarType::Bool: - case ScalarType::Byte: - return true; - default: - return false; - } - - return false; -} - llvm::CmpInst::Predicate llvm_comparison_predicate( CompareSelectOperation compare_op, const ScalarType& type) { @@ -54,30 +63,45 @@ llvm::CmpInst::Predicate llvm_comparison_predicate( case CompareSelectOperation::kNE: return llvm::ICmpInst::ICMP_NE; case CompareSelectOperation::kGT: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGT - : llvm::ICmpInst::ICMP_SGT; + return is_signed(type) ? llvm::ICmpInst::ICMP_SGT + : llvm::ICmpInst::ICMP_UGT; case CompareSelectOperation::kGE: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGE - : llvm::ICmpInst::ICMP_SGE; + return is_signed(type) ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE; case CompareSelectOperation::kLT: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULT - : llvm::ICmpInst::ICMP_SLT; + return is_signed(type) ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT; case CompareSelectOperation::kLE: - return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULE - : llvm::ICmpInst::ICMP_SLE; + return is_signed(type) ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE; default: // TODO: change to a proper error report throw std::runtime_error("invalid operator type"); } } +#if LLVM_VERSION_MAJOR <= 9 +int ElementCount(int lanes) { + return lanes; +} +#else +llvm::ElementCount ElementCount(int lanes) { +#if LLVM_VERSION_MAJOR <= 11 + return llvm::ElementCount(static_cast(lanes), false); +#elif LLVM_VERSION_MAJOR == 12 + return llvm::ElementCount::getFixed(lanes); +#else +#error Only LLVM versions 8 through 12 are supported. +#endif +} +#endif + } // namespace class LLVMCodeGenImpl : public IRVisitor { private: - llvm::orc::ThreadSafeContext context_; + std::unique_ptr context_; llvm::IRBuilder<> irb_; - std::unique_ptr TM_; std::unique_ptr jit_; std::unique_ptr module_; llvm::Function* fn_; @@ -101,6 +125,7 @@ class LLVMCodeGenImpl : public IRVisitor { llvm::Type* dtypeToLLVMPtr(Dtype dtype); void emitWrapper(const std::vector& params); void emitKernel(Stmt* stmt, const std::vector& params); + llvm::Value* toVec(llvm::Value* v, int lanes); public: LLVMCodeGenImpl( @@ -132,6 +157,7 @@ class LLVMCodeGenImpl : public IRVisitor { #undef IMM_VISIT_DECLARE void visit(const Cast* v) override; + void visit(const BitCast* v) override; void visit(const Var* v) override; void visit(const Ramp* v) override; void visit(const Load* v) override; @@ -148,6 +174,8 @@ class LLVMCodeGenImpl : public IRVisitor { void visit(const Let* v) override; void visit(const Cond* v) override; + void emitIsNan(const Intrinsics* v); + llvm::Value* emitUnmaskedLoad(llvm::Value* addr, llvm::Value* idx); llvm::Value* emitMaskedLoad( llvm::Value* addr, @@ -166,33 +194,6 @@ class LLVMCodeGenImpl : public IRVisitor { } // namespace jit } // namespace torch -static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() { -#if 0 - // FIXME: Switch to using detectHost() rather than setting up the JTMB manually - // once LLVM 10 is available. - return llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()); -#else - llvm::orc::JITTargetMachineBuilder JTMB( - (llvm::Triple(llvm::sys::getProcessTriple()))); - - // Retrieve host CPU name and sub-target features and add them to builder. - // Relocation model, code model and codegen opt level are kept to default - // values. - llvm::SubtargetFeatures SubtargetFeatures; - llvm::StringMap FeatureMap; - llvm::sys::getHostCPUFeatures(FeatureMap); - for (auto& Feature : FeatureMap) { - SubtargetFeatures.AddFeature(Feature.first(), Feature.second); - } - - JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default); - JTMB.setCPU(llvm::sys::getHostCPUName()); - JTMB.addFeatures(SubtargetFeatures.getFeatures()); - - return JTMB; -#endif -} - LLVMCodeGen::~LLVMCodeGen() = default; LLVMCodeGen::LLVMCodeGen(Stmt* stmt) @@ -202,8 +203,9 @@ LLVMCodeGen::LLVMCodeGen( Stmt* stmt, const std::vector& args, at::Device device, + const std::string& kernel_func_name, Dtype dtype) - : CodeGen(stmt, args, device), + : CodeGen(stmt, args, device, kernel_func_name), impl_(std::make_unique(stmt, args, device, dtype)) {} static void* argToPtr( @@ -244,6 +246,17 @@ void LLVMCodeGen::call(const std::vector& args) { USE_TRIGGER(llvm_codegen_executed); } +at::Tensor LLVMCodeGen::empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + return at::native::empty_strided_cpu( + size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + void* LLVMCodeGen::getKernelAddress(LLVMCodeGenImpl* impl) { return (void*)impl->getKernelAddress(); } @@ -276,13 +289,15 @@ LLVMCodeGenImpl::LLVMCodeGenImpl( llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); - auto JTMB = makeTargetMachineBuilder(); - TM_ = llvm::cantFail(JTMB.createTargetMachine()); - jit_ = std::make_unique(); module_ = std::make_unique("pytorch", getContext()); - module_->setDataLayout(cantFail(JTMB.getDefaultDataLayoutForTarget())); - module_->setTargetTriple(JTMB.getTargetTriple().str()); + module_->setDataLayout(jit_->getDataLayout()); + module_->setTargetTriple(jit_->getTargetMachine().getTargetTriple().str()); + + // We support float16 ops by casting expr inputs to float32 + // and then casting the result back to float16 + HalfRewriter hsFix; + stmt = stmt->accept_mutator(&hsFix); // Emit prototype and bind argument Vars to parameter indices. llvm::Type* retTy = dtypeToLLVM(dtype); @@ -299,6 +314,9 @@ LLVMCodeGenImpl::LLVMCodeGenImpl( llvm::FunctionType* fntype = llvm::FunctionType::get(retTy, params, false); fn_ = llvm::Function::Create( fntype, llvm::Function::PrivateLinkage, "pytorch", module_.get()); + fn_->addAttribute( + llvm::AttributeList::AttrIndex::FunctionIndex, + llvm::Attribute::AlwaysInline); for (size_t i = 0; i < args.size(); i++) { if (!args[i].isVar()) { fn_->addParamAttr(i, llvm::Attribute::NoAlias); @@ -308,17 +326,16 @@ LLVMCodeGenImpl::LLVMCodeGenImpl( emitWrapper(params); emitKernel(stmt, params); - cantFail(jit_->addModule( - llvm::orc::ThreadSafeModule(std::move(module_), context_))); + assertSuccess(jit_->addModule(std::move(module_), std::move(context_))); auto sym = jit_->findSymbol("wrapper"); - kernelAddress_ = cantFail(sym.getAddress()); + kernelAddress_ = assertSuccess(sym.getAddress()); argv_ = std::make_unique(params.size()); USE_TRIGGER(llvm_codegen_created); } llvm::LLVMContext& LLVMCodeGenImpl::getContext() { - return *context_.getContext(); + return *context_; } llvm::Type* LLVMCodeGenImpl::dtypeToLLVM(Dtype dtype) { @@ -373,12 +390,12 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { if (v->op_type() == kTanh) { ScalarType stype = v->dtype().scalar_type(); if (stype == ScalarType::Float) { - return fast_tanh(v->param(0)); + return fast_tanh(v->param(0)->accept_mutator(this)); } } else if (v->op_type() == kSigmoid) { ScalarType stype = v->dtype().scalar_type(); if (stype == ScalarType::Float) { - return fast_sigmoid(v->param(0)); + return fast_sigmoid(v->param(0)->accept_mutator(this)); } } // TODO: fast exp @@ -396,25 +413,25 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { int lanes = dtype.lanes(); // TODO: use a dedicated bind-var to make sure v is not evalualted multiple // times. Clamp the input expression to [-9, 9] - ExprHandle plus_9 = to_vec(9.0f, lanes); - ExprHandle minus_9 = to_vec(-9.0f, lanes); + ExprHandle plus_9 = float_to_vec(9.0f, lanes); + ExprHandle minus_9 = float_to_vec(-9.0f, lanes); ExprHandle v1 = Min::make(v, plus_9, false); v1 = Max::make(v1, minus_9, false); // The coefficients for the numerator - ExprHandle alpha_1 = to_vec(4.89352455891786e-03f, lanes); - ExprHandle alpha_3 = to_vec(6.37261928875436e-04f, lanes); - ExprHandle alpha_5 = to_vec(1.48572235717979e-05f, lanes); - ExprHandle alpha_7 = to_vec(5.12229709037114e-08f, lanes); - ExprHandle alpha_9 = to_vec(-8.60467152213735e-11f, lanes); - ExprHandle alpha_11 = to_vec(2.00018790482477e-13f, lanes); - ExprHandle alpha_13 = to_vec(-2.76076847742355e-16f, lanes); + ExprHandle alpha_1 = float_to_vec(4.89352455891786e-03f, lanes); + ExprHandle alpha_3 = float_to_vec(6.37261928875436e-04f, lanes); + ExprHandle alpha_5 = float_to_vec(1.48572235717979e-05f, lanes); + ExprHandle alpha_7 = float_to_vec(5.12229709037114e-08f, lanes); + ExprHandle alpha_9 = float_to_vec(-8.60467152213735e-11f, lanes); + ExprHandle alpha_11 = float_to_vec(2.00018790482477e-13f, lanes); + ExprHandle alpha_13 = float_to_vec(-2.76076847742355e-16f, lanes); // The coeffecients for the denominator - ExprHandle beta_0 = to_vec(4.89352518554385e-03f, lanes); - ExprHandle beta_2 = to_vec(2.26843463243900e-03f, lanes); - ExprHandle beta_4 = to_vec(1.18534705686654e-04f, lanes); - ExprHandle beta_6 = to_vec(1.19825839466702e-06f, lanes); + ExprHandle beta_0 = float_to_vec(4.89352518554385e-03f, lanes); + ExprHandle beta_2 = float_to_vec(2.26843463243900e-03f, lanes); + ExprHandle beta_4 = float_to_vec(1.18534705686654e-04f, lanes); + ExprHandle beta_6 = float_to_vec(1.19825839466702e-06f, lanes); // numerator ExprHandle v2 = v1 * v1; @@ -439,20 +456,16 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { // sigmoid(x) = (tanh(x / 2) + 1) / 2 ExprHandle x{v_ptr}; int lanes = x.dtype().lanes(); - ExprHandle one_v = to_vec(1.f, lanes); - ExprHandle half_v = to_vec(0.5f, lanes); + ExprHandle one_v = float_to_vec(1.f, lanes); + ExprHandle half_v = float_to_vec(0.5f, lanes); ExprHandle x2 = x * half_v; ExprHandle y{fast_tanh(x2.node())}; ExprHandle z = (y + one_v) * half_v; return z.node(); } - ExprHandle to_vec(float v, int lanes) { - if (lanes == 1) { - return v; - } else { - return Broadcast::make(v, lanes); - } + ExprHandle float_to_vec(float v, int lanes) { + return expr_to_vec(FloatImm::make(v), lanes); } }; @@ -464,8 +477,13 @@ void LLVMCodeGenImpl::emitKernel( irb_.SetInsertPoint(bb_); // Maybe expand some of the intrinsics. - LLVMIntrinsicsExpander intrinsics_expander; - stmt = stmt->accept_mutator(&intrinsics_expander); + if (FLAGS_torch_jit_llvm_use_fast_intrinsics) { + LLVMIntrinsicsExpander intrinsics_expander; + stmt = stmt->accept_mutator(&intrinsics_expander); + } else { + GenericIntrinsicsExpander intrinsics_expander; + stmt = stmt->accept_mutator(&intrinsics_expander); + } // Compile the kernel. stmt->accept(this); @@ -477,27 +495,39 @@ void LLVMCodeGenImpl::emitKernel( irb_.CreateRet(value_); -#if DEBUG_PRINT - llvm::errs() << *module_; -#endif if (llvm::verifyFunction(*fn_, &llvm::outs())) { throw std::runtime_error("Function verification failed"); } - optimize(*module_); -#if DEBUG_PRINT - llvm::errs() << *module_; + // print graph debug info before optimization llvm::SmallVector asmBuffer; llvm::raw_svector_ostream asmStream(asmBuffer); - llvm::legacy::PassManager PM; - TM_->addPassesToEmitFile( - PM, - asmStream, - nullptr, - llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile); - PM.run(*module_); - llvm::errs() << asmStream.str(); + if (GRAPH_DEBUG_ENABLED) { + module_->print(asmStream, nullptr); + } + GRAPH_DEBUG( + "\nLLVM module before optimizations\n\n", asmStream.str().str(), "\n"); + + optimize(*module_); + + // print graph debug info after optimization + asmBuffer.set_size(0); + if (GRAPH_DEBUG_ENABLED) { + module_->print(asmStream, nullptr); + llvm::legacy::PassManager PM; + jit_->getTargetMachine().addPassesToEmitFile( + PM, + asmStream, + nullptr, +#if LLVM_VERSION_MAJOR >= 10 + llvm::CodeGenFileType::CGFT_AssemblyFile); +#else + llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile); #endif + PM.run(*module_); + } + GRAPH_DEBUG( + "\nLLVM module after optimizations\n\n", asmStream.str().str(), "\n"); } // TODO: The binary ops are copypasta. @@ -671,7 +701,8 @@ void LLVMCodeGenImpl::visit(const Max* v) { auto rhs = this->value_; if (v->dtype().is_integral()) { - auto icmp = irb_.CreateICmpSGT(lhs, rhs); + auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSGT(lhs, rhs) + : irb_.CreateICmpUGT(lhs, rhs); value_ = irb_.CreateSelect(icmp, lhs, rhs); return; } @@ -691,9 +722,9 @@ void LLVMCodeGenImpl::visit(const Min* v) { auto lhs = this->value_; v->rhs()->accept(this); auto rhs = this->value_; - - if (v->dtype() == kInt) { - auto icmp = irb_.CreateICmpSLT(lhs, rhs); + if (v->dtype().is_integral()) { + auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSLT(lhs, rhs) + : irb_.CreateICmpULT(lhs, rhs); value_ = irb_.CreateSelect(icmp, lhs, rhs); return; } @@ -785,13 +816,19 @@ void LLVMCodeGenImpl::visit(const BoolImm* v) { value_ = llvm::ConstantInt::get(BoolTy_, v->value()); } +llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) { + if (lanes > 1) { + return llvm::VectorType::get(type, ElementCount(lanes)); + } else { + return type; + } +} + void LLVMCodeGenImpl::visit(const Cast* v) { v->src_value()->accept(this); - llvm::Type* dstType = dtypeToLLVM(v->dtype()); - if (v->dtype().lanes() > 1) { - dstType = llvm::VectorType::get(dstType, v->dtype().lanes()); - } + llvm::Type* dstType = + llvmTypeToVec(dtypeToLLVM(v->dtype()), v->dtype().lanes()); llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype()); if (srcType == dstType) { @@ -799,13 +836,33 @@ void LLVMCodeGenImpl::visit(const Cast* v) { return; } - bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte; + bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte || + v->dtype().scalar_type() == ScalarType::Bool; // Scalar casts if (srcType->isFPOrFPVectorTy()) { if (dstType->isFPOrFPVectorTy()) { + // as with eager, convert from Double -> Half by Converting to Float then + // Half. TODO: __truncdfhf2 + if (v->dtype().scalar_type() == ScalarType::Half && + v->src_value()->dtype().scalar_type() == ScalarType::Double) { + value_ = irb_.CreateFPCast( + value_, llvmTypeToVec(FloatTy_, v->dtype().lanes())); + } value_ = irb_.CreateFPCast(value_, dstType); } else if (dstType->isIntOrIntVectorTy()) { + // Strictly casting from Float -> i8 doesnt give correct results + // set one bit true if the input float is not 0 + if (v->dtype().scalar_type() == ScalarType::Bool) { + llvm::Value* zero = + toVec(llvm::ConstantFP::get(srcType, 0.), v->dtype().lanes()); + value_ = irb_.CreateFCmp(llvm::FCmpInst::FCMP_UNO, value_, zero); + value_ = irb_.CreateICmpEQ( + value_, llvm::ConstantInt::get(value_->getType(), 0)); + value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned); + return; + } + if (destUnsigned) { value_ = irb_.CreateFPToUI(value_, dstType); } else { @@ -814,19 +871,48 @@ void LLVMCodeGenImpl::visit(const Cast* v) { } else { throw unimplemented_lowering(v); } - } else if (srcType->isIntOrIntVectorTy()) { - if (dstType->isFPOrFPVectorTy()) { - if (destUnsigned) { - value_ = irb_.CreateUIToFP(value_, dstType); - } else { - value_ = irb_.CreateSIToFP(value_, dstType); - } - } else if (dstType->isIntOrIntVectorTy()) { - value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned); + return; + } + if (!srcType->isIntOrIntVectorTy()) { + throw unimplemented_lowering(v); + } + if (dstType->isFPOrFPVectorTy()) { + if (destUnsigned) { + value_ = irb_.CreateUIToFP(value_, dstType); } else { - throw unimplemented_lowering(v); + value_ = irb_.CreateSIToFP(value_, dstType); + } + } else if (dstType->isIntOrIntVectorTy()) { + // Ensure bool true value is exactly one, since we convert to int + // from bool by zero extending the int8 + if (v->dtype().scalar_type() == ScalarType::Bool) { + llvm::Value* zero = + toVec(llvm::ConstantInt::get(srcType, 0), v->dtype().lanes()); + value_ = irb_.CreateICmpNE(value_, zero); } + value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned); + } else { + throw unimplemented_lowering(v); + } +} + +void LLVMCodeGenImpl::visit(const BitCast* v) { + v->src_value()->accept(this); + + llvm::Type* dstType = dtypeToLLVM(v->dtype()); + if (v->dtype().lanes() > 1) { + dstType = llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes())); + } + llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype()); + + if (srcType == dstType) { + // do nothing. + return; } + + TORCH_CHECK(llvm::CastInst::isBitCastable( + srcType->getScalarType(), dstType->getScalarType())); + value_ = irb_.CreateBitOrPointerCast(value_, dstType); } void LLVMCodeGenImpl::visit(const Var* v) { @@ -861,10 +947,11 @@ void LLVMCodeGenImpl::visit(const Ramp* v) { } llvm::Type* vecType = nullptr; + auto element_count = ElementCount(lanes); switch (v->dtype().scalar_type()) { -#define TYPE_CASE(_1, Name) \ - case ScalarType::Name: \ - vecType = llvm::VectorType::get(Name##Ty_, lanes); \ +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + vecType = llvm::VectorType::get(Name##Ty_, element_count); \ break; AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); #undef TYPE_CASE @@ -934,10 +1021,11 @@ void LLVMCodeGenImpl::visit(const Load* v) { llvm::Type* loadType = nullptr; + auto element_count = ElementCount(v->dtype().lanes()); switch (v->dtype().scalar_type()) { -#define TYPE_CASE(_1, Name) \ - case ScalarType::Name: \ - loadType = llvm::VectorType::get(Name##Ty_, v->dtype().lanes()); \ +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + loadType = llvm::VectorType::get(Name##Ty_, element_count); \ break; AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); #undef TYPE_CASE @@ -968,7 +1056,7 @@ void LLVMCodeGenImpl::visit(const Load* v) { auto addr = irb_.CreateGEP(base, first_idx); auto vaddr = irb_.CreateBitOrPointerCast( addr, llvm::PointerType::get(loadType, 0)); - value_ = irb_.CreateAlignedLoad(loadType, vaddr, 4); + value_ = irb_.CreateAlignedLoad(vaddr, 4); return; } } @@ -1216,12 +1304,71 @@ void LLVMCodeGenImpl::visit(const BaseCallNode* v) { static void applyMathFunctionAttributes(llvm::Function* f) { f->addFnAttr(llvm::Attribute::ReadNone); - f->addFnAttr(llvm::Attribute::NoFree); f->addFnAttr(llvm::Attribute::NoUnwind); // TODO: Adding this attr should be correct, but as of LLVM 9.0.1 adding it // causes some math functions to incorrectly be turned into tail calls. // f->addFnAttr(llvm::Attribute::Speculatable); +#if LLVM_VERSION_MAJOR >= 9 + f->addFnAttr(llvm::Attribute::NoFree); f->addFnAttr(llvm::Attribute::WillReturn); +#endif +} + +namespace { +#if LLVM_VERSION_MAJOR >= 9 + +using FunctionCallee = llvm::FunctionCallee; + +#elif LLVM_VERSION_MAJOR == 8 && LLVM_VERSION_PATCH == 20181009 + +struct FunctionCallee { + FunctionCallee() {} + + FunctionCallee(llvm::Constant* fn) + : v_(fn), ft_(cast(v_)->getFunctionType()) {} + + llvm::FunctionType* getFunctionType() { + return ft_; + } + + llvm::Value* getCallee() { + return v_; + } + + private: + llvm::Value* v_{nullptr}; + llvm::FunctionType* ft_{nullptr}; +}; + +#else +#error Only LLVM versions 8 through 12 are supported. +#endif + +} // namespace + +llvm::Value* LLVMCodeGenImpl::toVec(llvm::Value* v, int lanes) { + if (lanes > 1) { + return irb_.CreateVectorSplat(lanes, v); + } else { + return v; + } +} + +void LLVMCodeGenImpl::emitIsNan(const Intrinsics* v) { + v->param(0)->accept(this); + llvm::Type* dstType = dtypeToLLVM(v->dtype()); + if (!v->param(0)->dtype().is_floating_point()) { + value_ = toVec(llvm::ConstantInt::get(dstType, 0), v->dtype().lanes()); + } else { + TORCH_INTERNAL_ASSERT(v->dtype().scalar_type() == ScalarType::Int); + auto is_nan = irb_.CreateFCmpUNO( + value_, llvm::ConstantFP::get(value_->getType(), 0.)); + if (v->dtype().lanes() > 1) { + dstType = + llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes())); + } + value_ = irb_.CreateIntCast(is_nan, dstType, /*isSigned*/ false); + } } void LLVMCodeGenImpl::visit(const Intrinsics* v) { @@ -1229,62 +1376,66 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { llvm::Value* call_fn = nullptr; bool call_simd_sleef = false; + if (v->op_type() == kIsNan) { + return emitIsNan(v); + } + if (v->dtype().scalar_type() == ScalarType::Float) { switch (v->op_type()) { case kRsqrt: { v->params().front()->accept(this); value_ = irb_.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, value_); - llvm::Value* constant = llvm::ConstantFP::get(FloatTy_, 1.0); - if (v->dtype().lanes() > 1) { - constant = irb_.CreateVectorSplat(v->dtype().lanes(), constant); - } + llvm::Value* constant = + toVec(llvm::ConstantFP::get(FloatTy_, 1.0), v->dtype().lanes()); value_ = irb_.CreateFDiv(constant, value_); return; } break; #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 8) { \ - fname = "Sleef_" + std::string(name) + "8"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 8) { \ + fname = "Sleef_" + std::string(name) + "8"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_UNARY_MATH_CASE(kLog10, "log10f", FloatTy_) @@ -1295,7 +1446,7 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { SIMD_UNARY_MATH_CASE(kCos, "cosf", FloatTy_) SIMD_UNARY_MATH_CASE(kSin, "sinf", FloatTy_) SIMD_UNARY_MATH_CASE(kSqrt, "sqrtf", FloatTy_) - SIMD_UNARY_MATH_CASE(kFabs, "fabsf", FloatTy_) + SIMD_UNARY_MATH_CASE(kAbs, "fabsf", FloatTy_) SIMD_UNARY_MATH_CASE(kFloor, "floorf", FloatTy_) SIMD_UNARY_MATH_CASE(kCeil, "ceilf", FloatTy_) SIMD_UNARY_MATH_CASE(kTrunc, "truncf", FloatTy_) @@ -1314,54 +1465,56 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { #undef SIMD_UNARY_MATH_CASE #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 8) { \ - fname = "Sleef_" + std::string(name) + "8"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 8) { \ + fname = "Sleef_" + std::string(name) + "8"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_BINARY_MATH_CASE(kAtan2, "atan2f", FloatTy_) @@ -1369,16 +1522,15 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { SIMD_BINARY_MATH_CASE(kFmod, "fmodf", FloatTy_) #undef SIMD_BINARY_MATH_CASE -#define BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - auto callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ - } break; - BINARY_MATH_CASE(kRemainder, "remainderf", FloatTy_) -#undef BINARY_MATH_CASE + case kRemainder: { + FunctionCallee callee = module_->getOrInsertFunction( + "remainderf", + llvm::FunctionType::get(FloatTy_, {FloatTy_, FloatTy_}, false), + {}); + call_ty = callee.getFunctionType(); + call_fn = callee.getCallee(); + applyMathFunctionAttributes(llvm::cast(call_fn)); + } break; default: { throw unimplemented_lowering(v); @@ -1388,48 +1540,50 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { } else if (v->dtype().scalar_type() == ScalarType::Double) { switch (v->op_type()) { #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "d4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "d4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = \ + llvm::VectorType::get(type, ElementCount(v->dtype().lanes())); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_UNARY_MATH_CASE(kLog10, "log10", DoubleTy_) @@ -1440,7 +1594,7 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { SIMD_UNARY_MATH_CASE(kCos, "cos", DoubleTy_) SIMD_UNARY_MATH_CASE(kSin, "sin", DoubleTy_) SIMD_UNARY_MATH_CASE(kSqrt, "sqrt", DoubleTy_) - SIMD_UNARY_MATH_CASE(kFabs, "fabs", DoubleTy_) + SIMD_UNARY_MATH_CASE(kAbs, "fabs", DoubleTy_) SIMD_UNARY_MATH_CASE(kFloor, "floor", DoubleTy_) SIMD_UNARY_MATH_CASE(kCeil, "ceil", DoubleTy_) SIMD_UNARY_MATH_CASE(kTrunc, "trunc", DoubleTy_) @@ -1470,54 +1624,56 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { } break; #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "d4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "d4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - llvm::FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_BINARY_MATH_CASE(kAtan2, "atan2", DoubleTy_) @@ -1527,7 +1683,7 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { #define BINARY_MATH_CASE(enum, name, type) \ case enum: { \ - auto callee = module_->getOrInsertFunction( \ + FunctionCallee callee = module_->getOrInsertFunction( \ name, llvm::FunctionType::get(type, {type, type}, false), {}); \ call_ty = callee.getFunctionType(); \ call_fn = callee.getCallee(); \ @@ -1540,6 +1696,26 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { throw unimplemented_lowering(v); } break; } + } else if (v->dtype().is_integral() && v->op_type() == kAbs) { + // abs is only intrinsic defined for integer inputs in pytorch eager + v->params().front()->accept(this); + if (!v->dtype().is_signed()) { + return; + } + // TODO: use llvm.abs intrinsic for LLVM 12 + auto zero = llvm::ConstantInt::get(value_->getType(), 0); + auto neg_value = irb_.CreateSub(zero, value_); + auto icmp = irb_.CreateICmpSGT(value_, zero); + value_ = irb_.CreateSelect(icmp, value_, neg_value); + return; + } else { + TORCH_INTERNAL_ASSERT( + false, + v, + "Unimplemented lowering:", + v->op_type(), + " for input of dtype", + v->dtype().scalar_dtype()); } std::vector params; @@ -1620,7 +1796,45 @@ void LLVMCodeGenImpl::visit(const Let* v) { } void LLVMCodeGenImpl::visit(const Cond* v) { - throw unimplemented_lowering(v); + // Even if true_stmt and false_stmt are nullptr, + // in case condition is a function call with side effect, + // we still evaluate it. + v->condition()->accept(this); + + if (!v->true_stmt() && !v->false_stmt()) { + return; + } + assert(v->true_stmt()); + + llvm::Value* condition = value_; + llvm::Value* c = irb_.CreateICmpNE( + condition, llvm::ConstantInt::get(condition->getType(), 0)); + llvm::BasicBlock* then_block = + llvm::BasicBlock::Create(getContext(), "then", fn_); + llvm::BasicBlock* else_block = nullptr; + if (v->false_stmt()) { + else_block = llvm::BasicBlock::Create(getContext(), "else", fn_); + } + llvm::BasicBlock* end_block = + llvm::BasicBlock::Create(getContext(), "end", fn_); + + if (else_block) { + irb_.CreateCondBr(c, then_block, else_block); + } else { + irb_.CreateCondBr(c, then_block, end_block); + } + + irb_.SetInsertPoint(then_block); + v->true_stmt()->accept(this); + irb_.CreateBr(end_block); + + if (else_block) { + irb_.SetInsertPoint(else_block); + v->false_stmt()->accept(this); + irb_.CreateBr(end_block); + } + + irb_.SetInsertPoint(end_block); } void LLVMCodeGenImpl::optimize(llvm::Module& M) { @@ -1628,20 +1842,21 @@ void LLVMCodeGenImpl::optimize(llvm::Module& M) { llvm::legacy::PassManager PM; // Add internal analysis passes from the target machine. - PM.add( - llvm::createTargetTransformInfoWrapperPass(TM_->getTargetIRAnalysis())); - FPM.add( - llvm::createTargetTransformInfoWrapperPass(TM_->getTargetIRAnalysis())); + auto& TM = jit_->getTargetMachine(); + PM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis())); + FPM.add(llvm::createTargetTransformInfoWrapperPass(TM.getTargetIRAnalysis())); llvm::PassManagerBuilder PMB; PMB.OptLevel = 3; PMB.LoopVectorize = true; PMB.SLPVectorize = true; - TM_->adjustPassManager(PMB); + TM.adjustPassManager(PMB); PMB.populateFunctionPassManager(FPM); PMB.populateModulePassManager(PM); FPM.doInitialization(); + PM.add(llvm::createDeadCodeEliminationPass()); + PM.add(llvm::createAlwaysInlinerLegacyPass()); PM.run(M); for (auto& FF : M) { FPM.run(FF); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h index ad5df73912693..f12be00a3c0ae 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.h +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h @@ -22,6 +22,7 @@ class TORCH_API LLVMCodeGen : public CodeGen { Stmt* stmt, const std::vector& args, at::Device device = at::kCPU, + const std::string& kernel_func_name = "func", Dtype dtype = kInt); explicit LLVMCodeGen(Stmt* stmt); @@ -30,6 +31,14 @@ class TORCH_API LLVMCodeGen : public CodeGen { TORCH_API void call(const std::vector& args) override; + at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) override; + template T value() { return value(nullptr); diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index c4cc9337ce165..d9b726b902e64 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -2,433 +2,263 @@ #include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + #include #include #include #include #include +using namespace torch::jit::tensorexpr; + +template +static llvm::JITTargetAddress toAddress(T* Ptr) { + return static_cast(reinterpret_cast(Ptr)); +} + +static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() { +#if 0 + // FIXME: Switch to using detectHost() rather than setting up the JTMB manually + // once LLVM 10 is available. + return assertSuccess(llvm::orc::JITTargetMachineBuilder::detectHost()); +#else + llvm::orc::JITTargetMachineBuilder JTMB( + (llvm::Triple(llvm::sys::getProcessTriple()))); + + // Retrieve host CPU name and sub-target features and add them to builder. + // Relocation model, code model and codegen opt level are kept to default + // values. + llvm::SubtargetFeatures SubtargetFeatures; + llvm::StringMap FeatureMap; + llvm::sys::getHostCPUFeatures(FeatureMap); + for (auto& Feature : FeatureMap) { + SubtargetFeatures.AddFeature(Feature.first(), Feature.second); + } + + JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default); + JTMB.setCPU(llvm::sys::getHostCPUName().str()); + JTMB.addFeatures(SubtargetFeatures.getFeatures()); + JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast; + + return JTMB; +#endif +} + +static void registerIntrinsics( + llvm::orc::JITDylib& JD, + llvm::orc::MangleAndInterner& Mangle) { + using namespace llvm; + using namespace llvm::orc; + + auto entry = [&](const char* name, auto ptr) -> SymbolMap::value_type { + return {Mangle(name), {toAddress(ptr), JITSymbolFlags::None}}; + }; + + assertSuccess(JD.define(absoluteSymbols({ + entry("log10f", &log10f), entry("log1pf", &log1pf), entry("logf", &logf), + entry("log2f", &log2f), entry("expf", &expf), entry("erff", &erff), + entry("cosf", &cosf), entry("sinf", &sinf), entry("tanf", &tanf), + entry("acosf", &acosf), entry("asinf", &asinf), entry("atanf", &atanf), + entry("coshf", &coshf), entry("sinhf", &sinhf), entry("tanhf", &tanhf), + entry("sqrtf", &sqrtf), entry("fabsf", &fabsf), + entry("floorf", &floorf), entry("ceilf", &ceilf), + entry("roundf", &roundf), entry("truncf", &truncf), + entry("atan2f", &atan2f), entry("fmodf", &fmodf), + entry("remainderf", &remainderf), + + // float -> half & half -> float conversions + entry("__gnu_h2f_ieee", &c10::detail::fp16_ieee_to_fp32_value), + entry("__gnu_f2h_ieee", &c10::detail::fp16_ieee_from_fp32_value), + + // FP32 Sleef functions -- SSE + entry("Sleef_acosf4", &Sleef_acosf4_u10), + entry("Sleef_asinf4", &Sleef_asinf4_u10), + entry("Sleef_atanf4", &Sleef_atanf4_u10), + entry("Sleef_cosf4", &Sleef_cosf4_u10), + entry("Sleef_sinf4", &Sleef_sinf4_u10), + entry("Sleef_tanf4", &Sleef_tanf4_u10), + entry("Sleef_coshf4", &Sleef_coshf4_u10), + entry("Sleef_sinhf4", &Sleef_sinhf4_u10), + entry("Sleef_tanhf4", &Sleef_tanhf4_u10), + entry("Sleef_erff4", &Sleef_erff4_u10), + entry("Sleef_erfcf4", &Sleef_erfcf4_u15), + entry("Sleef_expf4", &Sleef_expf4_u10), + entry("Sleef_expm1f4", &Sleef_expm1f4_u10), + entry("Sleef_logf4", &Sleef_logf4_u10), + entry("Sleef_log2f4", &Sleef_log2f4_u10), + entry("Sleef_log10f4", &Sleef_log10f4_u10), + entry("Sleef_log1pf4", &Sleef_log1pf4_u10), + entry("Sleef_sqrtf4", &Sleef_sqrtf4_u05), + entry("Sleef_fabsf4", &Sleef_fabsf4), + entry("Sleef_floorf4", &Sleef_floorf4), + entry("Sleef_ceilf4", &Sleef_ceilf4), + entry("Sleef_truncf4", &Sleef_truncf4), + entry("Sleef_roundf4", &Sleef_roundf4), + entry("Sleef_lgammaf4", &Sleef_lgammaf4_u10), + entry("Sleef_atan2f4", &Sleef_atan2f4_u10), + entry("Sleef_powf4", &Sleef_powf4_u10), + entry("Sleef_fmodf4", &Sleef_fmodf4), + + // FP32 Sleef functions -- AVX2 +#if defined(__AVX__) && !defined(_MSC_VER) + entry("Sleef_acosf8", &Sleef_acosf8_u10), + entry("Sleef_asinf8", &Sleef_asinf8_u10), + entry("Sleef_atanf8", &Sleef_atanf8_u10), + entry("Sleef_cosf8", &Sleef_cosf8_u10), + entry("Sleef_sinf8", &Sleef_sinf8_u10), + entry("Sleef_tanf8", &Sleef_tanf8_u10), + entry("Sleef_coshf8", &Sleef_coshf8_u10), + entry("Sleef_sinhf8", &Sleef_sinhf8_u10), + entry("Sleef_tanhf8", &Sleef_tanhf8_u10), + entry("Sleef_erff8", &Sleef_erff8_u10), + entry("Sleef_erfcf8", &Sleef_erfcf8_u15), + entry("Sleef_expf8", &Sleef_expf8_u10), + entry("Sleef_expm1f8", &Sleef_expm1f8_u10), + entry("Sleef_logf8", &Sleef_logf8_u10), + entry("Sleef_log2f8", &Sleef_log2f8_u10), + entry("Sleef_log10f8", &Sleef_log10f8_u10), + entry("Sleef_log1pf8", &Sleef_log1pf8_u10), + entry("Sleef_sqrtf8", &Sleef_sqrtf8_u05), + entry("Sleef_fabsf8", &Sleef_fabsf8), + entry("Sleef_floorf8", &Sleef_floorf8), + entry("Sleef_ceilf8", &Sleef_ceilf8), + entry("Sleef_truncf8", &Sleef_truncf8), + entry("Sleef_roundf8", &Sleef_roundf8), + entry("Sleef_lgammaf8", &Sleef_lgammaf8_u10), + entry("Sleef_atan2f8", &Sleef_atan2f8_u10), + entry("Sleef_powf8", &Sleef_powf8_u10), + entry("Sleef_fmodf8", &Sleef_fmodf8), +#endif + + // FP64 Sleef functions -- SSE + entry("Sleef_acosd2", &Sleef_acosd2_u10), + entry("Sleef_asind2", &Sleef_asind2_u10), + entry("Sleef_atand2", &Sleef_atand2_u10), + entry("Sleef_cosd2", &Sleef_cosd2_u10), + entry("Sleef_sind2", &Sleef_sind2_u10), + entry("Sleef_tand2", &Sleef_tand2_u10), + entry("Sleef_coshd2", &Sleef_coshd2_u10), + entry("Sleef_sinhd2", &Sleef_sinhd2_u10), + entry("Sleef_tanhd2", &Sleef_tanhd2_u10), + entry("Sleef_erfd2", &Sleef_erfd2_u10), + entry("Sleef_erfcd2", &Sleef_erfcd2_u15), + entry("Sleef_expd2", &Sleef_expd2_u10), + entry("Sleef_expm1d2", &Sleef_expm1d2_u10), + entry("Sleef_logd2", &Sleef_logd2_u10), + entry("Sleef_log2d2", &Sleef_log2d2_u10), + entry("Sleef_log10d2", &Sleef_log10d2_u10), + entry("Sleef_log1pd2", &Sleef_log1pd2_u10), + entry("Sleef_sqrtd2", &Sleef_sqrtd2_u05), + entry("Sleef_fabsd2", &Sleef_fabsd2), + entry("Sleef_floord2", &Sleef_floord2), + entry("Sleef_ceild2", &Sleef_ceild2), + entry("Sleef_truncd2", &Sleef_truncd2), + entry("Sleef_roundd2", &Sleef_roundd2), + entry("Sleef_lgammad2", &Sleef_lgammad2_u10), + entry("Sleef_atan2d2", &Sleef_atan2d2_u10), + entry("Sleef_powd2", &Sleef_powd2_u10), + entry("Sleef_fmodd2", &Sleef_fmodd2), + + // FP64 Sleef functions -- AVX2 +#if defined(__AVX__) && !defined(_MSC_VER) + entry("Sleef_acosd4", &Sleef_acosd4_u10), + entry("Sleef_asind4", &Sleef_asind4_u10), + entry("Sleef_atand4", &Sleef_atand4_u10), + entry("Sleef_cosd4", &Sleef_cosd4_u10), + entry("Sleef_sind4", &Sleef_sind4_u10), + entry("Sleef_tand4", &Sleef_tand4_u10), + entry("Sleef_coshd4", &Sleef_coshd4_u10), + entry("Sleef_sinhd4", &Sleef_sinhd4_u10), + entry("Sleef_tanhd4", &Sleef_tanhd4_u10), + entry("Sleef_erfd4", &Sleef_erfd4_u10), + entry("Sleef_erfcd4", &Sleef_erfcd4_u15), + entry("Sleef_expd4", &Sleef_expd4_u10), + entry("Sleef_expm1d4", &Sleef_expm1d4_u10), + entry("Sleef_logd4", &Sleef_logd4_u10), + entry("Sleef_log2d4", &Sleef_log2d4_u10), + entry("Sleef_log10d4", &Sleef_log10d4_u10), + entry("Sleef_log1pd4", &Sleef_log1pd4_u10), + entry("Sleef_sqrtd4", &Sleef_sqrtd4_u05), + entry("Sleef_fabsd4", &Sleef_fabsd4), + entry("Sleef_floord4", &Sleef_floord4), + entry("Sleef_ceild4", &Sleef_ceild4), + entry("Sleef_truncd4", &Sleef_truncd4), + entry("Sleef_roundd4", &Sleef_roundd4), + entry("Sleef_lgammad4", &Sleef_lgammad4_u10), + entry("Sleef_atan2d4", &Sleef_atan2d4_u10), + entry("Sleef_powd4", &Sleef_powd4_u10), + entry("Sleef_fmodd4", &Sleef_fmodd4), +#endif + }))); +} + namespace llvm { namespace orc { // Lightly modified implementation from LLVM's Kaleidoscope JIT tutorial: // https://llvm.org/docs/tutorial/BuildingAJIT1.html +#if LLVM_VERSION_MAJOR >= 9 && LLVM_VERSION_MAJOR <= 12 class TORCH_API PytorchLLVMJITImpl { private: + std::unique_ptr TM; std::unique_ptr LLJ; public: - PytorchLLVMJITImpl() : LLJ(cantFail(LLJITBuilder().create())) { + PytorchLLVMJITImpl() + : TM(assertSuccess(makeTargetMachineBuilder().createTargetMachine())), + LLJ(assertSuccess( + LLJITBuilder() + .setJITTargetMachineBuilder(makeTargetMachineBuilder()) + .create())) { auto ProcSymbolsGenerator = - cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( + assertSuccess(DynamicLibrarySearchGenerator::GetForCurrentProcess( LLJ->getDataLayout().getGlobalPrefix())); - LLJ->getMainJITDylib().setGenerator(std::move(ProcSymbolsGenerator)); + auto& JD = LLJ->getMainJITDylib(); +#if LLVM_VERSION_MAJOR == 9 + JD.setGenerator(std::move(ProcSymbolsGenerator)); +#else + JD.addGenerator(std::move(ProcSymbolsGenerator)); +#endif // Handle platform-specific symbol mangling MangleAndInterner Mangle(LLJ->getExecutionSession(), LLJ->getDataLayout()); // Register implementations of intrinsics - cantFail(LLJ->defineAbsolute( - *Mangle("log10f"), {llvm::pointerToJITTargetAddress(&log10f), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("log1pf"), {llvm::pointerToJITTargetAddress(&log1pf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("logf"), {llvm::pointerToJITTargetAddress(&logf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("log2f"), {llvm::pointerToJITTargetAddress(&log2f), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("expf"), {llvm::pointerToJITTargetAddress(&expf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("erff"), {llvm::pointerToJITTargetAddress(&erff), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("cosf"), {llvm::pointerToJITTargetAddress(&cosf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("sinf"), {llvm::pointerToJITTargetAddress(&sinf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("tanf"), {llvm::pointerToJITTargetAddress(&tanf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("acosf"), {llvm::pointerToJITTargetAddress(&acosf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("asinf"), {llvm::pointerToJITTargetAddress(&asinf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("atanf"), {llvm::pointerToJITTargetAddress(&atanf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("coshf"), {llvm::pointerToJITTargetAddress(&coshf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("sinhf"), {llvm::pointerToJITTargetAddress(&sinhf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("tanhf"), {llvm::pointerToJITTargetAddress(&tanhf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("sqrtf"), {llvm::pointerToJITTargetAddress(&sqrtf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("fabsf"), {llvm::pointerToJITTargetAddress(&fabsf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("floorf"), {llvm::pointerToJITTargetAddress(&floorf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("ceilf"), {llvm::pointerToJITTargetAddress(&ceilf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("roundf"), {llvm::pointerToJITTargetAddress(&roundf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("truncf"), {llvm::pointerToJITTargetAddress(&truncf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("atan2f"), {llvm::pointerToJITTargetAddress(&atan2f), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("fmodf"), {llvm::pointerToJITTargetAddress(&fmodf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("remainderf"), - {llvm::pointerToJITTargetAddress(&remainderf), {}})); - - // FP32 Sleef functions -- SSE - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosf4"), - {llvm::pointerToJITTargetAddress(&Sleef_acosf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asinf4"), - {llvm::pointerToJITTargetAddress(&Sleef_asinf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atanf4"), - {llvm::pointerToJITTargetAddress(&Sleef_atanf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosf4"), - {llvm::pointerToJITTargetAddress(&Sleef_cosf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinf4"), - {llvm::pointerToJITTargetAddress(&Sleef_sinf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanf4"), - {llvm::pointerToJITTargetAddress(&Sleef_tanf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshf4"), - {llvm::pointerToJITTargetAddress(&Sleef_coshf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhf4"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhf4"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erff4"), - {llvm::pointerToJITTargetAddress(&Sleef_erff4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcf4"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcf4_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expf4"), - {llvm::pointerToJITTargetAddress(&Sleef_expf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1f4"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logf4"), - {llvm::pointerToJITTargetAddress(&Sleef_logf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2f4"), - {llvm::pointerToJITTargetAddress(&Sleef_log2f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10f4"), - {llvm::pointerToJITTargetAddress(&Sleef_log10f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pf4"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtf4"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtf4_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsf4"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floorf4"), - {llvm::pointerToJITTargetAddress(&Sleef_floorf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceilf4"), - {llvm::pointerToJITTargetAddress(&Sleef_ceilf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncf4"), - {llvm::pointerToJITTargetAddress(&Sleef_truncf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundf4"), - {llvm::pointerToJITTargetAddress(&Sleef_roundf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammaf4"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammaf4_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2f4"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powf4"), - {llvm::pointerToJITTargetAddress(&Sleef_powf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodf4"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodf4), {}})); - - // FP32 Sleef functions -- AVX2 -#if defined(__AVX__) && !defined(_MSC_VER) - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosf8"), - {llvm::pointerToJITTargetAddress(&Sleef_acosf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asinf8"), - {llvm::pointerToJITTargetAddress(&Sleef_asinf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atanf8"), - {llvm::pointerToJITTargetAddress(&Sleef_atanf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosf8"), - {llvm::pointerToJITTargetAddress(&Sleef_cosf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinf8"), - {llvm::pointerToJITTargetAddress(&Sleef_sinf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanf8"), - {llvm::pointerToJITTargetAddress(&Sleef_tanf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshf8"), - {llvm::pointerToJITTargetAddress(&Sleef_coshf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhf8"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhf8"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erff8"), - {llvm::pointerToJITTargetAddress(&Sleef_erff8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcf8"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcf8_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expf8"), - {llvm::pointerToJITTargetAddress(&Sleef_expf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1f8"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logf8"), - {llvm::pointerToJITTargetAddress(&Sleef_logf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2f8"), - {llvm::pointerToJITTargetAddress(&Sleef_log2f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10f8"), - {llvm::pointerToJITTargetAddress(&Sleef_log10f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pf8"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtf8"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtf8_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsf8"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floorf8"), - {llvm::pointerToJITTargetAddress(&Sleef_floorf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceilf8"), - {llvm::pointerToJITTargetAddress(&Sleef_ceilf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncf8"), - {llvm::pointerToJITTargetAddress(&Sleef_truncf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundf8"), - {llvm::pointerToJITTargetAddress(&Sleef_roundf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammaf8"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammaf8_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2f8"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powf8"), - {llvm::pointerToJITTargetAddress(&Sleef_powf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodf8"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodf8), {}})); -#endif - - // FP64 Sleef functions -- SSE - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosd2"), - {llvm::pointerToJITTargetAddress(&Sleef_acosd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asind2"), - {llvm::pointerToJITTargetAddress(&Sleef_asind2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atand2"), - {llvm::pointerToJITTargetAddress(&Sleef_atand2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosd2"), - {llvm::pointerToJITTargetAddress(&Sleef_cosd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sind2"), - {llvm::pointerToJITTargetAddress(&Sleef_sind2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tand2"), - {llvm::pointerToJITTargetAddress(&Sleef_tand2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshd2"), - {llvm::pointerToJITTargetAddress(&Sleef_coshd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhd2"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhd2"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfd2"), - {llvm::pointerToJITTargetAddress(&Sleef_erfd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcd2"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcd2_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expd2"), - {llvm::pointerToJITTargetAddress(&Sleef_expd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1d2"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logd2"), - {llvm::pointerToJITTargetAddress(&Sleef_logd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2d2"), - {llvm::pointerToJITTargetAddress(&Sleef_log2d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10d2"), - {llvm::pointerToJITTargetAddress(&Sleef_log10d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pd2"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtd2"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtd2_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsd2"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsd2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floord2"), - {llvm::pointerToJITTargetAddress(&Sleef_floord2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceild2"), - {llvm::pointerToJITTargetAddress(&Sleef_ceild2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncd2"), - {llvm::pointerToJITTargetAddress(&Sleef_truncd2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundd2"), - {llvm::pointerToJITTargetAddress(&Sleef_roundd2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammad2"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammad2_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2d2"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powd2"), - {llvm::pointerToJITTargetAddress(&Sleef_powd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodd2"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodd2), {}})); - - // FP64 Sleef functions -- AVX2 -#if defined(__AVX__) && !defined(_MSC_VER) - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosd4"), - {llvm::pointerToJITTargetAddress(&Sleef_acosd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asind4"), - {llvm::pointerToJITTargetAddress(&Sleef_asind4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atand4"), - {llvm::pointerToJITTargetAddress(&Sleef_atand4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosd4"), - {llvm::pointerToJITTargetAddress(&Sleef_cosd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sind4"), - {llvm::pointerToJITTargetAddress(&Sleef_sind4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tand4"), - {llvm::pointerToJITTargetAddress(&Sleef_tand4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshd4"), - {llvm::pointerToJITTargetAddress(&Sleef_coshd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhd4"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhd4"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfd4"), - {llvm::pointerToJITTargetAddress(&Sleef_erfd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcd4"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcd4_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expd4"), - {llvm::pointerToJITTargetAddress(&Sleef_expd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1d4"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logd4"), - {llvm::pointerToJITTargetAddress(&Sleef_logd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2d4"), - {llvm::pointerToJITTargetAddress(&Sleef_log2d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10d4"), - {llvm::pointerToJITTargetAddress(&Sleef_log10d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pd4"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtd4"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtd4_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsd4"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsd4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floord4"), - {llvm::pointerToJITTargetAddress(&Sleef_floord4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceild4"), - {llvm::pointerToJITTargetAddress(&Sleef_ceild4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncd4"), - {llvm::pointerToJITTargetAddress(&Sleef_truncd4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundd4"), - {llvm::pointerToJITTargetAddress(&Sleef_roundd4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammad4"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammad4_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2d4"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powd4"), - {llvm::pointerToJITTargetAddress(&Sleef_powd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodd4"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodd4), {}})); -#endif + registerIntrinsics(JD, Mangle); } - Error addModule(ThreadSafeModule M) { - if (auto Err = LLJ->addIRModule(std::move(M))) { + Error addModule(std::unique_ptr M, std::unique_ptr C) { + if (auto Err = + LLJ->addIRModule(ThreadSafeModule(std::move(M), std::move(C)))) { return Err; } return Error::success(); } JITSymbol findSymbol(const std::string Name) { - return cantFail(LLJ->lookup(Name)); + return assertSuccess(LLJ->lookup(Name)); + } + + TargetMachine& getTargetMachine() { + return *TM; } const DataLayout& getDataLayout() { @@ -441,18 +271,132 @@ PytorchLLVMJIT::PytorchLLVMJIT() PytorchLLVMJIT::~PytorchLLVMJIT() = default; -Error PytorchLLVMJIT::addModule(ThreadSafeModule M) { - return impl_->addModule(std::move(M)); +Error PytorchLLVMJIT::addModule( + std::unique_ptr M, + std::unique_ptr C) { + return impl_->addModule(std::move(M), std::move(C)); +} + +JITSymbol PytorchLLVMJIT::findSymbol(const std::string Name) { + return impl_->findSymbol(std::move(Name)); +} + +TargetMachine& PytorchLLVMJIT::getTargetMachine() { + return impl_->getTargetMachine(); +} + +const DataLayout& PytorchLLVMJIT::getDataLayout() { + return impl_->getDataLayout(); +} + +#elif LLVM_VERSION_MAJOR == 8 && LLVM_VERSION_PATCH == 20181009 + +class TORCH_API PytorchLLVMJITImpl { + private: + ExecutionSession ES; + std::shared_ptr Resolver; + std::unique_ptr TM; + const DataLayout DL; + RTDyldObjectLinkingLayer ObjectLayer; + IRCompileLayer CompileLayer; + + public: + PytorchLLVMJITImpl() + : Resolver(createLegacyLookupResolver( + ES, + [this](const std::string& Name) -> JITSymbol { + if (auto Sym = CompileLayer.findSymbol(Name, false)) { + return Sym; + } else if (auto Err = Sym.takeError()) { + return std::move(Err); + } + if (auto SymAddr = + RTDyldMemoryManager::getSymbolAddressInProcess(Name)) { + return JITSymbol(SymAddr, JITSymbolFlags::Exported); + } + MangleAndInterner Mangle(ES, DL); + return assertSuccess( + lookup({&ES.getMainJITDylib()}, Mangle(Name))); + }, + [](Error Err) { + assertSuccess(std::move(Err), "lookupFlags failed"); + })), + TM(assertSuccess(makeTargetMachineBuilder().createTargetMachine())), + DL(TM->createDataLayout()), + ObjectLayer( + ES, + [this](VModuleKey) { + return RTDyldObjectLinkingLayer::Resources{ + std::make_shared(), Resolver}; + }), + CompileLayer(ObjectLayer, SimpleCompiler(*TM)) { + auto& JD = ES.getMainJITDylib(); + MangleAndInterner Mangle(ES, DL); + registerIntrinsics(JD, Mangle); + llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); + } + + TargetMachine& getTargetMachine() { + return *TM; + } + + VModuleKey addModule(std::unique_ptr M) { + // Add the module to the JIT with a new VModuleKey. + auto K = ES.allocateVModule(); + assertSuccess( + CompileLayer.addModule(K, std::move(M)), + "Failed to add module to compile layer"); + return K; + } + + JITSymbol findSymbol(const std::string Name) { + std::string MangledName; + raw_string_ostream MangledNameStream(MangledName); + Mangler::getNameWithPrefix(MangledNameStream, Name, DL); + return CompileLayer.findSymbol(MangledNameStream.str(), true); + } + + JITTargetAddress getSymbolAddress(const std::string Name) { + return assertSuccess(findSymbol(Name).getAddress()); + } + + void removeModule(VModuleKey K) { + assertSuccess(CompileLayer.removeModule(K)); + } + + const DataLayout& getDataLayout() { + return DL; + } +}; + +PytorchLLVMJIT::PytorchLLVMJIT() + : impl_(std::make_unique()) {} + +PytorchLLVMJIT::~PytorchLLVMJIT() = default; + +Error PytorchLLVMJIT::addModule( + std::unique_ptr M, + std::unique_ptr C) { + impl_->addModule(std::move(M)); + return Error::success(); } JITSymbol PytorchLLVMJIT::findSymbol(const std::string Name) { return impl_->findSymbol(std::move(Name)); } +TargetMachine& PytorchLLVMJIT::getTargetMachine() { + return impl_->getTargetMachine(); +} + const DataLayout& PytorchLLVMJIT::getDataLayout() { return impl_->getDataLayout(); } +#else // LLVM_VERSION_MAJOR +#error Only LLVM versions 8 through 12 are supported. +#endif + } // end namespace orc } // end namespace llvm diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index 0a96efd1298af..af6caee880a9d 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -1,6 +1,7 @@ #pragma once #ifdef TORCH_ENABLE_LLVM +#include #include #include @@ -11,6 +12,32 @@ #include #include +namespace torch { +namespace jit { +namespace tensorexpr { + +inline std::string formatError(llvm::Error&& err, const char* msg) { + static constexpr char* defaultErrorMsg = "Unexpected failure in LLVM JIT"; + std::string errorMsg(msg ? msg : defaultErrorMsg); + llvm::raw_string_ostream ss(errorMsg); + ss << ": " << err; + return ss.str(); +} + +template +T assertSuccess(llvm::Expected valOrErr, const char* msg = nullptr) { + TORCH_INTERNAL_ASSERT(valOrErr, formatError(valOrErr.takeError(), msg)); + return std::move(*valOrErr); +} + +inline void assertSuccess(llvm::Error err, const char* msg = nullptr) { + TORCH_INTERNAL_ASSERT(!err, formatError(std::move(err), msg)); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + namespace llvm { namespace orc { @@ -21,11 +48,12 @@ class TORCH_API PytorchLLVMJIT { PytorchLLVMJIT(); ~PytorchLLVMJIT(); - Error addModule(ThreadSafeModule M); + Error addModule(std::unique_ptr M, std::unique_ptr C); JITSymbol findSymbol(const std::string Name); TargetMachine& getTargetMachine(); + const DataLayout& getDataLayout(); private: diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 091b931bb8094..f4f346fdd36a6 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,28 @@ namespace torch { namespace jit { namespace tensorexpr { +class FunctionCallUseCount : public IRVisitor { + public: + std::unordered_map findUses(Stmt* s) { + s->accept(this); + return uses_; + } + + private: + void visit(const FunctionCall* v) override { + if (function_calls_[v->tensor()->buf()].insert(v).second) { + uses_[v->tensor()->buf()] = uses_[v->tensor()->buf()] + 1; + } + IRVisitor::visit(v); + } + + std::unordered_map uses_; + + // Sets of FunctionCalls in order to keep the results unique + std::unordered_map> + function_calls_; +}; + class IndexFlattener : public IRMutator { public: Stmt* flatten(Stmt* s) { @@ -38,6 +61,19 @@ class IndexFlattener : public IRMutator { {flatten_index(v->buf()->dims(), v->indices())}, v->mask()); } + + const Expr* mutate(const ReduceOp* v) override { + const Expr* new_body = v->body()->accept_mutator(this); + + auto* out = new ReduceOp( + v->accumulator(), + new_body, + {flatten_index(v->accumulator()->dims(), v->output_args())}, + v->reduce_args(), + v->reducer()); + return out; + } + Stmt* mutate(const Store* v) override { const Expr* value = v->value(); const Expr* new_value = value->accept_mutator(this); @@ -141,6 +177,14 @@ class Vectorizer : public IRMutator { }); } + const Expr* mutate(const BitCast* v) override { + std::vector inputs = {v->src_value()}; + return try_vectorize(v, inputs, [&]() { + return BitCast::make( + Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0])); + }); + } + const Expr* mutate(const Cast* v) override { std::vector inputs = {v->src_value()}; return try_vectorize(v, inputs, [&]() { @@ -184,6 +228,26 @@ class Vectorizer : public IRMutator { }); } + const Expr* mutate(const ReduceOp* v) override { + Dtype dtype(v->dtype().scalar_type(), lanes_); + + auto inputs = v->output_args(); + // should already be flattened. + TORCH_INTERNAL_ASSERT(inputs.size() == 1); + + inputs.push_back(v->body()); + + auto* out = try_vectorize(v, inputs, [&]() { + return ExprHandle(new ReduceOp( + v->accumulator(), + inputs[1], + {inputs[0]}, + v->reduce_args(), + v->reducer())); + }); + return out; + } + const Expr* mutate(const Broadcast* v) override { const Expr* val = v->value(); const Expr* new_val = val->accept_mutator(this); @@ -312,17 +376,21 @@ class Vectorizer : public IRMutator { const Expr* start_ = nullptr; }; -void LoopNest::vectorize(Stmt* stmt) { - For* f = dynamic_cast(stmt); - if (!f) { - return; - } - +void LoopNest::vectorize(For* f) { Block* b = dynamic_cast(f->get_parent()); if (!b) { return; } + // Can't vectorize reduction axes. + auto reductions = NodeFinder::find(f); + for (auto* r : reductions) { + if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) != + r->reduce_args().end()) { + throw std::logic_error("Cannot vectorize reduction axis - rfactor first"); + } + } + Vectorizer v; Stmt* old_f = Stmt::clone(f); Stmt* new_f = nullptr; @@ -342,13 +410,13 @@ class Flattener : public IRMutator { Expr* mutate(const FunctionCall* v) override { const Tensor* t = v->tensor(); const Buf* b = t->buf(); - Buffer buffer = Buffer(BufHandle(b)); + Placeholder buffer = Placeholder(BufHandle(b)); const std::vector& params = v->params(); std::vector params_expr(params.size()); for (size_t i = 0; i < params.size(); i++) { params_expr[i] = ExprHandle(params[i]); } - return buffer(params_expr).node(); + return buffer.load(params_expr).node(); } }; @@ -356,7 +424,11 @@ class DepTracker : public IRVisitor { public: std::vector findUsedTensors(Tensor* tensor) { used_tensors.clear(); - tensor->body()->accept(this); + if (tensor->body()) { + tensor->body()->accept(this); + } else { + tensor->ElementStmt()->accept(this); + } return used_tensors; } @@ -414,20 +486,23 @@ std::vector LoopNest::findAllNeededTensors( return result; } -LoopNest::LoopNest(const std::vector& output_tensors) - : output_tensors_(output_tensors.begin(), output_tensors.end()) { +LoopNest::LoopNest(const std::vector& output_tensors) { // Find all tensors we need to compute (including dependencies) and put them // in a topological order std::vector tensors_to_compute = findAllNeededTensors(output_tensors); + for (auto t : output_tensors) { + output_bufs_.insert(t->buf()); + } + // Find all intermediate tensors, we'll need that for inserting alloc/free // statements std::unordered_set tensors_to_compute_set( tensors_to_compute.begin(), tensors_to_compute.end()); for (Tensor* t : tensors_to_compute) { - if (!output_tensors_.count(t)) { - intermediate_tensors_.insert(t); + if (!output_bufs_.count(t->buf())) { + intermediate_bufs_.insert(t->buf()); } } @@ -446,40 +521,66 @@ LoopNest::LoopNest(const std::vector& output_tensors) } root_stmt_ = new Block(loops); + + // If it's referenced in the root_stmt, but it's not in output_bufs_ or + // intermediate_bufs_ then it must be an input. + auto bufs = NodeFinder::find(root_stmt_); + for (auto* buf : bufs) { + if (!output_bufs_.count(buf) && !intermediate_bufs_.count(buf)) { + input_bufs_.insert(buf); + } + } } Stmt* LoopNest::lowerToStmt(Tensor* t) { - Function* f = t->function(); - // TODO: Support multiple-output functions - Stmt* body = f->ElementStmt(0); + Stmt* body = t->ElementStmt(); - if (f->ndim() == 0) { + // If this Tensor has no functional body, it already has its axes expanded. + if (nullptr == t->body()) { + return body; + } + + if (t->ndim() == 0 && t->reduce_ndim() == 0) { return body; } const Expr* initializer = t->initializer(); if (initializer) { - buf_initializers_[t->func_var()] = initializer; + buf_initializers_[t->buf()] = initializer; } + std::vector indices(t->args().begin(), t->args().end()); - for (size_t i = 0; i < f->ndim(); i++) { - // Going in reverse order: from innermost loop to the outermost - size_t dim_index = f->ndim() - i - 1; - body = new For(f->arg(dim_index), new IntImm(0), f->dim(dim_index), body); - indices.pop_back(); - if (initializer && indices.size() == t->ndim()) { + if (t->reduce_ndim() > 0) { + for (size_t i = 0; i < t->reduce_ndim(); i++) { + // Going in reverse order: from innermost loop to the outermost + size_t dim_index = t->reduce_ndim() - i - 1; + body = new For( + t->reduce_arg(dim_index), + new IntImm(0), + t->reduce_dim(dim_index), + body); + } + if (initializer) { Store* init = new Store(t->buf(), indices, initializer, new IntImm(1)); body = new Block({init, body}); } } + + for (size_t i = 0; i < t->ndim(); i++) { + // Going in reverse order: from innermost loop to the outermost + size_t dim_index = t->ndim() - i - 1; + body = new For(t->arg(dim_index), new IntImm(0), t->dim(dim_index), body); + } return body; } class FunctionInliner : public IRMutator { public: - FunctionInliner(Store* producer) - : buf_(producer->buf()), producer_(producer) { + FunctionInliner(Store* producer, std::unordered_set outputs) + : buf_(producer->buf()), + producer_(producer), + outputs_(std::move(outputs)) { for (auto* i : producer->indices()) { const Var* index_var = dynamic_cast(i); if (index_var == nullptr) { @@ -493,26 +594,21 @@ class FunctionInliner : public IRMutator { // For the target function, insert the caller/callee pair into the replacement // mapping. const Expr* mutate(const FunctionCall* v) override { - Function* func = v->tensor()->function(); - const Buf* buf = v->tensor()->buf(); + const Tensor* t = v->tensor(); + const Buf* buf = t->buf(); if (buf != buf_) { return IRMutator::mutate(v); } - // TODO: Support multiple-output functions - if (func->func_vars().size() != 1) { - throw unimplemented_lowering(); - } - if (v->nparams() != buf->ndim()) { throw malformed_input( - "Buffer indexed access is inconsistent with its rank", v); + "Placeholder indexed access is inconsistent with its rank", v); } std::vector index_vars; - TORCH_INTERNAL_ASSERT(buf->ndim() == func->args().size()); + TORCH_INTERNAL_ASSERT(buf->ndim() == t->args().size()); for (size_t i = 0; i < buf->ndim(); i++) { - const Var* func_callee_arg = dynamic_cast(func->arg(i)); + const Var* func_callee_arg = dynamic_cast(t->arg(i)); const Expr* func_caller_param = v->param(i); auto iter = inline_mapping_.find(func_callee_arg); if (iter != inline_mapping_.end()) { @@ -572,7 +668,9 @@ class FunctionInliner : public IRMutator { // Remove the buffer write from the inlined function. Stmt* mutate(const Store* v) override { - if (v == producer_) { + // If the buf_ is in the outputs set, keep its statement intact. Otherwise, + // remove it. + if (v == producer_ && !outputs_.count(buf_)) { in_producer_ = true; producer_ = dynamic_cast(IRMutator::mutate(v)); TORCH_INTERNAL_ASSERT(producer_ != nullptr); @@ -637,23 +735,18 @@ class FunctionInliner : public IRMutator { // In the producer's scope - we need to bind any calls to rand(). bool in_producer_ = false; std::unordered_map> random_bindings_; + std::unordered_set outputs_; }; -void LoopNest::computeInline(Stmt* s) { +bool LoopNest::computeInline(Stmt* s) { auto* s_store = dynamic_cast(s); if (s_store == nullptr) { throw std::logic_error("Could not find buffer producer to inline"); } - computeInline(s_store->buf()); + return computeInline(s_store->buf()); } -void LoopNest::computeInline(const Buf* b) { - for (auto* t : output_tensors_) { - if (b == t->buf()) { - throw std::logic_error("Can't inline producers of output Tensors"); - } - } - +bool LoopNest::computeInline(const Buf* b) { // Find producers. Store* relevant_store{nullptr}; auto stores = NodeFinder::find(root_stmt_); @@ -661,38 +754,87 @@ void LoopNest::computeInline(const Buf* b) { if (s->buf() == b) { auto reductions = NodeFinder::find(s); if (!reductions.empty()) { - throw std::logic_error("cannot inline a reduction computation"); + // Cannot inline a reduction computation + return false; } if (relevant_store != nullptr) { - throw std::logic_error("cannot inline Buf with multiple Tensors"); + // Cannot inline Buf with multiple Tensors + return false; } relevant_store = s; } } + TORCH_INTERNAL_ASSERT(relevant_store); - FunctionInliner inliner(relevant_store); + FunctionInliner inliner(relevant_store, output_bufs_); root_stmt_ = root_stmt_->accept_mutator(&inliner); // No longer computing this intermediate tensor, so don't alloc it. - for (auto* t : intermediate_tensors_) { - if (b == t->buf()) { - intermediate_tensors_.erase(t); - break; + intermediate_bufs_.erase(b); + return true; +} + +// inlining buffers with multiple uses can create duplicated work, which can +// slow down cpu code generation but is enabled on gpu because it avoids +// difficult synchronization logic across blocks. Inlining trivial reads does +// not duplicate work +void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { + // We need to collect all intermediate buffers as the buffers to be inlined + // before calling 'computeInline' since the buffers that are inlined are + // erased from the set 'intermediate_bufs_' in that function. + std::unordered_set bufs_to_inline; + + if (allow_duplicated_work) { + bufs_to_inline.insert(intermediate_bufs_.begin(), intermediate_bufs_.end()); + } else { + FunctionCallUseCount fcu; + auto function_call_uses = fcu.findUses(root_stmt_); + auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_); + auto input_bufs = getInputBufs(); + + for (auto buf : intermediate_bufs_) { + TORCH_INTERNAL_ASSERT(buf_load_store_uses.count(buf)); + std::vector& uses = buf_load_store_uses[buf]; + auto stores = c10::filter( + uses, [](const BufLoadOrStoreUse& use) { return use.isStore; }); + + // if the intermediate is the buffer formed from reading in the input + // tensors, always inline, bc we are not duplicating any work + // and avoiding an intermediary buffer + if (stores.size() == 1) { + auto store = dynamic_cast(stores[0].s); + auto input_as_load = dynamic_cast(store->value()); + if (input_as_load && input_bufs.count(input_as_load->buf())) { + bufs_to_inline.insert(buf); + continue; + } + } + + // all bufs will have at least one store (if they have > 1 they cant be + // inlined anyway) + size_t reads = uses.size() - 1; + size_t function_call_reads = function_call_uses[buf]; + // if only one read, we can inline it without duplicating work + if ((reads + function_call_reads) <= 1) { + bufs_to_inline.insert(buf); + } } } - for (auto it = temp_bufs_.begin(); it != temp_bufs_.end(); ++it) { - if (b == *it) { - temp_bufs_.erase(it); - break; - } + if (allow_duplicated_work) { + bufs_to_inline.insert(output_bufs_.begin(), output_bufs_.end()); + } + + for (auto b : bufs_to_inline) { + computeInline(b); } } // TODO: Unify with DepTracker -class UseFinder : public IRVisitor { +class LoadOrStoreUseFinder : public IRVisitor { public: - std::unordered_map> findUses(Stmt* s) { + std::unordered_map> findUses( + Stmt* s) { uses_.clear(); s->accept(this); return uses_; @@ -714,15 +856,16 @@ class UseFinder : public IRVisitor { } Stmt* last_stmt_ = nullptr; - std::unordered_map> uses_; + std::unordered_map> uses_; // Sets of loads and stores in order to keep the results unique std::unordered_map> loads_; std::unordered_map> stores_; }; -std::unordered_map> findUses(Stmt* s) { - UseFinder uf; +std::unordered_map> +findLoadOrStoreUses(Stmt* s) { + LoadOrStoreUseFinder uf; return uf.findUses(s); } @@ -748,7 +891,7 @@ class ContainedStmtsFinder : public IRVisitor { std::unordered_set contained_; }; -bool containsAll(const std::vector& uses, Block* b) { +bool containsAll(const std::vector& uses, Block* b) { std::unordered_set not_found; for (auto use : uses) { not_found.insert(use.s); @@ -772,7 +915,7 @@ Block* findParentBlock(Stmt* s) { return nullptr; } -Block* findLowestContainingBlock(const std::vector& uses) { +Block* findLowestContainingBlock(const std::vector& uses) { // TODO: we're not using the most efficient algorithm here for simplicity. // Replace with something more performant in case it becomes a bottleneck. Block* b = findParentBlock(uses[0].s); @@ -783,9 +926,7 @@ Block* findLowestContainingBlock(const std::vector& uses) { } Stmt* LoopNest::insertAllocFree(Stmt* stmt) { - // Add allocs and frees for intermediate buffers at the global level. - // TODO: move allocs and frees to the imemediate areas to reuse buffers. - if (intermediate_tensors_.size() == 0ULL && temp_bufs_.size() == 0ULL) { + if (intermediate_bufs_.size() == 0ULL) { return stmt; } @@ -794,34 +935,83 @@ Stmt* LoopNest::insertAllocFree(Stmt* stmt) { b = new Block({stmt}); } - // TODO: Fix the traversal, currently the order is non-deterministic - for (Tensor* tensor : intermediate_tensors_) { - if (output_tensors_.count(tensor) > 0) { - // No need to allocate memory if the tensors are given as input/output. - continue; + std::unordered_map> uses = + findLoadOrStoreUses(stmt); + // Insert allocations and frees for temporary buffers in the innermost + // possible scope. + for (const Buf* buf : intermediate_bufs_) { + const Expr* flat_size = new IntImm(1); + for (auto& d : buf->dims()) { + flat_size = new Mul(flat_size, d); } - Stmt* alloc = new Allocate( - tensor->buf()->base_handle(), tensor->body()->dtype(), tensor->dims()); - Stmt* free = new Free(tensor->buf()->base_handle()); - b->prepend_stmt(alloc); - b->append_stmt(free); - } - - // Now insert allocations and frees for temporary buffers. Do that in the - // innermost possible scope. - std::unordered_map> uses = findUses(stmt); - - for (const auto& buf : temp_bufs_) { - Stmt* alloc = new Allocate(buf->base_handle(), buf->dtype(), buf->dims()); + flat_size = IRSimplifier::simplify(flat_size); + Stmt* alloc = new Allocate(buf->base_handle(), buf->dtype(), {flat_size}); Stmt* free = new Free(buf->base_handle()); - Block* alloc_block = findLowestContainingBlock(uses.at(buf)); alloc_block->prepend_stmt(alloc); alloc_block->append_stmt(free); } + return b; } +class StmtDeleter : public IRMutator { + public: + StmtDeleter(const std::unordered_set& targets) + : targets_(targets) {} + + private: + Stmt* mutate(const Block* v) override { + std::vector stmts; + + for (auto* s : v->stmts()) { + if (targets_.count(s) == 0) { + Stmt* ns = s->accept_mutator(this); + if (ns) { + stmts.push_back(Stmt::clone(ns)); + } + } + } + + return Block::make(stmts); + } + + const std::unordered_set& targets_; +}; + +void LoopNest::eliminateDeadStores() { + using namespace analysis; + MemDependencyChecker checker(getInputBufs(), getOutputBufs()); + root_stmt_->accept(&checker); + + std::unordered_set deadStores; + std::vector> outputAccesses; + for (auto* o : getOutputBufs()) { + outputAccesses.push_back(checker.output(o)); + } + + for (auto& info : checker.getHistory()) { + if (!info->isWrite()) { + continue; + } + bool found = false; + + for (auto& output : outputAccesses) { + if (checker.dependsIndirectly(output, info)) { + found = true; + break; + } + } + + if (!found) { + deadStores.insert(info->stmt()); + } + } + + StmtDeleter deleter(deadStores); + root_stmt_ = root_stmt_->accept_mutator(&deleter); +} + void LoopNest::prepareForCodegen() { // Expand reduction ops. ReductionExpander reduceExpander; @@ -984,6 +1174,11 @@ void LoopNest::sliceTail(For* f, int factor, For** head, For** tail) { // TODO: record history of transformations } +void LoopNest::splitWithTail(For* f, int factor) { + For *outer, *inner, *tail; + splitWithTail(f, factor, &outer, &inner, &tail); +} + void LoopNest::splitWithTail( For* f, int factor, @@ -1049,8 +1244,11 @@ void LoopNest::splitWithTail( } else { *tail = nullptr; } +} - // TODO: record history of transformations +void LoopNest::splitWithMask(For* f, int factor) { + For *outer, *inner; + splitWithMask(f, factor, &outer, &inner); } void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) { @@ -1110,8 +1308,6 @@ void LoopNest::splitWithMask(For* f, int factor, For** outer, For** inner) { // TODO: cleanup API for adding/removing statements p->replace_stmt(f, *outer); - - // TODO: record history of transformations } For* findOuterFor(For* a, For* b) { @@ -1312,6 +1508,89 @@ void LoopNest::normalize(For* f, For** normalized) { p->replace_stmt(f, *normalized); } +// This function expects that there are 'num' loops perfectly nested within +// and including 'f'. +std::vector LoopNest::getLoopStmtsInLoopNest(For* f, size_t num) { + std::vector loops(num); + For* curr_for = f; + loops[0] = curr_for; + for (size_t i = 1; i < num; ++i) { + TORCH_INTERNAL_ASSERT(curr_for->body()->nstmts() == 1); + curr_for = dynamic_cast(curr_for->body()->front()); + TORCH_INTERNAL_ASSERT(curr_for); + loops[i] = curr_for; + } + return loops; +} + +bool LoopNest::flatten(const std::vector& loops, For** flattened) { + if (loops.empty()) { + throw malformed_input("flatten attempted on empty set of loops"); + } + Block* p = dynamic_cast(loops[0]->get_parent()); + if (!p) { + throw malformed_input("flatten attempted on loops with no parent"); + } + + if (loops.size() == 1) { + // This loop nest is already flattened. + *flattened = loops[0]; + return false; + } + + // Check if all the loops correspond to a perfect loopnest: + // * every loop except the inner-most should have only one stmt, the For. + // Do not flatten, otherwise. + // This check also ensures we do not flatten reduction loops. + for (size_t i = 0; i < loops.size() - 1; ++i) { + if ((loops[i]->body()->nstmts() != 1) || + (loops[i]->body()->front() != loops[i + 1])) { + *flattened = loops[0]; + return false; + } + } + + // Normalize the loops before flattening. + // We need to normalize them from inner-most to outer because once the outer + // loop is normalized, the given pointers to inner loops point to old code. + // For the same reason, we can't store the normalized inner loops until after + // the outer-most loop is normalized. + For* normalized; + for (size_t i = 0; i < loops.size(); ++i) { + size_t idx = loops.size() - i - 1; + LoopNest::normalize(loops[idx], &normalized); + } + + // 'normalized' points to the outer-most loop in the normalized loopnest. + // Collect all the normalized loops. + auto normalized_loops = getLoopStmtsInLoopNest(normalized, loops.size()); + + auto flat_var = new Var( + normalized_loops[0]->var()->name_hint() + "_flat", + normalized_loops[0]->var()->dtype()); + VarMapping var_mapping; + Expr* stop = new IntImm(1); + for (size_t i = 0; i < normalized_loops.size(); ++i) { + size_t idx = normalized_loops.size() - i - 1; + auto curr_loop = normalized_loops[idx]; + Expr* div = new Div(flat_var, stop); + Expr* sub_expr = idx == 0 ? div : new Mod(div, curr_loop->stop()); + var_mapping.push_back(std::make_pair(curr_loop->var(), sub_expr)); + stop = new Mul(curr_loop->stop(), stop); + } + auto flattened_body = + Substitute(Stmt::clone(normalized_loops.back()->body()), var_mapping); + + *flattened = new For( + flat_var, + new IntImm(0), + stop, + flattened_body, + normalized_loops[0]->loop_options()); + p->replace_stmt(normalized_loops[0], *flattened); + return true; +} + std::vector LoopNest::getLoopStmtsFor(Tensor* t) const { Stmt* cur_stmt = getLoopBodyFor(t); return getLoopStmtsFor(cur_stmt); @@ -1406,7 +1685,7 @@ class LoopComputeAtRewriter : public IRMutator { return new Load(v->dtype(), new_buf_, new_indices, v->mask()); } const Expr* mutate(const FunctionCall* v) override { - if (v->tensor()->func_var() != buf_) { + if (v->tensor()->buf() != buf_) { return v; } std::vector new_indices; @@ -1444,6 +1723,257 @@ static std::vector getOuterLoopIndexes(Stmt* s) { return res; } +class CacheReplacer : public IRMutator { + public: + CacheReplacer( + const Buf* buffer, + const Buf* cache, + std::vector& offsets) + : buf_(buffer), cache_(cache), offsets_(offsets) {} + + private: + const Expr* mutate(const FunctionCall* v) override { + const Buf* buf = v->tensor()->buf(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + // for reductions the size of tensor->args() is not equal to the size of the + // output buffer, but they should be ordered so that the output args are at + // the beginning even if the loops are reordered later. + // Map indices to call-parameters. + std::vector newIndices; + for (size_t i = 0; i < offsets_.size(); ++i) { + const Expr* index = v->param(i)->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new Load(cache_, newIndices, new IntImm(1)); + } + + const Expr* mutate(const Load* v) override { + const Buf* buf = v->buf(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + // Map indices to call-parameters. + std::vector newIndices; + TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); + for (size_t i = 0; i < v->indices().size(); ++i) { + const Expr* index = v->indices()[i]->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new Load(cache_, newIndices, v->mask()); + } + + Stmt* mutate(const Store* v) override { + const Buf* buf = v->buf(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + const Expr* newValue = v->value()->accept_mutator(this); + + // Map indices to call-parameters. + std::vector newIndices; + TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); + for (size_t i = 0; i < v->indices().size(); ++i) { + const Expr* index = v->indices()[i]->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new Store(cache_, newIndices, newValue, v->mask()); + } + + const Expr* mutate(const ReduceOp* v) override { + const Buf* buf = v->accumulator(); + if (buf != buf_) { + return IRMutator::mutate(v); + } + + const Expr* newBody = v->body()->accept_mutator(this); + + // Map indices to call-parameters. + std::vector newIndices; + TORCH_INTERNAL_ASSERT(offsets_.size() == v->output_args().size()); + for (size_t i = 0; i < v->output_args().size(); ++i) { + const Expr* index = v->output_args()[i]->accept_mutator(this); + const Expr* offset = offsets_[i]; + const Expr* sub = IRSimplifier::simplify(new Sub(index, offset)); + newIndices.push_back(sub); + } + + return new ReduceOp( + cache_, newBody, newIndices, v->reduce_args(), v->reducer()); + } + + const Buf* buf_; + const Buf* cache_; + std::vector& offsets_; +}; + +LoopNest::AccessResult LoopNest::cacheAccesses( + const Buf* producer, + const std::string& name, + Stmt* consumer) { + ReduceOp* reduceOp{nullptr}; + auto reductions = NodeFinder::find(consumer); + for (auto* ro : reductions) { + if (ro->accumulator() != producer) { + continue; + } + + if (reduceOp) { + throw std::runtime_error( + "can only cache accesses used by at most a single reduceOp"); + return {nullptr, nullptr}; + } + + reduceOp = ro; + } + + // Check bounds but don't care about AccessKind. + auto consumer_bounds_info = inferBounds(consumer, false); + auto bounds_it = consumer_bounds_info.find(producer); + if (bounds_it == consumer_bounds_info.end()) { + throw std::runtime_error("consumer does not use the Tensor produced"); + return {nullptr, nullptr}; + } + + TORCH_INTERNAL_ASSERT(bounds_it->second.size() == 1); + TensorAccessBoundsInfo& info = bounds_it->second[0]; + bool hasReads = info.kind == kLoad || info.kind == kMutate; + bool hasWrites = info.kind == kStore || info.kind == kMutate; + + std::vector var_names = {"i", "j", "k", "l", "m", "n", "o", "p"}; + std::vector tmp_dims; + std::vector new_loop_vars; + std::vector new_loop_vars_expr; + + // Determine the size of the cache, and create a loop var for each dimension. + for (size_t i = 0; i < info.start.size(); ++i) { + const Expr* dim = IRSimplifier::simplify( + new Add(new Sub(info.stop[i], info.start[i]), new IntImm(1))); + + tmp_dims.push_back(dim); + + new_loop_vars.push_back(new Var(var_names[i % var_names.size()], kInt)); + new_loop_vars_expr.push_back(new_loop_vars[i]); + } + + // Create the var. + Buf* tmp_buf = new Buf(new Var(name, kHandle), tmp_dims, producer->dtype()); + + // determine the offsets for calls into the cache based off the loop start of + // each axis. + std::vector tmp_params; + for (size_t i = 0; i < new_loop_vars.size(); ++i) { + tmp_params.push_back(new Add(new_loop_vars[i], info.start[i])); + } + + // Replace acceses to the producer in the consumer with the cache. + CacheReplacer replacer(producer, tmp_buf, info.start); + Stmt* new_consumer = + IRSimplifier::simplify(consumer->accept_mutator(&replacer)); + + intermediate_bufs_.insert(tmp_buf); + + // replace the old consumer with the replaced consumer. + Block* consumer_block = nullptr; + // if the consumer is a block, we should mutate it in place. + if ((consumer_block = dynamic_cast(consumer))) { + consumer_block->clear(); + consumer_block->append_stmt(new_consumer); + } else { + consumer_block = dynamic_cast(consumer->get_parent()); + assert(consumer_block); + consumer_block->replace_stmt(consumer, new_consumer); + } + + // If there's a reduction we can't just write the result straight back to the + // original buffer, since after parallelism the writes will race. Instead we + // need to create a new ReduceOp. + if (reduceOp) { + // reduceOp means we had both loads and stores. + + // Init cache to 0. + Stmt* tmp_init = new Store( + tmp_buf, + new_loop_vars_expr, + getImmediateByType(tmp_buf->dtype(), 0), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_init = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_init); + } + + consumer_block->insert_stmt_before(tmp_init, new_consumer); + + // Reduce back to the original buffer: + Stmt* tmp_store = new Store( + producer, + tmp_params, + reduceOp->reducer()( + producer, + ExprHandle(new Load(tmp_buf, new_loop_vars_expr, new IntImm(1))), + tmp_params, + {}), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_store = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + } + + consumer_block->insert_stmt_after(tmp_store, new_consumer); + + return std::make_pair(tmp_buf, new_consumer); + } + + if (hasReads) { + // Fill the cache with values from the consumer. + Stmt* tmp_store = new Store( + tmp_buf, + new_loop_vars_expr, + new Load(producer, tmp_params, new IntImm(1)), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_store = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + } + + consumer_block->insert_stmt_before(tmp_store, new_consumer); + } + + if (hasWrites) { + // sync the cache back to the producer buf. + Stmt* tmp_store = new Store( + producer, + tmp_params, + new Load(tmp_buf, new_loop_vars_expr, new IntImm(1)), + new IntImm(1)); + + for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { + tmp_store = + new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store); + } + + consumer_block->insert_stmt_after(tmp_store, new_consumer); + } + + return std::make_pair(tmp_buf, new_consumer); +} + /* * WHAT COMPUTE_AT DOES * ==================== @@ -1561,32 +2091,16 @@ void LoopNest::computeAt(Stmt* s, For* f) { // Infer bounds info for all accesses that we make in the loop auto loop_bounds_info = inferBounds(f->body()); - // store_bounds_info holds bounds info for the store we're trying to move to + // bounds_it holds bounds info for the store we're trying to move to // the loop. If its result isn't accessed in the loop at all - do nothing and // exit early. - TensorAccessBoundsInfo store_bounds_info; - bool found = false; - for (const auto& pair : loop_bounds_info) { - const Buf* buf = pair.first; - for (const TensorAccessBoundsInfo& p : pair.second) { - if (buf == st->buf()) { - store_bounds_info = p; - found = true; - } - } - } - if (!found) { + auto bounds_it = loop_bounds_info.find(st->buf()); + if (bounds_it == loop_bounds_info.end()) { return; } // Compute dimensions of the temp buffer we would need to allocate - std::vector dims; - for (size_t i = 0; i < store_bounds_info.start.size(); i++) { - const Expr* dim = IRSimplifier::simplify(new Add( - new Sub(store_bounds_info.stop[i], store_bounds_info.start[i]), - new IntImm(1))); - dims.push_back(dim); - } + std::vector dims = getBoundExtents(bounds_it->second); // TODO: Use name-hint of the producer instead of "temp" const Buf* temp_buf = new Buf("temp", dims, st->value()->dtype()); @@ -1607,11 +2121,23 @@ void LoopNest::computeAt(Stmt* s, For* f) { // might be different. In that case, the loop below would crash. std::vector prod_indices = getOuterLoopIndexes(s); std::vector> rewrite_indices_map; + std::vector offsets; + for (const TensorAccessBoundsInfo& p : bounds_it->second) { + for (size_t i = 0; i < p.start.size(); i++) { + if (offsets.size() <= i) { + offsets.push_back(p.start[i]); + } else { + offsets[i] = + IRSimplifier::simplify(new Min(offsets[i], p.start[i], true)); + } + } + } + for (size_t i = 0; i < prod_indices.size(); i++) { - const Expr* offset = store_bounds_info.start[i]; rewrite_indices_map.push_back( - {prod_indices[i], new Add(temp_indices[i], offset)}); + {prod_indices[i], new Add(temp_indices[i], offsets[i])}); } + // Construct the temp statement Stmt* bd = new Store( temp_buf, @@ -1635,7 +2161,7 @@ void LoopNest::computeAt(Stmt* s, For* f) { f->body()->prepend_stmt(bd); // Rewrite accesses to producer in consumer with accesses to temp - LoopComputeAtRewriter lr(st->buf(), temp_buf, store_bounds_info.start); + LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets); Stmt* new_f = f->accept_mutator(&lr); if (f != new_f) { Block* bb = dynamic_cast(f->get_parent()); @@ -1644,7 +2170,7 @@ void LoopNest::computeAt(Stmt* s, For* f) { // Mark the new temp buffer as requiring an alloc (it will be inserted as a // part of prepareForCodegen). - temp_bufs_.emplace_back(temp_buf); + intermediate_bufs_.insert(temp_buf); } class SwapReduce : public IRMutator { @@ -1686,6 +2212,70 @@ class StoreFinder : public IRVisitor { const Store* store_; }; +class BufReplacer : public IRMutator { + public: + BufReplacer( + const Buf* old_buf, + const std::vector& old_indices, + const Buf* new_buf, + const std::vector& new_indices) + : old_buf_(old_buf), + old_indices_(old_indices), + new_buf_(new_buf), + new_indices_(new_indices) {} + + const Expr* mutate(const Load* v) override { + if (v->buf() != old_buf_) { + return IRMutator::mutate(v); + } + + TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); + + bool equal_indices = true; + for (size_t i = 0; i < v->indices().size(); ++i) { + if (!exprEquals(v->indices()[i], old_indices_[i])) { + equal_indices = false; + break; + } + } + if (!equal_indices) { + return IRMutator::mutate(v); + } + + const Expr* mask_new = v->mask()->accept_mutator(this); + return new Load(new_buf_, new_indices_, mask_new); + } + + Stmt* mutate(const Store* v) override { + if (v->buf() != old_buf_) { + return IRMutator::mutate(v); + } + + TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); + + bool equal_indices = true; + for (size_t i = 0; i < v->indices().size(); ++i) { + if (!exprEquals(v->indices()[i], old_indices_[i])) { + equal_indices = false; + break; + } + } + if (!equal_indices) { + return IRMutator::mutate(v); + } + + const Expr* new_value = v->value()->accept_mutator(this); + const Expr* mask_new = v->mask()->accept_mutator(this); + return new Store(new_buf_, new_indices_, new_value, mask_new); + } + + private: + const Buf* old_buf_; + const std::vector& old_indices_; + const Buf* new_buf_; + const std::vector& new_indices_; +}; + void LoopNest::rfactor( const Expr* r, const Var* reduction_var, @@ -1706,8 +2296,8 @@ void LoopNest::rfactor( For* root_for = nullptr; For* target_for = nullptr; - std::set reduce_args = {reduce_op->reduce_args().begin(), - reduce_op->reduce_args().end()}; + std::set reduce_args = { + reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()}; // Store loops below the target point. std::vector output_loops; @@ -1758,7 +2348,7 @@ void LoopNest::rfactor( std::vector new_dims = {}; Buf* tmp_buf = - new Buf(new Var("tmp_buf", kHandle), new_dims, reduce_op->body().dtype()); + new Buf(new Var("tmp_buf", kHandle), new_dims, reduce_op->dtype()); auto old_acc = reduce_op->accumulator(); auto new_inner = reduce_op->reduce_args(); @@ -1786,26 +2376,19 @@ void LoopNest::rfactor( } new_outer.emplace_back(reduction_var); + BufReplacer bufReplacer( + reduce_op->accumulator(), reduce_op->output_args(), tmp_buf, new_outer); + const Expr* new_body = reduce_op->body()->accept_mutator(&bufReplacer); + auto first_reduce = new ReduceOp( - tmp_buf, - reduce_op->body(), - reduce_op->interaction(), - new_outer, - new_inner); + tmp_buf, new_body, new_outer, new_inner, reduce_op->reducer()); auto second_reduce_load_indices = reduce_op->output_args(); second_reduce_load_indices.emplace_back(reduction_var); - auto second_reduce_load = ExprHandle(new Load( - reduce_op->body().dtype(), - tmp_buf, - second_reduce_load_indices, - new IntImm(1))); - auto second_reduce = new ReduceOp( - old_acc, - second_reduce_load, - reduce_op->interaction(), - reduce_op->output_args(), - {reduction_var}); + auto second_reduce_load = new Load( + reduce_op->dtype(), tmp_buf, second_reduce_load_indices, new IntImm(1)); + auto second_reduce = reduce_op->reducer()( + old_acc, second_reduce_load, reduce_op->output_args(), {reduction_var}); // 1) replace target for loop (which is a reduction loop) // with an iterative for loop by removing the reduction var from the @@ -1891,38 +2474,9 @@ void LoopNest::rfactor( "Hit undefined behavior in rfactor -- couldn't infer bounds."); } - std::vector starts; - std::vector stops; - - // Find the safe size of the temprorary buffer by determining the outer - // extents of a union of all bounds. - for (const TensorAccessBoundsInfo& p : bounds_it->second) { - for (size_t i = 0; i < p.start.size(); i++) { - if (starts.size() <= i) { - starts.push_back(p.start[i]); - } else { - starts[i] = - IRSimplifier::simplify(new Min(starts[i], p.start[i], true)); - } - - if (stops.size() <= i) { - stops.push_back(p.stop[i]); - } else { - stops[i] = IRSimplifier::simplify(new Max(stops[i], p.stop[i], true)); - } - } - } - - std::vector tmp_dims; - for (size_t i = 0; i < starts.size(); ++i) { - const Expr* dim = IRSimplifier::simplify( - new Add(new Sub(stops[i], starts[i]), new IntImm(1))); - - tmp_dims.push_back(dim); - } - + std::vector tmp_dims = getBoundExtents(bounds_it->second); tmp_buf->set_dims(tmp_dims); - temp_bufs_.emplace_back(tmp_buf); + intermediate_bufs_.insert(tmp_buf); } } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 391bdbeb1c371..962d69f0458d6 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -24,7 +24,22 @@ class Dtype; class TORCH_API LoopNest { public: + // A constructor for building a LoopNest from a list of Tensors LoopNest(const std::vector& output_tensors); + + // A constructor for building a LoopNest from a pre-baked Stmt and meta-info + // TODO: Nuke intermediate_bufs_ and possibly buf_initializers from here if + // they can be deduced. + LoopNest( + Stmt* stmt, + const std::unordered_set& output_bufs, + const std::unordered_set& intermediate_bufs, + const std::unordered_map& buf_initializers) + : root_stmt_(stmt), + output_bufs_(output_bufs), + intermediate_bufs_(intermediate_bufs), + buf_initializers_(buf_initializers) {} + Stmt* root_stmt() const { return root_stmt_; } @@ -34,18 +49,31 @@ class TORCH_API LoopNest { Stmt* getLoopBodyFor(Tensor*) const; bool hasLoopBodyFor(Tensor*) const; - void vectorize(Stmt*); + static void vectorize(For*); - void computeInline(Stmt* s); - void computeInline(const Buf* b); + bool computeInline(Stmt* s); + bool computeInline(const Buf* b); + void inlineIntermediateBufs(bool allow_duplicated_work); - void splitWithTail(For* f, int factor, For** outer, For** inner, For** tail); - void splitWithMask(For* f, int factor, For** outer, For** inner); + static void splitWithTail(For* f, int factor); + static void splitWithTail( + For* f, + int factor, + For** outer, + For** inner, + For** tail); + + static void splitWithMask(For* f, int factor); + static void splitWithMask(For* f, int factor, For** outer, For** inner); void reorderAxis(For* a, For* b); static void unroll(For* f, Stmt** unrolled); static void normalize(For* f, For** normalized); + static bool flatten(const std::vector& f, For** flattened); + + // Get 'num' loops from the loopnest starting at 'f'. + static std::vector getLoopStmtsInLoopNest(For* f, size_t num); // LoopOptions are propagated to tail. void sliceHead(For* f, int factor, For** head, For** tail); @@ -55,11 +83,21 @@ class TORCH_API LoopNest { void setGPUBlockIndex(For* f, int idx); void setGPUThreadIndex(For* f, int idx); + using AccessResult = std::pair; + // Insert a cache for the consumer's usages of the buffer produced in + // consumer, and redirect reads and writes in the consumer to that cache. + // Returns a pair of the new cache buffer, and the new rewritten consumer. + AccessResult cacheAccesses( + const Buf* producer, + const std::string& name, + Stmt* consumer); + // Insert a temporary computation of statement S in the scope of loop AT. // S is assumed to be a Store or a Block containing a Store. Along with the // computation itself, this transformation inserts Alloc/Free statements for // the temporary buffer used in the computation. void computeAt(Stmt* s, For* at); + void rfactor( const Expr* f, const Var* reduction_var, @@ -69,12 +107,20 @@ class TORCH_API LoopNest { For* f, const std::unordered_map& map); + void eliminateDeadStores(); void prepareForCodegen(); // Find the inner-most loops and vectorize them. Currently, this only works // for the LLVM backend, when no reductions are involved. void vectorizeInnerLoops(); + const std::unordered_set getInputBufs() { + return input_bufs_; + } + const std::unordered_set getOutputBufs() { + return output_bufs_; + } + private: std::vector findAllNeededTensors( const std::vector& tensors); @@ -83,9 +129,9 @@ class TORCH_API LoopNest { Stmt* root_stmt_; - std::unordered_set output_tensors_; - std::unordered_set intermediate_tensors_; - std::vector temp_bufs_; + std::unordered_set input_bufs_; + std::unordered_set output_bufs_; + std::unordered_set intermediate_bufs_; // Holds the initializer Expr of buffers that have been initialized. std::unordered_map buf_initializers_; }; @@ -95,7 +141,7 @@ TORCH_API Stmt* FlattenIndexes(Stmt* s); // TODO: Revisit this once we decide on how dependencies analysis should look // like. Maybe we would choose to use a different API and BufUse would be // removed, or if we decide to keep it we need to properly document its API. -struct BufUse { +struct BufLoadOrStoreUse { Stmt* s; bool isStore; }; @@ -106,7 +152,8 @@ struct BufUse { * in the vectors reflects the order in which the uses appear in the given * statement. */ -std::unordered_map> findUses(Stmt* s); +std::unordered_map> +findLoadOrStoreUses(Stmt* s); } // namespace tensorexpr } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp new file mode 100644 index 0000000000000..938af81a450db --- /dev/null +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp @@ -0,0 +1,1346 @@ +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { +namespace analysis { + +const char* AccessToString(AccessType a) { + switch (a) { + case AccessType::Input: + return "Input"; + case AccessType::Output: + return "Output"; + case AccessType::Load: + return "Load"; + case AccessType::Store: + return "Store"; + case AccessType::Call: + return "Call"; + case AccessType::AtomicAdd: + return "AtomicAdd"; + case AccessType::Alloc: + return "Alloc"; + case AccessType::Free: + return "Free"; + default: + break; + } + return "Unknown"; +} + +void getDependencyChain( + const std::shared_ptr& info, + DependencySet& dependencies) { + if (!dependencies.insert(info).second) { + return; + } + + for (auto& dep : info->dependencies()) { + getDependencyChain(dep.second, dependencies); + } +} + +void getDependentsChain( + const std::shared_ptr& info, + DependencySet& dependents) { + if (!dependents.insert(info).second) { + return; + } + + for (auto& dep : info->dependents()) { + getDependencyChain(dep.second, dependents); + } +} + +// AccessInfo + +std::vector AccessInfo::getIndices() const { + std::vector indices; + + if (expr_) { + if (auto* load = dynamic_cast(expr_)) { + indices = load->indices(); + } else if (auto* call = dynamic_cast(expr_)) { + indices = call->params(); + } else if (auto* reduce = dynamic_cast(expr_)) { + indices = reduce->output_args(); + } + } else { + if (auto* store = dynamic_cast(stmt_)) { + indices = store->indices(); + } + } + return indices; +} + +void AccessInfo::addDependency(const std::shared_ptr& write) { + auto res = dependencies_.emplace(write->id(), write); + TORCH_INTERNAL_ASSERT(res.second); +} + +void AccessInfo::addDependent(const std::shared_ptr& read) { + auto res = dependents_.emplace(read->id(), read); + TORCH_INTERNAL_ASSERT(res.second); +} + +bool AccessInfo::hasDependency(const std::shared_ptr& info) const { + return dependencies_.count(info->id()) != 0; +} + +DependencySet AccessInfo::getDirectDependencies() { + DependencySet res; + for (auto& depPair : dependencies_) { + res.insert(depPair.second); + } + return res; +} + +DependencySet AccessInfo::getIndirectDependencies() { + DependencySet res; + for (auto& depPair : dependencies_) { + getDependencyChain(depPair.second, res); + } + return res; +} + +DependencySet AccessInfo::getDirectDependents() { + DependencySet res; + for (auto& depPair : dependents_) { + res.insert(depPair.second); + } + return res; +} + +DependencySet AccessInfo::getIndirectDependents() { + DependencySet res; + for (auto& depPair : dependencies_) { + getDependentsChain(depPair.second, res); + } + return res; +} + +bool AccessInfo::isRead() const { + switch (type_) { + case AccessType::Output: + case AccessType::Load: + case AccessType::Call: + case AccessType::AtomicAdd: + return true; + default: + break; + } + return false; +} + +bool AccessInfo::isWrite() const { + switch (type_) { + case AccessType::Input: + case AccessType::Store: + case AccessType::AtomicAdd: + case AccessType::Alloc: + case AccessType::Free: + return true; + default: + break; + } + return false; +} + +void AccessInfo::print() const { + std::cout << id_ << ". " << AccessToString(type_) << ": " << *var_ << "["; + if (bounds_.size() > 0) { + for (size_t i = 0; i < bounds_.size() - 1; ++i) { + bounds_[i].print(); + std::cout << ", "; + } + + size_t i = bounds_.size() - 1; + bounds_[i].print(); + } + std::cout << "]"; + + if (!dependencies_.empty()) { + std::cout << " - depends on: "; + for (auto& pair : dependencies_) { + std::cout << pair.second->id() << " "; + } + } + + if (!dependents_.empty()) { + std::cout << " - dependents: "; + for (auto& pair : dependents_) { + std::cout << pair.second->id() << " "; + } + } + + std::cout << "\n"; +} + +void AccessInfo::dumpDOT(std::ostream& os) const { + if (type_ == AccessType::Input || type_ == AccessType::Output || + type_ == AccessType::Alloc) { + os << "n" << id_ << " [\n"; + os << "label = \"" << AccessToString(type_) << "\\n " << *var_ << "["; + if (bounds_.size() > 0) { + for (size_t i = 0; i < bounds_.size() - 1; ++i) { + os << *IRSimplifier::simplify(new Add(bounds_[i].end, new IntImm(1))) + << ", "; + } + + size_t i = bounds_.size() - 1; + os << *IRSimplifier::simplify(new Add(bounds_[i].end, new IntImm(1))); + os << "]\"\n "; + } + if (isWrite()) { + os << "\tshape = \"invhouse\"\n"; + } else { + os << "\tshape = \"house\"\n"; + } + } else { + os << "n" << id_ << " [\n"; + os << "label = \"" << AccessToString(type_) << " (#" << id_ << ")\\n"; + os << "buf : " << *var_ << "\\n"; + os << "bounds : \["; + if (bounds_.size() > 0) { + for (size_t i = 0; i < bounds_.size() - 1; ++i) { + os << "(" << *bounds_[i].start << ", " << *bounds_[i].end << "), "; + } + + size_t i = bounds_.size() - 1; + os << "(" << *bounds_[i].start << ", " << *bounds_[i].end << ")]"; + } + os << "\"\n"; + os << "\tshape = \"box\"\n"; + } + os << "\tstyle=\"filled\"\n"; + os << "\tcolor=\"" << AccessTypeColour() << "\"\n"; + std::string edgeColour; + if (isWrite()) { + edgeColour = "cornflowerblue"; + } else { + edgeColour = "goldenrod"; + } + os << "]\n"; + for (auto& pair : dependencies_) { + os << "n" << pair.second->id() << " -> " + << "n" << id_ << " [color=\"" << edgeColour << "\"]\n"; + } +} + +const char* AccessInfo::AccessTypeColour() const { + switch (type_) { + case AccessType::Input: + case AccessType::Output: + return "palegreen"; + case AccessType::Load: + return "peachpuff"; + case AccessType::Store: + return "dodgerblue"; + case AccessType::Call: + return "violet"; + case AccessType::Alloc: + case AccessType::Free: + return "sandybrown"; + default: + break; + } + return "white"; +} + +// MemDependencyChecker +// +MemDependencyChecker::MemDependencyChecker() { + currentScope_ = std::make_shared(nullptr, nullptr); +} + +MemDependencyChecker::MemDependencyChecker( + const std::unordered_set& inputs, + const std::unordered_set& outputs) { + for (auto* s : inputs) { + inputs_[s] = nullptr; + } + for (auto* s : outputs) { + outputs_[s] = nullptr; + } + + currentScope_ = std::make_shared(nullptr, nullptr); +} + +MemDependencyChecker::MemDependencyChecker( + const std::vector& inputs, + const std::vector& outputs) { + for (auto& s : inputs) { + inputs_[s.node()] = nullptr; + } + for (auto& s : outputs) { + outputs_[s.node()] = nullptr; + } + + currentScope_ = std::make_shared(nullptr, nullptr); +} + +bool MemDependencyChecker::allowLoopExecutionOrderAnalysis(bool allow) { + std::swap(allowExecutionOrderAnalysis_, allow); + return allow; +} + +const std::vector>& MemDependencyChecker:: + getHistory() const { + return currentScope_->accesses_; +} + +void MemDependencyChecker::dumpDAG(const std::string& filename) const { + std::ofstream dotfile(filename); + + dotfile << "digraph {\n"; + for (auto& wi : getHistory()) { + wi->dumpDOT(dotfile); + } + dotfile << "}\n"; + dotfile.close(); +} + +// dependsDirectly, dependsIndirectly and friends: + +DependencySet MemDependencyChecker::getAllWriteDependencies( + const DependencySet& products) { + DependencySet writes; + + for (auto& info : products) { + DependencySet dependencies; + getDependencyChain(info, dependencies); + for (auto& other : dependencies) { + if (other->isWrite()) { + writes.insert(other); + } + } + } + + return writes; +} + +bool MemDependencyChecker::dependsDirectly(const Expr* A, const Stmt* B) { + return dependsDirectlyHelper(A, B); +} + +bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Stmt* B) { + return dependsDirectlyHelper(A, B); +} + +bool MemDependencyChecker::dependsDirectly(const Buf* O, const Stmt* B) { + auto outputAccess = output(O); + auto bWrites = getAllWritesWithin(B); + + for (auto& depPair : outputAccess->dependencies()) { + if (bWrites.count(depPair.second) != 0) { + return true; + } + } + + return false; +} + +bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Buf* I) { + auto aReads = getAllReadsWithin(A); + auto inputAccess = input(I); + + for (auto& depPair : inputAccess->dependents()) { + if (aReads.count(depPair.second) != 0) { + return true; + } + } + + return false; +} + +bool MemDependencyChecker::dependsDirectly(const Expr* A, const Buf* I) { + auto aReads = getAllReadsWithin(A); + auto inputAccess = input(I); + + for (auto& depPair : inputAccess->dependents()) { + if (aReads.count(depPair.second) != 0) { + return true; + } + } + + return false; +} + +bool MemDependencyChecker::dependsDirectly( + const std::shared_ptr& A, + const std::shared_ptr& B) { + return A->hasDependency(B) && B->isWrite(); +} + +bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Stmt* B) { + return dependsIndirectlyHelper(A, B); +} + +bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Stmt* B) { + return dependsIndirectlyHelper(A, B); +} + +bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Stmt* B) { + auto outputAccess = output(O); + + DependencySet dependencies; + getDependencyChain(outputAccess, dependencies); + + auto bWrites = getAllWritesWithin(B); + for (auto& dep : dependencies) { + if (bWrites.count(dep) != 0) { + return true; + } + } + + return false; +} + +bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Buf* I) { + auto aReads = getAllReadsWithin(A); + auto inputAccess = input(I); + + auto aDeps = getAllWriteDependencies(aReads); + + return aDeps.count(inputAccess) != 0; +} + +bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Buf* I) { + auto aReads = getAllReadsWithin(A); + auto inputAccess = input(I); + + auto aDeps = getAllWriteDependencies(aReads); + + return aDeps.count(inputAccess) != 0; +} + +bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Buf* I) { + auto outputAccess = output(O); + auto inputAccess = input(I); + + return dependsIndirectly(outputAccess, inputAccess); +} + +bool MemDependencyChecker::dependsIndirectly( + const std::shared_ptr& A, + const std::shared_ptr& B) { + if (!B->isWrite()) { + return false; + } + + DependencySet dependencies; + getDependencyChain(A, dependencies); + if (dependencies.count(B) == 0) { + return false; + } + + return true; +} + +std::shared_ptr MemDependencyChecker::accessFor( + const Stmt* A) const { + auto bound = stmtToAccess_.equal_range(A); + for (auto it = bound.first; it != bound.second; ++it) { + if (it->second->expr() == nullptr) { + return it->second; + } + } + return nullptr; +} + +std::shared_ptr MemDependencyChecker::accessFor( + const Expr* A) const { + // TODO exprs can have multiple accesses... we're returning the first but that + // isn't great. Can't do much here. + auto bound = exprToAccess_.equal_range(A); + if (bound.first != exprToAccess_.end()) { + return bound.first->second; + } + + return nullptr; +} + +std::unordered_set> MemDependencyChecker:: + accessesWithin(const Stmt* A) const { + auto it = scopeToAccesses_.find(A); + if (it != scopeToAccesses_.end()) { + return std::unordered_set>( + it->second.begin(), it->second.end()); + } + + std::unordered_set> ret; + auto bound = stmtToAccess_.equal_range(A); + for (auto it = bound.first; it != bound.second; ++it) { + ret.insert(it->second); + } + return ret; +} + +std::unordered_set> MemDependencyChecker:: + accessesWithin(const Expr* A) const { + return {accessFor(A)}; +} + +std::shared_ptr MemDependencyChecker::input(const Buf* b) const { + auto it = inputs_.find(b); + if (it == inputs_.end()) { + return nullptr; + } + return it->second; +} + +std::shared_ptr MemDependencyChecker::output(const Buf* b) const { + auto it = outputs_.find(b); + if (it == outputs_.end()) { + return nullptr; + } + return it->second; +} + +// Node visitors: + +void MemDependencyChecker::visit(const Store* v) { + const Stmt* last = lastStmt_; + lastStmt_ = v; + v->value()->accept(this); + + for (const Expr* ind : v->indices()) { + ind->accept(this); + } + lastStmt_ = last; + + // Create a new AccessInfo for the store. + const Var* var = v->buf()->base_handle(); + auto info = std::make_shared( + nextAccess_++, AccessType::Store, v, var, getIndicesBounds(v->indices())); + + // Add a dependency to any accesses that are within the scope of this store + // (ie. the RHS). + auto bound = stmtToAccess_.equal_range(v); + for (auto it = bound.first; it != bound.second; ++it) { + info->addDependency(it->second); + it->second->addDependent(info); + } + + stmtToAccess_.emplace(v, info); + + // This write is open, and will close any open writes that it totally + // overlaps. + auto& history = currentScope_->openWrites_[var]; + updateWriteHistory(history, info, info->id()); + currentScope_->accesses_.push_back(info); +} + +void MemDependencyChecker::visit(const Load* v) { + // Create a temporary scope to hold any loads that occur within the indices of + // this load. + auto indicesScope = + std::make_shared(currentScope_->block, currentScope_); + currentScope_ = indicesScope; + + for (const Expr* ind : v->indices()) { + ind->accept(this); + } + + // Create a new AccessInfo for the load. + const Var* var = v->buf()->base_handle(); + auto load = std::make_shared( + nextAccess_++, + AccessType::Load, + v, + lastStmt_, + var, + getIndicesBounds(v->indices())); + + // If there were loads in the indices, this load depends on them, and merge + // them in. + if (!indicesScope->accesses_.empty()) { + for (auto& access : indicesScope->accesses_) { + load->addDependency(access); + access->addDependent(load); + } + mergeScope(indicesScope, indicesScope->parent, false); + } + + currentScope_ = indicesScope->parent; + + stmtToAccess_.emplace(lastStmt_, load); + exprToAccess_.emplace(v, load); + + // This is a read, and does not close any accesses - but we need to establish + // dependencies on accesses in the same scope. + // Intentionally using operator[], we want it to be created if it does not + // exist. + auto& writeHistory = currentScope_->openWrites_[var]; + updateWriteHistory(writeHistory, load, load->id()); + currentScope_->accesses_.push_back(load); +} + +void MemDependencyChecker::visit(const FunctionCall* v) { + // This is essentially the same as Load. + auto paramScope = + std::make_shared(currentScope_->block, currentScope_); + currentScope_ = paramScope; + + for (const Expr* param : v->params()) { + param->accept(this); + } + + const Var* var = v->tensor()->buf()->base_handle(); + auto call = std::make_shared( + nextAccess_++, + AccessType::Call, + v, + lastStmt_, + var, + getIndicesBounds(v->params())); + + // If there were loads in the parameters, this call depends on them, also + // merge. + if (!paramScope->accesses_.empty()) { + for (auto& access : paramScope->accesses_) { + call->addDependency(access); + access->addDependent(call); + } + mergeScope(paramScope, paramScope->parent, false); + } + + currentScope_ = paramScope->parent; + + stmtToAccess_.emplace(lastStmt_, call); + exprToAccess_.emplace(v, call); + + // Intentionally using operator[], we want it to be created if it does not + // exist. + auto& writeHistory = currentScope_->openWrites_[var]; + updateWriteHistory(writeHistory, call, call->id()); + currentScope_->accesses_.push_back(call); +} + +// This check determines if two accesses within a loop are "safe" from loop-self +// dependence. This function does not consider overlap in bound range, but +// rather the stride of the bound relative to the loop variable. This is the +// section of the code which considers iteration order, if allowed. +bool executionSafetyCheck( + const std::shared_ptr& info, + const std::shared_ptr& other, + const std::vector& aStrides, + const std::vector& oStrides, + bool parallelized) { + if (aStrides.empty() || oStrides.empty()) { + return false; + } + TORCH_INTERNAL_ASSERT(info->bounds().size() == other->bounds().size()); + for (size_t b = 0; b < info->bounds().size(); ++b) { + const Expr* aIndexStride = aStrides[b]; + const Expr* oIndexStride = oStrides[b]; + // can't be safe on this index if we can't determine stride. + if (!aIndexStride->isConstant() || !oIndexStride->isConstant()) { + continue; + } + + const Expr* minStride = + IRSimplifier::simplify(new Min(aIndexStride, oIndexStride, true)); + const Expr* maxStride = + IRSimplifier::simplify(new Max(aIndexStride, oIndexStride, true)); + + // If the first access has no stride don't apply safety). + if (immediateEquals(minStride, 0)) { + continue; + } + + const Expr* modCheck = + IRSimplifier::simplify(new Mod(maxStride, minStride)); + + // if the strides can't have easily inferable distinct offsets, they're not + // safe. + if (!immediateEquals(modCheck, 0)) { + continue; + } + + // If the loop has a defined execution order (ie. sequential for) then + // the order of execution can provide safety from overlaps. + // Specifically if the difference in first access position for any + // axis is the same sign as the common stride, then they will not + // overlap. + + const Expr* startDiff = IRSimplifier::simplify( + new Sub(info->bounds()[b].start, other->bounds()[b].start)); + + bool diffNegative = immediateIsNegative(startDiff); + bool strideNegative = immediateIsNegative(minStride); + + // Invert the startDiff so mod works. + if (diffNegative != strideNegative) { + startDiff = IRSimplifier::simplify(new Sub(new IntImm(0), startDiff)); + } + + // If both accesses have the same stride, and the difference in start + // element is smaller than this stride then the entire range is distinct. + if (exprEquals(minStride, maxStride)) { + const Expr* check1 = + IRSimplifier::simplify(new CompareSelect(startDiff, minStride, kLT)); + if (check1->isConstant() && immediateEquals(check1, 1)) { + return true; + } + } + + startDiff = IRSimplifier::simplify(new Mod(startDiff, minStride)); + + CompareSelectOperation op = strideNegative ? kLT : kGT; + + const Expr* check = + IRSimplifier::simplify(new CompareSelect(startDiff, new IntImm(0), op)); + + // If the start difference modulo the minimum stride is offset from that + // stride, then the ranges have distinct strides. + if (check->isConstant() && immediateEquals(check, 1)) { + return true; + } + + // If we can consider execution order and the difference in offset is + // opposite signed to the stride then the read occurs in the past and we can + // infer safety. + if (!parallelized && diffNegative == strideNegative && + immediateEquals(startDiff, 0)) { + return true; + } + } + + return false; +} + +void MemDependencyChecker::visit(const For* v) { + const Var* var = v->var(); + + const Stmt* last = lastStmt_; + lastStmt_ = v; + + v->var()->accept(this); + + // Loads inside the For's start and stop expression are special. + // They exist in the enclosing scope, but accesses within the loop body may + // depend on them via usage of the loop variable. + // The way we handle this is to create a new scope so we have an easily + // accessible list of the acceses within the extents. + auto extentsScope = + std::make_shared(currentScope_->block, currentScope_); + currentScope_ = extentsScope; + + v->start()->accept(this); + v->stop()->accept(this); + + currentScope_ = currentScope_->parent; + + auto newScope = std::make_shared(v->body(), currentScope_); + currentScope_ = newScope; + + v->body()->accept(this); + + lastStmt_ = last; + + // Ok now we need to determine whether accesses in the loop depend on + // other loop iterations. + // + // This is the real challenge here, it depends on both the fully expanded + // bounds and the symbolic bounds. + + // The indices must change monotonically to avoid intersection. This is + // hard to determine, so here's our heuristic I hope it's conservative + // enough. + + // the size of at least one dependent index must be >= the size of the + // loop. + + // First step is to infer the stride relative to each dimension of each + // access, which we do via substituting the loop var with (var+1) into the + // indices expr. + + std::vector> loopStrides; + loopStrides.resize(currentScope_->accesses_.size()); + + for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) { + auto& info = currentScope_->accesses_[a]; + + std::vector indices = info->getIndices(); + + std::vector& loopIndicesStride = loopStrides[a]; + loopIndicesStride.resize(indices.size()); + + // index expr must depend on the loop var in some way to have a stride. + for (size_t i = 0; i < indices.size(); i++) { + VarFinder vf; + if (vf.find(indices[i]).count(var) == 0) { + loopIndicesStride[i] = new IntImm(0); + } else { + // If we've previously swapped the start and end of this bound, we + // should apply the substitution to the reverse of the bounds. + if (info->bounds()[i].swapped) { + info->bounds()[i].end = IRSimplifier::simplify( + Substitute(info->bounds()[i].end, {{var, v->start()}})); + info->bounds()[i].start = IRSimplifier::simplify(Substitute( + info->bounds()[i].start, + {{var, new Sub(v->stop(), new IntImm(1))}})); + + } else { + info->bounds()[i].start = IRSimplifier::simplify( + Substitute(info->bounds()[i].start, {{var, v->start()}})); + info->bounds()[i].end = IRSimplifier::simplify(Substitute( + info->bounds()[i].end, + {{var, new Sub(v->stop(), new IntImm(1))}})); + } + + const Expr* zeroStep = indices[i]; + const Expr* oneStep = + Substitute(indices[i], {{var, new Add(var, new IntImm(1))}}); + loopIndicesStride[i] = + IRSimplifier::simplify(new Sub(oneStep, zeroStep)); + + // If the start < end then swap the order of the bound. + const Expr* diff = IRSimplifier::simplify( + new Sub(info->bounds()[i].end, info->bounds()[i].start)); + if (diff->isConstant() && immediateIsNegative(diff)) { + info->bounds()[i].swap(); + } + + // If this access uses the loop var, it depends on loads used to compute + // the loop var. + for (auto& extentLoad : extentsScope->accesses_) { + info->addDependency(extentLoad); + extentLoad->addDependent(info); + } + } + } + } + + // Now we need to update the bounds in openWrites since that is what we use to + // merge. + for (auto& openWritePair : currentScope_->openWrites_) { + for (auto& pair : openWritePair.second) { + IndexBounds& bounds = pair.first; + + // The bounds may not contain the loop var, but in that case Substitute + // does nothing. + for (auto& bound : bounds) { + bound.start = IRSimplifier::simplify( + Substitute(bound.start, {{var, v->start()}})); + bound.end = IRSimplifier::simplify( + Substitute(bound.end, {{var, new Sub(v->stop(), new IntImm(1))}})); + + // If the start < end then swap the order of the bound. + const Expr* diff = + IRSimplifier::simplify(new Sub(bound.end, bound.start)); + if (diff->isConstant() && immediateIsNegative(diff)) { + bound.swap(); + } + } + } + } + + // TODO this isn't a scalable way to determine parallelism. + bool parallelized = v->loop_options().is_gpu_block_index() || + v->loop_options().is_gpu_thread_index(); + + // Store buffers allocated at this scope. + std::unordered_set local_intermediates; + + // Scanning from the top of the loop, we look for accesses which may depend + // on a previous or parallel loop iteration. + for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) { + auto& info = currentScope_->accesses_[a]; + if (info->type() == AccessType::Alloc) { + local_intermediates.insert(info->var()); + continue; + } + + if (!info->isRead()) { + continue; + } + + // Vars that don't carry outside this scope can't have loop self dependence. + if (local_intermediates.count(info->var())) { + continue; + } + + // Copy the bounds so we can keep track of open bounds internally without + // affecting the merge into the enclosing scope. The open portion of the + // bounds may be cut into multiple independent slices. + std::vector openBounds({info->bounds()}); + + // Scan from the bottom of the loop. + for (size_t j = currentScope_->accesses_.size() - 1; j > a; --j) { + std::shared_ptr other = currentScope_->accesses_[j]; + if (!other->isWrite()) { + continue; + } + + if (info->var() != other->var()) { + continue; + } + + if (info->hasDependency(other)) { + continue; + } + + // Whether or not the accesses within the loop are dependent on other + // iterations depends whether the loop could be parallelized, the + // difference in their strides and their start offset. + bool iterationsDistinct = executionSafetyCheck( + info, + other, + loopStrides[a], + loopStrides[j], + !allowExecutionOrderAnalysis_ || parallelized); + + if (iterationsDistinct) { + continue; + } + + std::vector newBoundSlices; + for (auto& b : openBounds) { + OverlapKind overlap = overlaps(b, other->bounds()); + if (overlap == NoOverlap) { + newBoundSlices.push_back(b); + continue; + } + + // It's dependent, link it to other. + info->addDependency(other); + other->addDependent(info); + + if (overlap == Contains) { + continue; + } + + // Otherwise update openBounds. + auto slices = subtractIndicesBounds(b, other->bounds(), overlap); + std::move( + slices.begin(), slices.end(), std::back_inserter(newBoundSlices)); + } + + if (newBoundSlices.empty()) { + break; + } + openBounds.swap(newBoundSlices); + } + } + + std::vector> mergedAccesses; + mergedAccesses.reserve( + extentsScope->accesses_.size() + currentScope_->accesses_.size()); + std::copy( + extentsScope->accesses_.begin(), + extentsScope->accesses_.end(), + std::back_inserter(mergedAccesses)); + std::copy( + currentScope_->accesses_.begin(), + currentScope_->accesses_.end(), + std::back_inserter(mergedAccesses)); + scopeToAccesses_.emplace(v, mergedAccesses); + + // it's a little faster to merge without closing, and since no writes can + // occur within the start and stop exprs we'll do that. + mergeScope(extentsScope, extentsScope->parent, false); + mergeScope(currentScope_, currentScope_->parent, true); + currentScope_ = currentScope_->parent; +} + +void MemDependencyChecker::visit(const Cond* v) { + const Stmt* last = lastStmt_; + lastStmt_ = v; + + auto enclosingScope = + std::make_shared(currentScope_->block, currentScope_); + + // condition is in enclosing scope. + v->condition()->accept(this); + + Block* true_stmt = v->true_stmt(); + Block* false_stmt = v->false_stmt(); + + // Create scopes so the Block visitor doesn't create and merge a new scope. + auto trueScope = std::make_shared(true_stmt, enclosingScope); + auto falseScope = std::make_shared(false_stmt, enclosingScope); + + if (true_stmt) { + currentScope_ = trueScope; + true_stmt->accept(this); + } + + if (false_stmt) { + currentScope_ = falseScope; + false_stmt->accept(this); + } + + // TODO(nickg): this logic isn't quite correct, if a write's Bound range is + // present in both the true and false branches then we can close overlapping + // accesses in the enclosing scope. Without that analysis future accesses + // may be dependent on a write of a common range in all three of the + // enclosing, true and false scope. This is a false positve so not too bad + // in the short term, I think. + + // Merge both true and false branches into the parent, but don't close any + // accesses. + mergeScope(trueScope, enclosingScope, false); + mergeScope(falseScope, enclosingScope, false); + + // Merge the enclosing scope into it's parent. + mergeScope(enclosingScope, enclosingScope->parent, false); + + currentScope_ = enclosingScope; + scopeToAccesses_.emplace(v, enclosingScope->accesses_); + + currentScope_ = enclosingScope->parent; + lastStmt_ = last; +} + +void MemDependencyChecker::visit(const IfThenElse* v) { + // condition is in enclosing scope. + v->condition()->accept(this); + + const Expr* true_value = v->true_value(); + const Expr* false_value = v->false_value(); + + auto enclosingScope = currentScope_; + + // Create scopes to hold downstream Loads. It's safe to put nullptr for the + // Scope's Block as it is only used by Stmts, not Exprs. + auto trueScope = std::make_shared(nullptr, enclosingScope); + auto falseScope = std::make_shared(nullptr, enclosingScope); + + if (true_value) { + currentScope_ = trueScope; + true_value->accept(this); + } + + if (false_value) { + currentScope_ = falseScope; + false_value->accept(this); + } + + // This doesn't have the same issue as Cond where there could be false + // positives from the enclosing scope since there are no Exprs which are + // writes. + + // Merge both true and false branches into the parent, but don't close any + // accesses. + mergeScope(trueScope, enclosingScope, false); + mergeScope(falseScope, enclosingScope, false); + + currentScope_ = enclosingScope; +} + +void MemDependencyChecker::visit(const CompareSelect* v) { + // condition is in enclosing scope. + v->lhs()->accept(this); + v->rhs()->accept(this); + + const Expr* true_value = v->ret_val1(); + const Expr* false_value = v->ret_val2(); + + auto enclosingScope = currentScope_; + + // Create scopes to hold downstream Loads. It's safe to put nullptr for the + // Scope's Block as it is only used by Stmts, not Exprs. + auto trueScope = std::make_shared(nullptr, enclosingScope); + auto falseScope = std::make_shared(nullptr, enclosingScope); + + if (true_value) { + currentScope_ = trueScope; + true_value->accept(this); + } + + if (false_value) { + currentScope_ = falseScope; + false_value->accept(this); + } + + // This doesn't have the same issue as Cond where there could be false + // positives from the enclosing scope since there are no Exprs which are + // writes. + + // Merge both true and false branches into the parent, but don't close any + // accesses. + mergeScope(trueScope, enclosingScope, false); + mergeScope(falseScope, enclosingScope, false); + + currentScope_ = enclosingScope; +} + +// Inserts accesses for a map of buffers (ie. for inputs and outputs). +void MemDependencyChecker::insertBuffers( + std::unordered_map>& bufs, + AccessType type) { + for (auto& pair : bufs) { + const Buf* b = pair.first; + const Var* var = b->base_handle(); + IndexBounds bounds; + for (auto* d : b->dims()) { + bounds.push_back( + {new IntImm(0), IRSimplifier::simplify(new Sub(d, new IntImm(1)))}); + } + auto info = + std::make_shared(nextAccess_++, type, nullptr, var, bounds); + + bufs[b] = info; + + auto& history = currentScope_->openWrites_[var]; + updateWriteHistory(history, info, info->id()); + currentScope_->accesses_.push_back(info); + } +} + +void MemDependencyChecker::visit(const Block* v) { + auto prev_scope = currentScope_; + + // handle kernel inputs. + if (prev_scope->block == nullptr) { + insertBuffers(inputs_, AccessType::Input); + } + + if (currentScope_->block != v) { + currentScope_ = std::make_shared((Block*)v, prev_scope); + } + + for (auto* s : *v) { + s->accept(this); + } + + for (auto* v : currentScope_->localVars) { + knownVarBounds_.erase(v); + } + for (auto& pair : currentScope_->shadowedVarBounds) { + knownVarBounds_[pair.first] = pair.second; + } + + scopeToAccesses_.emplace(v, currentScope_->accesses_); + + if (currentScope_ != prev_scope) { + mergeScope(currentScope_, prev_scope, true); + currentScope_ = prev_scope; + } + + // handle kernel outputs. + if (prev_scope->block == nullptr) { + insertBuffers(outputs_, AccessType::Output); + } +} + +void MemDependencyChecker::visit(const Let* v) { + const Stmt* last = lastStmt_; + lastStmt_ = v; + + IRVisitor::visit(v); + + lastStmt_ = last; + + const Var* var = v->var(); + if (knownVarBounds_.count(var) != 0) { + currentScope_->shadowedVarBounds[var] = knownVarBounds_[var]; + } + + currentScope_->localVars.insert(var); + knownVarBounds_[var] = {v->value(), v->value()}; +} + +// Don't support AtomicAdd yet, it's a bit more complex since it's both a read +// and a write. It's only inserted during Cuda codegen so this should be okay. +void MemDependencyChecker::visit(const AtomicAdd* v) { + throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented"); +} + +void MemDependencyChecker::visit(const Allocate* v) { + const Stmt* last = lastStmt_; + lastStmt_ = v; + + IRVisitor::visit(v); + + const Var* var = v->buffer_var(); + IndexBounds bounds; + for (auto* d : v->dims()) { + bounds.push_back( + {new IntImm(0), IRSimplifier::simplify(new Sub(d, new IntImm(1)))}); + } + auto info = std::make_shared( + nextAccess_++, AccessType::Alloc, nullptr, var, bounds); + + intermediates_[var] = info; + + auto& history = currentScope_->openWrites_[var]; + history.emplace_back(std::make_pair(info->bounds(), info)); + currentScope_->accesses_.push_back(info); + + lastStmt_ = last; +} + +void MemDependencyChecker::visit(const Free* v) { + const Stmt* last = lastStmt_; + lastStmt_ = v; + + IRVisitor::visit(v); + + const Var* var = v->buffer_var(); + auto it = intermediates_.find(var); + TORCH_INTERNAL_ASSERT(it != intermediates_.end()); + + IndexBounds bounds = it->second->bounds(); + auto info = std::make_shared( + nextAccess_++, AccessType::Free, nullptr, var, bounds); + + auto& history = currentScope_->openWrites_[var]; + updateWriteHistory(history, info, info->id()); + currentScope_->accesses_.push_back(info); + + lastStmt_ = last; +} + +void MemDependencyChecker::updateWriteHistory( + std::list& writeHistory, + const std::shared_ptr& info, + size_t latestAccessToClose, + bool closeOverlapped, + bool insert) { + bool isWrite = info->isWrite(); + + for (auto it = writeHistory.begin(); it != writeHistory.end();) { + auto& indexBounds = it->first; + std::shared_ptr other = it->second; + if (info->hasDependency(other)) { + ++it; + continue; + } + + OverlapKind overlap = overlaps(indexBounds, info->bounds()); + + if (overlap == NoOverlap) { + ++it; + continue; + } + + // Only writes can close open accesses. + if (!isWrite) { + info->addDependency(other); + other->addDependent(info); + ++it; + continue; + } + + // If we're not closing accesses we can stop here. + if (!closeOverlapped || other->id() > latestAccessToClose) { + ++it; + continue; + } + + if (overlap == ContainedOrEqual) { + // Total overlap is easy - the new access totally replaces the old. + it = writeHistory.erase(it); + } else { + // The new write partially overlaps a previous write. We want to keep + // both, but only track the unconvered part of the earlier write. + + // Determine the slices of the earlier bound not covered by info. + auto newBounds = + subtractIndicesBounds(indexBounds, info->bounds(), overlap); + + // Erase the old slice. + it = writeHistory.erase(it); + + // Add all new slices. + for (auto& b : newBounds) { + it = writeHistory.insert(it, std::make_pair(b, other)); + } + it++; + } + } + + if (insert && isWrite) { + writeHistory.emplace_back(std::make_pair(info->bounds(), info)); + } +} + +void MemDependencyChecker::mergeScope( + const std::shared_ptr& child, + const std::shared_ptr& parent, + bool closeOverlapped) { + if (child->accesses_.empty()) { + return; + } + + // Update dependencies, but don't add new open writes yet. + for (auto& info : child->accesses_) { + // Intentionally using operator[], we want it to be created if it does not + // exist. + auto& writeHistory = parent->openWrites_[info->var()]; + + size_t latestAccessToClose = child->accesses_.front()->id(); + updateWriteHistory( + writeHistory, info, latestAccessToClose, closeOverlapped, false); + } + + // Copy open writes up. + for (auto& pair : child->openWrites_) { + const Var* var = pair.first; + + // Intentionally using operator[], we want it to be created if it does not + // exist. + auto& writeHistory = parent->openWrites_[var]; + + for (auto& rel : pair.second) { + writeHistory.push_back(rel); + } + } + + // the parent scope is responsible for holding all accesses now. + parent->accesses_.insert( + parent->accesses_.end(), + std::make_move_iterator(child->accesses_.begin()), + std::make_move_iterator(child->accesses_.end())); +} + +// A visitor which applies known Bounds to symbolic expressions. +class VarBoundBinder : public IRVisitor { + public: + VarBoundBinder(const VarBoundMap& vars) : vars_(vars) {} + + Bound getBounds(const Expr* e) { + min_ = e; + max_ = e; + e->accept(this); + min_ = IRSimplifier::simplify(min_); + max_ = IRSimplifier::simplify(max_); + return {min_, max_}; + } + + private: + void visit(const Var* v) override { + auto it = vars_.find(v); + if (it == vars_.end()) { + return; + } + + min_ = Substitute(min_, {{v, it->second.start}}); + max_ = Substitute(max_, {{v, it->second.end}}); + } + + const Expr* min_{nullptr}; + const Expr* max_{nullptr}; + const VarBoundMap& vars_; +}; + +std::vector MemDependencyChecker::getIndicesBounds( + const std::vector& indices) { + std::vector bounds; + bounds.reserve(indices.size()); + VarBoundBinder binder(knownVarBounds_); + for (auto* s : indices) { + bounds.push_back(binder.getBounds(s)); + } + return bounds; +} + +} // namespace analysis +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h new file mode 100644 index 0000000000000..745e6c4109284 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -0,0 +1,412 @@ +#pragma once +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { +namespace analysis { + +enum class AccessType { + Input, + Output, + Load, + Store, + Call, + AtomicAdd, + Alloc, + Free +}; +const char* AccessToString(AccessType a); + +class AccessInfo; +using DependencySet = std::unordered_set>; + +/* AccessInfo + * + * Represents a single bounded memory access to a buffer, for instance a Load or + * a Store. Holds infomation relating to the specific access and links to + * connected accesses in the dependency graph. + */ +class TORCH_API AccessInfo { + public: + AccessInfo( + size_t id, + AccessType type, + const Stmt* stmt, + const Var* var, + IndexBounds bounds) + : id_(id), + type_(type), + stmt_(stmt), + expr_(nullptr), + var_(var), + bounds_(bounds) {} + + AccessInfo( + size_t id, + AccessType type, + const Expr* expr, + const Stmt* stmt, + const Var* var, + IndexBounds bounds) + : id_(id), + type_(type), + stmt_(stmt), + expr_(expr), + var_(var), + bounds_(bounds) {} + + // Id is a unique int representing the order this access occured in the graph. + size_t id() const { + return id_; + } + + // The type of the access (Load, Store, etc). + AccessType type() const { + return type_; + } + + // The enclosing Stmt this access represents. E.g. if this is a Store then + // Stmt is the Store itself, while if the access is caused by an Expr, this is + // the most immediate parent Stmt. + const Stmt* stmt() const { + return stmt_; + } + + // If the access is represented by an Expr (such as Load or Call) then this is + // it, otherwise it's nullptr. + const Expr* expr() const { + return expr_; + } + + // The Var representing the underlying Buffer. + const Var* var() const { + return var_; + } + + // A vector of Bounds representing the start and end expression for each + // dimension. + IndexBounds& bounds() { + return bounds_; + } + + // Each access that this depends upon, + // eg. if this is a Load, then it contains every Store that immediately + // contributes to a load of the bounds. + // or: if this is a Store, it contains all reads on the RHS of the Store. + const std::map>& dependencies() const { + return dependencies_; + } + + // Each access that depends on this one. + // ie. this access is present in the dependencies map of all accesses that are + // dependent. + const std::map>& dependents() const { + return dependents_; + } + + // Returns the symbolic expression of the indices of this access. + std::vector getIndices() const; + + // Establishes a dependency or dependent relationship with another access. + void addDependency(const std::shared_ptr& write); + void addDependent(const std::shared_ptr& read); + + // helper for checking dependencies. + bool hasDependency(const std::shared_ptr& info) const; + + // Returns the set of all nodes that are direct (immediate) dependencies of + // this access. + DependencySet getDirectDependencies(); + // likewise, returns all nodes that directly depend on this one. + DependencySet getDirectDependents(); + + // Returns the full list of all nodes in the graph that this access depends + // on, and all nodes they depend on, and so forth, back to the inputs. + DependencySet getIndirectDependencies(); + // likewise, returns the full list of all nodes that depend on this node, and + // all nodes that depend on those nodes and so on down to the outputs. + DependencySet getIndirectDependents(); + + // Does this access represent a read of memory (Load, ReduceOp, Call, etc). + bool isRead() const; + // Does this access represent a write of memory (Store, etc). + bool isWrite() const; + + // Helpers for dumping accesses in various formats. + void print() const; + void dumpDOT(std::ostream& os) const; + const char* AccessTypeColour() const; + + private: + size_t id_; + AccessType type_; + const Stmt* stmt_; + const Expr* expr_; + const Var* var_; + IndexBounds bounds_; + + // Yes these should be sorted. + std::map> dependencies_; + std::map> dependents_; +}; + +using VarBoundMap = std::unordered_map; + +/* MemDepedencyChecker analyses a IR fragment and builds a dependency graph of + * accesses contained within. + * + * It's possible to retrieve the entire graph in node-object form, or can be + * used as an oracle for answering dependency questions. e.g: + * + * analyzer.hasIndirectDependency(BufA, BufB); or, + * analyzer.hasDirectDependency(LoadA, StoreB); + */ +class TORCH_API MemDependencyChecker : public IRVisitor { + struct Scope; + + public: + MemDependencyChecker(); + MemDependencyChecker( + const std::unordered_set& inputs, + const std::unordered_set& outputs); + MemDependencyChecker( + const std::vector& inputs, + const std::vector& outputs); + + virtual ~MemDependencyChecker() {} + + // Whether or not to allow loop execution order to influence dependency + // calculation. If the loop may later be parallelized you don't want this. + bool allowLoopExecutionOrderAnalysis(bool allow = true); + + // Dependency Checking API. + // The goal is to have enough overloads here so you don't really have to think + // about it. + + // Returns true if any read in A has a direct dependence on a write in B. + bool dependsDirectly(const Stmt* A, const Stmt* B); + bool dependsDirectly(const Expr* A, const Stmt* B); + + // Returns true of the output depends directly on a write contained in B. + bool dependsDirectly(const Buf* output, const Stmt* B); + + // Returns true if a read in A depends directly on the provided input. + bool dependsDirectly(const Stmt* A, const Buf* input); + bool dependsDirectly(const Expr* A, const Buf* input); + + // Outputs/inputs cannot depend directly. + + // Returns true if the access A has B as an immediate dependency. + bool dependsDirectly( + const std::shared_ptr& A, + const std::shared_ptr& B); + + // Returns true if any read in A has an ancestor write contained in B. + bool dependsIndirectly(const Stmt* A, const Stmt* B); + bool dependsIndirectly(const Expr* A, const Stmt* B); + + // Returns true of the output depends indirectly on a write contained in B. + bool dependsIndirectly(const Buf* output, const Stmt* B); + + // Returns true if a read in A depends indirectly on the provided input. + bool dependsIndirectly(const Stmt* A, const Buf* input); + bool dependsIndirectly(const Expr* A, const Buf* input); + + // returns true if the output uses any load of the input. + bool dependsIndirectly(const Buf* output, const Buf* input); + + // Returns true if the access A has a dependency chain to access B. + bool dependsIndirectly( + const std::shared_ptr& A, + const std::shared_ptr& B); + + // Returns the AccessInfo + std::shared_ptr accessFor(const Stmt* A) const; + std::shared_ptr accessFor(const Expr* A) const; + + // Returns all AccessInfos. + std::unordered_set> accessesWithin( + const Stmt* A) const; + // TODO: this will return only the AccessInfo for A. It's included for + // completeness but be aware it wont return accesses used in the computation + // of A. + std::unordered_set> accessesWithin( + const Expr* A) const; + + // Accesses relating to input and output buffers. + std::shared_ptr input(const Buf* B) const; + std::shared_ptr output(const Buf* B) const; + + // Returns the full history of reads and writes. + const std::vector>& getHistory() const; + + // Dumps the dependency graph in DOT format. + void dumpDAG(const std::string& filename) const; + + private: + // Node visitors. + void visit(const Store* v) override; + void visit(const Load* v) override; + void visit(const FunctionCall* v) override; + void visit(const For* v) override; + void visit(const Cond* v) override; + void visit(const IfThenElse* v) override; + void visit(const CompareSelect* v) override; + void visit(const Block* v) override; + void visit(const Let* v) override; + void visit(const AtomicAdd* v) override; + void visit(const Allocate* v) override; + void visit(const Free* v) override; + + using BoundRelationship = std::pair>; + + // An internal struct holding the accesses found within a scope Block. + struct Scope { + Scope(Block* b, std::shared_ptr p) : block(b), parent(p) {} + + Block* block; + std::shared_ptr parent; + + std::unordered_map shadowedVarBounds; + std::unordered_set localVars; + + std::vector> accesses_; + + std::unordered_map> openWrites_; + }; + std::shared_ptr currentScope_; + + bool allowExecutionOrderAnalysis_{false}; + + std::unordered_multimap> + stmtToAccess_; + std::unordered_multimap> + exprToAccess_; + std::unordered_map>> + scopeToAccesses_; + + VarBoundMap knownVarBounds_; + + // Finds all accesses that are reads within the scope of v. + template + DependencySet getAllReadsWithin(const StmtOrExpr* v) { + DependencySet reads; + auto insertAllReads = [&](const auto& nodes) { + for (auto* l : nodes) { + auto bound = exprToAccess_.equal_range(l); + for (auto it = bound.first; it != bound.second; ++it) { + if (it->second->isRead()) { + reads.insert(it->second); + } + } + } + }; + + // Look for and insert accesses belonging to all nodes that act like + // reads. + insertAllReads(NodeFinder::find(v)); + insertAllReads(NodeFinder::find(v)); + insertAllReads(NodeFinder::find(v)); + + return reads; + } + + // Finds all accesses that are writes within the scope of v. + // Writes cannot occur in Exprs, so this is a little simpler. + DependencySet getAllWritesWithin(const Stmt* v) { + DependencySet writes; + + // writes just Store currently. + auto stores = NodeFinder::find(v); + for (auto* s : stores) { + auto bound = stmtToAccess_.equal_range(s); + for (auto it = bound.first; it != bound.second; ++it) { + if (it->second->isWrite()) { + writes.insert(it->second); + } + } + } + return writes; + } + + // Templated helpers to work on either Exprs or Stmts. + template + bool dependsDirectlyHelper(const StmtOrExpr* A, const Stmt* B) { + auto aReads = getAllReadsWithin(A); + auto bWrites = getAllWritesWithin(B); + + for (auto& read : aReads) { + for (auto& depPair : read->dependencies()) { + if (bWrites.count(depPair.second) != 0) { + return true; + } + } + } + + return false; + } + + template + bool dependsIndirectlyHelper(const StmtOrExpr* A, const Stmt* B) { + auto aReads = getAllReadsWithin(A); + auto bWrites = getAllWritesWithin(B); + + auto aDeps = getAllWriteDependencies(aReads); + + for (auto& dependency : aDeps) { + if (bWrites.count(dependency) != 0) { + return true; + } + } + + return false; + } + + DependencySet getAllWriteDependencies(const DependencySet& products); + + // Maps for inputs and outputs, since they aren't present directly in the IR. + std::unordered_map> inputs_; + std::unordered_map> outputs_; + std::unordered_map> intermediates_; + + // Inserts accesses for Buf's: specifically for inputs and outputs. + void insertBuffers( + std::unordered_map>& bufs, + AccessType type); + + // Update the write history with a new write, adding dependencies and closing + // any overlapped writes (if possible). + void updateWriteHistory( + std::list& writeHistory, + const std::shared_ptr& info, + size_t latestAccessToClose, + bool closeOverlapped = true, + bool insert = true); + + // Merge a child scope into a parent scope, adding dependencies for open + // writes in the parent to accesses in the child. + void mergeScope( + const std::shared_ptr& child, + const std::shared_ptr& parent, + bool closeOverlapped = true); + + // Binds symbolic vars in indices with the low and high bound for those vars. + std::vector getIndicesBounds(const std::vector& indices); + + size_t nextAccess_{0}; + const Stmt* lastStmt_{nullptr}; +}; + +} // namespace analysis +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/reduction.cpp b/torch/csrc/jit/tensorexpr/reduction.cpp new file mode 100644 index 0000000000000..a3daeaa808a31 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/reduction.cpp @@ -0,0 +1,37 @@ + +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +ReduceOp* Reducer::operator()( + const Buf* result_buf, + ExprHandle body, + const std::vector& output, + const std::vector& inner) const { + return new ReduceOp( + result_buf, + complete(result_buf, interaction_, body, output, inner), + output, + inner, + *this); +} + +ReduceOp* Reducer::operator()( + const Buf* result_buf, + const Expr* body, + const std::vector& output, + const std::vector& inner) const { + return new ReduceOp( + result_buf, + complete(result_buf, interaction_, ExprHandle(body), output, inner), + output, + inner, + *this); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/reduction.h b/torch/csrc/jit/tensorexpr/reduction.h index 1f2358d203ed6..40fd58b0cd18f 100644 --- a/torch/csrc/jit/tensorexpr/reduction.h +++ b/torch/csrc/jit/tensorexpr/reduction.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -17,99 +16,40 @@ namespace tensorexpr { using ParameterList = const std::vector; using ReduceInteraction = std::function; -// An expression representing a Reduction operation (e.g. Sum, Max) broken into -// it's component parts: initialization, accumulation var, acquisition of value -// to be reduced and interaction. -// -// This is intended to be expanded in the loopnest and not make it to codegen. -class ReduceOp : public ExprNode { - public: - ReduceOp( - const Buf* accum, - ExprHandle body, - ReduceInteraction c, - const std::vector& output_args, - const std::vector& reduce_args) - : ExprNodeBase(body.dtype()), - accumulator_(accum), - body_(body), - interaction_(c), - output_args_(output_args), - reduce_args_(reduce_args) {} - - // return the accumulation load expression. - const Buf* accumulator() const { - return accumulator_; - } - - // return the body expression which obtains the value to be reduced. - ExprHandle body() const { - return body_; - } - - // returns a function encoding the interaction between accumulator and the - // reduction value. - ReduceInteraction interaction() const { - return interaction_; - } - - // returns variables associated with the output Tensor. - const std::vector& output_args() const { - return output_args_; - } - - // returns variables associated with the axes of reduction. - const std::vector& reduce_args() const { - return reduce_args_; - } - - // Completes the reduction operator by applying the interaction function to - // the accumulation and the body expression. - ExprHandle complete() const { - std::vector indices(output_args_.begin(), output_args_.end()); - ExprHandle accum = ExprHandle( - new Load(body_.dtype(), accumulator_, indices, new IntImm(1))); - auto e = interaction_(accum, body_); - return e; - } - - private: - const Buf* accumulator_; - ExprHandle body_; - ReduceInteraction interaction_; - std::vector output_args_; - std::vector reduce_args_; -}; - // A Reducer is a user interface describing a particular reduction // operation. It has three components: An initialization value, a way of // interacting each value with the accumulation, and a method for obtaining the // current value to be reduced. It is materialized into a ReduceOp when loop // variables are known. -class Reducer { +class TORCH_API Reducer { public: Reducer(ExprHandle init, ReduceInteraction& interaction) : init_(init.node()), interaction_(interaction) {} - Reducer(ExprHandle init, ReduceInteraction& interaction, Buffer& buf) + Reducer(ExprHandle init, ReduceInteraction& interaction, Placeholder& buf) : init_(init.node()), interaction_(interaction) {} template Reducer(ExprHandle init, RI interaction) : init_(init.node()) { interaction_ = interaction; } + virtual ~Reducer() {} const Expr* initializer() const { return init_; } ReduceOp* operator()( - Buf* result_buf, + const Buf* result_buf, ExprHandle body, - std::vector output, - std::vector inner) const { - return new ReduceOp(result_buf, body, interaction_, output, inner); - } + const std::vector& output, + const std::vector& inner) const; + + ReduceOp* operator()( + const Buf* result_buf, + const Expr* body, + const std::vector& output, + const std::vector& inner) const; // Polymorphic handling of Body functions with a variety of parameters. static ExprHandle getReduceBody( @@ -161,11 +101,78 @@ class Reducer { return func(vars[0], vars[1], vars[2], vars[3]); } + // Completes the reduction operator by applying the interaction function to + // the accumulation and the body expression. + static Expr* complete( + const Buf* accumulator, + ReduceInteraction interaction, + ExprHandle body, + const std::vector& output_args, + const std::vector& reduce_args) { + ExprHandle accum = ExprHandle( + new Load(body.dtype(), accumulator, output_args, new IntImm(1))); + auto e = interaction(accum, body); + return e.node(); + } + private: const Expr* init_; ReduceInteraction interaction_; }; +// An expression representing a Reduction operation (e.g. Sum, Max) broken into +// it's component parts: initialization, accumulation var, acquisition of value +// to be reduced and interaction. +// +// This is intended to be expanded in the loopnest and not make it to codegen. +class ReduceOp : public ExprNode { + public: + ReduceOp( + const Buf* accum, + const Expr* body, + const std::vector& output_args, + const std::vector& reduce_args, + const Reducer& reducer) + : ExprNodeBase(body->dtype()), + accumulator_(accum), + body_(body), + output_args_(output_args), + reduce_args_(reduce_args), + reducer_(reducer) {} + + // return the accumulation load expression. + const Buf* accumulator() const { + return accumulator_; + } + + // return the body expression which obtains the value to be reduced. + const Expr* body() const { + return body_; + } + + // Returns the original Reducer factory that can create ReduceOps. + const Reducer& reducer() const { + return reducer_; + } + + // returns variables associated with the output Tensor. + const std::vector& output_args() const { + return output_args_; + } + + // returns variables associated with the axes of reduction. + const std::vector& reduce_args() const { + return reduce_args_; + } + + private: + const Buf* accumulator_; + const Expr* body_; + std::vector output_args_; + std::vector reduce_args_; + const Reducer reducer_; +}; + class Sum : public Reducer { public: Sum() @@ -201,8 +208,8 @@ inline ExprHandle minimumVal(ScalarType type) { class Maximum : public Reducer { public: - // TODO possible to remove this arg by deferring the init value until we know - // the dtype of the body. + // TODO possible to remove this arg by deferring the init value until we + // know the dtype of the body. Maximum(Dtype dtype) : Reducer( minimumVal(dtype.scalar_type()), @@ -232,7 +239,7 @@ class ReductionExpander : public IRMutator { } const Expr* mutate(const ReduceOp* v) override { - return v->complete().node(); + return v->body(); } }; diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 7181e9ec134a8..6e4383b09196d 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -3,6 +3,181 @@ namespace torch { namespace jit { namespace tensorexpr { +namespace registerizer { + +// AccessInfo + +void AccessInfo::addStore( + const Store* store, + const std::shared_ptr& scope) { + block_ = + block_ ? Block::getSharedParent(block_, scope->block()) : scope->block(); + + // If there is already a usage and it's this store, that means the same + // access is present in the RHS. + firstUsageOverlapped_ |= first_usage_ == store; + first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : store; + last_usage_ = store; + + store_cost_ = IRSimplifier::simplify(new Add(store_cost_, new IntImm(1))); + stores_.push_back(store); + + conditionId_ = scope->conditionId(); + hiddenAccess_.reset(); +} + +void AccessInfo::addLoad( + const Load* load, + const std::shared_ptr& scope, + const Stmt* usage) { + block_ = + block_ ? Block::getSharedParent(block_, scope->block()) : scope->block(); + first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage; + last_usage_ = usage; + + load_cost_ = IRSimplifier::simplify(new Add(load_cost_, new IntImm(1))); + loads_.push_back(load); + + conditionId_ = scope->conditionId(); + hiddenAccess_.reset(); +} + +void AccessInfo::merge(const std::shared_ptr& other) { + TORCH_INTERNAL_ASSERT(hash_ == other->hash()); + TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size()); + + last_usage_ = other->last_usage(); + for (auto* s : other->stores()) { + stores_.push_back(s); + } + for (auto* l : other->loads()) { + loads_.push_back(l); + } + + store_cost_ = + IRSimplifier::simplify(new Add(store_cost_, other->store_cost())); + load_cost_ = IRSimplifier::simplify(new Add(load_cost_, other->load_cost())); + + block_ = Block::getSharedParent(block_, other->block()); + // update first and last usage to be in the parent Block. + first_usage_ = block_->getEnclosedRoot(first_usage_); + last_usage_ = block_->getEnclosedRoot(last_usage_); + hiddenAccess_.reset(); +} + +bool AccessInfo::overlaps(const std::shared_ptr& other) { + // All accesses to a buf must have the same dimensionality. + TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size()); + + const auto& other_indices = other->indices(); + + // They don't overlap if there is a guaranteed difference in any + // dimension. + bool overlap = true; + for (size_t i = 0; i < indices_.size(); ++i) { + const Expr* diff = new Sub(indices_[i], other_indices[i]); + diff = IRSimplifier::simplify(diff); + + if (diff->isConstant() && !immediateEquals(diff, 0)) { + overlap = false; + break; + } + } + + return overlap; +} + +bool AccessInfo::dependsOnVar(const Var* v) { + VarFinder vf; + for (auto* i : indices_) { + i->accept(&vf); + } + + return vf.vars().count(v); +} + +std::shared_ptr AccessInfo::cloneWithHiddenInfo( + const std::shared_ptr& orig) { + std::shared_ptr newInfo = std::make_shared( + orig->hash(), orig->buf(), orig->indices(), orig->accessOrder()); + + newInfo->block_ = orig->block_; + newInfo->first_usage_ = orig->first_usage_; + newInfo->last_usage_ = orig->last_usage_; + newInfo->firstUsageOverlapped_ = orig->firstUsageOverlapped_; + newInfo->store_cost_ = orig->store_cost_; + newInfo->load_cost_ = orig->load_cost_; + for (auto* s : orig->stores_) { + newInfo->stores_.push_back(s); + } + for (auto* s : orig->loads_) { + newInfo->loads_.push_back(s); + } + + newInfo->conditionId_ = orig->conditionId_; + newInfo->hiddenAccess_ = orig; + return newInfo; +} + +void AccessInfo::print() const { + std::cout << "Access: " << *buf_ << "{"; + for (auto* i : indices_) { + std::cout << *i << " "; + } + std::cout << "} stores: " << stores_.size() << " (" << *store_cost_ << ") -"; + std::cout << " loads: " << loads_.size() << " (" << *load_cost_ << ")"; + if (conditionId_) { + std::cout << " cond: " << conditionId_; + } + + std::cout << "\n"; +} + +// Scope + +void Scope::closeAccess(const std::shared_ptr& info) { + closedAccesses_.push_back(info); +} + +AccessHashMap& Scope::getAccessMapByBuf(const Buf* b) { + auto it = openAccesses_.find(b); + if (it == openAccesses_.end()) { + // create and return + return openAccesses_[b]; + } + + return it->second; +} + +void Scope::filterClosed() { + closedAccesses_.erase( + std::remove_if( + closedAccesses_.begin(), + closedAccesses_.end(), + [](auto info) { + return info->store_cost()->isConstant() && + immediateAs(info->store_cost()) <= 1 && + info->load_cost()->isConstant() && + immediateAs(info->load_cost()) <= 1; + }), + closedAccesses_.end()); +} + +// RegisterizerAnalysis + +void RegisterizerAnalysis::closeAccessIntoScope( + const std::shared_ptr& info, + const std::shared_ptr& scope) { + if (exprConditionals_.count(info->conditionId()) != 0) { + return; + } + + if (info->hiddenAccess()) { + closeAccessIntoScope(info->hiddenAccess(), scope); + return; + } + scope->closeAccess(info); +} void RegisterizerAnalysis::visit(const For* v) { if (v->loop_options().is_gpu_block_index() || @@ -11,28 +186,196 @@ void RegisterizerAnalysis::visit(const For* v) { "Registerization must occur after parallelism flattening"); } - const Expr* old_loopCost = loopCost_; - loopCost_ = IRSimplifier::simplify( - new Mul(loopCost_, new Sub(v->stop(), v->start()))); + auto parent = currentScope_; + currentScope_ = std::make_shared(v->body(), parent); + + currentScope_->addLocalVar(v->var()); + stmtStack_.push_front(v); v->body()->accept(this); stmtStack_.pop_front(); - loopCost_ = old_loopCost; + const Expr* loopExtent = + IRSimplifier::simplify(new Sub(v->stop(), v->start())); + + // now we need to see which accesses we can hoist out of the for loop, their + // costs should be multiplied by the loop extent. + for (auto& pair : currentScope_->openAccesses()) { + const Buf* buf = pair.first; + if (pair.second.empty()) { + continue; + } + + auto& childAccesses = pair.second; + + for (auto it = childAccesses.begin(); it != childAccesses.end();) { + std::shared_ptr& candidate = it->second; + + // If the access is open, but conditional, then we have a problem. It's + // possible that an access at a higher scope could "unhide" the + // conditional access, in which case we need to hoist. If there is no + // access to this element at a higher scope then we cannot safely hoist. + // We cannot know at this level whether that will or wont occur. + // + // The solution we take here is to split the space-time continuum, and + // keep both versions of the access handy. If the hoisted access is not + // used above, we'll fall back to using the hidden, conditional + // AccessInfo - if it is, we'll delete the copy. + if (candidate->conditionId() != 0) { + candidate = AccessInfo::cloneWithHiddenInfo(candidate); + } + + bool closed = false; + // If this access depends on a locally scoped variable, it cannot be + // hosted out of the loop. + for (auto* v : currentScope_->localVars()) { + if (candidate->dependsOnVar(v)) { + closeAccessIntoScope(candidate, currentScope_); + closed = true; + break; + } + } + if (closed) { + it = childAccesses.erase(it); + continue; + } + + // hoist! + // By hoisting we pull the reads and writes out of the loop, and so the + // benefit of registerizing this access is multiplied by the loop extent. + candidate->setEnclosingBlock(parent->block()); + candidate->hoistCosts(loopExtent); + + // in the parent block, this loop Stmt is the insertion point for the + // initializer and finalizer. + candidate->setUsageMarks(v, v); + + ++it; + } + } + + // If an access is closed within a loop then it cannot be merged into an + // existing open access, but will still close that existing access. This is + // somewhat different from the regular merge so we need to handle closed + // accesses first. + mergeHiddenScope(true); + + // having hoisted, now we can merge normally. + mergeCurrentScopeIntoParent(); }; +void RegisterizerAnalysis::visit(const Cond* v) { + const Expr* condition = v->condition(); + Block* true_stmt = v->true_stmt(); + Block* false_stmt = v->false_stmt(); + + stmtStack_.push_front(v); + + // condition is in the enclosing scope. + condition->accept(this); + + auto prev_scope = currentScope_; + auto true_scope = + std::make_shared(true_stmt, prev_scope, ++conditionId_); + auto false_scope = + std::make_shared(false_stmt, prev_scope, ++conditionId_); + + if (true_stmt) { + currentScope_ = true_scope; + true_stmt->accept(this); + mergeHiddenScope(true); + mergeCurrentScopeIntoParent(); + } + if (false_stmt) { + currentScope_ = false_scope; + false_stmt->accept(this); + mergeHiddenScope(true); + mergeCurrentScopeIntoParent(); + } + + // TODO: even though both scopes are conditional, we can merge accesses if + // they totally overlap in both branches, since we can guarantee one + // definition will be hit. We might need a 3-way merge? Not as simple as + // merging the true and false scopes together first. + + stmtStack_.pop_front(); +} + +// IfThenElses are just like Conds except they are not Stmts, which means no +// registerization can occur internally. However, the first reference to an +// access can occur within one if its visible outside the condition. +void RegisterizerAnalysis::visit(const IfThenElse* v) { + const Expr* condition = v->condition(); + const Expr* true_value = v->true_value(); + const Expr* false_value = v->false_value(); + + // condition is in enclosing scope. + condition->accept(this); + + auto prev_scope = currentScope_; + auto true_scope = + std::make_shared(prev_scope->block(), prev_scope, ++conditionId_); + auto false_scope = + std::make_shared(prev_scope->block(), prev_scope, ++conditionId_); + + // We store IfThenElse scopes in a global map, which we use to prevent closing + // any access that would require inserting statements in the values, which + // cannot enclose Stmts. + exprConditionals_.insert(true_scope->conditionId()); + exprConditionals_.insert(false_scope->conditionId()); + + if (true_value) { + currentScope_ = true_scope; + true_value->accept(this); + mergeHiddenScope(false); + mergeCurrentScopeIntoParent(); + } + + if (false_value) { + currentScope_ = false_scope; + false_value->accept(this); + mergeHiddenScope(false); + mergeCurrentScopeIntoParent(); + } +} + +void RegisterizerAnalysis::visit(const Let* v) { + currentScope_->addLocalVar(v->var()); + + stmtStack_.push_front(v); + v->value()->accept(this); + stmtStack_.pop_front(); +} + void RegisterizerAnalysis::visit(const Block* v) { - const Block* last = enclosingBlock_; - enclosingBlock_ = v; + auto prev_scope = currentScope_; + if (currentScope_->block() != v) { + currentScope_ = std::make_shared(v, prev_scope); + } + stmtStack_.push_front(v); - costByBlock_[v] = loopCost_; - IRVisitor::visit(v); + + for (auto* s : *v) { + s->accept(this); + if (currentScope_->block() != v) { + // merge the inner block's accesses into this Block's accesses. + mergeCurrentScopeIntoParent(); + } + } + stmtStack_.pop_front(); - enclosingBlock_ = last; + + if (prev_scope->block() == nullptr) { + // close any open candidates. + for (auto& p1 : currentScope_->openAccesses()) { + for (auto& p2 : p1.second) { + closeAccessIntoScope(p2.second, currentScope_); + } + } + } } void RegisterizerAnalysis::visit(const Store* v) { - // path into value first. stmtStack_.push_front(v); v->value()->accept(this); stmtStack_.pop_front(); @@ -42,26 +385,49 @@ void RegisterizerAnalysis::visit(const Store* v) { return; } + // hash the Store: SimplifierHashType accessHash = hasher_.hash(v->buf()); for (auto* i : v->indices()) { accessHash = hasher_.hash_combine(accessHash, i); } accessHash = hasher_.hash_combine(accessHash, v->mask()); - std::shared_ptr info; - auto candidateIt = candidates_.find(accessHash); - if (candidateIt != candidates_.end()) { - info = candidateIt->second; - } else { - info = std::make_shared(v->buf(), v->indices()); - candidates_[accessHash] = info; - encounterOrder_.push_back(info); + auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf()); + auto candidateIt = bufAccesses.find(accessHash); + + // If an identical access already exists, add this Store to it. + if (candidateIt != bufAccesses.end()) { + candidateIt->second->addStore(v, currentScope_); + return; } - if (nested_conditions_ > 0) { - info->invalid = true; + // Otherwise make a new AccessInfo and add this store. + auto info = std::make_shared( + accessHash, v->buf(), v->indices(), accessOrder_++); + info->addStore(v, currentScope_); + + // This new access may overlap an existing open access, in which case we need + // to close the older of the two. + bool alreadyOverlapped = false; + for (auto it = bufAccesses.begin(); it != bufAccesses.end();) { + auto other = it->second; + if (info->overlaps(other)) { + if (other->last_usage() == v) { + // we are already overlapped by an access in the RHS. + alreadyOverlapped = true; + } + closeAccessIntoScope(other, currentScope_); + it = bufAccesses.erase(it); + } else { + ++it; + } + } + + if (alreadyOverlapped) { + closeAccessIntoScope(info, currentScope_); + } else { + bufAccesses.emplace(accessHash, info); } - info->addStore(v, enclosingBlock_, loopCost_); } void RegisterizerAnalysis::visit(const Load* v) { @@ -69,281 +435,360 @@ void RegisterizerAnalysis::visit(const Load* v) { // already a scalar. return; } - + // hash the Load: SimplifierHashType accessHash = hasher_.hash(v->buf()); for (auto* i : v->indices()) { accessHash = hasher_.hash_combine(accessHash, i); } accessHash = hasher_.hash_combine(accessHash, v->mask()); - std::shared_ptr info; - auto candidateIt = candidates_.find(accessHash); - if (candidateIt != candidates_.end()) { - info = candidateIt->second; - } else { - info = std::make_shared(v->buf(), v->indices()); - candidates_[accessHash] = info; - encounterOrder_.push_back(info); + auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf()); + auto candidateIt = bufAccesses.find(accessHash); + if (candidateIt != bufAccesses.end()) { + // found the right access, can just insert. + candidateIt->second->addLoad(v, currentScope_, stmtStack_.front()); + return; } - if (nested_conditions_ > 0) { - info->invalid = true; + std::shared_ptr info = std::make_shared( + accessHash, v->buf(), v->indices(), accessOrder_++); + info->addLoad(v, currentScope_, stmtStack_.front()); + + bool alreadyOverlapped = false; + // This new access may overlap an existing open access, in which case we need + // to finalize the older of the two. + for (auto it = bufAccesses.begin(); it != bufAccesses.end();) { + auto other = it->second; + if (info->overlaps(other)) { + if (info->last_usage() == other->last_usage()) { + // if these two accesses are from the same Stmt, they already overlap + // each other. + alreadyOverlapped = true; + } + closeAccessIntoScope(other, currentScope_); + it = bufAccesses.erase(it); + } else { + ++it; + } } - info->addLoad(v, enclosingBlock_, loopCost_, stmtStack_.front()); + if (alreadyOverlapped) { + closeAccessIntoScope(info, currentScope_); + } else { + bufAccesses.emplace(accessHash, info); + } } -void RegisterizerAnalysis::visit(const IfThenElse* v) { - v->condition()->accept(this); - nested_conditions_++; - v->true_value()->accept(this); - v->false_value()->accept(this); - nested_conditions_--; -} +// Loop and Conditional scopes are different in that it may or may not be +// possible to hoist the intializer of a scalar variable outside the block +// depending on if we can tell that the Buffer access is valid outside. This is +// tricky because the access that demonstrates this may be later in the tree and +// we haven't encountered it yet. +// The allowClosed flag indicates whether we want to keep the closed accesses +// (For and Cond), or not (IfThenElse). +void RegisterizerAnalysis::mergeHiddenScope(bool allowClosed) { + // The rule is that if any access is closed within the conditional block, any + // accesses which overlap it must also be closed - since their initializer + // cannot be hoisted out of the block. + std::list> newClosed; + for (auto& info : currentScope_->closedAccesses()) { + auto& candidates = currentScope_->getAccessMapByBuf(info->buf()); + for (auto it = candidates.begin(); it != candidates.end();) { + std::shared_ptr candidate = it->second; + + if (info->hash() == candidate->hash() || info->overlaps(candidate)) { + newClosed.push_back(candidate); + it = candidates.erase(it); + } else { + ++it; + } + } + } -void RegisterizerAnalysis::visit(const Cond* v) { - const Expr* condition = v->condition(); - Stmt* true_stmt = v->true_stmt(); - Stmt* false_stmt = v->false_stmt(); - condition->accept(this); + if (allowClosed) { + for (auto& info : newClosed) { + closeAccessIntoScope(info, currentScope_); + } + } else { + currentScope_->closedAccesses().clear(); + } +} - stmtStack_.push_front(v); - nested_conditions_++; +// Merge currentScope_ into it's parent, and make parent the new currentScope_. +void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { + auto parent = currentScope_->parent(); - if (true_stmt) { - true_stmt->accept(this); - } - if (false_stmt) { - false_stmt->accept(this); - } + // copy across current closed accceses, merging / closing as necessary + for (auto& candidate : currentScope_->closedAccesses()) { + auto& parentAccesses = parent->getAccessMapByBuf(candidate->buf()); - nested_conditions_--; - stmtStack_.pop_front(); -} + auto parentIt = parentAccesses.find(candidate->hash()); + if (parentIt != parentAccesses.end()) { + std::shared_ptr pCandidate = parentIt->second; -std::vector> RegisterizerAnalysis::getCandidates() { - std::vector> ret; - - // Group accesses by the base buffer they refer to, so it's easier to - // determine which accesses may overlap. - std::unordered_map>> - access_by_buf; - for (const auto& pair : candidates_) { - std::shared_ptr info = pair.second; - - // We can "hoist" an access up the syntax tree if it's indices do not - // depend on any loop vars. - VarFinder vf; - for (auto* i : info->indices) { - i->accept(&vf); - } + // if the access is closed inside a condition, it can only be merged if + // the parent is in the same condition. + if (candidate->conditionId() && + pCandidate->conditionId() != candidate->conditionId()) { + // the parent's access must be closed. + closeAccessIntoScope(pCandidate, parent); + parentAccesses.erase(parentIt); - const Stmt* ancestor = info->parent; - const Stmt* target = nullptr; - while (ancestor) { - if (const For* f = dynamic_cast(ancestor)) { - if (vf.vars().count(f->var()) != 0) { - break; - } - target = f->get_parent(); + // the childs access inserted into the parent scope. + closeAccessIntoScope(candidate, parent); + continue; } - ancestor = ancestor->get_parent(); + // merge totally overlapping accesses. + parentIt->second->merge(candidate); + closeAccessIntoScope(parentIt->second, parent); + parentAccesses.erase(parentIt); + continue; } - if (info->parent != target) { - if (const Block* new_parent = dynamic_cast(target)) { - info->parent = new_parent; + // we didn't find a perfect match, but we need to check all open accesses of + // this buf for partial overlap. + for (auto it = parentAccesses.begin(); it != parentAccesses.end();) { + std::shared_ptr pCandidate = it->second; + // Partial overlap of parent access: close parent access. + if (candidate->overlaps(pCandidate)) { + closeAccessIntoScope(pCandidate, parent); + it = parentAccesses.erase(it); + continue; } + ++it; } - // Now that analysis is complete we must normalize the costs by the - // parent Block we plan to insert the scalar var into. - info->store_cost = IRSimplifier::simplify( - new Div(info->store_cost, costByBlock_[info->parent])); - - if (!info->loads.empty()) { - info->load_cost = IRSimplifier::simplify( - new Div(info->load_cost, costByBlock_[info->parent])); - } - - access_by_buf[info->buf].push_back(info); + // Insert the childs closed access into the parent scope. + closeAccessIntoScope(candidate, parent); } - // For each buffer, for each access, determine if another access to the - // buffer could possibly write to the same region. - for (const auto& pair : access_by_buf) { + // copy across current open accesses, merging as necessary. + // for each Buf with an open access: + for (auto& pair : currentScope_->openAccesses()) { const Buf* buf = pair.first; - const std::vector>& accesses = pair.second; - for (const auto& info : accesses) { - // Filter out low cost accesses. - if (info->store_cost->isConstant() && - immediateAs(info->store_cost) <= 1 && - info->load_cost->isConstant() && - immediateAs(info->load_cost) <= 1) { - info->invalid = true; - continue; - } + if (pair.second.empty()) { + continue; + } - // TODO: this is n^2 by the number of accesses to a single buffer - // program wide, may be an issue in large programs. - for (const auto& i2 : accesses) { - if (info == i2) { + auto& parentAccesses = parent->getAccessMapByBuf(buf); + + // for each open access in the child scope for this Buf: + for (auto& hpair : pair.second) { + bool handled{false}; + std::shared_ptr candidate = hpair.second; + + for (auto it = parentAccesses.begin(); it != parentAccesses.end();) { + std::shared_ptr pCandidate = it->second; + + // If it completely overlaps then merge. + if (candidate->hash() == pCandidate->hash()) { + // if both accesses are found in conditional blocks, they cannot be + // merged, but the earlier must be closed. + if (pCandidate->conditionId() != parent->conditionId() && + pCandidate->conditionId() != candidate->conditionId()) { + closeAccessIntoScope(pCandidate, parent); + it = parentAccesses.erase(it); + continue; + } + pCandidate->merge(candidate); + handled = true; + ++it; continue; } - // All accesses to a buf must have the same dimensionality. - assert(info->indices.size() == i2->indices.size()); - - // They don't overlap if there is a guaranteed difference in any - // dimension. - bool overlap = true; - for (size_t i = 0; i < info->indices.size(); ++i) { - const Expr* diff = new Sub(info->indices[i], i2->indices[i]); - diff = IRSimplifier::simplify(diff); - if (diff->isConstant() && !immediateEquals(diff, 0)) { - overlap = false; - break; - } + // It can overlap an access in the parent: close the parent access. + // The child access may still be open. + if (candidate->overlaps(pCandidate)) { + closeAccessIntoScope(pCandidate, parent); + it = parentAccesses.erase(it); + continue; } - if (overlap) { - info->invalid = true; + ++it; + } + + // If this access depends on a locally scoped variable, it cannot be + // lifted out of the loop. + for (auto* v : currentScope_->localVars()) { + if (candidate->dependsOnVar(v)) { + closeAccessIntoScope(candidate, parent); + handled = true; break; } } - } - } - // Return valid access candidates in the order they were first seen. - for (const auto& info : encounterOrder_) { - if (!info->invalid) { - ret.push_back(info); + if (!handled) { + // If the inner scope was not conditional, but the outer scope is: all + // current accesses are now conditional in the parent scope. + if (candidate->conditionId() == 0) { + candidate->setConditionId(parent->conditionId()); + } + parentAccesses[candidate->hash()] = candidate; + } } } - return ret; + currentScope_ = parent; } -const Expr* RegisterizerReplacer::mutate(const Load* v) { - if (v->buf() != info_->buf) { - return IRMutator::mutate(v); - } +std::vector> RegisterizerAnalysis::getCandidates() { + currentScope_->filterClosed(); + std::sort( + currentScope_->closedAccesses().begin(), + currentScope_->closedAccesses().end(), + [](auto i1, auto i2) { return i1->accessOrder() < i2->accessOrder(); }); + return currentScope_->closedAccesses(); +} - initializerReady_ = false; +// RegisterizerReplacer - // sanity check indices for the same buf must have the same dimensionality. - assert(v->indices().size() == info_->indices.size()); - for (size_t i = 0; i < info_->indices.size(); ++i) { - if (!exprEquals(v->indices()[i], info_->indices[i])) { - return IRMutator::mutate(v); - } +const Expr* RegisterizerReplacer::mutate(const Load* v) { + auto it = loadToAccess_.find(v); + if (it == loadToAccess_.end()) { + // This access cannot be registerized. + return v; } - return var_; + auto& info = it->second; + + return info->replacement().var; } Stmt* RegisterizerReplacer::mutate(const Store* v) { - if (v->buf() != info_->buf) { - return IRMutator::mutate(v); + if (eliminatedIntializers_.count(v) != 0) { + // This store is the intializer for a scalar var that is already inserted. + return nullptr; } - if (initializerReady_ && info_->parent == v->get_parent()) { - initializer_ = v; - initializerReady_ = false; - // This is the easiest way to return an empty statement; - return new Block({}); + auto it = storeToAccess_.find(v); + if (it == storeToAccess_.end()) { + // This access cannot be registerized. + return IRMutator::mutate(v); } - initializerReady_ = false; + auto& info = it->second; - // sanity check indices for the same buf must have the same dimensionality. - assert(v->indices().size() == info_->indices.size()); - for (size_t i = 0; i < info_->indices.size(); ++i) { - if (!exprEquals(v->indices()[i], info_->indices[i])) { - return IRMutator::mutate(v); - } - } const Expr* new_val = v->value()->accept_mutator(this); - Store* s = new Store(var_wrapper_, {}, new_val, v->mask()); - return s; + return new Store(info->replacement().var_wrapper, {}, new_val, v->mask()); } -// Finds the Stmt in parent which contains stmt. -const Stmt* RegisterizerReplacer::findInsertionPoint( - const Stmt* stmt, - const Block* parent) { - while (stmt) { - if (stmt->get_parent() == parent) { - return stmt; +Stmt* RegisterizerReplacer::mutate(const Block* v) { + auto& scope = parentToAccesses_[v]; + + std::vector stmts; + for (Stmt* stmt : v->stmts()) { + { + // Insert the initializer for any Scalars scoped to this block. + auto it = scope.initializerPoints_.find(stmt); + if (it != scope.initializerPoints_.end()) { + for (auto& info : it->second) { + Stmt* initializer = + info->replacement().initializer->accept_mutator(this); + stmts.push_back(initializer); + } + scope.initializerPoints_.erase(it); + } } - stmt = stmt->get_parent(); - } - return nullptr; -} -Stmt* RegisterizerReplacer::mutate(const Block* v) { - // We need to mutate this block in place, rather than clone - since other - // AccessInfo objects may hold a pointer to it. - Block* v1 = const_cast(v); // NOLINT - assert(v1); - - Stmt* first_changed{nullptr}; - Stmt* last_changed{nullptr}; - std::list stmts = v1->stmts(); - for (Stmt* stmt : stmts) { - dirty_ = false; Stmt* stmt_new = stmt->accept_mutator(this); - if (dirty_) { - first_changed = first_changed ? first_changed : stmt_new; - last_changed = stmt_new; + if (stmt_new) { + if (stmt_new->get_parent()) { + stmt_new = Stmt::clone(stmt_new); + } + stmts.push_back(stmt_new); } - if (stmt_new == stmt) { - continue; + { + // Insert the finalizer for any Scalars scoped to this block. + auto it = scope.finalizePoints_.find(stmt); + if (it != scope.finalizePoints_.end()) { + for (auto& info : it->second) { + Store* finalizer = new Store( + info->buf(), + info->indices(), + info->replacement().var, + new IntImm(1)); + stmts.push_back(finalizer); + } + scope.finalizePoints_.erase(it); + } } - v1->replace_stmt(stmt, stmt_new); - first_changed = first_changed ? first_changed : stmt_new; - last_changed = stmt_new; } - dirty_ = first_changed != nullptr; + return new Block(stmts); +} - if (v != info_->parent) { - return v1; - } +void RegisterizerReplacer::buildReplacements() { + // Traverse the list of replacements, creating vars and updating our local + // maps. + for (auto& info : infoSet_) { + Var* v = new Var( + info->buf()->name_hint() + "_" + + c10::to_string(getBufferAccessCount(info->buf())), + info->buf()->dtype()); + + info->replacement().var = v; + + // we need to wrap the Var in a Buf so we can Load or Store it. + info->replacement().var_wrapper = new Buf(v, {}, info->buf()->dtype()); + + bool first = true; + for (auto* s : info->stores()) { + if (first && info->first_usage() == s && !info->firstUsageOverlapped()) { + info->replacement().initializer = new Let(v, s->value()); + eliminatedIntializers_.insert(s); + } else { + storeToAccess_[s] = info; + } - Stmt* let; - // If we didn't find an initial store: intialize with the original buffer. - if (!initializer_) { - let = new Let( - var_, - new Load( - info_->buf->dtype(), info_->buf, info_->indices, new IntImm(1))); - } else { - let = new Let(var_, initializer_->value()); - } - v1->insert_stmt_before(let, first_changed); + first = false; + } + + for (auto* s : info->loads()) { + loadToAccess_[s] = info; + } + + auto& scope = parentToAccesses_[info->block()]; + scope.initializerPoints_[info->first_usage()].push_back(info); + + // Only finalize if the scalar is written. + if (!info->stores().empty()) { + // push front to finalize in reverse order of encounter. + scope.finalizePoints_[info->last_usage()].push_front(info); + } - // If it was written to the buffer, make sure we write it out. - if (info_->stores.size() > 0) { - v1->insert_stmt_after( - new Store(info_->buf, info_->indices, var_, new IntImm(1)), - last_changed); + // create a default initializer by reading the access. + if (info->replacement().initializer == nullptr) { + info->replacement().initializer = new Let( + v, + new Load( + info->buf()->dtype(), + info->buf(), + info->indices(), + new IntImm(1))); + } } - return v1; } +} // namespace registerizer + // Apply scalar replacement to all accesses in s. Stmt* registerize(Stmt* s) { - RegisterizerAnalysis analysis; + s = IRSimplifier::simplify(s); + + // The outermost node must be a Block so we have somewhere to put outer scope + // scalars. + if (!dynamic_cast(s)) { + s = new Block({s}); + } + registerizer::RegisterizerAnalysis analysis; s->accept(&analysis); auto candidates = analysis.getCandidates(); - for (const auto& info : candidates) { - RegisterizerReplacer replacer(info); - s = s->accept_mutator(&replacer); - } + + registerizer::RegisterizerReplacer replacer(candidates); + s = s->accept_mutator(&replacer); return s; } diff --git a/torch/csrc/jit/tensorexpr/registerizer.h b/torch/csrc/jit/tensorexpr/registerizer.h index 118686a3e4e18..551a9fbb32770 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.h +++ b/torch/csrc/jit/tensorexpr/registerizer.h @@ -11,6 +11,7 @@ namespace torch { namespace jit { namespace tensorexpr { +namespace registerizer { /* The Registerizer performs scalar replacement by looking for common Stores and Loads to a single item in a buffer and replacing them with a local temporary @@ -38,58 +39,289 @@ For example it can replace: This is particularly useful on GPUs when parallelizing, since after replacing loops with metavars we have a lot of accesses like this. */ -// Holds analysis information about accesses to a specific range of a -// buffer, including the number of loads and stores and the lowest common -// parent Block. -struct AccessInfo { +class Scope; + +/* Holds analysis information about accesses to a specific range of a + buffer, including the number of loads and stores and the lowest common parent + Block. + */ +class AccessInfo { + public: AccessInfo() = default; - AccessInfo(const Buf* b, const std::vector& i) - : buf(b), - indices(i), - store_cost(new IntImm(0)), - load_cost(new IntImm(0)) {} + AccessInfo( + SimplifierHashType h, + const Buf* b, + const std::vector& i, + size_t accessOrder) + : hash_(h), + buf_(b), + indices_(i), + store_cost_(new IntImm(0)), + load_cost_(new IntImm(0)), + accessOrder_(accessOrder) {} + + // Adds a Store to this access, which is in the provided scope. + void addStore(const Store* store, const std::shared_ptr& scope); + + // Adds a Load to this access, which occurs in the usage Stmt in the provided + // scope. + void addLoad( + const Load* load, + const std::shared_ptr& scope, + const Stmt* usage); + + // Merge another AccessInfo into this one. + void merge(const std::shared_ptr& other); + + // Returns true if the other AccessInfo's bounds may overlap this one. + bool overlaps(const std::shared_ptr& other); - void addStore(const Store* s, const Block* p, const Expr* cost) { - store_cost = IRSimplifier::simplify(new Add(store_cost, cost)); - stores.push_back(s); - parent = parent ? Block::getSharedParent(parent, p) : p; - first_usage = first_usage ? first_usage : s; + // Returns true if the indices of this access depend on the provided Var. + bool dependsOnVar(const Var* v); + + // Clone this AccessInfo, and set this as the new accesses' hiddenAccess. + static std::shared_ptr cloneWithHiddenInfo( + const std::shared_ptr& orig); + + // print for debugging. + void print() const; + + SimplifierHashType hash() const { + return hash_; } - void addLoad( - const Load* l, - const Block* p, - const Expr* cost, - const Stmt* usage) { - load_cost = IRSimplifier::simplify(new Add(load_cost, cost)); - loads.push_back(l); - parent = parent ? Block::getSharedParent(parent, p) : p; - first_usage = first_usage ? first_usage : usage; + const Buf* buf() const { + return buf_; + } + + const std::vector& indices() const { + return indices_; + } + + const Block* block() const { + return block_; + } + + void setEnclosingBlock(const Block* b) { + block_ = b; + } + + const Stmt* first_usage() const { + return first_usage_; + } + const Stmt* last_usage() const { + return last_usage_; + } + + void setUsageMarks(const Stmt* first, const Stmt* last) { + first_usage_ = first; + last_usage_ = last; + } + + bool firstUsageOverlapped() const { + return firstUsageOverlapped_; + } + + const Expr* store_cost() const { + return store_cost_; + } + + const Expr* load_cost() const { + return load_cost_; + } + + const std::vector& stores() const { + return stores_; + } + + const std::vector& loads() const { + return loads_; + } + + void hoistCosts(const Expr* extent) { + store_cost_ = IRSimplifier::simplify(new Mul(store_cost_, extent)); + load_cost_ = IRSimplifier::simplify(new Mul(load_cost_, extent)); + } + + size_t conditionId() const { + return conditionId_; + } + + void setConditionId(size_t c) { + conditionId_ = c; + } + + size_t accessOrder() const { + return accessOrder_; + } + + std::shared_ptr hiddenAccess() const { + return hiddenAccess_; + } + + // Holds state relating to the scalar variable we will insert to replace some + // number of loads and stores. + struct ScalarReplacement { + Var* var{nullptr}; + Buf* var_wrapper{nullptr}; + Let* initializer{nullptr}; + }; + + ScalarReplacement& replacement() { + return replacement_; } - const Buf* buf; - std::vector indices; - const Block* parent{nullptr}; + private: + SimplifierHashType hash_; + const Buf* buf_; + std::vector indices_; + const Block* block_{nullptr}; + + const Stmt* first_usage_{nullptr}; + const Stmt* last_usage_{nullptr}; + + // Whether or not this access is overlapped in the first Stmt it appears. This + // means we cannot use it's first Store as the initializer. + bool firstUsageOverlapped_{false}; + + // The cost in real ops that this access represents, to enable + // filtering accesses that wont save any loads or stores. + const Expr* store_cost_; + const Expr* load_cost_; + + // The actual Stores and Loads which represent this access. + // Be careful with these, any mutator will invalidate these pointers. + std::vector stores_; + std::vector loads_; + + // An identifier representing the conditional block, if any, this access + // depends on. + size_t conditionId_{0}; + + // An identifier representing the order this access was first encountered, for + // sorting returned results. + size_t accessOrder_{0}; + + // Sometimes when traversing the tree we need to record what would happen if + // we hoisted an access, but sometimes it doesn't work out. This lets us + // "undo" some mutation and return to the internal hidden AccessInfo. + // It will be removed after any further additions to this AccessInfo. + std::shared_ptr hiddenAccess_; + + ScalarReplacement replacement_; +}; - const Stmt* first_usage{nullptr}; +using AccessHashMap = + std::unordered_map>; + +// Represents a scope block and holds all accesses contained within it. +class Scope { + public: + Scope(const Block* b, std::shared_ptr parent, size_t conditionId = 0) + : block_(b), parent_(parent), conditionId_(conditionId) {} - const Expr* store_cost; - const Expr* load_cost; + AccessHashMap& getAccessMapByBuf(const Buf* b); - std::vector stores; - std::vector loads; + std::unordered_map& openAccesses() { + return openAccesses_; + } + + std::vector>& closedAccesses() { + return closedAccesses_; + } + + const Block* block() const { + return block_; + } + + std::shared_ptr parent() const { + return parent_; + } + + size_t conditionId() const { + return conditionId_; + } - bool invalid{false}; + const std::unordered_set& localVars() const { + return localVars_; + } + void addLocalVar(const Var* v) { + localVars_.insert(v); + } + + void closeAccess(const std::shared_ptr& info); + + void filterClosed(); + + private: + // Map of map to access, narrowing by Buf then by hash(Buf+Indices). + // This allows us to find a candidate access easily, and also check for + // overlap with other accesses to the same buf. Buf -> + // Hash -> + // Access + std::unordered_map openAccesses_; + std::vector> closedAccesses_; + + // The Block object this scope represents. + const Block* block_; + + // The enclosing scope object. + std::shared_ptr parent_; + + // An identifier representing the condition block this scope depends on. + size_t conditionId_; + + // A set of variables local to this scope (e.g. loop vars). + std::unordered_set localVars_; }; -// Walks the IR generating AccessInfo for each access. +/* Analyzes the graph and collects accesses to the same symbolic tensor element + * which can be replaced by a single local scalar. + * + * This works by recursively walking the tree in postfix order, building sets of + * accesses to the same symbolic element by scope and then merging lower scopes + * into their enclosing scope. + * + * It is safe to move two accesses of the same Tensor element to a local scalar + * Var if between all usages of the element there are no other Loads or Stores + * that may refer to it. In the comments I refer to this as overlapping the + * access, or "cutting" the existing AccessInfo. In the case where a candidate + * for registerization is cut, it may be possible to finalize the access early + * by writing it back to the Tensor and then create a new scalar variable after + * the overlapping access is complete. We will attempt to do this when it saves + * memory accesses. + * + * There are a few cases that make this more challenging: + * + * - For: Loops change the number of real usages of a buffer by the loop + * extent, but only if we can pull the definition and finalization of the scalar + * variable out of the loop block. + * + * - Cond: Conditions complicate lifting scalars out of internal scopes. + * Generally we cannot lift an access outside of a conditional scope unless + * there is already a reference to that same access at the higher scope, since + * we don't know if the condition was guarding an array access not safe at the + * higher scope. In the comments I refer to this as the condition "hiding" the + * access, and the outer access "unhiding" it. + * + * - IfThenElse: Same situation as Cond, except since IfThenElse is an Expr + * rather than a Stmt we cannot insert the scalar definition or finalizer + * within the conditional scope. Acccesses inside an IfThenElse can be safely + * combined with external accesses but cannot exist completely within. + * + * - Let: Accesses dependent on local variables via Let Stmts, or loop vars, + * cannot be raised outside of the scope of the dependent var. + */ class TORCH_API RegisterizerAnalysis : public IRVisitor { public: - RegisterizerAnalysis() : loopCost_(new IntImm(1)) {} + RegisterizerAnalysis() + : currentScope_(std::make_shared(nullptr, nullptr, 0)) {} virtual ~RegisterizerAnalysis() {} void visit(const For* v) override; + void visit(const Cond* v) override; + void visit(const Block* v) override; void visit(const Store* v) override; @@ -98,7 +330,7 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor { void visit(const IfThenElse* v) override; - void visit(const Cond* v) override; + void visit(const Let* v) override; #define STMT_ON_STACK(Op) \ virtual void visit(const Op* v) override { \ @@ -110,54 +342,77 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor { STMT_ON_STACK(AtomicAdd); STMT_ON_STACK(Allocate); STMT_ON_STACK(Free); - STMT_ON_STACK(Let); #undef STMT_ON_STACK std::vector> getCandidates(); private: - std::unordered_map> - candidates_; - std::unordered_map costByBlock_; - std::vector> encounterOrder_; + void mergeCurrentScopeIntoParent(); + void mergeHiddenScope(bool allowClosed); + void closeAccessIntoScope( + const std::shared_ptr& info, + const std::shared_ptr& scope); - const Expr* loopCost_; + std::unordered_set exprConditionals_; + // A stack of enclosing Stmts for tracking the usage Stmt of Loads. std::deque stmtStack_; - const Block* enclosingBlock_; + + // The current scope being analyzed. + std::shared_ptr currentScope_; + HashProvider hasher_; - size_t nested_conditions_{0}; + size_t conditionId_{0}; + size_t accessOrder_{0}; }; -// Walks the IR an replaces a single Acccess with a local scalar Var. +/* Replaces each registerizable access with a Scalar variable, including + * definition, initializer and finalizer. + */ class TORCH_API RegisterizerReplacer : public IRMutator { public: - RegisterizerReplacer(std::shared_ptr i) : info_(i) { - var_ = new Var(info_->buf->name_hint() + "_", info_->buf->dtype()); - var_wrapper_ = new Buf(var_, {}, info_->buf->dtype()); - - initializer_ = nullptr; + RegisterizerReplacer(std::vector>& vec) + : infoSet_(vec) { + buildReplacements(); } const Expr* mutate(const Load* v) override; Stmt* mutate(const Store* v) override; - // Finds the Stmt in parent which contains stmt. - const Stmt* findInsertionPoint(const Stmt* stmt, const Block* parent); - Stmt* mutate(const Block* v) override; private: - std::shared_ptr info_; - Var* var_; - Buf* var_wrapper_; - const Store* initializer_; - bool dirty_{false}; - bool initializerReady_{true}; + struct ReplacerScope { + std::unordered_map>> + initializerPoints_; + std::unordered_map>> + finalizePoints_; + }; + + // Creates the various ReplacerScope objects and builds internal maps. + void buildReplacements(); + + // State relating to the accesses yet to be replaced. + std::vector>& infoSet_; + std::unordered_map> storeToAccess_; + std::unordered_map> loadToAccess_; + std::unordered_map parentToAccesses_; + + // Holds the set of Stores that should be pulled into an initializer, so they + // can be eliminated. + std::set eliminatedIntializers_; + + // Tracks the number of times we've seen each buffer, so we can name the + // scalar Vars appropriately. + std::unordered_map bufferAccessCounts_; + unsigned int getBufferAccessCount(const Buf* b) { + return ++bufferAccessCounts_[b]; + } }; +} // namespace registerizer // Apply scalar replacement to all accesses in s. // To produce safe code, this must occur after handling parallelized axes and diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index eb0b7837c5c61..55c1926b3541c 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -10,7 +11,7 @@ namespace torch { namespace jit { namespace tensorexpr { -class Buffer; +class Placeholder; // The common base between all statement node. class TORCH_API Stmt : public KernelScopedObject { @@ -163,6 +164,13 @@ class TORCH_API Block : public StmtNode { return stmts_; } + void clear() { + for (auto* s : stmts_) { + set_parent(s, nullptr); + } + stmts_.clear(); + } + explicit Block(const std::vector& stmts) { for (Stmt* s : stmts) { if (s->get_parent()) { @@ -244,6 +252,14 @@ class TORCH_API Block : public StmtNode { return nullptr; } + // returns the immediate child containing statement s. + const Stmt* getEnclosedRoot(const Stmt* s) const { + while (s && s->get_parent() != this) { + s = s->get_parent(); + } + return s; + } + private: std::list stmts_; }; @@ -270,12 +286,6 @@ class TORCH_API Store : public StmtNode { return buf_; } - static Store* make( - const Buffer& buffer, - const std::vector& indices, - const ExprHandle& value, - const ExprHandle& mask); - static Store* make( const BufHandle& buf, const std::vector& indices, @@ -287,13 +297,6 @@ class TORCH_API Store : public StmtNode { const std::vector& indices, const ExprHandle& value); - // TODO: merge this with Load. - Store( - const Buffer& buffer, - const std::vector& indices, - const Expr* value, - const Expr* mask); - Store( const Buf* buf, std::vector indices, diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index f986cd663dc83..d12f6999c8d51 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -1,7 +1,130 @@ #include +#include +#include +#include + namespace torch { namespace jit { -namespace tensorexpr {} // namespace tensorexpr +namespace tensorexpr { + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + const std::function&)>& body_func) { + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = body_func(VarVectorToVarHandleVector(args)).node(); + return new Tensor(func_name, dims, args, body); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + const std::function& body_func) { + if (dim_args.size() != 1) { + throw malformed_input("mismatch between body and arg size (1)"); + } + + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = body_func(VarHandle(args[0])).node(); + return new Tensor(func_name, dims, args, body); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + const std::function& + body_func) { + if (dim_args.size() != 2) { + throw malformed_input("mismatch between body and arg size (2)"); + } + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); + return new Tensor(func_name, dims, args, body); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + const std::function< + ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& + body_func) { + if (dim_args.size() != 3) { + throw malformed_input("mismatch between body and arg size (3)"); + } + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = + body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])) + .node(); + return new Tensor(func_name, dims, args, body); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + const std::function& body_func) { + if (dim_args.size() != 4) { + throw malformed_input("mismatch between body and arg size (4)"); + } + std::vector dims; + std::vector args_nodes; + unpack_dim_args(dim_args, &dims, &args_nodes); + auto args = VarVectorToVarHandleVector(args_nodes); + const Expr* body = body_func(args[0], args[1], args[2], args[3]).node(); + return new Tensor(func_name, dims, args_nodes, body); +} + +Stmt* Tensor::ElementStmt() const { + std::vector indices; + for (size_t i = 0; i < buf_->ndim(); i++) { + indices.push_back(args_[i]); + } + + const Expr* mask = new IntImm(1); + Stmt* update_stmt = new Store(buf_, indices, body_, mask); + return update_stmt; +} + +Tensor* Reduce( + const std::string& func_name, + const std::vector& dim_args, + const Reducer& reducer, + const Placeholder& buffer, + const std::vector& reduce_args) { + return Reduce( + func_name, + dim_args, + reducer, + [&](ParameterList& p) { return buffer.load(p); }, + reduce_args); +} + +Tensor* Reduce( + const std::string& func_name, + const std::vector& dim_args, + const Reducer& reducer, + Tensor* tensor, + const std::vector& reduce_args) { + return Reduce( + func_name, + dim_args, + reducer, + [&](ParameterList& p) { return tensor->call(p); }, + reduce_args); +} + +} // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index 7ed32a905eadf..e5e399db348b9 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -1,51 +1,93 @@ #pragma once #include +#include #include #include #include -#include #include namespace torch { namespace jit { namespace tensorexpr { -class Tensor : KernelScopedObject { +class TORCH_API Tensor : KernelScopedObject { public: - Function* function() const { - return function_; - } - int output_index() const { - return output_index_; - } + Tensor( + const std::string& name, + const std::vector& dims, + const std::vector& args, + const Expr* body) + // TODO: Function should not create buffers, they should be created + // manually before constructing a function. + : buf_(new Buf(name, dims, body->dtype())), args_(args), body_(body) {} + + Tensor(Buf* buf, const std::vector& args, const Expr* body) + : buf_(buf), args_(args), body_(body) {} + + Tensor( + const Buf* buf, + const std::vector& args, + const std::vector& reduce_dims, + const std::vector& reduce_args, + const Expr* body) + : buf_(buf), + args_(args), + body_(body), + reduce_dims_(reduce_dims), + reduce_args_(reduce_args) {} + + virtual ~Tensor() {} // Wrappers over accessors to fields of the underlying function const Expr* body() const { - return function()->body(output_index()); + return body_; } - const Buf* func_var() const { - return function()->func_var(output_index()); + const Buf* buf() const { + return buf_; } - int ndim() const { - return buf_->dims().size(); + size_t ndim() const { + return buf()->ndim(); } - const Expr* dim(int index) const { - return buf_->dim(index); + const Expr* dim(size_t index) const { + if (index >= ndim()) { + throw out_of_range_index(); + } + return buf()->dim(index); } std::vector dims() const { - return buf_->dims(); + return buf()->dims(); } - const Var* arg(int index) const { - return function()->arg(index); + const Var* arg(size_t index) const { + if (index >= ndim()) { + throw out_of_range_index(); + } + return args_[index]; } const std::vector& args() const { - return function()->args(); + return args_; } - - const Buf* buf() const { - return buf_; + size_t reduce_ndim() const { + return reduce_dims_.size(); + } + std::vector reduce_dims() const { + return reduce_dims_; + } + std::vector reduce_args() const { + return reduce_args_; + } + const Expr* reduce_dim(size_t index) const { + if (index >= reduce_ndim()) { + throw out_of_range_index(); + } + return reduce_dims_[index]; + } + const Var* reduce_arg(size_t index) const { + if (index >= reduce_ndim()) { + throw out_of_range_index(); + } + return reduce_args_[index]; } void initializeTo(const Expr* initializer) { @@ -54,9 +96,8 @@ class Tensor : KernelScopedObject { const Expr* initializer() const { return initializer_; } + virtual Stmt* ElementStmt() const; - Tensor(const Buf* buf, Function* function, int output_index) - : buf_(buf), function_(function), output_index_(output_index) {} template inline ExprHandle operator()(const Ts&... ts); template @@ -66,11 +107,104 @@ class Tensor : KernelScopedObject { private: const Buf* buf_; - Function* function_; - int output_index_; + std::vector args_; + const Expr* body_; + std::vector reduce_dims_; + std::vector reduce_args_; + const Expr* initializer_{nullptr}; }; +class TORCH_API CompoundTensor : public Tensor { + public: + CompoundTensor( + const Buf* buf, + const std::vector& args, + Stmt* stmt) + : Tensor(buf, args, {}, {}, nullptr), stmt_(stmt) {} + + virtual ~CompoundTensor() {} + + Stmt* ElementStmt() const override { + return stmt_; + } + + private: + Stmt* stmt_; +}; + +class Placeholder { + public: + Placeholder(const BufHandle& data) : data_(data.node()) { + if (data_->base_handle()->dtype() != kHandle) { + throw malformed_input("Placeholder dtype must be Handle"); + } + + std::vector stride_handles(ndim()); + for (int i = (int)ndim() - 1; i >= 0; i--) { + if (i == ndim() - 1) { + stride_handles[i] = 1; + } else { + stride_handles[i] = stride_handles[i + 1] * ExprHandle(dim(i + 1)); + } + } + strides_ = ExprHandleVectorToExprVector(stride_handles); + } + Placeholder( + const std::string& name, + const Dtype& dtype, + const std::vector& dims) + : Placeholder(BufHandle(name, dims, dtype)) {} + + const Buf* data() const { + return data_; + } + Dtype dtype() const { + return data_->dtype(); + } + int ndim() const { + return data_->ndim(); + } + const Expr* dim(int index) const { + return data_->dim(index); + } + std::vector dims() const { + return data_->dims(); + } + + template + inline ExprHandle load(const Ts&... ts) const; + + template + inline ExprHandle load(const std::vector& args) const; + + inline ExprHandle loadWithMask( + const std::vector& args, + const ExprHandle& mask) const { + return ExprHandle( + new Load(data(), ExprHandleVectorToExprVector(args), mask.node())); + } + + inline Store* store( + const std::vector& args, + const ExprHandle& val) const { + return new Store( + data(), ExprHandleVectorToExprVector(args), val.node(), new IntImm(1)); + } + + inline Store* storeWithMask( + const std::vector& args, + const ExprHandle& val, + const ExprHandle& mask) const { + return new Store( + data(), ExprHandleVectorToExprVector(args), val.node(), mask.node()); + } + + private: + const Buf* data_; + std::vector strides_; +}; + TORCH_API Tensor* Compute( const std::string& func_name, const std::vector& dim_args, @@ -137,10 +271,8 @@ Tensor* Reduce( Buf* func_result = new Buf(func_name, dims, body.dtype()); const ReduceOp* reduce_op = reducer(func_result, body, output_args, reduce_vars); - dims.insert(dims.end(), reduce_dims.begin(), reduce_dims.end()); - Function* func = - new Function(func_name, func_result, dims, all_vars, reduce_op); - Tensor* t = new Tensor(func_result, func, 0); + Tensor* t = + new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op); t->initializeTo(new Cast(body.dtype(), reducer.initializer())); return t; } @@ -156,12 +288,12 @@ Tensor* Reduce( return Reduce(func_name, dim_args, reducer, body_func, reduce_args); } -// Overload for the common case of all dimensions of a Buffer. +// Overload for the common case of all dimensions of a Placeholder. TORCH_API Tensor* Reduce( const std::string& func_name, const std::vector& dim_args, const Reducer& reducer, - const Buffer& buffer, + const Placeholder& buffer, const std::vector& reduce_args); // Overload for the common case of all dimensions of a prevously Computed @@ -194,10 +326,7 @@ class FunctionCall : public CallNode { } FunctionCall(Tensor* tensor, const std::vector& params) - : BaseClass( - tensor->function()->body(tensor->output_index())->dtype(), - kFunctionCall, - params), + : BaseClass(tensor->buf()->dtype(), kFunctionCall, params), tensor_(tensor) {} private: @@ -207,7 +336,7 @@ class FunctionCall : public CallNode { } std::string func_name() const override { - return tensor_->func_var()->name_hint(); + return tensor_->buf()->name_hint(); } Tensor* tensor_; @@ -229,6 +358,21 @@ inline ExprHandle Tensor::call(const std::vector& args) { std::vector params(args.begin(), args.end()); return FunctionCall::make(this, params); } + +template +inline ExprHandle Placeholder::load(const Ts&... ts) const { + std::vector params({ExprHandle(ts)...}); + return ExprHandle( + new Load(data(), ExprHandleVectorToExprVector(params), new IntImm(1))); +} + +template +inline ExprHandle Placeholder::load(const std::vector& args) const { + std::vector params(args.begin(), args.end()); + return ExprHandle( + new Load(data(), ExprHandleVectorToExprVector(params), new IntImm(1))); +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp new file mode 100644 index 0000000000000..3c941240b5db2 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -0,0 +1,474 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +void initTensorExprBindings(PyObject* module) { + auto m = py::handle(module).cast(); + + // Tensor Expr Classes + auto te = m.def_submodule("te"); + py::class_(te, "KernelScope").def(py::init<>()); + + auto dtype_class = py::class_(te, "Dtype"); + +#define DTYPE_SINGLETON_ACCESSOR(ctype, name) \ + dtype_class.def_property_readonly_static( \ + #name, [](py::object) { return tensorexpr::k##name; }); + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DTYPE_SINGLETON_ACCESSOR) +#undef DTYPE_SINGLETON_ACCESSOR + + auto expr_handle_class = + py::class_(te, "ExprHandle") + .def(py::self + py::self) + .def(py::self * py::self) + .def(py::self - py::self) + .def(py::self / py::self) + .def(py::self % py::self) + .def(py::self == py::self) + .def(py::self != py::self) + .def(py::self > py::self) + .def(py::self >= py::self) + .def(py::self < py::self) + .def(py::self <= py::self) + .def(py::self & py::self) + .def(py::self | py::self) + .def(py::self ^ py::self) + .def(py::self << py::self) + .def(py::self >> py::self) + .def( + "sin", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::sin(self); + }) + .def( + "cos", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::cos(self); + }) + .def( + "tan", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::tan(self); + }) + .def( + "asin", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::asin(self); + }) + .def( + "acos", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::acos(self); + }) + .def( + "atan", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::atan(self); + }) + .def( + "sinh", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::sinh(self); + }) + .def( + "cosh", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::cosh(self); + }) + .def( + "tanh", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::tanh(self); + }) + .def( + "sigmoid", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::sigmoid(self); + }) + .def( + "exp", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::exp(self); + }) + .def( + "expm1", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::expm1(self); + }) + .def( + "abs", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::abs(self); + }) + .def( + "log", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::log(self); + }) + .def( + "fast_log", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::fast_log(self); + }) + .def( + "log2", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::log2(self); + }) + .def( + "log10", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::log10(self); + }) + .def( + "log1p", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::log1p(self); + }) + .def( + "erf", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::erf(self); + }) + .def( + "erfc", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::erfc(self); + }) + .def( + "sqrt", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::sqrt(self); + }) + .def( + "rsqrt", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::rsqrt(self); + }) + .def( + "ceil", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::ceil(self); + }) + .def( + "floor", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::floor(self); + }) + .def( + "round", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::round(self); + }) + .def( + "trunc", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::trunc(self); + }) + .def( + "frac", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::frac(self); + }) + .def( + "lgamma", + [](const tensorexpr::ExprHandle& self) { + return tensorexpr::lgamma(self); + }) + .def("isnan", [](const tensorexpr::ExprHandle& self) { + return tensorexpr::isnan(self); + }); + te.def( + "ifThenElse", + [](const tensorexpr::ExprHandle& c, + const tensorexpr::ExprHandle& t, + const tensorexpr::ExprHandle& f) { + return tensorexpr::ifThenElse(c, t, f); + }); + te.def( + "atan2", + [](const tensorexpr::ExprHandle& v1, const tensorexpr::ExprHandle& v2) { + return tensorexpr::atan2(v1, v2); + }); + te.def( + "pow", + [](const tensorexpr::ExprHandle& v1, const tensorexpr::ExprHandle& v2) { + return tensorexpr::pow(v1, v2); + }); + te.def( + "fmod", + [](const tensorexpr::ExprHandle& v1, const tensorexpr::ExprHandle& v2) { + return tensorexpr::fmod(v1, v2); + }); + te.def( + "remainder", + [](const tensorexpr::ExprHandle& v1, const tensorexpr::ExprHandle& v2) { + return tensorexpr::remainder(v1, v2); + }); + +#define EXPRHANDLE_CTOR(ctype, name) \ + expr_handle_class.def_static( \ + #ctype, [](ctype v) { return tensorexpr::ExprHandle(v); }); + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, EXPRHANDLE_CTOR) +#undef EXPRHANDLE_CTOR + + py::class_(te, "VarHandle"); + py::class_(te, "BufHandle"); + + py::class_(te, "Placeholder") + .def(py::init< + const std::string&, + const tensorexpr::Dtype&, + std::vector&>()) + .def( + "load", + [](tensorexpr::Placeholder& self, + const std::vector& v) { + return self.load(v); + }); + py::class_(te, "Tensor") + .def( + "load", + [](tensorexpr::Tensor& self, + const std::vector& v) { + return self.call(v); + }); + py::class_(te, "DimArg") + .def(py::init()) + .def(py::init()); + + te.def( + "Compute", + [](const std::string& func_name, + const std::vector& dim_args, + py::function func) { + if (dim_args.size() == 1) { + return tensorexpr::Compute( + func_name, dim_args, [&func](const tensorexpr::VarHandle& a) { + return py::cast(func(a)); + }); + } else if (dim_args.size() == 2) { + return tensorexpr::Compute( + func_name, + dim_args, + [&func]( + const tensorexpr::VarHandle& a, + const tensorexpr::VarHandle& b) { + return py::cast(func(a, b)); + }); + } else if (dim_args.size() == 3) { + return tensorexpr::Compute( + func_name, + dim_args, + [&func]( + const tensorexpr::VarHandle& a, + const tensorexpr::VarHandle& b, + const tensorexpr::VarHandle& c) { + return py::cast(func(a, b, c)); + }); + } else if (dim_args.size() == 4) { + return tensorexpr::Compute( + func_name, + dim_args, + [&func]( + const tensorexpr::VarHandle& a, + const tensorexpr::VarHandle& b, + const tensorexpr::VarHandle& c, + const tensorexpr::VarHandle& d) { + return py::cast(func(a, b, c, d)); + }); + } else { + throw std::runtime_error("Too many args"); + } + }, + py::return_value_policy::reference); + py::class_(te, "Reducer"); + + te.def( + "SumReduce", + [](const std::string& func_name, + const std::vector& dim_args, + tensorexpr::Tensor* buffer, + const std::vector& reduce_args) { + return tensorexpr::Reduce( + func_name, dim_args, tensorexpr::Sum(), buffer, reduce_args); + }, + py::return_value_policy::reference); + + py::class_(te, "Stmt") + .def("__str__", [](const tensorexpr::Stmt& self) { + std::stringstream ss; + ss << self; + return ss.str(); + }); + py::class_(te, "For") + .def( + "index_var", + [](const tensorexpr::For& self) { + return tensorexpr::VarHandle(self.var()); + }, + py::return_value_policy::reference) + .def("body", &tensorexpr::For::body, py::return_value_policy::reference); + + py::class_(te, "Block") + .def( + "stmts", + &tensorexpr::Block::stmts, + py::return_value_policy::reference); + + py::class_(te, "LoopNest") + .def(py::init&>()) + .def("vectorize_inner_loops", &tensorexpr::LoopNest::vectorizeInnerLoops) + .def("prepare_for_codegen", &tensorexpr::LoopNest::prepareForCodegen) + .def( + "get_loop_body_for", + &tensorexpr::LoopNest::getLoopBodyFor, + py::return_value_policy::reference) + .def( + "get_loops_for", + [](const tensorexpr::LoopNest& self, tensorexpr::Tensor* t) { + return self.getLoopStmtsFor(t); + }, + py::return_value_policy::reference) + .def( + "split_with_tail", + [](const tensorexpr::LoopNest& self, tensorexpr::For* f, int factor) { + tensorexpr::For *outer = nullptr, *inner = nullptr, *tail = nullptr; + self.splitWithTail(f, factor, &outer, &inner, &tail); + return std::make_tuple(outer, inner, tail); + }, + py::return_value_policy::reference) + .def( + "split_with_mask", + [](const tensorexpr::LoopNest& self, tensorexpr::For* f, int factor) { + tensorexpr::For *outer = nullptr, *inner = nullptr; + self.splitWithMask(f, factor, &outer, &inner); + return std::make_tuple(outer, inner); + }, + py::return_value_policy::reference) + .def( + "unroll", + [](const tensorexpr::LoopNest& self, tensorexpr::For* f) { + tensorexpr::Stmt* unrolled = nullptr; + self.unroll(f, &unrolled); + return unrolled; + }, + py::return_value_policy::reference) + .def( + "vectorize", + [](const tensorexpr::LoopNest& self, tensorexpr::For* f) { + self.vectorize(f); + }, + py::return_value_policy::reference) + .def( + "compute_inline", + [](tensorexpr::LoopNest& self, tensorexpr::Stmt* s) { + self.computeInline(s); + }, + py::return_value_policy::reference) + .def( + "compute_inline", + [](tensorexpr::LoopNest& self, const tensorexpr::BufHandle& b) { + self.computeInline(b.node()); + }, + py::return_value_policy::reference) + .def( + "rfactor", + [](tensorexpr::LoopNest& self, + const tensorexpr::Stmt& s, + const tensorexpr::VarHandle& v) { + auto st = dynamic_cast(&s); + if (!st) { + return; + } + auto r = st->value(); + self.rfactor(r, v.node()); + }, + py::return_value_policy::reference) + .def( + "rfactor", + [](tensorexpr::LoopNest& self, + const tensorexpr::Stmt& s, + const tensorexpr::VarHandle& v, + tensorexpr::Block& ins_point) { + auto st = dynamic_cast(&s); + if (!st) { + return; + } + auto r = st->value(); + self.rfactor(r, v.node(), &ins_point); + }, + py::return_value_policy::reference) + .def( + "reorder", + &tensorexpr::LoopNest::reorderAxis, + py::return_value_policy::reference) + .def( + "__str__", + [](const tensorexpr::LoopNest& self) { + std::stringstream ss; + ss << *self.root_stmt(); + return ss.str(); + }) + .def( + "root_stmt", + &tensorexpr::LoopNest::root_stmt, + py::return_value_policy::reference); + + te.def( + "simplify", + [](tensorexpr::Stmt* stmt) { + return tensorexpr::IRSimplifier::simplify(stmt); + }, + py::return_value_policy::reference); + + py::class_(te, "CodeGen") + .def( + "call", + [](tensorexpr::CodeGen& self, const std::vector& values) { + std::vector value_ptrs; + for (const auto& value : values) { + value_ptrs.emplace_back( + tensorexpr::CodeGen::CallArg(value.data_ptr())); + } + self.call(value_ptrs); + }); + py::class_( + te, "SimpleIREvaluator"); +#ifdef TORCH_ENABLE_LLVM + py::class_(te, "LLVMCodeGen"); +#endif + + py::class_(te, "BufferArg") + .def(py::init()) + .def(py::init()) + .def(py::init()); + + te.def( + "construct_codegen", + [](const std::string& name, + tensorexpr::Stmt* stmt, + const std::vector& args) { + tensorexpr::CodeGen* cg = nullptr; + if (name == "llvm") { +#ifdef TORCH_ENABLE_LLVM + cg = new tensorexpr::LLVMCodeGen(stmt, args); +#else + throw std::runtime_error("PyTorch not compiled with LLVM support!"); +#endif + } else { + cg = new tensorexpr::SimpleIREvaluator(stmt, args); + } + return cg; + }); +} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.h b/torch/csrc/jit/tensorexpr/tensorexpr_init.h new file mode 100644 index 0000000000000..d3893da99554e --- /dev/null +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +// Initialize Python bindings for Tensor Expressions +void initTensorExprBindings(PyObject* module); +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index f7aa96be4c451..ae9bdcf1986cd 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -9,33 +9,26 @@ namespace torch { namespace jit { namespace tensorexpr { -bool is_integral(const ScalarType& type) { - switch (type) { - case ScalarType::Bool: - case ScalarType::Byte: - case ScalarType::Char: - case ScalarType::Short: - case ScalarType::Int: - case ScalarType::Long: - return true; - default: - return false; - } +static bool is_c10_type(const ScalarType& type) { + return type < ScalarType::Undefined; +} - return false; +bool is_integral(const ScalarType& type) { + return is_c10_type(type) + ? c10::isIntegralType(static_cast(type), true) + : false; } bool is_floating_point(const ScalarType& type) { - switch (type) { - case ScalarType::Half: - case ScalarType::Float: - case ScalarType::Double: - return true; - default: - return false; - } + return is_c10_type(type) + ? c10::isFloatingType(static_cast(type)) + : false; +} - return false; +bool is_signed(const ScalarType& type) { + return is_c10_type(type) + ? c10::isSignedType(static_cast(type)) + : false; } Dtype Dtype::scalar_dtype() const { diff --git a/torch/csrc/jit/tensorexpr/types.h b/torch/csrc/jit/tensorexpr/types.h index 8e39ad2315453..29ccf06ef0350 100644 --- a/torch/csrc/jit/tensorexpr/types.h +++ b/torch/csrc/jit/tensorexpr/types.h @@ -37,6 +37,7 @@ TORCH_API std::ostream& operator<<( TORCH_API bool is_integral(const ScalarType& type); TORCH_API bool is_floating_point(const ScalarType& type); +TORCH_API bool is_signed(const ScalarType& type); // Data types for scalar and vector elements. class TORCH_API Dtype { @@ -75,6 +76,13 @@ class TORCH_API Dtype { bool is_floating_point() const { return tensorexpr::is_floating_point(scalar_type_); } + bool is_signed() const { + return tensorexpr::is_signed(scalar_type_); + } + + Dtype cloneWithScalarType(ScalarType nt) const { + return Dtype(nt, lanes_); + } private: friend std::ostream& operator<<(std::ostream& stream, const Dtype& dtype); @@ -124,10 +132,6 @@ inline Dtype BinaryOpDtype( Dtype op1_dtype, Dtype op2_dtype, ScalarType ret_type = ScalarType::None) { - if (op1_dtype.scalar_type() == ScalarType::Bool || - op2_dtype.scalar_type() == ScalarType::Bool) { - throw malformed_input("arithmetic binary operations on Bool not supported"); - } if (op1_dtype == op2_dtype) { if (ret_type == ScalarType::None) { return op1_dtype; diff --git a/torch/csrc/jit/tensorexpr/var_substitutor.h b/torch/csrc/jit/tensorexpr/var_substitutor.h index 3a02507c6dca4..29e0f8de2a012 100644 --- a/torch/csrc/jit/tensorexpr/var_substitutor.h +++ b/torch/csrc/jit/tensorexpr/var_substitutor.h @@ -37,7 +37,7 @@ class VarSubMutator : public IRMutator { } const Expr* mutate(const ReduceOp* var) override { - auto body = var->body().node()->accept_mutator(this); + auto body = var->body()->accept_mutator(this); std::vector new_outer; std::vector new_inner; @@ -59,10 +59,10 @@ class VarSubMutator : public IRMutator { return new ReduceOp( const_cast(var->accumulator()), - ExprHandle(body), - var->interaction(), + body, new_outer, - new_inner); + new_inner, + var->reducer()); } private: diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index a6250688ebb79..f5e96a501bfd1 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -43,7 +43,7 @@ struct Check { std::string str, c10::optional count = c10::nullopt) : type_(type), search_str_(std::move(str)) { - count_ = std::move(count); + count_ = count; }; CheckType type_; @@ -86,7 +86,7 @@ namespace { size_t assertFind( const SourceRange& search_range, const std::string& sub, - std::function extra_msg = nullptr) { + const std::function& extra_msg = nullptr) { auto pos = search_range.source()->text().find(sub, search_range.start()); if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) { auto found_range = @@ -166,7 +166,7 @@ struct FileCheckImpl { run(test_file); } - TORCH_API void addCheck(Check check) { + TORCH_API void addCheck(const Check& check) { // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group if (groups.size() == 0 || (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) { @@ -186,7 +186,7 @@ struct FileCheckImpl { CheckType type, const std::string& s, c10::optional count = c10::nullopt) { - addCheck(Check(type, s, std::move(count))); + addCheck(Check(type, s, count)); } bool has_run = false; @@ -548,7 +548,11 @@ FileCheck* FileCheck::check_count( const std::string& str, size_t count, bool exactly) { - fcImpl->addCheck(CHECK_COUNT, str, count); + TORCH_INTERNAL_ASSERT( + count != 0 || exactly, "Count == 0 && !exactly doesn't do anything"); + if (count) { + fcImpl->addCheck(CHECK_COUNT, str, count); + } if (exactly) { fcImpl->addCheck(CHECK_NOT, str); } diff --git a/torch/csrc/jit/testing/hooks_for_testing.cpp b/torch/csrc/jit/testing/hooks_for_testing.cpp index 29dbc690d9fb2..553938afd77c3 100644 --- a/torch/csrc/jit/testing/hooks_for_testing.cpp +++ b/torch/csrc/jit/testing/hooks_for_testing.cpp @@ -1,4 +1,5 @@ #include + #include namespace torch { diff --git a/torch/csrc/multiprocessing/init.cpp b/torch/csrc/multiprocessing/init.cpp index 8255cf3d6e774..7479003c0f748 100644 --- a/torch/csrc/multiprocessing/init.cpp +++ b/torch/csrc/multiprocessing/init.cpp @@ -43,7 +43,7 @@ PyObject* multiprocessing_init(PyObject* _unused, PyObject *noargs) { static PyMethodDef methods[] = { { "_multiprocessing_init", - (PyCFunction)multiprocessing_init, + multiprocessing_init, METH_NOARGS, nullptr, }, diff --git a/torch/csrc/onnx/onnx.h b/torch/csrc/onnx/onnx.h index 900de0ea57c23..9daced8a3e4de 100644 --- a/torch/csrc/onnx/onnx.h +++ b/torch/csrc/onnx/onnx.h @@ -20,5 +20,5 @@ enum class TrainingMode { // onnx::IR_VERSION. with this change, the test_operators.py will be more // stable. only bump it when it's necessary static const size_t IR_VERSION = 6; -static const char* PRODUCER_VERSION = "1.7"; +static const char* PRODUCER_VERSION = "1.8"; }} // namespace torch::onnx diff --git a/torch/csrc/python_headers.h b/torch/csrc/python_headers.h index 9c670b5202b0c..2a64bdd5c6ee0 100644 --- a/torch/csrc/python_headers.h +++ b/torch/csrc/python_headers.h @@ -1,5 +1,5 @@ #pragma once - +#include // workaround for Python 2 issue: https://bugs.python.org/issue17120 // NOTE: It looks like this affects Python 3 as well. #pragma push_macro("_XOPEN_SOURCE") diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 46fd6d595e6a6..9b2b522c51d7f 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -76,8 +76,9 @@ static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs // instanceof(t, torch.FloatTensor) work, but we are not going to keep // adding torch.QuantizedIntTensor classes for every new tensor type // we add... -static PyObject* Tensor_instancecheck(PyTensorType* self, PyObject* arg) { +static PyObject* Tensor_instancecheck(PyObject* _self, PyObject* arg) { HANDLE_TH_ERRORS + auto self = (PyTensorType*)_self; if (THPVariable_Check(arg)) { auto& var = ((THPVariable*)arg)->cdata; // NB: This is a little unfortunate, in that if I do an isinstance check @@ -123,7 +124,7 @@ PyObject *Tensor_is_sparse(PyTensorType *self, void *unused) { } static struct PyMethodDef metaclass_methods[] = { - {"__instancecheck__", (PyCFunction)Tensor_instancecheck, METH_O, nullptr}, + {"__instancecheck__", Tensor_instancecheck, METH_O, nullptr}, {nullptr} }; diff --git a/torch/csrc/utils.h b/torch/csrc/utils.h index c6ef331de2b99..010ea498c3e78 100644 --- a/torch/csrc/utils.h +++ b/torch/csrc/utils.h @@ -115,6 +115,9 @@ #define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object) #define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) #define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value) +#define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) +#define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value) #define THPUtils_assert(cond, ...) THPUtils_assertRet(nullptr, cond, __VA_ARGS__) diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index ff7478ac5f39d..28414dc911d19 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace torch { static thread_local bool enable_torch_function = true; @@ -46,8 +47,8 @@ PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused) { } static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT - {"__enter__", (PyCFunction)DisableTorchFunction__enter, METH_NOARGS, nullptr}, - {"__exit__", (PyCFunction)DisableTorchFunction__exit, METH_VARARGS, nullptr}, + {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, + {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr} }; @@ -125,3 +126,111 @@ PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) { return result; END_HANDLE_TH_ERRORS } + +// Makes sure that we don't check for __torch_function__ on basic Python types +static bool is_basic_python_type(PyTypeObject *tp) +{ + return ( + /* Basic number types */ + tp == &PyBool_Type || + + tp == &PyLong_Type || + tp == &PyFloat_Type || + tp == &PyComplex_Type || + + /* Basic sequence types */ + tp == &PyList_Type || + tp == &PyTuple_Type || + tp == &PyDict_Type || + tp == &PySet_Type || + tp == &PyFrozenSet_Type || + tp == &PyUnicode_Type || + tp == &PyBytes_Type || + + /* other builtins */ + tp == &PySlice_Type || + tp == Py_TYPE(Py_None) || + tp == Py_TYPE(Py_Ellipsis) || + tp == Py_TYPE(Py_NotImplemented) || + + PyModule_Check(tp) || + /* sentinel to swallow trailing || */ + false + ); +} + +inline bool has_torch_function_attr(PyObject* obj) { + auto attr = PyObject_FastGetAttrString(obj, "__torch_function__"); + return ( + attr.ptr() != nullptr && + attr.ptr() != torch::disabled_torch_function); +} + +namespace torch { +auto check_has_torch_function(PyObject* obj) -> bool +{ + PyTypeObject *tp = Py_TYPE(obj); + return ( + !THPVariable_CheckTypeExact(tp) && + !is_basic_python_type(tp) && + torch::torch_function_enabled() && + has_torch_function_attr(obj) + ); +} +} // namespace torch + +inline bool sequence_has_torch_function(PyObject* args) { + Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args); + for (Py_ssize_t i = 0; i < nargs; i++) { + PyObject* obj = PySequence_Fast_GET_ITEM(args, i); + if (torch::check_has_torch_function(obj)) { + return true; + } + } + return false; +} + +inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) { + for (Py_ssize_t i = 0; i < nargs; i++) { + if (torch::check_has_torch_function(args[i])) { + return true; + } + } + return false; +} + +PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) { + bool result; // NOLINT(cppcoreguidelines-init-variables) + if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) { + // Fast path: + // If we know that we have a tuple or list, we can skip an INCREF and + // DECREF from PySequence_Fast. Core functions will always follow this + // convention (almost always tuples), and it shaves ~3.5% off the cost of + // the check. + result = sequence_has_torch_function(arg); + } else { + auto args = py::reinterpret_steal( + PySequence_Fast(arg, "expected a sequence")); + result = sequence_has_torch_function(args.ptr()); + } + + if (result) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) { + // Special case `THPModule_has_torch_function` for the single arg case. + if (torch::check_has_torch_function(obj)) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) { + if (array_has_torch_function(args, nargs)) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 12166607f79c8..c4f9651d5342f 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -9,8 +9,12 @@ namespace torch { bool torch_function_enabled(); PyObject* disabled_torch_function_impl(); void set_disabled_torch_function_impl(PyObject* value); + bool check_has_torch_function(PyObject* obj); } PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused); PyObject* THPModule_DisableTorchFunctionType(); -PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args); \ No newline at end of file +PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args); +PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg); +PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj); +PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs); diff --git a/torch/csrc/utils/future.h b/torch/csrc/utils/future.h index 6d672ee86cd56..093d043ecf7d4 100644 --- a/torch/csrc/utils/future.h +++ b/torch/csrc/utils/future.h @@ -26,7 +26,7 @@ class TORCH_API FutureError final : public std::exception { // Most implementation is copied from FutureMessage and // c10::ivalue::Future template -class TORCH_API Future final { +class TORCH_PYTHON_API Future final { public: Future() = default; diff --git a/torch/csrc/utils/out_types.cpp b/torch/csrc/utils/out_types.cpp new file mode 100644 index 0000000000000..0ceeb43bd1f81 --- /dev/null +++ b/torch/csrc/utils/out_types.cpp @@ -0,0 +1,39 @@ +#include + +namespace torch { +namespace utils { + +// Used by python binding codegen to ensure any TensorOptions arguments are consistent +// with the out tensor's options +void check_out_type_matches(const at::Tensor& result, + at::ScalarType scalarType, bool scalarType_is_none, + c10::optional layout, + const at::Device& device, bool device_is_none) { + if (scalarType_is_none && !layout && device_is_none) { // common case + return; + } + if (!scalarType_is_none && result.scalar_type() != scalarType) { + AT_ERROR( + "dtype ", scalarType, + " does not match dtype of out parameter (", result.scalar_type(), ")"); + } + auto scalarType_arg = scalarType_is_none ? result.scalar_type() : scalarType; + auto device_type_arg = device_is_none ? result.device().type() : device.type(); + if (result.scalar_type() != scalarType_arg) { + AT_ERROR( + "scalar type ", scalarType_arg, + " does not match scalar type of out parameter (", result.scalar_type(), ")"); + } + if (layout && result.layout() != *layout) { + AT_ERROR( + "layout ", *layout, + " does not match layout of out parameter (", result.layout(), ")"); + } + if (result.device().type() != device_type_arg) { + AT_ERROR( + "device type ", device_type_arg, + " does not match device type of out parameter (", result.device().type(), ")"); + } +} + +}} diff --git a/torch/csrc/utils/out_types.h b/torch/csrc/utils/out_types.h new file mode 100644 index 0000000000000..adc3686a6b974 --- /dev/null +++ b/torch/csrc/utils/out_types.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace torch { +namespace utils { + +TORCH_API void check_out_type_matches( + const at::Tensor& result, + at::ScalarType scalarType, bool scalarType_is_none, + c10::optional layout, + const at::Device& device, bool device_is_none); + +}} diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 7df518f404c59..1447508535e58 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -17,6 +17,11 @@ namespace py = pybind11; +// This makes intrusive_ptr to be available as a custom pybind11 holder type, +// see +// https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers +PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr, true); + namespace pybind11 { namespace detail { // torch.autograd.Variable <-> at::Tensor conversions (without unwrapping) diff --git a/torch/csrc/utils/pycfunction_helpers.h b/torch/csrc/utils/pycfunction_helpers.h new file mode 100644 index 0000000000000..099ba0820a86e --- /dev/null +++ b/torch/csrc/utils/pycfunction_helpers.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +inline PyCFunction castPyCFunctionWithKeywords(PyCFunctionWithKeywords func) { + return (PyCFunction)(void(*)(void))func; +} diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index e954bef398e9c..4208f653e05d7 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -24,6 +24,7 @@ static std::unordered_map type_map = { {"double", ParameterType::DOUBLE}, {"complex", ParameterType::COMPLEX}, {"TensorList", ParameterType::TENSOR_LIST}, + {"c10::List>", ParameterType::TENSOR_LIST}, {"IntArrayRef", ParameterType::INT_LIST}, {"ArrayRef", ParameterType::FLOAT_LIST}, {"Generator", ParameterType::GENERATOR}, @@ -35,9 +36,11 @@ static std::unordered_map type_map = { {"MemoryFormat", ParameterType::MEMORY_FORMAT}, {"QScheme", ParameterType::QSCHEME}, {"Device", ParameterType::DEVICE}, + {"Stream", ParameterType::STREAM}, {"std::string", ParameterType::STRING}, {"Dimname", ParameterType::DIMNAME}, {"DimnameList", ParameterType::DIMNAME_LIST}, + {"ScalarList", ParameterType::SCALAR_LIST}, }; // Default arg name translations for compatibility with NumPy. @@ -138,7 +141,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* { py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); std::string module_name = "torch.Tensor." + property_name; - return handle_torch_function((PyObject *)self, "__get__", nullptr, torch_api.ptr(), module_name); + return handle_torch_function((PyObject *)self, "__get__", nullptr, nullptr, torch_api.ptr(), module_name); } auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int { @@ -147,10 +150,10 @@ auto handle_torch_function_setter(THPVariable* self, const std::string& property if (value != nullptr) { py::tuple args_ = py::make_tuple(py::handle(value)); - handle_torch_function((PyObject *)self, "__set__", args_.ptr(), torch_api.ptr(), module_name); + handle_torch_function((PyObject *)self, "__set__", args_.ptr(), nullptr, torch_api.ptr(), module_name); } else { - handle_torch_function((PyObject *)self, "__delete__", nullptr, torch_api.ptr(), module_name); + handle_torch_function((PyObject *)self, "__delete__", nullptr, nullptr, torch_api.ptr(), module_name); } return 0; } @@ -174,13 +177,13 @@ auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { return args_; } -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* torch_api, const std::string& module_name) -> PyObject* { +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* kwargs, PyObject* torch_api, const std::string& module_name) -> PyObject* { py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); py::tuple args_ = combine_self_args(self, args); py::tuple py_types = py::make_tuple(py::handle(PyObject_Type(self))); py::object torch_function = PyObject_FastGetAttrString(self, "__torch_function__"); - py::object ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), NULL)); + py::object ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), kwargs)); if (ret.ptr() == nullptr) { // if an exception occurred in a user's implementation of // __torch_function__, throw it @@ -331,7 +334,7 @@ void append_overloaded_arg(std::vector* overloaded_args, PyObject* o bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* overloaded_args) { if (THPVariable_CheckExact(obj)) { - // torch.Tensor instances (not subclasses) + // torch.Tensor instances (not subclasses, except for Parameter) return true; } @@ -347,13 +350,28 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* ove return false; } -bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { +bool is_scalar_list(PyObject* obj) { auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); for (size_t idx = 0; idx < size; idx++) { + PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); + if (!THPUtils_checkScalar(iobj)) { + return false; + } + } + return true; +} + +bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { + auto tuple = six::isTuple(obj); + if (!(tuple || PyList_Check(obj))) { + return false; + } + auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); + for (long idx = 0; idx < size; idx++) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) { if (throw_error) { @@ -366,6 +384,23 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector return true; } +bool is_float_or_complex_list(PyObject* obj) { + auto tuple = six::isTuple(obj); + if (!(tuple || PyList_Check(obj))) { + return false; + } + + auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); + if (size > 0) { + PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); + if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) { + return false; + } + } + + return true; +} + // argnum is needed for raising the TypeError, it's used in the error message. auto FunctionParameter::check(PyObject* obj, std::vector &overloaded_args, int argnum) -> bool { @@ -420,7 +455,7 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int return size > 0 && THPUtils_checkLong(obj); } - case ParameterType::FLOAT_LIST: return (PyTuple_Check(obj) || PyList_Check(obj)); + case ParameterType::FLOAT_LIST: return is_float_or_complex_list(obj); case ParameterType::GENERATOR: return THPGenerator_Check(obj); case ParameterType::BOOL: return PyBool_Check(obj); case ParameterType::STORAGE: return isStorage(obj); @@ -431,8 +466,13 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded case ParameterType::QSCHEME: return THPQScheme_Check(obj); case ParameterType::DEVICE: return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj); + case ParameterType::STREAM: + return THPStream_Check(obj); case ParameterType::STRING: return THPUtils_checkString(obj); default: throw std::runtime_error("unknown parameter type"); + case ParameterType::SCALAR_LIST: { + return is_scalar_list(obj); + } } } @@ -458,6 +498,7 @@ std::string FunctionParameter::type_name() const { case ParameterType::STRING: return "str"; case ParameterType::DIMNAME: return "name"; case ParameterType::DIMNAME_LIST: return "tuple of names"; + case ParameterType::SCALAR_LIST: return "tuple of Scalars"; default: throw std::runtime_error("unknown parameter type"); } } @@ -503,6 +544,62 @@ static inline std::vector parse_intlist_args(const std::string& s, int6 return args; } +// Parse a string literal to remove quotes and escape sequences +static std::string parse_string_literal(c10::string_view str) { + TORCH_CHECK(str.length() >= 2, "String defaults must be quoted"); + + if (str.front() == '"') { + TORCH_CHECK(str.back() == '"', + "Mismatched quotes in string default: ", str); + } else { + TORCH_CHECK(str.front() == '\'' && str.back() == '\'', + "Invalid quotes in string default: ", str) + } + + std::string parsed; + parsed.reserve(str.size()); + for (size_t i = 1; i < str.size() - 1;) { + if (str[i] != '\\') { + parsed.push_back(str[i]); + ++i; + continue; + } + + // Handle escape sequences + TORCH_CHECK(i < str.size() - 2, "String ends with escaped final quote: ", str) + char c = str[i + 1]; + switch (c) { + case '\\': + case '\'': + case '\"': + break; + case 'a': + c = '\a'; + break; + case 'b': + c = '\b'; + break; + case 'f': + c = '\f'; + break; + case 'n': + c = '\n'; + break; + case 'v': + c = '\v'; + break; + case 't': + c = '\t'; + break; + default: + TORCH_CHECK(false, "Unsupported escape sequence in string default: \\", str[i + 1]); + } + parsed.push_back(c); + i += 2; + } + return parsed; +} + void FunctionParameter::set_default_str(const std::string& str) { if (str == "None") { allow_none = true; @@ -557,9 +654,13 @@ void FunctionParameter::set_default_str(const std::string& str) { if (str != "None") { throw std::runtime_error("invalid device: " + str); } + } else if (type_ == ParameterType::STREAM) { + if (str != "None") { + throw std::runtime_error("invalid stream: " + str); + } } else if (type_ == ParameterType::STRING) { - if (str != "None" && str != "") { - throw std::runtime_error("invalid default string: " + str); + if (str != "None") { + default_string = parse_string_literal(str); } } } @@ -761,7 +862,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, } int i = 0; - if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) { + if (self != nullptr && check_has_torch_function(self)) { append_overloaded_arg(&this->overloaded_args, self); } for (auto& param : params) { @@ -901,6 +1002,7 @@ PythonArgs PythonArgParser::raw_parse(PyObject* self, PyObject* args, PyObject* print_error(self, args, kwargs, parsed_args); } + void PythonArgParser::print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT auto num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0); std::vector plausible_idxs; @@ -974,24 +1076,28 @@ at::Scalar PythonArgs::scalar_slow(int i) { signature.params[i].name, idx, var, jit::NumberType::get()); } + return scalar_slow(args[i]); +} + +at::Scalar PythonArgs::scalar_slow(PyObject* arg) { // Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently // handle most NumPy scalar types except np.float64. - if (THPVariable_Check(args[i])) { - return ((THPVariable*)args[i])->cdata.item(); + if (THPVariable_Check(arg)) { + return ((THPVariable*)arg)->cdata.item(); } - if (THPUtils_checkLong(args[i])) { - return at::Scalar(static_cast(THPUtils_unpackLong(args[i]))); + if (THPUtils_checkLong(arg)) { + return at::Scalar(static_cast(THPUtils_unpackLong(arg))); } - if (PyBool_Check(args[i])) { - return at::Scalar(THPUtils_unpackBool(args[i])); + if (PyBool_Check(arg)) { + return at::Scalar(THPUtils_unpackBool(arg)); } - if (PyComplex_Check(args[i])) { - return at::Scalar(THPUtils_unpackComplexDouble(args[i])); + if (PyComplex_Check(arg)) { + return at::Scalar(THPUtils_unpackComplexDouble(arg)); } - return at::Scalar(THPUtils_unpackDouble(args[i])); + return at::Scalar(THPUtils_unpackDouble(arg)); } } // namespace torch diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 78efb6cf2db3d..0f7f595f57f94 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -42,6 +42,7 @@ #include +#include #include #include #include @@ -78,8 +79,8 @@ namespace torch { enum class ParameterType { TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR, - BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING, - DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST + BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING, + DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST }; struct FunctionParameter; @@ -157,7 +158,9 @@ struct PythonArgs { inline c10::optional optionalTensor(int i); inline at::Scalar scalar(int i); inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar); + inline std::vector scalarlist(int i); inline std::vector tensorlist(int i); + inline torch::List> list_of_optional_tensors(int i); template inline std::array tensorlist_n(int i); inline std::vector intlist(int i); @@ -165,6 +168,7 @@ struct PythonArgs { inline std::vector intlistWithDefault(int i, std::vector default_intlist); inline c10::optional generator(int i); inline at::Storage storage(int i); + inline c10::Stream stream(int i); inline at::ScalarType scalartype(int i); inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype); inline c10::optional scalartypeOptional(int i); @@ -173,6 +177,8 @@ struct PythonArgs { inline c10::optional toBoolOptional(int i); inline c10::optional toDoubleOptional(int i); inline c10::OptionalArray doublelistOptional(int i); + inline std::vector doublelist(int i); + inline std::vector getDoublelist(int i); inline at::Layout layout(int i); inline at::Layout layoutWithDefault(int i, at::Layout default_layout); inline c10::optional layoutOptional(int i); @@ -186,6 +192,7 @@ struct PythonArgs { inline c10::optional memoryformatOptional(int i); inline at::QScheme toQScheme(int i); inline std::string string(int i); + inline std::string stringWithDefault(int i, const std::string& default_str); inline c10::optional stringOptional(int i); inline PyObject* pyobject(int i); inline int64_t toInt64(int i); @@ -201,6 +208,7 @@ struct PythonArgs { private: at::Tensor tensor_slow(int i); at::Scalar scalar_slow(int i); + at::Scalar scalar_slow(PyObject* arg); }; struct FunctionParameter { @@ -224,6 +232,7 @@ struct FunctionParameter { at::SmallVector numpy_python_names; at::Scalar default_scalar; std::vector default_intlist; + std::string default_string; union { bool default_bool; int64_t default_int; @@ -281,6 +290,19 @@ inline at::Scalar PythonArgs::scalar(int i) { return scalar_slow(i); } +inline std::vector PythonArgs::scalarlist(int i) { + if (!args[i]) return std::vector(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + std::vector res(size); + for (int idx = 0; idx < size; idx++) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); + res[idx] = scalar_slow(obj); + } + return res; +} + inline at::Scalar PythonArgs::scalarWithDefault(int i, at::Scalar default_scalar) { if (!args[i]) return default_scalar; return scalar_slow(i); @@ -306,6 +328,22 @@ inline std::vector PythonArgs::tensorlist(int i) { return res; } +inline torch::List> PythonArgs::list_of_optional_tensors(int i) { + if (!args[i]) return torch::List>(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + torch::List> res; + res.reserve(size); + for (int idx = 0; idx < size; idx++) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without checking + // if this is a tensor first + res.push_back(reinterpret_cast(obj)->cdata); + } + return res; +} + template inline std::array PythonArgs::tensorlist_n(int i) { auto res = std::array(); @@ -369,10 +407,7 @@ inline c10::OptionalArray PythonArgs::intlistOptional(int i) { return intlist(i); } -inline c10::OptionalArray PythonArgs::doublelistOptional(int i) { - if (!args[i]) { - return {}; - } +inline std::vector PythonArgs::getDoublelist(int i) { PyObject* arg = args[i]; auto tuple = PyTuple_Check(arg); auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); @@ -390,6 +425,20 @@ inline c10::OptionalArray PythonArgs::doublelistOptional(int i) { return res; } +inline c10::OptionalArray PythonArgs::doublelistOptional(int i) { + if (!args[i]) { + return {}; + } + return this->getDoublelist(i); +} + +inline std::vector PythonArgs::doublelist(int i) { + if (!args[i]) { + return {}; + } + return this->getDoublelist(i); +} + inline at::ScalarType PythonArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) { if (!args[i]) return default_scalartype; return scalartype(i); @@ -517,7 +566,11 @@ inline at::QScheme PythonArgs::toQScheme(int i) { } inline std::string PythonArgs::string(int i) { - if (!args[i]) return ""; + return stringWithDefault(i, signature.params[i].default_string); +} + +inline std::string PythonArgs::stringWithDefault(int i, const std::string& default_str) { + if (!args[i]) return default_str; return THPUtils_unpackString(args[i]); } @@ -580,7 +633,7 @@ inline c10::complex PythonArgs::toComplex(int i) { inline c10::complex PythonArgs::toComplexWithDefault(int i, c10::complex default_value) { if (!args[i]) return default_value; - return toDouble(i); + return toComplex(i); } inline bool PythonArgs::toBool(int i) { @@ -607,130 +660,17 @@ inline at::Storage PythonArgs::storage(int i) { return createStorage(args[i]); } -inline PyObject* PythonArgs::pyobject(int i) { - if (!args[i]) return Py_None; - return args[i]; -} - -/* - * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 - * - * Stripped down version of PyObject_GetAttrString, - * avoids lookups for None, tuple, and List objects, - * and doesn't create a PyErr since this code ignores it. - * - * This can be much faster then PyObject_GetAttrString where - * exceptions are not used by caller. - * - * 'obj' is the object to search for attribute. - * - * 'name' is the attribute to search for. - * - * Returns a py::object wrapping the return value. If the attribute lookup failed - * the value will be NULL. - * - */ - -static py::object PyObject_FastGetAttrString(PyObject *obj, char *name) -{ - PyTypeObject *tp = Py_TYPE(obj); - PyObject *res = (PyObject *)NULL; - - /* Attribute referenced by (char *)name */ - if (tp->tp_getattr != NULL) { - res = (*tp->tp_getattr)(obj, name); - if (res == NULL) { - PyErr_Clear(); - } - } - /* Attribute referenced by (PyObject *)name */ - else if (tp->tp_getattro != NULL) { - PyObject *w = THPUtils_internString(name); - if (w == NULL) { - return py::object(); - } - res = (*tp->tp_getattro)(obj, w); - Py_DECREF(w); - if (res == NULL) { - PyErr_Clear(); - } - } - return py::reinterpret_steal(res); -} - -// Makes sure that we don't check for __torch_function__ on basic Python types -static bool _is_basic_python_type(PyTypeObject *tp) -{ - return ( - /* Basic number types */ - tp == &PyBool_Type || - - tp == &PyLong_Type || - tp == &PyFloat_Type || - tp == &PyComplex_Type || - - /* Basic sequence types */ - tp == &PyList_Type || - tp == &PyTuple_Type || - tp == &PyDict_Type || - tp == &PySet_Type || - tp == &PyFrozenSet_Type || - tp == &PyUnicode_Type || - tp == &PyBytes_Type || - - /* other builtins */ - tp == &PySlice_Type || - tp == Py_TYPE(Py_None) || - tp == Py_TYPE(Py_Ellipsis) || - tp == Py_TYPE(Py_NotImplemented) || - - PyModule_Check(tp) || - /* sentinel to swallow trailing || */ - false - ); -} - -/* - * Lookup a special method, following the python approach of looking up - * on the type object, rather than on the instance itself. - * - * Assumes that the special method is a torch-specific one, so does not - * look at builtin types, nor does it look at a base Tensor. - * - * If no special method is found, return NULL, otherwise returns a new - * reference to the function object - * - * In future, could be made more like _Py_LookupSpecial - */ - -static py::object PyTorch_LookupSpecial(PyObject *obj, char* name) -{ - if (THPVariable_CheckExact(obj)) { - return py::object(); - } - PyTypeObject *tp = Py_TYPE(obj); - if (_is_basic_python_type(tp)) { - return py::object(); +inline c10::Stream PythonArgs::stream(int i) { + if (!args[i]) return c10::Stream(c10::Stream::Default::DEFAULT, c10::Device(DeviceType::CPU, -1)); + if (!THPStream_Check(args[i])) { + throw TypeError("expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name); } - return PyObject_FastGetAttrString((PyObject *)tp, name); + return c10::Stream::unpack(((THPStream*)args[i])->cdata); } -/* - * Checks if obj has a __torch_function__ implementation - * - * Returns true if an implementation is found and false otherwise - * - */ -static auto check_has_torch_function(PyObject* obj) -> bool -{ - if (!torch_function_enabled()) { - return false; - } - py::object method = PyTorch_LookupSpecial(obj, "__torch_function__"); - if(method.ptr() != nullptr && method.ptr() != disabled_torch_function_impl()){ - return true; - } - return false; +inline PyObject* PythonArgs::pyobject(int i) { + if (!args[i]) return Py_None; + return args[i]; } /* @@ -791,8 +731,8 @@ auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObje // Used for functions which needs to parse python args. auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*; -// Used for functions that accept no keyword arguments and have no argument parsing -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; +// Used for functions that have no argument parsing. +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* kwargs=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; // Used for functions created in C++, e.g., C++ custom op, which doesn't use PythonArgParser to get overloaded_args. auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject*; diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index 28d990c64c421..48ce9c1bdbcc4 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -2,6 +2,33 @@ #include +#if PY_VERSION_HEX < 0x03070000 +// METH_FASTCALL was introduced in Python 3.7, so we wrap _PyCFunctionFast +// signatures for earlier versions. + +template +PyObject* maybe_wrap_fastcall(PyObject *module, PyObject *args) { + return f( + module, + + // _PyTuple_ITEMS + // Because this is only a compat shim for Python 3.6, we don't have + // to worry about the representation changing. + ((PyTupleObject *)args)->ob_item, + PySequence_Fast_GET_SIZE(args) + ); +} + +#define MAYBE_METH_FASTCALL METH_VARARGS +#define MAYBE_WRAP_FASTCALL(f) maybe_wrap_fastcall + +#else + +#define MAYBE_METH_FASTCALL METH_FASTCALL +#define MAYBE_WRAP_FASTCALL(f) (PyCFunction)(void(*)(void))f + +#endif + // PyPy 3.6 does not yet have PySlice_Unpack #if PY_VERSION_HEX < 0x03060100 || defined(PYPY_VERSION) @@ -63,20 +90,5 @@ __PySlice_Unpack(PyObject *_r, (PySlice_Unpack(SLICE, START, STOP, STEP) == 0) #endif -// https://bugsfiles.kde.org/attachment.cgi?id=61186 -#if PY_VERSION_HEX >= 0x03020000 #define THPUtils_parseSlice(SLICE, LEN, START, STOP, LENGTH, STEP) \ (PySlice_GetIndicesEx(SLICE, LEN, START, STOP, LENGTH, STEP) == 0) -#else -#define THPUtils_parseSlice(SLICE, LEN, START, STOP, LENGTH, STEP) \ - (PySlice_GetIndicesEx((PySliceObject*)SLICE, LEN, START, STOP, LENGTH, STEP) == 0) -#endif - -// This function was introduced in Python 3.4 -#if PY_VERSION_HEX < 0x03040000 -inline int -PyGILState_Check() { - PyThreadState * tstate = _PyThreadState_Current; - return tstate && (tstate == PyGILState_GetThisThreadState()); -} -#endif diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index f0f63bf7a2f08..0ccb62dd52f0d 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -33,6 +33,7 @@ c10::optional parseDispatchKey(const std::string& k) { {"QuantizedCPU", c10::DispatchKey::QuantizedCPU}, {"Math", c10::DispatchKey::Math}, {"Autograd", c10::DispatchKey::Autograd}, + {"DefaultBackend", c10::DispatchKey::DefaultBackend}, {"AutogradCPU", c10::DispatchKey::AutogradCPU}, {"", c10::DispatchKey::Undefined}, }; diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index 9cb4028bb2a92..391dddf14c5d4 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -6,15 +6,24 @@ #include #include #include +#include #include // largest integer that can be represented consecutively in a double const int64_t DOUBLE_INT_MAX = 9007199254740992; +inline PyObject* THPUtils_packInt32(int32_t value) { + return PyLong_FromLong(value); +} + inline PyObject* THPUtils_packInt64(int64_t value) { return PyLong_FromLongLong(value); } +inline PyObject* THPUtils_packUInt32(uint32_t value) { + return PyLong_FromUnsignedLong(value); +} + inline PyObject* THPUtils_packUInt64(uint64_t value) { return PyLong_FromUnsignedLongLong(value); } @@ -33,6 +42,22 @@ inline bool THPUtils_checkLong(PyObject* obj) { return PyLong_Check(obj) && !PyBool_Check(obj); } +inline int32_t THPUtils_unpackInt(PyObject* obj) { + int overflow; + long value = PyLong_AsLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking long"); + } + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + throw std::runtime_error("Overflow when unpacking long"); + } + return (int32_t)value; +} + inline int64_t THPUtils_unpackLong(PyObject* obj) { int overflow; long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); @@ -45,6 +70,17 @@ inline int64_t THPUtils_unpackLong(PyObject* obj) { return (int64_t)value; } +inline uint32_t THPUtils_unpackUInt32(PyObject* obj) { + unsigned long value = PyLong_AsUnsignedLong(obj); + if (PyErr_Occurred()) { + throw python_error(); + } + if (value > std::numeric_limits::max()) { + throw std::runtime_error("Overflow when unpacking unsigned long"); + } + return (uint32_t)value; +} + inline uint64_t THPUtils_unpackUInt64(PyObject* obj) { unsigned long long value = PyLong_AsUnsignedLongLong(obj); if (PyErr_Occurred()) { diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h index 55cf0c3df7c42..66e5bf154834e 100644 --- a/torch/csrc/utils/python_strings.h +++ b/torch/csrc/utils/python_strings.h @@ -4,6 +4,7 @@ #include #include #include +#include // Utilities for handling Python strings. Note that PyString, when defined, is // the same as PyBytes. @@ -54,3 +55,49 @@ inline bool THPUtils_isInterned(PyObject* obj) { inline void THPUtils_internStringInPlace(PyObject** obj) { PyUnicode_InternInPlace(obj); } + +/* + * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 + * + * Stripped down version of PyObject_GetAttrString, + * avoids lookups for None, tuple, and List objects, + * and doesn't create a PyErr since this code ignores it. + * + * This can be much faster then PyObject_GetAttrString where + * exceptions are not used by caller. + * + * 'obj' is the object to search for attribute. + * + * 'name' is the attribute to search for. + * + * Returns a py::object wrapping the return value. If the attribute lookup failed + * the value will be NULL. + * + */ + +static py::object PyObject_FastGetAttrString(PyObject *obj, char *name) +{ + PyTypeObject *tp = Py_TYPE(obj); + PyObject *res = (PyObject *)nullptr; + + /* Attribute referenced by (char *)name */ + if (tp->tp_getattr != nullptr) { + res = (*tp->tp_getattr)(obj, name); + if (res == nullptr) { + PyErr_Clear(); + } + } + /* Attribute referenced by (PyObject *)name */ + else if (tp->tp_getattro != nullptr) { + auto w = py::reinterpret_steal( + THPUtils_internString(name)); + if (w.ptr() == nullptr) { + return py::object(); + } + res = (*tp->tp_getattro)(obj, w.ptr()); + if (res == nullptr) { + PyErr_Clear(); + } + } + return py::reinterpret_steal(res); +} diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h index 932f0bf61a29d..b83e60c77cf3c 100644 --- a/torch/csrc/utils/six.h +++ b/torch/csrc/utils/six.h @@ -23,11 +23,7 @@ inline bool isTuple(pybind11::handle input) { if (PyTuple_Check(input.ptr())) { return true; } -#if PY_MAJOR_VERSION == 2 - return isStructSeq(input); -#else return false; -#endif } inline bool isTuple(PyObject* obj) { @@ -40,12 +36,8 @@ inline bool isTuple(PyObject* obj) { // But on Python 2, structseq is not a subtype of tuple, so we need to manually create a // new tuple object from structseq. inline THPObjectPtr maybeAsTuple(PyStructSequence *obj) { -#if PY_MAJOR_VERSION == 2 - return THPObjectPtr(torch::utils::structseq_slice(obj, 0, Py_SIZE(obj))); -#else Py_INCREF(obj); return THPObjectPtr((PyObject *)obj); -#endif } inline THPObjectPtr maybeAsTuple(PyObject *obj) { diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index d15e75fe8f38c..724b7e35d8a15 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -49,6 +49,8 @@ std::pair getDtypeNames( return std::make_pair("qint32", ""); case at::ScalarType::BFloat16: return std::make_pair("bfloat16", ""); + case at::ScalarType::QUInt4x2: + return std::make_pair("quint4x2", ""); default: throw std::runtime_error("Unimplemented scalar type"); } diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index ee86239aa9d53..87472f4cd81fc 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -270,7 +270,7 @@ Tensor internal_new_from_data( if (PyArray_Check(data)) { TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from numpy"); - auto tensor = tensor_from_numpy(data); + auto tensor = tensor_from_numpy(data, /*warn_if_not_writeable=*/!copy_numpy); const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(dispatch_key)); pybind11::gil_scoped_release no_gil; @@ -360,9 +360,8 @@ void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_la void check_legacy_ctor_device(c10::DispatchKey dispatch_key, c10::optional device) { if (device.has_value()) { TORCH_CHECK(computeDeviceType(dispatch_key) == device.value().type(), - "legacy constructor for device type: ", computeDeviceType(dispatch_key), - " was passed device type: ", device.value().type(), - ", but device type must be: ", computeDeviceType(dispatch_key)); + "legacy constructor expects device type: ", computeDeviceType(dispatch_key), + "but device type: ", device.value().type(), " was passed"); } } diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 8c17c2ac74924..c2a67f8df06bf 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -7,7 +7,7 @@ namespace torch { namespace utils { PyObject* tensor_to_numpy(const at::Tensor& tensor) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } -at::Tensor tensor_from_numpy(PyObject* obj) { +at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } bool is_numpy_int(PyObject* obj) { @@ -125,13 +125,15 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) { return array.release(); } -at::Tensor tensor_from_numpy(PyObject* obj) { +at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) { if (!PyArray_Check(obj)) { throw TypeError("expected np.ndarray (got %s)", Py_TYPE(obj)->tp_name); } auto array = (PyArrayObject*)obj; - if (!PyArray_ISWRITEABLE(array)) { + // warn_if_not_writable is true when a copy of numpy variable is created. + // the warning is suppressed when a copy is being created. + if (!PyArray_ISWRITEABLE(array) && warn_if_not_writeable) { TORCH_WARN_ONCE( "The given NumPy array is not writeable, and PyTorch does " "not support non-writeable tensors. This means you can write to the " diff --git a/torch/csrc/utils/tensor_numpy.h b/torch/csrc/utils/tensor_numpy.h index f984d6b93a905..c4c93637db54d 100644 --- a/torch/csrc/utils/tensor_numpy.h +++ b/torch/csrc/utils/tensor_numpy.h @@ -6,7 +6,7 @@ namespace torch { namespace utils { PyObject* tensor_to_numpy(const at::Tensor& tensor); -at::Tensor tensor_from_numpy(PyObject* obj); +at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable=true); int aten_to_numpy_dtype(const at::ScalarType scalar_type); at::ScalarType numpy_dtype_to_aten(int dtype); diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index e6b851a3a74c6..5d60f8b07c644 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -21,6 +21,7 @@ static const char* backend_to_string(const at::Backend& backend) { case at::Backend::CUDA: return "torch.cuda"; case at::Backend::SparseCPU: return "torch.sparse"; case at::Backend::SparseCUDA: return "torch.cuda.sparse"; + case at::Backend::QuantizedCPU: return "torch.quantized"; default: AT_ERROR("Unimplemented backend ", backend); } } diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 1176c6ee30601..7286387644adf 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -16,12 +16,12 @@ import threading from typing import List, Optional, Tuple, Union from ._utils import _get_device_index, _dummy_type -from .streams import Stream, Event +from .streams import Stream, Event, _Graph from .. import device as _device import torch._C try: - from torch._C import _cudart + from torch._C import _cudart # type: ignore except ImportError: _cudart = None @@ -30,18 +30,18 @@ _initialization_lock = threading.Lock() _queued_calls = [] # don't invoke these until initialization occurs _is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) -_device_t = Union[_device, str, int] +_device_t = Union[_device, str, int, None] # Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA if hasattr(torch._C, '_CudaDeviceProperties'): _CudaDeviceProperties = torch._C._CudaDeviceProperties else: - _CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') + _CudaDeviceProperties = _dummy_type('_CudaDeviceProperties') # type: ignore # Global variables dynamically populated by native code has_magma: bool = False has_half: bool = False -default_generators: Tuple[torch._C.Generator] = () +default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] def is_available() -> bool: r"""Returns a bool indicating if CUDA is currently available.""" @@ -153,15 +153,9 @@ def _lazy_init(): # immediately, while we are still guaranteed to have the GIL, because some # of the C calls we make below will release the GIL if _is_in_bad_fork(): - from sys import version_info - if version_info < (3, 4): - msg = ("To use CUDA with multiprocessing, you must use Python " - "3.4+ and the 'spawn' start method") - else: - msg = ("To use CUDA with multiprocessing, you must use the " - "'spawn' start method") raise RuntimeError( - "Cannot re-initialize CUDA in forked subprocess. " + msg) + "Cannot re-initialize CUDA in forked subprocess. To use CUDA with " + "multiprocessing, you must use the 'spawn' start method") if not hasattr(torch._C, '_cuda_getDeviceCount'): raise AssertionError("Torch not compiled with CUDA enabled") if _cudart is None: @@ -210,7 +204,7 @@ def check_error(res: int) -> None: class device(object): r"""Context-manager that changes the selected device. - Arguments: + Args: device (torch.device or int): device index to select. It's a no-op if this argument is a negative integer or ``None``. """ @@ -239,7 +233,7 @@ class device_of(device): You can use both tensors and storages as arguments. If a given object is not allocated on a GPU, this is a no-op. - Arguments: + Args: obj (Tensor or Storage): object allocated on the selected device. """ @@ -254,7 +248,7 @@ def set_device(device: _device_t) -> None: Usage of this function is discouraged in favor of :any:`device`. In most cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable. - Arguments: + Args: device (torch.device or int): selected device. This function is a no-op if this argument is negative. """ @@ -266,11 +260,14 @@ def set_device(device: _device_t) -> None: def get_device_name(device: Optional[_device_t] = None) -> str: r"""Gets the name of a device. - Arguments: + Args: device (torch.device or int, optional): device for which to return the name. This function is a no-op if this argument is a negative integer. It uses the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). + + Returns: + str: the name of the device """ return get_device_properties(device).name @@ -278,7 +275,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str: def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]: r"""Gets the cuda capability of a device. - Arguments: + Args: device (torch.device or int, optional): device for which to return the device capability. This function is a no-op if this argument is a negative integer. It uses the current device, given by @@ -293,11 +290,32 @@ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int] def get_device_properties(device: _device_t) -> _CudaDeviceProperties: + r"""Gets the properties of a device. + + Args: + device (torch.device or int or str): device for which to return the + properties of the device. + + Returns: + _CudaDeviceProperties: the properties of the device + """ _lazy_init() # will define _get_device_properties device = _get_device_index(device, optional=True) if device < 0 or device >= device_count(): raise AssertionError("Invalid device id") - return _get_device_properties(device) + return _get_device_properties(device) # type: ignore[name-defined] + +def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool: + r"""Checks if peer access between two devices is possible. + """ + _lazy_init() + device = _get_device_index(device, optional=True) + peer_device = _get_device_index(peer_device) + if device < 0 or device >= device_count(): + raise AssertionError("Invalid device id") + if peer_device < 0 or peer_device >= device_count(): + raise AssertionError("Invalid peer device id") + return torch._C._cuda_canDeviceAccessPeer(device, peer_device) @contextlib.contextmanager @@ -307,7 +325,7 @@ def stream(stream): All CUDA kernels queued within its context will be enqueued on a selected stream. - Arguments: + Args: stream (Stream): selected stream. This manager is a no-op if it's ``None``. @@ -356,8 +374,8 @@ def get_gencode_flags() -> str: arch_list = get_arch_list() if len(arch_list) == 0: return "" - arch_list = [arch.split("_") for arch in arch_list] - return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list]) + arch_list_ = [arch.split("_") for arch in arch_list] + return " ".join([f"-gencode compute=compute_{arch},code={kind}_{arch}" for (kind, arch) in arch_list_]) @@ -370,7 +388,7 @@ def current_device() -> int: def synchronize(device: _device_t = None) -> None: r"""Waits for all kernels in all streams on a CUDA device to complete. - Arguments: + Args: device (torch.device or int, optional): device for which to synchronize. It uses the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -396,7 +414,7 @@ def ipc_collect(): def current_stream(device: Optional[_device_t] = None) -> Stream: r"""Returns the currently selected :class:`Stream` for a given device. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns the currently selected :class:`Stream` for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` @@ -410,7 +428,7 @@ def current_stream(device: Optional[_device_t] = None) -> Stream: def default_stream(device: Optional[_device_t] = None) -> Stream: r"""Returns the default :class:`Stream` for a given device. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns the default :class:`Stream` for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` @@ -454,7 +472,7 @@ def current_blas_handle(): torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase') -@staticmethod +@staticmethod # type: ignore[misc] def _lazy_new(cls, *args, **kwargs): _lazy_init() # We may need to call lazy init again if we are a forked child @@ -467,8 +485,11 @@ class _CudaBase(object): is_sparse = False def type(self, *args, **kwargs): - with device(self.get_device()): - return super(_CudaBase, self).type(*args, **kwargs) + # We could use a Protocol here to tell mypy that self has `get_device` method + # but it is only available in the typing module on Python >= 3.8 + # or on typing_extensions module on Python >= 3.6 + with device(self.get_device()): # type: ignore + return super(_CudaBase, self).type(*args, **kwargs) # type: ignore[misc] __new__ = _lazy_new diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index 1239dff4588fe..8f4105623a98f 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -5,7 +5,7 @@ from torch._utils import _get_device_index as _torch_get_device_index -def _get_device_index(device: Union[Device, str, int], optional: bool = False, +def _get_device_index(device: Union[Device, str, int, None], optional: bool = False, allow_cpu: bool = False) -> int: r"""Gets the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index 8bac02fc39f09..99fdf6e03e838 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -1,7 +1,10 @@ import torch import functools import warnings -import numpy as np +try: + import numpy as np +except ModuleNotFoundError: + np = None from torch._six import container_abcs, string_classes @@ -106,7 +109,7 @@ def forward(self, input): :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process (see :ref:`Working with Multiple GPUs`). - Arguments: + Args: enabled(bool, optional, default=True): Whether autocasting should be enabled in the region. """ def __init__(self, enabled=True): @@ -144,12 +147,16 @@ def _cast(value, dtype): return value.to(dtype) if is_eligible else value elif isinstance(value, string_classes): return value - elif isinstance(value, np.ndarray): + elif np is not None and isinstance(value, np.ndarray): return value elif isinstance(value, container_abcs.Mapping): return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()} elif isinstance(value, container_abcs.Iterable): - return type(value)(_cast(v, dtype) for v in value) + iterable = map(lambda v: _cast(v, dtype), value) + if isinstance(value, list) or isinstance(value, tuple): + return type(value)(iterable) + else: + return iterable else: return value @@ -169,7 +176,7 @@ def custom_fwd(fwd=None, **kwargs): Helper decorator for ``forward`` methods of custom autograd functions (subclasses of :class:`torch.autograd.Function`). See the :ref:`example page` for more detail. - Arguments: + Args: cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, when ``forward`` runs in an autocast-enabled region, casts incoming floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index 066ff1a0d3117..522b44cbf2466 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -3,18 +3,19 @@ from torch._six import container_abcs import warnings from enum import Enum +from typing import Any, Dict, List, Optional, Tuple class _MultiDeviceReplicator(object): """ Lazily serves copies of a tensor to requested devices. Copies are cached per-device. """ - def __init__(self, master_tensor): + def __init__(self, master_tensor: torch.Tensor) -> None: assert master_tensor.is_cuda self.master = master_tensor - self._per_device_tensors = {} + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} - def get(self, device): + def get(self, device) -> torch.Tensor: retval = self._per_device_tensors.get(device, None) if retval is None: retval = self.master.to(device=device, non_blocking=True, copy=True) @@ -38,6 +39,9 @@ def _refresh_per_optimizer_state(): class GradScaler(object): + _scale: Optional[torch.Tensor] + _grows_tracker: Optional[torch.Tensor] + _per_optimizer_states: Dict[int, Dict[str, Any]] """ An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling conveniently. @@ -90,7 +94,7 @@ class GradScaler(object): value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). - Arguments: + Args: init_scale (float, optional, default=2.**16): Initial scale factor. growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. @@ -128,10 +132,11 @@ def __init__(self, self._growth_tracker = None self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - def _check_scale_growth_tracker(self, funcname): + def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) def _lazy_init_scale_growth_tracker(self, dev): assert self._growth_tracker is None, "_growth_tracker initialized before _scale" @@ -145,7 +150,7 @@ def scale(self, outputs): Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified. - Arguments: + Args: outputs (Tensor or iterable of Tensors): Outputs to scale. """ if not self._enabled: @@ -156,21 +161,27 @@ def scale(self, outputs): assert outputs.is_cuda if self._scale is None: self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None return outputs * self._scale.to(device=outputs.device, non_blocking=True) # Invoke the more complex machinery only if we're treating multiple outputs. - stash = [None] # trick to hold a reference that can be overwritten at any level of the recursion below. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale def apply_scale(val): if isinstance(val, torch.Tensor): assert val.is_cuda - if self._scale is None: - self._lazy_init_scale_growth_tracker(val.device) - if stash[0] is None: - stash[0] = _MultiDeviceReplicator(self._scale) + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) return val * stash[0].get(val.device) elif isinstance(val, container_abcs.Iterable): - return type(val)(apply_scale(v) for v in val) + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable else: raise ValueError("outputs must be a Tensor or an iterable of Tensors") @@ -180,27 +191,39 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): per_device_inv_scale = _MultiDeviceReplicator(inv_scale) per_device_found_inf = _MultiDeviceReplicator(found_inf) - for group in optimizer.param_groups: - for param in group["params"]: - if param.grad is not None: + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue if (not allow_fp16) and param.grad.dtype == torch.float16: raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() else: - with torch.no_grad(): - if param.grad.is_sparse: - # is_coalesced() == False means the sparse grad has values with duplicate indices. - # coalesce() deduplicates indices and adds all values that have the same index. - # For scaled fp16 values, there's a good chance coalescing will cause overflow, - # so we should check the coalesced _values(). - if param.grad.dtype is torch.float16: - param.grad = param.grad.coalesce() - to_unscale = param.grad._values() - else: - to_unscale = param.grad - - torch._amp_non_finite_check_and_unscale_(to_unscale, - per_device_found_inf.get(param.grad.device), - per_device_inv_scale.get(param.grad.device)) + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_(grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device)) return per_device_found_inf._per_device_tensors @@ -222,7 +245,7 @@ def unscale_(self, optimizer): scaler.step(optimizer) scaler.update() - Arguments: + Args: optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. .. note:: @@ -249,6 +272,7 @@ def unscale_(self, optimizer): raise RuntimeError("unscale_() is being called after step().") # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None inv_scale = self._scale.double().reciprocal().float() found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) @@ -268,7 +292,7 @@ def step(self, optimizer, *args, **kwargs): Returns the return value of ``optimizer.step(*args, **kwargs)``. - Arguments: + Args: optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. args: Any arguments. kwargs: Any keyword arguments. @@ -322,7 +346,7 @@ def update(self, new_scale=None): Passing ``new_scale`` sets the scale directly. - Arguments: + Args: new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. .. warning:: @@ -332,22 +356,22 @@ def update(self, new_scale=None): if not self._enabled: return - self._check_scale_growth_tracker("update") + _scale, _growth_tracker = self._check_scale_growth_tracker("update") if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): - self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=self._scale.device) + self._scale = torch.full((1,), new_scale, dtype=torch.float32, device=_scale.device) else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.cuda.FloatTensor), reason + assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale = new_scale else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [found_inf.to(device=self._scale.device, non_blocking=True) + found_infs = [found_inf.to(device=_scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values()] @@ -358,8 +382,8 @@ def update(self, new_scale=None): for i in range(1, len(found_infs)): found_inf_combined += found_infs[i] - self._scale = torch._amp_update_scale(self._growth_tracker, - self._scale, + self._scale = torch._amp_update_scale(_growth_tracker, + _scale, found_inf_combined, self._growth_factor, self._backoff_factor, @@ -391,7 +415,7 @@ def get_growth_factor(self): def set_growth_factor(self, new_factor): r""" - Arguments: + Args: new_scale (float): Value to use as the new scale growth factor. """ self._growth_factor = new_factor @@ -404,7 +428,7 @@ def get_backoff_factor(self): def set_backoff_factor(self, new_factor): r""" - Arguments: + Args: new_scale (float): Value to use as the new scale backoff factor. """ self._backoff_factor = new_factor @@ -417,7 +441,7 @@ def get_growth_interval(self): def set_growth_interval(self, new_interval): r""" - Arguments: + Args: new_interval (int): Value to use as the new growth interval. """ self._growth_interval = new_interval @@ -460,7 +484,7 @@ def load_state_dict(self, state_dict): r""" Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. - Arguments: + Args: state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. """ if not self._enabled: @@ -498,10 +522,10 @@ def __setstate__(self, state): self.__dict__.update(state) def _check_inf_per_device(self, optimizer): - self._check_scale_growth_tracker("_check_inf_per_device") + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=self._scale.device) - found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) + dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) diff --git a/torch/cuda/comm.py b/torch/cuda/comm.py index f9856eda380fe..557ffb0c0de44 100644 --- a/torch/cuda/comm.py +++ b/torch/cuda/comm.py @@ -2,4 +2,4 @@ from torch.nn.parallel.comm import broadcast, broadcast_coalesced, reduce_add, \ reduce_add_coalesced, scatter, gather -__all__ = [broadcast, broadcast_coalesced, reduce_add, reduce_add_coalesced, scatter, gather] +__all__ = ['broadcast', 'broadcast_coalesced', 'reduce_add', 'reduce_add_coalesced', 'scatter', 'gather'] diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 6c2b1b867862d..c0bde95de7410 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -29,7 +29,7 @@ def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None frameworks. Allocated memory is released through :func:`~torch.cuda.caching_allocator_delete`. - Arguments: + Args: size (int): number of bytes to be allocated. device (torch.device or int, optional): selected device. If it is ``None`` the default CUDA device is used. @@ -62,7 +62,7 @@ def caching_allocator_delete(mem_ptr): is freed here. The associated device and stream are tracked inside the allocator. - Arguments: + Args: mem_ptr (int): memory address to be freed by the allocator. .. note:: @@ -72,6 +72,33 @@ def caching_allocator_delete(mem_ptr): torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr) +def set_per_process_memory_fraction(fraction, device: Union[Device, int] = None) -> None: + r"""Set memory fraction for a process. + The fraction is used to limit an caching allocator to allocated memory on a CUDA device. + The allowed value equals the total visible memory multiplied fraction. + If trying to allocate more than the allowed value in a process, will raise an out of + memory error in allocator. + + Args: + fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction. + device (torch.device or int, optional): selected device. If it is + ``None`` the default CUDA device is used. + .. note:: + In general, the total available free memory is less than the total capacity. + """ + _lazy_init() + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + if not isinstance(fraction, float): + raise TypeError('Invalid type for fraction argument, must be `float`') + if fraction < 0 or fraction > 1: + raise ValueError('Invalid fraction value: {}. ' + 'Allowed range: 0~1'.format(fraction)) + + torch._C._cuda_setMemoryFraction(fraction, device) + + def empty_cache() -> None: r"""Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in @@ -137,7 +164,7 @@ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]: result in a cache flush and retry. - ``"num_ooms"``: number of out-of-memory errors thrown. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistics for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -177,7 +204,7 @@ def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None: the `"allocated"` and `"freed"` keys in each individual stat dict, as well as `"num_alloc_retries"` and `"num_ooms"`. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -196,7 +223,7 @@ def reset_peak_memory_stats(device: Union[Device, int] = None) -> None: See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the `"peak"` key in each individual stat dict. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -215,7 +242,7 @@ def reset_max_memory_allocated(device: Union[Device, int] = None) -> None: See :func:`~torch.cuda.max_memory_allocated` for details. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -241,7 +268,7 @@ def reset_max_memory_cached(device: Union[Device, int] = None) -> None: See :func:`~torch.cuda.max_memory_cached` for details. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -265,7 +292,7 @@ def memory_allocated(device: Union[Device, int] = None) -> int: r"""Returns the current GPU memory occupied by tensors in bytes for a given device. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -289,7 +316,7 @@ def max_memory_allocated(device: Union[Device, int] = None) -> int: functions can measure the peak allocated memory usage of each iteration in a training loop. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -305,7 +332,7 @@ def memory_reserved(device: Union[Device, int] = None) -> int: r"""Returns the current GPU memory managed by the caching allocator in bytes for a given device. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -327,7 +354,7 @@ def max_memory_reserved(device: Union[Device, int] = None) -> int: can measure the peak cached memory amount of each iteration in a training loop. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns statistic for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -375,7 +402,7 @@ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) This can be useful to display periodically during training, or when handling out-of-memory exceptions. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns printout for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -476,7 +503,7 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: This can be useful to display periodically during training, or when handling out-of-memory exceptions. - Arguments: + Args: device (torch.device or int, optional): selected device. Returns printout for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index 9ce7a51f0dd1e..94108a3dadad8 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -3,6 +3,7 @@ import torch._six import torch.cuda +from typing import Optional, Sequence, Union __all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter'] @@ -43,7 +44,7 @@ def init_rank(num_ranks, uid, rank): return torch._C._nccl_init_rank(num_ranks, uid, rank) -def _check_sequence_type(inputs): +def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None: if not isinstance(inputs, collections.Container) or isinstance(inputs, torch.Tensor): raise TypeError("Inputs should be a collection of tensors") @@ -58,8 +59,15 @@ def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None): # `output` used to be `outputs`, taking in a list of tensors. So we have two # arguments for BC reasons. -def reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None, *, outputs=None): +def reduce(inputs: Sequence[torch.Tensor], + output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, + root: int = 0, + op: int = SUM, + streams: Optional[Sequence[torch.cuda.Stream]] = None, + comms=None, *, + outputs: Optional[Sequence[torch.Tensor]] = None) -> None: _check_sequence_type(inputs) + _output: torch.Tensor if outputs is not None: if output is not None: raise ValueError( @@ -70,30 +78,33 @@ def reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None, *, out warnings.warn( "nccl.reduce with an output tensor list is deprecated. " "Please specify a single output tensor with argument 'output' instead instead.") - output = outputs[root] + _output = outputs[root] elif not isinstance(output, torch.Tensor) and isinstance(output, torch._six.container_abcs.Sequence): # User called old API with positional arguments of list of output tensors. warnings.warn( "nccl.reduce with an output tensor list is deprecated. " "Please specify a single output tensor.") - output = output[root] - elif output is None: - output = inputs[root] - torch._C._nccl_reduce(inputs, output, root, op, streams, comms) + _output = output[root] + else: + _output = inputs[root] if output is None else output + torch._C._nccl_reduce(inputs, _output, root, op, streams, comms) -def broadcast(inputs, root=0, streams=None, comms=None): +def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None: _check_sequence_type(inputs) torch._C._nccl_broadcast(inputs, root, streams, comms) -def all_gather(inputs, outputs, streams=None, comms=None): +def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None: _check_sequence_type(inputs) _check_sequence_type(outputs) torch._C._nccl_all_gather(inputs, outputs, streams, comms) -def reduce_scatter(inputs, outputs, op=SUM, streams=None, comms=None): +def reduce_scatter(inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + op: int = SUM, + streams=None, comms=None) -> None: _check_sequence_type(inputs) _check_sequence_type(outputs) torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms) diff --git a/torch/cuda/nvtx.py b/torch/cuda/nvtx.py index e04c99f0124f2..4265eac9707d1 100644 --- a/torch/cuda/nvtx.py +++ b/torch/cuda/nvtx.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + try: from torch._C import _nvtx except ImportError: @@ -12,7 +14,7 @@ def _fail(*args, **kwargs): _nvtx = _NVTXStub() # type: ignore[assignment] -__all__ = ['range_push', 'range_pop', 'mark'] +__all__ = ['range_push', 'range_pop', 'mark', 'range'] def range_push(msg): @@ -20,7 +22,7 @@ def range_push(msg): Pushes a range onto a stack of nested range span. Returns zero-based depth of the range that is started. - Arguments: + Args: msg (string): ASCII message to associate with range """ return _nvtx.rangePushA(msg) @@ -38,7 +40,22 @@ def mark(msg): """ Describe an instantaneous event that occurred at some point. - Arguments: + Args: msg (string): ASCII message to associate with the event. """ return _nvtx.markA(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (string): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + yield + range_pop() diff --git a/torch/cuda/profiler.py b/torch/cuda/profiler.py index 25136acae12a0..95ae19d733b1e 100644 --- a/torch/cuda/profiler.py +++ b/torch/cuda/profiler.py @@ -26,7 +26,7 @@ def init(output_file, flags=None, output_mode='key_value'): else: raise RuntimeError("supported CUDA profiler output modes are: key_value and csv") with tempfile.NamedTemporaryFile(delete=True) as f: - f.write(b'\n'.join(map(lambda f: f.encode('ascii'), flags))) + f.write(b'\n'.join(f.encode('ascii') for f in flags)) f.flush() check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum)) diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 14345baf6abdd..5371f9ca34637 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -8,6 +8,7 @@ # Define dummy base classes torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase') torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase') + torch._C.__dict__['_CudaGraphBase'] = _dummy_type('_CudaGraphBase') class Stream(torch._C._CudaStreamBase): r"""Wrapper around a CUDA stream. @@ -16,11 +17,11 @@ class Stream(torch._C._CudaStreamBase): device, independent from other streams. See :ref:`cuda-semantics` for details. - Arguments: + Args: device(torch.device or int, optional): a device on which to allocate the stream. If :attr:`device` is ``None`` (default) or a negative integer, this will use the current device. - priority(int, optional): priority of the stream. Can be either + priority(int, optional): priority of the stream. Can be either -1 (high priority) or 0 (low priority). By default, streams have priority 0. @@ -35,7 +36,7 @@ def __new__(cls, device=None, priority=0, **kwargs): def wait_event(self, event): r"""Makes all future work submitted to the stream wait for an event. - Arguments: + Args: event (Event): an event to wait for. .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see @@ -55,7 +56,7 @@ def wait_stream(self, stream): All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete. - Arguments: + Args: stream (Stream): a stream to synchronize. .. note:: This function returns without waiting for currently enqueued @@ -66,7 +67,7 @@ def wait_stream(self, stream): def record_event(self, event=None): r"""Records an event. - Arguments: + Args: event (Event, optional): event to record. If not given, a new one will be allocated. @@ -122,7 +123,7 @@ class Event(torch._C._CudaEventBase): same device may record the event. However, streams on any device can wait on the event. - Arguments: + Args: enable_timing (bool, optional): indicates if the event should measure time (default: ``False``) blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) @@ -201,3 +202,5 @@ def __repr__(self): return ''.format(self._as_parameter_.value) else: return '' + +_Graph = torch._C._CudaGraphBase diff --git a/torch/custom_class.h b/torch/custom_class.h index 3805cfafc91ab..080d9d9d3c95a 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -27,6 +27,21 @@ detail::types init() { return detail::types{}; } +template +struct InitLambda { + Func f; +}; + +template +decltype(auto) init(Func&& f) { + using InitTraits = + c10::guts::infer_function_traits_t>; + using ParameterTypeList = typename InitTraits::parameter_types; + + InitLambda init{std::forward(f)}; + return init; +} + /// Entry point for custom C++ class registration. To register a C++ class /// in PyTorch, instantiate `torch::class_` with the desired class as the /// template parameter. Typically, this instantiation should be done in @@ -58,14 +73,16 @@ class class_ { /// see this class exposed as in Python and TorchScript. For example, if /// you pass `foo` as the namespace name and `Bar` as the className, the /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript - explicit class_(const std::string& namespaceName, const std::string& className) { + explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") { detail::checkValidIdent(namespaceName, "Namespace name"); detail::checkValidIdent(className, "Class name"); qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className; classTypePtr = at::ClassType::create( c10::QualifiedName(qualClassName), - std::weak_ptr()); + std::weak_ptr(), + /*is_module=*/false, + std::move(doc_string)); classTypePtr->addAttribute("capsule", at::CapsuleType::get()); c10::getCustomClassTypeMap().insert( @@ -81,7 +98,7 @@ class class_ { /// `torch::init()` would register a two-argument constructor /// taking an `int` and a `std::string` as argument. template - class_& def(detail::types) { // Used in combination with + class_& def(detail::types, std::string doc_string = "") { // Used in combination with // torch::init<...>() auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); @@ -89,7 +106,25 @@ class class_ { object->setSlot(0, c10::IValue::make_capsule(std::move(classObj))); }; - defineMethod("__init__", std::move(func)); + defineMethod("__init__", std::move(func), std::move(doc_string)); + return *this; + } + + // Used in combination with torch::init([]lambda(){......}) + template + class_& def( + InitLambda> init, + std::string doc_string = "") { + auto init_lambda_wrapper = [func = std::move(init.f)]( + c10::tagged_capsule self, + ParameterTypes... arg) { + c10::intrusive_ptr classObj = + at::guts::invoke(func, std::forward(arg)...); + auto object = self.ivalue.toObject(); + object->setSlot(0, c10::IValue::make_capsule(classObj)); + }; + defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string)); + return *this; } @@ -112,18 +147,18 @@ class class_ { /// // do something /// }) template - class_& def(std::string name, Func f) { + class_& def(std::string name, Func f, std::string doc_string = "") { auto wrapped_f = detail::wrap_func(std::move(f)); - defineMethod(std::move(name), std::move(wrapped_f)); + defineMethod(std::move(name), std::move(wrapped_f), std::move(doc_string)); return *this; } /// This is an unsafe method registration API added for adding custom JIT backend support via custom /// C++ classes. It is not for general purpose use. - class_& _def_unboxed(std::string name, std::function func, c10::FunctionSchema schema) { + class_& _def_unboxed(std::string name, std::function func, c10::FunctionSchema schema, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto method = std::make_unique( - qualMethodName, std::move(schema), std::move(func)); + qualMethodName, std::move(schema), std::move(func), std::move(doc_string)); classTypePtr->addMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; @@ -228,7 +263,7 @@ class class_ { private: template - void defineMethod(std::string name, Func func) { + void defineMethod(std::string name, Func func, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); @@ -241,7 +276,7 @@ class class_ { detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( - qualMethodName, std::move(schema), std::move(wrapped_func)); + qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string)); // Register the method here to keep the Method alive. // ClassTypes do not hold ownership of their methods (normally it diff --git a/torch/distributed/CONTRIBUTING.md b/torch/distributed/CONTRIBUTING.md index 0f180307c94c3..5f7d43afd8ec1 100644 --- a/torch/distributed/CONTRIBUTING.md +++ b/torch/distributed/CONTRIBUTING.md @@ -28,7 +28,7 @@ Processes discover each other through a rendezvous process on a common Store (Se ### Distributed Data Parallel -DDP is implemented as a module in [distributed.py](../nn/parallel/distributed.py) with some of the core functions implemented in [reducer.cpp](../csrc/distributed/c10d/reducer.cpp) and [comm.cpp](../csrc/distributed/c10d/reducer.cpp). Gradients synchronizations occur in backward pass, triggered as autograd hooks. +DDP is implemented as a module in [distributed.py](../nn/parallel/distributed.py) with some of the core functions implemented in [reducer.cpp](../lib/c10d/reducer.cpp) and [comm.cpp](../lib/c10d/reducer.cpp). Gradients synchronizations occur in backward pass, triggered as autograd hooks. ### Onboarding Tasks @@ -66,7 +66,7 @@ The distributed optimizer is completely written in Python and can be found at [o ### Onboarding Tasks -A list of onboarding tasks can be found [here](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+rpc%22+label%3A%22topic%3A+bootcamp%22+) and [here](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+rpc%22+label%3Apt_distributed_rampup) +A list of onboarding tasks can be found [here](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3Apt_distributed_rampup+). ## Running unit tests @@ -76,6 +76,10 @@ All the unit tests can be found under the [test/distributed](../../test/distribu # Run the c10d unit test. python test/distributed/test_c10d.py +# Run distributed tests, including tests for Distributed Data Parallel +python test/run_test.py --verbose -i distributed/test_distributed_fork +python test/run_test.py --verbose -i distributed/test_distributed_spawn + # Run the RPC test suite for the TensorPipeAgent. python test/distributed/rpc/test_tensorpipe_agent.py diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index adfd2201a0469..1335fe9d1d6df 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -1,14 +1,15 @@ import torch +import sys def is_available(): """ Returns ``True`` if the distributed package is available. Otherwise, ``torch.distributed`` does not expose any other APIs. Currently, - ``torch.distributed`` is available on Linux and MacOS. Set + ``torch.distributed`` is available on Linux, MacOS and Windows. Set ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. - Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and + Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, ``USE_DISTRIBUTED=0`` for MacOS. """ return hasattr(torch._C, "_c10d_init") @@ -19,6 +20,27 @@ def is_available(): if is_available(): + from torch._C._distributed_c10d import ( + Store, + FileStore, + TCPStore, + ProcessGroup, + Reducer, + BuiltinCommHookType, + _DEFAULT_FIRST_BUCKET_BYTES, + _GradBucket, + _register_comm_hook, + _register_builtin_comm_hook, + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _test_python_store, + ) + if sys.platform != 'win32': + from torch._C._distributed_c10d import ( + HashStore, + _round_robin_process_groups, + ) + from .distributed_c10d import * # Variables prefixed with underscore are not auto imported # See the comment in `distributed_c10d.py` above `_backend` on why we expose diff --git a/torch/utils/_benchmark/utils/__init__.py b/torch/distributed/algorithms/__init__.py similarity index 100% rename from torch/utils/_benchmark/utils/__init__.py rename to torch/distributed/algorithms/__init__.py diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 6b07e23c94761..44c77bad426f1 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -1,13 +1,39 @@ from enum import Enum from functools import partial -import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default -import torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks as quantization +import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel +from . import ( + default_hooks as default, + powerSGD_hook as powerSGD, + quantization_hooks as quantization, +) + def _ddp_comm_hook_wrapper(comm_hook, model, state): - model._register_comm_hook(state, comm_hook) + model.register_comm_hook(state, comm_hook) + + +def _powerSGD_comm_hook_wrapper( + comm_hook, + model, + state, + matrix_approximation_rank, + use_error_feedback=True, + random_seed=0, +): + """ + To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group, + which will be wrapped up with other state info. + """ + powerSGD_state = powerSGD.PowerSGDState( + process_group=state, + matrix_approximation_rank=matrix_approximation_rank, + use_error_feedback=use_error_feedback, + random_seed=random_seed, + ) + model.register_comm_hook(powerSGD_state, comm_hook) class DDPCommHookType(Enum): @@ -28,18 +54,42 @@ class DDPCommHookType(Enum): QUANTIZE_PER_CHANNEL = partial( _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook ) + POWER_SGD = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=1, + ) + # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version, + # but it runs slower and consumes more memory. + POWER_SGD_RANK2 = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=2, + ) + # Batching can lead to a faster training at the cost of accuracy. + BATCHED_POWER_SGD = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=1, + ) + BATCHED_POWER_SGD_RANK2 = partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=2, + ) def register_ddp_comm_hook( comm_hook_type: DDPCommHookType, model: DistributedDataParallel, state=None ): """ - Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` - to the DDP model. User can specify the type of hook as an enum - ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will - be passed to the model. + Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` + to the DDP model. User can specify the type of hook as an enum + ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will + be passed to the model. + Uses Python comm hook implementations. - Example:: - >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state) + Example:: + >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state) """ comm_hook_type.value(model=model, state=state) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 16638a915f706..59491a868be44 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -3,24 +3,22 @@ def allreduce_hook( - process_group: object, bucket: dist._GradBucket + process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: """ - This DDP communication hook just calls ``allreduce`` using ``GradBucket`` - tensors. Once gradient tensors are aggregated across all workers, its ``then`` - callback takes the mean and returns the result. If user registers this hook, - DDP results is expected to be same as the case where no hook was registered. - Hence, this won't change behavior of DDP and user can use this as a reference - or modify this hook to log useful information or any other purposes while - unaffecting DDP behavior. - - Example:: - >>> ddp_model._register_comm_hook(process_group, allreduce_hook) + This DDP communication hook just calls ``allreduce`` using ``GradBucket`` + tensors. Once gradient tensors are aggregated across all workers, its ``then`` + callback takes the mean and returns the result. If user registers this hook, + DDP results is expected to be same as the case where no hook was registered. + Hence, this won't change behavior of DDP and user can use this as a reference + or modify this hook to log useful information or any other purposes while + unaffecting DDP behavior. + + Example:: + >>> ddp_model.register_comm_hook(process_group, allreduce_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD - world_size = ( - process_group.size() if process_group is not None else dist.get_world_size() - ) + world_size = group_to_use.size() tensor = bucket.get_tensors()[0] fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future() @@ -31,22 +29,22 @@ def then_callback(fut): return fut.then(then_callback) -def fp16_compress_hook(process_group: object, bucket: dist._GradBucket): +def fp16_compress_hook( + process_group: dist.ProcessGroup, bucket: dist._GradBucket +) -> torch.futures.Future: """ - This DDP communication hook implements a simple gradient compression - approach that converts ``GradBucket`` tensors whose type is assumed to be - ``torch.float32`` to half-precision floating point format (``torch.float16``). - It allreduces those ``float16`` gradient tensors. Once compressed gradient - tensors are allreduced, its then callback called ``decompress`` converts the - aggregated result back to ``float32`` and takes the mean. - - Example:: - >>> ddp_model._register_comm_hook(process_group, fp16_compress_hook) + This DDP communication hook implements a simple gradient compression + approach that converts ``GradBucket`` tensors whose type is assumed to be + ``torch.float32`` to half-precision floating point format (``torch.float16``). + It allreduces those ``float16`` gradient tensors. Once compressed gradient + tensors are allreduced, its then callback called ``decompress`` converts the + aggregated result back to ``float32`` and takes the mean. + + Example:: + >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD - world_size = ( - process_group.size() if process_group is not None else dist.get_world_size() - ) + world_size = group_to_use.size() compressed_tensor = bucket.get_tensors()[0].to(torch.float16) @@ -55,7 +53,11 @@ def fp16_compress_hook(process_group: object, bucket: dist._GradBucket): ).get_future() def decompress(fut): - return [fut.value()[0].to(torch.float32).div_(world_size)] + decompressed_tensor = bucket.get_tensors()[0] + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()[0].div_(world_size)) + return [decompressed_tensor] return fut.then(decompress) @@ -73,30 +75,28 @@ def _get_allgather_out_list(all_gather_in_list, world_size): def _allgather_then_aggregate_hook( - process_group: object, bucket: dist._GradBucket + process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: """ - Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors - and its ``then`` callback aggregates the gathered gradient tensors and takes - mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with - W workers, both the computation and communication time scale as O(W) for - allgather compared to O(logW) for allreduce. Therefore, this hook is expected - to be much slower than ``allreduce_hook`` although both essentially do the - same thing with the gradients. - - .. warning :: - This is for test and experiments. User is suggested to use a faster - alternative called ``allreduce_hook`` that uses ``allreduce`` protocol - instead of ``allgather`` protocol. - - Example:: - >>> ddp_model._register_comm_hook(process_group, allreduce_hook) + Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors + and its ``then`` callback aggregates the gathered gradient tensors and takes + mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with + W workers, both the computation and communication time scale as O(W) for + allgather compared to O(logW) for allreduce. Therefore, this hook is expected + to be much slower than ``allreduce_hook`` although both essentially do the + same thing with the gradients. + + .. warning :: + This is for test and experiments. User is suggested to use a faster + alternative called ``allreduce_hook`` that uses ``allreduce`` protocol + instead of ``allgather`` protocol. + + Example:: + >>> ddp_model.register_comm_hook(process_group, allreduce_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank() if process_group is not None else dist.get_rank() - world_size = ( - process_group.size() if process_group is not None else dist.get_world_size() - ) + world_size = group_to_use.size() tensor = bucket.get_tensors()[0] fut = dist.all_gather( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py new file mode 100644 index 0000000000000..7183aa1a82a3b --- /dev/null +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -0,0 +1,494 @@ +import logging +import math + +import numpy as np +import torch +import torch.distributed as dist + + +def _orthogonalize(matrix, epsilon=1e-8): + """ + Applies Gram-Schmidt procedure to orthogonalize a given 2D tensor. + If epsilon is 0, this is equivalent to `torch.qr(matrix, out=(matrix, _))`, + but `torch.qr` is very slow, probably because it is not optimized for a matrix that has a small number of columns. + """ + num_cols = matrix.shape[1] + for i in range(num_cols): + # Normalize the i'th column. + col = matrix[:, i : i + 1] + # If no epsilon is added here, division by zero may be caused by vanishing gradients. + # This epsilon is not needed if the input matrix covers the gradients of at least one entire layer in the neural network. + if epsilon == 0: + # Note that col ** 2 can underflow/overflow if we use FP16. + # May need to consder multiplying a scaling factor and divding it later, or using bfloat16 isntead. + col /= torch.sqrt(torch.sum(col ** 2)) + else: + col /= torch.sqrt(torch.sum(col ** 2)) + epsilon + # Project it on the rest and remove it. + if i + 1 < num_cols: + rest = matrix[:, i + 1 :] + rest -= torch.sum(col * rest, dim=0) * col + + +class PowerSGDState(object): + __slots__ = [ + "process_group", + "matrix_approximation_rank", + "use_error_feedback", + "warm_start", + "rng", + "error_dict", + "p_memory_dict", + "q_memory_dict", + ] + + def __init__( + self, + process_group, + matrix_approximation_rank=1, + use_error_feedback=True, + warm_start=True, + random_seed=0, + ): + self.process_group = process_group + # The low rank for matrix approximation. + # Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf. + self.matrix_approximation_rank = matrix_approximation_rank + # Error feedback is usually crucial for both for convergence and generalization, + # because PowerSGD is a biased compressor, + # i.e., compressing and decompressing a random gradient does not yield the original in expectation. + # This mechanism requires a temporary copy of the input gradients, + # so it increases the peak memory consumption by the size of gradient tensor. + # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank), + # sometimes it is possible to converge to the optima without error feedback. + # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf + self.use_error_feedback = use_error_feedback + # Warm-start reuses P(s) and Q(s) from the previous iteration. + # This can improve the approximation quality and hence improve the accuracy. + # Additionally, by avoiding the initialization of these low-rank tensors at every step, + # this can also accelerate training. + # However, this is at the cost of extra memory. + self.warm_start = warm_start + # The purpose of this RNG is to generate different random seeds for initializing Q across iterations, + # but in the same order for all the DDP replicas. + # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps. + # If the same random projection is used, + # there will be differences between the gradients that are never synchronized. + self.rng = np.random.RandomState(random_seed) + # Since there is only a single state instance for all the input buckets, + # need to maintain a dictionary that maps each bucket index to the local error. + self.error_dict = {} + self.p_memory_dict = {} + self.q_memory_dict = {} + + +def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: + """ + This DDP communication hook implements the original PowerSGD gradient compression + algorithm described in https://arxiv.org/abs/1905.13727. + Once gradient tensors are aggregated across all workers, this hook applies + compression as follows: + 1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: + high-rank tensors and vector-like rank-1 tensors (for biases). + 2) Handles rank-1 tensors by allreducing them without compression: + 2.1) Allocate contiguous memory for those rank-1 tensors, + and allreduces all the rank-1 tensors as a batch, without compression; + 2.2) Copies the indvidual rank-1 tensors from the contiguous memory back to the input tensor. + 3) Handles high-rank tensors by PowerSGD compression: + 3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M, + such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + 3.2) Computes each P in Ps, which is equal to MQ; + 3.3) Allreduces Ps as a batch; + 3.4) Orthogonizes each P in Ps; + 3.5) Computes each Q in Qs, which is approximately equal to M^TP; + 3.6) Allreduces Qs as a batch; + 3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T. + + TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- + one left multiplication and one right multiplication. + For warm start, can take one such step at a time, and alternate between them. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode at this time, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) + >>> ddp_model.register_comm_hook(state, powerSGD_hook) + """ + process_group = state.process_group + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.get_tensors()[0] + device = input_tensor.device + dtype = input_tensor.dtype + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.get_index() + input_tensor_cp = None + total_length = input_tensor.shape[0] + if state.use_error_feedback: + # The buckets can be rebuilt during training. + # In this case, the error tensor shape will not be aligned with the input tensor, + # and the error will be re-initialized as zeros. + if ( + bucket_index in state.error_dict + and state.error_dict[bucket_index].shape[0] == total_length + ): + input_tensor.add_(state.error_dict[bucket_index]) + else: + logging.info( + "A zero tensor of length {} that represents local error is created.".format( + total_length + ) + ) + state.error_dict[bucket_index] = torch.zeros(total_length, device=device) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = torch.clone(input_tensor).detach() + + # Unflatten the input tensor into per-parameter tensors, for layer-wise compression. + tensors = [ + input_tensor[offset : offset + length].view(sizes) + for offset, length, sizes in zip( + bucket.get_offsets(), bucket.get_lengths(), bucket.get_sizes_list() + ) + ] + + # Step I: Handle rank-1 tensors. + # Allocate contiguous memory for rank-1 tensors to allreduce them without compression efficiently. + rank1_tensors = [tensor for tensor in tensors if tensor.ndimension() <= 1] + rank1_tensors_memory = ( + torch.cat([tensor.view(-1) for tensor in rank1_tensors]) + if rank1_tensors + else torch.tensor([], device=device) + ) + + # Step II: Handle high-rank tensors. + # Allocate contiguous memory for Ps and Qs to allreduce compressed high-rank tensors efficiently. + high_rank_tensors = [ + tensor.view(tensor.shape[0], -1) + for tensor in tensors + if tensor.ndimension() > 1 + ] + total_Ps_size = 0 + total_Qs_size = 0 + for tensor in high_rank_tensors: + n, m = tensor.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + total_Ps_size += n * matrix_approximation_rank + total_Qs_size += m * matrix_approximation_rank + # Reuse Ps and Qs from the previous iteration if possible. + # The memory spaces of Ps and Qs need to be (re)allocated at the beginning, + # as well as later whenever the buckets are rebuilt during training. + if ( + not state.warm_start + or bucket_index not in state.p_memory_dict + or state.p_memory_dict[bucket_index].shape[0] != total_Ps_size + or state.q_memory_dict[bucket_index].shape[0] != total_Qs_size + ): + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logging.info( + "Allocating contiguous memory of length {} for Ps, and of length {} for Qs, respectively.".format( + total_Ps_size, total_Qs_size + ) + ) + state.p_memory_dict[bucket_index] = torch.empty( + total_Ps_size, device=device, dtype=dtype + ) + state.q_memory_dict[bucket_index] = torch.empty( + total_Qs_size, device=device, dtype=dtype + ) + + # Create Ps and Qs that point to the allocated memory. + ps = [] + qs = [] + p_idx = 0 + q_idx = 0 + for tensor in high_rank_tensors: + n, m = tensor.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + ps.append( + state.p_memory_dict[bucket_index][ + p_idx : p_idx + n * matrix_approximation_rank + ].view(n, matrix_approximation_rank) + ) + qs.append( + state.q_memory_dict[bucket_index][ + q_idx : q_idx + m * matrix_approximation_rank + ].view(m, matrix_approximation_rank) + ) + p_idx += n * matrix_approximation_rank + q_idx += m * matrix_approximation_rank + + # Initialize and then orthogonalize Qs. + with torch.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # Such seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device. + torch.manual_seed(state.rng.randint(1_000_000_000)) + for q in qs: + q.data = torch.randn( + *q.shape, + device="cpu", + dtype=dtype, + ).to(device) + _orthogonalize(q) + + # Compute Ps. + for tensor, q, p in zip(high_rank_tensors, qs, ps): + torch.matmul(tensor, q, out=p) + + # This allreduce is only applied to rank-1 tensors, + # so it should have been kicked off before the above computation on the high-rank tensors to hide more communication costs. + # However, this somehow requires a separate future chain at this time. + allreduce_contiguous_rank1_tensors_fut = dist.all_reduce( + rank1_tensors_memory, group=group_to_use, async_op=True + ).get_future() + + def unpack_rank1_tensors_and_allreduce_ps(fut): + rank1_tensors_memory = fut.value()[0].div_(world_size) + idx = 0 + for tensor in rank1_tensors: + tensor.copy_(rank1_tensors_memory[idx : idx + tensor.shape[0]]) + idx += tensor.shape[0] + + # Since these Ps will be orthogonized later, no need to divide them by world size. + return [ + dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ] + + def compute_qs(fut): + state.p_memory_dict[bucket_index] = fut.value()[0] + for p in ps: + _orthogonalize(p) + + # Compute Qs. + for tensor, p, q in zip(high_rank_tensors, ps, qs): + torch.matmul(tensor.t(), p, out=q) + + # Allreduce Qs. + return [ + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ] + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value()[0].div_(world_size) + + for p, q, tensor in zip(ps, qs, high_rank_tensors): + torch.matmul(p, q.t(), out=tensor) + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + + if state.use_error_feedback: + # memoize the local errors. + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + + return [input_tensor] + + return ( + allreduce_contiguous_rank1_tensors_fut.then( + unpack_rank1_tensors_and_allreduce_ps + ) + .then(compute_qs) + .then(decompress) + ) + + +def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: + """ + This DDP communication hook implements a simplified PowerSGD gradient compression + algorithm described in https://arxiv.org/abs/1905.13727. + Once gradient tensors are aggregated across all workers, this hook applies + compression to the flattened input tensor that batches per-parameter tensors as follows: + 1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; + 2) Creates two low-rank tensors P and Q for decomposing M, + such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + 2) Computes P, which is equal to MQ; + 3) Allreduces P; + 4) Orthogonizes P; + 5) Computes Q, which is approximately equal to M^TP; + 6) Allreduces Q; + 7) Computes M, which is approximately equal to PQ^T. + 8) Truncates the input tensor to the original length. + + TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- + one left multiplication and one right multiplication. + For warm start, can take one such step at a time, and alternate between them. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode at this time, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) + >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook) + """ + process_group = state.process_group + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.get_tensors()[0] + device = input_tensor.device + total_length = input_tensor.shape[0] + + # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary. + square_side_length = math.ceil(math.sqrt(total_length)) + padded_total_length = square_side_length ** 2 + input_tensor.resize_(padded_total_length) + input_tensor[total_length:padded_total_length].fill_(0) + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.get_index() + input_tensor_cp = None + if state.use_error_feedback: + # The buckets can be rebuilt during training. + # In this case, the error tensor shape will not be aligned with the input tensor, + # and the error will be re-initialized as zeros. + if ( + bucket_index in state.error_dict + and state.error_dict[bucket_index].shape[0] == padded_total_length + ): + input_tensor.add_(state.error_dict[bucket_index]) + else: + logging.info( + "A zero tensor of length {} that represents local error is created.".format( + padded_total_length + ) + ) + state.error_dict[bucket_index] = torch.zeros( + padded_total_length, device=device + ) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = torch.clone(input_tensor).detach() + matrix = input_tensor.view(square_side_length, square_side_length) + + # Reuse P and Q from the previous iteration if possible. + # The memory spaces of P and Q need to be (re)allocated at the beginning, + # as well as later whenever the buckets are rebuilt during training. + if ( + not state.warm_start + or bucket_index not in state.p_memory_dict + or state.p_memory_dict[bucket_index].shape + != (square_side_length, state.matrix_approximation_rank) + ): + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logging.info( + "Initializing low-rank tensors P and Q, each of which has a shape of {} x {}.".format( + square_side_length, state.matrix_approximation_rank + ) + ) + + def create_low_rank_tensor(fill_random_values, rng): + "Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank." + if fill_random_values: + with torch.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling + # anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # Such seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device. + torch.manual_seed(rng.randint(1_000_000_000)) + return torch.randn( + square_side_length, + state.matrix_approximation_rank, + device="cpu", + dtype=input_tensor.dtype, + ).to(device) + else: + return torch.empty( + square_side_length, + state.matrix_approximation_rank, + device=device, + dtype=input_tensor.dtype, + ) + + state.p_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=False, rng=state.rng + ) + state.q_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=True, rng=state.rng + ) + _orthogonalize(state.q_memory_dict[bucket_index], 0) + + torch.matmul( + matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index] + ) + allreduce_p_fut = dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ).get_future() + + def compute_q(fut): + state.p_memory_dict[bucket_index] = fut.value()[0] + _orthogonalize(state.p_memory_dict[bucket_index], 0) + + torch.matmul( + matrix.t(), + state.p_memory_dict[bucket_index], + out=state.q_memory_dict[bucket_index], + ) + + return [ + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ] + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value()[0].div_(world_size) + torch.matmul( + state.p_memory_dict[bucket_index], + state.q_memory_dict[bucket_index].t(), + out=matrix, + ) + + if state.use_error_feedback: + # memoize the local errors. + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + ret = input_tensor.resize_(total_length) + return [ret] + + return allreduce_p_fut.then(compute_q).then(decompress) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py index afac1ee66873b..87ee4145bdeef 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -43,7 +43,7 @@ def _get_allgather_out_list(all_gather_in_list, world_size): def quantization_pertensor_hook( - process_group: object, bucket: dist._GradBucket + process_group: dist.ProcessGroup, bucket: dist._GradBucket ) -> torch.futures.Future: """ Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` @@ -59,13 +59,11 @@ def quantization_pertensor_hook( ``allreduce`` protocol. It works only with flattened grads. Example:: - >>> ddp_model._register_comm_hook(process_group, quantization_pertensor_hook) + >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank() if process_group is not None else dist.get_rank() - world_size = ( - process_group.size() if process_group is not None else dist.get_world_size() - ) + world_size = group_to_use.size() tensor = bucket.get_tensors()[0] @@ -118,7 +116,7 @@ def dequantize_and_aggregate(fut): def quantization_perchannel_hook( - process_group: object, bucket: dist._GradBucket, bucket_size=512 + process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512 ) -> torch.futures.Future: """ Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather`` @@ -140,13 +138,11 @@ def quantization_perchannel_hook( ``allreduce`` protocol. It works only with flattened grads. Example:: - >>> ddp_model._register_comm_hook(process_group, quantization_perchannel_hook) + >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank() if process_group is not None else dist.get_rank() - world_size = ( - process_group.size() if process_group is not None else dist.get_world_size() - ) + world_size = group_to_use.size() tensor = bucket.get_tensors()[0] diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index a56b41bce8c67..c8d4366e44294 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -10,6 +10,20 @@ def is_available(): if is_available() and not torch._C._dist_autograd_init(): raise RuntimeError("Failed to initialize torch.distributed.autograd") +if is_available(): + from torch._C._distributed_autograd import ( + get_gradients, + backward, + _init, + _new_context, + _release_context, + _get_max_id, + _is_valid_context, + _retrieve_context, + _current_context, + _get_debug_info, + DistAutogradContext, + ) class context(object): ''' diff --git a/torch/distributed/benchmarks/README.md b/torch/distributed/benchmarks/README.md new file mode 100644 index 0000000000000..082ab87af623c --- /dev/null +++ b/torch/distributed/benchmarks/README.md @@ -0,0 +1,68 @@ +# Benchmark combining Distributed Data Parallel and Distributed RPC + +This Benchmark is used to measure distributed training iteration time. It combines Distributed Data Parallelism with Distributed Model Parallelism leveraging PyTorch DDP and the Distributed RPC Framework. The number of trainer nodes and parameter servers are configurable. The default is 8 trainers, 1 master node and 8 parameter servers. + +## Background + +There are different training paradigms where combining these two techniques might be useful. For example: +1) If we have a model with a sparse part (large embedding table) and a dense + part (FC layers), we might want to set the embedding table on a parameter + server and replicate the FC layer across multiple trainers using [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel). The [Distributed RPC framework](https://pytorch.org/docs/master/rpc.html) comes handy to perform embedding lookups on the parameter servers. +2) Enable hybrid parallelism as described in the [PipeDream](https://arxiv.org/abs/1806.03377) paper. We can use the [Distributed RPC framework](https://pytorch.org/docs/master/rpc.html) to pipeline stages of the model across multiple workers and replicate each stage (if needed) using [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel). + +## Training Process +This benchmark focuses on the first paradime above. The training process is executed as follows: + +1) The master creates embedding tables on each of the 8 Parameter Servers and holds an [RRef](https://pytorch.org/docs/master/rpc.html#rref) to it. +2) The master, then kicks off the training loop on the 8 trainers and passes the embedding table RRef to the trainers. +3) The trainers create a `HybridModel` which performs embedding lookups in all 8 Parameter Servers using the embedding table RRef provided by the master and then executes the FC layer which is wrapped and replicated via DDP (DistributedDataParallel). +4) The trainer executes the forward pass of the model and uses the loss to + execute the backward pass using [Distributed Autograd](https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework). +5) As part of the backward pass, the gradients for the FC layer are computed + first and synced to all trainers via allreduce in DDP. +6) Next, Distributed Autograd propagates the gradients to the parameter servers, + where the gradients for the embedding table are updated. +7) Finally, the [Distributed Optimizer](https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim) is used to update all parameters. + + +## Example Benchmark output: + +---------- Info --------- + +* PyTorch version: 1.7.0 +* CUDA version: 9.2.0 + +---------- nvidia-smi topo -m --------- + + GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity + GPU0 X NV2 NV1 NV2 NV1 NODE NODE NODE 0-19,40-59 + GPU1 NV2 X NV2 NV1 NODE NV1 NODE NODE 0-19,40-59 + GPU2 NV1 NV2 X NV1 NODE NODE NV2 NODE 0-19,40-59 + GPU3 NV2 NV1 NV1 X NODE NODE NODE NV2 0-19,40-59 + GPU4 NV1 NODE NODE NODE X NV2 NV1 NV2 0-19,40-59 + GPU5 NODE NV1 NODE NODE NV2 X NV2 NV1 0-19,40-59 + GPU6 NODE NODE NV2 NODE NV1 NV2 X NV1 0-19,40-59 + GPU7 NODE NODE NODE NV2 NV2 NV1 NV1 X 0-19,40-59 + +Legend: + + X = Self + SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) + NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node + PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) + PXB = Connection traversing multiple PCIe switches (without traversing the PCIe Host Bridge) + PIX = Connection traversing a single PCIe switch + NV# = Connection traversing a bonded set of # NVLinks + +------------------ PyTorch Distributed Benchmark (DDP and RPC) --------------------- + + sec/epoch epoch/sec sec/epoch epoch/sec sec/epoch epoch/sec sec/epoch epoch/sec + Trainer0: p50: 0.376s 185/s p75: 0.384s 182/s p90: 0.390s 179/s p95: 0.396s 176/s + Trainer1: p50: 0.377s 204/s p75: 0.384s 200/s p90: 0.389s 197/s p95: 0.393s 195/s + Trainer2: p50: 0.377s 175/s p75: 0.384s 172/s p90: 0.390s 169/s p95: 0.395s 166/s + Trainer3: p50: 0.377s 161/s p75: 0.384s 158/s p90: 0.390s 156/s p95: 0.393s 155/s + Trainer4: p50: 0.377s 172/s p75: 0.383s 169/s p90: 0.389s 166/s p95: 0.395s 164/s + Trainer5: p50: 0.377s 180/s p75: 0.383s 177/s p90: 0.389s 174/s p95: 0.395s 172/s + Trainer6: p50: 0.377s 204/s p75: 0.384s 200/s p90: 0.390s 197/s p95: 0.394s 195/s + Trainer7: p50: 0.377s 185/s p75: 0.384s 182/s p90: 0.389s 179/s p95: 0.394s 177/s + All: p50: 0.377s 1470/s p75: 0.384s 1443/s p90: 0.390s 1421/s p95: 0.396s 1398/s diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py new file mode 100644 index 0000000000000..a137832cdbdeb --- /dev/null +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -0,0 +1,362 @@ +import argparse +import io +import os +import random +import shlex +import subprocess +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +import torch.optim as optim +from torch.distributed.optim import DistributedOptimizer +from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions +from torch.distributed.rpc.backend_registry import BackendType +from torch.nn.parallel import DistributedDataParallel as DDP + + +# Config +NUM_TRAINERS = 8 +NUM_PS = 8 + +NUM_EMBEDDINGS = 300 +EMBEDDING_DIM = 64 + +WARMUP_CYCLES = 5 + + +class HybridModel(torch.nn.Module): + r""" + The model consists of a sparse part and a dense part. The dense part is an + nn.Linear module that is replicated across all trainers using + DistributedDataParallel. The sparse part has nn.EmbeddingBags stored on multiple + parameter servers. + + The model holds a Remote Reference to the embedding tables on the parameter + servers. + """ + + def __init__(self, emb_rref_list, device): + super(HybridModel, self).__init__() + self.emb_rref_list = emb_rref_list + fc1 = torch.nn.Linear(512, 256) + fc2 = torch.nn.Linear(256, 128) + relu = torch.nn.ReLU() + fc3 = torch.nn.Linear(128, 64) + fc4 = torch.nn.Linear(64, 32) + fc5 = torch.nn.Linear(32, 8) + sec = nn.Sequential(fc1, fc2, relu, fc3, fc4, fc5) + self.ddp = DDP(sec.to(device), device_ids=[device]) + self.device = device + + def forward(self, indices, offsets): + emb_lookups = [] + + for emb_rref in self.emb_rref_list: + emb_lookups.append( + emb_rref.rpc_sync().forward( + indices, offsets + ) # embedding_sum(input, offsets) + ) + emb_lookups_cat = torch.cat(emb_lookups, dim=1) + + # Make sure combined PS dimension is always bigger or equal than the FC input + assert NUM_PS * EMBEDDING_DIM >= 512 + dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512) + emb_lookups_reshaped = emb_lookups_cat.reshape( + [emb_lookups_cat.shape[0] * dim_normalizer, 512] + ) + + return self.ddp(emb_lookups_reshaped) + + +def _retrieve_embedding_parameters(emb_rref): + return [RRef(p) for p in emb_rref.local_value().parameters()] + + +def _print_header(): + _print_cont("\n") + _print_cont("%10s" % "") + for p in [50, 75, 90, 95]: + _print_cont("%14s%10s" % ("sec/epoch", "epoch/sec")) + _print_cont("\n") + + +def _print_benchmark(prefix, nelem, measurements): + measurements = sorted(measurements) + _print_cont("%8s:" % prefix) + for p in [50, 75, 90, 95]: + v = np.percentile(measurements, p) + _print_cont(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v)) + _print_cont("\n") + + +def _print_cont(msg): + print(msg, end="", flush=True) + + +def _run_printable(cmd): + proc = subprocess.run(shlex.split(cmd), capture_output=True) + assert proc.returncode == 0 + + buffer = io.BytesIO() + torch.save(proc.stdout.decode("utf-8"), buffer) + input_tensor = torch.ByteTensor(list(buffer.getvalue())) + input_length = torch.IntTensor([input_tensor.size(0)]) + + output = [] + buffer = io.BytesIO(np.asarray(input_tensor).tobytes()) + output.append(torch.load(buffer)) + return output + + +def _run_trainer(emb_rref_list, rank): + r""" + Each trainer runs a forward pass which involves an embedding lookup on the + 8 parameter servers and running nn.Linear locally. During the backward pass, + DDP is responsible for aggregating the gradients for the dense part + (nn.Linear) and distributed autograd ensures gradients updates are + propagated to the parameter servers. + """ + + # Setup the model. + model = HybridModel(emb_rref_list, rank) + + # Retrieve all model parameters as rrefs for DistributedOptimizer. + + # Retrieve parameters from all embedding tables for the current trainer. + model_parameter_rrefs = [] + for ind, emb_rref in enumerate(emb_rref_list): + ps_name = "ps{}".format(ind) + model_parameter_rrefs.extend( + rpc.rpc_sync(ps_name, _retrieve_embedding_parameters, args=(emb_rref,)) + ) + + # model.parameters() only includes local parameters. + for param in model.parameters(): + model_parameter_rrefs.append(RRef(param)) + + # Setup distributed optimizer + opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05) + + criterion = torch.nn.CrossEntropyLoss() + + def get_next_batch(rank): + for _ in range(10): + num_indices = random.randint(20, 50) + indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS) + + # Generate offsets. + offsets = [] + start = 0 + batch_size = 0 + + while start < num_indices: + offsets.append(start) + start += random.randint(1, 10) + batch_size += 1 + + offsets_tensor = torch.LongTensor(offsets) + target = torch.LongTensor(batch_size).random_(8).cuda(rank) + + yield indices, offsets_tensor, target + + measurements = [] + # Include warm-up cycles during training + for epoch in range(100 + WARMUP_CYCLES): + start = time.time() + batch_size = 0 + + # create distributed autograd context + for indices, offsets, target in get_next_batch(rank): + batch_size += len(target) + + with dist_autograd.context() as context_id: + output = model(indices, offsets) + loss = criterion(output, target) + + # Run distributed backward pass + dist_autograd.backward(context_id, [loss]) + + # Run distributed optimizer. Gradients propagated all the way to the parameter servers + opt.step(context_id) + + # Not necessary to zero grads as each iteration creates a different + # distributed autograd context which hosts different grads + + measurements.append(time.time() - start) + # print("Training done for epoch {}".format(epoch)) + + # Throw away warm-up measurements + measurements = measurements[WARMUP_CYCLES:] + return rank, measurements, batch_size + + +def run_worker(rank, world_size): + r""" + A wrapper function that initializes RPC, calls the function, and shuts down + RPC. + """ + + # Using different port numbers in TCP init_method for init_rpc and + # init_process_group to avoid port conflicts. + rpc_backend_options = TensorPipeRpcBackendOptions() + rpc_backend_options.init_method = "tcp://localhost:29500" + + # Rank 16. Master + if rank == (NUM_TRAINERS + NUM_PS): + + rpc.init_rpc( + "master", rank=rank, backend=BackendType.TENSORPIPE, world_size=world_size + ) + + # Build the Embedding tables on the Parameter Servers. + emb_rref_list = [] + index = 0 + while index < NUM_PS: + ps_name = "ps{}".format(index) + emb_rref = rpc.remote( + ps_name, + torch.nn.EmbeddingBag, + args=(NUM_EMBEDDINGS, EMBEDDING_DIM), + kwargs={"mode": "sum"}, + ) + emb_rref_list.append(emb_rref) + index += 1 + + # Run training loop on the trainers. + futs = [] + for trainer_rank in range(NUM_TRAINERS): + trainer_name = "trainer{}".format(trainer_rank) + fut = rpc.rpc_async( + trainer_name, _run_trainer, args=(emb_rref_list, trainer_rank) + ) + futs.append(fut) + + _print_header() + + measurements_all_trainers = [] + batch_size_all_trainers = 0 + # Wait for all training to finish. + for fut in futs: + rank, measurements, batch_size = fut.wait() + _print_benchmark("Trainer{}".format(rank), batch_size, measurements) + batch_size_all_trainers += batch_size + measurements_all_trainers.append(measurements) + + _print_benchmark("All", batch_size_all_trainers, measurements_all_trainers) + + # Rank 0-7. Trainers + elif rank >= 0 and rank < NUM_PS: + + # Initialize process group for Distributed DataParallel on trainers. + dist.init_process_group( + backend=dist.Backend.GLOO, + rank=rank, + world_size=NUM_TRAINERS, + init_method="tcp://localhost:29501", + ) + + # Initialize RPC. Trainer just waits for RPCs from master. + trainer_name = "trainer{}".format(rank) + rpc.init_rpc( + trainer_name, + rank=rank, + world_size=world_size, + rpc_backend_options=rpc_backend_options, + ) + + # Rank 8-15. Parameter Servers + elif rank >= NUM_TRAINERS and rank < NUM_TRAINERS + NUM_PS: + ps_name = "ps{}".format(rank - NUM_TRAINERS) + rpc.init_rpc( + ps_name, + rank=rank, + world_size=world_size, + backend=BackendType.TENSORPIPE, + rpc_backend_options=rpc_backend_options, + ) + # parameter server do nothing + pass + + # block until all rpcs finish + rpc.shutdown() + + +if __name__ == "__main__": + """ Initializing the distributed environment. """ + + output = _run_printable("nvidia-smi topo -m") + print("-------------------------------------------") + print(" Info ") + print("-------------------------------------------") + print("") + print("* PyTorch version: {}".format(torch.__version__)) + print("* CUDA version: {}".format(torch.version.cuda)) + print("") + print("------------ nvidia-smi topo -m -----------") + print("") + print(output[0]) + print("-------------------------------------------") + print("PyTorch Distributed Benchmark (DDP and RPC)") + print("-------------------------------------------") + + # Cmd arguments to enable automated runs (e.g. Chronos, SSH, etc). + parser = argparse.ArgumentParser(description="PyTorch DDP and RPC Benchmark") + parser.add_argument( + "--master-addr", type=str, default="localhost", help="Address of master node." + ) + parser.add_argument("--master-port", type=str, default="29500", help="Master port.") + + parser.add_argument( + "--number-trainers", + type=int, + default=NUM_TRAINERS, + help="Number of Trainer Nodes.", + ) + parser.add_argument( + "--number-ps", type=int, default=NUM_PS, help="Number of Parameter Servers." + ) + parser.add_argument( + "--number-embeddings", + type=int, + default=NUM_EMBEDDINGS, + help="Number of test embeddings to be generated.", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=EMBEDDING_DIM, + help="Number of embedding dimentions.", + ) + parser.add_argument( + "--warmup-cycles", + type=int, + default=WARMUP_CYCLES, + help="Number of cycles to warm-up each process before running the benchmark.", + ) + + args = parser.parse_args() + + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port + + NUM_TRAINERS = args.number_trainers + NUM_PS = args.number_ps + + NUM_EMBEDDINGS = args.number_embeddings + EMBEDDING_DIM = args.embedding_dim + + WARMUP_CYCLES = args.warmup_cycles + + # Defaults: + # 8 trainers (rank 0-7), + # 8 parameter servers (rank 8-15), + # 1 master (rank 16). + world_size = NUM_TRAINERS + NUM_PS + 1 # Trainers + PS + Master + mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True) diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index f7718c4a20e07..dc541c932d114 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -2,7 +2,7 @@ # Default process group wide timeout, if applicable. # This only applies to the gloo and nccl backends -# (only if NCCL_BLOCKING_WAIT is set to 1). To make an attempt at -# backwards compatibility with THD, we use an extraordinarily high default -# timeout, given that THD did not have timeouts. +# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1). +# To make an attempt at backwards compatibility with THD, we use an +# extraordinarily high default timeout, given that THD did not have timeouts. default_pg_timeout = timedelta(minutes=30) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c7d66f322bb19..5b300452f6d30 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,27 +1,33 @@ +import contextlib +import logging import pickle import torch import warnings +import time from torch._six import string_classes from datetime import timedelta +from typing import Dict, Optional, Tuple, Union # This module is wildcard imported from torch.distributed. # TODO: specify __all__ from .constants import default_pg_timeout from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401 -from . import ( +from torch._C._distributed_c10d import ( AllreduceOptions, AllreduceCoalescedOptions, AllToAllOptions, + BarrierOptions, BroadcastOptions, GatherOptions, + PrefixStore, + ProcessGroup, ReduceOptions, + ReduceOp, ReduceScatterOptions, ScatterOptions, + Store, ) -from . import ReduceOp -from . import PrefixStore - _MPI_AVAILABLE = True _NCCL_AVAILABLE = True @@ -29,20 +35,32 @@ try: - from. import ProcessGroupMPI + from torch._C._distributed_c10d import ProcessGroupMPI except ImportError: _MPI_AVAILABLE = False try: - from. import ProcessGroupNCCL + from torch._C._distributed_c10d import ProcessGroupNCCL except ImportError: _NCCL_AVAILABLE = False try: - from. import ProcessGroupGloo + from torch._C._distributed_c10d import ProcessGroupGloo except ImportError: _GLOO_AVAILABLE = False +# Some reduce ops are not supported by complex numbers and will result in an error. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (torch.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT, + ReduceOp.BAND, ReduceOp.BOR, ReduceOp.BXOR] + return reduceOp not in denyList + class Backend(object): """ @@ -67,7 +85,7 @@ class Backend(object): MPI = "mpi" TCP = "tcp" - def __new__(cls, name): + def __new__(cls, name: str): if not isinstance(name, string_classes): raise ValueError("Backend name must be a string, but got: {}".format(name)) value = getattr(Backend, name.upper(), Backend.UNDEFINED) @@ -89,7 +107,7 @@ def register_backend(cls, name, func): This class method is used by 3rd party cpp extension to register new backend. - Arguments: + Args: name (str): Backend name matching with the one in `init_process_group()`. func (function): Function handler that instantiates the backend. The function should be implemented in the backend cpp extension @@ -104,11 +122,11 @@ def register_backend(cls, name, func): # `_backend`, `dist_backend`, and `reduce_op` are here to maintain backward # compatibility with pre-c10d distributed package. # TODO: remove them when users are ready to take a hard dependency on PyTorch 1. -_backend = Backend.UNDEFINED +_backend: str = Backend.UNDEFINED dist_backend = Backend -class reduce_op(object): +class _reduce_op(object): r""" Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``. @@ -127,11 +145,12 @@ def __getattribute__(self, key): "torch.distributed.ReduceOp instead") return object.__getattribute__(self, key) -reduce_op = reduce_op() +reduce_op = _reduce_op() class group(object): - WORLD = object() + # Points to the default PG once initialized. + WORLD: Optional[ProcessGroup] = None class GroupMember(object): @@ -143,35 +162,71 @@ class GroupMember(object): # Cached process groups # For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) # For MPI pg, it is a map from ProcessGroup to (Backend, None) -_pg_map = {} +_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {} # Process group's names, map from ProcessGroup to str -_pg_names = {} +_pg_names: Dict[ProcessGroup, str] = {} # Process group's global rank to local rank mapping -_pg_group_ranks = {} +_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {} # Default process group state -_default_pg = None _default_pg_init_method = None # Process group count for default naming _group_count = 0 +STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + +def _store_based_barrier(rank, store, timeout): + """ + Barrier based on store which is used for synchronizing processes after + ``init_process_group`` or ``new_group``. Intended to be used only with + those two methods and is not a generic alternative to ``barrier()``. + """ + store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count) + store.add(store_key, 1) + logging.info('Added key: {} to store for rank: {}'.format(store_key, rank)) + + # Now wait for all workers to check in with the store. + world_size = get_world_size() + # Use 'add' instead of 'get' since for some store implementations 'add' + # doesn't work well with 'get'. Ideally the store implementations should + # be fixed, but for backward compatiblity reasons it is risky to change + # the store implementations. Once, we completely migrate away from these + # legacy stores, we can use 'get' here instead. + worker_count = store.add(store_key, 0) + start = time.time() + log_time = time.time() + while worker_count != world_size: + time.sleep(0.01) + worker_count = store.add(store_key, 0) + + # Print status periodically to keep track. + if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10): + logging.info( + "Waiting in store based barrier to initialize process group for " + "rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format( + rank, store_key, world_size, worker_count, timeout)) + log_time = time.time() + + if timedelta(seconds=(time.time() - start)) > timeout: + raise RuntimeError( + "Timed out initializing process group in store based barrier on " + "rank: {}, for key: {} (world_size={}, worker_count={}, timeout={})".format( + rank, store_key, world_size, worker_count, timeout)) -def _rank_not_in_group(group): +def _rank_not_in_group(group: ProcessGroup): """ - Helper that checks if the current process's rank is not in a given group - + Helper that checks if the current process's rank is not in a given group. """ - if group == GroupMember.WORLD: + if group is None: return False return group == GroupMember.NON_GROUP_MEMBER -def _get_group_rank(group, rank): +def _get_group_rank(group: ProcessGroup, rank): """ Helper that gets a given group's local rank in the group from a given global - rank - + rank. """ if group is GroupMember.WORLD: raise RuntimeError("group.WORLD does not have local rank to global " @@ -188,8 +243,7 @@ def _get_group_rank(group, rank): def _get_global_rank(group, group_rank): """ Helper that gets a given group's global rank from a given local rank in the - group - + group. """ if group is GroupMember.WORLD: raise RuntimeError("group.WORLD does not have local rank to global " @@ -201,24 +255,13 @@ def _get_global_rank(group, group_rank): raise RuntimeError("The group rank is not part of the group") -def _check_default_pg(): - """ - Helper that checks if the default ProcessGroup has been initialized, with - assertion - - """ - assert _default_pg is not None, \ - "Default process group is not initialized" - - def _get_group_size(group): """ - Helper that gets a given group's world size - + Helper that gets a given group's world size. """ - if group is GroupMember.WORLD: - _check_default_pg() - return _default_pg.size() + if group is GroupMember.WORLD or group is None: + default_pg = _get_default_group() + return default_pg.size() if group not in _pg_group_ranks: raise RuntimeError("The given group does not exist") return len(_pg_group_ranks[group]) @@ -227,7 +270,6 @@ def _get_group_size(group): def _check_single_tensor(param, param_name): """ Helper to check that the parameter ``param_name`` is a single tensor. - """ if not isinstance(param, torch.Tensor): raise RuntimeError("Invalid function argument. Expected parameter `{}` " @@ -237,7 +279,6 @@ def _check_single_tensor(param, param_name): def _check_tensor_list(param, param_name): """ Helper to check that the parameter ``param_name`` is a list of tensors. - """ if not isinstance(param, list) or \ not all(isinstance(p, torch.Tensor) for p in param): @@ -245,10 +286,34 @@ def _check_tensor_list(param, param_name): "to be of type List[torch.Tensor].".format(param_name)) +def _check_op(op): + """ + Helper to check that the ``op`` is either isend or irecv. + """ + if op not in [isend, irecv]: + raise RuntimeError("Invalid ``op``. Expected ``op`` " + "to be of type ``torch.distributed.isend`` or " + "``torch.distributed.irecv``.") + +def _check_p2p_op_list(p2p_op_list): + """ + Helper to check that the ``p2p_op_list`` is a list of P2POp instances and + all ops use the same backend. + """ + if not isinstance(p2p_op_list, list) or \ + not all(isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list): + raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to " + "to be of type ``torch.distributed.P2POp``.") + + + backend = get_backend(p2p_op_list[0].group) + if not all(backend == get_backend(p2p_op.group) for p2p_op in p2p_op_list): + raise RuntimeError("All groups need to use the same backend.") + + def is_mpi_available(): """ Checks if the MPI backend is available. - """ return _MPI_AVAILABLE @@ -256,7 +321,6 @@ def is_mpi_available(): def is_nccl_available(): """ Checks if the NCCL backend is available. - """ return _NCCL_AVAILABLE @@ -264,7 +328,6 @@ def is_nccl_available(): def is_gloo_available(): """ Checks if the Gloo backend is available. - """ return _GLOO_AVAILABLE @@ -272,39 +335,40 @@ def is_gloo_available(): def is_initialized(): """ Checking if the default process group has been initialized - """ - return _default_pg is not None + return GroupMember.WORLD is not None def _get_default_group(): """ Getting the default process group created by init_process_group - """ if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " "please make sure to call init_process_group.") - return _default_pg + return GroupMember.WORLD def _get_default_store(): """ Getting the default store created by init_process_group - """ if not is_initialized(): raise RuntimeError("Default process group has not been initialized, " "please make sure to call init_process_group.") - _, default_store = _pg_map[_default_pg] + default_pg = _get_default_group() + _, default_store = _pg_map[default_pg] return default_store +def _update_default_pg(pg): + GroupMember.WORLD = group.WORLD = pg -def get_backend(group=group.WORLD): + +def get_backend(group=None): """ Returns the backend of the given process group. - Arguments: + Args: group (ProcessGroup, optional): The process group to work on. The default is the general main process group. If another specific group is specified, the calling process must be part of :attr:`group`. @@ -313,15 +377,15 @@ def get_backend(group=group.WORLD): The backend of the given process group as a lower case string. """ - _check_default_pg() - - if group == GroupMember.WORLD: - pg = _default_pg + if group is None: + pg = _get_default_group() else: pg = group if _rank_not_in_group(pg): raise RuntimeError("Invalid process group specified") - return _pg_map.get(pg, None)[0] + pg_store = _pg_map.get(pg, None) + assert pg_store is not None + return pg_store[0] def init_process_group(backend, @@ -344,7 +408,7 @@ def init_process_group(backend, If neither is specified, ``init_method`` is assumed to be "env://". - Arguments: + Args: backend (str or Backend): The backend to use. Depending on build-time configurations, valid values include ``mpi``, ``gloo``, and ``nccl``. This field should be given as a lowercase string @@ -359,7 +423,8 @@ def init_process_group(backend, Mutually exclusive with ``store``. world_size (int, optional): Number of processes participating in the job. Required if ``store`` is specified. - rank (int, optional): Rank of the current process. + rank (int, optional): Rank of the current process (it should be a + number between 0 and ``world_size``-1). Required if ``store`` is specified. store(Store, optional): Key/value store accessible to all workers, used to exchange connection/address information. @@ -368,7 +433,20 @@ def init_process_group(backend, the process group. Default value equals 30 minutes. This is applicable for the ``gloo`` backend. For ``nccl``, this is applicable only if the environment variable ``NCCL_BLOCKING_WAIT`` - is set to 1. + or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When + ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the + process will block and wait for collectives to complete before + throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set, + this is the duration after which collectives will be aborted + asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT`` + will provide errors to the user which can be caught and handled, + but due to its blocking nature, it has a performance overhead. On + the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has very little + performance overhead, but crashes the process on errors. This is + done since CUDA execution is async and it is no longer safe to + continue executing user code since failed async NCCL operations + might result in subsequent CUDA operations running on corrupted + data. Only one of these two environment variables should be set. group_name (str, optional, deprecated): Group name. To enable ``backend == Backend.MPI``, PyTorch needs to be built from source @@ -377,14 +455,13 @@ def init_process_group(backend, """ global _pg_group_ranks global _backend - global _default_pg global _default_pg_init_method if not isinstance(timeout, timedelta): raise RuntimeError("Expected timeout argument to be of type" "datetime.timedelta") - if _default_pg is not None: + if GroupMember.WORLD is not None: raise RuntimeError("trying to initialize the default process group " "twice!") @@ -406,14 +483,14 @@ def init_process_group(backend, "are ignored since they are assigned by the " "MPI runtime.".format(world_size, rank)) - _default_pg = _new_process_group_helper( + _update_default_pg(_new_process_group_helper( -1, -1, [], Backend.MPI, None, group_name=group_name, - timeout=timeout) + timeout=timeout)) else: # backward compatible API if store is None: @@ -423,19 +500,29 @@ def init_process_group(backend, store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) - _default_pg = _new_process_group_helper( + _update_default_pg(_new_process_group_helper( world_size, rank, [], backend, store, group_name=group_name, - timeout=timeout) + timeout=timeout)) - _pg_group_ranks[_default_pg] = {i: i for i in range(_default_pg.size())} - _backend = _pg_map[_default_pg][0] + _pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore + _backend = _pg_map[GroupMember.WORLD][0] # type: ignore _default_pg_init_method = init_method + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables are updated correctly on all + # ranks. + if backend == Backend.MPI: + # MPI backend doesn't use store. + barrier() + else: + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier(rank, store, timeout) def _new_process_group_helper(world_size, rank, @@ -473,6 +560,7 @@ def _new_process_group_helper(world_size, is_default_group = (len(group_ranks) == 0) backend = Backend(backend) + pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL] if backend == Backend.MPI: if not is_mpi_available(): raise RuntimeError( @@ -488,7 +576,7 @@ def _new_process_group_helper(world_size, # If this is a subgroup (which means group_ranks is specified), # we check if the current process is a member of the new group. if not is_default_group: - global_rank = _default_pg.rank() + global_rank = _get_default_group().rank() if global_rank not in group_ranks: return GroupMember.NON_GROUP_MEMBER @@ -527,11 +615,11 @@ def _new_process_group_helper(world_size, return pg -def destroy_process_group(group=group.WORLD): +def destroy_process_group(group=None): """ Destroy a given process group, and deinitialize the distributed package - Arguments: + Args: group (ProcessGroup, optional): The process group to be destroyed, if group.WORLD is given, all process groups including the default one will @@ -540,23 +628,23 @@ def destroy_process_group(group=group.WORLD): global _pg_map global _pg_names global _pg_group_ranks - global _default_pg global _default_pg_init_method global _group_count if group == GroupMember.NON_GROUP_MEMBER: return - if group == GroupMember.WORLD: - pg = _default_pg + if group is None: + pg = GroupMember.WORLD else: pg = group + assert pg is not None if _pg_map.get(pg, None) is None: raise RuntimeError("Invalid process group specified") - if group == GroupMember.WORLD: - _default_pg = None + if group is None or group == GroupMember.WORLD: + _update_default_pg(None) _default_pg_init_method = None _pg_map.clear() _pg_names.clear() @@ -577,7 +665,7 @@ def destroy_process_group(group=group.WORLD): del _pg_group_ranks[pg] -def get_rank(group=group.WORLD): +def get_rank(group=None): """ Returns the rank of current process group @@ -585,8 +673,9 @@ def get_rank(group=group.WORLD): process group. They are always consecutive integers ranging from 0 to ``world_size``. - Arguments: - group (ProcessGroup, optional): The process group to work on + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: The rank of the process group @@ -596,19 +685,20 @@ def get_rank(group=group.WORLD): if _rank_not_in_group(group): return -1 - _check_default_pg() - if group == GroupMember.WORLD: - return _default_pg.rank() + default_pg = _get_default_group() + if group is None or group is GroupMember.WORLD: + return default_pg.rank() - return _get_group_rank(group, _default_pg.rank()) + return _get_group_rank(group, default_pg.rank()) -def get_world_size(group=group.WORLD): +def get_world_size(group=None): """ Returns the number of processes in the current process group - Arguments: - group (ProcessGroup, optional): The process group to work on + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: The world size of the process group @@ -623,15 +713,16 @@ def get_world_size(group=group.WORLD): def isend(tensor, dst, - group=group.WORLD, + group=None, tag=0): """ Sends a tensor asynchronously. - Arguments: + Args: tensor (Tensor): Tensor to send. dst (int): Destination rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with remote recv Returns: @@ -643,25 +734,27 @@ def isend(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - _check_default_pg() - return _default_pg.send([tensor], dst, tag) + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + return default_pg.send([tensor], dst, tag) else: group_dst_rank = _get_group_rank(group, dst) return group.send([tensor], group_dst_rank, tag) def irecv(tensor, - src, - group=group.WORLD, + src=None, + group=None, tag=0): """ Receives a tensor asynchronously. - Arguments: + Args: tensor (Tensor): Tensor to fill with received data. - src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + src (int, optional): Source rank. Will receive from any + process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match recv with remote send Returns: @@ -673,25 +766,33 @@ def irecv(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - _check_default_pg() - return _default_pg.recv([tensor], src, tag) + if group is None or group is GroupMember.WORLD: + pg = _get_default_group() else: - group_src_rank = _get_group_rank(group, src) - return group.recv([tensor], group_src_rank, tag) + pg = group + + if src is None: + return pg.recv_anysource([tensor], tag) + else: + if pg is GroupMember.WORLD: + return pg.recv([tensor], src, tag) + else: + group_src_rank = _get_group_rank(pg, src) + return pg.recv([tensor], group_src_rank, tag) def send(tensor, dst, - group=group.WORLD, + group=None, tag=0): """ Sends a tensor synchronously. - Arguments: + Args: tensor (Tensor): Tensor to send. dst (int): Destination rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match send with remote recv """ @@ -699,9 +800,9 @@ def send(tensor, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - _check_default_pg() - _default_pg.send([tensor], dst, tag).wait() + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + default_pg.send([tensor], dst, tag).wait() else: group_dst_rank = _get_group_rank(group, dst) group.send([tensor], group_dst_rank, tag).wait() @@ -709,16 +810,17 @@ def send(tensor, def recv(tensor, src=None, - group=group.WORLD, + group=None, tag=0): """ Receives a tensor synchronously. - Arguments: + Args: tensor (Tensor): Tensor to fill with received data. src (int, optional): Source rank. Will receive from any process if unspecified. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. tag (int, optional): Tag to match recv with remote send Returns: @@ -730,22 +832,21 @@ def recv(tensor, if _rank_not_in_group(group): return -1 - if group == GroupMember.WORLD: - _check_default_pg() - pg = _default_pg + if group is None: + pg = _get_default_group() else: pg = group if src is None: work = pg.recv_anysource([tensor], tag) work.wait() - src_rank = work.source_rank() - if group == GroupMember.WORLD: + src_rank = work._source_rank() + if group is None or group is GroupMember.WORLD: return src_rank else: return _get_global_rank(pg, src_rank) else: - if group == GroupMember.WORLD: + if group is None or group is GroupMember.WORLD: pg.recv([tensor], src, tag).wait() else: group_src_rank = _get_group_rank(pg, src) @@ -753,9 +854,102 @@ def recv(tensor, return src +class P2POp(object): + """ + A class to build point-to-point operations for ``batch_isend_irecv``. + + This class builds the type of P2P operation, communication buffer, peer rank, + Process Group group, and tag. Instances of this class will be passed to + ``batch_isend_irecv`` for point-to-point communications. + + Args: + op (callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``torch.distributed.isend`` or + ``torch.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int): Destination or source rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with recv. + """ + def __init__(self, op, tensor, peer, group=None, tag=0): + self.op = op + self.tensor = tensor + self.peer = peer + self.group = group + self.tag = tag + + def __new__(cls, op, tensor, peer, group=None, tag=0): + _check_op(op) + _check_single_tensor(tensor, "tensor") + return object.__new__(cls) + + +@contextlib.contextmanager +def _batch_p2p_manager(backend): + if backend == Backend.NCCL: + ProcessGroupNCCL._group_start() + try: + yield + finally: + if backend == Backend.NCCL: + ProcessGroupNCCL._group_end() + + +def batch_isend_irecv(p2p_op_list): + """ + Send or Receive a batch of tensors asynchronously and return a list of requests. + + Process each of the operations in p2p_op_list and return the corresponding + requests. NCCL and Gloo backend are currently supported. + + Args: + p2p_op_list: A list of point-to-point operations(type of each operator is + ``torch.distributed.P2POp``). The order of the isend/irecv in the list + matters and it needs to match with corresponding isend/irecv on the + remote end. + + Returns: + A list of distributed request objects returned by calling the corresponding + op in the op_list. + + Examples: + >>> send_tensor = torch.arange(2) + 2 * rank + >>> recv_tensor = torch.randn(2) + >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size) + >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank + 1)%world_size) + >>> reqs = batch_isend_irecv([send_op, recv_op]) + >>> for req in reqs: + >>> req.wait() + >>> recv_tensor + tensor([2, 3]) # Rank 0 + tensor([0, 1]) # Rank 1 + + .. note:: Note that when this API is used with the NCCL PG backend, users must set + the current GPU device with `torch.cuda.set_device`, otherwise it will + lead to unexpected hang issues. + """ + _check_p2p_op_list(p2p_op_list) + backend = get_backend(p2p_op_list[0].group) + reqs = [] + with _batch_p2p_manager(backend): + for p2p_op in p2p_op_list: + op = p2p_op.op + tensor = p2p_op.tensor + peer = p2p_op.peer + curr_group = p2p_op.group + tag = p2p_op.tag + + ret = op(tensor, peer, curr_group, tag) + + if ret is not None: + reqs.append(ret) + return reqs + + def broadcast_multigpu(tensor_list, src, - group=group.WORLD, + group=None, async_op=False, src_tensor=0): """ @@ -769,7 +963,7 @@ def broadcast_multigpu(tensor_list, Only nccl and gloo backend are currently supported tensors should only be GPU tensors - Arguments: + Args: tensor_list (List[Tensor]): Tensors that participate in the collective operation. If ``src`` is the rank, then the specified ``src_tensor`` element of ``tensor_list`` (``tensor_list[src_tensor]``) will be @@ -779,7 +973,8 @@ def broadcast_multigpu(tensor_list, for all the distributed processes calling this function. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op src_tensor (int, optional): Source tensor rank within ``tensor_list`` @@ -795,9 +990,9 @@ def broadcast_multigpu(tensor_list, opts.rootRank = src opts.rootTensor = src_tensor - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.broadcast(tensor_list, opts) + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.broadcast(tensor_list, opts) else: group_src_rank = _get_group_rank(group, src) opts.rootRank = group_src_rank @@ -810,7 +1005,7 @@ def broadcast_multigpu(tensor_list, def broadcast(tensor, src, - group=group.WORLD, + group=None, async_op=False): """ Broadcasts the tensor to the whole group. @@ -818,11 +1013,12 @@ def broadcast(tensor, ``tensor`` must have the same number of elements in all processes participating in the collective. - Arguments: + Args: tensor (Tensor): Data to be sent if ``src`` is the rank of current process, and tensor to be used to save received data otherwise. src (int): Source rank. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -838,9 +1034,9 @@ def broadcast(tensor, opts.rootRank = src opts.rootTensor = 0 - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.broadcast([tensor], opts) + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.broadcast([tensor], opts) else: group_src_rank = _get_group_rank(group, src) opts.rootRank = group_src_rank @@ -853,7 +1049,7 @@ def broadcast(tensor, def all_reduce_multigpu(tensor_list, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): r""" Reduces the tensor data across all machines in such a way that all get @@ -865,10 +1061,12 @@ def all_reduce_multigpu(tensor_list, After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise identical in all processes. + Complex tensors are supported. + Only nccl and gloo backend is currently supported tensors should only be GPU tensors - Arguments: + Args: tensor list (List[Tensor]): List of input and output tensors of the collective. The function operates in-place and requires that each tensor to be a GPU tensor on different GPUs. @@ -877,7 +1075,8 @@ def all_reduce_multigpu(tensor_list, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -888,11 +1087,13 @@ def all_reduce_multigpu(tensor_list, if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + opts = AllreduceOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.allreduce(tensor_list, opts) + if group is None: + default_pg = _get_default_group() + work = default_pg.allreduce(tensor_list, opts) else: work = group.allreduce(tensor_list, opts) @@ -904,7 +1105,7 @@ def all_reduce_multigpu(tensor_list, def all_reduce(tensor, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces the tensor data across all machines in such a way that all get @@ -912,29 +1113,60 @@ def all_reduce(tensor, After the call ``tensor`` is going to be bitwise identical in all processes. - Arguments: + Complex tensors are supported. + + Args: tensor (Tensor): Input and output of the collective. The function operates in-place. op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Examples: + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6]) # Rank 0 + tensor([4, 6]) # Rank 1 + + >>> # All tensors below are of torch.cfloat type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j]) # Rank 0 + tensor([4.+4.j, 6.+6.j]) # Rank 1 + """ _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): return + if tensor.is_complex(): + if not supports_complex(op): + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") + tensor = torch.view_as_real(tensor) + opts = AllreduceOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.allreduce([tensor], opts) + if group is None: + default_pg = _get_default_group() + work = default_pg.allreduce([tensor], opts) else: work = group.allreduce([tensor], opts) @@ -946,7 +1178,7 @@ def all_reduce(tensor, def all_reduce_coalesced(tensors, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ WARNING: at this time individual shape checking is not implemented across nodes. @@ -963,13 +1195,16 @@ def all_reduce_coalesced(tensors, After the call each tensor in tensors is going to bitwise identical in all processes. - Arguments: + Complex tensors are supported. + + Args: tensors (List[Tensor]): Input and output of the collective. The function operates in-place. op (Optional[ReduceOp]): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (Optional[ProcessGroup]): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (Optional[bool]): Whether this op should be an async op. Returns: @@ -981,11 +1216,16 @@ def all_reduce_coalesced(tensors, if _rank_not_in_group(group): return + if any([t.is_complex() for t in tensors]) and not supports_complex(op): + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") + + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + opts = AllreduceCoalescedOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.allreduce_coalesced(tensors, opts) + if group is None: + default_pg = _get_default_group() + work = default_pg.allreduce_coalesced(tensors, opts) else: work = group.allreduce_coalesced(tensors, opts) @@ -998,7 +1238,7 @@ def all_reduce_coalesced(tensors, def reduce_multigpu(tensor_list, dst, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False, dst_tensor=0): """ @@ -1011,7 +1251,7 @@ def reduce_multigpu(tensor_list, Only nccl backend is currently supported tensors should only be GPU tensors - Arguments: + Args: tensor_list (List[Tensor]): Input and output GPU tensors of the collective. The function operates in-place. You also need to make sure that ``len(tensor_list)`` is the same for @@ -1020,7 +1260,8 @@ def reduce_multigpu(tensor_list, op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op dst_tensor (int, optional): Destination tensor rank within ``tensor_list`` @@ -1038,9 +1279,9 @@ def reduce_multigpu(tensor_list, opts.rootRank = dst opts.rootTensor = dst_tensor - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.reduce(tensor_list, opts) + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.reduce(tensor_list, opts) else: group_dst_rank = _get_group_rank(group, dst) opts.rootRank = group_dst_rank @@ -1055,21 +1296,22 @@ def reduce_multigpu(tensor_list, def reduce(tensor, dst, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces the tensor data across all machines. Only the process with rank ``dst`` is going to receive the final result. - Arguments: + Args: tensor (Tensor): Input and output of the collective. The function operates in-place. dst (int): Destination rank op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1085,9 +1327,9 @@ def reduce(tensor, opts.reduceOp = op opts.rootRank = dst - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.reduce([tensor], opts) + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.reduce([tensor], opts) else: group_dst_rank = _get_group_rank(group, dst) opts.rootRank = group_dst_rank @@ -1101,7 +1343,7 @@ def reduce(tensor, def all_gather_multigpu(output_tensor_lists, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Gathers tensors from the whole group in a list. @@ -1110,7 +1352,9 @@ def all_gather_multigpu(output_tensor_lists, Only nccl backend is currently supported tensors should only be GPU tensors - Arguments: + Complex tensors are supported. + + Args: output_tensor_lists (List[List[Tensor]]): Output lists. It should contain correctly-sized tensors on each GPU to be used for output of the collective, e.g. ``output_tensor_lists[i]`` contains the @@ -1134,7 +1378,8 @@ def all_gather_multigpu(output_tensor_lists, Note that ``len(input_tensor_list)`` needs to be the same for all the distributed processes calling this function. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1145,9 +1390,12 @@ def all_gather_multigpu(output_tensor_lists, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.allgather(output_tensor_lists, input_tensor_list) + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + + if group is None: + default_pg = _get_default_group() + work = default_pg.allgather(output_tensor_lists, input_tensor_list) else: work = group.allgather(output_tensor_lists, input_tensor_list) @@ -1159,7 +1407,7 @@ def all_gather_multigpu(output_tensor_lists, def _object_to_tensor(obj): buffer = pickle.dumps(obj) - byte_storage = torch.ByteStorage.from_buffer(buffer) + byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] byte_tensor = torch.ByteTensor(byte_storage) local_size = torch.LongTensor([byte_tensor.numel()]) return byte_tensor, local_size @@ -1171,17 +1419,18 @@ def _tensor_to_object(tensor, tensor_size): return out -def all_gather_object(object_list, obj, group=group.WORLD): +def all_gather_object(object_list, obj, group=None): """ Gathers picklable objects from the whole group into a list. Similar to :func:`all_gather`, but Python objects can be passed in. Note that the object must be picklable in order to be gathered. - Arguments: + Args: object_list (list[Any]): Output list. It should be correctly sized as the size of the group for this collective and will contain the output. object (Any): Pickable Python object to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: None. If the calling rank is part of this group, the output of the @@ -1193,38 +1442,58 @@ def all_gather_object(object_list, obj, group=group.WORLD): collective since it does not provide an ``async_op`` handle and thus will be a blocking call. + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + .. warning:: :func:`all_gather_object` uses ``pickle`` module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] """ if _rank_not_in_group(group): return input_tensor, local_size = _object_to_tensor(obj) group_backend = get_backend(group) - my_rank = get_rank() is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") if is_nccl_backend: - input_tensor, local_size = input_tensor.to(my_rank), local_size.to(my_rank) + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros(group_size, dtype=int).to( - my_rank if is_nccl_backend else "cpu" - ) + object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] # Allgather tensor sizes all_gather(object_size_list, local_size, group=group) - max_object_size = max(object_size_list) + max_object_size = int(max(object_size_list).item()) # type: ignore # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8 - ).to(my_rank if is_nccl_backend else "cpu") + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] @@ -1233,25 +1502,26 @@ def all_gather_object(object_list, obj, group=group.WORLD): all_gather(output_tensors, input_tensor, group=group) # Deserialize outputs back to object. for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.ByteTensor) + tensor = tensor.type(torch.ByteTensor) # type:ignore[call-overload] tensor_size = object_size_list[i] object_list[i] = _tensor_to_object(tensor, tensor_size) -def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): +def gather_object(obj, object_gather_list=None, dst=0, group=None): """ Gathers picklable objects from the whole group in a single process. Similar to :func:`gather`, but Python objects can be passed in. Note that the object must be picklable in order to be gathered. - Arguments: + Args: obj (Any): Input object. Must be picklable. object_gather_list (list[Any]): Output list. On the ``dst`` rank, it should be correctly sized as the size of the group for this collective and will contain the output. Must be ``None`` on non-dst ranks. (default is ``None``) dst (int, optional): Destination rank. (default is 0) - group: (ProcessGroup, optional): The process group to work on. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: None. On the ``dst`` rank, ``object_gather_list`` will contain the @@ -1268,6 +1538,21 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + gather_objects[dist.get_rank()], + output if dist.get_rank() == 0 else None, + dst=0 + ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] """ if _rank_not_in_group(group): return @@ -1277,29 +1562,31 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): _validate_output_list_for_rank(my_rank, dst, object_gather_list) input_tensor, local_size = _object_to_tensor(obj) group_backend = get_backend(group) + current_device = torch.device("cpu") is_nccl_backend = group_backend == Backend.NCCL if is_nccl_backend: - input_tensor, local_size = input_tensor.to(my_rank), local_size.to(my_rank) + current_device = torch.device('cuda', torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros(group_size, dtype=int).to( - my_rank if is_nccl_backend else "cpu" - ) + object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] - # Allgather tensor sizes. An all-gather is needed here despite this being a gather, - # since each rank needs to broadcast a tensor of the same (maximal) size. + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. all_gather(object_size_list, local_size, group=group) - max_object_size = max(object_size_list) + max_object_size = int(max(object_size_list).item()) # type: ignore # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) # Avoid populating output tensors if the result won't be gathered on this rank. if my_rank == dst: coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8 - ).to(my_rank if is_nccl_backend else "cpu") + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] @@ -1315,38 +1602,58 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): if my_rank != dst: return for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.ByteTensor) + tensor = tensor.type(torch.ByteTensor) # type: ignore[call-overload] tensor_size = object_size_list[i] object_gather_list[i] = _tensor_to_object(tensor, tensor_size) -def broadcast_object_list(object_list, src, group=group.WORLD): +def broadcast_object_list(object_list, src, group=None): """ Broadcasts picklable objects in ``object_list`` to the whole group. Similar to :func:`broadcast`, but Python objects can be passed in. Note that all objects in ``object_list`` must be picklable in order to be broadcasted. - Arguments: + Args: object_list (List[Any]): List of input objects to broadcast. Each object must be picklable. Only objects on the ``src`` rank will be broadcast, but each rank must provide lists of equal sizes. src (int): Source rank from which to broadcast ``object_list``. - group: (ProcessGroup, optional): The process group to work on. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: ``None``. If rank is part of the group, ``object_list`` will contain the broadcasted objects from ``src`` rank. - .. note:: Note that this API differs slightly from the broadcast collective - since it does not provide an ``async_op`` handle and thus will be a - blocking call. + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. .. warning:: :func:`broadcast_object_list` uses ``pickle`` module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> objects = [None, None, None] + >>> dist.broadcast_object_list(objects, src=0) + >>> broadcast_objects + ['foo', 12, {1: 2}] """ if _rank_not_in_group(group): return @@ -1361,8 +1668,14 @@ def broadcast_object_list(object_list, src, group=group.WORLD): group_backend = get_backend(group) is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") if is_nccl_backend: - object_sizes_tensor = object_sizes_tensor.to(my_rank) + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) # Broadcast object sizes broadcast(object_sizes_tensor, src=src, group=group) @@ -1374,45 +1687,179 @@ def broadcast_object_list(object_list, src, group=group.WORLD): object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) if is_nccl_backend: - object_tensor = object_tensor.to(my_rank) + object_tensor = object_tensor.to(current_device) broadcast(object_tensor, src=src, group=group) # Deserialize objects using their stored sizes. offset = 0 if my_rank != src: for i, obj_size in enumerate(object_sizes_tensor): obj_view = object_tensor[offset : offset + obj_size] - obj_view = obj_view.type(torch.ByteTensor) + obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] offset += obj_size object_list[i] = _tensor_to_object(obj_view, obj_size) +def scatter_object_list( + scatter_object_output_list, scatter_object_input_list, src=0, group=group.WORLD +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole + group. Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Args: + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any]): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + src (int): Source rank from which to scatter + ``scatter_object_input_list``. + group: (ProcessGroup, optional): The process group to work on. + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] + """ + if _rank_not_in_group(group): + return + + if ( + not isinstance(scatter_object_output_list, list) + or len(scatter_object_output_list) < 1 + ): + raise RuntimeError( + "Expected argument scatter_object_output_list to be a list of size at least 1." + ) + + my_rank = get_rank(group) + if my_rank == src: + tensor_list, tensor_sizes = zip( + *[_object_to_tensor(obj) for obj in scatter_object_input_list] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + obj_tensor_size = torch.LongTensor([0]) + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + if my_rank == src: + max_tensor_size = max(tensor_sizes) + for tensor in tensor_list: + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.LongTensor([0]) + broadcast(max_tensor_size, src=src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.ByteTensor(max_tensor_size.item()) + scatter( + output_tensor, + scatter_list=None if my_rank != src else tensor_list, + src=src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + scatter( + obj_tensor_size, + scatter_list=None if my_rank != src else tensor_sizes, + src=src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size) + + def all_gather(tensor_list, tensor, - group=group.WORLD, + group=None, async_op=False): """ Gathers tensors from the whole group in a list. - Arguments: + Complex tensors are supported. + + Args: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. tensor (Tensor): Tensor to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Examples: + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)] + >>> tensor_list + [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2]), tensor([3, 4])] # Rank 0 + [tensor([1, 2]), tensor([3, 4])] # Rank 1 + + >>> # All tensors below are of torch.cfloat dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0 + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1 + """ _check_tensor_list(tensor_list, "tensor_list") _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.allgather([tensor_list], [tensor]) + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + if group is None: + default_pg = _get_default_group() + work = default_pg.allgather([tensor_list], [tensor]) else: work = group.allgather([tensor_list], [tensor]) @@ -1423,17 +1870,20 @@ def all_gather(tensor_list, def all_gather_coalesced(output_tensor_lists, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Gathers input tensors from the whole group in a list in a coalesced manner. - Arguments: + Complex tensors are supported. + + Args: output_tensor_lists (list[list[Tensor]]): Output list. It should contain correctly-sized tensors to be used for output of the collective. input_tensor_list (list[Tensor]): Tensors to be broadcast from current process. At least one tensor has to be non empty. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -1476,9 +1926,12 @@ def all_gather_coalesced(output_tensor_lists, for output_tensor_list in output_tensor_lists: _check_tensor_list(output_tensor_list, "output_tensor_lists") - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.allgather_coalesced( + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + + if group is None: + default_pg = _get_default_group() + work = default_pg.allgather_coalesced( output_tensor_lists, input_tensor_list) else: work = group.allgather_coalesced(output_tensor_lists, input_tensor_list) @@ -1504,18 +1957,19 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): def gather(tensor, gather_list=None, dst=0, - group=group.WORLD, + group=None, async_op=False): """ Gathers a list of tensors in a single process. - Arguments: + Args: tensor (Tensor): Input tensor. gather_list (list[Tensor], optional): List of appropriately-sized tensors to use for gathered data (default is None, must be specified on the destination rank) dst (int, optional): Destination rank (default is 0) - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1542,9 +1996,9 @@ def gather(tensor, opts = GatherOptions() opts.rootRank = dst - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.gather(output_tensors, input_tensors, opts) + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.gather(output_tensors, input_tensors, opts) else: group_dst_rank = _get_group_rank(group, dst) opts.rootRank = group_dst_rank @@ -1559,7 +2013,7 @@ def gather(tensor, def scatter(tensor, scatter_list=None, src=0, - group=group.WORLD, + group=None, async_op=False): """ Scatters a list of tensors to all processes in a group. @@ -1567,12 +2021,13 @@ def scatter(tensor, Each process will receive exactly one tensor and store its data in the ``tensor`` argument. - Arguments: + Args: tensor (Tensor): Output tensor. scatter_list (list[Tensor]): List of tensors to scatter (default is None, must be specified on the source rank) src (int): Source rank (default is 0) - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: @@ -1608,9 +2063,9 @@ def scatter(tensor, opts = ScatterOptions() opts.rootRank = src - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.scatter(output_tensors, input_tensors, opts) + if group is None or group is GroupMember.WORLD: + default_pg = _get_default_group() + work = default_pg.scatter(output_tensors, input_tensors, opts) else: group_src_rank = _get_group_rank(group, src) opts.rootRank = group_src_rank @@ -1625,7 +2080,7 @@ def scatter(tensor, def reduce_scatter_multigpu(output_tensor_list, input_tensor_lists, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduce and scatter a list of tensors to the whole group. Only nccl backend @@ -1634,7 +2089,7 @@ def reduce_scatter_multigpu(output_tensor_list, Each tensor in ``output_tensor_list`` should reside on a separate GPU, as should each list of tensors in ``input_tensor_lists``. - Arguments: + Args: output_tensor_list (List[Tensor]): Output tensors (on different GPUs) to receive the result of the operation. @@ -1659,7 +2114,8 @@ def reduce_scatter_multigpu(output_tensor_list, therefore ``len(input_tensor_lists[i])``) need to be the same for all the distributed processes calling this function. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -1673,9 +2129,9 @@ def reduce_scatter_multigpu(output_tensor_list, opts = ReduceScatterOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.reduce_scatter( + if group is None: + default_pg = _get_default_group() + work = default_pg.reduce_scatter( output_tensor_list, input_tensor_lists, opts @@ -1696,15 +2152,16 @@ def reduce_scatter_multigpu(output_tensor_list, def reduce_scatter(output, input_list, op=ReduceOp.SUM, - group=group.WORLD, + group=None, async_op=False): """ Reduces, then scatters a list of tensors to all processes in a group. - Arguments: + Args: output (Tensor): Output tensor. input_list (list[Tensor]): List of tensors to reduce and scatter. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -1720,9 +2177,9 @@ def reduce_scatter(output, opts = ReduceScatterOptions() opts.reduceOp = op - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.reduce_scatter([output], [input_list], opts) + if group is None: + default_pg = _get_default_group() + work = default_pg.reduce_scatter([output], [input_list], opts) else: work = group.reduce_scatter([output], [input_list], opts) @@ -1736,14 +2193,14 @@ def all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None, - group=group.WORLD, + group=None, async_op=False): """ Each process splits input tensor and then scatters the split list to all processes in a group. Then concatenate the received tensors from all the processes in the group and return single output tensor. - Arguments: + Args: output (Tensor): Gathered cancatenated output tensor. input (Tensor): Input tensor to scatter. output_split_sizes: (list[Int], optional): Output split sizes for dim 0 @@ -1752,7 +2209,8 @@ def all_to_all_single(output, input_split_sizes: (list[Int], optional): Input split sizes for dim 0 if specified None or empty, dim 0 of ``input`` tensor must divide equally by ``world_size``. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -1816,9 +2274,9 @@ def all_to_all_single(output, output_split_sizes = [] if output_split_sizes is None else output_split_sizes input_split_sizes = [] if input_split_sizes is None else input_split_sizes - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) + if group is None: + default_pg = _get_default_group() + work = default_pg.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) else: work = group.alltoall_base(output, input, output_split_sizes, input_split_sizes, opts) @@ -1829,17 +2287,18 @@ def all_to_all_single(output, def all_to_all(output_tensor_list, input_tensor_list, - group=group.WORLD, + group=None, async_op=False): """ Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. - Arguments: + Args: output_tensor_list (list[Tensor]): List of tensors to be gathered one per rank. input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. - group (ProcessGroup, optional): The process group to work on. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op. Returns: @@ -1907,9 +2366,9 @@ def all_to_all(output_tensor_list, _check_tensor_list(output_tensor_list, "output_tensor_list") _check_tensor_list(input_tensor_list, "input_tensor_list") - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.alltoall(output_tensor_list, input_tensor_list, opts) + if group is None: + default_pg = _get_default_group() + work = default_pg.alltoall(output_tensor_list, input_tensor_list, opts) else: work = group.alltoall(output_tensor_list, input_tensor_list, opts) @@ -1919,17 +2378,23 @@ def all_to_all(output_tensor_list, work.wait() -def barrier(group=group.WORLD, - async_op=False): + +def barrier(group=GroupMember.WORLD, + async_op=False, + device_ids=None): + """ Synchronizes all processes. This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait(). - Arguments: - group (ProcessGroup, optional): The process group to work on + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. async_op (bool, optional): Whether this op should be an async op + device_ids ([int], optional): List of device/GPU ids. + Valid only for NCCL backend. Returns: Async work handle, if async_op is set to True. @@ -1938,11 +2403,22 @@ def barrier(group=group.WORLD, if _rank_not_in_group(group): return - if group == GroupMember.WORLD: - _check_default_pg() - work = _default_pg.barrier() + opts = BarrierOptions() + if device_ids is not None: + if get_backend(group) != Backend.NCCL: + raise RuntimeError("Function argument device_ids not supported " + "for the selected backend {}".format(get_backend(group))) + if isinstance(device_ids, list): + opts.device_ids = device_ids + else: + raise RuntimeError("Invalid function argument: " + "device_ids type should be List[int]") + + if group is None: + default_pg = _get_default_group() + work = default_pg.barrier(opts=opts) else: - work = group.barrier() + work = group.barrier(opts=opts) if async_op: return work @@ -1959,7 +2435,18 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None): if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes. - Arguments: + .. warning:: + Using multiple process groups with the ``NCCL`` backend concurrently + is not safe and the user should perform explicit synchronization in + their application to ensure only one process group is used at a time. + This means collectives from one process group should have completed + execution on the device (not just enqueued since CUDA execution is + async) before collectives from another process group are enqueued. + See `Using multiple NCCL communicators concurrently `_ for more details. + + Args: ranks (list[int]): List of ranks of group members. If ``None``, will be set to all ranks. Default is ``None``. timeout (timedelta, optional): Timeout for operations executed against @@ -1976,13 +2463,13 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None): A handle of distributed group that can be given to collective calls. """ - _check_default_pg() global _pg_group_ranks - default_backend, default_store = _pg_map[_default_pg] - global_rank = _default_pg.rank() - global_world_size = _default_pg.size() + default_pg = _get_default_group() + default_backend, default_store = _pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() # Default to the same backend as the global process group # if the backend is not specified. @@ -2025,4 +2512,15 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None): for group_rank, global_rank in enumerate(ranks) } + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables are updated correctly on all + # ranks. + if backend == Backend.MPI: + # MPI doesn't have store. + barrier() + else: + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier(global_rank, default_store, timeout) + return pg diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 2a5aebd9d97a9..a6a5b26e6d40f 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -83,13 +83,13 @@ :: - >>> torch.cuda.set_device(arg.local_rank) # before your code runs + >>> torch.cuda.set_device(args.local_rank) # before your code runs or :: - >>> with torch.cuda.device(arg.local_rank): + >>> with torch.cuda.device(args.local_rank): >>> # your code to run 3. In your training program, you are supposed to call the following function @@ -111,8 +111,8 @@ :: model = torch.nn.parallel.DistributedDataParallel(model, - device_ids=[arg.local_rank], - output_device=arg.local_rank) + device_ids=[args.local_rank], + output_device=args.local_rank) Please ensure that ``device_ids`` argument is set to be the only GPU device id that your code will be operating on. This is generally the local rank of the @@ -137,11 +137,16 @@ """ +import time +import signal import sys import subprocess import os from argparse import ArgumentParser, REMAINDER +from typing import Optional, IO, List, Any +node_local_rank_stdout_filename = "node_{}_local_rank_{}_stdout" +node_local_rank_stderr_filename = "node_{}_local_rank_{}_stderr" def parse_args(): """ @@ -185,6 +190,16 @@ def parse_args(): parser.add_argument("--no_python", default=False, action="store_true", help="Do not prepend the training script with \"python\" - just exec " "it directly. Useful when the script is not a Python script.") + parser.add_argument( + "--logdir", + default=None, + type=str, + help=f"""Relative path to write subprocess logs to. Passing in a relative + path will create a directory if needed, and write the stdout and stderr to files + {node_local_rank_stdout_filename} and {node_local_rank_stderr_filename}. Note that + successive runs with the same path to write logs to will overwrite existing logs, + so be sure to save logs as needed.""", + ) # positional parser.add_argument("training_script", type=str, @@ -209,7 +224,7 @@ def main(): current_env["MASTER_PORT"] = str(args.master_port) current_env["WORLD_SIZE"] = str(dist_world_size) - processes = [] + processes: List[Any] = [] if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1: current_env["OMP_NUM_THREADS"] = str(1) @@ -220,6 +235,17 @@ def main(): "your application as needed. \n" "*****************************************".format(current_env["OMP_NUM_THREADS"])) + if args.logdir: + # Possibly create the directory to write subprocess log output to. + if os.path.exists(args.logdir): + if not os.path.isdir(args.logdir): + raise ValueError("argument --logdir must be a path to a directory.") + else: + # create the relative directory + os.mkdir(os.path.join(os.getcwd(), args.logdir)) + + subprocess_file_handles = [] + for local_rank in range(0, args.nproc_per_node): # each process's rank dist_rank = args.nproc_per_node * args.node_rank + local_rank @@ -246,15 +272,69 @@ def main(): cmd.extend(args.training_script_args) - process = subprocess.Popen(cmd, env=current_env) + stdout_handle: Optional[IO] + stderr_handle: Optional[IO] + if args.logdir: + directory_path = os.path.join(os.getcwd(), args.logdir) + node_rank = args.node_rank + stdout_file_name = node_local_rank_stdout_filename.format(node_rank, local_rank) + stderr_file_name = node_local_rank_stderr_filename.format(node_rank, local_rank) + stdout_handle = open(os.path.join(directory_path, stdout_file_name), "w") + stderr_handle = open(os.path.join(directory_path, stderr_file_name), "w") + subprocess_file_handles.append((stdout_handle, stderr_handle)) + stdout_name = stdout_handle.name + stderr_name = stderr_handle.name + print(f"""Note: Stdout and stderr for node {node_rank} rank {local_rank} will + be written to {stdout_name}, {stderr_name} respectively.""") + + sig_names = {2: "SIGINT", 15: "SIGTERM"} + last_return_code = None + + def sigkill_handler(signum, frame): + for process in processes: + print(f"Killing subprocess {process.pid}") + try: + process.kill() + except Exception: + pass + if last_return_code is not None: + raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd) + if signum in sig_names: + print(f"Main process received {sig_names[signum]}, exiting") + sys.exit(1) + + # pass SIGINT/SIGTERM to children if the parent is being terminated + signal.signal(signal.SIGINT, sigkill_handler) + signal.signal(signal.SIGTERM, sigkill_handler) + + stdout_handle = None if not subprocess_file_handles else subprocess_file_handles[local_rank][0] + stderr_handle = None if not subprocess_file_handles else subprocess_file_handles[local_rank][1] + process = subprocess.Popen(cmd, env=current_env, stdout=stdout_handle, stderr=stderr_handle) processes.append(process) - for process in processes: - process.wait() - if process.returncode != 0: - raise subprocess.CalledProcessError(returncode=process.returncode, - cmd=cmd) - + try: + alive_processes = set(processes) + while len(alive_processes): + finished_processes = [] + for process in alive_processes: + if process.poll() is None: + # the process is still running + continue + else: + if process.returncode != 0: + last_return_code = process.returncode # for sigkill_handler + sigkill_handler(signal.SIGTERM, None) # not coming back + else: + # exited cleanly + finished_processes.append(process) + alive_processes = set(alive_processes) - set(finished_processes) + + time.sleep(1) + finally: + # close open file descriptors + for (stdout_handle, stderr_handle) in subprocess_file_handles: + stdout_handle.close() + stderr_handle.close() if __name__ == "__main__": main() diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 225cb4842bd1f..1aa54c693ee59 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -17,6 +17,8 @@ import torch.distributed.rpc as rpc from torch import Tensor, device, dtype, nn from torch.distributed.nn.jit import instantiator +from torch.distributed.rpc.utils import _parse_remote_device +from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle @@ -51,7 +53,7 @@ def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls= def _param_rrefs(module_rref, recurse): - ret = [] + ret: List[rpc.RRef[Parameter]] = [] for param in module_rref.local_value().parameters(recurse): ret.append(rpc.RRef(param)) return ret @@ -64,8 +66,7 @@ def _raise_not_supported(name): class _RemoteModule(nn.Module): def __init__( self, - on: str, - device: torch.device, + remote_device: str, module_cls: nn.Module, args: Tuple = None, kwargs: Dict[str, Any] = None, @@ -99,9 +100,11 @@ def __init__( ``def forward(input: Tensor) -> Tensor:`` and ``def forward_async(input: Tensor) -> Future[Tensor]:``. - Arguments: - on (str or WorkerInfo): id or name of the destination worker. - device (torch.device): Device on the destination worker where we‘d like to place this module. + Args: + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): >>> def forward(input): @@ -132,7 +135,7 @@ def __init__( >>> >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> remote_linear_module = RemoteModule( - >>> "worker1", "cpu", nn.Linear, args=(20, 30), + >>> "worker1/cpu", nn.Linear, args=(20, 30), >>> ) >>> input = torch.randn(128, 20) >>> ret_fut = remote_linear_module.forward_async(input) @@ -155,18 +158,22 @@ def __init__( args = args if args is not None else () kwargs = kwargs if kwargs is not None else {} - self.on = on + self.on, self.device = _parse_remote_device(remote_device) if _module_interface_cls is not None: # Users reply on this field to know if this generated RemoteModule is TorchScript-able. self.is_scriptable = True # Instantiate template on remote side. - fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,)) + fut = rpc.rpc_async( + self.on, _instantiate_template, (_module_interface_cls,) + ) # Instantiate template on local side. - generated_module = instantiator.instantiate_scriptable_remote_module_template( - _module_interface_cls + generated_module = ( + instantiator.instantiate_scriptable_remote_module_template( + _module_interface_cls + ) ) generated_methods = generated_module._generated_methods @@ -178,9 +185,9 @@ def __init__( # Create the module on the remote side. self.module_rref = rpc.rpc_sync( - on, + self.on, _create_module, - (module_cls, args, kwargs, device, _module_interface_cls), + (module_cls, args, kwargs, self.device, _module_interface_cls), ) # Install generated methods. @@ -202,6 +209,10 @@ def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]: """ return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse)) + def get_module_rref(self) -> rpc.RRef[nn.Module]: + """Returns the RRef to remote module.""" + return self.module_rref + def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: @@ -210,45 +221,45 @@ def register_buffer( def register_parameter(self, name: str, param: Optional[Parameter]) -> None: _raise_not_supported(self.register_parameter.__name__) - def add_module(self, name: str, module: Optional["Module"]) -> None: + def add_module(self, name: str, module: Optional[Module]) -> None: _raise_not_supported(self.add_module.__name__) - def apply(self: T, fn: Callable[["Module"], None]) -> T: + def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return] _raise_not_supported(self.apply.__name__) - def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: + def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] _raise_not_supported(self.cuda.__name__) - def cpu(self: T) -> T: + def cpu(self: T) -> T: # type: ignore[return] _raise_not_supported(self.cpu.__name__) - def type(self: T, dst_type: Union[dtype, str]) -> T: + def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return] _raise_not_supported(self.type.__name__) - def float(self: T) -> T: + def float(self: T) -> T: # type: ignore[return] _raise_not_supported(self.float.__name__) - def double(self: T) -> T: + def double(self: T) -> T: # type: ignore[return] _raise_not_supported(self.double.__name__) - def half(self: T) -> T: + def half(self: T) -> T: # type: ignore[return] _raise_not_supported(self.half.__name__) - def bfloat16(self: T) -> T: + def bfloat16(self: T) -> T: # type: ignore[return] _raise_not_supported(self.bfloat16.__name__) - def to(self, *args, **kwargs): + def to(self, *args, **kwargs) -> T: # type: ignore[return] _raise_not_supported(self.to.__name__) - def register_backward_hook( - self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]] + def register_backward_hook( # type: ignore[return] + self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, Tensor]] ) -> RemovableHandle: _raise_not_supported(self.register_backward_hook.__name__) - def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: + def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return] _raise_not_supported(self.register_forward_pre_hook.__name__) - def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: + def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return] _raise_not_supported(self.register_forward_hook.__name__) def state_dict(self, destination=None, prefix="", keep_vars=False): @@ -266,47 +277,47 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." ) - def named_parameters( + def named_parameters( # type: ignore[return] self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, Tensor]]: + ) -> Iterator[Tuple[str, Parameter]]: _raise_not_supported(self.named_parameters.__name__) - def buffers(self, recurse: bool = True) -> Iterator[Tensor]: + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return] _raise_not_supported(self.buffers.__name__) - def named_buffers( + def named_buffers( # type: ignore[return] self, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, Tensor]]: _raise_not_supported(self.named_buffers.__name__) - def children(self) -> Iterator["Module"]: + def children(self) -> Iterator[Module]: # type: ignore[return] _raise_not_supported(self.children.__name__) - def named_children(self) -> Iterator[Tuple[str, "Module"]]: + def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return] _raise_not_supported(self.named_children.__name__) - def modules(self) -> Iterator["Module"]: + def modules(self) -> Iterator[Module]: # type: ignore[return] _raise_not_supported(self.modules.__name__) - def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""): + def named_modules(self, memo: Optional[Set[Module]] = None, prefix: str = ""): _raise_not_supported(self.named_modules.__name__) - def train(self: T, mode: bool = True) -> T: + def train(self: T, mode: bool = True) -> T: # type: ignore[return] _raise_not_supported(self.train.__name__) - def eval(self: T) -> T: + def eval(self: T) -> T: # type: ignore[return] _raise_not_supported(self.eval.__name__) - def requires_grad_(self: T, requires_grad: bool = True) -> T: + def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return] _raise_not_supported(self.requires_grad_.__name__) - def zero_grad(self) -> None: + def zero_grad(self, set_to_none: bool = False) -> None: _raise_not_supported(self.zero_grad.__name__) - def share_memory(self: T) -> T: + def share_memory(self: T) -> T: # type: ignore[return] _raise_not_supported(self.share_memory.__name__) - def extra_repr(self) -> str: + def extra_repr(self) -> str: # type: ignore[return] _raise_not_supported(self.extra_repr.__name__) @@ -328,9 +339,11 @@ class RemoteModule(_RemoteModule): ``def forward(input: Tensor) -> Tensor:`` and ``def forward_async(input: Tensor) -> Future[Tensor]:``. - Arguments: - to (str or WorkerInfo): id or name of the destination worker. - device (torch.device): Device on the destination worker where we‘d like to place this module. + Args: + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): >>> def forward(input): @@ -357,7 +370,7 @@ class RemoteModule(_RemoteModule): >>> >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> remote_linear_module = RemoteModule( - >>> "worker1", nn.Linear, args=(20, 30), + >>> "worker1/cpu", nn.Linear, args=(20, 30), >>> ) >>> input = torch.randn(128, 20) >>> ret_fut = remote_linear_module.forward_async(input) @@ -374,10 +387,9 @@ class RemoteModule(_RemoteModule): def __init__( self, - on: str, - device: torch.device, + remote_device: str, module_cls: nn.Module, args: Tuple = None, kwargs: Dict[str, Any] = None, ): - super().__init__(on, device, module_cls, args, kwargs) + super().__init__(remote_device, module_cls, args, kwargs) diff --git a/torch/distributed/nn/jit/instantiator.py b/torch/distributed/nn/jit/instantiator.py index 346984c90dad3..950343f093303 100644 --- a/torch/distributed/nn/jit/instantiator.py +++ b/torch/distributed/nn/jit/instantiator.py @@ -6,6 +6,7 @@ import tempfile import torch +from typing import Optional from torch.distributed.nn.jit.templates.remote_module_template import ( REMOTE_MODULE_TEMPLATE, ) @@ -37,11 +38,12 @@ def get_arg_return_types_from_interface(module_interface): arg_str_list = [] arg_type_str_list = [] + assert method_schema is not None for argument in method_schema.arguments: arg_str_list.append(argument.name) if argument.has_default_value(): - default_value_str = " = {}".format(argument.default) + default_value_str = " = {}".format(argument.default_value) else: default_value_str = "" arg_type_str = "{name}: {type}{default_value}".format( @@ -63,6 +65,7 @@ def get_arg_return_types_from_interface(module_interface): def _write(out_path, text): + old_text: Optional[str] try: with open(out_path, "r") as f: old_text = f.read() diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py new file mode 100644 index 0000000000000..cafce79a8c8ef --- /dev/null +++ b/torch/distributed/optim/functional_adagrad.py @@ -0,0 +1,90 @@ +from typing import List, Dict, Optional +import torch +import torch.optim.functional as F + +from torch import Tensor + +# Define a TorchScript compatible Functional Adagrad Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly let the user pass gradients to the `step` function +# this is so that we could separate the gradients and parameters +# and allow multithreaded trainer to update the parameters +# without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdagrad(object): + def __init__( + self, + params: List[Tensor], + lr: float = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + warmup_lr_multiplier: float = 1.0, + warmup_num_iters: float = 0.0, + eps: float = 1e-10, + coalesce_grad: bool = True, + ): + self.defaults = { + "lr": lr, + "lr_decay": lr_decay, + "eps": eps, + "weight_decay": weight_decay, + "initial_accumulator_value": initial_accumulator_value, + "warmup_lr_multiplier": warmup_lr_multiplier, + "warmup_num_iters": warmup_num_iters, + } + self.coalesce_grad = coalesce_grad + self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + + if len(params) == 0: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + # TODO: no union or any types in TorchScript, make step a scalar tensor instead + # This is also needed by if we want to share_memory on the step across processes + for p in self.param_group["params"]: + self.state[p] = { + "sum": torch.full_like(p.data, initial_accumulator_value), + "step": torch.tensor(0.0), + } + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group['params'] + params_with_grad = [] + grads = [] + state_sums = [] + state_steps: List[int] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(self.param_group['params'], gradients): + if gradient is not None: + params_with_grad.append(param) + grads.append(gradient) + state = self.state[param] + state_sums.append(state['sum']) + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step'].item()) + + with torch.no_grad(): + F.adagrad(params, + grads, + state_sums, + state_steps, + self.defaults['lr'], + self.defaults['weight_decay'], + self.defaults['lr_decay'], + self.defaults['eps']) diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index 153d229e3bf36..c7f8e3236776b 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,11 +1,63 @@ +from typing import List, Optional + import torch.distributed.rpc as rpc +import torch.optim as optim +import torch.jit as jit +import torch.nn as nn +from torch import Tensor +from torch.distributed.rpc import RRef +from .functional_adagrad import _FunctionalAdagrad import torch.distributed.autograd as dist_autograd + from collections import defaultdict from threading import Lock -class _LocalOptimizer: +# XXX: we define a _ScriptModuleOptimizer here to explicitly +# compile the FunctionalOptimizer class into TorchScript +# This is because ScriptClass instance still lives in +# python unless you explictly compile it as an attribute +# in ScriptModule or pass it to a ScriptFunction +# _ScriptLocalOptimizerInterface serves as a common +# interface type for Optimizer ScriptModules. +# +# TODO (wanchaol): remove this once we added TorchScript +# class reference semantics +@jit.interface +class _ScriptLocalOptimizerInterface(object): + def step(self, autograd_ctx_id: int) -> None: + pass + +class _ScriptLocalOptimizer(nn.Module): + # TorchScript does not support multithread concurrent compiling. + # request_callback might invoke concurrent compiling, so we + # serialize the compiling with a lock + compile_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + super().__init__() + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls( + self._local_params, + *args, + **kwargs) + + @jit.export + def step(self, autograd_ctx_id: int): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + # apply functional optimizer step with a list of gradients + grads: List[Optional[Tensor]] = [ + all_local_grads[p] if p in all_local_grads else None + for p in self._local_params + ] + + self.optim.step(grads) + + +# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once +# we have converted all to functional optimizer in distributed.optim +class _LocalOptimizer(object): # Ideally we would only need to share a lock for instances of # _LocalOptimizer that deal with the same parameters. We are # making a simplifying assumption here that if there is more @@ -16,8 +68,9 @@ class _LocalOptimizer: global_lock = Lock() def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + self._local_params = [rref.local_value() for rref in local_params_rref] self.optim = optim_cls( - [rref.local_value() for rref in local_params_rref], + self._local_params, *args, **kwargs) @@ -40,6 +93,23 @@ def _local_optimizer_step(local_optim_rref, autograd_ctx_id): local_optim.step(autograd_ctx_id) +# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer +def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) + + with _ScriptLocalOptimizer.compile_lock: + script_optim = jit.script(optim) + return rpc.RRef( + script_optim, _ScriptLocalOptimizerInterface) + +@jit.script +def _script_local_optimizer_step( + local_optim_rref: RRef[_ScriptLocalOptimizerInterface], + autograd_ctx_id: int +) -> None: + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + def _wait_for_all(rpc_futs): # TODO: improve error propagation exception = None @@ -104,17 +174,34 @@ class DistributedOptimizer: >>> ) >>> dist_optim.step(context_id) """ + + # dict to map a user passed in optimizer_class to a functional + # optimizer class if we have already defined inside the + # distributed.optim package, this is so that we hide the + # functional optimizer to user and still provide the same API. + functional_optim_map = { + optim.Adagrad: _FunctionalAdagrad, + } + def __init__(self, optimizer_class, params_rref, *args, **kwargs): per_worker_params_rref = defaultdict(list) for param in params_rref: per_worker_params_rref[param.owner()].append(param) + optim_ctor = DistributedOptimizer.functional_optim_map.get(optimizer_class, optimizer_class) + self.is_functional_optim = (optim_ctor != optimizer_class) + + if self.is_functional_optim: + optimizer_new_func = _new_script_local_optimizer + else: + optimizer_new_func = _new_local_optimizer + remote_optim_futs = [] for worker, param_rrefs in per_worker_params_rref.items(): remote_optim_rref_fut = rpc.rpc_async( worker, - _new_local_optimizer, - args=(optimizer_class, param_rrefs) + args, + optimizer_new_func, + args=(optim_ctor, param_rrefs) + args, kwargs=kwargs, ) remote_optim_futs.append(remote_optim_rref_fut) @@ -136,11 +223,17 @@ def step(self, context_id): optimizer step. """ dist_autograd._is_valid_context(context_id) + + if self.is_functional_optim: + optimizer_step_func = _script_local_optimizer_step + else: + optimizer_step_func = _local_optimizer_step + rpc_futs = [] - for optim in self.remote_optimizers: + for optimizer in self.remote_optimizers: rpc_futs.append(rpc.rpc_async( - optim.owner(), - _local_optimizer_step, - args=(optim, context_id), + optimizer.owner(), + optimizer_step_func, + args=(optimizer, context_id), )) _wait_for_all(rpc_futs) diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/distributed/pipeline/sync/LICENSE b/torch/distributed/pipeline/sync/LICENSE new file mode 100644 index 0000000000000..e52be240fdc98 --- /dev/null +++ b/torch/distributed/pipeline/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright 2019-2020 Kakao Brain + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/distributed/pipeline/sync/__init__.py b/torch/distributed/pipeline/sync/__init__.py new file mode 100644 index 0000000000000..ca3c2a8823ade --- /dev/null +++ b/torch/distributed/pipeline/sync/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""A Pipe implementation in PyTorch.""" +from .checkpoint import is_checkpointing, is_recomputing +from .pipe import Pipe + +__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] diff --git a/torch/distributed/pipeline/sync/_balance/__init__.py b/torch/distributed/pipeline/sync/_balance/__init__.py new file mode 100644 index 0000000000000..a177ad6bd0226 --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/__init__.py @@ -0,0 +1,164 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""A helper to roughly balance a sequential module. + +Usage:: + + import torch + from torch.distributed.pipeline.sync import Pipe + from torch.distributed.pipeline.sync.balance import balance_by_time + + sample = torch.empty(128, 3, 224, 224) + balance = balance_by_time(torch.cuda.device_count(), model, sample) + + pipe = Pipe(model, balance, chunks=8) + +""" +from typing import List, Union, Sequence + +import torch +from torch import Tensor +import torch.nn as nn + +from . import blockpartition +from .profile import profile_sizes, profile_times + +__all__ = ["balance_by_time", "balance_by_size"] + + +Device = Union[torch.device, int, str] + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + + +def balance_cost(cost: List[int], partitions: int) -> List[int]: + partitioned = blockpartition.solve(cost, partitions) + return [len(p) for p in partitioned] + + +def balance_by_time( + partitions: int, + module: nn.Sequential, + sample: TensorOrTensors, + *, + timeout: float = 1.0, + device: Device = torch.device("cuda"), +) -> List[int]: + """Naive automatic balancing by elapsed time per layer. + :: + + sample = torch.empty(128, 3, 224, 224) + balance = balance_by_time(torch.cuda.device_count(), model, sample) + pipe = Pipe(model, balance, chunks=8) + + Args: + partitions (int): + intended number of partitions + module (torch.nn.Sequential): + sequential module to be partitioned + sample (torch.Tensor): + example input with arbitrary batch size + + Keyword Args: + timeout (float): + profiling iterates again if the timeout (in second) is not exceeded + (default: ``1.0``) + device ('cpu' or 'cuda' device): + CPU or CUDA device where each layer is profiled (default: the + current CUDA device) + + Returns: + A list of number of layers in each partition. Use it for the `balance` + parameter of :class:`~torchpipe.Pipe`. + + .. note:: + `module` and `sample` must be placed on the same device. + + """ + times = profile_times(module, sample, timeout, torch.device(device)) + return balance_cost(times, partitions) + + +def balance_by_size( + partitions: int, + module: nn.Sequential, + input: TensorOrTensors, + *, + chunks: int = 1, + param_scale: float = 2.0, + device: Device = torch.device("cuda"), +) -> List[int]: + """Naive automatic balancing by CUDA memory usage per layer. + + During training, required memory for parameters depends on which optimizer + is used. Optimizers may use buffers for each parameter to track + optimization statistics internally, such as momentum buffer in SGD. + + To get more reliable size based balance, you should specify `param_scale` + with regard to your optimizer. The default `param_scale` is 2 instead of 1 + due to gradient accumulation which is necessary for every optimizer. + + Follow this guide to choose correct `param_scale` for typical optimizers: + + ========= ============= ========================================= + Optimizer `param_scale` Internal State + ========= ============= ========================================= + SGD 2--3 (momentum_buffer) + Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) + Adadelta 4 square_avg, acc_delta + Adagrad 3 sum + RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) + ========= ============= ========================================= + + Here's a simple example with the Adam optimizer:: + + balance = balance_by_size( + torch.cuda.device_count(), + model, + + # Same size with mini-batch to train + torch.empty(1024, 3, 224, 224), + + # Number of micro-batches to train with Pipe + chunks=8, + + # 4 for Adam + param_scale=4.0, + ) + + pipe = Pipe(model, balance, chunks=8) + adam = Adam(pipe.parameters()) + + Args: + partitions (int): + intended number of partitions + module (torch.nn.Sequential): + sequential module to be partitioned + input (torch.Tensor): + example mini-batch with the same size to train + + Keyword Args: + chunks (int): + number of micro-batches will be used to train (default: ``1``) + param_scale (float): + how many copies of parameters would be allocated for training. It + depends on optimizer. See the above guide. (default: ``2.0``) + device ('cuda' device): + CUDA device where each layer is profiled (default: the current CUDA + device) + + Returns: + A list of number of layers in each partition. Use it for the `balance` + parameter of :class:`~torchpipe.Pipe`. + + .. note:: + `module` and `input` must be placed on the same CUDA device. + + """ + sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device)) + return balance_cost(sizes, partitions) diff --git a/torch/distributed/pipeline/sync/_balance/blockpartition.py b/torch/distributed/pipeline/sync/_balance/blockpartition.py new file mode 100644 index 0000000000000..7afe782f6ac8c --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/blockpartition.py @@ -0,0 +1,95 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Implements "Block Partitions of Sequences" by Imre Bárány et al. + +Paper: https://arxiv.org/pdf/1308.2452.pdf + +""" +from typing import Iterator, List, Tuple + +__all__ = ["solve"] + + +def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]: + """Splits a sequence into several partitions to minimize variance for each + partition. + + The result might not be optimal. However, it can be done only in O(kn³), + where k is the number of partitions and n is the length of the sequence. + + """ + if partitions < 1: + raise ValueError(f"partitions must be a positive integer ({partitions} < 1)") + + n = len(sequence) + if n < partitions: + raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})") + + # Normalize the sequence in [0, 1]. + minimum = min(sequence) + maximum = max(sequence) - minimum + + normal_sequence: List[float] + if maximum == 0: + normal_sequence = [0 for _ in sequence] + else: + normal_sequence = [(x - minimum) / maximum for x in sequence] + + splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n] + + def block_size(i: int) -> float: + start = splits[i - 1] if i > 0 else 0 + stop = splits[i] + return sum(normal_sequence[start:stop]) + + def leaderboard() -> Iterator[Tuple[float, int]]: + return ((block_size(i), i) for i in range(partitions)) + + while True: + """ + (1) Fix p ∈ [k] with M(P) = bp. So Bp is a maximal block of P. + """ + # max_size: M(P) + max_size, p = max(leaderboard()) + + while True: + """ + (2) If M(P) ≤ m(P) + 1, then stop. + """ + # min_size: m(P) + min_size, q = min(leaderboard()) + + if max_size <= min_size + 1: + return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)] + + """ + (3) If M(P) > m(P) + 1, then let m(P) = bq for the q ∈ [k] which is + closest to p (ties broken arbitrarily). Thus Bq is a minimal block + of P. Let Bh be the block next to Bq between Bp and Bq. (Note that + Bh is a non-empty block: if it were, then m(P) = 0 and we should + have chosen Bh instead of Bq.) + """ + if p < q: + """ + So either p < q and then h = q−1 and we define P ∗ by moving + the last element from Bh = Bq−1 to Bq, + """ + h = q - 1 + splits[h] -= 1 + else: + """ + or q < p, and then h = q + 1 and P ∗ is obtained by moving the + first element of Bh = Bq+1 to Bq. + """ + h = q + 1 + splits[q] += 1 + + """ + Set P = P ∗ . If p = h, then go to (1), else go to (2). + """ + if p == h: + break diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py new file mode 100644 index 0000000000000..382da988e808f --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/profile.py @@ -0,0 +1,114 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Per-layer profilers.""" +import copy +import time +from typing import Generator, List, Union, Sequence + +import torch +from torch import Tensor +import torch.nn as nn + +from ..microbatch import Batch + +__all__: List[str] = [] + + +Device = Union[torch.device, int, str] + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + + +def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: + """Copies layers for ease to profile. It doesn't modify the given + module. + """ + for layer in module: + layer_copy = copy.deepcopy(layer) + layer_copy.to(device) + layer_copy.train() + yield layer_copy + + +def detach(batch: Batch) -> None: + """Detaches from autograd graph.""" + for i, x in enumerate(batch): + batch[i] = x.detach().requires_grad_(x.requires_grad) + + +def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]: + """Profiles elapsed times per layer.""" + if any(p.grad is not None for p in module.parameters()): + raise ValueError("some parameter already has gradient") + + _batch = Batch(sample) + for i, x in enumerate(_batch): + _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) + + time_bufs: List[List[float]] = [[] for _ in module] + begun_at = time.time() + + while time.time() - begun_at < timeout: + batch = _batch + + for i, layer in enumerate(layerwise_sandbox(module, device)): + detach(batch) + + if device.type == "cuda": + torch.cuda.synchronize(device) + tick = time.time() + + # Forward + batch = batch.call(layer) + + # Backward + backward_tensors = tuple(y for y in batch if y.requires_grad) + if backward_tensors: + torch.autograd.backward(backward_tensors, backward_tensors) + + if device.type == "cuda": + torch.cuda.synchronize(device) + tock = time.time() + + time_bufs[i].append(tock - tick) + + us = 1_000_000 + return [sum(int(t * us) for t in buf) for buf in time_bufs] + + +def profile_sizes( + module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device, +) -> List[int]: + """Profiles CUDA memory usage per layer.""" + if device.type != "cuda": + raise ValueError("size profiler supports only CUDA device") + + batch = Batch(input) + sizes: List[int] = [] + + latent_scale = batch[0].size(0) / chunks + for i, x in enumerate(batch): + batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) + + for layer in layerwise_sandbox(module, device): + detach(batch) + + # Detect memory usage at forward. + memory_before = torch.cuda.memory_allocated(device) + batch = batch.call(layer) + memory_after = torch.cuda.memory_allocated(device) + latent_size = memory_after - memory_before + + # Analyze size of parameters. + param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters()) + + # Combine size of parameters and activations with normalize scales. + size = latent_size * latent_scale + param_size * param_scale + sizes.append(int(size)) + + return sizes diff --git a/torch/distributed/pipeline/sync/_balance/py.typed b/torch/distributed/pipeline/sync/_balance/py.typed new file mode 100644 index 0000000000000..ab03724cafbf5 --- /dev/null +++ b/torch/distributed/pipeline/sync/_balance/py.typed @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py new file mode 100644 index 0000000000000..983f10f3ff369 --- /dev/null +++ b/torch/distributed/pipeline/sync/batchnorm.py @@ -0,0 +1,163 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Tracks the running statistics per mini-batch instead of micro-batch.""" +from typing import Optional, TypeVar, cast + +import torch +from torch import Tensor, nn +from torch.nn.functional import batch_norm +from torch.nn.modules.batchnorm import _BatchNorm + +from .checkpoint import is_recomputing + +__all__ = ["DeferredBatchNorm"] + + +TModule = TypeVar("TModule", bound=nn.Module) + + +class DeferredBatchNorm(_BatchNorm): + """A BatchNorm layer tracks multiple micro-batches to update running + statistics per mini-batch. + """ + + sum: Tensor + sum_squares: Tensor + running_mean: Tensor + running_var: Tensor + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + chunks: int = 1, + ) -> None: + super().__init__(num_features, eps, momentum, affine, track_running_stats=True) + + self.register_buffer("sum", torch.zeros_like(self.running_mean)) + self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) + + self.counter = 0 + self.tracked = 0 + self.chunks = chunks + + def _check_input_dim(self, input: Tensor) -> None: + # It's the typical _check_input_dim() implementation in PyTorch. + if input.dim() <= 2: + raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) + + def _track(self, input: Tensor) -> bool: + """Tracks statistics of a micro-batch.""" + # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. + dim = [0] + dim.extend(range(2, input.dim())) + + with torch.no_grad(): + self.sum += input.sum(dim) + self.sum_squares += (input ** 2).sum(dim) + + size = input.size().numel() // input.size(1) + self.counter += size + self.tracked += 1 + + return self.tracked == self.chunks + + def _commit(self) -> None: + """Updates the running statistics of a mini-batch.""" + exponential_average_factor = 0.0 + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + mean = self.sum / self.counter + var = self.sum_squares / self.counter - mean ** 2 + + # Calculate the exponential moving average here. + m = exponential_average_factor + + self.running_mean *= 1 - m + self.running_mean += mean * m + + self.running_var *= 1 - m + self.running_var += var * m + + self.sum.zero_() + self.sum_squares.zero_() + self.counter = 0 + self.tracked = 0 + + def forward(self, input: Tensor) -> Tensor: # type: ignore + if not self.training: + # Don't train parameters on the evaluation mode. + return batch_norm( + input, + running_mean=self.running_mean, + running_var=self.running_var, + weight=self.weight, + bias=self.bias, + training=False, + momentum=0.0, + eps=self.eps, + ) + + if not is_recomputing(): + # Track a micro-batch on the training mode + # but not under a recomputation. + tracked_enough = self._track(input) + + # Update the running statistics for a mini-batch + # if it has tracked enough micro-batches. + if tracked_enough: + self._commit() + + # Normalize a micro-batch and train the parameters. + return batch_norm( + input, + running_mean=None, + running_var=None, + weight=self.weight, + bias=self.bias, + training=True, + momentum=0.0, + eps=self.eps, + ) + + @classmethod + def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: + """Converts a :class:`nn.BatchNorm` or underlying + :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: + + from torchvision.models.resnet import resnet101 + from torchpipe.batchnorm import DeferredBatchNorm + model = resnet101() + model = DeferredBatchNorm.convert_deferred_batch_norm(model) + + """ + if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: + return cast(TModule, module) + + module_output: nn.Module = module + + if isinstance(module, _BatchNorm) and module.track_running_stats: + module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) + if module.affine: + module_output.register_parameter("weight", module.weight) + module_output.register_parameter("bias", module.bias) + assert isinstance(module.running_mean, Tensor) + assert isinstance(module.running_var, Tensor) + module_output.register_buffer("running_mean", module.running_mean) + module_output.register_buffer("running_var", module.running_var) + module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) + + for name, child in module.named_children(): + module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) + + return cast(TModule, module_output) diff --git a/torch/distributed/pipeline/sync/checkpoint.py b/torch/distributed/pipeline/sync/checkpoint.py new file mode 100644 index 0000000000000..3f9240793183f --- /dev/null +++ b/torch/distributed/pipeline/sync/checkpoint.py @@ -0,0 +1,326 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Checkpointing with preceding recomputation. + +PyTorch already provides the official checkpointing utilities in +:mod:`torch.utils.checkpoint`. The official checkpointing combines +recomputation and recursive backpropagation into one autograd function named +``CheckpointFunction``. Hence, the recomputation can be started only when the +gradients arrive to the function. In Pipe, the recomputation needs to precede +the gradient arrival to minimize the GPU idle time. + +We solve this problem by introducing separate autograd functions named +:class:`Recompute` and :class:`Checkpoint`. Each function represents +recomputation and recursive backpropagation, respectively. We can manipulate +the control flow in aspect of both the autograd engine and CUDA with a pair of +the functions. + +Specifically, we place CUDA stream synchronization between :class:`Recompute` +and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is +copied entirely. + +""" +from collections import deque +from contextlib import contextmanager +import threading +from typing import ( + TYPE_CHECKING, + Deque, + Generator, + List, + Optional, + Union, + Sequence, + Tuple +) + +import torch +from torch import Tensor +import torch.autograd + +from .dependency import fork, join +from .microbatch import Batch +from .phony import get_phony + +__all__ = ["is_checkpointing", "is_recomputing"] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +# Types for shared memory between Checkpoint and Recompute. +Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) +RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) + + +if TYPE_CHECKING: + from typing_extensions import Protocol +else: + Protocol = object + + +# Protocol with __call__ instead of Callable can be used as an attribute type. +# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 +class Function(Protocol): + def __call__(self, input: TensorOrTensors) -> TensorOrTensors: + ... + + +def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors: + """Makes a checkpoint with a simple interface like + :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug + :class:`Checkpoint` and :class:`Recompute` without boilerplate. + """ + batch = Batch(input) + + chk = Checkpointing(function, batch) + batch = chk.checkpoint() + chk.recompute(batch) + + return batch.tensor_or_tensors + + +class Checkpointing: + """Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" + + def __init__(self, function: Function, batch: Batch) -> None: + self.function = function + self.batch = batch + + # Shared memory between Checkpoint and Recompute. 1-length deque is + # used for mutability and length limitation. + self.recomputed: Deque[Recomputed] = deque(maxlen=1) + self.rng_states: Deque[RNGStates] = deque(maxlen=1) + + def checkpoint(self) -> Batch: + """Returns a batch applied by :class:`Checkpoint`.""" + input_atomic = self.batch.atomic + input = tuple(self.batch) + + # Use a phony which requires grad to ensure that Checkpoint can be + # tracked by the autograd engine even when none of the input tensors + # require grad. + phony = get_phony(self.batch[0].device, requires_grad=True) + + output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input) + + # Gradients are only supported for float Tensors. + if isinstance(output, tuple): + output = tuple([x if x.is_floating_point() else x.detach() for x in output]) + + return Batch(output) + + def recompute(self, batch: Batch) -> None: + """Applies :class:`Recompute` to the batch in place.""" + input_atomic = self.batch.atomic + input = tuple(self.batch) + + # batch[0] is always requiring grad, because it has been passed + # checkpoint with a phony requiring grad. + batch[0], phony = fork(batch[0]) + phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input) + batch[0] = join(batch[0], phony) + + +class ThreadLocal(threading.local): + def __init__(self) -> None: + self.is_checkpointing = False + self.is_recomputing = False + + +thread_local = ThreadLocal() + + +@contextmanager +def enable_checkpointing() -> Generator[None, None, None]: + """Makes :func:`is_checkpointing` return :data:`True` within a context.""" + orig = thread_local.is_checkpointing + thread_local.is_checkpointing = True + try: + yield + finally: + thread_local.is_checkpointing = orig + + +@contextmanager +def enable_recomputing() -> Generator[None, None, None]: + """Makes :func:`is_recomputing` return :data:`True` within a context.""" + orig = thread_local.is_recomputing + thread_local.is_recomputing = True + try: + yield + finally: + thread_local.is_recomputing = orig + + +def is_checkpointing() -> bool: + """Whether the current forward propagation is under checkpointing. + + Returns: + bool: :data:`True` if it's under checkpointing. + + """ + return thread_local.is_checkpointing + + +def is_recomputing() -> bool: + """Whether the current forward propagation is under checkpoint + recomputation. Use this to prevent duplicated side-effects at forward + propagation:: + + class Counter(nn.Module): + def __init__(self): + super().__init__() + self.counter = 0 + + def forward(self, input): + if not is_recomputing(): + self.counter += 1 + return input + + Returns: + bool: :data:`True` if it's under checkpoint recomputation. + + .. seealso:: :ref:`Detecting Recomputation` + + """ + return thread_local.is_recomputing + + +class Context: + """The common interface between the :class:`Checkpoint` and + :class:`Recompute` context. + """ + + recomputed: Deque[Recomputed] + rng_states: Deque[RNGStates] + function: Function + input_atomic: bool + + saved_tensors: Tuple[Tensor, ...] + + def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover + pass + + +def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: + """:meth:`Checkpoint.forward` captures the current PyTorch's random number + generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. + + .. seealso:: :ref:`Referential Transparency` + + """ + cpu_rng_state = torch.get_rng_state() + + gpu_rng_state: Optional[Tensor] + if device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state(device) + else: + gpu_rng_state = None + + rng_states.append((cpu_rng_state, gpu_rng_state)) + + +@contextmanager +def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: + """:meth:`Recompute.backward` restores the random number generator states + captured by :func:`save_rng_states` within its context. + + .. seealso:: :ref:`Referential Transparency` + + """ + cpu_rng_state, gpu_rng_state = rng_states.pop() + + gpu_devices: List[torch.device] = [] + if device.type == "cuda": + gpu_devices.append(device) + + with torch.random.fork_rng(gpu_devices): + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + torch.cuda.set_rng_state(gpu_rng_state, device) + yield + + +class Checkpoint(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx: Context, + phony: Tensor, + recomputed: Deque[Recomputed], + rng_states: Deque[RNGStates], + function: Function, + input_atomic: bool, + *input: Tensor, + ) -> TensorOrTensors: + ctx.recomputed = recomputed + ctx.rng_states = rng_states + + save_rng_states(input[0].device, ctx.rng_states) + + ctx.function = function + ctx.input_atomic = input_atomic + ctx.save_for_backward(*input) + + with torch.no_grad(), enable_checkpointing(): + output = function(input[0] if input_atomic else input) + + return output + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover + output, input_leaf = ctx.recomputed.pop() + + if isinstance(output, tuple): + tensors = output + else: + tensors = (output,) + if any(y.requires_grad for y in tensors): + tensors = tuple([x for x in tensors if x.requires_grad]) + torch.autograd.backward(tensors, grad_output) + + grad_input: List[Optional[Tensor]] = [None, None, None, None, None] + grad_input.extend(x.grad for x in input_leaf) + return tuple(grad_input) + + +class Recompute(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx: Context, + phony: Tensor, + recomputed: Deque[Recomputed], + rng_states: Deque[RNGStates], + function: Function, + input_atomic: bool, + *input: Tensor, + ) -> Tensor: + ctx.recomputed = recomputed + ctx.rng_states = rng_states + + ctx.function = function + ctx.input_atomic = input_atomic + ctx.save_for_backward(*input) + + return phony + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover + input = ctx.saved_tensors + input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input) + + with restore_rng_states(input[0].device, ctx.rng_states): + with torch.enable_grad(), enable_recomputing(): + output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf) + + ctx.recomputed.append((output, input_leaf)) + + grad_input: List[None] = [None, None, None, None, None] + grad_input.extend(None for _ in ctx.saved_tensors) + return tuple(grad_input) diff --git a/torch/distributed/pipeline/sync/copy.py b/torch/distributed/pipeline/sync/copy.py new file mode 100644 index 0000000000000..07e71a87ce08a --- /dev/null +++ b/torch/distributed/pipeline/sync/copy.py @@ -0,0 +1,104 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Autograd functions for stream-aware CUDA copy. It is used to overlap copy +and computation on the same GPU. +""" +from collections import deque +from typing import Deque, List, Optional, Tuple, Sequence + +import torch +from torch import Tensor + +from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream + +__all__: List[str] = [] + + +Tensors = Sequence[Tensor] + + +# Common interface between :class:`Copy` and :class:`Wait`. +class Context: + prev_stream: AbstractStream + next_stream: AbstractStream + + +class Copy(torch.autograd.Function): + """Copies tensors on specific streams.""" + + @staticmethod + # type: ignore + def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: + ctx.prev_stream = prev_stream + ctx.next_stream = next_stream + + output = [] + output_stream = current_stream(get_device(next_stream)) + + with use_stream(prev_stream), use_stream(next_stream): + for x in input: + y = x.to(get_device(next_stream), non_blocking=True) + output.append(y) + + # 'prev_stream' is not where 'x' has been allocated. + record_stream(x, prev_stream) + # 'y' has been allocated on 'next_stream'. + # It might be used on the current stream captured as 'output_stream'. + record_stream(y, output_stream) + + return tuple(output) + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: + prev_stream = ctx.prev_stream + next_stream = ctx.next_stream + + grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) + input_stream = current_stream(get_device(prev_stream)) + + with use_stream(prev_stream), use_stream(next_stream): + for x in reversed(grad_output): + y = x.to(get_device(prev_stream), non_blocking=True) + grad_input.appendleft(y) + + # 'next_stream' is not where 'x' has been allocated. + record_stream(x, next_stream) + # 'y' has been allocated on 'prev_stream'. + # It might be used on the current stream captured as 'input_stream'. + record_stream(y, input_stream) + + grad_streams: Tuple[Optional[Tensor], ...] = (None, None) + return grad_streams + tuple(grad_input) + + +class Wait(torch.autograd.Function): + """Synchronizes a stream to another stream. + + Place it just before you want to start an operation on the next stream, + provided that all operations on the previous stream are done. + + """ + + @staticmethod + # type: ignore + def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: + ctx.prev_stream = prev_stream + ctx.next_stream = next_stream + + wait_stream(next_stream, prev_stream) + + return tuple(x.detach() for x in input) + + @staticmethod + def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: + prev_stream = ctx.prev_stream + next_stream = ctx.next_stream + + wait_stream(prev_stream, next_stream) + + grad_streams: Tuple[Optional[Tensor], ...] = (None, None) + return grad_streams + grad_input diff --git a/torch/distributed/pipeline/sync/dependency.py b/torch/distributed/pipeline/sync/dependency.py new file mode 100644 index 0000000000000..aeebc11aeeba3 --- /dev/null +++ b/torch/distributed/pipeline/sync/dependency.py @@ -0,0 +1,54 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Arbitrary dependency between two autograd lanes.""" +from typing import List, Tuple + +import torch +from torch import Tensor + +from .phony import get_phony + +__all__: List[str] = [] + + +def fork(input: Tensor) -> Tuple[Tensor, Tensor]: + """Branches out from an autograd lane of the given tensor.""" + if torch.is_grad_enabled() and input.requires_grad: + input, phony = Fork.apply(input) + else: + phony = get_phony(input.device, requires_grad=False) + + return input, phony + + +class Fork(torch.autograd.Function): + @staticmethod + def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + phony = get_phony(input.device, requires_grad=False) + return input.detach(), phony.detach() + + @staticmethod + def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore + return grad_input + + +def join(input: Tensor, phony: Tensor) -> Tensor: + """Merges two autograd lanes.""" + if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): + input = Join.apply(input, phony) + + return input + + +class Join(torch.autograd.Function): + @staticmethod + def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore + return input.detach() + + @staticmethod + def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore + return grad_input, None diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py new file mode 100644 index 0000000000000..fc4daf7a9b426 --- /dev/null +++ b/torch/distributed/pipeline/sync/microbatch.py @@ -0,0 +1,186 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Manipulation of micro-batches.""" +import typing +from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence + +import torch +from torch import Tensor +import torch.cuda.comm + +__all__: List[str] = [] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] +Function = Callable[[TensorOrTensors], TensorOrTensors] + + +class Batch: + """An abstraction of an atomic tensor or a tuple of tensors. This + eliminates every boilerplate code to classify an atomic tensor or a tuple + of tensors. + :: + + x = generate_tensor_or_tensors() + x = Batch(x) + + # in-place update + x[0] = F.apply(x[0]) + x[:] = F.apply(*x) + + # f(x) if x is a tensor. + # f(*x) if x is a tuple of tensors. + # y is also a batch. + y = x.call(f) + + """ + + def __init__(self, value: TensorOrTensors) -> None: + self.value = value + self.atomic = torch.is_tensor(value) + + @property + def tensor(self) -> Tensor: + """Retrieves the underlying tensor.""" + if not self.atomic: + raise AttributeError("not atomic batch") + return cast(Tensor, self.value) + + @property + def tensors(self) -> Tensors: + """Retrieves the underlying tensors.""" + if self.atomic: + raise AttributeError("batch is atomic") + return cast(Tensors, self.value) + + @property + def tensor_or_tensors(self) -> TensorOrTensors: + """Retrieves the underlying tensor or tensors regardless of type.""" + return self.value + + def call(self, function: Function) -> "Batch": + """Calls a function by the underlying tensor or tensors. It also wraps + the output with :class:`Batch`. + """ + return Batch(function(self.value)) + + def __repr__(self) -> str: + return f"Batch[atomic={self.atomic!r}]({self.value!r})" + + def __iter__(self) -> Iterator[Tensor]: + if self.atomic: + yield self.tensor + else: + yield from self.tensors + + def __len__(self) -> int: + return 1 if self.atomic else len(self.tensors) + + def __getitem__(self, index: int) -> Tensor: + if not self.atomic: + return self.tensors[index] + + if index != 0: + raise IndexError("atomic batch allows index 0 only") + + return self.tensor + + # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". + @typing.overload + def __setitem__(self, index: int, value: Tensor) -> None: + ... + + @typing.overload + def __setitem__(self, index: slice, value: Tensors) -> None: + ... + + def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None: + if isinstance(index, int): + value = cast(Tensor, value) + self._setitem_by_index(index, value) + else: + value = cast(Tensors, value) + self._setitem_by_slice(index, value) + + def _setitem_by_index(self, index: int, value: Tensor) -> None: + if not self.atomic: + i = index + self.value = self.value[:i] + (value,) + self.value[i + 1 :] # type: ignore + return + + if index != 0: + raise IndexError("atomic batch allows index 0 only") + + self.value = value + + def _setitem_by_slice(self, index: slice, value: Tensors) -> None: + if not (index.start is index.stop is index.step is None): + raise NotImplementedError("only slice [:] supported") + + if not self.atomic: + self.value = value + return + + if len(value) != 1: + raise IndexError("atomic batch cannot be replaced with multiple tensors") + + self.value = value[0] + + +def check(input: TensorOrTensors) -> None: + """Checks whether the input is a tensor or tensors. + + Raises: + TypeError: input is not a tensor or tensors. + + """ + if isinstance(input, Sequence): + for x in input: + if not isinstance(x, Tensor): + raise TypeError(f"expected Tensor, but got {input.__class__.__name__}") + return + + if not isinstance(input, Tensor): + raise TypeError(f"expected Tensor, but got {input.__class__.__name__}") + + +def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]: + """Splits an input mini-batch into multiple micro-batches.""" + inputs: Iterable[TensorOrTensors] + + if isinstance(input, Tensor): + inputs = input.chunk(chunks) + else: + rotated: List[Tensors] = [] + + for tensor in input: + tensors = tensor.chunk(chunks) + rotated.append(cast(Tensors, tensors)) + + inputs = zip(*rotated) + + return [Batch(x) for x in inputs] + + +def gather(outputs: List[Batch]) -> TensorOrTensors: + """Concatenates output micro-batches into a mini-batch.""" + output: TensorOrTensors + + if outputs[0].atomic: + tensors = tuple(b.tensor for b in outputs) + output = torch.cat(tensors) + else: + rotated = [b.tensors for b in outputs] + output_buf = [] + + for tensors in zip(*rotated): + output_buf.append(torch.cat(tensors)) + + output = tuple(output_buf) + + return output diff --git a/torch/distributed/pipeline/sync/phony.py b/torch/distributed/pipeline/sync/phony.py new file mode 100644 index 0000000000000..5e89ff0efd270 --- /dev/null +++ b/torch/distributed/pipeline/sync/phony.py @@ -0,0 +1,49 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Provides phony for arbitrary dependency in a autograd graph.""" +from typing import Dict, List, Tuple + +import torch +from torch import Tensor + +from .stream import default_stream, use_stream + +__all__: List[str] = [] + + +_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} + + +def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: + """Gets a phony. Phony is tensor without space. It is useful to make + arbitrary dependency in a autograd graph because it doesn't require any + gradient accumulation. + + .. note:: + + Phonies for each device are cached. If an autograd function gets a phony + internally, the phony must be detached to be returned. Otherwise, the + autograd engine will mutate the cached phony in-place:: + + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() # detach() is necessary. + + """ + key = (device, requires_grad) + + try: + phony = _phonies[key] + except KeyError: + with use_stream(default_stream(device)): + phony = torch.empty(0, device=device, requires_grad=requires_grad) + + _phonies[key] = phony + + return phony diff --git a/torch/distributed/pipeline/sync/pipe.py b/torch/distributed/pipeline/sync/pipe.py new file mode 100644 index 0000000000000..d191d880d299e --- /dev/null +++ b/torch/distributed/pipeline/sync/pipe.py @@ -0,0 +1,366 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The Pipe interface.""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast, Sequence + +import torch +from torch import Tensor, nn +from torch.distributed.rpc import RRef +import torch.autograd +import torch.cuda + +from . import microbatch +from .batchnorm import DeferredBatchNorm +from .pipeline import Pipeline +from .skip.layout import inspect_skip_layout +from .skip.skippable import verify_skippables +from .stream import AbstractStream, new_stream + +__all__ = ["Pipe"] + + +Device = Union[torch.device, int, str] +Devices = Union[Iterable[Device], List[Device]] + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +if TYPE_CHECKING: + # Typechecking: nn.Module is not a Generic + Module = nn.Module[TensorOrTensors] # type: ignore[type-arg] + NamedModules = OrderedDict[str, Module] +else: + Module = nn.Module + NamedModules = OrderedDict + + +def _recommend_auto_balance(message: str) -> str: + """Expands a message with recommendation to :mod:`torchpipe.balance`.""" + return f"""{message} + +If your model is still under development, its optimal balance would change +frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for +naive automatic balancing: + + from torch.distributed.pipeline.sync import Pipe + from torch.distributed.pipeline.sync.balance import balance_by_time + + partitions = torch.cuda.device_count() + sample = torch.empty(...) + balance = balance_by_time(partitions, model, sample) + + model = Pipe(model, balance, ...) +""" + + +def _verify_module(module: nn.Sequential) -> None: + if not isinstance(module, nn.Sequential): + raise TypeError("module must be nn.Sequential to be partitioned") + + named_children = list(module.named_children()) + if len(named_children) != len(module): + raise ValueError("module with duplicate children is not supported") + + +def _verify_splitting( + module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] +) -> None: + num_parameters = len(list(module.parameters())) + num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) + if num_parameters == num_child_parameters: + return + + for i in range(len(partitions)): + for j in range(i + 1, len(partitions)): + parti = partitions[i] + partj = partitions[j] + if devices[i] == devices[j]: + continue + for p in parti.parameters(): + for q in partj.parameters(): + if p is q: + raise ValueError("module with duplicate parameters on distinct devices is not supported") + + +class BalanceError(ValueError): + pass + + +def _retrieve_device(module: nn.Module) -> torch.device: + """Validates all parameters in the Module have the same device and returns + the appropriate device. + + Args: + An ``nn.Module`` to process. + + Returns: + ``torch.Device`` for the entire module. + + Raises: + ValueError: + If devices for ``nn.Module`` parameters are not all same. + """ + + device = None + for parameter in module.parameters(): + if device is None: + device = parameter.device + elif device != parameter.device: + raise ValueError( + 'nn.Module: {}, should have all parameters on a single device,' + ' please use .to() to place the module on a single device'.format(module)) + + return device if device is not None else torch.device("cpu") + +def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: + partitions = [] + devices = [] + for name, module in modules.named_children(): + devices.append(_retrieve_device(module)) + if isinstance(module, nn.Sequential): + partition = module + else: + partition = nn.Sequential(OrderedDict([(name, module)])) + partitions.append(partition) + + partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) + + return partitions, devices + +MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement") + + +class Pipe(Module): + """Wraps an arbitrary :class:`nn.Sequential ` module + to train on using synchronous pipeline parallelism. If the module requires + lots of memory and doesn't fit on a single GPU, pipeline parallelism is a + useful technique to employ for training. + + The implementation is based on the torchgpipe_ paper. + + .. _torchgpipe: https://arxiv.org/abs/2004.09910 + + Pipe combines pipeline parallelism with checkpointing to reduce peak + memory required to train while minimizing device under-utilization. + + You should place all the modules on the appropriate devices and wrap them + into an :class:`nn.Sequential ` module defining the + desired order of execution. + + Args: + module (:class:`nn.Sequential `): + sequential module to be parallelized using pipelining. Each module + in the sequence has to have all of its parameters on a single + device. Each module in the sequence has to either be an nn.Module + or :class:`nn.Sequential ` (to combine multiple + sequential modules on a single device) + chunks (int): + number of micro-batches (default: ``1``) + checkpoint (str): + when to enable checkpointing, one of ``'always'``, + ``'except_last'``, or ``'never'`` (default: ``'except_last'``). + ``'never'`` disables checkpointing completely, ``'except_last'`` + enables checkpointing for all micro-batches except the last one + and ``'always'`` enables checkpointing for all micro-batches. + deferred_batch_norm (bool): + whether to use deferred ``BatchNorm`` moving statistics (default: + :data:`False`). If set to :data:`True`, we track statistics across + multiple micro-batches to update the running statistics per + mini-batch. + + Raises: + TypeError: + the module is not a :class:`nn.Sequential `. + ValueError: + invalid arguments + + Example:: + Pipeline of two FC layers across GPUs 0 and 1. + + >>> fc1 = nn.Linear(16, 8).cuda(0) + >>> fc2 = nn.Linear(8, 4).cuda(1) + >>> model = nn.Sequential(fc1, fc2) + >>> model = Pipe(model, chunks=8) + >>> input = torch.rand(16, 16).cuda(0) + >>> output_rref = model(input) + + .. note:: + You can wrap a :class:`Pipe` model with + :class:`torch.nn.parallel.DistributedDataParallel` only when the + checkpoint parameter of :class:`Pipe` is ``'never'``. + + .. note:: + :class:`Pipe` only supports intra-node pipelining currently, but + will be expanded to support inter-node pipelining in the future. + The forward function returns an :class:`~torch.distributed.rpc.RRef` + to allow for inter-node pipelining in the future, where the output + might be on a remote host. For intra-node pipelinining you can use + :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the + output locally. + + .. warning:: + :class:`Pipe` is experimental and subject to change. + """ + + def __init__( + self, + module: nn.Sequential, + chunks: int = 1, + checkpoint: str = "except_last", + deferred_batch_norm: bool = False, + ) -> None: + super().__init__() + + chunks = int(chunks) + checkpoint = str(checkpoint) + + if chunks <= 0: + raise ValueError("number of chunks must be positive integer") + if checkpoint not in ["always", "except_last", "never"]: + raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") + + _verify_module(module) + + # Verify if the underlying skippable modules satisfy integrity. The + # integrity can be verified before forward() because it is static. + verify_skippables(module) + + self.chunks = chunks + self.checkpoint = checkpoint + + if deferred_batch_norm: + module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) + + self.partitions, self.devices = _split_module(module) + _verify_splitting(module, self.partitions, self.devices) + + self._copy_streams: List[List[AbstractStream]] = [] + self._skip_layout = inspect_skip_layout(self.partitions) + + # Separate CUDA streams for copy. + copy_streams = self._ensure_copy_streams() + + # The micro-batch index where the checkpointing stops. + checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] + + self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) + + def __len__(self) -> int: + """Counts the length of the underlying sequential module.""" + return sum(len(p) for p in self.partitions) + + def __getitem__(self, index: int) -> nn.Module: + """Gets a layer in the underlying sequential module.""" + partitions = self.partitions + if index < 0: + partitions = partitions[::-1] + + for partition in partitions: + try: + return partition[index] + except IndexError: + pass + + shift = len(partition) + + if index < 0: + index += shift + else: + index -= shift + + raise IndexError + + def __iter__(self) -> Iterable[nn.Module]: + """Iterates over children of the underlying sequential module.""" + for partition in self.partitions: + yield from partition + + # Pipe should manage the device of each partition. + # Deny cuda(), cpu(), and to() with device, by TypeError. + def cuda(self, device: Optional[Device] = None) -> "Pipe": + raise MOVING_DENIED + + def cpu(self) -> "Pipe": + raise MOVING_DENIED + + def to(self, *args: Any, **kwargs: Any) -> "Pipe": + # Deny these usages: + # + # - to(device[, dtype, non_blocking]) + # - to(tensor[, non_blocking]) + # + # But allow this: + # + # - to(dtype[, non_blocking]) + # + if "device" in kwargs or "tensor" in kwargs: + raise MOVING_DENIED + + if args: + if isinstance(args[0], (torch.device, int, str)): + raise MOVING_DENIED + if torch.is_tensor(args[0]): + raise MOVING_DENIED + + return super().to(*args, **kwargs) + + def _ensure_copy_streams(self) -> List[List[AbstractStream]]: + """Ensures that :class:`Pipe` caches CUDA streams for copy. + + It's worth to cache CUDA streams although PyTorch already manages a + pool of pre-allocated CUDA streams, because it may reduce GPU memory + fragementation when the number of micro-batches is small. + + """ + if not self._copy_streams: + for device in self.devices: + self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) + + return self._copy_streams + + def forward(self, input) -> RRef: # type: ignore + """ + Processes a single input mini-batch through the pipe and returns an + :class:`~torch.distributed.rpc.RRef` pointing to the output. + :class:`Pipe` is a fairly transparent module wrapper. It doesn't + modify the input and output signature of the underlying module. But + there's type restriction. Input and output have to be a + :class:`~torch.Tensor` or a sequence of tensors. This restriction is + applied at partition boundaries too. + + The input tensor is split into multiple micro-batches based on the + ``chunks`` parameter used to initialize :class:`Pipe`. The batch size + is assumed to be the first dimension of the tensor and if the batch + size is less than ``chunks``, the number of micro-batches is equal to + the batch size. + + Args: + input (torch.Tensor or sequence of :class:`~torch.Tensor`): input mini-batch + + Returns: + :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch + + Raises: + TypeError: input is not a tensor or sequence of tensors. + + """ + microbatch.check(input) + + if not self.devices: + # Empty sequential module is not illegal. + return RRef(input) + + # Divide a mini-batch into micro-batches. + batches = microbatch.scatter(input, self.chunks) + + # Run pipeline parallelism. + self.pipeline.run(batches) + + # Merge the micro-batches into one mini-batch. + output = microbatch.gather(batches) + return RRef(output) diff --git a/torch/distributed/pipeline/sync/pipeline.py b/torch/distributed/pipeline/sync/pipeline.py new file mode 100644 index 0000000000000..439a45c0f82e2 --- /dev/null +++ b/torch/distributed/pipeline/sync/pipeline.py @@ -0,0 +1,257 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The pipeline parallelism of Pipe.""" +from queue import Queue +from types import TracebackType +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence + +import torch +from torch import Tensor, nn +from torch.autograd.profiler import record_function + +from .checkpoint import Checkpointing +from .copy import Copy, Wait +from .dependency import fork, join +from .microbatch import Batch +from .skip.layout import SkipLayout +from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker +from .stream import AbstractStream, current_stream, use_device +from .worker import Task, create_workers, join_workers + +__all__: List[str] = [] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] + +# Queue is generic only in stubs. +# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +if TYPE_CHECKING: + InQueue = Queue[Optional["Task"]] + OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] +else: + InQueue = Queue + OutQueue = Queue + + +def _depend(fork_from: Batch, join_to: Batch) -> None: + fork_from[0], phony = fork(fork_from[0]) + join_to[0] = join(join_to[0], phony) + + +def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: + batch[:] = Copy.apply(prev_stream, next_stream, *batch) + # Gradients are only supported for float Tensors. + batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch]) + + +def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: + batch[:] = Wait.apply(prev_stream, next_stream, *batch) + # Gradients are only supported for float Tensors. + batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch]) + + +def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: + """Generates schedules for each clock cycle.""" + # m: number of micro-batches + # n: number of partitions + # i: index of micro-batch + # j: index of partition + # k: clock number + # + # k (i,j) (i,j) (i,j) + # - ----- ----- ----- + # 0 (0,0) + # 1 (1,0) (0,1) + # 2 (2,0) (1,1) (0,2) + # 3 (2,1) (1,2) + # 4 (2,2) + for k in range(m + n - 1): + yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] + + +class Pipeline: + """The pipeline parallelism for Pipe.""" + + def __init__( + self, + partitions: List[nn.Sequential], + devices: List[torch.device], + copy_streams: List[List[AbstractStream]], + skip_layout: SkipLayout, + checkpoint_stop: int, + ) -> None: + self.partitions = partitions + self.devices = devices + self.copy_streams = copy_streams + self.skip_layout = skip_layout + self.checkpoint_stop = checkpoint_stop + (self.in_queues, self.out_queues) = create_workers(devices) + + def __del__(self) -> None: + join_workers(self.in_queues, self.out_queues) + + def run(self, batches: List[Batch]) -> None: + """Runs pipeline parallelism. + + It modifies the given batches in place. + + """ + partitions = self.partitions + devices = self.devices + skip_layout = self.skip_layout + + m = len(batches) + n = len(partitions) + + skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] + + for schedule in _clock_cycles(m, n): + self.fence(batches, schedule, skip_trackers) + self.compute(batches, schedule, skip_trackers) + + def fence( + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + ) -> None: + """Copies micro-batches after computation for the previous + micro-batches. + """ + copy_streams = self.copy_streams + skip_layout = self.skip_layout + + for i, j in schedule: + # Ensure that batches[i-1] is executed after batches[i] in + # backpropagation by an explicit dependency. + if i != 0 and j != 0: + _depend(batches[i - 1], batches[i]) + + next_stream = copy_streams[j][i] + + for prev_j, ns, name in skip_layout.copy_policy(j): + prev_stream = copy_streams[prev_j][i] + skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) + + if j != 0: + prev_stream = copy_streams[j - 1][i] + _copy(batches[i], prev_stream, next_stream) + + def compute( + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + ) -> None: + """Runs tasks with synchronization to copy streams.""" + partitions = self.partitions + devices = self.devices + copy_streams = self.copy_streams + checkpoint_stop = self.checkpoint_stop + + # Disable checkpointing if in eval mode. + if not self.partitions[0].training: + checkpoint_stop = 0 + + n = len(partitions) + streams = [current_stream(d) for d in devices] + exc_info: Optional[ExcInfo] = None + + # With checkpointing, the autograd graph looks like this diagram: + # ┌─────┸──────┐ + # │ Copy │ + # └─────┰──────┘ (fence) + # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ + # ┃ (compute) + # ┌─────┸──────┐ + # │ Wait │ [1] Synchronize the current stream with the copy stream. + # └─────┰──────┘ + # ┌─────┸──────┐ + # │ Checkpoint │ [2] Compute a partition within checkpointing. + # └─────┰──────┘ + # ┌─────┸──────┐ + # │ Wait │ [3] Synchronize the copy stream with the current stream. + # └─────┰──────┘ + # ┠ ─ ─ ─ ┐ + # ┃ ┌─────┴─────┐ + # ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation. + # ┃ └─────┬─────┘ + # ┠ ─ ─ ─ ┘ + # ┃ + # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ + # ┌─────┸──────┐ (fence) + # │ Copy │ + # └─────┰──────┘ + for i, j in schedule: + batch = batches[i] + partition = partitions[j] + + # Synchronize with the copied input. ([1] in the diagram) + if j != 0: + _wait(batch, copy_streams[j][i], streams[j]) + + # Determine whether checkpointing or not. + checkpoint = i < checkpoint_stop + if checkpoint: + + def function( + input: TensorOrTensors, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> TensorOrTensors: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return partition(input) + + chk = Checkpointing(function, batch) + task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) + del function, chk + + else: + + def compute( + batch: Batch = batch, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> Batch: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return batch.call(partition) + + task = Task(streams[j], compute=compute, finalize=None) + del compute + + # Compute tasks in parallel. ([2] in the diagram) + self.in_queues[j].put(task) + + for i, j in schedule: + ok, payload = self.out_queues[j].get() + + # Hold the first exception. + if exc_info is not None: + continue + elif not ok: + exc_info = cast(ExcInfo, payload) + continue + + task, batch = cast(Tuple[Task, Batch], payload) + + # The copy stream synchronizes to copy the output. ([3] in the + # diagram) + if j != n - 1: + _wait(batch, streams[j], copy_streams[j][i]) + + # Finalize tasks. If checkpointing is enabled, here the + # recomputation is scheduled at backpropagation. ([4] in the + # diagram) + with use_device(devices[j]): + task.finalize(batch) + + batches[i] = batch + + # Fail at the first exception. + if exc_info is not None: + raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) diff --git a/torch/distributed/pipeline/sync/py.typed b/torch/distributed/pipeline/sync/py.typed new file mode 100644 index 0000000000000..ab03724cafbf5 --- /dev/null +++ b/torch/distributed/pipeline/sync/py.typed @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/skip/__init__.py b/torch/distributed/pipeline/sync/skip/__init__.py new file mode 100644 index 0000000000000..bdcb913867a73 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Supports efficiency with skip connections.""" +from .namespace import Namespace +from .skippable import pop, skippable, stash, verify_skippables + +__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] diff --git a/torch/distributed/pipeline/sync/skip/layout.py b/torch/distributed/pipeline/sync/skip/layout.py new file mode 100644 index 0000000000000..bff417bfbd65b --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/layout.py @@ -0,0 +1,86 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Static skip connection layout of ``@skippable`` modules.""" +from typing import Dict, Iterable, List, Tuple + +from torch import nn + +from .namespace import Namespace + +__all__: List[str] = [] + + +class SkipLayout: + """Represents a skip connection layout across partitions.""" + + # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...} + by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]] + + # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] + by_partition: List[List[Tuple[int, Namespace, str]]] + + def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: + # The skip routes are already indexed by 'ns, name'. + self.by_ns_name = skip_routes + + # Index skip routes by partition number 'j'. + self.by_partition = [[] for _ in range(num_partitions)] + + for (ns, name), (prev_j, next_j) in skip_routes.items(): + self.by_partition[next_j].append((prev_j, ns, name)) + + for p in self.by_partition: + p.sort() + + def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: + """Generates skip routes for the given destination partition number. + The skip routes are sorted by source partition number in ascending + order. + + Yields: + Each tuple of (source partition number, namespace, name). + + """ + for prev_j, ns, name in self.by_partition[next_j]: + if prev_j == next_j: + # This skip tensor will be popped at the same partition where + # it is stashed. In this case, copy is not required. + continue + + yield (prev_j, ns, name) + + def requires_copy(self, ns: Namespace, name: str) -> bool: + """Whether the given namespace and name requires partition-to-partition + copy or not. + """ + prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1)) + return prev_j != next_j + + +def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout: + """Inspects the skip connection layout in the given partitions.""" + # NOTE(sublee): Hide circular import inside this subroutine. Circular + # import is not ideal but placing this logic near to SkipLayout may + # increase cohesion of code. + from .skippable import Skippable + + skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {} + stashed_at: Dict[Tuple[Namespace, str], int] = {} + + for j, partition in enumerate(partitions): + for layer in partition: + if not isinstance(layer, Skippable): + continue + + for ns, name in layer.stashable(): + stashed_at[(ns, name)] = j + + for ns, name in layer.poppable(): + prev_j = stashed_at.pop((ns, name)) + skip_routes[(ns, name)] = (prev_j, j) + + return SkipLayout(len(partitions), skip_routes) diff --git a/torch/distributed/pipeline/sync/skip/namespace.py b/torch/distributed/pipeline/sync/skip/namespace.py new file mode 100644 index 0000000000000..d2a8de92588ec --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/namespace.py @@ -0,0 +1,50 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Provides isolated namespace of skip tensors.""" +import abc +from functools import total_ordering +from typing import Any +import uuid + +__all__ = ["Namespace"] + + +@total_ordering +class Namespace(metaclass=abc.ABCMeta): + """Namespace for isolating skip tensors used by :meth:`isolate() + `. + """ + + __slots__ = ("id",) + + def __init__(self) -> None: + self.id = uuid.uuid4() + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(self.id) + + # Namespaces should support ordering, since SkipLayout will sort tuples + # including a namespace. But actual order between namespaces is not + # important. That's why they are ordered by version 4 UUID which generates + # random numbers. + def __lt__(self, other: Any) -> bool: + if isinstance(other, Namespace): + return self.id < other.id + return False + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Namespace): + return self.id == other.id + return False + + +# 'None' is the default namespace, +# which means that 'isinstance(None, Namespace)' is 'True'. +Namespace.register(type(None)) diff --git a/torch/distributed/pipeline/sync/skip/portal.py b/torch/distributed/pipeline/sync/skip/portal.py new file mode 100644 index 0000000000000..6b3bbb3fb761d --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/portal.py @@ -0,0 +1,231 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the +autograd engine. The shared context of three functions (:class:`PortalBlue`, +:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is +one of the most important feature of :mod:`torchpipe.skip`. + +The metaphor is inspired by Portal™ from Valve. + +""" +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from ..copy import Context as CopyContext +from ..copy import Copy +from ..phony import get_phony +from ..stream import AbstractStream, get_device + +__all__: List[str] = [] + + +class Portal: + """A portal for a tensor.""" + + def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: + self.put_tensor(tensor, tensor_life) + self.grad: Optional[Tensor] = None + + def blue(self) -> Tensor: + """Creates a :class:`PortalBlue` which hides the underlying tensor from + the autograd engine. + + Join the returning phony to the main lane of the autograd graph to + assure the correct backpropagation:: + + PortalBlue --+ + | + ---------- Join -- + + """ + tensor = self.use_tensor() + + if tensor is None: + return get_phony(torch.device("cpu"), requires_grad=False) + + return PortalBlue.apply(self, tensor) + + def orange(self, phony: Tensor) -> Optional[Tensor]: + """Creates a :class:`PortalOrange` which retrieves the hidden tensor + without losing ability of backpropagation. + + Give a phony forked from the main lane of an autograd graph:: + + +-- PortalOrange --+ + | | + -- Fork --------- f(a, b) -- + + """ + self.check_tensor_life() + + if self.tensor is None: + return self.use_tensor() + + return PortalOrange.apply(self, phony) + + def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: + """Copies the hidden tensor by a :class:`PortalCopy`. + + Give a phony and use the returning phony to keep backpropagation:: + + +-- PortalCopy --+ + | | + -- Fork ---------- Join -- + + """ + if self.tensor is None: + return get_phony(torch.device("cpu"), requires_grad=False) + + return PortalCopy.apply(self, prev_stream, next_stream, phony) + + def check_tensor_life(self) -> None: + if self.tensor_life <= 0: + raise RuntimeError("tensor in portal has been removed") + + def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None: + """Stores a tensor into this portal.""" + # [Life of Tensor through Portal] + # + # The tensor can be retrieved by use_tensor() up to 'tensor_life' + # times. When the life becomes 0, the tensor will be deleted for + # deallocation in CUDA memory. + # + # The below events participate in a tensor through a portal. + # Note that [x] denotes the events which call use_tensor(): + # + # 1. [x] blue() + # 2. [ ] PortalBlue.forward + # 3. [ ] copy() + # 4. [ ] PortalCopy.forward + # 5. [ ] orange() + # 6. [x] PortalOrange.forward + # - - - - - - - - - - - - - - - - - - - - - - - - - - - + # 7. [ ] orange() (recomputed) + # 8. [x] PortalOrange.forward (recomputed) + # 9. [ ] PortalOrange.backward + # 10. [ ] PortalCopy.backward + # 11. [x] blue() (recomputed) + # 12. [ ] PortalBlue.forward (recomputed) + # 13. [ ] PortalBlue.backward + # + self.tensor_life = tensor_life + + if tensor_life > 0: + self.tensor = tensor + else: + self.tensor = None + + def use_tensor(self) -> Optional[Tensor]: + """Retrieves the underlying tensor and decreases the tensor life. When + the life becomes 0, it the tensor will be removed. + """ + self.check_tensor_life() + + tensor = self.tensor + + self.tensor_life -= 1 + + if self.tensor_life <= 0: + self.tensor = None + + return tensor + + def put_grad(self, grad: Tensor) -> None: + """Stores a gradient into this portal.""" + self.grad = grad + + def use_grad(self) -> Tensor: + """Retrieves and removes the underlying gradient. The gradient is + always ephemeral. + """ + if self.grad is None: + raise RuntimeError("grad in portal has been removed or never set") + + grad = self.grad + self.grad = None + return grad + + +# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and +# :class:`PortalCopy`. +class Context(CopyContext): + portal: Portal + + +class PortalBlue(torch.autograd.Function): + """Hides a tensor from the autograd engine by a :class:`Portal`.""" + + @staticmethod + # type: ignore + def forward( + ctx: Context, + portal: Portal, + # This tensor must be retrieved by portal.use_tensor(). + tensor: Tensor, + ) -> Tensor: + ctx.portal = portal + + phony = get_phony(tensor.device, requires_grad=False) + return phony.detach() + + @staticmethod + # type: ignore + def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: + # The paired PortalOrange should keep the gradient. + grad = ctx.portal.use_grad() + return None, grad + + +class PortalOrange(torch.autograd.Function): + """Retrieves the hidden tensor from a :class:`Portal`.""" + + @staticmethod + # type: ignore + def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor: + ctx.portal = portal + + tensor = portal.use_tensor() + assert tensor is not None + + return tensor.detach() + + @staticmethod + def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore + # The paired PortalBlue will use the gradient. + ctx.portal.put_grad(grad) + return None, None + + +class PortalCopy(torch.autograd.Function): + """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden + tensor with copied one. + """ + + @staticmethod + # type: ignore + def forward( + ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, + ) -> Tensor: + ctx.portal = portal + + assert portal.tensor is not None + (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) + + phony = get_phony(get_device(next_stream), requires_grad=False) + return phony.detach() + + @staticmethod + # type: ignore + def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: + portal = ctx.portal + + assert portal.grad is not None + _, _, portal.grad = Copy.backward(ctx, portal.grad) + + return None, None, None, None diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py new file mode 100644 index 0000000000000..e0b0dae584a23 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/skippable.py @@ -0,0 +1,441 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The user interface to define skip connections.""" +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + FrozenSet, + Generator, + Iterable, + List, + Optional, + Set, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from torch import Tensor, nn + +from ..microbatch import Batch +from .namespace import Namespace +from .tracker import current_skip_tracker + +__all__ = ["skippable", "stash", "pop", "verify_skippables"] + + +Tensors = Sequence[Tensor] +TensorOrTensors = Union[Tensor, Tensors] + +StashPop = Union["stash", "pop"] +StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] +if TYPE_CHECKING: + # Typechecking: nn.Module is not a Generic + SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg] +else: + SkippableModule = nn.Module + +T = TypeVar("T", bound="Skippable") + + +class Skippable(nn.Module): + """The base class for skippable modules. + + Do not use this class directly. Define a subclass by :func:`skippable` + instead. + + """ + + module_cls: ClassVar[Type[SkippableModule]] + stashable_names: ClassVar[FrozenSet[str]] + poppable_names: ClassVar[FrozenSet[str]] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + self.module = self.module_cls(*args, **kwargs) # type: ignore + self.namespaces: Dict[str, Namespace] = {} + + def __repr__(self) -> str: + return f"@skippable({self.module})" + + def namespaced(self, name: str) -> Tuple[Namespace, str]: + """Prepends namespace for the given skip name.""" + ns = self.namespaces.get(name) + ns = cast(Namespace, ns) + return (ns, name) + + def stashable(self) -> Iterable[Tuple[Namespace, str]]: + """Iterates over namespaced skip names to be stashed.""" + for name in self.stashable_names: + yield self.namespaced(name) + + def poppable(self) -> Iterable[Tuple[Namespace, str]]: + """Iterates over namespaced skip names to be popped.""" + for name in self.poppable_names: + yield self.namespaced(name) + + def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T: + r"""Isolates a specified subset or the whole set of skip tensors into a + namespace. In a single sequential module, skip tensors with the same + name are not allowed unless they are isolated by different namespaces. + + Here's an example using the same name for skip tensors twice. Each pair + of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` + and ``ns2``. There is no conflict anymore:: + + ns1 = Namespace() + ns2 = Namespace() + + model = nn.Sequential( + Layer1().isolate(ns1), + Layer1().isolate(ns2), + Layer2(), + Layer3().isolate(ns2), + Layer3().isolate(ns1), + ) + + When `only` parameter is omitted, all skip tensors are isolated. You + can isolate a subset of skip tensors by passing `only` parameter:: + + ns_alice = Namespace() + ns_bob = Namespace() + + model = nn.Sequential( + ... + StashStashPop().isolate(ns_alice, only=['alice']) \ + .isolate(ns_bob, only=['bob']), + ... + ) + + Args: + ns (Namespace): + namespace for isolation + + Keyword Args: + only (iterable of strs): + names of specific skip tensors to be isolated (omit this option + to isolate all skip tensors declared in this module) + + Returns: + this module itself + + """ + names: Iterable[str] + + if only is None: + names = self.stashable_names | self.poppable_names + else: + names = set(only) + + for name in names: + self.namespaces[name] = ns + + return self + + def dispatch( + self, + input: TensorOrTensors, + handle_stash: Callable[[str, Optional[Tensor]], None], + handle_pop: Callable[[str], Optional[Tensor]], + ) -> TensorOrTensors: + """Dispatches :class:`stash` or :class:`pop` commands generated by the + module's ``forward()``. + """ + generator = self.module(input) + + if not isinstance(generator, Generator): + # The underlying module returned output without any yield. + output = generator + return output + + try: + op = next(generator) + + while True: + if isinstance(op, stash): + handle_stash(op.name, op.tensor) + op = next(generator) + continue + + if isinstance(op, pop): + tensor = handle_pop(op.name) + op = generator.send(tensor) + continue + + raise TypeError("%r is not a command from @skippable" % op) + + except StopIteration as stop: + output = stop.args[0] + return output + + def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + """Performs the forward propagation. :class:`stash` or :class:`pop` + commands will be handled by portals silently. The portals won't be + exposed to users. + + Raises: + RuntimeError: + illegal 'stash' or 'pop' is found. + + """ + skip_tracker = current_skip_tracker() + stashed_tensors: Dict[str, Optional[Tensor]] = {} + + # Load skip tensors that might be popped. + poppable_tensors = {} + batch = Batch(input) + for ns, name in self.poppable(): + try: + poppable_tensors[name] = skip_tracker.load(batch, ns, name) + except KeyError: + raise RuntimeError(f"'{name}' has not been stashed") + input = batch.tensor_or_tensors + + # Handle skip commands. + def handle_stash(name: str, tensor: Optional[Tensor]) -> None: + if name not in self.stashable_names: + raise RuntimeError(f"'{name}' has not been declared as stashable") + stashed_tensors[name] = tensor + + def handle_pop(name: str) -> Optional[Tensor]: + if name not in self.poppable_names: + raise RuntimeError(f"'{name}' has not been declared as poppable") + return poppable_tensors.pop(name) + + output = self.dispatch(input, handle_stash, handle_pop) + + # All declared skips must be stashed or popped. + not_stashed = self.stashable_names - stashed_tensors.keys() + if not_stashed: + comma_names = ", ".join("'%s'" % n for n in not_stashed) + raise RuntimeError(f"{comma_names} must be stashed but have not") + + not_popped = poppable_tensors.keys() + if not_popped: + comma_names = ", ".join("'%s'" % n for n in not_popped) + raise RuntimeError(f"{comma_names} must be popped but have not") + + # Save stashed skip tensors. + batch = Batch(output) + for ns, name in self.stashable(): + tensor = stashed_tensors[name] + skip_tracker.save(batch, ns, name, tensor) + output = batch.tensor_or_tensors + + return output + + +# TODO(sublee): Move to above of Skippable class for better read flow. +def skippable( + stash: Iterable[str] = (), pop: Iterable[str] = (), +) -> Callable[[Type[SkippableModule]], Type[Skippable]]: + """The decorator to define a :class:`nn.Module ` with skip + connections. Decorated modules are called "skippable". This functionality + works perfectly fine even when the module is not wrapped by + :class:`~torchpipe.Pipe`. + + Each skip tensor is managed by its name. Before manipulating skip tensors, + a skippable module must statically declare the names for skip tensors by + `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be + stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield + pop(name)``. + + Here is an example with three layers. A skip tensor named "1to3" is stashed + and popped at the first and last layer, respectively:: + + @skippable(stash=['1to3']) + class Layer1(nn.Module): + def forward(self, input): + yield stash('1to3', input) + return f1(input) + + class Layer2(nn.Module): + def forward(self, input): + return f2(input) + + @skippable(pop=['1to3']) + class Layer3(nn.Module): + def forward(self, input): + skip_1to3 = yield pop('1to3') + return f3(input) + skip_1to3 + + model = nn.Sequential(Layer1(), Layer2(), Layer3()) + + One skippable module can stash or pop multiple skip tensors:: + + @skippable(stash=['alice', 'bob'], pop=['carol']) + class StashStashPop(nn.Module): + def forward(self, input): + yield stash('alice', f_alice(input)) + yield stash('bob', f_bob(input)) + carol = yield pop('carol') + return input + carol + + Every skip tensor must be associated with exactly one pair of `stash` and + `pop`. :class:`~torchpipe.Pipe` checks this restriction automatically + when wrapping a module. You can also check the restriction by + :func:`~torchpipe.skip.verify_skippables` without + :class:`~torchpipe.Pipe`. + + .. note:: + + :func:`@skippable ` changes the type of the wrapped class. + But currently (mypy v0.740), mypy could not understand class decorators + yet (`#3135 `_). + + There are two workarounds: + + 1. Naively ignore type errors by ``# type: ignore``. + 2. Use ``skippable()()`` as a function instead of a decorator. + + .. seealso:: :ref:`Long Skip Connections` + + """ + stashable_names = frozenset(stash) + poppable_names = frozenset(pop) + + def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]: + name = module_cls.__name__ + bases = (Skippable,) + attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names} + return type(name, bases, attrs) + + return extend_skippable + + +class stash: + """The command to stash a skip tensor. + + :: + + def forward(self, input): + yield stash('name', input) + return f(input) + + Args: + name (str): name of skip tensor + input (torch.Tensor or None): tensor to pass to the skip connection + + """ + + __slots__ = ("name", "tensor") + + def __init__(self, name: str, tensor: Optional[Tensor]) -> None: + self.name = name + self.tensor = tensor + + +class pop: + """The command to pop a skip tensor. + + :: + + def forward(self, input): + skip = yield pop('name') + return f(input) + skip + + Args: + name (str): name of skip tensor + + Returns: + the skip tensor previously stashed by another layer under the same name + + """ + + __slots__ = ("name",) + + def __init__(self, name: str) -> None: + self.name = name + + +def verify_skippables(module: nn.Sequential) -> None: + """Verifies if the underlying skippable modules satisfy integrity. + + Every skip tensor must have only one pair of `stash` and `pop`. If there + are one or more unmatched pairs, it will raise :exc:`TypeError` with the + detailed messages. + + Here are a few failure cases. :func:`verify_skippables` will report failure + for these cases:: + + # Layer1 stashes "1to3". + # Layer3 pops "1to3". + + nn.Sequential(Layer1(), Layer2()) + # └──── ? + + nn.Sequential(Layer2(), Layer3()) + # ? ────┘ + + nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) + # └───────────────────┘ ^^^^^^ + + nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) + # ^^^^^^ └───────────────────┘ + + To use the same name for multiple skip tensors, they must be isolated by + different namespaces. See :meth:`isolate() + `. + + Raises: + TypeError: + one or more pairs of `stash` and `pop` are not matched. + + """ + stashed: Set[Tuple[Namespace, str]] = set() + popped: Set[Tuple[Namespace, str]] = set() + msgs: List[str] = [] + + for layer_name, layer in module.named_children(): + if not isinstance(layer, Skippable): + continue + + for name in layer.stashable_names & layer.poppable_names: + msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable" + msgs.append(msg) + + for ns, name in layer.stashable(): + if name in layer.poppable_names: + continue + + if (ns, name) in stashed: + msg = f"'{layer_name}' redeclared '{name}' as stashable " "but not isolated by namespace" + msgs.append(msg) + continue + + stashed.add((ns, name)) + + for ns, name in layer.poppable(): + if name in layer.stashable_names: + continue + + if (ns, name) in popped: + msg = f"'{layer_name}' redeclared '{name}' as poppable " "but not isolated by namespace" + msgs.append(msg) + continue + + if (ns, name) not in stashed: + msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed" + msgs.append(msg) + continue + + popped.add((ns, name)) + + for (_, name) in stashed - popped: + msg = f"no module declared '{name}' as poppable but stashed" + msgs.append(msg) + + if msgs: + raise TypeError( + "one or more pairs of stash and pop do not match:\n\n%s" "" % "\n".join("* %s" % x for x in msgs) + ) diff --git a/torch/distributed/pipeline/sync/skip/tracker.py b/torch/distributed/pipeline/sync/skip/tracker.py new file mode 100644 index 0000000000000..397158c21dbf0 --- /dev/null +++ b/torch/distributed/pipeline/sync/skip/tracker.py @@ -0,0 +1,177 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Tracks skip tensors on a thread.""" +from contextlib import contextmanager +import threading +from typing import Dict, Generator, List, Optional, Tuple + +from torch import Tensor + +from ..checkpoint import is_checkpointing +from ..dependency import fork, join +from ..microbatch import Batch +from ..stream import AbstractStream +from .layout import SkipLayout +from .namespace import Namespace +from .portal import Portal + +__all__: List[str] = [] + + +class SkipTracker: + """Tracks saved skip tensors. + + It will update the given micro-batch in place. This is because when it + manipulates the underlying skip tensors, the current micro-batch also has + to be connected with the skip tensors. + + One thread has one skip tracker. Call :func:`current_skip_tracker` to get + the skip tracker on the current thread. + + """ + + def __init__(self) -> None: + self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {} + + def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: + self.tensors[(ns, name)] = tensor + + def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: + return self.tensors.pop((ns, name)) + + def copy( + self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, + ) -> None: + raise TypeError("copy is not supported for non-portal skip tensors") + + +class SkipTrackerThroughPotals(SkipTracker): + """Tracks saved skip tensors through portals. The skip tensors will be + hidden in portals so that the autograd engine does not need to track them. + + This tracker is only used when the training or evaluating module is wrapped + with :class:`torchpipe.Pipe`. + + """ + + def __init__(self, skip_layout: SkipLayout) -> None: + super().__init__() + self.skip_layout = skip_layout + self.portals: Dict[Tuple[Namespace, str], Portal] = {} + + def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: + """Saves the stashed skip tensor in a portal. The portal is then + connected to the given micro-batch with :class:`Join`. + """ + if not self.skip_layout.requires_copy(ns, name): + super().save(batch, ns, name, tensor) + return + + # See [Tensor Life of Portal] at Portal.put_tensor() to understand the + # below tensor_life values. Here are the selected events which retrieve + # the tensor in portal: + # + # 1. [x] blue() + # ... + # 6. [x] PortalOrange.forward + # ... + # 8. [x] PortalOrange.forward (recomputed) + # ... + # 11. [x] blue() (recomputed) + # + if (ns, name) not in self.portals: + if is_checkpointing(): + # Under checkpointing, the tensor used by the first + # PortalOrange should be alive in the portal. This tensor will + # be used again by the second PortalOrange during the + # recomputation. + tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)] + else: + tensor_life = 2 # Delete at [6. PortalOrange.forward] + + portal = Portal(tensor, tensor_life) + self.portals[(ns, name)] = portal + + else: + # Under recomputation, the portal already exists. + portal = self.portals[(ns, name)] + + # The existing tensor life already became 0. It should be reset as + # 1 to delete the tensor after the second PortalBlue immediately. + tensor_life = 1 # Delete at [11. blue() (recomputed)] + + portal.put_tensor(tensor, tensor_life) + + phony = portal.blue() + batch[0] = join(batch[0], phony) + + def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: + """Loads a skip tensor from the corresponding portal to pop. The given + micro-batch is connected to the portal with :class:`Fork`. + """ + if not self.skip_layout.requires_copy(ns, name): + tensor = super().load(batch, ns, name) + return tensor + + portal = self.portals[(ns, name)] + batch[0], phony = fork(batch[0]) + tensor = portal.orange(phony) + return tensor + + def copy( + self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, + ) -> None: + """Copies the skip tensor in the corresponding portal. The given + micro-batch and the portal will be tied with :class:`Fork` and + :class:`Join`. + """ + assert self.skip_layout.requires_copy(ns, name) + + batch[0], phony = fork(batch[0]) + + portal = self.portals[(ns, name)] + phony = portal.copy(prev_stream, next_stream, phony) + + batch[0] = join(batch[0], phony) + + +class ThreadLocal(threading.local): + def __init__(self) -> None: + self.skip_tracker: Optional[SkipTracker] = None + + +thread_local = ThreadLocal() + + +@contextmanager +def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]: + """Registers the given skip tracker on the current thread within a + context:: + + with use_skip_tracker(my_skip_tracker): + ... + + """ + orig = thread_local.skip_tracker + + thread_local.skip_tracker = skip_tracker + + try: + yield + finally: + thread_local.skip_tracker = orig + + +def current_skip_tracker() -> SkipTracker: + """Gets the skip tracker on the current thread.""" + skip_tracker = thread_local.skip_tracker + + if skip_tracker is None: + skip_tracker = SkipTracker() + thread_local.skip_tracker = skip_tracker + + return skip_tracker diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py new file mode 100644 index 0000000000000..41e1591793b6c --- /dev/null +++ b/torch/distributed/pipeline/sync/stream.py @@ -0,0 +1,118 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Utilities for eliminating boilerplate code to handle abstract streams with +CPU device. +""" +from contextlib import contextmanager +from typing import Generator, List, Union, cast + +import torch + +__all__: List[str] = [] + + +class CPUStreamType: + pass + + +# The placeholder on place of streams for the CPU device instead of CUDA. +CPUStream = CPUStreamType() + +# It represents both CUDA streams and the CPU stream. +AbstractStream = Union[torch.cuda.Stream, CPUStreamType] + + +def new_stream(device: torch.device) -> AbstractStream: + """Creates a new stream for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.Stream(device) + + +def current_stream(device: torch.device) -> AbstractStream: + """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.current_stream(device) + + +def default_stream(device: torch.device) -> AbstractStream: + """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.default_stream(device) + + +@contextmanager +def use_device(device: torch.device) -> Generator[None, None, None]: + """:func:`torch.cuda.device` for either CPU or CUDA device.""" + if device.type != "cuda": + yield + return + + with torch.cuda.device(device): + yield + + +@contextmanager +def use_stream(stream: AbstractStream) -> Generator[None, None, None]: + """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" + if not is_cuda(stream): + yield + return + + with torch.cuda.stream(as_cuda(stream)): + yield + + +def get_device(stream: AbstractStream) -> torch.device: + """Gets the device from CPU or CUDA stream.""" + if is_cuda(stream): + return as_cuda(stream).device + return torch.device("cpu") + + +def wait_stream(source: AbstractStream, target: AbstractStream) -> None: + """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It + makes the source stream wait until the target stream completes work queued. + """ + if is_cuda(target): + if is_cuda(source): + # A CUDA stream waits another CUDA stream. + as_cuda(source).wait_stream(as_cuda(target)) + else: + # CPU waits a CUDA stream. + as_cuda(target).synchronize() + + # If the target is CPU, synchronization is not required. + + +def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: + """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" + if is_cuda(stream): + # NOTE(sublee): record_stream() on a shifted view tensor throws + # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely + # protect the tensor against unexpected reallocation, here we use a + # temporal tensor associated with the same storage without shifting as + # a workaround. + # + # Issue: https://github.com/pytorch/pytorch/issues/27366 + # + tensor = tensor.new_empty([0]).set_(tensor.storage()) + + # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream + tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] + + +def is_cuda(stream: AbstractStream) -> bool: + """Returns ``True`` if the given stream is a valid CUDA stream.""" + return stream is not CPUStream + + +def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: + """Casts the given stream as :class:`torch.cuda.Stream`.""" + return cast(torch.cuda.Stream, stream) diff --git a/torch/distributed/pipeline/sync/worker.py b/torch/distributed/pipeline/sync/worker.py new file mode 100644 index 0000000000000..81a588071c2e3 --- /dev/null +++ b/torch/distributed/pipeline/sync/worker.py @@ -0,0 +1,151 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Multithreading in pipeline parallelism.""" +from contextlib import contextmanager +from queue import Queue +import sys +from threading import Thread +from types import TracebackType +from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast + +import torch + +from .microbatch import Batch +from .stream import AbstractStream, use_device, use_stream + +__all__: List[str] = [] + + +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] + +# Queue is generic only in stubs. +# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +if TYPE_CHECKING: + InQueue = Queue[Optional["Task"]] + OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] +else: + InQueue = Queue + OutQueue = Queue + + +class Task: + """A task represents how to compute a micro-batch on a partition. + + It consists of two parts: :meth:`compute` and :meth:`finalize`. + :meth:`compute` should be executed in worker threads concurrently. + :meth:`finalize` should be executed after when worker threads complete to + execute :meth:`compute`. + + :meth:`compute` might be boosted by worker threads. Because it produces + several CUDA API calls by user code. In PyTorch, parallel CUDA API calls + are not serialized through GIL. So more than one CUDA API call can be + produced at the same time. + + """ + + def __init__( + self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], + ) -> None: + self.stream = stream + self._compute = compute + self._finalize = finalize + self._grad_enabled = torch.is_grad_enabled() + + def compute(self) -> Batch: + with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): + return self._compute() + + def finalize(self, batch: Batch) -> None: + if self._finalize is None: + return + with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): + self._finalize(batch) + + +def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None: + """The main loop of a worker thread.""" + with use_device(device): + while True: + task = in_queue.get() + + if task is None: + break + + try: + batch = task.compute() + except Exception: + exc_info = cast(ExcInfo, sys.exc_info()) + out_queue.put((False, exc_info)) + continue + + out_queue.put((True, (task, batch))) + + done = (False, None) + out_queue.put(done) + + +def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: + """Spawns worker threads. A worker thread is bound to a device.""" + in_queues: List[InQueue] = [] + out_queues: List[OutQueue] = [] + + # Spawn workers. + workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {} + + def normalize_device(device: torch.device) -> torch.device: + if device.type == "cuda" and device.index is None: + return torch.device("cuda", index=torch.cuda.current_device()) + + if device.type == "cpu" and device.index is not None: + return torch.device("cpu") + + return device + + for device in devices: + device = normalize_device(device) + + try: + in_queue, out_queue = workers[device] + except KeyError: + in_queue = Queue() + out_queue = Queue() + workers[device] = (in_queue, out_queue) + + t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) + t.start() + + in_queues.append(in_queue) + out_queues.append(out_queue) + + return (in_queues, out_queues) + + +def join_workers(in_queues: List[InQueue], out_queues: List[OutQueue]) -> None: + # Close workers. + for in_queue in set(in_queues): + in_queue.put(None) + + # Join running workers. + running = set(out_queues) + while running: + out_queue = running.pop() + ok, payload = out_queue.get() + + done = (False, None) + if (ok, payload) == done: + continue + + running.add(out_queue) + + +@contextmanager +def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: + try: + (in_queues, out_queues) = create_workers(devices) + yield (in_queues, out_queues) + finally: + join_workers(in_queues, out_queues) diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 292634580aab2..3f04fa142a4b5 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -1,15 +1,17 @@ try: from urllib.parse import urlparse, urlunparse except ImportError: - from urlparse import urlparse, urlunparse + raise ImportError("urllib cannot be found, urlparse from python2 is no longer supported.") import torch._six as six import numbers import os -from . import FileStore, TCPStore +import sys +from datetime import timedelta +from typing import Optional, Dict, Union +from torch._C._distributed_c10d import FileStore, TCPStore from .constants import default_pg_timeout - _rendezvous_handlers = {} @@ -29,7 +31,7 @@ def register_rendezvous_handler(scheme, handler): Pick a unique name and use the URL scheme to identify it when calling the `rendezvous()` function. - Arguments: + Args: scheme (str): URL scheme to identify your rendezvous handler. handler (function): Handler that is invoked when the `rendezvous()` function is called with a URL that uses @@ -44,7 +46,7 @@ def register_rendezvous_handler(scheme, handler): _rendezvous_handlers[scheme] = handler -def rendezvous(url, rank=-1, world_size=-1, **kwargs): +def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): if not isinstance(url, six.string_classes): raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url)) @@ -57,8 +59,9 @@ def rendezvous(url, rank=-1, world_size=-1, **kwargs): # Append node-specific arguments. result = urlparse(url) if rank != -1 or world_size != -1: - query_dict = dict( - pair.split("=") for pair in filter(None, result.query.split("&")) + query_dict: Dict[str, Union[int, str]] = dict( + # mypy doesn't allow dict() to accept List of values (#257) + pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc] ) assert ( "rank" not in query_dict and "world_size" not in query_dict @@ -84,15 +87,21 @@ def _rendezvous_error(msg): return ValueError("Error initializing torch.distributed using " + msg) -def _file_rendezvous_handler(url, **kwargs): +def _file_rendezvous_handler(url: str, **kwargs): def _error(msg): return _rendezvous_error("file:// rendezvous: " + msg) result = urlparse(url) path = result.path + if sys.platform == 'win32': + import urllib.request + path = urllib.request.url2pathname(result.path) + if not path: raise _error("path missing") - query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) + query: Dict[str, str] + # mypy doesn't allow dict() to accept List of values (#257) + query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type] if "rank" not in query: raise _error("rank parameter missing") if "world_size" not in query: @@ -107,14 +116,16 @@ def _error(msg): raise RuntimeError("Unable to perform rerendezvous using file:// method") -def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs): +def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): def _error(msg): return _rendezvous_error("tcp:// rendezvous: " + msg) result = urlparse(url) if not result.port: raise _error("port number missing") - query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) + query: Dict[str, Union[int, str]] + # mypy doesn't allow dict() to accept List of values (#257) + query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type] if "rank" not in query: raise _error("rank parameter missing") if "world_size" not in query: @@ -123,6 +134,7 @@ def _error(msg): rank = int(query["rank"]) world_size = int(query["world_size"]) start_daemon = rank == 0 + assert result.hostname is not None store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout) yield (store, rank, world_size) @@ -130,7 +142,7 @@ def _error(msg): raise RuntimeError("Unable to perform rerendezvous using tcp:// method") -def _env_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs): +def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): def _error(msg): return _rendezvous_error("env:// rendezvous: " + msg) @@ -138,7 +150,13 @@ def _env_error(var): return _error("environment variable %s expected, but not set" % var) result = urlparse(url) - query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) + query: Dict[str, Union[int, str]] + # mypy doesn't allow dict() to accept List of values (#257) + query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type] + + rank: Optional[Union[str, int]] + world_size: Optional[Union[str, int]] + master_port: Optional[Union[str, int]] if "rank" in query: rank = int(query["rank"]) @@ -175,7 +193,6 @@ def _env_error(var): # If this configuration is invalidated, there is nothing we can do about it raise RuntimeError("Unable to perform rerendezvous using env:// method") - -register_rendezvous_handler("file", _file_rendezvous_handler) register_rendezvous_handler("tcp", _tcp_rendezvous_handler) register_rendezvous_handler("env", _env_rendezvous_handler) +register_rendezvous_handler("file", _file_rendezvous_handler) diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index 4598c78e72fe9..77bed89e33b55 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,6 +1,7 @@ import logging import threading +from typing import Generator, Tuple import torch import torch.distributed as dist @@ -19,15 +20,47 @@ def is_available(): raise RuntimeError("Failed to initialize torch.distributed.rpc") - - if is_available(): - from . import api, backend_registry, functions, _set_profiler_node_id - from . import ( + from . import api, backend_registry, functions + from torch._C._distributed_rpc import ( _disable_jit_rref_pickle, _enable_jit_rref_pickle, + _disable_server_process_global_profiler, + _enable_server_process_global_profiler, _set_and_start_rpc_agent, + _reset_current_rpc_agent, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _set_profiler_node_id, + _is_current_rpc_agent_set, + _rref_context_get_debug_info, + _cleanup_python_rpc_handler, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _set_rpc_timeout, + _get_current_rpc_agent, + get_rpc_timeout, + enable_gil_profiling, + RpcBackendOptions, + _TensorPipeRpcBackendOptionsBase, + ProcessGroupRpcBackendOptions, + RpcAgent, + PyRRef, + ProcessGroupAgent, + TensorPipeAgent, + RemoteProfilerManager, + WorkerInfo, + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_SEND_RECV_THREADS, + _DEFAULT_NUM_WORKER_THREADS, + _UNSET_RPC_TIMEOUT, + _DEFAULT_RPC_TIMEOUT_SEC, ) # noqa: F401 + from torch._C._distributed_c10d import Store from .api import * # noqa: F401 from .options import TensorPipeRpcBackendOptions # noqa: F401 from .backend_registry import BackendType @@ -38,6 +71,7 @@ def is_available(): import numbers + rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] def init_rpc( name, @@ -51,16 +85,16 @@ def init_rpc( and distributed autograd, which immediately makes the current process ready to send and receive RPCs. - Arguments: + Args: + name (str): a globally unique name of this node. (e.g., + ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) + Name can only contain number, alphabet, underscore, colon, + and/or dash, and must be shorter than 128 characters. backend (BackendType, optional): The type of RPC backend implementation. Supported values include ``BackendType.TENSORPIPE`` (the default) and ``BackendType.PROCESS_GROUP``. See :ref:`rpc-backends` for more information. - name (str): a globally unique name of this node. (e.g., - ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) - Name can only contain number, alphabet, underscore, colon, - and/or dash, and must be shorter than 128 characters. rank (int): a globally unique id/rank of this node. world_size (int): The number of workers in the group. rpc_backend_options (RpcBackendOptions, optional): The options @@ -106,18 +140,19 @@ def init_rpc( raise TypeError( f"Could not infer backend for options {rpc_backend_options}" ) - if backend != BackendType.TENSORPIPE: + # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) + if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] logger.warning( - f"RPC was initialized with no explicit backend but with options " + f"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] f"corresponding to {backend}, hence that backend will be used " f"instead of the default {BackendType.TENSORPIPE}. To silence this " f"warning pass `backend={backend}` explicitly." ) if backend is None: - backend = BackendType.TENSORPIPE + backend = BackendType.TENSORPIPE # type: ignore[attr-defined] - if backend == BackendType.PROCESS_GROUP: + if backend == BackendType.PROCESS_GROUP: # type: ignore[attr-defined] logger.warning( "RPC was initialized with the PROCESS_GROUP backend which is " "deprecated and slated to be removed and superseded by the TENSORPIPE " @@ -178,7 +213,7 @@ def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_optio def _init_rpc_backend( - backend=backend_registry.BackendType.TENSORPIPE, + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] store=None, name=None, rank=-1, @@ -206,7 +241,6 @@ def _init_rpc_backend( @api._require_initialized def _get_debug_info(): - from . import _rref_context_get_debug_info info = _rref_context_get_debug_info() info.update(api._get_current_rpc_agent().get_debug_info()) info.update(dist_autograd._get_debug_info()) diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py index 25c5195e45882..88a62fd0e1a58 100644 --- a/torch/distributed/rpc/_testing/__init__.py +++ b/torch/distributed/rpc/_testing/__init__.py @@ -12,3 +12,7 @@ def is_available(): if is_available(): # Registers FAULTY_PROCESS_GROUP RPC backend. from . import faulty_agent_backend_registry + from torch._C._distributed_rpc_testing import ( + FaultyProcessGroupRpcBackendOptions, + FaultyProcessGroupAgent, + ) diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index d1b62a5b0ab40..bc93271810779 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -4,11 +4,11 @@ import inspect import logging import threading -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Set, Any import torch -from . import ( +from torch._C._distributed_rpc import ( PyRRef, RemoteProfilerManager, WorkerInfo, @@ -99,10 +99,10 @@ def __init__(self): # States used by `def _all_gather()`. # `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer. -_ALL_WORKER_NAMES = None +_ALL_WORKER_NAMES: Set[Any] = set() _all_gather_dict_lock = threading.RLock() _all_gather_sequence_id = 0 -_all_gather_sequence_id_to_states = collections.defaultdict(AllGatherStates) +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) def _init_rpc_states(agent): @@ -141,6 +141,34 @@ def _broadcast_to_followers(sequence_id, objects_map): states.gathered_objects = objects_map states.proceed_signal.set() +_thread_local_var = threading.local() + +@contextlib.contextmanager +def _wait_all(): + r""" + A context manager that collects all futures returned by ``rpc_async`` and + waits them on the context manager's exit; relieving the user of needing + to explicitly call wait. + + + Example:: + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> with rpc._wait_all(): + >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> #fut_1 and fut_2 are waited on + """ + _thread_local_var.future_list = [] + try: + yield + finally: + try: + torch.futures.wait_all(_thread_local_var.future_list) + finally: + del _thread_local_var.future_list @_require_initialized def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT): @@ -245,7 +273,7 @@ def shutdown(graceful=True): :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not be called after ``shutdown()``. - Arguments: + Args: graceful (bool): Whether to do a graceful shutdown or not. If True, this will 1) wait until there is no pending system messages for ``UserRRefs`` and delete them; 2) block @@ -309,7 +337,7 @@ def get_worker_info(worker_name=None): Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an expensive string on every invocation. - Arguments: + Args: worker_name (str): the string name of a worker. If ``None``, return the the id of the current worker. (default ``None``) @@ -324,24 +352,25 @@ def get_worker_info(worker_name=None): return _get_current_rpc_agent().get_worker_info() -def _to_worker_info(name_or_info): - if isinstance(name_or_info, WorkerInfo): - return name_or_info - elif isinstance(name_or_info, str): - return get_worker_info(name_or_info) +def _to_worker_info(to): + if isinstance(to, WorkerInfo): + return to + elif isinstance(to, str) or isinstance(to, int): + return get_worker_info(to) else: - raise ValueError("Cannot get WorkerInfo from name {}".format(name_or_info)) + raise ValueError("Cannot get WorkerInfo from name {}".format(to)) def _rref_typeof_on_owner(rref): return type(rref.local_value()) -def _rref_typeof_on_user(rref): +def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT): return rpc_sync( rref.owner(), _rref_typeof_on_owner, - args=(rref,) + args=(rref,), + timeout=timeout ) @@ -351,16 +380,18 @@ def _rref_typeof_on_user(rref): try: # Combine the implementation class and the type class. - class RRef(PyRRef, GenericWithOneTypeVar): + class RRef(PyRRef, Generic[T]): pass -except TypeError as exc: +except TypeError: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases - class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): + # Mypy doesn't understand __class__ (mypy bug #4177) + class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore pass # Combine the implementation class and the type class. - class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): + # Types for classes expecting a certain generic parameter (mypy bug #7791) + class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore pass @@ -416,8 +447,8 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): :class:`~torch.distributed.rpc.RRef` is only destructed when globally there are no living references to it. - Arguments: - to (str or WorkerInfo): id or name of the destination worker. + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. func (callable): a callable function, such as Python callables, builtin operators (e.g. :meth:`~torch.add`) and annotated TorchScript functions. @@ -536,7 +567,8 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) - ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] with ctx_manager as rf: args = args if args else () @@ -611,7 +643,8 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) - ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] with ctx_manager as rf: args = args if args else () @@ -671,8 +704,8 @@ def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): messages are sent and received in parallel to execution of Python code. This method is thread-safe. - Arguments: - to (str or WorkerInfo): id or name of the destination worker. + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. func (callable): a callable function, such as Python callables, builtin operators (e.g. :meth:`~torch.add`) and annotated TorchScript functions. @@ -750,8 +783,8 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): method is thread-safe. This method will immediately return a :class:`~torch.futures.Future` that can be awaited on. - Arguments: - to (str or WorkerInfo): id or name of the destination worker. + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. func (callable): a callable function, such as Python callables, builtin operators (e.g. :meth:`~torch.add`) and annotated TorchScript functions. @@ -830,4 +863,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown() """ - return _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + if hasattr(_thread_local_var, "future_list"): + _thread_local_var.future_list.append(fut) + return fut diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 6dac7cb0863af..3e3001c519cec 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -28,15 +28,17 @@ def _backend_type_repr(self): """ # Create an enum type, `BackendType`, with empty members. -BackendType = enum.Enum(value="BackendType", names={}) -BackendType.__repr__ = _backend_type_repr +# Can't handle Function Enum API (mypy bug #9079) +BackendType = enum.Enum(value="BackendType", names=dict()) # type: ignore[misc] +# Unable to assign a function a method (mypy bug #2427) +BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] BackendType.__doc__ = _backend_type_doc def backend_registered(backend_name): """ Checks if backend_name is registered as an RPC backend. - Arguments: + Args: backend_name (str): string to identify the RPC backend. Returns: True if the backend has been registered with ``register_backend``, else @@ -50,7 +52,7 @@ def register_backend( ): """Registers a new RPC backend. - Arguments: + Args: backend_name (str): backend string to identify the handler. construct_rpc_backend_options_handler (function): Handler that is invoked when @@ -73,8 +75,10 @@ def register_backend( }, **existing_enum_dict ) - BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) - BackendType.__repr__ = _backend_type_repr + # Can't handle Function Enum API (mypy bug #9079) + BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] + # Unable to assign a function a method (mypy bug #2427) + BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] BackendType.__doc__ = _backend_type_doc return BackendType[backend_name] @@ -252,6 +256,18 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_ ) ) + if torch.cuda.is_available(): + # It's necessary to initialize PyTorch CUDA states here (e.g., + # CUDACachingAllocator). If this is missing, we could hit errors like + # "allocator not initialized", because other processes might send + # CUDA-related RPC request to this process before user code in this + # process initializes its PyTorch CUDA states. + torch.cuda.init() + # FIXME: this is needed for now because TensorPipe calls + # cudaPointerGetAttributes() on the default device. + # This error was also reported in https://github.com/pytorch/pytorch/issues/36594 + torch.zeros([1], device="cuda:0") + # The agent's join method is required to behave like a barrier and perform # collective operations, for which it relies on a process group, instead of # re-implementing this on top of RPCs. diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py index c2dd804e4c811..e6d79e6e5981c 100644 --- a/torch/distributed/rpc/constants.py +++ b/torch/distributed/rpc/constants.py @@ -1,6 +1,6 @@ from datetime import timedelta -from . import ( +from torch._C._distributed_rpc import ( _DEFAULT_INIT_METHOD, _DEFAULT_NUM_SEND_RECV_THREADS, _DEFAULT_NUM_WORKER_THREADS, @@ -10,16 +10,16 @@ # For any RpcAgent. -DEFAULT_RPC_TIMEOUT_SEC = _DEFAULT_RPC_TIMEOUT_SEC -DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD -DEFAULT_SHUTDOWN_TIMEOUT = 5.0 +DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC +DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD +DEFAULT_SHUTDOWN_TIMEOUT: float = 5.0 # For ProcessGroupAgent. -DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS +DEFAULT_NUM_SEND_RECV_THREADS: int = _DEFAULT_NUM_SEND_RECV_THREADS # For TensorPipeAgent. -DEFAULT_NUM_WORKER_THREADS = _DEFAULT_NUM_WORKER_THREADS +DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS # Ensure that we don't time out when there are long periods of time without # any operations against the underlying ProcessGroup. -DEFAULT_PROCESS_GROUP_TIMEOUT = timedelta(milliseconds=2 ** 31 - 1) +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1) # Value indicating that timeout is not set for RPC call, and the default should be used. -UNSET_RPC_TIMEOUT = _UNSET_RPC_TIMEOUT +UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index 5c807741f5f92..f0d106c53844a 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -17,8 +17,9 @@ def async_execution(fn): :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals. .. note:: To enable asynchronous execution, applications must pass the - function object returned by this decorator to RPC APIs. Otherwise, RPC - will not be able to detect the attributes installed by this decorator. + function object returned by this decorator to RPC APIs. If RPC detected + attributes installed by this decorator, it knows that this function + returns a ``Future`` object and will handle that accordingly. However, this does not mean this decorator has to be outmost one when defining a function. For example, when combined with ``@staticmethod`` or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the @@ -27,14 +28,14 @@ def async_execution(fn): because, when accessed, the static or class method preserves attributes installed by ``@rpc.functions.async_execution``. - .. warning:: `autograd profiler `_ - does not work with ``async_execution`` functions. Example:: The returned :class:`~torch.futures.Future` object can come from - ``rpc.rpc_async``, ``Future.then(cb)``, or :class:`~torch.futures.Future` + :meth:`~torch.distributed.rpc.rpc_async`, + :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future` constructor. The example below shows directly using the - :class:`~torch.futures.Future` returned by ``Future.then(cb)``. + :class:`~torch.futures.Future` returned by + :meth:`~torch.futures.Future.then`. >>> from torch.distributed import rpc >>> @@ -159,5 +160,6 @@ def async_execution(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) - wrapper._wrapped_async_rpc_function = fn + # Can't declare and use attributes of function objects (mypy#2087) + wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] return wrapper diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index c0fbb2dc1c296..3c1c913b38f89 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -2,6 +2,7 @@ import copyreg import io import pickle +import sys import threading import traceback from enum import Enum @@ -9,7 +10,7 @@ import torch import torch.distributed as dist -from . import _get_current_rpc_agent +from torch._C._distributed_rpc import _get_current_rpc_agent # Thread local tensor tables to store tensors while pickling torch.Tensor @@ -36,7 +37,8 @@ class _InternalRPCPickler: """ def __init__(self): - self._dispatch_table = copyreg.dispatch_table.copy() + # Ignore type error because dispatch_table is defined in third-party package + self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] self._dispatch_table[torch.Tensor] = self._tensor_reducer @classmethod @@ -61,6 +63,24 @@ def _py_rref_reducer(self, py_rref): def _rref_reducer(self, rref): return self._py_rref_reducer(rref) + @classmethod + def _script_module_receiver(cls, script_module_serialized): + """ + Given a serialized representation of a ScriptModule created with torch.jit.save, + loads and returns the ScriptModule. + """ + f = io.BytesIO(script_module_serialized) + m = torch.jit.load(f) + return m + + def _script_module_reducer(self, script_module): + """ + Serializes a ScriptModule. + """ + f = io.BytesIO() + torch.jit.save(script_module, f) + return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),)) + def serialize(self, obj): r""" Serialize non tensor data into binary string, tensor data into @@ -79,9 +99,15 @@ def serialize(self, obj): # # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`. # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`. - p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index] # An RRef created locally by RRef Python constructor is type of `rpc.RRef`. - p.dispatch_table[dist.rpc.RRef] = self._rref_reducer + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index] + # Add dispatch pickling for ScriptModule if needed. + if isinstance(obj, torch.jit.ScriptModule): + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index] # save _thread_local_tensor_tables.send_tables if it is in nested call global _thread_local_tensor_tables @@ -168,13 +194,14 @@ def _run_function(python_udf): f"On {_get_current_rpc_agent().get_worker_info()}:\n" f"{repr(e)}\n{traceback.format_exc()}" ) + print(except_str, file=sys.stderr) result = RemoteException(except_str, type(e)) return result def _handle_exception(result): if isinstance(result, RemoteException): - raise result.exception_type(result.msg) + raise result.exception_type(result.msg.encode("utf-8").decode("unicode_escape")) def _build_rpc_profiling_key( @@ -184,7 +211,7 @@ def _build_rpc_profiling_key( Builds the key that RPC calls are profiled with using the autograd profiler. This will be the name of the corresponding Event recorded in the profiler. - Arguments: + Args: exec_type (RPCExecMode): Type of RPC/RRef call func_name (str): Name of function being profiled. current_worker_name (str): Name of current worker. @@ -209,7 +236,7 @@ def _start_record_function(exec_type, func_name, current_worker_name, dest_worke callbacks that start the profiling, though the user is responsible for running the appropriate callbacks when the function to be profiled finishes. - Arguments: + Args: exec_type (RPCExecMode): Type of RPC/RRef call func_name (str): Name of function being profiled. current_worker_name (str): Name of current worker. @@ -222,8 +249,8 @@ def _start_record_function(exec_type, func_name, current_worker_name, dest_worke profile_key = "rpc_{}#{}({} -> {})".format( exec_type.value, str(func_name), current_worker_name, dest_worker_name ) - rf = torch.autograd._RecordFunction() - torch.autograd._run_before_callbacks(rf, profile_key) + rf = torch.autograd._RecordFunction() # type: ignore[attr-defined] + torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined] return rf diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 149a2544d2170..19092c283efb0 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -1,4 +1,4 @@ -from . import _TensorPipeRpcBackendOptionsBase +from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase from . import constants as rpc_contants import torch @@ -12,7 +12,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from :class:`~torch.distributed.rpc.RpcBackendOptions`. - Arguments: + Args: num_worker_threads (int, optional): The number of threads in the thread-pool used by :class:`~torch.distributed.rpc.TensorPipeAgent` to execute @@ -58,7 +58,7 @@ def set_device_map(self, to: str, device_map: Dict): function can be called multiple times to incrementally add device placement configurations. - Arguments: + Args: worker_name (str): Callee name. device_map (Dict of int, str, or torch.device): Device placement mappings from this worker to the callee. This map must be diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 17ce9da643b96..9c01052ae246b 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -3,6 +3,7 @@ from . import functions import torch +from .constants import UNSET_RPC_TIMEOUT def _local_invoke(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) @@ -11,17 +12,17 @@ def _local_invoke(rref, func_name, args, kwargs): def _local_invoke_async_execution(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) -def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs): - rref_type = rref._get_type() +def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): + # Since rref._get_type can potentially issue an RPC, it should respect the + # passed in timeout here. + rref_type = rref._get_type(timeout=timeout) _invoke_func = _local_invoke - if rref_type is not torch._C.ScriptModule: - if not hasattr(rref_type, func_name): - raise ValueError( - f"Function {func_name} is not an attribute of type {rref_type} " - f"referenced by RRef {rref}." - ) - + # Bypass ScriptModules when checking for async function attribute. + bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( + rref_type, torch._C.ScriptModule + ) + if not bypass_type: func = getattr(rref_type, func_name) if hasattr(func, "_wrapped_async_rpc_function"): _invoke_func = _local_invoke_async_execution @@ -29,14 +30,17 @@ def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs): return rpc_api( rref.owner(), _invoke_func, - args=(rref, func_name, args, kwargs) + args=(rref, func_name, args, kwargs), + timeout=timeout ) - +# This class manages proxied RPC API calls for RRefs. It is entirely used from +# C++ (see python_rpc_handler.cpp). class RRefProxy: - def __init__(self, rref, rpc_api): + def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): self.rref = rref self.rpc_api = rpc_api + self.rpc_timeout = timeout def __getattr__(self, func_name): - return partial(_invoke_rpc, self.rref, self.rpc_api, func_name) + return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout) diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 6114f66ce6cb1..d8de89bfc9371 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -22,7 +22,7 @@ class _server_process_global_profile(profile): only report runtime of PyTorch functions. Note: profiler is thread local and is automatically propagated into the async tasks - Arguments: + Args: enabled (bool, optional): Setting this to False makes this context manager a no-op. Default: ``True``. @@ -103,7 +103,7 @@ def __enter__(self): if not self.enabled: return - if self.entered: + if self.entered: # type: ignore[has-type] raise RuntimeError("autograd profiler traces are not reentrant") self.entered = True @@ -113,8 +113,11 @@ def __enter__(self): else torch.autograd.ProfilerState.CPU ) profiler_config = torch.autograd.ProfilerConfig( - profiler_kind, self.record_shapes, self.profile_memory - ) + profiler_kind, + self.record_shapes, + self.profile_memory, + False, + False) _enable_server_process_global_profiler(profiler_config) return self @@ -143,13 +146,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): process_global_function_events = [] for thread_local_events in process_global_events: # Parse from ``Event``s to ``FunctionEvent``s. - thread_local_function_events = torch.autograd.profiler.parse_cpu_trace( + thread_local_function_events = torch.autograd.profiler.parse_legacy_records( thread_local_events ) thread_local_function_events.sort( key=lambda function_event: [ - function_event.cpu_interval.start, - -(function_event.cpu_interval.end), + function_event.time_range.start, + -(function_event.time_range.end), ] ) process_global_function_events.append(thread_local_function_events) @@ -162,6 +165,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): use_cuda=self.use_cuda, profile_memory=self.profile_memory, ) + self.function_events._build_tree() self.process_global_function_events = process_global_function_events diff --git a/torch/distributed/rpc/utils.py b/torch/distributed/rpc/utils.py new file mode 100644 index 0000000000000..40585a73521d5 --- /dev/null +++ b/torch/distributed/rpc/utils.py @@ -0,0 +1,37 @@ +def _parse_remote_device(remote_device: str): + r""" + Parses the remote device. + + Args: + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + + Returns: + A workername and a device. + """ + fields = remote_device.split("/") + if len(fields) == 2: + [on, device] = fields + elif len(fields) == 1: + on = fields[0] + device = "cpu" + else: + raise RuntimeError( + "Could not parse remote_device: {}. The valid format is '/'".format( + remote_device + ) + ) + + # Since the workername in the input remote device won't be validated until the created remote module is executed, + # only do some very basic sanity check on workername at the module creation time. + # As currently there is no regex to describe the format of workername, just check whether the workername is empty. + if not on: + raise RuntimeError( + "The workername in remote_device '{}' cannot be empty. The valid format is '/'".format( + remote_device + ) + ) + + return on, device diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index 4d7a4bff96afe..dd963ab6f7a4e 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -91,7 +91,9 @@ from .half_normal import HalfNormal from .independent import Independent from .kl import kl_divergence, register_kl +from .kumaraswamy import Kumaraswamy from .laplace import Laplace +from .lkj_cholesky import LKJCholesky from .log_normal import LogNormal from .logistic_normal import LogisticNormal from .lowrank_multivariate_normal import LowRankMultivariateNormal @@ -100,7 +102,7 @@ from .multivariate_normal import MultivariateNormal from .negative_binomial import NegativeBinomial from .normal import Normal -from .one_hot_categorical import OneHotCategorical +from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough from .pareto import Pareto from .poisson import Poisson from .relaxed_bernoulli import RelaxedBernoulli @@ -111,6 +113,7 @@ from .uniform import Uniform from .von_mises import VonMises from .weibull import Weibull +from . import transforms __all__ = [ 'Bernoulli', @@ -131,6 +134,7 @@ 'HalfCauchy', 'HalfNormal', 'Independent', + 'Kumaraswamy', 'Laplace', 'LogNormal', 'LogisticNormal', @@ -141,6 +145,7 @@ 'NegativeBinomial', 'Normal', 'OneHotCategorical', + 'OneHotCategoricalStraightThrough', 'Pareto', 'RelaxedBernoulli', 'RelaxedOneHotCategorical', diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index 7c017c133b32f..76cb6ae7029a2 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,4 +1,4 @@ -from numbers import Number +from numbers import Real, Number import torch from torch.distributions import constraints @@ -28,7 +28,7 @@ class Beta(ExponentialFamily): has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): - if isinstance(concentration1, Number) and isinstance(concentration0, Number): + if isinstance(concentration1, Real) and isinstance(concentration0, Real): concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index af132b47c516c..dc2e8fc5bad60 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,4 +1,3 @@ -from numbers import Number import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution @@ -42,18 +41,13 @@ def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): raise ValueError("Either `probs` or `logits` must be specified, but not both.") if probs is not None: self.total_count, self.probs, = broadcast_all(total_count, probs) - self.total_count = self.total_count.type_as(self.logits) - is_scalar = isinstance(self.probs, Number) + self.total_count = self.total_count.type_as(self.probs) else: self.total_count, self.logits, = broadcast_all(total_count, logits) self.total_count = self.total_count.type_as(self.logits) - is_scalar = isinstance(self.logits, Number) self._param = self.probs if probs is not None else self.logits - if is_scalar: - batch_shape = torch.Size() - else: - batch_shape = self._param.size() + batch_shape = self._param.size() super(Binomial, self).__init__(batch_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 01f3dd5201740..319d2dd01b669 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -16,14 +16,14 @@ class Categorical(Distribution): Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``. - If :attr:`probs` is 1D with length-`K`, each element is the relative + If :attr:`probs` is 1-dimensional with length-`K`, each element is the relative probability of sampling the class at that index. - If :attr:`probs` is 2D, it is treated as a batch of relative probability - vectors. + If :attr:`probs` is N-dimensional, the first N-1 dimensions are treated as a batch of + relative probability vectors. .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum, - and it will be normalized to sum to 1. + and it will be normalized to sum to 1 along the last dimension. See also: :func:`torch.multinomial` diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 50be941e073aa..63181a2a6733d 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -69,8 +69,6 @@ def cdf(self, value): return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5 def icdf(self, value): - if self._validate_args: - self._validate_sample(value) return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc def entropy(self): diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 6587631c4cfec..4675b8ceaca8a 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -215,6 +215,12 @@ def _transform_to_lower_cholesky(constraint): return transforms.LowerCholeskyTransform() +@biject_to.register(constraints.corr_cholesky) +@transform_to.register(constraints.corr_cholesky) +def _transform_to_corr_cholesky(constraint): + return transforms.CorrCholeskyTransform() + + @biject_to.register(constraints.cat) def _biject_to_cat(constraint): return transforms.CatTransform([biject_to(c) diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 7bcbc586434dd..87d72d52d26b3 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -3,13 +3,17 @@ - ``constraints.boolean`` - ``constraints.cat`` +- ``constraints.corr_cholesky`` - ``constraints.dependent`` - ``constraints.greater_than(lower_bound)`` +- ``constraints.greater_than_eq(lower_bound)`` - ``constraints.integer_interval(lower_bound, upper_bound)`` - ``constraints.interval(lower_bound, upper_bound)`` +- ``constraints.less_than(upper_bound)`` - ``constraints.lower_cholesky`` - ``constraints.lower_triangular`` - ``constraints.nonnegative_integer`` +- ``constraints.one_hot`` - ``constraints.positive`` - ``constraints.positive_definite`` - ``constraints.positive_integer`` @@ -26,6 +30,7 @@ 'Constraint', 'boolean', 'cat', + 'corr_cholesky', 'dependent', 'dependent_property', 'greater_than', @@ -56,6 +61,8 @@ class Constraint(object): A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized. """ + is_discrete = False + def check(self, value): """ Returns a byte tensor of `sample_shape + batch_shape` indicating @@ -102,14 +109,30 @@ class _Boolean(Constraint): """ Constrain to the two values `{0, 1}`. """ + is_discrete = True + def check(self, value): return (value == 0) | (value == 1) +class _OneHot(Constraint): + """ + Constrain to one-hot vectors. + """ + is_discrete = True + + def check(self, value): + is_boolean = (value == 0) | (value == 1) + is_normalized = value.sum(-1).eq(1) + return is_boolean.all(-1) & is_normalized + + class _IntegerInterval(Constraint): """ Constrain to an integer interval `[lower_bound, upper_bound]`. """ + is_discrete = True + def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound @@ -127,6 +150,8 @@ class _IntegerLessThan(Constraint): """ Constrain to an integer interval `(-inf, upper_bound]`. """ + is_discrete = True + def __init__(self, upper_bound): self.upper_bound = upper_bound @@ -143,6 +168,8 @@ class _IntegerGreaterThan(Constraint): """ Constrain to an integer interval `[lower_bound, inf)`. """ + is_discrete = True + def __init__(self, lower_bound): self.lower_bound = lower_bound @@ -275,6 +302,18 @@ def check(self, value): return lower_triangular & positive_diagonal +class _CorrCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals and each + row vector being of unit length. + """ + def check(self, value): + tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor + row_norm = torch.linalg.norm(value.detach(), dim=-1) + unit_row_norm = (row_norm - 1.).abs().le(tol).all(dim=-1) + return _LowerCholesky().check(value) & unit_row_norm + + class _PositiveDefinite(Constraint): """ Constrain to positive-definite matrices. @@ -345,6 +384,7 @@ def check(self, value): dependent = _Dependent() dependent_property = _DependentProperty boolean = _Boolean() +one_hot = _OneHot() nonnegative_integer = _IntegerGreaterThan(0) positive_integer = _IntegerGreaterThan(1) integer_interval = _IntegerInterval @@ -360,6 +400,7 @@ def check(self, value): simplex = _Simplex() lower_triangular = _LowerTriangular() lower_cholesky = _LowerCholesky() +corr_cholesky = _CorrCholesky() positive_definite = _PositiveDefinite() cat = _Cat stack = _Stack diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index 180fbd8187ee3..5d3d488402030 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -168,8 +168,6 @@ def cdf(self, value): torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs)) def icdf(self, value): - if self._validate_args: - self._validate_sample(value) cut_probs = self._cut_probs() return torch.where( self._outside_unstable_region(), diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index d1e3a39247128..a6c5cc994f9cc 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -2,6 +2,7 @@ import warnings from torch.distributions import constraints from torch.distributions.utils import lazy_property +from typing import Dict, Optional, Any class Distribution(object): @@ -11,12 +12,21 @@ class Distribution(object): has_rsample = False has_enumerate_support = False - _validate_args = False - support = None - arg_constraints = {} + _validate_args = __debug__ @staticmethod def set_default_validate_args(value): + """ + Sets whether validation is enabled or disabled. + + The default behavior mimics Python's ``assert`` statement: validation + is on by default, but is disabled if Python is run in optimized mode + (via ``python -O``). Validation may be expensive, so you may want to + disable it once a model is working. + + Args: + value (bool): Whether to enable validation. + """ if value not in [True, False]: raise ValueError Distribution._validate_args = value @@ -27,7 +37,14 @@ def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_ if validate_args is not None: self._validate_args = validate_args if self._validate_args: - for param, constraint in self.arg_constraints.items(): + try: + arg_constraints = self.arg_constraints + except NotImplementedError: + arg_constraints = {} + warnings.warn(f'{self.__class__} does not define `arg_constraints`. ' + + 'Please set `arg_constraints = {}` or initialize the distribution ' + + 'with `validate_args=False` to turn off validation.') + for param, constraint in arg_constraints.items(): if constraints.is_dependent(constraint): continue # skip constraints that cannot be checked if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property): @@ -72,7 +89,7 @@ def event_shape(self): return self._event_shape @property - def arg_constraints(self): + def arg_constraints(self) -> Dict[str, constraints.Constraint]: """ Returns a dictionary from argument names to :class:`~torch.distributions.constraints.Constraint` objects that @@ -82,7 +99,7 @@ def arg_constraints(self): raise NotImplementedError @property - def support(self): + def support(self) -> Optional[Any]: """ Returns a :class:`~torch.distributions.constraints.Constraint` object representing this distribution's support. @@ -248,8 +265,15 @@ def _validate_sample(self, value): if i != 1 and j != 1 and i != j: raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'. format(actual_shape, expected_shape)) - - if not self.support.check(value).all(): + try: + support = self.support + except NotImplementedError: + warnings.warn(f'{self.__class__} does not define `support` to enable ' + + 'sample validation. Please initialize the distribution with ' + + '`validate_args=False` to turn off validation.') + return + assert support is not None + if not support.check(value).all(): raise ValueError('The value argument must be within the support') def _get_checked_instance(self, cls, _instance=None): diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index 41d7cd9f9787a..ac18980c778be 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -68,8 +68,6 @@ def cdf(self, value): return 1 - torch.exp(-self.rate * value) def icdf(self, value): - if self._validate_args: - self._validate_sample(value) return -torch.log(1 - value) / self.rate def entropy(self): diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 5bd3a2d3bd1ea..a569af34ebdcb 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -5,9 +5,7 @@ from torch.distributions.uniform import Uniform from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import AffineTransform, ExpTransform -from torch.distributions.utils import broadcast_all - -euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant +from torch.distributions.utils import broadcast_all, euler_constant class Gumbel(TransformedDistribution): diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index 73efe4c9e3dcf..5f49507f03383 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -29,7 +29,7 @@ class HalfCauchy(TransformedDistribution): has_rsample = True def __init__(self, scale, validate_args=None): - base_dist = Cauchy(0, scale) + base_dist = Cauchy(0, scale, validate_args=False) super(HalfCauchy, self).__init__(base_dist, AbsTransform(), validate_args=validate_args) @@ -50,6 +50,8 @@ def variance(self): return self.base_dist.variance def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) value = torch.as_tensor(value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device) log_prob = self.base_dist.log_prob(value) + math.log(2) @@ -57,6 +59,8 @@ def log_prob(self, value): return log_prob def cdf(self, value): + if self._validate_args: + self._validate_sample(value) return 2 * self.base_dist.cdf(value) - 1 def icdf(self, prob): diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 048703b30f7de..b528a8dbc1c7f 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -28,7 +28,7 @@ class HalfNormal(TransformedDistribution): has_rsample = True def __init__(self, scale, validate_args=None): - base_dist = Normal(0, scale) + base_dist = Normal(0, scale, validate_args=False) super(HalfNormal, self).__init__(base_dist, AbsTransform(), validate_args=validate_args) @@ -49,11 +49,15 @@ def variance(self): return self.scale.pow(2) * (1 - 2 / math.pi) def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) log_prob = self.base_dist.log_prob(value) + math.log(2) log_prob[value.expand(log_prob.shape) < 0] = -inf return log_prob def cdf(self, value): + if self._validate_args: + self._validate_sample(value) return 2 * self.base_dist.cdf(value) - 1 def icdf(self, prob): diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index cbec92dfd9c69..de34bb6047747 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -2,7 +2,7 @@ from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _sum_rightmost - +from typing import Dict class Independent(Distribution): r""" @@ -31,7 +31,7 @@ class Independent(Distribution): reinterpreted_batch_ndims (int): the number of batch dims to reinterpret as event dims """ - arg_constraints = {} + arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None): if reinterpreted_batch_ndims > len(base_distribution.batch_shape): diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index a30054d370cc8..ba7ba73d60635 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -1,6 +1,7 @@ import math import warnings from functools import total_ordering +from typing import Type, Dict, Callable, Tuple import torch from torch._six import inf @@ -30,10 +31,10 @@ from .poisson import Poisson from .transformed_distribution import TransformedDistribution from .uniform import Uniform -from .utils import _sum_rightmost +from .utils import _sum_rightmost, euler_constant as _euler_gamma _KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions. -_KL_MEMOIZE = {} # Memoized version mapping many specific (type, type) pairs to functions. +_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions. def register_kl(type_p, type_q): @@ -103,8 +104,10 @@ def _dispatch_kl(type_p, type_q): if not matches: return NotImplemented # Check that the left- and right- lexicographic orders agree. - left_p, left_q = min(_Match(*m) for m in matches).types - right_q, right_p = min(_Match(*reversed(m)) for m in matches).types + # mypy isn't smart enough to know that _Match implements __lt__ + # see: https://github.com/python/typing/issues/760#issuecomment-710670503 + left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore + right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore left_fun = _KL_REGISTRY[left_p, left_q] right_fun = _KL_REGISTRY[right_p, right_q] if left_fun is not right_fun: @@ -171,8 +174,6 @@ def kl_divergence(p, q): # KL Divergence Implementations ################################################################################ -_euler_gamma = 0.57721566490153286060 - # Same distributions diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py new file mode 100644 index 0000000000000..4fb2e177e7be7 --- /dev/null +++ b/torch/distributions/kumaraswamy.py @@ -0,0 +1,66 @@ +import torch +from torch.distributions import constraints +from torch.distributions.uniform import Uniform +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, PowerTransform +from torch.distributions.utils import broadcast_all, euler_constant + + +def _moments(a, b, n): + """ + Computes nth moment of Kumaraswamy using using torch.lgamma + """ + arg1 = 1 + n / a + log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b) + return b * torch.exp(log_value) + + +class Kumaraswamy(TransformedDistribution): + r""" + Samples from a Kumaraswamy distribution. + + Example:: + + >>> m = Kumaraswamy(torch.Tensor([1.0]), torch.Tensor([1.0])) + >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 + tensor([ 0.1729]) + + Args: + concentration1 (float or Tensor): 1st concentration parameter of the distribution + (often referred to as alpha) + concentration0 (float or Tensor): 2nd concentration parameter of the distribution + (often referred to as beta) + """ + arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive} + support = constraints.unit_interval + has_rsample = True + + def __init__(self, concentration1, concentration0, validate_args=None): + self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0) + finfo = torch.finfo(self.concentration0.dtype) + base_dist = Uniform(torch.full_like(self.concentration0, 0), + torch.full_like(self.concentration0, 1)) + transforms = [PowerTransform(exponent=self.concentration0.reciprocal()), + AffineTransform(loc=1., scale=-1.), + PowerTransform(exponent=self.concentration1.reciprocal())] + super(Kumaraswamy, self).__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Kumaraswamy, _instance) + new.concentration1 = self.concentration1.expand(batch_shape) + new.concentration0 = self.concentration0.expand(batch_shape) + return super(Kumaraswamy, self).expand(batch_shape, _instance=new) + + @property + def mean(self): + return _moments(self.concentration1, self.concentration0, 1) + + @property + def variance(self): + return _moments(self.concentration1, self.concentration0, 2) - torch.pow(self.mean, 2) + + def entropy(self): + t1 = (1 - self.concentration1.reciprocal()) + t0 = (1 - self.concentration0.reciprocal()) + H0 = torch.digamma(self.concentration0 + 1) + euler_constant + return t0 + t1 * H0 - torch.log(self.concentration1) - torch.log(self.concentration0) diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index d7ec01c65b35f..a505d60c8f381 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -75,8 +75,6 @@ def cdf(self, value): return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(-(value - self.loc).abs() / self.scale) def icdf(self, value): - if self._validate_args: - self._validate_sample(value) term = value - 0.5 return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs()) diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py new file mode 100644 index 0000000000000..cdbfe5be55bbf --- /dev/null +++ b/torch/distributions/lkj_cholesky.py @@ -0,0 +1,126 @@ +""" +This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro). + +Original copyright notice: + +# Copyright: Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 +""" + +import math + +import torch +from torch.distributions import constraints, Beta +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all + + +class LKJCholesky(Distribution): + r""" + LKJ distribution for lower Cholesky factor of correlation matrices. + The distribution is controlled by ``concentration`` parameter :math:`\eta` + to make the probability of the correlation matrix :math:`M` generated from + a Cholesky factor propotional to :math:`\det(M)^{\eta - 1}`. Because of that, + when ``concentration == 1``, we have a uniform distribution over Cholesky + factors of correlation matrices. Note that this distribution samples the + Cholesky factor of correlation matrices and not the correlation matrices + themselves and thereby differs slightly from the derivations in [1] for + the `LKJCorr` distribution. For sampling, this uses the Onion method from + [1] Section 3. + + L ~ LKJCholesky(dim, concentration) + X = L @ L' ~ LKJCorr(dim, concentration) + + Example:: + + >>> l = LKJCholesky(3, 0.5) + >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix + tensor([[ 1.0000, 0.0000, 0.0000], + [ 0.3516, 0.9361, 0.0000], + [-0.1899, 0.4748, 0.8593]]) + + Args: + dimension (dim): dimension of the matrices + concentration (float or Tensor): concentration/shape parameter of the + distribution (often referred to as eta) + + **References** + + [1] `Generating random correlation matrices based on vines and extended onion method`, + Daniel Lewandowski, Dorota Kurowicka, Harry Joe. + """ + arg_constraints = {'concentration': constraints.positive} + support = constraints.corr_cholesky + + def __init__(self, dim, concentration=1., validate_args=None): + if dim < 2: + raise ValueError(f'Expected dim to be an integer greater than or equal to 2. Found dim={dim}.') + self.dim = dim + self.concentration, = broadcast_all(concentration) + batch_shape = self.concentration.size() + event_shape = torch.Size((dim, dim)) + # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1]. + marginal_conc = self.concentration + 0.5 * (self.dim - 2) + offset = torch.arange(self.dim - 1, dtype=self.concentration.dtype, device=self.concentration.device) + offset = torch.cat([offset.new_zeros((1,)), offset]) + beta_conc1 = offset + 0.5 + beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset + self._beta = Beta(beta_conc1, beta_conc0) + super(LKJCholesky, self).__init__(batch_shape, event_shape, validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LKJCholesky, _instance) + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.concentration = self.concentration.expand(batch_shape) + new._beta = self._beta.expand(batch_shape + (self.dim,)) + super(LKJCholesky, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + # This uses the Onion method, but there are a few differences from [1] Sec. 3.2: + # - This vectorizes the for loop and also works for heterogeneous eta. + # - Same algorithm generalizes to n=1. + # - The procedure is simplified since we are sampling the cholesky factor of + # the correlation matrix instead of the correlation matrix itself. As such, + # we only need to generate `w`. + y = self._beta.sample(sample_shape).unsqueeze(-1) + u_normal = torch.randn(self._extended_shape(sample_shape), + dtype=y.dtype, + device=y.device).tril(-1) + u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True) + # Replace NaNs in first row + u_hypersphere[..., 0, :].fill_(0.) + w = torch.sqrt(y) * u_hypersphere + # Fill diagonal elements; clamp for numerical stability + eps = torch.finfo(w.dtype).tiny + diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt() + w += torch.diag_embed(diag_elems) + return w + + def log_prob(self, value): + # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html + # The probability of a correlation matrix is proportional to + # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) + # Additionally, the Jacobian of the transformation from Cholesky factor to + # correlation matrix is: + # prod(L_ii ^ (D - i)) + # So the probability of a Cholesky factor is propotional to + # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i) + # with order_i = 2 * concentration - 2 + D - i + diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:] + order = torch.arange(2, self.dim + 1) + order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order + unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1) + # Compute normalization constant (page 1999 of [1]) + dm1 = self.dim - 1 + alpha = self.concentration + 0.5 * dm1 + denominator = torch.lgamma(alpha) * dm1 + numerator = torch.mvlgamma(alpha - 0.5, dm1) + # pi_constant in [1] is D * (D - 1) / 4 * log(pi) + # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi) + # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2 + pi_constant = 0.5 * dm1 * math.log(math.pi) + normalize_term = pi_constant + numerator - denominator + return unnormalized_log_pdf - normalize_term diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index c0f778c5ed53a..716bfbd8c7a38 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -2,6 +2,7 @@ from torch.distributions.distribution import Distribution from torch.distributions import Categorical from torch.distributions import constraints +from typing import Dict class MixtureSameFamily(Distribution): @@ -45,7 +46,7 @@ class MixtureSameFamily(Distribution): component_distribution: `torch.distributions.Distribution`-like instance. Right-most batch dimension indexes component. """ - arg_constraints = {} + arg_constraints: Dict[str, constraints.Constraint] = {} has_rsample = False def __init__(self, diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 6d61578237cdf..9162dd4713d48 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -2,7 +2,6 @@ from torch._six import inf from torch.distributions.distribution import Distribution from torch.distributions import Categorical -from numbers import Number from torch.distributions import constraints from torch.distributions.utils import broadcast_all @@ -40,6 +39,7 @@ class Multinomial(Distribution): """ arg_constraints = {'probs': constraints.simplex, 'logits': constraints.real} + total_count: int @property def mean(self): @@ -50,7 +50,7 @@ def variance(self): return self.total_count * self.probs * (1 - self.probs) def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): - if not isinstance(total_count, Number): + if not isinstance(total_count, int): raise NotImplementedError('inhomogeneous total_count is not supported') self.total_count = total_count self._categorical = Categorical(probs=probs, logits=logits) diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index de997f49a94f7..4845d4742dfc5 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -122,25 +122,27 @@ def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tri if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") - loc_ = loc.unsqueeze(-1) # temporarily add dim on right if scale_tril is not None: if scale_tril.dim() < 2: raise ValueError("scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_) + batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) + self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: raise ValueError("covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_) + batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1]) + self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: if precision_matrix.dim() < 2: raise ValueError("precision_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_) - self.loc = loc_[..., 0] # drop rightmost dim + batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1]) + self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) + self.loc = loc.expand(batch_shape + (-1,)) - batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] + event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) if scale_tril is not None: diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index 7395635971dee..4a8babb34a7c0 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -10,14 +10,14 @@ class NegativeBinomial(Distribution): Creates a Negative Binomial distribution, i.e. distribution of the number of successful independent and identical Bernoulli trials before :attr:`total_count` failures are achieved. The probability - of success of each Bernoulli trial is :attr:`probs`. + of failure of each Bernoulli trial is :attr:`probs`. Args: total_count (float or Tensor): non-negative number of negative Bernoulli trials to stop, although the distribution is still valid for real valued count - probs (Tensor): Event probabilities of success in the half open interval [0, 1) - logits (Tensor): Event log-odds for probabilities of success + probs (Tensor): Event probabilities of failure in the half open interval [0, 1) + logits (Tensor): Event log-odds for probabilities of failure """ arg_constraints = {'total_count': constraints.greater_than_eq(0), 'probs': constraints.half_open_interval(0., 1.), @@ -77,8 +77,10 @@ def param_shape(self): @lazy_property def _gamma(self): + # Note we avoid validating because self.total_count can be zero. return torch.distributions.Gamma(concentration=self.total_count, - rate=torch.exp(-self.logits)) + rate=torch.exp(-self.logits), + validate_args=False) def sample(self, sample_shape=torch.Size()): with torch.no_grad(): diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index a125806108e8c..1f14f0ae015f2 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,4 +1,5 @@ import math +from numbers import Real from numbers import Number import torch @@ -72,7 +73,7 @@ def log_prob(self, value): self._validate_sample(value) # compute the variance var = (self.scale ** 2) - log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log() + log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) def cdf(self, value): @@ -81,8 +82,6 @@ def cdf(self, value): return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2))) def icdf(self, value): - if self._validate_args: - self._validate_sample(value) return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) def entropy(self): diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index bd23f2344df50..64f696802d769 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -29,7 +29,7 @@ class OneHotCategorical(Distribution): """ arg_constraints = {'probs': constraints.simplex, 'logits': constraints.real} - support = constraints.simplex + support = constraints.one_hot has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): @@ -96,3 +96,18 @@ def enumerate_support(self, expand=True): if expand: values = values.expand((n,) + self.batch_shape + (n,)) return values + +class OneHotCategoricalStraightThrough(OneHotCategorical): + r""" + Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- + through gradient estimator from [1]. + + [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation + (Bengio et al, 2013) + """ + has_rsample = True + + def rsample(self, sample_shape=torch.Size()): + samples = self.sample(sample_shape) + probs = self._categorical.probs # cached via @lazy_property + return samples + (probs - probs.detach()) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 46c4fbccb43fd..d6bb4de75c6b5 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -3,6 +3,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.transforms import Transform from torch.distributions.utils import _sum_rightmost +from typing import Dict class TransformedDistribution(Distribution): @@ -38,7 +39,7 @@ class TransformedDistribution(Distribution): :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` """ - arg_constraints = {} + arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, transforms, validate_args=None): self.base_dist = base_distribution diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 09e00d55e8d9e..4181db799b28e 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -6,15 +6,18 @@ import torch.nn.functional as F from torch.distributions import constraints from torch.distributions.utils import (_sum_rightmost, broadcast_all, - lazy_property) + lazy_property, tril_matrix_to_vec, + vec_to_tril_matrix) from torch.nn.functional import pad from torch.nn.functional import softplus +from typing import List __all__ = [ 'AbsTransform', 'AffineTransform', 'CatTransform', 'ComposeTransform', + 'CorrCholeskyTransform', 'ExpTransform', 'LowerCholeskyTransform', 'PowerTransform', @@ -77,6 +80,7 @@ class Transform(object): transforms that act jointly on matrices, etc. """ bijective = False + codomain: constraints.Constraint event_dim = 0 def __init__(self, cache_size=0): @@ -90,6 +94,14 @@ def __init__(self, cache_size=0): raise ValueError('cache_size must be 0 or 1') super(Transform, self).__init__() + @property + def input_event_dim(self): + return self.event_dim + + @property + def output_event_dim(self): + return self.event_dim + @property def inv(self): """ @@ -185,22 +197,37 @@ def __init__(self, transform): @constraints.dependent_property def domain(self): + assert self._inv is not None return self._inv.codomain @constraints.dependent_property def codomain(self): + assert self._inv is not None return self._inv.domain + @property + def input_event_dim(self): + assert self._inv is not None + return self._inv.output_event_dim + + @property + def output_event_dim(self): + assert self._inv is not None + return self._inv.input_event_dim + @property def bijective(self): + assert self._inv is not None return self._inv.bijective @property def sign(self): + assert self._inv is not None return self._inv.sign @property def event_dim(self): + assert self._inv is not None return self._inv.event_dim @property @@ -208,17 +235,21 @@ def inv(self): return self._inv def with_cache(self, cache_size=1): + assert self._inv is not None return self.inv.with_cache(cache_size).inv def __eq__(self, other): if not isinstance(other, _InverseTransform): return False + assert self._inv is not None return self._inv == other._inv def __call__(self, x): + assert self._inv is not None return self._inv._inv_call(x) def log_abs_det_jacobian(self, x, y): + assert self._inv is not None return -self._inv.log_abs_det_jacobian(y, x) @@ -500,8 +531,8 @@ def __eq__(self, other): @property def sign(self): - if isinstance(self.scale, numbers.Number): - return 1 if self.scale > 0 else -1 if self.scale < 0 else 0 + if isinstance(self.scale, numbers.Real): + return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 return self.scale.sign() def _call(self, x): @@ -513,7 +544,7 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): shape = x.shape scale = self.scale - if isinstance(scale, numbers.Number): + if isinstance(scale, numbers.Real): result = torch.full_like(x, math.log(abs(scale))) else: result = torch.abs(scale).log() @@ -524,6 +555,74 @@ def log_abs_det_jacobian(self, x, y): return result.expand(shape) +class CorrCholeskyTransform(Transform): + r""" + Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the + Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower + triangular matrix with positive diagonals and unit Euclidean norm for each row. + The transform is processed as follows: + + 1. First we convert x into a lower triangular matrix in row order. + 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of + class :class:`StickBreakingTransform` to transform :math:`X_i` into a + unit Euclidean length vector using the following steps: + - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. + - Transforms into an unsigned domain: :math:`z_i = r_i^2`. + - Applies :math:`s_i = StickBreakingTransform(z_i)`. + - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. + """ + domain = constraints.real_vector + codomain = constraints.corr_cholesky + input_event_dim = 1 + output_event_dim = 2 + bijective = True + + @property + def event_dim(self): + raise ValueError("Please use `.input_event_dim` or `.output_event_dim` instead.") + + def _call(self, x): + x = torch.tanh(x) + eps = torch.finfo(x.dtype).eps + x = x.clamp(min=-1 + eps, max=1 - eps) + r = vec_to_tril_matrix(x, diag=-1) + # apply stick-breaking on the squared values + # Note that y = sign(r) * sqrt(z * z1m_cumprod) + # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) + z = r ** 2 + z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) + # Diagonal elements must be 1. + r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) + y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) + return y + + def _inverse(self, y): + # inverse stick-breaking + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y_cumsum = 1 - torch.cumsum(y * y, dim=-1) + y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) + y_vec = tril_matrix_to_vec(y, diag=-1) + y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) + t = y_vec / (y_cumsum_vec).sqrt() + # inverse of tanh + x = ((1 + t) / (1 - t)).log() / 2 + return x + + def log_abs_det_jacobian(self, x, y, intermediates=None): + # Because domain and codomain are two spaces with different dimensions, determinant of + # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the + # flattened lower triangular part of `y`. + + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y1m_cumsum = 1 - (y * y).cumsum(dim=-1) + # by taking diagonal=-2, we don't need to shift z_cumprod to the right + # also works for 2 x 2 matrix + y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) + stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) + tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1) + return stick_breaking_logdet + tanh_logdet + + class SoftmaxTransform(Transform): r""" Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then @@ -575,7 +674,7 @@ def _call(self, x): offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) z = _clipped_sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) - y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1) + y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1) return y def _inverse(self, y): @@ -619,6 +718,7 @@ def _inverse(self, y): class CatTransform(Transform): + tseq: List[numbers.Number] """ Transform functor that applies a sequence of transforms `tseq` component-wise to each submatrix at `dim`, of length `lengths[dim]`, @@ -633,6 +733,7 @@ class CatTransform(Transform): """ def __init__(self, tseq, dim=0, lengths=None, cache_size=0): assert all(isinstance(t, Transform) for t in tseq) + self.event_dim = max(t.event_dim for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] super(CatTransform, self).__init__(cache_size=cache_size) @@ -684,9 +785,20 @@ def log_abs_det_jacobian(self, x, y): for trans, length in zip(self.transforms, self.lengths): xslice = x.narrow(self.dim, start, length) yslice = y.narrow(self.dim, start, length) - logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) + logdetjac = trans.log_abs_det_jacobian(xslice, yslice) + if trans.event_dim < self.event_dim: + logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) + logdetjacs.append(logdetjac) start = start + length # avoid += for jit compat - return torch.cat(logdetjacs, dim=self.dim) + # Decide whether to concatenate or sum. + dim = self.dim + if dim >= 0: + dim = dim - x.dim() + dim = dim + self.event_dim + if dim < 0: + return torch.cat(logdetjacs, dim=dim) + else: + return sum(logdetjacs) @property def bijective(self): diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index b212c52695c24..edaf5abf77a5a 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -81,8 +81,6 @@ def cdf(self, value): return result.clamp(min=0, max=1) def icdf(self, value): - if self._validate_args: - self._validate_sample(value) result = value * (self.high - self.low) + self.low return result diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 65636ab3f30a2..84f45f1d33cf4 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -2,6 +2,10 @@ from numbers import Number import torch import torch.nn.functional as F +from typing import Dict, Any +from torch.overrides import has_torch_function + +euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant def broadcast_all(*values): @@ -14,22 +18,25 @@ def broadcast_all(*values): values are scalars, then they are upcasted to scalar Tensors. Args: - values (list of `numbers.Number` or `torch.*Tensor`) + values (list of `numbers.Number`, `torch.*Tensor` or objects implementing __torch_function__) Raises: - ValueError: if any of the values is not a `numbers.Number` or - `torch.*Tensor` instance + ValueError: if any of the values is not a `numbers.Number` instance, + a `torch.*Tensor` instance, or an instance implementing __torch_function__ """ - if not all(isinstance(v, torch.Tensor) or isinstance(v, Number) for v in values): - raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.') - if not all([isinstance(v, torch.Tensor) for v in values]): - options = dict(dtype=torch.get_default_dtype()) + if not all(isinstance(v, torch.Tensor) or has_torch_function((v,)) or isinstance(v, Number) + for v in values): + raise ValueError('Input arguments must all be instances of numbers.Number, ' + 'torch.Tensor or objects implementing __torch_function__.') + if not all([isinstance(v, torch.Tensor) or has_torch_function((v,)) for v in values]): + options: Dict[str, Any] = dict(dtype=torch.get_default_dtype()) for value in values: if isinstance(value, torch.Tensor): options = dict(dtype=value.dtype, device=value.device) break - values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options) - for v in values] + new_values = [v if isinstance(v, torch.Tensor) or has_torch_function((v,)) else torch.tensor(v, **options) + for v in values] + return torch.broadcast_tensors(*new_values) return torch.broadcast_tensors(*values) @@ -94,7 +101,7 @@ class lazy_property(object): """ def __init__(self, wrapped): self.wrapped = wrapped - update_wrapper(self, wrapped) + update_wrapper(self, wrapped) # type: ignore[arg-type] def __get__(self, instance, obj_type=None): if instance is None: @@ -103,3 +110,36 @@ def __get__(self, instance, obj_type=None): value = self.wrapped(instance) setattr(instance, self.wrapped.__name__, value) return value + + +def tril_matrix_to_vec(mat, diag=0): + r""" + Convert a `D x D` matrix or a batch of matrices into a (batched) vector + which comprises of lower triangular elements from the matrix in row order. + """ + n = mat.shape[-1] + if not torch._C._get_tracing_state() and (diag < -n or diag >= n): + raise ValueError(f'diag ({diag}) provided is outside [{-n}, {n-1}].') + arange = torch.arange(n, device=mat.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + vec = mat[..., tril_mask] + return vec + + +def vec_to_tril_matrix(vec, diag=0): + r""" + Convert a vector or a batch of vectors into a batched `D x D` + lower triangular matrix containing elements from the vector in row order. + """ + # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0 + n = (-(1 + 2 * diag) + ((1 + 2 * diag)**2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1))**0.5) / 2 + eps = torch.finfo(vec.dtype).eps + if not torch._C._get_tracing_state() and (round(n) - n > eps): + raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' + + 'the lower triangular part of a square D x D matrix.') + n = torch.round(n).long() if isinstance(n, torch.Tensor) else round(n) + mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n))) + arange = torch.arange(n, device=vec.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + mat[..., tril_mask] = vec + return mat diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 72ee01c881bf4..31721960af2e9 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -49,7 +49,7 @@ def _log_modified_bessel_fn(x, order=0): return result -@torch.jit._script_if_tracing +@torch.jit.script_if_tracing def _rejection_sample(loc, concentration, proposal_r, x): done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) while not done.all(): diff --git a/torch/fft/__init__.py b/torch/fft/__init__.py index 3e4bcc35464b2..7efdb04a52d36 100644 --- a/torch/fft/__init__.py +++ b/torch/fft/__init__.py @@ -2,6 +2,12 @@ import torch from torch._C import _add_docstr, _fft # type: ignore +from torch._torch_docs import factory_common_args + +__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn', + 'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn', + 'hfft', 'ihfft', 'fftfreq', 'rfftfreq', 'fftshift', 'ifftshift', + 'Tensor'] Tensor = torch.Tensor @@ -43,7 +49,6 @@ Example: - >>> import torch.fft >>> t = torch.arange(4) >>> t tensor([0, 1, 2, 3]) @@ -81,12 +86,105 @@ Example: - >>> import torch.fft >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) >>> torch.fft.ifft(t) tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]) """) +fft2 = _add_docstr(_fft.fft_fft2, r""" +fft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor + +Computes the 2 dimensional discrete Fourier transform of :attr:`input`. +Equivalent to :func:`~torch.fft.fftn` but FFTs only the last two dimensions by default. + +Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This + function always returns all positive and negative frequency terms even + though, for real inputs, half of these values are redundant. + :func:`~torch.fft.rfft2` returns the more compact one-sided representation + where only the positive frequencies of the last dimension are returned. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.fft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ifft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` + between the two transforms. This is required to make + :func:`~torch.fft.ifft2` the exact inverse. + + Default is ``"backward"`` (no normalization). + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> fft2 = torch.fft.fft2(t) + + The discrete Fourier transform is separable, so :func:`~torch.fft.fft2` + here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls: + + >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) + >>> torch.allclose(fft2, two_ffts) + +""") + +ifft2 = _add_docstr(_fft.fft_ifft2, r""" +ifft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor + +Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`. +Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ifft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.fft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> ifft2 = torch.fft.ifft2(t) + + The discrete Fourier transform is separable, so :func:`~torch.fft.ifft2` + here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls: + + >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) + >>> torch.allclose(ifft2, two_iffts) + +""") + fftn = _add_docstr(_fft.fft_fftn, r""" fftn(input, s=None, dim=None, norm=None) -> Tensor @@ -127,7 +225,6 @@ Example: - >>> import torch.fft >>> x = torch.rand(10, 10, dtype=torch.complex64) >>> fftn = torch.fft.fftn(t) @@ -170,7 +267,6 @@ Example: - >>> import torch.fft >>> x = torch.rand(10, 10, dtype=torch.complex64) >>> ifftn = torch.fft.ifftn(t) @@ -212,7 +308,6 @@ Example: - >>> import torch.fft >>> t = torch.arange(4) >>> t tensor([0, 1, 2, 3]) @@ -274,7 +369,6 @@ Example: - >>> import torch.fft >>> t = torch.arange(5) >>> t tensor([0, 1, 2, 3, 4]) @@ -294,6 +388,135 @@ tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000]) """) +rfft2 = _add_docstr(_fft.fft_rfft2, r""" +rfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor + +Computes the 2-dimensional discrete Fourier transform of real :attr:`input`. +Equivalent to :func:`~torch.fft.rfftn` but FFTs only the last two dimensions by default. + +The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``, +so the full :func:`~torch.fft.fft2` output contains redundant information. +:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last +dimension. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.rfft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.irfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Example: + + >>> t = torch.rand(10, 10) + >>> rfft2 = torch.fft.rfft2(t) + >>> rfft2.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.fft2`, we have all + elements up to the Nyquist frequency. + + >>> fft2 = torch.fft.fft2(t) + >>> torch.allclose(fft2[..., :6], rfft2) + True + + The discrete Fourier transform is separable, so :func:`~torch.fft.rfft2` + here is equivalent to a combination of :func:`~torch.fft.fft` and + :func:`~torch.fft.rfft`: + + >>> two_ffts = torch.fft.fft(torch.fft.rfft(x, dim=1), dim=0) + >>> torch.allclose(rfft2, two_ffts) + +""") + +irfft2 = _add_docstr(_fft.fft_irfft2, r""" +irfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.rfft2`. +Equivalent to :func:`~torch.fft.irfftn` but IFFTs only the last two dimensions by default. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier +domain, as produced by :func:`~torch.fft.rfft2`. By the Hermitian property, the +output will be real-valued. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal shape :attr:`s`. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.irfft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.rfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Example: + + >>> t = torch.rand(10, 9) + >>> T = torch.fft.rfft2(t) + + Without specifying the output length to :func:`~torch.fft.irfft2`, the output + will not round-trip properly because the input is odd-length in the last + dimension: + + >>> torch.fft.irfft2(T).size() + torch.Size([10, 10]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.irfft2(T, t.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.allclose(roundtrip, t) + True + +""") + rfftn = _add_docstr(_fft.fft_rfftn, r""" rfftn(input, s=None, dim=None, norm=None) -> Tensor @@ -331,7 +554,6 @@ Example: - >>> import torch.fft >>> t = torch.rand(10, 10) >>> rfftn = torch.fft.rfftn(t) >>> rfftn.size() @@ -403,7 +625,6 @@ Example: - >>> import torch.fft >>> t = torch.rand(10, 9) >>> T = torch.fft.rfftn(t) @@ -481,7 +702,6 @@ Taking a real-valued frequency signal and bringing it into the time domain gives Hermitian symmetric output: - >>> import torch.fft >>> t = torch.arange(5) >>> t tensor([0, 1, 2, 3, 4]) @@ -536,7 +756,6 @@ Example: - >>> import torch.fft >>> t = torch.arange(5) >>> t tensor([0, 1, 2, 3, 4]) @@ -549,3 +768,177 @@ tensor([ 2.0000+-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, -0.5000+0.6882j]) """) + +fftfreq = _add_docstr(_fft.fft_fftfreq, r""" +fftfreq(n, d=1.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Computes the discrete Fourier Transform sample frequencies for a signal of size :attr:`n`. + +Note: + By convention, :func:`~torch.fft.fft` returns positive frequency terms + first, followed by the negative frequencies in reverse order, so that + ``f[-i]`` for all :math:`0 < i \leq n/2`` in Python gives the negative + frequency terms. For an FFT of length :attr:`n` and with inputs spaced in + length unit :attr:`d`, the frequencies are:: + + f = [0, 1, ..., (n - 1) // 2, -(n // 2), ..., -1] / (d * n) + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. :func:`~torch.fft.fftfreq` follows NumPy's + convention of taking it to be negative. + +Args: + n (int): the FFT length + d (float, optional): The sampling length scale. + The spacing between individual samples of the FFT input. + The default assumes unit spacing, dividing that result by the actual + spacing gives the result in physical frequency units. + +Keyword Args: + {dtype} + {layout} + {device} + {requires_grad} + +Example: + + >>> torch.fft.fftfreq(5) + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + + For even input, we can see the Nyquist frequency at ``f[2]`` is given as + negative: + + >>> torch.fft.fftfreq(4) + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + +""".format(**factory_common_args)) + +rfftfreq = _add_docstr(_fft.fft_rfftfreq, r""" +rfftfreq(n, d=1.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Computes the sample frequencies for :func:`~torch.fft.rfft` with a signal of size :attr:`n`. + +Note: + :func:`~torch.fft.rfft` returns Hermitian one-sided output, so only the + positive frequency terms are returned. For a real FFT of length :attr:`n` + and with inputs spaced in length unit :attr:`d`, the frequencies are:: + + f = torch.arange((n + 1) // 2) / (d * n) + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. Unlike :func:`~torch.fft.fftfreq`, + :func:`~torch.fft.rfftfreq` always returns it as positive. + +Args: + n (int): the real FFT length + d (float, optional): The sampling length scale. + The spacing between individual samples of the FFT input. + The default assumes unit spacing, dividing that result by the actual + spacing gives the result in physical frequency units. + +Keyword Args: + {dtype} + {layout} + {device} + {requires_grad} + +Example: + + >>> torch.fft.rfftfreq(5) + tensor([ 0.0000, 0.2000, 0.4000]) + + >>> torch.fft.rfftfreq(4) + tensor([ 0.0000, 0.2500, 0.5000]) + + Compared to the output from :func:`~torch.fft.fftfreq`, we see that the + Nyquist frequency at ``f[2]`` has changed sign: + >>> torch.fft.fftfreq(4) + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + +""".format(**factory_common_args)) + +fftshift = _add_docstr(_fft.fft_fftshift, r""" +fftshift(input, dim=None) -> Tensor + +Reorders n-dimensional FFT data, as provided by :func:`~torch.fft.fftn`, to have +negative frequency terms first. + +Note: + By convention, the FFT returns positive frequency terms first, followed by + the negative frequencies in reverse order, so that ``f[-i]`` for all + :math:`0 < i \leq n/2` in Python gives the negative frequency terms. + :func:`~torch.fft.fftshift` rearranges all frequencies into ascending order + from negative to positive with the zero-frequency term in the center. + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. :func:`~torch.fft.fftshift` always puts the + Nyquist term at the 0-index. This is the same convention used by + :func:`~torch.fft.fftfreq`. + +Args: + input (Tensor): the tensor in FFT order + dim (int, Tuple[int], optional): The dimensions to rearrange. + Only dimensions specified here will be rearranged, any other dimensions + will be left in their original order. + Default: All dimensions of :attr:`input`. + +Example: + + >>> f = torch.fft.fftfreq(4) + >>> f + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + + >>> torch.fftshift(f) + tensor([-0.5000, -0.2500, 0.0000, 0.2500]) + + Also notice that the Nyquist frequency term at ``f[2]`` was moved to the + beginning of the tensor. + + This also works for multi-dimensional transforms: + + >>> x = torch.fft.fftfreq(5, d=1/5) + 0.1 * torch.fft.fftfreq(5, d=1/5).unsqueeze(1) + >>> x + tensor([[ 0.0000, 1.0000, 2.0000, -2.0000, -1.0000], + [ 0.1000, 1.1000, 2.1000, -1.9000, -0.9000], + [ 0.2000, 1.2000, 2.2000, -1.8000, -0.8000], + [-0.2000, 0.8000, 1.8000, -2.2000, -1.2000], + [-0.1000, 0.9000, 1.9000, -2.1000, -1.1000]]) + + >>> torch.fft.fftshift(x) + tensor([[-2.2000, -1.2000, -0.2000, 0.8000, 1.8000], + [-2.1000, -1.1000, -0.1000, 0.9000, 1.9000], + [-2.0000, -1.0000, 0.0000, 1.0000, 2.0000], + [-1.9000, -0.9000, 0.1000, 1.1000, 2.1000], + [-1.8000, -0.8000, 0.2000, 1.2000, 2.2000]]) + +""") + +ifftshift = _add_docstr(_fft.fft_ifftshift, r""" +ifftshift(input, dim=None) -> Tensor + +Inverse of :func:`~torch.fft.fftshift`. + +Args: + input (Tensor): the tensor in FFT order + dim (int, Tuple[int], optional): The dimensions to rearrange. + Only dimensions specified here will be rearranged, any other dimensions + will be left in their original order. + Default: All dimensions of :attr:`input`. + +Example: + + >>> f = torch.fft.fftfreq(5) + >>> f + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + + A round-trip through :func:`~torch.fft.fftshift` and + :func:`~torch.fft.ifftshift` gives the same result: + + >>> shifted = torch.fftshift(f) + >>> torch.ifftshift(shifted) + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + +""") diff --git a/torch/functional.py b/torch/functional.py index 84dbc2c5a4b76..1d3403e653042 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -6,10 +6,12 @@ import torch.nn.functional as F from torch.types import _size from ._lowrank import svd_lowrank, pca_lowrank -from .overrides import has_torch_function, handle_torch_function +from .overrides import ( + has_torch_function, has_torch_function_unary, has_torch_function_variadic, + handle_torch_function) from ._jit_internal import boolean_dispatch, List from ._jit_internal import _overload as overload -import warnings +from torch._autograd_functions import _LU Tensor = torch.Tensor from torch import _VF @@ -19,6 +21,7 @@ 'atleast_2d', 'atleast_3d', 'align_tensors', + 'broadcast_shapes', 'broadcast_tensors', 'cartesian_prod', 'block_diag', @@ -66,12 +69,44 @@ def broadcast_tensors(*tensors): tensor([[0, 1, 2], [0, 1, 2]]) """ - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): - return handle_torch_function(broadcast_tensors, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(broadcast_tensors, tensors, *tensors) return _VF.broadcast_tensors(tensors) # type: ignore +def broadcast_shapes(*shapes): + r"""broadcast_shapes(*shapes) -> Size + + Similar to :func:`broadcast_tensors` but for shapes. + + This is equivalent to + ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape`` + but avoids the need create to intermediate tensors. This is useful for + broadcasting tensors of common batch shape but different rightmost shape, + e.g. to broadcast mean vectors with covariance matrices. + + Example:: + + >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1)) + torch.Size([1, 3, 2]) + + Args: + \*shapes (torch.Size): Shapes of tensors. + + Returns: + shape (torch.Size): A shape compatible with all input shapes. + + Raises: + RuntimeError: If shapes are incompatible. + """ + # TODO Movie this to C++ once the jit has better support for torch.Size. + with torch.no_grad(): + scalar = torch.zeros((), device="cpu") + tensors = [scalar.expand(shape) for shape in shapes] + tensors = broadcast_tensors(*tensors) + return tensors[0].shape + + def split(tensor, split_size_or_sections, dim=0): r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. @@ -84,7 +119,7 @@ def split(tensor, split_size_or_sections, dim=0): into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according to :attr:`split_size_or_sections`. - Arguments: + Args: tensor (Tensor): tensor to split. split_size_or_sections (int) or (list(int)): size of a single chunk or list of sizes for each chunk @@ -111,10 +146,9 @@ def split(tensor, split_size_or_sections, dim=0): [6, 7], [8, 9]])) """ - if not torch.jit.is_scripting(): - if type(tensor) is not Tensor and has_torch_function((tensor,)): - return handle_torch_function(split, (tensor,), tensor, split_size_or_sections, - dim=dim) + if has_torch_function_unary(tensor): + return handle_torch_function( + split, (tensor,), tensor, split_size_or_sections, dim=dim) # Overwriting reason: # This dispatches to two ATen functions depending on the type of # split_size_or_sections. The branching code is in tensor.py, which we @@ -155,7 +189,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``. - Arguments: + Args: LU_data (Tensor): the packed LU factorization data LU_pivots (Tensor): the packed LU factorization pivots unpack_data (bool): flag indicating if the data should be unpacked @@ -200,12 +234,11 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): >>> torch.norm(A_ - A) tensor(2.9802e-08) """ - if not torch.jit.is_scripting(): - tens_ops = (LU_data, LU_pivots) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data, - unpack_pivots=unpack_pivots) + if has_torch_function_variadic(LU_data, LU_pivots): + return handle_torch_function( + lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, + unpack_data=unpack_data, + unpack_pivots=unpack_pivots) shape = LU_data.shape # In generalized LU factorization, the following shape relations hold: # A.shape[-2:] == (m, n) @@ -241,7 +274,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed indices = _indices_product(shape[:-2]) for idx in indices: - final_order = [i for i in range(m)] # noqa: C416 TODO: rewrite as list(range(m)) + final_order = list(range(m)) for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)): final_order[k], final_order[j] = final_order[j], final_order[k] # TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list @@ -249,7 +282,7 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))) else: P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype) - final_order = [i for i in range(m)] # noqa: C416 TODO: rewrite as list(range(m)) + final_order = list(range(m)) for k, j, in enumerate(LU_pivots_zero_idx): final_order[k], final_order[j] = final_order[j], final_order[k] P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)) @@ -262,79 +295,109 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): def einsum(equation, *operands): r"""einsum(equation, *operands) -> Tensor -This function provides a way of computing multilinear expressions (i.e. sums of products) using the -Einstein summation convention. - -Args: - equation (string): The equation is given in terms of lower case letters (indices) to be associated - with each dimension of the operands and result. The left hand side lists the operands - dimensions, separated by commas. There should be one index letter per tensor dimension. - The right hand side follows after `->` and gives the indices for the output. - If the `->` and right hand side are omitted, it implicitly defined as the alphabetically - sorted list of all indices appearing exactly once in the left hand side. - The indices not apprearing in the output are summed over after multiplying the operands - entries. - If an index appears several times for the same operand, a diagonal is taken. - Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred, - the ellipsis dimensions are at the beginning of the output. - operands (Tensor): The operands to compute the Einstein sum of. - -.. note:: - - This function does not optimize the given expression, so a different formula for the same computation may - run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) - can optimize the formula for you. - -Examples:: - - >>> x = torch.randn(5) - >>> y = torch.randn(4) - >>> torch.einsum('i,j->ij', x, y) # outer product - tensor([[-0.0570, -0.0286, -0.0231, 0.0197], - [ 1.2616, 0.6335, 0.5113, -0.4351], - [ 1.4452, 0.7257, 0.5857, -0.4984], - [-0.4647, -0.2333, -0.1883, 0.1603], - [-1.1130, -0.5588, -0.4510, 0.3838]]) - - - >>> A = torch.randn(3,5,4) - >>> l = torch.randn(2,5) - >>> r = torch.randn(2,4) - >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear - tensor([[-0.3430, -5.2405, 0.4494], - [ 0.3311, 5.5201, -3.0356]]) - - - >>> As = torch.randn(3,2,5) - >>> Bs = torch.randn(3,5,4) - >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication - tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], - [-1.6706, -0.8097, -0.8025, -2.1183]], - - [[ 4.2239, 0.3107, -0.5756, -0.2354], - [-1.4558, -0.3460, 1.5087, -0.8530]], - - [[ 2.8153, 1.8787, -4.3839, -1.2112], - [ 0.3728, -2.1131, 0.0921, 0.8305]]]) - - >>> A = torch.randn(3, 3) - >>> torch.einsum('ii->i', A) # diagonal - tensor([-0.7825, 0.8291, -0.1936]) - - >>> A = torch.randn(4, 3, 3) - >>> torch.einsum('...ii->...i', A) # batch diagonal - tensor([[-1.0864, 0.7292, 0.0569], - [-0.9725, -1.0270, 0.6493], - [ 0.5832, -1.1716, -1.5084], - [ 0.4041, -1.1690, 0.8570]]) - - >>> A = torch.randn(2, 3, 4, 5) - >>> torch.einsum('...ij->...ji', A).shape # batch permute - torch.Size([2, 3, 5, 4]) -""" - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in operands) and has_torch_function(operands): - return handle_torch_function(einsum, operands, equation, *operands) + Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation + based on the Einstein summation convention. + + Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them + in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of + this format are described below, but the general idea is to label every dimension of the input :attr:`operands` + with some subscript and define which subscripts are part of the output. The output is then computed by summing + the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the + output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`. + Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). + + Equation: + + The :attr:`equation` string specifies the subscripts (lower case letters `['a', 'z']`) for each dimension of + the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a + comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript + must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is + repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand + must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that + appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. + The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based + on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. + + Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation + followed by the subscripts for the output. For instance, the following equation computes the transpose of a + matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and + at most once for the output. + + Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. + Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, + e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth + dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the + 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not + explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions), + before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements + batch matrix multiplication `'...ij,...jk'`. + + A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, + arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. + + .. note:: + + ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions + covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output. + + .. note:: + + This function does not optimize the given expression, so a different formula for the same computation may + run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) + can optimize the formula for you. + + Args: + equation (string): The subscripts for the Einstein summation. + operands (Tensor): The operands to compute the Einstein sum of. + + Examples:: + + # trace + >>> torch.einsum('ii', torch.randn(4, 4)) + tensor(-1.2104) + + # diagonal + >>> torch.einsum('ii->i', torch.randn(4, 4)) + tensor([-0.1034, 0.7952, -0.2433, 0.4545]) + + # outer product + >>> x = torch.randn(5) + >>> y = torch.randn(4) + >>> torch.einsum('i,j->ij', x, y) + tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], + [-0.3744, 0.9381, 1.2685, -1.6070], + [ 0.7208, -1.8058, -2.4419, 3.0936], + [ 0.1713, -0.4291, -0.5802, 0.7350], + [ 0.5704, -1.4290, -1.9323, 2.4480]]) + + # batch matrix multiplication + >>> As = torch.randn(3,2,5) + >>> Bs = torch.randn(3,5,4) + >>> torch.einsum('bij,bjk->bik', As, Bs) + tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], + [-1.6706, -0.8097, -0.8025, -2.1183]], + + [[ 4.2239, 0.3107, -0.5756, -0.2354], + [-1.4558, -0.3460, 1.5087, -0.8530]], + + [[ 2.8153, 1.8787, -4.3839, -1.2112], + [ 0.3728, -2.1131, 0.0921, 0.8305]]]) + + # batch permute + >>> A = torch.randn(2, 3, 4, 5) + >>> torch.einsum('...ij->...ji', A).shape + torch.Size([2, 3, 5, 4]) + + # equivalent to torch.nn.functional.bilinear + >>> A = torch.randn(3,5,4) + >>> l = torch.randn(2,5) + >>> r = torch.randn(2,4) + >>> torch.einsum('bn,anm,bm->ba', l, A, r) + tensor([[-0.3430, -5.2405, 0.4494], + [ 0.3311, 5.5201, -3.0356]]) + """ + if has_torch_function(operands): + return handle_torch_function(einsum, operands, equation, *operands) if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument _operands = operands[0] @@ -351,41 +414,39 @@ def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]: return _meshgrid(*tensors) else: def meshgrid(*tensors): + r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional + vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by + expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs. + + Args: + tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be + treated as tensors of size :math:`(1,)` automatically + + Returns: + seq (sequence of Tensors): If the input has :math:`k` tensors of size + :math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also have :math:`k` tensors, + where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([4, 5, 6]) + >>> grid_x, grid_y = torch.meshgrid(x, y) + >>> grid_x + tensor([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + >>> grid_y + tensor([[4, 5, 6], + [4, 5, 6], + [4, 5, 6]]) + """ return _meshgrid(*tensors) def _meshgrid(*tensors): - r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional -vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by -expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs. - - - Args: - tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be - treated as tensors of size :math:`(1,)` automatically - - Returns: - seq (sequence of Tensors): If the input has :math:`k` tensors of size - :math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also have :math:`k` tensors, - where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`. - - Example:: - - >>> x = torch.tensor([1, 2, 3]) - >>> y = torch.tensor([4, 5, 6]) - >>> grid_x, grid_y = torch.meshgrid(x, y) - >>> grid_x - tensor([[1, 1, 1], - [2, 2, 2], - [3, 3, 3]]) - >>> grid_y - tensor([[4, 5, 6], - [4, 5, 6], - [4, 5, 6]]) - """ - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): - return handle_torch_function(meshgrid, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(meshgrid, tensors, *tensors) if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): # the old interface of passing the operands as one list argument tensors = tensors[0] # type: ignore @@ -399,9 +460,18 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, return_complex: Optional[bool] = None) -> Tensor: r"""Short-time Fourier transform (STFT). + .. warning:: + From version 1.8.0, :attr:`return_complex` must always be given + explicitly for real inputs and `return_complex=False` has been + deprecated. Strongly prefer `return_complex=True` as in a future + pytorch release, this function will only return complex tensors. + + Note that :func:`torch.view_as_real` can be used to recover a real + tensor with an extra last dimension for real and imaginary components. + The STFT computes the Fourier transform of short overlapping windows of the input. This giving frequency components of the signal as they change over - time. The interface of this function is modeled after librosa_. + time. The interface of this function is modeled after the librosa_ stft function. .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html @@ -457,10 +527,6 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, the output is a ``input.dim() + 2`` dimensional real tensor where the last dimension represents the real and imaginary components. - .. warning:: - From pytorch 1.8.0, :attr:`return_complex` will default to ``True`` - for all input types. - Returns either a complex tensor of size :math:`(* \times N \times T)` if :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N \times T \times 2)`. Where :math:`*` is the optional batch size of @@ -471,7 +537,7 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, This function changed signature at version 0.4.1. Calling with the previous signature may cause error or return incorrect result. - Arguments: + Args: input (Tensor): the input tensor n_fft (int): size of Fourier transform hop_length (int, optional): the distance between neighboring sliding window @@ -498,19 +564,18 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Tensor: A tensor containing the STFT result with shape described above """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, pad_mode=pad_mode, normalized=normalized, - onesided=onesided, return_complex=return_complex) + if has_torch_function_unary(input): + return handle_torch_function( + stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided, return_complex=return_complex) # TODO: after having proper ways to map Python strings to ATen Enum, move # this and F.pad to ATen. if center: signal_dim = input.dim() extended_shape = [1] * (3 - signal_dim) + list(input.size()) pad = int(n_fft // 2) - input = F.pad(input.view(extended_shape), (pad, pad), pad_mode) + input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) input = input.view(input.shape[-signal_dim:]) return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore normalized, onesided, return_complex) @@ -548,11 +613,15 @@ def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform," IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984. - Arguments: + Args: input (Tensor): The input tensor. Expected to be output of :func:`~torch.stft`, can either be complex (``channel``, ``fft_size``, ``n_frame``), or real (``channel``, ``fft_size``, ``n_frame``, 2) where the ``channel`` dimension is optional. + + .. deprecated:: 1.8.0 + Real input is deprecated, use complex inputs as returned by + ``stft(..., return_complex=True)`` instead. n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. (Default: ``n_fft // 4``) @@ -576,12 +645,11 @@ def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Returns: Tensor: Least squares estimation of the original signal of size (..., signal_length) """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, normalized=normalized, onesided=onesided, - length=length, return_complex=return_complex) + if has_torch_function_unary(input): + return handle_torch_function( + istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, + length=length, return_complex=return_complex) return _VF.istft(input, n_fft, hop_length, win_length, window, center, # type: ignore normalized, onesided, length, return_complex) @@ -612,7 +680,7 @@ def _unique_impl(input: Tensor, sorted: bool = True, Sorting could be slow, so if your input tensor is already sorted, it is recommended to use :func:`torch.unique_consecutive` which avoids the sorting. - Arguments: + Args: input (Tensor): the input tensor sorted (bool): Whether to sort the unique elements in ascending order before returning as output. @@ -660,11 +728,10 @@ def _unique_impl(input: Tensor, sorted: bool = True, [ 1, 2]]) """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - unique, (input,), input, sorted=sorted, return_inverse=return_inverse, - return_counts=return_counts, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function( + unique, (input,), input, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) if dim is not None: output, inverse_indices, counts = _VF.unique_dim( # type: ignore @@ -693,7 +760,7 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, only eliminates consecutive duplicate values. This semantics is similar to `std::unique` in C++. - Arguments: + Args: input (Tensor): the input tensor return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list. @@ -736,11 +803,10 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, >>> counts tensor([2, 2, 1, 2, 1]) """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - unique_consecutive, (input,), input, return_inverse=return_inverse, - return_counts=return_counts, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function( + unique_consecutive, (input,), input, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore input, return_inverse=return_inverse, return_counts=return_counts, dim=dim) return output, inverse_indices, counts @@ -749,9 +815,8 @@ def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False, def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output, counts @@ -760,9 +825,8 @@ def _return_counts(input, sorted=True, return_inverse=False, return_counts=False def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output @@ -771,9 +835,8 @@ def _return_output(input, sorted=True, return_inverse=False, return_counts=False def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return _unique_impl(input, sorted, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_impl(input, sorted, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) return output, inverse_indices @@ -814,9 +877,8 @@ def _return_inverse(input, sorted=True, return_inverse=False, return_counts=Fals def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output, counts @@ -825,9 +887,8 @@ def _consecutive_return_counts(input, return_inverse=False, return_counts=False, def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tensor - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output @@ -836,9 +897,8 @@ def _consecutive_return_output(input, return_inverse=False, return_counts=False, def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return _unique_consecutive_impl(input, return_inverse, return_counts, dim) + if has_torch_function_unary(input): + return _unique_consecutive_impl(input, return_inverse, return_counts, dim) output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) return output, inverse_indices @@ -876,7 +936,7 @@ def _consecutive_return_inverse(input, return_inverse=False, return_counts=False unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__ -def tensordot(a, b, dims=2): +def tensordot(a, b, dims=2, out=None): r"""Returns a contraction of a and b over multiple dimensions. :attr:`tensordot` implements a generalized matrix product. @@ -884,7 +944,7 @@ def tensordot(a, b, dims=2): Args: a (Tensor): Left tensor to contract b (Tensor): Right tensor to contract - dims (int or tuple of two lists of integers): number of dimensions to + dims (int or Tuple[List[int]] containing two lists): number of dimensions to contract or explicit lists of dimensions for :attr:`a` and :attr:`b` respectively @@ -919,10 +979,15 @@ def tensordot(a, b, dims=2): [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]]) + >>> a = torch.randn(3, 5, 4, 6) + >>> b = torch.randn(6, 4, 5, 3) + >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0])) + tensor([[ 7.7193, -2.4867, -10.3204], + [ 1.5513, -14.4737, -6.5113], + [ -0.2850, 4.2573, -3.5997]]) """ - if not torch.jit.is_scripting(): - if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)): - return handle_torch_function(tensordot, (a, b), a, b, dims=dims) + if has_torch_function_variadic(a, b): + return handle_torch_function(tensordot, (a, b), a, b, dims=dims) if isinstance(dims, (list, tuple)) or \ (isinstance(dims, torch.Tensor) and dims.numel() > 1): dims_a, dims_b = dims @@ -933,13 +998,16 @@ def tensordot(a, b, dims=2): raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") dims_a = list(range(-dims, 0)) dims_b = list(range(dims)) - return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore + if out is None: + return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore + else: + return _VF.tensordot(a, b, dims_a, dims_b, out=out) # type: ignore def cartesian_prod(*tensors): """Do cartesian product of the given sequence of tensors. The behavior is similar to python's `itertools.product`. - Arguments: + Args: *tensors: any number of 1 dimensional tensors. Returns: @@ -963,15 +1031,14 @@ def cartesian_prod(*tensors): [3, 4], [3, 5]]) """ - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): - return handle_torch_function(cartesian_prod, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(cartesian_prod, tensors, *tensors) return _VF.cartesian_prod(tensors) # type: ignore def block_diag(*tensors): """Create a block diagonal matrix from provided tensors. - Arguments: + Args: *tensors: One or more tensors with 0, 1, or 2 dimensions. Returns: @@ -998,7 +1065,7 @@ def block_diag(*tensors): [0, 0, 0, 0, 0, 0, 0, 0, 0, 5], [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]]) """ - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + if has_torch_function(tensors): return handle_torch_function(block_diag, tensors, *tensors) return torch._C._VariableFunctions.block_diag(tensors) # type: ignore @@ -1045,10 +1112,9 @@ def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'): [2.7138, 3.8322], [2.2830, 0.3791]]) """ - if not torch.jit.is_scripting(): - if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)): - return handle_torch_function( - cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) + if has_torch_function_variadic(x1, x2): + return handle_torch_function( + cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) if compute_mode == 'use_mm_for_euclid_dist_if_necessary': return _VF.cdist(x1, x2, p, None) # type: ignore elif compute_mode == 'use_mm_for_euclid_dist': @@ -1085,9 +1151,8 @@ def atleast_1d(*tensors): >>> torch.atleast_1d((x,y)) (tensor([0.5000]), tensor([1.])) """ - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): - return handle_torch_function(atleast_1d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_1d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_1d(tensors) # type: ignore @@ -1120,9 +1185,8 @@ def atleast_2d(*tensors): >>> torch.atleast_2d((x,y)) (tensor([[0.5000]]), tensor([[1.]])) """ - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): - return handle_torch_function(atleast_2d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_2d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_2d(tensors) # type: ignore @@ -1164,9 +1228,8 @@ def atleast_3d(*tensors): >>> torch.atleast_3d((x,y)) (tensor([[[0.5000]]]), tensor([[[1.]]])) """ - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): - return handle_torch_function(atleast_3d, tensors, *tensors) + if has_torch_function(tensors): + return handle_torch_function(atleast_3d, tensors, *tensors) if len(tensors) == 1: tensors = tensors[0] return _VF.atleast_3d(tensors) # type: ignore @@ -1214,29 +1277,46 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa .. warning:: torch.norm is deprecated and may be removed in a future PyTorch release. - Use :func:`torch.linalg.norm` instead. + Use :func:`torch.linalg.norm` instead, but note that :func:`torch.linalg.norm` + has a different signature and slightly different behavior that is + more consistent with NumPy's numpy.linalg.norm. Args: - input (Tensor): the input tensor + input (Tensor): The input tensor. Its data type must be either a floating + point or complex type. For complex inputs, the norm is calculated using the + absolute value of each element. If the input is complex and neither + :attr:`dtype` nor :attr:`out` is specified, the result's data type will + be the corresponding floating point type (e.g. float if :attr:`input` is + complexfloat). + p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'`` The following norms can be calculated: - ===== ============================ ========================== - ord matrix norm vector norm - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - 'nuc' nuclear norm -- - Other as vec norm when dim is None sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - dim (int, 2-tuple of ints, 2-list of ints, optional): If it is an int, - vector norm will be calculated, if it is 2-tuple of ints, matrix norm - will be calculated. If the value is None, matrix norm will be calculated - when the input tensor only has two dimensions, vector norm will be - calculated when the input tensor only has one dimension. If the input - tensor has more than two dimensions, the vector norm will be applied to - last dimension. + ====== ============== ========================== + ord matrix norm vector norm + ====== ============== ========================== + 'fro' Frobenius norm -- + 'nuc' nuclear norm -- + Number -- sum(abs(x)**ord)**(1./ord) + ====== ============== ========================== + + The vector norm can be calculated across any number of dimensions. + The corresponding dimensions of :attr:`input` are flattened into + one dimension, and the norm is calculated on the flattened + dimension. + + Frobenius norm produces the same result as ``p=2`` in all cases + except when :attr:`dim` is a list of three or more dims, in which + case Frobenius norm throws an error. + + Nuclear norm can only be calculated across exactly two dimensions. + + dim (int, tuple of ints, list of ints, optional): + Specifies which dimension or dimensions of :attr:`input` to + calculate the norm across. If :attr:`dim` is ``None``, the norm will + be calculated across all dimensions of :attr:`input`. If the norm + type indicated by :attr:`p` does not support the specified number of + dimensions, an error will occur. keepdim (bool, optional): whether the output tensors have :attr:`dim` retained or not. Ignored if :attr:`dim` = ``None`` and :attr:`out` = ``None``. Default: ``False`` @@ -1246,6 +1326,12 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa returned tensor. If specified, the input tensor is casted to :attr:'dtype' while performing the operation. Default: None. + .. note:: + Even though ``p='fro'`` supports any number of dimensions, the true + mathematical definition of Frobenius norm only applies to tensors with + exactly two dimensions. :func:`torch.linalg.norm` with ``ord='fro'`` aligns + with the mathematical definition, since it can only be applied across + exactly two dimensions. Example:: @@ -1273,14 +1359,10 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) (tensor(3.7417), tensor(11.2250)) """ - warnings.warn(( - "torch.norm is deprecated and may be removed in a future PyTorch release. " - "Use torch.linalg.norm instead.")) - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function( + norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) ndim = input.dim() @@ -1310,7 +1392,7 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa raise ValueError("dtype argument is not supported in frobenius norm") if _dim is None: - _dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m)) + _dim = list(range(ndim)) if out is None: return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore else: @@ -1331,7 +1413,7 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}") else: if _dim is None: - _dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m)) + _dim = list(range(ndim)) if out is None: if dtype is None: @@ -1373,9 +1455,8 @@ def chain_matmul(*matrices): .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition """ - if not torch.jit.is_scripting(): - if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices): - return handle_torch_function(chain_matmul, matrices, *matrices) + if has_torch_function(matrices): + return handle_torch_function(chain_matmul, matrices, *matrices) return _VF.chain_matmul(matrices) # type: ignore @@ -1409,7 +1490,11 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): .. note:: ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`. - Arguments: + .. warning:: + The LU factorization does have backward support, + but only for square inputs of full rank. + + Args: A (Tensor): the tensor to factor of size :math:`(*, m, n)` pivot (bool, optional): controls whether pivoting is done. Default: ``True`` get_infos (bool, optional): if set to ``True``, returns an info IntTensor. @@ -1424,7 +1509,12 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)` - - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)` + - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`. + ``pivots`` stores all the intermediate transpositions of rows. + The final permutation ``perm`` could be reconstructed by + applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``, + where ``perm`` is initially the identity permutation of :math:`m` elements + (essentially this is what :func:`torch.lu_unpack` is doing). - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of size :math:`(*)` where non-zero values indicate whether factorization for the matrix or @@ -1450,10 +1540,26 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples! """ + if not torch._jit_internal.is_scripting(): + if A.requires_grad: + if not (A.size(-2) == A.size(-1) and A.dtype.is_floating_point): + raise ValueError( + 'lu.backward works only with batches of squared full-rank matrices' + ' of floating types.' + ) + + return _LU.apply(A, pivot, get_infos) + else: + if A.requires_grad: + raise RuntimeError( + 'Script and require gradients is not supported at the moment.' + 'If you just want to do the forward, use .detach()' + 'on the input before calling the function.' + ) + # If get_infos is True, then we don't need to check for errors and vice versa return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) - if TYPE_CHECKING: _ListOrSeq = Sequence[Tensor] else: @@ -1468,10 +1574,9 @@ def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: def _lu_with_infos(A, pivot=True, get_infos=False, out=None): # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor] - if not torch.jit.is_scripting(): - if type(A) is not Tensor and has_torch_function((A,)): - return handle_torch_function( - lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) result = _lu_impl(A, pivot, get_infos, out) if out is not None: _check_list_size(len(out), get_infos, out) @@ -1484,10 +1589,9 @@ def _lu_with_infos(A, pivot=True, get_infos=False, out=None): def _lu_no_infos(A, pivot=True, get_infos=False, out=None): # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] # need to check for torch_function here so that we exit if - if not torch.jit.is_scripting(): - if type(A) is not Tensor and has_torch_function((A,)): - return handle_torch_function( - lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) + if has_torch_function_unary(A): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) result = _lu_impl(A, pivot, get_infos, out) if out is not None: _check_list_size(len(out), get_infos, out) diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index 64b0b9671edc8..f1da04803b072 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -51,7 +51,7 @@ def then(self, callback): # type: (Callable[[Future[T]], S]) -> Future[S] this ``Future``. The callback function can use the ``Future.wait()`` API to get the value. - Arguments: + Args: callback(``Callable``): a ``Callable`` that takes this ``Future`` as the only argument. @@ -81,13 +81,47 @@ def then(self, callback): # type: (Callable[[Future[T]], S]) -> Future[S] """ return cast(Future[S], super().then(callback)) + # Have to use string annotations because PEP-0563 is not available in 3.6 + def _add_done_callback(self, callback): # type: (Callable[[Future[T]], None]) -> None + r""" + Append the given callback function to this ``Future``, which will be run + when the ``Future`` is completed. Multiple callbacks can be added to + the same ``Future``, and will be invoked in the same order as they were + added. The callback must take one argument, which is the reference to + this ``Future``. The callback function can use the ``Future.wait()`` API + to get the value. + + We recommend that you use the ``then`` API as it provides a way to synchronize + after your callback has completed. ``add_done_callback`` can be cheaper if your + callback does not return anything. But both ``then`` and ``add_done_callback`` + use the same callback registration API under the hood, and thus the order of + their callbacks will be maintained even if their calls are interleaved. + + Args: + callback(``None``): a ``Callable`` that takes in no arguments + + Example:: + >>> import torch + >>> + >>> def callback(): + >>> print(f"This will run after the future has finished.") + >>> + >>> fut = torch.futures.Future() + >>> fut.add_done_callback(callback) + >>> fut.set_result(5) + >>> + >>> # Outputs are: + >>> # This will run after the future has finished. + """ + super().add_done_callback(callback) + def set_result(self, result: T) -> None: r""" Set the result for this ``Future``, which will mark this ``Future`` as completed and trigger all attached callbacks. Note that a ``Future`` cannot be marked completed twice. - Arguments: + Args: result (object): the result object of this ``Future``. Example:: @@ -118,7 +152,7 @@ def collect_all(futures: List[Future]) -> Future[List[Future]]: combined :class:`~torch.futures.Future` that is completed when all of the sub-futures are completed. - Arguments: + Args: futures (list): a list of :class:`~torch.futures.Future` objects. Returns: @@ -151,7 +185,7 @@ def wait_all(futures: List[Future]) -> List: Waits for all provided futures to be complete, and returns the list of completed values. - Arguments: + Args: futures (list): a list of :class:`~torch.futures.Future` object. Returns: diff --git a/torch/fx/OVERVIEW.md b/torch/fx/OVERVIEW.md new file mode 100644 index 0000000000000..7247ae78459d6 --- /dev/null +++ b/torch/fx/OVERVIEW.md @@ -0,0 +1,133 @@ +# FX Technical Overview (WIP) + +FX is a toolkit for pass writers to facilitate Python-to-Python transformation of `nn.Module` instances. This toolkit aims to support a subset of Python language semantics—rather than the whole Python language—to facilitate ease of implementation of transforms. Currently, this feature is under a Beta release and its API may change. + +## Table of Contents + +- [FX Technical Overview](#fx-technical-overview) + - [Table of Contents](#table-of-contents) +- [Introduction](#introduction) + - [Motivation](#motivation) + - [Use Cases](#use-cases) + - [Technical Details](#technical-details) +- [Internal Structure](#internal-structure) + - [Graph](#graph) + - [Graph Module](#graph-module) +- [Symbolic Tracing](#symbolic-tracing) + - [About](#about) + - [Tracer](#tracer) + - [Proxy](#proxy) +- [The FX IR](#ir) +- [Transformation and Codegen](#codegen) + +# Introduction + +## Motivation ## + +TODO + +## Use Cases ## + +FX should be used by pass writers to provide functionality for capturing and constructing nn.Module code in a structured way. We do not expect end users to utilize FX directly. A useful property of framing FX in this way is that passes can be seen as functions of the form `pass(in_mod : nn.Module) -> nn.Module`. This means we can create composable pipelines of transformations. + +![An image of a sample nn.Module transformation pipeline that starts with a Quantize transformation, which is then composed with a Split transformation, then a Lower to Accelerator transformation](https://i.imgur.com/TzFIYMi.png "nn.Module transformation pipeline") + +In this example pipeline, we have a Quantize transformation, which is then composed with a Split transformation, then a Lower to Accelerator transformation. Finally, the transformed Modules are compiled with TorchScript for deployment. This last point emphasizes that not only should FX transforms be composable with each other, but their products are composable with other systems like TorchScript compilation or tracing. + +By using `nn.Module` as the interface between passes, FX transforms are interoperable with each other, and the resulting model can be used anywhere an `nn.Module` can be used. + +## Technical Details ## + +The following sections will walk us through the components that transform from original `torch.nn.Module` to FX IR and finally to generated Python code and a GraphModule instance: + +FX’s front-end makes use of the dynamic nature of Python to intercept call-sites for various entities (PyTorch operators, Module invocations, and Tensor method invocations). This functionality is exposed through an API called `torch.fx.symbolic_trace`. We can see how this works by way of an example: + +```python +import torch + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter( + torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + +from torch.fx import symbolic_trace +module = MyModule() +symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) + +input = torch.rand(3, 4) +torch.testing.assert_allclose(symbolic_traced(input), module(input)) +``` + +Here, we set up a simple Module that exercises different language features: fetching a parameter, applying an arithmetic operator, applying a submodule (linear), and applying a Tensor method. `symbolic_trace` returns an instance of GraphModule, which is in itself a subclass of `nn.Module`. We can see that the `symbolic_traced` instance runs and returns the same result as the original module instance module. + +# Internal Structure + +## [Graph](https://pytorch.org/docs/master/fx.html#torch.fx.Graph) ## +TODO + +## [GraphModule](https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule) ## +TODO + +# Symbolic Tracing + +## [Tracer](https://pytorch.org/docs/master/fx.html#torch.fx.Tracer) ## + +`Tracer` is the class that implements the symbolic tracing functionality of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent to `Tracer().trace(m)`. Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on the class. + +In the default implementation of `Tracer().trace`, the tracer first creates Proxy objects for all arguments in the `forward` function. (This happens in the call to `create_args_for_root`.) Next, the `forward` function is called with the new Proxy arguments. As the Proxies flow through the program, they record all the operations (`torch` function calls, method calls, and operators) that they touch into the growing FX Graph as Nodes. + +## Proxy ## + +Proxy objects are Node wrappers used by the Tracer to record operations seen during symbolic tracing. The mechanism through which Proxy objects record computation is [`__torch_function__`](https://pytorch.org/docs/stable/notes/extending.html#extending-torch). If any custom Python type defines a method named `__torch_function__`, PyTorch will invoke that `__torch_function__` implementation when an instance of that custom type is passed to a function in the `torch` namespace. In FX, when operations on Proxy are dispatched to the `__torch_function__` handler, the `__torch_function__` handler records the operation in the Graph as a Node. The Node that was recorded in the Graph is then itself wrapped in a Proxy, facilitating further application of ops on that value. + +Consider the following example: + +```python + class M(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + m = M() + traced = symbolic_trace(m) +``` + +During the call to `symbolic_trace`, the parameter `x` is transformed into a Proxy object and the corresponding Node (a Node with op = “placeholder” and target = “x”) is added to the Graph. Then, the Module is run with Proxies as inputs, and recording happens via the `__torch_function__` dispatch path. + +If you're doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph. + +# The FX IR + +Symbolic tracing captures an intermediate representation (IR), which is represented as a doubly-linked list of Nodes. + +Node is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include Nodes that specify function inputs and outputs). Each Node has a function specified by its `op` property. The Node semantics for each value of `op` are as follows: + +- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. +- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care +- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention +- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. +- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* +- `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement in the Graph printout. + +To facilitate easier analysis of data dependencies, Nodes have read-only properties `input_nodes` and `users`, which specify which Nodes in the Graph are used by this Node and which Nodes use this Node, respectively. Although Nodes are represented as a doubly-linked list, the use-def relationships form an acyclic graph and can be traversed as such. + +# Transformation and Codegen + +An invocation of `symbolic_traced` above requires a valid `forward()` method to be defined on the Module instance. How does this work? GraphModule actually generates valid Python source code based on the IR it is instantiated with. This can be seen by accessing the code attribute on the GraphModule: `print(symbolic_traced.code)`. + +After symbolic tracing, the code given under [Technical Details](#technical-details) is represented as follows: + +```python +def forward(self, x): + param = self.param + add_1 = x + param; x = param = None + linear_1 = self.linear(add_1); add_1 = None + clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None + return clamp_1 +``` + +This is the core of why FX is a Python-to-Python translation toolkit. Outside users can treat the results of FX transformations as they would any other `nn.Module` instance. diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index 1855114607409..84c05a00d4b8c 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -1,92 +1,87 @@ # type: ignore r''' -**This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** +**This feature is under a Beta release and its API may change.** -FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It +FX is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed -with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's -`forward` method symbolically and record those operations in the FX intermediate representation. +with an ``nn.Module`` instance as its argument, GraphModule will trace through the computation of that Module's +``forward`` method symbolically and record those operations in the FX intermediate representation. -``` -import torch -from torch.fx import GraphModule +.. code-block:: python -class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) + import torch + import torch.fx - def forward(self, x): - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) - -m = MyModule() -gm = symbolic_trace(m) -``` - -The Intermediate Representation centers around a 5-opcode format: - -``` -from tabulate import tabulate -node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes] -print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) -``` - -``` -opcode name target args kwargs -------------- ------------- ------------------------------------------------------- ------------------ ----------- -placeholder x x () {} -get_attr linear_weight linear.weight () {} -call_function add_1 (x, linear_weight) {} -call_module linear_1 linear (add_1,) {} -call_method relu_2 relu [linear_1] {} -call_function sum_1 (relu_2,) {'dim': -1} -call_function topk_1 (sum_1, 3) {} -``` - -The semantics are as follows: - -- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. - `target` is similarly the name of the argument. `args` and `kwargs` are don't-care -- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the - fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. - `args` and `kwargs` are don't-care -- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign - to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + +The Intermediate Representation centers around a 5-opcode format:: + + print(gm.graph) + +.. code-block:: text + + graph(x): + %linear_weight : [#users=1] = self.linear.weight + %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) + %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) + return topk_1 + +The Node semantics are as follows: + +- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. +- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care +- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, following the Python calling convention -- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is - as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. - `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. -- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method - to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, - _including the self argument_. - -GraphModule automatically generates Python code for the operations it symbolically observed: - -``` -print(gm.code) -``` - -``` -def forward(self, x): - self = self.root - linear_weight = self.linear.weight - add_1 = x + linear_weight - linear_1 = self.linear(add_1) - relu_2 = linear_1.relu() - sum_1 = torch.sum(relu_2, dim = -1) - topk_1 = torch.topk(sum_1, 3) - - - return topk_1 -``` - -Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another -`nn.Module` can be used, including in TorchScript tracing/compilation. +- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. +- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* +- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + +GraphModule automatically generates Python code for the operations it symbolically observed:: + + print(gm.code) + +.. code-block:: python + + import torch + def forward(self, x): + linear_weight = self.linear.weight + add_1 = x + linear_weight; x = linear_weight = None + linear_1 = self.linear(add_1); add_1 = None + relu_1 = linear_1.relu(); linear_1 = None + sum_1 = torch.sum(relu_1, dim = -1); relu_1 = None + topk_1 = torch.topk(sum_1, 3); sum_1 = None + return topk_1 + +Because this code is valid PyTorch code, the resulting ``GraphModule`` can be used in any context another +``nn.Module`` can be used, including in TorchScript tracing/compilation. ''' from .graph_module import GraphModule -from .symbolic_trace import symbolic_trace, Tracer -from .graph import Graph, map_arg -from .node import Node +from .symbolic_trace import symbolic_trace, Tracer, wrap +from .graph import Graph +from .node import Node, map_arg from .proxy import Proxy diff --git a/torch/fx/__init__.pyi b/torch/fx/__init__.pyi new file mode 100644 index 0000000000000..9939f25f4c147 --- /dev/null +++ b/torch/fx/__init__.pyi @@ -0,0 +1,5 @@ +from .graph import Graph as Graph +from .graph_module import GraphModule as GraphModule +from .node import Node as Node, map_arg as map_arg +from .proxy import Proxy as Proxy +from .symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap diff --git a/torch/fx/examples/inspect_utils.py b/torch/fx/examples/inspect_utils.py new file mode 100644 index 0000000000000..26833c1c6e14c --- /dev/null +++ b/torch/fx/examples/inspect_utils.py @@ -0,0 +1,14 @@ +from tabulate import tabulate + +""" +The methods in this file may be used to examine the state of the code +and how the Graph evolves at any time during execution. If you're +unsure of what's happening in an example in this folder, try adding one +of these methods before and after a key line. +""" + +def print_IR(graph): + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] + for n in graph.nodes] + print(tabulate(node_specs, + headers=['opcode', 'name', 'target', 'args', 'kwargs'])) diff --git a/torch/fx/examples/replace_op.py b/torch/fx/examples/replace_op.py new file mode 100644 index 0000000000000..f938ecc0f56bf --- /dev/null +++ b/torch/fx/examples/replace_op.py @@ -0,0 +1,63 @@ +import torch +from torch.fx import symbolic_trace +import operator + +""" +How to replace one op with another +1. Iterate through all Nodes in your GraphModule's Graph. +2. Determine if the current Node should be replaced. (Suggested: match +on the Node's ``target`` attribute). +3. Create a replacement Node and add it to the Graph. +4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of +the current Node with the replacement. +5. Delete the old Node from the graph. +6. Call ``recompile`` on the GraphModule. This updates the generated +Python code to reflect the new Graph state. + +Currently, FX does not provide any way to guarantee that replaced +operators are syntactically valid. It's up to the user to confirm that +any new operators will work with the existing operands. + +The following code demonstrates an example of replacing any instance of +addition with a bitwise AND. + +To examine how the Graph evolves during op replacement, add the +statement `print(traced.graph)` after the line you want to inspect. +Alternatively, see the Nodes in a tabular format by adding +`from inspect_utils import print_IR` to the top of this file and calling +`print_IR(traced.graph)`. +""" + +# Sample module +class M(torch.nn.Module): + def forward(self, x, y): + return x + y, torch.add(x, y), x.add(y) + +# Symbolically trace an instance of the module +traced = symbolic_trace(M()) + +# As demonstrated in the above example, there are several different ways +# to denote addition. The possible cases are: +# 1. `x + y` - A `call_function` Node with target `operator.add`. +# We can match for equality on that `operator.add` directly. +# 2. `torch.add(x, y)` - A `call_function` Node with target +# `torch.add`. Similarly, we can match this function directly. +# 3. `x.add(y)` - The Tensor method call, whose target we can match +# as a string. + +patterns = set([operator.add, torch.add, "add"]) + +# Go through all the nodes in the Graph +for n in traced.graph.nodes: + # If the target matches one of the patterns + if any(n.target == pattern for pattern in patterns): + # Set the insert point, add the new node, and replace all uses + # of `n` with the new node + with traced.graph.inserting_after(n): + new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs) + n.replace_all_uses_with(new_node) + # Remove the old node from the graph + traced.graph.erase_node(n) + +# Don't forget to recompile! +traced.recompile() diff --git a/torch/fx/experimental/__init__.py b/torch/fx/experimental/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py new file mode 100644 index 0000000000000..a995a58c5774c --- /dev/null +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -0,0 +1,879 @@ +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node, map_arg +from typing import Dict, List, Set, NamedTuple, Tuple +import torch +from torch.fx.experimental.subgraph_creation_example import split_module +import operator +from torch.fx.experimental.partitioner_utils import Partition, \ + Device, PartitionerConfig, get_partition_to_latency_mapping,\ + get_latency_of_partitioned_graph, NodeLatency, get_extra_size_of, \ + PartitionMode + +class DAGNode(): + """DAGNode class maintains useful information for a partition (submodule), + and its input submodules and output submodules. + """ + def __init__( + self, + submodule_node: Node, + input_nodes: List[Node], + output_nodes: List[Node], + logical_device_ids: List[int], + size_bytes: int + ) -> None: + self.submodule_node: Node = submodule_node + self.input_nodes: List[Node] = input_nodes + self.output_nodes: List[Node] = output_nodes + self.logical_device_ids: List[int] = logical_device_ids + self.size_bytes = size_bytes + + def __str__(self) -> str: + return str(self.submodule_node) + +class DAG: + """DAG class contains all the DAG nodes""" + def __init__(self) -> None: + self.nodes: List[DAGNode] = [] + + def create_node( + self, + submodule_node: Node, + input_nodes: List[Node], + output_nodes: List[Node], + logical_devices: List[int], + size_bytes: int + ) -> None: + node = DAGNode(submodule_node, input_nodes, output_nodes, logical_devices, size_bytes) + self.nodes.append(node) + +class PartitionResult(NamedTuple): + """NameTuple used for returning DAG and a new fx module + """ + dag: DAG + module_with_submodules: GraphModule + +"""Followings are some helper functions for partition manipulation""" +def reset_partition_device(partitions): + for partition in partitions: + partition.logical_device_ids = [] + +def combine_two_partitions( + partition_0: Partition, + partition_1: Partition, + partitions: List[Partition] +) -> None: + """Given a list of partitions and its two partitions, + combine these two partitions into a new one appending to the partitions + and remove the previous two partitions from the list of partitions + """ + partition = Partition(len(partitions)) + partition.nodes = partition_0.nodes.union(partition_1.nodes) + partition.recalculate_mem_size() + partitions.append(partition) + partitions.remove(partition_0) + partitions.remove(partition_1) + reorganize_partitions(partitions) + return + +def set_parents_and_children(partitions: List[Partition]) -> None: + """Given a list of partitions, mark parents and children for each partition + """ + # Go through all nodes in a partition. + # If a node's user is in other partition, + # then the other partition is this partition's children. + # This partition is the other partition's parent + for partition in partitions: + partition.children = set() + partition.parents = set() + for partition in partitions: + for node in partition.nodes: + # For each node in the current partition, find its users + users = node.users + for n in users: + # Find which the partition the user node belongs to. + # Note that if the node itself is also belongs to that partition, + # that partition is not the child of the current partition + for p in partitions: + if p != partition and n in p.nodes and node not in p.nodes: + partition.children.add(p) + p.parents.add(partition) + return + +def reorganize_partitions(partitions: List[Partition]) -> None: + """Given a list of partitions, reorganzie partiton id, + its parents and its children for each partition + """ + # Rearrange partition ids + for i, partition in enumerate(partitions): + partition.partition_id = i + set_parents_and_children(partitions) + return + +def get_bfs_level_partition(partitions: List[Partition]) -> None: + """Given a list of partitions, + mark the bfs level for each partition + """ + current_level: Set[Partition] = set() + visited: Set[Partition] = set() + for partition in partitions: + # If a partition has no parent, it should be in root level + if len(partition.parents) == 0: + current_level.add(partition) + next_level: Set[Partition] = set() + level = 0 + # bfs + while current_level: + partition = current_level.pop() + partition.bfs_level = level + visited.add(partition) + children = partition.children + for child in children: + if child not in next_level: + next_level.add(child) + if not current_level: + current_level = next_level.copy() + next_level = set() + level += 1 + return + +def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: + """Given a list of partitions,return node to partition mapping + """ + node_to_partition: Dict[Node, int] = {} + for partition in partitions: + for node in partition.nodes: + node_to_partition[node] = partition.partition_id + return node_to_partition + +def get_device_to_partitions_mapping(partitions: List[Partition], devices: List[Device]): + """Given a list of partitions and a list of devices, + map each partition into a device. + """ + def calculate_extra_mem_bytes_needed_for(partition: Partition, partitions: List[Partition]): + all_nodes: Set[Node] = set() + for p in partitions: + all_nodes = all_nodes.union(p.nodes) + if len(all_nodes) == 0: + return partition.used_mem_bytes + all_nodes = all_nodes.union(partition.nodes) + extra_size_needed = 0 + for node in partition.nodes: + extra_size_needed += get_extra_size_of(node, all_nodes) + return extra_size_needed + + def find_device_for(partition: Partition): + """Given a partition, find a logical device for the partition + The algorithm is to put the partition on the device + that has just enough mem left for that partition. + device_to_left_mem_bytes is a dictionary between device and its left mem size + sorted by its left mem size + """ + for d in device_to_left_mem_bytes: + extra_size_needed = calculate_extra_mem_bytes_needed_for(partition, device_to_partitions[d]) + if extra_size_needed < device_to_left_mem_bytes[d]: + device_to_partitions[d].append(partition) + partition.logical_device_ids.append(d.logical_id) + device_to_left_mem_bytes[d] -= extra_size_needed + return True + return False + # logical id to device + logical_id_to_device: Dict[int, Device] = {} + # Track partitions on device + device_to_partitions: Dict[Device, List[Partition]] = {} + # Track device's left mem size + device_to_left_mem_bytes: Dict[Device, int] = {} + for d in devices: + logical_id_to_device[d.logical_id] = d + device_to_partitions[d] = [] + device_to_left_mem_bytes[d] = d.available_mem_bytes + # Deal with the partitions that already have a device + # and also collect all partitions without a device (no_device_partitions) + no_device_partitions = [] + for partition in partitions: + if partition.logical_device_ids != []: + logical_id = partition.logical_device_ids[0] + device = logical_id_to_device[logical_id] + device_to_partitions[device] = [partition] + device_to_left_mem_bytes[device] = d.available_mem_bytes - partition.used_mem_bytes + else: + no_device_partitions.append(partition) + # Find devices for all the partitions without a device + found_device = True + for partition in no_device_partitions: + device_to_left_mem_bytes = { + d: left_mem_bytes for d, left_mem_bytes + in sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1]) + } + found_device = find_device_for(partition) + if not found_device: + break + return found_device + +def check_dependency(partition): + """Given a partition,check if there is a circular dependency on + this partition using bfs + """ + visited: Set[Partition] = set([partition]) + queue: List[Partition] = [partition] + while queue: + p = queue.pop(0) + for child in p.children: + if child == partition: + return True + else: + if child not in visited: + visited.add(child) + queue.append(child) + return False + +class Partitioner: + """A fx module may not fit into one device. + Partitioner class helps partition one fx module into submodules (partitions), + so that the submodules can be executed crossing different accelerators. + The main function of this class is self.partition_graph. + It partitions the fx module based on the scheme specified in partition_config + A DAG structure is returned + along with a new fx module with submodule nodes. + """ + def __init__(self) -> None: + self.partitions: List[Partition] = [] + self.node_to_partition: Dict[Node, int] = {} + self.devices: List[Device] = [] + + def partition_graph( + self, + fx_module: GraphModule, + torch_module: torch.nn.Module, + partitioner_config: PartitionerConfig + ) -> PartitionResult: + """Given the fx module, torch module and partitioner_config, + find the partitions, do the partitions, + and then return a DAG and a new fx module with submodule nodes (partitions) + """ + self.graph_module = fx_module + self.torch_module = torch_module + self.devices = partitioner_config.devices + if len(self.devices) == 0: + raise RuntimeError('No devices') + # Check if there are op nodes in the fx module + nodes = self.graph_module.graph.nodes + if all(node.op in {'placeholder', 'get_attr', 'output'} for node in nodes): + raise RuntimeError('No Partition since no operations in the module') + # Calculate total size of the fx module + total_size_of_graph = 0 + for node in nodes: + if node.op == 'output': + break + total_size_of_graph += node.size_bytes.total_size + # Find the device with the max mem size + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + # AOT based partition + if partitioner_config.mode == PartitionMode.aot_based: + self.aot_based_partition( + partitioner_config.node_to_partition_mapping, + partitioner_config.partition_to_logical_device_mapping + ) + # Single partition if the whole module can be fit into one device + elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: + self.find_single_partition(total_size_of_graph) + elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]): + raise RuntimeError('Devices have no enough memory for the module') + else: + # Sparse nn based partition + if partitioner_config.mode == PartitionMode.sparse_nn: + available_mem_bytes = self.devices[0].available_mem_bytes + if not all(device.available_mem_bytes == available_mem_bytes for device in self.devices): + raise RuntimeError('All devices must have same memory size!') + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition + self.sparse_nn_partition(available_mem_bytes) + # Cost aware partition + elif partitioner_config.mode == PartitionMode.cost_aware: + self.cost_aware_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping + ) + # KL based partition + elif partitioner_config.mode == PartitionMode.kl_based: + self.kl_based_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping + ) + else: + self.size_based_partition() + module_with_submodules = self.do_partition() + # The DAG contains DAGNodes with info of each partition's input nodes, output nodes + # and how partitions are connected. + dag = self.dump_dag(module_with_submodules) + ret = PartitionResult(dag, module_with_submodules) + return ret + + def find_single_partition(self, total_size_of_graph) -> None: + """Fit the whole fx module into one device + """ + partition_0 = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op == 'output': + break + partition_0.nodes.add(node) + partition_0.used_mem_bytes = total_size_of_graph + partition_0.logical_device_ids = [0] + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def size_based_partition(self) -> None: + """This method is to partition the fx module based on memory size. + It uses greedy approach. The result may not be the best. + The basic idea is: + Step 1: + Find a device which has enough memory to fit the current node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + and then try to map those partitions into logical devices with enough mem left. + """ + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device('', -1, -1) + for d in self.devices: + if d not in occupied_devices and d.available_mem_bytes >= mem_size_needed: + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + 'is too large to fit any device') + occupied_devices.append(device) + return device + + # Track partition and its left mem size + partition_to_left_mem_bytes: Dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: List[Device] = [] + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {'call_module', 'call_method', 'call_function'}: + # Check if there are devices left + if len(self.partitions) <= len(self.devices): + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[partition] = device.available_mem_bytes + # Update available mem for the current partitio + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into current partition + if partition_to_left_mem_bytes[partition] < total_size_of_input_nodes: + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device is left + # Put the previous partitions into a list (non_single_node_partitions) + non_single_node_partitions = self.partitions[:] + # Create the first single node partition for the current node + self.create_single_node_partition(node) + continue + # Some devices are still left + # Create a new partition with a mem size that is enough for the current node + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + partition_to_left_mem_bytes[partition] = device.available_mem_bytes + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + # Create single node partitions if no device is left + else: + self.create_single_node_partition(node) + reorganize_partitions(self.partitions) + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + # Mapping all partitions into device + found_partition_to_device_mapping = get_device_to_partitions_mapping(self.partitions, self.devices) + if not found_partition_to_device_mapping: + raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") + return + + def do_partition(self) -> GraphModule: + """Return a new fx module with submodule nodes (partitions).""" + module_with_submodules = split_module( + self.graph_module, + self.torch_module, + lambda node: self.node_to_partition[node] + ) + return module_with_submodules + + def dump_dag(self, module_with_submodules: GraphModule) -> DAG: + """Return the dag structure and the new fx module with submodules""" + dag = DAG() + for node in module_with_submodules.graph.nodes: + if node.op == 'output': + break + if node.op in {'placeholder', 'get_attr'}: + continue + if node.target == operator.__getitem__: + continue + input_nodes : Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # When a node has two or more output nodes, + # it outputs its result to 'getitem' nodes. + # Those 'getitem' nodes are the output node for this node. + # Otherwise, the output node is this node itself. + if len(node.users) > 1: + output_nodes = list(node.users) + else: + output_nodes = [node] + partition_id = int(node.name.rsplit('_', 1)[-1]) + device_ids = self.partitions[partition_id].logical_device_ids + size_bytes = self.partitions[partition_id].used_mem_bytes + dag.create_node(node, list(input_nodes), output_nodes, device_ids, size_bytes) + return dag + + def create_partition(self) -> Partition: + """Create a partition and append it to self.partitions.""" + partition_id = len(self.partitions) + partition = Partition(partition_id) + self.partitions.append(partition) + return partition + + def create_single_node_partition(self, node): + """Create a partition for a single node""" + partition = self.create_partition() + partition.add_node(node) + return + + def sparse_nn_partition(self, available_mem_bytes: int) -> None: + """This method partition a sparse nn module. + It is size based partition but different from size_based_partition, + it only works when all the devices have same memory size (available_mem_bytes). + In the future, devices with different mem sizes will be supported like size_based_partition. + It first traverse all the nodes and do the partitions based on the same memory size. + If the current partition has no enough memory left for a new op node + (call_module, call_method, call_function), a new partition is created. + When crossing the boundary between non-embedding nodes and embedding nodes, + a new partition is created regardlessly. + For example, if the current node is a non-embedding node but the next node is an + embedding node, a new partition is created for the next node. + After the partition, the partitions are combined as much as possible. + The rule is that a non-embedding partition only + combines with another non-embedding one. + So as the embedding partitions. + """ + def combine_partitions_based_on_size(partitions: List[Partition], available_mem_bytes: int) -> None: + """Combining small partitions together to keep as less partitions as possible. + Here is an example of the algorithm to do this: + Assume some partitions, we first sort them based on partiiton used memory size. + [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] + The available memory is 10. + step 1: self.find_partition_to_combine_based_on_size() + First, mark bfs level for each partition + Second, look the smallest partition, partition_4: 10 - 1 = 9 + It means any partition has a used memory equal or less than 9 could combine this partition + We go from the largest and selection partition_0. + Check the bfs level for two partitions, if the level difference is less than 2, + it can be combined. + step 2: repeat step 1 until no partitions can be combined + """ + find_combination = True + while find_combination: + # Sort partitions based on memory size + sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) + # Mark bfs level + get_bfs_level_partition(self.partitions) + find_combination, partitions = \ + find_partition_to_combine_based_on_size( + sorted_partitions, + available_mem_bytes, + partitions + ) + return + + def calculate_mem_bytes_needed(p1, p2): + """Given two partitions, calculate how many mem bytes + are needed if two partitions are combined + """ + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed + + def find_partition_to_combine_based_on_size( + sorted_partitions: List[Partition], + available_mem_bytes: int, + partitions: List[Partition] + ) -> Tuple[bool, List[Partition]]: + """step 1 in combine_partition_based_on_size()""" + find_combination = False + smallest_partition = sorted_partitions.pop(0) + for p in sorted_partitions[::-1]: + if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: + # Calculate how many bytes needed if combined + mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) + if mem_bytes_needed <= available_mem_bytes: + combine_two_partitions(p, smallest_partition, self.partitions) + partitions.remove(smallest_partition) + partitions.remove(p) + partitions.append(self.partitions[-1]) + find_combination = True + break + return find_combination, partitions + + def reset_partition_in_sparse_nn(partition, new_partition=True): + """If crossing the boudary between non-embedding nodes and + embedding nodes, create a new partition + """ + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + if new_partition: + partition = self.create_partition() + partition.left_mem_bytes = available_mem_bytes + return partition + return None + + def is_embedding_node(node: Node) -> bool: + """Check if a node is an embedding node""" + if node.op == 'call_module': + submodule = self.graph_module + for atom in str(node.target).split('.'): + if not hasattr(submodule, atom): + raise RuntimeError(f'Module {submodule} has no attribute {atom}') + submodule = getattr(submodule, atom) + if 'Embedding' in str(submodule): + return True + return False + + # Track embedding partitons and non-embedding partitions separately + embedding_partitions: List[Partition] = [] + non_embedding_partitions: List[Partition] = [] + # A Flag to check the boundary + in_embedding_region: bool = False + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {'call_module', 'call_method', 'call_function'}: + # Check if crossing the boundary between embedding nodes and non embedding nodes + if is_embedding_node(node) != in_embedding_region: + # Crossing the boundary + # Check if the current partition is an empty partition + if partition.used_mem_bytes != 0: + # The current partition isn't an empty partition. Create a new one. + partition = reset_partition_in_sparse_nn(partition) + in_embedding_region = not in_embedding_region + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes + partition.used_mem_bytes > available_mem_bytes: + partition = reset_partition_in_sparse_nn(partition) + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes > available_mem_bytes: + raise RuntimeError(node.target + 'is too large to fit into a device') + partition.add_node(node) + reset_partition_in_sparse_nn(partition, new_partition=False) + # Set parents and children for partitions + set_parents_and_children(self.partitions) + # Combining non-embedding partitions + combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) + # Combining embedding partitions + combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) + total_size_of_non_embedding_partitions = 0 + for partition in non_embedding_partitions: + total_size_of_non_embedding_partitions += partition.used_mem_bytes + # Check if devices are enough for all partitions + if len(embedding_partitions) > len(self.devices): + msg = 'Need ' + str(len(embedding_partitions)) + ' devices, but only ' \ + + str(len(self.devices)) + ' provided' + raise RuntimeError(msg) + occupied_devices = [] + for i, partition in enumerate(embedding_partitions): + # Check if all non-embedding partitions can fit into embedding partition devices + if total_size_of_non_embedding_partitions + partition.used_mem_bytes > available_mem_bytes: + raise RuntimeError( + 'partition_' + + str(partition.partition_id) + + '(embedding partition) and non embedding partitions can not fit into one device' + ) + else: + # Add logical device to the partition + partition.logical_device_ids = [self.devices[i].logical_id] + occupied_devices.append(self.devices[i].logical_id) + # Add logical devices to the non_embedding_partitions + for partition in non_embedding_partitions: + partition.logical_device_ids = occupied_devices + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def cost_aware_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: Dict[Node, NodeLatency] + ) -> None: + """This method is to partition the fx module based on the cost. + The cost is the total latency of running the whole fx module. + In partitioner_utils.py, the cost model is built. + The cost aware partition algorithm is: + #1. At every begining, each node is a partition. + Then we map all the partitions to the devices + and calculate the cost + #2. Then try to pre-combine any two of the partitions if the two + partitions can be combined. + (the bfs level is less than 2 or two partitions are connected and + can find partition to device mapping) + See if any partition pair could reduce the current cost. + Choose the pair that shows the minimum cost and then combine them + #3. Repeat #2 until the cost cannot be reduced. + """ + def try_combining_partitions( + p0_index, + p1_index, + partitions + ) -> float: + """Given two partitions and a list of partitions, combine these two partitions + and see what is the cost of the modified partition list + """ + p0 = partitions[p0_index] + p1 = partitions[p1_index] + """If two partitions' bfs level are less than 2 or two partitions are connected to each other, + then they can be combined + """ + if (abs(p0.bfs_level - p1.bfs_level) <= 1) or (p0 in p1.parents) or p0 in (p1.children): + combine_two_partitions(p0, p1, partitions) + # Check if a circular dependency exists after combining + if check_dependency(partitions[-1]): + return float('inf') + # Check if the modified partition list can be mapped to devices after combination + reset_partition_device(partitions) + found_deivce = get_device_to_partitions_mapping(partitions, self.devices) + if not found_deivce: + return float('inf') + # Calculate the new cost + partition_to_latency_mapping = get_partition_to_latency_mapping(partitions, node_to_latency_mapping) + cost = get_latency_of_partitioned_graph(partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) + return cost + # If two partition can not be combined, the cost is inf + return float('inf') + + def search_combination( + transfer_rate_bytes_per_sec, + node_to_latency_mapping + ) -> bool: + """Given transfer rate between partitions and each node's latency, + find two partitions to combine so the cost of the partitions can + be reduced. + The algorithm is : + 1. Go through all the partition pairs and see + if any pair of partitions can be combined. + 2. Calculate the cost after the combination. + 3. Select the minimum cost and combine its cooresponding partition pair. + """ + partition_to_latency_mapping = get_partition_to_latency_mapping(self.partitions, node_to_latency_mapping) + cost = get_latency_of_partitioned_graph(self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) + if len(self.partitions) == 1: + return False + partition_pair: List[int] = [] + for i in range(len(self.partitions) - 1): + for j in range(i + 1, len(self.partitions)): + # Try to combine the partition pair + # and see the new cost after combination + new_cost = try_combining_partitions( + i, + j, + self.partitions[:] + ) + if new_cost <= cost: + partition_pair = [i, j] + cost = new_cost + reorganize_partitions(self.partitions) + # If a partition pair is found, combine them + if len(partition_pair) != 0: + p0 = self.partitions[partition_pair[0]] + p1 = self.partitions[partition_pair[1]] + combine_two_partitions(p0, p1, self.partitions) + get_bfs_level_partition(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return len(partition_pair) != 0 + + for node in self.graph_module.graph.nodes: + if node.op not in {'placeholder', 'get_attr', 'output'}: + self.create_single_node_partition(node) + # Set up parent partitions and children partitions for each partition + set_parents_and_children(self.partitions) + # Get bfs level for each partition + get_bfs_level_partition(self.partitions) + find_combination = True + while find_combination: + # Search for a pair partition to generate the minimum new cost, + # then combine them + find_combination = search_combination( + transfer_rate_bytes_per_sec, + node_to_latency_mapping + ) + # Make sure all partitions are set up correctly + reorganize_partitions(self.partitions) + # Set up node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def kl_based_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: Dict[Node, NodeLatency] + ) -> None: + """This function is a cost aware partition based + on Kernighan-Lin algorithm. + First, the graph is partitioned using size_based_partition. + Then, each node is swapped with any other node in a different + partition, and at the same time, the cost is estimated after + the swapping. + For example, we have nodes n0, n1, n2, n3 and n4. + Using size_based_partition, n0 and n1 are in Partition p0. + n2, n3 and n4 in Partition p1. The current cost is esimated. + We first tried using n0 to swap with n2 from the other partiton. + Then we see that swapping n0 and n2 shows a lower cost + than the current cost and it is the minimum among other pairs like + (n0, None)(This means moving n0 to Partition without swapping other nodes), + (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost + as the current cost. + Then We repeat this process for all the other nodes until all swapping pairs + are tried. + """ + def swap_nodes(n0, n1, p0, p1): + # Either n0 or n1 could be None + # That means we simply move the node + # to another partition + if n0 is not None: + p0.remove_node(n0) + p1.add_node(n0) + if n1 is not None: + p0.add_node(n1) + p1.remove_node(n1) + + def try_swap_nodes(n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec): + cost = float('inf') + swap_nodes(n0, n1, p0, p1) + # Reorganize partitions after swapping + reorganize_partitions(self.partitions) + # Check if there is a circular dependency after swapping + if (not check_dependency(p0)) and (not check_dependency(p1)): + reset_partition_device(self.partitions) + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, + node_to_latency_mapping + ) + # Check if all partitions can be mapped to logical devices after swapping + found_device = get_device_to_partitions_mapping(self.partitions, self.devices) + if not found_device: + cost = float('inf') + else: + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec + ) + # Swap back and reset all partitions back to original + swap_nodes(n1, n0, p0, p1) + reorganize_partitions(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return cost + + def swap_node_to_partition(node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec): + """This function helps to swap one node from partition p0 + with all the nodes in another partition p1 + """ + p1_nodes = list(p1.nodes) + [None] + min_cost = float('inf') + node_pair: List[Node] = [] + for n1 in p1_nodes: + # Ignore the node if it is not a op node + if n1 is not None and n1.op in {'placeholder', 'get_attr'}: + continue + # Try swapping node in p0 with n1 in p1 + cost = try_swap_nodes(node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec) + if cost < min_cost: + node_pair = [node, n1] + min_cost = cost + return cost, node_pair + + # First use size_base_partition + self.size_based_partition() + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, + node_to_latency_mapping + ) + # Calculate the cost of the partitions + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec + ) + # Keep tracking the node pair that shows the better cost + node_pair: List[Node] = [] + # Keep tracking the partition pair of node pair + partition_pair: List[Partition] = [] + # Collect all the op nodes from the graph + op_nodes = [] + for n in self.graph_module.graph.nodes: + if n.op not in {'placeholder', 'get_attr', 'output'}: + op_nodes.append(n) + for node in op_nodes: + # Find which partition the current node belongs + p0_index = self.node_to_partition[node] + p0 = self.partitions[p0_index] + # Go through all the other partitions to swap + # with other nodes from those partitions + for p1_index, _ in enumerate(self.partitions): + if p0_index != p1_index: + p1 = self.partitions[p1_index] + new_cost, new_node_pair = swap_node_to_partition( + node, + p0, + p1, + node_to_latency_mapping, + transfer_rate_bytes_per_sec + ) + # Update the cost + # Track the swapped node pair and their partitions + if new_cost < cost: + cost = new_cost + node_pair = new_node_pair + partition_pair = [p0, p1] + # Do the swapping after trying all the nodes from a partition + if len(node_pair) != 0: + swap_nodes(node_pair[0], node_pair[1], partition_pair[0], partition_pair[1]) + reorganize_partitions(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + reorganize_partitions(self.partitions) + # Mapping the device to the partition + get_device_to_partitions_mapping(self.partitions, self.devices) + return + + def aot_based_partition(self, node_to_partition_mapping, partition_to_logical_device_mapping): + """This function helps to rebuild the partitions given the nodes and its + corresponding partition id + """ + partition_id_to_partition_mapping: Dict[int, Partition] = {} + self.node_to_partition = node_to_partition_mapping + for node in self.node_to_partition: + partition_id = self.node_to_partition[node] + # If the requested partition has not been created, create the partition + if partition_id not in partition_id_to_partition_mapping: + partition = Partition(partition_id) + self.partitions.append(partition) + partition_id_to_partition_mapping[partition_id] = partition + partition.logical_device_ids = partition_to_logical_device_mapping[partition_id] + else: + partition = partition_id_to_partition_mapping[self.node_to_partition[node]] + # Add the current node into the partition + partition.add_node(node) diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py new file mode 100644 index 0000000000000..4b5ea52c5c1fe --- /dev/null +++ b/torch/fx/experimental/const_fold.py @@ -0,0 +1,269 @@ +import operator +from typing import Dict, Set, List, Optional + +import torch.fx +from torch.fx.experimental.subgraph_creation_example import split_module +import re + + +def _make_tuple(x): + """ + Helper to convert x into a one item tuple if it's not a tuple already. + """ + return x if isinstance(x, tuple) else (x,) + + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + const_output_names: Optional[List[str]] = None, + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.const_output_names = const_output_names + self.has_folding_been_run = False + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if self.const_subgraph_module is None or self.const_output_names is None: + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. We _make_tuple here because + # single attr const fold subgraphs output a single Tensor while + # multiple outputs are returned as Tuple[Tensor,]. + folded_attrs = _make_tuple(self.const_subgraph_module()) + + # Look for output node from const folding subgraph and set attrs on the + # module with the results. + for i in range(len(folded_attrs)): + setattr( + self, self.const_output_names[i], torch.nn.Parameter(folded_attrs[i]) + ) + + +def split_const_subgraphs( + module: torch.nn.Module, +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + mod_traced = torch.fx.symbolic_trace(module) + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: Set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op == "get_attr" or set(node.all_input_nodes).issubset(const_nodes): + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + # Gather all names that are output from the const folding subgraph, which we + # will need to set dummy params on the module. + const_output_names: List[str] = [] + for node in split.submod_0.graph.nodes: + if node.op == "output": + # Note: we _make_tuple here because the output Node either contains + # a single output Node, or Tuple[Node], so this simplifies things. + const_output_names = [o.name for o in _make_tuple(node.args[0])] + break + + # Make sure the attr name we want to use is uniquely named in the module. + for i in range(len(const_output_names)): + # Add a suffix to make it easier to tell these were the result of const folding. + name = const_output_names[i] + "__CF" + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + const_output_names[i] = name + + # Now track the const_output_names to what name is used in the parent graph + # from the split via call_function getitem, to see what order it is passed + # into the non-const subgraph submod_1. First look to the parent module + # containing/calling into the const/non-const submodules to determine what + # the inputs are to each. Note if submod_0 had a single output then there is + # no getitem, and we can simply use the output from the call to submoid_0. + call_submod_0_args, call_submod_1_args = None, None + orig_ph_targets: List[str] = [] + for node in split.graph.nodes: + if node.op == "placeholder": + orig_ph_targets.append(node.target) + + if node.op == "call_module": + if node.target == "submod_0": + call_submod_0_args = node.args + continue + elif node.target == "submod_1": + call_submod_1_args = node.args + continue + assert call_submod_0_args is not None and call_submod_1_args is not None + + # Look through the args for the call into submod_1, and find the args that + # come from submod_0. Also look for get_attrs fed directly from the parent + # split into submod_1, i.e. those attrs that are not constant folded. + submod_1_input_idx_to_folded_attr_name: Dict[int, str] = {} + submod_1_input_idx_to_unfolded_attr_name: Dict[int, str] = {} + for i, node in enumerate(call_submod_1_args): + const_output_name = None + # If we only had a single output from submod_0 then we simply look for + # the call_module into it. + if len(const_output_names) == 1: + if node.op == "call_module" and node.target == "submod_0": + const_output_name = const_output_names[0] + + # Else we had multiple outputs from submod_0, so we need to look for all + # getitems from the call to it. + else: + if ( + node.op == "call_function" + and node.target == operator.__getitem__ + and node.args[0].target == "submod_0" + ): + const_output_name = const_output_names[node.args[1]] + + # Now map from the index of the constant into calling submod_1 and map + # to the constant output name, which we use for swapping in getattrs + # instead of placeholders in submod_1. + if const_output_name is not None: + submod_1_input_idx_to_folded_attr_name[i] = const_output_name + elif node.op == "get_attr": + submod_1_input_idx_to_unfolded_attr_name[i] = node.target + + assert len(submod_1_input_idx_to_folded_attr_name) == len(const_output_names) + + # Now we have a mapping from const output names to the index they are passed + # into submod_1, so swap in getattrs for placeholders. + ph_idx = 0 + for node in split.submod_1.graph.nodes: + if node.op != "placeholder": + continue + is_folded_attr = ph_idx in submod_1_input_idx_to_folded_attr_name.keys() + is_unfolded_attr = ph_idx in submod_1_input_idx_to_unfolded_attr_name.keys() + if not is_folded_attr and not is_unfolded_attr: + ph_idx += 1 + continue + + const_output_name = ( + submod_1_input_idx_to_folded_attr_name[ph_idx] + if is_folded_attr + else submod_1_input_idx_to_unfolded_attr_name[ph_idx] + ) + if is_folded_attr: + assert not hasattr(mod_traced, const_output_name) + # Use a dummy param, which will be overwritten when we run const folding. + setattr( + mod_traced, + const_output_name, + torch.nn.Parameter(torch.randn(1)), + ) + with split.submod_1.graph.inserting_before(node): + node.replace_all_uses_with(split.submod_1.graph.get_attr(const_output_name)) + split.submod_1.graph.erase_node(node) + ph_idx += 1 + + # We may need to reorder placeholders to ensure they have the same order as + # they do in the original split. + ph_idx = 0 + node = next(iter(split.submod_1.graph.nodes)) + while node.op != "root": + if node.op != "placeholder": + node = node.next + continue + + curr_orig_ph_target = orig_ph_targets[ph_idx] + ph_idx += 1 + # If this ph is in the correct position, nothing to do. + if curr_orig_ph_target == node.target: + node = node.next + continue + + # This ph is not in the correct order, so search the rest of the graph + # for the ph we expected and prepend it before the current ph. + later_node = node.next + while later_node.op != "root": + if ( + later_node.op == "placeholder" + and curr_orig_ph_target == later_node.target + ): + break + later_node = later_node.next + assert later_node.op != "root" + node.prepend(later_node) + # Note we do not increment node here, as it still may be in the wrong + # place (we just prepended the ph that should have come before it). + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside submod_0, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to submod_0 must be constants accessible via get_attr. + for node in split.submod_0.graph.nodes: + if node.op != "placeholder": + continue + in_node = next(n for n in call_submod_0_args if n.name == node.target) + assert in_node.op == "get_attr" + with split.submod_0.graph.inserting_before(node): + node.replace_all_uses_with(split.submod_0.graph.get_attr(in_node.target)) + split.submod_0.graph.erase_node(node) + + return FoldedGraphModule( + mod_traced, split.submod_1.graph, split.submod_0.graph, const_output_names + ) diff --git a/torch/fx/experimental/fuser.py b/torch/fx/experimental/fuser.py new file mode 100644 index 0000000000000..7e75f1c32ffc2 --- /dev/null +++ b/torch/fx/experimental/fuser.py @@ -0,0 +1,64 @@ +import torch.fx as fx +from torch.nn.utils.fusion import fuse_conv_bn_eval +from typing import Type, Dict, Any, Tuple, Iterable +import torch +import copy + +def _parent_name(target : str) -> Tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit('.', 1) + return parent[0] if parent else '', name + +# Works for length 2 patterns with 2 modules +def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): + if len(node.args) == 0: + return False + nodes: Tuple[Any, fx.Node] = (node.args[0], node) + for expected_type, current_node in zip(pattern, nodes): + if not isinstance(current_node, fx.Node): + return False + if current_node.op != 'call_module': + return False + if not isinstance(current_node.target, str): + return False + if current_node.target not in modules: + return False + if type(modules[current_node.target]) is not expected_type: + return False + return True + + +def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): + assert(isinstance(node.target, str)) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, new_module) + +def fuse(model: torch.nn.Module, inplace=True) -> torch.nn.Module: + """ + Fuses convolution/BN layers for inference purposes. Will deepcopy your + model by default, but can modify the model inplace as well. + """ + patterns = [(torch.nn.Conv1d, torch.nn.BatchNorm1d), + (torch.nn.Conv2d, torch.nn.BatchNorm2d), + (torch.nn.Conv3d, torch.nn.BatchNorm3d)] + if not inplace: + model = copy.deepcopy(model) + fx_model = fx.symbolic_trace(model) + modules = dict(fx_model.named_modules()) + new_graph = copy.deepcopy(fx_model.graph) + + for pattern in patterns: + for node in new_graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + fused_conv = fuse_conv_bn_eval(conv, bn) + replace_node_module(node.args[0], modules, fused_conv) + node.replace_all_uses_with(node.args[0]) + new_graph.erase_node(node) + return fx.GraphModule(fx_model, new_graph) diff --git a/torch/fx/experimental/graph_manipulation.py b/torch/fx/experimental/graph_manipulation.py new file mode 100644 index 0000000000000..4e6c23cbad9f9 --- /dev/null +++ b/torch/fx/experimental/graph_manipulation.py @@ -0,0 +1,258 @@ +import json +from typing import Dict, List, NamedTuple, Any + +import torch +from torch.fx.experimental.shape_prop import ShapeProp +from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes +from torch.fx.graph import Graph, get_qualified_name +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node, Target, map_arg + + +def replace_target_nodes_with( + fx_module: GraphModule, + old_op: str, + old_target: Target, + new_op: str, + new_target: Target, +): + """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, + and updates them to match the new op code and target""" + new_graph = Graph() + val_map: Dict[Node, Node] = {} + for node in fx_module.graph.nodes: + if node.op == old_op and node.target == old_target: + args = map_arg(node.args, lambda n: val_map[n]) + kwargs = map_arg(node.kwargs, lambda n: val_map[n]) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + val_map[node] = new_graph.create_node( + new_op, new_target, args, kwargs, node.name + ) + else: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + fx_module.graph = new_graph + + +class size_bytes(NamedTuple): + output_size: int + total_size: int + + +def get_size_of_all_nodes(fx_module: GraphModule, args: List[torch.Tensor]) -> None: + """Given a fx graph module, update each node with its total size (weights + bias + output) + and its output_size(output). For a non-module node, the total size is the output size. + return total size""" + # Mark shape and dtype for each node (node.shape and node.dtype) + ShapeProp(fx_module).propagate(*args) + # Calculate the total size of the whole fx graph + total_size_of_graph = 0.0 + for node in fx_module.graph.nodes: + if node.op == "output": + break + node.size_bytes = get_size_of_node(fx_module, node) + return + + +def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: + """Given a node with node.dtype and node.shape, return its total size and its output size. + total_size = weights + bias + output_size + """ + # Total num of elements + total_num_of_elems = 0 + # For a module, conside all parameters + if node.op == "call_module": + submodule_dict = dict(fx_module.named_modules()) + submodule = submodule_dict[node.target] + parameters = submodule.named_parameters() + # Parameters are named tuples + for name, p in parameters: + total_num_of_elems += p.numel() + # Don't forget the output size + # node.shape is the shape of this node's output + shape = getattr(node, "shape", None) + if shape: + output_elem = shape.numel() + else: + raise RuntimeError("Node has no shape attr") + total_num_of_elems += output_elem + size_per_elem_bytes = 0 + dtype = getattr(node, "dtype", None) + if dtype: + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + else: + raise RuntimeError("Node has no dtype attr") + total_size = size_per_elem_bytes * total_num_of_elems + output_size = size_per_elem_bytes * output_elem + return size_bytes(output_size, total_size) + + +def serialize_shape(shape: torch.Size) -> str: + return str(list(shape)) + + +def serialize_tensor_quantization(tensor: torch.Tensor) -> Dict[str, Any]: + scheme: Dict[str, Any] = {} + if tensor.is_quantized: + scheme["q_scheme"] = str(tensor.qscheme()) + if tensor.qscheme() in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + scheme["q_scale"] = tensor.q_scale() + scheme["q_zero_pont"] = tensor.q_zero_point() + if tensor.qscheme() in { + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + }: + scheme["q_per_channel_scales"] = tensor.q_per_channel_scales().tolist() + scheme[ + "q_per_channel_zero_points" + ] = tensor.q_per_channel_zero_points().tolist() + scheme["q_per_channel_axis"] = tensor.q_per_channel_axis() + + return scheme + + +def serialize_weight(tensor: torch.Tensor) -> Dict: + weight: Dict[str, Any] = {} + weight["dtype"] = str(tensor.dtype) + weight["is_quantized"] = tensor.is_quantized + if tensor.is_quantized: + weight["quantized_type"] = serialize_tensor_quantization(tensor) + weight["shape"] = serialize_shape(tensor.shape) + return weight + + +def serialize_leaf_module( + node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str +) -> Dict: + parameters: Dict[str, Any] = {} + + for p_name, p_value in node.attrs_for_lowering.items(): # type: ignore + if isinstance(p_value, torch.Tensor): + weights_metadata[f"{name_prefix}.{p_name}"] = serialize_weight(p_value) + weights[f"{name_prefix}.{p_name}"] = p_value + else: + parameters[p_name] = str(p_value) + + return parameters + + +def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: + """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. + It also adds all weights the provided weights dictionary by qualified_name. + Dictionary Schema: + MODULE + { + modules: {module_name: MODULE], + nodes: [NODE], + weights {qualified_name: WEIGHT}, + } + NODE + { + shape: [], + dtype: dtype, + target: target, + op_code: op_code, + name: name, + args: [], + kwargs: {} + } + WEIGHT + { + dtype: dtype, + is_quantized: bool, + shape: [], + quantization_info: QUANTIZATION + } + QUANTIZATION + { + qscheme: qscheme, + q_scale: float, + q_zero_point: float, + q_per_channel_scales, [], + q_per_channel_zero_points: [], + q_per_channel_axis, int + } + """ + serialized_dict: Dict[str, Any] = {} + serialized_dict["modules"] = {} + serialized_dict["weights"] = {} + serialized_dict["nodes"] = [] + parameters = fx_module.named_parameters() + prefix = f"{name_prefix}." if name_prefix else "" + submodules = dict(fx_module.named_modules()) + for name, p in parameters: + if isinstance(p, torch.Tensor): + weight = serialize_weight(p) + serialized_dict["weights"][prefix + name] = weight + weights[prefix + name] = p + lift_lowering_attrs_to_nodes(fx_module) + for node in fx_module.graph.nodes: + node_rep: Dict[str, Any] = {} + # Get shape/type info, currently not needed for call_module. + if node.op != "call_module" or not isinstance( + submodules[node.target], GraphModule + ): + shape = getattr(node, "shape", None) + if shape: + node_rep["shape"] = serialize_shape(shape) + else: + raise RuntimeError( + "Node has no shape attr, this is likely because shape propagation has not been run on this Graph." + ) + dtype = getattr(node, "dtype", None) + if dtype: + node_rep["dtype"] = str(dtype) + else: + raise RuntimeError( + "Node has no dtype attr, this is likely because shape propagation has not been run on this Graph." + ) + + # Recurse down into any submodules we are calling. + if node.op == "call_module": + if isinstance(submodules[node.target], GraphModule): + serialized_module = serialize_module( + getattr(fx_module, node.target), weights, node.target + ) + serialized_dict["modules"][node.target] = serialized_module + else: + node_rep["parameters"] = serialize_leaf_module( + node, + serialized_dict["weights"], + weights, + prefix + node.target, + ) + + if node.op == "call_function": + node_rep["target"] = get_qualified_name(node.target) + else: + node_rep["target"] = str(node.target) + + # Make sure we capture all constants. + if node.op == "get_attr": + target = getattr(fx_module, node.target) + qualname = prefix + node.target + if isinstance(target, torch.Tensor) and qualname not in weights: + weight = serialize_weight(target) + serialized_dict["weights"][prefix + node.target] = weight + weights[prefix + node.target] = target + + node_rep["op_code"] = node.op + node_rep["name"] = node.name + node_rep["args"] = map_arg( + node.args, lambda arg: {"is_node": True, "name": str(arg)} + ) + node_rep["kwargs"] = map_arg( + node.kwargs, lambda arg: {"is_node": True, "name": str(arg)} + ) + serialized_dict["nodes"] += [node_rep] + + return serialized_dict + + +class AcceleratedGraphModule: + def __init__(self, fx_module: GraphModule): + """Creates the needed data structures to pass to the glow runtime""" + self.weights: Dict[str, Any] = {} + self.serialized_graph = serialize_module(fx_module, self.weights) + self.serialized_graph_json = json.dumps(self.serialized_graph, indent=4) diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py new file mode 100644 index 0000000000000..b72bbe633dd98 --- /dev/null +++ b/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,220 @@ +import torch + +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.symbolic_trace import symbolic_trace + +import itertools +import operator + +from typing import Dict, List + + +def get_first_dim(t: torch.Tensor) -> int: + """ + A free function primarily for use in the merge_matmul graph transformation below + that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape + is an attribute (and cannot be the target of a call_function node) and also helps save + a getitem op in the graph. + + Arguments: + t: The tensor to get the first dimension of. + + Returns: + The first dimension of t. + """ + return t.shape[0] + + +def legalize_graph(gm: GraphModule): + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + """ + # Build an adjacency list representation of node dependencies in the graph. This also + # serves as a list of nodes that still need to be inserted into the new, topologically + # sorted graph. + dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes} + + # Construct a new graph that will contain all nodes in topologically sorted order. + new_graph = Graph() + value_remap: Dict[Node, Node] = {} + + # Copy over all nodes with no dependencies. + for node, deps in dependencies.items(): + if not deps: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + # Remove the copied over nodes from the adjacency list. + for copied_node in value_remap.keys(): + del dependencies[copied_node] + + # While there are still nodes to insert into the new graph: + while dependencies: + copied_this_round = [] + + # Copy over all nodes whose dependencies already exist in the new graph. + for node, deps in dependencies.items(): + all_deps_copied = True + for dep in deps: + if dep not in value_remap: + all_deps_copied = False + + if all_deps_copied: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + copied_this_round.append(node) + + # Delete all nodes copied over in this iteration from dependencies. + for copied_node in copied_this_round: + del dependencies[copied_node] + + # Replace the old graph with the new, topologically sorted one. + gm.graph = new_graph + + +def may_depend_on(a: Node, b: Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: List[Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = symbolic_trace(in_mod) + + rhs_users: Dict[Node, List[Node]] = {} + lhs_users: Dict[Node, List[Node]] = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_sizes = [ + gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs + ] + merge_mm_split = gm.graph.call_function( + torch.split, (merge_mm, merge_mm_sizes), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint(in_mod) + return gm diff --git a/torch/fx/experimental/param_fetch.py b/torch/fx/experimental/param_fetch.py new file mode 100644 index 0000000000000..6bce29b97e787 --- /dev/null +++ b/torch/fx/experimental/param_fetch.py @@ -0,0 +1,60 @@ +from torch.fx.graph_module import GraphModule +from typing import Any, Callable, Dict, List, Tuple, Type +import torch +import torch.nn as nn + + +# Matching method matches the attribute name of current version to the attribute name of `target_version` +def default_matching(name: str, target_version: int) -> str: + """Default matching method + """ + return name + +# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. +# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. +# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. +module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { + torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), + torch.nn.modules.conv.Conv2d: ( + 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching + ), + torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), + torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), + torch.nn.modules.pooling.MaxPool2d: ( + 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching + ), + torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), +} + +def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: + """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` + after checking module's version is compatible with the `module_fetch_book`. + """ + attrs_for_lowering: Dict[str, Any] = {} + attrs_for_lowering["name"] = torch.typename(mod) + + if type(mod) in module_fetch_book: + version, param_to_fetch, matching_method = module_fetch_book[type(mod)] + if version < mod._version: + raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + for attr in param_to_fetch: + attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) + else: + raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + return attrs_for_lowering + +def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. + """ + submodules = dict(fx_module.named_modules()) + + for node in fx_module.graph.nodes: + if node.op == "call_module": + if isinstance(submodules[node.target], GraphModule): + lift_lowering_attrs_to_nodes(submodules[node.target]) + else: + node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py new file mode 100644 index 0000000000000..7d54924bda463 --- /dev/null +++ b/torch/fx/experimental/partitioner_utils.py @@ -0,0 +1,255 @@ +from typing import NamedTuple, Dict, List, Set +from torch.fx.node import Node, map_arg +from enum import Enum +class Partition: + """Partition class contains all the information about an individual partition. + It also provides necessary methods for manipulation the partition. + """ + def __init__(self, partition_id: int) -> None: + self.nodes: Set[Node] = set() + self.partition_id = partition_id + self.parents: Set['Partition'] = set() + self.children: Set['Partition'] = set() + self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: List[int] = [] + + def __str__(self): + return str(self.partition_id) + + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) + + def add_node(self, node): + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {'placeholder', 'get_attr'}: + self.nodes.add(n) + self.nodes.add(node) + self.recalculate_mem_size() + + def remove_node(self, node): + # Remove a node only if the node is in the partition + if node in self.nodes: + self.nodes.remove(node) + # Collect the node's input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Check if an input node is a placeholder or get_attr, + # and this input node is not used by some other nodes in this partition, + # the remove this input node + for input_node in input_nodes: + if all([n not in self.nodes for n in input_node.users]) and input_node.op in {'placeholder', 'get_attr'}: + self.nodes.remove(input_node) + self.recalculate_mem_size() + +class Device(NamedTuple): + name: str + available_mem_bytes: int + logical_id: int + +class NodeLatency(NamedTuple): + # Latency due to the memory bandwidth + mem_latency_sec: float + # Latency due to the computation + computer_latency_sec: float + +class PartitionLatency(NamedTuple): + # Sum of all nodes' memory latency on the critical path + mem_latency_sec: float + # Sum of all nodes' compute latency on the critical path + computer_latency_sec: float + # Latency of the critical path + overall_latency_sec: float + +class PartitionMode(Enum): + size_based = 0 + sparse_nn = 1 + cost_aware = 2 + kl_based = 3 + aot_based = 4 + +class PartitionerConfig(NamedTuple): + devices: List[Device] + mode: PartitionMode = PartitionMode.size_based + transfer_rate_bytes_per_sec: float = 0. + node_to_latency_mapping: Dict[Node, NodeLatency] = {} + node_to_partition_mapping: Dict[Node, int] = {} + partition_to_logical_device_mapping: Dict[int, List[int]] = {} + +def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, 'size_bytes', None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError('node has no size_bytes attr') + # Don't forget the op node itself + size_bytes = getattr(node, 'size_bytes', None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError('node has no size_bytes attr') + return total_size_of_input_nodes + +def get_latency_of_one_partition( + partition: Partition, + node_to_latency_mapping: Dict[Node, NodeLatency] +) -> PartitionLatency: + """Given a partiton and its nodes' latency, return a PartitionLatency for this partition""" + + def get_top_nodes(partition: Partition) -> List[Node]: + """Given a partition, return a list of nodes on the top bfs level""" + top_nodes: List[Node] = [] + for node in partition.nodes: + # Skip placeholder and get_attr nodes + if node.op in {'placeholder', 'get_attr'}: + continue + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # If a node has no input nodes in this partition, + # or its input nodes in this partition are placeholders and get_attrs + # this node is on the top bfs level in this partition + if not any([n in partition.nodes and n.op not in {'placeholder', 'get_attr'} for n in input_nodes]): + top_nodes.append(node) + return top_nodes + + def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + """Given a top node of a partition, this function returns + the latency of the critical path in the partition + """ + node_latency = node_to_latency_mapping[node] + # Calculate the current overall latency of the partition + overall_latency_sec = partition_latency.overall_latency_sec + \ + max(node_latency.computer_latency_sec, node_latency.mem_latency_sec) + # Update the mem latency of this path + mem_latency_sec = partition_latency.mem_latency_sec + node_latency.mem_latency_sec + # Update the compute latency of this path + computer_latency_sec = partition_latency.computer_latency_sec + node_latency.computer_latency_sec + # Get all users of this node that are in this partition + users = set(node.users).intersection(partition.nodes) + if users: + max_latency = PartitionLatency(mem_latency_sec=0., computer_latency_sec=0., overall_latency_sec=0.) + for n in users: + # Get new partition latency recursively + new_partition_latency = dfs_helper(n, PartitionLatency(mem_latency_sec, computer_latency_sec, overall_latency_sec)) + if new_partition_latency.overall_latency_sec > max_latency.overall_latency_sec: + max_latency = new_partition_latency + return max_latency + # If there is no user, the node is at bottom of the partition + return PartitionLatency(mem_latency_sec, computer_latency_sec, overall_latency_sec) + # Main part starts + # Get all top level nodes of this partition + top_nodes = get_top_nodes(partition) + critical_path_latency = PartitionLatency(mem_latency_sec=0., computer_latency_sec=0., overall_latency_sec=0.) + # Go through all top nodes and find the largest latency (critical pass latency) + for node in top_nodes: + partition_latency = dfs_helper(node, PartitionLatency(mem_latency_sec=0., computer_latency_sec=0., overall_latency_sec=0.)) + if partition_latency.overall_latency_sec > critical_path_latency.overall_latency_sec: + critical_path_latency = partition_latency + return critical_path_latency + +def get_partition_to_latency_mapping( + partitions: List[Partition], + node_to_latency_mapping: Dict[Node, NodeLatency] +) -> Dict[Partition, PartitionLatency]: + """Given all the partitions and node_to_latency_mapping dictionary, + return a mapping dictionary of each partition to its overall latency + """ + partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} + # Go through each partition and get its latency + for partition in partitions: + partition_latency = get_latency_of_one_partition(partition, node_to_latency_mapping) + partition_to_latency_mapping[partition] = partition_latency + return partition_to_latency_mapping + +def get_comm_latency_between(parent_partition: Partition, child_partition: Partition, transfer_rate_bytes_per_sec: float): + """Given two partitions (parent and child), + calculate the communication latency between the two. + """ + # If two partitions are on the same device, the comm latency is 0. + if parent_partition.logical_device_ids != [] and child_partition.logical_device_ids != [] \ + and parent_partition.logical_device_ids == child_partition.logical_device_ids: + return 0. + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + for node in child_partition.nodes: + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + for n in input_nodes: + if n in parent_partition.nodes and n not in visited_nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes is not None: + comm_size += size_bytes.output_size + visited_nodes.add(n) + return comm_size / transfer_rate_bytes_per_sec + +def get_latency_of_partitioned_graph( + partitions: List[Partition], + partition_to_latency_mapping: Dict[Partition, PartitionLatency], + transfer_rate_bytes_per_sec: float +): + """Given all paritions in a graph, find the critical path among all partitions + and return its latency as the latency of the whole graph + """ + def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: + """This function helps to recursively get the latency of a path of partitions + """ + # Update latency by adding current partition's latency + latency_so_far_sec += partition_to_latency_mapping[partition].overall_latency_sec + children = partition.children + if partition.children: + max_latency_sec = 0. + for child in partition.children: + # Calculate latency between + comm_latency_sec = get_comm_latency_between(partition, child, transfer_rate_bytes_per_sec) + new_latency_sec = dfs_helper(child, latency_so_far_sec + comm_latency_sec) + if new_latency_sec > max_latency_sec: + max_latency_sec = new_latency_sec + return max_latency_sec + return latency_so_far_sec + + def get_top_partitions(partitions: List[Partition]) -> List[Partition]: + """This function is to return all the partitions without parents + as the starting points of all the paths + """ + top_partitions = [] + for partition in partitions: + # If a partition has no parents, then it is a top partition + if len(partition.parents) == 0: + top_partitions.append(partition) + return top_partitions + + top_partitions = get_top_partitions(partitions) + critical_path_latency_sec = 0. + for partition in top_partitions: + latency_sec = dfs_helper(partition, 0.) + if latency_sec > critical_path_latency_sec: + critical_path_latency_sec = latency_sec + return critical_path_latency_sec diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py new file mode 100644 index 0000000000000..93c10d2687d16 --- /dev/null +++ b/torch/fx/experimental/rewriter.py @@ -0,0 +1,90 @@ +import ast +import inspect +import textwrap +import copy +from types import FunctionType +from typing import cast, Union, Callable +from torch.fx.symbolic_trace import Tracer +from torch.fx.graph import Graph +from torch.jit.frontend import normalize_source_lines +import torch + +class AST_Rewriter(ast.NodeTransformer): + """ + Take a FunctionType object representing a `forward` method, then + perform an AST rewrite to swap out nodes that are not symbolically + traceable with a callsite to the FX alternative. + + To support swapping out an AST node, define a new `visit` method on + that node. For more details, see: + https://docs.python.org/3/library/ast.html#ast.NodeTransformer + """ + + def rewrite(self, fn: FunctionType): + + # Normalize the source lines + sourcelines, _ = inspect.getsourcelines(fn) + sourcelines = normalize_source_lines(sourcelines) + source = ''.join(sourcelines) + normalized_str = textwrap.dedent(source) + + # Rewrite the original AST + source_ast = ast.parse(normalized_str) + dest_ast = ast.fix_missing_locations(self.visit(source_ast)) + + # Pull out the compiled fucntion from the newly-created Module + code = compile(dest_ast, "", "exec") + globals_dict = copy.copy(fn.__globals__) + keys_before = set(globals_dict.keys()) + exec(code, globals_dict) + new_keys = list(set(globals_dict.keys()) - keys_before) + assert len(new_keys) == 1 + fn_compiled = globals_dict[new_keys[0]] + + # Return the correct FunctionType object + return fn_compiled + + def visit_Assert(self, node): + """ + Swap out the Assert node (Python's `assert`) with a callsite to the + symbolically-traceable torch._assert function + """ + # Create the Call node + n = ast.parse('torch._assert()', mode='eval') + assert isinstance(n, ast.Expression) + call_node = n.body + assert isinstance(call_node, ast.Call) + msg = node.msg if node.msg else ast.Constant(value="", kind=None) + call_node.args = [node.test, msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = ast.Expr(value=call_node) + + # Return the new Call node to signify that we want to use it as + # a replacement for the original _assert node + return ast.copy_location(expr_wrapper, node) + + +class RewritingTracer(Tracer): + def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: + return super().trace(_rewrite(root)) + + +def _rewrite(fn : Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: + if isinstance(fn, torch.nn.Module): + # Rewrite this module's forward() and all of its recursive children's + # forward. Return the new rewritten module hierarchy. + def rewrite_module(m : torch.nn.Module): + class RewrittenModule(torch.nn.Module): + def __init__(self, orig): + super().__init__() + self.__dict__ = copy.copy(orig.__dict__) + RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) + new_m = RewrittenModule(m) + for name, child in new_m.named_children(): + new_m[name] = rewrite_module(child) # type: ignore + return new_m + return rewrite_module(fn) + else: + # Rewrite this single free function + return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/torch/fx/experimental/shape_prop.py b/torch/fx/experimental/shape_prop.py new file mode 100644 index 0000000000000..52264796c7d4c --- /dev/null +++ b/torch/fx/experimental/shape_prop.py @@ -0,0 +1,51 @@ +import torch +import torch.fx +from torch.fx.node import Node + +from typing import Dict + +class ShapeProp: + def __init__(self, mod): + self.mod = mod + self.graph = mod.graph + self.modules = dict(self.mod.named_modules()) + + def propagate(self, *args): + args_iter = iter(args) + env : Dict[str, Node] = {} + + def load_arg(a): + return torch.fx.node.map_arg(a, lambda n: env[n.name]) + + def fetch_attr(target : str): + target_atoms = target.split('.') + attr_itr = self.mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + for node in self.graph.nodes: + if node.op == 'placeholder': + result = next(args_iter) + elif node.op == 'get_attr': + result = fetch_attr(node.target) + elif node.op == 'call_function': + result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == 'call_method': + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == 'call_module': + result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == 'output': + return load_arg(node.args[0]) + + if isinstance(result, torch.Tensor): + node.shape = result.shape + node.dtype = result.dtype + + env[node.name] = result + + return None diff --git a/torch/fx/experimental/subgraph_creation_example.py b/torch/fx/experimental/subgraph_creation_example.py new file mode 100644 index 0000000000000..daca84a86307a --- /dev/null +++ b/torch/fx/experimental/subgraph_creation_example.py @@ -0,0 +1,177 @@ +import torch +from torch.fx.graph_module import GraphModule +from typing import Callable, List, Dict, Set, Any, Optional + +class Partition: + def __init__(self, name: str): + self.name: str = name + self.node_names: List[str] = [] + self.inputs: Set[str] = set() + self.outputs: Set[str] = set() + self.partitions_dependent_on: Set[str] = set() + self.partition_dependents: Set[str] = set() + self.graph : torch.fx.graph.Graph = torch.fx.graph.Graph() # type: ignore + self.environment : Dict[torch.fx.node.Node, torch.fx.node.Node] = {} # type: ignore + self.targets : Dict[str, Any] = {} + + def __repr__(self) -> str: + return f"name: {self.name},\n" \ + f" nodes: {self.node_names},\n" \ + f" inputs: {self.inputs},\n" \ + f" outputs: {self.outputs},\n" \ + f" partitions depenent on: {self.partitions_dependent_on},\n" \ + f" parition dependents: {self.partition_dependents}" + +# Creates subgraphs out of main graph +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[torch.fx.node.Node], int], # type: ignore +): + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, torch.fx.node.Node] = {} # type: ignore + + def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optional[torch.fx.node.Node]): # type: ignore + def_partition_name = getattr(def_node, '_fx_partition', None) + use_partition_name = getattr(use_node, '_fx_partition', None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.add(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.add(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.add(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.add(def_partition_name) + + # split nodes into parititons + for node in m.graph.nodes: + orig_nodes[node.name] = node + + # TODO currently placeholders/parameters aren't put into random partitions, + # rather they're added to the graphs where they are used down below + if node.op in ["placeholder", "get_attr"]: + continue + if node.op == 'output': + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) # type: ignore + continue + partition_name = str(split_callback(node)) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) # type: ignore + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # type: ignore + + # find partitions with no dependencies + root_partitions : List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.partitions_dependent_on): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions : List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].partition_dependents: + partitions[dependent].partitions_dependent_on.remove(root_partition) + if not partitions[dependent].partitions_dependent_on: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # add placeholders to parititons + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for input in partition.inputs: + placeholder = partition.graph.placeholder(input) + partition.environment[orig_nodes[input]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, '_fx_partition'): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n : environment[n]) # type: ignore + gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n : environment[n]) # type: ignore + + if node.op not in ['call_module', 'get_attr']: + target = node.target + else: + target_atoms = node.target.split('.') + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise RuntimeError(f'Operator target {node.target} not found!') + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = '_'.join(target_atoms) + partition.targets[target] = target_attr + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + new_node = partition.graph.create_node(op=node.op, target=target, args=gathered_args, + kwargs=gathered_kwargs) + partition.environment[node] = new_node + + # Set up values to construct base module + base_mod_env : Dict[str, torch.fx.node.Node] = {} # type: ignore + base_mod_graph : torch.fx.graph.Graph = torch.fx.graph.Graph() # type: ignore + base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {} # type: ignore + for node in m.graph.nodes: + if node.op == 'placeholder': + base_mod_env[node.name] = base_mod_graph.placeholder(node.name) + elif node.op == 'get_attr': + base_mod_env[node.name] = base_mod_graph.get_attr(node.target) + attr_val = m + for atom in node.target.split('.'): + if not hasattr(attr_val, atom): + raise RuntimeError(f'Node target {node.target} not found!') + attr_val = getattr(attr_val, atom) + base_mod_attrs[node.target] = attr_val + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore + partition.graph.output(output_vals) + + # Construct GraphModule for this partition + submod_name = f'submod_{partition_name}' + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, partition.graph) # type: ignore + + # Emit call in base graph to this submodule + + output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) # type: ignore + if len(partition.outputs) > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) # type: ignore + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore + else: + base_mod_env[list(partition.outputs)[0]] = output_val + + for node in m.graph.nodes: + if node.op == 'output': + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n : base_mod_env[n.name])) # type: ignore + + return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) # type: ignore diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 6ca60f6211aa1..34bbc98cf9e00 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1,21 +1,22 @@ -from .node import Node, Argument, Target +from .node import Node, Argument, Target, map_arg -from typing import Callable, Any, List, Dict, Optional, Tuple +from typing import Callable, Any, List, Dict, Optional, Tuple, Set import builtins import torch +import types import keyword import re def _shadows_builtin_name(name: str) -> bool: - return name in builtins.__dict__ or name in keyword.kwlist + return name in builtins.__dict__ or name in keyword.kwlist or name in {'inf', 'nan'} def _is_magic(x: str) -> bool: return x.startswith('__') and x.endswith('__') -def snake_case(s: str) -> str: +def _snake_case(s: str) -> str: return ''.join(['_' + i.lower() if i.isupper() else i for i in s]).lstrip('_') -def _qualified_name(func: Callable[..., Any]) -> str: +def get_qualified_name(func: Callable[..., Any]) -> str: # things like getattr just appear in builtins if getattr(builtins, func.__name__, None) is func: return func.__name__ @@ -52,99 +53,481 @@ def _format_target(base: str, target: str) -> str: r = f'{r}.{e}' return r -def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: - """ apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ - if isinstance(a, (tuple, list)): - return type(a)(map_arg(elem, fn) for elem in a) - elif isinstance(a, dict): - return {k: map_arg(v, fn) for k, v in a.items()} - elif isinstance(a, slice): - return slice(map_arg(a.start, fn), map_arg(a.stop, fn), map_arg(a.step, fn)) - elif isinstance(a, Node): - return fn(a) - else: - return a +# Borrowed from CPython typing module +# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 +def _type_repr(obj): + """Return the repr() of an object, special-casing types (internal helper). + If obj is a type, we return a shorter version than the default + type.__repr__, based on the module and qualified name, which is + typically enough to uniquely identify a type. For everything + else, we fall back on repr(obj). + """ + # HACK: In Python 3.6, type aliases from ``typing`` are instances of ``type``, but in + # later Python versions, type aliases are not instances of ``type``!! We want + # all type aliases to fall through to ``repr``, so if we have a type that is + # in the module typing, don't go down this path. + if isinstance(obj, type) and obj.__module__ != 'typing': + if obj.__module__ == 'builtins': + return obj.__qualname__ + return f'{obj.__module__}.{obj.__qualname__}' + if obj is ...: + return('...') + if isinstance(obj, types.FunctionType): + return obj.__name__ + return repr(obj) + +class _InsertPoint: + def __init__(self, graph, new_insert): + self.graph = graph + self.orig_insert, graph._insert = graph._insert, new_insert + + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + self.graph._insert = self.orig_insert + +class _node_list: + def __init__(self, graph: 'Graph', direction: str = '_next'): + assert direction in ['_next', '_prev'] + self.graph = graph + self.direction = direction + + def __len__(self): + return self.graph._len + + def __iter__(self): + root, direction = self.graph._root, self.direction + cur = getattr(root, direction) + while cur is not root: + if not cur._erased: + yield cur + cur = getattr(cur, direction) + + def __reversed__(self): + return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') class Graph: + """ + ``Graph`` is the main data structure used in the FX Intermediate Representation. + It consists of a series of ``Node`` s, each representing callsites (or other + syntactic constructs). The list of ``Node`` s, taken together, constitute a + valid Python function. + + For example, the following code + + .. code-block:: python + + import torch + import torch.fx + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + Will produce the following Graph:: + + print(gm.graph) + + .. code-block:: text + + graph(x): + %linear_weight : [#users=1] = self.linear.weight + %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) + %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) + return topk_1 + + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. + """ def __init__(self): - self._nodes : List[Node] = [] + """ + Construct an empty Graph. + """ + self._root : Node = Node(self, '', 'root', '', (), {}) self._used_names : Dict[str, int] = {} # base name -> number + self._insert = self._root.prepend + self._len = 0 @property - def nodes(self): - return tuple(self._nodes) + def nodes(self) -> _node_list: + """ + Get the list of Nodes that constitute this Graph. + + Note that this ``Node`` list representation is a doubly-linked list. Mutations + during iteration (e.g. delete a Node, add a Node) are safe. + + Returns: + + A doubly-linked list of Nodes. Note that ``reversed`` can be called on + this list to switch iteration order. + """ + return _node_list(self) - def graph_copy(self, g : 'Graph'): + def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> 'Optional[Argument]': """ - Append all nodes from graph `g` to this graph + Copy all nodes from a given graph into ``self``. + + Args: + + g (Graph): The source graph from which to copy Nodes. + + val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping + from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed + in with values in it already to override copying of certain values. + + Returns: + + The value in ``self`` that is now equivalent to the output value in ``g``, + if ``g`` had an ``output`` node. ``None`` otherwise. """ - val_map : Dict[Node, Node] = {} - for node in g._nodes: + for node in g.nodes: + if node in val_map: + continue + if node.op == 'output': + rv = map_arg(node.args[0], lambda n: val_map[n]) + return rv val_map[node] = self.node_copy(node, lambda n : val_map[n]) + return None + + def __deepcopy__(self, memo=None) -> 'Graph': + """ + Explicitly implement __deepcopy__ to prevent excessive recursion depth + from the default implementation. This uses graph_copy to copy the nodes + in an iterative way, rather than recursive. It also populates the + memoization table to prevent unnecessary copies (e.g. references to + nodes or other parts of the Graph from a custom GraphModule implementation + """ + memo = memo if memo else {} + g = Graph() + output_val = g.graph_copy(self, val_map=memo) + g.output(output_val) + return g + + def create_node(self, op: str, target: 'Target', + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Create a ``Node`` and add it to the ``Graph`` at the current insert-point. + Note that the current insert-point can be set via :meth:`Graph.inserting_before` + and :meth:`Graph.inserting_after`. + + Args: + op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the ``Graph`` docstring. + + args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. + + kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node + + name (Optional[str]): an optional string name for the ``Node``. + This will influence the name of the value assigned to in the + Python generated code. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: - def _mark_uses(self, a: Argument): - def add_use(n: Node): - n.uses += 1 - return n - map_arg(a, add_use) - - def create_node(self, op: str, target: Target, - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None, - name: Optional[str] = None) -> Node: - assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder') + The newly-created and inserted node. + """ + assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') args = () if args is None else args kwargs = {} if kwargs is None else kwargs - self._mark_uses(args) - self._mark_uses(kwargs) - n = Node(self, name if name is not None else self._name(target), op, target, args, kwargs) - self._nodes.append(n) + assert isinstance(args, tuple), "args must be a tuple" + assert isinstance(kwargs, dict), "kwargs must be a dict" + unique_name = self._create_unique_name(name if name is not None else self._target_to_str(target)) + n = Node(self, unique_name, op, target, args, kwargs, type_expr) + self._insert(n) + self._len += 1 return n - # sugar for above when you know the op - def placeholder(self, name: str) -> Node: - return self.create_node('placeholder', name) + def erase_node(self, to_erase : Node) -> None: + """ + Erases a ``Node`` from the ``Graph``. Throws an exception if + there are still users of that node in the ``Graph``. + + Args: + + to_erase (Node): The ``Node`` to erase from the ``Graph``. + """ + if len(to_erase.users) > 0: + raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' + f'users in the graph: {to_erase.users}!') + + to_erase._remove_from_list() + to_erase._erased = True # iterators may retain handles to erased nodes + self._len -= 1 + + # Null out this Node's argument nodes so that the Nodes referred to + # can update their ``users`` accordingly + new_args = map_arg(to_erase.args, lambda n: None) + assert isinstance(new_args, tuple) + to_erase.args = new_args + new_kwargs = map_arg(to_erase.kwargs, lambda n: None) + assert isinstance(new_kwargs, dict) + to_erase.kwargs = new_kwargs + + def inserting_before(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_before(n): + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently + + Args: + n (Optional[Node]): The node before which to insert. If None this will insert before + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_after(self._root) + assert n.graph == self, "Node to insert before is not in graph." + return _InsertPoint(self, n.prepend) + + def inserting_after(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_after(n): + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently + + Args: + n (Optional[Node]): The node before which to insert. If None this will insert after + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_before(self._root) + assert n.graph == self, "Node to insert after is not in graph." + return _InsertPoint(self, n.append) + + # sugar for create_node when you know the op + def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents + a function input. + + Args: + + name (str): A name for the input value. This corresponds to the name + of the positional argument to the function this ``Graph`` represents. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. This is needed in some + cases for proper code generation (e.g. when the function is used + subsequently in TorchScript compilation). + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node('placeholder', name, type_expr=type_expr) + + def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the + fetch of an attribute from the ``Module`` hierarchy. + + Args: - def get_attr(self, name: str) -> Node: - return self.create_node('get_attr', name) + qualified_name (str): the fully-qualified name of the attribute to be retrieved. + For example, if the traced Module has a submodule named ``foo``, which has a + submodule named ``bar``, which has an attribute named ``baz``, the qualified + name ``foo.bar.baz`` should be passed as ``qualified_name``. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + + Returns: + + The newly-created and inserted ``get_attr`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node('get_attr', qualified_name, type_expr=type_expr) def call_module(self, module_name: str, - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None) -> Node: - return self.create_node('call_module', module_name, args, kwargs) + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node + represents a call to the forward() function of a ``Module`` in the ``Module`` + hierarchy. + + Args: + + module_name (str): The qualified name of the ``Module`` in the ``Module`` + hierarchy to be called. For example, if the traced ``Module`` has a + submodule named ``foo``, which has a submodule named ``bar``, the + qualified name ``foo.bar`` should be passed as ``module_name`` to + call that module. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this should *not* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted ``call_module`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) def call_method(self, method_name: str, - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None) -> Node: - return self.create_node('call_method', method_name, args, kwargs) + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node + represents a call to a given method on the 0th element of ``args``. + + Args: + + method_name (str): The name of the method to apply to the self argument. + For example, if args[0] is a ``Node`` representing a ``Tensor``, + then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this *should* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_method`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) def call_function(self, the_function: Callable[..., Any], - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None) -> Node: - return self.create_node('call_function', the_function, args, kwargs) + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node + represents a call to a Python callable, specified by ``the_function``. ``the_function`` + can be + + Args: + + the_function (Callable[..., Any]): The function to be called. Can be any PyTorch + operator, Python function, or member of the ``builtins`` or ``operator`` + namespaces. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called function. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called function + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns + + The newly created and inserted ``call_function`` node. - def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node: - """ copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node - to the graph of self""" + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) + + def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: + """ + Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from + the graph of node to the graph of self. Example:: + + # Copying all the nodes in `g` into `new_graph` + g : torch.fx.Graph = ... + new_graph = torch.fx.graph() + value_remap = {} + for node in g.nodes: + value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) + + Args: + + node (Node): The node to copy into ``self``. + + arg_transform (Callable[[Node], Argument]): A function that transforms + ``Node`` arguments in node's ``args`` and ``kwargs`` into the + equivalent argument in ``self``. In the simplest case, this should + retrieve a value out of a table mapping Nodes in the original + graph to ``self``. + """ args = map_arg(node.args, arg_transform) kwargs = map_arg(node.kwargs, arg_transform) assert isinstance(args, tuple) assert isinstance(kwargs, dict) - if node.op == "placeholder": - # Placeholder names are user-visible, so they should be copied as-is without normalizing them. - name = node.name - else: - name = self._name(node.name) - return self.create_node(node.op, node.target, args, kwargs, name) + return self.create_node(node.op, node.target, args, kwargs, node.name, node.type) - def output(self, result: Argument): - self.result = result - self._mark_uses(result) + def output(self, result: 'Argument', type_expr: Optional[Any] = None): + """ + Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents + a ``return`` statement in Python code. ``result`` is the value that should + be returned. + + Args: + + result (Argument): The value to be returned. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. - def _name(self, target: Target) -> str: + .. note:: + + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) + + def _target_to_str(self, target : Target) -> str: if callable(target): op = target.__name__ else: @@ -152,73 +535,188 @@ def _name(self, target: Target) -> str: op = target if _is_magic(op): op = op[2:-2] - op = op.replace('.', '_') + op = _snake_case(op) + return op + + def _create_unique_name(self, candidate : str) -> str: # delete all characters that are illegal in a Python identifier - op = re.sub('[^0-9a-zA-Z_]+', '_', op) - op = snake_case(op) - if op[0].isdigit(): - op = f'_{op}' - - if op not in self._used_names: - self._used_names[op] = 0 - # Avoid shadowing PyTorch and Python builtins. - if not hasattr(torch, op) and \ - not hasattr(torch.nn.functional, op) and \ - not hasattr(torch.nn, op) and \ - not _shadows_builtin_name(op): - return op - i = self._used_names[op] = self._used_names[op] + 1 - return f'{op}_{i}' - - def python_code(self, root_module: str) -> Tuple[str, str, List[str]]: + candidate = re.sub('[^0-9a-zA-Z_]+', '_', candidate) + if candidate[0].isdigit(): + candidate = f'_{candidate}' + + def illegal_shadowing_name(name : str) -> bool: + return hasattr(torch, name) or \ + hasattr(torch.nn.functional, name) or \ + hasattr(torch.nn, name) or \ + _shadows_builtin_name(name) + + while candidate in self._used_names or illegal_shadowing_name(candidate): + match = re.match(r"(.*)_(\d+)$", candidate) + if match is None: + candidate = candidate + '_1' + else: + base, num = match.group(1, 2) + candidate = f'{base}_{int(num) + 1}' + + self._used_names.setdefault(candidate) + return candidate + + def python_code(self, root_module: str) -> str: + """ + Turn this ``Graph`` into valid Python code. + + Args: + + root_module (str): The name of the root module on which to look-up + qualified name targets. This is usually 'self'. + + Returns: + + The string source code generated from this ``Graph``. + """ free_vars: List[str] = [] + modules_used : Set[str] = set() body: List[str] = [] - for node in self._nodes: + + # Wrap string in list to pass by reference + maybe_return_annotation : List[str] = [''] + + def register_modules_used(qualified_name : str): + if '.' in qualified_name: + module_name = qualified_name.split('.', maxsplit=1)[0] + modules_used.add(module_name) + + def type_repr(o : Any): + typename = _type_repr(o) + if all(x.isidentifier() for x in typename.split('.')): + register_modules_used(typename) + else: + # this is a constructor type, e.g. typing.List[torch.Tensor] + modules_used.add(o.__module__) + for sub_type in o.__args__: + # make sure we have torch.Tensor + type_repr(sub_type) + return typename + + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use : Dict[Node, Node] = {} + user_to_last_uses : Dict[Node, List[Node]] = {} + + def register_last_uses(n : Node, user : Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + def delete_unused_values(user : Node): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([n.name for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + def emit_node(node : Node): if node.op == 'placeholder': assert isinstance(node.target, str) - free_vars.append(node.target) + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') raw_name = node.target.replace('*', '') if raw_name != node.name: body.append(f'{node.name} = {raw_name}\n') - continue + return elif node.op == 'call_method': assert isinstance(node.target, str) body.append( f'{node.name} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})\n') - continue + f'({_format_args(node.args[1:], node.kwargs)})') + return elif node.op == 'call_function': assert callable(node.target) # pretty print operators if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n') - continue - qualified_name = _qualified_name(node.target) + body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + qualified_name = get_qualified_name(node.target) + register_modules_used(qualified_name) if qualified_name == 'getattr' and \ isinstance(node.args, tuple) and \ isinstance(node.args[1], str) and \ node.args[1].isidentifier(): # pretty print attribute access - body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}\n') - continue - body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})\n') - continue + body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})') + return elif node.op == 'call_module': assert isinstance(node.target, str) - body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})\n') - continue + body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return elif node.op == 'get_attr': assert isinstance(node.target, str) - body.append(f'{node.name} = {_format_target(root_module, node.target)}\n') - continue + body.append(f'{node.name} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(f'return {repr(node.args[0])}') + return raise NotImplementedError(f'node: {node.op} {node.target}') - src = ''.join(body) - return src, str(self.result), free_vars + for node in self.nodes: + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + emit_node(node) + delete_unused_values(node) + + # repr() for inf and nan floating point values aren't parseable by + # python as literals. Explicitly import the names from the ``math`` module. + import_strs = [f'import {name}' for name in sorted(modules_used)] + import_block = '\n'.join(import_strs) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + fn_code = f"""\ +{import_block} +def forward(self, {', '.join(free_vars)}){maybe_return_annotation[0]}: +{code}""" + + return fn_code def __str__(self) -> str: + """ + Print a human-readable (not machine-readable) string representation + of this Graph + """ placeholder_names : List[str] = [] + # This is a one-element array just so ``format_node`` can modify the closed + # over value + maybe_return_typename : List[str] = [''] def format_arg(arg) -> str: if isinstance(arg, list): @@ -237,28 +735,112 @@ def format_arg(arg) -> str: else: return str(arg) + def pretty_print_target(target): + """ + Make target printouts more user-friendly. + 1) builtins will be printed as `builtins.xyz` + 2) operators will be printed as `operator.xyz` + 3) other callables will be printed with qualfied name, e.g. torch.add + """ + if isinstance(target, str): + return target + if hasattr(target, '__module__'): + if not hasattr(target, '__name__'): + # Just to be defensive, if we don't have `__name__`, get the + # qualname. Not sure if this happens for any members of `operator` + # or `builtins`. This fallback path is not as good, since e.g. + # things in `operator` have `_operator` as their __module__. + return get_qualified_name(target) + if target.__module__ == 'builtins': + return f'builtins.{target.__name__}' + elif target.__module__ == '_operator': + return f'operator.{target.__name__}' + return get_qualified_name(target) + def format_node(n : Node) -> Optional[str]: if n.op == 'placeholder': assert isinstance(n.target, str) - placeholder_names.append(n.target) + arg_str = n.target + arg_str += arg_str + f': {_type_repr(n.type)}' if n.type is not None else '' + placeholder_names.append(arg_str) return None elif n.op == 'get_attr': - return f'%{n.name} : [uses={n.uses}] = self.{n.target}' + maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else '' + return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = self.{n.target}' + elif n.op == 'output': + if n.type is not None: + maybe_return_typename[0] = f' -> {_type_repr(n.type)}' + return f'return {n.args[0]}' else: - return f'%{n.name} : [uses={n.uses}] = {n.op}[target={n.target}](' \ + maybe_typename = f'{_type_repr(n.type)} ' if n.type is not None else '' + return f'%{n.name} : {maybe_typename}[#users={len(n.users)}] = {n.op}[target={pretty_print_target(n.target)}](' \ f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})' - node_strs = [format_node(node) for node in self._nodes] + node_strs = [format_node(node) for node in self.nodes] param_str = ', '.join(placeholder_names) - s = f'graph({param_str}):' + s = f'graph({param_str}){maybe_return_typename[0]}:' for node_str in node_strs: if node_str: s += '\n ' + node_str - if self.result: - s += f'\n return {format_arg(self.result)}' return s + def lint(self, root : Optional[torch.nn.Module] = None): + """ + Runs various checks on this Graph to make sure it is well-formed. In + particular: + - Checks Nodes have correct ownership (owned by this graph) + - Checks Nodes appear in topological order + - If ``root`` is provided, checks that targets exist in ``root`` + + Args: + + root (Optional[torch.nn.Module]): The root module with which to check + for targets. This is equivalent to the ``root`` argument that is + passed when constructing a ``GraphModule``. + """ + + # Check topo order + def check_arg(arg : Node, n : Optional[Node] = None) -> None: + context_str = f' of Node \'{n}\' ' if n else ' ' + if arg.graph is not self: + raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' + f'but was used as an argument! If you are copying nodes from another graph, make ' + f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') + if arg not in seen_values: + raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' + f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') + + seen_names : Set[str] = set() + seen_values : Set[Node] = set() + for node in self.nodes: + if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: + raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') + if node.graph is not self: + raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') + map_arg(node.args, lambda arg: check_arg(arg, node)) + map_arg(node.kwargs, lambda arg: check_arg(arg, node)) + seen_values.add(node) + + if node.name in seen_names: + raise RuntimeError(f'Node redefined name {node.name}!') + seen_names.add(node.name) + + # Check targets are legit + if root: + for node in self.nodes: + if node.op in ['get_attr', 'call_module']: + assert isinstance(node.target, str) + target_atoms = node.target.split('.') + m_itr = root + for i, atom in enumerate(target_atoms): + m_itr = getattr(m_itr, atom, None) + if m_itr is None: + seen_qualname = '.'.join(target_atoms[:i]) + raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' + f'{atom} of {seen_qualname}') + + reflectable_magic_methods = { 'add': '{} + {}', 'sub': '{} - {}', diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index e635819550ad1..fc68cdab56779 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,9 +1,17 @@ import torch +import torch.nn as nn import torch.overrides +from torch.nn.modules.module import _addindent import linecache -from typing import Type, Dict, List, Any, Union +from typing import Type, Dict, List, Any, Union, Optional from .graph import Graph import copy +import sys +import traceback +import math +from pathlib import Path +import os +import warnings # normal exec loses the source code, however we can patch # the linecache module to still recover it. @@ -28,9 +36,7 @@ def patched_getline(*args, **kwargs): linecache.getlines = patched_getline def _forward_from_src(src : str): - gbls: Dict[str, Any] = { - 'torch': torch - } + gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan} exec_with_source(src, gbls) return gbls['forward'] @@ -49,7 +55,12 @@ def __init__(self, body): super().__init__() self.__dict__ = body - CodeOnlyModule.forward = _forward_from_src(body['code']) + try: + CodeOnlyModule.forward = _forward_from_src(body['_code']) + except KeyError: + # BC: attribute name was changed from `code` to `_code` to facilitate + # making `code` into a property and adding a docstring to it + CodeOnlyModule.forward = _forward_from_src(body['code']) from .symbolic_trace import Tracer @@ -59,7 +70,8 @@ class KeepModules(Tracer): def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: return True - return KeepModules().trace(CodeOnlyModule(body)) + com = CodeOnlyModule(body) + return GraphModule(com, KeepModules().trace(com)) # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' # This installs empty Modules where none exist yet if they are subpaths of target @@ -80,7 +92,14 @@ def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: setattr(to_module, item, t) from_module, to_module = f, t - setattr(to_module, field, getattr(from_module, field)) + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module # This installs empty Modules where none exist yet if they are subpaths of target @@ -98,15 +117,17 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): class GraphModule(torch.nn.Module): """ - GraphModule is an nn.Module generated from an fx.Graph. GraphModule has - important attributes: + GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a + ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated + from that ``graph``. + + .. warning:: - graph : The graph from which this GraphModule was generated - code : The Python source code for the function generated from `graph` - forward : The Python method generated from `graph` + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. - Note that when `graph` is reassigned, `code` and `forward` will be automatically - regenerated. """ def __new__(cls: 'Type[GraphModule]', *args, **kwargs): # each instance of a graph module needs its own forward method @@ -121,14 +142,20 @@ class GraphModuleImpl(cls): # type: ignore def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): """ Construct a GraphModule. - root - `root` can either be an nn.Module instance or a Dict mapping strings to any attribute type. - - In the case that `root` is a Module, any references to Module-based objects (via qualified - name) in the Graph's Nodes' `target` field will be copied over from the respective place - within `root`'s Module hierarchy into the GraphModule's module hierarchy. - - In the case that `root` is a dict, the qualified name found in a Node's `target` will be - looked up directly in the dict's keys. The object mapped to by the Dict will be copied - over into the appropriate place within the GraphModule's module hierarchy. - graph - `graph` contains the nodes this GraphModule should use for code generation + + Args: + + root (Union[torch.nn.Module, Dict[str, Any]): + ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. + In the case that ``root`` is a Module, any references to Module-based objects (via qualified + name) in the Graph's Nodes' ``target`` field will be copied over from the respective place + within ``root``'s Module hierarchy into the GraphModule's module hierarchy. + In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be + looked up directly in the dict's keys. The object mapped to by the Dict will be copied + over into the appropriate place within the GraphModule's module hierarchy. + + graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation + """ super().__init__() if isinstance(root, torch.nn.Module): @@ -145,14 +172,14 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): assert isinstance(node.target, str) if node.target not in root: raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target + - ' but that target was not provided in `root`!') + ' but that target was not provided in ``root``!') targets_to_copy.append(node.target) # Sort targets in ascending order of the # of atoms. # This will ensure that less deeply nested attributes are assigned # before more deeply nested attributes. For example, foo.bar # will be assigned before foo.bar.baz. Otherwise, we might assign - # the user-provided `foo.bar` and wipe out the previously-assigned - # `foo.bar.baz` + # the user-provided ``foo.bar`` and wipe out the previously-assigned + # ``foo.bar.baz`` targets_to_copy.sort(key=lambda t: t.count('.')) for target_to_copy in targets_to_copy: _assign_attr(root[target_to_copy], self, target_to_copy) @@ -164,26 +191,127 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 # # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway - __ignored_properties__ = ['graph'] + __jit_unused_properties__ = ['graph'] @property - def graph(self): + def graph(self) -> Graph: + """ + Return the ``Graph`` underlying this ``GraphModule`` + """ return self._graph @graph.setter - def graph(self, val) -> None: - self._graph = val - body, result, free_variables = self._graph.python_code(root_module='self') - body = '\n'.join(' ' + line for line in body.split('\n')) + '\n' - self.code = f"""\ -def forward(self, {', '.join(free_variables)}): -{body} - return {result} + def graph(self, g) -> None: + """ + Set the underlying ``Graph`` for this ``GraphModule``. This will internally + recompile the ``GraphModule`` so that the generated ``forward()`` function + corresponds to ``g`` + """ + self._graph = g + self.recompile() + + def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / 'state_dict.pt') + tab = " " * 4 + model_str = f""" +import torch +from torch.nn import * +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() """ + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f'{module_name}.pt' + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_str = f"torch.load(r'{module_file}') # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}))\n" + + for param_name, param in self._parameters.items(): + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(buffer.shape)}))\n" + + model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / 'module.py' + module_file.write_text(model_str) + + init_file = folder / '__init__.py' + init_file.write_text('from .module import *') + + if len(blobified_modules) > 0: + warnings.warn("Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}") + + @property + def code(self) -> str: + """ + Return the Python code generated from the ``Graph`` underlying this + ``GraphModule``. + """ + if not hasattr(self, '_code'): + raise RuntimeError('Code has not been generated! Please report a bug to PyTorch') + return self._code + + def recompile(self) -> None: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + self._code = self._graph.python_code(root_module='self') cls = type(self) - cls.forward = _forward_from_src(self.code) + cls.forward = _forward_from_src(self._code) + + cls_call = cls.__call__ + + def print_full_traceback(exctype, value, tb): + traceback.print_exception(exctype, value, tb) + + def wrapped_call(self, *args, **kwargs): + old_excepthook = sys.excepthook + try: + sys.excepthook = print_full_traceback + return cls_call(self, *args, **kwargs) + finally: + sys.excepthook = old_excepthook + cls.__call__ = wrapped_call def __reduce__(self): + """ + Serialization of GraphModule. We serialize only the generated code, not + the underlying ``Graph``. This is because ``Graph`` does not have on-disk + backward-compatibility guarantees, whereas Python source code does. + On the deserialization side, we symbolically trace through the generated + code to regenerate the underlying ``Graph`` + """ dict_without_graph = self.__dict__.copy() del dict_without_graph['_graph'] return (deserialize_graphmodule, (dict_without_graph,)) @@ -201,7 +329,7 @@ def __copy__(self): def __str__(self) -> str: orig_str = super().__str__() - return '\n'.join([orig_str, self.code]) + return '\n'.join([orig_str, self._code]) # workarounds for issues in __torch_function__ diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py new file mode 100644 index 0000000000000..03cb329a81b22 --- /dev/null +++ b/torch/fx/immutable_collections.py @@ -0,0 +1,16 @@ +def _no_mutation(self, *args, **kwargs): + raise NotImplementedError(f"'{type(self).__name__}' object does not support mutation") + +def _create_immutable_container(base, mutable_functions): + container = type('immutable_' + base.__name__, (base,), {}) + for attr in mutable_functions: + setattr(container, attr, _no_mutation) + return container + +immutable_list = _create_immutable_container(list, + ['__delitem__', '__iadd__', '__imul__', '__setitem__', 'append', + 'clear', 'extend', 'insert', 'pop', 'remove']) +immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) + +immutable_dict = _create_immutable_container(dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update']) +immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) diff --git a/torch/fx/node.py b/torch/fx/node.py index 666c627ac3e64..dccb1a8ce8801 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -1,11 +1,11 @@ # Nodes represent a definition of a value in our graph of operators. from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict +from .immutable_collections import immutable_dict, immutable_list import torch if TYPE_CHECKING: from .graph import Graph - BaseArgumentTypes = Union[str, int, float, bool, torch.dtype, torch.Tensor] base_types = BaseArgumentTypes.__args__ # type: ignore @@ -21,16 +21,247 @@ ]] class Node: - def __init__(self, graph: 'Graph', name: str, op: str, target: Target, - args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> None: + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', + args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], + type : Optional[Any] = None) -> None: self.graph = graph self.name = name # unique name of value being created - self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|getattr + assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] + self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + if op in ['call_method', 'call_module']: + assert isinstance(target, str) self.target = target # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add - self.args = args - self.kwargs = kwargs - self.uses = 0 + + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + self._input_nodes : Dict[Node, None] = {} + self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore + + # All of the nodes that use the value produced by this Node + # Note one user may correspond to several uses, e.g. the node fo ``x + x`` + # would appear once here, but represents two uses. + # + # Is a dict to act as an "ordered set". Keys are significant, value dont-care + self.users : Dict['Node', None] = {} + # Type expression representing the output value of this node. + # This should contain the same class of Type objects that would appear + # as type annotations for function inputs/outputs. + # + # For placeholder nodes, this value will be used to type-annotate the + # generated function parameters. + # For the return ndoe, this value will be used to type-annotate the + # generated function return type. (Note this is a special case. ``return`` + # does not produce a value, it's more of a notation. Thus, this value + # describes the type of args[0] in the ``return`` node. + self.type : Optional[Any] = type + self._prev = self + self._next = self + self._erased = False + + @property + def next(self) -> 'Node': + """ + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. + """ + return self._next + + @property + def prev(self) -> 'Node': + """ + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. + """ + return self._prev + + def prepend(self, x: 'Node') -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax + + Args: + x (Node): The node to put before this node. Must be a member of the same graph. + """ + assert self.graph == x.graph, "Attempting to move a Node into a different Graph" + x._remove_from_list() + p = self._prev + p._next, x._prev = x, p + x._next, self._prev = self, x + + def append(self, x: 'Node') -> None: + """ + Insert x after this node in the list of nodes in the graph. + Equvalent to ``self.next.prepend(x)`` + + Args: + x (Node): The node to put after this node. Must be a member of the same graph. + """ + self._next.prepend(x) + + def _remove_from_list(self): + p, n = self._prev, self._next + p._next, n._prev = n, p + + @property + def args(self) -> Tuple[Argument, ...]: + """ + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._args + + @args.setter + def args(self, a : Tuple[Argument, ...]): + """ + Set the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.args = new_args` + self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore + + @property + def kwargs(self) -> Dict[str, Argument]: + """ + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._kwargs + + @kwargs.setter + def kwargs(self, k : Dict[str, Argument]): + """ + Set the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` + self.__update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore + + @property + def all_input_nodes(self) -> List['Node']: + """ + Return all Nodes that are inputs to this Node. This is equivalent to + iterating over ``args`` and ``kwargs`` and only collecting the values that + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. + """ + return list(self._input_nodes.keys()) + + def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']): + """ + This API is internal. Do *not* call it directly. + """ + self._args = new_args + self._kwargs = new_kwargs + + for old_use in self._input_nodes.keys(): + old_use.users.pop(self) + + self._input_nodes = {} + map_arg(self._args, lambda n: self._input_nodes.setdefault(n)) + map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n)) + + for new_use in self._input_nodes.keys(): + new_use.users.setdefault(self) def __repr__(self) -> str: return self.name + + def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']: + """ + Replace all uses of ``self`` in the Graph with the Node ``replace_with``. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + + Returns: + + The list of Nodes on which this change was made. + """ + to_process = list(self.users) + for use_node in to_process: + def maybe_replace_node(n : Node) -> Node: + if n == self: + return replace_with + else: + return n + + new_args = map_arg(use_node.args, maybe_replace_node) + new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + use_node.__update_args_kwargs(new_args, new_kwargs) + + assert len(self.users) == 0 + return to_process + +def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: + """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ + return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + +def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: + """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ + if isinstance(a, tuple): + return tuple(map_aggregate(elem, fn) for elem in a) + elif isinstance(a, list): + return immutable_list(map_aggregate(elem, fn) for elem in a) + elif isinstance(a, dict): + return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) + elif isinstance(a, slice): + return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + else: + return fn(a) diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 90593c4b82f43..2cf6960fa9a7b 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -4,14 +4,15 @@ import operator from .graph import magic_methods, reflectable_magic_methods, Graph -from typing import Tuple, Dict, Optional, Iterable, NoReturn, Any, Union, Callable +from typing import Tuple, Dict, Optional, Iterable, Any, Iterator from .node import Target, Node, Argument, base_types class TracerBase: graph: Graph - def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None) -> Node: + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: """ Inserts a graph node given target, args, kwargs, and name. @@ -19,7 +20,27 @@ def create_node(self, kind : str, target : Union[str, Callable], modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ - return self.graph.create_node(kind, target, args, kwargs, name) + return self.graph.create_node(kind, target, args, kwargs, name, type_expr) + + def proxy(self, node: Node) -> 'Proxy': + return Proxy(node, self) + + def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], + name: Optional[str] = None, type_expr : Optional[Any] = None): + ''' + Create a Node from the given arguments, then return the Node + wrapped in a Proxy object. + + If kind = 'placeholder', then we're creating a Node that + represents the parameter of a function. If we need to encode + a default parameter, we use the ``args`` tuple. ``args`` is + otherwise empty for ``placeholder`` Nodes. + ''' + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) + return self.proxy(self.create_node(kind, target, args_, kwargs_, name, type_expr)) def create_arg(self, a: Any) -> Argument: """ @@ -29,7 +50,13 @@ def create_arg(self, a: Any) -> Argument: Can be override to support more trace-specific types. """ # aggregates - if isinstance(a, (tuple, list)): + if isinstance(a, tuple) and hasattr(a, '_fields'): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = tuple(self.create_arg(elem) for elem in a) + return type(a)(*args) # type: ignore + elif isinstance(a, (tuple, list)): return type(a)(self.create_arg(elem) for elem in a) elif isinstance(a, dict): r = {} @@ -49,6 +76,31 @@ def create_arg(self, a: Any) -> Argument: raise NotImplementedError(f"argument of type: {type(a)}") + def to_bool(self, obj: 'Proxy') -> bool: + """Called when a proxy object is being converted to a boolean, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return a value. + """ + raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + + def iter(self, obj: 'Proxy') -> Iterator: + """Called when a proxy object is being iterated over, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return an iterator. + """ + raise TraceError('Proxy object cannot be iterated. ' + 'This can be attempted when used in a for loop or as a *args or **kwargs function argument.') + + def keys(self, obj: 'Proxy') -> Any: + """Called when a proxy object is has the keys() method called. + This is what happens when ** is called on a proxy. This should return an + iterator it ** is suppose to work in your custom tracer. + """ + return Attribute(obj, 'keys')() + + # used in Proxy object when just appending to the graph while not tracing. class GraphAppendingTracer(TracerBase): def __init__(self, graph: Graph): @@ -58,27 +110,21 @@ def __init__(self, graph: Graph): class TraceError(ValueError): pass -# Proxy objects are stand-in values for normal values in a PyTorch computation. -# Instead of performing compute they record computation into Graph. -# Each proxy wraps the Node instance that represents the expression that define the -# value. - -# Unwrap the proxies inside args, and kwargs, create the resulting node -# and then wrap the result in a proxy. -def _create_proxy(tracer: 'TracerBase', op: str, target: Target, args_: Tuple[Any, ...], kwargs_: Dict[str, Any], name=None): - args = tracer.create_arg(args_) - kwargs = tracer.create_arg(kwargs_) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - rn = tracer.create_node(op, target, args, kwargs, name) - return Proxy(rn, tracer) class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + """ def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): if tracer is None: - # this allows you to create a proxy object around a raw node - # so that if you are doing graph transforms you can use the overloaded operators - # to add additional things to a graph. + # This allows you to create a Proxy object around a raw Node tracer = GraphAppendingTracer(node.graph) self.tracer = tracer self.node = node @@ -92,7 +138,7 @@ def __getattr__(self, k) -> 'Attribute': return Attribute(self, k) def __call__(self, *args, **kwargs) -> 'Proxy': - return _create_proxy(self.tracer, 'call_method', '__call__', (self,) + args, kwargs) + return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) def __iter__(self) -> Iterable['Proxy']: frame = inspect.currentframe() @@ -102,28 +148,28 @@ def __iter__(self) -> Iterable['Proxy']: inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2] if inst.opname == 'UNPACK_SEQUENCE': return (self[i] for i in range(inst.argval)) # type: ignore - if inst.opname == 'CALL_FUNCTION_EX': - self._no_arg_unpack() - else: - self._no_control_flow() - def _no_control_flow(self) -> NoReturn: - raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + return self.tracer.iter(self) + + def __bool__(self) -> bool: + return self.tracer.to_bool(self) - def _no_arg_unpack(self) -> NoReturn: - raise TraceError('Proxy object cannot be unpacked as function argument') + def keys(self): + return self.tracer.keys(self) - def __bool__(self) -> NoReturn: - self._no_control_flow() + def __len__(self): + raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope") def __torch_function__(self, orig_method, types, args=None, kwargs=None): args = args if args else () kwargs = kwargs if kwargs else {} if torch.overrides.is_tensor_method_or_property(orig_method): - return _create_proxy(self.tracer, 'call_method', orig_method.__name__, args, kwargs) + return self.tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) else: - return _create_proxy(self.tracer, 'call_function', orig_method, args, kwargs, - name=self.tracer.graph._name(orig_method.__name__)) + return self.tracer.create_proxy('call_function', orig_method, args, kwargs, + name=self.tracer.graph._target_to_str(orig_method.__name__)) class Attribute(Proxy): def __init__(self, root: Proxy, attr: str): @@ -137,18 +183,18 @@ def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = _create_proxy(self.tracer, 'call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): - return _create_proxy(self.tracer, 'call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) for method in magic_methods: def scope(method): def impl(*args, **kwargs): tracer = args[0].tracer target = getattr(operator, method) - return _create_proxy(tracer, 'call_function', target, args, kwargs) + return tracer.create_proxy('call_function', target, args, kwargs) impl.__name__ = method as_magic = f'__{method}__' setattr(Proxy, as_magic, impl) @@ -159,7 +205,7 @@ def _define_reflectable(orig_method_name): def impl(self, rhs): target = getattr(operator, orig_method_name) - return _create_proxy(self.tracer, 'call_function', target, (rhs, self), {}) + return self.tracer.create_proxy('call_function', target, (rhs, self), {}) impl.__name__ = method_name impl.__qualname__ = method_name setattr(Proxy, method_name, impl) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py new file mode 100644 index 0000000000000..d2ba17dcee2b3 --- /dev/null +++ b/torch/fx/subgraph_rewriter.py @@ -0,0 +1,285 @@ +from torch.fx import Graph, GraphModule, Node, symbolic_trace + +import copy +from typing import Callable, Dict, List, NamedTuple, Set + +class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + +class SubgraphMatcher: + def __init__(self, pattern : Graph) -> None: + self.pattern = pattern + if len(pattern.nodes) == 0: + raise ValueError("SubgraphMatcher cannot be initialized with an " + "empty pattern") + self.pattern_anchor = next(iter(reversed(pattern.nodes))) + # Maps nodes in the pattern subgraph to nodes in the larger graph + self.nodes_map: Dict[Node, Node] = {} + + def matches_subgraph_from_anchor(self, anchor : Node) -> bool: + """ + Checks if the whole pattern can be matched starting from + ``anchor`` in the larger graph. + + Pattern matching is done by recursively comparing the pattern + node's use-def relationships against the graph node's. + """ + self.nodes_map = {} + return self._match_nodes(self.pattern_anchor, anchor) + + # Compare the pattern node `pn` against the graph node `gn` + def _match_nodes(self, pn : Node, gn : Node) -> bool: + # Check if we've already matched these nodes in the current + # traversal + if pn in self.nodes_map: + return self.nodes_map[pn] == gn + + def attributes_are_equal(pn : Node, gn : Node) -> bool: + # Use placeholder and output nodes as wildcards. The + # only exception is that an output node can't match + # a placeholder + if (pn.op == "placeholder" + or (pn.op == "output" and gn.op != "placeholder")): + return True + return pn.op == gn.op and pn.target == gn.target + + # Terminate early if the node attributes are not equal + if not attributes_are_equal(pn, gn): + return False + + # Optimistically mark `pn` as a match for `gn` + self.nodes_map[pn] = gn + + # Traverse the use-def relationships to ensure that `pn` is a true + # match for `gn` + if (pn.op != "output" + and len(pn.all_input_nodes) != len(gn.all_input_nodes)): + return False + match_found = all(self._match_nodes(pn_, gn_) for pn_, gn_ + in zip(pn.all_input_nodes, gn.all_input_nodes)) + if not match_found: + self.nodes_map.pop(pn) + return False + + return True + + +def replace_pattern(gm : GraphModule, pattern : Callable, replacement : Callable) -> None: + """ + Matches all possible non-overlapping sets of operators and their + data dependencies (``pattern``) in the Graph of a GraphModule + (``gm``), then replaces each of these matched subgraphs with another + subgraph (``replacement``). + + Args: + ``gm``: The GraphModule that wraps the Graph to operate on + ``pattern``: The subgraph to match in ``gm`` for replacement + ``replacement``: The subgraph to replace ``pattern`` with + + Examples: + + .. code-block:: python + + import torch + from torch.fx import symbolic_trace, subgraph_rewriter + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, w1, w2): + m1 = torch.cat([w1, w2]).sum() + m2 = torch.cat([w1, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + + def pattern(w1, w2): + return torch.cat([w1, w2]).sum() + + def replacement(w1, w2): + return torch.stack([w1, w2]) + + traced_module = symbolic_trace(M()) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + The above code will first match ``pattern`` in the ``forward`` + method of ``traced_module``. Pattern-matching is done based on + use-def relationships, not node names. For example, if you had + ``p = torch.cat([a, b])`` in ``pattern``, you could match + ``m = torch.cat([a, b])`` in the original ``forward`` function, + despite the variable names being different (``p`` vs ``m``). + + The ``return`` statement in ``pattern`` is matched based on its + value only; it may or may not match to the ``return`` statement in + the larger graph. In other words, the pattern doesn't have to extend + to the end of the larger graph. + + When the pattern is matched, it will be removed from the larger + function and replaced by ``replacement``. If there are multiple + matches for ``pattern`` in the larger function, each non-overlapping + match will be replaced. In the case of a match overlap, the first + found match in the set of overlapping matches will be replaced. + ("First" here being defined as the first in a topological ordering + of the Nodes' use-def relationships. In most cases, the first Node + is the parameter that appears directly after ``self``, while the + last Node is whatever the function returns.) + + One important thing to note is that the parameters of the + ``pattern`` Callable must be used in the Callable itself, + and the parameters of the ``replacement`` Callable must match + the pattern. The first rule is why, in the above code block, the + ``forward`` function has parameters ``x, w1, w2``, but the + ``pattern`` function only has parameters ``w1, w2``. ``pattern`` + doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. + As an example of the second rule, consider replacing + + .. code-block:: python + + def pattern(x, y): + return torch.neg(x) + torch.relu(y) + + with + + .. code-block:: python + + def replacement(x, y): + return torch.relu(x) + + In this case, ``replacement`` needs the same number of parameters + as ``pattern`` (both ``x`` and ``y``), even though the parameter + ``y`` isn't used in ``replacement``. + + After calling ``subgraph_rewriter.replace_pattern``, the generated + Python code looks like this: + + .. code-block:: python + + def forward(self, x, w1, w2): + stack_1 = torch.stack([w1, w2]) + sum_1 = stack_1.sum() + stack_2 = torch.stack([w1, w2]) + sum_2 = stack_2.sum() + max_1 = torch.max(sum_1) + add_1 = x + max_1 + max_2 = torch.max(sum_2) + add_2 = add_1 + max_2 + return add_2 + + """ + # Get the graphs for `gm`, `pattern`, `replacement` + original_graph = gm.graph + pattern_graph = symbolic_trace(pattern).graph + replacement_graph = symbolic_trace(replacement).graph + + # Find all possible pattern matches in original_graph. Note that + # pattern matches may overlap with each other. + matcher = SubgraphMatcher(pattern_graph) + matches: List[Match] = [] + + # Consider each node as an "anchor" (deepest matching graph node) + for anchor in original_graph.nodes: + + if matcher.matches_subgraph_from_anchor(anchor): + + def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + lookup: Dict[Node, Node] = {v : k for k, v + in nodes_map.items()} + for n in lookup.keys(): + if n.op == "placeholder" or lookup[n].op == "output": + continue + for user in n.users: + # If this node has users that were not in + # `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True + + # It's not a match if the pattern leaks out into the rest + # of the graph + if pattern_is_contained(matcher.nodes_map): + for k, v in matcher.nodes_map.items(): + # Shallow copy nodes_map + matches.append(Match(anchor=anchor, + nodes_map=copy.copy(matcher.nodes_map))) + + # The set of all nodes in `original_graph` that we've seen thus far + # as part of a pattern match + replaced_nodes: Set[Node] = set() + + # Return TRUE if one of the nodes in the current match has already + # been used as part of another match + def overlaps_with_prev_match(match : Match) -> bool: + for n in match.nodes_map.values(): + if n in replaced_nodes and n.op != "placeholder": + return True + return False + + for match in matches: + + # Skip overlapping matches + if overlaps_with_prev_match(match): + continue + + # Map replacement graph nodes to their copy in `original_graph` + val_map: Dict[Node, Node] = {} + + pattern_placeholders = [n for n in pattern_graph.nodes + if n.op == "placeholder"] + assert len(pattern_placeholders) + replacement_placeholders = [n for n in replacement_graph.nodes + if n.op == "placeholder"] + assert len(pattern_placeholders) == len(replacement_placeholders) + placeholder_map = {r : p for r, p + in zip(replacement_placeholders, pattern_placeholders)} + + # node from `original_graph` that matched with the output node + # in `pattern` + subgraph_output: Node = match.anchor + + def mark_node_as_replaced(n : Node) -> None: + if n not in match.nodes_map.values(): + return + for n_ in n.all_input_nodes: + mark_node_as_replaced(n_) + replaced_nodes.add(n) + + mark_node_as_replaced(subgraph_output) + + # Intialize `val_map` with mappings from placeholder nodes in + # `replacement` to their corresponding node in `original_graph` + for replacement_node in replacement_placeholders: + # Get the `original_graph` placeholder node + # corresponding to the current `replacement_node` + pattern_node = placeholder_map[replacement_node] + original_graph_node = match.nodes_map[pattern_node] + # Populate `val_map` + val_map[replacement_node] = original_graph_node + + # Copy the replacement graph over + with original_graph.inserting_before(subgraph_output): + copied_output = original_graph.graph_copy(replacement_graph, + val_map) + assert isinstance(copied_output, Node) + + # We only want to copy in the output node from `pattern` if we + # have an output-output match. Otherwise, we leave out the + # `pattern` output node so we don't have two outputs in the + # resultant graph + if subgraph_output.op != "output": + subgraph_output = subgraph_output.args[0] # type: ignore + subgraph_output.replace_all_uses_with(copied_output) + + # Erase the `pattern` nodes + for node in reversed(original_graph.nodes): + if len(node.users) == 0 and node.op != "output": + original_graph.erase_node(node) + + # Update the passed-in GraphModule to reflect the new state of + # `original_graph` + gm.recompile() diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index 9b192dd5501fd..c7b7c8fda55e2 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -1,21 +1,17 @@ +import builtins import inspect from types import CodeType, FunctionType -from typing import Any, Optional, List +from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union import torch +from torch._C import ScriptObject # type: ignore -from .node import Argument +from .node import Argument, map_aggregate from .graph import Graph from .graph_module import GraphModule -from .proxy import Proxy, _create_proxy, TracerBase +from .proxy import TracerBase, Proxy HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS -def _find_module(root: torch.nn.Module, m: torch.nn.Module): - for n, p in root.named_modules(): - if m is p: - return n - raise NameError('module is not installed as a submodule') - def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: co = fn.__code__ co_flags = co.co_flags & ~HAS_VARSTUFF @@ -39,15 +35,50 @@ def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: new_code = CodeType(*co_args) # type: ignore return FunctionType(new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__) - # we need to insert placeholder nodes for *args, and **kwargs, - # so we can't call this function normally, otherwise it would try to unpack them - # instead, let's make python think that args and kwargs are normay variables + # we need to insert placeholder nodes for *args and **kwargs + # we can't call this function normally, otherwise it would try to unpack them + # instead, let's make python think that args and kwargs are normal variables class Tracer(TracerBase): + """ + ``Tracer`` is the class that implements the symbolic tracing functionality + of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent + to ``Tracer().trace(m)``. + + Tracer can be subclassed to override various behaviors of the tracing + process. The different behaviors that can be overridden are described + in the docstrings of the methods on this class. + """ def __init__(self): super().__init__() - def create_arg(self, a: Any) -> Argument: + def create_arg(self, a: Any) -> 'Argument': + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the ``Graph``. + + By default, the behavior includes: + + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` + """ # The base tracer is used to construct Graphs when there is no associated # module hierarchy, so it can never create parameter references. # The default tracer adds the ability to refer to parameters when @@ -57,6 +88,17 @@ def create_arg(self, a: Any) -> Argument: if a is p: return self.create_node('get_attr', n, (), {}) raise NameError('parameter is not a member of this module') + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node('get_attr', n_, (), {}) + + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, '_fields'): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node('call_function', a.__class__, args, {}) + # Tensors do not have a reliable string repr() from which they can be # constructed (and we probably don't want to rely on that, either), so # for any constant Tensor values we encounter, first search for if they @@ -64,33 +106,15 @@ def create_arg(self, a: Any) -> Argument: # a get_attr to retrieve that tensor. Otherwise, we'll store away the # tensor value into a special attribute on the Module s.t. we can # retrieve it with a get_attr. - if isinstance(a, torch.Tensor): - # TODO: slow - def search_for_tensor(m : torch.nn.Module) -> Optional[List[str]]: - """ - Search for a tensor value in the module's attributes. If it's - found, return the qualified name of that attribute, given the - previous `qualname_atoms`. If it's not found, recurse down into - child submodules. If it's not found there, return None - """ - for n, p in m.__dict__.items(): - if a is p: - return [n] - for n, c in m.named_children(): - maybe_result : Optional[List[str]] = search_for_tensor(c) - if maybe_result: - return [n] + maybe_result - return None - # Retrieve the qualname for an existing Tensor attribute - qualname_atoms : Optional[List[str]] = search_for_tensor(self.root) - qualname = '.'.join(qualname_atoms) if qualname_atoms else None + if isinstance(a, (torch.Tensor, ScriptObject)): + qualname : Optional[str] = self.tensor_attrs.get(a) # Tensor was not found in the Module hierarchy, stow it away in a # special attribute and set the qualname to refer to that if not qualname: i = 0 while True: - qualname = f'__tensor_constant{i}' + qualname = f'_tensor_constant{i}' if not hasattr(self.root, qualname): break i += 1 @@ -101,68 +125,354 @@ def search_for_tensor(m : torch.nn.Module) -> Optional[List[str]]: def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: """ - A method to specify whether a given `nn.Module` is a "leaf" module. + A method to specify whether a given ``nn.Module`` is a "leaf" module. Leaf modules are the atomic units that appear in - the IR, referenced by `call_module` calls. By default, + the IR, referenced by ``call_module`` calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter. - Args - m - The module itself - module_qualified_name - The path to root of this module. For example, - if you have a module hierarchy where submodule `foo` contains - submodule `bar`, which contains submodule `baz`, that module will - appear with the qualified name `foo.bar.baz` here. + Args: + + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. """ return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) - def trace(self, root: torch.nn.Module) -> GraphModule: - self.root = root - self.graph = Graph() + def path_of_module(self, mod : torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return + the string "foo.bar". - fn = type(root).forward - assert isinstance(fn, FunctionType) - co = fn.__code__ + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. + """ + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError('module is not installed as a submodule') + + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + module_qualified_name = self.path_of_module(m) + if not self.is_leaf_module(m, module_qualified_name): + return forward(*args, **kwargs) + return self.create_proxy('call_module', module_qualified_name, args, kwargs) + + def create_args_for_root(self, root_fn, is_module): + """ + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects root's signature and emits those + nodes accordingly, also supporting ``*args`` and ``**kwargs``. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + co = fn_for_analysis.__code__ total_args = co.co_argcount + co.co_kwonlyargcount names_iter = iter(co.co_varnames) - next(names_iter) # skip self - args : List[Any] = [root] - args.extend(self._proxy_placeholder(next(names_iter)) for name in range(1, total_args)) + args : List[Any] = [] + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError('``self`` argument cannot be part of *args expansion!') + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + sig = inspect.signature(fn_for_analysis) + + def proxy_placeholder(name: str): + if name[0] == '*': + default = () # type: ignore + else: + param = sig.parameters[name] + default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore + return self.create_proxy('placeholder', name, default, {}, + type_expr=fn_for_analysis.__annotations__.get(name, None)) + + args.extend(proxy_placeholder(next(names_iter)) for _ in range(skip_arg_idx, total_args)) if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs if co.co_flags & inspect.CO_VARARGS: - args.append(self._proxy_placeholder('*' + next(names_iter))) + args.append(proxy_placeholder('*' + next(names_iter))) if co.co_flags & inspect.CO_VARKEYWORDS: - args.append(self._proxy_placeholder('**' + next(names_iter))) - fn = _patch_function(fn, len(args)) + args.append(proxy_placeholder('**' + next(names_iter))) + root_fn = _patch_function(root_fn, len(args)) - orig_call = torch.nn.Module.__call__ + return root_fn, args + + def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: + """ + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + if isinstance(root, torch.nn.Module): + self.root = root + fn = type(root).forward + else: + self.root = torch.nn.Module() + fn = root + self.graph = Graph() + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs : Dict[torch.Tensor, str] = {} + + def collect_tensor_attrs(m : torch.nn.Module, prefix_atoms : List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module)) + + parameter_proxy_cache : Dict[str, Proxy] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + def module_getattr_wrapper(mod, attr): + attr_val = orig_getattr(mod, attr) + if isinstance(attr_val, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if attr_val is p: + if n not in parameter_proxy_cache: + parameter_proxy_cache[n] = self.create_proxy('get_attr', n, (), {}) + return parameter_proxy_cache[n] + return attr_val def module_call_wrapper(mod, *args, **kwargs): - module_qualified_name = _find_module(root, mod) - if not self.is_leaf_module(mod, module_qualified_name): + def forward(*args, **kwargs): return orig_call(mod, *args, **kwargs) - else: - return _create_proxy(self, 'call_module', module_qualified_name, args, kwargs) + + return self.call_module(mod, forward, args, kwargs) + + orig_call = torch.nn.Module.__call__ + orig_getattr = torch.nn.Module.__getattr__ + orig_fns : List[PatchedFn] = [] + try: + # Seems to be a mypy limitation: https://github.com/python/mypy/issues/2427 + torch.nn.Module.__getattr__ = module_getattr_wrapper # type: ignore torch.nn.Module.__call__ = module_call_wrapper - self.graph.output(self.create_arg(fn(*args))) + + _patch_wrapped_functions(orig_fns) + + self.create_node('output', 'output', (self.create_arg(fn(*args)),), {}, + type_expr=fn.__annotations__.get('return', None)) finally: + _unpatch_wrapped_functions(orig_fns) torch.nn.Module.__call__ = orig_call - return GraphModule(root, self.graph) - - def _proxy_placeholder(self, name: str) -> Proxy: - return Proxy(self.create_node('placeholder', name, (), {}), self) - -# Symbolic tracing API -# -# Given an `nn.Module` instance `root`, this function will return a `GraphModule` -# constructed by recording operations seen while tracing through `root`. -# -# Args: -# - root - the `nn.Module` instance to trace -def symbolic_trace(root : torch.nn.Module) -> GraphModule: - return Tracer().trace(root) + torch.nn.Module.__getattr__ = orig_getattr # type: ignore + return self.graph + +# List of pairs of (global dict, function name) functions +# to patch for the purposes of the wrap() API. +_wrapped_fns_to_patch : List[Tuple[dict, str]] = [] + +def _create_wrapped_func(orig_fn): + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, Proxy): + proxy = x + + map_aggregate(args, find_proxy) + map_aggregate(kwargs, find_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs) + else: + return orig_fn(*args, **kwargs) + + return wrapped + +class PatchedFn(NamedTuple): + frame_dict : Dict[str, Any] + fn_name : str + orig_fn : Any + +# isinstance(orig_fn, NoneSentinel) if the original global namespace +# did not contain this function at the time of patching. This can +# occur, for example, when patching a builtin function +class PatchedFnNoneSentinel: + pass + +def _patch_wrapped_functions(orig_fns : List[PatchedFn]): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. Returns + a list of PatchedFn, which is a record specifiying a single function + entry that was patched and contains the original function for unpatching + + Note orig_fns is taken by reference and updated as we go to facilitate + reverting patching if this function itself throws an exception. + """ + # Set to deduplicate entries. Wrapping a function multiple times would + # be an error, since it would cause a `call_function` node for the + # wrapper to be emitted rather than the actual underlying function + # + # Use id(frame_dict) as a hashable identity here since none of the + # frame dicts should be destroyed during symtracing + processed_entries : Set[Tuple[int, str]] = set() + + for frame_dict, name in _wrapped_fns_to_patch: + if (id(frame_dict), name) in processed_entries: + continue + if name not in frame_dict and hasattr(builtins, name): + orig_fn = getattr(builtins, name) + orig_fns.append(PatchedFn(frame_dict, name, PatchedFnNoneSentinel())) + else: + orig_fn = frame_dict[name] + orig_fns.append(PatchedFn(frame_dict, name, orig_fn)) + + frame_dict[name] = _create_wrapped_func(orig_fn) + + processed_entries.add((id(frame_dict), name)) + +def _unpatch_wrapped_functions(orig_fns : List[PatchedFn]): + """ + Given the ``orig_fns`` dict that ``_patch_wrapped_functions``, + replace all of the global functions with the original global functions + that were there before symbolic tracing. + """ + for frame_dict, fn_name, orig_fn in orig_fns: + if isinstance(orig_fn, PatchedFnNoneSentinel): + del frame_dict[fn_name] + else: + frame_dict[fn_name] = orig_fn + +def wrap(fn_or_name : Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + torch.fx.wrap('my_custom_function') + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if callable(fn_or_name): + fn_name = fn_or_name.__code__.co_name + elif isinstance(fn_or_name, str): + fn_name = fn_or_name + else: + raise RuntimeError('Unsupported type for global function! Must be either a callable or ' + 'string name') + + if hasattr(fn_or_name, '__code__'): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__code__.co_name + else: + assert isinstance(fn_or_name, str), "fn_or_name must be a global function or string name" + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != '': + raise NotImplementedError('wrap must be called at the top level of a module') + + _wrapped_fns_to_patch.append((f.f_globals, fn_name)) + +def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule: + """Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + + """ + return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), Tracer().trace(root)) diff --git a/torch/hub.py b/torch/hub.py index 416401717c404..49b0aa612bde5 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -39,6 +39,9 @@ def update(self, n): sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) sys.stderr.flush() + def close(self): + self.disable = True + def __enter__(self): return self @@ -155,43 +158,9 @@ def _get_cache_or_reload(github, force_reload, verbose=True): def _check_module_exists(name): - if sys.version_info >= (3, 4): - import importlib.util - return importlib.util.find_spec(name) is not None - elif sys.version_info >= (3, 3): - # Special case for python3.3 - import importlib.find_loader - return importlib.find_loader(name) is not None - else: - # NB: Python2.7 imp.find_module() doesn't respect PEP 302, - # it cannot find a package installed as .egg(zip) file. - # Here we use workaround from: - # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1 - # Also imp doesn't handle hierarchical module names (names contains dots). - try: - # 1. Try imp.find_module(), which searches sys.path, but does - # not respect PEP 302 import hooks. - import imp - result = imp.find_module(name) - if result: - return True - except ImportError: - pass - path = sys.path - for item in path: - # 2. Scan path for import hooks. sys.path_importer_cache maps - # path items to optional "importer" objects, that implement - # find_module() etc. Note that path must be a subset of - # sys.path for this to work. - importer = sys.path_importer_cache.get(item) - if importer: - try: - result = importer.find_module(name, [item]) - if result: - return True - except ImportError: - pass - return False + import importlib.util + return importlib.util.find_spec(name) is not None + def _check_dependencies(m): dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index fd61228f33790..bf3643638d4fa 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -9,6 +9,7 @@ _overload, _overload_method, ignore, + _isinstance, is_scripting, export, unused, @@ -43,7 +44,8 @@ from torch.jit._serialization import save, load from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph -from torch.jit._freeze import freeze +from torch.jit.cuda import stream +from torch.jit._freeze import freeze, optimize_frozen_module # For backwards compatibility _fork = fork @@ -70,5 +72,69 @@ def annotate(the_type, the_value): return the_value +def script_if_tracing(fn): + """ + Compiles ``fn`` when it is first called during tracing. ``torch.jit.script`` + has a non-negligible start up time when it is first called due to + lazy-initializations of many compiler builtins. Therefore you should not use + it in library code. However, you may want to have parts of your library work + in tracing even if they use control flow. In these cases, you should use + ``@torch.jit.script_if_tracing`` to substitute for + ``torch.jit.script``. + + Args: + fn: A function to compile. + + Returns: + If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned. + Otherwise, the original function `fn` is returned. + """ + + return _script_if_tracing(fn) + + +# for torch.jit.isinstance +def isinstance(obj, target_type): + """ + This function provides for conatiner type refinement in TorchScript. It can refine + parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``, + ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also + refine basic types such as bools and ints that are available in TorchScript. + + Args: + obj: object to refine the type of + target_type: type to try to refine obj to + Returns: + ``bool``: True if obj was successfully refined to the type of target_type, + False otherwise with no new type refinement + + + Example (using ``torch.jit.isinstance`` for type refinement): + .. testcode:: + + import torch + from typing import Any, Dict, List + + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, input: Any): # note the Any type + if torch.jit.isinstance(input, List[torch.Tensor]): + for t in input: + y = t.clamp(0, 0.5) + elif torch.jit.isinstance(input, Dict[str, str]): + for val in input.values(): + print(val) + + m = torch.jit.script(MyModule()) + x = [torch.rand(3,3), torch.rand(4,3)] + m(x) + y = {"key1":"val1","key2":"val2"} + m(y) + """ + return _isinstance(obj, target_type) + + if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed") diff --git a/torch/jit/_async.py b/torch/jit/_async.py index 5e67167bd41a5..ae9684a0e229c 100644 --- a/torch/jit/_async.py +++ b/torch/jit/_async.py @@ -17,7 +17,7 @@ def fork(func, *args, **kwargs): - """ + r""" Creates an asynchronous task executing `func` and a reference to the value of the result of this execution. `fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion @@ -31,7 +31,7 @@ def fork(func, *args, **kwargs): `fork` tasks will execute non-deterministicly. We recommend only spawning parallel fork tasks for pure functions that do not modify their inputs, module attributes, or global state. - Arguments: + Args: func (callable or torch.nn.Module): A Python function or `torch.nn.Module` that will be invoked. If executed in TorchScript, it will execute asynchronously, otherwise it will not. Traced invocations of fork will be captured in the IR. @@ -42,7 +42,8 @@ def fork(func, *args, **kwargs): Example (fork a free function): - .. testcode:: + .. code-block:: python + import torch from torch import Tensor def foo(a : Tensor, b : int) -> Tensor: @@ -60,16 +61,17 @@ def bar(a): Example (fork a module method): - .. testcode:: + .. code-block:: python + import torch from torch import Tensor - class SubMod(torch.nn.Module): + class AddMod(torch.nn.Module): def forward(self, a: Tensor, b : int): return a + b class Mod(torch.nn.Module): def __init__(self): super(self).__init__() - self.mod = SubMod() + self.mod = AddMod() def forward(self, input): fut = torch.jit.fork(self.mod, a, b=2) return torch.jit.wait(fut) @@ -81,10 +83,10 @@ def forward(self, input): def wait(future): - """ + r""" Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. See :func:`~fork` for docs and examples. - Arguments: + Args: func (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork` Returns: `T`: the return value of the the completed task diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index 127d2e4a032a6..d1ad79ae4d1a0 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -67,6 +67,7 @@ (math.degrees, "aten::degrees"), (math.radians, "aten::radians"), (math.ldexp, "aten::ldexp"), + (torch._assert, "aten::_assert"), (torch.autograd.grad, "aten::grad"), (torch.autograd.backward, "aten::backward"), (torch._C._infer_size, "aten::_infer_size"), diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index 5c217ea17c1f4..98e53c6aee41b 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -10,7 +10,7 @@ from torch.jit._script import RecursiveScriptModule, ScriptModule -def freeze(mod, preserved_attrs: Optional[List[str]] = None): +def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True): r""" Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. @@ -20,12 +20,17 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None): Freezing currently only accepts ScriptModules that are in eval mode. - Arguments: + Args: mod (:class:`ScriptModule`): a module to be frozen preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. Attributes modified in preserved methods will also be preserved. + optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference, + in addition to the graph cleanup that already occurs. The details of the optimizations can be found in + `torch.jit.optimize_frozen_module.` + + Returns: Frozen :class:`ScriptModule`. @@ -97,5 +102,42 @@ def forward(self, input): out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) RecursiveScriptModule._finalize_scriptmodule(out) + if optimize: + optimize_frozen_module(out) return out + + +def optimize_frozen_module(mod): + r""" + Runs a series of optimizations looking for patterns that occur in frozen graphs. + The current set of optimizations is: + - Conv -> Batchnorm folding + - Conv -> Add/Sub folding + - Conv -> Mul/Div folding + + Args: + mod (:class:`ScriptModule`): a frozen module to be optimized + + Returns: + None + + Note: + In rare occassions, this can result in slower execution. + + Example (Freezing a module with Conv->Batchnorm) + .. code-block:: python + import torch + in_channels, out_channels = 3, 32 + conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) + bn = torch.nn.BatchNorm2d(out_channels, eps=.001) + mod = torch.nn.Sequential(conv, bn) + # set optimize to False here, by default freezing runs optimize_frozen_module + frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False) + # inspect frozen mod + assert "batch_norm" in str(frozen_mod.graph) + torch.jit.optimize_frozen_module(frozen_mod) + assert "batch_norm" not in str(frozen_mod.graph) + + """ + torch._C._jit_pass_optimize_frozen_graph(mod.graph) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 0eb423516f6fd..2d8d0c31cbbf3 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -104,6 +104,12 @@ def infer_concrete_type_builder(nn_module, share_types=True): concrete_type_builder.set_module_list() class_annotations = getattr(nn_module, '__annotations__', {}) + if isinstance(nn_module, (torch.quantization.QuantWrapper)): + class_annotations = {} + + # Get user-annotated ignored attributes. + user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list()) + concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes) # try to infer the type from type annotation or from the object itself def infer_type(name, item): @@ -123,6 +129,9 @@ def infer_type(name, item): added_names = set() for name, item in nn_module._parameters.items(): + if name in user_annotated_ignored_attributes: + continue + assert item is None or isinstance(item, torch.Tensor) attr_type = infer_type(name, item) # We currently have the invariant in various places in our code @@ -134,12 +143,18 @@ def infer_type(name, item): added_names.add(name) for name, item in nn_module._buffers.items(): + if name in user_annotated_ignored_attributes: + continue + assert item is None or isinstance(item, torch.Tensor) attr_type = infer_type(name, item) concrete_type_builder.add_attribute(name, attr_type, False, True) added_names.add(name) for name, item in nn_module._modules.items(): + if name in user_annotated_ignored_attributes: + continue + attr_type = infer_type(name, item) if item is None: # Modules can be None. We don't have direct support for optional @@ -205,6 +220,9 @@ def infer_type(name, item): # PyTorch adds a few more. Prevent these from getting compiled. continue + if name in user_annotated_ignored_attributes: + continue + if name in added_names: # Don't re-add anything we already added continue @@ -310,7 +328,7 @@ def get_module_concrete_type(nn_module, share_types=True): type is fetched from concrete_type_store. If it is False, a new concrete type is created without first searching concrete_type_store. - Arguments: + Args: nn_module: The original Python nn.Module that we are creating a ScriptModule for. share_types = Whether to share underlying JIT types between modules (if possible). @@ -338,7 +356,7 @@ def create_script_module(nn_module, stubs_fn, share_types=True): """ Creates a new ScriptModule from an nn.Module - Arguments: + Args: nn_module: The original Python nn.Module that we are creating a ScriptModule for. stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. share_types: Whether to share underlying JIT types between modules (if possible). @@ -355,7 +373,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): """ Convert an nn.Module to a RecursiveScriptModule. - Arguments: + Args: nn_module: The original Python nn.Module that we are creating a ScriptModule for. concrete_type: The fully initialized ConcreteType of the module. stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. @@ -390,7 +408,7 @@ def init_fn(script_module): cpp_module.setattr(name, scripted) script_module._modules[name] = scripted - # 3. Copy @ignored/@unused methods from the original `nn_module` to the new ScriptModule. + # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. # This ensures we can access these Python methods on the ScriptModule. for name in dir(nn_module): item = getattr(nn_module, name, None) @@ -398,6 +416,8 @@ def init_fn(script_module): unbound_function = getattr(type(nn_module), name) bound_method = unbound_function.__get__(script_module) setattr(script_module, name, bound_method) + elif concrete_type.is_ignored_attribute(name): + setattr(script_module, name, item) # For convenience, attach the concrete type to the new ScriptModule script_module._concrete_type = concrete_type @@ -541,6 +561,13 @@ def check_module_initialized(mod): raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?" .format(torch.typename(type(mod)))) + # This is to avoid importing torch.distributed.nn + if not hasattr(mod, 'remote_parameters'): + for name, param in mod._parameters.items(): + if isinstance(param, torch.nn.parameter.UninitializedParameter): + raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?" + .format(torch.typename(type(mod)), name)) + def infer_methods_to_compile(nn_module): """ Implements the default rules for which methods should act as starting @@ -618,7 +645,7 @@ def interface_script(mod_interface, nn_module): Makes a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile. - Arguments: + Args: mod_interface: the interface type that the module have nn_module: The original Python nn.Module that we are creating a ScriptModule for. """ diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 4d28a5f2ad130..bdf00e21c5153 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -32,6 +32,8 @@ _set_jit_function_cache, _set_jit_overload_cache, ) +from torch.overrides import ( + has_torch_function, has_torch_function_unary, has_torch_function_variadic) torch._C.ScriptMethod.graph_for = _graph_for # type: ignore torch._C.ScriptFunction.graph_for = _graph_for # type: ignore @@ -276,7 +278,7 @@ class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore contain methods, attributes, parameters, and constants. These can be accessed the same as on a normal ``nn.Module``. """ - __ignored_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name'] + __jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name'] def __init__(self): super(ScriptModule, self).__init__() @@ -379,7 +381,7 @@ def _construct(cpp_module, init_fn): object is properly finalized (and in the future we may take control of how the RecursiveScriptModule instance is created). - Arguments: + Args: cpp_module: The C++ Module that will hold the actual state of this RecursiveScriptModule instance. init_fn: Lambda that initializes the RecursiveScriptModule passed to it. @@ -409,7 +411,7 @@ def _reconstruct(self, cpp_module): """ Re-construct an instance of RecursiveScriptModule using an instance of a C++ module. - Arguments: + Args: cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around. """ self.__init__(cpp_module) # type: ignore @@ -443,7 +445,7 @@ def graph(self): Returns a string representation of the internal graph for the ``forward`` method. See :ref:`interpreting-graphs` for details. """ - return self.forward.graph + return self._c._get_method("forward").graph @property def inlined_graph(self): @@ -478,13 +480,13 @@ def code_with_constants(self): r = self.forward.code_with_constants return (r[0], ConstMap(r[1])) - def save(self, *args, **kwargs): + def save(self, f, **kwargs): r""" save(f, _extra_files={}) See :func:`torch.jit.save ` for details. """ - return self._c.save(*args, **kwargs) + return self._c.save(str(f), **kwargs) def _save_for_lite_interpreter(self, *args, **kwargs): r""" @@ -493,7 +495,7 @@ def _save_for_lite_interpreter(self, *args, **kwargs): Add (or update) the bytecode session to the script model. The updated model is used in lite interpreter for mobile applications. - Arguments: + Args: f: a string containing a file name. _extra_files: Map from filename to contents which will be stored as part of 'f'. @@ -741,6 +743,43 @@ class RecursiveScriptModule(ScriptModule): # type: ignore def __init__(self, arg=None): super().__init__() +def call_prepare_scriptable_func_impl(obj, memo): + if not isinstance(obj, torch.nn.Module): + return obj + + obj_id = id(obj) + + # If obj_id is in memo, obj has already been prepared or is being + # prepared in another call up the stack. + if obj_id in memo: + return memo[id(obj)] + + obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore + # Record obj in memo to avoid infinite recursion in the case of cycles in the module + # hierarchy when recursing below. + memo[obj_id] = obj + + new_obj_dict = {} + + for name in obj.__dict__: + sub_module = obj.__dict__.get(name) + if name == '_modules': + for k, v in sub_module.items(): + sub_module[k] = call_prepare_scriptable_func_impl(v, memo) + new_obj_dict[name] = sub_module + elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule): + new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo) + else: + new_obj_dict[name] = sub_module + + for k, v in new_obj_dict.items(): + obj.__dict__[name] = v + + return obj + +def call_prepare_scriptable_func(obj): + memo: Dict[int, torch.nn.Module] = {} + return call_prepare_scriptable_func_impl(obj, memo) def script(obj, optimize=None, _frames_up=0, _rcb=None): r""" @@ -754,7 +793,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): ``torch.jit.script`` can be used as a function for modules and functions, and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. - Arguments: + Args: obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, or class type to compile. @@ -894,6 +933,7 @@ def forward(self, input): return obj if isinstance(obj, torch.nn.Module): + obj = call_prepare_scriptable_func(obj) return torch.jit._recursive.create_script_module( obj, torch.jit._recursive.infer_methods_to_compile ) @@ -1081,3 +1121,6 @@ def _unwrap_optional(x): _register_builtin(_unwrap_optional, "aten::_unwrap_optional") _register_builtin(_jit_internal.is_scripting, "aten::is_scripting") +_register_builtin(has_torch_function, "aten::has_torch_function") +_register_builtin(has_torch_function_unary, "aten::has_torch_function") +_register_builtin(has_torch_function_variadic, "aten::has_torch_function") diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index d828ec8a0f1c3..7a551bb7da76b 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -33,7 +33,7 @@ def save(m, f, _extra_files=None): during loading. This is different from :func:`torch.load`'s semantics and may change in the future. - Arguments: + Args: m: A :class:`ScriptModule` to save. f: A file-like object (has to implement write and flush) or a string containing a file name. @@ -94,7 +94,7 @@ def load(f, map_location=None, _extra_files=None): because the run time system doesn't have certain devices), an exception is raised. - Arguments: + Args: f: a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name map_location (string or torch.device): A simplified version of @@ -158,7 +158,7 @@ def load(f, map_location=None, _extra_files=None): cu = torch._C.CompilationUnit() if isinstance(f, str) or isinstance(f, pathlib.Path): - cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files) + cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files) else: cpp_module = torch._C.import_ir_module_from_buffer( cu, f.read(), map_location, _extra_files diff --git a/torch/jit/_state.py b/torch/jit/_state.py index a6baba60f7caf..eb81c1e463a85 100644 --- a/torch/jit/_state.py +++ b/torch/jit/_state.py @@ -61,7 +61,6 @@ def enable(): _script_classes = {} def _add_script_class(cls, name): - cls.__torch_script_class__ = True global _script_classes _script_classes[name] = cls diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index e73785e15aeae..17be5c1ffd401 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -194,7 +194,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None): parameters), so don't expect the model to come out exactly the same as what you passed in. - Arguments: + Args: model (compiled torch.nn.Module or function): the module/function to be verified. The module/function definition MUST have been decorated with `@torch.jit.compile`. @@ -626,7 +626,7 @@ def trace( invocations of the model. The tracer will try to emit warnings when doing something that may cause an incorrect trace to be produced. - Arguments: + Args: func (callable or torch.nn.Module): A Python function or `torch.nn.Module` that will be run with `example_inputs`. `func` arguments and return values must be tensors or (possibly nested) tuples that contain @@ -830,7 +830,7 @@ def trace_module( See :func:`torch.jit.trace ` for more information on tracing. - Arguments: + Args: mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are specified in ``inputs``. The given methods will be compiled as a part of a single `ScriptModule`. @@ -1010,7 +1010,6 @@ def check_unique(param): "TracedModules don't support parameter sharing between modules" ) id_set.add(param) - tmp_module.training = orig.training for name, param in orig._parameters.items(): @@ -1036,6 +1035,8 @@ def check_unique(param): ) for name, submodule in orig._modules.items(): + if submodule is None: + continue tmp_module._modules[name] = make_module( submodule, TracedModule, _compilation_unit=None ) @@ -1046,7 +1047,7 @@ def check_unique(param): self.__dict__["_name"] = type(orig).__name__ self.__dict__["_actual_script_module"] = script_module - for name in ("_parameters", "_buffers", "_modules"): + for name in ("_parameters", "_buffers", "_modules", "training"): delattr(self, name) def forward(self, *args, **kwargs): @@ -1076,23 +1077,13 @@ def _reconstruct(self, cpp_module): """ Re-construct an instance of TopLevelTracedModule using an instance of a C++ module. - Arguments: + Args: cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around. """ self.__dict__["_actual_script_module"]._reconstruct(cpp_module) def _script_if_tracing(fn): - """ - Compiles ``fn`` when it is first called during tracing. ``torch.jit.script`` - has a non-negligible start up time when it is first called due to - lazy-initializations of many compiler builtins. Therefore you should not use - it in library code. However, you may want to have parts of your library work - in tracing even if they use control flow. In these cases, you should use - ``@torch.jit._script_if_tracing`` to substitute for - ``torch.jit.script``. - """ - @functools.wraps(fn) def wrapper(*args, **kwargs): if not is_tracing(): @@ -1126,7 +1117,7 @@ def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False Tracing is guaranteed not to change the semantics of the function/module that is traced. - Arguments: + Args: f (torch.nn.Module or function): the function or module to be traced. args (tuple or Tensor): the positional arguments to pass to the diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index d9fce627e52da..1c9b74526028b 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -6,10 +6,11 @@ from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \ is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore +from ._state import _get_script_class from torch._C import TensorType, TupleType, FloatType, IntType, \ ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \ - DeviceObjType, FutureType, EnumType + DeviceObjType, StreamObjType, FutureType, EnumType from textwrap import dedent @@ -271,7 +272,7 @@ def get_enum_value_type(e: Type[enum.Enum], loc): def try_ann_to_type(ann, loc): if ann is None: - return TensorType.get() + return TensorType.getInferred() if inspect.isclass(ann) and issubclass(ann, torch.Tensor): return TensorType.get() if is_tuple(ann): @@ -283,6 +284,11 @@ def try_ann_to_type(ann, loc): if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) + # Raise error if key or value is None + if key is None: + raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}") + if value is None: + raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}") return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): @@ -313,19 +319,23 @@ def try_ann_to_type(ann, loc): return InterfaceType(_qualified_name(ann)) if ann is torch.device: return DeviceObjType.get() + if ann is torch.Stream: + return StreamObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): - if not hasattr(ann, "__torch_script_class__"): + qualified_name = _qualified_name(ann) + if _get_script_class(qualified_name) is None: torch.jit._script._recursive_compile_class(ann, loc) return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): - if hasattr(ann, "__torch_script_class__"): - return ClassType(_qualified_name(ann)) + qualified_name = _qualified_name(ann) + if _get_script_class(qualified_name) is not None: + return ClassType(qualified_name) ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes): torch.jit._script._recursive_compile_class(ann, loc) - return ClassType(_qualified_name(ann)) + return ClassType(qualified_name) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): @@ -337,7 +347,7 @@ def ann_to_type(ann, loc): the_type = try_ann_to_type(ann, loc) if the_type is not None: return the_type - raise ValueError(f"Unknown type annotation: '{ann}'") + raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}") __all__ = [ diff --git a/torch/jit/cuda.py b/torch/jit/cuda.py new file mode 100644 index 0000000000000..16805301600bd --- /dev/null +++ b/torch/jit/cuda.py @@ -0,0 +1,182 @@ +# mypy: ignore-errors + +r""" +This package adds support for JIT compilation for CUDA Streams and events, +This is similar to API's available in the eager mode +:ref:`cuda-semantics` has more details about working with CUDA. +""" + +import torch +from typing import Optional, Any +from torch import device as _device + +def get_current_device_index() -> int: + r"""Checks if there are CUDA devices available and + returns the device index of the current default CUDA device. + Returns -1 in case there are no CUDA devices available. + + Arguments: ``None`` + """ + if torch.cuda.device_count() > 0: + return torch.cuda._current_device() + return -1 + +def get_device_index(device: Optional[_device] = None, optional: bool = False, allow_cpu: bool = False) -> int: + r"""Gets the device index from :attr:`device`, which can be a torch.device + object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a CUDA device. Note that for a CUDA device without a specified index, + , this will return the current default CUDA device if :attr:`optional` is ``True``. + If :attr:`allow_cpu` is ``True``,CPU devices will be accepted and ``-1`` will be + returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default CUDA + device if :attr:`optional` is ``True``. + """ + if device is None: + if optional: + return get_current_device_index() + else: + raise ValueError('Expected a torch.device with a specified index ' + f'or an integer, but got: {device}') + device_index = -1 + if isinstance(device, str): + device = torch.device(device) + + if isinstance(device, torch.device): + if not allow_cpu and device.type == 'cpu': + raise ValueError(f'Expected a non cpu device, but got: {device}') + device_index = -1 if device.type == 'cpu' else torch.cuda.device_index(device) + + if isinstance(device, int): + device_index = device + + return device_index + +class device(object): + r"""Context-manager that changes the selected device. + This is similar to device (torch.device or int), but has been + introduced for JIT compatibility. + Arguments: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + def __init__(self, device: Optional[_device]): + self.idx = -1 + self.prev_idx = -1 + self.device = device + + def __enter__(self): + self.idx = get_device_index(self.device, optional=True) + + if self.idx == -1: + return + self.prev_idx = torch.cuda._current_device() + + if self.prev_idx != self.idx: + torch.cuda._set_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + if self.prev_idx != self.idx: + torch.cuda._set_device(self.prev_idx) + +class StreamContext(object): + r"""Context-manager that selects a given stream. + All CUDA kernels queued within its context will be enqueued on a selected + stream. + Arguments: + StreamContext (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. If the selected stream is not on the + current device, this function will also change the current device to + match the stream. + """ + cur_stream : Optional['torch.classes.cuda.Stream'] + + def __init__(self, stream: Optional['torch.classes.cuda.Stream']): + self.idx = -1 + self.stream = stream + # Initialize the below streams to default stream on the current device + self.device_index = get_current_device_index() + self.src_prev_stream = torch.cuda.default_stream(self.device_index) + self.dst_prev_stream = torch.cuda.default_stream(self.device_index) + + def __enter__(self): + self.idx = get_device_index(device=None, optional=True) + # If there is no CUDA device available, return + if self.idx == -1: + return + + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None + if cur_stream is None: + return + self.src_prev_stream = torch.cuda.current_stream(self.idx) + # If the stream is not on the current device, then change the device + # and set the current stream on the device + if self.src_prev_stream.device_index() != cur_stream.device_index(): + with device(cur_stream.device()): + self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device_index()) + torch.cuda._set_device(cur_stream.device_index()) + torch.cuda.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no CUDA device available, return + if cur_stream is None or self.idx == -1: + return + # If the stream was not on the current device, restore the previous stream on + # the destination device and also reset the current device to the previous device. + # Set the current stream on the device to the src_prev_stream + if self.src_prev_stream.device_index() != cur_stream.device_index(): + torch.cuda.set_stream(self.dst_prev_stream) + torch.cuda._set_device(self.idx) + torch.cuda.set_stream(self.src_prev_stream) + +def stream(stream: Optional['torch.classes.cuda.Stream']) -> StreamContext: + r"""Wrapper around the Context-manager that selects a given stream. + All CUDA kernels queued within its context will be enqueued on a selected + stream. + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + """ + return StreamContext(stream) + +def Stream(device: int = -1, priority: int = 0) -> 'torch.classes.cuda.Stream': + r"""Wrapper around a CUDA stream. + A CUDA stream is a linear sequence of execution that belongs to a specific + device, independent from other streams. See :ref:`cuda-semantics` for + details. + Arguments: + device(int, optional): a device on which to allocate + the stream. If :attr:`device` is ``None`` (default) or a negative + integer, this will use the current device. + priority(int, optional): priority of the stream. Can be either + -1 (high priority) or 0 (low priority). By default, streams have + priority 0. + .. note:: Although CUDA versions >= 11 support more than two levels of + priorities, in PyTorch, we only support two levels of priorities. + """ + return torch.classes.cuda.Stream(device, priority) + +def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False) -> 'torch.classes.cuda.Event': + r"""Wrapper around a CUDA event. + CUDA events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize CUDA + streams. + Arguments: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) + interprocess (bool): if ``True``, the event can be shared between processes + (default: ``False``) + .. _CUDA Event Documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html + """ + return torch.classes.cuda.Event(enable_timing, blocking, interprocess) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 4cfba50d0466a..c6e5bd9a7870d 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -4,6 +4,7 @@ import inspect import string from textwrap import dedent +from typing import List from torch._C._jit_tree_views import ( ClassDef, Ident, Stmt, Decl, Def, Var, EmptyTypeAnnotation, Param, ExprStmt, Assign, @@ -13,6 +14,7 @@ ListLiteral, TupleLiteral, DictLiteral, Const, StringLiteral, ListComp, Attribute, BinOp, UnaryOp, SliceExpr, Subscript, TernaryIf, With, WithItem, Property, + DictComp, ) from torch._utils_internal import get_source_lines_and_file @@ -132,7 +134,7 @@ def get_class_properties(cls, self_name): """ Get a list of Property objects representing the properties of a class. - Arguments: + Args: cls: The class to get properties of. self_name: The name of the class that the properties should belong to. Returns: @@ -142,12 +144,12 @@ def get_class_properties(cls, self_name): props = inspect.getmembers( cls, predicate=lambda m: isinstance(m, property)) # Any property that should not compiled must be in this list on the Module. - ignored_properties = getattr(cls, "__ignored_properties__", []) + unused_properties = getattr(cls, "__jit_unused_properties__", []) # Create Property TreeView objects from inspected property objects. properties = [] for prop in props: - if prop[0] not in ignored_properties: + if prop[0] not in unused_properties and not should_drop(prop[1].fget): getter = get_jit_def(prop[1].fget, f"__{prop[0]}_getter", self_name=self_name) setter = get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) if prop[1].fset else None properties.append(Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter)) @@ -164,9 +166,14 @@ def get_jit_class_def(cls, self_name): and not is_static_fn(cls, m.__name__) and m.__name__ in cls.__dict__ ) + + def is_classmethod(fn): + return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls + methods = [get_jit_def(method[1], method[0], - self_name=self_name) for method in methods] + self_name=self_name, + is_classmethod=is_classmethod(method[1])) for method in methods] properties = get_class_properties(cls, self_name) @@ -179,11 +186,47 @@ def get_jit_class_def(cls, self_name): return build_class_def(ctx, py_ast.body[0], methods, properties, self_name) -def get_jit_def(fn, def_name, self_name=None): +def normalize_source_lines(sourcelines: List[str]) -> List[str]: + """ + This helper function accepts a list of source lines. It finds the + indentation level of the function definition (`def`), then it indents + all lines in the function body to a point at or greater than that + level. This allows for comments and continued string literals that + are at a lower indentation than the rest of the code. + Args: + sourcelines: function source code, separated into lines by + the '\n' character + Returns: + A list of source lines that have been correctly aligned + """ + + def remove_prefix(text, prefix): + return text[text.startswith(prefix) and len(prefix):] + + # Find the line and line number containing the function definition + for i, l in enumerate(sourcelines): + if l.lstrip().startswith("def"): + idx = i + break + fn_def = sourcelines[idx] + + # Get a string representing the amount of leading whitespace + whitespace = fn_def.split("def")[0] + + # Add this leading whitespace to all lines before and after the `def` + aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]] + aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]] + + # Put it together again + aligned_prefix.append(fn_def) + return aligned_prefix + aligned_suffix + + +def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. - Arguments: + Args: fn: A function object to compile def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: @@ -195,21 +238,28 @@ def _forward(self): self_name: If this function is a method, what the type name of `self` is. """ sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack()) + sourcelines = normalize_source_lines(sourcelines) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): - raise RuntimeError("Expected a single top-level function") + raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) type_line = torch.jit.annotations.get_type_line(source) ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] + if is_classmethod: + arg_name = fn_def.args.args[0].arg + # Insert a statement that assigns the first argument to the class + assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] + fn_def.body.insert(0, assign_stmt) + # Swap out the function signature and body if it is unused if should_drop(fn): unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")") if len(unused_fn_def.body) != 1 or not isinstance(unused_fn_def.body[0], ast.FunctionDef): - raise RuntimeError("Expected a single top-level function") + raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") unused_def = unused_fn_def.body[0] fn_def.body = unused_def.body # kwarg/vararg not supported by `build_def` @@ -350,6 +400,12 @@ class StmtBuilder(Builder): ast.Mult: '*', ast.Div: '/', ast.Mod: '%', + ast.BitOr: '|', + ast.BitAnd: '&', + ast.BitXor: '^', + ast.LShift: '<<', + ast.RShift: '>>', + ast.Pow: '**', } @staticmethod @@ -365,7 +421,7 @@ def build_Expr(ctx, stmt): @staticmethod def build_Assign(ctx, stmt): rhs = build_expr(ctx, stmt.value) - lhs = list(map(lambda x: build_expr(ctx, x), stmt.targets)) + lhs = [build_expr(ctx, x) for x in stmt.targets] return Assign(lhs, rhs) @staticmethod @@ -379,12 +435,9 @@ def build_AnnAssign(ctx, stmt): @staticmethod def build_Delete(ctx, stmt): - if len(stmt.targets) > 1: - source_range = ctx.make_range(stmt.lineno, stmt.col_offset, - stmt.col_offset + len("del")) - raise NotSupportedError( - source_range, 'del with more than one operand is not supported') - return Delete(build_expr(ctx, stmt.targets[0])) + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del")) + + return Delete(r, [build_expr(ctx, target) for target in stmt.targets]) @staticmethod def build_Return(ctx, stmt): @@ -657,11 +710,10 @@ def build_SliceExpr(ctx, base, slice_expr): return SliceExpr(base.range(), lower, upper, step) def build_Index(ctx, base, index_expr): - if isinstance(index_expr.value, ast.Tuple) or \ - isinstance(index_expr.value, ast.List): + if isinstance(index_expr.value, ast.Tuple): raise NotSupportedError(base.range(), "slicing multiple dimensions with " - "sequences not supported yet") + "tuples not supported yet") return build_expr(ctx, index_expr.value) def build_ExtSlice(ctx, base, extslice): @@ -686,9 +738,7 @@ def build_ExtSlice(ctx, base, extslice): if isinstance(expr.slice.value, ast.Tuple): # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] # XXX: Indexing using a list is **different**! It triggers advanced indexing. - indices = [] - for index_expr in expr.slice.value.elts: - indices.append(build_expr(ctx, index_expr)) + indices = [build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts] return Subscript(base, indices) else: return Subscript(base, [build_expr(ctx, expr.slice.value)]) @@ -696,6 +746,17 @@ def build_ExtSlice(ctx, base, extslice): return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)]) elif sub_type is ast.ExtSlice: return Subscript(base, build_ExtSlice(ctx, base, expr.slice)) + elif sys.version_info >= (3, 9): # In Python3.9 array indicies are not wrapped in ast.Index + if sub_type is ast.Tuple: + # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] + indices = [] + for index_expr in expr.slice.elts: + if isinstance(index_expr, ast.Slice): + indices.append(build_SliceExpr(ctx, base, index_expr)) + else: + indices.append(build_expr(ctx, index_expr)) + return Subscript(base, indices) + return Subscript(base, [build_expr(ctx, expr.slice)]) else: # Ellipsis (can only happen in Python 2) raise NotSupportedError(base.range(), "ellipsis is not supported") @@ -767,18 +828,34 @@ def build_JoinedStr(ctx, expr): @staticmethod def build_ListComp(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) - if (len(stmt.generators) > 1): - raise NotSupportedError(r, "multiple comprehension generators not supported yet") + if (len(stmt.generators) != 1): + raise NotSupportedError(r, "Only a single generator is currently supported") if (len(stmt.generators[0].ifs) != 0): - raise NotSupportedError(r, "comprehension ifs not supported yet") + raise NotSupportedError(r, "Comprehension ifs are not supported yet") elt_expr = build_expr(ctx, stmt.elt) target_expr = build_expr(ctx, stmt.generators[0].target) - iter_expr = build_expr(ctx, stmt.generators[0].iter) + return ListComp(r, elt_expr, target_expr, iter_expr) + @staticmethod + def build_DictComp(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) + if (len(stmt.generators) != 1): + raise NotSupportedError(r, "Only a single generator is currently supported") + + if (len(stmt.generators[0].ifs) != 0): + raise NotSupportedError(r, "Comprehension ifs are not supported yet") + + key_expr = build_expr(ctx, stmt.key) + value_expr = build_expr(ctx, stmt.value) + target_expr = build_expr(ctx, stmt.generators[0].target) + iter_expr = build_expr(ctx, stmt.generators[0].iter) + + return DictComp(r, key_expr, value_expr, target_expr, iter_expr) + @staticmethod def build_Starred(ctx, expr): r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) diff --git a/torch/jit/mobile/__init__.py b/torch/jit/mobile/__init__.py index 6afbf6f2cdb6b..4356400cb4474 100644 --- a/torch/jit/mobile/__init__.py +++ b/torch/jit/mobile/__init__.py @@ -10,7 +10,7 @@ def _load_for_lite_interpreter(f, map_location=None): Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter` - Arguments: + Args: f: a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name map_location: a string or torch.device used to dynamically remap diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index 615741f38da7d..a3dfeba640c86 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -1,13 +1,12 @@ -import torch - -from typing import Tuple, Optional, List - from torch import Tensor, _VF # noqa: F401 - from torch.nn.utils.rnn import PackedSequence +import torch import warnings +from typing import List, Optional, Tuple + + class QuantizedLinear(torch.jit.ScriptModule): __constants__ = ['scale', 'zero_point'] @@ -130,8 +129,7 @@ def check_forward_input(self, input): input.size(1), self.input_size)) @torch.jit.script_method - def check_forward_hidden(self, input, hx, hidden_label=''): - # type: (Tensor, Tensor, str) -> None + def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None: if input.size(0) != hx.size(0): raise RuntimeError( "Input batch size {} doesn't match hidden{} batch size {}".format( @@ -169,8 +167,7 @@ def __init__(self, other): self.nonlinearity = other.nonlinearity @torch.jit.script_method - def forward(self, input, hx=None): - # type: (Tensor, Optional[Tensor]) -> Tensor + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) @@ -201,8 +198,7 @@ def __init__(self, other): super(QuantizedLSTMCell, self).__init__(other) @torch.jit.script_method - def forward(self, input, hx=None): - # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: self.check_forward_input(input) if hx is None: zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) @@ -222,8 +218,7 @@ def __init__(self, other): super(QuantizedGRUCell, self).__init__(other) @torch.jit.script_method - def forward(self, input, hx=None): - # type: (Tensor, Optional[Tensor]) -> Tensor + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) @@ -236,8 +231,7 @@ def forward(self, input, hx=None): ) -def apply_permutation(tensor, permutation, dim=1): - # type: (Tensor, Tensor, int) -> Tensor +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: return tensor.index_select(dim, permutation) @@ -303,8 +297,7 @@ def get_weight_bias(ihhh): self.all_weights.append(cell_params) @torch.jit.script_method - def check_input(self, input, batch_sizes): - # type: (Tensor, Optional[Tensor]) -> None + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: expected_input_dim = 2 if batch_sizes is not None else 3 if input.dim() != expected_input_dim: raise RuntimeError( @@ -316,8 +309,7 @@ def check_input(self, input, batch_sizes): self.input_size, input.size(-1))) @torch.jit.script_method - def get_expected_hidden_size(self, input, batch_sizes): - # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int] + def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: @@ -328,21 +320,19 @@ def get_expected_hidden_size(self, input, batch_sizes): return expected_hidden_size @torch.jit.script_method - def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): - # type: (Tensor, Tuple[int, int, int], str) -> None + def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], + msg: str = 'Expected hidden size {}, got {}') -> None: if hx.size() != expected_hidden_size: raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) @torch.jit.script_method - def check_forward_args(self, input, hidden, batch_sizes): - # type: (Tensor, Tensor, Optional[Tensor]) -> None + def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) self.check_hidden_size(hidden, expected_hidden_size, msg='Expected hidden size {}, got {}') @torch.jit.script_method - def permute_hidden(self, hx, permutation): - # type: (Tensor, Optional[Tensor]) -> Tensor + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: if permutation is None: return hx return apply_permutation(hx, permutation) @@ -355,8 +345,9 @@ def __init__(self, other, dtype): super(QuantizedLSTM, self).__init__(other, dtype) @torch.jit.script_method - def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): - # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa + def forward_impl(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]], batch_sizes: Optional[Tensor], + max_batch_size: int, sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + # noqa if hx is None: num_directions = 2 if self.bidirectional else 1 zeros = torch.zeros(self.num_layers * num_directions, @@ -379,8 +370,7 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): return output, hidden @torch.jit.script_method - def forward_tensor(self, input, hx=None): - # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + def forward_tensor(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None @@ -391,8 +381,8 @@ def forward_tensor(self, input, hx=None): return output, self.permute_hidden(hidden, unsorted_indices) @torch.jit.script_method - def forward_packed(self, input, hx=None): - # type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa + def forward_packed(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size) @@ -404,15 +394,13 @@ def forward_packed(self, input, hx=None): @torch.jit.script_method - def permute_hidden(self, hx, permutation): - # type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor] + def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]: if permutation is None: return hx return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) @torch.jit.script_method - def check_forward_args(self, input, hidden, batch_sizes): - # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None + def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) @@ -432,8 +420,9 @@ class QuantizedGRU(QuantizedRNNBase): __overloads__ = {'forward': ['forward_packed', 'forward_tensor']} @torch.jit.script_method - def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): - # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa + def forward_impl(self, input: Tensor, hx: Optional[Tensor], batch_sizes: Optional[Tensor], max_batch_size: int, + sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tensor]: + # noqa if hx is None: num_directions = 2 if self.bidirectional else 1 hx = torch.zeros(self.num_layers * num_directions, @@ -459,8 +448,7 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): return output, hidden @torch.jit.script_method - def forward_tensor(self, input, hx=None): - # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor] + def forward_tensor(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None @@ -470,8 +458,7 @@ def forward_tensor(self, input, hx=None): return output, self.permute_hidden(hidden, unsorted_indices) @torch.jit.script_method - def forward_packed(self, input, hx=None): - # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor] + def forward_packed(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size) diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py index b8e0f48ad04ac..5babb405280fb 100644 --- a/torch/jit/unsupported_tensor_ops.py +++ b/torch/jit/unsupported_tensor_ops.py @@ -35,8 +35,8 @@ def func(x): else: properties.append(attr) - mapped_methods = map(lambda x: "\t* :meth:`~torch.Tensor." + x + r"`", methods) - mapped_properties = map(lambda x: "\t* :attr:`~torch.Tensor." + x + r"`", properties) + mapped_methods = ("\t* :meth:`~torch.Tensor." + x + r"`" for x in methods) + mapped_properties = ("\t* :attr:`~torch.Tensor." + x + r"`" for x in properties) return "\n".join(mapped_methods), "\n".join(mapped_properties) diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt index 68fe49f411f5b..4e72e2e32fbf5 100644 --- a/torch/lib/c10d/CMakeLists.txt +++ b/torch/lib/c10d/CMakeLists.txt @@ -16,7 +16,7 @@ else() endif() if(USE_TBB) -include_directories(${TBB_ROOT_DIR}/include) + include_directories(${TBB_ROOT_DIR}/include) endif() if(USE_GLOO) @@ -45,15 +45,17 @@ endfunction() set(C10D_SRCS FileStore.cpp - HashStore.cpp + PrefixStore.cpp ProcessGroup.cpp - ProcessGroupRoundRobin.cpp Store.cpp - PrefixStore.cpp TCPStore.cpp Utils.cpp ) +if(NOT WIN32) + list(APPEND C10D_SRCS HashStore.cpp ProcessGroupRoundRobin.cpp) +endif() + set(C10D_LIBS torch) if(USE_C10D_NCCL) @@ -77,14 +79,17 @@ endif() add_library(c10d STATIC ${C10D_SRCS}) set_property(TARGET c10d PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET c10d PROPERTY CXX_STANDARD 14) -target_compile_options(c10d PUBLIC - -Wall - -Wextra - -Wno-unused-parameter - -Wno-missing-field-initializers - -Wno-write-strings - -Wno-unknown-pragmas - ) + +if(NOT MSVC) + target_compile_options(c10d PUBLIC + -Wall + -Wextra + -Wno-unused-parameter + -Wno-missing-field-initializers + -Wno-write-strings + -Wno-unknown-pragmas + ) +endif() add_dependencies(c10d torch) @@ -102,8 +107,6 @@ target_include_directories(c10d PUBLIC # For target_include_directories(c10d PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) -# For torch/csrc/utils/hash.h -target_include_directories(c10d PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../..) if(USE_C10D_NCCL) target_compile_definitions(c10d INTERFACE USE_C10D_NCCL) @@ -118,7 +121,6 @@ if(USE_C10D_GLOO) endif() copy_header(FileStore.hpp) -copy_header(HashStore.hpp) copy_header(PrefixStore.hpp) copy_header(ProcessGroup.hpp) copy_header(Store.hpp) @@ -129,6 +131,12 @@ if(USE_GLOO) copy_header(ProcessGroupGloo.hpp) copy_header(GlooDeviceFactory.hpp) endif() +if(NOT WIN32) + copy_header(HashStore.hpp) + copy_header(UnixSockUtils.hpp) +else() + copy_header(WinSockUtils.hpp) +endif() if(USE_C10D_NCCL) copy_header(ProcessGroupNCCL.hpp) diff --git a/torch/lib/c10d/FileStore.cpp b/torch/lib/c10d/FileStore.cpp index 55346e0fa6358..335efec902460 100644 --- a/torch/lib/c10d/FileStore.cpp +++ b/torch/lib/c10d/FileStore.cpp @@ -3,9 +3,16 @@ #include #include #include -#include #include + +#ifdef _WIN32 +#include +#include +#include +#else +#include #include +#endif #include #include @@ -16,11 +23,47 @@ #include #include +#include + #define SYSASSERT(rv, ...) \ if ((rv) < 0) { \ throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \ } +#ifdef _WIN32 +#define LOCK_EX 0x00000001 +#define LOCK_SH 0x00000010 +#define LOCK_UN 0x00000100 + +int flock_(int fd, int op) { + HANDLE hdl = (HANDLE) _get_osfhandle(fd); + DWORD low = 1, high = 0; + OVERLAPPED offset = {0, 0, 0, 0, NULL}; + + if (hdl < 0) + return -1; + + switch (op) { + case LOCK_EX: + if (LockFileEx(hdl, LOCKFILE_EXCLUSIVE_LOCK, 0, low, high, &offset)) + return 0; + break; + case LOCK_SH: + if (LockFileEx(hdl, 0, 0, low, high, &offset)) + return 0; + break; + case LOCK_UN: + if(UnlockFileEx(hdl, 0, low, high, &offset) != 0) + return 0; + break; + default: + break; + } + errno = EINVAL; + return -1; +} +#endif + namespace c10d { namespace { @@ -79,7 +122,11 @@ class Lock { int fd_{-1}; void flock(int operation) { +#ifdef _WIN32 + auto rv = syscall(std::bind(::flock_, fd_, operation)); +#else auto rv = syscall(std::bind(::flock, fd_, operation)); +#endif SYSASSERT(rv, "flock"); } }; @@ -92,7 +139,11 @@ class File { std::chrono::milliseconds timeout) { const auto start = std::chrono::steady_clock::now(); while (true) { +#ifdef _WIN32 + fd_ = syscall(std::bind(::open, path.c_str(), flags | _O_BINARY, _S_IREAD | _S_IWRITE)); +#else fd_ = syscall(std::bind(::open, path.c_str(), flags, 0644)); +#endif // Only retry when the file doesn't exist, since we are waiting for the // file to be created in this case to address the following issue: // https://github.com/pytorch/pytorch/issues/13750 @@ -303,6 +354,18 @@ int64_t FileStore::add(const std::string& key, int64_t value) { return addHelper(regKey, value); } +int64_t FileStore::getNumKeys() { + std::unique_lock l(activeFileOpLock_); + File file(path_, O_RDONLY, timeout_); + auto lock = file.lockShared(); + pos_ = refresh(file, pos_, cache_); + return cache_.size(); +} + +bool FileStore::deleteKey(const std::string& /* unused */) { + TORCH_CHECK(false, "deleteKey not implemented for FileStore"); +} + bool FileStore::check(const std::vector& keys) { std::unique_lock l(activeFileOpLock_); File file(path_, O_RDONLY, timeout_); diff --git a/torch/lib/c10d/FileStore.hpp b/torch/lib/c10d/FileStore.hpp index dfca47ba7cc4e..aa5d9946e5b3c 100644 --- a/torch/lib/c10d/FileStore.hpp +++ b/torch/lib/c10d/FileStore.hpp @@ -21,6 +21,10 @@ class FileStore : public Store { int64_t add(const std::string& key, int64_t value) override; + int64_t getNumKeys() override; + + bool deleteKey(const std::string& key) override; + bool check(const std::vector& keys) override; void wait(const std::vector& keys) override; diff --git a/torch/lib/c10d/GlooDeviceFactory.cpp b/torch/lib/c10d/GlooDeviceFactory.cpp index 70c3c2bb7a31d..dca6b03eb9ddc 100644 --- a/torch/lib/c10d/GlooDeviceFactory.cpp +++ b/torch/lib/c10d/GlooDeviceFactory.cpp @@ -36,16 +36,16 @@ C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( #if GLOO_HAVE_TRANSPORT_TCP static std::shared_ptr<::gloo::transport::Device> makeTCPDevice( - const std::string& interface, + const std::string& interfaceName, const std::string& hostname) { TORCH_CHECK( - !interface.empty() || !hostname.empty(), + !interfaceName.empty() || !hostname.empty(), "GlooDeviceFactory::makeTCPDevice(): interface or hostname " "can't be empty"); ::gloo::transport::tcp::attr attr; - if (!interface.empty()) { - attr.iface = interface; + if (!interfaceName.empty()) { + attr.iface = interfaceName; } else { attr.hostname = hostname; } @@ -61,16 +61,16 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice); #if GLOO_HAVE_TRANSPORT_UV static std::shared_ptr<::gloo::transport::Device> makeUVDevice( - const std::string& interface, + const std::string& interfaceName, const std::string& hostname) { TORCH_CHECK( - !interface.empty() || !hostname.empty(), + !interfaceName.empty() || !hostname.empty(), "GlooDeviceFactory::makeUVDevice(): interface or hostname " "can't be empty"); ::gloo::transport::uv::attr attr; - if (!interface.empty()) { - attr.iface = interface; + if (!interfaceName.empty()) { + attr.iface = interfaceName; } else { attr.hostname = hostname; } @@ -81,23 +81,28 @@ static std::shared_ptr<::gloo::transport::Device> makeUVDevice( // the flexibility of other application to override by priority. Register // UV to `UV` for env "GLOO_DEVICE_TRANSPORT" override. C10_REGISTER_CREATOR(GlooDeviceRegistry, APPLE, makeUVDevice); +C10_REGISTER_CREATOR(GlooDeviceRegistry, WIN32, makeUVDevice); C10_REGISTER_CREATOR(GlooDeviceRegistry, UV, makeUVDevice); #endif static const char* glooDeviceTransport = getenv("GLOO_DEVICE_TRANSPORT"); std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: - makeDeviceForInterface(const std::string& interface) { + makeDeviceForInterface(const std::string& interfaceName) { if (glooDeviceTransport) { - return GlooDeviceRegistry()->Create(glooDeviceTransport, interface, ""); + return GlooDeviceRegistry()->Create(glooDeviceTransport, interfaceName, ""); } #ifdef __linux__ - return GlooDeviceRegistry()->Create("LINUX", interface, ""); + return GlooDeviceRegistry()->Create("LINUX", interfaceName, ""); #endif #ifdef __APPLE__ - return GlooDeviceRegistry()->Create("APPLE", interface, ""); + return GlooDeviceRegistry()->Create("APPLE", interfaceName, ""); +#endif + +#ifdef _WIN32 + return GlooDeviceRegistry()->Create("WIN32", interfaceName, ""); #endif throw std::runtime_error("makeDeviceForInterface(): unsupported gloo device"); @@ -117,6 +122,10 @@ std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: return GlooDeviceRegistry()->Create("APPLE", "", hostname); #endif +#ifdef _WIN32 + return GlooDeviceRegistry()->Create("WIN32", "", hostname); +#endif + throw std::runtime_error("makeDeviceForHostname(): unsupported gloo device"); } diff --git a/torch/lib/c10d/HashStore.cpp b/torch/lib/c10d/HashStore.cpp index 191560d5b0fcd..0cd89181ef1bb 100644 --- a/torch/lib/c10d/HashStore.cpp +++ b/torch/lib/c10d/HashStore.cpp @@ -8,6 +8,8 @@ #include #include +#include + namespace c10d { void HashStore::set(const std::string& key, const std::vector& data) { @@ -77,6 +79,17 @@ int64_t HashStore::add(const std::string& key, int64_t i) { return ti; } +int64_t HashStore::getNumKeys() { + std::unique_lock lock(m_); + return map_.size(); +} + +bool HashStore::deleteKey(const std::string& key) { + std::unique_lock lock(m_); + auto numDeleted = map_.erase(key); + return (numDeleted == 1); +} + bool HashStore::check(const std::vector& keys) { std::unique_lock lock(m_); for (const auto& key : keys) { diff --git a/torch/lib/c10d/HashStore.hpp b/torch/lib/c10d/HashStore.hpp index 0d55722efae94..1bdd67ca603ce 100644 --- a/torch/lib/c10d/HashStore.hpp +++ b/torch/lib/c10d/HashStore.hpp @@ -28,8 +28,12 @@ class HashStore : public Store { int64_t add(const std::string& key, int64_t value) override; + int64_t getNumKeys() override; + bool check(const std::vector& keys) override; + bool deleteKey(const std::string& key) override; + protected: std::unordered_map> map_; std::mutex m_; diff --git a/torch/lib/c10d/NCCLUtils.hpp b/torch/lib/c10d/NCCLUtils.hpp index 433a71ef92d79..de9484445b61b 100644 --- a/torch/lib/c10d/NCCLUtils.hpp +++ b/torch/lib/c10d/NCCLUtils.hpp @@ -8,6 +8,27 @@ #include +namespace { + // Provides additional detail into NCCL error codes based on when these are + // thrown in the NCCL codebase. +const char* errorMessage(ncclResult_t error) { + switch (error) { + case ncclUnhandledCudaError: + return "ncclUnhandledCudaError: Call to CUDA function failed."; + case ncclSystemError: + return "ncclSystemError: System call (socket, malloc, munmap, etc) failed."; + case ncclInternalError: + return "ncclInternalError: Internal check failed. This is either a bug in NCCL or due to memory corruption"; + case ncclInvalidArgument: + return "ncclInvalidArgument: Invalid value for an argument (such as invalid pointer, device count, ip:host pair, etc)."; + case ncclInvalidUsage: + return "ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc)."; + default: + break; + } + return "Unknown NCCL error"; +} +} // namespace // Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort() // and ncclCommGetAsyncError() are not supported in earlier versions. #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ @@ -17,8 +38,6 @@ #define ENABLE_NCCL_ERROR_CHECKING #endif -// Fix build issues with NCCL P2P - until then disable NCCL send/recv. -#if defined(ENABLE_NCCL_A2A) && (ENABLE_NCCL_A2A == 1) // P2P is enabled only for NCCL versions 2.7+ since ncclSend() // and ncclRecv() are not supported in earlier versions. #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ @@ -27,17 +46,17 @@ #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) #define ENABLE_NCCL_P2P_SUPPORT #endif -#endif // Macro to throw on a non-successful NCCL return value. -#define C10D_NCCL_CHECK(cmd) \ - do { \ - ncclResult_t result = cmd; \ - if (result != ncclSuccess) { \ - std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ - std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result); \ - throw std::runtime_error(err); \ - } \ +#define C10D_NCCL_CHECK(cmd) \ + do { \ + ncclResult_t result = cmd; \ + if (result != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ + "\n" + errorMessage(result); \ + throw std::runtime_error(err); \ + } \ } while (0) // Macro to print and abort on a non-successful NCCL return value. diff --git a/torch/lib/c10d/PrefixStore.cpp b/torch/lib/c10d/PrefixStore.cpp index a1bc174509424..6f71e422bd0e3 100644 --- a/torch/lib/c10d/PrefixStore.cpp +++ b/torch/lib/c10d/PrefixStore.cpp @@ -4,7 +4,7 @@ namespace c10d { PrefixStore::PrefixStore( const std::string& prefix, - std::shared_ptr store) + c10::intrusive_ptr store) : prefix_(prefix), store_(store) {} std::string PrefixStore::joinKey(const std::string& key) { @@ -35,6 +35,14 @@ int64_t PrefixStore::add(const std::string& key, int64_t value) { return store_->add(joinKey(key), value); } +bool PrefixStore::deleteKey(const std::string& key) { + return store_->deleteKey(joinKey(key)); +} + +int64_t PrefixStore::getNumKeys() { + return store_->getNumKeys(); +} + bool PrefixStore::check(const std::vector& keys) { auto joinedKeys = joinKeys(keys); return store_->check(joinedKeys); diff --git a/torch/lib/c10d/PrefixStore.hpp b/torch/lib/c10d/PrefixStore.hpp index 86dba598ed33f..ec50b3b719bff 100644 --- a/torch/lib/c10d/PrefixStore.hpp +++ b/torch/lib/c10d/PrefixStore.hpp @@ -7,7 +7,9 @@ namespace c10d { class PrefixStore : public Store { public: - explicit PrefixStore(const std::string& prefix, std::shared_ptr store); + explicit PrefixStore( + const std::string& prefix, + c10::intrusive_ptr store); virtual ~PrefixStore(){}; @@ -17,6 +19,10 @@ class PrefixStore : public Store { int64_t add(const std::string& key, int64_t value) override; + bool deleteKey(const std::string& key) override; + + int64_t getNumKeys() override; + bool check(const std::vector& keys) override; void wait(const std::vector& keys) override; @@ -27,7 +33,7 @@ class PrefixStore : public Store { protected: std::string prefix_; - std::shared_ptr store_; + c10::intrusive_ptr store_; std::string joinKey(const std::string& key); std::vector joinKeys(const std::vector& keys); diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index 5c362a42fcf5e..7909bfa7c96c7 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -1,9 +1,77 @@ #include +#include + #include namespace c10d { +std::string opTypeToString(OpType opType) { + switch (opType) { + case OpType::BROADCAST: + return "BROADCAST"; + case OpType::ALLREDUCE: + return "ALLREDUCE"; + case OpType::ALLREDUCE_COALESCED: + return "ALLREDUCE_COALESCED"; + case OpType::REDUCE: + return "REDUCE"; + case OpType::ALLGATHER: + return "ALLGATHER"; + case OpType::ALLGATHER_BASE: + return "ALLGATHER_BASE"; + case OpType::ALLGATHER_COALESCED: + return "ALLGATHER_COALESCED"; + case OpType::GATHER: + return "GATHER"; + case OpType::SCATTER: + return "SCATTER"; + case OpType::REDUCE_SCATTER: + return "REDUCE_SCATTER"; + case OpType::ALLTOALL_BASE: + return "ALLTOALL_BASE"; + case OpType::ALLTOALL: + return "ALLTOALL"; + case OpType::SEND: + return "SEND"; + case OpType::RECV: + return "RECV"; + case OpType::RECVANYSOURCE: + return "RECVANYSOURCE"; + case OpType::BARRIER: + return "BARRIER"; + case OpType::UNKNOWN: + return "UNKNOWN"; + default: + TORCH_INTERNAL_ASSERT("Unknown op type!"); + } + return "UNKNOWN"; +} + +bool isP2POp(OpType opType) { + return opType == OpType::SEND || opType == OpType::RECV || + opType == OpType::RECVANYSOURCE; +} + + +ProcessGroup::Work::Work(int rank, OpType opType, const char* profilingTitle) + : rank_(rank), opType_(opType) { + if (profilingTitle != nullptr) { + auto recordingFunction = std::make_shared(at::RecordScope::USER_SCOPE); + if (recordingFunction->isActive()) { + recordingFunction->before(profilingTitle, {}); + std::function end_handler = [this, recordingFunction]() { + recordingFunction->end(); + }; + recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler); + } + } +} + +OpType ProcessGroup::Work::retrieveOpType() { + return opType_; +} + ProcessGroup::Work::~Work() {} bool ProcessGroup::Work::isCompleted() { @@ -67,6 +135,10 @@ void ProcessGroup::Work::finish(std::exception_ptr exception) { std::unique_lock lock(mutex_); completed_ = true; exception_ = exception; + if (recordFunctionEndCallback_) { + recordFunctionEndCallback_(); + recordFunctionEndCallback_ = nullptr; + } lock.unlock(); cv_.notify_all(); } @@ -75,6 +147,10 @@ void ProcessGroup::Work::finishAndThrow(std::exception_ptr exception) { std::unique_lock lock(mutex_); completed_ = true; exception_ = exception; + if (recordFunctionEndCallback_) { + recordFunctionEndCallback_(); + recordFunctionEndCallback_ = nullptr; + } if (exception_) { std::rethrow_exception(exception_); } @@ -88,7 +164,7 @@ ProcessGroup::~ProcessGroup() {} // This is introduced so that implementors of ProcessGroup would not need to // have this implmentation. -std::shared_ptr ProcessGroup::allgather_coalesced( +c10::intrusive_ptr ProcessGroup::allgather_coalesced( std::vector>& /* usused */, std::vector& /* usused */, const AllgatherOptions& /* usused */) { diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index 59d40d2427a83..ea4b1428038ad 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -11,10 +11,43 @@ #include +// ************************************************************************* +// PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN +// versions 1.7 and 1.8. +// PLEASE DO NOT ADD ANY DEPENDENCIES. +// SEE RFC: https://github.com/pytorch/pytorch/issues/39662 +// ************************************************************************* + constexpr auto kNoTimeout = std::chrono::milliseconds(0); namespace c10d { +enum class OpType : std::uint8_t { + BROADCAST = 0, + ALLREDUCE = 1, + ALLREDUCE_COALESCED = 2, + REDUCE = 3, + ALLGATHER = 4, + ALLGATHER_BASE = 5, + ALLGATHER_COALESCED = 6, + GATHER = 7, + SCATTER = 8, + REDUCE_SCATTER = 9, + ALLTOALL_BASE = 10, + ALLTOALL = 11, + SEND = 12, + RECV = 13, + RECVANYSOURCE = 14, + BARRIER = 15, + UNKNOWN = 100, +}; + +// Converts OpType to human readable string. +std::string opTypeToString(OpType opType); + +// Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) +bool isP2POp(OpType opType); + // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // @@ -35,10 +68,16 @@ namespace c10d { // process group to find each other (referred to as rendezvous from // hereon) // -class ProcessGroup { +class ProcessGroup : public torch::CustomClassHolder { public: - class Work { + // Please do not use ProcessGroup::Work API, it is going away, to be + // replaced by ivalue::Future. + // Python binding for this class might change, please do not assume + // this will be bound using pybind. + class Work : public torch::CustomClassHolder { public: + Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr); + virtual ~Work(); // Checks if request has completed. Non-blocking operation. @@ -93,6 +132,8 @@ class ProcessGroup { // work. Only NCCL backend is currently supported. virtual c10::intrusive_ptr getFuture(); + OpType retrieveOpType(); + protected: // Completes the work object and optionally sets the exception in a // thread-safe manner. Notifies all waiting condition variables as well. @@ -106,6 +147,16 @@ class ProcessGroup { std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; + + // Current rank of the node. + const int rank_; + + // Operation type that this work object refers to. + OpType opType_; + + // When profiling, the callback to record end of operation event. This + // callback needs to be called when collective operation is complete. + std::function recordFunctionEndCallback_; }; explicit ProcessGroup(int rank, int size); @@ -119,25 +170,25 @@ class ProcessGroup { return size_; } - virtual std::shared_ptr broadcast( + virtual c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) = 0; - virtual std::shared_ptr allreduce( + virtual c10::intrusive_ptr allreduce( std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) = 0; // This will be moved out of ProcessGroup, do not add dependencies on this // function. - virtual std::shared_ptr allreduce_coalesced( + virtual c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; - virtual std::shared_ptr reduce( + virtual c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) = 0; - virtual std::shared_ptr allgather( + virtual c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -145,7 +196,7 @@ class ProcessGroup { // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. - virtual std::shared_ptr allgather_base( + virtual c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) = 0; @@ -154,27 +205,27 @@ class ProcessGroup { // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement allgather_base // instead. - virtual std::shared_ptr allgather_coalesced( + virtual c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()); - virtual std::shared_ptr gather( + virtual c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) = 0; - virtual std::shared_ptr scatter( + virtual c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) = 0; - virtual std::shared_ptr reduce_scatter( + virtual c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; - virtual std::shared_ptr alltoall_base( + virtual c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -183,28 +234,28 @@ class ProcessGroup { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr alltoall( + virtual c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) { throw std::runtime_error("ProcessGroup does not support alltoall"); } - virtual std::shared_ptr send( + virtual c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) = 0; - virtual std::shared_ptr recv( + virtual c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) = 0; - virtual std::shared_ptr recvAnysource( + virtual c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) = 0; - virtual std::shared_ptr barrier( + virtual c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) = 0; protected: diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 531fe751f1c9b..22da878cce43e 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -2,10 +2,16 @@ #include +#ifdef _WIN32 +#include +#include +#include +#else #include #include -#include #include +#endif +#include #include @@ -32,10 +38,41 @@ #endif #include +#include #include #include #include +#ifdef _WIN32 +#define GENERATE_ALL_TYPES(type, func, ...) \ + switch (type) { \ + case ::at::ScalarType::Float: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Double: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Half: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Char: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Byte: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Int: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Long: \ + func(__VA_ARGS__); \ + break; \ + default: \ + throw std::runtime_error("Invalid scalar type"); \ + } + +#define HOST_NAME_MAX 256 +#else #define GENERATE_ALL_TYPES(type, func, args...) \ switch (type) { \ case ::at::ScalarType::Float: \ @@ -62,6 +99,7 @@ default: \ throw std::runtime_error("Invalid scalar type"); \ } +#endif namespace c10d { @@ -70,7 +108,7 @@ namespace { // Wrap c10d store as Gloo store class GlooStore : public ::gloo::rendezvous::Store { public: - GlooStore(const std::shared_ptr<::c10d::Store>& store) : store_(store) {} + GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {} void set(const std::string& key, const std::vector& value) override { std::vector tmp(value.begin(), value.end()); @@ -93,7 +131,7 @@ class GlooStore : public ::gloo::rendezvous::Store { } protected: - std::shared_ptr<::c10d::Store> store_; + c10::intrusive_ptr<::c10d::Store> store_; }; typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); @@ -409,12 +447,19 @@ ProcessGroupGloo::Options::Options() namespace { +void socketInitialize() { +#ifdef _WIN32 + ::gloo::init_winsock(); +#endif +} + // Gloo assumes that this machine's hostname can always be resolved // to an address. If it doesn't it throws a runtime error saying // that it can't be resolved. Instead of catching it, we choose // to proactively check if an address can be resolved, so we can // gracefully fall back to an alternative if it doesn't. bool doesHostnameResolveToUsableAddress(const std::string& hostname) { + socketInitialize(); struct addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; @@ -431,7 +476,11 @@ bool doesHostnameResolveToUsableAddress(const std::string& hostname) { continue; } rv = bind(fd, rp->ai_addr, rp->ai_addrlen); +#ifdef _WIN32 + closesocket(fd); +#else close(fd); +#endif if (rv == -1) { continue; } @@ -443,14 +492,11 @@ bool doesHostnameResolveToUsableAddress(const std::string& hostname) { } // namespace -#if defined(__linux__) || defined(__APPLE__) std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDeviceForInterface(const std::string& interface) { - return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface); + createDeviceForInterface(const std::string& interface_name) { + return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface_name); } -#endif -#if defined(__linux__) || defined(__APPLE__) std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: createDeviceForHostname(const std::string& hostname) { TORCH_CHECK( @@ -460,14 +506,14 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: " to a (local) address"); return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname); } -#endif -#ifdef __linux__ +#if defined(__linux__) || defined(_WIN32) std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: createDefaultDevice() { // Use the hostname to resolve the network address to // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. + socketInitialize(); std::array hostname{}; auto rv = gethostname(hostname.data(), HOST_NAME_MAX); if (rv != 0) { @@ -516,7 +562,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: #endif ProcessGroupGloo::ProcessGroupGloo( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options) @@ -608,11 +654,11 @@ void ProcessGroupGloo::runLoop(int workerIndex) { AsyncWork::execute(std::move(work)); lock.lock(); - workInProgress_[workerIndex] = nullptr; + workInProgress_[workerIndex].reset(); } } -void ProcessGroupGloo::enqueue(std::shared_ptr work) { +void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); workQueue_.push_back(std::move(work)); lock.unlock(); @@ -632,7 +678,8 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { int rootRank, int rootTensor, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:broadcast"), + context(context), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), @@ -727,7 +774,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { } // namespace -std::shared_ptr ProcessGroupGloo::broadcast( +c10::intrusive_ptr ProcessGroupGloo::broadcast( std::vector& inputs, const BroadcastOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -750,15 +797,15 @@ std::shared_ptr ProcessGroupGloo::broadcast( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #endif } else { @@ -778,7 +825,8 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { std::vector& inputs, ReduceOp reduceOp, uint32_t tag) - : context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {} + : ProcessGroupGloo::AsyncWork("gloo:all_reduce"), + context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {} std::shared_ptr context; std::vector inputs; @@ -1253,7 +1301,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::allreduce( +c10::intrusive_ptr ProcessGroupGloo::allreduce( std::vector& inputs, const AllreduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1282,15 +1330,15 @@ std::shared_ptr ProcessGroupGloo::allreduce( "(allreduce of sparse tensors only works with ReduceOp.SUM)"); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1298,10 +1346,10 @@ std::shared_ptr ProcessGroupGloo::allreduce( #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); @@ -1315,7 +1363,7 @@ std::shared_ptr ProcessGroupGloo::allreduce( return work; } -std::shared_ptr ProcessGroupGloo::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1358,12 +1406,12 @@ std::shared_ptr ProcessGroupGloo::allreduce_coalesced( invalidArgument("unsupported layout"); } - std::shared_ptr work; + c10::intrusive_ptr work; const uint32_t tag = nextTag(); std::shared_ptr context = getContext(tag); if (device.type() == c10::kCPU) { if (layout == c10::kStrided) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), tensors, opts.reduceOp, tag); } else { invalidArgument("unsupported layout"); @@ -1386,7 +1434,8 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { int rootTensor, ReduceOp reduceOp, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:reduce"), + context(context), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), @@ -1490,7 +1539,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { } // namespace -std::shared_ptr ProcessGroupGloo::reduce( +c10::intrusive_ptr ProcessGroupGloo::reduce( std::vector& inputs, const ReduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { @@ -1513,11 +1562,11 @@ std::shared_ptr ProcessGroupGloo::reduce( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1526,7 +1575,7 @@ std::shared_ptr ProcessGroupGloo::reduce( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, @@ -1550,7 +1599,8 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { std::vector>& outputs, std::vector& inputs, uint32_t tag) - : context(context), outputs(outputs), inputs(inputs), tag(tag) {} + : ProcessGroupGloo::AsyncWork("gloo:all_gather"), + context(context), outputs(outputs), inputs(inputs), tag(tag) {} std::shared_ptr context; std::vector> outputs; @@ -1671,7 +1721,7 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { // Note: current CUDA implementation holds the assumption that the // tensors in the nested output tensor vectors are on the same device. -std::shared_ptr ProcessGroupGloo::allgather( +c10::intrusive_ptr ProcessGroupGloo::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { @@ -1720,15 +1770,15 @@ std::shared_ptr ProcessGroupGloo::allgather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, tag); #endif } else { @@ -1747,7 +1797,8 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { std::vector>& output_lists, std::vector& input_list, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:all_gather"), + context(context), output_lists(output_lists), input_list(input_list), tag(tag) {} @@ -1802,7 +1853,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::allgather_coalesced( +c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& /* unused */) { @@ -1852,13 +1903,13 @@ std::shared_ptr ProcessGroupGloo::allgather_coalesced( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), output_lists, input_list, tag); enqueue(work); return work; } -std::shared_ptr ProcessGroupGloo::allgather_base( +c10::intrusive_ptr ProcessGroupGloo::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { @@ -1876,7 +1927,8 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { std::vector& inputs, int root, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:gather"), + context(context), outputs(outputs), inputs(inputs), root(root), @@ -2006,7 +2058,7 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { } // namespace -std::shared_ptr ProcessGroupGloo::gather( +c10::intrusive_ptr ProcessGroupGloo::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { @@ -2052,15 +2104,15 @@ std::shared_ptr ProcessGroupGloo::gather( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2080,7 +2132,8 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { std::vector>& inputs, int root, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:scatter"), + context(context), outputs(outputs), inputs(inputs), root(root), @@ -2193,7 +2246,7 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { } // namespace -std::shared_ptr ProcessGroupGloo::scatter( +c10::intrusive_ptr ProcessGroupGloo::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { @@ -2238,15 +2291,15 @@ std::shared_ptr ProcessGroupGloo::scatter( invalidArgument(c10::str("unsupported device type ", device.type())); } - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { @@ -2256,7 +2309,7 @@ std::shared_ptr ProcessGroupGloo::scatter( return work; } -std::shared_ptr ProcessGroupGloo::reduce_scatter( +c10::intrusive_ptr ProcessGroupGloo::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { @@ -2274,7 +2327,8 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { std::vector& outputCounts, std::vector& inputCounts, uint32_t tag) - : context(context), + : ProcessGroupGloo::AsyncWork("gloo:all_to_all"), + context(context), outputTensor(outputTensor), inputTensor(inputTensor), outputCounts(std::move(outputCounts)), @@ -2390,7 +2444,7 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { } // namespace -std::shared_ptr ProcessGroupGloo::alltoall_base( +c10::intrusive_ptr ProcessGroupGloo::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, @@ -2407,12 +2461,12 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( assertDense(invalidArgument, {inputTensor}); const auto& device = outputTensor.device(); - std::shared_ptr work; + c10::intrusive_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2421,7 +2475,7 @@ std::shared_ptr ProcessGroupGloo::alltoall_base( tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { - work = std::make_shared( + work = c10::make_intrusive( std::move(context), outputTensor, inputTensor, @@ -2457,7 +2511,7 @@ uint32_t checkTag(int32_t tag) { return (uint32_t)tag; } -std::shared_ptr ProcessGroupGloo::send( +c10::intrusive_ptr ProcessGroupGloo::send( std::vector& tensors, int dstRank, int tag) { @@ -2473,10 +2527,10 @@ std::shared_ptr ProcessGroupGloo::send( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the send. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recv( +c10::intrusive_ptr ProcessGroupGloo::recv( std::vector& tensors, int srcRank, int tag) { @@ -2492,10 +2546,10 @@ std::shared_ptr ProcessGroupGloo::recv( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } -std::shared_ptr ProcessGroupGloo::recvAnysource( +c10::intrusive_ptr ProcessGroupGloo::recvAnysource( std::vector& tensors, int tag) { auto& tensor = checkSingleTensor(tensors); @@ -2520,7 +2574,7 @@ std::shared_ptr ProcessGroupGloo::recvAnysource( // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. - return std::make_shared(tensor, std::move(buf)); + return c10::make_intrusive(tensor, std::move(buf)); } namespace { @@ -2529,12 +2583,13 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { public: AsyncBarrierWork( const std::shared_ptr& context, - std::vector> priorWork, + std::vector> priorWork, uint32_t tag) - : context(context), priorWork(std::move(priorWork)), tag(tag) {} + : ProcessGroupGloo::AsyncWork("gloo:barrier"), + context(context), priorWork(std::move(priorWork)), tag(tag) {} std::shared_ptr context; - std::vector> priorWork; + std::vector> priorWork; const uint32_t tag; void run() override { @@ -2554,9 +2609,9 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { } // namespace -std::shared_ptr ProcessGroupGloo::barrier( +c10::intrusive_ptr ProcessGroupGloo::barrier( const BarrierOptions& opts) { - std::vector> priorWork; + std::vector> priorWork; // Snapshot all in progress and pending work as weak_ptr. // When executing a barrier, we need to ensure that all prior work @@ -2570,7 +2625,7 @@ std::shared_ptr ProcessGroupGloo::barrier( auto tag = nextTag(); auto context = getContext(tag); - auto work = std::make_shared( + auto work = c10::make_intrusive( std::move(context), std::move(priorWork), tag); enqueue(work); return work; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index dfae068de2440..0508b6f857a11 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -68,7 +68,9 @@ class ProcessGroupGloo : public ProcessGroup { // class AsyncWork : public ProcessGroup::Work { public: - static void execute(std::shared_ptr work) { + AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {} + + static void execute(c10::intrusive_ptr work) { std::exception_ptr eptr; try { work->run(); @@ -150,82 +152,82 @@ class ProcessGroupGloo : public ProcessGroup { static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); explicit ProcessGroupGloo( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, Options options = Options()); virtual ~ProcessGroupGloo(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, std::vector& inputCounts, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; protected: @@ -256,7 +258,7 @@ class ProcessGroupGloo : public ProcessGroup { void runLoop(int workerIndex); // Queue work to run on worker thread. - void enqueue(std::shared_ptr work); + void enqueue(c10::intrusive_ptr work); // Keep both a queue of pending work, and a vector with in progress work. // Both of these can only be mutated when holding the queue lock. @@ -264,8 +266,8 @@ class ProcessGroupGloo : public ProcessGroup { // to all in progress and pending work when executing a barrier. // When executing a barrier, we need to ensure that all prior work // has completed before completing itself. - std::deque> workQueue_; - std::vector> workInProgress_; + std::deque> workQueue_; + std::vector> workInProgress_; std::mutex workMutex_; std::condition_variable workProduceCV_; std::condition_variable workConsumeCV_; diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index d3e79a1dd4245..250b635e8c69c 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -199,7 +199,7 @@ void ProcessGroupMPI::initMPIOnce() { }); } -std::shared_ptr ProcessGroupMPI::createProcessGroupMPI( +c10::intrusive_ptr ProcessGroupMPI::createProcessGroupMPI( std::vector ranks) { // Once initialization initMPIOnce(); @@ -238,10 +238,10 @@ std::shared_ptr ProcessGroupMPI::createProcessGroupMPI( // process group instance. This is in line with the semantics of the // other process group types. if (groupComm == MPI_COMM_NULL) { - return std::shared_ptr(); + return c10::intrusive_ptr(); } - return std::make_shared(rank, size, groupComm); + return c10::make_intrusive(rank, size, groupComm); } ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm) @@ -308,9 +308,9 @@ void ProcessGroupMPI::runLoop() { } } -std::shared_ptr ProcessGroupMPI::enqueue( +c10::intrusive_ptr ProcessGroupMPI::enqueue( std::unique_ptr entry) { - auto work = std::make_shared(); + auto work = c10::make_intrusive(); std::unique_lock lock(pgMutex_); queue_.push_back(std::make_tuple(std::move(entry), work)); lock.unlock(); @@ -318,7 +318,7 @@ std::shared_ptr ProcessGroupMPI::enqueue( return work; } -std::shared_ptr ProcessGroupMPI::broadcast( +c10::intrusive_ptr ProcessGroupMPI::broadcast( std::vector& tensors, const BroadcastOptions& opts) { checkSingleTensor(tensors); @@ -339,7 +339,7 @@ std::shared_ptr ProcessGroupMPI::broadcast( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce( +c10::intrusive_ptr ProcessGroupMPI::allreduce( std::vector& tensors, const AllreduceOptions& opts) { checkSingleTensor(tensors); @@ -362,14 +362,14 @@ std::shared_ptr ProcessGroupMPI::allreduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with MPI"); } -std::shared_ptr ProcessGroupMPI::reduce( +c10::intrusive_ptr ProcessGroupMPI::reduce( std::vector& tensors, const ReduceOptions& opts) { checkSingleTensor(tensors); @@ -397,7 +397,7 @@ std::shared_ptr ProcessGroupMPI::reduce( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather( +c10::intrusive_ptr ProcessGroupMPI::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -441,7 +441,7 @@ std::shared_ptr ProcessGroupMPI::allgather( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_coalesced( +c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -449,7 +449,7 @@ std::shared_ptr ProcessGroupMPI::allgather_coalesced( "ProcessGroupMPI does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupMPI::gather( +c10::intrusive_ptr ProcessGroupMPI::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { @@ -516,7 +516,7 @@ std::shared_ptr ProcessGroupMPI::gather( } } -std::shared_ptr ProcessGroupMPI::scatter( +c10::intrusive_ptr ProcessGroupMPI::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { @@ -582,14 +582,14 @@ std::shared_ptr ProcessGroupMPI::scatter( } } -std::shared_ptr ProcessGroupMPI::reduce_scatter( +c10::intrusive_ptr ProcessGroupMPI::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupMPI does not support reduce_scatter"); } -std::shared_ptr ProcessGroupMPI::alltoall_base( +c10::intrusive_ptr ProcessGroupMPI::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -665,7 +665,7 @@ std::shared_ptr ProcessGroupMPI::alltoall_base( return enqueue(std::move(entry)); } } -std::shared_ptr ProcessGroupMPI::alltoall( +c10::intrusive_ptr ProcessGroupMPI::alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts) { @@ -722,7 +722,7 @@ std::shared_ptr ProcessGroupMPI::alltoall( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::send( +c10::intrusive_ptr ProcessGroupMPI::send( std::vector& tensors, int dstRank, int tag) { @@ -744,10 +744,10 @@ std::shared_ptr ProcessGroupMPI::send( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recv( +c10::intrusive_ptr ProcessGroupMPI::recv( std::vector& tensors, int srcRank, int tag) { @@ -769,10 +769,10 @@ std::shared_ptr ProcessGroupMPI::recv( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::recvAnysource( +c10::intrusive_ptr ProcessGroupMPI::recvAnysource( std::vector& tensors, int tag) { checkSingleTensor(tensors); @@ -793,10 +793,10 @@ std::shared_ptr ProcessGroupMPI::recvAnysource( &request)); } - return std::make_shared(tensor, request); + return c10::make_intrusive(tensor, request); } -std::shared_ptr ProcessGroupMPI::barrier( +c10::intrusive_ptr ProcessGroupMPI::barrier( const BarrierOptions& opts) { std::function&)> runFunc = [this](std::unique_ptr& entry) { @@ -808,7 +808,7 @@ std::shared_ptr ProcessGroupMPI::barrier( return enqueue(std::move(entry)); } -std::shared_ptr ProcessGroupMPI::allgather_base( +c10::intrusive_ptr ProcessGroupMPI::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 342fe87001a02..420c78ef028ad 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -108,95 +108,95 @@ class ProcessGroupMPI : public ProcessGroup { // Abort the MPI program, needs to be called when exception is detected void abort(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, - int tag); + int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, - int tag); + int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensor, - int tag); + int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized - static std::shared_ptr createProcessGroupMPI( + static c10::intrusive_ptr createProcessGroupMPI( std::vector ranks = {}); protected: using WorkType = - std::tuple, std::shared_ptr>; + std::tuple, c10::intrusive_ptr>; // Worker thread loop void runLoop(); // Helper function that is called by the destructor void destroy(); - std::shared_ptr enqueue(std::unique_ptr entry); + c10::intrusive_ptr enqueue(std::unique_ptr entry); bool stop_; diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index a3765670c6b21..473fcc0d4419e 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -8,6 +8,8 @@ #include #include +#include +#include #include namespace c10d { @@ -107,6 +109,13 @@ std::string getKeyFromDevices(const std::vector& devices) { return deviceList; } +std::string getKeySendRecv(int myRank, int peer) { + int lowRank = myRank < peer ? myRank : peer; + int highRank = myRank < peer ? peer : myRank; + std::string sendRecvPair = std::to_string(lowRank) + ":" + std::to_string(highRank); + return sendRecvPair; +} + // Get the list of devices from list of tensors std::vector getDeviceList(const std::vector& tensors) { std::vector res; @@ -158,31 +167,6 @@ std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { } #ifdef ENABLE_NCCL_P2P_SUPPORT -ncclResult_t ncclAlltoall( - void* sendbuff, - void* recvbuff, - size_t count, - size_t size, - ncclDataType_t type, - ncclComm_t comm, - cudaStream_t stream) { - int numranks; - size_t rankdiff = count * size; - C10D_NCCL_CHECK(ncclCommCount(comm, &numranks)); - C10D_NCCL_CHECK(ncclGroupStart()); - for (int r = 0; r < numranks; r++) { - // NCCL uses 0 byte message for synchronization - // Avoid send/recv when message size is zero - if (count != 0) { - C10D_NCCL_CHECK(ncclSend( - ((char*)sendbuff) + r * rankdiff, count, type, r, comm, stream)); - C10D_NCCL_CHECK(ncclRecv( - ((char*)recvbuff) + r * rankdiff, count, type, r, comm, stream)); - } - } - C10D_NCCL_CHECK(ncclGroupEnd()); - return ncclSuccess; -} ncclResult_t ncclAlltoallv( void* sendbuff, @@ -232,9 +216,35 @@ const int64_t ProcessGroupNCCL::kWorkCleanupThreadSleepMillis = 1000; constexpr int64_t kWaitForAbortCommStoreKey = 1000; constexpr int64_t kSynchronizeBusyWaitMillis = 10; const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000; +thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; + +std::ostream& operator<<( + std::ostream& output, + const ProcessGroupNCCL::WorkNCCL& workNCCL) { + std::string workInfo; + if (workNCCL.outputs_) { + workInfo = c10::str("WorkNCCL(", + "OpType=", opTypeToString(workNCCL.opType_), + ", TensorShape=", (*workNCCL.outputs_)[0].sizes(), + ", Timeout(ms)=", workNCCL.opTimeout_.count(), + ")"); + } else { + workInfo = c10::str("WorkNCCL(", + "OpType=", opTypeToString(workNCCL.opType_), + ", Timeout(ms)=", workNCCL.opTimeout_.count(), + ")"); + } + return output << workInfo; +} -ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) - : devices_(devices), workStartTime_(std::chrono::steady_clock::now()) { +ProcessGroupNCCL::WorkNCCL::WorkNCCL( + const std::vector& devices, + int rank, + OpType opType, + const char* profilingTitle) + : Work(rank, opType, profilingTitle), + devices_(devices), + workStartTime_(std::chrono::steady_clock::now()) { // Creates the CUDA event wrappers // Note: The actual events are lazily created when first recorded to with // DEFAULT_FLAGS = cudaEventDisableTiming. @@ -243,6 +253,19 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector& devices) ncclComms_.resize(devices.size()); } +ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) + : Work(w.rank_, w.opType_), + std::enable_shared_from_this(w), + devices_(w.devices_), + cudaEvents_(w.cudaEvents_), + ncclComms_(w.ncclComms_), + blockingWait_(w.blockingWait_), + opTimeout_(w.opTimeout_), + workStartTime_(w.workStartTime_) { + completed_ = w.completed_; + exception_ = w.exception_; +} + ProcessGroupNCCL::WorkNCCL::~WorkNCCL() {} bool ProcessGroupNCCL::WorkNCCL::isCompleted() { @@ -285,11 +308,7 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() { bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { for (size_t i = 0; i < devices_.size(); ++i) { // Checking the work's corresponding CUDA events' status - auto ret = cudaEventQuery((*cudaEvents_)[i]); - if (ret != cudaSuccess && ret != cudaErrorNotReady) { - AT_CUDA_CHECK(ret); - } - if (ret == cudaErrorNotReady) { + if (!(*cudaEvents_)[i].query()) { return false; } } @@ -310,6 +329,13 @@ void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() { std::lock_guard lock(mutex_); completed_ = true; if (exception_) { + auto exceptionMsg = c10::str( + "Some NCCL operations have failed or timed out. Due to the ", + "asynchronous nature of CUDA kernels, subsequent GPU operations ", + "might run on corrupted/incomplete data. To avoid this inconsistency, ", + "we are taking the entire process down."); + LOG(ERROR) << exceptionMsg; + C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleNCCLGuard"); std::rethrow_exception(exception_); } } @@ -355,10 +381,26 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( ncclComm->ncclCommAbort(); const auto& storeKey = getNcclAbortedCommStoreKey( buildNcclUniqueIdStr(ncclComm->getNcclId())); - store_->set(storeKey, {}); - LOG(INFO) << "Wrote aborted communicator id to store: " << storeKey; + auto rankStr = std::to_string(rank_); + store_->set( + storeKey, + std::vector( + reinterpret_cast(rankStr.data()), + reinterpret_cast(rankStr.data()) + + rankStr.size())); + LOG(INFO) << "[Rank " << rank_ + << "] Wrote aborted communicator id to store: " << storeKey; } - throw std::runtime_error("Operation timed out!"); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + std::string exceptionMsg = c10::str("[Rank ", rank_, "] ", + "Caught collective operation timeout: ", + (*this), + " ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + throw std::runtime_error(exceptionMsg); } // Check for errors and throw appropriate exception. checkAndThrowException(); @@ -389,36 +431,6 @@ void ProcessGroupNCCL::WorkNCCL::abort() { TORCH_CHECK(false, "ProcessGroupNCCL::WorkNCCL::abort not implemented."); } -void ProcessGroupNCCL::parseNcclBlockingWait() { - char* blockingWait = getenv(NCCL_BLOCKING_WAIT); - if (blockingWait != nullptr) { - auto val = std::stoi(blockingWait); - if (val == 1) { - // Make wait() and synchronize() a blocking call. - blockingWait_ = true; - } else if (val != 0) { - throw std::runtime_error( - "Invalid value for environment variable: " + - std::string(NCCL_BLOCKING_WAIT)); - } - } -} - -void ProcessGroupNCCL::parseNcclAsyncErrorHandling() { - char* errorHandle = getenv(NCCL_ASYNC_ERROR_HANDLING); - if (errorHandle != nullptr) { - auto val = std::stoi(errorHandle); - if (val == 1) { - asyncErrorHandling_ = true; - LOG(INFO) << "[Rank " << rank_ << "] NCCL Async Error Handling enabled."; - } else if (val != 0) { - throw std::runtime_error( - "Invalid value for environment variable: " + - std::string(NCCL_ASYNC_ERROR_HANDLING)); - } - } -} - bool ProcessGroupNCCL::WorkNCCL::timedOut() { auto currentTimepoint = std::chrono::steady_clock::now(); return ( @@ -427,30 +439,27 @@ bool ProcessGroupNCCL::WorkNCCL::timedOut() { } ProcessGroupNCCL::ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, - Options options) + c10::intrusive_ptr options) : ProcessGroup(rank, size), store_(store), ncclCommCounter_(0), terminateProcessGroup_(false), - opTimeout_(options.opTimeout), - futureNCCLCallbackStreams_(c10::cuda::device_count()), - isHighPriorityStream_(options.isHighPriorityStream) { - try { - parseNcclBlockingWait(); - } catch (std::exception& e) { - throw std::runtime_error( - "Invalid value for environment variable: " + - std::string(NCCL_BLOCKING_WAIT)); - } - try { - parseNcclAsyncErrorHandling(); - } catch (std::exception& e) { - throw std::runtime_error( - "Invalid value for environment variable: " + - std::string(NCCL_ASYNC_ERROR_HANDLING)); + opTimeout_(options->opTimeout), + isHighPriorityStream_(options->isHighPriorityStream) { + TORCH_CHECK(at::cuda::getNumGPUs() != 0, + "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); + blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT); + asyncErrorHandling_ = parseEnvVarFlag(NCCL_ASYNC_ERROR_HANDLING); + + if (blockingWait_ && asyncErrorHandling_) { + LOG(INFO) << "[Rank " << rank_ + << "] NCCL_BLOCKING_WAIT and NCCL_ASYNC_ERROR_HANDLING " + << "should not both be enabled. " + << "Only NCCL_BLOCKING_WAIT is being used in this process."; + asyncErrorHandling_ = false; } #ifdef ENABLE_NCCL_ERROR_CHECKING @@ -461,34 +470,18 @@ ProcessGroupNCCL::ProcessGroupNCCL( if (asyncErrorHandling_) { workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this); } + LOG(INFO) << "[Rank " << rank_ + << "] ProcessGroupNCCL initialized with following options:" + << "\nNCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ + << "\nNCCL_BLOCKING_WAIT: " << blockingWait_ + << "\nTIMEOUT(ms): " << opTimeout_.count() + << "\nUSE_HIGH_PRIORITY_STREAM: " << isHighPriorityStream_; } ProcessGroupNCCL::~ProcessGroupNCCL() { terminateProcessGroup_.store(true); - watchdogCV_.notify_one(); - workListCV_.notify_one(); - - if (asyncErrorHandling_) { - std::unique_lock lock(workListMutex_); - // TODO: We can potentially merge this functionality into the workCleanup - // thread or just allow the destructor to free workList_. - // Clean up any remaining items in the workList_ instead of waiting for the - // workCleanup Thread to be scheduled again. - for (auto it = workList_.begin(); it != workList_.end(); - /* no increment*/) { - auto& work = *it; - if (work->isCompleted()) { - it = workList_.erase(it); - } else { - ++it; - } - } - // Wait for workList_ to become empty before proceeding with shutdown. - workListCV_.wait(lock, [&]() -> bool { return workList_.empty(); }); - lock.unlock(); - workCleanupThread_.join(); - } + watchdogCV_.notify_one(); #ifdef ENABLE_NCCL_ERROR_CHECKING ncclCommWatchdogThread_.join(); #endif @@ -504,16 +497,59 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { } } } + + if (asyncErrorHandling_) { + workMetaListCV_.notify_one(); + workCleanupThread_.join(); + } +} + +void ProcessGroupNCCL::abortTimedOutCollectives(std::unordered_set& abortedCommIds) { + std::unique_lock lock(workMetaListMutex_); + for (auto& work : workMetaList_) { + work.checkAndSetException(); + // Aborting NCCL Communicators due to errors is already handled above. + if (work.exception()) { + continue; + } + + // Check for Timeouts in the WorkNCCL Operations, and abort all + // communicators accordingly. + if (work.timedOut()) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - work.workStartTime_); + std::string exceptionMsg = c10::str("[Rank ", rank_, "] ", + "Watchdog caught collective operation timeout: ", + work, + " ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + LOG(ERROR) << exceptionMsg; + std::exception_ptr exception_ptr = std::make_exception_ptr( + std::runtime_error(exceptionMsg)); + work.setException(exception_ptr); + for (const auto& ncclComm : work.ncclComms_) { + ncclComm->ncclCommAbort(); + abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); + } + } + } } void ProcessGroupNCCL::ncclCommWatchdog() { try { + LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread started!"; ncclCommWatchdogInternal(); - LOG(INFO) << "NCCL watchdog thread terminated normally"; + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated normally"; } catch (std::exception& e) { - LOG(INFO) << "NCCL watchdog thread terminated with exception: " << e.what(); + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated with exception: " + << e.what(); } catch (...) { - LOG(INFO) << "NCCL watchdog thread terminated with unknown exception"; + LOG(INFO) << "[Rank " << rank_ + << "] NCCL watchdog thread terminated with unknown exception"; } } @@ -534,10 +570,12 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { } if (checkForNCCLErrors(ncclComms)) { - LOG(INFO) << "Received NCCL errors for communicators in the cache"; + LOG(INFO) << "[Rank " << rank_ + << "] Received NCCL errors for communicators in the cache"; if (blockingWait_ || asyncErrorHandling_) { - LOG(INFO) << "Aborting communicators that received errors"; + LOG(INFO) << "[Rank " << rank_ + << "] Aborting communicators that received errors"; // We abort NCCL communicators that have received errors from this // thread, and exceptions are set on the corresponding work objects. // The workCleanupThread will then loop through the unfinished @@ -554,7 +592,8 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { // a communicator the application receives an exception and its // their responsibility to destroy the process group and recreate // it to recover from errors. - abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); + abortedCommIds.emplace( + buildNcclUniqueIdStr(ncclComm->getNcclId())); } } } @@ -562,26 +601,7 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { } if (asyncErrorHandling_) { - std::unique_lock lock(workListMutex_); - for (auto& work : workList_) { - work->checkAndSetException(); - // Aborting NCCL Communicators due to errors is already handled above. - if (work->exception()) { - continue; - } - - // Check for Timeouts in the WorkNCCL Operations, and abort all - // communicators accordingly. - if (work->timedOut()) { - std::exception_ptr exception_ptr = std::make_exception_ptr( - std::runtime_error("NCCL Operation Timed Out")); - work->setException(exception_ptr); - for (const auto& ncclComm : work->ncclComms_) { - ncclComm->ncclCommAbort(); - abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); - } - } - } + abortTimedOutCollectives(abortedCommIds); } if (blockingWait_) { @@ -595,8 +615,15 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { for (const auto& abortedCommId : abortedCommIds) { abortedComms_.emplace(abortedCommId); const auto& storeKey = getNcclAbortedCommStoreKey(abortedCommId); - store_->set(storeKey, {}); - LOG(INFO) << "Watchdog wrote aborted communicator id to store: " + auto rankStr = std::to_string(rank_); + store_->set( + storeKey, + std::vector( + reinterpret_cast(rankStr.data()), + reinterpret_cast(rankStr.data()) + + rankStr.size())); + LOG(INFO) << "[Rank " << rank_ + << "] Watchdog wrote aborted communicator id to store: " << storeKey; } @@ -610,7 +637,11 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { store_->wait( {storeKey}, std::chrono::milliseconds(kWaitForAbortCommStoreKey)); - LOG(INFO) << "Found key in store: " << storeKey + auto val = store_->get(storeKey); + std::string rank(reinterpret_cast(val.data()), val.size()); + LOG(INFO) << "[Rank " << rank_ + << "] Found key in store: " << storeKey + << ", from rank: " << rank << ", aborting appropriate communicators"; // Now abort the appropriate communicators. @@ -621,7 +652,9 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { ncclComm->ncclCommAbort(); } abortedComms_.emplace(commId); - LOG(INFO) << "Aborted communicators for key in store: " << storeKey; + LOG(INFO) << "[Rank " << rank_ + << "] Aborted communicators for key in store: " + << storeKey; } catch (std::exception& e) { VLOG(1) << "Did not find key in store: " << storeKey << ", error: " << e.what(); @@ -639,36 +672,38 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { } void ProcessGroupNCCL::workCleanupLoop() { - while (!terminateProcessGroup_.load()) { - std::unique_lock lock(workListMutex_); - // We busy-poll the work vector every kWatchdogThreadSleepMillis - // milliseconds as long as the atomic is True. - workListCV_.wait_for( - lock, - std::chrono::milliseconds(kWorkCleanupThreadSleepMillis), - [&]() -> bool { return terminateProcessGroup_.load(); }); - - for (auto it = workList_.begin(); it != workList_.end(); - /* no increment*/) { - auto& work = *it; - if (work->isCompleted()) { - // Handle Exceptions on failed GPU operations and remove completed - // workNCCL objects from work vector. - work->handleNCCLGuard(); - it = workList_.erase(it); - } else { - // Increment the iterator if the current WorkNCCL object is not - // completed. - ++it; + bool done = false; + while (!terminateProcessGroup_.load() || !done) { + std::list doneWorks; + { + std::unique_lock lock(workMetaListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + workMetaListCV_.wait_for( + lock, + std::chrono::milliseconds(kWorkCleanupThreadSleepMillis), + [&]() -> bool { return terminateProcessGroup_.load(); }); + + for (auto it = workMetaList_.begin(); it != workMetaList_.end(); + /* no increment*/) { + auto& work = *it; + if (work.isCompleted()) { + // Handle Exceptions on failed GPU operations and remove completed + // workNCCL objects from work vector. + if (!terminateProcessGroup_.load()) { + work.handleNCCLGuard(); + } + doneWorks.push_back(std::move(*it)); + it = workMetaList_.erase(it); + } else { + // Increment the iterator if the current WorkNCCL object is not + // completed. + ++it; + } } + done = workMetaList_.empty(); } - - if (workList_.empty()) { - // Notify the main thread if it is blocked in the shutdown sequence, - // waiting for the work vector to become empty. - lock.unlock(); - workListCV_.notify_one(); - } + doneWorks.clear(); } } @@ -695,15 +730,32 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( return nullptr; } -void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) { +void ProcessGroupNCCL::broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + OpType opType, + const std::string& p2pKey, + int p2pRank) { + // For collective operations: // For every NCCL communicator that we create we need to broadcast // a unique ID from rank 0 to all other ranks. This broadcast is // done by rank 0 setting a key in the store and all other ranks // retrieving the contents of that key. A single process group // may create multiple NCCL communicators, so we use a sequence // number to differentiate between them. - std::string storeKey = std::to_string(ncclCommCounter_++); - if (rank_ == 0) { + // For point-to-point operations: + // The sequence number will only be increased on 2 out of all the + // processes in a Process Group. So all following collective + // operations will see different sequence numbers which will cause + // runtime errors. To avoid that, use the src:target pair instead + // of sequence number for p2p communications. + + std::string storeKey; + if (!isP2POp(opType)) { + storeKey = std::to_string(ncclCommCounter_++); + } else { + storeKey = p2pKey; + } + if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) { auto vec = std::vector( reinterpret_cast(ncclID), reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); @@ -717,7 +769,10 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) { std::vector>& ProcessGroupNCCL::getNCCLComm( const std::string& devicesKey, - const std::vector& devices) { + const std::vector& devices, + OpType opType, + int p2pRank, + bool isSendRecvSelf) { // Sanity check if (devicesKey.empty()) { throw std::runtime_error( @@ -744,25 +799,61 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // Create the unique NCCL ID and broadcast it ncclUniqueId ncclID; - if (rank_ == 0) { + // For point-to-point communication, lower rank of the two will get unique id. + if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) { C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID)); } - // Broadcast so that each process can have a unique NCCL ID - broadcastUniqueNCCLID(&ncclID); + // For point-to-point communication on the same process, don't need broadcast. + if (!isSendRecvSelf) { + // Broadcast so that each process can have a unique NCCL ID + broadcastUniqueNCCLID(&ncclID, opType, devicesKey, p2pRank); + } at::cuda::OptionalCUDAGuard gpuGuard; std::vector streamVal; streamVal.reserve(devices.size()); - // Create the NCCL communicators for each GPU + // [Group Start/End Note] This is used to ensure that nccl communicator will be created + // before communication primitives are called. Let's look at this example: + // Using the batch_isend_irecv to send a tensor to a target process. On the sender side, + // the corresponding underlying NCCL calls will look like + // ncclGroupStart() // This is in batch_isend_irecv + // ncclGroupStart() // This is [Note 1] + // ncclCommInitRank() // Inside NCCLComm::create + // ncclSend() + // ncclGroupEnd() // This is [Note 2] + // ncclGroupEnd() // This is in batch_isend_irecv + // With this pattern, the nccl communicator will be created in the last ncclGroupEnd + // which means when ncclSend is processed, the passed communicator argument is NULL which will + // lead to runtime error. So we need to "close" all active nccl groups to ensure + // nccl communicator is actually created before encountering any communication calls. + // This is why we need the following for loop. + for (size_t i = 0; i < ncclActiveGroupCounter_; ++i) { + C10D_NCCL_CHECK(ncclGroupEnd()); + } + + // [Note 1] Create the NCCL communicators for each GPU C10D_NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < devices.size(); ++i) { // GPU world size and GPU rank - int numRanks = getSize() * devices.size(); - int rank = getRank() * devices.size() + i; + int numRanks, rank; + + if (!isP2POp(opType)) { + numRanks = getSize() * devices.size(); + rank = getRank() * devices.size() + i; + } else if(isSendRecvSelf) { + // Same process send and recv. + numRanks = 1; + rank = 0; + } else { + // For point-to-point operation, there are only 2 processes involved so + // the GPU rank is either 0 or 1. + numRanks = 2; + rank = p2pRank; + } // Get the device index int deviceIndex = devices[i].index(); @@ -771,18 +862,16 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // Creates the NCCL streams streamVal.push_back(at::cuda::getStreamFromPool(isHighPriorityStream_)); - - // If not set before, get a dedicated stream for the device to run - // FutureNCCL then callbacks. - std::lock_guard lock(mutex_); - if (futureNCCLCallbackStreams_[deviceIndex] == nullptr) { - futureNCCLCallbackStreams_[deviceIndex] = - std::make_shared(at::cuda::getStreamFromPool(isHighPriorityStream_)); - } } + // [Note 2 ] C10D_NCCL_CHECK(ncclGroupEnd()); + // See [Group Start/End Note] + for (size_t i = 0; i < ncclActiveGroupCounter_; ++i) { + C10D_NCCL_CHECK(ncclGroupStart()); + } + ncclStreams_.emplace(devicesKey, std::move(streamVal)); // Note: these events are created with the (default) cudaEventDisableTiming @@ -901,9 +990,12 @@ std::vector flatten_for_scatter_gather( } // namespace -std::shared_ptr ProcessGroupNCCL::initWork( - std::vector devices) { - return std::make_shared(devices); +c10::intrusive_ptr ProcessGroupNCCL::initWork( + std::vector devices, + int rank, + OpType opType, + const char* profilingTitle) { + return c10::make_intrusive(devices, rank, opType, profilingTitle); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -912,50 +1004,46 @@ std::vector ProcessGroupNCCL::WorkNCCL::result() { c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: getFuture() { - TORCH_INTERNAL_ASSERT( - outputs_->size() == 1, - "WorkNCCL's getFuture API is only supported for single-process single-device mode."); - auto deviceIndex = (*outputs_)[0].device().index(); - // Create a new FutureNCCL object after checking for single-process - // single-device mode. - return c10::make_intrusive( - at::IValue(*outputs_), - deviceIndex, - cudaEvents_, - futureNCCLCallbackStreams_[deviceIndex]); + return future_; } void ProcessGroupNCCL::workEnqueue( - std::shared_ptr work) { + c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { - std::lock_guard lock(workListMutex_); - workList_.emplace_back(std::move(work)); + std::lock_guard lock(workMetaListMutex_); + // Avoid view tensors to be processed in cleanup thread. + // View tensors' destruction invokes autograd_meta, which + // needs to be destructed in user thread. Otherwise will + // get deadlock. Here we enqueue work without outputs_. + workMetaList_.emplace_back(WorkNCCL(*work)); } } ProcessGroupNCCL::Options::Options() - : opTimeout(kProcessGroupNCCLOpTimeoutMillis), isHighPriorityStream(false) {} + : opTimeout(kProcessGroupNCCLOpTimeoutMillis), + isHighPriorityStream(false) {} template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, Fn fn, PreProcess pre, - PostProcess post) { + PostProcess post, + OpType opType, + const char* profilingTitle) { const auto devices = getDeviceList(inputs); const auto key = getKeyFromDevices(devices); - auto& ncclComms = getNCCLComm(key, devices); + auto& ncclComms = getNCCLComm(key, devices, opType); // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors - auto work = initWork(devices); + bool can_profile = outputs.size() == 1; + auto work = initWork(devices, rank_, opType, can_profile ? profilingTitle : nullptr); - // Store references to outputs and futureNCCLCallbackStream to be used by - // WorkNCCL::getFuture. + // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_; at::cuda::OptionalCUDAGuard gpuGuard; @@ -994,32 +1082,148 @@ std::shared_ptr ProcessGroupNCCL::collective( at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; (*work->cudaEvents_)[i].record(ncclStream); work->ncclComms_[i] = ncclComms[i]; + } + + { + at::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->opTimeout_ = opTimeout_; + work->store_ = store_; + + if (work->recordFunctionEndCallback_) { + // recordFunctionEndCallback_ is normally called in fininsh() function by + // base class, but since finish is not called by WorkNCCL, we schedule this + // function to be run when work is done. Note that addCallback() onto the + // Work's CUDAFuture is not useful here, as it would just run the callback + // inline. + // Note when can_profile is false, profilingTitle is not provided and so, + // recordFunctionEndCallback_ is not set. + work->recordFunctionEndCallback_(); + } + + if (asyncErrorHandling_) { + workEnqueue(work); + } + + return work; +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + std::vector& tensors, + Fn fn, + int peer, + OpType opType, + PreProcess pre, + PostProcess post) { + const auto devices = getDeviceList(tensors); + const auto key = getKeySendRecv(rank_, peer); + int p2pRank = rank_ <= peer ? 0 : 1; + auto isSendRecvSelf = rank_ == peer; + auto& ncclComms = getNCCLComm(key, devices, opType, p2pRank, isSendRecvSelf); + + // First let NCCL streams wait for input tensors allocation streams + syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); + + // Work itself will create the CUDA events on all GPUs of tensors + auto work = initWork(devices, rank_, opType); + + if (opType == OpType::RECV) { + // Store references to outputs to be used by WorkNCCL::result and operator<<. + work->outputs_ = std::make_shared>(tensors); + } + + at::cuda::OptionalCUDAGuard gpuGuard; + + pre(ncclStreams_[key]); + + for (size_t i = 0; i < tensors.size(); ++i) { + gpuGuard.set_index(devices[i].index()); + at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + + // Both send tensor and recv tensor are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensors[i].storage().data_ptr(), ncclStream); + } + + { + AutoNcclGroup nccl_group_guard; + for (size_t i = 0; i < tensors.size(); ++i) { + gpuGuard.set_index(devices[i].index()); + at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + // For point-to-point communication, NCCL ranks can only + // be 0 or 1. + int p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; + C10D_NCCL_CHECK( + fn(tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank)); + } + } + + post(ncclStreams_[key]); + + // Event should only be recorded after the ncclGroupEnd() + for (size_t i = 0; i < tensors.size(); ++i) { + at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + (*work->cudaEvents_)[i].record(ncclStream); + work->ncclComms_[i] = ncclComms[i]; work->blockingWait_ = blockingWait_; work->opTimeout_ = opTimeout_; work->store_ = store_; } - if (asyncErrorHandling_) { - workEnqueue(work); + if (opType == OpType::RECV) { + at::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + work->future_->markCompleted(at::IValue(*work->outputs_)); } return work; } template -std::shared_ptr ProcessGroupNCCL::collective( +c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& inputs, std::vector& outputs, - Fn fn) { + Fn fn, + OpType opType, + const char* profilingTitle) { return collective( inputs, outputs, fn, [](std::vector&) {}, + [](std::vector&) {}, + opType, + profilingTitle); +} + +template +c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( + std::vector& tensor, + Fn fn, + int peer, + OpType opType) { + return pointToPoint( + tensor, + fn, + peer, + opType, + [](std::vector&) {}, [](std::vector&) {}); } -std::shared_ptr ProcessGroupNCCL::allreduce( +c10::intrusive_ptr ProcessGroupNCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { check_gpu_tensors(tensors); @@ -1039,17 +1243,19 @@ std::shared_ptr ProcessGroupNCCL::allreduce( getNcclReduceOp(opts.reduceOp, input), comm, stream.stream()); - }); + }, + OpType::ALLREDUCE, + "nccl:all_reduce"); } -std::shared_ptr ProcessGroupNCCL::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { throw std::runtime_error( "allreduce_coalesced is currently not supported with NCCL"); } -std::shared_ptr ProcessGroupNCCL::broadcast( +c10::intrusive_ptr ProcessGroupNCCL::broadcast( std::vector& tensors, const BroadcastOptions& opts) { check_gpu_tensors(tensors); @@ -1069,10 +1275,12 @@ std::shared_ptr ProcessGroupNCCL::broadcast( root, comm, stream.stream()); - }); + }, + OpType::BROADCAST, + "nccl:broadcast"); } -std::shared_ptr ProcessGroupNCCL::reduce( +c10::intrusive_ptr ProcessGroupNCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { check_gpu_tensors(tensors); @@ -1094,10 +1302,12 @@ std::shared_ptr ProcessGroupNCCL::reduce( root, comm, stream.stream()); - }); + }, + OpType::REDUCE, + "nccl:reduce"); } -std::shared_ptr ProcessGroupNCCL::allgather( +c10::intrusive_ptr ProcessGroupNCCL::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { @@ -1137,10 +1347,12 @@ std::shared_ptr ProcessGroupNCCL::allgather( outputTensors[i][j].copy_(outputFlattened[i][j], true); } } - }); + }, + OpType::ALLGATHER, + "nccl:all_gather"); } -std::shared_ptr ProcessGroupNCCL::allgather_coalesced( +c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { @@ -1148,7 +1360,7 @@ std::shared_ptr ProcessGroupNCCL::allgather_coalesced( "ProcessGroupNCCL does not support allgather_coalesced"); } -std::shared_ptr ProcessGroupNCCL::reduce_scatter( +c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { @@ -1189,13 +1401,21 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( } } }, - [&](std::vector& ncclStreams) {}); + [&](std::vector& ncclStreams) {}, + OpType::REDUCE_SCATTER, + "nccl:reduce_scatter"); } -std::shared_ptr ProcessGroupNCCL::barrier( +c10::intrusive_ptr ProcessGroupNCCL::barrier( const BarrierOptions& opts) { std::vector devices; - if (usedDeviceIdxs_.empty()) { + + // Use user defined GPU device ids if provided + if (!opts.device_ids.empty()) { + for (auto device : opts.device_ids) { + devices.emplace_back(at::DeviceType::CUDA, device); + } + } else if (usedDeviceIdxs_.empty()) { // This means there is not yet a NCCL collective being called // Here we have to use the best guesses and will use a single GPU to call // allreduce to achieve barrier. @@ -1203,10 +1423,10 @@ std::shared_ptr ProcessGroupNCCL::barrier( // ensure that each process is on a different GPU auto numGPUs = at::cuda::getNumGPUs(); int16_t deviceIdx = static_cast(rank_ % numGPUs); - devices.push_back(at::Device(at::DeviceType::CUDA, deviceIdx)); + devices.emplace_back(at::DeviceType::CUDA, deviceIdx); } else { for (auto usedDeviceIdx : usedDeviceIdxs_) { - devices.push_back(at::Device(at::DeviceType::CUDA, usedDeviceIdx)); + devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); } } @@ -1233,7 +1453,7 @@ std::shared_ptr ProcessGroupNCCL::barrier( } #ifdef ENABLE_NCCL_P2P_SUPPORT -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -1251,15 +1471,19 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - return ncclAlltoall( - input.data_ptr(), - output.data_ptr(), - input.numel() / size_, - input.element_size(), - getNcclDataType(input.scalar_type()), + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + torch::cuda::nccl::all2all( + input, + output, + this->getSize(), comm, - stream.stream()); - }); + stream); + return ncclSuccess; + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); @@ -1280,6 +1504,9 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( inputSplitSizes, input, &send_lengths, &send_offsets); c10d::computeLengthsAndOffsets( outputSplitSizes, output, &recv_lengths, &recv_offsets); + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); return ncclAlltoallv( input.data_ptr(), send_lengths.data(), @@ -1291,11 +1518,51 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( getNcclDataType(input.scalar_type()), comm, stream.stream()); - }); + }, + OpType::ALLTOALL_BASE, + "nccl:all_to_all"); } } + +c10::intrusive_ptr ProcessGroupNCCL::send( + std::vector& tensors, + int dstRank, + int /* unused */) { + check_gpu_tensors(tensors); + auto ret = pointToPoint( + tensors, + [&](at::Tensor& input, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int dst) { + torch::cuda::nccl::send(input, comm, stream, dst); + return ncclSuccess; + }, + dstRank, + OpType::SEND); + return ret; +} + +c10::intrusive_ptr ProcessGroupNCCL::recv( + std::vector& tensors, + int srcRank, + int /* unused */) { + check_gpu_tensors(tensors); + auto ret = pointToPoint( + tensors, + [&](at::Tensor& output, + ncclComm_t comm, + at::cuda::CUDAStream& stream, + int src) { + torch::cuda::nccl::recv(output, comm, stream, src); + return ncclSuccess; + }, + srcRank, + OpType::RECV); + return ret; +} #else -std::shared_ptr ProcessGroupNCCL::alltoall_base( +c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& /* unused */, at::Tensor& /* unused */, std::vector& /* unused */, @@ -1304,50 +1571,66 @@ std::shared_ptr ProcessGroupNCCL::alltoall_base( throw std::runtime_error( "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } + +c10::intrusive_ptr ProcessGroupNCCL::send( + std::vector& /* unused */, + int /* unused */, + int /* unused */) { + throw std::runtime_error( + "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); +} + +c10::intrusive_ptr ProcessGroupNCCL::recv( + std::vector& /* unused */, + int /* unused */, + int /* unused */) { + throw std::runtime_error( + "ProcessGroupNCCL only supports recv for NCCL lib version >= 2.7.0"); +} #endif -std::shared_ptr ProcessGroupNCCL::alltoall( +void ProcessGroupNCCL::groupStart() { +#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) + C10D_NCCL_CHECK(ncclGroupStart()); +#endif + ++ncclActiveGroupCounter_; +} + +void ProcessGroupNCCL::groupEnd() { +#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) + C10D_NCCL_CHECK(ncclGroupEnd()); +#endif + --ncclActiveGroupCounter_; +} + +c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support alltoall"); } -std::shared_ptr ProcessGroupNCCL::gather( +c10::intrusive_ptr ProcessGroupNCCL::gather( std::vector>& /* unused */, std::vector& /* unused */, const GatherOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support gather"); } -std::shared_ptr ProcessGroupNCCL::scatter( +c10::intrusive_ptr ProcessGroupNCCL::scatter( std::vector& /* unused */, std::vector>& /* unused */, const ScatterOptions& /* unused */) { throw std::runtime_error("ProcessGroupNCCL does not support scatter"); } -std::shared_ptr ProcessGroupNCCL::send( - std::vector& /* unused */, - int /* unused */, - int /* unused */) { - throw std::runtime_error("ProcessGroupNCCL does not support send"); -} - -std::shared_ptr ProcessGroupNCCL::recv( - std::vector& /* unused */, - int /* unused */, - int /* unused */) { - throw std::runtime_error("ProcessGroupNCCL does not support recv"); -} - -std::shared_ptr ProcessGroupNCCL::recvAnysource( +c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( std::vector& /* unused */, int /* unused */) { - throw std::runtime_error("ProcessGroupNCCL does not support recv"); + throw std::runtime_error("ProcessGroupNCCL does not support recvAnysource"); } -std::shared_ptr ProcessGroupNCCL::allgather_base( +c10::intrusive_ptr ProcessGroupNCCL::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 5a5e5a718ad81..4d9dc3bd1ae89 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include #include @@ -11,7 +13,14 @@ #include #include +#include +#include +#include #include +#include +#include + +#include namespace c10d { @@ -61,10 +70,15 @@ constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING"; class ProcessGroupNCCL : public ProcessGroup { public: class WorkNCCL : public ProcessGroup::Work, - public std::enable_shared_from_this { + public std::enable_shared_from_this { public: // Constructor takes a list of CUDA devices - WorkNCCL(const std::vector& devices); + WorkNCCL(const std::vector& devices, int rank, OpType opType, const char* profilingTitle = nullptr); + // Copy constructor doing partial copy without outputs_. Cleanup thread + // monitors and removes finished works. However it will deadlock when + // destructs outputs_ tensors who are view tensors in autograd graph. + WorkNCCL(const WorkNCCL& w); + virtual ~WorkNCCL(); // Checks if request has completed. In this specific case of NCCL, it checks @@ -96,7 +110,7 @@ class ProcessGroupNCCL : public ProcessGroup { bool finishedGPUExecution(); // Get a Future object that will be marked as completed internally. - // It actually returns a FutureNCCL object which is a sub class Future. + // It actually returns a CUDAFuture object which is a sub class of Future. c10::intrusive_ptr getFuture() override; // Helper function that sets an exception_ptr on the WorkNCCL object. @@ -135,6 +149,10 @@ class ProcessGroupNCCL : public ProcessGroup { virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) const; + friend std::ostream& operator<<( + std::ostream& output, + const WorkNCCL& workNCCL); + private: // Helper function for synchronize void synchronizeInternal(std::chrono::milliseconds timeout); @@ -150,184 +168,30 @@ class ProcessGroupNCCL : public ProcessGroup { // Reference to the store so that we can write aborted communicators // to the store. - std::shared_ptr store_; + c10::intrusive_ptr store_; - // Store a reference to NCCL collective's outputs to be used by getFuture. + // Store a reference to NCCL collective's outputs, used by result and to + // give a more descriptive message when representing the Work as a string. std::shared_ptr> outputs_; - // Store streams that run FutureNCCL then callbacks. - std::vector> - futureNCCLCallbackStreams_; + + // The future returned by getFuture. + c10::intrusive_ptr future_; friend class ProcessGroupNCCL; }; - struct Options { + struct Options : torch::CustomClassHolder { explicit Options(); - std::chrono::milliseconds opTimeout; - bool isHighPriorityStream; - }; - - // FutureNCCL is a subclass of ivalue's Future. The goal is to use - // this class in getFuture API of WorkNCCL. This Future is mostly a - // wrapper to synchronize streams appropriately and it mostly enables - // the async programming model of CUDA while trying to adhere to the - // Future interface. FutureNCCL does not support NCCL_BLOCKING_WAIT flag - // or NCCL's barrier(). - // - // If created by WorkNCCL's getFuture API, FutureNCCL has a reference to - // WorkNCCL's cudaEvents, NCCL collective's outputs, device index of - // outputs' device, and the ProcesGroupNCCL's dedicated - // futureNCCLCallbackStream for outputs' device that runs all the then - // callbacks called from this FutureNCCL. Its value is NCCL collective's - // outputs. FutureNCCL only supports single-process single-device mode where - // the size of outputs is equal to 1. - // - // If created by FutureNCCL's then callback, its value becomes the value of - // callback() and its cudaEvents will record the NCCL stream that runs that - // callback. Before invoking the callback, FutureNCCL will synchronize its - // own cudaEvents with the stream that runs the callback. This design - // enables synchronizing the appropriate streams and avoids stalling PyTorch's - // default stream while running the callback. In case of multiple then - // callbacks, the design will work like a chain such that FutureNCCL n will - // wait on the cudaEvents from FutureNCCL n - 1. All callbacks are executed on - // outputs' device's dedicated futureNCCLCallbackStream. - struct FutureNCCL : at::ivalue::Future { - public: - explicit FutureNCCL( - at::IValue value, - c10::DeviceIndex deviceIndex, - std::shared_ptr> cudaEvents, - std::shared_ptr futureNCCLCallbackStream) - : at::ivalue::Future(c10::ListType::create(c10::TensorType::get())), - value_(std::move(value)), - deviceIndex_(deviceIndex), - cudaEvents_(cudaEvents), - futureNCCLCallbackStream_(futureNCCLCallbackStream) { - TORCH_INTERNAL_ASSERT( - cudaEvents_->size() == 1, - "FutureNCCL only supports single-process single-device mode."); - } - - // This constructor is used by then callback, it skips setting the value at - // the beginning. Later, the value will be set using markCompleted with the - // return value of callback. - explicit FutureNCCL( - c10::DeviceIndex deviceIndex, - std::shared_ptr> cudaEvents, - std::shared_ptr futureNCCLCallbackStream) - : at::ivalue::Future(c10::ListType::create(c10::TensorType::get())), - deviceIndex_(deviceIndex), - cudaEvents_(cudaEvents), - futureNCCLCallbackStream_(futureNCCLCallbackStream) { - TORCH_INTERNAL_ASSERT( - cudaEvents_->size() == 1, - "FutureNCCL only supports single-process single-device mode."); - } - - // Gets the current stream of the device and synchronizes recorded streams - // with that. It will return after synchronizing the correct GPU streams to - // ensure we can have async CUDA execution and it does not wait for the - // entire operation to complete on GPU. - void wait() override { - if (error_) { - throw *error_; - } - auto stream = at::cuda::getCurrentCUDAStream(deviceIndex_); - (*cudaEvents_)[0].block(stream); - } - - // If FutureNCCL was created by FutureNCCL::then, its value would be empty - // initially. FutureNCCL::then will later use this method to set its value - // to the return value of the callback. - void markCompleted(at::IValue value) override { - TORCH_INTERNAL_ASSERT( - value_.isNone(), - "Attempting to set value of a FutureNCCL which has a value." - "FutureNCCL's value was internally set to NCCL collective's " - "outputs or the return value of the callback."); - value_ = std::move(value); + // return intrusive_ptr of the object + static c10::intrusive_ptr create( + std::chrono::milliseconds timeout = kNoTimeout, + bool isHighStream = false) { + return c10::make_intrusive(); } - // Just returns FutureNCCL's value after wait returns. - at::IValue value() override { - TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.") - wait(); - return value_; - } - - const at::IValue& constValue() override { - TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.") - wait(); - return value_; - } - - // Adds a callback to FutureNCCL. It invokes the callback inline after - // synchronizing FutureNCCL's own cudaEvents with the stream that runs - // this callback. This new FutureNCCL's cudaEvents will record the - // callback's stream and will have the result value of the callback. - void addCallback(std::function callback) override { - (*cudaEvents_)[0].block(*futureNCCLCallbackStream_); - c10::OptionalStreamGuard streamGuard{ - c10::Stream(*futureNCCLCallbackStream_)}; - callback(); - } - - // Adds a callback to FutureNCCL, and returns another FutureNCCL to hold - // the return value of the callback and new cudaEvents that recorded the - // stream that runs this callback. - c10::intrusive_ptr then( - std::function callback, - at::TypePtr /* unused */) override { - // Create a new cudaEvents object of size 1 that will record - // futureNCCLCallbackStream_ after callback and will be passed to the new - // FutureNCCL. - auto thenFutCudaEvents = - std::make_shared>(1); - // Create a FutureNCCL without setting a value. - auto fut = c10::make_intrusive( - deviceIndex_, thenFutCudaEvents, futureNCCLCallbackStream_); - - // Use the dedicated callback stream to run callback. - // Cannot move capture std::function in lambda, because it cannot deduce - // the template type for std::function. Hence use std::bind to explicitly - // specify types. - addCallback(std::bind( - [&](std::function cb) { - try { - fut->markCompleted(at::IValue(cb())); - // In case of chained then callback calls, thenFutCudaEvents - // records callback's stream. - (*thenFutCudaEvents)[0].record(*futureNCCLCallbackStream_); - } catch (const std::exception& e) { - fut->setError(std::current_exception()); - } - }, - std::move(callback))); - return fut; - } - - // Checks cudaEventQuery with cudaEvents. Returns true if a FutureError was - // recorded or the entire operation is completed on the GPU. - bool completed() const override { - if (error_) { - return true; - } - // Checking the work's corresponding CUDA events' status - auto ret = cudaEventQuery((*cudaEvents_)[0]); - return ret != cudaErrorNotReady || ret == cudaSuccess; - } - - bool hasValue() const override { - return !value_.isNone(); - } - - private: - at::IValue value_; - c10::DeviceIndex deviceIndex_; - std::shared_ptr> cudaEvents_; - std::shared_ptr futureNCCLCallbackStream_; - c10::optional error_; + std::chrono::milliseconds opTimeout; + bool isHighPriorityStream; }; // If you wish to create multiple process groups, each with a potentially @@ -345,98 +209,102 @@ class ProcessGroupNCCL : public ProcessGroup { // communicator. These NCCL communicators are cached and reused if possible. // ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, - Options options = Options()); + c10::intrusive_ptr options = Options::create()); // This constructor includes the deprecated `groupName` argument. // If you have existing code that uses the `groupName`, you can replace // it by specifying a `c10d::PrefixStore(groupName, store)` for store. C10_DEPRECATED ProcessGroupNCCL( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, const std::string& groupName, - Options options = Options()) + c10::intrusive_ptr options = Options::create()) : ProcessGroupNCCL(store, rank, size, options) {} virtual ~ProcessGroupNCCL(); - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr alltoall( + c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + static void groupStart(); + + static void groupEnd(); + // Unsupported Ops - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr send( - std::vector& tensors, - int dstRank, - int tag) override; - - std::shared_ptr recv( - std::vector& tensors, - int srcRank, - int tag) override; - - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; @@ -444,20 +312,30 @@ class ProcessGroupNCCL : public ProcessGroup { protected: // Helper that broadcasts nccl unique ID to all ranks through the store - void broadcastUniqueNCCLID(ncclUniqueId* ncclID); + void broadcastUniqueNCCLID( + ncclUniqueId* ncclID, + OpType opType, + const std::string& devicesKey, + int p2pRank); // Helper that either looks up the cached NCCL communicators or creates // a new set of NCCL communicators as a cache entry std::vector>& getNCCLComm( const std::string& devicesKey, - const std::vector& devices); + const std::vector& devices, + OpType opType, + int p2pRank = 0, + bool isSendRecvSelf = false); // Wrapper method which can be overridden for tests. virtual std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms); - virtual std::shared_ptr initWork( - std::vector devices); + virtual c10::intrusive_ptr initWork( + std::vector devices, + int rank, + OpType opType, + const char* profilingTitle=nullptr); private: // Helper that encapsulates work shared across all collective communication @@ -467,16 +345,38 @@ class ProcessGroupNCCL : public ProcessGroup { // ncclComm_t, at::cuda::CUDAStream&); // void {pre,post}(std::vector); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, - Fn fn); + Fn fn, + OpType opType, + const char* profilingTitle = nullptr); template - std::shared_ptr collective( + c10::intrusive_ptr collective( std::vector& input, std::vector& output, Fn fn, PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr); + + // Helper that encapsulates work shared across point-to-point communication + // primitives. It is the same structure as the helper used for collective + // communicaiton primitives. + template + c10::intrusive_ptr pointToPoint( + std::vector& tensor, + Fn fn, + int peer, + OpType opType); + template + c10::intrusive_ptr pointToPoint( + std::vector& tensor, + Fn fn, + int peer, + OpType opType, + PreProcess pre, PostProcess post); // Checks for NCCL errors on each of the communicators and returns an @@ -497,13 +397,11 @@ class ProcessGroupNCCL : public ProcessGroup { void ncclCommWatchdogInternal(); - // Reads the NCCL_BLOCKING_WAIT environment variable and sets blockingWait_ - // accordingly. - void parseNcclBlockingWait(); - - // Reads the NCCL_ASYNC_ERROR_HANDLING environment variable and sets asyncErrorHandling_ - // accordingly. - void parseNcclAsyncErrorHandling(); + // This function iterates through the list of WorkNCCL objects in the + // workList_ corresponding to incomplete collectives and then aborts NCCL + // communicators associated with timed out collectives. + void abortTimedOutCollectives( + std::unordered_set& abortedCommIds); void workCleanupLoop(); @@ -512,7 +410,7 @@ class ProcessGroupNCCL : public ProcessGroup { static const int64_t kWorkCleanupThreadSleepMillis; // The store is used to broadcast the NCCL unique ID of rank 0. - std::shared_ptr store_; + c10::intrusive_ptr store_; // The number of NCCL communicators that have been created during // the lifetime of this process group. This sequence number is @@ -520,6 +418,8 @@ class ProcessGroupNCCL : public ProcessGroup { uint64_t ncclCommCounter_{0}; // The NCCL communicator that the process group has cached. + // + // For collective operations: // The key is a list of GPU devices that an operation is operating on // The GPU devices are stored in a device sequence and the cache NCCL // communicator is associated with this GPU device sequence @@ -538,6 +438,13 @@ class ProcessGroupNCCL : public ProcessGroup { // "0,4,5,6,7,1,2,3" // // Note that the order of the device for the tensor list matters. + // + // For point-to-point operations: + // The key is a string of my current rank and the peer process rank. + // e.g. If process 1 and process 2 are involved in a point-to-point + // communication, the key will be "1:2" on both processes. Note: this is for + // the scenario where there is only 1 GPU per process. When it comes to + // multiple GPUs per process, this part may need to redesigned. std::unordered_map>> devNCCLCommMap_; @@ -563,17 +470,17 @@ class ProcessGroupNCCL : public ProcessGroup { // Thread that removes NCCL Work upon timeout std::thread workCleanupThread_; - // Mutex to Guard workList_ - std::mutex workListMutex_; + // Mutex to Guard workMetaList_ + std::mutex workMetaListMutex_; // Condition Variable for timeout thread sleep - std::condition_variable workListCV_; + std::condition_variable workMetaListCV_; // Vector to Store WorkNCCL pointers - std::list> workList_; + std::list workMetaList_; // Add Work Pointer to workVector - void workEnqueue(std::shared_ptr); + void workEnqueue(c10::intrusive_ptr); // The CUDA steams used by NCCL kernels std::unordered_map> @@ -622,18 +529,13 @@ class ProcessGroupNCCL : public ProcessGroup { // set contains the string representation of ncclUniqueId. std::unordered_set abortedComms_; - // In single-process single-device mode, WorkNCCL::getFuture is supported. - // Depending on the device index of collective outputs, WorkNCCL will pass - // the corresponding device's then callback stream to FutureNCCL. - // We just inititalize futureNCCLCallbackStreams_ inside the constructor and - // set its size to the total number of available devices and depending on the - // device of the NCCL collective's outputs, we later set the callback stream - // of the corresponding device inside ProcessGroupNCCL::getNCCLComm if not set - // before. - std::vector> futureNCCLCallbackStreams_; - // Schedule NCCL operations on high priority CUDA streams. bool isHighPriorityStream_ = false; + + // The number of active ncclGroupStart() calls. This counter will be increased + // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd() + // is called. + static thread_local uint64_t ncclActiveGroupCounter_; }; } // namespace c10d diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.cpp b/torch/lib/c10d/ProcessGroupRoundRobin.cpp index 032f63c320f5c..455c1654f587e 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.cpp @@ -5,7 +5,7 @@ namespace c10d { ProcessGroupRoundRobin::ProcessGroupRoundRobin( int rank, int size, - std::vector> processGroups) + std::vector> processGroups) : ProcessGroup(rank, size), processGroups_(std::move(processGroups)) { TORCH_CHECK(processGroups_.size() >= 1); for (const auto& processGroup : processGroups_) { @@ -17,66 +17,66 @@ ProcessGroupRoundRobin::ProcessGroupRoundRobin( ProcessGroupRoundRobin::~ProcessGroupRoundRobin() {} -std::shared_ptr ProcessGroupRoundRobin::broadcast( +c10::intrusive_ptr ProcessGroupRoundRobin::broadcast( std::vector& tensors, const BroadcastOptions& opts) { return next()->broadcast(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce( std::vector& tensors, const AllreduceOptions& opts) { return next()->allreduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allreduce_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { return next()->allreduce_coalesced(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::reduce( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce( std::vector& tensors, const ReduceOptions& opts) { return next()->reduce(tensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::allgather( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { return next()->allgather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::allgather_coalesced( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts) { return next()->allgather(outputTensorLists, inputTensors, opts); } -std::shared_ptr ProcessGroupRoundRobin::gather( +c10::intrusive_ptr ProcessGroupRoundRobin::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { return next()->gather(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { return next()->scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::reduce_scatter( +c10::intrusive_ptr ProcessGroupRoundRobin::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { return next()->reduce_scatter(outputs, inputs, opts); }; -std::shared_ptr ProcessGroupRoundRobin::alltoall_base( +c10::intrusive_ptr ProcessGroupRoundRobin::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, @@ -86,32 +86,32 @@ std::shared_ptr ProcessGroupRoundRobin::alltoall_base( outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts); }; -std::shared_ptr ProcessGroupRoundRobin::send( +c10::intrusive_ptr ProcessGroupRoundRobin::send( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support send"); }; -std::shared_ptr ProcessGroupRoundRobin::recv( +c10::intrusive_ptr ProcessGroupRoundRobin::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::recvAnysource( +c10::intrusive_ptr ProcessGroupRoundRobin::recvAnysource( std::vector& /* unused */, int /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support recv"); }; -std::shared_ptr ProcessGroupRoundRobin::barrier( +c10::intrusive_ptr ProcessGroupRoundRobin::barrier( const BarrierOptions& /* unused */) { throw std::runtime_error("ProcessGroupRoundRobin does not support barrier"); }; -const std::shared_ptr& ProcessGroupRoundRobin::next() { +const c10::intrusive_ptr& ProcessGroupRoundRobin::next() { auto& processGroup = *iterator_; iterator_++; if (iterator_ == processGroups_.end()) { @@ -120,7 +120,7 @@ const std::shared_ptr& ProcessGroupRoundRobin::next() { return processGroup; } -std::shared_ptr ProcessGroupRoundRobin::allgather_base( +c10::intrusive_ptr ProcessGroupRoundRobin::allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index bbbd0a1c756be..a8c2eba115a6a 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -21,87 +21,87 @@ class ProcessGroupRoundRobin final : public ProcessGroup { explicit ProcessGroupRoundRobin( int rank, int size, - std::vector> processGroups); + std::vector> processGroups); ~ProcessGroupRoundRobin() override; - std::shared_ptr broadcast( + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; - std::shared_ptr allreduce( + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - std::shared_ptr allreduce_coalesced( + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; - std::shared_ptr reduce( + c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; - std::shared_ptr allgather( + c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_base( + c10::intrusive_ptr allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr allgather_coalesced( + c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; - std::shared_ptr gather( + c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; - std::shared_ptr scatter( + c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; - std::shared_ptr reduce_scatter( + c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - std::shared_ptr alltoall_base( + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; - std::shared_ptr send( + c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; - std::shared_ptr recv( + c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; - std::shared_ptr recvAnysource( + c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; - std::shared_ptr barrier( + c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; private: - std::vector> processGroups_; - std::vector>::const_iterator iterator_; + std::vector> processGroups_; + std::vector>::const_iterator iterator_; // Returns the next ProcessGroup to use. - const std::shared_ptr& next(); + const c10::intrusive_ptr& next(); }; } // namespace c10d diff --git a/torch/lib/c10d/Store.hpp b/torch/lib/c10d/Store.hpp index 8e313fda97674..f97e80013cdb0 100644 --- a/torch/lib/c10d/Store.hpp +++ b/torch/lib/c10d/Store.hpp @@ -6,9 +6,11 @@ #include #include +#include + namespace c10d { -class Store { +class Store : public torch::CustomClassHolder { public: static constexpr std::chrono::milliseconds kDefaultTimeout = std::chrono::seconds(300); @@ -30,8 +32,12 @@ class Store { virtual int64_t add(const std::string& key, int64_t value) = 0; + virtual bool deleteKey(const std::string& key) = 0; + virtual bool check(const std::vector& keys) = 0; + virtual int64_t getNumKeys() = 0; + virtual void wait(const std::vector& keys) = 0; virtual void wait( diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp index dfd33cfb77ca4..4151448e677a2 100644 --- a/torch/lib/c10d/TCPStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -1,16 +1,22 @@ #include +#ifdef _WIN32 +#include +#include +#else #include - #include +#endif + #include +#include #include namespace c10d { namespace { -enum class QueryType : uint8_t { SET, GET, ADD, CHECK, WAIT }; +enum class QueryType : uint8_t { SET, GET, ADD, CHECK, WAIT, GETNUMKEYS, DELETE_KEY }; enum class CheckResponseType : uint8_t { READY, NOT_READY }; @@ -23,11 +29,7 @@ enum class WaitResponseType : uint8_t { STOP_WAITING }; TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket) : storeListenSocket_(storeListenSocket) { // Use control pipe to signal instance destruction to the daemon thread. - if (pipe(controlPipeFd_.data()) == -1) { - throw std::runtime_error( - "Failed to create the control pipe to start the " - "TCPStoreDaemon run"); - } + initStopSignal(); daemonThread_ = std::thread(&TCPStoreDaemon::run, this); } @@ -39,123 +41,69 @@ TCPStoreDaemon::~TCPStoreDaemon() { // Close unclosed sockets for (auto socket : sockets_) { if (socket != -1) { - ::close(socket); + tcputil::closeSocket(socket); } } // Now close the rest control pipe - for (auto fd : controlPipeFd_) { - if (fd != -1) { - ::close(fd); - } - } + closeStopSignal(); } void TCPStoreDaemon::join() { daemonThread_.join(); } -void TCPStoreDaemon::run() { - std::vector fds; - fds.push_back({.fd = storeListenSocket_, .events = POLLIN}); - // Push the read end of the pipe to signal the stopping of the daemon run - fds.push_back({.fd = controlPipeFd_[0], .events = POLLHUP}); - - // receive the queries - bool finished = false; - while (!finished) { - for (size_t i = 0; i < sockets_.size(); i++) { - fds[i].revents = 0; - } - - SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); - - // TCPStore's listening socket has an event and it should now be able to - // accept new connections. - if (fds[0].revents != 0) { - if (fds[0].revents ^ POLLIN) { - throw std::system_error( - ECONNABORTED, - std::system_category(), - "Unexpected poll revent on the master's listening socket: " + - std::to_string(fds[0].revents)); - } - int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); - sockets_.push_back(sockFd); - fds.push_back({.fd = sockFd, .events = POLLIN}); - } - // The pipe receives an event which tells us to shutdown the daemon - if (fds[1].revents != 0) { - // Will be POLLUP when the pipe is closed - if (fds[1].revents ^ POLLHUP) { - throw std::system_error( - ECONNABORTED, - std::system_category(), - "Unexpected poll revent on the control pipe's reading fd: " + - std::to_string(fds[1].revents)); - } - finished = true; - break; +void TCPStoreDaemon::queryFds(std::vector& fds) { + // Skipping the fds[0] and fds[1], + // fds[0] is master's listening socket + // fds[1] is control pipe's reading fd, it is not for Windows platform + for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) { + if (fds[fdIdx].revents == 0) { + continue; } - // Skipping the fds[0] and fds[1], - // fds[0] is master's listening socket - // fds[1] is control pipe's reading fd - for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) { - if (fds[fdIdx].revents == 0) { - continue; - } - // Now query the socket that has the event - try { - query(fds[fdIdx].fd); - } catch (...) { - // There was an error when processing query. Probably an exception - // occurred in recv/send what would indicate that socket on the other - // side has been closed. If the closing was due to normal exit, then - // the store should continue executing. Otherwise, if it was different - // exception, other connections will get an exception once they try to - // use the store. We will go ahead and close this connection whenever - // we hit an exception here. - ::close(fds[fdIdx].fd); - - // Remove all the tracking state of the close FD - for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) { - for (auto vecIt = it->second.begin(); vecIt != it->second.end();) { - if (*vecIt == fds[fdIdx].fd) { - vecIt = it->second.erase(vecIt); - } else { - ++vecIt; - } - } - if (it->second.size() == 0) { - it = waitingSockets_.erase(it); + // Now query the socket that has the event + try { + query(fds[fdIdx].fd); + } catch (...) { + // There was an error when processing query. Probably an exception + // occurred in recv/send what would indicate that socket on the other + // side has been closed. If the closing was due to normal exit, then + // the store should continue executing. Otherwise, if it was different + // exception, other connections will get an exception once they try to + // use the store. We will go ahead and close this connection whenever + // we hit an exception here. + tcputil::closeSocket(fds[fdIdx].fd); + + // Remove all the tracking state of the close FD + for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) { + for (auto vecIt = it->second.begin(); vecIt != it->second.end();) { + if (*vecIt == fds[fdIdx].fd) { + vecIt = it->second.erase(vecIt); } else { - ++it; + ++vecIt; } } - for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) { - if (it->first == fds[fdIdx].fd) { - it = keysAwaited_.erase(it); - } else { - ++it; - } + if (it->second.size() == 0) { + it = waitingSockets_.erase(it); + } else { + ++it; } - fds.erase(fds.begin() + fdIdx); - sockets_.erase(sockets_.begin() + fdIdx - 2); - --fdIdx; - continue; } + for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) { + if (it->first == fds[fdIdx].fd) { + it = keysAwaited_.erase(it); + } else { + ++it; + } + } + fds.erase(fds.begin() + fdIdx); + sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET); + --fdIdx; + continue; } } } -void TCPStoreDaemon::stop() { - if (controlPipeFd_[1] != -1) { - // close the write end of the pipe - ::close(controlPipeFd_[1]); - controlPipeFd_[1] = -1; - } -} - // query communicates with the worker. The format // of the query is as follows: // type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... @@ -180,6 +128,12 @@ void TCPStoreDaemon::query(int socket) { } else if (qt == QueryType::WAIT) { waitHandler(socket); + } else if (qt == QueryType::GETNUMKEYS) { + getNumKeysHandler(socket); + + } else if (qt == QueryType::DELETE_KEY) { + deleteHandler(socket); + } else { throw std::runtime_error("Unexpected query type"); } @@ -228,6 +182,16 @@ void TCPStoreDaemon::getHandler(int socket) const { tcputil::sendVector(socket, data); } +void TCPStoreDaemon::getNumKeysHandler(int socket) const { + tcputil::sendValue(socket, tcpStore_.size()); +} + +void TCPStoreDaemon::deleteHandler(int socket) { + std::string key = tcputil::recvString(socket); + auto numDeleted = tcpStore_.erase(key); + tcputil::sendValue(socket, numDeleted); +} + void TCPStoreDaemon::checkHandler(int socket) const { SizeType nargs; tcputil::recvBytes(socket, &nargs, 1); @@ -267,6 +231,137 @@ bool TCPStoreDaemon::checkKeys(const std::vector& keys) const { }); } +#ifdef _WIN32 +void TCPStoreDaemon::initStopSignal() { + ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); + if (ghStopEvent_ == NULL) { + throw std::runtime_error( + "Failed to create the control pipe to start the " + "TCPStoreDaemon run"); + } +} + +void TCPStoreDaemon::closeStopSignal() { + CloseHandle(ghStopEvent_); +} + +void TCPStoreDaemon::stop() { + SetEvent(ghStopEvent_); +} + +void TCPStoreDaemon::run() { + std::vector fds; + tcputil::addPollfd(fds, storeListenSocket_, POLLIN); + + // receive the queries + bool finished = false; + while (!finished) { + for (size_t i = 0; i < sockets_.size(); i++) { + fds[i].revents = 0; + } + + int res; + SYSCHECK_ERR_RETURN_NEG1( + res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count())) + if (res == 0) { + auto rv = WaitForSingleObject(ghStopEvent_, 0); + if (rv != WAIT_TIMEOUT) { + finished = true; + break; + } + continue; + } + + // TCPStore's listening socket has an event and it should now be able to + // accept new connections. + if (fds[0].revents != 0) { + if (!(fds[0].revents & POLLIN)) { + throw std::system_error( + ECONNABORTED, + std::system_category(), + "Unexpected poll revent on the master's listening socket: " + + std::to_string(fds[0].revents)); + } + int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); + sockets_.push_back(sockFd); + tcputil::addPollfd(fds, sockFd, POLLIN); + } + queryFds(fds); + } +} +#else +void TCPStoreDaemon::initStopSignal() { + if (pipe(controlPipeFd_.data()) == -1) { + throw std::runtime_error( + "Failed to create the control pipe to start the " + "TCPStoreDaemon run"); + } +} + +void TCPStoreDaemon::closeStopSignal() { + for (auto fd : controlPipeFd_) { + if (fd != -1) { + ::close(fd); + } + } +} + +void TCPStoreDaemon::stop() { + if (controlPipeFd_[1] != -1) { + // close the write end of the pipe + ::close(controlPipeFd_[1]); + controlPipeFd_[1] = -1; + } +} + +void TCPStoreDaemon::run() { + std::vector fds; + tcputil::addPollfd(fds, storeListenSocket_, POLLIN); + // Push the read end of the pipe to signal the stopping of the daemon run + tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP); + + // receive the queries + bool finished = false; + while (!finished) { + for (size_t i = 0; i < sockets_.size(); i++) { + fds[i].revents = 0; + } + + SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); + + // TCPStore's listening socket has an event and it should now be able to + // accept new connections. + if (fds[0].revents != 0) { + if (fds[0].revents ^ POLLIN) { + throw std::system_error( + ECONNABORTED, + std::system_category(), + "Unexpected poll revent on the master's listening socket: " + + std::to_string(fds[0].revents)); + } + int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); + sockets_.push_back(sockFd); + tcputil::addPollfd(fds, sockFd, POLLIN); + } + + // The pipe receives an event which tells us to shutdown the daemon + if (fds[1].revents != 0) { + // Will be POLLUP when the pipe is closed + if (fds[1].revents ^ POLLHUP) { + throw std::system_error( + ECONNABORTED, + std::system_category(), + "Unexpected poll revent on the control pipe's reading fd: " + + std::to_string(fds[1].revents)); + } + finished = true; + break; + } + queryFds(fds); + } +} +#endif + // TCPStore class methods TCPStore::TCPStore( const std::string& masterAddr, @@ -282,6 +377,7 @@ TCPStore::TCPStore( numWorkers_(numWorkers), initKey_("init/"), regularPrefix_("/") { + tcputil::socketInitialize(); if (isServer_) { // Opening up the listening socket std::tie(masterListenSocket_, tcpStorePort_) = tcputil::listen(masterPort); @@ -292,19 +388,18 @@ TCPStore::TCPStore( // Connect to the daemon storeSocket_ = tcputil::connect( tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_); - if (waitWorkers) { waitForWorkers(); } } TCPStore::~TCPStore() { - ::close(storeSocket_); + tcputil::closeSocket(storeSocket_); if (isServer_) { // Store daemon should end because of closed connection. // daemon destructor should join the thread tcpStoreDaemon_.reset(nullptr); - ::close(masterListenSocket_); + tcputil::closeSocket(masterListenSocket_); } } @@ -357,6 +452,14 @@ int64_t TCPStore::add(const std::string& key, int64_t value) { return addHelper_(regKey, value); } +bool TCPStore::deleteKey(const std::string& key) { + std::string regKey = regularPrefix_ + key; + tcputil::sendValue(storeSocket_, QueryType::DELETE_KEY); + tcputil::sendString(storeSocket_, regKey, true); + auto numDeleted = tcputil::recvValue(storeSocket_); + return (numDeleted == 1); +} + int64_t TCPStore::addHelper_(const std::string& key, int64_t value) { tcputil::sendValue(storeSocket_, QueryType::ADD); tcputil::sendString(storeSocket_, key, true); @@ -364,6 +467,11 @@ int64_t TCPStore::addHelper_(const std::string& key, int64_t value) { return tcputil::recvValue(storeSocket_); } +int64_t TCPStore::getNumKeys() { + tcputil::sendValue(storeSocket_, QueryType::GETNUMKEYS); + return tcputil::recvValue(storeSocket_); +} + bool TCPStore::check(const std::vector& keys) { tcputil::sendValue(storeSocket_, QueryType::CHECK); SizeType nkeys = keys.size(); @@ -402,8 +510,13 @@ void TCPStore::waitHelper_( const std::chrono::milliseconds& timeout) { // Set the socket timeout if there is a wait timeout if (timeout != kNoTimeout) { +#ifdef _WIN32 + struct timeval timeoutTV = {timeout.count() / 1000, + (timeout.count() % 1000) * 1000}; +#else struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000, .tv_usec = (timeout.count() % 1000) * 1000}; +#endif SYSCHECK_ERR_RETURN_NEG1(::setsockopt( storeSocket_, SOL_SOCKET, diff --git a/torch/lib/c10d/TCPStore.hpp b/torch/lib/c10d/TCPStore.hpp index 29733639bd591..47c92b742520b 100644 --- a/torch/lib/c10d/TCPStore.hpp +++ b/torch/lib/c10d/TCPStore.hpp @@ -5,7 +5,12 @@ #include #include -#include + +#ifdef _WIN32 +#include +#else +#include +#endif namespace c10d { @@ -20,17 +25,23 @@ class TCPStoreDaemon { void run(); void stop(); + void queryFds(std::vector& fds); void query(int socket); void setHandler(int socket); void addHandler(int socket); void getHandler(int socket) const; void checkHandler(int socket) const; + void getNumKeysHandler(int socket) const; + void deleteHandler(int socket); void waitHandler(int socket); bool checkKeys(const std::vector& keys) const; void wakeupWaitingClients(const std::string& key); + void initStopSignal(); + void closeStopSignal(); + std::thread daemonThread_; std::unordered_map> tcpStore_; // From key -> the list of sockets waiting on it @@ -40,7 +51,13 @@ class TCPStoreDaemon { std::vector sockets_; int storeListenSocket_; +#ifdef _WIN32 + const std::chrono::milliseconds checkTimeout_ + = std::chrono::milliseconds(10); + HANDLE ghStopEvent_; +#else std::vector controlPipeFd_{-1, -1}; +#endif }; class TCPStore : public Store { @@ -61,8 +78,12 @@ class TCPStore : public Store { int64_t add(const std::string& key, int64_t value) override; + bool deleteKey(const std::string& key) override; + bool check(const std::vector& keys) override; + int64_t getNumKeys() override; + void wait(const std::vector& keys) override; void wait( diff --git a/torch/lib/c10d/Types.hpp b/torch/lib/c10d/Types.hpp index 03b2e59e42951..a5a0d5fa20dfd 100644 --- a/torch/lib/c10d/Types.hpp +++ b/torch/lib/c10d/Types.hpp @@ -62,6 +62,7 @@ struct AllToAllOptions { }; struct BarrierOptions { + std::vector device_ids; std::chrono::milliseconds timeout = kUnsetTimeout; }; diff --git a/torch/lib/c10d/UnixSockUtils.hpp b/torch/lib/c10d/UnixSockUtils.hpp new file mode 100644 index 0000000000000..fa74be27f889e --- /dev/null +++ b/torch/lib/c10d/UnixSockUtils.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include + +namespace c10d { +namespace tcputil { + +#define AF_SELECTED AF_UNSPEC +#define CONNECT_SOCKET_OFFSET 2 + +inline void closeSocket(int socket) { ::close(socket); } + +inline int setSocketAddrReUse(int socket) { + int optval = 1; + return ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)); +} + +inline int poll(struct pollfd *fds, unsigned long nfds, int timeout) { + return ::poll(fds, nfds, timeout); +} + +inline void addPollfd(std::vector &fds, int socket, + short events) { + fds.push_back({.fd = socket, .events = events}); +} + +inline void waitSocketConnected( + int socket, + struct ::addrinfo *nextAddr, + std::chrono::milliseconds timeout, + std::chrono::time_point startTime) { + SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK)); + + int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen); + + if (ret != 0 && errno != EINPROGRESS) { + throw std::system_error(errno, std::system_category()); + } + + struct ::pollfd pfd; + pfd.fd = socket; + pfd.events = POLLOUT; + + int64_t pollTimeout = -1; + if (timeout != kNoTimeout) { + // calculate remaining time and use that as timeout for poll() + const auto elapsed = std::chrono::high_resolution_clock::now() - startTime; + const auto remaining = + std::chrono::duration_cast(timeout) - + std::chrono::duration_cast(elapsed); + pollTimeout = std::max(static_cast(0), + static_cast(remaining.count())); + } + int numReady = ::poll(&pfd, 1, pollTimeout); + if (numReady < 0) { + throw std::system_error(errno, std::system_category()); + } else if (numReady == 0) { + errno = 0; + throw std::runtime_error(kConnectTimeoutMsg); + } + + socklen_t errLen = sizeof(errno); + errno = 0; + ::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen); + + // `errno` is set when: + // 1. `getsockopt` has failed + // 2. there is awaiting error in the socket + // (the error is saved to the `errno` variable) + if (errno != 0) { + throw std::system_error(errno, std::system_category()); + } + + // Disable non-blocking mode + int flags; + SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL)); + SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK))); +} + +// Linux socket does not need init libs first +inline void socketInitialize() {} + +inline struct ::pollfd getPollfd(int socket, short events) { + struct ::pollfd res = {.fd = socket, .events = events}; + return res; +} + +} // namespace tcputil +} // namespace c10d diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp index d975f6eb6bc5b..62e1e195ca458 100644 --- a/torch/lib/c10d/Utils.cpp +++ b/torch/lib/c10d/Utils.cpp @@ -1,17 +1,18 @@ -#include - +#ifdef _WIN32 +#include +#else +#include #include #include - #include #include #include - -#include #include +#endif #include #include +#include #include #include #include @@ -22,7 +23,6 @@ namespace tcputil { namespace { constexpr int LISTEN_QUEUE_SIZE = 2048; -const std::string kConnectTimeoutMsg = "connect() timed out."; void setSocketNoDelay(int socket) { int flag = 1; @@ -81,7 +81,7 @@ std::pair listen(PortType port) { struct ::addrinfo hints, *res = NULL; std::memset(&hints, 0x00, sizeof(hints)); hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; - hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 + hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux hints.ai_socktype = SOCK_STREAM; // TCP // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked @@ -105,18 +105,14 @@ std::pair listen(PortType port) { nextAddr->ai_family, nextAddr->ai_socktype, nextAddr->ai_protocol)) - - int optval = 1; - SYSCHECK_ERR_RETURN_NEG1( - ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int))) - + SYSCHECK_ERR_RETURN_NEG1(tcputil::setSocketAddrReUse(socket)) SYSCHECK_ERR_RETURN_NEG1( ::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen)) SYSCHECK_ERR_RETURN_NEG1(::listen(socket, LISTEN_QUEUE_SIZE)) break; } catch (const std::system_error& e) { - ::close(socket); + tcputil::closeSocket(socket); nextAddr = nextAddr->ai_next; // we have tried all addresses but could not start @@ -202,7 +198,7 @@ int connect( struct ::addrinfo hints, *res = NULL; std::memset(&hints, 0x00, sizeof(hints)); hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric - hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6 + hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux hints.ai_socktype = SOCK_STREAM; // TCP // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked @@ -235,55 +231,11 @@ int connect( nextAddr->ai_socktype, nextAddr->ai_protocol)) - ResourceGuard socketGuard([socket]() { ::close(socket); }); + ResourceGuard socketGuard([socket]() { tcputil::closeSocket(socket); }); // We need to connect in non-blocking mode, so we can use a timeout - SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK)); + waitSocketConnected(socket, nextAddr, timeout, start); - int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen); - - if (ret != 0 && errno != EINPROGRESS) { - throw std::system_error(errno, std::system_category()); - } - - struct ::pollfd pfd; - pfd.fd = socket; - pfd.events = POLLOUT; - - int64_t pollTimeout = -1; - if (timeout != kNoTimeout) { - // calculate remaining time and use that as timeout for poll() - const auto elapsed = std::chrono::high_resolution_clock::now() - start; - const auto remaining = - std::chrono::duration_cast(timeout) - - std::chrono::duration_cast(elapsed); - pollTimeout = std::max( - static_cast(0), static_cast(remaining.count())); - } - int numReady = ::poll(&pfd, 1, pollTimeout); - if (numReady < 0) { - throw std::system_error(errno, std::system_category()); - } else if (numReady == 0) { - errno = 0; - throw std::runtime_error(kConnectTimeoutMsg); - } - - socklen_t errLen = sizeof(errno); - errno = 0; - ::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen); - - // `errno` is set when: - // 1. `getsockopt` has failed - // 2. there is awaiting error in the socket - // (the error is saved to the `errno` variable) - if (errno != 0) { - throw std::system_error(errno, std::system_category()); - } - - // Disable non-blocking mode - int flags; - SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL)); - SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK))); socketGuard.release(); break; @@ -320,10 +272,10 @@ std::tuple accept( const std::chrono::milliseconds& timeout) { // poll on listen socket, it allows to make timeout std::unique_ptr events(new struct ::pollfd[1]); - events[0] = {.fd = listenSocket, .events = POLLIN}; + events[0] = tcputil::getPollfd(listenSocket, POLLIN); while (true) { - int res = ::poll(events.get(), 1, timeout.count()); + int res = tcputil::poll(events.get(), 1, timeout.count()); if (res == 0) { throw std::runtime_error( "waiting for processes to " @@ -354,6 +306,5 @@ std::tuple accept( return std::make_tuple( socket, sockaddrToString(reinterpret_cast(&addr))); } - } // namespace tcputil } // namespace c10d diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp index 1bdaddde9f243..e7b0f1834441b 100644 --- a/torch/lib/c10d/Utils.hpp +++ b/torch/lib/c10d/Utils.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -17,6 +16,19 @@ #include +#ifdef _WIN32 +#include +#include +typedef SSIZE_T ssize_t; +#pragma comment(lib, "Ws2_32.lib") +#else +#include +#include +#include +#include +#include +#endif + namespace c10d { // Turns at::IntArrayRef into "(1, 2, 3, 4)". @@ -52,6 +64,30 @@ inline void assertSameType( } } +inline bool parseEnvVarFlag(const char* envVarName) { + char* stringValue = std::getenv(envVarName); + if (stringValue != nullptr) { + int val; + try { + val = std::stoi(stringValue); + } catch (std::exception& e) { + throw std::runtime_error( + "Invalid value for environment variable: " + + std::string(envVarName)); + } + if (val == 1) { + return true; + } else if (val == 0) { + return false; + } else { + throw std::runtime_error( + "Invalid value for environment variable: " + + std::string(envVarName)); + } + } + return false; +} + inline void assertSameSizes( const at::IntArrayRef& sizes, const std::vector& tensors) { @@ -438,6 +474,25 @@ using SizeType = uint64_t; // `success_cond` is an expression used to check if an error has happend. So for // `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output // is stored in variable `__output` and may be used in `success_cond`. +#ifdef _WIN32 +#define SYSCHECK(expr, success_cond) \ + while (true) { \ + auto __output = (expr); \ + auto errno_local = WSAGetLastError(); \ + (void)__output; \ + if (!(success_cond)) { \ + if (errno == EINTR) { \ + continue; \ + } else if (errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK ) { \ + throw std::runtime_error("Socket Timeout"); \ + } else { \ + throw std::system_error(errno_local, std::system_category()); \ + } \ + } else { \ + break; \ + } \ + } +#else #define SYSCHECK(expr, success_cond) \ while (true) { \ auto __output = (expr); \ @@ -454,9 +509,11 @@ using SizeType = uint64_t; break; \ } \ } +#endif // Most functions indicate error by returning `-1`. This is a helper macro for // this common case with `SYSCHECK`. +// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) // Helper resource guard class @@ -483,6 +540,7 @@ class ResourceGuard { namespace tcputil { constexpr std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds(-1); +const std::string kConnectTimeoutMsg = "connect() timed out."; // Send and receive template @@ -510,7 +568,7 @@ void sendBytes( while (bytesToSend > 0) { ssize_t bytesSent; SYSCHECK_ERR_RETURN_NEG1( - bytesSent = ::send(socket, currentBytes, bytesToSend, flags)) + bytesSent = ::send(socket, (const char*)currentBytes, bytesToSend, flags)) if (bytesSent == 0) { throw std::system_error(ECONNRESET, std::system_category()); } @@ -533,7 +591,7 @@ void recvBytes(int socket, T* buffer, size_t length) { while (bytesToReceive > 0) { ssize_t bytesReceived; SYSCHECK_ERR_RETURN_NEG1( - bytesReceived = ::recv(socket, currentBytes, bytesToReceive, 0)) + bytesReceived = recv(socket, (char*)currentBytes, bytesToReceive, 0)) if (bytesReceived == 0) { throw std::system_error(ECONNRESET, std::system_category()); } diff --git a/torch/lib/c10d/WinSockUtils.hpp b/torch/lib/c10d/WinSockUtils.hpp new file mode 100644 index 0000000000000..cd37695845ab1 --- /dev/null +++ b/torch/lib/c10d/WinSockUtils.hpp @@ -0,0 +1,84 @@ +#pragma once + +#include + +namespace c10d { +namespace tcputil { + +#define AF_SELECTED AF_INET +#define CONNECT_SOCKET_OFFSET 1 + +inline void closeSocket(int socket) { ::closesocket(socket); } + +inline int setSocketAddrReUse(int socket) { + bool optval = false; + return ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char *)&optval, + sizeof(bool)); +} + +inline int poll(struct pollfd *fdArray, unsigned long fds, int timeout) { + return WSAPoll(fdArray, fds, timeout); +} + +inline void addPollfd(std::vector &fds, int socket, + short events) { + fds.push_back({(SOCKET)socket, events}); +} + +inline void waitSocketConnected( + int socket, + struct ::addrinfo *nextAddr, + std::chrono::milliseconds timeout, + std::chrono::time_point startTime) { + unsigned long block_mode = 1; + SYSCHECK_ERR_RETURN_NEG1(ioctlsocket(socket, FIONBIO, &block_mode)); + + int ret; + do { + ret = connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen); + if (ret == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEISCONN) { + break; + } else if (err == WSAEALREADY || err == WSAEWOULDBLOCK) { + if (timeout != kNoTimeout) { + const auto elapsed = + std::chrono::high_resolution_clock::now() - startTime; + if (elapsed > timeout) { + errno = 0; + throw std::runtime_error(kConnectTimeoutMsg); + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } + throw std::system_error(err, std::system_category(), + "Socket connect failed"); + } + } while (ret == SOCKET_ERROR); + + block_mode = 0; + SYSCHECK_ERR_RETURN_NEG1(ioctlsocket(socket, FIONBIO, &block_mode)); +} + +// All processes (applications or DLLs) that call Winsock +// functions must initialize the use of the Windows Sockets +// DLL before making other Winsock function calls. +// This also makes certain that Winsock is supported on the system. +// Ref to +// https://docs.microsoft.com/en-us/windows/win32/winsock/initializing-winsock +inline void socketInitialize() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + WSADATA wsa_data; + SYSCHECK_ERR_RETURN_NEG1(WSAStartup(MAKEWORD(2, 2), &wsa_data)) + }); +} + +inline struct ::pollfd getPollfd(int socket, short events) { + struct ::pollfd res = {(SOCKET)socket, events}; + return res; +} + +} // namespace tcputil +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/lib/c10d/comm.cpp similarity index 62% rename from torch/csrc/distributed/c10d/comm.cpp rename to torch/lib/c10d/comm.cpp index 2eb79283a0883..1db8901b2859f 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/lib/c10d/comm.cpp @@ -1,9 +1,9 @@ -#include +#include #include #include -#include +#include #include #include @@ -13,7 +13,7 @@ namespace { class BroadcastWork { public: BroadcastWork( - const std::shared_ptr& process_group, + const c10::intrusive_ptr& process_group, std::vector bucket_tensors, int root_rank = 0) : bucket_tensors_(std::move(bucket_tensors)), @@ -45,15 +45,17 @@ class BroadcastWork { // because c10d::ProcessGroup::broadcast takes a vector argument. std::vector flat_tensor_; + private: + // The broadcast work that is kicked off upon construction. - std::shared_ptr work_; + c10::intrusive_ptr work_; }; } // namespace // Broadcast many tensors to all processes in the process group. void broadcast_coalesced( - std::shared_ptr process_group, + c10::intrusive_ptr process_group, at::TensorList tensors, size_t buffer_size, int rank) { @@ -85,44 +87,4 @@ void broadcast_coalesced( } } -PythonCommHook::PythonCommHook(py::object state, py::object hook) - : state_(std::move(state)), hook_(std::move(hook)){}; - -c10::intrusive_ptr PythonCommHook::runHook( - const GradBucket& bucket) { - py::gil_scoped_acquire acquire; - - py::object py_fut = hook_(state_, bucket); - - try { - return py_fut.cast>()->fut; - } catch (const py::cast_error& e) { - auto type = py_fut.get_type(); - auto errMsg = c10::str( - e.what(), - ". DDP communication hook's callback must return a " - "torch.futures.Future or torch._C.Future object, but got ", - type.attr("__module__").cast(), - ".", - type.attr("__qualname__").cast()); - throw std::runtime_error(errMsg); - } -} - -std::vector PythonCommHook::processFuture( - c10::IValue future_value) { - // Since we have a Python hook, future_value can be a PyObject. - if (future_value.isPyObject()) { - // We first convert it to an IValue that contains a TensorVector. - py::gil_scoped_acquire ag; - py::object obj = torch::jit::toPyObject(future_value); - auto value = torch::jit::toIValue( - obj, c10::ListType::create(c10::TensorType::get())); - - return value.toTensorVector(); - } - - return future_value.toTensorVector(); -} - } // namespace c10d diff --git a/torch/lib/c10d/comm.hpp b/torch/lib/c10d/comm.hpp new file mode 100644 index 0000000000000..3a39baccc9532 --- /dev/null +++ b/torch/lib/c10d/comm.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include + +namespace c10d { + +// Broadcast many tensors to all processes in the process group. +void broadcast_coalesced( + c10::intrusive_ptr process_group, + at::TensorList tensors, + size_t buffer_size, + int rank = 0); + +// This class passes bucket contents tensor (for multiple replicas) to +// DDP communication hook. +// Optionally in the future this can be enhanced with parameter to bucket +// mappings as well. +class GradBucket { + public: + explicit GradBucket( + size_t index, + const std::vector& tensors, + const std::vector& offsets = {}, + const std::vector& lengths = {}, + const std::vector& sizes_vec = {}) + : index_(index), + tensors_(tensors), + offsets_(offsets), + lengths_(lengths), + sizes_vec_(sizes_vec) {} + + // Returns the index of the bucket, which is unique across all the buckets. + size_t getIndex() const { + return index_; + } + + // Each tensor in the list that getTensors returns refers to the replica on + // each device. There will be multiple replicas only in the case of single + // process multiple device mode. In the single process single device mode, + // this list would consist of only a single tensor. + const std::vector& getTensors() const { + return tensors_; + } + + // Returns a mutable tensor vector compared with the above method. + std::vector& getTensorsRef() { + return tensors_; + } + + // Returns the start index of each variable in tensors_[0]. + const std::vector& getOffsets() const { + return offsets_; + } + + // Returns the total (i.e., flattened) length of each variable in + // tensors_[0]. + const std::vector& getLengths() const { + return lengths_; + } + + // Returns the multi-dimensional sizes/shape of each variable in tensors_[0]. + const std::vector& getSizesVec() const { + return sizes_vec_; + } + + private: + size_t index_; + std::vector tensors_; + + // Per-variable info in tensors_[0]. + std::vector offsets_; + std::vector lengths_; + std::vector sizes_vec_; +}; + +// Base class of both `PythonCommHook` and `CppCommHook`. +// Requires implementing 1) `runHook` method that communicates gradients +// asynchronously, and 2) `parseHookResult` method that converts the hook +// result into a tensor vector. +class TORCH_PYTHON_API CommHookInterface { + public: + virtual ~CommHookInterface() {} + + // Passes the input grad bucket to the registered communication hook. + // Once the tensors in the bucket are ready, kicks off the hook asynchronously + // and returns a future that holds the communication results. + virtual c10::intrusive_ptr runHook( + GradBucket& bucket) = 0; + + // Returns the resulting tensors once the communication hook result is + // ready. The resulting tensors will then be copied to the grads of + // individual parameters. + virtual std::vector parseHookResult( + const c10::IValue& result) = 0; +}; + +// This CppCommHook interface only requires implementing runHook method that +// potentially uses a state. +// Still need TORCH_PYTHON_API instead of TORCH_API to support Windows platform. +template +class TORCH_PYTHON_API CppCommHookInterface : public CommHookInterface { + public: + explicit CppCommHookInterface(T& state) : state_(state) {} + + virtual ~CppCommHookInterface() {} + + std::vector parseHookResult(const c10::IValue& result) override { + TORCH_INTERNAL_ASSERT( + result.isTensor() || result.isTensorList(), + "expected the hook result is either a Tensor or a TensorList"); + + if (result.isTensor()) { + return {result.toTensor()}; + } + + return result.toTensorVector(); + } + + protected: + T state_; // Not owned. +}; + +} // namespace c10d diff --git a/torch/lib/c10d/default_comm_hooks.cpp b/torch/lib/c10d/default_comm_hooks.cpp new file mode 100644 index 0000000000000..105f7e14d9594 --- /dev/null +++ b/torch/lib/c10d/default_comm_hooks.cpp @@ -0,0 +1,41 @@ +#include + +#include +#include +#include + +namespace c10d { + +c10::intrusive_ptr AllReduceCommHook::runHook( + GradBucket& bucket) { + auto allreduce_work = state_->allreduce(bucket.getTensorsRef()); + + auto div_by_process_group_size = [allreduce_work, this]() { + auto tensor = allreduce_work->result()[0] / state_->getSize(); + return c10::IValue(tensor); + }; + + auto fut = allreduce_work->getFuture(); + return fut->then(div_by_process_group_size, fut->elementType()); +} + +c10::intrusive_ptr FP16CompressCommHook::runHook( + GradBucket& bucket) { + auto& tensors = bucket.getTensorsRef(); + for (auto& tensor : tensors) { + tensor.copy_(tensor.to(torch::kFloat16)); + } + auto allreduce_work = state_->allreduce(tensors); + + auto decompress_and_div_by_process_group_size = [allreduce_work, this]() { + auto tensor = allreduce_work->result()[0]; + tensor.copy_(tensor.to(torch::kFloat) / state_->getSize()); + return c10::IValue(tensor); + }; + + auto fut = allreduce_work->getFuture(); + return fut->then( + decompress_and_div_by_process_group_size, fut->elementType()); +} + +} // namespace c10d diff --git a/torch/lib/c10d/default_comm_hooks.hpp b/torch/lib/c10d/default_comm_hooks.hpp new file mode 100644 index 0000000000000..077d29bd977de --- /dev/null +++ b/torch/lib/c10d/default_comm_hooks.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace c10d { + +enum class BuiltinCommHookType { + ALLREDUCE = 1, + FP16_COMPRESS = 2, +}; + +class AllReduceCommHook : public CppCommHookInterface { + public: + explicit AllReduceCommHook(ProcessGroup* state) + : CppCommHookInterface(state) {} + + ~AllReduceCommHook() override {} + + c10::intrusive_ptr runHook(GradBucket& bucket) override; +}; + +class FP16CompressCommHook : public CppCommHookInterface { + public: + explicit FP16CompressCommHook(ProcessGroup* state) + : CppCommHookInterface(state) {} + + ~FP16CompressCommHook() override {} + + c10::intrusive_ptr runHook(GradBucket& bucket) override; +}; + +} // namespace c10d diff --git a/torch/lib/c10d/example/allreduce.cpp b/torch/lib/c10d/example/allreduce.cpp index 76d6a5588f7ea..3de7447d092ae 100644 --- a/torch/lib/c10d/example/allreduce.cpp +++ b/torch/lib/c10d/example/allreduce.cpp @@ -19,7 +19,7 @@ int main(int argc, char** argv) { } // Kick off work - std::vector> pending; + std::vector> pending; for (auto i = 0; i < ntensors; i++) { std::vector tmp = {tensors[i]}; pending.push_back(pg.allreduce(tmp)); diff --git a/torch/lib/c10d/frontend.cpp b/torch/lib/c10d/frontend.cpp new file mode 100644 index 0000000000000..bb8eb2045a288 --- /dev/null +++ b/torch/lib/c10d/frontend.cpp @@ -0,0 +1,921 @@ +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef USE_C10D_GLOO +#include +#endif + +#ifdef USE_C10D_NCCL +#include +#endif + +#ifdef USE_C10D_MPI +#include +#endif + +namespace c10d { + +namespace { + +void maybePreprocessComplexTensor(at::Tensor& tensor) { + if(!tensor.is_complex()) { + return; + } + + tensor = at::view_as_real(tensor); +} + +void maybePreprocessComplexTensor(std::vector& tensors) { + for(at::Tensor& t : tensors) { + maybePreprocessComplexTensor(t); + } +} + +void maybePreprocessComplexTensor(std::vector>& tensors_lists) { + for(std::vector& t : tensors_lists) { + maybePreprocessComplexTensor(t); + } +} + +void assertReduceOpSupportsComplexTensor(ReduceOp op) { + switch (op) { + case ReduceOp::MAX: + case ReduceOp::MIN: + case ReduceOp::PRODUCT: + AT_ERROR( + "all_reduce does not support requested Reduce op on complex tensors"); + default: + return; + } +} + +} // namespace anonymous + +std::string Backend::get(const std::string& backend_type) { + return backend_type; +} + +void Backend::registerBackend() { + TORCH_CHECK(false, "Registering third-party backend is currently not supported by TorchScript-friendly c10d"); +} + +c10::intrusive_ptr DistributedC10d::get() { + static c10::intrusive_ptr singleton = + c10::make_intrusive(); + + return singleton; +} + +c10::intrusive_ptr DistributedC10d::getProcessGroupByName(const std::string& name) const { + auto it = std::find_if( + pg_names_.begin(), + pg_names_.end(), + [&](const std::pair, std::string>& + pg_name) { return pg_name.second == name; }); + + if (it == pg_names_.end()) { + std::stringstream error; + error << "Unable to find process group with name: "; + error << name; + error << " , instead we have "; + error << pg_names_.size() << " process groups: {"; + for (const auto& pg : pg_names_) { + error << static_cast(pg.first.get()); + error << " with name: "; + error << pg.second; + error << ", "; + } + error << "}"; + AT_ERROR(error.str()); + } + + TORCH_CHECK(it->first.defined(), "found a process group that's null"); + + return it->first; +} + +std::string DistributedC10d::getNameOfProcessGroup(const c10::intrusive_ptr& pg) const { + auto it = pg_names_.find(pg); + if (it == pg_names_.end()) { + std::stringstream error; + error << "Unable to find name of process group "; + error << static_cast(pg.get()); + error << "instead we have " << pg_names_.size() << " process groups: {"; + for (const auto& pg : pg_names_) { + error << static_cast(pg.first.get()); + error << " with name: "; + error << pg.second; + error << ", "; + } + error << "}"; + AT_ERROR(error.str()); + } + + return it->second; +} + +c10::intrusive_ptr DistributedC10d::newProcessGroupHelper( + const int64_t world_size, + const int64_t rank, + const std::vector& group_ranks, + const std::string& backend_str, + const c10::intrusive_ptr& store, + c10::optional group_name, + int64_t timeout_milisesonds) { + if (!group_name.has_value()) { + group_name = std::to_string(group_count_); + ++group_count_; + } + + auto it = std::find_if( + pg_names_.begin(), + pg_names_.end(), + [&](const std::pair, std::string>& + pg_name) { return pg_name.second == *group_name; }); + + if (it != pg_names_.end()) { + throw std::runtime_error( + "The specified group name has already been " + "created, please use a different group name"); + } + + bool is_default_group = (group_ranks.size() == 0); + + c10::intrusive_ptr pg; + + auto timeout = std::chrono::milliseconds(timeout_milisesonds); + + std::string backend = Backend::get(backend_str); + if (backend == "mpi") { +#ifdef USE_C10D_MPI + std::vector group_ranks_copy(group_ranks.begin(), group_ranks.end()); + pg = ProcessGroupMPI::createProcessGroupMPI(group_ranks_copy); +#else + AT_ERROR( + "Distributed package doesn't have MPI built in." + " MPI is only included if you build PyTorch from" + " source on a host that has MPI installed."); +#endif + } else { + if (!is_default_group) { + int64_t global_rank = default_pg_->getRank(); + if (std::find(group_ranks.begin(), group_ranks.end(), global_rank) == + group_ranks.end()) { + return pg; + } + } + + auto prefix_store = c10::make_intrusive(*group_name, store); + + if (backend == "gloo") { +#ifdef USE_C10D_GLOO + auto options = ProcessGroupGloo::Options(); + + // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. + char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV); + if (ifnameEnv) { + for (const auto& iface : split(',', ifnameEnv)) { + options.devices.push_back( + ::c10d::ProcessGroupGloo::createDeviceForInterface(iface)); + } + } else { + // If no hostname is specified, this function looks up + // the machine's hostname and returns a device instance + // associated with the address that the hostname resolves to. + options.devices.push_back( + ::c10d::ProcessGroupGloo::createDefaultDevice()); + } + + options.timeout = timeout; + options.threads = options.devices.size() * 2; + pg = c10::make_intrusive( + prefix_store, rank, world_size, options); +#else + AT_ERROR( + "Attempting to create GLOO-based process group while GLOO is either not enabled or built"); +#endif // USE_C10D_GLOO + } else if (backend == "nccl") { +#ifdef USE_C10D_NCCL + auto options = c10::make_intrusive(); + + options->isHighPriorityStream = false; + options->opTimeout = timeout; + pg = c10::make_intrusive( + prefix_store, rank, world_size, options); +#else + AT_ERROR( + "Attempting to create NCCL-based process group while NCCL is either not enabled or built"); +#endif // USE_C10D_NCCL + } else { + // TODO: discuss to figure out how to extend this to third party backends? + AT_ERROR("Unsupported backend type: ", backend); + } + } + + // register to process group map + pg_map_[pg] = std::make_pair(backend, store); + pg_names_[pg] = *group_name; + return pg; +} + +// Note: We assume that group.WORLD equates default_pg_. Otherwise, +// we need many additional conditionals to check whether group is WORLD and +// then use default_pg_ explicitly. + +int64_t DistributedC10d::getRank( + const c10::intrusive_ptr& group) const { + if (rankNotInGroup(group)) { + return -1; + } + + return group->getRank(); +} + +int64_t DistributedC10d::getWorldSize( + const c10::intrusive_ptr& group) const { + if (rankNotInGroup(group)) { + return -1; + } + + return getGroupSize(group); +} + +int64_t DistributedC10d::getGroupSize( + const c10::intrusive_ptr& group) const { + if (group == default_pg_) { + default_pg_->getSize(); + } + + auto it = pg_group_ranks_.find(group); + TORCH_CHECK(it != pg_group_ranks_.end(), "The given group does not exist"); + + return it->second.size(); +} + +void DistributedC10d::checkDefaultPg() const { + TORCH_CHECK(default_pg_, "Default process group is not initialized"); +} + +c10::intrusive_ptr DistributedC10d::worldProcessGroup() { + checkDefaultPg(); + return default_pg_; +} + +bool DistributedC10d::rankNotInGroup( + const c10::intrusive_ptr& group) const { + if (group == default_pg_) { + return false; + } + return group; +} + +int64_t DistributedC10d::getGroupRank( + const c10::intrusive_ptr& group, + const int64_t rank) const { + TORCH_CHECK( + group != default_pg_, + "group.WORLD does not have local rank to global rank mapping"); + + auto it = pg_group_ranks_.find(group); + TORCH_CHECK(it != pg_group_ranks_.end(), "The given group does not exist"); + + auto& group_rank_map = it->second; + auto g_it = group_rank_map.find(rank); + if (g_it == group_rank_map.end()) { + std::string group_name = "Unknown"; + auto name_it = pg_names_.find(group); + if (name_it != pg_names_.end()) { + group_name = name_it->second; + } + + TORCH_CHECK( + false, + "The global rank ", + rank, + " is not part of the group ", + group_name); + } + + return g_it->second; +} + +int64_t DistributedC10d::getGlobalRank( + const c10::intrusive_ptr& group, + const int64_t group_rank) const { + TORCH_CHECK( + group != default_pg_, + "group.WORLD does not have local rank to global rank mapping"); + + auto it = pg_group_ranks_.find(group); + TORCH_CHECK(it != pg_group_ranks_.end(), "The given group does not exist"); + + auto& group_rank_map = it->second; + for (const auto& p : group_rank_map) { + if (p.second == group_rank) { + return p.first; + } + } + + AT_ERROR("The group rank is not part of the group"); +} + +std::string DistributedC10d::getBackend( + const c10::intrusive_ptr& group) { + TORCH_CHECK(!rankNotInGroup(group), "Invalid process group specified"); + + auto it = pg_map_.find(group); + TORCH_CHECK(it != pg_map_.end(), "The given group does not exist"); + + return it->second.first; +} + +c10::intrusive_ptr DistributedC10d::isend( + at::Tensor tensor, + int64_t dst, + const c10::intrusive_ptr& group, + c10::optional& tag) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + std::vector inputs = {std::move(tensor)}; + + checkDefaultPg(); + if (group == default_pg_) { + return default_pg_->send(inputs, dst, tag.value_or(0)); + } + + auto group_dst_rank = getGroupRank(group, dst); + return group->send(inputs, group_dst_rank, tag.value_or(0)); +} + +c10::intrusive_ptr DistributedC10d::irecv( + at::Tensor tensor, + int64_t src, + const c10::intrusive_ptr& group, + c10::optional& tag) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + std::vector inputs = {std::move(tensor)}; + + checkDefaultPg(); + if (group == default_pg_) { + return default_pg_->recv(inputs, src, tag.value_or(0)); + } + + auto group_dst_rank = getGroupRank(group, src); + return group->recv(inputs, group_dst_rank, tag.value_or(0)); +} + +void DistributedC10d::send( + at::Tensor tensor, + int64_t dst, + const c10::intrusive_ptr& group, + c10::optional& tag) { + auto work = isend(std::move(tensor), dst, group, tag); + if (work) { + work->wait(); + } +} + +int64_t DistributedC10d::recv( + at::Tensor tensor, + const c10::optional& src, + const c10::intrusive_ptr& group, + c10::optional& tag) { + if (rankNotInGroup(group)) { + return -1; + } + + std::vector outputs = {std::move(tensor)}; + if (!src.has_value()) { + auto work = group->recvAnysource(outputs, tag.value_or(0)); + work->wait(); + auto src_rank = work->sourceRank(); + if (group == default_pg_) { + return src_rank; + } + + return getGlobalRank(group, src_rank); + } + + if (group == default_pg_) { + group->recv(outputs, src.value(), tag.value_or(0))->wait(); + } else { + int64_t group_src_rank = getGroupRank(group, src.value()); + group->recv(outputs, group_src_rank, tag.value_or(0))->wait(); + } + + return src.value(); +} + +c10::intrusive_ptr DistributedC10d::broadcastMultiGPU( + std::vector& tensor_list, + int64_t src, + const c10::intrusive_ptr& group, + bool async_op, + int64_t src_tensor) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + BroadcastOptions opts; + opts.rootRank = src; + opts.rootTensor = src_tensor; + + checkDefaultPg(); + c10::intrusive_ptr work; + if (group == default_pg_) { + work = default_pg_->broadcast(tensor_list, opts); + } else { + int64_t group_src_rank = getGroupRank(group, src); + opts.rootRank = group_src_rank; + work = group->broadcast(tensor_list, opts); + } + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::broadcast( + at::Tensor tensor, + int64_t src, + const c10::intrusive_ptr& group, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + BroadcastOptions opts; + opts.rootRank = src; + opts.rootTensor = 0; + + std::vector tensors = {std::move(tensor)}; + c10::intrusive_ptr work; + checkDefaultPg(); + if (group == default_pg_) { + work = group->broadcast(tensors, opts); + } else { + int64_t group_src_rank = getGroupRank(group, src); + opts.rootRank = group_src_rank; + work = group->broadcast(tensors, opts); + } + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allReduceMultiGPU( + std::vector& tensor_list, + const c10::intrusive_ptr& group, + ReduceOp op, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + AllreduceOptions opts; + opts.reduceOp = op; + + assertReduceOpSupportsComplexTensor(op); + maybePreprocessComplexTensor(tensor_list); + + auto work = group->allreduce(tensor_list, opts); + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allReduce( + at::Tensor tensor, + const c10::intrusive_ptr& group, + ReduceOp op, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + AllreduceOptions opts; + opts.reduceOp = op; + + assertReduceOpSupportsComplexTensor(op); + maybePreprocessComplexTensor(tensor); + + std::vector tensors = {std::move(tensor)}; + auto work = group->allreduce(tensors, opts); + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allReduceCoalesced( + std::vector& tensors, + const c10::intrusive_ptr& group, + ReduceOp op, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + AllreduceCoalescedOptions opts; + opts.reduceOp = op; + + assertReduceOpSupportsComplexTensor(op); + maybePreprocessComplexTensor(tensors); + + auto work = group->allreduce_coalesced(tensors, opts); + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::reduceMultiGPU( + std::vector& tensor_list, + int64_t dst, + const c10::intrusive_ptr& group, + ReduceOp op, + bool async_op, + int64_t dst_tensor) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + ReduceOptions opts; + opts.reduceOp = op; + opts.rootRank = dst; + opts.rootTensor = dst_tensor; + + checkDefaultPg(); + + c10::intrusive_ptr work; + if (group == default_pg_) { + work = group->reduce(tensor_list, opts); + } else { + int64_t group_dst_rank = getGroupRank(group, dst); + opts.rootRank = group_dst_rank; + work = group->reduce(tensor_list, opts); + } + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::reduce( + at::Tensor tensor, + int64_t dst, + const c10::intrusive_ptr& group, + ReduceOp op, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + ReduceOptions opts; + opts.reduceOp = op; + opts.rootRank = dst; + + checkDefaultPg(); + c10::intrusive_ptr work; + std::vector tensors = {std::move(tensor)}; + if (group == default_pg_) { + work = group->reduce(tensors, opts); + } else { + int64_t group_dst_rank = getGroupRank(group, dst); + opts.rootRank = group_dst_rank; + work = group->reduce(tensors, opts); + } + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allGatherMultiGPU( + std::vector>& output_tensor_lists, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + maybePreprocessComplexTensor(output_tensor_lists); + maybePreprocessComplexTensor(input_tensor_list); + + auto work = group->allgather(output_tensor_lists, input_tensor_list); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allGather( + std::vector& tensor_list, + at::Tensor tensor, + const c10::intrusive_ptr& group, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + maybePreprocessComplexTensor(tensor_list); + maybePreprocessComplexTensor(tensor); + + std::vector> output_tensor_lists = {std::move(tensor_list)}; + std::vector input_tensor_list = {std::move(tensor)}; + auto work = group->allgather(output_tensor_lists, input_tensor_list); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allGatherCoalesced( + std::vector>& output_tensor_lists, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + maybePreprocessComplexTensor(output_tensor_lists); + maybePreprocessComplexTensor(input_tensor_list); + + auto work = + group->allgather_coalesced(output_tensor_lists, input_tensor_list); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::gather( + at::Tensor tensor, + const c10::optional>& gather_list, + const c10::intrusive_ptr& group, + int64_t dst, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + auto my_rank = group->getRank(); + + std::vector> output_tensors; + + if (dst == my_rank) { + TORCH_CHECK( + gather_list.has_value(), + "Argument ``gather_list`` must be specified on destination rank"); + output_tensors.push_back(gather_list.value()); + } else { + TORCH_CHECK( + !gather_list.has_value(), + "Argument ``gather_list`` must NOT be specified on non-destination ranks."); + } + + std::vector input_tensors = {std::move(tensor)}; + + GatherOptions opts; + opts.rootRank = dst; + + c10::intrusive_ptr work; + if (group == default_pg_) { + work = group->gather(output_tensors, input_tensors, opts); + } else { + int64_t group_dst_rank = getGroupRank(group, dst); + opts.rootRank = group_dst_rank; + work = group->gather(output_tensors, input_tensors, opts); + } + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::scatter( + at::Tensor tensor, + std::vector& scatter_list, + const c10::intrusive_ptr& group, + int64_t src, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + auto my_rank = getRank(default_pg_); + + std::vector output_tensors = {std::move(tensor)}; + std::vector> input_tensors; + if (src == my_rank) { + input_tensors.push_back(scatter_list); + } + + ScatterOptions opts; + opts.rootRank = src; + + c10::intrusive_ptr work; + if (group == default_pg_) { + work = group->scatter(output_tensors, input_tensors, opts); + } else { + int64_t group_src_rank = getGroupRank(group, src); + opts.rootRank = group_src_rank; + work = group->scatter(output_tensors, input_tensors, opts); + } + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::reduceScatterMultiGPU( + std::vector& output_tensor_list, + std::vector>& input_tensor_lists, + const c10::intrusive_ptr& group, + ReduceOp op, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + ReduceScatterOptions opts; + opts.reduceOp = op; + + auto work = + group->reduce_scatter(output_tensor_list, input_tensor_lists, opts); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::reduceScatter( + at::Tensor output, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + ReduceOp op, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + ReduceScatterOptions opts; + opts.reduceOp = op; + + std::vector output_tensor_list = {std::move(output)}; + std::vector> input_tensor_lists = {std::move(input_tensor_list)}; + + auto work = + group->reduce_scatter(output_tensor_list, input_tensor_lists, opts); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allToAllSingle( + at::Tensor output, + at::Tensor input, + std::vector& output_split_sizes, + std::vector& input_split_sizes, + const c10::intrusive_ptr& group, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + AllToAllOptions opts; + auto work = group->alltoall_base( + output, input, output_split_sizes, input_split_sizes, opts); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::allToAll( + std::vector& output_tensor_list, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + AllToAllOptions opts; + auto work = group->alltoall(output_tensor_list, input_tensor_list, opts); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +c10::intrusive_ptr DistributedC10d::barrier( + const c10::intrusive_ptr& group, + bool async_op) { + c10::intrusive_ptr empty_work; + if (rankNotInGroup(group)) { + return empty_work; + } + + auto work = group->barrier(); + + if (async_op) { + return work; + } + work->wait(); + return empty_work; +} + +void DistributedC10d::registerProcessGroupName(const c10::intrusive_ptr& process_group, const std::string& name) { + auto it = std::find_if( + pg_names_.begin(), + pg_names_.end(), + [&](const std::pair, std::string>& + pg_name) { return pg_name.second == name; }); + + if (it != pg_names_.end()) { + TORCH_CHECK( + it->first == process_group, + "Requested name already exists: ", + name, + " and it is associated with a different process group"); + return; + } + + it = pg_names_.find(process_group); + TORCH_CHECK( + it == pg_names_.end(), + "Given process group has been registered before with a different name: ", + it->second); + + pg_names_[process_group] = name; +} + +} // namespace c10d diff --git a/torch/lib/c10d/frontend.hpp b/torch/lib/c10d/frontend.hpp new file mode 100644 index 0000000000000..642c59458f329 --- /dev/null +++ b/torch/lib/c10d/frontend.hpp @@ -0,0 +1,262 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10d { + +#ifdef USE_C10D_GLOO +constexpr char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; +#endif + +inline std::vector split( + char separator, + const std::string& string) { + std::vector pieces; + std::stringstream ss(string); + std::string item; + while (std::getline(ss, item, separator)) { + pieces.push_back(std::move(item)); + } + return pieces; +} + +class Backend { + public: + // Maps to Backend.__new__ in Python. + static std::string get(const std::string&); + + // TODO: How to support registering third_party backend? + static void registerBackend(); + + private: + // TODO: Should this be an enum list instead since this set doesn't + // change at all. + std::unordered_set registered_backends_; +}; + +class TORCH_PYTHON_API DistributedC10d : public torch::CustomClassHolder { + public: + static c10::intrusive_ptr get(); + + DistributedC10d() = default; + + void initProcessGroup( + const std::string& backend, + const std::string& init_method, + const std::chrono::milliseconds& timeout, + int64_t world_size, + int64_t rank, + c10::intrusive_ptr store, + const std::string& group_name); + + void destroyProcessGroup(c10::intrusive_ptr group); + int64_t getRank(const c10::intrusive_ptr& group) const; + int64_t getWorldSize(const c10::intrusive_ptr& group) const; + + c10::intrusive_ptr isend( + at::Tensor tensor, + int64_t dst, + const c10::intrusive_ptr& group, + c10::optional& tag); + + c10::intrusive_ptr irecv( + at::Tensor tensor, + int64_t src, + const c10::intrusive_ptr& group, + c10::optional& tag); + + void send( + at::Tensor tensor, + int64_t dst, + const c10::intrusive_ptr& group, + c10::optional& tag); + + int64_t recv( + at::Tensor tensor, + const c10::optional& src, + const c10::intrusive_ptr& group, + c10::optional& tag); + + c10::intrusive_ptr broadcastMultiGPU( + std::vector& tensor_list, + int64_t src, + const c10::intrusive_ptr& group, + bool async_op = false, + int64_t src_tensor = 0); + + c10::intrusive_ptr broadcast( + at::Tensor tensor, + int64_t src, + const c10::intrusive_ptr& group, + bool async_op = false); + + c10::intrusive_ptr allReduceMultiGPU( + std::vector& tensor_list, + const c10::intrusive_ptr& group, + ReduceOp op = ReduceOp::SUM, + bool async_op = false); + + c10::intrusive_ptr allReduce( + at::Tensor tensor, + const c10::intrusive_ptr& group, + ReduceOp op = ReduceOp::SUM, + bool async_op = false); + + c10::intrusive_ptr allReduceCoalesced( + std::vector& tensors, + const c10::intrusive_ptr& group, + ReduceOp op = ReduceOp::SUM, + bool async_op = false); + + c10::intrusive_ptr reduceMultiGPU( + std::vector& tensor_list, + int64_t dst, + const c10::intrusive_ptr& group, + ReduceOp op = ReduceOp::SUM, + bool async_op = false, + int64_t dst_tensor = 0); + + c10::intrusive_ptr reduce( + at::Tensor tensor, + int64_t dst, + const c10::intrusive_ptr& group, + ReduceOp op = ReduceOp::SUM, + bool async_op = false); + + c10::intrusive_ptr allGatherMultiGPU( + std::vector>& output_tensor_lists, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + bool async_op = false); + + c10::intrusive_ptr allGather( + std::vector& tensor_list, + at::Tensor tensor, + const c10::intrusive_ptr& group, + bool async_op = false); + + c10::intrusive_ptr allGatherCoalesced( + std::vector>& output_tensor_lists, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + bool async_op = false); + + c10::intrusive_ptr gather( + at::Tensor tensor, + const c10::optional>& gather_list, + const c10::intrusive_ptr& group, + int64_t dst = 0, + bool async_op = false); + + c10::intrusive_ptr scatter( + at::Tensor tensor, + std::vector& scatter_list, + const c10::intrusive_ptr& group, + int64_t src = 0, + bool async_op = false); + + c10::intrusive_ptr reduceScatterMultiGPU( + std::vector& output_tensor_list, + std::vector>& input_tensor_lists, + const c10::intrusive_ptr& group, + ReduceOp op = ReduceOp::SUM, + bool async_op = false); + + c10::intrusive_ptr reduceScatter( + at::Tensor output, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + ReduceOp op = ReduceOp::SUM, + bool async_op = false); + + c10::intrusive_ptr allToAllSingle( + at::Tensor output, + at::Tensor input, + std::vector& output_split_sizes, + std::vector& input_split_sizes, + const c10::intrusive_ptr& group, + bool async_op = false); + + c10::intrusive_ptr allToAll( + std::vector& output_tensor_list, + std::vector& input_tensor_list, + const c10::intrusive_ptr& group, + bool async_op = false); + + c10::intrusive_ptr barrier( + const c10::intrusive_ptr& group, + bool async_op = false); + + c10::intrusive_ptr newGroup( + std::vector ranks, + std::chrono::milliseconds timeout, + Backend backend); + + c10::intrusive_ptr worldProcessGroup(); + + c10::intrusive_ptr newProcessGroupHelper( + const int64_t world_size, + const int64_t rank, + const std::vector& group_ranks, + const std::string& backend_str, + const c10::intrusive_ptr& store, + c10::optional group_name, + int64_t timeout_milisesonds); + + c10::intrusive_ptr getProcessGroupByName( + const std::string& name) const; + + std::string getNameOfProcessGroup( + const c10::intrusive_ptr& pg) const; + + void registerProcessGroupName(const c10::intrusive_ptr& process_group, const std::string& name); + + private: + + bool rankNotInGroup(const c10::intrusive_ptr& group) const; + int64_t getGroupRank( + const c10::intrusive_ptr& group, + const int64_t rank) const; + int64_t getGlobalRank( + const c10::intrusive_ptr& group, + const int64_t group_rank) const; + void checkDefaultPg() const; + int64_t getGroupSize(const c10::intrusive_ptr& group) const; + std::string getBackend(const c10::intrusive_ptr& group); + + std::string backend_; + // TODO: Ask Alex what kind of equality we need. It determine whether we + // need to use ProcessGroup or ProcesGroup* as key. + std::unordered_map< + c10::intrusive_ptr, + std::pair>> + pg_map_; + + // Note, this is different mapping relationship than original Python + // implementation. + std::unordered_map, std::string> pg_names_; + + // Process group's global rank to local rank mapping + std::unordered_map< + c10::intrusive_ptr, + std::unordered_map> + pg_group_ranks_; + + c10::intrusive_ptr default_pg_; + + // Default value should be "env://" + std::string default_pg_init_method_; + + int64_t group_count_; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp similarity index 92% rename from torch/csrc/distributed/c10d/reducer.cpp rename to torch/lib/c10d/reducer.cpp index 86916c7994ddd..ad0497724cfab 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -1,4 +1,4 @@ -#include +#include #include @@ -6,13 +6,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include namespace c10d { @@ -29,7 +29,7 @@ constexpr int kUnsetDivFactor = -1; Reducer::Reducer( std::vector> replicas, std::vector> bucket_indices, - std::shared_ptr process_group, + c10::intrusive_ptr process_group, std::vector> expect_sparse_gradients, int64_t bucket_bytes_cap, bool find_unused_parameters, @@ -89,10 +89,7 @@ Reducer::Reducer( for (size_t variable_index = 0; variable_index < variable_count; variable_index++) { auto& variable = replicas_[replica_index][variable_index]; - const auto index = VariableIndex{ - .replica_index = replica_index, - .variable_index = variable_index, - }; + const auto index = VariableIndex(replica_index, variable_index); // The gradient accumulator function is lazily initialized once. // Therefore we can use its presence in the autograd graph as @@ -100,15 +97,19 @@ Reducer::Reducer( auto grad_accumulator = torch::autograd::impl::grad_accumulator(variable); +#ifndef _WIN32 using torch::distributed::autograd::ThreadLocalDistAutogradContext; +#endif // Hook to execute after the gradient accumulator has executed. hooks_.emplace_back( grad_accumulator->add_post_hook( torch::make_unique( [=](const torch::autograd::variable_list& outputs, const torch::autograd::variable_list& /* unused */) { +#ifndef _WIN32 this->rpc_context_.set( ThreadLocalDistAutogradContext::getContextPtr()); +#endif this->autograd_hook(index); return outputs; })), @@ -116,8 +117,22 @@ Reducer::Reducer( // Map raw function pointer to replica index and parameter index. // This is used later on when the autograd graph is traversed - // to check for parameters for which no gradient is computed. - func_[grad_accumulator.get()] = index; + // to check for parameters for which no gradient is computed, if + // find_unused_parameters=True. + // We maintain a mapping of gradient accumulator to vector of variables, + // since multiple parameters may share the same grad accumulator. + if (find_unused_parameters_) { + auto gradAcc = gradAccToVariablesMap_.find(grad_accumulator.get()); + if (gradAcc == gradAccToVariablesMap_.end()) { + std::vector indexVec{index}; + gradAccToVariablesMap_[grad_accumulator.get()] = + std::move(indexVec); + } else { + // Scenario where we have indices whose corresponding parameters + // share the same grad accumulator. + gradAcc->second.push_back(index); + } + } // The gradient accumulator is stored as weak_ptr in the autograd // metadata of the variable, so we have to keep it alive here for @@ -193,7 +208,8 @@ Reducer::Reducer( // used for algorithms like Gradient Compression/GossipGrad. This hook can be // registered from Python API using `register_comm_hook`. `PythonCommHook` // enables registering a Python hook and is a subclass of `CommHookInterface`. -// `CommHookInterface` can be used to implement CPP hooks in the future. +// Additionally, there are also some built-in C++ hook implementations that can +// be specified by calling `register_builtin_comm_hook` from Python API. Reducer::~Reducer() noexcept(false) { // Remove all hooks on variables registered by this Reducer. This is necessary @@ -349,7 +365,9 @@ void Reducer::check_grad_layout( } } -void Reducer::copy_grad_to_bucket(at::Tensor& grad, at::Tensor& bucket_view) { +void Reducer::copy_grad_to_bucket( + const at::Tensor& grad, + at::Tensor& bucket_view) { // See Note [DDP Communication Hook] if (comm_hook_ == nullptr) { // imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp @@ -454,7 +472,7 @@ std::vector> Reducer::get_bucket_tensors() const { } void Reducer::set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize) { std::lock_guard lock(mutex_); forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle); @@ -477,10 +495,7 @@ void Reducer::push_rebuilt_params_for_all_indices() { const auto variable_count = replicas_[replica_index].size(); for (size_t variable_index = 0; variable_index < variable_count; ++variable_index) { - const auto index = VariableIndex{ - .replica_index = replica_index, - .variable_index = variable_index, - }; + const auto index = VariableIndex(replica_index, variable_index); push_rebuilt_params(index); } } @@ -697,7 +712,15 @@ void Reducer::mark_bucket_ready(size_t bucket_index) { if (comm_hook_ == nullptr) { bucket.work = process_group_->allreduce(tensors); } else { - bucket.future_work = comm_hook_->runHook(GradBucket(tensors)); + GradBucket grad_bucket( + next_bucket_, + tensors, + // Since currently we do not support single-process multiple-device + // mode, we can assume only one replica in the bucket. + bucket.replicas[0].offsets, + bucket.replicas[0].lengths, + bucket.replicas[0].sizes_vec); + bucket.future_work = comm_hook_->runHook(grad_bucket); } } } @@ -712,8 +735,10 @@ void Reducer::initialize_buckets( // bucket_view, then it needs to check rpc context ptr is nullptr or not, // If rpc context ptr is nullptr, mutate variable.grad(); otherwise, // mutate grad in rpc context. +#ifndef _WIN32 using torch::distributed::autograd::ThreadLocalDistAutogradContext; this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr()); +#endif // This shouldn't be called if we're expecting autograd hooks to fire. TORCH_CHECK( @@ -765,8 +790,17 @@ void Reducer::initialize_buckets( replica.variables = {variable}; } else { at::TensorOptions options; + // The start index of the variable in the flattened tensor. size_t offset = 0; + // Reserve enough space for the per-variable fields stored in bucket + // replica for efficiency. + const size_t num_variables = bucket_indices[bucket_index].size(); + replica.variables.reserve(num_variables); + replica.offsets.reserve(num_variables); + replica.lengths.reserve(num_variables); + replica.sizes_vec.reserve(num_variables); + // Iterate over bucket variables. for (const auto variable_index : bucket_indices[bucket_index]) { TORCH_CHECK( @@ -792,6 +826,7 @@ void Reducer::initialize_buckets( replica.variables.push_back(variable); replica.offsets.push_back(offset); replica.lengths.push_back(length); + replica.sizes_vec.push_back(variable.sizes()); offset += length; } @@ -850,10 +885,8 @@ void Reducer::initialize_buckets( TORCH_CHECK( variable_index < variable_locators_.size(), "Out of range variable index specified."); - variable_locators_[variable_index] = VariableLocator{ - .bucket_index = bucket_index, - .intra_bucket_index = intra_bucket_index++, - }; + variable_locators_[variable_index] = + VariableLocator(bucket_index, intra_bucket_index++); } bucket.variable_indices = std::move(bucket_indices[bucket_index]); @@ -948,31 +981,6 @@ void Reducer::prepare_for_backward( std::unordered_set seen; std::vector queue; - // Check that any prior reduction has finished. - // The variable `require_finalize_` is true until all gradients - // have been computed and reduction of all buckets has been kicked off. - if (require_finalize_) { - TORCH_CHECK( - false, - "Expected to have finished reduction in the prior iteration before ", - "starting a new one. ", - "", - "This error indicates that your module has parameters that were ", - "not used in producing loss. ", - "", - "You can enable unused parameter detection by (1) passing the keyword " - "argument `find_unused_parameters=True` to ", - "`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ", - "`forward` function outputs participate in calculating loss. " - "", - "If you already have done the above two steps, then the distributed ", - "data parallel module wasn't able to locate the output tensors in the ", - "return value of your module's `forward` function. ", - "Please include the loss function and the structure of the return ", - "value of `forward` of your module when reporting this issue (e.g. ", - "list, dict, iterable)."); - } - // Reset accounting. expect_autograd_hooks_ = true; next_bucket_ = 0; @@ -1018,14 +1026,28 @@ void Reducer::prepare_for_backward( } // Find accumulator functions that don't show up in this graph. - for (const auto& it : func_) { + for (const auto& it : gradAccToVariablesMap_) { // If the accumulator function is present in the graph, we know // a gradient will be computed for the corresponding parameter. - if (seen.count(it.first) > 0) { - continue; + if (seen.count(it.first) == 0) { + auto& indices = it.second; + unused_parameters_.reserve(unused_parameters_.size() + indices.size()); + unused_parameters_.insert( + unused_parameters_.end(), indices.begin(), indices.end()); } + } - unused_parameters_.push_back(it.second); + // Warn user about unnecessary perf hit if all parameters were used. + if (unused_parameters_.empty()) { + TORCH_WARN_ONCE( + "find_unused_parameters=True was specified in DDP constructor, " + "but did not find any unused parameters. This flag results in an extra " + "traversal of the autograd graph every iteration, which can adversely " + "affect performance. If your model indeed never has any unused " + "parameters, consider turning this flag off. Note that this warning may " + "be a false positive your model has flow control causing later iterations " + "to have unused parameters." + ); } } @@ -1187,7 +1209,7 @@ void Reducer::finalize_backward() { bucket.future_work->wait(); auto future_result = - comm_hook_->processFuture(bucket.future_work->value()); + comm_hook_->parseHookResult(bucket.future_work->value()); for (size_t i = 0; i < future_result.size(); i++) { auto& replica = bucket.replicas[i]; @@ -1235,7 +1257,9 @@ void Reducer::runGradCallbackForVariable( cb(variable.mutable_grad()); } else { // Under distributed autograd +#ifndef _WIN32 context_ptr->runGradCallbackForVariable(variable, std::move(cb)); +#endif } } @@ -1325,6 +1349,11 @@ void Reducer::sync_bucket_indices( } bool Reducer::rebuild_buckets() { + // Ensure reduction for previous backwards pass is finished. If user's model + // has unused parameters for example, this will raise an error recommending to + // run with find_unused_parameters=True, instead of the size mismatch + // exception below. + ensure_prior_reduction_finished(); std::lock_guard lock(mutex_); if (!should_rebuild_buckets() || rebuilt_params_.empty()) { return false; @@ -1341,8 +1370,9 @@ bool Reducer::rebuild_buckets() { replicas_[0].size() == rebuilt_param_indices_.size(), c10::str( "rebuilt parameter indices size is not same as original model parameters size.", + "Original model param size is: ", replicas_[0].size(), - " versus ", + " versus rebuilt params size of: ", rebuilt_param_indices_.size())); std::vector> rebuilt_bucket_indices; std::vector bucket_size_limits; @@ -1370,7 +1400,8 @@ bool Reducer::rebuild_buckets() { // See Note [DDP Communication Hook] void Reducer::register_comm_hook(std::unique_ptr iface) { TORCH_CHECK( - comm_hook_ == nullptr, "register_comm_hook can only be called once."); + comm_hook_ == nullptr, + "register_comm_hook or register_builtin_comm_hook can only be called once."); // TODO(@sinannasir): Single-process multiple-device mode support for DDP // communication hook. Related to GH Issue #42542. TORCH_CHECK( @@ -1380,6 +1411,60 @@ void Reducer::register_comm_hook(std::unique_ptr iface) { comm_hook_ = std::move(iface); } +// See Note [DDP Communication Hook] +void Reducer::register_builtin_comm_hook( + c10d::BuiltinCommHookType comm_hook_type) { + TORCH_CHECK( + comm_hook_ == nullptr, + "register_builtin_comm_hook or register_comm_hook can only be called once."); + TORCH_CHECK( + replicas_.size() == 1, + "Communication hook does not support single-process multiple-device mode."); + + switch (comm_hook_type) { + case c10d::BuiltinCommHookType::ALLREDUCE: + comm_hook_ = + std::make_unique(process_group_.get()); + LOG(INFO) << "Built-in communication hook ALLREDUCE is registered."; + break; + case c10d::BuiltinCommHookType::FP16_COMPRESS: + comm_hook_ = + std::make_unique(process_group_.get()); + LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered."; + break; + default: + TORCH_WARN_ONCE( + "Unknown built-in DDP comm hook type is provided. No comm hook will be used."); + } +} + +void Reducer::ensure_prior_reduction_finished() { + // Check that any prior reduction has finished. + // The variable `require_finalize_` is true until all gradients + // have been computed and reduction of all buckets has been kicked off. + if (require_finalize_) { + TORCH_CHECK( + false, + "Expected to have finished reduction in the prior iteration before ", + "starting a new one. ", + "", + "This error indicates that your module has parameters that were ", + "not used in producing loss. ", + "", + "You can enable unused parameter detection by (1) passing the keyword " + "argument `find_unused_parameters=True` to ", + "`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ", + "`forward` function outputs participate in calculating loss. " + "", + "If you already have done the above two steps, then the distributed ", + "data parallel module wasn't able to locate the output tensors in the ", + "return value of your module's `forward` function. ", + "Please include the loss function and the structure of the return ", + "value of `forward` of your module when reporting this issue (e.g. ", + "list, dict, iterable)."); + } +} + namespace { // Tensors may be coalesced into buckets. Buckets must contain tensors of diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/lib/c10d/reducer.hpp similarity index 88% rename from torch/csrc/distributed/c10d/reducer.h rename to torch/lib/c10d/reducer.hpp index 960a32356acf3..ada39844a9ca5 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/lib/c10d/reducer.hpp @@ -7,11 +7,13 @@ #include #include +#include +#include #include +#include #include #include #include -#include namespace c10d { @@ -27,7 +29,7 @@ class Reducer { explicit Reducer( std::vector> replicas, std::vector> bucket_indices, - std::shared_ptr process_group, + c10::intrusive_ptr process_group, std::vector> expect_sparse_gradients, int64_t bucket_bytes_cap, bool find_unused_parameters, @@ -55,11 +57,17 @@ class Reducer { return backward_stats_; } - // Registeres a hook to the reducer. The hook is `CommHookInterface` + // Registers a hook to the reducer. The hook is `CommHookInterface` // type to allow both Python and CPP hooks. This function can only // be called once before calling backward. + // Cannot combine with the call of `register_builtin_comm_hook`. void register_comm_hook(std::unique_ptr iface); + // Registers a built-in C++ comm hook to the reducer. This function can only + // be called once before calling backward. + // Cannot combine with the call of `register_comm_hook`. + void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type); + // Returns a vector of tensors in each bucket in sequential order. std::vector> get_bucket_tensors() const; @@ -89,7 +97,7 @@ class Reducer { // Creates and sets ForwardPassWorkHandle given a ProcessGroup::Work and the // corresponding tensor being reduced. void set_forward_pass_work_handle( - std::shared_ptr forwardPassWorkHandle, + c10::intrusive_ptr forwardPassWorkHandle, bool useStaticWorldSize); // Retrieve on-device tensors used to track locally unused parameters. For @@ -104,18 +112,26 @@ class Reducer { struct VariableIndex { size_t replica_index; size_t variable_index; + + VariableIndex() = default; + + VariableIndex(size_t replica_index_, size_t variable_index_) { + replica_index = replica_index_; + variable_index = variable_index_; + } }; void push_rebuilt_params(const VariableIndex& index); mutable std::mutex mutex_; std::vector> replicas_; - std::shared_ptr process_group_; + c10::intrusive_ptr<::c10d::ProcessGroup> process_group_; std::vector> expect_sparse_gradients_; std::vector>> grad_accumulators_; - std::unordered_map func_; + std::unordered_map> + gradAccToVariablesMap_; std::vector>> hooks_; @@ -143,7 +159,7 @@ class Reducer { bool local_used_maps_reduced_; // Work handle for allreduce on local_used_maps_ - std::shared_ptr local_used_work_; + c10::intrusive_ptr local_used_work_; void verify_replicas_within_process(); @@ -163,6 +179,10 @@ class Reducer { void finalize_backward(); + // Asserts that the reduction for the previous iteration has finished before + // rebuilding buckets or kicking off the next one. + void ensure_prior_reduction_finished(); + // Broadcast rebuilt buckets from rank 0 to other ranks before initializing // the buckets void sync_bucket_indices(std::vector>& bucket_indices); @@ -206,10 +226,13 @@ class Reducer { // participating variables after reduction has completed. std::vector variables; - // Per-variable offset/length into the flat bucket contents tensor. + // Per-variable offset/length into the flat bucket contents tensor and grad bucket. std::vector offsets; std::vector lengths; + // Per-variable sizes into the grad bucekt. + std::vector sizes_vec; + // Number of tensors to be added before this bucket is complete. // This is reset to `variables.size()` every iteration. size_t pending; @@ -243,7 +266,7 @@ class Reducer { void check_grad_layout(const at::Tensor& grad, const at::Tensor& bucket_view); // If gradient_as_bucket_view_ is false, before allreduce buckets, // copy grads to buckets. - void copy_grad_to_bucket(at::Tensor& grad, at::Tensor& bucket_view); + void copy_grad_to_bucket(const at::Tensor& grad, at::Tensor& bucket_view); // A bucket holds N bucket replicas (1 per model replica). // @@ -260,7 +283,7 @@ class Reducer { size_t pending; // Keep work handle around when this set of buckets is being reduced. - std::shared_ptr work; + c10::intrusive_ptr work; // Keep future work handle around if DDP comm hook is registered. c10::intrusive_ptr future_work; @@ -281,6 +304,13 @@ class Reducer { size_t bucket_index; // Index of parameter in single bucket replica. size_t intra_bucket_index; + + VariableLocator() = default; + + VariableLocator(size_t bucket_index_, size_t intra_bucket_index_) { + bucket_index = bucket_index_; + intra_bucket_index = intra_bucket_index_; + } }; // Map the index of a variable to its location in the bucket structure. @@ -311,7 +341,7 @@ class Reducer { // A struct containing work handle and tensor for allreduce scheduled in // forward pass, if applicable. struct ForwardPassAllreduceWork { - std::shared_ptr workHandle; + c10::intrusive_ptr workHandle; at::Tensor resultTensor; // whether we should divide by the initial world_size or the no. of // remaining DDP ranks. diff --git a/torch/lib/c10d/test/CMakeLists.txt b/torch/lib/c10d/test/CMakeLists.txt index 8429d1099b298..b74d4b65f70f7 100644 --- a/torch/lib/c10d/test/CMakeLists.txt +++ b/torch/lib/c10d/test/CMakeLists.txt @@ -8,14 +8,19 @@ function(c10d_add_test test_src) get_filename_component(test_name ${test_src} NAME_WE) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) - target_link_libraries(${test_name} pthread ${ARGN}) - target_compile_options(${test_name} PRIVATE -Wno-error) + target_link_libraries(${test_name} ${ARGN}) + if(NOT WIN32) + target_link_libraries(${test_name} pthread) + target_compile_options(${test_name} PRIVATE -Wno-error) + endif() add_test(NAME ${test_name} COMMAND $) endfunction() c10d_add_test(FileStoreTest.cpp c10d gtest_main) -c10d_add_test(HashStoreTest.cpp c10d gtest_main) c10d_add_test(TCPStoreTest.cpp c10d gtest_main) +if(NOT WIN32) + c10d_add_test(HashStoreTest.cpp c10d gtest_main) +endif() if(USE_CUDA) if(USE_C10D_GLOO) @@ -29,7 +34,7 @@ if(USE_CUDA) endif() else() if(USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp c10d c10d gtest_main) + c10d_add_test(ProcessGroupGlooTest.cpp c10d gtest_main) endif() endif() diff --git a/torch/lib/c10d/test/CUDATest.cu b/torch/lib/c10d/test/CUDATest.cu index c47b29ea536d1..88f87492206c7 100644 --- a/torch/lib/c10d/test/CUDATest.cu +++ b/torch/lib/c10d/test/CUDATest.cu @@ -17,6 +17,7 @@ __global__ void waitClocks(const uint64_t count) { void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks) { waitClocks<<<1, 1, 0, stream.stream()>>>(clocks); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } int cudaNumDevices() { diff --git a/torch/lib/c10d/test/CUDATest.hpp b/torch/lib/c10d/test/CUDATest.hpp index defaff895a18f..328da2faf648f 100644 --- a/torch/lib/c10d/test/CUDATest.hpp +++ b/torch/lib/c10d/test/CUDATest.hpp @@ -5,9 +5,15 @@ namespace c10d { namespace test { -void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks); +#ifdef _WIN32 +#define EXPORT_TEST_API __declspec(dllexport) +#else +#define EXPORT_TEST_API +#endif -int cudaNumDevices(); +EXPORT_TEST_API void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks); + +EXPORT_TEST_API int cudaNumDevices(); } // namespace test } // namespace c10d diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index 77215f4521c27..b24725e64e13e 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -1,6 +1,8 @@ #include +#ifndef _WIN32 #include +#endif #include #include @@ -10,6 +12,11 @@ #include #include +#ifdef _WIN32 +std::string tmppath() { + return c10d::test::autoGenerateTmpFilePath(); +} +#else std::string tmppath() { const char* tmpdir = getenv("TMPDIR"); if (tmpdir == nullptr) { @@ -29,11 +36,12 @@ std::string tmppath() { close(fd); return std::string(tmp.data(), tmp.size()); } +#endif void testGetSet(std::string path, std::string prefix = "") { // Basic Set/Get on File Store { - auto fileStore = std::make_shared(path, 2); + auto fileStore = c10::make_intrusive(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -41,13 +49,19 @@ void testGetSet(std::string path, std::string prefix = "") { c10d::test::check(store, "key0", "value0"); c10d::test::check(store, "key1", "value1"); c10d::test::check(store, "key2", "value2"); + auto numKeys = fileStore->getNumKeys(); + EXPECT_EQ(numKeys, 3); } // Perform get on new instance { - auto fileStore = std::make_shared(path, 2); + auto fileStore = c10::make_intrusive(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::check(store, "key0", "value0"); + auto numKeys = fileStore->getNumKeys(); + // There will be 4 keys since we still use the same underlying file as the + // other store above. + EXPECT_EQ(numKeys, 4); } } @@ -61,7 +75,8 @@ void stressTestStore(std::string path, std::string prefix = "") { for (auto i = 0; i < numThreads; i++) { threads.push_back(std::thread([&] { - auto fileStore = std::make_shared(path, numThreads + 1); + auto fileStore = + c10::make_intrusive(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); sem1.post(); sem2.wait(); @@ -79,7 +94,7 @@ void stressTestStore(std::string path, std::string prefix = "") { // Check that the counter has the expected value { - auto fileStore = std::make_shared(path, numThreads + 1); + auto fileStore = c10::make_intrusive(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); std::string expected = std::to_string(numThreads * numIterations); c10d::test::check(store, "counter", expected); diff --git a/torch/lib/c10d/test/HashStoreTest.cpp b/torch/lib/c10d/test/HashStoreTest.cpp index f2197861c787d..24b7fc76a417d 100644 --- a/torch/lib/c10d/test/HashStoreTest.cpp +++ b/torch/lib/c10d/test/HashStoreTest.cpp @@ -11,7 +11,7 @@ void testGetSet(std::string prefix = "") { // Basic set/get { - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -19,11 +19,20 @@ void testGetSet(std::string prefix = "") { c10d::test::check(store, "key0", "value0"); c10d::test::check(store, "key1", "value1"); c10d::test::check(store, "key2", "value2"); + auto numKeys = store.getNumKeys(); + EXPECT_EQ(numKeys, 3); + auto delSuccess = store.deleteKey("key0"); + EXPECT_TRUE(delSuccess); + numKeys = store.getNumKeys(); + EXPECT_EQ(numKeys, 2); + auto delFailure = store.deleteKey("badKeyName"); + EXPECT_FALSE(delFailure); + EXPECT_THROW(store.get("key0"), std::runtime_error); } // get() waits up to timeout_. { - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); std::thread th([&]() { c10d::test::set(store, "key0", "value0"); }); c10d::test::check(store, "key0", "value0"); @@ -38,7 +47,7 @@ void stressTestStore(std::string prefix = "") { std::vector threads; c10d::test::Semaphore sem1, sem2; - auto hashStore = std::make_shared(); + auto hashStore = c10::make_intrusive(); c10d::PrefixStore store(prefix, hashStore); for (auto i = 0; i < numThreads; i++) { diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index 92dede9a573e4..091ea9b2ad073 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -45,7 +45,7 @@ class AsyncTest { } void start(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; @@ -93,7 +93,7 @@ class AsyncInputIsOutputTest : public AsyncTest { } } - void wait(std::shared_ptr& work) { + void wait(c10::intrusive_ptr& work) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(); } @@ -130,7 +130,7 @@ class AsyncAllreduceTest : public AsyncInputIsOutputTest { AsyncAllreduceTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -156,7 +156,7 @@ class AsyncBroadcastTest : public AsyncInputIsOutputTest { AsyncBroadcastTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -185,7 +185,7 @@ void runAsyncAllreduceTest( size_t numProcesses = 4, size_t numTensors = 2) { auto tests = initialize(path, numProcesses, numTensors); - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(); } @@ -229,7 +229,7 @@ void runAsyncBroadcastTest( // Try every permutation of root rank and root tensor for (size_t rootRank = 0; rootRank < numProcesses; rootRank++) { for (size_t rootTensor = 0; rootTensor < numTensors; rootTensor++) { - std::vector> work(numProcesses); + std::vector> work(numProcesses); for (size_t i = 0; i < numProcesses; i++) { work[i] = tests[i].run(rootRank, rootTensor); } diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index 6606e553e7330..469cf32a84429 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -1,7 +1,10 @@ +#ifndef _WIN32 #include -#include #include #include +#endif + +#include #include #include @@ -21,6 +24,7 @@ using namespace c10d::test; constexpr auto kSendDelay = std::chrono::milliseconds(100); constexpr auto kWaitTimeout = std::chrono::milliseconds(1); +#ifndef _WIN32 class SignalTest { public: SignalTest(const std::string& path) : path_(path) {} @@ -40,8 +44,8 @@ class SignalTest { }); } - std::shared_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); ::c10d::ProcessGroupGloo::Options options; // Set a timeout that is small enough to make this test run fast, but also @@ -58,7 +62,7 @@ class SignalTest { }; // Loop until an exception happens - std::shared_ptr<::c10d::ProcessGroup::Work> work; + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work; while (true) { work = pg.allreduce(tensors); try { @@ -78,7 +82,7 @@ class SignalTest { Semaphore sem_; }; -std::shared_ptr<::c10d::ProcessGroup::Work> testSignal( +c10::intrusive_ptr<::c10d::ProcessGroup::Work> testSignal( const std::string& path, int signal) { Fork fork; @@ -92,17 +96,18 @@ std::shared_ptr<::c10d::ProcessGroup::Work> testSignal( test.arm(fork.pid, signal); return test.run(0, 2); } +#endif class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { public: ProcessGroupGlooDelayed( - const std::shared_ptr<::c10d::Store>& store, + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, Options options) : ProcessGroupGloo(store, rank, size, options) {} - std::shared_ptr<::c10d::ProcessGroup::Work> send( + c10::intrusive_ptr<::c10d::ProcessGroup::Work> send( std::vector& tensors, int dstRank, int tag) override { @@ -146,7 +151,7 @@ class CollectiveTest { } void start(int rank, int size, bool delayed) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); // Set a timeout that is small enough to make this test run fast, but also // make sure that we don't get timeouts in the ProcessGroupGloo constructor. @@ -195,7 +200,7 @@ void testAllreduce(const std::string& path, const at::DeviceType b) { } // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().allreduce(inputs[i]); } @@ -245,7 +250,7 @@ void testBroadcast(const std::string& path, const at::DeviceType b) { options.rootTensor = j; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().broadcast(inputs[i], options); } @@ -311,7 +316,7 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { }; // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto rank = 0; rank < size; rank++) { work[rank] = tests[rank].getProcessGroup().alltoall_base( outputs[rank], inputs[rank], outputSplits[rank], inputSplits[rank]); @@ -344,7 +349,7 @@ void testBarrier(const std::string& path) { auto tests = CollectiveTest::initialize(path, size); // Kick off work - std::vector> work(size); + std::vector> work(size); for (auto i = 0; i < size; i++) { work[i] = tests[i].getProcessGroup().barrier(); } @@ -456,6 +461,7 @@ void testRecv(const std::string& path) { EXPECT_TRUE(recvCompleted); } +#ifndef _WIN32 TEST(ProcessGroupGlooTest, testSIGSTOPException) { // test SIGSTOP // Fork() and TSAN don't play well together, so skip the test if we're testing @@ -485,6 +491,7 @@ TEST(ProcessGroupGlooTest, testSIGKILLException) { EXPECT_FALSE(work->isSuccess()); EXPECT_THROW(std::rethrow_exception(work->exception()), std::exception); } +#endif TEST(ProcessGroupGlooTest, testAllReduceCPU) { { diff --git a/torch/lib/c10d/test/ProcessGroupMPITest.cpp b/torch/lib/c10d/test/ProcessGroupMPITest.cpp index 3f5a9e4cf3314..5503b4cde866f 100644 --- a/torch/lib/c10d/test/ProcessGroupMPITest.cpp +++ b/torch/lib/c10d/test/ProcessGroupMPITest.cpp @@ -13,8 +13,8 @@ // Wait for work to complete void waitWork( - std::shared_ptr pg, - std::vector> works) { + c10::intrusive_ptr<::c10d::ProcessGroupMPI> pg, + std::vector> works) { for (auto& work : works) { try { work->wait(); @@ -34,10 +34,11 @@ void testAllreduce(int iter = 1000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->allreduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->allreduce(tensors); works.push_back(std::move(work)); } @@ -73,10 +74,11 @@ void testBroadcast(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->broadcast(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = + pg->broadcast(tensors); works.push_back(std::move(work)); } @@ -104,10 +106,10 @@ void testReduce(int iter = 10000) { allTensors[i] = std::vector({tensor}); } - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->reduce(tensors); works.push_back(std::move(work)); } @@ -150,10 +152,10 @@ void testAllgather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->allgather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -198,10 +200,10 @@ void testGather(int iter = 10000) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->gather(allOutputTensors[i], allTensors[i]); works.push_back(std::move(work)); } @@ -249,10 +251,10 @@ void testScatter(int iter = 1) { } } - std::vector> works; + std::vector> works; for (size_t i = 0; i < allTensors.size(); ++i) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->scatter(allTensors[i], allInputTensors[i]); works.push_back(std::move(work)); } @@ -289,27 +291,27 @@ void testSendRecv(bool recvAnysource, int iter = 10000) { } if (rank == 0) { - std::vector> works; + std::vector> works; for (auto& tensors : allTensors) { // Kick off work - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->send(tensors, 1, 0); works.push_back(std::move(work)); } waitWork(pg, works); } if (rank == 1) { - std::vector> works; + std::vector> works; std::vector srcRanks(allTensors.size(), -1); size_t i = 0; for (auto& tensors : allTensors) { // Kick off work if (!recvAnysource) { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recv(tensors, 0, 0); works.push_back(std::move(work)); } else { - std::shared_ptr<::c10d::ProcessGroup::Work> work = + c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = pg->recvAnysource(tensors, 0); works.push_back(std::move(work)); } diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 93f633938e180..3dbd266553911 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -16,8 +16,10 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLSimulateErrors( const std::vector& devices, - bool simulate_error) - : WorkNCCL(devices), simulate_error_(simulate_error) {} + bool simulate_error, + int rank, + c10d::OpType opType) + : WorkNCCL(devices, rank, opType), simulate_error_(simulate_error) {} std::exception_ptr checkForNCCLErrors( const std::vector>& ncclComms) @@ -35,10 +37,10 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { public: ProcessGroupNCCLSimulateErrors( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, - c10d::ProcessGroupNCCL::Options opts) + c10::intrusive_ptr opts) : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {} std::exception_ptr checkForNCCLErrors( @@ -54,9 +56,13 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis); } - std::shared_ptr initWork( - std::vector devices) override { - return std::make_shared(devices, simulate_error_); + c10::intrusive_ptr initWork( + std::vector devices, + int rank, + c10d::OpType opType, + const char* profilingTitle) override { + return c10::make_intrusive( + devices, simulate_error_, rank, opType); } size_t getNCCLCommCacheSize() { @@ -79,8 +85,11 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLTimedoutErrors( const std::vector& devices, - bool set_timedout_error) - : WorkNCCL(devices), set_timedout_error_(set_timedout_error) {} + bool set_timedout_error, + int rank, + c10d::OpType opType) + : WorkNCCL(devices, rank, opType), + set_timedout_error_(set_timedout_error) {} private: bool isCompleted() override { @@ -97,17 +106,20 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { public: ProcessGroupNCCLTimedOutErrors( - const std::shared_ptr& store, + const c10::intrusive_ptr& store, int rank, int size, - c10d::ProcessGroupNCCL::Options opts) + c10::intrusive_ptr opts) : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} - std::shared_ptr initWork( - std::vector devices) override { - return std::make_shared( - devices, set_timedout_error_); + c10::intrusive_ptr initWork( + std::vector devices, + int rank, + c10d::OpType opType, + const char* profilingTitle) override { + return c10::make_intrusive( + devices, set_timedout_error_, rank, opType); } void set_timedout_error() { @@ -141,7 +153,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { void SetUp() override { size_t numDevices = cudaNumDevices(); TemporaryFile file; - store_ = std::make_shared<::c10d::FileStore>(file.path, 1); + store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1); at::cuda::OptionalCUDAGuard deviceGuard; tensors_.resize(numDevices); @@ -156,7 +168,7 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test { } std::vector tensors_; - std::shared_ptr<::c10d::FileStore> store_; + c10::intrusive_ptr<::c10d::FileStore> store_; }; TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { @@ -165,8 +177,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(1000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(1000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); @@ -194,8 +206,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(3000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLTimedOutErrors pg( store_, 0, 1, options); @@ -217,8 +229,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { return; } - c10d::ProcessGroupNCCL::Options options; - options.opTimeout = std::chrono::milliseconds(3000); + auto options = c10d::ProcessGroupNCCL::Options::create(); + options->opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index 16cc778325a18..6c8bec2d0f920 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -9,6 +9,7 @@ #include #include +#include #include using namespace c10d::test; @@ -30,7 +31,7 @@ class NCCLTestBase { } void initialize(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_, size); + auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( new ::c10d::ProcessGroupNCCL(store, rank, size)); @@ -79,7 +80,7 @@ class NCCLTest : public NCCLTestBase { } void wait( - std::shared_ptr& work, + c10::intrusive_ptr& work, std::chrono::milliseconds timeout = kNoTimeout) { at::cuda::CUDAMultiStreamGuard guard(streams_); work->wait(timeout); @@ -165,14 +166,21 @@ class AllreduceNCCLTest : public NCCLTest { AllreduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); launchDeviceSleep(); valueInitialization(); - return pg_->allreduce(tensors_); + using namespace torch::autograd::profiler; + // Make sure enabling profile does not make any issue. Note, in single + // process multi-device mode we do not expect any events be populated for + // collective operations, since profiling for that mode is not supported. + enableProfilerLegacy({ProfilerState::CPU}); + auto results = pg_->allreduce(tensors_); + disableProfilerLegacy(); + return results; } }; @@ -181,7 +189,7 @@ class BroadcastNCCLTest : public NCCLTest { BroadcastNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -200,7 +208,7 @@ class ReduceNCCLTest : public NCCLTest { ReduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(int rootRank, int rootTensor) { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -219,7 +227,7 @@ class AllgatherNCCLTest : public NCCLTest { AllgatherNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); @@ -234,7 +242,7 @@ struct ReduceScatterNCCLTest : NCCLTest { ReduceScatterNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - std::shared_ptr run() { + c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams at::cuda::CUDAMultiStreamGuard guard(streams_); diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp index 66176d3e7355a..30a123dc163f9 100644 --- a/torch/lib/c10d/test/TCPStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -9,12 +9,14 @@ #include #include +constexpr int64_t kShortStoreTimeoutMillis = 100; + // Different ports for different tests. void testHelper(const std::string& prefix = "") { const auto numThreads = 16; const auto numWorkers = numThreads + 1; - auto serverTCPStore = std::make_shared( + auto serverTCPStore = c10::make_intrusive( "127.0.0.1", 0, numWorkers, @@ -23,7 +25,7 @@ void testHelper(const std::string& prefix = "") { /* wait */ false); auto serverStore = - std::make_unique(prefix, serverTCPStore); + c10::make_intrusive(prefix, serverTCPStore); // server store auto serverThread = std::thread([&serverStore, &serverTCPStore] { // Wait for all workers to join. @@ -36,6 +38,24 @@ void testHelper(const std::string& prefix = "") { c10d::test::check(*serverStore, "key0", "value0"); c10d::test::check(*serverStore, "key1", "value1"); c10d::test::check(*serverStore, "key2", "value2"); + serverStore->add("counter", 1); + auto numKeys = serverStore->getNumKeys(); + // We expect 5 keys since 3 are added above, 'counter' is added by the + // helper thread, and the init key to coordinate workers. + EXPECT_EQ(numKeys, 5); + + auto delSuccess = serverStore->deleteKey("key0"); + // Ensure that the key was successfully deleted + EXPECT_TRUE(delSuccess); + auto delFailure = serverStore->deleteKey("badKeyName"); + // The key was not in the store so the delete operation should have failed + // and returned false. + EXPECT_FALSE(delFailure); + numKeys = serverStore->getNumKeys(); + EXPECT_EQ(numKeys, 4); + auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis); + serverStore->setTimeout(timeout); + EXPECT_THROW(serverStore->get("key0"), std::runtime_error); }); // Hammer on TCPStore @@ -44,20 +64,20 @@ void testHelper(const std::string& prefix = "") { c10d::test::Semaphore sem1, sem2; // Each thread will have a client store to send/recv data - std::vector> clientTCPStores; - std::vector> clientStores; + std::vector> clientTCPStores; + std::vector> clientStores; for (auto i = 0; i < numThreads; i++) { - clientTCPStores.push_back(std::make_unique( + clientTCPStores.push_back(c10::make_intrusive( "127.0.0.1", serverTCPStore->getPort(), numWorkers, false)); - clientStores.push_back(std::unique_ptr( - new c10d::PrefixStore(prefix, clientTCPStores[i]))); + clientStores.push_back( + c10::make_intrusive(prefix, clientTCPStores[i])); } - std::string expectedCounterRes = std::to_string(numThreads * numIterations); + std::string expectedCounterRes = std::to_string(numThreads * numIterations + 1); for (auto i = 0; i < numThreads; i++) { threads.push_back( - std::thread([&sem1, &sem2, &clientStores, i, &expectedCounterRes] { + std::thread([&sem1, &sem2, &clientStores, i, &expectedCounterRes, &numIterations, &numThreads] { for (auto j = 0; j < numIterations; j++) { clientStores[i]->add("counter", 1); } diff --git a/torch/lib/c10d/test/TestUtils.hpp b/torch/lib/c10d/test/TestUtils.hpp index c62695485573e..5f5dfca315cb0 100644 --- a/torch/lib/c10d/test/TestUtils.hpp +++ b/torch/lib/c10d/test/TestUtils.hpp @@ -1,9 +1,12 @@ #pragma once +#ifndef _WIN32 #include -#include #include #include +#endif + +#include #include #include @@ -37,6 +40,28 @@ class Semaphore { std::condition_variable cv_; }; +#ifdef _WIN32 +std::string autoGenerateTmpFilePath() { + char tmp[L_tmpnam_s]; + errno_t err; + err = tmpnam_s(tmp, L_tmpnam_s); + if (err != 0) + { + throw std::system_error(errno, std::system_category()); + } + return std::string(tmp); +} + +std::string tmppath() { + const char* tmpfile = getenv("TMPFILE"); + if (tmpfile) { + return std::string(tmpfile); + } + else { + return autoGenerateTmpFilePath(); + } +} +#else std::string tmppath() { // TMPFILE is for manual test execution during which the user will specify // the full temp file path using the environmental variable TMPFILE @@ -63,6 +88,7 @@ std::string tmppath() { close(fd); return std::string(tmp.data(), tmp.size()); } +#endif bool isTSANEnabled() { auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); @@ -80,6 +106,7 @@ struct TemporaryFile { } }; +#ifndef _WIN32 struct Fork { pid_t pid; @@ -101,6 +128,7 @@ struct Fork { return pid == 0; } }; +#endif } // namespace test } // namespace c10d diff --git a/torch/library.h b/torch/library.h index 19d04aeb73c49..fee98abb2b81f 100644 --- a/torch/library.h +++ b/torch/library.h @@ -81,7 +81,7 @@ class class_; /// /// This class erases the type of the passed in function, but durably records /// the type via an inferred schema for the function. -class CAFFE2_API CppFunction final { +class TORCH_API CppFunction final { // TODO: This is morally the same thing as KernelRegistrationConfig, but it's // opaque to the user. @@ -99,7 +99,7 @@ class CAFFE2_API CppFunction final { /// This overload accepts compile time function pointers, e.g., `CppFunction(TORCH_FN(add_impl))` template explicit CppFunction(FuncPtr f, std::enable_if_t::value, std::nullptr_t> = nullptr) - : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f.func_ptr())) + : func_(c10::KernelFunction::makeFromUnboxedFunction(f)) , cpp_signature_(c10::impl::CppSignature::make()) // TODO: Don't go through WrapRuntimeKernelFunctor , schema_(c10::detail::inferFunctionSchemaFromFunctor>>()) @@ -116,19 +116,6 @@ class CAFFE2_API CppFunction final { , debug_() {} - /// This static factory lets you create CppFunctions that (1) don't have boxing - /// wrappers (because we don't support it yet) and (2) don't have schema - /// inference (because some ops don't support it). - template - static CppFunction makeUnboxedOnly(Func* f) { - // TODO: Eliminate the necessity for this function entirely. - return CppFunction( - c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f), - /* cpp_signature */ c10::impl::CppSignature::make(), - /* schema */ nullptr - ); - } - /// This creates a fallthrough function. Fallthrough functions /// immediately redispatch to the next available dispatch key, /// but are implemented more efficiently than a hand written @@ -170,6 +157,22 @@ class CAFFE2_API CppFunction final { ); } + /// Create a function from an unboxed kernel function. + /// This is typically used to register common operators. + template::value, std::nullptr_t> = nullptr> + static CppFunction makeFromUnboxedFunction(FuncPtr* f) { + return CppFunction(f); + } + + /// Create a function from a compile time unboxed kernel function pointer. + /// This is typically used to register common operators. + /// Compile time function pointers can be used to allow the compiler + /// to optimize (e.g. inline) calls to it. + template::value, std::nullptr_t> = nullptr> + static CppFunction makeFromUnboxedFunction(FuncPtr f) { + return CppFunction(f); + } + CppFunction&& debug(std::string d) && { debug_ = std::move(d); return std::move(*this); @@ -367,7 +370,7 @@ namespace detail { /// } /// ``` /// -class CAFFE2_API Library final { +class TORCH_API Library final { public: /// \private /// @@ -496,20 +499,10 @@ class CAFFE2_API Library final { return impl(name, dispatch(std::forward(key), std::forward(raw_f))); } - /// \private - /// - /// Convenience overload for unboxed only kernels; kernels whose type - /// signatures are not supported by our template based metaprogramming - /// system. These are currently quite common but will be eventually - /// eliminated. - /// - /// This is equivalent to calling CppFunction::makeUnboxedOnly() on - /// the function, but this name for the function makes it easy to grep for. template Library& impl_UNBOXED(Name name, Func* raw_f) & { - // TODO: Remove this overload once the makeUnboxedOnly incidence rate - // goes way down - return impl(name, CppFunction::makeUnboxedOnly(raw_f)); + static_assert(c10::guts::false_t(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead."); + return *this; } // These overloads cover cases when a SelectiveStr (see Note [Selective build]) @@ -531,7 +524,10 @@ class CAFFE2_API Library final { template Library& impl(detail::SelectiveStr, Dispatch&& key, Func&& raw_f) & { return *this; } template - Library& impl_UNBOXED(detail::SelectiveStr name, Func* raw_f) & { return *this; } + Library& impl_UNBOXED(detail::SelectiveStr name, Func* raw_f) & { + static_assert(c10::guts::false_t(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead."); + return *this; + } template Library& impl(detail::SelectiveStr name, Func&& raw_f) & { @@ -543,7 +539,8 @@ class CAFFE2_API Library final { } template Library& impl_UNBOXED(detail::SelectiveStr name, Func* raw_f) & { - return impl(name.operator const char*(), CppFunction::makeUnboxedOnly(raw_f)); + static_assert(c10::guts::false_t(), ".impl_UNBOXED(...) was removed. Please use .impl(...) instead."); + return *this; } /// Register a fallback implementation for all operators which will be used @@ -643,7 +640,7 @@ class TorchLibraryInit final { /// for any given namespace. #define TORCH_LIBRARY(ns, m) \ static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \ - static torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \ + static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \ torch::Library::DEF, \ &TORCH_LIBRARY_init_ ## ns, \ #ns, c10::nullopt, __FILE__, __LINE__ \ @@ -653,16 +650,28 @@ class TorchLibraryInit final { /// \private /// /// This macro is a version of TORCH_LIBRARY() that doesn't enforce that there -/// is only one library (it is a "fragment"). This should ONLY be used -/// inside the PerOpRegistration.cpp file (as its name suggests). -#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \ - static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library&); \ - static torch::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \ +/// is only one library (it is a "fragment"). This is used inside the +/// PerOpRegistration.cpp file, as well as in places where all op registrations +/// within the same namespace cannot be easily put into one macro block +/// (this is mostly the case for custom ops in fbcode that were ported from +/// the old API) +#define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID) + +/// \private +/// +/// The above macro requires an extra unique identifier (uid) to prevent variable name collisions +/// This can happen if TORCH_LIBRARY_FRAGMENT is called multiple times with the same namespace +/// in the same translation unit. +/// Note that the TORCH_LIBRARY variant doesn't run into this problem, because it enforces +/// that it can only be called once for a given namespace. +#define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ + static void C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid) (torch::Library&); \ + static const torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _, uid) ( \ torch::Library::FRAGMENT, \ - &TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \ + &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid), \ #ns, c10::nullopt, __FILE__, __LINE__ \ ); \ - void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library& m) + void C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _, uid) (torch::Library& m) /// Macro for defining a function that will be run at static /// initialization time to define operator overrides for dispatch key @@ -704,18 +713,25 @@ class TorchLibraryInit final { /// // NB: if the dispatch key is not whitelisted, we simply omit the Library // call entirely -#define TORCH_LIBRARY_IMPL(ns, k, m) \ - static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library&); \ - static torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \ +#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID) + +/// \private +/// +/// The above macro requires an extra unique identifier (uid) to prevent variable name collisions. +/// This can happen if TORCH_LIBRARY_IMPL is called multiple times with the same namespace +/// and dispatch key in the same translation unit. +#define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \ + static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid) (torch::Library&); \ + static const torch::detail::TorchLibraryInit C10_CONCATENATE(TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ## _, uid) ( \ torch::Library::IMPL, \ c10::guts::if_constexpr( \ - []() { return & TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k; }, \ + []() { return & C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid); }, \ []() { return [](torch::Library&) -> void {}; } \ ), \ #ns, c10::make_optional(c10::DispatchKey::k), \ __FILE__, __LINE__ \ ); \ - void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library& m) + void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k ## _, uid) (torch::Library& m) // These are variants of the macros above which are to be used for testing (they diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 5e2b59c45c807..32b5844ede536 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -8,12 +8,369 @@ # Note: This not only adds doc strings for functions in the linalg namespace, but # also connects the torch.linalg Python namespace to the torch._C._linalg builtins. +cholesky = _add_docstr(_linalg.linalg_cholesky, r""" +linalg.cholesky(input, *, out=None) -> Tensor + +Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices) +positive-definite matrix or the Cholesky decompositions for a batch of such matrices. +Each decomposition has the form: + +.. math:: + + \text{input} = LL^H + +where :math:`L` is a lower-triangular matrix and :math:`L^H` is the conjugate transpose of :math:`L`, +which is just a transpose for the case of real-valued input matrices. +In code it translates to ``input = L @ L.t()` if :attr:`input` is real-valued and +``input = L @ L.conj().t()`` if :attr:`input` is complex-valued. +The batch of :math:`L` matrices is returned. + +Supports real-valued and complex-valued inputs. + +.. note:: If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices + and one or more of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown. + If :attr:`input` is a batch of matrices, then the error message will include the batch index + of the first matrix that is not Hermitian positive-definite. + +.. warning:: This function always checks whether :attr:`input` is a Hermitian positive-definite matrix + using `info` argument to LAPACK/MAGMA call. For CUDA this causes cross-device memory synchronization. + +Args: + input (Tensor): the input tensor of size :math:`(*, n, n)` consisting of Hermitian positive-definite + :math:`n \times n` matrices, where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> a = torch.randn(2, 2, dtype=torch.complex128) + >>> a = torch.mm(a, a.t().conj()) # creates a Hermitian positive-definite matrix + >>> l = torch.linalg.cholesky(a) + >>> a + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + >>> l + tensor([[1.5895+0.0000j, 0.0000+0.0000j], + [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) + >>> torch.mm(l, l.t().conj()) + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + + >>> a = torch.randn(3, 2, 2, dtype=torch.float64) + >>> a = torch.matmul(a, a.transpose(-2, -1)) # creates a symmetric positive-definite matrix + >>> l = torch.linalg.cholesky(a) + >>> a + tensor([[[ 1.1629, 2.0237], + [ 2.0237, 6.6593]], + + [[ 0.4187, 0.1830], + [ 0.1830, 0.1018]], + + [[ 1.9348, -2.5744], + [-2.5744, 4.6386]]], dtype=torch.float64) + >>> l + tensor([[[ 1.0784, 0.0000], + [ 1.8766, 1.7713]], + + [[ 0.6471, 0.0000], + [ 0.2829, 0.1477]], + + [[ 1.3910, 0.0000], + [-1.8509, 1.1014]]], dtype=torch.float64) + >>> torch.allclose(torch.matmul(l, l.transpose(-2, -1)), a) + True +""") + +inv = _add_docstr(_linalg.linalg_inv, r""" +linalg.inv(input, *, out=None) -> Tensor + +This function computes the "multiplicative inverse" matrix of a square matrix, or batch of such matrices, :attr:`input`. +The result satisfies the relation + +``matmul(inv(input), input) = matmul(input, inv(input)) = eye(input.shape[0]).expand_as(input)``. + +Supports input of float, double, cfloat and cdouble data types. + +.. note:: If :attr:`input` is a non-invertible matrix or non-square matrix, or batch with at least one such matrix, + then a RuntimeError will be thrown. + +.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. + +Args: + input (Tensor): the square :math:`n \times n` matrix or the batch + of such matrices of size :math:`(*, n, n)` where `*` is one or more batch dimensions. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if None. Default: None + +Examples:: + + >>> x = torch.rand(4, 4) + >>> y = torch.linalg.inv(x) + >>> z = torch.mm(x, y) + >>> z + tensor([[ 1.0000, -0.0000, -0.0000, 0.0000], + [ 0.0000, 1.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 1.0000, 0.0000], + [ 0.0000, -0.0000, -0.0000, 1.0000]]) + >>> torch.max(torch.abs(z - torch.eye(4))) # Max non-zero + tensor(1.1921e-07) + + >>> # Batched inverse example + >>> x = torch.randn(2, 3, 4, 4) + >>> y = torch.linalg.inv(x) + >>> z = torch.matmul(x, y) + >>> torch.max(torch.abs(z - torch.eye(4).expand_as(x))) # Max non-zero + tensor(1.9073e-06) + + >>> x = torch.rand(4, 4, dtype=torch.cdouble) + >>> y = torch.linalg.inv(x) + >>> z = torch.mm(x, y) + >>> z + tensor([[ 1.0000e+00+0.0000e+00j, -1.3878e-16+3.4694e-16j, + 5.5511e-17-1.1102e-16j, 0.0000e+00-1.6653e-16j], + [ 5.5511e-16-1.6653e-16j, 1.0000e+00+6.9389e-17j, + 2.2204e-16-1.1102e-16j, -2.2204e-16+1.1102e-16j], + [ 3.8858e-16-1.2490e-16j, 2.7756e-17+3.4694e-17j, + 1.0000e+00+0.0000e+00j, -4.4409e-16+5.5511e-17j], + [ 4.4409e-16+5.5511e-16j, -3.8858e-16+1.8041e-16j, + 2.2204e-16+0.0000e+00j, 1.0000e+00-3.4694e-16j]], + dtype=torch.complex128) + >>> torch.max(torch.abs(z - torch.eye(4, dtype=torch.cdouble))) # Max non-zero + tensor(7.5107e-16, dtype=torch.float64) +""") + det = _add_docstr(_linalg.linalg_det, r""" linalg.det(input) -> Tensor Alias of :func:`torch.det`. """) +slogdet = _add_docstr(_linalg.linalg_slogdet, r""" +linalg.slogdet(input) -> (Tensor, Tensor) + +Calculates the sign and natural logarithm of the absolute value of a square matrix's determinant, +or of the absolute values of the determinants of a batch of square matrices :attr`input`. +The determinant can be computed with ``sign * exp(logabsdet)``. + +Supports input of float, double, cfloat and cdouble datatypes. + +.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. + +.. note:: For matrices that have zero determinant, this returns ``(0, -inf)``. + If :attr:`input` is batched then the entries in the result tensors corresponding to matrices with + the zero determinant have sign 0 and the natural logarithm of the absolute value of the determinant -inf. + +Arguments: + input (Tensor): the input matrix of size :math:`(n, n)` or the batch of matrices of size :math:`(*, n, n)` + where `*` is one or more batch dimensions. + +Returns: + A namedtuple (sign, logabsdet) containing the sign of the determinant and the natural logarithm + of the absolute value of determinant, respectively. + +Example:: + + >>> A = torch.randn(3, 3) + >>> A + tensor([[ 0.0032, -0.2239, -1.1219], + [-0.6690, 0.1161, 0.4053], + [-1.6218, -0.9273, -0.0082]]) + >>> torch.linalg.det(A) + tensor(-0.7576) + >>> torch.linalg.logdet(A) + tensor(nan) + >>> torch.linalg.slogdet(A) + torch.return_types.linalg_slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776)) +""") + +eigh = _add_docstr(_linalg.linalg_eigh, r""" +linalg.eigh(input, UPLO='L') -> tuple(Tensor, Tensor) + +This function computes the eigenvalues and eigenvectors +of a complex Hermitian (or real symmetric) matrix, or batch of such matrices, :attr:`input`. +For a single matrix :attr:`input`, the tensor of eigenvalues :math:`w` and the tensor of eigenvectors :math:`V` +decompose the :attr:`input` such that :math:`\text{input} = V \text{diag}(w) V^H`, +where :math:`^H` is the conjugate transpose operation. + +Since the matrix or matrices in :attr:`input` are assumed to be Hermitian, the imaginary part of their diagonals +is always treated as zero. When :attr:`UPLO` is "L", its default value, only the lower triangular part of +each matrix is used in the computation. When :attr:`UPLO` is "U" only the upper triangular part of each matrix is used. + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` data types. + +See :func:`torch.linalg.eigvalsh` for a related function that computes only eigenvalues, +however that function is not differentiable. + +.. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. + +.. note:: The eigenvectors of matrices are not unique, so any eigenvector multiplied by a constant remains + a valid eigenvector. This function may compute different eigenvector representations on + different device types. Usually the difference is only in the sign of the eigenvector. + +.. note:: The eigenvalues/eigenvectors are computed using LAPACK/MAGMA routines ``_syevd`` and ``_heevd``. + This function always checks whether the call to LAPACK/MAGMA is successful + using ``info`` argument of ``_syevd``, ``_heevd`` and throws a RuntimeError if it isn't. + On CUDA this causes a cross-device memory synchronization. + +Args: + input (Tensor): the Hermitian :math:`n \times n` matrix or the batch + of such matrices of size :math:`(*, n, n)` where `*` is one or more batch dimensions. + UPLO ('L', 'U', optional): controls whether to use the upper-triangular or the lower-triangular part + of :attr:`input` in the computations. Default: ``'L'`` + +Returns: + (Tensor, Tensor): A namedtuple (eigenvalues, eigenvectors) containing + + - **eigenvalues** (*Tensor*): Shape :math:`(*, m)`. + The eigenvalues in ascending order. + - **eigenvectors** (*Tensor*): Shape :math:`(*, m, m)`. + The orthonormal eigenvectors of the ``input``. + +Examples:: + + >>> a = torch.randn(2, 2, dtype=torch.complex128) + >>> a = a + a.t().conj() # creates a Hermitian matrix + >>> a + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> w, v = torch.linalg.eigh(a) + >>> w + tensor([0.3277, 2.9415], dtype=torch.float64) + >>> v + tensor([[-0.0846+-0.0000j, -0.9964+0.0000j], + [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128) + >>> torch.allclose(torch.matmul(v, torch.matmul(w.to(v.dtype).diag_embed(), v.t().conj())), a) + True + + >>> a = torch.randn(3, 2, 2, dtype=torch.float64) + >>> a = a + a.transpose(-2, -1) # creates a symmetric matrix + >>> w, v = torch.linalg.eigh(a) + >>> torch.allclose(torch.matmul(v, torch.matmul(w.diag_embed(), v.transpose(-2, -1))), a) + True +""") + +eigvalsh = _add_docstr(_linalg.linalg_eigvalsh, r""" +linalg.eigvalsh(input, UPLO='L') -> Tensor + +This function computes the eigenvalues of a complex Hermitian (or real symmetric) matrix, +or batch of such matrices, :attr:`input`. The eigenvalues are returned in ascending order. + +Since the matrix or matrices in :attr:`input` are assumed to be Hermitian, the imaginary part of their diagonals +is always treated as zero. When :attr:`UPLO` is "L", its default value, only the lower triangular part of +each matrix is used in the computation. When :attr:`UPLO` is "U" only the upper triangular part of each matrix is used. + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` data types. + +See :func:`torch.linalg.eigh` for a related function that computes both eigenvalues and eigenvectors. + +.. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. + +.. note:: The eigenvalues/eigenvectors are computed using LAPACK/MAGMA routines ``_syevd`` and ``_heevd``. + This function always checks whether the call to LAPACK/MAGMA is successful + using ``info`` argument of ``_syevd``, ``_heevd`` and throws a RuntimeError if it isn't. + On CUDA this causes a cross-device memory synchronization. + +.. note:: This function doesn't support backpropagation, please use :func:`torch.linalg.eigh` instead, + that also computes the eigenvectors. + +Args: + input (Tensor): the Hermitian :math:`n \times n` matrix or the batch + of such matrices of size :math:`(*, n, n)` where `*` is one or more batch dimensions. + UPLO ('L', 'U', optional): controls whether to use the upper-triangular or the lower-triangular part + of :attr:`input` in the computations. Default: ``'L'`` + +Examples:: + + >>> a = torch.randn(2, 2, dtype=torch.complex128) + >>> a = a + a.t().conj() # creates a Hermitian matrix + >>> a + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> w = torch.linalg.eigvalsh(a) + >>> w + tensor([0.3277, 2.9415], dtype=torch.float64) + + >>> a = torch.randn(3, 2, 2, dtype=torch.float64) + >>> a = a + a.transpose(-2, -1) # creates a symmetric matrix + >>> a + tensor([[[ 2.8050, -0.3850], + [-0.3850, 3.2376]], + + [[-1.0307, -2.7457], + [-2.7457, -1.7517]], + + [[ 1.7166, 2.2207], + [ 2.2207, -2.0898]]], dtype=torch.float64) + >>> w = torch.linalg.eigvalsh(a) + >>> w + tensor([[ 2.5797, 3.4629], + [-4.1605, 1.3780], + [-3.1113, 2.7381]], dtype=torch.float64) +""") + +matrix_rank = _add_docstr(_linalg.linalg_matrix_rank, r""" +matrix_rank(input, tol=None, hermitian=False) -> Tensor + +Computes the numerical rank of a matrix :attr:`input`, or of each matrix in a batched :attr:`input`. +The matrix rank is computed as the number of singular values (or the absolute eigenvalues when :attr:`hermitian` is ``True``) +above the specified :attr:`tol` threshold. + +If :attr:`tol` is not specified, :attr:`tol` is set to +``S.max(dim=-1) * max(input.shape[-2:]) * eps`` where ``S`` is the singular values +(or the absolute eigenvalues when :attr:`hermitian` is ``True``), +and ``eps`` is the epsilon value for the datatype of :attr:`input`. +The epsilon value can be obtained using ``eps`` attribute of :class:`torch.finfo`. + +The method to compute the matrix rank is done using singular value decomposition (see :func:`torch.linalg.svd`) by default. +If :attr:`hermitian` is ``True``, then :attr:`input` is assumed to be Hermitian (symmetric if real-valued), +and the computation of the rank is done by obtaining the eigenvalues (see :func:`torch.linalg.eigvalsh`). + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` datatypes. + +.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. + +Args: + input (Tensor): the input matrix of size :math:`(m, n)` or the batch of matrices of size :math:`(*, m, n)` + where `*` is one or more batch dimensions. + tol (float, optional): the tolerance value. Default: ``None`` + hermitian(bool, optional): indicates whether :attr:`input` is Hermitian. Default: ``False`` + +Examples:: + + >>> a = torch.eye(10) + >>> torch.linalg.matrix_rank(a) + tensor(10) + >>> b = torch.eye(10) + >>> b[0, 0] = 0 + >>> torch.linalg.matrix_rank(b) + tensor(9) + + >>> a = torch.randn(4, 3, 2) + >>> torch.linalg.matrix_rank(a) + tensor([2, 2, 2, 2]) + + >>> a = torch.randn(2, 4, 2, 3) + >>> torch.linalg.matrix_rank(a) + tensor([[2, 2, 2, 2], + [2, 2, 2, 2]]) + + >>> a = torch.randn(2, 4, 3, 3, dtype=torch.complex64) + >>> torch.linalg.matrix_rank(a) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(a, hermitian=True) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(a, tol=1.0) + tensor([[3, 2, 2, 2], + [1, 2, 1, 2]]) + >>> torch.linalg.matrix_rank(a, tol=1.0, hermitian=True) + tensor([[2, 2, 2, 1], + [1, 2, 2, 2]]) +""") + norm = _add_docstr(_linalg.linalg_norm, r""" linalg.norm(input, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor @@ -26,7 +383,10 @@ Args: input (Tensor): The input tensor. If dim is None, x must be 1-D or 2-D, unless :attr:`ord` is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D - will be returned. + will be returned. Its data type must be either a floating point or complex type. For complex + inputs, the norm is calculated on of the absolute values of each element. If the input is + complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will + be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat). ord (int, float, inf, -inf, 'fro', 'nuc', optional): The order of norm. inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object. @@ -68,8 +428,7 @@ :attr:`dtype` before performing the operation, and the returned tensor's type will be :attr:`dtype`. If this argument is used in conjunction with the :attr:`out` argument, the output tensor's type must match this argument or a - RuntimeError will be raised. This argument is not currently supported for - :attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None`` + RuntimeError will be raised. Default: ``None`` Examples:: @@ -140,3 +499,474 @@ >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (tensor(3.7417), tensor(11.2250)) """) + +svd = _add_docstr(_linalg.linalg_svd, r""" +linalg.svd(input, full_matrices=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, Vh)``, such that :math:`input = U \mathbin{@} diag(S) \times +Vh`. If :attr:`input` is a batch of tensors, then ``U``, ``S``, and ``Vh`` are +also batched with the same batch dimensions as :attr:`input`. + +If :attr:`full_matrices` is ``False`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `Vh` will be empy +tensors with no elements and the same device as :attr:`input`. The +:attr:`full_matrices` argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. + +.. note:: Unlike NumPy's ``linalg.svd``, this always returns a namedtuple of + three tensors, even when :attr:`compute_uv=False`. + +.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch is returned in descending order. + +.. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer + algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine + `gesdd` as well. + +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. + +.. note:: Gradients computed using `U` and `Vh` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. + +.. note:: When :attr:`full_matrices` = ``True``, the gradients on :code:`U[..., :, min(m, n):]` + and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors + can be arbitrary bases of the subspaces. + +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of :math:`m \times n` matrices. + full_matrices (bool, optional): controls whether to compute the full or reduced decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True. + out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False, + the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can + pass `(torch.Tensor(), out_S, torch.Tensor())` + +Example:: + + >>> import torch + >>> a = torch.randn(5, 3) + >>> a + tensor([[-0.3357, -0.2987, -1.1096], + [ 1.4894, 1.0016, -0.4572], + [-1.9401, 0.7437, 2.0968], + [ 0.1515, 1.3812, 1.5491], + [-1.8489, -0.5907, -2.5673]]) + >>> + >>> # reconstruction in the full_matrices=False case + >>> u, s, vh = torch.linalg.svd(a, full_matrices=False) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # reconstruction in the full_matrices=True case + >>> u, s, vh = torch.linalg.svd(a) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # extra dimensions + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False) + >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) + tensor(3.0957e-06) +""") + +cond = _add_docstr(_linalg.linalg_cond, r""" +linalg.cond(input, p=None, *, out=None) -> Tensor + +Computes the condition number of a matrix :attr:`input`, +or of each matrix in a batched :attr:`input`, using the matrix norm defined by :attr:`p`. +For norms ``p = {'fro', 'nuc', inf, -inf, 1, -1}`` this is defined as the matrix norm of :attr:`input` +times the matrix norm of the inverse of :attr:`input`. And for norms ``p = {None, 2, -2}`` this is defined as +the ratio between the largest and smallest singular values. + +This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. +If the input is complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will +be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat). + +.. note:: For ``p = {None, 2, -2}`` the condition number is computed as the ratio between the largest + and smallest singular values computed using :func:`torch.linalg.svd`. + For these norms :attr:`input` may be a non-square matrix or batch of non-square matrices. + For other norms, however, :attr:`input` must be a square matrix or a batch of square matrices, + and if this requirement is not satisfied a RuntimeError will be thrown. + +.. note:: For ``p = {'fro', 'nuc', inf, -inf, 1, -1}`` if :attr:`input` is a non-invertible matrix then + a tensor containing infinity will be returned. If :attr:`input` is a batch of matrices and one + or more of them is not invertible then a RuntimeError will be thrown. + +.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. + +Args: + input (Tensor): the input matrix of size :math:`(m, n)` or the batch of matrices of size :math:`(*, m, n)` + where `*` is one or more batch dimensions. + + p (int, float, inf, -inf, 'fro', 'nuc', optional): the type of the matrix norm to use in the computations. + The following norms are supported: + + ===== ============================ + p norm for matrices + ===== ============================ + None ratio of the largest singular value to the smallest singular value + 'fro' Frobenius norm + 'nuc' nuclear norm + inf max(sum(abs(x), dim=1)) + -inf min(sum(abs(x), dim=1)) + 1 max(sum(abs(x), dim=0)) + -1 min(sum(abs(x), dim=0)) + 2 ratio of the largest singular value to the smallest singular value + -2 ratio of the smallest singular value to the largest singular value + ===== ============================ + + Default: ``None`` + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) + >>> LA.cond(a) + tensor(1.4142) + >>> LA.cond(a, 'fro') + tensor(3.1623) + >>> LA.cond(a, 'nuc') + tensor(9.2426) + >>> LA.cond(a, np.inf) + tensor(2.) + >>> LA.cond(a, -np.inf) + tensor(1.) + >>> LA.cond(a, 1) + tensor(2.) + >>> LA.cond(a, -1) + tensor(1.) + >>> LA.cond(a, 2) + tensor(1.4142) + >>> LA.cond(a, -2) + tensor(0.7071) + + >>> a = torch.randn(3, 4, 4) + >>> LA.cond(a) + tensor([ 4.4739, 76.5234, 10.8409]) + + >>> a = torch.randn(3, 4, 4, dtype=torch.complex64) + >>> LA.cond(a) + tensor([ 5.9175, 48.4590, 5.6443]) + >>> LA.cond(a, 1) + >>> tensor([ 11.6734+0.j, 105.1037+0.j, 10.1978+0.j]) +""") + +pinv = _add_docstr(_linalg.linalg_pinv, r""" +linalg.pinv(input, rcond=1e-15, hermitian=False) -> Tensor + +Computes the pseudo-inverse (also known as the Moore-Penrose inverse) of a matrix :attr:`input`, +or of each matrix in a batched :attr:`input`. +The pseudo-inverse is computed using singular value decomposition (see :func:`torch.svd`) by default. +If :attr:`hermitian` is ``True``, then :attr:`input` is assumed to be Hermitian (symmetric if real-valued), +and the computation of the pseudo-inverse is done by obtaining the eigenvalues and eigenvectors +(see :func:`torch.linalg.eigh`). +The singular values (or the absolute values of the eigenvalues when :attr:`hermitian` is ``True``) that are below +the specified :attr:`rcond` threshold are treated as zero and discarded in the computation. + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` datatypes. + +.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. + +.. note:: If singular value decomposition or eigenvalue decomposition algorithms do not converge + then a RuntimeError will be thrown. + +Args: + input (Tensor): the input matrix of size :math:`(m, n)` or the batch of matrices of size :math:`(*, m, n)` + where `*` is one or more batch dimensions. + rcond (float, Tensor, optional): the tolerance value to determine the cutoff for small singular values. Default: 1e-15 + :attr:`rcond` must be broadcastable to the singular values of :attr:`input` + as returned by :func:`torch.svd`. + hermitian(bool, optional): indicates whether :attr:`input` is Hermitian. Default: ``False`` + +Examples:: + + >>> input = torch.randn(3, 5) + >>> input + tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], + [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], + [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) + >>> torch.linalg.pinv(input) + tensor([[ 0.0600, -0.1933, -0.2090], + [-0.0903, -0.0817, -0.4752], + [-0.7124, -0.1631, -0.2272], + [ 0.1356, 0.3933, -0.5023], + [-0.0308, -0.1725, -0.5216]]) + + Batched linalg.pinv example + >>> a = torch.randn(2, 6, 3) + >>> b = torch.linalg.pinv(a) + >>> torch.matmul(b, a) + tensor([[[ 1.0000e+00, 1.6391e-07, -1.1548e-07], + [ 8.3121e-08, 1.0000e+00, -2.7567e-07], + [ 3.5390e-08, 1.4901e-08, 1.0000e+00]], + + [[ 1.0000e+00, -8.9407e-08, 2.9802e-08], + [-2.2352e-07, 1.0000e+00, 1.1921e-07], + [ 0.0000e+00, 8.9407e-08, 1.0000e+00]]]) + + Hermitian input example + >>> a = torch.randn(3, 3, dtype=torch.complex64) + >>> a = a + a.t().conj() # creates a Hermitian matrix + >>> b = torch.linalg.pinv(a, hermitian=True) + >>> torch.matmul(b, a) + tensor([[ 1.0000e+00+0.0000e+00j, -1.1921e-07-2.3842e-07j, + 5.9605e-08-2.3842e-07j], + [ 5.9605e-08+2.3842e-07j, 1.0000e+00+2.3842e-07j, + -4.7684e-07+1.1921e-07j], + [-1.1921e-07+0.0000e+00j, -2.3842e-07-2.9802e-07j, + 1.0000e+00-1.7897e-07j]]) + + Non-default rcond example + >>> rcond = 0.5 + >>> a = torch.randn(3, 3) + >>> torch.linalg.pinv(a) + tensor([[ 0.2971, -0.4280, -2.0111], + [-0.0090, 0.6426, -0.1116], + [-0.7832, -0.2465, 1.0994]]) + >>> torch.linalg.pinv(a, rcond) + tensor([[-0.2672, -0.2351, -0.0539], + [-0.0211, 0.6467, -0.0698], + [-0.4400, -0.3638, -0.0910]]) + + Matrix-wise rcond example + >>> a = torch.randn(5, 6, 2, 3, 3) + >>> rcond = torch.rand(2) # different rcond values for each matrix in a[:, :, 0] and a[:, :, 1] + >>> torch.linalg.pinv(a, rcond) + >>> rcond = torch.randn(5, 6, 2) # different rcond value for each matrix in 'a' + >>> torch.linalg.pinv(a, rcond) +""") + +solve = _add_docstr(_linalg.linalg_solve, r""" +linalg.solve(input, other, *, out=None) -> Tensor + +Computes the solution ``x`` to the matrix equation ``matmul(input, x) = other`` +with a square matrix, or batches of such matrices, :attr:`input` and one or more right-hand side vectors :attr:`other`. +If :attr:`input` is batched and :attr:`other` is not, then :attr:`other` is broadcast +to have the same batch dimensions as :attr:`input`. +The resulting tensor has the same shape as the (possibly broadcast) :attr:`other`. + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes. + +.. note:: If :attr:`input` is a non-square or non-invertible matrix, or a batch containing non-square matrices + or one or more non-invertible matrices, then a RuntimeError will be thrown. +.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. + +Args: + input (Tensor): the square :math:`n \times n` matrix or the batch + of such matrices of size :math:`(*, n, n)` where ``*`` is one or more batch dimensions. + other (Tensor): right-hand side tensor of shape :math:`(*, n)` or :math:`(*, n, k)`, + where :math:`k` is the number of right-hand side vectors. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> A = torch.eye(3) + >>> b = torch.randn(3) + >>> x = torch.linalg.solve(A, b) + >>> torch.allclose(A @ x, b) + True + +Batched input:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(3, 1) + >>> x = torch.linalg.solve(A, b) + >>> torch.allclose(A @ x, b) + True + >>> b = torch.rand(3) # b is broadcast internally to (*A.shape[:-2], 3) + >>> x = torch.linalg.solve(A, b) + >>> x.shape + torch.Size([2, 3]) + >>> Ax = A @ x.unsqueeze(-1) + >>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax)) + True +""") + +tensorinv = _add_docstr(_linalg.linalg_tensorinv, r""" +linalg.tensorinv(input, ind=2, *, out=None) -> Tensor + +Computes a tensor ``input_inv`` such that ``tensordot(input_inv, input, ind) == I_n`` (inverse tensor equation), +where ``I_n`` is the n-dimensional identity tensor and ``n`` is equal to ``input.ndim``. +The resulting tensor ``input_inv`` has shape equal to ``input.shape[ind:] + input.shape[:ind]``. + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` data types. + +.. note:: If :attr:`input` is not invertible or does not satisfy the requirement + ``prod(input.shape[ind:]) == prod(input.shape[:ind])``, + then a RuntimeError will be thrown. + +.. note:: When :attr:`input` is a 2-dimensional tensor and ``ind=1``, this function computes the + (multiplicative) inverse of :attr:`input`, equivalent to calling :func:`torch.inverse`. + +Args: + input (Tensor): A tensor to invert. Its shape must satisfy ``prod(input.shape[:ind]) == prod(input.shape[ind:])``. + ind (int): A positive integer that describes the inverse tensor equation. See :func:`torch.tensordot` for details. Default: 2. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> a = torch.eye(4 * 6).reshape((4, 6, 8, 3)) + >>> ainv = torch.linalg.tensorinv(a, ind=2) + >>> ainv.shape + torch.Size([8, 3, 4, 6]) + >>> b = torch.randn(4, 6) + >>> torch.allclose(torch.tensordot(ainv, b), torch.linalg.tensorsolve(a, b)) + True + + >>> a = torch.randn(4, 4) + >>> a_tensorinv = torch.linalg.tensorinv(a, ind=1) + >>> a_inv = torch.inverse(a) + >>> torch.allclose(a_tensorinv, a_inv) + True +""") + +tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" +linalg.tensorsolve(input, other, dims=None, *, out=None) -> Tensor + +Computes a tensor ``x`` such that ``tensordot(input, x, dims=x.ndim) = other``. +The resulting tensor ``x`` has the same shape as ``input[other.ndim:]``. + +Supports real-valued and complex-valued inputs. + +.. note:: If :attr:`input` does not satisfy the requirement + ``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])`` + after (optionally) moving the dimensions using :attr:`dims`, then a RuntimeError will be thrown. + +Args: + input (Tensor): "left-hand-side" tensor, it must satisfy the requirement + ``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])``. + other (Tensor): "right-hand-side" tensor of shape ``input.shape[other.ndim]``. + dims (Tuple[int]): dimensions of :attr:`input` to be moved before the computation. + Equivalent to calling ``input = movedim(input, dims, range(len(dims) - input.ndim, 0))``. + If None (default), no dimensions are moved. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + >>> b = torch.randn(2 * 3, 4) + >>> x = torch.linalg.tensorsolve(a, b) + >>> x.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b) + True + + >>> a = torch.randn(6, 4, 4, 3, 2) + >>> b = torch.randn(4, 3, 2) + >>> x = torch.linalg.tensorsolve(a, b, dims=(0, 2)) + >>> x.shape + torch.Size([6, 4]) + >>> a = a.permute(1, 3, 4, 0, 2) + >>> a.shape[b.ndim:] + torch.Size([6, 4]) + >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b, atol=1e-6) + True +""") + + +qr = _add_docstr(_linalg.linalg_qr, r""" +qr(input, mode='reduced', *, out=None) -> (Tensor, Tensor) + +Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, +and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` +with :math:`Q` being an orthogonal matrix or batch of orthogonal matrices and +:math:`R` being an upper triangular matrix or batch of upper triangular matrices. + +Depending on the value of :attr:`mode` this function returns the reduced or +complete QR factorization. See below for a list of valid modes. + +.. note:: **Differences with** ``numpy.linalg.qr``: + + * ``mode='raw'`` is not implemented + + * unlike ``numpy.linalg.qr``, this function always returns a + tuple of two tensors. When ``mode='r'``, the `Q` tensor is an + empty tensor. + +.. note:: + Backpropagation is not supported for ``mode='r'``. Use ``mode='reduced'`` instead. + + Backpropagation is also not supported if the first + :math:`\min(input.size(-1), input.size(-2))` columns of any matrix + in :attr:`input` are not linearly independent. While no error will + be thrown when this occurs the values of the "gradient" produced may + be anything. This behavior may change in the future. + +.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of matrices of dimension :math:`m \times n`. + mode (str, optional): if `k = min(m, n)` then: + + * ``'reduced'`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'complete'``: returns `(Q, R)` with dimensions (m, m), (m, n) + + * ``'r'``: computes only `R`; returns `(Q, R)` where `Q` is empty and `R` has dimensions (k, n) + +Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`mode` above. + +Example:: + + >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> q, r = torch.linalg.qr(a) + >>> q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> r + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> torch.mm(q, r).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> torch.mm(q.t(), q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> q2, r2 = torch.linalg.qr(a, mode='r') + >>> q2 + tensor([]) + >>> torch.equal(r, r2) + True + >>> a = torch.randn(3, 4, 5) + >>> q, r = torch.linalg.qr(a, mode='complete') + >>> torch.allclose(torch.matmul(q, r), a) + True + >>> torch.allclose(torch.matmul(q.transpose(-2, -1), q), torch.eye(5)) + True +""") diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index 07a8e5c06e338..039ddf2a1b09c 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -25,7 +25,7 @@ from multiprocessing import * -__all__ += multiprocessing.__all__ +__all__ += multiprocessing.__all__ # type: ignore[attr-defined] # This call adds a Linux specific prctl(2) wrapper function to this module. @@ -33,16 +33,10 @@ torch._C._multiprocessing_init() -if sys.version_info < (3, 3): - """Override basic classes in Python 2.7 and Python 3.3 to use ForkingPickler - for serialization. Later versions of Python already use ForkingPickler.""" - from .queue import Queue, SimpleQueue - from .pool import Pool - - """Add helper function to spawn N processes and wait for completion of any of them. This depends `mp.get_context` which was added in Python 3.4.""" -from .spawn import spawn, SpawnContext, _supports_context, start_processes, ProcessContext +from .spawn import spawn, SpawnContext, start_processes, ProcessContext, \ + ProcessRaisedException, ProcessExitedException if sys.platform == 'darwin' or sys.platform == 'win32': @@ -56,7 +50,7 @@ def set_sharing_strategy(new_strategy): """Sets the strategy for sharing CPU tensors. - Arguments: + Args: new_strategy (str): Name of the selected strategy. Should be one of the values returned by :func:`get_all_sharing_strategies()`. """ diff --git a/torch/multiprocessing/_atfork.py b/torch/multiprocessing/_atfork.py index de7b77ef79018..b9d59bc306042 100644 --- a/torch/multiprocessing/_atfork.py +++ b/torch/multiprocessing/_atfork.py @@ -23,7 +23,7 @@ def register_after_fork(func): ``multiprocessing`` module. In python >= 3.7 it also works with ``os.fork()``. - Arguments: + Args: func (function): Function taking no arguments to be called in the child after fork """ diff --git a/torch/multiprocessing/pool.py b/torch/multiprocessing/pool.py index b768c05f9b3cd..85281e7e729fd 100644 --- a/torch/multiprocessing/pool.py +++ b/torch/multiprocessing/pool.py @@ -1,4 +1,3 @@ -import multiprocessing import multiprocessing.pool import multiprocessing.util as util diff --git a/torch/multiprocessing/queue.py b/torch/multiprocessing/queue.py index 9696ee420681a..9622cd8d3fb1a 100644 --- a/torch/multiprocessing/queue.py +++ b/torch/multiprocessing/queue.py @@ -1,5 +1,4 @@ import io -import multiprocessing import multiprocessing.queues from multiprocessing.reduction import ForkingPickler import pickle @@ -32,8 +31,8 @@ class Queue(multiprocessing.queues.Queue): def __init__(self, *args, **kwargs): super(Queue, self).__init__(*args, **kwargs) - self._reader = ConnectionWrapper(self._reader) - self._writer = ConnectionWrapper(self._writer) + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) self._send = self._writer.send self._recv = self._reader.recv @@ -42,6 +41,6 @@ class SimpleQueue(multiprocessing.queues.SimpleQueue): def _make_methods(self): if not isinstance(self._reader, ConnectionWrapper): - self._reader = ConnectionWrapper(self._reader) - self._writer = ConnectionWrapper(self._writer) - super(SimpleQueue, self)._make_methods() + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + super(SimpleQueue, self)._make_methods() # type: ignore[misc] diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index a54b8b162d533..cc83c9cf3473c 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -27,10 +27,10 @@ def __init__(self, storage): self.cdata = storage._weak_ref() # Save a direct reference to _free_weak_ref because the `torch` module # might be cleared during Python shutdown before this module is cleared. - self._free_weak_ref = torch.Storage._free_weak_ref + self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] def expired(self): - return torch.Storage._expired(self.cdata) + return torch.Storage._expired(self.cdata) # type: ignore[attr-defined] def __del__(self): self._free_weak_ref(self.cdata) @@ -322,7 +322,7 @@ def reduce_storage(storage): df = multiprocessing.reduction.DupFd(fd) cache_key = fd_id(fd) metadata = (df, size) - rebuild = rebuild_storage_fd + rebuild = rebuild_storage_fd # type: ignore[assignment] shared_cache[cache_key] = StorageWeakRef(storage) return (rebuild, (type(storage),) + metadata) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index c7f328c5d0781..9ad17c94ccf89 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -1,11 +1,51 @@ +from typing import Optional import multiprocessing import multiprocessing.connection import signal import sys import warnings -from . import _prctl_pr_set_pdeathsig +from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] + + +class ProcessException(Exception): + __slots__ = ["error_index", "error_pid"] + + def __init__(self, msg: str, error_index: int, pid: int): + super().__init__(msg) + self.error_index = error_index + self.pid = pid + + +class ProcessRaisedException(ProcessException): + """ + Exception is thrown when the process failed due to exception + raised by the code. + """ + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + ): + super().__init__(msg, error_index, error_pid) + + +class ProcessExitedException(ProcessException): + """ + Exception is thrown when the process failed due to signal + or exited with a specific code. + """ + __slots__ = ["exit_code"] + + def __init__( + self, msg: str, error_index: int, error_pid: int, + exit_code: int, signal_name: Optional[str] = None + ): + super().__init__(msg, error_index, error_pid) + self.exit_code = exit_code + self.signal_name = signal_name def _wrap(fn, i, args, error_queue): @@ -26,24 +66,8 @@ def _wrap(fn, i, args, error_queue): sys.exit(1) -# Multiprocessing contexts are introduced at Python 3.4 -_supports_context = sys.version_info >= (3, 4) - - -def _python_version_check(): - if not _supports_context: - raise RuntimeError("Requires python 3.4 or higher to use " - "torch.multiprocessing.spawn and " - "torch.multiprocessing.ProcessContext helper " - "to launch multiple processes. If you are using " - "this for distributed training and have a lower " - "version of python, please use " - "torch.distributed.launch instead.") - - class ProcessContext: def __init__(self, processes, error_queues): - _python_version_check() self.error_queues = error_queues self.processes = processes self.sentinels = { @@ -64,7 +88,7 @@ def join(self, timeout=None): Returns ``True`` if all processes have been joined successfully, ``False`` if there are more processes that need to be joined. - Arguments: + Args: timeout (float): Wait this long before giving up on waiting. """ # Ensure this function can be called even when we're done. @@ -98,30 +122,38 @@ def join(self, timeout=None): process.join() # There won't be an error on the queue if the process crashed. + failed_process = self.processes[error_index] if self.error_queues[error_index].empty(): exitcode = self.processes[error_index].exitcode if exitcode < 0: name = signal.Signals(-exitcode).name - raise Exception( + raise ProcessExitedException( "process %d terminated with signal %s" % - (error_index, name) + (error_index, name), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + signal_name=name ) else: - raise Exception( + raise ProcessExitedException( "process %d terminated with exit code %d" % - (error_index, exitcode) + (error_index, exitcode), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode ) original_trace = self.error_queues[error_index].get() msg = "\n\n-- Process %d terminated with the following error:\n" % error_index msg += original_trace - raise Exception(msg) + raise ProcessRaisedException(msg, error_index, failed_process.pid) class SpawnContext(ProcessContext): def __init__(self, processes, error_queues): warnings.warn('SpawnContext is renamed to ProcessContext since 1.4 release.') - super(SpawnContext, self).__init__(self, processes, error_queues) + super(SpawnContext, self).__init__(processes, error_queues) pass @@ -134,7 +166,6 @@ def __init__(self, processes, error_queues): # Currently we only add this API first, we can consider adding it to documentation as # needed in the future. def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): - _python_version_check() mp = multiprocessing.get_context(start_method) error_queues = [] processes = [] @@ -167,7 +198,7 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'): child process, it is forwarded and its traceback is included in the exception raised in the parent process. - Arguments: + Args: fn (function): Function is called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index b870a55c79ac9..82d7c4341d58a 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,5 +1,5 @@ from .modules import * -from .parameter import Parameter +from .parameter import Parameter, UninitializedParameter from .parallel import DataParallel from . import init from . import utils diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index 025ef157b958a..2fc9c58033a26 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -1,11 +1,10 @@ -import warnings from typing import Optional +import warnings # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h -def get_enum(reduction): - # type: (str) -> int +def get_enum(reduction: str) -> int: if reduction == 'none': ret = 0 elif reduction == 'mean': @@ -25,8 +24,7 @@ def get_enum(reduction): # We use these functions in torch/legacy as well, in which case we'll silence the warning -def legacy_get_string(size_average, reduce, emit_warning=True): - # type: (Optional[bool], Optional[bool], bool) -> str +def legacy_get_string(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> str: warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." if size_average is None: @@ -45,6 +43,5 @@ def legacy_get_string(size_average, reduce, emit_warning=True): return ret -def legacy_get_enum(size_average, reduce, emit_warning=True): - # type: (Optional[bool], Optional[bool], bool) -> int +def legacy_get_enum(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> int: return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/common_types.py b/torch/nn/common_types.py index fa9d5bb1eb001..884f739e27813 100644 --- a/torch/nn/common_types.py +++ b/torch/nn/common_types.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Union, Tuple +from typing import TypeVar, Union, Tuple, Optional from .. import Tensor # Create some useful type aliases @@ -24,6 +24,11 @@ _size_5_t = _scalar_or_tuple_5_t[int] _size_6_t = _scalar_or_tuple_6_t[int] +# For arguments which represent optional size parameters (eg, adaptive pool parameters) +_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]] +_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]] +_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]] + # For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) _ratio_2_t = _scalar_or_tuple_2_t[float] _ratio_3_t = _scalar_or_tuple_3_t[float] diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index 194c17bd6b5a4..25a5bcc446aa9 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -57,9 +57,9 @@ def __init__(self, cpp_module): # assigned to in the super class constructor. self.cpp_module = cpp_module super(ModuleWrapper, self).__init__() - self._parameters = OrderedDictWrapper(cpp_module, "_parameters") - self._buffers = OrderedDictWrapper(cpp_module, "_buffers") - self._modules = OrderedDictWrapper(cpp_module, "_modules") + self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] + self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] + self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] for attr in dir(cpp_module): # Skip magic methods and the three attributes above. if not attr.startswith("_"): @@ -78,7 +78,8 @@ def _apply(self, fn): return self - @property + # nn.Module defines training as a boolean + @property # type: ignore[override] def training(self): return self.cpp_module.training diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 2fdb40b2d93fa..ca2aaa5f9a409 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1,37 +1,43 @@ r"""Functional interface""" -import warnings +from typing import Callable, List, Optional, Tuple import math +import warnings import torch +from torch import _VF from torch._C import _infer_size, _add_docstr +from torch._torch_docs import reproducibility_notes, tf32_notes + +from .._jit_internal import boolean_dispatch, _overload +from ..overrides import ( + has_torch_function, has_torch_function_unary, has_torch_function_variadic, + handle_torch_function) from . import _reduction as _Reduction +from . import grad # noqa: F401 from .modules import utils from .modules.utils import _single, _pair, _triple, _list_with_default -from . import grad # noqa: F401 -from torch import _VF -from .._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple -from ..overrides import has_torch_function, handle_torch_function Tensor = torch.Tensor -conv1d = _add_docstr(torch.conv1d, r""" +conv1d = _add_docstr( + torch.conv1d, + r""" conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 1D convolution over an input signal composed of several input planes. -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.Conv1d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -51,27 +57,27 @@ >>> filters = torch.randn(33, 16, 3) >>> inputs = torch.randn(20, 16, 50) >>> F.conv1d(inputs, filters) -""") +""", +) -conv2d = _add_docstr(torch.conv2d, r""" +conv2d = _add_docstr( + torch.conv2d, + r""" conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 2D convolution over an input image composed of several input planes. -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.Conv2d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - - + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` @@ -91,25 +97,27 @@ >>> filters = torch.randn(8,4,3,3) >>> inputs = torch.randn(1,4,5,5) >>> F.conv2d(inputs, filters, padding=1) -""") # noqa: E501 +""", +) # noqa: E501 -conv3d = _add_docstr(torch.conv3d, r""" +conv3d = _add_docstr( + torch.conv3d, + r""" conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Applies a 3D convolution over an input image composed of several input planes. -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.Conv3d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -129,25 +137,27 @@ >>> filters = torch.randn(33, 16, 3, 3, 3) >>> inputs = torch.randn(20, 16, 50, 10, 20) >>> F.conv3d(inputs, filters) -""") # noqa: E501 +""", +) # noqa: E501 -conv_transpose1d = _add_docstr(torch.conv_transpose1d, r""" +conv_transpose1d = _add_docstr( + torch.conv_transpose1d, + r""" conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 1D transposed convolution operator over an input signal composed of several input planes, sometimes also called "deconvolution". -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.ConvTranspose1d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` @@ -170,25 +180,27 @@ >>> inputs = torch.randn(20, 16, 50) >>> weights = torch.randn(16, 33, 5) >>> F.conv_transpose1d(inputs, weights) -""") +""", +) -conv_transpose2d = _add_docstr(torch.conv_transpose2d, r""" +conv_transpose2d = _add_docstr( + torch.conv_transpose2d, + r""" conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 2D transposed convolution operator over an input image composed of several input planes, sometimes also called "deconvolution". -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.ConvTranspose2d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` @@ -213,25 +225,27 @@ >>> inputs = torch.randn(1, 4, 5, 5) >>> weights = torch.randn(4, 8, 3, 3) >>> F.conv_transpose2d(inputs, weights, padding=1) -""") # noqa: E501 +""", +) # noqa: E501 -conv_transpose3d = _add_docstr(torch.conv_transpose3d, r""" +conv_transpose3d = _add_docstr( + torch.conv_transpose3d, + r""" conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor Applies a 3D transposed convolution operator over an input image composed of several input planes, sometimes also called "deconvolution" -This operator supports :ref:`TensorFloat32`. +{tf32_note} See :class:`~torch.nn.ConvTranspose3d` for details and output shape. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" Args: input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` @@ -255,9 +269,12 @@ >>> inputs = torch.randn(20, 16, 50, 10, 20) >>> weights = torch.randn(16, 33, 3, 3, 3) >>> F.conv_transpose3d(inputs, weights) -""") # noqa: E501 +""", +) # noqa: E501 -conv_tbc = _add_docstr(torch.conv_tbc, r""" +conv_tbc = _add_docstr( + torch.conv_tbc, + r""" Applies a 1-dimensional sequence convolution over an input sequence. Input and output dimensions are (Time, Batch, Channels) - hence TBC. @@ -266,11 +283,14 @@ weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) bias: bias of shape (:math:`\text{out\_channels}`) pad: number of timesteps to pad. Default: 0 -""") +""", +) # Pooling -avg_pool1d = _add_docstr(torch.avg_pool1d, r""" +avg_pool1d = _add_docstr( + torch.avg_pool1d, + r""" avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor Applies a 1D average pooling over an input signal composed of several @@ -298,10 +318,13 @@ >>> F.avg_pool1d(input, kernel_size=3, stride=2) tensor([[[ 2., 4., 6.]]]) -""") +""", +) -avg_pool2d = _add_docstr(torch._C._nn.avg_pool2d, r""" +avg_pool2d = _add_docstr( + torch._C._nn.avg_pool2d, + r""" avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size @@ -324,9 +347,12 @@ averaging calculation. Default: ``True`` divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None -""") +""", +) -avg_pool3d = _add_docstr(torch._C._nn.avg_pool3d, r""" +avg_pool3d = _add_docstr( + torch._C._nn.avg_pool3d, + r""" avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step @@ -349,12 +375,13 @@ averaging calculation divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None -""") +""", +) -def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def fractional_max_pool2d_with_indices( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa r"""Applies 2D fractional max pooling over an input signal composed of several input planes. @@ -385,53 +412,63 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None, .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - fractional_max_pool2d_with_indices, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool2d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool2d requires specifying either " - "an output_size or an output_ratio") + raise ValueError("fractional_max_pool2d requires specifying either " "an output_size or an output_ratio") if output_size is None: assert output_ratio is not None _output_ratio = _pair(output_ratio) - output_size = [int(input.size(2) * _output_ratio[0]), - int(input.size(3) * _output_ratio[1])] + output_size = [int(input.size(2) * _output_ratio[0]), int(input.size(3) * _output_ratio[1])] if _random_samples is None: _random_samples = torch.rand(input.size(0), input.size(1), 2, dtype=input.dtype, device=input.device) return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) -def _fractional_max_pool2d(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def _fractional_max_pool2d( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor # noqa - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - fractional_max_pool2d, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) - return fractional_max_pool2d_with_indices(input, kernel_size, output_size, - output_ratio, return_indices, - _random_samples)[0] + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool2d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + fractional_max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=4, default=False, if_true=fractional_max_pool2d_with_indices, if_false=_fractional_max_pool2d, module_name=__name__, - func_name='fractional_max_pool2d') + func_name="fractional_max_pool2d", +) -def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def fractional_max_pool3d_with_indices( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa r"""Applies 3D fractional max pooling over an input signal composed of several input planes. @@ -463,283 +500,348 @@ def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None, .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - fractional_max_pool3d_with_indices, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool3d_with_indices, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool3d requires specifying either " - "an output_size or an output_ratio") + raise ValueError("fractional_max_pool3d requires specifying either " "an output_size or an output_ratio") if output_size is None: assert output_ratio is not None _output_ratio = _triple(output_ratio) - output_size = [int(input.size(2) * _output_ratio[0]), - int(input.size(3) * _output_ratio[1]), - int(input.size(4) * _output_ratio[2])] + output_size = [ + int(input.size(2) * _output_ratio[0]), + int(input.size(3) * _output_ratio[1]), + int(input.size(4) * _output_ratio[2]), + ] if _random_samples is None: _random_samples = torch.rand(input.size(0), input.size(1), 3, dtype=input.dtype, device=input.device) return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) -def _fractional_max_pool3d(input, kernel_size, output_size=None, - output_ratio=None, return_indices=False, - _random_samples=None): +def _fractional_max_pool3d( + input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - fractional_max_pool3d, (input,), input, kernel_size, - output_size=output_size, output_ratio=output_ratio, - return_indices=return_indices, _random_samples=_random_samples) - return fractional_max_pool3d_with_indices(input, kernel_size, output_size, - output_ratio, return_indices, - _random_samples)[0] + if has_torch_function_unary(input): + return handle_torch_function( + fractional_max_pool3d, + (input,), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + fractional_max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=4, default=False, if_true=fractional_max_pool3d_with_indices, if_false=_fractional_max_pool3d, module_name=__name__, - func_name='fractional_max_pool3d') + func_name="fractional_max_pool3d", +) -def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): +def max_pool1d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 1D max pooling over an input signal composed of several input planes. See :class:`~torch.nn.MaxPool1d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_pool1d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool1d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_pool1d, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool1d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool1d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool1d_with_indices, if_false=_max_pool1d, module_name=__name__, - func_name='max_pool1d') + func_name="max_pool1d", +) -def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def max_pool2d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 2D max pooling over an input signal composed of several input planes. See :class:`~torch.nn.MaxPool2d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_pool2d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_pool2d, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool2d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool2d_with_indices, if_false=_max_pool2d, module_name=__name__, - func_name='max_pool2d') + func_name="max_pool2d", +) -def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): +def max_pool3d_with_indices( + input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False +): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa r"""Applies a 3D max pooling over an input signal composed of several input planes. See :class:`~torch.nn.MaxPool3d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_pool3d_with_indices, (input,), input, kernel_size, - stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch._C._nn.max_pool3d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) -def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): +def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_pool3d, (input,), input, kernel_size, stride=stride, padding=padding, - dilation=dilation, ceil_mode=ceil_mode, return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool3d( - input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) + max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=6, default=False, if_true=max_pool3d_with_indices, if_false=_max_pool3d, module_name=__name__, - func_name='max_pool3d') + func_name="max_pool3d", +) -def _unpool_output_size(input, kernel_size, stride, padding, output_size): - # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int] +def _unpool_output_size( + input: Tensor, kernel_size: List[int], stride: List[int], padding: List[int], output_size: Optional[List[int]] +) -> List[int]: input_size = input.size() default_size = torch.jit.annotate(List[int], []) for d in range(len(kernel_size)): - default_size.append((input_size[d + 2] - 1) * stride[d] + - kernel_size[d] - 2 * padding[d]) + default_size.append((input_size[d + 2] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]) if output_size is None: ret = default_size else: if len(output_size) == len(kernel_size) + 2: output_size = output_size[2:] if len(output_size) != len(kernel_size): - raise ValueError("output_size should be a sequence containing " - "{} or {} elements, but it has a length of '{}'" - .format(len(kernel_size), len(kernel_size) + 2, - len(output_size))) + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), len(kernel_size) + 2, len(output_size) + ) + ) for d in range(len(kernel_size)): min_size = default_size[d] - stride[d] max_size = default_size[d] + stride[d] if not (min_size < output_size[d] < max_size): raise ValueError( - 'invalid output_size "{}" (dim {} must be between {} and {})' - .format(output_size, d, min_size, max_size)) + 'invalid output_size "{}" (dim {} must be between {} and {})'.format( + output_size, d, min_size, max_size + ) + ) ret = output_size return ret -def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool1d`. See :class:`~torch.nn.MaxUnpool1d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_unpool1d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _single(kernel_size) if stride is not None: _stride = _single(stride) else: _stride = kernel_size padding = _single(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) if isinstance(output_size, list): output_size = output_size + [1] else: output_size = output_size + (1,) - return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), - output_size).squeeze(3) + return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size).squeeze(3) -def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool2d`. See :class:`~torch.nn.MaxUnpool2d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_unpool2d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _pair(kernel_size) if stride is not None: _stride = _pair(stride) else: _stride = kernel_size padding = _pair(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) return torch._C._nn.max_unpool2d(input, indices, output_size) -def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, - output_size=None): +def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_size=None): # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool3d`. See :class:`~torch.nn.MaxUnpool3d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - max_unpool3d, (input,), input, indices, kernel_size, - stride=stride, padding=padding, output_size=output_size) + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) kernel_size = _triple(kernel_size) if stride is not None: _stride = _triple(stride) else: _stride = kernel_size padding = _triple(padding) - output_size = _unpool_output_size(input, kernel_size, _stride, padding, - output_size) - return torch._C._nn.max_unpool3d( - input, indices, output_size, _stride, padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): @@ -750,18 +852,17 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): See :class:`~torch.nn.LPPool2d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, - ceil_mode=ceil_mode) + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) kw, kh = utils._pair(kernel_size) if stride is not None: out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: out = avg_pool2d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) - return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type) + return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): @@ -772,17 +873,16 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): See :class:`~torch.nn.LPPool1d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, - ceil_mode=ceil_mode) + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) if stride is not None: out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: out = avg_pool1d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) - return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type) + return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): @@ -796,31 +896,31 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): output_size: the target output size (single integer) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_max_pool1d_with_indices, (input,), input, output_size, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices + ) return torch.adaptive_max_pool1d(input, output_size) def _adaptive_max_pool1d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList1[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_max_pool1d, (input,), input, output_size, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool1d_with_indices(input, output_size)[0] + adaptive_max_pool1d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool1d_with_indices, if_false=_adaptive_max_pool1d, module_name=__name__, - func_name='adaptive_max_pool1d') + func_name="adaptive_max_pool1d", +) def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): @@ -835,32 +935,32 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): double-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_max_pool2d_with_indices, (input,), input, output_size, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) def _adaptive_max_pool2d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList2[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_max_pool2d, (input,), input, output_size, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool2d_with_indices(input, output_size)[0] + adaptive_max_pool2d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool2d_with_indices, if_false=_adaptive_max_pool2d, module_name=__name__, - func_name='adaptive_max_pool2d') + func_name="adaptive_max_pool2d", +) def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): @@ -875,35 +975,37 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): triple-integer tuple) return_indices: whether to return pooling indices. Default: ``False`` """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_max_pool3d_with_indices, (input,), input, output_size, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices + ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) def _adaptive_max_pool3d(input, output_size, return_indices=False): # type: (Tensor, BroadcastingList3[int], bool) -> Tensor - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_max_pool3d, (input,), input, output_size, - return_indices=return_indices) + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices + ) return adaptive_max_pool3d_with_indices(input, output_size)[0] + adaptive_max_pool3d = boolean_dispatch( - arg_name='return_indices', + arg_name="return_indices", arg_index=2, default=False, if_true=adaptive_max_pool3d_with_indices, if_false=_adaptive_max_pool3d, module_name=__name__, - func_name='adaptive_max_pool3d') + func_name="adaptive_max_pool3d", +) -adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r""" +adaptive_avg_pool1d = _add_docstr( + torch.adaptive_avg_pool1d, + r""" adaptive_avg_pool1d(input, output_size) -> Tensor Applies a 1D adaptive average pooling over an input signal composed of @@ -913,7 +1015,8 @@ def _adaptive_max_pool3d(input, output_size, return_indices=False): Args: output_size: the target output size (single integer) -""") +""", +) def adaptive_avg_pool2d(input, output_size): @@ -928,10 +1031,8 @@ def adaptive_avg_pool2d(input, output_size): output_size: the target output size (single integer or double-integer tuple) """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_avg_pool2d, (input,), input, output_size) + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool2d(input, _output_size) @@ -948,17 +1049,14 @@ def adaptive_avg_pool3d(input, output_size): output_size: the target output size (single integer or triple-integer tuple) """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - adaptive_avg_pool3d, (input,), input, output_size) + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool3d(input, _output_size) # Activation functions -def dropout(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor +def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli @@ -971,38 +1069,26 @@ def dropout(input, p=0.5, training=True, inplace=False): training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.dropout_(input, p, training) - if inplace - else _VF.dropout(input, p, training)) - - -def alpha_dropout(input, p=0.5, training=False, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor + if has_torch_function_unary(input): + return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) + + +def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: r"""Applies alpha dropout to the input. See :class:`~torch.nn.AlphaDropout` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.alpha_dropout_(input, p, training) - if inplace - else _VF.alpha_dropout(input, p, training)) - - -def dropout2d(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor + if has_torch_function_unary(input): + return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training) + + +def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" Randomly zero out entire channels (a channel is a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -1017,20 +1103,14 @@ def dropout2d(input, p=0.5, training=True, inplace=False): training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout2d, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_dropout_(input, p, training) - if inplace - else _VF.feature_dropout(input, p, training)) - - -def dropout3d(input, p=0.5, training=True, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor + if has_torch_function_unary(input): + return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + + +def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: r""" Randomly zero out entire channels (a channel is a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample in the @@ -1047,20 +1127,14 @@ def dropout3d(input, p=0.5, training=True, inplace=False): """ # This is 100% the same code as dropout2d. We duplicate this code so that # stack traces are not confusing. - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - dropout3d, (input,), input, p=p, training=training, inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_dropout_(input, p, training) - if inplace - else _VF.feature_dropout(input, p, training)) - - -def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): - # type: (Tensor, float, bool, bool) -> Tensor + if has_torch_function_unary(input): + return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + + +def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: r""" Randomly masks out entire channels (a channel is a feature map, e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input @@ -1080,45 +1154,42 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False): training: apply dropout if is ``True``. Default: ``True`` inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - feature_alpha_dropout, (input,), input, p=p, training=training, - inplace=inplace) - if p < 0. or p > 1.: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) - return (_VF.feature_alpha_dropout_(input, p, training) - if inplace - else _VF.feature_alpha_dropout(input, p, training)) - - -def _threshold(input, threshold, value, inplace=False): - # type: (Tensor, float, float, bool) -> Tensor + if has_torch_function_unary(input): + return handle_torch_function( + feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_alpha_dropout_(input, p, training) if inplace else _VF.feature_alpha_dropout(input, p, training) + + +def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = False) -> Tensor: r"""Thresholds each element of the input Tensor. See :class:`~torch.nn.Threshold` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - _threshold, (input,), input, threshold, value, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) if inplace: result = _VF.threshold_(input, threshold, value) else: result = _VF.threshold(input, threshold, value) return result + # We define this function as _threshold because it takes an argument # named threshold, which clobbers the recursive reference to the # function needed for __torch_function__ support threshold = _threshold -threshold_ = _add_docstr(_VF.threshold_, r""" +threshold_ = _add_docstr( + _VF.threshold_, + r""" threshold_(input, threshold, value) -> Tensor In-place version of :func:`~threshold`. -""") +""", +) def relu(input: Tensor, inplace: bool = False) -> Tensor: @@ -1127,9 +1198,8 @@ def relu(input: Tensor, inplace: bool = False) -> Tensor: Applies the rectified linear unit function element-wise. See :class:`~torch.nn.ReLU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(relu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(relu, (input,), input, inplace=inplace) if inplace: result = torch.relu_(input) else: @@ -1137,11 +1207,14 @@ def relu(input: Tensor, inplace: bool = False) -> Tensor: return result -relu_ = _add_docstr(torch.relu_, r""" +relu_ = _add_docstr( + torch.relu_, + r""" relu_(input) -> Tensor In-place version of :func:`~relu`. -""") +""", +) def glu(input: Tensor, dim: int = -1) -> Tensor: @@ -1162,26 +1235,22 @@ def glu(input: Tensor, dim: int = -1) -> Tensor: input (Tensor): input tensor dim (int): dimension on which to split the input. Default: -1 """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(glu, (input,), input, dim=dim) + if has_torch_function_unary(input): + return handle_torch_function(glu, (input,), input, dim=dim) if input.dim() == 0: raise RuntimeError("glu does not support scalars because halving size must be even") return torch._C._nn.glu(input, dim) -def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: +def hardtanh(input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False) -> Tensor: r""" hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - hardtanh, (input,), input, min_val=min_val, max_val=max_val, - inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) if inplace: result = torch._C._nn.hardtanh_(input, min_val, max_val) else: @@ -1189,38 +1258,36 @@ def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: return result -hardtanh_ = _add_docstr(torch._C._nn.hardtanh_, r""" +hardtanh_ = _add_docstr( + torch._C._nn.hardtanh_, + r""" hardtanh_(input, min_val=-1., max_val=1.) -> Tensor In-place version of :func:`~hardtanh`. -""") +""", +) -def relu6(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def relu6(input: Tensor, inplace: bool = False) -> Tensor: r"""relu6(input, inplace=False) -> Tensor Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. See :class:`~torch.nn.ReLU6` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(relu6, (input,), input, inplace=inplace) - return hardtanh(input, 0., 6., inplace) + if has_torch_function_unary(input): + return handle_torch_function(relu6, (input,), input, inplace=inplace) + return hardtanh(input, 0.0, 6.0, inplace) -def elu(input, alpha=1., inplace=False): - # type: (Tensor, float, bool) -> Tensor +def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: r"""Applies element-wise, :math:`\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))`. See :class:`~torch.nn.ELU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(elu, (input,), input, alpha=alpha, - inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch._C._nn.elu_(input, alpha) else: @@ -1228,15 +1295,17 @@ def elu(input, alpha=1., inplace=False): return result -elu_ = _add_docstr(torch._C._nn.elu_, r""" +elu_ = _add_docstr( + torch._C._nn.elu_, + r""" elu_(input, alpha=1.) -> Tensor In-place version of :func:`~elu`. -""") +""", +) -def selu(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def selu(input: Tensor, inplace: bool = False) -> Tensor: r"""selu(input, inplace=False) -> Tensor Applies element-wise, @@ -1246,9 +1315,8 @@ def selu(input, inplace=False): See :class:`~torch.nn.SELU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(selu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(selu, (input,), input, inplace=inplace) if inplace: result = torch.selu_(input) else: @@ -1256,15 +1324,17 @@ def selu(input, inplace=False): return result -selu_ = _add_docstr(torch.selu_, r""" +selu_ = _add_docstr( + torch.selu_, + r""" selu_(input) -> Tensor In-place version of :func:`~selu`. -""") +""", +) -def celu(input, alpha=1., inplace=False): - # type: (Tensor, float, bool) -> Tensor +def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: r"""celu(input, alpha=1., inplace=False) -> Tensor Applies element-wise, @@ -1272,25 +1342,26 @@ def celu(input, alpha=1., inplace=False): See :class:`~torch.nn.CELU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(celu, (input,), input, alpha=alpha, - inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) if inplace: result = torch.celu_(input, alpha) else: result = torch.celu(input, alpha) return result -celu_ = _add_docstr(torch.celu_, r""" + +celu_ = _add_docstr( + torch.celu_, + r""" celu_(input, alpha=1.) -> Tensor In-place version of :func:`~celu`. -""") +""", +) -def leaky_relu(input, negative_slope=0.01, inplace=False): - # type: (Tensor, float, bool) -> Tensor +def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: r""" leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor @@ -1299,11 +1370,8 @@ def leaky_relu(input, negative_slope=0.01, inplace=False): See :class:`~torch.nn.LeakyReLU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - leaky_relu, (input,), input, negative_slope=negative_slope, - inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) if inplace: result = torch._C._nn.leaky_relu_(input, negative_slope) else: @@ -1311,15 +1379,17 @@ def leaky_relu(input, negative_slope=0.01, inplace=False): return result -leaky_relu_ = _add_docstr(torch._C._nn.leaky_relu_, r""" +leaky_relu_ = _add_docstr( + torch._C._nn.leaky_relu_, + r""" leaky_relu_(input, negative_slope=0.01) -> Tensor In-place version of :func:`~leaky_relu`. -""") +""", +) -def prelu(input, weight): - # type: (Tensor, Tensor) -> Tensor +def prelu(input: Tensor, weight: Tensor) -> Tensor: r"""prelu(input, weight) -> Tensor Applies element-wise the function @@ -1328,25 +1398,24 @@ def prelu(input, weight): See :class:`~torch.nn.PReLU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(prelu, (input,), input, weight) + if has_torch_function_unary(input): + return handle_torch_function(prelu, (input,), input, weight) return torch.prelu(input, weight) -def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): - # type: (Tensor, float, float, bool, bool) -> Tensor +def rrelu( + input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False, inplace: bool = False +) -> Tensor: r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor Randomized leaky ReLU. See :class:`~torch.nn.RReLU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - rrelu, (input,), input, lower=lower, upper=upper, - training=training, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function( + rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace + ) if inplace: result = torch.rrelu_(input, lower, upper, training) else: @@ -1354,19 +1423,26 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False): return result -rrelu_ = _add_docstr(torch.rrelu_, r""" +rrelu_ = _add_docstr( + torch.rrelu_, + r""" rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor In-place version of :func:`~rrelu`. -""") +""", +) -logsigmoid = _add_docstr(torch._C._nn.log_sigmoid, r""" +logsigmoid = _add_docstr( + torch._C._nn.log_sigmoid, + r""" logsigmoid(input) -> Tensor Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` See :class:`~torch.nn.LogSigmoid` for more details. -""") +""", +) + def gelu(input): r"""gelu(input) -> Tensor @@ -1378,14 +1454,12 @@ def gelu(input): See `Gaussian Error Linear Units (GELUs) `_. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(gelu, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(gelu, (input,), input) return torch._C._nn.gelu(input) -def hardshrink(input, lambd=0.5): - # type: (Tensor, float) -> Tensor +def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor: r""" hardshrink(input, lambd=0.5) -> Tensor @@ -1393,9 +1467,8 @@ def hardshrink(input, lambd=0.5): See :class:`~torch.nn.Hardshrink` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(hardshrink, (input,), input, lambd=lambd) + if has_torch_function_unary(input): + return handle_torch_function(hardshrink, (input,), input, lambd=lambd) return torch.hardshrink(input, lambd) @@ -1406,9 +1479,8 @@ def tanhshrink(input): See :class:`~torch.nn.Tanhshrink` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(tanhshrink, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(tanhshrink, (input,), input) return input - input.tanh() @@ -1419,13 +1491,14 @@ def softsign(input): See :class:`~torch.nn.Softsign` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(softsign, (input,), input) + if has_torch_function_unary(input): + return handle_torch_function(softsign, (input,), input) return input / (input.abs() + 1) -softplus = _add_docstr(torch._C._nn.softplus, r""" +softplus = _add_docstr( + torch._C._nn.softplus, + r""" softplus(input, beta=1, threshold=20) -> Tensor Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. @@ -1434,13 +1507,16 @@ def softsign(input): when :math:`input \times \beta > threshold`. See :class:`~torch.nn.Softplus` for more details. -""") +""", +) -def _get_softmax_dim(name, ndim, stacklevel): - # type: (str, int, int) -> int - warnings.warn("Implicit dimension choice for {} has been deprecated. " - "Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel) +def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: + warnings.warn( + "Implicit dimension choice for {} has been deprecated. " + "Change the call to include dim=X as an argument.".format(name), + stacklevel=stacklevel, + ) if ndim == 0 or ndim == 1 or ndim == 3: ret = 0 else: @@ -1448,15 +1524,14 @@ def _get_softmax_dim(name, ndim, stacklevel): return ret -def softmin(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmin function. Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. See :class:`~torch.nn.Softmin` for more details. - Arguments: + Args: input (Tensor): input dim (int): A dimension along which softmin will be computed (so every slice along dim will sum to 1). @@ -1464,12 +1539,10 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None): If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('softmin', input.dim(), _stacklevel) + dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) if dtype is None: ret = (-input).softmax(dim) else: @@ -1477,8 +1550,7 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None): return ret -def softmax(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmax function. Softmax is defined as: @@ -1490,7 +1562,7 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): See :class:`~torch.nn.Softmax` for more details. - Arguments: + Args: input (Tensor): input dim (int): A dimension along which softmax will be computed. dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. @@ -1503,12 +1575,10 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): Use log_softmax instead (it's faster and has better numerical properties). """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('softmax', input.dim(), _stacklevel) + dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) if dtype is None: ret = input.softmax(dim) else: @@ -1516,8 +1586,7 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None): return ret -def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): - # type: (Tensor, float, bool, float, int) -> Tensor +def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor: r""" Samples from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretizes. @@ -1557,14 +1626,14 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): .. _Link 2: https://arxiv.org/abs/1611.01144 """ - if not torch.jit.is_scripting(): - if type(logits) is not Tensor and has_torch_function((logits,)): - return handle_torch_function( - gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + if has_torch_function_unary(logits): + return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) if eps != 1e-10: warnings.warn("`eps` parameter is deprecated and has no effect.") - gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1) + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + ) # ~Gumbel(0,1) gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) y_soft = gumbels.softmax(dim) @@ -1579,8 +1648,7 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): return ret -def log_softmax(input, dim=None, _stacklevel=3, dtype=None): - # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor +def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: r"""Applies a softmax followed by a logarithm. While mathematically equivalent to log(softmax(x)), doing these two @@ -1589,19 +1657,17 @@ def log_softmax(input, dim=None, _stacklevel=3, dtype=None): See :class:`~torch.nn.LogSoftmax` for more details. - Arguments: + Args: input (Tensor): input dim (int): A dimension along which log_softmax will be computed. dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if has_torch_function_unary(input): + return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: - dim = _get_softmax_dim('log_softmax', input.dim(), _stacklevel) + dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) if dtype is None: ret = input.log_softmax(dim) else: @@ -1609,13 +1675,16 @@ def log_softmax(input, dim=None, _stacklevel=3, dtype=None): return ret -softshrink = _add_docstr(torch._C._nn.softshrink, r""" +softshrink = _add_docstr( + torch._C._nn.softshrink, + r""" softshrink(input, lambd=0.5) -> Tensor Applies the soft shrinkage function elementwise See :class:`~torch.nn.Softshrink` for more details. -""") +""", +) def tanh(input): @@ -1641,8 +1710,7 @@ def sigmoid(input): return input.sigmoid() -def hardsigmoid(input, inplace=False): - # type: (Tensor, bool) -> Tensor +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: r"""hardsigmoid(input) -> Tensor Applies the element-wise function @@ -1659,16 +1727,14 @@ def hardsigmoid(input, inplace=False): See :class:`~torch.nn.Hardsigmoid` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) if inplace: return torch._C._nn.hardsigmoid_(input) return torch._C._nn.hardsigmoid(input) -def linear(input, weight, bias=None): - # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor +def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. @@ -1682,10 +1748,8 @@ def linear(input, weight, bias=None): - Bias: :math:`(out\_features)` - Output: :math:`(N, *, out\_features)` """ - tens_ops = (input, weight) - if not torch.jit.is_scripting(): - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function(linear, tens_ops, input, weight, bias=bias) + if has_torch_function_variadic(input, weight): + return handle_torch_function(linear, (input, weight), input, weight, bias=bias) if input.dim() == 2 and bias is not None: # fused op is marginally faster ret = torch.addmm(bias, input, weight.t()) @@ -1697,8 +1761,7 @@ def linear(input, weight, bias=None): return ret -def bilinear(input1, input2, weight, bias=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor +def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: r""" Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b` @@ -1717,8 +1780,8 @@ def bilinear(input1, input2, weight, bias=None): """ return torch.bilinear(input1, input2, weight, bias) -def silu(input, inplace=False): - # type: (Tensor, bool) -> Tensor + +def silu(input: Tensor, inplace: bool = False) -> Tensor: r"""Applies the silu function, element-wise. .. math:: @@ -1734,13 +1797,13 @@ def silu(input, inplace=False): See :class:`~torch.nn.SiLU` for more details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(silu, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(silu, (input,), input, inplace=inplace) if inplace: return torch._C._nn.silu_(input) return torch._C._nn.silu(input) + def hardswish(input: Tensor, inplace: bool = False) -> Tensor: r"""Applies the hardswish function, element-wise, as described in the paper: @@ -1758,23 +1821,27 @@ def hardswish(input: Tensor, inplace: bool = False) -> Tensor: .. _`Searching for MobileNetV3`: https://arxiv.org/abs/1905.02244 """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function(hardswish, (input,), input, inplace=inplace) + if has_torch_function_unary(input): + return handle_torch_function(hardswish, (input,), input, inplace=inplace) if inplace: return torch._C._nn.hardswish_(input) return torch._C._nn.hardswish(input) -def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type): - # type: (Tensor, Tensor, float, float) -> Tensor +def _no_grad_embedding_renorm_(weight: Tensor, input: Tensor, max_norm: float, norm_type: float) -> Tensor: with torch.no_grad(): torch.embedding_renorm_(weight, input, max_norm, norm_type) -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., - scale_grad_by_freq=False, sparse=False): - # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: r"""A simple lookup table that looks up embeddings in a fixed dictionary and size. This module is often used to retrieve word embeddings using indices. @@ -1832,11 +1899,12 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., [ 0.0000, 0.0000, 0.0000], [ 0.6262, 0.2438, 0.7471]]]) """ + if padding_idx is not None: if padding_idx > 0: - assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings' + assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings' + assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings" padding_idx = weight.size(0) + padding_idx else: padding_idx = -1 @@ -1853,19 +1921,25 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) -def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, - scale_grad_by_freq=False, mode='mean', sparse=False, - per_sample_weights=None, include_last_offset=False): - # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool, Optional[Tensor], bool) -> Tensor +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, +) -> Tensor: r"""Computes sums, means or maxes of `bags` of embeddings, without instantiating the intermediate embeddings. See :class:`torch.nn.EmbeddingBag` for more details. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} Args: input (LongTensor): Tensor containing bags of indices into the embedding matrix @@ -1933,27 +2007,37 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, tensor([[ 0.3397, 0.3552, 0.5545], [ 0.5893, 0.4386, 0.5882]]) """ - if not torch.jit.is_scripting(): - tens_ops = (input, weight) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - embedding_bag, tens_ops, input, weight, offsets=offsets, max_norm=max_norm, - norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, mode=mode, - sparse=sparse, per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset) + if has_torch_function_variadic(input, weight): + return handle_torch_function( + embedding_bag, + (input, weight), + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + ) # Check for backward compatibility. # Used to be embedding_bag(weight, input, ...) # Now is embedding_bag(input, weight, ...) if weight.dtype == torch.long and input.is_floating_point(): - warnings.warn("Argument order of nn.functional.embedding_bag was changed. " - "Usage `embedding_bag(weight, input, ...)` is deprecated, " - "and should now be `embedding_bag(input, weight, ...)`.") + warnings.warn( + "Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`." + ) weight, input = input, weight if per_sample_weights is not None and input.size() != per_sample_weights.size(): - raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, " - "then it must have the same shape as the input ({})" - .format(per_sample_weights.shape, input.shape)) + raise ValueError( + "embedding_bag: If per_sample_weights ({}) is not None, " + "then it must have the same shape as the input ({})".format(per_sample_weights.shape, input.shape) + ) if input.dim() == 2: if offsets is not None: @@ -1961,12 +2045,13 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, # TODO: Remove this once script supports type() calls if not torch.jit.is_scripting(): type_str = str(type(offsets)) - raise ValueError("if input is 2D, then offsets has to be None" - ", as input is treated is a mini-batch of" - " fixed length sequences. However, found " - "offsets of type {}".format(type_str)) - offsets = torch.arange(0, input.numel(), input.size(1), - dtype=torch.long, device=input.device) + raise ValueError( + "if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + "offsets of type {}".format(type_str) + ) + offsets = torch.arange(0, input.numel(), input.size(1), dtype=input.dtype, device=input.device) input = input.reshape(-1) if per_sample_weights is not None: @@ -1977,13 +2062,12 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, if offsets.dim() != 1: raise ValueError("offsets has to be a 1D Tensor") else: - raise ValueError("input has to be 1D or 2D Tensor," - " but got Tensor of dimension {}".format(input.dim())) - if mode == 'sum': + raise ValueError("input has to be 1D or 2D Tensor," " but got Tensor of dimension {}".format(input.dim())) + if mode == "sum": mode_enum = 0 - elif mode == 'mean': + elif mode == "mean": mode_enum = 1 - elif mode == 'max': + elif mode == "max": mode_enum = 2 if scale_grad_by_freq: @@ -2002,26 +2086,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, # remove once script supports set_grad_enabled _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) - if per_sample_weights is not None and mode != 'sum': - raise NotImplementedError("embedding_bag: per_sample_weights was not None. " - "per_sample_weights is only supported for mode='sum' " - "(got mode='{}'). Please open a feature request on GitHub." - .format(mode)) + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + "(got mode='{}'). Please open a feature request on GitHub.".format(mode) + ) ret, _, _, _ = torch.embedding_bag( - weight, - input, - offsets, - scale_grad_by_freq, - mode_enum, - sparse, - per_sample_weights, - include_last_offset) + weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights, include_last_offset + ) return ret -def _verify_batch_size(size): - # type: (List[int]) -> None +embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) + + +def _verify_batch_size(size: List[int]) -> None: # XXX: JIT script does not support the reduce from functools, and mul op is a # builtin, which cannot be used as a value to a func yet, so rewrite this size # check to a simple equivalent for loop @@ -2035,100 +2116,130 @@ def _verify_batch_size(size): for i in range(len(size) - 2): size_prods *= size[i + 2] if size_prods == 1: - raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) - - -def batch_norm(input, running_mean, running_var, weight=None, bias=None, - training=False, momentum=0.1, eps=1e-5): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa + raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size)) + + +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + # noqa r"""Applies Batch Normalization for each channel across a batch of data. See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, :class:`~torch.nn.BatchNorm3d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - batch_norm, (input,), input, running_mean, running_var, weight=weight, - bias=bias, training=training, momentum=momentum, eps=eps) + if has_torch_function_unary(input): + return handle_torch_function( + batch_norm, + (input,), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) if training: _verify_batch_size(input.size()) return torch.batch_norm( - input, weight, bias, running_mean, running_var, - training, momentum, eps, torch.backends.cudnn.enabled + input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled ) -def instance_norm(input, running_mean=None, running_var=None, weight=None, - bias=None, use_input_stats=True, momentum=0.1, eps=1e-5): - # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + # noqa r"""Applies Instance Normalization for each channel in each data sample in a batch. See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, :class:`~torch.nn.InstanceNorm3d` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - instance_norm, (input,), input, running_mean=running_mean, - running_var=running_var, weight=weight, bias=bias, - use_input_stats=use_input_stats, momentum=momentum, eps=eps) + if has_torch_function_unary(input): + return handle_torch_function( + instance_norm, + (input,), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) _verify_batch_size(input.size()) return torch.instance_norm( - input, weight, bias, running_mean, running_var, - use_input_stats, momentum, eps, torch.backends.cudnn.enabled + input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, torch.backends.cudnn.enabled ) -def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): - # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor +def layer_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: r"""Applies Layer Normalization for last certain number of dimensions. See :class:`~torch.nn.LayerNorm` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps) - return torch.layer_norm(input, normalized_shape, weight, bias, eps, - torch.backends.cudnn.enabled) + if has_torch_function_unary(input): + return handle_torch_function( + layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps + ) + return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) -def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): - # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor +def group_norm( + input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5 +) -> Tensor: r"""Applies Group Normalization for last certain number of dimensions. See :class:`~torch.nn.GroupNorm` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) - _verify_batch_size([ - input.size(0) * input.size(1) // num_groups, num_groups] - + list(input.size()[2:])) - return torch.group_norm(input, num_groups, weight, bias, eps, - torch.backends.cudnn.enabled) - - -def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): - # type: (Tensor, int, float, float, float) -> Tensor + if has_torch_function_unary(input): + return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) + _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) + return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) + + +def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0) -> Tensor: r"""Applies local response normalization over an input signal composed of several input planes, where channels occupy the second dimension. Applies normalization across channels. See :class:`~torch.nn.LocalResponseNorm` for details. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + if has_torch_function_unary(input): + return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) dim = input.dim() if dim < 3: - raise ValueError('Expected 3D or higher dimensionality \ - input (got {} dimensions)'.format(dim)) + raise ValueError( + "Expected 3D or higher dimensionality \ + input (got {} dimensions)".format( + dim + ) + ) div = input.mul(input).unsqueeze(1) if dim == 3: div = pad(div, (0, 0, size // 2, (size - 1) // 2)) @@ -2145,25 +2256,25 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.): # loss -def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, - reduction='mean', zero_infinity=False): - # type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, +) -> Tensor: r"""The Connectionist Temporal Classification loss. See :class:`~torch.nn.CTCLoss` for details. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. + {cudnn_reproducibility_note} Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} Args: log_probs: :math:`(T, N, C)` where `C = number of characters in alphabet including blank`, @@ -2198,13 +2309,23 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward() """ - return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), - zero_infinity) + return torch.ctc_loss( + log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity + ) -def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor +ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) + + +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""The negative log likelihood loss. See :class:`~torch.nn.NLLLoss` for details. @@ -2246,21 +2367,28 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, >>> output = F.nll_loss(F.log_softmax(input), target) >>> output.backward() """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - nll_loss, tens_ops, input, target, weight=weight, size_average=size_average, - ignore_index=ignore_index, reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + nll_loss, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) dim = input.dim() if dim < 2: - raise ValueError('Expected 2 or more dimensions (got {})'.format(dim)) + raise ValueError("Expected 2 or more dimensions (got {})".format(dim)) if input.size(0) != target.size(0): - raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' - .format(input.size(0), target.size(0))) + raise ValueError( + "Expected input batch_size ({}) to match target batch_size ({}).".format(input.size(0), target.size(0)) + ) if dim == 2: ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) elif dim == 4: @@ -2271,8 +2399,7 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, c = input.size(1) out_size = (n,) + input.size()[2:] if target.size()[1:] != input.size()[2:]: - raise ValueError('Expected target size {}, got {}'.format( - out_size, target.size())) + raise ValueError("Expected target size {}, got {}".format(out_size, target.size())) input = input.contiguous() target = target.contiguous() # support empty batches, see #15870 @@ -2285,19 +2412,24 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100, else: target = target.view(n, 0, 0) reduction_enum = _Reduction.get_enum(reduction) - if reduction != 'none': - ret = torch._C._nn.nll_loss2d( - input, target, weight, reduction_enum, ignore_index) + if reduction != "none": + ret = torch._C._nn.nll_loss2d(input, target, weight, reduction_enum, ignore_index) else: - out = torch._C._nn.nll_loss2d( - input, target, weight, reduction_enum, ignore_index) + out = torch._C._nn.nll_loss2d(input, target, weight, reduction_enum, ignore_index) ret = out.view(out_size) return ret -def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""Poisson negative log likelihood loss. See :class:`~torch.nn.PoissonNLLLoss` for details. @@ -2330,15 +2462,22 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - poisson_nll_loss, tens_ops, input, target, log_input=log_input, full=full, - size_average=size_average, eps=eps, reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + poisson_nll_loss, + (input, target), + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + if reduction != "none" and reduction != "mean" and reduction != "sum": ret = input raise ValueError(reduction + " is not valid") @@ -2346,8 +2485,14 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non return ret -def kl_div(input, target, size_average=None, reduce=None, reduction='mean', log_target=False): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str, bool) -> Tensor +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + log_target: bool = False, +) -> Tensor: r"""The `Kullback-Leibler divergence Loss `__ @@ -2386,37 +2531,50 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean', log_ :attr:``reduction`` = ``'batchmean'`` which aligns with KL math definition. In the next major release, ``'mean'`` will be changed to be the same as 'batchmean'. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - kl_div, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction, log_target=log_target) + if has_torch_function_variadic(input, target): + return handle_torch_function( + kl_div, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: - if reduction == 'mean': - warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size." - "'batchmean' divides only by the batch size, and aligns with the KL div math definition." - "'mean' will be changed to behave the same as 'batchmean' in the next major release.") + if reduction == "mean": + warnings.warn( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release." + ) # special case for batchmean - if reduction == 'batchmean': - reduction_enum = _Reduction.get_enum('sum') + if reduction == "batchmean": + reduction_enum = _Reduction.get_enum("sum") else: reduction_enum = _Reduction.get_enum(reduction) reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) - if reduction == 'batchmean' and input.dim() != 0: + if reduction == "batchmean" and input.dim() != 0: reduced = reduced / input.size()[0] return reduced -def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""This criterion combines `log_softmax` and `nll_loss` in a single function. @@ -2457,21 +2615,31 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1 >>> loss = F.cross_entropy(input, target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - cross_entropy, tens_ops, input, target, weight=weight, - size_average=size_average, ignore_index=ignore_index, reduce=reduce, - reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + cross_entropy, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) -def binary_cross_entropy(input, target, weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""Function that measures the Binary Cross Entropy between the target and the output. @@ -2505,31 +2673,43 @@ def binary_cross_entropy(input, target, weight=None, size_average=None, >>> loss = F.binary_cross_entropy(F.sigmoid(input), target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - binary_cross_entropy, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + binary_cross_entropy, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if target.size() != input.size(): - raise ValueError("Using a target size ({}) that is different to the input size ({}) is deprecated. " - "Please ensure they have the same size.".format(target.size(), input.size())) + raise ValueError( + "Using a target size ({}) that is different to the input size ({}) is deprecated. " + "Please ensure they have the same size.".format(target.size(), input.size()) + ) if weight is not None: new_size = _infer_size(target.size(), weight.size()) weight = weight.expand(new_size) - return torch._C._nn.binary_cross_entropy( - input, target, weight, reduction_enum) + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) -def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, - reduce=None, reduction='mean', pos_weight=None): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, +) -> Tensor: r"""Function that measures Binary Cross Entropy between target and output logits. @@ -2565,13 +2745,18 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=No >>> loss = F.binary_cross_entropy_with_logits(input, target) >>> loss.backward() """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - binary_cross_entropy_with_logits, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction, - pos_weight=pos_weight) + if has_torch_function_variadic(input, target): + return handle_torch_function( + binary_cross_entropy_with_logits, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2583,83 +2768,99 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=No return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) -def _smooth_l1_loss(input, target): - # type: (Tensor, Tensor) -> Tensor - t = torch.abs(input - target) - return torch.where(t < 1, 0.5 * t ** 2, t - 0.5) - - -def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> Tensor: r"""Function that uses a squared term if the absolute - element-wise error falls below 1 and an L1 term otherwise. + element-wise error falls below beta and an L1 term otherwise. See :class:`~torch.nn.SmoothL1Loss` for details. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - smooth_l1_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + smooth_l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) expanded_input, expanded_target = torch.broadcast_tensors(input, target) - return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) + return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), beta) -def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor Function that takes the mean element-wise absolute value difference. See :class:`~torch.nn.L1Loss` for details. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, - reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - expanded_input, expanded_target = torch.broadcast_tensors(input, target) return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) -def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor Measures the element-wise mean squared error. See :class:`~torch.nn.MSELoss` for details. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, - reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if not (target.size() == input.size()): - warnings.warn("Using a target size ({}) that is different to the input size ({}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.".format(target.size(), input.size()), - stacklevel=2) + warnings.warn( + "Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2667,42 +2868,68 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) -def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.MarginRankingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - tens_ops = (input1, input2, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - margin_ranking_loss, tens_ops, input1, input2, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + margin_ranking_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if input1.dim() == 0 or input2.dim() == 0 or target.dim() == 0: - raise RuntimeError(("margin_ranking_loss does not support scalars, got sizes: " - "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size()))) + raise RuntimeError( + ( + "margin_ranking_loss does not support scalars, got sizes: " + "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size()) + ) + ) return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) -def hinge_embedding_loss(input, target, margin=1.0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = 1.0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.HingeEmbeddingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - hinge_embedding_loss, tens_ops, input, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + hinge_embedding_loss, + (input, target), + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2710,18 +2937,27 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None, return torch.hinge_embedding_loss(input, target, margin, reduction_enum) -def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.MultiLabelMarginLoss` for details. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - multilabel_margin_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2729,18 +2965,21 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) -def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.SoftMarginLoss` for details. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - soft_margin_loss, tens_ops, input, target, size_average=size_average, - reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2748,19 +2987,29 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m return torch._C._nn.soft_margin_loss(input, target, reduction_enum) -def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None) -> Tensor See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - multilabel_soft_margin_loss, tens_ops, input, target, weight=weight, - size_average=size_average, reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_soft_margin_loss, + (input, target), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) @@ -2771,11 +3020,11 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, loss = loss.sum(dim=1) / input.size(1) # only return N loss values - if reduction == 'none': + if reduction == "none": ret = loss - elif reduction == 'mean': + elif reduction == "mean": ret = loss.mean() - elif reduction == 'sum': + elif reduction == "sum": ret = loss.sum() else: ret = input @@ -2783,19 +3032,31 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, return ret -def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.CosineEmbeddingLoss` for details. """ # noqa - if not torch.jit.is_scripting(): - tens_ops = (input1, input2, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - cosine_embedding_loss, tens_ops, input1, input2, target, margin=margin, - size_average=size_average, reduce=reduce, reduction=reduction) + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + cosine_embedding_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: @@ -2803,37 +3064,54 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) -def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None, - reduce=None, reduction='mean'): - # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor See :class:`~torch.nn.MultiMarginLoss` for details. """ - if not torch.jit.is_scripting(): - tens_ops = (input, target) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - multi_margin_loss, tens_ops, input, target, p=p, margin=margin, - weight=weight, size_average=size_average, reduce=reduce, - reduction=reduction) + if has_torch_function_variadic(input, target): + return handle_torch_function( + multi_margin_loss, + (input, target), + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) if p != 1 and p != 2: - raise ValueError('only p == 1 and p == 2 supported') + raise ValueError("only p == 1 and p == 2 supported") if weight is not None: if weight.dim() != 1: - raise ValueError('weight must be one-dimensional') + raise ValueError("weight must be one-dimensional") return torch._C._nn.multi_margin_loss(input, target, p, margin, weight, reduction_enum) -pixel_shuffle = _add_docstr(torch.pixel_shuffle, r""" +pixel_shuffle = _add_docstr( + torch.pixel_shuffle, + r""" +pixel_shuffle(input, upscale_factor) -> Tensor + Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a -tensor of shape :math:`(*, C, H \times r, W \times r)`. +tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`. See :class:`~torch.nn.PixelShuffle` for details. @@ -2847,9 +3125,38 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N >>> output = torch.nn.functional.pixel_shuffle(input, 3) >>> print(output.size()) torch.Size([1, 1, 12, 12]) -""") +""", +) + +pixel_unshuffle = _add_docstr( + torch.pixel_unshuffle, + r""" +pixel_unshuffle(input, downscale_factor) -> Tensor + +Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a +tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape +:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`. + +See :class:`~torch.nn.PixelUnshuffle` for details. + +Args: + input (Tensor): the input tensor + downscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 1, 12, 12) + >>> output = torch.nn.functional.pixel_unshuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) +""", +) + +channel_shuffle = _add_docstr( + torch.channel_shuffle, + r""" +channel_shuffle(input, groups) -> Tensor -channel_shuffle = _add_docstr(torch.channel_shuffle, r""" Divide the channels in a tensor of shape :math:`(*, C , H, W)` into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, while keeping the original tensor shape. @@ -2884,20 +3191,23 @@ def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=N [[13, 14], [15, 16]], ]] -""") +""", +) + @_overload # noqa: F811 -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float], str, Optional[bool]) -> Tensor pass + @_overload # noqa: F811 -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor pass -def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # noqa: F811 +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 r"""Upsamples the input to either the given :attr:`size` or the given :attr:`scale_factor` @@ -2906,9 +3216,7 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners= This is equivalent with ``nn.functional.interpolate(...)``. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} The algorithm used for upsampling is determined by :attr:`mode`. @@ -2959,24 +3267,37 @@ def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners= warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") return interpolate(input, size, scale_factor, mode, align_corners) + +upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) + + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float], str, Optional[bool], Optional[bool]) -> Tensor pass + @_overload # noqa: F811 -def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool], Optional[bool]) -> Tensor +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, +) -> Tensor: # noqa: F811 pass def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 @@ -3015,7 +3336,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne Default: ``False`` recompute_scale_factor (bool, optional): recompute the scale_factor for use in the interpolation calculation. When `scale_factor` is passed as a parameter, it is used - to compute the `output_size`. If `recompute_scale_factor` is ```False`` or not specified, + to compute the `output_size`. If `recompute_scale_factor` is ``False`` or not specified, the passed-in `scale_factor` will be used in the interpolation computation. Otherwise, a new `scale_factor` will be computed based on the output and input sizes for use in the interpolation computation (i.e. the computation will be identical to if the computed @@ -3047,27 +3368,34 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne calculation. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - interpolate, (input,), input, size=size, scale_factor=scale_factor, - mode=mode, align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor) - - if mode in ('nearest', 'area'): + if has_torch_function_unary(input): + return handle_torch_function( + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + + if mode in ("nearest", "area"): if align_corners is not None: - raise ValueError("align_corners option can only be set with the " - "interpolating modes: linear | bilinear | bicubic | trilinear") + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) else: if align_corners is None: - warnings.warn("Default upsampling behavior when mode={} is changed " - "to align_corners=False since 0.4.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of nn.Upsample for details.".format(mode)) + warnings.warn( + "Default upsampling behavior when mode={} is changed " + "to align_corners=False since 0.4.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of nn.Upsample for details.".format(mode) + ) align_corners = False dim = input.dim() - 2 # Number of spatial dimensions. @@ -3077,14 +3405,15 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # After this block, exactly one of output_size and scale_factors will # be non-None, and it will be a list (or tuple). if size is not None and scale_factor is not None: - raise ValueError('only one of size or scale_factor should be defined') + raise ValueError("only one of size or scale_factor should be defined") elif size is not None: assert scale_factor is None scale_factors = None if isinstance(size, (list, tuple)): if len(size) != dim: - raise ValueError('size shape must match input shape. ' - 'Input is {}D, size is {}'.format(dim, len(size))) + raise ValueError( + "size shape must match input shape. " "Input is {}D, size is {}".format(dim, len(size)) + ) output_size = size else: output_size = [size for _ in range(dim)] @@ -3093,13 +3422,15 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne output_size = None if isinstance(scale_factor, (list, tuple)): if len(scale_factor) != dim: - raise ValueError('scale_factor shape must match input shape. ' - 'Input is {}D, scale_factor is {}'.format(dim, len(scale_factor))) + raise ValueError( + "scale_factor shape must match input shape. " + "Input is {}D, scale_factor is {}".format(dim, len(scale_factor)) + ) scale_factors = scale_factor else: scale_factors = [scale_factor for _ in range(dim)] else: - raise ValueError('either size or scale_factor should be defined') + raise ValueError("either size or scale_factor should be defined") if recompute_scale_factor is None: # only warn when the scales have floating values since @@ -3107,11 +3438,13 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne if scale_factors is not None: for scale in scale_factors: if math.floor(scale) != scale: - warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed " - "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, " - "instead of relying on the computed output size. " - "If you wish to restore the old behavior, please set recompute_scale_factor=True. " - "See the documentation of nn.Upsample for details. ") + warnings.warn( + "The default behavior for interpolate/upsample with float scale_factor changed " + "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, " + "instead of relying on the computed output size. " + "If you wish to restore the old behavior, please set recompute_scale_factor=True. " + "See the documentation of nn.Upsample for details. " + ) break elif recompute_scale_factor and size is not None: raise ValueError("recompute_scale_factor is not meaningful with an explicit size.") @@ -3126,70 +3459,80 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne # The C++ code will recompute it based on the (integer) output size. if not torch.jit.is_scripting() and torch._C._get_tracing_state(): # make scale_factor a tensor in tracing so constant doesn't get baked in - output_size = [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], - dtype=torch.float32)).float())) for i in range(dim)] + output_size = [ + (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float())) + for i in range(dim) + ] else: assert scale_factors is not None output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] scale_factors = None - if input.dim() == 3 and mode == 'nearest': + if input.dim() == 3 and mode == "nearest": return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) - if input.dim() == 4 and mode == 'nearest': + if input.dim() == 4 and mode == "nearest": return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) - if input.dim() == 5 and mode == 'nearest': + if input.dim() == 5 and mode == "nearest": return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) - if input.dim() == 3 and mode == 'area': + if input.dim() == 3 and mode == "area": assert output_size is not None return adaptive_avg_pool1d(input, output_size) - if input.dim() == 4 and mode == 'area': + if input.dim() == 4 and mode == "area": assert output_size is not None return adaptive_avg_pool2d(input, output_size) - if input.dim() == 5 and mode == 'area': + if input.dim() == 5 and mode == "area": assert output_size is not None return adaptive_avg_pool3d(input, output_size) - if input.dim() == 3 and mode == 'linear': + if input.dim() == 3 and mode == "linear": assert align_corners is not None return torch._C._nn.upsample_linear1d(input, output_size, align_corners, scale_factors) - if input.dim() == 4 and mode == 'bilinear': + if input.dim() == 4 and mode == "bilinear": assert align_corners is not None return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors) - if input.dim() == 5 and mode == 'trilinear': + if input.dim() == 5 and mode == "trilinear": assert align_corners is not None return torch._C._nn.upsample_trilinear3d(input, output_size, align_corners, scale_factors) - if input.dim() == 4 and mode == 'bicubic': + if input.dim() == 4 and mode == "bicubic": assert align_corners is not None return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors) - if input.dim() == 3 and mode == 'bilinear': + if input.dim() == 3 and mode == "bilinear": raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") - if input.dim() == 3 and mode == 'trilinear': + if input.dim() == 3 and mode == "trilinear": raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") - if input.dim() == 4 and mode == 'linear': + if input.dim() == 4 and mode == "linear": raise NotImplementedError("Got 4D input, but linear mode needs 3D input") - if input.dim() == 4 and mode == 'trilinear': + if input.dim() == 4 and mode == "trilinear": raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") - if input.dim() == 5 and mode == 'linear': + if input.dim() == 5 and mode == "linear": raise NotImplementedError("Got 5D input, but linear mode needs 3D input") - if input.dim() == 5 and mode == 'bilinear': + if input.dim() == 5 and mode == "bilinear": raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") - raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported" - " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear" - " (got {})".format(input.dim(), mode)) + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear" + " (got {})".format(input.dim(), mode) + ) + + +interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 # type: (Tensor, Optional[int], Optional[float]) -> Tensor pass + @_overload # noqa: F811 def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 # type: (Tensor, Optional[List[int]], Optional[float]) -> Tensor pass + def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 r"""Upsamples the input, using nearest neighbours' pixel values. @@ -3207,34 +3550,44 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 scale_factor (int): multiplier for spatial size. Has to be an integer. Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") - return interpolate(input, size, scale_factor, mode='nearest') + return interpolate(input, size, scale_factor, mode="nearest") + + +upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[int], Optional[float]) -> Tensor +def upsample_bilinear( + input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[float]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[int], Optional[List[float]]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 pass + @_overload # noqa: F811 -def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 - # type: (Tensor, Optional[List[int]], Optional[List[float]]) -> Tensor +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 pass + def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 r"""Upsamples the input, using bilinear upsampling. @@ -3252,29 +3605,35 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 scale_factor (int or Tuple[int, int]): multiplier for spatial size Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + {backward_reproducibility_note} """ # DeprecationWarning is ignored by default warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.") - return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + +upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(**reproducibility_notes) GRID_SAMPLE_INTERPOLATION_MODES = { - 'bilinear': 0, - 'nearest': 1, + "bilinear": 0, + "nearest": 1, + "bicubic": 2, } GRID_SAMPLE_PADDING_MODES = { - 'zeros': 0, - 'border': 1, - 'reflection': 2, + "zeros": 0, + "border": 1, + "reflection": 2, } -def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None): - # type: (Tensor, Tensor, str, str, Optional[bool]) -> Tensor +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> Tensor: r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. @@ -3330,8 +3689,9 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) mode (str): interpolation mode to calculate output values - ``'bilinear'`` | ``'nearest'``. Default: ``'bilinear'`` - Note: When ``mode='bilinear'`` and the input is 5-D, the interpolation mode + ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` + Note: ``mode='bicubic'`` supports only 4-D input. + When ``mode='bilinear'`` and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear. padding_mode (str): padding mode for outside grid values @@ -3361,45 +3721,61 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner The default behavior up to version 1.2.0 was ``align_corners = True``. Since then, the default behavior has been changed to ``align_corners = False``, in order to bring it in line with the default for :func:`interpolate`. + + .. note:: + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + Clamp the results with :func: `torch.clamp` to ensure they are within the valid range. + .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 + .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 """ - if not torch.jit.is_scripting(): - tens_ops = (input, grid) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, - align_corners=align_corners) - if mode != 'bilinear' and mode != 'nearest': - raise ValueError("nn.functional.grid_sample(): expected mode to be " - "'bilinear' or 'nearest', but got: '{}'".format(mode)) - if padding_mode != 'zeros' and padding_mode != 'border' and padding_mode != 'reflection': - raise ValueError("nn.functional.grid_sample(): expected padding_mode " - "to be 'zeros', 'border', or 'reflection', " - "but got: '{}'".format(padding_mode)) - - if mode == 'bilinear': + if has_torch_function_variadic(input, grid): + return handle_torch_function( + grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners + ) + if mode != "bilinear" and mode != "nearest" and mode != "bicubic": + raise ValueError( + "nn.functional.grid_sample(): expected mode to be " + "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode) + ) + if padding_mode != "zeros" and padding_mode != "border" and padding_mode != "reflection": + raise ValueError( + "nn.functional.grid_sample(): expected padding_mode " + "to be 'zeros', 'border', or 'reflection', " + "but got: '{}'".format(padding_mode) + ) + + if mode == "bilinear": mode_enum = 0 - else: # mode == 'nearest' + elif mode == "nearest": mode_enum = 1 + else: # mode == 'bicubic' + mode_enum = 2 - if padding_mode == 'zeros': + if padding_mode == "zeros": padding_mode_enum = 0 - elif padding_mode == 'border': + elif padding_mode == "border": padding_mode_enum = 1 else: # padding_mode == 'reflection' padding_mode_enum = 2 if align_corners is None: - warnings.warn("Default grid_sample and affine_grid behavior has changed " - "to align_corners=False since 1.3.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of grid_sample for details.") + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) align_corners = False return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) -def affine_grid(theta, size, align_corners=None): - # type: (Tensor, List[int], Optional[bool]) -> Tensor +def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = None) -> Tensor: r"""Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. @@ -3447,51 +3823,56 @@ def affine_grid(theta, size, align_corners=None): along a unit dimension are considered to be at ```0`` (the center of the input image). """ - if not torch.jit.is_scripting(): - if type(theta) is not Tensor and has_torch_function((theta,)): - return handle_torch_function( - affine_grid, (theta,), theta, size, align_corners=align_corners) + if has_torch_function_unary(theta): + return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) if align_corners is None: - warnings.warn("Default grid_sample and affine_grid behavior has changed " - "to align_corners=False since 1.3.0. Please specify " - "align_corners=True if the old behavior is desired. " - "See the documentation of grid_sample for details.") + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) align_corners = False # enforce floating point dtype on theta if not theta.is_floating_point(): - raise ValueError("Expected theta to have floating point type, but got {}" - .format(theta.dtype)) + raise ValueError("Expected theta to have floating point type, but got {}".format(theta.dtype)) # check that shapes and sizes match if len(size) == 4: if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: - raise ValueError("Expected a batch of 2D affine matrices of shape Nx2x3 " - "for size {}. Got {}.".format(size, theta.shape)) + raise ValueError( + "Expected a batch of 2D affine matrices of shape Nx2x3 " + "for size {}. Got {}.".format(size, theta.shape) + ) spatial_size = size[-2:] # spatial dimension sizes elif len(size) == 5: if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: - raise ValueError("Expected a batch of 3D affine matrices of shape Nx3x4 " - "for size {}. Got {}.".format(size, theta.shape)) + raise ValueError( + "Expected a batch of 3D affine matrices of shape Nx3x4 " + "for size {}. Got {}.".format(size, theta.shape) + ) spatial_size = size[-3:] # spatial dimension sizes else: - raise NotImplementedError("affine_grid only supports 4D and 5D sizes, " - "for 2D and 3D affine transforms, respectively. " - "Got size {}.".format(size)) + raise NotImplementedError( + "affine_grid only supports 4D and 5D sizes, " + "for 2D and 3D affine transforms, respectively. " + "Got size {}.".format(size) + ) # check for empty span if align_corners and min(spatial_size) == 1: - warnings.warn("Since version 1.3.0, affine_grid behavior has changed " - "for unit-size grids when align_corners=True. " - "This is not an intended use case of affine_grid. " - "See the documentation of affine_grid for details.") + warnings.warn( + "Since version 1.3.0, affine_grid behavior has changed " + "for unit-size grids when align_corners=True. " + "This is not an intended use case of affine_grid. " + "See the documentation of affine_grid for details." + ) elif min(size) <= 0: - raise ValueError("Expected non-zero, positive output size. Got {}" - .format(size)) + raise ValueError("Expected non-zero, positive output size. Got {}".format(size)) return torch.affine_grid_generator(theta, size, align_corners) -def _pad(input, pad, mode='constant', value=0): - # type: (Tensor, List[int], str, float) -> Tensor +def _pad(input: Tensor, pad: List[int], mode: str = "constant", value: float = 0) -> Tensor: r"""Pads tensor. Padding size: @@ -3550,51 +3931,50 @@ def _pad(input, pad, mode='constant', value=0): torch.Size([3, 9, 7, 3]) """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - _pad, (input,), input, pad, mode=mode, value=value) - assert len(pad) % 2 == 0, 'Padding length must be divisible by 2' - assert len(pad) // 2 <= input.dim(), 'Padding length too large' - if mode == 'constant': + if has_torch_function_unary(input): + return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value) + assert len(pad) % 2 == 0, "Padding length must be divisible by 2" + assert len(pad) // 2 <= input.dim(), "Padding length too large" + if mode == "constant": return _VF.constant_pad_nd(input, pad, value) else: assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode) if input.dim() == 3: - assert len(pad) == 2, '3D tensors expect 2 values for padding' - if mode == 'reflect': + assert len(pad) == 2, "3D tensors expect 2 values for padding" + if mode == "reflect": return torch._C._nn.reflection_pad1d(input, pad) - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad1d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError elif input.dim() == 4: - assert len(pad) == 4, '4D tensors expect 4 values for padding' - if mode == 'reflect': + assert len(pad) == 4, "4D tensors expect 4 values for padding" + if mode == "reflect": return torch._C._nn.reflection_pad2d(input, pad) - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad2d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError elif input.dim() == 5: - assert len(pad) == 6, '5D tensors expect 6 values for padding' - if mode == 'reflect': + assert len(pad) == 6, "5D tensors expect 6 values for padding" + if mode == "reflect": raise NotImplementedError - elif mode == 'replicate': + elif mode == "replicate": return torch._C._nn.replication_pad3d(input, pad) - elif mode == 'circular': + elif mode == "circular": return _pad_circular(input, pad) else: raise NotImplementedError else: raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now") + # We define this function as _pad because it takes an argument # named pad, which clobbers the recursive reference to the pad # function needed for __torch_function__ support @@ -3603,15 +3983,16 @@ def _pad(input, pad, mode='constant', value=0): # distance -def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): - # type: (Tensor, Tensor, float, float, bool) -> Tensor +def pairwise_distance(x1: Tensor, x2: Tensor, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False) -> Tensor: r""" See :class:`torch.nn.PairwiseDistance` for details """ return torch.pairwise_distance(x1, x2, p, eps, keepdim) -pdist = _add_docstr(torch.pdist, r""" +pdist = _add_docstr( + torch.pdist, + r""" pdist(input, p=2) -> Tensor Computes the p-norm distance between every pair of row vectors in the input. @@ -3632,10 +4013,13 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): input: input tensor of shape :math:`N \times M`. p: p value for the p-norm distance to calculate between each vector pair :math:`\in [0, \infty]`. -""") +""", +) -cosine_similarity = _add_docstr(torch.cosine_similarity, r""" +cosine_similarity = _add_docstr( + torch.cosine_similarity, + r""" cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor Returns cosine similarity between x1 and x2, computed along dim. @@ -3660,10 +4044,13 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output) -""") +""", +) -one_hot = _add_docstr(torch._C._nn.one_hot, r""" +one_hot = _add_docstr( + torch._C._nn.one_hot, + r""" one_hot(tensor, num_classes=-1) -> LongTensor Takes LongTensor with index values of shape ``(*)`` and returns a tensor @@ -3707,45 +4094,78 @@ def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]]) -""") - - -def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None, - reduce=None, reduction="mean"): - # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor +""", +) + + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: r""" See :class:`~torch.nn.TripletMarginLoss` for details """ - if not torch.jit.is_scripting(): - tens_ops = (anchor, positive, negative) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - triplet_margin_loss, tens_ops, anchor, positive, negative, margin=margin, - p=p, eps=eps, swap=swap, size_average=size_average, reduce=reduce, - reduction=reduction) + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, - swap, reduction_enum) - - -def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_function=None, - margin=1.0, swap=False, reduction="mean"): - # type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], float, bool, str) -> Tensor + return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction_enum) + + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean" +) -> Tensor: r""" See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. """ if torch.jit.is_scripting(): - raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: " - "functions requiring Callables cannot be scripted.") + raise NotImplementedError( + "F.triplet_margin_with_distance_loss does not support JIT scripting: " + "functions requiring Callables cannot be scripted." + ) - tens_ops = (anchor, positive, negative) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + if has_torch_function_variadic(anchor, positive, negative): return handle_torch_function( - triplet_margin_with_distance_loss, tens_ops, anchor, positive, negative, - distance_function=distance_function, margin=margin, swap=swap, reduction=reduction) + triplet_margin_with_distance_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + distance_function=distance_function, + margin=margin, + swap=swap, + reduction=reduction, + ) distance_function = distance_function if distance_function is not None else pairwise_distance @@ -3767,8 +4187,7 @@ def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_fu return output -def normalize(input, p=2, dim=1, eps=1e-12, out=None): - # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor +def normalize(input: Tensor, p: float = 2, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor: r"""Performs :math:`L_p` normalization of inputs over specified dimension. For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each @@ -3787,10 +4206,8 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): out (Tensor, optional): the output tensor. If :attr:`out` is used, this operation won't be differentiable. """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) + if has_torch_function_unary(input): + return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) if out is None: denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) return input / denom @@ -3799,14 +4216,13 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): return torch.div(input, denom, out=out) -def assert_int_or_pair(arg, arg_name, message): - # type: (List[int], str, str) -> None +def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) def unfold(input, kernel_size, dilation=1, padding=0, stride=1): # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa - r"""Extracts sliding local blocks from an batched input tensor. + r"""Extracts sliding local blocks from a batched input tensor. .. warning:: Currently, only 4-D input tensors (batched image-like tensors) are @@ -3822,20 +4238,18 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): See :class:`torch.nn.Unfold` for details """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - unfold, (input,), input, kernel_size, dilation=dilation, - padding=padding, stride=stride) + if has_torch_function_unary(input): + return handle_torch_function( + unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 4: - msg = '{} must be int or 2-tuple for 4D input' - assert_int_or_pair(kernel_size, 'kernel_size', msg) - assert_int_or_pair(dilation, 'dilation', msg) - assert_int_or_pair(padding, 'padding', msg) - assert_int_or_pair(stride, 'stride', msg) - - return torch._C._nn.im2col(input, _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + msg = "{} must be int or 2-tuple for 4D input" + assert_int_or_pair(kernel_size, "kernel_size", msg) + assert_int_or_pair(dilation, "dilation", msg) + assert_int_or_pair(padding, "padding", msg) + assert_int_or_pair(stride, "stride", msg) + + return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)) else: raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim())) @@ -3851,27 +4265,26 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): See :class:`torch.nn.Fold` for details """ - if not torch.jit.is_scripting(): - if type(input) is not Tensor and has_torch_function((input,)): - return handle_torch_function( - fold, (input,), input, output_size, kernel_size, dilation=dilation, - padding=padding, stride=stride) + if has_torch_function_unary(input): + return handle_torch_function( + fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride + ) if input.dim() == 3: - msg = '{} must be int or 2-tuple for 3D input' - assert_int_or_pair(output_size, 'output_size', msg) - assert_int_or_pair(kernel_size, 'kernel_size', msg) - assert_int_or_pair(dilation, 'dilation', msg) - assert_int_or_pair(padding, 'padding', msg) - assert_int_or_pair(stride, 'stride', msg) - - return torch._C._nn.col2im(input, _pair(output_size), _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + msg = "{} must be int or 2-tuple for 3D input" + assert_int_or_pair(output_size, "output_size", msg) + assert_int_or_pair(kernel_size, "kernel_size", msg) + assert_int_or_pair(dilation, "dilation", msg) + assert_int_or_pair(padding, "padding", msg) + assert_int_or_pair(stride, "stride", msg) + + return torch._C._nn.col2im( + input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) else: raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim())) -def _pad_circular(input, padding): - # type: (Tensor, List[int]) -> Tensor +def _pad_circular(input: Tensor, padding: List[int]) -> Tensor: """Circularly pads tensor. Tensor values at the beginning are used to pad the end, and values at the @@ -3925,21 +4338,19 @@ def _pad_circular(input, padding): for idx, size in enumerate(paddable_shape): # Only supports wrapping around once - assert padding[-(idx * 2 + 1)] <= size, \ - "Padding value causes wrapping around more than once." - assert padding[-(idx * 2 + 2)] <= size, \ - "Padding value causes wrapping around more than once." + assert padding[-(idx * 2 + 1)] <= size, "Padding value causes wrapping around more than once." + assert padding[-(idx * 2 + 2)] <= size, "Padding value causes wrapping around more than once." # Negative padding should not result in negative sizes - assert padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0, \ - "Negative padding value is resulting in an empty dimension." + assert ( + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0 + ), "Negative padding value is resulting in an empty dimension." # Get shape of padded tensor out_shape = in_shape[:2] for idx, size in enumerate(paddable_shape): out_shape += (size + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)],) - out = torch.empty(out_shape, dtype=input.dtype, layout=input.layout, - device=input.device) + out = torch.empty(out_shape, dtype=input.dtype, layout=input.layout, device=input.device) # Put original array in padded array if ndim == 1: @@ -3963,8 +4374,7 @@ def _pad_circular(input, padding): in_h0 = max(-padding[-4], 0) in_h1 = in_shape[3] - max(-padding[-3], 0) - out[..., out_d0:out_d1, out_h0:out_h1] = \ - input[..., in_d0:in_d1, in_h0:in_h1] + out[..., out_d0:out_d1, out_h0:out_h1] = input[..., in_d0:in_d1, in_h0:in_h1] elif ndim == 3: out_d0 = max(padding[-2], 0) out_d1 = out_shape[2] - max(padding[-1], 0) @@ -3984,8 +4394,7 @@ def _pad_circular(input, padding): in_w0 = max(-padding[-6], 0) in_w1 = in_shape[4] - max(-padding[-5], 0) - out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = \ - input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1] + out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1] # The following steps first pad the beginning of the tensor (left side), # and then pad the end of the tensor (right side). @@ -4015,15 +4424,13 @@ def _pad_circular(input, padding): i1 = out_shape[3] - max(padding[-3], 0) o0 = 0 o1 = padding[-4] - out[:, :, :, o0:o1] = \ - out[:, :, :, i0:i1] + out[:, :, :, o0:o1] = out[:, :, :, i0:i1] if padding[-3] > 0: i0 = max(padding[-4], 0) i1 = max(padding[-4], 0) + padding[-3] o0 = out_shape[3] - padding[-3] o1 = out_shape[3] - out[:, :, :, o0:o1] = \ - out[:, :, :, i0:i1] + out[:, :, :, o0:o1] = out[:, :, :, i0:i1] # Pad third dimension (width) if len(padding) > 4: @@ -4032,43 +4439,42 @@ def _pad_circular(input, padding): i1 = out_shape[4] - max(padding[-5], 0) o0 = 0 o1 = padding[-6] - out[:, :, :, :, o0:o1] = \ - out[:, :, :, :, i0:i1] + out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1] if padding[-5] > 0: i0 = max(padding[-6], 0) i1 = max(padding[-6], 0) + padding[-5] o0 = out_shape[4] - padding[-5] o1 = out_shape[4] - out[:, :, :, :, o0:o1] = \ - out[:, :, :, :, i0:i1] + out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1] return out -def multi_head_attention_forward(query: Tensor, - key: Tensor, - value: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - add_zero_attn: bool, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - use_separate_proj_weight: bool = False, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None - ) -> Tuple[Tensor, Optional[Tensor]]: +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. @@ -4125,19 +4531,35 @@ def multi_head_attention_forward(query: Tensor, - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ - if not torch.jit.is_scripting(): - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, - out_proj_weight, out_proj_bias) - if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): - return handle_torch_function( - multi_head_attention_forward, tens_ops, query, key, value, - embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, - bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, - out_proj_bias, training=training, key_padding_mask=key_padding_mask, - need_weights=need_weights, attn_mask=attn_mask, - use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) tgt_len, bsz, embed_dim = query.size() assert embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension @@ -4148,11 +4570,11 @@ def multi_head_attention_forward(query: Tensor, scaling = float(head_dim) ** -0.5 if not use_separate_proj_weight: - if torch.equal(query, key) and torch.equal(key, value): + if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)): # self-attention q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) - elif torch.equal(key, value): + elif key is value or torch.equal(key, value): # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -4220,8 +4642,8 @@ def multi_head_attention_forward(query: Tensor, if in_proj_bias is not None: q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) - k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) - v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]) + v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :]) else: q = linear(query, q_proj_weight_non_opt, in_proj_bias) k = linear(key, k_proj_weight_non_opt, in_proj_bias) @@ -4229,9 +4651,13 @@ def multi_head_attention_forward(query: Tensor, q = q * scaling if attn_mask is not None: - assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ - attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ - 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype) if attn_mask.dtype == torch.uint8: warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") attn_mask = attn_mask.to(torch.bool) @@ -4239,17 +4665,19 @@ def multi_head_attention_forward(query: Tensor, if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 2D attn_mask is not correct.') + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: - raise RuntimeError('The size of the 3D attn_mask is not correct.') + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) key_padding_mask = key_padding_mask.to(torch.bool) if bias_k is not None and bias_v is not None: @@ -4303,21 +4731,19 @@ def multi_head_attention_forward(query: Tensor, if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float('-inf')) + attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask - if key_padding_mask is not None: attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), - float('-inf'), + float("-inf"), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) - attn_output_weights = softmax( - attn_output_weights, dim=-1) + attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) attn_output = torch.bmm(attn_output_weights, v) diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 215fb0278dc6c..4a38dca9bdb87 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,7 +1,7 @@ from torch import Tensor from torch.types import _size -from typing import Any, Optional, Tuple, Dict, List, Callable, Sequence -from .common_types import _ratio_any_t +from typing import Any, Optional, Tuple, Dict, List, Callable, Sequence, Union +from .common_types import _ratio_any_t, _size_any_t, _size_1_t, _size_2_t, _size_3_t, _size_2_opt_t, _size_3_opt_t # 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys. # It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature @@ -63,11 +63,11 @@ def max_unpool3d(input: Tensor, indices: Tensor, kernel_size: _size, stride: Opt padding: _size = ..., output_size: Optional[_size] = ...) -> Tensor: ... -def lp_pool2d(input: Tensor, norm_type: float, kernel_size: int, stride: Optional[_size] = ..., +def lp_pool1d(input: Tensor, norm_type: float, kernel_size: _size_1_t, stride: Union[Optional[_size], Optional[int]] = ..., ceil_mode: bool = ...) -> Tensor: ... -def lp_pool1d(input: Tensor, norm_type: float, kernel_size: int, stride: Optional[_size] = ..., +def lp_pool2d(input: Tensor, norm_type: float, kernel_size: _size_2_t, stride: Union[Optional[_size], Optional[int]] = ..., ceil_mode: bool = ...) -> Tensor: ... @@ -75,18 +75,21 @@ def adaptive_max_pool1d_with_indices(input: Tensor, output_size: _size, return_i Tensor, Tensor]: ... -def adaptive_max_pool2d_with_indices(input: Tensor, output_size: _size, return_indices: bool = ...) -> Tuple[ +def adaptive_max_pool2d_with_indices(input: Tensor, output_size: _size_2_opt_t, return_indices: bool = ...) -> Tuple[ Tensor, Tensor]: ... -def adaptive_max_pool3d_with_indices(input: Tensor, output_size: _size, return_indices: bool = ...) -> Tuple[ +def adaptive_max_pool3d_with_indices(input: Tensor, output_size: _size_3_opt_t, return_indices: bool = ...) -> Tuple[ Tensor, Tensor]: ... -def adaptive_avg_pool2d(input: Tensor, output_size: _size) -> Tensor: ... +def adaptive_avg_pool1d(input: Tensor, output_size: _size_1_t) -> Tensor: ... -def adaptive_avg_pool3d(input: Tensor, output_size: _size) -> Tensor: ... +def adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ... + + +def adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ... def dropout(input: Tensor, p: float = ..., training: bool = ..., inplace: bool = ...) -> Tensor: ... @@ -189,7 +192,8 @@ def embedding(input: Tensor, weight: Tensor, padding_idx: Optional[int] = ..., m def embedding_bag(input: Tensor, weight: Tensor, offsets: Optional[Tensor] = ..., max_norm: Optional[float] = ..., norm_type: float = ..., scale_grad_by_freq: bool = ..., mode: str = ..., - sparse: bool = ...) -> Tensor: ... + sparse: bool = ..., per_sample_weights: Optional[Tensor] = ..., + include_last_offset: bool = ...) -> Tensor: ... def batch_norm(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., training: bool = ..., @@ -201,7 +205,7 @@ def instance_norm(input: Tensor, running_mean: Optional[Tensor] = ..., running_v momentum: float = ..., eps: float = ...) -> Tensor: ... -def layer_norm(input: Tensor, normalized_shape: List[int], weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., +def layer_norm(input: Tensor, normalized_shape: Sequence[int], weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., eps: float = ...) -> Tensor: ... @@ -244,7 +248,7 @@ def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, weight: Opti def smooth_l1_loss(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., - reduction: str = ...) -> Tensor: ... + reduction: str = ..., beta: float = ...) -> Tensor: ... def l1_loss(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., @@ -331,12 +335,12 @@ def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ..., def assert_int_or_pair(arg: Any, arg_name: Any, message: Any) -> None: ... -def unfold(input: Tensor, kernel_size: _size, dilation: _size = ..., padding: _size = ..., - stride: _size = ...) -> Tensor: ... +def unfold(input: Tensor, kernel_size: _size_any_t, dilation: _size_any_t = ..., padding: _size_any_t = ..., + stride: _size_any_t = ...) -> Tensor: ... -def fold(input: Tensor, output_size: _size, kernel_size: _size, dilation: _size = ..., padding: _size = ..., - stride: _size = ...) -> Tensor: ... +def fold(input: Tensor, output_size: _size_any_t, kernel_size: _size_any_t, dilation: _size_any_t = ..., padding: _size_any_t = ..., + stride: _size_any_t = ...) -> Tensor: ... def multi_head_attention_forward(query: Tensor, diff --git a/torch/nn/init.py b/torch/nn/init.py index c11dba648c5a9..080cfbedfef52 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -1,8 +1,8 @@ import math import warnings -import torch from torch import Tensor +import torch # These no_grad_* functions are necessary as wrappers around the parts of these @@ -77,6 +77,7 @@ def calculate_gain(nonlinearity, param=None): Tanh :math:`\frac{5}{3}` ReLU :math:`\sqrt{2}` Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` ================= ==================================================== Args: @@ -102,12 +103,13 @@ def calculate_gain(nonlinearity, param=None): else: raise ValueError("negative_slope {} not a valid number".format(param)) return math.sqrt(2.0 / (1 + negative_slope ** 2)) + elif nonlinearity == 'selu': + return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) else: raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) -def uniform_(tensor, a=0., b=1.): - # type: (Tensor, float, float) -> Tensor +def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor: r"""Fills the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. @@ -123,8 +125,7 @@ def uniform_(tensor, a=0., b=1.): return _no_grad_uniform_(tensor, a, b) -def normal_(tensor, mean=0., std=1.): - # type: (Tensor, float, float) -> Tensor +def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor: r"""Fills the input Tensor with values drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. @@ -139,8 +140,7 @@ def normal_(tensor, mean=0., std=1.): """ return _no_grad_normal_(tensor, mean, std) -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # type: (Tensor, float, float, float, float) -> Tensor +def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor: r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` @@ -162,8 +162,7 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) -def constant_(tensor, val): - # type: (Tensor, float) -> Tensor +def constant_(tensor: Tensor, val: float) -> Tensor: r"""Fills the input Tensor with the value :math:`\text{val}`. Args: @@ -177,8 +176,7 @@ def constant_(tensor, val): return _no_grad_fill_(tensor, val) -def ones_(tensor): - # type: (Tensor) -> Tensor +def ones_(tensor: Tensor) -> Tensor: r"""Fills the input Tensor with the scalar value `1`. Args: @@ -191,8 +189,7 @@ def ones_(tensor): return _no_grad_fill_(tensor, 1.) -def zeros_(tensor): - # type: (Tensor) -> Tensor +def zeros_(tensor: Tensor) -> Tensor: r"""Fills the input Tensor with the scalar value `0`. Args: @@ -284,8 +281,7 @@ def _calculate_fan_in_and_fan_out(tensor): return fan_in, fan_out -def xavier_uniform_(tensor, gain=1.): - # type: (Tensor, float) -> Tensor +def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor: r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform @@ -312,8 +308,7 @@ def xavier_uniform_(tensor, gain=1.): return _no_grad_uniform_(tensor, -a, a) -def xavier_normal_(tensor, gain=1.): - # type: (Tensor, float) -> Tensor +def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor: r"""Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal diff --git a/torch/nn/intrinsic/__init__.py b/torch/nn/intrinsic/__init__.py index ba4514ee52c4b..270dcebaa5f4e 100644 --- a/torch/nn/intrinsic/__init__.py +++ b/torch/nn/intrinsic/__init__.py @@ -1,28 +1 @@ - -from .modules import ConvBn1d -from .modules import ConvBn2d -from .modules import ConvBn3d -from .modules import ConvBnReLU1d -from .modules import ConvBnReLU2d -from .modules import ConvBnReLU3d -from .modules import ConvReLU1d -from .modules import ConvReLU2d -from .modules import ConvReLU3d -from .modules import LinearReLU -from .modules import BNReLU2d -from .modules import BNReLU3d - -__all__ = [ - 'ConvBn1d', - 'ConvBn2d', - 'ConvBn3d', - 'ConvBnReLU2d', - 'ConvBnReLU1d', - 'ConvBnReLU3d', - 'ConvReLU1d', - 'ConvReLU2d', - 'ConvReLU3d', - 'LinearReLU', - 'BNReLU2d', - 'BNReLU3d', -] +from .modules import * diff --git a/torch/nn/intrinsic/modules/__init__.py b/torch/nn/intrinsic/modules/__init__.py index 51b2c1052a954..21536df250179 100644 --- a/torch/nn/intrinsic/modules/__init__.py +++ b/torch/nn/intrinsic/modules/__init__.py @@ -1,4 +1,4 @@ - +from .fused import _FusedModule from .fused import ConvBn1d from .fused import ConvBn2d from .fused import ConvBn3d @@ -14,6 +14,7 @@ __all__ = [ + '_FusedModule', 'ConvBn1d', 'ConvBn2d', 'ConvBn3d', diff --git a/torch/nn/intrinsic/modules/fused.py b/torch/nn/intrinsic/modules/fused.py index 949e02dbf5661..17ab2c31eb315 100644 --- a/torch/nn/intrinsic/modules/fused.py +++ b/torch/nn/intrinsic/modules/fused.py @@ -1,52 +1,56 @@ import torch from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d -class ConvReLU1d(torch.nn.Sequential): - r"""This is a sequential container which calls the Conv 1d and ReLU modules. +# Used for identifying intrinsic modules used in quantization +class _FusedModule(torch.nn.Sequential): + pass + +class ConvReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv1d and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, relu): assert type(conv) == Conv1d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(relu)) - super(ConvReLU1d, self).__init__(conv, relu) + super().__init__(conv, relu) -class ConvReLU2d(torch.nn.Sequential): - r"""This is a sequential container which calls the Conv 2d and ReLU modules. +class ConvReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, relu): assert type(conv) == Conv2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(relu)) - super(ConvReLU2d, self).__init__(conv, relu) + super().__init__(conv, relu) -class ConvReLU3d(torch.nn.Sequential): - r"""This is a sequential container which calls the Conv 3d and ReLU modules. +class ConvReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv3d and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, relu): assert type(conv) == Conv3d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(relu)) - super(ConvReLU3d, self).__init__(conv, relu) + super().__init__(conv, relu) -class LinearReLU(torch.nn.Sequential): +class LinearReLU(_FusedModule): r"""This is a sequential container which calls the Linear and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, linear, relu): assert type(linear) == Linear and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(linear), type(relu)) - super(LinearReLU, self).__init__(linear, relu) + super().__init__(linear, relu) -class ConvBn1d(torch.nn.Sequential): +class ConvBn1d(_FusedModule): r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, bn): assert type(conv) == Conv1d and type(bn) == BatchNorm1d, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(bn)) - super(ConvBn1d, self).__init__(conv, bn) + super().__init__(conv, bn) -class ConvBn2d(torch.nn.Sequential): +class ConvBn2d(_FusedModule): r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, bn): @@ -55,57 +59,57 @@ def __init__(self, conv, bn): type(conv), type(bn)) super(ConvBn2d, self).__init__(conv, bn) -class ConvBnReLU1d(torch.nn.Sequential): +class ConvBnReLU1d(_FusedModule): r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, bn, relu): assert type(conv) == Conv1d and type(bn) == BatchNorm1d and \ type(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \ .format(type(conv), type(bn), type(relu)) - super(ConvBnReLU1d, self).__init__(conv, bn, relu) + super().__init__(conv, bn, relu) -class ConvBnReLU2d(torch.nn.Sequential): +class ConvBnReLU2d(_FusedModule): r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, bn, relu): assert type(conv) == Conv2d and type(bn) == BatchNorm2d and \ type(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \ .format(type(conv), type(bn), type(relu)) - super(ConvBnReLU2d, self).__init__(conv, bn, relu) + super().__init__(conv, bn, relu) -class ConvBn3d(torch.nn.Sequential): +class ConvBn3d(_FusedModule): r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, bn): assert type(conv) == Conv3d and type(bn) == BatchNorm3d, \ 'Incorrect types for input modules{}{}'.format( type(conv), type(bn)) - super(ConvBn3d, self).__init__(conv, bn) + super().__init__(conv, bn) -class ConvBnReLU3d(torch.nn.Sequential): +class ConvBnReLU3d(_FusedModule): r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, conv, bn, relu): assert type(conv) == Conv3d and type(bn) == BatchNorm3d and \ type(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \ .format(type(conv), type(bn), type(relu)) - super(ConvBnReLU3d, self).__init__(conv, bn, relu) + super().__init__(conv, bn, relu) -class BNReLU2d(torch.nn.Sequential): +class BNReLU2d(_FusedModule): r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, batch_norm, relu): assert type(batch_norm) == BatchNorm2d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(batch_norm), type(relu)) - super(BNReLU2d, self).__init__(batch_norm, relu) + super().__init__(batch_norm, relu) -class BNReLU3d(torch.nn.Sequential): +class BNReLU3d(_FusedModule): r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules. During quantization this will be replaced with the corresponding fused module.""" def __init__(self, batch_norm, relu): assert type(batch_norm) == BatchNorm3d and type(relu) == ReLU, \ 'Incorrect types for input modules{}{}'.format( type(batch_norm), type(relu)) - super(BNReLU3d, self).__init__(batch_norm, relu) + super().__init__(batch_norm, relu) diff --git a/torch/nn/intrinsic/qat/__init__.py b/torch/nn/intrinsic/qat/__init__.py index d46ca956685c8..270dcebaa5f4e 100644 --- a/torch/nn/intrinsic/qat/__init__.py +++ b/torch/nn/intrinsic/qat/__init__.py @@ -1,14 +1 @@ -from .modules import LinearReLU -from .modules import ConvReLU2d -from .modules import ConvBn2d -from .modules import ConvBnReLU2d -from .modules import update_bn_stats, freeze_bn_stats - -__all__ = [ - 'ConvBn2d', - 'ConvBnReLU2d', - 'ConvReLU2d', - 'LinearReLU', - 'update_bn_stats', - 'freeze_bn_stats' -] +from .modules import * diff --git a/torch/nn/intrinsic/qat/modules/__init__.py b/torch/nn/intrinsic/qat/modules/__init__.py index bcbb865a56499..f0876e8ded566 100644 --- a/torch/nn/intrinsic/qat/modules/__init__.py +++ b/torch/nn/intrinsic/qat/modules/__init__.py @@ -1,11 +1,13 @@ - from .linear_relu import LinearReLU -from .conv_fused import ConvBn2d, ConvBnReLU2d, ConvReLU2d, update_bn_stats, freeze_bn_stats +from .conv_fused import ConvBn1d, ConvBn2d, ConvBnReLU1d, ConvBnReLU2d, ConvReLU2d, \ + update_bn_stats, freeze_bn_stats __all__ = [ 'LinearReLU', 'ConvReLU2d', + 'ConvBn1d', 'ConvBn2d', + 'ConvBnReLU1d', 'ConvBnReLU2d', 'update_bn_stats', 'freeze_bn_stats' diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py index 5a8b0f042db19..12018a34e23f2 100644 --- a/torch/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -1,14 +1,21 @@ import math import torch import torch.nn as nn -import torch.nn.intrinsic +import torch.nn.intrinsic as nni import torch.nn.qat as nnqat import torch.nn.functional as F from torch.nn import init -from torch.nn.modules.utils import _pair +from torch.nn.modules.utils import _single, _pair from torch.nn.parameter import Parameter -class _ConvBnNd(nn.modules.conv._ConvNd): +_BN_CLASS_MAP = { + 1: nn.BatchNorm1d, + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d, +} + + +class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): _version = 2 @@ -26,14 +33,15 @@ def __init__(self, # track_running_stats: True # Args for this module freeze_bn=False, - qconfig=None): + qconfig=None, + dim=2): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True - self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True) + self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = Parameter(torch.Tensor(out_channels)) @@ -80,12 +88,21 @@ def freeze_bn_stats(self): def _forward(self, input): running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std - scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape([-1, 1, 1, 1])) - # this does not include the conv bias - conv = self._conv_forward(input, scaled_weight) - conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) + # using zero bias here since the bias for original conv + # will be added later + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device) + conv = self._conv_forward(input, scaled_weight, zero_bias) + conv_orig = conv / scale_factor.reshape(bias_shape) if self.bias is not None: - conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1]) + conv_orig = conv_orig + self.bias.reshape(bias_shape) conv = self.bn(conv_orig) return conv @@ -190,6 +207,92 @@ def from_float(cls, mod): qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked return qat_convbn +class ConvBn1d(_ConvBnNd, nn.Conv1d): + r""" + A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + _FLOAT_MODULE = nni.ConvBn1d + + def __init__(self, + # Conv1d args + in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=None, + padding_mode='zeros', + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, False, _single(0), groups, bias, padding_mode, + eps, momentum, freeze_bn, qconfig, dim=1) + +class ConvBnReLU1d(ConvBn1d): + r""" + A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + _FLOAT_MODULE = nni.ConvBnReLU1d + + def __init__(self, + # Conv1d args + in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=None, + padding_mode='zeros', + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, + padding_mode, eps, momentum, + freeze_bn, + qconfig) + + def forward(self, input): + return F.relu(ConvBn1d._forward(self, input)) + + @classmethod + def from_float(cls, mod): + return super(ConvBnReLU1d, cls).from_float(mod) + class ConvBn2d(_ConvBnNd, nn.Conv2d): r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, @@ -199,8 +302,6 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d`. - Implementation details: https://arxiv.org/pdf/1806.08342.pdf section 3.2.2 - Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized to default. @@ -209,7 +310,7 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvBn2d + _FLOAT_MODULE = nni.ConvBn2d def __init__(self, # ConvNd args @@ -231,7 +332,7 @@ def __init__(self, dilation = _pair(dilation) _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, - eps, momentum, freeze_bn, qconfig) + eps, momentum, freeze_bn, qconfig, dim=2) class ConvBnReLU2d(ConvBn2d): r""" @@ -242,8 +343,6 @@ class ConvBnReLU2d(ConvBn2d): We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. - Implementation details: https://arxiv.org/pdf/1806.08342.pdf - Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. @@ -251,7 +350,7 @@ class ConvBnReLU2d(ConvBn2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvBnReLU2d + _FLOAT_MODULE = nni.ConvBnReLU2d def __init__(self, # Conv2d args @@ -280,9 +379,8 @@ def forward(self, input): def from_float(cls, mod): return super(ConvBnReLU2d, cls).from_float(mod) -class ConvReLU2d(nnqat.Conv2d): - r""" - A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with +class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): + r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for weight for quantization aware training. @@ -293,7 +391,7 @@ class ConvReLU2d(nnqat.Conv2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d + _FLOAT_MODULE = nni.ConvReLU2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, @@ -309,16 +407,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, def forward(self, input): return F.relu( - self._conv_forward(input, self.weight_fake_quant(self.weight))) + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod def from_float(cls, mod): return super(ConvReLU2d, cls).from_float(mod) def update_bn_stats(mod): - if type(mod) in set([ConvBnReLU2d, ConvBn2d]): + if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]): mod.update_bn_stats() def freeze_bn_stats(mod): - if type(mod) in set([ConvBnReLU2d, ConvBn2d]): + if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]): mod.freeze_bn_stats() diff --git a/torch/nn/intrinsic/qat/modules/linear_relu.py b/torch/nn/intrinsic/qat/modules/linear_relu.py index b11072ddb7bea..ae54c36c45b20 100644 --- a/torch/nn/intrinsic/qat/modules/linear_relu.py +++ b/torch/nn/intrinsic/qat/modules/linear_relu.py @@ -1,8 +1,8 @@ import torch.nn.qat as nnqat -import torch.nn.intrinsic +import torch.nn.intrinsic as nni import torch.nn.functional as F -class LinearReLU(nnqat.Linear): +class LinearReLU(nnqat.Linear, nni._FusedModule): r""" A LinearReLU module fused from Linear and ReLU modules, attached with FakeQuantize modules for weight, used in @@ -24,7 +24,7 @@ class LinearReLU(nnqat.Linear): >>> print(output.size()) torch.Size([128, 30]) """ - _FLOAT_MODULE = torch.nn.intrinsic.LinearReLU + _FLOAT_MODULE = nni.LinearReLU def __init__(self, in_features, out_features, bias=True, qconfig=None): diff --git a/torch/nn/intrinsic/quantized/__init__.py b/torch/nn/intrinsic/quantized/__init__.py index 67cb2098cfa97..270dcebaa5f4e 100644 --- a/torch/nn/intrinsic/quantized/__init__.py +++ b/torch/nn/intrinsic/quantized/__init__.py @@ -1,12 +1 @@ -from .modules import LinearReLU -from .modules import ConvReLU1d, ConvReLU2d, ConvReLU3d -from .modules import BNReLU2d, BNReLU3d - -__all__ = [ - 'LinearReLU', - 'ConvReLU1d', - 'ConvReLU2d', - 'ConvReLU3d', - 'BNReLU2d', - 'BNReLU3d', -] +from .modules import * diff --git a/torch/nn/intrinsic/quantized/_reference/__init__.py b/torch/nn/intrinsic/quantized/_reference/__init__.py new file mode 100644 index 0000000000000..270dcebaa5f4e --- /dev/null +++ b/torch/nn/intrinsic/quantized/_reference/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py b/torch/nn/intrinsic/quantized/_reference/modules/__init__.py new file mode 100644 index 0000000000000..ce571862b4275 --- /dev/null +++ b/torch/nn/intrinsic/quantized/_reference/modules/__init__.py @@ -0,0 +1,6 @@ +import torch +from .linear_relu import LinearReLU + +__all__ = [ + 'LinearReLU', +] diff --git a/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py b/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py new file mode 100644 index 0000000000000..f8dab5900d246 --- /dev/null +++ b/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py @@ -0,0 +1,25 @@ +import torch +import torch.nn.quantized._reference as nnqr +import torch.nn.functional as F + +class LinearReLU(nnqr.Linear): + def __init__( + self, + in_features, + out_features, + bias=True, + dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dequant = x.dequantize() + weight_dequant = self._qweight.dequantize() + float_result = F.linear(x_dequant, weight_dequant, self._bias) + float_result = F.relu(float_result, inplace=True) + # NEEDFIX: we don't have dtype in the Linear module APIs right now! + result = torch.quantize_per_tensor( + float_result, self.scale, self.zero_point, torch.quint8) + return result + + def _get_name(self): + return "QuantizedLinearReLU(Reference)" diff --git a/torch/nn/intrinsic/quantized/modules/__init__.py b/torch/nn/intrinsic/quantized/modules/__init__.py index 34d0691b25f44..521e409b2b642 100644 --- a/torch/nn/intrinsic/quantized/modules/__init__.py +++ b/torch/nn/intrinsic/quantized/modules/__init__.py @@ -1,4 +1,3 @@ - from .linear_relu import LinearReLU from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d from .bn_relu import BNReLU2d, BNReLU3d diff --git a/torch/nn/intrinsic/quantized/modules/conv_relu.py b/torch/nn/intrinsic/quantized/modules/conv_relu.py index b75caf08ad5de..8dd931ff05a83 100644 --- a/torch/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/nn/intrinsic/quantized/modules/conv_relu.py @@ -16,7 +16,7 @@ class ConvReLU1d(nnq.Conv1d): Same as torch.nn.quantized.Conv1d """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d # type: ignore[assignment] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -39,6 +39,10 @@ def _get_name(self): @classmethod def from_float(cls, mod): + if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU1d: + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, + mod.bn.eps, mod.bn.weight, mod.bn.bias) return super(ConvReLU1d, cls).from_float(mod) class ConvReLU2d(nnq.Conv2d): @@ -51,7 +55,7 @@ class ConvReLU2d(nnq.Conv2d): Same as torch.nn.quantized.Conv2d """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d # type: ignore[assignment] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -90,7 +94,7 @@ class ConvReLU3d(nnq.Conv3d): Attributes: Same as torch.nn.quantized.Conv3d """ - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d # type: ignore[assignment] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, diff --git a/torch/nn/intrinsic/quantized/modules/linear_relu.py b/torch/nn/intrinsic/quantized/modules/linear_relu.py index f646437114779..08e0aecceb30d 100644 --- a/torch/nn/intrinsic/quantized/modules/linear_relu.py +++ b/torch/nn/intrinsic/quantized/modules/linear_relu.py @@ -1,6 +1,6 @@ -import torch.nn.quantized as nnq -import torch.nn.intrinsic import torch +import torch.nn.quantized as nnq +import torch.nn.intrinsic as nni class LinearReLU(nnq.Linear): r""" @@ -19,17 +19,14 @@ class LinearReLU(nnq.Linear): >>> print(output.size()) torch.Size([128, 30]) """ - _FLOAT_MODULE = torch.nn.intrinsic.LinearReLU + _FLOAT_MODULE = nni.LinearReLU def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): - super(LinearReLU, self).__init__(in_features, out_features, bias, dtype) - - def forward(self, input): - Y_q = torch.ops.quantized.linear_relu( - input, self._packed_params._packed_params, - float(self.scale), - int(self.zero_point)) - return Y_q + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_relu( + x, self._packed_params._packed_params, self.scale, self.zero_point) def _get_name(self): return 'QuantizedLinearReLU' diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 06a565700550b..4911d4bef38f7 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -1,7 +1,8 @@ from .module import Module -from .linear import Identity, Linear, Bilinear +from .linear import Identity, Linear, Bilinear, LazyLinear from .conv import Conv1d, Conv2d, Conv3d, \ - ConvTranspose1d, ConvTranspose2d, ConvTranspose3d + ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, \ + LazyConv1d, LazyConv2d, LazyConv3d, LazyConvTranspose1d, LazyConvTranspose2d, LazyConvTranspose3d from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \ Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, LogSigmoid, \ Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU, \ @@ -23,7 +24,7 @@ from .sparse import Embedding, EmbeddingBag from .rnn import RNNBase, RNN, LSTM, GRU, \ RNNCellBase, RNNCell, LSTMCell, GRUCell -from .pixelshuffle import PixelShuffle +from .pixelshuffle import PixelShuffle, PixelUnshuffle from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample from .distance import PairwiseDistance, CosineSimilarity from .fold import Fold, Unfold @@ -31,6 +32,7 @@ from .transformer import TransformerEncoder, TransformerDecoder, \ TransformerEncoderLayer, TransformerDecoderLayer, Transformer from .flatten import Flatten, Unflatten +from .channelshuffle import ChannelShuffle __all__ = [ 'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', @@ -48,11 +50,13 @@ 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout', 'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d', 'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', - 'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', + 'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', - 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss' + 'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', + 'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d', + 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle' ] diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 4df826692d0fe..e9424673dda17 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -9,13 +9,14 @@ class SyncBatchNorm(Function): def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): input = input.contiguous() - count = torch.empty(1, - dtype=running_mean.dtype, - device=input.device).fill_(input.numel() // input.size(1)) - # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) + count = torch.full((1,), input.numel() // input.size(1), + dtype=mean.dtype, + device=mean.device) + + num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) @@ -196,3 +197,13 @@ def backward(ctx, grad_output): accum_ratio.add_(paddded_ratio[c], alpha=-1) return grad_input, None, None, None, None + +class BackwardHookFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) + return args + + @staticmethod + def backward(ctx, *args): + return args diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index a3aef5d7e41d1..d666c6d01c56a 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1,5 +1,5 @@ import warnings -from typing import Tuple, Optional +from typing import Optional, Tuple import torch from torch import Tensor @@ -365,11 +365,11 @@ class SiLU(Module): \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} .. note:: - See `Gaussian Error Linear Units (GELUs) `_ - where the SiLU (Sigmoid Linear Unit) was originally coined, and see - `Sigmoid-Weighted Linear Units for Neural Network Function Approximation - in Reinforcement Learning `_ and `Swish: - a Self-Gated Activation Function `_ + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ where the SiLU was experimented with later. Shape: @@ -831,11 +831,12 @@ def extra_repr(self) -> str: class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. - See reference: Attention Is All You Need + See `Attention Is All You Need `_ .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. Args: embed_dim: total dimension of the model. @@ -848,8 +849,9 @@ class MultiheadAttention(Module): kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. - Note: if kdim and vdim are None, they will be set to embed_dim such that - query, key, and value have the same number of features. + Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set + to :attr:`embed_dim` such that query, key, and value have the same + number of features. Examples:: @@ -921,9 +923,8 @@ def __setstate__(self, state): super(MultiheadAttention, self).__setstate__(state) - def forward(self, query, key, value, key_padding_mask=None, - need_weights=True, attn_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. @@ -937,8 +938,7 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. - Shape: - - Inputs: + Shapes for inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is @@ -949,15 +949,17 @@ def forward(self, query, key, value, key_padding_mask=None, If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the + source sequence length. + + If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence + length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend + the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - - Outputs: + Shapes for outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, @@ -1112,7 +1114,7 @@ class Softmin(Module): dimensions - Output: :math:`(*)`, same shape as the input - Arguments: + Args: dim (int): A dimension along which Softmin will be computed (so every slice along dim will sum to 1). @@ -1166,7 +1168,7 @@ class Softmax(Module): a Tensor of the same dimension and shape as the input with values in the range [0, 1] - Arguments: + Args: dim (int): A dimension along which Softmax will be computed (so every slice along dim will sum to 1). @@ -1240,7 +1242,7 @@ class LogSoftmax(Module): dimensions - Output: :math:`(*)`, same shape as the input - Arguments: + Args: dim (int): A dimension along which LogSoftmax will be computed. Returns: diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index f5ca6deb5b19c..64417069e2b71 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -54,9 +54,11 @@ def __init__( def reset_running_stats(self) -> None: if self.track_running_stats: - self.running_mean.zero_() - self.running_var.fill_(1) - self.num_batches_tracked.zero_() + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[operator] + self.running_var.fill_(1) # type: ignore[operator] + self.num_batches_tracked.zero_() # type: ignore[operator] def reset_parameters(self) -> None: self.reset_running_stats() @@ -107,8 +109,8 @@ def forward(self, input: Tensor) -> Tensor: if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None - if self.num_batches_tracked is not None: - self.num_batches_tracked = self.num_batches_tracked + 1 + if self.num_batches_tracked is not None: # type: ignore + self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average @@ -128,6 +130,8 @@ def forward(self, input: Tensor) -> Tensor: passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """ + assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) + assert self.running_var is None or isinstance(self.running_var, torch.Tensor) return F.batch_norm( input, # If buffers are not to be tracked, ensure that they won't be updated @@ -430,8 +434,14 @@ class SyncBatchNorm(_BatchNorm): >>> # With Learnable Parameters >>> m = nn.SyncBatchNorm(100) >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) >>> input = torch.randn(20, 100, 35, 45, 10) @@ -487,6 +497,7 @@ def forward(self, input: Tensor) -> Tensor: exponential_average_factor = self.momentum if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() @@ -508,6 +519,8 @@ def forward(self, input: Tensor) -> Tensor: used for normalization (i.e. in eval mode when buffers are not None). """ # If buffers are not to be tracked, ensure that they won't be updated + assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) + assert self.running_var is None or isinstance(self.running_var, torch.Tensor) running_mean = self.running_mean if not self.training or self.track_running_stats else None running_var = self.running_var if not self.training or self.track_running_stats else None @@ -557,8 +570,14 @@ def convert_sync_batchnorm(cls, module, process_group=None): >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) """ diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index f5d07ae4a69ce..2ead84fe73e8a 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -125,7 +125,7 @@ class ModuleList(Module): modules it contains are properly registered, and will be visible by all :class:`~torch.nn.Module` methods. - Arguments: + Args: modules (iterable, optional): an iterable of modules to add Example:: @@ -197,7 +197,7 @@ def __dir__(self): def insert(self, index: int, module: Module) -> None: r"""Insert a given module before a given index in the list. - Arguments: + Args: index (int): index to insert. module (nn.Module): module to insert """ @@ -208,7 +208,7 @@ def insert(self, index: int, module: Module) -> None: def append(self: T, module: Module) -> T: r"""Appends a given module to the end of the list. - Arguments: + Args: module (nn.Module): module to append """ self.add_module(str(len(self)), module) @@ -217,7 +217,7 @@ def append(self: T, module: Module) -> T: def extend(self: T, modules: Iterable[Module]) -> T: r"""Appends modules from a Python iterable to the end of the list. - Arguments: + Args: modules (iterable): iterable of modules to append """ if not isinstance(modules, container_abcs.Iterable): @@ -252,7 +252,7 @@ class ModuleDict(Module): types (e.g., Python's plain ``dict`` before Python version 3.6) does not preserve the order of the merged mapping. - Arguments: + Args: modules (iterable, optional): a mapping (dictionary) of (string: module) or an iterable of key-value pairs of type (string, module) @@ -311,7 +311,7 @@ def clear(self) -> None: def pop(self, key: str) -> Module: r"""Remove key from the ModuleDict and return its module. - Arguments: + Args: key (string): key to pop from the ModuleDict """ v = self[key] @@ -344,7 +344,7 @@ def update(self, modules: Mapping[str, Module]) -> None: If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or an iterable of key-value pairs, the order of new elements in it is preserved. - Arguments: + Args: modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) """ @@ -379,7 +379,7 @@ class ParameterList(Module): list, but parameters it contains are properly registered, and will be visible by all :class:`~torch.nn.Module` methods. - Arguments: + Args: parameters (iterable, optional): an iterable of :class:`~torch.nn.Parameter` to add Example:: @@ -398,9 +398,15 @@ def forward(self, x): def __init__(self, parameters: Optional[Iterable['Parameter']] = None) -> None: super(ParameterList, self).__init__() + self._initialized = True if parameters is not None: self += parameters + def __setstate__(self, state): + state['_initialized'] = False + super(ParameterList, self).__setstate__(state) + self._initialized = True + def _get_abs_string_index(self, idx): """Get the absolute index for the list of modules""" idx = operator.index(idx) @@ -430,8 +436,9 @@ def __setitem__(self, idx: int, param: 'Parameter') -> None: return self.register_parameter(str(idx), param) def __setattr__(self, key: Any, value: Any) -> None: - if not isinstance(value, torch.nn.Parameter): - warnings.warn("Setting attributes on ParameterList is not supported.") + if getattr(self, "_initialized", False): + if not hasattr(self, key) and not isinstance(value, torch.nn.Parameter): + warnings.warn("Setting attributes on ParameterList is not supported.") super(ParameterList, self).__setattr__(key, value) def __len__(self) -> int: @@ -451,7 +458,7 @@ def __dir__(self): def append(self: T, parameter: 'Parameter') -> T: """Appends a given parameter at the end of the list. - Arguments: + Args: parameter (nn.Parameter): parameter to append """ self.register_parameter(str(len(self)), parameter) @@ -460,7 +467,7 @@ def append(self: T, parameter: 'Parameter') -> T: def extend(self: T, parameters: Iterable['Parameter']) -> T: """Appends parameters from a Python iterable to the end of the list. - Arguments: + Args: parameters (iterable): iterable of parameters to append """ if not isinstance(parameters, container_abcs.Iterable): @@ -511,7 +518,7 @@ class ParameterDict(Module): types (e.g., Python's plain ``dict``) does not preserve the order of the merged mapping. - Arguments: + Args: parameters (iterable, optional): a mapping (dictionary) of (string : :class:`~torch.nn.Parameter`) or an iterable of key-value pairs of type (string, :class:`~torch.nn.Parameter`) @@ -533,9 +540,15 @@ def forward(self, x, choice): def __init__(self, parameters: Optional[Mapping[str, 'Parameter']] = None) -> None: super(ParameterDict, self).__init__() + self._initialized = True if parameters is not None: self.update(parameters) + def __setstate__(self, state): + state['_initialized'] = False + super(ParameterDict, self).__setstate__(state) + self._initialized = True + def __getitem__(self, key: str) -> 'Parameter': return self._parameters[key] @@ -546,8 +559,9 @@ def __delitem__(self, key: str) -> None: del self._parameters[key] def __setattr__(self, key: Any, value: Any) -> None: - if not isinstance(value, torch.nn.Parameter): - warnings.warn("Setting attributes on ParameterDict is not supported.") + if getattr(self, "_initialized", False): + if not hasattr(self, key) and not isinstance(value, torch.nn.Parameter): + warnings.warn("Setting attributes on ParameterDict is not supported.") super(ParameterDict, self).__setattr__(key, value) def __len__(self) -> int: @@ -567,7 +581,7 @@ def clear(self) -> None: def pop(self, key: str) -> 'Parameter': r"""Remove key from the ParameterDict and return its parameter. - Arguments: + Args: key (string): key to pop from the ParameterDict """ v = self[key] @@ -597,7 +611,7 @@ def update(self, parameters: Mapping[str, 'Parameter']) -> None: If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or an iterable of key-value pairs, the order of new elements in it is preserved. - Arguments: + Args: parameters (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Parameter`, or an iterable of key-value pairs of type (string, :class:`~torch.nn.Parameter`) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 3b9391d1061cb..90023b7a4346b 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -4,15 +4,41 @@ import torch from torch import Tensor -from torch.nn.parameter import Parameter +from torch.nn.parameter import Parameter, UninitializedParameter from .. import functional as F from .. import init +from .lazy import LazyModuleMixin from .module import Module from .utils import _single, _pair, _triple, _reverse_repeat_tuple +from torch._torch_docs import reproducibility_notes from ..common_types import _size_1_t, _size_2_t, _size_3_t from typing import Optional, List, Tuple +convolution_notes = \ + {"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters (of size + :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""", + + "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also known as a "depthwise convolution". + + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`."""} # noqa: B950 + + + + class _ConvNd(Module): @@ -37,14 +63,14 @@ class _ConvNd(Module): def __init__(self, in_channels: int, out_channels: int, - kernel_size: _size_1_t, - stride: _size_1_t, - padding: _size_1_t, - dilation: _size_1_t, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], transposed: bool, - output_padding: _size_1_t, + output_padding: Tuple[int, ...], groups: int, - bias: Optional[Tensor], + bias: bool, padding_mode: str) -> None: super(_ConvNd, self).__init__() if in_channels % groups != 0: @@ -113,7 +139,7 @@ def __setstate__(self, state): class Conv1d(_ConvNd): - r"""Applies a 1D convolution over an input signal composed of several input + __doc__ = r"""Applies a 1D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size @@ -128,58 +154,26 @@ class Conv1d(_ConvNd): where :math:`\star` is the valid `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`L` is a length of signal sequence. + """ + r""" This module supports :ref:`TensorFloat32`. * :attr:`stride` controls the stride for the cross-correlation, a single number or a one-element tuple. - * :attr:`padding` controls the amount of implicit zero-paddings on both sides + * :attr:`padding` controls the amount of implicit padding on both sides for :attr:`padding` number of points. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters, - of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. - - Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid - `cross-correlation`_, and not a full `cross-correlation`_. - It is up to the user to add proper padding. + {groups_note} Note: - - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also termed in - literature as depthwise convolution. - - In other words, for an input of size :math:`(N, C_{in}, L_{in})`, - a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments - :math:`(C_\text{in}=C_{in}, C_\text{out}=C_{in} \times K, ..., \text{groups}=C_{in})`. - + {depthwise_separable_note} Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -197,6 +191,8 @@ class Conv1d(_ConvNd): bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" + Shape: - Input: :math:`(N, C_{in}, L_{in})` - Output: :math:`(N, C_{out}, L_{out})` where @@ -242,25 +238,30 @@ def __init__( bias: bool = True, padding_mode: str = 'zeros' # TODO: refine this type ): - kernel_size = _single(kernel_size) - stride = _single(stride) - padding = _single(padding) - dilation = _single(dilation) + # we create new variables below to make mypy happy since kernel_size has + # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = _single(padding) + dilation_ = _single(dilation) super(Conv1d, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, False, _single(0), groups, bias, padding_mode) - def forward(self, input: Tensor) -> Tensor: + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.weight, self.bias, self.stride, + weight, bias, self.stride, _single(0), self.dilation, self.groups) - return F.conv1d(input, self.weight, self.bias, self.stride, + return F.conv1d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + class Conv2d(_ConvNd): - r"""Applies a 2D convolution over an input signal composed of several input + __doc__ = r"""Applies a 2D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size @@ -276,31 +277,21 @@ class Conv2d(_ConvNd): :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. + """ + r""" This module supports :ref:`TensorFloat32`. * :attr:`stride` controls the stride for the cross-correlation, a single number or a tuple. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit padding on both sides for :attr:`padding` number of points for each dimension. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters, of size: - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: @@ -309,30 +300,10 @@ class Conv2d(_ConvNd): and the second `int` for the width dimension Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. - - Note: - - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also termed in - literature as depthwise convolution. - - In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, - a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments - :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. + {depthwise_separable_note} Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -348,6 +319,7 @@ class Conv2d(_ConvNd): channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` @@ -391,6 +363,7 @@ class Conv2d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + def __init__( self, in_channels: int, @@ -403,27 +376,27 @@ def __init__( bias: bool = True, padding_mode: str = 'zeros' # TODO: refine this type ): - kernel_size = _pair(kernel_size) - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = _pair(padding) + dilation_ = _pair(dilation) super(Conv2d, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, False, _pair(0), groups, bias, padding_mode) - def _conv_forward(self, input, weight): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, + weight, bias, self.stride, _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.bias, self.stride, + return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.weight) + return self._conv_forward(input, self.weight, self.bias) class Conv3d(_ConvNd): - r"""Applies a 3D convolution over an input signal composed of several input + __doc__ = r"""Applies a 3D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` @@ -434,29 +407,19 @@ class Conv3d(_ConvNd): \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) where :math:`\star` is the valid 3D `cross-correlation`_ operator + """ + r""" This module supports :ref:`TensorFloat32`. * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit padding on both sides for :attr:`padding` number of points for each dimension. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters, of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: @@ -465,30 +428,10 @@ class Conv3d(_ConvNd): the second `int` for the height dimension and the third `int` for the width dimension Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. - - Note: - - When `groups == in_channels` and `out_channels == K * in_channels`, - where `K` is a positive integer, this operation is also termed in - literature as depthwise convolution. - - In other words, for an input of size :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, - a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments - :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. + {depthwise_separable_note} Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -500,6 +443,7 @@ class Conv3d(_ConvNd): dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` @@ -544,6 +488,7 @@ class Conv3d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + def __init__( self, in_channels: int, @@ -556,12 +501,12 @@ def __init__( bias: bool = True, padding_mode: str = 'zeros' ): - kernel_size = _triple(kernel_size) - stride = _triple(stride) - padding = _triple(padding) - dilation = _triple(dilation) + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = _triple(padding) + dilation_ = _triple(dilation) super(Conv3d, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, False, _triple(0), groups, bias, padding_mode) def forward(self, input: Tensor) -> Tensor: @@ -587,8 +532,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, # dilation being an optional parameter is for backwards # compatibility - def _output_padding(self, input, output_size, stride, padding, kernel_size, dilation=None): - # type: (Tensor, Optional[List[int]], List[int], List[int], List[int], Optional[List[int]]) -> List[int] + def _output_padding(self, input: Tensor, output_size: Optional[List[int]], + stride: List[int], padding: List[int], kernel_size: List[int], + dilation: Optional[List[int]] = None) -> List[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already else: @@ -628,7 +574,7 @@ def _output_padding(self, input, output_size, stride, padding, kernel_size, dila class ConvTranspose1d(_ConvTransposeNd): - r"""Applies a 1D transposed convolution operator over an input image + __doc__ = r"""Applies a 1D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv1d with respect to its input. @@ -639,7 +585,7 @@ class ConvTranspose1d(_ConvTransposeNd): * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. @@ -649,25 +595,7 @@ class ConvTranspose1d(_ConvTransposeNd): * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`). - - Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. + {groups_note} Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` @@ -702,6 +630,7 @@ class ConvTranspose1d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, L_{in})` @@ -756,15 +685,17 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) + input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore return F.conv_transpose1d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class ConvTranspose2d(_ConvTransposeNd): - r"""Applies a 2D transposed convolution operator over an input image + __doc__ = r"""Applies a 2D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv2d with respect to its input. @@ -775,7 +706,7 @@ class ConvTranspose2d(_ConvTransposeNd): * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. @@ -785,18 +716,7 @@ class ConvTranspose2d(_ConvTransposeNd): * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`). + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` can either be: @@ -805,13 +725,6 @@ class ConvTranspose2d(_ConvTransposeNd): - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension - .. note:: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. - Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to both sizes of the input. This is set so that @@ -825,13 +738,7 @@ class ConvTranspose2d(_ConvTransposeNd): not actually add zero-padding to output. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -845,6 +752,7 @@ class ConvTranspose2d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` @@ -921,8 +829,10 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) + input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore return F.conv_transpose2d( input, self.weight, self.bias, self.stride, self.padding, @@ -930,7 +840,7 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten class ConvTranspose3d(_ConvTransposeNd): - r"""Applies a 3D transposed convolution operator over an input image composed of several input + __doc__ = r"""Applies a 3D transposed convolution operator over an input image composed of several input planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes. @@ -943,7 +853,7 @@ class ConvTranspose3d(_ConvTransposeNd): * :attr:`stride` controls the stride for the cross-correlation. - * :attr:`padding` controls the amount of implicit zero-paddings on both + * :attr:`padding` controls the amount of implicit zero padding on both sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note below for details. @@ -953,18 +863,7 @@ class ConvTranspose3d(_ConvTransposeNd): * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - * :attr:`groups` controls the connections between inputs and outputs. - :attr:`in_channels` and :attr:`out_channels` must both be divisible by - :attr:`groups`. For example, - - * At groups=1, all inputs are convolved to all outputs. - * At groups=2, the operation becomes equivalent to having two conv - layers side by side, each seeing half the input channels, - and producing half the output channels, and both subsequently - concatenated. - * At groups= :attr:`in_channels`, each input channel is convolved with - its own set of filters (of size - :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`). + {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` can either be: @@ -973,13 +872,6 @@ class ConvTranspose3d(_ConvTransposeNd): - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, the second `int` for the height dimension and the third `int` for the width dimension - Note: - - Depending of the size of your kernel, several (of the last) - columns of the input might be lost, because it is a valid `cross-correlation`_, - and not a full `cross-correlation`_. - It is up to the user to add proper padding. - Note: The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to both sizes of the input. This is set so that @@ -993,13 +885,7 @@ class ConvTranspose3d(_ConvTransposeNd): not actually add zero-padding to output. Note: - In some circumstances when using the CUDA backend with CuDNN, this operator - may select a nondeterministic algorithm to increase performance. If this is - undesirable, you can try to make the operation deterministic (potentially at - a performance cost) by setting ``torch.backends.cudnn.deterministic = - True``. - Please see the notes on :doc:`/notes/randomness` for background. - + {cudnn_reproducibility_note} Args: in_channels (int): Number of channels in the input image @@ -1013,6 +899,7 @@ class ConvTranspose3d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` @@ -1083,8 +970,10 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d') + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) + input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore return F.conv_transpose3d( input, self.weight, self.bias, self.stride, self.padding, @@ -1117,3 +1006,350 @@ def __init__(self, *args, **kwargs): # TODO: Conv2dLocal # TODO: Conv2dMap # TODO: ConvTranspose2dMap + + +class _LazyConvXdMixin(LazyModuleMixin): + groups: int + transposed: bool + in_channels: int + out_channels: int + kernel_size: Tuple[int, ...] + weight: UninitializedParameter + + def reset_parameters(self) -> None: + # has_uninitialized_params is defined in parent class and it is using a protocol on self + if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc] + # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined + # super class. Turns out that it is defined in _ConvND which is inherited by any class + # that also inherits _LazyConvXdMixin + super().reset_parameters() # type: ignore[misc] + + # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin + def initialize_parameters(self, input) -> None: # type: ignore[override] + # defined by parent class but using a protocol + if self.has_uninitialized_params(): # type: ignore[misc] + self.in_channels = input.shape[1] + if self.in_channels % self.groups != 0: + raise ValueError('in_channels must be divisible by groups') + if self.transposed: + self.weight.materialize(( + self.in_channels, self.out_channels // self.groups, *self.kernel_size)) + else: + self.weight.materialize(( + self.out_channels, self.in_channels // self.groups, *self.kernel_size)) + self.reset_parameters() + + +# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv1d` module with lazy initialization of + the ``in_channels`` argument of the :class:`Conv1d` that is inferred from + the ``input.size(1)``. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros' + ) -> None: + super().__init__( + 0, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode + ) + self.weight = UninitializedParameter() + + +# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv2d` module with lazy initialization of + the ``in_channels`` argument of the :class:`Conv2d` that is inferred from + the ``input.size(1)``. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros' # TODO: refine this type + ) -> None: + super().__init__( + 0, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode + ) + self.weight = UninitializedParameter() + + +# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv3d` module with lazy initialization of + the ``in_channels`` argument of the :class:`Conv3d` that is inferred from + the ``input.size(1)``. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros' + ) -> None: + super().__init__( + 0, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode + ) + self.weight = UninitializedParameter() + + +# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of + the ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from + the ``input.size(1)``. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = 'zeros' + ) -> None: + super().__init__( + 0, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode + ) + self.weight = UninitializedParameter() + + +# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of + the ``in_channels`` argument of the :class:`ConvTranspose2d` that is inferred from + the ``input.size(1)``. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: str = 'zeros' + ) -> None: + super().__init__( + 0, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode + ) + self.weight = UninitializedParameter() + + +# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of + the ``in_channels`` argument of the :class:`ConvTranspose3d` that is inferred from + the ``input.size(1)``. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: str = 'zeros' + ) -> None: + super().__init__( + 0, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode + ) + self.weight = UninitializedParameter() diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index c06b7a5534f61..dd491ba99620a 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -2,7 +2,7 @@ from typing import Tuple, Union from torch import Tensor -from torch import Size +from torch.types import _size class Flatten(Module): @@ -53,8 +53,8 @@ class Unflatten(Module): be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be - a `tuple` of ints or `torch.Size` for `Tensor` input or a `NamedShape` (tuple of `(name, size)` tuples) - for `NamedTensor` input. + a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. Shape: - Input: :math:`(N, *dims)` @@ -62,7 +62,7 @@ class Unflatten(Module): Args: dim (Union[int, str]): Dimension to be unflattened - unflattened_size (Union[torch.Size, NamedShape]): New shape of the unflattened dimension + unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension Examples: >>> input = torch.randn(2, 50) @@ -71,7 +71,7 @@ class Unflatten(Module): >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) - >>> output = m(output) + >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size @@ -79,15 +79,13 @@ class Unflatten(Module): >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) - >>> output = m(output) + >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) - >>> m = nn.Sequential( - >>> nn.Linear(50, 50), - >>> nn.Unflatten('features', (('C', 2), ('H', 50), ('W',50))) - >>> ) - >>> output = m(output) + >>> input = torch.randn(2, 50, names=('N', 'features')) + >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) + >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) """ @@ -95,9 +93,9 @@ class Unflatten(Module): __constants__ = ['dim', 'unflattened_size'] dim: Union[int, str] - unflattened_size: Union[Size, NamedShape] + unflattened_size: Union[_size, NamedShape] - def __init__(self, dim: Union[int, str], unflattened_size: Union[Size, NamedShape]) -> None: + def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: super(Unflatten, self).__init__() if isinstance(dim, int): @@ -121,7 +119,7 @@ def _require_tuple_tuple(self, input): "but found type {}".format(type(input).__name__)) def _require_tuple_int(self, input): - if (isinstance(input, tuple)): + if (isinstance(input, (tuple, list))): for idx, elem in enumerate(input): if not isinstance(elem, int): raise TypeError("unflattened_size must be tuple of ints, " + diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index a0f9c9a19afa8..b27fd644993f6 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -52,6 +52,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def forward(self, input: Tensor) -> Tensor: self._check_input_dim(input) + assert self.running_mean is None or isinstance(self.running_mean, Tensor) + assert self.running_var is None or isinstance(self.running_var, Tensor) return F.instance_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, self.momentum, self.eps) diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py new file mode 100644 index 0000000000000..dbb4eb6c36e8c --- /dev/null +++ b/torch/nn/modules/lazy.py @@ -0,0 +1,257 @@ +import itertools +from typing_extensions import Protocol +import warnings + +import torch +from ..parameter import UninitializedParameter + + +class _LazyProtocol(Protocol): + """This is to avoid errors with mypy checks for + The attributes in a mixin: + https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes + """ + def _register_load_state_dict_pre_hook(self, hook): + ... + + def register_forward_pre_hook(self, hook): + ... + + def _lazy_load_hook( + self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + ... + + def _get_name(self): + ... + + def _infer_parameters(self, module, input): + ... + + @property + def _parameters(self): + ... + + @property + def _buffers(self): + ... + + @property + def _non_persistent_buffers_set(self): + ... + + @property + def _load_hook(self): + ... + + @property + def _initialize_hook(self): + ... + + +class LazyModuleMixin: + r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules." + + .. warning: + Lazy modules are an experimental new feature under active development, + and their API is likely to change. + + Modules that lazily initialize parameters, or "lazy modules", + derive the shapes of their parameters from the first input(s) + to their forward method. Until that first forward they contain + :class:`torch.nn.UninitializedParameter`s that should not be accessed + or used, and afterward they contain regular :class:`torch.nn.Parameter`s. + Lazy modules are convenient since they don't require computing some + module arguments, like the `in_features` argument of a + typical :class:`torch.nn.Linear`. + + After construction, networks with lazy modules should first + be converted to the desired dtype and placed on the desired device. + The lazy modules should then be initialized with one or more "dry runs". + These "dry runs" send inputs of the correct size, dtype, and device through + the network and to each one of its lazy modules. After this the network can be used as usual. + + >>> class LazyMLP(torch.nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.fc1 = torch.nn.LazyLinear(10) + ... self.relu1 = torch.nn.ReLU() + ... self.fc2 = torch.nn.LazyLinear(1) + ... self.relu2 = torch.nn.ReLU() + ... + ... def forward(self, input): + ... x = self.relu1(self.fc1(input)) + ... y = self.relu2(self.fc2(x)) + ... return y + >>> # constructs a network with lazy modules + >>> lazy_mlp = LazyMLP() + >>> # transforms the network's device and dtype + >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' + >>> lazy_mlp = mlp.cuda().double() + >>> lazy_mlp + LazyMLP( + (fc1): LazyLinear(in_features=0, out_features=10, bias=True) + (relu1): ReLU() + (fc2): LazyLinear(in_features=0, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # performs a dry run to initialize the network's lazy modules + >>> lazy_mlp(torch.ones(10,10).cuda()) + >>> # after initialization, LazyLinear modules become regular Linear modules + >>> lazy_mlp + LazyMLP( + (fc1): Linear(in_features=10, out_features=10, bias=True) + (relu1): ReLU() + (fc2): Linear(in_features=10, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # attaches an optimizer, since parameters can now be used as usual + >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01) + + A final caveat when using lazy modules is that the order of initialization of a network's + parameters may change, since the lazy modules are always initialized after other modules. + This can cause the parameters of a network using lazy modules to be initialized differently + than the parameters of a network without lazy modules. + For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module + first and then a regular :class:`torch.nn.Linear` second, the second module would be + initialized on construction and the first module would be initialized during the first dry run. + + Lazy modules can be serialized with a state dict like other modules. For example: + + >>> lazy_mlp = LazyMLP() + >>> # The state dict shows the uninitialized parameters + >>> lazy_mlp.state_dict() + OrderedDict([('fc1.weight', Uninitialized parameter), + ('fc1.bias', + tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, + 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), + ('fc2.weight', Uninitialized parameter), + ('fc2.bias', tensor([0.0019]))]) + + + Lazy modules can also load regular :class:`torch.nn.Parameter` s, + which replace their :class:`torch.nn.UninitializedParameter` s: + + + >>> full_mlp = LazyMLP() + >>> # Dry run to initialize another module + >>> full_mlp.forward(torch.ones(10, 1)) + >>> # Load an initialized state into a lazy module + >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) + >>> # The state dict now holds valid values + >>> lazy_mlp.state_dict() + OrderedDict([('fc1.weight', + tensor([[-0.3837], + [ 0.0907], + [ 0.6708], + [-0.5223], + [-0.9028], + [ 0.2851], + [-0.4537], + [ 0.6813], + [ 0.5766], + [-0.8678]])), + ('fc1.bias', + tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, + 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), + ('fc2.weight', + tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, + 0.2479, 0.1091]])), + ('fc2.bias', tensor([0.0019]))]) + + Note, however, that lazy modules cannot validate that the shape of parameters they load is correct. + + """ + + # modules inheriting from this will change their __class__ to the specified + # one after they are fully initialized + cls_to_become = None + + def __init__(self: _LazyProtocol, *args, **kwargs): + # Mypy doesnt like this super call in a mixin + super().__init__(*args, **kwargs) # type: ignore + self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) + self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters) + warnings.warn('Lazy modules are a new feature under heavy development ' + 'so changes to the API or functionality can happen at any moment.') + + def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): + # This should be ideally implemented as a hook, + # but we should override `detach` in the UninitializedParameter to return itself + # which is not clean + for name, param in self._parameters.items(): + if param is not None: + if isinstance(param, UninitializedParameter): + destination[prefix + name] = param + else: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + + def _lazy_load_hook( + self: _LazyProtocol, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """load_state_dict pre-hook function for lazy buffers and parameters. + + The purpose of this hook is to adjust the current state and/or + ``state_dict`` being loaded so that a module instance serialized in + both un/initialized state can be deserialized onto both un/initialized + module instance. + See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` + for the details of the hook specification. + """ + local_state = {k: v for k, v in self._parameters.items() if v is not None} + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if isinstance(param, UninitializedParameter): + # The current parameter is not initialized but the one being loaded one is + # create a new parameter based on the uninitialized one + if not isinstance(input_param, UninitializedParameter): + with torch.no_grad(): + param.materialize(input_param.shape) + + def initialize_parameters(self: _LazyProtocol, *args, **kwargs): + r"""Initialize parameters according to the input batch properties. + This adds an interface to isolate parameter initialization from the + forward pass when doing parameter shape inference. + """ + raise NotImplementedError('initialize_parameters is not implemented for {}'.format(self.__class__.__name__)) + + def has_uninitialized_params(self: _LazyProtocol): + r"""Check if a module has parameters that are not initialized + """ + # This is to avoid the JIT to track this parameter and force + # custom modules __setstate__ to add it + params = self._parameters.values() + for param in itertools.chain(params): + if isinstance(param, (UninitializedParameter)): + return True + return False + + def _infer_parameters(self: _LazyProtocol, module, input): + r"""Infers the size and initializes the parameters according to the + provided input batch. + Given a module that contains parameters that were declared inferrable + using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass + in the complete module using the provided input to initialize all the parameters + as needed. + The module is set into evaluation mode before running the forward pass in order + to avoid saving statistics or calculating gradients + """ + module.initialize_parameters(*input) + if module.has_uninitialized_params(): + raise RuntimeError('module {} has not been fully initialized'.format(self._get_name())) + module._initialize_hook.remove() + module._load_hook.remove() + delattr(module, '_initialize_hook') + delattr(module, '_load_hook') + if module.cls_to_become is not None: + module.__class__ = module.cls_to_become + + + def _replicate_for_data_parallel(self: _LazyProtocol): + raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. ' + 'Run a dummy forward pass to correctly initialize the modules') diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 76cec57f72057..ea2c3a8f453ba 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -2,10 +2,11 @@ import torch from torch import Tensor -from torch.nn.parameter import Parameter +from torch.nn.parameter import Parameter, UninitializedParameter from .. import functional as F from .. import init from .module import Module +from .lazy import LazyModuleMixin class Identity(Module): @@ -101,10 +102,10 @@ def extra_repr(self) -> str: # This class exists solely for Transformer; it has an annotation stating # that bias is never None, which appeases TorchScript class _LinearWithBias(Linear): - bias: Tensor + bias: Tensor # type: ignore def __init__(self, in_features: int, out_features: int) -> None: - super().__init__(in_features, out_features, bias=True) + super().__init__(in_features, out_features, bias=True) # type: ignore class Bilinear(Module): @@ -178,4 +179,50 @@ def extra_repr(self) -> str: self.in1_features, self.in2_features, self.out_features, self.bias is not None ) + +class LazyLinear(LazyModuleMixin, Linear): + r"""A :class:`torch.nn.Linear` module with lazy initialization. + + In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter` + class. They will be initialized after the first call to ``forward`` is done and the + module will become a regular :class:`torch.nn.Linear` module. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + + + """ + + cls_to_become = Linear # type: ignore[assignment] + weight: UninitializedParameter + + def __init__(self, out_features: int, bias: bool = True) -> None: + super().__init__(0, out_features, bias) + self.weight = UninitializedParameter() + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params() and self.in_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore + if self.has_uninitialized_params(): + with torch.no_grad(): + self.in_features = input.shape[-1] + self.weight.materialize((self.out_features, self.in_features)) + self.reset_parameters() # TODO: PartialLinear - maybe in sparse? diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 91a62a85771e1..a2642fb4f149a 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -42,8 +42,8 @@ class L1Loss(_Loss): .. math:: \ell(x, y) = \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} :math:`x` and :math:`y` are tensors of arbitrary shapes with a total @@ -53,12 +53,14 @@ class L1Loss(_Loss): The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + Supports real-valued and complex-valued inputs. + Args: size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -130,9 +132,9 @@ class NLLLoss(_WeightedLoss): .. math:: \ell(x, y) = \begin{cases} \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & - \text{if reduction} = \text{'mean';}\\ + \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & - \text{if reduction} = \text{'sum'.} + \text{if reduction} = \text{`sum'.} \end{cases} Can also be used for higher dimension inputs, such as 2D images, by providing @@ -148,7 +150,7 @@ class NLLLoss(_WeightedLoss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. When :attr:`size_average` is ``True``, the loss is averaged over @@ -210,6 +212,7 @@ def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_in self.ignore_index = ignore_index def forward(self, input: Tensor, target: Tensor) -> Tensor: + assert self.weight is None or isinstance(self.weight, Tensor) return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) @@ -250,7 +253,7 @@ class PoissonNLLLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when :attr:`log_input = False`. Default: 1e-8 reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the @@ -322,8 +325,8 @@ class KLDivLoss(_Loss): .. math:: \ell(x, y) = \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{'mean';} \\ - \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';} \\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} In default :attr:`reduction` mode ``'mean'``, the losses are averaged for each minibatch over observations @@ -338,7 +341,7 @@ class KLDivLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -396,8 +399,8 @@ class MSELoss(_Loss): .. math:: \ell(x, y) = \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} :math:`x` and :math:`y` are tensors of arbitrary shapes with a total @@ -412,7 +415,7 @@ class MSELoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -461,8 +464,8 @@ class BCELoss(_WeightedLoss): .. math:: \ell(x, y) = \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} This is used for measuring the error of a reconstruction in for example @@ -493,7 +496,7 @@ class BCELoss(_WeightedLoss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -527,6 +530,7 @@ def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=No super(BCELoss, self).__init__(weight, size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: + assert self.weight is None or isinstance(self.weight, Tensor) return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) @@ -548,8 +552,8 @@ class BCEWithLogitsLoss(_Loss): .. math:: \ell(x, y) = \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} This is used for measuring the error of a reconstruction in for example @@ -591,7 +595,7 @@ class BCEWithLogitsLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -626,6 +630,8 @@ def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=No self.register_buffer('pos_weight', pos_weight) def forward(self, input: Tensor, target: Tensor) -> Tensor: + assert self.weight is None or isinstance(self.weight, Tensor) + assert self.pos_weight is None or isinstance(self.pos_weight, Tensor) return F.binary_cross_entropy_with_logits(input, target, self.weight, pos_weight=self.pos_weight, @@ -651,8 +657,8 @@ class HingeEmbeddingLoss(_Loss): .. math:: \ell(x, y) = \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} where :math:`L = \{l_1,\dots,l_N\}^\top`. @@ -663,7 +669,7 @@ class HingeEmbeddingLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -718,7 +724,7 @@ class MultiLabelMarginLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -758,10 +764,10 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: class SmoothL1Loss(_Loss): r"""Creates a criterion that uses a squared term if the absolute - element-wise error falls below 1 and an L1 term otherwise. - It is less sensitive to outliers than the `MSELoss` and in some cases + element-wise error falls below beta and an L1 term otherwise. + It is less sensitive to outliers than the :class:`torch.nn.MSELoss` and in some cases prevents exploding gradients (e.g. see `Fast R-CNN` paper by Ross Girshick). - Also known as the Huber loss: + Omitting a scaling factor of :attr:`beta`, this loss is also known as the Huber loss: .. math:: \text{loss}(x, y) = \frac{1}{n} \sum_{i} z_{i} @@ -771,13 +777,18 @@ class SmoothL1Loss(_Loss): .. math:: z_{i} = \begin{cases} - 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ - |x_i - y_i| - 0.5, & \text{otherwise } + 0.5 (x_i - y_i)^2 / beta, & \text{if } |x_i - y_i| < beta \\ + |x_i - y_i| - 0.5 * beta, & \text{otherwise } \end{cases} :math:`x` and :math:`y` arbitrary shapes with a total of :math:`n` elements each the sum operation still operates over all the elements, and divides by :math:`n`. + :attr:`beta` is an optional parameter that defaults to 1. + + Note: When :attr:`beta` is set to 0, this is equivalent to :class:`L1Loss`. + Passing a negative value in for :attr:`beta` will result in an exception. + The division by :math:`n` can be avoided if sets ``reduction = 'sum'``. Args: @@ -785,7 +796,7 @@ class SmoothL1Loss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -796,6 +807,8 @@ class SmoothL1Loss(_Loss): elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss. + This value defaults to 1.0. Shape: - Input: :math:`(N, *)` where :math:`*` means, any number of additional @@ -807,11 +820,12 @@ class SmoothL1Loss(_Loss): """ __constants__ = ['reduction'] - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', beta: float = 1.0) -> None: super(SmoothL1Loss, self).__init__(size_average, reduce, reduction) + self.beta = beta def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.smooth_l1_loss(input, target, reduction=self.reduction) + return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta) class SoftMarginLoss(_Loss): @@ -827,7 +841,7 @@ class SoftMarginLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -856,7 +870,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: class CrossEntropyLoss(_WeightedLoss): - r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class. + r"""This criterion combines :class:`~torch.nn.LogSoftmax` and :class:`~torch.nn.NLLLoss` in one single class. It is useful when training a classification problem with `C` classes. If provided, the optional argument :attr:`weight` should be a 1D `Tensor` @@ -904,7 +918,7 @@ class CrossEntropyLoss(_WeightedLoss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. When :attr:`size_average` is ``True``, the loss is averaged over non-ignored targets. @@ -950,6 +964,7 @@ def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_in self.ignore_index = ignore_index def forward(self, input: Tensor, target: Tensor) -> Tensor: + assert self.weight is None or isinstance(self.weight, Tensor) return F.cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) @@ -975,7 +990,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -998,6 +1013,7 @@ def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=No super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: + assert self.weight is None or isinstance(self.weight, Tensor) return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction) @@ -1025,7 +1041,7 @@ class CosineEmbeddingLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -1067,7 +1083,7 @@ class MarginRankingLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -1139,7 +1155,7 @@ class MultiMarginLoss(_WeightedLoss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -1165,6 +1181,7 @@ def __init__(self, p: int = 1, margin: float = 1., weight: Optional[Tensor] = No self.margin = margin def forward(self, input: Tensor, target: Tensor) -> Tensor: + assert self.weight is None or isinstance(self.weight, Tensor) return F.multi_margin_loss(input, target, p=self.p, margin=self.margin, weight=self.weight, reduction=self.reduction) @@ -1205,7 +1222,7 @@ class TripletMarginLoss(_Loss): the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are instead summed for each minibatch. Ignored - when reduce is ``False``. Default: ``True`` + when :attr:`reduce` is ``False``. Default: ``True`` reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged or summed over observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per @@ -1268,7 +1285,7 @@ class TripletMarginWithDistanceLoss(_Loss): where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function quantifying the closeness of two tensors, referred to as the :attr:`distance_function`; - and :math:`margin` is a non-negative margin representing the minimum difference + and :math:`margin` is a nonnegative margin representing the minimum difference between the positive and negative distances that is required for the loss to be 0. The input tensors have :math:`N` elements each and can be of any shape that the distance function can handle. @@ -1290,7 +1307,7 @@ class TripletMarginWithDistanceLoss(_Loss): distance_function (callable, optional): A nonnegative, real-valued function that quantifies the closeness of two tensors. If not specified, `nn.PairwiseDistance` will be used. Default: ``None`` - margin (float, optional): A non-negative margin representing the minimum difference + margin (float, optional): A nonnegative margin representing the minimum difference between the positive and negative distances required for the loss to be 0. Larger margins penalize cases where the negative examples are not distant enough from the anchors, relative to the positives. Default: :math:`1`. @@ -1315,9 +1332,9 @@ class TripletMarginWithDistanceLoss(_Loss): >>> # Initialize embeddings >>> embedding = nn.Embedding(1000, 128) - >>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True) - >>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True) - >>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True) + >>> anchor_ids = torch.randint(0, 1000, (1,)) + >>> positive_ids = torch.randint(0, 1000, (1,)) + >>> negative_ids = torch.randint(0, 1000, (1,)) >>> anchor = embedding(anchor_ids) >>> positive = embedding(positive_ids) >>> negative = embedding(negative_ids) @@ -1355,7 +1372,8 @@ class TripletMarginWithDistanceLoss(_Loss): def __init__(self, *, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, margin: float = 1.0, swap: bool = False, reduction: str = 'mean'): super(TripletMarginWithDistanceLoss, self).__init__(size_average=None, reduce=None, reduction=reduction) - self.distance_function = distance_function if distance_function is not None else PairwiseDistance() + self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = \ + distance_function if distance_function is not None else PairwiseDistance() self.margin = margin self.swap = swap diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 30e732e6d859f..a243c3b8006e2 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1,14 +1,14 @@ from collections import OrderedDict, namedtuple -import functools import itertools import warnings +import functools import torch from ..parameter import Parameter import torch.utils.hooks as hooks from torch import Tensor, device, dtype -from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict +from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List from ...utils.hooks import RemovableHandle _grad_t = Union[Tuple[Tensor, ...], Tensor] @@ -26,14 +26,6 @@ def __repr__(self): __str__ = __repr__ -class ModuleAttributeError(AttributeError): - """ When `__getattr__` raises AttributeError inside a property, - AttributeError is raised with the property name instead of the - attribute that initially raised AttributeError, making the error - message uninformative. Using `ModuleAttributeError` instead - fixes this issue.""" - - def _addindent(s_, numSpaces): s = s_.split('\n') # don't do anything for single-line stuff @@ -49,9 +41,10 @@ def _addindent(s_, numSpaces): r"""This tracks hooks common to all modules that are executed before/after calling forward and backward. This is global state used for debugging/profiling purposes""" -_global_backward_hooks = OrderedDict() -_global_forward_pre_hooks = OrderedDict() -_global_forward_hooks = OrderedDict() +_global_backward_hooks: Dict[int, Callable] = OrderedDict() +_global_is_full_backward_hook: Optional[bool] = None +_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() +_global_forward_hooks: Dict[int, Callable] = OrderedDict() def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: @@ -122,6 +115,31 @@ def register_module_backward_hook( ) -> RemovableHandle: r"""Registers a backward hook common to all the modules. + This function is deprecated in favor of :meth:`nn.module.register_module_full_backward_hook` + and the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is True: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them.") + + _global_is_full_backward_hook = False + + handle = hooks.RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + +def register_module_full_backward_hook( + hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]] +) -> RemovableHandle: + r"""Registers a backward hook common to all the modules. + .. warning :: This adds global state to the `nn.module` module and it is only intended for debugging/profiling purposes. @@ -138,12 +156,13 @@ def register_module_backward_hook( hook(module, grad_input, grad_output) -> Tensor or None - The :attr:`grad_input` and :attr:`grad_output` may be tuples if the - module has multiple inputs or outputs. The hook should not modify its - arguments, but it can optionally return a new gradient with respect to - input that will be used in place of :attr:`grad_input` in subsequent - computations. :attr:`grad_input` will only correspond to the inputs given - as positional arguments. + The :attr:`grad_input` and :attr:`grad_output` are tuples. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the input that will be used in place of :attr:`grad_input` in + subsequent computations. :attr:`grad_input` will only correspond to the inputs given + as positional arguments and all kwarg arguments will not appear in the hook. Entries + in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor + arguments. Global hooks are called before hooks registered with `register_backward_hook` @@ -153,6 +172,13 @@ def register_module_backward_hook( ``handle.remove()`` """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is False: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them.") + + _global_is_full_backward_hook = True + handle = hooks.RemovableHandle(_global_backward_hooks) _global_backward_hooks[handle.id] = hook return handle @@ -219,6 +245,7 @@ def forward(self, x): _version: int = 1 training: bool + _is_full_backward_hook: Optional[bool] def __init__(self): """ @@ -231,6 +258,7 @@ def __init__(self): self._buffers = OrderedDict() self._non_persistent_buffers_set = set() self._backward_hooks = OrderedDict() + self._is_full_backward_hook = None self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() @@ -349,7 +377,7 @@ def add_module(self, name: str, module: Optional['Module']) -> None: elif hasattr(self, name) and name not in self._modules: raise KeyError("attribute '{}' already exists".format(name)) elif '.' in name: - raise KeyError("module name can't contain \".\"") + raise KeyError("module name can't contain \".\", got: {}".format(name)) elif name == '': raise KeyError("module name can't be empty string \"\"") self._modules[name] = module @@ -453,7 +481,7 @@ def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: it should be called before constructing optimizer if the module will live on GPU while being optimized. - Arguments: + Args: device (int, optional): if specified, all parameters will be copied to that device @@ -473,7 +501,7 @@ def cpu(self: T) -> T: def type(self: T, dst_type: Union[dtype, str]) -> T: r"""Casts all parameters and buffers to :attr:`dst_type`. - Arguments: + Args: dst_type (type or string): the desired type Returns: @@ -540,8 +568,8 @@ def to(self, *args, **kwargs): .. function:: to(memory_format=torch.channels_last) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts - floating point desired :attr:`dtype` s. In addition, this method will - only cast the floating point parameters and buffers to :attr:`dtype` + floating point or complex :attr:`dtype`s. In addition, this method will + only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously @@ -556,8 +584,8 @@ def to(self, *args, **kwargs): Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module - dtype (:class:`torch.dtype`): the desired floating point type of - the floating point parameters and buffers in this module + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module memory_format (:class:`torch.memory_format`): the desired memory @@ -567,7 +595,7 @@ def to(self, *args, **kwargs): Returns: Module: self - Example:: + Examples:: >>> linear = nn.Linear(2, 2) >>> linear.weight @@ -595,19 +623,36 @@ def to(self, *args, **kwargs): tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) + >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) + >>> linear.weight + Parameter containing: + tensor([[ 0.3741+0.j, 0.2382+0.j], + [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) + >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) + tensor([[0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) + """ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if dtype is not None: - if not dtype.is_floating_point: - raise TypeError('nn.Module.to only accepts floating point ' + if not (dtype.is_floating_point or dtype.is_complex): + raise TypeError('nn.Module.to only accepts floating point or complex ' 'dtypes, but got desired dtype={}'.format(dtype)) + if dtype.is_complex: + warnings.warn( + "Complex modules are a new feature under active development whose design may change, " + "and some modules might not work as expected when using complex tensors as parameters or buffers. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md " + "if a complex module does not work as expected.") def convert(t): if convert_to_format is not None and t.dim() == 4: - return t.to(device, dtype if t.is_floating_point() else None, non_blocking, memory_format=convert_to_format) - return t.to(device, dtype if t.is_floating_point() else None, non_blocking) + return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, memory_format=convert_to_format) + return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) return self._apply(convert) @@ -616,26 +661,47 @@ def register_backward_hook( ) -> RemovableHandle: r"""Registers a backward hook on the module. - .. warning :: + This function is deprecated in favor of :meth:`nn.Module.register_full_backward_hook` and + the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + if self._is_full_backward_hook is True: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them.") - The current implementation will not have the presented behavior - for complex :class:`Module` that perform many operations. - In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only - contain the gradients for a subset of the inputs and outputs. - For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` - directly on a specific input or output to get the required gradients. + self._is_full_backward_hook = False + + handle = hooks.RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + return handle + + def register_full_backward_hook( + self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]] + ) -> RemovableHandle: + r"""Registers a backward hook on the module. The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:: - hook(module, grad_input, grad_output) -> Tensor or None + hook(module, grad_input, grad_output) -> tuple(Tensor) or None - The :attr:`grad_input` and :attr:`grad_output` may be tuples if the - module has multiple inputs or outputs. The hook should not modify its - arguments, but it can optionally return a new gradient with respect to - input that will be used in place of :attr:`grad_input` in subsequent - computations. :attr:`grad_input` will only correspond to the inputs given - as positional arguments. + The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients + with respect to the inputs and outputs respectively. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the input that will be used in place of :attr:`grad_input` in + subsequent computations. :attr:`grad_input` will only correspond to the inputs given + as positional arguments and all kwarg arguments are ignored. Entries + in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor + arguments. + + .. warning :: + Modifying inputs or outputs inplace is not allowed when using backward hooks and + will raise an error. Returns: :class:`torch.utils.hooks.RemovableHandle`: @@ -643,10 +709,78 @@ def register_backward_hook( ``handle.remove()`` """ + if self._is_full_backward_hook is False: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them.") + + self._is_full_backward_hook = True + handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle + def _get_backward_hooks(self): + r"""Returns the backward hooks for use in the call function. + It returns two lists, one with the full backward hooks and one with the non-full + backward hooks. + """ + full_backward_hooks: List[Callable] = [] + if (_global_is_full_backward_hook is True): + full_backward_hooks += _global_backward_hooks.values() + if (self._is_full_backward_hook is True): + full_backward_hooks += self._backward_hooks.values() + + non_full_backward_hooks: List[Callable] = [] + if (_global_is_full_backward_hook is False): + non_full_backward_hooks += _global_backward_hooks.values() + if (self._is_full_backward_hook is False): + non_full_backward_hooks += self._backward_hooks.values() + + return full_backward_hooks, non_full_backward_hooks + + def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): + if not isinstance(result, torch.Tensor): + if not (isinstance(result, tuple) and all([isinstance(r, torch.Tensor) for r in result])): + warnings.warn("Using non-full backward hooks on a Module that does not return a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_output. " + "Please use register_full_backward_hook to get the documented behavior.") + return + else: + result = (result,) + + if not isinstance(inputs, torch.Tensor): + if not (isinstance(inputs, tuple) and all([isinstance(i, torch.Tensor) for i in inputs])): + warnings.warn("Using non-full backward hooks on a Module that does not take as input a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_input. " + "Please use register_full_backward_hook to get the documented behavior.") + return + else: + inputs = (inputs,) + + # At this point we are sure that inputs and result are tuple of Tensors + out_grad_fn = set([r.grad_fn for r in result if r.grad_fn is not None]) + if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): + warnings.warn("Using a non-full backward hook when outputs are nested in python data structure " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output.") + elif len(out_grad_fn) > 1: + warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output. Please use register_full_backward_hook to get the documented behavior.") + else: + # At this point the grad_ouput part of the hook will most likely be correct + inputs_grad_fn = set([i.grad_fn for i in inputs if i.grad_fn is not None]) + + next_functions = set([n[0] for n in grad_fn.next_functions]) + + if inputs_grad_fn != next_functions: + warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_input. Please use register_full_backward_hook to get the documented " + "behavior.") + def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: r"""Registers a forward pre-hook on the module. @@ -699,9 +833,10 @@ def _slow_forward(self, *input, **kwargs): return self.forward(*input, **kwargs) recording_scopes = torch.jit._trace._trace_module_map is not None if recording_scopes: - name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None + # type ignore was added because at this point one knows that + # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] + name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore if name: - cur_scope_name = tracing_state.current_scope() tracing_state.push_scope(name) else: recording_scopes = False @@ -713,6 +848,11 @@ def _slow_forward(self, *input, **kwargs): return result def _call_impl(self, *input, **kwargs): + # Do not call functions when jit is used + full_backward_hooks, non_full_backward_hooks = [], [] + if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0: + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + for hook in itertools.chain( _global_forward_pre_hooks.values(), self._forward_pre_hooks.values()): @@ -721,6 +861,12 @@ def _call_impl(self, *input, **kwargs): if not isinstance(result, tuple): result = (result,) input = result + + bw_hook = None + if len(full_backward_hooks) > 0: + bw_hook = hooks.BackwardHook(self, full_backward_hooks) + input = bw_hook.setup_input_hook(input) + if torch._C._get_tracing_state(): result = self._slow_forward(*input, **kwargs) else: @@ -731,7 +877,12 @@ def _call_impl(self, *input, **kwargs): hook_result = hook(self, input, result) if hook_result is not None: result = hook_result - if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0): + + if bw_hook: + result = bw_hook.setup_output_hook(result) + + # Handle the non-full backward hooks + if len(non_full_backward_hooks) > 0: var = result while not isinstance(var, torch.Tensor): if isinstance(var, dict): @@ -740,12 +891,12 @@ def _call_impl(self, *input, **kwargs): var = var[0] grad_fn = var.grad_fn if grad_fn is not None: - for hook in itertools.chain( - _global_backward_hooks.values(), - self._backward_hooks.values()): + for hook in non_full_backward_hooks: wrapper = functools.partial(hook, self) functools.update_wrapper(wrapper, hook) grad_fn.register_hook(wrapper) + self._maybe_warn_non_full_backward_hook(input, result, grad_fn) + return result __call__ : Callable[..., Any] = _call_impl @@ -761,6 +912,8 @@ def __setstate__(self, state): self._load_state_dict_pre_hooks = OrderedDict() if '_non_persistent_buffers_set' not in self.__dict__: self._non_persistent_buffers_set = set() + if '_is_full_backward_hook' not in self.__dict__: + self._is_full_backward_hook = None def __getattr__(self, name: str) -> Union[Tensor, 'Module']: if '_parameters' in self.__dict__: @@ -775,7 +928,7 @@ def __getattr__(self, name: str) -> Union[Tensor, 'Module']: modules = self.__dict__['_modules'] if name in modules: return modules[name] - raise ModuleAttributeError("'{}' object has no attribute '{}'".format( + raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, name)) def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: @@ -855,7 +1008,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): In rare cases, subclasses can achieve class-specific behavior by overriding this method with custom logic. - Arguments: + Args: destination (dict): a dict where state will be stored prefix (str): the prefix for parameters and buffers used in this module @@ -936,7 +1089,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So it can be modified. - Arguments: + Args: state_dict (dict): a dict containing parameters and persistent buffers. prefix (str): the prefix for parameters and buffers used in this @@ -965,18 +1118,20 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, key = prefix + name if key in state_dict: input_param = state_dict[key] - + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter) # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if len(param.shape) == 0 and len(input_param.shape) == 1: + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] - if input_param.shape != param.shape: + if not is_param_lazy and input_param.shape != param.shape: # local shape should match the one in checkpoint error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue - try: with torch.no_grad(): param.copy_(input_param) @@ -997,14 +1152,14 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) - def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], + def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. - Arguments: + Args: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys @@ -1016,15 +1171,16 @@ def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor] * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys """ - missing_keys = [] - unexpected_keys = [] - error_msgs = [] + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: - state_dict._metadata = metadata + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) @@ -1035,7 +1191,7 @@ def load(module, prefix=''): load(child, prefix + name + '.') load(self) - load = None # break load->load reference cycle + del load if strict: if len(unexpected_keys) > 0: @@ -1089,7 +1245,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: for name, param in self.named_parameters(recurse=recurse): yield param - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]: + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. @@ -1317,7 +1473,7 @@ def zero_grad(self, set_to_none: bool = False) -> None: r"""Sets gradients of all model parameters to zero. See similar function under :class:`torch.optim.Optimizer` for more context. - Arguments: + Args: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. """ diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index cd9613eb835f7..acdd2a8c9d419 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -7,7 +7,7 @@ from .. import init from torch import Tensor, Size -from typing import Union, List +from typing import Union, List, Tuple class LocalResponseNorm(Module): @@ -141,20 +141,21 @@ class LayerNorm(Module): >>> output = m(input) """ __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] - normalized_shape: _shape_t + normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None: super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = tuple(normalized_shape) + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.weight = Parameter(torch.Tensor(*self.normalized_shape)) + self.bias = Parameter(torch.Tensor(*self.normalized_shape)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) @@ -169,7 +170,7 @@ def forward(self, input: Tensor) -> Tensor: return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps) - def extra_repr(self) -> Tensor: + def extra_repr(self) -> str: return '{normalized_shape}, eps={eps}, ' \ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 2ca9c19f79b47..186d89c6fd098 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -4,6 +4,7 @@ from torch import Tensor from ..common_types import _size_2_t, _size_4_t, _size_6_t +from typing import Sequence, Tuple # TODO: grad_output size asserts in THNN @@ -12,6 +13,7 @@ class _ConstantPadNd(Module): __constants__ = ['padding', 'value'] value: float + padding: Sequence[int] def __init__(self, value: float) -> None: super(_ConstantPadNd, self).__init__() @@ -67,7 +69,7 @@ class ConstantPad1d(_ConstantPadNd): [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]]) """ - padding: _size_2_t + padding: Tuple[int, int] def __init__(self, padding: _size_2_t, value: float): super(ConstantPad1d, self).__init__(value) @@ -117,7 +119,7 @@ class ConstantPad2d(_ConstantPadNd): """ __constants__ = ['padding', 'value'] - padding: _size_4_t + padding: Tuple[int, int, int, int] def __init__(self, padding: _size_4_t, value: float) -> None: super(ConstantPad2d, self).__init__(value) @@ -156,7 +158,7 @@ class ConstantPad3d(_ConstantPadNd): >>> output = m(input) """ - padding: _size_6_t + padding: Tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t, value: float) -> None: super(ConstantPad3d, self).__init__(value) @@ -165,6 +167,7 @@ def __init__(self, padding: _size_6_t, value: float) -> None: class _ReflectionPadNd(Module): __constants__ = ['padding'] + padding: Sequence[int] def forward(self, input: Tensor) -> Tensor: return F.pad(input, self.padding, 'reflect') @@ -206,7 +209,7 @@ class ReflectionPad1d(_ReflectionPadNd): [7., 6., 5., 4., 5., 6., 7., 6.]]]) """ - padding: _size_2_t + padding: Tuple[int, int] def __init__(self, padding: _size_2_t) -> None: super(ReflectionPad1d, self).__init__() @@ -257,7 +260,7 @@ class ReflectionPad2d(_ReflectionPadNd): [7., 6., 7., 8., 7.]]]]) """ - padding: _size_4_t + padding: Tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: super(ReflectionPad2d, self).__init__() @@ -266,6 +269,7 @@ def __init__(self, padding: _size_4_t) -> None: class _ReplicationPadNd(Module): __constants__ = ['padding'] + padding: Sequence[int] def forward(self, input: Tensor) -> Tensor: return F.pad(input, self.padding, 'replicate') @@ -307,7 +311,7 @@ class ReplicationPad1d(_ReplicationPadNd): [4., 4., 4., 4., 5., 6., 7., 7.]]]) """ - padding: _size_2_t + padding: Tuple[int, int] def __init__(self, padding: _size_2_t) -> None: super(ReplicationPad1d, self).__init__() @@ -358,7 +362,7 @@ class ReplicationPad2d(_ReplicationPadNd): [6., 6., 7., 8., 8.]]]]) """ - padding: _size_4_t + padding: Tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: super(ReplicationPad2d, self).__init__() @@ -397,7 +401,7 @@ class ReplicationPad3d(_ReplicationPadNd): >>> output = m(input) """ - padding: _size_6_t + padding: Tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: super(ReplicationPad3d, self).__init__() @@ -448,7 +452,7 @@ class ZeroPad2d(ConstantPad2d): [ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]]) """ - padding: _size_4_t + padding: Tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: super(ZeroPad2d, self).__init__(padding, 0.) diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 3c8c626047dcc..d17f5616c2e9d 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -6,12 +6,12 @@ class PixelShuffle(Module): r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` - to a tensor of shape :math:`(*, C, H \times r, W \times r)`. + to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. This is useful for implementing efficient sub-pixel convolution with a stride of :math:`1/r`. - Look at the paper: + See the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ by Shi et. al (2016) for more details. @@ -19,10 +19,17 @@ class PixelShuffle(Module): upscale_factor (int): factor to increase spatial resolution by Shape: - - Input: :math:`(N, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2` - - Output: :math:`(N, C, H_{out}, W_{out})` where - :math:`H_{out} = H_{in} \times \text{upscale\_factor}` - and :math:`W_{out} = W_{in} \times \text{upscale\_factor}` + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \div \text{upscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \times \text{upscale\_factor} + + .. math:: + W_{out} = W_{in} \times \text{upscale\_factor} Examples:: @@ -47,3 +54,53 @@ def forward(self, input: Tensor) -> Tensor: def extra_repr(self) -> str: return 'upscale_factor={}'.format(self.upscale_factor) + + +class PixelUnshuffle(Module): + r"""Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements + in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape + :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et. al (2016) for more details. + + Args: + downscale_factor (int): factor to decrease spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \times \text{downscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \div \text{downscale\_factor} + + .. math:: + W_{out} = W_{in} \div \text{downscale\_factor} + + Examples:: + + >>> pixel_unshuffle = nn.PixelUnshuffle(3) + >>> input = torch.randn(1, 1, 12, 12) + >>> output = pixel_unshuffle(input) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) + + .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: + https://arxiv.org/abs/1609.05158 + """ + __constants__ = ['downscale_factor'] + downscale_factor: int + + def __init__(self, downscale_factor: int) -> None: + super(PixelUnshuffle, self).__init__() + self.downscale_factor = downscale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.pixel_unshuffle(input, self.downscale_factor) + + def extra_repr(self) -> str: + return 'downscale_factor={}'.format(self.downscale_factor) diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 734912684d8f9..78aae504083b0 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -5,7 +5,8 @@ from .utils import _single, _pair, _triple from .. import functional as F -from ..common_types import _size_any_t, _size_1_t, _size_2_t, _size_3_t, _ratio_3_t, _ratio_2_t +from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t, + _ratio_3_t, _ratio_2_t, _size_any_opt_t, _size_2_opt_t, _size_3_opt_t) class _MaxPoolNd(Module): @@ -45,6 +46,10 @@ class MaxPool1d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the sliding window. This `link`_ has a nice visualization of the pooling parameters. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + Args: kernel_size: The size of the sliding window, must be > 0. stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. @@ -104,6 +109,10 @@ class MaxPool2d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension @@ -174,6 +183,10 @@ class MaxPool3d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension @@ -474,6 +487,10 @@ class AvgPool1d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be an ``int`` or a one-element tuple. @@ -537,6 +554,10 @@ class AvgPool2d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension @@ -581,7 +602,7 @@ class AvgPool2d(_AvgPoolNd): count_include_pad: bool def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, - ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: bool = None) -> None: + ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None: super(AvgPool2d, self).__init__() self.kernel_size = kernel_size self.stride = stride if (stride is not None) else kernel_size @@ -614,6 +635,10 @@ class AvgPool3d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension @@ -662,7 +687,7 @@ class AvgPool3d(_AvgPoolNd): count_include_pad: bool def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, - ceil_mode: bool = False, count_include_pad: bool = True, divisor_override=None) -> None: + ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None: super(AvgPool3d, self).__init__() self.kernel_size = kernel_size self.stride = stride if (stride is not None) else kernel_size @@ -929,7 +954,7 @@ class _AdaptiveMaxPoolNd(Module): __constants__ = ['output_size', 'return_indices'] return_indices: bool - def __init__(self, output_size: _size_any_t, return_indices: bool = False) -> None: + def __init__(self, output_size: _size_any_opt_t, return_indices: bool = False) -> None: super(_AdaptiveMaxPoolNd, self).__init__() self.output_size = output_size self.return_indices = return_indices @@ -996,7 +1021,7 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): """ - output_size: _size_2_t + output_size: _size_2_opt_t def forward(self, input: Tensor) -> Tensor: return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) @@ -1033,7 +1058,7 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): """ - output_size: _size_3_t + output_size: _size_3_opt_t def forward(self, input: Tensor) -> Tensor: return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) @@ -1042,7 +1067,7 @@ def forward(self, input: Tensor) -> Tensor: class _AdaptiveAvgPoolNd(Module): __constants__ = ['output_size'] - def __init__(self, output_size: _size_any_t) -> None: + def __init__(self, output_size: _size_any_opt_t) -> None: super(_AdaptiveAvgPoolNd, self).__init__() self.output_size = output_size @@ -1101,7 +1126,7 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): """ - output_size: _size_2_t + output_size: _size_2_opt_t def forward(self, input: Tensor) -> Tensor: return F.adaptive_avg_pool2d(input, self.output_size) @@ -1135,7 +1160,7 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): """ - output_size: _size_3_t + output_size: _size_3_opt_t def forward(self, input: Tensor) -> Tensor: return F.adaptive_avg_pool3d(input, self.output_size) diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index e6589b9ef1d9b..7b980c080dbe1 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -1,7 +1,7 @@ import math import warnings import numbers -from typing import List, Tuple, Optional, overload +from typing import List, Tuple, Optional, overload, Union import torch from torch import Tensor @@ -23,8 +23,8 @@ def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tens class RNNBase(Module): __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', - 'batch_first', 'dropout', 'bidirectional'] - __ignored_properties__ = ['all_weights'] + 'batch_first', 'dropout', 'bidirectional', 'proj_size'] + __jit_unused_properties__ = ['all_weights'] mode: str input_size: int @@ -34,10 +34,11 @@ class RNNBase(Module): batch_first: bool dropout: float bidirectional: bool + proj_size: int def __init__(self, mode: str, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, - dropout: float = 0., bidirectional: bool = False) -> None: + dropout: float = 0., bidirectional: bool = False, proj_size: int = 0) -> None: super(RNNBase, self).__init__() self.mode = mode self.input_size = input_size @@ -47,6 +48,7 @@ def __init__(self, mode: str, input_size: int, hidden_size: int, self.batch_first = batch_first self.dropout = float(dropout) self.bidirectional = bidirectional + self.proj_size = proj_size num_directions = 2 if bidirectional else 1 if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ @@ -59,6 +61,10 @@ def __init__(self, mode: str, input_size: int, hidden_size: int, "recurrent layer, so non-zero dropout expects " "num_layers greater than 1, but got dropout={} and " "num_layers={}".format(dropout, num_layers)) + if proj_size < 0: + raise ValueError("proj_size should be a positive integer or zero to disable projections") + if proj_size >= hidden_size: + raise ValueError("proj_size has to be smaller than hidden_size") if mode == 'LSTM': gate_size = 4 * hidden_size @@ -75,20 +81,34 @@ def __init__(self, mode: str, input_size: int, hidden_size: int, self._all_weights = [] for layer in range(num_layers): for direction in range(num_directions): - layer_input_size = input_size if layer == 0 else hidden_size * num_directions + real_hidden_size = proj_size if proj_size > 0 else hidden_size + layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) - w_hh = Parameter(torch.Tensor(gate_size, hidden_size)) + w_hh = Parameter(torch.Tensor(gate_size, real_hidden_size)) b_ih = Parameter(torch.Tensor(gate_size)) # Second bias vector included for CuDNN compatibility. Only one # bias vector is needed in standard definition. b_hh = Parameter(torch.Tensor(gate_size)) - layer_params = (w_ih, w_hh, b_ih, b_hh) + layer_params: Tuple[Tensor, ...] = () + if self.proj_size == 0: + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh) + else: + layer_params = (w_ih, w_hh) + else: + w_hr = Parameter(torch.Tensor(proj_size, hidden_size)) + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) + else: + layer_params = (w_ih, w_hh, w_hr) suffix = '_reverse' if direction == 1 else '' param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] if bias: param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] + if self.proj_size > 0: + param_names += ['weight_hr_l{}{}'] param_names = [x.format(layer, suffix) for x in param_names] for name, param in zip(param_names, layer_params): @@ -146,10 +166,14 @@ def flatten_parameters(self) -> None: # an inplace operation on self._flat_weights with torch.no_grad(): if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 torch._cudnn_rnn_flatten_weight( - self._flat_weights, (4 if self.bias else 2), - self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers, - self.batch_first, bool(self.bidirectional)) + self._flat_weights, num_weights, + self.input_size, rnn.get_cudnn_mode(self.mode), + self.hidden_size, self.proj_size, self.num_layers, # type: ignore + self.batch_first, bool(self.bidirectional)) # type: ignore def _apply(self, fn): ret = super(RNNBase, self)._apply(fn) @@ -181,13 +205,16 @@ def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: if batch_sizes is not None: - mini_batch = batch_sizes[0] - mini_batch = int(mini_batch) + mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 - expected_hidden_size = (self.num_layers * num_directions, - mini_batch, self.hidden_size) + if self.proj_size > 0: + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.proj_size) + else: + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.hidden_size) return expected_hidden_size def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], @@ -206,18 +233,21 @@ def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): return hx return apply_permutation(hx, permutation) - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + def forward(self, + input: Union[Tensor, PackedSequence], + hx: Optional[Tensor] = None) -> Tuple[Union[Tensor, PackedSequence], Tensor]: is_packed = isinstance(input, PackedSequence) if is_packed: input, batch_sizes, sorted_indices, unsorted_indices = input - max_batch_size = batch_sizes[0] - max_batch_size = int(max_batch_size) + max_batch_size = int(batch_sizes[0]) else: + assert isinstance(input, Tensor) batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None + assert isinstance(input, Tensor) if hx is None: num_directions = 2 if self.bidirectional else 1 hx = torch.zeros(self.num_layers * num_directions, @@ -228,6 +258,7 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, T # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) + assert hx is not None self.check_forward_args(input, hx, batch_sizes) _impl = _rnn_impls[self.mode] if batch_sizes is None: @@ -236,6 +267,8 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, T else: result = _impl(input, batch_sizes, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) + + output: Union[Tensor, PackedSequence] output = result[0] hidden = result[1] @@ -245,6 +278,8 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, T def extra_repr(self) -> str: s = '{input_size}, {hidden_size}' + if self.proj_size != 0: + s += ', proj_size={proj_size}' if self.num_layers != 1: s += ', num_layers={num_layers}' if self.bias is not True: @@ -271,18 +306,27 @@ def __setstate__(self, d): for layer in range(num_layers): for direction in range(num_directions): suffix = '_reverse' if direction == 1 else '' - weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] + weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', + 'bias_hh_l{}{}', 'weight_hr_l{}{}'] weights = [x.format(layer, suffix) for x in weights] if self.bias: - self._all_weights += [weights] - self._flat_weights_names.extend(weights) + if self.proj_size > 0: + self._all_weights += [weights] + self._flat_weights_names.extend(weights) + else: + self._all_weights += [weights[:4]] + self._flat_weights_names.extend(weights[:4]) else: - self._all_weights += [weights[:2]] - self._flat_weights_names.extend(weights[:2]) + if self.proj_size > 0: + self._all_weights += [weights[:2]] + [weights[-1:]] + self._flat_weights_names.extend(weights[:2] + [weights[-1:]]) + else: + self._all_weights += [weights[:2]] + self._flat_weights_names.extend(weights[:2]) self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names] @property - def all_weights(self) -> List[Parameter]: + def all_weights(self) -> List[List[Parameter]]: return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] def _replicate_for_data_parallel(self): @@ -394,6 +438,8 @@ class RNN(RNNBase): """ def __init__(self, *args, **kwargs): + if 'proj_size' in kwargs: + raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") self.nonlinearity = kwargs.pop('nonlinearity', 'tanh') if self.nonlinearity == 'tanh': mode = 'RNN_TANH' @@ -446,6 +492,14 @@ class LSTM(RNNBase): dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random variable which is :math:`0` with probability :attr:`dropout`. + If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes + the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from + ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). + Second, the output hidden state of each layer will be multiplied by a learnable projection + matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output + of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact + dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. + Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` @@ -461,6 +515,7 @@ class LSTM(RNNBase): LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` + proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 Inputs: input, (h_0, c_0) - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features @@ -471,6 +526,8 @@ class LSTM(RNNBase): - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. If the LSTM is bidirectional, num_directions should be 2, else it should be 1. + If ``proj_size > 0`` was specified, the shape has to be + `(num_layers * num_directions, batch, proj_size)`. - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial cell state for each element in the batch. @@ -481,14 +538,16 @@ class LSTM(RNNBase): - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor containing the output features `(h_t)` from the last layer of the LSTM, for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been - given as the input, the output will also be a packed sequence. + given as the input, the output will also be a packed sequence. If ``proj_size > 0`` + was specified, output shape will be `(seq_len, batch, num_directions * proj_size)`. For the unpacked case, the directions can be separated using ``output.view(seq_len, batch, num_directions, hidden_size)``, with forward and backward being direction `0` and `1` respectively. Similarly, the directions can be separated in the packed case. - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor - containing the hidden state for `t = seq_len`. + containing the hidden state for `t = seq_len`. If ``proj_size > 0`` + was specified, ``h_n`` shape will be `(num_layers * num_directions, batch, proj_size)`. Like *output*, the layers can be separated using ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*. @@ -500,11 +559,15 @@ class LSTM(RNNBase): `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer - `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)` + `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` + was specified, the shape will be `(4*hidden_size, proj_size)`. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` + weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer + of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was + specified. .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` @@ -526,26 +589,39 @@ class LSTM(RNNBase): def __init__(self, *args, **kwargs): super(LSTM, self).__init__('LSTM', *args, **kwargs) - def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]): - self.check_input(input, batch_sizes) - expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.hidden_size) + return expected_hidden_size - self.check_hidden_size(hidden[0], expected_hidden_size, + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]): # type: ignore + self.check_input(input, batch_sizes) + self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes), 'Expected hidden[0] size {}, got {}') - self.check_hidden_size(hidden[1], expected_hidden_size, + self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes), 'Expected hidden[1] size {}, got {}') - def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]: + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]: # type: ignore if permutation is None: return hx return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) - @overload + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload # type: ignore @torch._jit_internal._overload_method # noqa: F811 def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811 pass + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented @overload @torch._jit_internal._overload_method # noqa: F811 def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None @@ -567,10 +643,14 @@ def forward(self, input, hx=None): # noqa: F811 if hx is None: num_directions = 2 if self.bidirectional else 1 - zeros = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) - hx = (zeros, zeros) + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + h_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, real_hidden_size, + dtype=input.dtype, device=input.device) + c_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + hx = (h_zeros, c_zeros) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. @@ -699,9 +779,11 @@ class GRU(RNNBase): """ def __init__(self, *args, **kwargs): + if 'proj_size' in kwargs: + raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") super(GRU, self).__init__('GRU', *args, **kwargs) - @overload + @overload # type: ignore @torch._jit_internal._overload_method # noqa: F811 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811 pass diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index b6997ca7701ab..1762297189fcf 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -34,7 +34,7 @@ class Embedding(Module): initialized from :math:`\mathcal{N}(0, 1)` Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` .. note:: @@ -50,6 +50,23 @@ class Embedding(Module): output. The gradient for this vector from :class:`~torch.nn.Embedding` is always zero. + .. note:: + When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the + :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be + modified in-place, performing a differentiable operation on ``Embedding.weight`` before + calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when + :attr:`max_norm` is not ``None``. For example:: + + n, d, m = 3, 5, 7 + embedding = nn.Embedding(n, d, max_norm=True) + W = torch.randn((m, d), requires_grad=True) + idx = torch.tensor([1, 2]) + a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable + b = embedding(idx) @ W.t() # modifies weight in-place + out = (a.unsqueeze(0) + b.unsqueeze(1)) + loss = out.sigmoid().prod() + loss.backward() + Examples:: >>> # an Embedding module containing 10 tensors of size 3 @@ -82,8 +99,8 @@ class Embedding(Module): num_embeddings: int embedding_dim: int - padding_idx: int - max_norm: float + padding_idx: Optional[int] + max_norm: Optional[float] norm_type: float scale_grad_by_freq: bool weight: Tensor @@ -112,10 +129,14 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona assert list(_weight.shape) == [num_embeddings, embedding_dim], \ 'Shape of weight does not match num_embeddings and embedding_dim' self.weight = Parameter(_weight) + self._fill_padding_idx_with_zero() self.sparse = sparse def reset_parameters(self) -> None: init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) @@ -186,11 +207,11 @@ class EmbeddingBag(Module): r"""Computes sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. - For bags of constant length and no :attr:`per_sample_weights`, this class + For bags of constant length and no :attr:`per_sample_weights` and 2D inputs, this class - * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``, - * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``, - * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``. + * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``, + * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``, + * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``. However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these operations. @@ -225,9 +246,11 @@ class EmbeddingBag(Module): weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` initialized from :math:`\mathcal{N}(0, 1)`. - Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and + Inputs: :attr:`input` (IntTensor or LongTensor), :attr:`offsets` (IntTensor or LongTensor, optional), and :attr:`per_index_weights` (Tensor, optional) + - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and @@ -267,7 +290,7 @@ class EmbeddingBag(Module): num_embeddings: int embedding_dim: int - max_norm: float + max_norm: Optional[float] norm_type: float scale_grad_by_freq: bool weight: Tensor diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 3e0b93c7afc02..0a07551000215 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -1,7 +1,8 @@ -from typing import List from torch._six import container_abcs + from itertools import repeat +from typing import List def _ntuple(n): @@ -26,8 +27,7 @@ def _reverse_repeat_tuple(t, n): return tuple(x for x in reversed(t) for _ in range(n)) -def _list_with_default(out_size, defaults): - # type: (List[int], List[int]) -> List[int] +def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size): diff --git a/torch/nn/parallel/_functions.py b/torch/nn/parallel/_functions.py index 55a9d8c4db6ae..dd42d9a05dfbd 100644 --- a/torch/nn/parallel/_functions.py +++ b/torch/nn/parallel/_functions.py @@ -4,16 +4,17 @@ from . import comm from torch.autograd import Function from torch._utils import _get_device_index +from typing import List, Optional class Broadcast(Function): @staticmethod def forward(ctx, target_gpus, *inputs): - assert all(map(lambda i: i.device.type != 'cpu', inputs)), ( + assert all(i.device.type != 'cpu' for i in inputs), ( 'Broadcast function not implemented for CPU tensors' ) - target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus)) + target_gpus = [_get_device_index(x, True) for x in target_gpus] ctx.target_gpus = target_gpus if len(inputs) == 0: return tuple() @@ -39,9 +40,9 @@ class ReduceAddCoalesced(Function): def forward(ctx, destination, num_inputs, *grads): ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)] - grads = [grads[i:i + num_inputs] - for i in range(0, len(grads), num_inputs)] - return comm.reduce_add_coalesced(grads, destination) + grads_ = [grads[i:i + num_inputs] + for i in range(0, len(grads), num_inputs)] + return comm.reduce_add_coalesced(grads_, destination) @staticmethod def backward(ctx, *grad_outputs): @@ -52,13 +53,13 @@ class Gather(Function): @staticmethod def forward(ctx, target_device, dim, *inputs): - assert all(map(lambda i: i.device.type != 'cpu', inputs)), ( + assert all(i.device.type != 'cpu' for i in inputs), ( 'Gather function not implemented for CPU tensors' ) target_device = _get_device_index(target_device, True) ctx.target_device = target_device ctx.dim = dim - ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs)) + ctx.input_gpus = tuple(i.get_device() for i in inputs) if all(t.dim() == 0 for t in inputs) and dim == 0: inputs = tuple(t.view(1) for t in inputs) warnings.warn('Was asked to gather along dimension 0, but all ' @@ -67,7 +68,7 @@ def forward(ctx, target_device, dim, *inputs): ctx.unsqueezed_scalar = True else: ctx.unsqueezed_scalar = False - ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs)) + ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) return comm.gather(inputs, ctx.dim, ctx.target_device) @staticmethod @@ -82,7 +83,7 @@ class Scatter(Function): @staticmethod def forward(ctx, target_gpus, chunk_sizes, dim, input): - target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus)) + target_gpus = [_get_device_index(x, True) for x in target_gpus] ctx.dim = dim ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 streams = None @@ -105,10 +106,10 @@ def backward(ctx, *grad_output): # background streams used for copying -_streams = None +_streams: Optional[List[Optional[torch.cuda.Stream]]] = None -def _get_stream(device): +def _get_stream(device: int): """Gets a background stream for copying between CPU and GPU""" global _streams if device == -1: diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index f38d7fcaafc48..587f88eb1b995 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -2,13 +2,13 @@ import torch from torch.cuda import nccl from torch._utils import _take_tensors, _flatten_dense_tensors, \ - _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index - + _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index, _handle_complex +from typing import List def broadcast(tensor, devices=None, *, out=None): r"""Broadcasts a tensor to specified GPU devices. - Arguments: + Args: tensor (Tensor): tensor to broadcast. Can be on CPU or GPU. devices (Iterable[torch.device, str or int], optional): an iterable of GPU devices, among which to broadcast. @@ -26,6 +26,7 @@ def broadcast(tensor, devices=None, *, out=None): a tuple containing :attr:`out` tensors, each containing a copy of :attr:`tensor`. """ + tensor = _handle_complex(tensor) if not ((devices is None) ^ (out is None)): raise RuntimeError( "Exactly one of 'devices' and 'out' must be specified, but got " @@ -42,7 +43,7 @@ def broadcast_coalesced(tensors, devices, buffer_size=10485760): Small tensors are first coalesced into a buffer to reduce the number of synchronizations. - Arguments: + Args: tensors (sequence): tensors to broadcast. Must be on the same device, either CPU or GPU. devices (Iterable[torch.device, str or int]): an iterable of GPU @@ -53,6 +54,7 @@ def broadcast_coalesced(tensors, devices, buffer_size=10485760): A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`. """ devices = [_get_device_index(d) for d in devices] + tensors = [_handle_complex(t) for t in tensors] return torch._C._broadcast_coalesced(tensors, devices, buffer_size) @@ -62,7 +64,7 @@ def reduce_add(inputs, destination=None): All inputs should have matching shapes, dtype, and layout. The output tensor will be of the same shape, dtype, and layout. - Arguments: + Args: inputs (Iterable[Tensor]): an iterable of tensors to add. destination (int, optional): a device on which the output will be placed (default: current device). @@ -108,7 +110,7 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): Small tensors are first coalesced into a buffer to reduce the number of synchronizations. - Arguments: + Args: inputs (Iterable[Iterable[Tensor]]): iterable of iterables that contain tensors from a single device. destination (int, optional): a device on which the output will be @@ -121,7 +123,7 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): """ # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just # return `inputs`. - dense_tensors = [[] for _ in inputs] # shape (num_gpus, num_tensors) + dense_tensors: List[List] = [[] for _ in inputs] # shape (num_gpus, num_tensors) output = [] ref_order = [] # process sparse ones first since they may have different sizes on different gpus @@ -150,7 +152,7 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None): """Scatters tensor across multiple GPUs. - Arguments: + Args: tensor (Tensor): tensor to scatter. Can be on CPU or GPU. devices (Iterable[torch.device, str or int], optional): an iterable of GPU devices, among which to scatter. @@ -181,6 +183,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out= a tuple containing :attr:`out` tensors, each containing a chunk of :attr:`tensor`. """ + tensor = _handle_complex(tensor) if out is None: devices = [_get_device_index(d) for d in devices] return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) @@ -195,10 +198,11 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out= "but got chunk_sizes={}".format(chunk_sizes)) return tuple(torch._C._scatter_out(tensor, out, dim, streams)) + def gather(tensors, dim=0, destination=None, *, out=None): r"""Gathers tensors from multiple GPU devices. - Arguments: + Args: tensors (Iterable[Tensor]): an iterable of tensors to gather. Tensor sizes in all dimensions other than :attr:`dim` have to match. dim (int, optional): a dimension along which the tensors will be @@ -221,6 +225,7 @@ def gather(tensors, dim=0, destination=None, *, out=None): the :attr:`out` tensor, now containing results of concatenating :attr:`tensors` along :attr:`dim`. """ + tensors = [_handle_complex(t) for t in tensors] if out is None: if destination == -1: warnings.warn( diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index b66c1513ad86c..5dd438df53bfb 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -19,7 +19,7 @@ def _check_balance(device_ids): has less than 75% of the memory or cores of GPU {}. You can do so by setting the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES environment variable.""" - device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) + device_ids = [_get_device_index(x, True) for x in device_ids] dev_props = _get_devices_properties(device_ids) def warn_imbalance(get_prop): @@ -135,7 +135,7 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0): self.dim = dim self.module = module - self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) + self.device_ids = [_get_device_index(x, True) for x in device_ids] self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device(device_type, self.device_ids[0]) @@ -155,6 +155,12 @@ def forward(self, *inputs, **kwargs): "them on device: {}".format(self.src_device_obj, t.device)) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + # for forward function without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not kwargs: + inputs = ((),) + kwargs = ({},) + if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) @@ -190,7 +196,7 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo output_device """ if not isinstance(inputs, tuple): - inputs = (inputs,) + inputs = (inputs,) if inputs is not None else () device_type = _get_available_device_type() @@ -200,7 +206,7 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo if output_device is None: output_device = device_ids[0] - device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) + device_ids = [_get_device_index(x, True) for x in device_ids] output_device = _get_device_index(output_device, True) src_device_obj = torch.device(device_type, device_ids[0]) @@ -211,6 +217,12 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo "them on device: {}".format(src_device_obj, t.device)) inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) + # for module without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not module_kwargs: + inputs = ((),) + module_kwargs = ({},) + if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) used_device_ids = device_ids[:len(inputs)] diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 5ec2b0148a215..bb58105ead597 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -5,18 +5,23 @@ import inspect import logging import warnings +from typing import NamedTuple import torch from . import comm import torch.distributed as dist +RPC_AVAILABLE = False if dist.is_available(): from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import ReduceOp +if torch.distributed.rpc.is_available(): + RPC_AVAILABLE = True + from torch.distributed.rpc import RRef from ..modules import Module from .replicate import replicate -from .scatter_gather import scatter_kwargs, gather +from .scatter_gather import scatter_kwargs, gather, is_namedtuple from .parallel_apply import parallel_apply from torch._utils import _get_device_index, _get_all_device_indices @@ -25,6 +30,12 @@ def _find_tensors(obj): r""" Recursively find all tensors contained in the specified object. """ + if RPC_AVAILABLE and isinstance(obj, RRef): + # If the current node is the owner of the RRef, unwrap it and try to + # find Tensors. + # TODO: Expand to remote RRefs. + if obj.is_owner(): + return _find_tensors(obj.local_value()) if isinstance(obj, torch.Tensor): return [obj] if isinstance(obj, (list, tuple)): @@ -90,6 +101,12 @@ def _dump_DDP_relevant_env_vars(): print(formatted_output) + +class _DDPUnevenInputsConfig(NamedTuple): + ddp_join_enabled: bool + ddp_join_divide_by_initial_world_size: bool + + class DistributedDataParallel(Module): r"""Implements distributed data parallelism that is based on ``torch.distributed`` package at the module level. @@ -112,37 +129,36 @@ class DistributedDataParallel(Module): :class:`torch.nn.DataParallel` for single-node multi-GPU data parallel training. - Here is how to use it: on each host with N GPUs, you should spawn up N - processes, while ensuring that each process individually works on a single GPU - from 0 to N-1. Therefore, it is your job to ensure that your training script - operates on a single given GPU by calling: + To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn + up ``N`` processes, ensuring that each process exclusively works on a single + GPU from 0 to N-1. This can be done by either setting + ``CUDA_VISIBLE_DEVICES`` for every process or by calling: >>> torch.cuda.set_device(i) where i is from 0 to N-1. In each process, you should refer the following to construct this module: - >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') + >>> torch.distributed.init_process_group( + >>> backend='nccl', world_size=N, init_method='...' + >>> ) >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i) In order to spawn up multiple processes per node, you can use either - ``torch.distributed.launch`` or ``torch.multiprocessing.spawn`` + ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``. .. note :: Please refer to `PyTorch Distributed Overview `__ for a brief introduction to all features related to distributed training. - .. note:: ``nccl`` backend is currently the fastest and - highly recommended backend to be used with Multi-Process Single-GPU - distributed training and this applies to both single-node and multi-node - distributed training + .. note:: ``nccl`` backend is currently the fastest and highly recommended + backend when using GPUs. This applies to both single-node and + multi-node distributed training. .. note:: This module also supports mixed-precision distributed training. This means that your model can have different types of parameters such - as mixed types of fp16 and fp32, the gradient reduction on these + as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these mixed types of parameters will just work fine. - Also note that ``nccl`` backend is currently the fastest and highly - recommended backend for fp16/fp32 mixed-precision training. .. note:: If you use ``torch.save`` on one process to checkpoint the module, and ``torch.load`` on some other processes to recover it, make sure that @@ -152,19 +168,63 @@ class DistributedDataParallel(Module): .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the gradient will be ``M`` times smaller when compared to the same model - trained on a single node with ``batch=M*N`` (because the gradients + trained on a single node with ``batch=M*N`` if the loss is summed (NOT + averaged as usual) across instances in a batch (because the gradients between different nodes are averaged). You should take this into consideration when you want to obtain a mathematically equivalent - training process compared to the non-DistributedDataParallel - counterpart. + training process compared to the local training counterpart. But in most + cases, you can just treat a DistributedDataParallel wrapped model, a + DataParallel wrapped model and an ordinary model on a single GPU as the + same (E.g. using the same learning rate for equivalent batch size). - .. warning:: - This module works only with the ``gloo`` and ``nccl`` backends. + .. note:: + Parameters are never broadcast between processes. The module performs + an all-reduce step on gradients and assumes that they will be modified + by the optimizer in all processes in the same way. Buffers + (e.g. BatchNorm stats) are broadcast from the module in process of rank + 0, to all other replicas in the system in every iteration. + + .. note:: + If you are using DistributedDataParallel in conjunction with the + :ref:`distributed-rpc-framework`, you should always use + :meth:`torch.distributed.autograd.backward` to compute gradients and + :class:`torch.distributed.optim.DistributedOptimizer` for optimizing + parameters. + + Example:: + + >>> import torch.distributed.autograd as dist_autograd + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch import optim + >>> from torch.distributed.optim import DistributedOptimizer + >>> from torch.distributed.rpc import RRef + >>> + >>> t1 = torch.rand((3, 3), requires_grad=True) + >>> t2 = torch.rand((3, 3), requires_grad=True) + >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) + >>> ddp_model = DDP(my_model) + >>> + >>> # Setup optimizer + >>> optimizer_params = [rref] + >>> for param in ddp_model.parameters(): + >>> optimizer_params.append(RRef(param)) + >>> + >>> dist_optim = DistributedOptimizer( + >>> optim.SGD, + >>> optimizer_params, + >>> lr=0.05, + >>> ) + >>> + >>> with dist_autograd.context() as context_id: + >>> pred = ddp_model(rref.to_here()) + >>> loss = loss_func(pred, loss) + >>> dist_autograd.backward(context_id, loss) + >>> dist_optim.step() .. warning:: Constructor, forward method, and differentiation of the output (or a - function of the output of this module) is a distributed synchronization - point. Take that into account in case different processes might be + function of the output of this module) are distributed synchronization + points. Take that into account in case different processes might be executing different code. .. warning:: @@ -175,7 +235,7 @@ class DistributedDataParallel(Module): .. warning:: This module assumes all parameters are registered in the model of each distributed processes are in the same order. The module itself will - conduct gradient all-reduction following the reverse order of the + conduct gradient ``allreduce`` following the reverse order of the registered parameters of the model. In other words, it is users' responsibility to ensure that each distributed process has the exact same model and thus the exact same parameter registration order. @@ -208,144 +268,89 @@ class DistributedDataParallel(Module): .. warning:: You should never try to change your model's parameters after wrapping - up your model with DistributedDataParallel. In other words, when - wrapping up your model with DistributedDataParallel, the constructor of - DistributedDataParallel will register the additional gradient + up your model with ``DistributedDataParallel``. Because, when + wrapping up your model with ``DistributedDataParallel``, the constructor + of ``DistributedDataParallel`` will register the additional gradient reduction functions on all the parameters of the model itself at the - time of construction. If you change the model's parameters after - the DistributedDataParallel construction, this is not supported and - unexpected behaviors can happen, since some parameters' gradient - reduction functions might not get called. - - .. note:: - Parameters are never broadcast between processes. The module performs - an all-reduce step on gradients and assumes that they will be modified - by the optimizer in all processes in the same way. Buffers - (e.g. BatchNorm stats) are broadcast from the module in process of rank - 0, to all other replicas in the system in every iteration. - - .. note:: - If you are using DistributedDataParallel in conjunction with the - :ref:`distributed-rpc-framework`, you should always use - :meth:`torch.distributed.autograd.backward` to compute gradients and - :class:`torch.distributed.optim.DistributedOptimizer` for optimizing + time of construction. If you change the model's parameters afterwards, + gradient redunction functions no longer match the correct set of parameters. - Example:: - - >>> import torch.distributed.autograd as dist_autograd - >>> from torch.nn.parallel import DistributedDataParallel as DDP - >>> from torch import optim - >>> from torch.distributed.optim import DistributedOptimizer - >>> from torch.distributed.rpc import RRef - >>> - >>> t1 = torch.rand((3, 3), requires_grad=True) - >>> t2 = torch.rand((3, 3), requires_grad=True) - >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) - >>> ddp_model = DDP(my_model) - >>> - >>> # Setup optimizer - >>> optimizer_params = [rref] - >>> for param in ddp_model.parameters(): - >>> optimizer_params.append(RRef(param)) - >>> - >>> dist_optim = DistributedOptimizer( - >>> optim.SGD, - >>> optimizer_params, - >>> lr=0.05, - >>> ) - >>> - >>> with dist_autograd.context() as context_id: - >>> pred = ddp_model(rref.to_here()) - >>> loss = loss_func(pred, loss) - >>> dist_autograd.backward(context_id, loss) - >>> dist_optim.step() - .. warning:: - Using DistributedDataParallel in conjuction with the + Using ``DistributedDataParallel`` in conjunction with the :ref:`distributed-rpc-framework` is experimental and subject to change. + .. warning:: + The ``gradient_as_bucket_view`` mode does not yet work with Automatic + Mixed Precision (AMP). AMP maintains stashed gradients that are used for + unscaling gradients. With ``gradient_as_bucket_view=True``, these + stashed gradients will point to communication buckets in the first + iteration. In the next iteration, the communication buckets are mutated + and thus these stashed gradients will be unexpectedly mutated as well, + which might lead to wrong results. + Args: module (Module): module to be parallelized device_ids (list of int or torch.device): CUDA devices. This should only be provided when the input module resides on a single - CUDA device. For single-device modules, the ``i``th + CUDA device. For single-device modules, the i'th :attr:`module` replica is placed on ``device_ids[i]``. For - multi-device modules and CPU modules, device_ids must be None - or an empty list, and input data for the forward pass must be - placed on the correct device. (default: all devices for - single-device modules) - output_device (int or torch.device): device location of output for + multi-device modules and CPU modules, ``device_ids`` must be + ``None`` or an empty list, and input data for the forward + pass must be placed on the correct device. (default: all + visible devices for single-device modules) + output_device (int or torch.device): Device location of output for single-device CUDA modules. For multi-device modules and - CPU modules, it must be None, and the module itself - dictates the output location. (default: device_ids[0] for - single-device modules) - broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of - the module at beginning of the forward function. - (default: ``True``) - process_group: the process group to be used for distributed data + CPU modules, it must be ``None``, and the module itself + dictates the output location. (default: ``device_ids[0]`` + for single-device modules) + broadcast_buffers (bool): Flag that enables syncing (broadcasting) + buffers of the module at beginning of the ``forward`` + function. (default: ``True``) + process_group: The process group to be used for distributed data all-reduction. If ``None``, the default process group, which - is created by ```torch.distributed.init_process_group```, + is created by :func:`torch.distributed.init_process_group`, will be used. (default: ``None``) - bucket_cap_mb: DistributedDataParallel will bucket parameters into + bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. - :attr:`bucket_cap_mb` controls the bucket size in MegaBytes (MB) - (default: 25) - find_unused_parameters (bool): Traverse the autograd graph of all tensors - contained in the return value of the wrapped - module's ``forward`` function. - Parameters that don't receive gradients as - part of this graph are preemptively marked - as being ready to be reduced. Note that all - ``forward`` outputs that are derived from - module parameters must participate in - calculating loss and later the gradient - computation. If they don't, this wrapper will - hang waiting for autograd to produce gradients - for those parameters. Any outputs derived from - module parameters that are otherwise unused can - be detached from the autograd graph using - ``torch.Tensor.detach``. (default: ``False``) - check_reduction: when setting to ``True``, it enables DistributedDataParallel - to automatically check if the previous iteration's - backward reductions were successfully issued at the - beginning of every iteration's forward function. - You normally don't need this option enabled unless you - are observing weird behaviors such as different ranks - are getting different gradients, which should not - happen if DistributedDataParallel is correctly used. - (default: ``False``) - gradient_as_bucket_view (bool): this is a prototype feature. When set to ``True``, - gradients will be views pointing to different offsets of - allreduce communication buckets. This can reduce peak memory - usage, where the saved memory size will be equal to the total - gradients size. Moreover, it avoids the overhead of copying - between gradients and allreduce communication buckets. - When gradients are views, ``detach_()`` cannot be called on the - gradients. If hitting such errors, please fix it by referring to - the :meth:`~torch.optim.Optimizer.zero_grad` function in - ``torch/optim/optimizer.py`` as the solution. - Warning! It is also found that ``gradient_as_bucket_view = true`` - does not work as expected when ``apex.amp`` is used for - mixed precision training. ``apex.amp`` maintained stashed gradients - that are used for unscaling gradients. These stashed gradients - are pointed to gradients (will be communication buckets when - ``gradient_as_bucket_view = true``) before starting new iteration. - In new iteration, the communication buckets are mutated and thus - these stashed gradients will be unexpectedly mutated as well, - the unexpectedly muated stashed gradients may result in wrong - results. To fix it, these stashed gradients should not be pointed - to gradients, instead they should be copied from gradients when - ``gradient_as_bucket_view = true``. + :attr:`bucket_cap_mb` controls the bucket size in + MegaBytes (MB). (default: 25) + find_unused_parameters (bool): Traverse the autograd graph from all + tensors contained in the return value of the + wrapped module's ``forward`` function. Parameters + that don't receive gradients as part of this + graph are preemptively marked as being ready to + be reduced. Note that all ``forward`` outputs + that are derived from module parameters must + participate in calculating loss and later the + gradient computation. If they don't, this wrapper + will hang waiting for autograd to produce + gradients for those parameters. Any outputs + derived from module parameters that are otherwise + unused can be detached from the autograd graph + using ``torch.Tensor.detach``. (default: ``False``) + check_reduction: This argument is deprecated. + gradient_as_bucket_view (bool): This is a prototype feature and subject + to changes. When set to ``True``, gradients will be views + pointing to different offsets of ``allreduce`` communication + buckets. This can reduce peak memory usage, where the + saved memory size will be equal to the total gradients + size. Moreover, it avoids the overhead of copying between + gradients and ``allreduce`` communication buckets. When + gradients are views, ``detach_()`` cannot be called on the + gradients. If hitting such errors, please fix it by + referring to the :meth:`~torch.optim.Optimizer.zero_grad` + function in ``torch/optim/optimizer.py`` as a solution. + Attributes: - module (Module): the module to be parallelized + module (Module): the module to be parallelized. Example:: >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') - >>> net = torch.nn.DistributedDataParallel(model, pg) + >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) """ def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, @@ -384,7 +389,7 @@ def __init__(self, module, device_ids=None, if device_ids is None: device_ids = _get_all_device_indices() - self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) + self.device_ids = [_get_device_index(x, True) for x in device_ids] if output_device is None: output_device = device_ids[0] @@ -403,16 +408,32 @@ def __init__(self, module, device_ids=None, self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True - self.ddp_join_enabled = False + self.ddp_uneven_inputs_config = _DDPUnevenInputsConfig( + ddp_join_enabled=False, ddp_join_divide_by_initial_world_size=False + ) self.gradient_as_bucket_view = gradient_as_bucket_view + if hasattr(module, '_ddp_params_and_buffers_to_ignore'): + self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore + else: + self.parameters_to_ignore = [] if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters # do not receive gradients. + warnings.warn( + "The `check_reduction` argument in `DistributedDataParallel` " + "module is deprecated. Please avoid using it." + ) pass - # used for intra-node param sync and inter-node sync as well + # Check that a module does not have Uninitialized parameters + for param in module.parameters(): + if isinstance(param, torch.nn.parameter.UninitializedParameter): + raise RuntimeError( + 'Modules with uninitialized parameters can\'t be used with `DistributedDataParallel`. ' + 'Run a dummy forward pass to correctly initialize the modules') + # used for intra-node param sync and inter-node sync as wel self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size @@ -424,7 +445,11 @@ def __init__(self, module, device_ids=None, self._ddp_init_helper() def _sync_params_and_buffers(self, authoritative_rank=0): - module_states = list(self.module.state_dict().values()) + module_states = [] + for name, param in self.module.state_dict().items(): + if name not in self.parameters_to_ignore: + module_states.append(param) + if len(module_states) > 0: self._distributed_broadcast_coalesced( module_states, @@ -490,17 +515,55 @@ def model_parameters(m): self._module_copies = [self.module] self.modules_params = [list(parameters(m)) for m in self._module_copies] - self.modules_buffers = [list(m.buffers()) for m in self._module_copies] - - # Build tuple of (module, parameter) for all parameters that require grads. - modules_and_parameters = [ + # Collect buffers for modules, filtering out buffers that should be ignored. + named_module_buffers = [ + [(buffer, buffer_name) for buffer_name, buffer in m.named_buffers()] + for m in self._module_copies + ] + self.modules_buffers = [ [ - (module, parameter) - for module in replica.modules() - for parameter in filter( - lambda parameter: parameter.requires_grad, - parameters(module, recurse=False)) - ] for replica in self._module_copies] + buffer + for (buffer, buffer_name) in module_buffers + if buffer_name not in self.parameters_to_ignore + ] + for module_buffers in named_module_buffers + ] + # Build tuple of (module, parameter) for all parameters that require grads. + if self.device_ids and len(self.device_ids) > 1: + # Single-process multi-device mode,does not support self.parameters_to_ignore. + if self.parameters_to_ignore: + raise ValueError( + "Single-Process multi-device mode does not " + "support ignoring parameters upfront. Please consider " + "using one DDP instance per device." + ) + + modules_and_parameters = [ + [ + (module, parameter) + for module in replica.modules() + for parameter in filter( + lambda parameter: parameter.requires_grad, + parameters(module, recurse=False)) + ] for replica in self._module_copies] + else: + modules_and_parameters = [ + [ + (module, parameter) + for module_name, module in replica.named_modules() + for parameter in [ + param + # Note that we access module.named_parameters instead of + # parameters(module). parameters(module) is only needed in the + # single-process multi device case, where it accesses replicated + # parameters through _former_parameters. + for param_name, param in module.named_parameters(recurse=False) + if param.requires_grad + and f"{module_name}.{param_name}" not in self.parameters_to_ignore + ] + ] + for replica in self._module_copies + ] # Build list of parameters. parameters = [ @@ -586,11 +649,11 @@ def no_sync(self): Example:: - >>> ddp = torch.nn.DistributedDataParallel(model, pg) + >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) >>> with ddp.no_sync(): - ... for input in inputs: - ... ddp(input).backward() # no synchronization, accumulate grads - ... ddp(another_input).backward() # synchronize grads + >>> for input in inputs: + >>> ddp(input).backward() # no synchronization, accumulate grads + >>> ddp(another_input).backward() # synchronize grads """ old_require_backward_grad_sync = self.require_backward_grad_sync self.require_backward_grad_sync = False @@ -600,13 +663,13 @@ def no_sync(self): self.require_backward_grad_sync = old_require_backward_grad_sync def forward(self, *inputs, **kwargs): - if self.ddp_join_enabled: + if self.ddp_uneven_inputs_config.ddp_join_enabled: ones = torch.ones( 1, device=self.device ) work = dist.all_reduce(ones, group=self.process_group, async_op=True) self.reducer._set_forward_pass_work_handle( - work, self.ddp_join_divide_by_initial_world_size + work, self.ddp_uneven_inputs_config.ddp_join_divide_by_initial_world_size ) # Calling _rebuild_buckets before forward compuation, @@ -621,15 +684,16 @@ def forward(self, *inputs, **kwargs): if self.require_forward_param_sync: self._sync_params() - if self.ddp_join_enabled: + if self.ddp_uneven_inputs_config.ddp_join_enabled: # Notify joined ranks whether they should sync in backwards pass or not. self._check_global_requires_backward_grad_sync(is_joined_rank=False) if self.device_ids: - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: + inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0]) output = self.module(*inputs[0], **kwargs[0]) else: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: @@ -654,6 +718,41 @@ def forward(self, *inputs, **kwargs): def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + def _recursive_to(self, inputs, target_gpu): + r""" + Recursively moves input to the target_gpu. + """ + def to_map(obj): + if isinstance(obj, torch.Tensor): + return (obj.to(target_gpu), ) + if is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(to_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return [list(i) for i in zip(*map(to_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None + return res + + def to_kwargs(self, inputs, kwargs, device_id): + inputs = self._recursive_to(inputs, device_id) if inputs else [] + kwargs = self._recursive_to(kwargs, device_id) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs + def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) @@ -807,7 +906,9 @@ def join(self, divide_by_initial_world_size=True, enable=True): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> torch.cuda.set_device(rank) >>> model = nn.Linear(1, 1, bias=False).to(rank) - >>> model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) + >>> model = torch.nn.parallel.DistributedDataParallel( + >>> model, device_ids=[rank], output_device=rank + >>> ) >>> # Rank 1 gets one more input than rank 0. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] >>> with model.join(): @@ -827,8 +928,10 @@ def join(self, divide_by_initial_world_size=True, enable=True): to spawn a single process that works on a single GPU.""" ) has_error = False - self.ddp_join_enabled = enable - self.ddp_join_divide_by_initial_world_size = divide_by_initial_world_size + self.ddp_uneven_inputs_config = _DDPUnevenInputsConfig( + ddp_join_enabled=enable, + ddp_join_divide_by_initial_world_size=divide_by_initial_world_size, + ) yield except Exception as e: # Set to skip any processing in the finally block. @@ -895,9 +998,9 @@ def join(self, divide_by_initial_world_size=True, enable=True): # All procs joined. Agree on authoritative rank and broadcast the model. self._sync_final_model(is_last_joiner) - def _register_comm_hook(self, state: object, hook: callable): + def register_comm_hook(self, state: object, hook: callable): r""" - Register a communication hook which is an enhancement that provides a + Registers a communication hook which is an enhancement that provides a flexible hook to users where they can specify how DDP aggregates gradients across multiple workers. @@ -906,7 +1009,7 @@ def _register_comm_hook(self, state: object, hook: callable): and gradient compression which involve different communication strategies for parameter syncs while running Distributed DataParallel training. - Arguments: + Args: state (object): state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error @@ -946,7 +1049,7 @@ def _register_comm_hook(self, state: object, hook: callable): .. warning :: ``get_future`` API supports only NCCL backend and will return a ``torch._C.Future`` which is an internal type and should be used with caution. It can still be used by - ``_register_comm_hook`` API, but it is subject to some subtle differences compared + ``register_comm_hook`` API, but it is subject to some subtle differences compared to ``torch.futures.Future``. .. warning :: @@ -960,7 +1063,7 @@ def _register_comm_hook(self, state: object, hook: callable): >>> fut.set_result(bucket.get_tensors()) >>> return fut - >>> ddp._register_comm_hook(state = None, hook = noop) + >>> ddp.register_comm_hook(state = None, hook = noop) Example:: Below is an example of a Parallel SGD algorithm where gradients are encoded before @@ -976,12 +1079,46 @@ def _register_comm_hook(self, state: object, hook: callable): >>> return decoded_tensors >>> return fut.then(decode) - >>> ddp._register_comm_hook(state = None, hook = encode_and_decode) + >>> ddp.register_comm_hook(state = None, hook = encode_and_decode) """ self._check_comm_hook(hook) dist._register_comm_hook(self.reducer, state, hook) + def _register_builtin_comm_hook( + self, comm_hook_type + ): + r""" + Registers a built-in communication hook that specifies how DDP + aggregates gradients across multiple workers. + The built-in hooks aim to provide efficient C++ implementations for certain hooks, + which might not be as efficient if implemented in Python using a Python communication hook. + + Args: + comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as + ALLREDUCE, FP16_COMPRESS, etc. + + .. warning :: + DDP communication hook can only be registered once and should be registered + before calling backward. + + .. warning :: + DDP communication hook does not support single-process multiple-device mode. + Gradbucket tensors should consist of only a single tensor. + + .. warning :: + DDP communication hook is experimental and subject to change. + + Example:: + Below is an example of a FP16 compression where gradients are + compressed into 16-bit floating-point numbers before allreduce, and + then decompressed after allreduce. + + >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS) + + """ + dist._register_builtin_comm_hook(self.reducer, comm_hook_type) + def _distributed_broadcast_coalesced( self, tensors, buffer_size, authoritative_rank=0 ): @@ -1047,7 +1184,7 @@ def _sync_params(self): # If we are running DDP with the join manager, we have to agree # upon a rank to sync module buffers from, since rank 0 may # already have been joined and have stale module buffers. - if self.ddp_join_enabled: + if self.ddp_uneven_inputs_config.ddp_join_enabled: authoritative_rank = self._find_common_rank(dist.get_rank(), True) else: # The process with rank 0 is considered the authoritative copy. @@ -1098,3 +1235,12 @@ def _check_comm_hook(self, hook): raise ValueError( "Communication hook: return annotation should be torch.futures.Future or torch._C.Future." ) + + @staticmethod + def _set_params_and_buffers_to_ignore_for_model( + module, params_and_buffers_to_ignore + ): + # This is a workaround to set parameters and buffers DDP should ignore + # during synchronization. It will be removed when the API is finalized + # as part of addressing https://github.com/pytorch/pytorch/issues/43690. + module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py index 94d0f7c9ea1a2..06ab69332e16a 100644 --- a/torch/nn/parallel/parallel_apply.py +++ b/torch/nn/parallel/parallel_apply.py @@ -44,7 +44,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): assert len(modules) == len(devices) else: devices = [None] * len(modules) - devices = list(map(lambda x: _get_device_index(x, True), devices)) + devices = [_get_device_index(x, True) for x in devices] lock = threading.Lock() results = {} grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index bb104d7cc38d7..8effeece59081 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -44,7 +44,7 @@ def descendant_modules(module): if memo is None: memo = set() - # memorize visited modules + # memoize visited modules memo.add(module) if _is_script_module(module): memo.update(descendant_modules(module)) @@ -80,7 +80,10 @@ def replicate(network, devices, detach=False): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") - devices = list(map(lambda x: _get_device_index(x, True), devices)) + if not devices: + return [] + + devices = [_get_device_index(x, True) for x in devices] num_replicas = len(devices) params = list(network.parameters()) @@ -105,7 +108,6 @@ def replicate(network, devices, detach=False): modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} - scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"} for i, module in enumerate(modules): module_indices[module] = i diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 1635d40e29e8d..771fbba68f02c 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,6 +1,12 @@ import torch from ._functions import Scatter, Gather +def is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + def scatter(inputs, target_gpus, dim=0): r""" @@ -11,12 +17,14 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) + if is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: - return list(map(list, zip(*map(scatter_map, obj)))) + return [list(i) for i in zip(*map(scatter_map, obj))] if isinstance(obj, dict) and len(obj) > 0: - return list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index 9749a70d024b2..c5d63708e9085 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -15,7 +15,7 @@ class Parameter(torch.Tensor): the model. If there was no such class as :class:`Parameter`, these temporaries would get registered too. - Arguments: + Args: data (Tensor): parameter tensor. requires_grad (bool, optional): if the parameter requires gradient. See :ref:`excluding-subgraphs` for more details. Default: `True` @@ -44,3 +44,97 @@ def __reduce_ex__(self, proto): ) __torch_function__ = _disabled_torch_function_impl + + +class UninitializedParameter(Parameter): + r"""A parameter that is not initialized. + + Unitialized Parameters are a a special case of :class:`torch.nn.Parameter` + where the shape of the data is still unknown. + + Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.nn.Parameter`. + """ + _allowed_methods = [ + torch.Tensor.__hash__, + torch.Tensor.size, + torch.Tensor.copy_, + torch.Tensor.is_floating_point, + torch.Tensor.half, + torch.Tensor.float, + torch.Tensor.double, + torch.Tensor.char, + torch.Tensor.short, + torch.Tensor.int, + torch.Tensor.long, + torch.Tensor.cuda, + torch.Tensor.cpu, + torch.Tensor.to, + torch.Tensor.get_device, + torch._has_compatible_shallow_copy_type] + + def __new__(cls, requires_grad=True): + data = torch.Tensor() + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def materialize(self, shape, device=None, dtype=None): + r"""Create a Parameter with the same properties of the uninitialized one. + Given a shape, it materializes a parameter in the same device + and with the same `dtype` as the current one or the specified ones in the + arguments. + + Args: + shape : (tuple): the shape for the materialized tensor. + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module. Optional. + dtype (:class:`torch.dtype`): the desired floating point type of + the floating point parameters and buffers in this module. Optional. + """ + if device is None: + device = self.data.device + if dtype is None: + dtype = self.data.dtype + self.data = torch.empty(shape, device=device, dtype=dtype) + self.__class__ = Parameter + + @property + def shape(self): + raise RuntimeError( + 'Can\'t access the shape of an uninitialized parameter. ' + 'This error usually happens in `load_state_dict` when trying to load ' + 'an uninitialized parameter into an initialized one. ' + 'Call `forward` to initialize the parameters before accessing their attributes.') + + def share_memory_(self): + raise RuntimeError( + 'Can\'t share memory on an uninitialized parameter. ' + 'Call `forward` to initialize the parameters before calling ' + '`module.share_memory()`.') + + def __repr__(self): + return 'Uninitialized parameter' + + def __reduce_ex__(self, proto): + # See Note [Don't serialize hooks] + return ( + UninitializedParameter, + (self.requires_grad,) + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + # method-wrapper is to detect access to Tensor properties that are + # wrapped in descriptors + if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper': + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + raise ValueError( + 'Attempted to use an uninitialized parameter in {}. ' + 'This error happens when you are using a `LazyModule` or ' + 'explicitly manipulating `torch.nn.parameter.UninitializedParameter` ' + 'objects. When using LazyModules Call `forward` with a dummy batch ' + 'to initialize the parameters before calling torch functions'.format(func)) diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index dcaf1715a1acc..7e9e17eebf83c 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -1,7 +1,15 @@ +import torch from .. import Tensor +from typing import Tuple, Optional import builtins class Parameter(Tensor): def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ... ... + +class UninitializedParameter(Tensor): + def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ... + + def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ... + ... diff --git a/torch/nn/qat/modules/conv.py b/torch/nn/qat/modules/conv.py index 7daeecddd4e10..4b38149833470 100644 --- a/torch/nn/qat/modules/conv.py +++ b/torch/nn/qat/modules/conv.py @@ -21,15 +21,15 @@ class Conv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None): - super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, - stride=stride, padding=padding, dilation=dilation, - groups=groups, bias=bias, padding_mode=padding_mode) + super().__init__(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias, padding_mode=padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.weight_fake_quant = qconfig.weight() def forward(self, input): - return self._conv_forward(input, self.weight_fake_quant(self.weight)) + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod def from_float(cls, mod): diff --git a/torch/nn/qat/modules/linear.py b/torch/nn/qat/modules/linear.py index 47fc40b9b6c0e..26849e42d4c39 100644 --- a/torch/nn/qat/modules/linear.py +++ b/torch/nn/qat/modules/linear.py @@ -21,7 +21,7 @@ class Linear(nn.Linear): def __init__(self, in_features, out_features, bias=True, qconfig=None): - super(Linear, self).__init__(in_features, out_features, bias) + super().__init__(in_features, out_features, bias) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.weight_fake_quant = qconfig.weight() diff --git a/torch/nn/quantizable/__init__.py b/torch/nn/quantizable/__init__.py new file mode 100644 index 0000000000000..270dcebaa5f4e --- /dev/null +++ b/torch/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/torch/nn/quantizable/modules/__init__.py b/torch/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000..b3480b717a2d6 --- /dev/null +++ b/torch/nn/quantizable/modules/__init__.py @@ -0,0 +1,7 @@ +from .rnn import LSTM +from .rnn import LSTMCell + +__all__ = [ + 'LSTM', + 'LSTMCell', +] diff --git a/torch/nn/quantizable/modules/rnn.py b/torch/nn/quantizable/modules/rnn.py new file mode 100644 index 0000000000000..cfe076fac16cf --- /dev/null +++ b/torch/nn/quantizable/modules/rnn.py @@ -0,0 +1,403 @@ +import numbers +from typing import Optional, Tuple +import warnings + +import torch +from torch import Tensor + +""" +We will recreate all the RNN modules as we require the modules to be decomposed +into its building blocks to be able to observe. +""" + +class LSTMCell(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM) cell. + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` + + Examples:: + + >>> import torch.nn.quantizable as nnqa + >>> rnn = nnqa.LSTMCell(10, 20) + >>> input = torch.randn(3, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + hx, cx = rnn(input[i], (hx, cx)) + output.append(hx) + """ + _FLOAT_MODULE = torch.nn.LSTMCell + + def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True): + super().__init__() + self.input_size = input_dim + self.hidden_size = hidden_dim + self.bias = bias + + self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias) + self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias) + self.gates = torch.nn.quantized.FloatFunctional() + + self.fgate_cx = torch.nn.quantized.FloatFunctional() + self.igate_cgate = torch.nn.quantized.FloatFunctional() + self.fgate_cx_igate_cgate = torch.nn.quantized.FloatFunctional() + + self.ogate_cy = torch.nn.quantized.FloatFunctional() + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: + if hidden is None or hidden == (None, None): + hidden = self.initialize_hidden(x.shape[0], x.is_quantized) + hx, cx = hidden + + igates = self.igates(x) + hgates = self.hgates(hx) + gates = self.gates.add(igates, hgates) + + input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) + + input_gate = torch.sigmoid(input_gate) + forget_gate = torch.sigmoid(forget_gate) + cell_gate = torch.tanh(cell_gate) + out_gate = torch.sigmoid(out_gate) + + fgate_cx = self.fgate_cx.mul(forget_gate, cx) + igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) + fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) + cy = fgate_cx_igate_cgate + + tanh_cy = torch.tanh(cy) + hy = self.ogate_cy.mul(out_gate, tanh_cy) + return hy, cy + + def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]: + h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size)) + if is_quantized: + h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8) + c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8) + return h, c + + def _get_name(self): + return 'QuantizableLSTMCell' + + @classmethod + def from_params(cls, wi, wh, bi=None, bh=None): + """Uses the weights and biases to create a new LSTM cell. + + Args: + wi, wh: Weights for the input and hidden layers + bi, bh: Biases for the input and hidden layers + """ + assert (bi is None) == (bh is None) # Either both None or both have values + input_size = wi.shape[1] + hidden_size = wh.shape[1] + cell = cls(input_dim=input_size, hidden_dim=hidden_size, + bias=(bi is not None)) + cell.igates.weight = torch.nn.Parameter(wi) + if bi is not None: + cell.igates.bias = torch.nn.Parameter(bi) + cell.hgates.weight = torch.nn.Parameter(wh) + if bh is not None: + cell.hgates.bias = torch.nn.Parameter(bh) + return cell + + @classmethod + def from_float(cls, other): + assert type(other) == cls._FLOAT_MODULE + assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" + observed = cls.from_params(other.weight_ih, other.weight_hh, + other.bias_ih, other.bias_hh) + observed.qconfig = other.qconfig + observed.igates.qconfig = other.qconfig + observed.hgates.qconfig = other.qconfig + return observed + + +class _LSTMSingleLayer(torch.nn.Module): + r"""A single one-directional LSTM layer. + + The difference between a layer and a cell is that the layer can process a + sequence, while the cell only expects an instantaneous value. + """ + def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True): + super().__init__() + self.cell = LSTMCell(input_dim, hidden_dim, bias=bias) + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): + result = [] + for xx in x: + hidden = self.cell(xx, hidden) + result.append(hidden[0]) # type: ignore + result_tensor = torch.stack(result, 0) + return result_tensor, hidden + + @classmethod + def from_params(cls, *args, **kwargs): + cell = LSTMCell.from_params(*args, **kwargs) + layer = cls(cell.input_size, cell.hidden_size, cell.bias) + layer.cell = cell + return layer + + +class _LSTMLayer(torch.nn.Module): + r"""A single bi-directional LSTM layer.""" + def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, + batch_first: bool = False, bidirectional: bool = False): + super().__init__() + self.batch_first = batch_first + self.bidirectional = bidirectional + self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias) + if self.bidirectional: + self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias) + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + if hidden is None: + hx_fw, cx_fw = (None, None) + else: + hx_fw, cx_fw = hidden + if self.bidirectional: + if hx_fw is None: + hx_bw = None + else: + hx_bw = hx_fw[1] + hx_fw = hx_fw[0] + if cx_fw is None: + cx_bw = None + else: + cx_bw = cx_fw[1] + cx_fw = cx_fw[0] + hidden_bw = hx_bw, cx_bw + hidden_fw = hx_fw, cx_fw + result_fw, hidden_fw = self.layer_fw(x, hidden_fw) + + if self.bidirectional: + x_reversed = x.flip(0) + result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) + result_bw = result_bw.flip(0) + + result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) + h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore + c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore + else: + result = result_fw + h, c = hidden_fw # type: ignore + + if self.batch_first: + result.transpose_(0, 1) + + return result, (h, c) + + @classmethod + def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): + r""" + There is no FP equivalent of this class. This function is here just to + mimic the behavior of the `prepare` within the `torch.quantization` + flow. + """ + assert hasattr(other, 'qconfig') or (qconfig is not None) + + input_size = kwargs.get('input_size', other.input_size) + hidden_size = kwargs.get('hidden_size', other.hidden_size) + bias = kwargs.get('bias', other.bias) + batch_first = kwargs.get('batch_first', other.batch_first) + bidirectional = kwargs.get('bidirectional', other.bidirectional) + + layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) + layer.qconfig = getattr(other, 'qconfig', qconfig) + wi = getattr(other, f'weight_ih_l{layer_idx}') + wh = getattr(other, f'weight_hh_l{layer_idx}') + bi = getattr(other, f'bias_ih_l{layer_idx}', None) + bh = getattr(other, f'bias_hh_l{layer_idx}', None) + + layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) + + if other.bidirectional: + wi = getattr(other, f'weight_ih_l{layer_idx}_reverse') + wh = getattr(other, f'weight_hh_l{layer_idx}_reverse') + bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None) + bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None) + layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) + return layer + + # Getters for the weights and biases + # Note that jit currently doesn't support the `porperty`, so if you need to + # access the weights/biases you would need to navigate manually to the + # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883 + @property + def weight_ih(self): + return self.layer_fw.cell.igates.weight + + @property + def weight_hh(self): + return self.layer_fw.cell.hgates.weight + + @property + def bias_ih(self): + return self.layer_fw.cell.igates.bias + + @property + def bias_hh(self): + return self.layer_fw.cell.hgates.bias + + @property + def weight_ih_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.igates.weight + + @property + def weight_hh_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.hgates.weight + + @property + def bias_ih_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.igates.bias + + @property + def bias_hh_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.hgates.bias + + +class LSTM(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples below. + + Examples:: + + >>> import torch.nn.quantizable as nnqa + >>> rnn = nnqa.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + >>> # To get the weights: + >>> print(rnn.layers[0].weight_ih) + tensor([[...]]) + >>> print(rnn.layers[0].weight_hh) + AssertionError: There is no reverse path in the non-bidirectional layer + """ + _FLOAT_MODULE = torch.nn.LSTM + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, + batch_first: bool = False, dropout: float = 0., + bidirectional: bool = False): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.training = False # We don't want to train using this module + num_directions = 2 if bidirectional else 1 + + if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ + isinstance(dropout, bool): + raise ValueError("dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed") + if dropout > 0: + warnings.warn("dropout option for quantizable LSTM is ignored. " + "If you are training, please, use nn.LSTM version " + "followed by `prepare` step.") + if num_layers == 1: + warnings.warn("dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + "num_layers greater than 1, but got dropout={} " + "and num_layers={}".format(dropout, num_layers)) + + layers = [_LSTMLayer(self.input_size, self.hidden_size, + self.bias, batch_first=False, + bidirectional=self.bidirectional)] + for layer in range(1, num_layers): + layers.append(_LSTMLayer(self.hidden_size, self.hidden_size, + self.bias, batch_first=False, + bidirectional=self.bidirectional)) + self.layers = torch.nn.ModuleList(layers) + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + + max_batch_size = x.size(1) + num_directions = 2 if self.bidirectional else 1 + if hidden is None: + zeros = torch.zeros(num_directions, max_batch_size, + self.hidden_size, dtype=torch.float, + device=x.device) + zeros.squeeze_(0) + if x.is_quantized: + zeros = torch.quantize_per_tensor(zeros, scale=1.0, + zero_point=0, dtype=x.dtype) + hxcx = [(zeros, zeros) for _ in range(self.num_layers)] + else: + hidden_non_opt = torch.jit._unwrap_optional(hidden) + if isinstance(hidden_non_opt[0], Tensor): + hx = hidden_non_opt[0].reshape(self.num_layers, num_directions, + max_batch_size, + self.hidden_size).unbind(0) + cx = hidden_non_opt[1].reshape(self.num_layers, num_directions, + max_batch_size, + self.hidden_size).unbind(0) + hxcx = [] + for idx in range(self.num_layers): + hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0))) + else: + hxcx = hidden_non_opt + + for idx in range(self.num_layers): + x, hxcx[idx] = self.layers[idx](x, hxcx[idx]) + + hx_list = [] + cx_list = [] + for idx in range(self.num_layers): + hx_list.append(hxcx[idx][0]) + cx_list.append(hxcx[idx][1]) + hx_tensor = torch.stack(hx_list) + cx_tensor = torch.stack(cx_list) + + # We are creating another dimension for bidirectional case + # need to collapse it + hx_tensor = hx_tensor.reshape(-1, *hx_tensor.shape[-2:]) + cx_tensor = cx_tensor.reshape(-1, *cx_tensor.shape[-2:]) + + if self.batch_first: + x = x.transpose(0, 1) + + return x, (hx_tensor, cx_tensor) + + def _get_name(self): + return 'QuantizableLSTM' + + @classmethod + def from_float(cls, other, qconfig=None): + assert isinstance(other, cls._FLOAT_MODULE) + assert (hasattr(other, 'qconfig') or qconfig) + observed = cls(other.input_size, other.hidden_size, other.num_layers, + other.bias, other.batch_first, other.dropout, + other.bidirectional) + observed.qconfig = getattr(other, 'qconfig', qconfig) + for idx in range(other.num_layers): + observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig, + batch_first=False) + observed.eval() + observed = torch.quantization.prepare(observed, inplace=True) + return observed + + def from_observed(self, other): + return torch.quantization.convert(self, inplace=False, + remove_qconfig=True) diff --git a/torch/nn/quantized/_reference/__init__.py b/torch/nn/quantized/_reference/__init__.py new file mode 100644 index 0000000000000..270dcebaa5f4e --- /dev/null +++ b/torch/nn/quantized/_reference/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/torch/nn/quantized/_reference/modules/__init__.py b/torch/nn/quantized/_reference/modules/__init__.py new file mode 100644 index 0000000000000..0d2c201e5e493 --- /dev/null +++ b/torch/nn/quantized/_reference/modules/__init__.py @@ -0,0 +1,5 @@ +from .linear import Linear + +__all__ = [ + 'Linear', +] diff --git a/torch/nn/quantized/_reference/modules/linear.py b/torch/nn/quantized/_reference/modules/linear.py new file mode 100644 index 0000000000000..276dc0161ded8 --- /dev/null +++ b/torch/nn/quantized/_reference/modules/linear.py @@ -0,0 +1,51 @@ +import torch +import torch.nn.quantized as nnq +import torch.nn.functional as F +from typing import Optional + +class Linear(nnq.Linear): + """ A backend independent version of nn.quantized.Linear + we will not pack the parameters in this module, since weight packing is an + optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), + this is useful when user want to use this module in other backends like Glow. + """ + def __init__(self, in_features, out_features, bias_=True, + dtype=torch.qint8): + super().__init__(in_features, out_features, bias_, dtype) + self._qweight, self._bias = self._packed_params._weight_bias() + del self._packed_params + + def _get_name(self): + return "QuantizedLinear(Reference)" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dequant = x.dequantize() + weight_dequant = self._qweight.dequantize() + float_result = F.linear(x_dequant, weight_dequant, self._bias) + # NEEDFIX: we don't have dtype in the Linear module APIs right now! + result = torch.quantize_per_tensor( + float_result, self.scale, self.zero_point, torch.quint8) + return result + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + '_qweight'] = self._qweight + destination[prefix + '_bias'] = self._bias + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + self._qweight = state_dict[prefix + '_qweight'] + self._bias = state_dict[prefix + '_bias'] + state_dict.pop(prefix + '_qweight') + state_dict.pop(prefix + '_bias') + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, False, + missing_keys, unexpected_keys, error_msgs) + + def _weight_bias(self): + return self._qweight, self._bias + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._qweight = w + self._bias = b diff --git a/torch/nn/quantized/dynamic/modules/__init__.py b/torch/nn/quantized/dynamic/modules/__init__.py index 624854b025c0d..83278dffe28c7 100644 --- a/torch/nn/quantized/dynamic/modules/__init__.py +++ b/torch/nn/quantized/dynamic/modules/__init__.py @@ -1,10 +1,11 @@ from .linear import Linear -from .rnn import LSTM, LSTMCell, RNNCell, GRUCell +from .rnn import LSTM, GRU, LSTMCell, RNNCell, GRUCell __all__ = [ 'Linear', 'LSTM', + 'GRU', 'LSTMCell', 'RNNCell', 'GRUCell', diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index 4a4a46bf780a8..527ee76fdc765 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -1,5 +1,4 @@ import torch -from ....modules.linear import Linear as NNLinear import torch.nn.quantized as nnq from torch.nn.quantized.modules.utils import _quantize_weight @@ -80,7 +79,10 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear' + float_modules = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias] + assert type(mod) in float_modules, \ + 'nn.quantized.dynamic.Linear.from_float only works for one of' + \ + str([float_mod.__name__ for float_mod in float_modules]) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() @@ -91,7 +93,8 @@ def from_float(cls, mod): from torch.quantization.qconfig import default_dynamic_qconfig weight_observer = default_dynamic_qconfig.weight() dtype = weight_observer.dtype - assert dtype in [torch.qint8, torch.float16], 'The only supported dtypes for dynamic quantized linear are qint8 and float16' + assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \ + "dynamic quantized linear are qint8 and float16 got: {}".format(dtype) weight_observer(mod.weight) if dtype == torch.qint8: qweight = _quantize_weight(mod.weight.float(), weight_observer) diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py index 53d7e1124916f..59c0195d7858e 100644 --- a/torch/nn/quantized/dynamic/modules/rnn.py +++ b/torch/nn/quantized/dynamic/modules/rnn.py @@ -65,6 +65,8 @@ def __init__(self, mode, input_size, hidden_size, if mode == 'LSTM': gate_size = 4 * hidden_size + elif mode == 'GRU': + gate_size = 3 * hidden_size else: raise ValueError("Unrecognized RNN mode: " + mode) @@ -219,12 +221,16 @@ def from_float(cls, mod): supported_scalar_types = [torch.qint8, torch.float16] if dtype not in supported_scalar_types: raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype)) - + # RNNBase can be either LSTM or GRU + qRNNBase: Union[LSTM, GRU] if mod.mode == 'LSTM': qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers, mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype) + elif mod.mode == 'GRU': + qRNNBase = GRU(mod.input_size, mod.hidden_size, mod.num_layers, + mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype) else: - raise NotImplementedError('Only LSTM is supported for QuantizedRNN for now') + raise NotImplementedError('Only LSTM/GRU is supported for QuantizedRNN for now') num_directions = 2 if mod.bidirectional else 1 @@ -233,8 +239,6 @@ def from_float(cls, mod): _all_weight_values = [] for layer in range(qRNNBase.num_layers): for direction in range(num_directions): - layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions - suffix = '_reverse' if direction == 1 else '' def retrieve_weight_bias(ihhh): @@ -426,6 +430,219 @@ def from_float(cls, mod): return super(LSTM, cls).from_float(mod) +class GRU(RNNBase): + r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + + + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features + of the input sequence. The input can also be a packed variable length + sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` + for details. + - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the initial hidden state for each element in the batch. + Defaults to zero if not provided. If the RNN is bidirectional, + num_directions should be 2, else it should be 1. + + Outputs: output, h_n + - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor + containing the output features h_t from the last layer of the GRU, + for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been + given as the input, the output will also be a packed sequence. + For the unpacked case, the directions can be separated + using ``output.view(seq_len, batch, num_directions, hidden_size)``, + with forward and backward being direction `0` and `1` respectively. + + Similarly, the directions can be separated in the packed case. + - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the hidden state for `t = seq_len` + + Like *output*, the layers can be separated using + ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + + Shape: + - Input1: :math:`(L, N, H_{in})` tensor containing input features where + :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. + - Input2: :math:`(S, N, H_{out})` tensor + containing the initial hidden state for each element in the batch. + :math:`H_{out}=\text{hidden\_size}` + Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` + If the RNN is bidirectional, num_directions should be 2, else it should be 1. + - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` + - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state + for each element in the batch + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. include:: cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + _FLOAT_MODULE = nn.GRU + + __overloads__ = {'forward': ['forward_packed', 'forward_tensor']} + + def __init__(self, *args, **kwargs): + super(GRU, self).__init__('GRU', *args, **kwargs) + + def _get_name(self): + return 'DynamicQuantizedGRU' + + def check_forward_args(self, input, hidden, batch_sizes): + # type: (Tensor, Tensor, Optional[Tensor])->None + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size(hidden, expected_hidden_size, + 'Expected hidden size {}, got {}') + + def forward_impl( + self, input: Tensor, hx: Optional[Tensor], + batch_sizes: Optional[Tensor], max_batch_size: int, + sorted_indices: Optional[Tensor] + ) -> Tuple[Tensor, Tensor]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + hx = zeros + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + + _all_params = ([m.param for m in self._all_weight_values]) + if batch_sizes is None: + result = torch.quantized_gru(input, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first) + else: + result = torch.quantized_gru(input, + batch_sizes, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional) + output = result[0] + hidden = result[1] + + return output, hidden + + + @torch.jit.export + def forward_tensor( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.export + def forward_packed( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> Tuple[PackedSequence, Tensor]: # noqa + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + max_batch_size = int(max_batch_size) + output_, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices) + + output = PackedSequence(output_, batch_sizes, + sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + def permute_hidden( # type: ignore + self, hx: Tensor, permutation: Optional[Tensor] + ) -> Tensor: + if permutation is None: + return hx + return apply_permutation(hx, permutation) + + @torch.jit.ignore + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + + @classmethod + def from_float(cls, mod): + return super(GRU, cls).from_float(mod) + class RNNCellBase(torch.nn.Module): # _FLOAT_MODULE = nn.CellRNNBase diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py index 5985104eaf074..7364b3166062a 100644 --- a/torch/nn/quantized/functional.py +++ b/torch/nn/quantized/functional.py @@ -1,10 +1,12 @@ r""" Functional interface (quantized).""" from typing import List, Optional +import warnings import torch from torch import Tensor from torch.nn.modules.utils import _pair, _triple from torch.nn.quantized.modules.utils import _pair_from_first +from torch.jit.annotations import BroadcastingList2 # Although some of the functions and docstrings are mirrored from the torch.nn, # we want to have them here for future changes. @@ -71,8 +73,7 @@ def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, ceil_mode, count_include_pad, divisor_override) -def adaptive_avg_pool2d(input, output_size): - # type: (Tensor, BroadcastingList2[int]) -> Tensor +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: r""" Applies a 2D adaptive average pooling over a quantized input signal composed of several quantized input planes. @@ -89,8 +90,7 @@ def adaptive_avg_pool2d(input, output_size): raise ValueError("Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!") return torch.nn.functional.adaptive_avg_pool2d(input, output_size) -def adaptive_avg_pool3d(input, output_size): - # type: (Tensor, BroadcastingList2[int]) -> Tensor +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: r""" Applies a 3D adaptive average pooling over a quantized input signal composed of several quantized input planes. @@ -327,8 +327,10 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) -def linear(input, weight, bias=None, scale=None, zero_point=None): - # type: (Tensor, Tensor, Optional[Tensor], Optional[float], Optional[int]) -> Tensor +def linear( + input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, + scale: Optional[float] = None, zero_point: Optional[int] = None +) -> Tensor: r""" Applies a linear transformation to the incoming quantized data: :math:`y = xA^T + b`. @@ -360,6 +362,22 @@ def linear(input, weight, bias=None, scale=None, zero_point=None): _packed_params = torch.ops.quantized.linear_prepack(weight, bias) return torch.ops.quantized.linear(input, _packed_params, scale, zero_point) +def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): + r"""Applies a 1D max pooling over a quantized input signal composed of + several quantized input planes. + + .. note:: The input quantization parameters are propagated to the output. + + See :class:`~torch.nn.quantized.MaxPool1d` for details. + """ + if return_indices: + raise NotImplementedError("return_indices is not yet implemented!") + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.nn.functional.max_pool1d(input, kernel_size, stride, padding, + dilation, ceil_mode, return_indices) + def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): r"""Applies a 2D max pooling over a quantized input signal composed of @@ -376,8 +394,7 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices) -def celu(input, scale, zero_point, alpha=1.): - # type: (Tensor, float, int, Optional[float]) -> Tensor +def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Tensor: r"""celu(input, scale, zero_point, alpha=1.) -> Tensor Applies the quantized CELU function element-wise. @@ -393,27 +410,8 @@ def celu(input, scale, zero_point, alpha=1.): return torch.ops.quantized.celu(input, scale, zero_point, alpha) -def relu(input, inplace=False): - # type: (Tensor, bool) -> Tensor - r"""relu(input, inplace=False) -> Tensor - - Applies the rectified linear unit function element-wise. - See :class:`~torch.nn.quantized.ReLU` for more details. - - Args: - input: quantized input - inplace: perform the computation inplace - """ - if not input.is_quantized: - raise ValueError("Input to 'quantized.relu' must be quantized!") - if inplace: - return torch.relu_(input) - else: - return torch.relu(input) - -def leaky_relu(input, negative_slope=0.01, inplace=False, - scale=None, zero_point=None): - # type: (Tensor, float, bool, float, int) -> Tensor +def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False, + scale: Optional[float] = None, zero_point: Optional[int] = None): r""" Quantized version of the. leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor @@ -441,8 +439,7 @@ def leaky_relu(input, negative_slope=0.01, inplace=False, result = torch._C._nn.leaky_relu(input, negative_slope) return result -def hardtanh(input, min_val=-1., max_val=1., inplace=False): - # type: (Tensor, float, float, bool) -> Tensor +def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`. """ if not input.is_quantized: @@ -451,8 +448,7 @@ def hardtanh(input, min_val=-1., max_val=1., inplace=False): return torch._C._nn.hardtanh_(input, min_val, max_val) return torch._C._nn.hardtanh(input, min_val, max_val) -def hardswish(input, scale, zero_point): - # type: (Tensor, float, int) -> Tensor +def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor: r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`. Args: @@ -464,8 +460,7 @@ def hardswish(input, scale, zero_point): raise ValueError("Input to 'quantized.hardswish' must be quantized!") return torch._ops.ops.quantized.hardswish(input, scale, zero_point) -def threshold(input, threshold, value): - # type: (Tensor, float, float) -> Tensor +def threshold(input: Tensor, threshold: float, value: float) -> Tensor: r"""Applies the quantized version of the threshold function element-wise: .. math:: @@ -484,8 +479,7 @@ def threshold(input, threshold, value): raise ValueError("Input to 'value' must be specified!") return torch._ops.ops.quantized.threshold(input, threshold, value) -def elu(input, scale, zero_point, alpha=1.): - # type: (Tensor, float, int, float) -> Tensor +def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Tensor: r"""This is the quantized version of :func:`~torch.nn.functional.elu`. Args: @@ -498,16 +492,14 @@ def elu(input, scale, zero_point, alpha=1.): raise ValueError("Input to 'quantized.elu' must be quantized!") return torch.ops.quantized.elu(input, scale, zero_point, alpha) -def hardsigmoid(input): - # type: (Tensor) -> Tensor +def hardsigmoid(input: Tensor) -> Tensor: r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`. """ if not input.is_quantized: raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!") return torch._C._nn.hardsigmoid(input) -def clamp(input, min_, max_): - # type: (Tensor, float, float) -> Tensor +def clamp(input: Tensor, min_: float, max_: float) -> Tensor: r"""float(input, min_, max_) -> Tensor Applies the clamp function element-wise. diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index 40b11c89ef902..aed3bb17207ab 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -1,17 +1,16 @@ - import torch from torch.nn.modules.pooling import MaxPool2d -from .activation import ReLU, ReLU6, Hardswish, ELU +from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid from .batchnorm import BatchNorm2d, BatchNorm3d from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ InstanceNorm2d, InstanceNorm3d from .conv import Conv1d, Conv2d, Conv3d -from .conv import ConvTranspose1d, ConvTranspose2d +from .conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .linear import Linear from .embedding_ops import Embedding, EmbeddingBag -from .functional_modules import FloatFunctional, QFunctional +from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional class Quantize(torch.nn.Module): @@ -35,6 +34,9 @@ class Quantize(torch.nn.Module): [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2) """ + scale: torch.Tensor + zero_point: torch.Tensor + def __init__(self, scale, zero_point, dtype): super(Quantize, self).__init__() self.register_buffer('scale', torch.tensor([scale])) @@ -88,22 +90,25 @@ def from_float(mod): 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', + 'ConvTranspose3d', 'DeQuantize', - 'Linear', - 'MaxPool2d', - 'Quantize', - 'ReLU', - 'ReLU6', - 'Hardswish', 'ELU', - 'LayerNorm', + 'Embedding', + 'EmbeddingBag', 'GroupNorm', + 'Hardswish', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', - 'Embedding', - 'EmbeddingBag', + 'LayerNorm', + 'LeakyReLU', + 'Linear', + 'MaxPool2d', + 'Quantize', + 'ReLU6', + 'Sigmoid', # Wrapper modules 'FloatFunctional', + 'FXFloatFunctional', 'QFunctional', ] diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py index 35353aaf58b47..b63fa335cc807 100644 --- a/torch/nn/quantized/modules/activation.py +++ b/torch/nn/quantized/modules/activation.py @@ -1,44 +1,6 @@ import torch import torch.nn.quantized.functional -class ReLU(torch.nn.ReLU): - r"""Applies quantized rectified linear unit function element-wise: - - :math:`\text{ReLU}(x)= \max(x_0, x)`, where :math:`x_0` is the zero point. - - Please see https://pytorch.org/docs/stable/nn.html#torch.nn.ReLU - for more documentation on ReLU. - - Args: - inplace: (Currently not supported) can optionally do the operation in-place. - - Shape: - - Input: :math:`(N, *)` where `*` means, any number of additional - dimensions - - Output: :math:`(N, *)`, same shape as the input - - Examples:: - - >>> m = nn.quantized.ReLU() - >>> input = torch.randn(2) - >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) - >>> output = m(input) - """ - def __init__(self, inplace=False): - super(ReLU, self).__init__(inplace) - self.inplace = inplace - - def forward(self, input): - return torch.nn.quantized.functional.relu(input, inplace=self.inplace) - - def _get_name(self): - return 'QuantizedReLU' - - @staticmethod - def from_float(mod): - return ReLU(mod.inplace) - - class ReLU6(torch.nn.ReLU): r"""Applies the element-wise function: @@ -124,3 +86,49 @@ def _get_name(self): def from_float(mod): scale, zero_point = mod.activation_post_process.calculate_qparams() return ELU(float(scale), int(zero_point), mod.alpha) + +class LeakyReLU(torch.nn.LeakyReLU): + r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + negative_slope: Controls the angle of the negative slope. Default: 1e-2 + """ + def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2, inplace: bool = False): + super().__init__(negative_slope, inplace) + self.register_buffer('scale', torch.tensor(scale)) + self.register_buffer('zero_point', torch.tensor(zero_point)) + + def forward(self, input): + return torch.ops.quantized.leaky_relu( + input, self.negative_slope, self.inplace, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedLeakyReLU' + + @classmethod + def from_float(cls, mod): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + +class Sigmoid(torch.nn.Sigmoid): + r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, output_scale: float, output_zero_point: int): + super().__init__() + self.output_scale = output_scale + self.output_zero_point = output_zero_point + + def forward(self, input): + return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point) + + @classmethod + def from_float(cls, mod): + output_scale, output_zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(output_scale), int(output_zero_point)) diff --git a/torch/nn/quantized/modules/batchnorm.py b/torch/nn/quantized/modules/batchnorm.py index c3e028b191b40..189d402ee2a5e 100644 --- a/torch/nn/quantized/modules/batchnorm.py +++ b/torch/nn/quantized/modules/batchnorm.py @@ -21,11 +21,9 @@ def _get_name(self): @classmethod def from_float(cls, mod): + activation_post_process = mod.activation_post_process if type(mod) == nni.BNReLU2d: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process scale, zero_point = activation_post_process.calculate_qparams() new_mod = cls(mod.num_features, mod.eps) new_mod.weight = mod.weight @@ -36,6 +34,7 @@ def from_float(cls, mod): new_mod.zero_point = int(zero_point) return new_mod +# TODO: dedup with BatchNorm2d class BatchNorm3d(torch.nn.BatchNorm3d): r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`. """ @@ -55,12 +54,9 @@ def _get_name(self): @classmethod def from_float(cls, mod): + activation_post_process = mod.activation_post_process if type(mod) == nni.BNReLU3d: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process - scale, zero_point = activation_post_process.calculate_qparams() new_mod = cls(mod.num_features, mod.eps) new_mod.weight = mod.weight diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 773a9a37fbb30..b3bc78ff69417 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # coding=utf-8 r"""Quantized convolution modules.""" -from typing import Optional, List +from typing import Optional, List, TypeVar import torch import torch.nn as nn @@ -16,11 +16,17 @@ class _ConvNd(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, - transposed, output_padding, - groups, bias, + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, + transposed, output_padding, + groups, bias, + padding_mode='zeros'): super(_ConvNd, self).__init__() if padding_mode != 'zeros': raise NotImplementedError( @@ -54,6 +60,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, self.scale = 1.0 self.zero_point = 0 + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + def extra_repr(self): s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, scale={scale}, zero_point={zero_point}') @@ -155,7 +170,8 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) @@ -163,6 +179,30 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): qconv.zero_point = int(act_zp) return qconv + @staticmethod + def from_float(cls, mod): + if hasattr(mod, "weight_fake_quant"): + # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ + # ".from_float only works for " + cls.__QAT_MODULE.__name__ + if type(mod) == cls._NNIQAT_CONV_BN_MODULE: + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, + mod.bn.eps, mod.bn.weight, mod.bn.bias) + assert hasattr(mod, "activation_post_process"), \ + "Input QAT module must have observer attached" + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + assert type(mod) == cls._FLOAT_MODULE, \ + " nnq." + cls.__name__ + ".from_float only works for " + \ + cls._FLOAT_MODULE.__name__ + assert hasattr(mod, "qconfig"), \ + "Input float module must have qconfig defined." + activation_post_process = mod.activation_post_process + if type(mod) == cls._NNI_CONV_RELU_MODULE: + mod = mod[0] + weight_post_process = mod.qconfig.weight() + return cls.get_qconv(mod, activation_post_process, weight_post_process) class Conv1d(_ConvNd): r"""Applies a 1D convolution over a quantized input signal composed of @@ -198,6 +238,8 @@ class Conv1d(_ConvNd): """ _FLOAT_MODULE = nn.Conv1d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE = nni.ConvReLU1d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -207,15 +249,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding = _pair_from_first(padding) dilation = _pair_from_first(dilation) - super(Conv1d, self).__init__( + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv1d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) def _get_name(self): return 'QuantizedConv1d' - def set_weight_bias(self, w, b): - # type: (torch.Tensor, Optional[torch.Tensor]) -> None + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params = torch.ops.quantized.conv1d_prepack( w, b, self.stride, self.padding, self.dilation, self.groups) @@ -244,17 +287,7 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ - assert hasattr(mod, 'qconfig'), \ - 'Input float module must have qconfig defined.' - if type(mod) == nni.ConvReLU1d: - activation_post_process = mod[1].activation_post_process - mod = mod[0] - else: - activation_post_process = mod.activation_post_process - return cls.get_qconv(mod, activation_post_process) + return _ConvNd.from_float(cls, mod) class Conv2d(_ConvNd): @@ -294,6 +327,8 @@ class Conv2d(_ConvNd): """ _FLOAT_MODULE = nn.Conv2d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE = nni.ConvReLU2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -302,15 +337,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) - super(Conv2d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv2d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) def _get_name(self): return 'QuantizedConv2d' - def set_weight_bias(self, w, b): - # type: (torch.Tensor, Optional[torch.Tensor]) -> None + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params = torch.ops.quantized.conv2d_prepack( w, b, self.stride, self.padding, self.dilation, self.groups) @@ -339,33 +375,7 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - if hasattr(mod, 'weight_fake_quant'): - # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + \ - # '.from_float only works for ' + cls.__QAT_MODULE.__name__ - if type(mod) == nniqat.ConvBn2d: - mod.weight, mod.bias = fuse_conv_bn_weights( - mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, - mod.bn.eps, mod.bn.weight, mod.bn.bias) - assert hasattr(mod, 'activation_post_process'), \ - 'Input QAT module must have observer attached' - weight_post_process = mod.weight_fake_quant - activation_post_process = mod.activation_post_process - else: - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ - assert hasattr(mod, 'qconfig'), \ - 'Input float module must have qconfig defined.' - # workaround for sequential, ConvReLU2d should probably - # inherit from Conv2d instead - if type(mod) == nni.ConvReLU2d: - activation_post_process = mod[1].activation_post_process - mod = mod[0] - else: - activation_post_process = mod.activation_post_process - weight_post_process = mod.qconfig.weight() - - return cls.get_qconv(mod, activation_post_process, weight_post_process) + return _ConvNd.from_float(cls, mod) class Conv3d(_ConvNd): @@ -413,15 +423,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) - super(Conv3d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv3d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias, padding_mode) def _get_name(self): return 'QuantizedConv3d' - def set_weight_bias(self, w, b): - # type: (torch.Tensor, Optional[torch.Tensor]) -> None + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params = torch.ops.quantized.conv3d_prepack( w, b, self.stride, self.padding, self.dilation, self.groups) @@ -455,31 +466,31 @@ def from_float(cls, mod): cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' - # Workaround for sequential, ConvReLU3d should probably inherit from - # Conv3d instead + activation_post_process = mod.activation_post_process if type(mod) == nni.ConvReLU3d: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process return cls.get_qconv(mod, activation_post_process) # === Transposed Convolutions === +MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): + + _FLOAT_MODULE = MOD + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode): if padding_mode != 'zeros': raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__)) - - super(_ConvTransposeNd, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(_ConvTransposeNd, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode) - def _input_padding(self, kernel_size, dilation, padding): - # type: (List[int], List[int], List[int]) -> List[int] + def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]: res = torch.jit.annotate(List[int], []) for kdx in range(len(kernel_size)): pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx]) @@ -493,9 +504,10 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ + # derived classes override cls._FLOAT_MODULE attribute + msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ + assert type(mod) == cls._FLOAT_MODULE, msg assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' weight_post_process = mod.qconfig.weight() @@ -504,7 +516,8 @@ def from_float(cls, mod): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.output_padding, mod.groups, mod.bias is not None, mod.dilation, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) @@ -519,15 +532,22 @@ class ConvTranspose1d(_ConvTransposeNd): composed of several input planes. For details on input arguments, parameters, and implementation see :class:`~torch.nn.ConvTranspose1d`. + .. note:: Currently only the QNNPACK engine is implemented. + Please, set the `torch.backends.quantized.engine = 'qnnpack'` + For special notes, please, see :class:`~torch.nn.quantized.Conv1d` + Attributes: weight (Tensor): packed tensor derived from the learnable weight parameter. scale (Tensor): scalar for the output scale zero_point (Tensor): scalar for the output zero point See :class:`~torch.nn.ConvTranspose2d` for other attributes. + Examples:: + + >>> torch.backends.quantized.engine = 'qnnpack' >>> # With square kernels and equal stride >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding @@ -566,8 +586,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, def _get_name(self): return 'QuantizedConvTranpose1d' - def set_weight_bias(self, w, b): - # type: (torch.Tensor, Optional[torch.Tensor]) -> None + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params = torch.ops.quantized.conv_transpose1d_prepack( w, b, self.stride, self.padding, self.output_padding, self.dilation, self.groups) @@ -598,15 +617,20 @@ class ConvTranspose2d(_ConvTransposeNd): composed of several input planes. For details on input arguments, parameters, and implementation see :class:`~torch.nn.ConvTranspose2d`. - .. note:: Currently only the QNNPACK engine is implemented. + For special notes, please, see :class:`~torch.nn.quantized.Conv2d` + Attributes: weight (Tensor): packed tensor derived from the learnable weight parameter. scale (Tensor): scalar for the output scale zero_point (Tensor): scalar for the output zero point See :class:`~torch.nn.ConvTranspose2d` for other attributes. + Examples:: + + >>> # QNNPACK or FBGEMM as backend + >>> torch.backends.quantized.engine = 'qnnpack' >>> # With square kernels and equal stride >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding @@ -645,8 +669,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, def _get_name(self): return 'QuantizedConvTranpose2d' - def set_weight_bias(self, w, b): - # type: (torch.Tensor, Optional[torch.Tensor]) -> None + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params = torch.ops.quantized.conv_transpose2d_prepack( w, b, self.stride, self.padding, self.output_padding, self.dilation, self.groups) @@ -670,3 +693,87 @@ def forward(self, input): raise ValueError("Input shape must be `(N, C, H, W)`!") return ops.quantized.conv_transpose2d( input, self._packed_params, self.scale, self.zero_point) + +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + .. note:: Currently only the FBGEMM engine is implemented. + Please, set the `torch.backends.quantized.engine = 'fbgemm'` + + For special notes, please, see :class:`~torch.nn.quantized.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> torch.backends.quantized.engine = 'fbgemm' + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> input = torch.randn(20, 16, 50, 100, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE = nn.ConvTranspose3d + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, + dilation=1, padding_mode='zeros'): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + super(ConvTranspose3d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode) + + def _get_name(self): + return 'QuantizedConvTranpose3d' + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( + w, b, self.stride, self.padding, self.output_padding, self.dilation, + self.groups) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d( + input, self._packed_params, self.scale, self.zero_point) diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index 278eeed2ca9f0..b12591a6d1dec 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -11,31 +11,30 @@ class EmbeddingPackedParams(torch.nn.Module): def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): super(EmbeddingPackedParams, self).__init__() self.dtype = dtype - if self.dtype == torch.quint8: + if self.dtype in [torch.quint8, torch.quint4x2]: scales = torch.ones(num_embeddings, dtype=torch.float) zero_points = torch.zeros(num_embeddings, dtype=torch.float) wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales, zero_points=zero_points, - axis=0, dtype=torch.quint8) + axis=0, dtype=self.dtype) self.set_weight(wq) else: - raise RuntimeError('Unsupported dtype on quantized embedding!') + raise NotImplementedError('Unsupported dtype on quantized embedding! Supports quint8 and quint4x2.') @torch.jit.export - def set_weight(self, weight): - # type: (torch.Tensor) -> None - if self.dtype == torch.quint8: + def set_weight(self, weight: torch.Tensor) -> None: + if self.dtype in [torch.quint8, torch.quint4x2]: self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) else: - raise RuntimeError('Unsupported dtype on quantized embedding!') + raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.') @torch.jit.export def _weight(self): - if self.dtype == torch.quint8: + if self.dtype in [torch.quint8, torch.quint4x2]: return torch.ops.quantized.embedding_bag_unpack(self._packed_weight) else: - raise RuntimeError('Unsupported dtype on quantized embedding!') + raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.') def forward(self, x): return x @@ -52,7 +51,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - version = local_metadata.get('version', None) self.dtype = state_dict[prefix + 'dtype'] state_dict.pop(prefix + 'dtype') @@ -99,16 +97,16 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona if _weight is None: scales = torch.ones(num_embeddings, dtype=torch.float) zero_points = torch.zeros(num_embeddings, dtype=torch.float) - self.qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], - scales=scales, zero_points=zero_points, - axis=0, dtype=torch.quint8) + qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], + scales=scales, zero_points=zero_points, + axis=0, dtype=torch.quint8) else: assert list(_weight.shape) == [num_embeddings, embedding_dim], \ 'Shape of weight does not match num_embeddings and embedding_dim' - self.qweight = _weight + qweight = _weight self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype) - self._packed_params.set_weight(self.qweight) + self._packed_params.set_weight(qweight) def forward(self, indices: Tensor) -> Tensor: return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices) @@ -121,13 +119,12 @@ def __repr__(self): def extra_repr(self): extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format( - self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.qweight.qscheme() + self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme() ) return extra_repr_str - def set_weight(self, w): - # type: (torch.Tensor) -> None + def set_weight(self, w: torch.Tensor) -> None: self._packed_params.set_weight(w) def weight(self): @@ -144,11 +141,11 @@ def from_float(cls, mod): assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \ nn.Embedding.__name__ assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined' - from torch.quantization.qconfig import float_qparams_dynamic_qconfig + from torch.quantization import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: - weight_observer = float_qparams_dynamic_qconfig.weight() + weight_observer = float_qparams_weight_only_qconfig.weight() dtype = weight_observer.dtype @@ -192,17 +189,23 @@ def __init__(self, num_embeddings: int, embedding_dim: int, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None, include_last_offset: bool = False, dtype=torch.quint8) -> None: - super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight) + super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) self.mode = mode - self.sparse = sparse + self.pruned_weights = False self.include_last_offset = include_last_offset + self.dtype = dtype def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None, compressed_indices_mapping: Optional[Tensor] = None) -> Tensor: - return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0, - self.sparse, per_sample_weights, compressed_indices_mapping, - self.include_last_offset) + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0, + self.pruned_weights, per_sample_weights, compressed_indices_mapping, + self.include_last_offset) + else: + return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0, + self.pruned_weights, per_sample_weights, compressed_indices_mapping, + self.include_last_offset) def _get_name(self): return 'QuantizedEmbeddingBag' @@ -218,21 +221,22 @@ def from_float(cls, mod): assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \ nn.EmbeddingBag.__name__ assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined' - from torch.quantization.qconfig import float_qparams_dynamic_qconfig + from torch.quantization.qconfig import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: - weight_observer = float_qparams_dynamic_qconfig.weight() + weight_observer = float_qparams_weight_only_qconfig.weight() dtype = weight_observer.dtype - assert dtype == torch.quint8, 'The only supported dtype for nnq.EmbeddingBag is torch.quint8' + assert dtype == torch.quint8 or dtype == torch.quint4x2, \ + 'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2' # Run the observer to calculate qparams. weight_observer(mod.weight) qweight = _quantize_weight(mod.weight.float(), weight_observer) # Create quantized EmbeddingBag module and pass in the quantized weight - qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim) + qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype) qembedding_bag.set_weight(qweight) return qembedding_bag diff --git a/torch/nn/quantized/modules/functional_modules.py b/torch/nn/quantized/modules/functional_modules.py index d3fa7189e0566..4bac20dc28d34 100644 --- a/torch/nn/quantized/modules/functional_modules.py +++ b/torch/nn/quantized/modules/functional_modules.py @@ -4,7 +4,6 @@ from torch import Tensor from torch._ops import ops - class FloatFunctional(torch.nn.Module): r"""State collector class for float operations. @@ -40,48 +39,90 @@ def forward(self, x): "'forward'. Please use the underlying operation") r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" - def add(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def add(self, x: Tensor, y: Tensor) -> Tensor: r = torch.add(x, y) r = self.activation_post_process(r) return r r"""Operation equivalent to ``torch.add(Tensor, float)``""" - def add_scalar(self, x, y): - # type: (Tensor, float) -> Tensor + def add_scalar(self, x: Tensor, y: float) -> Tensor: r = torch.add(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" - def mul(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def mul(self, x: Tensor, y: Tensor) -> Tensor: r = torch.mul(x, y) r = self.activation_post_process(r) return r r"""Operation equivalent to ``torch.mul(Tensor, float)``""" - def mul_scalar(self, x, y): - # type: (Tensor, float) -> Tensor + def mul_scalar(self, x: Tensor, y: float) -> Tensor: r = torch.mul(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.cat``""" - def cat(self, x, dim=0): - # type: (List[Tensor], int) -> Tensor + def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: r = torch.cat(x, dim=dim) r = self.activation_post_process(r) return r r"""Operation equivalent to ``relu(torch.add(x,y))``""" - def add_relu(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: r = torch.add(x, y) r = torch.nn.functional.relu(r) r = self.activation_post_process(r) return r +class FXFloatFunctional(torch.nn.Module): + r""" module to replace FloatFunctional module before FX graph mode quantization, + since activation_post_process will be inserted in top level module directly + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + def forward(self, x): + raise RuntimeError("FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation") + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.cat``""" + def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + return r class QFunctional(torch.nn.Module): r"""Wrapper class for quantized operations. @@ -141,43 +182,39 @@ def forward(self, x): "'forward'. Please use the underlying operation") r"""Operation equivalent to ``torch.ops.quantized.add``""" - def add(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def add(self, x: Tensor, y: Tensor) -> Tensor: r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point) r = self.activation_post_process(r) return r r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``""" - def add_scalar(self, x, y): - # type: (Tensor, float) -> Tensor + def add_scalar(self, x: Tensor, y: float) -> Tensor: r = ops.quantized.add_scalar(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``""" - def mul(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def mul(self, x: Tensor, y: Tensor) -> Tensor: r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point) r = self.activation_post_process(r) return r r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``""" - def mul_scalar(self, x, y): - # type: (Tensor, float) -> Tensor + def mul_scalar(self, x: Tensor, y: float) -> Tensor: r = ops.quantized.mul_scalar(x, y) - r = self.activation_post_process(r) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. return r r"""Operation equivalent to ``torch.ops.quantized.cat``""" - def cat(self, x, dim=0): - # type: (List[Tensor], int) -> Tensor + def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim) r = self.activation_post_process(r) return r r"""Operation equivalent to ``torch.ops.quantized.add_relu``""" - def add_relu(self, x, y): - # type: (Tensor, Tensor) -> Tensor + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point) r = self.activation_post_process(r) return r diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index 4d27dad07bc16..d7f86ccb62161 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import torch import torch.nn as nn @@ -9,7 +10,7 @@ class LinearPackedParams(torch.nn.Module): _version = 3 def __init__(self, dtype=torch.qint8): - super(LinearPackedParams, self).__init__() + super().__init__() self.dtype = dtype if self.dtype == torch.qint8: wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8) @@ -124,10 +125,11 @@ class Linear(torch.nn.Module): torch.Size([128, 30]) """ _version = 3 - _FLOAT_MODULE = nn.Linear + _FLOAT_MODULE = (nn.Linear, nn.modules.linear._LinearWithBias) - def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): - super(Linear, self).__init__() + def __init__(self, in_features, out_features, bias_=True, + dtype=torch.qint8): + super().__init__() # We don't muck around with buffers or attributes or anything here # to keep the module simple. *everything* is simply a Python attribute. # Serialization logic is explicitly handled in the below serialization and @@ -162,7 +164,7 @@ def extra_repr(self): def __repr__(self): return hide_packed_params_repr(self, LinearPackedParams) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.quantized.linear( x, self._packed_params._packed_params, self.scale, self.zero_point) @@ -196,7 +198,7 @@ def forward(self, x): # of LinearPackedParams C++ struct # def _save_to_state_dict(self, destination, prefix, keep_vars): - super(Linear, self)._save_to_state_dict(destination, prefix, keep_vars) + super()._save_to_state_dict(destination, prefix, keep_vars) destination[prefix + 'scale'] = torch.tensor(self.scale) destination[prefix + 'zero_point'] = torch.tensor(self.zero_point) @@ -212,6 +214,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, state_dict.pop(prefix + 'zero_point') version = local_metadata.get('version', None) + if version is None or version == 1: # We moved the parameters into a LinearPackedParameters submodule weight = state_dict.pop(prefix + 'weight') @@ -219,8 +222,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, state_dict.update({prefix + '_packed_params.weight': weight, prefix + '_packed_params.bias': bias}) - super(Linear, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, - missing_keys, unexpected_keys, error_msgs) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, False, + missing_keys, unexpected_keys, error_msgs) # Function rather than property to make sure that JIT serialization doesn't # register this as an attribute @@ -249,21 +253,27 @@ def from_float(cls, mod): weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: - assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ + # This function does not participate in JIT, so it is OK to ignore + # the type mismatch in assignment. Also, mypy has an issue with + # iterables not being implemented, so we are ignoring those too. + if not isinstance(cls._FLOAT_MODULE, Iterable): + cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore + supported_modules = ', '.join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore + error_msg = 'nnq.{}.from_float only works for {}'.format(cls.__name__, supported_modules) + assert type(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + activation_post_process = mod.activation_post_process if type(mod) == nni.LinearReLU: - activation_post_process = mod[1].activation_post_process mod = mod[0] - else: - activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) dtype = weight_post_process.dtype act_scale, act_zp = activation_post_process.calculate_qparams() assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear = cls(mod.in_features, + mod.out_features, + dtype=dtype) qlinear.set_weight_bias(qweight, mod.bias) qlinear.scale = float(act_scale) qlinear.zero_point = int(act_zp) diff --git a/torch/nn/quantized/modules/normalization.py b/torch/nn/quantized/modules/normalization.py index 4664120ec8b5a..c12f743748633 100644 --- a/torch/nn/quantized/modules/normalization.py +++ b/torch/nn/quantized/modules/normalization.py @@ -29,7 +29,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), @@ -63,7 +62,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point), @@ -98,7 +96,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -133,7 +130,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -168,7 +164,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), diff --git a/torch/nn/quantized/modules/utils.py b/torch/nn/quantized/modules/utils.py index e043db89c647a..9d6e93f9d2fed 100644 --- a/torch/nn/quantized/modules/utils.py +++ b/torch/nn/quantized/modules/utils.py @@ -1,6 +1,7 @@ import torch from torch._six import container_abcs from itertools import repeat +from torch.nn.modules.module import _addindent def _quantize_weight(float_wt, observer): wt_scale, wt_zp = observer.calculate_qparams() @@ -16,7 +17,7 @@ def _quantize_weight(float_wt, observer): elif observer.qscheme in [torch.per_channel_affine_float_qparams]: qweight = torch.quantize_per_channel( float_wt, - wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, torch.quint8) + wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, observer.dtype) else: raise ValueError("Unexpected qscheme " + observer.qscheme) return qweight @@ -25,7 +26,7 @@ def _ntuple_from_first(n): """Converts the argument to a tuple of size n with the first element repeated.""" def parse(x): - while isinstance(x, container_abcs.Iterable): + while isinstance(x, container_abcs.Sequence): if len(x) == n: break x = x[0] diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index d7caf83110953..8c7da10346694 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -12,7 +12,7 @@ def clip_grad_norm_(parameters: _tensor_or_tensors, max_norm: float, norm_type: The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. - Arguments: + Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients @@ -58,7 +58,7 @@ def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float) -> None: Gradients are modified in-place. - Arguments: + Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized clip_value (float or int): maximum allowed value of the gradients. diff --git a/torch/nn/utils/convert_parameters.py b/torch/nn/utils/convert_parameters.py index c51893de400d8..f5b286263387e 100644 --- a/torch/nn/utils/convert_parameters.py +++ b/torch/nn/utils/convert_parameters.py @@ -5,7 +5,7 @@ def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: r"""Convert parameters to one vector - Arguments: + Args: parameters (Iterable[Tensor]): an iterator of Tensors that are the parameters of a model. @@ -27,7 +27,7 @@ def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None: r"""Convert one vector to the parameters - Arguments: + Args: vec (Tensor): a single vector represents the parameters of a model. parameters (Iterable[Tensor]): an iterator of Tensors that are the parameters of a model. @@ -60,7 +60,7 @@ def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> and single vector form is not supported for multiple allocations, e.g. parameters in different GPUs, or mixture of CPU/GPU. - Arguments: + Args: param ([Tensor]): a Tensor of a parameter of a model old_param_device (int): the device where the first parameter of a model is allocated. diff --git a/torch/nn/utils/fusion.py b/torch/nn/utils/fusion.py index 65b1b7eb00480..c4f164ee3b401 100644 --- a/torch/nn/utils/fusion.py +++ b/torch/nn/utils/fusion.py @@ -15,7 +15,11 @@ def fuse_conv_bn_eval(conv, bn): def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): if conv_b is None: - conv_b = bn_rm.new_zeros(bn_rm.shape) + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 84fa30021ed11..c21940689a770 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -1,13 +1,14 @@ r""" Pruning methods """ -from abc import abstractmethod import numbers -import torch -from abc import ABC +from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Tuple +import torch + + class BasePruningMethod(ABC): r"""Abstract base class for creation of new pruning techniques. @@ -40,7 +41,8 @@ def compute_mask(self, t, default_mask): method recipe. Args: - t (torch.Tensor): tensor representing the parameter to prune + t (torch.Tensor): tensor representing the importance scores of the + parameter to prune. default_mask (torch.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. @@ -64,9 +66,7 @@ def apply_mask(self, module): """ # to carry out the multiplication, the mask needs to have been computed, # so the pruning method must know what tensor it's operating on - assert ( - self._tensor_name is not None - ), "Module {} has to be pruned".format( + assert self._tensor_name is not None, "Module {} has to be pruned".format( module ) # this gets set in apply() mask = getattr(module, self._tensor_name + "_mask") @@ -75,7 +75,7 @@ def apply_mask(self, module): return pruned_tensor @classmethod - def apply(cls, module, name, *args, **kwargs): + def apply(cls, module, name, *args, importance_scores=None, **kwargs): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -86,6 +86,11 @@ def apply(cls, module, name, *args, **kwargs): will act. args: arguments passed on to a subclass of :class:`BasePruningMethod` + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the + corresponding elements in the parameter being pruned. + If unspecified or None, the parameter will be used in its place. kwargs: keyword arguments passed on to a subclass of a :class:`BasePruningMethod` """ @@ -101,10 +106,7 @@ def _get_composite_method(cls, module, name, *args, **kwargs): for k, hook in module._forward_pre_hooks.items(): # if it exists, take existing thing, remove hook, then # go through normal thing - if ( - isinstance(hook, BasePruningMethod) - and hook._tensor_name == name - ): + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: old_method = hook hooks_to_remove.append(k) found += 1 @@ -150,8 +152,18 @@ def _get_composite_method(cls, module, name, *args, **kwargs): # Pruning is to be applied to the module's tensor named `name`, # starting from the state it is found in prior to this iteration of - # pruning + # pruning. The pruning mask is calculated based on importances scores. + orig = getattr(module, name) + if importance_scores is not None: + assert ( + importance_scores.shape == orig.shape + ), "importance_scores should have the same shape as parameter \ + {} of {}".format( + name, module + ) + else: + importance_scores = orig # If this is the first time pruning is applied, take care of moving # the original tensor to a new parameter called name + '_orig' and @@ -166,13 +178,17 @@ def _get_composite_method(cls, module, name, *args, **kwargs): # has been done before in a previous pruning iteration, so we're good # to go else: - default_mask = getattr(module, name + "_mask").detach().clone(memory_format=torch.contiguous_format) + default_mask = ( + getattr(module, name + "_mask") + .detach() + .clone(memory_format=torch.contiguous_format) + ) # Use try/except because if anything goes wrong with the mask # computation etc., you'd want to roll back. try: # get the final mask, computed according to the specific method - mask = method.compute_mask(orig, default_mask=default_mask) + mask = method.compute_mask(importance_scores, default_mask=default_mask) # reparametrize by saving mask to `module[name + '_mask']`... module.register_buffer(name + "_mask", mask) # ... and the new pruned tensor to `module[name]` @@ -190,13 +206,18 @@ def _get_composite_method(cls, module, name, *args, **kwargs): return method - def prune(self, t, default_mask=None): + def prune(self, t, default_mask=None, importance_scores=None): r"""Computes and returns a pruned version of input tensor ``t`` according to the pruning rule specified in :meth:`compute_mask`. Args: t (torch.Tensor): tensor to prune (of same dimensions as ``default_mask``). + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as ``t``) used to compute mask for pruning ``t``. + The values in this tensor indicate the importance of the + corresponding elements in the ``t`` that is being pruned. + If unspecified or None, the tensor ``t`` will be used in its place. default_mask (torch.Tensor, optional): mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, @@ -205,9 +226,14 @@ def prune(self, t, default_mask=None): Returns: pruned version of tensor ``t``. """ - if default_mask is None: - default_mask = torch.ones_like(t) - return t * self.compute_mask(t, default_mask=default_mask) + if importance_scores is not None: + assert ( + importance_scores.shape == t.shape + ), "importance_scores should have the same shape as tensor t" + else: + importance_scores = t + default_mask = default_mask if default_mask is not None else torch.ones_like(t) + return t * self.compute_mask(importance_scores, default_mask=default_mask) def remove(self, module): r"""Removes the pruning reparameterization from a module. The pruned @@ -249,7 +275,7 @@ class PruningContainer(BasePruningMethod): """ def __init__(self, *args): - self._pruning_methods: Tuple['BasePruningMethod', ...] = tuple() + self._pruning_methods: Tuple["BasePruningMethod", ...] = tuple() if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name self.add_pruning_method(args) @@ -319,6 +345,7 @@ def compute_mask(self, t, default_mask): pruning ``method`` (of same dimensions as ``default_mask`` and ``t``). """ + def _combine_masks(method, t, mask): r""" Args: @@ -360,13 +387,12 @@ def _combine_masks(method, t, mask): # if dim is still negative after subtracting it from n_dims if dim < 0: raise IndexError( - 'Index is out of bounds for tensor with dimensions {}' - .format(n_dims) + "Index is out of bounds for tensor with dimensions {}".format( + n_dims + ) ) # find channels along dim = dim that aren't already tots 0ed out - keep_channel = ( - mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 - ) + keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 # create slice to identify what to prune slc = [slice(None)] * n_dims slc[dim] = keep_channel @@ -470,9 +496,7 @@ def apply(cls, module, name, amount): fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """ - return super(RandomUnstructured, cls).apply( - module, name, amount=amount - ) + return super(RandomUnstructured, cls).apply(module, name, amount=amount) class L1Unstructured(BasePruningMethod): @@ -509,16 +533,14 @@ def compute_mask(self, t, default_mask): if nparams_toprune != 0: # k=0 not supported by torch.kthvalue # largest=True --> top k; largest=False --> bottom k # Prune the smallest k - topk = torch.topk( - torch.abs(t).view(-1), k=nparams_toprune, largest=False - ) + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) # topk will have .indices and .values mask.view(-1)[topk.indices] = 0 return mask @classmethod - def apply(cls, module, name, amount): + def apply(cls, module, name, amount, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -531,8 +553,15 @@ def apply(cls, module, name, amount): If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. """ - return super(L1Unstructured, cls).apply(module, name, amount=amount) + return super(L1Unstructured, cls).apply( + module, name, amount=amount, importance_scores=importance_scores + ) class RandomStructured(BasePruningMethod): @@ -587,7 +616,6 @@ def compute_mask(self, t, default_mask): # Compute number of units to prune: amount if int, # else amount * tensor_size nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) - nparams_tokeep = tensor_size - nparams_toprune # This should raise an error if the number of units to prune is larger # than the number of units in the tensor _validate_pruning_amount(nparams_toprune, tensor_size) @@ -635,9 +663,7 @@ def apply(cls, module, name, amount, dim=-1): dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """ - return super(RandomStructured, cls).apply( - module, name, amount=amount, dim=dim - ) + return super(RandomStructured, cls).apply(module, name, amount=amount, dim=dim) class LnStructured(BasePruningMethod): @@ -706,11 +732,7 @@ def compute_mask(self, t, default_mask): norm = _compute_norm(t, self.n, self.dim) # largest=True --> top k; largest=False --> bottom k # Keep the largest k channels along dim=self.dim - topk = torch.topk( - norm, - k=nparams_tokeep, - largest=True, - ) + topk = torch.topk(norm, k=nparams_tokeep, largest=True) # topk will have .indices and .values # Compute binary mask by initializing it to all 0s and then filling in @@ -738,7 +760,7 @@ def make_mask(t, dim, indices): return mask @classmethod - def apply(cls, module, name, amount, n, dim): + def apply(cls, module, name, amount, n, dim, importance_scores=None): r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. @@ -755,9 +777,19 @@ def apply(cls, module, name, amount, n, dim): entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. """ return super(LnStructured, cls).apply( - module, name, amount=amount, n=n, dim=dim + module, + name, + amount=amount, + n=n, + dim=dim, + importance_scores=importance_scores, ) @@ -784,9 +816,7 @@ def apply(cls, module, name, mask): name (str): parameter name within ``module`` on which pruning will act. """ - return super(CustomFromMask, cls).apply( - module, name, mask - ) + return super(CustomFromMask, cls).apply(module, name, mask=mask) def identity(module, name): @@ -853,7 +883,7 @@ def random_unstructured(module, name, amount): return module -def l1_unstructured(module, name, amount): +def l1_unstructured(module, name, amount, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified `amount` of (currently unpruned) units with the lowest L1-norm. @@ -873,6 +903,11 @@ def l1_unstructured(module, name, amount): If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module @@ -882,7 +917,9 @@ def l1_unstructured(module, name, amount): >>> m.state_dict().keys() odict_keys(['bias', 'weight_orig', 'weight_mask']) """ - L1Unstructured.apply(module, name, amount) + L1Unstructured.apply( + module, name, amount=amount, importance_scores=importance_scores + ) return module @@ -923,7 +960,7 @@ def random_structured(module, name, amount, dim): return module -def ln_structured(module, name, amount, n, dim): +def ln_structured(module, name, amount, n, dim, importance_scores=None): r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) channels along the specified ``dim`` with the lowest L``n``-norm. @@ -946,6 +983,11 @@ def ln_structured(module, name, amount, n, dim): n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module @@ -955,11 +997,13 @@ def ln_structured(module, name, amount, n, dim): nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ) """ - LnStructured.apply(module, name, amount, n, dim) + LnStructured.apply( + module, name, amount, n, dim, importance_scores=importance_scores + ) return module -def global_unstructured(parameters, pruning_method, **kwargs): +def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. @@ -978,6 +1022,12 @@ def global_unstructured(parameters, pruning_method, **kwargs): pruning_method (function): a valid pruning function from this module, or a custom one implemented by the user that satisfies the implementation guidelines and has ``PRUNING_TYPE='unstructured'``. + importance_scores (dict): a dictionary mapping (module, name) tuples to + the corresponding parameter's importance scores tensor. The tensor + should be the same shape as the parameter, and is used for computing + mask for pruning. + If unspecified or None, the parameter will be used in place of its + importance scores. kwargs: other keyword arguments such as: amount (int or float): quantity of parameters to prune across the specified parameters. @@ -1012,17 +1062,25 @@ def global_unstructured(parameters, pruning_method, **kwargs): """ # ensure parameters is a list or generator of tuples - assert isinstance(parameters, Iterable) + if not isinstance(parameters, Iterable): + raise TypeError("global_unstructured(): parameters is not an Iterable") - # flatten parameter values to consider them all at once in global pruning - t = torch.nn.utils.parameters_to_vector([getattr(*p) for p in parameters]) + importance_scores = importance_scores if importance_scores is not None else {} + if not isinstance(importance_scores, dict): + raise TypeError("global_unstructured(): importance_scores must be of type dict") + + # flatten importance scores to consider them all at once in global pruning + relevant_importance_scores = torch.nn.utils.parameters_to_vector( + [ + importance_scores.get((module, name), getattr(module, name)) + for (module, name) in parameters + ] + ) # similarly, flatten the masks (if they exist), or use a flattened vector # of 1s of the same dimensions as t default_mask = torch.nn.utils.parameters_to_vector( [ - getattr( - module, name + "_mask", torch.ones_like(getattr(module, name)) - ) + getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) for (module, name) in parameters ] ) @@ -1045,7 +1103,7 @@ def global_unstructured(parameters, pruning_method, **kwargs): # use the `compute_mask` method from `PruningContainer` to combine the # mask computed by the new method with the pre-existing mask - final_mask = container.compute_mask(t, default_mask) + final_mask = container.compute_mask(relevant_importance_scores, default_mask) # Pointer for slicing the mask to match the shape of each parameter pointer = 0 @@ -1058,7 +1116,7 @@ def global_unstructured(parameters, pruning_method, **kwargs): param_mask = final_mask[pointer : pointer + num_param].view_as(param) # Assign the correct pre-computed mask to each parameter and add it # to the forward_pre_hooks like any other pruning method - custom_from_mask(module, name, param_mask) + custom_from_mask(module, name, mask=param_mask) # Increment the pointer to continue slicing the final_mask pointer += num_param @@ -1173,8 +1231,7 @@ def _validate_pruning_amount_init(amount): """ if not isinstance(amount, numbers.Real): raise TypeError( - "Invalid type for amount: {}. Must be int or float." - "".format(amount) + "Invalid type for amount: {}. Must be int or float." "".format(amount) ) if (isinstance(amount, numbers.Integral) and amount < 0) or ( @@ -1261,9 +1318,7 @@ def _validate_pruning_dim(t, dim): dim (int): index of the dim along which we define channels to prune """ if dim >= t.dim(): - raise IndexError( - "Invalid index {} for tensor of size {}".format(dim, t.shape) - ) + raise IndexError("Invalid index {} for tensor of size {}".format(dim, t.shape)) def _compute_norm(t, n, dim): diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index f62c2e921c497..eab89d8189ca6 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -213,9 +213,10 @@ def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True) them to compute the loss directly. A Tensor can be retrieved from a :class:`PackedSequence` object by accessing its ``.data`` attribute. - Arguments: + Args: input (Tensor): padded batch of variable length sequences. - lengths (Tensor): list of sequences lengths of each batch element. + lengths (Tensor or list(int)): list of sequence lengths of each batch + element (must be on the CPU if provided as a tensor). batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *`` format. enforce_sorted (bool, optional): if ``True``, the input is expected to @@ -278,7 +279,7 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_le See :ref:`this FAQ section ` for details. - Arguments: + Args: sequence (PackedSequence): batch to pad batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` format. @@ -342,7 +343,7 @@ def pad_sequence(sequences, batch_first=False, padding_value=0.0): where `T` is the length of the longest sequence. This function assumes trailing dimensions and type of all the Tensors in sequences are same. - Arguments: + Args: sequences (list[Tensor]): list of variable length sequences. batch_first (bool, optional): output will be in ``B x T x *`` if True, or in ``T x B x *`` otherwise @@ -397,7 +398,7 @@ def pack_sequence(sequences, enforce_sorted=True): PackedSequence(data=tensor([ 1, 4, 6, 2, 5, 3]), batch_sizes=tensor([ 3, 2, 1])) - Arguments: + Args: sequences (list[Tensor]): A list of sequences of decreasing length. enforce_sorted (bool, optional): if ``True``, checks that the input contains sequences sorted by length in a decreasing order. If diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index 22e1f8fd7de68..2ae7af3d78bd1 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -120,6 +120,10 @@ def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: flo fn = SpectralNorm(name, n_power_iterations, dim, eps) weight = module._parameters[name] + if isinstance(weight, torch.nn.parameter.UninitializedParameter): + raise ValueError( + 'The module passed to `SpectralNorm` can\'t have uninitialized parameters. ' + 'Make sure to run the dummy forward before applying spectral normalization') with torch.no_grad(): weight_mat = fn.reshape_weight_to_matrix(weight) diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index 684b0b9f239f3..c10a5f917a713 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -1,7 +1,7 @@ r""" Weight Normalization from https://arxiv.org/abs/1602.07868 """ -from torch.nn.parameter import Parameter +from torch.nn.parameter import Parameter, UninitializedParameter from torch import _weight_norm, norm_except_dim from typing import Any, TypeVar from ..modules import Module @@ -36,7 +36,10 @@ def apply(module, name: str, dim: int) -> 'WeightNorm': fn = WeightNorm(name, dim) weight = getattr(module, name) - + if isinstance(weight, UninitializedParameter): + raise ValueError( + 'The module passed to `WeightNorm` can\'t have uninitialized parameters. ' + 'Make sure to run the dummy forward before applying weight normalization') # remove w from parameter list del module._parameters[name] diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 4ec89e4c9b0b8..5b3aaea115e75 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -40,17 +40,63 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM once in order to get a trace of its execution to be exported; at the moment, it supports a limited set of dynamic models (e.g., RNNs.) - Arguments: + Args: model (torch.nn.Module): the model to be exported. - args (tuple of arguments or torch.Tensor): the inputs to - the model, e.g., such that ``model(*args)`` is a valid - invocation of the model. Any non-Tensor arguments will - be hard-coded into the exported model; any Tensor arguments - will become inputs of the exported model, in the order they - occur in args. If args is a Tensor, this is equivalent - to having called it with a 1-ary tuple of that Tensor. - (Note: passing keyword arguments to the model is not currently - supported. Give us a shout if you need it.) + args (tuple of arguments or torch.Tensor, a dictionary consisting of named arguments (optional)): + a dictionary to specify the input to the corresponding named parameter: + - KEY: str, named parameter + - VALUE: corresponding input + args can be structured either as: + + 1. ONLY A TUPLE OF ARGUMENTS or torch.Tensor:: + + ‘’args = (x, y, z)’' + + The inputs to the model, e.g., such that ``model(*args)`` is a valid invocation + of the model. Any non-Tensor arguments will be hard-coded into the exported model; + any Tensor arguments will become inputs of the exported model, in the order they + occur in args. If args is a Tensor, this is equivalent to having + called it with a 1-ary tuple of that Tensor. + + 2. A TUPLE OF ARGUEMENTS WITH A DICTIONARY OF NAMED PARAMETERS:: + + ‘’args = (x, + { + ‘y’: input_y, + ‘z’: input_z + }) ‘’ + + The inputs to the model are structured as a tuple consisting of + non-keyword arguments and the last value of this tuple being a dictionary + consisting of named parameters and the corresponding inputs as key-value pairs. + If certain named argument is not present in the dictionary, it is assigned + the default value, or None if default value is not provided. + + Cases in which an dictionary input is the last input of the args tuple + would cause a conflict when a dictionary of named parameters is used. + The model below provides such an example. + + class Model(torch.nn.Module): + def forward(self, k, x): + ... + return x + + m = Model() + k = torch.randn(2, 3)   + x = {torch.tensor(1.): torch.randn(2, 3)} + + In the previous iteration, the call to export API would look like + + torch.onnx.export(model, (k, x), ‘test.onnx’) + + This would work as intended. However, the export function + would now assume that the ‘x’ input is intended to represent the optional + dictionary consisting of named arguments. In order to prevent this from being + an issue a constraint is placed to provide an empty dictionary as the last + input in the tuple args in such cases. The new call would look like this. + + torch.onnx.export(model, (k, x, {}), ‘test.onnx’) + f: a file-like object (has to implement fileno that returns a file descriptor) or a string containing a file name. A binary Protobuf will be written to this file. @@ -102,12 +148,12 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM exporter falls back on this op. OperatorExportTypes.RAW: Export raw ir. OperatorExportTypes.ONNX_FALLTHROUGH: If an op is not supported - in ONNX, fall through and export the operator as is, as a custom + in ONNX, fall through and export the operator as is, as a custom ONNX op. Using this mode, the op can be exported and implemented by the user for their runtime backend. Example graph:: - graph(%x.1 : Long(1:1)):: + graph(%x.1 : Long(1, strides=[1])):: %1 : None = prim::Constant() %2 : Tensor = aten::sum(%x.1, %1) %y.1 : Tensor[] = prim::ListConstruct(%2) @@ -115,7 +161,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM is exported as:: - graph(%x.1 : Long(1:1)):: + graph(%x.1 : Long(1, strides=[1])):: %1 : Tensor = onnx::ReduceSum[keepdims=0](%x.1) %y.1 : Long() = prim::ListConstruct(%1) return (%y.1) @@ -127,7 +173,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM opset version of the onnx submodule. Since ONNX's latest opset may evolve before next stable release, by default we export to one stable opset version. Right now, supported stable opset version is 9. - The opset_version must be _onnx_master_opset or in _onnx_stable_opsets + The opset_version must be _onnx_main_opset or in _onnx_stable_opsets which are defined in torch/onnx/symbolic_helper.py do_constant_folding (bool, default False): If True, the constant-folding optimization is applied to the model during export. Constant-folding @@ -175,7 +221,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM 2:'height'}, 'input_2':{0:'batch'}, 'output':{0:'batch', - 1:'detections'}`` + 1:'detections'}}`` where provided names will be applied to exported dynamic axes 3. MIXED MODE OF (1) and (2):: @@ -212,13 +258,13 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM external_data_format (bool, default False): If True, then the model is exported in ONNX external data format, in which case some of the model parameters are stored in external binary files and not in the ONNX model file itself. See link for format - details: + details: https://github.com/onnx/onnx/blob/8b3f7e2e7a0f2aba0e629e23d89f07c7fc0e6a5e/onnx/onnx.proto#L423 Also, in this case, argument 'f' must be a string specifying the location of the model. - The external binary files will be stored in the same location specified by the model + The external binary files will be stored in the same location specified by the model location 'f'. If False, then the model is stored in regular format, i.e. model and parameters are all in one file. This argument is ignored for all export types other - than ONNX. + than ONNX. """ from torch.onnx import utils diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 07462d3f21a5a..187dcfcb87e2b 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -2,6 +2,7 @@ import torch import warnings from sys import maxsize as maxsize +from typing import Set import torch.onnx # This import monkey-patches graph manipulation methods on Graph, used for the @@ -125,7 +126,7 @@ def decorator(fn): def wrapper(g, *args, **kwargs): # some args may be optional, so the length may be smaller assert len(arg_descriptors) >= len(args) - args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] + args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] # type: ignore # only support _outputs in kwargs assert len(kwargs) <= 1 if len(kwargs) == 1: @@ -171,10 +172,35 @@ def _is_none(x): def _is_value(x): return isinstance(x, torch._C.Value) +def _is_tensor(x): + return x.type().isSubtypeOf(torch._C.TensorType.get()) def _is_tensor_list(x): return isinstance(x.type(), torch._C.ListType) and isinstance(x.type().getElementType(), torch._C.TensorType) +def _get_tensor_rank(x): + if not _is_tensor(x) or x.type() is None: + return None + return x.type().dim() + +def _get_tensor_sizes(x, allow_nonstatic=True): + if not _is_tensor(x) or x.type() is None: + return None + if allow_nonstatic: + # Each individual symbol is returned as None. + # e.g. [1, 'a', 'b'] -> [1, None, None] + return x.type().varyingSizes() + # returns None, if exists any symbol in sizes. + # e.g. [1, 'a', 'b'] -> None + return x.type().sizes() + +def _get_tensor_dim_size(x, dim): + try: + sizes = _get_tensor_sizes(x) + return sizes[dim] + except Exception: + pass + return None def _unimplemented(op, msg): warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported") @@ -211,19 +237,49 @@ def _try_get_scalar_type(*args): return None +def _select_helper(g, self, dim, index, apply_reshape=True): + index_const = _maybe_get_scalar(index) + index_dim = _get_tensor_rank(index) + if not _is_value(index_const): + # Index is a constant scalar. Make it a size 1 constant tensor. + index = g.op("Constant", value_t=torch.LongTensor([index_const])) + elif index_dim is not None and apply_reshape: + if index_dim == 0: + # Index is a scalar. Reshape it to a size 1 tensor. + index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1]))) + + index_scalar_type = index.type().scalarType() + if index_scalar_type is None or index_scalar_type not in ['Long', 'Int']: + index = g.op("Cast", index, to_i=cast_pytorch_to_onnx["Long"]) + return g.op("Gather", self, index, axis_i=dim) + + def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False): if _export_onnx_opset_version <= 9: - from torch.onnx.symbolic_opset9 import _slice - return _slice(g, input, axes, starts, ends) + from torch.onnx.symbolic_opset9 import _slice as _slice9 + return _slice9(g, input, axes, starts, ends) else: - from torch.onnx.symbolic_opset10 import _slice - return _slice(g, input, axes, starts, ends, steps, dynamic_slice) + from torch.onnx.symbolic_opset10 import _slice as _slice10 + return _slice10(g, input, axes, starts, ends, steps, dynamic_slice) +def _hardtanh_helper(g, input, min_val, max_val): + if _export_onnx_opset_version <= 10: + from torch.onnx.symbolic_opset9 import hardtanh + return hardtanh(g, input, min_val, max_val) + else: + from torch.onnx.symbolic_opset11 import hardtanh # type: ignore[no-redef] + return hardtanh(g, input, min_val, max_val) def _is_fp(value): if value: - type = value.type().scalarType() - return (type == 'Float') or (type == 'Double') or (type == 'Half') + if isinstance(value, torch.Tensor): + type = value.dtype + return (type == 'torch.float32') or (type == 'torch.float64') or (type == 'torch.float16') + else: + type = value.type().scalarType() + if type is None: + warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.") + return (type == 'Float') or (type == 'Double') or (type == 'Half') return False @@ -311,7 +367,8 @@ def _get_interpolate_attributes(g, mode, args): def _interpolate_get_scales(g, scale_factor, dim): offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) - if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor.isCompleteTensor() and scale_factor.type().dim() > 0): + scale_factor_rank = _get_tensor_rank(scale_factor) + if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor_rank is not None and scale_factor_rank > 0): return g.op("Concat", offsets, scale_factor, axis_i=0) else: scale_factor = _unsqueeze_helper(g, scale_factor, 0) @@ -348,15 +405,24 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_ size = g.op("Concat", *size, axis_i=0) scale_factor = _interpolate_size_to_scales(g, input, size, dim) else: - return _unimplemented("Both size and scales are None in __interpolate") + return _unimplemented("interpolate", "Both size and scales are None in __interpolate") return scale_factor, mode +def _unbind_helper(g, self, dim, _outputs): + if _export_onnx_opset_version <= 9: + from torch.onnx.symbolic_opset9 import unbind + else: + from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] + return unbind(g, self, dim, _outputs) + + def _scatter_helper(g, self, dim, index, src): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: - from torch.onnx.symbolic_opset11 import scatter + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore return scatter(g, self, dim, index, src) @@ -404,7 +470,8 @@ def _index_fill_reshape_helper(g, self, dim, index): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: - from torch.onnx.symbolic_opset11 import scatter + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore if self.type().dim() is None: return _unimplemented("index_fill", "input rank not accesible") @@ -486,7 +553,7 @@ def _is_split_static(split_size_or_sizes, _outputs): _default_onnx_opset_version = 9 -_onnx_master_opset = 10 +_onnx_main_opset = 13 _onnx_stable_opsets = [7, 8, 9, 10, 11, 12] _export_onnx_opset_version = _default_onnx_opset_version @@ -496,7 +563,7 @@ def _set_opset_version(opset_version): if opset_version == _default_onnx_opset_version: _export_onnx_opset_version = opset_version return - if opset_version in _onnx_stable_opsets + [_onnx_master_opset]: + if opset_version in _onnx_stable_opsets + [_onnx_main_opset]: _export_onnx_opset_version = opset_version return raise ValueError("Unsupported ONNX opset version: " + str(opset_version)) @@ -592,4 +659,4 @@ def _cast_func_template(to_i, g, input, non_blocking): # Global set to store the list of quantized operators in the network. # This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. -_quantized_ops = set() +_quantized_ops: Set[int] = set() diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 88bc6d5fb8b5d..6558df6e3d4ce 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -5,12 +5,12 @@ # This import monkey-patches graph manipulation methods on Graph, used for the # ONNX symbolics import torch.onnx.utils -from sys import maxsize import torch.onnx.symbolic_helper as sym_help from torch.onnx.symbolic_helper import parse_args, _unimplemented import torch.onnx.symbolic_opset9 +from sys import maxsize # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py @@ -205,44 +205,17 @@ def embedding_bag(g, include_last_offset): if scale_grad_by_freq and sym_help._training_mode: return sym_help._onnx_unsupported('embedding_bag with scale_grad_by_freq for training mode') - - from torch.onnx.symbolic_opset9 import size, div, select - - # Check if initial indices was 2D. In functional.py: - # offsets is set to torch.arange(0, indices.numel(), indices.size(1)) - # Then indices is reshaped to 1D: indices.reshape(-1) - if len(list(indices.node().inputs())) > 0 and indices.node().inputs().__next__().type().sizes() is not None \ - and len(indices.node().inputs().__next__().type().sizes()) == 2: - # Assert include_last_offset is False - assert not include_last_offset - embeddings = g.op("Gather", embedding_matrix, indices) - dim_0 = size(g, offsets, g.op("Constant", value_t=torch.LongTensor([0]))) - dim_1 = div(g, size(g, indices, g.op("Constant", value_t=torch.LongTensor([0]))), dim_0) - dim_2 = g.op("Constant", value_t=torch.LongTensor([-1])) - - shape = [dim_0, dim_1, dim_2] - shape = g.op("Concat", *shape, axis_i=0) - - if not sym_help._is_none(per_sample_weights): - per_sample_weights = g.op("Unsqueeze", per_sample_weights, axes_i=[1]) - embeddings = g.op("Mul", embeddings, per_sample_weights) - - embeddings = g.op("Reshape", embeddings, shape) - if mode == 0: - embeddings = g.op("ReduceSum", embeddings, axes_i=[1], keepdims_i=0) - elif mode == 1: - embeddings = g.op("ReduceMean", embeddings, axes_i=[1], keepdims_i=0) - else: - embeddings = g.op("ReduceMax", embeddings, axes_i=[1], keepdims_i=0) - # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. - # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. - return embeddings, None, None, None - elif offsets.type().sizes() is not None: + from torch.onnx.symbolic_opset9 import select + import warnings + warnings.warn("Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " + "Please use opset 11 or higher to export model for dynamic input shape.'") + offsets_dim_0 = sym_help._get_tensor_dim_size(offsets, 0) + if offsets_dim_0 is not None: if include_last_offset: - offset_len = offsets.type().sizes()[0] - 1 + offset_len = offsets_dim_0 - 1 offsets_extended = offsets else: - offset_len = offsets.type().sizes()[0] + offset_len = offsets_dim_0 offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))] offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) list_ = [] @@ -272,7 +245,8 @@ def embedding_bag(g, # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. return output, None, None, None else: - return sym_help._onnx_unsupported('embedding_bag with unknown shape of indices') + return sym_help._onnx_unsupported('embedding_bag with unknown shape of offsets for opset 10 is not supported. ' + 'please use opset 11 or higher.') @parse_args('v', 't', 'i', 'i', 'i') diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 01cdd48906334..85c7bf97c8830 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -6,10 +6,10 @@ import warnings import numpy -from torch.onnx.symbolic_helper import parse_args, _unimplemented +from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list from torch.onnx.symbolic_opset9 import expand, unused from torch.nn.modules.utils import _single, _pair, _triple - +from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py @@ -80,6 +80,105 @@ def index_put(g, self, indices_list_value, values, accumulate=False): ] index = g.op("Concat", *indices_list, axis_i=-1) else: + # Replace index_put node with masked_scatter or masked_fill + # when inputs to the index_put node contains boolean inputs + # + # index_put -> masked_fill + # + # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %6 : None = prim::Constant() + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const) + # %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() + # %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %11 : Device = prim::Constant[value="cpu"]() + # %12 : None = prim::Constant() + # %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %15 : None = prim::Constant() + # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) + # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %22 : int[] = prim::Constant[value=[-1]]() + # %23 : Tensor = aten::view(%16, %22) + # %24 : Tensor?[] = prim::ListConstruct(%23) + # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::index_put(%mask, %24, %18, %30) + # return (%25) + # + # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %3 : Tensor = onnx::Equal(%0, %some_const) + # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) + # %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4) + # %19 : Tensor = onnx::Cast[to=9](%12) + # %20 : Tensor = onnx::Constant[value={1}]() + # %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = onnx::Where(%19, %20, %0) + # return (%21) + # + # index_put -> masked_scatter + # + # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %6 : None = prim::Constant() + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::ne(%mask, %some_const) + # %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() + # %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %18 : Device = prim::Constant[value="cpu"]() + # %19 : None = prim::Constant() + # %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %22 : None = prim::Constant() + # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) + # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %30 : int[] = prim::Constant[value=[-1]]() + # %31 : Tensor = aten::view(%23, %30) + # %32 : Tensor?[] = prim::ListConstruct(%31) + # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::index_put(%mask, %32, %28, %38) + # return (%33) + # + # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), + # %some_const : Float(requires_grad=0, device=cpu)): + # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %4 : Tensor = onnx::Equal(%0, %some_const) + # %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4) + # %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5) + # %19 : Tensor = onnx::Shape(%0) + # %20 : Tensor = onnx::Expand(%13, %19) + # %21 : Tensor = onnx::NonZero(%20) + # %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21) + # %23 : Tensor = onnx::Constant[value={-1}]() + # %24 : Tensor = onnx::Reshape(%3, %23) + # %25 : Tensor = onnx::Shape(%22) + # %27 : Tensor = onnx::Constant[value={0}]() + # %28 : Tensor = onnx::Gather[axis=0](%25, %27) + # %29 : Tensor = onnx::Constant[value={0}]() + # %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29) + # %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28) + # %32 : Tensor = onnx::Constant[value={0}]() + # %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32) + # %34 : Tensor = onnx::Slice(%24, %30, %31, %33) + # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = onnx::ScatterND(%0, %22, %34) + # return (%35) + + bool_inp = list(index.node().inputs())[0] + if bool_inp.type() is not None and bool_inp.type().scalarType() == 'Bool': + rank = sym_help._get_tensor_rank(values) + if rank is not None and rank == 0: + from torch.onnx.symbolic_opset9 import masked_fill + return masked_fill(g, self, bool_inp, values) + return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = g.op("Unsqueeze", index, axes_i=[-1]) sub_data_shape = sym_help._slice_helper( @@ -102,8 +201,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False): @parse_args('v', 'i') def pixel_shuffle(g, self, upscale_factor): - dims = self.type().sizes() - if len(dims) != 4: + rank = sym_help._get_tensor_rank(self) + if rank is not None and rank != 4: return _unimplemented("pixel_shuffle", "only support 4d input") return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") @@ -181,11 +280,12 @@ def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_s "while exporting interpolate. Assuming that it is not a scalar.") if is_scalar: - if not input.type().dim(): + rank = sym_help._get_tensor_rank(input) + if rank is None: return sym_help._unimplemented("interpolate (with a scalar output_size)", "missing input shape (try giving an array of output_size values)") size = unsqueeze(g, size, 0) - size = [size for i in range(input.type().dim() - 2)] + size = [size for i in range(rank - 2)] size = g.op("Concat", *size, axis_i=0) size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long']) size = g.op("Concat", input_size, size, axis_i=0) @@ -200,9 +300,10 @@ def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_s mode_s=mode, # nearest, linear, or cubic nearest_mode_s="floor") else: # if not sym_help._is_none(scales) - if not input.type().dim(): + rank = sym_help._get_tensor_rank(input) + if rank is None: return sym_help._unimplemented("interpolate (with scales)", "missing input shape") - scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim()) + scales = sym_help._interpolate_get_scales(g, scale_factor, rank) return g.op("Resize", input, roi, @@ -272,9 +373,10 @@ def masked_scatter(g, self, mask, source): def _len(g, self): - if self.type().isSubtypeOf(torch._C.ListType.ofTensors()): + if _is_tensor_list(self) or self.node().kind() == "onnx::SplitToSequence": return g.op("SequenceLength", self) - return g.op("Size", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return g.op('Squeeze', sz_0, axes_i=[0]) def __getitem_(g, self, i): @@ -416,7 +518,7 @@ def unbind(g, self, dim=0, _outputs=None): # Generate paddings in ONNX order based on pad in pytorch. -# Arguments: +# Args: # dim: the dimension of the tensor. # pad: the paddings in pytorch. # The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, @@ -449,19 +551,19 @@ def constant_pad_nd(g, input, padding, value=None): mode = "constant" value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, input) - pad = _prepare_onnx_paddings(g, input.type().dim(), padding) + pad = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pad, value, mode_s=mode) def reflection_pad(g, input, padding): mode = "reflect" - paddings = _prepare_onnx_paddings(g, input.type().dim(), padding) + paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, paddings, mode_s=mode) def replication_pad(g, input, padding): mode = "edge" - paddings = _prepare_onnx_paddings(g, input.type().dim(), padding) + paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, paddings, mode_s=mode) @@ -539,17 +641,40 @@ def squeeze(g, self, dim=None): dim = sym_help._get_const(dim, 'i', 'dim') - # create 'cond' node (condition is shape[i]==1) - dim_constant = g.op("Constant", value_t=torch.tensor([dim])) - size = sym_help._size_helper(g, self, dim_constant) - const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) - cond = g.op("Equal", size, const_one) - # create the 'If' node and add the 'then' and 'else' blocks to it. - if_node_outputs = g.op("If", cond) - if_node = if_node_outputs.node() - torch.onnx.utils._add_block(if_node, self, "onnx::Squeeze", axes_i=[dim]) - torch.onnx.utils._add_block(if_node, self, "onnx::Identity") - return if_node_outputs + input_rank = sym_help._get_tensor_rank(self) + adjusted_dim = dim + if input_rank is not None and dim < 0: + adjusted_dim += input_rank + dim_size = sym_help._get_tensor_dim_size(self, adjusted_dim) + if (dim < 0 and input_rank is None) or dim_size is None: + # If onnx shape inference is not on, export always as dynamic. + # Because we cannot tell if observed static shape is also static at runtime. + # create 'cond' node (condition is shape[i]==1) + dim_constant = g.op("Constant", value_t=torch.tensor([dim])) + size = sym_help._size_helper(g, self, dim_constant) + const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) + cond = g.op("Equal", size, const_one) + # create the 'If' node and add the 'then' and 'else' blocks to it. + if_node_outputs = g.op("If", cond) + if_node = if_node_outputs.node() + if_block = torch.onnx.utils._add_block(if_node) + squeeze_ = if_block.op("Squeeze", self, axes_i=[dim]) + torch.onnx.utils._add_output_to_block(if_block, squeeze_) + else_block = torch.onnx.utils._add_block(if_node) + identity_ = else_block.op("Identity", self) + torch.onnx.utils._add_output_to_block(else_block, identity_) + return if_node_outputs + + # For static input shape + dim = adjusted_dim + if dim_size > 1: + warnings.warn("This model contains a squeeze operation on dimension " + str(dim) + ". The size of " + + "this dimension in the given input is " + str(dim_size) + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please export with dynamic_axes argument.") + return self + return g.op("Squeeze", self, axes_i=[dim]) + @parse_args('v', 'i') def unsqueeze(g, self, dim): @@ -560,6 +685,26 @@ def mm(g, self, other): return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) +def index(g, self, index): + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return g.op("ATen", self, index, operator_s="index") + + if sym_help._is_packed_list(index): + indices = sym_help._unpack_list(index) + else: + indices = [index] + + # Handle single mask index. + if len(indices) == 1: + index = indices[0] + if not sym_help._is_none(index) and (index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): + from torch.onnx.symbolic_opset9 import nonzero + index = nonzero(g, index) + return g.op('GatherND', self, index) + from torch.onnx.symbolic_opset9 import index as index_opset9 + return index_opset9(g, self, index) + + def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, 'i') if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: @@ -593,7 +738,7 @@ def __rshift_(g, self, other): if not sym_help._is_fp(self): other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float']) two_pow = g.op('Pow', two, other) - + two_pow = g.op('Cast', two_pow, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) rshift = g.op('Div', self, two_pow) return rshift @@ -612,7 +757,7 @@ def __lshift_(g, self, other): if not sym_help._is_fp(self): other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float']) two_pow = g.op('Pow', two, other) - + two_pow = g.op('Cast', two_pow, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) lshift = g.op('Mul', self, two_pow) return lshift @@ -712,23 +857,114 @@ def im2col(g, input, kernel_size, dilation, padding, stride): return g.op("Reshape", output, output_shape) +def narrow(g, input, dim, start, length): + from torch.onnx.symbolic_helper import _slice_helper + end = g.op("Add", start, length) + return _slice_helper(g, input, axes=dim, starts=start, ends=end, dynamic_slice=True) + + @parse_args('v', 'i', 'i') def flatten(g, input, start_dim, end_dim): - dim = input.type().dim() - if dim is None: - return _unimplemented("dim", - "ONNX and PyTorch use different strategies to split the input. " - "Input rank must be known at export time.") - + dim = sym_help._get_tensor_rank(input) # use ONNX's Flatten operator for cases where the output shape is 2D if start_dim == 1: - if (end_dim == -1 or end_dim == dim - 1): + if (end_dim == -1 or (dim is not None and end_dim == dim - 1)): return g.op("Flatten", input, axis_i=start_dim) elif start_dim == 0: - if (end_dim == -2 or end_dim == dim - 2): + if (end_dim == -2 or (dim is not None and end_dim == dim - 2)): return g.op("Flatten", input, axis_i=end_dim + 1) + if dim is None: + return _unimplemented("dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.") # if end_dim is negative add dim if end_dim < 0 : end_dim = dim + end_dim return sym_help._flatten_helper(g, input, start_dim, end_dim, dim) + + +@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v', 'i') +def embedding_bag(g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset): + if scale_grad_by_freq and sym_help._training_mode: + return sym_help._onnx_unsupported('embedding_bag with scale_grad_by_freq for training mode') + + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=9) + zero = g.op("Constant", value_t=torch.tensor([0])) + + indices_len = g.op("Unsqueeze", + sym_help._size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), + axes_i=[0]) + if not include_last_offset: + offsets = [offsets, indices_len] + offsets = g.op("Concat", *offsets, axis_i=0) + + # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by + # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. + # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. + offsets_starts = sym_help._slice_helper(g, offsets, axes=[0], starts=[0], ends=[maxsize], steps=[1]) + offsets_ends = sym_help._slice_helper(g, offsets, axes=[0], starts=[1], ends=[maxsize], steps=[1]) + + loop_len = sym_help._size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) + loop = g.op("Loop", loop_len, loop_condition) + + loop_block = _add_block(loop.node()) + block_input_iter = _add_input_to_block(loop_block) + cond = _add_input_to_block(loop_block) + + indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0) + indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0) + indices_start = loop_block.op("Unsqueeze", indices_start, axes_i=[0]) + indices_end = loop_block.op("Unsqueeze", indices_end, axes_i=[0]) + + indices_row = loop_block.op("Slice", indices, indices_start, indices_end, zero) + embeddings = loop_block.op("Gather", embedding_matrix, indices_row, axis_i=0) + if not sym_help._is_none(per_sample_weights): + per_sample_weights_row = loop_block.op("Slice", per_sample_weights, + indices_start, + indices_end, + zero) + per_sample_weights_row = loop_block.op("Unsqueeze", per_sample_weights_row, axes_i=[1]) + embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = loop_block.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0) + elif mode == 1: + embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) + else: + embeddings = loop_block.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) + + cond_out = loop_block.op("Cast", loop_condition, to_i=9) + _add_output_to_block(loop_block, cond_out) + _add_output_to_block(loop_block, embeddings) + + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return loop.node().output(), None, None, None + + +def prim_ConstantChunk(g, self, chunks, dim): + input_shape = g.op("Shape", self) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + axis_next = g.op("Constant", value_t=torch.tensor([dim + 1], dtype=torch.long)) + input_shape_dim = g.op("Slice", input_shape, axis, axis_next) + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) + chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)) + input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) + chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) + res = [] + for i in range(chunks): + index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) + end = g.op("Mul", chunk_dim, index) + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 9a7fb9bc9bc28..cd67fd508fa2d 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -1,7 +1,8 @@ import torch import torch.onnx.symbolic_helper as sym_help -from torch.onnx.symbolic_helper import parse_args, _parse_arg +from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented +from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block # EDITING THIS FILE? READ THIS FIRST! @@ -36,15 +37,9 @@ def nll_loss(g, self, target, weight, reduction, ignore_index): reduction_vals = ['none', 'mean', 'sum'] reduction = reduction_vals[reduction] - # when ignore_index is not specified, ignore_index == onnx::Constant[value={-100}] + # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). ignore_index = sym_help._maybe_get_const(ignore_index, 'i') - if ignore_index == -100: - if weight.node().mustBeNone(): - return g.op("NegativeLogLikelihoodLoss", self, target, reduction_s=reduction) - else: - return g.op("NegativeLogLikelihoodLoss", self, target, weight, reduction_s=reduction) - - # if ignore_index is specified, compute nllloss with no reduction and apply the reduction afterwards if weight.node().mustBeNone(): nllloss = g.op("NegativeLogLikelihoodLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index) else: @@ -71,7 +66,7 @@ def celu(g, self, alpha): def argmax(g, input, dim, keepdim): if sym_help._is_none(dim): from torch.onnx.symbolic_opset9 import reshape - flattened = reshape(g, input, (-1,)) + flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False, select_last_index_i=False) else: dim = _parse_arg(dim, 'i') @@ -82,7 +77,7 @@ def argmax(g, input, dim, keepdim): def argmin(g, input, dim, keepdim): if sym_help._is_none(dim): from torch.onnx.symbolic_opset9 import reshape - flattened = reshape(g, input, (-1,)) + flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False, select_last_index_i=False) else: dim = _parse_arg(dim, 'i') @@ -98,3 +93,63 @@ def ge(g, input, other): def le(g, input, other): return g.op('LessOrEqual', input, other) + +@parse_args('v', 'i', 'v', 'v') +def unfold(g, input, dimension, size, step): + size = sym_help._maybe_get_const(size, 'i') + step = sym_help._maybe_get_const(step, 'i') + if not sym_help._is_value(size) and not sym_help._is_value(step): + from torch.onnx.symbolic_opset9 import unfold as _unfold + return _unfold(g, input, dimension, size, step) + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) + + sizedim = sym_help._get_tensor_dim_size(input, dimension) + if sizedim is not None: + low_start = g.op("Constant", value_t=torch.tensor(0)) + low_end = g.op("Constant", value_t=torch.tensor(sizedim)) + hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) + low_indices = g.op("Range", low_start, low_end, step) + hi_indices = g.op("Range", size, hi_end, step) + + low_size = sym_help._size_helper(g, low_indices, g.op("Constant", value_t=torch.tensor(0))) + hi_size = sym_help._size_helper(g, hi_indices, g.op("Constant", value_t=torch.tensor(0))) + + ndim = sym_help._get_tensor_rank(input) + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + + unsqueeze_list = [] + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=9) + loop_len = g.op("Min", low_size, hi_size) + loop = g.op("Loop", loop_len, loop_condition) + + loop_block = _add_block(loop.node()) + block_input_iter = _add_input_to_block(loop_block) + cond = _add_input_to_block(loop_block) + + starts = loop_block.op("Gather", low_indices, block_input_iter) + ends = loop_block.op("Gather", hi_indices, block_input_iter) + axes = loop_block.op("Constant", value_t=torch.tensor([2])) + starts = loop_block.op("Unsqueeze", starts, axes_i=[0]) + ends = loop_block.op("Unsqueeze", ends, axes_i=[0]) + stack = loop_block.op("Slice", input, starts, ends, axes) + + unsqueeze = loop_block.op("Unsqueeze", loop_block.op("Transpose", stack, perm_i=perm), axes_i=[dimension]) + unsqueeze_list.append(unsqueeze) + concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0) + + cond_out = loop_block.op("Cast", loop_condition, to_i=9) + _add_output_to_block(loop_block, cond_out) + _add_output_to_block(loop_block, concat) + + loop_output = loop.node().output() + perm = [0, 1, 2, 3, 4] + perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] + transpose = g.op("Transpose", loop_output, perm_i=perm) + squeeze = g.op("Squeeze", transpose, axes_i=[0]) + + return squeeze + else: + return _unimplemented("Unfold", "input size not accessible") diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py new file mode 100644 index 0000000000000..001a20147c4f5 --- /dev/null +++ b/torch/onnx/symbolic_opset13.py @@ -0,0 +1,110 @@ +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +# This file exports ONNX ops for opset 13 +from torch.onnx.symbolic_helper import _block_list_in_opset +import torch +import torch.onnx.symbolic_helper as sym_help +from torch.onnx.symbolic_helper import parse_args + +block_listed_operators = ['embedding_bag'] + +for block_listed_op in block_listed_operators: + vars()[block_listed_op] = _block_list_in_opset(block_listed_op) + + +@parse_args('v', 'i', 'none') +def softmax(g, input, dim, dtype=None): + softmax = g.op('Softmax', input, axis_i=dim) + if dtype and dtype.node().kind() != 'prim::Constant': + parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype') + softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) + + return softmax + + +@parse_args('v', 'i', 'none') +def log_softmax(g, input, dim, dtype=None): + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != 'prim::Constant': + parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype') + return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) + return return_op + + +@parse_args('v', 'v', 'i') +def frobenius_norm(g, self, dim=None, keepdim=False): + dim_val = sym_help._maybe_get_const(dim, 'is') + if not sym_help._is_value(dim_val) and len(dim_val) == 0: + return g.op("ReduceL2", self, keepdims_i=0) + sqr = g.op('Mul', self, self) + sumsqr = g.op('ReduceSum', sqr, dim, keepdims_i=keepdim) + return g.op('Sqrt', sumsqr) + + +@parse_args('v', 'v', 'i', 'i') +def split(g, self, split_size_or_sizes, dim, _outputs=None): + if not sym_help._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if sym_help._is_packed_list(split_size_or_sizes) and \ + len(sym_help._unpack_list(split_size_or_sizes)) == _outputs: + split_sizes = [g.op("Unsqueeze", v, g.op("Constant", value_t=torch.tensor([0]))) + for v in sym_help._unpack_list(split_size_or_sizes)] + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op("Add", start, split_sizes[i]) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [g.op("SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long))) + for i in range(_outputs)] + + split_val = split_size_or_sizes.node()['value'] + if split_val.dim() > 0: + return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) + split_size = sym_help._get_const(split_size_or_sizes, 'i', 'split_size') + + size = self.type().sizes()[dim] + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + splits = g.op("Constant", value_t=torch.tensor(splits)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +def split_with_sizes(g, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@parse_args('v', 'i', 'i') +def unbind(g, self, dim=0, _outputs=None): + if _outputs is None: + return g.op("SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, keepdims_i=0) + + splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) + outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) for out in outputs] + return squeezed_outputs + + +def glu(g, input, dim): + first, second = g.op('Split', input, dim, outputs=2) + return g.op('Mul', first, g.op('Sigmoid', second)) diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 0a1476d881111..1fa9fa5e985b9 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -4,7 +4,7 @@ import torch.onnx.symbolic_opset9 as sym_opset9 from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type -from torch.onnx.symbolic_opset9 import _cast_Float +from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore import warnings @@ -148,10 +148,9 @@ def matmul(g, self, other): def prelu(g, self, weight): - if self.isCompleteTensor(): - self_sizes = self.type().sizes() - if self_sizes and len(self_sizes) > 2: - weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1))) + self_rank = sym_help._get_tensor_rank(self) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) if _try_get_scalar_type(self): old_type, self, weight = _try_cast_integer_to_float(g, self, weight) return _cast_to_type(g, g.op("PRelu", self, weight), old_type) @@ -182,20 +181,6 @@ def addmm(g, self, mat1, mat2, beta, alpha): return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha)) -def view(g, self, size): - size = sym_help._maybe_get_const(size, 'is') - if sym_help._is_value(size): - shape = size - else: - if self.isCompleteTensor(): - self_sizes = self.type().sizes() - if self_sizes and len(size) == 2 and self_sizes[0] == size[0]: - old_type, self = _try_cast_integer_to_float(g, self) - return _cast_to_type(g, g.op("Flatten", self, axis_i=1), old_type) - shape = g.op("Constant", value_t=torch.LongTensor(size)) - return g.op("Reshape", self, shape) - - def flatten(g, input, start_dim, end_dim): start_dim_i = sym_help._get_const(start_dim, 'i', 'start_dim') end_dim_i = sym_help._get_const(end_dim, 'i', 'end_dim') @@ -281,7 +266,7 @@ def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, mem def repeat(g, self, repeats): if not sym_help._is_value(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) - if sym_help._is_packed_list(repeats): + if sym_help._is_packed_list(repeats): repeat_size_len = len(sym_help._unpack_list(repeats)) else: const_repeats = sym_help._maybe_get_const(repeats, 'is') @@ -290,5 +275,5 @@ def repeat(g, self, repeats): sizes = self.type().sizes() diff_dims = repeat_size_len - len(sizes) if diff_dims > 0: - self = sym_opset9.view(g, self, [1] * diff_dims + sizes) + self = sym_opset9.view(g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))) return g.op("Tile", self, repeats) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 0acb254327e35..ada731884f767 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -13,6 +13,8 @@ import torch.onnx.symbolic_helper as sym_help from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented +from typing import Optional + import numpy import math import warnings @@ -103,11 +105,11 @@ def mul(g, self, other): def div(g, self, other): - return g.op("Div", self, other) + return true_divide(g, self, other) def floor_divide(g, self, other): - out = div(g, self, other) + out = g.op('Div', self, other) # the correct operation is truncate, which is not supported in ONNX, # we cannot call floor since it will behave differently for negative numbers # (eg. -0.1 should become -0 ) @@ -121,6 +123,7 @@ def floor_divide(g, self, other): # - self is not fp and other is not fp, the output's type is self's output type # - the output type defaults to Float scalar_type = self.type().scalarType() + if scalar_type is not None: if not sym_help._is_fp(self) and \ other.type().scalarType() is not None and \ @@ -144,19 +147,19 @@ def true_divide(g, self, other): # Case 1: both values are floating # Performs div as usual if sym_help._is_fp(self) and sym_help._is_fp(other): - return div(g, self, other) + return g.op("Div", self, other) # Case 2: self is floating, other is not # Casts other to self's dtype if sym_help._is_fp(self): other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) - return div(g, self, other) + return g.op("Div", self, other) # Case 3: other is floating, self is not # Casts self to other's dtype if sym_help._is_fp(other): self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other.type().scalarType()]) - return div(g, self, other) + return g.op("Div", self, other) # Case 4: neither is floating # Casts both inputs to the default scalar type @@ -168,7 +171,7 @@ def true_divide(g, self, other): self = g.op("Cast", self, to_i=onnx_scalar_type) other = g.op("Cast", other, to_i=onnx_scalar_type) - return div(g, self, other) + return g.op("Div", self, other) def reciprocal(g, self): @@ -186,9 +189,11 @@ def stack(g, tensor_list, dim): unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in sym_help._unpack_list(tensor_list)] return g.op("Concat", *unsqueezed, axis_i=dim) + def _list(g, self): return self + def mm(g, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is # since beta = 0 @@ -206,6 +211,44 @@ def matmul(g, self, other): @parse_args('v', 'v', 'v', 't', 't') def addmm(g, self, mat1, mat2, beta, alpha): + dtype = None + self_dtype = sym_help._try_get_scalar_type(self) + mat1_dtype = sym_help._try_get_scalar_type(mat1) + mat2_dtype = sym_help._try_get_scalar_type(mat2) + if self_dtype is not None: + dtype = self_dtype + elif mat1_dtype is not None: + dtype = mat1_dtype + elif mat2_dtype is not None: + dtype = mat2_dtype + + mat1_rank = sym_help._get_tensor_rank(mat1) + mat2_rank = sym_help._get_tensor_rank(mat2) + + def isNotNoneAnd(v, u): + return v is not None and v != u + + if dtype is not None and (isNotNoneAnd(mat1_rank, 2) or isNotNoneAnd(mat2_rank, 2)): + dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) + dtype = sym_help.scalar_type_to_pytorch_type[dtype] + + res1 = g.op("MatMul", mat1, mat2) + res2 = self + + alpha = sym_help._scalar(alpha) + beta = sym_help._scalar(beta) + + if alpha != 1: + alpha = g.op("Constant", + value_t=torch.tensor(alpha, dtype=dtype)) + res1 = g.op("Mul", res1, alpha) + if beta != 1: + beta = g.op("Constant", + value_t=torch.tensor(sym_help._scalar(beta), dtype=dtype)) + res2 = g.op("Mul", res2, beta) + + return g.op("Add", res1, res2) + return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha)) @@ -218,7 +261,7 @@ def sqrt(g, self): def rsqrt(g, self): - return div(g, sym_help._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self)) + return g.op("Div", sym_help._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self)) def tanh(g, self): @@ -270,7 +313,7 @@ def _maybe_cast_reduce_op_input(g, self): if dtype is not None: # pytorch reduce-ops cast all other integral types to int64 if not sym_help._is_fp(self) and not (dtype == 'Long'): - self = _cast_Long(g, self, False) + self = _cast_Long(g, self, False) # type: ignore return self @@ -372,7 +415,7 @@ def expand(g, self, size, implicit): # Expand with -1 dim value means dim is unchanged. # Since onnx::expand supports two-way broadcasting, # -1 dim value can be exported to onnx as 1 - size = view(g, stack(g, size, 0), [-1]) + size = view(g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))) dtype = 4 # dim type is int64 ones = ones_like(g, size, dtype) neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) @@ -420,8 +463,8 @@ def size(g, self, dim=None): if dim is None: return g.op("Shape", self) if sym_help._maybe_get_const(dim, 'i') < 0: - rank = self.type().dim() - if rank: + rank = sym_help._get_tensor_rank(self) + if rank is not None: dim = sym_help._maybe_get_const(dim, 'i') + rank dim = g.op("Constant", value_t=torch.tensor(dim)) return sym_help._size_helper(g, self, dim) @@ -433,8 +476,9 @@ def transpose(g, self, dim0, dim1): return self # NB: Transpose in ONNX is actually a Permute - if self.isCompleteTensor(): - axes = list(range(self.type().dim())) + rank = sym_help._get_tensor_rank(self) + if rank is not None: + axes = list(range(rank)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return g.op("Transpose", self, perm_i=axes) else: @@ -459,20 +503,19 @@ def view(g, self, size): if sym_help._is_value(size): shape = size else: - if self.isCompleteTensor(): - self_sizes = self.type().sizes() - if self_sizes and len(size) == 2 and self_sizes[0] == size[0]: - return g.op("Flatten", self, axis_i=1) shape = g.op("Constant", value_t=torch.LongTensor(size)) return g.op("Reshape", self, shape) + def view_as(g, self, other): shape = g.op("Shape", other) return g.op("Reshape", self, shape) def prim_ConstantSplit(g, self, split_size, dim): - size = self.type().sizes()[dim] + size = sym_help._get_tensor_dim_size(self, dim) + if size is None: + return _unimplemented('prim::ConstantSplit', 'unknown dimension size') splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: @@ -485,7 +528,10 @@ def prim_ConstantSplit(g, self, split_size, dim): # TODO: Once we have proper scoping, stop reimplementing chunk, delete this # method, and use the desugared version def prim_ConstantChunk(g, self, chunks, dim): - split_size = (self.type().sizes()[dim] + chunks - 1) // chunks + dim_size = sym_help._get_tensor_dim_size(self, dim) + if dim_size is None: + return _unimplemented('prim::ConstantChunk', 'unknown dimension size') + split_size = (dim_size + chunks - 1) // chunks return prim_ConstantSplit(g, self, split_size, dim) @@ -493,8 +539,10 @@ def prim_ConstantChunk(g, self, chunks, dim): def unsafe_chunk(g, self, chunks, dim, _outputs=None): if _outputs is None: return sym_help._onnx_opset_unsupported_detailed('unsafe_chunk', 9, 11, 'Dynamic number of outputs not supported') - split_size = (self.type().sizes()[dim] + chunks - 1) // chunks - size = self.type().sizes()[dim] + size = sym_help._get_tensor_dim_size(self, dim) + if size is None: + return _unimplemented('unsafe_chunk', 'unknown dimension size') + split_size = (size + chunks - 1) // chunks splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: @@ -512,7 +560,9 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None): split_size = sym_help._get_const(split_size_or_sizes, 'i', 'split_size') dim = sym_help._get_const(dim, 'i', 'dim') - size = self.type().sizes()[dim] + size = sym_help._get_tensor_dim_size(self, dim) + if size is None: + return sym_help._onnx_opset_unsupported_detailed('split', 9, 11, 'Unknown dimension size not supported') splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: @@ -560,6 +610,10 @@ def select(g, self, dim, index): return g.op("Gather", self, index, axis_i=dim) +def square(g, self): + return g.op("Mul", self, self) + + def squeeze(g, self, dim=None): if dim is None: return g.op("Squeeze", self) @@ -567,8 +621,8 @@ def squeeze(g, self, dim=None): squeeze_dim = sym_help._get_const(dim, 'i', 'dim') # Handle negative dims if squeeze_dim < 0: - rank = self.type().dim() - if rank: + rank = sym_help._get_tensor_rank(self) + if rank is not None: warnings.warn("ONNX export squeeze with negative axis " + str(squeeze_dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + @@ -579,17 +633,17 @@ def squeeze(g, self, dim=None): else: return _unimplemented('squeeze', 'negative axis with unknown input rank') - input_shape = self.type().sizes() - if input_shape is None: + dim_size = sym_help._get_tensor_dim_size(self, squeeze_dim) + if dim_size is None: warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " + "with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + "non-singleton dimensions, it is recommended to export this model using opset " + "version 11 or higher.") return g.op("Squeeze", self, axes_i=[squeeze_dim]) - if input_shape[squeeze_dim] > 1: + if dim_size > 1: warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " + - "this dimension in the given input is " + str(input_shape[squeeze_dim]) + ". The model will " + + "this dimension in the given input is " + str(dim_size) + ". The model will " + "be exported without the squeeze node. If the model is intended to be used with dynamic " + "input shapes, please use opset version 11 to " + "export the model.") @@ -600,10 +654,9 @@ def squeeze(g, self, dim=None): return g.op("Squeeze", self, axes_i=[squeeze_dim]) def prelu(g, self, weight): - if self.isCompleteTensor(): - self_sizes = self.type().sizes() - if self_sizes and len(self_sizes) > 2: - weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1))) + self_rank = sym_help._get_tensor_rank(self) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) return g.op("PRelu", self, weight) @@ -620,7 +673,8 @@ def floor(g, input): def _len(g, self): - return g.op("Size", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return g.op('Squeeze', sz_0, axes_i=[0]) @parse_args('v', 't', 't') @@ -642,7 +696,9 @@ def leaky_relu(g, input, negative_slope, inplace=False): @parse_args('v', 'i') def glu(g, input, dim): - assert input.type().sizes()[dim] % 2 == 0 + dim_size = sym_help._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 first, second = g.op('Split', input, axis_i=dim, outputs=2) return g.op('Mul', first, g.op('Sigmoid', second)) @@ -670,7 +726,7 @@ def softmax(g, input, dim, dtype=None): # otherwise transpose the input to put the vectors to be normalized to the last dimension. # When input rank is not known at export time we compute softmax using a subgraph # with other operators - input_dim = input.type().dim() + input_dim = sym_help._get_tensor_rank(input) if input_dim is not None: # TODO: remove this as onnx opset 11 spec allows negative axes if dim < 0: @@ -712,7 +768,10 @@ def softplus(g, self, beta, threshold): def get_pool_ceil_padding(input, kernel_size, stride, padding): - dim = input.type().sizes()[-len(padding):] + sizes = sym_help._get_tensor_sizes(input) + dim = sizes[-len(padding):] if sizes is not None else None + if dim is None or any([i is None for i in dim]): + return _unimplemented(name, "input size not accessible") ceiled_output_dim = [int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1 for i in range(0, len(padding))] # ensure last pooling starts inside @@ -737,8 +796,6 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding): def _max_pool(name, tuple_fn, ndims, return_indices): @parse_args('v', 'is', 'is', 'is', 'is', 'i') def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): - if ceil_mode and not input.isCompleteTensor(): - return _unimplemented(name, "input size not accessible") if set(tuple_fn(dilation)) != {1}: return _unimplemented(name, "dilation") if not stride: @@ -795,8 +852,6 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): def _avg_pool(name, tuple_fn): @parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none') def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None): - if ceil_mode and not input.isCompleteTensor(): - return _unimplemented(name, "input size not accessible") if not stride: stride = kernel_size padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name) @@ -826,7 +881,6 @@ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include def _adaptive_pool(name, type, tuple_fn, fn=None): - @parse_args('v', 'is') def symbolic_fn(g, input, output_size): # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, # by executing a GlobalPool. @@ -837,19 +891,30 @@ def symbolic_fn(g, input, output_size): # so we try using max_poolxd_with_indices, and if it is not possible # (input is not a complete tensor or output size not factor of input size) # then we call GlobalAveragePool and return None for the indices + try: + output_size = _parse_arg(output_size, 'is') + except Exception: + return sym_help._onnx_unsupported('adaptive pooling, since output_size is not constant.') if output_size == [1] * len(output_size) and type == "AveragePool": return g.op("GlobalAveragePool", input) - if not input.isCompleteTensor(): + sizes = sym_help._get_tensor_sizes(input) + try: + dim = sizes[2:] + except Exception: + dim = None + if dim is None or any([i is None for i in dim]): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None return _unimplemented(name, 'input size not accessible') - dim = input.type().sizes()[2:] # verify if output size % input size = 0 for all dim mod = [dim[i] % output_size[i] for i in range(0, len(dim))] if mod != [0] * len(mod): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None - return _unimplemented(name, 'output size that are not factor of input size') + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return _unimplemented(name, 'output size that are not factor of input size') + else: + return sym_help._onnx_unsupported(name + ', since output size is not factor of input size') k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] # call max_poolxd_with_indices to get indices in the output if type == "MaxPool": @@ -871,7 +936,7 @@ def symbolic_fn(g, input, output_size): # Generate paddings in ONNX order based on pad in pytorch. -# Arguments: +# Args: # dim: the dimension of the tensor. # pad: the paddings in pytorch. # The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... @@ -904,21 +969,21 @@ def constant_pad_nd(g, input, padding, value): return sym_help._onnx_opset_unsupported_detailed('Pad', 9, 11, 'The value for the padding must be constant') padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(input.type().dim(), padding) + paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value) def reflection_pad(g, input, padding): mode = "reflect" padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(input.type().dim(), padding) + paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pads_i=paddings, mode_s=mode) def replication_pad(g, input, padding): mode = "edge" padding = _convert_padding_node(padding) - paddings = _prepare_onnx_paddings(input.type().dim(), padding) + paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding) return g.op("Pad", input, pads_i=paddings, mode_s=mode) @@ -1051,7 +1116,7 @@ def __rshift_(g, self, other): if not sym_help._is_fp(self): other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float']) two_pow = g.op('Pow', two, other) - + two_pow = g.op('Cast', two_pow, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) rshift = g.op('Div', self, two_pow) return rshift @@ -1067,7 +1132,7 @@ def __lshift_(g, self, other): if not sym_help._is_fp(self): other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float']) two_pow = g.op('Pow', two, other) - + two_pow = g.op('Cast', two_pow, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) lshift = g.op('Mul', self, two_pow) return lshift @@ -1079,7 +1144,7 @@ def where(g, condition, self=None, other=None, _outputs=None): condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx['Bool']) if self is None: condition = torch.onnx.symbolic_opset9.nonzero(g, condition) - return unbind(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs) + return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs) return g.op("Where", condition, self, other) @@ -1088,7 +1153,7 @@ def log_softmax(g, input, dim, dtype=None): # PyTorch dim and ONNX axis have different meanings. # See Softmax comment for details. # TODO: remove this as onnx opset 11 spec allows negative axes - input_dim = input.type().dim() + input_dim = sym_help._get_tensor_rank(input) if input_dim is None: return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input. " @@ -1104,7 +1169,8 @@ def log_softmax(g, input, dim, dtype=None): dim = input_dim - 1 return_op = g.op("LogSoftmax", input, axis_i=dim) if dtype and dtype.node().kind() != 'prim::Constant': - return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[dtype]) + parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype') + return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) if is_transpose_required: return_op = g.op("Transpose", return_op, perm_i=axes) return return_op @@ -1113,11 +1179,19 @@ def log_softmax(g, input, dim, dtype=None): @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i', 'i') def _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32): - weight_size = weight.type().sizes() + weight_size = sym_help._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + kernel_shape = None + + if kernel_shape is None or any([i is None for i in kernel_shape]): + raise RuntimeError('Unsupported: ONNX export of convolution for kernel ' + 'of unknown shape.') args = [input, weight] # ONNX only supports 1D bias - if not sym_help._is_none(bias) and bias.type().dim() == 1: + if not sym_help._is_none(bias) and sym_help._get_tensor_rank(bias) == 1: args.append(bias) kwargs = {"kernel_shape_i": weight_size[2:], @@ -1138,7 +1212,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation, n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) - if not sym_help._is_none(bias) and bias.type().dim() != 1: + if not sym_help._is_none(bias) and sym_help._get_tensor_rank(bias) != 1: return g.op("Add", n, bias) else: return n @@ -1176,20 +1250,33 @@ def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, gr @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled): - sym_help.assert_training_mode(training, "dropout") - input_sizes = input.type().sizes() + sym_help.assert_training_mode(training, "batch_norm") + batch_size = sym_help._get_tensor_dim_size(input, 0) + channel_size = sym_help._get_tensor_dim_size(input, 1) if weight is None or sym_help._is_none(weight): - assert len(input_sizes) > 1 - weight_value = torch.tensor([1.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of batch_norm for unknown ' + 'channel size.') + weight_value = torch.tensor([1.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') weight = g.op("Constant", value_t=weight_value) if bias is None or sym_help._is_none(bias): - assert len(input_sizes) > 1 - bias_value = torch.tensor([0.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of batch_norm for unknown ' + 'channel size.') + bias_value = torch.tensor([0.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') bias = g.op("Constant", value_t=bias_value) - + # If track_running_stats is set to False batch statistics are instead used during evaluation time + if running_mean is None or sym_help._is_none(running_mean) or running_var is None or sym_help._is_none(running_var): + assert batch_size is not None and channel_size is not None + reshape_in = g.op("Reshape", input, + g.op("Constant", value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64))) + trans_in = g.op('Transpose', reshape_in, perm_i=[0, 2, 1]) + running_var, running_mean = _var_mean(g, trans_in, + g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), + False, False) out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var, epsilon_f=eps, momentum_f=1 - momentum, @@ -1222,7 +1309,7 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) denominator = sqrt(g, add(g, variance, eps_cst)) - layer_norm = div(g, numerator, denominator) + layer_norm = g.op("Div", numerator, denominator) if not (weight is None or sym_help._is_none(weight)): layer_norm = mul(g, layer_norm, weight) @@ -1234,15 +1321,19 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled): - input_sizes = input.type().sizes() + channel_size = sym_help._get_tensor_dim_size(input, 1) if weight is None or sym_help._is_none(weight): - assert len(input_sizes) > 1 - weight_value = torch.tensor([1.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of instance_norm for unknown ' + 'channel size.') + weight_value = torch.tensor([1.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') weight = g.op("Constant", value_t=weight_value) if bias is None or sym_help._is_none(bias): - assert len(input_sizes) > 1 - bias_value = torch.tensor([0.] * input_sizes[1]).type( + if channel_size is None: + raise RuntimeError('Unsupported: ONNX export of instance_norm for unknown ' + 'channel size.') + bias_value = torch.tensor([0.] * channel_size).type( 'torch.' + input.type().scalarType() + 'Tensor') bias = g.op("Constant", value_t=bias_value) return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) @@ -1252,13 +1343,17 @@ def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_s def unfold(g, input, dimension, size, step): if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) - if input.isCompleteTensor(): - sizedim = input.type().sizes()[dimension] + sizes = sym_help._get_tensor_sizes(input) + try: + sizedim = sizes[dimension] + except Exception: + sizedim = None + if sizedim is not None: low_indices = range(0, sizedim, step) hi_indices = range(size, sizedim + 1, step) stack = [sym_help._slice_helper(g, input, axes=[dimension], starts=[low], ends=[hi]) for low, hi in zip(low_indices, hi_indices)] - ndim = input.type().dim() + ndim = len(sizes) perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) unsqueeze = [g.op("Unsqueeze", g.op("Transpose", t, perm_i=perm), axes_i=[dimension]) for t in stack] @@ -1286,17 +1381,7 @@ def index_select(g, self, dim, index): # In case of a scalar index, index_select returns a tensor with the same rank as the input. # To match this behavior in ONNX, we make index a 1D tensor so that the following gather # also produces a tensor with the same rank as the input. - - index_const = sym_help._maybe_get_scalar(index) - index_dim = index.type().dim() - if not sym_help._is_value(index_const): - # Index is a constant scalar. Make it a size 1 constant tensor. - index = g.op("Constant", value_t=torch.LongTensor([index_const])) - elif index_dim is not None: - if index_dim == 0: - # Index is a scalar. Reshape it to a size 1 tensor. - index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1]))) - return g.op("Gather", self, index, axis_i=dim) + return sym_help._select_helper(g, self, dim, index) def index_put(g, self, indices_list_value, values, accumulate): @@ -1329,11 +1414,12 @@ def index_copy(g, self, dim, index, source): def type_as(g, self, other): - if self.isCompleteTensor() and other.isCompleteTensor() and self.type().scalarType() == other.type().scalarType(): + self_dtype = sym_help._try_get_scalar_type(self) + other_dtype = sym_help._try_get_scalar_type(other) + if self_dtype == other_dtype and self_dtype is not None: return self - if other.isCompleteTensor(): - other_type_name = other.type().scalarType() - return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other_type_name]) + if other_dtype is not None: + return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other_dtype]) else: if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: # We don't know the type of other, bail by emitting ATen @@ -1529,8 +1615,9 @@ def empty_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False): - if dtype is None and self.isCompleteTensor(): - dtype = self.type().scalarType() + self_dtype = sym_help._try_get_scalar_type(self) + if dtype is None and self_dtype is not None: + dtype = self_dtype dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) return empty(g, sizes, dtype, layout, device, pin_memory) @@ -1558,7 +1645,7 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False): return g.op("Concat", *input_list, axis_i=0) else: if dtype is None: - dtype = sym_help._maybe_get_const(data, 't').type().scalarType() + dtype = data.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) return g.op("Cast", data, to_i=sym_help.scalar_type_to_onnx[dtype]) @@ -1582,8 +1669,9 @@ def zeros_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False): - if dtype is None and self.isCompleteTensor(): - dtype = self.type().scalarType() + self_dtype = sym_help._try_get_scalar_type(self) + if dtype is None and self_dtype is not None: + dtype = self_dtype dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) return zeros(g, sizes, dtype, layout, device, pin_memory) @@ -1633,16 +1721,29 @@ def full_like(g, input, fill_value, dtype=None, layout=None, device=None, pin_me def new_full(g, self, size, fill_value, dtype, layout, device, pin_memory=False): - if dtype is None and self.isCompleteTensor(): - dtype = self.type().scalarType() + self_dtype = sym_help._try_get_scalar_type(self) + if dtype is None and self_dtype is not None: + dtype = self_dtype dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) return full(g, size, fill_value, dtype, layout, device, pin_memory) -def eye(g, n, m, dtype=None, layout=None, device=None, pin_memory=False): - shape = g.op("Concat", g.op("Unsqueeze", n, axes_i=[0]), g.op("Unsqueeze", m, axes_i=[0]), axis_i=0) - tensor = zeros(g, shape, dtype, layout, device) - return g.op("EyeLike", tensor) +def eye(g, *args): + if len(args) == 5: + # aten::eye(n, dtype, layout, device, pin_memory) + n, dtype, layout, device, pin_memory = args + dim_size = g.op("Unsqueeze", n, axes_i=[0]) + shape = g.op("Concat", dim_size, dim_size, axis_i=0) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + elif len(args) == 6: + # aten::eye(n, m, dtype, layout, device, pin_memory) + n, m, dtype, layout, device, pin_memory = args + shape = g.op("Concat", g.op("Unsqueeze", n, axes_i=[0]), g.op("Unsqueeze", m, axes_i=[0]), axis_i=0) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + else: + raise NotImplementedError("Unknown aten::eye signature") def slice(g, self, *args): @@ -1684,6 +1785,15 @@ def hardtanh(g, self, min_val, max_val): return g.op("Clip", self, min_f=min_val, max_f=max_val) +@parse_args('v') +def hardswish(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = sym_help._hardtanh_helper(g, input, + g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), + g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + def alias(g, self): return self @@ -1692,8 +1802,8 @@ def alias(g, self): def unsqueeze(g, self, dim): # Handle negative dim if dim < 0: - rank = self.type().dim() - if rank: + rank = sym_help._get_tensor_rank(self) + if rank is not None: warnings.warn("ONNX export unsqueeze with negative axis " + str(dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + @@ -1711,10 +1821,16 @@ def unsqueeze(g, self, dim): def sort(g, self, dim, decending, out=None): if out is not None: _unimplemented("Sort", "Out parameter is not supported for sort") - if not self.isCompleteTensor(): + self_sizes = sym_help._get_tensor_sizes(self) + try: + dim_size = self_sizes[dim] + except Exception: + dim_size = None + + if dim_size is None: return _unimplemented("Sort", "input size not accessible") - return g.op("TopK", self, k_i=self.type().sizes()[dim], axis_i=dim, outputs=2) + return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) def numel(g, self): @@ -1777,16 +1893,18 @@ def repeat(g, self, repeats): @parse_args('v', 'i') def pixel_shuffle(g, self, upscale_factor): - dims = self.type().sizes() + dims = sym_help._get_tensor_sizes(self) if len(dims) != 4: return _unimplemented("pixel_shuffle", "only support 4d input") + if any([i is None for i in dims[1:]]): + return _unimplemented("pixel_shuffle", "only support static input shape, except for batch size") output_channel = dims[1] // upscale_factor // upscale_factor - after_view = view(g, self, [-1, output_channel, upscale_factor, upscale_factor, - dims[2], dims[3]]) + after_view = view(g, self, g.op("Constant", value_t=torch.tensor([-1, output_channel, upscale_factor, + upscale_factor, dims[2], dims[3]]))) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) return view(g, after_transpose, - [-1, output_channel, dims[2] * upscale_factor, dims[3] * - upscale_factor]) + g.op("Constant", value_t=torch.tensor([-1, output_channel, dims[2] * upscale_factor, + dims[3] * upscale_factor]))) def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, @@ -1802,6 +1920,9 @@ def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, 'ScaledTanh', 'HardSigmoid', 'Elu', 'Softsign', 'Softplus'] variantToOnnxActivationMap = dict(zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations)) weights_per_layer = 4 if has_biases else 2 + # this means that projections are used inside LSTM, so need to tell user that it's not supported + if variant == 'LSTM' and len(all_weights) != num_layers * weights_per_layer * (1 + bidirectional): + return _unimplemented("LSTM", "LSTMs with projections") assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) layer_weights = [all_weights[i:i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer)] if batch_first: @@ -1815,7 +1936,9 @@ def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, variant = 'RNN' w_hh = all_weights[1] - hidden_size = w_hh.type().sizes()[1] + hidden_size = sym_help._get_tensor_dim_size(w_hh, 1) + if hidden_size is None: + return _unimplemented("RNN/GRU/LSTM", "unknown hidden size") unidirectional = not bidirectional @@ -2027,7 +2150,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first): # It's really only necessary because those operators expand to something that # only works with int32 types in Caffe2... if lengths.type().scalarType() != 'Int': - lengths = _cast_Int(g, lengths, False) + lengths = _cast_Int(g, lengths, False) # type: ignore return g.op("prim::PackPadded", input, lengths, outputs=2) @@ -2046,11 +2169,11 @@ def randn(g, shapes, dtype, *options): dtype = sym_help._get_const(dtype, 'i', 'dtype') if dtype is None: dtype = 6 # float - if sym_help._is_packed_list(shapes): + shape = sym_help._maybe_get_const(shapes, "is") + if sym_help._is_value(shape): shape_const = g.op("ConstantOfShape", shapes, value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[6])) return g.op('RandomNormalLike', shape_const, dtype_i=sym_help.scalar_type_to_onnx[dtype]) - shape = sym_help._get_const(shapes, "is", "randn") return g.op('RandomNormal', shape_i=shape) @@ -2058,11 +2181,11 @@ def rand(g, shapes, dtype, *options): dtype = sym_help._get_const(dtype, 'i', 'dtype') if dtype is None: dtype = 6 # float - if sym_help._is_packed_list(shapes): + shape = sym_help._maybe_get_const(shapes, "is") + if sym_help._is_value(shape): shape_const = g.op("ConstantOfShape", shapes, value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[6])) return g.op('RandomUniformLike', shape_const, dtype_i=sym_help.scalar_type_to_onnx[dtype]) - shape = sym_help._get_const(shapes, "is", "rand") return g.op('RandomUniform', shape_i=shape) @@ -2099,7 +2222,7 @@ def erf(g, input): @parse_args('v', 'i', 'i') def flatten(g, input, start_dim, end_dim): - dim = input.type().dim() + dim = sym_help._get_tensor_rank(input) if dim is None: return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input. " @@ -2116,11 +2239,17 @@ def flatten(g, input, start_dim, end_dim): return sym_help._flatten_helper(g, input, start_dim, end_dim, dim) +# Emitted from `torch.nonzero(x, as_tuple=False)` @parse_args('v') def nonzero(g, input): return t(g, g.op('NonZero', input)) +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g, input, _outputs=None): + return unbind(g, nonzero(g, input), 1, _outputs=_outputs) + + @parse_args('v') def isnan(g, input): output = g.op('IsNaN', input) @@ -2134,7 +2263,7 @@ def narrow(g, input, dim, start, length): def argmax(g, input, dim, keepdim): if sym_help._is_none(dim): - flattened = reshape(g, input, (-1,)) + flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False) else: dim = _parse_arg(dim, 'i') @@ -2144,7 +2273,7 @@ def argmax(g, input, dim, keepdim): def argmin(g, input, dim, keepdim): if sym_help._is_none(dim): - flattened = reshape(g, input, (-1,)) + flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False) else: dim = _parse_arg(dim, 'i') @@ -2168,13 +2297,16 @@ def scatter(g, self, dim, index, src): @parse_args('v', 'i', 'v', 'v') def scatter_add(g, self, dim, index, src): - if not self.isCompleteTensor(): - return _unimplemented("scatter_add", "input size not accessible") - dtype = self.type().scalarType() + dtype = sym_help._try_get_scalar_type(self) + if dtype is None: + return _unimplemented("scatter_add", "input dtype not accessible") dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] - sizes = self.type().sizes() - to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=dtype)) + sizes = sym_help._get_tensor_sizes(self, allow_nonstatic=False) + if sizes: + to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=dtype)) + else: + to_add = zeros_like(self, dtype) to_add = sym_help._scatter_helper(g, to_add, dim, index, src) return add(g, self, to_add) @@ -2187,6 +2319,51 @@ def log2(g, self): def prim_shape(g, self): return g.op('Shape', self) +def prim_max(g, self, other): + return g.op('Max', self, other) + +def prim_data(g, self): + return self + +def is_floating_point(g, self): + if sym_help._is_fp(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + + +def __isnot_(g, self, other): + if sym_help._is_none(other): + if sym_help._is_none(self): + return g.op("Constant", value_t=torch.BoolTensor([0])) + return g.op("Constant", value_t=torch.BoolTensor([1])) + return ne(g, self, other) + + +# exists to refine the type of the Value +# if x is an optional Tensor, unchecked_cast will cast +# x to Tensor, so the rest of the graph knows that x is a Tensor +# this doesn't do anything in runtime and is a noop in ONNX +def prim_unchecked_cast(g, self): + return self + + +def prim_dtype(g, self): + dtype = sym_help._try_get_scalar_type(self) + if dtype is None: + dtype = "Float" + dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) + return g.op("Constant", value_t=torch.tensor(dtype)) + + +# tolist is currently supported only for 1D input tensors. +# dim_val and elem_ty_val represent dimension and type annotations +# that need to match dimension and type of the input tensor. +def prim_tolist(g, input, dim_val, elem_ty_val): + dim = sym_help._maybe_get_const(dim_val, 'i') + if dim > 1: + return _unimplemented("prim_tolist", "dim_val > 1") + return input + @parse_args('v', 'i') def one_hot(g, self, num_classes): @@ -2210,39 +2387,77 @@ def gather(g, self, dim, index, sparse_grad=False): @parse_args('v', 'is', 'b', 'i') -def _std(g, input, dim, unbiased, keepdim): - if input.isCompleteTensor(): - sqrd = g.op("Mul", input, input) - if dim is None: - sqrdmean = g.op("ReduceMean", sqrd, keepdims_i=0) - mean = g.op("ReduceMean", input, keepdims_i=0) - redudced_dims = input.type().sizes() - else: - sqrdmean = g.op("ReduceMean", sqrd, axes_i=dim, keepdims_i=keepdim) - mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) - redudced_dims = [input.type().sizes()[i] for i in dim] - meansqrd = g.op("Mul", mean, mean) - var = g.op("Abs", g.op("Sub", sqrdmean, meansqrd)) - # This is to correct bias in calculating variance, by dividing it over (N - 1) instead on N - if unbiased: - count = numpy.prod(redudced_dims) - mul = g.op("Mul", var, g.op("Constant", value_t=torch.tensor(count, dtype=torch.float))) - var = g.op("Div", mul, g.op("Constant", value_t=torch.tensor(count - 1, dtype=torch.float))) - std = g.op("Sqrt", var) - return std +def _var_mean(g, input, dim, unbiased, keepdim): + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = numel(g, input) else: - _unimplemented("std", "Unknown input rank. Cannot compute std along dimensions.") + mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op("Gather", redudced_dims, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - 1) instead on N + if unbiased: + num_elements = g.op("Cast", num_elements, to_i=sym_help.cast_pytorch_to_onnx['Float']) + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean # Since position of optional arguments can change for std, this is a hack to find if first argument # is 'dim' or 'unbiased'. As shown below, 'dim' argument could be listed before 'unbiased' : -# torch.std(input, unbiased=True) -# torch.std(input, dim, keepdim=False, unbiased=True) +# at::std(input, unbiased) +# at::std(input, dim, unbiased, keepdim) def std(g, input, *args): if len(args) == 3: - return _std(g, input, *args) + var, _ = _var_mean(g, input, *args) + else: + var, _ = _var_mean(g, input, None, args[0], None) + return g.op("Sqrt", var) + + +# Since position of optional arguments can change for var, this is a hack to find if first argument +# is 'dim' or 'unbiased'. As shown below, 'dim' argument could be listed before 'unbiased' : +# at::var(input, unbiased) +# at::var(input, dim, unbiased, keepdim) +def var(g, input, *args): + if len(args) == 3: + var, _ = _var_mean(g, input, *args) + else: + var, _ = _var_mean(g, input, None, args[0], None) + return var + + +# Since position of optional arguments can change for var_mean, this is a hack to find if first argument +# is 'dim' or 'unbiased'. As shown below, 'dim' argument could be listed before 'unbiased' : +# at::var_mean(input, unbiased) +# at::var_mean(input, dim, unbiased, keepdim) +def var_mean(g, input, *args): + if len(args) == 3: + var, mean = _var_mean(g, input, *args) + else: + var, mean = _var_mean(g, input, None, args[0], None) + return var, mean + + +# Since position of optional arguments can change for std_mean, this is a hack to find if first argument +# is 'dim' or 'unbiased'. As shown below, 'dim' argument could be listed before 'unbiased' : +# at::std_mean(input, unbiased) +# at::std_mean(input, dim, unbiased, keepdim) +def std_mean(g, input, *args): + if len(args) == 3: + var, mean = _var_mean(g, input, *args) else: - return _std(g, input, None, args[0], None) + var, mean = _var_mean(g, input, None, args[0], None) + return g.op("Sqrt", var), mean @parse_args('v', 'is', 'i') @@ -2305,7 +2520,7 @@ def _get_arange_dtype(dtype): def masked_fill(g, self, mask, value): - mask = _cast_Bool(g, mask, False) + mask = _cast_Bool(g, mask, False) # type: ignore value = sym_help._maybe_get_scalar(value) return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self) @@ -2331,7 +2546,7 @@ def try_mask_to_index(index): indices = [try_mask_to_index(idx) for idx in indices] if len(indices) == 1: - return index_select(g, self, 0, indices[0]) + return sym_help._select_helper(g, self, 0, indices[0], apply_reshape=False) else: # Multiple tensors as indices. Each tensor could either be # 1. prim::Constant() @@ -2358,7 +2573,7 @@ def try_mask_to_index(index): elif len(adv_idx_indices) == 1: return index_select(g, self, adv_idx_indices[0], indices[adv_idx_indices[0]]) else: - rank = self.type().dim() + rank = sym_help._get_tensor_rank(self) if rank is None: raise NotImplementedError("Unsupported aten::index operator of advanced indexing on tensor of unknown rank, " + "try turning on shape and type propagate during export: " + @@ -2370,7 +2585,6 @@ def try_mask_to_index(index): " is achieved by combination of multiple ONNX operators, " + "including Reshape, Transpose, Concat, and Gather. " + "If indices include negative values, the exported graph will produce incorrect results.") - rank = self.type().dim() adv_idx_count = len(adv_idx_indices) shape_tensor = _shape_as_tensor(g, self) dim_tensor_list = [ @@ -2451,7 +2665,7 @@ def baddbmm(g, self, batch1, batch2, beta, alpha): def meshgrid(g, tensor_list): - tensors = [view(g, t, torch.LongTensor([-1])) for t in sym_help._unpack_list(tensor_list)] + tensors = [view(g, t, g.op("Constant", value_t=torch.LongTensor([-1]))) for t in sym_help._unpack_list(tensor_list)] tensors_shape = [g.op("Shape", t) for t in tensors] out_shape = g.op("Concat", *tensors_shape, axis_i=0) out = [] @@ -2465,7 +2679,7 @@ def meshgrid(g, tensor_list): def remainder(g, input, other): div = g.op("Div", input, other) - if sym_help._is_fp(input): + if sym_help._is_fp(input) or sym_help._is_fp(other): div = g.op("Floor", div) quo = g.op("Mul", div, other) return g.op("Sub", input, quo) @@ -2473,7 +2687,7 @@ def remainder(g, input, other): def gelu(g, self): _sqrt2 = 1.4142135623730951 - erf = g.op('Erf', div(g, self, torch.tensor(_sqrt2))) + erf = g.op('Erf', g.op('Div', self, torch.tensor(_sqrt2))) erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.float))) return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.float))) @@ -2484,8 +2698,12 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): return g.op("ATen", input, weight, bias, num_groups_i=num_groups, eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm") - input_sizes = input.type().sizes() - assert input_sizes[1] % num_groups == 0 + channel_size = sym_help._get_tensor_dim_size(input, 1) + if channel_size is not None: + assert channel_size % num_groups == 0 + input_rank = sym_help._get_tensor_rank(input) + if input_rank is None: + return _unimplemented("group_norm", "unknown input rank") # 0 in the shape list keeps dimension value unchanged. shape = [0, num_groups, -1] input_reshaped = g.op('Reshape', input, g.op('Constant', value_t=torch.LongTensor(shape))) @@ -2511,14 +2729,14 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): bias = g.op("Constant", value_t=bias_value) # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] - axes = list(range(1, len(input_sizes) - 1)) + axes = list(range(1, input_rank - 1)) return add(g, mul(g, norm, g.op("Unsqueeze", weight, axes_i=axes)), g.op("Unsqueeze", bias, axes_i=axes)) @parse_args('v', 'v', 'i') def _weight_norm(g, weight_v, weight_g, dim): - rank = weight_v.type().dim() - if rank: + rank = sym_help._get_tensor_rank(weight_v) + if rank is not None: # W = g * ((v) / ||v||) # Compute norm_except_dim for l2 norm. dim = None means over all dims # torch's weight_norm module sets dim = -1 if it's None. @@ -2598,6 +2816,7 @@ def as_strided(g, self, sizes, strides, offset=None): sizes = sym_help._maybe_get_const(sizes, 'is') rank = len(strides) self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + ind: Optional[torch.Tensor] if not sym_help._is_value(sizes): ind = torch.tensor([0], dtype=torch.long) for i, (size, stride) in enumerate(zip(sizes, strides)): @@ -2622,3 +2841,21 @@ def as_strided(g, self, sizes, strides, offset=None): if offset: ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) return g.op("Gather", self_1d, ind) + + +def __derive_index(g, index, start, step): + return g.op("Add", start, g.op("Mul", index, step)) + + +# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp +# if (step > 0 && lo < hi) { +# push(stack, 1 + (hi - 1 - lo) / step); +# } else if (step < 0 && lo > hi) { +# push(stack, 1 + (lo - 1 - hi) / (0 - step)); +# } else { +# push(stack, 0); +# } +def __range_length(g, lo, hi, step): + sub = g.op("Sub", hi, lo) + div = g.op("Ceil", true_divide(g, sub, step)) + return g.op("Cast", div, to_i=sym_help.cast_pytorch_to_onnx['Long']) diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py index 48114d6c472b5..f748f63efc6bf 100644 --- a/torch/onnx/symbolic_registry.py +++ b/torch/onnx/symbolic_registry.py @@ -1,6 +1,7 @@ import warnings import importlib from inspect import getmembers, isfunction +from typing import Dict, Tuple, Any, Union # The symbolic registry "_registry" is a dictionary that maps operators # (for a specific domain and opset version) to their symbolic functions. @@ -8,11 +9,11 @@ # The keys are tuples (domain, version), (where domain is a string, and version is an int), # and the operator's name (string). # The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic -_registry = {} +_registry: Dict[Tuple[str, int], Dict] = {} -_symbolic_versions = {} -from torch.onnx.symbolic_helper import _onnx_stable_opsets -for opset_version in _onnx_stable_opsets: +_symbolic_versions: Dict[Union[int, str], Any] = {} +from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset +for opset_version in _onnx_stable_opsets + [_onnx_main_opset]: module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version)) _symbolic_versions[opset_version] = module @@ -90,7 +91,7 @@ def is_registered_op(opname, domain, version): def get_op_supported_version(opname, domain, version): iter_version = version - while iter_version <= _onnx_stable_opsets[-1]: + while iter_version <= _onnx_main_opset: ops = [op[0] for op in get_ops_in_version(iter_version)] if opname in ops: return iter_version diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index a3b05ff71c61b..4e98608e4ef45 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -17,7 +17,8 @@ from torch._six import string_classes from torch.jit import _unique_state_dict from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode -from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _assign_output_shapes, _check_onnx_proto +from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto +from typing import Union, Tuple, List # the flag to tell the user whether it's in the middle of ONNX export or not @@ -76,7 +77,7 @@ def export(model, args, f, export_params=True, verbose=False, training=None, if aten or export_raw_ir: assert operator_export_type is None assert aten ^ export_raw_ir - operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW + operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW elif operator_export_type is None: if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE: operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK @@ -121,7 +122,7 @@ def _split_tensor_list_constants(g, block): def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=False, fixed_batch_size=False, - params_dict=None, use_new_jit_passes=False): + params_dict=None, use_new_jit_passes=True, dynamic_axes=None, input_names=None): # Inline everything torch._C._jit_pass_inline(graph) @@ -195,12 +196,19 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa # onnx only supports tensors, so we turn all out number types into tensors torch._C._jit_pass_erase_number_types(graph) + from torch.onnx.symbolic_helper import _onnx_shape_inference + if _onnx_shape_inference: + input_names = [] if input_names is None else input_names + dynamic_axes = {} if dynamic_axes is None else dynamic_axes + torch._C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) graph = torch._C._jit_pass_onnx(graph, operator_export_type) torch._C._jit_pass_lint(graph) torch._C._jit_pass_onnx_scalar_type_analysis(graph) torch._C._jit_pass_lint(graph) + torch._C._jit_pass_onnx_fold_if(graph) + from torch.onnx.symbolic_helper import _export_onnx_opset_version torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version, fixed_batch_size) torch._C._jit_pass_lint(graph) @@ -214,6 +222,9 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_lint(graph) graph = torch._C._jit_pass_canonicalize(graph) torch._C._jit_pass_lint(graph) + from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version + if _onnx_shape_inference: + torch._C._jit_pass_onnx_graph_shape_type_inference(graph, _export_onnx_opset_version) return graph @@ -308,6 +319,35 @@ def _decide_external_data_format(use_external_data_format, operator_export_type, model_file_location = f if val_use_external_data_format and isinstance(f, str) else str() return val_use_external_data_format, model_file_location +def _decide_input_format(model, args): + import inspect + try: + sig = inspect.signature(model.forward) + ordered_list_keys = list(sig.parameters.keys()) + if isinstance(args[-1], dict): + args_dict = args[-1] + args = list(args)[:-1] + n_nonkeyword = len(args) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default is param.empty: + args.append(None) + else: + args.append(param.default) + args = tuple(args) + return args + # Cases of models without forward functions and dict inputs + except AttributeError: + warnings.warn("Model has no forward function") + return args + # Cases of models with no input args + except IndexError: + warnings.warn("No input args") + return args def _trace(func, args, operator_export_type, return_outs=False): # Special case for common case of passing a single Tensor @@ -343,6 +383,7 @@ def _trace_and_get_graph_from_model(model, args): def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes): torch_out = None + params: Union[List, Tuple] if isinstance(model, torch.jit.ScriptModule): try: graph = model.forward.graph @@ -350,10 +391,10 @@ def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes): if not use_new_jit_passes: method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c) else: - freezed_m = torch._C._freeze_module(model._c) + freezed_m = torch._C._freeze_module(model._c, preserveParameters=True) + freezed_m, params = torch._C._jit_onnx_list_model_parameters(freezed_m) method_graph = freezed_m._get_method('forward').graph method_graph.eraseInput(0) # Remove 'self' from model inputs - params = [] in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params)) graph = _propagate_and_assign_input_shapes( @@ -388,7 +429,8 @@ def _model_to_graph(model, args, verbose=False, example_outputs=None, _retain_param_name=False, do_constant_folding=True, _disable_torch_constant_prop=False, fixed_batch_size=False, - training=None, use_new_jit_passes=False): + training=None, use_new_jit_passes=True, + dynamic_axes=None): from torch.onnx.symbolic_helper import _export_onnx_opset_version # Special case for common case of passing a single Tensor if isinstance(args, torch.Tensor): @@ -408,19 +450,20 @@ def _model_to_graph(model, args, verbose=False, graph = _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=_disable_torch_constant_prop, fixed_batch_size=fixed_batch_size, params_dict=params_dict, - use_new_jit_passes=use_new_jit_passes) + use_new_jit_passes=use_new_jit_passes, + dynamic_axes=dynamic_axes, input_names=input_names) + from torch.onnx.symbolic_helper import _onnx_shape_inference if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction): assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \ "ScriptFunction." out_vars, _ = torch.jit._flatten(tuple(example_outputs)) - graph = _assign_output_shapes(graph, out_vars) + torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, _onnx_shape_inference) # NB: ONNX requires complete information about output types, which might be # erased by some optimizations, so we need to set it explicitly again. if torch_out is not None: output_tensors, _ = torch._C._jit_flatten(torch_out) - for output, tensor in zip(graph.outputs(), output_tensors): - output.inferTypeFrom(tensor) + torch._C._jit_pass_onnx_assign_output_shape(graph, output_tensors, _onnx_shape_inference) _set_input_and_output_names(graph, input_names, output_names) @@ -432,7 +475,7 @@ def _model_to_graph(model, args, verbose=False, param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) - if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training): + if training is None or training == TrainingMode.EVAL: params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions: @@ -466,7 +509,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t if aten or export_raw_ir: assert operator_export_type is None assert aten ^ export_raw_ir - operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW + operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW elif operator_export_type is None: operator_export_type = OperatorExportTypes.ONNX return _export_to_pretty_string(model, args, f, export_params, verbose, training, @@ -484,7 +527,8 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, google_printer=False, opset_version=None, _retain_param_name=False, do_constant_folding=True, keep_initializers_as_inputs=None, - fixed_batch_size=False, custom_opsets=None, add_node_names=True): + fixed_batch_size=False, custom_opsets=None, add_node_names=True, + onnx_shape_inference=True): from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: @@ -493,12 +537,15 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, custom_opsets = {} _set_opset_version(opset_version) _set_operator_export_type(operator_export_type) + from torch.onnx.symbolic_helper import _set_onnx_shape_inference + _set_onnx_shape_inference(onnx_shape_inference) with select_model_mode_for_export(model, training): val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs, operator_export_type, opset_version) val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type) val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type, training) + args = _decide_input_format(model, args) graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, @@ -513,31 +560,31 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini input_names=None, output_names=None, opset_version=None, dynamic_axes=None): r""" This diagnostic tool runs your model with operator_export_type set to - OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of + OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of all the ops that are not supported/implemented by the current exporter operator_export_type is set to OperatorExportTypes.ONNX_FALLTHROUGH by default OperatorExportTypes.ONNX_FALLTHROUGH: If an op is not supported - in ONNX, fall through and export the operator as is, as a custom + in ONNX, fall through and export the operator as is, as a custom ONNX op. Using this mode, the op can be exported and implemented by the user for their runtime backend. Example graph:: - graph(%0 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu)): + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu)): %6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() %4 : None = prim::Constant() - %5 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 + %5 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 return (%5) is exported as:: - graph(%0 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu)): + graph(%0 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu)): %6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() %4 : None = prim::Constant() - %5 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 + %5 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 return (%5) - In the above example, aten::cumsum in not implemented in opset 9, hence exporter falls + In the above example, aten::cumsum in not implemented in opset 9, hence exporter falls through and provides a list of unsupported ops, the result being: Unsupported ops : [aten:cumsum] """ @@ -549,6 +596,7 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini # in ONNX, fall through will occur and export the operator as is, as a custom ONNX op. operator_export_type = OperatorExportTypes.ONNX_FALLTHROUGH with select_model_mode_for_export(model, training): + args = _decide_input_format(model, args) graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type) # The output 'unsupported_ops' will contain the names of all the ops that are not supported in ONNX @@ -573,7 +621,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, fixed_batch_size=False, custom_opsets=None, add_node_names=True, enable_onnx_checker=True, use_external_data_format=False, - onnx_shape_inference=False, use_new_jit_passes=False): + onnx_shape_inference=True, use_new_jit_passes=True): if isinstance(model, torch.nn.DataParallel): raise ValueError('torch.nn.DataParallel is not supported by ONNX ' @@ -614,6 +662,11 @@ def _export(model, args, f, export_params=True, verbose=False, training=None, val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format, operator_export_type, f) + args = _decide_input_format(model, args) + if dynamic_axes is None: + dynamic_axes = {} + _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + graph, params_dict, torch_out = \ _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, @@ -621,17 +674,14 @@ def _export(model, args, f, export_params=True, verbose=False, training=None, val_do_constant_folding, fixed_batch_size=fixed_batch_size, training=training, - use_new_jit_passes=use_new_jit_passes) + use_new_jit_passes=use_new_jit_passes, + dynamic_axes=dynamic_axes) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE - if dynamic_axes is None: - dynamic_axes = {} if custom_opsets is None: custom_opsets = {} - _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) - if export_params: proto, export_map = graph._export_onnx( params_dict, opset_version, dynamic_axes, defer_weight_export, @@ -783,7 +833,7 @@ def _graph_op(g, opname, *raw_args, **kwargs): This function is monkey-patched onto Graph. - Arguments: + Args: opname (string): The ONNX operator name, e.g., `Abs` or `Add`. args (Node...): The inputs to the operator; usually provided as arguments to the `symbolic` definition. @@ -827,6 +877,38 @@ def const_if_tensor(arg): return tuple(o for o in n.outputs()) +def _block_op(b, opname, *args, **kwargs): + if "::" in opname: + aten = False + ns_opname = opname + else: + aten = kwargs.pop("aten", False) + ns = "aten" if aten else "onnx" + ns_opname = ns + "::" + opname + n = b.addNode(ns_opname, list(args)) + for k, v in sorted(kwargs.items()): + # TODO: enable inplace in aten exporting mode. + if k == "inplace": + continue + _add_attribute(n, k, v, aten=aten) + if len(list(n.outputs())) == 1: + return n.output() + return tuple(o for o in n.outputs()) + + +def _add_block(node): + return node.addBlock() + + +def _add_input_to_block(block): + return block.addInputToBlock() + + +def _add_output_to_block(block, value): + new_output = block.registerOutput(value) + return new_output + + # Note [Export inplace] # ~~~~~~~~~~~~~~~~~~~~~ # In abstract, it would be better for us to export inplace annotations, @@ -868,8 +950,9 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor ns_op_name = n.kind() ns, op_name = ns_op_name.split("::") if ns == "onnx": - # Use the original node directly - return None + # Clone node to trigger ONNX shape inference + attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} + return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize()) elif ns == "aten": is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version) @@ -905,10 +988,11 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor else: raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format( n.kindOf("value"))) - elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack": + elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack" or op_name == "Uninitialized": # None is not an ONNX operator; keep it as None # Let the exporter handle and finally eliminate these ops # ListConstruct and ListUnpack will be erased in the ONNX peephole pass + # Uninitialized will be erased during shape/type inference return None elif op_name == "device" and n.output().type().kind() == "DeviceObjType": return None @@ -918,8 +1002,21 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor for b in n.blocks(): new_block = new_node.addBlock() # Copy input metadata to subblock - # This is for Loop only, since If only has a single input. + # + # If format: + # prim::If(cond) + # block0() + # block1() + # + # Loop format: + # prim::Loop(iter, cond, input_1, ..., input_n) + # block0(iter, input_1, ..., input_n) + # + # For `If` node, there is nothing to copy. + # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. for i, b_in in enumerate(b.inputs()): + if i == 0 and i < len(inputs): + b_in.setType(inputs[i].type()) if i > 0 and (i + 1) < len(inputs): b_in.setType(inputs[i + 1].type()) torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) @@ -960,8 +1057,7 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor else: raise RuntimeError("ONNX export failed on an operator with unrecognized namespace {}::{}. " "If you are trying to export a custom operator, make sure you registered " - "it with the right domain and version. " - "Otherwise, please report a bug.".format(ns, op_name)) + "it with the right domain and version.".format(ns, op_name)) except RuntimeError: if operator_export_type == OperatorExportTypes.ONNX_FALLTHROUGH: return None @@ -991,6 +1087,10 @@ def _graph_constant(g, value, dims, type, *args, **kwargs): dims = [1] isscalar = True type = type.lower() + tensor: Union[torch.CharTensor, torch.ShortTensor, + torch.IntTensor, torch.LongTensor, + torch.HalfTensor, torch.FloatTensor, + torch.DoubleTensor] if type == "char": tensor = torch.CharTensor(*dims) elif type == "short": @@ -1008,7 +1108,7 @@ def _graph_constant(g, value, dims, type, *args, **kwargs): else: raise ValueError("Unknown type, type should be one of the following strings: " "char, short, int, long, half, float, double") - tensor.fill_(value) + tensor.fill_(value) # type: ignore if isscalar: return g.op("Constant", *args, value_z=tensor, **kwargs) return g.op("Constant", *args, value_t=tensor, **kwargs) @@ -1038,9 +1138,9 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version): raise RuntimeError("Failed to register operator {}. The domain {} is already a used domain." .format(symbolic_name, ns)) import torch.onnx.symbolic_registry as sym_registry - from torch.onnx.symbolic_helper import _onnx_stable_opsets + from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset - for version in _onnx_stable_opsets: + for version in _onnx_stable_opsets + [_onnx_main_opset]: if version >= opset_version: sym_registry.register_op(op_name, symbolic_fn, ns, version) @@ -1080,13 +1180,9 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): value_dict[x] = str(key) + '_dynamic_axes_' + str(i + 1) dynamic_axes[key] = value_dict -def _add_block(node, input_node, op_name, **kwargs): - new_block = node.addBlock() - new_node = new_block.addNode(input_node, op_name) - for k, v in kwargs.items(): - _add_attribute(new_node, k, v, False) -torch._C.Graph.op = _graph_op -torch._C.Graph.at = _graph_at -torch._C.Graph.constant = _graph_constant -torch._C.Node.__getitem__ = _node_getitem +torch._C.Graph.op = _graph_op # type: ignore +torch._C.Graph.at = _graph_at # type: ignore +torch._C.Block.op = _block_op # type: ignore +torch._C.Graph.constant = _graph_constant # type: ignore +torch._C.Node.__getitem__ = _node_getitem # type: ignore diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py new file mode 100644 index 0000000000000..2f9b65bfbe3f3 --- /dev/null +++ b/torch/optim/_multi_tensor/__init__.py @@ -0,0 +1,24 @@ +""" +:mod:`torch.optim._multi_tensor` is a package implementing various optimization algorithms. +Most commonly used methods are already supported, and the interface is general +enough, so that more sophisticated ones can be also easily integrated in the +future. +""" + +from .adam import Adam +from .adamw import AdamW +from .sgd import SGD +from .rmsprop import RMSprop +from .rprop import Rprop +from .asgd import ASGD +from .adamax import Adamax +from .adadelta import Adadelta + +del adam +del adamw +del sgd +del rmsprop +del rprop +del asgd +del adamax +del adadelta diff --git a/torch/optim/_multi_tensor/__init__.pyi b/torch/optim/_multi_tensor/__init__.pyi new file mode 100644 index 0000000000000..952b969012b79 --- /dev/null +++ b/torch/optim/_multi_tensor/__init__.pyi @@ -0,0 +1,8 @@ +from .adam import Adam as Adam +from .adamw import AdamW as AdamW +from .sgd import SGD as SGD +from .rmsprop import RMSprop as RMSprop +from .rprop import Rprop as Rprop +from .asgd import ASGD as ASGD +from .adamax import Adamax as Adamax +from .adadelta import Adadelta as Adadelta \ No newline at end of file diff --git a/torch/optim/_multi_tensor/adadelta.py b/torch/optim/_multi_tensor/adadelta.py new file mode 100644 index 0000000000000..7c600fafd45e0 --- /dev/null +++ b/torch/optim/_multi_tensor/adadelta.py @@ -0,0 +1,123 @@ +import torch +from ..optimizer import Optimizer +from collections import defaultdict + +class Adadelta(Optimizer): + """Implements Adadelta algorithm. + + It has been proposed in `ADADELTA: An Adaptive Learning Rate Method`__. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + rho (float, optional): coefficient used for computing a running average + of squared gradients (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-6) + lr (float, optional): coefficient that scale delta before it is applied + to the parameters (default: 1.0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + __ https://arxiv.org/abs/1212.5701 + """ + + def __init__(self, params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= rho <= 1.0: + raise ValueError("Invalid rho value: {}".format(rho)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) + super(Adadelta, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + grads = [] + params_with_grad = [] + states = [] + square_avgs = [] + acc_deltas = [] + + rho, eps = group['rho'], group['eps'] + + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('Adadelta does not support sparse gradients') + + grads.append(p.grad) + params_with_grad.append(p) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['acc_delta'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + square_avgs.append(state['square_avg']) + acc_deltas.append(state['acc_delta']) + + state['step'] += 1 + states.append(state) + + if group['weight_decay'] != 0: + torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay']) + + torch._foreach_mul_(square_avgs, rho) + torch._foreach_addcmul_(square_avgs, grads, grads, value=1 - rho) + + std = torch._foreach_add(square_avgs, eps) + torch._foreach_sqrt_(std) + + deltas = torch._foreach_add(acc_deltas, eps) + torch._foreach_sqrt_(deltas) + torch._foreach_div_(deltas, std) + torch._foreach_mul_(deltas, grads) + + torch._foreach_add_(params_with_grad, deltas, alpha=-group['lr']) + + torch._foreach_mul_(acc_deltas, rho) + torch._foreach_addcmul_(acc_deltas, deltas, deltas, value=1 - rho) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/adadelta.pyi b/torch/optim/_multi_tensor/adadelta.pyi new file mode 100644 index 0000000000000..0ca4478a16da6 --- /dev/null +++ b/torch/optim/_multi_tensor/adadelta.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class Adadelta(Optimizer): + def __init__(self, params: _params_t, lr: float=..., rho: float=..., eps: float=..., weight_decay: float=...) -> None: ... \ No newline at end of file diff --git a/torch/optim/_multi_tensor/adam.py b/torch/optim/_multi_tensor/adam.py new file mode 100644 index 0000000000000..d539e865d478f --- /dev/null +++ b/torch/optim/_multi_tensor/adam.py @@ -0,0 +1,165 @@ +import math +import torch +from ..optimizer import Optimizer +from collections import defaultdict + +class Adam(Optimizer): + r"""Implements Adam algorithm with multi tensor APIs. + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + The implementation of the L2 penalty follows changes proposed in + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(Adam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Adam, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + amsgrad = group['amsgrad'] + + grads = [] + states = [] + exp_avg = [] + exp_avg_sq = [] + max_exp_avg_sq = [] + params_with_grad = [] + + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + params_with_grad.append(p) + grads.append(p.grad) + + for p in params_with_grad: + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg.append(state['exp_avg']) + exp_avg_sq.append(state['exp_avg_sq']) + + if amsgrad: + max_exp_avg_sq.append(state['max_exp_avg_sq']) + + state['step'] += 1 + states.append(state) + + beta1, beta2 = group['betas'] + + bias_correction1 = [1 - beta1 ** state['step'] for state in states] + bias_correction2 = [1 - beta2 ** state['step'] for state in states] + if group['weight_decay'] != 0: + grads = torch._foreach_add(grads, params_with_grad, alpha=group['weight_decay']) + + # + # Decay the first and second moment running average coefficient + # + torch._foreach_mul_(exp_avg, beta1) + torch._foreach_add_(exp_avg, grads, alpha=1 - beta1) + + torch._foreach_mul_(exp_avg_sq, beta2) + torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + max_exp_avg_sq = torch._foreach_maximum(max_exp_avg_sq, exp_avg_sq) + + # Use the max. for normalizing running avg. of gradient + max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq) + bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] + torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction_sqrt) + denom = torch._foreach_add(max_exp_avg_sq_sqrt, group['eps']) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq) + bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) + denom = torch._foreach_add(exp_avg_sq_sqrt, group['eps']) + + step_size = [(group['lr'] / bc) * -1 for bc in bias_correction1] + torch._foreach_addcdiv_(params_with_grad, exp_avg, denom, step_size) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/adam.pyi b/torch/optim/_multi_tensor/adam.pyi new file mode 100644 index 0000000000000..09f29597fd189 --- /dev/null +++ b/torch/optim/_multi_tensor/adam.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class Adam(Optimizer): + def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., amsgrad: bool = ...) -> None: ... diff --git a/torch/optim/_multi_tensor/adamax.py b/torch/optim/_multi_tensor/adamax.py new file mode 100644 index 0000000000000..a866a3e02f38a --- /dev/null +++ b/torch/optim/_multi_tensor/adamax.py @@ -0,0 +1,128 @@ +import torch +from ..optimizer import Optimizer +from collections import defaultdict + +class Adamax(Optimizer): + """Implements Adamax algorithm (a variant of Adam based on infinity norm). + + It has been proposed in `Adam: A Method for Stochastic Optimization`__. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + __ https://arxiv.org/abs/1412.6980 + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(Adamax, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + grads = [] + params_with_grad = [] + states = [] + exp_avgs = [] + exp_infs = [] + + beta1, beta2 = group['betas'] + eps = group['eps'] + + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('Adamax does not support sparse gradients') + + grads.append(p.grad) + params_with_grad.append(p) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['exp_inf'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_infs.append(state['exp_inf']) + + state['step'] += 1 + states.append(state) + + if group['weight_decay'] != 0: + torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay']) + + # Update biased first moment estimate. + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) + + # Update the exponentially weighted infinity norm. + torch._foreach_mul_(exp_infs, beta2) + + for exp_inf, grad in zip(exp_infs, grads): + norm_buf = torch.cat([ + exp_inf.unsqueeze(0), + grad.abs().add_(eps).unsqueeze_(0) + ], 0) + torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long())) + + bias_corrections = [1 - beta1 ** state['step'] for state in states] + clr = [-1 * (group['lr'] / bias_correction) for bias_correction in bias_corrections] + torch._foreach_addcdiv_(params_with_grad, exp_avgs, exp_infs, clr) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/adamax.pyi b/torch/optim/_multi_tensor/adamax.pyi new file mode 100644 index 0000000000000..4ac68f75ba995 --- /dev/null +++ b/torch/optim/_multi_tensor/adamax.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class Adamax(Optimizer): + def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=...) -> None: ... diff --git a/torch/optim/_multi_tensor/adamw.py b/torch/optim/_multi_tensor/adamw.py new file mode 100644 index 0000000000000..3670c786b68dc --- /dev/null +++ b/torch/optim/_multi_tensor/adamw.py @@ -0,0 +1,166 @@ +import math +import torch +from ..optimizer import Optimizer +from collections import defaultdict + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + amsgrad = group['amsgrad'] + + grads = [] + states = [] + exp_avg = [] + exp_avg_sq = [] + max_exp_avg_sq = [] + params_with_grad = [] + + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + + # Perform stepweight decay + p.mul_(1 - group['lr'] * group['weight_decay']) + + params_with_grad.append(p) + grads.append(p.grad) + + for p in params_with_grad: + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg.append(state['exp_avg']) + exp_avg_sq.append(state['exp_avg_sq']) + + if amsgrad: + max_exp_avg_sq.append(state['max_exp_avg_sq']) + + state['step'] += 1 + states.append(state) + + beta1, beta2 = group['betas'] + + bias_correction1 = [1 - beta1 ** state['step'] for state in states] + bias_correction2 = [1 - beta2 ** state['step'] for state in states] + + # + # Decay the first and second moment running average coefficient + # + torch._foreach_mul_(exp_avg, beta1) + torch._foreach_add_(exp_avg, grads, alpha=1 - beta1) + + torch._foreach_mul_(exp_avg_sq, beta2) + torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + max_exp_avg_sq = torch._foreach_maximum(max_exp_avg_sq, exp_avg_sq) + + # Use the max. for normalizing running avg. of gradient + max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq) + bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] + torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction_sqrt) + denom = torch._foreach_add(max_exp_avg_sq_sqrt, group['eps']) + else: + exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq) + bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) + denom = torch._foreach_add(exp_avg_sq_sqrt, group['eps']) + + step_size = [-1 * (group['lr'] / bc) for bc in bias_correction1] + torch._foreach_addcdiv_(params_with_grad, exp_avg, denom, step_size) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/adamw.pyi b/torch/optim/_multi_tensor/adamw.pyi new file mode 100644 index 0000000000000..dedd8de3f876a --- /dev/null +++ b/torch/optim/_multi_tensor/adamw.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class AdamW(Optimizer): + def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., amsgrad: bool = ...) -> None: ... diff --git a/torch/optim/_multi_tensor/asgd.py b/torch/optim/_multi_tensor/asgd.py new file mode 100644 index 0000000000000..351e4324908fc --- /dev/null +++ b/torch/optim/_multi_tensor/asgd.py @@ -0,0 +1,117 @@ +import math +import torch +from ..optimizer import Optimizer +from collections import defaultdict + +class ASGD(Optimizer): + """Implements Averaged Stochastic Gradient Descent. + + It has been proposed in `Acceleration of stochastic approximation by + averaging`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + lambd (float, optional): decay term (default: 1e-4) + alpha (float, optional): power for eta update (default: 0.75) + t0 (float, optional): point at which to start averaging (default: 1e6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + .. _Acceleration of stochastic approximation by averaging: + https://dl.acm.org/citation.cfm?id=131098 + """ + + def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, lambd=lambd, alpha=alpha, t0=t0, + weight_decay=weight_decay) + super(ASGD, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + grads = [] + params_with_grad = [] + states = [] + + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('ASGD does not support sparse gradients') + + grads.append(p.grad) + params_with_grad.append(p) + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['eta'] = group['lr'] + state['mu'] = 1 + state['ax'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['step'] += 1 + states.append(state) + + if group['weight_decay'] != 0: + torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay']) + + # decay term + torch._foreach_mul_(params_with_grad, 1 - group['lambd'] * state['eta']) + + # update parameter + torch._foreach_add_(params_with_grad, grads, alpha=-state['eta']) + + # averaging + for i in range(len(states)): + if states[i]['mu'] != 1: + states[i]['ax'].add_(params_with_grad[i].sub(states[i]['ax']).mul(states[i]['mu'])) + else: + states[i]['ax'].copy_(params_with_grad[i]) + + # update eta and mu + for state in states: + state['eta'] = (group['lr'] / + math.pow((1 + group['lambd'] * group['lr'] * state['step']), group['alpha'])) + state['mu'] = 1 / max(1, state['step'] - group['t0']) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/asgd.pyi b/torch/optim/_multi_tensor/asgd.pyi new file mode 100644 index 0000000000000..06e9149b72f5c --- /dev/null +++ b/torch/optim/_multi_tensor/asgd.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class ASGD(Optimizer): + def __init__(self, params: _params_t, lr: float=..., lambd: float=..., alpha: float=..., t0: float=..., weight_decay: float=...) -> None: ... diff --git a/torch/optim/_multi_tensor/rmsprop.py b/torch/optim/_multi_tensor/rmsprop.py new file mode 100644 index 0000000000000..ac918307e7c05 --- /dev/null +++ b/torch/optim/_multi_tensor/rmsprop.py @@ -0,0 +1,146 @@ +import torch +from ..optimizer import Optimizer +from collections import defaultdict + +class RMSprop(Optimizer): + r"""Implements RMSprop algorithm. + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + The implementation here takes the square root of the gradient average before + adding epsilon (note that TensorFlow interchanges these two operations). The effective + learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` + is the scheduled learning rate and :math:`v` is the weighted moving average + of the squared gradient. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing constant (default: 0.99) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + """ + + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) + super(RMSprop, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSprop, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + grads = [] + params_with_grad = [] + states = [] + alpha = group['alpha'] + square_avg = [] + + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + + grads.append(p.grad) + params_with_grad.append(p) + + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['step'] += 1 + + states.append(state) + square_avg.append(state['square_avg']) + + if group['weight_decay'] != 0: + torch._foreach_add_(grads, params_with_grad, alpha=group['weight_decay']) + + torch._foreach_mul_(square_avg, alpha) + torch._foreach_addcmul_(square_avg, grads, grads, value=1 - alpha) + + if group['centered']: + grad_avgs = [s['grad_avg'] for s in states] + torch._foreach_mul_(grad_avgs, alpha) + torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha) + avg = torch._foreach_addcmul(square_avg, grad_avgs, grad_avgs, value=-1) + torch._foreach_sqrt_(avg) + torch._foreach_add_(avg, group['eps']) + else: + avg = torch._foreach_sqrt(square_avg) + torch._foreach_add_(avg, group['eps']) + + if group['momentum'] > 0: + buf = [s['momentum_buffer'] for s in states] + torch._foreach_mul_(buf, group['momentum']) + torch._foreach_addcdiv_(buf, grads, avg) + torch._foreach_add_(params_with_grad, buf, alpha=-group['lr']) + else: + torch._foreach_addcdiv_(params_with_grad, grads, avg, value=-group['lr']) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/rmsprop.pyi b/torch/optim/_multi_tensor/rmsprop.pyi new file mode 100644 index 0000000000000..691f2188ebb12 --- /dev/null +++ b/torch/optim/_multi_tensor/rmsprop.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class RMSprop(Optimizer): + def __init__(self, params: _params_t, lr: float=..., alpha: float=..., eps: float=..., weight_decay: float=..., momentum: float=..., centered: bool=...) -> None: ... diff --git a/torch/optim/_multi_tensor/rprop.py b/torch/optim/_multi_tensor/rprop.py new file mode 100644 index 0000000000000..d2a3eca755db0 --- /dev/null +++ b/torch/optim/_multi_tensor/rprop.py @@ -0,0 +1,118 @@ +import torch +from ..optimizer import Optimizer +from collections import defaultdict + +class Rprop(Optimizer): + """Implements the resilient backpropagation algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + etas (Tuple[float, float], optional): pair of (etaminus, etaplis), that + are multiplicative increase and decrease factors + (default: (0.5, 1.2)) + step_sizes (Tuple[float, float], optional): a pair of minimal and + maximal allowed step sizes (default: (1e-6, 50)) + """ + + def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < etas[0] < 1.0 < etas[1]: + raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1])) + + defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes) + super(Rprop, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + grads = [] + states = [] + params_with_grad = [] + step_sizes = [] + + for group in self.param_groups: + for p in group['params']: + etaminus, etaplus = group['etas'] + step_size_min, step_size_max = group['step_sizes'] + + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + + grads.append(p.grad) + params_with_grad.append(p) + + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + state['prev'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['step_size'] = p.grad.new().resize_as_(p.grad).fill_(group['lr']) + + state['step'] += 1 + + states.append(state) + step_sizes.append(state['step_size']) + + signs = torch._foreach_mul(grads, [s['prev'] for s in states]) + signs = [s.sign() for s in signs] + for sign in signs: + sign[sign.gt(0)] = etaplus + sign[sign.lt(0)] = etaminus + sign[sign.eq(0)] = 1 + + # update stepsizes with step size updates + torch._foreach_mul_(step_sizes, signs) + for step_size in step_sizes: + step_size.clamp_(step_size_min, step_size_max) + + # for dir<0, dfdx=0 + # for dir>=0 dfdx=dfdx + for i in range(len(grads)): + grads[i] = grads[i].clone(memory_format=torch.preserve_format) + grads[i][signs[i].eq(etaminus)] = 0 + + # update parameters + grad_signs = [grad.sign() for grad in grads] + torch._foreach_addcmul_(params_with_grad, grad_signs, step_sizes, value=-1) + + for i in range(len(states)): + states[i]['prev'].copy_(grads[i]) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/rprop.pyi b/torch/optim/_multi_tensor/rprop.pyi new file mode 100644 index 0000000000000..0ea64c63d25e5 --- /dev/null +++ b/torch/optim/_multi_tensor/rprop.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class Rprop(Optimizer): + def __init__(self, params: _params_t, lr: float=..., etas: Tuple[float, float]=..., step_sizes: Tuple[float, float]=...) -> None: ... diff --git a/torch/optim/_multi_tensor/sgd.py b/torch/optim/_multi_tensor/sgd.py new file mode 100644 index 0000000000000..a1f5772871f5b --- /dev/null +++ b/torch/optim/_multi_tensor/sgd.py @@ -0,0 +1,177 @@ +import torch +from ..optimizer import Optimizer, required +from collections import defaultdict + +class SGD(Optimizer): + r"""Implements stochastic gradient descent (optionally with momentum). + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et. al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + + where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the + parameters, gradient, velocity, and momentum respectively. + + This is in contrast to Sutskever et. al. and + other frameworks which employ an update of the form + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ + p_{t+1} & = p_{t} - v_{t+1}. + \end{aligned} + + The Nesterov version is analogously modified. + """ + + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(SGD, self).__init__(params, defaults) + + def __setstate__(self, state): + super(SGD, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + grads = [] + params_with_grad = [] + states = [] + has_sparse_grad = False + + for p in group['params']: + if p.grad is not None: + grads.append(p.grad) + params_with_grad.append(p) + states.append(self.state[p]) + + if p.grad.is_sparse: + has_sparse_grad = True + + if momentum != 0: + raise RuntimeError('SGD does not support momentum for sparse gradients') + + if grads == []: + return loss + + if weight_decay != 0: + grads = torch._foreach_add(grads, params_with_grad, alpha=weight_decay) + + if momentum != 0: + bufs = [] + + all_states_with_momentum_buffer = True + for i in range(len(states)): + if 'momentum_buffer' not in states[i]: + all_states_with_momentum_buffer = False + break + else: + bufs.append(states[i]['momentum_buffer']) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(states)): + if 'momentum_buffer' not in states[i]: + buf = states[i]['momentum_buffer'] = torch.clone(grads[i]).detach() + else: + buf = states[i]['momentum_buffer'] + buf.mul_(momentum).add_(grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(grads, bufs, alpha=momentum) + else: + grads = bufs + + if not has_sparse_grad: + torch._foreach_add_(params_with_grad, grads, alpha=-group['lr']) + else: + # foreach APIs dont support sparse + for i in range(len(params_with_grad)): + params_with_grad[i].add_(grads[i], alpha=-group['lr']) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/sgd.pyi b/torch/optim/_multi_tensor/sgd.pyi new file mode 100644 index 0000000000000..6082e230cd795 --- /dev/null +++ b/torch/optim/_multi_tensor/sgd.pyi @@ -0,0 +1,4 @@ +from ..optimizer import _params_t, Optimizer + +class SGD(Optimizer): + def __init__(self, params: _params_t, lr: float, momentum: float=..., dampening: float=..., weight_decay:float=..., nesterov:bool=...) -> None: ... diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 11cdf5e28b656..ae9286a5cb2f9 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -8,7 +8,7 @@ class Adadelta(Optimizer): It has been proposed in `ADADELTA: An Adaptive Learning Rate Method`__. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups rho (float, optional): coefficient used for computing a running average @@ -39,7 +39,7 @@ def __init__(self, params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 761648537fc65..6299924c5aa38 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -1,4 +1,5 @@ import torch +from . import functional as F from .optimizer import Optimizer @@ -8,7 +9,7 @@ class Adagrad(Optimizer): It has been proposed in `Adaptive Subgradient Methods for Online Learning and Stochastic Optimization`_. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) @@ -53,7 +54,7 @@ def share_memory(self): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ @@ -63,40 +64,29 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - - grad = p.grad - state = self.state[p] - - state['step'] += 1 + params_with_grad = [] + grads = [] + state_sums = [] + state_steps = [] - if group['weight_decay'] != 0: - if p.grad.is_sparse: - raise RuntimeError("weight_decay option is not compatible with sparse gradients") - grad = grad.add(p, alpha=group['weight_decay']) - - clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay']) - - if grad.is_sparse: - grad = grad.coalesce() # the update is non-linear so indices must be unique - grad_indices = grad._indices() - grad_values = grad._values() - size = grad.size() - - def make_sparse(values): - constructor = grad.new - if grad_indices.dim() == 0 or values.dim() == 0: - return constructor().resize_as_(grad) - return constructor(grad_indices, values, size) - state['sum'].add_(make_sparse(grad_values.pow(2))) - std = state['sum'].sparse_mask(grad) - std_values = std._values().sqrt_().add_(group['eps']) - p.add_(make_sparse(grad_values / std_values), alpha=-clr) - else: - state['sum'].addcmul_(grad, grad, value=1) - std = state['sum'].sqrt().add_(group['eps']) - p.addcdiv_(grad, std, value=-clr) + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + grads.append(p.grad) + state = self.state[p] + state_sums.append(state['sum']) + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + F.adagrad(params_with_grad, + grads, + state_sums, + state_steps, + group['lr'], + group['weight_decay'], + group['lr_decay'], + group['eps']) return loss diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 22a9e3828a57f..2e0611dae6dc6 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -1,5 +1,5 @@ -import math import torch +from . import functional as F from .optimizer import Optimizer @@ -10,7 +10,7 @@ class Adam(Optimizer): The implementation of the L2 penalty follows changes proposed in `Decoupled Weight Decay Regularization`_. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) @@ -56,7 +56,7 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ @@ -66,52 +66,56 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad - if grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - amsgrad = group['amsgrad'] - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - if amsgrad: - max_exp_avg_sq = state['max_exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - - if group['weight_decay'] != 0: - grad = grad.add(p, alpha=group['weight_decay']) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - else: - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - - step_size = group['lr'] / bias_correction1 - - p.addcdiv_(exp_avg, denom, value=-step_size) + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if group['amsgrad']: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + beta1, beta2 = group['betas'] + F.adam(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + group['amsgrad'], + beta1, + beta2, + group['lr'], + group['weight_decay'], + group['eps'] + ) return loss diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 1bb8423763497..fa028f469b432 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -7,7 +7,7 @@ class Adamax(Optimizer): It has been proposed in `Adam: A Method for Stochastic Optimization`__. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 2e-3) @@ -40,7 +40,7 @@ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 07d58313cfb84..e350566d7f5b8 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -9,7 +9,7 @@ class AdamW(Optimizer): The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) @@ -55,7 +55,7 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index e9f7ca2b22556..887bf71e00290 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -9,7 +9,7 @@ class ASGD(Optimizer): It has been proposed in `Acceleration of stochastic approximation by averaging`_. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) @@ -36,7 +36,7 @@ def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/functional.py b/torch/optim/functional.py new file mode 100644 index 0000000000000..2984d7f417eac --- /dev/null +++ b/torch/optim/functional.py @@ -0,0 +1,98 @@ +r"""Functional interface""" +import math +import torch +from torch import Tensor +from typing import List + +# TODO: use foreach API in optim.functional to do all the computation + +def _make_sparse(grad, grad_indices, values): + size = grad.size() + if grad_indices.numel() == 0 or values.numel() == 0: + return torch.empty_like(grad) + return torch.sparse_coo_tensor(grad_indices, values, size) + + +def adagrad(params: List[Tensor], + grads: List[Tensor], + state_sums: List[Tensor], + state_steps: List[int], + lr: float, + weight_decay: float, + lr_decay: float, + eps: float): + r"""Functional API that performs Adagrad algorithm computation. + + See :class:`~torch.optim.Adagrad` for details. + """ + + for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps): + if weight_decay != 0: + if grad.is_sparse: + raise RuntimeError("weight_decay option is not compatible with sparse gradients") + grad = grad.add(param, alpha=weight_decay) + + clr = lr / (1 + (step - 1) * lr_decay) + + if grad.is_sparse: + grad = grad.coalesce() # the update is non-linear so indices must be unique + grad_indices = grad._indices() + grad_values = grad._values() + size = grad.size() + + state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) + std = state_sum.sparse_mask(grad) + std_values = std._values().sqrt_().add_(eps) + param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) + else: + state_sum.addcmul_(grad, grad, value=1) + std = state_sum.sqrt().add_(eps) + param.addcdiv_(grad, std, value=-clr) + + +def adam(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs Adam algorithm computation. + + See :class:`~torch.optim.Adam` for details. + """ + + for i, param in enumerate(params): + + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + if amsgrad: + max_exp_avg_sq = max_exp_avg_sqs[i] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 3571975b1388a..4329b6fd8baea 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -197,7 +197,7 @@ class LBFGS(Optimizer): ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory try reducing the history size, or use a different algorithm. - Arguments: + Args: lr (float): learning rate (default: 1) max_iter (int): maximal number of iterations per optimization step (default: 20) @@ -283,7 +283,7 @@ def _directional_evaluate(self, closure, x, t, d): def step(self, closure): """Performs a single optimization step. - Arguments: + Args: closure (callable): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 1fc112e54e168..043a8213b4c2d 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -19,8 +19,6 @@ "https://github.com/pytorch/pytorch/issues/new/choose." ) -SAVE_STATE_WARNING = "Please also save or load the state of the optimizer when saving or loading the scheduler." - class _LRScheduler(object): def __init__(self, optimizer, last_epoch=-1, verbose=False): @@ -40,7 +38,7 @@ def __init__(self, optimizer, last_epoch=-1, verbose=False): if 'initial_lr' not in group: raise KeyError("param 'initial_lr' is not specified " "in param_groups[{}] when resuming an optimizer".format(i)) - self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] self.last_epoch = last_epoch # Following https://github.com/pytorch/pytorch/issues/20124 @@ -89,7 +87,7 @@ def state_dict(self): def load_state_dict(self, state_dict): """Loads the schedulers state. - Arguments: + Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ @@ -211,9 +209,10 @@ def state_dict(self): is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. """ - warnings.warn(SAVE_STATE_WARNING, UserWarning) state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) @@ -226,12 +225,13 @@ def state_dict(self): def load_state_dict(self, state_dict): """Loads the schedulers state. - Arguments: + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + + Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ - warnings.warn(SAVE_STATE_WARNING, UserWarning) lr_lambdas = state_dict.pop('lr_lambdas') self.__dict__.update(state_dict) # Restore state_dict keys in order to prevent side effects @@ -305,7 +305,7 @@ def state_dict(self): def load_state_dict(self, state_dict): """Loads the schedulers state. - Arguments: + Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ @@ -855,7 +855,7 @@ def __init__(self, if last_epoch == -1: for momentum, group in zip(base_momentums, optimizer.param_groups): group['momentum'] = momentum - self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups)) + self.base_momentums = [group['momentum'] for group in optimizer.param_groups] self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) super(CyclicLR, self).__init__(optimizer, last_epoch, verbose) @@ -1072,6 +1072,10 @@ class OneCycleLR(_LRScheduler): You must either provide a value for total_steps or provide a value for both epochs and steps_per_epoch. + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + Args: optimizer (Optimizer): Wrapped optimizer. max_lr (float or list): Upper learning rate boundaries in the cycle @@ -1116,6 +1120,10 @@ class OneCycleLR(_LRScheduler): final_div_factor (float): Determines the minimum learning rate via min_lr = initial_lr/final_div_factor Default: 1e4 + three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the + learning rate according to 'final_div_factor' instead of modifying the second + phase (the first two phases will be symmetrical about the step indicated by + 'pct_start'). last_epoch (int): The index of the last batch. This parameter is used when resuming a training job. Since `step()` should be invoked after each batch instead of after each epoch, this number represents the total @@ -1151,6 +1159,7 @@ def __init__(self, max_momentum=0.95, div_factor=25., final_div_factor=1e4, + three_phase=False, last_epoch=-1, verbose=False): @@ -1173,8 +1182,48 @@ def __init__(self, if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch)) self.total_steps = epochs * steps_per_epoch - self.step_size_up = float(pct_start * self.total_steps) - 1 - self.step_size_down = float(self.total_steps - self.step_size_up) - 1 + + if three_phase: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'max_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'base_momentum', + }, + { + 'end_step': float(2 * pct_start * self.total_steps) - 2, + 'start_lr': 'max_lr', + 'end_lr': 'initial_lr', + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum', + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'min_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'max_momentum', + }, + ] + else: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'max_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'base_momentum', + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': 'max_lr', + 'end_lr': 'min_lr', + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum', + }, + ] # Validate pct_start if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): @@ -1248,17 +1297,16 @@ def get_lr(self): .format(step_num + 1, self.total_steps)) for group in self.optimizer.param_groups: - if step_num <= self.step_size_up: - computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up) - if self.cycle_momentum: - computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'], - step_num / self.step_size_up) - else: - down_step_num = step_num - self.step_size_up - computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down) - if self.cycle_momentum: - computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'], - down_step_num / self.step_size_down) + start_step = 0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase['end_step'] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct) + if self.cycle_momentum: + computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct) + break + start_step = phase['end_step'] lrs.append(computed_lr) if self.cycle_momentum: diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 7d413b9594158..b3e38c613fe0c 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -5,6 +5,7 @@ from copy import deepcopy from itertools import chain import warnings +import functools class _RequiredParameter(object): @@ -23,7 +24,7 @@ class Optimizer(object): ordering that is consistent between runs. Examples of objects that don't satisfy those properties are sets and iterators over values of dictionaries. - Arguments: + Args: params (iterable): an iterable of :class:`torch.Tensor` s or :class:`dict` s. Specifies what Tensors should be optimized. defaults: (dict): a dict containing default values of optimization @@ -34,6 +35,8 @@ def __init__(self, params, defaults): torch._C._log_api_usage_once("python.optimizer") self.defaults = defaults + self._hook_for_profile() + if isinstance(params, torch.Tensor): raise TypeError("params argument given to the optimizer should be " "an iterable of Tensors or dicts, but got " + @@ -60,6 +63,7 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) + self._hook_for_profile() # To support multiprocessing pickle/unpickle. def __repr__(self): format_string = self.__class__.__name__ + ' (' @@ -72,6 +76,24 @@ def __repr__(self): format_string += ')' return format_string + def _hook_for_profile(self): + self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) + + def profile_hook_step(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + obj, *_ = args + profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) + with torch.autograd.profiler.record_function(profile_name): + return func(*args, **kwargs) + return wrapper + + hooked = getattr(self.__class__.step, "hooked", None) + if not hooked: + self.__class__.step = profile_hook_step(self.__class__.step) + self.__class__.step.hooked = True + def state_dict(self): r"""Returns the state of the optimizer as a :class:`dict`. @@ -105,7 +127,7 @@ def pack_group(group): def load_state_dict(self, state_dict): r"""Loads the optimizer state. - Arguments: + Args: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ @@ -167,7 +189,7 @@ def update_group(group, new_group): def zero_grad(self, set_to_none: bool = False): r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero. - Arguments: + Args: set_to_none (bool): instead of setting to zero, set the grads to None. This is will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: @@ -179,22 +201,25 @@ def zero_grad(self, set_to_none: bool = False): (in one case it does the step with a gradient of 0 and in the other it skips the step altogether). """ - for group in self.param_groups: - for p in group['params']: - if p.grad is not None: - if set_to_none: - p.grad = None - else: - if p.grad.grad_fn is not None: - p.grad.detach_() + if not hasattr(self, "_zero_grad_profile_name"): + self._hook_for_profile() + with torch.autograd.profiler.record_function(self._zero_grad_profile_name): + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None else: - p.grad.requires_grad_(False) - p.grad.zero_() + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() def step(self, closure): r"""Performs a single optimization step (parameter update). - Arguments: + Args: closure (callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. @@ -210,7 +235,7 @@ def add_param_group(self, param_group): This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the :class:`Optimizer` as training progresses. - Arguments: + Args: param_group (dict): Specifies what Tensors should be optimized along with group specific optimization options. """ diff --git a/torch/optim/optimizer.pyi b/torch/optim/optimizer.pyi index aa50a6fd1027e..6202050f3493e 100644 --- a/torch/optim/optimizer.pyi +++ b/torch/optim/optimizer.pyi @@ -10,7 +10,7 @@ class Optimizer: param_groups: List[dict] def __init__(self, params: _params_t, default: dict) -> None: ... - def __setstate__(self, statue: dict) -> None: ... + def __setstate__(self, state: dict) -> None: ... def state_dict(self) -> dict: ... def load_state_dict(self, state_dict: dict) -> None: ... def zero_grad(self, set_to_none: Optional[bool]=...) -> None: ... diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index ea02096b54601..d19f87fda0b6b 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -17,7 +17,7 @@ class RMSprop(Optimizer): is the scheduled learning rate and :math:`v` is the weighted moving average of the squared gradient. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) @@ -56,7 +56,7 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 071f104b958f8..ec2a5f1f222ac 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -5,7 +5,7 @@ class Rprop(Optimizer): """Implements the resilient backpropagation algorithm. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) @@ -29,7 +29,7 @@ def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 9c6af5c3aa8bb..5b071f820ad6a 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -76,7 +76,7 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index e1315e370269f..05a576d6544ca 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -9,7 +9,7 @@ class SparseAdam(Optimizer): In this variant, only moments that show up in the gradient get updated, and only those portions of the gradient get applied to the parameters. - Arguments: + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) @@ -32,6 +32,8 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + params = list(params) + sparse_params = [] for index, param in enumerate(params): if isinstance(param, dict): @@ -52,7 +54,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): def step(self, closure=None): """Performs a single optimization step. - Arguments: + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index e9cd8d8cfc237..65b694a27cf25 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -17,7 +17,7 @@ class AveragedModel(Module): on the device :attr:`device` and allows to compute running averages of the parameters of the :attr:`model`. - Arguments: + Args: model (torch.nn.Module): model to use with SWA device (torch.device, optional): if provided, the averaged model will be stored on the :attr:`device` @@ -117,7 +117,7 @@ def update_bn(loader, model, device=None): It performs one pass over data in `loader` to estimate the activation statistics for BatchNorm layers in the model. - Arguments: + Args: loader (torch.utils.data.DataLoader): dataset loader to compute the activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor @@ -172,7 +172,7 @@ class SWALR(_LRScheduler): This learning rate scheduler is meant to be used with Stochastic Weight Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`). - Arguments: + Args: optimizer (torch.optim.Optimizer): wrapped optimizer swa_lrs (float or list): the learning rate value for all param groups together or separately for each group. @@ -219,8 +219,8 @@ def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', l self.anneal_func = self._cosine_anneal elif anneal_strategy == 'linear': self.anneal_func = self._linear_anneal - if not isinstance(anneal_epochs, int) or anneal_epochs < 1: - raise ValueError("anneal_epochs must be a positive integer, got {}".format( + if not isinstance(anneal_epochs, int) or anneal_epochs < 0: + raise ValueError("anneal_epochs must be equal or greater than 0, got {}".format( anneal_epochs)) self.anneal_epochs = anneal_epochs @@ -257,11 +257,13 @@ def get_lr(self): warnings.warn("To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning) step = self._step_count - 1 - prev_t = max(0, min(1, (step - 1) / self.anneal_epochs)) + if self.anneal_epochs == 0: + step = max(1, step) + prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) prev_alpha = self.anneal_func(prev_t) prev_lrs = [self._get_initial_lr(group['lr'], group['swa_lr'], prev_alpha) for group in self.optimizer.param_groups] - t = max(0, min(1, step / self.anneal_epochs)) + t = max(0, min(1, step / max(1, self.anneal_epochs))) alpha = self.anneal_func(t) return [group['swa_lr'] * alpha + lr * (1 - alpha) for group, lr in zip(self.optimizer.param_groups, prev_lrs)] diff --git a/torch/overrides.py b/torch/overrides.py index d5f247e5d51a4..cbb33c8d28484 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1,14 +1,16 @@ """ -Python implementation of __torch_function__ +Python implementation of ``__torch_function__`` -While most of the torch API and handling for __torch_function__ happens +While most of the torch API and handling for ``__torch_function__`` happens at the C++ level, some of the torch API is written in Python so we need -python-level handling for __torch_function__ overrides as well. The main +python-level handling for ``__torch_function__`` overrides as well. The main developer-facing functionality in this file are handle_torch_function and has_torch_function. See torch/functional.py and test/test_overrides.py for usage examples. -NOTE: heavily inspired by NumPy's ``__array_function__`` (see: +Note +---- +heavily inspired by NumPy's ``__array_function__`` (see: https://github.com/pytorch/pytorch/issues/24015 and https://www.numpy.org/neps/nep-0018-array-function-protocol.html ) @@ -26,17 +28,39 @@ from typing import Dict, Set, List, Any, Callable, Iterable import torch -from torch._C import _is_torch_function_enabled, _disabled_torch_function_impl +from torch._C import ( + _has_torch_function, _has_torch_function_unary, + _has_torch_function_variadic, _add_docstr) + +__all__ = [ + "get_ignored_functions", + "get_overridable_functions", + "get_testing_overrides", + "handle_torch_function", + "has_torch_function", + "is_tensor_like", + "is_tensor_method_or_property", + "wrap_torch_function", +] @functools.lru_cache(None) def get_ignored_functions() -> Set[Callable]: - """Return public functions that cannot be overridden by __torch_function__ + """ + Return public functions that cannot be overridden by ``__torch_function__``. Returns ------- - A tuple of functions that are publicly available in the torch API but cannot - be overridden with __torch_function__. Mostly this is because none of the - arguments of these functions are tensors or tensor-likes. + Tuple[Callable] + A tuple of functions that are publicly available in the torch API but cannot + be overridden with ``__torch_function__``. Mostly this is because none of the + arguments of these functions are tensors or tensor-likes. + + Examples + -------- + >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() + True + >>> torch.add in torch.overrides.get_ignored_functions() + False """ Tensor = torch.Tensor return { @@ -97,6 +121,7 @@ def get_ignored_functions() -> Set[Callable]: torch.as_strided, torch.bartlett_window, torch.blackman_window, + torch.broadcast_shapes, torch.can_cast, torch.cudnn_affine_grid_generator, torch.cudnn_batch_norm, @@ -109,6 +134,8 @@ def get_ignored_functions() -> Set[Callable]: torch.empty_strided, torch.empty_quantized, torch.eye, + torch.fft.fftfreq, + torch.fft.rfftfreq, torch.from_file, torch.full, torch.hamming_window, @@ -142,6 +169,8 @@ def get_ignored_functions() -> Set[Callable]: torch.nn.functional.upsample_bilinear, torch.nn.functional.upsample_nearest, torch.nn.functional.has_torch_function, + torch.nn.functional.has_torch_function_unary, + torch.nn.functional.has_torch_function_variadic, torch.nn.functional.handle_torch_function, torch.nn.functional.sigmoid, torch.nn.functional.hardsigmoid, @@ -156,6 +185,8 @@ def get_ignored_functions() -> Set[Callable]: torch.is_deterministic, torch.set_deterministic, torch.unify_type_list, + torch.make_dual, + torch.unpack_dual, Tensor.__delitem__, Tensor.__dir__, Tensor.__getattribute__, @@ -172,6 +203,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor.new, Tensor.new_tensor, Tensor.new_empty, + Tensor.new_empty_strided, Tensor.new_zeros, Tensor.new_ones, Tensor.new_full, @@ -187,12 +219,20 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Returns ------- - A dictionary that maps overridable functions in the PyTorch API to - lambda functions that have the same signature as the real function - and unconditionally return -1. These lambda functions are useful - for testing API coverage for a type that defines __torch_function__. + Dict[Callable, Callable] + A dictionary that maps overridable functions in the PyTorch API to + lambda functions that have the same signature as the real function + and unconditionally return -1. These lambda functions are useful + for testing API coverage for a type that defines ``__torch_function__``. + + Examples + -------- + >>> import inspect + >>> my_add = torch.overrides.get_testing_overrides()[torch.add] + >>> inspect.signature(my_add) + """ - # Every function in the PyTorch API that can be overriden needs an entry + # Every function in the PyTorchAPI that can be overriden needs an entry # in this dict. # # Optimally we would use inspect to get the function signature and define @@ -209,7 +249,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.arccos: lambda input, out=None: -1, torch.acosh: lambda input, out=None: -1, torch.arccosh: lambda input, out=None: -1, - torch.add_relu: lambda input, other, out=None: -1, torch.add: lambda input, other, out=None: -1, torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1, @@ -237,12 +276,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.atan2: lambda input, other, out=None: -1, torch.atanh: lambda input, out=None: -1, torch.arctanh: lambda input, out=None: -1, - torch.atleast_1d: lambda input: -1, - torch.atleast_2d: lambda input: -1, - torch.atleast_3d: lambda input: -1, - torch.atleast_1d: lambda *inputs: -1, - torch.atleast_2d: lambda *inputs: -1, - torch.atleast_3d: lambda *inputs: -1, + torch.atleast_1d: lambda *tensors: -1, + torch.atleast_2d: lambda *tensors: -1, + torch.atleast_3d: lambda *tensors: -1, torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1, torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1, @@ -266,15 +302,17 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.block_diag: lambda *tensors: -1, torch.bmm: lambda input, mat2, out=None: -1, torch.broadcast_tensors: lambda *tensors: -1, + torch.broadcast_to: lambda self, size: -1, torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1, torch.cartesian_prod: lambda *tensors: -1, torch.cat: lambda tensors, dim=0, out=None: -1, - torch.cdist: lambda x1, c2, p=2, compute_mode=None: -1, + torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1, torch.ceil: lambda input, out=None: -1, torch.celu: lambda input, alhpa=1., inplace=False: -1, torch.chain_matmul: lambda *matrices: -1, torch.channel_shuffle: lambda input, groups : -1, torch.cholesky: lambda input, upper=False, out=None: -1, + torch.linalg.cholesky: lambda input, out=None: -1, torch.cholesky_inverse: lambda input, upper=False, out=None: -1, torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1, torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1, @@ -283,10 +321,13 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.clip: lambda input, min=None, max=None, out=None: -1, torch.clamp_min: lambda input, min, out=None: -1, torch.clamp_max: lambda input, max, out=None: -1, + torch.column_stack: lambda tensors, out=None: -1, torch.clone: lambda input: -1, torch.combinations: lambda input, r=2, with_replacement=False: -1, torch.complex: lambda real, imag: -1, + torch.copysign: lambda input, other, out=None: -1, torch.polar: lambda abs, ang: -1, + torch.linalg.cond: lambda input, ord=None: -1, torch.conj: lambda input, out=None: -1, torch.constant_pad_nd: lambda input, pad, value=0: -1, torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, @@ -323,12 +364,14 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.dist: lambda input, other, p=2: -1, torch.div: lambda input, other, out=None: -1, torch.divide: lambda input, other, out=None: -1, - torch.dot: lambda mat1, mat2: -1, + torch.dot: lambda input, other, out=None: -1, torch.dropout: lambda input, p, train, inplace=False: -1, torch.dsmm: lambda input, mat2: -1, torch.hsmm: lambda mat1, mat2: -1, torch.dstack: lambda tensors, out=None: -1, torch.eig: lambda input, eigenvectors=False, out=None: -1, + torch.linalg.eigh: lambda input, UPLO="L", out=None: -1, + torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1, torch.einsum: lambda equation, *operands: -1, torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1), @@ -355,8 +398,24 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1, torch.feature_alpha_dropout: lambda input, p, train: -1, torch.feature_dropout: lambda input, p, train: -1, + torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1, + torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1, + torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, + torch.fft.fftshift: lambda input, dim=None: -1, + torch.fft.ifftshift: lambda input, dim=None: -1, + torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1, torch.fix: lambda input, out=None: -1, - torch.fft: lambda input, signal_ndim, normalized=False: -1, torch.flatten: lambda input, start_dim=0, end_dim=-1: -1, torch.flip: lambda input, dims: -1, torch.fliplr: lambda input: -1, @@ -364,6 +423,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1, torch.floor: lambda input, out=None: -1, torch.floor_divide: lambda input, other: -1, + torch.float_power: lambda input, exponent, out=None: -1, torch.fmod: lambda input, other, out=None: -1, torch.frac: lambda input, out=None: -1, torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, @@ -374,6 +434,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.greater_equal: lambda input, other, out=None: -1, torch.geqrf: lambda input, out=None: -1, torch.i0: lambda input, out=None: -1, + torch.inner: lambda input, other, out=None: -1, torch.outer: lambda input, vec2, out=None: -1, # alias for torch.ger torch.ger: lambda input, vec2, out=None: -1, torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, @@ -391,7 +452,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.hspmm: lambda mat1, mat2, out=None: -1, torch.hstack: lambda tensors, out=None: -1, torch.hypot: lambda input, other, out=None: -1, - torch.ifft: lambda input, signal_ndim, normalized=False: -1, + torch.igamma: lambda input, other, out=None: -1, + torch.igammac: lambda input, other, out=None: -1, torch.imag: lambda input, out=None: -1, torch.index_add: lambda input, dim, index, source: -1, torch.index_copy: lambda input, dim, index, source: -1, @@ -407,7 +469,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: cudnn_enabled: -1), torch.int_repr: lambda input: -1, torch.inverse: lambda input, out=None: -1, - torch.irfft: lambda input, signal_ndim, normalized=False, onesided=True, signal_sizes=None: -1, + torch.linalg.inv: lambda input, out=None: -1, torch.is_complex: lambda input: -1, torch.is_distributed: lambda input: -1, torch.is_floating_point: lambda input: -1, @@ -417,11 +479,13 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1, torch.isnan: lambda input: -1, torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, - normalized=False, onesided=True, length=None, return_complex=False: -1), + normalized=False, onesided=None, length=None, return_complex=False: -1), torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1, + torch.kron: lambda input, other: -1, torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1, torch.lcm: lambda input, other, out=None: -1, + torch.ldexp: lambda input, other, out=None: -1, torch.le: lambda input, other, out=None: -1, torch.less_equal: lambda input, other, out=None: -1, torch.lerp: lambda input, end, weight, out=None: -1, @@ -436,6 +500,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.logaddexp: lambda input, other, out=None: -1, torch.logaddexp2: lambda input, other, out=None: -1, torch.logdet: lambda input: -1, + torch.xlogy: lambda x, y: -1, torch.logical_and: lambda input, other, out=None: -1, torch.logical_not: lambda input, out=None: -1, torch.logical_or: lambda input, other, out=None: -1, @@ -449,7 +514,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.lt: lambda input, other, out=None: -1, torch.less: lambda input, other, out=None: -1, torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, - torch.lu_solve: lambda input, LU_data, LU_pivots, out=None: -1, + torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, torch.masked_fill: lambda input, mask, value: -1, torch.masked_scatter: lambda input, mask, source: -1, @@ -457,16 +522,18 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.matmul: lambda input, other, out=None: -1, torch.matrix_power: lambda input, n: -1, torch.matrix_rank: lambda input, tol=None, symmetric=False: -1, + torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1, torch.matrix_exp: lambda input: -1, torch.max: lambda input, out=None: -1, torch.maximum: lambda input, other, out=None: -1, - torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1, - torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1, - torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1, + torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, + torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1), torch.mean: lambda input, dim=None: -1, torch.median: lambda input, dim=None: -1, + torch.nanmedian: lambda input, dim=None: -1, torch.meshgrid: lambda *tensors, **kwargs: -1, torch.min: lambda input, out=None: -1, torch.minimum: lambda input, other, out=None: -1, @@ -482,14 +549,18 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.mm: lambda input, mat2, out=None: -1, torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1, torch.movedim: lambda input, source, destination: -1, + torch.moveaxis: lambda input, source, destination: -1, + torch.msort: lambda input, descending=False, out=None: -1, torch.mul: lambda input, other, out=None: -1, torch.multiply: lambda input, other, out=None: -1, torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1, torch.mv: lambda input, vec, out=None: -1, torch.mvlgamma: lambda input, p: -1, torch.narrow: lambda input, dim, start, length: -1, + torch.narrow_copy: lambda input, dim, start, length: -1, + torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1, torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1, - torch.native_layer_norm: lambda input, weight, bias, M, N, eps: -1, + torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2: -1, torch.native_norm: lambda input, p=2: -1, @@ -574,11 +645,11 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1), torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, - return_indices=False, ceil_mode=False: -1), + ceil_mode=False, return_indices=False: -1), torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1), torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, - return_indices=False, ceil_mode=False: -1), + ceil_mode=False, return_indices=False: -1), torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1), torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1, @@ -614,7 +685,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, torch.nn.functional.selu: lambda input, inplace=False: -1, torch.nn.functional.silu: lambda input, inplace=False: -1, - torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1, + torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1, torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1, torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1, @@ -640,7 +711,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1, torch.pdist: lambda input, p=2: -1, torch.pinverse: lambda input, rcond=1e-15: -1, + torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1, torch.pixel_shuffle: lambda input, upscale_factor: -1, + torch.pixel_unshuffle: lambda input, downscale_factor: -1, torch.poisson: lambda input, generator=None: -1, torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1, torch.polygamma: lambda input, n, out=None: -1, @@ -654,6 +727,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.q_scale: lambda input: -1, torch.q_zero_point: lambda input: -1, torch.qr: lambda input, some=True, out=None: -1, + torch.linalg.qr: lambda input, mode='reduced', out=None: -1, torch.quantile: lambda input, q, dim=None, keepdim=False, out=None: -1, torch.nanquantile: lambda input, q, dim=None, keepdim=False, out=None: -1, torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1, @@ -664,6 +738,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,), + dilation=(1,), ceil_mode=False: -1), torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(1, 1), ceil_mode=False: -1), torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, @@ -674,8 +750,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + torch.ravel: lambda input: -1, torch.real: lambda input, out=None: -1, - torch.vdot: lambda mat1, mat2: -1, + torch.vdot: lambda input, other, out=None: -1, torch.view_as_real: lambda input: -1, torch.view_as_complex: lambda input: -1, torch.reciprocal: lambda input, out=None: -1, @@ -684,7 +761,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.renorm: lambda input, p, dim, maxnorm, out=None: -1, torch.repeat_interleave: lambda input, dim=None: -1, torch.reshape: lambda input, shape: -1, - torch.rfft: lambda input, signal_ndim, normalized=False, onesided=True: -1, torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, @@ -692,6 +768,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.roll: lambda input, shifts, dims=None: -1, torch.rot90: lambda input, k=1, dims=(0, 1): -1, torch.round: lambda input, out=None: -1, + torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack torch.rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1), torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1, torch.rsqrt: lambda input, out=None: -1, @@ -707,13 +784,16 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.signbit: lambda input, out=None: -1, torch.sgn: lambda input, out=None: -1, torch.sin: lambda input, out=None: -1, + torch.sinc: lambda input, out=None: -1, torch.sinh: lambda input, out=None: -1, torch.slogdet: lambda input: -1, + torch.linalg.slogdet: lambda input: -1, torch.smm: lambda input, mat2: -1, torch.spmm: lambda input, mat2: -1, torch.softmax: lambda input, dim, dtype=None: -1, torch.solve: lambda input, A, out=None: -1, - torch.sort: lambda input, dim=-1, descending=False, out=None: -1, + torch.linalg.solve: lambda input, other, out=None: -1, + torch.sort: lambda input, dim=-1, descending=False, stable=False, out=None: -1, torch.split: lambda tensor, split_size_or_sections, dim=0: -1, torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, torch.sqrt: lambda input, out=None: -1, @@ -731,13 +811,20 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nansum: lambda input, dim=None: -1, torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, + torch.linalg.svd: lambda input, full_matrices=True, compute_uv=True, out=None: -1, torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1, + torch.swapaxes: lambda input, dim0, dim1: -1, + torch.swapdims: lambda input, axis0, axis1: -1, torch.t: lambda input: -1, torch.take: lambda input, index: -1, torch.tan: lambda input, out=None: -1, torch.tanh: lambda input, out=None: -1, - torch.tensordot: lambda a, b, dims=2: -1, + torch.linalg.tensorinv: lambda a, ind=2: -1, + torch.linalg.tensorsolve: lambda a, b, dims=None: -1, + torch.tensordot: lambda a, b, dims=2, out=None: -1, + torch.tensor_split: lambda input, indices_or_sections, dim=0: -1, torch.threshold: lambda input, threshold, value, inplace=False: -1, + torch.tile: lambda input, dims: -1, torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1, torch.trace: lambda input: -1, torch.transpose: lambda input, dim0, dim1: -1, @@ -779,6 +866,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.__neg__: lambda self: -1, Tensor.__invert__: lambda self: -1, Tensor.__mod__: lambda self, other: -1, + Tensor.__imod__: lambda self, other: -1, Tensor.__array_wrap__: lambda self, array: -1, Tensor.__getitem__: lambda self, idx: -1, Tensor.__deepcopy__: lambda self, memo: -1, @@ -812,6 +900,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.is_mkldnn.__get__: lambda self: -1, Tensor.is_quantized.__get__: lambda self: -1, Tensor.is_sparse.__get__: lambda self: -1, + Tensor.is_vulkan.__get__: lambda self: -1, Tensor.layout.__get__: lambda self: -1, Tensor.name.__get__: lambda self: -1, Tensor.names.__get__: lambda self: -1, @@ -837,7 +926,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.apply_: lambda self, callable: -1, Tensor.as_strided: lambda self, size, stride: -1, Tensor.as_strided_: lambda self, size, stride: -1, - Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False: -1, + Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1, Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1, Tensor.bool: lambda self, memory_format=torch.preserve_format: -1, Tensor.byte: lambda self, memory_format=torch.preserve_format: -1, @@ -912,6 +1001,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.storage_offset: lambda self: -1, Tensor.storage_type: lambda self: -1, Tensor.sum_to_size: lambda self, size: -1, + Tensor.tile: lambda self, *reps: -1, Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1, Tensor.to_dense: lambda self: -1, Tensor.to_sparse: lambda self: -1, @@ -924,6 +1014,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.view: lambda self, shape: -1, Tensor.view_as: lambda self, other: -1, Tensor.zero_: lambda self: -1, + torch.linalg.norm: lambda self: -1 } ret2 = {} @@ -957,6 +1048,41 @@ def get_testing_overrides() -> Dict[Callable, Callable]: ret.update(ret2) return ret +def wrap_torch_function(dispatcher: Callable): + """Wraps a given function with ``__torch_function__`` -related functionality. + + Parameters + ---------- + dispatcher: Callable + A callable that returns an iterable of Tensor-likes passed into the function. + + Note + ---- + This decorator may reduce the performance of your code. Generally, it's enough to express + your code as a series of functions that, themselves, support __torch_function__. If you + find yourself in the rare situation where this is not the case, e.g. if you're wrapping a + low-level library and you also need it to work for Tensor-likes, then this function is available. + + Examples + -------- + >>> def dispatcher(a): # Must have the same signature as func + ... return (a,) + >>> @torch.overrides.wrap_torch_function(dispatcher) + >>> def func(a): # This will make func dispatchable by __torch_function__ + ... return a + 0 + """ + def inner(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + if has_torch_function(relevant_args): + return handle_torch_function(func, relevant_args, *args, **kwargs) + + return func(*args, **kwargs) + + return wrapped + + return inner def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: """Returns a list of arguments on which to call __torch_function__. @@ -982,18 +1108,15 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: Returns ------- - overloaded_types : collection of types - Types of arguments from relevant_args with __torch_function__ methods. overloaded_args : list Arguments from relevant_args on which to call __torch_function__ methods, in the order in which they should be called. .. _NEP-0018: https://numpy.org/neps/nep-0018-array-function-protocol.html - """ # Runtime is O(num_arguments * num_unique_types) - overloaded_types = [] + overloaded_types = set() overloaded_args = [] for arg in relevant_args: arg_type = type(arg) @@ -1004,7 +1127,7 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: # Create lists explicitly for the first type (usually the only one # done) to avoid setting up the iterator for overloaded_args. if overloaded_types: - overloaded_types.append(arg_type) + overloaded_types.add(arg_type) # By default, insert argument at the end, but if it is # subclass of another argument, insert it before that argument. # This ensures "subclasses before superclasses". @@ -1015,7 +1138,7 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: break overloaded_args.insert(index, arg) else: - overloaded_types = [arg_type] + overloaded_types = {arg_type} overloaded_args = [arg] return overloaded_args @@ -1023,7 +1146,7 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: def handle_torch_function( public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any: - """Implement a function with checks for __torch_function__ overrides. + """Implement a function with checks for ``__torch_function__`` overrides. See torch::autograd::handle_torch_function for the equivalent of this function in the C++ implementation. @@ -1043,13 +1166,20 @@ def handle_torch_function( Returns ------- - Result from calling `implementation()` or an `__torch_function__` - method, as appropriate. + object + Result from calling ``implementation`` or an ``__torch_function__`` + method, as appropriate. Raises ------ TypeError : if no implementation is found. + Example + ------- + >>> def func(a): + ... if type(a) is not torch.Tensor: # This will make func dispatchable by __torch_function__ + ... return handle_torch_function(func, (a,), a) + ... return a + 0 """ # Check for __torch_function__ methods. overloaded_args = _get_overloaded_args(relevant_args) @@ -1068,27 +1198,54 @@ def handle_torch_function( func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) raise TypeError("no implementation found for '{}' on types that implement " '__torch_function__: {}' - .format(func_name, list(map(type, overloaded_args)))) - -def has_torch_function(relevant_args: Iterable[Any]) -> bool: - """Check for __torch_function__ implementations in the elements of an iterable + .format(func_name, [type(arg) for arg in overloaded_args])) +has_torch_function = _add_docstr( + _has_torch_function, + r"""Check for __torch_function__ implementations in the elements of an iterable. + Considers exact ``Tensor`` s and ``Parameter`` s non-dispatchable. Arguments --------- relevant_args : iterable Iterable or aguments to check for __torch_function__ methods. - Returns ------- - True if any of the elements of relevant_args have __torch_function__ - implementations, False otherwise. + bool + True if any of the elements of relevant_args have __torch_function__ + implementations, False otherwise. + See Also + ________ + torch.is_tensor_like + Checks if something is a Tensor-like, including an exact ``Tensor``. """ - return _is_torch_function_enabled() and any( - type(a) is not torch.Tensor and - getattr(a, '__torch_function__', _disabled_torch_function_impl) - is not _disabled_torch_function_impl - for a in relevant_args - ) +) + +has_torch_function_unary = _add_docstr( + _has_torch_function_unary, + r"""Special case of `has_torch_function` for single inputs. + Instead of: + `has_torch_function((t,))` + call: + `has_torch_function_unary(t)` + which skips unnecessary packing and unpacking work. + """ +) + +has_torch_function_variadic = _add_docstr( + _has_torch_function_variadic, + r"""Special case of `has_torch_function` that skips tuple creation. + + This uses the METH_FASTCALL protocol introduced in Python 3.7; for 3.6 + and before it has roughly equivilent performance compared to + `has_torch_function`. + + Instead of: + `has_torch_function((a, b))` + call: + `has_torch_function_variadic(a, b)` + which skips unnecessary packing and unpacking work. + """ +) @functools.lru_cache(None) def get_overridable_functions() -> Dict[Any, List[Callable]]: @@ -1096,15 +1253,18 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]: Returns ------- - A dictionary that maps namespaces that contain overridable functions - to functions in that namespace that can be overridden. + Dict[Any, List[Callable]] + A dictionary that maps namespaces that contain overridable functions + to functions in that namespace that can be overridden. """ overridable_funcs = collections.defaultdict(list) tested_namespaces = [ (torch, torch.__all__ + dir(torch._C._VariableFunctions)), (torch.functional, torch.functional.__all__), (torch.nn.functional, dir(torch.nn.functional)), - (torch.Tensor, dir(torch.Tensor)) + (torch.Tensor, dir(torch.Tensor)), + (torch.linalg, dir(torch.linalg)), + (torch.fft, dir(torch.fft)), ] for namespace, ns_funcs in tested_namespaces: for func_name in ns_funcs: @@ -1151,7 +1311,7 @@ def get_overridable_functions() -> Dict[Any, List[Callable]]: return overridable_funcs @functools.lru_cache(None) -def get_tensor_methods() -> Set[Callable]: +def _get_tensor_methods() -> Set[Callable]: """ Returns a set of the overridable methods on ``torch.Tensor`` """ overridable_funcs = get_overridable_functions() methods = set(overridable_funcs[torch.Tensor]) @@ -1171,14 +1331,47 @@ def is_tensor_method_or_property(func: Callable) -> bool: 1. Methods/properties sometimes don't contain a `__module__` slot. 2. They require that the first passed-in argument is an instance of ``torch.Tensor``. + + Examples + -------- + >>> is_tensor_method_or_property(torch.Tensor.add) + True + >>> is_tensor_method_or_property(torch.add) + False """ - return func in get_tensor_methods() or func.__name__ == "__get__" + return func in _get_tensor_methods() or func.__name__ == "__get__" def is_tensor_like(inp): """ - Returns ``True`` if the passed-in input is a tensor-like. + Returns ``True`` if the passed-in input is a Tensor-like. Currently, this occurs whenever there's a ``__torch_function__`` - attribute on the input. + attribute on the type of the input. + + Examples + -------- + A subclass of tensor is generally a Tensor-like. + + >>> class SubTensor(torch.Tensor): ... + >>> is_tensor_like(SubTensor([0])) + True + + Built-in or user types aren't usually Tensor-like. + + >>> is_tensor_like(6) + False + >>> is_tensor_like(None) + False + >>> class NotATensor: ... + >>> is_tensor_like(NotATensor()) + False + + But, they can be made Tensor-like by implementing __torch_function__. + + >>> class TensorLike: + ... def __torch_function__(self, func, types, args, kwargs): + ... return -1 + >>> is_tensor_like(TensorLike()) + True """ - return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__") + return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__") diff --git a/torch/package/_custom_import_pickler.py b/torch/package/_custom_import_pickler.py index fd5787b6b3e31..665eaaebe946f 100644 --- a/torch/package/_custom_import_pickler.py +++ b/torch/package/_custom_import_pickler.py @@ -1,8 +1,14 @@ -from pickle import _Pickler, _getattribute, whichmodule, _extension_registry, _compat_pickle # type: ignore +from pickle import Pickler, whichmodule, _Pickler, _getattribute, _extension_registry, _compat_pickle # type: ignore from pickle import GLOBAL, STACK_GLOBAL, EXT1, EXT2, EXT4, PicklingError +from types import FunctionType from struct import pack +from ._mangling import demangle, is_mangled, get_mangle_prefix +import importlib + class CustomImportPickler(_Pickler): + dispatch = _Pickler.dispatch.copy() + def __init__(self, import_module, *args, **kwargs): self.import_module = import_module super().__init__(*args, **kwargs) @@ -19,7 +25,11 @@ def save_global(self, obj, name=None): if name is None: name = obj.__name__ - module_name = whichmodule(obj, name) + orig_module_name = whichmodule(obj, name) + # CHANGED: demangle the module name before importing. If this obj came + # out of a PackageImporter, `__module__` will be mangled. See + # mangling.md for details. + module_name = demangle(orig_module_name) try: # CHANGED: self.import_module rather than # __import__ @@ -31,9 +41,44 @@ def save_global(self, obj, name=None): (obj, module_name, name)) from None else: if obj2 is not obj: - raise PicklingError( - "Can't pickle %r: it's not the same object as %s.%s" % - (obj, module_name, name)) + # CHANGED: More specific error message in the case of mangling. + obj_module_name = getattr(obj, "__module__", orig_module_name) + obj2_module_name = getattr(obj2, "__module__", orig_module_name) + + msg = f"Can't pickle {obj}: it's not the same object as {obj2_module_name}.{name}." + + is_obj_mangled = is_mangled(obj_module_name) + is_obj2_mangled = is_mangled(obj2_module_name) + + if is_obj_mangled or is_obj2_mangled: + obj_location = ( + get_mangle_prefix(obj_module_name) + if is_obj_mangled + else "the current Python environment" + ) + obj2_location = ( + get_mangle_prefix(obj2_module_name) + if is_obj2_mangled + else "the current Python environment" + ) + + obj_importer_name = ( + f"the importer for {get_mangle_prefix(obj_module_name)}" + if is_obj_mangled + else "'importlib.import_module'" + ) + obj2_importer_name = ( + f"the importer for {get_mangle_prefix(obj2_module_name)}" + if is_obj2_mangled + else "'importlib.import_module'" + ) + + msg += (f"\n\nThe object being pickled is from '{orig_module_name}', " + f"which is coming from {obj_location}." + f"\nHowever, when we import '{orig_module_name}', it's coming from {obj2_location}." + "\nTo fix this, make sure 'PackageExporter.importers' lists " + f"{obj_importer_name} before {obj2_importer_name}") + raise PicklingError(msg) if self.proto >= 2: code = _extension_registry.get((module_name, name)) @@ -76,3 +121,26 @@ def save_global(self, obj, name=None): "pickle protocol %i" % (module, name, self.proto)) from None self.memoize(obj) + dispatch[FunctionType] = save_global + +def import_module_from_importers(module_name, importers): + last_err = None + for import_module in importers: + try: + return import_module(module_name) + except ModuleNotFoundError as err: + last_err = err + + if last_err is not None: + raise last_err + else: + raise ModuleNotFoundError(module_name) + +def create_custom_import_pickler(data_buf, importers): + if importers == [importlib.import_module]: + # if we are using the normal import library system, then + # we can use the C implementation of pickle which is faster + return Pickler(data_buf, protocol=3) + else: + return CustomImportPickler(lambda mod: import_module_from_importers(mod, importers), + data_buf, protocol=3) diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py new file mode 100644 index 0000000000000..74949b3c8f8be --- /dev/null +++ b/torch/package/_mangling.py @@ -0,0 +1,50 @@ +"""Import mangling. +See mangling.md for details. +""" +import re + +_mangle_index = 0 + + +class PackageMangler: + """ + Used on import, to ensure that all modules imported have a shared mangle parent. + """ + + def __init__(self): + global _mangle_index + self._mangle_index = _mangle_index + # Increment the global index + _mangle_index += 1 + # Angle brackets are used so that there is almost no chance of + # confusing this module for a real module. Plus, it is Python's + # preferred way of denoting special modules. + self._mangle_parent = f"" + + def mangle(self, name) -> str: + assert len(name) != 0 + return self._mangle_parent + "." + name + + def parent_name(self): + return self._mangle_parent + + +def is_mangled(name: str) -> bool: + return bool(re.match(r"", name)) + + +def demangle(name: str) -> str: + """ + Note: Unlike PackageMangler.demangle, this version works on any + mangled name, irrespective of which PackageMangler created it. + """ + if is_mangled(name): + first, sep, last = name.partition(".") + # If there is only a base mangle prefix, e.g. '', + # then return an empty string. + return last if len(sep) != 0 else "" + return name + + +def get_mangle_prefix(name: str) -> str: + return name.partition(".")[0] if is_mangled(name) else name diff --git a/torch/package/_mock.py b/torch/package/_mock.py index d291bb58ba5e2..0139db107cdeb 100644 --- a/torch/package/_mock.py +++ b/torch/package/_mock.py @@ -23,7 +23,23 @@ class MockedObject: _name: str - def __init__(self, name): + def __new__(cls, *args, **kwargs): + # _suppress_err is set by us in the mocked module impl, so that we can + # construct instances of MockedObject to hand out to people looking up + # module attributes. + + # Any other attempt to construct a MockedOject instance (say, in the + # unpickling process) should give an error. + if not kwargs.get("_suppress_err"): + raise NotImplementedError(f"Object '{cls._name}' was mocked out during packaging " + f"but it is being used in '__new__'. If this error is " + "happening during 'load_pickle', please ensure that your " + "pickled object doesn't contain any mocked objects.") + # Otherwise, this is just a regular object creation + # (e.g. `x = MockedObject("foo")`), so pass it through normally. + return super().__new__(cls) + + def __init__(self, name: str, _suppress_err: bool): self.__dict__['_name'] = name def __repr__(self): diff --git a/torch/package/exporter.py b/torch/package/exporter.py index 8530f6f68f3a8..de80f401be3c8 100644 --- a/torch/package/exporter.py +++ b/torch/package/exporter.py @@ -1,18 +1,21 @@ import torch -from torch.serialization import normalize_storage_type, location_tag, _should_read_directly +from torch.serialization import normalize_storage_type, location_tag import io -import pickle import pickletools from .find_file_dependencies import find_files_source_depends_on -from ._custom_import_pickler import CustomImportPickler +from ._custom_import_pickler import create_custom_import_pickler, import_module_from_importers from ._importlib import _normalize_path +from ._mangling import is_mangled import types import importlib -from typing import List, Any, Callable, Dict +from typing import List, Any, Callable, Dict, Tuple, Union, Iterable from distutils.sysconfig import get_python_lib from pathlib import Path import linecache import sys +from urllib.parse import quote +import re + class PackageExporter: """ Exporters allow you to write packages of code, pickled python data, and @@ -70,6 +73,8 @@ def __init__(self, filename: str, verbose: bool = True): self.provided : Dict[str, bool] = {} self.verbose = verbose self.importers = [importlib.import_module] + self.patterns : List[Tuple[Any, Callable[[str], None]]] = [] # 'any' is 're.Pattern' but breaks old mypy + self.debug_deps : List[Tuple[str, str]] = [] def save_source_file(self, module_name: str, file_or_directory: str, dependencies=True): """Adds the local file system `file_or_directory` to the source package to provide the code @@ -131,15 +136,9 @@ def save_source_string(self, module_name: str, src: str, is_package: bool = Fals self._write(filename, src) if dependencies: package = module_name if is_package else module_name.rsplit('.', maxsplit=1)[0] - dep_list = find_files_source_depends_on(src, package) - if self.verbose: - def fmt_dep(mod, obj): - return f'{mod}' if obj is None else f'{mod}.{obj}' - dep_str = ''.join(f' {fmt_dep(mod, obj)}\n' for mod, obj in dep_list) - file_info = f'(from file {orig_file_name}) ' if orig_file_name is not None else '' - print(f"{module_name} {file_info}depends on:\n{dep_str}\n") - - for dep_module_name, dep_module_obj in dep_list: + dep_pairs = find_files_source_depends_on(src, package) + dep_list = {} + for dep_module_name, dep_module_obj in dep_pairs: # handle the case where someone did something like `from pack import sub` # where `sub` is a submodule. In this case we don't have to save pack, just sub. # this ensures we don't pick up additional dependencies on pack. @@ -148,25 +147,63 @@ def fmt_dep(mod, obj): if dep_module_obj is not None: possible_submodule = f'{dep_module_name}.{dep_module_obj}' if self._module_exists(possible_submodule): - self.require_module_if_not_provided(possible_submodule) + dep_list[possible_submodule] = True # we don't need to save `pack` continue if self._module_exists(dep_module_name): - self.require_module_if_not_provided(dep_module_name) + dep_list[dep_module_name] = True + + for dep in dep_list.keys(): + self.debug_deps.append((module_name, dep)) + + if self.verbose: + dep_str = ''.join(f' {dep}\n' for dep in dep_list.keys()) + file_info = f'(from file {orig_file_name}) ' if orig_file_name is not None else '' + print(f"{module_name} {file_info}depends on:\n{dep_str}\n") + + for dep in dep_list.keys(): + self.require_module_if_not_provided(dep) + + def _import_module(self, module_name: str): + try: + return import_module_from_importers(module_name, self.importers) + except ModuleNotFoundError as e: + if not is_mangled(module_name): + raise + msg = (f"Module not found: '{module_name}'. Modules imported " + "from a torch.package cannot be re-exported directly.") + raise ModuleNotFoundError(msg) from None def _module_exists(self, module_name: str) -> bool: try: self._import_module(module_name) return True - except ModuleNotFoundError: + except Exception: return False + def _write_dep_graph(self, failing_module=None): + edges = '\n'.join(f'"{f}" -> "{t}";' for f, t in self.debug_deps) + failing = '' if failing_module is None else f'"{failing_module}" [color=red];' + template = f"""\ +digraph G {{ +rankdir = LR; +node [shape=box]; +{failing} +{edges} +}} +""" + arg = quote(template, safe='') + return f'https://dreampuf.github.io/GraphvizOnline/#{arg}' + def _get_source_of_module(self, module: types.ModuleType) -> str: filename = getattr(module, '__file__', None) - result = None if filename is None else linecache.getlines(filename, module.__dict__) + result = None if filename is None or not filename.endswith('.py') else linecache.getlines(filename, module.__dict__) if result is None: + extra = '' + if self.verbose: + extra = f' See the dependency graph for more info: \n{self._write_dep_graph(module.__name__)}' raise ValueError(f'cannot save source for module "{module.__name__}" because ' - f'its source file "{filename}" could not be found.') + f'its source file "{filename}" could not be found.{extra}') return ''.join(result) def require_module_if_not_provided(self, module_name: str, dependencies=True): @@ -187,9 +224,14 @@ def require_module(self, module_name: str, dependencies=True): if self.verbose: print(f'implicitly adding {root_name} to external modules ' f'since it is part of the standard library and is a dependency.') - self.extern_module(root_name) + self.save_extern_module(root_name) return + for pattern, action in self.patterns: + if pattern.matches(module_name): + action(module_name) + return + self.save_module(module_name, dependencies) def save_module(self, module_name: str, dependencies=True): @@ -203,27 +245,6 @@ def save_module(self, module_name: str, dependencies=True): source = self._get_source_of_module(module) self.save_source_string(module_name, source, hasattr(module, '__path__'), dependencies, module.__file__) - - def _import_module(self, module_name): - last_err = None - for import_module in self.importers: - try: - return import_module(module_name) - except ModuleNotFoundError as err: - last_err = err - if last_err is not None: - raise last_err - else: - raise ModuleNotFoundError(module_name) - - def _create_pickler(self, data_buf): - if self.importers == [importlib.import_module]: - # if we are using the normal import library system, then - # we can use the C implementation of pickle which is faster - return pickle.Pickler(data_buf, protocol=3) - else: - return CustomImportPickler(self._import_module, data_buf, protocol=3) - def save_pickle(self, package: str, resource: str, obj: Any, dependencies: bool = True): """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into the archive rather than a stand-alone file. Stanard pickle does not save the code, only the objects. @@ -244,7 +265,7 @@ def save_pickle(self, package: str, resource: str, obj: Any, dependencies: bool filename = self._filename(package, resource) # Write the pickle data for `obj` data_buf = io.BytesIO() - pickler = self._create_pickler(data_buf) + pickler = create_custom_import_pickler(data_buf, self.importers) pickler.persistent_id = self._persistent_id pickler.dump(obj) data_value = data_buf.getvalue() @@ -258,6 +279,9 @@ def save_pickle(self, package: str, resource: str, obj: Any, dependencies: bool if module not in all_dependencies: all_dependencies.append(module) + for dep in all_dependencies: + self.debug_deps.append((package + '.' + resource, dep)) + if self.verbose: dep_string = ''.join(f' {dep}\n' for dep in all_dependencies) print(f"{resource} depends on:\n{dep_string}\n") @@ -288,52 +312,64 @@ def save_binary(self, package, resource, binary: bytes): filename = self._filename(package, resource) self._write(filename, binary) - def extern_module(self, module_name: str): + def mock(self, include: 'GlobPattern', *, exclude: 'GlobPattern' = ()): + """Replace some required modules with a mock implementation. Mocked modules will return a fake + object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes + find files that are imported by model files but whose functionality is never used + (e.g. custom serialization code or training helpers). + Use this function to mock this functionality out without having to modify the original code. + + Args: + include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings + for the names of the modules to be mocked out. Strings can also be a glob-style pattern + string that may match multiple modules. Any required dependencies that match this pattern + string will be mocked out automatically. + + Examples: + 'torch.**' -- matches torch and all submodules of torch, e.g. 'torch.nn' and torch.nn.functional' + 'torch.*' -- matches 'torch.nn' or 'torch.functional', but not 'torch.nn.functional' + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + e.g. include='torch.**', exclude='torch.foo' will mock all torch packages except 'torch.foo' Default: [] + + """ + self.patterns.append((_GlobGroup(include, exclude), self.save_mock_module)) + + def extern(self, include: 'GlobPattern', *, exclude: 'GlobPattern' = ()): """Include `module` in the list of external modules the package can import. This will prevent dependency discover from saving it in the package. The importer will load an external module directly from the standard import system. Code for extern modules must also exist in the process loading the package. Args: - module_name (str): e.g. "my_package.my_subpackage" the name of the external module + include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings + for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock` + + exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. + """ - if module_name not in self.external: - self.external.append(module_name) + self.patterns.append((_GlobGroup(include, exclude), self.save_extern_module)) - def extern_modules(self, module_names: List[str]): - """Extern a list of modules. Convience wrapper for calling :meth:`extern_module` on many items. + def save_extern_module(self, module_name: str): + """Add `module_name` to the list of external modules, regardless of whether it is + required by other modules. - Args: - module_names (List[str]): List of module names + Prefer using `extern` to only mark modules extern if they are actually required by the packaged code. """ - for m in module_names: - self.extern_module(m) + if module_name not in self.external: + self.external.append(module_name) - def mock_module(self, module_name: str): - """Replace the code for `module_name` in the package with a fake implementation. This module will return a fake - object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes - find files that are imported by model files but whose functionality is never used - (e.g. custom serialization code or training helpers). - Use this function to mock this functionality out without having to modify the original code. + def save_mock_module(self, module_name: str): + """Add `module_name` to the package, implemented it with a mocked out version that + can be imported but does not include any implementations. - Args: - module_name (str): e.g. "my_package.my_subpackage" the name of the module to be mocked out. + Prefer using `mock` to only include this module if it is required by other modules. """ if '_mock' not in self.provided: self.save_source_file('_mock', str(Path(__file__).parent / '_mock.py'), dependencies=False) is_package = hasattr(self._import_module(module_name), '__path__') self.save_source_string(module_name, _MOCK_IMPL, is_package, dependencies=False) - - def mock_modules(self, module_names): - """Mock a list of modules. Convience wrapper for calling :meth:`mock_module` on many items. - - Args: - module_names (List[str]): List of module names - """ - for module_name in module_names: - self.mock_module(module_name) - def _module_is_already_provided(self, qualified_name: str) -> bool: for mod in self.external: if qualified_name == mod or qualified_name.startswith(mod + '.'): @@ -366,6 +402,9 @@ def __exit__(self, type, value, traceback): self.close() def _write(self, filename, str_or_bytes): + if is_mangled(filename): + raise RuntimeError(f"Tried to save a torch.package'd module as '{filename}'. " + "Directly saving torch.package'd modules is not allowed.") if isinstance(str_or_bytes, str): str_or_bytes = str_or_bytes.encode('utf-8') self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes)) @@ -377,20 +416,19 @@ def close(self): with PackageExporter("file.zip") as e: ... """ + if self.verbose: + print(f"Dependency graph for exported package: \n{self._write_dep_graph()}") + # Write each tensor to a file named tensor/the_tensor_key in the zip archive for key in sorted(self.serialized_storages.keys()): name = 'data/{}'.format(key) storage = self.serialized_storages[key] - if storage.device.type == 'cpu': - # If it's on the CPU we can directly copy it into the zip file - num_bytes = storage.size() * storage.element_size() - self.zip_file.write_record(name, storage.data_ptr(), num_bytes) - else: - # Copy to a buffer, then serialize that - buf = io.BytesIO() - storage._write_file(buf, _should_read_directly(buf)) - buf_value = buf.getvalue() - self._write(name, buf_value) + # location information is saved in python, but to actually + # get the data from non cpu tensors we need to move them over first + if storage.device.type != 'cpu': + storage = storage.cpu() + num_bytes = storage.size() * storage.element_size() + self.zip_file.write_record(name, storage.data_ptr(), num_bytes) contents = ('\n'.join(self.external) + '\n') self._write('extern_modules', contents) del self.zip_file @@ -404,7 +442,6 @@ def _can_implicitly_extern(self, module_name: str): return module_name == 'torch' or (module_name not in _DISALLOWED_MODULES and _is_builtin_or_stdlib_module(self._import_module(module_name))) - # even though these are in the standard library, we do not allow them to be # automatically externed since they offer a lot of system level access _DISALLOWED_MODULES = ['sys', 'io'] @@ -412,7 +449,7 @@ def _can_implicitly_extern(self, module_name: str): def _is_builtin_or_stdlib_module(module: types.ModuleType) -> bool: if module.__name__ in sys.builtin_module_names: return True - filename = module.__file__ + filename = getattr(module, '__file__', None) if filename is None: return False standard_lib = get_python_lib(standard_lib=True) @@ -426,10 +463,49 @@ def _is_builtin_or_stdlib_module(module: types.ModuleType) -> bool: _MOCK_IMPL = """\ from _mock import MockedObject def __getattr__(attr: str): - return MockedObject(__name__ + '.' + attr) + return MockedObject(__name__ + '.' + attr, _suppress_err=True) """ def _read_file(filename: str) -> str: with open(filename, 'rb') as f: b = f.read() return b.decode('utf-8') + +GlobPattern = Union[str, Iterable[str]] + + +class _GlobGroup: + def __init__(self, include: 'GlobPattern', exclude: 'GlobPattern'): + self._dbg = f'_GlobGroup(include={include}, exclude={exclude})' + self.include = _GlobGroup._glob_list(include) + self.exclude = _GlobGroup._glob_list(exclude) + + def __str__(self): + return self._dbg + + def matches(self, candidate: str) -> bool: + candidate = '.' + candidate + return any(p.fullmatch(candidate) for p in self.include) and all(not p.fullmatch(candidate) for p in self.exclude) + + @staticmethod + def _glob_list(elems: 'GlobPattern'): + if isinstance(elems, str): + return [_GlobGroup._glob_to_re(elems)] + else: + return [_GlobGroup._glob_to_re(e) for e in elems] + + @staticmethod + def _glob_to_re(pattern: str): + # to avoid corner cases for the first component, we prefix the candidate string + # with '.' so `import torch` will regex against `.torch` + def component_to_re(component): + if '**' in component: + if component == '**': + return '(\\.[^.]+)*' + else: + raise ValueError('** can only appear as an entire path segment') + else: + return '\\.' + '[^.]*'.join(re.escape(x) for x in component.split('*')) + + result = ''.join(component_to_re(c) for c in pattern.split('.')) + return re.compile(result) diff --git a/torch/package/importer.py b/torch/package/importer.py index 1a02e69436faf..58ec960995357 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,6 +1,7 @@ from typing import List, Callable, Dict, Optional, Any, Union import builtins import importlib +import linecache from torch.serialization import _load import pickle import torch @@ -11,6 +12,7 @@ from ._importlib import _normalize_line_endings, _resolve_name, _sanity_check, _calc___package__, \ _normalize_path from ._mock_zipreader import MockZipReader +from ._mangling import PackageMangler, demangle class PackageImporter: """Importers allow you to load code written to packages by PackageExporter. @@ -26,12 +28,13 @@ class PackageImporter: a locally-installed package, but then fails when the package is copied to another machine. """ - modules : Dict[str, Optional[types.ModuleType]] """The dictionary of already loaded modules from this package, equivalent to `sys.modules` but local to this importer. """ + modules : Dict[str, Optional[types.ModuleType]] - def __init__(self, filename: str, module_allowed: Callable[[str], bool] = lambda module_name: True): + def __init__(self, filename: Union[str, torch._C.PyTorchFileReader], + module_allowed: Callable[[str], bool] = lambda module_name: True): """Open `filename` for importing. This checks that the imported package only requires modules allowed by `module_allowed` @@ -45,12 +48,16 @@ def __init__(self, filename: str, module_allowed: Callable[[str], bool] = lambda Raises: ImportError: If the package will use a disallowed module. """ - self.filename = filename self.zip_reader : Any - if not os.path.isdir(self.filename): - self.zip_reader = torch._C.PyTorchFileReader(self.filename) + if isinstance(filename, torch._C.PyTorchFileReader): + self.filename = '' + self.zip_reader = filename else: - self.zip_reader = MockZipReader(self.filename) + self.filename = filename + if not os.path.isdir(self.filename): + self.zip_reader = torch._C.PyTorchFileReader(self.filename) + else: + self.zip_reader = MockZipReader(self.filename) self.root = _PackageNode(None) self.modules = {} @@ -62,14 +69,16 @@ def __init__(self, filename: str, module_allowed: Callable[[str], bool] = lambda f"but that module has been disallowed") self._add_extern(extern_module) - for filename in self.zip_reader.get_all_records(): - self._add_file(filename) + for fname in self.zip_reader.get_all_records(): + self._add_file(fname) self.patched_builtins = builtins.__dict__.copy() self.patched_builtins['__import__'] = self.__import__ # allow pickles from archive using `import resources` self.modules['resources'] = self # type: ignore + self._mangler = PackageMangler() + # used for torch.serialization._load self.Unpickler = lambda *args, **kwargs: _UnpicklerWrapper(self, *args, **kwargs) @@ -131,26 +140,47 @@ def load_pickle(self, package: str, resource: str, map_location=None) -> Any: pickle_file = self._zipfile_path(package, resource) return _load(self.zip_reader, map_location, self, pickle_file=pickle_file) + def id(self): + """ + Returns internal identifier that torch.package uses to distinguish PackageImporter instances. + Looks like: + + """ + return self._mangler.parent_name() def _read_extern(self): return self.zip_reader.get_record('extern_modules').decode('utf-8').splitlines(keepends=False) - def _make_module(self, name: str, filename: Optional[str], is_package: bool): + def _make_module(self, name: str, filename: Optional[str], is_package: bool, parent: str): + mangled_filename = self._mangler.mangle(filename) if filename else None spec = importlib.machinery.ModuleSpec(name, self, is_package=is_package) # type: ignore module = importlib.util.module_from_spec(spec) self.modules[name] = module + module.__name__ = self._mangler.mangle(name) ns = module.__dict__ ns['__spec__'] = spec ns['__loader__'] = self - ns['__file__'] = filename + ns['__file__'] = mangled_filename ns['__cached__'] = None ns['__builtins__'] = self.patched_builtins + + # pre-emptively install on the parent to prevent IMPORT_FROM from trying to + # access sys.modules + self._install_on_parent(parent, name, module) + if filename is not None: - code = self._compile_source(filename) + assert mangled_filename is not None + # pre-emptively install the source in `linecache` so that stack traces, + # `inspect`, etc. work. + assert filename not in linecache.cache # type: ignore + linecache.lazycache(mangled_filename, ns) + + code = self._compile_source(filename, mangled_filename) exec(code, ns) + return module - def _load_module(self, name: str): + def _load_module(self, name: str, parent: str): cur : _PathNode = self.root for atom in name.split('.'): if not isinstance(cur, _PackageNode) or atom not in cur.children: @@ -161,18 +191,27 @@ def _load_module(self, name: str): if isinstance(cur, _ExternNode): module = self.modules[name] = importlib.import_module(name) return module - return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode)) # type: ignore + return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore - def _compile_source(self, fullpath): + def _compile_source(self, fullpath: str, mangled_filename: str): source = self.zip_reader.get_record(fullpath) source = _normalize_line_endings(source) - return compile(source, fullpath, 'exec', dont_inherit=True) + return compile(source, mangled_filename, 'exec', dont_inherit=True) # note: named `get_source` so that linecache can find the source # when this is the __loader__ of a module. def get_source(self, module_name) -> str: - module = self.import_module(module_name) - return self.zip_reader.get_record(module.__file__).decode('utf-8') + # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here. + module = self.import_module(demangle(module_name)) + return self.zip_reader.get_record(demangle(module.__file__)).decode('utf-8') + + def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): + if not parent: + return + # Set the module as an attribute on its parent. + parent_module = self.modules[parent] + if parent_module.__loader__ is self: # type: ignore + setattr(parent_module, name.rpartition('.')[2], module) # note: copied from cpython's import code, with call to create module replaced with _make_module def _do_find_and_load(self, name): @@ -191,13 +230,10 @@ def _do_find_and_load(self, name): msg = (_ERR_MSG + '; {!r} is not a package').format(name, parent) raise ModuleNotFoundError(msg, name=name) from None - module = self._load_module(name) + module = self._load_module(name, parent) + + self._install_on_parent(parent, name, module) - if parent: - # Set the module as an attribute on its parent. - parent_module = self.modules[parent] - if parent_module.__loader__ is self: # type: ignore - setattr(parent_module, name.rpartition('.')[2], module) return module # note: copied from cpython's import code @@ -238,13 +274,14 @@ def _handle_fromlist(self, module, fromlist, *, recursive=False): import implementation is desired. """ + module_name = demangle(module.__name__) # The hell that is fromlist ... # If a package was imported, try to import stuff from fromlist. if hasattr(module, '__path__'): for x in fromlist: if not isinstance(x, str): if recursive: - where = module.__name__ + '.__all__' + where = module_name + '.__all__' else: where = "``from list''" raise TypeError(f"Item in {where} must be str, " @@ -254,7 +291,7 @@ def _handle_fromlist(self, module, fromlist, *, recursive=False): self._handle_fromlist(module, module.__all__, recursive=True) elif not hasattr(module, x): - from_name = '{}.{}'.format(module.__name__, x) + from_name = '{}.{}'.format(module_name, x) try: self._gcd_import(from_name) except ModuleNotFoundError as exc: @@ -287,7 +324,8 @@ def __import__(self, name, globals=None, locals=None, fromlist=(), level=0): cut_off = len(name) - len(name.partition('.')[0]) # Slice end needs to be positive to alleviate need to special-case # when ``'.' not in name``. - return self.modules[module.__name__[:len(module.__name__) - cut_off]] + module_name = demangle(module.__name__) + return self.modules[module_name[:len(module_name) - cut_off]] else: return self._handle_fromlist(module, fromlist) @@ -314,7 +352,8 @@ def _zipfile_path(self, package, resource): package = self._get_package(package) resource = _normalize_path(resource) assert package.__loader__ is self - return f"{package.__name__.replace('.', '/')}/{resource}" + name = demangle(package.__name__) + return f"{name.replace('.', '/')}/{resource}" def _get_or_create_package(self, atoms: List[str]) -> 'Union[_PackageNode, _ExternNode]': cur = self.root diff --git a/torch/package/mangling.md b/torch/package/mangling.md new file mode 100644 index 0000000000000..a648a5d2bd5ed --- /dev/null +++ b/torch/package/mangling.md @@ -0,0 +1,70 @@ +# Import mangling in `torch.package` + +## Mangling rules +These are the core invariants; if you are changing mangling code please preserve them. + +1. For every module imported by `PackageImporter`, two attributes are mangled: + - `__module__` + - `__file__` +2. Any `__module__` and `__file__` attribute accessed inside + `Package{Ex|Im}porter` should be demangled immediately. +3. No mangled names should be serialized by `PackageExporter`. + +## Why do we mangle imported names? +To avoid accidental name collisions with modules in `sys.modules`. Consider the following: + + from torchvision.models import resnet18 + local_resnet18 = resnet18() + + # a loaded resnet18, potentially with a different implementation than the local one! + i = torch.PackageImporter('my_resnet_18.pt') + loaded_resnet18 = i.load_pickle('model', 'model.pkl') + + print(type(local_resnet18).__module__) # 'torchvision.models.resnet18' + print(type(loaded_resnet18).__module__) # ALSO 'torchvision.models.resnet18' + +These two model types have the same originating `__module__` name set. +While this isn't facially incorrect, there are a number of places in +`cpython` and elsewhere that assume you can take any module name, look it +up `sys.modules`, and get the right module back, including: +- [`import_from`](https://github.com/python/cpython/blob/5977a7989d49c3e095c7659a58267d87a17b12b1/Python/ceval.c#L5500) +- `inspect`: used in TorchScript to retrieve source code to compile +- …probably more that we don't know about. + +In these cases, we may silently pick up the wrong module for `loaded_resnet18` +and e.g. TorchScript the wrong source code for our model. + +## How names are mangled +On import, all modules produced by a given `PackageImporter` are given a +new top-level module as their parent. This is called the `mangle parent`. For example: + + torchvision.models.resnet18 + +becomes + + .torchvision.models.resnet18 + +The mangle parent is made unique to a given `PackageImporter` instance by +bumping a process-global `mangle_index`, i.e. ``. + +The mangle parent intentionally uses angle brackets (`<` and `>`) to make it +very unlikely that mangled names will collide with any "real" user module. + +An imported module's `__file__` attribute is mangled in the same way, so: + + torchvision/modules/resnet18.py + +becomes + + .torchvision/modules/resnet18.py + +Similarly, the use of angle brackets makes it very unlikely that such a name +will exist in the user's file system. + +## Don't serialize mangled names +Mangling happens `on import`, and the results are never saved into a package. +Assigning mangle parents on import means that we can enforce that mangle +parents are unique within the environment doing the importing. + +It also allows us to avoid serializing (and maintaining backward +compatibility for) this detail. diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py new file mode 100644 index 0000000000000..dabbf91dff908 --- /dev/null +++ b/torch/profiler/__init__.py @@ -0,0 +1,12 @@ +# type: ignore +r''' +PyTorch Profiler is a tool that allows the collecton of the performance metrics during the training and inference. +Profiler's context manager API can be used to better understand what model operators are the most expensive, +examine their input shapes and stack traces, study device kernel activity and visualize the execution trace. + +.. note:: + An earlier version of the API in ``torch.autograd`` module is considered legacy and will be deprecated. + +''' + +from .profiler import profile, schedule, ProfilerAction, ProfilerActivity diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py new file mode 100644 index 0000000000000..25bee1c2019f9 --- /dev/null +++ b/torch/profiler/profiler.py @@ -0,0 +1,312 @@ +import torch.autograd.profiler as prof +from torch.autograd import ProfilerActivity + +from enum import Enum +from typing import Any, Callable, Iterable, Optional +from warnings import warn + + +class ProfilerAction(Enum): + NONE = 0 + WARMUP = 1 + RECORD = 2 + RECORD_AND_SAVE = 3 + + +def schedule(*, wait: int, warmup: int, active: int): + """ + Represents profiler behavior: wait for ``wait`` steps, then + do the warmup for the next ``warmup`` steps, then + do the active recording for the next ``active`` steps and then + repeat the cycle staring with the next step. + """ + def schedule_fn(step: int) -> ProfilerAction: + assert step >= 0 + num_steps = wait + warmup + active + mod_step = step % num_steps + if mod_step < wait: + return ProfilerAction.NONE + elif mod_step < wait + warmup: + return ProfilerAction.WARMUP + else: + return ProfilerAction.RECORD if mod_step < num_steps - 1 \ + else ProfilerAction.RECORD_AND_SAVE + assert wait >= 0 and warmup >= 0 and active > 0, \ + "Invalid profiler schedule arguments" + if warmup == 0: + warn("Profiler won't be using warmup, this can skew profiler results") + return schedule_fn + + +def _default_schedule_fn(_: int) -> ProfilerAction: + """ + Default profiler behavior - immediately starts recording the events, + keeps doing it on every profiler step. + """ + return ProfilerAction.RECORD + + +class profile(object): + """ + Profiler context manager. + + Args: + + - ``activities`` - list of activity groups (CPU, CUDA) to use in profiling; + - ``schedule`` - callable that takes step (int) as a single parameter and returns + ``ProfilerAction`` value that specifies the profiler action on each step; + - ``on_trace_ready`` (optional) - callable, called each time the trace is ready + during the profiling; + - ``record_shapes`` - save information about operator's input shapes; + - ``profile_memory`` - track tensor memory allocation/deallocation; + - ``with_stack`` - save stack traces; + - ``use_gpu`` - (deprecated, use ``activities``). + + .. note:: + Use ``torch.profiler.schedule`` to generate the callable schedule. + Non-default schedules are useful when profiling long training jobs + and allow the user to obtain multiple traces at the different iterations + of the training process. + The default schedule simply records all the events continuously for the + duration of the context manager. + + .. note:: + Enabling shape and stack tracing results in additional overhead. + + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA] + ) as p: + code_to_profile() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + Usimg the profiler's ``schedule``, ``on_trace_ready`` and ``next_step`` functions: + + .. code-block:: python + + # Non-default profiler schedule allows user to turn profiler on and off + # on different iterations of the training loop; + # trace_handler is called every time a new trace becomes available + def trace_handler(prof): + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step()) + ".json") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + + # In this example with wait=1, warmup=1, active=2, + # profiler will skip the first step/iteration, + # start warming up on the second, record + # the third and the forth iterations, + # after which the trace will become available + # and on_trace_ready (when set) is called; + # the cycle repeats starting with the next step + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2), + on_trace_ready=trace_handler + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.next_step() + """ + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + # deprecated: + use_gpu: Optional[bool] = None): + if activities: + self.activities = activities + else: + if use_gpu is not None: + warn("use_gpu is deprecated, use activities argument instead") + self.activities = set([ProfilerActivity.CPU]) + if use_gpu: + self.activities.add(ProfilerActivity.CUDA) + else: + raise RuntimeError("Profiler activities are not specified") + + if schedule: + self.schedule = schedule + # add step markers into the trace and table view + self.record_steps = True + else: + self.schedule = _default_schedule_fn + self.record_steps = False + self.on_trace_ready = on_trace_ready + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.with_stack = with_stack + self.step_num = 0 + self.current_action = self.schedule(self.step_num) + self.profiler: Optional[prof.profile] = None + self.step_rec_fn: Optional[prof.record_function] = None + + def __enter__(self): + self._enter_actions() + if self.record_steps: + self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) + self.step_rec_fn.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + self._exit_actions() + + def next_step(self): + """ + Signals the profiler that the next profiling step has started. + """ + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + prev_action = self.current_action + self.step_num += 1 + self.current_action = self.schedule(self.step_num) + + if self.current_action == ProfilerAction.NONE: + if prev_action == ProfilerAction.NONE: + pass + elif prev_action == ProfilerAction.WARMUP: + warn("Incorrect schedule: WARMUP followed by NONE") + self._start_trace() + self._stop_trace() + elif prev_action == ProfilerAction.RECORD: + warn("Incorrect schedule: RECORD followed by NONE") + self._stop_trace() + else: + assert prev_action == ProfilerAction.RECORD_AND_SAVE + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + elif self.current_action == ProfilerAction.WARMUP: + if prev_action == ProfilerAction.NONE: + self._start_warmup() + elif prev_action == ProfilerAction.WARMUP: + pass + elif prev_action == ProfilerAction.RECORD: + warn("Incorrect schedule: RECORD followed by WARMUP") + self._stop_trace() + else: + assert prev_action == ProfilerAction.RECORD_AND_SAVE + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + self._start_warmup() + elif self.current_action in \ + [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]: + if prev_action == ProfilerAction.NONE: + self._start_warmup() + self._start_trace() + elif prev_action == ProfilerAction.WARMUP: + self._start_trace() + elif prev_action == ProfilerAction.RECORD: + pass + else: + assert prev_action == ProfilerAction.RECORD_AND_SAVE + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + self._start_warmup() + self._start_trace() + + if self.record_steps: + self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) + self.step_rec_fn.__enter__() + + def step(self): + """ + Returns the current profiling step. + """ + return self.step_num + + def export_chrome_trace(self, path: str): + """ + Exports the collected trace in Chrome JSON format. + """ + assert self.profiler + return self.profiler.export_chrome_trace(path) + + def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): + """ + Save stack traces in a file in a format suitable for visualization. + + Args: + + - ``path`` - save stacks file to this location; + - ``metric`` - metric to use: "self_cpu_time_total" or "self_cuda_time_total" + + .. note:: + Example of using FlameGraph tool: + + - git clone https://github.com/brendangregg/FlameGraph + - cd FlameGraph + - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg + """ + assert self.profiler + return self.profiler.export_stacks(path, metric) + + def key_averages(self, group_by_input_shape: bool = False, group_by_stack_n: int = 0): + """ + Averages events, grouping them by operator name and (optionally) input shapes and + stack. + Note: to use shape/stack functionality make sure to set record_shapes/with_stack + when creating profiler context manager. + """ + assert self.profiler + return self.profiler.key_averages(group_by_input_shape, group_by_stack_n) + + def _enter_actions(self): + if self.current_action == ProfilerAction.WARMUP: + self._start_warmup() + elif self.current_action in \ + [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]: + self._start_warmup() + self._start_trace() + + def _exit_actions(self): + if self.current_action == ProfilerAction.WARMUP: + self._start_trace() + self._stop_trace() + elif self.current_action in \ + [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]: + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + + def _start_warmup(self): + self.profiler = prof.profile( + use_cuda=(ProfilerActivity.CUDA in self.activities), + use_cpu=(ProfilerActivity.CPU in self.activities), + record_shapes=self.record_shapes, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + use_kineto=True, + ) + self.profiler._prepare_kineto_trace() + + def _start_trace(self): + assert self.profiler is not None + self.profiler._start_kineto_trace() + + def _stop_trace(self): + assert self.profiler is not None + self.profiler.__exit__(None, None, None) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index 31943e56e6a3c..b2a8e542f93ac 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -6,10 +6,9 @@ from .stubs import * from .quant_type import * from .quantize_jit import * -from .quantize_fx import * +# from .quantize_fx import * from .quantization_mappings import * from .fuser_method_mappings import * -from .custom_module_class_mappings import * def default_eval_fn(model, calib_data): r""" @@ -27,35 +26,33 @@ def default_eval_fn(model, calib_data): # Top level API for graph mode quantization on TorchScript 'quantize_jit', 'quantize_dynamic_jit', # Top level API for graph mode quantization on GraphModule(torch.fx) - 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx - 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', - 'QuantType', # quantization type + # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx + # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', + 'QuantType', 'quant_type_to_str', # quantization type # custom module APIs - 'register_static_quant_module_mapping', - 'get_static_quant_module_mappings', 'get_static_quant_module_class', - 'register_dynamic_quant_module_mapping', - 'get_dynamic_quant_module_mappings', - 'register_qat_module_mapping', - 'get_qat_module_mappings', - 'get_qconfig_propagation_list', - 'get_compare_output_module_list', - 'register_quantized_operator_mapping', 'get_quantized_operator', - 'register_fuser_method', 'get_fuser_method', - 'register_observed_custom_module_mapping', - 'get_observed_custom_module_class', - 'register_quantized_custom_mdoule_mapping', - 'get_quantized_custom_module_class', - 'is_custom_module_class', - 'is_observed_custom_module', + 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', + 'get_default_dynamic_quant_module_mappings', + 'get_default_qat_module_mappings', + 'get_default_qconfig_propagation_list', + 'get_default_compare_output_module_list', + 'get_quantized_operator', + 'get_fuser_method', # Sub functions for `prepare` and `swap_module` 'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module', 'default_eval_fn', 'get_observer_dict', 'register_activation_post_process_hook', # Observers 'ObserverBase', 'WeightObserver', 'observer', 'default_observer', - 'default_weight_observer', + 'default_weight_observer', 'default_placeholder_observer', + # FakeQuantize (for qat) + 'default_fake_quant', 'default_weight_fake_quant', + 'default_symmetric_fixed_qparams_fake_quant', + 'default_affine_fixed_qparams_fake_quant', + 'default_per_channel_weight_fake_quant', + 'default_histogram_fake_quant', # QConfig 'QConfig', 'default_qconfig', 'default_dynamic_qconfig', 'float16_dynamic_qconfig', + 'float_qparams_weight_only_qconfig', # QAT utilities 'default_qat_qconfig', 'prepare_qat', 'quantize_qat', # module transformations diff --git a/torch/quantization/_numeric_suite.py b/torch/quantization/_numeric_suite.py index 01703269604a6..100ff54d4436f 100644 --- a/torch/quantization/_numeric_suite.py +++ b/torch/quantization/_numeric_suite.py @@ -4,9 +4,10 @@ import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd from torch.quantization import prepare +from typing import Dict from .quantization_mappings import ( - get_compare_output_module_list, + get_default_compare_output_module_list, ) NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { @@ -66,7 +67,8 @@ def compare_weights(float_dict, quantized_dict): a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights """ - weight_dict = {} + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights") + weight_dict: Dict[str, Dict] = {} for key in quantized_dict: match_key = _find_match(float_dict, key, "weight") if match_key is not None: @@ -141,8 +143,9 @@ def get_logger_dict(mod, prefix=""): Return: target_dict: the dictionary used to save all logger stats """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict") - target_dict = {} + target_dict: Dict[str, Dict] = {} _get_logger_dict_helper(mod, target_dict, prefix) return target_dict @@ -166,23 +169,16 @@ class ShadowLogger(Logger): def __init__(self): super(ShadowLogger, self).__init__() - self.stats["float"] = None - self.stats["quantized"] = None + self.stats["float"] = [] + self.stats["quantized"] = [] def forward(self, x, y): if len(x) > 1: x = x[0] if len(y) > 1: y = y[0] - if self.stats["quantized"] is None: - self.stats["quantized"] = x.detach() - else: - self.stats["quantized"] = torch.cat((self.stats["quantized"], x.detach())) - - if self.stats["float"] is None: - self.stats["float"] = y.detach() - else: - self.stats["float"] = torch.cat((self.stats["float"], y.detach())) + self.stats["quantized"].append(x.detach()) + self.stats["float"].append(y.detach()) class OutputLogger(Logger): @@ -191,13 +187,11 @@ class OutputLogger(Logger): def __init__(self): super(OutputLogger, self).__init__() - self.stats["tensor_val"] = None + self.stats["tensor_val"] = [] + def forward(self, x): - if self.stats["tensor_val"] is None: - self.stats["tensor_val"] = x - else: - self.stats["tensor_val"] = torch.cat((self.stats["tensor_val"], x)) + self.stats["tensor_val"].append(x) return x @@ -304,6 +298,7 @@ def prepare_model_with_stubs(float_module, q_module, module_swap_list, Logger): Logger: type of logger to be used in shadow module to process the outputs of quantized module and its float shadow module """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs") float_module_children = {} for name, mod in float_module.named_children(): @@ -359,6 +354,7 @@ def compare_model_stub( Logger: type of logger to be used in shadow module to process the outputs of quantized module and its float shadow module """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub") prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) q_model(*data) ob_dict = get_logger_dict(q_model) @@ -377,9 +373,10 @@ def get_matching_activations(float_module, q_module): entry being a dictionary with two keys 'float' and 'quantized', containing the matching float and quantized activations """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations") float_dict = get_logger_dict(float_module) quantized_dict = get_logger_dict(q_module) - act_dict = {} + act_dict: Dict[str, Dict] = {} for key in quantized_dict: match_key = _find_match(sorted(float_dict, reverse=True), key, "stats") if match_key is not None: @@ -404,8 +401,9 @@ def prepare_model_outputs( Logger: type of logger to be attached to float_module and q_module allow_list: list of module types to attach logger """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs") if allow_list is None: - allow_list = get_compare_output_module_list() + allow_list = get_default_compare_output_module_list() qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None) float_module.qconfig = qconfig_debug @@ -450,8 +448,9 @@ def compare_model_outputs( and each entry being a dictionary with two keys 'float' and 'quantized', containing the matching float and quantized activations """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs") if allow_list is None: - allow_list = get_compare_output_module_list() + allow_list = get_default_compare_output_module_list() prepare_model_outputs(float_model, q_model, Logger, allow_list) float_model(*data) q_model(*data) diff --git a/torch/quantization/_numeric_suite_fx.py b/torch/quantization/_numeric_suite_fx.py new file mode 100644 index 0000000000000..aeba95bb4e8fc --- /dev/null +++ b/torch/quantization/_numeric_suite_fx.py @@ -0,0 +1,139 @@ +from typing import Any, Dict + +import torch +import torch.nn as nn +import torch.nn.quantized as nnq +import torch.nn.quantized.dynamic as nnqd +from torch.fx import GraphModule # type: ignore +from torch.fx import map_arg # type: ignore +from torch.fx.graph import Graph +from torch.quantization.fx.quantize import _remove_qconfig, is_activation_post_process + + +NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { + nnqd.Linear, + nnq.Linear, + nnqd.LSTM, + nn.LSTM, +} + + +def remove_qconfig_observer_fx(model): + # remove activation post process + act_post_process_removed_graph = Graph() + env: Dict[str, Any] = {} + + modules = dict(model.named_modules()) + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in model.graph.nodes: + if node.op == "output": + act_post_process_removed_graph.output(map_arg(node.args[0], load_arg)) + continue + if node.op == "call_module" and is_activation_post_process( + modules[node.target] + ): + # remove activation post process node + env[node.name] = env[node.args[0].name] + else: + env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg) + + _remove_qconfig(model) + model = GraphModule(model, act_post_process_removed_graph) + return model + + +def _find_match(str_list, key_str, postfix): + split_str = key_str.split(".") + if split_str[-1] == postfix: + match_string = "".join(key_str.split(".")[0:-1]) + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + + # For matching "fc.weight" and "fc._packed_params._packed_params" + if postfix == "_packed_params": + match_string = "".join(key_str.split(".")[0:-2]) + if len(match_string) == 0: + return None + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + else: + return None + + +def compare_weights_fx(float_dict, quantized_dict): + r"""Compare the weights of the float module with its corresponding quantized + module. Return a dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights. This dict can be used to compare and compute the quantization + error of the weights of float and quantized models. + + Example usage: + prepared_model = prepare_fx(float_model, qconfig_dict) + backup_prepared_model = copy.deepcopy(prepared_model) + quantized_model = convert_fx(prepared_model) + + qmodel = quantized_model + wt_compare_dict = compare_weights(backup_prepared_model.state_dict(), qmodel.state_dict()) + for key in wt_compare_dict: + print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize())) + + Args: + float_dict: state dict of the float model (prepared model) + quantized_dict: state dict of the quantized model + + Return: + weight_dict: dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.compare_weights_fx" + ) + weight_dict: Dict[str, Dict] = {} + for key in quantized_dict: + match_key = _find_match(float_dict, key, "weight") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key] + continue + + # For matching "fc.weight" and "fc._packed_params._packed_params" + match_key = _find_match(float_dict, key, "_packed_params") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key][0] + + # For LSTM + split_str = key.split(".") + if split_str[-1] == "param" and split_str[-3] == "_all_weight_values": + layer = split_str[-2] + module_name = ".".join(split_str[:-3]) + float_weight_ih_key = module_name + ".weight_ih_l" + layer + float_weight_hh_key = module_name + ".weight_hh_l" + layer + if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[float_weight_ih_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0] + ) + weight_dict[key]["float"] = float_dict[float_weight_hh_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0] + ) + + return weight_dict diff --git a/torch/quantization/custom_module_class_mappings.py b/torch/quantization/custom_module_class_mappings.py deleted file mode 100644 index c62290228c5b9..0000000000000 --- a/torch/quantization/custom_module_class_mappings.py +++ /dev/null @@ -1,75 +0,0 @@ -OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS = dict() - -def register_observed_custom_module_mapping(float_custom_module_class, observed_custom_module_class): - """ Register a mapping from `float_custom_module_class` to - `observed_custom_module_class` - `observed_custom_module_class` will have a `from_float` classmethod, - which will return an observed custom module instance given - a float custom module instance. - This will be used in prepare step of post training static quantization or - quantization aware training - """ - assert hasattr(observed_custom_module_class, 'from_float'), 'from_float must be' + \ - ' defined in observed custom module class' - OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \ - observed_custom_module_class - -def get_observed_custom_module_class(float_custom_module_class): - """ Get the corresponding observed module class for a given - float custom module. - """ - observed_custom_module_class = \ - OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None) - assert observed_custom_module_class is not None, \ - 'Float Custom module class {}'.format(float_custom_module_class) + \ - ' does not have a corresponding observed module class' - return observed_custom_module_class - -QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS = dict() - -def register_quantized_custom_module_mapping(float_custom_module_class, quantized_custom_module_class): - """ Register a mapping from `float_custom_module_class` to `quantized_custom_module_class` - A quantized custom module class should accept quantized input and - return quantized output. (we can relax this condition in the - future if there is a need) - `quantized_custom_module_class` will have a `from_observed` classmethod, - which will return an quantized custom module instance given - a observed custom module instance. - This will be used in prepare step of post training static quantization or - quantization aware training - """ - assert hasattr(quantized_custom_module_class, 'from_observed'), 'from_observed' + \ - ' must be defined in quantized custom module class' - QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \ - quantized_custom_module_class - -def get_quantized_custom_module_class(float_custom_module_class): - """ Get the corresponding quantized module class for a given - float custom module. - """ - quantized_custom_module_class = \ - QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None) - assert quantized_custom_module_class is not None, \ - 'Float Custom module class {}'.format(float_custom_module_class) + \ - ' does not have a corresponding quantized module class' - return quantized_custom_module_class - -def is_custom_module_class(module_class): - """ Check if a given module class is a custom module class - """ - return module_class in OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS and \ - module_class in QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS - -def mark_observed_custom_module(module, custom_module_class): - """ Mark a module as observed custom module, so that - it can be identified during convert step - """ - module._is_observed_custom_module = True - module._FLOAT_MODULE = custom_module_class - -def is_observed_custom_module(module): - """ Check if a module is marked as observed custom module - or not - """ - return hasattr(module, '_is_observed_custom_module') and \ - module._is_observed_custom_module diff --git a/torch/quantization/default_mappings.py b/torch/quantization/default_mappings.py deleted file mode 100644 index a1ddb2f629852..0000000000000 --- a/torch/quantization/default_mappings.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch -from torch import nn - -import torch.nn.functional as F -import torch.nn.intrinsic as nni -import torch.nn.intrinsic.quantized as nniq -import torch.nn.intrinsic.qat as nniqat -import torch.nn.quantized as nnq -import torch.nn.quantized.dynamic as nnqd -import torch.nn.qat as nnqat - -from .stubs import QuantStub, DeQuantStub - -# Map for swapping float module to quantized ones -DEFAULT_MODULE_MAPPING = { - nn.Linear: nnq.Linear, - nn.ReLU: nnq.ReLU, - nn.ReLU6: nnq.ReLU6, - nn.Hardswish: nnq.Hardswish, - nn.ELU: nnq.ELU, - nn.Conv1d: nnq.Conv1d, - nn.Conv2d: nnq.Conv2d, - nn.Conv3d: nnq.Conv3d, - nn.ConvTranspose1d: nnq.ConvTranspose1d, - nn.ConvTranspose2d: nnq.ConvTranspose2d, - nn.BatchNorm2d: nnq.BatchNorm2d, - nn.BatchNorm3d: nnq.BatchNorm3d, - nn.LayerNorm: nnq.LayerNorm, - nn.GroupNorm: nnq.GroupNorm, - nn.InstanceNorm1d: nnq.InstanceNorm1d, - nn.InstanceNorm2d: nnq.InstanceNorm2d, - nn.InstanceNorm3d: nnq.InstanceNorm3d, - nn.Embedding: nnq.Embedding, - nn.EmbeddingBag: nnq.EmbeddingBag, - QuantStub: nnq.Quantize, - DeQuantStub: nnq.DeQuantize, - # Wrapper Modules: - nnq.FloatFunctional: nnq.QFunctional, - # Intrinsic modules: - nni.ConvReLU1d: nniq.ConvReLU1d, - nni.ConvReLU2d: nniq.ConvReLU2d, - nni.ConvReLU3d: nniq.ConvReLU3d, - nni.LinearReLU: nniq.LinearReLU, - nni.BNReLU2d: nniq.BNReLU2d, - nni.BNReLU3d: nniq.BNReLU3d, - nniqat.ConvReLU2d: nniq.ConvReLU2d, - nniqat.LinearReLU: nniq.LinearReLU, - nniqat.ConvBn2d: nnq.Conv2d, - nniqat.ConvBnReLU2d: nniq.ConvReLU2d, - # QAT modules: - nnqat.Linear: nnq.Linear, - nnqat.Conv2d: nnq.Conv2d, -} - -# mapping from floating point function or torch ops to quantized ops -DEFAULT_OPERATOR_MAPPING = { - F.elu: torch._ops.ops.quantized.elu, - F.hardswish: torch._ops.ops.quantized.hardswish, - F.instance_norm: torch._ops.ops.quantized.instance_norm, - F.layer_norm: torch._ops.ops.quantized.layer_norm, -} - -# Map for swapping float module to qat modules -DEFAULT_QAT_MODULE_MAPPING = { - nn.Linear: nnqat.Linear, - nn.Conv2d: nnqat.Conv2d, - # Intrinsic modules: - nni.ConvBn2d: nniqat.ConvBn2d, - nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, - nni.ConvReLU2d: nniqat.ConvReLU2d, - nni.LinearReLU: nniqat.LinearReLU -} - -# Map for swapping dynamic modules -DEFAULT_DYNAMIC_MODULE_MAPPING = { - nn.Linear: nnqd.Linear, - nn.LSTM: nnqd.LSTM, - nn.LSTMCell: nnqd.LSTMCell, - nn.RNNCell: nnqd.RNNCell, - nn.GRUCell: nnqd.GRUCell, -} - -# Allowed list for propagating the qconfig -_EXCLUDE_QCONFIG_PROPAGATE_LIST = { - DeQuantStub, -} -_INCLUDE_QCONFIG_PROPAGATE_LIST = { - nn.Sequential, -} - -DEFAULT_QCONFIG_PROPAGATE_ALLOWED_LIST = ( - (set(DEFAULT_MODULE_MAPPING.keys()) | - set(DEFAULT_QAT_MODULE_MAPPING.keys()) | - set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys()) | - _INCLUDE_QCONFIG_PROPAGATE_LIST) - - _EXCLUDE_QCONFIG_PROPAGATE_LIST -) - -DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_ALLOWED_LIST = ( - set(DEFAULT_MODULE_MAPPING.values()) - | set(DEFAULT_QAT_MODULE_MAPPING.values()) - | set(DEFAULT_DYNAMIC_MODULE_MAPPING.values()) - | set(DEFAULT_MODULE_MAPPING.keys()) - | set(DEFAULT_QAT_MODULE_MAPPING.keys()) - | set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys()) - | _INCLUDE_QCONFIG_PROPAGATE_LIST -) - _EXCLUDE_QCONFIG_PROPAGATE_LIST diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index 6cd06f8567d48..460b1c277a935 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -2,8 +2,63 @@ from torch.nn import Module from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args import re +from abc import ABC, abstractmethod -class FakeQuantize(Module): +def _is_per_channel(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine] + +def _is_per_tensor(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + +class FakeQuantizeBase(ABC, Module): + r""" Base fake quantize module + Any fake quantize implementation should derive from this class. + + Concrete fake quantize module should follow the same API. In forward, they will update + the statistics of the observed Tensor and fake quantize the input. They should also provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + """ + + fake_quant_enabled: torch.Tensor + observer_enabled: torch.Tensor + + def __init__(self): + super().__init__() + # fake_quant_enabled and observer_enabled are buffers to support their + # replication in DDP. Data type is uint8 because NCCL does not support + # bool tensors. + self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8)) + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + @torch.jit.export + def enable_fake_quant(self, enabled: bool = True) -> None: + self.fake_quant_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_fake_quant(self): + self.enable_fake_quant(False) + + @torch.jit.export + def enable_observer(self, enabled: bool = True) -> None: + self.observer_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_observer(self): + self.enable_observer(False) + + with_args = classmethod(_with_args) + +class FakeQuantize(FakeQuantizeBase): r""" Simulate the quantize and dequantize operations in training time. The output of this module is given by @@ -41,17 +96,16 @@ class FakeQuantize(Module): provides a method to calculate scale and zero-point. """ + + scale: torch.Tensor + zero_point: torch.Tensor + def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): - super(FakeQuantize, self).__init__() + super().__init__() assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' self.quant_min = quant_min self.quant_max = quant_max - # fake_quant_enabled and observer_enabled are buffers to support their - # replication in DDP. Data type is uint8 because NCCL does not support - # bool tensors. - self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) - self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8)) self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound' assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound' @@ -61,24 +115,11 @@ def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max= self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1 - - @torch.jit.export - def enable_fake_quant(self, enabled=True): - # type: (bool) -> None - self.fake_quant_enabled[0] = 1 if enabled else 0 - - @torch.jit.export - def disable_fake_quant(self): - self.enable_fake_quant(False) - - @torch.jit.export - def enable_observer(self, enabled=True): - # type: (bool) -> None - self.observer_enabled[0] = 1 if enabled else 0 - - @torch.jit.export - def disable_observer(self): - self.enable_observer(False) + assert _is_per_channel(self.qscheme) or \ + _is_per_tensor(self.qscheme), \ + 'Only per channel and per tensor quantization are supported in fake quantize' + \ + ' got qscheme: ' + str(self.qscheme) + self.is_per_channel = _is_per_channel(self.qscheme) @torch.jit.export def calculate_qparams(self): @@ -95,7 +136,7 @@ def forward(self, X): self.zero_point.copy_(_zero_point) if self.fake_quant_enabled[0] == 1: - if self.qscheme == torch.per_channel_symmetric or self.qscheme == torch.per_channel_affine: + if self.is_per_channel: X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: @@ -104,16 +145,14 @@ def forward(self, X): self.quant_max) return X - with_args = classmethod(_with_args) - @torch.jit.export def extra_repr(self): - return 'fake_quant_enabled={}, observer_enabled={},\ - quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, \ - scale={}, zero_point={}'.format( - self.fake_quant_enabled, self.observer_enabled, - self.quant_min, self.quant_max, - self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ + 'scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) def _save_to_state_dict(self, destination, prefix, keep_vars): # We cannot currently register scalar values as buffers, so need to manually @@ -137,11 +176,69 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) +class FixedQParamsFakeQuantize(FakeQuantizeBase): + """ Simulate quantize and dequantize with fixed quantization + parameters in training time. Only per tensor quantization + is supported. + Args: + `scale` (float): fixed scale for the fake quantize module + `zero_point` (int): fixed zero point for the fake quantize module + `dtype`, `qscheme`, `quant_min`, `quant_max` + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__(self, + scale, + zero_point, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=0, + quant_max=255): + super().__init__() + assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max' + self.quant_min = quant_min + self.quant_max = quant_max + self.register_buffer('scale', torch.tensor([scale])) + self.register_buffer('zero_point', torch.tensor([zero_point])) + self.dtype = dtype + self.qscheme = qscheme + assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \ + ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme) + + def forward(self, X): + if self.fake_quant_enabled[0] == 1: + X = torch.fake_quantize_per_tensor_affine(X, float(self.scale), + int(self.zero_point), self.quant_min, + self.quant_max) + return X + + @torch.jit.export + def calculate_qparams(self): + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \ + 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.scale, self.zero_point, self.dtype, + self.quant_min, self.quant_max, self.qscheme) + + default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False) +# TODO(future PR): remove these defaults and enforce activation functions +# to explicitly specify their output range +default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args( + scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255) +default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args( + scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255) + default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, quant_min=-128, quant_max=127, @@ -167,17 +264,17 @@ def _is_fake_quant_script_module(mod): return False def disable_fake_quant(mod): - if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod): + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): mod.disable_fake_quant() def enable_fake_quant(mod): - if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod): + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): mod.enable_fake_quant() def disable_observer(mod): - if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod): + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): mod.disable_observer() def enable_observer(mod): - if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod): + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): mod.enable_observer() diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index d8f06e5387268..5ccebd973fd3a 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module): setattr(cur_mod, tokens[-1], module) -def fuse_known_modules(mod_list): +def fuse_known_modules(mod_list, additional_fuser_method_mapping=None): r"""Returns a list of modules that fuses the operations specified in the input module list. @@ -41,7 +41,7 @@ def fuse_known_modules(mod_list): the fused operation. The rest of the elements are set to nn.Identity() """ types = tuple(type(m) for m in mod_list) - fuser_method = get_fuser_method(types) + fuser_method = get_fuser_method(types, additional_fuser_method_mapping) if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) @@ -64,20 +64,22 @@ def fuse_known_modules(mod_list): return new_mod -def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules): - +def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) mod_list = [] for item in modules_to_fuse: mod_list.append(_get_module(model, item)) # Fuse list of modules - new_mod_list = fuser_func(mod_list) + new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping) # Replace original module list with fused module list for i, item in enumerate(modules_to_fuse): _set_module(model, item, new_mod_list[i]) -def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules): +def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): r"""Fuses a list of modules into a single module Fuses only the following sequence of modules: @@ -91,7 +93,7 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo with the fused module, replacing the rest of the modules with identity. - Arguments: + Args: model: Model containing the modules to be fused modules_to_fuse: list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse. @@ -101,6 +103,18 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.quantization.fuse_known_modules + `fuse_custom_config_dict`: custom configuration for fusion + + .. code-block:: python + + # Example of fuse_custom_config_dict + fuse_custom_config_dict = { + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + } + Returns: model with fused modules. A new copy is created if inplace=True. @@ -124,9 +138,9 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo if all(isinstance(module_element, str) for module_element in modules_to_fuse): # Handle case of modules_to_fuse being a list - _fuse_modules(model, modules_to_fuse, fuser_func) + _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict) else: # Handle case of modules_to_fuse being a list of lists for module_list in modules_to_fuse: - _fuse_modules(model, module_list, fuser_func) + _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict) return model diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 72ad5a7bcc718..d098620ae663f 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -3,6 +3,8 @@ from typing import Union, Callable, Tuple, Dict, Optional, Type +from .utils import get_combined_dict + def fuse_conv_bn(conv, bn): r"""Given the conv and bn modules, fuses them and returns the fused module @@ -19,14 +21,21 @@ def fuse_conv_bn(conv, bn): assert(conv.training == bn.training),\ "Conv and BN both must be in the same mode (train or eval)." - is_3d = isinstance(conv, nn.Conv3d) + fused_module_class_map = { + nn.Conv1d: nni.ConvBn1d, + nn.Conv2d: nni.ConvBn2d, + nn.Conv3d: nni.ConvBn3d, + } if conv.training: assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' - return nni.ConvBn3d(conv, bn) if is_3d \ - else nni.ConvBn2d(conv, bn) + fused_module_class = fused_module_class_map.get((type(conv)), None) + if fused_module_class is not None: + return fused_module_class(conv, bn) + else: + raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn))) else: return nn.utils.fuse_conv_bn_eval(conv, bn) @@ -48,13 +57,14 @@ def fuse_conv_bn_relu(conv, bn, relu): fused_module : Optional[Type[nn.Sequential]] = None if conv.training: map_to_fused_module_train = { + nn.Conv1d: nni.ConvBnReLU1d, nn.Conv2d: nni.ConvBnReLU2d, nn.Conv3d: nni.ConvBnReLU3d, } assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' assert bn.affine, 'Only support fusing BatchNorm with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True' - fused_module = map_to_fused_module_train.get(type(conv)) + fused_module = map_to_fused_module_train.get(type(conv), None) if fused_module is not None: return fused_module(conv, bn, relu) else: @@ -65,14 +75,14 @@ def fuse_conv_bn_relu(conv, bn, relu): nn.Conv2d: nni.ConvReLU2d, nn.Conv3d: nni.ConvReLU3d, } - fused_module = map_to_fused_module_eval[type(conv)] + fused_module = map_to_fused_module_eval.get(type(conv), None) if fused_module is not None: fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) return fused_module(fused_conv, relu) else: raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) -OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { +DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, @@ -87,15 +97,14 @@ def fuse_conv_bn_relu(conv, bn, relu): (nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d, } -def register_fuser_method(op_list, fuser_method): - ''' Register a fuser method for a tuple of ops, will be called - during fusion step - ''' - assert isinstance(op_list, tuple), 'op list must be a tuple' - OP_LIST_TO_FUSER_METHOD[op_list] = fuser_method - -def get_fuser_method(op_list): +def get_fuser_method(op_list, additional_fuser_method_mapping=None): ''' Get fuser method for the given list of module types, return None if fuser method does not exist ''' - return OP_LIST_TO_FUSER_METHOD.get(op_list, None) + if additional_fuser_method_mapping is None: + additional_fuser_method_mapping = dict() + all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD, + additional_fuser_method_mapping) + fuser_method = all_mappings.get(op_list, None) + assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list) + return fuser_method diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 4e8103d710159..59e3851dcd57b 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -1,33 +1,48 @@ -from torch.fx import ( - GraphModule +from typing import Dict, Any + +from torch.fx import ( # type: ignore + GraphModule, + Node, + map_arg ) -from torch.fx.graph import ( - Graph, - map_arg, +from torch.fx.graph import Graph + +from ..utils import ( + get_combined_dict ) from .pattern_utils import ( is_match, - get_fusion_patterns, + get_default_fusion_patterns, ) from .fusion_patterns import * # noqa: F401 -import copy +from .quantization_types import Pattern + +from typing import Callable, Tuple + + class Fuser: - def fuse(self, model, inplace=False): - if not inplace: - model = copy.deepcopy(model) + def fuse(self, model: GraphModule, + fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + input_root = model input_graph = model.graph self.modules = dict(input_root.named_modules()) - fusion_patterns = get_fusion_patterns() + additional_fusion_patterns = \ + fuse_custom_config_dict.get("additional_fusion_pattern", {}) + fusion_patterns = get_combined_dict( + get_default_fusion_patterns(), additional_fusion_patterns) # find fusion - fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) + fusion_pairs = self._find_matches( + input_root, input_graph, fusion_patterns) self.fused_graph = Graph() - env = {} + env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) @@ -35,18 +50,21 @@ def load_arg(a): for node in input_graph.nodes: root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: + assert obj is not None env[node.name] = obj.fuse(self, load_arg) elif root_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here - self.fused_graph.output(load_arg(input_graph.result)) model = GraphModule(input_root, self.fused_graph) return model - def _find_matches(self, root, graph, patterns): + def _find_matches( + self, root: GraphModule, graph: Graph, + patterns: Dict[Pattern, Callable] + ) -> Dict[str, Tuple[Node, FuseHandler]]: modules = dict(root.named_modules()) - match_map = {} # node name -> (root_node, match_value?) + match_map : Dict[str, Tuple[Node, FuseHandler]] = {} # node name -> (root_node, match_value) def apply_match(pattern, node, match): if isinstance(pattern, tuple): diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index fe5631d854828..1749484fccec4 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -1,42 +1,70 @@ import torch +from torch.fx.graph import Node from .pattern_utils import ( register_fusion_pattern, ) from .utils import _parent_name +from .quantization_types import QuantizerCls from ..fuser_method_mappings import get_fuser_method +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict # --------------------- -# Fusion Patterns +# Fusion Pattern Registrations # --------------------- -@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) +# Base Pattern Handler +class FuseHandler(ABC): + """ Base handler class for the fusion patterns + """ + def __init__(self, quantizer: QuantizerCls, node: Node): + pass + + @abstractmethod + def fuse(self, quantizer: QuantizerCls, load_arg: Callable, + fuse_custom_config_dict: Dict[str, Any] = None) -> Node: + pass + @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d)) +@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d)) +@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) +@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d)) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) @register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) -class ConvBNReLUFusion(): - def __init__(self, quantizer, node): - super().__init__() +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) +class ConvBNReLUFusion(FuseHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) self.relu_node = None self.bn_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU): self.relu_node = node + assert isinstance(node.args[0], Node) node = node.args[0] assert node.op == 'call_module' if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]: self.bn_node = node self.bn = quantizer.modules[self.bn_node.target] + assert isinstance(node.args[0], Node) node = node.args[0] assert node.op == 'call_module' self.conv_node = node self.conv = quantizer.modules[self.conv_node.target] - def fuse(self, quantizer, load_arg): + def fuse(self, quantizer: QuantizerCls, load_arg: Callable, + fuse_custom_config_dict: Dict[str, Any] = None) -> Node: + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) op_list = [] if self.relu_node is not None: # since relu can be used multiple times, we'll need to create a relu module for each match @@ -60,10 +88,11 @@ def fuse(self, quantizer, load_arg): op_list.reverse() op_type_list = tuple(type(m) for m in op_list) conv_parent_name, conv_name = _parent_name(self.conv_node.target) - fuser_method = get_fuser_method(op_type_list) + fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping) if fuser_method is None: - raise NotImplementedError("Cannot fuse modules: {}".format(types)) - setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list)) + raise NotImplementedError("Cannot fuse modules: {}".format(op_type_list)) + fused = fuser_method(*op_list) + setattr(quantizer.modules[conv_parent_name], conv_name, fused) # TODO: do we need to make sure bn is only used once? if self.bn_node is not None: @@ -74,22 +103,25 @@ def fuse(self, quantizer, load_arg): @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear)) -@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm1d)) -@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm1d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d)) -class ModuleReLUFusion(): - def __init__(self, quantizer, node): - super().__init__() +class ModuleReLUFusion(FuseHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) self.relu_node = node + assert isinstance(node.args[0], Node) node = node.args[0] assert node.op == 'call_module' self.module_node = node self.module = quantizer.modules[self.module_node.target] - def fuse(self, quantizer, load_arg): + def fuse(self, quantizer: QuantizerCls, load_arg: Callable, + fuse_custom_config_dict: Dict[str, Any] = None) -> Node: + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) op_list = [] # since relu can be used multiple times, we'll need to create a relu module for each match if self.relu_node.op == 'call_module': @@ -104,6 +136,6 @@ def fuse(self, quantizer, load_arg): op_list.reverse() op_type_list = tuple(type(m) for m in op_list) module_parent_name, module_name = _parent_name(self.module_node.target) - fuser_method = get_fuser_method(op_type_list) + fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping) setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list)) return quantizer.fused_graph.node_copy(self.module_node, load_arg) diff --git a/torch/quantization/fx/observed_module.py b/torch/quantization/fx/observed_module.py new file mode 100644 index 0000000000000..808a3b36fb4a7 --- /dev/null +++ b/torch/quantization/fx/observed_module.py @@ -0,0 +1,53 @@ +import torch +import copy +from torch.fx import GraphModule # type: ignore +from torch.fx.graph import Graph +from typing import Union, Dict, Any, List + +class ObservedGraphModule(GraphModule): + + def get_preserved_attr_names(self) -> List[str]: + return ['_activation_post_process_map', + '_patterns', + '_qconfig_map', + '_prepare_custom_config_dict'] + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): + preserved_attrs = dict() + for attr in self.get_preserved_attr_names(): + preserved_attrs[attr] = getattr(root, attr) + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedGraphModule(fake_mod, self.graph) + +def mark_observed_module(module: GraphModule) -> GraphModule: + return ObservedGraphModule(module, module.graph) + +def is_observed_module(module: Any) -> bool: + return isinstance(module, ObservedGraphModule) + +class ObservedStandaloneGraphModule(ObservedGraphModule): + def get_preserved_attr_names(self) -> List[str] : + return super().get_preserved_attr_names() + [ + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs" + ] + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedStandaloneGraphModule(fake_mod, self.graph) + +def mark_observed_standalone_module(module: GraphModule) -> GraphModule: + return ObservedStandaloneGraphModule(module, module.graph) + +def is_observed_standalone_module(module: Any) -> bool: + return isinstance(module, ObservedStandaloneGraphModule) diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index ae9e92ccda268..fe13d0a3fed73 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -1,41 +1,66 @@ import torch import sys from collections import OrderedDict +from typing import Dict, Any + +from .quantization_types import Pattern + +# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency) +QuantizeHandler = Any # pattern for conv bn fusion -FUSION_PATTERNS = OrderedDict() +DEFAULT_FUSION_PATTERNS = OrderedDict() def register_fusion_pattern(pattern): def insert(fn): - FUSION_PATTERNS[pattern] = fn + DEFAULT_FUSION_PATTERNS[pattern] = fn return fn return insert -def get_fusion_patterns(): - return FUSION_PATTERNS +def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]: + return DEFAULT_FUSION_PATTERNS + +DEFAULT_QUANTIZATION_PATTERNS = OrderedDict() +# a map from pattern to activation_post_process(observer/fake_quant) consstructor for output activation +# e.g. pattern: torch.sigmoid, +# output_activation_post_process: default_affine_fixed_qparam_fake_quant +DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP = dict() -QUANTIZATION_PATTERNS = OrderedDict() # Register pattern for both static quantization and qat -def register_quant_pattern(pattern): +def register_quant_pattern(pattern, output_activation_post_process=None): def insert(fn): - QUANTIZATION_PATTERNS[pattern] = fn + DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn + if output_activation_post_process is not None: + DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP[pattern] = output_activation_post_process return fn return insert # Get patterns for both static quantization and qat -def get_quant_patterns(): - return QUANTIZATION_PATTERNS +def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]: + return DEFAULT_QUANTIZATION_PATTERNS -DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict() -# Register pattern for dynamic quantization -def register_dynamic_quant_pattern(pattern): +# a map from pattern to output activation post process constructor +# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant +def get_default_output_activation_post_process_map() -> Dict[Pattern, torch.quantization.observer.ObserverBase]: + return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP + +# a set of QuantizeHandler classes that are not observed +# we'll skip inserting observers for input and output for these QuantizeHandlers +# used for ops that only supports dynamic/weight only quantization +DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER = set() +def mark_input_output_not_observed(): def insert(fn): - DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn + DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER.add(fn) return fn return insert -# Get patterns for dynamic quantization -def get_dynamic_quant_patterns(): - return DYNAMIC_QUANTIZATION_PATTERNS +def input_output_observed(qh): + return type(qh) not in DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER + + +class MatchAllNode: + """ A node pattern that matches all nodes + """ + pass # Example use of register pattern function: # @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) @@ -60,7 +85,10 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize): self_match = pattern arg_matches = [] - if node.uses > max_uses: + if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): + return True + + if len(node.users) > max_uses: return False if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): @@ -74,6 +102,9 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize): elif node.target is getattr: if node.args[1] != pattern[1]: return False + elif isinstance(self_match, str): + if node.op != 'call_method' or node.target != self_match: + return False elif node.target != self_match: return False diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py new file mode 100644 index 0000000000000..42d4c1a3f3b73 --- /dev/null +++ b/torch/quantization/fx/qconfig_utils.py @@ -0,0 +1,98 @@ +import torch +from collections import OrderedDict +from typing import Union, Callable, Any +import re + +from .utils import _parent_name + +QConfigAny = Union[torch.quantization.QConfig, + torch.quantization.QConfigDynamic, None] + +def get_flattened_qconfig_dict(qconfig_dict): + """ flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened = dict() + if '' in qconfig_dict: + flattened[''] = qconfig_dict[''] + + def flatten_key(key): + if key in qconfig_dict: + for obj, qconfig in qconfig_dict[key]: + flattened[obj] = qconfig + + flatten_key('object_type') + flatten_key('module_name') + return flattened + +def convert_dict_to_ordered_dict(qconfig_dict): + """ Convert dict in qconfig_dict to ordered dict + """ + # convert a qconfig list for a type to OrderedDict + def _convert_to_ordered_dict(key, qconfig_dict): + qconfig_dict[key] = OrderedDict(qconfig_dict.get(key, [])) + + _convert_to_ordered_dict('object_type', qconfig_dict) + _convert_to_ordered_dict('module_name_regex', qconfig_dict) + _convert_to_ordered_dict('module_name', qconfig_dict) + +def get_object_type_qconfig( + qconfig_dict: Any, + object_type: Union[Callable, str], + fallback_qconfig: QConfigAny) -> QConfigAny: + # object_type can be + # 1. module type (call_module) + # 2. function (call_function) + # 3. string (call_method) + return qconfig_dict['object_type'].get( + object_type, fallback_qconfig) + +def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig): + for regex_pattern, qconfig in \ + qconfig_dict['module_name_regex'].items(): + if re.match(regex_pattern, module_name): + # first match wins + return qconfig + return fallback_qconfig + +def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): + if module_name == '': + # module name qconfig not found + return fallback_qconfig + if module_name in qconfig_dict['module_name']: + return qconfig_dict['module_name'][module_name] + else: + parent, _ = _parent_name(module_name) + return get_module_name_qconfig(qconfig_dict, parent, fallback_qconfig) + +# get qconfig for module_name, +# fallback to module_name_regex_qconfig, module_type_qconfig, +# global_qconfig if necessary +def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig): + module_type_qconfig = get_object_type_qconfig( + qconfig_dict, module_type, global_qconfig) + module_name_regex_qconfig = get_module_name_regex_qconfig( + qconfig_dict, module_name, module_type_qconfig) + module_name_qconfig = get_module_name_qconfig( + qconfig_dict, module_name, module_name_regex_qconfig) + return module_name_qconfig diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index ab85c9a9daffc..fb5bef0bd0add 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -2,25 +2,45 @@ from torch.fx.graph import ( Node, ) +import torch.nn.quantized as nnq +import torch.nn.quantized.dynamic as nnqd +from torch.quantization import ( + default_affine_fixed_qparams_fake_quant, + default_symmetric_fixed_qparams_fake_quant, +) + from ..quantization_mappings import ( get_static_quant_module_class, + get_dynamic_quant_module_class, get_quantized_operator, ) -from ..custom_module_class_mappings import ( - get_quantized_custom_module_class, +from ..utils import ( + get_swapped_custom_module_class, + activation_is_statically_quantized, + weight_is_statically_quantized, + weight_dtype, + get_qconfig_dtypes, ) + from .pattern_utils import ( register_quant_pattern, - register_dynamic_quant_pattern, + mark_input_output_not_observed, ) + from .utils import ( _parent_name, quantize_node, get_per_tensor_qparams, + get_linear_prepack_op_for_dtype, ) +from .quantization_types import QuantizerCls + from abc import ABC, abstractmethod import operator +import warnings + +from typing import Any, Callable, Dict # ------------------------- # Pattern Registrations @@ -32,47 +52,61 @@ class QuantizeHandler(ABC): """ Base handler class for the quantizer patterns """ - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): """ Records pattern information in __init__, which will be used in convert """ # this is an indicator of whether all the inputs are Node or not # since some op might be quantized differently depending on whether # all inputs are tensors or not, e.g. add/mul - self.all_nodes = True + self.num_node_args = len(node.args) + self.all_node_args = True @abstractmethod - def convert(self, quantizer, node, load_arg, debug=False): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: """ Convert the given node to a quantized node and insert it to the quantized graph """ return NotImplemented @register_quant_pattern(operator.add) +@register_quant_pattern(torch.add) @register_quant_pattern((torch.nn.ReLU, operator.add)) +@register_quant_pattern((torch.nn.ReLU, torch.add)) @register_quant_pattern((torch.nn.functional.relu, operator.add)) +@register_quant_pattern((torch.nn.functional.relu, torch.add)) class Add(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] - assert node.op == 'call_function' and node.target == operator.add + node = node.args[0] # type: ignore + assert node.op == 'call_function' and node.target in [operator.add, torch.add] self.add_node = node - self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]]) + self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)]) - def convert(self, quantizer, node, load_arg, debug=False): - if not self.all_nodes: + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + if self.num_node_args == 1: # add scalar if self.relu_node is not None: op = torch.ops.quantized.add_relu else: op = torch.ops.quantized.add + + if isinstance(self.add_node.args[0], Node): + quantized_index = 0 + else: + quantized_index = 1 + return quantizer.quantized_graph.create_node( 'call_function', op, - load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs) + load_arg(quantized=[quantized_index])(self.add_node.args), self.add_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -82,35 +116,46 @@ def convert(self, quantizer, node, load_arg, debug=False): op = torch.ops.quantized.add_relu else: op = torch.ops.quantized.add - kwargs = self.add_node.kwargs - kwargs.update({'scale': scale, 'zero_point': zero_point}) + kwargs = {**self.add_node.kwargs, 'scale': scale, 'zero_point': zero_point} return quantizer.quantized_graph.create_node( 'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs) +# TODO: merge with Add @register_quant_pattern(operator.mul) +@register_quant_pattern(torch.mul) @register_quant_pattern((torch.nn.ReLU, operator.mul)) +@register_quant_pattern((torch.nn.ReLU, torch.mul)) @register_quant_pattern((torch.nn.functional.relu, operator.mul)) +@register_quant_pattern((torch.nn.functional.relu, torch.mul)) class Mul(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] - assert node.op == 'call_function' and node.target == operator.mul + node = node.args[0] # type: ignore + assert node.op == 'call_function' and node.target in [operator.mul, torch.mul] self.mul_node = node - self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]]) + self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)]) - def convert(self, quantizer, node, load_arg, debug=False): - if not self.all_nodes: + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + if self.num_node_args == 1: # mul scalar if self.relu_node is not None: op = torch.ops.quantized.mul_relu else: op = torch.ops.quantized.mul + + if isinstance(self.mul_node.args[0], Node): + quantized_index = 0 + else: + quantized_index = 1 + return quantizer.quantized_graph.create_node( - 'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs) + 'call_function', op, load_arg(quantized=[quantized_index])(self.mul_node.args), self.mul_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -120,21 +165,21 @@ def convert(self, quantizer, node, load_arg, debug=False): op = torch.ops.quantized.mul_relu else: op = torch.ops.quantized.mul - kwargs = self.mul_node.kwargs - kwargs.update({'scale': scale, 'zero_point': zero_point}) + kwargs = {**self.mul_node.kwargs, 'scale': scale, 'zero_point': zero_point} return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs) @register_quant_pattern(torch.cat) class Cat(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False): - if not self.all_nodes: + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + if not self.all_node_args: return NotImplemented activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) - kwargs = load_arg(quantized=False)(node.kwargs) - kwargs.update({'scale': scale, 'zero_point': zero_point}) + kwargs = {**load_arg(quantized=False)(node.kwargs), 'scale': scale, 'zero_point': zero_point} return quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs) @@ -148,7 +193,9 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.intrinsic.ConvReLU1d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU2d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU3d) +@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d) +@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d) @register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d)) @@ -157,34 +204,39 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) class ConvRelu(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] + node = node.args[0] # type: ignore self.conv_node = node if node.op == 'call_module': self.conv = quantizer.modules[self.conv_node.target] - def convert(self, quantizer, node, load_arg, debug=False): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: # TODO: debug option for conv module + qconfig = quantizer.qconfig_map[node.name] + activation_statically_quantized = activation_is_statically_quantized(qconfig) + # only static qunatization (for both ptq and qat) is supported for conv + if not activation_statically_quantized: + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + if self.conv_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \ 'please make sure to run fusion before prepare' + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) # 1. attach activation post process to module - if type(self.conv) in [ - torch.nn.intrinsic.ConvReLU1d, - torch.nn.intrinsic.ConvReLU2d, - torch.nn.intrinsic.ConvReLU3d - ]: - self.conv[1].activation_post_process = quantizer.activation_post_process_map[node.name] - else: - self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] + self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] # 2. select quantized class - qconv_cls = get_static_quant_module_class(type(self.conv)) + qconv_cls = get_static_quant_module_class( + type(self.conv), additional_static_quant_mapping) quantized = qconv_cls.from_float(self.conv) parent_name, name = _parent_name(self.conv_node.target) setattr(quantizer.modules[parent_name], name, quantized) @@ -193,7 +245,8 @@ def convert(self, quantizer, node, load_arg, debug=False): self.conv_node.target, (load_arg(quantized=True)(self.conv_node.args[0]),), {}) - elif self.conv_node.op == 'call_function': + else: # call_function + assert self.conv_node.op == 'call_function' if self.relu_node is not None: raise Exception("functional conv + relu is not supported yet") if debug: @@ -235,63 +288,107 @@ def convert(self, quantizer, node, load_arg, debug=False): # for error checks @register_quant_pattern((torch.nn.ReLU, torch.nn.Linear)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear)) -class LinearReLU(QuantizeHandler): - def __init__(self, quantizer, node): +class LinearReLUQuantizeHandler(QuantizeHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] + node = node.args[0] # type: ignore self.linear_node = node if node.op == 'call_module': self.linear = quantizer.modules[self.linear_node.target] - def convert(self, quantizer, node, load_arg, debug=False): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + # Supported combinations are: + # quant_type | activation (compute_type) | weight + # static quint8 qint8 + # dynamic float32 (quint8) qint8 + # weight_only float32 float16 + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.quint8, torch.qint8, None), + (torch.float32, torch.qint8, torch.quint8), + (torch.float16, torch.float16, None), + ] + qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Linear " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + + activation_statically_quantized = activation_is_statically_quantized(qconfig) # TODO: debug option for linear module if self.linear_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \ 'please make sure to run fusion before prepare' - # 1. attach activation post process to module - if type(self.linear) == torch.nn.intrinsic.LinearReLU: - self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name] + # 1. attach output activation post process to linear module + if node.name in quantizer.activation_post_process_map: + # this is the static quantization case + output_activation_post_process = quantizer.activation_post_process_map[node.name] else: - self.linear.activation_post_process = quantizer.activation_post_process_map[node.name] - # 2. select quantized class + output_activation_post_process = None + + if output_activation_post_process: + self.linear.activation_post_process = output_activation_post_process + + # 2. select corresponding quantized linear class for the float linear class if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]: - qlinear = torch.nn.quantized.Linear + qlinear = nnq.Linear if activation_statically_quantized else nnqd.Linear elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]: + assert activation_statically_quantized, \ + 'Only static quantization is supported for LinearReLU' qlinear = torch.nn.intrinsic.quantized.LinearReLU else: raise Exception("unhandled linear type:", type(self.linear)) quantized = qlinear.from_float(self.linear) parent_name, name = _parent_name(self.linear_node.target) setattr(quantizer.modules[parent_name], name, quantized) + # activation needs to be quantized for static quantization return quantizer.quantized_graph.create_node( 'call_module', - self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {}) - elif self.linear_node.op == 'call_function': + self.linear_node.target, + (load_arg(quantized=activation_statically_quantized)(self.linear_node.args[0]),), {}) + else: # call_function + assert self.linear_node.op == 'call_function' if debug: - args = load_arg(quantized=[0, 1])(self.linear_node.args) + quantized_input_idxs = [] + if activation_statically_quantized: + quantized_input_idxs.append(0) + if weight_is_statically_quantized(qconfig): + quantized_input_idxs.append(1) + args = load_arg(quantized=quantized_input_idxs)(self.linear_node.args) args = load_arg(quantized=False)(self.linear_node.args) kwargs = load_arg(quantized=False)(self.linear_node.kwargs) linear_out = quantizer.quantized_graph.create_node( 'call_function', torch.nn.functional.linear, args, kwargs) - root_module = quantizer.modules[''] - return quantize_node( - root_module, - quantizer.quantized_graph, - linear_out, - quantizer.activation_post_process_map[self.linear_node.name]) - else: - # TODO: this code can be merged with dynamic linear code + if activation_statically_quantized: + # quantize output for statically quantized linear op + root_module = quantizer.modules[''] + return quantize_node( + root_module, + quantizer.quantized_graph, + linear_out, + quantizer.activation_post_process_map[self.linear_node.name]) + else: + # output for dynamically quantized linear op is not quantized + return linear_out + else: # non-debug option # linear args # (x, weight, bias, ...) - args = load_arg(quantized=[0, 1])(self.linear_node.args) - kwargs = load_arg(quantized=False)(self.linear_node.kwargs) + weight_quantized = weight_is_statically_quantized(qconfig) + linear_weight = load_arg(quantized=weight_quantized)(self.linear_node.args[1]) + + # get other arguments + kwargs = {**load_arg(quantized=False)(self.linear_node.kwargs)} # pack weight - weight = load_arg(quantized=True)(self.linear_node.args[1]) bias = None # all args after bias, including bias other_args = load_arg(quantized=False)(self.linear_node.args[2:]) @@ -303,39 +400,45 @@ def convert(self, quantizer, node, load_arg, debug=False): 'expect bias provided as a keyword argument when it is not a positional argument' bias = kwargs['bias'] kwargs.pop('bias') - prepack_args = (weight, bias) + prepack_args = (linear_weight, bias) + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype(qconfig)) packed_weight = quantizer.quantized_graph.create_node( - 'call_function', torch.ops.quantized.linear_prepack, prepack_args, {}) + 'call_function', prepack_op, prepack_args, {}) # construct linear input - linear_input = load_arg(quantized=True)(self.linear_node.args[0]) - activation_post_process = \ - quantizer.activation_post_process_map[self.linear_node.name] - scale, zero_point, _ = get_per_tensor_qparams(activation_post_process) - qlinear_args = (linear_input, packed_weight, scale, zero_point) - return quantizer.quantized_graph.create_node( - 'call_function', torch.ops.quantized.linear, qlinear_args, kwargs) + if activation_statically_quantized: + linear_input = load_arg(quantized=True)(self.linear_node.args[0]) + activation_post_process = \ + quantizer.activation_post_process_map[self.linear_node.name] + scale, zero_point, _ = get_per_tensor_qparams(activation_post_process) + qlinear_args = (linear_input, packed_weight, scale, zero_point) + return quantizer.quantized_graph.create_node( + 'call_function', torch.ops.quantized.linear, qlinear_args, kwargs) + else: + linear_input = load_arg(quantized=False)(self.linear_node.args[0]) + qlinear_args = (linear_input, packed_weight) # type: ignore + return quantizer.quantized_graph.create_node( + 'call_function', torch.ops.quantized.linear_dynamic, qlinear_args, kwargs) @register_quant_pattern(torch.nn.BatchNorm2d) @register_quant_pattern(torch.nn.BatchNorm3d) @register_quant_pattern(torch.nn.intrinsic.BNReLU2d) @register_quant_pattern(torch.nn.intrinsic.BNReLU3d) class BatchNorm(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) assert node.op == 'call_module' self.bn_node = node self.bn = quantizer.modules[self.bn_node.target] - def convert(self, quantizer, node, load_arg, debug=False): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) # 1. attach activation post process to module - activation_post_process = quantizer.activation_post_process_map[node.name] - if type(self.bn) in \ - [torch.nn.intrinsic.BNReLU2d, - torch.nn.intrinsic.BNReLU3d]: - self.bn[1].activation_post_process = activation_post_process - else: - self.bn.activation_post_process = activation_post_process - qbn_cls = get_static_quant_module_class(type(self.bn)) + self.bn.activation_post_process = quantizer.activation_post_process_map[node.name] + qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping) quantized = qbn_cls.from_float(self.bn) parent_name, name = _parent_name(self.bn_node.target) setattr(quantizer.modules[parent_name], name, quantized) @@ -345,12 +448,99 @@ def convert(self, quantizer, node, load_arg, debug=False): load_arg(quantized=[0])(self.bn_node.args), load_arg(quantized=False)(self.bn_node.kwargs)) +@register_quant_pattern(torch.nn.Embedding) +@register_quant_pattern(torch.nn.EmbeddingBag) +@mark_input_output_not_observed() +class Embedding(QuantizeHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) + + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + # Supported combinations are: + # quant_type | activation | weight | activation_compute_type + # weight_only | float32 | quint8 | None + # weight_only | float32 | quint4x2 | None + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.float32, torch.quint8, None), + (torch.float32, torch.quint4x2, None), + ] + assert node.op == 'call_module' + emb_node = node + qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Embedding/EmbeddingBag, " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + + emb = quantizer.modules[emb_node.target] + qemb = get_static_quant_module_class(type(emb)) + quantized = qemb.from_float(emb) + parent_name, name = _parent_name(emb_node.target) + setattr(quantizer.modules[parent_name], name, quantized) + return quantizer.quantized_graph.create_node( + 'call_module', + emb_node.target, + load_arg(quantized=False)(emb_node.args), + load_arg(quantized=False)(emb_node.kwargs)) + +# TODO (maybe): merge with embedding quantize handler +@register_quant_pattern(torch.nn.GRUCell) +@register_quant_pattern(torch.nn.LSTMCell) +@register_quant_pattern(torch.nn.RNNCell) +@register_quant_pattern(torch.nn.LSTM) +@mark_input_output_not_observed() +class RNNDynamic(QuantizeHandler): + def __init__(self, quantizer: QuantizerCls, node: Node): + super().__init__(quantizer, node) + + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + # Supported combinations are: + # quant_type | activation | weight | activation_compute_type + # dynamic | float32 | qint8 | quint8 + # dynamic | float16 | float16 | None + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.float32, torch.qint8, torch.quint8), + (torch.float16, torch.float16, None), + ] + assert node.op == 'call_module' + qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Embedding/EmbeddingBag, " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + + module = quantizer.modules[node.target] + qmodule_cls = get_dynamic_quant_module_class(type(module)) + qmodule = qmodule_cls.from_float(module) + parent_name, name = _parent_name(node.target) + setattr(quantizer.modules[parent_name], name, qmodule) + return quantizer.quantized_graph.create_node( + 'call_module', + node.target, + load_arg(quantized=False)(node.args), + load_arg(quantized=False)(node.kwargs)) + ARGS_TO_SKIP = { torch._ops.ops.quantized.hardswish: ['inplace'], torch._ops.ops.quantized.instance_norm: ['running_mean', 'running_var', 'use_input_stats', 'momentum'], } +@register_quant_pattern(torch.nn.ConvTranspose1d) +@register_quant_pattern(torch.nn.ConvTranspose2d) @register_quant_pattern(torch.nn.ELU) +@register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.Hardswish) @register_quant_pattern(torch.nn.InstanceNorm1d) @register_quant_pattern(torch.nn.InstanceNorm2d) @@ -359,19 +549,26 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.functional.hardswish) @register_quant_pattern(torch.nn.functional.instance_norm) @register_quant_pattern(torch.nn.functional.layer_norm) +@register_quant_pattern(torch.nn.functional.leaky_relu) class DefaultNode(QuantizeHandler): ''' Common quantized op, first input and first output will be quantized ''' - def convert(self, quantizer, node, load_arg, debug=False): - if not self.all_nodes: + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + if not self.all_node_args: return NotImplemented assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \ 'call_function are handled in DefaultNode' + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) activation_post_process = quantizer.activation_post_process_map[node.name] if node.op == 'call_module': module = quantizer.modules[node.target] module.activation_post_process = activation_post_process - quantized_module_cls = get_static_quant_module_class(type(module)) + quantized_module_cls = get_static_quant_module_class( + type(module), additional_static_quant_mapping) quantized_module = quantized_module_cls.from_float(module) parent_name, name = _parent_name(node.target) setattr(quantizer.modules[parent_name], name, quantized_module) @@ -381,39 +578,60 @@ def convert(self, quantizer, node, load_arg, debug=False): load_arg(quantized=[0])(node.args), load_arg(quantized=False)(node.kwargs)) else: + assert node.op == "call_function" # call_function scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) + assert not isinstance(node.target, str), "Expecting node.target for " + "call_function to be a function instead of a string" quantized_op = get_quantized_operator(node.target) args = load_arg(quantized=[0])(node.args) - kwargs = load_arg(quantized=False)(node.kwargs) - kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) + kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale, "output_zero_point": zero_point} if quantized_op in ARGS_TO_SKIP: args_to_skip = ARGS_TO_SKIP[quantized_op] for arg in args_to_skip: if arg in kwargs: kwargs.pop(arg) return quantizer.quantized_graph.create_node( - 'call_function', quantized_op, args, kwargs) + "call_function", quantized_op, args, kwargs) # TODO: elu is using scale/zero_point instead of output_scale, output_zero_point @register_quant_pattern(torch.nn.functional.elu) class ELU(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) quantized_op = get_quantized_operator(node.target) args = load_arg(quantized=[0])(node.args) - kwargs = load_arg(quantized=False)(node.kwargs) - kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) + kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale, 'output_zero_point': zero_point} kwargs.pop('inplace') return quantizer.quantized_graph.create_node( 'call_function', quantized_op, args, kwargs) +@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('sigmoid', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_fake_quant) +@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_fake_quant) +@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant) +@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant) +class FixedQParamsOpQuantizeHandler(QuantizeHandler): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + # these ops have quantized equivalents that do not need any extra information @register_quant_pattern(torch.nn.AdaptiveAvgPool1d) @register_quant_pattern(torch.nn.AdaptiveAvgPool2d) @@ -422,25 +640,19 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.nn.AvgPool2d) @register_quant_pattern(torch.nn.AvgPool3d) @register_quant_pattern(torch.nn.Dropout) -@register_quant_pattern(torch.nn.Hardsigmoid) @register_quant_pattern(torch.nn.Hardtanh) -@register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.MaxPool1d) @register_quant_pattern(torch.nn.MaxPool2d) @register_quant_pattern(torch.nn.MaxPool3d) @register_quant_pattern(torch.nn.ReLU) @register_quant_pattern(torch.nn.ReLU6) -@register_quant_pattern(torch.nn.Sigmoid) -@register_quant_pattern(torch.nn.Tanh) @register_quant_pattern(torch.adaptive_avg_pool1d) @register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d) @register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d) @register_quant_pattern(torch.nn.functional.dropout) -@register_quant_pattern(torch.nn.functional.hardsigmoid) @register_quant_pattern(torch.nn.functional.hardtanh) @register_quant_pattern(torch.nn.functional.hardtanh_) @register_quant_pattern(torch.nn.functional.interpolate) -@register_quant_pattern(torch.nn.functional.leaky_relu) @register_quant_pattern(torch.nn.functional.max_pool1d) @register_quant_pattern(torch.nn.functional.max_pool2d) @register_quant_pattern(torch.nn.functional.max_pool3d) @@ -457,11 +669,9 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern(torch.mean) @register_quant_pattern(torch.min) @register_quant_pattern(torch.repeat_interleave) -@register_quant_pattern(torch.sigmoid) @register_quant_pattern(torch.sort) @register_quant_pattern(torch.squeeze) @register_quant_pattern(torch.stack) -@register_quant_pattern(torch.tanh) @register_quant_pattern(torch.unsqueeze) @register_quant_pattern(operator.getitem) @register_quant_pattern(operator.floordiv) @@ -470,10 +680,6 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern('contiguous') @register_quant_pattern('detach') @register_quant_pattern('detach_') -@register_quant_pattern('hardsigmoid') -@register_quant_pattern('hardsigmoid_') -@register_quant_pattern('leaky_relu') -@register_quant_pattern('leaky_relu_') @register_quant_pattern('mean') @register_quant_pattern('numel') @register_quant_pattern('permute') @@ -484,26 +690,26 @@ def convert(self, quantizer, node, load_arg, debug=False): @register_quant_pattern('reshape') @register_quant_pattern('resize_') @register_quant_pattern('shape') -@register_quant_pattern('sigmoid') -@register_quant_pattern('sigmoid_') @register_quant_pattern('size') @register_quant_pattern('squeeze') @register_quant_pattern('squeeze_') -@register_quant_pattern('tanh') -@register_quant_pattern('tanh_') @register_quant_pattern('transpose') @register_quant_pattern('unsqueeze') @register_quant_pattern('unsqueeze_') @register_quant_pattern('view') class CopyNode(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) # Default quantization handler, used for quantization of input and output # of quantizable objects (e.g. modules and functionals) -class DefaultQuant(QuantizeHandler): - def convert(self, quantizer, node): - assert self.all_nodes +class DefaultQuantizeHandler(QuantizeHandler): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + assert self.all_node_args root_module = quantizer.modules[''] return quantize_node( root_module, @@ -511,16 +717,23 @@ def convert(self, quantizer, node): node, quantizer.activation_post_process_map[node.name]) class CustomModuleQuantizeHandler(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: """ Convert a float custom module to quantized custom module """ assert node.op == 'call_module' + assert convert_custom_config_dict is not None + custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) + assert custom_module_class_mapping is not None + qconfig = quantizer.qconfig_map[node.name] observed_custom_module = quantizer.modules[node.target] - if node.name in quantizer.activation_post_process_map: + if activation_is_statically_quantized(qconfig): + assert node.name in quantizer.activation_post_process_map observed_custom_module.activation_post_process = \ quantizer.activation_post_process_map[node.name] - quantized_custom_module_class = \ - get_quantized_custom_module_class(observed_custom_module._FLOAT_MODULE) + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) @@ -531,66 +744,21 @@ def convert(self, quantizer, node, load_arg, debug=False): # module attribute like module._QUANTIZED_INPUT_INDEXES return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) - -# 2. Post Training Dynamic Quantizatoin Patterns -@register_dynamic_quant_pattern(torch.nn.Linear) -@register_dynamic_quant_pattern(torch.nn.functional.linear) -class DynamicLinear(QuantizeHandler): - def __init__(self, quantizer, node): - super().__init__(quantizer, node) - self.linear_node = node - if node.op == 'call_module': - assert isinstance(quantizer.modules[node.target], torch.nn.Linear) - self.linear = quantizer.modules[self.linear_node.target] - - def convert(self, quantizer, node, load_arg, debug=False): - if self.linear_node.op == 'call_module': - quantized = torch.nn.quantized.dynamic.Linear.from_float(self.linear) - parent_name, name = _parent_name(self.linear_node.target) - setattr(quantizer.modules[parent_name], name, quantized) - return quantizer.quantized_graph.create_node( - 'call_module', - self.linear_node.target, - (load_arg(quantized=False)(self.linear_node.args[0]),), - {}) - elif self.linear_node.op == 'call_function': - if debug: - # quantize and dequantize weight - args = load_arg(quantized=[1])(self.linear_node.args) - args = load_arg(quantized=False)(self.linear_node.args) - kwargs = load_arg(quantized=False)(self.linear_node.kwargs) - return quantizer.quantized_graph.create_node( - 'call_function', torch.nn.functional.linear, args, kwargs) - else: - # linear args: - # (x, observed_weight, bias) - # get observer for the weight - weight_observer = quantizer.activation_post_process_map[self.linear_node.args[1].args[0].name] - - if weight_observer.dtype == torch.float16: - linear_weight = load_arg(quantized=False)(self.linear_node.args[1]) - prepack_op = torch.ops.quantized.linear_prepack_fp16 - else: - linear_weight = load_arg(quantized=True)(self.linear_node.args[1]) - prepack_op = torch.ops.quantized.linear_prepack - bias = None - # all args after bias, including bias - other_args = load_arg(quantized=False)(self.linear_node.args[2:]) - kwargs = load_arg(quantized=False)(self.linear_node.kwargs) - if len(self.linear_node.args) > 2: - bias = load_arg(quantized=False)(self.linear_node.args[2]) - other_args = other_args[1:] # remove the bias argument - else: - assert 'bias' in kwargs, \ - 'expect bias provided as a keyword argument when it is not a positional argument' - bias = kwargs['bias'] - kwargs.pop('bias') - prepack_args = (linear_weight, bias) - # pack weight - packed_weight = quantizer.quantized_graph.create_node( - 'call_function', prepack_op, prepack_args, {}) - # construct dynamic linear input - non_quantized_input = load_arg(quantized=False)(self.linear_node.args[0]) - qdynamic_linear_args = (non_quantized_input, packed_weight) - return quantizer.quantized_graph.create_node( - 'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs) +class StandaloneModuleQuantizeHandler(QuantizeHandler): + """ Converts an observed standalone module to quantized standalone module + by calling convert_fx on the observed standalone module. + """ + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: + assert node.op == 'call_module' + qconfig = quantizer.qconfig_map[node.name] + convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore + observed_standalone_module = quantizer.modules[node.target] + input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist() + quantized_standalone_module = convert(observed_standalone_module, debug=debug) + parent_name, name = _parent_name(node.target) + # update the modules dict + setattr(quantizer.modules[parent_name], name, quantized_standalone_module) + quantizer.modules[node.target] = quantized_standalone_module + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs)) diff --git a/torch/quantization/fx/quantization_types.py b/torch/quantization/fx/quantization_types.py new file mode 100644 index 0000000000000..f0c1608292afa --- /dev/null +++ b/torch/quantization/fx/quantization_types.py @@ -0,0 +1,8 @@ +from typing import Union, Callable, Tuple, Any + +Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Callable, Callable]] + +# This is the Quantizer class instance from torch/quantization/fx/quantize.py. +# Define separately to prevent circular imports. +# TODO(future PR): improve this. +QuantizerCls = Any diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 67e538b404338..3329424551237 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -1,36 +1,50 @@ import torch -from torch.fx import ( +from torch.fx import ( # type: ignore GraphModule, Proxy, + map_arg ) from torch.fx.graph import ( Graph, Node, - map_arg, ) +from torch.fx.node import Argument + from torch.quantization import ( propagate_qconfig_, convert, ) from ..quantization_mappings import ( - get_qat_module_mappings, + get_default_qat_module_mappings, ) -from ..custom_module_class_mappings import ( - is_custom_module_class, - get_observed_custom_module_class, - mark_observed_custom_module, - is_observed_custom_module, + +from ..quantize import ( + _remove_qconfig, + is_activation_post_process ) -from ..quantize import _remove_qconfig +from ..utils import ( + get_combined_dict, + get_swapped_custom_module_class, + activation_is_statically_quantized, +) from .pattern_utils import ( is_match, - get_quant_patterns, - get_dynamic_quant_patterns, + get_default_quant_patterns, + get_default_output_activation_post_process_map, + input_output_observed, + Pattern, +) + +from .observed_module import ( + mark_observed_module, + is_observed_module, + mark_observed_standalone_module, + is_observed_standalone_module, ) from .quantization_patterns import * @@ -38,160 +52,249 @@ from .utils import ( _parent_name, quantize_node, + get_custom_module_class_keys, + get_new_attr_name_with_prefix, + collect_producer_nodes, + graph_module_from_producer_nodes, + assert_and_get_unique_device, ) -from collections import OrderedDict -import copy -import re +from .qconfig_utils import * + +from typing import Optional, Dict, Any, List, Tuple, Set, Callable + +# Define helper types +MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler, + QConfigAny] # ------------------------ # Helper Functions # ------------------------ -# Returns a function that can get a new attribute name for module with given prefix -# for example, -# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') -# >> new_name = get_new_observer_name(module) -# new_name will be an unused attribute name on module, e.g. `_observer_1` -def get_new_attr_name_with_prefix(prefix): - def get_new_attr_name(module): - def get_attr_name(i): - return prefix + str(i) - i = 0 - attr_name = get_attr_name(i) - while hasattr(module, attr_name): - i += 1 - attr_name = get_attr_name(i) - return attr_name - return get_new_attr_name - -def collect_producer_nodes(node): - r''' Starting from a target node, trace back until we hit inpu or - getattr node. This is used to extract the chain of operators - starting from getattr to the target node, for example - def forward(self, x): - observed = self.observer(self.weight) - return F.linear(x, observed) - collect_producer_nodes(observed) will either return a list of nodes that produces - the observed node or None if we can't extract a self contained graph without - free variables(inputs of the forward function). - ''' - nodes = [node] - frontier = [node] - while frontier: - node = frontier.pop() - all_args = list(node.args) + list(node.kwargs.values()) - for arg in all_args: - if not isinstance(arg, Node): - continue - if arg.op == 'placeholder': - # hit input, can't fold in this case - return None - nodes.append(arg) - if not (arg.op == 'call_function' and arg.target == getattr): - frontier.append(arg) - return nodes - -def graph_module_from_producer_nodes(root, producer_nodes): - r''' Construct a graph module from extracted producer nodes - from `collect_producer_nodes` function - Args: - root: the root module for the original graph - producer_nodes: a list of nodes we use to construct the graph - Return: - A graph module constructed from the producer nodes - ''' - assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' - # since we traced back from node to getattrr - producer_nodes.reverse() - graph = Graph() - env = {} - - def load_arg(a): - return map_arg(a, lambda node: env[node.name]) - for producer_node in producer_nodes: - env[producer_node.name] = graph.node_copy(producer_node, load_arg) - graph.output(load_arg(producer_nodes[-1].name)) - graph_module = GraphModule(root, graph) - return graph_module - - -def assert_and_get_unique_device(module): +def insert_observer( + node: Node, observer: torch.quantization.ObserverBase, + model: torch.nn.Module, + activation_post_process_map: Dict[str, torch.quantization.ObserverBase], + env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable, + observed_node_names_set: Set[str]): + """Insert observer for node by modifying the observed_graph and + attach observer module to the model + Args: + node: Node + observer: observer/fake_quantize module instance """ - Returns the unique device for a module, or None if no device is found. - Throws an error if multiple devices are detected. + # respect device affinity when adding observers + model_device = assert_and_get_unique_device(model) + if model_device: + observer.to(model_device) + # add observer module as attribute + prefix = node.name + '_activation_post_process_' + get_new_observer_name = get_new_attr_name_with_prefix(prefix) + observer_name = get_new_observer_name(model) + setattr(model, observer_name, observer) + # put observer instance activation_post_process map + assert activation_post_process_map is not None + activation_post_process_map[node.name] = observer + # insert observer call + env[node.name] = observed_graph.create_node( + 'call_module', observer_name, (load_arg(node),), {}) + observed_node_names_set.add(node.name) + +def maybe_insert_observer_for_special_module( + quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module], + prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]: + """ Insert observer for custom module and standalone module + Returns: standalone_module_input_idxs: the indexs for inputs that + needs to be observed by parent module """ - devices = {p.device for p in module.parameters()} | \ - {p.device for p in module.buffers()} - assert len(devices) <= 1, ( - "prepare only works with cpu or single-device CUDA modules, " - "but got devices {}".format(devices) - ) - device = next(iter(devices)) if len(devices) > 0 else None - return device - -def is_activation_post_process(module): - return (isinstance(module, torch.quantization.ObserverBase) or - isinstance(module, torch.quantization.FakeQuantize)) - -def is_submodule_of_fake_quant(name, module, named_modules): - parent_name, _ = _parent_name(name) - return is_activation_post_process(named_modules[parent_name]) - -def get_flattened_qconfig_dict(qconfig_dict): - """ flatten the global, object_type and module_name qconfig - to the same qconfig_dict so that it can be used by - propagate_qconfig_ function. - "module_name_regex" is ignored for now since it's not supported - in propagate_qconfig_, but it can be fixed later. - - For example: - Input: { - "": qconfig, - "object_type": [ - (torch.add, qconfig) - ], - "module_name": [ - ("conv", qconfig) - ] - } - - Output: { - "": qconfig, - torch.add: qconfig, - "conv": qconfig - } - """ - flattened = dict() - if '' in qconfig_dict: - flattened[''] = qconfig_dict[''] - - def flatten_key(key): - if key in qconfig_dict: - for obj, qconfig in qconfig_dict[key]: - flattened[obj] = qconfig - - flatten_key('object_type') - flatten_key('module_name') - return flattened - -def convert_dict_to_ordered_dict(qconfig_dict): - """ Convert dict in qconfig_dict to ordered dict + assert modules is not None + standalone_module_input_idxs = None + if isinstance(quantize_handler, CustomModuleQuantizeHandler): + custom_module = modules[node.target] # type: ignore + custom_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {}) + observed_custom_module_class = \ + get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig) + observed_custom_module = \ + observed_custom_module_class.from_float(custom_module) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, observed_custom_module) + elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): + # observe standalone module + standalone_module = modules[node.target] # type: ignore + standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", []) + standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", []) + class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs} + name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs} + config = class_config_map.get(type(standalone_module), (None, None)) + config = name_config_map.get(node.target, config) + sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0] + sm_prepare_config_dict = {} if config[1] is None else config[1] + prepare = \ + torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore + observed_standalone_module = \ + prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict) + standalone_module_input_idxs = observed_standalone_module.\ + _standalone_module_input_quantized_idxs.int().tolist() + observed_standalone_module = mark_observed_standalone_module( + observed_standalone_module) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, + observed_standalone_module) + modules[node.target] = observed_standalone_module # type: ignore + return standalone_module_input_idxs + +def insert_observer_for_output_of_the_node( + node: Node, + quantize_handler: QuantizeHandler, + qconfig: Any, + modules: Dict[str, torch.nn.Module], + model: torch.nn.Module, + pattern: Any, + activation_post_process_map: Dict[str, torch.quantization.ObserverBase], + env: Dict[Any, Any], + observed_graph: Graph, + load_arg: Callable, + observed_node_names_set: Set[str], + matched_nodes: Optional[List[Node]], + standalone_module_input_idxs: Optional[List[int]]): + """ Insert observer/fake_quantize module for output of the observed + module if needed """ - # convert a qconfig list for a type to OrderedDict - def _convert_to_ordered_dict(key, qconfig_dict): - qconfig_dict[key] = OrderedDict(qconfig_dict.get(key, [])) - - _convert_to_ordered_dict('object_type', qconfig_dict) - _convert_to_ordered_dict('module_name_regex', qconfig_dict) - _convert_to_ordered_dict('module_name', qconfig_dict) + # don't need to insert observer for output if activation does not + # need to be statically quantized + assert modules is not None + if activation_is_statically_quantized(qconfig): + if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \ + and model.training: + # we only insert fake quantize module in qat + assert pattern is not None + activation_post_process_ctr = \ + get_default_output_activation_post_process_map().get( + pattern, None) + assert activation_post_process_ctr is not None, \ + "activation_post_process constructor not provided " + \ + "for pattern:" + str(pattern) + insert_observer( + node, activation_post_process_ctr(), + model, activation_post_process_map, env, observed_graph, + load_arg, observed_node_names_set) + elif (isinstance(quantize_handler, + FixedQParamsOpQuantizeHandler) and + not model.training) or \ + isinstance(quantize_handler, CopyNode): + # inserting observers for output of observed module, or + # mark the output as observed + assert node.op in [ + 'call_module', + 'call_function', + 'call_method'], \ + 'CopyNode of type ' + node.op + ' is not handled' + + def is_observed(input_arg): + if isinstance(input_arg, Node): + return input_arg.name in observed_node_names_set + elif isinstance(input_arg, list): + return all(map(is_observed, input_arg)) + # propagate observed property from input + if is_observed(node.args[0]): + observed_node_names_set.add(node.name) + elif ((isinstance(quantize_handler, Add) or + isinstance(quantize_handler, Mul)) and + quantize_handler.num_node_args == 1): + assert matched_nodes is not None + input_node = matched_nodes[-1] # first node in the sequence + + def input_is_observed(arg): + return (isinstance(arg, Node) and + arg.name in observed_node_names_set) + # This is checking if one of the argument of add/mul + # is an observed node + # If both of the inputs are number, + # we will not consider the output to be observed + if (input_is_observed(input_node.args[0]) or + input_is_observed(input_node.args[1])): + observed_node_names_set.add(node.name) + elif isinstance(quantize_handler, + StandaloneModuleQuantizeHandler): + assert node.op == "call_module" + assert isinstance(node.target, str) + sm_out_qidxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore + output_is_quantized = 0 in sm_out_qidxs + + if output_is_quantized: + observed_node_names_set.add(node.name) + elif (quantize_handler.all_node_args and + input_output_observed(quantize_handler)): + # observer for outputs + new_observer = qconfig.activation() + insert_observer( + node, new_observer, model, + activation_post_process_map, env, observed_graph, + load_arg, observed_node_names_set) + + # insert observer for input of standalone module + if standalone_module_input_idxs is not None: + for idx in standalone_module_input_idxs: + if node.args[idx].name not in observed_node_names_set: # type: ignore + new_observer = qconfig.activation() + insert_observer( + node, new_observer, model, + activation_post_process_map, env, observed_graph, + load_arg, observed_node_names_set) + +def insert_observer_for_input_arg_of_observed_node( + node: Node, observed_node_names_set: Set[str], + quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]], + model: torch.nn.Module, + activation_post_process_map: Dict[str, torch.quantization.ObserverBase], + env: Dict[str, str], observed_graph: Graph, + load_arg: Callable): + if node.name not in observed_node_names_set and node.name in quants: + _, activation_post_process_ctr = quants[node.name] + if activation_post_process_ctr is not None: + insert_observer( + node, activation_post_process_ctr(), + model, activation_post_process_map, + env, observed_graph, load_arg, observed_node_names_set) # A dictionary for querying the weight index for a given op WEIGHT_INDEX_DICT = { + torch.nn.functional.conv1d : [1], torch.nn.functional.conv2d : [1], + torch.nn.functional.conv3d : [1], torch.nn.functional.linear : [1], } +def node_arg_is_weight(node: Node, arg: Any) -> bool: + if isinstance(node, Node) and node.op == 'call_function' and \ + node.target in WEIGHT_INDEX_DICT: + for i, node_arg in enumerate(node.args): + if arg is node_arg and i in \ + WEIGHT_INDEX_DICT[node.target]: # type: ignore + return True + return False + +CONV_OPS_WITH_BIAS = { + torch.nn.functional.conv1d, + torch.nn.functional.conv2d, + torch.nn.functional.conv3d, +} +CONV_BIAS_ARG_INDEX = 2 + +def node_arg_is_bias(node: Node, arg: Any) -> bool: + if isinstance(node, Node) and node.op == 'call_function': + if node.target in CONV_OPS_WITH_BIAS: + for i, node_arg in enumerate(node.args): + if arg is node_arg and i == CONV_BIAS_ARG_INDEX: + return True + elif node.target is torch.nn.functional.linear: + for kwarg_name, kwarg_value in node.kwargs.items(): + if kwarg_name == 'bias' and arg is kwarg_value: + return True + return False + # weight prepacking ops WEIGHT_PREPACK_OPS = { torch._ops.ops.quantized.linear_prepack, @@ -203,10 +306,11 @@ class Quantizer: def __init__(self): # mapping from matched node to activation_post_process # must be filled before convert - self.activation_post_process_map = None + self.activation_post_process_map: Optional[ + Dict[str, torch.quantization.observer.ObserverBase]] = None # mapping from node name to qconfig that should be used for that node # filled out for a model during _generate_qconfig_map - self.qconfig_map = None + self.qconfig_map: Optional[Dict[str, QConfigAny]] = None # mapping from fully qualified module name to module instance # for example, # { @@ -214,7 +318,7 @@ def __init__(self): # 'linear': Linear(...), # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), # } - self.modules = None + self.modules: Optional[Dict[str, torch.nn.Module]] = None # mapping from a tuple of nodes in reverse order to uninitialized # QuantizeHandler subclass. For example, # { @@ -225,232 +329,267 @@ def __init__(self): # ((, ): # ), # } - self.patterns = None - - - def _qat_swap_modules(self, root): - convert(root, mapping=get_qat_module_mappings(), inplace=True, remove_qconfig=False) - - def _generate_qconfig_map(self, - root, - input_graph, - qconfig_dict): - global_qconfig = qconfig_dict.get('', None) - - def get_module_type_qconfig( - module_type, fallback_qconfig=global_qconfig): - return qconfig_dict['object_type'].get(module_type, fallback_qconfig) - - def get_function_qconfig( - function, fallback_qconfig=global_qconfig): - return qconfig_dict['object_type'].get(function, fallback_qconfig) - - def get_module_name_regex_qconfig( - module_name, fallback_qconfig=global_qconfig): - for regex_pattern, qconfig in qconfig_dict['module_name_regex'].items(): - if re.match(regex_pattern, module_name): - # first match wins - return qconfig - return fallback_qconfig - - def get_module_name_qconfig( - module_name, fallback_qconfig=global_qconfig): - if module_name == '': - # module name qconfig not found - return fallback_qconfig - if module_name in qconfig_dict['module_name']: - return qconfig_dict['module_name'][module_name] - else: - parent, _ = _parent_name(module_name) - return get_module_name_qconfig(parent, fallback_qconfig) - - # get qconfig for module_name, - # fallback to module_name_regex_qconfig, module_type_qconfig, global_qconfig - # if necessary - def get_qconfig(module_name): - module_type_qconfig = \ - get_module_type_qconfig(type(self.modules[module_name])) - module_name_regex_qconfig = \ - get_module_name_regex_qconfig(module_name, module_type_qconfig) - module_name_qconfig = \ - get_module_name_qconfig(module_name, module_name_regex_qconfig) - return module_name_qconfig + self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None + self.prepare_custom_config_dict: Dict[str, Any] = {} + + + def _qat_swap_modules( + self, root: torch.nn.Module, + additional_qat_module_mapping: Dict[Callable, Callable]) -> None: + all_mappings = get_combined_dict( + get_default_qat_module_mappings(), additional_qat_module_mapping) + convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False) + + def _generate_qconfig_map( + self, + root: torch.nn.Module, + input_graph: Graph, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, type]]) -> None: + global_qconfig = qconfig_dict.get("", None) self.qconfig_map = dict() for node in input_graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": module_name, _ = _parent_name(node.target) - self.qconfig_map[node.name] = get_qconfig(module_name) - elif node.op == 'call_function': - # precedence: [TODO] module_name_qconfig (need scope support from fx) + assert self.modules is not None + self.qconfig_map[node.name] = get_qconfig( + qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig) + elif node.op == "call_function": + # precedence: [TODO] module_name_qconfig (need scope support + # from fx) # > function_qconfig > global_qconfig - function_qconfig = get_function_qconfig(node.target) + function_qconfig = get_object_type_qconfig( + qconfig_dict, node.target, global_qconfig) self.qconfig_map[node.name] = function_qconfig - elif node.op == 'call_method': - self_obj = node.args[0] - # qconfig for call_method should be the same as the `self` object for the call - self.qconfig_map[node.name] = self.qconfig_map[self_obj.name] + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # use the qconfig of the module that the node belongs to + qconfig = get_qconfig( + qconfig_dict, module_type, module_path, global_qconfig) + self.qconfig_map[node.name] = qconfig elif node.op == 'call_module': - module_qconfig = get_qconfig(node.target) - # regex is not supported eager mode propagate_qconfig_, we'll need to - # set the qconfig explicitly here in case regex + assert self.modules is not None + module_qconfig = get_qconfig( + qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig) + # regex is not supported eager mode propagate_qconfig_, we'll + # need to set the qconfig explicitly here in case regex # is used self.modules[node.target].qconfig = module_qconfig self.qconfig_map[node.name] = module_qconfig - def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant): - if not inplace: - model = copy.deepcopy(model) - self.is_dynamic_quant = is_dynamic_quant - if self.is_dynamic_quant: - self.patterns = get_dynamic_quant_patterns() - else: - self.patterns = get_quant_patterns() + def _prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, type]], + prepare_custom_config_dict: Optional[Dict[str, Any]], + is_standalone_module: bool) -> GraphModule: + """ standalone_module means it a submodule that is not inlined in + parent module, and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + Returns: + model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + """ + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + self.prepare_custom_config_dict = prepare_custom_config_dict + + additional_quant_patterns = \ + prepare_custom_config_dict.get("additional_quant_pattern", {}) + self.patterns = get_combined_dict( + get_default_quant_patterns(), additional_quant_patterns) flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: - self._qat_swap_modules(model) + additional_qat_module_mapping = prepare_custom_config_dict.get( + "additional_qat_module_mapping", {}) + self._qat_swap_modules(model, additional_qat_module_mapping) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches - self._generate_qconfig_map(model, model.graph, qconfig_dict) + self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope) # match the patterns that will get quantized - matches = self._find_matches(model.graph, self.modules, self.patterns) + standalone_module_name_configs = prepare_custom_config_dict.get( + "standalone_module_name", []) + standalone_module_class_configs = prepare_custom_config_dict.get( + "standalone_module_class", []) + + standalone_module_names = [config[0] for config in standalone_module_name_configs] + standalone_module_classes = [config[0] for config in standalone_module_class_configs] + custom_module_classes = get_custom_module_class_keys( + prepare_custom_config_dict, "float_to_observed_custom_module_class") + assert self.patterns is not None + matches = self._find_matches( + model.graph, self.modules, self.patterns, standalone_module_names, + standalone_module_classes, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, - # initialize an DefaultQuant object for each - quants = self._find_quants(model.graph, matches) + # initialize an DefaultQuantizeHandler object for each + quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \ + self._find_quants(model.graph, matches) self.activation_post_process_map = dict() - - env = {} + env: Dict[Any, Any] = {} observed_graph = Graph() - observed_node_names_set = set() + observed_node_names_set: Set[str] = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) + graph_inputs = [] for node in model.graph.nodes: + if node.op == 'placeholder': + graph_inputs.append(node.name) + + get_new_observer_name = get_new_attr_name_with_prefix( + 'activation_post_process_') + + placeholder_node_seen_cnt = 0 + output_node_seen_cnt = 0 + input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( + "input_quantized_idxs", []) + output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( + "output_quantized_idxs", []) + + result_node : Optional[Node] = None + for node in model.graph.nodes: + if node.op == 'output': + # If this output is hardcoded to be quantized, insert an + # observer on the previous node if it does not already + # exist. + cur_output_node_idx = output_node_seen_cnt + output_node_seen_cnt += 1 + if cur_output_node_idx in output_quantized_idxs: + prev_node = node.args[0] + assert isinstance(prev_node, Node), \ + ('hardcoding list/dict outputs to be quantized is ' + + 'not supported') + if prev_node.name not in observed_node_names_set: + assert self.qconfig_map is not None + local_qconfig = self.qconfig_map[prev_node.name] + assert local_qconfig is not None, \ + 'qconfig of a node before a quantized output must exist' + insert_observer( + prev_node, local_qconfig.activation(), + model, self.activation_post_process_map, + env, observed_graph, load_arg, observed_node_names_set) + + observed_graph.output(load_arg(node.args[0])) + result_node = node + continue + if node.name in observed_node_names_set: continue - prefix = node.name + '_activation_post_process_' - root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) + root_node, matched_nodes, pattern, obj, qconfig = matches.get( + node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) - if qconfig is None: - continue + # index for input of custom module that needs to be observed in + # parent + if qconfig is not None: + assert obj is not None + standalone_module_input_idxs = \ + maybe_insert_observer_for_special_module( + obj, self.modules, prepare_custom_config_dict, qconfig, + node) + insert_observer_for_output_of_the_node( + node, obj, qconfig, self.modules, model, pattern, + self.activation_post_process_map, env, + observed_graph, load_arg, observed_node_names_set, + matched_nodes, standalone_module_input_idxs) + else: + env[node.name] = observed_graph.node_copy(node, load_arg) - def insert_observer(node, observer, device): - get_new_observer_name = get_new_attr_name_with_prefix(prefix) - observer_name = get_new_observer_name(model) - setattr(model, observer_name, observer) - self.activation_post_process_map[node.name] = observer - env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) + if node.op == 'placeholder': + # skip adding observers at the graph input if the input is + # overriden to be quantized + cur_placeholder_node_idx = placeholder_node_seen_cnt + placeholder_node_seen_cnt += 1 + if cur_placeholder_node_idx in input_quantized_idxs: observed_node_names_set.add(node.name) - if device: - getattr(model, observer_name).to(device) - - if isinstance(obj, CustomModuleQuantizeHandler): - custom_module = self.modules[node.target] - observed_custom_module_class = \ - get_observed_custom_module_class(type(custom_module)) - observed_custom_module = \ - observed_custom_module_class.from_float(custom_module) - mark_observed_custom_module(observed_custom_module, type(custom_module)) - parent_name, name = _parent_name(node.target) - setattr(self.modules[parent_name], name, observed_custom_module) - - # don't need to insert observer for output in dynamic quantization - if self.is_dynamic_quant: continue - # inserting observers for output of observed module, or mark the output - # as observed - if isinstance(obj, CopyNode): - assert node.op in [ - 'call_module', - 'call_function', - 'call_method'], \ - 'CopyNode of type ' + node.op + ' is not handled' - - def is_observed(input_arg): - if isinstance(input_arg, Node): - return input_arg.name in observed_node_names_set - elif isinstance(input_arg, list): - return all(map(is_observed, input_arg)) - # propagate observed property from input - if is_observed(node.args[0]): - observed_node_names_set.add(node.name) - elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: - if node.args[0].name in observed_node_names_set: - observed_node_names_set.add(node.name) - elif qconfig is not None and obj.all_nodes: - # observer for outputs - new_observer = qconfig.activation() - # respect device affinity when adding observers - device = assert_and_get_unique_device(model) - insert_observer(node, new_observer, device) - else: - env[node.name] = observed_graph.node_copy(node, load_arg) + insert_observer_for_input_arg_of_observed_node( + node, observed_node_names_set, quants, + model, self.activation_post_process_map, env, + observed_graph, load_arg) - if node.name not in observed_node_names_set and node.name in quants: - get_new_observer_name = get_new_attr_name_with_prefix(prefix) - observer_name = get_new_observer_name(model) - _, qconfig, is_weight = quants[node.name] - if qconfig is not None: - new_observer = \ - qconfig.weight() if is_weight else qconfig.activation() - # respect device affinity when adding observers - device = assert_and_get_unique_device(model) - if device: - new_observer.to(device) - self.activation_post_process_map[node.name] = new_observer - setattr(model, observer_name, self.activation_post_process_map[node.name]) - env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) - observed_node_names_set.add(node.name) - observed_graph.output(load_arg(model.graph.result)) model = GraphModule(model, observed_graph) self.save_state(model) + model = mark_observed_module(model) + if is_standalone_module: + assert result_node is not None + assert isinstance(result_node.args[0], Node), \ + "standalone module only supports returning simple value currently"\ + "(not tuple, dict etc.)" + # indicator for whether output is observed or not. + # This used for correctly quantize standalone modules + output_is_observed = \ + result_node.args[0].name in observed_node_names_set + # these inputs are observed in parent + # converting List[int] to Tensor since module attribute is + # Union[Tensor, Module] + model._standalone_module_input_quantized_idxs = \ + torch.Tensor(input_quantized_idxs) + model._standalone_module_output_quantized_idxs = torch.Tensor(output_quantized_idxs) return model - def save_state(self, observed): - observed._activation_post_process_map = self.activation_post_process_map - observed._patterns = self.patterns - observed._qconfig_map = self.qconfig_map - - def restore_state(self, observed): - err_msg = 'please make sure the model is produced by prepare' - assert hasattr(observed, '_activation_post_process_map'), 'did not found ' + \ - '_activation_post_process attribute ' + err_msg - assert hasattr(observed, '_patterns'), 'did not found ' + \ - '_patterns attribute ' + err_msg - assert hasattr(observed, '_qconfig_map'), 'did not found ' + \ - '_qconfig_map attribute ' + err_msg - self.activation_post_process_map = observed._activation_post_process_map - self.patterns = observed._patterns - self.qconfig_map = observed._qconfig_map - - def prepare(self, model, qconfig_dict, inplace=False): - return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=False) - - def prepare_dynamic(self, model, qconfig_dict, inplace=False): - return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=True) - - def _run_weight_observers(self, observed): - r''' Extract the subgraph that produces the weight for dynamically quantized - node and run the subgraph to observe the weight. - Note that the observers of dynamically quantized modules are run during - the conversion step. + def save_state(self, observed: GraphModule) -> None: + observed._activation_post_process_map = \ + self.activation_post_process_map # type: ignore + observed._patterns = self.patterns # type: ignore + observed._qconfig_map = self.qconfig_map # type: ignore + observed._prepare_custom_config_dict = \ + self.prepare_custom_config_dict # type: ignore + + def restore_state(self, observed: GraphModule) -> None: + assert is_observed_module(observed), \ + 'incoming model must be produced by prepare_fx' + self.activation_post_process_map = \ + observed._activation_post_process_map # type: ignore + self.patterns = observed._patterns # type: ignore + self.qconfig_map = observed._qconfig_map # type: ignore + self.prepare_custom_config_dict = \ + observed._prepare_custom_config_dict # type: ignore + + def prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, type]], + prepare_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: + return self._prepare( + model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict, + is_standalone_module) + + def _run_weight_observers(self, observed: GraphModule) -> None: + r''' Extract the subgraph that produces the weight for dynamic quant + or weight only quant node and run the subgraph to observe the weight. + Note that the observers of dynamic quant or weight only quant ops are + run during the convert step. ''' for node in observed.graph.nodes: if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: @@ -459,53 +598,72 @@ def _run_weight_observers(self, observed): # node_arg is weight weight_observer_nodes = collect_producer_nodes(node_arg) if weight_observer_nodes is not None: - weight_observer_module = graph_module_from_producer_nodes( - observed, weight_observer_nodes) + weight_observer_module = \ + graph_module_from_producer_nodes( + observed, weight_observer_nodes) # run the weight observer weight_observer_module() return - def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False): + def _convert(self, model: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: + """ standalone_module means it a submodule that is not inlined in + parent module, and will be quantized separately as one unit. + + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details + """ + if convert_custom_config_dict is None: + convert_custom_config_dict = {} self.restore_state(model) - if not inplace: - model = copy.deepcopy(model) - self.is_dynamic_quant = is_dynamic_quant - # run weight observers before inserting quant dequant nodes - # for dynamic quantization - if self.is_dynamic_quant: - self._run_weight_observers(model) + # always run weight observers in the top level forward method + # for dynamic quant ops or weight only quant ops + self._run_weight_observers(model) # move to cpu since we only have quantized cpu kernels model.eval().cpu() self.modules = dict(model.named_modules()) - matches = self._find_matches(model.graph, self.modules, self.patterns) + custom_module_classes = get_custom_module_class_keys( + convert_custom_config_dict, + "observed_to_quantized_custom_module_class") + assert self.patterns is not None + matches = self._find_matches( + model.graph, self.modules, self.patterns, + custom_module_classes=custom_module_classes) + + quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \ + self._find_quants(model.graph, matches) - quants = self._find_quants(model.graph, matches) self.quantized_graph = Graph() - env = {} - quant_env = {} + env: Dict[str, Node] = {} + quant_env: Dict[str, Node] = {} + + graph_inputs: List[str] = [] + for node in model.graph.nodes: + if node.op == 'placeholder': + graph_inputs.append(node.name) - def load_non_quantized(n): + def load_non_quantized(n: Node) -> Node: if n.name not in env: assert n.name in quant_env, \ - 'trying to load float node but did not find node:' + n.name + \ - ' in quantized or non quantized environment, env: ' + str(env) + \ - ' quant_env:' + str(quant_env) + 'trying to load float node but did not find ' + \ + 'node:' + n.name + \ + ' in quantized or non quantized environment, env: ' + \ + str(env) + ' quant_env:' + str(quant_env) env[n.name] = Proxy(quant_env[n.name]).dequantize().node return env[n.name] - def load_quantized(n): - if n.name not in quant_env: - assert n.name in env, \ - 'trying to load quantized node but did not find node:' + n.name + \ - ' in float environment:' + str(env) - assert n.name in quants, 'did not find quant object for node:' + n.name - quant = quants[n.name][0] - quant_env[n.name] = quant.convert(self, env[n.name]) + def load_quantized(n: Node) -> Node: + assert n.name in quant_env, \ + 'trying to load quantized node but did not find node:' + \ + n.name + ' in quant environment:' + str(quant_env) return quant_env[n.name] - def load_x(n): + def load_x(n: Node) -> Node: assert n.name in env or n.name in quant_env, \ 'node ' + n.name + ' does not exist in either environment' if n.name in quant_env: @@ -513,76 +671,182 @@ def load_x(n): else: return env[n.name] - def load_arg(quantized): + def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] + ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - - if quantized is a list or tuple, then arg should be a list and the args with corresponding - indexes will be quantized - - if quantized is a boolean, then all args will be quantized/not quantized - - if quantized is None, then we'll load the node as long as it exists - - Output: fn which takes arg_or_args, and loads them from the corresponding - environment depending on the value of quantized. + - if quantized is None, then we'll load the node as long as it + exists + - if quantized is a boolean, then all args will be + quantized/not quantized + - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) + - if quantized is a list or tuple, then arg should be a list and + the args with corresponding indexes will be quantized + + Output: fn which takes arg_or_args, and loads them from the + corresponding environment depending on the value of quantized. """ - assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized) + assert quantized is None or \ + isinstance(quantized, (tuple, list, bool)), type(quantized) + if isinstance(quantized, (tuple, list)) and len(quantized) == 0: + # empty tuple or list means nothing is quantized + quantized = False def load_arg_impl(arg_or_args): - if quantized is None: + # we'll update the format of `quantized` + # to better match arg_or_args + updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized + + if isinstance(quantized, (tuple, list)) and \ + len(quantized) == 1 and isinstance(arg_or_args, Node): + # when argument is one Node instead of tuple, we just need to check + # 0 is in the quantized list + updated_quantized = 0 in quantized + + if updated_quantized is None: return map_arg(arg_or_args, load_x) - if isinstance(quantized, bool): - return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized) - elif isinstance(quantized, (tuple, list)): + if isinstance(updated_quantized, bool): + return map_arg( + arg_or_args, + load_quantized if updated_quantized else load_non_quantized) + elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): - if i in quantized: + if i in updated_quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) return type(arg_or_args)(loaded_args) return load_arg_impl - def is_quantized(node): - if isinstance(node, Node): - assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment' - # there might be nodes appearing in both environemnts, but quant_env will take - # precedence - if node.name in quant_env: + def node_arg_is_quantized(node_arg: Any) -> bool: + if isinstance(node_arg, Node): + assert node_arg.name in env or node_arg.name in quant_env, \ + 'Expecting node_arg to be in the environment' + # there might be nodes appearing in both environemnts, but + # quant_env will take precedence + if node_arg.name in quant_env: return True - elif node.name in env: + elif node_arg.name in env: + return False + else: return False - elif isinstance(node, list): - quantized = map(is_quantized, node) + elif isinstance(node_arg, list): + quantized = map(node_arg_is_quantized, node_arg) if all(quantized): return True elif not any(quantized): return False else: - raise Exception("partially quantized inputs in list not handled yet") + raise Exception( + "partially quantized inputs in list not handled yet") + else: + return False + + def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: + """ Check if output node is quantized or not """ + assert self.modules is not None + # by default the output for a quantizable node is expected to be quantized + quantized = True + + # Need to get correct quantized/non-quantized state forn the output + # of CopyNode + if type(obj) in [ + CopyNode, + FixedQParamsOpQuantizeHandler + ]: + assert node.op in [ + 'call_module', + 'call_function', + 'call_method'], \ + 'CopyNode of type ' + node.op + ' is not handled' + quantized = node_arg_is_quantized(node.args[0]) + + if not activation_is_statically_quantized(qconfig) or \ + not input_output_observed(obj): + quantized = False + + return quantized + + def insert_quantize_node(node: Node) -> None: + """ Given a activation_post_process module call node, insert a + quantize node""" + assert self.modules is not None + assert isinstance(node.target, str) + observer_module = self.modules[node.target] + prev_node = node.args[0] + if observer_module.dtype == torch.float16: + # activations are not quantized for + # fp16 dynamic quantization + # copy the activaiton_post_process node here + # since we may need it when we insert prepack + # op for weight of linear, this will be removed + # later in a separate pass + env[node.name] = self.quantized_graph.node_copy( + node, load_non_quantized) + elif isinstance(prev_node, Node) and prev_node.name in quant_env: + # if previous node is already quantized, we'll just remove the + # activation_post_process + quant_env[node.name] = quant_env[prev_node.name] + else: + # replace activation post process with quantization ops + root_module = self.modules[""] + assert isinstance(node.args[0], Node) + quant_env[node.name] = quantize_node( + root_module, self.quantized_graph, + load_non_quantized(node.args[0]), observer_module) + + # additional state to override inputs to be quantized, if specified + # by the user + placeholder_node_seen_cnt = 0 + output_node_seen_cnt = 0 + input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( + "input_quantized_idxs", []) + output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get( + "output_quantized_idxs", []) for node in model.graph.nodes: - root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) + if node.op == "output": + cur_output_node_idx = output_node_seen_cnt + output_node_seen_cnt += 1 + if cur_output_node_idx in output_quantized_idxs: + # Result are kept quantized if the user specified the + # output_quantized_idxs override. + graph_output = map_arg(node.args[0], load_x) + else: + graph_output = map_arg(node.args[0], load_non_quantized) + self.quantized_graph.output(graph_output) + continue + root_node, matched, matched_pattern, obj, qconfig = \ + matches.get(node.name, (None, None, None, None, None)) if root_node is node: - if qconfig is None: - result = self.quantized_graph.node_copy(node, load_non_quantized) + is_observed_standalone_module_node = ( + node.op == 'call_module' and + is_observed_standalone_module( + self.modules[node.target]) # type: ignore + ) + if qconfig is None and not is_observed_standalone_module_node: + result = self.quantized_graph.node_copy( + node, load_non_quantized) quantized = False else: - result = obj.convert(self, node, load_arg) - # Need to get correct quantized/non-quantized state for the output of CopyNode - if isinstance(obj, CopyNode): - assert node.op in [ - 'call_module', - 'call_function', - 'call_method'], \ - 'CopyNode of type ' + node.op + ' is not handled' - quantized = is_quantized(node.args[0]) - else: - quantized = True - - # output of dynamic quantization is not quantized - if self.is_dynamic_quant: - quantized = False + assert obj is not None + # We will get whether the output is quantized or not before + # convert for standalone module and after convert + # for non-standalone module, since _standalone_module_output_quantized_idxs + # is only available in observed standalone module + if is_observed_standalone_module_node: + out_quant_idxs = self.modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore + assert len(out_quant_idxs) <= 1, "Currently standalone only support one output" + quantized = 0 in out_quant_idxs + + result = obj.convert( + self, node, load_arg, debug=debug, + convert_custom_config_dict=convert_custom_config_dict) + if not is_observed_standalone_module_node: + quantized = is_output_quantized(node, obj) if quantized: quant_env[node.name] = result @@ -593,63 +857,51 @@ def is_quantized(node): continue # handle activation post process calls - if node.op == 'call_module': - if is_activation_post_process(self.modules[node.target]): - observer_module = self.modules[node.target] - prev_node = node.args[0] - if observer_module.dtype == torch.float16: - # activations are not quantized for - # fp16 dynamic quantization - # copy the activaiton_post_process node here - # since we may need it when we insert prepack - # op for weight of linear, this will be removed - # later in a separate pass - env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) - continue - if prev_node.name in quant_env: - # if previous node is already quantized, we'll just remove the activation_post_process - quant_env[node.name] = quant_env[prev_node.name] - continue - # replace activation post process with quantization ops - root_module = self.modules[''] - quant_env[node.name] = quantize_node( - root_module, self.quantized_graph, - load_non_quantized(node.args[0]), observer_module) - continue - # dequantize inputs for the node that are not quantized - env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) - self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized)) + if node.op == 'call_module' and \ + is_activation_post_process(self.modules[node.target]): + insert_quantize_node(node) + elif node.op == 'placeholder': + cur_placeholder_node_idx = placeholder_node_seen_cnt + placeholder_node_seen_cnt += 1 + if cur_placeholder_node_idx in input_quantized_idxs: + quant_env[node.name] = \ + self.quantized_graph.node_copy(node, load_non_quantized) + else: + env[node.name] = \ + self.quantized_graph.node_copy(node, load_non_quantized) + else: + # copy quantized or non-quantized node + env[node.name] = \ + self.quantized_graph.node_copy(node, load_non_quantized) # remove activation post process act_post_process_removed_graph = Graph() env = {} - def load_arg(a): + def load_arg_simple(a: Argument) -> Argument: return map_arg(a, lambda node: env[node.name]) for node in self.quantized_graph.nodes: + if node.op == 'output': + act_post_process_removed_graph.output( + map_arg(node.args[0], load_arg_simple)) + continue if node.op == 'call_module' and \ is_activation_post_process(self.modules[node.target]): - # remove activation post process + # remove activation post process node env[node.name] = env[node.args[0].name] else: - env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg) - act_post_process_removed_graph.output(map_arg(self.quantized_graph.result, load_arg)) - - module_dict = dict(model.named_modules()) - to_be_removed = [] - for name, module in model.named_modules(): - if is_activation_post_process(module) and not is_submodule_of_fake_quant(name, module, module_dict): - to_be_removed.append(name) - for n in to_be_removed: - delattr(model, n) + env[node.name] = act_post_process_removed_graph.node_copy( + node, load_arg_simple) + + # removes qconfig and activation_post_process modules _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model - # Trace back from the weight node util we hit getattr, reconstruct the graph module - # with the traced nodes and run the graph module to pack the weight. then replace - # the original chain of ops with the packed weight. - def _fold_weight(self, quantized): + # Trace back from the weight node util we hit getattr, reconstruct the + # graph module with the traced nodes and run the graph module to pack the + # weight. then replace the original chain of ops with the packed weight. + def _fold_weight(self, quantized: GraphModule) -> GraphModule: packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() @@ -668,11 +920,12 @@ def _fold_weight(self, quantized): # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() - env = {} + env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) - get_new_packed_weight_name = get_new_attr_name_with_prefix('_fx_pass_packed_weight_') + get_new_packed_weight_name = \ + get_new_attr_name_with_prefix('_fx_pass_packed_weight_') quantized_root = quantized quantized_graph = quantized.graph for node in quantized_graph.nodes: @@ -691,17 +944,24 @@ def load_arg(a): else: # copy other nodes env[node.name] = folded_graph.node_copy(node, load_arg) - folded_graph.output(load_arg(quantized_graph.result)) quantized = GraphModule(quantized_root, folded_graph) return quantized - def convert(self, model, inplace=False, debug=False, is_dynamic=False): - quantized = self._convert(model, inplace, debug, is_dynamic) + def convert(self, model: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: + quantized = self._convert( + model, debug, convert_custom_config_dict, is_standalone_module) if not debug: quantized = self._fold_weight(quantized) return quantized - def _find_matches(self, graph, modules, patterns): + def _find_matches( + self, graph: Graph, modules: Dict[str, torch.nn.Module], + patterns: Dict[Pattern, QuantizeHandler], + standalone_module_names: List[str] = None, + standalone_module_classes: List[Callable] = None, + custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]: """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. @@ -715,15 +975,26 @@ def _find_matches(self, graph, modules, patterns): Outputs a map of node_name -> - (node, matched_values, QuantizeHandler instance, qconfig) + (node, matched_values, matched_pattern, QuantizeHandler instance, + qconfig) For example, { - 'relu_1': (relu_1, [relu_1], , QConfig(...)), + 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, + , QConfig(...)), ... } """ - match_map = {} - all_matched = set() + if custom_module_classes is None: + custom_module_classes = [] + + if standalone_module_classes is None: + standalone_module_classes = [] + + if standalone_module_names is None: + standalone_module_names = [] + + match_map: Dict[str, MatchResult] = {} + all_matched : Set[str] = set() def record_match(pattern, node, matched): if isinstance(pattern, tuple): @@ -735,30 +1006,54 @@ def record_match(pattern, node, matched): else: matched.append(node) + assert self.qconfig_map is not None for node in reversed(graph.nodes): if node.name not in match_map and node.name not in all_matched: for pattern, value in patterns.items(): if is_match(modules, node, pattern): - matched = [] + matched: List[Any] = [] record_match(pattern, node, matched) for n in matched: - match_map[n.name] = (node, matched, value(self, node), self.qconfig_map[n.name]) + match_map[n.name] = ( + node, matched, pattern, value(self, node), # type: ignore + self.qconfig_map[n.name]) all_matched.add(n.name) # break after finding the first match break # add custom module instances to the match result + assert self.modules is not None for node in graph.nodes: if node.op == 'call_module' and \ - (is_custom_module_class(type(self.modules[node.target])) or - is_observed_custom_module(self.modules[node.target])): + type(self.modules[node.target]) in custom_module_classes: custom_module_qconfig = self.qconfig_map[node.name] match_map[node.name] = ( - node, [node], CustomModuleQuantizeHandler(self, node), custom_module_qconfig) + node, [node], None, CustomModuleQuantizeHandler(self, node), + custom_module_qconfig) + + def is_standalone_module(node_target): + assert self.modules is not None + return ( + node_target in standalone_module_names or # type: ignore + type(self.modules[node_target]) in standalone_module_classes # type: ignore + ) + + # add standalone modules to the match + for node in graph.nodes: + if node.op == 'call_module' and \ + (is_standalone_module(node.target) or + is_observed_standalone_module(self.modules[node.target])): + # add node to matched nodes + custom_module_qconfig = self.qconfig_map[node.name] + match_map[node.name] = ( + node, [node], None, + StandaloneModuleQuantizeHandler(self, node), + custom_module_qconfig) return match_map - def _find_quants(self, graph, matches): + def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult], + ) -> Dict[str, Tuple[DefaultQuantizeHandler, Callable]]: """ Takes the nodes in the input graph and pending matches, and finds and returns the input and output nodes which need to be quantized. @@ -768,38 +1063,64 @@ def _find_quants(self, graph, matches): - matches: output of self._find_matches function Outputs a map of - node_name -> (QuantizeHandler instance (always DefaultQuant), qconfig) + node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler), + activation_post_process (observer/fake_quantize module) constructor) """ - quants = {} + quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = {} - def visit(node, qconfig): + def visit(node, matched_pattern, qconfig): def visit_arg(arg): - # note: we have to measure quantization information - # even for nodes where we might not use it because it is already - # quantized. This is because each match has the option to - # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) - is_weight = False - if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: - for i, node_arg in enumerate(node.args): - if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]: - is_weight = True - if (not self.is_dynamic_quant) or is_weight: - # overwrite previous quant config - quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight) + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not (is_weight or is_bias) + should_add_handler = qconfig is not None and ( + (is_activation and + activation_is_statically_quantized(qconfig)) or + (is_weight and weight_is_statically_quantized(qconfig)) + ) + + if should_add_handler: + act_post_process_ctr = qconfig.weight if is_weight else \ + qconfig.activation + # overwrite the constructor from qconfig + act_post_process_ctr = \ + get_default_output_activation_post_process_map().get( + matched_pattern, + act_post_process_ctr) + quants[arg.name] = ( + DefaultQuantizeHandler(self, arg), act_post_process_ctr) return visit_arg for node in graph.nodes: if node.name in matches: - root_node, matched, obj, qconfig = matches[node.name] + root_node, matched_nodes, matched_pattern, quantize_handler, \ + qconfig = matches[node.name] # don't attach observer/fake_quant for CopyNode - if isinstance(obj, CopyNode): + if isinstance(quantize_handler, CopyNode): qconfig = None - if root_node is node: - # matched[-1] is the first op in the sequence and - # matched[0] is the last op in the sequence + if root_node is node and \ + input_output_observed(quantize_handler): + # matched_nodes[-1] is the first op in the sequence and + # matched_nodes[0] is the last op in the sequence # inputs - map_arg(matched[-1].args, visit(matched[-1], qconfig)) - map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) + # matched_pattern is set to None for inputs because + # we only want to select QuantizeHandler object based + # on pattern for output, inputs will always use + # DefaultQuantizeHandler + map_arg(matched_nodes[-1].args, visit(matched_nodes[-1], + None, qconfig)) + map_arg(matched_nodes[-1].kwargs, visit(matched_nodes[-1], + None, qconfig)) + # output - map_arg(matched[0], visit(None, qconfig)) + # we don't insert observer for output of standalone module + if not isinstance( + quantize_handler, StandaloneModuleQuantizeHandler): + # passing in matched_pattern here so that we can + # customize activation_post_process constructor for + # output based on the pattern, e.g. + # for sigmoid op we'll use + # default_affine_fixed_qparam_fake_quant + map_arg(matched_nodes[0], + visit(None, matched_pattern, qconfig)) return quants diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 5d5532dc48fc2..8285e204b1edb 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -1,5 +1,15 @@ import re import torch +from ..utils import is_per_tensor, is_per_channel + +from torch.fx import GraphModule, map_arg + +from torch.fx.graph import ( + Graph, + Node, +) + +from typing import Callable, Optional, List, Dict, Any, Set # turn foo.bar -> ['foo', 'bar'] def _parent_name(target): @@ -75,15 +85,6 @@ def graph_pretty_str(g, shorten=True) -> str: res_str += "*obs_{n} = activation_post_process_{n}\n" return res_str -def is_per_tensor(qscheme): - return qscheme == torch.per_tensor_affine or \ - qscheme == torch.per_tensor_symmetric - -def is_per_channel(qscheme): - return qscheme in [torch.per_channel_affine, - torch.per_channel_affine_float_qparams, - torch.per_channel_symmetric] - def get_per_tensor_qparams(activation_post_process): assert is_per_tensor(activation_post_process.qscheme), 'Only per tensor quantization is supported' scale, zero_point = activation_post_process.calculate_qparams() @@ -107,7 +108,7 @@ def get_quantize_op_and_qparams(activation_post_process): scale = float(scale) zero_point = int(zero_point) qparams = {'_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype} - quantize_op = torch.quantize_per_tensor + quantize_op = torch.quantize_per_tensor # type: ignore return quantize_op, qparams def quantize_node(root_module, graph, node, activation_post_process): @@ -138,3 +139,124 @@ def get_next_qparams_idx(module, qparams): qparam_full_path = key + str(idx) inputs.append(graph.create_node('get_attr', qparam_full_path)) return graph.create_node('call_function', quantize_op, tuple(inputs), {}) + +def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key) -> List[Any]: + r""" Get all the unique custom module keys in the custom config dict + e.g. + Input: + custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule1: ObservedCustomModule + }, + "dynamic": { + CustomModule2: DynamicObservedCustomModule + }, + "weight_only": { + CustomModule3: WeightOnlyObservedCustomModule + }, + }, + } + + Output: + # extract all the keys in "static", "dynamic" and "weight_only" dict + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes : Set[Any] = set() + custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) + for quant_mode in ["static", "dynamic", "weight_only"]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + +def get_linear_prepack_op_for_dtype(dtype): + if dtype == torch.float16: + return torch.ops.quantized.linear_prepack_fp16 + elif dtype == torch.qint8: + return torch.ops.quantized.linear_prepack + else: + raise Exception("can't get linear prepack op for dtype:", dtype) + +# Returns a function that can get a new attribute name for module with given +# prefix, for example, +# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') +# >> new_name = get_new_observer_name(module) +# new_name will be an unused attribute name on module, e.g. `_observer_1` +def get_new_attr_name_with_prefix(prefix: str) -> Callable: + def get_new_attr_name(module: torch.nn.Module): + def get_attr_name(i: int): + return prefix + str(i) + i = 0 + attr_name = get_attr_name(i) + while hasattr(module, attr_name): + i += 1 + attr_name = get_attr_name(i) + return attr_name + return get_new_attr_name + +def collect_producer_nodes(node: Node) -> Optional[List[Node]]: + r''' Starting from a target node, trace back until we hit inpu or + getattr node. This is used to extract the chain of operators + starting from getattr to the target node, for example + def forward(self, x): + observed = self.observer(self.weight) + return F.linear(x, observed) + collect_producer_nodes(observed) will either return a list of nodes that + produces the observed node or None if we can't extract a self contained + graph without free variables(inputs of the forward function). + ''' + nodes = [node] + frontier = [node] + while frontier: + node = frontier.pop() + all_args = list(node.args) + list(node.kwargs.values()) + for arg in all_args: + if not isinstance(arg, Node): + continue + if arg.op == 'placeholder': + # hit input, can't fold in this case + return None + nodes.append(arg) + if not (arg.op == 'call_function' and arg.target == getattr): + frontier.append(arg) + return nodes + +def graph_module_from_producer_nodes( + root: GraphModule, producer_nodes: List[Node]) -> GraphModule: + r''' Construct a graph module from extracted producer nodes + from `collect_producer_nodes` function + Args: + root: the root module for the original graph + producer_nodes: a list of nodes we use to construct the graph + Return: + A graph module constructed from the producer nodes + ''' + assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' + # since we traced back from node to getattrr + producer_nodes.reverse() + graph = Graph() + env: Dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node]) + for producer_node in producer_nodes: + env[producer_node] = graph.node_copy(producer_node, load_arg) + graph.output(load_arg(producer_nodes[-1])) + graph_module = GraphModule(root, graph) + return graph_module + +def assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + devices = {p.device for p in module.parameters()} | \ + {p.device for p in module.buffers()} + assert len(devices) <= 1, ( + "prepare only works with cpu or single-device CUDA modules, " + "but got devices {}".format(devices) + ) + device = next(iter(devices)) if len(devices) > 0 else None + return device diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 163bd037467e4..2cc579f66087a 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -2,10 +2,11 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial -from typing import List, Tuple, Optional - +from typing import Any, List, Tuple, Optional, Dict, Union +from collections import OrderedDict import torch import torch.nn as nn +import re def _with_args(cls_or_self, **kwargs): r"""Wrapper that allows creation of class factories. @@ -37,7 +38,7 @@ def __repr__(self): return r -ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3: +ABC: Any = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3: class ObserverBase(ABC, nn.Module): @@ -110,6 +111,8 @@ class _ObserverBase(ObserverBase): # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) _version = 2 + eps: torch.Tensor + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None): super(_ObserverBase, self).__init__(dtype=dtype) @@ -133,7 +136,8 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, assert self.dtype in ( torch.qint8, torch.quint8, - ), "Default Observer only works for qint8 and quint8 data type" + torch.quint4x2, + ), "Default Observer only works for qint8, quint8 and quint4x2 data type" self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) if self.has_customized_qrange: self._validate_qmin_qmax(quant_min, quant_max) @@ -154,8 +158,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @torch.jit.export - def _validate_qmin_qmax(self, quant_min, quant_max): - # type: (int, int) -> None + def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: r"""Validates that the user-specified quantization range is properly initialized and within the given bound supported by the observer dtype. @@ -175,8 +178,7 @@ def _validate_qmin_qmax(self, quant_min, quant_max): assert quant_min < quant_max, "qmin must be strictly less than qmax for user-specified quantization range." @torch.jit.export - def _calculate_qmin_qmax(self): - # type: () -> Tuple[int, int] + def _calculate_qmin_qmax(self) -> Tuple[int, int]: r"""Calculates actual qmin and qmax based on the quantization range, observer datatype and if range is reduced. """ @@ -207,16 +209,17 @@ def _calculate_qmin_qmax(self): quant_min, quant_max = -64, 63 else: quant_min, quant_max = -128, 127 - else: + elif self.dtype == torch.quint8: if self.reduce_range: quant_min, quant_max = 0, 127 else: quant_min, quant_max = 0, 255 + else: + quant_min, quant_max = 0, 15 return quant_min, quant_max @torch.jit.export - def _calculate_qparams(self, min_val, max_val): - # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor] + def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Calculates the quantization parameters, given min and max value tensors. Works for both per tensor and per channel cases @@ -255,9 +258,9 @@ def _calculate_qparams(self, min_val, max_val): min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - scale = torch.ones(min_val_neg.size(), dtype=torch.float32) - zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64) - device = 'cuda' if min_val_neg.is_cuda else 'cpu' + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric: max_val_pos = torch.max(-min_val_neg, max_val_pos) @@ -294,7 +297,6 @@ def _calculate_qparams(self, min_val, max_val): if self.qscheme == torch.per_channel_affine_float_qparams: zero_point = torch.tensor([float(zero_point)], dtype=zero_point.dtype, device=device) - return scale, zero_point @@ -361,6 +363,8 @@ class MinMaxObserver(_ObserverBase): .. note:: If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0. """ + min_val: torch.Tensor + max_val: torch.Tensor def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None): @@ -386,6 +390,8 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, def forward(self, x_orig): r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig x = x_orig.detach() # avoid keeping autograd tape x = x.to(self.min_val.dtype) min_val_cur, max_val_cur = torch._aminmax(x) @@ -459,6 +465,8 @@ def __init__(self, averaging_constant=0.01, dtype=torch.quint8, quant_max=quant_max) def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig x = x_orig.detach() # avoid keeping autograd tape x = x.to(self.min_val.dtype) min_val = self.min_val @@ -475,82 +483,6 @@ def forward(self, x_orig): self.max_val.copy_(max_val) return x_orig - -class MinMaxDynamicQuantObserver(MinMaxObserver): - r"""Observer module for computing the quantization parameters based on the - tensor min and max values in dynamic quantization. - - This observer will mimic the quantization steps followed in the operator - to compute the activation tensor quantization parameters at run-time. - - Args: - dtype: Quantized data type - qscheme: Quantization scheme to be used - reduce_range: Reduces the range of the quantized data type by 1 bit - - .. warning:: Only works with ``torch.per_tensor_symmetric`` quantization scheme - - .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. - - .. note:: If the running minimum equals to the running maximum, the scale - and zero_point are set to 0.1 and 0. - """ - - @torch.jit.export - def calculate_qparams(self): - r"""Calculates the quantization parameters.""" - - if self.max_val == float('-inf') and self.min_val == float('inf'): - return torch.tensor([1.0]), torch.tensor([0]) - - assert self.min_val <= self.max_val, "min {} should be less than max {}".format( - self.min_val, self.max_val - ) - - if self.dtype == torch.qint8: - if self.reduce_range: - qmin, qmax = -64, 63 - else: - qmin, qmax = -128, 127 - else: # dtype == torch.quint8 - if self.reduce_range: - qmin, qmax = 0, 127 - else: - qmin, qmax = 0, 255 - - max_val, min_val = self.max_val.to(dtype=torch.float), self.min_val.to(dtype=torch.float) - - # Extend the min_val and max_val to ensure that it contains 0. - min_val = torch.min(min_val, torch.tensor(0.).to(dtype=torch.float)) - max_val = torch.max(max_val, torch.tensor(0.).to(dtype=torch.float)) - - scale = (max_val.to(dtype=torch.double) - min_val) / float(qmax - qmin) - - if scale == 0.0 or torch.isinf(1.0 / scale): - scale = torch.tensor(0.1).to(dtype=torch.float) - zero_point = 0 - - zero_point_from_min = qmin - min_val / scale.to(dtype=torch.double) - zero_point_from_max = qmax - max_val / scale.to(dtype=torch.double) - zero_point_from_min_error = abs(qmin) - abs(min_val / scale.to(dtype=torch.double)) - zero_point_from_max_error = abs(qmax) - abs(max_val / scale.to(dtype=torch.double)) - - if zero_point_from_min_error < zero_point_from_max_error: - initial_zero_point = zero_point_from_min - else: - initial_zero_point = zero_point_from_max - - nudged_zero_point = 0 - - if initial_zero_point < qmin: - nudged_zero_point = qmin - elif initial_zero_point > qmax: - nudged_zero_point = qmax - else: - nudged_zero_point = int(initial_zero_point.round()) - - return scale.to(dtype=torch.float), torch.tensor([nudged_zero_point]) - class PerChannelMinMaxObserver(_ObserverBase): r"""Observer module for computing the quantization parameters based on the running per channel min and max values. @@ -576,6 +508,9 @@ class PerChannelMinMaxObserver(_ObserverBase): .. note:: If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0. """ + min_vals: torch.Tensor + max_vals: torch.Tensor + def __init__(self, ch_axis=0, dtype=torch.quint8, qscheme=torch.per_channel_affine, reduce_range=False, @@ -601,6 +536,8 @@ def forward(self, x_orig): return self._forward(x_orig) def _forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig x = x_orig.detach() # avoid keeping autograd tape min_vals = self.min_vals max_vals = self.max_vals @@ -633,9 +570,10 @@ def calculate_qparams(self): def extra_repr(self): return "min_val={}, max_val={}".format(self.min_vals, self.max_vals) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - + @torch.jit.export + def _load_from_state_dict(self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], prefix: str, + local_metadata: Dict[str, torch.Tensor], strict: bool, + missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]): local_state = ['min_vals', 'max_vals'] for name in local_state: key = prefix + name @@ -649,10 +587,26 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, self.min_vals.resize_(val.shape) else: self.max_vals.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == 'min_vals': + self.min_vals.copy_(val) + else: + self.max_vals.copy_(val) elif strict: missing_keys.append(key) - super(PerChannelMinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + + if not torch.jit.is_scripting(): + super(PerChannelMinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + @torch.jit.export + def _load_from_state_dict_script(self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], + prefix: str, local_metadata: Dict[str, torch.Tensor], strict: bool, + missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]): + + self._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): r"""Observer module for computing the quantization parameters based on the @@ -690,6 +644,8 @@ def __init__(self, averaging_constant=0.01, ch_axis=0, dtype=torch.quint8, self.averaging_constant = averaging_constant def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig x = x_orig.detach() # avoid keeping autograd tape x = x.to(self.min_vals.dtype) min_vals = self.min_vals @@ -737,6 +693,9 @@ class HistogramObserver(_ObserverBase): 3. Compute the scale and zero point the same way as in the :class:`~torch.quantization.MinMaxObserver` """ + histogram: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor def __init__(self, bins=2048, upsample_rate=128, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False): @@ -879,8 +838,10 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type): return new_min, new_max @torch.jit.ignore - def _adjust_min_max(self, combined_min, combined_max, upsample_rate): - # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor, int, int] + def _adjust_min_max(self, + combined_min: torch.Tensor, + combined_max: torch.Tensor, + upsample_rate: int) -> Tuple[torch.Tensor, torch.Tensor, int, int]: # We ensure that: # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins) # This allows us to have a common grid of resolution s, where we can align @@ -888,17 +849,22 @@ def _adjust_min_max(self, combined_min, combined_max, upsample_rate): # start_idx maps min_val to the histogram bin index. hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate) - downsample_rate = torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).to(torch.int).item() + downsample_rate = int(torch.ceil((combined_max - combined_min) / (self.bins * hist_bin_width)).item()) e = downsample_rate * (self.bins * hist_bin_width) - (combined_max - combined_min) # Relax only the max, not the min, so that for one sided distributions, min stays at zero combined_max = combined_max + e combined_min = combined_min - start_idx = torch.round((self.min_val - combined_min) / hist_bin_width).to(torch.int).item() + start_idx = int(torch.round((self.min_val - combined_min) / hist_bin_width).item()) return combined_min, combined_max, downsample_rate, start_idx @torch.jit.ignore - def _combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins): - # type: (Tensor, Tensor, int, int, int, int) -> Tensor + def _combine_histograms(self, + orig_hist: torch.Tensor, + new_hist: torch.Tensor, + upsample_rate: int, + downsample_rate: int, + start_idx: int, + Nbins: int) -> torch.Tensor: # First up-sample the histogram with new data by a factor of L # This creates an approximate probability density thats piecwise constant upsampled_histogram = new_hist.repeat_interleave(upsample_rate) @@ -919,8 +885,9 @@ def _combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rat orig_hist = orig_hist + interpolated_histogram.to(torch.float) return orig_hist - def forward(self, x_orig): - # type: (Tensor) -> Tensor + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + if x_orig.numel() == 0: + return x_orig x = x_orig.detach() min_val = self.min_val max_val = self.max_val @@ -932,7 +899,10 @@ def forward(self, x_orig): self.min_val.copy_(min_val) self.max_val.resize_(max_val.shape) self.max_val.copy_(max_val) - torch.histc(x, self.bins, min=min_val, max=max_val, out=self.histogram) + assert min_val.numel() == 1 and max_val.numel() == 1, ( + "histogram min/max values must be scalar." + ) + torch.histc(x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram) else: new_min, new_max = torch._aminmax(x) combined_min = torch.min(new_min, min_val) @@ -942,7 +912,10 @@ def forward(self, x_orig): # and then downsampling the histogram efficiently combined_min, combined_max, downsample_rate, start_idx = \ self._adjust_min_max(combined_min, combined_max, self.upsample_rate) - combined_histogram = torch.histc(x, self.bins, min=combined_min, max=combined_max) + assert combined_min.numel() == 1 and combined_max.numel() == 1, ( + "histogram min/max values must be scalar." + ) + combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max)) if combined_min == min_val and combined_max == max_val: combined_histogram += self.histogram else: @@ -1025,10 +998,15 @@ class PlaceholderObserver(ObserverBase): custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation (Can be used in Graph Mode Passes for special case ops). """ - def __init__(self, dtype=torch.float16, custom_op_name=""): + def __init__(self, dtype=torch.float32, custom_op_name="", compute_dtype=None): super(PlaceholderObserver, self).__init__(dtype=dtype) + # dtype of input of the target operator, e.g. for dynamic quantization + # ops, the dtype will be float32 self.dtype = dtype self.custom_op = custom_op_name + # used for configuration of computation type for dynamic quantization + if compute_dtype: + self.compute_dtype = compute_dtype def forward(self, x): return x @@ -1092,14 +1070,77 @@ def forward(self, x): def calculate_qparams(self): raise Exception("calculate_qparams should not be called for NoopObserver") +def _is_observer_script_module(mod, obs_type_name): + ''' Returns true if given mod is an instance of Observer script module. + ''' + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.quantization.observer.___torch_mangle_2.MinMaxObserver' + suffix = mod._c.qualified_name.split('.', 1)[1] + name = re.sub(r'\.___torch_mangle_\d+', '', suffix) + return obs_type_name in name + return False + +def _is_activation_post_process(module): + return (isinstance(module, torch.quantization.ObserverBase) or + isinstance(module, torch.quantization.FakeQuantize) or + _is_observer_script_module(module, 'torch.quantization.observer')) + +def _is_per_channel_script_obs_instance(module): + if isinstance(module, torch.jit.RecursiveScriptModule): + return _is_observer_script_module(module, "torch.quantization.observer.PerChannelMinMaxObserver") or\ + _is_observer_script_module(module, "torch.quantization.observer.MovingAveragePerChannelMinMaxObserver") + return False + +def get_observer_state_dict(mod): + r""" + Returns the state dict corresponding to the observer stats. + Traverse the model state_dict and extract out the stats. + """ + od = OrderedDict() + if isinstance(mod, torch.jit.RecursiveScriptModule): + for k, v in mod.state_dict().items(): + if 'observer' in k: + od[k] = v + else: + # path for GraphModule and nn.Module (eager mode) + for k, v in mod.state_dict().items(): + if 'activation_post_process' in k: + od[k] = v + od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] + return od + +def load_observer_state_dict(mod, obs_dict): + r""" + Given input model and a state_dict containing model observer stats, + load the stats back into the model. The observer state_dict can be saved + using torch.quantization.get_observer_state_dict + """ + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + for name, module in mod.named_modules(): + prefix = name + '.' + if _is_activation_post_process(module): + if _is_per_channel_script_obs_instance(module): + # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. + # However this is not called when the module is scripted and we end up calling the default one in module.py + module._load_from_state_dict_script(obs_dict, prefix, {}, True, missing_keys, unexpected_keys, []) + else: + module._load_from_state_dict(obs_dict, prefix, {}, False, missing_keys, unexpected_keys, []) + for k in missing_keys: + if 'observer' in k or 'activation_post_process' in k: + raise Exception("Missing keys for observer {} in state_dict".format(k)) + for k in unexpected_keys: + if 'observer' in k or 'activation_post_process' in k: + raise Exception("Unexpected keys for observer {} in state_dict".format(k)) # Restrict activations to be in the range (0,127) default_observer = MinMaxObserver.with_args(reduce_range=True) +default_placeholder_observer = PlaceholderObserver default_debug_observer = RecordingObserver default_weight_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric) default_histogram_observer = HistogramObserver.with_args(reduce_range=True) default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric) -default_dynamic_quant_observer = MinMaxDynamicQuantObserver +default_dynamic_quant_observer = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8) default_float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 253abbaf4445d..2d91d8ab6b3ec 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -3,6 +3,8 @@ from .fake_quantize import * import torch.nn as nn +from typing import Union + class QConfig(namedtuple('QConfig', ['activation', 'weight'])): """ Describes how to quantize a layer or a part of the network by providing @@ -67,8 +69,11 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): per_channel_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, weight=default_per_channel_weight_observer) -float_qparams_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, - weight=default_float_qparams_observer) +# TODO: this is weight only quant, change this to QConfigWeightOnly +# or remove the QConfigDynamic later +float_qparams_weight_only_qconfig = QConfigDynamic( + activation=default_placeholder_observer, + weight=default_float_qparams_observer) default_qat_qconfig = QConfig(activation=default_fake_quant, weight=default_weight_fake_quant) @@ -106,3 +111,18 @@ def get_default_qat_qconfig(backend='fbgemm'): else: qconfig = default_qat_qconfig return qconfig + +def assert_valid_qconfig(qconfig: Union[QConfig, QConfigDynamic], + mod: torch.nn.Module) -> None: + is_conv_transpose_mod = ( + isinstance(mod, torch.nn.ConvTranspose1d) or + isinstance(mod, torch.nn.ConvTranspose2d) or + isinstance(mod, torch.nn.ConvTranspose3d)) + if is_conv_transpose_mod: + example_observer = qconfig.weight() + is_per_channel = ( + isinstance(example_observer, torch.quantization.PerChannelMinMaxObserver) or + isinstance(example_observer, torch.quantization.MovingAveragePerChannelMinMaxObserver) + ) + assert not is_per_channel, \ + 'Per channel weight observer is not supported yet for ConvTranspose{n}d.' diff --git a/torch/quantization/quant_type.py b/torch/quantization/quant_type.py index 212dec1fe28c7..463d086b39b62 100644 --- a/torch/quantization/quant_type.py +++ b/torch/quantization/quant_type.py @@ -1,4 +1,3 @@ - import enum # Quantization type (dynamic quantization, static quantization). @@ -7,3 +6,14 @@ class QuantType(enum.IntEnum): DYNAMIC = 0 STATIC = 1 QAT = 2 + WEIGHT_ONLY = 3 + + +def quant_type_to_str(quant_type): + m = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", + } + return m[quant_type] diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 60d166ae4896d..802d81d3ca0cd 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -1,3 +1,5 @@ +import copy + import torch from torch import nn @@ -9,181 +11,195 @@ import torch.nn.quantized.dynamic as nnqd import torch.nn.qat as nnqat -from .stubs import QuantStub, DeQuantStub +from typing import Optional, Union, Dict, Set, Callable, Any -# Map for swapping float module to quantized ones -STATIC_QUANT_MODULE_MAPPINGS = { - nn.Linear: nnq.Linear, - nn.ReLU: nnq.ReLU, - nn.ReLU6: nnq.ReLU6, - nn.Hardswish: nnq.Hardswish, - nn.ELU: nnq.ELU, +from .stubs import QuantStub, DeQuantStub +from .fake_quantize import ( + default_affine_fixed_qparams_fake_quant, + default_symmetric_fixed_qparams_fake_quant, +) +from .utils import get_combined_dict + +# Default map for swapping float module to quantized ones +DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, nn.Conv1d: nnq.Conv1d, nn.Conv2d: nnq.Conv2d, nn.Conv3d: nnq.Conv3d, nn.ConvTranspose1d: nnq.ConvTranspose1d, nn.ConvTranspose2d: nnq.ConvTranspose2d, - nn.BatchNorm2d: nnq.BatchNorm2d, - nn.BatchNorm3d: nnq.BatchNorm3d, - nn.LayerNorm: nnq.LayerNorm, + nn.ELU: nnq.ELU, + nn.Embedding: nnq.Embedding, + nn.EmbeddingBag: nnq.EmbeddingBag, nn.GroupNorm: nnq.GroupNorm, + nn.Hardswish: nnq.Hardswish, nn.InstanceNorm1d: nnq.InstanceNorm1d, nn.InstanceNorm2d: nnq.InstanceNorm2d, nn.InstanceNorm3d: nnq.InstanceNorm3d, - nn.Embedding: nnq.Embedding, - nn.EmbeddingBag: nnq.EmbeddingBag, - QuantStub: nnq.Quantize, - DeQuantStub: nnq.DeQuantize, + nn.LayerNorm: nnq.LayerNorm, + nn.LeakyReLU: nnq.LeakyReLU, + nn.modules.linear._LinearWithBias: nnq.Linear, + nn.Linear: nnq.Linear, + nn.ReLU6: nnq.ReLU6, # Wrapper Modules: nnq.FloatFunctional: nnq.QFunctional, # Intrinsic modules: + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, nni.ConvReLU1d: nniq.ConvReLU1d, nni.ConvReLU2d: nniq.ConvReLU2d, nni.ConvReLU3d: nniq.ConvReLU3d, nni.LinearReLU: nniq.LinearReLU, - nni.BNReLU2d: nniq.BNReLU2d, - nni.BNReLU3d: nniq.BNReLU3d, - nniqat.ConvReLU2d: nniq.ConvReLU2d, - nniqat.LinearReLU: nniq.LinearReLU, + nniqat.ConvBn1d: nnq.Conv1d, nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBnReLU1d: nniq.ConvReLU1d, nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.LinearReLU: nniq.LinearReLU, # QAT modules: nnqat.Linear: nnq.Linear, nnqat.Conv2d: nnq.Conv2d, } -# Map for swapping float module to qat modules -QAT_MODULE_MAPPINGS = { - nn.Linear: nnqat.Linear, +# Default map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.Conv2d: nnqat.Conv2d, + nn.Linear: nnqat.Linear, + nn.modules.linear._LinearWithBias: nnqat.Linear, # Intrinsic modules: + nni.ConvBn1d: nniqat.ConvBn1d, nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, nni.ConvReLU2d: nniqat.ConvReLU2d, nni.LinearReLU: nniqat.LinearReLU } -# Map for swapping dynamic modules -DYNAMIC_QUANT_MODULE_MAPPINGS = { +# Default map for swapping dynamic modules +DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.GRUCell: nnqd.GRUCell, nn.Linear: nnqd.Linear, + nn.modules.linear._LinearWithBias: nnqd.Linear, nn.LSTM: nnqd.LSTM, + nn.GRU: nnqd.GRU, nn.LSTMCell: nnqd.LSTMCell, nn.RNNCell: nnqd.RNNCell, - nn.GRUCell: nnqd.GRUCell, } # Whitelist for propagating the qconfig -_EXCLUDE_QCONFIG_PROPAGATE_LIST = { - DeQuantStub, -} -_INCLUDE_QCONFIG_PROPAGATE_LIST = { +_INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = { nn.Sequential, } -# mapping from floating point function or torch ops to quantized ops -FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS = { +# Default mapping from floating point function or torch ops to quantized ops +# TODO: merge with default static mapping +DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = { F.elu: torch._ops.ops.quantized.elu, F.hardswish: torch._ops.ops.quantized.hardswish, F.instance_norm: torch._ops.ops.quantized.instance_norm, F.layer_norm: torch._ops.ops.quantized.layer_norm, + F.leaky_relu: torch._ops.ops.quantized.leaky_relu, } -def register_static_quant_module_mapping( - float_source_module_class, static_quant_target_module_class): - ''' Register a mapping from `float_source__module_class` to `static_quant_target_module_class` - `static_quant_target_module_class` must have from_float defined as a class method - The mapping is used in the convert step of post training static quantization to - convert a float module to a statically quantized module. - ''' - assert hasattr(static_quant_target_module_class, 'from_float'), 'from_float must be defined' + \ - ' in quantized module class' - STATIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = static_quant_target_module_class +# mapping from module to output activation post process class +DEFAULT_MODULE_TO_ACT_POST_PROCESS : Dict[Callable, Callable] = { + nn.Hardsigmoid: default_affine_fixed_qparams_fake_quant, + nn.Sigmoid: default_affine_fixed_qparams_fake_quant, + nn.Tanh: default_symmetric_fixed_qparams_fake_quant, +} -def get_static_quant_module_mappings(): +def get_default_static_quant_module_mappings() -> Dict[Callable, Any]: ''' Get module mapping for post training static quantization ''' - return STATIC_QUANT_MODULE_MAPPINGS + return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) -def get_static_quant_module_class(float_module_class): - ''' Get the statically quantized module class corresponding to +def get_static_quant_module_class( + float_module_class: Callable, + additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any: + r"""n Get the statically quantized module class corresponding to the floating point module class - ''' - static_quant_module_class = STATIC_QUANT_MODULE_MAPPINGS.get(float_module_class, None) + """ + if additional_static_quant_mapping is None: + additional_static_quant_mapping = {} + all_mappings = get_combined_dict(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping) + static_quant_module_class = all_mappings.get(float_module_class, None) assert static_quant_module_class is not None, \ - 'Floating point module class {}'.format(float_module_class) + \ - ' does not have a corresponding quantized module class' - return static_quant_module_class - -def register_qat_module_mapping(float_source_module_class, qat_target_module_class): - '''Register a mapping from `float_source_module_class` to `qat_target_module_class`, - `qat_target_module_class` must have from_float defined as a class method - This mapping is used in prepare step of quantization aware training to swap - a float module to a qat module. - ''' - assert hasattr(qat_target_module_class, 'from_float'), 'from_float must be defined' + \ - ' in qat module class' - QAT_MODULE_MAPPINGS[float_source_module_class] = qat_target_module_class - -def get_qat_module_mappings(): - ''' Get module mapping for quantization aware training - ''' - return QAT_MODULE_MAPPINGS - -def register_dynamic_quant_module_class(float_source_module_class, dynamic_quant_target_module_class): - ''' Register a mapping from `float_source_module_class` to `dynamic_quant_target_module_class`, - `dynamic_quant_target_module_class` must have from_float defined as a class method - This mapping is used in convert step of post training dynamic - quantization to swap a float module to a dynamically quantized - module. + "Floating point module class {}".format(str(float_module_class)) + \ + " does not have a corresponding quantized module class" + return copy.deepcopy(static_quant_module_class) + +def get_dynamic_quant_module_class( + float_module_class: Callable, + additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any: + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + assert dynamic_quant_module_class is not None, \ + "Floating point module class {}".format(str(float_module_class)) + \ + " does not have a corresponding quantized module class" + return copy.deepcopy(dynamic_quant_module_class) + +def get_default_qat_module_mappings() -> Dict[Callable, Any]: + ''' Get default module mapping for quantization aware training ''' - assert hasattr(dynamic_quant_target_module_class, 'from_float'), 'from_float must be defined' + \ - ' in dynamically quantized module type' - DYNAMIC_QUANT_MODULE_MAPPINGS[float_source_module_class] = dynamic_quant_target_module_class + return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) -def get_dynamic_quant_module_mappings(): +def get_default_dynamic_quant_module_mappings() -> Dict[Callable, Any]: ''' Get module mapping for post training dynamic quantization ''' - return DYNAMIC_QUANT_MODULE_MAPPINGS + return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS -def get_qconfig_propagation_list(): - ''' Get the list of module types that we'll attach qconfig +def get_default_qconfig_propagation_list() -> Set[Callable]: + ''' Get the default list of module types that we'll attach qconfig attribute to in prepare ''' QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( - (set(STATIC_QUANT_MODULE_MAPPINGS.keys()) | - set(QAT_MODULE_MAPPINGS.keys()) | - set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | - _INCLUDE_QCONFIG_PROPAGATE_LIST) - - _EXCLUDE_QCONFIG_PROPAGATE_LIST + (set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) | + set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) | + set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | + _INCLUDE_QCONFIG_PROPAGATE_LIST) ) - return QCONFIG_PROPAGATE_MODULE_CLASS_LIST + return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST) -def get_compare_output_module_list(): +def get_default_compare_output_module_list() -> Set[Callable]: ''' Get list of module class types that we will record output in numeric suite ''' NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( - set(STATIC_QUANT_MODULE_MAPPINGS.values()) - | set(QAT_MODULE_MAPPINGS.values()) - | set(DYNAMIC_QUANT_MODULE_MAPPINGS.values()) - | set(STATIC_QUANT_MODULE_MAPPINGS.keys()) - | set(QAT_MODULE_MAPPINGS.keys()) - | set(DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.values()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | _INCLUDE_QCONFIG_PROPAGATE_LIST - ) - _EXCLUDE_QCONFIG_PROPAGATE_LIST - return NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST - -def register_quantized_operator_mapping(float_op, quantized_op): - ''' Register a mapping from `floating_point_op` (torch or functional) to `quantized_op` - This is used in convert step of fx based graph mode quantization - to convert a float op to quantized op. - ''' - FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS[float_op] = quantized_op + ) + return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST) -def get_quantized_operator(float_op): +# TODO: merge with get_static_quant_module_class +def get_quantized_operator(float_op: Union[Callable, str]) -> Callable: ''' Get the quantized operator corresponding to the float operator ''' - quantized_op = FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) + quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) assert quantized_op is not None, \ - 'Operator {} does not have corresponding quantized op'.format(float_op) + 'Operator {} does not have corresponding quantized op'.format(str(float_op)) return quantized_op + +def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]: + r""" Get the special activation post process for `module`, this has + higher priority than the activation post process in `qconfig` + e.g. + input: torch.nn.Sigmoid + output: default_affine_fixed_qparam_fake_quant + """ + return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type(module), None) + +def _has_special_act_post_process(module: torch.nn.Module) -> bool: + return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 19a27e62ac5ba..77752a8af9c9a 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -1,29 +1,28 @@ - import copy import itertools import warnings import torch import torch.nn as nn -import torch.nn.intrinsic as nni import torch.nn.quantized as nnq -import torch.nn.intrinsic.qat as nniqat - -from .quantization_mappings import (get_dynamic_quant_module_mappings, - get_static_quant_module_mappings, - get_qat_module_mappings, - get_qconfig_propagation_list) - -from .custom_module_class_mappings import ( - is_custom_module_class, - get_observed_custom_module_class, - get_quantized_custom_module_class, - mark_observed_custom_module, - is_observed_custom_module, +import torch.nn.quantizable as nnqa +from torch.nn.intrinsic import _FusedModule + +from .quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_static_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + _has_special_act_post_process, + _get_special_act_post_process, ) from .stubs import DeQuantStub, QuantWrapper -from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig +from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig + +def is_activation_post_process(module): + return (isinstance(module, torch.quantization.ObserverBase) or + isinstance(module, torch.quantization.FakeQuantizeBase)) def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None, qconfig_parent=None, prefix=''): @@ -45,12 +44,14 @@ def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None, """ # TODO: Add test if allow_list is None: - allow_list = get_qconfig_propagation_list() + allow_list = get_default_qconfig_propagation_list() module_qconfig = qconfig_dict.get(type(module), qconfig_parent) module_qconfig = qconfig_dict.get(prefix, module_qconfig) module_qconfig = getattr(module, 'qconfig', module_qconfig) + torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module) + module.qconfig = module_qconfig for name, child in module.named_children(): module_prefix = prefix + '.' + name if prefix else name @@ -81,19 +82,12 @@ def _observer_forward_hook(self, input, output): """ return self.activation_post_process(output) -def _observer_forward_pre_hook(self, input): - ''' Forward pre hook that calls observer on the input (can be a tuple of values) - ''' - self.activation_pre_process(*input) - # Returning nothing is Ok, Module._call_impl will intrepret this - # as the pre_hook making no changes to the input, as desired - def register_activation_post_process_hook(module): assert hasattr(module, 'activation_post_process'), \ 'Expect activation_post_process attribut already attached to the module' return module.register_forward_hook(_observer_forward_hook) -def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, prehook=None): +def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): r"""Add observer for the leaf child of the module. This function insert observer module to all leaf child module that @@ -108,7 +102,10 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No None, module is modified inplace with added observer modules and forward_hooks """ if qconfig_propagation_list is None: - qconfig_propagation_list = get_qconfig_propagation_list() + qconfig_propagation_list = get_default_qconfig_propagation_list() + + if custom_module_class_mapping is None: + custom_module_class_mapping = {} # respect device affinity when adding observers if device is None: @@ -119,8 +116,8 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No ) device = next(iter(devices)) if len(devices) > 0 else None - def get_activation_post_process(qconfig, device): - activation = qconfig.activation() + def get_activation_post_process(qconfig, device, special_act_post_process=None): + activation = qconfig.activation() if special_act_post_process is None else special_act_post_process() if device is not None: activation.to(device) return activation @@ -128,49 +125,48 @@ def get_activation_post_process(qconfig, device): def needs_observation(m): return hasattr(m, 'qconfig') and m.qconfig is not None - def insert_activation_post_process(m): + def insert_activation_post_process(m, special_act_post_process=None): """ Adds an activation post process module and register a post hook that calls the module """ - if needs_observation(m): + # We don't insert observer/fake_quantize for DeQuantStub + if needs_observation(m) and not isinstance(m, DeQuantStub): # observer and hook will be gone after we swap the module - m.add_module('activation_post_process', get_activation_post_process(m.qconfig, device)) + m.add_module('activation_post_process', get_activation_post_process(m.qconfig, device, special_act_post_process)) # Register observer as the first entry in the hook list # All post forward hooks are preserved and will be executed after the observer before convert handle = register_activation_post_process_hook(m) m._forward_hooks.move_to_end(handle.id, last=False) for name, child in module.named_children(): - if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional: - if hasattr(child, 'qconfig') and child.qconfig is not None: + if type(child) in [nnq.FloatFunctional, nnq.QFunctional]: + if needs_observation(child): child.activation_post_process = get_activation_post_process(child.qconfig, device) + elif isinstance(child, _FusedModule): + # activation_post_process are now added directly to nn.Sequentail/_FusedModule + if needs_observation(child): + insert_activation_post_process(child) + elif _has_special_act_post_process(child): + special_act_post_process = _get_special_act_post_process(child) + insert_activation_post_process(child, special_act_post_process) elif non_leaf_module_list is not None and type(child) in non_leaf_module_list: - insert_activation_post_process(child) - # TODO: remove if needs_observation(child): - # Attaching prehook - if prehook is not None: - child.add_module('activation_pre_process', prehook()) - child.register_forward_pre_hook(_observer_forward_pre_hook) - elif needs_observation(child) and is_custom_module_class(type(child)): - observed_child = get_observed_custom_module_class(type(child)).from_float(child) - mark_observed_custom_module(observed_child, type(child)) + insert_activation_post_process(child) + elif needs_observation(child) and type(child) in custom_module_class_mapping: + observed_child = custom_module_class_mapping[type(child)].from_float(child) setattr(module, name, observed_child) - insert_activation_post_process(observed_child) + # TODO: These are the modules that cannot be observed + # Once there are more, we should move them to a separate list + if custom_module_class_mapping[type(child)] != nnqa.LSTM: + insert_activation_post_process(observed_child) else: - add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, prehook) + add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping) # Insert observers only for leaf nodes, note that this observer is for # the output of the module, for input QuantStub will observe them if len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \ and type(module) in qconfig_propagation_list: insert_activation_post_process(module) - # TOOD: remove - if needs_observation(module): - # Attaching prehook - if prehook is not None: - module.add_module('activation_pre_process', prehook()) - module.register_forward_pre_hook(_observer_forward_pre_hook) def get_unique_devices_(module): return {p.device for p in module.parameters()} | \ @@ -199,7 +195,8 @@ def add_quant_dequant(module): return module def prepare(model, inplace=False, allow_list=None, - observer_non_leaf_module_list=None, prehook=None): + observer_non_leaf_module_list=None, + prepare_custom_config_dict=None): r"""Prepares a copy of the model for quantization calibration or quantization-aware training. Quantization configuration should be assigned preemptively @@ -209,18 +206,37 @@ def prepare(model, inplace=False, allow_list=None, will be propagated. Args: - model: input model to be modified in-place - inplace: carry out model transformations in-place, the original module is mutated - allow_list: list of quantizable modules - observer_non_leaf_module_list: list of non-leaf modules we want to add observer - prehook: observer we want to add to forward_pre_hook + `model`: input model to be modified in-place + `inplace`: carry out model transformations in-place, the original module is mutated + `allow_list`: list of quantizable modules + `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer + `prepare_custom_config_dict`: customization configuration dictionary for prepare function + + .. code-block:: python + + # Example of prepare_custom_config_dict: + prepare_custom_config_dict = { + # user will manually define the corresponding observed + # module class which has a from_float class method that converts + # float custom module to observed custom module + "float_to_observed_custom_module_class": { + CustomModule: ObservedCustomModule + } + } + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare") + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) + if not inplace: model = copy.deepcopy(model) + # TODO: remove allow_list qconfig_propagation_list = allow_list if qconfig_propagation_list is None: - qconfig_propagation_list = get_qconfig_propagation_list() + qconfig_propagation_list = get_default_qconfig_propagation_list() propagate_qconfig_(model, qconfig_dict=None) # sanity check common API misusage @@ -229,9 +245,27 @@ def prepare(model, inplace=False, allow_list=None, "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") - add_observer_(model, qconfig_propagation_list, observer_non_leaf_module_list, prehook=prehook) + add_observer_( + model, qconfig_propagation_list, observer_non_leaf_module_list, + custom_module_class_mapping=custom_module_class_mapping) return model +def _remove_activation_post_process(module): + # TODO: maybe we should change activation_post_process to _activation_post_process + # to prevent it from being used by user + if hasattr(module, 'activation_post_process') and \ + is_activation_post_process(module.activation_post_process): + delattr(module, 'activation_post_process') + + # remove activation_post_proceess hook + handle_ids_to_remove = set() + for handle_id, hook_fn in module._forward_hooks.items(): + if hook_fn is _observer_forward_hook: + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + module._forward_hooks.pop(handle_id) + +# TODO: rename to something more general def _remove_qconfig(module): r"""Clean up the qconfig left in the module so that new qconfig can be propagated. @@ -245,6 +279,8 @@ def _remove_qconfig(module): if hasattr(module, "qconfig"): del module.qconfig + _remove_activation_post_process(module) + def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization. @@ -262,13 +298,14 @@ def quantize(model, run_fn, run_args, mapping=None, inplace=False): Return: Quantized model. """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize") if mapping is None: - mapping = get_static_quant_module_mappings() + mapping = get_default_static_quant_module_mappings() if not inplace: model = copy.deepcopy(model) model.eval() prepare(model, inplace=True) - run_fn(model, run_args) + run_fn(model, *run_args) convert(model, mapping, inplace=True) return model @@ -302,6 +339,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, with which the submodule needs to be replaced """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") if qconfig_spec is None: if dtype == torch.qint8: qconfig_spec = { @@ -323,7 +361,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, } elif dtype == torch.quint8: qconfig_spec = { - nn.EmbeddingBag : float_qparams_dynamic_qconfig, + nn.EmbeddingBag : float_qparams_weight_only_qconfig, } else: raise ValueError( @@ -334,13 +372,13 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, elif dtype is torch.float16: default_qconfig = float16_dynamic_qconfig elif dtype is torch.quint8: - default_qconfig = float_qparams_dynamic_qconfig + default_qconfig = float_qparams_weight_only_qconfig else: raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) if mapping is None: - mapping = get_dynamic_quant_module_mappings() + mapping = get_default_dynamic_quant_module_mappings() if not inplace: model = copy.deepcopy(model) @@ -364,8 +402,10 @@ def prepare_qat(model, mapping=None, inplace=False): inplace: carry out model transformations in-place, the original module is mutated """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") if mapping is None: - mapping = get_qat_module_mappings() + mapping = get_default_qat_module_mappings() + if not inplace: model = copy.deepcopy(model) @@ -387,36 +427,57 @@ def quantize_qat(model, run_fn, run_args, inplace=False): Return: Quantized model. """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") if not inplace: model = copy.deepcopy(model) model.train() prepare_qat(model, inplace=True) - run_fn(model, run_args) + run_fn(model, *run_args) convert(model, inplace=True) return model -def convert(module, mapping=None, inplace=False, remove_qconfig=True): +def convert( + module, mapping=None, inplace=False, remove_qconfig=True, + convert_custom_config_dict=None): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. Args: - module: input module - mapping: a dictionary that maps from source module type to target - module type, can be overwritten to allow swapping user defined - Modules - inplace: carry out model transformations in-place, the original module - is mutated + `module`: prepared and calibrated module + `mapping`: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + `inplace`: carry out model transformations in-place, the original module + is mutated + `convert_custom_config_dict`: custom configuration dictionary for convert function + + .. code-block:: python + + # Example of convert_custom_config_dict: + convert_custom_config_dict = { + # user will manually define the corresponding quantized + # module class which has a from_observed class method that converts + # observed custom module to quantized custom module + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: QuantizedCustomModule + } + } """ + torch._C._log_api_usage_once("quantization_api.quantize.convert") if not inplace: module = copy.deepcopy(module) - _convert(module, mapping, inplace=True) + _convert( + module, mapping, inplace=True, + convert_custom_config_dict=convert_custom_config_dict) if remove_qconfig: _remove_qconfig(module) return module -def _convert(module, mapping=None, inplace=False): +def _convert( + module, mapping=None, inplace=False, + convert_custom_config_dict=None): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class @@ -430,40 +491,29 @@ def _convert(module, mapping=None, inplace=False): """ if mapping is None: - mapping = get_static_quant_module_mappings() + mapping = get_default_static_quant_module_mappings() + if convert_custom_config_dict is None: + convert_custom_config_dict = {} + custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) + if not inplace: module = copy.deepcopy(module) reassign = {} - # TODO(jerryzh): remove after deciding on the impl of intrinsic modules - # This is required because intrinsic modules right now are implemented as - # nn.Sequential and we don't want to swap their constituents - SWAPPABLE_MODULES = (nni.ConvBn2d, - nni.ConvBnReLU2d, - nni.LinearReLU, - nni.BNReLU2d, - nni.BNReLU3d, - nni.ConvBn1d, - nni.ConvReLU1d, - nni.ConvBnReLU1d, - nni.ConvReLU2d, - nni.ConvReLU3d, - nniqat.ConvBn2d, - nniqat.ConvBnReLU2d) - for name, mod in module.named_children(): - # both swappable modules and observed custom modules are + # both fused modules and observed custom modules are # swapped as one unit - if type(mod) not in SWAPPABLE_MODULES and \ - not is_observed_custom_module(mod): - _convert(mod, mapping, inplace=True) - reassign[name] = swap_module(mod, mapping) + if not isinstance(mod, _FusedModule) and \ + type(mod) not in custom_module_class_mapping: + _convert(mod, mapping, True, # inplace + custom_module_class_mapping) + reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) for key, value in reassign.items(): module._modules[key] = value return module -def swap_module(mod, mapping): +def swap_module(mod, mapping, custom_module_class_mapping): r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. @@ -475,11 +525,10 @@ def swap_module(mod, mapping): The corresponding quantized module of `mod` """ new_mod = mod - # Always replace dequantstub with dequantize - if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub: + if hasattr(mod, 'qconfig') and mod.qconfig is not None: swapped = False - if is_observed_custom_module(mod): - new_mod = get_quantized_custom_module_class(mod._FLOAT_MODULE).from_observed(mod) + if type(mod) in custom_module_class_mapping: + new_mod = custom_module_class_mapping[type(mod)].from_observed(mod) swapped = True elif type(mod) in mapping: new_mod = mapping[type(mod)].from_float(mod) diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 0f68f2e0e9e9c..4ad27d7407967 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -1,164 +1,365 @@ +import torch +from torch.fx import GraphModule # type: ignore +from torch.fx.symbolic_trace import Tracer # type: ignore +from torch.fx.node import Target, Node, Argument # type: ignore from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 -from torch.fx import GraphModule # type: ignore from .fx.utils import graph_pretty_str # noqa: F401 +from .fx.utils import get_custom_module_class_keys # noqa: F401 +from torch.nn.intrinsic import _FusedModule +from typing import Dict, Any, List, Callable, Tuple, Optional -def _check_is_graph_module(model): +def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): raise ValueError( 'input model must be a GraphModule, ' + - 'please run torch.fx.symbolic_trace on your model before using ' + - 'quantize_fx. Got type:' + str(type(model))) + 'Got type:' + str(type(model)) + ' Please make ' + + 'sure to follow the tutorials.') -def fuse_fx(graph_module, inplace=False): - r""" Fuse modules in preparation for quantization +def _swap_ff_with_fxff(model: torch.nn.Module) -> None: + r""" Swap FloatFunctional with FXFloatFunctional + """ + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + _swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.nn.quantized.FXFloatFunctional() + +def _fuse_fx( + graph_module: GraphModule, + fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + r""" Internal helper function to fuse modules in preparation for quantization Args: graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) """ _check_is_graph_module(graph_module) fuser = Fuser() - return fuser.fuse(graph_module, inplace) + return fuser.fuse(graph_module, fuse_custom_config_dict) + +class Scope(object): + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example: + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self): + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x -def _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant): - _check_is_graph_module(graph_module) - graph_module = fuse_fx(graph_module, inplace) + """ + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + +class ScopeContextManager(object): + """ A context manager to track the Scope of Node during symbolic + tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + def __init__( + self, + scope: Scope, + current_module: torch.nn.Module, + current_module_path: str): + super().__init__() + self.prev_module_type = scope.module_type + self.prev_module_path = scope.module_path + self.scope = scope + self.scope.module_path = current_module_path + self.scope.module_type = type(current_module) + + def __enter__(self): + return + + def __exit__(self, *args): + self.scope.module_path = self.prev_module_path + self.scope.module_type = self.prev_module_type + return + + +class QuantizationTracer(Tracer): + def __init__( + self, + skipped_module_names: List[str], + skipped_module_classes: List[Callable]): + super().__init__() + self.skipped_module_names = skipped_module_names + self.skipped_module_classes = skipped_module_classes + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.node_name_to_scope : Dict[str, Tuple[str, type]] = {} + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + return (m.__module__.startswith("torch.nn") and + not isinstance(m, torch.nn.Sequential)) or \ + module_qualified_name in self.skipped_module_names or \ + type(m) in self.skipped_module_classes or \ + isinstance(m, _FusedModule) + + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + module_qualified_name = self.path_of_module(m) + # Creating scope with information of current module + # scope will be restored automatically upon exit + with ScopeContextManager(self.scope, m, module_qualified_name): + return super().call_module(m, forward, args, kwargs) + + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) + if kind == "call_method": + self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type) + return node + +def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: + r""" Internal helper function for prepare_fx + Args: + `model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx` + `is_standalone_module`: a boolean flag indicates whether we are + quantizing a standalone module or not, a standalone module + is a submodule of the parent module that is not inlined in the +forward graph of the parent module, + the way we quantize standalone module is described in: + :func:`~torch.quantization._prepare_standalone_module_fx` + """ + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + + skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", []) + skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", []) + + # swap FloatFunctional with FXFloatFunctional + _swap_ff_with_fxff(model) + + # symbolically trace the model + if not is_standalone_module: + # standalone module and custom module config are applied in top level module + standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", []) + skipped_module_names += [config[0] for config in standalone_module_name_configs] + + standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", []) + skipped_module_classes += [config[0] for config in standalone_module_class_configs] + float_custom_module_classes = get_custom_module_class_keys( + prepare_custom_config_dict, "float_to_observed_custom_module_class") + skipped_module_classes += float_custom_module_classes + tracer = QuantizationTracer( + skipped_module_names, skipped_module_classes) + graph_module = GraphModule(model, tracer.trace(model)) + graph_module = _fuse_fx(graph_module, prepare_custom_config_dict) quantizer = Quantizer() - prepare = quantizer.prepare_dynamic if is_dynamic_quant else quantizer.prepare - prepared = prepare(graph_module, qconfig_dict, inplace=True) + prepared = quantizer.prepare( + graph_module, + qconfig_dict, + tracer.node_name_to_scope, + prepare_custom_config_dict=prepare_custom_config_dict, + is_standalone_module=is_standalone_module) + + preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", []) + for attr_name in preserved_attributes: + setattr(prepared, attr_name, getattr(model, attr_name)) return prepared -def prepare_fx(graph_module, qconfig_dict, inplace=False): - r""" Prepare a model for post training static quantization or - qantization aware training, not for public use. - - Args: - graph_module: model from symbolic_tracing (torch.fx.symbolic_trace), must be - an eval model - qconfig_dict: see :func:`~torch.quantization.quantize_fx` - - Return: - A GraphModule with observer or fake quant modules, ready for - calibration or quantization aware training +def _prepare_standalone_module_fx( + model: torch.nn.Module, + qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the + parent module. + standalone_module means it a submodule that is not inlined in parent module, + and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module """ - return _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) -def prepare_static_fx(graph_module, qconfig_dict, inplace=False): - assert not graph_module.training, 'prepare_static_fx only works for models in ' + \ - 'eval mode' - return prepare_fx(graph_module, qconfig_dict, inplace) - -def prepare_qat_fx(graph_module, qconfig_dict, inplace=False): - r""" Prepare a model for quantization aware training +def fuse_fx(model: torch.nn.Module, + fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. + Fusion rules are defined in torch.quantization.fx.fusion_pattern.py Args: - graph_module: model from symbolic_tracing (torch.fx.symbolic_trace), must be - a train model - qconfig_dict: see :func:`~torch.quantization.quantize_fx` - - Return: - A GraphModule with observer or fake quant modules, ready for - calibration or quantization aware training - """ - assert graph_module.training, 'prepare_qat_fx only works for models in ' + \ - 'train mode' - return prepare_fx(graph_module, qconfig_dict, inplace) + `model`: a torch.nn.Module model + `fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g. + fuse_custom_config_dict = { + "additional_fuser_method_mapping": { + (Module1, Module2): fuse_module1_module2 + } + } -def prepare_dynamic_fx(graph_module, qconfig_dict, inplace=False): - r""" Prepare a model for post training dynamic quantization + Example: + ```python + from torch.quantization import fuse_fx + m = Model().eval() + m = fuse_fx(m) + ``` """ - return _prepare_fx(graph_module, qconfig_dict, inplace, True) + torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") + assert not model.training, 'fuse_fx only works on models in eval mode' + graph_module = torch.fx.symbolic_trace(model) # type: ignore + return _fuse_fx(graph_module, fuse_custom_config_dict) -def _convert_fx(graph_module, inplace, debug, is_dynamic_quant): - _check_is_graph_module(graph_module) - quantizer = Quantizer() - return quantizer.convert(graph_module, inplace, debug, is_dynamic_quant) - -def convert_fx(graph_module, inplace=False, debug=False): - r""" Convert a calibrated or trained model to a quantized model - """ - return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=False) +def prepare_fx( + model: torch.nn.Module, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + r""" Prepare a model for post training static quantization -convert_static_fx = convert_fx -convert_qat_fx = convert_fx + Args: + `model`: torch.nn.Module model, must be in eval mode + `qconfig_dict`: qconfig_dict is a dictionary with the following configurations: + qconfig_dict = { + # optional, global config + "": qconfig?, + + # optional, used for module and function types + # could also be split into module_types and function_types if we prefer + "object_type": [ + (torch.nn.Conv2d, qconfig?), + (torch.nn.functional.add, qconfig?), + ..., + ], + + # optional, used for module names + "module_name": [ + ("foo.bar", qconfig?) + ..., + ], + + # optional, matched in order, first match takes precedence + "module_name_regex": [ + ("foo.*bar.*conv[0-9]+", qconfig?) + ..., + ], + # priority (in increasing order): global, object_type, module_name_regex, module_name + # qconfig == None means fusion and quantization should be skipped for anything + # matching the rule + } + `prepare_custom_config_dict`: customization configuration dictionary for + quantization tool: + prepare_custom_config_dict = { + # optional: specify the path for standalone modules + # These modules are symbolically traced and quantized as one unit + "standalone_module_name": [ + # module_name, qconfig_dict, prepare_custom_config_dict + ("submodule.standalone", + None, # qconfig_dict for the prepare function called in the submodule, + # None means use qconfig from parent qconfig_dict + {"input_quantized_idxs": [], "output_quantized_idxs": []}) # prepare_custom_config_dict + ], -def convert_dynamic_fx(graph_module, inplace=False, debug=False): - return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=True) + "standalone_module_class": [ + # module_class, qconfig_dict, prepare_custom_config_dict + (StandaloneModule, + None, # qconfig_dict for the prepare function called in the submodule, + # None means use qconfig from parent qconfig_dict + {"input_quantized_idxs": [0], "output_quantized_idxs": [0]}) # prepare_custom_config_dict + ], -def _quantize_fx(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, - debug=False, is_dynamic_quant=False): - assert not model.training, 'quantize_fx is only used for post training ' + \ - 'quantization(eval mode), for quantization aware training please use ' + \ - 'prepare_qat_fx and convert_qat_fx.' + # user will manually define the corresponding observed + # module class which has a from_float class method that converts + # float custom module to observed custom module + # (only needed for static quantization) + "float_to_observed_custom_module_class": { + "static": { + CustomModule: ObservedCustomModule + } + }, + + # the qualified names for the submodule that are not symbolically traceable + "non_traceable_module_name": [ + "non_traceable_module" + ], - if is_dynamic_quant: - model = prepare_dynamic_fx(model, qconfig_dict, inplace) - # inplace is True since the inplace option is already applied in previous step - model = convert_dynamic_fx(model, inplace=True, debug=debug) - else: - assert run_fn, "Must provide calibration function for post training static quantization" - assert run_args, "Must provide calibration dataset for post training static quantization" - model = prepare_fx(model, qconfig_dict, inplace) - run_fn(model, *run_args) - # inplace is True since the inplace option is already applied in previous step - model = convert_fx(model, inplace=True, debug=debug) + # the module classes that are not symbolically traceable + # we'll also put dynamic/weight_only custom module here + "non_traceable_module_class": [ + NonTraceableModule + ], - return model + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + + # Additioanl module mapping for qat + "additional_qat_module_mapping": { + torch.nn.intrinsic.ConvBn2d: torch.nn.qat.ConvBn2d + }, + + # Additional fusion patterns + "additional_fusion_pattern": { + (torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler + }, + + # Additional quantization patterns + "additional_quant_pattern": { + torch.nn.Conv2d: ConvReluQuantizeHandler, + (torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler, + } + # By default, inputs and outputs of the graph are assumed to be in + # fp32. Providing `input_quantized_idxs` will set the inputs with the + # corresponding indices to be quantized. Providing + # `output_quantized_idxs` will set the outputs with the corresponding + # indices to be quantized. + "input_quantized_idxs": [0], + "output_quantized_idxs": [0], -def quantize_static_fx(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False): - r"""Quantize the input float symbolically traced GraphModule model with - post training static quantization + # Attributes that are not used in forward function will + # be removed when constructing GraphModule, this is a list of attributes + # to preserve as an attribute of the GraphModule even when they are + # not used in the code + "preserved_attributes": ["preserved_attr"], + } - First it will prepare the model for calibration, then it calls - `run_fn` which will run the calibration step, after that we will - convert the model to a quantized model. - - Args: - `model`: input float TorchScript model - `qconfig_dict`: qconfig_dict is a dictionary with the following configurations: - qconfig_dict = { - # optional, global config - "": qconfig?, - - # optional, used for module and function types - # could also be split into module_types and function_types if we prefer - "object_type": [ - (torch.nn.Conv2d, qconfig?), - (torch.nn.functional.add, qconfig?), - ..., - ], - - # optional, used for module names - "module_name": [ - ("foo.bar", qconfig?) - ..., - ], - - # optional, matched in order, first match takes precedence - "module_name_regex": [ - ("foo.*bar.*conv[0-9]+", qconfig?) - ..., - ] - # priority (in increasing order): global, object_type, module_name_regex, module_name - # qconfig == None means fusion and quantization should be skipped for anything - # matching the rule - } - `run_fn`: a calibration function for calibrating the prepared model - `run_args`: positional arguments for `run_fn` - `inplace`: carry out model transformations in-place, the original module is - mutated - `debug`: flag for producing a debug friendly model (preserve weight attribute) Return: - Quantized TorchSciprt model. + A GraphModule with observer (configured by qconfig_dict), ready for calibration Example: ```python import torch from torch.quantization import get_default_qconfig - from torch.quantization import quantize_fx + from torch.quantization import prepare_fx - graph_module = torch.fx.symbolic_trace(float_model.eval()) + float_model.eval() + graph_module = torch.fx.symbolic_trace(float_model) qconfig = get_default_qconfig('fbgemm') def calibrate(model, data_loader): model.eval() @@ -166,53 +367,139 @@ def calibrate(model, data_loader): for image, target in data_loader: model(image) - quantized_model = quantize_fx( - graph_module, - {'': qconfig}, - calibrate, - [data_loader_test]) + qconfig_dict = {"": qconfig} + prepared_model = prepare_fx(graph_module, qconfig_dict) + # Run calibration + calibrate(prepared_model, sample_inference_data) ``` """ - return _quantize_fx( - model, qconfig_dict, run_fn, run_args, inplace, debug, is_dynamic_quant=False) - -def quantize_dynamic_fx(model, qconfig_dict, inplace=False, debug=False): - r"""Quantize the input float symbolically traced GraphModule model with - post training dynamic quantization. - Currently only qint8 quantization of torch.nn.Linear is supported. + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") + assert not model.training, 'prepare_fx only works for models in ' + \ + 'eval mode' + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) +def prepare_qat_fx( + model: torch.nn.Module, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + r""" Prepare a model for quantization aware training Args: - `model`: input float TorchScript model - `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and - qconfig for that module as value, please see detailed - descriptions in :func:`~torch.quantization.quantize_fx` - `inplace`: carry out model transformations in-place, the original module is - mutated - `debug`: flag for producing a debug friendly model (preserve weight attribute) + `model`: torch.nn.Module model, must be in train mode + `qconfig_dict`: see :func:`~torch.quantization.prepare_fx` + `prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx` Return: - Quantized TorchSciprt model. + A GraphModule with fake quant modules (configured by qconfig_dict), ready for + quantization aware training Example: ```python import torch - from torch.quantization import per_channel_dynamic_qconfig - from torch.quantization import quantize_dynmiac_fx + from torch.quantization import get_default_qat_qconfig + from torch.quantization import prepare_fx + + qconfig = get_default_qat_qconfig('fbgemm') + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + float_model.train() + qconfig_dict = {"": qconfig} + prepared_model = prepare_fx(float_model, qconfig_dict) + # Run calibration + train_loop(prepared_model, train_loop) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") + assert model.training, 'prepare_qat_fx only works for models in ' + \ + 'train mode' + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) - graph_module = torch.fx.symbolic_trace(float_model.eval()) - qconfig = get_default_qconfig('fbgemm') - def calibrate(model, data_loader): - model.eval() - with torch.no_grad(): - for image, target in data_loader: - model(image) +def _convert_fx( + graph_module: GraphModule, debug: bool, + convert_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: + """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx` + """ + if convert_custom_config_dict is None: + convert_custom_config_dict = {} - quantized_model = quantize_dynamic_fx( - graph_module, - {'': qconfig}, - calibrate, - [data_loader_test]) + _check_is_graph_module(graph_module) + + quantizer = Quantizer() + quantized = quantizer.convert(graph_module, debug, convert_custom_config_dict, is_standalone_module) + + preserved_attributes = convert_custom_config_dict.get("preserved_attributes", []) + for attr_name in preserved_attributes: + setattr(quantized, attr_name, getattr(graph_module, attr_name)) + return quantized + +def convert_fx( + graph_module: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + r""" Convert a calibrated or trained model to a quantized model + Args: + `graph_module`: A prepared and calibrated/trained model (GraphModule) + `debug`: flag for producing a debug friendly model (preserve weight attribute) + `convert_custom_config_dict`: dictionary for custom configurations for convert function: + convert_custom_config_dict = { + + # addtional object (module/operator) mappings that will overwrite the default + # module mappingn + "additional_object_mapping": { + "static": { + FloatModule: QuantizedModule, + float_op: quantized_op + }, + "dynamic": { + FloatModule: DynamicallyQuantizedModule, + float_op: dynamically_quantized_op + }, + }, + + # user will manually define the corresponding quantized + # module class which has a from_observed class method that converts + # observed custom module to quantized custom module + "observed_to_quantized_custom_module_class": { + "static": { + ObservedCustomModule: QuantizedCustomModule + }, + "dynamic": { + ObservedCustomModule: QuantizedCustomModule + }, + "weight_only": { + ObservedCustomModule: QuantizedCustomModule + } + }, + + # Attributes that are not used in forward function will + # be removed when constructing GraphModule, this is a list of attributes + # to preserve as an attribute of the GraphModule even when they are + # not used in the code + "preserved_attributes": ["preserved_attr"], + } + + Return: + A quantized model (GraphModule) + + Example: + ```python + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + quantized_model = convert_fx(prepared_model) ``` """ - return _quantize_fx( - model, qconfig_dict, inplace=inplace, debug=debug, is_dynamic_quant=True) + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") + return _convert_fx(graph_module, debug, convert_custom_config_dict) + +def _convert_standalone_module_fx( + graph_module: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` + and convert it to a quantized model + + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details + """ + return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True) diff --git a/torch/quantization/quantize_jit.py b/torch/quantization/quantize_jit.py index ef6792d521f6b..39df93730d06d 100644 --- a/torch/quantization/quantize_jit.py +++ b/torch/quantization/quantize_jit.py @@ -35,6 +35,7 @@ def fuse_conv_bn_jit(model, inplace=False): Args: model: TorchScript model from scripting or tracing """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit") model_c = model._c model_c = torch._C._jit_pass_fold_convbn(model_c) if inplace: @@ -62,9 +63,11 @@ def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC return model def prepare_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit") return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC) def prepare_dynamic_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit") return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC) def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC, @@ -84,12 +87,15 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC, model._reconstruct(model_c) else: model = wrap_cpp_module(model_c) + torch._C._jit_pass_constant_propagation(model.graph) return model def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs) def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs) def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC): @@ -105,6 +111,7 @@ def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False run_fn(model, *run_args) model = convert_jit(model, True, debug) + torch._C._jit_pass_constant_propagation(model.graph) return model def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False): @@ -157,6 +164,7 @@ def calibrate(model, data_loader): [data_loader_test]) ``` """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit") return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC) def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): @@ -197,4 +205,5 @@ def calibrate(model, data_loader): [data_loader_test]) ``` """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC) diff --git a/torch/quantization/stubs.py b/torch/quantization/stubs.py index 7018ef1b836da..1f4c462e56e2c 100644 --- a/torch/quantization/stubs.py +++ b/torch/quantization/stubs.py @@ -40,6 +40,10 @@ class QuantWrapper(nn.Module): will be swapped to `nnq.Quantize` which does actual quantization. Similarly for `DeQuantStub`. """ + quant: QuantStub + dequant: DeQuantStub + module: nn.Module + def __init__(self, module): super(QuantWrapper, self).__init__() qconfig = module.qconfig if hasattr(module, 'qconfig') else None diff --git a/torch/quantization/utils.py b/torch/quantization/utils.py new file mode 100644 index 0000000000000..f5732b1fc24f1 --- /dev/null +++ b/torch/quantization/utils.py @@ -0,0 +1,86 @@ +""" +Utils shared by different modes of quantization (eager/graph) +""" +import torch +from .quant_type import QuantType, quant_type_to_str + +def get_combined_dict(default_dict, additional_dict): + d = default_dict.copy() + d.update(additional_dict) + return d + +def is_per_tensor(qscheme): + return qscheme == torch.per_tensor_affine or \ + qscheme == torch.per_tensor_symmetric + +def is_per_channel(qscheme): + return qscheme in [torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric] + +def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): + """ Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + quant_type_str = quant_type_to_str(quant_type) + class_mapping = custom_module_class_mapping.get(quant_type_str, {}) + assert type(custom_module) in class_mapping, "did not find corresponding observed " \ + "module class for {} in mapping: {}".format(type(custom_module), class_mapping) + return class_mapping[type(custom_module)] + +def activation_is_statically_quantized(qconfig): + """ Given a qconfig, decide if the activation needs to be + statically quantized or not + """ + assert qconfig is not None + activation = qconfig.activation() + return activation.dtype in [torch.quint8, torch.qint8] + +def weight_dtype(qconfig): + assert qconfig is not None + weight = qconfig.weight() + return weight.dtype + +def weight_is_statically_quantized(qconfig): + """ Given a qconfig, decide if the weight needs to be + quantized or not + """ + return weight_dtype(qconfig) in [torch.quint8, torch.qint8] + +def get_qconfig_dtypes(qconfig): + r""" returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_compute_dtype) + """ + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None + return (activation.dtype, weight.dtype, compute_dtype) + +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [torch.quint8, torch.qint8] + if weight.dtype in static_dtypes: + if activation.dtype in static_dtypes: + return QuantType.STATIC + elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: + return QuantType.DYNAMIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if activation.dtype == torch.float: + return QuantType.DYNAMIC + + raise Exception("Unrecognized dtype combination in get_quant_type: activation({})," + "weight({})".format(activation.dtype, weight.dtype)) diff --git a/torch/quasirandom.py b/torch/quasirandom.py index c4b75f4cd5cec..b738dcf7b60c6 100644 --- a/torch/quasirandom.py +++ b/torch/quasirandom.py @@ -1,4 +1,5 @@ import torch +from typing import Optional class SobolEngine(object): @@ -57,11 +58,10 @@ def __init__(self, dimension, scramble=False, seed=None): torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) if self.scramble: + g: Optional[torch.Generator] = None if self.seed is not None: g = torch.Generator() g.manual_seed(self.seed) - else: - g = None shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g) self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))) diff --git a/torch/random.py b/torch/random.py index 30af86065907e..31e2643845d27 100644 --- a/torch/random.py +++ b/torch/random.py @@ -67,7 +67,7 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in. - Arguments: + Args: devices (iterable of CUDA IDs): CUDA devices for which to fork the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates on all devices, but will emit a warning if your machine has a lot diff --git a/torch/serialization.py b/torch/serialization.py index 1c05767922a85..3b6f5828d8583 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -13,7 +13,7 @@ from ._six import string_classes as _string_classes from torch._utils_internal import get_source_lines_and_file from torch.types import Storage -from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union +from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO import copyreg import pickle import pathlib @@ -192,7 +192,7 @@ def storage_to_tensor_type(storage): def _is_path(name_or_buffer): return isinstance(name_or_buffer, str) or \ - (sys.version_info[0] == 3 and isinstance(name_or_buffer, pathlib.Path)) + isinstance(name_or_buffer, pathlib.Path) class _opener(object): @@ -330,7 +330,7 @@ def _check_dill_version(pickle_module) -> None: pickle_module.__version__ )) -def save(obj, f: Union[str, os.PathLike, BinaryIO], +def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None: """Saves an object to a disk file. @@ -481,16 +481,14 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f'data/{key}' storage = serialized_storages[key] - if storage.device.type == 'cpu': - # If it's on the CPU we can directly copy it into the zip file - num_bytes = storage.size() * storage.element_size() - zip_file.write_record(name, storage.data_ptr(), num_bytes) - else: - # Copy to a buffer, then serialize that - buf = io.BytesIO() - storage._write_file(buf, _should_read_directly(buf)) - buf_value = buf.getvalue() - zip_file.write_record(name, buf_value, len(buf_value)) + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != 'cpu': + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + num_bytes = storage.size() * storage.element_size() + zip_file.write_record(name, storage.data_ptr(), num_bytes) def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): @@ -526,7 +524,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): deserialization methods using :func:`torch.serialization.register_package`. Args: - f: a file-like object (has to implement :meth:`read`, :meth`readline`, :meth`tell`, and :meth`seek`), + f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), or a string or os.PathLike object containing a file name map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 3795051062f45..9ed1b0dc02acc 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -34,8 +34,8 @@ def addmm(mat: Tensor, mat1: Tensor, mat2: Tensor, Args: mat (Tensor): a dense matrix to be added - mat1 (SparseTensor): a sparse matrix to be multiplied - mat2 (Tensor): a dense matrix be multiplied + mat1 (Tensor): a sparse matrix to be multiplied + mat2 (Tensor): a dense matrix to be multiplied beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) """ @@ -45,15 +45,20 @@ def addmm(mat: Tensor, mat1: Tensor, mat2: Tensor, def mm(mat1: Tensor, mat2: Tensor) -> Tensor: r""" Performs a matrix multiplication of the sparse matrix :attr:`mat1` - and dense matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a + and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a - :math:`(n \times p)` dense tensor. :attr:`mat1` need to have `sparse_dim = 2`. + :math:`(n \times p)` tensor. :attr:`mat1` need to have `sparse_dim = 2`. This function also supports backward for both matrices. Note that the gradients of :attr:`mat1` is a coalesced sparse tensor. Args: mat1 (SparseTensor): the first sparse matrix to be multiplied - mat2 (Tensor): the second dense matrix to be multiplied + mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense + + Shape: + The format of the output tensor of this function follows: + - sparse x sparse -> sparse + - sparse x dense -> dense Example:: @@ -81,16 +86,18 @@ def mm(mat1: Tensor, mat2: Tensor) -> Tensor: values=tensor([ 0.1394, -0.6415, -2.1639, 0.1394, -0.6415, -2.1639]), size=(2, 3), nnz=6, layout=torch.sparse_coo) """ + if mat1.is_sparse and mat2.is_sparse: + return torch._sparse_sparse_matmul(mat1, mat2) return torch._sparse_mm(mat1, mat2) def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor: r""" - Returns the sum of each row of SparseTensor :attr:`input` in the given + Returns the sum of each row of the sparse tensor :attr:`input` in the given dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions, reduce over all of them. When sum over all ``sparse_dim``, this method - returns a Tensor instead of SparseTensor. + returns a dense tensor instead of a sparse tensor. All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output tensor having :attr:`dim` fewer dimensions than :attr:`input`. @@ -99,7 +106,7 @@ def sum(input: Tensor, dim: DimOrDims = None, will propagate back. Note that the gradients of :attr:`input` is coalesced. Args: - input (Tensor): the input SparseTensor + input (Tensor): the input sparse tensor dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce over all dims. dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. @@ -127,7 +134,7 @@ def sum(input: Tensor, dim: DimOrDims = None, [-1.9682, -0.5340, 0.7483]]]), size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo) - # when sum over only part of sparse_dims, return a SparseTensor + # when sum over only part of sparse_dims, return a sparse tensor >>> torch.sparse.sum(S, [1, 3]) tensor(indices=tensor([[0, 2, 3]]), values=tensor([[-1.4512, 0.4073], @@ -135,7 +142,7 @@ def sum(input: Tensor, dim: DimOrDims = None, [-0.3183, -1.7539]]), size=(5, 2), nnz=3, layout=torch.sparse_coo) - # when sum over all sparse dim, return a dense Tensor + # when sum over all sparse dim, return a dense tensor # with summed dims squeezed >>> torch.sparse.sum(S, [0, 1, 3]) tensor([-2.6596, -1.1450]) @@ -161,13 +168,13 @@ def softmax(input: Tensor, dim: int, dtype: Optional[DType] = None) -> Tensor: where :math:`i, j` run over sparse tensor indices and unspecified entries are ignores. This is equivalent to defining unspecified - entries as negative infinity so that :max:`exp(x_k) = 0` when the + entries as negative infinity so that :math:`exp(x_k) = 0` when the entry with index :math:`k` has not specified. It is applied to all slices along `dim`, and will re-scale them so that the elements lie in the range `[0, 1]` and sum to 1. - Arguments: + Args: input (Tensor): input dim (int): A dimension along which softmax will be computed. dtype (:class:`torch.dtype`, optional): the desired data type @@ -184,7 +191,7 @@ def log_softmax(input: Tensor, dim: int, dtype: Optional[DType] = None) -> Tenso See :class:`~torch.sparse.softmax` for more details. - Arguments: + Args: input (Tensor): input dim (int): A dimension along which softmax will be computed. dtype (:class:`torch.dtype`, optional): the desired data type diff --git a/torch/storage.py b/torch/storage.py index c06c42f7a2868..1e06f3fee95b6 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -2,11 +2,31 @@ import torch from ._utils import _type, _cuda +from typing import Any, TypeVar, Type - +T = TypeVar('T', bound='_StorageBase') class _StorageBase(object): - is_cuda = False - is_sparse = False + _cdata: Any + is_cuda: bool = False + is_sparse: bool = False + + def __init__(self, *args, **kwargs): ... # noqa: E704 + def __len__(self) -> int: ... # noqa: E704 + def __getitem__(self, idx): ... # noqa: E704 + def copy_(self, source: T) -> T: ... # noqa: E704 + def size(self) -> int: ... # noqa: E704 + def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704 + def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704 + def element_size(self) -> int: ... # noqa: E704 + def get_device(self) -> int: ... # noqa: E704 + + # Defined in torch/csrc/generic/StorageSharing.cpp + def _share_filename_(self): ... # noqa: E704 + def _share_fd_(self): ... # noqa: E704 + @classmethod + def _new_using_filename(cls: Type[T], size: int) -> T: ... # noqa: E704 + @classmethod + def _new_using_fd(cls: Type[T], size: int) -> T: ... # noqa: E704 def __str__(self): content = ' ' + '\n '.join(str(self[i]) for i in range(len(self))) @@ -104,7 +124,7 @@ def pin_memory(self): if self.is_cuda: raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned") import torch.cuda - allocator = torch.cuda._host_allocator() + allocator = torch.cuda._host_allocator() # type: ignore[attr-defined] return type(self)(self.size(), allocator=allocator).copy_(self) def share_memory_(self): @@ -141,5 +161,5 @@ def _load_from_bytes(b): return torch.load(io.BytesIO(b)) -_StorageBase.type = _type -_StorageBase.cuda = _cuda +_StorageBase.type = _type # type: ignore[assignment] +_StorageBase.cuda = _cuda # type: ignore[assignment] diff --git a/torch/tensor.py b/torch/tensor.py index 3eadb4667e879..eedffc7b6a97e 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -1,16 +1,19 @@ -import torch -import torch._C as _C -from torch._namedtensor_internals import update_names, check_serializing_named_tensor, resolve_ellipsis -from torch._namedtensor_internals import unzip_namedshape, single_ellipsis_index, is_ellipsis from collections import OrderedDict -import torch.utils.hooks as hooks +import functools +from numbers import Number +from typing import Any, Dict, Optional, Tuple, Union import warnings import weakref -from torch._C import _add_docstr -from typing import Any, Dict, Tuple, Union -from numbers import Number -import functools -from typing import Optional + +import torch +import torch._C as _C +from torch._namedtensor_internals import ( + update_names, check_serializing_named_tensor, resolve_ellipsis, + unzip_namedshape, single_ellipsis_index, is_ellipsis) +from torch.overrides import ( + has_torch_function, has_torch_function_unary, has_torch_function_variadic, + handle_torch_function) +import torch.utils.hooks as hooks def _wrap_type_error_to_not_implemented(f): @@ -20,8 +23,7 @@ def _wrap_type_error_to_not_implemented(f): @functools.wraps(f, assigned=assigned) def wrapped(*args, **kwargs): - from torch.overrides import has_torch_function, handle_torch_function - if not all(type(t) is Tensor for t in args) and has_torch_function(args): + if has_torch_function(args): return handle_torch_function(wrapped, args, *args, **kwargs) try: return f(*args, **kwargs) @@ -39,10 +41,8 @@ def wrapped(*args, **kwargs): # otherwise, it will not show up in autocomplete. class Tensor(torch._C._TensorBase): def __deepcopy__(self, memo): - from torch.overrides import has_torch_function, handle_torch_function - relevant_args = (self,) - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__deepcopy__, relevant_args, self, memo) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo) if not self.is_leaf: raise RuntimeError("Only Tensors created explicitly by the user " "(graph leaves) support the deepcopy protocol at the moment") @@ -81,10 +81,8 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto) check_serializing_named_tensor(self) # See Note [Don't serialize hooks] torch.utils.hooks.warn_if_has_hooks(self) @@ -150,10 +148,8 @@ def __reduce_ex__(self, proto): return (torch._utils._rebuild_tensor_v2, args) def __setstate__(self, state): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__setstate__, relevant_args, self, state) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__setstate__, (self,), self, state) # Warning: this method is NOT called when you torch.load() a tensor; # that is managed by _rebuild_tensor_v2 if not self.is_leaf: @@ -171,14 +167,12 @@ def __setstate__(self, state): self.requires_grad, _, self._backward_hooks = state def __repr__(self): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__repr__, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__repr__, (self,), self) # All strings are unicode in Python 3. return torch._tensor_str._str(self) - def backward(self, gradient=None, retain_graph=None, create_graph=False): + def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None): r"""Computes the gradient of current tensor w.r.t. graph leaves. The graph is differentiated using the chain rule. If the tensor is @@ -192,7 +186,13 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False): See :ref:`Default gradient layouts` for details on the memory layout of accumulated gradients. - Arguments: + .. note:: + + If you run any forward ops, create ``gradient``, and/or call ``backward`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + + Args: gradient (Tensor or None): Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted to a Tensor that does not require grad unless ``create_graph`` is True. @@ -207,18 +207,22 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False): create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. + inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be + accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were + used to compute the attr::tensors. All the provided inputs must be leaf + Tensors. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): return handle_torch_function( Tensor.backward, - relevant_args, + (self,), self, gradient=gradient, retain_graph=retain_graph, - create_graph=create_graph) - torch.autograd.backward(self, gradient, retain_graph, create_graph) + create_graph=create_graph, + inputs=inputs) + torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) def register_hook(self, hook): r"""Registers a backward hook. @@ -249,10 +253,8 @@ def register_hook(self, hook): >>> h.remove() # removes the hook """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.register_hook, relevant_args, self, hook) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.register_hook, (self,), self, hook) if not self.requires_grad: raise RuntimeError("cannot register a hook on a tensor that " "doesn't require gradient") @@ -291,7 +293,7 @@ def trim(str): loss.backward() """)) - detach = _add_docstr(_C._TensorBase.detach, r""" + detach = _C._add_docstr(_C._TensorBase.detach, r""" Returns a new Tensor, detached from the current graph. The result will never require gradient. @@ -311,17 +313,15 @@ def trim(str): trigger an error. """) - detach_ = _add_docstr(_C._TensorBase.detach_, r""" + detach_ = _C._add_docstr(_C._TensorBase.detach_, r""" Detaches the Tensor from the graph that created it, making it a leaf. Views cannot be detached in-place. """) def retain_grad(self): r"""Enables .grad attribute for non-leaf Tensors.""" - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.retain_grad, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.retain_grad, (self,), self) if not self.requires_grad: raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False") if self.is_leaf: # no-op for leaves @@ -350,10 +350,8 @@ def is_shared(self): This is always ``True`` for CUDA tensors. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.is_shared, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.is_shared, (self,), self) return self.storage().is_shared() def share_memory_(self): @@ -362,19 +360,15 @@ def share_memory_(self): This is a no-op if the underlying storage is already in shared memory and for CUDA tensors. Tensors in shared memory cannot be resized. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.share_memory_, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.share_memory_, (self,), self) self.storage().share_memory_() return self def __reversed__(self): r"""Reverses the tensor along dimension 0.""" - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__reversed__, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__reversed__, (self,), self) if self.dim() == 0: return self else: @@ -382,19 +376,38 @@ def __reversed__(self): def norm(self, p="fro", dim=None, keepdim=False, dtype=None): r"""See :func:`torch.norm`""" - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.norm, relevant_args, self, p=p, dim=dim, keepdim=keepdim, dtype=dtype) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype) return torch.norm(self, p, dim, keepdim, dtype=dtype) def lu(self, pivot=True, get_infos=False): r"""See :func:`torch.lu`""" # If get_infos is True, then we don't need to check for errors and vice versa - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.lu, relevant_args, self, pivot=pivot, get_infos=get_infos) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos) + + if not torch._jit_internal.is_scripting(): + if self.requires_grad: + if not (self.size(-2) == self.size(-1) and self.dtype.is_floating_point): + raise ValueError( + 'lu.backward works only with batches of squared full-rank matrices' + ' of floating types.' + ) + + from torch._autograd_functions import _LU + LU, pivots, infos = _LU.apply(self, pivot, get_infos) + if get_infos: + return LU, pivots, infos + else: + return LU, pivots + else: + if self.requires_grad: + raise RuntimeError( + 'Script and require gradients is not supported at the moment.' + 'If you just want to do the forward, use .detach()' + 'on the input before calling the function.' + ) + LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) if get_infos: return LU, pivots, infos @@ -411,11 +424,9 @@ def stft(self, n_fft: int, hop_length: Optional[int] = None, This function changed signature at version 0.4.1. Calling with the previous signature may cause error or return incorrect result. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): return handle_torch_function( - Tensor.stft, relevant_args, self, n_fft, hop_length=hop_length, + Tensor.stft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, onesided=onesided, return_complex=return_complex ) @@ -428,11 +439,9 @@ def istft(self, n_fft: int, hop_length: Optional[int] = None, onesided: Optional[bool] = None, length: Optional[int] = None, return_complex: bool = False): r"""See :func:`torch.istft`""" - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): return handle_torch_function( - Tensor.istft, relevant_args, self, n_fft, hop_length=hop_length, win_length=win_length, + Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, normalized=normalized, onesided=onesided, length=length, return_complex=return_complex ) @@ -440,19 +449,15 @@ def istft(self, n_fft: int, hop_length: Optional[int] = None, normalized, onesided, length, return_complex=return_complex) def resize(self, *sizes): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.resize, relevant_args, self, *sizes) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.resize, (self,), self, *sizes) warnings.warn("non-inplace resize is deprecated") from torch.autograd._functions import Resize return Resize.apply(self, sizes) def resize_as(self, tensor): - relevant_args = (self, tensor) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and type(tensor) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.resize_as, relevant_args, self, tensor) + if has_torch_function_variadic(self, tensor): + return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor) warnings.warn("non-inplace resize_as is deprecated") from torch.autograd._functions import Resize return Resize.apply(self, tensor.size()) @@ -460,10 +465,8 @@ def resize_as(self, tensor): def split(self, split_size, dim=0): r"""See :func:`torch.split` """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.split, relevant_args, self, split_size, dim=dim) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.split, (self,), self, split_size, dim=dim) if isinstance(split_size, int): return super(Tensor, self).split(split_size, dim) elif isinstance(split_size, Tensor): @@ -480,11 +483,9 @@ def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=Non See :func:`torch.unique` """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): return handle_torch_function( - Tensor.unique, relevant_args, self, sorted=sorted, return_inverse=return_inverse, + Tensor.unique, (self,), self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim ) return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) @@ -494,31 +495,22 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None See :func:`torch.unique_consecutive` """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): return handle_torch_function( - Tensor.unique_consecutive, relevant_args, self, return_inverse=return_inverse, + Tensor.unique_consecutive, (self,), self, return_inverse=return_inverse, return_counts=return_counts, dim=dim ) return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim) def __rsub__(self, other): - relevant_args = (self, other) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__rsub__, relevant_args, self, other) + if has_torch_function_variadic(self, other): + return handle_torch_function(Tensor.__rsub__, (self, other), self, other) return _C._VariableFunctions.rsub(self, other) def __rdiv__(self, other): - relevant_args = (self, other) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__rdiv__, relevant_args, self, other) - if self.dtype.is_floating_point or self.dtype.is_complex: - return self.reciprocal() * other - else: - return (self.double().reciprocal() * other).type_as(self) + if has_torch_function_variadic(self, other): + return handle_torch_function(Tensor.__rdiv__, (self, other), self, other) + return self.reciprocal() * other __rtruediv__ = __rdiv__ __itruediv__ = _C._TensorBase.__idiv__ @@ -526,19 +518,15 @@ def __rdiv__(self, other): __pow__ = _C._TensorBase.pow def __format__(self, format_spec): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__format__, relevant_args, self, format_spec) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__format__, (self,), self, format_spec) if self.dim() == 0: return self.item().__format__(format_spec) return object.__format__(self, format_spec) def __ipow__(self, other): # type: ignore[misc] - relevant_args = (self, other) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__ipow__, relevant_args, self, other) + if has_torch_function_variadic(self, other): + return handle_torch_function(Tensor.__ipow__, (self, other), self, other) return NotImplemented @_wrap_type_error_to_not_implemented @@ -552,26 +540,14 @@ def __floordiv__(self, other): @_wrap_type_error_to_not_implemented def __rfloordiv__(self, other): - result = other / self - if result.dtype.is_floating_point: - result = result.trunc() - return result + return torch.floor_divide(other, self) __neg__ = _C._TensorBase.neg - - __eq__ = _wrap_type_error_to_not_implemented(_C._TensorBase.eq) - __ne__ = _wrap_type_error_to_not_implemented(_C._TensorBase.ne) - __lt__ = _wrap_type_error_to_not_implemented(_C._TensorBase.lt) - __le__ = _wrap_type_error_to_not_implemented(_C._TensorBase.le) - __gt__ = _wrap_type_error_to_not_implemented(_C._TensorBase.gt) - __ge__ = _wrap_type_error_to_not_implemented(_C._TensorBase.ge) __abs__ = _C._TensorBase.abs def __len__(self): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__len__, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__len__, (self,), self) if self.dim() == 0: raise TypeError("len() of a 0-d tensor") return self.shape[0] @@ -583,10 +559,8 @@ def __iter__(self): # (e.g., if you zip(*hiddens), the eager map will force all the # indexes of hiddens[0] before hiddens[1], while the generator # map will interleave them.) - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__iter__, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__iter__, (self,), self) if self.dim() == 0: raise TypeError('iteration over a 0-d tensor') if torch._C._get_tracing_state(): @@ -597,17 +571,13 @@ def __iter__(self): return iter(self.unbind(0)) def __hash__(self): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__hash__, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__hash__, (self,), self) return id(self) def __dir__(self): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__dir__, relevant_args, self) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dir__, (self,), self) if self.is_quantized: warnings.warn('Only a small subset of methods are supported for quantized tensors.') tensor_methods = dir(self.__class__) @@ -625,10 +595,8 @@ def __dir__(self): __array_priority__ = 1000 # prefer Tensor ops over numpy ones def __array__(self, dtype=None): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__array__, relevant_args, self, dtype=dtype) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype) if dtype is None: return self.numpy() else: @@ -637,10 +605,8 @@ def __array__(self, dtype=None): # Wrap Numpy array again in a suitable tensor when done, to support e.g. # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` def __array_wrap__(self, array): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__array_wrap__, relevant_args, self, array=array) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__array_wrap__, (self,), self, array=array) if array.dtype == bool: # Workaround, torch has no built-in bool tensor array = array.astype('uint8') @@ -649,14 +615,12 @@ def __array_wrap__(self, array): def __contains__(self, element): r"""Check if `element` is present in tensor - Arguments: + Args: element (Tensor or scalar): element to be checked for presence in current tensor" """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.__contains__, relevant_args, self, element) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__contains__, (self,), self, element) if isinstance(element, (torch.Tensor, Number)): # type hint doesn't understand the __contains__ result array return (element == self).any().item() # type: ignore[union-attr] @@ -673,11 +637,9 @@ def __cuda_array_interface__(self): See: https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 - return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self) # type: ignore[attr-defined] + return handle_torch_function(Tensor.__cuda_array_interface__.__get__, (self,), self) # type: ignore[attr-defined] # raise AttributeError for unsupported tensors, so that # hasattr(cpu_tensor, "__cuda_array_interface__") is False. @@ -750,7 +712,7 @@ def refine_names(self, *names): Python 2 does not support Ellipsis but one may use a string literal instead (``'...'``). - Arguments: + Args: names (iterable of str): The desired names of the output tensor. May contain up to one Ellipsis. @@ -770,10 +732,8 @@ def refine_names(self, *names): The named tensor API is experimental and subject to change. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.refine_names, relevant_args, self, *names) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.refine_names, (self,), self, *names) names = resolve_ellipsis(names, self.names, 'refine_names') return super(Tensor, self).refine_names(names) @@ -796,7 +756,7 @@ def align_to(self, *names): Python 2 does not support Ellipsis but one may use a string literal instead (``'...'``). - Arguments: + Args: names (iterable of str): The desired dimension ordering of the output tensor. May contain up to one Ellipsis that is expanded to all unmentioned dim names of :attr:`self`. @@ -813,10 +773,8 @@ def align_to(self, *names): The named tensor API is experimental and subject to change. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.align_to, relevant_args, self, *names) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.align_to, (self,), self, *names) ellipsis_idx = single_ellipsis_index(names, 'align_to') if ellipsis_idx is None: return super(Tensor, self).align_to(names) @@ -833,7 +791,7 @@ def unflatten(self, dim, sizes): if :attr:`self` is a `NamedTensor`. The total number of elements in sizes must match the number of elements in the original dim being unflattened. - Arguments: + Args: dim (Union[int, str]): Dimension to unflatten sizes (Union[Tuple[int] or torch.Size, Tuple[Tuple[str, int]]]): New shape of the unflattened dimension @@ -850,10 +808,8 @@ def unflatten(self, dim, sizes): .. warning:: The named tensor API is experimental and subject to change. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.unflatten, relevant_args, self, dim, sizes) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes) if not sizes: raise RuntimeError("unflatten: sizes must be non-empty") @@ -867,10 +823,8 @@ def unflatten(self, dim, sizes): def rename_(self, *names, **rename_map): """In-place version of :meth:`~Tensor.rename`.""" - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.rename_, relevant_args, self, *names, **rename_map) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.rename_, (self,), self, *names, **rename_map) # Note [rename_ / rename API] # The Python API for these is different from the C++ API. In Python: @@ -913,19 +867,15 @@ def rename(self, *names, **rename_map): The named tensor API is experimental and subject to change. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor.rename, relevant_args, self, *names, **rename_map) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.rename, (self,), self, *names, **rename_map) # See Note [rename_ / rename API] return update_names(self, names, rename_map, inplace=False) def _update_names(self, names, inplace): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): - return handle_torch_function(Tensor._update_names, relevant_args, self, names, inplace) + if has_torch_function_unary(self): + return handle_torch_function(Tensor._update_names, (self,), self, names, inplace) # See Note [rename_ / rename API] if inplace: @@ -941,11 +891,9 @@ def grad(self): The attribute will then contain the gradients computed and future calls to :func:`backward` will accumulate (add) gradients into it. """ - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 - return handle_torch_function(Tensor.grad.__get__, relevant_args, self) # type: ignore[attr-defined] + return handle_torch_function(Tensor.grad.__get__, (self,), self) # type: ignore[attr-defined] if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None: warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad " @@ -957,20 +905,16 @@ def grad(self): @grad.setter def grad(self, new_grad): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 - return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad) # type: ignore[attr-defined] + return handle_torch_function(Tensor.grad.__set__, (self,), self, new_grad) # type: ignore[attr-defined] self._grad = new_grad @grad.deleter def grad(self): - relevant_args = (self,) - from torch.overrides import has_torch_function, handle_torch_function - if type(self) is not Tensor and has_torch_function(relevant_args): + if has_torch_function_unary(self): # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 - return handle_torch_function(Tensor.grad.__delete__, relevant_args, self) # type: ignore[attr-defined] + return handle_torch_function(Tensor.grad.__delete__, (self,), self) # type: ignore[attr-defined] del self._grad @classmethod @@ -1008,7 +952,8 @@ def _convert(ret, cls): if isinstance(ret, Tensor): ret = ret.as_subclass(cls) - if isinstance(ret, tuple): - ret = tuple(_convert(r, cls) for r in ret) + if isinstance(ret, (tuple, list)): + # Also handles things like namedtuples + ret = type(ret)(_convert(r, cls) for r in ret) return ret diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 396d0718efbcf..c26556f4d70a9 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -6,6 +6,7 @@ import random import math from typing import cast, List, Optional, Tuple, Union +from .check_kernel_launches import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches FileCheck = torch._C.FileCheck @@ -24,6 +25,9 @@ def is_integral(dtype: torch.dtype) -> bool: dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()] return dtype in dtypes and not dtype.is_floating_point +def is_quantized(dtype: torch.dtype) -> bool: + return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2) + # Helper function that maps a flattened index back into the given shape # TODO: consider adding torch.unravel_index def _unravel_index(flat_index, shape): @@ -70,7 +74,11 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e debug_msg : Optional[str] # Integer (including bool) comparisons are identity comparisons # when rtol is zero and atol is less than one - if (is_integral(a.dtype) and rtol == 0 and atol < 1) or a.dtype is torch.bool: + if ( + (is_integral(a.dtype) and rtol == 0 and atol < 1) + or a.dtype is torch.bool + or is_quantized(a.dtype) + ): if (a == b).all().item(): return (True, None) @@ -204,7 +212,8 @@ def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg= if not isinstance(expected, torch.Tensor): expected = torch.tensor(expected, dtype=actual.dtype) if expected.shape != actual.shape: - expected = expected.expand_as(actual) + raise AssertionError("expected tensor shape {0} doesn't match with actual tensor " + "shape {1}!".format(expected.shape, actual.shape)) if rtol is None or atol is None: if rtol is not None or atol is not None: raise ValueError("rtol and atol must both be specified or both be unspecified") diff --git a/torch/testing/_internal/codegen/random_topo_test.py b/torch/testing/_internal/codegen/random_topo_test.py index e2823a97f10b1..cf27fadff314c 100644 --- a/torch/testing/_internal/codegen/random_topo_test.py +++ b/torch/testing/_internal/codegen/random_topo_test.py @@ -2,6 +2,8 @@ import numpy as np import argparse +from typing import Dict + # debug print DEBUG_PRINT = False @@ -71,7 +73,7 @@ def get_root(x, dependency_map): return get_root(dependency_map[x], dependency_map) else: return x - d_map = {} + d_map: Dict[int, int] = {} num_sets = num_tensor candidate = list(range(num_tensor)) @@ -283,7 +285,7 @@ def runDefaultTestWithSeed(seed): jit_o = traced_model(seed_tensor, *tensor_list) validate_o = zip(o, jit_o) for oo, jit_oo in validate_o: - if not oo.allclose(jit_oo, equal_nan=True): + if not oo.allclose(jit_oo, atol=1e-5, equal_nan=True): return False return True diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 2904e4c824ad8..7172da19a4b7b 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -3,13 +3,16 @@ import inspect import runpy import threading +from enum import Enum from functools import wraps -from typing import List, Any, ClassVar +from typing import List, Any, ClassVar, Optional, Sequence import unittest import os import torch from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \ - skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN + skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ + IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard +from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing import \ (get_all_dtypes) @@ -165,13 +168,12 @@ # See below for how this list is populated. If you're adding a device type # you should check if it's available and (if it is) add it to this list. -# set type to List[Any] due to mypy list-of-union issue: -# https://github.com/python/mypy/issues/3351 -device_type_test_bases: List[Any] = list() def _construct_test_name(test_name, op, device_type, dtype): if op is not None: - test_name += "_" + op.name + test_name += "_" + op.name.replace('.', '_') + if op.variant_test_name: + test_name += "_" + op.variant_test_name test_name += "_" + device_type @@ -187,6 +189,9 @@ def _construct_test_name(test_name, op, device_type, dtype): class DeviceTypeTestBase(TestCase): device_type: str = 'generic_device_type' + # Flag to disable test suite early due to unrecoverable error such as CUDA error. + _stop_test_suite = False + # Precision is a thread-local setting since it may be overridden per test _tls = threading.local() _tls.precision = TestCase._precision @@ -227,6 +232,9 @@ def _get_precision_override(self, test, dtype): return self.precision return test.precision_overrides.get(dtype, self.precision) + def _should_stop_test_suite(self, rte): + return False + # Creates device-specific tests. @classmethod def instantiate_test(cls, name, test, *, generic_cls=None): @@ -271,6 +279,11 @@ def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): self.precision = self._get_precision_override(test_fn, dtype) args = (arg for arg in (device_arg, dtype, op) if arg is not None) result = test_fn(self, *args) + except RuntimeError as rte: + # check if rte should stop entire test suite. + self._stop_test_suite = self._should_stop_test_suite(rte) + # raise the runtime error as is for the test suite to record. + raise rte finally: self.precision = guard_precision @@ -285,21 +298,21 @@ def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): # Acquires dtypes, using the op data if unspecified dtypes = cls._get_dtypes(test) if dtypes is None: - if cls.device_type == 'cpu' and op.dtypesIfCPU is not None: - dtypes = op.dtypesIfCPU - elif (cls.device_type == 'cuda' and not TEST_WITH_ROCM - and op.dtypesIfCUDA is not None): - dtypes = op.dtypesIfCUDA - elif (cls.device_type == 'cuda' and TEST_WITH_ROCM - and op.dtypesIfROCM is not None): - dtypes = op.dtypesIfROCM + if test.opinfo_dtypes == OpDTypes.unsupported: + dtypes = set(get_all_dtypes()).difference(op.supported_dtypes(cls.device_type)) + elif test.opinfo_dtypes == OpDTypes.supported: + dtypes = op.supported_dtypes(cls.device_type) + elif test.opinfo_dtypes == OpDTypes.basic: + dtypes = op.default_test_dtypes(cls.device_type) else: - dtypes = op.dtypes + raise RuntimeError(f"Unknown OpDType: {test.opinfo_dtypes}") + + if test.allowed_dtypes is not None: + dtypes = dtypes.intersection(test.allowed_dtypes) + else: + assert test.allowed_dtypes is None, "ops(allowed_dtypes=[...]) and the dtypes decorator are incompatible" + assert test.opinfo_dtypes == OpDTypes.basic, "ops(dtypes=...) and the dtypes decorator are incompatible" - # Inverts dtypes if the function wants unsupported dtypes - if test.unsupported_dtypes_only is True: - dtypes = [d for d in get_all_dtypes() if d not in dtypes] - dtypes = dtypes if dtypes is not None else (None,) for dtype in dtypes: instantiate_test_helper(cls, name, @@ -313,6 +326,12 @@ def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): for dtype in dtypes: instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None) + def run(self, result=None): + super().run(result=result) + # Early terminate test if _stop_test_suite is set. + if self._stop_test_suite: + result.stop() + class CPUTestBase(DeviceTypeTestBase): device_type = 'cpu' @@ -327,6 +346,14 @@ class CUDATestBase(DeviceTypeTestBase): no_magma: ClassVar[bool] no_cudnn: ClassVar[bool] + def _should_stop_test_suite(self, rte): + # CUDA device side error will cause subsequence test cases to fail. + # stop entire test suite if catches RuntimeError during torch.cuda.synchronize(). + try: + torch.cuda.synchronize() + except RuntimeError as rte: + return True + return False def has_cudnn(self): return not self.no_cudnn @@ -360,9 +387,27 @@ def setUpClass(cls): # Adds available device-type-specific test base classes -device_type_test_bases.append(CPUTestBase) -if torch.cuda.is_available(): - device_type_test_bases.append(CUDATestBase) +def get_device_type_test_bases(): + # set type to List[Any] due to mypy list-of-union issue: + # https://github.com/python/mypy/issues/3351 + test_bases: List[Any] = list() + + if IS_SANDCASTLE or IS_FBCODE: + if IS_REMOTE_GPU: + # skip if sanitizer is enabled + if not TEST_WITH_ASAN and not TEST_WITH_TSAN and not TEST_WITH_UBSAN: + test_bases.append(CUDATestBase) + else: + test_bases.append(CPUTestBase) + else: + test_bases.append(CPUTestBase) + if torch.cuda.is_available(): + test_bases.append(CUDATestBase) + + return test_bases + + +device_type_test_bases = get_device_type_test_bases() # Note [How to extend DeviceTypeTestBase to add new test device] @@ -465,6 +510,22 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on scope[class_name] = device_type_test_class +# Category of dtypes to run an OpInfo-based test for +# Example use: @ops(dtype=OpDTypes.supported) +# +# There are 3 categories: supported, unsupported and basic. +# - basic: The dtypes the operator wants to be tested on by default. This will be +# a subset of the types supported by the operator. +# - supported: Every dtype supported by the operator. Use for exhaustive +# testing of all dtypes. +# - unsupported: Run tests on dtypes not supported by the operator. e.g. for +# testing the operator raises an error and doesn't crash. +class OpDTypes(Enum): + basic = 0 # Test the basic set of dtypes (default) + supported = 1 # Test all supported dtypes + unsupported = 2 # Test only unsupported dtypes + + # Decorator that defines the ops a test should be run with # The test signature must be: # (self, device, dtype, op) @@ -473,13 +534,16 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on # test_numerics(self, device, dtype, op): # class ops(object): - def __init__(self, op_list, *, unsupported_dtypes_only=False): + def __init__(self, op_list, *, dtypes: OpDTypes = OpDTypes.basic, + allowed_dtypes: Optional[Sequence[torch.dtype]] = None): self.op_list = op_list - self.unsupported_dtypes_only = unsupported_dtypes_only + self.opinfo_dtypes = dtypes + self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None def __call__(self, fn): fn.op_list = self.op_list - fn.unsupported_dtypes_only = self.unsupported_dtypes_only + fn.allowed_dtypes = self.allowed_dtypes + fn.opinfo_dtypes = self.opinfo_dtypes return fn # Decorator that skips a test if the given condition is true. @@ -522,20 +586,14 @@ class skipCUDAIf(skipIf): def __init__(self, dep, reason): super().__init__(dep, reason, device_type='cuda') - -# Only runs on cuda, and only run when there is enough GPU RAM -def largeCUDATensorTest(size): - if isinstance(size, str): - assert size.endswith("GB") or size.endswith("gb"), "only bytes or GB supported" - size = 1024 ** 3 * int(size[:-2]) - valid = torch.cuda.is_available() and torch.cuda.get_device_properties(0).total_memory >= size - return unittest.skipIf(not valid, "No CUDA or Has CUDA but GPU RAM is not large enough") - - def _has_sufficient_memory(device, size): - if device.startswith('cuda'): - return (torch.cuda.is_available() and - torch.cuda.get_device_properties(0).total_memory >= size) + if torch.device(device).type == 'cuda': + if not torch.cuda.is_available(): + return False + gc.collect() + torch.cuda.empty_cache() + return torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device) >= size + if device == 'xla': raise unittest.SkipTest('TODO: Memory availability checks for XLA?') @@ -557,10 +615,14 @@ def _has_sufficient_memory(device, size): return psutil.virtual_memory().available >= effective_size -def largeTensorTest(size): +def largeTensorTest(size, device=None): """Skip test if the device has insufficient memory to run the test size may be a number of bytes, a string of the form "N GB", or a callable + + If the test is a device generic test, available memory on the primary device will be checked. + It can also be overriden by the optional `device=` argument. + In other tests, the `device=` argument needs to be specified. """ if isinstance(size, str): assert size.endswith("GB") or size.endswith("gb"), "only bytes or GB supported" @@ -570,8 +632,9 @@ def inner(fn): @wraps(fn) def dep_fn(self, *args, **kwargs): size_bytes = size(self, *args, **kwargs) if callable(size) else size - if not _has_sufficient_memory(self.device_type, size_bytes): - raise unittest.SkipTest('Insufficient {} memory'.format(self.device_type)) + _device = device if device is not None else self.get_primary_device() + if not _has_sufficient_memory(_device, size_bytes): + raise unittest.SkipTest('Insufficient {} memory'.format(_device)) return fn(self, *args, **kwargs) return dep_fn @@ -754,24 +817,21 @@ def __call__(self, fn): @wraps(fn) def efail_fn(slf, device, *args, **kwargs): if self.device_type is None or self.device_type == slf.device_type: - deterministic_restore = torch.is_deterministic() - torch.set_deterministic(True) - try: - if self.fn_has_device_arg: - fn(slf, device, *args, **kwargs) + with DeterministicGuard(True): + try: + if self.fn_has_device_arg: + fn(slf, device, *args, **kwargs) + else: + fn(slf, *args, **kwargs) + except RuntimeError as e: + if self.error_message not in str(e): + slf.fail( + 'expected non-deterministic error message to start with "' + + self.error_message + + '" but got this instead: "' + str(e) + '"') + return else: - fn(slf, *args, **kwargs) - except RuntimeError as e: - torch.set_deterministic(deterministic_restore) - if self.error_message not in str(e): - slf.fail( - 'expected non-deterministic error message to start with "' - + self.error_message - + '" but got this instead: "' + str(e) + '"') - return - else: - torch.set_deterministic(deterministic_restore) - slf.fail('expected a non-deterministic error, but it was not raised') + slf.fail('expected a non-deterministic error, but it was not raised') if self.fn_has_device_arg: return fn(slf, device, *args, **kwargs) @@ -801,6 +861,13 @@ def skipCPUIfNoMkl(fn): def skipCUDAIfNoMagma(fn): return skipCUDAIf('no_magma', "no MAGMA library detected")(skipCUDANonDefaultStreamIf(True)(fn)) +def skipCUDAIfNoMagmaAndNoCusolver(fn): + version = _get_torch_cuda_version() + if version >= [10, 2]: + return fn + else: + # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA + return skipCUDAIfNoMagma(fn) # Skips a test on CUDA when using ROCm. def skipCUDAIfRocm(fn): diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index f8e5b4822bd85..23a4e7a362702 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1,5 +1,8 @@ from multiprocessing import Manager +from contextlib import contextmanager +from io import StringIO +import itertools import os import sys import tempfile @@ -16,7 +19,7 @@ import torch.distributed as c10d from functools import partial, reduce -from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM +from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, FILE_SCHEMA class TestSkip(NamedTuple): exit_code: int @@ -29,9 +32,28 @@ class TestSkip(NamedTuple): "no_cuda": TestSkip(74, "CUDA is not available."), "multi-gpu": TestSkip(75, "Need at least 2 CUDA devices"), "nccl": TestSkip(76, "c10d not compiled with NCCL support"), - "skipIfRocm": TestSkip(78, "Test skipped for ROCm") + "skipIfRocm": TestSkip(78, "Test skipped for ROCm"), + "no_peer_access": TestSkip(79, "Test skipped because no GPU peer access"), } + +# FIXME: this should be removed when TensorPipe can detect availability of peer access +def skip_if_no_peer_access(func): + """TensorPipe same-machine GPU-to-GPU comm requires peer access""" + @wraps(func) + def wrapper(*args, **kwargs): + if not torch.cuda.is_available(): + sys.exit(TEST_SKIPS["no_cuda"].exit_code) + n = torch.cuda.device_count() + for i, j in itertools.product(range(n), range(n)): + if i != j and not torch.cuda.can_device_access_peer(i, j): + sys.exit(TEST_SKIPS["no_peer_access"].exit_code) + + return func(*args, **kwargs) + + return wrapper + + def skip_if_no_gpu(func): """ Nccl multigpu tests require at least 2 GPUS. Skip if this is not met""" @wraps(func) @@ -61,15 +83,17 @@ def wrapper(*args, **kwargs): def skip_if_not_multigpu(func): """Multi-GPU tests requires at least 2 GPUS. Skip if this is not met.""" - @wraps(func) - def wrapper(*args, **kwargs): - if torch.cuda.is_available() and torch.cuda.device_count() >= 2: - return func(*args, **kwargs) - message = "Need at least {} CUDA devices".format(2) - TEST_SKIPS["multi-gpu"] = TestSkip(75, message) - sys.exit(TEST_SKIPS['multi-gpu'].exit_code) + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + return func(*args, **kwargs) + message = "Need at least {} CUDA devices".format(2) + TEST_SKIPS["multi-gpu"] = TestSkip(75, message) + sys.exit(TEST_SKIPS['multi-gpu'].exit_code) + return wrapper - return wrapper + return decorator def require_n_gpus_for_nccl_backend(n, backend): def decorator(func): @@ -130,6 +154,17 @@ def requires_mpi(): "c10d was not compiled with the MPI backend", ) +def skip_if_rocm_single_process(func): + """Skips a test for ROCm in a single process environment""" + func.skip_if_rocm = True + + @wraps(func) + def wrapper(*args, **kwargs): + if not TEST_WITH_ROCM: + return func(*args, **kwargs) + raise unittest.SkipTest("Test skipped for ROCm") + + return wrapper def skip_if_rocm(func): """Skips a test for ROCm""" @@ -143,13 +178,35 @@ def wrapper(*args, **kwargs): return wrapper +def skip_if_win32(): + return unittest.skipIf( + sys.platform == 'win32', + "This unit test case is not supportted on Windows platform", + ) + TIMEOUT_DEFAULT = 100 TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400} +def create_device(interface=None): + if sys.platform == 'win32' or interface is None: + return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1") + else: + return c10d.ProcessGroupGloo.create_device(interface=interface) + + def get_timeout(test_id): return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT) +@contextmanager +def captured_output(): + new_out, new_err = StringIO(), StringIO() + old_out, old_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = new_out, new_err + yield sys.stdout, sys.stderr + finally: + sys.stdout, sys.stderr = old_out, old_err def simple_sparse_reduce_tests(rank, world_size, num_inputs=1): """ @@ -206,7 +263,7 @@ def initialize_temp_directories(init_method=None): if init_method is not None: os.environ["INIT_METHOD"] = init_method else: - os.environ["INIT_METHOD"] = "file://" + os.path.join( + os.environ["INIT_METHOD"] = FILE_SCHEMA + os.path.join( init_dir_path, "shared_init_file" ) diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py new file mode 100644 index 0000000000000..d5dda32b46867 --- /dev/null +++ b/torch/testing/_internal/common_jit.py @@ -0,0 +1,220 @@ +# Torch +import torch +import torch.cuda +import torch.jit +import torch.jit._logging +import torch.jit.frontend +import torch.jit.quantized + +# Testing utils +from torch.testing import floating_and_complex_types_and +from torch.testing._internal.common_utils import TestCase, \ + freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests +from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 + +# Standard library +from itertools import chain + +import io + +def check_output_types(self, func, ref_outputs, args, kwargs): + graph = getattr(func, 'last_graph', None) + types = [o.type() for o in graph.outputs()] + self.assertTrue(len(types) == 1) + t = types[0] + torch._C._jit_assert_is_instance(ref_outputs, t) + +# Test names in this set are only checked for a single derivative +nn_functional_single_grad = frozenset('test_nn_' + name for name in [ + 'pdist', + 'multilabel_margin_loss', + 'max_unpool3d', + 'multi_margin_loss', + 'binary_cross_entropy', + 'binary_cross_entropy_size_average', + 'ctc_loss', + 'grid_sample', +]) + +def check_against_reference(self, func, reference_func, args, kwargs=None, + allow_unused=True, check_types=True, no_grad=False): + kwargs = kwargs if kwargs else {} + + def allSum(vs): + if isinstance(vs, torch.Tensor): + vs = (vs,) + return sum((i + 1) * v.sum() + for i, v in enumerate(vs) + if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16)) + + def clone_inputs(requires_grad): + inputs = [ + arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad) + if isinstance(arg, torch.Tensor) else arg for arg in args + ] + return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad] + + nograd_inputs, nograd_tensors = clone_inputs(False) + recording_inputs, recording_tensors = clone_inputs(True) + + # test no gradients case + outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs) + with enable_profiling_mode_for_profiling_tests(): + outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs) + self.assertEqual(outputs, outputs_test) + + if check_types: + check_output_types(self, func, outputs_test, nograd_inputs, kwargs) + + if no_grad: + # skip grad tests + return + + with enable_profiling_mode_for_profiling_tests(): + # test single grad case + outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs) + grads = torch.autograd.grad(allSum(outputs), recording_tensors, + allow_unused=allow_unused) + outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs) + grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, + allow_unused=allow_unused) + self.assertEqual(outputs, outputs_test) + self.assertEqual(grads, grads_test) + # test the grad grad case + if self._testMethodName in nn_functional_single_grad: + return + + outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs) + l1 = allSum(outputs) + grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, + allow_unused=allow_unused) + + l2 = (allSum(grads) * l1) + grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) + recording_inputs, recording_tensors = clone_inputs(True) + outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs) + l1_test = allSum(outputs_test) + grads_test = torch.autograd.grad( + l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) + + l2_test = (allSum(grads_test) * l1_test) + grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused) + + self.assertEqual(outputs, outputs_test) + self.assertEqual(grads, grads_test) + for g2, g2_test in zip(grads2, grads2_test): + if g2 is None and g2_test is None: + continue + self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4)) + + +class JitCommonTestCase(TestCase): + def createFunctionFromGraph(self, trace): + graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() + return torch._C._create_function_from_graph("forward", graph) + + def assertExportImport(self, trace, inputs): + m = self.createFunctionFromGraph(trace) + self.assertExportImportModule(m, inputs) + + def assertExportImportModule(self, m, inputs): + m_import = self.getExportImportCopy(m) + a = self.runAndSaveRNG(m, inputs) + b = self.runAndSaveRNG(m_import, inputs) + self.assertEqual(a, b, "Results of original model and " + "exported/imported version of model differed") + + def runAndSaveRNG(self, func, inputs, kwargs=None): + kwargs = kwargs if kwargs else {} + with freeze_rng_state(): + results = func(*inputs, **kwargs) + return results + + def getExportImportCopy(self, m, also_test_file=True, map_location=None): + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + imported = torch.jit.load(buffer, map_location=map_location) + + if not also_test_file: + return imported + + with TemporaryFileName() as fname: + torch.jit.save(imported, fname) + return torch.jit.load(fname, map_location=map_location) + + def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph, + fusion_nodes_not_found, non_fusible_nodes_being_fused, + fusion_nodes_found, nodes_in_diff_graph): + err_msg = "\nFailure in testing nodes' autodifferentiation, " + if should_autodiff_node: + err_msg += "one or more nodes were expected to be autodiffed, " \ + "but were not found in specified fusible/nonfusible " \ + "DifferentiableGraph groups. \nSpecifically:" + for node in nodes_not_in_diff_graph: + err_msg += f"\n {node} was not in one of the DifferentiableGraphs " \ + "when it was expected to be. Did you intend for this node to be " \ + "autodiffed? If not, remove it from the list of nonfusible nodes." + if node in non_fusible_nodes_being_fused: + err_msg += "Additionally, This node was found in a FusionGroup " \ + "in a DifferentiableGraph. If that was intended, " \ + "reclassify this node as a fusible node. If not, your " \ + "autodifferention logic might be wrong." + for node in fusion_nodes_not_found: + err_msg += f"\n {node} was not in one of the DifferentiableGraphs' " \ + "fusion groups when it was expected to be. " \ + "Did you intend for this node to be fused? If not, you should " \ + "move this node into the test's non-fusible nodes." + else: + err_msg += "one or more nodes were not expected to be autodiffed, " \ + "but were found in a fused/nonfused DifferentiableGraph group. " \ + "Did you intend for these nodes to be autodiffed? " \ + "If so, change this test to expect autodifferention. " \ + "\nSpecifically:" + for node in fusion_nodes_found: + err_msg += f"\n {node} was not expected to in one of the " \ + "DifferentiableGraph's fusion groups but was. " + for node in nodes_in_diff_graph: + err_msg += f"\n {node} was not expected to be in a " \ + "DifferentiableGraph but was." + return err_msg + + def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): + diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') + diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] + + # Note: currently no tests have fusible_nodes + fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs])) + fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes] + + # For any non-fusible node, it must show up in one of the DifferentiableGraphs. + nodes_in_diff_graph = [] + nodes_not_in_diff_graph = [] + non_fusible_nodes_being_fused = [] + for node in nonfusible_nodes: + if any(g.findNode(node) is not None for g in diff_subgraphs): + nodes_in_diff_graph.append(node) + else: + nodes_not_in_diff_graph.append(node) + if any(g.findNode(node) is not None for g in fusion_subgraphs): + non_fusible_nodes_being_fused.append(node) + found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes) + + # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs. + fusion_nodes_found = [] + fusion_nodes_not_found = [] + for node in fusible_nodes: + if any(g.findNode(node) is not None for g in fusion_subgraphs): + fusion_nodes_found.append(node) + else: + fusion_nodes_not_found.append(node) + found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes) + + err_msg = self.autoDiffErrorMessage(should_autodiff_node, + nodes_not_in_diff_graph, + fusion_nodes_not_found, + non_fusible_nodes_being_fused, + fusion_nodes_found, + nodes_in_diff_graph) + self.assertEqual(should_autodiff_node, + found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a434d69a8654a..0e7163ce3e0d1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1,6 +1,8 @@ -from functools import reduce +from functools import reduce, wraps +from itertools import product from operator import mul, itemgetter import collections +import operator import torch import numpy as np @@ -10,19 +12,25 @@ from typing import List, Tuple, Dict, Any from torch.testing import \ - (make_non_contiguous, _dispatch_dtypes, - floating_types, floating_types_and, floating_and_complex_types, - floating_and_complex_types_and, all_types_and_complex_and) + (make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and, + floating_and_complex_types, floating_and_complex_types_and, + all_types_and_complex_and, all_types_and) from torch.testing._internal.common_device_type import \ - (skipCUDAIfNoMagma, skipCPUIfNoLapack, expectedFailureCUDA, - expectedAlertNondeterministic, precisionOverride) + (skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm, + expectedAlertNondeterministic, precisionOverride, onlyCPU) +from torch.testing._internal.common_cuda import tf32_is_not_fp32 from torch.testing._internal.common_utils import \ (prod_single_zero, random_square_matrix_of_rank, random_symmetric_matrix, random_symmetric_psd_matrix, random_symmetric_pd_matrix, make_nonzero_det, random_fullrank_matrix_distinct_singular_value, set_rng_seed, - TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor) + TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY, + torch_to_numpy_dtype_dict, TEST_WITH_SLOW) +from distutils.version import LooseVersion + +if TEST_SCIPY: + import scipy.special class SkipInfo(object): """Describes which test, or type of tests, should be skipped when testing @@ -42,12 +50,32 @@ def __init__(self, cls_name=None, test_name=None, *, class SampleInput(object): """Represents sample inputs to a function.""" - __slots__ = ['input', 'args', 'kwargs'] + # output_process_fn_grad is a function that modifies the output of op compatible with input + __slots__ = ['input', 'args', 'kwargs', 'output_process_fn_grad'] - def __init__(self, input, *, args=tuple(), kwargs=None): - self.input = input + def __init__(self, input, *, args=tuple(), kwargs=None, output_process_fn_grad=None): + # test_ops.py expects input to be a tuple + self.input = input if isinstance(input, tuple) else (input,) self.args = args self.kwargs = kwargs if kwargs is not None else {} + self.output_process_fn_grad = output_process_fn_grad + + +_NOTHING = object() # Unique value to distinguish default from anything else + + +# Extension of getattr to support qualified names +# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm +def _getattr_qual(obj, name, default=_NOTHING): + try: + for path in name.split('.'): + obj = getattr(obj, path) + return obj + except AttributeError: + if default is not _NOTHING: + return default + else: + raise # Classes and methods for the operator database @@ -62,32 +90,69 @@ def __init__(self, dtypesIfCPU=None, # dtypes this function is expected to work with on CPU dtypesIfCUDA=None, # dtypes this function is expected to work with on CUDA dtypesIfROCM=None, # dtypes this function is expected to work with on ROCM + default_test_dtypes=None, # dtypes to test with by default. Gets intersected + # with the dtypes support on the tested device test_inplace_grad=True, # whether to gradcheck and gradgradcheck the inplace variant + test_complex_grad=True, # whether to gradcheck and gradgradcheck for complex dtypes + skip_bfloat16_grad=False, # whether to skip grad and gradgradcheck for bfloat16 dtype + assert_autodiffed=False, # if a op's aten::node is expected to be symbolically autodiffed + autodiff_nonfusible_nodes=None, # a list of strings with node names that are expected to be in a + # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'], + # default is populated to be ['aten::(name of Python operator)'] + autodiff_fusible_nodes=None, # a list of strings with node names that are expected to be in FusionGroups + # inside of DifferentiableGraphs when this operation is autodiffed. + # Ex: ['aten::add', 'aten::mm'], defaults to an empty list + # Note: currently no ops use fusible nodes + output_func=lambda x: x, # fn mapping output to part that should be gradcheck'ed + supports_tensor_out=True, # whether the op supports the out kwarg, returning a Tensor skips=tuple(), # information about which tests to skip - decorators=None): # decorators to apply to generated tests + decorators=None, # decorators to apply to generated tests + promotes_integers_to_float=False, # whether op promotes unary output to float or not + sample_inputs_func=None, # function to generate sample inputs + aten_name=None, # name of the corresponding aten:: operator + variant_test_name='', # additional string to include in the test name + supports_sparse=False # supported for sparse + ): + # Validates the dtypes are generated from the dispatch-related functions for dtype_list in (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM): assert isinstance(dtype_list, (_dispatch_dtypes, type(None))) self.name = name + self.aten_name = aten_name if aten_name is not None else name + self.variant_test_name = variant_test_name - self.dtypes = dtypes - self.dtypesIfCPU = dtypesIfCPU if dtypesIfCPU is not None else dtypes - self.dtypesIfCUDA = dtypesIfCUDA if dtypesIfCUDA is not None else dtypes - self.dtypesIfROCM = dtypesIfROCM if dtypesIfROCM is not None else dtypes + self.dtypes = set(dtypes) + self.dtypesIfCPU = set(dtypesIfCPU) if dtypesIfCPU is not None else self.dtypes + self.dtypesIfCUDA = set(dtypesIfCUDA) if dtypesIfCUDA is not None else self.dtypes + self.dtypesIfROCM = set(dtypesIfROCM) if dtypesIfROCM is not None else self.dtypes + self._default_test_dtypes = set(default_test_dtypes) if default_test_dtypes is not None else None # NOTE: if the op is unspecified it is assumed to be under the torch namespace - if op is None: - assert hasattr(torch, self.name) - self.op = op if op else getattr(torch, self.name) - self.method_variant = getattr(torch.Tensor, name) if hasattr(torch.Tensor, name) else None + self.op = op if op else _getattr_qual(torch, self.name) + self.method_variant = getattr(torch.Tensor, name, None) inplace_name = name + "_" - self.inplace_variant = getattr(torch.Tensor, inplace_name) if hasattr(torch.Tensor, name) else None + self.inplace_variant = getattr(torch.Tensor, inplace_name, None) + self.operator_variant = getattr(operator, name, None) + self.skip_bfloat16_grad = skip_bfloat16_grad self.test_inplace_grad = test_inplace_grad + self.test_complex_grad = test_complex_grad + self.supports_tensor_out = supports_tensor_out + self.promotes_integers_to_float = promotes_integers_to_float self.skips = skips self.decorators = decorators + self.output_func = output_func + self.sample_inputs_func = sample_inputs_func + + self.assert_autodiffed = assert_autodiffed + self.autodiff_fusible_nodes = autodiff_fusible_nodes if autodiff_fusible_nodes else [] + if autodiff_nonfusible_nodes is None: + self.autodiff_nonfusible_nodes = ['aten::' + self.name] + else: + self.autodiff_nonfusible_nodes = autodiff_nonfusible_nodes + self.supports_sparse = supports_sparse def __call__(self, *args, **kwargs): """Calls the function variant of the operator.""" @@ -109,9 +174,19 @@ def get_inplace(self): """ return self.inplace_variant + def get_operator_variant(self): + """Returns operator variant of the operator, e.g. operator.neg + Returns None if the operator has no operator variant. + """ + return self.operator_variant + def sample_inputs(self, device, dtype, requires_grad=False): - """Returns an iterable of SampleInputs.""" - return tuple() + """Returns an iterable of SampleInputs. + + These samples should be sufficient to test the function works correctly + with autograd, TorchScript, etc. + """ + return self.sample_inputs_func(self, device, dtype, requires_grad) # Returns True if the test should be skipped and False otherwise def should_skip(self, cls_name, test_name, device_type, dtype): @@ -128,15 +203,27 @@ def should_skip(self, cls_name, test_name, device_type, dtype): return False - def supports_dtype(self, dtype, device_type): + def supported_dtypes(self, device_type): if device_type == 'cpu': - return dtype in self.dtypesIfCPU + return self.dtypesIfCPU if device_type == 'cuda': - if TEST_WITH_ROCM: - return dtype in self.dtypesIfROCM - return dtype in self.dtypesIfCUDA + return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA + else: + return self.dtypes - return dtype in self.dtypes + + def supports_dtype(self, dtype, device_type): + return dtype in self.supported_dtypes(device_type) + + def default_test_dtypes(self, device_type): + """Returns the default dtypes used to test this operator on the device. + + Equal to the operator's default_test_dtypes filtered to remove dtypes + not supported by the device. + """ + supported = self.supported_dtypes(device_type) + return (supported if self._default_test_dtypes is None + else supported.intersection(self._default_test_dtypes)) L = 20 @@ -144,6 +231,18 @@ def supports_dtype(self, dtype, device_type): S = 5 +def sample_inputs_unary(op_info, device, dtype, requires_grad): + low, high = op_info.domain + low = low if low is None else low + op_info._domain_eps + high = high if high is None else high - op_info._domain_eps + + return (SampleInput(make_tensor((L,), device, dtype, + low=low, high=high, + requires_grad=requires_grad)), + SampleInput(make_tensor((), device, dtype, + low=low, high=high, + requires_grad=requires_grad))) + # Metadata class for unary "universal functions (ufuncs)" that accept a single # tensor and have common properties like: class UnaryUfuncInfo(OpInfo): @@ -171,44 +270,576 @@ def __init__(self, handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) handles_extremals=True, # whether the op correctly handles extremal values (like inf) handles_complex_extremals=True, # whether the op correct handles complex extremals (like inf -infj) + supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle + sample_inputs_func=sample_inputs_unary, + supports_sparse=False, **kwargs): super(UnaryUfuncInfo, self).__init__(name, dtypes=dtypes, dtypesIfCPU=dtypesIfCPU, dtypesIfCUDA=dtypesIfCUDA, dtypesIfROCM=dtypesIfROCM, + sample_inputs_func=sample_inputs_func, + supports_sparse=supports_sparse, **kwargs) self.ref = ref self.domain = domain self.handles_large_floats = handles_large_floats self.handles_extremals = handles_extremals self.handles_complex_extremals = handles_complex_extremals + self.supports_complex_to_float = supports_complex_to_float # Epsilon to ensure grad and gradgrad checks don't test values # outside a function's domain. self._domain_eps = 1e-5 +def sample_inputs_tensor_split(op_info, device, dtype, requires_grad): + return (SampleInput(make_tensor((S, S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + args=(torch.tensor([1, 2, 3]),),), + SampleInput(make_tensor((S, S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + args=(torch.tensor(1),),), + SampleInput(make_tensor((S, S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + args=(torch.tensor([1, 2, 3]),), + kwargs=dict(dim=1)),) + +def sample_inputs_slogdet(op_info, device, dtype, requires_grad): + # original test cases from 'method_tests' have too many test_inputs + # we don't actually need all of them to check the autograd and jit correctness + # sample inputs with shapes 0x0, 0xSxS, 2x0x0 are added + test_inputs = ( + torch.randn(0, 0, dtype=dtype, device=device), # '0x0' + torch.randn(S, S, dtype=dtype, device=device), # 'SxS' + torch.randn(0, S, S, dtype=dtype, device=device), # 'zero_batched_SxS' + torch.randn(2, 0, 0, dtype=dtype, device=device), # 'batched_0x0' + torch.randn(2, S, S, dtype=dtype, device=device), # 'batched_SxS' + ) + out = [] + for a in test_inputs: + a.requires_grad = requires_grad + out.append(SampleInput(a)) + return out + +def sample_inputs_addmm(op_info, device, dtype, requires_grad): + input = SampleInput((make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=False))) + if dtype.is_complex: + another_input = SampleInput((make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=False)), + kwargs=dict(beta=1 + 2j, alpha=2 + 3j)) + return (input, another_input) + else: + return (input, ) + + +def sample_inputs_xlogy(self, device, dtype, requires_grad): + return (SampleInput((make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=0, high=None, + requires_grad=requires_grad))),) + +def sample_inputs_linalg_inv(op_info, device, dtype, requires_grad=False): + """ + This function generates always invertible input for torch.linalg.inv using + random_fullrank_matrix_distinct_singular_value. + The input is generated as the itertools.product of 'batches' and 'ns'. + In total this function generates 8 SampleInputs + 'batches' cases include: + () - single input, + (0,) - zero batched dimension, + (2,) - batch of two matrices, + (2, 3) - 2x3 batch of matrices + 'ns' gives 0x0 and 5x5 matrices. + Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. + """ + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + batches = [(), (0, ), (2, ), (2, 3)] + ns = [0, 5] + out = [] + for batch, n in product(batches, ns): + a = random_fullrank_matrix_distinct_singular_value(n, *batch, dtype=dtype).to(device) + a.requires_grad = requires_grad + out.append(SampleInput(a)) + return out + +def np_sinc_with_fp16_as_fp32(x): + # Wraps numpy's sinc function so that fp16 values are promoted to fp32 + # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated + # at 0 for fp16. + if x.dtype == np.float16: + return np.sinc(x.astype(np.float32)) + else: + return np.sinc(x) + +def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad): + test_cases = ( + ((S, 1, 1), (S, S, S)), + ((S, 1, S), (S, S, S)), + ((S, 1), (S, S, S)), + ((1,), (S, S, S)), + ((1, S), (1, 1, S)), + ((), ()), + ((), (1, 3, 2)), + ) + + return tuple(SampleInput((make_tensor(size, device, dtype, + low=None, high=None, + requires_grad=requires_grad), shape)) + for size, shape in test_cases) + +def sample_inputs_stack(op_info, device, dtype, requires_grad): + return (SampleInput((make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad)), kwargs=dict(idx=0)),) + +def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad): + return (SampleInput((make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad))),) + +def sample_inputs_gather(op_info, device, dtype, requires_grad): + return (SampleInput((make_tensor((M, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 0, gather_variable((S, S), 1, M, True, device=device))), + SampleInput((make_tensor((M, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 1, gather_variable((M, S // 2), 0, S, True, device=device))), + SampleInput((make_tensor((), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 0, torch.tensor([0], dtype=torch.int64, device=device))), + SampleInput((make_tensor((S,), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 0, torch.tensor(0, dtype=torch.int64, device=device))), + SampleInput((make_tensor((), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 0, torch.tensor(0, dtype=torch.int64, device=device))), + ) + + +def sample_inputs_index_select(op_info, device, dtype, requires_grad): + return (SampleInput((make_tensor((S, S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 0, index_variable(2, S, device=device))), + SampleInput((make_tensor((), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 0, torch.tensor([0], dtype=torch.int64, device=device))), + SampleInput((make_tensor((), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + 0, torch.tensor(0, dtype=torch.int64, device=device))), + ) + +def sample_movedim_moveaxis(op_info, device, dtype, requires_grad): + return (SampleInput((make_tensor((4, 3, 2, 1), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + (0, 1, 2, 3), (3, 2, 1, 0))), + SampleInput((make_tensor((4, 3, 2, 1), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + (0, -1, -2, -3), (-3, -2, -1, -0)))) + + +def sample_repeat_tile(op_info, device, dtype, requires_grad): + rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),) + shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1)) + + if requires_grad: + # Tests for variant_consistency_jit, grad, gradgrad + # are slower. Use smaller bags of `rep_dims` and `shapes` + # in this case. + rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1)) # type: ignore + shapes = ((), (0,), (2,), (3, 2)) # type: ignore + + tensors = [make_tensor(shape, device, dtype, + low=None, high=None, + requires_grad=requires_grad) for shape in shapes] + + samples = [] + for rep_dim, tensor in product(rep_dims, tensors): + for t in (tensor, tensor.T): + if op_info.name == 'repeat' and len(rep_dim) >= t.dim(): + # `torch.repeat` errors for `len(rep_dims) < t.dim()`, + # so we filter such combinations. + samples.append(SampleInput((t, rep_dim),)) + elif op_info.name == 'tile': + samples.append(SampleInput((t, rep_dim),)) + + return samples + +def np_unary_ufunc_integer_promotion_wrapper(fn): + # Wrapper that passes PyTorch's default scalar + # type as an argument to the wrapped NumPy + # unary ufunc when given an integer input. + # This mimicks PyTorch's integer->floating point + # type promotion. + # + # This is necessary when NumPy promotes + # integer types to double, since PyTorch promotes + # integer types to the default scalar type. + + # Helper to determine if promotion is needed + def is_integral(dtype): + return dtype in [np.bool, np.uint8, np.int8, np.int16, np.int32, np.int64] + + # NOTE: Promotion in PyTorch is from integer types to the default dtype + np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + + @wraps(fn) + def wrapped_fn(x): + if is_integral(x.dtype): + return fn(x, dtype=np_dtype) + return fn(x) + + return wrapped_fn + + +# Metadata class for Fast Fourier Transforms in torch.fft. +class SpectralFuncInfo(OpInfo): + """Operator information for torch.fft transforms. """ + + def __init__(self, + name, # the string name of the function + *, + ref=None, # Reference implementation (probably in np.fft namespace) + dtypes=floating_and_complex_types(), + ndimensional: bool, # Whether dim argument can be a tuple + skips=None, + decorators=None, + **kwargs): + skips = skips if skips is not None else [] + + # gradgrad is quite slow + if not TEST_WITH_SLOW: + skips.append(SkipInfo('TestGradients', 'test_fn_gradgrad')) + + decorators = decorators if decorators is not None else [] + decorators += [skipCPUIfNoMkl, skipCUDAIfRocm] + + super().__init__(name=name, + dtypes=dtypes, + skips=skips, + decorators=decorators, + **kwargs) + self.ref = ref if ref is not None else _getattr_qual(np, name) + self.ndimensional = ndimensional + + def sample_inputs(self, device, dtype, requires_grad=False): - low, high = self.domain - low = low if low is None else low + self._domain_eps - high = high if high is None else high - self._domain_eps + nd_tensor = make_tensor((S, S + 1, S + 2), device, dtype, low=None, high=None, + requires_grad=requires_grad) + tensor = make_tensor((31,), device, dtype, low=None, high=None, + requires_grad=requires_grad) + + if self.ndimensional: + return [ + SampleInput(nd_tensor, kwargs=dict(s=(3, 10), dim=(1, 2), norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(s=(8,))), + SampleInput(tensor), + + *(SampleInput(nd_tensor, kwargs=dict(dim=dim)) + for dim in [-1, -2, -3, (0, -1)]), + ] + else: + return [ + SampleInput(nd_tensor, kwargs=dict(n=10, dim=1, norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(norm='ortho')), + SampleInput(nd_tensor, kwargs=dict(n=7)), + SampleInput(tensor), + + *(SampleInput(nd_tensor, kwargs=dict(dim=dim)) + for dim in [-1, -2, -3]), + ] + + +class ShapeFuncInfo(OpInfo): + """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" + def __init__(self, + name, # the string name of the function + *, + ref, # a reference function + dtypes=floating_types(), + dtypesIfCPU=None, + dtypesIfCUDA=None, + dtypesIfROCM=None, + sample_inputs_func=None, + **kwargs): + super(ShapeFuncInfo, self).__init__(name, + dtypes=dtypes, + dtypesIfCPU=dtypesIfCPU, + dtypesIfCUDA=dtypesIfCUDA, + dtypesIfROCM=dtypesIfROCM, + sample_inputs_func=sample_inputs_func, + **kwargs) + self.ref = ref + + +class HermitianOpInfo(OpInfo): + """Operator information for Hermitian functions + These are functions that take Hermitian matrices as input. + They require a modified function to be tested for gradcheck, because the finite-difference algorithm + for calculating derivatives does not preserve the Hermitian property of the input and returning incorrect results. + """ + + def get_op(self): + """ + Returns the function variant of the operator, torch., + compatible with gradcheck for Hermitian functions. + It works only for single input argument. + """ + def hermitian_func(non_hermitian_input, **kwargs): + hermitian_input = non_hermitian_input + non_hermitian_input.conj().transpose(-2, -1) + return self.op(hermitian_input, **kwargs) + + return hermitian_func + + +def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False): + """ + This function generates input for torch.linalg.pinv with distinct singular values so that autograd is always stable + Implementation of torch.linalg.pinv depends on torch.svd and torch.linalg.eigh, therefore it's sufficient to + check only square S x S matrix and the batched (3 x S x S) input. + """ + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + test_cases = ( + random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), # single matrix + random_fullrank_matrix_distinct_singular_value(S, 3, dtype=dtype).to(device), # batch of matrices + ) + + out = [] + for a in test_cases: + a.requires_grad = requires_grad + out.append(SampleInput(a)) + return out + + +def sample_inputs_linalg_pinv_hermitian(op_info, device, dtype, requires_grad=False): + """ + This function generates input for torch.linalg.pinv with hermitian=True keyword argument. + """ + out = sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad) + for o in out: + o.kwargs = {"hermitian": True} + return out + +def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False): + """ + This function generates always solvable input for torch.linalg.solve + Using random_fullrank_matrix_distinct_singular_value gives a non-singular (=invertible, =solvable) matrices 'a'. + The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'. + The second input is generated as the product of 'batches', 'ns' and 'nrhs'. + In total this function generates 18 SampleInputs + 'batches' cases include: + () - single input, + (0,) - zero batched dimension, + (2,) - batch of two matrices. + 'ns' gives 0x0 and 5x5 matrices. + and 'nrhs' controls the number of vectors to solve for: + () - using 1 as the number of vectors implicitly + (1,) - same as () but explicit + (3,) - solve for 3 vectors. + Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes. + """ + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + batches = [(), (0, ), (2, )] + ns = [0, 5] + nrhs = [(), (1, ), (3, )] + out = [] + for n, batch, rhs in product(ns, batches, nrhs): + a = random_fullrank_matrix_distinct_singular_value(n, *batch, dtype=dtype).to(device) + a.requires_grad = requires_grad + b = torch.randn(*batch, n, *rhs, dtype=dtype, device=device) + b.requires_grad = requires_grad + out.append(SampleInput((a, b))) + return out + + +def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False): + """ + This function generates input for torch.svd with distinct singular values so that autograd is always stable. + Matrices of different size: + square matrix - S x S size + tall marix - S x (S-2) + wide matrix - (S-2) x S + and batched variants of above are generated. + Each SampleInput has a function 'output_process_fn_grad' attached to it that is applied on the output of torch.svd + It is needed for autograd checks, because backward of svd doesn't work for an arbitrary loss function. + """ + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # along different dimensions when needed (this is used by + # test_cases2:wide_all and wide_all_batched below) + if is_linalg_svd: + def slice_V(v): + return v[..., :(S - 2), :] + else: + def slice_V(v): + return v[..., :, :(S - 2)] + + test_cases1 = ( # some=True (default) + # loss functions for complex-valued svd have to be "gauge invariant", + # i.e. loss functions shouldn't change when sigh of the singular vectors change. + # the simplest choice to satisfy this requirement is to apply 'abs'. + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: usv[1]), # 'check_grad_s' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: abs(usv[0])), # 'check_grad_u' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: abs(usv[2])), # 'check_grad_v' + # TODO: replace lambda usv: usv[0][0, 0] * usv[2][0, 0] with lambda usv: usv[0][0, 0] * usv[2][0, 0].conj() + # once https://github.com/pytorch/pytorch/issues/45821 is resolved + # this test is important as it checks the additional term that is non-zero only for complex-valued inputs + # and when the loss function depends both on 'u' and 'v' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), + lambda usv: usv[0][0, 0] * usv[2][0, 0]), # 'check_grad_uv' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'tall' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device), + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'batched' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'wide_batched' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(usv[2]))), # 'tall_batched' + ) + test_cases2 = ( # some=False + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all' + (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], + lambda usv: (abs(usv[0][:, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all_batched' + (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], + lambda usv: (abs(usv[0][..., :, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all_batched' + ) + + out = [] + for a, out_fn in test_cases1: + a.requires_grad = requires_grad + if is_linalg_svd: + kwargs = {'full_matrices': False} + else: + kwargs = {'some': True} + out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) + + for a, out_fn in test_cases2: + a.requires_grad = requires_grad + if is_linalg_svd: + kwargs = {'full_matrices': True} + else: + kwargs = {'some': False} + out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) + + return out + +def sample_inputs_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=False) + +def sample_inputs_linalg_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=True) + +def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): + """ + This function generates input for torch.pinverse with distinct singular values so that autograd is always stable. + Implementation of torch.pinverse depends on torch.svd, therefore it's sufficient to check only square S x S matrix + and the batched (3 x S x S) input. + """ + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + + test_cases = ( + random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), # pinverse + random_fullrank_matrix_distinct_singular_value(S, 3, dtype=dtype).to(device), # pinverse 'batched' + ) + + out = [] + for a in test_cases: + a.requires_grad = requires_grad + out.append(SampleInput(a)) + return out - return (SampleInput(make_tensor((L,), device, dtype, - low=low, high=high, - requires_grad=requires_grad)),) +def sample_inputs_flip(op_info, device, dtype, requires_grad): + tensors = ( + make_tensor((S, M, S), device, dtype, low=None, high=None, requires_grad=requires_grad), + make_tensor((S, 0, M), device, dtype, low=None, high=None, requires_grad=requires_grad) + ) + dims = ((0, 1, 2), (0,), (0, 2), (-1,), ()) + + samples = [SampleInput(tensor, kwargs={'dims': dim}) for tensor, dim in product(tensors, dims)] + + return samples + +def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): + tensors = ( + make_tensor((S, M, S), device, dtype, low=None, high=None, requires_grad=requires_grad), + make_tensor((S, 0, M), device, dtype, low=None, high=None, requires_grad=requires_grad) + ) + return [SampleInput(tensor) for tensor in tensors] # Operator database (sorted alphabetically) -op_db = [ +op_db: List[OpInfo] = [ # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) UnaryUfuncInfo('acos', ref=np.arccos, domain=(-1, 1), handles_complex_extremals=False, + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + default_test_dtypes=[torch.long, torch.half, torch.bfloat16, torch.float32, torch.cfloat], + skip_bfloat16_grad=True, + assert_autodiffed=True, decorators=(precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-1, torch.complex64: 1e-2}),), + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -228,31 +859,85 @@ def sample_inputs(self, device, dtype, requires_grad=False): UnaryUfuncInfo('acosh', ref=np.arccosh, domain=(1, float('inf')), - dtypesIfCPU=floating_types(), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), - test_inplace_grad=False), + test_inplace_grad=False, + skips=( + # RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16' + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cuda', dtypes=[torch.bfloat16]), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + # Reference: https://github.com/pytorch/pytorch/issues/50692 + SkipInfo('TestGradients', 'test_fn_grad', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + SkipInfo('TestGradients', 'test_method_grad', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + )), + OpInfo('addmm', + dtypes=floating_types(), + dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, + *[torch.bfloat16] if tf32_is_not_fp32() else []), + dtypesIfROCM=floating_types_and(torch.half), + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], + skips=( + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=[torch.bfloat16, torch.float16, torch.cfloat, torch.cdouble]),), + sample_inputs_func=sample_inputs_addmm), UnaryUfuncInfo('asin', ref=np.arcsin, domain=(-1, 1), + supports_sparse=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + promotes_integers_to_float=True, + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + assert_autodiffed=True, + skip_bfloat16_grad=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], - active_if=IS_WINDOWS), + active_if=IS_WINDOWS) )), # NOTE: derivative for inplace asinh is not implemented UnaryUfuncInfo('asinh', ref=np.arcsinh, - dtypesIfCPU=floating_types(), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), - test_inplace_grad=False), + test_inplace_grad=False, + skips=( + # RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16' + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cuda', dtypes=[torch.bfloat16]), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + )), UnaryUfuncInfo('atan', ref=np.arctan, + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + assert_autodiffed=True, + skip_bfloat16_grad=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), @@ -263,14 +948,33 @@ def sample_inputs(self, device, dtype, requires_grad=False): UnaryUfuncInfo('atanh', ref=np.arctanh, domain=(-1, 1), - dtypesIfCPU=floating_types(), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), - test_inplace_grad=False), + test_inplace_grad=False, + skips=( + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + )), + OpInfo('broadcast_to', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_tensor_out=False, + test_inplace_grad=False, + sample_inputs_func=sample_inputs_broadcast_to), UnaryUfuncInfo('cos', ref=np.cos, - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + skip_bfloat16_grad=True, handles_large_floats=False, + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -281,18 +985,159 @@ def sample_inputs(self, device, dtype, requires_grad=False): dtypes=[torch.float], active_if=TEST_WITH_ROCM), )), UnaryUfuncInfo('cosh', - ref=np.cosh, - dtypesIfCPU=floating_and_complex_types(), + ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + promotes_integers_to_float=True, + assert_autodiffed=True, skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.int8]), SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cuda', dtypes=[torch.float16]), )), + UnaryUfuncInfo('exp', + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), + dtypes=all_types_and_complex_and(torch.bool, torch.half), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + ), + assert_autodiffed=True, + promotes_integers_to_float=True), + SpectralFuncInfo('fft.fft', + aten_name='fft_fft', + ref=np.fft.fft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + SpectralFuncInfo('fft.fftn', + aten_name='fft_fftn', + ref=np.fft.fftn, + ndimensional=True, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False, + decorators=[precisionOverride( + {torch.float: 1e-4, torch.cfloat: 1e-4})],), + SpectralFuncInfo('fft.hfft', + aten_name='fft_hfft', + ref=np.fft.hfft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + SpectralFuncInfo('fft.rfft', + aten_name='fft_rfft', + ref=np.fft.rfft, + ndimensional=False, + dtypes=all_types_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + SpectralFuncInfo('fft.rfftn', + aten_name='fft_rfftn', + ref=np.fft.rfftn, + ndimensional=True, + dtypes=all_types_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False, + decorators=[precisionOverride({torch.float: 1e-4})],), + SpectralFuncInfo('fft.ifft', + aten_name='fft_ifft', + ref=np.fft.ifft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + SpectralFuncInfo('fft.ifftn', + aten_name='fft_ifftn', + ref=np.fft.ifftn, + ndimensional=True, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + SpectralFuncInfo('fft.ihfft', + aten_name='fft_ihfft', + ref=np.fft.ihfft, + ndimensional=False, + dtypes=all_types_and(torch.bool), + default_test_dtypes=floating_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + SpectralFuncInfo('fft.irfft', + aten_name='fft_irfft', + ref=np.fft.irfft, + ndimensional=False, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + SpectralFuncInfo('fft.irfftn', + aten_name='fft_irfftn', + ref=np.fft.irfftn, + ndimensional=True, + dtypes=all_types_and_complex_and(torch.bool), + default_test_dtypes=floating_and_complex_types(), + supports_tensor_out=True, + test_inplace_grad=False,), + OpInfo('flip', + op=torch.flip, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_flip, + test_inplace_grad=False, + supports_tensor_out=False), + OpInfo('fliplr', + op=torch.fliplr, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + test_inplace_grad=False, + supports_tensor_out=False), + OpInfo('flipud', + op=torch.flipud, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + test_inplace_grad=False, + supports_tensor_out=False), + OpInfo('linalg.slogdet', + aten_name='linalg_slogdet', + op=torch.linalg.slogdet, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_slogdet, + output_func=itemgetter(1), + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # These tests do not work with output_func=itemgetter(1) + # TODO: remove this once https://github.com/pytorch/pytorch/issues/49326 is resolved + SkipInfo('TestCommon', 'test_variant_consistency_jit'),)), UnaryUfuncInfo('log', ref=np.log, domain=(0, float('inf')), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + skip_bfloat16_grad=True, + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -307,7 +1152,12 @@ def sample_inputs(self, device, dtype, requires_grad=False): ref=np.log10, domain=(0, float('inf')), decorators=(precisionOverride({torch.bfloat16: 5e-2}),), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + assert_autodiffed=True, + skip_bfloat16_grad=True, + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -318,13 +1168,21 @@ def sample_inputs(self, device, dtype, requires_grad=False): UnaryUfuncInfo('log1p', ref=np.log1p, domain=(-1, float('inf')), - dtypesIfCPU=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), - decorators=(precisionOverride({torch.bfloat16: 1e-1}),)), + dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + promotes_integers_to_float=True, + assert_autodiffed=True, + skip_bfloat16_grad=True), UnaryUfuncInfo('log2', ref=np.log2, domain=(0, float('inf')), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + skip_bfloat16_grad=True, + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-1}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -334,13 +1192,21 @@ def sample_inputs(self, device, dtype, requires_grad=False): )), UnaryUfuncInfo('neg', ref=np.negative, + skip_bfloat16_grad=True, dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half)), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + assert_autodiffed=True,), UnaryUfuncInfo('sin', ref=np.sin, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + assert_autodiffed=True, + skip_bfloat16_grad=True, handles_large_floats=False, handles_complex_extremals=False, + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -348,9 +1214,32 @@ def sample_inputs(self, device, dtype, requires_grad=False): SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.float], active_if=TEST_WITH_ROCM), )), + UnaryUfuncInfo('sinc', + ref=np_sinc_with_fp16_as_fp32, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + skip_bfloat16_grad=True, + handles_large_floats=False, + handles_complex_extremals=False, + promotes_integers_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-2, + torch.float16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/49133 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.cfloat]), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.float], active_if=TEST_WITH_ROCM), + )), UnaryUfuncInfo('sinh', - ref=np.sinh, - dtypesIfCPU=floating_and_complex_types(), + ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + promotes_integers_to_float=True, + assert_autodiffed=True, decorators=(precisionOverride({torch.float16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -359,9 +1248,20 @@ def sample_inputs(self, device, dtype, requires_grad=False): SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.int8]), + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cuda', dtypes=[torch.float16]), )), UnaryUfuncInfo('tan', ref=np.tan, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + assert_autodiffed=True, + skip_bfloat16_grad=True, + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -370,11 +1270,19 @@ def sample_inputs(self, device, dtype, requires_grad=False): SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cuda', dtypes=[torch.float64], + active_if=TEST_WITH_ROCM), )), UnaryUfuncInfo('tanh', ref=np.tanh, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + skip_bfloat16_grad=True, + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -382,20 +1290,357 @@ def sample_inputs(self, device, dtype, requires_grad=False): device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), )), + OpInfo('tensor_split', + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_tensor_out=False, + test_inplace_grad=False, + sample_inputs_func=sample_inputs_tensor_split,), UnaryUfuncInfo('exp2', - ref=np.exp2, - dtypes=floating_types_and(torch.half), + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2), + dtypes=all_types_and(torch.bool, torch.half), + dtypesIfCPU=all_types_and(torch.bool, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + promotes_integers_to_float=True), + UnaryUfuncInfo('expm1', + ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1), + dtypes=all_types_and(torch.bool, torch.half), + dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + promotes_integers_to_float=True, + assert_autodiffed=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/48926#issuecomment-739734774 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.bfloat16]), + )), + UnaryUfuncInfo('nan_to_num', + ref=np.nan_to_num, + dtypes=all_types_and(torch.half, torch.bool), dtypesIfCPU=None, - dtypesIfCUDA=None) + dtypesIfCUDA=None), + UnaryUfuncInfo('reciprocal', + ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCPU=None, + dtypesIfCUDA=None, + assert_autodiffed=True, + skip_bfloat16_grad=True, + promotes_integers_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/45690 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.cfloat, torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/pull/49102#issuecomment-744604601 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.bfloat16]), + )), + UnaryUfuncInfo('rsqrt', + ref=lambda x: np.reciprocal(np.sqrt(x)), + domain=(0, float('inf')), + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + decorators=(precisionOverride({torch.half: 5e-2}),), + promotes_integers_to_float=True, + assert_autodiffed=True, + handles_complex_extremals=False), + UnaryUfuncInfo('sqrt', + ref=np.sqrt, + domain=(0, float('inf')), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + skip_bfloat16_grad=True, + decorators=(precisionOverride({torch.bfloat16: 7e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_MACOS), + # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.bfloat16])), + promotes_integers_to_float=True, + handles_complex_extremals=False), + OpInfo('linalg.inv', + aten_name='linalg_inv', + op=torch.linalg.inv, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=True, + sample_inputs_func=sample_inputs_linalg_inv, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + UnaryUfuncInfo('angle', + ref=np.angle, + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool), + dtypesIfROCM=all_types_and_complex_and(torch.bool), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + promotes_integers_to_float=True, + supports_complex_to_float=True, + test_inplace_grad=False), + OpInfo('linalg.solve', + aten_name='linalg_solve', + op=torch.linalg.solve, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=True, + sample_inputs_func=sample_inputs_linalg_solve, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + OpInfo('linalg.pinv', + aten_name='linalg_pinv', + op=torch.linalg.pinv, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_pinv, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + HermitianOpInfo('linalg.pinv', + variant_test_name='hermitian', + aten_name='linalg_pinv', + op=torch.linalg.pinv, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_pinv_hermitian, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # These tests do not take into account custom op.get_op() + SkipInfo('TestCommon', 'test_variant_consistency_jit'),) + ), + OpInfo('svd', + op=torch.svd, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_svd, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # gradgrad checks are slow + SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + # cuda gradchecks are very slow + # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + OpInfo('linalg.svd', + op=torch.linalg.svd, + aten_name='linalg_svd', + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_svd, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # gradgrad checks are slow + SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + # cuda gradchecks are very slow + # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + OpInfo('pinverse', + op=torch.pinverse, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_pinv, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + OpInfo('gather', + dtypes=all_types_and_complex_and(torch.bool, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + sample_inputs_func=sample_inputs_gather), + OpInfo('index_select', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + skips=( + # https://github.com/pytorch/pytorch/issues/49707 + SkipInfo('TestCommon', 'test_variant_consistency_eager', + dtypes=[torch.float16, torch.bfloat16]), + SkipInfo('TestCommon', 'test_variant_consistency_jit', dtypes=[torch.float16, torch.bfloat16]), + ), + sample_inputs_func=sample_inputs_index_select), + OpInfo('stack', + # gradcheck expects the input arguments as a flat list + op=lambda *args, idx: torch.stack([*args], idx), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + supports_tensor_out=False, + skips=( + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)), + ), + sample_inputs_func=sample_inputs_stack), + OpInfo('hstack', + # gradcheck expects the input arguments as a flat list + op=lambda *args: torch.hstack([*args]), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + supports_tensor_out=False, + skips=( + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)), + ), + sample_inputs_func=sample_inputs_hstack_dstack_vstack), + OpInfo('vstack', + # gradcheck expects the input arguments as a flat list + op=lambda *args: torch.vstack([*args]), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + supports_tensor_out=False, + skips=( + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)), + ), + sample_inputs_func=sample_inputs_hstack_dstack_vstack), + OpInfo('dstack', + # gradcheck expects the input arguments as a flat list + op=lambda *args: torch.dstack([*args]), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + supports_tensor_out=False, + skips=( + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)), + ), + sample_inputs_func=sample_inputs_hstack_dstack_vstack), + OpInfo('movedim', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_movedim_moveaxis), + OpInfo('moveaxis', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_movedim_moveaxis), + ShapeFuncInfo('repeat', + op=lambda x, dims: x.repeat(dims), + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_tensor_out=False, + test_inplace_grad=False, + skips=( + # torch.repeat does not exist so we get a RuntimeError. + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16)), + ), + sample_inputs_func=sample_repeat_tile), + ShapeFuncInfo('tile', + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_tensor_out=False, + test_inplace_grad=False, + sample_inputs_func=sample_repeat_tile), ] +if TEST_SCIPY: + def reference_sigmoid(x): + # 'scipy.special.expit' not supported for the input types + if x.dtype in [np.complex64, np.complex128]: + return (1 / (1 + np.exp(-x))) + return scipy.special.expit(x) + + op_db_scipy_reference: List[OpInfo] = [ + UnaryUfuncInfo('sigmoid', + ref=reference_sigmoid, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + skips=( + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + # RuntimeError: sigmoid does not support automatic differentiation for outputs with complex dtype. + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=[torch.complex64, torch.complex128]), + SkipInfo('TestCommon', 'test_variant_consistency_eager', + dtypes=[torch.complex64, torch.complex128]),), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, + assert_autodiffed=True, + test_complex_grad=False), # Reference: https://github.com/pytorch/pytorch/issues/48552 + UnaryUfuncInfo('digamma', + ref=scipy.special.digamma, + decorators=(precisionOverride({torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + skips=( + # In some cases, output is NaN (for input close to + # negative integers) especially due to reduced precision + # in float16 and NaN's can't be tested for equality. + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cuda', dtypes=[torch.float16]),), + promotes_integers_to_float=True), + UnaryUfuncInfo('erf', + ref=scipy.special.erf, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + skips=( + # RuntimeError: "pow" not implemented for 'BFloat16' + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=[torch.bfloat16]),), + assert_autodiffed=True, + promotes_integers_to_float=True), + UnaryUfuncInfo('erfc', + ref=scipy.special.erfc, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + skips=( + # RuntimeError: "pow" not implemented for 'BFloat16' + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=[torch.bfloat16]),), + assert_autodiffed=True, + promotes_integers_to_float=True), + UnaryUfuncInfo('erfinv', + ref=scipy.special.erfinv, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2, + torch.float32: 1e-4}),), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + promotes_integers_to_float=True, + domain=(-1, 1), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + active_if=LooseVersion(scipy.__version__) < "1.4.0"), + # RuntimeError: "pow" not implemented for 'BFloat16' + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=[torch.bfloat16]), + ) + ), + OpInfo('xlogy', + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + test_inplace_grad=True, + supports_tensor_out=True, + promotes_integers_to_float=True, + sample_inputs_func=sample_inputs_xlogy), + ] + op_db = op_db + op_db_scipy_reference + # Common operator groupings unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)] +spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)] +sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse is True] +shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)] -def index_variable(shape, max_indices): +def index_variable(shape, max_indices, device=torch.device('cpu')): if not isinstance(shape, tuple): shape = (shape,) - index = torch.rand(*shape).mul_(max_indices).floor_().long() + index = torch.rand(*shape, device=device).mul_(max_indices).floor_().long() return index @@ -407,14 +1652,14 @@ def index_perm_variable(shape, max_indices): return index -def gather_variable(shape, index_dim, max_indices, duplicate=False): +def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')): assert len(shape) == 2 assert index_dim < 2 batch_dim = 1 - index_dim - index = torch.LongTensor(*shape) + index = torch.zeros(*shape, dtype=torch.long, device=device) for i in range(shape[index_dim]): index.select(index_dim, i).copy_( - torch.randperm(max_indices)[:shape[batch_dim]]) + torch.randperm(max_indices, device=device)[:shape[batch_dim]]) if duplicate: index.select(batch_dim, 0).copy_(index.select(batch_dim, 1)) return index @@ -494,8 +1739,6 @@ def ident(x): def method_tests(): set_rng_seed(0) return [ - ('acosh', torch.rand(S, S, S).add(1), NO_ARGS, ''), - ('acosh', torch.rand(tuple()).add(1), NO_ARGS, 'scalar'), ('add', (S, S, S), ((S, S, S),), '', (True,)), ('add', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)), ('add', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)), @@ -505,10 +1748,7 @@ def method_tests(): ('add', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)), ('add', (S, S, S), (3.14,), 'constant', (True,)), ('add', (), (3.14,), 'scalar_constant', (True,)), - ('asinh', (S, S, S), NO_ARGS, ''), - ('asinh', (), NO_ARGS, 'scalar'), - ('atanh', torch.rand(S, S, S), NO_ARGS, ''), - ('atanh', torch.rand(tuple()), NO_ARGS, 'scalar'), + ('add', (S, S, S), (3.14j,), 'complex_scalar_constant', (True,)), ('__radd__', (S, S, S), (3.14,), 'constant', (True, 'aten::add')), ('__radd__', (), (3.14,), 'scalar_constant', (True, 'aten::add')), ('sub', (S, S, S), ((S, S, S),), '', (True,)), @@ -519,6 +1759,7 @@ def method_tests(): ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)), ('sub', (S, S, S), (3.14,), 'constant', (True,)), ('sub', (), (3.14,), 'scalar_constant', (True,)), + ('sub', (S, S, S), (3.14j,), 'complex_scalar_constant', (True,)), ('__rsub__', (S, S, S), (3.14,), 'constant', (True, 'aten::rsub')), ('__rsub__', (), (3.14,), 'scalar_constant', (True, 'aten::rsub')), ('mul', (S, S, S), ((S, S, S),), '', (True,)), @@ -557,6 +1798,19 @@ def method_tests(): (True, [], ['aten::mul', 'aten::reciprocal'])), ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True, [], ['aten::mul', 'aten::reciprocal'])), + ('__rdiv__', torch.rand(S, S, S, dtype=torch.cdouble) + 1e-1, (3.14j,), 'complex_constant', + (True, [], ['aten::mul', 'aten::reciprocal'])), + ('__rdiv__', uniform_scalar(1e-1 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_scalar_constant', + (True, [], ['aten::mul', 'aten::reciprocal'])), + ('div', (S, S, S), (torch.rand(S, S, S, dtype=torch.cdouble) + 0.1,), 'complex', (True,)), + ('div', (S, S, S), (torch.rand(S, S, dtype=torch.cdouble) + 0.1,), 'complex_broadcast_rhs', (True,)), + ('div', (S, S), (torch.rand(S, S, S, dtype=torch.cdouble) + 0.1,), 'complex_broadcast_lhs', (True,)), + ('div', (S, 1, S), (torch.rand(M, S, dtype=torch.cdouble) + 0.1,), 'complex_broadcast_all', (True,)), + ('div', (), (uniform_scalar(0.1j),), 'complex_scalar', (True,)), + ('div', (S, S, S), (uniform_scalar(0.1j),), 'complex_scalar_broadcast_rhs', (True,)), + ('div', (), (uniform_scalar(0.1j),), 'complex_scalar_broadcast_lhs', (True,)), + ('div', torch.rand(S, S, S, dtype=torch.cdouble) + 1e-1, (3.14j,), 'complex_constant', (True,)), + ('div', uniform_scalar(1e-1j, requires_grad=True), (3.14j,), 'complex_scalar_constant', (True,)), ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), '', (True,)), ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs', (True,)), ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)), @@ -565,35 +1819,50 @@ def method_tests(): ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)), ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs', (True,)), ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True,)), + ('pow', torch.rand(S, S, S, dtype=torch.cdouble) + 1e-3 * (1 + 1j), (3.14,), 'complex_constant', (True,)), ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')), ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)), + ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14,), 'complex_scalar_constant', (True,)), + ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_imaginary_exponent', (True,)), ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')), + ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), ''), + ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'), + ('float_power', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'), + ('float_power', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'), + ('float_power', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'), + ('float_power', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'), + ('float_power', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'), + ('float_power', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'), ('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]), ('transpose', (), (0, 0), 'scalar', (False,)), ('transpose', (1,), (0, 0), '1d', (False,)), ('transpose', (L, L), (0, 1), '2d', (False,)), ('transpose', (S, S, S), (2, 0), '3d', (False,)), + ('swapdims', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]), + ('swapdims', (), (0, 0), 'scalar', (False,)), + ('swapdims', (1,), (0, 0), '1d', (False,)), + ('swapdims', (L, L), (0, 1), '2d', (False,)), + ('swapdims', (S, S, S), (2, 0), '3d', (False,)), + ('swapaxes', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]), + ('swapaxes', (), (0, 0), 'scalar', (False,)), + ('swapaxes', (1,), (0, 0), '1d', (False,)), + ('swapaxes', (L, L), (0, 1), '2d', (False,)), + ('swapaxes', (S, S, S), (2, 0), '3d', (False,)), ('t', (1, 2), NO_ARGS, '', (False,)), ('view', (S, S, S), (S * S, S), '', (False,)), - ('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)), + ('view', (torch.Size([S * S, S]),), (S, S, S), 'size', (False,)), ('view', (S,), (S,), '1d', (False,)), ('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)), ('view', (), (1,), 'scalar_to_1d', (False,)), + ('ravel', (S, S, S), NO_ARGS, '', (False,)), ('reshape', (S, S, S), (S * S, S), '', (False,)), - ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)), + ('reshape', (torch.Size([S * S, S]),), (S, S, S), 'size', (False,)), ('reshape', (S,), (S,), '1d', (False,)), ('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (False,)), ('reshape', (), (1,), 'scalar_to_1d', (False,)), ('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), ('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'), ('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), - ('flip', (S, S, S), ([0],), 'd0'), - ('flip', (S, S, S), ([0, 1, 2],), 'd012'), - ('flip', (S, S, S), ([0, 2],), 'd02'), - ('flip', (S, S, S), ([2, 0],), 'd20'), - ('flip', (S, S, S), ([-1],), 'neg_d'), - ('fliplr', (S, S, S), ()), - ('flipud', (S, S, S), ()), ('roll', (S, S, S), (0, 0), 'd0'), ('roll', (S, S, S), (1, 2), 'd12'), ('roll', (S, S, S), (0, 2,), 'd02'), @@ -617,46 +1886,20 @@ def method_tests(): ('expand', (), (dont_convert(()),), 'scalar_to_scalar'), ('expand', (), (1, 3, 2), 'scalar_to_dims', (False,)), ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)), - ('exp', (S, S, S), NO_ARGS, '', (True,)), - ('exp', (), NO_ARGS, 'scalar', (True,)), - ('exp2', (S, S, S), NO_ARGS, '', (False,)), - ('exp2', (), NO_ARGS, 'scalar', (False,)), - ('expm1', (S, S, S), NO_ARGS, '', (True,)), - ('expm1', (), NO_ARGS, 'scalar', (True,)), - ('erf', torch.rand(S, S, S), NO_ARGS, '', (True,)), - ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)), - ('erfc', torch.rand(S, S, S), NO_ARGS, '', (True,)), - ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)), - ('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS), - ('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'), - ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), - ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), - ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), - ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), - ('log1p', torch.rand(S, S, S), NO_ARGS, '', (True,)), - ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)), - ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), - ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), - # TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition. - # ('log', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), - # ('log', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), - # ('log10', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), - # ('log10', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), - # ('log2', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), - # ('log2', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), - ('tanh', (S, S, S), NO_ARGS, '', (True,)), - ('tanh', (), NO_ARGS, 'scalar', (True,)), - ('sigmoid', (S, S, S), NO_ARGS, '', (True,)), - ('sigmoid', (), NO_ARGS, 'scalar', (True,)), ('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''), ('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), (0.2,), 'eps'), ('logit', uniform_scalar().clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, 'scalar'), ('logit', uniform_scalar().clamp(0.1, 0.9).requires_grad_(True), (0.2,), 'scalar_eps'), - ('sinh', (S, S, S), NO_ARGS, '', (True,)), - ('sinh', (), NO_ARGS, 'scalar', (True,)), - ('cosh', (S, S, S), NO_ARGS, '', (True,)), - ('cosh', (), NO_ARGS, 'scalar', (True,)), ('conj', (S, S, S), NO_ARGS), + ('copysign', (S, S, S), ((S, S, S),), '', (False,)), + ('copysign', (S, S, S), ((S, S),), 'broadcast_rhs', (False,)), + ('copysign', (S, S), ((S, S, S),), 'broadcast_lhs', (False,)), + ('copysign', (S, 1, S), ((M, S),), 'broadcast_all', (False,)), + ('copysign', (S, S), (3.14,), 'scalar', (False,)), + ('copysign', (S, S), (0.0,), 'scalar_pos_zero', (False,)), + # TorchScript does not recognize -0.0: Issue #46848 + # https://github.com/pytorch/pytorch/issues/46848 + # ('copysign', (S, S), (-0.0,), 'scalar_neg_zero', (False,)), ('real', (S, S, S), NO_ARGS, 'complex'), ('imag', (S, S, S), NO_ARGS, 'complex'), ('view_as_real', (S, S, S), NO_ARGS, 'complex'), @@ -671,29 +1914,11 @@ def method_tests(): ('clamp', (), (None, 0.5), 'min_scalar', (True,)), ('clamp', (), (0.5, None), 'max_scalar', (True,)), ('clamp', (S, S), (), 'max_scalar_kwarg', (True,), (), (), ident, {'max': 1}), - ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS, '', (True,)), - ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar', (True,)), - ('sin', (S, S, S), NO_ARGS, '', (True,)), - ('sin', (), NO_ARGS, 'scalar', (True,)), - ('cos', (S, S, S), NO_ARGS, '', (True,)), - ('cos', (), NO_ARGS, 'scalar', (True,)), - ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS, '', (True,)), - # TODO(@anjali411): add the commented test back after updating the formula based on tensorflow definition. - # ('tan', (S, S, S), NO_ARGS, 'complex', (True,)), - ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)), - ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)), - ('atan', (S, S, S), NO_ARGS, '', (True,)), - ('atan', (), NO_ARGS, 'scalar', (True,)), ('atan2', (S, S, S), ((S, S, S),)), ('atan2', (), ((),), 'scalar'), ('atan2', (S, S, S), ((S,),), 'broadcast_rhs'), ('atan2', (S,), ((S, S, S),), 'broadcast_lhs'), ('atan2', (S, 1, S), ((S, S),), 'broadcast_all'), - ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)), - ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)), - # TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition. - # ('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)), - # ('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)), ('round', (S, S, S), NO_ARGS, '', (True,)), ('round', (), NO_ARGS, 'scalar', (True,)), ('sign', (S, S, S), NO_ARGS), @@ -708,8 +1933,14 @@ def method_tests(): ('ceil', (), NO_ARGS, 'scalar', (True,)), ('rad2deg', (S, S, S), NO_ARGS), ('deg2rad', (S, S, S), NO_ARGS), + # Removing the 'rsqrt' entries leads to failure in + # test_index_fill_variable_dim_* + # TODO: Remove when fixed. + # Reference: https://github.com/pytorch/pytorch/issues/48230 ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('rsqrt', torch.rand(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), + ('rsqrt', uniform_scalar(1e-2 * (1 + 1j), requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), ('frac', (S, S, S), NO_ARGS, '', (True,)), ('frac', (), NO_ARGS, 'scalar', (True,)), ('fmod', (S, S, S), (1.5,), '', (True,)), @@ -785,14 +2016,14 @@ def method_tests(): ('mean', (S, S, S), (), 'dtype', (True,), (), (), ident, {'dtype': torch.float64}), ('kthvalue', (S, S, S), (2,)), ('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]), + ('kthvalue', (S, S, S), (2, 1,), 'dim_alert_nondeterministic', (), [1], + [expectedAlertNondeterministic('kthvalue CUDA', 'cuda')]), ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]), ('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]), ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]), - # TODO: https://github.com/pytorch/pytorch/issues/30818 - ('kthvalue', (), (1,), 'scalar', (), (), [expectedFailureCUDA]), - ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1], [expectedFailureCUDA]), - ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1], [expectedFailureCUDA]), - # END TODO + ('kthvalue', (), (1,), 'scalar', (), ()), + ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]), + ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]), ('quantile', (S, S, S), (0.5,)), ('quantile', (S, S, S), (0.5, 0), 'dim', (), [1]), ('quantile', (S, S, S), (0.5, None, True), 'keepdim'), @@ -805,12 +2036,18 @@ def method_tests(): ('nanquantile', (), (0.5,), 'scalar'), ('median', (S, S, S), NO_ARGS), ('median', (S, S, S), (1,), 'dim', (), [0]), + ('median', (S, S, S), (1,), 'dim_alert_nondeterministic', (), [0], + [expectedAlertNondeterministic('median CUDA with indices output', 'cuda')]), ('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('median', (), NO_ARGS, 'scalar'), - # TODO: https://github.com/pytorch/pytorch/issues/30818 - ('median', (), (0,), 'scalar_dim', (), [0], [expectedFailureCUDA]), - ('median', (), (0, True,), 'scalar_keepdim_dim', (), [0], [expectedFailureCUDA]), - # END TODO + ('median', (), (0,), 'scalar_dim', (), [0]), + ('median', (), (0, True,), 'scalar_keepdim_dim', (), [0]), + ('nanmedian', (S, S, S), NO_ARGS), + ('nanmedian', (S, S, S), (1,), 'dim', (), [0]), + ('nanmedian', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), + ('nanmedian', (), NO_ARGS, 'scalar'), + ('nanmedian', (), (0,), 'scalar_dim', (), [0]), + ('nanmedian', (), (0, True,), 'scalar_keepdim_dim', (), [0]), ('mode', (S, S, S), NO_ARGS), ('mode', (S, S, S), (1,), 'dim', (), [0]), ('mode', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), @@ -875,12 +2112,6 @@ def method_tests(): ('renorm', (S, S, S), (2, 1, 0.5), 'dim', (), [1]), ('renorm', (S, S, S), (1, 2, 3), 'norm_1'), ('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'), - ('repeat', (S,), (2,), 'single_number'), - ('repeat', (), (2, 3), 'scalar'), - ('repeat', (2, 2), (3, 2)), - ('repeat', (2, 2), (1, 3, 1, 2), 'unsqueeze'), - ('repeat', (S,), (0, ), 'zero_dim'), - ('repeat', (S,), (0, 2), 'zero_dim_multi'), ('logcumsumexp', (S, S, S), (0,), 'dim0', (), [0]), ('logcumsumexp', (S, S, S), (1,), 'dim1', (), [0]), ('logcumsumexp', (), (0,), 'dim0_scalar', (), [0]), @@ -958,10 +2189,13 @@ def method_tests(): ('addr', (S, M), ((S,), (M,)), 'coef', (), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), ('dot', (L,), ((L,),), '', (True,)), + ('vdot', (L,), ((L,),),), ('mm', (S, M), ((M, S),), '', (True,)), ('bmm', (M, S, M), ((M, M, S),), '', (True,)), ('mv', (S, M), ((M,),), '', (True,)), ('ger', (S,), ((M,),)), + ('inner', (S,), ((S,),), "1d_1d", (False,)), + ('inner', (), ((S, S),), "scalar_2d", (False,)), ('matmul', (L,), ((L,),), '', (True,)), ('matmul', (S, M), ((M,),), "2d_1d", (True,)), ('matmul', (M,), ((M, S),), "1d_2d", (True,)), @@ -977,11 +2211,11 @@ def method_tests(): ('matrix_power', (S, S, S), [3], "n=3"), ('matrix_power', (S, S, S), [1], "n=1"), ('matrix_power', (S, S, S), [0], "n=0"), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", (), + ('matrix_power', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", (), + ('matrix_power', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", (), + ('matrix_power', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('matrix_exp', (S, S), NO_ARGS, "single_matrix"), ('matrix_exp', (S, S, S), NO_ARGS, "batch_of_matrices"), @@ -1107,15 +2341,14 @@ def method_tests(): ('trace', (M, M), NO_ARGS), ('cross', (S, 3), ((S, 3),)), ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'), - ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', (), [0]), - ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', (), [0]), - ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', (), [0]), ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', (), [0]), ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]), ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]), ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'alert_nondeterministic', (), [0], [expectedAlertNondeterministic('index_add_cuda_', 'cuda')]), ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', (), [0]), + ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim_alert_nondeterministic', (), [0], + [expectedAlertNondeterministic('index_copy')]), ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]), ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]), ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', (), [0]), @@ -1123,106 +2356,61 @@ def method_tests(): ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', (), [0]), ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', (), [0]), ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', (), [0]), - ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), + ('inverse', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S, 2, 3), + ('inverse', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 2, 3, dtype=dtype).to(device), NO_ARGS, 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (S, S), NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (1, 1), NO_ARGS, '1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_psd_matrix(S), + ('det', lambda dtype, device: random_symmetric_matrix(S), NO_ARGS, 'symmetric', (), + NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('det', lambda dtype, device: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_pd_matrix(S), + ('det', lambda dtype, device: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_square_matrix_of_rank(S, S - 2), + ('det', lambda dtype, device: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, + ('det', lambda dtype, device: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', (), + NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('det', lambda dtype, device: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), + NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('det', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, 'distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (3, 3, S, S), NO_ARGS, 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('det', (3, 3, 1, 1), NO_ARGS, 'batched_1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_matrix(S, 3), + ('det', lambda dtype, device: random_symmetric_matrix(S, 3), NO_ARGS, 'batched_symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_psd_matrix(S, 3), + ('det', lambda dtype, device: random_symmetric_psd_matrix(S, 3), NO_ARGS, 'batched_symmetric_psd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_symmetric_pd_matrix(S, 3), + ('det', lambda dtype, device: random_symmetric_pd_matrix(S, 3), NO_ARGS, 'batched_symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('det', lambda: random_fullrank_matrix_distinct_singular_value(S, 3, 3), NO_ARGS, + ('det', lambda dtype, device: random_fullrank_matrix_distinct_singular_value(S, 3, 3), NO_ARGS, 'batched_distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - # For `logdet` and `slogdet`, the function at det=0 is not smooth. + # For `logdet` the function at det=0 is not smooth. # We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use # `make_nonzero_det` to make the random matrices have nonzero det. For # `logdet`, we also set `make_nonzero_det(matrix, sign=1)` to make the # matrix have positive det. - ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS, 'symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS, 'distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(torch.randn(3, 3, S, S), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(3, 3, S, S), 1), NO_ARGS, 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(torch.randn(3, 3, 1, 1), 1), + ('logdet', lambda dtype, device: make_nonzero_det(torch.randn(3, 3, 1, 1), 1), NO_ARGS, 'batched_1x1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S, 3), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_matrix(S, 3), 1), NO_ARGS, 'batched_symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S, 3), 1), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_symmetric_pd_matrix(S, 3), 1), NO_ARGS, 'batched_symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S, 3), 1, 0), NO_ARGS, + ('logdet', lambda dtype, device: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S, 3), 1, 0), NO_ARGS, 'batched_distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, - '1x1_pos_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS, - '1x1_neg_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, - 'pos_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS, - 'neg_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS, - 'symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_symmetric_pd_matrix(S), NO_ARGS, - 'symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, - 'distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(3, 3, 1, 1), -1), NO_ARGS, - 'batched_1x1_neg_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(torch.randn(3, 3, S, S), 1), NO_ARGS, - 'batched_pos_det', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S, 3)), NO_ARGS, - 'batched_symmetric', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_symmetric_pd_matrix(S, 3), NO_ARGS, - 'batched_symmetric_pd', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, - 'batched_distinct_singular_values', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], itemgetter(1)), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), - NO_ARGS, '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS, - 'wide', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS, - 'tall', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,), - 'wide_all', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,), - 'tall_all', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS, - 'large', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3), NO_ARGS, - 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :(S - 2), :], NO_ARGS, - 'wide_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3)[..., :, :(S - 2)], NO_ARGS, - 'tall_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :(S - 2), :], (False,), - 'wide_all_batched', (), NO_ARGS, - [skipCPUIfNoLapack, skipCUDAIfNoMagma], lambda usv: (usv[0], usv[1], usv[2][..., :, :(S - 2)])), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S, 3, 3)[..., :, :(S - 2)], (False,), - 'tall_all_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma], - lambda usv: (usv[0][..., :, :(S - 2)], usv[1], usv[2])), ('qr', (S, S), (False,), 'square_single', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('qr', (S, S - 2), (True,), 'tall_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('qr', (S - 2, S), (False,), 'wide_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), @@ -1232,15 +2420,29 @@ def method_tests(): ('qr', (3, 2, S, S), (False,), 'square_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('qr', (3, 2, S, S - 2), (True,), 'tall_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('qr', (3, 2, S - 2, S), (True,), 'wide_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value( - S, silent=True),), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),), + ('lu', (S, S), (True, False), 'square_single_no_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('lu', (S, S), (True, True), 'square_single_with_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('lu', (3, S, S), (True, False), 'square_batch_no_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('lu', (3, S, S), (True, True), 'square_batch_with_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('lu', (3, 3, S, S), (True, False), 'square_many_batches_no_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('lu', (3, 3, S, S), (True, True), 'square_many_batches_with_info', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('solve', (S, S), (lambda dtype, device: random_fullrank_matrix_distinct_singular_value( + S, silent=True, dtype=dtype, device=device),), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('solve', (S, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, S, silent=True, dtype=dtype, device=device),), 'batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),), + ('solve', (2, 3, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True, dtype=dtype, device=device),), 'batched_dims', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),), + ('solve', (2, 2, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, 1, silent=True, dtype=dtype, device=device),), 'batched_broadcast_A', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),), + ('solve', (1, S, S), + (lambda dtype, device: + random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True, dtype=dtype, device=device),), 'batched_broadcast_b', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('fill_', (S, S, S), (1,), 'number'), ('fill_', (), (1,), 'number_scalar'), @@ -1309,11 +2511,10 @@ def method_tests(): ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), '', (True,)), ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3), 0],), 'size_0', (True, )), ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'dim', (True, ), [1]), - ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]), - ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]), - ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]), - ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', (), [0]), - ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', (), [0]), + ('tensor_split', (S, S, S), (3,), 'sections', (False,)), + ('tensor_split', (S, S, S), (3, 1), 'sections_dim', (False,), [1]), + ('tensor_split', (S, S, S), ([2, 4],), 'indices', (False,)), + ('tensor_split', (S, S, S), ([2, 4], 1), 'indices_dim', (False,), [1]), ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', (), [0]), ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', (), [0]), ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', (), [0]), @@ -1356,12 +2557,20 @@ def method_tests(): ('resize_as_', (), (non_differentiable(torch.tensor(5.)),), 'scalar'), ('resize_as_', (), (non_differentiable(torch.randn((1, 1, 1))),), 'scalar_to_dims'), ('resize_as_', (S, S, S), (non_differentiable(torch.randn(S * S, S)),)), + # TODO(@nikitaved) enable stable sort tests for CUDA once they are implemented ('sort', (S, M, S), NO_ARGS), + ('sort', (S, M, S), (0, False, True), 'stable', (), NO_ARGS, [onlyCPU]), ('sort', (S, M, S), (1,), 'dim'), + ('sort', (S, M, S), (1, False, True), 'dim_stable', (), NO_ARGS, [onlyCPU]), ('sort', (S, M, S), (1, True), 'dim_desc'), + ('sort', (S, M, S), (1, True, True), 'dim_desc_stable', (), NO_ARGS, [onlyCPU]), ('sort', (), NO_ARGS, 'scalar'), + ('sort', (), (0, False, True), 'scalar_stable', (), NO_ARGS, [onlyCPU]), ('sort', (), (0,), 'dim_scalar'), + ('sort', (), (0, False, True), 'dim_scalar_stable', (), NO_ARGS, [onlyCPU]), ('sort', (), (0, True), 'dim_desc_scalar'), + ('sort', (), (0, True, True), 'dim_desc_scalar_stable', (), NO_ARGS, [onlyCPU]), + ('msort', (S, M, S), NO_ARGS), ('topk', (S, M, S), (3,)), ('topk', (S, M, S), (3, 1), 'dim', (), [1]), ('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]), @@ -1394,6 +2603,8 @@ def method_tests(): ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])]),), 'adv_index_var'), ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()), + ('triangular_solve', (S, M), ((S, S), ), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('kron', (S, S), ((M, L),)) ] def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None): @@ -1430,7 +2641,7 @@ def maybe_non_contig(tensor): v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex()) return v elif callable(arg): - return map_arg(arg()) + return map_arg(arg(dtype=dtype, device=device)) else: return arg args_out = tuple(map_arg(arg) for arg in call_args) @@ -1635,11 +2846,6 @@ def unpack_variables(args): 'test_logdet_batched', 'test_logdet_batched_1x1', 'test_logdet_batched_symmetric', - 'test_slogdet_1x1_neg_det', - 'test_slogdet_neg_det', - 'test_slogdet_symmetric', - 'test_slogdet_batched_1x1_neg_det', - 'test_slogdet_batched_symmetric', 'test_cdist', } @@ -1647,10 +2853,6 @@ def unpack_variables(args): def exclude_tensor_method(name, test_name): # there are no tensor equivalents for these (inplace or out) exclude_all_tensor_method_by_test_name = { - 'test_clamp_min', - 'test_clamp_max', - 'test_clamp_min_scalar', - 'test_clamp_max_scalar', 'test_slice', 'test_where', 'test_where_broadcast_all', diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 2de86795cda76..26bab3a67c10f 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -1,5 +1,5 @@ +from abc import abstractmethod import math -import sys import tempfile import unittest @@ -14,21 +14,20 @@ import torch.cuda import torch.nn as nn import torch.nn.functional as F -from torch.nn.functional import _Reduction +from torch.nn import _reduction as _Reduction from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ - TEST_WITH_ROCM, _assertGradAndGradgradChecks + TEST_WITH_ROCM from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_device_type import expectedAlertNondeterministic from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors, \ gradcheck, gradgradcheck from torch.autograd import Variable +from torch.types import _TensorOrTensors import torch.backends.cudnn -# tarfile module tries to obtain a file object name in python 3.3 -if sys.version_info[:2] == (3, 3): - TemporaryFile = tempfile.NamedTemporaryFile -else: - TemporaryFile = tempfile.TemporaryFile +from typing import Dict, Callable, Tuple, List, Sequence, Union, Any + +TemporaryFile = tempfile.TemporaryFile PRECISION = 1e-5 @@ -597,6 +596,19 @@ def l1loss_no_reduce_test(): pickle=False) +def l1loss_no_reduce_complex_test(): + t = torch.randn(2, 3, 4, dtype=torch.cdouble) + return dict( + fullname='L1Loss_no_reduce_complex', + constructor=wrap_functional( + lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), + cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', + input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), + pickle=False) + + def l1loss_no_reduce_scalar_test(): t = torch.randn(()) return dict( @@ -644,7 +656,7 @@ def nllloss_no_reduce_test(): return dict( fullname='NLLLoss_no_reduce', constructor=wrap_functional( - lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)), + lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(15, 10).log(), @@ -656,11 +668,12 @@ def nllloss_no_reduce_test(): def nllloss_no_reduce_ignore_index_test(): t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) - kwargs = {'ignore_index': 2, 'reduction': 'none'} + kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'} return dict( fullname='NLLLoss_no_reduce_ignore_index', constructor=wrap_functional( - lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)), + lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), + reduction=str(kwargs['reduction']))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''', input_fn=lambda: torch.rand(15, 10).log(), @@ -741,7 +754,7 @@ def nllloss2d_no_reduce_test(): return dict( fullname='NLLLoss2d_no_reduce', constructor=wrap_functional( - lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)), + lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5).log(), @@ -753,11 +766,12 @@ def nllloss2d_no_reduce_test(): def nllloss2d_no_reduce_ignore_index_test(): t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) - kwargs = {'ignore_index': 1, 'reduction': 'none'} + kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} return dict( fullname='NLLLoss2d_no_reduce_ignore_index', constructor=wrap_functional( - lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)), + lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), + reduction=str(kwargs['reduction']))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5).log(), @@ -794,7 +808,7 @@ def nlllossNd_no_reduce_test(): return dict( fullname='NLLLossNd_no_reduce', constructor=wrap_functional( - lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)), + lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), @@ -806,11 +820,12 @@ def nlllossNd_no_reduce_test(): def nlllossNd_no_reduce_ignore_index_test(): t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) - kwargs = {'ignore_index': 1, 'reduction': 'none'} + kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} return dict( fullname='NLLLossNd_no_reduce_ignore_index', constructor=wrap_functional( - lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)), + lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), + reduction=str(kwargs['reduction']))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), @@ -871,6 +886,36 @@ def smoothl1loss_no_reduce_scalar_test(): pickle=False) +def smoothl1loss_beta_test(): + t = torch.randn(2, 3, 4) + return dict( + fullname='SmoothL1Loss_beta', + constructor=wrap_functional( + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)), + cpp_function_call='''F::smooth_l1_loss( + i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''', + input_fn=lambda: torch.randn(2, 3, 4), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5), + pickle=False) + + +def smoothl1loss_zero_beta_test(): + t = torch.randn(2, 3, 4) + return dict( + fullname='SmoothL1Loss_zero_beta', + constructor=wrap_functional( + lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)), + cpp_function_call='''F::smooth_l1_loss( + i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''', + input_fn=lambda: torch.randn(2, 3, 4), + cpp_var_map={'i': '_get_input()', 't': t}, + reference_fn=lambda i, *_: + loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0), + pickle=False) + + def multilabelmarginloss_0d_no_reduce_test(): t = torch.zeros(()).long() return dict( @@ -1228,6 +1273,7 @@ def fractional_max_pool3d_test(test_case): kldivloss_no_reduce_log_target_test(), kldivloss_no_reduce_scalar_log_target_test(), l1loss_no_reduce_test(), + l1loss_no_reduce_complex_test(), l1loss_no_reduce_scalar_test(), mseloss_no_reduce_test(), mseloss_no_reduce_scalar_test(), @@ -1244,6 +1290,8 @@ def fractional_max_pool3d_test(test_case): nlllossNd_no_reduce_ignore_index_test(), smoothl1loss_no_reduce_test(), smoothl1loss_no_reduce_scalar_test(), + smoothl1loss_beta_test(), + smoothl1loss_zero_beta_test(), multilabelmarginloss_0d_no_reduce_test(), multilabelmarginloss_1d_no_reduce_test(), multilabelmarginloss_index_neg_test(), @@ -1547,8 +1595,30 @@ def fractional_max_pool3d_test(test_case): input_size=(4, 6, 5), cudnn=True, check_eval=True, + check_bfloat16=True, desc='1d_affine', ), + dict( + module_name='GroupNorm', + constructor_args=(3, 12, 1e-3), + cpp_constructor_args='torch::nn::GroupNormOptions(3, 12).eps(1e-3)', + input_size=(4, 12), + cudnn=True, + check_eval=True, + check_bfloat16=True, + desc='1d_affine_GN', + ), + dict( + module_name='GroupNorm', + constructor_args=(1, 6, 1e-3), + cpp_constructor_args='torch::nn::GroupNormOptions(1, 6).eps(1e-3)', + input_size=(150, 6), + cudnn=True, + check_eval=True, + desc='1d_affine_large_batch', # For large batch_size + check_bfloat16=True, + test_cpu=False, + ), dict( module_name='GroupNorm', constructor_args=(5, 5, 1e-3, False), @@ -1556,15 +1626,17 @@ def fractional_max_pool3d_test(test_case): input_size=(4, 5, 5), cudnn=True, check_eval=True, + check_bfloat16=True, desc='1d_no_affine_IN', # this setting is equivalent with InstanceNormi ), dict( module_name='GroupNorm', - constructor_args=(1, 5, 1e-3, False), - cpp_constructor_args='torch::nn::GroupNormOptions(1, 5).eps(1e-3).affine(false)', - input_size=(4, 5, 5), + constructor_args=(1, 10, 1e-3, False), + cpp_constructor_args='torch::nn::GroupNormOptions(1, 10).eps(1e-3).affine(false)', + input_size=(4, 10), cudnn=True, check_eval=True, + check_bfloat16=True, desc='1d_no_affine_LN', # this setting is equivalent with LayerNorm ), dict( @@ -1574,8 +1646,20 @@ def fractional_max_pool3d_test(test_case): input_size=(4, 6, 2, 3), cudnn=True, check_eval=True, + check_bfloat16=True, desc='2d_affine', ), + dict( + module_name='GroupNorm', + constructor_args=(3, 6, 1e-3), + cpp_constructor_args='torch::nn::GroupNormOptions(3, 6).eps(1e-3)', + input_size=(4, 6, 28, 28), + cudnn=True, + check_eval=True, + check_bfloat16=True, + desc='2d_affine_large_feature', + test_cpu=False, + ), dict( module_name='GroupNorm', constructor_args=(3, 3, 1e-3, False), @@ -1583,6 +1667,7 @@ def fractional_max_pool3d_test(test_case): input_size=(4, 3, 2, 3), cudnn=True, check_eval=True, + check_bfloat16=True, desc='2d_no_affine_IN', # this setting is equivalent with InstanceNorm ), dict( @@ -1592,6 +1677,7 @@ def fractional_max_pool3d_test(test_case): input_size=(4, 3, 2, 3), cudnn=True, check_eval=True, + check_bfloat16=True, desc='2d_no_affine_LN', # this setting is equivalent with LayerNorm ), dict( @@ -1660,7 +1746,6 @@ def fractional_max_pool3d_test(test_case): input_size=(0, 4, 10), cudnn=True, desc='zero_batch', - test_cuda=(not TEST_WITH_ROCM), with_tf32=True, tf32_precision=0.005, ), @@ -1775,6 +1860,7 @@ def fractional_max_pool3d_test(test_case): desc='dilated', check_with_long_tensor=True, with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Conv2d', @@ -1795,7 +1881,6 @@ def fractional_max_pool3d_test(test_case): cudnn=True, desc='zero_batch', check_with_long_tensor=True, - test_cuda=(not TEST_WITH_ROCM), with_tf32=True, ), dict( @@ -1824,7 +1909,7 @@ def fractional_max_pool3d_test(test_case): input_size=(1, 3, 7, 6), check_with_long_tensor=True, with_tf32=True, - tf32_precision=0.005, + tf32_precision=0.01, ), dict( module_name='ConvTranspose2d', @@ -2140,6 +2225,18 @@ def fractional_max_pool3d_test(test_case): with_tf32=True, tf32_precision=0.05, ), + dict( + module_name='Conv3d', + constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False), + cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) + .stride(1).padding(0).dilation(1).groups(1).bias(false)''', + input_size=(1, 2, 3, 4, 5), + cudnn=True, + desc='1x1x1_no_bias', + check_with_long_tensor=False, + with_tf32=False, + tf32_precision=0.05, + ), dict( module_name='Conv3d', constructor_args=(3, 4, 2, 2), @@ -2170,7 +2267,6 @@ def fractional_max_pool3d_test(test_case): cudnn=True, check_with_long_tensor=True, desc='zero_batch', - test_cuda=(not TEST_WITH_ROCM), with_tf32=True, ), dict( @@ -2374,7 +2470,6 @@ def fractional_max_pool3d_test(test_case): constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - jacobian_input=False, check_gradgrad=False, ), dict( @@ -2382,7 +2477,6 @@ def fractional_max_pool3d_test(test_case): constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - jacobian_input=False, check_gradgrad=False, desc='mean', ), @@ -2391,7 +2485,6 @@ def fractional_max_pool3d_test(test_case): constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - jacobian_input=False, check_gradgrad=False, desc='alert_nondeterministic', test_cpu=False, @@ -2403,7 +2496,6 @@ def fractional_max_pool3d_test(test_case): cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - jacobian_input=False, check_gradgrad=False, desc='sum', ), @@ -2413,7 +2505,6 @@ def fractional_max_pool3d_test(test_case): cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), - jacobian_input=False, check_gradgrad=False, desc='max', ), @@ -2422,16 +2513,16 @@ def fractional_max_pool3d_test(test_case): constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True), cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)', input_fn=lambda: torch.randperm(2).repeat(1, 2), - jacobian_input=False, check_gradgrad=False, + has_sparse_gradients=True, ), dict( constructor=lambda: nn.Embedding(4, 3, sparse=True), cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)', input_fn=lambda: torch.randperm(2).repeat(1, 2), - jacobian_input=False, fullname='Embedding_sparse', check_gradgrad=False, + has_sparse_gradients=True, ), dict( module_name='PixelShuffle', @@ -2439,6 +2530,12 @@ def fractional_max_pool3d_test(test_case): cpp_constructor_args='torch::nn::PixelShuffleOptions(3)', input_size=(1, 9, 4, 4), ), + dict( + module_name='PixelUnshuffle', + constructor_args=(3,), + cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)', + input_size=(1, 1, 12, 12), + ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() @@ -2905,7 +3002,7 @@ def fractional_max_pool3d_test(test_case): .scale_factor(std::vector({3., 3., 3.})) .mode(torch::kTrilinear) .align_corners(false)''', - input_size=(1, 2, 3, 4, 4), + input_size=(1, 2, 3, 4, 5), fullname='interpolate_trilinear_scale_3d', # See https://github.com/pytorch/pytorch/issues/5006 precision=3e-4, @@ -3501,6 +3598,72 @@ def fractional_max_pool3d_test(test_case): skip_double=TEST_WITH_ROCM, pickle=False, ), + dict( + module_name='TransformerEncoderLayer', + constructor_args=(4, 2, 16, 0.0), + cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) + .dim_feedforward(16) + .dropout(0.0)''', + input_size=(2, 3, 4), + desc='relu_activation', + with_tf32=True, + tf32_precision=0.1, + ), + dict( + module_name='TransformerEncoderLayer', + constructor_args=(4, 2, 8, 0.0, 'gelu'), + cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kGELU)''', + input_size=(2, 3, 4), + check_gradgrad=False, + desc='gelu_activation', + with_tf32=True, + tf32_precision=0.01, + ), + dict( + module_name='TransformerDecoderLayer', + constructor_args=(4, 2, 8, 0.0), + cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), + check_gradgrad=False, + desc='relu_activation', + with_tf32=True, + tf32_precision=0.01, + ), + dict( + module_name='TransformerDecoderLayer', + constructor_args=(4, 2, 8, 0.0, 'gelu'), + cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kGELU)''', + input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), + check_gradgrad=False, + desc='gelu_activation', + with_tf32=True, + tf32_precision=0.01, + ), + dict( + module_name='Transformer', + constructor_args=(4, 2, 2, 2, 8, 0.0, "relu"), + cpp_constructor_args='''torch::nn::TransformerOptions() + .d_model(4) + .nhead(2) + .num_encoder_layers(2) + .num_decoder_layers(2) + .dim_feedforward(8) + .dropout(0.0) + .activation(torch::kReLU)''', + input_fn=lambda:(torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), + check_gradgrad=False, + desc='multilayer_coder', + with_tf32=True, + tf32_precision=0.01, + ) ] # add conv padding mode tests: @@ -3549,7 +3712,7 @@ def kldivloss_reference(input, target, reduction='mean'): return result.mean() elif reduction == 'sum': return result.sum() - elif reduction == 'batchmean' and results.dim() != 0: + elif reduction == 'batchmean' and result.dim() != 0: return result.sum() / result.size(0) return result @@ -3559,7 +3722,7 @@ def kldivloss_log_target_reference(input, target, reduction='mean'): return result.mean() elif reduction == 'sum': return result.sum() - elif reduction == 'batchmean' and results.dim() != 0: + elif reduction == 'batchmean' and result.dim() != 0: return result.sum() / result.size(0) return result @@ -3612,11 +3775,15 @@ def nll_loss_helper(input, target, weight, ignore_index): return losses_tensor -def smoothl1loss_reference(input, target, reduction='mean'): +def smoothl1loss_reference(input, target, reduction='mean', beta=1.0): abs_diff = (input - target).abs() - ge_one_mask = (abs_diff >= 1).type_as(abs_diff) - lt_one_mask = (abs_diff < 1).type_as(abs_diff) - output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff ** 2) + ge_beta_mask = (abs_diff >= beta).type_as(abs_diff) + lt_beta_mask = (abs_diff < beta).type_as(abs_diff) + # when beta <= 0 we should just use l1_loss + if beta == 0: + output = abs_diff + else: + output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta if reduction == 'mean': return output.mean() elif reduction == 'sum': @@ -3871,7 +4038,7 @@ def padding3d_circular(input, pad): return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4) -loss_reference_fns = { +loss_reference_fns: Dict['str', Callable] = { 'KLDivLoss': kldivloss_reference, 'KLDivLoss_log_target': kldivloss_log_target_reference, 'NLLLoss': nllloss_reference, @@ -3895,6 +4062,7 @@ def padding3d_circular(input, pad): target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True), reference_fn=lambda i, t, _: 1. / i.numel() * sum((a - b).abs().sum() for a, b in zip(i, t)), + check_complex=True, ), dict( module_name='NLLLoss', @@ -4114,8 +4282,8 @@ def padding3d_circular(input, pad): input_size=(5, 10), target_fn=lambda: torch.randn((5, 10), requires_grad=True), check_sum_reduction=True, - reference_fn=lambda i, t, m: - smoothl1loss_reference(i, t, reduction=get_reduction(m)), + reference_fn=lambda i, t, m, b=1.0: + smoothl1loss_reference(i, t, reduction=get_reduction(m), beta=b), ), dict( module_name='SoftMarginLoss', @@ -4291,6 +4459,7 @@ def padding3d_circular(input, pad): target_fn=lambda: torch.randn((), requires_grad=True), reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(), desc='scalar', + check_complex=True, ), dict( module_name='KLDivLoss', @@ -4355,8 +4524,8 @@ def padding3d_circular(input, pad): input_size=(), target_fn=lambda: torch.randn((), requires_grad=True), check_sum_reduction=True, - reference_fn=lambda i, t, m: - smoothl1loss_reference(i, t, reduction=get_reduction(m)), + reference_fn=lambda i, t, m, b=1.0: + smoothl1loss_reference(i, t, reduction=get_reduction(m), beta=b), desc='scalar', ), dict( @@ -4446,7 +4615,6 @@ def padding3d_circular(input, pad): check_sum_reduction=True, check_gradgrad=False, check_half=False, - convert_target=False, # `CTCLoss` in C++ frontend doesn't accept integer list for `input_lengths` or `target_lengths` test_cpp_api_parity=False, check_jit=False, @@ -4465,7 +4633,6 @@ def padding3d_circular(input, pad): check_sum_reduction=True, check_gradgrad=False, check_half=False, - convert_target=False, ), dict( module_name='CTCLoss', @@ -4481,13 +4648,32 @@ def padding3d_circular(input, pad): check_sum_reduction=True, check_gradgrad=False, check_half=False, - convert_target=False, ), ] class NNTestCase(TestCase): + # _forward is defined in classes inheriting from NNTestCase + @abstractmethod + def _forward(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]: + raise NotImplementedError + + @abstractmethod + def _zero_grad_parameters(self, module: nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def _backward(self, module: nn.Module, + input: _TensorOrTensors, output: torch.Tensor, + grad_output: Union[torch.Tensor, Sequence[torch.Tensor]], + create_graph: bool = False): + raise NotImplementedError + def _jacobian(self, input, num_out): if isinstance(input, tuple): return tuple(self._jacobian(elem, num_out) for elem in input) @@ -4514,7 +4700,7 @@ def _zero_grad_input(self, input): for i in input: self._zero_grad_input(i) - def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True): + def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): output = self._forward(module, input) output_size = output.nelement() @@ -4548,7 +4734,7 @@ def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_para if jacobian_parameters: jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0) - res = tuple() + res: Tuple[torch.Tensor, ...] = tuple() if jacobian_input: res += jacobian_inp, if jacobian_parameters: @@ -4556,11 +4742,11 @@ def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_para return res - def _numerical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True): + def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): def fw(input): return self._forward(module, input).detach() - res = tuple() + res: Tuple[torch.Tensor, ...] = tuple() if jacobian_input: res += get_numerical_jacobian(fw, input, eps=1e-6), if jacobian_parameters: @@ -4568,19 +4754,20 @@ def fw(input): res += torch.cat([get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param], 0), return res - def check_jacobian(self, module, input, jacobian_input=True): + def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True): jacobian_parameters = bool(self._get_parameters(module)[0]) analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters) numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters) analytical_t = list(iter_tensors(analytical)) numerical_t = list(iter_tensors(numerical)) - # TODO: compare structure - if input.numel() != 0: - self.assertLessEqual( - max(a.add(n, alpha=-1).abs().max() for a, n in zip(analytical_t, numerical_t)), - PRECISION - ) + differences = [] + for a, n in zip(analytical_t, numerical_t): + if a.numel() != 0: + differences.append(a.add(n, alpha=-1).abs().max()) + # TODO: compare structure (ensure analytic jacobian has correct shape) + if len(differences) > 0: + self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var] class TestBase(object): @@ -4663,6 +4850,10 @@ def __call__(self, test_case): class ModuleTest(TestBase): + @abstractmethod + def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any: + raise NotImplementedError + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.jacobian_input = kwargs.get('jacobian_input', True) @@ -4691,7 +4882,7 @@ def __call__(self, test_case): if self.should_test_pickle: # TODO: do this with in-memory files as soon as torch.save will support it - with TemporaryFile() as f: + with tempfile.TemporaryFile() as f: test_case._forward(module, input) torch.save(module, f) f.seek(0) @@ -4703,6 +4894,8 @@ def __call__(self, test_case): def noncontiguize(self, obj): if isinstance(obj, list): return [self.noncontiguize(o) for o in obj] + elif isinstance(obj, tuple): + return tuple(self.noncontiguize(o) for o in obj) tensor = obj ndim = tensor.dim() # Always making only the last dimension noncontiguous is easy to hide @@ -4757,8 +4950,9 @@ def test_cuda(self, test_case): raise unittest.SkipTest('Excluded from CUDA tests') cpu_input = self._get_input() - type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor} - gpu_input = to_gpu(cpu_input, type_map=type_map) + type_map = {torch.double: torch.float} + cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,) + gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map) cpu_module = self.constructor(*self.constructor_args) gpu_module = self.constructor(*self.constructor_args).float().cuda() @@ -4767,12 +4961,12 @@ def test_cuda(self, test_case): for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]): gpu_p.data.copy_(cpu_p) - test_case._zero_grad_input(cpu_input) - test_case._zero_grad_input(gpu_input) + test_case._zero_grad_input(cpu_input_tuple) + test_case._zero_grad_input(gpu_input_tuple) test_case._zero_grad_parameters(cpu_module) test_case._zero_grad_parameters(gpu_module) - cpu_output = test_case._forward(cpu_module, cpu_input) - gpu_output = test_case._forward(gpu_module, gpu_input) + cpu_output = test_case._forward(cpu_module, cpu_input_tuple) + gpu_output = test_case._forward(gpu_module, gpu_input_tuple) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 test_case.assertEqualIgnoreType(cpu_output, gpu_output, atol=self.precision, rtol=0) @@ -4780,8 +4974,8 @@ def test_cuda(self, test_case): for _ in range(5): cpu_gradOutput = cpu_output.clone().normal_() gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor') - cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput) - gpu_gradInput = test_case._backward(gpu_module, gpu_input, gpu_output, gpu_gradOutput) + cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput) + gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0) for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]): @@ -4789,8 +4983,8 @@ def test_cuda(self, test_case): # Run double-backwards on CPU and GPU and compare results if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison: - cpu_output = cpu_module(cpu_input) - gpu_output = gpu_module(gpu_input) + cpu_output = cpu_module(*cpu_input_tuple) + gpu_output = gpu_module(*gpu_input_tuple) cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True) gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach() @@ -4798,12 +4992,12 @@ def test_cuda(self, test_case): cpu_gradInputs = torch.autograd.grad( cpu_output, - (cpu_input,) + tuple(cpu_module.parameters()), + cpu_input_tuple + tuple(cpu_module.parameters()), cpu_gradOutput, create_graph=True) gpu_gradInputs = torch.autograd.grad( gpu_output, - (gpu_input,) + tuple(gpu_module.parameters()), + gpu_input_tuple + tuple(gpu_module.parameters()), gpu_gradOutput, create_graph=True) @@ -4816,12 +5010,12 @@ def test_cuda(self, test_case): # are unreachable (which can happen if you differentiate # only on the gradient. cpu_gg = torch.autograd.grad( - cpu_output.sum() + sum(map(lambda x: x.sum(), cpu_gradInputs)), - (cpu_input, cpu_gradOutput) + tuple(cpu_module.parameters()), + cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs), + cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()), retain_graph=True) gpu_gg = torch.autograd.grad( - gpu_output.sum() + sum(map(lambda x: x.sum(), gpu_gradInputs)), - (gpu_input, gpu_gradOutput) + tuple(gpu_module.parameters()), + gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs), + gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()), retain_graph=True) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0) @@ -4829,16 +5023,15 @@ def test_cuda(self, test_case): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 test_case.assertEqualIgnoreType(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0) - self.test_noncontig(test_case, gpu_module, gpu_input) - + self.test_noncontig(test_case, gpu_module, gpu_input_tuple) class InputVariableMixin(object): def _get_input(self): - input = TestBase._get_input(self, False) + input = TestBase._get_input(self, False) # type: ignore[arg-type] def map_variables(i): if isinstance(i, torch.Tensor): - if i.is_floating_point(): + if i.is_floating_point() or i.is_complex(): i.requires_grad = True return i else: @@ -4847,7 +5040,7 @@ def map_variables(i): return map_variables(input) -class NewModuleTest(InputVariableMixin, ModuleTest): +class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cudnn = kwargs.get('cudnn', False) @@ -4857,16 +5050,36 @@ def __init__(self, *args, **kwargs): self.with_tf32 = kwargs.get('with_tf32', False) self.tf32_precision = kwargs.get('tf32_precision', 0.001) self.test_cpu = kwargs.get('test_cpu', True) + self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False) + + def _check_gradients(self, test_case, module, input_tuple): + params = tuple(x for x in module.parameters()) + num_inputs = len(input_tuple) + + def fn_to_gradcheck(*inputs_and_params, **kwargs): + assert not kwargs + return test_case._forward(module, inputs_and_params[:num_inputs]) + + # gradcheck doesn't support operators that take in dense inputs but + # return sparse parameters. This only happens in the case of nn.Embedding + # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which + # is a slightly different version of gradcheck that can handle this. + if self.has_sparse_gradients: + assert num_inputs == 1 + test_input_jacobian = torch.is_floating_point(input_tuple[0]) + test_case.check_jacobian(module, input_tuple[0], test_input_jacobian) + else: + test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params)) + + if self.check_gradgrad: + test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params)) def _do_test(self, test_case, module, input): num_threads = torch.get_num_threads() torch.set_num_threads(1) - test_case.check_jacobian(module, input, self.jacobian_input) - if self.check_gradgrad: - # could probably unify check_jacobian above with this. - params = tuple(x for x in module.parameters()) - _assertGradAndGradgradChecks(test_case, - lambda x, *args, **kw: test_case._forward(module, x), (input,) + params) + input_tuple = input if isinstance(input, tuple) else (input,) + + self._check_gradients(test_case, module, input_tuple) # check if module can be printed module.__repr__() @@ -4875,6 +5088,11 @@ def _do_test(self, test_case, module, input): # check if the inplace variant of the module gives the same result # as the out-of-place + # check_inplace doesn't support multiple input tensors, since we don't have any modules + # that modify the inputs in-place and that accept more than one input + assert len(input_tuple) == 1 + input = input_tuple[0] + module_ip = self.constructor(*self.constructor_args, inplace=True) input_version = input._version @@ -4889,111 +5107,100 @@ def _do_test(self, test_case, module, input): test_case.assertNotEqual(input_ip_clone._version, input_version) test_case.assertEqual(output, output_ip) grad = output.data.clone().normal_() - input.grad.data.zero_() + if input.grad is not None: + with torch.no_grad(): + input.grad.zero_() output.backward(grad) output_ip.backward(grad) test_case.assertEqual(input.grad, input_ip.grad) - if isinstance(input, torch.LongTensor) and TEST_CUDA: + def assert_module_parameters_are(tensor_type, device_id=None): + for p in module.parameters(): + test_case.assertIsInstance(p, tensor_type) + if device_id is not None: + test_case.assertEqual(p.get_device(), device_id) + + if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA: # check that cuda() moves module parameters to correct GPU device, # and that float() casts parameters correctly - - input = input.cuda() + input_tuple = tuple(t.cuda() for t in input_tuple) module.float().cuda() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.FloatTensor) - test_case.assertEqual(p.get_device(), 0) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] if torch.cuda.device_count() > 1: - input = input.cuda(1) + input_tuple = tuple(t.cuda(1) for t in input_tuple) module.cuda(1) with torch.cuda.device(1): - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.FloatTensor) - test_case.assertEqual(p.get_device(), 1) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] else: # check that float()/double() casters work correctly # to float - if not isinstance(input, torch.LongTensor): - input = input.float() + input_tuple = tuple(t.float() if not isinstance(t, torch.LongTensor) else t for t in input_tuple) module.float() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.FloatTensor) + module(*input_tuple) + assert_module_parameters_are(torch.FloatTensor) # and back to double - if not isinstance(input, torch.LongTensor): - input = input.double() + input_tuple = tuple(t.double() if not isinstance(t, torch.LongTensor) else t for t in input_tuple) module.double() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.DoubleTensor) + module(*input_tuple) + assert_module_parameters_are(torch.DoubleTensor) if TEST_CUDA and self.should_test_cuda: # check that cuda() moves module parameters to correct GPU device, # and that float() casts parameters correctly # to GPU0 - input = input.float().cuda() + input_tuple = tuple( + t.float().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple) module.float().cuda() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.FloatTensor) - test_case.assertEqual(p.get_device(), 0) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] # to CPU - input = input.cpu() + input_tuple = tuple(t.cpu() for t in input_tuple) module.cpu() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.FloatTensor) + module(*input_tuple) + assert_module_parameters_are(torch.FloatTensor) # back to GPU0 - input = input.cuda() + input_tuple = tuple(t.cuda() for t in input_tuple) module.cuda() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.FloatTensor) - test_case.assertEqual(p.get_device(), 0) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] # test that forwards of module runs correctly without cuDNN if self.cudnn: with torch.backends.cudnn.flags(enabled=False): - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.FloatTensor) - test_case.assertEqual(p.get_device(), 0) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] if torch.cuda.device_count() >= 2: # test cross-GPU transfer works # to GPU1 - input = input.cuda(1) + input_tuple = tuple(t.cuda(1) for t in input_tuple) module.cuda(1) with torch.cuda.device(1): - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.FloatTensor) - test_case.assertEqual(p.get_device(), 1) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] if not self.skip_double: # test double() - input = input.double().cuda() + input_tuple = tuple( + t.double().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple) module.double().cuda() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.DoubleTensor) - test_case.assertEqual(p.get_device(), 0) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined] # test half() - input = input.half().cuda() + input_tuple = tuple( + t.half().cuda() if not isinstance(t, torch.LongTensor) else t.cuda() for t in input_tuple) module.half().cuda() - module(input) - for p in module.parameters(): - test_case.assertIsInstance(p, torch.cuda.HalfTensor) - test_case.assertEqual(p.get_device(), 0) + module(*input_tuple) + assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined] torch.set_num_threads(num_threads) def _get_target(self): @@ -5004,7 +5211,7 @@ def constructor_args(self): return self._get_arg('constructor_args', False) -class CriterionTest(InputVariableMixin, TestBase): +class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc] # TODO: check that criterions don't ignore grad_output _required_arg_names = TestBase._required_arg_names.union({'target'}) @@ -5016,7 +5223,7 @@ def __init__(self, *args, **kwargs): self.check_gradgrad = kwargs.get('check_gradgrad', True) self.check_half = kwargs.get('check_half', True) self.check_bfloat16 = kwargs.get('check_bfloat16', False) - self.convert_target = kwargs.get('convert_target', True) + self.check_complex = kwargs.get('check_complex', False) self.test_cpu = kwargs.get('test_cpu', True) self.with_tf32 = kwargs.get('with_tf32', True) self.tf32_precision = kwargs.get('tf32_precision', 0.001) @@ -5049,7 +5256,7 @@ def apply_fn(input, target, *params): else: inputs = input + params + (target,) - def apply_fn(input1, input2, target, *params): + def apply_fn(input1, input2, target, *params): # type: ignore[misc] return module(input1, input2, target) gradcheck(apply_fn, inputs) @@ -5057,12 +5264,10 @@ def apply_fn(input1, input2, target, *params): if self.check_gradgrad: gradgradcheck(apply_fn, inputs) - def test_cuda(self, test_case, dtype=None, extra_args=None): + def test_cuda(self, test_case, dtype, extra_args=None): def convert_dtype(obj, dtype, requires_grad=False): if isinstance(obj, torch.Tensor): return obj.detach().to(dtype=dtype).requires_grad_(requires_grad) - elif isinstance(obj, torch.Tensor): - return obj.to(dtype) elif isinstance(obj, tuple): return tuple(convert_dtype(o, dtype, requires_grad) for o in obj) else: @@ -5077,13 +5282,11 @@ def convert_dtype(obj, dtype, requires_grad=False): gpu_module = self.constructor(*self.constructor_args) # Convert input, target and module parameters to dtype - if dtype is not None: - cpu_input = convert_dtype(cpu_input, dtype, True) - # NLLLoss requires target to be LongTensor - if not isinstance(cpu_target, torch.LongTensor) and self.convert_target: - cpu_target = convert_dtype(cpu_target, dtype) - cpu_module.type(dtype) - gpu_module.type(dtype) + cpu_input = convert_dtype(cpu_input, dtype, True) + if cpu_target.is_floating_point() or cpu_target.is_complex(): + cpu_target = convert_dtype(cpu_target, dtype) + cpu_module.type(dtype) + gpu_module.type(dtype) # GPU setup gpu_input = to_gpu(cpu_input) @@ -5099,13 +5302,14 @@ def convert_dtype(obj, dtype, requires_grad=False): cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args) gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args) - # dtype can be None, so set precision in this way instead of a precision map + # dtype used to be able to be None, so set precision in this way instead of a precision map # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 test_case.assertEqualIgnoreType(cpu_output, gpu_output, atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0) - cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args) - gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args) + cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args) + gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args) + # dtype used to be able to be None, so set precision in this way instead of a precision map # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 4031b2fdd0ded..eef9381d79d94 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -6,34 +6,31 @@ import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd +from torch.nn.intrinsic import _FusedModule import torch.distributed as dist from torch.testing._internal.common_utils import TestCase from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ - propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \ - get_default_qat_qconfig -from torch.quantization import ( - is_custom_module_class, - is_observed_custom_module, -) + propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \ + get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType from torch.quantization.quantization_mappings import ( - get_dynamic_quant_module_mappings, - get_qconfig_propagation_list, - get_qat_module_mappings, -) -# symbolic trace -from torch.fx import symbolic_trace - -# graph mode quantization based on fx -from torch.quantization import ( - QuantType, - prepare_fx, - prepare_dynamic_fx, - convert_fx, - convert_dynamic_fx, + get_default_dynamic_quant_module_mappings, + get_default_qconfig_propagation_list, + get_default_qat_module_mappings, ) +try: + # graph mode quantization based on fx + from torch.quantization.quantize_fx import ( + prepare_fx, + prepare_qat_fx, + convert_fx, + ) + HAS_FX = True +except ImportError: + HAS_FX = False + import copy import io import functools @@ -191,7 +188,7 @@ def run_ddp(rank, world_size, prepared): def convert_dynamic(module): - convert(module, get_dynamic_quant_module_mappings(), inplace=True) + convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) def prepare_dynamic(model, qconfig_dict=None): propagate_qconfig_(model, qconfig_dict) @@ -342,12 +339,15 @@ def checkHasPrepModules(self, module): self.assertTrue(hasattr(module, 'quant')) self.assertTrue(hasattr(module, 'dequant')) - def checkObservers(self, module, propagate_qconfig_list=None): + def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None): r"""Checks the module or module's leaf descendants have observers in preperation for quantization """ if propagate_qconfig_list is None: - propagate_qconfig_list = get_qconfig_propagation_list() + propagate_qconfig_list = get_default_qconfig_propagation_list() + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) # check if a module is a leaf module, ignoring activation_post_process attribute def is_leaf_module(module): @@ -357,18 +357,20 @@ def is_leaf_module(module): submodule_name_count += 1 return submodule_name_count == 0 - if (hasattr(module, 'qconfig') and module.qconfig is not None and - is_leaf_module(module) and not isinstance(module, torch.nn.Sequential) - and type(module) in propagate_qconfig_list) or \ - is_custom_module_class(type(module)): + if hasattr(module, 'qconfig') and module.qconfig is not None and \ + ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential) + and type(module) in propagate_qconfig_list) or + type(module) in float_to_observed_module_class_mapping.keys()) and \ + not isinstance(module, torch.quantization.DeQuantStub): self.assertTrue(hasattr(module, 'activation_post_process'), 'module: ' + str(type(module)) + ' do not have observer') # we don't need to check observers for child modules of the # qat modules - if type(module) not in get_qat_module_mappings().values() and \ - not is_observed_custom_module(module): + if type(module) not in get_default_qat_module_mappings().values() and \ + type(module) not in float_to_observed_module_class_mapping.values() and \ + not isinstance(module, _FusedModule): for child in module.children(): - self.checkObservers(child) + self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict) def checkQuantDequant(self, mod): r"""Checks that mod has nn.Quantize and @@ -601,75 +603,105 @@ def printGraphModule(self, graph_module, print_str=True): print(str_to_print) return str_to_print - def checkGraphModeFxOp(self, model, inputs, quant_type, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None, - debug=False, - print_debug_info=False): - """ Quantizes model with graph mode quantization on fx and check if the - quantized model contains the quantized_node + if HAS_FX: + def checkGraphModeFxOp(self, model, inputs, quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + debug=False, + print_debug_info=False, + custom_qconfig=None, + prepare_expected_node=None, + prepare_expected_node_occurrence=None, + prepare_expected_node_list=None, + prepare_custom_config_dict=None): + """ Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node + + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + debug: if True, enables debug mode + print_debug_info: if True, prints debug info + custom_qconfig: overrides default qconfig + prepare_expected_node: same as expected_node, but for prepare + prepare_expected_node_occurrence: same as + expected_node_occurrence, but for prepare + prepare_expected_node_list: same as expected_node_list, but + for prepare + """ + # TODO: make img_data a single example instead of a list + if type(inputs) == list: + inputs = inputs[0] + + if quant_type == QuantType.QAT: + qconfig = get_default_qat_qconfig(torch.backends.quantized.engine) + model.train() + elif quant_type == QuantType.STATIC: + qconfig = get_default_qconfig(torch.backends.quantized.engine) + model.eval() + else: + qconfig = default_dynamic_qconfig + model.eval() - Args: - model: floating point torch.nn.Module - inputs: one positional sample input arguments for model - expected_node: NodeSpec - e.g. NodeSpec.call_function(torch.quantize_per_tensor) - expected_node_occurrence: a dict from NodeSpec to - expected number of occurences (int) - e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, - NodeSpec.call_method('dequantize'): 1} - expected_node_list: a list of NodeSpec, used to check the order - of the occurrence of Node - e.g. [NodeSpec.call_function(torch.quantize_per_tensor), - NodeSpec.call_module(nnq.Conv2d), - NodeSpec.call_function(F.hardtanh_), - NodeSpec.call_method('dequantize')] - """ - # TODO: make img_data a single example instead of a list - if type(inputs) == list: - inputs = inputs[0] - if quant_type == QuantType.QAT: - qconfig_dict = {'': get_default_qat_qconfig(torch.backends.quantized.engine)} - model.train() - else: - qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} - model.eval() - original = symbolic_trace(model) + # overwrite qconfig with custom_qconfig + if custom_qconfig is not None: + qconfig = custom_qconfig - if quant_type == QuantType.DYNAMIC: - prepare = prepare_dynamic_fx - convert = convert_dynamic_fx - else: - prepare = prepare_fx - convert = convert_fx - - prepared = prepare(original, qconfig_dict) - prepared(*inputs) - qgraph = convert(prepared) - qgraph_debug = convert(prepared, debug=True) - - result = qgraph(*inputs) - result_debug = qgraph_debug(*inputs) - - self.assertEqual((result - result_debug).abs().max(), 0), \ - 'Expecting debug and non-debug option to produce identical result' - - if print_debug_info: - print() - print('quant type:', quant_type) - print('origianl graph module:', type(model)) - self.printGraphModule(original) - print() - print('quantized graph module:', type(qgraph)) - self.printGraphModule(qgraph) - print() - qgraph_to_check = qgraph_debug if debug else qgraph - self.checkGraphModuleNodes( - qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) - - - def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag): + if quant_type == QuantType.QAT: + prepare = prepare_qat_fx + else: + prepare = prepare_fx + + qconfig_dict = {'': qconfig} + prepared = prepare( + model, qconfig_dict, + prepare_custom_config_dict=prepare_custom_config_dict) + if not quant_type == QuantType.DYNAMIC: + prepared(*inputs) + + if print_debug_info: + print() + print('quant type:\n', quant_type) + print('original model:\n', model) + print() + print('prepared model:\n', prepared) + + self.checkGraphModuleNodes( + prepared, prepare_expected_node, + prepare_expected_node_occurrence, prepare_expected_node_list) + + prepared_copy = copy.deepcopy(prepared) + qgraph = convert_fx(prepared) + qgraph_debug = convert_fx(prepared_copy, debug=True) + result = qgraph(*inputs) + result_debug = qgraph_debug(*inputs) + + qgraph_to_check = qgraph_debug if debug else qgraph + if print_debug_info: + print() + print('quantized model:\n', qgraph_to_check) + self.printGraphModule(qgraph_to_check) + print() + self.checkGraphModuleNodes( + qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) + return result + + + def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, + set_qconfig, is_emb_bag, dtype=torch.quint8): # Test serialization of dynamic EmbeddingBag module using state_dict if is_emb_bag: inputs = [indices, offsets] @@ -692,9 +724,9 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic # Check state dict serialization and torch.save APIs if is_emb_bag: loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, mode='sum') + include_last_offset=True, mode='sum', dtype=dtype) else: - loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype) self.check_eager_serialization(qemb, loaded_qemb, inputs) loaded_qemb.load_state_dict(loaded_dict) @@ -713,7 +745,11 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) if set_qconfig: - float_embedding.qconfig = float_qparams_dynamic_qconfig + float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0) + float_embedding.qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, + weight=float_qparams_observer) prepare_dynamic(float_embedding) @@ -730,8 +766,8 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic self.assertTrue(expected_name in str(q_embeddingbag)) -# Below are a series of neural net models to use in testing quantization -# Single layer models +# Below are a series of toy models to use in testing quantization + class SingleLayerLinearModel(torch.nn.Module): def __init__(self): super().__init__() @@ -810,6 +846,15 @@ def forward(self, x): x = self.conv(x) return x +class ConvTransposeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) + + def forward(self, x): + x = self.conv(x) + return x + class AnnotatedConvModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -824,6 +869,20 @@ def forward(self, x): x = self.dequant(x) return x +class AnnotatedConvTransposeModel(torch.nn.Module): + def __init__(self, qengine): + super().__init__() + self.qconfig = torch.quantization.get_default_qconfig(qengine) + self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.dequant(x) + return x + class ConvBnModel(torch.nn.Module): def __init__(self): super().__init__() @@ -883,6 +942,17 @@ def forward(self, x): x = self.fc2(x) return x +class LinearModelWithSubmodule(nn.Module): + def __init__(self): + super(LinearModelWithSubmodule, self).__init__() + self.subm = TwoLayerLinearModel() + self.fc = nn.Linear(5, 5) + + def forward(self, x): + x = self.subm(x) + x = self.fc(x) + return x + class AnnotatedTwoLayerLinearModel(torch.nn.Module): def __init__(self): super().__init__() @@ -1304,7 +1374,7 @@ def __init__(self): self.downsample = torch.nn.Identity() self.myop = nn.quantized.FloatFunctional() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - + self.fc = torch.nn.Linear(inplanes, 1) def forward(self, x): out = self.conv1(x) @@ -1314,8 +1384,13 @@ def forward(self, x): out = self.myop.add(out, identity) out = self.relu2(out) out = self.avgpool(out) + out = torch.flatten(out, 1) + out = self.fc(out) return out + def fuse_model(self): + torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) + class ModelMultipleOps(torch.nn.Module): def __init__(self): super().__init__() @@ -1403,7 +1478,7 @@ def __init__(self): super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) self.fc = torch.nn.Linear(5, 5) - self.emb.qconfig = float_qparams_dynamic_qconfig + self.emb.qconfig = float_qparams_weight_only_qconfig self.qconfig = default_qconfig def forward(self, indices, linear_in): diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 243cd964b96d3..f145565971286 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -102,6 +102,35 @@ def _calculate_dynamic_per_channel_qparams(X, dtype): return scale, zero_point +def _snr(x, x_hat): + """Calculates the signal to noise ratio and returns the signal and noise + power, as well as the SNR in dB. + If the input is a list/tuple this function is called recursively on each + element. The result will have the same nested structure as the inputs. + + Args: + x, x_hat: Either a tensor or a nested list/tuple of tensors. + Returns: + signal, noise, SNR(in dB): Either floats or a nested list of floats + """ + if isinstance(x, (list, tuple)): + assert(len(x) == len(x_hat)) + res = [] + for idx in range(len(x)): + res.append(_snr(x[idx], x_hat[idx])) + return res + if x_hat.is_quantized: + x_hat = x_hat.dequantize() + if x.is_quantized: + x = x.dequantize() + noise = (x - x_hat).norm() + if noise == 0: + return 0.0, float('inf'), float('inf') + signal = x.norm() + snr = signal / noise + snr_db = 20 * snr.log10() + return signal, noise, snr_db + @contextmanager def override_quantized_engine(qengine): previous = torch.backends.quantized.engine diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 9959551031ff3..236e1817465a2 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -15,15 +15,19 @@ from functools import partial import inspect import io +import copy +import operator import argparse import unittest import warnings import random import contextlib +import shutil import socket import subprocess import time from collections import OrderedDict +from collections.abc import Sequence from contextlib import contextmanager from functools import wraps from itertools import product @@ -32,9 +36,9 @@ import tempfile import json from urllib.request import urlopen -import __main__ +import __main__ # type: ignore[import] import errno -from typing import cast, Any, Iterable, Optional +from typing import cast, Any, Dict, Iterable, Iterator, Optional from torch.testing._internal import expecttest from torch.testing import \ @@ -53,7 +57,13 @@ torch.backends.disable_global_flags() +FILE_SCHEMA = "file://" +if sys.platform == 'win32': + FILE_SCHEMA = "file:///" + IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle' +IS_FBCODE = os.getenv('PYTORCH_TEST_FBCODE') == '1' +IS_REMOTE_GPU = os.getenv('PYTORCH_TEST_REMOTE_GPU') == '1' class ProfilingMode(Enum): LEGACY = 1 @@ -123,8 +133,9 @@ def prof_func_call(*args, **kwargs): def prof_meth_call(*args, **kwargs): return prof_callable(meth_call, *args, **kwargs) -torch._C.ScriptFunction.__call__ = prof_func_call -torch._C.ScriptMethod.__call__ = prof_meth_call +# TODO fix when https://github.com/python/mypy/issues/2427 is address +torch._C.ScriptFunction.__call__ = prof_func_call # type: ignore[assignment] +torch._C.ScriptMethod.__call__ = prof_meth_call # type: ignore[assignment] def _get_test_report_path(): # allow users to override the test file location. We need this @@ -140,22 +151,22 @@ def _get_test_report_path(): help='whether to run each test in a subprocess') parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--accept', action='store_true') -parser.add_argument('--ge_config', type=str) +parser.add_argument('--jit_executor', type=str) parser.add_argument('--repeat', type=int, default=1) parser.add_argument('--test_bailouts', action='store_true') parser.add_argument('--save-xml', nargs='?', type=str, const=_get_test_report_path(), - default=_get_test_report_path() if bool(os.environ.get('IN_CIRCLECI')) else None) + default=_get_test_report_path() if bool(os.environ.get('IN_CI')) else None) parser.add_argument('--discover-tests', action='store_true') parser.add_argument('--log-suffix', type=str, default="") parser.add_argument('--run-parallel', type=int, default=1) args, remaining = parser.parse_known_args() -if args.ge_config == 'legacy': +if args.jit_executor == 'legacy': GRAPH_EXECUTOR = ProfilingMode.LEGACY -elif args.ge_config == 'profiling': +elif args.jit_executor == 'profiling': GRAPH_EXECUTOR = ProfilingMode.PROFILING -elif args.ge_config == 'simple': +elif args.jit_executor == 'simple': GRAPH_EXECUTOR = ProfilingMode.SIMPLE else: # infer flags based on the default settings @@ -272,7 +283,7 @@ def run_tests(argv=UNITTEST_ARGS): assert not failed, "Some test shards have failed" elif TEST_SAVE_XML is not None: # import here so that non-CI doesn't need xmlrunner installed - import xmlrunner + import xmlrunner # type: ignore[import] test_report_path = TEST_SAVE_XML + LOG_SUFFIX os.makedirs(test_report_path, exist_ok=True) verbose = '--verbose' in argv or '-v' in argv @@ -292,11 +303,16 @@ def run_tests(argv=UNITTEST_ARGS): if IS_WINDOWS: @contextmanager - def TemporaryFileName(): + def TemporaryFileName(*args, **kwargs): # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile # opens the file, and it cannot be opened multiple times in Windows. To support Windows, # close the file after creation and try to remove it manually - f = tempfile.NamedTemporaryFile(delete=False) + if 'delete' in kwargs: + if kwargs['delete'] is not False: + raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.") + else: + kwargs['delete'] = False + f = tempfile.NamedTemporaryFile(*args, **kwargs) try: f.close() yield f.name @@ -304,10 +320,27 @@ def TemporaryFileName(): os.unlink(f.name) else: @contextmanager # noqa: T484 - def TemporaryFileName(): - with tempfile.NamedTemporaryFile() as f: + def TemporaryFileName(*args, **kwargs): + with tempfile.NamedTemporaryFile(*args, **kwargs) as f: yield f.name +if IS_WINDOWS: + @contextmanager + def TemporaryDirectoryName(suffix=None): + # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely, + # so we first create the directory using mkdtemp and then remove it manually + try: + dir_name = tempfile.mkdtemp(suffix=suffix) + yield dir_name + finally: + shutil.rmtree(dir_name) +else: + @contextmanager # noqa: T484 + def TemporaryDirectoryName(suffix=None): + with tempfile.TemporaryDirectory(suffix=suffix) as d: + yield d + +IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8' def _check_module_exists(name): r"""Returns if a top-level module with :attr:`name` exists *without** @@ -316,7 +349,6 @@ def _check_module_exists(name): our tests, e.g., setting multiprocessing start method when imported (see librosa/#747, torchvision/#544). """ - import importlib import importlib.util spec = importlib.util.find_spec(name) return spec is not None @@ -390,43 +422,73 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +# Context manager for setting deterministic flag and automatically +# resetting it to its original value +class DeterministicGuard: + def __init__(self, deterministic): + self.deterministic = deterministic + + def __enter__(self): + self.deterministic_restore = torch.is_deterministic() + torch.set_deterministic(self.deterministic) + + def __exit__(self, exception_type, exception_value, traceback): + torch.set_deterministic(self.deterministic_restore) + # This decorator can be used for API tests that call torch.set_deterministic(). # When the test is finished, it will restore the previous deterministic flag -# setting. Also, if CUDA >= 10.2, this will set the environment variable -# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that setting -# is not thrown during the test unless the test changes that variable on purpose. -# The previous CUBLAS_WORKSPACE_CONFIG setting will also be restored once the -# test is finished. +# setting. +# +# If CUDA >= 10.2, this will set the environment variable +# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that +# setting is not thrown during the test unless the test changes that variable +# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be +# restored once the test is finished. +# +# Note that if a test requires CUDA to actually register the changed +# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because +# CUDA only checks the variable when the runtime initializes. Tests can be +# run inside a subprocess like so: +# +# import subprocess, sys, os +# script = ''' +# # Test code should go here +# ''' +# try: +# subprocess.check_output( +# [sys.executable, '-c', script], +# stderr=subprocess.STDOUT, +# cwd=os.path.dirname(os.path.realpath(__file__)), +# env=os.environ.copy()) +# except subprocess.CalledProcessError as e: +# error_message = e.output.decode('utf-8') +# # Handle exceptions raised by the subprocess here +# def wrapDeterministicFlagAPITest(fn): @wraps(fn) def wrapper(*args, **kwargs): - deterministic_restore = torch.is_deterministic() - - is_cuda10_2_or_higher = ( - (torch.version.cuda is not None) - and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2])) - - if is_cuda10_2_or_higher: - cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG' - cublas_config_restore = os.environ.get(cublas_var_name) - os.environ[cublas_var_name] = ':4096:8' - - def restore(): - torch.set_deterministic(deterministic_restore) - if is_cuda10_2_or_higher: - cur_cublas_config = os.environ.get(cublas_var_name) - if cublas_config_restore is None: - if cur_cublas_config is not None: - del os.environ[cublas_var_name] - else: - os.environ[cublas_var_name] = cublas_config_restore - try: - fn(*args, **kwargs) - except RuntimeError: - restore() - raise - else: - restore() + with DeterministicGuard(torch.is_deterministic()): + class CuBLASConfigGuard: + cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG' + + def __enter__(self): + self.is_cuda10_2_or_higher = ( + (torch.version.cuda is not None) + and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2])) + if self.is_cuda10_2_or_higher: + self.cublas_config_restore = os.environ.get(self.cublas_var_name) + os.environ[self.cublas_var_name] = ':4096:8' + + def __exit__(self, exception_type, exception_value, traceback): + if self.is_cuda10_2_or_higher: + cur_cublas_config = os.environ.get(self.cublas_var_name) + if self.cublas_config_restore is None: + if cur_cublas_config is not None: + del os.environ[self.cublas_var_name] + else: + os.environ[self.cublas_var_name] = self.cublas_config_restore + with CuBLASConfigGuard(): + fn(*args, **kwargs) return wrapper def skipIfCompiledWithoutNumpy(fn): @@ -506,6 +568,11 @@ def wrapper(*args, **kwargs): return wrapper +def slowAwareTest(fn): + fn.__dict__['slow_test'] = True + return fn + + def skipCUDAMemoryLeakCheckIf(condition): def dec(fn): if getattr(fn, '_do_cuda_memory_leak_check', True): # if current True @@ -529,28 +596,14 @@ def wrapper(*args, **kwargs): return wrapper -def get_cpu_type(type_name): - module, name = type_name.rsplit('.', 1) - assert module == 'torch.cuda' - return getattr(torch, name) - - -def get_gpu_type(type_name): - if isinstance(type_name, type): - type_name = '{}.{}'.format(type_name.__module__, type_name.__name__) - module, name = type_name.rsplit('.', 1) - assert module == 'torch' - return getattr(torch.cuda, name) - - def to_gpu(obj, type_map=None): if type_map is None: type_map = {} if isinstance(obj, torch.Tensor): assert obj.is_leaf - t = type_map.get(obj.type(), get_gpu_type(obj.type())) + t = type_map.get(obj.dtype, obj.dtype) with torch.no_grad(): - res = obj.clone().type(t) + res = obj.clone().to(dtype=t, device="cuda") res.requires_grad = obj.requires_grad return res elif torch.is_storage(obj): @@ -701,11 +754,11 @@ def settings(*args, **kwargs): except ImportError: print('Fail to import hypothesis in common_utils, tests are not derandomized') -disabled_test_from_issues = None +disabled_test_from_issues: Optional[Dict[str, Any]] = None def check_disabled(test_name): global disabled_test_from_issues if disabled_test_from_issues is None: - disabled_test_from_issues = {} + _disabled_test_from_issues: Dict = {} def read_and_process(): url = 'https://raw.githubusercontent.com/zdevito/pytorch_disabled_tests/master/result.json' @@ -716,18 +769,21 @@ def read_and_process(): key = 'DISABLED ' if title.startswith(key): test_name = title[len(key):].strip() - disabled_test_from_issues[test_name] = item['html_url'] + _disabled_test_from_issues[test_name] = item['html_url'] if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1": try: read_and_process() + disabled_test_from_issues = _disabled_test_from_issues except Exception: print("Couldn't download test skip set, leaving all tests enabled...") + disabled_test_from_issues = {} - if test_name in disabled_test_from_issues: - raise unittest.SkipTest( - "Test is disabled because an issue exists disabling it: {}".format(disabled_test_from_issues[test_name]) + - " To enable set the environment variable PYTORCH_RUN_DISABLED_TESTS=1") + if disabled_test_from_issues is not None: + if test_name in disabled_test_from_issues: + raise unittest.SkipTest( + "Test is disabled because an issue exists disabling it: {}".format(disabled_test_from_issues[test_name]) + + " To enable set the environment variable PYTORCH_RUN_DISABLED_TESTS=1") # Acquires the comparison dtype, required since isclose # requires both inputs have the same dtype, and isclose is not supported @@ -785,7 +841,7 @@ def __init__(self, method_name='runTest'): # Wraps the tested method if we should enforce non default CUDA stream. self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True) - if self._do_cuda_non_default_stream and not IS_WINDOWS and not TEST_WITH_ROCM: + if self._do_cuda_non_default_stream and not IS_WINDOWS: self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream) def assertLeaksNoCudaTensors(self, name=None): @@ -845,7 +901,6 @@ def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device='cpu'): if is_uncoalesced: v = torch.cat([v, torch.randn_like(v)], 0) i = torch.cat([i, i], 1) - x = torch.sparse_coo_tensor(i, v, torch.Size(size)) if not is_uncoalesced: @@ -861,58 +916,16 @@ def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device='cpu'): return x, x._indices().clone(), x._values().clone() def safeToDense(self, t): - r = self.safeCoalesce(t) - return r.to_dense() - - def safeCoalesce(self, t): - tc = t.coalesce() - self.assertEqual(tc.to_dense(), t.to_dense()) - self.assertTrue(tc.is_coalesced()) - - # Our code below doesn't work when nnz is 0, because - # then it's a 0D tensor, not a 2D tensor. - if t._nnz() == 0: - self.assertEqual(t._indices(), tc._indices()) - self.assertEqual(t._values(), tc._values()) - return tc - - value_map = {} - for idx, val in zip(t._indices().t(), t._values()): - idx_tup = tuple(idx.tolist()) - if idx_tup in value_map: - value_map[idx_tup] += val - else: - value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val - - new_indices = sorted(list(value_map.keys())) - new_values = [value_map[idx] for idx in new_indices] - if t._values().ndimension() < 2: - new_values = t._values().new(new_values) - else: - new_values = torch.stack(new_values) - - new_indices = t._indices().new(new_indices).t() - tg = t.new(new_indices, new_values, t.size()) - - self.assertEqual(tc._indices(), tg._indices()) - self.assertEqual(tc._values(), tg._values()) - - if t.is_coalesced(): - self.assertEqual(tc._indices(), t._indices()) - self.assertEqual(tc._values(), t._values()) - - return tg + return t.coalesce().to_dense() # Compares the given Torch and NumPy functions on the given tensor-like object. # NOTE: both torch_fn and np_fn should be functions that take a single # tensor (array). If the torch and/or NumPy function require additional # arguments then wrap the function in a lambda or pass a partial function. - # TODO: support bfloat16 comparisons # TODO: add args/kwargs for passing to assertEqual (e.g. rtol, atol) def compare_with_numpy(self, torch_fn, np_fn, tensor_like, device=None, dtype=None, **kwargs): assert TEST_NUMPY - assert dtype is not torch.bfloat16 if isinstance(tensor_like, torch.Tensor): assert device is None @@ -920,7 +933,9 @@ def compare_with_numpy(self, torch_fn, np_fn, tensor_like, a = tensor_like.detach().cpu().numpy() t = tensor_like else: - a = np.array(tensor_like, dtype=torch_to_numpy_dtype_dict[dtype]) + d = copy.copy(torch_to_numpy_dtype_dict) + d[torch.bfloat16] = np.float32 + a = np.array(tensor_like, dtype=d[dtype]) t = torch.tensor(tensor_like, device=device, dtype=dtype) np_result = np_fn(a) @@ -934,6 +949,8 @@ def compare_with_numpy(self, torch_fn, np_fn, tensor_like, # NOTE: copying an array before conversion is necessary when, # for example, the array has negative strides. np_result = torch.from_numpy(np_result.copy()) + if dtype is torch.bfloat16 and torch_result.dtype is torch.bfloat16 and np_result.dtype is torch.float: + torch_result = torch_result.to(torch.float) self.assertEqual(np_result, torch_result, **kwargs) @@ -1027,9 +1044,19 @@ def _compareScalars(self, a, b, *, rtol, atol = self._getDefaultRtolAndAtol(torch.float32, torch.float32) else: rtol, atol = 0, 0 + rtol = cast(float, rtol) + atol = cast(float, atol) + assert atol is not None atol = max(atol, self.precision) - return _compare_scalars_internal(a, b, rtol=cast(float, rtol), atol=cast(float, atol), equal_nan=equal_nan) + return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + # Construct assert messages basd on internal debug message and user provided message. + def _get_assert_msg(self, msg, debug_msg=None): + if msg is None: + return debug_msg + else: + return f"\n{msg}" if debug_msg is None else f"{debug_msg}\n{msg}" def assertEqualIgnoreType(self, *args, **kwargs) -> None: # If you are seeing this function used, that means test is written wrongly @@ -1041,7 +1068,8 @@ def assertEqualIgnoreType(self, *args, **kwargs) -> None: def assertEqual(self, x, y, msg: Optional[str] = None, *, atol: Optional[float] = None, rtol: Optional[float] = None, equal_nan=True, exact_dtype=True, exact_device=False) -> None: - assert (atol is None) == (rtol is None), "If one of atol or rtol is specified the other must be, too" + assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too" + debug_msg: Optional[str] = None # Tensor x Number and Number x Tensor comparisons if isinstance(x, torch.Tensor) and isinstance(y, Number): @@ -1057,32 +1085,42 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *, elif isinstance(y, torch.Tensor) and isinstance(x, np.bool_): self.assertEqual(x, y.item(), atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, exact_device=exact_device) + # Tensor x Tensor elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): - super().assertEqual(x.is_sparse, y.is_sparse, msg=msg) - super().assertEqual(x.is_quantized, y.is_quantized, msg=msg) + debug_msg = ("Attempted to compare with different is_sparse settings: " + f"Expected: {x.is_sparse}; Actual: {y.is_sparse}.") + super().assertEqual(x.is_sparse, y.is_sparse, msg=self._get_assert_msg(msg=msg, debug_msg=debug_msg)) + debug_msg = ("Attempted to compare with different is_quantized settings: " + f"Expected: {x.is_quantized}; Actual: {y.is_quantized}.") + super().assertEqual(x.is_quantized, y.is_quantized, msg=self._get_assert_msg(msg=msg, debug_msg=debug_msg)) if x.is_sparse: - x = self.safeCoalesce(x) - y = self.safeCoalesce(y) - indices_result, debug_msg = self._compareTensors(x._indices(), y._indices(), - rtol=rtol, atol=atol, - equal_nan=equal_nan, exact_dtype=exact_dtype, - exact_device=exact_device) - - if not indices_result and msg is None: - assert debug_msg is not None - msg = "Sparse tensor indices failed to compare as equal! " + debug_msg - self.assertTrue(indices_result, msg=msg) - - values_result, debug_msg = self._compareTensors(x._values(), y._values(), - rtol=rtol, atol=atol, - equal_nan=equal_nan, exact_dtype=exact_dtype, - exact_device=exact_device) - - if not values_result and msg is None: - assert debug_msg is not None - msg = "Sparse tensor values failed to compare as equal! " + debug_msg - self.assertTrue(values_result, msg=msg) + if x.size() != y.size(): + debug_msg_sparse = ("Attempted to compare equality of tensors with different sizes: " + f"Expected: {x.size()}; Actual: {y.size()}.") + super().assertTrue(False, msg=self._get_assert_msg(msg=msg, debug_msg=debug_msg_sparse)) + + x = x.coalesce() + y = y.coalesce() + indices_result, debug_msg_indices = self._compareTensors(x._indices(), y._indices(), + rtol=rtol, atol=atol, + equal_nan=equal_nan, exact_dtype=exact_dtype, + exact_device=exact_device) + + if not indices_result: + assert debug_msg_indices is not None + debug_msg = "Sparse tensor indices failed to compare as equal! " + debug_msg_indices + super().assertTrue(indices_result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) + + values_result, debug_msg_values = self._compareTensors(x._values(), y._values(), + rtol=rtol, atol=atol, + equal_nan=equal_nan, exact_dtype=exact_dtype, + exact_device=exact_device) + + if not values_result: + assert debug_msg_values is not None + debug_msg = "Sparse tensor values failed to compare as equal! " + debug_msg_values + super().assertTrue(values_result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif x.is_quantized and y.is_quantized: self.assertEqual(x.qscheme(), y.qscheme(), atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, @@ -1106,29 +1144,33 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *, atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, exact_device=exact_device) - result, debug_msg = self._compareTensors(x.int_repr().to(torch.int32), - y.int_repr().to(torch.int32), - atol=atol, rtol=rtol, - exact_dtype=exact_dtype, - exact_device=exact_device) + result, debug_msg_compare = self._compareTensors(x.int_repr().to(torch.int32), + y.int_repr().to(torch.int32), + atol=atol, rtol=rtol, + exact_dtype=exact_dtype, + exact_device=exact_device) - if not result and msg is None: - assert debug_msg is not None - msg = "Quantized representations failed to compare as equal! " + debug_msg - self.assertTrue(result, msg=msg) + if not result: + assert debug_msg_compare is not None + debug_msg = "Quantized representations failed to compare as equal! " + debug_msg_compare + super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) else: - result, debug_msg = self._compareTensors(x, y, rtol=rtol, atol=atol, - equal_nan=equal_nan, exact_dtype=exact_dtype, - exact_device=exact_device) - - if not result and msg is None: - assert debug_msg is not None - msg = "Tensors failed to compare as equal! " + debug_msg - self.assertTrue(result, msg=msg) + result, debug_msg_generic = self._compareTensors(x, y, rtol=rtol, atol=atol, + equal_nan=equal_nan, exact_dtype=exact_dtype, + exact_device=exact_device) + + if not result: + assert debug_msg_generic is not None + debug_msg = "Tensors failed to compare as equal!" + debug_msg_generic + super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif isinstance(x, string_classes) and isinstance(y, string_classes): - super().assertEqual(x, y, msg=msg) + debug_msg = ("Attempted to compare [string] types: " + f"Expected: {repr(x)}; Actual: {repr(y)}.") + super().assertEqual(x, y, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif type(x) == set and type(y) == set: - super().assertEqual(x, y, msg=msg) + debug_msg = ("Attempted to compare [set] types: " + f"Expected: {x}; Actual: {y}.") + super().assertEqual(x, y, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif isinstance(x, dict) and isinstance(y, dict): if isinstance(x, OrderedDict) and isinstance(y, OrderedDict): self.assertEqual(x.items(), y.items(), atol=atol, rtol=rtol, @@ -1145,28 +1187,44 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *, exact_dtype=exact_dtype, exact_device=exact_device) elif isinstance(x, type) and isinstance(y, type): # See TestTorch.test_assert_equal_generic_meta - super().assertEqual(x, y, msg=msg) + debug_msg = ("Attempted to compare [type] types: " + f"Expected: {x}; Actual: {y}.") + super().assertEqual(x, y, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) elif is_iterable(x) and is_iterable(y): - super().assertEqual(len(x), len(y), msg=msg) + debug_msg = ("Attempted to compare the lengths of [iterable] types: " + f"Expected: {len(x)}; Actual: {len(y)}.") + super().assertEqual(len(x), len(y), msg=self._get_assert_msg(msg, debug_msg=debug_msg)) for x_, y_ in zip(x, y): self.assertEqual(x_, y_, atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, exact_device=exact_device) elif isinstance(x, bool) and isinstance(y, bool): - self.assertTrue(x == y, msg=msg) + super().assertTrue(x == y, msg=msg) # Scalar x Scalar elif isinstance(x, Number) and isinstance(y, Number): - result, debug_msg = self._compareScalars(x, y, rtol=rtol, atol=atol, - equal_nan=equal_nan) - if not result and msg is None: - assert debug_msg is not None - msg = "Scalars failed to compare as equal! " + debug_msg - self.assertTrue(result, msg=msg) + result, debug_msg_scalars = self._compareScalars(x, y, rtol=rtol, atol=atol, + equal_nan=equal_nan) + if not result: + assert debug_msg_scalars is not None + debug_msg = "Scalars failed to compare as equal! " + debug_msg_scalars + super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) + # Tensor x Numpy array + elif isinstance(x, torch.Tensor) and isinstance(y, np.ndarray): + self.assertEqual(x, torch.from_numpy(y), atol=atol, rtol=rtol, msg=msg, + exact_dtype=exact_dtype, exact_device=exact_device) + # Numpy array x Tensor + elif isinstance(x, np.ndarray) and isinstance(y, torch.Tensor): + self.assertEqual(torch.from_numpy(x), y, atol=atol, rtol=rtol, msg=msg, + exact_dtype=exact_dtype, exact_device=exact_device) + # Numpy array x Numpy array + elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray): + self.assertEqual(torch.from_numpy(x), torch.from_numpy(y), atol=atol, rtol=rtol, msg=msg, + exact_dtype=exact_dtype, exact_device=exact_device) else: super().assertEqual(x, y, msg=msg) - def assertNotEqual(self, x, y, msg: Optional[str] = None, *, - atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None: + def assertNotEqual(self, x, y, msg: Optional[str] = None, *, # type: ignore[override] + atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None: # type: ignore[override] with self.assertRaises(AssertionError, msg=msg): self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs) @@ -1225,7 +1283,7 @@ def maybeWarnsRegex(self, category, regex=''): msg = 'Caught unexpected warnings:\n' for w in ws: msg += warnings.formatwarning( - w.message, w.category, w.filename, w.lineno, w.line) + str(w.message), w.category, w.filename, w.lineno, w.line) msg += '\n' self.fail(msg) @@ -1313,29 +1371,30 @@ def assertExpectedStripMangled(self, s, subname=None): s = re.sub(r'__torch__[^ ]+', '', s) self.assertExpected(s, subname) - # returns captured stderr + # run code in subprocess and capture exceptions. @staticmethod - def runWithPytorchAPIUsageStderr(code): + def run_process_no_exception(code, env=None): import subprocess - env = os.environ.copy() - env["PYTORCH_API_USAGE_STDERR"] = "1" - pipes = subprocess.Popen( + popen = subprocess.Popen( [sys.executable, '-c', code], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) - return pipes.communicate()[1].decode('ascii') - - if sys.version_info < (3, 2): - # assertRegexpMatches renamed to assertRegex in 3.2 - assertRegex = unittest.TestCase.assertRegexpMatches - # assertRaisesRegexp renamed to assertRaisesRegex in 3.2 - assertRaisesRegex = unittest.TestCase.assertRaisesRegexp + (stdout, stderr) = popen.communicate() + return (stdout, stderr) - if sys.version_info < (3, 5): - # assertNotRegexpMatches renamed to assertNotRegex in 3.5 - assertNotRegex = unittest.TestCase.assertNotRegexpMatches + # returns captured stderr + @staticmethod + def runWithPytorchAPIUsageStderr(code): + env = os.environ.copy() + env["PYTORCH_API_USAGE_STDERR"] = "1" + # remove IN_CI flag since this is a wrapped test process. + # IN_CI flag should be set in the parent process only. + if "IN_CI" in env.keys(): + del env["IN_CI"] + (stdout, stderr) = TestCase.run_process_no_exception(code, env=env) + return stderr.decode('ascii') def download_file(url, binary=True): @@ -1483,7 +1542,7 @@ def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): s[i] = 0 elif s[i] == 0: s[i] = 1 - return u.mm(torch.diag(s)).mm(v.transpose(0, 1)) + return u.mm(torch.diag(s).to(dtype)).mm(v.transpose(0, 1)) def random_symmetric_matrix(l, *batches, **kwargs): @@ -1494,6 +1553,14 @@ def random_symmetric_matrix(l, *batches, **kwargs): return A +def random_hermitian_matrix(l, *batches, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) + A = (A + A.transpose(-2, -1).conj()).div_(2) + return A + + def random_symmetric_psd_matrix(l, *batches, **kwargs): dtype = kwargs.get('dtype', torch.double) device = kwargs.get('device', 'cpu') @@ -1501,6 +1568,17 @@ def random_symmetric_psd_matrix(l, *batches, **kwargs): return torch.matmul(A, A.transpose(-2, -1)) +def random_hermitian_psd_matrix(matrix_size, *batch_dims, dtype=torch.double, device='cpu'): + """ + Returns a batch of random Hermitian semi-positive-definite matrices. + The shape of the result is batch_dims + (matrix_size, matrix_size) + The following example creates a tensor of size 2 x 4 x 3 x 3 + >>> matrices = random_hermitian_psd_matrix(3, 2, 4, dtype=dtype, device=device) + """ + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device) + return torch.matmul(A, A.conj().transpose(-2, -1)) + + def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): dtype = kwargs.get('dtype', torch.double) device = kwargs.get('device', 'cpu') @@ -1510,6 +1588,19 @@ def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5 +def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device): + """ + Returns a batch of random Hermitian positive-definite matrices. + The shape of the result is batch_dims + (matrix_size, matrix_size) + The following example creates a tensor of size 2 x 4 x 3 x 3 + >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device) + """ + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), + dtype=dtype, device=device) + return torch.matmul(A, A.transpose(-2, -1).conj()) \ + + torch.eye(matrix_size, dtype=dtype, device=device) + + def make_nonzero_det(A, sign=None, min_singular_value=0.1): u, s, v = A.svd() s.clamp_(min=min_singular_value) @@ -1538,8 +1629,9 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims, A = torch.randn(batch_dims + (matrix_size, matrix_size), dtype=dtype, device=device) u, _, v = A.svd() - s = torch.arange(1., matrix_size + 1, dtype=dtype, device=device).mul_(1.0 / (matrix_size + 1)).diag() - return u.matmul(s.expand(batch_dims + (matrix_size, matrix_size)).matmul(v.transpose(-2, -1))) + real_dtype = A.real.dtype if A.dtype.is_complex else A.dtype + s = torch.arange(1., matrix_size + 1, dtype=real_dtype, device=device).mul_(1.0 / (matrix_size + 1)).diag() + return u.matmul(s.expand(batch_dims + (matrix_size, matrix_size)).to(A.dtype).matmul(v.transpose(-2, -1))) def random_matrix(rows, columns, *batch_dims, **kwargs): @@ -1604,7 +1696,8 @@ def random_sparse_matrix(rows, columns, density=0.01, **kwargs): values = torch.randn(nonzero_elements, dtype=dtype, device=device) # ensure that the diagonal dominates values *= torch.tensor([-float(i - j)**2 for i, j in zip(*indices)], dtype=dtype, device=device).exp() - A = torch.sparse_coo_tensor(indices, values, (rows, columns), device=device) + indices_tensor = torch.tensor(indices) + A = torch.sparse_coo_tensor(indices_tensor, values, (rows, columns), device=device) return A.coalesce() @@ -1661,8 +1754,8 @@ def multiply(data, N, i, j, cs, sn, left=True): icoords.append(i) jcoords.append(j) values.append(v) - indices = [icoords, jcoords] - return torch.sparse_coo_tensor(indices, values, (matrix_size, matrix_size), dtype=dtype, device=device) + indices_tensor = torch.tensor([icoords, jcoords]) + return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device) def do_test_dtypes(self, dtypes, layout, device): @@ -1727,8 +1820,16 @@ def get_int64_dtype(dtype): dtype=int64_dtype, layout=layout, device=device, requires_grad=False), int64_dtype, layout, device, fv + 5, False) +# this helper method is to recursively +# clone the tensor-type input of operators tested by OpInfo +def clone_input_helper(input): + if isinstance(input, torch.Tensor): + return torch.clone(input) + if isinstance(input, Sequence): + return tuple(map(clone_input_helper, input)) + return input THESE_TAKE_WAY_TOO_LONG = { 'test_Conv3d_groups', @@ -1800,6 +1901,16 @@ def _assertGradAndGradgradChecks(test_case, apply_fn, inputs): test_case.assertTrue(gradgradcheck(apply_fn, inputs)) +@contextmanager +def set_cwd(path: str) -> Iterator[None]: + old_cwd = os.getcwd() + try: + os.chdir(path) + yield + finally: + os.chdir(old_cwd) + + # Using @precisionOverride specific to your test is the recommended way # of doing this. These are just some values that worked for test_nn. dtype2prec_DONTUSE = {torch.float: 1e-5, diff --git a/torch/testing/_internal/dist_utils.py b/torch/testing/_internal/dist_utils.py index b88765211df19..18d7a0417eac2 100644 --- a/torch/testing/_internal/dist_utils.py +++ b/torch/testing/_internal/dist_utils.py @@ -7,6 +7,7 @@ import torch.distributed as dist import torch.distributed.rpc as rpc from torch.distributed.rpc import _rref_context_get_debug_info # type: ignore[attr-defined] +from torch.testing._internal.common_utils import FILE_SCHEMA if not dist.is_available(): @@ -14,7 +15,26 @@ sys.exit(0) -INIT_METHOD_TEMPLATE = "file://{file_name}" +INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}" + + +def single_threaded_process_group_agent(f): + """ + Forces ProcessGroupAgent to use only a single thread in the ThreadPool for + sending and processing requests. + """ + @wraps(f) + def wrapper(self, *args, **kwargs): + backend_type = self.rpc_backend + if backend_type == rpc.backend_registry.BackendType["PROCESS_GROUP"]: + self.rpc_backend_options = rpc.backend_registry.construct_rpc_backend_options( + self.rpc_backend, + init_method=self.init_method, + num_send_recv_threads=1, + ) + return_value = f(self, *args, **kwargs) + return return_value + return wrapper def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True, diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index 1b1f755ed4cc5..84768496b5ff9 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -20,7 +20,7 @@ skip_if_lt_x_gpu, skip_if_rocm, ) -from torch.testing._internal.dist_utils import dist_init +from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) @@ -325,14 +325,19 @@ def trainer_name(self, rank): # The name has to be consistent with that in 'dist_init' decorator. return f"worker{rank}" - def _remote_worker_process(self): + def _remote_worker_process(self, ddp_mode): gLogger.info("The remote worker is running.") dist.init_process_group( backend="gloo", - init_method="file://{}".format(self.file_name), + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) + + if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE): + # new_group needs to be called on ranks. + dist.new_group(TRAINER_RANKS) + global shutdown_signal with shutdown_signal: shutdown_signal.wait() @@ -346,7 +351,7 @@ def _trainer_process(self, rank: int): ) dist.init_process_group( backend="gloo", - init_method="file://{}".format(self.file_name), + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) @@ -363,10 +368,11 @@ def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool): gLogger.info("Running the master process...") dist.init_process_group( backend="gloo", - init_method="file://{}".format(self.file_name), + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) + remote_em_rref = rpc.remote( self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE) ) @@ -401,6 +407,10 @@ def do_test_on_master( ) ) + if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE): + # new_group needs to be called on ranks. + dist.new_group(TRAINER_RANKS) + training_examples = get_training_examples() for _ in range(3): futures = [] @@ -455,7 +465,7 @@ def _do_test(self, ddp_mode, simulate_uneven_inputs=False): if self.rank == MASTER_RANK: self._master_process(ddp_mode, simulate_uneven_inputs) elif self.rank == REMOTE_WORKER_RANK: - self._remote_worker_process() + self._remote_worker_process(ddp_mode) elif self.rank in TRAINER_RANKS: self._trainer_process(self.rank) else: @@ -500,7 +510,7 @@ def _run_test_ddp_comparision(self, simulate_uneven_inputs=False): torch.manual_seed(self.rank) dist.init_process_group( backend="gloo", - init_method="file://{}".format(self.file_name), + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) @@ -567,7 +577,7 @@ def test_ddp_dist_autograd_sparse_grads(self): torch.manual_seed(self.rank) dist.init_process_group( backend="gloo", - init_method="file://{}".format(self.file_name), + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) @@ -604,40 +614,43 @@ def test_ddp_dist_autograd_local_vs_remote(self): torch.manual_seed(self.rank) dist.init_process_group( backend="gloo", - init_method="file://{}".format(self.file_name), + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) - remote_layer1 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(10, 5, False) - ) - layer1 = nn.Linear(10, 5, False) - # Start with the same parameters for remote and local - layer1.weight = remote_layer1.module_rref.to_here().weight - - # Run local case. - layer2 = nn.Linear(5, 1) - inputs = torch.rand((10, 10)) - ddp_model = DistributedDataParallel(layer2) - loss = ddp_model(layer1(inputs)).sum() - loss.backward() - - # Run remote case. - with dist_autograd.context() as context_id: - loss = ddp_model(remote_layer1(inputs)).sum() - dist_autograd.backward(context_id, [loss]) - grads_dict = dist_autograd.get_gradients(context_id) - dist.barrier() - self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) - self.assertEqual( - layer1.weight.grad, - rpc.rpc_sync( - "worker0", - DdpComparisonTest.get_remote_grads, - args=(remote_layer1.module_rref, context_id), - ), + # Use two different remote device input string, w/ and w/o the default + # device string "cpu", respectively. + for remote_device in ["worker0/cpu", "worker0"]: + remote_layer1 = RemoteModule( + remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False) ) + layer1 = nn.Linear(10, 5, False) + # Start with the same parameters for remote and local + layer1.weight = remote_layer1.module_rref.to_here().weight + + # Run local case. + layer2 = nn.Linear(5, 1) + inputs = torch.rand((10, 10)) + ddp_model = DistributedDataParallel(layer2) + loss = ddp_model(layer1(inputs)).sum() + loss.backward() + + # Run remote case. + with dist_autograd.context() as context_id: + loss = ddp_model(remote_layer1(inputs)).sum() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + dist.barrier() + self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) + self.assertEqual( + layer1.weight.grad, + rpc.rpc_sync( + "worker0", + DdpComparisonTest.get_remote_grads, + args=(remote_layer1.module_rref, context_id), + ), + ) @skip_if_lt_x_gpu(NUM_TRAINERS) @requires_nccl() @@ -651,13 +664,13 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self): torch.manual_seed(self.rank) dist.init_process_group( backend="gloo", - init_method="file://{}".format(self.file_name), + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) remote_layer1 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(10, 7, False) + remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False) ) layer1 = nn.Linear(10, 7, False) # Start with the same parameters for remote and local @@ -667,7 +680,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self): ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank]) remote_layer3 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(5, 3, False) + remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False) ) layer3 = nn.Linear(5, 3, False) # Start with the same parameters for remote and local diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f6f2b9a6fbfbc..3b9882d1376b8 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1,5 +1,5 @@ import copy -import fcntl +from collections import namedtuple import itertools import random import math @@ -11,17 +11,18 @@ from contextlib import contextmanager, suppress from datetime import timedelta from functools import reduce -from io import StringIO from typing import Union, NamedTuple import torch import torch.cuda import torch.distributed as dist +import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars import torch.nn as nn import torch.nn.functional as F from torch.distributed.distributed_c10d import _get_default_group, AllreduceOptions, GroupMember +from torch.testing._internal.common_utils import FILE_SCHEMA from torch.testing._internal.common_distributed import ( MultiProcessTestCase, TEST_SKIPS, @@ -33,6 +34,8 @@ skip_if_lt_x_gpu, skip_if_no_gpu, require_n_gpus_for_nccl_backend, + requires_nccl_version, + captured_output, ) from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT @@ -43,6 +46,10 @@ except ImportError: HAS_TORCHVISION = False +if sys.platform == 'win32': + import msvcrt +else: + import fcntl class Foo: def __init__(self, x): @@ -61,6 +68,24 @@ def __eq__(self, other): [1, 2, True, "string", [4, 5, "nested"]], ] +# Allowlist of distributed backends where profiling collectives is supported. +PROFILING_SUPPORTED_BACKENDS = [ + dist.Backend.NCCL, + dist.Backend.GLOO, +] + +# Allowlist of distributed backends where profiling is supported with use_cuda=True +CUDA_PROFILING_SUPPORTED_BACKENDS = [ + dist.Backend.GLOO +] + +# Dummy NamedTuple data structures to test DDP support for NamedTuple types. +EXPECTED_FIELDS = ("a", "b") +TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS) + +class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") @@ -127,18 +152,6 @@ def forward(self, x): BN_NET = BatchNormNet() ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) - -@contextmanager -def _captured_output(): - new_out, new_err = StringIO(), StringIO() - old_out, old_err = sys.stdout, sys.stderr - try: - sys.stdout, sys.stderr = new_out, new_err - yield sys.stdout, sys.stderr - finally: - sys.stdout, sys.stderr = old_out, old_err - - def get_timeout(test_id): test_name = test_id.split(".")[-1] if test_name in CUSTOMIZED_TIMEOUT: @@ -162,8 +175,7 @@ def check(backend): if backend == dist.Backend.MPI: return dist.is_mpi_available() return False - backends = map(lambda b: dist.Backend(b), backends) - if not all(map(check, backends)): + if not all(check(dist.Backend(backend)) for backend in backends): return unittest.skip( "Test requires backends to be available %s" % backends) return lambda func: func @@ -191,23 +203,33 @@ def _lock(): lockfile = os.path.join(TEMP_DIR, "lockfile") with open(lockfile, "w") as lf: try: - fcntl.flock(lf.fileno(), fcntl.LOCK_EX) - yield + if sys.platform == 'win32': + msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1) + yield + else: + fcntl.flock(lf.fileno(), fcntl.LOCK_EX) + yield finally: - fcntl.flock(lf.fileno(), fcntl.LOCK_UN) + if sys.platform == 'win32': + msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lf.fileno(), fcntl.LOCK_UN) lf.close() -def _build_tensor(size, value=None, dtype=torch.float): +def _build_tensor(size, value=None, dtype=torch.float, device_id=None): if value is None: value = size - return torch.empty(size, size, size, dtype=dtype).fill_(value) + if device_id is None: + return torch.empty(size, size, size, dtype=dtype).fill_(value) + else: + return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id) -def _build_multidim_tensor(dim, dim_size, value=None): +def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float): if value is None: value = size - return torch.FloatTensor(size=[dim_size for _ in range(dim)]).fill_(value) + return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value) class Barrier(object): @@ -270,19 +292,26 @@ def tearDown(self): @property def init_method(self): - return "file://{file_name}".format(file_name=self.file_name) + return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name) @classmethod def _run(cls, rank, test_name, file_name): + if BACKEND == 'nccl' and not torch.cuda.is_available(): + sys.exit(TEST_SKIPS['no_cuda'].exit_code) self = cls(test_name) self.rank = rank self.file_name = file_name + + if torch.cuda.is_available() and torch.cuda.device_count() < int(self.world_size): + sys.exit(TEST_SKIPS['multi-gpu'].exit_code) try: + timeout = timedelta(seconds=60) dist.init_process_group( init_method=self.init_method, backend=BACKEND, world_size=int(self.world_size), rank=self.rank, + timeout=timeout, ) except RuntimeError as e: if "recompile" in e.args[0]: @@ -349,7 +378,11 @@ def _init_multigpu_helper(self): if BACKEND == "nccl": apply_hack_for_nccl() - nGPUs_per_process = nGPUs // world_size + # If rank is lesser than or equal to number of available GPU's + # then each rank can be mapped to corresponding GPU. + nGPUs_per_process = 1 + if world_size > nGPUs: + nGPUs_per_process = nGPUs // world_size rank_to_GPU = { i: list( visible_devices[i * nGPUs_per_process: (i + 1) * nGPUs_per_process] @@ -359,7 +392,7 @@ def _init_multigpu_helper(self): return rank_to_GPU def test_dump_DDP_relevant_env_vars(self): - with _captured_output() as (out, err): + with captured_output() as (out, _): _dump_DDP_relevant_env_vars() lines = out.getvalue().splitlines() @@ -571,6 +604,234 @@ def test_backend_group(self): def test_backend_full_group(self): self._test_group_override_backend(self._init_full_group_test) + # NCCL Batch SEND RECV + @skip_if_no_gpu + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_nccl(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + p2p_op_list = [] + + for val in ["1", "0"]: + os.environ["NCCL_BLOCKING_WAIT"] = val + for src in range(0, dist.get_world_size()): + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(src + 1, value=-1, device_id=device_id) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + @skip_if_no_gpu + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_self_nccl(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + p2p_op_list = [] + + if rank == 0: + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id) + recv_op = dist.P2POp(dist.irecv, recv_tensor, 0) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, 0) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + @skip_if_no_gpu + @skip_if_small_worldsize + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_no_rank_zero_nccl(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + p2p_op_list = [] + + if rank == 1: + peer = 2 + elif rank == 2: + peer = 1 + + if rank in [1, 2]: + send_tensor = _build_tensor(rank + 1, device_id=device_id) + recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id) + recv_op = dist.P2POp(dist.irecv, recv_tensor, peer) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, peer) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + + self._barrier() + + # GLOO Batch SEND RECV CPU + @unittest.skipIf(BACKEND != "gloo", "GLOO Batch Send Recv CPU") + def test_batch_isend_irecv_gloo(self): + self._barrier() + rank = dist.get_rank() + p2p_op_list = [] + + for src in range(0, dist.get_world_size()): + if src == rank: + continue + send_tensor = _build_tensor(rank + 1) + recv_tensor = _build_tensor(src + 1, value=-1) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # GLOO Batch SEND RECV CPU with provided tags + @unittest.skipIf(BACKEND != "gloo", "GLOO Batch Send Recv CPU") + def test_batch_isend_irecv_gloo_tags(self): + self._barrier() + rank = dist.get_rank() + p2p_op_list = [] + + for src in range(0, dist.get_world_size()): + if src == rank: + continue + send_tensor = _build_tensor(rank + 1) + recv_tensor = _build_tensor(src + 1, value=-1) + recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src) + p2p_op_list.append(recv_op) + send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank) + p2p_op_list.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_op_list) + for req in reqs: + req.wait() + + self._barrier() + + # NCCL Batch SEND RECV Tensor Error + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_tensor_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + with self.assertRaisesRegex( + RuntimeError, "Tensors must be CUDA and dense" + ): + send_tensor = _build_tensor(rank + 1) + send_op = dist.P2POp(dist.isend, send_tensor, 1) + req = dist.batch_isend_irecv([send_op]) + req.wait() + + # NCCL Batch SEND RECV Op Error + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_op_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + with self.assertRaisesRegex( + RuntimeError, "^Invalid ``op``" + ): + send_tensor = _build_tensor(rank + 1, device_id=device_id) + send_op = dist.P2POp(dist.broadcast, send_tensor, 1) + req = dist.batch_isend_irecv([send_op]) + req.wait() + + # NCCL Batch SEND RECV p2p_op_list Error + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_op_list_err(self): + self._barrier() + rank = dist.get_rank() + if rank == 0: + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + with self.assertRaisesRegex( + RuntimeError, "^Invalid ``p2p_op_list``" + ): + send_tensor = _build_tensor(rank + 1) + req = dist.batch_isend_irecv([1, 2]) + req.wait() + + # NCCL Batch SEND RECV Mixed Backend Error + @unittest.skipIf(BACKEND != "nccl", "NCCL Batch Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_batch_isend_irecv_mixed_backend_err(self): + self._barrier() + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + group_gloo = dist.new_group(ranks=[0, 1], backend="gloo") + group_nccl = dist.new_group(ranks=[0, 1], backend="nccl") + if rank == 0: + with self.assertRaisesRegex( + RuntimeError, "All groups need to use the same backend" + ): + send_tensor = _build_tensor(rank + 1) + send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo) + send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl) + req = dist.batch_isend_irecv([send_op_gloo, send_op_nccl]) + req.wait() + + # NCCL SEND RECV + @skip_if_no_gpu + @unittest.skipIf(BACKEND != "nccl", "NCCL Send Recv Only") + @requires_nccl_version(2700, "Need NCCL 2.7+ for send/recv") + def test_send_recv_nccl(self): + rank = dist.get_rank() + rank_to_GPU = self._init_multigpu_helper() + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + + tensor = _build_tensor(rank + 1, device_id=device_id) + + for src in range(0, dist.get_world_size()): + if src == rank: + # Send mode + for dst in range(0, dist.get_world_size()): + if dst == rank: + continue + dist.send(tensor, dst) + else: + # Recv mode + expected_tensor = _build_tensor(src + 1) + output_tensor = _build_tensor(src + 1, value=-1, device_id=device_id) + dist.recv(output_tensor, src) + self.assertEqual(output_tensor, expected_tensor) + + self._barrier() + # SEND RECV @unittest.skipIf(BACKEND == "nccl", "Nccl does not support send/recv") def test_send_recv(self): @@ -600,7 +861,8 @@ def test_send_recv(self): def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, value=rank) - recv_ranks = set() + recv_ranks = list() + irecv_ranks = list() for dst in range(0, dist.get_world_size()): if dst == rank: @@ -608,19 +870,42 @@ def test_send_recv_any_source(self): for dst in range(0, dist.get_world_size()): if dst == rank: continue - output_tensor = _build_tensor(10, value=-1) - sender = dist.recv(output_tensor) - # Assert the scalar value "sender" that should be - # equal to the rank of the sender is equal to all - # values in the received tensor. - self.assertTrue(output_tensor.eq(sender).all()) - recv_ranks.add(sender) + for recv in ["recv", "irecv"]: + output_tensor = _build_tensor(10, value=-1) + + if recv == "recv": + sender = dist.recv(output_tensor) + recv_ranks.append(sender) + elif recv == "irecv": + work = dist.irecv(output_tensor) + work.wait() + sender = work._source_rank() + irecv_ranks.append(sender) + + # Assert the scalar value "sender" that should be + # equal to the rank of the sender is equal to all + # values in the received tensor. + self.assertTrue(output_tensor.eq(sender).all()) else: # Send mode - dist.send(tensor, dst) - - self.assertEqual(len(recv_ranks), dist.get_world_size() - 1) + dist.send(tensor, dst) # recv + dist.send(tensor, dst) # irecv + + # Each rank would have 2 * (world_size - 1) sends, verify that + # globally we receive the same amount on the other end. + recv_ranks_tensor = torch.cat((torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0) + global_recv_ranks = [torch.empty_like(recv_ranks_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(global_recv_ranks, recv_ranks_tensor) + global_recv_ranks_list = [] + for tensor in global_recv_ranks: + global_recv_ranks_list += tensor.tolist() + + from itertools import groupby + global_recv_ranks_list.sort() + frequency = [len(list(group)) for key, group in groupby(global_recv_ranks_list)] + self.assertEqual(dist.get_world_size(), len(frequency)) + self.assertEqual([2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency) self._barrier() # SEND RECV WITH TAG @@ -711,9 +996,9 @@ def _test_broadcast_helper( opts = dist.BroadcastOptions() opts.rootTensor = 0 opts.rootRank = src - group_id.broadcast([expected_tensor], opts).wait() + self.call_dist_op(":broadcast", True, group_id.broadcast, [expected_tensor], opts) else: - dist.broadcast(expected_tensor, src, group_id) + self.call_dist_op(":broadcast", False, dist.broadcast, expected_tensor, src, group_id) else: tensor = _build_tensor(src + 1, -1, dtype) if cuda: @@ -722,9 +1007,9 @@ def _test_broadcast_helper( opts = dist.BroadcastOptions() opts.rootTensor = 0 opts.rootRank = src - group_id.broadcast([tensor], opts).wait() + self.call_dist_op(":broadcast", True, group_id.broadcast, [tensor], opts) else: - dist.broadcast(tensor, src, group_id) + self.call_dist_op(":broadcast", False, dist.broadcast, tensor, src, group_id) self.assertEqual(tensor.size(), expected_tensor.size()) self.assertEqual(tensor.ne(expected_tensor).max(), torch.tensor(False)) @@ -761,7 +1046,6 @@ def test_broadcast_full_group(self): "Only NCCL backend supports high priority stream", ) @skip_if_no_gpu - @skip_if_rocm def test_nccl_high_priority_stream(self): group, _, rank = self._init_global_test() rank_to_GPU = self._init_multigpu_helper() @@ -792,17 +1076,12 @@ def _test_reduce_helper( rank_to_GPU=None, ): for src in group: + tensor = _build_tensor(src + 1).fill_(master_value if rank == src else worker_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + self.call_dist_op(":reduce", False, dist.reduce, tensor, src, op, group_id) if rank == src: - tensor = _build_tensor(src + 1).fill_(master_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.reduce(tensor, src, op, group_id) self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) - else: - tensor = _build_tensor(src + 1).fill_(worker_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.reduce(tensor, src, op, group_id) self._barrier() @@ -936,6 +1215,64 @@ def test_reduce_full_group_max(self): group, group_id, rank = self._init_full_group_test() self._test_reduce_helper(group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10) + # REDUCE TWICE + def _test_reduce_twice_helper( + self, + group, + group_id, + rank, + op, + master_value, + worker_value, + expected_value, + cuda=False, + rank_to_GPU=None, + ): + for src in group: + tensors = [_build_tensor(src + 1).fill_(master_value if rank == src else worker_value) for i in range(2)] + if cuda: + for i in range(2): + tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0]) + self.call_dist_op(":reduce", False, dist.reduce, tensors[0], src, op, group_id, + secondary_op_call=lambda: dist.reduce(tensors[1], src, op, group_id)) + if rank == src: + for tensor in tensors: + self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + + self._barrier() + + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_reduce_sum_twice(self): + group, group_id, rank = self._init_global_test() + self._test_reduce_twice_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + ) + + @unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA reduce") + @skip_if_no_gpu + @skip_if_rocm + def test_reduce_sum_cuda_twice(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_reduce_twice_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + 10 * (len(group) - 1), + True, + rank_to_GPU, + ) + + @skip_if_no_gpu @require_backend({"gloo", "nccl"}) def test_all_reduce_result_cuda(self): @@ -982,6 +1319,37 @@ def test_all_reduce_result_cuda(self): self.assertEqual(result, [_build_tensor(src + 1, expected_value)]) self._barrier() + def call_dist_op( + self, + profiling_title_postfix, + is_async, + op, + *args, + expect_event=True, + secondary_op_call=None, + profile_cuda=False, + **kwargs, + ): + op_calls = [lambda: op(*args, **kwargs)] + if secondary_op_call is not None: + op_calls.append(secondary_op_call) + + with torch.autograd.profiler.profile(use_cuda=profile_cuda) as prof: + works = [op_call() for op_call in op_calls] + if is_async: + for work in works: + work.wait() + + def get_event(postfix): + return [event for event in prof.function_events if event.name.endswith(postfix)] + + if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS: + events = get_event(profiling_title_postfix) + self.assertEqual(len(events), len(op_calls)) + for e in events: + self.assertEqual(e.count, 1) + self.assertGreaterEqual(e.cpu_time, 0) + # ALL REDUCE def _test_all_reduce_helper( self, @@ -994,20 +1362,34 @@ def _test_all_reduce_helper( expected_value, cuda=False, rank_to_GPU=None, + dtype=torch.float, + async_op=False, ): for src in group: - if rank == src: - tensor = _build_tensor(src + 1).fill_(master_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) - else: - tensor = _build_tensor(src + 1).fill_(worker_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + curr_value = master_value if rank == src else worker_value + + tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + self.call_dist_op(":all_reduce", async_op, dist.all_reduce, tensor, op, group_id, async_op=async_op) + # Currently, only Gloo backend has profiling tested with CUDA enabled. + # Only run cuda profiling test for one rank to speed up since + # running with different src_rank does not affect the correctness. + if ( + src == 0 + and cuda + and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS + ): + self.call_dist_op( + ":all_reduce", + async_op, + dist.all_reduce, + tensor, + op, + group_id, + async_op=async_op, + profile_cuda=True, + ) self._barrier() @@ -1024,9 +1406,23 @@ def test_all_reduce_sum(self): 2 + (10 * (len(group) - 1)), ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_sum_async(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + async_op=True + ) + @unittest.skipIf( - BACKEND != "gloo", - "Only Gloo backend will have CUDA allReduce tested", + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and NCCL backends will have CUDA allReduce tested", ) @skip_if_no_gpu def test_all_reduce_sum_cuda(self): @@ -1044,6 +1440,71 @@ def test_all_reduce_sum_cuda(self): rank_to_GPU, ) + @unittest.skipIf( + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and NCCL backends will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_async(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + 2, + 10, + 2 + (10 * (len(group) - 1)), + True, + rank_to_GPU, + async_op=True + ) + + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_sum_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + dtype=torch.cfloat, + ) + + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_complex_unsupported_ops(self): + unsupported_ops = [dist.ReduceOp.MAX, dist.ReduceOp.MIN, dist.ReduceOp.PRODUCT, + dist.ReduceOp.BAND, dist.ReduceOp.BOR, dist.ReduceOp.BXOR] + group, group_id, rank = self._init_global_test() + for unsupported_op in unsupported_ops: + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): + dist.all_reduce(_build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id) + + @unittest.skipIf( + BACKEND != "gloo" and BACKEND != "nccl", + "Only Gloo and NCCL backends will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + True, + rank_to_GPU, + dtype=torch.cfloat, + ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_product(self): group, group_id, rank = self._init_global_test() @@ -1182,9 +1643,10 @@ def test_sparse_all_reduce_sum_cuda(self): @staticmethod def _all_reduce_coalesced_sum_test_cases(group_size): return ( - [2, 3], - [10, 11], - [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1)] + [2, 3, complex(2, 3)], + [10, 11, complex(10, 11)], + [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1), complex(2, 3) + complex(10, 11) * (group_size - 1)], + [torch.float, torch.float, torch.cfloat], ) @staticmethod @@ -1192,7 +1654,8 @@ def _all_reduce_coalesced_product_test_cases(group_size): return ( [1, 2], [3, 4], - [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)] + [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)], + [torch.float, torch.float], ) @staticmethod @@ -1200,7 +1663,8 @@ def _all_reduce_coalesced_min_test_cases(group_size): return ( [1, 4], [2, 3], - [1, 3] + [1, 3], + [torch.float, torch.float], ) @staticmethod @@ -1208,9 +1672,16 @@ def _all_reduce_coalesced_max_test_cases(group_size): return ( [1, 4], [2, 3], - [2, 4] + [2, 4], + [torch.float, torch.float], ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_coalesced_max_complex_unsupported(self): + group, group_id, rank = self._init_global_test() + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): + dist.all_reduce_coalesced([_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id) + def _test_all_reduce_coalesced_helper( self, group, @@ -1227,22 +1698,24 @@ def _test_all_reduce_coalesced_helper( dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases }[op] - master_values, worker_values, expected_values = test_case_func(len(group)) + master_values, worker_values, expected_values, dtypes = test_case_func(len(group)) for src in group: + curr_values = master_values if rank == src else worker_values tensors = [ - _build_tensor(src + 1, val) - for val in (master_values if rank == src else worker_values) + _build_tensor(src + 1, val, dtype=dtype) + for dtype, val in zip(dtypes, curr_values) ] if cuda: - tensors = list(map(tensors, lambda t: t.cuda(rank_to_GPU[rank][0]))) - dist.all_reduce_coalesced(tensors, op, group_id) + tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] + self.call_dist_op(":all_reduce", False, dist.all_reduce_coalesced, tensors, op, group_id) + expected_tensors = [ + _build_tensor(src + 1, expected_value, dtype=dtype) + for dtype, expected_value in zip(dtypes, expected_values) + ] self.assertEqual( tensors, - [ - _build_tensor(src + 1, expected_value) - for expected_value in expected_values - ] + expected_tensors ) self._barrier() @@ -1403,7 +1876,7 @@ def _test_scatter_helper(self, group, group_id, rank): tensors = ( [_build_tensor(dest + 1, i) for i in group] if rank == dest else [] ) - dist.scatter(tensor, src=dest, scatter_list=tensors, group=group_id) + self.call_dist_op(":scatter", False, dist.scatter, tensor, src=dest, scatter_list=tensors, group=group_id) self.assertEqual(tensor, expected_tensor) self._barrier() @@ -1454,7 +1927,7 @@ def _test_gather_helper(self, group, group_id, rank): tensors = ( [_build_tensor(dest + 1, -1) for i in group] if rank == dest else [] ) - dist.gather(tensor, dst=dest, gather_list=tensors, group=group_id) + self.call_dist_op(":gather", False, dist.gather, tensor, dst=dest, gather_list=tensors, group=group_id) if rank == dest: expected_tensors = [_build_tensor(dest + 1, i) for i in group] for t1, t2 in zip(tensors, expected_tensors): @@ -1503,17 +1976,18 @@ def test_gather_full_group(self): # ALL GATHER def _test_all_gather_helper( - self, group, group_id, rank, cuda=False, rank_to_GPU=None + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float ): + for dest in group: - tensor = _build_tensor(dest + 1, rank) - tensors = [_build_tensor(dest + 1, -1) for i in group] + tensor = _build_tensor(dest + 1, rank, dtype=dtype) + tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group] if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] - dist.all_gather(tensors, tensor, group_id) + self.call_dist_op(":all_gather", False, dist.all_gather, tensors, tensor, group_id) - expected_tensors = [_build_tensor(dest + 1, i) for i in group] + expected_tensors = [_build_tensor(dest + 1, i, dtype=dtype) for i in group] for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2) @@ -1532,6 +2006,19 @@ def test_all_gather_cuda(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_gather_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat) + + @unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA all gather") + @unittest.skipIf(BACKEND == "nccl", "CUDA all gather skipped for NCCL") + @skip_if_no_gpu + def test_all_gather_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_gather_group(self): @@ -1550,8 +2037,8 @@ def _run_all_gather_coalesced_and_verify( Helper that runs all_gather_coalesced and returns true if output matches expectations. """ - dist.all_gather_coalesced( - output_tensor_lists, input_tensors, group_id) + self.call_dist_op(":all_gather", False, dist.all_gather_coalesced, + output_tensor_lists, input_tensors, group_id) for l1, l2 in zip(output_tensor_lists, expected_tensors): for t1, t2 in zip(l1, l2): @@ -1560,7 +2047,7 @@ def _run_all_gather_coalesced_and_verify( return True def _test_all_gather_coalesced_helper( - self, group, group_id, rank + self, group, group_id, rank, dtype=torch.float ): # TODO: Instead we should probably go through _rank_not_in_group # mechanism to disable sending tensors @@ -1570,13 +2057,16 @@ def _test_all_gather_coalesced_helper( # [1], [2x2], [3x3x3] ... to be sent in one batch input_tensors = [ _build_multidim_tensor( - tensor_id, tensor_id, rank + tensor_id) for tensor_id in range( + tensor_id, + tensor_id, + rank + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] output_tensor_lists = [ [ _build_multidim_tensor( - tensor_id, tensor_id, -1) for tensor_id in range( + tensor_id, tensor_id, -1, dtype=dtype) for tensor_id in range( 1, test_case_id) ] for _ in group ] @@ -1585,7 +2075,8 @@ def _test_all_gather_coalesced_helper( _build_multidim_tensor( tensor_id, tensor_id, - rank_iter + tensor_id) for tensor_id in range( + rank_iter + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] for rank_iter in group ] @@ -1602,6 +2093,12 @@ def test_all_gather_coalesced_simple(self): group, group_id, rank = self._init_global_test() self._test_all_gather_coalesced_helper(group, group_id, rank) + @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") + @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") + def test_all_gather_coalesced_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_coalesced_helper(group, group_id, rank, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") @@ -1666,7 +2163,7 @@ def _test_all_to_all_single_equal_split_helper( in_tensor = in_tensor.cuda(rank_to_GPU[rank][0]) expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0]) out_tensor = out_tensor.cuda(rank_to_GPU[rank][0]) - dist.all_to_all_single(out_tensor, in_tensor, group=group_id) + self.call_dist_op(":all_to_all", False, dist.all_to_all_single, out_tensor, in_tensor, group=group_id) self.assertEqual(out_tensor, expected_tensor) self._barrier() @@ -1974,21 +2471,16 @@ def _test_all_reduce_multigpu_helper( master_value, worker_value, expected_value, + dtype=torch.float, ): for src in group: - if rank == src: - tensors = [ - _build_tensor(src + 1, master_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - else: - tensors = [ - _build_tensor(src + 1, worker_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - - dist.all_reduce_multigpu(tensors, op, group_id) - expected_tensor = _build_tensor(src + 1, expected_value) + curr_value = master_value if rank == src else worker_value + tensors = [ + _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i) + for i in rank_to_GPU[rank] + ] + self.call_dist_op(":all_reduce", False, dist.all_reduce_multigpu, tensors, op, group_id) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) for tensor in tensors: self.assertEqual(tensor, expected_tensor) @@ -2011,6 +2503,24 @@ def test_all_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) + @unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") + @unittest.skipIf(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") + @skip_if_no_gpu + def test_all_reduce_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_multigpu_helper( + group, + group_id, + rank, + rank_to_GPU, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + (complex(2, 3) + complex(10, 11) * (len(group) - 1)) * len(rank_to_GPU[0]), + dtype=torch.cfloat, + ) + def _test_reduce_multigpu_helper( self, group, @@ -2023,20 +2533,17 @@ def _test_reduce_multigpu_helper( expected_value, ): for src in group: + tensor_value = master_value if rank == src else worker_value + tensors = [ + _build_tensor(src + 1, tensor_value).cuda(device=i) + for i in rank_to_GPU[rank] + ] + self.call_dist_op( + "reduce", False, dist.reduce_multigpu, tensors, src, op, group_id, + expect_event=len(tensors) == 1) if rank == src: - tensors = [ - _build_tensor(src + 1, master_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - dist.reduce_multigpu(tensors, src, op, group_id) expected_tensor = _build_tensor(src + 1, expected_value) self.assertEqual(tensors[0], expected_tensor) - else: - tensors = [ - _build_tensor(src + 1, worker_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - dist.reduce_multigpu(tensors, src, op, group_id) self._barrier() @@ -2057,10 +2564,10 @@ def test_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) - def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): + def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU, dtype=torch.float): for dest in group: tensors = [ - _build_tensor(dest + 1).cuda(device=i) for i in rank_to_GPU[rank] + _build_tensor(dest + 1, dtype=dtype).cuda(device=i) for i in rank_to_GPU[rank] ] # construct expected output along with @@ -2068,16 +2575,18 @@ def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): output_tensors = [] expected_output = [] output_per_gpu = ( - [_build_tensor(dest + 1, -1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, -1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) expected_per_gpu = ( - [_build_tensor(dest + 1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) for gpu in rank_to_GPU[rank]: output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu]) - - dist.all_gather_multigpu(output_tensors, tensors, group_id) + self.call_dist_op( + "all_gather", False, + dist.all_gather_multigpu, output_tensors, tensors, group_id, + expect_event=len(expected_output) == 1) self.assertEqual(output_tensors, expected_output) self._barrier() @@ -2089,6 +2598,13 @@ def test_all_gather_multigpu(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU) + @unittest.skipIf(BACKEND != "nccl", "Only Nccl backend supports allgather multigpu") + @skip_if_no_gpu + def test_all_gather_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU, dtype=torch.cfloat) + def _model_step(self, model): for param in model.parameters(): if param.grad is not None: @@ -2162,8 +2678,13 @@ def _test_DDP_5iter( # save the model in the middle and reload if test_save and idx == 2 and INIT_METHOD.startswith("file://"): with tempfile.NamedTemporaryFile() as tmp: - torch.save(model_DDP, tmp.name) - model_DDP = torch.load(tmp.name) + if sys.platform == 'win32': + torch.save(model_DDP, tmp) + tmp.seek(0) + model_DDP = torch.load(tmp) + else: + torch.save(model_DDP, tmp.name) + model_DDP = torch.load(tmp.name) with tempfile.TemporaryFile() as tmp_file: torch.save(model_DDP, tmp_file) @@ -2192,8 +2713,13 @@ def _test_DistributedDataParallel(self, gpu_subset, rank, output_device=None, gr # test serializable/unserializable with tempfile.NamedTemporaryFile() as tmp: - torch.save(model_DDP, tmp.name) - model_DDP = torch.load(tmp.name) + if sys.platform == 'win32': + torch.save(model_DDP, tmp) + tmp.seek(0) + model_DDP = torch.load(tmp) + else: + torch.save(model_DDP, tmp.name) + model_DDP = torch.load(tmp.name) # dummy data initialization local_bs = len(gpu_subset) @@ -2262,7 +2788,7 @@ def test_DistributedDataParallel_requires_grad(self): @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm def test_DistributedDataParallel_non_default_stream(self): - stream = torch.cuda.Stream() + stream = torch.cuda.Stream(self.rank) rank = self.rank with torch.cuda.stream(stream): net = torch.nn.parallel.DistributedDataParallel( @@ -2294,6 +2820,54 @@ def test_DistributedDataParallel_non_default_stream(self): msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}", ) + @unittest.skipIf( + BACKEND != "nccl", + "Only NCCL backend supports DDP communication hook", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @skip_if_rocm + def test_DistributedDataParallel_powerSGD_ddp_comm_hook(self): + stream = torch.cuda.Stream(self.rank) + rank = self.rank + rank_to_GPU = self._init_multigpu_helper() + gpus = list(rank_to_GPU[rank]) + with torch.cuda.stream(stream): + net = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(1, 5).to(rank), device_ids=[rank] + ) + process_group = torch.distributed.new_group(gpus) + state = powerSGD.PowerSGDState( + process_group=process_group, matrix_approximation_rank=1 + ) + net.register_comm_hook(state=state, hook=powerSGD.powerSGD_hook) + # NOTE: batched_powerSGD_hook cannot pass the following test, because it has a lower accuracy. + for i in range(1000): + # Clear gradients manually. + grad = net.module.weight.grad + if grad is not None: + grad.requires_grad_(False) + grad.zero_() + # Forward + BW + batch = torch.tensor([rank]).float().cuda(rank) + loss = net(batch).sum() + loss.backward() + # For each worker, the gradient on the weight should be worker_rank. + grad = net.module.weight.grad + avg = grad.clone() + # All-reducing the gradient averages should give us the gradient + # average. If not, then one of the workers has not correctly + # written back the averaged gradient before this all-reduce call. + dist.all_reduce(avg) + world_size = int(os.environ["WORLD_SIZE"]) + avg.div_(world_size) + expected_grad = sum(i for i in range(world_size)) / world_size + self.assertEqual( + avg[0, 0], + expected_grad, + msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}", + ) + + @unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo', "Only Nccl & Gloo backend support DistributedDataParallel") @skip_if_no_gpu @@ -2308,7 +2882,7 @@ def test_DistributedDataParallel(self): self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda')) # test device_ids - gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus)) + gpus = [torch.device('cuda:' + str(i)) for i in gpus] self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda')) @unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo', @@ -2326,7 +2900,7 @@ def test_DistributedDataParallel_with_grad_is_view(self): gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'), gradient_as_bucket_view=True) # test device_ids - gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus)) + gpus = [torch.device('cuda:' + str(i)) for i in gpus] self._test_DistributedDataParallel( gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'), gradient_as_bucket_view=True) @@ -2350,8 +2924,13 @@ def _test_DistributedDataParallel_SyncBatchNorm(self, gpu_subset, rank, local_bs # test serializable/unserializable with tempfile.NamedTemporaryFile() as tmp: - torch.save(model_DDP, tmp.name) - model_DDP = torch.load(tmp.name) + if sys.platform == 'win32': + torch.save(model_DDP, tmp) + tmp.seek(0) + model_DDP = torch.load(tmp) + else: + torch.save(model_DDP, tmp.name) + model_DDP = torch.load(tmp.name) # data initialization input_cpu = torch.randn(global_bs, 2) @@ -2406,7 +2985,7 @@ def test_DistributedDataParallel_SyncBatchNorm(self): output_device=torch.device('cuda')) # test device_ids - gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus)) + gpus = [torch.device('cuda:' + str(i)) for i in gpus] self._test_DistributedDataParallel_SyncBatchNorm( gpu_subset=gpus, rank=rank, @@ -2709,9 +3288,15 @@ def test_nccl_backend_bool_broadcast(self): def test_DistributedSampler_padding(self): # Tests padding of distributed sampler. world_size = dist.get_world_size() + + # Simulates the 'casual' dataset size dataset_size = 100 + world_size + 1 dataset = [torch.ones(1).to(self.rank) * i for i in range(dataset_size)] + # Simulates the 'tiny' dataset size + dataset_tiny_size = max(world_size // 2 - 1, 1) + dataset_tiny = [torch.ones(1).to(self.rank) * i for i in range(dataset_tiny_size)] + # Specifying drop_last=True will cause the tail of the data to be dropped. dist_sampler = DistributedSampler(dataset=dataset, drop_last=True) local_num_samples, local_dataset_size = ( @@ -2761,9 +3346,32 @@ def validate_global_samples(local_num_samples): # Ensure that each rank processes the same number of samples. validate_global_samples(local_num_samples) + # Ensure additional samples are padded even when + # the extremely small dataset is given. + dist_sampler_added_samples_tiny = DistributedSampler(dataset=dataset_tiny) + local_num_samples, local_dataset_size = ( + dist_sampler_added_samples_tiny.num_samples, + dist_sampler_added_samples_tiny.total_size, + ) + self.assertEqual( + local_num_samples, math.ceil(dataset_tiny_size / world_size) + ) + self.assertEqual(local_dataset_size, local_num_samples * world_size) + indices_list = list(iter(dist_sampler_added_samples_tiny)) + self.assertEqual(len(indices_list), local_num_samples) + validate_global_samples(local_num_samples) + + @require_backend({"nccl", "gloo"}) @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) def test_allgather_object(self): + # Only set device for NCCL backend since it must use GPUs. + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + gather_objects = collectives_object_test_list output_gathered = [None for _ in range(dist.get_world_size())] dist.all_gather_object( @@ -2818,7 +3426,10 @@ class Bar: def test_nccl_gather_object_err(self): output_gathered = [None for _ in range(dist.get_world_size())] gather_on_rank = 0 + # Case where rank != GPU device. my_rank = dist.get_rank() + next_rank = (my_rank + 1) % dist.get_world_size() + torch.cuda.set_device(next_rank) with self.assertRaisesRegex( RuntimeError, "ProcessGroupNCCL does not support gather" ): @@ -2994,7 +3605,7 @@ def _run_uneven_inputs_test( rank = self.rank sync_interval = test_case.sync_interval # Ensure all outsanding GPU work is comlete so this test runs independently. - torch.cuda.synchronize() + dist.barrier() # Bucket_cap_mb is intentionally low to test allreduce scheduling when # there are many buckets. net = torch.nn.parallel.DistributedDataParallel( @@ -3231,7 +3842,8 @@ def test_ddp_uneven_input_join_disable(self): net.module.weight.grad.item(), expected_grad ) - self.assertFalse(net.ddp_join_enabled) + join_config = net.ddp_uneven_inputs_config + self.assertFalse(join_config.ddp_join_enabled) self.validate_net_equivalence(net) @require_backend({"gloo", "nccl"}) @@ -3289,6 +3901,13 @@ def test_ddp_uneven_inputs_replicated_error(self): @require_backend({"nccl", "gloo"}) @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) def test_broadcast_object_list(self): + # Only set device for NCCL backend since it must use GPUs. + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + src_rank = 0 objects = collectives_object_test_list if self.rank == src_rank else [None for _ in collectives_object_test_list] @@ -3304,3 +3923,469 @@ def test_broadcast_object_list(self): self.assertNotEqual(objects, collectives_object_test_list) dist.broadcast_object_list(objects, src=0) self.assertEqual(objects, collectives_object_test_list) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_ignore_params_arg(self): + class TestModel(nn.Module): + def __init__(self, rank): + self.rank = rank + super(TestModel, self).__init__() + self.fc1 = nn.Linear(1, 1, bias=False) + # Proxy that will be materialized to another architecture later. + # (after wrapping model with DDP) + if self.rank == 0: + self.fc2 = nn.Linear(1, 10, bias=False) + else: + self.fc2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + device_id = self.rank + # Ensure the test works for both find_unused_parameter and broadcast_buffer settings. + for (find_unused, broadcast_buffers) in itertools.product([False, True], [False, True]): + model = TestModel(self.rank).float().to(device_id) + # Note that the model can have different shape buffers if we pass + # them in to be ignored as well. + model.fc2.register_buffer( + "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank) + ) + proxy_params = list(model.fc2.parameters()) + proxy_buffers = list(model.fc2.buffers()) + model_fc2_name = [ + module_name + for module_name, module in model.named_modules() + if module is model.fc2 + ][0] + proxy_param_names = [ + f"{model_fc2_name}.{param_name}" + for param_name, _ in model.fc2.named_parameters() + ] + proxy_buffer_names = [ + f"{model_fc2_name}.{buf_name}" + for buf_name, _ in model.fc2.named_buffers() + ] + # Specify that we should ignore proxy_params since it will be + # materialized later. + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + model, proxy_param_names + proxy_buffer_names + ) + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device_id], + find_unused_parameters=find_unused, + broadcast_buffers=broadcast_buffers, + ) + # Materialize new params. These are not registered in DDP and thus + # don't have autograd hooks installed on them. + ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id) + # local model with the new materialized parameters. + local_model = copy.deepcopy(ddp.module).cuda(self.rank) + + inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1) + for i in range(6): + ddp(inp).sum().backward() + local_model(inp).sum().backward() + # materialized param grad is not touched by DDP, so its grad should + # be the same as if running locally. + for materialized_param, local_param in zip( + ddp.module.fc2.parameters(), local_model.fc2.parameters() + ): + self.assertEqual(materialized_param.grad, local_param.grad) + + # fc1 parameter grad should still be different, due to allreduce. + for synced_param, local_param in zip( + ddp.module.fc1.parameters(), local_model.fc1.parameters() + ): + self.assertFalse(synced_param.grad == local_param.grad) + + # Proxy module grad should not be touched + for proxy_param in proxy_params: + self.assertTrue(proxy_param.grad is None) + + # Synchronize since we run multiple iterations of this test, to + # isolate failure hangs. + torch.cuda.synchronize(device=self.rank) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_unused_params_rebuild_buckets_exception(self): + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10, bias=False) + self.net2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + return self.net1(x) + + ddp = torch.nn.parallel.DistributedDataParallel( + ToyModel().cuda(self.rank), device_ids=[self.rank] + ) + for i in range(2): + inp = torch.rand(1, 10) + if i > 0: + # On 2nd iteration, this will fail during rebuild_buckets, + # but we should report an error regarding unused parameters + # since that is the underlying root cause. + with self.assertRaisesRegex( + RuntimeError, + "Expected to have finished reduction in the prior iteration", + ): + ddp(inp).sum().backward() + else: + ddp(inp).sum().backward() + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_shared_grad_acc_unused_params(self): + # When find_unused_parameters=True, ensure we mark unused parameters + # even if they share gradient accumulators. + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + # net1, bias, and net1.bias are all unused params. + self.net1 = nn.Linear(10, 5, bias=False) + self.bias = nn.Parameter(torch.zeros(5)) + # net1.bias and self.bias are names for the same underlying + # parameter, so they share the same grad acc. This caused + # the bug reported in https://github.com/pytorch/pytorch/issues/41324. + self.net1.bias = self.bias + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(x) + + torch.cuda.set_device(self.rank) + model = ToyModel().to(torch.cuda.current_device()) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.rank], find_unused_parameters=True + ) + inp = torch.randn(20, 10, device=self.rank) + for i in range(6): + out = ddp_model(inp) + loss = out.sum() + loss.backward() + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_device(self): + m = nn.Linear(10, 10).to(self.rank) + expected_len = 2 + + class TensorWrapper: + __slots__ = ['t', 'moved_to_gpu'] + + def __init__(self, t): + self.t = t + self.moved_to_gpu = False + + # Handlers for specific types of validation we want to do based on + # the input type. + + def tuple_and_list_validator(x): + self.assertTrue(len(x), expected_len) + self.assertEqual(1, len(set(t.device for t in x))) + self.assertEqual(x[0].device.index, self.rank) + return x[0] + x[1] + + def namedtuple_validator(x): + self.assertEqual(x._fields, EXPECTED_FIELDS) + self.assertEqual(x.a.device.index, x.b.device.index) + self.assertEqual(x.a.device.index, self.rank) + return x.a + x.b + + def custom_type_validator(x): + self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu")) + x.t = x.t.to(self.rank) + x.moved_to_gpu = True + return x.t + + def dict_validator(x): + self.assertTrue(EXPECTED_FIELDS[0] in x.keys()) + self.assertTrue(EXPECTED_FIELDS[1] in x.keys()) + self.assertEqual(1, len(set(t.device for t in x.values()))) + self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank) + return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]] + + validators = { + TensorWrapper: custom_type_validator, + tuple: tuple_and_list_validator, + list: tuple_and_list_validator, + TestNamedTupleInput_0: namedtuple_validator, + TestNamedTupleInput_1: namedtuple_validator, + dict: dict_validator, + } + + class ToyModel(torch.nn.Module): + def __init__(_self): # noqa: B902 + super().__init__() + _self.lin = nn.Linear(10, 10, bias=False) + + def forward(_self, x, expected_type): # noqa: B902 + # Similar to scatter, the recursive to in the single-device + # case does not move tensors if they are in a custom type. + self.assertTrue(isinstance(x, expected_type)) + fwd_tensor = validators[expected_type](x) + return _self.lin(fwd_tensor) + + model = torch.nn.parallel.DistributedDataParallel( + ToyModel().to(self.rank), device_ids=[self.rank] + ) + + def train_iter(inp, input_type): + for _ in range(4): + out = model(inp, input_type) + out.sum().backward() + + # CPU tuple input, should be moved to the proper device before call + # to forward. + inp = tuple(torch.randn(10, 10) for _ in range(expected_len)) + train_iter(inp, tuple) + + # List CPU input, should be moved to proper device before call to + # forward. + inp = [torch.randn(10, 10) for _ in range(expected_len)] + train_iter(inp, list) + # Custom type containing tensor. The type is maintained, but the + # device is not propagated (which is what happens with scatter too) + inp = TensorWrapper(torch.randn(10, 10)) + train_iter(inp, TensorWrapper) + # NamedTuple input. The type should be maintained and tensor inputs + # should be moved to the correct device as in scatter. + batch = 5 + dim = 10 + a = torch.rand(batch, dim) + b = torch.rand(batch, dim) + + inp = TestNamedTupleInput_0(a, b) + train_iter(inp, type(inp)) + + inp = TestNamedTupleInput_1(a, b) + train_iter(inp, type(inp)) + + # dictionary input. + inp = { + EXPECTED_FIELDS[0]: a, + EXPECTED_FIELDS[1]: b, + } + train_iter(inp, type(inp)) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_namedtuple(self): + batch = 5 + dim = 10 + + a = torch.rand(batch, dim, device=self.rank) + b = torch.rand(batch, dim, device=self.rank) + + class NamedTupleModule(torch.nn.Module): + def __init__(_self): # noqa + super().__init__() + _self.lin = nn.Linear(10, 1) + + def forward(_self, input, expected_type): # noqa + # Without NamedTuple support, this would be of type tuple. + self.assertTrue( + isinstance(input, expected_type), + f"Expected type {expected_type} but got {type(input)}", + ) + self.assertEqual(input._fields, EXPECTED_FIELDS) + self.assertEqual(a, input.a) + self.assertEqual(b, input.b) + return _self.lin(torch.mul(input.a, input.b)) + + model = torch.nn.parallel.DistributedDataParallel( + NamedTupleModule().cuda(self.rank), device_ids=[self.rank] + ) + inp = TestNamedTupleInput_0(a, b) + # The following would fail if DDP does not propagate NamedTuples correctly. + model(inp, type(inp)) + + inp = TestNamedTupleInput_1(a, b) + model(inp, type(inp)) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_control_flow_same_across_ranks(self): + # Control flow that is the same across ranks. + batch = 20 + dim = 10 + + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.lin1 = nn.Linear(10, 10, bias=False) + self.lin2 = nn.Linear(10, 10, bias=False) + + def forward(self, x): + # Second layer is used dependent on input x. + use_second_layer = torch.equal( + x, torch.ones(batch, dim, device=x.device) + ) + if use_second_layer: + return self.lin2(F.relu(self.lin1(x))) + else: + return F.relu(self.lin1(x)) + + world_size = dist.get_world_size() + torch.cuda.set_device(self.rank) + model = torch.nn.parallel.DistributedDataParallel( + ToyModel().cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + ) + random_input = torch.randn(batch, dim, device=self.rank) + ones_input = torch.ones(batch, dim, device=self.rank) + for i in range(6): + if i % 2 == 0: + out = model(random_input) + else: + out = model(ones_input) + loss = out.sum() + loss.backward() + # On even iterations, 2nd param goes unused, on odd iterations, + # it is used. + local_used_maps = model.reducer._get_local_used_maps() + if i % 2 == 0: + expected = torch.tensor([world_size, 0], device=self.rank, dtype=torch.int32) + else: + expected = torch.tensor([world_size, world_size], device=self.rank, dtype=torch.int32) + + # Validate parameter usage. + variable_usage_tensor = local_used_maps[0] + self.assertEqual(variable_usage_tensor, expected) + + # Validate appropriate error message when DDP is used with + # find_unused_parameters=False. + model = torch.nn.parallel.DistributedDataParallel( + ToyModel().cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=False, + ) + for i in range(2): + with self.assertRaisesRegex( + RuntimeError, + "Expected to have finished reduction in the prior iteration before starting a new one", + ) if i == 1 else suppress(): + loss = model(random_input).sum() + loss.backward() + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_control_flow_different_across_ranks(self): + # Control flow that is different across ranks. + batch = 20 + dim = 10 + + class ToyModel(nn.Module): + def __init__(self, rank): + super(ToyModel, self).__init__() + self.lin1 = nn.Linear(10, 10, bias=False) + self.lin2 = nn.Linear(10, 10, bias=False) + self.rank = rank + + def forward(self, x): + # Control-flow that is rank and input dependent for the + # model. + use_second_layer = ( + torch.equal(x, torch.ones(batch, dim, device=x.device)) + and self.rank == 1 + ) + + if use_second_layer: + return self.lin2(F.relu(self.lin1(x))) + else: + return F.relu(self.lin1(x)) + + world_size = dist.get_world_size() + torch.cuda.set_device(self.rank) + model = torch.nn.parallel.DistributedDataParallel( + ToyModel(self.rank).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + ) + random_input = torch.randn(batch, dim, device=self.rank) + ones_input = torch.ones(batch, dim, device=self.rank) + for i in range(6): + if i % 2 == 0: + out = model(random_input) + else: + out = model(ones_input) + loss = out.sum() + loss.backward() + # On even iterations, 2nd param goes unused, on odd iterations, + # it is used only on rank 1. + local_used_maps = model.reducer._get_local_used_maps() + + if i % 2 == 0: + expected = torch.tensor([world_size, 0], device=self.rank, dtype=torch.int32) + else: + expected = torch.tensor([world_size, 1], device=self.rank, dtype=torch.int32) + + variable_usage_tensor = local_used_maps[0] + # Validate parameter usage. On odd iterations, 2nd param is only + # used on rank 1. + self.assertEqual(variable_usage_tensor, expected) + + # Validate appropriate error message when DDP is used with + # find_unused_parameters=False. + model = torch.nn.parallel.DistributedDataParallel( + ToyModel(self.rank).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=False, + ) + for i in range(2): + with self.assertRaisesRegex( + RuntimeError, + "Expected to have finished reduction in the prior iteration before starting a new one", + ) if i == 1 else suppress(): + loss = model(random_input).sum() + loss.backward() + + @require_backend({"gloo"}) + @unittest.skipIf(BACKEND == "nccl", "NCCL does not support scatter") + def test_scatter_object_list(self): + src_rank = 0 + scatter_list = ( + collectives_object_test_list + if self.rank == src_rank + else [None for _ in collectives_object_test_list] + ) + world_size = dist.get_world_size() + scatter_list = scatter_list[: world_size] + i = 0 + while len(scatter_list) < world_size: + scatter_list.append(scatter_list[i]) + i += 1 + + output_obj_list = [None] + dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) + self.assertEqual( + output_obj_list[0], + collectives_object_test_list[self.rank % len(collectives_object_test_list)], + ) + # Ensure errors are raised upon incorrect arguments. + with self.assertRaisesRegex( + RuntimeError, + "Expected argument scatter_object_output_list to be a list of size at least 1.", + ): + dist.scatter_object_list([], scatter_list, src=src_rank) diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index 1b453bc12a06d..4f14584af3b13 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -78,7 +78,7 @@ def world_size(self): # Override setting in RpcAgentTestFixture return 2 @staticmethod - def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None): + def _create_remote_module_iter(remote_device, modes=None): if modes is None: modes = ModuleCreationMode.__members__.values() @@ -86,15 +86,12 @@ def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None): kwargs = dict(first_kwarg=2) if ModuleCreationMode.MODULE_CTOR in modes: - remote_module = RemoteModule( - dst_worker_name, device, MyModule, args, kwargs - ) + remote_module = RemoteModule(remote_device, MyModule, args, kwargs) yield remote_module if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes: remote_module = _RemoteModule( - dst_worker_name, - device, + remote_device, create_scripted_module, args, kwargs, @@ -108,6 +105,7 @@ def test_bad_module(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + remote_device = "{}/cpu".format(dst_worker_name) args = (1,) kwargs = dict(first_kwarg=2) @@ -115,13 +113,13 @@ def test_bad_module(self): ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): - RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs) + RemoteModule(remote_device, BadModule, args, kwargs) with self.assertRaisesRegex( ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): - RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs) + RemoteModule(remote_device, BadModule, args, kwargs) @dist_utils.dist_init def test_forward_async(self): @@ -219,6 +217,21 @@ def test_remote_parameters(self): self.assertEqual(len(param_rrefs), 1) self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL)) + @dist_utils.dist_init + def test_get_module_rref(self): + if self.rank != 0: + return + dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + + # Only test Python nn.Module, because script module methods don't support ``get_module_rref``. + for remote_module in self._create_remote_module_iter( + dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] + ): + rref = remote_module.get_module_rref() + self.assertEqual(rref, remote_module.module_rref) + for param in rref.to_here().parameters(): + self.assertTrue(torch.equal(param, _PARAM_VAL)) + @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_valid_device(self): @@ -227,7 +240,7 @@ def test_valid_device(self): dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( - dst_worker_name, device="cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] + "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR] ): device = rpc.rpc_sync( dst_worker_name, remote_device, (remote_module.module_rref,) @@ -244,12 +257,12 @@ def test_invalid_devices(self): with self.assertRaisesRegex( RuntimeError, - r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string", + r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan" + " device type at start of device string", ): list( self._create_remote_module_iter( - dst_worker_name, - device="foo", + "{}/foo".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) @@ -259,8 +272,7 @@ def test_invalid_devices(self): ): list( self._create_remote_module_iter( - dst_worker_name, - device="cuda:100", + "{}/cuda:100".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) @@ -268,19 +280,48 @@ def test_invalid_devices(self): with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): list( self._create_remote_module_iter( - dst_worker_name, + "{}/cpu2".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): + list( + self._create_remote_module_iter( + "{}/".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], - device="cpu2", ) ) with self.assertRaisesRegex( - RuntimeError, r"CPU device index must be -1 or zero, got 2" + RuntimeError, + r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "{}/cuda:0/cuda:1".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"The workername in remote_device '/' cannot be empty. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "/", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '/'", ): list( self._create_remote_module_iter( - dst_worker_name, - device="cpu:2", + "/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR], ) ) diff --git a/torch/testing/_internal/distributed/pipe_with_ddp_test.py b/torch/testing/_internal/distributed/pipe_with_ddp_test.py new file mode 100644 index 0000000000000..5362c7ae9bad5 --- /dev/null +++ b/torch/testing/_internal/distributed/pipe_with_ddp_test.py @@ -0,0 +1,134 @@ +import torch +import torch.distributed as dist +import unittest + +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( + RpcAgentTestFixture, +) +from torch.testing._internal.common_distributed import ( + requires_gloo, + requires_nccl, + skip_if_lt_x_gpu, + skip_if_rocm, +) +from torch.distributed.pipeline.sync import Pipe + +class PipeWithDDPTest(RpcAgentTestFixture): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + def test_basic_nccl_ckpt_never(self): + self._run_basic_test("nccl", "never") + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + def test_basic_nccl_ckpt_never_find_unused(self): + self._run_basic_test("nccl", "never", find_unused_parameters=True) + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + @unittest.skip("DDP doesn't work with checkpointing") + def test_basic_nccl_ckpt_always(self): + self._run_basic_test("nccl", "always") + + @skip_if_lt_x_gpu(4) + @requires_nccl() + @dist_init + @skip_if_rocm + @unittest.skip("DDP doesn't work with checkpointing") + def test_basic_nccl_ckpt_except_last(self): + self._run_basic_test("nccl", "except_last") + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + def test_basic_gloo_ckpt_never(self): + self._run_basic_test("gloo", "never") + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + def test_basic_gloo_ckpt_never_find_unused(self): + self._run_basic_test("gloo", "never", find_unused_parameters=True) + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + @unittest.skip("DDP doesn't work with checkpointing") + def test_basic_gloo_ckpt_always(self): + self._run_basic_test("gloo", "always") + + @skip_if_lt_x_gpu(4) + @requires_gloo() + @dist_init + @skip_if_rocm + @unittest.skip("DDP doesn't work with checkpointing") + def test_basic_gloo_ckpt_except_last(self): + self._run_basic_test("gloo", "except_last") + + def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False): + dist.init_process_group( + backend="nccl", + init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), + world_size=self.world_size, + rank=self.rank, + ) + + # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another + # pipe between GPU 2 and 3. Both replicas are replicated via DDP. + fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) + + class MyModule(nn.Module): + def __init__(self, device): + super(MyModule, self).__init__() + self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) + self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) + + def forward(self, inp): + if find_unused_parameters: + return self.fc2(inp) + else: + return self.fc3(self.fc2(inp)) + + layer2 = MyModule(2 * self.rank + 1) + model = nn.Sequential( + fc1, + layer2 + ) + model = Pipe(model, chunks=2, checkpoint=checkpoint) + model = DistributedDataParallel(model, find_unused_parameters=find_unused_parameters) + out = model(torch.rand(16, 16).cuda(2 * self.rank)).local_value() + out.sum().backward() + + # Run forward again for find_unused_parameters to trigger any potential errors. + if find_unused_parameters: + model(torch.rand(16, 16).cuda(2 * self.rank)) + + # Check grads + output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] + dist.all_gather(output, fc1.weight.grad) + self.assertEqual(output[0], output[1]) + + output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] + dist.all_gather(output, layer2.fc2.weight.grad) + self.assertEqual(output[0], output[1]) + + if not find_unused_parameters: + output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] + dist.all_gather(output, layer2.fc3.weight.grad) + self.assertEqual(output[0], output[1]) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/testing/_internal/distributed/pipeline/utils.py b/torch/testing/_internal/distributed/pipeline/utils.py new file mode 100644 index 0000000000000..2bf4829b82232 --- /dev/null +++ b/torch/testing/_internal/distributed/pipeline/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +from typing import List + +def convert_to_balance(pipe: nn.Sequential, balance: List[int]): + device_idx = 0 + pipe_idx = 0 + balanced_pipe = [] + for num_layers in balance: + layers = [] + for i in range(num_layers): + layers.append(pipe[pipe_idx]) + pipe_idx += 1 + balanced_pipe.append(nn.Sequential(*layers).to(device_idx)) + device_idx += 1 + + return nn.Sequential(*balanced_pipe) diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index dd7481b2d80fc..15d5cfeca214f 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -11,6 +11,7 @@ import torch.testing._internal.dist_utils from torch.autograd import Function from torch.autograd.function import once_differentiable +from torch.distributed.rpc import RRef from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.dist_utils import ( dist_init, @@ -70,8 +71,7 @@ def create_tensor(): @torch.jit.script -def create_torchscript_tensor(): - # type: () -> Tensor +def create_torchscript_tensor() -> torch.Tensor: return torch.ones((3, 3)).requires_grad_() @@ -94,8 +94,7 @@ def my_script_add(t1, t2): @torch.jit.script -def my_script_ref_add(ref_t1, t2): - # type: (RRef[Tensor], Tensor) -> Tensor +def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor: t1 = ref_t1.to_here() return torch.add(t1, t2) @@ -355,7 +354,7 @@ def _verify_graph_for_nested_rpc_call(self, ctx): def _test_graph(self, fn, exec_mode): dst_rank = (self.rank + 1) % self.world_size - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) @@ -427,7 +426,7 @@ def test_graph_for_python_remote_call(self): def _test_graph_for_py_nested_call(self, exec_mode): dst_rank = (self.rank + 1) % self.world_size - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) @@ -514,7 +513,7 @@ def test_graph_for_py_nested_remote_call(self): def _test_graph_for_py_nested_call_itself(self, exec_mode): dst_rank = (self.rank + 1) % self.world_size - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) @@ -590,7 +589,7 @@ def test_graph_for_py_nested_remote_call_itself(self): self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE) def _test_no_graph_with_tensors_not_require_grad(self, exec_mode): - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=False) @@ -636,7 +635,7 @@ def test_no_graph_with_tensors_not_require_grad_remote(self): self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE) def _test_grad_only_on_return_value(self, exec_mode): - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: if ExecMode.RPC_SYNC == exec_mode: @@ -719,7 +718,7 @@ def test_remote_complex_args(self): self._test_rpc_complex_args(ExecMode.REMOTE) def context_cleanup_test_helper(self, rpc_args, func, nested=False): - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) # test that in dist autograd, in the case that tensors communicated over RPC do # NOT require grad, we still cleanup the dist autograd contexts created @@ -1206,7 +1205,7 @@ def test_backward_autograd_engine_error(self): ) def test_backward_node_failure(self): rpc._set_rpc_timeout(5) # 5 seconds - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) @@ -1429,7 +1428,7 @@ def test_backward_python_udf_error(self): def test_backward_node_failure_python_udf(self): # Set a short timeout to quickly time out failed RPCs. rpc._set_rpc_timeout(5) # 5 seconds - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) @@ -1534,7 +1533,7 @@ def test_clean_context_during_backward(self): It is fine for the 'backward' call to throw an exception in this test, but the process should not crash. """ - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) context = dist_autograd._new_context() context_id = context._context_id() @@ -1682,7 +1681,7 @@ def backward(ctx, input): @dist_init def test_debug_info(self): - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) @@ -1758,7 +1757,7 @@ def test_async_dist_autograd(self): hammering a single node with a lot of backward() calls. """ - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) if self.rank != 0: # All other ranks schedule work on rank 0. threads = [] @@ -1832,7 +1831,7 @@ def test_multiple_backward(self): @dist_init(clean_shutdown=False) def test_multiple_backward_with_errors(self): - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) with dist_autograd.context() as context_id: @@ -2197,7 +2196,7 @@ class FaultyAgentDistAutogradTest(RpcAgentTestFixture): # Reusing a simplified helper function from DistAutogradTest to ensure # autograd context is successfully cleaned up even when RPCs are failing. def context_cleanup_test_helper(self, rpc_args, func): - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) # test that in dist autograd, in the case that tensors communicated over RPC do # NOT require grad, we still cleanup the dist autograd contexts created diff --git a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py index 3754aa014ad22..b111ff6146087 100644 --- a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py @@ -198,3 +198,66 @@ def test_dist_optim(self): # ensure local equals remote self.assertEqual(new_w1, module1.get_w()) self.assertEqual(new_w2, module2.get_w()) + + + @dist_init + def test_dist_optim_functional(self): + # local version + module1 = MyModule() + module2 = MyModule() + params = [module1.get_w(), module2.get_w()] + local_optim = optim.Adagrad(params, lr=0.05) + + old_w1 = module1.w.clone().detach() + old_w2 = module2.w.clone().detach() + + g_cpu = torch.Generator() + g_cpu.manual_seed(0) + t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + output1 = module1.forward(t2) + output2 = module2.forward(output1) + loss = torch.add(output2, t1).sum() + + loss.backward() + local_optim.step() + + # distributed version + owner1 = "worker%d" % ((self.rank + 1) % self.world_size) + owner2 = "worker%d" % ((self.rank + 2) % self.world_size) + + remote_module1 = rpc.remote(owner1, MyModule) + remote_module2 = rpc.remote(owner2, MyModule) + remote_param1 = remote_method(MyModule.get_w, remote_module1) + remote_param2 = remote_method(MyModule.get_w, remote_module2) + + old_w1_remote = remote_param1.to_here() + + # sanity check: local and remote initial weights should match + self.assertEqual(old_w1, remote_param1.to_here()) + self.assertEqual(old_w2, remote_param2.to_here()) + + dist_optim = DistributedOptimizer( + optim.Adagrad, [remote_param1, remote_param2], lr=0.05 + ) + + with dist_autograd.context() as context_id: + g_cpu.manual_seed(0) + t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) + output1 = rpc_async_method(MyModule.forward, remote_module1, t2) + output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) + loss = torch.add(output2.wait(), t1) + + dist_autograd.backward(context_id, [loss.sum()]) + dist_optim.step(context_id) + + new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait() + new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait() + + # ensure optimizer changed weights + self.assertNotEqual(old_w1, new_w1) + self.assertNotEqual(old_w2, new_w2) + # ensure local equals remote + self.assertEqual(new_w1, module1.get_w()) + self.assertEqual(new_w2, module2.get_w()) diff --git a/torch/testing/_internal/distributed/rpc/examples/__init__.py b/torch/testing/_internal/distributed/rpc/examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py b/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py new file mode 100644 index 0000000000000..414e079b86d36 --- /dev/null +++ b/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py @@ -0,0 +1,139 @@ +# If you need to modify this file to make this test pass, please also apply same edits accordingly to +# https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py +# and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server + +import threading +from datetime import datetime +from time import perf_counter + +import torch +import torch.distributed.rpc as rpc +import torch.nn as nn +from torch import optim + +from torch.testing._internal.dist_utils import ( + dist_init, + worker_name, +) +from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture + +batch_size = 20 +in_features = 100 +out_features = 30 +num_batches = 4 + + +def timed_log(text): + print(f"{datetime.now().strftime('%H:%M:%S')} {text}") + + +class BatchUpdateParameterServer(object): + + def __init__(self, batch_update_size): + self.model = nn.Linear(in_features, out_features) + self.lock = threading.Lock() + self.future_model = torch.futures.Future() + self.batch_update_size = batch_update_size + self.curr_update_size = 0 + self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + for p in self.model.parameters(): + p.grad = torch.zeros_like(p) + + def get_model(self): + return self.model + + @staticmethod + @rpc.functions.async_execution + def update_and_fetch_model(ps_rref, grads): + self = ps_rref.local_value() + for p, g in zip(self.model.parameters(), grads): + p.grad += g + with self.lock: + timed_log(f"PS got {self.curr_update_size}/{self.batch_update_size} updates") + self.curr_update_size += 1 + fut = self.future_model + + if self.curr_update_size >= self.batch_update_size: + for p in self.model.parameters(): + p.grad /= self.batch_update_size + self.curr_update_size = 0 + self.optimizer.step() + self.optimizer.zero_grad() + fut.set_result(self.model) + timed_log("PS updated model") + self.future_model = torch.futures.Future() + + return fut + + +class Trainer(object): + + def __init__(self, ps_rref): + self.ps_rref = ps_rref + self.loss_fn = nn.L1Loss() + + def get_next_batch(self): + for _ in range(num_batches): + inputs = torch.randn(batch_size, in_features) + labels = torch.zeros(batch_size, out_features) + yield inputs, labels + + def train(self): + name = rpc.get_worker_info().name + m = self.ps_rref.rpc_sync().get_model() + for inputs, labels in self.get_next_batch(): + timed_log(f"{name} processing one batch") + self.loss_fn(m(inputs), labels).backward() + timed_log(f"{name} reporting grads") + m = rpc.rpc_sync( + self.ps_rref.owner(), + BatchUpdateParameterServer.update_and_fetch_model, + args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]), + ) + timed_log(f"{name} got updated model") + + +def run_trainer(ps_rref): + trainer = Trainer(ps_rref) + trainer.train() + + +def run_ps(trainers): + timed_log("Start training") + start = perf_counter() + ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers))) + futs = [] + for trainer in trainers: + futs.append( + rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) + ) + + torch.futures.wait_all(futs) + stop = perf_counter() + timed_log("Finish training") + timed_log(f"Time spent training: {stop-start}s") + +class ParameterServerTest(RpcAgentTestFixture): + + @dist_init(setup_rpc=False) + def test_batch_updating_parameter_server(self): + + if self.rank != 0: + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + else: + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)]) + + rpc.shutdown() diff --git a/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py index ee3ebdb33eff9..c4887f8c23f31 100644 --- a/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py @@ -1,4 +1,4 @@ -from typing import Tuple, Dict +from typing import Dict, Tuple import torch import torch.distributed.autograd as dist_autograd @@ -34,8 +34,7 @@ def test_get_gradients(self): dst_rank = self.rank @torch.jit.script - def dist_get_gradients(context_id): - # type: (int) -> (Dict[Tensor, Tensor]) + def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]): return dist_autograd.get_gradients(context_id) FileCheck().check("get_gradients").run(str(dist_get_gradients.graph)) diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py index a43d0f4fb4003..2a0b114f2b8a0 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -21,6 +21,8 @@ RpcAgentTestFixture, ) +def rref_isinstance(rref, cls_to_check): + return isinstance(rref.local_value(), cls_to_check) def sleep(t): time.sleep(t) @@ -357,6 +359,10 @@ def __init__(self, rank): def forward(self) -> Tensor: return self.a + @torch.jit.script_method + def custom_func(self) -> Tensor: + return self.a + def owner_create_rref_my_script_class(a): return rpc.RRef(MyScriptClass(a)) @@ -894,17 +900,14 @@ def test_torchscript_functions_not_supported(self): # wait for local MyScriptModule instantiation to finish, # otherwise it could instantiate MyScriptModule in parallel with # server thread in the below - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) dist.barrier() # rpc_sync still accepts script class and run it in # the same code path as python call. ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,)) - # rpc_sync does not accept script module and script module method. - with self.assertRaisesRegex(RuntimeError, "ScriptModules cannot be deepcopied"): - ret = rpc.rpc_sync(dst_worker_name, MyScriptModule, args=(self.rank,)) - + # rpc_sync does not accept script module method. # Python 3.5 and Python 3.6 throw different error message, the only # common word can be greped is "pickle". with self.assertRaisesRegex(TypeError, "pickle"): @@ -949,6 +952,42 @@ def test_remote_script_module(self): args=(remote_ref, torch.ones(self.rank)), ) + @dist_init + def test_create_script_module_on_remote(self): + dst_name = worker_name((self.rank + 1) % self.world_size) + # Construct on remote end with rpc_sync + created_script_module = rpc.rpc_sync( + dst_name, MyScriptModule, args=(self.rank,) + ) + # Forward should output a ones tensor of self.rank. + self.assertTrue(isinstance(created_script_module, torch.jit.ScriptModule)) + rank_ones_tensor = created_script_module() + self.assertEqual(torch.ones(self.rank), rank_ones_tensor) + + # Construct ScriptModule with rpc.remote. + remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank,)) + # Verify it is an instance of ScriptModule on remote end. + remote_end_is_script = rpc.rpc_sync( + remote_script_module.owner(), + rref_isinstance, + args=(remote_script_module, torch.jit.ScriptModule), + ) + self.assertTrue(remote_end_is_script) + # Run forward pass remotely. + remote_forward_output = remote_script_module.rpc_sync().forward() + self.assertEqual(remote_forward_output, torch.ones(self.rank)) + # Run function defined on ScriptModule remotely. + remote_func_output = remote_script_module.rpc_sync().custom_func() + self.assertEqual(remote_func_output, torch.ones(self.rank)) + # Ensure we can transfer ScriptModule RRef to this rank and run + # forward pass. + local_script_module = remote_script_module.to_here() + self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule)) + rank_ones_tensor = local_script_module() + self.assertEqual(rank_ones_tensor, torch.ones(self.rank)) + local_script_func_output = local_script_module.custom_func() + self.assertEqual(local_script_func_output, torch.ones(self.rank)) + @dist_init def test_load_script_module_with_pickled_rref(self): dst_name = worker_name((self.rank + 1) % self.world_size) @@ -1038,6 +1077,30 @@ def callback(fut): self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) + @dist_init + def test_add_done_callback(self): + callback_called = None + + def callback(fut): + nonlocal callback_called + callback_called = fut.wait() * 2 + + future = rpc.rpc_async( + worker_name((self.rank + 1) % self.world_size), + script_fork_wait_udf, + args=(torch.ones(2),), + ) + + future.add_done_callback(callback) + future_then = future.then(lambda _: True) + + self.assertEqual(future.wait(), torch.ones(2) * 2) + + # We have no guarantee that the add_done_callback fn will execute before the test finishes. + # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback + future_then.wait() + self.assertEqual(callback_called, torch.ones(2) * 4) + @dist_init def test_async_script_throw(self): future = rpc.rpc_async( diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py index 656f25322274b..96ede7231a972 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py @@ -3,6 +3,7 @@ import torch import torch.distributed.rpc as rpc from torch import Tensor +from torch.distributed.rpc import RRef from torch.testing._internal.dist_utils import ( dist_init, worker_name, @@ -63,18 +64,15 @@ def rpc_async_call_future_ret( return fut @torch.jit.script -def rref_to_here(rref_var): - # type: (RRef[Tensor]) -> Tensor +def rref_to_here(rref_var: RRef[Tensor]) -> Tensor: return rref_var.to_here() @torch.jit.script -def rref_to_here_with_timeout(rref_var, timeout): - # type: (RRef[Tensor], float) -> Tensor +def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor: return rref_var.to_here(timeout) @torch.jit.script -def rpc_async_with_rref_arg(dst_worker_name, args): - # type: (str, Tuple[RRef[Tensor]]) -> Tensor +def rpc_async_with_rref_arg(dst_worker_name: str, args: Tuple[RRef[Tensor]]) -> Tensor: fut = rpc.rpc_async(dst_worker_name, rref_to_here, args) ret = fut.wait() return ret diff --git a/torch/testing/_internal/distributed/rpc/process_group_agent_test_fixture.py b/torch/testing/_internal/distributed/rpc/process_group_agent_test_fixture.py index 893e5b8e17b0c..3a49304763455 100644 --- a/torch/testing/_internal/distributed/rpc/process_group_agent_test_fixture.py +++ b/torch/testing/_internal/distributed/rpc/process_group_agent_test_fixture.py @@ -13,12 +13,19 @@ def rpc_backend(self): @property def rpc_backend_options(self): - return rpc.backend_registry.construct_rpc_backend_options( - self.rpc_backend, - init_method=self.init_method, - # Some tests need additional threads (ex: test_trainer_ps) - num_send_recv_threads=8, - ) + try: + return self._rpc_backend_options + except AttributeError: + return rpc.backend_registry.construct_rpc_backend_options( + self.rpc_backend, + init_method=self.init_method, + # Some tests need additional threads (ex: test_trainer_ps) + num_send_recv_threads=8, + ) + + @rpc_backend_options.setter + def rpc_backend_options(self, new_rpc_backend_options): + self._rpc_backend_options = new_rpc_backend_options def get_shutdown_error_regex(self): error_regexes = [ diff --git a/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py b/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py index ed0b1cb2fe6cf..967539d056c05 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py +++ b/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py @@ -1,3 +1,4 @@ +import os from abc import ABC, abstractmethod import torch.testing._internal.dist_utils @@ -10,6 +11,16 @@ def world_size(self): @property def init_method(self): + use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) + if use_tcp_init == "1": + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + return f"tcp://{master_addr}:{master_port}" + else: + return self.file_init_method + + @property + def file_init_method(self): return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format( file_name=self.file_name ) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 163a772628a56..de988eb871410 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -2,6 +2,7 @@ import contextlib import json import logging +import os import sys from threading import Lock import time @@ -15,14 +16,18 @@ import torch.distributed.rpc as rpc import torch.distributed.autograd as dist_autograd from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info -from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs, _use_rpc_pickler +from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs, _use_rpc_pickler, _thread_local_var, _wait_all from torch.distributed.rpc.internal import ( PythonUDF, RPCExecMode, _internal_rpc_pickler, _build_rpc_profiling_key, ) -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + skip_if_lt_x_gpu, + skip_if_no_peer_access, + captured_output, +) from torch.testing._internal.common_utils import IS_MACOS, load_tests from torch.testing._internal.dist_utils import ( dist_init, @@ -32,6 +37,7 @@ wait_until_pending_futures_and_users_flushed, wait_until_owners_and_forks_on_rank, worker_name, + single_threaded_process_group_agent, ) from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, @@ -66,13 +72,16 @@ def udf_with_torch_ops(device=-1, use_record_function=False): "aten::relu", "aten::threshold", "aten::sigmoid", - "aten::sigmoid", ] +# Remote operations are prefixed with the following string for RPC profiling. +REMOTE_OP_STR = "#remote_op: " + VALUE_FUTURE = concurrent.futures.Future() DONE_FUTURE = concurrent.futures.Future() +FIFTY_MIL_CYCLES = 50000000 class StubRpcAgent: def __init__(self, world_size): @@ -145,8 +154,11 @@ def __setstate__(self, obj): class MyClass: - def __init__(self, a): + def __init__(self, a, delay=False): self.a = a + # delay initialization to simulate errors if specified + if delay: + time.sleep(2) def my_instance_method(self, b): return self.a + b @@ -165,6 +177,10 @@ def increment_value(self, increment): def get_value(self): return self.a + def my_slow_method(self, my_tensor_arg): + time.sleep(5) + return torch.add(self.a, my_tensor_arg) + def _call_method_on_rref(method, rref, *args, **kwargs): return method(rref.local_value(), *args, **kwargs) @@ -309,8 +325,13 @@ def my_script_func(tensor): return torch.add(tensor, tensor) +expected_err = "Expected error" def raise_func(): - raise ValueError("Expected error") + raise ValueError(expected_err) + +expected_err_escape = "\nFirst line of error \n next line of error \n last line of error" +def raise_func_escape(): + raise ValueError(expected_err_escape) global_rref = None @@ -386,6 +407,19 @@ def async_wrong_type(): def async_add(to, x, y): return rpc.rpc_async(to, torch.add, args=(x, y)) + +def slow_add(x, y, device="cpu"): + time.sleep(1) + x = x.to(device) + y = y.to(device) + return torch.add(x, y).cpu() + + +@rpc.functions.async_execution +def slow_async_add(to, x, y, device="cpu"): + return rpc.rpc_async(to, slow_add, args=(x, y, device)) + + @rpc.functions.async_execution def async_add_with_future_ctor(to, x, y, z): fut = torch.futures.Future() @@ -523,6 +557,31 @@ def test_self_add(self): self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) self.assertEqual(ret, torch.ones(2, 2) + 1) + @dist_init + def test_send_to_rank(self): + dst_rank = (self.rank + 1) % self.world_size + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + self.assertEqual(ret, torch.ones(2, 2) + 1) + + # Test invalid ranks + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(RuntimeError): + self._run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(RuntimeError): + self._run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(ValueError): + self._run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + with self.assertRaises(ValueError): + self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + @dist_init def test_self_py_udf_remote(self): self_worker_info = rpc.get_worker_info() @@ -563,14 +622,14 @@ def test_self_remote_rref_as_remote_arg(self): def test_rref_proxy_non_exist(self): dst = worker_name((self.rank + 1) % self.world_size) rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) - msg = "non_exist is not an attribute of type" - with self.assertRaisesRegex(ValueError, msg): + msg = "has no attribute \'non_exist\'" + with self.assertRaisesRegex(AttributeError, msg): rref.rpc_sync().non_exist() - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex(AttributeError, msg): rref.rpc_async().non_exist() - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex(AttributeError, msg): rref.remote().non_exist() def _test_rref_proxy_tensor(self, dst): @@ -739,11 +798,19 @@ def test_reinit(self): rpc_backend_options=self.rpc_backend_options, ) - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) # Wait for all init to complete. dist.barrier() - with self.assertRaisesRegex(RuntimeError, "is already initialized"): + # TODO: with TCP init, rank 0 raises Address already in use because + # rank 0 is the start daemon and the store is created before checking if + # RPC is already initialized in init_rpc. + if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0: + expected_reinit_err = "Address already in use" + else: + expected_reinit_err = "is already initialized" + + with self.assertRaisesRegex(RuntimeError, expected_reinit_err): rpc.init_rpc( name=worker_name(self.rank), backend=self.rpc_backend, @@ -955,7 +1022,8 @@ def test_all_gather_timeout(self): ): rpc.api._all_gather(SlowPickleClass(0.5)) else: - with self.assertRaisesRegex(RuntimeError, "timeout.*100 ms"): + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): rpc.api._all_gather(SlowPickleClass(0.5)) @dist_init @@ -1039,6 +1107,33 @@ def check_profiling_info(self, self_worker_name, dst_worker_name, func, rpc_even self.assertTrue(rpc_exec_mode.value in rpc_event.name) self.assertEqual(rpc_event.count, 1) + @dist_init + def test_profiler_rpc_record_shapes(self): + if self.rank != 1: + return + dst = (self.rank + 1) % self.world_size + dst_worker = worker_name(dst) + t1, t2 = torch.ones(100), torch.ones(100) + with torch.autograd.profiler.profile(record_shapes=True) as prof: + rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2)) + + function_events = prof.function_events + remote_events = [event for event in function_events if event.is_remote] + remote_add_event = [ + event for event in remote_events if "aten::add" in event.name + ][0] + remote_add_input_shapes = remote_add_event.input_shapes + # Run profiler on equivalent local op and validate shapes are the same. + with torch.autograd.profiler.profile(record_shapes=True) as prof: + torch.add(t1, t2) + + local_function_events = prof.function_events + local_add_event = [ + event for event in local_function_events if "aten::add" in event.name + ][0] + local_add_input_shapes = local_add_event.input_shapes + self.assertEqual(remote_add_input_shapes, local_add_input_shapes) + @dist_init def test_profiler_rpc_memory(self): if self.rank != 1: @@ -1080,6 +1175,9 @@ def test_profiler_remote_cuda(self): fut1.wait() fut2.wait() + def get_name(event): + return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR):] + function_events = p.function_events for event in function_events: if event.is_async: @@ -1090,23 +1188,18 @@ def test_profiler_remote_cuda(self): if event.node_id == 1: continue self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1]) - self.assertGreater(event.cuda_time_total, 0) - self.assertEqual(1, len(event.kernels)) - kernel = event.kernels[0] - if event.node_id == dst_cuda_0: - self.assertEqual(kernel.device, 0) - if event.node_id == dst_cuda_1: - self.assertEqual(kernel.device, 1) - - self.assertGreater(event.cuda_time, 0) + if get_name(event) in EXPECTED_REMOTE_EVENTS: + self.assertGreater(event.cuda_time_total, 0) + self.assertEqual(1, len(event.kernels)) + kernel = event.kernels[0] + if event.node_id == dst_cuda_0: + self.assertEqual(kernel.device, 0) + if event.node_id == dst_cuda_1: + self.assertEqual(kernel.device, 1) + self.assertGreater(event.cuda_time, 0) # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled # events. - REMOTE_OP_STR = "#remote_op: " - - def get_name(event): - return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR):] - remote_events = [event for event in function_events if event.is_remote] remote_event_names = [get_name(event) for event in remote_events if get_name(event) in EXPECTED_REMOTE_EVENTS] self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS)) @@ -1201,8 +1294,7 @@ def rpc_with_profiling(dst_worker): for fut in futs: fut.result() - @dist_init - def test_profiler_remote_events_profiled(self): + def _run_test_profiler_remote_events_profiled(self): # Tests that we can successfully invoke the profiler on a remote node, # and collect the remote events back in the local profiler. if self.rank != 1: @@ -1235,7 +1327,6 @@ def test_profiler_remote_events_profiled(self): ) for expected_remote_event_name in EXPECTED_REMOTE_EVENTS: - REMOTE_OP_STR = "#remote_op: " expected_key = rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name self.assertTrue(expected_key in remote_events) remote_event = remote_events[expected_key] @@ -1256,7 +1347,20 @@ def convert_remote_to_local(event_name): for event in events if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS ] - self.assertEqual(remote_events_list, EXPECTED_REMOTE_EVENTS) + self.assertEqual( + set(remote_events_list), + set(EXPECTED_REMOTE_EVENTS), + f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}", + ) + + @dist_init + def test_profiler_remote_events_profiled(self): + self._run_test_profiler_remote_events_profiled() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_remote_events_profiled_single_threaded(self): + self._run_test_profiler_remote_events_profiled() def run_profiling_workload(self, dst): fut = rpc.rpc_async( @@ -1269,6 +1373,67 @@ def run_profiling_workload(self, dst): ) fut.wait() + def _run_rpc_profiling_async_function(self, device="cpu"): + if self.rank != 1: + return + + dst1 = worker_name((self.rank + 1) % self.world_size) + dst2 = worker_name((self.rank + 2) % self.world_size) + x = torch.ones(2) + y = torch.ones(2) + with torch.autograd.profiler.profile() as prof: + ret = rpc.rpc_async( + dst1, slow_async_add, args=(dst2, x, y, device), timeout=20 + ) + out = ret.wait() + + function_events = prof.function_events + # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be + # recorded. + key_prefix = _build_rpc_profiling_key( + RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1 + ) + + nested_rpc_key_prefix = _build_rpc_profiling_key( + RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2 + ) + expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix + remote_events = [event for event in function_events if event.is_remote] + rpc_remote_event = [ + event for event in remote_events if event.name == expected_key + ] + self.assertEqual(1, len(rpc_remote_event)) + rpc_remote_event = rpc_remote_event[0] + self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size) + # slow_async_add's RPC does an add on dst2, which should be reflected as well. + remote_add_key = ( + expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add) + ) + remote_add_event = [ + event for event in remote_events if event.name == remote_add_key + ] + self.assertEqual(1, len(remote_add_event)) + remote_add_event = remote_add_event[0] + # Validate that node_id is dst2. + self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size) + + @dist_init + def test_rpc_profiling_async_function(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + self._run_rpc_profiling_async_function() + if torch.cuda.is_available(): + dist.barrier() + self._run_rpc_profiling_async_function(device="cuda:0") + + @single_threaded_process_group_agent + @dist_init + def test_rpc_profiling_async_function_single_threaded(self): + initialize_pg(self.file_init_method, self.rank, self.world_size) + self._run_rpc_profiling_async_function() + if torch.cuda.is_available(): + dist.barrier() + self._run_rpc_profiling_async_function(device="cuda:0") + @dist_init def test_rpc_profiling_remote_record_function(self): # test that functions run over RPC with record_function show the expected @@ -1339,11 +1504,9 @@ def get_cpu_children(event): ) def validate_profiling_workload(self, dst, prof): - REMOTE_OP_STR = "#remote_op: " def convert_remote_to_local(event_name): - remote_op_key = REMOTE_OP_STR - return event_name[event_name.find(remote_op_key) + len(remote_op_key) :] + return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :] events = prof.function_events remote_events = { @@ -1362,8 +1525,7 @@ def convert_remote_to_local(event_name): RPCExecMode.ASYNC, ) - @dist_init - def test_profiler_with_autograd_context(self): + def _run_test_profiler_with_autograd_context(self): dst = (self.rank + 1) % self.world_size if self.rank == 1: # Cases where we can double wrap messages with profiling information and autograd info. @@ -1381,6 +1543,15 @@ def test_profiler_with_autograd_context(self): self.validate_profiling_workload(dst, prof) + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_autograd_context_single_threaded(self): + self._run_test_profiler_with_autograd_context() + + @dist_init + def test_profiler_with_autograd_context(self): + self._run_test_profiler_with_autograd_context() + def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function=False, dst=None): dst = dst if dst is not None else (self.rank + 1) % self.world_size @@ -1424,8 +1595,8 @@ def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function scope_event = get_function_event(events, "foo") # Since RPC call is within the scope, its CPU interval should be # contained within foo's interval. - self.assertTrue(scope_event.cpu_interval.start < rpc_event.cpu_interval.start) - self.assertTrue(scope_event.cpu_interval.end > rpc_event.cpu_interval.end) + self.assertLessEqual(scope_event.time_range.start, rpc_event.time_range.start) + self.assertGreaterEqual(scope_event.time_range.end, rpc_event.time_range.end) # the sender, dest worker, function run, and type of RPC should all # be recorded. self_worker_name = worker_name(self.rank) @@ -1438,14 +1609,21 @@ def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function rpc_event_idx = next(i for i, event in enumerate(events) if rpc_exec_mode.value in event.name) self.assertLess(foo_event_ix, rpc_event_idx) - @dist_init - def test_profiler_with_sync_rpc_udf(self): + def _run_test_profiler_with_sync_rpc_udf(self): self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,)) self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,), use_record_function=True) @dist_init - def test_profiler_with_sync_rpc_builtin(self): + def test_profiler_with_sync_rpc_udf(self): + self._run_test_profiler_with_sync_rpc_udf() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_sync_rpc_udf_single_threaded(self): + self._run_test_profiler_with_sync_rpc_udf() + + def _run_test_profiler_with_sync_rpc_builtin(self): self._profiler_test_with_rpc( RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) ) @@ -1455,13 +1633,29 @@ def test_profiler_with_sync_rpc_builtin(self): ) @dist_init - def test_profiler_with_async_rpc_udf(self): + def test_profiler_with_sync_rpc_builtin(self): + self._run_test_profiler_with_sync_rpc_builtin() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_sync_rpc_builtin_single_threaded(self): + self._run_test_profiler_with_sync_rpc_builtin() + + def _run_test_profiler_with_async_rpc_udf(self): self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,)) self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,), use_record_function=True) @dist_init - def test_profiler_with_async_rpc_builtin(self): + def test_profiler_with_async_rpc_udf(self): + self._run_test_profiler_with_async_rpc_udf() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_async_rpc_udf_single_threaded(self): + self._run_test_profiler_with_async_rpc_udf() + + def _run_test_profiler_with_async_rpc_builtin(self): self._profiler_test_with_rpc( RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1)) ) @@ -1471,7 +1665,15 @@ def test_profiler_with_async_rpc_builtin(self): ) @dist_init - def test_profiler_with_remote_udf(self): + def test_profiler_with_async_rpc_builtin(self): + self._run_test_profiler_with_async_rpc_builtin() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_async_rpc_builtin_single_threaded(self): + self._run_test_profiler_with_async_rpc_builtin() + + def _run_test_profiler_with_remote_udf(self): self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,)) self._profiler_test_with_rpc( RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True @@ -1482,7 +1684,15 @@ def test_profiler_with_remote_udf(self): ) @dist_init - def test_profiler_with_remote_builtin(self): + def test_profiler_with_remote_udf(self): + self._run_test_profiler_with_remote_udf() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_remote_udf_single_threaded(self): + self._run_test_profiler_with_remote_udf() + + def _run_test_profiler_with_remote_builtin(self): self._profiler_test_with_rpc( RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1)) ) @@ -1499,7 +1709,15 @@ def test_profiler_with_remote_builtin(self): ) @dist_init - def test_profiler_with_script_async_rpc(self): + def test_profiler_with_remote_builtin(self): + self._run_test_profiler_with_remote_builtin() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_remote_builtin_single_threaded(self): + self._run_test_profiler_with_remote_builtin() + + def _run_test_profiler_with_script_async_rpc(self): self._profiler_test_with_rpc( RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),) ) @@ -1511,7 +1729,15 @@ def test_profiler_with_script_async_rpc(self): ) @dist_init - def test_profiler_with_script_sync_rpc(self): + def test_profiler_with_script_async_rpc(self): + self._run_test_profiler_with_script_async_rpc() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_script_async_rpc_single_threaded(self): + self._run_test_profiler_with_script_async_rpc() + + def _run_test_profiler_with_script_sync_rpc(self): self._profiler_test_with_rpc( RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),) ) @@ -1523,7 +1749,15 @@ def test_profiler_with_script_sync_rpc(self): ) @dist_init - def test_profiler_with_script_remote_rpc(self): + def test_profiler_with_script_sync_rpc(self): + self._run_test_profiler_with_script_sync_rpc() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_script_sync_rpc_single_threaded(self): + self._run_test_profiler_with_script_sync_rpc() + + def _run_test_profiler_with_script_remote_rpc(self): self._profiler_test_with_rpc( RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),) ) @@ -1538,6 +1772,14 @@ def test_profiler_with_script_remote_rpc(self): RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank ) + @dist_init + def test_profiler_with_script_remote_rpc(self): + self._run_test_profiler_with_script_remote_rpc() + + @single_threaded_process_group_agent + @dist_init + def test_profiler_with_script_remote_rpc_single_threaded(self): + self._run_test_profiler_with_script_remote_rpc() def _assert_top_level_events(self, process_global_events, expected_top_level_event_names): top_level_event_names = [] @@ -1546,11 +1788,17 @@ def _assert_top_level_events(self, process_global_events, expected_top_level_eve last_end_time = 0 for event in thread_local_events: event_name = event.name - cpu_interval = event.cpu_interval - if cpu_interval.start > last_end_time: + time_range = event.time_range + if time_range.start > last_end_time: top_level_event_names.append(event_name) - last_end_time = cpu_interval.end - self.assertEqual(sorted(top_level_event_names), sorted(expected_top_level_event_names)) + last_end_time = time_range.end + top_level_event_names = sorted(top_level_event_names) + expected_top_level_event_names = sorted(expected_top_level_event_names) + self.assertEqual( + top_level_event_names, + expected_top_level_event_names, + f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}", + ) @dist_init def test_server_process_global_profiler(self): @@ -1573,9 +1821,12 @@ def test_server_process_global_profiler(self): outer_profile_rref.rpc_sync().__exit__(None, None, None) inner_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (inner_profile_rref,)) - self._assert_top_level_events(inner_events, ['aten::sub']) + expected_inner_events = ['aten::sub'] + expected_outer_events = expected_inner_events + ['aten::add'] + + self._assert_top_level_events(inner_events, expected_inner_events) outer_events = rpc.rpc_sync(dst_worker_name, get_events_from_profile, (outer_profile_rref,)) - self._assert_top_level_events(outer_events, ['aten::add', 'aten::sub']) + self._assert_top_level_events(outer_events, expected_outer_events) inner_profile_rref.rpc_sync().key_averages() outer_profile_rref.rpc_sync().key_averages() @@ -1742,11 +1993,37 @@ def test_py_function_exception(self): @dist_init def test_py_raise_in_user_func(self): + with captured_output() as (_, err): + # This barrier prevents a race condition where the main thread has + # not entered the context manager when the remote function runs. + initialize_pg(self.file_init_method, self.rank, self.world_size) + dist.barrier() + n = self.rank + 1 + dst_rank = n % self.world_size + fut = rpc.rpc_async(worker_name(dst_rank), raise_func) + with self.assertRaisesRegex(ValueError, expected_err): + fut.wait() + # This barrier prevents a race condition where the main thread exits + # context manager before the remote function has ran. + dist.barrier() + + # Validate that trainers log errors when running functions. + stderr_lines = err.getvalue() + self.assertTrue(expected_err in stderr_lines) + + @dist_init + def test_py_raise_in_user_func_escaped_str(self): n = self.rank + 1 dst_rank = n % self.world_size - fut = rpc.rpc_async(worker_name(dst_rank), raise_func) - with self.assertRaises(ValueError): + fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape) + try: fut.wait() + except ValueError as e: + msg = str(e) + # Ensure newlines are unescaped to provide a better repr of error. + self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape")) + else: + self.assertTrue(False, "expected raise_func_escape to raise ValueError.") @dist_init def test_nested_rpc(self): @@ -2233,7 +2510,7 @@ def _test_rref_leak(self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore rpc_backend_options=self.rpc_backend_options, ) - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) # Wait for all init to complete. dist.barrier() @@ -2311,7 +2588,7 @@ def test_rref_context_debug_info(self): # The barrier before the check makes sure that all previous states are # cleared globally, the barrier after ensures that no following states # change gets into the current check. - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) # Check 1: local RRef does not update owners_ map or add a pending user. ################################################# @@ -2375,7 +2652,7 @@ def test_rref_context_debug_info(self): @dist_init def test_disable_gil_profiling(self): - # test that rpc.enable_gil_profilig(false) will result in + # test that rpc.enable_gil_profiling(false) will result in # GIL wait time not being recorded. # GIL profiling should be disabled by default. @@ -2448,7 +2725,7 @@ def test_handle_send_exceptions(self): rpc._set_rpc_timeout(10) # This barrier is needed to ensure that some workers do not exit before # others have been brought up, for non ProcessGroupAgent backends. - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) dist.barrier() if self.rank == 1: dst_rank = (self.rank + 1) % self.world_size @@ -2477,7 +2754,7 @@ def test_deadlock(self): if not dist_initialized: dist.init_process_group( backend="gloo", - init_method=self.init_method, + init_method=self.file_init_method, rank=self.rank, world_size=self.world_size, ) @@ -2502,7 +2779,7 @@ def test_local_shutdown_with_rpc(self): # A barrier is needed to ensure that all RPCs are processed. # Otherwise, some RPCs can timeout since the receiving end # has terminated. - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) dist.barrier() # pass in graceful=False to ensure that we don't wait for other workers. rpc.shutdown(graceful=False) @@ -2630,6 +2907,58 @@ class TestPickler: torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler ) + @dist_init + def test_wait_all(self): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + self.assertTrue(len(_thread_local_var.future_list) == 1) + self.assertTrue(isinstance(_thread_local_var.future_list[0], torch._C.Future)) + self.assertTrue(fut.done()) + self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_multiple_call(self): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + for i in range(20): + fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1)) + res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1)) + self.assertEqual(res, torch.ones(i, i) + 1) + self.assertEqual(fut.wait(), torch.ones(i, i) + 1) + self.assertTrue(len(_thread_local_var.future_list) == 20) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_timeout(self): + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + timeout = 0.1 # 100 ms + fut = rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_raise_in_user_func(self): + with self.assertRaises(ValueError): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, raise_func) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_raise_in_body(self): + with self.assertRaises(ValueError): + with _wait_all(): + raise_func() + self.assertFalse(hasattr(_thread_local_var, "future_list")) + @dist_init def test_function_not_on_callee(self): # test that if a function does not exist on a callee, we don't crash, @@ -2940,6 +3269,33 @@ def test_callback_none(self): ): rpc.rpc_async(dst, raise_func).then(None) + @dist_init + def test_add_done_callback(self): + set_by_cb = False + n = self.rank + 1 + + def callback(fut): + nonlocal set_by_cb + fut.wait() + set_by_cb = True + + fut = rpc.rpc_async( + worker_name(n % self.world_size), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)) + ) + + fut.add_done_callback(callback) + fut_then = fut.then(lambda _: True) + + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + + # We have no guarantee that the add_done_callback fn will execute before the test finishes. + # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback + fut_then.wait() + self.assertTrue(set_by_cb) + self.assertEqual(fut.wait(), torch.ones(n, n) * 2) + @dist_init def test_mark_future_twice(self): fut = rpc.rpc_async( @@ -3249,6 +3605,10 @@ def test_rref_timeout(self): wait_until_owners_and_forks_on_rank(1, 1, rank=1) @dist_init(setup_rpc=False) + @unittest.skipIf( + os.environ.get("RPC_INIT_WITH_TCP", None) == "1", + "init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614." + ) def test_init_pg_then_rpc(self): dist.init_process_group( backend="gloo", @@ -3276,6 +3636,10 @@ def test_init_pg_then_rpc(self): rpc.shutdown() @dist_init(setup_rpc=False) + @unittest.skipIf( + os.environ.get("RPC_INIT_WITH_TCP", None) == "1", + "init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614." + ) def test_init_rpc_then_pg(self): rpc.init_rpc( name=worker_name(self.rank), @@ -3325,8 +3689,12 @@ def test_wait_all_with_partial_exception(self): ret = torch.futures.wait_all(futs) @dist_init(setup_rpc=False) + @unittest.skipIf( + os.environ.get("RPC_INIT_WITH_TCP", None) == "1", + "Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491", + ) def test_init_rpc_twice(self): - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) rpc.init_rpc( name=worker_name(self.rank), @@ -3395,6 +3763,67 @@ def test_cannot_infer_backend_from_options(self): rpc_backend_options=rpc_backend_options, ) + @dist_init + def test_owner_rref_backward(self): + dst = worker_name((self.rank + 1) % self.world_size) + t1 = torch.rand(10, 10, requires_grad=True) + rref = rpc.RRef(t1.sum() + t1.sum()) + rref.backward() + expected_grad = torch.ones_like(t1) * 2 + self.assertEqual(expected_grad, t1.grad) + + with dist_autograd.context() as context_id: + t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) + rref = rpc.RRef(t2.sum()) + rref.backward(context_id) + self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1]) + + # Double backward. + with dist_autograd.context() as context_id: + t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) + rref = rpc.RRef(t2.sum()) + rref.backward(context_id, retain_graph=True) + rref.backward(context_id) + self.assertEqual(expected_grad * 2, dist_autograd.get_gradients(context_id)[t1]) + + # Test errors. + with self.assertRaisesRegex(RuntimeError, "tensors does not require grad and does not have a grad_fn"): + rpc.RRef(torch.rand(10)).backward() + + with self.assertRaisesRegex(RuntimeError, "grad can be implicitly created only for scalar outputs"): + rpc.RRef(torch.rand(10, requires_grad=True)).backward() + + with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id: 100"): + rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100) + + with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"): + rpc.RRef("foo").backward() + + @staticmethod + def _sum(x): + return x.sum() + + @staticmethod + def _identity(x): + return x + + @dist_init + def test_user_rref_backward(self): + dst = worker_name((self.rank + 1) % self.world_size) + t = torch.rand(10, requires_grad=True) + with dist_autograd.context() as context_id: + rref = rpc.remote(dst, RpcTest._sum, args=(t,)) + rref.backward(context_id, retain_graph=True) + rref.backward(context_id) + self.assertEqual(torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t]) + + with dist_autograd.context() as context_id: + rref = rpc.remote(dst, RpcTest._identity, args=("foo",)) + with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"): + rref.backward(context_id) + + with self.assertRaisesRegex(RuntimeError, "User RRefs require 'dist_autograd_ctx_id' to be specified"): + rref.backward() class ProcessGroupAgentRpcTest(RpcAgentTestFixture): @@ -3483,12 +3912,7 @@ def test_cuda(self): def test_single_threaded_rref_owner(self): # We need a process group in order to perform a barrier at the end. - dist.init_process_group( - backend="gloo", - init_method=self.init_method, - rank=self.rank, - world_size=self.world_size, - ) + initialize_pg(self.file_init_method, self.rank, self.world_size) # This test aims to verify if the server can handle all internal RPC # messages using just one thread. @@ -3554,12 +3978,7 @@ def test_single_threaded_rref_owner(self): def test_single_threaded_rref_to_here(self): # We need a process group in order to perform a barrier at the end. - dist.init_process_group( - backend="gloo", - init_method=self.init_method, - rank=self.rank, - world_size=self.world_size, - ) + initialize_pg(self.file_init_method, self.rank, self.world_size) # This test aims to verify if the server can handle all internal RPC # messages using just one thread. @@ -3618,7 +4037,7 @@ def test_single_threaded_rref_to_here(self): @dist_init def test_process_group_debug_info(self): rpc.enable_gil_profiling(True) - initialize_pg(self.init_method, self.rank, self.world_size) + initialize_pg(self.file_init_method, self.rank, self.world_size) NUM_THREAD = self.rpc_backend_options.num_send_recv_threads info = rpc.api._get_current_rpc_agent().get_debug_info() @@ -4203,6 +4622,7 @@ def _gpu_add(x, y): else: raise ValueError("Wrong device affinity") + @skip_if_no_peer_access @skip_if_lt_x_gpu(2) def test_device_maps_gpu(self): options = self.rpc_backend_options @@ -4226,17 +4646,270 @@ def test_device_maps_gpu(self): self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1)) rpc.shutdown() + @staticmethod + def _gpu_add_given_gpus(x, y, x_to, y_to, z_to): + if all([ + x.is_cuda, + x.device.index == x_to, + y.is_cuda, + y.device.index == y_to + ]): + return x.to(z_to) + y.to(z_to) + else: + raise ValueError("Wrong device affinity") + + def _test_device_maps_gpu(self, x_from, y_from, z_to, device_map, dst=None): + x_to = device_map[x_from] + y_to = device_map[y_from] + + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst + options.set_device_map(dst, device_map) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + x = torch.zeros(2).to(x_from) + y = torch.ones(2).to(y_from) + + ret = rpc.rpc_sync( + dst, + TensorPipeAgentRpcTest._gpu_add_given_gpus, + args=(x, y, x_to, y_to, z_to) + ) + + reverse_device_map = {device_map[k] : k for k in device_map} + z_from = reverse_device_map[z_to] + + self.assertEqual(ret.device.index, z_from) + self.assertEqual(ret, torch.ones(2).to(z_from)) + + rpc.shutdown() + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_default(self): + self._test_device_maps_gpu( + x_from=0, + y_from=0, + z_to=0, + device_map={0 : 0} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_non_default(self): + self._test_device_maps_gpu( + x_from=1, + y_from=1, + z_to=1, + device_map={1 : 1} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_default_to_non_default(self): + self._test_device_maps_gpu( + x_from=0, + y_from=0, + z_to=1, + device_map={0 : 1} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_non_default_to_default(self): + self._test_device_maps_gpu( + x_from=1, + y_from=1, + z_to=0, + device_map={1 : 0} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_1(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=0, + device_map={0 : 0, 1 : 1} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_2(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=1, + device_map={0 : 0, 1 : 1} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_3(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=0, + device_map={0 : 0, 1 : 1} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_4(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=1, + device_map={0 : 0, 1 : 1} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_5(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=0, + device_map={0 : 1, 1 : 0} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_6(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=1, + device_map={0 : 1, 1 : 0} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_7(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=0, + device_map={0 : 1, 1 : 0} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_8(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=1, + device_map={0 : 1, 1 : 0} + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_1(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=0, + device_map={0 : 0, 1 : 1}, + dst=worker_name(self.rank) + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_2(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=1, + device_map={0 : 0, 1 : 1}, + dst=worker_name(self.rank) + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_3(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=0, + device_map={0 : 0, 1 : 1}, + dst=worker_name(self.rank) + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_4(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=1, + device_map={0 : 0, 1 : 1}, + dst=worker_name(self.rank) + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_5(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=0, + device_map={0 : 1, 1 : 0}, + dst=worker_name(self.rank) + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_6(self): + self._test_device_maps_gpu( + x_from=0, + y_from=1, + z_to=1, + device_map={0 : 1, 1 : 0}, + dst=worker_name(self.rank) + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_7(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=0, + device_map={0 : 1, 1 : 0}, + dst=worker_name(self.rank) + ) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_device_map_gpu_mixed_self_8(self): + self._test_device_maps_gpu( + x_from=1, + y_from=0, + z_to=1, + device_map={0 : 1, 1 : 0}, + dst=worker_name(self.rank) + ) + @staticmethod def _gpu_add_multi_gpu(x, y): - if all([x.is_cuda, x.device.index == 0, y.is_cuda, y.device.index == 1]): - return x + y.to(0), x.to(1) - y + if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]): + return x.to(0) + y, x - y.to(1) else: raise ValueError("Wrong device affinity") def _test_device_maps_multi_gpu(self, dst): options = self.rpc_backend_options - options.set_device_map(dst, {1: 0}) options.set_device_map(dst, {0: 1}) + options.set_device_map(dst, {1: 0}) rpc.init_rpc( name=worker_name(self.rank), @@ -4246,22 +4919,27 @@ def _test_device_maps_multi_gpu(self, dst): rpc_backend_options=options, ) + x = torch.zeros(2).to(0) + y = torch.ones(2).to(1) rets = rpc.rpc_sync( dst, TensorPipeAgentRpcTest._gpu_add_multi_gpu, - args=(torch.zeros(2).to(1), torch.ones(2).to(0)) + args=(x, y) ) + self.assertEqual(rets[0].device, torch.device(1)) self.assertEqual(rets[1].device, torch.device(0)) self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1)) self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0)) rpc.shutdown() + @skip_if_no_peer_access @skip_if_lt_x_gpu(2) def test_device_maps_multi_gpu(self): dst = worker_name((self.rank + 1) % self.world_size) self._test_device_maps_multi_gpu(dst) + @skip_if_no_peer_access @skip_if_lt_x_gpu(2) def test_device_maps_multi_gpu_self(self): dst = worker_name(self.rank) @@ -4274,6 +4952,7 @@ def _gpu_add_return_to_gpu(x, y): else: raise ValueError("Wrong device affinity") + @skip_if_no_peer_access @skip_if_lt_x_gpu(2) def test_device_maps_in_options(self): dst = worker_name((self.rank + 1) % self.world_size) @@ -4294,7 +4973,7 @@ def test_device_maps_in_options(self): rets = rpc.rpc_sync( dst, TensorPipeAgentRpcTest._gpu_add_multi_gpu, - args=(torch.zeros(2).to(1), torch.ones(2).to(0)) + args=(torch.zeros(2).to(0), torch.ones(2).to(1)) ) self.assertEqual(rets[0].device, torch.device(1)) self.assertEqual(rets[1].device, torch.device(0)) @@ -4331,11 +5010,13 @@ def _test_device_maps_return_to_gpu(self, dst): self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2)) rpc.shutdown() + @skip_if_no_peer_access @skip_if_lt_x_gpu(4) def test_device_maps_return_to_gpu(self): dst = worker_name((self.rank + 1) % self.world_size) self._test_device_maps_return_to_gpu(dst) + @skip_if_no_peer_access @skip_if_lt_x_gpu(4) def test_device_maps_return_to_gpu_self(self): dst = worker_name(self.rank) @@ -4348,7 +5029,7 @@ def _add_to_gpu(x, y): def _test_device_maps_missing_config(self, mode): dst = worker_name((self.rank + 1) % self.world_size) errMsg = ( - "TensorPipeAgent only supports CPU tensors by default.*" + "TensorPipe RPC backend only supports CPU tensors by default.*" "`set_device_map` on `TensorPipeRpcBackendOptions`" ) @@ -4388,38 +5069,68 @@ def _test_device_maps_missing_config_response(self, mode): ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1)) self.assertEqual(ret, torch.ones(2) + 1) + @skip_if_no_peer_access @skip_if_lt_x_gpu(1) @dist_init def test_device_maps_missing_config(self): self._test_device_maps_missing_config(RPCExecMode.SYNC) + @skip_if_no_peer_access + @skip_if_lt_x_gpu(1) + def test_device_maps_missing_config_not_timeout(self): + dst = worker_name((self.rank + 1) % self.world_size) + options = self.rpc_backend_options + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options + ) + + timeout = rpc.get_rpc_timeout() + + tik = time.time() + self._test_device_maps_missing_config(RPCExecMode.SYNC) + rpc.shutdown() + tok = time.time() + + self.assertTrue(tok - tik < timeout) + + @skip_if_no_peer_access @skip_if_lt_x_gpu(1) @dist_init def test_device_maps_missing_config_loop(self): for _ in range(self.rpc_backend_options.num_worker_threads + 5): self._test_device_maps_missing_config(RPCExecMode.SYNC) + @skip_if_no_peer_access @skip_if_lt_x_gpu(1) @dist_init def test_device_maps_missing_config_response(self): self._test_device_maps_missing_config_response(RPCExecMode.SYNC) + @skip_if_no_peer_access @skip_if_lt_x_gpu(1) @dist_init def test_device_maps_missing_config_response_loop(self): for _ in range(self.rpc_backend_options.num_worker_threads + 5): self._test_device_maps_missing_config_response(RPCExecMode.SYNC) + @skip_if_no_peer_access @skip_if_lt_x_gpu(1) @dist_init def test_device_maps_missing_config_remote(self): self._test_device_maps_missing_config(RPCExecMode.REMOTE) + @skip_if_no_peer_access @skip_if_lt_x_gpu(1) @dist_init def test_device_maps_missing_config_remote_response(self): self._test_device_maps_missing_config_response(RPCExecMode.REMOTE) + @skip_if_no_peer_access @skip_if_lt_x_gpu(2) def test_device_maps_remote(self): options = self.rpc_backend_options @@ -4440,6 +5151,190 @@ def test_device_maps_remote(self): args=(torch.zeros(2), 1) ) + self.assertEqual(rref.to_here().device.index, 1) self.assertEqual(rref.to_here(), torch.ones(2).to(1)) rpc.shutdown() + + @staticmethod + def _slow_add_on_user_stream(x, y): + s0 = torch.cuda.current_stream(x.device) + s1 = torch.cuda.Stream(device=x.device) + with torch.cuda.stream(s1): + torch.cuda._sleep(10 * FIFTY_MIL_CYCLES) + s1.wait_stream(s0) + z = x + y + event = torch.cuda.Event() + event.record(s1) + event.wait(s0) + return z + + def _test_custom_stream(self, fn, device_map): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + options.set_device_map(dst, device_map) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + fn(dst) + + rpc.shutdown() + + def _test_stream_sync(self, dst): + x = torch.ones(2, 2).to(0) + ret = rpc.rpc_sync( + dst, + TensorPipeAgentRpcTest._slow_add_on_user_stream, + args=(x, x) + ) + self.assertEqual(ret, 2 * x) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_custom_stream(self): + self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"}) + + def _test_stream_multi_async(self, dst): + futs = [] + for i in range(20): + x = torch.ones(2, 2).to(0) * i + futs.append( + rpc.rpc_async( + dst, + TensorPipeAgentRpcTest._slow_add_on_user_stream, + args=(x, x) + ) + ) + + for i in range(20): + self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_custom_stream_multi(self): + self._test_custom_stream( + self._test_stream_multi_async, + {"cuda:0": "cuda:1"} + ) + + @staticmethod + def _nested_slow_add_on_user_stream(dst, x, y, z): + ret = rpc.rpc_sync( + dst, + TensorPipeAgentRpcTest._slow_add_on_user_stream, + args=(x, y) + ) + + return TensorPipeAgentRpcTest._slow_add_on_user_stream(ret, z) + + def _test_stream_nested_sync(self, dst): + x = torch.ones(2, 2).to(0) + y = torch.ones(2, 2).to(0) * 2 + z = torch.ones(2, 2).to(0) * 3 + nested_dst = worker_name((self.rank + 2) % self.world_size) + ret = rpc.rpc_sync( + dst, + TensorPipeAgentRpcTest._nested_slow_add_on_user_stream, + args=(nested_dst, x, y, z) + ) + self.assertEqual(ret, 6 * x) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_custom_stream_nested(self): + self._test_custom_stream( + self._test_stream_nested_sync, + {"cuda:0": "cuda:1", "cuda:1": "cuda:0"} + ) + + def _test_stream_nested_multi_async(self, dst): + if self.rank == 0: + futs = [] + n = 5 + xs, ys, zs = [], [], [] + for i in range(n): + x = torch.ones(2, 2).to(0) * (i - 1) + y = torch.ones(2, 2).to(0) * i + z = torch.ones(2, 2).to(0) * (i + 1) + xs.append(x) + ys.append(y) + zs.append(z) + nested_dst = worker_name((self.rank + 2) % self.world_size) + futs.append( + rpc.rpc_async( + dst, + TensorPipeAgentRpcTest._nested_slow_add_on_user_stream, + args=(nested_dst, x, y, z) + ) + ) + + for i in range(n): + self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i]) + + @skip_if_no_peer_access + @skip_if_lt_x_gpu(2) + def test_custom_stream_nested_multi(self): + self._test_custom_stream( + self._test_stream_nested_multi_async, + {"cuda:0": "cuda:1", "cuda:1": "cuda:0"} + ) + + @dist_init + def test_rref_get_type_timeout(self): + # Test where we try to get the type of a RRef from an owner, but RRef + # creation is slower than timeout passed into _get_type. + dst_rank = (self.rank + 1) % self.world_size + dst = worker_name(dst_rank) + slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) + timeout = 0.5 + expected_err = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_err): + slow_rref._get_type(timeout=timeout) + + @dist_init + def test_op_with_invalid_args(self): + dst = worker_name((self.rank + 1) % self.world_size) + with self.assertRaisesRegex( + RuntimeError, "Overloaded torch operator invoked from Python failed to many any schema" + ): + rpc.rpc_sync(dst, torch.add, args=()) + + def _test_rref_proxy_timeout(self, rref_proxy_api): + dst_rank = (self.rank + 1) % self.world_size + dst = worker_name(dst_rank) + rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), )) + # Ensure RRef is created on remote node. + rref.to_here() + rref_api = getattr(rref, rref_proxy_api) + self.assertTrue(rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}") + expected_error = self.get_timeout_error_regex() + timeout = 2 + with self.assertRaisesRegex(RuntimeError, expected_error): + result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2)) + if rref_api == rref.rpc_async: + result.wait() + elif rref_api == rref.remote: + result._get_future().wait() + + # Case where rpc.remote() is stuck and exceeds timeout + slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True)) + timeout = 0.01 + rref_api = getattr(slow_rref, rref_proxy_api) + # Note that even when we call rref.rpc_async() in this case, we + # time out in future creation, not waiting for future. This is because + # rref proxy function calls rref._get_type before returning future, + # which blocks on the RRef being created on owner node, until the + # specified timeout. + with self.assertRaisesRegex(RuntimeError, expected_error): + rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2)) + + @dist_init + def test_rref_proxy_timeout(self): + for rpc_api in ["rpc_sync", "rpc_async", "remote"]: + self._test_rref_proxy_timeout(rpc_api) diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index a039947342a6d..d35f3da5d2c2c 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -1,14 +1,22 @@ #!/usr/bin/env python3 +import os import unittest from enum import Flag, auto from typing import Dict, List, Type from torch.testing._internal.common_distributed import MultiProcessTestCase -from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN +from torch.testing._internal.common_utils import ( + TEST_WITH_ASAN, + TEST_WITH_TSAN, + find_free_port, +) from torch.testing._internal.distributed.ddp_under_dist_autograd_test import ( DdpComparisonTest, DdpUnderDistAutogradTest, ) +from torch.testing._internal.distributed.pipe_with_ddp_test import ( + PipeWithDDPTest, +) from torch.testing._internal.distributed.nn.api.remote_module_test import ( RemoteModuleTest, ) @@ -35,7 +43,22 @@ RpcTest, TensorPipeAgentRpcTest, ) - +from torch.testing._internal.distributed.rpc.examples.parameter_server_test import ParameterServerTest + +def _check_and_set_tcp_init(): + # if we are running with TCP init, set main address and port + # before spawning subprocesses, since different processes could find + # different ports. + use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) + if use_tcp_init == "1": + os.environ["MASTER_ADDR"] = '127.0.0.1' + os.environ["MASTER_PORT"] = str(find_free_port()) + +def _check_and_unset_tcp_init(): + use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) + if use_tcp_init == "1": + del os.environ["MASTER_ADDR"] + del os.environ["MASTER_PORT"] # The tests for the RPC module need to cover multiple possible combinations: # - different aspects of the API, each one having its own suite of tests; @@ -53,8 +76,12 @@ class ForkHelper(MultiProcessTestCase): def setUp(self): super().setUp() + _check_and_set_tcp_init() self._fork_processes() + def tearDown(self): + _check_and_unset_tcp_init() + super().tearDown() @unittest.skipIf( TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues" @@ -62,8 +89,13 @@ def setUp(self): class SpawnHelper(MultiProcessTestCase): def setUp(self): super().setUp() + _check_and_set_tcp_init() self._spawn_processes() + def tearDown(self): + _check_and_unset_tcp_init() + super().tearDown() + class MultiProcess(Flag): FORK = auto() @@ -83,6 +115,7 @@ class MultiProcess(Flag): # for each agent (except the faulty agent, which is special). GENERIC_TESTS = [ RpcTest, + ParameterServerTest, DistAutogradTest, DistOptimizerTest, JitRpcTest, @@ -90,6 +123,7 @@ class MultiProcess(Flag): RemoteModuleTest, DdpUnderDistAutogradTest, DdpComparisonTest, + PipeWithDDPTest, ] diff --git a/torch/testing/_internal/expecttest.py b/torch/testing/_internal/expecttest.py index 0f8f4619c6554..4dae7ebf03dcc 100644 --- a/torch/testing/_internal/expecttest.py +++ b/torch/testing/_internal/expecttest.py @@ -3,6 +3,50 @@ import traceback import os import string +from typing import Tuple + + +# This file implements expect tests (also known as "golden" tests). +# Expect tests are a method of writing tests where instead of +# hard-coding the expected output of a test, you instead run the test to +# get the output, and the test framework automatically populates the +# expected output. If the output of the test changes, you can rerun the +# test with EXPECTTEST_ACCEPT=1 environment variable to automatically +# update the expected output. +# +# Somewhat unusually, this file implements *inline* expect tests: that +# is to say, the expected output isn't save to an external file, it is +# saved directly in the Python file (and we modify your Python the file +# when updating the expect test.) +# +# The general recipe for how to use this is as follows: +# +# 1. Write your test and use assertExpectedInline() instead of +# a normal assertEqual. Leave the expected argument blank +# with an empty string: +# +# self.assertExpectedInline(some_func(), "") +# +# 2. Run your test. It should fail, and you get an error message +# about accepting the output with EXPECTTEST_ACCEPT=1 +# +# 3. Rerun the test with EXPECTTEST_ACCEPT=1. Now the previously +# blank string literal will now contain the expected value of +# the test. +# +# self.assertExpectedInline(some_func(), "my_value") +# +# Some tips and tricks: +# +# - Often, you will want to expect test on a multiline string. This +# framework understands triple-quoted strings, so you can just +# write """my_value""" and it will turn into triple-quoted +# strings. +# +# - Take some time thinking about how exactly you want to design +# the output format of the expect test. It is often profitable +# to design an output representation specifically for expect tests. +# ACCEPT = os.getenv('EXPECTTEST_ACCEPT') @@ -96,7 +140,8 @@ def ok_for_raw_triple_quoted_string(s, quote): r"(?Pr?)", re.DOTALL) -def replace_string_literal(src, lineno, new_string): +def replace_string_literal(src : str, lineno : int, + new_string : str) -> Tuple[str, int]: r""" Replace a triple quoted string literal with new contents. Only handles printable ASCII correctly at the moment. This @@ -161,6 +206,17 @@ class TestCase(unittest.TestCase): longMessage = True def assertExpectedInline(self, actual, expect, skip=0): + """ + Assert that actual is equal to expect. The expect argument + MUST be a string literal (triple-quoted strings OK), and will + get updated directly in source when you run the test suite + with EXPECTTEST_ACCEPT=1. + + If you want to write a helper function that makes use of + assertExpectedInline (e.g., expect is not a string literal), + set the skip argument to how many function calls we should + skip to find the string literal to update. + """ if ACCEPT: if actual != expect: # current frame and parent frame, plus any requested skip @@ -197,6 +253,11 @@ def assertExpectedInline(self, actual, expect, skip=0): self.assertEqual(expect, actual, msg=help_text) def assertExpectedRaisesInline(self, exc_type, callable, expect, *args, **kwargs): + """ + Like assertExpectedInline, but tests the str() representation of + the raised exception from callable. The raised exeption must + be exc_type. + """ try: callable(*args, **kwargs) except exc_type as e: diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 0f7370ba23b96..2acc380579e5f 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -16,6 +16,8 @@ # Testing utils from torch._six import inf + +# TODO: include files like this should not set the default dtype torch.set_default_dtype(torch.double) L = 20 @@ -138,6 +140,7 @@ ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),), ('pixel_shuffle', (1, 9, 4, 4), (3,),), + ('pixel_unshuffle', (1, 1, 12, 12), (3,),), ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),), ('pad', (3, 3, 4, 2), ([1, 1],),), ('pairwise_distance', (S, S), ((S, S),),), @@ -227,8 +230,15 @@ def the_method({}): return {} ''' +def value_to_literal(value): + if isinstance(value, str): + # Quotes string and escapes special characters + return ascii(value) + else: + return str(value) + def get_call(method_name, func_type, args, kwargs): - kwargs_str = ', '.join([k + '=' + str(v) for k, v in kwargs.items()]) + kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()]) self_arg = args[0] if(func_type == 'method'): args = args[1:] @@ -237,7 +247,7 @@ def get_call(method_name, func_type, args, kwargs): argument_str += ', ' if len(args) and len(kwargs) else '' argument_str += kwargs_str - if func_type == 'functional': + if func_type == 'functional' or func_type == 'function': call = 'torch.{}({})'.format(method_name, argument_str) elif func_type == 'method': call = '{}.{}({})'.format(self_arg, method_name, argument_str) @@ -459,6 +469,13 @@ def make_module(script): return module return script_module +def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'): + formals, tensors, actuals = get_script_args(args) + call = get_call(method_name, func_type, actuals, kwargs) + script = script_template.format(', '.join(formals), call) + CU = torch.jit.CompilationUnit(script) + torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name) + def get_nn_module_name_from_kwargs(**kwargs): if 'module_name' in kwargs: return kwargs['module_name'] diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 732260573ecc5..261a9ce9dfdcc 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -15,16 +15,16 @@ # Testing utils from torch.testing import FileCheck -from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \ - freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS +from torch.testing._internal.common_utils import IS_WINDOWS, \ + freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS +from torch.testing._internal.common_jit import JitCommonTestCase from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 # Standard library from contextlib import contextmanager from functools import reduce -from itertools import chain from torch._six import StringIO -from typing import Any, Dict +from collections import defaultdict import inspect import io @@ -34,6 +34,7 @@ import sys import tempfile import textwrap +from typing import Any, Dict, List RUN_CUDA = torch.cuda.is_available() RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1 @@ -55,6 +56,7 @@ def do_input_map(fn, input): def clear_class_registry(): torch._C._jit_clear_class_registry() torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._script_classes.clear() def get_execution_plan(graph_executor_state): execution_plans = list(graph_executor_state.execution_plans.values()) @@ -89,8 +91,9 @@ def __exit__(self, type, value, traceback): return True +FUSION_GROUP = "prim::TensorExprGroup" -class JitTestCase(TestCase): +class JitTestCase(JitCommonTestCase): _do_cuda_memory_leak_check = True _restored_warnings = False @@ -132,6 +135,36 @@ def tearDown(self): self.clearHooks() clear_class_registry() + def assertAllFused(self, graph, except_for=()): + + # note this helper collects nodes on 'fast path' only + # i.e. the true blocks of specialized checks + def get_nodes_and_parents_recursively(block, kind, acc): + for node in block.nodes(): + if node.kind() == kind: + acc[block].append(node) + elif node.kind() == 'prim::DifferentiableGraph': + get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc) + elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or + node.inputs().__next__().node().kind() == 'prim::TypeCheck' or + node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'): + get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc) + else: + for inner_block in node.blocks(): + get_nodes_and_parents_recursively(inner_block, kind, acc) + + allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate', + 'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for) + + fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list) + get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups) + self.assertTrue(len(fusion_groups) == 1, 'got {}'.format(graph)) + (graph, fusion_nodes) = list(fusion_groups.items())[0] + # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes` + self.assertTrue(len(fusion_nodes) == 1, 'got {}'.format(graph)) + self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), + 'got {}'.format(graph)) + def _isHookExceptionOk(self, e): se = str(e) allowed = ("Could not export Python function", @@ -150,13 +183,13 @@ def extract_files(buffer): files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) # unwrap all the code files into strings code_files_str = filter(lambda x: x.endswith('.py'), files) - code_files_stream = map(lambda f: archive.open(f), code_files_str) - code_files = map(lambda file: "".join([line.decode() for line in file]), code_files_stream) + code_files_stream = (archive.open(f) for f in code_files_str) + code_files = ("".join([line.decode() for line in file]) for file in code_files_stream) # unpickled all the debug files debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files) - debug_files_stream = map(lambda f: archive.open(f), debug_files_str) - debug_files = map(lambda f: pickle.load(f), debug_files_stream) + debug_files_stream = (archive.open(f) for f in debug_files_str) + debug_files = (pickle.load(f) for f in debug_files_stream) return code_files, debug_files # disable the hook while we parse code, otherwise we will re-enter the hook @@ -213,19 +246,6 @@ def emitModuleHook(self, module): self._compared_saved_loaded(module) - def getExportImportCopy(self, m, also_test_file=True, map_location=None): - buffer = io.BytesIO() - torch.jit.save(m, buffer) - buffer.seek(0) - imported = torch.jit.load(buffer, map_location=map_location) - - if not also_test_file: - return imported - - with TemporaryFileName() as fname: - torch.jit.save(imported, fname) - return torch.jit.load(fname, map_location=map_location) - def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None): buffer = io.BytesIO() m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None) @@ -301,22 +321,6 @@ def assertExpectedGraph(self, trace, *args, **kwargs): torch._C._jit_pass_lint(graph) self.assertExpected(str(graph), *args, **kwargs) - def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): - diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') - diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] - - # For any non-fusible node, it must show up in one of the DifferentiableGraph. - found_all_nonfusible_nodes = (len(diff_subgraphs) == 0 and len(nonfusible_nodes) == 0)\ - or all([any(g.findNode(n) is not None for g in diff_subgraphs) for n in nonfusible_nodes]) - - # For any fusible node, it must show up in one of the FusionGroup in the DifferentiableGraph. - fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs])) - fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes] - found_all_fusible_nodes = (len(fusion_nodes) == 0 and len(fusible_nodes) == 0)\ - or all([any(g.findNode(n) is not None for g in fusion_subgraphs) for n in fusible_nodes]) - - self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes) - def run_pass(self, name, trace): if isinstance(trace, torch._C.Graph): graph = trace @@ -568,26 +572,6 @@ def input_reduce(input, fn, acc): return ge - def createFunctionFromGraph(self, trace): - graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() - return torch._C._create_function_from_graph("forward", graph) - - def assertExportImport(self, trace, inputs): - m = self.createFunctionFromGraph(trace) - self.assertExportImportModule(m, inputs) - - def assertExportImportModule(self, m, inputs): - m_import = self.getExportImportCopy(m) - a = self.runAndSaveRNG(m, inputs) - b = self.runAndSaveRNG(m_import, inputs) - self.assertEqual(a, b) - - def runAndSaveRNG(self, func, inputs, kwargs=None): - kwargs = kwargs if kwargs else {} - with freeze_rng_state(): - results = func(*inputs, **kwargs) - return results - def checkModule(self, nn_module, args): """ Check that a nn.Module's results in Script mode match eager and that it @@ -615,6 +599,14 @@ def inline_everything_mode(should_inline): finally: torch._C._jit_set_inline_everything_mode(old) +@contextmanager +def set_fusion_group_inlining(inlining): + old = torch._C._debug_get_fusion_group_inlining() + torch._C._debug_set_fusion_group_inlining(inlining) + try: + yield + finally: + torch._C._debug_set_fusion_group_inlining(old) # note: not re-entrant, use unnested only @contextmanager diff --git a/torch/testing/_internal/test_module/future_div.py b/torch/testing/_internal/test_module/future_div.py index 2cdbf9ca129bd..3f042188490c0 100644 --- a/torch/testing/_internal/test_module/future_div.py +++ b/torch/testing/_internal/test_module/future_div.py @@ -1,5 +1,4 @@ from __future__ import division -import torch # noqa: F401 def div_int_future(): diff --git a/torch/testing/check_kernel_launches.py b/torch/testing/check_kernel_launches.py new file mode 100644 index 0000000000000..c274316b54fe3 --- /dev/null +++ b/torch/testing/check_kernel_launches.py @@ -0,0 +1,119 @@ +import os +import re +import sys + + +# Regular expression identifies a kernel launch indicator by +# finding something approximating the pattern ">>>(arguments);" +# It then requires that `C10_CUDA_KERNEL_LAUNCH_CHECK` be +# the next command. +# It allows a single backslash `\` between the end of the launch +# command and the beginning of the kernel check. This handles +# cases where the kernel launch is in a multiline preprocessor +# definition. +# +# There are various ways this can fail: +# * If the semicolon is in a string for some reason +# * If there's a triply-nested template +# But this should be sufficient to detect and fix most problem +# instances and can be refined before the test is made binding +kernel_launch_regex = re.compile(r""" + ^.*>>> # Identifies kernel launch + \s* # Maybe some whitespace (includes newlines) + \([^;]+\); # And then arguments in parens and semi-colon + (?! # Negative lookahead: we trigger if we don't find the launch guard + \s* # Maybe some whitespace (includes newlines) + \\? # 0 or 1 backslashes (for launches in preprocessor macros) + \s* # Maybe some whitespace (includes newlines) + (?:[0-9]+: )? # Detects and ignores a line numbering, if present + \s* # Maybe some whitespace (includes newlines) + C10_CUDA_KERNEL_LAUNCH_CHECK\(\); # Kernel launch guard! + ) # End negative lookahead +""", flags=re.MULTILINE | re.VERBOSE) + + +def check_code_for_cuda_kernel_launches(code, filename=None): + """Checks code for CUDA kernel launches without cuda error checks. + + Args: + filename - Filename of file containing the code. Used only for display + purposes, so you can put anything here. + code - The code to check + + Returns: + The number of unsafe kernel launches in the code + """ + if filename is None: + filename = "##Python Function Call##" + + # We break the code apart and put it back together to add + # helpful line numberings for identifying problem areas + code = enumerate(code.split("\n")) # Split by line breaks + code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines + code = '\n'.join(code) # Put it back together + + results = kernel_launch_regex.findall(code) # Search for bad launches + for r in results: + print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{r}", file=sys.stderr) + return len(results) + + +def check_file(filename): + """Checks a file for CUDA kernel launches without cuda error checks + + Args: + filename - File to check + + Returns: + The number of unsafe kernel launches in the file + """ + if not (filename.endswith(".cu") or filename.endswith(".cuh")): + return 0 + contents = open(filename, "r").read() + return check_code_for_cuda_kernel_launches(contents, filename) + + +def check_cuda_kernel_launches(): + """Checks all pytorch code for CUDA kernel launches without cuda error checks + + Returns: + The number of unsafe kernel launches in the codebase + """ + torch_dir = os.path.dirname(os.path.realpath(__file__)) + torch_dir = os.path.dirname(torch_dir) # Go up to parent torch + torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2 + + kernels_without_checks = 0 + files_without_checks = [] + for root, dirnames, filenames in os.walk(torch_dir): + # `$BASE/build` and `$BASE/torch/include` are generated + # so we don't want to flag their contents + if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"): + # Curtail search by modifying dirnames and filenames in place + # Yes, this is the way to do this, see `help(os.walk)` + dirnames[:] = [] + continue + + for x in filenames: + filename = os.path.join(root, x) + file_result = check_file(filename) + if file_result > 0: + kernels_without_checks += file_result + files_without_checks.append(filename) + + if kernels_without_checks > 0: + count_str = f"Found {kernels_without_checks} instances in " \ + f"{len(files_without_checks)} files where kernel " \ + "launches didn't have checks." + print(count_str, file=sys.stderr) + print("Files without checks:", file=sys.stderr) + for x in files_without_checks: + print(f"\t{x}", file=sys.stderr) + print(count_str, file=sys.stderr) + + return kernels_without_checks + + +if __name__ == "__main__": + unsafe_launches = check_cuda_kernel_launches() + sys.exit(0) diff --git a/torch/utils/_benchmark/__init__.py b/torch/utils/_benchmark/__init__.py deleted file mode 100644 index 30a9543e45443..0000000000000 --- a/torch/utils/_benchmark/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from torch.utils._benchmark.utils.common import * -from torch.utils._benchmark.utils.timer import * -from torch.utils._benchmark.utils.compare import * -from torch.utils._benchmark.utils.fuzzer import * diff --git a/torch/utils/_benchmark/utils/common.py b/torch/utils/_benchmark/utils/common.py deleted file mode 100644 index 57000cc148974..0000000000000 --- a/torch/utils/_benchmark/utils/common.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Base shared classes and utilities.""" - -import collections -import contextlib -import logging -from typing import Any, Dict, List, Optional - -import numpy as np -import torch - - -__all__ = ["Measurement"] - - -_MAX_SIGNIFICANT_FIGURES = 4 -_MIN_CONFIDENCE_INTERVAL = 25e-9 # 25 ns - -# Measurement will include a warning if the distribution is suspect. All -# runs are expected to have some variation; these parameters set the -# thresholds. -_IQR_WARN_THRESHOLD = 0.1 -_IQR_GROSS_WARN_THRESHOLD = 0.25 - - -class Measurement: - """The result of a Timer measurement. - - This class stores one or more measurements of a given statement. It is - serializable and provides several convenience methods - (including a detailed __repr__) for downstream consumers. - """ - def __init__( - self, - number_per_run: int, - times: List[float], - num_threads: int, - label: Optional[str], - sub_label: Optional[str], - description: Optional[str], - env: Optional[str], - stmt: Optional[str], - metadata: Optional[dict] = None, - ): - self.number_per_run = number_per_run - self.times = times - self.label = label - self.sub_label = sub_label - self.description = description - self._env = env - self.num_threads = num_threads - self.stmt = stmt - self.metadata = metadata - - # Derived attributes - self._sorted_times = sorted([t / number_per_run for t in times]) - self._median = np.median(self._sorted_times) - self._bottom_quartile = np.percentile(self._sorted_times, 25) - self._top_quartile = np.percentile(self._sorted_times, 75) - self._iqr = self._top_quartile - self._bottom_quartile - self._warnings = self._populate_warnings() - - # Pickle support. - def __getstate__(self): - return { - "label": self.label, - "sub_label": self.sub_label, - "description": self.description, - "env": self._env, - "num_threads": self.num_threads, - "number_per_run": self.number_per_run, - "times": self.times, - "stmt": self.stmt, - "metadata": self.metadata, - } - - def __setstate__(self, state: Dict[str, Any]): - self.__init__(**state) # type: ignore - - def meets_confidence(self, threshold=_IQR_WARN_THRESHOLD): - return self._iqr / self._median < threshold - - def _populate_warnings(self): - warnings, rel_iqr = [], self._iqr / self._median * 100 - - def add_warning(msg): - warnings.append( - f" WARNING: Interquartile range is {rel_iqr:.1f}% " - f"of the median measurement.\n {msg}" - ) - - if self._iqr / self._median > _IQR_GROSS_WARN_THRESHOLD: - add_warning("This suggests significant environmental influence.") - elif not self.meets_confidence(): - add_warning("This could indicate system fluctuation.") - return warnings - - @property - def median(self) -> float: - return self._median - - @property - def significant_figures(self) -> int: - """Approximate significant figure estimate. - - This property is intended to give a convenient way to estimate the - precision of a measurement. It only uses the interquartile region to - estimate statistics to try to mitigate skew from the tails, and - uses a static z value of 1.645 since it is not expected to be used - for small values of `n`, so z can approximate `t`. - - The significant figure estimation used in conjunction with the - `trim_sigfig` method to provide a more human interpretable data - summary. __repr__ does not use this method; it simply displays raw - values. Significant figure estimation is intended for `Compare`. - """ - n_total = len(self._sorted_times) - lower_bound = int(n_total // 4) - upper_bound = int(np.ceil(3 * n_total / 4)) - interquartile_points = self._sorted_times[lower_bound:upper_bound] - std = np.std(interquartile_points) - sqrt_n = np.sqrt(len(interquartile_points)) - - # Rough estimates. These are by no means statistically rigorous. - confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL) - relative_ci = np.log10(self._median / confidence_interval) - num_significant_figures = int(np.floor(relative_ci)) - return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES) - - @property - def title(self) -> str: - """Best effort attempt at a string label for the measurement.""" - if self.label is not None: - label = self.label - elif isinstance(self.stmt, str): - label = self.stmt - else: - label = "[Missing primary label]" - - return label + (f": {self.sub_label}" if self.sub_label else "") - - @property - def env(self) -> str: - return "Unspecified env" if self._env is None else self._env - - @property - def as_row_name(self) -> str: - return self.sub_label or self.stmt or "[Unknown]" - - @property - def has_warnings(self): - return bool(self._warnings) - - def __repr__(self): - """ - Example repr: - - Broadcasting add (4x8) - Median: 5.73 us - IQR: 2.25 us (4.01 to 6.26) - 372 measurements, 100 runs per measurement, 1 thread - WARNING: Interquartile range is 39.4% of the median measurement. - This suggests significant environmental influence. - """ - repr = [super().__repr__(), "\n", self.title, "\n"] - if self.description: - repr.extend([self.description, "\n"]) - n = len(self._sorted_times) - - time_unit, time_scale = select_unit(self.median) - repr.extend([ - f" {'Median: ' if n > 1 else ''}" - f"{self._median / time_scale:.2f} {time_unit}\n" - ]) - if n >= 4: - repr.extend( - [ - f" IQR: {self._iqr / time_scale:.2f} {time_unit} " - f"({self._bottom_quartile / time_scale:.2f} to " - f"{self._top_quartile / time_scale:.2f})\n", - ] - ) - repr.extend( - [ - f" {len(self.times)} measurement{'s' if n > 1 else ''}, " - f"{self.number_per_run} runs {'per measurement,' if n > 1 else ','} " - f"{self.num_threads} thread{'s' if self.num_threads > 1 else ''}\n" - ] - ) - repr.extend(self._warnings) - - return "".join(repr).strip() - - -def select_unit(t: float): - """Determine how to scale times for O(1) magnitude. - - This utility is used to format numbers for human consumption. - """ - time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(np.log10(t) // 3), "s") - time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit] - return time_unit, time_scale - - -def unit_to_english(u: str) -> str: - return { - "ns": "nanosecond", - "us": "microsecond", - "ms": "millisecond", - "s": "second", - }[u] - - -def trim_sigfig(x: float, n: int) -> float: - """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)""" - assert n == int(n) - magnitude = int(np.ceil(np.log10(np.abs(x)))) - scale = 10 ** (magnitude - n) - return np.round(x / scale) * scale - - -def ordered_unique(elements): - return list(collections.OrderedDict({i: None for i in elements}).keys()) - - -def merge_measurements(measurements: List[Measurement]): - grouped_measurements = collections.defaultdict(list) - for m in measurements: - key = (m.label, m.sub_label, m.description, m.env, m.num_threads) - grouped_measurements[key].append(m) - - def merge_group(label, sub_label, description, env, num_threads, group): - times = [] - for m in group: - # Different measurements could have different `number_per_run`. - times.extend([t / m.number_per_run for t in m.times]) - unique_stmts = {m.stmt for m in group} - if len(unique_stmts) != 1: - logging.warning( - "Merged Examples with identical `label`, `sub_label`,\n" - "`description`, `env`, and `num_threads`, but different" - "`stmt`s:\n " + "\n ".join(unique_stmts) - ) - return Measurement( - number_per_run=1, - times=times, - num_threads=num_threads, - label=label, - sub_label=sub_label, - description=description, - env=env, - stmt=unique_stmts.pop(), - metadata=None, - ) - - return [ - merge_group(*(key + (group,))) - for key, group in grouped_measurements.items() - ] - - -@contextlib.contextmanager -def set_torch_threads(n: int): - prior_num_threads = torch.get_num_threads() - try: - torch.set_num_threads(n) - yield - finally: - torch.set_num_threads(prior_num_threads) diff --git a/torch/utils/_benchmark/utils/timer.py b/torch/utils/_benchmark/utils/timer.py deleted file mode 100644 index c78db2740c2fa..0000000000000 --- a/torch/utils/_benchmark/utils/timer.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Timer class based on the timeit.Timer class, but torch aware.""" - -import timeit -from typing import List, Optional - -import numpy as np -import torch -from torch.utils._benchmark.utils import common - - -__all__ = ["Timer"] - - -if torch.has_cuda and torch.cuda.is_available(): - def timer(): - torch.cuda.synchronize() - return timeit.default_timer() -else: - timer = timeit.default_timer - - -class Timer(object): - _timer_cls = timeit.Timer - - def __init__( - self, - stmt="pass", - setup="pass", - timer=timer, - globals: Optional[dict] = None, - label: Optional[str] = None, - sub_label: Optional[str] = None, - description: Optional[str] = None, - env: Optional[str] = None, - num_threads=1, - ): - if not isinstance(stmt, str): - raise ValueError("Currently only a `str` stmt is supported.") - - # We copy `globals` to prevent mutations from leaking, (for instance, - # `eval` adds the `__builtins__` key) and include `torch` if not - # specified as a convenience feature. - globals = dict(globals or {}) - globals.setdefault("torch", torch) - - self._stmt = stmt - self._label = label - self._sub_label = sub_label - self._description = description - self._env = env - self._num_threads = num_threads - self._timer = self._timer_cls(stmt=stmt, setup=setup, timer=timer, globals=globals) - - def _construct_measurement(self, number_per_run: int, times: List[float]): - return common.Measurement( - number_per_run=number_per_run, - times=times, - num_threads=self._num_threads, - label=self._label, - sub_label=self._sub_label, - description=self._description, - env=self._env, - stmt=self._stmt, - ) - - def timeit(self, number=1000000): - # Warmup - self._timer.timeit(number=max(int(number // 100), 1)) - with common.set_torch_threads(self._num_threads): - return self._construct_measurement( - number_per_run=number, times=[self._timer.timeit(number=number)] - ) - - def repeat(self, repeat=-1, number=-1): - raise NotImplementedError("See `Timer.blocked_autorange.`") - - def autorange(self, callback=None): - raise NotImplementedError("See `Timer.blocked_autorange.`") - - def _threaded_measurement_loop(self, number, time_hook, stop_hook, min_run_time: float, - max_run_time: Optional[float] = None, callback=None): - total_time = 0.0 - can_stop = False - times = [] - with common.set_torch_threads(self._num_threads): - while (total_time < min_run_time) or (not can_stop): - time_spent = time_hook() - times.append(time_spent) - total_time += time_spent - if callback: - callback(number, time_spent) - can_stop = stop_hook(times) - if max_run_time and total_time > max_run_time: - break - return times - - def adaptive_autorange(self, threshold=0.1, max_run_time=10, callback=None, min_run_time=0.01): - number = self._estimate_block_size(min_run_time=0.05) - - def time_hook(): - return self._timer.timeit(number) - - def stop_hook(times): - if len(times) > 3: - measure = self._construct_measurement(number, times) - return measure.meets_confidence(threshold=threshold) - return False - times = self._threaded_measurement_loop(number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback) - measure = self._construct_measurement(number, times) - return measure - - def _estimate_block_size(self, min_run_time): - with common.set_torch_threads(self._num_threads): - # Estimate the block size needed for measurement to be negligible - # compared to the inner loop. This also serves as a warmup. - overhead = np.median([self._timer.timeit(0) for _ in range(5)]) - number = 1 - while True: - time_taken = self._timer.timeit(number) - relative_overhead = overhead / time_taken - if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000: - break - if time_taken > min_run_time: - break - number *= 10 - return number - - def blocked_autorange(self, callback=None, min_run_time=0.2): - number = self._estimate_block_size(min_run_time) - - def time_hook(): - return self._timer.timeit(number) - - def stop_hook(times): - return True - times = self._threaded_measurement_loop(number, time_hook, stop_hook, min_run_time=min_run_time, - callback=callback) - return self._construct_measurement(number_per_run=number, times=times) diff --git a/torch/utils/_cpp_extension_versioner.py b/torch/utils/_cpp_extension_versioner.py index cb778ab8923d7..958d34ecc71af 100644 --- a/torch/utils/_cpp_extension_versioner.py +++ b/torch/utils/_cpp_extension_versioner.py @@ -38,12 +38,16 @@ def bump_version_if_changed(self, source_files, build_arguments, build_directory, - with_cuda): + with_cuda, + is_python_module, + is_standalone): hash_value = 0 hash_value = hash_source_files(hash_value, source_files) hash_value = hash_build_arguments(hash_value, build_arguments) hash_value = update_hash(hash_value, build_directory) hash_value = update_hash(hash_value, with_cuda) + hash_value = update_hash(hash_value, is_python_module) + hash_value = update_hash(hash_value, is_standalone) entry = self.entries.get(name) if entry is None: diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py new file mode 100644 index 0000000000000..86a7c54c4a1e4 --- /dev/null +++ b/torch/utils/_pytree.py @@ -0,0 +1,189 @@ +from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional + +""" +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_unflatten` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List, Context]] +UnflattenFunc = Callable[[List, Context], PyTree] + +class NodeDef(NamedTuple): + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + +def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None: + SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + +def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: + return {key: value for key, value in zip(context, values)} + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + +def _list_unflatten(values: List[Any], context: Context) -> List[Any]: + return list(values) + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + +def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + +_register_pytree_node(dict, _dict_flatten, _dict_unflatten) +_register_pytree_node(list, _list_flatten, _list_unflatten) +_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(pytree: PyTree) -> bool: + return type(pytree) not in SUPPORTED_NODES.keys() + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +class TreeSpec: + def __init__(self, typ: Any, context: Context, children_specs: List['TreeSpec']) -> None: + self.type = typ + self.context = context + self.children_specs = children_specs + self.num_leaves: int = sum([spec.num_leaves for spec in children_specs]) + + def __repr__(self) -> str: + return f'TreeSpec({self.type.__name__}, {self.context}, {self.children_specs})' + + def __eq__(self, other: Any) -> bool: + result = self.type == other.type and self.context == other.context \ + and self.children_specs == other.children_specs \ + and self.num_leaves == other.num_leaves + # This should really not be necessary, but mypy errors out without it. + return cast(bool, result) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + self.num_leaves = 1 + + def __repr__(self) -> str: + return '*' + + +def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + if _is_leaf(pytree): + return [pytree], LeafSpec() + + flatten_fn = SUPPORTED_NODES[type(pytree)].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + # Recursively flatten the children + result : List[Any] = [] + children_specs : List['TreeSpec'] = [] + for child in child_pytrees: + flat, child_spec = tree_flatten(child) + result += flat + children_specs.append(child_spec) + + return result, TreeSpec(type(pytree), context, children_specs) + + +def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ + if not isinstance(spec, TreeSpec): + raise ValueError( + f'tree_unflatten(values, spec): Expected `spec` to be instance of ' + f'TreeSpec but got item of type {type(spec)}.') + if len(values) != spec.num_leaves: + raise ValueError( + f'tree_unflatten(values, spec): `values` has length {len(values)} ' + f'but the spec refers to a pytree that holds {spec.num_leaves} ' + f'items ({spec}).') + if isinstance(spec, LeafSpec): + return values[0] + + unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in spec.children_specs: + end += child_spec.num_leaves + child_pytrees.append(tree_unflatten(values[start:end], child_spec)) + start = end + + return unflatten_fn(child_pytrees, spec.context) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]: + assert isinstance(spec, TreeSpec) + + if _is_leaf(pytree): + return [pytree] * spec.num_leaves + if isinstance(spec, LeafSpec): + return None + if type(pytree) != spec.type: + return None + + flatten_fn = SUPPORTED_NODES[type(pytree)].flatten_fn + child_pytrees, ctx = flatten_fn(pytree) + + # Check if the Node is different from the spec + if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context: + return None + + # Recursively flatten the children + result : List[Any] = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec) + if flat is not None: + result += flat + else: + return None + + return result diff --git a/torch/utils/_benchmark/README.md b/torch/utils/benchmark/README.md similarity index 98% rename from torch/utils/_benchmark/README.md rename to torch/utils/benchmark/README.md index e432a553e6ba6..4a64b778181f8 100644 --- a/torch/utils/_benchmark/README.md +++ b/torch/utils/benchmark/README.md @@ -15,8 +15,8 @@ into two broad categories: ### Integration and better measurement: - `Timer`, while modeled after the `timit` analog, uses a slightly different - API from `timit.Timer`. + `Timer`, while modeled after the `timeit` analog, uses a slightly different + API from `timeit.Timer`. * The constructor accepts additional metadata and timing methods return a `Measurement` class rather than a float. This `Measurement` class is diff --git a/torch/utils/benchmark/__init__.py b/torch/utils/benchmark/__init__.py new file mode 100644 index 0000000000000..e018eef3d39f3 --- /dev/null +++ b/torch/utils/benchmark/__init__.py @@ -0,0 +1,5 @@ +from torch.utils.benchmark.utils.common import * +from torch.utils.benchmark.utils.timer import * +from torch.utils.benchmark.utils.compare import * +from torch.utils.benchmark.utils.fuzzer import * +from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import * diff --git a/torch/utils/benchmark/examples/__init__.py b/torch/utils/benchmark/examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/utils/benchmark/examples/blas_compare.py b/torch/utils/benchmark/examples/blas_compare.py new file mode 100644 index 0000000000000..68c7c15a02b13 --- /dev/null +++ b/torch/utils/benchmark/examples/blas_compare.py @@ -0,0 +1,230 @@ +import argparse +import datetime +import itertools as it +import multiprocessing +import multiprocessing.dummy +import os +import queue +import pickle +import shutil +import subprocess +import sys +import tempfile +import threading +import time + +import blas_compare_setup + + +MIN_RUN_TIME = 1 +NUM_REPLICATES = 20 +NUM_THREAD_SETTINGS = (1, 2, 4) +RESULT_FILE = os.path.join(blas_compare_setup.WORKING_ROOT, "blas_results.pkl") +SCRATCH_DIR = os.path.join(blas_compare_setup.WORKING_ROOT, "scratch") + + +BLAS_CONFIGS = ( + ("MKL (2020.3)", blas_compare_setup.MKL_2020_3, None), + ("MKL (2020.0)", blas_compare_setup.MKL_2020_0, None), + ("OpenBLAS", blas_compare_setup.OPEN_BLAS, None) +) + + +_RESULT_FILE_LOCK = threading.Lock() +_WORKER_POOL = queue.Queue() +def clear_worker_pool(): + while not _WORKER_POOL.empty(): + _, result_file, _ = _WORKER_POOL.get_nowait() + os.remove(result_file) + + if os.path.exists(SCRATCH_DIR): + shutil.rmtree(SCRATCH_DIR) + + +def fill_core_pool(n: int): + clear_worker_pool() + os.makedirs(SCRATCH_DIR) + + # Reserve two cores so that bookkeeping does not interfere with runs. + cpu_count = multiprocessing.cpu_count() - 2 + + # Adjacent cores sometimes share cache, so we space out single core runs. + step = max(n, 2) + for i in range(0, cpu_count, step): + core_str = f"{i}" if n == 1 else f"{i},{i + n - 1}" + _, result_file = tempfile.mkstemp(suffix=".pkl", prefix=SCRATCH_DIR) + _WORKER_POOL.put((core_str, result_file, n)) + + +def _subprocess_main(seed=0, num_threads=1, sub_label="N/A", result_file=None, env=None): + import torch + from torch.utils.benchmark import Timer + + conda_prefix = os.getenv("CONDA_PREFIX") + assert conda_prefix + if not torch.__file__.startswith(conda_prefix): + raise ValueError( + f"PyTorch mismatch: `import torch` resolved to `{torch.__file__}`, " + f"which is not in the correct conda env: {conda_prefix}" + ) + + torch.manual_seed(seed) + results = [] + for n in [4, 8, 16, 32, 64, 128, 256, 512, 1024, 7, 96, 150, 225]: + dtypes = (("Single", torch.float32), ("Double", torch.float64)) + shapes = ( + # Square MatMul + ((n, n), (n, n), "(n x n) x (n x n)", "Matrix-Matrix Product"), + + # Matrix-Vector product + ((n, n), (n, 1), "(n x n) x (n x 1)", "Matrix-Vector Product"), + ) + for (dtype_name, dtype), (x_shape, y_shape, shape_str, blas_type) in it.product(dtypes, shapes): + t = Timer( + stmt="torch.mm(x, y)", + label=f"torch.mm {shape_str} {blas_type} ({dtype_name})", + sub_label=sub_label, + description=f"n = {n}", + env=os.path.split(env or "")[1] or None, + globals={ + "x": torch.rand(x_shape, dtype=dtype), + "y": torch.rand(y_shape, dtype=dtype), + }, + num_threads=num_threads, + ).blocked_autorange(min_run_time=MIN_RUN_TIME) + results.append(t) + + if result_file is not None: + with open(result_file, "wb") as f: + pickle.dump(results, f) + + +def run_subprocess(args): + seed, env, sub_label, extra_env_vars = args + core_str = None + try: + core_str, result_file, num_threads = _WORKER_POOL.get() + with open(result_file, "wb"): + pass + + env_vars = { + "PATH": os.getenv("PATH"), + "PYTHONPATH": os.getenv("PYTHONPATH") or "", + + # NumPy + "OMP_NUM_THREADS": str(num_threads), + "MKL_NUM_THREADS": str(num_threads), + "NUMEXPR_NUM_THREADS": str(num_threads), + } + env_vars.update(extra_env_vars or {}) + + subprocess.run( + f"source activate {env} && " + f"taskset --cpu-list {core_str} " + f"python {os.path.abspath(__file__)} " + "--DETAIL_in_subprocess " + f"--DETAIL_seed {seed} " + f"--DETAIL_num_threads {num_threads} " + f"--DETAIL_sub_label '{sub_label}' " + f"--DETAIL_result_file {result_file} " + f"--DETAIL_env {env}", + env=env_vars, + stdout=subprocess.PIPE, + shell=True + ) + + with open(result_file, "rb") as f: + result_bytes = f.read() + + with _RESULT_FILE_LOCK, \ + open(RESULT_FILE, "ab") as f: + f.write(result_bytes) + + except KeyboardInterrupt: + pass # Handle ctrl-c gracefully. + + finally: + if core_str is not None: + _WORKER_POOL.put((core_str, result_file, num_threads)) + + +def _compare_main(): + results = [] + with open(RESULT_FILE, "rb") as f: + while True: + try: + results.extend(pickle.load(f)) + except EOFError: + break + + from torch.utils.benchmark import Compare + + comparison = Compare(results) + comparison.trim_significant_figures() + comparison.colorize() + comparison.print() + + +def main(): + with open(RESULT_FILE, "wb"): + pass + + for num_threads in NUM_THREAD_SETTINGS: + fill_core_pool(num_threads) + workers = _WORKER_POOL.qsize() + + trials = [] + for seed in range(NUM_REPLICATES): + for sub_label, env, extra_env_vars in BLAS_CONFIGS: + env_path = os.path.join(blas_compare_setup.WORKING_ROOT, env) + trials.append((seed, env_path, sub_label, extra_env_vars)) + + n = len(trials) + with multiprocessing.dummy.Pool(workers) as pool: + start_time = time.time() + for i, r in enumerate(pool.imap(run_subprocess, trials)): + n_trials_done = i + 1 + time_per_result = (time.time() - start_time) / n_trials_done + eta = int((n - n_trials_done) * time_per_result) + print(f"\r{i + 1} / {n} ETA:{datetime.timedelta(seconds=eta)}".ljust(80), end="") + sys.stdout.flush() + print(f"\r{n} / {n} Total time: {datetime.timedelta(seconds=int(time.time() - start_time))}") + print() + + # Any env will do, it just needs to have torch for benchmark utils. + env_path = os.path.join(blas_compare_setup.WORKING_ROOT, BLAS_CONFIGS[0][1]) + subprocess.run( + f"source activate {env_path} && " + f"python {os.path.abspath(__file__)} " + "--DETAIL_in_compare", + shell=True + ) + + +if __name__ == "__main__": + # These flags are for subprocess control, not controlling the main loop. + parser = argparse.ArgumentParser() + parser.add_argument("--DETAIL_in_subprocess", action="store_true") + parser.add_argument("--DETAIL_in_compare", action="store_true") + parser.add_argument("--DETAIL_seed", type=int, default=None) + parser.add_argument("--DETAIL_num_threads", type=int, default=None) + parser.add_argument("--DETAIL_sub_label", type=str, default="N/A") + parser.add_argument("--DETAIL_result_file", type=str, default=None) + parser.add_argument("--DETAIL_env", type=str, default=None) + args = parser.parse_args() + + if args.DETAIL_in_subprocess: + try: + _subprocess_main( + args.DETAIL_seed, + args.DETAIL_num_threads, + args.DETAIL_sub_label, + args.DETAIL_result_file, + args.DETAIL_env, + ) + except KeyboardInterrupt: + pass # Handle ctrl-c gracefully. + elif args.DETAIL_in_compare: + _compare_main() + else: + main() diff --git a/torch/utils/benchmark/examples/blas_compare_setup.py b/torch/utils/benchmark/examples/blas_compare_setup.py new file mode 100644 index 0000000000000..0ef0d3a4015ab --- /dev/null +++ b/torch/utils/benchmark/examples/blas_compare_setup.py @@ -0,0 +1,221 @@ +import collections +import os +import shutil +import subprocess + +try: + import conda.cli.python_api + from conda.cli.python_api import Commands as conda_commands +except ImportError: + # blas_compare.py will fail to import these when it's inside a conda env, + # but that's fine as it only wants the constants. + pass + + +WORKING_ROOT = "/tmp/pytorch_blas_compare_environments" +MKL_2020_3 = "mkl_2020_3" +MKL_2020_0 = "mkl_2020_0" +OPEN_BLAS = "open_blas" +EIGEN = "eigen" + + +GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0") +BASE_PKG_DEPS = ( + "cffi", + "cmake", + "hypothesis", + "ninja", + "numpy", + "pyyaml", + "setuptools", + "typing_extensions", +) + + +SubEnvSpec = collections.namedtuple( + "SubEnvSpec", ( + "generic_installs", + "special_installs", + "environment_variables", + + # Validate install. + "expected_blas_symbols", + "expected_mkl_version", + )) + + +SUB_ENVS = { + MKL_2020_3: SubEnvSpec( + generic_installs=(), + special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")), + environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS, + expected_blas_symbols=("mkl_blas_sgemm",), + expected_mkl_version="2020.0.3", + ), + + MKL_2020_0: SubEnvSpec( + generic_installs=(), + special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")), + environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS, + expected_blas_symbols=("mkl_blas_sgemm",), + expected_mkl_version="2020.0.0", + ), + + OPEN_BLAS: SubEnvSpec( + generic_installs=("openblas",), + special_installs=(), + environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS, + expected_blas_symbols=("exec_blas",), + expected_mkl_version=None, + ), + + # EIGEN: SubEnvSpec( + # generic_installs=(), + # special_installs=(), + # environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS, + # expected_blas_symbols=(), + # ), +} + + +def conda_run(*args): + """Convenience method.""" + stdout, stderr, retcode = conda.cli.python_api.run_command(*args) + if retcode: + raise OSError(f"conda error: {str(args)} retcode: {retcode}\n{stderr}") + + return stdout + + +def main(): + if os.path.exists(WORKING_ROOT): + print("Cleaning: removing old working root.") + shutil.rmtree(WORKING_ROOT) + os.makedirs(WORKING_ROOT) + + git_root = subprocess.check_output( + "git rev-parse --show-toplevel", + shell=True, + cwd=os.path.dirname(os.path.realpath(__file__)) + ).decode("utf-8").strip() + + for env_name, env_spec in SUB_ENVS.items(): + env_path = os.path.join(WORKING_ROOT, env_name) + print(f"Creating env: {env_name}: ({env_path})") + conda_run( + conda_commands.CREATE, + "--no-default-packages", + "--prefix", env_path, + "python=3", + ) + + print("Testing that env can be activated:") + base_source = subprocess.run( + f"source activate {env_path}", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if base_source.returncode: + raise OSError( + "Failed to source base environment:\n" + f" stdout: {base_source.stdout.decode('utf-8')}\n" + f" stderr: {base_source.stderr.decode('utf-8')}" + ) + + print("Installing packages:") + conda_run( + conda_commands.INSTALL, + "--prefix", env_path, + *(BASE_PKG_DEPS + env_spec.generic_installs) + ) + + if env_spec.special_installs: + channel, channel_deps = env_spec.special_installs + print(f"Installing packages from channel: {channel}") + conda_run( + conda_commands.INSTALL, + "--prefix", env_path, + "-c", channel, *channel_deps + ) + + if env_spec.environment_variables: + print("Setting environment variables.") + + # This does not appear to be possible using the python API. + env_set = subprocess.run( + f"source activate {env_path} && " + f"conda env config vars set {' '.join(env_spec.environment_variables)}", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if env_set.returncode: + raise OSError( + "Failed to set environment variables:\n" + f" stdout: {env_set.stdout.decode('utf-8')}\n" + f" stderr: {env_set.stderr.decode('utf-8')}" + ) + + # Check that they were actually set correctly. + actual_env_vars = subprocess.run( + f"source activate {env_path} && env", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).stdout.decode("utf-8").strip().splitlines() + for e in env_spec.environment_variables: + assert e in actual_env_vars, f"{e} not in envs" + + print(f"Building PyTorch for env: `{env_name}`") + # We have to re-run during each build to pick up the new + # build config settings. + build_run = subprocess.run( + f"source activate {env_path} && " + f"cd {git_root} && " + "python setup.py install --cmake", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + print("Checking configuration:") + check_run = subprocess.run( + # Shameless abuse of `python -c ...` + f"source activate {env_path} && " + "python -c \"" + "import torch;" + "from torch.utils.benchmark import Timer;" + "print(torch.__config__.show());" + "setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';" + "counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);" + "stats = counts.as_standardized().stats(inclusive=True);" + "print(stats.filter(lambda l: 'blas' in l.lower()))\"", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if check_run.returncode: + raise OSError( + "Failed to set environment variables:\n" + f" stdout: {check_run.stdout.decode('utf-8')}\n" + f" stderr: {check_run.stderr.decode('utf-8')}" + ) + check_run_stdout = check_run.stdout.decode('utf-8') + print(check_run_stdout) + + for e in env_spec.environment_variables: + if "BLAS" in e: + assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}" + + for s in env_spec.expected_blas_symbols: + assert s in check_run_stdout + + if env_spec.expected_mkl_version is not None: + assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout + + print(f"Build complete: {env_name}") + + +if __name__ == "__main__": + main() diff --git a/torch/utils/_benchmark/examples/compare.py b/torch/utils/benchmark/examples/compare.py similarity index 98% rename from torch/utils/_benchmark/examples/compare.py rename to torch/utils/benchmark/examples/compare.py index 3373149c7039e..f1688976af378 100644 --- a/torch/utils/_benchmark/examples/compare.py +++ b/torch/utils/benchmark/examples/compare.py @@ -9,7 +9,7 @@ import torch -import torch.utils._benchmark as benchmark_utils +import torch.utils.benchmark as benchmark_utils class FauxTorch(object): diff --git a/torch/utils/_benchmark/examples/end_to_end.py b/torch/utils/benchmark/examples/end_to_end.py similarity index 99% rename from torch/utils/_benchmark/examples/end_to_end.py rename to torch/utils/benchmark/examples/end_to_end.py index b275b9a076a27..942c20e541734 100644 --- a/torch/utils/_benchmark/examples/end_to_end.py +++ b/torch/utils/benchmark/examples/end_to_end.py @@ -26,8 +26,8 @@ import numpy as np import torch -from torch.utils._benchmark.op_fuzzers import unary -from torch.utils._benchmark import Timer, Measurement +from torch.utils.benchmark.op_fuzzers import unary +from torch.utils.benchmark import Timer, Measurement from typing import Dict, Tuple, List diff --git a/torch/utils/_benchmark/examples/fuzzer.py b/torch/utils/benchmark/examples/fuzzer.py similarity index 98% rename from torch/utils/_benchmark/examples/fuzzer.py rename to torch/utils/benchmark/examples/fuzzer.py index 157782de4ccd9..4446e2d85c0a2 100644 --- a/torch/utils/_benchmark/examples/fuzzer.py +++ b/torch/utils/benchmark/examples/fuzzer.py @@ -5,7 +5,7 @@ import sys -import torch.utils._benchmark as benchmark_utils +import torch.utils.benchmark as benchmark_utils def main(): diff --git a/torch/utils/_benchmark/examples/op_benchmark.py b/torch/utils/benchmark/examples/op_benchmark.py similarity index 95% rename from torch/utils/_benchmark/examples/op_benchmark.py rename to torch/utils/benchmark/examples/op_benchmark.py index 1d3cc618fa352..65b69d84b41f4 100644 --- a/torch/utils/_benchmark/examples/op_benchmark.py +++ b/torch/utils/benchmark/examples/op_benchmark.py @@ -6,9 +6,9 @@ import numpy as np import torch -from torch.utils._benchmark import Timer -from torch.utils._benchmark.op_fuzzers.binary import BinaryOpFuzzer -from torch.utils._benchmark.op_fuzzers.unary import UnaryOpFuzzer +from torch.utils.benchmark import Timer +from torch.utils.benchmark.op_fuzzers.binary import BinaryOpFuzzer +from torch.utils.benchmark.op_fuzzers.unary import UnaryOpFuzzer _MEASURE_TIME = 1.0 diff --git a/torch/utils/_benchmark/examples/prepare_e2e.sh b/torch/utils/benchmark/examples/prepare_e2e.sh similarity index 100% rename from torch/utils/_benchmark/examples/prepare_e2e.sh rename to torch/utils/benchmark/examples/prepare_e2e.sh diff --git a/torch/utils/_benchmark/examples/simple_timeit.py b/torch/utils/benchmark/examples/simple_timeit.py similarity index 90% rename from torch/utils/_benchmark/examples/simple_timeit.py rename to torch/utils/benchmark/examples/simple_timeit.py index 4bd76ce4ccebf..81aaa6dee9817 100644 --- a/torch/utils/_benchmark/examples/simple_timeit.py +++ b/torch/utils/benchmark/examples/simple_timeit.py @@ -5,7 +5,7 @@ import torch -import torch.utils._benchmark as benchmark_utils +import torch.utils.benchmark as benchmark_utils def main(): diff --git a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py new file mode 100644 index 0000000000000..d73dcd490b824 --- /dev/null +++ b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py @@ -0,0 +1,113 @@ +"""Microbenchmarks for the torch.fft module""" +from argparse import ArgumentParser +from collections import namedtuple +from collections.abc import Iterable + +import torch +import torch.fft +from torch.utils import benchmark +from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer + + +def _dim_options(ndim): + if ndim == 1: + return [None] + elif ndim == 2: + return [0, 1, None] + elif ndim == 3: + return [0, 1, 2, (0, 1), (0, 2), None] + raise ValueError(f"Expected ndim in range 1-3, got {ndim}") + + +def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int, + probability_regular: float): + cuda = device == 'cuda' + spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda, + probability_regular=probability_regular) + results = [] + for tensors, tensor_params, params in spectral_fuzzer.take(samples): + shape = [params['k0'], params['k1'], params['k2']][:params['ndim']] + str_shape = ' x '.join(["{:<4}".format(s) for s in shape]) + sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}" + for dim in _dim_options(params['ndim']): + for nthreads in (1, 4, 16) if not cuda else (1,): + measurement = benchmark.Timer( + stmt='func(x, dim=dim)', + globals={'func': function, 'x': tensors['x'], 'dim': dim}, + label=f"{name}_{device}", + sub_label=sub_label, + description=f"dim={dim}", + num_threads=nthreads, + ).blocked_autorange(min_run_time=1) + measurement.metadata = { + 'name': name, + 'device': device, + 'dim': dim, + 'shape': shape, + } + measurement.metadata.update(tensor_params['x']) + results.append(measurement) + return results + + +Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype']) +BENCHMARKS = [ + Benchmark('fft_real', torch.fft.fftn, torch.float32), # type: ignore + Benchmark('fft_complex', torch.fft.fftn, torch.complex64), # type: ignore + Benchmark('ifft', torch.fft.ifftn, torch.complex64), # type: ignore + Benchmark('rfft', torch.fft.rfftn, torch.float32), # type: ignore + Benchmark('irfft', torch.fft.irfftn, torch.complex64), # type: ignore +] +BENCHMARK_MAP = {b.name: b for b in BENCHMARKS} +BENCHMARK_NAMES = [b.name for b in BENCHMARKS] +DEVICE_NAMES = ['cpu', 'cuda'] + +def _output_csv(file, results): + file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n') + for measurement in results: + metadata = measurement.metadata + device, dim, shape, name, numel, contiguous = ( + metadata['device'], metadata['dim'], metadata['shape'], + metadata['name'], metadata['numel'], metadata['is_contiguous']) + + if isinstance(dim, Iterable): + dim_str = '-'.join(str(d) for d in dim) + else: + dim_str = str(dim) + shape_str = 'x'.join(str(s) for s in shape) + + print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str, + measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6, + sep=',', file=file) + + +if __name__ == '__main__': + parser = ArgumentParser(description=__doc__) + parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES) + parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--samples', type=int, default=10) + parser.add_argument('--probability_regular', type=float, default=1.0) + parser.add_argument('-o', '--output', type=str) + args = parser.parse_args() + + num_benchmarks = len(args.device) * len(args.bench) + i = 0 + results = [] + for device in args.device: + for bench in (BENCHMARK_MAP[b] for b in args.bench): + results += run_benchmark( + name=bench.name, function=bench.function, dtype=bench.dtype, + seed=args.seed, device=device, samples=args.samples, + probability_regular=args.probability_regular) + i += 1 + print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})') + + if args.output is not None: + with open(args.output, 'w') as f: + _output_csv(f, results) + + compare = benchmark.Compare(results) + compare.trim_significant_figures() + compare.colorize() + compare.print() diff --git a/torch/utils/benchmark/op_fuzzers/__init__.py b/torch/utils/benchmark/op_fuzzers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/utils/_benchmark/op_fuzzers/binary.py b/torch/utils/benchmark/op_fuzzers/binary.py similarity index 97% rename from torch/utils/_benchmark/op_fuzzers/binary.py rename to torch/utils/benchmark/op_fuzzers/binary.py index 848cc7c36875a..91289d88db8ad 100644 --- a/torch/utils/_benchmark/op_fuzzers/binary.py +++ b/torch/utils/benchmark/op_fuzzers/binary.py @@ -1,7 +1,7 @@ import numpy as np import torch -from torch.utils._benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor _MIN_DIM_SIZE = 16 diff --git a/torch/utils/benchmark/op_fuzzers/spectral.py b/torch/utils/benchmark/op_fuzzers/spectral.py new file mode 100644 index 0000000000000..29359ba3edb6f --- /dev/null +++ b/torch/utils/benchmark/op_fuzzers/spectral.py @@ -0,0 +1,93 @@ +import math + +import torch +from torch.utils import benchmark +from torch.utils.benchmark import FuzzedParameter, FuzzedTensor, ParameterAlias + + +__all__ = ['SpectralOpFuzzer'] + +MIN_DIM_SIZE = 16 +MAX_DIM_SIZE = 16 * 1024 + +def power_range(upper_bound, base): + return (base ** i for i in range(int(math.log(upper_bound, base)) + 1)) + +# List of regular numbers from MIN_DIM_SIZE to MAX_DIM_SIZE +# These numbers factorize into multiples of prime factors 2, 3, and 5 only +# and are usually the fastest in FFT implementations. +REGULAR_SIZES = [] +for i in power_range(MAX_DIM_SIZE, 2): + for j in power_range(MAX_DIM_SIZE // i, 3): + ij = i * j + for k in power_range(MAX_DIM_SIZE // ij, 5): + ijk = ij * k + if ijk > MIN_DIM_SIZE: + REGULAR_SIZES.append(ijk) +REGULAR_SIZES.sort() + +class SpectralOpFuzzer(benchmark.Fuzzer): + def __init__(self, *, seed: int, dtype=torch.float64, + cuda: bool = False, probability_regular: float = 1.0): + super().__init__( + parameters=[ + # Dimensionality of x. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("ndim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + + # Shapes for `x`. + # It is important to test all shapes, however + # regular sizes are especially important to the FFT and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the regular numbers + # (both distributions are loguniform) and then randomly + # selecting between the two. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=MIN_DIM_SIZE, + maxval=MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_regular_{i}", + distribution={size: 1. / len(REGULAR_SIZES) for size in REGULAR_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_regular_{i}"): probability_regular, + ParameterAlias(f"k_any_{i}"): 1 - probability_regular, + }, + strict=True, + ) for i in range(3) + ], + + # Steps for `x`. (Benchmarks strided memory access.) + [ + FuzzedParameter( + name=f"step_{i}", + distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04}, + ) for i in range(3) + ], + ], + tensors=[ + FuzzedTensor( + name="x", + size=("k0", "k1", "k2"), + steps=("step_0", "step_1", "step_2"), + probability_contiguous=0.75, + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="ndim", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/torch/utils/_benchmark/op_fuzzers/unary.py b/torch/utils/benchmark/op_fuzzers/unary.py similarity index 97% rename from torch/utils/_benchmark/op_fuzzers/unary.py rename to torch/utils/benchmark/op_fuzzers/unary.py index 10cee4316c1c8..a0f810d0b9fad 100644 --- a/torch/utils/_benchmark/op_fuzzers/unary.py +++ b/torch/utils/benchmark/op_fuzzers/unary.py @@ -1,7 +1,7 @@ import numpy as np import torch -from torch.utils._benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor _MIN_DIM_SIZE = 16 diff --git a/torch/utils/benchmark/utils/__init__.py b/torch/utils/benchmark/utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/utils/benchmark/utils/_stubs.py b/torch/utils/benchmark/utils/_stubs.py new file mode 100644 index 0000000000000..e2ab6ec086a49 --- /dev/null +++ b/torch/utils/benchmark/utils/_stubs.py @@ -0,0 +1,46 @@ +import sys +from typing import Any, Callable, Dict, TYPE_CHECKING + + +if TYPE_CHECKING or sys.version_info >= (3, 8): + from typing import runtime_checkable, Protocol +else: + from typing_extensions import runtime_checkable, Protocol + + +class TimerClass(Protocol): + """This is the portion of the `timeit.Timer` API used by benchmark utils.""" + def __init__( + self, + stmt: str, + setup: str, + timer: Callable[[], float], + globals: Dict[str, Any] + ) -> None: + ... + + def timeit(self, number: int) -> float: + ... + + +@runtime_checkable +class TimeitModuleType(Protocol): + """Modules generated from `timeit_template.cpp`.""" + def timeit(self, number: int) -> float: + ... + + +class CallgrindModuleType(Protocol): + """Replicates the valgrind endpoints in `torch._C`. + + These bindings are used to collect Callgrind profiles on earlier versions + of PyTorch and will eventually be removed. + """ + __file__: str + __name__: str + + def _valgrind_supported_platform(self) -> bool: + ... + + def _valgrind_toggle(self) -> None: + ... diff --git a/torch/utils/benchmark/utils/common.py b/torch/utils/benchmark/utils/common.py new file mode 100644 index 0000000000000..1cbed2df51c22 --- /dev/null +++ b/torch/utils/benchmark/utils/common.py @@ -0,0 +1,288 @@ +"""Base shared classes and utilities.""" + +import collections +import contextlib +import dataclasses +import textwrap +from typing import cast, Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple + +import numpy as np +import torch + + +__all__ = ["TaskSpec", "Measurement"] + + +_MAX_SIGNIFICANT_FIGURES = 4 +_MIN_CONFIDENCE_INTERVAL = 25e-9 # 25 ns + +# Measurement will include a warning if the distribution is suspect. All +# runs are expected to have some variation; these parameters set the +# thresholds. +_IQR_WARN_THRESHOLD = 0.1 +_IQR_GROSS_WARN_THRESHOLD = 0.25 + + +@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True) +class TaskSpec: + """Container for information used to define a Timer. (except globals)""" + stmt: str + setup: str + label: Optional[str] = None + sub_label: Optional[str] = None + description: Optional[str] = None + env: Optional[str] = None + num_threads: int = 1 + + @property + def title(self) -> str: + """Best effort attempt at a string label for the measurement.""" + if self.label is not None: + return self.label + (f": {self.sub_label}" if self.sub_label else "") + elif "\n" not in self.stmt: + return self.stmt + (f": {self.sub_label}" if self.sub_label else "") + return ( + f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n" + f"{textwrap.indent(self.stmt, ' ')}" + ) + + def setup_str(self) -> str: + return ( + "" if (self.setup == "pass" or not self.setup) + else f"setup:\n{textwrap.indent(self.setup, ' ')}" if "\n" in self.setup + else f"setup: {self.setup}" + ) + + def summarize(self) -> str: + """Build TaskSpec portion of repr string for other containers.""" + sections = [ + self.title, + self.description or "", + self.setup_str(), + ] + return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i]) + +_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec)) + + +@dataclasses.dataclass(init=True, repr=False) +class Measurement: + """The result of a Timer measurement. + + This class stores one or more measurements of a given statement. It is + serializable and provides several convenience methods + (including a detailed __repr__) for downstream consumers. + """ + number_per_run: int + raw_times: List[float] + task_spec: TaskSpec + metadata: Optional[Dict[Any, Any]] = None # Reserved for user payloads. + + def __post_init__(self) -> None: + self._sorted_times: Tuple[float, ...] = () + self._warnings: Tuple[str, ...] = () + self._median: float = -1.0 + self._mean: float = -1.0 + self._p25: float = -1.0 + self._p75: float = -1.0 + + def __getattr__(self, name: str) -> Any: + # Forward TaskSpec fields for convenience. + if name in _TASKSPEC_FIELDS: + return getattr(self.task_spec, name) + return super().__getattribute__(name) + + # ========================================================================= + # == Convenience methods for statistics =================================== + # ========================================================================= + # + # These methods use raw time divided by number_per_run; this is an + # extrapolation and hides the fact that different number_per_run will + # result in different amortization of overheads, however if Timer has + # selected an appropriate number_per_run then this is a non-issue, and + # forcing users to handle that division would result in a poor experience. + @property + def times(self) -> List[float]: + return [t / self.number_per_run for t in self.raw_times] + + @property + def median(self) -> float: + self._lazy_init() + return self._median + + @property + def mean(self) -> float: + self._lazy_init() + return self._mean + + @property + def iqr(self) -> float: + self._lazy_init() + return self._p75 - self._p25 + + @property + def significant_figures(self) -> int: + """Approximate significant figure estimate. + + This property is intended to give a convenient way to estimate the + precision of a measurement. It only uses the interquartile region to + estimate statistics to try to mitigate skew from the tails, and + uses a static z value of 1.645 since it is not expected to be used + for small values of `n`, so z can approximate `t`. + + The significant figure estimation used in conjunction with the + `trim_sigfig` method to provide a more human interpretable data + summary. __repr__ does not use this method; it simply displays raw + values. Significant figure estimation is intended for `Compare`. + """ + self._lazy_init() + n_total = len(self._sorted_times) + lower_bound = int(n_total // 4) + upper_bound = int(np.ceil(3 * n_total / 4)) + interquartile_points: Tuple[float, ...] = self._sorted_times[lower_bound:upper_bound] + std = np.std(interquartile_points) + sqrt_n = np.sqrt(len(interquartile_points)) + + # Rough estimates. These are by no means statistically rigorous. + confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL) + relative_ci = np.log10(self._median / confidence_interval) + num_significant_figures = int(np.floor(relative_ci)) + return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES) + + @property + def has_warnings(self) -> bool: + self._lazy_init() + return bool(self._warnings) + + def _lazy_init(self) -> None: + if self.raw_times and not self._sorted_times: + self._sorted_times = tuple(sorted(self.times)) + self._median = float(np.median(self._sorted_times)) + self._mean = float(np.mean(self._sorted_times)) + self._p25 = float(np.percentile(self._sorted_times, 25)) + self._p75 = float(np.percentile(self._sorted_times, 75)) + + def add_warning(msg: str) -> None: + rel_iqr = self.iqr / self.median * 100 + self._warnings += ( + f" WARNING: Interquartile range is {rel_iqr:.1f}% " + f"of the median measurement.\n {msg}", + ) + + if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD): + add_warning("This suggests significant environmental influence.") + elif not self.meets_confidence(_IQR_WARN_THRESHOLD): + add_warning("This could indicate system fluctuation.") + + + def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool: + return self.iqr / self.median < threshold + + @property + def title(self) -> str: + return self.task_spec.title + + @property + def env(self) -> str: + return ( + "Unspecified env" if self.taskspec.env is None + else cast(str, self.taskspec.env) + ) + + @property + def as_row_name(self) -> str: + return self.sub_label or self.stmt or "[Unknown]" + + def __repr__(self) -> str: + """ + Example repr: + + Broadcasting add (4x8) + Median: 5.73 us + IQR: 2.25 us (4.01 to 6.26) + 372 measurements, 100 runs per measurement, 1 thread + WARNING: Interquartile range is 39.4% of the median measurement. + This suggests significant environmental influence. + """ + self._lazy_init() + skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n" + n = len(self._sorted_times) + time_unit, time_scale = select_unit(self._median) + iqr_filter = '' if n >= 4 else skip_line + + repr_str = f""" +{super().__repr__()} +{self.task_spec.summarize()} + {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit} + {iqr_filter}IQR: {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f}) + {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''} +{newline.join(self._warnings)}""".strip() # noqa + + return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l) + + @staticmethod + def merge(measurements): # type: (Iterable[Measurement]) -> List[Measurement] + """Convenience method for merging replicates. + NB: merge will extrapolate times to `number_per_run=1` and will not + transfer any metadata (since it might differ between replicates) + """ + grouped_measurements: DefaultDict[TaskSpec, List[Measurement]] = collections.defaultdict(list) + for m in measurements: + grouped_measurements[m.task_spec].append(m) + + def merge_group(task_spec: TaskSpec, group: List[Measurement]) -> Measurement: + times: List[float] = [] + for m in group: + # Different measurements could have different `number_per_run`, + # so we call `.times` which normalizes the results. + times.extend(m.times) + + return Measurement( + number_per_run=1, + raw_times=times, + task_spec=task_spec, + metadata=None, + ) + + return [merge_group(t, g) for t, g in grouped_measurements.items()] + + +def select_unit(t: float) -> Tuple[str, float]: + """Determine how to scale times for O(1) magnitude. + + This utility is used to format numbers for human consumption. + """ + time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(np.log10(t) // 3), "s") + time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit] + return time_unit, time_scale + + +def unit_to_english(u: str) -> str: + return { + "ns": "nanosecond", + "us": "microsecond", + "ms": "millisecond", + "s": "second", + }[u] + + +def trim_sigfig(x: float, n: int) -> float: + """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)""" + assert n == int(n) + magnitude = int(np.ceil(np.log10(np.abs(x)))) + scale = 10 ** (magnitude - n) + return float(np.round(x / scale) * scale) + + +def ordered_unique(elements: Iterable[Any]) -> List[Any]: + return list(collections.OrderedDict({i: None for i in elements}).keys()) + + +@contextlib.contextmanager +def set_torch_threads(n: int) -> Iterator[None]: + prior_num_threads = torch.get_num_threads() + try: + torch.set_num_threads(n) + yield + finally: + torch.set_num_threads(prior_num_threads) diff --git a/torch/utils/_benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py similarity index 63% rename from torch/utils/_benchmark/utils/compare.py rename to torch/utils/benchmark/utils/compare.py index 6f29b67fef8c7..a35f25a43774a 100644 --- a/torch/utils/_benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -1,11 +1,12 @@ """Display class to aggregate and print the results of many measurements.""" import collections +import enum import itertools as it -from typing import cast, List, Optional, Tuple +from typing import DefaultDict, List, Optional, Tuple import numpy as np -from torch.utils._benchmark.utils import common +from torch.utils.benchmark.utils import common __all__ = ["Compare"] @@ -17,11 +18,17 @@ TERMINATE = "\033[0m" +class Colorize(enum.Enum): + NONE = "none" + COLUMNWISE = "columnwise" + ROWWISE = "rowwise" + + # Classes to separate internal bookkeeping from what is rendered. class _Column(object): def __init__( self, - grouped_results: List[Tuple[common.Measurement, ...]], + grouped_results: List[Tuple[Optional[common.Measurement], ...]], time_scale: float, time_unit: str, trim_significant_figures: bool, @@ -32,15 +39,19 @@ def __init__( self._time_scale = time_scale self._time_unit = time_unit self._trim_significant_figures = trim_significant_figures - self._highlight_warnings = highlight_warnings and any(r.has_warnings for r in self._flat_results) + self._highlight_warnings = ( + highlight_warnings + and any(r.has_warnings for r in self._flat_results if r) + ) leading_digits = [ - int(np.ceil(np.log10(r.median / self._time_scale))) + int(np.ceil(np.log10(r.median / self._time_scale))) if r else None for r in self._flat_results ] - unit_digits = max(leading_digits) + unit_digits = max(d for d in leading_digits if d is not None) decimal_digits = min( max(m.significant_figures - digits, 0) for digits, m in zip(leading_digits, self._flat_results) + if (m is not None) and (digits is not None) ) if self._trim_significant_figures else 1 length = unit_digits + decimal_digits + (1 if decimal_digits else 0) self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}" @@ -48,12 +59,21 @@ def __init__( def get_results_for(self, group): return self._grouped_results[group] - def num_to_str(self, value: float, estimated_sigfigs: int, spread: Optional[float]): + def num_to_str(self, value: Optional[float], estimated_sigfigs: int, spread: Optional[float]): + if value is None: + return " " * len(self.num_to_str(1, estimated_sigfigs, None)) + if self._trim_significant_figures: value = common.trim_sigfig(value, estimated_sigfigs) + return self._template.format( value, - f" (! {spread:.0f}%)" if self._highlight_warnings and spread is not None else "") + f" (! {spread * 100:.0f}%)" if self._highlight_warnings and spread is not None else "") + + +def optional_min(seq): + l = list(seq) + return None if len(l) == 0 else min(l) class _Row(object): @@ -67,27 +87,30 @@ def __init__(self, results, row_group, render_env, env_str_len, self._row_name_str_len = row_name_str_len self._time_scale = time_scale self._colorize = colorize - self._columns = None + self._columns: Tuple[_Column, ...] = () self._num_threads = num_threads def register_columns(self, columns: Tuple[_Column, ...]): self._columns = columns def as_column_strings(self): - env = f"({self._results[0].env})" if self._render_env else "" + concrete_results = [r for r in self._results if r is not None] + env = f"({concrete_results[0].env})" if self._render_env else "" env = env.ljust(self._env_str_len + 4) - output = [" " + env + self._results[0].as_row_name] + output = [" " + env + concrete_results[0].as_row_name] for m, col in zip(self._results, self._columns or ()): - output.append(col.num_to_str( - m.median / self._time_scale, - m.significant_figures, - m.median / m._iqr if m.has_warnings else None - )) + if m is None: + output.append(col.num_to_str(None, 1, None)) + else: + output.append(col.num_to_str( + m.median / self._time_scale, + m.significant_figures, + m.iqr / m.median if m.has_warnings else None + )) return output @staticmethod - def color_segment(segment, value, group_values): - best_value = min(group_values) + def color_segment(segment, value, best_value): if value <= best_value * 1.01 or value <= best_value + 100e-9: return BEST + BOLD + segment + TERMINATE * 2 if value <= best_value * 1.1: @@ -106,19 +129,33 @@ def row_separator(self, overall_width): ) def finalize_column_strings(self, column_strings, col_widths): + best_values = [-1 for _ in column_strings] + if self._colorize == Colorize.ROWWISE: + row_min = min(r.median for r in self._results if r is not None) + best_values = [row_min for _ in column_strings] + elif self._colorize == Colorize.COLUMNWISE: + best_values = [ + optional_min(r.median for r in column.get_results_for(self._row_group) if r is not None) + for column in (self._columns or ()) + ] + row_contents = [column_strings[0].ljust(col_widths[0])] - for col_str, width, result, column in zip(column_strings[1:], col_widths[1:], self._results, self._columns or ()): + for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values): col_str = col_str.center(width) - if self._colorize: - group_medians = [r.median for r in column.get_results_for(self._row_group)] - col_str = self.color_segment(col_str, result.median, group_medians) + if self._colorize != Colorize.NONE and result is not None and best_value is not None: + col_str = self.color_segment(col_str, result.median, best_value) row_contents.append(col_str) return row_contents class Table(object): - def __init__(self, results: List[common.Measurement], colorize: bool, - trim_significant_figures: bool, highlight_warnings: bool): + def __init__( + self, + results: List[common.Measurement], + colorize: Colorize, + trim_significant_figures: bool, + highlight_warnings: bool + ): assert len(set(r.label for r in results)) == 1 self.results = results @@ -136,17 +173,20 @@ def __init__(self, results: List[common.Measurement], colorize: bool, self.rows, self.columns = self.populate_rows_and_columns() @staticmethod - def row_fn(m: common.Measurement): + def row_fn(m: common.Measurement) -> Tuple[int, Optional[str], str]: return m.num_threads, m.env, m.as_row_name @staticmethod - def col_fn(m: common.Measurement): + def col_fn(m: common.Measurement) -> Optional[str]: return m.description - def populate_rows_and_columns(self): - rows, columns = [], [] - - ordered_results: List[List[Optional[common.Measurement]]] = [[None for _ in self.column_keys] for _ in self.row_keys] + def populate_rows_and_columns(self) -> Tuple[Tuple[_Row, ...], Tuple[_Column, ...]]: + rows: List[_Row] = [] + columns: List[_Column] = [] + ordered_results: List[List[Optional[common.Measurement]]] = [ + [None for _ in self.column_keys] + for _ in self.row_keys + ] row_position = {key: i for i, key in enumerate(self.row_keys)} col_position = {key: i for i, key in enumerate(self.column_keys)} for r in self.results: @@ -187,10 +227,7 @@ def populate_rows_and_columns(self): prior_env = env for i in range(len(self.column_keys)): - grouped_results = cast( - List[Tuple[common.Measurement, ...]], # All Nones should be gone. - [tuple(row[i] for row in g) for g in rows_by_group], - ) + grouped_results = [tuple(row[i] for row in g) for g in rows_by_group] column = _Column( grouped_results=grouped_results, time_scale=self.time_scale, @@ -204,13 +241,13 @@ def populate_rows_and_columns(self): ri.register_columns(columns_tuple) return rows_tuple, columns_tuple - def render(self): + def render(self) -> str: string_rows = [[""] + self.column_keys] for r in self.rows: string_rows.append(r.as_column_strings()) num_cols = max(len(i) for i in string_rows) - for r in string_rows: - r.extend(["" for _ in range(num_cols - len(r))]) + for sr in string_rows: + sr.extend(["" for _ in range(num_cols - len(sr))]) col_widths = [max(len(j) for j in i) for i in zip(*string_rows)] finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))] @@ -218,12 +255,15 @@ def render(self): for string_row, row in zip(string_rows[1:], self.rows): finalized_columns.extend(row.row_separator(overall_width)) finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths))) - print("[" + (" " + (self.label or "") + " ").center(overall_width - 2, "-") + "]") - print("\n".join(finalized_columns)) - print(f"\nTimes are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).") - if self._highlight_warnings and any(r.has_warnings for r in self.results): - print("(! XX%) Measurement has high variance, where XX is the median / IQR * 100.") - print("\n") + + newline = "\n" + has_warnings = self._highlight_warnings and any(ri.has_warnings for ri in self.results) + return f""" +[{(' ' + (self.label or '') + ' ').center(overall_width - 2, '-')}] +{newline.join(finalized_columns)} + +Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}). +{'(! XX%) Measurement has high variance, where XX is the IQR / median * 100.' + newline if has_warnings else ""}"""[1:] class Compare(object): @@ -231,9 +271,12 @@ def __init__(self, results: List[common.Measurement]): self._results: List[common.Measurement] = [] self.extend_results(results) self._trim_significant_figures = False - self._colorize = False + self._colorize = Colorize.NONE self._highlight_warnings = False + def __str__(self): + return "\n".join(self._render()) + def extend_results(self, results): for r in results: if not isinstance(r, common.Measurement): @@ -245,28 +288,34 @@ def extend_results(self, results): def trim_significant_figures(self): self._trim_significant_figures = True - def colorize(self): - self._colorize = True + def colorize(self, rowwise=False): + self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE def highlight_warnings(self): self._highlight_warnings = True def print(self): - self._render() + print(str(self)) def _render(self): - results = common.merge_measurements(self._results) - results = self._group_by_label(results) - for group in results.values(): - self._layout(group) + results = common.Measurement.merge(self._results) + grouped_results = self._group_by_label(results) + output = [] + for group in grouped_results.values(): + output.append(self._layout(group)) + return output - def _group_by_label(self, results): - grouped_results = collections.defaultdict(list) + def _group_by_label(self, results: List[common.Measurement]): + grouped_results: DefaultDict[str, List[common.Measurement]] = collections.defaultdict(list) for r in results: grouped_results[r.label].append(r) return grouped_results def _layout(self, results: List[common.Measurement]): - table = Table(results, self._colorize, self._trim_significant_figures, - self._highlight_warnings) - table.render() + table = Table( + results, + self._colorize, + self._trim_significant_figures, + self._highlight_warnings + ) + return table.render() diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py new file mode 100644 index 0000000000000..3ec3b6c097d82 --- /dev/null +++ b/torch/utils/benchmark/utils/cpp_jit.py @@ -0,0 +1,154 @@ +"""JIT C++ strings into executables.""" +import atexit +import os +import re +import shutil +import tempfile +import textwrap +import threading +import uuid +from typing import Any, List, Optional + +import torch +from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType +from torch.utils import cpp_extension + + +LOCK = threading.Lock() +SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0] + +# We calculate uuid once at import time so that separate processes will have +# separate build roots, but threads will share the same build root. +# `cpp_extension` uses build root as part of the cache key, so per-invocation +# uuid's (e.g. different build root per _compile_template call) would lead to +# a 0% cache hit rate and spurious recompilation. Consider the following: +# ``` +# setup = "auto x = torch::ones({1024, 1024});" +# stmt = "torch::mm(x, x);" +# for num_threads in [1, 2, 4, 8]: +# print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange()) +# ```` +# `setup` and `stmt` do not change, so we can reuse the executable from the +# first pass through the loop. +BUILD_ROOT = os.path.join( + tempfile.gettempdir(), + f"benchmark_utils_jit_build_{uuid.uuid4()}".replace("-", "") +) + +# BACK_TESTING_NOTE: +# There are two workflows where this code could be used. One is the obvious +# case where someone simply builds or installs PyTorch and uses Timer. +# The other is that the entire `torch/utils/benchmark` folder from a CURRENT +# PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch +# source code. This is what we refer to here as "back testing". The rationale +# is that we might want to use current tooling to study some aspect of an +# earlier version of PyTorch. (e.g. a regression.) +# +# The problem is that Timer relies on several aspects of core PyTorch, namely +# some binding functions for Valgrind symbols in `torch._C` and the +# `torch.__config__._cxx_flags()` method. If we were to naively copy code +# around this wouldn't work as the symbols of interest aren't present in +# earlier versions of PyTorch. In order to work around this, we must add back +# testing shims. These shims will never activate during normal use, but will +# allow Timer to function outside of the "correct" version of PyTorch by +# emulating functionality that was added later. +# +# These shims are temporary, and as Timer becomes more integrated with +# PyTorch the cost and complexity of such shims will increase. Once back +# testing is no longer required (which is to say we have done enough historic +# analysis and the shims no longer justify their maintenance and code +# complexity costs) back testing paths will be removed. + +if hasattr(torch.__config__, "_cxx_flags"): + CXX_FLAGS = torch.__config__._cxx_flags().strip().split() + if "-g" not in CXX_FLAGS: + CXX_FLAGS.append("-g") +else: + # FIXME: Remove when back testing is no longer required. + CXX_FLAGS = ["-O2", "-fPIC", "-g"] + +EXTRA_INCLUDE_PATHS: List[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")] +CONDA_PREFIX = os.getenv("CONDA_PREFIX") +if CONDA_PREFIX is not None: + # Load will automatically search /usr/include, but not conda include. + EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include")) + + +COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None +def get_compat_bindings() -> CallgrindModuleType: + with LOCK: + global COMPAT_CALLGRIND_BINDINGS + if COMPAT_CALLGRIND_BINDINGS is None: + COMPAT_CALLGRIND_BINDINGS = cpp_extension.load( + name="callgrind_bindings", + sources=[os.path.join( + SOURCE_ROOT, + "valgrind_wrapper", + "compat_bindings.cpp" + )], + extra_cflags=CXX_FLAGS, + extra_include_paths=EXTRA_INCLUDE_PATHS, + ) + return COMPAT_CALLGRIND_BINDINGS + + +def _compile_template(stmt: str, setup: str, src: str, is_standalone: bool) -> Any: + for before, after, indentation in ( + ("// SETUP_TEMPLATE_LOCATION", setup, 4), + ("// STMT_TEMPLATE_LOCATION", stmt, 8) + ): + # C++ doesn't care about indentation so this code isn't load + # bearing the way it is with Python, but this makes the source + # look nicer if a human has to look at it. + src = re.sub( + before, + textwrap.indent(after, " " * indentation)[indentation:], + src + ) + + # We want to isolate different Timers. However `cpp_extension` will + # cache builds which will significantly reduce the cost of repeated + # invocations. + with LOCK: + if not os.path.exists(BUILD_ROOT): + os.makedirs(BUILD_ROOT) + atexit.register(shutil.rmtree, BUILD_ROOT) + + name = f"timer_cpp_{abs(hash(src))}" + build_dir = os.path.join(BUILD_ROOT, name) + os.makedirs(build_dir, exist_ok=True) + + src_path = os.path.join(build_dir, "timer_src.cpp") + with open(src_path, "wt") as f: + f.write(src) + + # `cpp_extension` has its own locking scheme, so we don't need our lock. + return cpp_extension.load( + name=name, + sources=[src_path], + build_directory=build_dir, + extra_cflags=CXX_FLAGS, + extra_include_paths=EXTRA_INCLUDE_PATHS, + is_python_module=not is_standalone, + is_standalone=is_standalone, + ) + + +def compile_timeit_template(stmt: str, setup: str) -> TimeitModuleType: + template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp") + with open(template_path, "rt") as f: + src: str = f.read() + + module = _compile_template(stmt, setup, src, is_standalone=False) + assert isinstance(module, TimeitModuleType) + return module + + +def compile_callgrind_template(stmt: str, setup: str) -> str: + template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp") + with open(template_path, "rt") as f: + src: str = f.read() + + target = _compile_template(stmt, setup, src, is_standalone=True) + assert isinstance(target, str) + return target diff --git a/torch/utils/_benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py similarity index 99% rename from torch/utils/_benchmark/utils/fuzzer.py rename to torch/utils/benchmark/utils/fuzzer.py index a355beb0ed49b..ac813bb42393f 100644 --- a/torch/utils/_benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -161,7 +161,7 @@ def __repr__(self): def dtype_size(dtype): if dtype == torch.bool: return 1 - if dtype.is_floating_point: + if dtype.is_floating_point or dtype.is_complex: return int(torch.finfo(dtype).bits / 8) return int(torch.iinfo(dtype).bits / 8) @@ -259,7 +259,7 @@ def name(self): @staticmethod def default_tensor_constructor(size, dtype, **kwargs): - if dtype.is_floating_point: + if dtype.is_floating_point or dtype.is_complex: return torch.rand(size=size, dtype=dtype, device="cpu") else: return torch.randint(1, 127, size=size, dtype=dtype, device="cpu") diff --git a/torch/utils/benchmark/utils/timeit_template.cpp b/torch/utils/benchmark/utils/timeit_template.cpp new file mode 100644 index 0000000000000..01d62efdb1610 --- /dev/null +++ b/torch/utils/benchmark/utils/timeit_template.cpp @@ -0,0 +1,36 @@ +/* C++ template for Timer.timeit + +This template will be consumed by `cpp_jit.py`, and will replace: + `SETUP_TEMPLATE_LOCATION` + and + `STMT_TEMPLATE_LOCATION` +sections with user provided statements. +*/ +#include + +#include +#include + + +double timeit(int n) { + // Setup + // SETUP_TEMPLATE_LOCATION + + { + // Warmup + // STMT_TEMPLATE_LOCATION + } + + // Main loop + auto start_time = std::chrono::high_resolution_clock::now(); + for (int loop_idx = 0; loop_idx < n; loop_idx++) { + // STMT_TEMPLATE_LOCATION + } + auto end_time = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end_time - start_time).count(); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("timeit", &timeit); +} diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py new file mode 100644 index 0000000000000..f7001b697392c --- /dev/null +++ b/torch/utils/benchmark/utils/timer.py @@ -0,0 +1,438 @@ +"""Timer class based on the timeit.Timer class, but torch aware.""" +import enum +import timeit +import textwrap +from typing import Any, Callable, Dict, List, NoReturn, Optional, Type, Union + +import numpy as np +import torch +from torch.utils.benchmark.utils import common, cpp_jit +from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType +from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface + + +__all__ = ["Timer", "timer", "Language"] + + +if torch.has_cuda and torch.cuda.is_available(): + def timer() -> float: + torch.cuda.synchronize() + return timeit.default_timer() +else: + timer = timeit.default_timer + + +class Language(enum.Enum): + PYTHON = 0 + CPP = 1 + + +class CPPTimer: + def __init__( + self, + stmt: str, + setup: str, + timer: Callable[[], float], + globals: Dict[str, Any], + ) -> None: + if timer is not timeit.default_timer: + raise NotImplementedError( + "PyTorch was built with CUDA and a GPU is present; however " + "Timer does not yet support GPU measurements. If your " + "code is CPU only, pass `timer=timeit.default_timer` to the " + "Timer's constructor to indicate this. (Note that this will " + "produce incorrect results if the GPU is in fact used, as " + "Timer will not synchronize CUDA.)" + ) + + if globals: + raise ValueError("C++ timing does not support globals.") + + self._stmt: str = textwrap.dedent(stmt) + self._setup: str = textwrap.dedent(setup) + self._timeit_module: Optional[TimeitModuleType] = None + + def timeit(self, number: int) -> float: + if self._timeit_module is None: + self._timeit_module = cpp_jit.compile_timeit_template( + self._stmt, + self._setup, + ) + + return self._timeit_module.timeit(number) + + +class Timer(object): + """Helper class for measuring execution time of PyTorch statements. + + For a full tutorial on how to use this class, see: + https://pytorch.org/tutorials/recipes/recipes/benchmark.html + + The PyTorch Timer is based on `timeit.Timer` (and in fact uses + `timeit.Timer` internally), but with several key differences: + + 1) Runtime aware: + Timer will perform warmups (important as some elements of PyTorch are + lazily initialized), set threadpool size so that comparisons are + apples-to-apples, and synchronize asynchronous CUDA functions when + necessary. + + 2) Focus on replicates: + When measuring code, and particularly complex kernels / models, + run-to-run variation is a significant confounding factor. It is + expected that all measurements should include replicates to quantify + noise and allow median computation, which is more robust than mean. + To that effect, this class deviates from the `timeit` API by + conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`. + (Exact algorithms are discussed in method docstrings.) The `timeit` + method is replicated for cases where an adaptive strategy is not + desired. + + 3) Optional metadata: + When defining a Timer, one can optionally specify `label`, `sub_label`, + `description`, and `env`. (Defined later) These fields are included in + the representation of result object and by the `Compare` class to group + and display results for comparison. + + 4) Instruction counts + In addition to wall times, Timer can run a statement under Callgrind + and report instructions executed. + + Directly analogous to `timeit.Timer` constructor arguments: + + `stmt`, `setup`, `timer`, `globals` + + PyTorch Timer specific constructor arguments: + + `label`, `sub_label`, `description`, `env`, `num_threads` + + Args: + stmt: Code snippet to be run in a loop and timed. + + setup: Optional setup code. Used to define variables used in `stmt` + + timer: + Callable which returns the current time. If PyTorch was built + without CUDA or there is no GPU present, this defaults to + `timeit.default_timer`; otherwise it will synchronize CUDA before + measuring the time. + + globals: + A dict which defines the global variables when `stmt` is being + executed. This is the other method for providing variables which + `stmt` needs. + + label: + String which summarizes `stmt`. For instance, if `stmt` is + "torch.nn.functional.relu(torch.add(x, 1, out=out))" + one might set label to "ReLU(x + 1)" to improve readability. + + sub_label: + Provide supplemental information to disambiguate measurements + with identical stmt or label. For instance, in our example + above sub_label might be "float" or "int", so that it is easy + to differentiate: + "ReLU(x + 1): (float)" + + "ReLU(x + 1): (int)" + when printing Measurements or summarizing using `Compare`. + + description: + String to distinguish measurements with identical label and + sub_label. The principal use of `description` is to signal to + `Compare` the columns of data. For instance one might set it + based on the input size to create a table of the form: :: + + | n=1 | n=4 | ... + ------------- ... + ReLU(x + 1): (float) | ... | ... | ... + ReLU(x + 1): (int) | ... | ... | ... + + + using `Compare`. It is also included when printing a Measurement. + + env: + This tag indicates that otherwise identical tasks were run in + different environments, and are therefore not equivilent, for + instance when A/B testing a change to a kernel. `Compare` will + treat Measurements with different `env` specification as distinct + when merging replicate runs. + + num_threads: + The size of the PyTorch threadpool when executing `stmt`. Single + threaded performace is important as both a key inference workload + and a good indicator of intrinsic algorithmic efficiency, so the + default is set to one. This is in contrast to the default PyTorch + threadpool size which tries to utilize all cores. + """ + + _timer_cls: Type[TimerClass] = timeit.Timer + + def __init__( + self, + stmt: str = "pass", + setup: str = "pass", + timer: Callable[[], float] = timer, + globals: Optional[Dict[str, Any]] = None, + label: Optional[str] = None, + sub_label: Optional[str] = None, + description: Optional[str] = None, + env: Optional[str] = None, + num_threads: int = 1, + language: Union[Language, str] = Language.PYTHON, + ): + if not isinstance(stmt, str): + raise ValueError("Currently only a `str` stmt is supported.") + + # We copy `globals` to prevent mutations from leaking. + # (For instance, `eval` adds the `__builtins__` key) + self._globals = dict(globals or {}) + if language in (Language.PYTHON, "py", "python"): + # Include `torch` if not specified as a convenience feature. + self._globals.setdefault("torch", torch) + self._language: Language = Language.PYTHON + + elif language in (Language.CPP, "cpp", "c++"): + assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped." + self._timer_cls = CPPTimer + setup = ("" if setup == "pass" else setup) + self._language = Language.CPP + + else: + raise ValueError(f"Invalid language `{language}`.") + + # Convenience adjustment so that multi-line code snippets defined in + # functions do not IndentationError (Python) or look odd (C++). The + # leading newline removal is for the initial newline that appears when + # defining block strings. For instance: + # textwrap.dedent(""" + # print("This is a stmt") + # """) + # produces '\nprint("This is a stmt")\n'. + # + # Stripping this down to 'print("This is a stmt")' doesn't change + # what gets executed, but it makes __repr__'s nicer. + stmt = textwrap.dedent(stmt) + stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip() + setup = textwrap.dedent(setup) + setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip() + + self._timer = self._timer_cls( + stmt=stmt, + setup=setup, + timer=timer, + globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals), + ) + self._task_spec = common.TaskSpec( + stmt=stmt, + setup=setup, + label=label, + sub_label=sub_label, + description=description, + env=env, + num_threads=num_threads, + ) + + def timeit(self, number: int = 1000000) -> common.Measurement: + """Mirrors the semantics of timeit.Timer.timeit(). + + Execute the main statement (`stmt`) `number` times. + https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit + """ + with common.set_torch_threads(self._task_spec.num_threads): + # Warmup + self._timer.timeit(number=max(int(number // 100), 1)) + + return common.Measurement( + number_per_run=number, + raw_times=[self._timer.timeit(number=number)], + task_spec=self._task_spec + ) + + def repeat(self, repeat: int = -1, number: int = -1) -> None: + raise NotImplementedError("See `Timer.blocked_autorange.`") + + def autorange(self, callback: Optional[Callable[[int, float], NoReturn]] = None) -> None: + raise NotImplementedError("See `Timer.blocked_autorange.`") + + def _threaded_measurement_loop( + self, + number: int, + time_hook: Callable[[], float], + stop_hook: Callable[[List[float]], bool], + min_run_time: float, + max_run_time: Optional[float] = None, + callback: Optional[Callable[[int, float], NoReturn]] = None + ) -> List[float]: + total_time = 0.0 + can_stop = False + times: List[float] = [] + with common.set_torch_threads(self._task_spec.num_threads): + while (total_time < min_run_time) or (not can_stop): + time_spent = time_hook() + times.append(time_spent) + total_time += time_spent + if callback: + callback(number, time_spent) + can_stop = stop_hook(times) + if max_run_time and total_time > max_run_time: + break + return times + + def _estimate_block_size(self, min_run_time: float) -> int: + with common.set_torch_threads(self._task_spec.num_threads): + # Estimate the block size needed for measurement to be negligible + # compared to the inner loop. This also serves as a warmup. + overhead = np.median([self._timer.timeit(0) for _ in range(5)]) + number = 1 + while True: + time_taken = self._timer.timeit(number) + relative_overhead = overhead / time_taken + if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000: + break + if time_taken > min_run_time: + break + number *= 10 + return number + + def adaptive_autorange( + self, + threshold: float = 0.1, + *, + min_run_time: float = 0.01, + max_run_time: float = 10.0, + callback: Optional[Callable[[int, float], NoReturn]] = None, + ) -> common.Measurement: + number = self._estimate_block_size(min_run_time=0.05) + + def time_hook() -> float: + return self._timer.timeit(number) + + def stop_hook(times: List[float]) -> bool: + if len(times) > 3: + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ).meets_confidence(threshold=threshold) + return False + times = self._threaded_measurement_loop( + number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback) + + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ) + + def blocked_autorange( + self, + callback: Optional[Callable[[int, float], NoReturn]] = None, + min_run_time: float = 0.2, + ) -> common.Measurement: + """Measure many replicates while keeping timer overhead to a minimum. + + At a high level, blocked_autorange executes the following pseudo-code:: + + `setup` + + total_time = 0 + while total_time < min_run_time + start = timer() + for _ in range(block_size): + `stmt` + total_time += (timer() - start) + + Note the variable `block_size` in the inner loop. The choice of block + size is important to measurement quality, and must balance two + competing objectives: + + 1) A small block size results in more replicates and generally + better statistics. + + 2) A large block size better amortizes the cost of `timer` + invocation, and results in a less biased measurement. This is + important because CUDA syncronization time is non-trivial + (order single to low double digit microseconds) and would + otherwise bias the measurement. + + blocked_autorange sets block_size by running a warmup period, + increasing block size until timer overhead is less than 0.1% of + the overall computation. This value is then used for the main + measurement loop. + + Returns: + A `Measurement` object that contains measured runtimes and + repetition counts, and can be used to compute statistics. + (mean, median, etc.) + """ + number = self._estimate_block_size(min_run_time) + + def time_hook() -> float: + return self._timer.timeit(number) + + def stop_hook(times: List[float]) -> bool: + return True + + times = self._threaded_measurement_loop( + number, time_hook, stop_hook, + min_run_time=min_run_time, + callback=callback) + + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ) + + def collect_callgrind( + self, + number: int = 100, + collect_baseline: bool = True + ) -> valgrind_timer_interface.CallgrindStats: + """Collect instruction counts using Callgrind. + + Unlike wall times, instruction counts are deterministic + (modulo non-determinism in the program itself and small amounts of + jitter from the Python interpreter.) This makes them ideal for detailed + performance analysis. This method runs `stmt` in a separate process + so that Valgrind can instrument the program. Performance is severely + degraded due to the instrumentation, howevever this is ameliorated by + the fact that a small number of iterations is generally sufficient to + obtain good measurements. + + In order to to use this method `valgrind`, `callgrind_control`, and + `callgrind_annotate` must be installed. + + Because there is a process boundary between the caller (this process) + and the `stmt` execution, `globals` cannot contain arbitrary in-memory + data structures. (Unlike timing methods) Instead, globals are + restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules + to reduce the surprise factor from serialization and subsequent + deserialization. The `GlobalsBridge` class provides more detail on this + subject. Take particular care with nn.Modules: they rely on pickle and + you may need to add an import to `setup` for them to transfer properly. + + By default, a profile for an empty statement will be collected and + cached to indicate how many instructions are from the Python loop which + drives `stmt`. + + Returns: + A `CallgrindStats` object which provides instruction counts and + some basic facilities for analyzing and manipulating results. + """ + if not isinstance(self._task_spec.stmt, str): + raise ValueError("`collect_callgrind` currently only supports string `stmt`") + + # Check that the statement is valid. It doesn't guarantee success, but it's much + # simpler and quicker to raise an exception for a faulty `stmt` or `setup` in + # the parent process rather than the valgrind subprocess. + self._timer.timeit(1) + is_python = (self._language == Language.PYTHON) + assert is_python or not self._globals + return valgrind_timer_interface.wrapper_singleton().collect_callgrind( + task_spec=self._task_spec, + globals=self._globals, + number=number, + collect_baseline=collect_baseline and is_python, + is_python=is_python) diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py b/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp new file mode 100644 index 0000000000000..b52626fe76fd1 --- /dev/null +++ b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp @@ -0,0 +1,25 @@ +/* Used to collect profiles of old versions of PyTorch. */ +#include +#include + + +bool _valgrind_supported_platform() { + #if defined(NVALGRIND) + return false; + #else + return true; + #endif +} + +void _valgrind_toggle() { + #if defined(NVALGRIND) + TORCH_CHECK(false, "Valgrind is not supported."); + #else + CALLGRIND_TOGGLE_COLLECT; + #endif +} + +PYBIND11_MODULE(callgrind_bindings, m) { + m.def("_valgrind_supported_platform", &_valgrind_supported_platform); + m.def("_valgrind_toggle", &_valgrind_toggle); +} diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp new file mode 100644 index 0000000000000..a64484f709249 --- /dev/null +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp @@ -0,0 +1,47 @@ +/* C++ template for Timer.collect_callgrind + +This template will be consumed by `cpp_jit.py`, and will replace: + `SETUP_TEMPLATE_LOCATION` + and + `STMT_TEMPLATE_LOCATION` +sections with user provided statements. +*/ + +#include + +#include +#include + +#if defined(NVALGRIND) +static_assert(false); +#endif + +int main(int argc, char* argv[]) { + // This file should only be called inside of `Timer`, so we can adopt a + // very simple and rigid argument parsing scheme. + TORCH_CHECK(argc == 7); + TORCH_CHECK(std::string(argv[1]) == "--number"); + auto number = std::stoi(argv[2]); + + TORCH_CHECK(std::string(argv[3]) == "--number_warmup"); + auto number_warmup = std::stoi(argv[4]); + + TORCH_CHECK(std::string(argv[5]) == "--number_threads"); + auto number_threads = std::stoi(argv[6]); + torch::set_num_threads(number_threads); + + // Setup + // SETUP_TEMPLATE_LOCATION + + // Warmup + for (int i = 0; i < number_warmup; i++) { + // STMT_TEMPLATE_LOCATION + } + + // Main loop + CALLGRIND_TOGGLE_COLLECT; + for (int i = 0; i < number; i++) { + // STMT_TEMPLATE_LOCATION + } + CALLGRIND_TOGGLE_COLLECT; +} diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py new file mode 100644 index 0000000000000..64c5d9793e80f --- /dev/null +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -0,0 +1,791 @@ +"""Intermediate layer between `Timer` and `valgrind`.""" +import collections +import enum +import dataclasses +import itertools as it +import os +import pickle +import re +import shutil +import subprocess +import sys +import tempfile +import textwrap +from typing import ( + cast, Any, Callable, DefaultDict, Dict, Generator, List, NamedTuple, + Optional, Tuple, Union, TYPE_CHECKING) + +import torch +from torch.utils.benchmark.utils import common, cpp_jit +from torch.utils.benchmark.utils._stubs import CallgrindModuleType + + +__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"] + + +if TYPE_CHECKING: + CompletedProcessType = subprocess.CompletedProcess[str] +else: + CompletedProcessType = subprocess.CompletedProcess + + +FunctionCount = NamedTuple("FunctionCount", [("count", int), ("function", str)]) + + +@dataclasses.dataclass(repr=False, eq=False, frozen=True) +class FunctionCounts(object): + _data: Tuple[FunctionCount, ...] + inclusive: bool + + # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines + # the print settings. This is simply to allow hermetic unit tests. + _linewidth: Optional[int] = None + + def __iter__(self) -> Generator[FunctionCount, None, None]: + for i in self._data: + yield i + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, item: Any) -> "Union[FunctionCount, FunctionCounts]": + data: Union[FunctionCount, Tuple[FunctionCount, ...]] = self._data[item] + return ( + FunctionCounts(cast(Tuple[FunctionCount, ...], data), self.inclusive) + if isinstance(data, tuple) else data + ) + + def __repr__(self) -> str: + count_len = 0 + for c, _ in self: + # Account for sign in string length. + count_len = max(count_len, len(str(c)) + int(c < 0)) + + lines = [] + linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth + fn_str_len = max(linewidth - count_len - 4, 40) + for c, fn in self: + if len(fn) > fn_str_len: + left_len = int((fn_str_len - 5) // 2) + fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):] + lines.append(f" {c:>{count_len}} {fn}") + + if len(lines) > 18: + lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:] + + if not self.inclusive: + lines.extend(["", f"Total: {self.sum()}"]) + + return "\n".join([super().__repr__()] + lines) + + def __add__( + self, + other, # type: FunctionCounts + ) -> "FunctionCounts": + return self._merge(other, lambda c: c) + + def __sub__( + self, + other, # type: FunctionCounts + ) -> "FunctionCounts": + return self._merge(other, lambda c: -c) + + def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts": + counts: DefaultDict[str, int] = collections.defaultdict(int) + for c, fn in self._data: + counts[map_fn(fn)] += c + + return self._from_dict(counts, self.inclusive) + + def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts": + return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive) + + def sum(self) -> int: + return sum(c for c, _ in self) + + def denoise(self) -> "FunctionCounts": + """Remove known noisy instructions. + + Several instructions in the CPython interpreter are rather noisy. These + instructions involve unicode to dictionary lookups which Python uses to + map variable names. FunctionCounts is generally a content agnostic + container, however this is sufficiently important for obtaining + reliable results to warrant an exception.""" + return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn) + + def _merge( + self, + second, # type: FunctionCounts + merge_fn: Callable[[int], int] + ) -> "FunctionCounts": + assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts." + counts: DefaultDict[str, int] = collections.defaultdict(int) + for c, fn in self: + counts[fn] += c + + for c, fn in second: + counts[fn] += merge_fn(c) + + return self._from_dict(counts, self.inclusive) + + @staticmethod + def _from_dict(counts: Dict[str, int], inclusive: bool) -> "FunctionCounts": + flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c) + return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive) + + +@dataclasses.dataclass(repr=False, eq=False, frozen=True) +class CallgrindStats(object): + task_spec: common.TaskSpec + number_per_run: int + built_with_debug_symbols: bool + baseline_inclusive_stats: FunctionCounts + baseline_exclusive_stats: FunctionCounts + stmt_inclusive_stats: FunctionCounts + stmt_exclusive_stats: FunctionCounts + + def __repr__(self) -> str: + newline = "\n" # `\` cannot appear in fstring code section. + base_stats = self.baseline_exclusive_stats + output = f""" +{super().__repr__()} +{self.task_spec.summarize()} + {'':>25}All{'':>10}Noisy symbols removed + Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12} + Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12} +{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''} +""".strip() + if not self.built_with_debug_symbols: + output += textwrap.dedent(""" + Warning: PyTorch was not built with debug symbols. + Source information may be limited. Rebuild with + REL_WITH_DEB_INFO=1 for more detailed results.""") + return output + + def stats(self, inclusive: bool = False) -> FunctionCounts: + """Returns stats as a tuple of (count, function) + + `inclusive` matches the semantics of callgrind. If True, the counts + include instructions executed by children. `inclusive=True` is useful + for identifying hot spots in code; `inclusive=False` is useful for + identifying reducing noise when diffing counts from two different + runs. (See CallgrindStats.delta(...) for more details) + """ + if inclusive: + return self.stmt_inclusive_stats - self.baseline_inclusive_stats + return self.stmt_exclusive_stats - self.baseline_exclusive_stats + + def counts(self, *, denoise: bool = False) -> int: + """Returns the total number of instructions executed. + + See `FunctionCounts.denoise()` for an explation of the `denoise` arg. + """ + stats = self.stmt_exclusive_stats + return (stats.denoise() if denoise else stats).sum() + + # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563 + def delta( + self, + other, # type: CallgrindStats + inclusive: bool = False, + subtract_baselines: bool = True + ) -> FunctionCounts: + """Diff two sets of counts. + + One common reason to collect instruction counts is to determine the + the effect that a particular change will have on the number of instructions + needed to perform some unit of work. If a change increases that number, the + next logical question is "why". This generally involves looking at what part + if the code increased in instruction count. This function automates that + process so that one can easily diff counts on both an inclusive and + exclusive basis. The `subtract_baselines` argument allows one to disable + baseline correction, though in most cases it shouldn't matter as the + baselines are expected to more or less cancel out. + """ + if subtract_baselines: + return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive) + elif inclusive: + return self.stmt_inclusive_stats - other.stmt_inclusive_stats + return self.stmt_exclusive_stats - other.stmt_exclusive_stats + + def as_standardized(self) -> "CallgrindStats": + """Strip library names and some prefixes from function strings. + + When comparing two different sets of instruction counts, on stumbling + block can be path prefixes. Callgrind includes the full filepath + when reporting a function (as it should). However, this can cause + issues when diffing profiles. If a key component such as Python + or PyTorch was built in separate locations in the two profiles, which + can result in something resembling: + 23234231 /tmp/first_build_dir/thing.c:foo(...) + 9823794 /tmp/first_build_dir/thing.c:bar(...) + ... + 53453 .../aten/src/Aten/...:function_that_actually_changed(...) + ... + -9823794 /tmp/second_build_dir/thing.c:bar(...) + -23234231 /tmp/second_build_dir/thing.c:foo(...) + + Stripping prefixes can ameliorate this issue by regularizing the + strings and causing better cancellation of equivilent call sites + when diffing. + """ + def strip(stats: FunctionCounts) -> FunctionCounts: + transforms = ( + # PyTorch may have been built in different locations. + (r"^.+build/\.\./", "build/../"), + (r"^.+/" + re.escape("build/aten/"), "build/aten/"), + + # "Python" and "Objects" come from CPython. + (r"^.+/" + re.escape("Python/"), "Python/"), + (r"^.+/" + re.escape("Objects/"), "Objects/"), + + # Strip library name. e.g. `libtorch.so` + (r"\s\[.+\]$", ""), + ) + + for before, after in transforms: + stats = stats.transform(lambda fn: re.sub(before, after, fn)) + + return stats + + return CallgrindStats( + task_spec=self.task_spec, + number_per_run=self.number_per_run, + built_with_debug_symbols=self.built_with_debug_symbols, + baseline_inclusive_stats=strip(self.baseline_inclusive_stats), + baseline_exclusive_stats=strip(self.baseline_exclusive_stats), + stmt_inclusive_stats=strip(self.stmt_inclusive_stats), + stmt_exclusive_stats=strip(self.stmt_exclusive_stats), + ) + + +class Serialization(enum.Enum): + PICKLE = 0 + TORCH = 1 + TORCH_JIT = 2 + + +_GLOBALS_ALLOWED_TYPES: Dict[Serialization, Tuple[Any, ...]] = { + Serialization.PICKLE: (str, bytes, bool, int, float, complex), + Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule), + Serialization.TORCH: (torch.nn.Module,), +} + + +class CopyIfCallgrind: + """Signal that a global may be replaced with a deserialized copy. + + See `GlobalsBridge` for why this matters. + """ + def __init__(self, value: Any, *, setup: Optional[str] = None): + for method, supported_types in _GLOBALS_ALLOWED_TYPES.items(): + if any(isinstance(value, t) for t in supported_types): + self._value: Any = value + self._setup: Optional[str] = setup + self._serialization: Serialization = method + break + else: + supported_str = "\n".join([ + getattr(t, "__name__", repr(t)) + for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())]) + + raise ValueError( + f"Unsupported type: {type(value)}\n" + f"`collect_callgrind` restricts globals to the following types:\n" + f"{textwrap.indent(supported_str, ' ')}" + ) + + @property + def value(self) -> Any: + return self._value + + @property + def setup(self) -> Optional[str]: + return self._setup + + @property + def serialization(self) -> Serialization: + return self._serialization + + @staticmethod + def unwrap_all(globals: Dict[str, Any]) -> Dict[str, Any]: + return { + k: (v.value if isinstance(v, CopyIfCallgrind) else v) + for k, v in globals.items() + } + + +class GlobalsBridge: + """Handle the transfer of (certain) globals when collecting Callgrind statistics. + + Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to + work with `Timer.collect_callgrind`. + + Consider the following code snippet: + ``` + import pickle + import timeit + + class Counter: + value = 0 + + def __call__(self): + self.value += 1 + + counter = Counter() + timeit.Timer("counter()", globals={"counter": counter}).timeit(10) + print(counter.value) # 10 + + timeit.Timer( + "counter()", + globals={"counter": pickle.loads(pickle.dumps(counter))} + ).timeit(20) + print(counter.value) # Still 10 + ``` + + In the first case, `stmt` is executed using the objects in `globals`; + however, the addition of serialization and deserialization changes the + semantics and may meaningfully change behavior. + + This is a practical consideration when collecting Callgrind statistics. + Unlike `exec` based execution (which `timeit` uses under the hood) which + can share in-memory data structures with the caller, Callgrind collection + requires an entirely new process in order to run under Valgrind. This means + that any data structures used for statement execution will have to be + serialized and deserialized in the subprocess. + + In order to avoid surprising semantics from (user invisible) process + boundaries, what can be passed through `globals` is severely restricted + for `Timer.collect_callgrind`. It is expected that most setup should be + achievable (albeit perhaps less ergonomically) by passing a `setup` + string. + + There are, however, exceptions. One such class are TorchScripted functions. + Because they require a concrete file with source code it is not possible + to define them using a `setup` string. Another group are torch.nn.Modules, + whose construction can be complex and prohibitively cumbersome to coerce + into a `setup` string. Finally, most builtin types are sufficiently well + behaved and sufficiently common to warrant allowing as well. (e.g. + `globals={"n": 1}` is very convenient.) + + Fortunately, all have well defined serialization semantics. This class + is responsible for enabling the Valgrind subprocess to use elements in + `globals` so long as they are an allowed type. + + Caveats: + The user is required to acknowledge this serialization by wrapping + elements in `globals` with `CopyIfCallgrind`. + + While ScriptFunction and ScriptModule are expected to save and load + quite robustly, it is up to the user to ensure that an nn.Module can + un-pickle successfully. + + `torch.Tensor` and `np.ndarray` are deliberately excluded. The + serialization/deserialization process perturbs the representation of a + tensor in ways that could result in incorrect measurements. For example, + if a tensor lives in pinned CPU memory, this fact would not be preserved + by a dump, and that will in turn change the performance of certain CUDA + operations. + """ + + def __init__(self, globals: Dict[str, Any], data_dir: str) -> None: + self._globals: Dict[str, CopyIfCallgrind] = {} + self._data_dir = data_dir + if not os.path.exists(data_dir): + os.mkdir(data_dir) + + if globals.get("torch", torch) is not torch: + raise ValueError("`collect_callgrind` does not support mocking out `torch`.") + + for name, value in globals.items(): + if name in ("torch", "__builtins__"): + # Torch will be imported by the collection script, and + # __builtins__ is added by Timer. + continue + + if not isinstance(value, CopyIfCallgrind): + raise ValueError( + "`collect_callgrind` requires that globals be wrapped in " + "`CopyIfCallgrind` so that serialization is explicit." + ) + + self._globals[name] = value + + def construct(self) -> str: + load_lines = [] + for name, wrapped_value in self._globals.items(): + if wrapped_value.setup is not None: + load_lines.append(textwrap.dedent(wrapped_value.setup)) + + if wrapped_value.serialization == Serialization.PICKLE: + path = os.path.join(self._data_dir, f"{name}.pkl") + load_lines.append( + f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)") + with open(path, "wb") as f: + pickle.dump(wrapped_value.value, f) + + elif wrapped_value.serialization == Serialization.TORCH: + path = os.path.join(self._data_dir, f"{name}.pt") + load_lines.append(f"{name} = torch.load({repr(path)})") + torch.save(wrapped_value.value, path) + + elif wrapped_value.serialization == Serialization.TORCH_JIT: + path = os.path.join(self._data_dir, f"{name}.pt") + load_lines.append(f"{name} = torch.jit.load({repr(path)})") + with open(path, "wb") as f: + torch.jit.save(wrapped_value.value, f) + + else: + raise NotImplementedError( + f"Unknown serialization method: {wrapped_value.serialization}") + + return "\n".join(load_lines) + + +class _ValgrindWrapper(object): + def __init__(self) -> None: + self._bindings_module: Optional[CallgrindModuleType] = None + if hasattr(torch._C, "_valgrind_supported_platform"): + self._supported_platform: bool = torch._C._valgrind_supported_platform() + + else: + print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.") + self._bindings_module = cpp_jit.get_compat_bindings() + self._supported_platform = self._bindings_module._valgrind_supported_platform() + + self._commands_available: Dict[str, bool] = {} + if self._supported_platform: + # Only bother checking on supported platforms. + for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"): + self._commands_available[cmd] = not subprocess.run( + ["which", cmd], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).returncode + + self._build_type: Optional[str] = None + build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) + if build_search is not None: + self._build_type = build_search.groups()[0].split(",")[0] + + self._baseline_cache: Dict[Tuple[int, int], Tuple[FunctionCounts, FunctionCounts]] = {} + + def _validate(self) -> None: + if not self._supported_platform: + raise OSError("Valgrind is not supported on this platform.") + + missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available] + if missing_cmds: + raise OSError("Missing: " + ", ".join(missing_cmds)) + + def collect_callgrind( + self, + task_spec: common.TaskSpec, + globals: Dict[str, Any], + number: int, + collect_baseline: bool, + is_python: bool, + ) -> CallgrindStats: + """Collect stats, and attach a reference run which can be used to filter interpreter overhead.""" + self._validate() + assert is_python or not collect_baseline + + baseline_inclusive_stats = FunctionCounts((), inclusive=True) + baseline_exclusive_stats = FunctionCounts((), inclusive=False) + if collect_baseline: + cache_key = (number, task_spec.num_threads) + if cache_key not in self._baseline_cache: + self._baseline_cache[cache_key] = self._invoke( + common.TaskSpec( + stmt="pass", + setup="pass", + num_threads=task_spec.num_threads, + ), + globals={}, + number=number, + is_python=True, + ) + baseline_inclusive_stats, baseline_exclusive_stats = \ + self._baseline_cache[cache_key] + + stmt_inclusive_stats, stmt_exclusive_stats = self._invoke( + task_spec, globals, number, is_python) + return CallgrindStats( + task_spec=task_spec, + number_per_run=number, + built_with_debug_symbols=self._build_type == "RelWithDebInfo", + baseline_inclusive_stats=baseline_inclusive_stats, + baseline_exclusive_stats=baseline_exclusive_stats, + stmt_inclusive_stats=stmt_inclusive_stats, + stmt_exclusive_stats=stmt_exclusive_stats, + ) + + def _invoke( + self, + task_spec: common.TaskSpec, + globals: Dict[str, Any], + number: int, + is_python: bool, + ) -> Tuple[FunctionCounts, FunctionCounts]: + """Core invocation method for Callgrind collection. + + Valgrind operates by effectively replacing the CPU with an emulated + version which allows it to instrument any code at the cost of severe + performance degradation. This has the practical effect that in order + to collect Callgrind statistics, a new process has to be created + running under `valgrind`. The steps for this process are: + + 1) Create a scratch directory. + 2) Codegen a run script. (_ValgrindWrapper._construct_script) + Inside the run script: + * Validate that Python and torch match the parent process + * Validate that it is indeed running under valgrind + * Execute `setup` and warm up `stmt` + * Begin collecting stats + * Run the `stmt` loop + * Stop collecting stats + 3) Parse the run results. + 4) Cleanup the scratch directory. + """ + working_dir = tempfile.mkdtemp() + data_dir = os.path.join(working_dir, "data") + script_file = os.path.join(working_dir, "timer_callgrind.py") + callgrind_out = os.path.join(working_dir, "callgrind.out") + error_log = os.path.join(working_dir, "error.txt") + stat_log = os.path.join(working_dir, "callgrind_stat.txt") + stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log") + + def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]: + # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/ + f_stdout_stderr = open(stdout_stderr_log, "wb") + try: + invocation = subprocess.run( + args, + stdout=f_stdout_stderr, + stderr=subprocess.STDOUT, + **kwargs, + ) + with open(stdout_stderr_log, "rt") as f: + return invocation, f.read() + finally: + f_stdout_stderr.close() + + try: + if is_python: + if self._bindings_module is not None: + shutil.copy( + self._bindings_module.__file__, + os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1]) + ) + + script_file = os.path.join(working_dir, "timer_callgrind.py") + with open(script_file, "wt") as f: + f.write(self._construct_script( + task_spec, + globals=GlobalsBridge(globals, data_dir), + number=number, + error_log=error_log, + stat_log=stat_log, + bindings=self._bindings_module)) + run_loop_cmd = ["python", script_file] + else: + run_loop_exec = cpp_jit.compile_callgrind_template( + task_spec.stmt, + task_spec.setup, + ) + run_loop_cmd = [ + run_loop_exec, + "--number", str(number), + "--number_warmup", str(min(number, 10)), + "--number_threads", str(task_spec.num_threads), + ] + + valgrind_invocation, valgrind_invocation_output = run([ + "valgrind", + "--tool=callgrind", + f"--callgrind-out-file={callgrind_out}", + "--dump-line=yes", + "--dump-instr=yes", + "--instr-atstart=yes", + "--collect-atstart=no", + ] + run_loop_cmd) + + if valgrind_invocation.returncode: + error_report = "" + if os.path.exists(error_log): + with open(error_log, "rt") as f: + error_report = f.read() + if not error_report: + error_report = "Unknown error.\n" + valgrind_invocation_output + + raise OSError(f"Failed to collect callgrind profile:\n{error_report}") + + def parse_output(inclusive: bool) -> FunctionCounts: + annotate_invocation, annotate_invocation_output = run([ + "callgrind_annotate", + f"--inclusive={'yes' if inclusive else 'no'}", + callgrind_out + ], check=True) + + begin_collecting = False + fn_counts = [] + for l in annotate_invocation_output.splitlines(keepends=False): + if not begin_collecting and re.match(r"Ir\s+file:function", l): + begin_collecting = True + continue + + count_match = re.match(r"^\s*([0-9,]+)\s+(.+:.+)$", l) + if count_match: + ir_str, file_function = count_match.groups() + ir = int(ir_str.replace(",", "")) + fn_counts.append(FunctionCount(ir, file_function)) + continue + + if begin_collecting and re.match(r"-+", l): + continue + + begin_collecting = False + + return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive) + return parse_output(inclusive=True), parse_output(inclusive=False) + finally: + shutil.rmtree(working_dir) + + @staticmethod + def _construct_script( + task_spec: common.TaskSpec, + globals: GlobalsBridge, + number: int, + error_log: str, + stat_log: str, + bindings: Optional[CallgrindModuleType], + ) -> str: + # The naive template looks something like: + # "for _ in range({number}): {stmt}" + # However a loop in Python is surprisingly expensive, and significantly + # increases the number of background Python instructions. So instead we + # partially unroll the loops, with a block size of 100 chosen to keep + # the instruction overhead from `range` low while also not ballooning + # the size of the generated file. + block_size = 100 + loop_count = number // block_size + remainder = number - block_size * loop_count + blocked_stmt = "" + if loop_count: + unrolled_stmts = textwrap.indent("\n".join([task_spec.stmt] * block_size), " " * 4) + blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n" + if remainder: + blocked_stmt += "\n".join([task_spec.stmt] * remainder) + + return textwrap.dedent(r""" + import gc + import os + import pickle + import subprocess + import sys + import time + + # Mitigate https://github.com/pytorch/pytorch/issues/37377 + # which can sometimes cause the subprocess call to fail. + import numpy as np + + import torch + torch.set_num_threads({num_threads}) + + {bindings_import} + + PID = os.getpid() + + def log_failure(msg): + with open({error_log_repr}, "wt") as f: + f.write(msg) + sys.exit(1) + + def check_result(completed_process): + if completed_process.returncode: + log_failure(f"Command failed: {{' '.join(completed_process.args)}}") + return completed_process + + # ============================================================================= + # == Check that subprocess matches parent ===================================== + # ============================================================================= + if sys.executable != "{parent_interpreter}": + log_failure( + "Interpreter mismatch:\n" + f" {{sys.executable}}\n vs.\n {parent_interpreter}" + ) + + if torch.__file__ != "{torch_file}": + log_failure( + "PyTorch does not match expected file:\n" + f" {{torch.__file__}}\n vs.\n {torch_file}" + ) + + # ============================================================================= + # == User specified setup ===================================================== + # ============================================================================= + # Load serialized globals + {load_globals} + + # User setup str + {setup} + + for _ in range({warmup_number}): + {indented_stmt} + + # ============================================================================= + # == Callgrind management ===================================================== + # ============================================================================= + with open("{stat_log}", "wb") as stat_file: + # If many instances of callgrind are running at once, the output of + # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE` + # to deadlock. So instead we use a file. + callgrind_stat = check_result(subprocess.run( + ["callgrind_control", "--stat"], + stdout=stat_file, + stderr=subprocess.STDOUT, + )) + + with open("{stat_log}", "rt") as stat_file: + stat_lines = stat_file.read().splitlines() + + if f"PID {{PID}}: python {{__file__}}" not in stat_lines: + log_failure("Process does not appear to be running callgrind.") + + gc.collect() + time.sleep(0.01) + + # ============================================================================= + # == User code block ========================================================== + # ============================================================================= + callgrind_bindings._valgrind_toggle() + {blocked_stmt} + + # Sleep is to allow the interpreter to catch up before we stop collecting in + # order to reduce jitter. + time.sleep(0.01) + callgrind_bindings._valgrind_toggle() + """).strip().format( + indented_stmt=textwrap.indent(task_spec.stmt, " " * 4), + blocked_stmt=blocked_stmt, + number=number, + load_globals=globals.construct(), + setup=task_spec.setup, + warmup_number=min(number, 10), + num_threads=task_spec.num_threads, + error_log_repr=repr(error_log), + stat_log=stat_log, + parent_interpreter=sys.executable, + torch_file=torch.__file__, + bindings_import=( + "import torch._C as callgrind_bindings" if bindings is None + else f"import {bindings.__name__} as callgrind_bindings"), + ) + + +CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None +def wrapper_singleton() -> _ValgrindWrapper: + global CALLGRIND_SINGLETON + if CALLGRIND_SINGLETON is None: + CALLGRIND_SINGLETON = _ValgrindWrapper() + return CALLGRIND_SINGLETON diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index c5d603885e4a2..741c0841778a6 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union +from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union, Sequence import textwrap import torch from torch._C import TupleType, OptionalType, ListType @@ -17,7 +17,7 @@ class InflatableArg(NamedTuple): def augment_model_with_bundled_inputs( model: torch.jit.ScriptModule, - inputs: Optional[List[Tuple[Any, ...]]] = None, + inputs: Optional[Sequence[Tuple[Any, ...]]] = None, _receive_inflate_expr: Optional[List[str]] = None, # For debugging. ) -> None: """Add bundled sample inputs to a model. diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 9bf664cd9e7c4..a31a15907a338 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -59,6 +59,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state + ctx.had_autocast_in_fwd = torch.is_autocast_enabled() if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. @@ -91,12 +92,24 @@ def backward(ctx, *args): if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(inputs) - with torch.enable_grad(): + with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) - torch.autograd.backward(outputs, args) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True," + " this checkpoint() is not necessary") + torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None, None) + grads @@ -132,15 +145,16 @@ def checkpoint(function, *args, **kwargs): .. warning:: If checkpointed segment contains tensors detached from the computational graph by `detach()` or `torch.no_grad()`, the backward pass will raise an - error. This is because `checkpoint` makes all the outputs require - gradients which causes issues when a tensor is defined to have no - gradient in the model. To circumvent this, detach the tensors outside of + error. This is because `checkpoint` makes all the outputs require + gradients which causes issues when a tensor is defined to have no + gradient in the model. To circumvent this, detach the tensors outside of the `checkpoint` function. .. warning: At least one of the inputs needs to have :code:`requires_grad=True` if grads are needed for model inputs, otherwise the checkpointed part of the - model won't have gradients. + model won't have gradients. At least one of the outputs needs to have + :code:`requires_grad=True` as well. Args: function: describes what to run in the forward pass of the model or diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 57fe37581e71b..5b91c7a9a0faa 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -10,7 +10,7 @@ try: import torch TORCH_AVAILABLE = True -except (ImportError, NameError, AttributeError): +except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information @@ -43,7 +43,10 @@ def run(command): stderr=subprocess.PIPE, shell=True) raw_output, raw_err = p.communicate() rc = p.returncode - enc = locale.getpreferredencoding() + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() output = raw_output.decode(enc) err = raw_err.decode(enc) return rc, output.strip(), err.strip() @@ -70,7 +73,7 @@ def run_and_parse_first_match(run_lambda, command, regex): def get_conda_packages(run_lambda): if get_platform() == 'win32': - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') findstr_cmd = os.path.join(system_root, 'System32', 'findstr') grep_cmd = r'{} /R "torch numpy cudatoolkit soumith mkl magma"'.format(findstr_cmd) else: @@ -105,7 +108,7 @@ def get_nvidia_driver_version(run_lambda): def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or torch.version.hip is not None: + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): if TORCH_AVAILABLE and torch.cuda.is_available(): return torch.cuda.get_device_name(None) return None @@ -125,7 +128,7 @@ def get_running_cuda_version(run_lambda): def get_cudnn_version(run_lambda): """This will return a list of libcudnn.so; it's hard to tell which one is being used""" if get_platform() == 'win32': - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") where_cmd = os.path.join(system_root, 'System32', 'where') cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) @@ -163,7 +166,15 @@ def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux smi = 'nvidia-smi' if get_platform() == 'win32': - smi = '"C:\\Program Files\\NVIDIA Corporation\\NVSMI\\%s"' % smi + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = f'"{candidate_smi}"' + break return smi @@ -185,7 +196,7 @@ def get_mac_version(run_lambda): def get_windows_version(run_lambda): - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') findstr_cmd = os.path.join(system_root, 'System32', 'findstr') return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) @@ -211,7 +222,7 @@ def get_os(run_lambda): version = get_mac_version(run_lambda) if version is None: return None - return 'Mac OSX {} ({})'.format(version, machine()) + return 'macOS {} ({})'.format(version, machine()) if platform == 'linux': # Ubuntu/Debian based @@ -236,7 +247,7 @@ def get_pip_packages(run_lambda): # People generally have `pip` as `pip` or `pip3` def run_with_pip(pip): if get_platform() == 'win32': - system_root = os.environ.get('SystemRoot', 'C:\\Windows') + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') findstr_cmd = os.path.join(system_root, 'System32', 'findstr') grep_cmd = r'{} /R "numpy torch"'.format(findstr_cmd) else: @@ -270,41 +281,31 @@ def get_env_info(): debug_mode_str = str(torch.version.debug) cuda_available_str = str(torch.cuda.is_available()) cuda_version_str = torch.version.cuda + if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + else: # HIP version + cfg = torch._C._show_config().split('\n') + hip_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'HIP Runtime' in s][0] + miopen_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'MIOpen' in s][0] + cuda_version_str = 'N/A' + hip_compiled_version = torch.version.hip else: version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' - - if torch.version.hip is None: # cuda version - gpu_info = dict( - is_cuda_available=cuda_available_str, - cuda_compiled_version=cuda_version_str, - cuda_runtime_version=get_running_cuda_version(run_lambda), - nvidia_gpu_models=get_gpu_info(run_lambda), - nvidia_driver_version=get_nvidia_driver_version(run_lambda), - cudnn_version=get_cudnn_version(run_lambda), - hip_compiled_version='N/A', - hip_runtime_version='N/A', - miopen_runtime_version='N/A', - ) - else: # HIP version - cfg = torch._C._show_config().split('\n') - hip_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'HIP Runtime' in s][0] - miopen_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'MIOpen' in s][0] - gpu_info = dict( - is_cuda_available=cuda_available_str, - cuda_compiled_version='N/A', - hip_compiled_version=torch.version.hip, - hip_runtime_version=hip_runtime_version, - miopen_runtime_version=miopen_runtime_version, - cuda_runtime_version='N/A', - nvidia_gpu_models=get_gpu_info(run_lambda), - nvidia_driver_version=get_nvidia_driver_version(run_lambda), - cudnn_version='N/A', - ) + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, python_version='{}.{} ({}-bit runtime)'.format(sys.version_info[0], sys.version_info[1], sys.maxsize.bit_length() + 1), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, pip_version=pip_version, pip_packages=pip_list_output, conda_packages=get_conda_packages(run_lambda), @@ -312,7 +313,6 @@ def get_env_info(): gcc_version=get_gcc_version(run_lambda), clang_version=get_clang_version(run_lambda), cmake_version=get_cmake_version(run_lambda), - **gpu_info ) env_info_fmt = """ diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 1b3d50a8244d0..93a8f403bc4a1 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -17,12 +17,34 @@ from ._cpp_extension_versioner import ExtensionVersioner from .hipify import hipify_python from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner -from typing import List, Optional +from typing import List, Optional, Union from setuptools.command.build_ext import build_ext +from pkg_resources import packaging # type: ignore IS_WINDOWS = sys.platform == 'win32' +LIB_EXT = '.pyd' if IS_WINDOWS else '.so' +EXEC_EXT = '.exe' if IS_WINDOWS else '' +SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared' + +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib') + + +# Taken directly from python stdlib < 3.9 +# See https://github.com/pytorch/pytorch/issues/48617 +def _nt_quote_args(args: Optional[List[str]]) -> List[str]: + """Quote command-line arguments for DOS/Windows conventions. + + Just wraps every argument which contains blanks in double quotes, and + returns a new argument list. + """ + # Cover None-type + if not args: + return [] + return [f'"{arg}"' if ' ' in arg else arg for arg in args] def _find_cuda_home() -> Optional[str]: r'''Finds the CUDA install path.''' @@ -50,7 +72,7 @@ def _find_cuda_home() -> Optional[str]: if not os.path.exists(cuda_home): cuda_home = None if cuda_home and not torch.cuda.is_available(): - print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home)) + print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'") return cuda_home def _find_rocm_home() -> Optional[str]: @@ -72,7 +94,7 @@ def _find_rocm_home() -> Optional[str]: if not os.path.exists(rocm_home): rocm_home = None if rocm_home and torch.version.hip is None: - print("No ROCm runtime is found, using ROCM_HOME='{}'".format(rocm_home)) + print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'") return rocm_home @@ -152,13 +174,17 @@ def _join_rocm_home(*paths) -> str: COMMON_NVCC_FLAGS = [ '-D__CUDA_NO_HALF_OPERATORS__', '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', '-D__CUDA_NO_HALF2_OPERATORS__', '--expt-relaxed-constexpr' ] -COMMON_HIPCC_FLAGS = [ +COMMON_HIP_FLAGS = [ '-fPIC', '-D__HIP_PLATFORM_HCC__=1', +] + +COMMON_HIPCC_FLAGS = [ '-DCUDA_HAS_FP16=1', '-D__HIP_NO_HALF_OPERATORS__=1', '-D__HIP_NO_HALF_CONVERSIONS__=1', @@ -200,7 +226,7 @@ def check_compiler_ok_for_platform(compiler: str) -> bool: r''' Verifies that the compiler is the expected one for the current platform. - Arguments: + Args: compiler (str): The compiler executable to check. Returns: @@ -235,7 +261,7 @@ def check_compiler_abi_compatibility(compiler) -> bool: r''' Verifies that the given compiler is ABI-compatible with PyTorch. - Arguments: + Args: compiler (str): The compiler executable name to check (e.g. ``g++``). Must be executable in a shell process. @@ -271,13 +297,13 @@ def check_compiler_abi_compatibility(compiler) -> bool: version = (0, 0, 0) if match is None else match.groups() except Exception: _, error, _ = sys.exc_info() - warnings.warn('Error checking compiler version for {}: {}'.format(compiler, error)) + warnings.warn(f'Error checking compiler version for {compiler}: {error}') return False if tuple(map(int, version)) >= minimum_required_version: return True - compiler = '{} {}'.format(compiler, ".".join(version)) + compiler = f'{compiler} {".".join(version)}' warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler)) return False @@ -347,7 +373,24 @@ def finalize_options(self) -> None: def build_extensions(self) -> None: self._check_abi() for extension in self.extensions: + # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when + # extra_compile_args is a dict. Otherwise, default torch flags do + # not get passed. Necessary when only one of 'cxx' and 'nvcc' is + # passed to extra_compile_args in CUDAExtension, i.e. + # CUDAExtension(..., extra_compile_args={'cxx': [...]}) + # or + # CUDAExtension(..., extra_compile_args={'nvcc': [...]}) + if isinstance(extension.extra_compile_args, dict): + for ext in ['cxx', 'nvcc']: + if ext not in extension.extra_compile_args: + extension.extra_compile_args[ext] = [] + self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H') + # See note [Pybind11 ABI constants] + for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + val = getattr(torch._C, f"_PYBIND11_{name}") + if val is not None and not IS_WINDOWS: + self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"') self._define_torch_extension_name(extension) self._add_gnu_cpp_abi_flag(extension) @@ -371,11 +414,20 @@ def append_std14_if_no_std_present(cflags) -> None: cflags.append(cpp_flag) def unix_cuda_flags(cflags): + cflags = (COMMON_NVCC_FLAGS + + ['--compiler-options', "'-fPIC'"] + + cflags + _get_cuda_arch_flags(cflags)) + + # NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid + # overriding the option if the user explicitly passed it. _ccbin = os.getenv("CC") - return (COMMON_NVCC_FLAGS + - ['--compiler-options', "'-fPIC'"] + - cflags + _get_cuda_arch_flags(cflags) + - (['-ccbin', _ccbin] if _ccbin is not None else [])) + if ( + _ccbin is not None + and not any([flag.startswith('-ccbin') or flag.startswith('--compiler-bindir') for flag in cflags]) + ): + cflags.extend(['-ccbin', _ccbin]) + + return cflags def convert_to_absolute_paths_inplace(paths): # Helper function. See Note [Absolute include_dirs] @@ -395,13 +447,13 @@ def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) -> if isinstance(cflags, dict): cflags = cflags['nvcc'] if IS_HIP_EXTENSION: - cflags = cflags + _get_rocm_arch_flags(cflags) + cflags = COMMON_HIPCC_FLAGS + cflags + _get_rocm_arch_flags(cflags) else: cflags = unix_cuda_flags(cflags) elif isinstance(cflags, dict): cflags = cflags['cxx'] if IS_HIP_EXTENSION: - cflags = cflags + COMMON_HIPCC_FLAGS + cflags = COMMON_HIP_FLAGS + cflags append_std14_if_no_std_present(cflags) original_compile(obj, src, ext, cc_args, cflags, pp_opts) @@ -450,7 +502,7 @@ def unix_wrap_ninja_compile(sources, else: post_cflags = list(extra_postargs) if IS_HIP_EXTENSION: - post_cflags += COMMON_HIPCC_FLAGS + post_cflags = COMMON_HIP_FLAGS + post_cflags append_std14_if_no_std_present(post_cflags) cuda_post_cflags = None @@ -463,7 +515,7 @@ def unix_wrap_ninja_compile(sources, cuda_post_cflags = list(extra_postargs) if IS_HIP_EXTENSION: cuda_post_cflags = cuda_post_cflags + _get_rocm_arch_flags(cuda_post_cflags) - cuda_post_cflags = cuda_post_cflags + COMMON_HIPCC_FLAGS + cuda_post_cflags = COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_post_cflags else: cuda_post_cflags = unix_cuda_flags(cuda_post_cflags) append_std14_if_no_std_present(cuda_post_cflags) @@ -532,7 +584,7 @@ def spawn(cmd): else: cflags = [] - cflags = win_cuda_flags(cflags) + cflags = win_cuda_flags(cflags) + ['--use-local-env'] for flag in COMMON_MSVC_FLAGS: cflags = ['-Xcompiler', flag] + cflags for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: @@ -602,7 +654,7 @@ def win_wrap_ninja_compile(sources, cuda_post_cflags = None cuda_cflags = None if with_cuda: - cuda_cflags = [] + cuda_cflags = ['--use-local-env'] for common_cflag in common_cflags: cuda_cflags.append('-Xcompiler') cuda_cflags.append(common_cflag) @@ -616,7 +668,6 @@ def win_wrap_ninja_compile(sources, cuda_post_cflags = list(extra_postargs) cuda_post_cflags = win_cuda_flags(cuda_post_cflags) - from distutils.spawn import _nt_quote_args # type: ignore cflags = _nt_quote_args(cflags) post_cflags = _nt_quote_args(post_cflags) if with_cuda: @@ -699,7 +750,7 @@ def _define_torch_extension_name(self, extension): # as the library name names = extension.name.split('.') name = names[-1] - define = '-DTORCH_EXTENSION_NAME={}'.format(name) + define = f'-DTORCH_EXTENSION_NAME={name}' self._add_compile_flag(extension, define) def _add_gnu_cpp_abi_flag(self, extension): @@ -778,6 +829,35 @@ def CUDAExtension(name, sources, *args, **kwargs): cmdclass={ 'build_ext': BuildExtension }) + + Compute capabilities: + + By default the extension will be compiled to run on all archs of the cards visible during the + building process of the extension, plus PTX. If down the road a new card is installed the + extension may need to be recompiled. If a visible card has a compute capability (CC) that's + newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch + will make nvcc fall back to building kernels with the newest version of PTX your nvcc does + support (see below for details on PTX). + + You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which + CCs you want the extension to support: + + TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py + TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py + + The +PTX option causes extension kernel binaries to include PTX instructions for the specified + CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >= + the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with + CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to + provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on + those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better + off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6, + "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but + "8.0 8.6" would be better. + + Note that while it's possible to include all supported archs, the more archs get included the + slower the building process will be, as it will build a separate kernel image for each arch. + ''' library_dirs = kwargs.get('library_dirs', []) library_dirs += library_paths(cuda=True) @@ -800,6 +880,27 @@ def CUDAExtension(name, sources, *args, **kwargs): kwargs['libraries'] = libraries include_dirs = kwargs.get('include_dirs', []) + + if IS_HIP_EXTENSION: + build_dir = os.getcwd() + if not include_dirs: + include_dirs = ['*'] + hipify_result = hipify_python.hipify( + project_directory=build_dir, + output_directory=build_dir, + includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs], + extra_files=[os.path.abspath(s) for s in sources], + show_detailed=True, + is_pytorch_extension=True, + ) + + hipified_sources = set() + for source in sources: + s_abs = os.path.abspath(source) + hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs) + + sources = list(hipified_sources) + include_dirs += include_paths(cuda=True) kwargs['include_dirs'] = include_dirs @@ -818,9 +919,7 @@ def include_paths(cuda: bool = False) -> List[str]: Returns: A list of include path strings. ''' - here = os.path.abspath(__file__) - torch_path = os.path.dirname(os.path.dirname(here)) - lib_include = os.path.join(torch_path, 'include') + lib_include = os.path.join(_TORCH_PATH, 'include') paths = [ lib_include, # Remove this once torch/torch.h is officially no longer supported for C++ extensions. @@ -856,13 +955,8 @@ def library_paths(cuda: bool = False) -> List[str]: Returns: A list of library path strings. ''' - paths = [] - # We need to link against libtorch.so - here = os.path.abspath(__file__) - torch_path = os.path.dirname(os.path.dirname(here)) - lib_path = os.path.join(torch_path, 'lib') - paths.append(lib_path) + paths = [TORCH_LIB_PATH] if cuda and IS_HIP_EXTENSION: lib_dir = 'lib' @@ -886,7 +980,7 @@ def library_paths(cuda: bool = False) -> List[str]: def load(name, - sources: List[str], + sources: Union[str, List[str]], extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, @@ -895,6 +989,7 @@ def load(name, verbose=False, with_cuda: Optional[bool] = None, is_python_module=True, + is_standalone=False, keep_intermediates=True): r''' Loads a PyTorch C++ extension just-in-time (JIT). @@ -949,14 +1044,23 @@ def load(name, ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers and libraries to be included. is_python_module: If ``True`` (default), imports the produced shared - library as a Python module. If ``False``, loads it into the process - as a plain dynamic library. + library as a Python module. If ``False``, behavior depends on + ``is_standalone``. + is_standalone: If ``False`` (default) loads the constructed extension + into the process as a plain dynamic library. If ``True``, build a + standalone executable. Returns: - If ``is_python_module`` is ``True``, returns the loaded PyTorch - extension as a Python module. If ``is_python_module`` is ``False`` - returns nothing (the shared library is loaded into the process as a side - effect). + If ``is_python_module`` is ``True``: + Returns the loaded PyTorch extension as a Python module. + + If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``: + Returns nothing. (The shared library is loaded into the process as + a side effect.) + + If ``is_standalone`` is ``True``. + Return the path to the executable. (On Windows, TORCH_LIB_PATH is + added to the PATH environment variable as a side effect.) Example: >>> from torch.utils.cpp_extension import load @@ -977,6 +1081,7 @@ def load(name, verbose, with_cuda, is_python_module, + is_standalone, keep_intermediates=keep_intermediates) @@ -1086,9 +1191,7 @@ def load_inline(name, # Make the function docstring the same as the function name. functions = dict((f, f) for f in functions) elif not isinstance(functions, dict): - raise ValueError( - "Expected 'functions' to be a list or dict, but was {}".format( - type(functions))) + raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}") for function_name, docstring in functions.items(): if with_pytorch_error_handling: module_def.append( @@ -1127,6 +1230,7 @@ def load_inline(name, verbose, with_cuda, is_python_module, + is_standalone=False, keep_intermediates=keep_intermediates) @@ -1140,7 +1244,11 @@ def _jit_compile(name, verbose: bool, with_cuda: Optional[bool], is_python_module, + is_standalone, keep_intermediates=True) -> None: + if is_python_module and is_standalone: + raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.") + if with_cuda is None: with_cuda = any(map(_is_cuda_file, sources)) with_cudnn = any(['cudnn' in f for f in extra_ldflags or []]) @@ -1150,13 +1258,15 @@ def _jit_compile(name, sources, build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths], build_directory=build_directory, - with_cuda=with_cuda + with_cuda=with_cuda, + is_python_module=is_python_module, + is_standalone=is_standalone, ) if version > 0: if version != old_version and verbose: - print('The input conditions for extension module {} have changed. '.format(name) + - 'Bumping to version {0} and re-building as {1}_v{0}...'.format(version, name)) - name = '{}_v{}'.format(name, version) + print(f'The input conditions for extension module {name} have changed. ' + + f'Bumping to version {version} and re-building as {name}_v{version}...') + name = f'{name}_v{version}' if version != old_version: baton = FileBaton(os.path.join(build_directory, 'lock')) @@ -1182,17 +1292,22 @@ def _jit_compile(name, extra_include_paths=extra_include_paths or [], build_directory=build_directory, verbose=verbose, - with_cuda=with_cuda) + with_cuda=with_cuda, + is_standalone=is_standalone) finally: baton.release() else: baton.wait() elif verbose: print('No modifications detected for re-loaded extension ' - 'module {}, skipping build step...'.format(name)) + f'module {name}, skipping build step...') if verbose: print(f'Loading extension module {name}...') + + if is_standalone: + return _get_exec_path(name, build_directory) + return _import_module_from_library(name, build_directory, is_python_module) @@ -1247,7 +1362,8 @@ def _write_ninja_file_and_build_library( extra_include_paths, build_directory: str, verbose: bool, - with_cuda: Optional[bool]) -> None: + with_cuda: Optional[bool], + is_standalone: bool = False) -> None: verify_ninja_availability() if IS_WINDOWS: compiler = os.environ.get('CXX', 'cl') @@ -1259,7 +1375,8 @@ def _write_ninja_file_and_build_library( extra_ldflags = _prepare_ldflags( extra_ldflags or [], with_cuda, - verbose) + verbose, + is_standalone) build_file_path = os.path.join(build_directory, 'build.ninja') if verbose: print(f'Emitting ninja build file {build_file_path}...') @@ -1273,14 +1390,15 @@ def _write_ninja_file_and_build_library( extra_cuda_cflags=extra_cuda_cflags or [], extra_ldflags=extra_ldflags or [], extra_include_paths=extra_include_paths or [], - with_cuda=with_cuda) + with_cuda=with_cuda, + is_standalone=is_standalone) if verbose: - print('Building extension module {}...'.format(name)) + print(f'Building extension module {name}...') _run_ninja_build( build_directory, verbose, - error_prefix="Error building extension '{}'".format(name)) + error_prefix=f"Error building extension '{name}'") def is_ninja_available(): @@ -1288,13 +1406,12 @@ def is_ninja_available(): Returns ``True`` if the `ninja `_ build system is available on the system, ``False`` otherwise. ''' - with open(os.devnull, 'wb') as devnull: - try: - subprocess.check_call('ninja --version'.split(), stdout=devnull) - except OSError: - return False - else: - return True + try: + subprocess.check_output('ninja --version'.split()) + except Exception: + return False + else: + return True def verify_ninja_availability(): @@ -1306,11 +1423,7 @@ def verify_ninja_availability(): raise RuntimeError("Ninja is required to load C++ extensions") -def _prepare_ldflags(extra_ldflags, with_cuda, verbose): - here = os.path.abspath(__file__) - torch_path = os.path.dirname(os.path.dirname(here)) - lib_path = os.path.join(torch_path, 'lib') - +def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): if IS_WINDOWS: python_path = os.path.dirname(sys.executable) python_lib_path = os.path.join(python_path, 'libs') @@ -1325,11 +1438,13 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): # Related issue: https://github.com/pytorch/pytorch/issues/31611 extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ') extra_ldflags.append('torch.lib') - extra_ldflags.append('torch_python.lib') - extra_ldflags.append('/LIBPATH:{}'.format(python_lib_path)) - extra_ldflags.append('/LIBPATH:{}'.format(lib_path)) + extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}') + if not is_standalone: + extra_ldflags.append('torch_python.lib') + extra_ldflags.append(f'/LIBPATH:{python_lib_path}') + else: - extra_ldflags.append('-L{}'.format(lib_path)) + extra_ldflags.append(f'-L{TORCH_LIB_PATH}') extra_ldflags.append('-lc10') if with_cuda: extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') @@ -1337,25 +1452,31 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): if with_cuda: extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda') extra_ldflags.append('-ltorch') - extra_ldflags.append('-ltorch_python') + if not is_standalone: + extra_ldflags.append('-ltorch_python') + + if is_standalone and "TBB" in torch.__config__.parallel_info(): + extra_ldflags.append('-ltbb') + + if is_standalone: + extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}") if with_cuda: if verbose: print('Detected CUDA files, patching ldflags') if IS_WINDOWS: - extra_ldflags.append('/LIBPATH:{}'.format( - _join_cuda_home('lib/x64'))) + extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib/x64")}') extra_ldflags.append('cudart.lib') if CUDNN_HOME is not None: extra_ldflags.append(os.path.join(CUDNN_HOME, 'lib/x64')) elif not IS_HIP_EXTENSION: - extra_ldflags.append('-L{}'.format(_join_cuda_home('lib64'))) + extra_ldflags.append(f'-L{_join_cuda_home("lib64")}') extra_ldflags.append('-lcudart') if CUDNN_HOME is not None: - extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64'))) + extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}') elif IS_HIP_EXTENSION: assert ROCM_VERSION is not None - extra_ldflags.append('-L{}'.format(_join_rocm_home('lib'))) + extra_ldflags.append(f'-L{_join_rocm_home("lib")}') extra_ldflags.append('-lamdhip64' if ROCM_VERSION >= (3, 5) else '-lhip_hcc') return extra_ldflags @@ -1389,11 +1510,11 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: ('Pascal', '6.0;6.1+PTX'), ('Volta', '7.0+PTX'), ('Turing', '7.5+PTX'), - ('Ampere', '8.0+PTX'), + ('Ampere', '8.0;8.6+PTX'), ]) supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2', - '7.0', '7.2', '7.5', '8.0'] + '7.0', '7.2', '7.5', '8.0', '8.6'] valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] # The default is sm_30 for CUDA 9.x and 10.x @@ -1402,10 +1523,26 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: # See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake _arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None) - # If not given, determine what's needed for the GPU that can be found + # If not given, determine what's best for the GPU / CUDA version that can be found if not _arch_list: - capability = torch.cuda.get_device_capability() - arch_list = ['{}.{}'.format(capability[0], capability[1])] + arch_list = [] + # the assumption is that the extension should run on any of the currently visible cards, + # which could be of different types - therefore all archs for visible cards should be included + for i in range(torch.cuda.device_count()): + capability = torch.cuda.get_device_capability(i) + supported_sm = [int(arch.split('_')[1]) + for arch in torch.cuda.get_arch_list() if 'sm_' in arch] + max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) + # Capability of the device may be higher than what's supported by the user's + # NVCC, causing compilation error. User's NVCC is expected to match the one + # used to build pytorch, so we use the maximum supported capability of pytorch + # to clamp the capability. + capability = min(max_supported_sm, capability) + arch = f'{capability[0]}.{capability[1]}' + if arch not in arch_list: + arch_list.append(arch) + arch_list = sorted(arch_list) + arch_list[-1] += '+PTX' else: # Deal with lists that are ' ' separated (only deal with ';' after) _arch_list = _arch_list.replace(' ', ';') @@ -1418,14 +1555,14 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: flags = [] for arch in arch_list: if arch not in valid_arch_strings: - raise ValueError("Unknown CUDA arch ({}) or GPU not supported".format(arch)) + raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported") else: num = arch[0] + arch[2] - flags.append('-gencode=arch=compute_{},code=sm_{}'.format(num, num)) + flags.append(f'-gencode=arch=compute_{num},code=sm_{num}') if arch.endswith('+PTX'): - flags.append('-gencode=arch=compute_{},code=compute_{}'.format(num, num)) + flags.append(f'-gencode=arch=compute_{num},code=compute_{num}') - return list(set(flags)) + return sorted(list(set(flags))) def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: @@ -1450,8 +1587,7 @@ def _get_build_directory(name: str, verbose: bool) -> str: root_extensions_directory = get_default_build_root() if verbose: - print('Using {} as PyTorch extensions root...'.format( - root_extensions_directory)) + print(f'Using {root_extensions_directory} as PyTorch extensions root...') build_directory = os.path.join(root_extensions_directory, name) if not os.path.exists(build_directory): @@ -1467,7 +1603,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]: max_jobs = os.environ.get('MAX_JOBS') if max_jobs is not None and max_jobs.isdigit(): if verbose: - print('Using envvar MAX_JOBS ({}) as the number of workers...'.format(max_jobs)) + print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...') return int(max_jobs) if verbose: print('Allowing ninja to set a default number of workers... ' @@ -1499,33 +1635,26 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> try: sys.stdout.flush() sys.stderr.flush() - if sys.version_info >= (3, 5): - # Warning: don't pass stdout=None to subprocess.run to get output. - # subprocess.run assumes that sys.__stdout__ has not been modified and - # attempts to write to it by default. However, when we call _run_ninja_build - # from ahead-of-time cpp extensions, the following happens: - # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__. - # https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110 - # (it probably shouldn't do this) - # 2) subprocess.run (on POSIX, with no stdout override) relies on - # __stdout__ not being detached: - # https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214 - # To work around this, we pass in the fileno directly and hope that - # it is valid. - stdout_fileno = 1 - subprocess.run( - command, - stdout=stdout_fileno if verbose else subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=build_directory, - check=True, - env=env) - else: - subprocess.check_output( - command, - stderr=subprocess.STDOUT, - cwd=build_directory, - env=env) + # Warning: don't pass stdout=None to subprocess.run to get output. + # subprocess.run assumes that sys.__stdout__ has not been modified and + # attempts to write to it by default. However, when we call _run_ninja_build + # from ahead-of-time cpp extensions, the following happens: + # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__. + # https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110 + # (it probably shouldn't do this) + # 2) subprocess.run (on POSIX, with no stdout override) relies on + # __stdout__ not being detached: + # https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214 + # To work around this, we pass in the fileno directly and hope that + # it is valid. + stdout_fileno = 1 + subprocess.run( + command, + stdout=stdout_fileno if verbose else subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=build_directory, + check=True, + env=env) except subprocess.CalledProcessError as e: # Python 2 and 3 compatible way of getting the error object. _, error, _ = sys.exc_info() @@ -1534,17 +1663,28 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> # `error` is a CalledProcessError (which has an `ouput`) attribute, but # mypy thinks it's Optional[BaseException] and doesn't narrow if hasattr(error, 'output') and error.output: # type: ignore - message += ": {}".format(error.output.decode()) # type: ignore + message += f": {error.output.decode()}" # type: ignore raise RuntimeError(message) from e +def _get_exec_path(module_name, path): + if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'): + torch_lib_in_path = any( + os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH) + for p in os.getenv('PATH', '').split(';') + ) + if not torch_lib_in_path: + os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}" + return os.path.join(path, f'{module_name}{EXEC_EXT}') + + def _import_module_from_library(module_name, path, is_python_module): # https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path file, path, description = imp.find_module(module_name, [path]) # Close the .so file after load. with file: if is_python_module: - return imp.load_module(module_name, file, path, description) + return imp.load_module(module_name, file, path, description) # type: ignore else: torch.ops.load_library(path) @@ -1556,7 +1696,8 @@ def _write_ninja_file_to_build_library(path, extra_cuda_cflags, extra_ldflags, extra_include_paths, - with_cuda) -> None: + with_cuda, + is_standalone) -> None: extra_cflags = [flag.strip() for flag in extra_cflags] extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags] extra_ldflags = [flag.strip() for flag in extra_ldflags] @@ -1576,27 +1717,46 @@ def _write_ninja_file_to_build_library(path, user_includes += system_includes system_includes.clear() - common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)] - common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') - common_cflags += ['-I{}'.format(include) for include in user_includes] - common_cflags += ['-isystem {}'.format(include) for include in system_includes] + common_cflags = [] + if not is_standalone: + common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}') + common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') + + # Note [Pybind11 ABI constants] + # + # Pybind11 before 2.4 used to build an ABI strings using the following pattern: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__" + # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__" + # + # This was done in order to further narrow down the chances of compiler ABI incompatibility + # that can cause a hard to debug segfaults. + # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties + # captured during PyTorch native library compilation in torch/csrc/Module.cpp + + for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + pval = getattr(torch._C, f"_PYBIND11_{pname}") + if pval is not None and not IS_WINDOWS: + common_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"') + + common_cflags += [f'-I{include}' for include in user_includes] + common_cflags += [f'-isystem {include}' for include in system_includes] common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] if IS_WINDOWS: cflags = common_cflags + COMMON_MSVC_FLAGS + extra_cflags - from distutils.spawn import _nt_quote_args # type: ignore cflags = _nt_quote_args(cflags) else: cflags = common_cflags + ['-fPIC', '-std=c++14'] + extra_cflags if with_cuda and IS_HIP_EXTENSION: - cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIPCC_FLAGS + cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS cuda_flags += extra_cuda_cflags cuda_flags += _get_rocm_arch_flags(cuda_flags) sources = [s if not _is_cuda_file(s) else os.path.abspath(os.path.join( - path, get_hip_file_path(os.path.relpath(s, path)))) + path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True))) for s in sources] elif with_cuda: cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags() @@ -1623,25 +1783,22 @@ def object_file_path(source_file: str) -> str: if _is_cuda_file(source_file) and with_cuda: # Use a different object filename in case a C++ and CUDA file have # the same filename but different extension (.cpp vs. .cu). - target = '{}.cuda.o'.format(file_name) + target = f'{file_name}.cuda.o' else: - target = '{}.o'.format(file_name) + target = f'{file_name}.o' return target - objects = list(map(object_file_path, sources)) + objects = [object_file_path(src) for src in sources] + ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags - if IS_WINDOWS: - ldflags = ['/DLL'] + extra_ldflags - else: - ldflags = ['-shared'] + extra_ldflags # The darwin linker needs explicit consent to ignore unresolved symbols. if sys.platform.startswith('darwin'): ldflags.append('-undefined dynamic_lookup') elif IS_WINDOWS: ldflags = _nt_quote_args(ldflags) - ext = 'pyd' if IS_WINDOWS else 'so' - library_target = '{}.{}'.format(name, ext) + ext = EXEC_EXT if is_standalone else LIB_EXT + library_target = f'{name}{ext}' _write_ninja_file( path=path, @@ -1703,20 +1860,20 @@ def sanitize_flags(flags): # Version 1.3 is required for the `deps` directive. config = ['ninja_required_version = 1.3'] - config.append('cxx = {}'.format(compiler)) + config.append(f'cxx = {compiler}') if with_cuda: if IS_HIP_EXTENSION: nvcc = _join_rocm_home('bin', 'hipcc') else: nvcc = _join_cuda_home('bin', 'nvcc') - config.append('nvcc = {}'.format(nvcc)) + config.append(f'nvcc = {nvcc}') - flags = ['cflags = {}'.format(' '.join(cflags))] - flags.append('post_cflags = {}'.format(' '.join(post_cflags))) + flags = [f'cflags = {" ".join(cflags)}'] + flags.append(f'post_cflags = {" ".join(post_cflags)}') if with_cuda: - flags.append('cuda_cflags = {}'.format(' '.join(cuda_cflags))) - flags.append('cuda_post_cflags = {}'.format(' '.join(cuda_post_cflags))) - flags.append('ldflags = {}'.format(' '.join(ldflags))) + flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') + flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') + flags.append(f'ldflags = {" ".join(ldflags)}') # Turn into absolute paths so we can emit them into the ninja build # file wherever it is. @@ -1736,8 +1893,21 @@ def sanitize_flags(flags): if with_cuda: cuda_compile_rule = ['rule cuda_compile'] + nvcc_gendeps = '' + # --generate-dependencies-with-compile was added in CUDA 10.2. + # Compilation will work on earlier CUDA versions but header file + # dependencies are not correctly computed. + required_cuda_version = packaging.version.parse('10.2') + has_cuda_version = torch.version.cuda is not None + if has_cuda_version and packaging.version.parse(torch.version.cuda) >= required_cuda_version: + cuda_compile_rule.append(' depfile = $out.d') + cuda_compile_rule.append(' deps = gcc') + # Note: non-system deps with nvcc are only supported + # on Linux so use --generate-dependencies-with-compile + # to make this work on Windows too. + nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' cuda_compile_rule.append( - ' command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags') + f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') # Emit one build rule per source to enable incremental build. build = [] @@ -1749,7 +1919,7 @@ def sanitize_flags(flags): object_file = object_file.replace(':', '$:') source_file = source_file.replace(" ", "$ ") object_file = object_file.replace(" ", "$ ") - build.append('build {}: {} {}'.format(object_file, rule, source_file)) + build.append(f'build {object_file}: {rule} {source_file}') if library_target is not None: link_rule = ['rule link'] @@ -1760,15 +1930,13 @@ def sanitize_flags(flags): cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') else: raise RuntimeError("MSVC is required to load C++ extensions") - link_rule.append( - ' command = "{}/link.exe" $in /nologo $ldflags /out:$out'.format( - cl_path)) + link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out') else: link_rule.append(' command = $cxx $in $ldflags -o $out') - link = ['build {}: link {}'.format(library_target, ' '.join(objects))] + link = [f'build {library_target}: link {" ".join(objects)}'] - default = ['default {}'.format(library_target)] + default = [f'default {library_target}'] else: link_rule, link, default = [], [], [] @@ -1780,7 +1948,7 @@ def sanitize_flags(flags): with open(path, 'w') as build_file: for block in blocks: lines = '\n'.join(block) - build_file.write('{}\n\n'.format(lines)) + build_file.write(f'{lines}\n\n') def _join_cuda_home(*paths) -> str: diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 92f019bae1b9e..6aac75611a6fc 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -1,11 +1,13 @@ from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler -from .dataset import Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, Subset, random_split +from .dataset import (Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, BufferedShuffleDataset, + Subset, random_split) from .distributed import DistributedSampler from .dataloader import DataLoader, _DatasetKind, get_worker_info - +from .datasets import (BatchIterableDataset, CollateIterableDataset, SamplerIterableDataset) __all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', - 'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler' - 'DistributedSampler' 'Dataset', 'IterableDataset', 'TensorDataset', - 'ConcatDataset', 'ChainDataset', 'Subset', 'random_split' - 'DataLoader', '_DatasetKind', 'get_worker_info'] + 'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler', + 'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset', + 'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset', + 'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info', + 'BatchIterableDataset', 'CollateIterableDataset', 'SamplerIterableDataset'] diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 508c03950852c..a8ca66057b6bb 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -7,9 +7,10 @@ import torch import random import os -from collections import namedtuple +from dataclasses import dataclass from torch._six import queue from torch._utils import ExceptionWrapper +from typing import Union from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS if IS_WINDOWS: @@ -23,7 +24,8 @@ class ManagerWatchdog(object): def __init__(self): self.manager_pid = os.getppid() - self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + # mypy cannot detect this code is windows only + self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # type: ignore self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) self.kernel32.OpenProcess.restype = HANDLE self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) @@ -34,7 +36,7 @@ def __init__(self): self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) if not self.manager_handle: - raise ctypes.WinError(ctypes.get_last_error()) + raise ctypes.WinError(ctypes.get_last_error()) # type: ignore self.manager_dead = False @@ -108,10 +110,14 @@ def get_worker_info(): r"""Dummy class used to signal the end of an IterableDataset""" -_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration', ['worker_id']) +@dataclass(frozen=True) +class _IterableDatasetStopIteration(object): + worker_id: int r"""Dummy class used to resume the fetching when worker reuse is enabled""" -_ResumeIteration = namedtuple('_ResumeIteration', []) +@dataclass(frozen=True) +class _ResumeIteration(object): + pass def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, auto_collation, collate_fn, drop_last, seed, init_fn, worker_id, @@ -171,7 +177,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, continue if isinstance(r, _ResumeIteration): # Acknowledge the main process - data_queue.put(r) + data_queue.put((r, None)) iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher( @@ -187,6 +193,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, # processing steps. continue idx, index = r + data: Union[_IterableDatasetStopIteration, ExceptionWrapper] if init_exception is not None: data = init_exception init_exception = None diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index f784817a20e6f..a5eeeec671e39 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ in `./_utils/worker.py`. """ +import os import threading import itertools import warnings @@ -54,7 +55,7 @@ class _InfiniteConstantSampler(Sampler): r"""Analogous to ``itertools.repeat(None, None)``. Used as sampler for :class:`~torch.utils.data.IterableDataset`. - Arguments: + Args: data_source (Dataset): dataset to sample from """ @@ -77,7 +78,7 @@ class DataLoader(Generic[T_co]): See :py:mod:`torch.utils.data` documentation page for more details. - Arguments: + Args: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: ``1``). @@ -109,11 +110,11 @@ class DataLoader(Generic[T_co]): worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``) - prefetch_factor (int, optional, keyword-only arg): Number of sample loaded + prefetch_factor (int, optional, keyword-only arg): Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers samples prefetched across all workers. (default: ``2``) persistent_workers (bool, optional): If ``True``, the data loader will not shutdown - the worker processes after a dataset has been consumed once. This allows to + the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``) @@ -139,6 +140,9 @@ class DataLoader(Generic[T_co]): See `Dataset Types`_ for more details on these two types of datasets and how :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. """ dataset: Dataset[T_co] batch_size: Optional[int] @@ -287,10 +291,13 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, self._iterator = None + self.check_worker_number_rationality() + def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: + self.check_worker_number_rationality() return _MultiProcessingDataLoaderIter(self) @property @@ -301,10 +308,6 @@ def multiprocessing_context(self): def multiprocessing_context(self, multiprocessing_context): if multiprocessing_context is not None: if self.num_workers > 0: - if not multiprocessing._supports_context: - raise ValueError('multiprocessing_context relies on Python >= 3.4, with ' - 'support for different start methods') - if isinstance(multiprocessing_context, string_classes): valid_start_methods = multiprocessing.get_all_start_methods() if multiprocessing_context not in valid_start_methods: @@ -396,6 +399,83 @@ def __len__(self) -> int: else: return len(self._index_sampler) + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + + suggested_max_worker_msg = (( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create.").format( + num_worker_suggest, + ("" if cpuset_checked else " (`cpuset` is not taken into account)")) + ) if num_worker_suggest is not None else ( + "DataLoader is not able to compute a suggested max number of worker in current system.") + + warn_msg = ( + "This DataLoader will create {} worker processes in total. {} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary.").format( + num_worker_created, + suggested_max_worker_msg) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, 'sched_getaffinity'): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satify mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn(_create_warning_msg( + max_num_worker_suggest, + self.num_workers, + cpuset_checked)) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn(_create_warning_msg( + max_num_worker_suggest, + self.num_workers, + cpuset_checked)) + class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: @@ -414,6 +494,7 @@ def __init__(self, loader: DataLoader) -> None: self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 + self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__) def __iter__(self) -> '_BaseDataLoaderIter': return self @@ -430,22 +511,23 @@ def _next_data(self): raise NotImplementedError def __next__(self) -> Any: - if self._sampler_iter is None: - self._reset() - data = self._next_data() - self._num_yielded += 1 - if self._dataset_kind == _DatasetKind.Iterable and \ - self._IterableDataset_len_called is not None and \ - self._num_yielded > self._IterableDataset_len_called: - warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " - "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, - self._num_yielded) - if self._num_workers > 0: - warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " - "IterableDataset replica at each worker. Please see " - "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") - warnings.warn(warn_msg) - return data + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + self._reset() + data = self._next_data() + self._num_yielded += 1 + if self._dataset_kind == _DatasetKind.Iterable and \ + self._IterableDataset_len_called is not None and \ + self._num_yielded > self._IterableDataset_len_called: + warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " + "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, + self._num_yielded) + if self._num_workers > 0: + warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") + warnings.warn(warn_msg) + return data next = __next__ # Python 2 compatibility @@ -534,46 +616,72 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # simple things like acquiring an internal lock of a queue may hang. # Therefore, in this case, we actually need to prevent `__del__` from # being executed, and rely on the automatic termination of daemonic - # children. Thus, we register an `atexit` hook that sets a global flag + # children. + # + # Thus, we register an `atexit` hook that sets a global flag # `_utils.python_exit_status`. Since `atexit` hooks are executed in the # reverse order of registration, we are guaranteed that this flag is - # set before library resources we use are freed. (Hooks freeing those - # resources are registered at importing the Python core libraries at - # the top of this file.) So in `__del__`, we check if - # `_utils.python_exit_status` is set or `None` (freed), and perform - # no-op if so. + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: # - # Another problem with `__del__` is also related to the library cleanup - # calls. When a process ends, it shuts the all its daemonic children - # down with a SIGTERM (instead of joining them without a timeout). - # Simiarly for threads, but by a different mechanism. This fact, - # together with a few implementation details of multiprocessing, forces - # us to make workers daemonic. All of our problems arise when a - # DataLoader is used in a subprocess, and are caused by multiprocessing - # code which looks more or less like this: + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. # - # try: - # your_function_using_a_dataloader() - # finally: - # multiprocessing.util._exit_function() + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: # - # The joining/termination mentioned above happens inside - # `_exit_function()`. Now, if `your_function_using_a_dataloader()` - # throws, the stack trace stored in the exception will prevent the - # frame which uses `DataLoaderIter` to be freed. If the frame has any - # reference to the `DataLoaderIter` (e.g., in a method of the iter), - # its `__del__`, which starts the shutdown procedure, will not be - # called. That, in turn, means that workers aren't notified. Attempting - # to join in `_exit_function` will then result in a hang. + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() # - # For context, `_exit_function` is also registered as an `atexit` call. - # So it is unclear to me (@ssnl) why this is needed in a finally block. - # The code dates back to 2008 and there is no comment on the original - # PEP 371 or patch https://bugs.python.org/issue3050 (containing both - # the finally block and the `atexit` registration) that explains this. + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. # - # Another choice is to just shutdown workers with logic in 1 above - # whenever we see an error in `next`. This isn't ideal because + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because # a. It prevents users from using try-catch to resume data loading. # b. It doesn't prevent hanging if users have references to the # iterator. @@ -621,30 +729,33 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # We use `mp.Queue` which has a separate background thread to put # objects from an unbounded buffer array. The background thread is # daemonic and usually automatically joined when the process - # exits. - # - # However, in case that the receiver has ended abruptly while - # reading from the pipe, the join will hang forever. Therefore, - # for both `worker_result_queue` (worker -> main process/pin_memory_thread) - # and each `index_queue` (main process -> worker), we use - # `q.cancel_join_thread()` in sender process before any `q.put` to - # prevent this automatic join. + # *exits*. # - # Moreover, having all queues called `cancel_join_thread` makes - # implementing graceful shutdown logic in `__del__` much easier. - # It won't need to get from any queue, which would also need to be - # guarded by periodic status checks. + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). # # Nonetheless, `cancel_join_thread` must only be called when the # queue is **not** going to be read from or write into by another # process, because it may hold onto a lock or leave corrupted data # in the queue, leading other readers/writers to hang. # - # `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does - # a blocking `put` if the queue is full. So there is no above - # problem, but we do need to wrap the `put` in a loop that breaks - # not only upon success, but also when the main process stops - # reading, i.e., is shutting down. + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. # # # Now let's get back to 1: @@ -783,7 +894,9 @@ def __init__(self, loader): for i in range(self._num_workers): # No certainty which module multiprocessing_context is index_queue = multiprocessing_context.Queue() # type: ignore - # index_queue.cancel_join_thread() + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() w = multiprocessing_context.Process( target=_utils.worker._worker_loop, args=(self._dataset_kind, self._dataset, index_queue, @@ -840,7 +953,7 @@ def _reset(self, loader, first_iter=False): # contains all `True`s if not using an iterable-style dataset # (i.e., if kind != Iterable). # Not that this indicates that a worker still has work to do *for this epoch*. - # It does not mean that a worker is dead. In case of `_persistent_workers`, + # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. self._workers_status = [True for i in range(self._num_workers)] # We resume the prefetching in case it was enabled @@ -849,8 +962,9 @@ def _reset(self, loader, first_iter=False): self._index_queues[idx].put(_utils.worker._ResumeIteration()) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: - data = self._get_data() - if isinstance(data, _utils.worker._ResumeIteration): + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None resume_iteration_cnt -= 1 # prime the prefetch loop for _ in range(self._prefetch_factor * self._num_workers): @@ -1150,6 +1264,9 @@ def _shutdown_workers(self): if not self._shutdown: self._shutdown = True try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + # Exit `pin_memory_thread` first because exiting workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread` # reads from. @@ -1174,13 +1291,10 @@ def _shutdown_workers(self): if self._persistent_workers or self._workers_status[worker_id]: self._mark_worker_as_unavailable(worker_id, shutdown=True) for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) - if w.is_alive(): - # Existing mechanisms try to make the workers exit - # peacefully, but in case that we unfortunately reach - # here, which we shouldn't, (e.g., pytorch/pytorch#39570), - # we kill the worker. - w.terminate() for q in self._index_queues: q.cancel_join_thread() q.close() @@ -1198,6 +1312,13 @@ def _shutdown_workers(self): if self._worker_pids_set: _utils.signal_handling._remove_worker_pids(id(self)) self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() def __del__(self): self._shutdown_workers() diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index 7f466d18dcc85..0bef57c97629d 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -1,4 +1,5 @@ import bisect +import random import warnings from torch._utils import _accumulate @@ -157,13 +158,13 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]): Each sample will be retrieved by indexing tensors along the first dimension. - Arguments: + Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: - assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) + assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): @@ -178,7 +179,7 @@ class ConcatDataset(Dataset[T_co]): This class is useful to assemble different existing datasets. - Arguments: + Args: datasets (sequence): List of datasets to be concatenated """ datasets: List[Dataset[T_co]] @@ -231,7 +232,7 @@ class ChainDataset(IterableDataset): chainning operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. - Arguments: + Args: datasets (iterable of IterableDataset): datasets to be chained together """ def __init__(self, datasets: Iterable[Dataset]) -> None: @@ -253,11 +254,68 @@ def __len__(self): return total +class BufferedShuffleDataset(IterableDataset[T_co]): + r"""Dataset shuffled from the original dataset. + + This class is useful to shuffle an existing instance of an IterableDataset. + The buffer with `buffer_size` is filled with the items from the dataset first. Then, + each item will be yielded from the buffer by reservoir sampling via iterator. + + `buffer_size` is required to be larger than 0. For `buffer_size == 1`, the + dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size` + is required to be greater than or equal to the size of dataset. + + When it is used with :class:`~torch.utils.data.DataLoader`, each item in the + dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. + And, the method to set up a random seed is different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is required to + be set before the :class:`~torch.utils.data.DataLoader` in the main process. + + >>> ds = BufferedShuffleDataset(dataset) + >>> random.seed(...) + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + + For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable + function in each worker. + + >>> ds = BufferedShuffleDataset(dataset) + >>> def init_fn(worker_id): + ... random.seed(...) + >>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn))) + + Args: + dataset (IterableDataset): The original IterableDataset. + buffer_size (int): The buffer size for shuffling. + """ + dataset: IterableDataset[T_co] + buffer_size: int + + def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None: + super(BufferedShuffleDataset, self).__init__() + assert buffer_size > 0, "buffer_size should be larger than 0" + self.dataset = dataset + self.buffer_size = buffer_size + + def __iter__(self) -> Iterator[T_co]: + buf: List[T_co] = [] + for x in self.dataset: + if len(buf) == self.buffer_size: + idx = random.randint(0, self.buffer_size - 1) + yield buf[idx] + buf[idx] = x + else: + buf.append(x) + random.shuffle(buf) + while buf: + yield buf.pop() + + class Subset(Dataset[T_co]): r""" Subset of a dataset at specified indices. - Arguments: + Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ @@ -283,7 +341,7 @@ def random_split(dataset: Dataset[T], lengths: Sequence[int], >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42)) - Arguments: + Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced generator (Generator): Generator used for the random permutation. diff --git a/torch/utils/data/datasets/__init__.py b/torch/utils/data/datasets/__init__.py new file mode 100644 index 0000000000000..6011fcf9e5256 --- /dev/null +++ b/torch/utils/data/datasets/__init__.py @@ -0,0 +1,8 @@ +from .batchdataset import BatchIterableDataset +from .collatedataset import CollateIterableDataset +from .samplerdataset import SamplerIterableDataset +from .listdirfilesdataset import ListDirFilesIterableDataset +from .loadfilesfromdiskdataset import LoadFilesFromDiskIterableDataset + +__all__ = ['BatchIterableDataset', 'CollateIterableDataset', 'ListDirFilesIterableDataset', + 'LoadFilesFromDiskIterableDataset', 'SamplerIterableDataset'] diff --git a/torch/utils/data/datasets/batchdataset.py b/torch/utils/data/datasets/batchdataset.py new file mode 100644 index 0000000000000..fe725379bac76 --- /dev/null +++ b/torch/utils/data/datasets/batchdataset.py @@ -0,0 +1,57 @@ +from torch.utils.data import IterableDataset +from typing import TypeVar, Optional, Iterator, List, Sized + +T_co = TypeVar('T_co', covariant=True) + + +class BatchIterableDataset(IterableDataset[List[T_co]]): + r""" :class:`BatchIterableDataset`. + + IterableDataset to create mini-batches of data. An outer dimension will be added as + `batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the + last batch if `drop_last` is set to `False`. + args: + dataset: IterableDataset being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + """ + dataset: IterableDataset[T_co] + batch_size: int + drop_last: bool + length: Optional[int] + + def __init__(self, + dataset: IterableDataset[T_co], + *, + batch_size: int, + drop_last: bool = False, + ) -> None: + assert batch_size > 0, "Batch size is required to be larger than 0!" + super(BatchIterableDataset, self).__init__() + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.length = None + + def __iter__(self) -> Iterator[List[T_co]]: + batch: List[T_co] = [] + for x in self.dataset: + batch.append(x) + if len(batch) == self.batch_size: + yield batch + batch.clear() + if len(batch) > 0: + if not self.drop_last: + yield batch + batch.clear() + + def __len__(self) -> int: + if self.length is not None: + return self.length + if isinstance(self.dataset, Sized) and len(self.dataset) >= 0: + if self.drop_last: + self.length = len(self.dataset) // self.batch_size + else: + self.length = (len(self.dataset) + self.batch_size - 1) // self.batch_size + return self.length + raise NotImplementedError diff --git a/torch/utils/data/datasets/collatedataset.py b/torch/utils/data/datasets/collatedataset.py new file mode 100644 index 0000000000000..0ea5df0667f51 --- /dev/null +++ b/torch/utils/data/datasets/collatedataset.py @@ -0,0 +1,61 @@ +from torch.utils.data import IterableDataset, _utils +from typing import TypeVar, Callable, Iterator, Sized + +T_co = TypeVar('T_co', covariant=True) +S_co = TypeVar('S_co', covariant=True) + + +class CollateIterableDataset(IterableDataset[T_co]): + r""" :class:`CollateIterableDataset`. + + IterableDataset to collate samples from dataset to Tensor(s) by `util_.collate.default_collate`, + or customized Data Structure by collate_fn. + args: + dataset: IterableDataset being collated + collate_fn: Customized collate function to collect and combine data or a batch of data. + Default function collates to Tensor(s) based on data type. + + Example: Convert integer data to float Tensor + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + ... def __len__(self): + ... return self.end - self.start + ... + >>> ds = MyIterableDataset(start=3, end=7) + >>> print(list(ds)) + [3, 4, 5, 6] + + >>> def collate_fn(batch): + ... return torch.tensor(batch, dtype=torch.float) + ... + >>> collated_ds = CollateIterableDataset(ds, collate_fn=collate_fn) + >>> print(list(collated_ds)) + [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] + """ + def __init__(self, + dataset: IterableDataset[S_co], + *, + collate_fn: Callable[[S_co], T_co] = _utils.collate.default_collate, + ) -> None: + super(CollateIterableDataset, self).__init__() + self.dataset = dataset + self.collate_fn = collate_fn + + def __iter__(self) -> Iterator[T_co]: + for data in self.dataset: + yield self.collate_fn(data) + + # `__len__` is attached to class not instance + # Assume dataset has implemented `__len__` or raise NotImplementedError + def __len__(self) -> int: + if isinstance(self.dataset, Sized) and len(self.dataset) >= 0: + return len(self.dataset) + raise NotImplementedError diff --git a/torch/utils/data/datasets/common.py b/torch/utils/data/datasets/common.py new file mode 100644 index 0000000000000..c28e01eb3a836 --- /dev/null +++ b/torch/utils/data/datasets/common.py @@ -0,0 +1,53 @@ +import os +import fnmatch +import warnings +from typing import List, Union, Iterable + + +def match_masks(name : str, masks : Union[str, List[str]]) -> bool: + # empty mask matches any input name + if not masks: + return True + + if isinstance(masks, str): + return fnmatch.fnmatch(name, masks) + + for mask in masks: + if fnmatch.fnmatch(name, mask): + return True + return False + + +def get_file_pathnames_from_root( + root: str, + masks: Union[str, List[str]], + recursive: bool = False, + abspath: bool = False) -> Iterable[str]: + + # print out an error message and raise the error out + def onerror(err : OSError): + warnings.warn(err.filename + " : " + err.strerror) + raise err + + for path, dirs, files in os.walk(root, onerror=onerror): + if abspath: + path = os.path.abspath(path) + for f in files: + if match_masks(f, masks): + yield os.path.join(path, f) + if not recursive: + break + + +def get_file_binaries_from_pathnames(pathnames : Iterable): + + if not isinstance(pathnames, Iterable): + warnings.warn("get_file_binaries_from_pathnames needs the input be an Iterable") + raise TypeError + + for pathname in pathnames: + if not isinstance(pathname, str): + warnings.warn("file pathname must be string type, but got {}".format(type(pathname))) + raise TypeError + + yield (pathname, open(pathname, 'rb')) diff --git a/torch/utils/data/datasets/listdirfilesdataset.py b/torch/utils/data/datasets/listdirfilesdataset.py new file mode 100644 index 0000000000000..376971cc1adc6 --- /dev/null +++ b/torch/utils/data/datasets/listdirfilesdataset.py @@ -0,0 +1,36 @@ +from torch.utils.data.dataset import IterableDataset +from torch.utils.data.datasets.common import get_file_pathnames_from_root + +from typing import List, Union, Iterator + +class ListDirFilesIterableDataset(IterableDataset): + r""" :class:`ListDirFilesIterableDataset` + + IterableDataset to load file pathname(s) (path + filename), yield pathname from given disk root dir. + args: + root : root dir + mask : a unix style filter string or string list for filtering file name(s) + abspath : whether to return relative pathname or absolute pathname + length : a nominal length of the dataset + """ + + def __init__( + self, + root: str = '.', + masks: Union[str, List[str]] = '*.tar', + *, + abspath: bool = False, + length: int = -1): + super().__init__() + self.root : str = root + self.masks : Union[str, List[str]] = masks + self.abspath : bool = abspath + self.length : int = length + + def __iter__(self) -> Iterator[str] : + yield from get_file_pathnames_from_root(self.root, self.masks, self.abspath) + + def __len__(self): + if self.length == -1: + raise NotImplementedError + return self.length diff --git a/torch/utils/data/datasets/loadfilesfromdiskdataset.py b/torch/utils/data/datasets/loadfilesfromdiskdataset.py new file mode 100644 index 0000000000000..fdf8acb07ca1a --- /dev/null +++ b/torch/utils/data/datasets/loadfilesfromdiskdataset.py @@ -0,0 +1,30 @@ +from torch.utils.data.dataset import IterableDataset +from torch.utils.data.datasets.common import get_file_binaries_from_pathnames + +from typing import Iterable, Iterator + +class LoadFilesFromDiskIterableDataset(IterableDataset): + r""" :class:`LoadFilesFromDiskIterableDataset`. + + IterableDataset to load file binary streams from given pathnames, + yield pathname and binary stream in a tuple. + args: + dataset: Iterable dataset that provides pathnames + length: a nominal length of the dataset + """ + + def __init__( + self, + dataset : Iterable, + length : int = -1): + super().__init__() + self.dataset : Iterable = dataset + self.length : int = length + + def __iter__(self) -> Iterator[tuple] : + yield from get_file_binaries_from_pathnames(self.dataset) + + def __len__(self): + if self.length == -1: + raise NotImplementedError + return self.length diff --git a/torch/utils/data/datasets/samplerdataset.py b/torch/utils/data/datasets/samplerdataset.py new file mode 100644 index 0000000000000..449755be0d80e --- /dev/null +++ b/torch/utils/data/datasets/samplerdataset.py @@ -0,0 +1,38 @@ +from torch.utils.data import IterableDataset, Sampler, SequentialSampler +from typing import TypeVar, Type, Iterator, Sized + +T_co = TypeVar('T_co', covariant=True) + + +class SamplerIterableDataset(IterableDataset[T_co]): + r""" :class:`SamplerIterableDataset`. + + IterableDataset to generate sample elements. + args: + dataset: IterableDataset sampled from + sampler: Sampler class to genereate sample elements from input dataset. + Default is :class:`SequentialSampler` for IterableDataset + """ + dataset: IterableDataset + sampler: Sampler + + def __init__(self, + dataset: IterableDataset, + *, + sampler: Type[Sampler] = SequentialSampler, + **kwargs + ) -> None: + assert isinstance(dataset, Sized), \ + "Sampler class requires input dataset implemented `__len__`" + self.dataset = dataset + # https://github.com/python/mypy/pull/9629 will solve + self.sampler = sampler(data_source=self.dataset, **kwargs) # type: ignore + + def __iter__(self) -> Iterator[T_co]: + return iter(self.sampler) + + def __len__(self) -> int: + # Dataset has been tested as `Sized` + if isinstance(self.sampler, Sized) and len(self.sampler) >= 0: + return len(self.sampler) + raise NotImplementedError diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py index b4f5ac3995046..8ce61d9948c57 100644 --- a/torch/utils/data/distributed.py +++ b/torch/utils/data/distributed.py @@ -21,10 +21,10 @@ class DistributedSampler(Sampler[T_co]): .. note:: Dataset is assumed to be of constant size. - Arguments: + Args: dataset: Dataset used for sampling. num_replicas (int, optional): Number of processes participating in - distributed training. By default, :attr:`rank` is retrieved from the + distributed training. By default, :attr:`world_size` is retrieved from the current distributed group. rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is retrieved from the current distributed @@ -67,6 +67,10 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -74,15 +78,17 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. - if self.drop_last and len(self.dataset) % self.num_replicas != 0: + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / self.num_replicas + # `type:ignore` is required because Dataset cannot provide a default __len__ + # see NOTE in pytorch/torch/utils/data/sampler.py + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore ) else: - self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed @@ -92,13 +98,17 @@ def __iter__(self) -> Iterator[T_co]: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore else: - indices = list(range(len(self.dataset))) + indices = list(range(len(self.dataset))) # type: ignore if not self.drop_last: # add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] @@ -119,7 +129,7 @@ def set_epoch(self, epoch: int) -> None: use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. - Arguments: + Args: epoch (int): Epoch number. """ self.epoch = epoch diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 8cc650e899488..e48ad64fdc9b2 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -55,7 +55,7 @@ def __iter__(self) -> Iterator[T_co]: class SequentialSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. - Arguments: + Args: data_source (Dataset): dataset to sample from """ data_source: Sized @@ -74,7 +74,7 @@ class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. - Arguments: + Args: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument @@ -131,7 +131,7 @@ def __len__(self): class SubsetRandomSampler(Sampler[int]): r"""Samples elements randomly from a given list of indices, without replacement. - Arguments: + Args: indices (sequence): a sequence of indices generator (Generator): Generator used in sampling. """ diff --git a/torch/utils/file_baton.py b/torch/utils/file_baton.py index 7b15ceb70f41e..83cb85699c6d5 100644 --- a/torch/utils/file_baton.py +++ b/torch/utils/file_baton.py @@ -1,12 +1,6 @@ import os -import sys import time -if sys.version < '3.3': - # Note(jiayq): in Python 2, FileExistsError is not defined and the - # error manifests it as OSError. - FileExistsError = OSError - class FileBaton: '''A primitive, file-based synchronization utility.''' diff --git a/torch/utils/hipify/__init__.py b/torch/utils/hipify/__init__.py index e69de29bb2d1d..58f3ace6c03d0 100644 --- a/torch/utils/hipify/__init__.py +++ b/torch/utils/hipify/__init__.py @@ -0,0 +1 @@ +from .version import __version__ diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index c6fdcf49e46c7..25069b584b114 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -552,26 +552,26 @@ ("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)), ("cublas.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), ("cublas_v2.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)), - ("curand.h", ("hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)), - ("curand_kernel.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_discrete.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_discrete2.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_globals.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_lognormal.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_mrg32k3a.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_mtgp32.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_mtgp32_host.h", ("hiprand_mtgp32_host.h", CONV_INCLUDE, API_RAND)), - ("curand_mtgp32_kernel.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)), + ("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_discrete2.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_globals.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_lognormal.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_mrg32k3a.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_mtgp32.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_mtgp32_host.h", ("hiprand/hiprand_mtgp32_host.h", CONV_INCLUDE, API_RAND)), + ("curand_mtgp32_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ( "curand_mtgp32dc_p_11213.h", - ("rocrand_mtgp32_11213.h", CONV_INCLUDE, API_RAND), - ), - ("curand_normal.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_normal_static.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_philox4x32_x.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_poisson.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_precalc.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), - ("curand_uniform.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("rocrand/rocrand_mtgp32_11213.h", CONV_INCLUDE, API_RAND), + ), + ("curand_normal.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_normal_static.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_philox4x32_x.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_poisson.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), + ("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("cusparse.h", ("hipsparse.h", CONV_INCLUDE, API_RAND)), ("cufft.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)), ("cufftXt.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)), @@ -586,7 +586,7 @@ ("cub/device/device_radix_sort.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), ("cub/device/device_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), ("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), - ("nvToolsExt.h", ("roctx.h", CONV_INCLUDE, API_ROCTX)), + ("nvToolsExt.h", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)), ] ) @@ -7746,6 +7746,10 @@ [ ("cusparseStatus_t", ("hipsparseStatus_t", CONV_MATH_FUNC, API_SPARSE)), ("cusparseHandle_t", ("hipsparseHandle_t", CONV_MATH_FUNC, API_SPARSE)), + ( + "CUSPARSE_POINTER_MODE_HOST", + ("HIPSPARSE_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_SPARSE), + ), ("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPARSE)), ( "cusparseCreateMatDescr", @@ -7767,6 +7771,17 @@ "cusparseXcsrsort_bufferSizeExt", ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE), ), + ("cusparseCreateCsrgemm2Info", ("hipsparseCreateCsrgemm2Info", CONV_MATH_FUNC, API_SPARSE)), + ( + "cusparseDestroyCsrgemm2Info", + ("hipsparseDestroyCsrgemm2Info", CONV_MATH_FUNC, API_SPARSE), + ), + ("cusparseXcsrgemm2Nnz", ("hipsparseXcsrgemm2Nnz", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseDcsrgemm2_bufferSizeExt", ("hipsparseDcsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseScsrgemm2_bufferSizeExt", ("hipsparseScsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseDcsrgemm2", ("hipsparseDcsrgemm2", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseScsrgemm2", ("hipsparseScsrgemm2", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseSetPointerMode", ("hipsparseSetPointerMode", CONV_MATH_FUNC, API_SPARSE)), ("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPARSE)), ( "cusparseXcoosort_bufferSizeExt", @@ -8103,6 +8118,7 @@ ("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)), ("cuda::CUDACachingAllocator", ("hip::HIPCachingAllocator", API_C10)), ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)), + ("C10_CUDA_KERNEL_LAUNCH_CHECK", ("C10_HIP_KERNEL_LAUNCH_CHECK", API_C10)) ] ) diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 52aad5ea1d69c..898db0c0c15b9 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -34,11 +34,17 @@ from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS from .cuda_to_hip_mappings import MATH_TRANSPILATIONS +from typing import Dict, List, Iterator, Optional +from collections.abc import Mapping, Iterable +HipifyResult = Dict[str, Optional[str]] +HipifyFinalResult = Dict[str, HipifyResult] +HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n" +HIPIFY_FINAL_RESULT: HipifyFinalResult = {} + # Hardcode the PyTorch template map """This dictionary provides the mapping from PyTorch kernel template types to their actual types.""" PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} -CAFFE2_TEMPLATE_MAP = {} class InputError(Exception): @@ -107,14 +113,20 @@ def __exit__(self, type, value, traceback): for d in self.dirs_to_clean[::-1]: os.rmdir(d) -def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), out_of_place_only=False, is_pytorch_extension=False): +def match_extensions(filename: str, extensions: Iterable) -> bool: + """Helper method to see if filename ends with certain extension""" + return any(filename.endswith(e) for e in extensions) + +def matched_files_iter( + root_path: str, + includes: Iterable = ('*',), + ignores: Iterable = (), + extensions: Iterable = (), + out_of_place_only: bool = False, + is_pytorch_extension: bool = False) -> Iterator[str]: def _fnmatch(filepath, patterns): return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) - def match_extensions(filename): - """Helper method to see if filename ends with certain extension""" - return any(filename.endswith(e) for e in extensions) - exact_matches = set(includes) # This is a very rough heuristic; really, we want to avoid scanning @@ -139,7 +151,7 @@ def match_extensions(filename): if ( _fnmatch(filepath, includes) and (not _fnmatch(filepath, ignores)) - and (match_extensions(filepath) or filepath in exact_matches) + and (match_extensions(filepath, extensions) or filepath in exact_matches) ): if not is_pytorch_extension: # for pytorch extensions, consider all files if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath): @@ -149,14 +161,39 @@ def match_extensions(filename): yield filepath +def preprocess_file_and_save_result( + output_directory: str, + filepath: str, + all_files: Iterable, + includes: Iterable, + stats: Dict[str, List], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> None: + result = preprocessor(output_directory, filepath, all_files, includes, stats, + hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + fin_path = os.path.join(output_directory, filepath) + # Show what happened + if show_progress: + print( + fin_path, "->", + result["hipified_path"], result["status"]) + + if result["hipified_path"] is not None: + HIPIFY_FINAL_RESULT[fin_path] = result + + def preprocess( - output_directory, - all_files, - show_detailed=False, - show_progress=True, - hip_clang_launch=False, - is_pytorch_extension=False, - clean_ctx=None): + output_directory: str, + all_files: Iterable, + includes: Iterable, + show_detailed: bool = False, + show_progress: bool = True, + hip_clang_launch: bool = False, + is_pytorch_extension: bool = False, + clean_ctx: GeneratedFileCleaner = None) -> HipifyFinalResult: """ Call preprocessor on selected files. @@ -168,16 +205,11 @@ def preprocess( clean_ctx = GeneratedFileCleaner(keep_intermediates=True) # Preprocessing statistics. - stats = {"unsupported_calls": [], "kernel_launches": []} + stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []} for filepath in all_files: - result = preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx) - - # Show what happened - if show_progress: - print( - filepath, "->", - get_hip_file_path(filepath), result) + preprocess_file_and_save_result(output_directory, filepath, all_files, includes, stats, + hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr) @@ -185,6 +217,8 @@ def preprocess( if show_detailed: compute_stats(stats) + return HIPIFY_FINAL_RESULT + def compute_stats(stats): unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]} @@ -204,7 +238,7 @@ def add_dim3(kernel_string, cuda_kernel): count = 0 closure = 0 kernel_string = kernel_string.replace("<<<", "").replace(">>>", "") - arg_locs = [{} for _ in range(2)] + arg_locs: List[Dict[str, int]] = [{} for _ in range(2)] arg_locs[count]['start'] = 0 for ind, c in enumerate(kernel_string): if count > 1: @@ -425,7 +459,7 @@ def replace_math_functions(input_string): return output_string -RE_SYNCTHREADS = re.compile(r"[:]?[:]?\b(__syncthreads)\b(\w*\()") +RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()") def hip_header_magic(input_string): @@ -444,6 +478,7 @@ def hip_header_magic(input_string): return output_string # Rough logic to detect if we're inside device code + hasDeviceLogic: int hasDeviceLogic = "hipLaunchKernelGGL" in output_string hasDeviceLogic += "__global__" in output_string hasDeviceLogic += "__shared__" in output_string @@ -474,13 +509,13 @@ def replace_extern_shared(input_string): return output_string -def get_hip_file_path(filepath): +def get_hip_file_path(filepath, is_pytorch_extension=False): """ Returns the new name of the hipified file """ - # At the moment, some files are HIPified in place. The predicate + # At the moment, some PyTorch source files are HIPified in place. The predicate # is_out_of_place tells us if this is the case or not. - if not is_out_of_place(filepath): + if not is_pytorch_extension and not is_out_of_place(filepath): return filepath dirpath, filename = os.path.split(filepath) @@ -489,10 +524,8 @@ def get_hip_file_path(filepath): # Here's the plan: # # In general, we need to disambiguate the HIPified filename so that - # it gets a different name from the original Caffe2 filename, so - # that we don't overwrite the original file. (Additionally, - # hcc historically had a bug where if you had two files with - # the same basename, they would clobber each other.) + # it gets a different name from the original filename, so + # that we don't overwrite the original file # # There's a lot of different naming conventions across PyTorch # and Caffe2, but the general recipe is to convert occurrences @@ -506,12 +539,18 @@ def get_hip_file_path(filepath): # # - If the file name contains "CUDA", replace it with "HIP", AND # - # If NONE of the above occurred, then insert "hip" in the file path - # as the direct parent folder of the file + # - ALWAYS replace '.cu' with '.hip', because those files + # contain CUDA kernels that needs to be hipified and processed with + # hip compiler + # + # - If we are not hipifying a PyTorch extension, and the parent + # directory name did not change as a result of the above + # transformations, insert "hip" in the file path + # as the direct parent folder of the file # - # Furthermore, ALWAYS replace '.cu' with '.hip', because those files - # contain CUDA kernels that needs to be hipified and processed with - # hcc compiler + # - If we are hipifying a PyTorch extension, and the parent directory + # name as well as the filename (incl. extension) did not change as + # a result of the above transformations, insert "_hip" in the filename # # This isn't set in stone; we might adjust this to support other # naming conventions. @@ -519,6 +558,7 @@ def get_hip_file_path(filepath): if ext == '.cu': ext = '.hip' + orig_filename = filename orig_dirpath = dirpath dirpath = dirpath.replace('cuda', 'hip') @@ -530,9 +570,12 @@ def get_hip_file_path(filepath): if dirpath != "caffe2/core": root = root.replace('THC', 'THH') - if dirpath == orig_dirpath: + if not is_pytorch_extension and dirpath == orig_dirpath: dirpath = os.path.join(dirpath, 'hip') + if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename: + root = root + "_hip" + return os.path.join(dirpath, root + ext) @@ -632,6 +675,7 @@ def pattern(self): PYTORCH_TRIE = Trie() PYTORCH_MAP = {} for mapping in CUDA_TO_HIP_MAPPINGS: + assert isinstance(mapping, Mapping) for src, value in mapping.items(): dst = value[0] meta_data = value[1:] @@ -649,13 +693,35 @@ def pattern(self): RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"') RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh -def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx): +""" +Returns a dict with the following keys: + "hipified_path" : absolute path of hipified source file + "status" : "ok" if hipified file was written out + "skipped" if an identical hipified file already existed + "ignored" if the source file was a hipified file itself +""" +def preprocessor( + output_directory: str, + filepath: str, + all_files: Iterable, + includes: Iterable, + stats: Dict[str, List], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> HipifyResult: """ Executes the CUDA -> HIP conversion on the specified file. """ fin_path = os.path.join(output_directory, filepath) + with open(fin_path, 'r', encoding='utf-8') as fin: + if fin.readline() == HIPIFY_C_BREADCRUMB: + return {"hipified_path": None, "status": "ignored"} + fin.seek(0) output_source = fin.read() - fout_path = os.path.join(output_directory, get_hip_file_path(filepath)) + orig_output_source = output_source + + fout_path = os.path.join(output_directory, get_hip_file_path(filepath, is_pytorch_extension)) if not os.path.exists(os.path.dirname(fout_path)): clean_ctx.makedirs(os.path.dirname(fout_path)) @@ -674,9 +740,10 @@ def c2_repl(m): output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source) # Header rewrites - def mk_repl(templ): + def mk_repl(templ, include_current_dir=True): def repl(m): f = m.group(1) + dirpath, filename = os.path.split(f) if ( f.startswith("ATen/cuda") or f.startswith("ATen/native/cuda") @@ -686,11 +753,43 @@ def repl(m): or f.startswith("THCUNN/") or (f.startswith("THC") and not f.startswith("THCP")) ): - return templ.format(get_hip_file_path(m.group(1))) + return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension)) + # if filename is one of the files being hipified for this extension + if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)): + header_dir = None + header_filepath = None + # If include_current_dir True, look first in same dir as the including source file + if include_current_dir: + header_dir_to_check = os.path.dirname(fin_path) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If not found, look in include dirs one by one and first match wins + if header_filepath is None: + for include in includes: + header_dir_to_check = os.path.join(output_directory, os.path.dirname(include)) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If header file not found, keep as is + if header_filepath is None: + return m.group(0) + # Hipify header file first if needed + if header_filepath not in HIPIFY_FINAL_RESULT: + preprocess_file_and_save_result(output_directory, + os.path.relpath(header_filepath, output_directory), + all_files, includes, stats, hip_clang_launch, is_pytorch_extension, + clean_ctx, show_progress) + value = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"] + assert value is not None + return templ.format(os.path.relpath(value, header_dir)) + return m.group(0) return repl - output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"'), output_source) - output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>'), output_source) + output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source) + output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source) output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source) # CMakeLists.txt rewrites @@ -713,6 +812,18 @@ def repl(m): # Replace the extern __shared__ output_source = replace_extern_shared(output_source) + # Don't write out identical hipified files for extensions if dirpath has not changed + if ( + is_pytorch_extension + and orig_output_source == output_source + and os.path.dirname(fin_path) == os.path.dirname(fout_path) + ): + return {"hipified_path": fin_path, "status": "ok"} + + # Add hipify breadcrumb for C-style files to avoid re-hipification + if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")): + output_source = HIPIFY_C_BREADCRUMB + output_source + do_write = True if os.path.exists(fout_path): with open(fout_path, 'r', encoding='utf-8') as fout_old: @@ -720,9 +831,9 @@ def repl(m): if do_write: with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout: fout.write(output_source) - return "ok" + return {"hipified_path": fout_path, "status": "ok"} else: - return "skipped" + return {"hipified_path": fout_path, "status": "skipped"} def file_specific_replacement(filepath, search_string, replace_string, strict=False): with openf(filepath, "r+") as f: @@ -814,19 +925,19 @@ def str2bool(v): def hipify( - project_directory, - show_detailed=False, - extensions=(".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"), - output_directory="", - includes=(), - extra_files=(), - out_of_place_only=False, - ignores=(), - show_progress=True, - hip_clang_launch=False, - is_pytorch_extension=False, - clean_ctx=None -): + project_directory: str, + show_detailed: bool = False, + extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"), + output_directory: str = "", + includes: Iterable = (), + extra_files: Iterable = (), + out_of_place_only: bool = False, + ignores: Iterable = (), + show_progress: bool = True, + hip_clang_launch: bool = False, + is_pytorch_extension: bool = False, + clean_ctx: GeneratedFileCleaner = None +) -> HipifyFinalResult: if project_directory == "": project_directory = os.getcwd() @@ -849,12 +960,17 @@ def hipify( out_of_place_only=out_of_place_only, is_pytorch_extension=is_pytorch_extension)) all_files_set = set(all_files) - all_files += [f for f in extra_files if f not in all_files_set] + # Convert extra_files to relative paths since all_files has all relative paths + for f in extra_files: + f_rel = os.path.relpath(f, output_directory) + if f_rel not in all_files_set: + all_files.append(f_rel) # Start Preprocessor - preprocess( + return preprocess( output_directory, all_files, + includes, show_detailed=show_detailed, show_progress=show_progress, hip_clang_launch=hip_clang_launch, diff --git a/torch/utils/hipify/version.py b/torch/utils/hipify/version.py new file mode 100644 index 0000000000000..1f356cc57bfa0 --- /dev/null +++ b/torch/utils/hipify/version.py @@ -0,0 +1 @@ +__version__ = '1.0.0' diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index ac13023c81393..38ab2339e5b4f 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -1,6 +1,9 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import torch from collections import OrderedDict import weakref import warnings +import functools from typing import Any @@ -58,3 +61,122 @@ def warn_if_has_hooks(tensor): "serialized. If this is expected, you can " "decorate the function with @torch.utils.hooks.unserializable_hook " "to suppress this warning".format(repr(hook))) + +class BackwardHook(object): + """ + A wrapper class to implement nn.Module backward hooks. + It handles: + - Ignoring non-Tensor inputs and replacing them by None before calling the user hook + - Generating the proper Node to capture a set of Tensor's gradients + - Linking the gradients captures for the outputs with the gradients captured for the input + - Calling the user hook once both output and input gradients are available + """ + + def __init__(self, module, user_hooks): + self.user_hooks = user_hooks + self.module = module + + self.grad_outputs = None + self.n_outputs = -1 + self.output_tensors_index = None + self.n_inputs = -1 + self.input_tensors_index = None + + def _pack_with_none(self, indices, values, size): + res = [None] * size + for idx, val in zip(indices, values): + res[idx] = val + + return tuple(res) + + def _unpack_none(self, indices, values): + res = [] + for idx in indices: + res.append(values[idx]) + + return tuple(res) + + def _set_user_hook(self, grad_fn, user_hook): + @functools.wraps(user_hook) + def hook(grad_input, _): + if self.grad_outputs is None: + raise RuntimeError("Module backward hook for grad_input is called before " + "the grad_output one. This happens because the gradient " + "in your nn.Module flows to the Module's input without " + "passing through the Module's output. Make sure that the " + "output depends on the input and that the loss is computed " + "based on the output.") + + grad_input = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs) + res = user_hook(self.module, grad_input, self.grad_outputs) + if res is None: + return res + + if len(res) != len(grad_input): + raise RuntimeError("Backward hook returned an invalid number of grad_input, " + "got {}, but expected {}".format(len(res), len(grad_input))) + return self._unpack_none(self.input_tensors_index, res) + grad_fn.register_hook(hook) + + def _apply_on_tensors(self, fn, args): + # Can be used to apply the given function to the tensors contained in the + # args. Will return updated args and the tensors indices + tensors_idx = [] + tensors = [] + + requires_grad = False + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + tensors_idx.append(i) + tensors.append(arg) + requires_grad |= arg.requires_grad + + if not requires_grad: + return args, None + + new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors) + if len(new_tensors) == 0: + raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.") + grad_fn = new_tensors[0].grad_fn + if not grad_fn.name() == "BackwardHookFunctionBackward": + raise RuntimeError("Error while setting up backward hooks. Please open " + "an issue with a code sample to reproduce this.") + + fn(grad_fn) + + arg_list = list(args) + for idx, val in zip(tensors_idx, new_tensors): + arg_list[idx] = val + + return tuple(arg_list), tensors_idx + + def setup_input_hook(self, args): + def fn(grad_fn): + for hook in self.user_hooks: + self._set_user_hook(grad_fn, hook) + + res, input_idx = self._apply_on_tensors(fn, args) + self.n_inputs = len(args) + self.input_tensors_index = input_idx + return res + + def setup_output_hook(self, args): + def fn(grad_fn): + def hook(_, grad_output): + self.grad_outputs = self._pack_with_none(self.output_tensors_index, + grad_output, + self.n_outputs) + grad_fn.register_hook(hook) + + is_tuple = True + if not isinstance(args, tuple): + args = (args,) + is_tuple = False + + res, output_idx = self._apply_on_tensors(fn, args) + self.n_outputs = len(args) + self.output_tensors_index = output_idx + + if not is_tuple: + res = res[0] + return res diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index 16d49195ce162..caa0a9c9671dd 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -25,7 +25,7 @@ def optimize_for_mobile( optimization method will run all the optimizer pass; otherwise, optimizer method will run the optimization pass that is not included inside optimization_blocklist. perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked - backend: Device type to use for running the result model ('CPU'(default) or 'Vulkan'). + backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal'). Returns: A new optimized torch script module """ @@ -39,12 +39,36 @@ def optimize_for_mobile( if preserved_methods is None: preserved_methods = [] - if backend == 'CPU': - optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods) - elif backend == 'Vulkan': - optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods) + # Convert potential byte arrays into strings (if there is any) to pass type checking + # Here we use a new name as assigning it back to preserved_methods will invoke + # mypy errors (i.e. List[AnyStr] = List[str]) + preserved_methods_str: List[str] = [str(method) for method in preserved_methods] + + bundled_inputs_methods = ['get_all_bundled_inputs', 'get_num_bundled_inputs', 'run_on_bundled_input'] + if all([hasattr(script_module, method) for method in bundled_inputs_methods]): + preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_methods)) + + non_exist_methods = [] + for method in preserved_methods_str: + if not hasattr(script_module, method): + non_exist_methods.append(method) + if non_exist_methods: + raise AttributeError( + 'The following methods to preserve do not exist in script_module: {}' + .format(', '.join(non_exist_methods))) + + backend = backend.lower() + if backend == 'cpu': + optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile( + script_module._c, + optimization_blocklist, + preserved_methods_str) + elif backend == 'vulkan': + optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str) + elif backend == 'metal': + optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str) else: - raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan'") + raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'") return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module) diff --git a/torch/utils/show_pickle.py b/torch/utils/show_pickle.py index 51d524d63acf8..9e55ebff48b95 100644 --- a/torch/utils/show_pickle.py +++ b/torch/utils/show_pickle.py @@ -4,6 +4,7 @@ import pprint import zipfile import fnmatch +from typing import IO, BinaryIO, Union class FakeObject(object): @@ -44,7 +45,7 @@ class FakeClass(object): def __init__(self, module, name): self.module = module self.name = name - self.__new__ = self.fake_new + self.__new__ = self.fake_new # type: ignore def __repr__(self): return f"{self.module}.{self.name}" @@ -56,7 +57,7 @@ def fake_new(self, *args): return FakeObject(self.module, self.name, args[1:]) -class DumpUnpickler(pickle._Unpickler): +class DumpUnpickler(pickle._Unpickler): # type: ignore def find_class(self, module, name): return FakeClass(module, name) @@ -67,6 +68,7 @@ def persistent_load(self, pid): def dump(cls, in_stream, out_stream): value = cls(in_stream).load() pprint.pprint(value, stream=out_stream) + return value def main(argv, output_stream=None): @@ -84,6 +86,7 @@ def main(argv, output_stream=None): return 2 fname = argv[1] + handle: Union[IO[bytes], BinaryIO] if "@" not in fname: with open(fname, "rb") as handle: DumpUnpickler.dump(handle, output_stream) @@ -110,6 +113,6 @@ def main(argv, output_stream=None): # I've tested on the following versions: # 3.7.4 if True: - pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format + pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore sys.exit(main(sys.argv)) diff --git a/torch/utils/tensorboard/_caffe2_graph.py b/torch/utils/tensorboard/_caffe2_graph.py index 64f9f11f06d9f..3cd3a3608fed6 100644 --- a/torch/utils/tensorboard/_caffe2_graph.py +++ b/torch/utils/tensorboard/_caffe2_graph.py @@ -2,7 +2,6 @@ import logging import os import re -import six from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef @@ -12,8 +11,10 @@ from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace +from typing import Set, Dict, Tuple, List -def _make_unique_name(seen, name, min_version=0): + +def _make_unique_name(seen: Set[str], name: str, min_version: int = 0): ''' Make the name unique by appending a unique number to the name. Used for SSA. @@ -91,12 +92,12 @@ def _convert_to_ssa(shapes, blob_name_tracker, ops): None. Modifies blob_name_tracker and ops in-place. ''' ir = core.IR(ops) - seen = set() - versioned = {} + seen: Set[str] = set() + versioned: Dict[Tuple[str, int], int] = {} new_shapes = {} new_blob_name_tracker = {} - def ssa_name(name, versions): + def ssa_name(name: str, versions: Dict[str, int]) -> int: assert name in versions version = versions[name] if (name, version) in versioned: @@ -160,7 +161,7 @@ def _remap_keys(old_dict, rename_fn): None. Modifies old_dict in-place. ''' new_dict = {rename_fn(key): value for key, - value in six.iteritems(old_dict)} + value in old_dict.items()} old_dict.clear() old_dict.update(new_dict) @@ -180,8 +181,8 @@ def _rename_all(shapes, blob_name_tracker, ops, rename_fn): None. Modifies shapes, blob_name_tracker and ops in-place using the specified 'rename_fn'. ''' - seen = set() - renamed = {} + seen: Set[str] = set() + renamed: Dict[Tuple[str, int], int] = {} def g(name): """ Collision-free version of f. @@ -683,7 +684,7 @@ def _operators_to_graph_def( _fill_missing_operator_names(ops) if show_simplified: # use_tensorflow_naming _rename_tensorflow_style(shapes, blob_name_tracker, ops) - producing_ops = {} + producing_ops: Dict[caffe2_pb2.OperatorDef, List] = {} blobs = set() input_blobs, inter_blobs, _ = _compute_in_out(ops) current_graph = GraphDef() diff --git a/torch/utils/tensorboard/_convert_np.py b/torch/utils/tensorboard/_convert_np.py index 465eba41e859f..0e8fd663f106f 100644 --- a/torch/utils/tensorboard/_convert_np.py +++ b/torch/utils/tensorboard/_convert_np.py @@ -3,7 +3,6 @@ """ import numpy as np import torch -import six def make_np(x): @@ -16,7 +15,7 @@ def make_np(x): """ if isinstance(x, np.ndarray): return x - if isinstance(x, six.string_types): # Caffe2 will pass name of blob(s) to fetch + if isinstance(x, str): # Caffe2 will pass name of blob(s) to fetch return _prepare_caffe2(x) if np.isscalar(x): return np.array([x]) diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 8afd94febf4bd..c52776a894483 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -489,27 +489,22 @@ def audio(tag, tensor, sample_rate=44100): print('warning: audio amplitude out of range, auto clipped.') tensor = tensor.clip(-1, 1) assert(tensor.ndim == 1), 'input tensor should be 1 dimensional.' + tensor = (tensor * np.iinfo(np.int16).max).astype('